libjxl

FORK: libjxl patches used on blog
git clone https://git.neptards.moe/blog/libjxl.git
Log | Files | Refs | Submodules | README | LICENSE

icc_codec.cc (14840B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 #include "lib/jxl/icc_codec.h"
      7 
      8 #include <stdint.h>
      9 
     10 #include <map>
     11 #include <string>
     12 #include <vector>
     13 
     14 #include "lib/jxl/base/byte_order.h"
     15 #include "lib/jxl/dec_ans.h"
     16 #include "lib/jxl/fields.h"
     17 #include "lib/jxl/icc_codec_common.h"
     18 #include "lib/jxl/padded_bytes.h"
     19 
     20 namespace jxl {
     21 namespace {
     22 
     23 // Shuffles or interleaves bytes, for example with width 2, turns "ABCDabcd"
     24 // into "AaBbCcDc". Transposes a matrix of ceil(size / width) columns and
     25 // width rows. There are size elements, size may be < width * height, if so the
     26 // last elements of the rightmost column are missing, the missing spots are
     27 // transposed along with the filled spots, and the result has the missing
     28 // elements at the end of the bottom row. The input is the input matrix in
     29 // scanline order but with missing elements skipped (which may occur in multiple
     30 // locations), the output is the result matrix in scanline order (with
     31 // no need to skip missing elements as they are past the end of the data).
     32 void Shuffle(uint8_t* data, size_t size, size_t width) {
     33   size_t height = (size + width - 1) / width;  // amount of rows of output
     34   PaddedBytes result(size);
     35   // i = output index, j input index
     36   size_t s = 0;
     37   size_t j = 0;
     38   for (size_t i = 0; i < size; i++) {
     39     result[i] = data[j];
     40     j += height;
     41     if (j >= size) j = ++s;
     42   }
     43 
     44   for (size_t i = 0; i < size; i++) {
     45     data[i] = result[i];
     46   }
     47 }
     48 
     49 // TODO(eustas): should be 20, or even 18, once DecodeVarInt is improved;
     50 //               currently DecodeVarInt does not signal the errors, and marks
     51 //               11 bytes as used even if only 10 are used (and 9 is enough for
     52 //               63-bit values).
     53 constexpr const size_t kPreambleSize = 22;  // enough for reading 2 VarInts
     54 
     55 uint64_t DecodeVarInt(const uint8_t* input, size_t inputSize, size_t* pos) {
     56   size_t i;
     57   uint64_t ret = 0;
     58   for (i = 0; *pos + i < inputSize && i < 10; ++i) {
     59     ret |= static_cast<uint64_t>(input[*pos + i] & 127)
     60            << static_cast<uint64_t>(7 * i);
     61     // If the next-byte flag is not set, stop
     62     if ((input[*pos + i] & 128) == 0) break;
     63   }
     64   // TODO(user): Return a decoding error if i == 10.
     65   *pos += i + 1;
     66   return ret;
     67 }
     68 
     69 }  // namespace
     70 
     71 // Mimics the beginning of UnpredictICC for quick validity check.
     72 // At least kPreambleSize bytes of data should be valid at invocation time.
     73 Status CheckPreamble(const PaddedBytes& data, size_t enc_size,
     74                      size_t output_limit) {
     75   const uint8_t* enc = data.data();
     76   size_t size = data.size();
     77   size_t pos = 0;
     78   uint64_t osize = DecodeVarInt(enc, size, &pos);
     79   JXL_RETURN_IF_ERROR(CheckIs32Bit(osize));
     80   if (pos >= size) return JXL_FAILURE("Out of bounds");
     81   uint64_t csize = DecodeVarInt(enc, size, &pos);
     82   JXL_RETURN_IF_ERROR(CheckIs32Bit(csize));
     83   JXL_RETURN_IF_ERROR(CheckOutOfBounds(pos, csize, size));
     84   // We expect that UnpredictICC inflates input, not the other way round.
     85   if (osize + 65536 < enc_size) return JXL_FAILURE("Malformed ICC");
     86   if (output_limit && osize > output_limit) {
     87     return JXL_FAILURE("Decoded ICC is too large");
     88   }
     89   return true;
     90 }
     91 
     92 // Decodes the result of PredictICC back to a valid ICC profile.
     93 Status UnpredictICC(const uint8_t* enc, size_t size, PaddedBytes* result) {
     94   if (!result->empty()) return JXL_FAILURE("result must be empty initially");
     95   size_t pos = 0;
     96   // TODO(lode): technically speaking we need to check that the entire varint
     97   // decoding never goes out of bounds, not just the first byte. This requires
     98   // a DecodeVarInt function that returns an error code. It is safe to use
     99   // DecodeVarInt with out of bounds values, it silently returns, but the
    100   // specification requires an error. Idem for all DecodeVarInt below.
    101   if (pos >= size) return JXL_FAILURE("Out of bounds");
    102   uint64_t osize = DecodeVarInt(enc, size, &pos);  // Output size
    103   JXL_RETURN_IF_ERROR(CheckIs32Bit(osize));
    104   if (pos >= size) return JXL_FAILURE("Out of bounds");
    105   uint64_t csize = DecodeVarInt(enc, size, &pos);  // Commands size
    106   // Every command is translated to at least on byte.
    107   JXL_RETURN_IF_ERROR(CheckIs32Bit(csize));
    108   size_t cpos = pos;  // pos in commands stream
    109   JXL_RETURN_IF_ERROR(CheckOutOfBounds(pos, csize, size));
    110   size_t commands_end = cpos + csize;
    111   pos = commands_end;  // pos in data stream
    112 
    113   // Header
    114   PaddedBytes header;
    115   header.append(ICCInitialHeaderPrediction());
    116   EncodeUint32(0, osize, &header);
    117   for (size_t i = 0; i <= kICCHeaderSize; i++) {
    118     if (result->size() == osize) {
    119       if (cpos != commands_end) return JXL_FAILURE("Not all commands used");
    120       if (pos != size) return JXL_FAILURE("Not all data used");
    121       return true;  // Valid end
    122     }
    123     if (i == kICCHeaderSize) break;  // Done
    124     ICCPredictHeader(result->data(), result->size(), header.data(), i);
    125     if (pos >= size) return JXL_FAILURE("Out of bounds");
    126     result->push_back(enc[pos++] + header[i]);
    127   }
    128   if (cpos >= commands_end) return JXL_FAILURE("Out of bounds");
    129 
    130   // Tag list
    131   uint64_t numtags = DecodeVarInt(enc, size, &cpos);
    132 
    133   if (numtags != 0) {
    134     numtags--;
    135     JXL_RETURN_IF_ERROR(CheckIs32Bit(numtags));
    136     AppendUint32(numtags, result);
    137     uint64_t prevtagstart = kICCHeaderSize + numtags * 12;
    138     uint64_t prevtagsize = 0;
    139     for (;;) {
    140       if (result->size() > osize) return JXL_FAILURE("Invalid result size");
    141       if (cpos > commands_end) return JXL_FAILURE("Out of bounds");
    142       if (cpos == commands_end) break;  // Valid end
    143       uint8_t command = enc[cpos++];
    144       uint8_t tagcode = command & 63;
    145       Tag tag;
    146       if (tagcode == 0) {
    147         break;
    148       } else if (tagcode == kCommandTagUnknown) {
    149         JXL_RETURN_IF_ERROR(CheckOutOfBounds(pos, 4, size));
    150         tag = DecodeKeyword(enc, size, pos);
    151         pos += 4;
    152       } else if (tagcode == kCommandTagTRC) {
    153         tag = kRtrcTag;
    154       } else if (tagcode == kCommandTagXYZ) {
    155         tag = kRxyzTag;
    156       } else {
    157         if (tagcode - kCommandTagStringFirst >= kNumTagStrings) {
    158           return JXL_FAILURE("Unknown tagcode");
    159         }
    160         tag = *kTagStrings[tagcode - kCommandTagStringFirst];
    161       }
    162       AppendKeyword(tag, result);
    163 
    164       uint64_t tagstart;
    165       uint64_t tagsize = prevtagsize;
    166       if (tag == kRxyzTag || tag == kGxyzTag || tag == kBxyzTag ||
    167           tag == kKxyzTag || tag == kWtptTag || tag == kBkptTag ||
    168           tag == kLumiTag) {
    169         tagsize = 20;
    170       }
    171 
    172       if (command & kFlagBitOffset) {
    173         if (cpos >= commands_end) return JXL_FAILURE("Out of bounds");
    174         tagstart = DecodeVarInt(enc, size, &cpos);
    175       } else {
    176         JXL_RETURN_IF_ERROR(CheckIs32Bit(prevtagstart));
    177         tagstart = prevtagstart + prevtagsize;
    178       }
    179       JXL_RETURN_IF_ERROR(CheckIs32Bit(tagstart));
    180       AppendUint32(tagstart, result);
    181       if (command & kFlagBitSize) {
    182         if (cpos >= commands_end) return JXL_FAILURE("Out of bounds");
    183         tagsize = DecodeVarInt(enc, size, &cpos);
    184       }
    185       JXL_RETURN_IF_ERROR(CheckIs32Bit(tagsize));
    186       AppendUint32(tagsize, result);
    187       prevtagstart = tagstart;
    188       prevtagsize = tagsize;
    189 
    190       if (tagcode == kCommandTagTRC) {
    191         AppendKeyword(kGtrcTag, result);
    192         AppendUint32(tagstart, result);
    193         AppendUint32(tagsize, result);
    194         AppendKeyword(kBtrcTag, result);
    195         AppendUint32(tagstart, result);
    196         AppendUint32(tagsize, result);
    197       }
    198 
    199       if (tagcode == kCommandTagXYZ) {
    200         JXL_RETURN_IF_ERROR(CheckIs32Bit(tagstart + tagsize * 2));
    201         AppendKeyword(kGxyzTag, result);
    202         AppendUint32(tagstart + tagsize, result);
    203         AppendUint32(tagsize, result);
    204         AppendKeyword(kBxyzTag, result);
    205         AppendUint32(tagstart + tagsize * 2, result);
    206         AppendUint32(tagsize, result);
    207       }
    208     }
    209   }
    210 
    211   // Main Content
    212   for (;;) {
    213     if (result->size() > osize) return JXL_FAILURE("Invalid result size");
    214     if (cpos > commands_end) return JXL_FAILURE("Out of bounds");
    215     if (cpos == commands_end) break;  // Valid end
    216     uint8_t command = enc[cpos++];
    217     if (command == kCommandInsert) {
    218       if (cpos >= commands_end) return JXL_FAILURE("Out of bounds");
    219       uint64_t num = DecodeVarInt(enc, size, &cpos);
    220       JXL_RETURN_IF_ERROR(CheckOutOfBounds(pos, num, size));
    221       for (size_t i = 0; i < num; i++) {
    222         result->push_back(enc[pos++]);
    223       }
    224     } else if (command == kCommandShuffle2 || command == kCommandShuffle4) {
    225       if (cpos >= commands_end) return JXL_FAILURE("Out of bounds");
    226       uint64_t num = DecodeVarInt(enc, size, &cpos);
    227       JXL_RETURN_IF_ERROR(CheckOutOfBounds(pos, num, size));
    228       PaddedBytes shuffled(num);
    229       for (size_t i = 0; i < num; i++) {
    230         shuffled[i] = enc[pos + i];
    231       }
    232       if (command == kCommandShuffle2) {
    233         Shuffle(shuffled.data(), num, 2);
    234       } else if (command == kCommandShuffle4) {
    235         Shuffle(shuffled.data(), num, 4);
    236       }
    237       for (size_t i = 0; i < num; i++) {
    238         result->push_back(shuffled[i]);
    239         pos++;
    240       }
    241     } else if (command == kCommandPredict) {
    242       JXL_RETURN_IF_ERROR(CheckOutOfBounds(cpos, 2, commands_end));
    243       uint8_t flags = enc[cpos++];
    244 
    245       size_t width = (flags & 3) + 1;
    246       if (width == 3) return JXL_FAILURE("Invalid width");
    247 
    248       int order = (flags & 12) >> 2;
    249       if (order == 3) return JXL_FAILURE("Invalid order");
    250 
    251       uint64_t stride = width;
    252       if (flags & 16) {
    253         if (cpos >= commands_end) return JXL_FAILURE("Out of bounds");
    254         stride = DecodeVarInt(enc, size, &cpos);
    255         if (stride < width) {
    256           return JXL_FAILURE("Invalid stride");
    257         }
    258       }
    259       // If stride * 4 >= result->size(), return failure. The check
    260       // "size == 0 || ((size - 1) >> 2) < stride" corresponds to
    261       // "stride * 4 >= size", but does not suffer from integer overflow.
    262       // This check is more strict than necessary but follows the specification
    263       // and the encoder should ensure this is followed.
    264       if (result->empty() || ((result->size() - 1u) >> 2u) < stride) {
    265         return JXL_FAILURE("Invalid stride");
    266       }
    267 
    268       if (cpos >= commands_end) return JXL_FAILURE("Out of bounds");
    269       uint64_t num = DecodeVarInt(enc, size, &cpos);  // in bytes
    270       JXL_RETURN_IF_ERROR(CheckOutOfBounds(pos, num, size));
    271 
    272       PaddedBytes shuffled(num);
    273       for (size_t i = 0; i < num; i++) {
    274         shuffled[i] = enc[pos + i];
    275       }
    276       if (width > 1) Shuffle(shuffled.data(), num, width);
    277 
    278       size_t start = result->size();
    279       for (size_t i = 0; i < num; i++) {
    280         uint8_t predicted = LinearPredictICCValue(result->data(), start, i,
    281                                                   stride, width, order);
    282         result->push_back(predicted + shuffled[i]);
    283       }
    284       pos += num;
    285     } else if (command == kCommandXYZ) {
    286       AppendKeyword(kXyz_Tag, result);
    287       for (int i = 0; i < 4; i++) result->push_back(0);
    288       JXL_RETURN_IF_ERROR(CheckOutOfBounds(pos, 12, size));
    289       for (size_t i = 0; i < 12; i++) {
    290         result->push_back(enc[pos++]);
    291       }
    292     } else if (command >= kCommandTypeStartFirst &&
    293                command < kCommandTypeStartFirst + kNumTypeStrings) {
    294       AppendKeyword(*kTypeStrings[command - kCommandTypeStartFirst], result);
    295       for (size_t i = 0; i < 4; i++) {
    296         result->push_back(0);
    297       }
    298     } else {
    299       return JXL_FAILURE("Unknown command");
    300     }
    301   }
    302 
    303   if (pos != size) return JXL_FAILURE("Not all data used");
    304   if (result->size() != osize) return JXL_FAILURE("Invalid result size");
    305 
    306   return true;
    307 }
    308 
    309 Status ICCReader::Init(BitReader* reader, size_t output_limit) {
    310   JXL_RETURN_IF_ERROR(CheckEOI(reader));
    311   used_bits_base_ = reader->TotalBitsConsumed();
    312   if (bits_to_skip_ == 0) {
    313     enc_size_ = U64Coder::Read(reader);
    314     if (enc_size_ > 268435456) {
    315       // Avoid too large memory allocation for invalid file.
    316       return JXL_FAILURE("Too large encoded profile");
    317     }
    318     JXL_RETURN_IF_ERROR(
    319         DecodeHistograms(reader, kNumICCContexts, &code_, &context_map_));
    320     ans_reader_ = ANSSymbolReader(&code_, reader);
    321     i_ = 0;
    322     decompressed_.resize(std::min<size_t>(i_ + 0x400, enc_size_));
    323     for (; i_ < std::min<size_t>(2, enc_size_); i_++) {
    324       decompressed_[i_] = ans_reader_.ReadHybridUint(
    325           ICCANSContext(i_, i_ > 0 ? decompressed_[i_ - 1] : 0,
    326                         i_ > 1 ? decompressed_[i_ - 2] : 0),
    327           reader, context_map_);
    328     }
    329     if (enc_size_ > kPreambleSize) {
    330       for (; i_ < kPreambleSize; i_++) {
    331         decompressed_[i_] = ans_reader_.ReadHybridUint(
    332             ICCANSContext(i_, decompressed_[i_ - 1], decompressed_[i_ - 2]),
    333             reader, context_map_);
    334       }
    335       JXL_RETURN_IF_ERROR(CheckEOI(reader));
    336       JXL_RETURN_IF_ERROR(
    337           CheckPreamble(decompressed_, enc_size_, output_limit));
    338     }
    339     bits_to_skip_ = reader->TotalBitsConsumed() - used_bits_base_;
    340   } else {
    341     reader->SkipBits(bits_to_skip_);
    342   }
    343   return true;
    344 }
    345 
    346 Status ICCReader::Process(BitReader* reader, PaddedBytes* icc) {
    347   ANSSymbolReader::Checkpoint checkpoint;
    348   size_t saved_i = 0;
    349   auto save = [&]() {
    350     ans_reader_.Save(&checkpoint);
    351     bits_to_skip_ = reader->TotalBitsConsumed() - used_bits_base_;
    352     saved_i = i_;
    353   };
    354   save();
    355   auto check_and_restore = [&]() {
    356     Status status = CheckEOI(reader);
    357     if (!status) {
    358       // not enough bytes.
    359       ans_reader_.Restore(checkpoint);
    360       i_ = saved_i;
    361       return status;
    362     }
    363     return Status(true);
    364   };
    365   for (; i_ < enc_size_; i_++) {
    366     if (i_ % ANSSymbolReader::kMaxCheckpointInterval == 0 && i_ > 0) {
    367       JXL_RETURN_IF_ERROR(check_and_restore());
    368       save();
    369       if ((i_ > 0) && (((i_ & 0xFFFF) == 0))) {
    370         float used_bytes =
    371             (reader->TotalBitsConsumed() - used_bits_base_) / 8.0f;
    372         if (i_ > used_bytes * 256) return JXL_FAILURE("Corrupted stream");
    373       }
    374       decompressed_.resize(std::min<size_t>(i_ + 0x400, enc_size_));
    375     }
    376     JXL_DASSERT(i_ >= 2);
    377     decompressed_[i_] = ans_reader_.ReadHybridUint(
    378         ICCANSContext(i_, decompressed_[i_ - 1], decompressed_[i_ - 2]), reader,
    379         context_map_);
    380   }
    381   JXL_RETURN_IF_ERROR(check_and_restore());
    382   bits_to_skip_ = reader->TotalBitsConsumed() - used_bits_base_;
    383   if (!ans_reader_.CheckANSFinalState()) {
    384     return JXL_FAILURE("Corrupted ICC profile");
    385   }
    386 
    387   icc->clear();
    388   return UnpredictICC(decompressed_.data(), decompressed_.size(), icc);
    389 }
    390 
    391 Status ICCReader::CheckEOI(BitReader* reader) {
    392   if (reader->AllReadsWithinBounds()) return true;
    393   return JXL_STATUS(StatusCode::kNotEnoughBytes,
    394                     "Not enough bytes for reading ICC profile");
    395 }
    396 
    397 }  // namespace jxl