reconnect.c++ (4865B)
1 // Copyright (c) 2020 Cloudflare, Inc. and contributors 2 // Licensed under the MIT License: 3 // 4 // Permission is hereby granted, free of charge, to any person obtaining a copy 5 // of this software and associated documentation files (the "Software"), to deal 6 // in the Software without restriction, including without limitation the rights 7 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 8 // copies of the Software, and to permit persons to whom the Software is 9 // furnished to do so, subject to the following conditions: 10 // 11 // The above copyright notice and this permission notice shall be included in 12 // all copies or substantial portions of the Software. 13 // 14 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 17 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 19 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 20 // THE SOFTWARE. 21 22 #include "reconnect.h" 23 24 namespace capnp { 25 26 namespace { 27 28 class ReconnectHook final: public ClientHook, public kj::Refcounted { 29 public: 30 ReconnectHook(kj::Function<Capability::Client()> connectParam, bool lazy = false) 31 : connect(kj::mv(connectParam)), 32 current(lazy ? kj::Maybe<kj::Own<ClientHook>>() : ClientHook::from(connect())) {} 33 34 Request<AnyPointer, AnyPointer> newCall( 35 uint64_t interfaceId, uint16_t methodId, kj::Maybe<MessageSize> sizeHint) override { 36 auto result = getCurrent().newCall(interfaceId, methodId, sizeHint); 37 AnyPointer::Builder builder = result; 38 auto hook = kj::heap<RequestImpl>(kj::addRef(*this), RequestHook::from(kj::mv(result))); 39 return { builder, kj::mv(hook) }; 40 } 41 42 VoidPromiseAndPipeline call(uint64_t interfaceId, uint16_t methodId, 43 kj::Own<CallContextHook>&& context) override { 44 auto result = getCurrent().call(interfaceId, methodId, kj::mv(context)); 45 wrap(result.promise); 46 return result; 47 } 48 49 kj::Maybe<ClientHook&> getResolved() override { 50 // We can't let people resolve to the underlying capability because then we wouldn't be able 51 // to redirect them later. 52 return nullptr; 53 } 54 55 kj::Maybe<kj::Promise<kj::Own<ClientHook>>> whenMoreResolved() override { 56 return nullptr; 57 } 58 59 kj::Own<ClientHook> addRef() override { 60 return kj::addRef(*this); 61 } 62 63 const void* getBrand() override { 64 return nullptr; 65 } 66 67 kj::Maybe<int> getFd() override { 68 // It's not safe to return current->getFd() because normally callers wouldn't expect the FD to 69 // change or go away over time, but this one could whenever we reconnect. If there's a use 70 // case for being able to access the FD here, we'll need a different interface to do it. 71 return nullptr; 72 } 73 74 private: 75 kj::Function<Capability::Client()> connect; 76 kj::Maybe<kj::Own<ClientHook>> current; 77 uint generation = 0; 78 79 template <typename T> 80 void wrap(kj::Promise<T>& promise) { 81 promise = promise.catch_( 82 [self = kj::addRef(*this), startGeneration = generation] 83 (kj::Exception&& exception) mutable -> kj::Promise<T> { 84 if (exception.getType() == kj::Exception::Type::DISCONNECTED && 85 self->generation == startGeneration) { 86 self->generation++; 87 KJ_IF_MAYBE(e2, kj::runCatchingExceptions([&]() { 88 self->current = ClientHook::from(self->connect()); 89 })) { 90 self->current = newBrokenCap(kj::mv(*e2)); 91 } 92 } 93 return kj::mv(exception); 94 }); 95 } 96 97 ClientHook& getCurrent() { 98 KJ_IF_MAYBE(c, current) { 99 return **c; 100 } else { 101 return *current.emplace(ClientHook::from(connect())); 102 } 103 } 104 105 class RequestImpl final: public RequestHook { 106 public: 107 RequestImpl(kj::Own<ReconnectHook> parent, kj::Own<RequestHook> inner) 108 : parent(kj::mv(parent)), inner(kj::mv(inner)) {} 109 110 RemotePromise<AnyPointer> send() override { 111 auto result = inner->send(); 112 parent->wrap(result); 113 return result; 114 } 115 116 kj::Promise<void> sendStreaming() override { 117 auto result = inner->sendStreaming(); 118 parent->wrap(result); 119 return result; 120 } 121 122 const void* getBrand() override { 123 return nullptr; 124 } 125 126 private: 127 kj::Own<ReconnectHook> parent; 128 kj::Own<RequestHook> inner; 129 }; 130 }; 131 132 } // namespace 133 134 Capability::Client autoReconnect(kj::Function<Capability::Client()> connect) { 135 return Capability::Client(kj::refcounted<ReconnectHook>(kj::mv(connect))); 136 } 137 138 Capability::Client lazyAutoReconnect(kj::Function<Capability::Client()> connect) { 139 return Capability::Client(kj::refcounted<ReconnectHook>(kj::mv(connect), true)); 140 } 141 } // namespace capnp