serialize-async.c++ (18853B)
1 // Copyright (c) 2013-2014 Sandstorm Development Group, Inc. and contributors 2 // Licensed under the MIT License: 3 // 4 // Permission is hereby granted, free of charge, to any person obtaining a copy 5 // of this software and associated documentation files (the "Software"), to deal 6 // in the Software without restriction, including without limitation the rights 7 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 8 // copies of the Software, and to permit persons to whom the Software is 9 // furnished to do so, subject to the following conditions: 10 // 11 // The above copyright notice and this permission notice shall be included in 12 // all copies or substantial portions of the Software. 13 // 14 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 17 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 19 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 20 // THE SOFTWARE. 21 22 // Includes just for need SOL_SOCKET and SO_SNDBUF 23 #if _WIN32 24 #include <kj/win32-api-version.h> 25 26 #include <winsock2.h> 27 #include <mswsock.h> 28 #include <kj/windows-sanity.h> 29 #else 30 #include <sys/socket.h> 31 #endif 32 33 #include "serialize-async.h" 34 #include <kj/debug.h> 35 #include <kj/io.h> 36 37 namespace capnp { 38 39 namespace { 40 41 class AsyncMessageReader: public MessageReader { 42 public: 43 inline AsyncMessageReader(ReaderOptions options): MessageReader(options) { 44 memset(firstWord, 0, sizeof(firstWord)); 45 } 46 ~AsyncMessageReader() noexcept(false) {} 47 48 kj::Promise<bool> read(kj::AsyncInputStream& inputStream, kj::ArrayPtr<word> scratchSpace); 49 50 kj::Promise<kj::Maybe<size_t>> readWithFds( 51 kj::AsyncCapabilityStream& inputStream, 52 kj::ArrayPtr<kj::AutoCloseFd> fds, kj::ArrayPtr<word> scratchSpace); 53 54 // implements MessageReader ---------------------------------------- 55 56 kj::ArrayPtr<const word> getSegment(uint id) override { 57 if (id >= segmentCount()) { 58 return nullptr; 59 } else { 60 uint32_t size = id == 0 ? segment0Size() : moreSizes[id - 1].get(); 61 return kj::arrayPtr(segmentStarts[id], size); 62 } 63 } 64 65 private: 66 _::WireValue<uint32_t> firstWord[2]; 67 kj::Array<_::WireValue<uint32_t>> moreSizes; 68 kj::Array<const word*> segmentStarts; 69 70 kj::Array<word> ownedSpace; 71 // Only if scratchSpace wasn't big enough. 72 73 inline uint segmentCount() { return firstWord[0].get() + 1; } 74 inline uint segment0Size() { return firstWord[1].get(); } 75 76 kj::Promise<void> readAfterFirstWord( 77 kj::AsyncInputStream& inputStream, kj::ArrayPtr<word> scratchSpace); 78 kj::Promise<void> readSegments( 79 kj::AsyncInputStream& inputStream, kj::ArrayPtr<word> scratchSpace); 80 }; 81 82 kj::Promise<bool> AsyncMessageReader::read(kj::AsyncInputStream& inputStream, 83 kj::ArrayPtr<word> scratchSpace) { 84 return inputStream.tryRead(firstWord, sizeof(firstWord), sizeof(firstWord)) 85 .then([this,&inputStream,KJ_CPCAP(scratchSpace)](size_t n) mutable -> kj::Promise<bool> { 86 if (n == 0) { 87 return false; 88 } else if (n < sizeof(firstWord)) { 89 // EOF in first word. 90 kj::throwRecoverableException(KJ_EXCEPTION(DISCONNECTED, "Premature EOF.")); 91 return false; 92 } 93 94 return readAfterFirstWord(inputStream, scratchSpace).then([]() { return true; }); 95 }); 96 } 97 98 kj::Promise<kj::Maybe<size_t>> AsyncMessageReader::readWithFds( 99 kj::AsyncCapabilityStream& inputStream, kj::ArrayPtr<kj::AutoCloseFd> fds, 100 kj::ArrayPtr<word> scratchSpace) { 101 return inputStream.tryReadWithFds(firstWord, sizeof(firstWord), sizeof(firstWord), 102 fds.begin(), fds.size()) 103 .then([this,&inputStream,KJ_CPCAP(scratchSpace)] 104 (kj::AsyncCapabilityStream::ReadResult result) mutable 105 -> kj::Promise<kj::Maybe<size_t>> { 106 if (result.byteCount == 0) { 107 return kj::Maybe<size_t>(nullptr); 108 } else if (result.byteCount < sizeof(firstWord)) { 109 // EOF in first word. 110 kj::throwRecoverableException(KJ_EXCEPTION(DISCONNECTED, "Premature EOF.")); 111 return kj::Maybe<size_t>(nullptr); 112 } 113 114 return readAfterFirstWord(inputStream, scratchSpace) 115 .then([result]() -> kj::Maybe<size_t> { return result.capCount; }); 116 }); 117 } 118 119 kj::Promise<void> AsyncMessageReader::readAfterFirstWord(kj::AsyncInputStream& inputStream, 120 kj::ArrayPtr<word> scratchSpace) { 121 if (segmentCount() == 0) { 122 firstWord[1].set(0); 123 } 124 125 // Reject messages with too many segments for security reasons. 126 KJ_REQUIRE(segmentCount() < 512, "Message has too many segments.") { 127 return kj::READY_NOW; // exception will be propagated 128 } 129 130 if (segmentCount() > 1) { 131 // Read sizes for all segments except the first. Include padding if necessary. 132 moreSizes = kj::heapArray<_::WireValue<uint32_t>>(segmentCount() & ~1); 133 return inputStream.read(moreSizes.begin(), moreSizes.size() * sizeof(moreSizes[0])) 134 .then([this,&inputStream,KJ_CPCAP(scratchSpace)]() mutable { 135 return readSegments(inputStream, scratchSpace); 136 }); 137 } else { 138 return readSegments(inputStream, scratchSpace); 139 } 140 } 141 142 kj::Promise<void> AsyncMessageReader::readSegments(kj::AsyncInputStream& inputStream, 143 kj::ArrayPtr<word> scratchSpace) { 144 size_t totalWords = segment0Size(); 145 146 if (segmentCount() > 1) { 147 for (uint i = 0; i < segmentCount() - 1; i++) { 148 totalWords += moreSizes[i].get(); 149 } 150 } 151 152 // Don't accept a message which the receiver couldn't possibly traverse without hitting the 153 // traversal limit. Without this check, a malicious client could transmit a very large segment 154 // size to make the receiver allocate excessive space and possibly crash. 155 KJ_REQUIRE(totalWords <= getOptions().traversalLimitInWords, 156 "Message is too large. To increase the limit on the receiving end, see " 157 "capnp::ReaderOptions.") { 158 return kj::READY_NOW; // exception will be propagated 159 } 160 161 if (scratchSpace.size() < totalWords) { 162 // TODO(perf): Consider allocating each segment as a separate chunk to reduce memory 163 // fragmentation. 164 ownedSpace = kj::heapArray<word>(totalWords); 165 scratchSpace = ownedSpace; 166 } 167 168 segmentStarts = kj::heapArray<const word*>(segmentCount()); 169 170 segmentStarts[0] = scratchSpace.begin(); 171 172 if (segmentCount() > 1) { 173 size_t offset = segment0Size(); 174 175 for (uint i = 1; i < segmentCount(); i++) { 176 segmentStarts[i] = scratchSpace.begin() + offset; 177 offset += moreSizes[i-1].get(); 178 } 179 } 180 181 return inputStream.read(scratchSpace.begin(), totalWords * sizeof(word)); 182 } 183 184 185 } // namespace 186 187 kj::Promise<kj::Own<MessageReader>> readMessage( 188 kj::AsyncInputStream& input, ReaderOptions options, kj::ArrayPtr<word> scratchSpace) { 189 auto reader = kj::heap<AsyncMessageReader>(options); 190 auto promise = reader->read(input, scratchSpace); 191 return promise.then([reader = kj::mv(reader)](bool success) mutable -> kj::Own<MessageReader> { 192 if (!success) { 193 kj::throwRecoverableException(KJ_EXCEPTION(DISCONNECTED, "Premature EOF.")); 194 } 195 return kj::mv(reader); 196 }); 197 } 198 199 kj::Promise<kj::Maybe<kj::Own<MessageReader>>> tryReadMessage( 200 kj::AsyncInputStream& input, ReaderOptions options, kj::ArrayPtr<word> scratchSpace) { 201 auto reader = kj::heap<AsyncMessageReader>(options); 202 auto promise = reader->read(input, scratchSpace); 203 return promise.then([reader = kj::mv(reader)](bool success) mutable 204 -> kj::Maybe<kj::Own<MessageReader>> { 205 if (success) { 206 return kj::mv(reader); 207 } else { 208 return nullptr; 209 } 210 }); 211 } 212 213 kj::Promise<MessageReaderAndFds> readMessage( 214 kj::AsyncCapabilityStream& input, kj::ArrayPtr<kj::AutoCloseFd> fdSpace, 215 ReaderOptions options, kj::ArrayPtr<word> scratchSpace) { 216 auto reader = kj::heap<AsyncMessageReader>(options); 217 auto promise = reader->readWithFds(input, fdSpace, scratchSpace); 218 return promise.then([reader = kj::mv(reader), fdSpace](kj::Maybe<size_t> nfds) mutable 219 -> MessageReaderAndFds { 220 KJ_IF_MAYBE(n, nfds) { 221 return { kj::mv(reader), fdSpace.slice(0, *n) }; 222 } else { 223 kj::throwRecoverableException(KJ_EXCEPTION(DISCONNECTED, "Premature EOF.")); 224 return { kj::mv(reader), nullptr }; 225 } 226 }); 227 } 228 229 kj::Promise<kj::Maybe<MessageReaderAndFds>> tryReadMessage( 230 kj::AsyncCapabilityStream& input, kj::ArrayPtr<kj::AutoCloseFd> fdSpace, 231 ReaderOptions options, kj::ArrayPtr<word> scratchSpace) { 232 auto reader = kj::heap<AsyncMessageReader>(options); 233 auto promise = reader->readWithFds(input, fdSpace, scratchSpace); 234 return promise.then([reader = kj::mv(reader), fdSpace](kj::Maybe<size_t> nfds) mutable 235 -> kj::Maybe<MessageReaderAndFds> { 236 KJ_IF_MAYBE(n, nfds) { 237 return MessageReaderAndFds { kj::mv(reader), fdSpace.slice(0, *n) }; 238 } else { 239 return nullptr; 240 } 241 }); 242 } 243 244 // ======================================================================================= 245 246 namespace { 247 248 struct WriteArrays { 249 // Holds arrays that must remain valid until a write completes. 250 251 kj::Array<_::WireValue<uint32_t>> table; 252 kj::Array<kj::ArrayPtr<const byte>> pieces; 253 }; 254 255 inline size_t tableSizeForSegments(size_t segmentsSize) { 256 return (segmentsSize + 2) & ~size_t(1); 257 } 258 259 // Helper function that allocates and fills the pointed-to table with info about the segments and 260 // populates the pieces array with pointers to the segments. 261 void fillWriteArraysWithMessage(kj::ArrayPtr<const kj::ArrayPtr<const word>> segments, 262 kj::ArrayPtr<_::WireValue<uint32_t>> table, 263 kj::ArrayPtr<kj::ArrayPtr<const byte>> pieces) { 264 KJ_REQUIRE(segments.size() > 0, "Tried to serialize uninitialized message."); 265 266 // We write the segment count - 1 because this makes the first word zero for single-segment 267 // messages, improving compression. We don't bother doing this with segment sizes because 268 // one-word segments are rare anyway. 269 table[0].set(segments.size() - 1); 270 for (uint i = 0; i < segments.size(); i++) { 271 table[i + 1].set(segments[i].size()); 272 } 273 if (segments.size() % 2 == 0) { 274 // Set padding byte. 275 table[segments.size() + 1].set(0); 276 } 277 278 KJ_ASSERT(pieces.size() == segments.size() + 1, "incorrectly sized pieces array during write"); 279 pieces[0] = table.asBytes(); 280 for (uint i = 0; i < segments.size(); i++) { 281 pieces[i + 1] = segments[i].asBytes(); 282 } 283 } 284 285 template <typename WriteFunc> 286 kj::Promise<void> writeMessageImpl(kj::ArrayPtr<const kj::ArrayPtr<const word>> segments, 287 WriteFunc&& writeFunc) { 288 KJ_REQUIRE(segments.size() > 0, "Tried to serialize uninitialized message."); 289 290 WriteArrays arrays; 291 arrays.table = kj::heapArray<_::WireValue<uint32_t>>(tableSizeForSegments(segments.size())); 292 arrays.pieces = kj::heapArray<kj::ArrayPtr<const byte>>(segments.size() + 1); 293 fillWriteArraysWithMessage(segments, arrays.table, arrays.pieces); 294 295 auto promise = writeFunc(arrays.pieces); 296 297 // Make sure the arrays aren't freed until the write completes. 298 return promise.then(kj::mvCapture(arrays, [](WriteArrays&&) {})); 299 } 300 301 template <typename WriteFunc> 302 kj::Promise<void> writeMessagesImpl( 303 kj::ArrayPtr<kj::ArrayPtr<const kj::ArrayPtr<const word>>> messages, WriteFunc&& writeFunc) { 304 KJ_REQUIRE(messages.size() > 0, "Tried to serialize zero messages."); 305 306 // Determine how large the shared table and pieces arrays needs to be. 307 size_t tableSize = 0; 308 size_t piecesSize = 0; 309 for (auto& segments : messages) { 310 tableSize += tableSizeForSegments(segments.size()); 311 piecesSize += segments.size() + 1; 312 } 313 auto table = kj::heapArray<_::WireValue<uint32_t>>(tableSize); 314 auto pieces = kj::heapArray<kj::ArrayPtr<const byte>>(piecesSize); 315 316 size_t tableValsWritten = 0; 317 size_t piecesWritten = 0; 318 for (auto i : kj::indices(messages)) { 319 const size_t tableValsToWrite = tableSizeForSegments(messages[i].size()); 320 const size_t piecesToWrite = messages[i].size() + 1; 321 fillWriteArraysWithMessage( 322 messages[i], 323 table.slice(tableValsWritten, tableValsWritten + tableValsToWrite), 324 pieces.slice(piecesWritten, piecesWritten + piecesToWrite)); 325 tableValsWritten += tableValsToWrite; 326 piecesWritten += piecesToWrite; 327 } 328 329 auto promise = writeFunc(pieces); 330 return promise.attach(kj::mv(table), kj::mv(pieces)); 331 } 332 333 } // namespace 334 335 kj::Promise<void> writeMessage(kj::AsyncOutputStream& output, 336 kj::ArrayPtr<const kj::ArrayPtr<const word>> segments) { 337 return writeMessageImpl(segments, 338 [&](kj::ArrayPtr<const kj::ArrayPtr<const byte>> pieces) { 339 return output.write(pieces); 340 }); 341 } 342 343 kj::Promise<void> writeMessage(kj::AsyncCapabilityStream& output, kj::ArrayPtr<const int> fds, 344 kj::ArrayPtr<const kj::ArrayPtr<const word>> segments) { 345 return writeMessageImpl(segments, 346 [&](kj::ArrayPtr<const kj::ArrayPtr<const byte>> pieces) { 347 return output.writeWithFds(pieces[0], pieces.slice(1, pieces.size()), fds); 348 }); 349 } 350 351 kj::Promise<void> writeMessages( 352 kj::AsyncOutputStream& output, 353 kj::ArrayPtr<kj::ArrayPtr<const kj::ArrayPtr<const word>>> messages) { 354 return writeMessagesImpl(messages, 355 [&](kj::ArrayPtr<const kj::ArrayPtr<const byte>> pieces) { 356 return output.write(pieces); 357 }); 358 } 359 360 kj::Promise<void> writeMessages( 361 kj::AsyncOutputStream& output, kj::ArrayPtr<MessageBuilder*> builders) { 362 auto messages = kj::heapArray<kj::ArrayPtr<const kj::ArrayPtr<const word>>>(builders.size()); 363 for (auto i : kj::indices(builders)) { 364 messages[i] = builders[i]->getSegmentsForOutput(); 365 } 366 return writeMessages(output, messages); 367 } 368 369 kj::Promise<void> MessageStream::writeMessages(kj::ArrayPtr<MessageBuilder*> builders) { 370 auto messages = kj::heapArray<kj::ArrayPtr<const kj::ArrayPtr<const word>>>(builders.size()); 371 for (auto i : kj::indices(builders)) { 372 messages[i] = builders[i]->getSegmentsForOutput(); 373 } 374 return writeMessages(messages); 375 } 376 377 AsyncIoMessageStream::AsyncIoMessageStream(kj::AsyncIoStream& stream) 378 : stream(stream) {}; 379 380 kj::Promise<kj::Maybe<MessageReaderAndFds>> AsyncIoMessageStream::tryReadMessage( 381 kj::ArrayPtr<kj::AutoCloseFd> fdSpace, 382 ReaderOptions options, 383 kj::ArrayPtr<word> scratchSpace) { 384 return capnp::tryReadMessage(stream, options, scratchSpace) 385 .then([](kj::Maybe<kj::Own<MessageReader>> maybeReader) -> kj::Maybe<MessageReaderAndFds> { 386 KJ_IF_MAYBE(reader, maybeReader) { 387 return MessageReaderAndFds { kj::mv(*reader), nullptr }; 388 } else { 389 return nullptr; 390 } 391 }); 392 } 393 394 kj::Promise<void> AsyncIoMessageStream::writeMessage( 395 kj::ArrayPtr<const int> fds, 396 kj::ArrayPtr<const kj::ArrayPtr<const word>> segments) { 397 return capnp::writeMessage(stream, segments); 398 } 399 400 kj::Promise<void> AsyncIoMessageStream::writeMessages( 401 kj::ArrayPtr<kj::ArrayPtr<const kj::ArrayPtr<const word>>> messages) { 402 return capnp::writeMessages(stream, messages); 403 } 404 405 kj::Maybe<int> getSendBufferSize(kj::AsyncIoStream& stream) { 406 // TODO(perf): It might be nice to have a tryGetsockopt() that doesn't require catching 407 // exceptions? 408 int bufSize = 0; 409 KJ_IF_MAYBE(exception, kj::runCatchingExceptions([&]() { 410 uint len = sizeof(int); 411 stream.getsockopt(SOL_SOCKET, SO_SNDBUF, &bufSize, &len); 412 KJ_ASSERT(len == sizeof(bufSize)) { break; } 413 })) { 414 if (exception->getType() != kj::Exception::Type::UNIMPLEMENTED) { 415 // TODO(someday): Figure out why getting SO_SNDBUF sometimes throws EINVAL. I suspect it 416 // happens when the remote side has closed their read end, meaning we no longer have 417 // a send buffer, but I don't know what is the best way to verify that that was actually 418 // the reason. I'd prefer not to ignore EINVAL errors in general. 419 420 // kj::throwRecoverableException(kj::mv(*exception)); 421 } 422 return nullptr; 423 } 424 return bufSize; 425 } 426 427 kj::Promise<void> AsyncIoMessageStream::end() { 428 stream.shutdownWrite(); 429 return kj::READY_NOW; 430 } 431 432 kj::Maybe<int> AsyncIoMessageStream::getSendBufferSize() { 433 return capnp::getSendBufferSize(stream); 434 } 435 436 AsyncCapabilityMessageStream::AsyncCapabilityMessageStream(kj::AsyncCapabilityStream& stream) 437 : stream(stream) {}; 438 439 kj::Promise<kj::Maybe<MessageReaderAndFds>> AsyncCapabilityMessageStream::tryReadMessage( 440 kj::ArrayPtr<kj::AutoCloseFd> fdSpace, 441 ReaderOptions options, 442 kj::ArrayPtr<word> scratchSpace) { 443 return capnp::tryReadMessage(stream, fdSpace, options, scratchSpace); 444 } 445 446 kj::Promise<void> AsyncCapabilityMessageStream::writeMessage( 447 kj::ArrayPtr<const int> fds, 448 kj::ArrayPtr<const kj::ArrayPtr<const word>> segments) { 449 return capnp::writeMessage(stream, fds, segments); 450 } 451 452 kj::Promise<void> AsyncCapabilityMessageStream::writeMessages( 453 kj::ArrayPtr<kj::ArrayPtr<const kj::ArrayPtr<const word>>> messages) { 454 return capnp::writeMessages(stream, messages); 455 } 456 457 kj::Maybe<int> AsyncCapabilityMessageStream::getSendBufferSize() { 458 return capnp::getSendBufferSize(stream); 459 } 460 461 kj::Promise<void> AsyncCapabilityMessageStream::end() { 462 stream.shutdownWrite(); 463 return kj::READY_NOW; 464 } 465 466 kj::Promise<kj::Own<MessageReader>> MessageStream::readMessage( 467 ReaderOptions options, 468 kj::ArrayPtr<word> scratchSpace) { 469 return tryReadMessage(options, scratchSpace).then([](kj::Maybe<kj::Own<MessageReader>> maybeResult) { 470 KJ_IF_MAYBE(result, maybeResult) { 471 return kj::mv(*result); 472 } else { 473 kj::throwRecoverableException(KJ_EXCEPTION(DISCONNECTED, "Premature EOF.")); 474 KJ_UNREACHABLE; 475 } 476 }); 477 } 478 479 kj::Promise<kj::Maybe<kj::Own<MessageReader>>> MessageStream::tryReadMessage( 480 ReaderOptions options, 481 kj::ArrayPtr<word> scratchSpace) { 482 return tryReadMessage(nullptr, options, scratchSpace) 483 .then([](auto maybeReaderAndFds) -> kj::Maybe<kj::Own<MessageReader>> { 484 KJ_IF_MAYBE(readerAndFds, maybeReaderAndFds) { 485 return kj::mv(readerAndFds->reader); 486 } else { 487 return nullptr; 488 } 489 }); 490 } 491 492 kj::Promise<MessageReaderAndFds> MessageStream::readMessage( 493 kj::ArrayPtr<kj::AutoCloseFd> fdSpace, 494 ReaderOptions options, kj::ArrayPtr<word> scratchSpace) { 495 return tryReadMessage(fdSpace, options, scratchSpace).then([](auto maybeResult) { 496 KJ_IF_MAYBE(result, maybeResult) { 497 return kj::mv(*result); 498 } else { 499 kj::throwRecoverableException(KJ_EXCEPTION(DISCONNECTED, "Premature EOF.")); 500 KJ_UNREACHABLE; 501 } 502 }); 503 } 504 505 } // namespace capnp