capnproto

FORK: Cap'n Proto serialization/RPC system - core tools and C++ library
git clone https://git.neptards.moe/neptards/capnproto.git
Log | Files | Refs | README | LICENSE

reconnect.c++ (4865B)


      1 // Copyright (c) 2020 Cloudflare, Inc. and contributors
      2 // Licensed under the MIT License:
      3 //
      4 // Permission is hereby granted, free of charge, to any person obtaining a copy
      5 // of this software and associated documentation files (the "Software"), to deal
      6 // in the Software without restriction, including without limitation the rights
      7 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
      8 // copies of the Software, and to permit persons to whom the Software is
      9 // furnished to do so, subject to the following conditions:
     10 //
     11 // The above copyright notice and this permission notice shall be included in
     12 // all copies or substantial portions of the Software.
     13 //
     14 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
     15 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
     16 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
     17 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
     18 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
     19 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
     20 // THE SOFTWARE.
     21 
     22 #include "reconnect.h"
     23 
     24 namespace capnp {
     25 
     26 namespace {
     27 
     28 class ReconnectHook final: public ClientHook, public kj::Refcounted {
     29 public:
     30   ReconnectHook(kj::Function<Capability::Client()> connectParam, bool lazy = false)
     31       : connect(kj::mv(connectParam)),
     32         current(lazy ? kj::Maybe<kj::Own<ClientHook>>() : ClientHook::from(connect())) {}
     33 
     34   Request<AnyPointer, AnyPointer> newCall(
     35       uint64_t interfaceId, uint16_t methodId, kj::Maybe<MessageSize> sizeHint) override {
     36     auto result = getCurrent().newCall(interfaceId, methodId, sizeHint);
     37     AnyPointer::Builder builder = result;
     38     auto hook = kj::heap<RequestImpl>(kj::addRef(*this), RequestHook::from(kj::mv(result)));
     39     return { builder, kj::mv(hook) };
     40   }
     41 
     42   VoidPromiseAndPipeline call(uint64_t interfaceId, uint16_t methodId,
     43                               kj::Own<CallContextHook>&& context) override {
     44     auto result = getCurrent().call(interfaceId, methodId, kj::mv(context));
     45     wrap(result.promise);
     46     return result;
     47   }
     48 
     49   kj::Maybe<ClientHook&> getResolved() override {
     50     // We can't let people resolve to the underlying capability because then we wouldn't be able
     51     // to redirect them later.
     52     return nullptr;
     53   }
     54 
     55   kj::Maybe<kj::Promise<kj::Own<ClientHook>>> whenMoreResolved() override {
     56     return nullptr;
     57   }
     58 
     59   kj::Own<ClientHook> addRef() override {
     60     return kj::addRef(*this);
     61   }
     62 
     63   const void* getBrand() override {
     64     return nullptr;
     65   }
     66 
     67   kj::Maybe<int> getFd() override {
     68     // It's not safe to return current->getFd() because normally callers wouldn't expect the FD to
     69     // change or go away over time, but this one could whenever we reconnect. If there's a use
     70     // case for being able to access the FD here, we'll need a different interface to do it.
     71     return nullptr;
     72   }
     73 
     74 private:
     75   kj::Function<Capability::Client()> connect;
     76   kj::Maybe<kj::Own<ClientHook>> current;
     77   uint generation = 0;
     78 
     79   template <typename T>
     80   void wrap(kj::Promise<T>& promise) {
     81     promise = promise.catch_(
     82         [self = kj::addRef(*this), startGeneration = generation]
     83         (kj::Exception&& exception) mutable -> kj::Promise<T> {
     84       if (exception.getType() == kj::Exception::Type::DISCONNECTED &&
     85           self->generation == startGeneration) {
     86         self->generation++;
     87         KJ_IF_MAYBE(e2, kj::runCatchingExceptions([&]() {
     88           self->current = ClientHook::from(self->connect());
     89         })) {
     90           self->current = newBrokenCap(kj::mv(*e2));
     91         }
     92       }
     93       return kj::mv(exception);
     94     });
     95   }
     96 
     97   ClientHook& getCurrent() {
     98     KJ_IF_MAYBE(c, current) {
     99       return **c;
    100     } else {
    101       return *current.emplace(ClientHook::from(connect()));
    102     }
    103   }
    104 
    105   class RequestImpl final: public RequestHook {
    106   public:
    107     RequestImpl(kj::Own<ReconnectHook> parent, kj::Own<RequestHook> inner)
    108         : parent(kj::mv(parent)), inner(kj::mv(inner)) {}
    109 
    110     RemotePromise<AnyPointer> send() override {
    111       auto result = inner->send();
    112       parent->wrap(result);
    113       return result;
    114     }
    115 
    116     kj::Promise<void> sendStreaming() override {
    117       auto result = inner->sendStreaming();
    118       parent->wrap(result);
    119       return result;
    120     }
    121 
    122     const void* getBrand() override {
    123       return nullptr;
    124     }
    125 
    126   private:
    127     kj::Own<ReconnectHook> parent;
    128     kj::Own<RequestHook> inner;
    129   };
    130 };
    131 
    132 }  // namespace
    133 
    134 Capability::Client autoReconnect(kj::Function<Capability::Client()> connect) {
    135   return Capability::Client(kj::refcounted<ReconnectHook>(kj::mv(connect)));
    136 }
    137 
    138 Capability::Client lazyAutoReconnect(kj::Function<Capability::Client()> connect) {
    139   return Capability::Client(kj::refcounted<ReconnectHook>(kj::mv(connect), true));
    140 }
    141 }  // namespace capnp