byte-stream.c++ (40034B)
1 // Copyright (c) 2019 Cloudflare, Inc. and contributors 2 // Licensed under the MIT License: 3 // 4 // Permission is hereby granted, free of charge, to any person obtaining a copy 5 // of this software and associated documentation files (the "Software"), to deal 6 // in the Software without restriction, including without limitation the rights 7 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 8 // copies of the Software, and to permit persons to whom the Software is 9 // furnished to do so, subject to the following conditions: 10 // 11 // The above copyright notice and this permission notice shall be included in 12 // all copies or substantial portions of the Software. 13 // 14 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 17 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 19 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 20 // THE SOFTWARE. 21 22 #include "byte-stream.h" 23 #include <kj/one-of.h> 24 #include <kj/debug.h> 25 26 namespace capnp { 27 28 const uint MAX_BYTES_PER_WRITE = 1 << 16; 29 30 class ByteStreamFactory::StreamServerBase: public capnp::ByteStream::Server { 31 public: 32 virtual void returnStream(uint64_t written) = 0; 33 // Called after the StreamServerBase's internal kj::AsyncOutputStream has been borrowed, to 34 // indicate that the borrower is done. 35 // 36 // A stream becomes borrowed either when getShortestPath() returns a BorrowedStream, or when 37 // a SubstreamImpl is constructed wrapping an existing stream. 38 39 struct BorrowedStream { 40 // Represents permission to use the StreamServerBase's inner AsyncOutputStream directly, up 41 // to some limit of bytes written. 42 43 StreamServerBase& lender; 44 kj::AsyncOutputStream& stream; 45 uint64_t limit; 46 }; 47 48 typedef kj::OneOf<kj::Promise<void>, capnp::ByteStream::Client*, BorrowedStream> ShortestPath; 49 50 virtual ShortestPath getShortestPath() = 0; 51 // Called by KjToCapnpStreamAdapter when it has determined that its inner ByteStream::Client 52 // actually points back to a StreamServerBase in the same process created by the same 53 // ByteStreamFactory. Returns the best shortened path to use, or a promise that resolves when the 54 // shortest path is known. 55 56 virtual void directEnd() = 0; 57 // Called by KjToCapnpStreamAdapter's destructor when it has determined that its inner 58 // ByteStream::Client actually points back to a StreamServerBase in the same process created by 59 // the same ByteStreamFactory. Since destruction of a KJ stream signals EOF, we need to propagate 60 // that by destroying our underlying stream. 61 // TODO(cleanup): When KJ streams evolve an end() method, this can go away. 62 }; 63 64 class ByteStreamFactory::SubstreamImpl final: public StreamServerBase { 65 public: 66 SubstreamImpl(ByteStreamFactory& factory, 67 StreamServerBase& parent, 68 capnp::ByteStream::Client ownParent, 69 kj::AsyncOutputStream& stream, 70 capnp::ByteStream::SubstreamCallback::Client callback, 71 uint64_t limit, 72 kj::PromiseFulfillerPair<void> paf = kj::newPromiseAndFulfiller<void>()) 73 : factory(factory), 74 state(Streaming {parent, kj::mv(ownParent), stream, kj::mv(callback)}), 75 limit(limit), 76 resolveFulfiller(kj::mv(paf.fulfiller)), 77 resolvePromise(paf.promise.fork()) {} 78 79 // --------------------------------------------------------------------------- 80 // implements StreamServerBase 81 82 void returnStream(uint64_t written) override { 83 completed += written; 84 KJ_ASSERT(completed <= limit); 85 auto borrowed = kj::mv(state.get<Borrowed>()); 86 state = kj::mv(borrowed.originalState); 87 88 if (completed == limit) { 89 limitReached(); 90 } 91 } 92 93 ShortestPath getShortestPath() override { 94 KJ_SWITCH_ONEOF(state) { 95 KJ_CASE_ONEOF(redirected, Redirected) { 96 return &redirected.replacement; 97 } 98 KJ_CASE_ONEOF(e, Ended) { 99 KJ_FAIL_REQUIRE("already called end()"); 100 } 101 KJ_CASE_ONEOF(b, Borrowed) { 102 KJ_FAIL_REQUIRE("can't call other methods while substream is active"); 103 } 104 KJ_CASE_ONEOF(streaming, Streaming) { 105 auto& stream = streaming.stream; 106 auto oldState = kj::mv(streaming); 107 state = Borrowed { kj::mv(oldState) }; 108 return BorrowedStream { *this, stream, limit - completed }; 109 } 110 } 111 KJ_UNREACHABLE; 112 } 113 114 void directEnd() override { 115 KJ_SWITCH_ONEOF(state) { 116 KJ_CASE_ONEOF(redirected, Redirected) { 117 // Ugh I guess we need to send a real end() request here. 118 redirected.replacement.endRequest(MessageSize {2, 0}).send().detach([](kj::Exception&&){}); 119 } 120 KJ_CASE_ONEOF(e, Ended) { 121 // whatever 122 } 123 KJ_CASE_ONEOF(b, Borrowed) { 124 // ... whatever. 125 } 126 KJ_CASE_ONEOF(streaming, Streaming) { 127 auto req = streaming.callback.endedRequest(MessageSize {4, 0}); 128 req.setByteCount(completed); 129 req.send().detach([](kj::Exception&&){}); 130 streaming.parent.returnStream(completed); 131 state = Ended(); 132 } 133 } 134 } 135 136 // --------------------------------------------------------------------------- 137 // implements ByteStream::Server RPC interface 138 139 kj::Maybe<kj::Promise<Capability::Client>> shortenPath() override { 140 return resolvePromise.addBranch() 141 .then([this]() -> Capability::Client { 142 return state.get<Redirected>().replacement; 143 }); 144 } 145 146 kj::Promise<void> write(WriteContext context) override { 147 auto params = context.getParams(); 148 auto data = params.getBytes(); 149 150 KJ_SWITCH_ONEOF(state) { 151 KJ_CASE_ONEOF(redirected, Redirected) { 152 auto req = redirected.replacement.writeRequest(params.totalSize()); 153 req.setBytes(data); 154 return req.send(); 155 } 156 KJ_CASE_ONEOF(e, Ended) { 157 KJ_FAIL_REQUIRE("already called end()"); 158 } 159 KJ_CASE_ONEOF(b, Borrowed) { 160 KJ_FAIL_REQUIRE("can't call other methods while stream is borrowed"); 161 } 162 KJ_CASE_ONEOF(streaming, Streaming) { 163 if (completed + data.size() < limit) { 164 completed += data.size(); 165 return streaming.stream.write(data.begin(), data.size()); 166 } else { 167 // This write passes the limit. 168 uint64_t remainder = limit - completed; 169 auto leftover = data.slice(remainder, data.size()); 170 return streaming.stream.write(data.begin(), remainder) 171 .then([this, leftover]() -> kj::Promise<void> { 172 completed = limit; 173 limitReached(); 174 175 if (leftover.size() > 0) { 176 // Need to forward the leftover bytes to the next stream. 177 auto req = state.get<Redirected>().replacement.writeRequest( 178 MessageSize { 4 + leftover.size() / sizeof(capnp::word), 0 }); 179 req.setBytes(leftover); 180 return req.send(); 181 } else { 182 return kj::READY_NOW; 183 } 184 }); 185 } 186 } 187 } 188 KJ_UNREACHABLE; 189 } 190 191 kj::Promise<void> end(EndContext context) override { 192 KJ_SWITCH_ONEOF(state) { 193 KJ_CASE_ONEOF(redirected, Redirected) { 194 return context.tailCall(redirected.replacement.endRequest(MessageSize {2,0})); 195 } 196 KJ_CASE_ONEOF(e, Ended) { 197 KJ_FAIL_REQUIRE("already called end()"); 198 } 199 KJ_CASE_ONEOF(b, Borrowed) { 200 KJ_FAIL_REQUIRE("can't call other methods while stream is borrowed"); 201 } 202 KJ_CASE_ONEOF(streaming, Streaming) { 203 auto req = streaming.callback.endedRequest(MessageSize {4, 0}); 204 req.setByteCount(completed); 205 auto result = req.send().ignoreResult(); 206 streaming.parent.returnStream(completed); 207 state = Ended(); 208 return result; 209 } 210 } 211 KJ_UNREACHABLE; 212 } 213 214 kj::Promise<void> getSubstream(GetSubstreamContext context) override { 215 KJ_SWITCH_ONEOF(state) { 216 KJ_CASE_ONEOF(redirected, Redirected) { 217 auto params = context.getParams(); 218 auto req = redirected.replacement.getSubstreamRequest(params.totalSize()); 219 req.setCallback(params.getCallback()); 220 req.setLimit(params.getLimit()); 221 return context.tailCall(kj::mv(req)); 222 } 223 KJ_CASE_ONEOF(e, Ended) { 224 KJ_FAIL_REQUIRE("already called end()"); 225 } 226 KJ_CASE_ONEOF(b, Borrowed) { 227 KJ_FAIL_REQUIRE("can't call other methods while stream is borrowed"); 228 } 229 KJ_CASE_ONEOF(streaming, Streaming) { 230 auto params = context.getParams(); 231 auto callback = params.getCallback(); 232 auto limit = params.getLimit(); 233 context.releaseParams(); 234 auto results = context.getResults(MessageSize { 2, 1 }); 235 results.setSubstream(factory.streamSet.add(kj::heap<SubstreamImpl>( 236 factory, *this, thisCap(), streaming.stream, kj::mv(callback), kj::mv(limit)))); 237 state = Borrowed { kj::mv(streaming) }; 238 return kj::READY_NOW; 239 } 240 } 241 KJ_UNREACHABLE; 242 } 243 244 private: 245 ByteStreamFactory& factory; 246 247 struct Streaming { 248 StreamServerBase& parent; 249 capnp::ByteStream::Client ownParent; 250 kj::AsyncOutputStream& stream; 251 capnp::ByteStream::SubstreamCallback::Client callback; 252 }; 253 struct Borrowed { 254 Streaming originalState; 255 }; 256 struct Redirected { 257 capnp::ByteStream::Client replacement; 258 }; 259 struct Ended {}; 260 261 kj::OneOf<Streaming, Borrowed, Redirected, Ended> state; 262 263 uint64_t limit; 264 uint64_t completed = 0; 265 266 kj::Own<kj::PromiseFulfiller<void>> resolveFulfiller; 267 kj::ForkedPromise<void> resolvePromise; 268 269 void limitReached() { 270 auto& streaming = state.get<Streaming>(); 271 auto next = streaming.callback.reachedLimitRequest(capnp::MessageSize {2,0}) 272 .send().getNext(); 273 274 // Set the next stream as our replacement. 275 streaming.parent.returnStream(limit); 276 state = Redirected { kj::mv(next) }; 277 resolveFulfiller->fulfill(); 278 } 279 }; 280 281 // ======================================================================================= 282 283 class ByteStreamFactory::CapnpToKjStreamAdapter final: public StreamServerBase { 284 // Implements Cap'n Proto ByteStream as a wrapper around a KJ stream. 285 286 class SubstreamCallbackImpl; 287 288 public: 289 class PathProber; 290 291 CapnpToKjStreamAdapter(ByteStreamFactory& factory, 292 kj::Own<kj::AsyncOutputStream> inner) 293 : factory(factory), 294 state(kj::heap<PathProber>(*this, kj::mv(inner))) { 295 state.get<kj::Own<PathProber>>()->startProbing(); 296 } 297 298 CapnpToKjStreamAdapter(ByteStreamFactory& factory, 299 kj::Own<PathProber> pathProber) 300 : factory(factory), 301 state(kj::mv(pathProber)) { 302 state.get<kj::Own<PathProber>>()->setNewParent(*this); 303 } 304 305 // --------------------------------------------------------------------------- 306 // implements StreamServerBase 307 308 void returnStream(uint64_t written) override { 309 auto stream = kj::mv(state.get<Borrowed>().stream); 310 state = kj::mv(stream); 311 } 312 313 ShortestPath getShortestPath() override { 314 // Called by KjToCapnpStreamAdapter when it has determined that its inner ByteStream::Client 315 // actually points back to a CapnpToKjStreamAdapter in the same process. Returns the best 316 // shortened path to use, or a promise that resolves when the shortest path is known. 317 318 KJ_SWITCH_ONEOF(state) { 319 KJ_CASE_ONEOF(prober, kj::Own<PathProber>) { 320 return prober->whenReady(); 321 } 322 KJ_CASE_ONEOF(kjStream, kj::Own<kj::AsyncOutputStream>) { 323 auto& streamRef = *kjStream; 324 state = Borrowed { kj::mv(kjStream) }; 325 return StreamServerBase::BorrowedStream { *this, streamRef, kj::maxValue }; 326 } 327 KJ_CASE_ONEOF(capnpStream, capnp::ByteStream::Client) { 328 return &capnpStream; 329 } 330 KJ_CASE_ONEOF(b, Borrowed) { 331 KJ_FAIL_REQUIRE("concurrent streaming calls disallowed") { break; } 332 return kj::Promise<void>(kj::READY_NOW); 333 } 334 KJ_CASE_ONEOF(e, Ended) { 335 KJ_FAIL_REQUIRE("already ended") { break; } 336 return kj::Promise<void>(kj::READY_NOW); 337 } 338 } 339 KJ_UNREACHABLE; 340 } 341 342 void directEnd() override { 343 KJ_SWITCH_ONEOF(state) { 344 KJ_CASE_ONEOF(prober, kj::Own<PathProber>) { 345 state = Ended(); 346 } 347 KJ_CASE_ONEOF(kjStream, kj::Own<kj::AsyncOutputStream>) { 348 state = Ended(); 349 } 350 KJ_CASE_ONEOF(capnpStream, capnp::ByteStream::Client) { 351 // Ugh I guess we need to send a real end() request here. 352 capnpStream.endRequest(MessageSize {2, 0}).send().detach([](kj::Exception&&){}); 353 } 354 KJ_CASE_ONEOF(b, Borrowed) { 355 // Fine, ignore. 356 } 357 KJ_CASE_ONEOF(e, Ended) { 358 // Fine, ignore. 359 } 360 } 361 } 362 363 // --------------------------------------------------------------------------- 364 // PathProber 365 366 class PathProber final: public kj::AsyncInputStream { 367 public: 368 PathProber(CapnpToKjStreamAdapter& parent, kj::Own<kj::AsyncOutputStream> inner, 369 kj::PromiseFulfillerPair<void> paf = kj::newPromiseAndFulfiller<void>()) 370 : parent(parent), inner(kj::mv(inner)), 371 readyPromise(paf.promise.fork()), 372 readyFulfiller(kj::mv(paf.fulfiller)), 373 task(nullptr) {} 374 375 void startProbing() { 376 task = probeForShorterPath(); 377 } 378 379 void setNewParent(CapnpToKjStreamAdapter& newParent) { 380 KJ_ASSERT(parent == nullptr); 381 parent = newParent; 382 auto paf = kj::newPromiseAndFulfiller<void>(); 383 readyPromise = paf.promise.fork(); 384 readyFulfiller = kj::mv(paf.fulfiller); 385 } 386 387 kj::Promise<void> whenReady() { 388 return readyPromise.addBranch(); 389 } 390 391 kj::Promise<uint64_t> pumpToShorterPath(capnp::ByteStream::Client target, uint64_t limit) { 392 // If our probe succeeds in finding a KjToCapnpStreamAdapter somewhere down the stack, that 393 // will call this method to provide the shortened path. 394 395 KJ_IF_MAYBE(currentParent, parent) { 396 parent = nullptr; 397 398 auto self = kj::mv(currentParent->state.get<kj::Own<PathProber>>()); 399 currentParent->state = Ended(); // temporary, we'll set this properly below 400 KJ_ASSERT(self.get() == this); 401 402 // Open a substream on the target stream. 403 auto req = target.getSubstreamRequest(); 404 req.setLimit(limit); 405 auto paf = kj::newPromiseAndFulfiller<uint64_t>(); 406 req.setCallback(kj::heap<SubstreamCallbackImpl>(currentParent->factory, 407 kj::mv(self), kj::mv(paf.fulfiller), limit)); 408 409 // Now we hook up the incoming stream adapter to point directly to this substream, yay. 410 currentParent->state = req.send().getSubstream(); 411 412 // Let the original CapnpToKjStreamAdapter know that it's safe to handle incoming requests. 413 readyFulfiller->fulfill(); 414 415 // It's now up to the SubstreamCallbackImpl to signal when the pump is done. 416 return kj::mv(paf.promise); 417 } else { 418 // We already completed a path-shortening. Probably SubstreamCallbackImpl::ended() was 419 // eventually called, meaning the substream was ended without redirecting back to us. So, 420 // we're at EOF. 421 return uint64_t(0); 422 } 423 } 424 425 kj::Promise<size_t> tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { 426 // If this is called, it means the tryPumpFrom() in probeForShorterPath() eventually invoked 427 // code that tries to read manually from the source. We don't know what this code is doing 428 // exactly, but we do know for sure that the endpoint is not a KjToCapnpStreamAdapter, so 429 // we can't optimize. Instead, we pretend that we immediately hit EOF, ending the pump. This 430 // works because pumps do not propagate EOF -- the destination can still receive further 431 // writes and pumps. Basically our probing pump becomes a no-op, and then we revert to having 432 // each write() RPC directly call write() on the inner stream. 433 return size_t(0); 434 } 435 436 kj::Promise<uint64_t> pumpTo(kj::AsyncOutputStream& output, uint64_t amount) override { 437 // Call the stream's `tryPumpFrom()` as a way to discover where the data will eventually go, 438 // in hopes that we find we can shorten the path. 439 KJ_IF_MAYBE(promise, output.tryPumpFrom(*this, amount)) { 440 // tryPumpFrom() returned non-null. Either it called `tryRead()` or `pumpTo()` (see 441 // below), or it plans to do so in the future. 442 return kj::mv(*promise); 443 } else { 444 // There is no shorter path. As with tryRead(), we pretend we get immediate EOF. 445 return uint64_t(0); 446 } 447 } 448 449 private: 450 kj::Maybe<CapnpToKjStreamAdapter&> parent; 451 kj::Own<kj::AsyncOutputStream> inner; 452 kj::ForkedPromise<void> readyPromise; 453 kj::Own<kj::PromiseFulfiller<void>> readyFulfiller; 454 kj::Promise<void> task; 455 456 friend class SubstreamCallbackImpl; 457 458 kj::Promise<void> probeForShorterPath() { 459 return kj::evalNow([&]() -> kj::Promise<uint64_t> { 460 return pumpTo(*inner, kj::maxValue); 461 }).then([this](uint64_t actual) { 462 KJ_IF_MAYBE(currentParent, parent) { 463 KJ_IF_MAYBE(prober, currentParent->state.tryGet<kj::Own<PathProber>>()) { 464 // Either we didn't find any shorter path at all during probing and faked an EOF 465 // to get out of the probe (see comments in tryRead(), or we DID find a shorter path, 466 // completed a pumpTo() using a substream, and that substream redirected back to us, 467 // and THEN we couldn't find any further shorter paths for subsequent pumps. 468 469 // HACK: If we overwrite the Probing state now, we'll delete ourselves and delete 470 // this task promise, which is an error... let the event loop do it later by 471 // detaching. 472 task.attach(kj::mv(*prober)).detach([](kj::Exception&&){}); 473 parent = nullptr; 474 475 // OK, now we can change the parent state and signal it to proceed. 476 currentParent->state = kj::mv(inner); 477 readyFulfiller->fulfill(); 478 } 479 } 480 }).eagerlyEvaluate([this](kj::Exception&& exception) mutable { 481 // Something threw, so propagate the exception to break the parent. 482 readyFulfiller->reject(kj::mv(exception)); 483 }); 484 } 485 }; 486 487 protected: 488 // --------------------------------------------------------------------------- 489 // implements ByteStream::Server RPC interface 490 491 kj::Maybe<kj::Promise<Capability::Client>> shortenPath() override { 492 return shortenPathImpl(); 493 } 494 kj::Promise<Capability::Client> shortenPathImpl() { 495 // Called by RPC implementation to find out if a shorter path presents itself. 496 KJ_SWITCH_ONEOF(state) { 497 KJ_CASE_ONEOF(prober, kj::Own<PathProber>) { 498 return prober->whenReady().then([this]() { 499 KJ_ASSERT(!state.is<kj::Own<PathProber>>()); 500 return shortenPathImpl(); 501 }); 502 } 503 KJ_CASE_ONEOF(kjStream, kj::Own<kj::AsyncOutputStream>) { 504 // No shortening possible. Pretend we never resolve so that calls continue to be routed 505 // to us forever. 506 return kj::NEVER_DONE; 507 } 508 KJ_CASE_ONEOF(capnpStream, capnp::ByteStream::Client) { 509 return Capability::Client(capnpStream); 510 } 511 KJ_CASE_ONEOF(b, Borrowed) { 512 KJ_FAIL_REQUIRE("concurrent streaming calls disallowed") { break; } 513 return kj::NEVER_DONE; 514 } 515 KJ_CASE_ONEOF(e, Ended) { 516 // No shortening possible. Pretend we never resolve so that calls continue to be routed 517 // to us forever. 518 return kj::NEVER_DONE; 519 } 520 } 521 KJ_UNREACHABLE; 522 } 523 524 kj::Promise<void> write(WriteContext context) override { 525 KJ_SWITCH_ONEOF(state) { 526 KJ_CASE_ONEOF(prober, kj::Own<PathProber>) { 527 return prober->whenReady().then([this, context]() mutable { 528 KJ_ASSERT(!state.is<kj::Own<PathProber>>()); 529 return write(context); 530 }); 531 } 532 KJ_CASE_ONEOF(kjStream, kj::Own<kj::AsyncOutputStream>) { 533 auto data = context.getParams().getBytes(); 534 return kjStream->write(data.begin(), data.size()); 535 } 536 KJ_CASE_ONEOF(capnpStream, capnp::ByteStream::Client) { 537 auto params = context.getParams(); 538 auto req = capnpStream.writeRequest(params.totalSize()); 539 req.setBytes(params.getBytes()); 540 return req.send(); 541 } 542 KJ_CASE_ONEOF(b, Borrowed) { 543 KJ_FAIL_REQUIRE("concurrent streaming calls disallowed") { break; } 544 return kj::READY_NOW; 545 } 546 KJ_CASE_ONEOF(e, Ended) { 547 KJ_FAIL_REQUIRE("already called end()") { break; } 548 return kj::READY_NOW; 549 } 550 } 551 KJ_UNREACHABLE; 552 } 553 554 kj::Promise<void> end(EndContext context) override { 555 KJ_SWITCH_ONEOF(state) { 556 KJ_CASE_ONEOF(prober, kj::Own<PathProber>) { 557 return prober->whenReady().then([this, context]() mutable { 558 KJ_ASSERT(!state.is<kj::Own<PathProber>>()); 559 return end(context); 560 }); 561 } 562 KJ_CASE_ONEOF(kjStream, kj::Own<kj::AsyncOutputStream>) { 563 // TODO(someday): When KJ adds a proper .end() call, use it here. For now, we must 564 // drop the stream to close it. 565 state = Ended(); 566 return kj::READY_NOW; 567 } 568 KJ_CASE_ONEOF(capnpStream, capnp::ByteStream::Client) { 569 auto params = context.getParams(); 570 auto req = capnpStream.endRequest(params.totalSize()); 571 return context.tailCall(kj::mv(req)); 572 } 573 KJ_CASE_ONEOF(b, Borrowed) { 574 KJ_FAIL_REQUIRE("concurrent streaming calls disallowed") { break; } 575 return kj::READY_NOW; 576 } 577 KJ_CASE_ONEOF(e, Ended) { 578 KJ_FAIL_REQUIRE("already called end()") { break; } 579 return kj::READY_NOW; 580 } 581 } 582 KJ_UNREACHABLE; 583 } 584 585 kj::Promise<void> getSubstream(GetSubstreamContext context) override { 586 KJ_SWITCH_ONEOF(state) { 587 KJ_CASE_ONEOF(prober, kj::Own<PathProber>) { 588 return prober->whenReady().then([this, context]() mutable { 589 KJ_ASSERT(!state.is<kj::Own<PathProber>>()); 590 return getSubstream(context); 591 }); 592 } 593 KJ_CASE_ONEOF(kjStream, kj::Own<kj::AsyncOutputStream>) { 594 auto params = context.getParams(); 595 auto callback = params.getCallback(); 596 uint64_t limit = params.getLimit(); 597 context.releaseParams(); 598 599 auto results = context.initResults(MessageSize {2, 1}); 600 results.setSubstream(factory.streamSet.add(kj::heap<SubstreamImpl>( 601 factory, *this, thisCap(), *kjStream, kj::mv(callback), kj::mv(limit)))); 602 state = Borrowed { kj::mv(kjStream) }; 603 return kj::READY_NOW; 604 } 605 KJ_CASE_ONEOF(capnpStream, capnp::ByteStream::Client) { 606 auto params = context.getParams(); 607 auto req = capnpStream.getSubstreamRequest(params.totalSize()); 608 req.setCallback(params.getCallback()); 609 req.setLimit(params.getLimit()); 610 return context.tailCall(kj::mv(req)); 611 } 612 KJ_CASE_ONEOF(b, Borrowed) { 613 KJ_FAIL_REQUIRE("concurrent streaming calls disallowed") { break; } 614 return kj::READY_NOW; 615 } 616 KJ_CASE_ONEOF(e, Ended) { 617 KJ_FAIL_REQUIRE("already called end()") { break; } 618 return kj::READY_NOW; 619 } 620 } 621 KJ_UNREACHABLE; 622 } 623 624 private: 625 ByteStreamFactory& factory; 626 627 struct Borrowed { kj::Own<kj::AsyncOutputStream> stream; }; 628 struct Ended {}; 629 630 kj::OneOf<kj::Own<PathProber>, kj::Own<kj::AsyncOutputStream>, 631 capnp::ByteStream::Client, Borrowed, Ended> state; 632 633 class SubstreamCallbackImpl final: public capnp::ByteStream::SubstreamCallback::Server { 634 public: 635 SubstreamCallbackImpl(ByteStreamFactory& factory, 636 kj::Own<PathProber> pathProber, 637 kj::Own<kj::PromiseFulfiller<uint64_t>> originalPumpfulfiller, 638 uint64_t originalPumpLimit) 639 : factory(factory), 640 pathProber(kj::mv(pathProber)), 641 originalPumpfulfiller(kj::mv(originalPumpfulfiller)), 642 originalPumpLimit(originalPumpLimit) {} 643 644 ~SubstreamCallbackImpl() noexcept(false) { 645 if (!done) { 646 originalPumpfulfiller->reject(KJ_EXCEPTION(DISCONNECTED, 647 "stream disconnected because SubstreamCallbackImpl was never called back")); 648 } 649 } 650 651 kj::Promise<void> ended(EndedContext context) override { 652 KJ_REQUIRE(!done); 653 uint64_t actual = context.getParams().getByteCount(); 654 KJ_REQUIRE(actual <= originalPumpLimit); 655 656 done = true; 657 658 // EOF before pump completed. Signal a short pump. 659 originalPumpfulfiller->fulfill(context.getParams().getByteCount()); 660 661 // Give the original pump task a chance to finish up. 662 return pathProber->task.attach(kj::mv(pathProber)); 663 } 664 665 kj::Promise<void> reachedLimit(ReachedLimitContext context) override { 666 KJ_REQUIRE(!done); 667 done = true; 668 669 // Allow the shortened stream to redirect back to our original underlying stream. 670 auto results = context.getResults(capnp::MessageSize { 4, 1 }); 671 results.setNext(factory.streamSet.add( 672 kj::heap<CapnpToKjStreamAdapter>(factory, kj::mv(pathProber)))); 673 674 // The full pump completed. Note that it's important that we fulfill this after the 675 // PathProber has been attached to the new CapnpToKjStreamAdapter, which will have happened 676 // in CapnpToKjStreamAdapter's constructor, which calls pathProber->setNewParent(). 677 originalPumpfulfiller->fulfill(kj::cp(originalPumpLimit)); 678 679 return kj::READY_NOW; 680 } 681 682 private: 683 ByteStreamFactory& factory; 684 kj::Own<PathProber> pathProber; 685 kj::Own<kj::PromiseFulfiller<uint64_t>> originalPumpfulfiller; 686 uint64_t originalPumpLimit; 687 bool done = false; 688 }; 689 }; 690 691 // ======================================================================================= 692 693 class ByteStreamFactory::KjToCapnpStreamAdapter final: public kj::AsyncOutputStream { 694 public: 695 KjToCapnpStreamAdapter(ByteStreamFactory& factory, capnp::ByteStream::Client innerParam) 696 : factory(factory), 697 inner(kj::mv(innerParam)), 698 findShorterPathTask(findShorterPath(inner).fork()) {} 699 700 ~KjToCapnpStreamAdapter() noexcept(false) { 701 // HACK: KJ streams are implicitly ended on destruction, but the RPC stream needs a call. We 702 // use a detached promise for now, which is probably OK since capabilities are refcounted and 703 // asynchronously destroyed anyway. 704 // TODO(cleanup): Fix this when KJ streads add an explicit end() method. 705 KJ_IF_MAYBE(o, optimized) { 706 o->directEnd(); 707 } else { 708 inner.endRequest(MessageSize {2, 0}).send().detach([](kj::Exception&&){}); 709 } 710 } 711 712 kj::Promise<void> write(const void* buffer, size_t size) override { 713 KJ_SWITCH_ONEOF(getShortestPath()) { 714 KJ_CASE_ONEOF(promise, kj::Promise<void>) { 715 return promise.then([this,buffer,size]() { 716 return write(buffer, size); 717 }); 718 } 719 KJ_CASE_ONEOF(kjStream, StreamServerBase::BorrowedStream) { 720 auto limit = kj::min(kjStream.limit, MAX_BYTES_PER_WRITE); 721 if (size <= limit) { 722 auto promise = kjStream.stream.write(buffer, size); 723 return promise.then([kjStream,size]() mutable { 724 kjStream.lender.returnStream(size); 725 }); 726 } else { 727 auto promise = kjStream.stream.write(buffer, limit); 728 return promise.then([this,kjStream,buffer,size,limit]() mutable { 729 kjStream.lender.returnStream(limit); 730 return write(reinterpret_cast<const byte*>(buffer) + limit, 731 size - limit); 732 }); 733 } 734 } 735 KJ_CASE_ONEOF(capnpStream, capnp::ByteStream::Client*) { 736 if (size <= MAX_BYTES_PER_WRITE) { 737 auto req = capnpStream->writeRequest(MessageSize { 8 + size / sizeof(word), 0 }); 738 req.setBytes(kj::arrayPtr(reinterpret_cast<const byte*>(buffer), size)); 739 return req.send(); 740 } else { 741 auto req = capnpStream->writeRequest( 742 MessageSize { 8 + MAX_BYTES_PER_WRITE / sizeof(word), 0 }); 743 req.setBytes(kj::arrayPtr(reinterpret_cast<const byte*>(buffer), MAX_BYTES_PER_WRITE)); 744 return req.send().then([this,buffer,size]() mutable { 745 return write(reinterpret_cast<const byte*>(buffer) + MAX_BYTES_PER_WRITE, 746 size - MAX_BYTES_PER_WRITE); 747 }); 748 } 749 } 750 } 751 KJ_UNREACHABLE; 752 } 753 754 kj::Promise<void> write(kj::ArrayPtr<const kj::ArrayPtr<const byte>> pieces) override { 755 KJ_SWITCH_ONEOF(getShortestPath()) { 756 KJ_CASE_ONEOF(promise, kj::Promise<void>) { 757 return promise.then([this,pieces]() { 758 return write(pieces); 759 }); 760 } 761 KJ_CASE_ONEOF(kjStream, StreamServerBase::BorrowedStream) { 762 size_t size = 0; 763 for (auto& piece: pieces) { size += piece.size(); } 764 auto limit = kj::min(kjStream.limit, MAX_BYTES_PER_WRITE); 765 if (size <= limit) { 766 auto promise = kjStream.stream.write(pieces); 767 return promise.then([kjStream,size]() mutable { 768 kjStream.lender.returnStream(size); 769 }); 770 } else { 771 // ughhhhhhhhhh, we need to split the pieces. 772 return splitAndWrite(pieces, kjStream.limit, 773 [kjStream,limit](kj::ArrayPtr<const kj::ArrayPtr<const byte>> pieces) mutable { 774 return kjStream.stream.write(pieces).then([kjStream,limit]() mutable { 775 kjStream.lender.returnStream(limit); 776 }); 777 }); 778 } 779 } 780 KJ_CASE_ONEOF(capnpStream, capnp::ByteStream::Client*) { 781 auto writePieces = [capnpStream](kj::ArrayPtr<const kj::ArrayPtr<const byte>> pieces) { 782 size_t size = 0; 783 for (auto& piece: pieces) size += piece.size(); 784 auto req = capnpStream->writeRequest(MessageSize { 8 + size / sizeof(word), 0 }); 785 auto out = req.initBytes(size); 786 byte* ptr = out.begin(); 787 for (auto& piece: pieces) { 788 memcpy(ptr, piece.begin(), piece.size()); 789 ptr += piece.size(); 790 } 791 KJ_ASSERT(ptr == out.end()); 792 return req.send(); 793 }; 794 795 size_t size = 0; 796 for (auto& piece: pieces) size += piece.size(); 797 if (size <= MAX_BYTES_PER_WRITE) { 798 return writePieces(pieces); 799 } else { 800 // ughhhhhhhhhh, we need to split the pieces. 801 return splitAndWrite(pieces, MAX_BYTES_PER_WRITE, writePieces); 802 } 803 } 804 } 805 KJ_UNREACHABLE; 806 } 807 808 kj::Maybe<kj::Promise<uint64_t>> tryPumpFrom( 809 kj::AsyncInputStream& input, uint64_t amount = kj::maxValue) override { 810 KJ_IF_MAYBE(rpc, kj::dynamicDowncastIfAvailable<CapnpToKjStreamAdapter::PathProber>(input)) { 811 // Oh interesting, it turns we're hosting an incoming ByteStream which is pumping to this 812 // outgoing ByteStream. We can let the Cap'n Proto RPC layer know that it can shorten the 813 // path from one to the other. 814 return rpc->pumpToShorterPath(inner, amount); 815 } else { 816 return pumpLoop(input, 0, amount); 817 } 818 } 819 820 kj::Promise<void> whenWriteDisconnected() override { 821 return findShorterPathTask.addBranch(); 822 } 823 824 private: 825 ByteStreamFactory& factory; 826 capnp::ByteStream::Client inner; 827 kj::Maybe<StreamServerBase&> optimized; 828 829 kj::ForkedPromise<void> findShorterPathTask; 830 // This serves two purposes: 831 // 1. Waits for the capability to resolve (if it is a promise), and then shortens the path if 832 // possible. 833 // 2. Implements whenWriteDisconnected(). 834 835 kj::Promise<void> findShorterPath(capnp::ByteStream::Client& capnpClient) { 836 // If the capnp stream turns out to resolve back to this process, shorten the path. 837 // Also, implement whenWriteDisconnected() based on this. 838 return factory.streamSet.getLocalServer(capnpClient) 839 .then([this](kj::Maybe<capnp::ByteStream::Server&> server) -> kj::Promise<void> { 840 KJ_IF_MAYBE(s, server) { 841 // Yay, we discovered that the ByteStream actually points back to a local KJ stream. 842 // We can use this to shorten the path by skipping the RPC machinery. 843 return findShorterPath(kj::downcast<StreamServerBase>(*s)); 844 } else { 845 // The capability is fully-resolved. This suggests that the remote implementation is 846 // NOT a CapnpToKjStreamAdapter at all, because CapnpToKjStreamAdapter is designed to 847 // always look like a promise. It's some other implementation that doesn't present 848 // itself as a promise. We have no way to detect when it is disconnected. 849 return kj::NEVER_DONE; 850 } 851 }, [](kj::Exception&& e) -> kj::Promise<void> { 852 // getLocalServer() thrown when the capability is a promise cap that rejects. We can 853 // use this to implement whenWriteDisconnected(). 854 // 855 // (Note that because this exception handler is passed to the .then(), it does NOT catch 856 // eoxceptions thrown by the success handler immediately above it. This handler will ONLY 857 // catch exceptions from getLocalServer() itself.) 858 return kj::READY_NOW; 859 }); 860 } 861 862 kj::Promise<void> findShorterPath(StreamServerBase& capnpServer) { 863 // We found a shorter path back to this process. Record it. 864 optimized = capnpServer; 865 866 KJ_SWITCH_ONEOF(capnpServer.getShortestPath()) { 867 KJ_CASE_ONEOF(promise, kj::Promise<void>) { 868 return promise.then([this,&capnpServer]() { 869 return findShorterPath(capnpServer); 870 }); 871 } 872 KJ_CASE_ONEOF(kjStream, StreamServerBase::BorrowedStream) { 873 // The ByteStream::Server wraps a regular KJ stream that does not wrap another capnp 874 // stream. 875 if (kjStream.limit < (uint64_t)kj::maxValue / 2) { 876 // But it isn't wrapping that stream forever. Eventually it plans to redirect back to 877 // some other stream. So, let's wait for that, and possibly shorten again. 878 kjStream.lender.returnStream(0); 879 return KJ_ASSERT_NONNULL(capnpServer.shortenPath()) 880 .then([this, &capnpServer](auto&&) { 881 return findShorterPath(capnpServer); 882 }); 883 } else { 884 // This KJ stream is (effectively) the permanent endpoint. We can't get any shorter 885 // from here. All we want to do now is watch for disconnect. 886 auto promise = kjStream.stream.whenWriteDisconnected(); 887 kjStream.lender.returnStream(0); 888 return promise; 889 } 890 } 891 KJ_CASE_ONEOF(capnpStream, capnp::ByteStream::Client*) { 892 return findShorterPath(*capnpStream); 893 } 894 } 895 KJ_UNREACHABLE; 896 } 897 898 StreamServerBase::ShortestPath getShortestPath() { 899 KJ_IF_MAYBE(o, optimized) { 900 return o->getShortestPath(); 901 } else { 902 return &inner; 903 } 904 } 905 906 kj::Promise<uint64_t> pumpLoop(kj::AsyncInputStream& input, 907 uint64_t completed, uint64_t remaining) { 908 if (remaining == 0) return completed; 909 910 KJ_SWITCH_ONEOF(getShortestPath()) { 911 KJ_CASE_ONEOF(promise, kj::Promise<void>) { 912 return promise.then([this,&input,completed,remaining]() { 913 return pumpLoop(input,completed,remaining); 914 }); 915 } 916 KJ_CASE_ONEOF(kjStream, StreamServerBase::BorrowedStream) { 917 // Oh hell yes, this capability actually points back to a stream in our own thread. We can 918 // stop sending RPCs and just pump directly. 919 920 if (remaining <= kjStream.limit) { 921 return input.pumpTo(kjStream.stream, remaining) 922 .then([kjStream,completed](uint64_t actual) { 923 kjStream.lender.returnStream(actual); 924 return actual + completed; 925 }); 926 } else { 927 auto promise = input.pumpTo(kjStream.stream, kjStream.limit); 928 return promise.then([this,&input,completed,remaining,kjStream] 929 (uint64_t actual) mutable -> kj::Promise<uint64_t> { 930 kjStream.lender.returnStream(actual); 931 if (actual < kjStream.limit) { 932 // EOF reached. 933 return completed + actual; 934 } else { 935 return pumpLoop(input, completed + actual, remaining - actual); 936 } 937 }); 938 } 939 } 940 KJ_CASE_ONEOF(capnpStream, capnp::ByteStream::Client*) { 941 // Pumping from some other kind of steram. Optimize the pump by reading from the input 942 // directly into outgoing RPC messages. 943 size_t size = kj::min(remaining, 8192); 944 auto req = capnpStream->writeRequest(MessageSize { 8 + size / sizeof(word) }); 945 946 auto orphanage = Orphanage::getForMessageContaining( 947 capnp::ByteStream::WriteParams::Builder(req)); 948 949 auto buffer = orphanage.newOrphan<Data>(size); 950 951 struct WriteRequestAndBuffer { 952 // The order of construction/destruction of lambda captures is unspecified, but we care 953 // about ordering between these two things that we want to capture, so... we need a 954 // struct. 955 StreamingRequest<capnp::ByteStream::WriteParams> request; 956 Orphan<Data> buffer; // points into `request`... 957 }; 958 959 WriteRequestAndBuffer wrab = { kj::mv(req), kj::mv(buffer) }; 960 961 return input.tryRead(wrab.buffer.get().begin(), 1, size) 962 .then([this, &input, completed, remaining, size, wrab = kj::mv(wrab)] 963 (size_t actual) mutable -> kj::Promise<uint64_t> { 964 if (actual == 0) { 965 return completed; 966 } if (actual < size) { 967 wrab.buffer.truncate(actual); 968 } 969 970 wrab.request.adoptBytes(kj::mv(wrab.buffer)); 971 return wrab.request.send() 972 .then([this, &input, completed, remaining, actual]() { 973 return pumpLoop(input, completed + actual, remaining - actual); 974 }); 975 }); 976 } 977 } 978 KJ_UNREACHABLE; 979 } 980 981 template <typename WritePieces> 982 kj::Promise<void> splitAndWrite(kj::ArrayPtr<const kj::ArrayPtr<const byte>> pieces, 983 size_t limit, WritePieces&& writeFirstPieces) { 984 size_t splitByte = limit; 985 size_t splitPiece = 0; 986 while (pieces[splitPiece].size() <= splitByte) { 987 splitByte -= pieces[splitPiece].size(); 988 ++splitPiece; 989 } 990 991 if (splitByte == 0) { 992 // Oh thank god, the split is between two pieces. 993 auto rest = pieces.slice(splitPiece, pieces.size()); 994 return writeFirstPieces(pieces.slice(0, splitPiece)) 995 .then([this,rest]() mutable { 996 return write(rest); 997 }); 998 } else { 999 // FUUUUUUUU---- we need to split one of the pieces in two. 1000 auto left = kj::heapArray<kj::ArrayPtr<const byte>>(splitPiece + 1); 1001 auto right = kj::heapArray<kj::ArrayPtr<const byte>>(pieces.size() - splitPiece); 1002 for (auto i: kj::zeroTo(splitPiece)) { 1003 left[i] = pieces[i]; 1004 } 1005 for (auto i: kj::zeroTo(right.size())) { 1006 right[i] = pieces[splitPiece + i]; 1007 } 1008 left.back() = pieces[splitPiece].slice(0, splitByte); 1009 right.front() = pieces[splitPiece].slice(splitByte, pieces[splitPiece].size()); 1010 1011 return writeFirstPieces(left).attach(kj::mv(left)) 1012 .then([this,right=kj::mv(right)]() mutable { 1013 return write(right).attach(kj::mv(right)); 1014 }); 1015 } 1016 } 1017 }; 1018 1019 // ======================================================================================= 1020 1021 capnp::ByteStream::Client ByteStreamFactory::kjToCapnp(kj::Own<kj::AsyncOutputStream> kjStream) { 1022 return streamSet.add(kj::heap<CapnpToKjStreamAdapter>(*this, kj::mv(kjStream))); 1023 } 1024 1025 kj::Own<kj::AsyncOutputStream> ByteStreamFactory::capnpToKj(capnp::ByteStream::Client capnpStream) { 1026 return kj::heap<KjToCapnpStreamAdapter>(*this, kj::mv(capnpStream)); 1027 } 1028 1029 } // namespace capnp