capnproto

FORK: Cap'n Proto serialization/RPC system - core tools and C++ library
git clone https://git.neptards.moe/neptards/capnproto.git
Log | Files | Refs | README | LICENSE

rpc-twoparty.c++ (17355B)


      1 // Copyright (c) 2013-2014 Sandstorm Development Group, Inc. and contributors
      2 // Licensed under the MIT License:
      3 //
      4 // Permission is hereby granted, free of charge, to any person obtaining a copy
      5 // of this software and associated documentation files (the "Software"), to deal
      6 // in the Software without restriction, including without limitation the rights
      7 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
      8 // copies of the Software, and to permit persons to whom the Software is
      9 // furnished to do so, subject to the following conditions:
     10 //
     11 // The above copyright notice and this permission notice shall be included in
     12 // all copies or substantial portions of the Software.
     13 //
     14 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
     15 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
     16 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
     17 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
     18 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
     19 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
     20 // THE SOFTWARE.
     21 
     22 #include "rpc-twoparty.h"
     23 #include "serialize-async.h"
     24 #include <kj/debug.h>
     25 #include <kj/io.h>
     26 
     27 namespace capnp {
     28 
     29 TwoPartyVatNetwork::TwoPartyVatNetwork(
     30     kj::OneOf<MessageStream*, kj::Own<MessageStream>>&& stream,
     31     uint maxFdsPerMessage,
     32     rpc::twoparty::Side side,
     33     ReaderOptions receiveOptions,
     34     const kj::MonotonicClock& clock)
     35 
     36     : stream(kj::mv(stream)),
     37       maxFdsPerMessage(maxFdsPerMessage),
     38       side(side),
     39       peerVatId(4),
     40       receiveOptions(receiveOptions),
     41       previousWrite(kj::READY_NOW),
     42       clock(clock),
     43       currentOutgoingMessageSendTime(clock.now()) {
     44   peerVatId.initRoot<rpc::twoparty::VatId>().setSide(
     45       side == rpc::twoparty::Side::CLIENT ? rpc::twoparty::Side::SERVER
     46                                           : rpc::twoparty::Side::CLIENT);
     47 
     48   auto paf = kj::newPromiseAndFulfiller<void>();
     49   disconnectPromise = paf.promise.fork();
     50   disconnectFulfiller.fulfiller = kj::mv(paf.fulfiller);
     51 }
     52 
     53 TwoPartyVatNetwork::TwoPartyVatNetwork(capnp::MessageStream& stream,
     54                    rpc::twoparty::Side side, ReaderOptions receiveOptions,
     55                    const kj::MonotonicClock& clock)
     56   : TwoPartyVatNetwork(stream, 0, side, receiveOptions, clock) {}
     57 
     58 TwoPartyVatNetwork::TwoPartyVatNetwork(
     59     capnp::MessageStream& stream,
     60     uint maxFdsPerMessage,
     61     rpc::twoparty::Side side,
     62     ReaderOptions receiveOptions,
     63     const kj::MonotonicClock& clock)
     64     : TwoPartyVatNetwork(&stream, maxFdsPerMessage, side, receiveOptions, clock) {}
     65 
     66 TwoPartyVatNetwork::TwoPartyVatNetwork(kj::AsyncIoStream& stream, rpc::twoparty::Side side,
     67                                        ReaderOptions receiveOptions,
     68                                        const kj::MonotonicClock& clock)
     69     : TwoPartyVatNetwork(kj::Own<MessageStream>(kj::heap<AsyncIoMessageStream>(stream)),
     70                          0, side, receiveOptions, clock) {}
     71 
     72 TwoPartyVatNetwork::TwoPartyVatNetwork(kj::AsyncCapabilityStream& stream, uint maxFdsPerMessage,
     73                                        rpc::twoparty::Side side, ReaderOptions receiveOptions,
     74                                        const kj::MonotonicClock& clock)
     75     : TwoPartyVatNetwork(kj::Own<MessageStream>(kj::heap<AsyncCapabilityMessageStream>(stream)),
     76                          maxFdsPerMessage, side, receiveOptions, clock) {}
     77 
     78 MessageStream& TwoPartyVatNetwork::getStream() {
     79   KJ_SWITCH_ONEOF(stream) {
     80     KJ_CASE_ONEOF(s, MessageStream*) {
     81       return *s;
     82     }
     83     KJ_CASE_ONEOF(s, kj::Own<MessageStream>) {
     84       return *s;
     85     }
     86   }
     87   KJ_UNREACHABLE;
     88 }
     89 
     90 void TwoPartyVatNetwork::FulfillerDisposer::disposeImpl(void* pointer) const {
     91   if (--refcount == 0) {
     92     fulfiller->fulfill();
     93   }
     94 }
     95 
     96 kj::Own<TwoPartyVatNetworkBase::Connection> TwoPartyVatNetwork::asConnection() {
     97   ++disconnectFulfiller.refcount;
     98   return kj::Own<TwoPartyVatNetworkBase::Connection>(this, disconnectFulfiller);
     99 }
    100 
    101 kj::Maybe<kj::Own<TwoPartyVatNetworkBase::Connection>> TwoPartyVatNetwork::connect(
    102     rpc::twoparty::VatId::Reader ref) {
    103   if (ref.getSide() == side) {
    104     return nullptr;
    105   } else {
    106     return asConnection();
    107   }
    108 }
    109 
    110 kj::Promise<kj::Own<TwoPartyVatNetworkBase::Connection>> TwoPartyVatNetwork::accept() {
    111   if (side == rpc::twoparty::Side::SERVER && !accepted) {
    112     accepted = true;
    113     return asConnection();
    114   } else {
    115     // Create a promise that will never be fulfilled.
    116     auto paf = kj::newPromiseAndFulfiller<kj::Own<TwoPartyVatNetworkBase::Connection>>();
    117     acceptFulfiller = kj::mv(paf.fulfiller);
    118     return kj::mv(paf.promise);
    119   }
    120 }
    121 
    122 class TwoPartyVatNetwork::OutgoingMessageImpl final
    123     : public OutgoingRpcMessage, public kj::Refcounted {
    124 public:
    125   OutgoingMessageImpl(TwoPartyVatNetwork& network, uint firstSegmentWordSize)
    126       : network(network),
    127         message(firstSegmentWordSize == 0 ? SUGGESTED_FIRST_SEGMENT_WORDS : firstSegmentWordSize) {}
    128 
    129   AnyPointer::Builder getBody() override {
    130     return message.getRoot<AnyPointer>();
    131   }
    132 
    133   void setFds(kj::Array<int> fds) override {
    134     if (network.maxFdsPerMessage > 0) {
    135       this->fds = kj::mv(fds);
    136     }
    137   }
    138 
    139   void send() override {
    140     size_t size = 0;
    141     for (auto& segment: message.getSegmentsForOutput()) {
    142       size += segment.size();
    143     }
    144     KJ_REQUIRE(size < network.receiveOptions.traversalLimitInWords, size,
    145                "Trying to send Cap'n Proto message larger than our single-message size limit. The "
    146                "other side probably won't accept it (assuming its traversalLimitInWords matches "
    147                "ours) and would abort the connection, so I won't send it.") {
    148       return;
    149     }
    150 
    151     network.currentQueueSize += size * sizeof(capnp::word);
    152     ++network.currentQueueCount;
    153     auto deferredSizeUpdate = kj::defer([&network = network, size]() mutable {
    154       network.currentQueueSize -= size * sizeof(capnp::word);
    155       --network.currentQueueCount;
    156     });
    157 
    158     auto sendTime = network.clock.now();
    159     network.previousWrite = KJ_ASSERT_NONNULL(network.previousWrite, "already shut down")
    160         .then([this, sendTime]() {
    161       return kj::evalNow([&]() {
    162         network.currentOutgoingMessageSendTime = sendTime;
    163         return network.getStream().writeMessage(fds, message);
    164       }).catch_([this](kj::Exception&& e) {
    165         // Since no one checks write failures, we need to propagate them into read failures,
    166         // otherwise we might get stuck sending all messages into a black hole and wondering why
    167         // the peer never replies.
    168         network.readCancelReason = kj::cp(e);
    169         if (!network.readCanceler.isEmpty()) {
    170           network.readCanceler.cancel(kj::cp(e));
    171         }
    172         kj::throwRecoverableException(kj::mv(e));
    173       });
    174     }).attach(kj::addRef(*this), kj::mv(deferredSizeUpdate))
    175       // Note that it's important that the eagerlyEvaluate() come *after* the attach() because
    176       // otherwise the message (and any capabilities in it) will not be released until a new
    177       // message is written! (Kenton once spent all afternoon tracking this down...)
    178       .eagerlyEvaluate(nullptr);
    179   }
    180 
    181   size_t sizeInWords() override {
    182     return message.sizeInWords();
    183   }
    184 
    185 private:
    186   TwoPartyVatNetwork& network;
    187   MallocMessageBuilder message;
    188   kj::Array<int> fds;
    189 };
    190 
    191 kj::Duration TwoPartyVatNetwork::getOutgoingMessageWaitTime() {
    192   if (currentQueueCount > 0) {
    193     return clock.now() - currentOutgoingMessageSendTime;
    194   } else {
    195     return 0 * kj::SECONDS;
    196   }
    197 }
    198 
    199 class TwoPartyVatNetwork::IncomingMessageImpl final: public IncomingRpcMessage {
    200 public:
    201   IncomingMessageImpl(kj::Own<MessageReader> message): message(kj::mv(message)) {}
    202 
    203   IncomingMessageImpl(MessageReaderAndFds init, kj::Array<kj::AutoCloseFd> fdSpace)
    204       : message(kj::mv(init.reader)),
    205         fdSpace(kj::mv(fdSpace)),
    206         fds(init.fds) {
    207     KJ_DASSERT(this->fds.begin() == this->fdSpace.begin());
    208   }
    209 
    210   AnyPointer::Reader getBody() override {
    211     return message->getRoot<AnyPointer>();
    212   }
    213 
    214   kj::ArrayPtr<kj::AutoCloseFd> getAttachedFds() override {
    215     return fds;
    216   }
    217 
    218   size_t sizeInWords() override {
    219     return message->sizeInWords();
    220   }
    221 
    222 private:
    223   kj::Own<MessageReader> message;
    224   kj::Array<kj::AutoCloseFd> fdSpace;
    225   kj::ArrayPtr<kj::AutoCloseFd> fds;
    226 };
    227 
    228 kj::Own<RpcFlowController> TwoPartyVatNetwork::newStream() {
    229   return RpcFlowController::newVariableWindowController(*this);
    230 }
    231 
    232 size_t TwoPartyVatNetwork::getWindow() {
    233   // The socket's send buffer size -- as returned by getsockopt(SO_SNDBUF) -- tells us how much
    234   // data the kernel itself is willing to buffer. The kernel will increase the send buffer size if
    235   // needed to fill the connection's congestion window. So we can cheat and use it as our stream
    236   // window, too, to make sure we saturate said congestion window.
    237   //
    238   // TODO(perf): Unfortunately, this hack breaks down in the presence of proxying. What we really
    239   //   want is the window all the way to the endpoint, which could cross multiple connections. The
    240   //   first-hop window could be either too big or too small: it's too big if the first hop has
    241   //   much higher bandwidth than the full path (causing buffering at the bottleneck), and it's
    242   //   too small if the first hop has much lower latency than the full path (causing not enough
    243   //   data to be sent to saturate the connection). To handle this, we could either:
    244   //   1. Have proxies be aware of streaming, by flagging streaming calls in the RPC protocol. The
    245   //      proxies would then handle backpressure at each hop. This seems simple to implement but
    246   //      requires base RPC protocol changes and might require thinking carefully about e-ordering
    247   //      implications. Also, it only fixes underutilization; it does not fix buffer bloat.
    248   //   2. Do our own BBR-like computation, where the client measures the end-to-end latency and
    249   //      bandwidth based on the observed sends and returns, and then compute the window based on
    250   //      that. This seems complicated, but avoids the need for any changes to the RPC protocol.
    251   //      In theory it solves both underutilization and buffer bloat. Note that this approach would
    252   //      require the RPC system to use a clock, which feels dirty and adds non-determinism.
    253 
    254   if (solSndbufUnimplemented) {
    255     return RpcFlowController::DEFAULT_WINDOW_SIZE;
    256   } else {
    257     KJ_IF_MAYBE(bufSize, getStream().getSendBufferSize()) {
    258       return *bufSize;
    259     } else {
    260       solSndbufUnimplemented = true;
    261       return RpcFlowController::DEFAULT_WINDOW_SIZE;
    262     }
    263   }
    264 }
    265 
    266 rpc::twoparty::VatId::Reader TwoPartyVatNetwork::getPeerVatId() {
    267   return peerVatId.getRoot<rpc::twoparty::VatId>();
    268 }
    269 
    270 kj::Own<OutgoingRpcMessage> TwoPartyVatNetwork::newOutgoingMessage(uint firstSegmentWordSize) {
    271   return kj::refcounted<OutgoingMessageImpl>(*this, firstSegmentWordSize);
    272 }
    273 
    274 kj::Promise<kj::Maybe<kj::Own<IncomingRpcMessage>>> TwoPartyVatNetwork::receiveIncomingMessage() {
    275   return kj::evalLater([this]() -> kj::Promise<kj::Maybe<kj::Own<IncomingRpcMessage>>> {
    276     KJ_IF_MAYBE(e, readCancelReason) {
    277       // A previous write failed; propagate the failure to reads, too.
    278       return kj::cp(*e);
    279     }
    280 
    281     kj::Array<kj::AutoCloseFd> fdSpace = nullptr;
    282     if(maxFdsPerMessage > 0) {
    283       fdSpace = kj::heapArray<kj::AutoCloseFd>(maxFdsPerMessage);
    284     }
    285     auto promise = readCanceler.wrap(getStream().tryReadMessage(fdSpace, receiveOptions));
    286     return promise.then([fdSpace = kj::mv(fdSpace)]
    287                         (kj::Maybe<MessageReaderAndFds>&& messageAndFds) mutable
    288                       -> kj::Maybe<kj::Own<IncomingRpcMessage>> {
    289       KJ_IF_MAYBE(m, messageAndFds) {
    290         if (m->fds.size() > 0) {
    291           return kj::Own<IncomingRpcMessage>(
    292               kj::heap<IncomingMessageImpl>(kj::mv(*m), kj::mv(fdSpace)));
    293         } else {
    294           return kj::Own<IncomingRpcMessage>(kj::heap<IncomingMessageImpl>(kj::mv(m->reader)));
    295         }
    296       } else {
    297         return nullptr;
    298       }
    299     });
    300   });
    301 }
    302 
    303 kj::Promise<void> TwoPartyVatNetwork::shutdown() {
    304   kj::Promise<void> result = KJ_ASSERT_NONNULL(previousWrite, "already shut down").then([this]() {
    305     return getStream().end();
    306   });
    307   previousWrite = nullptr;
    308   return kj::mv(result);
    309 }
    310 
    311 // =======================================================================================
    312 
    313 TwoPartyServer::TwoPartyServer(Capability::Client bootstrapInterface)
    314     : bootstrapInterface(kj::mv(bootstrapInterface)), tasks(*this) {}
    315 
    316 struct TwoPartyServer::AcceptedConnection {
    317   kj::Own<kj::AsyncIoStream> connection;
    318   TwoPartyVatNetwork network;
    319   RpcSystem<rpc::twoparty::VatId> rpcSystem;
    320 
    321   explicit AcceptedConnection(Capability::Client bootstrapInterface,
    322                               kj::Own<kj::AsyncIoStream>&& connectionParam)
    323       : connection(kj::mv(connectionParam)),
    324         network(*connection, rpc::twoparty::Side::SERVER),
    325         rpcSystem(makeRpcServer(network, kj::mv(bootstrapInterface))) {}
    326 
    327   explicit AcceptedConnection(Capability::Client bootstrapInterface,
    328                               kj::Own<kj::AsyncCapabilityStream>&& connectionParam,
    329                               uint maxFdsPerMessage)
    330       : connection(kj::mv(connectionParam)),
    331         network(kj::downcast<kj::AsyncCapabilityStream>(*connection),
    332                 maxFdsPerMessage, rpc::twoparty::Side::SERVER),
    333         rpcSystem(makeRpcServer(network, kj::mv(bootstrapInterface))) {}
    334 };
    335 
    336 void TwoPartyServer::accept(kj::Own<kj::AsyncIoStream>&& connection) {
    337   auto connectionState = kj::heap<AcceptedConnection>(bootstrapInterface, kj::mv(connection));
    338 
    339   // Run the connection until disconnect.
    340   auto promise = connectionState->network.onDisconnect();
    341   tasks.add(promise.attach(kj::mv(connectionState)));
    342 }
    343 
    344 void TwoPartyServer::accept(
    345     kj::Own<kj::AsyncCapabilityStream>&& connection, uint maxFdsPerMessage) {
    346   auto connectionState = kj::heap<AcceptedConnection>(
    347       bootstrapInterface, kj::mv(connection), maxFdsPerMessage);
    348 
    349   // Run the connection until disconnect.
    350   auto promise = connectionState->network.onDisconnect();
    351   tasks.add(promise.attach(kj::mv(connectionState)));
    352 }
    353 
    354 kj::Promise<void> TwoPartyServer::accept(kj::AsyncIoStream& connection) {
    355   auto connectionState = kj::heap<AcceptedConnection>(bootstrapInterface,
    356       kj::Own<kj::AsyncIoStream>(&connection, kj::NullDisposer::instance));
    357 
    358   // Run the connection until disconnect.
    359   auto promise = connectionState->network.onDisconnect();
    360   return promise.attach(kj::mv(connectionState));
    361 }
    362 
    363 kj::Promise<void> TwoPartyServer::accept(
    364     kj::AsyncCapabilityStream& connection, uint maxFdsPerMessage) {
    365   auto connectionState = kj::heap<AcceptedConnection>(bootstrapInterface,
    366       kj::Own<kj::AsyncCapabilityStream>(&connection, kj::NullDisposer::instance),
    367       maxFdsPerMessage);
    368 
    369   // Run the connection until disconnect.
    370   auto promise = connectionState->network.onDisconnect();
    371   return promise.attach(kj::mv(connectionState));
    372 }
    373 
    374 kj::Promise<void> TwoPartyServer::listen(kj::ConnectionReceiver& listener) {
    375   return listener.accept()
    376       .then([this,&listener](kj::Own<kj::AsyncIoStream>&& connection) mutable {
    377     accept(kj::mv(connection));
    378     return listen(listener);
    379   });
    380 }
    381 
    382 kj::Promise<void> TwoPartyServer::listenCapStreamReceiver(
    383       kj::ConnectionReceiver& listener, uint maxFdsPerMessage) {
    384   return listener.accept()
    385       .then([this,&listener,maxFdsPerMessage](kj::Own<kj::AsyncIoStream>&& connection) mutable {
    386     accept(connection.downcast<kj::AsyncCapabilityStream>(), maxFdsPerMessage);
    387     return listenCapStreamReceiver(listener, maxFdsPerMessage);
    388   });
    389 }
    390 
    391 void TwoPartyServer::taskFailed(kj::Exception&& exception) {
    392   KJ_LOG(ERROR, exception);
    393 }
    394 
    395 TwoPartyClient::TwoPartyClient(kj::AsyncIoStream& connection)
    396     : network(connection, rpc::twoparty::Side::CLIENT),
    397       rpcSystem(makeRpcClient(network)) {}
    398 
    399 
    400 TwoPartyClient::TwoPartyClient(kj::AsyncCapabilityStream& connection, uint maxFdsPerMessage)
    401     : network(connection, maxFdsPerMessage, rpc::twoparty::Side::CLIENT),
    402       rpcSystem(makeRpcClient(network)) {}
    403 
    404 TwoPartyClient::TwoPartyClient(kj::AsyncIoStream& connection,
    405                                Capability::Client bootstrapInterface,
    406                                rpc::twoparty::Side side)
    407     : network(connection, side),
    408       rpcSystem(network, bootstrapInterface) {}
    409 
    410 TwoPartyClient::TwoPartyClient(kj::AsyncCapabilityStream& connection, uint maxFdsPerMessage,
    411                                Capability::Client bootstrapInterface,
    412                                rpc::twoparty::Side side)
    413     : network(connection, maxFdsPerMessage, side),
    414       rpcSystem(network, bootstrapInterface) {}
    415 
    416 Capability::Client TwoPartyClient::bootstrap() {
    417   capnp::word scratch[4];
    418   memset(&scratch, 0, sizeof(scratch));
    419   capnp::MallocMessageBuilder message(scratch);
    420   auto vatId = message.getRoot<rpc::twoparty::VatId>();
    421   vatId.setSide(network.getSide() == rpc::twoparty::Side::CLIENT
    422                 ? rpc::twoparty::Side::SERVER
    423                 : rpc::twoparty::Side::CLIENT);
    424   return rpcSystem.bootstrap(vatId);
    425 }
    426 
    427 void TwoPartyClient::setTraceEncoder(kj::Function<kj::String(const kj::Exception&)> func) {
    428   rpcSystem.setTraceEncoder(kj::mv(func));
    429 }
    430 
    431 }  // namespace capnp