capability.c++ (40116B)
1 // Copyright (c) 2013-2014 Sandstorm Development Group, Inc. and contributors 2 // Licensed under the MIT License: 3 // 4 // Permission is hereby granted, free of charge, to any person obtaining a copy 5 // of this software and associated documentation files (the "Software"), to deal 6 // in the Software without restriction, including without limitation the rights 7 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 8 // copies of the Software, and to permit persons to whom the Software is 9 // furnished to do so, subject to the following conditions: 10 // 11 // The above copyright notice and this permission notice shall be included in 12 // all copies or substantial portions of the Software. 13 // 14 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 17 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 19 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 20 // THE SOFTWARE. 21 22 #define CAPNP_PRIVATE 23 24 #include "capability.h" 25 #include "message.h" 26 #include "arena.h" 27 #include <kj/refcount.h> 28 #include <kj/debug.h> 29 #include <kj/vector.h> 30 #include <map> 31 #include "generated-header-support.h" 32 33 namespace capnp { 34 35 namespace _ { 36 37 void setGlobalBrokenCapFactoryForLayoutCpp(BrokenCapFactory& factory); 38 // Defined in layout.c++. 39 40 } // namespace _ 41 42 namespace { 43 44 static kj::Own<ClientHook> newNullCap(); 45 46 class BrokenCapFactoryImpl: public _::BrokenCapFactory { 47 public: 48 kj::Own<ClientHook> newBrokenCap(kj::StringPtr description) override { 49 return capnp::newBrokenCap(description); 50 } 51 kj::Own<ClientHook> newNullCap() override { 52 return capnp::newNullCap(); 53 } 54 }; 55 56 static BrokenCapFactoryImpl brokenCapFactory; 57 58 } // namespace 59 60 ClientHook::ClientHook() { 61 setGlobalBrokenCapFactoryForLayoutCpp(brokenCapFactory); 62 } 63 64 // ======================================================================================= 65 66 Capability::Client::Client(decltype(nullptr)) 67 : hook(newNullCap()) {} 68 69 Capability::Client::Client(kj::Exception&& exception) 70 : hook(newBrokenCap(kj::mv(exception))) {} 71 72 kj::Promise<kj::Maybe<int>> Capability::Client::getFd() { 73 auto fd = hook->getFd(); 74 if (fd != nullptr) { 75 return fd; 76 } else KJ_IF_MAYBE(promise, hook->whenMoreResolved()) { 77 return promise->attach(hook->addRef()).then([](kj::Own<ClientHook> newHook) { 78 return Client(kj::mv(newHook)).getFd(); 79 }); 80 } else { 81 return kj::Maybe<int>(nullptr); 82 } 83 } 84 85 kj::Maybe<kj::Promise<Capability::Client>> Capability::Server::shortenPath() { 86 return nullptr; 87 } 88 89 Capability::Server::DispatchCallResult Capability::Server::internalUnimplemented( 90 const char* actualInterfaceName, uint64_t requestedTypeId) { 91 return { 92 KJ_EXCEPTION(UNIMPLEMENTED, "Requested interface not implemented.", 93 actualInterfaceName, requestedTypeId), 94 false 95 }; 96 } 97 98 Capability::Server::DispatchCallResult Capability::Server::internalUnimplemented( 99 const char* interfaceName, uint64_t typeId, uint16_t methodId) { 100 return { 101 KJ_EXCEPTION(UNIMPLEMENTED, "Method not implemented.", interfaceName, typeId, methodId), 102 false 103 }; 104 } 105 106 kj::Promise<void> Capability::Server::internalUnimplemented( 107 const char* interfaceName, const char* methodName, uint64_t typeId, uint16_t methodId) { 108 return KJ_EXCEPTION(UNIMPLEMENTED, "Method not implemented.", interfaceName, 109 typeId, methodName, methodId); 110 } 111 112 ResponseHook::~ResponseHook() noexcept(false) {} 113 114 kj::Promise<void> ClientHook::whenResolved() { 115 KJ_IF_MAYBE(promise, whenMoreResolved()) { 116 return promise->then([](kj::Own<ClientHook>&& resolution) { 117 return resolution->whenResolved(); 118 }); 119 } else { 120 return kj::READY_NOW; 121 } 122 } 123 124 kj::Promise<void> Capability::Client::whenResolved() { 125 return hook->whenResolved().attach(hook->addRef()); 126 } 127 128 // ======================================================================================= 129 130 static inline uint firstSegmentSize(kj::Maybe<MessageSize> sizeHint) { 131 KJ_IF_MAYBE(s, sizeHint) { 132 return s->wordCount; 133 } else { 134 return SUGGESTED_FIRST_SEGMENT_WORDS; 135 } 136 } 137 138 class LocalResponse final: public ResponseHook, public kj::Refcounted { 139 public: 140 LocalResponse(kj::Maybe<MessageSize> sizeHint) 141 : message(firstSegmentSize(sizeHint)) {} 142 143 MallocMessageBuilder message; 144 }; 145 146 class LocalCallContext final: public CallContextHook, public ResponseHook, public kj::Refcounted { 147 public: 148 LocalCallContext(kj::Own<MallocMessageBuilder>&& request, kj::Own<ClientHook> clientRef, 149 kj::Own<kj::PromiseFulfiller<void>> cancelAllowedFulfiller) 150 : request(kj::mv(request)), clientRef(kj::mv(clientRef)), 151 cancelAllowedFulfiller(kj::mv(cancelAllowedFulfiller)) {} 152 153 AnyPointer::Reader getParams() override { 154 KJ_IF_MAYBE(r, request) { 155 return r->get()->getRoot<AnyPointer>(); 156 } else { 157 KJ_FAIL_REQUIRE("Can't call getParams() after releaseParams()."); 158 } 159 } 160 void releaseParams() override { 161 request = nullptr; 162 } 163 AnyPointer::Builder getResults(kj::Maybe<MessageSize> sizeHint) override { 164 if (response == nullptr) { 165 auto localResponse = kj::refcounted<LocalResponse>(sizeHint); 166 responseBuilder = localResponse->message.getRoot<AnyPointer>(); 167 response = Response<AnyPointer>(responseBuilder.asReader(), kj::mv(localResponse)); 168 } 169 return responseBuilder; 170 } 171 void setPipeline(kj::Own<PipelineHook>&& pipeline) override { 172 KJ_IF_MAYBE(f, tailCallPipelineFulfiller) { 173 f->get()->fulfill(AnyPointer::Pipeline(kj::mv(pipeline))); 174 } 175 } 176 kj::Promise<void> tailCall(kj::Own<RequestHook>&& request) override { 177 auto result = directTailCall(kj::mv(request)); 178 KJ_IF_MAYBE(f, tailCallPipelineFulfiller) { 179 f->get()->fulfill(AnyPointer::Pipeline(kj::mv(result.pipeline))); 180 } 181 return kj::mv(result.promise); 182 } 183 ClientHook::VoidPromiseAndPipeline directTailCall(kj::Own<RequestHook>&& request) override { 184 KJ_REQUIRE(response == nullptr, "Can't call tailCall() after initializing the results struct."); 185 186 auto promise = request->send(); 187 188 auto voidPromise = promise.then([this](Response<AnyPointer>&& tailResponse) { 189 response = kj::mv(tailResponse); 190 }); 191 192 return { kj::mv(voidPromise), PipelineHook::from(kj::mv(promise)) }; 193 } 194 kj::Promise<AnyPointer::Pipeline> onTailCall() override { 195 auto paf = kj::newPromiseAndFulfiller<AnyPointer::Pipeline>(); 196 tailCallPipelineFulfiller = kj::mv(paf.fulfiller); 197 return kj::mv(paf.promise); 198 } 199 void allowCancellation() override { 200 cancelAllowedFulfiller->fulfill(); 201 } 202 kj::Own<CallContextHook> addRef() override { 203 return kj::addRef(*this); 204 } 205 206 kj::Maybe<kj::Own<MallocMessageBuilder>> request; 207 kj::Maybe<Response<AnyPointer>> response; 208 AnyPointer::Builder responseBuilder = nullptr; // only valid if `response` is non-null 209 kj::Own<ClientHook> clientRef; 210 kj::Maybe<kj::Own<kj::PromiseFulfiller<AnyPointer::Pipeline>>> tailCallPipelineFulfiller; 211 kj::Own<kj::PromiseFulfiller<void>> cancelAllowedFulfiller; 212 }; 213 214 class LocalRequest final: public RequestHook { 215 public: 216 inline LocalRequest(uint64_t interfaceId, uint16_t methodId, 217 kj::Maybe<MessageSize> sizeHint, kj::Own<ClientHook> client) 218 : message(kj::heap<MallocMessageBuilder>(firstSegmentSize(sizeHint))), 219 interfaceId(interfaceId), methodId(methodId), client(kj::mv(client)) {} 220 221 RemotePromise<AnyPointer> send() override { 222 KJ_REQUIRE(message.get() != nullptr, "Already called send() on this request."); 223 224 auto cancelPaf = kj::newPromiseAndFulfiller<void>(); 225 226 auto context = kj::refcounted<LocalCallContext>( 227 kj::mv(message), client->addRef(), kj::mv(cancelPaf.fulfiller)); 228 auto promiseAndPipeline = client->call(interfaceId, methodId, kj::addRef(*context)); 229 230 // We have to make sure the call is not canceled unless permitted. We need to fork the promise 231 // so that if the client drops their copy, the promise isn't necessarily canceled. 232 auto forked = promiseAndPipeline.promise.fork(); 233 234 // We daemonize one branch, but only after joining it with the promise that fires if 235 // cancellation is allowed. 236 forked.addBranch() 237 .attach(kj::addRef(*context)) 238 .exclusiveJoin(kj::mv(cancelPaf.promise)) 239 .detach([](kj::Exception&&) {}); // ignore exceptions 240 241 // Now the other branch returns the response from the context. 242 auto promise = forked.addBranch().then(kj::mvCapture(context, 243 [](kj::Own<LocalCallContext>&& context) { 244 // force response allocation 245 auto reader = context->getResults(MessageSize { 0, 0 }).asReader(); 246 247 if (context->isShared()) { 248 // We can't just move away context->response as `context` itself is still referenced by 249 // something -- probably a Pipeline object. As a bit of a hack, LocalCallContext itself 250 // implements ResponseHook so that we can just return a ref on it. 251 // 252 // TODO(cleanup): Maybe ResponseHook should be refcounted? Note that context->response 253 // might not necessarily contain a LocalResponse if it was resolved by a tail call, so 254 // we'd have to add refcounting to all ResponseHook implementations. 255 context->releaseParams(); // The call is done so params can definitely be dropped. 256 context->clientRef = nullptr; // Definitely not using the client cap anymore either. 257 return Response<AnyPointer>(reader, kj::mv(context)); 258 } else { 259 return kj::mv(KJ_ASSERT_NONNULL(context->response)); 260 } 261 })); 262 263 // We return the other branch. 264 return RemotePromise<AnyPointer>( 265 kj::mv(promise), AnyPointer::Pipeline(kj::mv(promiseAndPipeline.pipeline))); 266 } 267 268 kj::Promise<void> sendStreaming() override { 269 // We don't do any special handling of streaming in RequestHook for local requests, because 270 // there is no latency to compensate for between the client and server in this case. 271 return send().ignoreResult(); 272 } 273 274 const void* getBrand() override { 275 return nullptr; 276 } 277 278 kj::Own<MallocMessageBuilder> message; 279 280 private: 281 uint64_t interfaceId; 282 uint16_t methodId; 283 kj::Own<ClientHook> client; 284 }; 285 286 // ======================================================================================= 287 // Call queues 288 // 289 // These classes handle pipelining in the case where calls need to be queued in-memory until some 290 // local operation completes. 291 292 class QueuedPipeline final: public PipelineHook, public kj::Refcounted { 293 // A PipelineHook which simply queues calls while waiting for a PipelineHook to which to forward 294 // them. 295 296 public: 297 QueuedPipeline(kj::Promise<kj::Own<PipelineHook>>&& promiseParam) 298 : promise(promiseParam.fork()), 299 selfResolutionOp(promise.addBranch().then([this](kj::Own<PipelineHook>&& inner) { 300 redirect = kj::mv(inner); 301 }, [this](kj::Exception&& exception) { 302 redirect = newBrokenPipeline(kj::mv(exception)); 303 }).eagerlyEvaluate(nullptr)) {} 304 305 kj::Own<PipelineHook> addRef() override { 306 return kj::addRef(*this); 307 } 308 309 kj::Own<ClientHook> getPipelinedCap(kj::ArrayPtr<const PipelineOp> ops) override { 310 auto copy = kj::heapArrayBuilder<PipelineOp>(ops.size()); 311 for (auto& op: ops) { 312 copy.add(op); 313 } 314 return getPipelinedCap(copy.finish()); 315 } 316 317 kj::Own<ClientHook> getPipelinedCap(kj::Array<PipelineOp>&& ops) override; 318 319 private: 320 kj::ForkedPromise<kj::Own<PipelineHook>> promise; 321 322 kj::Maybe<kj::Own<PipelineHook>> redirect; 323 // Once the promise resolves, this will become non-null and point to the underlying object. 324 325 kj::Promise<void> selfResolutionOp; 326 // Represents the operation which will set `redirect` when possible. 327 328 kj::HashMap<kj::Array<PipelineOp>, kj::Own<ClientHook>> clientMap; 329 // If the same pipelined cap is requested twice, we have to return the same object. This is 330 // necessary because each ClientHook we create is a QueuedClient which queues up calls. If we 331 // return a new one each time, there will be several queues, and ordering of calls will be lost 332 // between the queues. 333 // 334 // One case where this is particularly problematic is with promises resolved over RPC. Consider 335 // this case: 336 // 337 // * Alice holds a promise capability P pointing towards Bob. 338 // * Bob makes a call Q on an object hosted by Alice. 339 // * Without waiting for Q to complete, Bob obtains a pipelined-promise capability for Q's 340 // eventual result, P2. 341 // * Alice invokes a method M on P. The call is sent to Bob. 342 // * Bob resolves Alice's original promise P to P2. 343 // * Alice receives a Resolve message from Bob resolving P to Q's eventual result. 344 // * As a result, Alice calls getPipelinedCap() on the QueuedPipeline for Q's result, which 345 // returns a QueuedClient for that result, which we'll call QR1. 346 // * Alice also sends a Disembargo to Bob. 347 // * Alice calls a method M2 on P. This call is blocked locally waiting for the disembargo to 348 // complete. 349 // * Bob receives Alice's first method call, M. Since it's addressed to P, which later resolved 350 // to Q's result, Bob reflects the call back to Alice. 351 // * Alice receives the reflected call, which is addressed to Q's result. 352 // * Alice calls getPipelinedCap() on the QueuedPipeline for Q's result, which returns a 353 // QueuedClient for that result, which we'll call QR2. 354 // * Alice enqueues the call M on QR2. 355 // * Bob receives Alice's Disembargo message, and reflects it back. 356 // * Alices receives the Disembrago. 357 // * Alice unblocks the method cgall M2, which had been blocked on the embargo. 358 // * The call M2 is then equeued onto QR1. 359 // * Finally, the call Q completes. 360 // * This causes QR1 and QR2 to resolve to their final destinations. But if QR1 and QR2 are 361 // separate objects, then one of them must resolve first. QR1 was created first, so naturally 362 // it resolves first, followed by QR2. 363 // * Because QR1 resolves first, method call M2 is delivered first. 364 // * QR2 resolves second, so method call M1 is delivered next. 365 // * THIS IS THE WRONG ORDER! 366 // 367 // In order to avoid this problem, it's necessary for QR1 and QR2 to be the same object, so that 368 // they share the same call queue. In this case, M2 is correctly enqueued onto QR2 *after* M1 was 369 // enqueued on QR1, and so the method calls are delivered in the correct order. 370 }; 371 372 class QueuedClient final: public ClientHook, public kj::Refcounted { 373 // A ClientHook which simply queues calls while waiting for a ClientHook to which to forward 374 // them. 375 376 public: 377 QueuedClient(kj::Promise<kj::Own<ClientHook>>&& promiseParam) 378 : promise(promiseParam.fork()), 379 selfResolutionOp(promise.addBranch().then([this](kj::Own<ClientHook>&& inner) { 380 redirect = kj::mv(inner); 381 }, [this](kj::Exception&& exception) { 382 redirect = newBrokenCap(kj::mv(exception)); 383 }).eagerlyEvaluate(nullptr)), 384 promiseForCallForwarding(promise.addBranch().fork()), 385 promiseForClientResolution(promise.addBranch().fork()) {} 386 387 Request<AnyPointer, AnyPointer> newCall( 388 uint64_t interfaceId, uint16_t methodId, kj::Maybe<MessageSize> sizeHint) override { 389 auto hook = kj::heap<LocalRequest>( 390 interfaceId, methodId, sizeHint, kj::addRef(*this)); 391 auto root = hook->message->getRoot<AnyPointer>(); 392 return Request<AnyPointer, AnyPointer>(root, kj::mv(hook)); 393 } 394 395 VoidPromiseAndPipeline call(uint64_t interfaceId, uint16_t methodId, 396 kj::Own<CallContextHook>&& context) override { 397 // This is a bit complicated. We need to initiate this call later on. When we initiate the 398 // call, we'll get a void promise for its completion and a pipeline object. Right now, we have 399 // to produce a similar void promise and pipeline that will eventually be chained to those. 400 // The problem is, these are two independent objects, but they both depend on the result of 401 // one future call. 402 // 403 // So, we need to set up a continuation that will initiate the call later, then we need to 404 // fork the promise for that continuation in order to send the completion promise and the 405 // pipeline to their respective places. 406 // 407 // TODO(perf): Too much reference counting? Can we do better? Maybe a way to fork 408 // Promise<Tuple<T, U>> into Tuple<Promise<T>, Promise<U>>? 409 410 struct CallResultHolder: public kj::Refcounted { 411 // Essentially acts as a refcounted \VoidPromiseAndPipeline, so that we can create a promise 412 // for it and fork that promise. 413 414 VoidPromiseAndPipeline content; 415 // One branch of the fork will use content.promise, the other branch will use 416 // content.pipeline. Neither branch will touch the other's piece. 417 418 inline CallResultHolder(VoidPromiseAndPipeline&& content): content(kj::mv(content)) {} 419 420 kj::Own<CallResultHolder> addRef() { return kj::addRef(*this); } 421 }; 422 423 // Create a promise for the call initiation. 424 kj::ForkedPromise<kj::Own<CallResultHolder>> callResultPromise = 425 promiseForCallForwarding.addBranch().then(kj::mvCapture(context, 426 [=](kj::Own<CallContextHook>&& context, kj::Own<ClientHook>&& client){ 427 return kj::refcounted<CallResultHolder>( 428 client->call(interfaceId, methodId, kj::mv(context))); 429 })).fork(); 430 431 // Create a promise that extracts the pipeline from the call initiation, and construct our 432 // QueuedPipeline to chain to it. 433 auto pipelinePromise = callResultPromise.addBranch().then( 434 [](kj::Own<CallResultHolder>&& callResult){ 435 return kj::mv(callResult->content.pipeline); 436 }); 437 auto pipeline = kj::refcounted<QueuedPipeline>(kj::mv(pipelinePromise)); 438 439 // Create a promise that simply chains to the void promise produced by the call initiation. 440 auto completionPromise = callResultPromise.addBranch().then( 441 [](kj::Own<CallResultHolder>&& callResult){ 442 return kj::mv(callResult->content.promise); 443 }); 444 445 // OK, now we can actually return our thing. 446 return VoidPromiseAndPipeline { kj::mv(completionPromise), kj::mv(pipeline) }; 447 } 448 449 kj::Maybe<ClientHook&> getResolved() override { 450 KJ_IF_MAYBE(inner, redirect) { 451 return **inner; 452 } else { 453 return nullptr; 454 } 455 } 456 457 kj::Maybe<kj::Promise<kj::Own<ClientHook>>> whenMoreResolved() override { 458 return promiseForClientResolution.addBranch(); 459 } 460 461 kj::Own<ClientHook> addRef() override { 462 return kj::addRef(*this); 463 } 464 465 const void* getBrand() override { 466 return nullptr; 467 } 468 469 kj::Maybe<int> getFd() override { 470 KJ_IF_MAYBE(r, redirect) { 471 return r->get()->getFd(); 472 } else { 473 return nullptr; 474 } 475 } 476 477 private: 478 typedef kj::ForkedPromise<kj::Own<ClientHook>> ClientHookPromiseFork; 479 480 kj::Maybe<kj::Own<ClientHook>> redirect; 481 // Once the promise resolves, this will become non-null and point to the underlying object. 482 483 ClientHookPromiseFork promise; 484 // Promise that resolves when we have a new ClientHook to forward to. 485 // 486 // This fork shall only have three branches: `selfResolutionOp`, `promiseForCallForwarding`, and 487 // `promiseForClientResolution`, in that order. 488 489 kj::Promise<void> selfResolutionOp; 490 // Represents the operation which will set `redirect` when possible. 491 492 ClientHookPromiseFork promiseForCallForwarding; 493 // When this promise resolves, each queued call will be forwarded to the real client. This needs 494 // to occur *before* any 'whenMoreResolved()' promises resolve, because we want to make sure 495 // previously-queued calls are delivered before any new calls made in response to the resolution. 496 497 ClientHookPromiseFork promiseForClientResolution; 498 // whenMoreResolved() returns forks of this promise. These must resolve *after* queued calls 499 // have been initiated (so that any calls made in the whenMoreResolved() handler are correctly 500 // delivered after calls made earlier), but *before* any queued calls return (because it might 501 // confuse the application if a queued call returns before the capability on which it was made 502 // resolves). Luckily, we know that queued calls will involve, at the very least, an 503 // eventLoop.evalLater. 504 }; 505 506 kj::Own<ClientHook> QueuedPipeline::getPipelinedCap(kj::Array<PipelineOp>&& ops) { 507 KJ_IF_MAYBE(r, redirect) { 508 return r->get()->getPipelinedCap(kj::mv(ops)); 509 } else { 510 return clientMap.findOrCreate(ops.asPtr(), [&]() { 511 auto clientPromise = promise.addBranch() 512 .then([ops = KJ_MAP(op, ops) { return op; }](kj::Own<PipelineHook> pipeline) { 513 return pipeline->getPipelinedCap(kj::mv(ops)); 514 }); 515 return kj::HashMap<kj::Array<PipelineOp>, kj::Own<ClientHook>>::Entry { 516 kj::mv(ops), kj::refcounted<QueuedClient>(kj::mv(clientPromise)) 517 }; 518 })->addRef(); 519 } 520 } 521 522 // ======================================================================================= 523 524 class LocalPipeline final: public PipelineHook, public kj::Refcounted { 525 public: 526 inline LocalPipeline(kj::Own<CallContextHook>&& contextParam) 527 : context(kj::mv(contextParam)), 528 results(context->getResults(MessageSize { 0, 0 })) {} 529 530 kj::Own<PipelineHook> addRef() { 531 return kj::addRef(*this); 532 } 533 534 kj::Own<ClientHook> getPipelinedCap(kj::ArrayPtr<const PipelineOp> ops) { 535 return results.getPipelinedCap(ops); 536 } 537 538 private: 539 kj::Own<CallContextHook> context; 540 AnyPointer::Reader results; 541 }; 542 543 class LocalClient final: public ClientHook, public kj::Refcounted { 544 public: 545 LocalClient(kj::Own<Capability::Server>&& serverParam) 546 : server(kj::mv(serverParam)) { 547 server->thisHook = this; 548 startResolveTask(); 549 } 550 LocalClient(kj::Own<Capability::Server>&& serverParam, 551 _::CapabilityServerSetBase& capServerSet, void* ptr) 552 : server(kj::mv(serverParam)), capServerSet(&capServerSet), ptr(ptr) { 553 server->thisHook = this; 554 startResolveTask(); 555 } 556 557 ~LocalClient() noexcept(false) { 558 server->thisHook = nullptr; 559 } 560 561 Request<AnyPointer, AnyPointer> newCall( 562 uint64_t interfaceId, uint16_t methodId, kj::Maybe<MessageSize> sizeHint) override { 563 KJ_IF_MAYBE(r, resolved) { 564 // We resolved to a shortened path. New calls MUST go directly to the replacement capability 565 // so that their ordering is consistent with callers who call getResolved() to get direct 566 // access to the new capability. In particular it's important that we don't place these calls 567 // in our streaming queue. 568 return r->get()->newCall(interfaceId, methodId, sizeHint); 569 } 570 571 auto hook = kj::heap<LocalRequest>( 572 interfaceId, methodId, sizeHint, kj::addRef(*this)); 573 auto root = hook->message->getRoot<AnyPointer>(); 574 return Request<AnyPointer, AnyPointer>(root, kj::mv(hook)); 575 } 576 577 VoidPromiseAndPipeline call(uint64_t interfaceId, uint16_t methodId, 578 kj::Own<CallContextHook>&& context) override { 579 KJ_IF_MAYBE(r, resolved) { 580 // We resolved to a shortened path. New calls MUST go directly to the replacement capability 581 // so that their ordering is consistent with callers who call getResolved() to get direct 582 // access to the new capability. In particular it's important that we don't place these calls 583 // in our streaming queue. 584 return r->get()->call(interfaceId, methodId, kj::mv(context)); 585 } 586 587 auto contextPtr = context.get(); 588 589 // We don't want to actually dispatch the call synchronously, because we don't want the callee 590 // to have any side effects before the promise is returned to the caller. This helps avoid 591 // race conditions. 592 // 593 // So, we do an evalLater() here. 594 // 595 // Note also that QueuedClient depends on this evalLater() to ensure that pipelined calls don't 596 // complete before 'whenMoreResolved()' promises resolve. 597 auto promise = kj::evalLater([this,interfaceId,methodId,contextPtr]() { 598 if (blocked) { 599 return kj::newAdaptedPromise<kj::Promise<void>, BlockedCall>( 600 *this, interfaceId, methodId, *contextPtr); 601 } else { 602 return callInternal(interfaceId, methodId, *contextPtr); 603 } 604 }).attach(kj::addRef(*this)); 605 606 // We have to fork this promise for the pipeline to receive a copy of the answer. 607 auto forked = promise.fork(); 608 609 auto pipelinePromise = forked.addBranch().then(kj::mvCapture(context->addRef(), 610 [=](kj::Own<CallContextHook>&& context) -> kj::Own<PipelineHook> { 611 context->releaseParams(); 612 return kj::refcounted<LocalPipeline>(kj::mv(context)); 613 })); 614 615 auto tailPipelinePromise = context->onTailCall().then([](AnyPointer::Pipeline&& pipeline) { 616 return kj::mv(pipeline.hook); 617 }); 618 619 pipelinePromise = pipelinePromise.exclusiveJoin(kj::mv(tailPipelinePromise)); 620 621 auto completionPromise = forked.addBranch().attach(kj::mv(context)); 622 623 return VoidPromiseAndPipeline { kj::mv(completionPromise), 624 kj::refcounted<QueuedPipeline>(kj::mv(pipelinePromise)) }; 625 } 626 627 kj::Maybe<ClientHook&> getResolved() override { 628 return resolved.map([](kj::Own<ClientHook>& hook) -> ClientHook& { return *hook; }); 629 } 630 631 kj::Maybe<kj::Promise<kj::Own<ClientHook>>> whenMoreResolved() override { 632 KJ_IF_MAYBE(r, resolved) { 633 return kj::Promise<kj::Own<ClientHook>>(r->get()->addRef()); 634 } else KJ_IF_MAYBE(t, resolveTask) { 635 return t->addBranch().then([this]() { 636 return KJ_ASSERT_NONNULL(resolved)->addRef(); 637 }); 638 } else { 639 return nullptr; 640 } 641 } 642 643 kj::Own<ClientHook> addRef() override { 644 return kj::addRef(*this); 645 } 646 647 static const uint BRAND; 648 // Value is irrelevant; used for pointer. 649 650 const void* getBrand() override { 651 return &BRAND; 652 } 653 654 kj::Maybe<kj::Promise<void*>> getLocalServer(_::CapabilityServerSetBase& capServerSet) { 655 // If this is a local capability created through `capServerSet`, return the underlying Server. 656 // Otherwise, return nullptr. Default implementation (which everyone except LocalClient should 657 // use) always returns nullptr. 658 659 if (this->capServerSet == &capServerSet) { 660 if (blocked) { 661 // If streaming calls are in-flight, it could be the case that they were originally sent 662 // over RPC and reflected back, before the capability had resolved to a local object. In 663 // that case, the client may already perceive these calls as "done" because the RPC 664 // implementation caused the client promise to resolve early. However, the capability is 665 // now local, and the app is trying to break through the LocalClient wrapper and access 666 // the server directly, bypassing the stream queue. Since the app thinks that all 667 // previous calls already completed, it may then try to queue a new call directly on the 668 // server, jumping the queue. 669 // 670 // We can solve this by delaying getLocalServer() until all current streaming calls have 671 // finished. Note that if a new streaming call is started *after* this point, we need not 672 // worry about that, because in this case it is presumably a local call and the caller 673 // won't be informed of completion until the call actually does complete. Thus the caller 674 // is well-aware that this call is still in-flight. 675 // 676 // However, the app still cannot assume that there aren't multiple clients, perhaps even 677 // a malicious client that tries to send stream requests that overlap with the app's 678 // direct use of the server... so it's up to the app to check for and guard against 679 // concurrent calls after using getLocalServer(). 680 return kj::newAdaptedPromise<kj::Promise<void>, BlockedCall>(*this) 681 .then([this]() { return ptr; }); 682 } else { 683 return kj::Promise<void*>(ptr); 684 } 685 } else { 686 return nullptr; 687 } 688 } 689 690 kj::Maybe<int> getFd() override { 691 return server->getFd(); 692 } 693 694 private: 695 kj::Own<Capability::Server> server; 696 _::CapabilityServerSetBase* capServerSet = nullptr; 697 void* ptr = nullptr; 698 699 kj::Maybe<kj::ForkedPromise<void>> resolveTask; 700 kj::Maybe<kj::Own<ClientHook>> resolved; 701 702 void startResolveTask() { 703 resolveTask = server->shortenPath().map([this](kj::Promise<Capability::Client> promise) { 704 return promise.then([this](Capability::Client&& cap) { 705 auto hook = ClientHook::from(kj::mv(cap)); 706 707 if (blocked) { 708 // This is a streaming interface and we have some calls queued up as a result. We cannot 709 // resolve directly to the new shorter path because this may allow new calls to hop 710 // the queue -- we need to embargo new calls until the queue clears out. 711 auto promise = kj::newAdaptedPromise<kj::Promise<void>, BlockedCall>(*this) 712 .then([hook = kj::mv(hook)]() mutable { return kj::mv(hook); }); 713 hook = newLocalPromiseClient(kj::mv(promise)); 714 } 715 716 resolved = kj::mv(hook); 717 }).fork(); 718 }); 719 } 720 721 class BlockedCall { 722 public: 723 BlockedCall(kj::PromiseFulfiller<kj::Promise<void>>& fulfiller, LocalClient& client, 724 uint64_t interfaceId, uint16_t methodId, CallContextHook& context) 725 : fulfiller(fulfiller), client(client), 726 interfaceId(interfaceId), methodId(methodId), context(context), 727 prev(client.blockedCallsEnd) { 728 *prev = *this; 729 client.blockedCallsEnd = &next; 730 } 731 732 BlockedCall(kj::PromiseFulfiller<kj::Promise<void>>& fulfiller, LocalClient& client) 733 : fulfiller(fulfiller), client(client), prev(client.blockedCallsEnd) { 734 *prev = *this; 735 client.blockedCallsEnd = &next; 736 } 737 738 ~BlockedCall() noexcept(false) { 739 unlink(); 740 } 741 742 void unblock() { 743 unlink(); 744 KJ_IF_MAYBE(c, context) { 745 fulfiller.fulfill(kj::evalNow([&]() { 746 return client.callInternal(interfaceId, methodId, *c); 747 })); 748 } else { 749 // This is just a barrier. 750 fulfiller.fulfill(kj::READY_NOW); 751 } 752 } 753 754 private: 755 kj::PromiseFulfiller<kj::Promise<void>>& fulfiller; 756 LocalClient& client; 757 uint64_t interfaceId; 758 uint16_t methodId; 759 kj::Maybe<CallContextHook&> context; 760 761 kj::Maybe<BlockedCall&> next; 762 kj::Maybe<BlockedCall&>* prev; 763 764 void unlink() { 765 if (prev != nullptr) { 766 *prev = next; 767 KJ_IF_MAYBE(n, next) { 768 n->prev = prev; 769 } else { 770 client.blockedCallsEnd = prev; 771 } 772 prev = nullptr; 773 } 774 } 775 }; 776 777 class BlockingScope { 778 public: 779 BlockingScope(LocalClient& client): client(client) { client.blocked = true; } 780 BlockingScope(): client(nullptr) {} 781 BlockingScope(BlockingScope&& other): client(other.client) { other.client = nullptr; } 782 KJ_DISALLOW_COPY(BlockingScope); 783 784 ~BlockingScope() noexcept(false) { 785 KJ_IF_MAYBE(c, client) { 786 c->unblock(); 787 } 788 } 789 790 private: 791 kj::Maybe<LocalClient&> client; 792 }; 793 794 bool blocked = false; 795 kj::Maybe<kj::Exception> brokenException; 796 kj::Maybe<BlockedCall&> blockedCalls; 797 kj::Maybe<BlockedCall&>* blockedCallsEnd = &blockedCalls; 798 799 void unblock() { 800 blocked = false; 801 while (!blocked) { 802 KJ_IF_MAYBE(t, blockedCalls) { 803 t->unblock(); 804 } else { 805 break; 806 } 807 } 808 } 809 810 kj::Promise<void> callInternal(uint64_t interfaceId, uint16_t methodId, 811 CallContextHook& context) { 812 KJ_ASSERT(!blocked); 813 814 KJ_IF_MAYBE(e, brokenException) { 815 // Previous streaming call threw, so everything fails from now on. 816 return kj::cp(*e); 817 } 818 819 auto result = server->dispatchCall(interfaceId, methodId, 820 CallContext<AnyPointer, AnyPointer>(context)); 821 if (result.isStreaming) { 822 return result.promise 823 .catch_([this](kj::Exception&& e) { 824 brokenException = kj::cp(e); 825 kj::throwRecoverableException(kj::mv(e)); 826 }).attach(BlockingScope(*this)); 827 } else { 828 return kj::mv(result.promise); 829 } 830 } 831 }; 832 833 const uint LocalClient::BRAND = 0; 834 835 kj::Own<ClientHook> Capability::Client::makeLocalClient(kj::Own<Capability::Server>&& server) { 836 return kj::refcounted<LocalClient>(kj::mv(server)); 837 } 838 839 kj::Own<ClientHook> newLocalPromiseClient(kj::Promise<kj::Own<ClientHook>>&& promise) { 840 return kj::refcounted<QueuedClient>(kj::mv(promise)); 841 } 842 843 kj::Own<PipelineHook> newLocalPromisePipeline(kj::Promise<kj::Own<PipelineHook>>&& promise) { 844 return kj::refcounted<QueuedPipeline>(kj::mv(promise)); 845 } 846 847 // ======================================================================================= 848 849 namespace _ { // private 850 851 class PipelineBuilderHook final: public PipelineHook, public kj::Refcounted { 852 public: 853 PipelineBuilderHook(uint firstSegmentWords) 854 : message(firstSegmentWords), 855 root(message.getRoot<AnyPointer>()) {} 856 857 kj::Own<PipelineHook> addRef() override { 858 return kj::addRef(*this); 859 } 860 861 kj::Own<ClientHook> getPipelinedCap(kj::ArrayPtr<const PipelineOp> ops) override { 862 return root.asReader().getPipelinedCap(ops); 863 } 864 865 MallocMessageBuilder message; 866 AnyPointer::Builder root; 867 }; 868 869 PipelineBuilderPair newPipelineBuilder(uint firstSegmentWords) { 870 auto hook = kj::refcounted<PipelineBuilderHook>(firstSegmentWords); 871 auto root = hook->root; 872 return { root, kj::mv(hook) }; 873 } 874 875 } // namespace _ (private) 876 877 // ======================================================================================= 878 879 namespace { 880 881 class BrokenPipeline final: public PipelineHook, public kj::Refcounted { 882 public: 883 BrokenPipeline(const kj::Exception& exception): exception(exception) {} 884 885 kj::Own<PipelineHook> addRef() override { 886 return kj::addRef(*this); 887 } 888 889 kj::Own<ClientHook> getPipelinedCap(kj::ArrayPtr<const PipelineOp> ops) override; 890 891 private: 892 kj::Exception exception; 893 }; 894 895 class BrokenRequest final: public RequestHook { 896 public: 897 BrokenRequest(const kj::Exception& exception, kj::Maybe<MessageSize> sizeHint) 898 : exception(exception), message(firstSegmentSize(sizeHint)) {} 899 900 RemotePromise<AnyPointer> send() override { 901 return RemotePromise<AnyPointer>(kj::cp(exception), 902 AnyPointer::Pipeline(kj::refcounted<BrokenPipeline>(exception))); 903 } 904 905 kj::Promise<void> sendStreaming() override { 906 return kj::cp(exception); 907 } 908 909 const void* getBrand() override { 910 return nullptr; 911 } 912 913 kj::Exception exception; 914 MallocMessageBuilder message; 915 }; 916 917 class BrokenClient final: public ClientHook, public kj::Refcounted { 918 public: 919 BrokenClient(const kj::Exception& exception, bool resolved, const void* brand) 920 : exception(exception), resolved(resolved), brand(brand) {} 921 BrokenClient(const kj::StringPtr description, bool resolved, const void* brand) 922 : exception(kj::Exception::Type::FAILED, "", 0, kj::str(description)), 923 resolved(resolved), brand(brand) {} 924 925 Request<AnyPointer, AnyPointer> newCall( 926 uint64_t interfaceId, uint16_t methodId, kj::Maybe<MessageSize> sizeHint) override { 927 return newBrokenRequest(kj::cp(exception), sizeHint); 928 } 929 930 VoidPromiseAndPipeline call(uint64_t interfaceId, uint16_t methodId, 931 kj::Own<CallContextHook>&& context) override { 932 return VoidPromiseAndPipeline { kj::cp(exception), kj::refcounted<BrokenPipeline>(exception) }; 933 } 934 935 kj::Maybe<ClientHook&> getResolved() override { 936 return nullptr; 937 } 938 939 kj::Maybe<kj::Promise<kj::Own<ClientHook>>> whenMoreResolved() override { 940 if (resolved) { 941 return nullptr; 942 } else { 943 return kj::Promise<kj::Own<ClientHook>>(kj::cp(exception)); 944 } 945 } 946 947 kj::Own<ClientHook> addRef() override { 948 return kj::addRef(*this); 949 } 950 951 const void* getBrand() override { 952 return brand; 953 } 954 955 kj::Maybe<int> getFd() override { 956 return nullptr; 957 } 958 959 private: 960 kj::Exception exception; 961 bool resolved; 962 const void* brand; 963 }; 964 965 kj::Own<ClientHook> BrokenPipeline::getPipelinedCap(kj::ArrayPtr<const PipelineOp> ops) { 966 return kj::refcounted<BrokenClient>(exception, false, &ClientHook::BROKEN_CAPABILITY_BRAND); 967 } 968 969 kj::Own<ClientHook> newNullCap() { 970 // A null capability, unlike other broken capabilities, is considered resolved. 971 return kj::refcounted<BrokenClient>("Called null capability.", true, 972 &ClientHook::NULL_CAPABILITY_BRAND); 973 } 974 975 } // namespace 976 977 kj::Own<ClientHook> newBrokenCap(kj::StringPtr reason) { 978 return kj::refcounted<BrokenClient>(reason, false, &ClientHook::BROKEN_CAPABILITY_BRAND); 979 } 980 981 kj::Own<ClientHook> newBrokenCap(kj::Exception&& reason) { 982 return kj::refcounted<BrokenClient>(kj::mv(reason), false, &ClientHook::BROKEN_CAPABILITY_BRAND); 983 } 984 985 kj::Own<PipelineHook> newBrokenPipeline(kj::Exception&& reason) { 986 return kj::refcounted<BrokenPipeline>(kj::mv(reason)); 987 } 988 989 Request<AnyPointer, AnyPointer> newBrokenRequest( 990 kj::Exception&& reason, kj::Maybe<MessageSize> sizeHint) { 991 auto hook = kj::heap<BrokenRequest>(kj::mv(reason), sizeHint); 992 auto root = hook->message.getRoot<AnyPointer>(); 993 return Request<AnyPointer, AnyPointer>(root, kj::mv(hook)); 994 } 995 996 // ======================================================================================= 997 998 ReaderCapabilityTable::ReaderCapabilityTable( 999 kj::Array<kj::Maybe<kj::Own<ClientHook>>> table) 1000 : table(kj::mv(table)) { 1001 setGlobalBrokenCapFactoryForLayoutCpp(brokenCapFactory); 1002 } 1003 1004 kj::Maybe<kj::Own<ClientHook>> ReaderCapabilityTable::extractCap(uint index) { 1005 if (index < table.size()) { 1006 return table[index].map([](kj::Own<ClientHook>& cap) { return cap->addRef(); }); 1007 } else { 1008 return nullptr; 1009 } 1010 } 1011 1012 BuilderCapabilityTable::BuilderCapabilityTable() { 1013 setGlobalBrokenCapFactoryForLayoutCpp(brokenCapFactory); 1014 } 1015 1016 kj::Maybe<kj::Own<ClientHook>> BuilderCapabilityTable::extractCap(uint index) { 1017 if (index < table.size()) { 1018 return table[index].map([](kj::Own<ClientHook>& cap) { return cap->addRef(); }); 1019 } else { 1020 return nullptr; 1021 } 1022 } 1023 1024 uint BuilderCapabilityTable::injectCap(kj::Own<ClientHook>&& cap) { 1025 uint result = table.size(); 1026 table.add(kj::mv(cap)); 1027 return result; 1028 } 1029 1030 void BuilderCapabilityTable::dropCap(uint index) { 1031 KJ_ASSERT(index < table.size(), "Invalid capability descriptor in message.") { 1032 return; 1033 } 1034 table[index] = nullptr; 1035 } 1036 1037 // ======================================================================================= 1038 // CapabilityServerSet 1039 1040 namespace _ { // private 1041 1042 Capability::Client CapabilityServerSetBase::addInternal( 1043 kj::Own<Capability::Server>&& server, void* ptr) { 1044 return Capability::Client(kj::refcounted<LocalClient>(kj::mv(server), *this, ptr)); 1045 } 1046 1047 kj::Promise<void*> CapabilityServerSetBase::getLocalServerInternal(Capability::Client& client) { 1048 ClientHook* hook = client.hook.get(); 1049 1050 // Get the most-resolved-so-far version of the hook. 1051 for (;;) { 1052 KJ_IF_MAYBE(h, hook->getResolved()) { 1053 hook = h; 1054 } else { 1055 break; 1056 } 1057 } 1058 1059 // Try to unwrap that. 1060 if (hook->getBrand() == &LocalClient::BRAND) { 1061 KJ_IF_MAYBE(promise, kj::downcast<LocalClient>(*hook).getLocalServer(*this)) { 1062 // This is definitely a member of our set and will resolve to non-null. We just have to wait 1063 // for any existing streaming calls to complete. 1064 return kj::mv(*promise); 1065 } 1066 } 1067 1068 // OK, the capability isn't part of this set. 1069 KJ_IF_MAYBE(p, hook->whenMoreResolved()) { 1070 // This hook is an unresolved promise. It might resolve eventually to a local server, so wait 1071 // for it. 1072 return p->attach(hook->addRef()) 1073 .then([this](kj::Own<ClientHook>&& resolved) { 1074 Capability::Client client(kj::mv(resolved)); 1075 return getLocalServerInternal(client); 1076 }); 1077 } else { 1078 // Cap is settled, so it definitely will never resolve to a member of this set. 1079 return kj::implicitCast<void*>(nullptr); 1080 } 1081 } 1082 1083 } // namespace _ (private) 1084 1085 } // namespace capnp