tls.c++ (33847B)
1 // Copyright (c) 2016 Sandstorm Development Group, Inc. and contributors 2 // Licensed under the MIT License: 3 // 4 // Permission is hereby granted, free of charge, to any person obtaining a copy 5 // of this software and associated documentation files (the "Software"), to deal 6 // in the Software without restriction, including without limitation the rights 7 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 8 // copies of the Software, and to permit persons to whom the Software is 9 // furnished to do so, subject to the following conditions: 10 // 11 // The above copyright notice and this permission notice shall be included in 12 // all copies or substantial portions of the Software. 13 // 14 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 17 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 19 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 20 // THE SOFTWARE. 21 22 #if KJ_HAS_OPENSSL 23 24 #include "tls.h" 25 26 #include "readiness-io.h" 27 28 #include <openssl/bio.h> 29 #include <openssl/conf.h> 30 #include <openssl/err.h> 31 #include <openssl/evp.h> 32 #include <openssl/ssl.h> 33 #include <openssl/tls1.h> 34 #include <openssl/x509.h> 35 #include <openssl/x509v3.h> 36 37 #include <kj/async-queue.h> 38 #include <kj/debug.h> 39 #include <kj/vector.h> 40 41 #if OPENSSL_VERSION_NUMBER < 0x10100000L 42 #define BIO_set_init(x,v) (x->init=v) 43 #define BIO_get_data(x) (x->ptr) 44 #define BIO_set_data(x,v) (x->ptr=v) 45 #endif 46 47 namespace kj { 48 49 // ======================================================================================= 50 // misc helpers 51 52 namespace { 53 54 KJ_NORETURN(void throwOpensslError()); 55 void throwOpensslError() { 56 // Call when an OpenSSL function returns an error code to convert that into an exception and 57 // throw it. 58 59 kj::Vector<kj::String> lines; 60 while (unsigned long long error = ERR_get_error()) { 61 char message[1024]; 62 ERR_error_string_n(error, message, sizeof(message)); 63 lines.add(kj::heapString(message)); 64 } 65 kj::String message = kj::strArray(lines, "\n"); 66 KJ_FAIL_ASSERT("OpenSSL error", message); 67 } 68 69 #if OPENSSL_VERSION_NUMBER < 0x10100000L && !defined(OPENSSL_IS_BORINGSSL) 70 // Older versions of OpenSSL don't define _up_ref() functions. 71 72 void EVP_PKEY_up_ref(EVP_PKEY* pkey) { 73 CRYPTO_add(&pkey->references, 1, CRYPTO_LOCK_EVP_PKEY); 74 } 75 76 void X509_up_ref(X509* x509) { 77 CRYPTO_add(&x509->references, 1, CRYPTO_LOCK_X509); 78 } 79 80 #endif 81 82 #if OPENSSL_VERSION_NUMBER < 0x10100000L 83 class OpenSslInit { 84 // Initializes the OpenSSL library. 85 public: 86 OpenSslInit() { 87 SSL_library_init(); 88 SSL_load_error_strings(); 89 OPENSSL_config(nullptr); 90 } 91 }; 92 93 void ensureOpenSslInitialized() { 94 // Initializes the OpenSSL library the first time it is called. 95 static OpenSslInit init; 96 } 97 #else 98 inline void ensureOpenSslInitialized() { 99 // As of 1.1.0, no initialization is needed. 100 } 101 #endif 102 103 } // namespace 104 105 // ======================================================================================= 106 // Implementation of kj::AsyncIoStream that applies TLS on top of some other AsyncIoStream. 107 // 108 // TODO(perf): OpenSSL's I/O abstraction layer, "BIO", is readiness-based, but AsyncIoStream is 109 // completion-based. This forces us to use an intermediate buffer which wastes memory and incurs 110 // redundant copies. We could improve the situation by creating a way to detect if the underlying 111 // AsyncIoStream is simply wrapping a file descriptor (or other readiness-based stream?) and use 112 // that directly if so. 113 114 class TlsConnection final: public kj::AsyncIoStream { 115 public: 116 TlsConnection(kj::Own<kj::AsyncIoStream> stream, SSL_CTX* ctx) 117 : TlsConnection(*stream, ctx) { 118 ownInner = kj::mv(stream); 119 } 120 121 TlsConnection(kj::AsyncIoStream& stream, SSL_CTX* ctx) 122 : inner(stream), readBuffer(stream), writeBuffer(stream) { 123 ssl = SSL_new(ctx); 124 if (ssl == nullptr) { 125 throwOpensslError(); 126 } 127 128 BIO* bio = BIO_new(const_cast<BIO_METHOD*>(getBioVtable())); 129 if (bio == nullptr) { 130 SSL_free(ssl); 131 throwOpensslError(); 132 } 133 134 BIO_set_data(bio, this); 135 BIO_set_init(bio, 1); 136 SSL_set_bio(ssl, bio, bio); 137 } 138 139 kj::Promise<void> connect(kj::StringPtr expectedServerHostname) { 140 if (!SSL_set_tlsext_host_name(ssl, expectedServerHostname.cStr())) { 141 throwOpensslError(); 142 } 143 144 X509_VERIFY_PARAM* verify = SSL_get0_param(ssl); 145 if (verify == nullptr) { 146 throwOpensslError(); 147 } 148 149 if (X509_VERIFY_PARAM_set1_host( 150 verify, expectedServerHostname.cStr(), expectedServerHostname.size()) <= 0) { 151 throwOpensslError(); 152 } 153 154 return sslCall([this]() { return SSL_connect(ssl); }).then([this](size_t) { 155 X509* cert = SSL_get_peer_certificate(ssl); 156 KJ_REQUIRE(cert != nullptr, "TLS peer provided no certificate"); 157 X509_free(cert); 158 159 auto result = SSL_get_verify_result(ssl); 160 if (result != X509_V_OK) { 161 const char* reason = X509_verify_cert_error_string(result); 162 KJ_FAIL_REQUIRE("TLS peer's certificate is not trusted", reason); 163 } 164 }); 165 } 166 167 kj::Promise<void> accept() { 168 // We are the server. Set SSL options to prefer server's cipher choice. 169 SSL_set_options(ssl, SSL_OP_CIPHER_SERVER_PREFERENCE); 170 171 auto acceptPromise = sslCall([this]() { 172 return SSL_accept(ssl); 173 }); 174 return acceptPromise.then([](size_t ret) { 175 if (ret == 0) { 176 kj::throwRecoverableException( 177 KJ_EXCEPTION(DISCONNECTED, "Client disconnected during SSL_accept()")); 178 } 179 }); 180 } 181 182 kj::Own<TlsPeerIdentity> getIdentity(kj::Own<kj::PeerIdentity> inner) { 183 return kj::heap<TlsPeerIdentity>(SSL_get_peer_certificate(ssl), kj::mv(inner), 184 kj::Badge<TlsConnection>()); 185 } 186 187 ~TlsConnection() noexcept(false) { 188 SSL_free(ssl); 189 } 190 191 kj::Promise<size_t> tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { 192 return tryReadInternal(buffer, minBytes, maxBytes, 0); 193 } 194 195 Promise<void> write(const void* buffer, size_t size) override { 196 return writeInternal(kj::arrayPtr(reinterpret_cast<const byte*>(buffer), size), nullptr); 197 } 198 199 Promise<void> write(ArrayPtr<const ArrayPtr<const byte>> pieces) override { 200 auto cork = writeBuffer.cork(); 201 return writeInternal(pieces[0], pieces.slice(1, pieces.size())).attach(kj::mv(cork)); 202 } 203 204 Promise<void> whenWriteDisconnected() override { 205 return inner.whenWriteDisconnected(); 206 } 207 208 void shutdownWrite() override { 209 KJ_REQUIRE(shutdownTask == nullptr, "already called shutdownWrite()"); 210 211 // TODO(0.10): shutdownWrite() is problematic because it doesn't return a promise. It was 212 // designed to assume that it would only be called after all writes are finished and that 213 // there was no reason to block at that point, but SSL sessions don't fit this since they 214 // actually have to send a shutdown message. 215 shutdownTask = sslCall([this]() { 216 // The first SSL_shutdown() call is expected to return 0 and may flag a misleading error. 217 int result = SSL_shutdown(ssl); 218 return result == 0 ? 1 : result; 219 }).ignoreResult().eagerlyEvaluate([](kj::Exception&& e) { 220 KJ_LOG(ERROR, e); 221 }); 222 } 223 224 void abortRead() override { 225 inner.abortRead(); 226 } 227 228 void getsockopt(int level, int option, void* value, uint* length) override { 229 inner.getsockopt(level, option, value, length); 230 } 231 void setsockopt(int level, int option, const void* value, uint length) override { 232 inner.setsockopt(level, option, value, length); 233 } 234 235 void getsockname(struct sockaddr* addr, uint* length) override { 236 inner.getsockname(addr, length); 237 } 238 void getpeername(struct sockaddr* addr, uint* length) override { 239 inner.getpeername(addr, length); 240 } 241 242 kj::Maybe<int> getFd() const override { 243 return inner.getFd(); 244 } 245 246 private: 247 SSL* ssl; 248 kj::AsyncIoStream& inner; 249 kj::Own<kj::AsyncIoStream> ownInner; 250 251 bool disconnected = false; 252 kj::Maybe<kj::Promise<void>> shutdownTask; 253 254 ReadyInputStreamWrapper readBuffer; 255 ReadyOutputStreamWrapper writeBuffer; 256 257 kj::Promise<size_t> tryReadInternal( 258 void* buffer, size_t minBytes, size_t maxBytes, size_t alreadyDone) { 259 if (disconnected) return alreadyDone; 260 261 return sslCall([this,buffer,maxBytes]() { return SSL_read(ssl, buffer, maxBytes); }) 262 .then([this,buffer,minBytes,maxBytes,alreadyDone](size_t n) -> kj::Promise<size_t> { 263 if (n >= minBytes || n == 0) { 264 return alreadyDone + n; 265 } else { 266 return tryReadInternal(reinterpret_cast<byte*>(buffer) + n, 267 minBytes - n, maxBytes - n, alreadyDone + n); 268 } 269 }); 270 } 271 272 Promise<void> writeInternal(kj::ArrayPtr<const byte> first, 273 kj::ArrayPtr<const kj::ArrayPtr<const byte>> rest) { 274 KJ_REQUIRE(shutdownTask == nullptr, "already called shutdownWrite()"); 275 276 // SSL_write() with a zero-sized input returns 0, but a 0 return is documented as indicating 277 // an error. So, we need to avoid zero-sized writes entirely. 278 while (first.size() == 0) { 279 if (rest.size() == 0) { 280 return kj::READY_NOW; 281 } 282 first = rest.front(); 283 rest = rest.slice(1, rest.size()); 284 } 285 286 return sslCall([this,first]() { return SSL_write(ssl, first.begin(), first.size()); }) 287 .then([this,first,rest](size_t n) -> kj::Promise<void> { 288 if (n == 0) { 289 return KJ_EXCEPTION(DISCONNECTED, "ssl connection ended during write"); 290 } else if (n < first.size()) { 291 return writeInternal(first.slice(n, first.size()), rest); 292 } else if (rest.size() > 0) { 293 return writeInternal(rest[0], rest.slice(1, rest.size())); 294 } else { 295 return kj::READY_NOW; 296 } 297 }); 298 } 299 300 template <typename Func> 301 kj::Promise<size_t> sslCall(Func&& func) { 302 if (disconnected) return size_t(0); 303 304 auto result = func(); 305 306 if (result > 0) { 307 return result; 308 } else { 309 int error = SSL_get_error(ssl, result); 310 switch (error) { 311 case SSL_ERROR_ZERO_RETURN: 312 disconnected = true; 313 return size_t(0); 314 case SSL_ERROR_WANT_READ: 315 return readBuffer.whenReady().then(kj::mvCapture(func, 316 [this](Func&& func) mutable { return sslCall(kj::fwd<Func>(func)); })); 317 case SSL_ERROR_WANT_WRITE: 318 return writeBuffer.whenReady().then(kj::mvCapture(func, 319 [this](Func&& func) mutable { return sslCall(kj::fwd<Func>(func)); })); 320 case SSL_ERROR_SSL: 321 throwOpensslError(); 322 case SSL_ERROR_SYSCALL: 323 if (result == 0) { 324 disconnected = true; 325 return size_t(0); 326 } else { 327 // According to documentation we shouldn't get here, because our BIO never returns an 328 // "error". But in practice we do get here sometimes when the peer disconnects 329 // prematurely. 330 return KJ_EXCEPTION(DISCONNECTED, "SSL unable to continue I/O"); 331 } 332 default: 333 KJ_FAIL_ASSERT("unexpected SSL error code", error); 334 } 335 } 336 } 337 338 static int bioRead(BIO* b, char* out, int outl) { 339 BIO_clear_retry_flags(b); 340 KJ_IF_MAYBE(n, reinterpret_cast<TlsConnection*>(BIO_get_data(b))->readBuffer 341 .read(kj::arrayPtr(out, outl).asBytes())) { 342 return *n; 343 } else { 344 BIO_set_retry_read(b); 345 return -1; 346 } 347 } 348 349 static int bioWrite(BIO* b, const char* in, int inl) { 350 BIO_clear_retry_flags(b); 351 KJ_IF_MAYBE(n, reinterpret_cast<TlsConnection*>(BIO_get_data(b))->writeBuffer 352 .write(kj::arrayPtr(in, inl).asBytes())) { 353 return *n; 354 } else { 355 BIO_set_retry_write(b); 356 return -1; 357 } 358 } 359 360 static long bioCtrl(BIO* b, int cmd, long num, void* ptr) { 361 switch (cmd) { 362 case BIO_CTRL_FLUSH: 363 return 1; 364 case BIO_CTRL_PUSH: 365 case BIO_CTRL_POP: 366 // Informational? 367 return 0; 368 default: 369 KJ_LOG(WARNING, "unimplemented bio_ctrl", cmd); 370 return 0; 371 } 372 } 373 374 static int bioCreate(BIO* b) { 375 BIO_set_data(b, nullptr); 376 return 1; 377 } 378 379 static int bioDestroy(BIO* b) { 380 // The BIO does NOT own the TlsConnection. 381 return 1; 382 } 383 384 #if OPENSSL_VERSION_NUMBER < 0x10100000L 385 static const BIO_METHOD* getBioVtable() { 386 static const BIO_METHOD VTABLE { 387 BIO_TYPE_SOURCE_SINK, 388 "KJ stream", 389 TlsConnection::bioWrite, 390 TlsConnection::bioRead, 391 nullptr, // puts 392 nullptr, // gets 393 TlsConnection::bioCtrl, 394 TlsConnection::bioCreate, 395 TlsConnection::bioDestroy, 396 nullptr 397 }; 398 return &VTABLE; 399 } 400 #else 401 static const BIO_METHOD* getBioVtable() { 402 static const BIO_METHOD* const vtable = makeBioVtable(); 403 return vtable; 404 } 405 static const BIO_METHOD* makeBioVtable() { 406 BIO_METHOD* vtable = BIO_meth_new(BIO_TYPE_SOURCE_SINK, "KJ stream"); 407 BIO_meth_set_write(vtable, TlsConnection::bioWrite); 408 BIO_meth_set_read(vtable, TlsConnection::bioRead); 409 BIO_meth_set_ctrl(vtable, TlsConnection::bioCtrl); 410 BIO_meth_set_create(vtable, TlsConnection::bioCreate); 411 BIO_meth_set_destroy(vtable, TlsConnection::bioDestroy); 412 return vtable; 413 } 414 #endif 415 }; 416 417 // ======================================================================================= 418 // Implementations of ConnectionReceiver, NetworkAddress, and Network as wrappers adding TLS. 419 420 class TlsConnectionReceiver final: public ConnectionReceiver, public TaskSet::ErrorHandler { 421 public: 422 TlsConnectionReceiver(TlsContext &tls, Own<ConnectionReceiver> inner) 423 : tls(tls), inner(kj::mv(inner)), 424 acceptLoopTask(acceptLoop().eagerlyEvaluate([this](Exception &&e) { 425 onAcceptFailure(kj::mv(e)); 426 })), 427 tasks(*this) {} 428 429 void taskFailed(Exception&& e) override { 430 // TODO(someday): SSL connection failures may be a fact of normal operation but they may also 431 // be important diagnostic information. We should allow for an error handler to be passed in so 432 // that network issues that affect TLS can be more discoverable from the server side. 433 if (e.getType() != Exception::Type::DISCONNECTED) { 434 KJ_LOG(ERROR, "error accepting tls connection", kj::mv(e)); 435 } 436 }; 437 438 Promise<Own<AsyncIoStream>> accept() override { 439 return acceptAuthenticated().then([](AuthenticatedStream&& stream) { 440 return kj::mv(stream.stream); 441 }); 442 } 443 444 Promise<AuthenticatedStream> acceptAuthenticated() override { 445 KJ_IF_MAYBE(e, maybeInnerException) { 446 // We've experienced an exception from the inner receiver, we consider this unrecoverable. 447 return Exception(*e); 448 } 449 450 return queue.pop(); 451 } 452 453 uint getPort() override { 454 return inner->getPort(); 455 } 456 457 void getsockopt(int level, int option, void* value, uint* length) override { 458 return inner->getsockopt(level, option, value, length); 459 } 460 461 void setsockopt(int level, int option, const void* value, uint length) override { 462 return inner->setsockopt(level, option, value, length); 463 } 464 465 private: 466 void onAcceptSuccess(AuthenticatedStream&& stream) { 467 // Queue this stream to go through SSL_accept. 468 469 auto acceptPromise = kj::evalNow([&] { 470 // Do the SSL acceptance procedure. 471 return tls.wrapServer(kj::mv(stream)); 472 }); 473 474 auto sslPromise = acceptPromise.then([this](auto&& stream) -> Promise<void> { 475 // This is only attached to the success path, thus the error handler will catch if our 476 // promise fails. 477 queue.push(kj::mv(stream)); 478 return kj::READY_NOW; 479 }); 480 tasks.add(kj::mv(sslPromise)); 481 } 482 483 void onAcceptFailure(Exception&& e) { 484 // Store this exception to reject all future calls to accept() and reject any unfulfilled 485 // promises from the queue. 486 maybeInnerException = kj::mv(e); 487 queue.rejectAll(Exception(KJ_REQUIRE_NONNULL(maybeInnerException))); 488 } 489 490 Promise<void> acceptLoop() { 491 // Accept one connection and queue up the next accept on our TaskSet. 492 493 return inner->acceptAuthenticated().then( 494 [this](AuthenticatedStream&& stream) { 495 onAcceptSuccess(kj::mv(stream)); 496 497 // Queue up the next accept loop immediately without waiting for SSL_accept()/wrapServer(). 498 return acceptLoop(); 499 }); 500 } 501 502 TlsContext& tls; 503 Own<ConnectionReceiver> inner; 504 505 Promise<void> acceptLoopTask; 506 ProducerConsumerQueue<AuthenticatedStream> queue; 507 TaskSet tasks; 508 509 Maybe<Exception> maybeInnerException; 510 }; 511 512 class TlsNetworkAddress final: public kj::NetworkAddress { 513 public: 514 TlsNetworkAddress(TlsContext& tls, kj::String hostname, kj::Own<kj::NetworkAddress>&& inner) 515 : tls(tls), hostname(kj::mv(hostname)), inner(kj::mv(inner)) {} 516 517 Promise<Own<AsyncIoStream>> connect() override { 518 // Note: It's unfortunately pretty common for people to assume they can drop the NetworkAddress 519 // as soon as connect() returns, and this works with the native network implementation. 520 // So, we make some copies here. 521 auto& tlsRef = tls; 522 auto hostnameCopy = kj::str(hostname); 523 return inner->connect().then(kj::mvCapture(hostnameCopy, 524 [&tlsRef](kj::String&& hostname, Own<AsyncIoStream>&& stream) { 525 return tlsRef.wrapClient(kj::mv(stream), hostname); 526 })); 527 } 528 529 Promise<kj::AuthenticatedStream> connectAuthenticated() override { 530 // Note: It's unfortunately pretty common for people to assume they can drop the NetworkAddress 531 // as soon as connect() returns, and this works with the native network implementation. 532 // So, we make some copies here. 533 auto& tlsRef = tls; 534 auto hostnameCopy = kj::str(hostname); 535 return inner->connectAuthenticated().then( 536 [&tlsRef, hostname = kj::mv(hostnameCopy)](kj::AuthenticatedStream stream) { 537 return tlsRef.wrapClient(kj::mv(stream), hostname); 538 }); 539 } 540 541 Own<ConnectionReceiver> listen() override { 542 return tls.wrapPort(inner->listen()); 543 } 544 545 Own<NetworkAddress> clone() override { 546 return kj::heap<TlsNetworkAddress>(tls, kj::str(hostname), inner->clone()); 547 } 548 549 String toString() override { 550 return kj::str("tls:", inner->toString()); 551 } 552 553 private: 554 TlsContext& tls; 555 kj::String hostname; 556 kj::Own<kj::NetworkAddress> inner; 557 }; 558 559 class TlsNetwork final: public kj::Network { 560 public: 561 TlsNetwork(TlsContext& tls, kj::Network& inner): tls(tls), inner(inner) {} 562 TlsNetwork(TlsContext& tls, kj::Own<kj::Network> inner) 563 : tls(tls), inner(*inner), ownInner(kj::mv(inner)) {} 564 565 Promise<Own<NetworkAddress>> parseAddress(StringPtr addr, uint portHint) override { 566 kj::String hostname; 567 KJ_IF_MAYBE(pos, addr.findFirst(':')) { 568 hostname = kj::heapString(addr.slice(0, *pos)); 569 } else { 570 hostname = kj::heapString(addr); 571 } 572 573 return inner.parseAddress(addr, portHint) 574 .then(kj::mvCapture(hostname, [this](kj::String&& hostname, kj::Own<NetworkAddress>&& addr) 575 -> kj::Own<kj::NetworkAddress> { 576 return kj::heap<TlsNetworkAddress>(tls, kj::mv(hostname), kj::mv(addr)); 577 })); 578 } 579 580 Own<NetworkAddress> getSockaddr(const void* sockaddr, uint len) override { 581 KJ_UNIMPLEMENTED("TLS does not implement getSockaddr() because it needs to know hostnames"); 582 } 583 584 Own<Network> restrictPeers( 585 kj::ArrayPtr<const kj::StringPtr> allow, 586 kj::ArrayPtr<const kj::StringPtr> deny = nullptr) override { 587 // TODO(someday): Maybe we could implement the ability to specify CA or hostname restrictions? 588 // Or is it better to let people do that via the TlsContext? A neat thing about 589 // restrictPeers() is that it's easy to make user-configurable. 590 return kj::heap<TlsNetwork>(tls, inner.restrictPeers(allow, deny)); 591 } 592 593 private: 594 TlsContext& tls; 595 kj::Network& inner; 596 kj::Own<kj::Network> ownInner; 597 }; 598 599 // ======================================================================================= 600 // class TlsContext 601 602 TlsContext::Options::Options() 603 : useSystemTrustStore(true), 604 verifyClients(false), 605 minVersion(TlsVersion::TLS_1_2), 606 cipherList("ECDHE-ECDSA-AES128-GCM-SHA256:ECDHE-RSA-AES128-GCM-SHA256:ECDHE-ECDSA-AES256-GCM-SHA384:ECDHE-RSA-AES256-GCM-SHA384:ECDHE-ECDSA-CHACHA20-POLY1305:ECDHE-RSA-CHACHA20-POLY1305") {} 607 // Cipher list is Mozilla's "intermediate" list, except with classic DH removed since we don't 608 // currently support setting dhparams. See: 609 // https://mozilla.github.io/server-side-tls/ssl-config-generator/ 610 // 611 // Classic DH is arguably obsolete and will only become more so as time passes, so perhaps we'll 612 // never bother. 613 614 struct TlsContext::SniCallback { 615 // struct SniCallback exists only so that callback() can be declared in the .c++ file, since it 616 // references OpenSSL types. 617 618 static int callback(SSL* ssl, int* ad, void* arg); 619 }; 620 621 TlsContext::TlsContext(Options options) { 622 ensureOpenSslInitialized(); 623 624 #if OPENSSL_VERSION_NUMBER >= 0x10100000L || defined(OPENSSL_IS_BORINGSSL) 625 SSL_CTX* ctx = SSL_CTX_new(TLS_method()); 626 #else 627 SSL_CTX* ctx = SSL_CTX_new(SSLv23_method()); 628 #endif 629 630 if (ctx == nullptr) { 631 throwOpensslError(); 632 } 633 KJ_ON_SCOPE_FAILURE(SSL_CTX_free(ctx)); 634 635 // honor options.useSystemTrustStore 636 if (options.useSystemTrustStore) { 637 if (!SSL_CTX_set_default_verify_paths(ctx)) { 638 throwOpensslError(); 639 } 640 } 641 642 // honor options.trustedCertificates 643 if (options.trustedCertificates.size() > 0) { 644 X509_STORE* store = SSL_CTX_get_cert_store(ctx); 645 if (store == nullptr) { 646 throwOpensslError(); 647 } 648 for (auto& cert: options.trustedCertificates) { 649 if (!X509_STORE_add_cert(store, reinterpret_cast<X509*>(cert.chain[0]))) { 650 throwOpensslError(); 651 } 652 } 653 } 654 655 if (options.verifyClients) { 656 SSL_CTX_set_verify(ctx, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, NULL); 657 } 658 659 // honor options.minVersion 660 long optionFlags = 0; 661 if (options.minVersion > TlsVersion::SSL_3) { 662 optionFlags |= SSL_OP_NO_SSLv3; 663 } 664 if (options.minVersion > TlsVersion::TLS_1_0) { 665 optionFlags |= SSL_OP_NO_TLSv1; 666 } 667 if (options.minVersion > TlsVersion::TLS_1_1) { 668 optionFlags |= SSL_OP_NO_TLSv1_1; 669 } 670 if (options.minVersion > TlsVersion::TLS_1_2) { 671 optionFlags |= SSL_OP_NO_TLSv1_2; 672 } 673 SSL_CTX_set_options(ctx, optionFlags); // note: never fails; returns new options bitmask 674 675 // honor options.cipherList 676 if (!SSL_CTX_set_cipher_list(ctx, options.cipherList.cStr())) { 677 throwOpensslError(); 678 } 679 680 // honor options.defaultKeypair 681 KJ_IF_MAYBE(kp, options.defaultKeypair) { 682 if (!SSL_CTX_use_PrivateKey(ctx, reinterpret_cast<EVP_PKEY*>(kp->privateKey.pkey))) { 683 throwOpensslError(); 684 } 685 686 if (!SSL_CTX_use_certificate(ctx, reinterpret_cast<X509*>(kp->certificate.chain[0]))) { 687 throwOpensslError(); 688 } 689 690 for (size_t i = 1; i < kj::size(kp->certificate.chain); i++) { 691 X509* x509 = reinterpret_cast<X509*>(kp->certificate.chain[i]); 692 if (x509 == nullptr) break; // end of chain 693 694 if (!SSL_CTX_add_extra_chain_cert(ctx, x509)) { 695 throwOpensslError(); 696 } 697 698 // SSL_CTX_add_extra_chain_cert() does NOT up the refcount itself. 699 X509_up_ref(x509); 700 } 701 } 702 703 // honor options.sniCallback 704 KJ_IF_MAYBE(sni, options.sniCallback) { 705 SSL_CTX_set_tlsext_servername_callback(ctx, &SniCallback::callback); 706 SSL_CTX_set_tlsext_servername_arg(ctx, sni); 707 } 708 709 KJ_IF_MAYBE(timeout, options.acceptTimeout) { 710 this->timer = KJ_REQUIRE_NONNULL(options.timer, 711 "acceptTimeout option requires that a timer is also provided"); 712 this->acceptTimeout = *timeout; 713 } 714 715 this->ctx = ctx; 716 } 717 718 int TlsContext::SniCallback::callback(SSL* ssl, int* ad, void* arg) { 719 // The third parameter is actually type TlsSniCallback*. 720 721 KJ_IF_MAYBE(exception, kj::runCatchingExceptions([&]() { 722 TlsSniCallback& sni = *reinterpret_cast<TlsSniCallback*>(arg); 723 724 const char* name = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name); 725 if (name != nullptr) { 726 KJ_IF_MAYBE(kp, sni.getKey(name)) { 727 if (!SSL_use_PrivateKey(ssl, reinterpret_cast<EVP_PKEY*>(kp->privateKey.pkey))) { 728 throwOpensslError(); 729 } 730 731 if (!SSL_use_certificate(ssl, reinterpret_cast<X509*>(kp->certificate.chain[0]))) { 732 throwOpensslError(); 733 } 734 735 if (!SSL_clear_chain_certs(ssl)) { 736 throwOpensslError(); 737 } 738 739 for (size_t i = 1; i < kj::size(kp->certificate.chain); i++) { 740 X509* x509 = reinterpret_cast<X509*>(kp->certificate.chain[i]); 741 if (x509 == nullptr) break; // end of chain 742 743 if (!SSL_add0_chain_cert(ssl, x509)) { 744 throwOpensslError(); 745 } 746 747 // SSL_add0_chain_cert() does NOT up the refcount itself. 748 X509_up_ref(x509); 749 } 750 } 751 } 752 })) { 753 KJ_LOG(ERROR, "exception when invoking SNI callback", *exception); 754 *ad = SSL_AD_INTERNAL_ERROR; 755 return SSL_TLSEXT_ERR_ALERT_FATAL; 756 } 757 758 return SSL_TLSEXT_ERR_OK; 759 } 760 761 TlsContext::~TlsContext() noexcept(false) { 762 SSL_CTX_free(reinterpret_cast<SSL_CTX*>(ctx)); 763 } 764 765 kj::Promise<kj::Own<kj::AsyncIoStream>> TlsContext::wrapClient( 766 kj::Own<kj::AsyncIoStream> stream, kj::StringPtr expectedServerHostname) { 767 auto conn = kj::heap<TlsConnection>(kj::mv(stream), reinterpret_cast<SSL_CTX*>(ctx)); 768 auto promise = conn->connect(expectedServerHostname); 769 return promise.then(kj::mvCapture(conn, [](kj::Own<TlsConnection> conn) 770 -> kj::Own<kj::AsyncIoStream> { 771 return kj::mv(conn); 772 })); 773 } 774 775 kj::Promise<kj::Own<kj::AsyncIoStream>> TlsContext::wrapServer(kj::Own<kj::AsyncIoStream> stream) { 776 auto conn = kj::heap<TlsConnection>(kj::mv(stream), reinterpret_cast<SSL_CTX*>(ctx)); 777 auto promise = conn->accept(); 778 KJ_IF_MAYBE(timeout, acceptTimeout) { 779 promise = KJ_REQUIRE_NONNULL(timer).timeoutAfter(*timeout, kj::mv(promise)); 780 } 781 return promise.then(kj::mvCapture(conn, [](kj::Own<TlsConnection> conn) 782 -> kj::Own<kj::AsyncIoStream> { 783 return kj::mv(conn); 784 })); 785 } 786 787 kj::Promise<kj::AuthenticatedStream> TlsContext::wrapClient( 788 kj::AuthenticatedStream stream, kj::StringPtr expectedServerHostname) { 789 auto conn = kj::heap<TlsConnection>(kj::mv(stream.stream), reinterpret_cast<SSL_CTX*>(ctx)); 790 auto promise = conn->connect(expectedServerHostname); 791 return promise.then([conn=kj::mv(conn),innerId=kj::mv(stream.peerIdentity)]() mutable { 792 auto id = conn->getIdentity(kj::mv(innerId)); 793 return kj::AuthenticatedStream { kj::mv(conn), kj::mv(id) }; 794 }); 795 } 796 797 kj::Promise<kj::AuthenticatedStream> TlsContext::wrapServer(kj::AuthenticatedStream stream) { 798 auto conn = kj::heap<TlsConnection>(kj::mv(stream.stream), reinterpret_cast<SSL_CTX*>(ctx)); 799 auto promise = conn->accept(); 800 KJ_IF_MAYBE(timeout, acceptTimeout) { 801 promise = KJ_REQUIRE_NONNULL(timer).timeoutAfter(*timeout, kj::mv(promise)); 802 } 803 return promise.then([conn=kj::mv(conn),innerId=kj::mv(stream.peerIdentity)]() mutable { 804 auto id = conn->getIdentity(kj::mv(innerId)); 805 return kj::AuthenticatedStream { kj::mv(conn), kj::mv(id) }; 806 }); 807 } 808 809 kj::Own<kj::ConnectionReceiver> TlsContext::wrapPort(kj::Own<kj::ConnectionReceiver> port) { 810 return kj::heap<TlsConnectionReceiver>(*this, kj::mv(port)); 811 } 812 813 kj::Own<kj::Network> TlsContext::wrapNetwork(kj::Network& network) { 814 return kj::heap<TlsNetwork>(*this, network); 815 } 816 817 // ======================================================================================= 818 // class TlsPrivateKey 819 820 TlsPrivateKey::TlsPrivateKey(kj::ArrayPtr<const byte> asn1) { 821 ensureOpenSslInitialized(); 822 823 const byte* ptr = asn1.begin(); 824 pkey = d2i_AutoPrivateKey(nullptr, &ptr, asn1.size()); 825 if (pkey == nullptr) { 826 throwOpensslError(); 827 } 828 } 829 830 TlsPrivateKey::TlsPrivateKey(kj::StringPtr pem, kj::Maybe<kj::StringPtr> password) { 831 ensureOpenSslInitialized(); 832 833 // const_cast apparently needed for older versions of OpenSSL. 834 BIO* bio = BIO_new_mem_buf(const_cast<char*>(pem.begin()), pem.size()); 835 KJ_DEFER(BIO_free(bio)); 836 837 pkey = PEM_read_bio_PrivateKey(bio, nullptr, &passwordCallback, &password); 838 if (pkey == nullptr) { 839 throwOpensslError(); 840 } 841 } 842 843 TlsPrivateKey::TlsPrivateKey(const TlsPrivateKey& other) 844 : pkey(other.pkey) { 845 if (pkey != nullptr) EVP_PKEY_up_ref(reinterpret_cast<EVP_PKEY*>(pkey)); 846 } 847 848 TlsPrivateKey& TlsPrivateKey::operator=(const TlsPrivateKey& other) { 849 if (pkey != other.pkey) { 850 EVP_PKEY_free(reinterpret_cast<EVP_PKEY*>(pkey)); 851 pkey = other.pkey; 852 if (pkey != nullptr) EVP_PKEY_up_ref(reinterpret_cast<EVP_PKEY*>(pkey)); 853 } 854 return *this; 855 } 856 857 TlsPrivateKey::~TlsPrivateKey() noexcept(false) { 858 EVP_PKEY_free(reinterpret_cast<EVP_PKEY*>(pkey)); 859 } 860 861 int TlsPrivateKey::passwordCallback(char* buf, int size, int rwflag, void* u) { 862 auto& password = *reinterpret_cast<kj::Maybe<kj::StringPtr>*>(u); 863 864 KJ_IF_MAYBE(p, password) { 865 int result = kj::min(p->size(), size); 866 memcpy(buf, p->begin(), result); 867 return result; 868 } else { 869 return 0; 870 } 871 } 872 873 // ======================================================================================= 874 // class TlsCertificate 875 876 TlsCertificate::TlsCertificate(kj::ArrayPtr<const kj::ArrayPtr<const byte>> asn1) { 877 ensureOpenSslInitialized(); 878 879 KJ_REQUIRE(asn1.size() > 0, "must provide at least one certificate in chain"); 880 KJ_REQUIRE(asn1.size() <= kj::size(chain), 881 "exceeded maximum certificate chain length of 10"); 882 883 memset(chain, 0, sizeof(chain)); 884 885 for (auto i: kj::indices(asn1)) { 886 auto p = asn1[i].begin(); 887 888 // "_AUX" apparently refers to some auxilliary information that can be appended to the 889 // certificate, but should only be trusted for your own certificate, not the whole chain?? 890 // I don't really know, I'm just cargo-culting. 891 chain[i] = i == 0 ? d2i_X509_AUX(nullptr, &p, asn1[i].size()) 892 : d2i_X509(nullptr, &p, asn1[i].size()); 893 894 if (chain[i] == nullptr) { 895 for (size_t j = 0; j < i; j++) { 896 X509_free(reinterpret_cast<X509*>(chain[j])); 897 } 898 throwOpensslError(); 899 } 900 } 901 } 902 903 TlsCertificate::TlsCertificate(kj::ArrayPtr<const byte> asn1) 904 : TlsCertificate(kj::arrayPtr(&asn1, 1)) {} 905 906 TlsCertificate::TlsCertificate(kj::StringPtr pem) { 907 ensureOpenSslInitialized(); 908 909 memset(chain, 0, sizeof(chain)); 910 911 // const_cast apparently needed for older versions of OpenSSL. 912 BIO* bio = BIO_new_mem_buf(const_cast<char*>(pem.begin()), pem.size()); 913 KJ_DEFER(BIO_free(bio)); 914 915 for (auto i: kj::indices(chain)) { 916 // "_AUX" apparently refers to some auxilliary information that can be appended to the 917 // certificate, but should only be trusted for your own certificate, not the whole chain?? 918 // I don't really know, I'm just cargo-culting. 919 chain[i] = i == 0 ? PEM_read_bio_X509_AUX(bio, nullptr, nullptr, nullptr) 920 : PEM_read_bio_X509(bio, nullptr, nullptr, nullptr); 921 922 if (chain[i] == nullptr) { 923 auto error = ERR_peek_last_error(); 924 if (i > 0 && ERR_GET_LIB(error) == ERR_LIB_PEM && 925 ERR_GET_REASON(error) == PEM_R_NO_START_LINE) { 926 // EOF; we're done. 927 ERR_clear_error(); 928 return; 929 } else { 930 for (size_t j = 0; j < i; j++) { 931 X509_free(reinterpret_cast<X509*>(chain[j])); 932 } 933 throwOpensslError(); 934 } 935 } 936 } 937 938 // We reached the chain length limit. Try to read one more to verify that the chain ends here. 939 X509* dummy = PEM_read_bio_X509(bio, nullptr, nullptr, nullptr); 940 if (dummy != nullptr) { 941 X509_free(dummy); 942 for (auto i: kj::indices(chain)) { 943 X509_free(reinterpret_cast<X509*>(chain[i])); 944 } 945 KJ_FAIL_REQUIRE("exceeded maximum certificate chain length of 10"); 946 } 947 } 948 949 TlsCertificate::TlsCertificate(const TlsCertificate& other) { 950 memcpy(chain, other.chain, sizeof(chain)); 951 for (void* p: chain) { 952 if (p == nullptr) break; // end of chain; quit early 953 X509_up_ref(reinterpret_cast<X509*>(p)); 954 } 955 } 956 957 TlsCertificate& TlsCertificate::operator=(const TlsCertificate& other) { 958 for (auto i: kj::indices(chain)) { 959 if (chain[i] != other.chain[i]) { 960 EVP_PKEY_free(reinterpret_cast<EVP_PKEY*>(chain[i])); 961 chain[i] = other.chain[i]; 962 if (chain[i] != nullptr) X509_up_ref(reinterpret_cast<X509*>(chain[i])); 963 } else if (chain[i] == nullptr) { 964 // end of both chains; quit early 965 break; 966 } 967 } 968 return *this; 969 } 970 971 TlsCertificate::~TlsCertificate() noexcept(false) { 972 for (void* p: chain) { 973 if (p == nullptr) break; // end of chain; quit early 974 X509_free(reinterpret_cast<X509*>(p)); 975 } 976 } 977 978 // ======================================================================================= 979 // class TlsPeerIdentity 980 981 TlsPeerIdentity::~TlsPeerIdentity() noexcept(false) { 982 if (cert != nullptr) { 983 X509_free(reinterpret_cast<X509*>(cert)); 984 } 985 } 986 987 kj::String TlsPeerIdentity::toString() { 988 if (hasCertificate()) { 989 return getCommonName(); 990 } else { 991 return kj::str("(anonymous client)"); 992 } 993 } 994 995 kj::String TlsPeerIdentity::getCommonName() { 996 if (cert == nullptr) { 997 KJ_FAIL_REQUIRE("client did not provide a certificate") { return nullptr; } 998 } 999 1000 X509_NAME* subj = X509_get_subject_name(reinterpret_cast<X509*>(cert)); 1001 1002 int index = X509_NAME_get_index_by_NID(subj, NID_commonName, -1); 1003 KJ_ASSERT(index != -1, "certificate has no common name?"); 1004 X509_NAME_ENTRY* entry = X509_NAME_get_entry(subj, index); 1005 KJ_ASSERT(entry != nullptr); 1006 ASN1_STRING* data = X509_NAME_ENTRY_get_data(entry); 1007 KJ_ASSERT(data != nullptr); 1008 1009 unsigned char* out = nullptr; 1010 int len = ASN1_STRING_to_UTF8(&out, data); 1011 KJ_ASSERT(len >= 0); 1012 KJ_DEFER(OPENSSL_free(out)); 1013 1014 return kj::heapString(reinterpret_cast<char*>(out), len); 1015 } 1016 1017 } // namespace kj 1018 1019 #endif // KJ_HAS_OPENSSL