ez-rpc.c++ (13351B)
1 // Copyright (c) 2013-2014 Sandstorm Development Group, Inc. and contributors 2 // Licensed under the MIT License: 3 // 4 // Permission is hereby granted, free of charge, to any person obtaining a copy 5 // of this software and associated documentation files (the "Software"), to deal 6 // in the Software without restriction, including without limitation the rights 7 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 8 // copies of the Software, and to permit persons to whom the Software is 9 // furnished to do so, subject to the following conditions: 10 // 11 // The above copyright notice and this permission notice shall be included in 12 // all copies or substantial portions of the Software. 13 // 14 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 17 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 19 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 20 // THE SOFTWARE. 21 22 #include "ez-rpc.h" 23 #include "rpc-twoparty.h" 24 #include <capnp/rpc.capnp.h> 25 #include <kj/async-io.h> 26 #include <kj/debug.h> 27 #include <kj/threadlocal.h> 28 #include <map> 29 30 namespace capnp { 31 32 KJ_THREADLOCAL_PTR(EzRpcContext) threadEzContext = nullptr; 33 34 class EzRpcContext: public kj::Refcounted { 35 public: 36 EzRpcContext(): ioContext(kj::setupAsyncIo()) { 37 threadEzContext = this; 38 } 39 40 ~EzRpcContext() noexcept(false) { 41 KJ_REQUIRE(threadEzContext == this, 42 "EzRpcContext destroyed from different thread than it was created.") { 43 return; 44 } 45 threadEzContext = nullptr; 46 } 47 48 kj::WaitScope& getWaitScope() { 49 return ioContext.waitScope; 50 } 51 52 kj::AsyncIoProvider& getIoProvider() { 53 return *ioContext.provider; 54 } 55 56 kj::LowLevelAsyncIoProvider& getLowLevelIoProvider() { 57 return *ioContext.lowLevelProvider; 58 } 59 60 static kj::Own<EzRpcContext> getThreadLocal() { 61 EzRpcContext* existing = threadEzContext; 62 if (existing != nullptr) { 63 return kj::addRef(*existing); 64 } else { 65 return kj::refcounted<EzRpcContext>(); 66 } 67 } 68 69 private: 70 kj::AsyncIoContext ioContext; 71 }; 72 73 // ======================================================================================= 74 75 kj::Promise<kj::Own<kj::AsyncIoStream>> connectAttach(kj::Own<kj::NetworkAddress>&& addr) { 76 return addr->connect().attach(kj::mv(addr)); 77 } 78 79 struct EzRpcClient::Impl { 80 kj::Own<EzRpcContext> context; 81 82 struct ClientContext { 83 kj::Own<kj::AsyncIoStream> stream; 84 TwoPartyVatNetwork network; 85 RpcSystem<rpc::twoparty::VatId> rpcSystem; 86 87 ClientContext(kj::Own<kj::AsyncIoStream>&& stream, ReaderOptions readerOpts) 88 : stream(kj::mv(stream)), 89 network(*this->stream, rpc::twoparty::Side::CLIENT, readerOpts), 90 rpcSystem(makeRpcClient(network)) {} 91 92 Capability::Client getMain() { 93 word scratch[4]; 94 memset(scratch, 0, sizeof(scratch)); 95 MallocMessageBuilder message(scratch); 96 auto hostId = message.getRoot<rpc::twoparty::VatId>(); 97 hostId.setSide(rpc::twoparty::Side::SERVER); 98 return rpcSystem.bootstrap(hostId); 99 } 100 101 Capability::Client restore(kj::StringPtr name) { 102 word scratch[64]; 103 memset(scratch, 0, sizeof(scratch)); 104 MallocMessageBuilder message(scratch); 105 106 auto hostIdOrphan = message.getOrphanage().newOrphan<rpc::twoparty::VatId>(); 107 auto hostId = hostIdOrphan.get(); 108 hostId.setSide(rpc::twoparty::Side::SERVER); 109 110 auto objectId = message.getRoot<AnyPointer>(); 111 objectId.setAs<Text>(name); 112 #pragma GCC diagnostic push 113 #pragma GCC diagnostic ignored "-Wdeprecated-declarations" 114 return rpcSystem.restore(hostId, objectId); 115 #pragma GCC diagnostic pop 116 } 117 }; 118 119 kj::ForkedPromise<void> setupPromise; 120 121 kj::Maybe<kj::Own<ClientContext>> clientContext; 122 // Filled in before `setupPromise` resolves. 123 124 Impl(kj::StringPtr serverAddress, uint defaultPort, 125 ReaderOptions readerOpts) 126 : context(EzRpcContext::getThreadLocal()), 127 setupPromise(context->getIoProvider().getNetwork() 128 .parseAddress(serverAddress, defaultPort) 129 .then([](kj::Own<kj::NetworkAddress>&& addr) { 130 return connectAttach(kj::mv(addr)); 131 }).then([this, readerOpts](kj::Own<kj::AsyncIoStream>&& stream) { 132 clientContext = kj::heap<ClientContext>(kj::mv(stream), 133 readerOpts); 134 }).fork()) {} 135 136 Impl(const struct sockaddr* serverAddress, uint addrSize, 137 ReaderOptions readerOpts) 138 : context(EzRpcContext::getThreadLocal()), 139 setupPromise( 140 connectAttach(context->getIoProvider().getNetwork() 141 .getSockaddr(serverAddress, addrSize)) 142 .then([this, readerOpts](kj::Own<kj::AsyncIoStream>&& stream) { 143 clientContext = kj::heap<ClientContext>(kj::mv(stream), 144 readerOpts); 145 }).fork()) {} 146 147 Impl(int socketFd, ReaderOptions readerOpts) 148 : context(EzRpcContext::getThreadLocal()), 149 setupPromise(kj::Promise<void>(kj::READY_NOW).fork()), 150 clientContext(kj::heap<ClientContext>( 151 context->getLowLevelIoProvider().wrapSocketFd(socketFd), 152 readerOpts)) {} 153 }; 154 155 EzRpcClient::EzRpcClient(kj::StringPtr serverAddress, uint defaultPort, ReaderOptions readerOpts) 156 : impl(kj::heap<Impl>(serverAddress, defaultPort, readerOpts)) {} 157 158 EzRpcClient::EzRpcClient(const struct sockaddr* serverAddress, uint addrSize, ReaderOptions readerOpts) 159 : impl(kj::heap<Impl>(serverAddress, addrSize, readerOpts)) {} 160 161 EzRpcClient::EzRpcClient(int socketFd, ReaderOptions readerOpts) 162 : impl(kj::heap<Impl>(socketFd, readerOpts)) {} 163 164 EzRpcClient::~EzRpcClient() noexcept(false) {} 165 166 Capability::Client EzRpcClient::getMain() { 167 KJ_IF_MAYBE(client, impl->clientContext) { 168 return client->get()->getMain(); 169 } else { 170 return impl->setupPromise.addBranch().then([this]() { 171 return KJ_ASSERT_NONNULL(impl->clientContext)->getMain(); 172 }); 173 } 174 } 175 176 Capability::Client EzRpcClient::importCap(kj::StringPtr name) { 177 KJ_IF_MAYBE(client, impl->clientContext) { 178 return client->get()->restore(name); 179 } else { 180 return impl->setupPromise.addBranch().then(kj::mvCapture(kj::heapString(name), 181 [this](kj::String&& name) { 182 return KJ_ASSERT_NONNULL(impl->clientContext)->restore(name); 183 })); 184 } 185 } 186 187 kj::WaitScope& EzRpcClient::getWaitScope() { 188 return impl->context->getWaitScope(); 189 } 190 191 kj::AsyncIoProvider& EzRpcClient::getIoProvider() { 192 return impl->context->getIoProvider(); 193 } 194 195 kj::LowLevelAsyncIoProvider& EzRpcClient::getLowLevelIoProvider() { 196 return impl->context->getLowLevelIoProvider(); 197 } 198 199 // ======================================================================================= 200 201 namespace { 202 203 class DummyFilter: public kj::LowLevelAsyncIoProvider::NetworkFilter { 204 public: 205 bool shouldAllow(const struct sockaddr* addr, uint addrlen) override { 206 return true; 207 } 208 }; 209 210 static DummyFilter DUMMY_FILTER; 211 212 } // namespace 213 214 struct EzRpcServer::Impl final: public SturdyRefRestorer<AnyPointer>, 215 public kj::TaskSet::ErrorHandler { 216 Capability::Client mainInterface; 217 kj::Own<EzRpcContext> context; 218 219 struct ExportedCap { 220 kj::String name; 221 Capability::Client cap = nullptr; 222 223 ExportedCap(kj::StringPtr name, Capability::Client cap) 224 : name(kj::heapString(name)), cap(cap) {} 225 226 ExportedCap() = default; 227 ExportedCap(const ExportedCap&) = delete; 228 ExportedCap(ExportedCap&&) = default; 229 ExportedCap& operator=(const ExportedCap&) = delete; 230 ExportedCap& operator=(ExportedCap&&) = default; 231 // Make std::map happy... 232 }; 233 234 std::map<kj::StringPtr, ExportedCap> exportMap; 235 236 kj::ForkedPromise<uint> portPromise; 237 238 kj::TaskSet tasks; 239 240 struct ServerContext { 241 kj::Own<kj::AsyncIoStream> stream; 242 TwoPartyVatNetwork network; 243 RpcSystem<rpc::twoparty::VatId> rpcSystem; 244 245 #pragma GCC diagnostic push 246 #pragma GCC diagnostic ignored "-Wdeprecated-declarations" 247 ServerContext(kj::Own<kj::AsyncIoStream>&& stream, SturdyRefRestorer<AnyPointer>& restorer, 248 ReaderOptions readerOpts) 249 : stream(kj::mv(stream)), 250 network(*this->stream, rpc::twoparty::Side::SERVER, readerOpts), 251 rpcSystem(makeRpcServer(network, restorer)) {} 252 #pragma GCC diagnostic pop 253 }; 254 255 Impl(Capability::Client mainInterface, kj::StringPtr bindAddress, uint defaultPort, 256 ReaderOptions readerOpts) 257 : mainInterface(kj::mv(mainInterface)), 258 context(EzRpcContext::getThreadLocal()), portPromise(nullptr), tasks(*this) { 259 auto paf = kj::newPromiseAndFulfiller<uint>(); 260 portPromise = paf.promise.fork(); 261 262 tasks.add(context->getIoProvider().getNetwork().parseAddress(bindAddress, defaultPort) 263 .then(kj::mvCapture(paf.fulfiller, 264 [this, readerOpts](kj::Own<kj::PromiseFulfiller<uint>>&& portFulfiller, 265 kj::Own<kj::NetworkAddress>&& addr) { 266 auto listener = addr->listen(); 267 portFulfiller->fulfill(listener->getPort()); 268 acceptLoop(kj::mv(listener), readerOpts); 269 }))); 270 } 271 272 Impl(Capability::Client mainInterface, struct sockaddr* bindAddress, uint addrSize, 273 ReaderOptions readerOpts) 274 : mainInterface(kj::mv(mainInterface)), 275 context(EzRpcContext::getThreadLocal()), portPromise(nullptr), tasks(*this) { 276 auto listener = context->getIoProvider().getNetwork() 277 .getSockaddr(bindAddress, addrSize)->listen(); 278 portPromise = kj::Promise<uint>(listener->getPort()).fork(); 279 acceptLoop(kj::mv(listener), readerOpts); 280 } 281 282 Impl(Capability::Client mainInterface, int socketFd, uint port, ReaderOptions readerOpts) 283 : mainInterface(kj::mv(mainInterface)), 284 context(EzRpcContext::getThreadLocal()), 285 portPromise(kj::Promise<uint>(port).fork()), 286 tasks(*this) { 287 acceptLoop(context->getLowLevelIoProvider().wrapListenSocketFd(socketFd, DUMMY_FILTER), 288 readerOpts); 289 } 290 291 void acceptLoop(kj::Own<kj::ConnectionReceiver>&& listener, ReaderOptions readerOpts) { 292 auto ptr = listener.get(); 293 tasks.add(ptr->accept().then(kj::mvCapture(kj::mv(listener), 294 [this, readerOpts](kj::Own<kj::ConnectionReceiver>&& listener, 295 kj::Own<kj::AsyncIoStream>&& connection) { 296 acceptLoop(kj::mv(listener), readerOpts); 297 298 auto server = kj::heap<ServerContext>(kj::mv(connection), *this, readerOpts); 299 300 // Arrange to destroy the server context when all references are gone, or when the 301 // EzRpcServer is destroyed (which will destroy the TaskSet). 302 tasks.add(server->network.onDisconnect().attach(kj::mv(server))); 303 }))); 304 } 305 306 Capability::Client restore(AnyPointer::Reader objectId) override { 307 if (objectId.isNull()) { 308 return mainInterface; 309 } else { 310 auto name = objectId.getAs<Text>(); 311 auto iter = exportMap.find(name); 312 if (iter == exportMap.end()) { 313 KJ_FAIL_REQUIRE("Server exports no such capability.", name) { break; } 314 return nullptr; 315 } else { 316 return iter->second.cap; 317 } 318 } 319 } 320 321 void taskFailed(kj::Exception&& exception) override { 322 kj::throwFatalException(kj::mv(exception)); 323 } 324 }; 325 326 EzRpcServer::EzRpcServer(Capability::Client mainInterface, kj::StringPtr bindAddress, 327 uint defaultPort, ReaderOptions readerOpts) 328 : impl(kj::heap<Impl>(kj::mv(mainInterface), bindAddress, defaultPort, readerOpts)) {} 329 330 EzRpcServer::EzRpcServer(Capability::Client mainInterface, struct sockaddr* bindAddress, 331 uint addrSize, ReaderOptions readerOpts) 332 : impl(kj::heap<Impl>(kj::mv(mainInterface), bindAddress, addrSize, readerOpts)) {} 333 334 EzRpcServer::EzRpcServer(Capability::Client mainInterface, int socketFd, uint port, 335 ReaderOptions readerOpts) 336 : impl(kj::heap<Impl>(kj::mv(mainInterface), socketFd, port, readerOpts)) {} 337 338 EzRpcServer::EzRpcServer(kj::StringPtr bindAddress, uint defaultPort, 339 ReaderOptions readerOpts) 340 : EzRpcServer(nullptr, bindAddress, defaultPort, readerOpts) {} 341 342 EzRpcServer::EzRpcServer(struct sockaddr* bindAddress, uint addrSize, 343 ReaderOptions readerOpts) 344 : EzRpcServer(nullptr, bindAddress, addrSize, readerOpts) {} 345 346 EzRpcServer::EzRpcServer(int socketFd, uint port, ReaderOptions readerOpts) 347 : EzRpcServer(nullptr, socketFd, port, readerOpts) {} 348 349 EzRpcServer::~EzRpcServer() noexcept(false) {} 350 351 void EzRpcServer::exportCap(kj::StringPtr name, Capability::Client cap) { 352 Impl::ExportedCap entry(kj::heapString(name), cap); 353 impl->exportMap[entry.name] = kj::mv(entry); 354 } 355 356 kj::Promise<uint> EzRpcServer::getPort() { 357 return impl->portPromise.addBranch(); 358 } 359 360 kj::WaitScope& EzRpcServer::getWaitScope() { 361 return impl->context->getWaitScope(); 362 } 363 364 kj::AsyncIoProvider& EzRpcServer::getIoProvider() { 365 return impl->context->getIoProvider(); 366 } 367 368 kj::LowLevelAsyncIoProvider& EzRpcServer::getLowLevelIoProvider() { 369 return impl->context->getLowLevelIoProvider(); 370 } 371 372 } // namespace capnp