capnproto

FORK: Cap'n Proto serialization/RPC system - core tools and C++ library
git clone https://git.neptards.moe/neptards/capnproto.git
Log | Files | Refs | README | LICENSE

tls.c++ (33847B)


      1 // Copyright (c) 2016 Sandstorm Development Group, Inc. and contributors
      2 // Licensed under the MIT License:
      3 //
      4 // Permission is hereby granted, free of charge, to any person obtaining a copy
      5 // of this software and associated documentation files (the "Software"), to deal
      6 // in the Software without restriction, including without limitation the rights
      7 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
      8 // copies of the Software, and to permit persons to whom the Software is
      9 // furnished to do so, subject to the following conditions:
     10 //
     11 // The above copyright notice and this permission notice shall be included in
     12 // all copies or substantial portions of the Software.
     13 //
     14 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
     15 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
     16 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
     17 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
     18 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
     19 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
     20 // THE SOFTWARE.
     21 
     22 #if KJ_HAS_OPENSSL
     23 
     24 #include "tls.h"
     25 
     26 #include "readiness-io.h"
     27 
     28 #include <openssl/bio.h>
     29 #include <openssl/conf.h>
     30 #include <openssl/err.h>
     31 #include <openssl/evp.h>
     32 #include <openssl/ssl.h>
     33 #include <openssl/tls1.h>
     34 #include <openssl/x509.h>
     35 #include <openssl/x509v3.h>
     36 
     37 #include <kj/async-queue.h>
     38 #include <kj/debug.h>
     39 #include <kj/vector.h>
     40 
     41 #if OPENSSL_VERSION_NUMBER < 0x10100000L
     42 #define BIO_set_init(x,v)          (x->init=v)
     43 #define BIO_get_data(x)            (x->ptr)
     44 #define BIO_set_data(x,v)          (x->ptr=v)
     45 #endif
     46 
     47 namespace kj {
     48 
     49 // =======================================================================================
     50 // misc helpers
     51 
     52 namespace {
     53 
     54 KJ_NORETURN(void throwOpensslError());
     55 void throwOpensslError() {
     56   // Call when an OpenSSL function returns an error code to convert that into an exception and
     57   // throw it.
     58 
     59   kj::Vector<kj::String> lines;
     60   while (unsigned long long error = ERR_get_error()) {
     61     char message[1024];
     62     ERR_error_string_n(error, message, sizeof(message));
     63     lines.add(kj::heapString(message));
     64   }
     65   kj::String message = kj::strArray(lines, "\n");
     66   KJ_FAIL_ASSERT("OpenSSL error", message);
     67 }
     68 
     69 #if OPENSSL_VERSION_NUMBER < 0x10100000L && !defined(OPENSSL_IS_BORINGSSL)
     70 // Older versions of OpenSSL don't define _up_ref() functions.
     71 
     72 void EVP_PKEY_up_ref(EVP_PKEY* pkey) {
     73   CRYPTO_add(&pkey->references, 1, CRYPTO_LOCK_EVP_PKEY);
     74 }
     75 
     76 void X509_up_ref(X509* x509) {
     77   CRYPTO_add(&x509->references, 1, CRYPTO_LOCK_X509);
     78 }
     79 
     80 #endif
     81 
     82 #if OPENSSL_VERSION_NUMBER < 0x10100000L
     83 class OpenSslInit {
     84   // Initializes the OpenSSL library.
     85 public:
     86   OpenSslInit() {
     87     SSL_library_init();
     88     SSL_load_error_strings();
     89     OPENSSL_config(nullptr);
     90   }
     91 };
     92 
     93 void ensureOpenSslInitialized() {
     94   // Initializes the OpenSSL library the first time it is called.
     95   static OpenSslInit init;
     96 }
     97 #else
     98 inline void ensureOpenSslInitialized() {
     99   // As of 1.1.0, no initialization is needed.
    100 }
    101 #endif
    102 
    103 }  // namespace
    104 
    105 // =======================================================================================
    106 // Implementation of kj::AsyncIoStream that applies TLS on top of some other AsyncIoStream.
    107 //
    108 // TODO(perf): OpenSSL's I/O abstraction layer, "BIO", is readiness-based, but AsyncIoStream is
    109 //   completion-based. This forces us to use an intermediate buffer which wastes memory and incurs
    110 //   redundant copies. We could improve the situation by creating a way to detect if the underlying
    111 //   AsyncIoStream is simply wrapping a file descriptor (or other readiness-based stream?) and use
    112 //   that directly if so.
    113 
    114 class TlsConnection final: public kj::AsyncIoStream {
    115 public:
    116   TlsConnection(kj::Own<kj::AsyncIoStream> stream, SSL_CTX* ctx)
    117       : TlsConnection(*stream, ctx) {
    118     ownInner = kj::mv(stream);
    119   }
    120 
    121   TlsConnection(kj::AsyncIoStream& stream, SSL_CTX* ctx)
    122       : inner(stream), readBuffer(stream), writeBuffer(stream) {
    123     ssl = SSL_new(ctx);
    124     if (ssl == nullptr) {
    125       throwOpensslError();
    126     }
    127 
    128     BIO* bio = BIO_new(const_cast<BIO_METHOD*>(getBioVtable()));
    129     if (bio == nullptr) {
    130       SSL_free(ssl);
    131       throwOpensslError();
    132     }
    133 
    134     BIO_set_data(bio, this);
    135     BIO_set_init(bio, 1);
    136     SSL_set_bio(ssl, bio, bio);
    137   }
    138 
    139   kj::Promise<void> connect(kj::StringPtr expectedServerHostname) {
    140     if (!SSL_set_tlsext_host_name(ssl, expectedServerHostname.cStr())) {
    141       throwOpensslError();
    142     }
    143 
    144     X509_VERIFY_PARAM* verify = SSL_get0_param(ssl);
    145     if (verify == nullptr) {
    146       throwOpensslError();
    147     }
    148 
    149     if (X509_VERIFY_PARAM_set1_host(
    150         verify, expectedServerHostname.cStr(), expectedServerHostname.size()) <= 0) {
    151       throwOpensslError();
    152     }
    153 
    154     return sslCall([this]() { return SSL_connect(ssl); }).then([this](size_t) {
    155       X509* cert = SSL_get_peer_certificate(ssl);
    156       KJ_REQUIRE(cert != nullptr, "TLS peer provided no certificate");
    157       X509_free(cert);
    158 
    159       auto result = SSL_get_verify_result(ssl);
    160       if (result != X509_V_OK) {
    161         const char* reason = X509_verify_cert_error_string(result);
    162         KJ_FAIL_REQUIRE("TLS peer's certificate is not trusted", reason);
    163       }
    164     });
    165   }
    166 
    167   kj::Promise<void> accept() {
    168     // We are the server. Set SSL options to prefer server's cipher choice.
    169     SSL_set_options(ssl, SSL_OP_CIPHER_SERVER_PREFERENCE);
    170 
    171     auto acceptPromise = sslCall([this]() {
    172       return SSL_accept(ssl);
    173     });
    174     return acceptPromise.then([](size_t ret) {
    175       if (ret == 0) {
    176         kj::throwRecoverableException(
    177             KJ_EXCEPTION(DISCONNECTED, "Client disconnected during SSL_accept()"));
    178       }
    179     });
    180   }
    181 
    182   kj::Own<TlsPeerIdentity> getIdentity(kj::Own<kj::PeerIdentity> inner) {
    183     return kj::heap<TlsPeerIdentity>(SSL_get_peer_certificate(ssl), kj::mv(inner),
    184                                      kj::Badge<TlsConnection>());
    185   }
    186 
    187   ~TlsConnection() noexcept(false) {
    188     SSL_free(ssl);
    189   }
    190 
    191   kj::Promise<size_t> tryRead(void* buffer, size_t minBytes, size_t maxBytes) override {
    192     return tryReadInternal(buffer, minBytes, maxBytes, 0);
    193   }
    194 
    195   Promise<void> write(const void* buffer, size_t size) override {
    196     return writeInternal(kj::arrayPtr(reinterpret_cast<const byte*>(buffer), size), nullptr);
    197   }
    198 
    199   Promise<void> write(ArrayPtr<const ArrayPtr<const byte>> pieces) override {
    200     auto cork = writeBuffer.cork();
    201     return writeInternal(pieces[0], pieces.slice(1, pieces.size())).attach(kj::mv(cork));
    202   }
    203 
    204   Promise<void> whenWriteDisconnected() override {
    205     return inner.whenWriteDisconnected();
    206   }
    207 
    208   void shutdownWrite() override {
    209     KJ_REQUIRE(shutdownTask == nullptr, "already called shutdownWrite()");
    210 
    211     // TODO(0.10): shutdownWrite() is problematic because it doesn't return a promise. It was
    212     //   designed to assume that it would only be called after all writes are finished and that
    213     //   there was no reason to block at that point, but SSL sessions don't fit this since they
    214     //   actually have to send a shutdown message.
    215     shutdownTask = sslCall([this]() {
    216       // The first SSL_shutdown() call is expected to return 0 and may flag a misleading error.
    217       int result = SSL_shutdown(ssl);
    218       return result == 0 ? 1 : result;
    219     }).ignoreResult().eagerlyEvaluate([](kj::Exception&& e) {
    220       KJ_LOG(ERROR, e);
    221     });
    222   }
    223 
    224   void abortRead() override {
    225     inner.abortRead();
    226   }
    227 
    228   void getsockopt(int level, int option, void* value, uint* length) override {
    229     inner.getsockopt(level, option, value, length);
    230   }
    231   void setsockopt(int level, int option, const void* value, uint length) override {
    232     inner.setsockopt(level, option, value, length);
    233   }
    234 
    235   void getsockname(struct sockaddr* addr, uint* length) override {
    236     inner.getsockname(addr, length);
    237   }
    238   void getpeername(struct sockaddr* addr, uint* length) override {
    239     inner.getpeername(addr, length);
    240   }
    241 
    242   kj::Maybe<int> getFd() const override {
    243     return inner.getFd();
    244   }
    245 
    246 private:
    247   SSL* ssl;
    248   kj::AsyncIoStream& inner;
    249   kj::Own<kj::AsyncIoStream> ownInner;
    250 
    251   bool disconnected = false;
    252   kj::Maybe<kj::Promise<void>> shutdownTask;
    253 
    254   ReadyInputStreamWrapper readBuffer;
    255   ReadyOutputStreamWrapper writeBuffer;
    256 
    257   kj::Promise<size_t> tryReadInternal(
    258       void* buffer, size_t minBytes, size_t maxBytes, size_t alreadyDone) {
    259     if (disconnected) return alreadyDone;
    260 
    261     return sslCall([this,buffer,maxBytes]() { return SSL_read(ssl, buffer, maxBytes); })
    262         .then([this,buffer,minBytes,maxBytes,alreadyDone](size_t n) -> kj::Promise<size_t> {
    263       if (n >= minBytes || n == 0) {
    264         return alreadyDone + n;
    265       } else {
    266         return tryReadInternal(reinterpret_cast<byte*>(buffer) + n,
    267                                minBytes - n, maxBytes - n, alreadyDone + n);
    268       }
    269     });
    270   }
    271 
    272   Promise<void> writeInternal(kj::ArrayPtr<const byte> first,
    273                               kj::ArrayPtr<const kj::ArrayPtr<const byte>> rest) {
    274     KJ_REQUIRE(shutdownTask == nullptr, "already called shutdownWrite()");
    275 
    276     // SSL_write() with a zero-sized input returns 0, but a 0 return is documented as indicating
    277     // an error. So, we need to avoid zero-sized writes entirely.
    278     while (first.size() == 0) {
    279       if (rest.size() == 0) {
    280         return kj::READY_NOW;
    281       }
    282       first = rest.front();
    283       rest = rest.slice(1, rest.size());
    284     }
    285 
    286     return sslCall([this,first]() { return SSL_write(ssl, first.begin(), first.size()); })
    287         .then([this,first,rest](size_t n) -> kj::Promise<void> {
    288       if (n == 0) {
    289         return KJ_EXCEPTION(DISCONNECTED, "ssl connection ended during write");
    290       } else if (n < first.size()) {
    291         return writeInternal(first.slice(n, first.size()), rest);
    292       } else if (rest.size() > 0) {
    293         return writeInternal(rest[0], rest.slice(1, rest.size()));
    294       } else {
    295         return kj::READY_NOW;
    296       }
    297     });
    298   }
    299 
    300   template <typename Func>
    301   kj::Promise<size_t> sslCall(Func&& func) {
    302     if (disconnected) return size_t(0);
    303 
    304     auto result = func();
    305 
    306     if (result > 0) {
    307       return result;
    308     } else {
    309       int error = SSL_get_error(ssl, result);
    310       switch (error) {
    311         case SSL_ERROR_ZERO_RETURN:
    312           disconnected = true;
    313           return size_t(0);
    314         case SSL_ERROR_WANT_READ:
    315           return readBuffer.whenReady().then(kj::mvCapture(func,
    316               [this](Func&& func) mutable { return sslCall(kj::fwd<Func>(func)); }));
    317         case SSL_ERROR_WANT_WRITE:
    318           return writeBuffer.whenReady().then(kj::mvCapture(func,
    319               [this](Func&& func) mutable { return sslCall(kj::fwd<Func>(func)); }));
    320         case SSL_ERROR_SSL:
    321           throwOpensslError();
    322         case SSL_ERROR_SYSCALL:
    323           if (result == 0) {
    324             disconnected = true;
    325             return size_t(0);
    326           } else {
    327             // According to documentation we shouldn't get here, because our BIO never returns an
    328             // "error". But in practice we do get here sometimes when the peer disconnects
    329             // prematurely.
    330             return KJ_EXCEPTION(DISCONNECTED, "SSL unable to continue I/O");
    331           }
    332         default:
    333           KJ_FAIL_ASSERT("unexpected SSL error code", error);
    334       }
    335     }
    336   }
    337 
    338   static int bioRead(BIO* b, char* out, int outl) {
    339     BIO_clear_retry_flags(b);
    340     KJ_IF_MAYBE(n, reinterpret_cast<TlsConnection*>(BIO_get_data(b))->readBuffer
    341         .read(kj::arrayPtr(out, outl).asBytes())) {
    342       return *n;
    343     } else {
    344       BIO_set_retry_read(b);
    345       return -1;
    346     }
    347   }
    348 
    349   static int bioWrite(BIO* b, const char* in, int inl) {
    350     BIO_clear_retry_flags(b);
    351     KJ_IF_MAYBE(n, reinterpret_cast<TlsConnection*>(BIO_get_data(b))->writeBuffer
    352         .write(kj::arrayPtr(in, inl).asBytes())) {
    353       return *n;
    354     } else {
    355       BIO_set_retry_write(b);
    356       return -1;
    357     }
    358   }
    359 
    360   static long bioCtrl(BIO* b, int cmd, long num, void* ptr) {
    361     switch (cmd) {
    362       case BIO_CTRL_FLUSH:
    363         return 1;
    364       case BIO_CTRL_PUSH:
    365       case BIO_CTRL_POP:
    366         // Informational?
    367         return 0;
    368       default:
    369         KJ_LOG(WARNING, "unimplemented bio_ctrl", cmd);
    370         return 0;
    371     }
    372   }
    373 
    374   static int bioCreate(BIO* b) {
    375     BIO_set_data(b, nullptr);
    376     return 1;
    377   }
    378 
    379   static int bioDestroy(BIO* b) {
    380     // The BIO does NOT own the TlsConnection.
    381     return 1;
    382   }
    383 
    384 #if OPENSSL_VERSION_NUMBER < 0x10100000L
    385   static const BIO_METHOD* getBioVtable() {
    386     static const BIO_METHOD VTABLE {
    387       BIO_TYPE_SOURCE_SINK,
    388       "KJ stream",
    389       TlsConnection::bioWrite,
    390       TlsConnection::bioRead,
    391       nullptr,  // puts
    392       nullptr,  // gets
    393       TlsConnection::bioCtrl,
    394       TlsConnection::bioCreate,
    395       TlsConnection::bioDestroy,
    396       nullptr
    397     };
    398     return &VTABLE;
    399   }
    400 #else
    401   static const BIO_METHOD* getBioVtable() {
    402     static const BIO_METHOD* const vtable = makeBioVtable();
    403     return vtable;
    404   }
    405   static const BIO_METHOD* makeBioVtable() {
    406     BIO_METHOD* vtable = BIO_meth_new(BIO_TYPE_SOURCE_SINK, "KJ stream");
    407     BIO_meth_set_write(vtable, TlsConnection::bioWrite);
    408     BIO_meth_set_read(vtable, TlsConnection::bioRead);
    409     BIO_meth_set_ctrl(vtable, TlsConnection::bioCtrl);
    410     BIO_meth_set_create(vtable, TlsConnection::bioCreate);
    411     BIO_meth_set_destroy(vtable, TlsConnection::bioDestroy);
    412     return vtable;
    413   }
    414 #endif
    415 };
    416 
    417 // =======================================================================================
    418 // Implementations of ConnectionReceiver, NetworkAddress, and Network as wrappers adding TLS.
    419 
    420 class TlsConnectionReceiver final: public ConnectionReceiver, public TaskSet::ErrorHandler {
    421 public:
    422   TlsConnectionReceiver(TlsContext &tls, Own<ConnectionReceiver> inner)
    423       : tls(tls), inner(kj::mv(inner)),
    424         acceptLoopTask(acceptLoop().eagerlyEvaluate([this](Exception &&e) {
    425           onAcceptFailure(kj::mv(e));
    426         })),
    427         tasks(*this) {}
    428 
    429   void taskFailed(Exception&& e) override {
    430     // TODO(someday): SSL connection failures may be a fact of normal operation but they may also
    431     // be important diagnostic information. We should allow for an error handler to be passed in so
    432     // that network issues that affect TLS can be more discoverable from the server side.
    433     if (e.getType() != Exception::Type::DISCONNECTED) {
    434       KJ_LOG(ERROR, "error accepting tls connection", kj::mv(e));
    435     }
    436   };
    437 
    438   Promise<Own<AsyncIoStream>> accept() override {
    439     return acceptAuthenticated().then([](AuthenticatedStream&& stream) {
    440       return kj::mv(stream.stream);
    441     });
    442   }
    443 
    444   Promise<AuthenticatedStream> acceptAuthenticated() override {
    445     KJ_IF_MAYBE(e, maybeInnerException) {
    446       // We've experienced an exception from the inner receiver, we consider this unrecoverable.
    447       return Exception(*e);
    448     }
    449 
    450     return queue.pop();
    451   }
    452 
    453   uint getPort() override {
    454     return inner->getPort();
    455   }
    456 
    457   void getsockopt(int level, int option, void* value, uint* length) override {
    458     return inner->getsockopt(level, option, value, length);
    459   }
    460 
    461   void setsockopt(int level, int option, const void* value, uint length) override {
    462     return inner->setsockopt(level, option, value, length);
    463   }
    464 
    465 private:
    466   void onAcceptSuccess(AuthenticatedStream&& stream) {
    467     // Queue this stream to go through SSL_accept.
    468 
    469     auto acceptPromise = kj::evalNow([&] {
    470       // Do the SSL acceptance procedure.
    471       return tls.wrapServer(kj::mv(stream));
    472     });
    473 
    474     auto sslPromise = acceptPromise.then([this](auto&& stream) -> Promise<void> {
    475       // This is only attached to the success path, thus the error handler will catch if our
    476       // promise fails.
    477       queue.push(kj::mv(stream));
    478       return kj::READY_NOW;
    479     });
    480     tasks.add(kj::mv(sslPromise));
    481   }
    482 
    483   void onAcceptFailure(Exception&& e) {
    484     // Store this exception to reject all future calls to accept() and reject any unfulfilled
    485     // promises from the queue.
    486     maybeInnerException = kj::mv(e);
    487     queue.rejectAll(Exception(KJ_REQUIRE_NONNULL(maybeInnerException)));
    488   }
    489 
    490   Promise<void> acceptLoop() {
    491     // Accept one connection and queue up the next accept on our TaskSet.
    492 
    493     return inner->acceptAuthenticated().then(
    494         [this](AuthenticatedStream&& stream) {
    495       onAcceptSuccess(kj::mv(stream));
    496 
    497       // Queue up the next accept loop immediately without waiting for SSL_accept()/wrapServer().
    498       return acceptLoop();
    499     });
    500   }
    501 
    502   TlsContext& tls;
    503   Own<ConnectionReceiver> inner;
    504 
    505   Promise<void> acceptLoopTask;
    506   ProducerConsumerQueue<AuthenticatedStream> queue;
    507   TaskSet tasks;
    508 
    509   Maybe<Exception> maybeInnerException;
    510 };
    511 
    512 class TlsNetworkAddress final: public kj::NetworkAddress {
    513 public:
    514   TlsNetworkAddress(TlsContext& tls, kj::String hostname, kj::Own<kj::NetworkAddress>&& inner)
    515       : tls(tls), hostname(kj::mv(hostname)), inner(kj::mv(inner)) {}
    516 
    517   Promise<Own<AsyncIoStream>> connect() override {
    518     // Note: It's unfortunately pretty common for people to assume they can drop the NetworkAddress
    519     //   as soon as connect() returns, and this works with the native network implementation.
    520     //   So, we make some copies here.
    521     auto& tlsRef = tls;
    522     auto hostnameCopy = kj::str(hostname);
    523     return inner->connect().then(kj::mvCapture(hostnameCopy,
    524         [&tlsRef](kj::String&& hostname, Own<AsyncIoStream>&& stream) {
    525       return tlsRef.wrapClient(kj::mv(stream), hostname);
    526     }));
    527   }
    528 
    529   Promise<kj::AuthenticatedStream> connectAuthenticated() override {
    530     // Note: It's unfortunately pretty common for people to assume they can drop the NetworkAddress
    531     //   as soon as connect() returns, and this works with the native network implementation.
    532     //   So, we make some copies here.
    533     auto& tlsRef = tls;
    534     auto hostnameCopy = kj::str(hostname);
    535     return inner->connectAuthenticated().then(
    536         [&tlsRef, hostname = kj::mv(hostnameCopy)](kj::AuthenticatedStream stream) {
    537       return tlsRef.wrapClient(kj::mv(stream), hostname);
    538     });
    539   }
    540 
    541   Own<ConnectionReceiver> listen() override {
    542     return tls.wrapPort(inner->listen());
    543   }
    544 
    545   Own<NetworkAddress> clone() override {
    546     return kj::heap<TlsNetworkAddress>(tls, kj::str(hostname), inner->clone());
    547   }
    548 
    549   String toString() override {
    550     return kj::str("tls:", inner->toString());
    551   }
    552 
    553 private:
    554   TlsContext& tls;
    555   kj::String hostname;
    556   kj::Own<kj::NetworkAddress> inner;
    557 };
    558 
    559 class TlsNetwork final: public kj::Network {
    560 public:
    561   TlsNetwork(TlsContext& tls, kj::Network& inner): tls(tls), inner(inner) {}
    562   TlsNetwork(TlsContext& tls, kj::Own<kj::Network> inner)
    563       : tls(tls), inner(*inner), ownInner(kj::mv(inner)) {}
    564 
    565   Promise<Own<NetworkAddress>> parseAddress(StringPtr addr, uint portHint) override {
    566     kj::String hostname;
    567     KJ_IF_MAYBE(pos, addr.findFirst(':')) {
    568       hostname = kj::heapString(addr.slice(0, *pos));
    569     } else {
    570       hostname = kj::heapString(addr);
    571     }
    572 
    573     return inner.parseAddress(addr, portHint)
    574         .then(kj::mvCapture(hostname, [this](kj::String&& hostname, kj::Own<NetworkAddress>&& addr)
    575             -> kj::Own<kj::NetworkAddress> {
    576       return kj::heap<TlsNetworkAddress>(tls, kj::mv(hostname), kj::mv(addr));
    577     }));
    578   }
    579 
    580   Own<NetworkAddress> getSockaddr(const void* sockaddr, uint len) override {
    581     KJ_UNIMPLEMENTED("TLS does not implement getSockaddr() because it needs to know hostnames");
    582   }
    583 
    584   Own<Network> restrictPeers(
    585       kj::ArrayPtr<const kj::StringPtr> allow,
    586       kj::ArrayPtr<const kj::StringPtr> deny = nullptr) override {
    587     // TODO(someday): Maybe we could implement the ability to specify CA or hostname restrictions?
    588     //   Or is it better to let people do that via the TlsContext? A neat thing about
    589     //   restrictPeers() is that it's easy to make user-configurable.
    590     return kj::heap<TlsNetwork>(tls, inner.restrictPeers(allow, deny));
    591   }
    592 
    593 private:
    594   TlsContext& tls;
    595   kj::Network& inner;
    596   kj::Own<kj::Network> ownInner;
    597 };
    598 
    599 // =======================================================================================
    600 // class TlsContext
    601 
    602 TlsContext::Options::Options()
    603     : useSystemTrustStore(true),
    604       verifyClients(false),
    605       minVersion(TlsVersion::TLS_1_2),
    606       cipherList("ECDHE-ECDSA-AES128-GCM-SHA256:ECDHE-RSA-AES128-GCM-SHA256:ECDHE-ECDSA-AES256-GCM-SHA384:ECDHE-RSA-AES256-GCM-SHA384:ECDHE-ECDSA-CHACHA20-POLY1305:ECDHE-RSA-CHACHA20-POLY1305") {}
    607 // Cipher list is Mozilla's "intermediate" list, except with classic DH removed since we don't
    608 // currently support setting dhparams. See:
    609 //     https://mozilla.github.io/server-side-tls/ssl-config-generator/
    610 //
    611 // Classic DH is arguably obsolete and will only become more so as time passes, so perhaps we'll
    612 // never bother.
    613 
    614 struct TlsContext::SniCallback {
    615   // struct SniCallback exists only so that callback() can be declared in the .c++ file, since it
    616   // references OpenSSL types.
    617 
    618   static int callback(SSL* ssl, int* ad, void* arg);
    619 };
    620 
    621 TlsContext::TlsContext(Options options) {
    622   ensureOpenSslInitialized();
    623 
    624 #if OPENSSL_VERSION_NUMBER >= 0x10100000L || defined(OPENSSL_IS_BORINGSSL)
    625   SSL_CTX* ctx = SSL_CTX_new(TLS_method());
    626 #else
    627   SSL_CTX* ctx = SSL_CTX_new(SSLv23_method());
    628 #endif
    629 
    630   if (ctx == nullptr) {
    631     throwOpensslError();
    632   }
    633   KJ_ON_SCOPE_FAILURE(SSL_CTX_free(ctx));
    634 
    635   // honor options.useSystemTrustStore
    636   if (options.useSystemTrustStore) {
    637     if (!SSL_CTX_set_default_verify_paths(ctx)) {
    638       throwOpensslError();
    639     }
    640   }
    641 
    642   // honor options.trustedCertificates
    643   if (options.trustedCertificates.size() > 0) {
    644     X509_STORE* store = SSL_CTX_get_cert_store(ctx);
    645     if (store == nullptr) {
    646       throwOpensslError();
    647     }
    648     for (auto& cert: options.trustedCertificates) {
    649       if (!X509_STORE_add_cert(store, reinterpret_cast<X509*>(cert.chain[0]))) {
    650         throwOpensslError();
    651       }
    652     }
    653   }
    654 
    655   if (options.verifyClients) {
    656     SSL_CTX_set_verify(ctx, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, NULL);
    657   }
    658 
    659   // honor options.minVersion
    660   long optionFlags = 0;
    661   if (options.minVersion > TlsVersion::SSL_3) {
    662     optionFlags |= SSL_OP_NO_SSLv3;
    663   }
    664   if (options.minVersion > TlsVersion::TLS_1_0) {
    665     optionFlags |= SSL_OP_NO_TLSv1;
    666   }
    667   if (options.minVersion > TlsVersion::TLS_1_1) {
    668     optionFlags |= SSL_OP_NO_TLSv1_1;
    669   }
    670   if (options.minVersion > TlsVersion::TLS_1_2) {
    671     optionFlags |= SSL_OP_NO_TLSv1_2;
    672   }
    673   SSL_CTX_set_options(ctx, optionFlags);  // note: never fails; returns new options bitmask
    674 
    675   // honor options.cipherList
    676   if (!SSL_CTX_set_cipher_list(ctx, options.cipherList.cStr())) {
    677     throwOpensslError();
    678   }
    679 
    680   // honor options.defaultKeypair
    681   KJ_IF_MAYBE(kp, options.defaultKeypair) {
    682     if (!SSL_CTX_use_PrivateKey(ctx, reinterpret_cast<EVP_PKEY*>(kp->privateKey.pkey))) {
    683       throwOpensslError();
    684     }
    685 
    686     if (!SSL_CTX_use_certificate(ctx, reinterpret_cast<X509*>(kp->certificate.chain[0]))) {
    687       throwOpensslError();
    688     }
    689 
    690     for (size_t i = 1; i < kj::size(kp->certificate.chain); i++) {
    691       X509* x509 = reinterpret_cast<X509*>(kp->certificate.chain[i]);
    692       if (x509 == nullptr) break;  // end of chain
    693 
    694       if (!SSL_CTX_add_extra_chain_cert(ctx, x509)) {
    695         throwOpensslError();
    696       }
    697 
    698       // SSL_CTX_add_extra_chain_cert() does NOT up the refcount itself.
    699       X509_up_ref(x509);
    700     }
    701   }
    702 
    703   // honor options.sniCallback
    704   KJ_IF_MAYBE(sni, options.sniCallback) {
    705     SSL_CTX_set_tlsext_servername_callback(ctx, &SniCallback::callback);
    706     SSL_CTX_set_tlsext_servername_arg(ctx, sni);
    707   }
    708 
    709   KJ_IF_MAYBE(timeout, options.acceptTimeout) {
    710     this->timer = KJ_REQUIRE_NONNULL(options.timer,
    711         "acceptTimeout option requires that a timer is also provided");
    712     this->acceptTimeout = *timeout;
    713   }
    714 
    715   this->ctx = ctx;
    716 }
    717 
    718 int TlsContext::SniCallback::callback(SSL* ssl, int* ad, void* arg) {
    719   // The third parameter is actually type TlsSniCallback*.
    720 
    721   KJ_IF_MAYBE(exception, kj::runCatchingExceptions([&]() {
    722     TlsSniCallback& sni = *reinterpret_cast<TlsSniCallback*>(arg);
    723 
    724     const char* name = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
    725     if (name != nullptr) {
    726       KJ_IF_MAYBE(kp, sni.getKey(name)) {
    727         if (!SSL_use_PrivateKey(ssl, reinterpret_cast<EVP_PKEY*>(kp->privateKey.pkey))) {
    728           throwOpensslError();
    729         }
    730 
    731         if (!SSL_use_certificate(ssl, reinterpret_cast<X509*>(kp->certificate.chain[0]))) {
    732           throwOpensslError();
    733         }
    734 
    735         if (!SSL_clear_chain_certs(ssl)) {
    736           throwOpensslError();
    737         }
    738 
    739         for (size_t i = 1; i < kj::size(kp->certificate.chain); i++) {
    740           X509* x509 = reinterpret_cast<X509*>(kp->certificate.chain[i]);
    741           if (x509 == nullptr) break;  // end of chain
    742 
    743           if (!SSL_add0_chain_cert(ssl, x509)) {
    744             throwOpensslError();
    745           }
    746 
    747           // SSL_add0_chain_cert() does NOT up the refcount itself.
    748           X509_up_ref(x509);
    749         }
    750       }
    751     }
    752   })) {
    753     KJ_LOG(ERROR, "exception when invoking SNI callback", *exception);
    754     *ad = SSL_AD_INTERNAL_ERROR;
    755     return SSL_TLSEXT_ERR_ALERT_FATAL;
    756   }
    757 
    758   return SSL_TLSEXT_ERR_OK;
    759 }
    760 
    761 TlsContext::~TlsContext() noexcept(false) {
    762   SSL_CTX_free(reinterpret_cast<SSL_CTX*>(ctx));
    763 }
    764 
    765 kj::Promise<kj::Own<kj::AsyncIoStream>> TlsContext::wrapClient(
    766     kj::Own<kj::AsyncIoStream> stream, kj::StringPtr expectedServerHostname) {
    767   auto conn = kj::heap<TlsConnection>(kj::mv(stream), reinterpret_cast<SSL_CTX*>(ctx));
    768   auto promise = conn->connect(expectedServerHostname);
    769   return promise.then(kj::mvCapture(conn, [](kj::Own<TlsConnection> conn)
    770       -> kj::Own<kj::AsyncIoStream> {
    771     return kj::mv(conn);
    772   }));
    773 }
    774 
    775 kj::Promise<kj::Own<kj::AsyncIoStream>> TlsContext::wrapServer(kj::Own<kj::AsyncIoStream> stream) {
    776   auto conn = kj::heap<TlsConnection>(kj::mv(stream), reinterpret_cast<SSL_CTX*>(ctx));
    777   auto promise = conn->accept();
    778   KJ_IF_MAYBE(timeout, acceptTimeout) {
    779     promise = KJ_REQUIRE_NONNULL(timer).timeoutAfter(*timeout, kj::mv(promise));
    780   }
    781   return promise.then(kj::mvCapture(conn, [](kj::Own<TlsConnection> conn)
    782       -> kj::Own<kj::AsyncIoStream> {
    783     return kj::mv(conn);
    784   }));
    785 }
    786 
    787 kj::Promise<kj::AuthenticatedStream> TlsContext::wrapClient(
    788     kj::AuthenticatedStream stream, kj::StringPtr expectedServerHostname) {
    789   auto conn = kj::heap<TlsConnection>(kj::mv(stream.stream), reinterpret_cast<SSL_CTX*>(ctx));
    790   auto promise = conn->connect(expectedServerHostname);
    791   return promise.then([conn=kj::mv(conn),innerId=kj::mv(stream.peerIdentity)]() mutable {
    792     auto id = conn->getIdentity(kj::mv(innerId));
    793     return kj::AuthenticatedStream { kj::mv(conn), kj::mv(id) };
    794   });
    795 }
    796 
    797 kj::Promise<kj::AuthenticatedStream> TlsContext::wrapServer(kj::AuthenticatedStream stream) {
    798   auto conn = kj::heap<TlsConnection>(kj::mv(stream.stream), reinterpret_cast<SSL_CTX*>(ctx));
    799   auto promise = conn->accept();
    800   KJ_IF_MAYBE(timeout, acceptTimeout) {
    801     promise = KJ_REQUIRE_NONNULL(timer).timeoutAfter(*timeout, kj::mv(promise));
    802   }
    803   return promise.then([conn=kj::mv(conn),innerId=kj::mv(stream.peerIdentity)]() mutable {
    804     auto id = conn->getIdentity(kj::mv(innerId));
    805     return kj::AuthenticatedStream { kj::mv(conn), kj::mv(id) };
    806   });
    807 }
    808 
    809 kj::Own<kj::ConnectionReceiver> TlsContext::wrapPort(kj::Own<kj::ConnectionReceiver> port) {
    810   return kj::heap<TlsConnectionReceiver>(*this, kj::mv(port));
    811 }
    812 
    813 kj::Own<kj::Network> TlsContext::wrapNetwork(kj::Network& network) {
    814   return kj::heap<TlsNetwork>(*this, network);
    815 }
    816 
    817 // =======================================================================================
    818 // class TlsPrivateKey
    819 
    820 TlsPrivateKey::TlsPrivateKey(kj::ArrayPtr<const byte> asn1) {
    821   ensureOpenSslInitialized();
    822 
    823   const byte* ptr = asn1.begin();
    824   pkey = d2i_AutoPrivateKey(nullptr, &ptr, asn1.size());
    825   if (pkey == nullptr) {
    826     throwOpensslError();
    827   }
    828 }
    829 
    830 TlsPrivateKey::TlsPrivateKey(kj::StringPtr pem, kj::Maybe<kj::StringPtr> password) {
    831   ensureOpenSslInitialized();
    832 
    833   // const_cast apparently needed for older versions of OpenSSL.
    834   BIO* bio = BIO_new_mem_buf(const_cast<char*>(pem.begin()), pem.size());
    835   KJ_DEFER(BIO_free(bio));
    836 
    837   pkey = PEM_read_bio_PrivateKey(bio, nullptr, &passwordCallback, &password);
    838   if (pkey == nullptr) {
    839     throwOpensslError();
    840   }
    841 }
    842 
    843 TlsPrivateKey::TlsPrivateKey(const TlsPrivateKey& other)
    844     : pkey(other.pkey) {
    845   if (pkey != nullptr) EVP_PKEY_up_ref(reinterpret_cast<EVP_PKEY*>(pkey));
    846 }
    847 
    848 TlsPrivateKey& TlsPrivateKey::operator=(const TlsPrivateKey& other) {
    849   if (pkey != other.pkey) {
    850     EVP_PKEY_free(reinterpret_cast<EVP_PKEY*>(pkey));
    851     pkey = other.pkey;
    852     if (pkey != nullptr) EVP_PKEY_up_ref(reinterpret_cast<EVP_PKEY*>(pkey));
    853   }
    854   return *this;
    855 }
    856 
    857 TlsPrivateKey::~TlsPrivateKey() noexcept(false) {
    858   EVP_PKEY_free(reinterpret_cast<EVP_PKEY*>(pkey));
    859 }
    860 
    861 int TlsPrivateKey::passwordCallback(char* buf, int size, int rwflag, void* u) {
    862   auto& password = *reinterpret_cast<kj::Maybe<kj::StringPtr>*>(u);
    863 
    864   KJ_IF_MAYBE(p, password) {
    865     int result = kj::min(p->size(), size);
    866     memcpy(buf, p->begin(), result);
    867     return result;
    868   } else {
    869     return 0;
    870   }
    871 }
    872 
    873 // =======================================================================================
    874 // class TlsCertificate
    875 
    876 TlsCertificate::TlsCertificate(kj::ArrayPtr<const kj::ArrayPtr<const byte>> asn1) {
    877   ensureOpenSslInitialized();
    878 
    879   KJ_REQUIRE(asn1.size() > 0, "must provide at least one certificate in chain");
    880   KJ_REQUIRE(asn1.size() <= kj::size(chain),
    881       "exceeded maximum certificate chain length of 10");
    882 
    883   memset(chain, 0, sizeof(chain));
    884 
    885   for (auto i: kj::indices(asn1)) {
    886     auto p = asn1[i].begin();
    887 
    888     // "_AUX" apparently refers to some auxilliary information that can be appended to the
    889     // certificate, but should only be trusted for your own certificate, not the whole chain??
    890     // I don't really know, I'm just cargo-culting.
    891     chain[i] = i == 0 ? d2i_X509_AUX(nullptr, &p, asn1[i].size())
    892                       : d2i_X509(nullptr, &p, asn1[i].size());
    893 
    894     if (chain[i] == nullptr) {
    895       for (size_t j = 0; j < i; j++) {
    896         X509_free(reinterpret_cast<X509*>(chain[j]));
    897       }
    898       throwOpensslError();
    899     }
    900   }
    901 }
    902 
    903 TlsCertificate::TlsCertificate(kj::ArrayPtr<const byte> asn1)
    904     : TlsCertificate(kj::arrayPtr(&asn1, 1)) {}
    905 
    906 TlsCertificate::TlsCertificate(kj::StringPtr pem) {
    907   ensureOpenSslInitialized();
    908 
    909   memset(chain, 0, sizeof(chain));
    910 
    911   // const_cast apparently needed for older versions of OpenSSL.
    912   BIO* bio = BIO_new_mem_buf(const_cast<char*>(pem.begin()), pem.size());
    913   KJ_DEFER(BIO_free(bio));
    914 
    915   for (auto i: kj::indices(chain)) {
    916     // "_AUX" apparently refers to some auxilliary information that can be appended to the
    917     // certificate, but should only be trusted for your own certificate, not the whole chain??
    918     // I don't really know, I'm just cargo-culting.
    919     chain[i] = i == 0 ? PEM_read_bio_X509_AUX(bio, nullptr, nullptr, nullptr)
    920                       : PEM_read_bio_X509(bio, nullptr, nullptr, nullptr);
    921 
    922     if (chain[i] == nullptr) {
    923       auto error = ERR_peek_last_error();
    924       if (i > 0 && ERR_GET_LIB(error) == ERR_LIB_PEM &&
    925           ERR_GET_REASON(error) == PEM_R_NO_START_LINE) {
    926         // EOF; we're done.
    927         ERR_clear_error();
    928         return;
    929       } else {
    930         for (size_t j = 0; j < i; j++) {
    931           X509_free(reinterpret_cast<X509*>(chain[j]));
    932         }
    933         throwOpensslError();
    934       }
    935     }
    936   }
    937 
    938   // We reached the chain length limit. Try to read one more to verify that the chain ends here.
    939   X509* dummy = PEM_read_bio_X509(bio, nullptr, nullptr, nullptr);
    940   if (dummy != nullptr) {
    941     X509_free(dummy);
    942     for (auto i: kj::indices(chain)) {
    943       X509_free(reinterpret_cast<X509*>(chain[i]));
    944     }
    945     KJ_FAIL_REQUIRE("exceeded maximum certificate chain length of 10");
    946   }
    947 }
    948 
    949 TlsCertificate::TlsCertificate(const TlsCertificate& other) {
    950   memcpy(chain, other.chain, sizeof(chain));
    951   for (void* p: chain) {
    952     if (p == nullptr) break;  // end of chain; quit early
    953     X509_up_ref(reinterpret_cast<X509*>(p));
    954   }
    955 }
    956 
    957 TlsCertificate& TlsCertificate::operator=(const TlsCertificate& other) {
    958   for (auto i: kj::indices(chain)) {
    959     if (chain[i] != other.chain[i]) {
    960       EVP_PKEY_free(reinterpret_cast<EVP_PKEY*>(chain[i]));
    961       chain[i] = other.chain[i];
    962       if (chain[i] != nullptr) X509_up_ref(reinterpret_cast<X509*>(chain[i]));
    963     } else if (chain[i] == nullptr) {
    964       // end of both chains; quit early
    965       break;
    966     }
    967   }
    968   return *this;
    969 }
    970 
    971 TlsCertificate::~TlsCertificate() noexcept(false) {
    972   for (void* p: chain) {
    973     if (p == nullptr) break;  // end of chain; quit early
    974     X509_free(reinterpret_cast<X509*>(p));
    975   }
    976 }
    977 
    978 // =======================================================================================
    979 // class TlsPeerIdentity
    980 
    981 TlsPeerIdentity::~TlsPeerIdentity() noexcept(false) {
    982   if (cert != nullptr) {
    983     X509_free(reinterpret_cast<X509*>(cert));
    984   }
    985 }
    986 
    987 kj::String TlsPeerIdentity::toString() {
    988   if (hasCertificate()) {
    989     return getCommonName();
    990   } else {
    991     return kj::str("(anonymous client)");
    992   }
    993 }
    994 
    995 kj::String TlsPeerIdentity::getCommonName() {
    996   if (cert == nullptr) {
    997     KJ_FAIL_REQUIRE("client did not provide a certificate") { return nullptr; }
    998   }
    999 
   1000   X509_NAME* subj = X509_get_subject_name(reinterpret_cast<X509*>(cert));
   1001 
   1002   int index = X509_NAME_get_index_by_NID(subj, NID_commonName, -1);
   1003   KJ_ASSERT(index != -1, "certificate has no common name?");
   1004   X509_NAME_ENTRY* entry = X509_NAME_get_entry(subj, index);
   1005   KJ_ASSERT(entry != nullptr);
   1006   ASN1_STRING* data = X509_NAME_ENTRY_get_data(entry);
   1007   KJ_ASSERT(data != nullptr);
   1008 
   1009   unsigned char* out = nullptr;
   1010   int len = ASN1_STRING_to_UTF8(&out, data);
   1011   KJ_ASSERT(len >= 0);
   1012   KJ_DEFER(OPENSSL_free(out));
   1013 
   1014   return kj::heapString(reinterpret_cast<char*>(out), len);
   1015 }
   1016 
   1017 }  // namespace kj
   1018 
   1019 #endif  // KJ_HAS_OPENSSL