capnproto

FORK: Cap'n Proto serialization/RPC system - core tools and C++ library
git clone https://git.neptards.moe/neptards/capnproto.git
Log | Files | Refs | README | LICENSE

byte-stream.c++ (40034B)


      1 // Copyright (c) 2019 Cloudflare, Inc. and contributors
      2 // Licensed under the MIT License:
      3 //
      4 // Permission is hereby granted, free of charge, to any person obtaining a copy
      5 // of this software and associated documentation files (the "Software"), to deal
      6 // in the Software without restriction, including without limitation the rights
      7 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
      8 // copies of the Software, and to permit persons to whom the Software is
      9 // furnished to do so, subject to the following conditions:
     10 //
     11 // The above copyright notice and this permission notice shall be included in
     12 // all copies or substantial portions of the Software.
     13 //
     14 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
     15 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
     16 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
     17 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
     18 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
     19 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
     20 // THE SOFTWARE.
     21 
     22 #include "byte-stream.h"
     23 #include <kj/one-of.h>
     24 #include <kj/debug.h>
     25 
     26 namespace capnp {
     27 
     28 const uint MAX_BYTES_PER_WRITE = 1 << 16;
     29 
     30 class ByteStreamFactory::StreamServerBase: public capnp::ByteStream::Server {
     31 public:
     32   virtual void returnStream(uint64_t written) = 0;
     33   // Called after the StreamServerBase's internal kj::AsyncOutputStream has been borrowed, to
     34   // indicate that the borrower is done.
     35   //
     36   // A stream becomes borrowed either when getShortestPath() returns a BorrowedStream, or when
     37   // a SubstreamImpl is constructed wrapping an existing stream.
     38 
     39   struct BorrowedStream {
     40     // Represents permission to use the StreamServerBase's inner AsyncOutputStream directly, up
     41     // to some limit of bytes written.
     42 
     43     StreamServerBase& lender;
     44     kj::AsyncOutputStream& stream;
     45     uint64_t limit;
     46   };
     47 
     48   typedef kj::OneOf<kj::Promise<void>, capnp::ByteStream::Client*, BorrowedStream> ShortestPath;
     49 
     50   virtual ShortestPath getShortestPath() = 0;
     51   // Called by KjToCapnpStreamAdapter when it has determined that its inner ByteStream::Client
     52   // actually points back to a StreamServerBase in the same process created by the same
     53   // ByteStreamFactory. Returns the best shortened path to use, or a promise that resolves when the
     54   // shortest path is known.
     55 
     56   virtual void directEnd() = 0;
     57   // Called by KjToCapnpStreamAdapter's destructor when it has determined that its inner
     58   // ByteStream::Client actually points back to a StreamServerBase in the same process created by
     59   // the same ByteStreamFactory. Since destruction of a KJ stream signals EOF, we need to propagate
     60   // that by destroying our underlying stream.
     61   // TODO(cleanup): When KJ streams evolve an end() method, this can go away.
     62 };
     63 
     64 class ByteStreamFactory::SubstreamImpl final: public StreamServerBase {
     65 public:
     66   SubstreamImpl(ByteStreamFactory& factory,
     67                 StreamServerBase& parent,
     68                 capnp::ByteStream::Client ownParent,
     69                 kj::AsyncOutputStream& stream,
     70                 capnp::ByteStream::SubstreamCallback::Client callback,
     71                 uint64_t limit,
     72                 kj::PromiseFulfillerPair<void> paf = kj::newPromiseAndFulfiller<void>())
     73       : factory(factory),
     74         state(Streaming {parent, kj::mv(ownParent), stream, kj::mv(callback)}),
     75         limit(limit),
     76         resolveFulfiller(kj::mv(paf.fulfiller)),
     77         resolvePromise(paf.promise.fork()) {}
     78 
     79   // ---------------------------------------------------------------------------
     80   // implements StreamServerBase
     81 
     82   void returnStream(uint64_t written) override {
     83     completed += written;
     84     KJ_ASSERT(completed <= limit);
     85     auto borrowed = kj::mv(state.get<Borrowed>());
     86     state = kj::mv(borrowed.originalState);
     87 
     88     if (completed == limit) {
     89       limitReached();
     90     }
     91   }
     92 
     93   ShortestPath getShortestPath() override {
     94     KJ_SWITCH_ONEOF(state) {
     95       KJ_CASE_ONEOF(redirected, Redirected) {
     96         return &redirected.replacement;
     97       }
     98       KJ_CASE_ONEOF(e, Ended) {
     99         KJ_FAIL_REQUIRE("already called end()");
    100       }
    101       KJ_CASE_ONEOF(b, Borrowed) {
    102         KJ_FAIL_REQUIRE("can't call other methods while substream is active");
    103       }
    104       KJ_CASE_ONEOF(streaming, Streaming) {
    105         auto& stream = streaming.stream;
    106         auto oldState = kj::mv(streaming);
    107         state = Borrowed { kj::mv(oldState) };
    108         return BorrowedStream { *this, stream, limit - completed };
    109       }
    110     }
    111     KJ_UNREACHABLE;
    112   }
    113 
    114   void directEnd() override {
    115     KJ_SWITCH_ONEOF(state) {
    116       KJ_CASE_ONEOF(redirected, Redirected) {
    117         // Ugh I guess we need to send a real end() request here.
    118         redirected.replacement.endRequest(MessageSize {2, 0}).send().detach([](kj::Exception&&){});
    119       }
    120       KJ_CASE_ONEOF(e, Ended) {
    121         // whatever
    122       }
    123       KJ_CASE_ONEOF(b, Borrowed) {
    124         // ... whatever.
    125       }
    126       KJ_CASE_ONEOF(streaming, Streaming) {
    127         auto req = streaming.callback.endedRequest(MessageSize {4, 0});
    128         req.setByteCount(completed);
    129         req.send().detach([](kj::Exception&&){});
    130         streaming.parent.returnStream(completed);
    131         state = Ended();
    132       }
    133     }
    134   }
    135 
    136   // ---------------------------------------------------------------------------
    137   // implements ByteStream::Server RPC interface
    138 
    139   kj::Maybe<kj::Promise<Capability::Client>> shortenPath() override {
    140     return resolvePromise.addBranch()
    141         .then([this]() -> Capability::Client {
    142       return state.get<Redirected>().replacement;
    143     });
    144   }
    145 
    146   kj::Promise<void> write(WriteContext context) override {
    147     auto params = context.getParams();
    148     auto data = params.getBytes();
    149 
    150     KJ_SWITCH_ONEOF(state) {
    151       KJ_CASE_ONEOF(redirected, Redirected) {
    152         auto req = redirected.replacement.writeRequest(params.totalSize());
    153         req.setBytes(data);
    154         return req.send();
    155       }
    156       KJ_CASE_ONEOF(e, Ended) {
    157         KJ_FAIL_REQUIRE("already called end()");
    158       }
    159       KJ_CASE_ONEOF(b, Borrowed) {
    160         KJ_FAIL_REQUIRE("can't call other methods while stream is borrowed");
    161       }
    162       KJ_CASE_ONEOF(streaming, Streaming) {
    163         if (completed + data.size() < limit) {
    164           completed += data.size();
    165           return streaming.stream.write(data.begin(), data.size());
    166         } else {
    167           // This write passes the limit.
    168           uint64_t remainder = limit - completed;
    169           auto leftover = data.slice(remainder, data.size());
    170           return streaming.stream.write(data.begin(), remainder)
    171               .then([this, leftover]() -> kj::Promise<void> {
    172             completed = limit;
    173             limitReached();
    174 
    175             if (leftover.size() > 0) {
    176               // Need to forward the leftover bytes to the next stream.
    177               auto req = state.get<Redirected>().replacement.writeRequest(
    178                   MessageSize { 4 + leftover.size() / sizeof(capnp::word), 0 });
    179               req.setBytes(leftover);
    180               return req.send();
    181             } else {
    182               return kj::READY_NOW;
    183             }
    184           });
    185         }
    186       }
    187     }
    188     KJ_UNREACHABLE;
    189   }
    190 
    191   kj::Promise<void> end(EndContext context) override {
    192     KJ_SWITCH_ONEOF(state) {
    193       KJ_CASE_ONEOF(redirected, Redirected) {
    194         return context.tailCall(redirected.replacement.endRequest(MessageSize {2,0}));
    195       }
    196       KJ_CASE_ONEOF(e, Ended) {
    197         KJ_FAIL_REQUIRE("already called end()");
    198       }
    199       KJ_CASE_ONEOF(b, Borrowed) {
    200         KJ_FAIL_REQUIRE("can't call other methods while stream is borrowed");
    201       }
    202       KJ_CASE_ONEOF(streaming, Streaming) {
    203         auto req = streaming.callback.endedRequest(MessageSize {4, 0});
    204         req.setByteCount(completed);
    205         auto result = req.send().ignoreResult();
    206         streaming.parent.returnStream(completed);
    207         state = Ended();
    208         return result;
    209       }
    210     }
    211     KJ_UNREACHABLE;
    212   }
    213 
    214   kj::Promise<void> getSubstream(GetSubstreamContext context) override {
    215     KJ_SWITCH_ONEOF(state) {
    216       KJ_CASE_ONEOF(redirected, Redirected) {
    217         auto params = context.getParams();
    218         auto req = redirected.replacement.getSubstreamRequest(params.totalSize());
    219         req.setCallback(params.getCallback());
    220         req.setLimit(params.getLimit());
    221         return context.tailCall(kj::mv(req));
    222       }
    223       KJ_CASE_ONEOF(e, Ended) {
    224         KJ_FAIL_REQUIRE("already called end()");
    225       }
    226       KJ_CASE_ONEOF(b, Borrowed) {
    227         KJ_FAIL_REQUIRE("can't call other methods while stream is borrowed");
    228       }
    229       KJ_CASE_ONEOF(streaming, Streaming) {
    230         auto params = context.getParams();
    231         auto callback = params.getCallback();
    232         auto limit = params.getLimit();
    233         context.releaseParams();
    234         auto results = context.getResults(MessageSize { 2, 1 });
    235         results.setSubstream(factory.streamSet.add(kj::heap<SubstreamImpl>(
    236             factory, *this, thisCap(), streaming.stream, kj::mv(callback), kj::mv(limit))));
    237         state = Borrowed { kj::mv(streaming) };
    238         return kj::READY_NOW;
    239       }
    240     }
    241     KJ_UNREACHABLE;
    242   }
    243 
    244 private:
    245   ByteStreamFactory& factory;
    246 
    247   struct Streaming {
    248     StreamServerBase& parent;
    249     capnp::ByteStream::Client ownParent;
    250     kj::AsyncOutputStream& stream;
    251     capnp::ByteStream::SubstreamCallback::Client callback;
    252   };
    253   struct Borrowed {
    254     Streaming originalState;
    255   };
    256   struct Redirected {
    257     capnp::ByteStream::Client replacement;
    258   };
    259   struct Ended {};
    260 
    261   kj::OneOf<Streaming, Borrowed, Redirected, Ended> state;
    262 
    263   uint64_t limit;
    264   uint64_t completed = 0;
    265 
    266   kj::Own<kj::PromiseFulfiller<void>> resolveFulfiller;
    267   kj::ForkedPromise<void> resolvePromise;
    268 
    269   void limitReached() {
    270     auto& streaming = state.get<Streaming>();
    271     auto next = streaming.callback.reachedLimitRequest(capnp::MessageSize {2,0})
    272         .send().getNext();
    273 
    274     // Set the next stream as our replacement.
    275     streaming.parent.returnStream(limit);
    276     state = Redirected { kj::mv(next) };
    277     resolveFulfiller->fulfill();
    278   }
    279 };
    280 
    281 // =======================================================================================
    282 
    283 class ByteStreamFactory::CapnpToKjStreamAdapter final: public StreamServerBase {
    284   // Implements Cap'n Proto ByteStream as a wrapper around a KJ stream.
    285 
    286   class SubstreamCallbackImpl;
    287 
    288 public:
    289   class PathProber;
    290 
    291   CapnpToKjStreamAdapter(ByteStreamFactory& factory,
    292                          kj::Own<kj::AsyncOutputStream> inner)
    293       : factory(factory),
    294         state(kj::heap<PathProber>(*this, kj::mv(inner))) {
    295     state.get<kj::Own<PathProber>>()->startProbing();
    296   }
    297 
    298   CapnpToKjStreamAdapter(ByteStreamFactory& factory,
    299                          kj::Own<PathProber> pathProber)
    300       : factory(factory),
    301         state(kj::mv(pathProber)) {
    302     state.get<kj::Own<PathProber>>()->setNewParent(*this);
    303   }
    304 
    305   // ---------------------------------------------------------------------------
    306   // implements StreamServerBase
    307 
    308   void returnStream(uint64_t written) override {
    309     auto stream = kj::mv(state.get<Borrowed>().stream);
    310     state = kj::mv(stream);
    311   }
    312 
    313   ShortestPath getShortestPath() override {
    314     // Called by KjToCapnpStreamAdapter when it has determined that its inner ByteStream::Client
    315     // actually points back to a CapnpToKjStreamAdapter in the same process. Returns the best
    316     // shortened path to use, or a promise that resolves when the shortest path is known.
    317 
    318     KJ_SWITCH_ONEOF(state) {
    319       KJ_CASE_ONEOF(prober, kj::Own<PathProber>) {
    320         return prober->whenReady();
    321       }
    322       KJ_CASE_ONEOF(kjStream, kj::Own<kj::AsyncOutputStream>) {
    323         auto& streamRef = *kjStream;
    324         state = Borrowed { kj::mv(kjStream) };
    325         return StreamServerBase::BorrowedStream { *this, streamRef, kj::maxValue };
    326       }
    327       KJ_CASE_ONEOF(capnpStream, capnp::ByteStream::Client) {
    328         return &capnpStream;
    329       }
    330       KJ_CASE_ONEOF(b, Borrowed) {
    331         KJ_FAIL_REQUIRE("concurrent streaming calls disallowed") { break; }
    332         return kj::Promise<void>(kj::READY_NOW);
    333       }
    334       KJ_CASE_ONEOF(e, Ended) {
    335         KJ_FAIL_REQUIRE("already ended") { break; }
    336         return kj::Promise<void>(kj::READY_NOW);
    337       }
    338     }
    339     KJ_UNREACHABLE;
    340   }
    341 
    342   void directEnd() override {
    343     KJ_SWITCH_ONEOF(state) {
    344       KJ_CASE_ONEOF(prober, kj::Own<PathProber>) {
    345         state = Ended();
    346       }
    347       KJ_CASE_ONEOF(kjStream, kj::Own<kj::AsyncOutputStream>) {
    348         state = Ended();
    349       }
    350       KJ_CASE_ONEOF(capnpStream, capnp::ByteStream::Client) {
    351         // Ugh I guess we need to send a real end() request here.
    352         capnpStream.endRequest(MessageSize {2, 0}).send().detach([](kj::Exception&&){});
    353       }
    354       KJ_CASE_ONEOF(b, Borrowed) {
    355         // Fine, ignore.
    356       }
    357       KJ_CASE_ONEOF(e, Ended) {
    358         // Fine, ignore.
    359       }
    360     }
    361   }
    362 
    363   // ---------------------------------------------------------------------------
    364   // PathProber
    365 
    366   class PathProber final: public kj::AsyncInputStream {
    367   public:
    368     PathProber(CapnpToKjStreamAdapter& parent, kj::Own<kj::AsyncOutputStream> inner,
    369                kj::PromiseFulfillerPair<void> paf = kj::newPromiseAndFulfiller<void>())
    370         : parent(parent), inner(kj::mv(inner)),
    371           readyPromise(paf.promise.fork()),
    372           readyFulfiller(kj::mv(paf.fulfiller)),
    373           task(nullptr) {}
    374 
    375     void startProbing() {
    376       task = probeForShorterPath();
    377     }
    378 
    379     void setNewParent(CapnpToKjStreamAdapter& newParent) {
    380       KJ_ASSERT(parent == nullptr);
    381       parent = newParent;
    382       auto paf = kj::newPromiseAndFulfiller<void>();
    383       readyPromise = paf.promise.fork();
    384       readyFulfiller = kj::mv(paf.fulfiller);
    385     }
    386 
    387     kj::Promise<void> whenReady() {
    388       return readyPromise.addBranch();
    389     }
    390 
    391     kj::Promise<uint64_t> pumpToShorterPath(capnp::ByteStream::Client target, uint64_t limit) {
    392       // If our probe succeeds in finding a KjToCapnpStreamAdapter somewhere down the stack, that
    393       // will call this method to provide the shortened path.
    394 
    395       KJ_IF_MAYBE(currentParent, parent) {
    396         parent = nullptr;
    397 
    398         auto self = kj::mv(currentParent->state.get<kj::Own<PathProber>>());
    399         currentParent->state = Ended();  // temporary, we'll set this properly below
    400         KJ_ASSERT(self.get() == this);
    401 
    402         // Open a substream on the target stream.
    403         auto req = target.getSubstreamRequest();
    404         req.setLimit(limit);
    405         auto paf = kj::newPromiseAndFulfiller<uint64_t>();
    406         req.setCallback(kj::heap<SubstreamCallbackImpl>(currentParent->factory,
    407             kj::mv(self), kj::mv(paf.fulfiller), limit));
    408 
    409         // Now we hook up the incoming stream adapter to point directly to this substream, yay.
    410         currentParent->state = req.send().getSubstream();
    411 
    412         // Let the original CapnpToKjStreamAdapter know that it's safe to handle incoming requests.
    413         readyFulfiller->fulfill();
    414 
    415         // It's now up to the SubstreamCallbackImpl to signal when the pump is done.
    416         return kj::mv(paf.promise);
    417       } else {
    418         // We already completed a path-shortening. Probably SubstreamCallbackImpl::ended() was
    419         // eventually called, meaning the substream was ended without redirecting back to us. So,
    420         // we're at EOF.
    421         return uint64_t(0);
    422       }
    423     }
    424 
    425     kj::Promise<size_t> tryRead(void* buffer, size_t minBytes, size_t maxBytes) override {
    426       // If this is called, it means the tryPumpFrom() in probeForShorterPath() eventually invoked
    427       // code that tries to read manually from the source. We don't know what this code is doing
    428       // exactly, but we do know for sure that the endpoint is not a KjToCapnpStreamAdapter, so
    429       // we can't optimize. Instead, we pretend that we immediately hit EOF, ending the pump. This
    430       // works because pumps do not propagate EOF -- the destination can still receive further
    431       // writes and pumps. Basically our probing pump becomes a no-op, and then we revert to having
    432       // each write() RPC directly call write() on the inner stream.
    433       return size_t(0);
    434     }
    435 
    436     kj::Promise<uint64_t> pumpTo(kj::AsyncOutputStream& output, uint64_t amount) override {
    437       // Call the stream's `tryPumpFrom()` as a way to discover where the data will eventually go,
    438       // in hopes that we find we can shorten the path.
    439       KJ_IF_MAYBE(promise, output.tryPumpFrom(*this, amount)) {
    440         // tryPumpFrom() returned non-null. Either it called `tryRead()` or `pumpTo()` (see
    441         // below), or it plans to do so in the future.
    442         return kj::mv(*promise);
    443       } else {
    444         // There is no shorter path. As with tryRead(), we pretend we get immediate EOF.
    445         return uint64_t(0);
    446       }
    447     }
    448 
    449   private:
    450     kj::Maybe<CapnpToKjStreamAdapter&> parent;
    451     kj::Own<kj::AsyncOutputStream> inner;
    452     kj::ForkedPromise<void> readyPromise;
    453     kj::Own<kj::PromiseFulfiller<void>> readyFulfiller;
    454     kj::Promise<void> task;
    455 
    456     friend class SubstreamCallbackImpl;
    457 
    458     kj::Promise<void> probeForShorterPath() {
    459       return kj::evalNow([&]() -> kj::Promise<uint64_t> {
    460         return pumpTo(*inner, kj::maxValue);
    461       }).then([this](uint64_t actual) {
    462         KJ_IF_MAYBE(currentParent, parent) {
    463           KJ_IF_MAYBE(prober, currentParent->state.tryGet<kj::Own<PathProber>>()) {
    464             // Either we didn't find any shorter path at all during probing and faked an EOF
    465             // to get out of the probe (see comments in tryRead(), or we DID find a shorter path,
    466             // completed a pumpTo() using a substream, and that substream redirected back to us,
    467             // and THEN we couldn't find any further shorter paths for subsequent pumps.
    468 
    469             // HACK: If we overwrite the Probing state now, we'll delete ourselves and delete
    470             //   this task promise, which is an error... let the event loop do it later by
    471             //   detaching.
    472             task.attach(kj::mv(*prober)).detach([](kj::Exception&&){});
    473             parent = nullptr;
    474 
    475             // OK, now we can change the parent state and signal it to proceed.
    476             currentParent->state = kj::mv(inner);
    477             readyFulfiller->fulfill();
    478           }
    479         }
    480       }).eagerlyEvaluate([this](kj::Exception&& exception) mutable {
    481         // Something threw, so propagate the exception to break the parent.
    482         readyFulfiller->reject(kj::mv(exception));
    483       });
    484     }
    485   };
    486 
    487 protected:
    488   // ---------------------------------------------------------------------------
    489   // implements ByteStream::Server RPC interface
    490 
    491   kj::Maybe<kj::Promise<Capability::Client>> shortenPath() override {
    492     return shortenPathImpl();
    493   }
    494   kj::Promise<Capability::Client> shortenPathImpl() {
    495     // Called by RPC implementation to find out if a shorter path presents itself.
    496     KJ_SWITCH_ONEOF(state) {
    497       KJ_CASE_ONEOF(prober, kj::Own<PathProber>) {
    498         return prober->whenReady().then([this]() {
    499           KJ_ASSERT(!state.is<kj::Own<PathProber>>());
    500           return shortenPathImpl();
    501         });
    502       }
    503       KJ_CASE_ONEOF(kjStream, kj::Own<kj::AsyncOutputStream>) {
    504         // No shortening possible. Pretend we never resolve so that calls continue to be routed
    505         // to us forever.
    506         return kj::NEVER_DONE;
    507       }
    508       KJ_CASE_ONEOF(capnpStream, capnp::ByteStream::Client) {
    509         return Capability::Client(capnpStream);
    510       }
    511       KJ_CASE_ONEOF(b, Borrowed) {
    512         KJ_FAIL_REQUIRE("concurrent streaming calls disallowed") { break; }
    513         return kj::NEVER_DONE;
    514       }
    515       KJ_CASE_ONEOF(e, Ended) {
    516         // No shortening possible. Pretend we never resolve so that calls continue to be routed
    517         // to us forever.
    518         return kj::NEVER_DONE;
    519       }
    520     }
    521     KJ_UNREACHABLE;
    522   }
    523 
    524   kj::Promise<void> write(WriteContext context) override {
    525     KJ_SWITCH_ONEOF(state) {
    526       KJ_CASE_ONEOF(prober, kj::Own<PathProber>) {
    527         return prober->whenReady().then([this, context]() mutable {
    528           KJ_ASSERT(!state.is<kj::Own<PathProber>>());
    529           return write(context);
    530         });
    531       }
    532       KJ_CASE_ONEOF(kjStream, kj::Own<kj::AsyncOutputStream>) {
    533         auto data = context.getParams().getBytes();
    534         return kjStream->write(data.begin(), data.size());
    535       }
    536       KJ_CASE_ONEOF(capnpStream, capnp::ByteStream::Client) {
    537         auto params = context.getParams();
    538         auto req = capnpStream.writeRequest(params.totalSize());
    539         req.setBytes(params.getBytes());
    540         return req.send();
    541       }
    542       KJ_CASE_ONEOF(b, Borrowed) {
    543         KJ_FAIL_REQUIRE("concurrent streaming calls disallowed") { break; }
    544         return kj::READY_NOW;
    545       }
    546       KJ_CASE_ONEOF(e, Ended) {
    547         KJ_FAIL_REQUIRE("already called end()") { break; }
    548         return kj::READY_NOW;
    549       }
    550     }
    551     KJ_UNREACHABLE;
    552   }
    553 
    554   kj::Promise<void> end(EndContext context) override {
    555     KJ_SWITCH_ONEOF(state) {
    556       KJ_CASE_ONEOF(prober, kj::Own<PathProber>) {
    557         return prober->whenReady().then([this, context]() mutable {
    558           KJ_ASSERT(!state.is<kj::Own<PathProber>>());
    559           return end(context);
    560         });
    561       }
    562       KJ_CASE_ONEOF(kjStream, kj::Own<kj::AsyncOutputStream>) {
    563         // TODO(someday): When KJ adds a proper .end() call, use it here. For now, we must
    564         //   drop the stream to close it.
    565         state = Ended();
    566         return kj::READY_NOW;
    567       }
    568       KJ_CASE_ONEOF(capnpStream, capnp::ByteStream::Client) {
    569         auto params = context.getParams();
    570         auto req = capnpStream.endRequest(params.totalSize());
    571         return context.tailCall(kj::mv(req));
    572       }
    573       KJ_CASE_ONEOF(b, Borrowed) {
    574         KJ_FAIL_REQUIRE("concurrent streaming calls disallowed") { break; }
    575         return kj::READY_NOW;
    576       }
    577       KJ_CASE_ONEOF(e, Ended) {
    578         KJ_FAIL_REQUIRE("already called end()") { break; }
    579         return kj::READY_NOW;
    580       }
    581     }
    582     KJ_UNREACHABLE;
    583   }
    584 
    585   kj::Promise<void> getSubstream(GetSubstreamContext context) override {
    586     KJ_SWITCH_ONEOF(state) {
    587       KJ_CASE_ONEOF(prober, kj::Own<PathProber>) {
    588         return prober->whenReady().then([this, context]() mutable {
    589           KJ_ASSERT(!state.is<kj::Own<PathProber>>());
    590           return getSubstream(context);
    591         });
    592       }
    593       KJ_CASE_ONEOF(kjStream, kj::Own<kj::AsyncOutputStream>) {
    594         auto params = context.getParams();
    595         auto callback = params.getCallback();
    596         uint64_t limit = params.getLimit();
    597         context.releaseParams();
    598 
    599         auto results = context.initResults(MessageSize {2, 1});
    600         results.setSubstream(factory.streamSet.add(kj::heap<SubstreamImpl>(
    601             factory, *this, thisCap(), *kjStream, kj::mv(callback), kj::mv(limit))));
    602         state = Borrowed { kj::mv(kjStream) };
    603         return kj::READY_NOW;
    604       }
    605       KJ_CASE_ONEOF(capnpStream, capnp::ByteStream::Client) {
    606         auto params = context.getParams();
    607         auto req = capnpStream.getSubstreamRequest(params.totalSize());
    608         req.setCallback(params.getCallback());
    609         req.setLimit(params.getLimit());
    610         return context.tailCall(kj::mv(req));
    611       }
    612       KJ_CASE_ONEOF(b, Borrowed) {
    613         KJ_FAIL_REQUIRE("concurrent streaming calls disallowed") { break; }
    614         return kj::READY_NOW;
    615       }
    616       KJ_CASE_ONEOF(e, Ended) {
    617         KJ_FAIL_REQUIRE("already called end()") { break; }
    618         return kj::READY_NOW;
    619       }
    620     }
    621     KJ_UNREACHABLE;
    622   }
    623 
    624 private:
    625   ByteStreamFactory& factory;
    626 
    627   struct Borrowed { kj::Own<kj::AsyncOutputStream> stream; };
    628   struct Ended {};
    629 
    630   kj::OneOf<kj::Own<PathProber>, kj::Own<kj::AsyncOutputStream>,
    631             capnp::ByteStream::Client, Borrowed, Ended> state;
    632 
    633   class SubstreamCallbackImpl final: public capnp::ByteStream::SubstreamCallback::Server {
    634   public:
    635     SubstreamCallbackImpl(ByteStreamFactory& factory,
    636                           kj::Own<PathProber> pathProber,
    637                           kj::Own<kj::PromiseFulfiller<uint64_t>> originalPumpfulfiller,
    638                           uint64_t originalPumpLimit)
    639         : factory(factory),
    640           pathProber(kj::mv(pathProber)),
    641           originalPumpfulfiller(kj::mv(originalPumpfulfiller)),
    642           originalPumpLimit(originalPumpLimit) {}
    643 
    644     ~SubstreamCallbackImpl() noexcept(false) {
    645       if (!done) {
    646         originalPumpfulfiller->reject(KJ_EXCEPTION(DISCONNECTED,
    647             "stream disconnected because SubstreamCallbackImpl was never called back"));
    648       }
    649     }
    650 
    651     kj::Promise<void> ended(EndedContext context) override {
    652       KJ_REQUIRE(!done);
    653       uint64_t actual = context.getParams().getByteCount();
    654       KJ_REQUIRE(actual <= originalPumpLimit);
    655 
    656       done = true;
    657 
    658       // EOF before pump completed. Signal a short pump.
    659       originalPumpfulfiller->fulfill(context.getParams().getByteCount());
    660 
    661       // Give the original pump task a chance to finish up.
    662       return pathProber->task.attach(kj::mv(pathProber));
    663     }
    664 
    665     kj::Promise<void> reachedLimit(ReachedLimitContext context) override {
    666       KJ_REQUIRE(!done);
    667       done = true;
    668 
    669       // Allow the shortened stream to redirect back to our original underlying stream.
    670       auto results = context.getResults(capnp::MessageSize { 4, 1 });
    671       results.setNext(factory.streamSet.add(
    672           kj::heap<CapnpToKjStreamAdapter>(factory, kj::mv(pathProber))));
    673 
    674       // The full pump completed. Note that it's important that we fulfill this after the
    675       // PathProber has been attached to the new CapnpToKjStreamAdapter, which will have happened
    676       // in CapnpToKjStreamAdapter's constructor, which calls pathProber->setNewParent().
    677       originalPumpfulfiller->fulfill(kj::cp(originalPumpLimit));
    678 
    679       return kj::READY_NOW;
    680     }
    681 
    682   private:
    683     ByteStreamFactory& factory;
    684     kj::Own<PathProber> pathProber;
    685     kj::Own<kj::PromiseFulfiller<uint64_t>> originalPumpfulfiller;
    686     uint64_t originalPumpLimit;
    687     bool done = false;
    688   };
    689 };
    690 
    691 // =======================================================================================
    692 
    693 class ByteStreamFactory::KjToCapnpStreamAdapter final: public kj::AsyncOutputStream {
    694 public:
    695   KjToCapnpStreamAdapter(ByteStreamFactory& factory, capnp::ByteStream::Client innerParam)
    696       : factory(factory),
    697         inner(kj::mv(innerParam)),
    698         findShorterPathTask(findShorterPath(inner).fork()) {}
    699 
    700   ~KjToCapnpStreamAdapter() noexcept(false) {
    701     // HACK: KJ streams are implicitly ended on destruction, but the RPC stream needs a call. We
    702     //   use a detached promise for now, which is probably OK since capabilities are refcounted and
    703     //   asynchronously destroyed anyway.
    704     // TODO(cleanup): Fix this when KJ streads add an explicit end() method.
    705     KJ_IF_MAYBE(o, optimized) {
    706       o->directEnd();
    707     } else {
    708       inner.endRequest(MessageSize {2, 0}).send().detach([](kj::Exception&&){});
    709     }
    710   }
    711 
    712   kj::Promise<void> write(const void* buffer, size_t size) override {
    713     KJ_SWITCH_ONEOF(getShortestPath()) {
    714       KJ_CASE_ONEOF(promise, kj::Promise<void>) {
    715         return promise.then([this,buffer,size]() {
    716           return write(buffer, size);
    717         });
    718       }
    719       KJ_CASE_ONEOF(kjStream, StreamServerBase::BorrowedStream) {
    720         auto limit = kj::min(kjStream.limit, MAX_BYTES_PER_WRITE);
    721         if (size <= limit) {
    722           auto promise = kjStream.stream.write(buffer, size);
    723           return promise.then([kjStream,size]() mutable {
    724             kjStream.lender.returnStream(size);
    725           });
    726         } else {
    727           auto promise = kjStream.stream.write(buffer, limit);
    728           return promise.then([this,kjStream,buffer,size,limit]() mutable {
    729             kjStream.lender.returnStream(limit);
    730             return write(reinterpret_cast<const byte*>(buffer) + limit,
    731                          size - limit);
    732           });
    733         }
    734       }
    735       KJ_CASE_ONEOF(capnpStream, capnp::ByteStream::Client*) {
    736         if (size <= MAX_BYTES_PER_WRITE) {
    737           auto req = capnpStream->writeRequest(MessageSize { 8 + size / sizeof(word), 0 });
    738           req.setBytes(kj::arrayPtr(reinterpret_cast<const byte*>(buffer), size));
    739           return req.send();
    740         } else {
    741           auto req = capnpStream->writeRequest(
    742               MessageSize { 8 + MAX_BYTES_PER_WRITE / sizeof(word), 0 });
    743           req.setBytes(kj::arrayPtr(reinterpret_cast<const byte*>(buffer), MAX_BYTES_PER_WRITE));
    744           return req.send().then([this,buffer,size]() mutable {
    745             return write(reinterpret_cast<const byte*>(buffer) + MAX_BYTES_PER_WRITE,
    746                          size - MAX_BYTES_PER_WRITE);
    747           });
    748         }
    749       }
    750     }
    751     KJ_UNREACHABLE;
    752   }
    753 
    754   kj::Promise<void> write(kj::ArrayPtr<const kj::ArrayPtr<const byte>> pieces) override {
    755     KJ_SWITCH_ONEOF(getShortestPath()) {
    756       KJ_CASE_ONEOF(promise, kj::Promise<void>) {
    757         return promise.then([this,pieces]() {
    758           return write(pieces);
    759         });
    760       }
    761       KJ_CASE_ONEOF(kjStream, StreamServerBase::BorrowedStream) {
    762         size_t size = 0;
    763         for (auto& piece: pieces) { size += piece.size(); }
    764         auto limit = kj::min(kjStream.limit, MAX_BYTES_PER_WRITE);
    765         if (size <= limit) {
    766           auto promise = kjStream.stream.write(pieces);
    767           return promise.then([kjStream,size]() mutable {
    768             kjStream.lender.returnStream(size);
    769           });
    770         } else {
    771           // ughhhhhhhhhh, we need to split the pieces.
    772           return splitAndWrite(pieces, kjStream.limit,
    773               [kjStream,limit](kj::ArrayPtr<const kj::ArrayPtr<const byte>> pieces) mutable {
    774             return kjStream.stream.write(pieces).then([kjStream,limit]() mutable {
    775               kjStream.lender.returnStream(limit);
    776             });
    777           });
    778         }
    779       }
    780       KJ_CASE_ONEOF(capnpStream, capnp::ByteStream::Client*) {
    781         auto writePieces = [capnpStream](kj::ArrayPtr<const kj::ArrayPtr<const byte>> pieces) {
    782           size_t size = 0;
    783           for (auto& piece: pieces) size += piece.size();
    784           auto req = capnpStream->writeRequest(MessageSize { 8 + size / sizeof(word), 0 });
    785           auto out = req.initBytes(size);
    786           byte* ptr = out.begin();
    787           for (auto& piece: pieces) {
    788             memcpy(ptr, piece.begin(), piece.size());
    789             ptr += piece.size();
    790           }
    791           KJ_ASSERT(ptr == out.end());
    792           return req.send();
    793         };
    794 
    795         size_t size = 0;
    796         for (auto& piece: pieces) size += piece.size();
    797         if (size <= MAX_BYTES_PER_WRITE) {
    798           return writePieces(pieces);
    799         } else {
    800           // ughhhhhhhhhh, we need to split the pieces.
    801           return splitAndWrite(pieces, MAX_BYTES_PER_WRITE, writePieces);
    802         }
    803       }
    804     }
    805     KJ_UNREACHABLE;
    806   }
    807 
    808   kj::Maybe<kj::Promise<uint64_t>> tryPumpFrom(
    809       kj::AsyncInputStream& input, uint64_t amount = kj::maxValue) override {
    810     KJ_IF_MAYBE(rpc, kj::dynamicDowncastIfAvailable<CapnpToKjStreamAdapter::PathProber>(input)) {
    811       // Oh interesting, it turns we're hosting an incoming ByteStream which is pumping to this
    812       // outgoing ByteStream. We can let the Cap'n Proto RPC layer know that it can shorten the
    813       // path from one to the other.
    814       return rpc->pumpToShorterPath(inner, amount);
    815     } else {
    816       return pumpLoop(input, 0, amount);
    817     }
    818   }
    819 
    820   kj::Promise<void> whenWriteDisconnected() override {
    821     return findShorterPathTask.addBranch();
    822   }
    823 
    824 private:
    825   ByteStreamFactory& factory;
    826   capnp::ByteStream::Client inner;
    827   kj::Maybe<StreamServerBase&> optimized;
    828 
    829   kj::ForkedPromise<void> findShorterPathTask;
    830   // This serves two purposes:
    831   // 1. Waits for the capability to resolve (if it is a promise), and then shortens the path if
    832   //    possible.
    833   // 2. Implements whenWriteDisconnected().
    834 
    835   kj::Promise<void> findShorterPath(capnp::ByteStream::Client& capnpClient) {
    836     // If the capnp stream turns out to resolve back to this process, shorten the path.
    837     // Also, implement whenWriteDisconnected() based on this.
    838     return factory.streamSet.getLocalServer(capnpClient)
    839         .then([this](kj::Maybe<capnp::ByteStream::Server&> server) -> kj::Promise<void> {
    840       KJ_IF_MAYBE(s, server) {
    841         // Yay, we discovered that the ByteStream actually points back to a local KJ stream.
    842         // We can use this to shorten the path by skipping the RPC machinery.
    843         return findShorterPath(kj::downcast<StreamServerBase>(*s));
    844       } else {
    845         // The capability is fully-resolved. This suggests that the remote implementation is
    846         // NOT a CapnpToKjStreamAdapter at all, because CapnpToKjStreamAdapter is designed to
    847         // always look like a promise. It's some other implementation that doesn't present
    848         // itself as a promise. We have no way to detect when it is disconnected.
    849         return kj::NEVER_DONE;
    850       }
    851     }, [](kj::Exception&& e) -> kj::Promise<void> {
    852       // getLocalServer() thrown when the capability is a promise cap that rejects. We can
    853       // use this to implement whenWriteDisconnected().
    854       //
    855       // (Note that because this exception handler is passed to the .then(), it does NOT catch
    856       // eoxceptions thrown by the success handler immediately above it. This handler will ONLY
    857       // catch exceptions from getLocalServer() itself.)
    858       return kj::READY_NOW;
    859     });
    860   }
    861 
    862   kj::Promise<void> findShorterPath(StreamServerBase& capnpServer) {
    863     // We found a shorter path back to this process. Record it.
    864     optimized = capnpServer;
    865 
    866     KJ_SWITCH_ONEOF(capnpServer.getShortestPath()) {
    867       KJ_CASE_ONEOF(promise, kj::Promise<void>) {
    868         return promise.then([this,&capnpServer]() {
    869           return findShorterPath(capnpServer);
    870         });
    871       }
    872       KJ_CASE_ONEOF(kjStream, StreamServerBase::BorrowedStream) {
    873         // The ByteStream::Server wraps a regular KJ stream that does not wrap another capnp
    874         // stream.
    875         if (kjStream.limit < (uint64_t)kj::maxValue / 2) {
    876           // But it isn't wrapping that stream forever. Eventually it plans to redirect back to
    877           // some other stream. So, let's wait for that, and possibly shorten again.
    878           kjStream.lender.returnStream(0);
    879           return KJ_ASSERT_NONNULL(capnpServer.shortenPath())
    880               .then([this, &capnpServer](auto&&) {
    881             return findShorterPath(capnpServer);
    882           });
    883         } else {
    884           // This KJ stream is (effectively) the permanent endpoint. We can't get any shorter
    885           // from here. All we want to do now is watch for disconnect.
    886           auto promise = kjStream.stream.whenWriteDisconnected();
    887           kjStream.lender.returnStream(0);
    888           return promise;
    889         }
    890       }
    891       KJ_CASE_ONEOF(capnpStream, capnp::ByteStream::Client*) {
    892         return findShorterPath(*capnpStream);
    893       }
    894     }
    895     KJ_UNREACHABLE;
    896   }
    897 
    898   StreamServerBase::ShortestPath getShortestPath() {
    899     KJ_IF_MAYBE(o, optimized) {
    900       return o->getShortestPath();
    901     } else {
    902       return &inner;
    903     }
    904   }
    905 
    906   kj::Promise<uint64_t> pumpLoop(kj::AsyncInputStream& input,
    907                                  uint64_t completed, uint64_t remaining) {
    908     if (remaining == 0) return completed;
    909 
    910     KJ_SWITCH_ONEOF(getShortestPath()) {
    911       KJ_CASE_ONEOF(promise, kj::Promise<void>) {
    912         return promise.then([this,&input,completed,remaining]() {
    913           return pumpLoop(input,completed,remaining);
    914         });
    915       }
    916       KJ_CASE_ONEOF(kjStream, StreamServerBase::BorrowedStream) {
    917         // Oh hell yes, this capability actually points back to a stream in our own thread. We can
    918         // stop sending RPCs and just pump directly.
    919 
    920         if (remaining <= kjStream.limit) {
    921           return input.pumpTo(kjStream.stream, remaining)
    922               .then([kjStream,completed](uint64_t actual) {
    923             kjStream.lender.returnStream(actual);
    924             return actual + completed;
    925           });
    926         } else {
    927           auto promise = input.pumpTo(kjStream.stream, kjStream.limit);
    928           return promise.then([this,&input,completed,remaining,kjStream]
    929                               (uint64_t actual) mutable -> kj::Promise<uint64_t> {
    930             kjStream.lender.returnStream(actual);
    931             if (actual < kjStream.limit) {
    932               // EOF reached.
    933               return completed + actual;
    934             } else {
    935               return pumpLoop(input, completed + actual, remaining - actual);
    936             }
    937           });
    938         }
    939       }
    940       KJ_CASE_ONEOF(capnpStream, capnp::ByteStream::Client*) {
    941         // Pumping from some other kind of steram. Optimize the pump by reading from the input
    942         // directly into outgoing RPC messages.
    943         size_t size = kj::min(remaining, 8192);
    944         auto req = capnpStream->writeRequest(MessageSize { 8 + size / sizeof(word) });
    945 
    946         auto orphanage = Orphanage::getForMessageContaining(
    947             capnp::ByteStream::WriteParams::Builder(req));
    948 
    949         auto buffer = orphanage.newOrphan<Data>(size);
    950 
    951         struct WriteRequestAndBuffer {
    952           // The order of construction/destruction of lambda captures is unspecified, but we care
    953           // about ordering between these two things that we want to capture, so... we need a
    954           // struct.
    955           StreamingRequest<capnp::ByteStream::WriteParams> request;
    956           Orphan<Data> buffer;  // points into `request`...
    957         };
    958 
    959         WriteRequestAndBuffer wrab = { kj::mv(req), kj::mv(buffer) };
    960 
    961         return input.tryRead(wrab.buffer.get().begin(), 1, size)
    962             .then([this, &input, completed, remaining, size, wrab = kj::mv(wrab)]
    963                   (size_t actual) mutable -> kj::Promise<uint64_t> {
    964           if (actual == 0) {
    965             return completed;
    966           } if (actual < size) {
    967             wrab.buffer.truncate(actual);
    968           }
    969 
    970           wrab.request.adoptBytes(kj::mv(wrab.buffer));
    971           return wrab.request.send()
    972               .then([this, &input, completed, remaining, actual]() {
    973             return pumpLoop(input, completed + actual, remaining - actual);
    974           });
    975         });
    976       }
    977     }
    978     KJ_UNREACHABLE;
    979   }
    980 
    981   template <typename WritePieces>
    982   kj::Promise<void> splitAndWrite(kj::ArrayPtr<const kj::ArrayPtr<const byte>> pieces,
    983                                   size_t limit, WritePieces&& writeFirstPieces) {
    984     size_t splitByte = limit;
    985     size_t splitPiece = 0;
    986     while (pieces[splitPiece].size() <= splitByte) {
    987       splitByte -= pieces[splitPiece].size();
    988       ++splitPiece;
    989     }
    990 
    991     if (splitByte == 0) {
    992       // Oh thank god, the split is between two pieces.
    993       auto rest = pieces.slice(splitPiece, pieces.size());
    994       return writeFirstPieces(pieces.slice(0, splitPiece))
    995           .then([this,rest]() mutable {
    996         return write(rest);
    997       });
    998     } else {
    999       // FUUUUUUUU---- we need to split one of the pieces in two.
   1000       auto left = kj::heapArray<kj::ArrayPtr<const byte>>(splitPiece + 1);
   1001       auto right = kj::heapArray<kj::ArrayPtr<const byte>>(pieces.size() - splitPiece);
   1002       for (auto i: kj::zeroTo(splitPiece)) {
   1003         left[i] = pieces[i];
   1004       }
   1005       for (auto i: kj::zeroTo(right.size())) {
   1006         right[i] = pieces[splitPiece + i];
   1007       }
   1008       left.back() = pieces[splitPiece].slice(0, splitByte);
   1009       right.front() = pieces[splitPiece].slice(splitByte, pieces[splitPiece].size());
   1010 
   1011       return writeFirstPieces(left).attach(kj::mv(left))
   1012           .then([this,right=kj::mv(right)]() mutable {
   1013         return write(right).attach(kj::mv(right));
   1014       });
   1015     }
   1016   }
   1017 };
   1018 
   1019 // =======================================================================================
   1020 
   1021 capnp::ByteStream::Client ByteStreamFactory::kjToCapnp(kj::Own<kj::AsyncOutputStream> kjStream) {
   1022   return streamSet.add(kj::heap<CapnpToKjStreamAdapter>(*this, kj::mv(kjStream)));
   1023 }
   1024 
   1025 kj::Own<kj::AsyncOutputStream> ByteStreamFactory::capnpToKj(capnp::ByteStream::Client capnpStream) {
   1026   return kj::heap<KjToCapnpStreamAdapter>(*this, kj::mv(capnpStream));
   1027 }
   1028 
   1029 }  // namespace capnp