capnproto

FORK: Cap'n Proto serialization/RPC system - core tools and C++ library
git clone https://git.neptards.moe/neptards/capnproto.git
Log | Files | Refs | README | LICENSE

serialize-packed.c++ (16005B)


      1 // Copyright (c) 2013-2014 Sandstorm Development Group, Inc. and contributors
      2 // Licensed under the MIT License:
      3 //
      4 // Permission is hereby granted, free of charge, to any person obtaining a copy
      5 // of this software and associated documentation files (the "Software"), to deal
      6 // in the Software without restriction, including without limitation the rights
      7 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
      8 // copies of the Software, and to permit persons to whom the Software is
      9 // furnished to do so, subject to the following conditions:
     10 //
     11 // The above copyright notice and this permission notice shall be included in
     12 // all copies or substantial portions of the Software.
     13 //
     14 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
     15 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
     16 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
     17 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
     18 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
     19 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
     20 // THE SOFTWARE.
     21 
     22 #include "serialize-packed.h"
     23 #include <kj/debug.h>
     24 #include "layout.h"
     25 #include <vector>
     26 
     27 namespace capnp {
     28 
     29 namespace _ {  // private
     30 
     31 PackedInputStream::PackedInputStream(kj::BufferedInputStream& inner): inner(inner) {}
     32 PackedInputStream::~PackedInputStream() noexcept(false) {}
     33 
     34 size_t PackedInputStream::tryRead(void* dst, size_t minBytes, size_t maxBytes) {
     35   if (maxBytes == 0) {
     36     return 0;
     37   }
     38 
     39   KJ_DREQUIRE(minBytes % sizeof(word) == 0, "PackedInputStream reads must be word-aligned.");
     40   KJ_DREQUIRE(maxBytes % sizeof(word) == 0, "PackedInputStream reads must be word-aligned.");
     41 
     42   uint8_t* __restrict__ out = reinterpret_cast<uint8_t*>(dst);
     43   uint8_t* const outEnd = reinterpret_cast<uint8_t*>(dst) + maxBytes;
     44   uint8_t* const outMin = reinterpret_cast<uint8_t*>(dst) + minBytes;
     45 
     46   kj::ArrayPtr<const byte> buffer = inner.tryGetReadBuffer();
     47   if (buffer.size() == 0) {
     48     return 0;
     49   }
     50   const uint8_t* __restrict__ in = reinterpret_cast<const uint8_t*>(buffer.begin());
     51 
     52 #define REFRESH_BUFFER() \
     53   inner.skip(buffer.size()); \
     54   buffer = inner.getReadBuffer(); \
     55   KJ_REQUIRE(buffer.size() > 0, "Premature end of packed input.") { \
     56     return out - reinterpret_cast<uint8_t*>(dst); \
     57   } \
     58   in = reinterpret_cast<const uint8_t*>(buffer.begin())
     59 
     60 #define BUFFER_END (reinterpret_cast<const uint8_t*>(buffer.end()))
     61 #define BUFFER_REMAINING ((size_t)(BUFFER_END - in))
     62 
     63   for (;;) {
     64     uint8_t tag;
     65 
     66     KJ_DASSERT((out - reinterpret_cast<uint8_t*>(dst)) % sizeof(word) == 0,
     67            "Output pointer should always be aligned here.");
     68 
     69     if (BUFFER_REMAINING < 10) {
     70       if (out >= outMin) {
     71         // We read at least the minimum amount, so go ahead and return.
     72         inner.skip(in - reinterpret_cast<const uint8_t*>(buffer.begin()));
     73         return out - reinterpret_cast<uint8_t*>(dst);
     74       }
     75 
     76       if (BUFFER_REMAINING == 0) {
     77         REFRESH_BUFFER();
     78         continue;
     79       }
     80 
     81       // We have at least 1, but not 10, bytes available.  We need to read slowly, doing a bounds
     82       // check on each byte.
     83 
     84       tag = *in++;
     85 
     86       for (uint i = 0; i < 8; i++) {
     87         if (tag & (1u << i)) {
     88           if (BUFFER_REMAINING == 0) {
     89             REFRESH_BUFFER();
     90           }
     91           *out++ = *in++;
     92         } else {
     93           *out++ = 0;
     94         }
     95       }
     96 
     97       if (BUFFER_REMAINING == 0 && (tag == 0 || tag == 0xffu)) {
     98         REFRESH_BUFFER();
     99       }
    100     } else {
    101       tag = *in++;
    102 
    103 #define HANDLE_BYTE(n) \
    104       { \
    105          bool isNonzero = (tag & (1u << n)) != 0; \
    106          *out++ = *in & (-(int8_t)isNonzero); \
    107          in += isNonzero; \
    108       }
    109 
    110       HANDLE_BYTE(0);
    111       HANDLE_BYTE(1);
    112       HANDLE_BYTE(2);
    113       HANDLE_BYTE(3);
    114       HANDLE_BYTE(4);
    115       HANDLE_BYTE(5);
    116       HANDLE_BYTE(6);
    117       HANDLE_BYTE(7);
    118 #undef HANDLE_BYTE
    119     }
    120 
    121     if (tag == 0) {
    122       KJ_DASSERT(BUFFER_REMAINING > 0, "Should always have non-empty buffer here.");
    123 
    124       uint runLength = *in++ * sizeof(word);
    125 
    126       KJ_REQUIRE(runLength <= outEnd - out,
    127                  "Packed input did not end cleanly on a segment boundary.") {
    128         return out - reinterpret_cast<uint8_t*>(dst);
    129       }
    130       memset(out, 0, runLength);
    131       out += runLength;
    132 
    133     } else if (tag == 0xffu) {
    134       KJ_DASSERT(BUFFER_REMAINING > 0, "Should always have non-empty buffer here.");
    135 
    136       uint runLength = *in++ * sizeof(word);
    137 
    138       KJ_REQUIRE(runLength <= outEnd - out,
    139                  "Packed input did not end cleanly on a segment boundary.") {
    140         return out - reinterpret_cast<uint8_t*>(dst);
    141       }
    142 
    143       size_t inRemaining = BUFFER_REMAINING;
    144       if (inRemaining >= runLength) {
    145         // Fast path.
    146         memcpy(out, in, runLength);
    147         out += runLength;
    148         in += runLength;
    149       } else {
    150         // Copy over the first buffer, then do one big read for the rest.
    151         memcpy(out, in, inRemaining);
    152         out += inRemaining;
    153         runLength -= inRemaining;
    154 
    155         inner.skip(buffer.size());
    156         inner.read(out, runLength);
    157         out += runLength;
    158 
    159         if (out == outEnd) {
    160           return maxBytes;
    161         } else {
    162           buffer = inner.getReadBuffer();
    163           in = reinterpret_cast<const uint8_t*>(buffer.begin());
    164 
    165           // Skip the bounds check below since we just did the same check above.
    166           continue;
    167         }
    168       }
    169     }
    170 
    171     if (out == outEnd) {
    172       inner.skip(in - reinterpret_cast<const uint8_t*>(buffer.begin()));
    173       return maxBytes;
    174     }
    175   }
    176 
    177   KJ_FAIL_ASSERT("Can't get here.");
    178   return 0;  // GCC knows KJ_FAIL_ASSERT doesn't return, but Eclipse CDT still warns...
    179 
    180 #undef REFRESH_BUFFER
    181 }
    182 
    183 void PackedInputStream::skip(size_t bytes) {
    184   // We can't just read into buffers because buffers must end on block boundaries.
    185 
    186   if (bytes == 0) {
    187     return;
    188   }
    189 
    190   KJ_DREQUIRE(bytes % sizeof(word) == 0, "PackedInputStream reads must be word-aligned.");
    191 
    192   kj::ArrayPtr<const byte> buffer = inner.getReadBuffer();
    193   const uint8_t* __restrict__ in = reinterpret_cast<const uint8_t*>(buffer.begin());
    194 
    195 #define REFRESH_BUFFER() \
    196   inner.skip(buffer.size()); \
    197   buffer = inner.getReadBuffer(); \
    198   KJ_REQUIRE(buffer.size() > 0, "Premature end of packed input.") { return; } \
    199   in = reinterpret_cast<const uint8_t*>(buffer.begin())
    200 
    201   for (;;) {
    202     uint8_t tag;
    203 
    204     if (BUFFER_REMAINING < 10) {
    205       if (BUFFER_REMAINING == 0) {
    206         REFRESH_BUFFER();
    207         continue;
    208       }
    209 
    210       // We have at least 1, but not 10, bytes available.  We need to read slowly, doing a bounds
    211       // check on each byte.
    212 
    213       tag = *in++;
    214 
    215       for (uint i = 0; i < 8; i++) {
    216         if (tag & (1u << i)) {
    217           if (BUFFER_REMAINING == 0) {
    218             REFRESH_BUFFER();
    219           }
    220           in++;
    221         }
    222       }
    223       bytes -= 8;
    224 
    225       if (BUFFER_REMAINING == 0 && (tag == 0 || tag == 0xffu)) {
    226         REFRESH_BUFFER();
    227       }
    228     } else {
    229       tag = *in++;
    230 
    231 #define HANDLE_BYTE(n) \
    232       in += (tag & (1u << n)) != 0
    233 
    234       HANDLE_BYTE(0);
    235       HANDLE_BYTE(1);
    236       HANDLE_BYTE(2);
    237       HANDLE_BYTE(3);
    238       HANDLE_BYTE(4);
    239       HANDLE_BYTE(5);
    240       HANDLE_BYTE(6);
    241       HANDLE_BYTE(7);
    242 #undef HANDLE_BYTE
    243 
    244       bytes -= 8;
    245     }
    246 
    247     if (tag == 0) {
    248       KJ_DASSERT(BUFFER_REMAINING > 0, "Should always have non-empty buffer here.");
    249 
    250       uint runLength = *in++ * sizeof(word);
    251 
    252       KJ_REQUIRE(runLength <= bytes, "Packed input did not end cleanly on a segment boundary.") {
    253         return;
    254       }
    255 
    256       bytes -= runLength;
    257 
    258     } else if (tag == 0xffu) {
    259       KJ_DASSERT(BUFFER_REMAINING > 0, "Should always have non-empty buffer here.");
    260 
    261       uint runLength = *in++ * sizeof(word);
    262 
    263       KJ_REQUIRE(runLength <= bytes, "Packed input did not end cleanly on a segment boundary.") {
    264         return;
    265       }
    266 
    267       bytes -= runLength;
    268 
    269       size_t inRemaining = BUFFER_REMAINING;
    270       if (inRemaining > runLength) {
    271         // Fast path.
    272         in += runLength;
    273       } else {
    274         // Forward skip to the underlying stream.
    275         runLength -= inRemaining;
    276         inner.skip(buffer.size() + runLength);
    277 
    278         if (bytes == 0) {
    279           return;
    280         } else {
    281           buffer = inner.getReadBuffer();
    282           in = reinterpret_cast<const uint8_t*>(buffer.begin());
    283 
    284           // Skip the bounds check below since we just did the same check above.
    285           continue;
    286         }
    287       }
    288     }
    289 
    290     if (bytes == 0) {
    291       inner.skip(in - reinterpret_cast<const uint8_t*>(buffer.begin()));
    292       return;
    293     }
    294   }
    295 
    296   KJ_FAIL_ASSERT("Can't get here.");
    297 }
    298 
    299 // -------------------------------------------------------------------
    300 
    301 PackedOutputStream::PackedOutputStream(kj::BufferedOutputStream& inner)
    302     : inner(inner) {}
    303 PackedOutputStream::~PackedOutputStream() noexcept(false) {}
    304 
    305 void PackedOutputStream::write(const void* src, size_t size) {
    306   kj::ArrayPtr<byte> buffer = inner.getWriteBuffer();
    307   byte slowBuffer[20];
    308 
    309   uint8_t* __restrict__ out = reinterpret_cast<uint8_t*>(buffer.begin());
    310 
    311   const uint8_t* __restrict__ in = reinterpret_cast<const uint8_t*>(src);
    312   const uint8_t* const inEnd = reinterpret_cast<const uint8_t*>(src) + size;
    313 
    314   while (in < inEnd) {
    315     if (reinterpret_cast<uint8_t*>(buffer.end()) - out < 10) {
    316       // Oops, we're out of space.  We need at least 10 bytes for the fast path, since we don't
    317       // bounds-check on every byte.
    318 
    319       // Write what we have so far.
    320       inner.write(buffer.begin(), out - reinterpret_cast<uint8_t*>(buffer.begin()));
    321 
    322       // Use a slow buffer into which we'll encode 10 to 20 bytes.  This should get us past the
    323       // output stream's buffer boundary.
    324       buffer = kj::arrayPtr(slowBuffer, sizeof(slowBuffer));
    325       out = reinterpret_cast<uint8_t*>(buffer.begin());
    326     }
    327 
    328     uint8_t* tagPos = out++;
    329 
    330 #define HANDLE_BYTE(n) \
    331     uint8_t bit##n = *in != 0; \
    332     *out = *in; \
    333     out += bit##n; /* out only advances if the byte was non-zero */ \
    334     ++in
    335 
    336     HANDLE_BYTE(0);
    337     HANDLE_BYTE(1);
    338     HANDLE_BYTE(2);
    339     HANDLE_BYTE(3);
    340     HANDLE_BYTE(4);
    341     HANDLE_BYTE(5);
    342     HANDLE_BYTE(6);
    343     HANDLE_BYTE(7);
    344 #undef HANDLE_BYTE
    345 
    346     uint8_t tag = (bit0 << 0) | (bit1 << 1) | (bit2 << 2) | (bit3 << 3)
    347                 | (bit4 << 4) | (bit5 << 5) | (bit6 << 6) | (bit7 << 7);
    348     *tagPos = tag;
    349 
    350     if (tag == 0) {
    351       // An all-zero word is followed by a count of consecutive zero words (not including the
    352       // first one).
    353 
    354       // We can check a whole word at a time. (Here is where we use the assumption that
    355       // `src` is word-aligned.)
    356       const uint64_t* inWord = reinterpret_cast<const uint64_t*>(in);
    357 
    358       // The count must fit it 1 byte, so limit to 255 words.
    359       const uint64_t* limit = reinterpret_cast<const uint64_t*>(inEnd);
    360       if (limit - inWord > 255) {
    361         limit = inWord + 255;
    362       }
    363 
    364       while (inWord < limit && *inWord == 0) {
    365         ++inWord;
    366       }
    367 
    368       // Write the count.
    369       *out++ = inWord - reinterpret_cast<const uint64_t*>(in);
    370 
    371       // Advance input.
    372       in = reinterpret_cast<const uint8_t*>(inWord);
    373 
    374     } else if (tag == 0xffu) {
    375       // An all-nonzero word is followed by a count of consecutive uncompressed words, followed
    376       // by the uncompressed words themselves.
    377 
    378       // Count the number of consecutive words in the input which have no more than a single
    379       // zero-byte.  We look for at least two zeros because that's the point where our compression
    380       // scheme becomes a net win.
    381       // TODO(perf):  Maybe look for three zeros?  Compressing a two-zero word is a loss if the
    382       //   following word has no zeros.
    383       const uint8_t* runStart = in;
    384 
    385       const uint8_t* limit = inEnd;
    386       if ((size_t)(limit - in) > 255 * sizeof(word)) {
    387         limit = in + 255 * sizeof(word);
    388       }
    389 
    390       while (in < limit) {
    391         // Check eight input bytes for zeros.
    392         uint c = *in++ == 0;
    393         c += *in++ == 0;
    394         c += *in++ == 0;
    395         c += *in++ == 0;
    396         c += *in++ == 0;
    397         c += *in++ == 0;
    398         c += *in++ == 0;
    399         c += *in++ == 0;
    400 
    401         if (c >= 2) {
    402           // Un-read the word with multiple zeros, since we'll want to compress that one.
    403           in -= 8;
    404           break;
    405         }
    406       }
    407 
    408       // Write the count.
    409       uint count = in - runStart;
    410       *out++ = count / sizeof(word);
    411 
    412       if (count <= reinterpret_cast<uint8_t*>(buffer.end()) - out) {
    413         // There's enough space to memcpy.
    414         memcpy(out, runStart, count);
    415         out += count;
    416       } else {
    417         // Input overruns the output buffer.  We'll give it to the output stream in one chunk
    418         // and let it decide what to do.
    419         inner.write(buffer.begin(), reinterpret_cast<byte*>(out) - buffer.begin());
    420         inner.write(runStart, in - runStart);
    421         buffer = inner.getWriteBuffer();
    422         out = reinterpret_cast<uint8_t*>(buffer.begin());
    423       }
    424     }
    425   }
    426 
    427   // Write whatever is left.
    428   inner.write(buffer.begin(), reinterpret_cast<byte*>(out) - buffer.begin());
    429 }
    430 
    431 }  // namespace _ (private)
    432 
    433 // =======================================================================================
    434 
    435 PackedMessageReader::PackedMessageReader(
    436     kj::BufferedInputStream& inputStream, ReaderOptions options, kj::ArrayPtr<word> scratchSpace)
    437     : PackedInputStream(inputStream),
    438       InputStreamMessageReader(static_cast<PackedInputStream&>(*this), options, scratchSpace) {}
    439 
    440 PackedMessageReader::~PackedMessageReader() noexcept(false) {}
    441 
    442 PackedFdMessageReader::PackedFdMessageReader(
    443     int fd, ReaderOptions options, kj::ArrayPtr<word> scratchSpace)
    444     : FdInputStream(fd),
    445       BufferedInputStreamWrapper(static_cast<FdInputStream&>(*this)),
    446       PackedMessageReader(static_cast<BufferedInputStreamWrapper&>(*this),
    447                           options, scratchSpace) {}
    448 
    449 PackedFdMessageReader::PackedFdMessageReader(
    450     kj::AutoCloseFd fd, ReaderOptions options, kj::ArrayPtr<word> scratchSpace)
    451     : FdInputStream(kj::mv(fd)),
    452       BufferedInputStreamWrapper(static_cast<FdInputStream&>(*this)),
    453       PackedMessageReader(static_cast<BufferedInputStreamWrapper&>(*this),
    454                           options, scratchSpace) {}
    455 
    456 PackedFdMessageReader::~PackedFdMessageReader() noexcept(false) {}
    457 
    458 void writePackedMessage(kj::BufferedOutputStream& output,
    459                         kj::ArrayPtr<const kj::ArrayPtr<const word>> segments) {
    460   _::PackedOutputStream packedOutput(output);
    461   writeMessage(packedOutput, segments);
    462 }
    463 
    464 void writePackedMessage(kj::OutputStream& output,
    465                         kj::ArrayPtr<const kj::ArrayPtr<const word>> segments) {
    466   KJ_IF_MAYBE(bufferedOutputPtr, kj::dynamicDowncastIfAvailable<kj::BufferedOutputStream>(output)) {
    467     writePackedMessage(*bufferedOutputPtr, segments);
    468   } else {
    469     byte buffer[8192];
    470     kj::BufferedOutputStreamWrapper bufferedOutput(output, kj::arrayPtr(buffer, sizeof(buffer)));
    471     writePackedMessage(bufferedOutput, segments);
    472   }
    473 }
    474 
    475 void writePackedMessageToFd(int fd, kj::ArrayPtr<const kj::ArrayPtr<const word>> segments) {
    476   kj::FdOutputStream output(fd);
    477   writePackedMessage(output, segments);
    478 }
    479 
    480 size_t computeUnpackedSizeInWords(kj::ArrayPtr<const byte> packedBytes) {
    481   const byte* ptr = packedBytes.begin();
    482   const byte* end = packedBytes.end();
    483 
    484   size_t total = 0;
    485   while (ptr < end) {
    486     uint tag = *ptr;
    487     size_t count = kj::popCount(tag);
    488     total += 1;
    489     KJ_REQUIRE(end - ptr >= count, "invalid packed data");
    490     ptr += count + 1;
    491 
    492     if (tag == 0) {
    493       KJ_REQUIRE(ptr < end, "invalid packed data");
    494       total += *ptr++;
    495     } else if (tag == 0xff) {
    496       KJ_REQUIRE(ptr < end, "invalid packed data");
    497       size_t words = *ptr++;
    498       total += words;
    499       size_t bytes = words * sizeof(word);
    500       KJ_REQUIRE(end - ptr >= bytes, "invalid packed data");
    501       ptr += bytes;
    502     }
    503   }
    504 
    505   return total;
    506 }
    507 
    508 }  // namespace capnp