http.c++ (190709B)
1 // Copyright (c) 2017 Sandstorm Development Group, Inc. and contributors 2 // Licensed under the MIT License: 3 // 4 // Permission is hereby granted, free of charge, to any person obtaining a copy 5 // of this software and associated documentation files (the "Software"), to deal 6 // in the Software without restriction, including without limitation the rights 7 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 8 // copies of the Software, and to permit persons to whom the Software is 9 // furnished to do so, subject to the following conditions: 10 // 11 // The above copyright notice and this permission notice shall be included in 12 // all copies or substantial portions of the Software. 13 // 14 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 17 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 19 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 20 // THE SOFTWARE. 21 22 #include "http.h" 23 #include "url.h" 24 #include <kj/debug.h> 25 #include <kj/parse/char.h> 26 #include <unordered_map> 27 #include <stdlib.h> 28 #include <kj/encoding.h> 29 #include <deque> 30 #include <queue> 31 #include <map> 32 33 namespace kj { 34 35 // ======================================================================================= 36 // SHA-1 implementation from https://github.com/clibs/sha1 37 // 38 // The WebSocket standard depends on SHA-1. ARRRGGGHHHHH. 39 // 40 // Any old checksum would have served the purpose, or hell, even just returning the header 41 // verbatim. But NO, they decided to throw a whole complicated hash algorithm in there, AND 42 // THEY CHOSE A BROKEN ONE THAT WE OTHERWISE WOULDN'T NEED ANYMORE. 43 // 44 // TODO(cleanup): Move this to a shared hashing library. Maybe. Or maybe don't, because no one 45 // should be using SHA-1 anymore. 46 // 47 // THIS USAGE IS NOT SECURITY SENSITIVE. IF YOU REPORT A SECURITY ISSUE BECAUSE YOU SAW SHA1 IN THE 48 // SOURCE CODE I WILL MAKE FUN OF YOU. 49 50 /* 51 SHA-1 in C 52 By Steve Reid <steve@edmweb.com> 53 100% Public Domain 54 Test Vectors (from FIPS PUB 180-1) 55 "abc" 56 A9993E36 4706816A BA3E2571 7850C26C 9CD0D89D 57 "abcdbcdecdefdefgefghfghighijhijkijkljklmklmnlmnomnopnopq" 58 84983E44 1C3BD26E BAAE4AA1 F95129E5 E54670F1 59 A million repetitions of "a" 60 34AA973C D4C4DAA4 F61EEB2B DBAD2731 6534016F 61 */ 62 63 /* #define LITTLE_ENDIAN * This should be #define'd already, if true. */ 64 /* #define SHA1HANDSOFF * Copies data before messing with it. */ 65 66 #define SHA1HANDSOFF 67 68 typedef struct 69 { 70 uint32_t state[5]; 71 uint32_t count[2]; 72 unsigned char buffer[64]; 73 } SHA1_CTX; 74 75 #define rol(value, bits) (((value) << (bits)) | ((value) >> (32 - (bits)))) 76 77 /* blk0() and blk() perform the initial expand. */ 78 /* I got the idea of expanding during the round function from SSLeay */ 79 #if BYTE_ORDER == LITTLE_ENDIAN 80 #define blk0(i) (block->l[i] = (rol(block->l[i],24)&0xFF00FF00) \ 81 |(rol(block->l[i],8)&0x00FF00FF)) 82 #elif BYTE_ORDER == BIG_ENDIAN 83 #define blk0(i) block->l[i] 84 #else 85 #error "Endianness not defined!" 86 #endif 87 #define blk(i) (block->l[i&15] = rol(block->l[(i+13)&15]^block->l[(i+8)&15] \ 88 ^block->l[(i+2)&15]^block->l[i&15],1)) 89 90 /* (R0+R1), R2, R3, R4 are the different operations used in SHA1 */ 91 #define R0(v,w,x,y,z,i) z+=((w&(x^y))^y)+blk0(i)+0x5A827999+rol(v,5);w=rol(w,30); 92 #define R1(v,w,x,y,z,i) z+=((w&(x^y))^y)+blk(i)+0x5A827999+rol(v,5);w=rol(w,30); 93 #define R2(v,w,x,y,z,i) z+=(w^x^y)+blk(i)+0x6ED9EBA1+rol(v,5);w=rol(w,30); 94 #define R3(v,w,x,y,z,i) z+=(((w|x)&y)|(w&x))+blk(i)+0x8F1BBCDC+rol(v,5);w=rol(w,30); 95 #define R4(v,w,x,y,z,i) z+=(w^x^y)+blk(i)+0xCA62C1D6+rol(v,5);w=rol(w,30); 96 97 98 /* Hash a single 512-bit block. This is the core of the algorithm. */ 99 100 void SHA1Transform( 101 uint32_t state[5], 102 const unsigned char buffer[64] 103 ) 104 { 105 uint32_t a, b, c, d, e; 106 107 typedef union 108 { 109 unsigned char c[64]; 110 uint32_t l[16]; 111 } CHAR64LONG16; 112 113 #ifdef SHA1HANDSOFF 114 CHAR64LONG16 block[1]; /* use array to appear as a pointer */ 115 116 memcpy(block, buffer, 64); 117 #else 118 /* The following had better never be used because it causes the 119 * pointer-to-const buffer to be cast into a pointer to non-const. 120 * And the result is written through. I threw a "const" in, hoping 121 * this will cause a diagnostic. 122 */ 123 CHAR64LONG16 *block = (const CHAR64LONG16 *) buffer; 124 #endif 125 /* Copy context->state[] to working vars */ 126 a = state[0]; 127 b = state[1]; 128 c = state[2]; 129 d = state[3]; 130 e = state[4]; 131 /* 4 rounds of 20 operations each. Loop unrolled. */ 132 R0(a, b, c, d, e, 0); 133 R0(e, a, b, c, d, 1); 134 R0(d, e, a, b, c, 2); 135 R0(c, d, e, a, b, 3); 136 R0(b, c, d, e, a, 4); 137 R0(a, b, c, d, e, 5); 138 R0(e, a, b, c, d, 6); 139 R0(d, e, a, b, c, 7); 140 R0(c, d, e, a, b, 8); 141 R0(b, c, d, e, a, 9); 142 R0(a, b, c, d, e, 10); 143 R0(e, a, b, c, d, 11); 144 R0(d, e, a, b, c, 12); 145 R0(c, d, e, a, b, 13); 146 R0(b, c, d, e, a, 14); 147 R0(a, b, c, d, e, 15); 148 R1(e, a, b, c, d, 16); 149 R1(d, e, a, b, c, 17); 150 R1(c, d, e, a, b, 18); 151 R1(b, c, d, e, a, 19); 152 R2(a, b, c, d, e, 20); 153 R2(e, a, b, c, d, 21); 154 R2(d, e, a, b, c, 22); 155 R2(c, d, e, a, b, 23); 156 R2(b, c, d, e, a, 24); 157 R2(a, b, c, d, e, 25); 158 R2(e, a, b, c, d, 26); 159 R2(d, e, a, b, c, 27); 160 R2(c, d, e, a, b, 28); 161 R2(b, c, d, e, a, 29); 162 R2(a, b, c, d, e, 30); 163 R2(e, a, b, c, d, 31); 164 R2(d, e, a, b, c, 32); 165 R2(c, d, e, a, b, 33); 166 R2(b, c, d, e, a, 34); 167 R2(a, b, c, d, e, 35); 168 R2(e, a, b, c, d, 36); 169 R2(d, e, a, b, c, 37); 170 R2(c, d, e, a, b, 38); 171 R2(b, c, d, e, a, 39); 172 R3(a, b, c, d, e, 40); 173 R3(e, a, b, c, d, 41); 174 R3(d, e, a, b, c, 42); 175 R3(c, d, e, a, b, 43); 176 R3(b, c, d, e, a, 44); 177 R3(a, b, c, d, e, 45); 178 R3(e, a, b, c, d, 46); 179 R3(d, e, a, b, c, 47); 180 R3(c, d, e, a, b, 48); 181 R3(b, c, d, e, a, 49); 182 R3(a, b, c, d, e, 50); 183 R3(e, a, b, c, d, 51); 184 R3(d, e, a, b, c, 52); 185 R3(c, d, e, a, b, 53); 186 R3(b, c, d, e, a, 54); 187 R3(a, b, c, d, e, 55); 188 R3(e, a, b, c, d, 56); 189 R3(d, e, a, b, c, 57); 190 R3(c, d, e, a, b, 58); 191 R3(b, c, d, e, a, 59); 192 R4(a, b, c, d, e, 60); 193 R4(e, a, b, c, d, 61); 194 R4(d, e, a, b, c, 62); 195 R4(c, d, e, a, b, 63); 196 R4(b, c, d, e, a, 64); 197 R4(a, b, c, d, e, 65); 198 R4(e, a, b, c, d, 66); 199 R4(d, e, a, b, c, 67); 200 R4(c, d, e, a, b, 68); 201 R4(b, c, d, e, a, 69); 202 R4(a, b, c, d, e, 70); 203 R4(e, a, b, c, d, 71); 204 R4(d, e, a, b, c, 72); 205 R4(c, d, e, a, b, 73); 206 R4(b, c, d, e, a, 74); 207 R4(a, b, c, d, e, 75); 208 R4(e, a, b, c, d, 76); 209 R4(d, e, a, b, c, 77); 210 R4(c, d, e, a, b, 78); 211 R4(b, c, d, e, a, 79); 212 /* Add the working vars back into context.state[] */ 213 state[0] += a; 214 state[1] += b; 215 state[2] += c; 216 state[3] += d; 217 state[4] += e; 218 /* Wipe variables */ 219 a = b = c = d = e = 0; 220 #ifdef SHA1HANDSOFF 221 memset(block, '\0', sizeof(block)); 222 #endif 223 } 224 225 226 /* SHA1Init - Initialize new context */ 227 228 void SHA1Init( 229 SHA1_CTX * context 230 ) 231 { 232 /* SHA1 initialization constants */ 233 context->state[0] = 0x67452301; 234 context->state[1] = 0xEFCDAB89; 235 context->state[2] = 0x98BADCFE; 236 context->state[3] = 0x10325476; 237 context->state[4] = 0xC3D2E1F0; 238 context->count[0] = context->count[1] = 0; 239 } 240 241 242 /* Run your data through this. */ 243 244 void SHA1Update( 245 SHA1_CTX * context, 246 const unsigned char *data, 247 uint32_t len 248 ) 249 { 250 uint32_t i; 251 252 uint32_t j; 253 254 j = context->count[0]; 255 if ((context->count[0] += len << 3) < j) 256 context->count[1]++; 257 context->count[1] += (len >> 29); 258 j = (j >> 3) & 63; 259 if ((j + len) > 63) 260 { 261 memcpy(&context->buffer[j], data, (i = 64 - j)); 262 SHA1Transform(context->state, context->buffer); 263 for (; i + 63 < len; i += 64) 264 { 265 SHA1Transform(context->state, &data[i]); 266 } 267 j = 0; 268 } 269 else 270 i = 0; 271 memcpy(&context->buffer[j], &data[i], len - i); 272 } 273 274 275 /* Add padding and return the message digest. */ 276 277 void SHA1Final( 278 unsigned char digest[20], 279 SHA1_CTX * context 280 ) 281 { 282 unsigned i; 283 284 unsigned char finalcount[8]; 285 286 unsigned char c; 287 288 #if 0 /* untested "improvement" by DHR */ 289 /* Convert context->count to a sequence of bytes 290 * in finalcount. Second element first, but 291 * big-endian order within element. 292 * But we do it all backwards. 293 */ 294 unsigned char *fcp = &finalcount[8]; 295 for (i = 0; i < 2; i++) 296 { 297 uint32_t t = context->count[i]; 298 int j; 299 for (j = 0; j < 4; t >>= 8, j++) 300 *--fcp = (unsigned char) t} 301 #else 302 for (i = 0; i < 8; i++) 303 { 304 finalcount[i] = (unsigned char) ((context->count[(i >= 4 ? 0 : 1)] >> ((3 - (i & 3)) * 8)) & 255); /* Endian independent */ 305 } 306 #endif 307 c = 0200; 308 SHA1Update(context, &c, 1); 309 while ((context->count[0] & 504) != 448) 310 { 311 c = 0000; 312 SHA1Update(context, &c, 1); 313 } 314 SHA1Update(context, finalcount, 8); /* Should cause a SHA1Transform() */ 315 for (i = 0; i < 20; i++) 316 { 317 digest[i] = (unsigned char) 318 ((context->state[i >> 2] >> ((3 - (i & 3)) * 8)) & 255); 319 } 320 /* Wipe variables */ 321 memset(context, '\0', sizeof(*context)); 322 memset(&finalcount, '\0', sizeof(finalcount)); 323 } 324 325 // End SHA-1 implementation. 326 // ======================================================================================= 327 328 static const char* METHOD_NAMES[] = { 329 #define METHOD_NAME(id) #id, 330 KJ_HTTP_FOR_EACH_METHOD(METHOD_NAME) 331 #undef METHOD_NAME 332 }; 333 334 kj::StringPtr KJ_STRINGIFY(HttpMethod method) { 335 return METHOD_NAMES[static_cast<uint>(method)]; 336 } 337 338 static kj::Maybe<HttpMethod> consumeHttpMethod(char*& ptr) { 339 char* p = ptr; 340 341 #define EXPECT_REST(prefix, suffix) \ 342 if (strncmp(p, #suffix, sizeof(#suffix)-1) == 0) { \ 343 ptr = p + (sizeof(#suffix)-1); \ 344 return HttpMethod::prefix##suffix; \ 345 } else { \ 346 return nullptr; \ 347 } 348 349 switch (*p++) { 350 case 'A': EXPECT_REST(A,CL) 351 case 'C': 352 switch (*p++) { 353 case 'H': EXPECT_REST(CH,ECKOUT) 354 case 'O': EXPECT_REST(CO,PY) 355 default: return nullptr; 356 } 357 case 'D': EXPECT_REST(D,ELETE) 358 case 'G': EXPECT_REST(G,ET) 359 case 'H': EXPECT_REST(H,EAD) 360 case 'L': EXPECT_REST(L,OCK) 361 case 'M': 362 switch (*p++) { 363 case 'E': EXPECT_REST(ME,RGE) 364 case 'K': 365 switch (*p++) { 366 case 'A': EXPECT_REST(MKA,CTIVITY) 367 case 'C': EXPECT_REST(MKC,OL) 368 default: return nullptr; 369 } 370 case 'O': EXPECT_REST(MO,VE) 371 case 'S': EXPECT_REST(MS,EARCH) 372 default: return nullptr; 373 } 374 case 'N': EXPECT_REST(N,OTIFY) 375 case 'O': EXPECT_REST(O,PTIONS) 376 case 'P': 377 switch (*p++) { 378 case 'A': EXPECT_REST(PA,TCH) 379 case 'O': EXPECT_REST(PO,ST) 380 case 'R': 381 if (*p++ != 'O' || *p++ != 'P') return nullptr; 382 switch (*p++) { 383 case 'F': EXPECT_REST(PROPF,IND) 384 case 'P': EXPECT_REST(PROPP,ATCH) 385 default: return nullptr; 386 } 387 case 'U': 388 switch (*p++) { 389 case 'R': EXPECT_REST(PUR,GE) 390 case 'T': EXPECT_REST(PUT,) 391 default: return nullptr; 392 } 393 default: return nullptr; 394 } 395 case 'R': EXPECT_REST(R,EPORT) 396 case 'S': 397 switch (*p++) { 398 case 'E': EXPECT_REST(SE,ARCH) 399 case 'U': EXPECT_REST(SU,BSCRIBE) 400 default: return nullptr; 401 } 402 case 'T': EXPECT_REST(T,RACE) 403 case 'U': 404 if (*p++ != 'N') return nullptr; 405 switch (*p++) { 406 case 'L': EXPECT_REST(UNL,OCK) 407 case 'S': EXPECT_REST(UNS,UBSCRIBE) 408 default: return nullptr; 409 } 410 default: return nullptr; 411 } 412 #undef EXPECT_REST 413 } 414 415 kj::Maybe<HttpMethod> tryParseHttpMethod(kj::StringPtr name) { 416 // const_cast OK because we don't actually access it. consumeHttpMethod() is also called by some 417 // code later than explicitly needs to use a non-const pointer. 418 char* ptr = const_cast<char*>(name.begin()); 419 auto result = consumeHttpMethod(ptr); 420 if (*ptr == '\0') { 421 return result; 422 } else { 423 return nullptr; 424 } 425 } 426 427 // ======================================================================================= 428 429 namespace { 430 431 constexpr char WEBSOCKET_GUID[] = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; 432 // From RFC6455. 433 434 static kj::String generateWebSocketAccept(kj::StringPtr key) { 435 // WebSocket demands we do a SHA-1 here. ARRGHH WHY SHA-1 WHYYYYYY? 436 SHA1_CTX ctx; 437 byte digest[20]; 438 SHA1Init(&ctx); 439 SHA1Update(&ctx, key.asBytes().begin(), key.size()); 440 SHA1Update(&ctx, reinterpret_cast<const byte*>(WEBSOCKET_GUID), strlen(WEBSOCKET_GUID)); 441 SHA1Final(digest, &ctx); 442 return kj::encodeBase64(digest); 443 } 444 445 constexpr auto HTTP_SEPARATOR_CHARS = kj::parse::anyOfChars("()<>@,;:\\\"/[]?={} \t"); 446 // RFC2616 section 2.2: https://www.w3.org/Protocols/rfc2616/rfc2616-sec2.html#sec2.2 447 448 constexpr auto HTTP_TOKEN_CHARS = 449 kj::parse::controlChar.orChar('\x7f') 450 .orGroup(kj::parse::whitespaceChar) 451 .orGroup(HTTP_SEPARATOR_CHARS) 452 .invert(); 453 // RFC2616 section 2.2: https://www.w3.org/Protocols/rfc2616/rfc2616-sec2.html#sec2.2 454 455 constexpr auto HTTP_HEADER_NAME_CHARS = HTTP_TOKEN_CHARS; 456 // RFC2616 section 4.2: https://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.2 457 458 static void requireValidHeaderName(kj::StringPtr name) { 459 for (char c: name) { 460 KJ_REQUIRE(HTTP_HEADER_NAME_CHARS.contains(c), "invalid header name", name); 461 } 462 } 463 464 static void requireValidHeaderValue(kj::StringPtr value) { 465 KJ_REQUIRE(HttpHeaders::isValidHeaderValue(value), "invalid header value", 466 kj::encodeCEscape(value)); 467 } 468 469 static const char* BUILTIN_HEADER_NAMES[] = { 470 // Indexed by header ID, which includes connection headers, so we include those names too. 471 #define HEADER_NAME(id, name) name, 472 KJ_HTTP_FOR_EACH_BUILTIN_HEADER(HEADER_NAME) 473 #undef HEADER_NAME 474 }; 475 476 } // namespace 477 478 #define HEADER_ID(id, name) constexpr uint HttpHeaders::BuiltinIndices::id; 479 KJ_HTTP_FOR_EACH_BUILTIN_HEADER(HEADER_ID) 480 #undef HEADER_ID 481 482 #define DEFINE_HEADER(id, name) \ 483 const HttpHeaderId HttpHeaderId::id(nullptr, HttpHeaders::BuiltinIndices::id); 484 KJ_HTTP_FOR_EACH_BUILTIN_HEADER(DEFINE_HEADER) 485 #undef DEFINE_HEADER 486 487 kj::StringPtr HttpHeaderId::toString() const { 488 if (table == nullptr) { 489 KJ_ASSERT(id < kj::size(BUILTIN_HEADER_NAMES)); 490 return BUILTIN_HEADER_NAMES[id]; 491 } else { 492 return table->idToString(*this); 493 } 494 } 495 496 namespace { 497 498 struct HeaderNameHash { 499 size_t operator()(kj::StringPtr s) const { 500 size_t result = 5381; 501 for (byte b: s.asBytes()) { 502 // Masking bit 0x20 makes our hash case-insensitive while conveniently avoiding any 503 // collisions that would matter for header names. 504 result = ((result << 5) + result) ^ (b & ~0x20); 505 } 506 return result; 507 } 508 509 bool operator()(kj::StringPtr a, kj::StringPtr b) const { 510 // TODO(perf): I wonder if we can beat strcasecmp() by masking bit 0x20 from each byte. We'd 511 // need to prohibit one of the technically-legal characters '^' or '~' from header names 512 // since they'd otherwise be ambiguous, but otherwise there is no ambiguity. 513 #if _MSC_VER 514 return _stricmp(a.cStr(), b.cStr()) == 0; 515 #else 516 return strcasecmp(a.cStr(), b.cStr()) == 0; 517 #endif 518 } 519 }; 520 521 } // namespace 522 523 struct HttpHeaderTable::IdsByNameMap { 524 // TODO(perf): If we were cool we could maybe use a perfect hash here, since our hashtable is 525 // static once built. 526 527 std::unordered_map<kj::StringPtr, uint, HeaderNameHash, HeaderNameHash> map; 528 }; 529 530 HttpHeaderTable::Builder::Builder() 531 : table(kj::heap<HttpHeaderTable>()) {} 532 533 HttpHeaderId HttpHeaderTable::Builder::add(kj::StringPtr name) { 534 requireValidHeaderName(name); 535 536 auto insertResult = table->idsByName->map.insert(std::make_pair(name, table->namesById.size())); 537 if (insertResult.second) { 538 table->namesById.add(name); 539 } 540 return HttpHeaderId(table, insertResult.first->second); 541 } 542 543 HttpHeaderTable::HttpHeaderTable() 544 : idsByName(kj::heap<IdsByNameMap>()) { 545 #define ADD_HEADER(id, name) \ 546 namesById.add(name); \ 547 idsByName->map.insert(std::make_pair(name, HttpHeaders::BuiltinIndices::id)); 548 KJ_HTTP_FOR_EACH_BUILTIN_HEADER(ADD_HEADER); 549 #undef ADD_HEADER 550 } 551 HttpHeaderTable::~HttpHeaderTable() noexcept(false) {} 552 553 kj::Maybe<HttpHeaderId> HttpHeaderTable::stringToId(kj::StringPtr name) const { 554 auto iter = idsByName->map.find(name); 555 if (iter == idsByName->map.end()) { 556 return nullptr; 557 } else { 558 return HttpHeaderId(this, iter->second); 559 } 560 } 561 562 // ======================================================================================= 563 564 bool HttpHeaders::isValidHeaderValue(kj::StringPtr value) { 565 for (char c: value) { 566 // While the HTTP spec suggests that only printable ASCII characters are allowed in header 567 // values, reality has a different opinion. See: https://github.com/httpwg/http11bis/issues/19 568 // We follow the browsers' lead. 569 if (c == '\0' || c == '\r' || c == '\n') { 570 return false; 571 } 572 } 573 574 return true; 575 } 576 577 HttpHeaders::HttpHeaders(const HttpHeaderTable& table) 578 : table(&table), 579 indexedHeaders(kj::heapArray<kj::StringPtr>(table.idCount())) {} 580 581 void HttpHeaders::clear() { 582 for (auto& header: indexedHeaders) { 583 header = nullptr; 584 } 585 586 unindexedHeaders.clear(); 587 } 588 589 size_t HttpHeaders::size() const { 590 size_t result = unindexedHeaders.size(); 591 for (auto i: kj::indices(indexedHeaders)) { 592 if (indexedHeaders[i] != nullptr) { 593 ++result; 594 } 595 } 596 return result; 597 } 598 599 HttpHeaders HttpHeaders::clone() const { 600 HttpHeaders result(*table); 601 602 for (auto i: kj::indices(indexedHeaders)) { 603 if (indexedHeaders[i] != nullptr) { 604 result.indexedHeaders[i] = result.cloneToOwn(indexedHeaders[i]); 605 } 606 } 607 608 result.unindexedHeaders.resize(unindexedHeaders.size()); 609 for (auto i: kj::indices(unindexedHeaders)) { 610 result.unindexedHeaders[i].name = result.cloneToOwn(unindexedHeaders[i].name); 611 result.unindexedHeaders[i].value = result.cloneToOwn(unindexedHeaders[i].value); 612 } 613 614 return result; 615 } 616 617 HttpHeaders HttpHeaders::cloneShallow() const { 618 HttpHeaders result(*table); 619 620 for (auto i: kj::indices(indexedHeaders)) { 621 if (indexedHeaders[i] != nullptr) { 622 result.indexedHeaders[i] = indexedHeaders[i]; 623 } 624 } 625 626 result.unindexedHeaders.resize(unindexedHeaders.size()); 627 for (auto i: kj::indices(unindexedHeaders)) { 628 result.unindexedHeaders[i] = unindexedHeaders[i]; 629 } 630 631 return result; 632 } 633 634 kj::StringPtr HttpHeaders::cloneToOwn(kj::StringPtr str) { 635 auto copy = kj::heapString(str); 636 kj::StringPtr result = copy; 637 ownedStrings.add(copy.releaseArray()); 638 return result; 639 } 640 641 642 namespace { 643 644 template <char... chars> 645 constexpr bool fastCaseCmp(const char* actual); 646 647 } // namespace 648 649 bool HttpHeaders::isWebSocket() const { 650 return fastCaseCmp<'w', 'e', 'b', 's', 'o', 'c', 'k', 'e', 't'>( 651 get(HttpHeaderId::UPGRADE).orDefault(nullptr).cStr()); 652 } 653 654 void HttpHeaders::set(HttpHeaderId id, kj::StringPtr value) { 655 id.requireFrom(*table); 656 requireValidHeaderValue(value); 657 658 indexedHeaders[id.id] = value; 659 } 660 661 void HttpHeaders::set(HttpHeaderId id, kj::String&& value) { 662 set(id, kj::StringPtr(value)); 663 takeOwnership(kj::mv(value)); 664 } 665 666 void HttpHeaders::add(kj::StringPtr name, kj::StringPtr value) { 667 requireValidHeaderName(name); 668 requireValidHeaderValue(value); 669 670 addNoCheck(name, value); 671 } 672 673 void HttpHeaders::add(kj::StringPtr name, kj::String&& value) { 674 add(name, kj::StringPtr(value)); 675 takeOwnership(kj::mv(value)); 676 } 677 678 void HttpHeaders::add(kj::String&& name, kj::String&& value) { 679 add(kj::StringPtr(name), kj::StringPtr(value)); 680 takeOwnership(kj::mv(name)); 681 takeOwnership(kj::mv(value)); 682 } 683 684 void HttpHeaders::addNoCheck(kj::StringPtr name, kj::StringPtr value) { 685 KJ_IF_MAYBE(id, table->stringToId(name)) { 686 if (indexedHeaders[id->id] == nullptr) { 687 indexedHeaders[id->id] = value; 688 } else { 689 // Duplicate HTTP headers are equivalent to the values being separated by a comma. 690 691 #if _MSC_VER 692 if (_stricmp(name.cStr(), "set-cookie") == 0) { 693 #else 694 if (strcasecmp(name.cStr(), "set-cookie") == 0) { 695 #endif 696 // Uh-oh, Set-Cookie will be corrupted if we try to concatenate it. We'll make it an 697 // unindexed header, which is weird, but the alternative is guaranteed corruption, so... 698 // TODO(cleanup): Maybe HttpHeaders should just special-case set-cookie in general? 699 unindexedHeaders.add(Header {name, value}); 700 } else { 701 auto concat = kj::str(indexedHeaders[id->id], ", ", value); 702 indexedHeaders[id->id] = concat; 703 ownedStrings.add(concat.releaseArray()); 704 } 705 } 706 } else { 707 unindexedHeaders.add(Header {name, value}); 708 } 709 } 710 711 void HttpHeaders::takeOwnership(kj::String&& string) { 712 ownedStrings.add(string.releaseArray()); 713 } 714 void HttpHeaders::takeOwnership(kj::Array<char>&& chars) { 715 ownedStrings.add(kj::mv(chars)); 716 } 717 void HttpHeaders::takeOwnership(HttpHeaders&& otherHeaders) { 718 for (auto& str: otherHeaders.ownedStrings) { 719 ownedStrings.add(kj::mv(str)); 720 } 721 otherHeaders.ownedStrings.clear(); 722 } 723 724 // ----------------------------------------------------------------------------- 725 726 static inline char* skipSpace(char* p) { 727 for (;;) { 728 switch (*p) { 729 case '\t': 730 case ' ': 731 ++p; 732 break; 733 default: 734 return p; 735 } 736 } 737 } 738 739 static kj::Maybe<kj::StringPtr> consumeWord(char*& ptr) { 740 char* start = skipSpace(ptr); 741 char* p = start; 742 743 for (;;) { 744 switch (*p) { 745 case '\0': 746 ptr = p; 747 return kj::StringPtr(start, p); 748 749 case '\t': 750 case ' ': { 751 char* end = p++; 752 ptr = p; 753 *end = '\0'; 754 return kj::StringPtr(start, end); 755 } 756 757 case '\n': 758 case '\r': 759 // Not expecting EOL! 760 return nullptr; 761 762 default: 763 ++p; 764 break; 765 } 766 } 767 } 768 769 static kj::Maybe<uint> consumeNumber(char*& ptr) { 770 char* start = skipSpace(ptr); 771 char* p = start; 772 773 uint result = 0; 774 775 for (;;) { 776 char c = *p; 777 if ('0' <= c && c <= '9') { 778 result = result * 10 + (c - '0'); 779 ++p; 780 } else { 781 if (p == start) return nullptr; 782 ptr = p; 783 return result; 784 } 785 } 786 } 787 788 static kj::StringPtr consumeLine(char*& ptr) { 789 char* start = skipSpace(ptr); 790 char* p = start; 791 792 for (;;) { 793 switch (*p) { 794 case '\0': 795 ptr = p; 796 return kj::StringPtr(start, p); 797 798 case '\r': { 799 char* end = p++; 800 if (*p == '\n') ++p; 801 802 if (*p == ' ' || *p == '\t') { 803 // Whoa, continuation line. These are deprecated, but historically a line starting with 804 // a space was treated as a continuation of the previous line. The behavior should be 805 // the same as if the \r\n were replaced with spaces, so let's do that here to prevent 806 // confusion later. 807 *end = ' '; 808 p[-1] = ' '; 809 break; 810 } 811 812 ptr = p; 813 *end = '\0'; 814 return kj::StringPtr(start, end); 815 } 816 817 case '\n': { 818 char* end = p++; 819 820 if (*p == ' ' || *p == '\t') { 821 // Whoa, continuation line. These are deprecated, but historically a line starting with 822 // a space was treated as a continuation of the previous line. The behavior should be 823 // the same as if the \n were replaced with spaces, so let's do that here to prevent 824 // confusion later. 825 *end = ' '; 826 break; 827 } 828 829 ptr = p; 830 *end = '\0'; 831 return kj::StringPtr(start, end); 832 } 833 834 default: 835 ++p; 836 break; 837 } 838 } 839 } 840 841 static kj::Maybe<kj::StringPtr> consumeHeaderName(char*& ptr) { 842 // Do NOT skip spaces before the header name. Leading spaces indicate a continuation line; they 843 // should have been handled in consumeLine(). 844 char* p = ptr; 845 846 char* start = p; 847 while (HTTP_HEADER_NAME_CHARS.contains(*p)) ++p; 848 char* end = p; 849 850 p = skipSpace(p); 851 852 if (end == start || *p != ':') return nullptr; 853 ++p; 854 855 p = skipSpace(p); 856 857 *end = '\0'; 858 ptr = p; 859 return kj::StringPtr(start, end); 860 } 861 862 static char* trimHeaderEnding(kj::ArrayPtr<char> content) { 863 // Trim off the trailing \r\n from a header blob. 864 865 if (content.size() < 2) return nullptr; 866 867 // Remove trailing \r\n\r\n and replace with \0 sentinel char. 868 char* end = content.end(); 869 870 if (end[-1] != '\n') return nullptr; 871 --end; 872 if (end[-1] == '\r') --end; 873 *end = '\0'; 874 875 return end; 876 } 877 878 HttpHeaders::RequestOrProtocolError HttpHeaders::tryParseRequest(kj::ArrayPtr<char> content) { 879 char* end = trimHeaderEnding(content); 880 if (end == nullptr) { 881 return ProtocolError { 400, "Bad Request", 882 "Request headers have no terminal newline.", content }; 883 } 884 885 char* ptr = content.begin(); 886 887 HttpHeaders::Request request; 888 889 KJ_IF_MAYBE(method, consumeHttpMethod(ptr)) { 890 request.method = *method; 891 if (*ptr != ' ' && *ptr != '\t') { 892 return ProtocolError { 501, "Not Implemented", 893 "Unrecognized request method.", content }; 894 } 895 ++ptr; 896 } else { 897 return ProtocolError { 501, "Not Implemented", 898 "Unrecognized request method.", content }; 899 } 900 901 KJ_IF_MAYBE(path, consumeWord(ptr)) { 902 request.url = *path; 903 } else { 904 return ProtocolError { 400, "Bad Request", 905 "Invalid request line.", content }; 906 } 907 908 // Ignore rest of line. Don't care about "HTTP/1.1" or whatever. 909 consumeLine(ptr); 910 911 if (!parseHeaders(ptr, end)) { 912 return ProtocolError { 400, "Bad Request", 913 "The headers sent by your client are not valid.", content }; 914 } 915 916 return request; 917 } 918 919 HttpHeaders::ResponseOrProtocolError HttpHeaders::tryParseResponse(kj::ArrayPtr<char> content) { 920 char* end = trimHeaderEnding(content); 921 if (end == nullptr) { 922 return ProtocolError { 502, "Bad Gateway", 923 "Response headers have no terminal newline.", content }; 924 } 925 926 char* ptr = content.begin(); 927 928 HttpHeaders::Response response; 929 930 KJ_IF_MAYBE(version, consumeWord(ptr)) { 931 if (!version->startsWith("HTTP/")) { 932 return ProtocolError { 502, "Bad Gateway", 933 "Invalid response status line (invalid protocol).", content }; 934 } 935 } else { 936 return ProtocolError { 502, "Bad Gateway", 937 "Invalid response status line (no spaces).", content }; 938 } 939 940 KJ_IF_MAYBE(code, consumeNumber(ptr)) { 941 response.statusCode = *code; 942 } else { 943 return ProtocolError { 502, "Bad Gateway", 944 "Invalid response status line (invalid status code).", content }; 945 } 946 947 response.statusText = consumeLine(ptr); 948 949 if (!parseHeaders(ptr, end)) { 950 return ProtocolError { 502, "Bad Gateway", 951 "The headers sent by the server are not valid.", content }; 952 } 953 954 return response; 955 } 956 957 bool HttpHeaders::tryParse(kj::ArrayPtr<char> content) { 958 char* end = trimHeaderEnding(content); 959 if (end == nullptr) return false; 960 961 char* ptr = content.begin(); 962 return parseHeaders(ptr, end); 963 } 964 965 bool HttpHeaders::parseHeaders(char* ptr, char* end) { 966 while (*ptr != '\0') { 967 KJ_IF_MAYBE(name, consumeHeaderName(ptr)) { 968 kj::StringPtr line = consumeLine(ptr); 969 addNoCheck(*name, line); 970 } else { 971 return false; 972 } 973 } 974 975 return ptr == end; 976 } 977 978 // ----------------------------------------------------------------------------- 979 980 kj::String HttpHeaders::serializeRequest( 981 HttpMethod method, kj::StringPtr url, 982 kj::ArrayPtr<const kj::StringPtr> connectionHeaders) const { 983 return serialize(kj::toCharSequence(method), url, kj::StringPtr("HTTP/1.1"), connectionHeaders); 984 } 985 986 kj::String HttpHeaders::serializeResponse( 987 uint statusCode, kj::StringPtr statusText, 988 kj::ArrayPtr<const kj::StringPtr> connectionHeaders) const { 989 auto statusCodeStr = kj::toCharSequence(statusCode); 990 991 return serialize(kj::StringPtr("HTTP/1.1"), statusCodeStr, statusText, connectionHeaders); 992 } 993 994 kj::String HttpHeaders::serialize(kj::ArrayPtr<const char> word1, 995 kj::ArrayPtr<const char> word2, 996 kj::ArrayPtr<const char> word3, 997 kj::ArrayPtr<const kj::StringPtr> connectionHeaders) const { 998 const kj::StringPtr space = " "; 999 const kj::StringPtr newline = "\r\n"; 1000 const kj::StringPtr colon = ": "; 1001 1002 size_t size = 2; // final \r\n 1003 if (word1 != nullptr) { 1004 size += word1.size() + word2.size() + word3.size() + 4; 1005 } 1006 KJ_ASSERT(connectionHeaders.size() <= indexedHeaders.size()); 1007 for (auto i: kj::indices(indexedHeaders)) { 1008 kj::StringPtr value = i < connectionHeaders.size() ? connectionHeaders[i] : indexedHeaders[i]; 1009 if (value != nullptr) { 1010 size += table->idToString(HttpHeaderId(table, i)).size() + value.size() + 4; 1011 } 1012 } 1013 for (auto& header: unindexedHeaders) { 1014 size += header.name.size() + header.value.size() + 4; 1015 } 1016 1017 String result = heapString(size); 1018 char* ptr = result.begin(); 1019 1020 if (word1 != nullptr) { 1021 ptr = kj::_::fill(ptr, word1, space, word2, space, word3, newline); 1022 } 1023 for (auto i: kj::indices(indexedHeaders)) { 1024 kj::StringPtr value = i < connectionHeaders.size() ? connectionHeaders[i] : indexedHeaders[i]; 1025 if (value != nullptr) { 1026 ptr = kj::_::fill(ptr, table->idToString(HttpHeaderId(table, i)), colon, value, newline); 1027 } 1028 } 1029 for (auto& header: unindexedHeaders) { 1030 ptr = kj::_::fill(ptr, header.name, colon, header.value, newline); 1031 } 1032 ptr = kj::_::fill(ptr, newline); 1033 1034 KJ_ASSERT(ptr == result.end()); 1035 return result; 1036 } 1037 1038 kj::String HttpHeaders::toString() const { 1039 return serialize(nullptr, nullptr, nullptr, nullptr); 1040 } 1041 1042 // ======================================================================================= 1043 1044 namespace { 1045 1046 static constexpr size_t MIN_BUFFER = 4096; 1047 static constexpr size_t MAX_BUFFER = 128 * 1024; 1048 static constexpr size_t MAX_CHUNK_HEADER_SIZE = 32; 1049 1050 class HttpInputStreamImpl final: public HttpInputStream { 1051 public: 1052 explicit HttpInputStreamImpl(AsyncInputStream& inner, const HttpHeaderTable& table) 1053 : inner(inner), headerBuffer(kj::heapArray<char>(MIN_BUFFER)), headers(table) { 1054 } 1055 1056 bool canReuse() { 1057 return !broken && pendingMessageCount == 0; 1058 } 1059 1060 // --------------------------------------------------------------------------- 1061 // public interface 1062 1063 kj::Promise<Request> readRequest() override { 1064 return readRequestHeaders() 1065 .then([this](HttpHeaders::RequestOrProtocolError&& requestOrProtocolError) 1066 -> HttpInputStream::Request { 1067 auto request = KJ_REQUIRE_NONNULL( 1068 requestOrProtocolError.tryGet<HttpHeaders::Request>(), "bad request"); 1069 auto body = getEntityBody(HttpInputStreamImpl::REQUEST, request.method, 0, headers); 1070 1071 return { request.method, request.url, headers, kj::mv(body) }; 1072 }); 1073 } 1074 1075 kj::Promise<Response> readResponse(HttpMethod requestMethod) override { 1076 return readResponseHeaders() 1077 .then([this,requestMethod](HttpHeaders::ResponseOrProtocolError&& responseOrProtocolError) 1078 -> HttpInputStream::Response { 1079 auto response = KJ_REQUIRE_NONNULL( 1080 responseOrProtocolError.tryGet<HttpHeaders::Response>(), "bad response"); 1081 auto body = getEntityBody(HttpInputStreamImpl::RESPONSE, requestMethod, 1082 response.statusCode, headers); 1083 1084 return { response.statusCode, response.statusText, headers, kj::mv(body) }; 1085 }); 1086 } 1087 1088 kj::Promise<Message> readMessage() override { 1089 return readMessageHeaders() 1090 .then([this](kj::ArrayPtr<char> text) -> HttpInputStream::Message { 1091 headers.clear(); 1092 KJ_REQUIRE(headers.tryParse(text), "bad message"); 1093 auto body = getEntityBody(HttpInputStreamImpl::RESPONSE, HttpMethod::GET, 0, headers); 1094 1095 return { headers, kj::mv(body) }; 1096 }); 1097 } 1098 1099 // --------------------------------------------------------------------------- 1100 // Stream locking: While an entity-body is being read, the body stream "locks" the underlying 1101 // HTTP stream. Once the entity-body is complete, we can read the next pipelined message. 1102 1103 void finishRead() { 1104 // Called when entire request has been read. 1105 1106 KJ_REQUIRE_NONNULL(onMessageDone)->fulfill(); 1107 onMessageDone = nullptr; 1108 --pendingMessageCount; 1109 } 1110 1111 void abortRead() { 1112 // Called when a body input stream was destroyed without reading to the end. 1113 1114 KJ_REQUIRE_NONNULL(onMessageDone)->reject(KJ_EXCEPTION(FAILED, 1115 "application did not finish reading previous HTTP response body", 1116 "can't read next pipelined request/response")); 1117 onMessageDone = nullptr; 1118 broken = true; 1119 } 1120 1121 // --------------------------------------------------------------------------- 1122 1123 kj::Promise<bool> awaitNextMessage() override { 1124 // Waits until more data is available, but doesn't consume it. Returns false on EOF. 1125 // 1126 // Used on the server after a request is handled, to check for pipelined requests. 1127 // 1128 // Used on the client to detect when idle connections are closed from the server end. (In this 1129 // case, the promise always returns false or is canceled.) 1130 1131 if (onMessageDone != nullptr) { 1132 // We're still working on reading the previous body. 1133 auto fork = messageReadQueue.fork(); 1134 messageReadQueue = fork.addBranch(); 1135 return fork.addBranch().then([this]() { 1136 return awaitNextMessage(); 1137 }); 1138 } 1139 1140 snarfBufferedLineBreak(); 1141 1142 if (!lineBreakBeforeNextHeader && leftover != nullptr) { 1143 return true; 1144 } 1145 1146 return inner.tryRead(headerBuffer.begin(), 1, headerBuffer.size()) 1147 .then([this](size_t amount) -> kj::Promise<bool> { 1148 if (amount > 0) { 1149 leftover = headerBuffer.slice(0, amount); 1150 return awaitNextMessage(); 1151 } else { 1152 return false; 1153 } 1154 }); 1155 } 1156 1157 bool isCleanDrain() { 1158 // Returns whether we can cleanly drain the stream at this point. 1159 if (onMessageDone != nullptr) return false; 1160 snarfBufferedLineBreak(); 1161 return !lineBreakBeforeNextHeader && leftover == nullptr; 1162 } 1163 1164 kj::Promise<kj::ArrayPtr<char>> readMessageHeaders() { 1165 ++pendingMessageCount; 1166 auto paf = kj::newPromiseAndFulfiller<void>(); 1167 1168 auto promise = messageReadQueue 1169 .then(kj::mvCapture(paf.fulfiller, [this](kj::Own<kj::PromiseFulfiller<void>> fulfiller) { 1170 onMessageDone = kj::mv(fulfiller); 1171 return readHeader(HeaderType::MESSAGE, 0, 0); 1172 })); 1173 1174 messageReadQueue = kj::mv(paf.promise); 1175 1176 return promise; 1177 } 1178 1179 kj::Promise<uint64_t> readChunkHeader() { 1180 KJ_REQUIRE(onMessageDone != nullptr); 1181 1182 // We use the portion of the header after the end of message headers. 1183 return readHeader(HeaderType::CHUNK, messageHeaderEnd, messageHeaderEnd) 1184 .then([](kj::ArrayPtr<char> text) -> uint64_t { 1185 KJ_REQUIRE(text.size() > 0) { break; } 1186 1187 uint64_t value = 0; 1188 for (char c: text) { 1189 if ('0' <= c && c <= '9') { 1190 value = value * 16 + (c - '0'); 1191 } else if ('a' <= c && c <= 'f') { 1192 value = value * 16 + (c - 'a' + 10); 1193 } else if ('A' <= c && c <= 'F') { 1194 value = value * 16 + (c - 'A' + 10); 1195 } else { 1196 KJ_FAIL_REQUIRE("invalid HTTP chunk size", text, text.asBytes()) { break; } 1197 return value; 1198 } 1199 } 1200 1201 return value; 1202 }); 1203 } 1204 1205 inline kj::Promise<HttpHeaders::RequestOrProtocolError> readRequestHeaders() { 1206 return readMessageHeaders().then([this](kj::ArrayPtr<char> text) { 1207 headers.clear(); 1208 return headers.tryParseRequest(text); 1209 }); 1210 } 1211 1212 inline kj::Promise<HttpHeaders::ResponseOrProtocolError> readResponseHeaders() { 1213 // Note: readResponseHeaders() could be called multiple times concurrently when pipelining 1214 // requests. readMessageHeaders() will serialize these, but it's important not to mess with 1215 // state (like calling headers.clear()) before said serialization has taken place. 1216 return readMessageHeaders().then([this](kj::ArrayPtr<char> text) { 1217 headers.clear(); 1218 return headers.tryParseResponse(text); 1219 }); 1220 } 1221 1222 inline const HttpHeaders& getHeaders() const { return headers; } 1223 1224 Promise<size_t> tryRead(void* buffer, size_t minBytes, size_t maxBytes) { 1225 // Read message body data. 1226 1227 KJ_REQUIRE(onMessageDone != nullptr); 1228 1229 if (leftover == nullptr) { 1230 // No leftovers. Forward directly to inner stream. 1231 return inner.tryRead(buffer, minBytes, maxBytes); 1232 } else if (leftover.size() >= maxBytes) { 1233 // Didn't even read the entire leftover buffer. 1234 memcpy(buffer, leftover.begin(), maxBytes); 1235 leftover = leftover.slice(maxBytes, leftover.size()); 1236 return maxBytes; 1237 } else { 1238 // Read the entire leftover buffer, plus some. 1239 memcpy(buffer, leftover.begin(), leftover.size()); 1240 size_t copied = leftover.size(); 1241 leftover = nullptr; 1242 if (copied >= minBytes) { 1243 // Got enough to stop here. 1244 return copied; 1245 } else { 1246 // Read the rest from the underlying stream. 1247 return inner.tryRead(reinterpret_cast<byte*>(buffer) + copied, 1248 minBytes - copied, maxBytes - copied) 1249 .then([copied](size_t n) { return n + copied; }); 1250 } 1251 } 1252 } 1253 1254 enum RequestOrResponse { 1255 REQUEST, 1256 RESPONSE 1257 }; 1258 1259 kj::Own<kj::AsyncInputStream> getEntityBody( 1260 RequestOrResponse type, HttpMethod method, uint statusCode, 1261 const kj::HttpHeaders& headers); 1262 1263 struct ReleasedBuffer { 1264 kj::Array<byte> buffer; 1265 kj::ArrayPtr<byte> leftover; 1266 }; 1267 1268 ReleasedBuffer releaseBuffer() { 1269 return { headerBuffer.releaseAsBytes(), leftover.asBytes() }; 1270 } 1271 1272 private: 1273 AsyncInputStream& inner; 1274 kj::Array<char> headerBuffer; 1275 1276 size_t messageHeaderEnd = 0; 1277 // Position in headerBuffer where the message headers end -- further buffer space can 1278 // be used for chunk headers. 1279 1280 kj::ArrayPtr<char> leftover; 1281 // Data in headerBuffer that comes immediately after the header content, if any. 1282 1283 HttpHeaders headers; 1284 // Parsed headers, after a call to parseAwaited*(). 1285 1286 bool lineBreakBeforeNextHeader = false; 1287 // If true, the next await should expect to start with a spurious '\n' or '\r\n'. This happens 1288 // as a side-effect of HTTP chunked encoding, where such a newline is added to the end of each 1289 // chunk, for no good reason. 1290 1291 bool broken = false; 1292 // Becomes true if the caller failed to read the whole entity-body before closing the stream. 1293 1294 uint pendingMessageCount = 0; 1295 // Number of reads we have queued up. 1296 1297 kj::Promise<void> messageReadQueue = kj::READY_NOW; 1298 1299 kj::Maybe<kj::Own<kj::PromiseFulfiller<void>>> onMessageDone; 1300 // Fulfill once the current message has been completely read. Unblocks reading of the next 1301 // message headers. 1302 1303 enum class HeaderType { 1304 MESSAGE, 1305 CHUNK 1306 }; 1307 1308 kj::Promise<kj::ArrayPtr<char>> readHeader( 1309 HeaderType type, size_t bufferStart, size_t bufferEnd) { 1310 // Reads the HTTP message header or a chunk header (as in transfer-encoding chunked) and 1311 // returns the buffer slice containing it. 1312 // 1313 // The main source of complication here is that we want to end up with one continuous buffer 1314 // containing the result, and that the input is delimited by newlines rather than by an upfront 1315 // length. 1316 1317 kj::Promise<size_t> readPromise = nullptr; 1318 1319 // Figure out where we're reading from. 1320 if (leftover != nullptr) { 1321 // Some data is still left over from the previous message, so start with that. 1322 1323 // This can only happen if this is the initial call to readHeader() (not recursive). 1324 KJ_ASSERT(bufferStart == bufferEnd); 1325 1326 // OK, set bufferStart and bufferEnd to both point to the start of the leftover, and then 1327 // fake a read promise as if we read the bytes from the leftover. 1328 bufferStart = leftover.begin() - headerBuffer.begin(); 1329 bufferEnd = bufferStart; 1330 readPromise = leftover.size(); 1331 leftover = nullptr; 1332 } else { 1333 // Need to read more data from the underlying stream. 1334 1335 if (bufferEnd == headerBuffer.size()) { 1336 // Out of buffer space. 1337 1338 // Maybe we can move bufferStart backwards to make more space at the end? 1339 size_t minStart = type == HeaderType::MESSAGE ? 0 : messageHeaderEnd; 1340 1341 if (bufferStart > minStart) { 1342 // Move to make space. 1343 memmove(headerBuffer.begin() + minStart, headerBuffer.begin() + bufferStart, 1344 bufferEnd - bufferStart); 1345 bufferEnd = bufferEnd - bufferStart + minStart; 1346 bufferStart = minStart; 1347 } else { 1348 // Really out of buffer space. Grow the buffer. 1349 if (type != HeaderType::MESSAGE) { 1350 // Can't grow because we'd invalidate the HTTP headers. 1351 return KJ_EXCEPTION(FAILED, "invalid HTTP chunk size"); 1352 } 1353 KJ_REQUIRE(headerBuffer.size() < MAX_BUFFER, "request headers too large"); 1354 auto newBuffer = kj::heapArray<char>(headerBuffer.size() * 2); 1355 memcpy(newBuffer.begin(), headerBuffer.begin(), headerBuffer.size()); 1356 headerBuffer = kj::mv(newBuffer); 1357 } 1358 } 1359 1360 // How many bytes will we read? 1361 size_t maxBytes = headerBuffer.size() - bufferEnd; 1362 1363 if (type == HeaderType::CHUNK) { 1364 // Roughly limit the amount of data we read to MAX_CHUNK_HEADER_SIZE. 1365 // TODO(perf): This is mainly to avoid copying a lot of body data into our buffer just to 1366 // copy it again when it is read. But maybe the copy would be cheaper than overhead of 1367 // extra event loop turns? 1368 KJ_REQUIRE(bufferEnd - bufferStart <= MAX_CHUNK_HEADER_SIZE, "invalid HTTP chunk size"); 1369 maxBytes = kj::min(maxBytes, MAX_CHUNK_HEADER_SIZE); 1370 } 1371 1372 readPromise = inner.read(headerBuffer.begin() + bufferEnd, 1, maxBytes); 1373 } 1374 1375 return readPromise.then([this,type,bufferStart,bufferEnd](size_t amount) mutable 1376 -> kj::Promise<kj::ArrayPtr<char>> { 1377 if (lineBreakBeforeNextHeader) { 1378 // Hackily deal with expected leading line break. 1379 if (bufferEnd == bufferStart && headerBuffer[bufferEnd] == '\r') { 1380 ++bufferEnd; 1381 --amount; 1382 } 1383 1384 if (amount > 0 && headerBuffer[bufferEnd] == '\n') { 1385 lineBreakBeforeNextHeader = false; 1386 ++bufferEnd; 1387 --amount; 1388 1389 // Cut the leading line break out of the buffer entirely. 1390 bufferStart = bufferEnd; 1391 } 1392 1393 if (amount == 0) { 1394 return readHeader(type, bufferStart, bufferEnd); 1395 } 1396 } 1397 1398 size_t pos = bufferEnd; 1399 size_t newEnd = pos + amount; 1400 1401 for (;;) { 1402 // Search for next newline. 1403 char* nl = reinterpret_cast<char*>( 1404 memchr(headerBuffer.begin() + pos, '\n', newEnd - pos)); 1405 if (nl == nullptr) { 1406 // No newline found. Wait for more data. 1407 return readHeader(type, bufferStart, newEnd); 1408 } 1409 1410 // Is this newline which we found the last of the header? For a chunk header, always. For 1411 // a message header, we search for two newlines in a row. We accept either "\r\n" or just 1412 // "\n" as a newline sequence (though the standard requires "\r\n"). 1413 if (type == HeaderType::CHUNK || 1414 (nl - headerBuffer.begin() >= 4 && 1415 ((nl[-1] == '\r' && nl[-2] == '\n') || (nl[-1] == '\n')))) { 1416 // OK, we've got all the data! 1417 1418 size_t endIndex = nl + 1 - headerBuffer.begin(); 1419 size_t leftoverStart = endIndex; 1420 1421 // Strip off the last newline from end. 1422 endIndex -= 1 + (nl[-1] == '\r'); 1423 1424 if (type == HeaderType::MESSAGE) { 1425 if (headerBuffer.size() - newEnd < MAX_CHUNK_HEADER_SIZE) { 1426 // Ugh, there's not enough space for the secondary await buffer. Grow once more. 1427 auto newBuffer = kj::heapArray<char>(headerBuffer.size() * 2); 1428 memcpy(newBuffer.begin(), headerBuffer.begin(), headerBuffer.size()); 1429 headerBuffer = kj::mv(newBuffer); 1430 } 1431 messageHeaderEnd = endIndex; 1432 } else { 1433 // For some reason, HTTP specifies that there will be a line break after each chunk. 1434 lineBreakBeforeNextHeader = true; 1435 } 1436 1437 auto result = headerBuffer.slice(bufferStart, endIndex); 1438 leftover = headerBuffer.slice(leftoverStart, newEnd); 1439 return result; 1440 } else { 1441 pos = nl - headerBuffer.begin() + 1; 1442 } 1443 } 1444 }); 1445 } 1446 1447 void snarfBufferedLineBreak() { 1448 // Slightly-crappy code to snarf the expected line break. This will actually eat the leading 1449 // regex /\r*\n?/. 1450 while (lineBreakBeforeNextHeader && leftover.size() > 0) { 1451 if (leftover[0] == '\r') { 1452 leftover = leftover.slice(1, leftover.size()); 1453 } else if (leftover[0] == '\n') { 1454 leftover = leftover.slice(1, leftover.size()); 1455 lineBreakBeforeNextHeader = false; 1456 } else { 1457 // Err, missing line break, whatever. 1458 lineBreakBeforeNextHeader = false; 1459 } 1460 } 1461 } 1462 }; 1463 1464 // ----------------------------------------------------------------------------- 1465 1466 class HttpEntityBodyReader: public kj::AsyncInputStream { 1467 public: 1468 HttpEntityBodyReader(HttpInputStreamImpl& inner): inner(inner) {} 1469 ~HttpEntityBodyReader() noexcept(false) { 1470 if (!finished) { 1471 inner.abortRead(); 1472 } 1473 } 1474 1475 protected: 1476 HttpInputStreamImpl& inner; 1477 1478 void doneReading() { 1479 KJ_REQUIRE(!finished); 1480 finished = true; 1481 inner.finishRead(); 1482 } 1483 1484 inline bool alreadyDone() { return finished; } 1485 1486 private: 1487 bool finished = false; 1488 }; 1489 1490 class HttpNullEntityReader final: public HttpEntityBodyReader { 1491 // Stream for an entity-body which is not present. Always returns EOF on read, but tryGetLength() 1492 // may indicate non-zero in the special case of a response to a HEAD request. 1493 1494 public: 1495 HttpNullEntityReader(HttpInputStreamImpl& inner, kj::Maybe<uint64_t> length) 1496 : HttpEntityBodyReader(inner), length(length) { 1497 // `length` is what to return from tryGetLength(). For a response to a HEAD request, this may 1498 // be non-zero. 1499 doneReading(); 1500 } 1501 1502 Promise<size_t> tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { 1503 return size_t(0); 1504 } 1505 1506 Maybe<uint64_t> tryGetLength() override { 1507 return length; 1508 } 1509 1510 private: 1511 kj::Maybe<uint64_t> length; 1512 }; 1513 1514 class HttpConnectionCloseEntityReader final: public HttpEntityBodyReader { 1515 // Stream which reads until EOF. 1516 1517 public: 1518 HttpConnectionCloseEntityReader(HttpInputStreamImpl& inner) 1519 : HttpEntityBodyReader(inner) {} 1520 1521 Promise<size_t> tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { 1522 if (alreadyDone()) return size_t(0); 1523 1524 return inner.tryRead(buffer, minBytes, maxBytes) 1525 .then([=](size_t amount) { 1526 if (amount < minBytes) { 1527 doneReading(); 1528 } 1529 return amount; 1530 }); 1531 } 1532 }; 1533 1534 class HttpFixedLengthEntityReader final: public HttpEntityBodyReader { 1535 // Stream which reads only up to a fixed length from the underlying stream, then emulates EOF. 1536 1537 public: 1538 HttpFixedLengthEntityReader(HttpInputStreamImpl& inner, size_t length) 1539 : HttpEntityBodyReader(inner), length(length) { 1540 if (length == 0) doneReading(); 1541 } 1542 1543 Maybe<uint64_t> tryGetLength() override { 1544 return length; 1545 } 1546 1547 Promise<size_t> tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { 1548 return tryReadInternal(buffer, minBytes, maxBytes, 0); 1549 } 1550 1551 private: 1552 size_t length; 1553 1554 Promise<size_t> tryReadInternal(void* buffer, size_t minBytes, size_t maxBytes, 1555 size_t alreadyRead) { 1556 if (length == 0) return size_t(0); 1557 1558 // We have to set minBytes to 1 here so that if we read any data at all, we update our 1559 // counter immediately, so that we still know where we are in case of cancellation. 1560 return inner.tryRead(buffer, 1, kj::min(maxBytes, length)) 1561 .then([=](size_t amount) -> kj::Promise<size_t> { 1562 length -= amount; 1563 if (length > 0) { 1564 // We haven't reached the end of the entity body yet. 1565 if (amount == 0) { 1566 kj::throwRecoverableException(KJ_EXCEPTION(DISCONNECTED, 1567 "premature EOF in HTTP entity body; did not reach Content-Length")); 1568 } else if (amount < minBytes) { 1569 // We requested a minimum 1 byte above, but our own caller actually set a larger minimum 1570 // which has not yet been reached. Keep trying until we reach it. 1571 return tryReadInternal(reinterpret_cast<byte*>(buffer) + amount, 1572 minBytes - amount, maxBytes - amount, alreadyRead + amount); 1573 } 1574 } else if (length == 0) { 1575 doneReading(); 1576 } 1577 return amount + alreadyRead; 1578 }); 1579 } 1580 }; 1581 1582 class HttpChunkedEntityReader final: public HttpEntityBodyReader { 1583 // Stream which reads a Transfer-Encoding: Chunked stream. 1584 1585 public: 1586 HttpChunkedEntityReader(HttpInputStreamImpl& inner) 1587 : HttpEntityBodyReader(inner) {} 1588 1589 Promise<size_t> tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { 1590 return tryReadInternal(buffer, minBytes, maxBytes, 0); 1591 } 1592 1593 private: 1594 size_t chunkSize = 0; 1595 1596 Promise<size_t> tryReadInternal(void* buffer, size_t minBytes, size_t maxBytes, 1597 size_t alreadyRead) { 1598 if (alreadyDone()) { 1599 return alreadyRead; 1600 } else if (chunkSize == 0) { 1601 // Read next chunk header. 1602 return inner.readChunkHeader().then([=](uint64_t nextChunkSize) { 1603 if (nextChunkSize == 0) { 1604 doneReading(); 1605 } 1606 1607 chunkSize = nextChunkSize; 1608 return tryReadInternal(buffer, minBytes, maxBytes, alreadyRead); 1609 }); 1610 } else { 1611 // Read current chunk. 1612 // We have to set minBytes to 1 here so that if we read any data at all, we update our 1613 // counter immediately, so that we still know where we are in case of cancellation. 1614 return inner.tryRead(buffer, 1, kj::min(maxBytes, chunkSize)) 1615 .then([=](size_t amount) -> kj::Promise<size_t> { 1616 chunkSize -= amount; 1617 if (amount == 0) { 1618 kj::throwRecoverableException(KJ_EXCEPTION(DISCONNECTED, "premature EOF in HTTP chunk")); 1619 } else if (amount < minBytes) { 1620 // We requested a minimum 1 byte above, but our own caller actually set a larger minimum 1621 // which has not yet been reached. Keep trying until we reach it. 1622 return tryReadInternal(reinterpret_cast<byte*>(buffer) + amount, 1623 minBytes - amount, maxBytes - amount, alreadyRead + amount); 1624 } 1625 return alreadyRead + amount; 1626 }); 1627 } 1628 } 1629 }; 1630 1631 template <char...> 1632 struct FastCaseCmp; 1633 1634 template <char first, char... rest> 1635 struct FastCaseCmp<first, rest...> { 1636 static constexpr bool apply(const char* actual) { 1637 return 1638 ('a' <= first && first <= 'z') || ('A' <= first && first <= 'Z') 1639 ? (*actual | 0x20) == (first | 0x20) && FastCaseCmp<rest...>::apply(actual + 1) 1640 : *actual == first && FastCaseCmp<rest...>::apply(actual + 1); 1641 } 1642 }; 1643 1644 template <> 1645 struct FastCaseCmp<> { 1646 static constexpr bool apply(const char* actual) { 1647 return *actual == '\0'; 1648 } 1649 }; 1650 1651 template <char... chars> 1652 constexpr bool fastCaseCmp(const char* actual) { 1653 return FastCaseCmp<chars...>::apply(actual); 1654 } 1655 1656 // Tests 1657 static_assert(fastCaseCmp<'f','O','o','B','1'>("FooB1"), ""); 1658 static_assert(!fastCaseCmp<'f','O','o','B','2'>("FooB1"), ""); 1659 static_assert(!fastCaseCmp<'n','O','o','B','1'>("FooB1"), ""); 1660 static_assert(!fastCaseCmp<'f','O','o','B'>("FooB1"), ""); 1661 static_assert(!fastCaseCmp<'f','O','o','B','1','a'>("FooB1"), ""); 1662 1663 kj::Own<kj::AsyncInputStream> HttpInputStreamImpl::getEntityBody( 1664 RequestOrResponse type, HttpMethod method, uint statusCode, 1665 const kj::HttpHeaders& headers) { 1666 // Rules to determine how HTTP entity-body is delimited: 1667 // https://tools.ietf.org/html/rfc7230#section-3.3.3 1668 1669 // #1 1670 if (type == RESPONSE) { 1671 if (method == HttpMethod::HEAD) { 1672 // Body elided. 1673 kj::Maybe<uint64_t> length; 1674 KJ_IF_MAYBE(cl, headers.get(HttpHeaderId::CONTENT_LENGTH)) { 1675 length = strtoull(cl->cStr(), nullptr, 10); 1676 } else if (headers.get(HttpHeaderId::TRANSFER_ENCODING) == nullptr) { 1677 // HACK: Neither Content-Length nor Transfer-Encoding header in response to HEAD request. 1678 // Propagate this fact with a 0 expected body length. 1679 length = uint64_t(0); 1680 } 1681 return kj::heap<HttpNullEntityReader>(*this, length); 1682 } else if (statusCode == 204 || statusCode == 304) { 1683 // No body. 1684 return kj::heap<HttpNullEntityReader>(*this, uint64_t(0)); 1685 } 1686 } 1687 1688 // #2 deals with the CONNECT method which is handled separately. 1689 1690 // #3 1691 KJ_IF_MAYBE(te, headers.get(HttpHeaderId::TRANSFER_ENCODING)) { 1692 // TODO(someday): Support plugable transfer encodings? Or at least gzip? 1693 // TODO(someday): Support stacked transfer encodings, e.g. "gzip, chunked". 1694 1695 // NOTE: #3¶3 is ambiguous about what should happen if Transfer-Encoding and Content-Length are 1696 // both present. It says that Transfer-Encoding takes precedence, but also that the request 1697 // "ought to be handled as an error", and that proxies "MUST" drop the Content-Length before 1698 // forwarding. We ignore the vague "ought to" part and implement the other two. (The 1699 // dropping of Content-Length will happen naturally if/when the message is sent back out to 1700 // the network.) 1701 if (fastCaseCmp<'c','h','u','n','k','e','d'>(te->cStr())) { 1702 // #3¶1 1703 return kj::heap<HttpChunkedEntityReader>(*this); 1704 } else if (fastCaseCmp<'i','d','e','n','t','i','t','y'>(te->cStr())) { 1705 // #3¶2 1706 KJ_REQUIRE(type != REQUEST, "request body cannot have Transfer-Encoding other than chunked"); 1707 return kj::heap<HttpConnectionCloseEntityReader>(*this); 1708 } else { 1709 KJ_FAIL_REQUIRE("unknown transfer encoding", *te) { break; } 1710 } 1711 } 1712 1713 // #4 and #5 1714 KJ_IF_MAYBE(cl, headers.get(HttpHeaderId::CONTENT_LENGTH)) { 1715 // NOTE: By spec, multiple Content-Length values are allowed as long as they are the same, e.g. 1716 // "Content-Length: 5, 5, 5". Hopefully no one actually does that... 1717 char* end; 1718 uint64_t length = strtoull(cl->cStr(), &end, 10); 1719 if (end > cl->begin() && *end == '\0') { 1720 // #5 1721 return kj::heap<HttpFixedLengthEntityReader>(*this, length); 1722 } else { 1723 // #4 (bad content-length) 1724 KJ_FAIL_REQUIRE("invalid Content-Length header value", *cl); 1725 } 1726 } 1727 1728 // #6 1729 if (type == REQUEST) { 1730 // Lack of a Content-Length or Transfer-Encoding means no body for requests. 1731 return kj::heap<HttpNullEntityReader>(*this, uint64_t(0)); 1732 } 1733 1734 // RFC 2616 permitted "multipart/byteranges" responses to be self-delimiting, but this was 1735 // mercifully removed in RFC 7230, and new exceptions of this type are disallowed: 1736 // https://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.4 1737 // https://tools.ietf.org/html/rfc7230#page-81 1738 // To be extra-safe, we'll reject a multipart/byteranges response that lacks transfer-encoding 1739 // and content-length. 1740 KJ_IF_MAYBE(type, headers.get(HttpHeaderId::CONTENT_TYPE)) { 1741 if (type->startsWith("multipart/byteranges")) { 1742 KJ_FAIL_REQUIRE( 1743 "refusing to handle multipart/byteranges response without transfer-encoding nor " 1744 "content-length due to ambiguity between RFC 2616 vs RFC 7230."); 1745 } 1746 } 1747 1748 // #7 1749 return kj::heap<HttpConnectionCloseEntityReader>(*this); 1750 } 1751 1752 } // namespace 1753 1754 kj::Own<HttpInputStream> newHttpInputStream( 1755 kj::AsyncInputStream& input, const HttpHeaderTable& table) { 1756 return kj::heap<HttpInputStreamImpl>(input, table); 1757 } 1758 1759 // ======================================================================================= 1760 1761 namespace { 1762 1763 class HttpOutputStream { 1764 public: 1765 HttpOutputStream(AsyncOutputStream& inner): inner(inner) {} 1766 1767 bool isInBody() { 1768 return inBody; 1769 } 1770 1771 bool canReuse() { 1772 return !inBody && !broken && !writeInProgress; 1773 } 1774 1775 bool canWriteBodyData() { 1776 return !writeInProgress && inBody; 1777 } 1778 1779 bool isBroken() { 1780 return broken; 1781 } 1782 1783 void writeHeaders(String content) { 1784 // Writes some header content and begins a new entity body. 1785 1786 KJ_REQUIRE(!writeInProgress, "concurrent write()s not allowed") { return; } 1787 KJ_REQUIRE(!inBody, "previous HTTP message body incomplete; can't write more messages"); 1788 inBody = true; 1789 1790 queueWrite(kj::mv(content)); 1791 } 1792 1793 void writeBodyData(kj::String content) { 1794 KJ_REQUIRE(!writeInProgress, "concurrent write()s not allowed") { return; } 1795 KJ_REQUIRE(inBody) { return; } 1796 1797 queueWrite(kj::mv(content)); 1798 } 1799 1800 kj::Promise<void> writeBodyData(const void* buffer, size_t size) { 1801 KJ_REQUIRE(!writeInProgress, "concurrent write()s not allowed") { return kj::READY_NOW; } 1802 KJ_REQUIRE(inBody) { return kj::READY_NOW; } 1803 1804 writeInProgress = true; 1805 auto fork = writeQueue.fork(); 1806 writeQueue = fork.addBranch(); 1807 1808 return fork.addBranch().then([this,buffer,size]() { 1809 return inner.write(buffer, size); 1810 }).then([this]() { 1811 writeInProgress = false; 1812 }); 1813 } 1814 1815 kj::Promise<void> writeBodyData(kj::ArrayPtr<const kj::ArrayPtr<const byte>> pieces) { 1816 KJ_REQUIRE(!writeInProgress, "concurrent write()s not allowed") { return kj::READY_NOW; } 1817 KJ_REQUIRE(inBody) { return kj::READY_NOW; } 1818 1819 writeInProgress = true; 1820 auto fork = writeQueue.fork(); 1821 writeQueue = fork.addBranch(); 1822 1823 return fork.addBranch().then([this,pieces]() { 1824 return inner.write(pieces); 1825 }).then([this]() { 1826 writeInProgress = false; 1827 }); 1828 } 1829 1830 Promise<uint64_t> pumpBodyFrom(AsyncInputStream& input, uint64_t amount) { 1831 KJ_REQUIRE(!writeInProgress, "concurrent write()s not allowed") { return uint64_t(0); } 1832 KJ_REQUIRE(inBody) { return uint64_t(0); } 1833 1834 writeInProgress = true; 1835 auto fork = writeQueue.fork(); 1836 writeQueue = fork.addBranch(); 1837 1838 return fork.addBranch().then([this,&input,amount]() { 1839 return input.pumpTo(inner, amount); 1840 }).then([this](uint64_t actual) { 1841 writeInProgress = false; 1842 return actual; 1843 }); 1844 } 1845 1846 void finishBody() { 1847 // Called when entire body was written. 1848 1849 KJ_REQUIRE(inBody) { return; } 1850 inBody = false; 1851 1852 if (writeInProgress) { 1853 // It looks like the last write never completed -- possibly because it was canceled or threw 1854 // an exception. We must treat this equivalent to abortBody(). 1855 broken = true; 1856 1857 // Cancel any writes that are still queued. 1858 writeQueue = KJ_EXCEPTION(FAILED, 1859 "previous HTTP message body incomplete; can't write more messages"); 1860 } 1861 } 1862 1863 void abortBody() { 1864 // Called if the application failed to write all expected body bytes. 1865 KJ_REQUIRE(inBody) { return; } 1866 inBody = false; 1867 broken = true; 1868 1869 // Cancel any writes that are still queued. 1870 writeQueue = KJ_EXCEPTION(FAILED, 1871 "previous HTTP message body incomplete; can't write more messages"); 1872 } 1873 1874 kj::Promise<void> flush() { 1875 auto fork = writeQueue.fork(); 1876 writeQueue = fork.addBranch(); 1877 return fork.addBranch(); 1878 } 1879 1880 Promise<void> whenWriteDisconnected() { 1881 return inner.whenWriteDisconnected(); 1882 } 1883 1884 bool isWriteInProgress() { return writeInProgress; } 1885 1886 private: 1887 AsyncOutputStream& inner; 1888 kj::Promise<void> writeQueue = kj::READY_NOW; 1889 bool inBody = false; 1890 bool broken = false; 1891 1892 bool writeInProgress = false; 1893 // True if a write method has been called and has not completed successfully. In the case that 1894 // a write throws an exception or is canceled, this remains true forever. In these cases, the 1895 // underlying stream is in an inconsistent state and cannot be reused. 1896 1897 void queueWrite(kj::String content) { 1898 // We only use queueWrite() in cases where we can take ownership of the write buffer, and where 1899 // it is convenient if we can return `void` rather than a promise. In particular, this is used 1900 // to write headers and chunk boundaries. Writes of application data do not go into 1901 // `writeQueue` because this would prevent cancellation. Instead, they wait until `writeQueue` 1902 // is empty, then they make the write directly, using `writeInProgress` to detect and block 1903 // concurrent writes. 1904 1905 writeQueue = writeQueue.then(kj::mvCapture(content, [this](kj::String&& content) { 1906 auto promise = inner.write(content.begin(), content.size()); 1907 return promise.attach(kj::mv(content)); 1908 })); 1909 } 1910 }; 1911 1912 class HttpNullEntityWriter final: public kj::AsyncOutputStream { 1913 public: 1914 Promise<void> write(const void* buffer, size_t size) override { 1915 return KJ_EXCEPTION(FAILED, "HTTP message has no entity-body; can't write()"); 1916 } 1917 Promise<void> write(ArrayPtr<const ArrayPtr<const byte>> pieces) override { 1918 return KJ_EXCEPTION(FAILED, "HTTP message has no entity-body; can't write()"); 1919 } 1920 Promise<void> whenWriteDisconnected() override { 1921 return kj::NEVER_DONE; 1922 } 1923 }; 1924 1925 class HttpDiscardingEntityWriter final: public kj::AsyncOutputStream { 1926 public: 1927 Promise<void> write(const void* buffer, size_t size) override { 1928 return kj::READY_NOW; 1929 } 1930 Promise<void> write(ArrayPtr<const ArrayPtr<const byte>> pieces) override { 1931 return kj::READY_NOW; 1932 } 1933 Promise<void> whenWriteDisconnected() override { 1934 return kj::NEVER_DONE; 1935 } 1936 }; 1937 1938 class HttpFixedLengthEntityWriter final: public kj::AsyncOutputStream { 1939 public: 1940 HttpFixedLengthEntityWriter(HttpOutputStream& inner, uint64_t length) 1941 : inner(inner), length(length) { 1942 if (length == 0) inner.finishBody(); 1943 } 1944 ~HttpFixedLengthEntityWriter() noexcept(false) { 1945 if (length > 0 || inner.isWriteInProgress()) { 1946 inner.abortBody(); 1947 } 1948 } 1949 1950 Promise<void> write(const void* buffer, size_t size) override { 1951 if (size == 0) return kj::READY_NOW; 1952 KJ_REQUIRE(size <= length, "overwrote Content-Length"); 1953 length -= size; 1954 1955 return maybeFinishAfter(inner.writeBodyData(buffer, size)); 1956 } 1957 Promise<void> write(ArrayPtr<const ArrayPtr<const byte>> pieces) override { 1958 uint64_t size = 0; 1959 for (auto& piece: pieces) size += piece.size(); 1960 1961 if (size == 0) return kj::READY_NOW; 1962 KJ_REQUIRE(size <= length, "overwrote Content-Length"); 1963 length -= size; 1964 1965 return maybeFinishAfter(inner.writeBodyData(pieces)); 1966 } 1967 1968 Maybe<Promise<uint64_t>> tryPumpFrom(AsyncInputStream& input, uint64_t amount) override { 1969 if (amount == 0) return Promise<uint64_t>(uint64_t(0)); 1970 1971 bool overshot = amount > length; 1972 if (overshot) { 1973 // Hmm, the requested amount was too large, but it's common to specify kj::max as the amount 1974 // to pump, in which case we pump to EOF. Let's try to verify whether EOF is where we 1975 // expect it to be. 1976 KJ_IF_MAYBE(available, input.tryGetLength()) { 1977 // Great, the stream knows how large it is. If it's indeed larger than the space available 1978 // then let's abort. 1979 KJ_REQUIRE(*available <= length, "overwrote Content-Length"); 1980 } else { 1981 // OK, we have no idea how large the input is, so we'll have to check later. 1982 } 1983 } 1984 1985 amount = kj::min(amount, length); 1986 length -= amount; 1987 1988 auto promise = amount == 0 1989 ? kj::Promise<uint64_t>(amount) 1990 : inner.pumpBodyFrom(input, amount).then([this,amount](uint64_t actual) { 1991 // Adjust for bytes not written. 1992 length += amount - actual; 1993 if (length == 0) inner.finishBody(); 1994 return actual; 1995 }); 1996 1997 if (overshot) { 1998 promise = promise.then([amount,&input](uint64_t actual) -> kj::Promise<uint64_t> { 1999 if (actual == amount) { 2000 // We read exactly the amount expected. In order to detect an overshoot, we have to 2001 // try reading one more byte. Ugh. 2002 static byte junk; 2003 return input.tryRead(&junk, 1, 1).then([actual](size_t extra) { 2004 KJ_REQUIRE(extra == 0, "overwrote Content-Length"); 2005 return actual; 2006 }); 2007 } else { 2008 // We actually read less data than requested so we couldn't have overshot. In fact, we 2009 // undershot. 2010 return actual; 2011 } 2012 }); 2013 } 2014 2015 return kj::mv(promise); 2016 } 2017 2018 Promise<void> whenWriteDisconnected() override { 2019 return inner.whenWriteDisconnected(); 2020 } 2021 2022 private: 2023 HttpOutputStream& inner; 2024 uint64_t length; 2025 2026 kj::Promise<void> maybeFinishAfter(kj::Promise<void> promise) { 2027 if (length == 0) { 2028 return promise.then([this]() { inner.finishBody(); }); 2029 } else { 2030 return kj::mv(promise); 2031 } 2032 } 2033 }; 2034 2035 class HttpChunkedEntityWriter final: public kj::AsyncOutputStream { 2036 public: 2037 HttpChunkedEntityWriter(HttpOutputStream& inner) 2038 : inner(inner) {} 2039 ~HttpChunkedEntityWriter() noexcept(false) { 2040 if (inner.canWriteBodyData()) { 2041 inner.writeBodyData(kj::str("0\r\n\r\n")); 2042 inner.finishBody(); 2043 } else { 2044 inner.abortBody(); 2045 } 2046 } 2047 2048 Promise<void> write(const void* buffer, size_t size) override { 2049 if (size == 0) return kj::READY_NOW; // can't encode zero-size chunk since it indicates EOF. 2050 2051 auto header = kj::str(kj::hex(size), "\r\n"); 2052 auto parts = kj::heapArray<ArrayPtr<const byte>>(3); 2053 parts[0] = header.asBytes(); 2054 parts[1] = kj::arrayPtr(reinterpret_cast<const byte*>(buffer), size); 2055 parts[2] = kj::StringPtr("\r\n").asBytes(); 2056 2057 auto promise = inner.writeBodyData(parts.asPtr()); 2058 return promise.attach(kj::mv(header), kj::mv(parts)); 2059 } 2060 2061 Promise<void> write(ArrayPtr<const ArrayPtr<const byte>> pieces) override { 2062 uint64_t size = 0; 2063 for (auto& piece: pieces) size += piece.size(); 2064 2065 if (size == 0) return kj::READY_NOW; // can't encode zero-size chunk since it indicates EOF. 2066 2067 auto header = kj::str(kj::hex(size), "\r\n"); 2068 auto partsBuilder = kj::heapArrayBuilder<ArrayPtr<const byte>>(pieces.size() + 2); 2069 partsBuilder.add(header.asBytes()); 2070 for (auto& piece: pieces) { 2071 partsBuilder.add(piece); 2072 } 2073 partsBuilder.add(kj::StringPtr("\r\n").asBytes()); 2074 2075 auto parts = partsBuilder.finish(); 2076 auto promise = inner.writeBodyData(parts.asPtr()); 2077 return promise.attach(kj::mv(header), kj::mv(parts)); 2078 } 2079 2080 Maybe<Promise<uint64_t>> tryPumpFrom(AsyncInputStream& input, uint64_t amount) override { 2081 KJ_IF_MAYBE(l, input.tryGetLength()) { 2082 // Hey, we know exactly how large the input is, so we can write just one chunk. 2083 2084 uint64_t length = kj::min(amount, *l); 2085 inner.writeBodyData(kj::str(kj::hex(length), "\r\n")); 2086 return inner.pumpBodyFrom(input, length) 2087 .then([this,length](uint64_t actual) { 2088 if (actual < length) { 2089 inner.abortBody(); 2090 KJ_FAIL_REQUIRE( 2091 "value returned by input.tryGetLength() was greater than actual bytes transferred") { 2092 break; 2093 } 2094 } 2095 2096 inner.writeBodyData(kj::str("\r\n")); 2097 return actual; 2098 }); 2099 } else { 2100 // Need to use naive read/write loop. 2101 return nullptr; 2102 } 2103 } 2104 2105 Promise<void> whenWriteDisconnected() override { 2106 return inner.whenWriteDisconnected(); 2107 } 2108 2109 private: 2110 HttpOutputStream& inner; 2111 }; 2112 2113 // ======================================================================================= 2114 2115 class WebSocketImpl final: public WebSocket { 2116 public: 2117 WebSocketImpl(kj::Own<kj::AsyncIoStream> stream, 2118 kj::Maybe<EntropySource&> maskKeyGenerator, 2119 kj::Array<byte> buffer = kj::heapArray<byte>(4096), 2120 kj::ArrayPtr<byte> leftover = nullptr, 2121 kj::Maybe<kj::Promise<void>> waitBeforeSend = nullptr) 2122 : stream(kj::mv(stream)), maskKeyGenerator(maskKeyGenerator), 2123 sendingPong(kj::mv(waitBeforeSend)), 2124 recvBuffer(kj::mv(buffer)), recvData(leftover) {} 2125 2126 kj::Promise<void> send(kj::ArrayPtr<const byte> message) override { 2127 return sendImpl(OPCODE_BINARY, message); 2128 } 2129 2130 kj::Promise<void> send(kj::ArrayPtr<const char> message) override { 2131 return sendImpl(OPCODE_TEXT, message.asBytes()); 2132 } 2133 2134 kj::Promise<void> close(uint16_t code, kj::StringPtr reason) override { 2135 kj::Array<byte> payload; 2136 if (code == 1005) { 2137 KJ_REQUIRE(reason.size() == 0, "WebSocket close code 1005 cannot have a reason"); 2138 2139 // code 1005 -- leave payload empty 2140 } else { 2141 payload = heapArray<byte>(reason.size() + 2); 2142 payload[0] = code >> 8; 2143 payload[1] = code; 2144 memcpy(payload.begin() + 2, reason.begin(), reason.size()); 2145 } 2146 2147 auto promise = sendImpl(OPCODE_CLOSE, payload); 2148 return promise.attach(kj::mv(payload)); 2149 } 2150 2151 kj::Promise<void> disconnect() override { 2152 KJ_REQUIRE(!currentlySending, "another message send is already in progress"); 2153 2154 KJ_IF_MAYBE(p, sendingPong) { 2155 // We recently sent a pong, make sure it's finished before proceeding. 2156 currentlySending = true; 2157 auto promise = p->then([this]() { 2158 currentlySending = false; 2159 return disconnect(); 2160 }); 2161 sendingPong = nullptr; 2162 return promise; 2163 } 2164 2165 disconnected = true; 2166 2167 stream->shutdownWrite(); 2168 return kj::READY_NOW; 2169 } 2170 2171 void abort() override { 2172 queuedPong = nullptr; 2173 sendingPong = nullptr; 2174 disconnected = true; 2175 stream->abortRead(); 2176 stream->shutdownWrite(); 2177 } 2178 2179 kj::Promise<void> whenAborted() override { 2180 return stream->whenWriteDisconnected(); 2181 } 2182 2183 kj::Promise<Message> receive(size_t maxSize) override { 2184 size_t headerSize = Header::headerSize(recvData.begin(), recvData.size()); 2185 2186 if (headerSize > recvData.size()) { 2187 if (recvData.begin() != recvBuffer.begin()) { 2188 // Move existing data to front of buffer. 2189 if (recvData.size() > 0) { 2190 memmove(recvBuffer.begin(), recvData.begin(), recvData.size()); 2191 } 2192 recvData = recvBuffer.slice(0, recvData.size()); 2193 } 2194 2195 return stream->tryRead(recvData.end(), 1, recvBuffer.end() - recvData.end()) 2196 .then([this,maxSize](size_t actual) -> kj::Promise<Message> { 2197 receivedBytes += actual; 2198 if (actual == 0) { 2199 if (recvData.size() > 0) { 2200 return KJ_EXCEPTION(DISCONNECTED, "WebSocket EOF in frame header"); 2201 } else { 2202 // It's incorrect for the WebSocket to disconnect without sending `Close`. 2203 return KJ_EXCEPTION(DISCONNECTED, 2204 "WebSocket disconnected between frames without sending `Close`."); 2205 } 2206 } 2207 2208 recvData = recvBuffer.slice(0, recvData.size() + actual); 2209 return receive(maxSize); 2210 }); 2211 } 2212 2213 auto& recvHeader = *reinterpret_cast<Header*>(recvData.begin()); 2214 2215 recvData = recvData.slice(headerSize, recvData.size()); 2216 2217 size_t payloadLen = recvHeader.getPayloadLen(); 2218 2219 KJ_REQUIRE(payloadLen < maxSize, "WebSocket message is too large"); 2220 2221 auto opcode = recvHeader.getOpcode(); 2222 bool isData = opcode < OPCODE_FIRST_CONTROL; 2223 if (opcode == OPCODE_CONTINUATION) { 2224 KJ_REQUIRE(!fragments.empty(), "unexpected continuation frame in WebSocket"); 2225 2226 opcode = fragmentOpcode; 2227 } else if (isData) { 2228 KJ_REQUIRE(fragments.empty(), "expected continuation frame in WebSocket"); 2229 } 2230 2231 bool isFin = recvHeader.isFin(); 2232 2233 kj::Array<byte> message; // space to allocate 2234 byte* payloadTarget; // location into which to read payload (size is payloadLen) 2235 if (isFin) { 2236 // Add space for NUL terminator when allocating text message. 2237 size_t amountToAllocate = payloadLen + (opcode == OPCODE_TEXT && isFin); 2238 2239 if (isData && !fragments.empty()) { 2240 // Final frame of a fragmented message. Gather the fragments. 2241 size_t offset = 0; 2242 for (auto& fragment: fragments) offset += fragment.size(); 2243 message = kj::heapArray<byte>(offset + amountToAllocate); 2244 2245 offset = 0; 2246 for (auto& fragment: fragments) { 2247 memcpy(message.begin() + offset, fragment.begin(), fragment.size()); 2248 offset += fragment.size(); 2249 } 2250 payloadTarget = message.begin() + offset; 2251 2252 fragments.clear(); 2253 fragmentOpcode = 0; 2254 } else { 2255 // Single-frame message. 2256 message = kj::heapArray<byte>(amountToAllocate); 2257 payloadTarget = message.begin(); 2258 } 2259 } else { 2260 // Fragmented message, and this isn't the final fragment. 2261 KJ_REQUIRE(isData, "WebSocket control frame cannot be fragmented"); 2262 2263 message = kj::heapArray<byte>(payloadLen); 2264 payloadTarget = message.begin(); 2265 if (fragments.empty()) { 2266 // This is the first fragment, so set the opcode. 2267 fragmentOpcode = opcode; 2268 } 2269 } 2270 2271 Mask mask = recvHeader.getMask(); 2272 2273 auto handleMessage = kj::mvCapture(message, 2274 [this,opcode,payloadTarget,payloadLen,mask,isFin,maxSize] 2275 (kj::Array<byte>&& message) -> kj::Promise<Message> { 2276 if (!mask.isZero()) { 2277 mask.apply(kj::arrayPtr(payloadTarget, payloadLen)); 2278 } 2279 2280 if (!isFin) { 2281 // Add fragment to the list and loop. 2282 auto newMax = maxSize - message.size(); 2283 fragments.add(kj::mv(message)); 2284 return receive(newMax); 2285 } 2286 2287 switch (opcode) { 2288 case OPCODE_CONTINUATION: 2289 // Shouldn't get here; handled above. 2290 KJ_UNREACHABLE; 2291 case OPCODE_TEXT: 2292 message.back() = '\0'; 2293 return Message(kj::String(message.releaseAsChars())); 2294 case OPCODE_BINARY: 2295 return Message(message.releaseAsBytes()); 2296 case OPCODE_CLOSE: 2297 if (message.size() < 2) { 2298 return Message(Close { 1005, nullptr }); 2299 } else { 2300 uint16_t status = (static_cast<uint16_t>(message[0]) << 8) 2301 | (static_cast<uint16_t>(message[1]) ); 2302 return Message(Close { 2303 status, kj::heapString(message.slice(2, message.size()).asChars()) 2304 }); 2305 } 2306 case OPCODE_PING: 2307 // Send back a pong. 2308 queuePong(kj::mv(message)); 2309 return receive(maxSize); 2310 case OPCODE_PONG: 2311 // Unsolicited pong. Ignore. 2312 return receive(maxSize); 2313 default: 2314 KJ_FAIL_REQUIRE("unknown WebSocket opcode", opcode); 2315 } 2316 }); 2317 2318 if (payloadLen <= recvData.size()) { 2319 // All data already received. 2320 memcpy(payloadTarget, recvData.begin(), payloadLen); 2321 recvData = recvData.slice(payloadLen, recvData.size()); 2322 return handleMessage(); 2323 } else { 2324 // Need to read more data. 2325 memcpy(payloadTarget, recvData.begin(), recvData.size()); 2326 size_t remaining = payloadLen - recvData.size(); 2327 auto promise = stream->tryRead(payloadTarget + recvData.size(), remaining, remaining) 2328 .then([this, remaining](size_t amount) { 2329 receivedBytes += amount; 2330 if (amount < remaining) { 2331 kj::throwRecoverableException(KJ_EXCEPTION(DISCONNECTED, "WebSocket EOF in message")); 2332 } 2333 }); 2334 recvData = nullptr; 2335 return promise.then(kj::mv(handleMessage)); 2336 } 2337 } 2338 2339 kj::Maybe<kj::Promise<void>> tryPumpFrom(WebSocket& other) override { 2340 KJ_IF_MAYBE(optOther, kj::dynamicDowncastIfAvailable<WebSocketImpl>(other)) { 2341 // Both WebSockets are raw WebSockets, so we can pump the streams directly rather than read 2342 // whole messages. 2343 2344 if ((maskKeyGenerator == nullptr) == (optOther->maskKeyGenerator == nullptr)) { 2345 // Oops, it appears that we either believe we are the client side of both sockets, or we 2346 // are the server side of both sockets. Since clients must "mask" their outgoing frames but 2347 // servers must *not* do so, we can't direct-pump. Sad. 2348 return nullptr; 2349 } 2350 2351 // Check same error conditions as with sendImpl(). 2352 KJ_REQUIRE(!disconnected, "WebSocket can't send after disconnect()"); 2353 KJ_REQUIRE(!currentlySending, "another message send is already in progress"); 2354 currentlySending = true; 2355 2356 // If the application chooses to pump messages out, but receives incoming messages normally 2357 // with `receive()`, then we will receive pings and attempt to send pongs. But we can't 2358 // safely insert a pong in the middle of a pumped stream. We kind of don't have a choice 2359 // except to drop them on the floor, which is what will happen if we set `hasSentClose` true. 2360 // Hopefully most apps that set up a pump do so in both directions at once, and so pings will 2361 // flow through and pongs will flow back. 2362 hasSentClose = true; 2363 2364 return optOther->optimizedPumpTo(*this); 2365 } 2366 2367 return nullptr; 2368 } 2369 2370 uint64_t sentByteCount() override { return sentBytes; } 2371 2372 uint64_t receivedByteCount() override { return receivedBytes; } 2373 2374 private: 2375 class Mask { 2376 public: 2377 Mask(): maskBytes { 0, 0, 0, 0 } {} 2378 Mask(const byte* ptr) { memcpy(maskBytes, ptr, 4); } 2379 2380 Mask(kj::Maybe<EntropySource&> generator) { 2381 KJ_IF_MAYBE(g, generator) { 2382 g->generate(maskBytes); 2383 } else { 2384 memset(maskBytes, 0, 4); 2385 } 2386 } 2387 2388 void apply(kj::ArrayPtr<byte> bytes) const { 2389 apply(bytes.begin(), bytes.size()); 2390 } 2391 2392 void copyTo(byte* output) const { 2393 memcpy(output, maskBytes, 4); 2394 } 2395 2396 bool isZero() const { 2397 return (maskBytes[0] | maskBytes[1] | maskBytes[2] | maskBytes[3]) == 0; 2398 } 2399 2400 private: 2401 byte maskBytes[4]; 2402 2403 void apply(byte* __restrict__ bytes, size_t size) const { 2404 for (size_t i = 0; i < size; i++) { 2405 bytes[i] ^= maskBytes[i % 4]; 2406 } 2407 } 2408 }; 2409 2410 class Header { 2411 public: 2412 kj::ArrayPtr<const byte> compose(bool fin, byte opcode, uint64_t payloadLen, Mask mask) { 2413 bytes[0] = (fin ? FIN_MASK : 0) | opcode; 2414 bool hasMask = !mask.isZero(); 2415 2416 size_t fill; 2417 2418 if (payloadLen < 126) { 2419 bytes[1] = (hasMask ? USE_MASK_MASK : 0) | payloadLen; 2420 if (hasMask) { 2421 mask.copyTo(bytes + 2); 2422 fill = 6; 2423 } else { 2424 fill = 2; 2425 } 2426 } else if (payloadLen < 65536) { 2427 bytes[1] = (hasMask ? USE_MASK_MASK : 0) | 126; 2428 bytes[2] = static_cast<byte>(payloadLen >> 8); 2429 bytes[3] = static_cast<byte>(payloadLen ); 2430 if (hasMask) { 2431 mask.copyTo(bytes + 4); 2432 fill = 8; 2433 } else { 2434 fill = 4; 2435 } 2436 } else { 2437 bytes[1] = (hasMask ? USE_MASK_MASK : 0) | 127; 2438 bytes[2] = static_cast<byte>(payloadLen >> 56); 2439 bytes[3] = static_cast<byte>(payloadLen >> 48); 2440 bytes[4] = static_cast<byte>(payloadLen >> 40); 2441 bytes[5] = static_cast<byte>(payloadLen >> 42); 2442 bytes[6] = static_cast<byte>(payloadLen >> 24); 2443 bytes[7] = static_cast<byte>(payloadLen >> 16); 2444 bytes[8] = static_cast<byte>(payloadLen >> 8); 2445 bytes[9] = static_cast<byte>(payloadLen ); 2446 if (hasMask) { 2447 mask.copyTo(bytes + 10); 2448 fill = 14; 2449 } else { 2450 fill = 10; 2451 } 2452 } 2453 2454 return arrayPtr(bytes, fill); 2455 } 2456 2457 bool isFin() const { 2458 return bytes[0] & FIN_MASK; 2459 } 2460 2461 bool hasRsv() const { 2462 return bytes[0] & RSV_MASK; 2463 } 2464 2465 byte getOpcode() const { 2466 return bytes[0] & OPCODE_MASK; 2467 } 2468 2469 uint64_t getPayloadLen() const { 2470 byte payloadLen = bytes[1] & PAYLOAD_LEN_MASK; 2471 if (payloadLen == 127) { 2472 return (static_cast<uint64_t>(bytes[2]) << 56) 2473 | (static_cast<uint64_t>(bytes[3]) << 48) 2474 | (static_cast<uint64_t>(bytes[4]) << 40) 2475 | (static_cast<uint64_t>(bytes[5]) << 32) 2476 | (static_cast<uint64_t>(bytes[6]) << 24) 2477 | (static_cast<uint64_t>(bytes[7]) << 16) 2478 | (static_cast<uint64_t>(bytes[8]) << 8) 2479 | (static_cast<uint64_t>(bytes[9]) ); 2480 } else if (payloadLen == 126) { 2481 return (static_cast<uint64_t>(bytes[2]) << 8) 2482 | (static_cast<uint64_t>(bytes[3]) ); 2483 } else { 2484 return payloadLen; 2485 } 2486 } 2487 2488 Mask getMask() const { 2489 if (bytes[1] & USE_MASK_MASK) { 2490 byte payloadLen = bytes[1] & PAYLOAD_LEN_MASK; 2491 if (payloadLen == 127) { 2492 return Mask(bytes + 10); 2493 } else if (payloadLen == 126) { 2494 return Mask(bytes + 4); 2495 } else { 2496 return Mask(bytes + 2); 2497 } 2498 } else { 2499 return Mask(); 2500 } 2501 } 2502 2503 static size_t headerSize(byte const* bytes, size_t sizeSoFar) { 2504 if (sizeSoFar < 2) return 2; 2505 2506 size_t required = 2; 2507 2508 if (bytes[1] & USE_MASK_MASK) { 2509 required += 4; 2510 } 2511 2512 byte payloadLen = bytes[1] & PAYLOAD_LEN_MASK; 2513 if (payloadLen == 127) { 2514 required += 8; 2515 } else if (payloadLen == 126) { 2516 required += 2; 2517 } 2518 2519 return required; 2520 } 2521 2522 private: 2523 byte bytes[14]; 2524 2525 static constexpr byte FIN_MASK = 0x80; 2526 static constexpr byte RSV_MASK = 0x70; 2527 static constexpr byte OPCODE_MASK = 0x0f; 2528 2529 static constexpr byte USE_MASK_MASK = 0x80; 2530 static constexpr byte PAYLOAD_LEN_MASK = 0x7f; 2531 }; 2532 2533 static constexpr byte OPCODE_CONTINUATION = 0; 2534 static constexpr byte OPCODE_TEXT = 1; 2535 static constexpr byte OPCODE_BINARY = 2; 2536 static constexpr byte OPCODE_CLOSE = 8; 2537 static constexpr byte OPCODE_PING = 9; 2538 static constexpr byte OPCODE_PONG = 10; 2539 2540 static constexpr byte OPCODE_FIRST_CONTROL = 8; 2541 2542 // --------------------------------------------------------------------------- 2543 2544 kj::Own<kj::AsyncIoStream> stream; 2545 kj::Maybe<EntropySource&> maskKeyGenerator; 2546 2547 bool hasSentClose = false; 2548 bool disconnected = false; 2549 bool currentlySending = false; 2550 Header sendHeader; 2551 kj::ArrayPtr<const byte> sendParts[2]; 2552 2553 kj::Maybe<kj::Array<byte>> queuedPong; 2554 // If a Ping is received while currentlySending is true, then queuedPong is set to the body of 2555 // a pong message that should be sent once the current send is complete. 2556 2557 kj::Maybe<kj::Promise<void>> sendingPong; 2558 // If a Pong is being sent asynchronously in response to a Ping, this is a promise for the 2559 // completion of that send. 2560 // 2561 // Additionally, this member is used if we need to block our first send on WebSocket startup, 2562 // e.g. because we need to wait for HTTP handshake writes to flush before we can start sending 2563 // WebSocket data. `sendingPong` was overloaded for this use case because the logic is the same. 2564 // Perhaps it should be renamed to `blockSend` or `writeQueue`. 2565 2566 uint fragmentOpcode = 0; 2567 kj::Vector<kj::Array<byte>> fragments; 2568 // If `fragments` is non-empty, we've already received some fragments of a message. 2569 // `fragmentOpcode` is the original opcode. 2570 2571 kj::Array<byte> recvBuffer; 2572 kj::ArrayPtr<byte> recvData; 2573 2574 uint64_t sentBytes = 0; 2575 uint64_t receivedBytes = 0; 2576 2577 kj::Promise<void> sendImpl(byte opcode, kj::ArrayPtr<const byte> message) { 2578 KJ_REQUIRE(!disconnected, "WebSocket can't send after disconnect()"); 2579 KJ_REQUIRE(!currentlySending, "another message send is already in progress"); 2580 2581 currentlySending = true; 2582 2583 KJ_IF_MAYBE(p, sendingPong) { 2584 // We recently sent a pong, make sure it's finished before proceeding. 2585 auto promise = p->then([this, opcode, message]() { 2586 currentlySending = false; 2587 return sendImpl(opcode, message); 2588 }); 2589 sendingPong = nullptr; 2590 return promise; 2591 } 2592 2593 // We don't stop the application from sending further messages after close() -- this is the 2594 // application's error to make. But, we do want to make sure we don't send any PONGs after a 2595 // close, since that would be our error. So we stack whether we closed for that reason. 2596 hasSentClose = hasSentClose || opcode == OPCODE_CLOSE; 2597 2598 Mask mask(maskKeyGenerator); 2599 2600 kj::Array<byte> ownMessage; 2601 if (!mask.isZero()) { 2602 // Sadness, we have to make a copy to apply the mask. 2603 ownMessage = kj::heapArray(message); 2604 mask.apply(ownMessage); 2605 message = ownMessage; 2606 } 2607 2608 sendParts[0] = sendHeader.compose(true, opcode, message.size(), mask); 2609 sendParts[1] = message; 2610 2611 auto promise = stream->write(sendParts); 2612 if (!mask.isZero()) { 2613 promise = promise.attach(kj::mv(ownMessage)); 2614 } 2615 return promise.then([this, size = sendParts[0].size() + sendParts[1].size()]() { 2616 currentlySending = false; 2617 2618 // Send queued pong if needed. 2619 KJ_IF_MAYBE(q, queuedPong) { 2620 kj::Array<byte> payload = kj::mv(*q); 2621 queuedPong = nullptr; 2622 queuePong(kj::mv(payload)); 2623 } 2624 sentBytes += size; 2625 }); 2626 } 2627 2628 void queuePong(kj::Array<byte> payload) { 2629 if (currentlySending) { 2630 // There is a message-send in progress, so we cannot write to the stream now. 2631 // 2632 // Note: According to spec, if the server receives a second ping before responding to the 2633 // previous one, it can opt to respond only to the last ping. So we don't have to check if 2634 // queuedPong is already non-null. 2635 queuedPong = kj::mv(payload); 2636 } else KJ_IF_MAYBE(promise, sendingPong) { 2637 // We're still sending a previous pong. Wait for it to finish before sending ours. 2638 sendingPong = promise->then(kj::mvCapture(payload, [this](kj::Array<byte> payload) mutable { 2639 return sendPong(kj::mv(payload)); 2640 })); 2641 } else { 2642 // We're not sending any pong currently. 2643 sendingPong = sendPong(kj::mv(payload)); 2644 } 2645 } 2646 2647 kj::Promise<void> sendPong(kj::Array<byte> payload) { 2648 if (hasSentClose || disconnected) { 2649 return kj::READY_NOW; 2650 } 2651 2652 sendParts[0] = sendHeader.compose(true, OPCODE_PONG, payload.size(), Mask(maskKeyGenerator)); 2653 sendParts[1] = payload; 2654 return stream->write(sendParts).attach(kj::mv(payload)); 2655 } 2656 2657 kj::Promise<void> optimizedPumpTo(WebSocketImpl& other) { 2658 KJ_IF_MAYBE(p, other.sendingPong) { 2659 // We recently sent a pong, make sure it's finished before proceeding. 2660 auto promise = p->then([this, &other]() { 2661 return optimizedPumpTo(other); 2662 }); 2663 other.sendingPong = nullptr; 2664 return promise; 2665 } 2666 2667 if (recvData.size() > 0) { 2668 // We have some data buffered. Write it first. 2669 return other.stream->write(recvData.begin(), recvData.size()) 2670 .then([this, &other, size = recvData.size()]() { 2671 recvData = nullptr; 2672 other.sentBytes += size; 2673 return optimizedPumpTo(other); 2674 }); 2675 } 2676 2677 auto cancelPromise = other.stream->whenWriteDisconnected() 2678 .then([this]() -> kj::Promise<void> { 2679 this->abort(); 2680 return KJ_EXCEPTION(DISCONNECTED, 2681 "destination of WebSocket pump disconnected prematurely"); 2682 }); 2683 2684 // There's no buffered incoming data, so start pumping stream now. 2685 return stream->pumpTo(*other.stream).then([this, &other](size_t s) -> kj::Promise<void> { 2686 // WebSocket pumps are expected to include end-of-stream. 2687 other.disconnected = true; 2688 other.stream->shutdownWrite(); 2689 receivedBytes += s; 2690 other.sentBytes += s; 2691 return kj::READY_NOW; 2692 }, [&other](kj::Exception&& e) -> kj::Promise<void> { 2693 // We don't know if it was a read or a write that threw. If it was a read that threw, we need 2694 // to send a disconnect on the destination. If it was the destination that threw, it 2695 // shouldn't hurt to disconnect() it again, but we'll catch and squelch any exceptions. 2696 other.disconnected = true; 2697 kj::runCatchingExceptions([&other]() { other.stream->shutdownWrite(); }); 2698 return kj::mv(e); 2699 }).exclusiveJoin(kj::mv(cancelPromise)); 2700 } 2701 }; 2702 2703 kj::Own<WebSocket> upgradeToWebSocket( 2704 kj::Own<kj::AsyncIoStream> stream, HttpInputStreamImpl& httpInput, HttpOutputStream& httpOutput, 2705 kj::Maybe<EntropySource&> maskKeyGenerator) { 2706 // Create a WebSocket upgraded from an HTTP stream. 2707 auto releasedBuffer = httpInput.releaseBuffer(); 2708 return kj::heap<WebSocketImpl>(kj::mv(stream), maskKeyGenerator, 2709 kj::mv(releasedBuffer.buffer), releasedBuffer.leftover, 2710 httpOutput.flush()); 2711 } 2712 2713 } // namespace 2714 2715 kj::Own<WebSocket> newWebSocket(kj::Own<kj::AsyncIoStream> stream, 2716 kj::Maybe<EntropySource&> maskKeyGenerator) { 2717 return kj::heap<WebSocketImpl>(kj::mv(stream), maskKeyGenerator); 2718 } 2719 2720 static kj::Promise<void> pumpWebSocketLoop(WebSocket& from, WebSocket& to) { 2721 return from.receive().then([&from,&to](WebSocket::Message&& message) { 2722 KJ_SWITCH_ONEOF(message) { 2723 KJ_CASE_ONEOF(text, kj::String) { 2724 return to.send(text) 2725 .attach(kj::mv(text)) 2726 .then([&from,&to]() { return pumpWebSocketLoop(from, to); }); 2727 } 2728 KJ_CASE_ONEOF(data, kj::Array<byte>) { 2729 return to.send(data) 2730 .attach(kj::mv(data)) 2731 .then([&from,&to]() { return pumpWebSocketLoop(from, to); }); 2732 } 2733 KJ_CASE_ONEOF(close, WebSocket::Close) { 2734 // Once a close has passed through, the pump is complete. 2735 return to.close(close.code, close.reason) 2736 .attach(kj::mv(close)); 2737 } 2738 } 2739 KJ_UNREACHABLE; 2740 }, [&to](kj::Exception&& e) { 2741 if (e.getType() == kj::Exception::Type::DISCONNECTED) { 2742 return to.disconnect(); 2743 } else { 2744 return to.close(1002, e.getDescription()); 2745 } 2746 }); 2747 } 2748 2749 kj::Promise<void> WebSocket::pumpTo(WebSocket& other) { 2750 KJ_IF_MAYBE(p, other.tryPumpFrom(*this)) { 2751 // Yay, optimized pump! 2752 return kj::mv(*p); 2753 } else { 2754 // Fall back to default implementation. 2755 return kj::evalNow([&]() { 2756 auto cancelPromise = other.whenAborted().then([this]() -> kj::Promise<void> { 2757 this->abort(); 2758 return KJ_EXCEPTION(DISCONNECTED, 2759 "destination of WebSocket pump disconnected prematurely"); 2760 }); 2761 return pumpWebSocketLoop(*this, other).exclusiveJoin(kj::mv(cancelPromise)); 2762 }); 2763 } 2764 } 2765 2766 kj::Maybe<kj::Promise<void>> WebSocket::tryPumpFrom(WebSocket& other) { 2767 return nullptr; 2768 } 2769 2770 namespace { 2771 2772 class WebSocketPipeImpl final: public WebSocket, public kj::Refcounted { 2773 // Represents one direction of a WebSocket pipe. 2774 // 2775 // This class behaves as a "loopback" WebSocket: a message sent using send() is received using 2776 // receive(), on the same object. This is *not* how WebSocket implementations usually behave. 2777 // But, this object is actually used to implement only one direction of a bidirectional pipe. At 2778 // another layer above this, the pipe is actually composed of two WebSocketPipeEnd instances, 2779 // which layer on top of two WebSocketPipeImpl instances representing the two directions. So, 2780 // send() calls on a WebSocketPipeImpl instance always come from one of the two WebSocketPipeEnds 2781 // while receive() calls come from the other end. 2782 2783 public: 2784 ~WebSocketPipeImpl() noexcept(false) { 2785 KJ_REQUIRE(state == nullptr || ownState.get() != nullptr, 2786 "destroying WebSocketPipe with operation still in-progress; probably going to segfault") { 2787 // Don't std::terminate(). 2788 break; 2789 } 2790 } 2791 2792 void abort() override { 2793 KJ_IF_MAYBE(s, state) { 2794 s->abort(); 2795 } else { 2796 ownState = heap<Aborted>(); 2797 state = *ownState; 2798 2799 aborted = true; 2800 KJ_IF_MAYBE(f, abortedFulfiller) { 2801 f->get()->fulfill(); 2802 abortedFulfiller = nullptr; 2803 } 2804 } 2805 } 2806 2807 kj::Promise<void> send(kj::ArrayPtr<const byte> message) override { 2808 KJ_IF_MAYBE(s, state) { 2809 return s->send(message).then([&, size = message.size()]() { transferredBytes += size; }); 2810 } else { 2811 return newAdaptedPromise<void, BlockedSend>(*this, MessagePtr(message)) 2812 .then([&, size = message.size()]() { transferredBytes += size; }); 2813 } 2814 } 2815 kj::Promise<void> send(kj::ArrayPtr<const char> message) override { 2816 KJ_IF_MAYBE(s, state) { 2817 return s->send(message).then([&, size = message.size()]() { transferredBytes += size; }); 2818 } else { 2819 return newAdaptedPromise<void, BlockedSend>(*this, MessagePtr(message)) 2820 .then([&, size = message.size()]() { transferredBytes += size; }); 2821 } 2822 } 2823 kj::Promise<void> close(uint16_t code, kj::StringPtr reason) override { 2824 KJ_IF_MAYBE(s, state) { 2825 return s->close(code, reason) 2826 .then([&, size = reason.size()]() { transferredBytes += (2 +size); }); 2827 } else { 2828 return newAdaptedPromise<void, BlockedSend>(*this, MessagePtr(ClosePtr { code, reason })) 2829 .then([&, size = reason.size()]() { transferredBytes += (2 +size); }); 2830 } 2831 } 2832 kj::Promise<void> disconnect() override { 2833 KJ_IF_MAYBE(s, state) { 2834 return s->disconnect(); 2835 } else { 2836 ownState = heap<Disconnected>(); 2837 state = *ownState; 2838 return kj::READY_NOW; 2839 } 2840 } 2841 kj::Promise<void> whenAborted() override { 2842 if (aborted) { 2843 return kj::READY_NOW; 2844 } else KJ_IF_MAYBE(p, abortedPromise) { 2845 return p->addBranch(); 2846 } else { 2847 auto paf = newPromiseAndFulfiller<void>(); 2848 abortedFulfiller = kj::mv(paf.fulfiller); 2849 auto fork = paf.promise.fork(); 2850 auto result = fork.addBranch(); 2851 abortedPromise = kj::mv(fork); 2852 return result; 2853 } 2854 } 2855 kj::Maybe<kj::Promise<void>> tryPumpFrom(WebSocket& other) override { 2856 KJ_IF_MAYBE(s, state) { 2857 return s->tryPumpFrom(other); 2858 } else { 2859 return newAdaptedPromise<void, BlockedPumpFrom>(*this, other); 2860 } 2861 } 2862 2863 kj::Promise<Message> receive(size_t maxSize) override { 2864 KJ_IF_MAYBE(s, state) { 2865 return s->receive(maxSize); 2866 } else { 2867 return newAdaptedPromise<Message, BlockedReceive>(*this, maxSize); 2868 } 2869 } 2870 kj::Promise<void> pumpTo(WebSocket& other) override { 2871 KJ_IF_MAYBE(s, state) { 2872 auto before = other.receivedByteCount(); 2873 return s->pumpTo(other).attach(kj::defer([this, &other, before]() { 2874 transferredBytes += other.receivedByteCount() - before; 2875 })); 2876 } else { 2877 return newAdaptedPromise<void, BlockedPumpTo>(*this, other); 2878 } 2879 } 2880 2881 uint64_t sentByteCount() override { 2882 return transferredBytes; 2883 } 2884 uint64_t receivedByteCount() override { 2885 return transferredBytes; 2886 } 2887 2888 private: 2889 kj::Maybe<WebSocket&> state; 2890 // Object-oriented state! If any method call is blocked waiting on activity from the other end, 2891 // then `state` is non-null and method calls should be forwarded to it. If no calls are 2892 // outstanding, `state` is null. 2893 2894 kj::Own<WebSocket> ownState; 2895 2896 uint64_t transferredBytes = 0; 2897 2898 bool aborted = false; 2899 Maybe<Own<PromiseFulfiller<void>>> abortedFulfiller = nullptr; 2900 Maybe<ForkedPromise<void>> abortedPromise = nullptr; 2901 2902 void endState(WebSocket& obj) { 2903 KJ_IF_MAYBE(s, state) { 2904 if (s == &obj) { 2905 state = nullptr; 2906 } 2907 } 2908 } 2909 2910 struct ClosePtr { 2911 uint16_t code; 2912 kj::StringPtr reason; 2913 }; 2914 typedef kj::OneOf<kj::ArrayPtr<const char>, kj::ArrayPtr<const byte>, ClosePtr> MessagePtr; 2915 2916 class BlockedSend final: public WebSocket { 2917 public: 2918 BlockedSend(kj::PromiseFulfiller<void>& fulfiller, WebSocketPipeImpl& pipe, MessagePtr message) 2919 : fulfiller(fulfiller), pipe(pipe), message(kj::mv(message)) { 2920 KJ_REQUIRE(pipe.state == nullptr); 2921 pipe.state = *this; 2922 } 2923 ~BlockedSend() noexcept(false) { 2924 pipe.endState(*this); 2925 } 2926 2927 void abort() override { 2928 canceler.cancel("other end of WebSocketPipe was destroyed"); 2929 fulfiller.reject(KJ_EXCEPTION(DISCONNECTED, "other end of WebSocketPipe was destroyed")); 2930 pipe.endState(*this); 2931 pipe.abort(); 2932 } 2933 kj::Promise<void> whenAborted() override { 2934 KJ_FAIL_ASSERT("can't get here -- implemented by WebSocketPipeImpl"); 2935 } 2936 2937 kj::Promise<void> send(kj::ArrayPtr<const byte> message) override { 2938 KJ_FAIL_ASSERT("another message send is already in progress"); 2939 } 2940 kj::Promise<void> send(kj::ArrayPtr<const char> message) override { 2941 KJ_FAIL_ASSERT("another message send is already in progress"); 2942 } 2943 kj::Promise<void> close(uint16_t code, kj::StringPtr reason) override { 2944 KJ_FAIL_ASSERT("another message send is already in progress"); 2945 } 2946 kj::Promise<void> disconnect() override { 2947 KJ_FAIL_ASSERT("another message send is already in progress"); 2948 } 2949 kj::Maybe<kj::Promise<void>> tryPumpFrom(WebSocket& other) override { 2950 KJ_FAIL_ASSERT("another message send is already in progress"); 2951 } 2952 2953 kj::Promise<Message> receive(size_t maxSize) override { 2954 KJ_REQUIRE(canceler.isEmpty(), "already pumping"); 2955 fulfiller.fulfill(); 2956 pipe.endState(*this); 2957 KJ_SWITCH_ONEOF(message) { 2958 KJ_CASE_ONEOF(arr, kj::ArrayPtr<const char>) { 2959 return Message(kj::str(arr)); 2960 } 2961 KJ_CASE_ONEOF(arr, kj::ArrayPtr<const byte>) { 2962 auto copy = kj::heapArray<byte>(arr.size()); 2963 memcpy(copy.begin(), arr.begin(), arr.size()); 2964 return Message(kj::mv(copy)); 2965 } 2966 KJ_CASE_ONEOF(close, ClosePtr) { 2967 return Message(Close { close.code, kj::str(close.reason) }); 2968 } 2969 } 2970 KJ_UNREACHABLE; 2971 } 2972 kj::Promise<void> pumpTo(WebSocket& other) override { 2973 KJ_REQUIRE(canceler.isEmpty(), "already pumping"); 2974 kj::Promise<void> promise = nullptr; 2975 KJ_SWITCH_ONEOF(message) { 2976 KJ_CASE_ONEOF(arr, kj::ArrayPtr<const char>) { 2977 promise = other.send(arr); 2978 } 2979 KJ_CASE_ONEOF(arr, kj::ArrayPtr<const byte>) { 2980 promise = other.send(arr); 2981 } 2982 KJ_CASE_ONEOF(close, ClosePtr) { 2983 promise = other.close(close.code, close.reason); 2984 } 2985 } 2986 return canceler.wrap(promise.then([this,&other]() { 2987 canceler.release(); 2988 fulfiller.fulfill(); 2989 pipe.endState(*this); 2990 return pipe.pumpTo(other); 2991 }, [this](kj::Exception&& e) -> kj::Promise<void> { 2992 canceler.release(); 2993 fulfiller.reject(kj::cp(e)); 2994 pipe.endState(*this); 2995 return kj::mv(e); 2996 })); 2997 } 2998 2999 uint64_t sentByteCount() override { 3000 KJ_FAIL_ASSERT("Bytes are not counted for the individual states of WebSocketPipeImpl."); 3001 } 3002 uint64_t receivedByteCount() override { 3003 KJ_FAIL_ASSERT("Bytes are not counted for the individual states of WebSocketPipeImpl."); 3004 } 3005 3006 private: 3007 kj::PromiseFulfiller<void>& fulfiller; 3008 WebSocketPipeImpl& pipe; 3009 MessagePtr message; 3010 Canceler canceler; 3011 }; 3012 3013 class BlockedPumpFrom final: public WebSocket { 3014 public: 3015 BlockedPumpFrom(kj::PromiseFulfiller<void>& fulfiller, WebSocketPipeImpl& pipe, 3016 WebSocket& input) 3017 : fulfiller(fulfiller), pipe(pipe), input(input) { 3018 KJ_REQUIRE(pipe.state == nullptr); 3019 pipe.state = *this; 3020 } 3021 ~BlockedPumpFrom() noexcept(false) { 3022 pipe.endState(*this); 3023 } 3024 3025 void abort() override { 3026 canceler.cancel("other end of WebSocketPipe was destroyed"); 3027 fulfiller.reject(KJ_EXCEPTION(DISCONNECTED, "other end of WebSocketPipe was destroyed")); 3028 pipe.endState(*this); 3029 pipe.abort(); 3030 } 3031 kj::Promise<void> whenAborted() override { 3032 KJ_FAIL_ASSERT("can't get here -- implemented by WebSocketPipeImpl"); 3033 } 3034 3035 kj::Promise<void> send(kj::ArrayPtr<const byte> message) override { 3036 KJ_FAIL_ASSERT("another message send is already in progress"); 3037 } 3038 kj::Promise<void> send(kj::ArrayPtr<const char> message) override { 3039 KJ_FAIL_ASSERT("another message send is already in progress"); 3040 } 3041 kj::Promise<void> close(uint16_t code, kj::StringPtr reason) override { 3042 KJ_FAIL_ASSERT("another message send is already in progress"); 3043 } 3044 kj::Promise<void> disconnect() override { 3045 KJ_FAIL_ASSERT("another message send is already in progress"); 3046 } 3047 kj::Maybe<kj::Promise<void>> tryPumpFrom(WebSocket& other) override { 3048 KJ_FAIL_ASSERT("another message send is already in progress"); 3049 } 3050 3051 kj::Promise<Message> receive(size_t maxSize) override { 3052 KJ_REQUIRE(canceler.isEmpty(), "another message receive is already in progress"); 3053 return canceler.wrap(input.receive(maxSize) 3054 .then([this](Message message) { 3055 if (message.is<Close>()) { 3056 canceler.release(); 3057 fulfiller.fulfill(); 3058 pipe.endState(*this); 3059 } 3060 return kj::mv(message); 3061 }, [this](kj::Exception&& e) -> Message { 3062 canceler.release(); 3063 fulfiller.reject(kj::cp(e)); 3064 pipe.endState(*this); 3065 kj::throwRecoverableException(kj::mv(e)); 3066 return Message(kj::String()); 3067 })); 3068 } 3069 kj::Promise<void> pumpTo(WebSocket& other) override { 3070 KJ_REQUIRE(canceler.isEmpty(), "another message receive is already in progress"); 3071 return canceler.wrap(input.pumpTo(other) 3072 .then([this]() { 3073 canceler.release(); 3074 fulfiller.fulfill(); 3075 pipe.endState(*this); 3076 }, [this](kj::Exception&& e) { 3077 canceler.release(); 3078 fulfiller.reject(kj::cp(e)); 3079 pipe.endState(*this); 3080 kj::throwRecoverableException(kj::mv(e)); 3081 })); 3082 } 3083 3084 uint64_t sentByteCount() override { 3085 KJ_FAIL_ASSERT("Bytes are not counted for the individual states of WebSocketPipeImpl."); 3086 } 3087 uint64_t receivedByteCount() override { 3088 KJ_FAIL_ASSERT("Bytes are not counted for the individual states of WebSocketPipeImpl."); 3089 } 3090 3091 private: 3092 kj::PromiseFulfiller<void>& fulfiller; 3093 WebSocketPipeImpl& pipe; 3094 WebSocket& input; 3095 Canceler canceler; 3096 }; 3097 3098 class BlockedReceive final: public WebSocket { 3099 public: 3100 BlockedReceive(kj::PromiseFulfiller<Message>& fulfiller, WebSocketPipeImpl& pipe, 3101 size_t maxSize) 3102 : fulfiller(fulfiller), pipe(pipe), maxSize(maxSize) { 3103 KJ_REQUIRE(pipe.state == nullptr); 3104 pipe.state = *this; 3105 } 3106 ~BlockedReceive() noexcept(false) { 3107 pipe.endState(*this); 3108 } 3109 3110 void abort() override { 3111 canceler.cancel("other end of WebSocketPipe was destroyed"); 3112 fulfiller.reject(KJ_EXCEPTION(DISCONNECTED, "other end of WebSocketPipe was destroyed")); 3113 pipe.endState(*this); 3114 pipe.abort(); 3115 } 3116 kj::Promise<void> whenAborted() override { 3117 KJ_FAIL_ASSERT("can't get here -- implemented by WebSocketPipeImpl"); 3118 } 3119 3120 kj::Promise<void> send(kj::ArrayPtr<const byte> message) override { 3121 KJ_REQUIRE(canceler.isEmpty(), "already pumping"); 3122 auto copy = kj::heapArray<byte>(message.size()); 3123 memcpy(copy.begin(), message.begin(), message.size()); 3124 fulfiller.fulfill(Message(kj::mv(copy))); 3125 pipe.endState(*this); 3126 return kj::READY_NOW; 3127 } 3128 kj::Promise<void> send(kj::ArrayPtr<const char> message) override { 3129 KJ_REQUIRE(canceler.isEmpty(), "already pumping"); 3130 fulfiller.fulfill(Message(kj::str(message))); 3131 pipe.endState(*this); 3132 return kj::READY_NOW; 3133 } 3134 kj::Promise<void> close(uint16_t code, kj::StringPtr reason) override { 3135 KJ_REQUIRE(canceler.isEmpty(), "already pumping"); 3136 fulfiller.fulfill(Message(Close { code, kj::str(reason) })); 3137 pipe.endState(*this); 3138 return kj::READY_NOW; 3139 } 3140 kj::Promise<void> disconnect() override { 3141 KJ_REQUIRE(canceler.isEmpty(), "already pumping"); 3142 fulfiller.reject(KJ_EXCEPTION(DISCONNECTED, "WebSocket disconnected")); 3143 pipe.endState(*this); 3144 return pipe.disconnect(); 3145 } 3146 kj::Maybe<kj::Promise<void>> tryPumpFrom(WebSocket& other) override { 3147 KJ_REQUIRE(canceler.isEmpty(), "already pumping"); 3148 return canceler.wrap(other.receive(maxSize).then([this,&other](Message message) { 3149 canceler.release(); 3150 fulfiller.fulfill(kj::mv(message)); 3151 pipe.endState(*this); 3152 return other.pumpTo(pipe); 3153 }, [this](kj::Exception&& e) -> kj::Promise<void> { 3154 canceler.release(); 3155 fulfiller.reject(kj::cp(e)); 3156 pipe.endState(*this); 3157 return kj::mv(e); 3158 })); 3159 } 3160 3161 kj::Promise<Message> receive(size_t maxSize) override { 3162 KJ_FAIL_ASSERT("another message receive is already in progress"); 3163 } 3164 kj::Promise<void> pumpTo(WebSocket& other) override { 3165 KJ_FAIL_ASSERT("another message receive is already in progress"); 3166 } 3167 3168 uint64_t sentByteCount() override { 3169 KJ_FAIL_ASSERT("Bytes are not counted for the individual states of WebSocketPipeImpl."); 3170 } 3171 uint64_t receivedByteCount() override { 3172 KJ_FAIL_ASSERT("Bytes are not counted for the individual states of WebSocketPipeImpl."); 3173 } 3174 3175 private: 3176 kj::PromiseFulfiller<Message>& fulfiller; 3177 WebSocketPipeImpl& pipe; 3178 size_t maxSize; 3179 Canceler canceler; 3180 }; 3181 3182 class BlockedPumpTo final: public WebSocket { 3183 public: 3184 BlockedPumpTo(kj::PromiseFulfiller<void>& fulfiller, WebSocketPipeImpl& pipe, WebSocket& output) 3185 : fulfiller(fulfiller), pipe(pipe), output(output) { 3186 KJ_REQUIRE(pipe.state == nullptr); 3187 pipe.state = *this; 3188 } 3189 ~BlockedPumpTo() noexcept(false) { 3190 pipe.endState(*this); 3191 } 3192 3193 void abort() override { 3194 canceler.cancel("other end of WebSocketPipe was destroyed"); 3195 3196 // abort() is called when the pipe end is dropped. This should be treated as disconnecting, 3197 // so pumpTo() should complete normally. 3198 fulfiller.fulfill(); 3199 3200 pipe.endState(*this); 3201 pipe.abort(); 3202 } 3203 kj::Promise<void> whenAborted() override { 3204 KJ_FAIL_ASSERT("can't get here -- implemented by WebSocketPipeImpl"); 3205 } 3206 3207 kj::Promise<void> send(kj::ArrayPtr<const byte> message) override { 3208 KJ_REQUIRE(canceler.isEmpty(), "another message send is already in progress"); 3209 return canceler.wrap(output.send(message)); 3210 } 3211 kj::Promise<void> send(kj::ArrayPtr<const char> message) override { 3212 KJ_REQUIRE(canceler.isEmpty(), "another message send is already in progress"); 3213 return canceler.wrap(output.send(message)); 3214 } 3215 kj::Promise<void> close(uint16_t code, kj::StringPtr reason) override { 3216 KJ_REQUIRE(canceler.isEmpty(), "another message send is already in progress"); 3217 return canceler.wrap(output.close(code, reason).then([this]() { 3218 // A pump is expected to end upon seeing a Close message. 3219 canceler.release(); 3220 pipe.endState(*this); 3221 fulfiller.fulfill(); 3222 })); 3223 } 3224 kj::Promise<void> disconnect() override { 3225 KJ_REQUIRE(canceler.isEmpty(), "another message send is already in progress"); 3226 return canceler.wrap(output.disconnect().then([this]() { 3227 canceler.release(); 3228 pipe.endState(*this); 3229 fulfiller.fulfill(); 3230 return pipe.disconnect(); 3231 })); 3232 } 3233 kj::Maybe<kj::Promise<void>> tryPumpFrom(WebSocket& other) override { 3234 KJ_REQUIRE(canceler.isEmpty(), "another message send is already in progress"); 3235 return canceler.wrap(other.pumpTo(output).then([this]() { 3236 canceler.release(); 3237 pipe.endState(*this); 3238 fulfiller.fulfill(); 3239 })); 3240 } 3241 3242 kj::Promise<Message> receive(size_t maxSize) override { 3243 KJ_FAIL_ASSERT("another message receive is already in progress"); 3244 } 3245 kj::Promise<void> pumpTo(WebSocket& other) override { 3246 KJ_FAIL_ASSERT("another message receive is already in progress"); 3247 } 3248 3249 uint64_t sentByteCount() override { 3250 KJ_FAIL_ASSERT("Bytes are not counted for the individual states of WebSocketPipeImpl."); 3251 } 3252 uint64_t receivedByteCount() override { 3253 KJ_FAIL_ASSERT("Bytes are not counted for the individual states of WebSocketPipeImpl."); 3254 } 3255 3256 private: 3257 kj::PromiseFulfiller<void>& fulfiller; 3258 WebSocketPipeImpl& pipe; 3259 WebSocket& output; 3260 Canceler canceler; 3261 }; 3262 3263 class Disconnected final: public WebSocket { 3264 public: 3265 void abort() override { 3266 // can ignore 3267 } 3268 kj::Promise<void> whenAborted() override { 3269 KJ_FAIL_ASSERT("can't get here -- implemented by WebSocketPipeImpl"); 3270 } 3271 3272 kj::Promise<void> send(kj::ArrayPtr<const byte> message) override { 3273 KJ_FAIL_REQUIRE("can't send() after disconnect()"); 3274 } 3275 kj::Promise<void> send(kj::ArrayPtr<const char> message) override { 3276 KJ_FAIL_REQUIRE("can't send() after disconnect()"); 3277 } 3278 kj::Promise<void> close(uint16_t code, kj::StringPtr reason) override { 3279 KJ_FAIL_REQUIRE("can't close() after disconnect()"); 3280 } 3281 kj::Promise<void> disconnect() override { 3282 return kj::READY_NOW; 3283 } 3284 kj::Maybe<kj::Promise<void>> tryPumpFrom(WebSocket& other) override { 3285 KJ_FAIL_REQUIRE("can't tryPumpFrom() after disconnect()"); 3286 } 3287 3288 kj::Promise<Message> receive(size_t maxSize) override { 3289 return KJ_EXCEPTION(DISCONNECTED, "WebSocket disconnected"); 3290 } 3291 kj::Promise<void> pumpTo(WebSocket& other) override { 3292 return kj::READY_NOW; 3293 } 3294 3295 uint64_t sentByteCount() override { 3296 KJ_FAIL_ASSERT("Bytes are not counted for the individual states of WebSocketPipeImpl."); 3297 } 3298 uint64_t receivedByteCount() override { 3299 KJ_FAIL_ASSERT("Bytes are not counted for the individual states of WebSocketPipeImpl."); 3300 } 3301 3302 }; 3303 3304 class Aborted final: public WebSocket { 3305 public: 3306 void abort() override { 3307 // can ignore 3308 } 3309 kj::Promise<void> whenAborted() override { 3310 KJ_FAIL_ASSERT("can't get here -- implemented by WebSocketPipeImpl"); 3311 } 3312 3313 kj::Promise<void> send(kj::ArrayPtr<const byte> message) override { 3314 return KJ_EXCEPTION(DISCONNECTED, "other end of WebSocketPipe was destroyed"); 3315 } 3316 kj::Promise<void> send(kj::ArrayPtr<const char> message) override { 3317 return KJ_EXCEPTION(DISCONNECTED, "other end of WebSocketPipe was destroyed"); 3318 } 3319 kj::Promise<void> close(uint16_t code, kj::StringPtr reason) override { 3320 return KJ_EXCEPTION(DISCONNECTED, "other end of WebSocketPipe was destroyed"); 3321 } 3322 kj::Promise<void> disconnect() override { 3323 return KJ_EXCEPTION(DISCONNECTED, "other end of WebSocketPipe was destroyed"); 3324 } 3325 kj::Maybe<kj::Promise<void>> tryPumpFrom(WebSocket& other) override { 3326 return kj::Promise<void>(KJ_EXCEPTION(DISCONNECTED, 3327 "other end of WebSocketPipe was destroyed")); 3328 } 3329 3330 kj::Promise<Message> receive(size_t maxSize) override { 3331 return KJ_EXCEPTION(DISCONNECTED, "other end of WebSocketPipe was destroyed"); 3332 } 3333 kj::Promise<void> pumpTo(WebSocket& other) override { 3334 return KJ_EXCEPTION(DISCONNECTED, "other end of WebSocketPipe was destroyed"); 3335 } 3336 3337 uint64_t sentByteCount() override { 3338 KJ_FAIL_ASSERT("Bytes are not counted for the individual states of WebSocketPipeImpl."); 3339 } 3340 uint64_t receivedByteCount() override { 3341 KJ_FAIL_ASSERT("Bytes are not counted for the individual states of WebSocketPipeImpl."); 3342 } 3343 }; 3344 }; 3345 3346 class WebSocketPipeEnd final: public WebSocket { 3347 public: 3348 WebSocketPipeEnd(kj::Own<WebSocketPipeImpl> in, kj::Own<WebSocketPipeImpl> out) 3349 : in(kj::mv(in)), out(kj::mv(out)) {} 3350 ~WebSocketPipeEnd() noexcept(false) { 3351 in->abort(); 3352 out->abort(); 3353 } 3354 3355 kj::Promise<void> send(kj::ArrayPtr<const byte> message) override { 3356 return out->send(message); 3357 } 3358 kj::Promise<void> send(kj::ArrayPtr<const char> message) override { 3359 return out->send(message); 3360 } 3361 kj::Promise<void> close(uint16_t code, kj::StringPtr reason) override { 3362 return out->close(code, reason); 3363 } 3364 kj::Promise<void> disconnect() override { 3365 return out->disconnect(); 3366 } 3367 void abort() override { 3368 in->abort(); 3369 out->abort(); 3370 } 3371 kj::Promise<void> whenAborted() override { 3372 return out->whenAborted(); 3373 } 3374 kj::Maybe<kj::Promise<void>> tryPumpFrom(WebSocket& other) override { 3375 return out->tryPumpFrom(other); 3376 } 3377 3378 kj::Promise<Message> receive(size_t maxSize) override { 3379 return in->receive(maxSize); 3380 } 3381 kj::Promise<void> pumpTo(WebSocket& other) override { 3382 return in->pumpTo(other); 3383 } 3384 3385 uint64_t sentByteCount() override { return out->sentByteCount(); } 3386 uint64_t receivedByteCount() override { return in->sentByteCount(); } 3387 3388 private: 3389 kj::Own<WebSocketPipeImpl> in; 3390 kj::Own<WebSocketPipeImpl> out; 3391 }; 3392 3393 } // namespace 3394 3395 WebSocketPipe newWebSocketPipe() { 3396 auto pipe1 = kj::refcounted<WebSocketPipeImpl>(); 3397 auto pipe2 = kj::refcounted<WebSocketPipeImpl>(); 3398 3399 auto end1 = kj::heap<WebSocketPipeEnd>(kj::addRef(*pipe1), kj::addRef(*pipe2)); 3400 auto end2 = kj::heap<WebSocketPipeEnd>(kj::mv(pipe2), kj::mv(pipe1)); 3401 3402 return { { kj::mv(end1), kj::mv(end2) } }; 3403 } 3404 3405 // ======================================================================================= 3406 3407 namespace { 3408 3409 class HttpClientImpl final: public HttpClient, 3410 private HttpClientErrorHandler { 3411 public: 3412 HttpClientImpl(const HttpHeaderTable& responseHeaderTable, kj::Own<kj::AsyncIoStream> rawStream, 3413 HttpClientSettings settings) 3414 : httpInput(*rawStream, responseHeaderTable), 3415 httpOutput(*rawStream), 3416 ownStream(kj::mv(rawStream)), 3417 settings(kj::mv(settings)) {} 3418 3419 bool canReuse() { 3420 // Returns true if we can immediately reuse this HttpClient for another message (so all 3421 // previous messages have been fully read). 3422 3423 return !upgraded && !closed && httpInput.canReuse() && httpOutput.canReuse(); 3424 } 3425 3426 Request request(HttpMethod method, kj::StringPtr url, const HttpHeaders& headers, 3427 kj::Maybe<uint64_t> expectedBodySize = nullptr) override { 3428 KJ_REQUIRE(!upgraded, 3429 "can't make further requests on this HttpClient because it has been or is in the process " 3430 "of being upgraded"); 3431 KJ_REQUIRE(!closed, 3432 "this HttpClient's connection has been closed by the server or due to an error"); 3433 KJ_REQUIRE(httpOutput.canReuse(), 3434 "can't start new request until previous request body has been fully written"); 3435 closeWatcherTask = nullptr; 3436 3437 kj::StringPtr connectionHeaders[HttpHeaders::CONNECTION_HEADERS_COUNT]; 3438 kj::String lengthStr; 3439 3440 bool isGet = method == HttpMethod::GET || method == HttpMethod::HEAD; 3441 bool hasBody; 3442 3443 KJ_IF_MAYBE(s, expectedBodySize) { 3444 if (isGet && *s == 0) { 3445 // GET with empty body; don't send any Content-Length. 3446 hasBody = false; 3447 } else { 3448 lengthStr = kj::str(*s); 3449 connectionHeaders[HttpHeaders::BuiltinIndices::CONTENT_LENGTH] = lengthStr; 3450 hasBody = true; 3451 } 3452 } else { 3453 if (isGet && headers.get(HttpHeaderId::TRANSFER_ENCODING) == nullptr) { 3454 // GET with empty body; don't send any Transfer-Encoding. 3455 hasBody = false; 3456 } else { 3457 // HACK: Normally GET requests shouldn't have bodies. But, if the caller set a 3458 // Transfer-Encoding header on a GET, we use this as a special signal that it might 3459 // actually want to send a body. This allows pass-through of a GET request with a chunked 3460 // body to "just work". We strongly discourage writing any new code that sends 3461 // full-bodied GETs. 3462 connectionHeaders[HttpHeaders::BuiltinIndices::TRANSFER_ENCODING] = "chunked"; 3463 hasBody = true; 3464 } 3465 } 3466 3467 httpOutput.writeHeaders(headers.serializeRequest(method, url, connectionHeaders)); 3468 3469 kj::Own<kj::AsyncOutputStream> bodyStream; 3470 if (!hasBody) { 3471 // No entity-body. 3472 httpOutput.finishBody(); 3473 bodyStream = heap<HttpNullEntityWriter>(); 3474 } else KJ_IF_MAYBE(s, expectedBodySize) { 3475 bodyStream = heap<HttpFixedLengthEntityWriter>(httpOutput, *s); 3476 } else { 3477 bodyStream = heap<HttpChunkedEntityWriter>(httpOutput); 3478 } 3479 3480 auto id = ++counter; 3481 3482 auto responsePromise = httpInput.readResponseHeaders().then( 3483 [this,method,id](HttpHeaders::ResponseOrProtocolError&& responseOrProtocolError) 3484 -> HttpClient::Response { 3485 KJ_SWITCH_ONEOF(responseOrProtocolError) { 3486 KJ_CASE_ONEOF(response, HttpHeaders::Response) { 3487 auto& responseHeaders = httpInput.getHeaders(); 3488 HttpClient::Response result { 3489 response.statusCode, 3490 response.statusText, 3491 &responseHeaders, 3492 httpInput.getEntityBody( 3493 HttpInputStreamImpl::RESPONSE, method, response.statusCode, responseHeaders) 3494 }; 3495 3496 if (fastCaseCmp<'c', 'l', 'o', 's', 'e'>( 3497 responseHeaders.get(HttpHeaderId::CONNECTION).orDefault(nullptr).cStr())) { 3498 closed = true; 3499 } else if (counter == id) { 3500 watchForClose(); 3501 } else { 3502 // Another request was already queued after this one, so we don't want to watch for 3503 // stream closure because we're fully expecting another response. 3504 } 3505 return result; 3506 } 3507 KJ_CASE_ONEOF(protocolError, HttpHeaders::ProtocolError) { 3508 closed = true; 3509 return settings.errorHandler.orDefault(*this).handleProtocolError( 3510 kj::mv(protocolError)); 3511 } 3512 } 3513 3514 KJ_UNREACHABLE; 3515 }); 3516 3517 return { kj::mv(bodyStream), kj::mv(responsePromise) }; 3518 } 3519 3520 kj::Promise<WebSocketResponse> openWebSocket( 3521 kj::StringPtr url, const HttpHeaders& headers) override { 3522 KJ_REQUIRE(!upgraded, 3523 "can't make further requests on this HttpClient because it has been or is in the process " 3524 "of being upgraded"); 3525 KJ_REQUIRE(!closed, 3526 "this HttpClient's connection has been closed by the server or due to an error"); 3527 closeWatcherTask = nullptr; 3528 3529 // Mark upgraded for now, even though the upgrade could fail, because we can't allow pipelined 3530 // requests in the meantime. 3531 upgraded = true; 3532 3533 byte keyBytes[16]; 3534 KJ_ASSERT_NONNULL(settings.entropySource, 3535 "can't use openWebSocket() because no EntropySource was provided when creating the " 3536 "HttpClient").generate(keyBytes); 3537 auto keyBase64 = kj::encodeBase64(keyBytes); 3538 3539 kj::StringPtr connectionHeaders[HttpHeaders::WEBSOCKET_CONNECTION_HEADERS_COUNT]; 3540 connectionHeaders[HttpHeaders::BuiltinIndices::CONNECTION] = "Upgrade"; 3541 connectionHeaders[HttpHeaders::BuiltinIndices::UPGRADE] = "websocket"; 3542 connectionHeaders[HttpHeaders::BuiltinIndices::SEC_WEBSOCKET_VERSION] = "13"; 3543 connectionHeaders[HttpHeaders::BuiltinIndices::SEC_WEBSOCKET_KEY] = keyBase64; 3544 3545 httpOutput.writeHeaders(headers.serializeRequest(HttpMethod::GET, url, connectionHeaders)); 3546 3547 // No entity-body. 3548 httpOutput.finishBody(); 3549 3550 auto id = ++counter; 3551 3552 return httpInput.readResponseHeaders() 3553 .then([this,id,keyBase64 = kj::mv(keyBase64)]( 3554 HttpHeaders::ResponseOrProtocolError&& responseOrProtocolError) 3555 -> HttpClient::WebSocketResponse { 3556 KJ_SWITCH_ONEOF(responseOrProtocolError) { 3557 KJ_CASE_ONEOF(response, HttpHeaders::Response) { 3558 auto& responseHeaders = httpInput.getHeaders(); 3559 if (response.statusCode == 101) { 3560 if (!fastCaseCmp<'w', 'e', 'b', 's', 'o', 'c', 'k', 'e', 't'>( 3561 responseHeaders.get(HttpHeaderId::UPGRADE).orDefault(nullptr).cStr())) { 3562 kj::String ownMessage; 3563 kj::StringPtr message; 3564 KJ_IF_MAYBE(actual, responseHeaders.get(HttpHeaderId::UPGRADE)) { 3565 ownMessage = kj::str( 3566 "Server failed WebSocket handshake: incorrect Upgrade header: " 3567 "expected 'websocket', got '", *actual, "'."); 3568 message = ownMessage; 3569 } else { 3570 message = "Server failed WebSocket handshake: missing Upgrade header."; 3571 } 3572 return settings.errorHandler.orDefault(*this).handleWebSocketProtocolError({ 3573 502, "Bad Gateway", message, nullptr 3574 }); 3575 } 3576 3577 auto expectedAccept = generateWebSocketAccept(keyBase64); 3578 if (responseHeaders.get(HttpHeaderId::SEC_WEBSOCKET_ACCEPT).orDefault(nullptr) 3579 != expectedAccept) { 3580 kj::String ownMessage; 3581 kj::StringPtr message; 3582 KJ_IF_MAYBE(actual, responseHeaders.get(HttpHeaderId::SEC_WEBSOCKET_ACCEPT)) { 3583 ownMessage = kj::str( 3584 "Server failed WebSocket handshake: incorrect Sec-WebSocket-Accept header: " 3585 "expected '", expectedAccept, "', got '", *actual, "'."); 3586 message = ownMessage; 3587 } else { 3588 message = "Server failed WebSocket handshake: missing Upgrade header."; 3589 } 3590 return settings.errorHandler.orDefault(*this).handleWebSocketProtocolError({ 3591 502, "Bad Gateway", message, nullptr 3592 }); 3593 } 3594 3595 return { 3596 response.statusCode, 3597 response.statusText, 3598 &httpInput.getHeaders(), 3599 upgradeToWebSocket(kj::mv(ownStream), httpInput, httpOutput, settings.entropySource), 3600 }; 3601 } else { 3602 upgraded = false; 3603 HttpClient::WebSocketResponse result { 3604 response.statusCode, 3605 response.statusText, 3606 &responseHeaders, 3607 httpInput.getEntityBody(HttpInputStreamImpl::RESPONSE, HttpMethod::GET, 3608 response.statusCode, responseHeaders) 3609 }; 3610 if (fastCaseCmp<'c', 'l', 'o', 's', 'e'>( 3611 responseHeaders.get(HttpHeaderId::CONNECTION).orDefault(nullptr).cStr())) { 3612 closed = true; 3613 } else if (counter == id) { 3614 watchForClose(); 3615 } else { 3616 // Another request was already queued after this one, so we don't want to watch for 3617 // stream closure because we're fully expecting another response. 3618 } 3619 return result; 3620 } 3621 } 3622 KJ_CASE_ONEOF(protocolError, HttpHeaders::ProtocolError) { 3623 return settings.errorHandler.orDefault(*this).handleWebSocketProtocolError( 3624 kj::mv(protocolError)); 3625 } 3626 } 3627 3628 KJ_UNREACHABLE; 3629 }); 3630 } 3631 3632 private: 3633 HttpInputStreamImpl httpInput; 3634 HttpOutputStream httpOutput; 3635 kj::Own<AsyncIoStream> ownStream; 3636 HttpClientSettings settings; 3637 kj::Maybe<kj::Promise<void>> closeWatcherTask; 3638 bool upgraded = false; 3639 bool closed = false; 3640 3641 uint counter = 0; 3642 // Counts requests for the sole purpose of detecting if more requests have been made after some 3643 // point in history. 3644 3645 void watchForClose() { 3646 closeWatcherTask = httpInput.awaitNextMessage() 3647 .then([this](bool hasData) -> kj::Promise<void> { 3648 if (hasData) { 3649 // Uhh... The server sent some data before we asked for anything. Perhaps due to properties 3650 // of this application, the server somehow already knows what the next request will be, and 3651 // it is trying to optimize. Or maybe this is some sort of test and the server is just 3652 // replaying a script. In any case, we will humor it -- leave the data in the buffer and 3653 // let it become the response to the next request. 3654 return kj::READY_NOW; 3655 } else { 3656 // EOF -- server disconnected. 3657 closed = true; 3658 if (httpOutput.isInBody()) { 3659 // Huh, the application is still sending a request. We should let it finish. We do not 3660 // need to proactively free the socket in this case because we know that we're not 3661 // sitting in a reusable connection pool, because we know the application is still 3662 // actively using the connection. 3663 return kj::READY_NOW; 3664 } else { 3665 return httpOutput.flush().then([this]() { 3666 // We might be sitting in NetworkAddressHttpClient's `availableClients` pool. We don't 3667 // have a way to notify it to remove this client from the pool; instead, when it tries 3668 // to pull this client from the pool later, it will notice the client is dead and will 3669 // discard it then. But, we would like to avoid holding on to a socket forever. So, 3670 // destroy the socket now. 3671 // TODO(cleanup): Maybe we should arrange to proactively remove ourselves? Seems 3672 // like the code will be awkward. 3673 ownStream = nullptr; 3674 }); 3675 } 3676 } 3677 }).eagerlyEvaluate(nullptr); 3678 } 3679 }; 3680 3681 } // namespace 3682 3683 kj::Promise<HttpClient::WebSocketResponse> HttpClient::openWebSocket( 3684 kj::StringPtr url, const HttpHeaders& headers) { 3685 return request(HttpMethod::GET, url, headers, nullptr) 3686 .response.then([](HttpClient::Response&& response) -> WebSocketResponse { 3687 kj::OneOf<kj::Own<kj::AsyncInputStream>, kj::Own<WebSocket>> body; 3688 body.init<kj::Own<kj::AsyncInputStream>>(kj::mv(response.body)); 3689 3690 return { 3691 response.statusCode, 3692 response.statusText, 3693 response.headers, 3694 kj::mv(body) 3695 }; 3696 }); 3697 } 3698 3699 kj::Promise<kj::Own<kj::AsyncIoStream>> HttpClient::connect(kj::StringPtr host) { 3700 KJ_UNIMPLEMENTED("CONNECT is not implemented by this HttpClient"); 3701 } 3702 3703 kj::Own<HttpClient> newHttpClient( 3704 const HttpHeaderTable& responseHeaderTable, kj::AsyncIoStream& stream, 3705 HttpClientSettings settings) { 3706 return kj::heap<HttpClientImpl>(responseHeaderTable, 3707 kj::Own<kj::AsyncIoStream>(&stream, kj::NullDisposer::instance), 3708 kj::mv(settings)); 3709 } 3710 3711 HttpClient::Response HttpClientErrorHandler::handleProtocolError( 3712 HttpHeaders::ProtocolError protocolError) { 3713 KJ_FAIL_REQUIRE(protocolError.description) { break; } 3714 return HttpClient::Response(); 3715 } 3716 3717 HttpClient::WebSocketResponse HttpClientErrorHandler::handleWebSocketProtocolError( 3718 HttpHeaders::ProtocolError protocolError) { 3719 auto response = handleProtocolError(protocolError); 3720 return HttpClient::WebSocketResponse { 3721 response.statusCode, response.statusText, response.headers, kj::mv(response.body) 3722 }; 3723 } 3724 3725 // ======================================================================================= 3726 3727 namespace { 3728 3729 class NetworkAddressHttpClient final: public HttpClient { 3730 public: 3731 NetworkAddressHttpClient(kj::Timer& timer, const HttpHeaderTable& responseHeaderTable, 3732 kj::Own<kj::NetworkAddress> address, HttpClientSettings settings) 3733 : timer(timer), 3734 responseHeaderTable(responseHeaderTable), 3735 address(kj::mv(address)), 3736 settings(kj::mv(settings)) {} 3737 3738 bool isDrained() { 3739 // Returns true if there are no open connections. 3740 return activeConnectionCount == 0 && availableClients.empty(); 3741 } 3742 3743 kj::Promise<void> onDrained() { 3744 // Returns a promise which resolves the next time isDrained() transitions from false to true. 3745 auto paf = kj::newPromiseAndFulfiller<void>(); 3746 drainedFulfiller = kj::mv(paf.fulfiller); 3747 return kj::mv(paf.promise); 3748 } 3749 3750 Request request(HttpMethod method, kj::StringPtr url, const HttpHeaders& headers, 3751 kj::Maybe<uint64_t> expectedBodySize = nullptr) override { 3752 auto refcounted = getClient(); 3753 auto result = refcounted->client->request(method, url, headers, expectedBodySize); 3754 result.body = result.body.attach(kj::addRef(*refcounted)); 3755 result.response = result.response.then(kj::mvCapture(refcounted, 3756 [](kj::Own<RefcountedClient>&& refcounted, Response&& response) { 3757 response.body = response.body.attach(kj::mv(refcounted)); 3758 return kj::mv(response); 3759 })); 3760 return result; 3761 } 3762 3763 kj::Promise<WebSocketResponse> openWebSocket( 3764 kj::StringPtr url, const HttpHeaders& headers) override { 3765 auto refcounted = getClient(); 3766 auto result = refcounted->client->openWebSocket(url, headers); 3767 return result.then(kj::mvCapture(refcounted, 3768 [](kj::Own<RefcountedClient>&& refcounted, WebSocketResponse&& response) { 3769 KJ_SWITCH_ONEOF(response.webSocketOrBody) { 3770 KJ_CASE_ONEOF(body, kj::Own<kj::AsyncInputStream>) { 3771 response.webSocketOrBody = body.attach(kj::mv(refcounted)); 3772 } 3773 KJ_CASE_ONEOF(ws, kj::Own<WebSocket>) { 3774 // The only reason we need to attach the client to the WebSocket is because otherwise 3775 // the response headers will be deleted prematurely. Otherwise, the WebSocket has taken 3776 // ownership of the connection. 3777 // 3778 // TODO(perf): Maybe we could transfer ownership of the response headers specifically? 3779 response.webSocketOrBody = ws.attach(kj::mv(refcounted)); 3780 } 3781 } 3782 return kj::mv(response); 3783 })); 3784 } 3785 3786 private: 3787 kj::Timer& timer; 3788 const HttpHeaderTable& responseHeaderTable; 3789 kj::Own<kj::NetworkAddress> address; 3790 HttpClientSettings settings; 3791 3792 kj::Maybe<kj::Own<kj::PromiseFulfiller<void>>> drainedFulfiller; 3793 uint activeConnectionCount = 0; 3794 3795 bool timeoutsScheduled = false; 3796 kj::Promise<void> timeoutTask = nullptr; 3797 3798 struct AvailableClient { 3799 kj::Own<HttpClientImpl> client; 3800 kj::TimePoint expires; 3801 }; 3802 3803 std::deque<AvailableClient> availableClients; 3804 3805 struct RefcountedClient final: public kj::Refcounted { 3806 RefcountedClient(NetworkAddressHttpClient& parent, kj::Own<HttpClientImpl> client) 3807 : parent(parent), client(kj::mv(client)) { 3808 ++parent.activeConnectionCount; 3809 } 3810 ~RefcountedClient() noexcept(false) { 3811 --parent.activeConnectionCount; 3812 KJ_IF_MAYBE(exception, kj::runCatchingExceptions([&]() { 3813 parent.returnClientToAvailable(kj::mv(client)); 3814 })) { 3815 KJ_LOG(ERROR, *exception); 3816 } 3817 } 3818 3819 NetworkAddressHttpClient& parent; 3820 kj::Own<HttpClientImpl> client; 3821 }; 3822 3823 kj::Own<RefcountedClient> getClient() { 3824 for (;;) { 3825 if (availableClients.empty()) { 3826 auto stream = newPromisedStream(address->connect()); 3827 return kj::refcounted<RefcountedClient>(*this, 3828 kj::heap<HttpClientImpl>(responseHeaderTable, kj::mv(stream), settings)); 3829 } else { 3830 auto client = kj::mv(availableClients.back().client); 3831 availableClients.pop_back(); 3832 if (client->canReuse()) { 3833 return kj::refcounted<RefcountedClient>(*this, kj::mv(client)); 3834 } 3835 // Whoops, this client's connection was closed by the server at some point. Discard. 3836 } 3837 } 3838 } 3839 3840 void returnClientToAvailable(kj::Own<HttpClientImpl> client) { 3841 // Only return the connection to the pool if it is reusable and if our settings indicate we 3842 // should reuse connections. 3843 if (client->canReuse() && settings.idleTimeout > 0 * kj::SECONDS) { 3844 availableClients.push_back(AvailableClient { 3845 kj::mv(client), timer.now() + settings.idleTimeout 3846 }); 3847 } 3848 3849 // Call this either way because it also signals onDrained(). 3850 if (!timeoutsScheduled) { 3851 timeoutsScheduled = true; 3852 timeoutTask = applyTimeouts(); 3853 } 3854 } 3855 3856 kj::Promise<void> applyTimeouts() { 3857 if (availableClients.empty()) { 3858 timeoutsScheduled = false; 3859 if (activeConnectionCount == 0) { 3860 KJ_IF_MAYBE(f, drainedFulfiller) { 3861 f->get()->fulfill(); 3862 drainedFulfiller = nullptr; 3863 } 3864 } 3865 return kj::READY_NOW; 3866 } else { 3867 auto time = availableClients.front().expires; 3868 return timer.atTime(time).then([this,time]() { 3869 while (!availableClients.empty() && availableClients.front().expires <= time) { 3870 availableClients.pop_front(); 3871 } 3872 return applyTimeouts(); 3873 }); 3874 } 3875 } 3876 }; 3877 3878 class PromiseNetworkAddressHttpClient final: public HttpClient { 3879 // An HttpClient which waits for a promise to resolve then forwards all calls to the promised 3880 // client. 3881 3882 public: 3883 PromiseNetworkAddressHttpClient(kj::Promise<kj::Own<NetworkAddressHttpClient>> promise) 3884 : promise(promise.then([this](kj::Own<NetworkAddressHttpClient>&& client) { 3885 this->client = kj::mv(client); 3886 }).fork()) {} 3887 3888 bool isDrained() { 3889 KJ_IF_MAYBE(c, client) { 3890 return c->get()->isDrained(); 3891 } else { 3892 return failed; 3893 } 3894 } 3895 3896 kj::Promise<void> onDrained() { 3897 KJ_IF_MAYBE(c, client) { 3898 return c->get()->onDrained(); 3899 } else { 3900 return promise.addBranch().then([this]() { 3901 return KJ_ASSERT_NONNULL(client)->onDrained(); 3902 }, [this](kj::Exception&& e) { 3903 // Connecting failed. Treat as immediately drained. 3904 failed = true; 3905 return kj::READY_NOW; 3906 }); 3907 } 3908 } 3909 3910 Request request(HttpMethod method, kj::StringPtr url, const HttpHeaders& headers, 3911 kj::Maybe<uint64_t> expectedBodySize = nullptr) override { 3912 KJ_IF_MAYBE(c, client) { 3913 return c->get()->request(method, url, headers, expectedBodySize); 3914 } else { 3915 // This gets complicated since request() returns a pair of a stream and a promise. 3916 auto urlCopy = kj::str(url); 3917 auto headersCopy = headers.clone(); 3918 auto combined = promise.addBranch().then(kj::mvCapture(urlCopy, kj::mvCapture(headersCopy, 3919 [this,method,expectedBodySize](HttpHeaders&& headers, kj::String&& url) 3920 -> kj::Tuple<kj::Own<kj::AsyncOutputStream>, kj::Promise<Response>> { 3921 auto req = KJ_ASSERT_NONNULL(client)->request(method, url, headers, expectedBodySize); 3922 return kj::tuple(kj::mv(req.body), kj::mv(req.response)); 3923 }))); 3924 3925 auto split = combined.split(); 3926 return { 3927 newPromisedStream(kj::mv(kj::get<0>(split))), 3928 kj::mv(kj::get<1>(split)) 3929 }; 3930 } 3931 } 3932 3933 kj::Promise<WebSocketResponse> openWebSocket( 3934 kj::StringPtr url, const HttpHeaders& headers) override { 3935 KJ_IF_MAYBE(c, client) { 3936 return c->get()->openWebSocket(url, headers); 3937 } else { 3938 auto urlCopy = kj::str(url); 3939 auto headersCopy = headers.clone(); 3940 return promise.addBranch().then(kj::mvCapture(urlCopy, kj::mvCapture(headersCopy, 3941 [this](HttpHeaders&& headers, kj::String&& url) { 3942 return KJ_ASSERT_NONNULL(client)->openWebSocket(url, headers); 3943 }))); 3944 } 3945 } 3946 3947 private: 3948 kj::ForkedPromise<void> promise; 3949 kj::Maybe<kj::Own<NetworkAddressHttpClient>> client; 3950 bool failed = false; 3951 }; 3952 3953 class NetworkHttpClient final: public HttpClient, private kj::TaskSet::ErrorHandler { 3954 public: 3955 NetworkHttpClient(kj::Timer& timer, const HttpHeaderTable& responseHeaderTable, 3956 kj::Network& network, kj::Maybe<kj::Network&> tlsNetwork, 3957 HttpClientSettings settings) 3958 : timer(timer), 3959 responseHeaderTable(responseHeaderTable), 3960 network(network), 3961 tlsNetwork(tlsNetwork), 3962 settings(kj::mv(settings)), 3963 tasks(*this) {} 3964 3965 Request request(HttpMethod method, kj::StringPtr url, const HttpHeaders& headers, 3966 kj::Maybe<uint64_t> expectedBodySize = nullptr) override { 3967 // We need to parse the proxy-style URL to convert it to host-style. 3968 // Use URL parsing options that avoid unnecessary rewrites. 3969 Url::Options urlOptions; 3970 urlOptions.allowEmpty = true; 3971 urlOptions.percentDecode = false; 3972 3973 auto parsed = Url::parse(url, Url::HTTP_PROXY_REQUEST, urlOptions); 3974 auto path = parsed.toString(Url::HTTP_REQUEST); 3975 auto headersCopy = headers.clone(); 3976 headersCopy.set(HttpHeaderId::HOST, parsed.host); 3977 return getClient(parsed).request(method, path, headersCopy, expectedBodySize); 3978 } 3979 3980 kj::Promise<WebSocketResponse> openWebSocket( 3981 kj::StringPtr url, const HttpHeaders& headers) override { 3982 // We need to parse the proxy-style URL to convert it to host-style. 3983 // Use URL parsing options that avoid unnecessary rewrites. 3984 Url::Options urlOptions; 3985 urlOptions.allowEmpty = true; 3986 urlOptions.percentDecode = false; 3987 3988 auto parsed = Url::parse(url, Url::HTTP_PROXY_REQUEST, urlOptions); 3989 auto path = parsed.toString(Url::HTTP_REQUEST); 3990 auto headersCopy = headers.clone(); 3991 headersCopy.set(HttpHeaderId::HOST, parsed.host); 3992 return getClient(parsed).openWebSocket(path, headersCopy); 3993 } 3994 3995 private: 3996 kj::Timer& timer; 3997 const HttpHeaderTable& responseHeaderTable; 3998 kj::Network& network; 3999 kj::Maybe<kj::Network&> tlsNetwork; 4000 HttpClientSettings settings; 4001 4002 struct Host { 4003 kj::String name; // including port, if non-default 4004 kj::Own<PromiseNetworkAddressHttpClient> client; 4005 }; 4006 4007 std::map<kj::StringPtr, Host> httpHosts; 4008 std::map<kj::StringPtr, Host> httpsHosts; 4009 4010 struct RequestInfo { 4011 HttpMethod method; 4012 kj::String hostname; 4013 kj::String path; 4014 HttpHeaders headers; 4015 kj::Maybe<uint64_t> expectedBodySize; 4016 }; 4017 4018 kj::TaskSet tasks; 4019 4020 HttpClient& getClient(kj::Url& parsed) { 4021 bool isHttps = parsed.scheme == "https"; 4022 bool isHttp = parsed.scheme == "http"; 4023 KJ_REQUIRE(isHttp || isHttps); 4024 4025 auto& hosts = isHttps ? httpsHosts : httpHosts; 4026 4027 // Look for a cached client for this host. 4028 // TODO(perf): It would be nice to recognize when different hosts have the same address and 4029 // reuse the same connection pool, but: 4030 // - We'd need a reliable way to compare NetworkAddresses, e.g. .equals() and .hashCode(). 4031 // It's very Java... ick. 4032 // - Correctly handling TLS would be tricky: we'd need to verify that the new hostname is 4033 // on the certificate. When SNI is in use we might have to request an additional 4034 // certificate (is that possible?). 4035 auto iter = hosts.find(parsed.host); 4036 4037 if (iter == hosts.end()) { 4038 // Need to open a new connection. 4039 kj::Network* networkToUse = &network; 4040 if (isHttps) { 4041 networkToUse = &KJ_REQUIRE_NONNULL(tlsNetwork, "this HttpClient doesn't support HTTPS"); 4042 } 4043 4044 auto promise = networkToUse->parseAddress(parsed.host, isHttps ? 443 : 80) 4045 .then([this](kj::Own<kj::NetworkAddress> addr) { 4046 return kj::heap<NetworkAddressHttpClient>( 4047 timer, responseHeaderTable, kj::mv(addr), settings); 4048 }); 4049 4050 Host host { 4051 kj::mv(parsed.host), 4052 kj::heap<PromiseNetworkAddressHttpClient>(kj::mv(promise)) 4053 }; 4054 kj::StringPtr nameRef = host.name; 4055 4056 auto insertResult = hosts.insert(std::make_pair(nameRef, kj::mv(host))); 4057 KJ_ASSERT(insertResult.second); 4058 iter = insertResult.first; 4059 4060 tasks.add(handleCleanup(hosts, iter)); 4061 } 4062 4063 return *iter->second.client; 4064 } 4065 4066 kj::Promise<void> handleCleanup(std::map<kj::StringPtr, Host>& hosts, 4067 std::map<kj::StringPtr, Host>::iterator iter) { 4068 return iter->second.client->onDrained() 4069 .then([this,&hosts,iter]() -> kj::Promise<void> { 4070 // Double-check that it's really drained to avoid race conditions. 4071 if (iter->second.client->isDrained()) { 4072 hosts.erase(iter); 4073 return kj::READY_NOW; 4074 } else { 4075 return handleCleanup(hosts, iter); 4076 } 4077 }); 4078 } 4079 4080 void taskFailed(kj::Exception&& exception) override { 4081 KJ_LOG(ERROR, exception); 4082 } 4083 }; 4084 4085 } // namespace 4086 4087 kj::Own<HttpClient> newHttpClient(kj::Timer& timer, const HttpHeaderTable& responseHeaderTable, 4088 kj::NetworkAddress& addr, HttpClientSettings settings) { 4089 return kj::heap<NetworkAddressHttpClient>(timer, responseHeaderTable, 4090 kj::Own<kj::NetworkAddress>(&addr, kj::NullDisposer::instance), kj::mv(settings)); 4091 } 4092 4093 kj::Own<HttpClient> newHttpClient(kj::Timer& timer, const HttpHeaderTable& responseHeaderTable, 4094 kj::Network& network, kj::Maybe<kj::Network&> tlsNetwork, 4095 HttpClientSettings settings) { 4096 return kj::heap<NetworkHttpClient>( 4097 timer, responseHeaderTable, network, tlsNetwork, kj::mv(settings)); 4098 } 4099 4100 // ======================================================================================= 4101 4102 namespace { 4103 4104 class ConcurrencyLimitingHttpClient final: public HttpClient { 4105 public: 4106 ConcurrencyLimitingHttpClient( 4107 kj::HttpClient& inner, uint maxConcurrentRequests, 4108 kj::Function<void(uint runningCount, uint pendingCount)> countChangedCallback) 4109 : inner(inner), 4110 maxConcurrentRequests(maxConcurrentRequests), 4111 countChangedCallback(kj::mv(countChangedCallback)) {} 4112 4113 Request request(HttpMethod method, kj::StringPtr url, const HttpHeaders& headers, 4114 kj::Maybe<uint64_t> expectedBodySize = nullptr) override { 4115 if (concurrentRequests < maxConcurrentRequests) { 4116 auto counter = ConnectionCounter(*this); 4117 auto request = inner.request(method, url, headers, expectedBodySize); 4118 fireCountChanged(); 4119 auto promise = attachCounter(kj::mv(request.response), kj::mv(counter)); 4120 return { kj::mv(request.body), kj::mv(promise) }; 4121 } 4122 4123 auto paf = kj::newPromiseAndFulfiller<ConnectionCounter>(); 4124 auto urlCopy = kj::str(url); 4125 auto headersCopy = headers.clone(); 4126 4127 auto combined = paf.promise 4128 .then([this, 4129 method, 4130 urlCopy = kj::mv(urlCopy), 4131 headersCopy = kj::mv(headersCopy), 4132 expectedBodySize](ConnectionCounter&& counter) mutable { 4133 auto req = inner.request(method, urlCopy, headersCopy, expectedBodySize); 4134 return kj::tuple(kj::mv(req.body), attachCounter(kj::mv(req.response), kj::mv(counter))); 4135 }); 4136 auto split = combined.split(); 4137 pendingRequests.push(kj::mv(paf.fulfiller)); 4138 fireCountChanged(); 4139 return { newPromisedStream(kj::mv(kj::get<0>(split))), kj::mv(kj::get<1>(split)) }; 4140 } 4141 4142 kj::Promise<WebSocketResponse> openWebSocket( 4143 kj::StringPtr url, const kj::HttpHeaders& headers) override { 4144 if (concurrentRequests < maxConcurrentRequests) { 4145 auto counter = ConnectionCounter(*this); 4146 auto response = inner.openWebSocket(url, headers); 4147 fireCountChanged(); 4148 return attachCounter(kj::mv(response), kj::mv(counter)); 4149 } 4150 4151 auto paf = kj::newPromiseAndFulfiller<ConnectionCounter>(); 4152 auto urlCopy = kj::str(url); 4153 auto headersCopy = headers.clone(); 4154 4155 auto promise = paf.promise 4156 .then([this, 4157 urlCopy = kj::mv(urlCopy), 4158 headersCopy = kj::mv(headersCopy)](ConnectionCounter&& counter) mutable { 4159 return attachCounter(inner.openWebSocket(urlCopy, headersCopy), kj::mv(counter)); 4160 }); 4161 4162 pendingRequests.push(kj::mv(paf.fulfiller)); 4163 fireCountChanged(); 4164 return kj::mv(promise); 4165 } 4166 4167 private: 4168 struct ConnectionCounter; 4169 4170 kj::HttpClient& inner; 4171 uint maxConcurrentRequests; 4172 uint concurrentRequests = 0; 4173 kj::Function<void(uint runningCount, uint pendingCount)> countChangedCallback; 4174 4175 std::queue<kj::Own<kj::PromiseFulfiller<ConnectionCounter>>> pendingRequests; 4176 // TODO(someday): want maximum cap on queue size? 4177 4178 struct ConnectionCounter final { 4179 ConnectionCounter(ConcurrencyLimitingHttpClient& client) : parent(&client) { 4180 ++parent->concurrentRequests; 4181 } 4182 KJ_DISALLOW_COPY(ConnectionCounter); 4183 ~ConnectionCounter() noexcept(false) { 4184 if (parent != nullptr) { 4185 --parent->concurrentRequests; 4186 parent->serviceQueue(); 4187 parent->fireCountChanged(); 4188 } 4189 } 4190 ConnectionCounter(ConnectionCounter&& other) : parent(other.parent) { 4191 other.parent = nullptr; 4192 } 4193 ConnectionCounter& operator=(ConnectionCounter&& other) { 4194 if (this != &other) { 4195 this->parent = other.parent; 4196 other.parent = nullptr; 4197 } 4198 return *this; 4199 } 4200 4201 ConcurrencyLimitingHttpClient* parent; 4202 }; 4203 4204 void serviceQueue() { 4205 if (concurrentRequests >= maxConcurrentRequests) { return; } 4206 if (pendingRequests.empty()) { return; } 4207 4208 auto fulfiller = kj::mv(pendingRequests.front()); 4209 pendingRequests.pop(); 4210 fulfiller->fulfill(ConnectionCounter(*this)); 4211 } 4212 4213 void fireCountChanged() { 4214 countChangedCallback(concurrentRequests, pendingRequests.size()); 4215 } 4216 4217 using WebSocketOrBody = kj::OneOf<kj::Own<kj::AsyncInputStream>, kj::Own<WebSocket>>; 4218 static WebSocketOrBody attachCounter(WebSocketOrBody&& webSocketOrBody, 4219 ConnectionCounter&& counter) { 4220 KJ_SWITCH_ONEOF(webSocketOrBody) { 4221 KJ_CASE_ONEOF(ws, kj::Own<WebSocket>) { 4222 return ws.attach(kj::mv(counter)); 4223 } 4224 KJ_CASE_ONEOF(body, kj::Own<kj::AsyncInputStream>) { 4225 return body.attach(kj::mv(counter)); 4226 } 4227 } 4228 KJ_UNREACHABLE; 4229 } 4230 4231 static kj::Promise<WebSocketResponse> attachCounter(kj::Promise<WebSocketResponse>&& promise, 4232 ConnectionCounter&& counter) { 4233 return promise.then([counter = kj::mv(counter)](WebSocketResponse&& response) mutable { 4234 return WebSocketResponse { 4235 response.statusCode, 4236 response.statusText, 4237 response.headers, 4238 attachCounter(kj::mv(response.webSocketOrBody), kj::mv(counter)) 4239 }; 4240 }); 4241 } 4242 4243 static kj::Promise<Response> attachCounter(kj::Promise<Response>&& promise, 4244 ConnectionCounter&& counter) { 4245 return promise.then([counter = kj::mv(counter)](Response&& response) mutable { 4246 return Response { 4247 response.statusCode, 4248 response.statusText, 4249 response.headers, 4250 response.body.attach(kj::mv(counter)) 4251 }; 4252 }); 4253 } 4254 }; 4255 4256 } 4257 4258 kj::Own<HttpClient> newConcurrencyLimitingHttpClient( 4259 HttpClient& inner, uint maxConcurrentRequests, 4260 kj::Function<void(uint runningCount, uint pendingCount)> countChangedCallback) { 4261 return kj::heap<ConcurrencyLimitingHttpClient>(inner, maxConcurrentRequests, 4262 kj::mv(countChangedCallback)); 4263 } 4264 4265 // ======================================================================================= 4266 4267 namespace { 4268 4269 class NullInputStream final: public kj::AsyncInputStream { 4270 public: 4271 NullInputStream(kj::Maybe<size_t> expectedLength = size_t(0)) 4272 : expectedLength(expectedLength) {} 4273 4274 kj::Promise<size_t> tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { 4275 return size_t(0); 4276 } 4277 4278 kj::Maybe<uint64_t> tryGetLength() override { 4279 return expectedLength; 4280 } 4281 4282 kj::Promise<uint64_t> pumpTo(AsyncOutputStream& output, uint64_t amount) override { 4283 return uint64_t(0); 4284 } 4285 4286 private: 4287 kj::Maybe<size_t> expectedLength; 4288 }; 4289 4290 class NullOutputStream final: public kj::AsyncOutputStream { 4291 public: 4292 Promise<void> write(const void* buffer, size_t size) override { 4293 return kj::READY_NOW; 4294 } 4295 Promise<void> write(ArrayPtr<const ArrayPtr<const byte>> pieces) override { 4296 return kj::READY_NOW; 4297 } 4298 Promise<void> whenWriteDisconnected() override { 4299 return kj::NEVER_DONE; 4300 } 4301 4302 // We can't really optimize tryPumpFrom() unless AsyncInputStream grows a skip() method. 4303 }; 4304 4305 class HttpClientAdapter final: public HttpClient { 4306 public: 4307 HttpClientAdapter(HttpService& service): service(service) {} 4308 4309 Request request(HttpMethod method, kj::StringPtr url, const HttpHeaders& headers, 4310 kj::Maybe<uint64_t> expectedBodySize = nullptr) override { 4311 // We have to clone the URL and headers because HttpService implementation are allowed to 4312 // assume that they remain valid until the service handler completes whereas HttpClient callers 4313 // are allowed to destroy them immediately after the call. 4314 auto urlCopy = kj::str(url); 4315 auto headersCopy = kj::heap(headers.clone()); 4316 4317 auto pipe = newOneWayPipe(expectedBodySize); 4318 4319 // TODO(cleanup): The ownership relationships here are a mess. Can we do something better 4320 // involving a PromiseAdapter, maybe? 4321 auto paf = kj::newPromiseAndFulfiller<Response>(); 4322 auto responder = kj::refcounted<ResponseImpl>(method, kj::mv(paf.fulfiller)); 4323 4324 auto requestPaf = kj::newPromiseAndFulfiller<kj::Promise<void>>(); 4325 responder->setPromise(kj::mv(requestPaf.promise)); 4326 4327 auto promise = service.request(method, urlCopy, *headersCopy, *pipe.in, *responder) 4328 .attach(kj::mv(pipe.in), kj::mv(urlCopy), kj::mv(headersCopy)); 4329 requestPaf.fulfiller->fulfill(kj::mv(promise)); 4330 4331 return { 4332 kj::mv(pipe.out), 4333 paf.promise.attach(kj::mv(responder)) 4334 }; 4335 } 4336 4337 kj::Promise<WebSocketResponse> openWebSocket( 4338 kj::StringPtr url, const HttpHeaders& headers) override { 4339 // We have to clone the URL and headers because HttpService implementation are allowed to 4340 // assume that they remain valid until the service handler completes whereas HttpClient callers 4341 // are allowed to destroy them immediately after the call. Also we need to add 4342 // `Upgrade: websocket` so that headers.isWebSocket() returns true on the service side. 4343 auto urlCopy = kj::str(url); 4344 auto headersCopy = kj::heap(headers.clone()); 4345 headersCopy->set(HttpHeaderId::UPGRADE, "websocket"); 4346 KJ_DASSERT(headersCopy->isWebSocket()); 4347 4348 auto paf = kj::newPromiseAndFulfiller<WebSocketResponse>(); 4349 auto responder = kj::refcounted<WebSocketResponseImpl>(kj::mv(paf.fulfiller)); 4350 4351 auto requestPaf = kj::newPromiseAndFulfiller<kj::Promise<void>>(); 4352 responder->setPromise(kj::mv(requestPaf.promise)); 4353 4354 auto in = kj::heap<NullInputStream>(); 4355 auto promise = service.request(HttpMethod::GET, urlCopy, *headersCopy, *in, *responder) 4356 .attach(kj::mv(in), kj::mv(urlCopy), kj::mv(headersCopy)); 4357 requestPaf.fulfiller->fulfill(kj::mv(promise)); 4358 4359 return paf.promise.attach(kj::mv(responder)); 4360 } 4361 4362 kj::Promise<kj::Own<kj::AsyncIoStream>> connect(kj::StringPtr host) override { 4363 return service.connect(kj::mv(host)); 4364 } 4365 4366 private: 4367 HttpService& service; 4368 4369 class DelayedEofInputStream final: public kj::AsyncInputStream { 4370 // An AsyncInputStream wrapper that, when it reaches EOF, delays the final read until some 4371 // promise completes. 4372 4373 public: 4374 DelayedEofInputStream(kj::Own<kj::AsyncInputStream> inner, kj::Promise<void> completionTask) 4375 : inner(kj::mv(inner)), completionTask(kj::mv(completionTask)) {} 4376 4377 kj::Promise<size_t> tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { 4378 return wrap(minBytes, inner->tryRead(buffer, minBytes, maxBytes)); 4379 } 4380 4381 kj::Maybe<uint64_t> tryGetLength() override { 4382 return inner->tryGetLength(); 4383 } 4384 4385 kj::Promise<uint64_t> pumpTo(kj::AsyncOutputStream& output, uint64_t amount) override { 4386 return wrap(amount, inner->pumpTo(output, amount)); 4387 } 4388 4389 private: 4390 kj::Own<kj::AsyncInputStream> inner; 4391 kj::Maybe<kj::Promise<void>> completionTask; 4392 4393 template <typename T> 4394 kj::Promise<T> wrap(T requested, kj::Promise<T> innerPromise) { 4395 return innerPromise.then([this,requested](T actual) -> kj::Promise<T> { 4396 if (actual < requested) { 4397 // Must have reached EOF. 4398 KJ_IF_MAYBE(t, completionTask) { 4399 // Delay until completion. 4400 auto result = t->then([actual]() { return actual; }); 4401 completionTask = nullptr; 4402 return result; 4403 } else { 4404 // Must have called tryRead() again after we already signaled EOF. Fine. 4405 return actual; 4406 } 4407 } else { 4408 return actual; 4409 } 4410 }, [this](kj::Exception&& e) -> kj::Promise<T> { 4411 // The stream threw an exception, but this exception is almost certainly just complaining 4412 // that the other end of the stream was dropped. In all likelihood, the HttpService 4413 // request() call itself will throw a much more interesting error -- we'd rather propagate 4414 // that one, if so. 4415 KJ_IF_MAYBE(t, completionTask) { 4416 auto result = t->then([e = kj::mv(e)]() mutable -> kj::Promise<T> { 4417 // Looks like the service didn't throw. I guess we should propagate the stream error 4418 // after all. 4419 return kj::mv(e); 4420 }); 4421 completionTask = nullptr; 4422 return result; 4423 } else { 4424 // Must have called tryRead() again after we already signaled EOF or threw. Fine. 4425 return kj::mv(e); 4426 } 4427 }); 4428 } 4429 }; 4430 4431 class ResponseImpl final: public HttpService::Response, public kj::Refcounted { 4432 public: 4433 ResponseImpl(kj::HttpMethod method, 4434 kj::Own<kj::PromiseFulfiller<HttpClient::Response>> fulfiller) 4435 : method(method), fulfiller(kj::mv(fulfiller)) {} 4436 4437 void setPromise(kj::Promise<void> promise) { 4438 task = promise.eagerlyEvaluate([this](kj::Exception&& exception) { 4439 if (fulfiller->isWaiting()) { 4440 fulfiller->reject(kj::mv(exception)); 4441 } else { 4442 // We need to cause the response stream's read() to throw this, so we should propagate it. 4443 kj::throwRecoverableException(kj::mv(exception)); 4444 } 4445 }); 4446 } 4447 4448 kj::Own<kj::AsyncOutputStream> send( 4449 uint statusCode, kj::StringPtr statusText, const HttpHeaders& headers, 4450 kj::Maybe<uint64_t> expectedBodySize = nullptr) override { 4451 // The caller of HttpClient is allowed to assume that the statusText and headers remain 4452 // valid until the body stream is dropped, but the HttpService implementation is allowed to 4453 // send values that are only valid until send() returns, so we have to copy. 4454 auto statusTextCopy = kj::str(statusText); 4455 auto headersCopy = kj::heap(headers.clone()); 4456 4457 if (method == kj::HttpMethod::HEAD || expectedBodySize.orDefault(1) == 0) { 4458 // We're not expecting any body. We need to delay reporting completion to the client until 4459 // the server side has actually returned from the service method, otherwise we may 4460 // prematurely cancel it. 4461 4462 task = task.then([this,statusCode,statusTextCopy=kj::mv(statusTextCopy), 4463 headersCopy=kj::mv(headersCopy),expectedBodySize]() mutable { 4464 fulfiller->fulfill({ 4465 statusCode, statusTextCopy, headersCopy.get(), 4466 kj::heap<NullInputStream>(expectedBodySize) 4467 .attach(kj::mv(statusTextCopy), kj::mv(headersCopy)) 4468 }); 4469 }).eagerlyEvaluate([](kj::Exception&& e) { KJ_LOG(ERROR, e); }); 4470 return kj::heap<NullOutputStream>(); 4471 } else { 4472 auto pipe = newOneWayPipe(expectedBodySize); 4473 4474 // Wrap the stream in a wrapper that delays the last read (the one that signals EOF) until 4475 // the service's request promise has finished. 4476 auto wrapper = kj::heap<DelayedEofInputStream>( 4477 kj::mv(pipe.in), task.attach(kj::addRef(*this))); 4478 4479 fulfiller->fulfill({ 4480 statusCode, statusTextCopy, headersCopy.get(), 4481 wrapper.attach(kj::mv(statusTextCopy), kj::mv(headersCopy)) 4482 }); 4483 return kj::mv(pipe.out); 4484 } 4485 } 4486 4487 kj::Own<WebSocket> acceptWebSocket(const HttpHeaders& headers) override { 4488 KJ_FAIL_REQUIRE("a WebSocket was not requested"); 4489 } 4490 4491 private: 4492 kj::HttpMethod method; 4493 kj::Own<kj::PromiseFulfiller<HttpClient::Response>> fulfiller; 4494 kj::Promise<void> task = nullptr; 4495 }; 4496 4497 class DelayedCloseWebSocket final: public WebSocket { 4498 // A WebSocket wrapper that, when it reaches Close (in both directions), delays the final close 4499 // operation until some promise completes. 4500 4501 public: 4502 DelayedCloseWebSocket(kj::Own<kj::WebSocket> inner, kj::Promise<void> completionTask) 4503 : inner(kj::mv(inner)), completionTask(kj::mv(completionTask)) {} 4504 4505 kj::Promise<void> send(kj::ArrayPtr<const byte> message) override { 4506 return inner->send(message); 4507 } 4508 kj::Promise<void> send(kj::ArrayPtr<const char> message) override { 4509 return inner->send(message); 4510 } 4511 kj::Promise<void> close(uint16_t code, kj::StringPtr reason) override { 4512 return inner->close(code, reason) 4513 .then([this]() { 4514 return afterSendClosed(); 4515 }); 4516 } 4517 kj::Promise<void> disconnect() override { 4518 return inner->disconnect(); 4519 } 4520 void abort() override { 4521 // Don't need to worry about completion task in this case -- cancelling it is reasonable. 4522 inner->abort(); 4523 } 4524 kj::Promise<void> whenAborted() override { 4525 return inner->whenAborted(); 4526 } 4527 kj::Promise<Message> receive(size_t maxSize) override { 4528 return inner->receive(maxSize).then([this](Message&& message) -> kj::Promise<Message> { 4529 if (message.is<WebSocket::Close>()) { 4530 return afterReceiveClosed() 4531 .then([message = kj::mv(message)]() mutable { return kj::mv(message); }); 4532 } 4533 return kj::mv(message); 4534 }); 4535 } 4536 kj::Promise<void> pumpTo(WebSocket& other) override { 4537 return inner->pumpTo(other).then([this]() { 4538 return afterReceiveClosed(); 4539 }); 4540 } 4541 kj::Maybe<kj::Promise<void>> tryPumpFrom(WebSocket& other) override { 4542 return other.pumpTo(*inner).then([this]() { 4543 return afterSendClosed(); 4544 }); 4545 } 4546 4547 uint64_t sentByteCount() override { return inner->sentByteCount(); } 4548 uint64_t receivedByteCount() override { return inner->receivedByteCount(); } 4549 4550 private: 4551 kj::Own<kj::WebSocket> inner; 4552 kj::Maybe<kj::Promise<void>> completionTask; 4553 4554 bool sentClose = false; 4555 bool receivedClose = false; 4556 4557 kj::Promise<void> afterSendClosed() { 4558 sentClose = true; 4559 if (receivedClose) { 4560 KJ_IF_MAYBE(t, completionTask) { 4561 auto result = kj::mv(*t); 4562 completionTask = nullptr; 4563 return result; 4564 } 4565 } 4566 return kj::READY_NOW; 4567 } 4568 4569 kj::Promise<void> afterReceiveClosed() { 4570 receivedClose = true; 4571 if (sentClose) { 4572 KJ_IF_MAYBE(t, completionTask) { 4573 auto result = kj::mv(*t); 4574 completionTask = nullptr; 4575 return result; 4576 } 4577 } 4578 return kj::READY_NOW; 4579 } 4580 }; 4581 4582 class WebSocketResponseImpl final: public HttpService::Response, public kj::Refcounted { 4583 public: 4584 WebSocketResponseImpl(kj::Own<kj::PromiseFulfiller<HttpClient::WebSocketResponse>> fulfiller) 4585 : fulfiller(kj::mv(fulfiller)) {} 4586 4587 void setPromise(kj::Promise<void> promise) { 4588 task = promise.eagerlyEvaluate([this](kj::Exception&& exception) { 4589 if (fulfiller->isWaiting()) { 4590 fulfiller->reject(kj::mv(exception)); 4591 } else { 4592 // We need to cause the client-side WebSocket to throw on close, so propagate the 4593 // exception. 4594 kj::throwRecoverableException(kj::mv(exception)); 4595 } 4596 }); 4597 } 4598 4599 kj::Own<kj::AsyncOutputStream> send( 4600 uint statusCode, kj::StringPtr statusText, const HttpHeaders& headers, 4601 kj::Maybe<uint64_t> expectedBodySize = nullptr) override { 4602 // The caller of HttpClient is allowed to assume that the statusText and headers remain 4603 // valid until the body stream is dropped, but the HttpService implementation is allowed to 4604 // send values that are only valid until send() returns, so we have to copy. 4605 auto statusTextCopy = kj::str(statusText); 4606 auto headersCopy = kj::heap(headers.clone()); 4607 4608 if (expectedBodySize.orDefault(1) == 0) { 4609 // We're not expecting any body. We need to delay reporting completion to the client until 4610 // the server side has actually returned from the service method, otherwise we may 4611 // prematurely cancel it. 4612 4613 task = task.then([this,statusCode,statusTextCopy=kj::mv(statusTextCopy), 4614 headersCopy=kj::mv(headersCopy),expectedBodySize]() mutable { 4615 fulfiller->fulfill({ 4616 statusCode, statusTextCopy, headersCopy.get(), 4617 kj::Own<AsyncInputStream>(kj::heap<NullInputStream>(expectedBodySize) 4618 .attach(kj::mv(statusTextCopy), kj::mv(headersCopy))) 4619 }); 4620 }).eagerlyEvaluate([](kj::Exception&& e) { KJ_LOG(ERROR, e); }); 4621 return kj::heap<NullOutputStream>(); 4622 } else { 4623 auto pipe = newOneWayPipe(expectedBodySize); 4624 4625 // Wrap the stream in a wrapper that delays the last read (the one that signals EOF) until 4626 // the service's request promise has finished. 4627 kj::Own<AsyncInputStream> wrapper = 4628 kj::heap<DelayedEofInputStream>(kj::mv(pipe.in), task.attach(kj::addRef(*this))); 4629 4630 fulfiller->fulfill({ 4631 statusCode, statusTextCopy, headersCopy.get(), 4632 wrapper.attach(kj::mv(statusTextCopy), kj::mv(headersCopy)) 4633 }); 4634 return kj::mv(pipe.out); 4635 } 4636 } 4637 4638 kj::Own<WebSocket> acceptWebSocket(const HttpHeaders& headers) override { 4639 // The caller of HttpClient is allowed to assume that the headers remain valid until the body 4640 // stream is dropped, but the HttpService implementation is allowed to send headers that are 4641 // only valid until acceptWebSocket() returns, so we have to copy. 4642 auto headersCopy = kj::heap(headers.clone()); 4643 4644 auto pipe = newWebSocketPipe(); 4645 4646 // Wrap the client-side WebSocket in a wrapper that delays clean close of the WebSocket until 4647 // the service's request promise has finished. 4648 kj::Own<WebSocket> wrapper = 4649 kj::heap<DelayedCloseWebSocket>(kj::mv(pipe.ends[0]), task.attach(kj::addRef(*this))); 4650 fulfiller->fulfill({ 4651 101, "Switching Protocols", headersCopy.get(), 4652 wrapper.attach(kj::mv(headersCopy)) 4653 }); 4654 return kj::mv(pipe.ends[1]); 4655 } 4656 4657 private: 4658 kj::Own<kj::PromiseFulfiller<HttpClient::WebSocketResponse>> fulfiller; 4659 kj::Promise<void> task = nullptr; 4660 }; 4661 }; 4662 4663 } // namespace 4664 4665 kj::Own<HttpClient> newHttpClient(HttpService& service) { 4666 return kj::heap<HttpClientAdapter>(service); 4667 } 4668 4669 // ======================================================================================= 4670 4671 namespace { 4672 4673 class HttpServiceAdapter final: public HttpService { 4674 public: 4675 HttpServiceAdapter(HttpClient& client): client(client) {} 4676 4677 kj::Promise<void> request( 4678 HttpMethod method, kj::StringPtr url, const HttpHeaders& headers, 4679 kj::AsyncInputStream& requestBody, Response& response) override { 4680 if (!headers.isWebSocket()) { 4681 auto innerReq = client.request(method, url, headers, requestBody.tryGetLength()); 4682 4683 auto promises = kj::heapArrayBuilder<kj::Promise<void>>(2); 4684 promises.add(requestBody.pumpTo(*innerReq.body).ignoreResult() 4685 .attach(kj::mv(innerReq.body)).eagerlyEvaluate(nullptr)); 4686 4687 promises.add(innerReq.response 4688 .then([&response](HttpClient::Response&& innerResponse) { 4689 auto out = response.send( 4690 innerResponse.statusCode, innerResponse.statusText, *innerResponse.headers, 4691 innerResponse.body->tryGetLength()); 4692 auto promise = innerResponse.body->pumpTo(*out); 4693 return promise.ignoreResult().attach(kj::mv(out), kj::mv(innerResponse.body)); 4694 })); 4695 4696 return kj::joinPromises(promises.finish()); 4697 } else { 4698 return client.openWebSocket(url, headers) 4699 .then([&response](HttpClient::WebSocketResponse&& innerResponse) -> kj::Promise<void> { 4700 KJ_SWITCH_ONEOF(innerResponse.webSocketOrBody) { 4701 KJ_CASE_ONEOF(ws, kj::Own<WebSocket>) { 4702 auto ws2 = response.acceptWebSocket(*innerResponse.headers); 4703 auto promises = kj::heapArrayBuilder<kj::Promise<void>>(2); 4704 promises.add(ws->pumpTo(*ws2)); 4705 promises.add(ws2->pumpTo(*ws)); 4706 return kj::joinPromises(promises.finish()).attach(kj::mv(ws), kj::mv(ws2)); 4707 } 4708 KJ_CASE_ONEOF(body, kj::Own<kj::AsyncInputStream>) { 4709 auto out = response.send( 4710 innerResponse.statusCode, innerResponse.statusText, *innerResponse.headers, 4711 body->tryGetLength()); 4712 auto promise = body->pumpTo(*out); 4713 return promise.ignoreResult().attach(kj::mv(out), kj::mv(body)); 4714 } 4715 } 4716 KJ_UNREACHABLE; 4717 }); 4718 } 4719 } 4720 4721 kj::Promise<kj::Own<kj::AsyncIoStream>> connect(kj::StringPtr host) override { 4722 return client.connect(kj::mv(host)); 4723 } 4724 4725 private: 4726 HttpClient& client; 4727 }; 4728 4729 } // namespace 4730 4731 kj::Own<HttpService> newHttpService(HttpClient& client) { 4732 return kj::heap<HttpServiceAdapter>(client); 4733 } 4734 4735 // ======================================================================================= 4736 4737 kj::Promise<void> HttpService::Response::sendError( 4738 uint statusCode, kj::StringPtr statusText, const HttpHeaders& headers) { 4739 auto stream = send(statusCode, statusText, headers, statusText.size()); 4740 auto promise = stream->write(statusText.begin(), statusText.size()); 4741 return promise.attach(kj::mv(stream)); 4742 } 4743 4744 kj::Promise<void> HttpService::Response::sendError( 4745 uint statusCode, kj::StringPtr statusText, const HttpHeaderTable& headerTable) { 4746 return sendError(statusCode, statusText, HttpHeaders(headerTable)); 4747 } 4748 4749 kj::Promise<kj::Own<kj::AsyncIoStream>> HttpService::connect(kj::StringPtr host) { 4750 KJ_UNIMPLEMENTED("CONNECT is not implemented by this HttpService"); 4751 } 4752 4753 class HttpServer::Connection final: private HttpService::Response, 4754 private HttpServerErrorHandler { 4755 public: 4756 Connection(HttpServer& server, kj::AsyncIoStream& stream, 4757 HttpService& service) 4758 : server(server), 4759 stream(stream), 4760 service(service), 4761 httpInput(stream, server.requestHeaderTable), 4762 httpOutput(stream) { 4763 ++server.connectionCount; 4764 } 4765 ~Connection() noexcept(false) { 4766 if (--server.connectionCount == 0) { 4767 KJ_IF_MAYBE(f, server.zeroConnectionsFulfiller) { 4768 f->get()->fulfill(); 4769 } 4770 } 4771 } 4772 4773 public: 4774 kj::Promise<bool> startLoop(bool firstRequest) { 4775 return loop(firstRequest).catch_([this](kj::Exception&& e) -> kj::Promise<bool> { 4776 // Exception; report 5xx. 4777 4778 KJ_IF_MAYBE(p, webSocketError) { 4779 // sendWebSocketError() was called. Finish sending and close the connection. Don't log 4780 // the exception because it's probably a side-effect of this. 4781 auto promise = kj::mv(*p); 4782 webSocketError = nullptr; 4783 return kj::mv(promise); 4784 } 4785 4786 return sendError(kj::mv(e)); 4787 }); 4788 } 4789 4790 private: 4791 HttpServer& server; 4792 kj::AsyncIoStream& stream; 4793 HttpService& service; 4794 HttpInputStreamImpl httpInput; 4795 HttpOutputStream httpOutput; 4796 kj::Maybe<HttpMethod> currentMethod; 4797 bool timedOut = false; 4798 bool closed = false; 4799 bool upgraded = false; 4800 bool webSocketClosed = false; 4801 bool closeAfterSend = false; // True if send() should set Connection: close. 4802 kj::Maybe<kj::Promise<bool>> webSocketError; 4803 4804 kj::Promise<bool> loop(bool firstRequest) { 4805 if (!firstRequest && server.draining && httpInput.isCleanDrain()) { 4806 // Don't call awaitNextMessage() in this case because that will initiate a read() which will 4807 // immediately be canceled, losing data. 4808 return true; 4809 } 4810 4811 auto firstByte = httpInput.awaitNextMessage(); 4812 4813 if (!firstRequest) { 4814 // For requests after the first, require that the first byte arrive before the pipeline 4815 // timeout, otherwise treat it like the connection was simply closed. 4816 auto timeoutPromise = server.timer.afterDelay(server.settings.pipelineTimeout); 4817 4818 if (httpInput.isCleanDrain()) { 4819 // If we haven't buffered any data, then we can safely drain here, so allow the wait to 4820 // be canceled by the onDrain promise. 4821 timeoutPromise = timeoutPromise.exclusiveJoin(server.onDrain.addBranch()); 4822 } 4823 4824 firstByte = firstByte.exclusiveJoin(timeoutPromise.then([this]() -> bool { 4825 timedOut = true; 4826 return false; 4827 })); 4828 } 4829 4830 auto receivedHeaders = firstByte 4831 .then([this,firstRequest](bool hasData) 4832 -> kj::Promise<HttpHeaders::RequestOrProtocolError> { 4833 if (hasData) { 4834 auto readHeaders = httpInput.readRequestHeaders(); 4835 if (!firstRequest) { 4836 // On requests other than the first, the header timeout starts ticking when we receive 4837 // the first byte of a pipeline response. 4838 readHeaders = readHeaders.exclusiveJoin( 4839 server.timer.afterDelay(server.settings.headerTimeout) 4840 .then([this]() -> HttpHeaders::RequestOrProtocolError { 4841 timedOut = true; 4842 return HttpHeaders::ProtocolError { 4843 408, "Request Timeout", 4844 "Timed out waiting for next request headers.", nullptr 4845 }; 4846 })); 4847 } 4848 return kj::mv(readHeaders); 4849 } else { 4850 // Client closed connection or pipeline timed out with no bytes received. This is not an 4851 // error, so don't report one. 4852 this->closed = true; 4853 return HttpHeaders::RequestOrProtocolError(HttpHeaders::ProtocolError { 4854 408, "Request Timeout", 4855 "Client closed connection or connection timeout " 4856 "while waiting for request headers.", nullptr 4857 }); 4858 } 4859 }); 4860 4861 if (firstRequest) { 4862 // On the first request, the header timeout starts ticking immediately upon request opening. 4863 auto timeoutPromise = server.timer.afterDelay(server.settings.headerTimeout) 4864 .exclusiveJoin(server.onDrain.addBranch()) 4865 .then([this]() -> HttpHeaders::RequestOrProtocolError { 4866 timedOut = true; 4867 return HttpHeaders::ProtocolError { 4868 408, "Request Timeout", 4869 "Timed out waiting for initial request headers.", nullptr 4870 }; 4871 }); 4872 receivedHeaders = receivedHeaders.exclusiveJoin(kj::mv(timeoutPromise)); 4873 } 4874 4875 return receivedHeaders 4876 .then([this](HttpHeaders::RequestOrProtocolError&& requestOrProtocolError) 4877 -> kj::Promise<bool> { 4878 if (timedOut) { 4879 // Client took too long to send anything, so we're going to close the connection. In 4880 // theory, we should send back an HTTP 408 error -- it is designed exactly for this 4881 // purpose. Alas, in practice, Google Chrome does not have any special handling for 408 4882 // errors -- it will assume the error is a response to the next request it tries to send, 4883 // and will happily serve the error to the user. OTOH, if we simply close the connection, 4884 // Chrome does the "right thing", apparently. (Though I'm not sure what happens if a 4885 // request is in-flight when we close... if it's a GET, the browser should retry. But if 4886 // it's a POST, retrying may be dangerous. This is why 408 exists -- it unambiguously 4887 // tells the client that it should retry.) 4888 // 4889 // Also note that if we ever decide to send 408 again, we might want to send some other 4890 // error in the case that the server is draining, which also sets timedOut = true; see 4891 // above. 4892 4893 return httpOutput.flush().then([this]() { 4894 return server.draining && httpInput.isCleanDrain(); 4895 }); 4896 } 4897 4898 if (closed) { 4899 // Client closed connection. Close our end too. 4900 return httpOutput.flush().then([]() { return false; }); 4901 } 4902 4903 KJ_SWITCH_ONEOF(requestOrProtocolError) { 4904 KJ_CASE_ONEOF(request, HttpHeaders::Request) { 4905 auto& headers = httpInput.getHeaders(); 4906 4907 currentMethod = request.method; 4908 auto body = httpInput.getEntityBody( 4909 HttpInputStreamImpl::REQUEST, request.method, 0, headers); 4910 4911 // TODO(perf): If the client disconnects, should we cancel the response? Probably, to 4912 // prevent permanent deadlock. It's slightly weird in that arguably the client should 4913 // be able to shutdown the upstream but still wait on the downstream, but I believe many 4914 // other HTTP servers do similar things. 4915 4916 auto promise = service.request( 4917 request.method, request.url, headers, *body, *this); 4918 return promise.then([this, body = kj::mv(body)]() mutable -> kj::Promise<bool> { 4919 // Response done. Await next request. 4920 4921 KJ_IF_MAYBE(p, webSocketError) { 4922 // sendWebSocketError() was called. Finish sending and close the connection. 4923 auto promise = kj::mv(*p); 4924 webSocketError = nullptr; 4925 return kj::mv(promise); 4926 } 4927 4928 if (upgraded) { 4929 // We've upgraded to WebSocket, and by now we should have closed the WebSocket. 4930 if (!webSocketClosed) { 4931 // This is gonna segfault later so abort now instead. 4932 KJ_LOG(FATAL, "Accepted WebSocket object must be destroyed before HttpService " 4933 "request handler completes."); 4934 abort(); 4935 } 4936 4937 // Once we start a WebSocket there's no going back to HTTP. 4938 return false; 4939 } 4940 4941 if (currentMethod != nullptr) { 4942 return sendError(); 4943 } 4944 4945 if (httpOutput.isBroken()) { 4946 // We started a response but didn't finish it. But HttpService returns success? 4947 // Perhaps it decided that it doesn't want to finish this response. We'll have to 4948 // disconnect here. If the response body is not complete (e.g. Content-Length not 4949 // reached), the client should notice. We don't want to log an error because this 4950 // condition might be intentional on the service's part. 4951 return false; 4952 } 4953 4954 return httpOutput.flush().then( 4955 [this, body = kj::mv(body)]() mutable -> kj::Promise<bool> { 4956 if (httpInput.canReuse()) { 4957 // Things look clean. Go ahead and accept the next request. 4958 4959 // Note that we don't have to handle server.draining here because we'll take care of 4960 // it the next time around the loop. 4961 return loop(false); 4962 } else { 4963 // Apparently, the application did not read the request body. Maybe this is a bug, 4964 // or maybe not: maybe the client tried to upload too much data and the application 4965 // legitimately wants to cancel the upload without reading all it it. 4966 // 4967 // We have a problem, though: We did send a response, and we didn't send 4968 // `Connection: close`, so the client may expect that it can send another request. 4969 // Perhaps the client has even finished sending the previous request's body, in 4970 // which case the moment it finishes receiving the response, it could be completely 4971 // within its rights to start a new request. If we close the socket now, we might 4972 // interrupt that new request. 4973 // 4974 // There's no way we can get out of this perfectly cleanly. HTTP just isn't good 4975 // enough at connection management. The best we can do is give the client some grace 4976 // period and then abort the connection. 4977 4978 auto dummy = kj::heap<HttpDiscardingEntityWriter>(); 4979 auto lengthGrace = body->pumpTo(*dummy, server.settings.canceledUploadGraceBytes) 4980 .then([this](size_t amount) { 4981 if (httpInput.canReuse()) { 4982 // Success, we can continue. 4983 return true; 4984 } else { 4985 // Still more data. Give up. 4986 return false; 4987 } 4988 }); 4989 lengthGrace = lengthGrace.attach(kj::mv(dummy), kj::mv(body)); 4990 4991 auto timeGrace = server.timer.afterDelay(server.settings.canceledUploadGracePeriod) 4992 .then([]() { return false; }); 4993 4994 return lengthGrace.exclusiveJoin(kj::mv(timeGrace)) 4995 .then([this](bool clean) -> kj::Promise<bool> { 4996 if (clean) { 4997 // We recovered. Continue loop. 4998 return loop(false); 4999 } else { 5000 // Client still not done. Return broken. 5001 return false; 5002 } 5003 }); 5004 } 5005 }); 5006 }); 5007 } 5008 KJ_CASE_ONEOF(protocolError, HttpHeaders::ProtocolError) { 5009 // Bad request. 5010 5011 // sendError() uses Response::send(), which requires that we have a currentMethod, but we 5012 // never read one. GET seems like the correct choice here. 5013 currentMethod = HttpMethod::GET; 5014 return sendError(kj::mv(protocolError)); 5015 } 5016 } 5017 5018 KJ_UNREACHABLE; 5019 }); 5020 } 5021 5022 kj::Own<kj::AsyncOutputStream> send( 5023 uint statusCode, kj::StringPtr statusText, const HttpHeaders& headers, 5024 kj::Maybe<uint64_t> expectedBodySize) override { 5025 auto method = KJ_REQUIRE_NONNULL(currentMethod, "already called send()"); 5026 currentMethod = nullptr; 5027 5028 kj::StringPtr connectionHeaders[HttpHeaders::CONNECTION_HEADERS_COUNT]; 5029 kj::String lengthStr; 5030 5031 if (!closeAfterSend) { 5032 // Check if application wants us to close connections. 5033 KJ_IF_MAYBE(c, server.settings.callbacks) { 5034 if (c->shouldClose()) { 5035 closeAfterSend = true; 5036 } 5037 } 5038 } 5039 5040 // TODO(0.10): If `server.draining`, we should probably set `closeAfterSend` -- UNLESS the 5041 // connection was created using listenHttpCleanDrain(), in which case the application may 5042 // intend to continue using the connection. 5043 5044 if (closeAfterSend) { 5045 connectionHeaders[HttpHeaders::BuiltinIndices::CONNECTION] = "close"; 5046 } 5047 5048 if (statusCode == 204 || statusCode == 304) { 5049 // No entity-body. 5050 } else if (statusCode == 205) { 5051 // Status code 205 also has no body, but unlike 204 and 304, it must explicitly encode an 5052 // empty body, e.g. using content-length: 0. I'm guessing this is one of those things, where 5053 // some early clients expected an explicit body while others assumed an empty body, and so 5054 // the standard had to choose the common denominator. 5055 // 5056 // Spec: https://tools.ietf.org/html/rfc7231#section-6.3.6 5057 connectionHeaders[HttpHeaders::BuiltinIndices::CONTENT_LENGTH] = "0"; 5058 } else KJ_IF_MAYBE(s, expectedBodySize) { 5059 // HACK: We interpret a zero-length expected body length on responses to HEAD requests to mean 5060 // "don't set a Content-Length header at all." This provides a way to omit a body header on 5061 // HEAD responses with non-null-body status codes. This is a hack that *only* makes sense 5062 // for HEAD responses. 5063 if (method != HttpMethod::HEAD || *s > 0) { 5064 lengthStr = kj::str(*s); 5065 connectionHeaders[HttpHeaders::BuiltinIndices::CONTENT_LENGTH] = lengthStr; 5066 } 5067 } else { 5068 connectionHeaders[HttpHeaders::BuiltinIndices::TRANSFER_ENCODING] = "chunked"; 5069 } 5070 5071 // For HEAD requests, if the application specified a Content-Length or Transfer-Encoding 5072 // header, use that instead of whatever we decided above. 5073 kj::ArrayPtr<kj::StringPtr> connectionHeadersArray = connectionHeaders; 5074 if (method == HttpMethod::HEAD) { 5075 if (headers.get(HttpHeaderId::CONTENT_LENGTH) != nullptr || 5076 headers.get(HttpHeaderId::TRANSFER_ENCODING) != nullptr) { 5077 connectionHeadersArray = connectionHeadersArray 5078 .slice(0, HttpHeaders::HEAD_RESPONSE_CONNECTION_HEADERS_COUNT); 5079 } 5080 } 5081 5082 httpOutput.writeHeaders(headers.serializeResponse( 5083 statusCode, statusText, connectionHeadersArray)); 5084 5085 kj::Own<kj::AsyncOutputStream> bodyStream; 5086 if (method == HttpMethod::HEAD) { 5087 // Ignore entity-body. 5088 httpOutput.finishBody(); 5089 return heap<HttpDiscardingEntityWriter>(); 5090 } else if (statusCode == 204 || statusCode == 205 || statusCode == 304) { 5091 // No entity-body. 5092 httpOutput.finishBody(); 5093 return heap<HttpNullEntityWriter>(); 5094 } else KJ_IF_MAYBE(s, expectedBodySize) { 5095 return heap<HttpFixedLengthEntityWriter>(httpOutput, *s); 5096 } else { 5097 return heap<HttpChunkedEntityWriter>(httpOutput); 5098 } 5099 } 5100 5101 kj::Own<WebSocket> acceptWebSocket(const HttpHeaders& headers) override { 5102 auto& requestHeaders = httpInput.getHeaders(); 5103 KJ_REQUIRE(requestHeaders.isWebSocket(), 5104 "can't call acceptWebSocket() if the request headers didn't have Upgrade: WebSocket"); 5105 5106 auto method = KJ_REQUIRE_NONNULL(currentMethod, "already called send()"); 5107 // Unlike send(), we neither need nor want to null out currentMethod. The error cases below 5108 // depend on it being non-null to allow error responses to be sent, and the happy path expects 5109 // it to be GET. 5110 5111 if (method != HttpMethod::GET) { 5112 return sendWebSocketError("WebSocket must be initiated with a GET request."); 5113 } 5114 5115 if (requestHeaders.get(HttpHeaderId::SEC_WEBSOCKET_VERSION).orDefault(nullptr) != "13") { 5116 return sendWebSocketError("The requested WebSocket version is not supported."); 5117 } 5118 5119 kj::String key; 5120 KJ_IF_MAYBE(k, requestHeaders.get(HttpHeaderId::SEC_WEBSOCKET_KEY)) { 5121 key = kj::str(*k); 5122 } else { 5123 return sendWebSocketError("Missing Sec-WebSocket-Key"); 5124 } 5125 5126 auto websocketAccept = generateWebSocketAccept(key); 5127 5128 kj::StringPtr connectionHeaders[HttpHeaders::WEBSOCKET_CONNECTION_HEADERS_COUNT]; 5129 connectionHeaders[HttpHeaders::BuiltinIndices::SEC_WEBSOCKET_ACCEPT] = websocketAccept; 5130 connectionHeaders[HttpHeaders::BuiltinIndices::UPGRADE] = "websocket"; 5131 connectionHeaders[HttpHeaders::BuiltinIndices::CONNECTION] = "Upgrade"; 5132 5133 httpOutput.writeHeaders(headers.serializeResponse( 5134 101, "Switching Protocols", connectionHeaders)); 5135 5136 upgraded = true; 5137 // We need to give the WebSocket an Own<AsyncIoStream>, but we only have a reference. This is 5138 // safe because the application is expected to drop the WebSocket object before returning 5139 // from the request handler. For some extra safety, we check that webSocketClosed has been 5140 // set true when the handler returns. 5141 auto deferNoteClosed = kj::defer([this]() { webSocketClosed = true; }); 5142 kj::Own<kj::AsyncIoStream> ownStream(&stream, kj::NullDisposer::instance); 5143 return upgradeToWebSocket(ownStream.attach(kj::mv(deferNoteClosed)), 5144 httpInput, httpOutput, nullptr); 5145 } 5146 5147 kj::Promise<bool> sendError(HttpHeaders::ProtocolError protocolError) { 5148 closeAfterSend = true; 5149 5150 // Client protocol errors always happen on request headers parsing, before we call into the 5151 // HttpService, meaning no response has been sent and we can provide a Response object. 5152 auto promise = server.settings.errorHandler.orDefault(*this).handleClientProtocolError( 5153 kj::mv(protocolError), *this); 5154 5155 return promise.then([this]() { return httpOutput.flush(); }) 5156 .then([]() { return false; }); // loop ends after flush 5157 } 5158 5159 kj::Promise<bool> sendError(kj::Exception&& exception) { 5160 closeAfterSend = true; 5161 5162 // We only provide the Response object if we know we haven't already sent a response. 5163 auto promise = server.settings.errorHandler.orDefault(*this).handleApplicationError( 5164 kj::mv(exception), currentMethod.map([this](auto&&) -> Response& { return *this; })); 5165 5166 return promise.then([this]() { return httpOutput.flush(); }) 5167 .then([]() { return false; }); // loop ends after flush 5168 } 5169 5170 kj::Promise<bool> sendError() { 5171 closeAfterSend = true; 5172 5173 // We can provide a Response object, since none has already been sent. 5174 auto promise = server.settings.errorHandler.orDefault(*this).handleNoResponse(*this); 5175 5176 return promise.then([this]() { return httpOutput.flush(); }) 5177 .then([]() { return false; }); // loop ends after flush 5178 } 5179 5180 kj::Own<WebSocket> sendWebSocketError(StringPtr errorMessage) { 5181 kj::Exception exception = KJ_EXCEPTION(FAILED, 5182 "received bad WebSocket handshake", errorMessage); 5183 webSocketError = sendError( 5184 HttpHeaders::ProtocolError { 400, "Bad Request", errorMessage, nullptr }); 5185 kj::throwRecoverableException(kj::mv(exception)); 5186 5187 // Fallback path when exceptions are disabled. 5188 class BrokenWebSocket final: public WebSocket { 5189 public: 5190 BrokenWebSocket(kj::Exception exception): exception(kj::mv(exception)) {} 5191 5192 kj::Promise<void> send(kj::ArrayPtr<const byte> message) override { 5193 return kj::cp(exception); 5194 } 5195 kj::Promise<void> send(kj::ArrayPtr<const char> message) override { 5196 return kj::cp(exception); 5197 } 5198 kj::Promise<void> close(uint16_t code, kj::StringPtr reason) override { 5199 return kj::cp(exception); 5200 } 5201 kj::Promise<void> disconnect() override { 5202 return kj::cp(exception); 5203 } 5204 void abort() override { 5205 kj::throwRecoverableException(kj::cp(exception)); 5206 } 5207 kj::Promise<void> whenAborted() override { 5208 return kj::cp(exception); 5209 } 5210 kj::Promise<Message> receive(size_t maxSize) override { 5211 return kj::cp(exception); 5212 } 5213 5214 uint64_t sentByteCount() override { KJ_FAIL_ASSERT("received bad WebSocket handshake"); } 5215 uint64_t receivedByteCount() override { KJ_FAIL_ASSERT("received bad WebSocket handshake"); } 5216 5217 private: 5218 kj::Exception exception; 5219 }; 5220 5221 return kj::heap<BrokenWebSocket>(KJ_EXCEPTION(FAILED, 5222 "received bad WebSocket handshake", errorMessage)); 5223 } 5224 }; 5225 5226 HttpServer::HttpServer(kj::Timer& timer, const HttpHeaderTable& requestHeaderTable, 5227 HttpService& service, Settings settings) 5228 : HttpServer(timer, requestHeaderTable, &service, settings, 5229 kj::newPromiseAndFulfiller<void>()) {} 5230 5231 HttpServer::HttpServer(kj::Timer& timer, const HttpHeaderTable& requestHeaderTable, 5232 HttpServiceFactory serviceFactory, Settings settings) 5233 : HttpServer(timer, requestHeaderTable, kj::mv(serviceFactory), settings, 5234 kj::newPromiseAndFulfiller<void>()) {} 5235 5236 HttpServer::HttpServer(kj::Timer& timer, const HttpHeaderTable& requestHeaderTable, 5237 kj::OneOf<HttpService*, HttpServiceFactory> service, 5238 Settings settings, kj::PromiseFulfillerPair<void> paf) 5239 : timer(timer), requestHeaderTable(requestHeaderTable), service(kj::mv(service)), 5240 settings(settings), onDrain(paf.promise.fork()), drainFulfiller(kj::mv(paf.fulfiller)), 5241 tasks(*this) {} 5242 5243 kj::Promise<void> HttpServer::drain() { 5244 KJ_REQUIRE(!draining, "you can only call drain() once"); 5245 5246 draining = true; 5247 drainFulfiller->fulfill(); 5248 5249 if (connectionCount == 0) { 5250 return kj::READY_NOW; 5251 } else { 5252 auto paf = kj::newPromiseAndFulfiller<void>(); 5253 zeroConnectionsFulfiller = kj::mv(paf.fulfiller); 5254 return kj::mv(paf.promise); 5255 } 5256 } 5257 5258 kj::Promise<void> HttpServer::listenHttp(kj::ConnectionReceiver& port) { 5259 return listenLoop(port).exclusiveJoin(onDrain.addBranch()); 5260 } 5261 5262 kj::Promise<void> HttpServer::listenLoop(kj::ConnectionReceiver& port) { 5263 return port.accept() 5264 .then([this,&port](kj::Own<kj::AsyncIoStream>&& connection) -> kj::Promise<void> { 5265 if (draining) { 5266 // Can get here if we *just* started draining. 5267 return kj::READY_NOW; 5268 } 5269 5270 tasks.add(listenHttp(kj::mv(connection))); 5271 return listenLoop(port); 5272 }); 5273 } 5274 5275 kj::Promise<void> HttpServer::listenHttp(kj::Own<kj::AsyncIoStream> connection) { 5276 auto promise = listenHttpCleanDrain(*connection).ignoreResult(); 5277 5278 // eagerlyEvaluate() to maintain historical guarantee that this method eagerly closes the 5279 // connection when done. 5280 return promise.attach(kj::mv(connection)).eagerlyEvaluate(nullptr); 5281 } 5282 5283 kj::Promise<bool> HttpServer::listenHttpCleanDrain(kj::AsyncIoStream& connection) { 5284 kj::Own<Connection> obj; 5285 5286 KJ_SWITCH_ONEOF(service) { 5287 KJ_CASE_ONEOF(ptr, HttpService*) { 5288 obj = heap<Connection>(*this, connection, *ptr); 5289 } 5290 KJ_CASE_ONEOF(func, HttpServiceFactory) { 5291 auto srv = func(connection); 5292 obj = heap<Connection>(*this, connection, *srv); 5293 obj = obj.attach(kj::mv(srv)); 5294 } 5295 } 5296 5297 // Start reading requests and responding to them, but immediately cancel processing if the client 5298 // disconnects. 5299 auto promise = obj->startLoop(true) 5300 .exclusiveJoin(connection.whenWriteDisconnected().then([]() {return false;})); 5301 5302 // Eagerly evaluate so that we drop the connection when the promise resolves, even if the caller 5303 // doesn't eagerly evaluate. 5304 return promise.attach(kj::mv(obj)).eagerlyEvaluate(nullptr); 5305 } 5306 5307 void HttpServer::taskFailed(kj::Exception&& exception) { 5308 KJ_LOG(ERROR, "unhandled exception in HTTP server", exception); 5309 } 5310 5311 kj::Promise<void> HttpServerErrorHandler::handleClientProtocolError( 5312 HttpHeaders::ProtocolError protocolError, kj::HttpService::Response& response) { 5313 // Default error handler implementation. 5314 5315 HttpHeaderTable headerTable {}; 5316 HttpHeaders headers(headerTable); 5317 headers.set(HttpHeaderId::CONTENT_TYPE, "text/plain"); 5318 5319 auto errorMessage = kj::str("ERROR: ", protocolError.description); 5320 auto body = response.send(protocolError.statusCode, protocolError.statusMessage, 5321 headers, errorMessage.size()); 5322 5323 return body->write(errorMessage.begin(), errorMessage.size()) 5324 .attach(kj::mv(errorMessage), kj::mv(body)); 5325 } 5326 5327 kj::Promise<void> HttpServerErrorHandler::handleApplicationError( 5328 kj::Exception exception, kj::Maybe<kj::HttpService::Response&> response) { 5329 // Default error handler implementation. 5330 5331 if (exception.getType() == kj::Exception::Type::DISCONNECTED) { 5332 // How do we tell an HTTP client that there was a transient network error, and it should 5333 // try again immediately? There's no HTTP status code for this (503 is meant for "try 5334 // again later, not now"). Here's an idea: Don't send any response; just close the 5335 // connection, so that it looks like the connection between the HTTP client and server 5336 // was dropped. A good client should treat this exactly the way we want. 5337 // 5338 // We also bail here to avoid logging the disconnection, which isn't very interesting. 5339 return kj::READY_NOW; 5340 } 5341 5342 KJ_IF_MAYBE(r, response) { 5343 HttpHeaderTable headerTable {}; 5344 HttpHeaders headers(headerTable); 5345 headers.set(HttpHeaderId::CONTENT_TYPE, "text/plain"); 5346 5347 kj::String errorMessage; 5348 kj::Own<AsyncOutputStream> body; 5349 5350 if (exception.getType() == kj::Exception::Type::OVERLOADED) { 5351 errorMessage = kj::str( 5352 "ERROR: The server is temporarily unable to handle your request. Details:\n\n", exception); 5353 body = r->send(503, "Service Unavailable", headers, errorMessage.size()); 5354 } else if (exception.getType() == kj::Exception::Type::UNIMPLEMENTED) { 5355 errorMessage = kj::str( 5356 "ERROR: The server does not implement this operation. Details:\n\n", exception); 5357 body = r->send(501, "Not Implemented", headers, errorMessage.size()); 5358 } else { 5359 errorMessage = kj::str( 5360 "ERROR: The server threw an exception. Details:\n\n", exception); 5361 body = r->send(500, "Internal Server Error", headers, errorMessage.size()); 5362 } 5363 5364 return body->write(errorMessage.begin(), errorMessage.size()) 5365 .attach(kj::mv(errorMessage), kj::mv(body)); 5366 } 5367 5368 KJ_LOG(ERROR, "HttpService threw exception after generating a partial response", 5369 "too late to report error to client", exception); 5370 return kj::READY_NOW; 5371 } 5372 5373 kj::Promise<void> HttpServerErrorHandler::handleNoResponse(kj::HttpService::Response& response) { 5374 HttpHeaderTable headerTable {}; 5375 HttpHeaders headers(headerTable); 5376 headers.set(HttpHeaderId::CONTENT_TYPE, "text/plain"); 5377 5378 constexpr auto errorMessage = "ERROR: The HttpService did not generate a response."_kj; 5379 auto body = response.send(500, "Internal Server Error", headers, errorMessage.size()); 5380 5381 return body->write(errorMessage.begin(), errorMessage.size()).attach(kj::mv(body)); 5382 } 5383 5384 } // namespace kj