rpc-twoparty.c++ (17355B)
1 // Copyright (c) 2013-2014 Sandstorm Development Group, Inc. and contributors 2 // Licensed under the MIT License: 3 // 4 // Permission is hereby granted, free of charge, to any person obtaining a copy 5 // of this software and associated documentation files (the "Software"), to deal 6 // in the Software without restriction, including without limitation the rights 7 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 8 // copies of the Software, and to permit persons to whom the Software is 9 // furnished to do so, subject to the following conditions: 10 // 11 // The above copyright notice and this permission notice shall be included in 12 // all copies or substantial portions of the Software. 13 // 14 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 17 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 19 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 20 // THE SOFTWARE. 21 22 #include "rpc-twoparty.h" 23 #include "serialize-async.h" 24 #include <kj/debug.h> 25 #include <kj/io.h> 26 27 namespace capnp { 28 29 TwoPartyVatNetwork::TwoPartyVatNetwork( 30 kj::OneOf<MessageStream*, kj::Own<MessageStream>>&& stream, 31 uint maxFdsPerMessage, 32 rpc::twoparty::Side side, 33 ReaderOptions receiveOptions, 34 const kj::MonotonicClock& clock) 35 36 : stream(kj::mv(stream)), 37 maxFdsPerMessage(maxFdsPerMessage), 38 side(side), 39 peerVatId(4), 40 receiveOptions(receiveOptions), 41 previousWrite(kj::READY_NOW), 42 clock(clock), 43 currentOutgoingMessageSendTime(clock.now()) { 44 peerVatId.initRoot<rpc::twoparty::VatId>().setSide( 45 side == rpc::twoparty::Side::CLIENT ? rpc::twoparty::Side::SERVER 46 : rpc::twoparty::Side::CLIENT); 47 48 auto paf = kj::newPromiseAndFulfiller<void>(); 49 disconnectPromise = paf.promise.fork(); 50 disconnectFulfiller.fulfiller = kj::mv(paf.fulfiller); 51 } 52 53 TwoPartyVatNetwork::TwoPartyVatNetwork(capnp::MessageStream& stream, 54 rpc::twoparty::Side side, ReaderOptions receiveOptions, 55 const kj::MonotonicClock& clock) 56 : TwoPartyVatNetwork(stream, 0, side, receiveOptions, clock) {} 57 58 TwoPartyVatNetwork::TwoPartyVatNetwork( 59 capnp::MessageStream& stream, 60 uint maxFdsPerMessage, 61 rpc::twoparty::Side side, 62 ReaderOptions receiveOptions, 63 const kj::MonotonicClock& clock) 64 : TwoPartyVatNetwork(&stream, maxFdsPerMessage, side, receiveOptions, clock) {} 65 66 TwoPartyVatNetwork::TwoPartyVatNetwork(kj::AsyncIoStream& stream, rpc::twoparty::Side side, 67 ReaderOptions receiveOptions, 68 const kj::MonotonicClock& clock) 69 : TwoPartyVatNetwork(kj::Own<MessageStream>(kj::heap<AsyncIoMessageStream>(stream)), 70 0, side, receiveOptions, clock) {} 71 72 TwoPartyVatNetwork::TwoPartyVatNetwork(kj::AsyncCapabilityStream& stream, uint maxFdsPerMessage, 73 rpc::twoparty::Side side, ReaderOptions receiveOptions, 74 const kj::MonotonicClock& clock) 75 : TwoPartyVatNetwork(kj::Own<MessageStream>(kj::heap<AsyncCapabilityMessageStream>(stream)), 76 maxFdsPerMessage, side, receiveOptions, clock) {} 77 78 MessageStream& TwoPartyVatNetwork::getStream() { 79 KJ_SWITCH_ONEOF(stream) { 80 KJ_CASE_ONEOF(s, MessageStream*) { 81 return *s; 82 } 83 KJ_CASE_ONEOF(s, kj::Own<MessageStream>) { 84 return *s; 85 } 86 } 87 KJ_UNREACHABLE; 88 } 89 90 void TwoPartyVatNetwork::FulfillerDisposer::disposeImpl(void* pointer) const { 91 if (--refcount == 0) { 92 fulfiller->fulfill(); 93 } 94 } 95 96 kj::Own<TwoPartyVatNetworkBase::Connection> TwoPartyVatNetwork::asConnection() { 97 ++disconnectFulfiller.refcount; 98 return kj::Own<TwoPartyVatNetworkBase::Connection>(this, disconnectFulfiller); 99 } 100 101 kj::Maybe<kj::Own<TwoPartyVatNetworkBase::Connection>> TwoPartyVatNetwork::connect( 102 rpc::twoparty::VatId::Reader ref) { 103 if (ref.getSide() == side) { 104 return nullptr; 105 } else { 106 return asConnection(); 107 } 108 } 109 110 kj::Promise<kj::Own<TwoPartyVatNetworkBase::Connection>> TwoPartyVatNetwork::accept() { 111 if (side == rpc::twoparty::Side::SERVER && !accepted) { 112 accepted = true; 113 return asConnection(); 114 } else { 115 // Create a promise that will never be fulfilled. 116 auto paf = kj::newPromiseAndFulfiller<kj::Own<TwoPartyVatNetworkBase::Connection>>(); 117 acceptFulfiller = kj::mv(paf.fulfiller); 118 return kj::mv(paf.promise); 119 } 120 } 121 122 class TwoPartyVatNetwork::OutgoingMessageImpl final 123 : public OutgoingRpcMessage, public kj::Refcounted { 124 public: 125 OutgoingMessageImpl(TwoPartyVatNetwork& network, uint firstSegmentWordSize) 126 : network(network), 127 message(firstSegmentWordSize == 0 ? SUGGESTED_FIRST_SEGMENT_WORDS : firstSegmentWordSize) {} 128 129 AnyPointer::Builder getBody() override { 130 return message.getRoot<AnyPointer>(); 131 } 132 133 void setFds(kj::Array<int> fds) override { 134 if (network.maxFdsPerMessage > 0) { 135 this->fds = kj::mv(fds); 136 } 137 } 138 139 void send() override { 140 size_t size = 0; 141 for (auto& segment: message.getSegmentsForOutput()) { 142 size += segment.size(); 143 } 144 KJ_REQUIRE(size < network.receiveOptions.traversalLimitInWords, size, 145 "Trying to send Cap'n Proto message larger than our single-message size limit. The " 146 "other side probably won't accept it (assuming its traversalLimitInWords matches " 147 "ours) and would abort the connection, so I won't send it.") { 148 return; 149 } 150 151 network.currentQueueSize += size * sizeof(capnp::word); 152 ++network.currentQueueCount; 153 auto deferredSizeUpdate = kj::defer([&network = network, size]() mutable { 154 network.currentQueueSize -= size * sizeof(capnp::word); 155 --network.currentQueueCount; 156 }); 157 158 auto sendTime = network.clock.now(); 159 network.previousWrite = KJ_ASSERT_NONNULL(network.previousWrite, "already shut down") 160 .then([this, sendTime]() { 161 return kj::evalNow([&]() { 162 network.currentOutgoingMessageSendTime = sendTime; 163 return network.getStream().writeMessage(fds, message); 164 }).catch_([this](kj::Exception&& e) { 165 // Since no one checks write failures, we need to propagate them into read failures, 166 // otherwise we might get stuck sending all messages into a black hole and wondering why 167 // the peer never replies. 168 network.readCancelReason = kj::cp(e); 169 if (!network.readCanceler.isEmpty()) { 170 network.readCanceler.cancel(kj::cp(e)); 171 } 172 kj::throwRecoverableException(kj::mv(e)); 173 }); 174 }).attach(kj::addRef(*this), kj::mv(deferredSizeUpdate)) 175 // Note that it's important that the eagerlyEvaluate() come *after* the attach() because 176 // otherwise the message (and any capabilities in it) will not be released until a new 177 // message is written! (Kenton once spent all afternoon tracking this down...) 178 .eagerlyEvaluate(nullptr); 179 } 180 181 size_t sizeInWords() override { 182 return message.sizeInWords(); 183 } 184 185 private: 186 TwoPartyVatNetwork& network; 187 MallocMessageBuilder message; 188 kj::Array<int> fds; 189 }; 190 191 kj::Duration TwoPartyVatNetwork::getOutgoingMessageWaitTime() { 192 if (currentQueueCount > 0) { 193 return clock.now() - currentOutgoingMessageSendTime; 194 } else { 195 return 0 * kj::SECONDS; 196 } 197 } 198 199 class TwoPartyVatNetwork::IncomingMessageImpl final: public IncomingRpcMessage { 200 public: 201 IncomingMessageImpl(kj::Own<MessageReader> message): message(kj::mv(message)) {} 202 203 IncomingMessageImpl(MessageReaderAndFds init, kj::Array<kj::AutoCloseFd> fdSpace) 204 : message(kj::mv(init.reader)), 205 fdSpace(kj::mv(fdSpace)), 206 fds(init.fds) { 207 KJ_DASSERT(this->fds.begin() == this->fdSpace.begin()); 208 } 209 210 AnyPointer::Reader getBody() override { 211 return message->getRoot<AnyPointer>(); 212 } 213 214 kj::ArrayPtr<kj::AutoCloseFd> getAttachedFds() override { 215 return fds; 216 } 217 218 size_t sizeInWords() override { 219 return message->sizeInWords(); 220 } 221 222 private: 223 kj::Own<MessageReader> message; 224 kj::Array<kj::AutoCloseFd> fdSpace; 225 kj::ArrayPtr<kj::AutoCloseFd> fds; 226 }; 227 228 kj::Own<RpcFlowController> TwoPartyVatNetwork::newStream() { 229 return RpcFlowController::newVariableWindowController(*this); 230 } 231 232 size_t TwoPartyVatNetwork::getWindow() { 233 // The socket's send buffer size -- as returned by getsockopt(SO_SNDBUF) -- tells us how much 234 // data the kernel itself is willing to buffer. The kernel will increase the send buffer size if 235 // needed to fill the connection's congestion window. So we can cheat and use it as our stream 236 // window, too, to make sure we saturate said congestion window. 237 // 238 // TODO(perf): Unfortunately, this hack breaks down in the presence of proxying. What we really 239 // want is the window all the way to the endpoint, which could cross multiple connections. The 240 // first-hop window could be either too big or too small: it's too big if the first hop has 241 // much higher bandwidth than the full path (causing buffering at the bottleneck), and it's 242 // too small if the first hop has much lower latency than the full path (causing not enough 243 // data to be sent to saturate the connection). To handle this, we could either: 244 // 1. Have proxies be aware of streaming, by flagging streaming calls in the RPC protocol. The 245 // proxies would then handle backpressure at each hop. This seems simple to implement but 246 // requires base RPC protocol changes and might require thinking carefully about e-ordering 247 // implications. Also, it only fixes underutilization; it does not fix buffer bloat. 248 // 2. Do our own BBR-like computation, where the client measures the end-to-end latency and 249 // bandwidth based on the observed sends and returns, and then compute the window based on 250 // that. This seems complicated, but avoids the need for any changes to the RPC protocol. 251 // In theory it solves both underutilization and buffer bloat. Note that this approach would 252 // require the RPC system to use a clock, which feels dirty and adds non-determinism. 253 254 if (solSndbufUnimplemented) { 255 return RpcFlowController::DEFAULT_WINDOW_SIZE; 256 } else { 257 KJ_IF_MAYBE(bufSize, getStream().getSendBufferSize()) { 258 return *bufSize; 259 } else { 260 solSndbufUnimplemented = true; 261 return RpcFlowController::DEFAULT_WINDOW_SIZE; 262 } 263 } 264 } 265 266 rpc::twoparty::VatId::Reader TwoPartyVatNetwork::getPeerVatId() { 267 return peerVatId.getRoot<rpc::twoparty::VatId>(); 268 } 269 270 kj::Own<OutgoingRpcMessage> TwoPartyVatNetwork::newOutgoingMessage(uint firstSegmentWordSize) { 271 return kj::refcounted<OutgoingMessageImpl>(*this, firstSegmentWordSize); 272 } 273 274 kj::Promise<kj::Maybe<kj::Own<IncomingRpcMessage>>> TwoPartyVatNetwork::receiveIncomingMessage() { 275 return kj::evalLater([this]() -> kj::Promise<kj::Maybe<kj::Own<IncomingRpcMessage>>> { 276 KJ_IF_MAYBE(e, readCancelReason) { 277 // A previous write failed; propagate the failure to reads, too. 278 return kj::cp(*e); 279 } 280 281 kj::Array<kj::AutoCloseFd> fdSpace = nullptr; 282 if(maxFdsPerMessage > 0) { 283 fdSpace = kj::heapArray<kj::AutoCloseFd>(maxFdsPerMessage); 284 } 285 auto promise = readCanceler.wrap(getStream().tryReadMessage(fdSpace, receiveOptions)); 286 return promise.then([fdSpace = kj::mv(fdSpace)] 287 (kj::Maybe<MessageReaderAndFds>&& messageAndFds) mutable 288 -> kj::Maybe<kj::Own<IncomingRpcMessage>> { 289 KJ_IF_MAYBE(m, messageAndFds) { 290 if (m->fds.size() > 0) { 291 return kj::Own<IncomingRpcMessage>( 292 kj::heap<IncomingMessageImpl>(kj::mv(*m), kj::mv(fdSpace))); 293 } else { 294 return kj::Own<IncomingRpcMessage>(kj::heap<IncomingMessageImpl>(kj::mv(m->reader))); 295 } 296 } else { 297 return nullptr; 298 } 299 }); 300 }); 301 } 302 303 kj::Promise<void> TwoPartyVatNetwork::shutdown() { 304 kj::Promise<void> result = KJ_ASSERT_NONNULL(previousWrite, "already shut down").then([this]() { 305 return getStream().end(); 306 }); 307 previousWrite = nullptr; 308 return kj::mv(result); 309 } 310 311 // ======================================================================================= 312 313 TwoPartyServer::TwoPartyServer(Capability::Client bootstrapInterface) 314 : bootstrapInterface(kj::mv(bootstrapInterface)), tasks(*this) {} 315 316 struct TwoPartyServer::AcceptedConnection { 317 kj::Own<kj::AsyncIoStream> connection; 318 TwoPartyVatNetwork network; 319 RpcSystem<rpc::twoparty::VatId> rpcSystem; 320 321 explicit AcceptedConnection(Capability::Client bootstrapInterface, 322 kj::Own<kj::AsyncIoStream>&& connectionParam) 323 : connection(kj::mv(connectionParam)), 324 network(*connection, rpc::twoparty::Side::SERVER), 325 rpcSystem(makeRpcServer(network, kj::mv(bootstrapInterface))) {} 326 327 explicit AcceptedConnection(Capability::Client bootstrapInterface, 328 kj::Own<kj::AsyncCapabilityStream>&& connectionParam, 329 uint maxFdsPerMessage) 330 : connection(kj::mv(connectionParam)), 331 network(kj::downcast<kj::AsyncCapabilityStream>(*connection), 332 maxFdsPerMessage, rpc::twoparty::Side::SERVER), 333 rpcSystem(makeRpcServer(network, kj::mv(bootstrapInterface))) {} 334 }; 335 336 void TwoPartyServer::accept(kj::Own<kj::AsyncIoStream>&& connection) { 337 auto connectionState = kj::heap<AcceptedConnection>(bootstrapInterface, kj::mv(connection)); 338 339 // Run the connection until disconnect. 340 auto promise = connectionState->network.onDisconnect(); 341 tasks.add(promise.attach(kj::mv(connectionState))); 342 } 343 344 void TwoPartyServer::accept( 345 kj::Own<kj::AsyncCapabilityStream>&& connection, uint maxFdsPerMessage) { 346 auto connectionState = kj::heap<AcceptedConnection>( 347 bootstrapInterface, kj::mv(connection), maxFdsPerMessage); 348 349 // Run the connection until disconnect. 350 auto promise = connectionState->network.onDisconnect(); 351 tasks.add(promise.attach(kj::mv(connectionState))); 352 } 353 354 kj::Promise<void> TwoPartyServer::accept(kj::AsyncIoStream& connection) { 355 auto connectionState = kj::heap<AcceptedConnection>(bootstrapInterface, 356 kj::Own<kj::AsyncIoStream>(&connection, kj::NullDisposer::instance)); 357 358 // Run the connection until disconnect. 359 auto promise = connectionState->network.onDisconnect(); 360 return promise.attach(kj::mv(connectionState)); 361 } 362 363 kj::Promise<void> TwoPartyServer::accept( 364 kj::AsyncCapabilityStream& connection, uint maxFdsPerMessage) { 365 auto connectionState = kj::heap<AcceptedConnection>(bootstrapInterface, 366 kj::Own<kj::AsyncCapabilityStream>(&connection, kj::NullDisposer::instance), 367 maxFdsPerMessage); 368 369 // Run the connection until disconnect. 370 auto promise = connectionState->network.onDisconnect(); 371 return promise.attach(kj::mv(connectionState)); 372 } 373 374 kj::Promise<void> TwoPartyServer::listen(kj::ConnectionReceiver& listener) { 375 return listener.accept() 376 .then([this,&listener](kj::Own<kj::AsyncIoStream>&& connection) mutable { 377 accept(kj::mv(connection)); 378 return listen(listener); 379 }); 380 } 381 382 kj::Promise<void> TwoPartyServer::listenCapStreamReceiver( 383 kj::ConnectionReceiver& listener, uint maxFdsPerMessage) { 384 return listener.accept() 385 .then([this,&listener,maxFdsPerMessage](kj::Own<kj::AsyncIoStream>&& connection) mutable { 386 accept(connection.downcast<kj::AsyncCapabilityStream>(), maxFdsPerMessage); 387 return listenCapStreamReceiver(listener, maxFdsPerMessage); 388 }); 389 } 390 391 void TwoPartyServer::taskFailed(kj::Exception&& exception) { 392 KJ_LOG(ERROR, exception); 393 } 394 395 TwoPartyClient::TwoPartyClient(kj::AsyncIoStream& connection) 396 : network(connection, rpc::twoparty::Side::CLIENT), 397 rpcSystem(makeRpcClient(network)) {} 398 399 400 TwoPartyClient::TwoPartyClient(kj::AsyncCapabilityStream& connection, uint maxFdsPerMessage) 401 : network(connection, maxFdsPerMessage, rpc::twoparty::Side::CLIENT), 402 rpcSystem(makeRpcClient(network)) {} 403 404 TwoPartyClient::TwoPartyClient(kj::AsyncIoStream& connection, 405 Capability::Client bootstrapInterface, 406 rpc::twoparty::Side side) 407 : network(connection, side), 408 rpcSystem(network, bootstrapInterface) {} 409 410 TwoPartyClient::TwoPartyClient(kj::AsyncCapabilityStream& connection, uint maxFdsPerMessage, 411 Capability::Client bootstrapInterface, 412 rpc::twoparty::Side side) 413 : network(connection, maxFdsPerMessage, side), 414 rpcSystem(network, bootstrapInterface) {} 415 416 Capability::Client TwoPartyClient::bootstrap() { 417 capnp::word scratch[4]; 418 memset(&scratch, 0, sizeof(scratch)); 419 capnp::MallocMessageBuilder message(scratch); 420 auto vatId = message.getRoot<rpc::twoparty::VatId>(); 421 vatId.setSide(network.getSide() == rpc::twoparty::Side::CLIENT 422 ? rpc::twoparty::Side::SERVER 423 : rpc::twoparty::Side::CLIENT); 424 return rpcSystem.bootstrap(vatId); 425 } 426 427 void TwoPartyClient::setTraceEncoder(kj::Function<kj::String(const kj::Exception&)> func) { 428 rpcSystem.setTraceEncoder(kj::mv(func)); 429 } 430 431 } // namespace capnp