capnproto

FORK: Cap'n Proto serialization/RPC system - core tools and C++ library
git clone https://git.neptards.moe/neptards/capnproto.git
Log | Files | Refs | README | LICENSE

serialize.c++ (11196B)


      1 // Copyright (c) 2013-2014 Sandstorm Development Group, Inc. and contributors
      2 // Licensed under the MIT License:
      3 //
      4 // Permission is hereby granted, free of charge, to any person obtaining a copy
      5 // of this software and associated documentation files (the "Software"), to deal
      6 // in the Software without restriction, including without limitation the rights
      7 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
      8 // copies of the Software, and to permit persons to whom the Software is
      9 // furnished to do so, subject to the following conditions:
     10 //
     11 // The above copyright notice and this permission notice shall be included in
     12 // all copies or substantial portions of the Software.
     13 //
     14 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
     15 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
     16 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
     17 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
     18 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
     19 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
     20 // THE SOFTWARE.
     21 
     22 #include "serialize.h"
     23 #include "layout.h"
     24 #include <kj/debug.h>
     25 #include <exception>
     26 #ifdef _WIN32
     27 #include <io.h>
     28 #include <fcntl.h>
     29 #endif
     30 
     31 namespace capnp {
     32 
     33 FlatArrayMessageReader::FlatArrayMessageReader(
     34     kj::ArrayPtr<const word> array, ReaderOptions options)
     35     : MessageReader(options), end(array.end()) {
     36   if (array.size() < 1) {
     37     // Assume empty message.
     38     return;
     39   }
     40 
     41   const _::WireValue<uint32_t>* table =
     42       reinterpret_cast<const _::WireValue<uint32_t>*>(array.begin());
     43 
     44   uint segmentCount = table[0].get() + 1;
     45   size_t offset = segmentCount / 2u + 1u;
     46 
     47   KJ_REQUIRE(array.size() >= offset, "Message ends prematurely in segment table.") {
     48     return;
     49   }
     50 
     51   {
     52     uint segmentSize = table[1].get();
     53 
     54     KJ_REQUIRE(array.size() >= offset + segmentSize,
     55                "Message ends prematurely in first segment.") {
     56       return;
     57     }
     58 
     59     segment0 = array.slice(offset, offset + segmentSize);
     60     offset += segmentSize;
     61   }
     62 
     63   if (segmentCount > 1) {
     64     moreSegments = kj::heapArray<kj::ArrayPtr<const word>>(segmentCount - 1);
     65 
     66     for (uint i = 1; i < segmentCount; i++) {
     67       uint segmentSize = table[i + 1].get();
     68 
     69       KJ_REQUIRE(array.size() >= offset + segmentSize, "Message ends prematurely.") {
     70         moreSegments = nullptr;
     71         return;
     72       }
     73 
     74       moreSegments[i - 1] = array.slice(offset, offset + segmentSize);
     75       offset += segmentSize;
     76     }
     77   }
     78 
     79   end = array.begin() + offset;
     80 }
     81 
     82 size_t expectedSizeInWordsFromPrefix(kj::ArrayPtr<const word> array) {
     83   if (array.size() < 1) {
     84     // All messages are at least one word.
     85     return 1;
     86   }
     87 
     88   const _::WireValue<uint32_t>* table =
     89       reinterpret_cast<const _::WireValue<uint32_t>*>(array.begin());
     90 
     91   uint segmentCount = table[0].get() + 1;
     92   size_t offset = segmentCount / 2u + 1u;
     93 
     94   // If the array is too small to contain the full segment table, truncate segmentCount to just
     95   // what is available.
     96   segmentCount = kj::min(segmentCount, array.size() * 2 - 1u);
     97 
     98   size_t totalSize = offset;
     99   for (uint i = 0; i < segmentCount; i++) {
    100     totalSize += table[i + 1].get();
    101   }
    102   return totalSize;
    103 }
    104 
    105 kj::ArrayPtr<const word> FlatArrayMessageReader::getSegment(uint id) {
    106   if (id == 0) {
    107     return segment0;
    108   } else if (id <= moreSegments.size()) {
    109     return moreSegments[id - 1];
    110   } else {
    111     return nullptr;
    112   }
    113 }
    114 
    115 kj::ArrayPtr<const word> initMessageBuilderFromFlatArrayCopy(
    116     kj::ArrayPtr<const word> array, MessageBuilder& target, ReaderOptions options) {
    117   FlatArrayMessageReader reader(array, options);
    118   target.setRoot(reader.getRoot<AnyPointer>());
    119   return kj::arrayPtr(reader.getEnd(), array.end());
    120 }
    121 
    122 kj::Array<word> messageToFlatArray(kj::ArrayPtr<const kj::ArrayPtr<const word>> segments) {
    123   kj::Array<word> result = kj::heapArray<word>(computeSerializedSizeInWords(segments));
    124 
    125   _::WireValue<uint32_t>* table =
    126       reinterpret_cast<_::WireValue<uint32_t>*>(result.begin());
    127 
    128   // We write the segment count - 1 because this makes the first word zero for single-segment
    129   // messages, improving compression.  We don't bother doing this with segment sizes because
    130   // one-word segments are rare anyway.
    131   table[0].set(segments.size() - 1);
    132 
    133   for (uint i = 0; i < segments.size(); i++) {
    134     table[i + 1].set(segments[i].size());
    135   }
    136 
    137   if (segments.size() % 2 == 0) {
    138     // Set padding byte.
    139     table[segments.size() + 1].set(0);
    140   }
    141 
    142   word* dst = result.begin() + segments.size() / 2 + 1;
    143 
    144   for (auto& segment: segments) {
    145     memcpy(dst, segment.begin(), segment.size() * sizeof(word));
    146     dst += segment.size();
    147   }
    148 
    149   KJ_DASSERT(dst == result.end(), "Buffer overrun/underrun bug in code above.");
    150 
    151   return kj::mv(result);
    152 }
    153 
    154 size_t computeSerializedSizeInWords(kj::ArrayPtr<const kj::ArrayPtr<const word>> segments) {
    155   KJ_REQUIRE(segments.size() > 0, "Tried to serialize uninitialized message.");
    156 
    157   size_t totalSize = segments.size() / 2 + 1;
    158 
    159   for (auto& segment: segments) {
    160     totalSize += segment.size();
    161   }
    162 
    163   return totalSize;
    164 }
    165 
    166 // =======================================================================================
    167 
    168 InputStreamMessageReader::InputStreamMessageReader(
    169     kj::InputStream& inputStream, ReaderOptions options, kj::ArrayPtr<word> scratchSpace)
    170     : MessageReader(options), inputStream(inputStream), readPos(nullptr) {
    171   _::WireValue<uint32_t> firstWord[2];
    172 
    173   inputStream.read(firstWord, sizeof(firstWord));
    174 
    175   uint segmentCount = firstWord[0].get() + 1;
    176   uint segment0Size = segmentCount == 0 ? 0 : firstWord[1].get();
    177 
    178   size_t totalWords = segment0Size;
    179 
    180   // Reject messages with too many segments for security reasons.
    181   KJ_REQUIRE(segmentCount < 512, "Message has too many segments.") {
    182     segmentCount = 1;
    183     segment0Size = 1;
    184     break;
    185   }
    186 
    187   // Read sizes for all segments except the first.  Include padding if necessary.
    188   KJ_STACK_ARRAY(_::WireValue<uint32_t>, moreSizes, segmentCount & ~1, 16, 64);
    189   if (segmentCount > 1) {
    190     inputStream.read(moreSizes.begin(), moreSizes.size() * sizeof(moreSizes[0]));
    191     for (uint i = 0; i < segmentCount - 1; i++) {
    192       totalWords += moreSizes[i].get();
    193     }
    194   }
    195 
    196   // Don't accept a message which the receiver couldn't possibly traverse without hitting the
    197   // traversal limit.  Without this check, a malicious client could transmit a very large segment
    198   // size to make the receiver allocate excessive space and possibly crash.
    199   KJ_REQUIRE(totalWords <= options.traversalLimitInWords,
    200              "Message is too large.  To increase the limit on the receiving end, see "
    201              "capnp::ReaderOptions.") {
    202     segmentCount = 1;
    203     segment0Size = kj::min(segment0Size, options.traversalLimitInWords);
    204     totalWords = segment0Size;
    205     break;
    206   }
    207 
    208   if (scratchSpace.size() < totalWords) {
    209     // TODO(perf):  Consider allocating each segment as a separate chunk to reduce memory
    210     //   fragmentation.
    211     ownedSpace = kj::heapArray<word>(totalWords);
    212     scratchSpace = ownedSpace;
    213   }
    214 
    215   segment0 = scratchSpace.slice(0, segment0Size);
    216 
    217   if (segmentCount > 1) {
    218     moreSegments = kj::heapArray<kj::ArrayPtr<const word>>(segmentCount - 1);
    219     size_t offset = segment0Size;
    220 
    221     for (uint i = 0; i < segmentCount - 1; i++) {
    222       uint segmentSize = moreSizes[i].get();
    223       moreSegments[i] = scratchSpace.slice(offset, offset + segmentSize);
    224       offset += segmentSize;
    225     }
    226   }
    227 
    228   if (segmentCount == 1) {
    229     inputStream.read(scratchSpace.begin(), totalWords * sizeof(word));
    230   } else if (segmentCount > 1) {
    231     readPos = scratchSpace.asBytes().begin();
    232     readPos += inputStream.read(readPos, segment0Size * sizeof(word), totalWords * sizeof(word));
    233   }
    234 }
    235 
    236 InputStreamMessageReader::~InputStreamMessageReader() noexcept(false) {
    237   if (readPos != nullptr) {
    238     unwindDetector.catchExceptionsIfUnwinding([&]() {
    239       // Note that lazy reads only happen when we have multiple segments, so moreSegments.back() is
    240       // valid.
    241       const byte* allEnd = reinterpret_cast<const byte*>(moreSegments.back().end());
    242       inputStream.skip(allEnd - readPos);
    243     });
    244   }
    245 }
    246 
    247 kj::ArrayPtr<const word> InputStreamMessageReader::getSegment(uint id) {
    248   if (id > moreSegments.size()) {
    249     return nullptr;
    250   }
    251 
    252   kj::ArrayPtr<const word> segment = id == 0 ? segment0 : moreSegments[id - 1];
    253 
    254   if (readPos != nullptr) {
    255     // May need to lazily read more data.
    256     const byte* segmentEnd = reinterpret_cast<const byte*>(segment.end());
    257     if (readPos < segmentEnd) {
    258       // Note that lazy reads only happen when we have multiple segments, so moreSegments.back() is
    259       // valid.
    260       const byte* allEnd = reinterpret_cast<const byte*>(moreSegments.back().end());
    261       readPos += inputStream.read(readPos, segmentEnd - readPos, allEnd - readPos);
    262     }
    263   }
    264 
    265   return segment;
    266 }
    267 
    268 void readMessageCopy(kj::InputStream& input, MessageBuilder& target,
    269                      ReaderOptions options, kj::ArrayPtr<word> scratchSpace) {
    270   InputStreamMessageReader message(input, options, scratchSpace);
    271   target.setRoot(message.getRoot<AnyPointer>());
    272 }
    273 
    274 // -------------------------------------------------------------------
    275 
    276 void writeMessage(kj::OutputStream& output, kj::ArrayPtr<const kj::ArrayPtr<const word>> segments) {
    277   KJ_REQUIRE(segments.size() > 0, "Tried to serialize uninitialized message.");
    278 
    279   KJ_STACK_ARRAY(_::WireValue<uint32_t>, table, (segments.size() + 2) & ~size_t(1), 16, 64);
    280 
    281   // We write the segment count - 1 because this makes the first word zero for single-segment
    282   // messages, improving compression.  We don't bother doing this with segment sizes because
    283   // one-word segments are rare anyway.
    284   table[0].set(segments.size() - 1);
    285   for (uint i = 0; i < segments.size(); i++) {
    286     table[i + 1].set(segments[i].size());
    287   }
    288   if (segments.size() % 2 == 0) {
    289     // Set padding byte.
    290     table[segments.size() + 1].set(0);
    291   }
    292 
    293   KJ_STACK_ARRAY(kj::ArrayPtr<const byte>, pieces, segments.size() + 1, 4, 32);
    294   pieces[0] = table.asBytes();
    295 
    296   for (uint i = 0; i < segments.size(); i++) {
    297     pieces[i + 1] = segments[i].asBytes();
    298   }
    299 
    300   output.write(pieces);
    301 }
    302 
    303 // =======================================================================================
    304 
    305 StreamFdMessageReader::~StreamFdMessageReader() noexcept(false) {}
    306 
    307 void writeMessageToFd(int fd, kj::ArrayPtr<const kj::ArrayPtr<const word>> segments) {
    308 #ifdef _WIN32
    309     auto oldMode = _setmode(fd, _O_BINARY);
    310     if (oldMode != _O_BINARY) {
    311       _setmode(fd, oldMode);
    312       KJ_FAIL_REQUIRE("Tried to write a message to a file descriptor that is in text mode. Set the "
    313           "file descriptor to binary mode by calling the _setmode Windows CRT function, or passing "
    314           "_O_BINARY to _open().");
    315     }
    316 #endif
    317   kj::FdOutputStream stream(fd);
    318   writeMessage(stream, segments);
    319 }
    320 
    321 void readMessageCopyFromFd(int fd, MessageBuilder& target,
    322                            ReaderOptions options, kj::ArrayPtr<word> scratchSpace) {
    323   kj::FdInputStream stream(fd);
    324   readMessageCopy(stream, target, options, scratchSpace);
    325 }
    326 
    327 }  // namespace capnp