libjxl

FORK: libjxl patches used on blog
git clone https://git.neptards.moe/blog/libjxl.git
Log | Files | Refs | Submodules | README | LICENSE

enc_icc_codec.cc (16435B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 #include "lib/jxl/enc_icc_codec.h"
      7 
      8 #include <stdint.h>
      9 
     10 #include <limits>
     11 #include <map>
     12 #include <string>
     13 #include <vector>
     14 
     15 #include "lib/jxl/base/byte_order.h"
     16 #include "lib/jxl/color_encoding_internal.h"
     17 #include "lib/jxl/enc_ans.h"
     18 #include "lib/jxl/enc_aux_out.h"
     19 #include "lib/jxl/fields.h"
     20 #include "lib/jxl/icc_codec_common.h"
     21 #include "lib/jxl/padded_bytes.h"
     22 
     23 namespace jxl {
     24 namespace {
     25 
     26 // Unshuffles or de-interleaves bytes, for example with width 2, turns
     27 // "AaBbCcDc" into "ABCDabcd", this for example de-interleaves UTF-16 bytes into
     28 // first all the high order bytes, then all the low order bytes.
     29 // Transposes a matrix of width columns and ceil(size / width) rows. There are
     30 // size elements, size may be < width * height, if so the
     31 // last elements of the bottom row are missing, the missing spots are
     32 // transposed along with the filled spots, and the result has the missing
     33 // elements at the bottom of the rightmost column. The input is the input matrix
     34 // in scanline order, the output is the result matrix in scanline order, with
     35 // missing elements skipped over (this may occur at multiple positions).
     36 void Unshuffle(uint8_t* data, size_t size, size_t width) {
     37   size_t height = (size + width - 1) / width;  // amount of rows of input
     38   PaddedBytes result(size);
     39   // i = input index, j output index
     40   size_t s = 0;
     41   size_t j = 0;
     42   for (size_t i = 0; i < size; i++) {
     43     result[j] = data[i];
     44     j += height;
     45     if (j >= size) j = ++s;
     46   }
     47 
     48   for (size_t i = 0; i < size; i++) {
     49     data[i] = result[i];
     50   }
     51 }
     52 
     53 // This is performed by the encoder, the encoder must be able to encode any
     54 // random byte stream (not just byte streams that are a valid ICC profile), so
     55 // an error returned by this function is an implementation error.
     56 Status PredictAndShuffle(size_t stride, size_t width, int order, size_t num,
     57                          const uint8_t* data, size_t size, size_t* pos,
     58                          PaddedBytes* result) {
     59   JXL_RETURN_IF_ERROR(CheckOutOfBounds(*pos, num, size));
     60   // Required by the specification, see decoder. stride * 4 must be < *pos.
     61   if (!*pos || ((*pos - 1u) >> 2u) < stride) {
     62     return JXL_FAILURE("Invalid stride");
     63   }
     64   if (*pos < stride * 4) return JXL_FAILURE("Too large stride");
     65   size_t start = result->size();
     66   for (size_t i = 0; i < num; i++) {
     67     uint8_t predicted =
     68         LinearPredictICCValue(data, *pos, i, stride, width, order);
     69     result->push_back(data[*pos + i] - predicted);
     70   }
     71   *pos += num;
     72   if (width > 1) Unshuffle(result->data() + start, num, width);
     73   return true;
     74 }
     75 
     76 inline void EncodeVarInt(uint64_t value, PaddedBytes* data) {
     77   size_t pos = data->size();
     78   data->resize(data->size() + 9);
     79   size_t output_size = data->size();
     80   uint8_t* output = data->data();
     81 
     82   // While more than 7 bits of data are left,
     83   // store 7 bits and set the next byte flag
     84   while (value > 127) {
     85     // TODO(eustas): should it be `<` ?
     86     JXL_CHECK(pos <= output_size);
     87     // |128: Set the next byte flag
     88     output[pos++] = (static_cast<uint8_t>(value & 127)) | 128;
     89     // Remove the seven bits we just wrote
     90     value >>= 7;
     91   }
     92   // TODO(eustas): should it be `<` ?
     93   JXL_CHECK(pos <= output_size);
     94   output[pos++] = static_cast<uint8_t>(value & 127);
     95 
     96   data->resize(pos);
     97 }
     98 
     99 constexpr size_t kSizeLimit = std::numeric_limits<uint32_t>::max() >> 2;
    100 
    101 }  // namespace
    102 
    103 // Outputs a transformed form of the given icc profile. The result itself is
    104 // not particularly smaller than the input data in bytes, but it will be in a
    105 // form that is easier to compress (more zeroes, ...) and will compress better
    106 // with brotli.
    107 Status PredictICC(const uint8_t* icc, size_t size, PaddedBytes* result) {
    108   PaddedBytes commands;
    109   PaddedBytes data;
    110 
    111   static_assert(sizeof(size_t) >= 4, "size_t is too short");
    112   // Fuzzer expects that PredictICC can accept any input,
    113   // but 1GB should be enough for any purpose.
    114   if (size > kSizeLimit) {
    115     return JXL_FAILURE("ICC profile is too large");
    116   }
    117 
    118   EncodeVarInt(size, result);
    119 
    120   // Header
    121   PaddedBytes header;
    122   header.append(ICCInitialHeaderPrediction());
    123   EncodeUint32(0, size, &header);
    124   for (size_t i = 0; i < kICCHeaderSize && i < size; i++) {
    125     ICCPredictHeader(icc, size, header.data(), i);
    126     data.push_back(icc[i] - header[i]);
    127   }
    128   if (size <= kICCHeaderSize) {
    129     EncodeVarInt(0, result);  // 0 commands
    130     for (size_t i = 0; i < data.size(); i++) {
    131       result->push_back(data[i]);
    132     }
    133     return true;
    134   }
    135 
    136   std::vector<Tag> tags;
    137   std::vector<size_t> tagstarts;
    138   std::vector<size_t> tagsizes;
    139   std::map<size_t, size_t> tagmap;
    140 
    141   // Tag list
    142   size_t pos = kICCHeaderSize;
    143   if (pos + 4 <= size) {
    144     uint64_t numtags = DecodeUint32(icc, size, pos);
    145     pos += 4;
    146     EncodeVarInt(numtags + 1, &commands);
    147     uint64_t prevtagstart = kICCHeaderSize + numtags * 12;
    148     uint32_t prevtagsize = 0;
    149     for (size_t i = 0; i < numtags; i++) {
    150       if (pos + 12 > size) break;
    151 
    152       Tag tag = DecodeKeyword(icc, size, pos + 0);
    153       uint32_t tagstart = DecodeUint32(icc, size, pos + 4);
    154       uint32_t tagsize = DecodeUint32(icc, size, pos + 8);
    155       pos += 12;
    156 
    157       tags.push_back(tag);
    158       tagstarts.push_back(tagstart);
    159       tagsizes.push_back(tagsize);
    160       tagmap[tagstart] = tags.size() - 1;
    161 
    162       uint8_t tagcode = kCommandTagUnknown;
    163       for (size_t j = 0; j < kNumTagStrings; j++) {
    164         if (tag == *kTagStrings[j]) {
    165           tagcode = j + kCommandTagStringFirst;
    166           break;
    167         }
    168       }
    169 
    170       if (tag == kRtrcTag && pos + 24 < size) {
    171         bool ok = true;
    172         ok &= DecodeKeyword(icc, size, pos + 0) == kGtrcTag;
    173         ok &= DecodeKeyword(icc, size, pos + 12) == kBtrcTag;
    174         if (ok) {
    175           for (size_t kk = 0; kk < 8; kk++) {
    176             if (icc[pos - 8 + kk] != icc[pos + 4 + kk]) ok = false;
    177             if (icc[pos - 8 + kk] != icc[pos + 16 + kk]) ok = false;
    178           }
    179         }
    180         if (ok) {
    181           tagcode = kCommandTagTRC;
    182           pos += 24;
    183           i += 2;
    184         }
    185       }
    186 
    187       if (tag == kRxyzTag && pos + 24 < size) {
    188         bool ok = true;
    189         ok &= DecodeKeyword(icc, size, pos + 0) == kGxyzTag;
    190         ok &= DecodeKeyword(icc, size, pos + 12) == kBxyzTag;
    191         uint32_t offsetr = tagstart;
    192         uint32_t offsetg = DecodeUint32(icc, size, pos + 4);
    193         uint32_t offsetb = DecodeUint32(icc, size, pos + 16);
    194         uint32_t sizer = tagsize;
    195         uint32_t sizeg = DecodeUint32(icc, size, pos + 8);
    196         uint32_t sizeb = DecodeUint32(icc, size, pos + 20);
    197         ok &= sizer == 20;
    198         ok &= sizeg == 20;
    199         ok &= sizeb == 20;
    200         ok &= (offsetg == offsetr + 20);
    201         ok &= (offsetb == offsetr + 40);
    202         if (ok) {
    203           tagcode = kCommandTagXYZ;
    204           pos += 24;
    205           i += 2;
    206         }
    207       }
    208 
    209       uint8_t command = tagcode;
    210       uint64_t predicted_tagstart = prevtagstart + prevtagsize;
    211       if (predicted_tagstart != tagstart) command |= kFlagBitOffset;
    212       size_t predicted_tagsize = prevtagsize;
    213       if (tag == kRxyzTag || tag == kGxyzTag || tag == kBxyzTag ||
    214           tag == kKxyzTag || tag == kWtptTag || tag == kBkptTag ||
    215           tag == kLumiTag) {
    216         predicted_tagsize = 20;
    217       }
    218       if (predicted_tagsize != tagsize) command |= kFlagBitSize;
    219       commands.push_back(command);
    220       if (tagcode == 1) {
    221         AppendKeyword(tag, &data);
    222       }
    223       if (command & kFlagBitOffset) EncodeVarInt(tagstart, &commands);
    224       if (command & kFlagBitSize) EncodeVarInt(tagsize, &commands);
    225 
    226       prevtagstart = tagstart;
    227       prevtagsize = tagsize;
    228     }
    229   }
    230   // Indicate end of tag list or varint indicating there's none
    231   commands.push_back(0);
    232 
    233   // Main content
    234   // The main content in a valid ICC profile contains tagged elements, with the
    235   // tag types (4 letter names) given by the tag list above, and the tag list
    236   // pointing to the start and indicating the size of each tagged element. It is
    237   // allowed for tagged elements to overlap, e.g. the curve for R, G and B could
    238   // all point to the same one.
    239   Tag tag;
    240   size_t tagstart = 0;
    241   size_t tagsize = 0;
    242   size_t clutstart = 0;
    243 
    244   // Should always check tag_sane before doing math with tagsize.
    245   const auto tag_sane = [&tagsize]() {
    246     return (tagsize > 8) && (tagsize < kSizeLimit);
    247   };
    248 
    249   size_t last0 = pos;
    250   // This loop appends commands to the output, processing some sub-section of a
    251   // current tagged element each time. We need to keep track of the tagtype of
    252   // the current element, and update it when we encounter the boundary of a
    253   // next one.
    254   // It is not required that the input data is a valid ICC profile, if the
    255   // encoder does not recognize the data it will still be able to output bytes
    256   // but will not predict as well.
    257   while (pos <= size) {
    258     size_t last1 = pos;
    259     PaddedBytes commands_add;
    260     PaddedBytes data_add;
    261 
    262     // This means the loop brought the position beyond the tag end.
    263     // If tagsize is nonsensical, any pos looks "ok-ish".
    264     if ((pos > tagstart + tagsize) && (tagsize < kSizeLimit)) {
    265       tag = {{0, 0, 0, 0}};  // nonsensical value
    266     }
    267 
    268     if (commands_add.empty() && data_add.empty() && tagmap.count(pos) &&
    269         pos + 4 <= size) {
    270       size_t index = tagmap[pos];
    271       tag = DecodeKeyword(icc, size, pos);
    272       tagstart = tagstarts[index];
    273       tagsize = tagsizes[index];
    274 
    275       if (tag == kMlucTag && tag_sane() && pos + tagsize <= size &&
    276           icc[pos + 4] == 0 && icc[pos + 5] == 0 && icc[pos + 6] == 0 &&
    277           icc[pos + 7] == 0) {
    278         size_t num = tagsize - 8;
    279         commands_add.push_back(kCommandTypeStartFirst + 3);
    280         pos += 8;
    281         commands_add.push_back(kCommandShuffle2);
    282         EncodeVarInt(num, &commands_add);
    283         size_t start = data_add.size();
    284         for (size_t i = 0; i < num; i++) {
    285           data_add.push_back(icc[pos]);
    286           pos++;
    287         }
    288         Unshuffle(data_add.data() + start, num, 2);
    289       }
    290 
    291       if (tag == kCurvTag && tag_sane() && pos + tagsize <= size &&
    292           icc[pos + 4] == 0 && icc[pos + 5] == 0 && icc[pos + 6] == 0 &&
    293           icc[pos + 7] == 0) {
    294         size_t num = tagsize - 8;
    295         if (num > 16 && num < (1 << 28) && pos + num <= size && pos > 0) {
    296           commands_add.push_back(kCommandTypeStartFirst + 5);
    297           pos += 8;
    298           commands_add.push_back(kCommandPredict);
    299           int order = 1;
    300           int width = 2;
    301           int stride = width;
    302           commands_add.push_back((order << 2) | (width - 1));
    303           EncodeVarInt(num, &commands_add);
    304           JXL_RETURN_IF_ERROR(PredictAndShuffle(stride, width, order, num, icc,
    305                                                 size, &pos, &data_add));
    306         }
    307       }
    308     }
    309 
    310     if (tag == kMab_Tag || tag == kMba_Tag) {
    311       Tag subTag = DecodeKeyword(icc, size, pos);
    312       if (pos + 12 < size && (subTag == kCurvTag || subTag == kVcgtTag) &&
    313           DecodeUint32(icc, size, pos + 4) == 0) {
    314         uint32_t num = DecodeUint32(icc, size, pos + 8) * 2;
    315         if (num > 16 && num < (1 << 28) && pos + 12 + num <= size) {
    316           pos += 12;
    317           last1 = pos;
    318           commands_add.push_back(kCommandPredict);
    319           int order = 1;
    320           int width = 2;
    321           int stride = width;
    322           commands_add.push_back((order << 2) | (width - 1));
    323           EncodeVarInt(num, &commands_add);
    324           JXL_RETURN_IF_ERROR(PredictAndShuffle(stride, width, order, num, icc,
    325                                                 size, &pos, &data_add));
    326         }
    327       }
    328 
    329       if (pos == tagstart + 24 && pos + 4 < size) {
    330         // Note that this value can be remembered for next iterations of the
    331         // loop, so the "pos == clutstart" if below can trigger during a later
    332         // iteration.
    333         clutstart = tagstart + DecodeUint32(icc, size, pos);
    334       }
    335 
    336       if (pos == clutstart && clutstart + 16 < size) {
    337         size_t numi = icc[tagstart + 8];
    338         size_t numo = icc[tagstart + 9];
    339         size_t width = icc[clutstart + 16];
    340         size_t stride = width * numo;
    341         size_t num = width * numo;
    342         for (size_t i = 0; i < numi && clutstart + i < size; i++) {
    343           num *= icc[clutstart + i];
    344         }
    345         if ((width == 1 || width == 2) && num > 64 && num < (1 << 28) &&
    346             pos + num <= size && pos > stride * 4) {
    347           commands_add.push_back(kCommandPredict);
    348           int order = 1;
    349           uint8_t flags =
    350               (order << 2) | (width - 1) | (stride == width ? 0 : 16);
    351           commands_add.push_back(flags);
    352           if (flags & 16) EncodeVarInt(stride, &commands_add);
    353           EncodeVarInt(num, &commands_add);
    354           JXL_RETURN_IF_ERROR(PredictAndShuffle(stride, width, order, num, icc,
    355                                                 size, &pos, &data_add));
    356         }
    357       }
    358     }
    359 
    360     if (commands_add.empty() && data_add.empty() && tag == kGbd_Tag &&
    361         tag_sane() && pos == tagstart + 8 && pos + tagsize - 8 <= size &&
    362         pos > 16) {
    363       size_t width = 4;
    364       size_t order = 0;
    365       size_t stride = width;
    366       size_t num = tagsize - 8;
    367       uint8_t flags = (order << 2) | (width - 1) | (stride == width ? 0 : 16);
    368       commands_add.push_back(kCommandPredict);
    369       commands_add.push_back(flags);
    370       if (flags & 16) EncodeVarInt(stride, &commands_add);
    371       EncodeVarInt(num, &commands_add);
    372       JXL_RETURN_IF_ERROR(PredictAndShuffle(stride, width, order, num, icc,
    373                                             size, &pos, &data_add));
    374     }
    375 
    376     if (commands_add.empty() && data_add.empty() && pos + 20 <= size) {
    377       Tag subTag = DecodeKeyword(icc, size, pos);
    378       if (subTag == kXyz_Tag && DecodeUint32(icc, size, pos + 4) == 0) {
    379         commands_add.push_back(kCommandXYZ);
    380         pos += 8;
    381         for (size_t j = 0; j < 12; j++) data_add.push_back(icc[pos++]);
    382       }
    383     }
    384 
    385     if (commands_add.empty() && data_add.empty() && pos + 8 <= size) {
    386       if (DecodeUint32(icc, size, pos + 4) == 0) {
    387         Tag subTag = DecodeKeyword(icc, size, pos);
    388         for (size_t i = 0; i < kNumTypeStrings; i++) {
    389           if (subTag == *kTypeStrings[i]) {
    390             commands_add.push_back(kCommandTypeStartFirst + i);
    391             pos += 8;
    392             break;
    393           }
    394         }
    395       }
    396     }
    397 
    398     if (!(commands_add.empty() && data_add.empty()) || pos == size) {
    399       if (last0 < last1) {
    400         commands.push_back(kCommandInsert);
    401         EncodeVarInt(last1 - last0, &commands);
    402         while (last0 < last1) {
    403           data.push_back(icc[last0++]);
    404         }
    405       }
    406       for (size_t i = 0; i < commands_add.size(); i++) {
    407         commands.push_back(commands_add[i]);
    408       }
    409       for (size_t i = 0; i < data_add.size(); i++) {
    410         data.push_back(data_add[i]);
    411       }
    412       last0 = pos;
    413     }
    414     if (commands_add.empty() && data_add.empty()) {
    415       pos++;
    416     }
    417   }
    418 
    419   EncodeVarInt(commands.size(), result);
    420   for (size_t i = 0; i < commands.size(); i++) {
    421     result->push_back(commands[i]);
    422   }
    423   for (size_t i = 0; i < data.size(); i++) {
    424     result->push_back(data[i]);
    425   }
    426 
    427   return true;
    428 }
    429 
    430 Status WriteICC(const IccBytes& icc, BitWriter* JXL_RESTRICT writer,
    431                 size_t layer, AuxOut* JXL_RESTRICT aux_out) {
    432   if (icc.empty()) return JXL_FAILURE("ICC must be non-empty");
    433   PaddedBytes enc;
    434   JXL_RETURN_IF_ERROR(PredictICC(icc.data(), icc.size(), &enc));
    435   std::vector<std::vector<Token>> tokens(1);
    436   BitWriter::Allotment allotment(writer, 128);
    437   JXL_RETURN_IF_ERROR(U64Coder::Write(enc.size(), writer));
    438   allotment.ReclaimAndCharge(writer, layer, aux_out);
    439 
    440   for (size_t i = 0; i < enc.size(); i++) {
    441     tokens[0].emplace_back(
    442         ICCANSContext(i, i > 0 ? enc[i - 1] : 0, i > 1 ? enc[i - 2] : 0),
    443         enc[i]);
    444   }
    445   HistogramParams params;
    446   params.lz77_method = enc.size() < 4096 ? HistogramParams::LZ77Method::kOptimal
    447                                          : HistogramParams::LZ77Method::kLZ77;
    448   EntropyEncodingData code;
    449   std::vector<uint8_t> context_map;
    450   params.force_huffman = true;
    451   BuildAndEncodeHistograms(params, kNumICCContexts, tokens, &code, &context_map,
    452                            writer, layer, aux_out);
    453   WriteTokens(tokens[0], code, context_map, 0, writer, layer, aux_out);
    454   return true;
    455 }
    456 
    457 }  // namespace jxl