serialize.c++ (11196B)
1 // Copyright (c) 2013-2014 Sandstorm Development Group, Inc. and contributors 2 // Licensed under the MIT License: 3 // 4 // Permission is hereby granted, free of charge, to any person obtaining a copy 5 // of this software and associated documentation files (the "Software"), to deal 6 // in the Software without restriction, including without limitation the rights 7 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 8 // copies of the Software, and to permit persons to whom the Software is 9 // furnished to do so, subject to the following conditions: 10 // 11 // The above copyright notice and this permission notice shall be included in 12 // all copies or substantial portions of the Software. 13 // 14 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 17 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 19 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 20 // THE SOFTWARE. 21 22 #include "serialize.h" 23 #include "layout.h" 24 #include <kj/debug.h> 25 #include <exception> 26 #ifdef _WIN32 27 #include <io.h> 28 #include <fcntl.h> 29 #endif 30 31 namespace capnp { 32 33 FlatArrayMessageReader::FlatArrayMessageReader( 34 kj::ArrayPtr<const word> array, ReaderOptions options) 35 : MessageReader(options), end(array.end()) { 36 if (array.size() < 1) { 37 // Assume empty message. 38 return; 39 } 40 41 const _::WireValue<uint32_t>* table = 42 reinterpret_cast<const _::WireValue<uint32_t>*>(array.begin()); 43 44 uint segmentCount = table[0].get() + 1; 45 size_t offset = segmentCount / 2u + 1u; 46 47 KJ_REQUIRE(array.size() >= offset, "Message ends prematurely in segment table.") { 48 return; 49 } 50 51 { 52 uint segmentSize = table[1].get(); 53 54 KJ_REQUIRE(array.size() >= offset + segmentSize, 55 "Message ends prematurely in first segment.") { 56 return; 57 } 58 59 segment0 = array.slice(offset, offset + segmentSize); 60 offset += segmentSize; 61 } 62 63 if (segmentCount > 1) { 64 moreSegments = kj::heapArray<kj::ArrayPtr<const word>>(segmentCount - 1); 65 66 for (uint i = 1; i < segmentCount; i++) { 67 uint segmentSize = table[i + 1].get(); 68 69 KJ_REQUIRE(array.size() >= offset + segmentSize, "Message ends prematurely.") { 70 moreSegments = nullptr; 71 return; 72 } 73 74 moreSegments[i - 1] = array.slice(offset, offset + segmentSize); 75 offset += segmentSize; 76 } 77 } 78 79 end = array.begin() + offset; 80 } 81 82 size_t expectedSizeInWordsFromPrefix(kj::ArrayPtr<const word> array) { 83 if (array.size() < 1) { 84 // All messages are at least one word. 85 return 1; 86 } 87 88 const _::WireValue<uint32_t>* table = 89 reinterpret_cast<const _::WireValue<uint32_t>*>(array.begin()); 90 91 uint segmentCount = table[0].get() + 1; 92 size_t offset = segmentCount / 2u + 1u; 93 94 // If the array is too small to contain the full segment table, truncate segmentCount to just 95 // what is available. 96 segmentCount = kj::min(segmentCount, array.size() * 2 - 1u); 97 98 size_t totalSize = offset; 99 for (uint i = 0; i < segmentCount; i++) { 100 totalSize += table[i + 1].get(); 101 } 102 return totalSize; 103 } 104 105 kj::ArrayPtr<const word> FlatArrayMessageReader::getSegment(uint id) { 106 if (id == 0) { 107 return segment0; 108 } else if (id <= moreSegments.size()) { 109 return moreSegments[id - 1]; 110 } else { 111 return nullptr; 112 } 113 } 114 115 kj::ArrayPtr<const word> initMessageBuilderFromFlatArrayCopy( 116 kj::ArrayPtr<const word> array, MessageBuilder& target, ReaderOptions options) { 117 FlatArrayMessageReader reader(array, options); 118 target.setRoot(reader.getRoot<AnyPointer>()); 119 return kj::arrayPtr(reader.getEnd(), array.end()); 120 } 121 122 kj::Array<word> messageToFlatArray(kj::ArrayPtr<const kj::ArrayPtr<const word>> segments) { 123 kj::Array<word> result = kj::heapArray<word>(computeSerializedSizeInWords(segments)); 124 125 _::WireValue<uint32_t>* table = 126 reinterpret_cast<_::WireValue<uint32_t>*>(result.begin()); 127 128 // We write the segment count - 1 because this makes the first word zero for single-segment 129 // messages, improving compression. We don't bother doing this with segment sizes because 130 // one-word segments are rare anyway. 131 table[0].set(segments.size() - 1); 132 133 for (uint i = 0; i < segments.size(); i++) { 134 table[i + 1].set(segments[i].size()); 135 } 136 137 if (segments.size() % 2 == 0) { 138 // Set padding byte. 139 table[segments.size() + 1].set(0); 140 } 141 142 word* dst = result.begin() + segments.size() / 2 + 1; 143 144 for (auto& segment: segments) { 145 memcpy(dst, segment.begin(), segment.size() * sizeof(word)); 146 dst += segment.size(); 147 } 148 149 KJ_DASSERT(dst == result.end(), "Buffer overrun/underrun bug in code above."); 150 151 return kj::mv(result); 152 } 153 154 size_t computeSerializedSizeInWords(kj::ArrayPtr<const kj::ArrayPtr<const word>> segments) { 155 KJ_REQUIRE(segments.size() > 0, "Tried to serialize uninitialized message."); 156 157 size_t totalSize = segments.size() / 2 + 1; 158 159 for (auto& segment: segments) { 160 totalSize += segment.size(); 161 } 162 163 return totalSize; 164 } 165 166 // ======================================================================================= 167 168 InputStreamMessageReader::InputStreamMessageReader( 169 kj::InputStream& inputStream, ReaderOptions options, kj::ArrayPtr<word> scratchSpace) 170 : MessageReader(options), inputStream(inputStream), readPos(nullptr) { 171 _::WireValue<uint32_t> firstWord[2]; 172 173 inputStream.read(firstWord, sizeof(firstWord)); 174 175 uint segmentCount = firstWord[0].get() + 1; 176 uint segment0Size = segmentCount == 0 ? 0 : firstWord[1].get(); 177 178 size_t totalWords = segment0Size; 179 180 // Reject messages with too many segments for security reasons. 181 KJ_REQUIRE(segmentCount < 512, "Message has too many segments.") { 182 segmentCount = 1; 183 segment0Size = 1; 184 break; 185 } 186 187 // Read sizes for all segments except the first. Include padding if necessary. 188 KJ_STACK_ARRAY(_::WireValue<uint32_t>, moreSizes, segmentCount & ~1, 16, 64); 189 if (segmentCount > 1) { 190 inputStream.read(moreSizes.begin(), moreSizes.size() * sizeof(moreSizes[0])); 191 for (uint i = 0; i < segmentCount - 1; i++) { 192 totalWords += moreSizes[i].get(); 193 } 194 } 195 196 // Don't accept a message which the receiver couldn't possibly traverse without hitting the 197 // traversal limit. Without this check, a malicious client could transmit a very large segment 198 // size to make the receiver allocate excessive space and possibly crash. 199 KJ_REQUIRE(totalWords <= options.traversalLimitInWords, 200 "Message is too large. To increase the limit on the receiving end, see " 201 "capnp::ReaderOptions.") { 202 segmentCount = 1; 203 segment0Size = kj::min(segment0Size, options.traversalLimitInWords); 204 totalWords = segment0Size; 205 break; 206 } 207 208 if (scratchSpace.size() < totalWords) { 209 // TODO(perf): Consider allocating each segment as a separate chunk to reduce memory 210 // fragmentation. 211 ownedSpace = kj::heapArray<word>(totalWords); 212 scratchSpace = ownedSpace; 213 } 214 215 segment0 = scratchSpace.slice(0, segment0Size); 216 217 if (segmentCount > 1) { 218 moreSegments = kj::heapArray<kj::ArrayPtr<const word>>(segmentCount - 1); 219 size_t offset = segment0Size; 220 221 for (uint i = 0; i < segmentCount - 1; i++) { 222 uint segmentSize = moreSizes[i].get(); 223 moreSegments[i] = scratchSpace.slice(offset, offset + segmentSize); 224 offset += segmentSize; 225 } 226 } 227 228 if (segmentCount == 1) { 229 inputStream.read(scratchSpace.begin(), totalWords * sizeof(word)); 230 } else if (segmentCount > 1) { 231 readPos = scratchSpace.asBytes().begin(); 232 readPos += inputStream.read(readPos, segment0Size * sizeof(word), totalWords * sizeof(word)); 233 } 234 } 235 236 InputStreamMessageReader::~InputStreamMessageReader() noexcept(false) { 237 if (readPos != nullptr) { 238 unwindDetector.catchExceptionsIfUnwinding([&]() { 239 // Note that lazy reads only happen when we have multiple segments, so moreSegments.back() is 240 // valid. 241 const byte* allEnd = reinterpret_cast<const byte*>(moreSegments.back().end()); 242 inputStream.skip(allEnd - readPos); 243 }); 244 } 245 } 246 247 kj::ArrayPtr<const word> InputStreamMessageReader::getSegment(uint id) { 248 if (id > moreSegments.size()) { 249 return nullptr; 250 } 251 252 kj::ArrayPtr<const word> segment = id == 0 ? segment0 : moreSegments[id - 1]; 253 254 if (readPos != nullptr) { 255 // May need to lazily read more data. 256 const byte* segmentEnd = reinterpret_cast<const byte*>(segment.end()); 257 if (readPos < segmentEnd) { 258 // Note that lazy reads only happen when we have multiple segments, so moreSegments.back() is 259 // valid. 260 const byte* allEnd = reinterpret_cast<const byte*>(moreSegments.back().end()); 261 readPos += inputStream.read(readPos, segmentEnd - readPos, allEnd - readPos); 262 } 263 } 264 265 return segment; 266 } 267 268 void readMessageCopy(kj::InputStream& input, MessageBuilder& target, 269 ReaderOptions options, kj::ArrayPtr<word> scratchSpace) { 270 InputStreamMessageReader message(input, options, scratchSpace); 271 target.setRoot(message.getRoot<AnyPointer>()); 272 } 273 274 // ------------------------------------------------------------------- 275 276 void writeMessage(kj::OutputStream& output, kj::ArrayPtr<const kj::ArrayPtr<const word>> segments) { 277 KJ_REQUIRE(segments.size() > 0, "Tried to serialize uninitialized message."); 278 279 KJ_STACK_ARRAY(_::WireValue<uint32_t>, table, (segments.size() + 2) & ~size_t(1), 16, 64); 280 281 // We write the segment count - 1 because this makes the first word zero for single-segment 282 // messages, improving compression. We don't bother doing this with segment sizes because 283 // one-word segments are rare anyway. 284 table[0].set(segments.size() - 1); 285 for (uint i = 0; i < segments.size(); i++) { 286 table[i + 1].set(segments[i].size()); 287 } 288 if (segments.size() % 2 == 0) { 289 // Set padding byte. 290 table[segments.size() + 1].set(0); 291 } 292 293 KJ_STACK_ARRAY(kj::ArrayPtr<const byte>, pieces, segments.size() + 1, 4, 32); 294 pieces[0] = table.asBytes(); 295 296 for (uint i = 0; i < segments.size(); i++) { 297 pieces[i + 1] = segments[i].asBytes(); 298 } 299 300 output.write(pieces); 301 } 302 303 // ======================================================================================= 304 305 StreamFdMessageReader::~StreamFdMessageReader() noexcept(false) {} 306 307 void writeMessageToFd(int fd, kj::ArrayPtr<const kj::ArrayPtr<const word>> segments) { 308 #ifdef _WIN32 309 auto oldMode = _setmode(fd, _O_BINARY); 310 if (oldMode != _O_BINARY) { 311 _setmode(fd, oldMode); 312 KJ_FAIL_REQUIRE("Tried to write a message to a file descriptor that is in text mode. Set the " 313 "file descriptor to binary mode by calling the _setmode Windows CRT function, or passing " 314 "_O_BINARY to _open()."); 315 } 316 #endif 317 kj::FdOutputStream stream(fd); 318 writeMessage(stream, segments); 319 } 320 321 void readMessageCopyFromFd(int fd, MessageBuilder& target, 322 ReaderOptions options, kj::ArrayPtr<word> scratchSpace) { 323 kj::FdInputStream stream(fd); 324 readMessageCopy(stream, target, options, scratchSpace); 325 } 326 327 } // namespace capnp