capnproto

FORK: Cap'n Proto serialization/RPC system - core tools and C++ library
git clone https://git.neptards.moe/neptards/capnproto.git
Log | Files | Refs | README | LICENSE

http-over-capnp.c++ (29422B)


      1 // Copyright (c) 2019 Cloudflare, Inc. and contributors
      2 // Licensed under the MIT License:
      3 //
      4 // Permission is hereby granted, free of charge, to any person obtaining a copy
      5 // of this software and associated documentation files (the "Software"), to deal
      6 // in the Software without restriction, including without limitation the rights
      7 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
      8 // copies of the Software, and to permit persons to whom the Software is
      9 // furnished to do so, subject to the following conditions:
     10 //
     11 // The above copyright notice and this permission notice shall be included in
     12 // all copies or substantial portions of the Software.
     13 //
     14 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
     15 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
     16 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
     17 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
     18 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
     19 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
     20 // THE SOFTWARE.
     21 
     22 #include "http-over-capnp.h"
     23 #include <kj/debug.h>
     24 #include <capnp/schema.h>
     25 
     26 namespace capnp {
     27 
     28 using kj::uint;
     29 using kj::byte;
     30 
     31 class HttpOverCapnpFactory::RequestState final
     32     : public kj::Refcounted, public kj::TaskSet::ErrorHandler {
     33 public:
     34   RequestState() {
     35     tasks.emplace(*this);
     36   }
     37 
     38   template <typename Func>
     39   auto wrap(Func&& func) -> decltype(func()) {
     40     if (tasks == nullptr) {
     41       return KJ_EXCEPTION(DISCONNECTED, "client canceled HTTP request");
     42     } else {
     43       return canceler.wrap(func());
     44     }
     45   }
     46 
     47   void cancel() {
     48     if (tasks != nullptr) {
     49       if (!canceler.isEmpty()) {
     50         canceler.cancel(KJ_EXCEPTION(DISCONNECTED, "request canceled"));
     51       }
     52       tasks = nullptr;
     53       webSocket = nullptr;
     54     }
     55   }
     56 
     57   void assertNotCanceled() {
     58     if (tasks == nullptr) {
     59       kj::throwFatalException(KJ_EXCEPTION(DISCONNECTED, "client canceled HTTP request"));
     60     }
     61   }
     62 
     63   void addTask(kj::Promise<void> task) {
     64     KJ_IF_MAYBE(t, tasks) {
     65       t->add(kj::mv(task));
     66     } else {
     67       // Just drop the task.
     68     }
     69   }
     70 
     71   kj::Promise<void> finishTasks() {
     72     // This is merged into the final promise, so we don't need to worry about wrapping it for
     73     // cancellation.
     74     return KJ_REQUIRE_NONNULL(tasks).onEmpty()
     75         .then([this]() {
     76       KJ_IF_MAYBE(e, error) {
     77         kj::throwRecoverableException(kj::mv(*e));
     78       }
     79     });
     80   }
     81 
     82   void taskFailed(kj::Exception&& exception) override {
     83     if (error == nullptr) {
     84       error = kj::mv(exception);
     85     }
     86   }
     87 
     88   void holdWebSocket(kj::Own<kj::WebSocket> webSocket) {
     89     // Hold on to this WebSocket until cancellation.
     90     KJ_REQUIRE(this->webSocket == nullptr);
     91     KJ_REQUIRE(tasks != nullptr);
     92     this->webSocket = kj::mv(webSocket);
     93   }
     94 
     95   void disconnectWebSocket() {
     96     KJ_IF_MAYBE(t, tasks) {
     97       t->add(kj::evalNow([&]() { return KJ_ASSERT_NONNULL(webSocket)->disconnect(); }));
     98     }
     99   }
    100 
    101 private:
    102   kj::Maybe<kj::Exception> error;
    103   kj::Maybe<kj::Own<kj::WebSocket>> webSocket;
    104   kj::Canceler canceler;
    105   kj::Maybe<kj::TaskSet> tasks;
    106 };
    107 
    108 // =======================================================================================
    109 
    110 class HttpOverCapnpFactory::CapnpToKjWebSocketAdapter final: public capnp::WebSocket::Server {
    111 public:
    112   CapnpToKjWebSocketAdapter(kj::Own<RequestState> state, kj::WebSocket& webSocket,
    113                             kj::Promise<Capability::Client> shorteningPromise)
    114       : state(kj::mv(state)), webSocket(webSocket),
    115         shorteningPromise(kj::mv(shorteningPromise)) {}
    116 
    117   ~CapnpToKjWebSocketAdapter() noexcept(false) {
    118     state->disconnectWebSocket();
    119   }
    120 
    121   kj::Maybe<kj::Promise<Capability::Client>> shortenPath() override {
    122     return kj::mv(shorteningPromise);
    123   }
    124 
    125   kj::Promise<void> sendText(SendTextContext context) override {
    126     return state->wrap([&]() { return webSocket.send(context.getParams().getText()); });
    127   }
    128   kj::Promise<void> sendData(SendDataContext context) override {
    129     return state->wrap([&]() { return webSocket.send(context.getParams().getData()); });
    130   }
    131   kj::Promise<void> close(CloseContext context) override {
    132     auto params = context.getParams();
    133     return state->wrap([&]() { return webSocket.close(params.getCode(), params.getReason()); });
    134   }
    135 
    136 private:
    137   kj::Own<RequestState> state;
    138   kj::WebSocket& webSocket;
    139   kj::Promise<Capability::Client> shorteningPromise;
    140 };
    141 
    142 class HttpOverCapnpFactory::KjToCapnpWebSocketAdapter final: public kj::WebSocket {
    143 public:
    144   KjToCapnpWebSocketAdapter(
    145       kj::Maybe<kj::Own<kj::WebSocket>> in, capnp::WebSocket::Client out,
    146       kj::Own<kj::PromiseFulfiller<kj::Promise<Capability::Client>>> shorteningFulfiller)
    147       : in(kj::mv(in)), out(kj::mv(out)), shorteningFulfiller(kj::mv(shorteningFulfiller)) {}
    148   ~KjToCapnpWebSocketAdapter() noexcept(false) {
    149     if (shorteningFulfiller->isWaiting()) {
    150       // We want to make sure the fulfiller is not rejected with a bogus "PromiseFulfiller
    151       // destroyed" error, so fulfill it with never-done.
    152       shorteningFulfiller->fulfill(kj::NEVER_DONE);
    153     }
    154   }
    155 
    156   kj::Promise<void> send(kj::ArrayPtr<const byte> message) override {
    157     auto req = KJ_REQUIRE_NONNULL(out, "already called disconnect()").sendDataRequest(
    158         MessageSize { 8 + message.size() / sizeof(word), 0 });
    159     req.setData(message);
    160     sentBytes += message.size();
    161     return req.send();
    162   }
    163 
    164   kj::Promise<void> send(kj::ArrayPtr<const char> message) override {
    165     auto req = KJ_REQUIRE_NONNULL(out, "already called disconnect()").sendTextRequest(
    166         MessageSize { 8 + message.size() / sizeof(word), 0 });
    167     memcpy(req.initText(message.size()).begin(), message.begin(), message.size());
    168     sentBytes += message.size();
    169     return req.send();
    170   }
    171 
    172   kj::Promise<void> close(uint16_t code, kj::StringPtr reason) override {
    173     auto req = KJ_REQUIRE_NONNULL(out, "already called disconnect()").closeRequest();
    174     req.setCode(code);
    175     req.setReason(reason);
    176     sentBytes += reason.size() + 2;
    177     return req.send().ignoreResult();
    178   }
    179 
    180   kj::Promise<void> disconnect() override {
    181     out = nullptr;
    182     return kj::READY_NOW;
    183   }
    184 
    185   void abort() override {
    186     KJ_ASSERT_NONNULL(in)->abort();
    187   }
    188 
    189   kj::Promise<void> whenAborted() override {
    190     return KJ_ASSERT_NONNULL(out).whenResolved()
    191         .then([]() -> kj::Promise<void> {
    192       // It would seem this capability resolved to an implementation of the WebSocket RPC interface
    193       // that does not support further path-shortening (so, it's not the implementation found in
    194       // this file). Since the path-shortening facility is also how we discover disconnects, we
    195       // apparently have no way to be alerted on disconnect. We have to assume the other end
    196       // never aborts.
    197       return kj::NEVER_DONE;
    198     }, [](kj::Exception&& e) -> kj::Promise<void> {
    199       if (e.getType() == kj::Exception::Type::DISCONNECTED) {
    200         // Looks like we were aborted!
    201         return kj::READY_NOW;
    202       } else {
    203         // Some other error... propagate it.
    204         return kj::mv(e);
    205       }
    206     });
    207   }
    208 
    209   kj::Promise<Message> receive(size_t maxSize) override {
    210     return KJ_ASSERT_NONNULL(in)->receive(maxSize);
    211   }
    212 
    213   kj::Promise<void> pumpTo(WebSocket& other) override {
    214     KJ_IF_MAYBE(optimized, kj::dynamicDowncastIfAvailable<KjToCapnpWebSocketAdapter>(other)) {
    215       shorteningFulfiller->fulfill(
    216           kj::cp(KJ_REQUIRE_NONNULL(optimized->out, "already called disconnect()")));
    217 
    218       // We expect the `in` pipe will stop receiving messages after the redirect, but we need to
    219       // pump anything already in-flight.
    220       return KJ_ASSERT_NONNULL(in)->pumpTo(other);
    221     } else KJ_IF_MAYBE(promise, other.tryPumpFrom(*this)) {
    222       // We may have unwrapped some layers around `other` leading to a shorter path.
    223       return kj::mv(*promise);
    224     } else {
    225       return KJ_ASSERT_NONNULL(in)->pumpTo(other);
    226     }
    227   }
    228 
    229   uint64_t sentByteCount() override { return sentBytes; }
    230   uint64_t receivedByteCount() override { return KJ_ASSERT_NONNULL(in)->receivedByteCount(); }
    231 
    232 private:
    233   kj::Maybe<kj::Own<kj::WebSocket>> in;   // One end of a WebSocketPipe, used only for receiving.
    234   kj::Maybe<capnp::WebSocket::Client> out;  // Used only for sending.
    235   kj::Own<kj::PromiseFulfiller<kj::Promise<Capability::Client>>> shorteningFulfiller;
    236   uint64_t sentBytes = 0;
    237 };
    238 
    239 // =======================================================================================
    240 
    241 class HttpOverCapnpFactory::ClientRequestContextImpl final
    242     : public capnp::HttpService::ClientRequestContext::Server {
    243 public:
    244   ClientRequestContextImpl(HttpOverCapnpFactory& factory,
    245                            kj::Own<RequestState> state,
    246                            kj::HttpService::Response& kjResponse)
    247       : factory(factory), state(kj::mv(state)), kjResponse(kjResponse) {}
    248 
    249   ~ClientRequestContextImpl() noexcept(false) {
    250     // Note this implicitly cancels the upstream pump task.
    251   }
    252 
    253   kj::Promise<void> startResponse(StartResponseContext context) override {
    254     KJ_REQUIRE(!sent, "already called startResponse() or startWebSocket()");
    255     sent = true;
    256     state->assertNotCanceled();
    257 
    258     auto params = context.getParams();
    259     auto rpcResponse = params.getResponse();
    260 
    261     auto bodySize = rpcResponse.getBodySize();
    262     kj::Maybe<uint64_t> expectedSize;
    263     bool hasBody = true;
    264     if (bodySize.isFixed()) {
    265       auto size = bodySize.getFixed();
    266       expectedSize = bodySize.getFixed();
    267       hasBody = size > 0;
    268     }
    269 
    270     auto bodyStream = kjResponse.send(rpcResponse.getStatusCode(), rpcResponse.getStatusText(),
    271         factory.headersToKj(rpcResponse.getHeaders()), expectedSize);
    272 
    273     auto results = context.getResults(MessageSize { 16, 1 });
    274     if (hasBody) {
    275       auto pipe = kj::newOneWayPipe();
    276       results.setBody(factory.streamFactory.kjToCapnp(kj::mv(pipe.out)));
    277       state->addTask(pipe.in->pumpTo(*bodyStream)
    278           .ignoreResult()
    279           .attach(kj::mv(bodyStream), kj::mv(pipe.in)));
    280     }
    281     return kj::READY_NOW;
    282   }
    283 
    284   kj::Promise<void> startWebSocket(StartWebSocketContext context) override {
    285     KJ_REQUIRE(!sent, "already called startResponse() or startWebSocket()");
    286     sent = true;
    287     state->assertNotCanceled();
    288 
    289     auto params = context.getParams();
    290 
    291     auto shorteningPaf = kj::newPromiseAndFulfiller<kj::Promise<Capability::Client>>();
    292 
    293     auto ownWebSocket = kjResponse.acceptWebSocket(factory.headersToKj(params.getHeaders()));
    294     auto& webSocket = *ownWebSocket;
    295     state->holdWebSocket(kj::mv(ownWebSocket));
    296 
    297     auto upWrapper = kj::heap<KjToCapnpWebSocketAdapter>(
    298         nullptr, params.getUpSocket(), kj::mv(shorteningPaf.fulfiller));
    299     state->addTask(webSocket.pumpTo(*upWrapper).attach(kj::mv(upWrapper))
    300         .catch_([&webSocket=webSocket](kj::Exception&& e) -> kj::Promise<void> {
    301       // The pump in the client -> server direction failed. The error may have originated from
    302       // either the client or the server. In case it came from the server, we want to call .abort()
    303       // to propagate the problem back to the client. If the error came from the client, then
    304       // .abort() probably is a noop.
    305       webSocket.abort();
    306       return kj::mv(e);
    307     }));
    308 
    309     auto results = context.getResults(MessageSize { 16, 1 });
    310     results.setDownSocket(kj::heap<CapnpToKjWebSocketAdapter>(
    311         kj::addRef(*state), webSocket, kj::mv(shorteningPaf.promise)));
    312 
    313     return kj::READY_NOW;
    314   }
    315 
    316 private:
    317   HttpOverCapnpFactory& factory;
    318   kj::Own<RequestState> state;
    319   bool sent = false;
    320 
    321   kj::HttpService::Response& kjResponse;
    322   // Must check state->assertNotCanceled() before using this.
    323 };
    324 
    325 class HttpOverCapnpFactory::KjToCapnpHttpServiceAdapter final: public kj::HttpService {
    326 public:
    327   KjToCapnpHttpServiceAdapter(HttpOverCapnpFactory& factory, capnp::HttpService::Client inner)
    328       : factory(factory), inner(kj::mv(inner)) {}
    329 
    330   kj::Promise<void> request(
    331       kj::HttpMethod method, kj::StringPtr url, const kj::HttpHeaders& headers,
    332       kj::AsyncInputStream& requestBody, kj::HttpService::Response& kjResponse) override {
    333     auto rpcRequest = inner.startRequestRequest();
    334 
    335     auto metadata = rpcRequest.initRequest();
    336     metadata.setMethod(static_cast<capnp::HttpMethod>(method));
    337     metadata.setUrl(url);
    338     metadata.adoptHeaders(factory.headersToCapnp(
    339         headers, Orphanage::getForMessageContaining(metadata)));
    340 
    341     kj::Maybe<kj::AsyncInputStream&> maybeRequestBody;
    342 
    343     KJ_IF_MAYBE(s, requestBody.tryGetLength()) {
    344       metadata.getBodySize().setFixed(*s);
    345       if (*s == 0) {
    346         maybeRequestBody = nullptr;
    347       } else {
    348         maybeRequestBody = requestBody;
    349       }
    350     } else if ((method == kj::HttpMethod::GET || method == kj::HttpMethod::HEAD) &&
    351                headers.get(kj::HttpHeaderId::TRANSFER_ENCODING) == nullptr) {
    352       maybeRequestBody = nullptr;
    353       metadata.getBodySize().setFixed(0);
    354     } else {
    355       metadata.getBodySize().setUnknown();
    356       maybeRequestBody = requestBody;
    357     }
    358 
    359     auto state = kj::refcounted<RequestState>();
    360     auto deferredCancel = kj::defer([state = kj::addRef(*state)]() mutable {
    361       state->cancel();
    362     });
    363 
    364     rpcRequest.setContext(
    365         kj::heap<ClientRequestContextImpl>(factory, kj::addRef(*state), kjResponse));
    366 
    367     auto pipeline = rpcRequest.send();
    368 
    369     // Pump upstream -- unless we don't expect a request body.
    370     kj::Maybe<kj::Promise<void>> pumpRequestTask;
    371     KJ_IF_MAYBE(rb, maybeRequestBody) {
    372       auto bodyOut = factory.streamFactory.capnpToKj(pipeline.getRequestBody());
    373       pumpRequestTask = rb->pumpTo(*bodyOut).attach(kj::mv(bodyOut)).ignoreResult()
    374           .eagerlyEvaluate([state = kj::addRef(*state)](kj::Exception&& e) mutable {
    375         // A DISCONNECTED exception probably means the server decided not to read the whole request
    376         // before responding. In that case we simply want the pump to end, so that on this end it
    377         // also appears that the service simply didn't read everything. So we don't propagate the
    378         // exception in that case. For any other exception, we want to merge the exception with
    379         // the final result.
    380         if (e.getType() != kj::Exception::Type::DISCONNECTED) {
    381           state->taskFailed(kj::mv(e));
    382         }
    383       });
    384     }
    385 
    386     // Wait for the ServerRequestContext to resolve, which indicates completion. Meanwhile, if the
    387     // promise is canceled from the client side, we drop the ServerRequestContext naturally, and we
    388     // also call state->cancel().
    389     return pipeline.getContext().whenResolved()
    390         // Once the server indicates it is done, then we can cancel pumping the request, because
    391         // obviously the server won't use it. We should not cancel pumping the response since there
    392         // could be data in-flight still.
    393         .attach(kj::mv(pumpRequestTask))
    394         // finishTasks() will wait for the respones to complete.
    395         .then([state = kj::mv(state)]() mutable { return state->finishTasks(); })
    396         .attach(kj::mv(deferredCancel));
    397   }
    398 
    399 private:
    400   HttpOverCapnpFactory& factory;
    401   capnp::HttpService::Client inner;
    402 };
    403 
    404 kj::Own<kj::HttpService> HttpOverCapnpFactory::capnpToKj(capnp::HttpService::Client rpcService) {
    405   return kj::heap<KjToCapnpHttpServiceAdapter>(*this, kj::mv(rpcService));
    406 }
    407 
    408 // =======================================================================================
    409 
    410 namespace {
    411 
    412 class NullInputStream final: public kj::AsyncInputStream {
    413   // TODO(cleanup): This class has been replicated in a bunch of places now, make it public
    414   //   somewhere.
    415 
    416 public:
    417   kj::Promise<size_t> tryRead(void* buffer, size_t minBytes, size_t maxBytes) override {
    418     return size_t(0);
    419   }
    420 
    421   kj::Maybe<uint64_t> tryGetLength() override {
    422     return uint64_t(0);
    423   }
    424 
    425   kj::Promise<uint64_t> pumpTo(kj::AsyncOutputStream& output, uint64_t amount) override {
    426     return uint64_t(0);
    427   }
    428 };
    429 
    430 class NullOutputStream final: public kj::AsyncOutputStream {
    431   // TODO(cleanup): This class has been replicated in a bunch of places now, make it public
    432   //   somewhere.
    433 
    434 public:
    435   kj::Promise<void> write(const void* buffer, size_t size) override {
    436     return kj::READY_NOW;
    437   }
    438   kj::Promise<void> write(kj::ArrayPtr<const kj::ArrayPtr<const byte>> pieces) override {
    439     return kj::READY_NOW;
    440   }
    441   kj::Promise<void> whenWriteDisconnected() override {
    442     return kj::NEVER_DONE;
    443   }
    444 
    445   // We can't really optimize tryPumpFrom() unless AsyncInputStream grows a skip() method.
    446 };
    447 
    448 class ResolvedServerRequestContext final: public capnp::HttpService::ServerRequestContext::Server {
    449 public:
    450   // Nothing! It's done.
    451 };
    452 
    453 }  // namespace
    454 
    455 class HttpOverCapnpFactory::ServerRequestContextImpl final
    456     : public capnp::HttpService::ServerRequestContext::Server,
    457       public kj::HttpService::Response {
    458 public:
    459   ServerRequestContextImpl(HttpOverCapnpFactory& factory,
    460                            HttpService::Client serviceCap,
    461                            capnp::HttpRequest::Reader request,
    462                            capnp::HttpService::ClientRequestContext::Client clientContext,
    463                            kj::Own<kj::AsyncInputStream> requestBodyIn,
    464                            kj::HttpService& kjService)
    465       : factory(factory), serviceCap(kj::mv(serviceCap)),
    466         method(validateMethod(request.getMethod())),
    467         url(kj::str(request.getUrl())),
    468         headers(factory.headersToKj(request.getHeaders()).clone()),
    469         clientContext(kj::mv(clientContext)),
    470         // Note we attach `requestBodyIn` to `task` so that we will implicitly cancel reading
    471         // the request body as soon as the service returns. This is important in particular when
    472         // the request body is not fully consumed, in order to propagate cancellation.
    473         task(kjService.request(method, url, headers, *requestBodyIn, *this)
    474                       .attach(kj::mv(requestBodyIn))) {}
    475 
    476   KJ_DISALLOW_COPY(ServerRequestContextImpl);
    477 
    478   kj::Maybe<kj::Promise<Capability::Client>> shortenPath() override {
    479     return task.then([]() -> Capability::Client {
    480       // If all went well, resolve to a settled capability.
    481       // TODO(perf): Could save a message by resolving to a capability hosted by the client, or
    482       //     some special "null" capability that isn't an error but is still transmitted by value.
    483       //     Otherwise we need a Release message from client -> server just to drop this...
    484       return kj::heap<ResolvedServerRequestContext>();
    485     });
    486   }
    487 
    488   kj::Own<kj::AsyncOutputStream> send(
    489       uint statusCode, kj::StringPtr statusText, const kj::HttpHeaders& headers,
    490       kj::Maybe<uint64_t> expectedBodySize = nullptr) override {
    491     KJ_REQUIRE(replyTask == nullptr, "already called send() or acceptWebSocket()");
    492 
    493     auto req = clientContext.startResponseRequest();
    494 
    495     if (method == kj::HttpMethod::HEAD ||
    496         statusCode == 204 || statusCode == 304) {
    497       expectedBodySize = uint64_t(0);
    498     }
    499 
    500     auto rpcResponse = req.initResponse();
    501     rpcResponse.setStatusCode(statusCode);
    502     rpcResponse.setStatusText(statusText);
    503     rpcResponse.adoptHeaders(factory.headersToCapnp(
    504         headers, Orphanage::getForMessageContaining(rpcResponse)));
    505     bool hasBody = true;
    506     KJ_IF_MAYBE(s, expectedBodySize) {
    507       rpcResponse.getBodySize().setFixed(*s);
    508       hasBody = *s > 0;
    509     }
    510 
    511     if (hasBody) {
    512       auto pipeline = req.send();
    513       auto result = factory.streamFactory.capnpToKj(pipeline.getBody());
    514       replyTask = pipeline.ignoreResult()
    515           .eagerlyEvaluate([](kj::Exception&& e) {
    516         KJ_LOG(ERROR, "HTTP-over-RPC startResponse() failed", e);
    517       });
    518       return result;
    519     } else {
    520       replyTask = req.send().ignoreResult()
    521           .eagerlyEvaluate([](kj::Exception&& e) {
    522         KJ_LOG(ERROR, "HTTP-over-RPC startResponse() failed", e);
    523       });
    524       return kj::heap<NullOutputStream>();
    525     }
    526 
    527     // We don't actually wait for replyTask anywhere, because we may be all done with this HTTP
    528     // message before the client gets a chance to respond, and we don't want to force an extra
    529     // network round trip. If the client fails this call that's the client's problem, really.
    530   }
    531 
    532   kj::Own<kj::WebSocket> acceptWebSocket(const kj::HttpHeaders& headers) override {
    533     KJ_REQUIRE(replyTask == nullptr, "already called send() or acceptWebSocket()");
    534 
    535     auto req = clientContext.startWebSocketRequest();
    536 
    537     req.adoptHeaders(factory.headersToCapnp(
    538         headers, Orphanage::getForMessageContaining(
    539             capnp::HttpService::ClientRequestContext::StartWebSocketParams::Builder(req))));
    540 
    541     auto pipe = kj::newWebSocketPipe();
    542     auto shorteningPaf = kj::newPromiseAndFulfiller<kj::Promise<Capability::Client>>();
    543 
    544     // We don't need the RequestState mechanism on the server side because
    545     // CapnpToKjWebSocketAdapter wraps a pipe end, and that pipe end can continue to exist beyond
    546     // the lifetime of the request, because the other end will have been dropped. We only create
    547     // a RequestState here so that we can reuse the implementation of CapnpToKjWebSocketAdapter
    548     // that needs this for the client side.
    549     auto dummyState = kj::refcounted<RequestState>();
    550     auto& pipeEnd0Ref = *pipe.ends[0];
    551     dummyState->holdWebSocket(kj::mv(pipe.ends[0]));
    552     req.setUpSocket(kj::heap<CapnpToKjWebSocketAdapter>(
    553         kj::mv(dummyState), pipeEnd0Ref, kj::mv(shorteningPaf.promise)));
    554 
    555     auto pipeline = req.send();
    556     auto result = kj::heap<KjToCapnpWebSocketAdapter>(
    557         kj::mv(pipe.ends[1]), pipeline.getDownSocket(), kj::mv(shorteningPaf.fulfiller));
    558 
    559     // Note we need eagerlyEvaluate() here to force proactively discarding the response object,
    560     // since it holds a reference to `downSocket`.
    561     replyTask = pipeline.ignoreResult()
    562         .eagerlyEvaluate([](kj::Exception&& e) {
    563       KJ_LOG(ERROR, "HTTP-over-RPC startWebSocketRequest() failed", e);
    564     });
    565 
    566     return result;
    567   }
    568 
    569 private:
    570   HttpOverCapnpFactory& factory;
    571   HttpService::Client serviceCap;  // ensures the inner kj::HttpService isn't destroyed
    572   kj::HttpMethod method;
    573   kj::String url;
    574   kj::HttpHeaders headers;
    575   capnp::HttpService::ClientRequestContext::Client clientContext;
    576   kj::Maybe<kj::Promise<void>> replyTask;
    577   kj::Promise<void> task;
    578 
    579   static kj::HttpMethod validateMethod(capnp::HttpMethod method) {
    580     KJ_REQUIRE(method <= capnp::HttpMethod::UNSUBSCRIBE, "unknown method", method);
    581     return static_cast<kj::HttpMethod>(method);
    582   }
    583 };
    584 
    585 class HttpOverCapnpFactory::CapnpToKjHttpServiceAdapter final: public capnp::HttpService::Server {
    586 public:
    587   CapnpToKjHttpServiceAdapter(HttpOverCapnpFactory& factory, kj::Own<kj::HttpService> inner)
    588       : factory(factory), inner(kj::mv(inner)) {}
    589 
    590   kj::Promise<void> startRequest(StartRequestContext context) override {
    591     auto params = context.getParams();
    592     auto metadata = params.getRequest();
    593 
    594     auto bodySize = metadata.getBodySize();
    595     kj::Maybe<uint64_t> expectedSize;
    596     bool hasBody = true;
    597     if (bodySize.isFixed()) {
    598       auto size = bodySize.getFixed();
    599       expectedSize = bodySize.getFixed();
    600       hasBody = size > 0;
    601     }
    602 
    603     auto results = context.getResults(MessageSize {8, 2});
    604     kj::Own<kj::AsyncInputStream> requestBody;
    605     if (hasBody) {
    606       auto pipe = kj::newOneWayPipe(expectedSize);
    607       results.setRequestBody(factory.streamFactory.kjToCapnp(kj::mv(pipe.out)));
    608       requestBody = kj::mv(pipe.in);
    609     } else {
    610       requestBody = kj::heap<NullInputStream>();
    611     }
    612     results.setContext(kj::heap<ServerRequestContextImpl>(
    613         factory, thisCap(), metadata, params.getContext(), kj::mv(requestBody), *inner));
    614 
    615     return kj::READY_NOW;
    616   }
    617 
    618 private:
    619   HttpOverCapnpFactory& factory;
    620   kj::Own<kj::HttpService> inner;
    621 };
    622 
    623 capnp::HttpService::Client HttpOverCapnpFactory::kjToCapnp(kj::Own<kj::HttpService> service) {
    624   return kj::heap<CapnpToKjHttpServiceAdapter>(*this, kj::mv(service));
    625 }
    626 
    627 // =======================================================================================
    628 
    629 static constexpr uint64_t COMMON_TEXT_ANNOTATION = 0x857745131db6fc83ull;
    630 // Type ID of `commonText` from `http.capnp`.
    631 // TODO(cleanup): Cap'n Proto should auto-generate constants for these.
    632 
    633 HttpOverCapnpFactory::HeaderIdBundle::HeaderIdBundle(kj::HttpHeaderTable::Builder& builder)
    634     : table(builder.getFutureTable()) {
    635   auto commonHeaderNames = Schema::from<capnp::CommonHeaderName>().getEnumerants();
    636   nameCapnpToKj = kj::heapArray<kj::HttpHeaderId>(commonHeaderNames.size());
    637   for (size_t i = 1; i < commonHeaderNames.size(); i++) {
    638     kj::StringPtr nameText;
    639     for (auto ann: commonHeaderNames[i].getProto().getAnnotations()) {
    640       if (ann.getId() == COMMON_TEXT_ANNOTATION) {
    641         nameText = ann.getValue().getText();
    642         break;
    643       }
    644     }
    645     KJ_ASSERT(nameText != nullptr);
    646     kj::HttpHeaderId headerId = builder.add(nameText);
    647     nameCapnpToKj[i] = headerId;
    648     maxHeaderId = kj::max(maxHeaderId, headerId.hashCode());
    649   }
    650 }
    651 
    652 HttpOverCapnpFactory::HeaderIdBundle::HeaderIdBundle(
    653     const kj::HttpHeaderTable& table, kj::Array<kj::HttpHeaderId> nameCapnpToKj, size_t maxHeaderId)
    654     : table(table), nameCapnpToKj(kj::mv(nameCapnpToKj)), maxHeaderId(maxHeaderId) {}
    655 
    656 HttpOverCapnpFactory::HeaderIdBundle HttpOverCapnpFactory::HeaderIdBundle::clone() const {
    657   return HeaderIdBundle(table, kj::heapArray<kj::HttpHeaderId>(nameCapnpToKj), maxHeaderId);
    658 }
    659 
    660 HttpOverCapnpFactory::HttpOverCapnpFactory(ByteStreamFactory& streamFactory,
    661                                            HeaderIdBundle headerIds)
    662     : streamFactory(streamFactory), headerTable(headerIds.table),
    663       nameCapnpToKj(kj::mv(headerIds.nameCapnpToKj)) {
    664   auto commonHeaderNames = Schema::from<capnp::CommonHeaderName>().getEnumerants();
    665   nameKjToCapnp = kj::heapArray<capnp::CommonHeaderName>(headerIds.maxHeaderId + 1);
    666   for (auto& slot: nameKjToCapnp) slot = capnp::CommonHeaderName::INVALID;
    667 
    668   for (size_t i = 1; i < commonHeaderNames.size(); i++) {
    669     auto& slot = nameKjToCapnp[nameCapnpToKj[i].hashCode()];
    670     KJ_ASSERT(slot == capnp::CommonHeaderName::INVALID);
    671     slot = static_cast<capnp::CommonHeaderName>(i);
    672   }
    673 
    674   auto commonHeaderValues = Schema::from<capnp::CommonHeaderValue>().getEnumerants();
    675   valueCapnpToKj = kj::heapArray<kj::StringPtr>(commonHeaderValues.size());
    676   for (size_t i = 1; i < commonHeaderValues.size(); i++) {
    677     kj::StringPtr valueText;
    678     for (auto ann: commonHeaderValues[i].getProto().getAnnotations()) {
    679       if (ann.getId() == COMMON_TEXT_ANNOTATION) {
    680         valueText = ann.getValue().getText();
    681         break;
    682       }
    683     }
    684     KJ_ASSERT(valueText != nullptr);
    685     valueCapnpToKj[i] = valueText;
    686     valueKjToCapnp.insert(valueText, static_cast<capnp::CommonHeaderValue>(i));
    687   }
    688 }
    689 
    690 Orphan<List<capnp::HttpHeader>> HttpOverCapnpFactory::headersToCapnp(
    691     const kj::HttpHeaders& headers, Orphanage orphanage) {
    692   auto result = orphanage.newOrphan<List<capnp::HttpHeader>>(headers.size());
    693   auto rpcHeaders = result.get();
    694   uint i = 0;
    695   headers.forEach([&](kj::HttpHeaderId id, kj::StringPtr value) {
    696     auto capnpName = id.hashCode() < nameKjToCapnp.size()
    697         ? nameKjToCapnp[id.hashCode()]
    698         : capnp::CommonHeaderName::INVALID;
    699     if (capnpName == capnp::CommonHeaderName::INVALID) {
    700       auto header = rpcHeaders[i++].initUncommon();
    701       header.setName(id.toString());
    702       header.setValue(value);
    703     } else {
    704       auto header = rpcHeaders[i++].initCommon();
    705       header.setName(capnpName);
    706       header.setValue(value);
    707     }
    708   }, [&](kj::StringPtr name, kj::StringPtr value) {
    709     auto header = rpcHeaders[i++].initUncommon();
    710     header.setName(name);
    711     header.setValue(value);
    712   });
    713   KJ_ASSERT(i == rpcHeaders.size());
    714   return result;
    715 }
    716 
    717 kj::HttpHeaders HttpOverCapnpFactory::headersToKj(
    718     List<capnp::HttpHeader>::Reader capnpHeaders) const {
    719   kj::HttpHeaders result(headerTable);
    720 
    721   for (auto header: capnpHeaders) {
    722     switch (header.which()) {
    723       case capnp::HttpHeader::COMMON: {
    724         auto nv = header.getCommon();
    725         auto nameInt = static_cast<uint>(nv.getName());
    726         KJ_REQUIRE(nameInt < nameCapnpToKj.size(), "unknown common header name", nv.getName());
    727 
    728         switch (nv.which()) {
    729           case capnp::HttpHeader::Common::COMMON_VALUE: {
    730             auto cvInt = static_cast<uint>(nv.getCommonValue());
    731             KJ_REQUIRE(nameInt < valueCapnpToKj.size(),
    732                 "unknown common header value", nv.getCommonValue());
    733             result.set(nameCapnpToKj[nameInt], valueCapnpToKj[cvInt]);
    734             break;
    735           }
    736           case capnp::HttpHeader::Common::VALUE: {
    737             auto headerId = nameCapnpToKj[nameInt];
    738             if (result.get(headerId) == nullptr) {
    739               result.set(headerId, nv.getValue());
    740             } else {
    741               // Unusual: This is a duplicate header, so fall back to add(), which may trigger
    742               //   comma-concatenation, except in certain cases where comma-concatentaion would
    743               //   be problematic.
    744               result.add(headerId.toString(), nv.getValue());
    745             }
    746             break;
    747           }
    748         }
    749         break;
    750       }
    751       case capnp::HttpHeader::UNCOMMON: {
    752         auto nv = header.getUncommon();
    753         result.add(nv.getName(), nv.getValue());
    754       }
    755     }
    756   }
    757 
    758   return result;
    759 }
    760 
    761 }  // namespace capnp