async-io.c++ (112738B)
1 // Copyright (c) 2013-2017 Sandstorm Development Group, Inc. and contributors 2 // Licensed under the MIT License: 3 // 4 // Permission is hereby granted, free of charge, to any person obtaining a copy 5 // of this software and associated documentation files (the "Software"), to deal 6 // in the Software without restriction, including without limitation the rights 7 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 8 // copies of the Software, and to permit persons to whom the Software is 9 // furnished to do so, subject to the following conditions: 10 // 11 // The above copyright notice and this permission notice shall be included in 12 // all copies or substantial portions of the Software. 13 // 14 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 17 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 19 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 20 // THE SOFTWARE. 21 22 #if _WIN32 23 // Request Vista-level APIs. 24 #include "win32-api-version.h" 25 #endif 26 27 #include "async-io.h" 28 #include "async-io-internal.h" 29 #include "debug.h" 30 #include "vector.h" 31 #include "io.h" 32 #include "one-of.h" 33 #include <deque> 34 35 #if _WIN32 36 #include <winsock2.h> 37 #include <ws2ipdef.h> 38 #include <ws2tcpip.h> 39 #include "windows-sanity.h" 40 #define inet_pton InetPtonA 41 #define inet_ntop InetNtopA 42 #include <io.h> 43 #define dup _dup 44 #else 45 #include <sys/socket.h> 46 #include <arpa/inet.h> 47 #include <netinet/in.h> 48 #include <sys/un.h> 49 #include <unistd.h> 50 #endif 51 52 namespace kj { 53 54 Promise<void> AsyncInputStream::read(void* buffer, size_t bytes) { 55 return read(buffer, bytes, bytes).then([](size_t) {}); 56 } 57 58 Promise<size_t> AsyncInputStream::read(void* buffer, size_t minBytes, size_t maxBytes) { 59 return tryRead(buffer, minBytes, maxBytes).then([=](size_t result) { 60 if (result >= minBytes) { 61 return result; 62 } else { 63 kj::throwRecoverableException(KJ_EXCEPTION(DISCONNECTED, "stream disconnected prematurely")); 64 // Pretend we read zeros from the input. 65 memset(reinterpret_cast<byte*>(buffer) + result, 0, minBytes - result); 66 return minBytes; 67 } 68 }); 69 } 70 71 Maybe<uint64_t> AsyncInputStream::tryGetLength() { return nullptr; } 72 73 void AsyncInputStream::registerAncillaryMessageHandler( 74 Function<void(ArrayPtr<AncillaryMessage>)> fn) { 75 KJ_UNIMPLEMENTED("registerAncillaryMsgHandler is not implemented by this AsyncInputStream"); 76 } 77 78 Maybe<Own<AsyncInputStream>> AsyncInputStream::tryTee(uint64_t) { 79 return nullptr; 80 } 81 82 namespace { 83 84 class AsyncPump { 85 public: 86 AsyncPump(AsyncInputStream& input, AsyncOutputStream& output, uint64_t limit) 87 : input(input), output(output), limit(limit) {} 88 89 Promise<uint64_t> pump() { 90 // TODO(perf): This could be more efficient by reading half a buffer at a time and then 91 // starting the next read concurrent with writing the data from the previous read. 92 93 uint64_t n = kj::min(limit - doneSoFar, sizeof(buffer)); 94 if (n == 0) return doneSoFar; 95 96 return input.tryRead(buffer, 1, n) 97 .then([this](size_t amount) -> Promise<uint64_t> { 98 if (amount == 0) return doneSoFar; // EOF 99 doneSoFar += amount; 100 return output.write(buffer, amount) 101 .then([this]() { 102 return pump(); 103 }); 104 }); 105 } 106 107 private: 108 AsyncInputStream& input; 109 AsyncOutputStream& output; 110 uint64_t limit; 111 uint64_t doneSoFar = 0; 112 byte buffer[4096]; 113 }; 114 115 } // namespace 116 117 Promise<uint64_t> AsyncInputStream::pumpTo( 118 AsyncOutputStream& output, uint64_t amount) { 119 // See if output wants to dispatch on us. 120 KJ_IF_MAYBE(result, output.tryPumpFrom(*this, amount)) { 121 return kj::mv(*result); 122 } 123 124 // OK, fall back to naive approach. 125 auto pump = heap<AsyncPump>(*this, output, amount); 126 auto promise = pump->pump(); 127 return promise.attach(kj::mv(pump)); 128 } 129 130 namespace { 131 132 class AllReader { 133 public: 134 AllReader(AsyncInputStream& input): input(input) {} 135 136 Promise<Array<byte>> readAllBytes(uint64_t limit) { 137 return loop(limit).then([this, limit](uint64_t headroom) { 138 auto out = heapArray<byte>(limit - headroom); 139 copyInto(out); 140 return out; 141 }); 142 } 143 144 Promise<String> readAllText(uint64_t limit) { 145 return loop(limit).then([this, limit](uint64_t headroom) { 146 auto out = heapArray<char>(limit - headroom + 1); 147 copyInto(out.slice(0, out.size() - 1).asBytes()); 148 out.back() = '\0'; 149 return String(kj::mv(out)); 150 }); 151 } 152 153 private: 154 AsyncInputStream& input; 155 Vector<Array<byte>> parts; 156 157 Promise<uint64_t> loop(uint64_t limit) { 158 KJ_REQUIRE(limit > 0, "Reached limit before EOF."); 159 160 auto part = heapArray<byte>(kj::min(4096, limit)); 161 auto partPtr = part.asPtr(); 162 parts.add(kj::mv(part)); 163 return input.tryRead(partPtr.begin(), partPtr.size(), partPtr.size()) 164 .then([this,KJ_CPCAP(partPtr),limit](size_t amount) mutable -> Promise<uint64_t> { 165 limit -= amount; 166 if (amount < partPtr.size()) { 167 return limit; 168 } else { 169 return loop(limit); 170 } 171 }); 172 } 173 174 void copyInto(ArrayPtr<byte> out) { 175 size_t pos = 0; 176 for (auto& part: parts) { 177 size_t n = kj::min(part.size(), out.size() - pos); 178 memcpy(out.begin() + pos, part.begin(), n); 179 pos += n; 180 } 181 } 182 }; 183 184 } // namespace 185 186 Promise<Array<byte>> AsyncInputStream::readAllBytes(uint64_t limit) { 187 auto reader = kj::heap<AllReader>(*this); 188 auto promise = reader->readAllBytes(limit); 189 return promise.attach(kj::mv(reader)); 190 } 191 192 Promise<String> AsyncInputStream::readAllText(uint64_t limit) { 193 auto reader = kj::heap<AllReader>(*this); 194 auto promise = reader->readAllText(limit); 195 return promise.attach(kj::mv(reader)); 196 } 197 198 Maybe<Promise<uint64_t>> AsyncOutputStream::tryPumpFrom( 199 AsyncInputStream& input, uint64_t amount) { 200 return nullptr; 201 } 202 203 namespace { 204 205 class AsyncPipe final: public AsyncCapabilityStream, public Refcounted { 206 public: 207 ~AsyncPipe() noexcept(false) { 208 KJ_REQUIRE(state == nullptr || ownState.get() != nullptr, 209 "destroying AsyncPipe with operation still in-progress; probably going to segfault") { 210 // Don't std::terminate(). 211 break; 212 } 213 } 214 215 Promise<size_t> tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { 216 if (minBytes == 0) { 217 return size_t(0); 218 } else KJ_IF_MAYBE(s, state) { 219 return s->tryRead(buffer, minBytes, maxBytes); 220 } else { 221 return newAdaptedPromise<ReadResult, BlockedRead>( 222 *this, arrayPtr(reinterpret_cast<byte*>(buffer), maxBytes), minBytes) 223 .then([](ReadResult r) { return r.byteCount; }); 224 } 225 } 226 227 Promise<ReadResult> tryReadWithFds(void* buffer, size_t minBytes, size_t maxBytes, 228 AutoCloseFd* fdBuffer, size_t maxFds) override { 229 if (minBytes == 0) { 230 return ReadResult { 0, 0 }; 231 } else KJ_IF_MAYBE(s, state) { 232 return s->tryReadWithFds(buffer, minBytes, maxBytes, fdBuffer, maxFds); 233 } else { 234 return newAdaptedPromise<ReadResult, BlockedRead>( 235 *this, arrayPtr(reinterpret_cast<byte*>(buffer), maxBytes), minBytes, 236 kj::arrayPtr(fdBuffer, maxFds)); 237 } 238 } 239 240 Promise<ReadResult> tryReadWithStreams( 241 void* buffer, size_t minBytes, size_t maxBytes, 242 Own<AsyncCapabilityStream>* streamBuffer, size_t maxStreams) override { 243 if (minBytes == 0) { 244 return ReadResult { 0, 0 }; 245 } else KJ_IF_MAYBE(s, state) { 246 return s->tryReadWithStreams(buffer, minBytes, maxBytes, streamBuffer, maxStreams); 247 } else { 248 return newAdaptedPromise<ReadResult, BlockedRead>( 249 *this, arrayPtr(reinterpret_cast<byte*>(buffer), maxBytes), minBytes, 250 kj::arrayPtr(streamBuffer, maxStreams)); 251 } 252 } 253 254 Promise<uint64_t> pumpTo(AsyncOutputStream& output, uint64_t amount) override { 255 if (amount == 0) { 256 return uint64_t(0); 257 } else KJ_IF_MAYBE(s, state) { 258 return s->pumpTo(output, amount); 259 } else { 260 return newAdaptedPromise<uint64_t, BlockedPumpTo>(*this, output, amount); 261 } 262 } 263 264 void abortRead() override { 265 KJ_IF_MAYBE(s, state) { 266 s->abortRead(); 267 } else { 268 ownState = kj::heap<AbortedRead>(); 269 state = *ownState; 270 271 readAborted = true; 272 KJ_IF_MAYBE(f, readAbortFulfiller) { 273 f->get()->fulfill(); 274 readAbortFulfiller = nullptr; 275 } 276 } 277 } 278 279 Promise<void> write(const void* buffer, size_t size) override { 280 if (size == 0) { 281 return READY_NOW; 282 } else KJ_IF_MAYBE(s, state) { 283 return s->write(buffer, size); 284 } else { 285 return newAdaptedPromise<void, BlockedWrite>( 286 *this, arrayPtr(reinterpret_cast<const byte*>(buffer), size), nullptr); 287 } 288 } 289 290 Promise<void> write(ArrayPtr<const ArrayPtr<const byte>> pieces) override { 291 while (pieces.size() > 0 && pieces[0].size() == 0) { 292 pieces = pieces.slice(1, pieces.size()); 293 } 294 295 if (pieces.size() == 0) { 296 return kj::READY_NOW; 297 } else KJ_IF_MAYBE(s, state) { 298 return s->write(pieces); 299 } else { 300 return newAdaptedPromise<void, BlockedWrite>( 301 *this, pieces[0], pieces.slice(1, pieces.size())); 302 } 303 } 304 305 Promise<void> writeWithFds(ArrayPtr<const byte> data, 306 ArrayPtr<const ArrayPtr<const byte>> moreData, 307 ArrayPtr<const int> fds) override { 308 while (data.size() == 0 && moreData.size() > 0) { 309 data = moreData.front(); 310 moreData = moreData.slice(1, moreData.size()); 311 } 312 313 if (data.size() == 0) { 314 KJ_REQUIRE(fds.size() == 0, "can't attach FDs to empty message"); 315 return READY_NOW; 316 } else KJ_IF_MAYBE(s, state) { 317 return s->writeWithFds(data, moreData, fds); 318 } else { 319 return newAdaptedPromise<void, BlockedWrite>(*this, data, moreData, fds); 320 } 321 } 322 323 Promise<void> writeWithStreams(ArrayPtr<const byte> data, 324 ArrayPtr<const ArrayPtr<const byte>> moreData, 325 Array<Own<AsyncCapabilityStream>> streams) override { 326 while (data.size() == 0 && moreData.size() > 0) { 327 data = moreData.front(); 328 moreData = moreData.slice(1, moreData.size()); 329 } 330 331 if (data.size() == 0) { 332 KJ_REQUIRE(streams.size() == 0, "can't attach capabilities to empty message"); 333 return READY_NOW; 334 } else KJ_IF_MAYBE(s, state) { 335 return s->writeWithStreams(data, moreData, kj::mv(streams)); 336 } else { 337 return newAdaptedPromise<void, BlockedWrite>(*this, data, moreData, kj::mv(streams)); 338 } 339 } 340 341 Maybe<Promise<uint64_t>> tryPumpFrom( 342 AsyncInputStream& input, uint64_t amount) override { 343 if (amount == 0) { 344 return Promise<uint64_t>(uint64_t(0)); 345 } else KJ_IF_MAYBE(s, state) { 346 return s->tryPumpFrom(input, amount); 347 } else { 348 return newAdaptedPromise<uint64_t, BlockedPumpFrom>(*this, input, amount); 349 } 350 } 351 352 Promise<void> whenWriteDisconnected() override { 353 if (readAborted) { 354 return kj::READY_NOW; 355 } else KJ_IF_MAYBE(p, readAbortPromise) { 356 return p->addBranch(); 357 } else { 358 auto paf = newPromiseAndFulfiller<void>(); 359 readAbortFulfiller = kj::mv(paf.fulfiller); 360 auto fork = paf.promise.fork(); 361 auto result = fork.addBranch(); 362 readAbortPromise = kj::mv(fork); 363 return result; 364 } 365 } 366 367 void shutdownWrite() override { 368 KJ_IF_MAYBE(s, state) { 369 s->shutdownWrite(); 370 } else { 371 ownState = kj::heap<ShutdownedWrite>(); 372 state = *ownState; 373 } 374 } 375 376 private: 377 Maybe<AsyncCapabilityStream&> state; 378 // Object-oriented state! If any method call is blocked waiting on activity from the other end, 379 // then `state` is non-null and method calls should be forwarded to it. If no calls are 380 // outstanding, `state` is null. 381 382 kj::Own<AsyncCapabilityStream> ownState; 383 384 bool readAborted = false; 385 Maybe<Own<PromiseFulfiller<void>>> readAbortFulfiller = nullptr; 386 Maybe<ForkedPromise<void>> readAbortPromise = nullptr; 387 388 void endState(AsyncIoStream& obj) { 389 KJ_IF_MAYBE(s, state) { 390 if (s == &obj) { 391 state = nullptr; 392 } 393 } 394 } 395 396 template <typename F> 397 static auto teeExceptionVoid(F& fulfiller) { 398 // Returns a functor that can be passed as the second parameter to .then() to propagate the 399 // exception to a given fulfiller. The functor's return type is void. 400 return [&fulfiller](kj::Exception&& e) { 401 fulfiller.reject(kj::cp(e)); 402 kj::throwRecoverableException(kj::mv(e)); 403 }; 404 } 405 template <typename F> 406 static auto teeExceptionSize(F& fulfiller) { 407 // Returns a functor that can be passed as the second parameter to .then() to propagate the 408 // exception to a given fulfiller. The functor's return type is size_t. 409 return [&fulfiller](kj::Exception&& e) -> size_t { 410 fulfiller.reject(kj::cp(e)); 411 kj::throwRecoverableException(kj::mv(e)); 412 return 0; 413 }; 414 } 415 template <typename T, typename F> 416 static auto teeExceptionPromise(F& fulfiller) { 417 // Returns a functor that can be passed as the second parameter to .then() to propagate the 418 // exception to a given fulfiller. The functor's return type is Promise<T>. 419 return [&fulfiller](kj::Exception&& e) -> kj::Promise<T> { 420 fulfiller.reject(kj::cp(e)); 421 return kj::mv(e); 422 }; 423 } 424 425 class BlockedWrite final: public AsyncCapabilityStream { 426 // AsyncPipe state when a write() is currently waiting for a corresponding read(). 427 428 public: 429 BlockedWrite(PromiseFulfiller<void>& fulfiller, AsyncPipe& pipe, 430 ArrayPtr<const byte> writeBuffer, 431 ArrayPtr<const ArrayPtr<const byte>> morePieces, 432 kj::OneOf<ArrayPtr<const int>, Array<Own<AsyncCapabilityStream>>> capBuffer = {}) 433 : fulfiller(fulfiller), pipe(pipe), writeBuffer(writeBuffer), morePieces(morePieces), 434 capBuffer(kj::mv(capBuffer)) { 435 KJ_REQUIRE(pipe.state == nullptr); 436 pipe.state = *this; 437 } 438 439 ~BlockedWrite() noexcept(false) { 440 pipe.endState(*this); 441 } 442 443 Promise<size_t> tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { 444 KJ_SWITCH_ONEOF(tryReadImpl(buffer, minBytes, maxBytes)) { 445 KJ_CASE_ONEOF(done, Done) { 446 return done.result; 447 } 448 KJ_CASE_ONEOF(retry, Retry) { 449 return pipe.tryRead(retry.buffer, retry.minBytes, retry.maxBytes) 450 .then([n = retry.alreadyRead](size_t amount) { return amount + n; }); 451 } 452 } 453 KJ_UNREACHABLE; 454 } 455 456 Promise<ReadResult> tryReadWithFds(void* buffer, size_t minBytes, size_t maxBytes, 457 AutoCloseFd* fdBuffer, size_t maxFds) override { 458 size_t capCount = 0; 459 { // TODO(cleanup): Remove redundant braces when we update to C++17. 460 KJ_SWITCH_ONEOF(capBuffer) { 461 KJ_CASE_ONEOF(fds, ArrayPtr<const int>) { 462 capCount = kj::max(fds.size(), maxFds); 463 // Unfortunately, we have to dup() each FD, because the writer doesn't release ownership 464 // by default. 465 // TODO(perf): Should we add an ownership-releasing version of writeWithFds()? 466 for (auto i: kj::zeroTo(capCount)) { 467 int duped; 468 KJ_SYSCALL(duped = dup(fds[i])); 469 fdBuffer[i] = kj::AutoCloseFd(fds[i]); 470 } 471 fdBuffer += capCount; 472 maxFds -= capCount; 473 } 474 KJ_CASE_ONEOF(streams, Array<Own<AsyncCapabilityStream>>) { 475 if (streams.size() > 0 && maxFds > 0) { 476 // TODO(someday): We could let people pass a LowLevelAsyncIoProvider to 477 // newTwoWayPipe() if we wanted to auto-wrap FDs, but does anyone care? 478 KJ_FAIL_REQUIRE( 479 "async pipe message was written with streams attached, but corresponding read " 480 "asked for FDs, and we don't know how to convert here"); 481 } 482 } 483 } 484 } 485 486 // Drop any unclaimed caps. This mirrors the behavior of unix sockets, where if we didn't 487 // provide enough buffer space for all the written FDs, the remaining ones are lost. 488 capBuffer = {}; 489 490 KJ_SWITCH_ONEOF(tryReadImpl(buffer, minBytes, maxBytes)) { 491 KJ_CASE_ONEOF(done, Done) { 492 return ReadResult { done.result, capCount }; 493 } 494 KJ_CASE_ONEOF(retry, Retry) { 495 return pipe.tryReadWithFds( 496 retry.buffer, retry.minBytes, retry.maxBytes, fdBuffer, maxFds) 497 .then([byteCount = retry.alreadyRead, capCount](ReadResult result) { 498 result.byteCount += byteCount; 499 result.capCount += capCount; 500 return result; 501 }); 502 } 503 } 504 KJ_UNREACHABLE; 505 } 506 507 Promise<ReadResult> tryReadWithStreams( 508 void* buffer, size_t minBytes, size_t maxBytes, 509 Own<AsyncCapabilityStream>* streamBuffer, size_t maxStreams) override { 510 size_t capCount = 0; 511 { // TODO(cleanup): Remove redundant braces when we update to C++17. 512 KJ_SWITCH_ONEOF(capBuffer) { 513 KJ_CASE_ONEOF(fds, ArrayPtr<const int>) { 514 if (fds.size() > 0 && maxStreams > 0) { 515 // TODO(someday): Use AsyncIoStream's `Maybe<int> getFd()` method? 516 KJ_FAIL_REQUIRE( 517 "async pipe message was written with FDs attached, but corresponding read " 518 "asked for streams, and we don't know how to convert here"); 519 } 520 } 521 KJ_CASE_ONEOF(streams, Array<Own<AsyncCapabilityStream>>) { 522 capCount = kj::max(streams.size(), maxStreams); 523 for (auto i: kj::zeroTo(capCount)) { 524 streamBuffer[i] = kj::mv(streams[i]); 525 } 526 streamBuffer += capCount; 527 maxStreams -= capCount; 528 } 529 } 530 } 531 532 // Drop any unclaimed caps. This mirrors the behavior of unix sockets, where if we didn't 533 // provide enough buffer space for all the written FDs, the remaining ones are lost. 534 capBuffer = {}; 535 536 KJ_SWITCH_ONEOF(tryReadImpl(buffer, minBytes, maxBytes)) { 537 KJ_CASE_ONEOF(done, Done) { 538 return ReadResult { done.result, capCount }; 539 } 540 KJ_CASE_ONEOF(retry, Retry) { 541 return pipe.tryReadWithStreams( 542 retry.buffer, retry.minBytes, retry.maxBytes, streamBuffer, maxStreams) 543 .then([byteCount = retry.alreadyRead, capCount](ReadResult result) { 544 result.byteCount += byteCount; 545 result.capCount += capCount; 546 return result; 547 }); 548 } 549 } 550 KJ_UNREACHABLE; 551 } 552 553 Promise<uint64_t> pumpTo(AsyncOutputStream& output, uint64_t amount) override { 554 // Note: Pumps drop all capabilities. 555 KJ_REQUIRE(canceler.isEmpty(), "already pumping"); 556 557 if (amount < writeBuffer.size()) { 558 // Consume a portion of the write buffer. 559 return canceler.wrap(output.write(writeBuffer.begin(), amount) 560 .then([this,amount]() { 561 writeBuffer = writeBuffer.slice(amount, writeBuffer.size()); 562 // We pumped the full amount, so we're done pumping. 563 return amount; 564 }, teeExceptionSize(fulfiller))); 565 } 566 567 // First piece doesn't cover the whole pump. Figure out how many more pieces to add. 568 uint64_t actual = writeBuffer.size(); 569 size_t i = 0; 570 while (i < morePieces.size() && 571 amount >= actual + morePieces[i].size()) { 572 actual += morePieces[i++].size(); 573 } 574 575 // Write the first piece. 576 auto promise = output.write(writeBuffer.begin(), writeBuffer.size()); 577 578 // Write full pieces as a single gather-write. 579 if (i > 0) { 580 auto more = morePieces.slice(0, i); 581 promise = promise.then([&output,more]() { return output.write(more); }); 582 } 583 584 if (i == morePieces.size()) { 585 // This will complete the write. 586 return canceler.wrap(promise.then([this,&output,amount,actual]() -> Promise<uint64_t> { 587 canceler.release(); 588 fulfiller.fulfill(); 589 pipe.endState(*this); 590 591 if (actual == amount) { 592 // Oh, we had exactly enough. 593 return actual; 594 } else { 595 return pipe.pumpTo(output, amount - actual) 596 .then([actual](uint64_t actual2) { return actual + actual2; }); 597 } 598 }, teeExceptionPromise<uint64_t>(fulfiller))); 599 } else { 600 // Pump ends mid-piece. Write the last, partial piece. 601 auto n = amount - actual; 602 auto splitPiece = morePieces[i]; 603 KJ_ASSERT(n <= splitPiece.size()); 604 auto newWriteBuffer = splitPiece.slice(n, splitPiece.size()); 605 auto newMorePieces = morePieces.slice(i + 1, morePieces.size()); 606 auto prefix = splitPiece.slice(0, n); 607 if (prefix.size() > 0) { 608 promise = promise.then([&output,prefix]() { 609 return output.write(prefix.begin(), prefix.size()); 610 }); 611 } 612 613 return canceler.wrap(promise.then([this,newWriteBuffer,newMorePieces,amount]() { 614 writeBuffer = newWriteBuffer; 615 morePieces = newMorePieces; 616 canceler.release(); 617 return amount; 618 }, teeExceptionSize(fulfiller))); 619 } 620 } 621 622 void abortRead() override { 623 canceler.cancel("abortRead() was called"); 624 fulfiller.reject(KJ_EXCEPTION(DISCONNECTED, "read end of pipe was aborted")); 625 pipe.endState(*this); 626 pipe.abortRead(); 627 } 628 629 Promise<void> write(const void* buffer, size_t size) override { 630 KJ_FAIL_REQUIRE("can't write() again until previous write() completes"); 631 } 632 Promise<void> write(ArrayPtr<const ArrayPtr<const byte>> pieces) override { 633 KJ_FAIL_REQUIRE("can't write() again until previous write() completes"); 634 } 635 Promise<void> writeWithFds(ArrayPtr<const byte> data, 636 ArrayPtr<const ArrayPtr<const byte>> moreData, 637 ArrayPtr<const int> fds) override { 638 KJ_FAIL_REQUIRE("can't write() again until previous write() completes"); 639 } 640 Promise<void> writeWithStreams(ArrayPtr<const byte> data, 641 ArrayPtr<const ArrayPtr<const byte>> moreData, 642 Array<Own<AsyncCapabilityStream>> streams) override { 643 KJ_FAIL_REQUIRE("can't write() again until previous write() completes"); 644 } 645 Maybe<Promise<uint64_t>> tryPumpFrom(AsyncInputStream& input, uint64_t amount) override { 646 KJ_FAIL_REQUIRE("can't tryPumpFrom() again until previous write() completes"); 647 } 648 void shutdownWrite() override { 649 KJ_FAIL_REQUIRE("can't shutdownWrite() until previous write() completes"); 650 } 651 652 Promise<void> whenWriteDisconnected() override { 653 KJ_FAIL_ASSERT("can't get here -- implemented by AsyncPipe"); 654 } 655 656 private: 657 PromiseFulfiller<void>& fulfiller; 658 AsyncPipe& pipe; 659 ArrayPtr<const byte> writeBuffer; 660 ArrayPtr<const ArrayPtr<const byte>> morePieces; 661 kj::OneOf<ArrayPtr<const int>, Array<Own<AsyncCapabilityStream>>> capBuffer; 662 Canceler canceler; 663 664 struct Done { size_t result; }; 665 struct Retry { void* buffer; size_t minBytes; size_t maxBytes; size_t alreadyRead; }; 666 667 OneOf<Done, Retry> tryReadImpl(void* readBufferPtr, size_t minBytes, size_t maxBytes) { 668 KJ_REQUIRE(canceler.isEmpty(), "already pumping"); 669 670 auto readBuffer = arrayPtr(reinterpret_cast<byte*>(readBufferPtr), maxBytes); 671 672 size_t totalRead = 0; 673 while (readBuffer.size() >= writeBuffer.size()) { 674 // The whole current write buffer can be copied into the read buffer. 675 676 { 677 auto n = writeBuffer.size(); 678 memcpy(readBuffer.begin(), writeBuffer.begin(), n); 679 totalRead += n; 680 readBuffer = readBuffer.slice(n, readBuffer.size()); 681 } 682 683 if (morePieces.size() == 0) { 684 // All done writing. 685 fulfiller.fulfill(); 686 pipe.endState(*this); 687 688 if (totalRead >= minBytes) { 689 // Also all done reading. 690 return Done { totalRead }; 691 } else { 692 return Retry { readBuffer.begin(), minBytes - totalRead, readBuffer.size(), totalRead }; 693 } 694 } 695 696 writeBuffer = morePieces[0]; 697 morePieces = morePieces.slice(1, morePieces.size()); 698 } 699 700 // At this point, the read buffer is smaller than the current write buffer, so we can fill 701 // it completely. 702 { 703 auto n = readBuffer.size(); 704 memcpy(readBuffer.begin(), writeBuffer.begin(), n); 705 writeBuffer = writeBuffer.slice(n, writeBuffer.size()); 706 totalRead += n; 707 } 708 709 return Done { totalRead }; 710 } 711 }; 712 713 class BlockedPumpFrom final: public AsyncCapabilityStream { 714 // AsyncPipe state when a tryPumpFrom() is currently waiting for a corresponding read(). 715 716 public: 717 BlockedPumpFrom(PromiseFulfiller<uint64_t>& fulfiller, AsyncPipe& pipe, 718 AsyncInputStream& input, uint64_t amount) 719 : fulfiller(fulfiller), pipe(pipe), input(input), amount(amount) { 720 KJ_REQUIRE(pipe.state == nullptr); 721 pipe.state = *this; 722 } 723 724 ~BlockedPumpFrom() noexcept(false) { 725 pipe.endState(*this); 726 } 727 728 Promise<size_t> tryRead(void* readBuffer, size_t minBytes, size_t maxBytes) override { 729 KJ_REQUIRE(canceler.isEmpty(), "already pumping"); 730 731 auto pumpLeft = amount - pumpedSoFar; 732 auto min = kj::min(pumpLeft, minBytes); 733 auto max = kj::min(pumpLeft, maxBytes); 734 return canceler.wrap(input.tryRead(readBuffer, min, max) 735 .then([this,readBuffer,minBytes,maxBytes,min](size_t actual) -> kj::Promise<size_t> { 736 canceler.release(); 737 pumpedSoFar += actual; 738 KJ_ASSERT(pumpedSoFar <= amount); 739 740 if (pumpedSoFar == amount || actual < min) { 741 // Either we pumped all we wanted or we hit EOF. 742 fulfiller.fulfill(kj::cp(pumpedSoFar)); 743 pipe.endState(*this); 744 } 745 746 if (actual >= minBytes) { 747 return actual; 748 } else { 749 return pipe.tryRead(reinterpret_cast<byte*>(readBuffer) + actual, 750 minBytes - actual, maxBytes - actual) 751 .then([actual](size_t actual2) { return actual + actual2; }); 752 } 753 }, teeExceptionPromise<size_t>(fulfiller))); 754 } 755 756 Promise<ReadResult> tryReadWithFds(void* readBuffer, size_t minBytes, size_t maxBytes, 757 AutoCloseFd* fdBuffer, size_t maxFds) override { 758 // Pumps drop all capabilities, so fall back to regular read. (We don't even know if the 759 // destination is an AsyncCapabilityStream...) 760 return tryRead(readBuffer, minBytes, maxBytes) 761 .then([](size_t n) { return ReadResult { n, 0 }; }); 762 } 763 764 Promise<ReadResult> tryReadWithStreams( 765 void* readBuffer, size_t minBytes, size_t maxBytes, 766 Own<AsyncCapabilityStream>* streamBuffer, size_t maxStreams) override { 767 // Pumps drop all capabilities, so fall back to regular read. (We don't even know if the 768 // destination is an AsyncCapabilityStream...) 769 return tryRead(readBuffer, minBytes, maxBytes) 770 .then([](size_t n) { return ReadResult { n, 0 }; }); 771 } 772 773 Promise<uint64_t> pumpTo(AsyncOutputStream& output, uint64_t amount2) override { 774 KJ_REQUIRE(canceler.isEmpty(), "already pumping"); 775 776 auto n = kj::min(amount2, amount - pumpedSoFar); 777 return canceler.wrap(input.pumpTo(output, n) 778 .then([this,&output,amount2,n](uint64_t actual) -> Promise<uint64_t> { 779 canceler.release(); 780 pumpedSoFar += actual; 781 KJ_ASSERT(pumpedSoFar <= amount); 782 if (pumpedSoFar == amount || actual < n) { 783 // Either we pumped all we wanted or we hit EOF. 784 fulfiller.fulfill(kj::cp(pumpedSoFar)); 785 pipe.endState(*this); 786 return pipe.pumpTo(output, amount2 - actual) 787 .then([actual](uint64_t actual2) { return actual + actual2; }); 788 } 789 790 // Completed entire pumpTo amount. 791 KJ_ASSERT(actual == amount2); 792 return amount2; 793 }, teeExceptionSize(fulfiller))); 794 } 795 796 void abortRead() override { 797 canceler.cancel("abortRead() was called"); 798 799 // The input might have reached EOF, but we haven't detected it yet because we haven't tried 800 // to read that far. If we had not optimized tryPumpFrom() and instead used the default 801 // pumpTo() implementation, then the input would not have called write() again once it 802 // reached EOF, and therefore the abortRead() on the other end would *not* propagate an 803 // exception! We need the same behavior here. To that end, we need to detect if we're at EOF 804 // by reading one last byte. 805 checkEofTask = kj::evalNow([&]() { 806 static char junk; 807 return input.tryRead(&junk, 1, 1).then([this](uint64_t n) { 808 if (n == 0) { 809 fulfiller.fulfill(kj::cp(pumpedSoFar)); 810 } else { 811 fulfiller.reject(KJ_EXCEPTION(DISCONNECTED, "read end of pipe was aborted")); 812 } 813 }).eagerlyEvaluate([this](kj::Exception&& e) { 814 fulfiller.reject(kj::mv(e)); 815 }); 816 }); 817 818 pipe.endState(*this); 819 pipe.abortRead(); 820 } 821 822 Promise<void> write(const void* buffer, size_t size) override { 823 KJ_FAIL_REQUIRE("can't write() again until previous tryPumpFrom() completes"); 824 } 825 Promise<void> write(ArrayPtr<const ArrayPtr<const byte>> pieces) override { 826 KJ_FAIL_REQUIRE("can't write() again until previous tryPumpFrom() completes"); 827 } 828 Promise<void> writeWithFds(ArrayPtr<const byte> data, 829 ArrayPtr<const ArrayPtr<const byte>> moreData, 830 ArrayPtr<const int> fds) override { 831 KJ_FAIL_REQUIRE("can't write() again until previous tryPumpFrom() completes"); 832 } 833 Promise<void> writeWithStreams(ArrayPtr<const byte> data, 834 ArrayPtr<const ArrayPtr<const byte>> moreData, 835 Array<Own<AsyncCapabilityStream>> streams) override { 836 KJ_FAIL_REQUIRE("can't write() again until previous tryPumpFrom() completes"); 837 } 838 Maybe<Promise<uint64_t>> tryPumpFrom(AsyncInputStream& input, uint64_t amount) override { 839 KJ_FAIL_REQUIRE("can't tryPumpFrom() again until previous tryPumpFrom() completes"); 840 } 841 void shutdownWrite() override { 842 KJ_FAIL_REQUIRE("can't shutdownWrite() until previous tryPumpFrom() completes"); 843 } 844 845 Promise<void> whenWriteDisconnected() override { 846 KJ_FAIL_ASSERT("can't get here -- implemented by AsyncPipe"); 847 } 848 849 private: 850 PromiseFulfiller<uint64_t>& fulfiller; 851 AsyncPipe& pipe; 852 AsyncInputStream& input; 853 uint64_t amount; 854 uint64_t pumpedSoFar = 0; 855 Canceler canceler; 856 kj::Promise<void> checkEofTask = nullptr; 857 }; 858 859 class BlockedRead final: public AsyncCapabilityStream { 860 // AsyncPipe state when a tryRead() is currently waiting for a corresponding write(). 861 862 public: 863 BlockedRead( 864 PromiseFulfiller<ReadResult>& fulfiller, AsyncPipe& pipe, 865 ArrayPtr<byte> readBuffer, size_t minBytes, 866 kj::OneOf<ArrayPtr<AutoCloseFd>, ArrayPtr<Own<AsyncCapabilityStream>>> capBuffer = {}) 867 : fulfiller(fulfiller), pipe(pipe), readBuffer(readBuffer), minBytes(minBytes), 868 capBuffer(capBuffer) { 869 KJ_REQUIRE(pipe.state == nullptr); 870 pipe.state = *this; 871 } 872 873 ~BlockedRead() noexcept(false) { 874 pipe.endState(*this); 875 } 876 877 Promise<size_t> tryRead(void* readBuffer, size_t minBytes, size_t maxBytes) override { 878 KJ_FAIL_REQUIRE("can't read() again until previous read() completes"); 879 } 880 Promise<ReadResult> tryReadWithFds(void* readBuffer, size_t minBytes, size_t maxBytes, 881 AutoCloseFd* fdBuffer, size_t maxFds) override { 882 KJ_FAIL_REQUIRE("can't read() again until previous read() completes"); 883 } 884 Promise<ReadResult> tryReadWithStreams( 885 void* readBuffer, size_t minBytes, size_t maxBytes, 886 Own<AsyncCapabilityStream>* streamBuffer, size_t maxStreams) override { 887 KJ_FAIL_REQUIRE("can't read() again until previous read() completes"); 888 } 889 Promise<uint64_t> pumpTo(AsyncOutputStream& output, uint64_t amount) override { 890 KJ_FAIL_REQUIRE("can't read() again until previous read() completes"); 891 } 892 893 void abortRead() override { 894 canceler.cancel("abortRead() was called"); 895 fulfiller.reject(KJ_EXCEPTION(DISCONNECTED, "read end of pipe was aborted")); 896 pipe.endState(*this); 897 pipe.abortRead(); 898 } 899 900 Promise<void> write(const void* writeBuffer, size_t size) override { 901 KJ_REQUIRE(canceler.isEmpty(), "already pumping"); 902 903 auto data = arrayPtr(reinterpret_cast<const byte*>(writeBuffer), size); 904 KJ_SWITCH_ONEOF(writeImpl(data, nullptr)) { 905 KJ_CASE_ONEOF(done, Done) { 906 return READY_NOW; 907 } 908 KJ_CASE_ONEOF(retry, Retry) { 909 KJ_ASSERT(retry.moreData == nullptr); 910 return pipe.write(retry.data.begin(), retry.data.size()); 911 } 912 } 913 KJ_UNREACHABLE; 914 } 915 916 Promise<void> write(ArrayPtr<const ArrayPtr<const byte>> pieces) override { 917 KJ_REQUIRE(canceler.isEmpty(), "already pumping"); 918 919 KJ_SWITCH_ONEOF(writeImpl(pieces[0], pieces.slice(1, pieces.size()))) { 920 KJ_CASE_ONEOF(done, Done) { 921 return READY_NOW; 922 } 923 KJ_CASE_ONEOF(retry, Retry) { 924 if (retry.data.size() == 0) { 925 // We exactly finished the current piece, so just issue a write for the remaining 926 // pieces. 927 if (retry.moreData.size() == 0) { 928 // Nothing left. 929 return READY_NOW; 930 } else { 931 // Write remaining pieces. 932 return pipe.write(retry.moreData); 933 } 934 } else { 935 // Unfortunately we have to execute a separate write() for the remaining part of this 936 // piece, because we can't modify the pieces array. 937 auto promise = pipe.write(retry.data.begin(), retry.data.size()); 938 if (retry.moreData.size() == 0) { 939 // No more pieces so that's it. 940 return kj::mv(promise); 941 } else { 942 // Also need to write the remaining pieces. 943 auto& pipeRef = pipe; 944 return promise.then([pieces=retry.moreData,&pipeRef]() { 945 return pipeRef.write(pieces); 946 }); 947 } 948 } 949 } 950 } 951 KJ_UNREACHABLE; 952 } 953 954 Promise<void> writeWithFds(ArrayPtr<const byte> data, 955 ArrayPtr<const ArrayPtr<const byte>> moreData, 956 ArrayPtr<const int> fds) override { 957 #if __GNUC__ && !__clang__ && __GNUC__ >= 7 958 // GCC 7 decides the open-brace below is "misleadingly indented" as if it were guarded by the `for` 959 // that appears in the implementation of KJ_REQUIRE(). Shut up shut up shut up. 960 #pragma GCC diagnostic ignored "-Wmisleading-indentation" 961 #endif 962 KJ_REQUIRE(canceler.isEmpty(), "already pumping"); 963 964 { // TODO(cleanup): Remove redundant braces when we update to C++17. 965 KJ_SWITCH_ONEOF(capBuffer) { 966 KJ_CASE_ONEOF(fdBuffer, ArrayPtr<AutoCloseFd>) { 967 size_t count = kj::max(fdBuffer.size(), fds.size()); 968 // Unfortunately, we have to dup() each FD, because the writer doesn't release ownership 969 // by default. 970 // TODO(perf): Should we add an ownership-releasing version of writeWithFds()? 971 for (auto i: kj::zeroTo(count)) { 972 int duped; 973 KJ_SYSCALL(duped = dup(fds[i])); 974 fdBuffer[i] = kj::AutoCloseFd(duped); 975 } 976 capBuffer = fdBuffer.slice(count, fdBuffer.size()); 977 readSoFar.capCount += count; 978 } 979 KJ_CASE_ONEOF(streamBuffer, ArrayPtr<Own<AsyncCapabilityStream>>) { 980 if (streamBuffer.size() > 0 && fds.size() > 0) { 981 // TODO(someday): Use AsyncIoStream's `Maybe<int> getFd()` method? 982 KJ_FAIL_REQUIRE( 983 "async pipe message was written with FDs attached, but corresponding read " 984 "asked for streams, and we don't know how to convert here"); 985 } 986 } 987 } 988 } 989 990 KJ_SWITCH_ONEOF(writeImpl(data, moreData)) { 991 KJ_CASE_ONEOF(done, Done) { 992 return READY_NOW; 993 } 994 KJ_CASE_ONEOF(retry, Retry) { 995 // Any leftover fds in `fds` are dropped on the floor, per contract. 996 // TODO(cleanup): We use another writeWithFds() call here only because it accepts `data` 997 // and `moreData` directly. After the stream API refactor, we should be able to avoid 998 // this. 999 return pipe.writeWithFds(retry.data, retry.moreData, nullptr); 1000 } 1001 } 1002 KJ_UNREACHABLE; 1003 } 1004 1005 Promise<void> writeWithStreams(ArrayPtr<const byte> data, 1006 ArrayPtr<const ArrayPtr<const byte>> moreData, 1007 Array<Own<AsyncCapabilityStream>> streams) override { 1008 KJ_REQUIRE(canceler.isEmpty(), "already pumping"); 1009 1010 { // TODO(cleanup): Remove redundant braces when we update to C++17. 1011 KJ_SWITCH_ONEOF(capBuffer) { 1012 KJ_CASE_ONEOF(fdBuffer, ArrayPtr<AutoCloseFd>) { 1013 if (fdBuffer.size() > 0 && streams.size() > 0) { 1014 // TODO(someday): We could let people pass a LowLevelAsyncIoProvider to newTwoWayPipe() 1015 // if we wanted to auto-wrap FDs, but does anyone care? 1016 KJ_FAIL_REQUIRE( 1017 "async pipe message was written with streams attached, but corresponding read " 1018 "asked for FDs, and we don't know how to convert here"); 1019 } 1020 } 1021 KJ_CASE_ONEOF(streamBuffer, ArrayPtr<Own<AsyncCapabilityStream>>) { 1022 size_t count = kj::max(streamBuffer.size(), streams.size()); 1023 for (auto i: kj::zeroTo(count)) { 1024 streamBuffer[i] = kj::mv(streams[i]); 1025 } 1026 capBuffer = streamBuffer.slice(count, streamBuffer.size()); 1027 readSoFar.capCount += count; 1028 } 1029 } 1030 } 1031 1032 KJ_SWITCH_ONEOF(writeImpl(data, moreData)) { 1033 KJ_CASE_ONEOF(done, Done) { 1034 return READY_NOW; 1035 } 1036 KJ_CASE_ONEOF(retry, Retry) { 1037 // Any leftover fds in `fds` are dropped on the floor, per contract. 1038 // TODO(cleanup): We use another writeWithStreams() call here only because it accepts 1039 // `data` and `moreData` directly. After the stream API refactor, we should be able to 1040 // avoid this. 1041 return pipe.writeWithStreams(retry.data, retry.moreData, nullptr); 1042 } 1043 } 1044 KJ_UNREACHABLE; 1045 } 1046 1047 Maybe<Promise<uint64_t>> tryPumpFrom(AsyncInputStream& input, uint64_t amount) override { 1048 // Note: Pumps drop all capabilities. 1049 KJ_REQUIRE(canceler.isEmpty(), "already pumping"); 1050 1051 KJ_ASSERT(minBytes > readSoFar.byteCount); 1052 auto minToRead = kj::min(amount, minBytes - readSoFar.byteCount); 1053 auto maxToRead = kj::min(amount, readBuffer.size()); 1054 1055 return canceler.wrap(input.tryRead(readBuffer.begin(), minToRead, maxToRead) 1056 .then([this,&input,amount](size_t actual) -> Promise<uint64_t> { 1057 readBuffer = readBuffer.slice(actual, readBuffer.size()); 1058 readSoFar.byteCount += actual; 1059 1060 if (readSoFar.byteCount >= minBytes) { 1061 // We've read enough to close out this read (readSoFar >= minBytes). 1062 canceler.release(); 1063 fulfiller.fulfill(kj::cp(readSoFar)); 1064 pipe.endState(*this); 1065 1066 if (actual < amount) { 1067 // We didn't read as much data as the pump requested, but we did fulfill the read, so 1068 // we don't know whether we reached EOF on the input. We need to continue the pump, 1069 // replacing the BlockedRead state. 1070 return input.pumpTo(pipe, amount - actual) 1071 .then([actual](uint64_t actual2) -> uint64_t { return actual + actual2; }); 1072 } else { 1073 // We pumped as much data as was requested, so we can return that now. 1074 return actual; 1075 } 1076 } else { 1077 // The pump completed without fulfilling the read. This either means that the pump 1078 // reached EOF or the `amount` requested was not enough to satisfy the read in the first 1079 // place. Pumps do not propagate EOF, so either way we want to leave the BlockedRead in 1080 // place waiting for more data. 1081 return actual; 1082 } 1083 }, teeExceptionPromise<uint64_t>(fulfiller))); 1084 } 1085 1086 void shutdownWrite() override { 1087 canceler.cancel("shutdownWrite() was called"); 1088 fulfiller.fulfill(kj::cp(readSoFar)); 1089 pipe.endState(*this); 1090 pipe.shutdownWrite(); 1091 } 1092 1093 Promise<void> whenWriteDisconnected() override { 1094 KJ_FAIL_ASSERT("can't get here -- implemented by AsyncPipe"); 1095 } 1096 1097 private: 1098 PromiseFulfiller<ReadResult>& fulfiller; 1099 AsyncPipe& pipe; 1100 ArrayPtr<byte> readBuffer; 1101 size_t minBytes; 1102 kj::OneOf<ArrayPtr<AutoCloseFd>, ArrayPtr<Own<AsyncCapabilityStream>>> capBuffer; 1103 ReadResult readSoFar = {0, 0}; 1104 Canceler canceler; 1105 1106 struct Done {}; 1107 struct Retry { ArrayPtr<const byte> data; ArrayPtr<const ArrayPtr<const byte>> moreData; }; 1108 1109 OneOf<Done, Retry> writeImpl(ArrayPtr<const byte> data, 1110 ArrayPtr<const ArrayPtr<const byte>> moreData) { 1111 for (;;) { 1112 if (data.size() < readBuffer.size()) { 1113 // First write segment consumes a portion of the read buffer but not all of it. 1114 auto n = data.size(); 1115 memcpy(readBuffer.begin(), data.begin(), n); 1116 readSoFar.byteCount += n; 1117 readBuffer = readBuffer.slice(n, readBuffer.size()); 1118 if (moreData.size() == 0) { 1119 // Consumed all written pieces. 1120 if (readSoFar.byteCount >= minBytes) { 1121 // We've read enough to close out this read. 1122 fulfiller.fulfill(kj::cp(readSoFar)); 1123 pipe.endState(*this); 1124 } 1125 return Done(); 1126 } 1127 data = moreData[0]; 1128 moreData = moreData.slice(1, moreData.size()); 1129 // loop 1130 } else { 1131 // First write segment consumes entire read buffer. 1132 auto n = readBuffer.size(); 1133 readSoFar.byteCount += n; 1134 fulfiller.fulfill(kj::cp(readSoFar)); 1135 pipe.endState(*this); 1136 memcpy(readBuffer.begin(), data.begin(), n); 1137 1138 data = data.slice(n, data.size()); 1139 if (data.size() == 0 && moreData.size() == 0) { 1140 return Done(); 1141 } else { 1142 // Note: Even if `data` is empty, we don't replace it with moreData[0], because the 1143 // retry might need to use write(ArrayPtr<ArrayPtr<byte>>) which doesn't allow 1144 // passing a separate first segment. 1145 return Retry { data, moreData }; 1146 } 1147 } 1148 } 1149 } 1150 }; 1151 1152 class BlockedPumpTo final: public AsyncCapabilityStream { 1153 // AsyncPipe state when a pumpTo() is currently waiting for a corresponding write(). 1154 1155 public: 1156 BlockedPumpTo(PromiseFulfiller<uint64_t>& fulfiller, AsyncPipe& pipe, 1157 AsyncOutputStream& output, uint64_t amount) 1158 : fulfiller(fulfiller), pipe(pipe), output(output), amount(amount) { 1159 KJ_REQUIRE(pipe.state == nullptr); 1160 pipe.state = *this; 1161 } 1162 1163 ~BlockedPumpTo() noexcept(false) { 1164 pipe.endState(*this); 1165 } 1166 1167 Promise<size_t> tryRead(void* readBuffer, size_t minBytes, size_t maxBytes) override { 1168 KJ_FAIL_REQUIRE("can't read() again until previous pumpTo() completes"); 1169 } 1170 Promise<ReadResult> tryReadWithFds(void* readBuffer, size_t minBytes, size_t maxBytes, 1171 AutoCloseFd* fdBuffer, size_t maxFds) override { 1172 KJ_FAIL_REQUIRE("can't read() again until previous pumpTo() completes"); 1173 } 1174 Promise<ReadResult> tryReadWithStreams( 1175 void* readBuffer, size_t minBytes, size_t maxBytes, 1176 Own<AsyncCapabilityStream>* streamBuffer, size_t maxStreams) override { 1177 KJ_FAIL_REQUIRE("can't read() again until previous pumpTo() completes"); 1178 } 1179 Promise<uint64_t> pumpTo(AsyncOutputStream& output, uint64_t amount) override { 1180 KJ_FAIL_REQUIRE("can't read() again until previous pumpTo() completes"); 1181 } 1182 1183 void abortRead() override { 1184 canceler.cancel("abortRead() was called"); 1185 fulfiller.reject(KJ_EXCEPTION(DISCONNECTED, "read end of pipe was aborted")); 1186 pipe.endState(*this); 1187 pipe.abortRead(); 1188 } 1189 1190 Promise<void> write(const void* writeBuffer, size_t size) override { 1191 KJ_REQUIRE(canceler.isEmpty(), "already pumping"); 1192 1193 auto actual = kj::min(amount - pumpedSoFar, size); 1194 return canceler.wrap(output.write(writeBuffer, actual) 1195 .then([this,size,actual,writeBuffer]() -> kj::Promise<void> { 1196 canceler.release(); 1197 pumpedSoFar += actual; 1198 1199 KJ_ASSERT(pumpedSoFar <= amount); 1200 KJ_ASSERT(actual <= size); 1201 1202 if (pumpedSoFar == amount) { 1203 // Done with pump. 1204 fulfiller.fulfill(kj::cp(pumpedSoFar)); 1205 pipe.endState(*this); 1206 } 1207 1208 if (actual == size) { 1209 return kj::READY_NOW; 1210 } else { 1211 KJ_ASSERT(pumpedSoFar == amount); 1212 return pipe.write(reinterpret_cast<const byte*>(writeBuffer) + actual, size - actual); 1213 } 1214 }, teeExceptionPromise<void>(fulfiller))); 1215 } 1216 1217 Promise<void> write(ArrayPtr<const ArrayPtr<const byte>> pieces) override { 1218 KJ_REQUIRE(canceler.isEmpty(), "already pumping"); 1219 1220 size_t size = 0; 1221 size_t needed = amount - pumpedSoFar; 1222 for (auto i: kj::indices(pieces)) { 1223 if (pieces[i].size() > needed) { 1224 // The pump ends in the middle of this write. 1225 1226 auto promise = output.write(pieces.slice(0, i)); 1227 1228 if (needed > 0) { 1229 // The pump includes part of this piece, but not all. Unfortunately we need to split 1230 // writes. 1231 auto partial = pieces[i].slice(0, needed); 1232 promise = promise.then([this,partial]() { 1233 return output.write(partial.begin(), partial.size()); 1234 }); 1235 auto partial2 = pieces[i].slice(needed, pieces[i].size()); 1236 promise = canceler.wrap(promise.then([this,partial2]() { 1237 canceler.release(); 1238 fulfiller.fulfill(kj::cp(amount)); 1239 pipe.endState(*this); 1240 return pipe.write(partial2.begin(), partial2.size()); 1241 }, teeExceptionPromise<void>(fulfiller))); 1242 ++i; 1243 } else { 1244 // The pump ends exactly at the end of a piece, how nice. 1245 promise = canceler.wrap(promise.then([this]() { 1246 canceler.release(); 1247 fulfiller.fulfill(kj::cp(amount)); 1248 pipe.endState(*this); 1249 }, teeExceptionVoid(fulfiller))); 1250 } 1251 1252 auto remainder = pieces.slice(i, pieces.size()); 1253 if (remainder.size() > 0) { 1254 auto& pipeRef = pipe; 1255 promise = promise.then([&pipeRef,remainder]() { 1256 return pipeRef.write(remainder); 1257 }); 1258 } 1259 1260 return promise; 1261 } else { 1262 size += pieces[i].size(); 1263 needed -= pieces[i].size(); 1264 } 1265 } 1266 1267 // Turns out we can forward this whole write. 1268 KJ_ASSERT(size <= amount - pumpedSoFar); 1269 return canceler.wrap(output.write(pieces).then([this,size]() { 1270 pumpedSoFar += size; 1271 KJ_ASSERT(pumpedSoFar <= amount); 1272 if (pumpedSoFar == amount) { 1273 // Done pumping. 1274 canceler.release(); 1275 fulfiller.fulfill(kj::cp(amount)); 1276 pipe.endState(*this); 1277 } 1278 }, teeExceptionVoid(fulfiller))); 1279 } 1280 1281 Promise<void> writeWithFds(ArrayPtr<const byte> data, 1282 ArrayPtr<const ArrayPtr<const byte>> moreData, 1283 ArrayPtr<const int> fds) override { 1284 // Pumps drop all capabilities, so fall back to regular write(). 1285 1286 // TODO(cleaunp): After stream API refactor, regular write() methods will take 1287 // (data, moreData) and we can clean this up. 1288 if (moreData.size() == 0) { 1289 return write(data.begin(), data.size()); 1290 } else { 1291 auto pieces = kj::heapArrayBuilder<const ArrayPtr<const byte>>(moreData.size() + 1); 1292 pieces.add(data); 1293 pieces.addAll(moreData); 1294 return write(pieces.finish()); 1295 } 1296 } 1297 1298 Promise<void> writeWithStreams(ArrayPtr<const byte> data, 1299 ArrayPtr<const ArrayPtr<const byte>> moreData, 1300 Array<Own<AsyncCapabilityStream>> streams) override { 1301 // Pumps drop all capabilities, so fall back to regular write(). 1302 1303 // TODO(cleaunp): After stream API refactor, regular write() methods will take 1304 // (data, moreData) and we can clean this up. 1305 if (moreData.size() == 0) { 1306 return write(data.begin(), data.size()); 1307 } else { 1308 auto pieces = kj::heapArrayBuilder<const ArrayPtr<const byte>>(moreData.size() + 1); 1309 pieces.add(data); 1310 pieces.addAll(moreData); 1311 return write(pieces.finish()); 1312 } 1313 } 1314 1315 Maybe<Promise<uint64_t>> tryPumpFrom(AsyncInputStream& input, uint64_t amount2) override { 1316 KJ_REQUIRE(canceler.isEmpty(), "already pumping"); 1317 1318 auto n = kj::min(amount2, amount - pumpedSoFar); 1319 return output.tryPumpFrom(input, n) 1320 .map([&](Promise<uint64_t> subPump) { 1321 return canceler.wrap(subPump 1322 .then([this,&input,amount2,n](uint64_t actual) -> Promise<uint64_t> { 1323 canceler.release(); 1324 pumpedSoFar += actual; 1325 KJ_ASSERT(pumpedSoFar <= amount); 1326 if (pumpedSoFar == amount) { 1327 fulfiller.fulfill(kj::cp(amount)); 1328 pipe.endState(*this); 1329 } 1330 1331 KJ_ASSERT(actual <= amount2); 1332 if (actual == amount2) { 1333 // Completed entire tryPumpFrom amount. 1334 return amount2; 1335 } else if (actual < n) { 1336 // Received less than requested, presumably because EOF. 1337 return actual; 1338 } else { 1339 // We received all the bytes that were requested but it didn't complete the pump. 1340 KJ_ASSERT(pumpedSoFar == amount); 1341 return input.pumpTo(pipe, amount2 - actual); 1342 } 1343 }, teeExceptionPromise<uint64_t>(fulfiller))); 1344 }); 1345 } 1346 1347 void shutdownWrite() override { 1348 canceler.cancel("shutdownWrite() was called"); 1349 fulfiller.fulfill(kj::cp(pumpedSoFar)); 1350 pipe.endState(*this); 1351 pipe.shutdownWrite(); 1352 } 1353 1354 Promise<void> whenWriteDisconnected() override { 1355 KJ_FAIL_ASSERT("can't get here -- implemented by AsyncPipe"); 1356 } 1357 1358 private: 1359 PromiseFulfiller<uint64_t>& fulfiller; 1360 AsyncPipe& pipe; 1361 AsyncOutputStream& output; 1362 uint64_t amount; 1363 size_t pumpedSoFar = 0; 1364 Canceler canceler; 1365 }; 1366 1367 class AbortedRead final: public AsyncCapabilityStream { 1368 // AsyncPipe state when abortRead() has been called. 1369 1370 public: 1371 Promise<size_t> tryRead(void* readBufferPtr, size_t minBytes, size_t maxBytes) override { 1372 return KJ_EXCEPTION(DISCONNECTED, "abortRead() has been called"); 1373 } 1374 Promise<ReadResult> tryReadWithFds(void* readBuffer, size_t minBytes, size_t maxBytes, 1375 AutoCloseFd* fdBuffer, size_t maxFds) override { 1376 return KJ_EXCEPTION(DISCONNECTED, "abortRead() has been called"); 1377 } 1378 Promise<ReadResult> tryReadWithStreams( 1379 void* readBuffer, size_t minBytes, size_t maxBytes, 1380 Own<AsyncCapabilityStream>* streamBuffer, size_t maxStreams) override { 1381 return KJ_EXCEPTION(DISCONNECTED, "abortRead() has been called"); 1382 } 1383 Promise<uint64_t> pumpTo(AsyncOutputStream& output, uint64_t amount) override { 1384 return KJ_EXCEPTION(DISCONNECTED, "abortRead() has been called"); 1385 } 1386 void abortRead() override { 1387 // ignore repeated abort 1388 } 1389 1390 Promise<void> write(const void* buffer, size_t size) override { 1391 return KJ_EXCEPTION(DISCONNECTED, "abortRead() has been called"); 1392 } 1393 Promise<void> write(ArrayPtr<const ArrayPtr<const byte>> pieces) override { 1394 return KJ_EXCEPTION(DISCONNECTED, "abortRead() has been called"); 1395 } 1396 Promise<void> writeWithFds(ArrayPtr<const byte> data, 1397 ArrayPtr<const ArrayPtr<const byte>> moreData, 1398 ArrayPtr<const int> fds) override { 1399 return KJ_EXCEPTION(DISCONNECTED, "abortRead() has been called"); 1400 } 1401 Promise<void> writeWithStreams(ArrayPtr<const byte> data, 1402 ArrayPtr<const ArrayPtr<const byte>> moreData, 1403 Array<Own<AsyncCapabilityStream>> streams) override { 1404 return KJ_EXCEPTION(DISCONNECTED, "abortRead() has been called"); 1405 } 1406 Maybe<Promise<uint64_t>> tryPumpFrom(AsyncInputStream& input, uint64_t amount) override { 1407 // There might not actually be any data in `input`, in which case a pump wouldn't actually 1408 // write anything and wouldn't fail. 1409 1410 if (input.tryGetLength().orDefault(1) == 0) { 1411 // Yeah a pump would pump nothing. 1412 return Promise<uint64_t>(uint64_t(0)); 1413 } else { 1414 // While we *could* just return nullptr here, it would probably then fall back to a normal 1415 // buffered pump, which would allocate a big old buffer just to find there's nothing to 1416 // read. Let's try reading 1 byte to avoid that allocation. 1417 static char c; 1418 return input.tryRead(&c, 1, 1).then([](size_t n) { 1419 if (n == 0) { 1420 // Yay, we're at EOF as hoped. 1421 return uint64_t(0); 1422 } else { 1423 // There was data in the input. The pump would have thrown. 1424 kj::throwRecoverableException( 1425 KJ_EXCEPTION(DISCONNECTED, "abortRead() has been called")); 1426 return uint64_t(0); 1427 } 1428 }); 1429 } 1430 } 1431 void shutdownWrite() override { 1432 // ignore -- currently shutdownWrite() actually means that the PipeWriteEnd was dropped, 1433 // which is not an error even if reads have been aborted. 1434 } 1435 Promise<void> whenWriteDisconnected() override { 1436 KJ_FAIL_ASSERT("can't get here -- implemented by AsyncPipe"); 1437 } 1438 }; 1439 1440 class ShutdownedWrite final: public AsyncCapabilityStream { 1441 // AsyncPipe state when shutdownWrite() has been called. 1442 1443 public: 1444 Promise<size_t> tryRead(void* readBufferPtr, size_t minBytes, size_t maxBytes) override { 1445 return size_t(0); 1446 } 1447 Promise<ReadResult> tryReadWithFds(void* readBuffer, size_t minBytes, size_t maxBytes, 1448 AutoCloseFd* fdBuffer, size_t maxFds) override { 1449 return ReadResult { 0, 0 }; 1450 } 1451 Promise<ReadResult> tryReadWithStreams( 1452 void* readBuffer, size_t minBytes, size_t maxBytes, 1453 Own<AsyncCapabilityStream>* streamBuffer, size_t maxStreams) override { 1454 return ReadResult { 0, 0 }; 1455 } 1456 Promise<uint64_t> pumpTo(AsyncOutputStream& output, uint64_t amount) override { 1457 return uint64_t(0); 1458 } 1459 void abortRead() override { 1460 // ignore 1461 } 1462 1463 Promise<void> write(const void* buffer, size_t size) override { 1464 KJ_FAIL_REQUIRE("shutdownWrite() has been called"); 1465 } 1466 Promise<void> write(ArrayPtr<const ArrayPtr<const byte>> pieces) override { 1467 KJ_FAIL_REQUIRE("shutdownWrite() has been called"); 1468 } 1469 Promise<void> writeWithFds(ArrayPtr<const byte> data, 1470 ArrayPtr<const ArrayPtr<const byte>> moreData, 1471 ArrayPtr<const int> fds) override { 1472 KJ_FAIL_REQUIRE("shutdownWrite() has been called"); 1473 } 1474 Promise<void> writeWithStreams(ArrayPtr<const byte> data, 1475 ArrayPtr<const ArrayPtr<const byte>> moreData, 1476 Array<Own<AsyncCapabilityStream>> streams) override { 1477 KJ_FAIL_REQUIRE("shutdownWrite() has been called"); 1478 } 1479 Maybe<Promise<uint64_t>> tryPumpFrom(AsyncInputStream& input, uint64_t amount) override { 1480 KJ_FAIL_REQUIRE("shutdownWrite() has been called"); 1481 } 1482 void shutdownWrite() override { 1483 // ignore -- currently shutdownWrite() actually means that the PipeWriteEnd was dropped, 1484 // so it will only be called once anyhow. 1485 } 1486 Promise<void> whenWriteDisconnected() override { 1487 KJ_FAIL_ASSERT("can't get here -- implemented by AsyncPipe"); 1488 } 1489 }; 1490 }; 1491 1492 class PipeReadEnd final: public AsyncInputStream { 1493 public: 1494 PipeReadEnd(kj::Own<AsyncPipe> pipe): pipe(kj::mv(pipe)) {} 1495 ~PipeReadEnd() noexcept(false) { 1496 unwind.catchExceptionsIfUnwinding([&]() { 1497 pipe->abortRead(); 1498 }); 1499 } 1500 1501 Promise<size_t> tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { 1502 return pipe->tryRead(buffer, minBytes, maxBytes); 1503 } 1504 1505 Promise<uint64_t> pumpTo(AsyncOutputStream& output, uint64_t amount) override { 1506 return pipe->pumpTo(output, amount); 1507 } 1508 1509 private: 1510 Own<AsyncPipe> pipe; 1511 UnwindDetector unwind; 1512 }; 1513 1514 class PipeWriteEnd final: public AsyncOutputStream { 1515 public: 1516 PipeWriteEnd(kj::Own<AsyncPipe> pipe): pipe(kj::mv(pipe)) {} 1517 ~PipeWriteEnd() noexcept(false) { 1518 unwind.catchExceptionsIfUnwinding([&]() { 1519 pipe->shutdownWrite(); 1520 }); 1521 } 1522 1523 Promise<void> write(const void* buffer, size_t size) override { 1524 return pipe->write(buffer, size); 1525 } 1526 1527 Promise<void> write(ArrayPtr<const ArrayPtr<const byte>> pieces) override { 1528 return pipe->write(pieces); 1529 } 1530 1531 Maybe<Promise<uint64_t>> tryPumpFrom( 1532 AsyncInputStream& input, uint64_t amount) override { 1533 return pipe->tryPumpFrom(input, amount); 1534 } 1535 1536 Promise<void> whenWriteDisconnected() override { 1537 return pipe->whenWriteDisconnected(); 1538 } 1539 1540 private: 1541 Own<AsyncPipe> pipe; 1542 UnwindDetector unwind; 1543 }; 1544 1545 class TwoWayPipeEnd final: public AsyncCapabilityStream { 1546 public: 1547 TwoWayPipeEnd(kj::Own<AsyncPipe> in, kj::Own<AsyncPipe> out) 1548 : in(kj::mv(in)), out(kj::mv(out)) {} 1549 ~TwoWayPipeEnd() noexcept(false) { 1550 unwind.catchExceptionsIfUnwinding([&]() { 1551 out->shutdownWrite(); 1552 in->abortRead(); 1553 }); 1554 } 1555 1556 Promise<size_t> tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { 1557 return in->tryRead(buffer, minBytes, maxBytes); 1558 } 1559 Promise<ReadResult> tryReadWithFds(void* buffer, size_t minBytes, size_t maxBytes, 1560 AutoCloseFd* fdBuffer, size_t maxFds) override { 1561 return in->tryReadWithFds(buffer, minBytes, maxBytes, fdBuffer, maxFds); 1562 } 1563 Promise<ReadResult> tryReadWithStreams( 1564 void* buffer, size_t minBytes, size_t maxBytes, 1565 Own<AsyncCapabilityStream>* streamBuffer, size_t maxStreams) override { 1566 return in->tryReadWithStreams(buffer, minBytes, maxBytes, streamBuffer, maxStreams); 1567 } 1568 Promise<uint64_t> pumpTo(AsyncOutputStream& output, uint64_t amount) override { 1569 return in->pumpTo(output, amount); 1570 } 1571 void abortRead() override { 1572 in->abortRead(); 1573 } 1574 1575 Promise<void> write(const void* buffer, size_t size) override { 1576 return out->write(buffer, size); 1577 } 1578 Promise<void> write(ArrayPtr<const ArrayPtr<const byte>> pieces) override { 1579 return out->write(pieces); 1580 } 1581 Promise<void> writeWithFds(ArrayPtr<const byte> data, 1582 ArrayPtr<const ArrayPtr<const byte>> moreData, 1583 ArrayPtr<const int> fds) override { 1584 return out->writeWithFds(data, moreData, fds); 1585 } 1586 Promise<void> writeWithStreams(ArrayPtr<const byte> data, 1587 ArrayPtr<const ArrayPtr<const byte>> moreData, 1588 Array<Own<AsyncCapabilityStream>> streams) override { 1589 return out->writeWithStreams(data, moreData, kj::mv(streams)); 1590 } 1591 Maybe<Promise<uint64_t>> tryPumpFrom( 1592 AsyncInputStream& input, uint64_t amount) override { 1593 return out->tryPumpFrom(input, amount); 1594 } 1595 Promise<void> whenWriteDisconnected() override { 1596 return out->whenWriteDisconnected(); 1597 } 1598 void shutdownWrite() override { 1599 out->shutdownWrite(); 1600 } 1601 1602 private: 1603 kj::Own<AsyncPipe> in; 1604 kj::Own<AsyncPipe> out; 1605 UnwindDetector unwind; 1606 }; 1607 1608 class LimitedInputStream final: public AsyncInputStream { 1609 public: 1610 LimitedInputStream(kj::Own<AsyncInputStream> inner, uint64_t limit) 1611 : inner(kj::mv(inner)), limit(limit) { 1612 if (limit == 0) { 1613 this->inner = nullptr; 1614 } 1615 } 1616 1617 Maybe<uint64_t> tryGetLength() override { 1618 return limit; 1619 } 1620 1621 Promise<size_t> tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { 1622 if (limit == 0) return size_t(0); 1623 return inner->tryRead(buffer, kj::min(minBytes, limit), kj::min(maxBytes, limit)) 1624 .then([this,minBytes](size_t actual) { 1625 decreaseLimit(actual, minBytes); 1626 return actual; 1627 }); 1628 } 1629 1630 Promise<uint64_t> pumpTo(AsyncOutputStream& output, uint64_t amount) override { 1631 if (limit == 0) return uint64_t(0); 1632 auto requested = kj::min(amount, limit); 1633 return inner->pumpTo(output, requested) 1634 .then([this,requested](uint64_t actual) { 1635 decreaseLimit(actual, requested); 1636 return actual; 1637 }); 1638 } 1639 1640 private: 1641 Own<AsyncInputStream> inner; 1642 uint64_t limit; 1643 1644 void decreaseLimit(uint64_t amount, uint64_t requested) { 1645 KJ_ASSERT(limit >= amount); 1646 limit -= amount; 1647 if (limit == 0) { 1648 inner = nullptr; 1649 } else if (amount < requested) { 1650 kj::throwRecoverableException(KJ_EXCEPTION(DISCONNECTED, 1651 "fixed-length pipe ended prematurely")); 1652 } 1653 } 1654 }; 1655 1656 } // namespace 1657 1658 OneWayPipe newOneWayPipe(kj::Maybe<uint64_t> expectedLength) { 1659 auto impl = kj::refcounted<AsyncPipe>(); 1660 Own<AsyncInputStream> readEnd = kj::heap<PipeReadEnd>(kj::addRef(*impl)); 1661 KJ_IF_MAYBE(l, expectedLength) { 1662 readEnd = kj::heap<LimitedInputStream>(kj::mv(readEnd), *l); 1663 } 1664 Own<AsyncOutputStream> writeEnd = kj::heap<PipeWriteEnd>(kj::mv(impl)); 1665 return { kj::mv(readEnd), kj::mv(writeEnd) }; 1666 } 1667 1668 TwoWayPipe newTwoWayPipe() { 1669 auto pipe1 = kj::refcounted<AsyncPipe>(); 1670 auto pipe2 = kj::refcounted<AsyncPipe>(); 1671 auto end1 = kj::heap<TwoWayPipeEnd>(kj::addRef(*pipe1), kj::addRef(*pipe2)); 1672 auto end2 = kj::heap<TwoWayPipeEnd>(kj::mv(pipe2), kj::mv(pipe1)); 1673 return { { kj::mv(end1), kj::mv(end2) } }; 1674 } 1675 1676 CapabilityPipe newCapabilityPipe() { 1677 auto pipe1 = kj::refcounted<AsyncPipe>(); 1678 auto pipe2 = kj::refcounted<AsyncPipe>(); 1679 auto end1 = kj::heap<TwoWayPipeEnd>(kj::addRef(*pipe1), kj::addRef(*pipe2)); 1680 auto end2 = kj::heap<TwoWayPipeEnd>(kj::mv(pipe2), kj::mv(pipe1)); 1681 return { { kj::mv(end1), kj::mv(end2) } }; 1682 } 1683 1684 namespace { 1685 1686 class AsyncTee final: public Refcounted { 1687 class Buffer { 1688 public: 1689 Buffer() = default; 1690 1691 uint64_t consume(ArrayPtr<byte>& readBuffer, size_t& minBytes); 1692 // Consume as many bytes as possible, copying them into `readBuffer`. Return the number of bytes 1693 // consumed. 1694 // 1695 // `readBuffer` and `minBytes` are both assigned appropriate new values, such that after any 1696 // call to `consume()`, `readBuffer` will point to the remaining slice of unwritten space, and 1697 // `minBytes` will have been decremented (clamped to zero) by the amount of bytes read. That is, 1698 // the read can be considered fulfilled if `minBytes` is zero after a call to `consume()`. 1699 1700 Array<const ArrayPtr<const byte>> asArray(uint64_t minBytes, uint64_t& amount); 1701 // Consume the first `minBytes` of the buffer (or the entire buffer) and return it in an Array 1702 // of ArrayPtr<const byte>s, suitable for passing to AsyncOutputStream.write(). The outer Array 1703 // owns the underlying data. 1704 1705 void produce(Array<byte> bytes); 1706 // Enqueue a byte array to the end of the buffer list. 1707 1708 bool empty() const; 1709 uint64_t size() const; 1710 1711 Buffer clone() const { 1712 size_t size = 0; 1713 for (const auto& buf: bufferList) { 1714 size += buf.size(); 1715 } 1716 auto builder = heapArrayBuilder<byte>(size); 1717 for (const auto& buf: bufferList) { 1718 builder.addAll(buf); 1719 } 1720 std::deque<Array<byte>> deque; 1721 deque.emplace_back(builder.finish()); 1722 return Buffer{mv(deque)}; 1723 } 1724 1725 private: 1726 Buffer(std::deque<Array<byte>>&& buffer) : bufferList(mv(buffer)) {} 1727 1728 std::deque<Array<byte>> bufferList; 1729 }; 1730 1731 class Sink; 1732 1733 public: 1734 using BranchId = uint; 1735 1736 struct Branch { 1737 Buffer buffer; 1738 Maybe<Sink&> sink; 1739 }; 1740 1741 explicit AsyncTee(Own<AsyncInputStream> inner, uint64_t bufferSizeLimit) 1742 : inner(mv(inner)), bufferSizeLimit(bufferSizeLimit), length(this->inner->tryGetLength()) {} 1743 ~AsyncTee() noexcept(false) { 1744 bool hasBranches = false; 1745 for (auto& branch: branches) { 1746 hasBranches = hasBranches || branch != nullptr; 1747 } 1748 KJ_ASSERT(!hasBranches, "destroying AsyncTee with branch still alive") { 1749 // Don't std::terminate(). 1750 break; 1751 } 1752 } 1753 1754 BranchId addBranch() { 1755 return addBranch(Branch()); 1756 } 1757 1758 BranchId addBranch(Branch&& branch) { 1759 BranchId branchId = branches.size(); 1760 branches.add(mv(branch)); 1761 return branchId; 1762 } 1763 1764 Branch cloneBranch(BranchId branchId) const { 1765 const auto& state = KJ_ASSERT_NONNULL(branches[branchId]); 1766 return {state.buffer.clone(), nullptr}; 1767 } 1768 1769 void removeBranch(BranchId branch) { 1770 auto& state = KJ_REQUIRE_NONNULL(branches[branch], "branch was already destroyed"); 1771 KJ_REQUIRE(state.sink == nullptr, 1772 "destroying tee branch with operation still in-progress; probably going to segfault") { 1773 // Don't std::terminate(). 1774 break; 1775 } 1776 1777 branches[branch] = nullptr; 1778 } 1779 1780 Promise<size_t> tryRead(BranchId branch, void* buffer, size_t minBytes, size_t maxBytes) { 1781 auto& state = KJ_ASSERT_NONNULL(branches[branch]); 1782 KJ_ASSERT(state.sink == nullptr); 1783 1784 // If there is excess data in the buffer for us, slurp that up. 1785 auto readBuffer = arrayPtr(reinterpret_cast<byte*>(buffer), maxBytes); 1786 auto readSoFar = state.buffer.consume(readBuffer, minBytes); 1787 1788 if (minBytes == 0) { 1789 return readSoFar; 1790 } 1791 1792 if (state.buffer.empty()) { 1793 KJ_IF_MAYBE(reason, stoppage) { 1794 // Prefer a short read to an exception. The exception prevents the pull loop from adding any 1795 // data to the buffer, so `readSoFar` will be zero the next time someone calls `tryRead()`, 1796 // and the caller will see the exception. 1797 if (reason->is<Eof>() || readSoFar > 0) { 1798 return readSoFar; 1799 } 1800 return cp(reason->get<Exception>()); 1801 } 1802 } 1803 1804 auto promise = newAdaptedPromise<size_t, ReadSink>(state.sink, readBuffer, minBytes, readSoFar); 1805 ensurePulling(); 1806 return mv(promise); 1807 } 1808 1809 Maybe<uint64_t> tryGetLength(BranchId branch) { 1810 auto& state = KJ_ASSERT_NONNULL(branches[branch]); 1811 1812 return length.map([&state](uint64_t amount) { 1813 return amount + state.buffer.size(); 1814 }); 1815 } 1816 1817 uint64_t getBufferSizeLimit() const { 1818 return bufferSizeLimit; 1819 } 1820 1821 Promise<uint64_t> pumpTo(BranchId branch, AsyncOutputStream& output, uint64_t amount) { 1822 auto& state = KJ_ASSERT_NONNULL(branches[branch]); 1823 KJ_ASSERT(state.sink == nullptr); 1824 1825 if (amount == 0) { 1826 return amount; 1827 } 1828 1829 if (state.buffer.empty()) { 1830 KJ_IF_MAYBE(reason, stoppage) { 1831 if (reason->is<Eof>()) { 1832 return uint64_t(0); 1833 } 1834 return cp(reason->get<Exception>()); 1835 } 1836 } 1837 1838 auto promise = newAdaptedPromise<uint64_t, PumpSink>(state.sink, output, amount); 1839 ensurePulling(); 1840 return mv(promise); 1841 } 1842 1843 private: 1844 struct Eof {}; 1845 using Stoppage = OneOf<Eof, Exception>; 1846 1847 class Sink { 1848 public: 1849 struct Need { 1850 // We use uint64_t here because: 1851 // - pumpTo() accepts it as the `amount` parameter. 1852 // - all practical values of tryRead()'s `maxBytes` parameter (a size_t) should also fit into 1853 // a uint64_t, unless we're on a machine with multiple exabytes of memory ... 1854 1855 uint64_t minBytes = 0; 1856 1857 uint64_t maxBytes = kj::maxValue; 1858 }; 1859 1860 virtual Promise<void> fill(Buffer& inBuffer, const Maybe<Stoppage>& stoppage) = 0; 1861 // Attempt to fill the sink with bytes andreturn a promise which must resolve before any inner 1862 // read may be attempted. If a sink requires backpressure to be respected, this is how it should 1863 // be communicated. 1864 // 1865 // If the sink is full, it must detach from the tee before the returned promise is resolved. 1866 // 1867 // The returned promise must not result in an exception. 1868 1869 virtual Need need() = 0; 1870 1871 virtual void reject(Exception&& exception) = 0; 1872 // Inform this sink of a catastrophic exception and detach it. Regular read exceptions should be 1873 // propagated through `fill()`'s stoppage parameter instead. 1874 }; 1875 1876 template <typename T> 1877 class SinkBase: public Sink { 1878 // Registers itself with the tee as a sink on construction, detaches from the tee on 1879 // fulfillment, rejection, or destruction. 1880 // 1881 // A bit of a Frankenstein, avert your eyes. For one thing, it's more of a mixin than a base... 1882 1883 public: 1884 explicit SinkBase(PromiseFulfiller<T>& fulfiller, Maybe<Sink&>& sinkLink) 1885 : fulfiller(fulfiller), sinkLink(sinkLink) { 1886 KJ_ASSERT(sinkLink == nullptr, "sink initiated with sink already in flight"); 1887 sinkLink = *this; 1888 } 1889 KJ_DISALLOW_COPY(SinkBase); 1890 ~SinkBase() noexcept(false) { detach(); } 1891 1892 void reject(Exception&& exception) override { 1893 // The tee is allowed to reject this sink if it needs to, e.g. to propagate a non-inner read 1894 // exception from the pull loop. Only the derived class is allowed to fulfill() directly, 1895 // though -- the tee must keep calling fill(). 1896 1897 fulfiller.reject(mv(exception)); 1898 detach(); 1899 } 1900 1901 protected: 1902 template <typename U> 1903 void fulfill(U value) { 1904 fulfiller.fulfill(fwd<U>(value)); 1905 detach(); 1906 } 1907 1908 private: 1909 void detach() { 1910 KJ_IF_MAYBE(sink, sinkLink) { 1911 if (sink == this) { 1912 sinkLink = nullptr; 1913 } 1914 } 1915 } 1916 1917 PromiseFulfiller<T>& fulfiller; 1918 Maybe<Sink&>& sinkLink; 1919 }; 1920 1921 class ReadSink final: public SinkBase<size_t> { 1922 public: 1923 explicit ReadSink(PromiseFulfiller<size_t>& fulfiller, Maybe<Sink&>& registration, 1924 ArrayPtr<byte> buffer, size_t minBytes, size_t readSoFar) 1925 : SinkBase(fulfiller, registration), buffer(buffer), 1926 minBytes(minBytes), readSoFar(readSoFar) {} 1927 1928 Promise<void> fill(Buffer& inBuffer, const Maybe<Stoppage>& stoppage) override { 1929 auto amount = inBuffer.consume(buffer, minBytes); 1930 readSoFar += amount; 1931 1932 if (minBytes == 0) { 1933 // We satisfied the read request. 1934 fulfill(readSoFar); 1935 return READY_NOW; 1936 } 1937 1938 if (amount == 0 && inBuffer.empty()) { 1939 // We made no progress on the read request and the buffer is tapped out. 1940 KJ_IF_MAYBE(reason, stoppage) { 1941 if (reason->is<Eof>() || readSoFar > 0) { 1942 // Prefer short read to exception. 1943 fulfill(readSoFar); 1944 } else { 1945 reject(cp(reason->get<Exception>())); 1946 } 1947 return READY_NOW; 1948 } 1949 } 1950 1951 return READY_NOW; 1952 } 1953 1954 Need need() override { return Need { minBytes, buffer.size() }; } 1955 1956 private: 1957 ArrayPtr<byte> buffer; 1958 size_t minBytes; 1959 // Arguments to the outer tryRead() call, sliced/decremented after every buffer consumption. 1960 1961 size_t readSoFar; 1962 // End result of the outer tryRead(). 1963 }; 1964 1965 class PumpSink final: public SinkBase<uint64_t> { 1966 public: 1967 explicit PumpSink(PromiseFulfiller<uint64_t>& fulfiller, Maybe<Sink&>& registration, 1968 AsyncOutputStream& output, uint64_t limit) 1969 : SinkBase(fulfiller, registration), output(output), limit(limit) {} 1970 1971 ~PumpSink() noexcept(false) { 1972 canceler.cancel("This pump has been canceled."); 1973 } 1974 1975 Promise<void> fill(Buffer& inBuffer, const Maybe<Stoppage>& stoppage) override { 1976 KJ_ASSERT(limit > 0); 1977 1978 uint64_t amount = 0; 1979 1980 // TODO(someday): This consumes data from the buffer, but we cannot know if the stream to 1981 // which we're pumping will accept it until after the write() promise completes. If the 1982 // write() promise rejects, we lose this data. We should consume the data from the buffer 1983 // only after successful writes. 1984 auto writeBuffer = inBuffer.asArray(limit, amount); 1985 KJ_ASSERT(limit >= amount); 1986 if (amount > 0) { 1987 Promise<void> promise = kj::evalNow([&]() { 1988 return output.write(writeBuffer).attach(mv(writeBuffer)); 1989 }).then([this, amount]() { 1990 limit -= amount; 1991 pumpedSoFar += amount; 1992 if (limit == 0) { 1993 fulfill(pumpedSoFar); 1994 } 1995 }).eagerlyEvaluate([this](Exception&& exception) { 1996 reject(mv(exception)); 1997 }); 1998 1999 return canceler.wrap(mv(promise)).catch_([](kj::Exception&&) {}); 2000 } else KJ_IF_MAYBE(reason, stoppage) { 2001 if (reason->is<Eof>()) { 2002 // Unlike in the read case, it makes more sense to immediately propagate exceptions to the 2003 // pump promise rather than show it a "short pump". 2004 fulfill(pumpedSoFar); 2005 } else { 2006 reject(cp(reason->get<Exception>())); 2007 } 2008 } 2009 2010 return READY_NOW; 2011 } 2012 2013 Need need() override { return Need { 1, limit }; } 2014 2015 private: 2016 AsyncOutputStream& output; 2017 uint64_t limit; 2018 // Arguments to the outer pumpTo() call, decremented after every buffer consumption. 2019 // 2020 // Equal to zero once fulfiller has been fulfilled/rejected. 2021 2022 uint64_t pumpedSoFar = 0; 2023 // End result of the outer pumpTo(). 2024 2025 Canceler canceler; 2026 // When the pump is canceled, we also need to cancel any write operations in flight. 2027 }; 2028 2029 // ===================================================================================== 2030 2031 Maybe<Sink::Need> analyzeSinks() { 2032 // Return nullptr if there are no sinks at all. Otherwise, return the largest `minBytes` and the 2033 // smallest `maxBytes` requested by any sink. The pull loop will use these values to calculate 2034 // the optimal buffer size for the next inner read, so that a minimum amount of data is buffered 2035 // at any given time. 2036 2037 uint64_t minBytes = 0; 2038 uint64_t maxBytes = kj::maxValue; 2039 2040 uint nBranches = 0; 2041 uint nSinks = 0; 2042 2043 for (auto& state: branches) { 2044 KJ_IF_MAYBE(s, state) { 2045 ++nBranches; 2046 KJ_IF_MAYBE(sink, s->sink) { 2047 ++nSinks; 2048 auto need = sink->need(); 2049 minBytes = kj::max(minBytes, need.minBytes); 2050 maxBytes = kj::min(maxBytes, need.maxBytes); 2051 } 2052 } 2053 } 2054 2055 if (nSinks > 0) { 2056 KJ_ASSERT(minBytes > 0); 2057 KJ_ASSERT(maxBytes > 0, "sink was filled but did not detach"); 2058 2059 // Sinks may report non-overlapping needs. 2060 maxBytes = kj::max(minBytes, maxBytes); 2061 2062 return Sink::Need { minBytes, maxBytes }; 2063 } 2064 2065 // No active sinks. 2066 return nullptr; 2067 } 2068 2069 void ensurePulling() { 2070 if (!pulling) { 2071 pulling = true; 2072 UnwindDetector unwind; 2073 KJ_DEFER(if (unwind.isUnwinding()) pulling = false); 2074 pullPromise = pull(); 2075 } 2076 } 2077 2078 Promise<void> pull() { 2079 return pullLoop().eagerlyEvaluate([this](Exception&& exception) { 2080 // Exception from our loop, not from inner tryRead(). Something is broken; tell everybody! 2081 pulling = false; 2082 for (auto& state: branches) { 2083 KJ_IF_MAYBE(s, state) { 2084 KJ_IF_MAYBE(sink, s->sink) { 2085 sink->reject(KJ_EXCEPTION(FAILED, "Exception in tee loop", exception)); 2086 } 2087 } 2088 } 2089 }); 2090 } 2091 2092 constexpr static size_t MAX_BLOCK_SIZE = 1 << 14; // 16k 2093 2094 Own<AsyncInputStream> inner; 2095 const uint64_t bufferSizeLimit = kj::maxValue; 2096 Maybe<uint64_t> length; 2097 Vector<Maybe<Branch>> branches; 2098 Maybe<Stoppage> stoppage; 2099 Promise<void> pullPromise = READY_NOW; 2100 bool pulling = false; 2101 2102 private: 2103 Promise<void> pullLoop() { 2104 // Use evalLater() so that two pump sinks added on the same turn of the event loop will not 2105 // cause buffering. 2106 return evalLater([this] { 2107 // Attempt to fill any sinks that exist. 2108 2109 Vector<Promise<void>> promises; 2110 2111 for (auto& state: branches) { 2112 KJ_IF_MAYBE(s, state) { 2113 KJ_IF_MAYBE(sink, s->sink) { 2114 promises.add(sink->fill(s->buffer, stoppage)); 2115 } 2116 } 2117 } 2118 2119 // Respect the greatest of the sinks' backpressures. 2120 return joinPromises(promises.releaseAsArray()); 2121 }).then([this]() -> Promise<void> { 2122 // Check to see whether we need to perform an inner read. 2123 2124 auto need = analyzeSinks(); 2125 2126 if (need == nullptr) { 2127 // No more sinks, stop pulling. 2128 pulling = false; 2129 return READY_NOW; 2130 } 2131 2132 if (stoppage != nullptr) { 2133 // We're eof or errored, don't read, but loop so we can fill the sink(s). 2134 return pullLoop(); 2135 } 2136 2137 auto& n = KJ_ASSERT_NONNULL(need); 2138 2139 KJ_ASSERT(n.minBytes > 0); 2140 2141 // We must perform an inner read. 2142 2143 // We'd prefer not to explode our buffer, if that's cool. We cap `maxBytes` to the buffer size 2144 // limit or our builtin MAX_BLOCK_SIZE, whichever is smaller. But, we make sure `maxBytes` is 2145 // still >= `minBytes`. 2146 n.maxBytes = kj::min(n.maxBytes, MAX_BLOCK_SIZE); 2147 n.maxBytes = kj::min(n.maxBytes, bufferSizeLimit); 2148 n.maxBytes = kj::max(n.minBytes, n.maxBytes); 2149 for (auto& state: branches) { 2150 KJ_IF_MAYBE(s, state) { 2151 // TODO(perf): buffer.size() is O(n) where n = # of individual heap-allocated byte arrays. 2152 if (s->buffer.size() + n.maxBytes > bufferSizeLimit) { 2153 stoppage = Stoppage(KJ_EXCEPTION(FAILED, "tee buffer size limit exceeded")); 2154 return pullLoop(); 2155 } 2156 } 2157 } 2158 auto heapBuffer = heapArray<byte>(n.maxBytes); 2159 2160 // gcc 4.9 quirk: If I don't hoist this into a separate variable and instead call 2161 // 2162 // inner->tryRead(heapBuffer.begin(), n.minBytes, heapBuffer.size()) 2163 // 2164 // `heapBuffer` seems to get moved into the lambda capture before the arguments to `tryRead()` 2165 // are evaluated, meaning `inner` sees a nullptr destination. Bizarrely, `inner` sees the 2166 // correct value for `heapBuffer.size()`... I dunno, man. 2167 auto destination = heapBuffer.begin(); 2168 2169 return kj::evalNow([&]() { return inner->tryRead(destination, n.minBytes, n.maxBytes); }) 2170 .then([this, heapBuffer = mv(heapBuffer), minBytes = n.minBytes](size_t amount) mutable 2171 -> Promise<void> { 2172 length = length.map([amount](uint64_t n) { 2173 KJ_ASSERT(n >= amount); 2174 return n - amount; 2175 }); 2176 2177 if (amount < heapBuffer.size()) { 2178 heapBuffer = heapBuffer.slice(0, amount).attach(mv(heapBuffer)); 2179 } 2180 2181 KJ_ASSERT(stoppage == nullptr); 2182 Maybe<ArrayPtr<byte>> bufferPtr = nullptr; 2183 for (auto& state: branches) { 2184 KJ_IF_MAYBE(s, state) { 2185 // Prefer to move the buffer into the receiving branch's deque, rather than memcpy. 2186 // 2187 // TODO(perf): For the 2-branch case, this is fine, since the majority of the time 2188 // only one buffer will be in use. If we generalize to the n-branch case, this would 2189 // become memcpy-heavy. 2190 KJ_IF_MAYBE(ptr, bufferPtr) { 2191 s->buffer.produce(heapArray(*ptr)); 2192 } else { 2193 bufferPtr = ArrayPtr<byte>(heapBuffer); 2194 s->buffer.produce(mv(heapBuffer)); 2195 } 2196 } 2197 } 2198 2199 if (amount < minBytes) { 2200 // Short read, EOF. 2201 stoppage = Stoppage(Eof()); 2202 } 2203 2204 return pullLoop(); 2205 }, [this](Exception&& exception) { 2206 // Exception from the inner tryRead(). Propagate. 2207 stoppage = Stoppage(mv(exception)); 2208 return pullLoop(); 2209 }); 2210 }); 2211 } 2212 }; 2213 2214 constexpr size_t AsyncTee::MAX_BLOCK_SIZE; 2215 2216 uint64_t AsyncTee::Buffer::consume(ArrayPtr<byte>& readBuffer, size_t& minBytes) { 2217 uint64_t totalAmount = 0; 2218 2219 while (readBuffer.size() > 0 && !bufferList.empty()) { 2220 auto& bytes = bufferList.front(); 2221 auto amount = kj::min(bytes.size(), readBuffer.size()); 2222 memcpy(readBuffer.begin(), bytes.begin(), amount); 2223 totalAmount += amount; 2224 2225 readBuffer = readBuffer.slice(amount, readBuffer.size()); 2226 minBytes -= kj::min(amount, minBytes); 2227 2228 if (amount == bytes.size()) { 2229 bufferList.pop_front(); 2230 } else { 2231 bytes = heapArray(bytes.slice(amount, bytes.size())); 2232 return totalAmount; 2233 } 2234 } 2235 2236 return totalAmount; 2237 } 2238 2239 void AsyncTee::Buffer::produce(Array<byte> bytes) { 2240 bufferList.push_back(mv(bytes)); 2241 } 2242 2243 Array<const ArrayPtr<const byte>> AsyncTee::Buffer::asArray( 2244 uint64_t maxBytes, uint64_t& amount) { 2245 amount = 0; 2246 2247 Vector<ArrayPtr<const byte>> buffers; 2248 Vector<Array<byte>> ownBuffers; 2249 2250 while (maxBytes > 0 && !bufferList.empty()) { 2251 auto& bytes = bufferList.front(); 2252 2253 if (bytes.size() <= maxBytes) { 2254 amount += bytes.size(); 2255 maxBytes -= bytes.size(); 2256 2257 buffers.add(bytes); 2258 ownBuffers.add(mv(bytes)); 2259 2260 bufferList.pop_front(); 2261 } else { 2262 auto ownBytes = heapArray(bytes.slice(0, maxBytes)); 2263 buffers.add(ownBytes); 2264 ownBuffers.add(mv(ownBytes)); 2265 2266 bytes = heapArray(bytes.slice(maxBytes, bytes.size())); 2267 2268 amount += maxBytes; 2269 maxBytes = 0; 2270 } 2271 } 2272 2273 2274 if (buffers.size() > 0) { 2275 return buffers.releaseAsArray().attach(mv(ownBuffers)); 2276 } 2277 2278 return {}; 2279 } 2280 2281 bool AsyncTee::Buffer::empty() const { 2282 return bufferList.empty(); 2283 } 2284 2285 uint64_t AsyncTee::Buffer::size() const { 2286 uint64_t result = 0; 2287 2288 for (auto& bytes: bufferList) { 2289 result += bytes.size(); 2290 } 2291 2292 return result; 2293 } 2294 2295 class TeeBranch final: public AsyncInputStream { 2296 public: 2297 TeeBranch(Own<AsyncTee> teeArg): tee(mv(teeArg)), branch(tee->addBranch()) {} 2298 2299 TeeBranch(Badge<TeeBranch>, Own<AsyncTee> teeArg, AsyncTee::Branch&& branchState) 2300 : tee(mv(teeArg)), branch(tee->addBranch(mv(branchState))) {} 2301 2302 ~TeeBranch() noexcept(false) { 2303 unwind.catchExceptionsIfUnwinding([&]() { 2304 tee->removeBranch(branch); 2305 }); 2306 } 2307 2308 Promise<size_t> tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { 2309 return tee->tryRead(branch, buffer, minBytes, maxBytes); 2310 } 2311 2312 Promise<uint64_t> pumpTo(AsyncOutputStream& output, uint64_t amount) override { 2313 return tee->pumpTo(branch, output, amount); 2314 } 2315 2316 Maybe<uint64_t> tryGetLength() override { 2317 return tee->tryGetLength(branch); 2318 } 2319 2320 Maybe<Own<AsyncInputStream>> tryTee(uint64_t limit) override { 2321 if (tee->getBufferSizeLimit() != limit) { 2322 // Cannot optimize this path as the limit has changed, so we need a new AsyncTee to manage 2323 // the limit. 2324 return nullptr; 2325 } 2326 2327 return kj::heap<TeeBranch>(Badge<TeeBranch>{}, addRef(*tee), tee->cloneBranch(branch)); 2328 } 2329 2330 private: 2331 Own<AsyncTee> tee; 2332 const uint branch; 2333 UnwindDetector unwind; 2334 }; 2335 2336 } // namespace 2337 2338 Tee newTee(Own<AsyncInputStream> input, uint64_t limit) { 2339 KJ_IF_MAYBE(t, input->tryTee()) { 2340 return { { mv(input), mv(*t) }}; 2341 } 2342 2343 auto impl = refcounted<AsyncTee>(mv(input), limit); 2344 Own<AsyncInputStream> branch1 = heap<TeeBranch>(addRef(*impl)); 2345 Own<AsyncInputStream> branch2 = heap<TeeBranch>(mv(impl)); 2346 return { { mv(branch1), mv(branch2) } }; 2347 } 2348 2349 namespace { 2350 2351 class PromisedAsyncIoStream final: public kj::AsyncIoStream, private kj::TaskSet::ErrorHandler { 2352 // An AsyncIoStream which waits for a promise to resolve then forwards all calls to the promised 2353 // stream. 2354 2355 public: 2356 PromisedAsyncIoStream(kj::Promise<kj::Own<AsyncIoStream>> promise) 2357 : promise(promise.then([this](kj::Own<AsyncIoStream> result) { 2358 stream = kj::mv(result); 2359 }).fork()), 2360 tasks(*this) {} 2361 2362 kj::Promise<size_t> read(void* buffer, size_t minBytes, size_t maxBytes) override { 2363 KJ_IF_MAYBE(s, stream) { 2364 return s->get()->read(buffer, minBytes, maxBytes); 2365 } else { 2366 return promise.addBranch().then([this,buffer,minBytes,maxBytes]() { 2367 return KJ_ASSERT_NONNULL(stream)->read(buffer, minBytes, maxBytes); 2368 }); 2369 } 2370 } 2371 kj::Promise<size_t> tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { 2372 KJ_IF_MAYBE(s, stream) { 2373 return s->get()->tryRead(buffer, minBytes, maxBytes); 2374 } else { 2375 return promise.addBranch().then([this,buffer,minBytes,maxBytes]() { 2376 return KJ_ASSERT_NONNULL(stream)->tryRead(buffer, minBytes, maxBytes); 2377 }); 2378 } 2379 } 2380 2381 kj::Maybe<uint64_t> tryGetLength() override { 2382 KJ_IF_MAYBE(s, stream) { 2383 return s->get()->tryGetLength(); 2384 } else { 2385 return nullptr; 2386 } 2387 } 2388 2389 kj::Promise<uint64_t> pumpTo(kj::AsyncOutputStream& output, uint64_t amount) override { 2390 KJ_IF_MAYBE(s, stream) { 2391 return s->get()->pumpTo(output, amount); 2392 } else { 2393 return promise.addBranch().then([this,&output,amount]() { 2394 return KJ_ASSERT_NONNULL(stream)->pumpTo(output, amount); 2395 }); 2396 } 2397 } 2398 2399 kj::Promise<void> write(const void* buffer, size_t size) override { 2400 KJ_IF_MAYBE(s, stream) { 2401 return s->get()->write(buffer, size); 2402 } else { 2403 return promise.addBranch().then([this,buffer,size]() { 2404 return KJ_ASSERT_NONNULL(stream)->write(buffer, size); 2405 }); 2406 } 2407 } 2408 kj::Promise<void> write(kj::ArrayPtr<const kj::ArrayPtr<const byte>> pieces) override { 2409 KJ_IF_MAYBE(s, stream) { 2410 return s->get()->write(pieces); 2411 } else { 2412 return promise.addBranch().then([this,pieces]() { 2413 return KJ_ASSERT_NONNULL(stream)->write(pieces); 2414 }); 2415 } 2416 } 2417 2418 kj::Maybe<kj::Promise<uint64_t>> tryPumpFrom( 2419 kj::AsyncInputStream& input, uint64_t amount = kj::maxValue) override { 2420 KJ_IF_MAYBE(s, stream) { 2421 // Call input.pumpTo() on the resolved stream instead, so that if it does some dynamic_casts 2422 // or whatnot to detect stream types it can retry those on the inner stream. 2423 return input.pumpTo(**s, amount); 2424 } else { 2425 return promise.addBranch().then([this,&input,amount]() { 2426 // Here we actually have no choice but to call input.pumpTo() because if we called 2427 // tryPumpFrom(input, amount) and it returned nullptr, what would we do? It's too late for 2428 // us to return nullptr. But the thing about dynamic_cast also applies. 2429 return input.pumpTo(*KJ_ASSERT_NONNULL(stream), amount); 2430 }); 2431 } 2432 } 2433 2434 Promise<void> whenWriteDisconnected() override { 2435 KJ_IF_MAYBE(s, stream) { 2436 return s->get()->whenWriteDisconnected(); 2437 } else { 2438 return promise.addBranch().then([this]() { 2439 return KJ_ASSERT_NONNULL(stream)->whenWriteDisconnected(); 2440 }, [](kj::Exception&& e) -> kj::Promise<void> { 2441 if (e.getType() == kj::Exception::Type::DISCONNECTED) { 2442 return kj::READY_NOW; 2443 } else { 2444 return kj::mv(e); 2445 } 2446 }); 2447 } 2448 } 2449 2450 void shutdownWrite() override { 2451 KJ_IF_MAYBE(s, stream) { 2452 return s->get()->shutdownWrite(); 2453 } else { 2454 tasks.add(promise.addBranch().then([this]() { 2455 return KJ_ASSERT_NONNULL(stream)->shutdownWrite(); 2456 })); 2457 } 2458 } 2459 2460 void abortRead() override { 2461 KJ_IF_MAYBE(s, stream) { 2462 return s->get()->abortRead(); 2463 } else { 2464 tasks.add(promise.addBranch().then([this]() { 2465 return KJ_ASSERT_NONNULL(stream)->abortRead(); 2466 })); 2467 } 2468 } 2469 2470 kj::Maybe<int> getFd() const override { 2471 KJ_IF_MAYBE(s, stream) { 2472 return s->get()->getFd(); 2473 } else { 2474 return nullptr; 2475 } 2476 } 2477 2478 private: 2479 kj::ForkedPromise<void> promise; 2480 kj::Maybe<kj::Own<AsyncIoStream>> stream; 2481 kj::TaskSet tasks; 2482 2483 void taskFailed(kj::Exception&& exception) override { 2484 KJ_LOG(ERROR, exception); 2485 } 2486 }; 2487 2488 class PromisedAsyncOutputStream final: public kj::AsyncOutputStream { 2489 // An AsyncOutputStream which waits for a promise to resolve then forwards all calls to the 2490 // promised stream. 2491 // 2492 // TODO(cleanup): Can this share implementation with PromiseIoStream? Seems hard. 2493 2494 public: 2495 PromisedAsyncOutputStream(kj::Promise<kj::Own<AsyncOutputStream>> promise) 2496 : promise(promise.then([this](kj::Own<AsyncOutputStream> result) { 2497 stream = kj::mv(result); 2498 }).fork()) {} 2499 2500 kj::Promise<void> write(const void* buffer, size_t size) override { 2501 KJ_IF_MAYBE(s, stream) { 2502 return s->get()->write(buffer, size); 2503 } else { 2504 return promise.addBranch().then([this,buffer,size]() { 2505 return KJ_ASSERT_NONNULL(stream)->write(buffer, size); 2506 }); 2507 } 2508 } 2509 kj::Promise<void> write(kj::ArrayPtr<const kj::ArrayPtr<const byte>> pieces) override { 2510 KJ_IF_MAYBE(s, stream) { 2511 return s->get()->write(pieces); 2512 } else { 2513 return promise.addBranch().then([this,pieces]() { 2514 return KJ_ASSERT_NONNULL(stream)->write(pieces); 2515 }); 2516 } 2517 } 2518 2519 kj::Maybe<kj::Promise<uint64_t>> tryPumpFrom( 2520 kj::AsyncInputStream& input, uint64_t amount = kj::maxValue) override { 2521 KJ_IF_MAYBE(s, stream) { 2522 return s->get()->tryPumpFrom(input, amount); 2523 } else { 2524 return promise.addBranch().then([this,&input,amount]() { 2525 // Call input.pumpTo() on the resolved stream instead. 2526 return input.pumpTo(*KJ_ASSERT_NONNULL(stream), amount); 2527 }); 2528 } 2529 } 2530 2531 Promise<void> whenWriteDisconnected() override { 2532 KJ_IF_MAYBE(s, stream) { 2533 return s->get()->whenWriteDisconnected(); 2534 } else { 2535 return promise.addBranch().then([this]() { 2536 return KJ_ASSERT_NONNULL(stream)->whenWriteDisconnected(); 2537 }, [](kj::Exception&& e) -> kj::Promise<void> { 2538 if (e.getType() == kj::Exception::Type::DISCONNECTED) { 2539 return kj::READY_NOW; 2540 } else { 2541 return kj::mv(e); 2542 } 2543 }); 2544 } 2545 } 2546 2547 private: 2548 kj::ForkedPromise<void> promise; 2549 kj::Maybe<kj::Own<AsyncOutputStream>> stream; 2550 }; 2551 2552 } // namespace 2553 2554 Own<AsyncOutputStream> newPromisedStream(Promise<Own<AsyncOutputStream>> promise) { 2555 return heap<PromisedAsyncOutputStream>(kj::mv(promise)); 2556 } 2557 Own<AsyncIoStream> newPromisedStream(Promise<Own<AsyncIoStream>> promise) { 2558 return heap<PromisedAsyncIoStream>(kj::mv(promise)); 2559 } 2560 2561 Promise<void> AsyncCapabilityStream::writeWithFds( 2562 ArrayPtr<const byte> data, ArrayPtr<const ArrayPtr<const byte>> moreData, 2563 ArrayPtr<const AutoCloseFd> fds) { 2564 // HACK: AutoCloseFd actually contains an `int` under the hood. We can reinterpret_cast to avoid 2565 // unnecessary memory allocation. 2566 static_assert(sizeof(AutoCloseFd) == sizeof(int), "this optimization won't work"); 2567 auto intArray = arrayPtr(reinterpret_cast<const int*>(fds.begin()), fds.size()); 2568 2569 // Be extra-paranoid about aliasing rules by injecting a compiler barrier here. Probably 2570 // not necessary but also probably doesn't hurt. 2571 #if _MSC_VER 2572 _ReadWriteBarrier(); 2573 #else 2574 __asm__ __volatile__("": : :"memory"); 2575 #endif 2576 2577 return writeWithFds(data, moreData, intArray); 2578 } 2579 2580 Promise<Own<AsyncCapabilityStream>> AsyncCapabilityStream::receiveStream() { 2581 return tryReceiveStream() 2582 .then([](Maybe<Own<AsyncCapabilityStream>>&& result) 2583 -> Promise<Own<AsyncCapabilityStream>> { 2584 KJ_IF_MAYBE(r, result) { 2585 return kj::mv(*r); 2586 } else { 2587 return KJ_EXCEPTION(FAILED, "EOF when expecting to receive capability"); 2588 } 2589 }); 2590 } 2591 2592 kj::Promise<Maybe<Own<AsyncCapabilityStream>>> AsyncCapabilityStream::tryReceiveStream() { 2593 struct ResultHolder { 2594 byte b; 2595 Own<AsyncCapabilityStream> stream; 2596 }; 2597 auto result = kj::heap<ResultHolder>(); 2598 auto promise = tryReadWithStreams(&result->b, 1, 1, &result->stream, 1); 2599 return promise.then([result = kj::mv(result)](ReadResult actual) mutable 2600 -> Maybe<Own<AsyncCapabilityStream>> { 2601 if (actual.byteCount == 0) { 2602 return nullptr; 2603 } 2604 2605 KJ_REQUIRE(actual.capCount == 1, 2606 "expected to receive a capability (e.g. file descirptor via SCM_RIGHTS), but didn't") { 2607 return nullptr; 2608 } 2609 2610 return kj::mv(result->stream); 2611 }); 2612 } 2613 2614 Promise<void> AsyncCapabilityStream::sendStream(Own<AsyncCapabilityStream> stream) { 2615 static constexpr byte b = 0; 2616 auto streams = kj::heapArray<Own<AsyncCapabilityStream>>(1); 2617 streams[0] = kj::mv(stream); 2618 return writeWithStreams(arrayPtr(&b, 1), nullptr, kj::mv(streams)); 2619 } 2620 2621 Promise<AutoCloseFd> AsyncCapabilityStream::receiveFd() { 2622 return tryReceiveFd().then([](Maybe<AutoCloseFd>&& result) -> Promise<AutoCloseFd> { 2623 KJ_IF_MAYBE(r, result) { 2624 return kj::mv(*r); 2625 } else { 2626 return KJ_EXCEPTION(FAILED, "EOF when expecting to receive capability"); 2627 } 2628 }); 2629 } 2630 2631 kj::Promise<kj::Maybe<AutoCloseFd>> AsyncCapabilityStream::tryReceiveFd() { 2632 struct ResultHolder { 2633 byte b; 2634 AutoCloseFd fd; 2635 }; 2636 auto result = kj::heap<ResultHolder>(); 2637 auto promise = tryReadWithFds(&result->b, 1, 1, &result->fd, 1); 2638 return promise.then([result = kj::mv(result)](ReadResult actual) mutable 2639 -> Maybe<AutoCloseFd> { 2640 if (actual.byteCount == 0) { 2641 return nullptr; 2642 } 2643 2644 KJ_REQUIRE(actual.capCount == 1, 2645 "expected to receive a file descriptor (e.g. via SCM_RIGHTS), but didn't") { 2646 return nullptr; 2647 } 2648 2649 return kj::mv(result->fd); 2650 }); 2651 } 2652 2653 Promise<void> AsyncCapabilityStream::sendFd(int fd) { 2654 static constexpr byte b = 0; 2655 auto fds = kj::heapArray<int>(1); 2656 fds[0] = fd; 2657 auto promise = writeWithFds(arrayPtr(&b, 1), nullptr, fds); 2658 return promise.attach(kj::mv(fds)); 2659 } 2660 2661 void AsyncIoStream::getsockopt(int level, int option, void* value, uint* length) { 2662 KJ_UNIMPLEMENTED("Not a socket.") { *length = 0; break; } 2663 } 2664 void AsyncIoStream::setsockopt(int level, int option, const void* value, uint length) { 2665 KJ_UNIMPLEMENTED("Not a socket.") { break; } 2666 } 2667 void AsyncIoStream::getsockname(struct sockaddr* addr, uint* length) { 2668 KJ_UNIMPLEMENTED("Not a socket.") { *length = 0; break; } 2669 } 2670 void AsyncIoStream::getpeername(struct sockaddr* addr, uint* length) { 2671 KJ_UNIMPLEMENTED("Not a socket.") { *length = 0; break; } 2672 } 2673 void ConnectionReceiver::getsockopt(int level, int option, void* value, uint* length) { 2674 KJ_UNIMPLEMENTED("Not a socket.") { *length = 0; break; } 2675 } 2676 void ConnectionReceiver::setsockopt(int level, int option, const void* value, uint length) { 2677 KJ_UNIMPLEMENTED("Not a socket.") { break; } 2678 } 2679 void ConnectionReceiver::getsockname(struct sockaddr* addr, uint* length) { 2680 KJ_UNIMPLEMENTED("Not a socket.") { *length = 0; break; } 2681 } 2682 void DatagramPort::getsockopt(int level, int option, void* value, uint* length) { 2683 KJ_UNIMPLEMENTED("Not a socket.") { *length = 0; break; } 2684 } 2685 void DatagramPort::setsockopt(int level, int option, const void* value, uint length) { 2686 KJ_UNIMPLEMENTED("Not a socket.") { break; } 2687 } 2688 Own<DatagramPort> NetworkAddress::bindDatagramPort() { 2689 KJ_UNIMPLEMENTED("Datagram sockets not implemented."); 2690 } 2691 Own<DatagramPort> LowLevelAsyncIoProvider::wrapDatagramSocketFd( 2692 Fd fd, LowLevelAsyncIoProvider::NetworkFilter& filter, uint flags) { 2693 KJ_UNIMPLEMENTED("Datagram sockets not implemented."); 2694 } 2695 #if !_WIN32 2696 Own<AsyncCapabilityStream> LowLevelAsyncIoProvider::wrapUnixSocketFd(Fd fd, uint flags) { 2697 KJ_UNIMPLEMENTED("Unix socket with FD passing not implemented."); 2698 } 2699 #endif 2700 CapabilityPipe AsyncIoProvider::newCapabilityPipe() { 2701 KJ_UNIMPLEMENTED("Capability pipes not implemented."); 2702 } 2703 2704 Own<AsyncInputStream> LowLevelAsyncIoProvider::wrapInputFd(OwnFd&& fd, uint flags) { 2705 return wrapInputFd(reinterpret_cast<Fd>(fd.release()), flags | TAKE_OWNERSHIP); 2706 } 2707 Own<AsyncOutputStream> LowLevelAsyncIoProvider::wrapOutputFd(OwnFd&& fd, uint flags) { 2708 return wrapOutputFd(reinterpret_cast<Fd>(fd.release()), flags | TAKE_OWNERSHIP); 2709 } 2710 Own<AsyncIoStream> LowLevelAsyncIoProvider::wrapSocketFd(OwnFd&& fd, uint flags) { 2711 return wrapSocketFd(reinterpret_cast<Fd>(fd.release()), flags | TAKE_OWNERSHIP); 2712 } 2713 #if !_WIN32 2714 Own<AsyncCapabilityStream> LowLevelAsyncIoProvider::wrapUnixSocketFd(OwnFd&& fd, uint flags) { 2715 return wrapUnixSocketFd(reinterpret_cast<Fd>(fd.release()), flags | TAKE_OWNERSHIP); 2716 } 2717 #endif 2718 Promise<Own<AsyncIoStream>> LowLevelAsyncIoProvider::wrapConnectingSocketFd( 2719 OwnFd&& fd, const struct sockaddr* addr, uint addrlen, uint flags) { 2720 return wrapConnectingSocketFd(reinterpret_cast<Fd>(fd.release()), addr, addrlen, 2721 flags | TAKE_OWNERSHIP); 2722 } 2723 Own<ConnectionReceiver> LowLevelAsyncIoProvider::wrapListenSocketFd( 2724 OwnFd&& fd, NetworkFilter& filter, uint flags) { 2725 return wrapListenSocketFd(reinterpret_cast<Fd>(fd.release()), filter, flags | TAKE_OWNERSHIP); 2726 } 2727 Own<ConnectionReceiver> LowLevelAsyncIoProvider::wrapListenSocketFd(OwnFd&& fd, uint flags) { 2728 return wrapListenSocketFd(reinterpret_cast<Fd>(fd.release()), flags | TAKE_OWNERSHIP); 2729 } 2730 Own<DatagramPort> LowLevelAsyncIoProvider::wrapDatagramSocketFd( 2731 OwnFd&& fd, NetworkFilter& filter, uint flags) { 2732 return wrapDatagramSocketFd(reinterpret_cast<Fd>(fd.release()), filter, flags | TAKE_OWNERSHIP); 2733 } 2734 Own<DatagramPort> LowLevelAsyncIoProvider::wrapDatagramSocketFd(OwnFd&& fd, uint flags) { 2735 return wrapDatagramSocketFd(reinterpret_cast<Fd>(fd.release()), flags | TAKE_OWNERSHIP); 2736 } 2737 2738 namespace { 2739 2740 class DummyNetworkFilter: public kj::LowLevelAsyncIoProvider::NetworkFilter { 2741 public: 2742 bool shouldAllow(const struct sockaddr* addr, uint addrlen) override { return true; } 2743 }; 2744 2745 } // namespace 2746 2747 LowLevelAsyncIoProvider::NetworkFilter& LowLevelAsyncIoProvider::NetworkFilter::getAllAllowed() { 2748 static DummyNetworkFilter result; 2749 return result; 2750 } 2751 2752 // ======================================================================================= 2753 // Convenience adapters. 2754 2755 Promise<Own<AsyncIoStream>> CapabilityStreamConnectionReceiver::accept() { 2756 return inner.receiveStream() 2757 .then([](Own<AsyncCapabilityStream>&& stream) -> Own<AsyncIoStream> { 2758 return kj::mv(stream); 2759 }); 2760 } 2761 2762 Promise<AuthenticatedStream> CapabilityStreamConnectionReceiver::acceptAuthenticated() { 2763 return accept().then([](Own<AsyncIoStream>&& stream) { 2764 return AuthenticatedStream { kj::mv(stream), UnknownPeerIdentity::newInstance() }; 2765 }); 2766 } 2767 2768 uint CapabilityStreamConnectionReceiver::getPort() { 2769 return 0; 2770 } 2771 2772 Promise<Own<AsyncIoStream>> CapabilityStreamNetworkAddress::connect() { 2773 CapabilityPipe pipe; 2774 KJ_IF_MAYBE(p, provider) { 2775 pipe = p->newCapabilityPipe(); 2776 } else { 2777 pipe = kj::newCapabilityPipe(); 2778 } 2779 auto result = kj::mv(pipe.ends[0]); 2780 return inner.sendStream(kj::mv(pipe.ends[1])) 2781 .then(kj::mvCapture(result, [](Own<AsyncIoStream>&& result) { 2782 return kj::mv(result); 2783 })); 2784 } 2785 Promise<AuthenticatedStream> CapabilityStreamNetworkAddress::connectAuthenticated() { 2786 return connect().then([](Own<AsyncIoStream>&& stream) { 2787 return AuthenticatedStream { kj::mv(stream), UnknownPeerIdentity::newInstance() }; 2788 }); 2789 } 2790 Own<ConnectionReceiver> CapabilityStreamNetworkAddress::listen() { 2791 return kj::heap<CapabilityStreamConnectionReceiver>(inner); 2792 } 2793 2794 Own<NetworkAddress> CapabilityStreamNetworkAddress::clone() { 2795 KJ_UNIMPLEMENTED("can't clone CapabilityStreamNetworkAddress"); 2796 } 2797 String CapabilityStreamNetworkAddress::toString() { 2798 return kj::str("<CapabilityStreamNetworkAddress>"); 2799 } 2800 2801 // ======================================================================================= 2802 2803 namespace _ { // private 2804 2805 #if !_WIN32 2806 2807 kj::ArrayPtr<const char> safeUnixPath(const struct sockaddr_un* addr, uint addrlen) { 2808 KJ_REQUIRE(addr->sun_family == AF_UNIX, "not a unix address"); 2809 KJ_REQUIRE(addrlen >= offsetof(sockaddr_un, sun_path), "invalid unix address"); 2810 2811 size_t maxPathlen = addrlen - offsetof(sockaddr_un, sun_path); 2812 2813 size_t pathlen; 2814 if (maxPathlen > 0 && addr->sun_path[0] == '\0') { 2815 // Linux "abstract" unix address 2816 pathlen = strnlen(addr->sun_path + 1, maxPathlen - 1) + 1; 2817 } else { 2818 pathlen = strnlen(addr->sun_path, maxPathlen); 2819 } 2820 return kj::arrayPtr(addr->sun_path, pathlen); 2821 } 2822 2823 #endif // !_WIN32 2824 2825 CidrRange::CidrRange(StringPtr pattern) { 2826 size_t slashPos = KJ_REQUIRE_NONNULL(pattern.findFirst('/'), "invalid CIDR", pattern); 2827 2828 bitCount = pattern.slice(slashPos + 1).parseAs<uint>(); 2829 2830 KJ_STACK_ARRAY(char, addr, slashPos + 1, 128, 128); 2831 memcpy(addr.begin(), pattern.begin(), slashPos); 2832 addr[slashPos] = '\0'; 2833 2834 if (pattern.findFirst(':') == nullptr) { 2835 family = AF_INET; 2836 KJ_REQUIRE(bitCount <= 32, "invalid CIDR", pattern); 2837 } else { 2838 family = AF_INET6; 2839 KJ_REQUIRE(bitCount <= 128, "invalid CIDR", pattern); 2840 } 2841 2842 KJ_ASSERT(inet_pton(family, addr.begin(), bits) > 0, "invalid CIDR", pattern); 2843 zeroIrrelevantBits(); 2844 } 2845 2846 CidrRange::CidrRange(int family, ArrayPtr<const byte> bits, uint bitCount) 2847 : family(family), bitCount(bitCount) { 2848 if (family == AF_INET) { 2849 KJ_REQUIRE(bitCount <= 32); 2850 } else { 2851 KJ_REQUIRE(bitCount <= 128); 2852 } 2853 KJ_REQUIRE(bits.size() * 8 >= bitCount); 2854 size_t byteCount = (bitCount + 7) / 8; 2855 memcpy(this->bits, bits.begin(), byteCount); 2856 memset(this->bits + byteCount, 0, sizeof(this->bits) - byteCount); 2857 2858 zeroIrrelevantBits(); 2859 } 2860 2861 CidrRange CidrRange::inet4(ArrayPtr<const byte> bits, uint bitCount) { 2862 return CidrRange(AF_INET, bits, bitCount); 2863 } 2864 CidrRange CidrRange::inet6( 2865 ArrayPtr<const uint16_t> prefix, ArrayPtr<const uint16_t> suffix, 2866 uint bitCount) { 2867 KJ_REQUIRE(prefix.size() + suffix.size() <= 8); 2868 2869 byte bits[16] = { 0,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0, }; 2870 2871 for (size_t i: kj::indices(prefix)) { 2872 bits[i * 2] = prefix[i] >> 8; 2873 bits[i * 2 + 1] = prefix[i] & 0xff; 2874 } 2875 2876 byte* suffixBits = bits + (16 - suffix.size() * 2); 2877 for (size_t i: kj::indices(suffix)) { 2878 suffixBits[i * 2] = suffix[i] >> 8; 2879 suffixBits[i * 2 + 1] = suffix[i] & 0xff; 2880 } 2881 2882 return CidrRange(AF_INET6, bits, bitCount); 2883 } 2884 2885 bool CidrRange::matches(const struct sockaddr* addr) const { 2886 const byte* otherBits; 2887 2888 switch (family) { 2889 case AF_INET: 2890 if (addr->sa_family == AF_INET6) { 2891 otherBits = reinterpret_cast<const struct sockaddr_in6*>(addr)->sin6_addr.s6_addr; 2892 static constexpr byte V6MAPPED[12] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff }; 2893 if (memcmp(otherBits, V6MAPPED, sizeof(V6MAPPED)) == 0) { 2894 // We're an ipv4 range and the address is ipv6, but it's a "v6 mapped" address, meaning 2895 // it's equivalent to an ipv4 address. Try to match against the ipv4 part. 2896 otherBits = otherBits + sizeof(V6MAPPED); 2897 } else { 2898 return false; 2899 } 2900 } else if (addr->sa_family == AF_INET) { 2901 otherBits = reinterpret_cast<const byte*>( 2902 &reinterpret_cast<const struct sockaddr_in*>(addr)->sin_addr.s_addr); 2903 } else { 2904 return false; 2905 } 2906 2907 break; 2908 2909 case AF_INET6: 2910 if (addr->sa_family != AF_INET6) return false; 2911 2912 otherBits = reinterpret_cast<const struct sockaddr_in6*>(addr)->sin6_addr.s6_addr; 2913 break; 2914 2915 default: 2916 KJ_UNREACHABLE; 2917 } 2918 2919 if (memcmp(bits, otherBits, bitCount / 8) != 0) return false; 2920 2921 return bitCount == 128 || 2922 bits[bitCount / 8] == (otherBits[bitCount / 8] & (0xff00 >> (bitCount % 8))); 2923 } 2924 2925 bool CidrRange::matchesFamily(int family) const { 2926 switch (family) { 2927 case AF_INET: 2928 return this->family == AF_INET; 2929 case AF_INET6: 2930 // Even if we're a v4 CIDR, we can match v6 addresses in the v4-mapped range. 2931 return true; 2932 default: 2933 return false; 2934 } 2935 } 2936 2937 String CidrRange::toString() const { 2938 char result[128]; 2939 KJ_ASSERT(inet_ntop(family, (void*)bits, result, sizeof(result)) == result); 2940 return kj::str(result, '/', bitCount); 2941 } 2942 2943 void CidrRange::zeroIrrelevantBits() { 2944 // Mask out insignificant bits of partial byte. 2945 if (bitCount < 128) { 2946 bits[bitCount / 8] &= 0xff00 >> (bitCount % 8); 2947 2948 // Zero the remaining bytes. 2949 size_t n = bitCount / 8 + 1; 2950 memset(bits + n, 0, sizeof(bits) - n); 2951 } 2952 } 2953 2954 // ----------------------------------------------------------------------------- 2955 2956 ArrayPtr<const CidrRange> localCidrs() { 2957 static const CidrRange result[] = { 2958 // localhost 2959 "127.0.0.0/8"_kj, 2960 "::1/128"_kj, 2961 2962 // Trying to *connect* to 0.0.0.0 on many systems is equivalent to connecting to localhost. 2963 // (wat) 2964 "0.0.0.0/32"_kj, 2965 "::/128"_kj, 2966 }; 2967 2968 // TODO(cleanup): A bug in GCC 4.8, fixed in 4.9, prevents result from implicitly 2969 // casting to our return type. 2970 return kj::arrayPtr(result, kj::size(result)); 2971 } 2972 2973 ArrayPtr<const CidrRange> privateCidrs() { 2974 static const CidrRange result[] = { 2975 "10.0.0.0/8"_kj, // RFC1918 reserved for internal network 2976 "100.64.0.0/10"_kj, // RFC6598 "shared address space" for carrier-grade NAT 2977 "169.254.0.0/16"_kj, // RFC3927 "link local" (auto-configured LAN in absence of DHCP) 2978 "172.16.0.0/12"_kj, // RFC1918 reserved for internal network 2979 "192.168.0.0/16"_kj, // RFC1918 reserved for internal network 2980 2981 "fc00::/7"_kj, // RFC4193 unique private network 2982 "fe80::/10"_kj, // RFC4291 "link local" (auto-configured LAN in absence of DHCP) 2983 }; 2984 2985 // TODO(cleanup): A bug in GCC 4.8, fixed in 4.9, prevents result from implicitly 2986 // casting to our return type. 2987 return kj::arrayPtr(result, kj::size(result)); 2988 } 2989 2990 ArrayPtr<const CidrRange> reservedCidrs() { 2991 static const CidrRange result[] = { 2992 "192.0.0.0/24"_kj, // RFC6890 reserved for special protocols 2993 "224.0.0.0/4"_kj, // RFC1112 multicast 2994 "240.0.0.0/4"_kj, // RFC1112 multicast / reserved for future use 2995 "255.255.255.255/32"_kj, // RFC0919 broadcast address 2996 2997 "2001::/23"_kj, // RFC2928 reserved for special protocols 2998 "ff00::/8"_kj, // RFC4291 multicast 2999 }; 3000 3001 // TODO(cleanup): A bug in GCC 4.8, fixed in 4.9, prevents result from implicitly 3002 // casting to our return type. 3003 return kj::arrayPtr(result, kj::size(result)); 3004 } 3005 3006 ArrayPtr<const CidrRange> exampleAddresses() { 3007 static const CidrRange result[] = { 3008 "192.0.2.0/24"_kj, // RFC5737 "example address" block 1 -- like example.com for IPs 3009 "198.51.100.0/24"_kj, // RFC5737 "example address" block 2 -- like example.com for IPs 3010 "203.0.113.0/24"_kj, // RFC5737 "example address" block 3 -- like example.com for IPs 3011 "2001:db8::/32"_kj, // RFC3849 "example address" block -- like example.com for IPs 3012 }; 3013 3014 // TODO(cleanup): A bug in GCC 4.8, fixed in 4.9, prevents result from implicitly 3015 // casting to our return type. 3016 return kj::arrayPtr(result, kj::size(result)); 3017 } 3018 3019 NetworkFilter::NetworkFilter() 3020 : allowUnix(true), allowAbstractUnix(true) { 3021 allowCidrs.add(CidrRange::inet4({0,0,0,0}, 0)); 3022 allowCidrs.add(CidrRange::inet6({}, {}, 0)); 3023 denyCidrs.addAll(reservedCidrs()); 3024 } 3025 3026 NetworkFilter::NetworkFilter(ArrayPtr<const StringPtr> allow, ArrayPtr<const StringPtr> deny, 3027 NetworkFilter& next) 3028 : allowUnix(false), allowAbstractUnix(false), next(next) { 3029 for (auto rule: allow) { 3030 if (rule == "local") { 3031 allowCidrs.addAll(localCidrs()); 3032 } else if (rule == "network") { 3033 allowCidrs.add(CidrRange::inet4({0,0,0,0}, 0)); 3034 allowCidrs.add(CidrRange::inet6({}, {}, 0)); 3035 denyCidrs.addAll(localCidrs()); 3036 } else if (rule == "private") { 3037 allowCidrs.addAll(privateCidrs()); 3038 allowCidrs.addAll(localCidrs()); 3039 } else if (rule == "public") { 3040 allowCidrs.add(CidrRange::inet4({0,0,0,0}, 0)); 3041 allowCidrs.add(CidrRange::inet6({}, {}, 0)); 3042 denyCidrs.addAll(privateCidrs()); 3043 denyCidrs.addAll(localCidrs()); 3044 } else if (rule == "unix") { 3045 allowUnix = true; 3046 } else if (rule == "unix-abstract") { 3047 allowAbstractUnix = true; 3048 } else { 3049 allowCidrs.add(CidrRange(rule)); 3050 } 3051 } 3052 3053 for (auto rule: deny) { 3054 if (rule == "local") { 3055 denyCidrs.addAll(localCidrs()); 3056 } else if (rule == "network") { 3057 KJ_FAIL_REQUIRE("don't deny 'network', allow 'local' instead"); 3058 } else if (rule == "private") { 3059 denyCidrs.addAll(privateCidrs()); 3060 } else if (rule == "public") { 3061 // Tricky: What if we allow 'network' and deny 'public'? 3062 KJ_FAIL_REQUIRE("don't deny 'public', allow 'private' instead"); 3063 } else if (rule == "unix") { 3064 allowUnix = false; 3065 } else if (rule == "unix-abstract") { 3066 allowAbstractUnix = false; 3067 } else { 3068 denyCidrs.add(CidrRange(rule)); 3069 } 3070 } 3071 } 3072 3073 bool NetworkFilter::shouldAllow(const struct sockaddr* addr, uint addrlen) { 3074 KJ_REQUIRE(addrlen >= sizeof(addr->sa_family)); 3075 3076 #if !_WIN32 3077 if (addr->sa_family == AF_UNIX) { 3078 auto path = safeUnixPath(reinterpret_cast<const struct sockaddr_un*>(addr), addrlen); 3079 if (path.size() > 0 && path[0] == '\0') { 3080 return allowAbstractUnix; 3081 } else { 3082 return allowUnix; 3083 } 3084 } 3085 #endif 3086 3087 bool allowed = false; 3088 uint allowSpecificity = 0; 3089 for (auto& cidr: allowCidrs) { 3090 if (cidr.matches(addr)) { 3091 allowSpecificity = kj::max(allowSpecificity, cidr.getSpecificity()); 3092 allowed = true; 3093 } 3094 } 3095 if (!allowed) return false; 3096 for (auto& cidr: denyCidrs) { 3097 if (cidr.matches(addr)) { 3098 if (cidr.getSpecificity() >= allowSpecificity) return false; 3099 } 3100 } 3101 3102 KJ_IF_MAYBE(n, next) { 3103 return n->shouldAllow(addr, addrlen); 3104 } else { 3105 return true; 3106 } 3107 } 3108 3109 bool NetworkFilter::shouldAllowParse(const struct sockaddr* addr, uint addrlen) { 3110 bool matched = false; 3111 #if !_WIN32 3112 if (addr->sa_family == AF_UNIX) { 3113 auto path = safeUnixPath(reinterpret_cast<const struct sockaddr_un*>(addr), addrlen); 3114 if (path.size() > 0 && path[0] == '\0') { 3115 if (allowAbstractUnix) matched = true; 3116 } else { 3117 if (allowUnix) matched = true; 3118 } 3119 } else { 3120 #endif 3121 for (auto& cidr: allowCidrs) { 3122 if (cidr.matchesFamily(addr->sa_family)) { 3123 matched = true; 3124 } 3125 } 3126 #if !_WIN32 3127 } 3128 #endif 3129 3130 if (matched) { 3131 KJ_IF_MAYBE(n, next) { 3132 return n->shouldAllowParse(addr, addrlen); 3133 } else { 3134 return true; 3135 } 3136 } else { 3137 // No allow rule matches this address family, so don't even allow parsing it. 3138 return false; 3139 } 3140 } 3141 3142 } // namespace _ (private) 3143 3144 // ======================================================================================= 3145 // PeerIdentity implementations 3146 3147 namespace { 3148 3149 class NetworkPeerIdentityImpl final: public NetworkPeerIdentity { 3150 public: 3151 NetworkPeerIdentityImpl(kj::Own<NetworkAddress> addr): addr(kj::mv(addr)) {} 3152 3153 kj::String toString() override { return addr->toString(); } 3154 NetworkAddress& getAddress() override { return *addr; } 3155 3156 private: 3157 kj::Own<NetworkAddress> addr; 3158 }; 3159 3160 class LocalPeerIdentityImpl final: public LocalPeerIdentity { 3161 public: 3162 LocalPeerIdentityImpl(Credentials creds): creds(creds) {} 3163 3164 kj::String toString() override { 3165 char pidBuffer[16]; 3166 kj::StringPtr pidStr = nullptr; 3167 KJ_IF_MAYBE(p, creds.pid) { 3168 pidStr = strPreallocated(pidBuffer, " pid:", *p); 3169 } 3170 3171 char uidBuffer[16]; 3172 kj::StringPtr uidStr = nullptr; 3173 KJ_IF_MAYBE(u, creds.uid) { 3174 uidStr = strPreallocated(uidBuffer, " uid:", *u); 3175 } 3176 3177 return kj::str("(local peer", pidStr, uidStr, ")"); 3178 } 3179 3180 Credentials getCredentials() override { return creds; } 3181 3182 private: 3183 Credentials creds; 3184 }; 3185 3186 class UnknownPeerIdentityImpl final: public UnknownPeerIdentity { 3187 public: 3188 kj::String toString() override { 3189 return kj::str("(unknown peer)"); 3190 } 3191 }; 3192 3193 } // namespace 3194 3195 kj::Own<NetworkPeerIdentity> NetworkPeerIdentity::newInstance(kj::Own<NetworkAddress> addr) { 3196 return kj::heap<NetworkPeerIdentityImpl>(kj::mv(addr)); 3197 } 3198 3199 kj::Own<LocalPeerIdentity> LocalPeerIdentity::newInstance(LocalPeerIdentity::Credentials creds) { 3200 return kj::heap<LocalPeerIdentityImpl>(creds); 3201 } 3202 3203 kj::Own<UnknownPeerIdentity> UnknownPeerIdentity::newInstance() { 3204 static UnknownPeerIdentityImpl instance; 3205 return { &instance, NullDisposer::instance }; 3206 } 3207 3208 Promise<AuthenticatedStream> ConnectionReceiver::acceptAuthenticated() { 3209 return accept().then([](Own<AsyncIoStream> stream) { 3210 return AuthenticatedStream { kj::mv(stream), UnknownPeerIdentity::newInstance() }; 3211 }); 3212 } 3213 3214 Promise<AuthenticatedStream> NetworkAddress::connectAuthenticated() { 3215 return connect().then([](Own<AsyncIoStream> stream) { 3216 return AuthenticatedStream { kj::mv(stream), UnknownPeerIdentity::newInstance() }; 3217 }); 3218 } 3219 3220 } // namespace kj