djxl_fuzzer.cc (21301B)
1 // Copyright (c) the JPEG XL Project Authors. All rights reserved. 2 // 3 // Use of this source code is governed by a BSD-style 4 // license that can be found in the LICENSE file. 5 6 #include <jxl/codestream_header.h> 7 #include <jxl/decode.h> 8 #include <jxl/decode_cxx.h> 9 #include <jxl/thread_parallel_runner.h> 10 #include <jxl/thread_parallel_runner_cxx.h> 11 #include <jxl/types.h> 12 #include <limits.h> 13 #include <stdint.h> 14 #include <stdlib.h> 15 #include <string.h> 16 17 #include <algorithm> 18 #include <hwy/targets.h> 19 #include <map> 20 #include <mutex> 21 #include <random> 22 #include <vector> 23 24 namespace { 25 26 // Externally visible value to ensure pixels are used in the fuzzer. 27 int external_code = 0; 28 29 constexpr const size_t kStreamingTargetNumberOfChunks = 128; 30 31 // Options for the fuzzing 32 struct FuzzSpec { 33 JxlDataType output_type; 34 JxlEndianness output_endianness; 35 size_t output_align; 36 bool get_alpha; 37 bool get_grayscale; 38 bool use_streaming; 39 bool jpeg_to_pixels; // decode to pixels even if it is JPEG-reconstructible 40 // Whether to use the callback mechanism for the output image or not. 41 bool use_callback; 42 bool keep_orientation; 43 bool decode_boxes; 44 bool coalescing; 45 // Used for random variation of chunk sizes, extra channels, ... to get 46 uint32_t random_seed; 47 }; 48 49 template <typename It> 50 void Consume(const It& begin, const It& end) { 51 for (auto it = begin; it < end; ++it) { 52 if (*it == 0) { 53 external_code ^= ~0; 54 } else { 55 external_code ^= *it; 56 } 57 } 58 } 59 60 template <typename T> 61 void Consume(const T& entry) { 62 const uint8_t* begin = reinterpret_cast<const uint8_t*>(&entry); 63 Consume(begin, begin + sizeof(T)); 64 } 65 66 // use_streaming: if true, decodes the data in small chunks, if false, decodes 67 // it in one shot. 68 bool DecodeJpegXl(const uint8_t* jxl, size_t size, size_t max_pixels, 69 const FuzzSpec& spec, std::vector<uint8_t>* pixels, 70 std::vector<uint8_t>* jpeg, size_t* xsize, size_t* ysize, 71 std::vector<uint8_t>* icc_profile) { 72 // Multi-threaded parallel runner. Limit to max 2 threads since the fuzzer 73 // itself is already multithreaded. 74 size_t num_threads = 75 std::min<size_t>(2, JxlThreadParallelRunnerDefaultNumWorkerThreads()); 76 auto runner = JxlThreadParallelRunnerMake(nullptr, num_threads); 77 78 std::mt19937 mt(spec.random_seed); 79 std::exponential_distribution<> dis_streaming(kStreamingTargetNumberOfChunks); 80 81 auto dec = JxlDecoderMake(nullptr); 82 if (JXL_DEC_SUCCESS != 83 JxlDecoderSubscribeEvents( 84 dec.get(), JXL_DEC_BASIC_INFO | JXL_DEC_COLOR_ENCODING | 85 JXL_DEC_PREVIEW_IMAGE | JXL_DEC_FRAME | 86 JXL_DEC_FULL_IMAGE | JXL_DEC_JPEG_RECONSTRUCTION | 87 JXL_DEC_BOX)) { 88 return false; 89 } 90 if (JXL_DEC_SUCCESS != JxlDecoderSetParallelRunner(dec.get(), 91 JxlThreadParallelRunner, 92 runner.get())) { 93 return false; 94 } 95 if (JXL_DEC_SUCCESS != JxlDecoderSetKeepOrientation( 96 dec.get(), TO_JXL_BOOL(spec.keep_orientation))) { 97 abort(); 98 } 99 if (JXL_DEC_SUCCESS != 100 JxlDecoderSetCoalescing(dec.get(), TO_JXL_BOOL(spec.coalescing))) { 101 abort(); 102 } 103 JxlBasicInfo info; 104 uint32_t channels = (spec.get_grayscale ? 1 : 3) + (spec.get_alpha ? 1 : 0); 105 JxlPixelFormat format = {channels, spec.output_type, spec.output_endianness, 106 spec.output_align}; 107 108 if (!spec.use_streaming) { 109 // Set all input at once 110 JxlDecoderSetInput(dec.get(), jxl, size); 111 JxlDecoderCloseInput(dec.get()); 112 } 113 114 bool seen_basic_info = false; 115 bool seen_color_encoding = false; 116 bool seen_preview = false; 117 bool seen_need_image_out = false; 118 bool seen_full_image = false; 119 bool seen_frame = false; 120 uint32_t num_frames = 0; 121 bool seen_jpeg_reconstruction = false; 122 bool seen_jpeg_need_more_output = false; 123 // If streaming and seen around half the input, test flushing 124 bool tested_flush = false; 125 126 // Size made available for the streaming input, emulating a subset of the 127 // full input size. 128 size_t streaming_size = 0; 129 size_t leftover = size; 130 size_t preview_xsize = 0; 131 size_t preview_ysize = 0; 132 bool want_preview = false; 133 std::vector<uint8_t> preview_pixels; 134 135 std::vector<uint8_t> extra_channel_pixels; 136 137 // Callback function used when decoding with use_callback. 138 struct DecodeCallbackData { 139 JxlBasicInfo info; 140 size_t xsize = 0; 141 size_t ysize = 0; 142 std::mutex called_rows_mutex; 143 // For each row stores the segments of the row being called. For each row 144 // the sum of all the int values in the map up to [i] (inclusive) tell how 145 // many times a callback included the pixel i of that row. 146 std::vector<std::map<uint32_t, int>> called_rows; 147 148 // Use the pixel values. 149 uint32_t value = 0; 150 }; 151 DecodeCallbackData decode_callback_data; 152 auto decode_callback = +[](void* opaque, size_t x, size_t y, 153 size_t num_pixels, const void* pixels) { 154 DecodeCallbackData* data = static_cast<DecodeCallbackData*>(opaque); 155 if (num_pixels > data->xsize) abort(); 156 if (x + num_pixels > data->xsize) abort(); 157 if (y >= data->ysize) abort(); 158 if (num_pixels && !pixels) abort(); 159 // Keep track of the segments being called by the callback. 160 { 161 const std::lock_guard<std::mutex> lock(data->called_rows_mutex); 162 data->called_rows[y][x]++; 163 data->called_rows[y][x + num_pixels]--; 164 data->value += *static_cast<const uint8_t*>(pixels); 165 } 166 }; 167 168 JxlExtraChannelInfo extra_channel_info; 169 170 std::vector<uint8_t> box_buffer; 171 172 if (spec.decode_boxes && 173 JXL_DEC_SUCCESS != JxlDecoderSetDecompressBoxes(dec.get(), JXL_TRUE)) { 174 // error ignored, can still fuzz if it doesn't brotli-decompress brob boxes. 175 } 176 177 for (;;) { 178 JxlDecoderStatus status = JxlDecoderProcessInput(dec.get()); 179 if (status == JXL_DEC_ERROR) { 180 return false; 181 } else if (status == JXL_DEC_NEED_MORE_INPUT) { 182 if (spec.use_streaming) { 183 size_t remaining = JxlDecoderReleaseInput(dec.get()); 184 // move any remaining bytes to the front if necessary 185 size_t used = streaming_size - remaining; 186 jxl += used; 187 leftover -= used; 188 streaming_size -= used; 189 size_t chunk_size = std::max<size_t>( 190 1, size * std::min<double>(1.0, dis_streaming(mt))); 191 size_t add_size = 192 std::min<size_t>(chunk_size, leftover - streaming_size); 193 if (add_size == 0) { 194 // End of the streaming data reached 195 return false; 196 } 197 streaming_size += add_size; 198 if (JXL_DEC_SUCCESS != 199 JxlDecoderSetInput(dec.get(), jxl, streaming_size)) { 200 return false; 201 } 202 if (leftover == streaming_size) { 203 // All possible input bytes given 204 JxlDecoderCloseInput(dec.get()); 205 } 206 207 if (!tested_flush && seen_frame) { 208 // Test flush max once to avoid too slow fuzzer run 209 tested_flush = true; 210 JxlDecoderFlushImage(dec.get()); 211 } 212 } else { 213 return false; 214 } 215 } else if (status == JXL_DEC_JPEG_NEED_MORE_OUTPUT) { 216 if (want_preview) abort(); // expected preview before frame 217 if (spec.jpeg_to_pixels) abort(); 218 if (!seen_jpeg_reconstruction) abort(); 219 seen_jpeg_need_more_output = true; 220 size_t used_jpeg_output = 221 jpeg->size() - JxlDecoderReleaseJPEGBuffer(dec.get()); 222 jpeg->resize(std::max<size_t>(4096, jpeg->size() * 2)); 223 uint8_t* jpeg_buffer = jpeg->data() + used_jpeg_output; 224 size_t jpeg_buffer_size = jpeg->size() - used_jpeg_output; 225 226 if (JXL_DEC_SUCCESS != 227 JxlDecoderSetJPEGBuffer(dec.get(), jpeg_buffer, jpeg_buffer_size)) { 228 return false; 229 } 230 } else if (status == JXL_DEC_BASIC_INFO) { 231 if (seen_basic_info) abort(); // already seen basic info 232 seen_basic_info = true; 233 234 memset(&info, 0, sizeof(info)); 235 if (JXL_DEC_SUCCESS != JxlDecoderGetBasicInfo(dec.get(), &info)) { 236 return false; 237 } 238 Consume(info); 239 240 *xsize = info.xsize; 241 *ysize = info.ysize; 242 decode_callback_data.info = info; 243 size_t num_pixels = *xsize * *ysize; 244 // num_pixels overflow 245 if (*xsize != 0 && num_pixels / *xsize != *ysize) return false; 246 // limit max memory of this fuzzer test 247 if (num_pixels > max_pixels) return false; 248 249 if (info.have_preview) { 250 want_preview = true; 251 preview_xsize = info.preview.xsize; 252 preview_ysize = info.preview.ysize; 253 size_t preview_num_pixels = preview_xsize * preview_ysize; 254 // num_pixels overflow 255 if (preview_xsize != 0 && 256 preview_num_pixels / preview_xsize != preview_ysize) { 257 return false; 258 } 259 // limit max memory of this fuzzer test 260 if (preview_num_pixels > max_pixels) return false; 261 } 262 263 for (size_t ec = 0; ec < info.num_extra_channels; ++ec) { 264 memset(&extra_channel_info, 0, sizeof(extra_channel_info)); 265 if (JXL_DEC_SUCCESS != 266 JxlDecoderGetExtraChannelInfo(dec.get(), ec, &extra_channel_info)) { 267 abort(); 268 } 269 Consume(extra_channel_info); 270 std::vector<char> ec_name(extra_channel_info.name_length + 1); 271 if (JXL_DEC_SUCCESS != JxlDecoderGetExtraChannelName(dec.get(), ec, 272 ec_name.data(), 273 ec_name.size())) { 274 abort(); 275 } 276 Consume(ec_name.cbegin(), ec_name.cend()); 277 } 278 } else if (status == JXL_DEC_COLOR_ENCODING) { 279 if (!seen_basic_info) abort(); // expected basic info first 280 if (seen_color_encoding) abort(); // already seen color encoding 281 seen_color_encoding = true; 282 283 // Get the ICC color profile of the pixel data 284 size_t icc_size; 285 if (JXL_DEC_SUCCESS != 286 JxlDecoderGetICCProfileSize(dec.get(), JXL_COLOR_PROFILE_TARGET_DATA, 287 &icc_size)) { 288 return false; 289 } 290 icc_profile->resize(icc_size); 291 if (JXL_DEC_SUCCESS != JxlDecoderGetColorAsICCProfile( 292 dec.get(), JXL_COLOR_PROFILE_TARGET_DATA, 293 icc_profile->data(), icc_profile->size())) { 294 return false; 295 } 296 if (want_preview) { 297 size_t preview_size; 298 if (JXL_DEC_SUCCESS != 299 JxlDecoderPreviewOutBufferSize(dec.get(), &format, &preview_size)) { 300 return false; 301 } 302 preview_pixels.resize(preview_size); 303 if (JXL_DEC_SUCCESS != JxlDecoderSetPreviewOutBuffer( 304 dec.get(), &format, preview_pixels.data(), 305 preview_pixels.size())) { 306 abort(); 307 } 308 } 309 } else if (status == JXL_DEC_PREVIEW_IMAGE) { 310 // TODO(eustas): test JXL_DEC_NEED_PREVIEW_OUT_BUFFER 311 if (seen_preview) abort(); 312 if (!want_preview) abort(); 313 if (!seen_color_encoding) abort(); 314 want_preview = false; 315 seen_preview = true; 316 Consume(preview_pixels.cbegin(), preview_pixels.cend()); 317 } else if (status == JXL_DEC_FRAME || 318 status == JXL_DEC_NEED_IMAGE_OUT_BUFFER) { 319 if (want_preview) abort(); // expected preview before frame 320 if (!seen_color_encoding) abort(); // expected color encoding first 321 if (status == JXL_DEC_FRAME) { 322 if (seen_frame) abort(); // already seen JXL_DEC_FRAME 323 seen_frame = true; 324 JxlFrameHeader frame_header; 325 memset(&frame_header, 0, sizeof(frame_header)); 326 if (JXL_DEC_SUCCESS != 327 JxlDecoderGetFrameHeader(dec.get(), &frame_header)) { 328 abort(); 329 } 330 decode_callback_data.xsize = frame_header.layer_info.xsize; 331 decode_callback_data.ysize = frame_header.layer_info.ysize; 332 if (!spec.coalescing) { 333 decode_callback_data.called_rows.clear(); 334 } 335 decode_callback_data.called_rows.resize(decode_callback_data.ysize); 336 Consume(frame_header); 337 std::vector<char> frame_name(frame_header.name_length + 1); 338 if (JXL_DEC_SUCCESS != JxlDecoderGetFrameName(dec.get(), 339 frame_name.data(), 340 frame_name.size())) { 341 abort(); 342 } 343 Consume(frame_name.cbegin(), frame_name.cend()); 344 // When not testing streaming, test that JXL_DEC_NEED_IMAGE_OUT_BUFFER 345 // occurs instead, so do not set buffer now. 346 if (!spec.use_streaming) continue; 347 } 348 if (status == JXL_DEC_NEED_IMAGE_OUT_BUFFER) { 349 // expected JXL_DEC_FRAME instead 350 if (!seen_frame) abort(); 351 // already should have set buffer if streaming 352 if (spec.use_streaming) abort(); 353 // already seen need image out 354 if (seen_need_image_out) abort(); 355 seen_need_image_out = true; 356 } 357 358 if (info.num_extra_channels > 0) { 359 std::uniform_int_distribution<> dis(0, info.num_extra_channels); 360 size_t ec_index = dis(mt); 361 // There is also a probability no extra channel is chosen 362 if (ec_index < info.num_extra_channels) { 363 size_t ec_index = info.num_extra_channels - 1; 364 size_t ec_size; 365 if (JXL_DEC_SUCCESS != JxlDecoderExtraChannelBufferSize( 366 dec.get(), &format, &ec_size, ec_index)) { 367 return false; 368 } 369 extra_channel_pixels.resize(ec_size); 370 if (JXL_DEC_SUCCESS != 371 JxlDecoderSetExtraChannelBuffer(dec.get(), &format, 372 extra_channel_pixels.data(), 373 ec_size, ec_index)) { 374 return false; 375 } 376 } 377 } 378 379 if (spec.use_callback) { 380 if (JXL_DEC_SUCCESS != 381 JxlDecoderSetImageOutCallback(dec.get(), &format, decode_callback, 382 &decode_callback_data)) { 383 return false; 384 } 385 } else { 386 // Use the pixels output buffer. 387 size_t buffer_size; 388 if (JXL_DEC_SUCCESS != 389 JxlDecoderImageOutBufferSize(dec.get(), &format, &buffer_size)) { 390 return false; 391 } 392 pixels->resize(buffer_size); 393 void* pixels_buffer = static_cast<void*>(pixels->data()); 394 size_t pixels_buffer_size = pixels->size(); 395 if (JXL_DEC_SUCCESS != 396 JxlDecoderSetImageOutBuffer(dec.get(), &format, pixels_buffer, 397 pixels_buffer_size)) { 398 return false; 399 } 400 } 401 } else if (status == JXL_DEC_JPEG_RECONSTRUCTION) { 402 // Do not check preview precedence here, since this event only declares 403 // that JPEG is going to be decoded; though, when first byte of JPEG 404 // arrives (JXL_DEC_JPEG_NEED_MORE_OUTPUT) it is certain that preview 405 // should have been produced already. 406 if (seen_jpeg_reconstruction) abort(); 407 seen_jpeg_reconstruction = true; 408 if (!spec.jpeg_to_pixels) { 409 // Make sure buffer is allocated, but current size is too small to 410 // contain valid JPEG. 411 jpeg->resize(1); 412 uint8_t* jpeg_buffer = jpeg->data(); 413 size_t jpeg_buffer_size = jpeg->size(); 414 if (JXL_DEC_SUCCESS != 415 JxlDecoderSetJPEGBuffer(dec.get(), jpeg_buffer, jpeg_buffer_size)) { 416 return false; 417 } 418 } 419 } else if (status == JXL_DEC_FULL_IMAGE) { 420 if (want_preview) abort(); // expected preview before frame 421 if (!spec.jpeg_to_pixels && seen_jpeg_reconstruction) { 422 if (!seen_jpeg_need_more_output) abort(); 423 jpeg->resize(jpeg->size() - JxlDecoderReleaseJPEGBuffer(dec.get())); 424 } else { 425 // expected need image out or frame first 426 if (!seen_need_image_out && !seen_frame) abort(); 427 } 428 429 seen_full_image = true; // there may be multiple if animated 430 431 // There may be a next animation frame so expect those again: 432 seen_need_image_out = false; 433 seen_frame = false; 434 num_frames++; 435 436 // "Use" all the pixels; MSAN needs a conditional to count as usage. 437 Consume(pixels->cbegin(), pixels->cend()); 438 Consume(jpeg->cbegin(), jpeg->cend()); 439 440 // When not coalescing, check that the whole (possibly cropped) frame was 441 // sent 442 if (seen_need_image_out && spec.use_callback && spec.coalescing) { 443 // Check that the callback sent all the pixels 444 for (uint32_t y = 0; y < decode_callback_data.ysize; y++) { 445 // Check that each row was at least called once. 446 if (decode_callback_data.called_rows[y].empty()) abort(); 447 uint32_t last_idx = 0; 448 int calls = 0; 449 for (auto it : decode_callback_data.called_rows[y]) { 450 if (it.first > last_idx) { 451 if (static_cast<uint32_t>(calls) != 1) abort(); 452 } 453 calls += it.second; 454 last_idx = it.first; 455 } 456 } 457 } 458 // Nothing to do. Do not yet return. If the image is an animation, more 459 // full frames may be decoded. This example only keeps the last one. 460 } else if (status == JXL_DEC_SUCCESS) { 461 if (!seen_full_image) abort(); // expected full image before finishing 462 463 // When decoding we may not get seen_need_image_out unless we were 464 // decoding the image to pixels. 465 if (seen_need_image_out && spec.use_callback && spec.coalescing) { 466 // Check that the callback sent all the pixels 467 for (uint32_t y = 0; y < decode_callback_data.ysize; y++) { 468 // Check that each row was at least called once. 469 if (decode_callback_data.called_rows[y].empty()) abort(); 470 uint32_t last_idx = 0; 471 int calls = 0; 472 for (auto it : decode_callback_data.called_rows[y]) { 473 if (it.first > last_idx) { 474 if (static_cast<uint32_t>(calls) != num_frames) abort(); 475 } 476 calls += it.second; 477 last_idx = it.first; 478 } 479 } 480 } 481 482 // All decoding successfully finished. 483 // It's not required to call JxlDecoderReleaseInput(dec.get()) here since 484 // the decoder will be destroyed. 485 return true; 486 } else if (status == JXL_DEC_BOX) { 487 if (spec.decode_boxes) { 488 if (!box_buffer.empty()) { 489 size_t remaining = JxlDecoderReleaseBoxBuffer(dec.get()); 490 size_t box_size = box_buffer.size() - remaining; 491 if (box_size != 0) { 492 Consume(box_buffer.begin(), box_buffer.begin() + box_size); 493 box_buffer.clear(); 494 } 495 } 496 box_buffer.resize(64); 497 JxlDecoderSetBoxBuffer(dec.get(), box_buffer.data(), box_buffer.size()); 498 } 499 } else if (status == JXL_DEC_BOX_NEED_MORE_OUTPUT) { 500 if (!spec.decode_boxes) { 501 abort(); // Not expected when not setting output buffer 502 } 503 size_t remaining = JxlDecoderReleaseBoxBuffer(dec.get()); 504 size_t box_size = box_buffer.size() - remaining; 505 box_buffer.resize(box_buffer.size() * 2); 506 JxlDecoderSetBoxBuffer(dec.get(), box_buffer.data() + box_size, 507 box_buffer.size() - box_size); 508 } else { 509 return false; 510 } 511 } 512 } 513 514 int TestOneInput(const uint8_t* data, size_t size) { 515 if (size < 4) return 0; 516 uint32_t flags = 0; 517 size_t used_flag_bits = 0; 518 memcpy(&flags, data + size - 4, 4); 519 size -= 4; 520 521 const auto getFlag = [&flags, &used_flag_bits](size_t max_value) { 522 size_t limit = 1; 523 while (limit <= max_value) { 524 limit <<= 1; 525 used_flag_bits++; 526 if (used_flag_bits > 32) abort(); 527 } 528 uint32_t result = flags % limit; 529 flags /= limit; 530 return result % (max_value + 1); 531 }; 532 const auto getBoolFlag = [&getFlag]() -> bool { 533 return static_cast<bool>(getFlag(1)); 534 }; 535 536 FuzzSpec spec; 537 // Allows some different possible variations in the chunk sizes of the 538 // streaming case 539 spec.random_seed = flags ^ size; 540 spec.get_alpha = getBoolFlag(); 541 spec.get_grayscale = getBoolFlag(); 542 spec.use_streaming = getBoolFlag(); 543 spec.jpeg_to_pixels = getBoolFlag(); 544 spec.use_callback = getBoolFlag(); 545 spec.keep_orientation = getBoolFlag(); 546 spec.coalescing = getBoolFlag(); 547 spec.output_type = static_cast<JxlDataType>(getFlag(JXL_TYPE_FLOAT16)); 548 spec.output_endianness = static_cast<JxlEndianness>(getFlag(JXL_BIG_ENDIAN)); 549 spec.output_align = getFlag(16); 550 spec.decode_boxes = getBoolFlag(); 551 552 std::vector<uint8_t> pixels; 553 std::vector<uint8_t> jpeg; 554 std::vector<uint8_t> icc; 555 size_t xsize; 556 size_t ysize; 557 size_t max_pixels = 1 << 21; 558 559 const auto targets = hwy::SupportedAndGeneratedTargets(); 560 hwy::SetSupportedTargetsForTest(targets[getFlag(targets.size() - 1)]); 561 DecodeJpegXl(data, size, max_pixels, spec, &pixels, &jpeg, &xsize, &ysize, 562 &icc); 563 hwy::SetSupportedTargetsForTest(0); 564 565 return 0; 566 } 567 568 } // namespace 569 570 extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { 571 return TestOneInput(data, size); 572 }