capnproto

FORK: Cap'n Proto serialization/RPC system - core tools and C++ library
git clone https://git.neptards.moe/neptards/capnproto.git
Log | Files | Refs | README | LICENSE

ez-rpc.c++ (13351B)


      1 // Copyright (c) 2013-2014 Sandstorm Development Group, Inc. and contributors
      2 // Licensed under the MIT License:
      3 //
      4 // Permission is hereby granted, free of charge, to any person obtaining a copy
      5 // of this software and associated documentation files (the "Software"), to deal
      6 // in the Software without restriction, including without limitation the rights
      7 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
      8 // copies of the Software, and to permit persons to whom the Software is
      9 // furnished to do so, subject to the following conditions:
     10 //
     11 // The above copyright notice and this permission notice shall be included in
     12 // all copies or substantial portions of the Software.
     13 //
     14 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
     15 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
     16 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
     17 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
     18 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
     19 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
     20 // THE SOFTWARE.
     21 
     22 #include "ez-rpc.h"
     23 #include "rpc-twoparty.h"
     24 #include <capnp/rpc.capnp.h>
     25 #include <kj/async-io.h>
     26 #include <kj/debug.h>
     27 #include <kj/threadlocal.h>
     28 #include <map>
     29 
     30 namespace capnp {
     31 
     32 KJ_THREADLOCAL_PTR(EzRpcContext) threadEzContext = nullptr;
     33 
     34 class EzRpcContext: public kj::Refcounted {
     35 public:
     36   EzRpcContext(): ioContext(kj::setupAsyncIo()) {
     37     threadEzContext = this;
     38   }
     39 
     40   ~EzRpcContext() noexcept(false) {
     41     KJ_REQUIRE(threadEzContext == this,
     42                "EzRpcContext destroyed from different thread than it was created.") {
     43       return;
     44     }
     45     threadEzContext = nullptr;
     46   }
     47 
     48   kj::WaitScope& getWaitScope() {
     49     return ioContext.waitScope;
     50   }
     51 
     52   kj::AsyncIoProvider& getIoProvider() {
     53     return *ioContext.provider;
     54   }
     55 
     56   kj::LowLevelAsyncIoProvider& getLowLevelIoProvider() {
     57     return *ioContext.lowLevelProvider;
     58   }
     59 
     60   static kj::Own<EzRpcContext> getThreadLocal() {
     61     EzRpcContext* existing = threadEzContext;
     62     if (existing != nullptr) {
     63       return kj::addRef(*existing);
     64     } else {
     65       return kj::refcounted<EzRpcContext>();
     66     }
     67   }
     68 
     69 private:
     70   kj::AsyncIoContext ioContext;
     71 };
     72 
     73 // =======================================================================================
     74 
     75 kj::Promise<kj::Own<kj::AsyncIoStream>> connectAttach(kj::Own<kj::NetworkAddress>&& addr) {
     76   return addr->connect().attach(kj::mv(addr));
     77 }
     78 
     79 struct EzRpcClient::Impl {
     80   kj::Own<EzRpcContext> context;
     81 
     82   struct ClientContext {
     83     kj::Own<kj::AsyncIoStream> stream;
     84     TwoPartyVatNetwork network;
     85     RpcSystem<rpc::twoparty::VatId> rpcSystem;
     86 
     87     ClientContext(kj::Own<kj::AsyncIoStream>&& stream, ReaderOptions readerOpts)
     88         : stream(kj::mv(stream)),
     89           network(*this->stream, rpc::twoparty::Side::CLIENT, readerOpts),
     90           rpcSystem(makeRpcClient(network)) {}
     91 
     92     Capability::Client getMain() {
     93       word scratch[4];
     94       memset(scratch, 0, sizeof(scratch));
     95       MallocMessageBuilder message(scratch);
     96       auto hostId = message.getRoot<rpc::twoparty::VatId>();
     97       hostId.setSide(rpc::twoparty::Side::SERVER);
     98       return rpcSystem.bootstrap(hostId);
     99     }
    100 
    101     Capability::Client restore(kj::StringPtr name) {
    102       word scratch[64];
    103       memset(scratch, 0, sizeof(scratch));
    104       MallocMessageBuilder message(scratch);
    105 
    106       auto hostIdOrphan = message.getOrphanage().newOrphan<rpc::twoparty::VatId>();
    107       auto hostId = hostIdOrphan.get();
    108       hostId.setSide(rpc::twoparty::Side::SERVER);
    109 
    110       auto objectId = message.getRoot<AnyPointer>();
    111       objectId.setAs<Text>(name);
    112 #pragma GCC diagnostic push
    113 #pragma GCC diagnostic ignored "-Wdeprecated-declarations"
    114       return rpcSystem.restore(hostId, objectId);
    115 #pragma GCC diagnostic pop
    116     }
    117   };
    118 
    119   kj::ForkedPromise<void> setupPromise;
    120 
    121   kj::Maybe<kj::Own<ClientContext>> clientContext;
    122   // Filled in before `setupPromise` resolves.
    123 
    124   Impl(kj::StringPtr serverAddress, uint defaultPort,
    125        ReaderOptions readerOpts)
    126       : context(EzRpcContext::getThreadLocal()),
    127         setupPromise(context->getIoProvider().getNetwork()
    128             .parseAddress(serverAddress, defaultPort)
    129             .then([](kj::Own<kj::NetworkAddress>&& addr) {
    130               return connectAttach(kj::mv(addr));
    131             }).then([this, readerOpts](kj::Own<kj::AsyncIoStream>&& stream) {
    132               clientContext = kj::heap<ClientContext>(kj::mv(stream),
    133                                                       readerOpts);
    134             }).fork()) {}
    135 
    136   Impl(const struct sockaddr* serverAddress, uint addrSize,
    137        ReaderOptions readerOpts)
    138       : context(EzRpcContext::getThreadLocal()),
    139         setupPromise(
    140             connectAttach(context->getIoProvider().getNetwork()
    141                 .getSockaddr(serverAddress, addrSize))
    142             .then([this, readerOpts](kj::Own<kj::AsyncIoStream>&& stream) {
    143               clientContext = kj::heap<ClientContext>(kj::mv(stream),
    144                                                       readerOpts);
    145             }).fork()) {}
    146 
    147   Impl(int socketFd, ReaderOptions readerOpts)
    148       : context(EzRpcContext::getThreadLocal()),
    149         setupPromise(kj::Promise<void>(kj::READY_NOW).fork()),
    150         clientContext(kj::heap<ClientContext>(
    151             context->getLowLevelIoProvider().wrapSocketFd(socketFd),
    152             readerOpts)) {}
    153 };
    154 
    155 EzRpcClient::EzRpcClient(kj::StringPtr serverAddress, uint defaultPort, ReaderOptions readerOpts)
    156     : impl(kj::heap<Impl>(serverAddress, defaultPort, readerOpts)) {}
    157 
    158 EzRpcClient::EzRpcClient(const struct sockaddr* serverAddress, uint addrSize, ReaderOptions readerOpts)
    159     : impl(kj::heap<Impl>(serverAddress, addrSize, readerOpts)) {}
    160 
    161 EzRpcClient::EzRpcClient(int socketFd, ReaderOptions readerOpts)
    162     : impl(kj::heap<Impl>(socketFd, readerOpts)) {}
    163 
    164 EzRpcClient::~EzRpcClient() noexcept(false) {}
    165 
    166 Capability::Client EzRpcClient::getMain() {
    167   KJ_IF_MAYBE(client, impl->clientContext) {
    168     return client->get()->getMain();
    169   } else {
    170     return impl->setupPromise.addBranch().then([this]() {
    171       return KJ_ASSERT_NONNULL(impl->clientContext)->getMain();
    172     });
    173   }
    174 }
    175 
    176 Capability::Client EzRpcClient::importCap(kj::StringPtr name) {
    177   KJ_IF_MAYBE(client, impl->clientContext) {
    178     return client->get()->restore(name);
    179   } else {
    180     return impl->setupPromise.addBranch().then(kj::mvCapture(kj::heapString(name),
    181         [this](kj::String&& name) {
    182       return KJ_ASSERT_NONNULL(impl->clientContext)->restore(name);
    183     }));
    184   }
    185 }
    186 
    187 kj::WaitScope& EzRpcClient::getWaitScope() {
    188   return impl->context->getWaitScope();
    189 }
    190 
    191 kj::AsyncIoProvider& EzRpcClient::getIoProvider() {
    192   return impl->context->getIoProvider();
    193 }
    194 
    195 kj::LowLevelAsyncIoProvider& EzRpcClient::getLowLevelIoProvider() {
    196   return impl->context->getLowLevelIoProvider();
    197 }
    198 
    199 // =======================================================================================
    200 
    201 namespace {
    202 
    203 class DummyFilter: public kj::LowLevelAsyncIoProvider::NetworkFilter {
    204 public:
    205   bool shouldAllow(const struct sockaddr* addr, uint addrlen) override {
    206     return true;
    207   }
    208 };
    209 
    210 static DummyFilter DUMMY_FILTER;
    211 
    212 }  // namespace
    213 
    214 struct EzRpcServer::Impl final: public SturdyRefRestorer<AnyPointer>,
    215                                 public kj::TaskSet::ErrorHandler {
    216   Capability::Client mainInterface;
    217   kj::Own<EzRpcContext> context;
    218 
    219   struct ExportedCap {
    220     kj::String name;
    221     Capability::Client cap = nullptr;
    222 
    223     ExportedCap(kj::StringPtr name, Capability::Client cap)
    224         : name(kj::heapString(name)), cap(cap) {}
    225 
    226     ExportedCap() = default;
    227     ExportedCap(const ExportedCap&) = delete;
    228     ExportedCap(ExportedCap&&) = default;
    229     ExportedCap& operator=(const ExportedCap&) = delete;
    230     ExportedCap& operator=(ExportedCap&&) = default;
    231     // Make std::map happy...
    232   };
    233 
    234   std::map<kj::StringPtr, ExportedCap> exportMap;
    235 
    236   kj::ForkedPromise<uint> portPromise;
    237 
    238   kj::TaskSet tasks;
    239 
    240   struct ServerContext {
    241     kj::Own<kj::AsyncIoStream> stream;
    242     TwoPartyVatNetwork network;
    243     RpcSystem<rpc::twoparty::VatId> rpcSystem;
    244 
    245 #pragma GCC diagnostic push
    246 #pragma GCC diagnostic ignored "-Wdeprecated-declarations"
    247     ServerContext(kj::Own<kj::AsyncIoStream>&& stream, SturdyRefRestorer<AnyPointer>& restorer,
    248                   ReaderOptions readerOpts)
    249         : stream(kj::mv(stream)),
    250           network(*this->stream, rpc::twoparty::Side::SERVER, readerOpts),
    251           rpcSystem(makeRpcServer(network, restorer)) {}
    252 #pragma GCC diagnostic pop
    253   };
    254 
    255   Impl(Capability::Client mainInterface, kj::StringPtr bindAddress, uint defaultPort,
    256        ReaderOptions readerOpts)
    257       : mainInterface(kj::mv(mainInterface)),
    258         context(EzRpcContext::getThreadLocal()), portPromise(nullptr), tasks(*this) {
    259     auto paf = kj::newPromiseAndFulfiller<uint>();
    260     portPromise = paf.promise.fork();
    261 
    262     tasks.add(context->getIoProvider().getNetwork().parseAddress(bindAddress, defaultPort)
    263         .then(kj::mvCapture(paf.fulfiller,
    264           [this, readerOpts](kj::Own<kj::PromiseFulfiller<uint>>&& portFulfiller,
    265                              kj::Own<kj::NetworkAddress>&& addr) {
    266       auto listener = addr->listen();
    267       portFulfiller->fulfill(listener->getPort());
    268       acceptLoop(kj::mv(listener), readerOpts);
    269     })));
    270   }
    271 
    272   Impl(Capability::Client mainInterface, struct sockaddr* bindAddress, uint addrSize,
    273        ReaderOptions readerOpts)
    274       : mainInterface(kj::mv(mainInterface)),
    275         context(EzRpcContext::getThreadLocal()), portPromise(nullptr), tasks(*this) {
    276     auto listener = context->getIoProvider().getNetwork()
    277         .getSockaddr(bindAddress, addrSize)->listen();
    278     portPromise = kj::Promise<uint>(listener->getPort()).fork();
    279     acceptLoop(kj::mv(listener), readerOpts);
    280   }
    281 
    282   Impl(Capability::Client mainInterface, int socketFd, uint port, ReaderOptions readerOpts)
    283       : mainInterface(kj::mv(mainInterface)),
    284         context(EzRpcContext::getThreadLocal()),
    285         portPromise(kj::Promise<uint>(port).fork()),
    286         tasks(*this) {
    287     acceptLoop(context->getLowLevelIoProvider().wrapListenSocketFd(socketFd, DUMMY_FILTER),
    288                readerOpts);
    289   }
    290 
    291   void acceptLoop(kj::Own<kj::ConnectionReceiver>&& listener, ReaderOptions readerOpts) {
    292     auto ptr = listener.get();
    293     tasks.add(ptr->accept().then(kj::mvCapture(kj::mv(listener),
    294         [this, readerOpts](kj::Own<kj::ConnectionReceiver>&& listener,
    295                            kj::Own<kj::AsyncIoStream>&& connection) {
    296       acceptLoop(kj::mv(listener), readerOpts);
    297 
    298       auto server = kj::heap<ServerContext>(kj::mv(connection), *this, readerOpts);
    299 
    300       // Arrange to destroy the server context when all references are gone, or when the
    301       // EzRpcServer is destroyed (which will destroy the TaskSet).
    302       tasks.add(server->network.onDisconnect().attach(kj::mv(server)));
    303     })));
    304   }
    305 
    306   Capability::Client restore(AnyPointer::Reader objectId) override {
    307     if (objectId.isNull()) {
    308       return mainInterface;
    309     } else {
    310       auto name = objectId.getAs<Text>();
    311       auto iter = exportMap.find(name);
    312       if (iter == exportMap.end()) {
    313         KJ_FAIL_REQUIRE("Server exports no such capability.", name) { break; }
    314         return nullptr;
    315       } else {
    316         return iter->second.cap;
    317       }
    318     }
    319   }
    320 
    321   void taskFailed(kj::Exception&& exception) override {
    322     kj::throwFatalException(kj::mv(exception));
    323   }
    324 };
    325 
    326 EzRpcServer::EzRpcServer(Capability::Client mainInterface, kj::StringPtr bindAddress,
    327                          uint defaultPort, ReaderOptions readerOpts)
    328     : impl(kj::heap<Impl>(kj::mv(mainInterface), bindAddress, defaultPort, readerOpts)) {}
    329 
    330 EzRpcServer::EzRpcServer(Capability::Client mainInterface, struct sockaddr* bindAddress,
    331                          uint addrSize, ReaderOptions readerOpts)
    332     : impl(kj::heap<Impl>(kj::mv(mainInterface), bindAddress, addrSize, readerOpts)) {}
    333 
    334 EzRpcServer::EzRpcServer(Capability::Client mainInterface, int socketFd, uint port,
    335                          ReaderOptions readerOpts)
    336     : impl(kj::heap<Impl>(kj::mv(mainInterface), socketFd, port, readerOpts)) {}
    337 
    338 EzRpcServer::EzRpcServer(kj::StringPtr bindAddress, uint defaultPort,
    339                          ReaderOptions readerOpts)
    340     : EzRpcServer(nullptr, bindAddress, defaultPort, readerOpts) {}
    341 
    342 EzRpcServer::EzRpcServer(struct sockaddr* bindAddress, uint addrSize,
    343                          ReaderOptions readerOpts)
    344     : EzRpcServer(nullptr, bindAddress, addrSize, readerOpts) {}
    345 
    346 EzRpcServer::EzRpcServer(int socketFd, uint port, ReaderOptions readerOpts)
    347     : EzRpcServer(nullptr, socketFd, port, readerOpts) {}
    348 
    349 EzRpcServer::~EzRpcServer() noexcept(false) {}
    350 
    351 void EzRpcServer::exportCap(kj::StringPtr name, Capability::Client cap) {
    352   Impl::ExportedCap entry(kj::heapString(name), cap);
    353   impl->exportMap[entry.name] = kj::mv(entry);
    354 }
    355 
    356 kj::Promise<uint> EzRpcServer::getPort() {
    357   return impl->portPromise.addBranch();
    358 }
    359 
    360 kj::WaitScope& EzRpcServer::getWaitScope() {
    361   return impl->context->getWaitScope();
    362 }
    363 
    364 kj::AsyncIoProvider& EzRpcServer::getIoProvider() {
    365   return impl->context->getIoProvider();
    366 }
    367 
    368 kj::LowLevelAsyncIoProvider& EzRpcServer::getLowLevelIoProvider() {
    369   return impl->context->getLowLevelIoProvider();
    370 }
    371 
    372 }  // namespace capnp