libjxl

FORK: libjxl patches used on blog
git clone https://git.neptards.moe/blog/libjxl.git
Log | Files | Refs | Submodules | README | LICENSE

enc_ans.cc (68365B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 #include "lib/jxl/enc_ans.h"
      7 
      8 #include <jxl/types.h>
      9 #include <stdint.h>
     10 
     11 #include <algorithm>
     12 #include <array>
     13 #include <cmath>
     14 #include <cstdint>
     15 #include <limits>
     16 #include <numeric>
     17 #include <type_traits>
     18 #include <unordered_map>
     19 #include <utility>
     20 #include <vector>
     21 
     22 #include "lib/jxl/ans_common.h"
     23 #include "lib/jxl/base/bits.h"
     24 #include "lib/jxl/base/fast_math-inl.h"
     25 #include "lib/jxl/base/status.h"
     26 #include "lib/jxl/dec_ans.h"
     27 #include "lib/jxl/enc_ans_params.h"
     28 #include "lib/jxl/enc_aux_out.h"
     29 #include "lib/jxl/enc_cluster.h"
     30 #include "lib/jxl/enc_context_map.h"
     31 #include "lib/jxl/enc_fields.h"
     32 #include "lib/jxl/enc_huffman.h"
     33 #include "lib/jxl/enc_params.h"
     34 #include "lib/jxl/fields.h"
     35 
     36 namespace jxl {
     37 
     38 namespace {
     39 
     40 #if !JXL_IS_DEBUG_BUILD
     41 constexpr
     42 #endif
     43     bool ans_fuzzer_friendly_ = false;
     44 
     45 const int kMaxNumSymbolsForSmallCode = 4;
     46 
     47 void ANSBuildInfoTable(const ANSHistBin* counts, const AliasTable::Entry* table,
     48                        size_t alphabet_size, size_t log_alpha_size,
     49                        ANSEncSymbolInfo* info) {
     50   size_t log_entry_size = ANS_LOG_TAB_SIZE - log_alpha_size;
     51   size_t entry_size_minus_1 = (1 << log_entry_size) - 1;
     52   // create valid alias table for empty streams.
     53   for (size_t s = 0; s < std::max<size_t>(1, alphabet_size); ++s) {
     54     const ANSHistBin freq = s == alphabet_size ? ANS_TAB_SIZE : counts[s];
     55     info[s].freq_ = static_cast<uint16_t>(freq);
     56 #ifdef USE_MULT_BY_RECIPROCAL
     57     if (freq != 0) {
     58       info[s].ifreq_ =
     59           ((1ull << RECIPROCAL_PRECISION) + info[s].freq_ - 1) / info[s].freq_;
     60     } else {
     61       info[s].ifreq_ = 1;  // shouldn't matter (symbol shouldn't occur), but...
     62     }
     63 #endif
     64     info[s].reverse_map_.resize(freq);
     65   }
     66   for (int i = 0; i < ANS_TAB_SIZE; i++) {
     67     AliasTable::Symbol s =
     68         AliasTable::Lookup(table, i, log_entry_size, entry_size_minus_1);
     69     info[s.value].reverse_map_[s.offset] = i;
     70   }
     71 }
     72 
     73 float EstimateDataBits(const ANSHistBin* histogram, const ANSHistBin* counts,
     74                        size_t len) {
     75   float sum = 0.0f;
     76   int total_histogram = 0;
     77   int total_counts = 0;
     78   for (size_t i = 0; i < len; ++i) {
     79     total_histogram += histogram[i];
     80     total_counts += counts[i];
     81     if (histogram[i] > 0) {
     82       JXL_ASSERT(counts[i] > 0);
     83       // += histogram[i] * -log(counts[i]/total_counts)
     84       sum += histogram[i] *
     85              std::max(0.0f, ANS_LOG_TAB_SIZE - FastLog2f(counts[i]));
     86     }
     87   }
     88   if (total_histogram > 0) {
     89     // Used only in assert.
     90     (void)total_counts;
     91     JXL_ASSERT(total_counts == ANS_TAB_SIZE);
     92   }
     93   return sum;
     94 }
     95 
     96 float EstimateDataBitsFlat(const ANSHistBin* histogram, size_t len) {
     97   const float flat_bits = std::max(FastLog2f(len), 0.0f);
     98   float total_histogram = 0;
     99   for (size_t i = 0; i < len; ++i) {
    100     total_histogram += histogram[i];
    101   }
    102   return total_histogram * flat_bits;
    103 }
    104 
    105 // Static Huffman code for encoding logcounts. The last symbol is used as RLE
    106 // sequence.
    107 const uint8_t kLogCountBitLengths[ANS_LOG_TAB_SIZE + 2] = {
    108     5, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 6, 7, 7,
    109 };
    110 const uint8_t kLogCountSymbols[ANS_LOG_TAB_SIZE + 2] = {
    111     17, 11, 15, 3, 9, 7, 4, 2, 5, 6, 0, 33, 1, 65,
    112 };
    113 
    114 // Returns the difference between largest count that can be represented and is
    115 // smaller than "count" and smallest representable count larger than "count".
    116 int SmallestIncrement(uint32_t count, uint32_t shift) {
    117   int bits = count == 0 ? -1 : FloorLog2Nonzero(count);
    118   int drop_bits = bits - GetPopulationCountPrecision(bits, shift);
    119   return drop_bits < 0 ? 1 : (1 << drop_bits);
    120 }
    121 
    122 template <bool minimize_error_of_sum>
    123 bool RebalanceHistogram(const float* targets, int max_symbol, int table_size,
    124                         uint32_t shift, int* omit_pos, ANSHistBin* counts) {
    125   int sum = 0;
    126   float sum_nonrounded = 0.0;
    127   int remainder_pos = 0;  // if all of them are handled in first loop
    128   int remainder_log = -1;
    129   for (int n = 0; n < max_symbol; ++n) {
    130     if (targets[n] > 0 && targets[n] < 1.0f) {
    131       counts[n] = 1;
    132       sum_nonrounded += targets[n];
    133       sum += counts[n];
    134     }
    135   }
    136   const float discount_ratio =
    137       (table_size - sum) / (table_size - sum_nonrounded);
    138   JXL_ASSERT(discount_ratio > 0);
    139   JXL_ASSERT(discount_ratio <= 1.0f);
    140   // Invariant for minimize_error_of_sum == true:
    141   // abs(sum - sum_nonrounded)
    142   //   <= SmallestIncrement(max(targets[])) + max_symbol
    143   for (int n = 0; n < max_symbol; ++n) {
    144     if (targets[n] >= 1.0f) {
    145       sum_nonrounded += targets[n];
    146       counts[n] =
    147           static_cast<ANSHistBin>(targets[n] * discount_ratio);  // truncate
    148       if (counts[n] == 0) counts[n] = 1;
    149       if (counts[n] == table_size) counts[n] = table_size - 1;
    150       // Round the count to the closest nonzero multiple of SmallestIncrement
    151       // (when minimize_error_of_sum is false) or one of two closest so as to
    152       // keep the sum as close as possible to sum_nonrounded.
    153       int inc = SmallestIncrement(counts[n], shift);
    154       counts[n] -= counts[n] & (inc - 1);
    155       // TODO(robryk): Should we rescale targets[n]?
    156       const int target = minimize_error_of_sum
    157                              ? (static_cast<int>(sum_nonrounded) - sum)
    158                              : static_cast<int>(targets[n]);
    159       if (counts[n] == 0 ||
    160           (target >= counts[n] + inc / 2 && counts[n] + inc < table_size)) {
    161         counts[n] += inc;
    162       }
    163       sum += counts[n];
    164       const int count_log = FloorLog2Nonzero(static_cast<uint32_t>(counts[n]));
    165       if (count_log > remainder_log) {
    166         remainder_pos = n;
    167         remainder_log = count_log;
    168       }
    169     }
    170   }
    171   JXL_ASSERT(remainder_pos != -1);
    172   // NOTE: This is the only place where counts could go negative. We could
    173   // detect that, return false and make ANSHistBin uint32_t.
    174   counts[remainder_pos] -= sum - table_size;
    175   *omit_pos = remainder_pos;
    176   return counts[remainder_pos] > 0;
    177 }
    178 
    179 Status NormalizeCounts(ANSHistBin* counts, int* omit_pos, const int length,
    180                        const int precision_bits, uint32_t shift,
    181                        int* num_symbols, int* symbols) {
    182   const int32_t table_size = 1 << precision_bits;  // target sum / table size
    183   uint64_t total = 0;
    184   int max_symbol = 0;
    185   int symbol_count = 0;
    186   for (int n = 0; n < length; ++n) {
    187     total += counts[n];
    188     if (counts[n] > 0) {
    189       if (symbol_count < kMaxNumSymbolsForSmallCode) {
    190         symbols[symbol_count] = n;
    191       }
    192       ++symbol_count;
    193       max_symbol = n + 1;
    194     }
    195   }
    196   *num_symbols = symbol_count;
    197   if (symbol_count == 0) {
    198     return true;
    199   }
    200   if (symbol_count == 1) {
    201     counts[symbols[0]] = table_size;
    202     return true;
    203   }
    204   if (symbol_count > table_size)
    205     return JXL_FAILURE("Too many entries in an ANS histogram");
    206 
    207   const float norm = 1.f * table_size / total;
    208   std::vector<float> targets(max_symbol);
    209   for (size_t n = 0; n < targets.size(); ++n) {
    210     targets[n] = norm * counts[n];
    211   }
    212   if (!RebalanceHistogram<false>(targets.data(), max_symbol, table_size, shift,
    213                                  omit_pos, counts)) {
    214     // Use an alternative rebalancing mechanism if the one above failed
    215     // to create a histogram that is positive wherever the original one was.
    216     if (!RebalanceHistogram<true>(targets.data(), max_symbol, table_size, shift,
    217                                   omit_pos, counts)) {
    218       return JXL_FAILURE("Logic error: couldn't rebalance a histogram");
    219     }
    220   }
    221   return true;
    222 }
    223 
    224 struct SizeWriter {
    225   size_t size = 0;
    226   void Write(size_t num, size_t bits) { size += num; }
    227 };
    228 
    229 template <typename Writer>
    230 void StoreVarLenUint8(size_t n, Writer* writer) {
    231   JXL_DASSERT(n <= 255);
    232   if (n == 0) {
    233     writer->Write(1, 0);
    234   } else {
    235     writer->Write(1, 1);
    236     size_t nbits = FloorLog2Nonzero(n);
    237     writer->Write(3, nbits);
    238     writer->Write(nbits, n - (1ULL << nbits));
    239   }
    240 }
    241 
    242 template <typename Writer>
    243 void StoreVarLenUint16(size_t n, Writer* writer) {
    244   JXL_DASSERT(n <= 65535);
    245   if (n == 0) {
    246     writer->Write(1, 0);
    247   } else {
    248     writer->Write(1, 1);
    249     size_t nbits = FloorLog2Nonzero(n);
    250     writer->Write(4, nbits);
    251     writer->Write(nbits, n - (1ULL << nbits));
    252   }
    253 }
    254 
    255 template <typename Writer>
    256 bool EncodeCounts(const ANSHistBin* counts, const int alphabet_size,
    257                   const int omit_pos, const int num_symbols, uint32_t shift,
    258                   const int* symbols, Writer* writer) {
    259   bool ok = true;
    260   if (num_symbols <= 2) {
    261     // Small tree marker to encode 1-2 symbols.
    262     writer->Write(1, 1);
    263     if (num_symbols == 0) {
    264       writer->Write(1, 0);
    265       StoreVarLenUint8(0, writer);
    266     } else {
    267       writer->Write(1, num_symbols - 1);
    268       for (int i = 0; i < num_symbols; ++i) {
    269         StoreVarLenUint8(symbols[i], writer);
    270       }
    271     }
    272     if (num_symbols == 2) {
    273       writer->Write(ANS_LOG_TAB_SIZE, counts[symbols[0]]);
    274     }
    275   } else {
    276     // Mark non-small tree.
    277     writer->Write(1, 0);
    278     // Mark non-flat histogram.
    279     writer->Write(1, 0);
    280 
    281     // Precompute sequences for RLE encoding. Contains the number of identical
    282     // values starting at a given index. Only contains the value at the first
    283     // element of the series.
    284     std::vector<uint32_t> same(alphabet_size, 0);
    285     int last = 0;
    286     for (int i = 1; i < alphabet_size; i++) {
    287       // Store the sequence length once different symbol reached, or we're at
    288       // the end, or the length is longer than we can encode, or we are at
    289       // the omit_pos. We don't support including the omit_pos in an RLE
    290       // sequence because this value may use a different amount of log2 bits
    291       // than standard, it is too complex to handle in the decoder.
    292       if (counts[i] != counts[last] || i + 1 == alphabet_size ||
    293           (i - last) >= 255 || i == omit_pos || i == omit_pos + 1) {
    294         same[last] = (i - last);
    295         last = i + 1;
    296       }
    297     }
    298 
    299     int length = 0;
    300     std::vector<int> logcounts(alphabet_size);
    301     int omit_log = 0;
    302     for (int i = 0; i < alphabet_size; ++i) {
    303       JXL_ASSERT(counts[i] <= ANS_TAB_SIZE);
    304       JXL_ASSERT(counts[i] >= 0);
    305       if (i == omit_pos) {
    306         length = i + 1;
    307       } else if (counts[i] > 0) {
    308         logcounts[i] = FloorLog2Nonzero(static_cast<uint32_t>(counts[i])) + 1;
    309         length = i + 1;
    310         if (i < omit_pos) {
    311           omit_log = std::max(omit_log, logcounts[i] + 1);
    312         } else {
    313           omit_log = std::max(omit_log, logcounts[i]);
    314         }
    315       }
    316     }
    317     logcounts[omit_pos] = omit_log;
    318 
    319     // Elias gamma-like code for shift. Only difference is that if the number
    320     // of bits to be encoded is equal to FloorLog2(ANS_LOG_TAB_SIZE+1), we skip
    321     // the terminating 0 in unary coding.
    322     int upper_bound_log = FloorLog2Nonzero(ANS_LOG_TAB_SIZE + 1);
    323     int log = FloorLog2Nonzero(shift + 1);
    324     writer->Write(log, (1 << log) - 1);
    325     if (log != upper_bound_log) writer->Write(1, 0);
    326     writer->Write(log, ((1 << log) - 1) & (shift + 1));
    327 
    328     // Since num_symbols >= 3, we know that length >= 3, therefore we encode
    329     // length - 3.
    330     if (length - 3 > 255) {
    331       // Pretend that everything is OK, but complain about correctness later.
    332       StoreVarLenUint8(255, writer);
    333       ok = false;
    334     } else {
    335       StoreVarLenUint8(length - 3, writer);
    336     }
    337 
    338     // The logcount values are encoded with a static Huffman code.
    339     static const size_t kMinReps = 4;
    340     size_t rep = ANS_LOG_TAB_SIZE + 1;
    341     for (int i = 0; i < length; ++i) {
    342       if (i > 0 && same[i - 1] > kMinReps) {
    343         // Encode the RLE symbol and skip the repeated ones.
    344         writer->Write(kLogCountBitLengths[rep], kLogCountSymbols[rep]);
    345         StoreVarLenUint8(same[i - 1] - kMinReps - 1, writer);
    346         i += same[i - 1] - 2;
    347         continue;
    348       }
    349       writer->Write(kLogCountBitLengths[logcounts[i]],
    350                     kLogCountSymbols[logcounts[i]]);
    351     }
    352     for (int i = 0; i < length; ++i) {
    353       if (i > 0 && same[i - 1] > kMinReps) {
    354         // Skip symbols encoded by RLE.
    355         i += same[i - 1] - 2;
    356         continue;
    357       }
    358       if (logcounts[i] > 1 && i != omit_pos) {
    359         int bitcount = GetPopulationCountPrecision(logcounts[i] - 1, shift);
    360         int drop_bits = logcounts[i] - 1 - bitcount;
    361         JXL_CHECK((counts[i] & ((1 << drop_bits) - 1)) == 0);
    362         writer->Write(bitcount, (counts[i] >> drop_bits) - (1 << bitcount));
    363       }
    364     }
    365   }
    366   return ok;
    367 }
    368 
    369 void EncodeFlatHistogram(const int alphabet_size, BitWriter* writer) {
    370   // Mark non-small tree.
    371   writer->Write(1, 0);
    372   // Mark uniform histogram.
    373   writer->Write(1, 1);
    374   JXL_ASSERT(alphabet_size > 0);
    375   // Encode alphabet size.
    376   StoreVarLenUint8(alphabet_size - 1, writer);
    377 }
    378 
    379 float ComputeHistoAndDataCost(const ANSHistBin* histogram, size_t alphabet_size,
    380                               uint32_t method) {
    381   if (method == 0) {  // Flat code
    382     return ANS_LOG_TAB_SIZE + 2 +
    383            EstimateDataBitsFlat(histogram, alphabet_size);
    384   }
    385   // Non-flat: shift = method-1.
    386   uint32_t shift = method - 1;
    387   std::vector<ANSHistBin> counts(histogram, histogram + alphabet_size);
    388   int omit_pos = 0;
    389   int num_symbols;
    390   int symbols[kMaxNumSymbolsForSmallCode] = {};
    391   JXL_CHECK(NormalizeCounts(counts.data(), &omit_pos, alphabet_size,
    392                             ANS_LOG_TAB_SIZE, shift, &num_symbols, symbols));
    393   SizeWriter writer;
    394   // Ignore the correctness, no real encoding happens at this stage.
    395   (void)EncodeCounts(counts.data(), alphabet_size, omit_pos, num_symbols, shift,
    396                      symbols, &writer);
    397   return writer.size +
    398          EstimateDataBits(histogram, counts.data(), alphabet_size);
    399 }
    400 
    401 uint32_t ComputeBestMethod(
    402     const ANSHistBin* histogram, size_t alphabet_size, float* cost,
    403     HistogramParams::ANSHistogramStrategy ans_histogram_strategy) {
    404   size_t method = 0;
    405   float fcost = ComputeHistoAndDataCost(histogram, alphabet_size, 0);
    406   auto try_shift = [&](size_t shift) {
    407     float c = ComputeHistoAndDataCost(histogram, alphabet_size, shift + 1);
    408     if (c < fcost) {
    409       method = shift + 1;
    410       fcost = c;
    411     }
    412   };
    413   switch (ans_histogram_strategy) {
    414     case HistogramParams::ANSHistogramStrategy::kPrecise: {
    415       for (uint32_t shift = 0; shift <= ANS_LOG_TAB_SIZE; shift++) {
    416         try_shift(shift);
    417       }
    418       break;
    419     }
    420     case HistogramParams::ANSHistogramStrategy::kApproximate: {
    421       for (uint32_t shift = 0; shift <= ANS_LOG_TAB_SIZE; shift += 2) {
    422         try_shift(shift);
    423       }
    424       break;
    425     }
    426     case HistogramParams::ANSHistogramStrategy::kFast: {
    427       try_shift(0);
    428       try_shift(ANS_LOG_TAB_SIZE / 2);
    429       try_shift(ANS_LOG_TAB_SIZE);
    430       break;
    431     }
    432   };
    433   *cost = fcost;
    434   return method;
    435 }
    436 
    437 }  // namespace
    438 
    439 // Returns an estimate of the cost of encoding this histogram and the
    440 // corresponding data.
    441 size_t BuildAndStoreANSEncodingData(
    442     HistogramParams::ANSHistogramStrategy ans_histogram_strategy,
    443     const ANSHistBin* histogram, size_t alphabet_size, size_t log_alpha_size,
    444     bool use_prefix_code, ANSEncSymbolInfo* info, BitWriter* writer) {
    445   if (use_prefix_code) {
    446     if (alphabet_size <= 1) return 0;
    447     std::vector<uint32_t> histo(alphabet_size);
    448     for (size_t i = 0; i < alphabet_size; i++) {
    449       histo[i] = histogram[i];
    450       JXL_CHECK(histogram[i] >= 0);
    451     }
    452     size_t cost = 0;
    453     {
    454       std::vector<uint8_t> depths(alphabet_size);
    455       std::vector<uint16_t> bits(alphabet_size);
    456       if (writer == nullptr) {
    457         BitWriter tmp_writer;
    458         BitWriter::Allotment allotment(
    459             &tmp_writer, 8 * alphabet_size + 8);  // safe upper bound
    460         BuildAndStoreHuffmanTree(histo.data(), alphabet_size, depths.data(),
    461                                  bits.data(), &tmp_writer);
    462         allotment.ReclaimAndCharge(&tmp_writer, 0, /*aux_out=*/nullptr);
    463         cost = tmp_writer.BitsWritten();
    464       } else {
    465         size_t start = writer->BitsWritten();
    466         BuildAndStoreHuffmanTree(histo.data(), alphabet_size, depths.data(),
    467                                  bits.data(), writer);
    468         cost = writer->BitsWritten() - start;
    469       }
    470       for (size_t i = 0; i < alphabet_size; i++) {
    471         info[i].bits = depths[i] == 0 ? 0 : bits[i];
    472         info[i].depth = depths[i];
    473       }
    474     }
    475     // Estimate data cost.
    476     for (size_t i = 0; i < alphabet_size; i++) {
    477       cost += histogram[i] * info[i].depth;
    478     }
    479     return cost;
    480   }
    481   JXL_ASSERT(alphabet_size <= ANS_TAB_SIZE);
    482   float cost;
    483   uint32_t method = ComputeBestMethod(histogram, alphabet_size, &cost,
    484                                       ans_histogram_strategy);
    485   JXL_ASSERT(cost >= 0);
    486   int num_symbols;
    487   int symbols[kMaxNumSymbolsForSmallCode] = {};
    488   std::vector<ANSHistBin> counts(histogram, histogram + alphabet_size);
    489   if (!counts.empty()) {
    490     size_t sum = 0;
    491     for (int count : counts) {
    492       sum += count;
    493     }
    494     if (sum == 0) {
    495       counts[0] = ANS_TAB_SIZE;
    496     }
    497   }
    498   int omit_pos = 0;
    499   uint32_t shift = method - 1;
    500   if (method == 0) {
    501     counts = CreateFlatHistogram(alphabet_size, ANS_TAB_SIZE);
    502   } else {
    503     JXL_CHECK(NormalizeCounts(counts.data(), &omit_pos, alphabet_size,
    504                               ANS_LOG_TAB_SIZE, shift, &num_symbols, symbols));
    505   }
    506   AliasTable::Entry a[ANS_MAX_ALPHABET_SIZE];
    507   InitAliasTable(counts, ANS_TAB_SIZE, log_alpha_size, a);
    508   ANSBuildInfoTable(counts.data(), a, alphabet_size, log_alpha_size, info);
    509   if (writer != nullptr) {
    510     if (method == 0) {
    511       EncodeFlatHistogram(alphabet_size, writer);
    512     } else {
    513       bool ok = EncodeCounts(counts.data(), alphabet_size, omit_pos,
    514                              num_symbols, method - 1, symbols, writer);
    515       (void)ok;
    516       JXL_DASSERT(ok);
    517     }
    518   }
    519   return cost;
    520 }
    521 
    522 float ANSPopulationCost(const ANSHistBin* data, size_t alphabet_size) {
    523   float c;
    524   ComputeBestMethod(data, alphabet_size, &c,
    525                     HistogramParams::ANSHistogramStrategy::kFast);
    526   return c;
    527 }
    528 
    529 template <typename Writer>
    530 void EncodeUintConfig(const HybridUintConfig uint_config, Writer* writer,
    531                       size_t log_alpha_size) {
    532   writer->Write(CeilLog2Nonzero(log_alpha_size + 1),
    533                 uint_config.split_exponent);
    534   if (uint_config.split_exponent == log_alpha_size) {
    535     return;  // msb/lsb don't matter.
    536   }
    537   size_t nbits = CeilLog2Nonzero(uint_config.split_exponent + 1);
    538   writer->Write(nbits, uint_config.msb_in_token);
    539   nbits = CeilLog2Nonzero(uint_config.split_exponent -
    540                           uint_config.msb_in_token + 1);
    541   writer->Write(nbits, uint_config.lsb_in_token);
    542 }
    543 template <typename Writer>
    544 void EncodeUintConfigs(const std::vector<HybridUintConfig>& uint_config,
    545                        Writer* writer, size_t log_alpha_size) {
    546   // TODO(veluca): RLE?
    547   for (const auto& cfg : uint_config) {
    548     EncodeUintConfig(cfg, writer, log_alpha_size);
    549   }
    550 }
    551 template void EncodeUintConfigs(const std::vector<HybridUintConfig>&,
    552                                 BitWriter*, size_t);
    553 
    554 namespace {
    555 
    556 void ChooseUintConfigs(const HistogramParams& params,
    557                        const std::vector<std::vector<Token>>& tokens,
    558                        const std::vector<uint8_t>& context_map,
    559                        std::vector<Histogram>* clustered_histograms,
    560                        EntropyEncodingData* codes, size_t* log_alpha_size) {
    561   codes->uint_config.resize(clustered_histograms->size());
    562   if (params.uint_method == HistogramParams::HybridUintMethod::kNone) {
    563     return;
    564   }
    565   if (params.uint_method == HistogramParams::HybridUintMethod::k000) {
    566     codes->uint_config.clear();
    567     codes->uint_config.resize(clustered_histograms->size(),
    568                               HybridUintConfig(0, 0, 0));
    569     return;
    570   }
    571   if (params.uint_method == HistogramParams::HybridUintMethod::kContextMap) {
    572     codes->uint_config.clear();
    573     codes->uint_config.resize(clustered_histograms->size(),
    574                               HybridUintConfig(2, 0, 1));
    575     return;
    576   }
    577 
    578   // If the uint config is adaptive, just stick with the default in streaming
    579   // mode.
    580   if (params.streaming_mode) {
    581     return;
    582   }
    583 
    584   // Brute-force method that tries a few options.
    585   std::vector<HybridUintConfig> configs;
    586   if (params.uint_method == HistogramParams::HybridUintMethod::kBest) {
    587     configs = {
    588         HybridUintConfig(4, 2, 0),  // default
    589         HybridUintConfig(4, 1, 0),  // less precise
    590         HybridUintConfig(4, 2, 1),  // add sign
    591         HybridUintConfig(4, 2, 2),  // add sign+parity
    592         HybridUintConfig(4, 1, 2),  // add parity but less msb
    593         // Same as above, but more direct coding.
    594         HybridUintConfig(5, 2, 0), HybridUintConfig(5, 1, 0),
    595         HybridUintConfig(5, 2, 1), HybridUintConfig(5, 2, 2),
    596         HybridUintConfig(5, 1, 2),
    597         // Same as above, but less direct coding.
    598         HybridUintConfig(3, 2, 0), HybridUintConfig(3, 1, 0),
    599         HybridUintConfig(3, 2, 1), HybridUintConfig(3, 1, 2),
    600         // For near-lossless.
    601         HybridUintConfig(4, 1, 3), HybridUintConfig(5, 1, 4),
    602         HybridUintConfig(5, 2, 3), HybridUintConfig(6, 1, 5),
    603         HybridUintConfig(6, 2, 4), HybridUintConfig(6, 0, 0),
    604         // Other
    605         HybridUintConfig(0, 0, 0),   // varlenuint
    606         HybridUintConfig(2, 0, 1),   // works well for ctx map
    607         HybridUintConfig(7, 0, 0),   // direct coding
    608         HybridUintConfig(8, 0, 0),   // direct coding
    609         HybridUintConfig(9, 0, 0),   // direct coding
    610         HybridUintConfig(10, 0, 0),  // direct coding
    611         HybridUintConfig(11, 0, 0),  // direct coding
    612         HybridUintConfig(12, 0, 0),  // direct coding
    613     };
    614   } else if (params.uint_method == HistogramParams::HybridUintMethod::kFast) {
    615     configs = {
    616         HybridUintConfig(4, 2, 0),  // default
    617         HybridUintConfig(4, 1, 2),  // add parity but less msb
    618         HybridUintConfig(0, 0, 0),  // smallest histograms
    619         HybridUintConfig(2, 0, 1),  // works well for ctx map
    620     };
    621   }
    622 
    623   std::vector<float> costs(clustered_histograms->size(),
    624                            std::numeric_limits<float>::max());
    625   std::vector<uint32_t> extra_bits(clustered_histograms->size());
    626   std::vector<uint8_t> is_valid(clustered_histograms->size());
    627   size_t max_alpha =
    628       codes->use_prefix_code ? PREFIX_MAX_ALPHABET_SIZE : ANS_MAX_ALPHABET_SIZE;
    629   for (HybridUintConfig cfg : configs) {
    630     std::fill(is_valid.begin(), is_valid.end(), true);
    631     std::fill(extra_bits.begin(), extra_bits.end(), 0);
    632 
    633     for (auto& histo : *clustered_histograms) {
    634       histo.Clear();
    635     }
    636     for (const auto& stream : tokens) {
    637       for (const auto& token : stream) {
    638         // TODO(veluca): do not ignore lz77 commands.
    639         if (token.is_lz77_length) continue;
    640         size_t histo = context_map[token.context];
    641         uint32_t tok, nbits, bits;
    642         cfg.Encode(token.value, &tok, &nbits, &bits);
    643         if (tok >= max_alpha ||
    644             (codes->lz77.enabled && tok >= codes->lz77.min_symbol)) {
    645           is_valid[histo] = JXL_FALSE;
    646           continue;
    647         }
    648         extra_bits[histo] += nbits;
    649         (*clustered_histograms)[histo].Add(tok);
    650       }
    651     }
    652 
    653     for (size_t i = 0; i < clustered_histograms->size(); i++) {
    654       if (!is_valid[i]) continue;
    655       float cost = (*clustered_histograms)[i].PopulationCost() + extra_bits[i];
    656       // add signaling cost of the hybriduintconfig itself
    657       cost += CeilLog2Nonzero(cfg.split_exponent + 1);
    658       cost += CeilLog2Nonzero(cfg.split_exponent - cfg.msb_in_token + 1);
    659       if (cost < costs[i]) {
    660         codes->uint_config[i] = cfg;
    661         costs[i] = cost;
    662       }
    663     }
    664   }
    665 
    666   // Rebuild histograms.
    667   for (auto& histo : *clustered_histograms) {
    668     histo.Clear();
    669   }
    670   *log_alpha_size = 4;
    671   for (const auto& stream : tokens) {
    672     for (const auto& token : stream) {
    673       uint32_t tok, nbits, bits;
    674       size_t histo = context_map[token.context];
    675       (token.is_lz77_length ? codes->lz77.length_uint_config
    676                             : codes->uint_config[histo])
    677           .Encode(token.value, &tok, &nbits, &bits);
    678       tok += token.is_lz77_length ? codes->lz77.min_symbol : 0;
    679       (*clustered_histograms)[histo].Add(tok);
    680       while (tok >= (1u << *log_alpha_size)) (*log_alpha_size)++;
    681     }
    682   }
    683 #if JXL_ENABLE_ASSERT
    684   size_t max_log_alpha_size = codes->use_prefix_code ? PREFIX_MAX_BITS : 8;
    685   JXL_ASSERT(*log_alpha_size <= max_log_alpha_size);
    686 #endif
    687 }
    688 
    689 Histogram HistogramFromSymbolInfo(
    690     const std::vector<ANSEncSymbolInfo>& encoding_info, bool use_prefix_code) {
    691   Histogram histo;
    692   histo.data_.resize(DivCeil(encoding_info.size(), Histogram::kRounding) *
    693                      Histogram::kRounding);
    694   histo.total_count_ = 0;
    695   for (size_t i = 0; i < encoding_info.size(); ++i) {
    696     const ANSEncSymbolInfo& info = encoding_info[i];
    697     int count = use_prefix_code
    698                     ? (info.depth ? (1u << (PREFIX_MAX_BITS - info.depth)) : 0)
    699                     : info.freq_;
    700     histo.data_[i] = count;
    701     histo.total_count_ += count;
    702   }
    703   return histo;
    704 }
    705 
    706 class HistogramBuilder {
    707  public:
    708   explicit HistogramBuilder(const size_t num_contexts)
    709       : histograms_(num_contexts) {}
    710 
    711   void VisitSymbol(int symbol, size_t histo_idx) {
    712     JXL_DASSERT(histo_idx < histograms_.size());
    713     histograms_[histo_idx].Add(symbol);
    714   }
    715 
    716   // NOTE: `layer` is only for clustered_entropy; caller does ReclaimAndCharge.
    717   size_t BuildAndStoreEntropyCodes(
    718       const HistogramParams& params,
    719       const std::vector<std::vector<Token>>& tokens, EntropyEncodingData* codes,
    720       std::vector<uint8_t>* context_map, BitWriter* writer, size_t layer,
    721       AuxOut* aux_out) const {
    722     const size_t prev_histograms = codes->encoding_info.size();
    723     size_t cost = 0;
    724     std::vector<Histogram> clustered_histograms;
    725     for (size_t i = 0; i < prev_histograms; ++i) {
    726       clustered_histograms.push_back(HistogramFromSymbolInfo(
    727           codes->encoding_info[i], codes->use_prefix_code));
    728     }
    729     size_t context_offset = context_map->size();
    730     context_map->resize(context_offset + histograms_.size());
    731     if (histograms_.size() > 1) {
    732       if (!ans_fuzzer_friendly_) {
    733         std::vector<uint32_t> histogram_symbols;
    734         ClusterHistograms(params, histograms_, kClustersLimit,
    735                           &clustered_histograms, &histogram_symbols);
    736         for (size_t c = 0; c < histograms_.size(); ++c) {
    737           (*context_map)[context_offset + c] =
    738               static_cast<uint8_t>(histogram_symbols[c]);
    739         }
    740       } else {
    741         JXL_ASSERT(codes->encoding_info.empty());
    742         fill(context_map->begin(), context_map->end(), 0);
    743         size_t max_symbol = 0;
    744         for (const Histogram& h : histograms_) {
    745           max_symbol = std::max(h.data_.size(), max_symbol);
    746         }
    747         size_t num_symbols = 1 << CeilLog2Nonzero(max_symbol + 1);
    748         clustered_histograms.resize(1);
    749         clustered_histograms[0].Clear();
    750         for (size_t i = 0; i < num_symbols; i++) {
    751           clustered_histograms[0].Add(i);
    752         }
    753       }
    754       if (writer != nullptr) {
    755         EncodeContextMap(*context_map, clustered_histograms.size(), writer,
    756                          layer, aux_out);
    757       }
    758     } else {
    759       JXL_ASSERT(codes->encoding_info.empty());
    760       clustered_histograms.push_back(histograms_[0]);
    761     }
    762     if (aux_out != nullptr) {
    763       for (size_t i = prev_histograms; i < clustered_histograms.size(); ++i) {
    764         aux_out->layers[layer].clustered_entropy +=
    765             clustered_histograms[i].ShannonEntropy();
    766       }
    767     }
    768     size_t log_alpha_size = codes->lz77.enabled ? 8 : 7;  // Sane default.
    769     if (ans_fuzzer_friendly_) {
    770       codes->uint_config.clear();
    771       codes->uint_config.resize(1, HybridUintConfig(7, 0, 0));
    772     } else {
    773       ChooseUintConfigs(params, tokens, *context_map, &clustered_histograms,
    774                         codes, &log_alpha_size);
    775     }
    776     if (log_alpha_size < 5) log_alpha_size = 5;
    777     if (params.streaming_mode) {
    778       // TODO(szabadka) Figure out if we can use lower values here.
    779       log_alpha_size = 8;
    780     }
    781     SizeWriter size_writer;  // Used if writer == nullptr to estimate costs.
    782     cost += 1;
    783     if (writer) writer->Write(1, TO_JXL_BOOL(codes->use_prefix_code));
    784 
    785     if (codes->use_prefix_code) {
    786       log_alpha_size = PREFIX_MAX_BITS;
    787     } else {
    788       cost += 2;
    789     }
    790     if (writer == nullptr) {
    791       EncodeUintConfigs(codes->uint_config, &size_writer, log_alpha_size);
    792     } else {
    793       if (!codes->use_prefix_code) writer->Write(2, log_alpha_size - 5);
    794       EncodeUintConfigs(codes->uint_config, writer, log_alpha_size);
    795     }
    796     if (codes->use_prefix_code) {
    797       for (const auto& histo : clustered_histograms) {
    798         size_t alphabet_size = histo.alphabet_size();
    799         if (writer) {
    800           StoreVarLenUint16(alphabet_size - 1, writer);
    801         } else {
    802           StoreVarLenUint16(alphabet_size - 1, &size_writer);
    803         }
    804       }
    805     }
    806     cost += size_writer.size;
    807     for (size_t c = prev_histograms; c < clustered_histograms.size(); ++c) {
    808       size_t alphabet_size = clustered_histograms[c].alphabet_size();
    809       codes->encoding_info.emplace_back();
    810       codes->encoding_info.back().resize(alphabet_size);
    811       BitWriter* histo_writer = writer;
    812       if (params.streaming_mode) {
    813         codes->encoded_histograms.emplace_back();
    814         histo_writer = &codes->encoded_histograms.back();
    815       }
    816       BitWriter::Allotment allotment(histo_writer, 256 + alphabet_size * 24);
    817       cost += BuildAndStoreANSEncodingData(
    818           params.ans_histogram_strategy, clustered_histograms[c].data_.data(),
    819           alphabet_size, log_alpha_size, codes->use_prefix_code,
    820           codes->encoding_info.back().data(), histo_writer);
    821       allotment.FinishedHistogram(histo_writer);
    822       allotment.ReclaimAndCharge(histo_writer, layer, aux_out);
    823       if (params.streaming_mode) {
    824         writer->AppendUnaligned(*histo_writer);
    825       }
    826     }
    827     return cost;
    828   }
    829 
    830   const Histogram& Histo(size_t i) const { return histograms_[i]; }
    831 
    832  private:
    833   std::vector<Histogram> histograms_;
    834 };
    835 
    836 class SymbolCostEstimator {
    837  public:
    838   SymbolCostEstimator(size_t num_contexts, bool force_huffman,
    839                       const std::vector<std::vector<Token>>& tokens,
    840                       const LZ77Params& lz77) {
    841     HistogramBuilder builder(num_contexts);
    842     // Build histograms for estimating lz77 savings.
    843     HybridUintConfig uint_config;
    844     for (const auto& stream : tokens) {
    845       for (const auto& token : stream) {
    846         uint32_t tok, nbits, bits;
    847         (token.is_lz77_length ? lz77.length_uint_config : uint_config)
    848             .Encode(token.value, &tok, &nbits, &bits);
    849         tok += token.is_lz77_length ? lz77.min_symbol : 0;
    850         builder.VisitSymbol(tok, token.context);
    851       }
    852     }
    853     max_alphabet_size_ = 0;
    854     for (size_t i = 0; i < num_contexts; i++) {
    855       max_alphabet_size_ =
    856           std::max(max_alphabet_size_, builder.Histo(i).data_.size());
    857     }
    858     bits_.resize(num_contexts * max_alphabet_size_);
    859     // TODO(veluca): SIMD?
    860     add_symbol_cost_.resize(num_contexts);
    861     for (size_t i = 0; i < num_contexts; i++) {
    862       float inv_total = 1.0f / (builder.Histo(i).total_count_ + 1e-8f);
    863       float total_cost = 0;
    864       for (size_t j = 0; j < builder.Histo(i).data_.size(); j++) {
    865         size_t cnt = builder.Histo(i).data_[j];
    866         float cost = 0;
    867         if (cnt != 0 && cnt != builder.Histo(i).total_count_) {
    868           cost = -FastLog2f(cnt * inv_total);
    869           if (force_huffman) cost = std::ceil(cost);
    870         } else if (cnt == 0) {
    871           cost = ANS_LOG_TAB_SIZE;  // Highest possible cost.
    872         }
    873         bits_[i * max_alphabet_size_ + j] = cost;
    874         total_cost += cost * builder.Histo(i).data_[j];
    875       }
    876       // Penalty for adding a lz77 symbol to this contest (only used for static
    877       // cost model). Higher penalty for contexts that have a very low
    878       // per-symbol entropy.
    879       add_symbol_cost_[i] = std::max(0.0f, 6.0f - total_cost * inv_total);
    880     }
    881   }
    882   float Bits(size_t ctx, size_t sym) const {
    883     return bits_[ctx * max_alphabet_size_ + sym];
    884   }
    885   float LenCost(size_t ctx, size_t len, const LZ77Params& lz77) const {
    886     uint32_t nbits, bits, tok;
    887     lz77.length_uint_config.Encode(len, &tok, &nbits, &bits);
    888     tok += lz77.min_symbol;
    889     return nbits + Bits(ctx, tok);
    890   }
    891   float DistCost(size_t len, const LZ77Params& lz77) const {
    892     uint32_t nbits, bits, tok;
    893     HybridUintConfig().Encode(len, &tok, &nbits, &bits);
    894     return nbits + Bits(lz77.nonserialized_distance_context, tok);
    895   }
    896   float AddSymbolCost(size_t idx) const { return add_symbol_cost_[idx]; }
    897 
    898  private:
    899   size_t max_alphabet_size_;
    900   std::vector<float> bits_;
    901   std::vector<float> add_symbol_cost_;
    902 };
    903 
    904 void ApplyLZ77_RLE(const HistogramParams& params, size_t num_contexts,
    905                    const std::vector<std::vector<Token>>& tokens,
    906                    LZ77Params& lz77,
    907                    std::vector<std::vector<Token>>& tokens_lz77) {
    908   // TODO(veluca): tune heuristics here.
    909   SymbolCostEstimator sce(num_contexts, params.force_huffman, tokens, lz77);
    910   float bit_decrease = 0;
    911   size_t total_symbols = 0;
    912   tokens_lz77.resize(tokens.size());
    913   std::vector<float> sym_cost;
    914   HybridUintConfig uint_config;
    915   for (size_t stream = 0; stream < tokens.size(); stream++) {
    916     size_t distance_multiplier =
    917         params.image_widths.size() > stream ? params.image_widths[stream] : 0;
    918     const auto& in = tokens[stream];
    919     auto& out = tokens_lz77[stream];
    920     total_symbols += in.size();
    921     // Cumulative sum of bit costs.
    922     sym_cost.resize(in.size() + 1);
    923     for (size_t i = 0; i < in.size(); i++) {
    924       uint32_t tok, nbits, unused_bits;
    925       uint_config.Encode(in[i].value, &tok, &nbits, &unused_bits);
    926       sym_cost[i + 1] = sce.Bits(in[i].context, tok) + nbits + sym_cost[i];
    927     }
    928     out.reserve(in.size());
    929     for (size_t i = 0; i < in.size(); i++) {
    930       size_t num_to_copy = 0;
    931       size_t distance_symbol = 0;  // 1 for RLE.
    932       if (distance_multiplier != 0) {
    933         distance_symbol = 1;  // Special distance 1 if enabled.
    934         JXL_DASSERT(kSpecialDistances[1][0] == 1);
    935         JXL_DASSERT(kSpecialDistances[1][1] == 0);
    936       }
    937       if (i > 0) {
    938         for (; i + num_to_copy < in.size(); num_to_copy++) {
    939           if (in[i + num_to_copy].value != in[i - 1].value) {
    940             break;
    941           }
    942         }
    943       }
    944       if (num_to_copy == 0) {
    945         out.push_back(in[i]);
    946         continue;
    947       }
    948       float cost = sym_cost[i + num_to_copy] - sym_cost[i];
    949       // This subtraction might overflow, but that's OK.
    950       size_t lz77_len = num_to_copy - lz77.min_length;
    951       float lz77_cost = num_to_copy >= lz77.min_length
    952                             ? CeilLog2Nonzero(lz77_len + 1) + 1
    953                             : 0;
    954       if (num_to_copy < lz77.min_length || cost <= lz77_cost) {
    955         for (size_t j = 0; j < num_to_copy; j++) {
    956           out.push_back(in[i + j]);
    957         }
    958         i += num_to_copy - 1;
    959         continue;
    960       }
    961       // Output the LZ77 length
    962       out.emplace_back(in[i].context, lz77_len);
    963       out.back().is_lz77_length = true;
    964       i += num_to_copy - 1;
    965       bit_decrease += cost - lz77_cost;
    966       // Output the LZ77 copy distance.
    967       out.emplace_back(lz77.nonserialized_distance_context, distance_symbol);
    968     }
    969   }
    970 
    971   if (bit_decrease > total_symbols * 0.2 + 16) {
    972     lz77.enabled = true;
    973   }
    974 }
    975 
    976 // Hash chain for LZ77 matching
    977 struct HashChain {
    978   size_t size_;
    979   std::vector<uint32_t> data_;
    980 
    981   unsigned hash_num_values_ = 32768;
    982   unsigned hash_mask_ = hash_num_values_ - 1;
    983   unsigned hash_shift_ = 5;
    984 
    985   std::vector<int> head;
    986   std::vector<uint32_t> chain;
    987   std::vector<int> val;
    988 
    989   // Speed up repetitions of zero
    990   std::vector<int> headz;
    991   std::vector<uint32_t> chainz;
    992   std::vector<uint32_t> zeros;
    993   uint32_t numzeros = 0;
    994 
    995   size_t window_size_;
    996   size_t window_mask_;
    997   size_t min_length_;
    998   size_t max_length_;
    999 
   1000   // Map of special distance codes.
   1001   std::unordered_map<int, int> special_dist_table_;
   1002   size_t num_special_distances_ = 0;
   1003 
   1004   uint32_t maxchainlength = 256;  // window_size_ to allow all
   1005 
   1006   HashChain(const Token* data, size_t size, size_t window_size,
   1007             size_t min_length, size_t max_length, size_t distance_multiplier)
   1008       : size_(size),
   1009         window_size_(window_size),
   1010         window_mask_(window_size - 1),
   1011         min_length_(min_length),
   1012         max_length_(max_length) {
   1013     data_.resize(size);
   1014     for (size_t i = 0; i < size; i++) {
   1015       data_[i] = data[i].value;
   1016     }
   1017 
   1018     head.resize(hash_num_values_, -1);
   1019     val.resize(window_size_, -1);
   1020     chain.resize(window_size_);
   1021     for (uint32_t i = 0; i < window_size_; ++i) {
   1022       chain[i] = i;  // same value as index indicates uninitialized
   1023     }
   1024 
   1025     zeros.resize(window_size_);
   1026     headz.resize(window_size_ + 1, -1);
   1027     chainz.resize(window_size_);
   1028     for (uint32_t i = 0; i < window_size_; ++i) {
   1029       chainz[i] = i;
   1030     }
   1031     // Translate distance to special distance code.
   1032     if (distance_multiplier) {
   1033       // Count down, so if due to small distance multiplier multiple distances
   1034       // map to the same code, the smallest code will be used in the end.
   1035       for (int i = kNumSpecialDistances - 1; i >= 0; --i) {
   1036         special_dist_table_[SpecialDistance(i, distance_multiplier)] = i;
   1037       }
   1038       num_special_distances_ = kNumSpecialDistances;
   1039     }
   1040   }
   1041 
   1042   uint32_t GetHash(size_t pos) const {
   1043     uint32_t result = 0;
   1044     if (pos + 2 < size_) {
   1045       // TODO(lode): take the MSB's of the uint32_t values into account as well,
   1046       // given that the hash code itself is less than 32 bits.
   1047       result ^= static_cast<uint32_t>(data_[pos + 0] << 0u);
   1048       result ^= static_cast<uint32_t>(data_[pos + 1] << hash_shift_);
   1049       result ^= static_cast<uint32_t>(data_[pos + 2] << (hash_shift_ * 2));
   1050     } else {
   1051       // No need to compute hash of last 2 bytes, the length 2 is too short.
   1052       return 0;
   1053     }
   1054     return result & hash_mask_;
   1055   }
   1056 
   1057   uint32_t CountZeros(size_t pos, uint32_t prevzeros) const {
   1058     size_t end = pos + window_size_;
   1059     if (end > size_) end = size_;
   1060     if (prevzeros > 0) {
   1061       if (prevzeros >= window_mask_ && data_[end - 1] == 0 &&
   1062           end == pos + window_size_) {
   1063         return prevzeros;
   1064       } else {
   1065         return prevzeros - 1;
   1066       }
   1067     }
   1068     uint32_t num = 0;
   1069     while (pos + num < end && data_[pos + num] == 0) num++;
   1070     return num;
   1071   }
   1072 
   1073   void Update(size_t pos) {
   1074     uint32_t hashval = GetHash(pos);
   1075     uint32_t wpos = pos & window_mask_;
   1076 
   1077     val[wpos] = static_cast<int>(hashval);
   1078     if (head[hashval] != -1) chain[wpos] = head[hashval];
   1079     head[hashval] = wpos;
   1080 
   1081     if (pos > 0 && data_[pos] != data_[pos - 1]) numzeros = 0;
   1082     numzeros = CountZeros(pos, numzeros);
   1083 
   1084     zeros[wpos] = numzeros;
   1085     if (headz[numzeros] != -1) chainz[wpos] = headz[numzeros];
   1086     headz[numzeros] = wpos;
   1087   }
   1088 
   1089   void Update(size_t pos, size_t len) {
   1090     for (size_t i = 0; i < len; i++) {
   1091       Update(pos + i);
   1092     }
   1093   }
   1094 
   1095   template <typename CB>
   1096   void FindMatches(size_t pos, int max_dist, const CB& found_match) const {
   1097     uint32_t wpos = pos & window_mask_;
   1098     uint32_t hashval = GetHash(pos);
   1099     uint32_t hashpos = chain[wpos];
   1100 
   1101     int prev_dist = 0;
   1102     int end = std::min<int>(pos + max_length_, size_);
   1103     uint32_t chainlength = 0;
   1104     uint32_t best_len = 0;
   1105     for (;;) {
   1106       int dist = (hashpos <= wpos) ? (wpos - hashpos)
   1107                                    : (wpos - hashpos + window_mask_ + 1);
   1108       if (dist < prev_dist) break;
   1109       prev_dist = dist;
   1110       uint32_t len = 0;
   1111       if (dist > 0) {
   1112         int i = pos;
   1113         int j = pos - dist;
   1114         if (numzeros > 3) {
   1115           int r = std::min<int>(numzeros - 1, zeros[hashpos]);
   1116           if (i + r >= end) r = end - i - 1;
   1117           i += r;
   1118           j += r;
   1119         }
   1120         while (i < end && data_[i] == data_[j]) {
   1121           i++;
   1122           j++;
   1123         }
   1124         len = i - pos;
   1125         // This can trigger even if the new length is slightly smaller than the
   1126         // best length, because it is possible for a slightly cheaper distance
   1127         // symbol to occur.
   1128         if (len >= min_length_ && len + 2 >= best_len) {
   1129           auto it = special_dist_table_.find(dist);
   1130           int dist_symbol = (it == special_dist_table_.end())
   1131                                 ? (num_special_distances_ + dist - 1)
   1132                                 : it->second;
   1133           found_match(len, dist_symbol);
   1134           if (len > best_len) best_len = len;
   1135         }
   1136       }
   1137 
   1138       chainlength++;
   1139       if (chainlength >= maxchainlength) break;
   1140 
   1141       if (numzeros >= 3 && len > numzeros) {
   1142         if (hashpos == chainz[hashpos]) break;
   1143         hashpos = chainz[hashpos];
   1144         if (zeros[hashpos] != numzeros) break;
   1145       } else {
   1146         if (hashpos == chain[hashpos]) break;
   1147         hashpos = chain[hashpos];
   1148         if (val[hashpos] != static_cast<int>(hashval)) {
   1149           // outdated hash value
   1150           break;
   1151         }
   1152       }
   1153     }
   1154   }
   1155   void FindMatch(size_t pos, int max_dist, size_t* result_dist_symbol,
   1156                  size_t* result_len) const {
   1157     *result_dist_symbol = 0;
   1158     *result_len = 1;
   1159     FindMatches(pos, max_dist, [&](size_t len, size_t dist_symbol) {
   1160       if (len > *result_len ||
   1161           (len == *result_len && *result_dist_symbol > dist_symbol)) {
   1162         *result_len = len;
   1163         *result_dist_symbol = dist_symbol;
   1164       }
   1165     });
   1166   }
   1167 };
   1168 
   1169 float LenCost(size_t len) {
   1170   uint32_t nbits, bits, tok;
   1171   HybridUintConfig(1, 0, 0).Encode(len, &tok, &nbits, &bits);
   1172   constexpr float kCostTable[] = {
   1173       2.797667318563126,  3.213177690381199,  2.5706009246743737,
   1174       2.408392498667534,  2.829649191872326,  3.3923087753324577,
   1175       4.029267451554331,  4.415576699706408,  4.509357574741465,
   1176       9.21481543803004,   10.020590190114898, 11.858671627804766,
   1177       12.45853300490526,  11.713105831990857, 12.561996324849314,
   1178       13.775477692278367, 13.174027068768641,
   1179   };
   1180   size_t table_size = sizeof kCostTable / sizeof *kCostTable;
   1181   if (tok >= table_size) tok = table_size - 1;
   1182   return kCostTable[tok] + nbits;
   1183 }
   1184 
   1185 // TODO(veluca): this does not take into account usage or non-usage of distance
   1186 // multipliers.
   1187 float DistCost(size_t dist) {
   1188   uint32_t nbits, bits, tok;
   1189   HybridUintConfig(7, 0, 0).Encode(dist, &tok, &nbits, &bits);
   1190   constexpr float kCostTable[] = {
   1191       6.368282626312716,  5.680793277090298,  8.347404197105247,
   1192       7.641619201599141,  6.914328374119438,  7.959808291537444,
   1193       8.70023120759855,   8.71378518934703,   9.379132523982769,
   1194       9.110472749092708,  9.159029569270908,  9.430936766731973,
   1195       7.278284055315169,  7.8278514904267755, 10.026641158289236,
   1196       9.976049229827066,  9.64351607048908,   9.563403863480442,
   1197       10.171474111762747, 10.45950155077234,  9.994813912104219,
   1198       10.322524683741156, 8.465808729388186,  8.756254166066853,
   1199       10.160930174662234, 10.247329273413435, 10.04090403724809,
   1200       10.129398517544082, 9.342311691539546,  9.07608009102374,
   1201       10.104799540677513, 10.378079384990906, 10.165828974075072,
   1202       10.337595322341553, 7.940557464567944,  10.575665823319431,
   1203       11.023344321751955, 10.736144698831827, 11.118277044595054,
   1204       7.468468230648442,  10.738305230932939, 10.906980780216568,
   1205       10.163468216353817, 10.17805759656433,  11.167283670483565,
   1206       11.147050200274544, 10.517921919244333, 10.651764778156886,
   1207       10.17074446448919,  11.217636876224745, 11.261630721139484,
   1208       11.403140815247259, 10.892472096873417, 11.1859607804481,
   1209       8.017346947551262,  7.895143720278828,  11.036577113822025,
   1210       11.170562110315794, 10.326988722591086, 10.40872184751056,
   1211       11.213498225466386, 11.30580635516863,  10.672272515665442,
   1212       10.768069466228063, 11.145257364153565, 11.64668307145549,
   1213       10.593156194627339, 11.207499484844943, 10.767517766396908,
   1214       10.826629811407042, 10.737764794499988, 10.6200448518045,
   1215       10.191315385198092, 8.468384171390085,  11.731295299170432,
   1216       11.824619886654398, 10.41518844301179,  10.16310536548649,
   1217       10.539423685097576, 10.495136599328031, 10.469112847728267,
   1218       11.72057686174922,  10.910326337834674, 11.378921834673758,
   1219       11.847759036098536, 11.92071647623854,  10.810628276345282,
   1220       11.008601085273893, 11.910326337834674, 11.949212023423133,
   1221       11.298614839104337, 11.611603659010392, 10.472930394619985,
   1222       11.835564720850282, 11.523267392285337, 12.01055816679611,
   1223       8.413029688994023,  11.895784139536406, 11.984679534970505,
   1224       11.220654278717394, 11.716311684833672, 10.61036646226114,
   1225       10.89849965960364,  10.203762898863669, 10.997560826267238,
   1226       11.484217379438984, 11.792836176993665, 12.24310468755171,
   1227       11.464858097919262, 12.212747017409377, 11.425595666074955,
   1228       11.572048533398757, 12.742093965163013, 11.381874288645637,
   1229       12.191870445817015, 11.683156920035426, 11.152442115262197,
   1230       11.90303691580457,  11.653292787169159, 11.938615382266098,
   1231       16.970641701570223, 16.853602280380002, 17.26240782594733,
   1232       16.644655390108507, 17.14310889757499,  16.910935455445955,
   1233       17.505678976959697, 17.213498225466388, 2.4162310293553024,
   1234       3.494587244462329,  3.5258600986408344, 3.4959806589517095,
   1235       3.098390886949687,  3.343454654302911,  3.588847442290287,
   1236       4.14614790111827,   5.152948641990529,  7.433696808092598,
   1237       9.716311684833672,
   1238   };
   1239   size_t table_size = sizeof kCostTable / sizeof *kCostTable;
   1240   if (tok >= table_size) tok = table_size - 1;
   1241   return kCostTable[tok] + nbits;
   1242 }
   1243 
   1244 void ApplyLZ77_LZ77(const HistogramParams& params, size_t num_contexts,
   1245                     const std::vector<std::vector<Token>>& tokens,
   1246                     LZ77Params& lz77,
   1247                     std::vector<std::vector<Token>>& tokens_lz77) {
   1248   // TODO(veluca): tune heuristics here.
   1249   SymbolCostEstimator sce(num_contexts, params.force_huffman, tokens, lz77);
   1250   float bit_decrease = 0;
   1251   size_t total_symbols = 0;
   1252   tokens_lz77.resize(tokens.size());
   1253   HybridUintConfig uint_config;
   1254   std::vector<float> sym_cost;
   1255   for (size_t stream = 0; stream < tokens.size(); stream++) {
   1256     size_t distance_multiplier =
   1257         params.image_widths.size() > stream ? params.image_widths[stream] : 0;
   1258     const auto& in = tokens[stream];
   1259     auto& out = tokens_lz77[stream];
   1260     total_symbols += in.size();
   1261     // Cumulative sum of bit costs.
   1262     sym_cost.resize(in.size() + 1);
   1263     for (size_t i = 0; i < in.size(); i++) {
   1264       uint32_t tok, nbits, unused_bits;
   1265       uint_config.Encode(in[i].value, &tok, &nbits, &unused_bits);
   1266       sym_cost[i + 1] = sce.Bits(in[i].context, tok) + nbits + sym_cost[i];
   1267     }
   1268 
   1269     out.reserve(in.size());
   1270     size_t max_distance = in.size();
   1271     size_t min_length = lz77.min_length;
   1272     JXL_ASSERT(min_length >= 3);
   1273     size_t max_length = in.size();
   1274 
   1275     // Use next power of two as window size.
   1276     size_t window_size = 1;
   1277     while (window_size < max_distance && window_size < kWindowSize) {
   1278       window_size <<= 1;
   1279     }
   1280 
   1281     HashChain chain(in.data(), in.size(), window_size, min_length, max_length,
   1282                     distance_multiplier);
   1283     size_t len;
   1284     size_t dist_symbol;
   1285 
   1286     const size_t max_lazy_match_len = 256;  // 0 to disable lazy matching
   1287 
   1288     // Whether the next symbol was already updated (to test lazy matching)
   1289     bool already_updated = false;
   1290     for (size_t i = 0; i < in.size(); i++) {
   1291       out.push_back(in[i]);
   1292       if (!already_updated) chain.Update(i);
   1293       already_updated = false;
   1294       chain.FindMatch(i, max_distance, &dist_symbol, &len);
   1295       if (len >= min_length) {
   1296         if (len < max_lazy_match_len && i + 1 < in.size()) {
   1297           // Try length at next symbol lazy matching
   1298           chain.Update(i + 1);
   1299           already_updated = true;
   1300           size_t len2, dist_symbol2;
   1301           chain.FindMatch(i + 1, max_distance, &dist_symbol2, &len2);
   1302           if (len2 > len) {
   1303             // Use the lazy match. Add literal, and use the next length starting
   1304             // from the next byte.
   1305             ++i;
   1306             already_updated = false;
   1307             len = len2;
   1308             dist_symbol = dist_symbol2;
   1309             out.push_back(in[i]);
   1310           }
   1311         }
   1312 
   1313         float cost = sym_cost[i + len] - sym_cost[i];
   1314         size_t lz77_len = len - lz77.min_length;
   1315         float lz77_cost = LenCost(lz77_len) + DistCost(dist_symbol) +
   1316                           sce.AddSymbolCost(out.back().context);
   1317 
   1318         if (lz77_cost <= cost) {
   1319           out.back().value = len - min_length;
   1320           out.back().is_lz77_length = true;
   1321           out.emplace_back(lz77.nonserialized_distance_context, dist_symbol);
   1322           bit_decrease += cost - lz77_cost;
   1323         } else {
   1324           // LZ77 match ignored, and symbol already pushed. Push all other
   1325           // symbols and skip.
   1326           for (size_t j = 1; j < len; j++) {
   1327             out.push_back(in[i + j]);
   1328           }
   1329         }
   1330 
   1331         if (already_updated) {
   1332           chain.Update(i + 2, len - 2);
   1333           already_updated = false;
   1334         } else {
   1335           chain.Update(i + 1, len - 1);
   1336         }
   1337         i += len - 1;
   1338       } else {
   1339         // Literal, already pushed
   1340       }
   1341     }
   1342   }
   1343 
   1344   if (bit_decrease > total_symbols * 0.2 + 16) {
   1345     lz77.enabled = true;
   1346   }
   1347 }
   1348 
   1349 void ApplyLZ77_Optimal(const HistogramParams& params, size_t num_contexts,
   1350                        const std::vector<std::vector<Token>>& tokens,
   1351                        LZ77Params& lz77,
   1352                        std::vector<std::vector<Token>>& tokens_lz77) {
   1353   std::vector<std::vector<Token>> tokens_for_cost_estimate;
   1354   ApplyLZ77_LZ77(params, num_contexts, tokens, lz77, tokens_for_cost_estimate);
   1355   // If greedy-LZ77 does not give better compression than no-lz77, no reason to
   1356   // run the optimal matching.
   1357   if (!lz77.enabled) return;
   1358   SymbolCostEstimator sce(num_contexts + 1, params.force_huffman,
   1359                           tokens_for_cost_estimate, lz77);
   1360   tokens_lz77.resize(tokens.size());
   1361   HybridUintConfig uint_config;
   1362   std::vector<float> sym_cost;
   1363   std::vector<uint32_t> dist_symbols;
   1364   for (size_t stream = 0; stream < tokens.size(); stream++) {
   1365     size_t distance_multiplier =
   1366         params.image_widths.size() > stream ? params.image_widths[stream] : 0;
   1367     const auto& in = tokens[stream];
   1368     auto& out = tokens_lz77[stream];
   1369     // Cumulative sum of bit costs.
   1370     sym_cost.resize(in.size() + 1);
   1371     for (size_t i = 0; i < in.size(); i++) {
   1372       uint32_t tok, nbits, unused_bits;
   1373       uint_config.Encode(in[i].value, &tok, &nbits, &unused_bits);
   1374       sym_cost[i + 1] = sce.Bits(in[i].context, tok) + nbits + sym_cost[i];
   1375     }
   1376 
   1377     out.reserve(in.size());
   1378     size_t max_distance = in.size();
   1379     size_t min_length = lz77.min_length;
   1380     JXL_ASSERT(min_length >= 3);
   1381     size_t max_length = in.size();
   1382 
   1383     // Use next power of two as window size.
   1384     size_t window_size = 1;
   1385     while (window_size < max_distance && window_size < kWindowSize) {
   1386       window_size <<= 1;
   1387     }
   1388 
   1389     HashChain chain(in.data(), in.size(), window_size, min_length, max_length,
   1390                     distance_multiplier);
   1391 
   1392     struct MatchInfo {
   1393       uint32_t len;
   1394       uint32_t dist_symbol;
   1395       uint32_t ctx;
   1396       float total_cost = std::numeric_limits<float>::max();
   1397     };
   1398     // Total cost to encode the first N symbols.
   1399     std::vector<MatchInfo> prefix_costs(in.size() + 1);
   1400     prefix_costs[0].total_cost = 0;
   1401 
   1402     size_t rle_length = 0;
   1403     size_t skip_lz77 = 0;
   1404     for (size_t i = 0; i < in.size(); i++) {
   1405       chain.Update(i);
   1406       float lit_cost =
   1407           prefix_costs[i].total_cost + sym_cost[i + 1] - sym_cost[i];
   1408       if (prefix_costs[i + 1].total_cost > lit_cost) {
   1409         prefix_costs[i + 1].dist_symbol = 0;
   1410         prefix_costs[i + 1].len = 1;
   1411         prefix_costs[i + 1].ctx = in[i].context;
   1412         prefix_costs[i + 1].total_cost = lit_cost;
   1413       }
   1414       if (skip_lz77 > 0) {
   1415         skip_lz77--;
   1416         continue;
   1417       }
   1418       dist_symbols.clear();
   1419       chain.FindMatches(i, max_distance,
   1420                         [&dist_symbols](size_t len, size_t dist_symbol) {
   1421                           if (dist_symbols.size() <= len) {
   1422                             dist_symbols.resize(len + 1, dist_symbol);
   1423                           }
   1424                           if (dist_symbol < dist_symbols[len]) {
   1425                             dist_symbols[len] = dist_symbol;
   1426                           }
   1427                         });
   1428       if (dist_symbols.size() <= min_length) continue;
   1429       {
   1430         size_t best_cost = dist_symbols.back();
   1431         for (size_t j = dist_symbols.size() - 1; j >= min_length; j--) {
   1432           if (dist_symbols[j] < best_cost) {
   1433             best_cost = dist_symbols[j];
   1434           }
   1435           dist_symbols[j] = best_cost;
   1436         }
   1437       }
   1438       for (size_t j = min_length; j < dist_symbols.size(); j++) {
   1439         // Cost model that uses results from lazy LZ77.
   1440         float lz77_cost = sce.LenCost(in[i].context, j - min_length, lz77) +
   1441                           sce.DistCost(dist_symbols[j], lz77);
   1442         float cost = prefix_costs[i].total_cost + lz77_cost;
   1443         if (prefix_costs[i + j].total_cost > cost) {
   1444           prefix_costs[i + j].len = j;
   1445           prefix_costs[i + j].dist_symbol = dist_symbols[j] + 1;
   1446           prefix_costs[i + j].ctx = in[i].context;
   1447           prefix_costs[i + j].total_cost = cost;
   1448         }
   1449       }
   1450       // We are in a RLE sequence: skip all the symbols except the first 8 and
   1451       // the last 8. This avoid quadratic costs for sequences with long runs of
   1452       // the same symbol.
   1453       if ((dist_symbols.back() == 0 && distance_multiplier == 0) ||
   1454           (dist_symbols.back() == 1 && distance_multiplier != 0)) {
   1455         rle_length++;
   1456       } else {
   1457         rle_length = 0;
   1458       }
   1459       if (rle_length >= 8 && dist_symbols.size() > 9) {
   1460         skip_lz77 = dist_symbols.size() - 10;
   1461         rle_length = 0;
   1462       }
   1463     }
   1464     size_t pos = in.size();
   1465     while (pos > 0) {
   1466       bool is_lz77_length = prefix_costs[pos].dist_symbol != 0;
   1467       if (is_lz77_length) {
   1468         size_t dist_symbol = prefix_costs[pos].dist_symbol - 1;
   1469         out.emplace_back(lz77.nonserialized_distance_context, dist_symbol);
   1470       }
   1471       size_t val = is_lz77_length ? prefix_costs[pos].len - min_length
   1472                                   : in[pos - 1].value;
   1473       out.emplace_back(prefix_costs[pos].ctx, val);
   1474       out.back().is_lz77_length = is_lz77_length;
   1475       pos -= prefix_costs[pos].len;
   1476     }
   1477     std::reverse(out.begin(), out.end());
   1478   }
   1479 }
   1480 
   1481 void ApplyLZ77(const HistogramParams& params, size_t num_contexts,
   1482                const std::vector<std::vector<Token>>& tokens, LZ77Params& lz77,
   1483                std::vector<std::vector<Token>>& tokens_lz77) {
   1484   if (params.initialize_global_state) {
   1485     lz77.enabled = false;
   1486   }
   1487   if (params.force_huffman) {
   1488     lz77.min_symbol = std::min(PREFIX_MAX_ALPHABET_SIZE - 32, 512);
   1489   } else {
   1490     lz77.min_symbol = 224;
   1491   }
   1492   if (params.lz77_method == HistogramParams::LZ77Method::kNone) {
   1493     return;
   1494   } else if (params.lz77_method == HistogramParams::LZ77Method::kRLE) {
   1495     ApplyLZ77_RLE(params, num_contexts, tokens, lz77, tokens_lz77);
   1496   } else if (params.lz77_method == HistogramParams::LZ77Method::kLZ77) {
   1497     ApplyLZ77_LZ77(params, num_contexts, tokens, lz77, tokens_lz77);
   1498   } else if (params.lz77_method == HistogramParams::LZ77Method::kOptimal) {
   1499     ApplyLZ77_Optimal(params, num_contexts, tokens, lz77, tokens_lz77);
   1500   } else {
   1501     JXL_UNREACHABLE("Not implemented");
   1502   }
   1503 }
   1504 }  // namespace
   1505 
   1506 void EncodeHistograms(const std::vector<uint8_t>& context_map,
   1507                       const EntropyEncodingData& codes, BitWriter* writer,
   1508                       size_t layer, AuxOut* aux_out) {
   1509   BitWriter::Allotment allotment(writer, 128 + kClustersLimit * 136);
   1510   JXL_CHECK(Bundle::Write(codes.lz77, writer, layer, aux_out));
   1511   if (codes.lz77.enabled) {
   1512     EncodeUintConfig(codes.lz77.length_uint_config, writer,
   1513                      /*log_alpha_size=*/8);
   1514   }
   1515   EncodeContextMap(context_map, codes.encoding_info.size(), writer, layer,
   1516                    aux_out);
   1517   writer->Write(1, TO_JXL_BOOL(codes.use_prefix_code));
   1518   size_t log_alpha_size = 8;
   1519   if (codes.use_prefix_code) {
   1520     log_alpha_size = PREFIX_MAX_BITS;
   1521   } else {
   1522     log_alpha_size = 8;  // streaming_mode
   1523     writer->Write(2, log_alpha_size - 5);
   1524   }
   1525   EncodeUintConfigs(codes.uint_config, writer, log_alpha_size);
   1526   if (codes.use_prefix_code) {
   1527     for (const auto& info : codes.encoding_info) {
   1528       StoreVarLenUint16(info.size() - 1, writer);
   1529     }
   1530   }
   1531   for (const auto& histo_writer : codes.encoded_histograms) {
   1532     writer->AppendUnaligned(histo_writer);
   1533   }
   1534   allotment.FinishedHistogram(writer);
   1535   allotment.ReclaimAndCharge(writer, layer, aux_out);
   1536 }
   1537 
   1538 size_t BuildAndEncodeHistograms(const HistogramParams& params,
   1539                                 size_t num_contexts,
   1540                                 std::vector<std::vector<Token>>& tokens,
   1541                                 EntropyEncodingData* codes,
   1542                                 std::vector<uint8_t>* context_map,
   1543                                 BitWriter* writer, size_t layer,
   1544                                 AuxOut* aux_out) {
   1545   size_t total_bits = 0;
   1546   codes->lz77.nonserialized_distance_context = num_contexts;
   1547   std::vector<std::vector<Token>> tokens_lz77;
   1548   ApplyLZ77(params, num_contexts, tokens, codes->lz77, tokens_lz77);
   1549   if (ans_fuzzer_friendly_) {
   1550     codes->lz77.length_uint_config = HybridUintConfig(10, 0, 0);
   1551     codes->lz77.min_symbol = 2048;
   1552   }
   1553 
   1554   const size_t max_contexts = std::min(num_contexts, kClustersLimit);
   1555   BitWriter::Allotment allotment(writer,
   1556                                  128 + num_contexts * 40 + max_contexts * 96);
   1557   if (writer) {
   1558     JXL_CHECK(Bundle::Write(codes->lz77, writer, layer, aux_out));
   1559   } else {
   1560     size_t ebits, bits;
   1561     JXL_CHECK(Bundle::CanEncode(codes->lz77, &ebits, &bits));
   1562     total_bits += bits;
   1563   }
   1564   if (codes->lz77.enabled) {
   1565     if (writer) {
   1566       size_t b = writer->BitsWritten();
   1567       EncodeUintConfig(codes->lz77.length_uint_config, writer,
   1568                        /*log_alpha_size=*/8);
   1569       total_bits += writer->BitsWritten() - b;
   1570     } else {
   1571       SizeWriter size_writer;
   1572       EncodeUintConfig(codes->lz77.length_uint_config, &size_writer,
   1573                        /*log_alpha_size=*/8);
   1574       total_bits += size_writer.size;
   1575     }
   1576     num_contexts += 1;
   1577     tokens = std::move(tokens_lz77);
   1578   }
   1579   size_t total_tokens = 0;
   1580   // Build histograms.
   1581   HistogramBuilder builder(num_contexts);
   1582   HybridUintConfig uint_config;  //  Default config for clustering.
   1583   // Unless we are using the kContextMap histogram option.
   1584   if (params.uint_method == HistogramParams::HybridUintMethod::kContextMap) {
   1585     uint_config = HybridUintConfig(2, 0, 1);
   1586   }
   1587   if (params.uint_method == HistogramParams::HybridUintMethod::k000) {
   1588     uint_config = HybridUintConfig(0, 0, 0);
   1589   }
   1590   if (ans_fuzzer_friendly_) {
   1591     uint_config = HybridUintConfig(10, 0, 0);
   1592   }
   1593   for (const auto& stream : tokens) {
   1594     if (codes->lz77.enabled) {
   1595       for (const auto& token : stream) {
   1596         total_tokens++;
   1597         uint32_t tok, nbits, bits;
   1598         (token.is_lz77_length ? codes->lz77.length_uint_config : uint_config)
   1599             .Encode(token.value, &tok, &nbits, &bits);
   1600         tok += token.is_lz77_length ? codes->lz77.min_symbol : 0;
   1601         builder.VisitSymbol(tok, token.context);
   1602       }
   1603     } else if (num_contexts == 1) {
   1604       for (const auto& token : stream) {
   1605         total_tokens++;
   1606         uint32_t tok, nbits, bits;
   1607         uint_config.Encode(token.value, &tok, &nbits, &bits);
   1608         builder.VisitSymbol(tok, /*token.context=*/0);
   1609       }
   1610     } else {
   1611       for (const auto& token : stream) {
   1612         total_tokens++;
   1613         uint32_t tok, nbits, bits;
   1614         uint_config.Encode(token.value, &tok, &nbits, &bits);
   1615         builder.VisitSymbol(tok, token.context);
   1616       }
   1617     }
   1618   }
   1619 
   1620   if (params.add_missing_symbols) {
   1621     for (size_t c = 0; c < num_contexts; ++c) {
   1622       for (int symbol = 0; symbol < ANS_MAX_ALPHABET_SIZE; ++symbol) {
   1623         builder.VisitSymbol(symbol, c);
   1624       }
   1625     }
   1626   }
   1627 
   1628   if (params.initialize_global_state) {
   1629     bool use_prefix_code =
   1630         params.force_huffman || total_tokens < 100 ||
   1631         params.clustering == HistogramParams::ClusteringType::kFastest ||
   1632         ans_fuzzer_friendly_;
   1633     if (!use_prefix_code) {
   1634       bool all_singleton = true;
   1635       for (size_t i = 0; i < num_contexts; i++) {
   1636         if (builder.Histo(i).ShannonEntropy() >= 1e-5) {
   1637           all_singleton = false;
   1638         }
   1639       }
   1640       if (all_singleton) {
   1641         use_prefix_code = true;
   1642       }
   1643     }
   1644     codes->use_prefix_code = use_prefix_code;
   1645   }
   1646 
   1647   if (params.add_fixed_histograms) {
   1648     // TODO(szabadka) Add more fixed histograms.
   1649     // TODO(szabadka) Reduce alphabet size by choosing a non-default
   1650     // uint_config.
   1651     const size_t alphabet_size = ANS_MAX_ALPHABET_SIZE;
   1652     const size_t log_alpha_size = 8;
   1653     JXL_ASSERT(alphabet_size == 1u << log_alpha_size);
   1654     std::vector<int32_t> counts =
   1655         CreateFlatHistogram(alphabet_size, ANS_TAB_SIZE);
   1656     codes->encoding_info.emplace_back();
   1657     codes->encoding_info.back().resize(alphabet_size);
   1658     codes->encoded_histograms.emplace_back();
   1659     BitWriter* histo_writer = &codes->encoded_histograms.back();
   1660     BitWriter::Allotment allotment(histo_writer, 256 + alphabet_size * 24);
   1661     BuildAndStoreANSEncodingData(
   1662         params.ans_histogram_strategy, counts.data(), alphabet_size,
   1663         log_alpha_size, codes->use_prefix_code,
   1664         codes->encoding_info.back().data(), histo_writer);
   1665     allotment.ReclaimAndCharge(histo_writer, 0, nullptr);
   1666   }
   1667 
   1668   // Encode histograms.
   1669   total_bits += builder.BuildAndStoreEntropyCodes(
   1670       params, tokens, codes, context_map, writer, layer, aux_out);
   1671   allotment.FinishedHistogram(writer);
   1672   allotment.ReclaimAndCharge(writer, layer, aux_out);
   1673 
   1674   if (aux_out != nullptr) {
   1675     aux_out->layers[layer].num_clustered_histograms +=
   1676         codes->encoding_info.size();
   1677   }
   1678   return total_bits;
   1679 }
   1680 
   1681 size_t WriteTokens(const std::vector<Token>& tokens,
   1682                    const EntropyEncodingData& codes,
   1683                    const std::vector<uint8_t>& context_map,
   1684                    size_t context_offset, BitWriter* writer) {
   1685   size_t num_extra_bits = 0;
   1686   if (codes.use_prefix_code) {
   1687     for (const auto& token : tokens) {
   1688       uint32_t tok, nbits, bits;
   1689       size_t histo = context_map[context_offset + token.context];
   1690       (token.is_lz77_length ? codes.lz77.length_uint_config
   1691                             : codes.uint_config[histo])
   1692           .Encode(token.value, &tok, &nbits, &bits);
   1693       tok += token.is_lz77_length ? codes.lz77.min_symbol : 0;
   1694       // Combine two calls to the BitWriter. Equivalent to:
   1695       // writer->Write(codes.encoding_info[histo][tok].depth,
   1696       //               codes.encoding_info[histo][tok].bits);
   1697       // writer->Write(nbits, bits);
   1698       uint64_t data = codes.encoding_info[histo][tok].bits;
   1699       data |= static_cast<uint64_t>(bits)
   1700               << codes.encoding_info[histo][tok].depth;
   1701       writer->Write(codes.encoding_info[histo][tok].depth + nbits, data);
   1702       num_extra_bits += nbits;
   1703     }
   1704     return num_extra_bits;
   1705   }
   1706   std::vector<uint64_t> out;
   1707   std::vector<uint8_t> out_nbits;
   1708   out.reserve(tokens.size());
   1709   out_nbits.reserve(tokens.size());
   1710   uint64_t allbits = 0;
   1711   size_t numallbits = 0;
   1712   // Writes in *reversed* order.
   1713   auto addbits = [&](size_t bits, size_t nbits) {
   1714     if (JXL_UNLIKELY(nbits)) {
   1715       JXL_DASSERT(bits >> nbits == 0);
   1716       if (JXL_UNLIKELY(numallbits + nbits > BitWriter::kMaxBitsPerCall)) {
   1717         out.push_back(allbits);
   1718         out_nbits.push_back(numallbits);
   1719         numallbits = allbits = 0;
   1720       }
   1721       allbits <<= nbits;
   1722       allbits |= bits;
   1723       numallbits += nbits;
   1724     }
   1725   };
   1726   const int end = tokens.size();
   1727   ANSCoder ans;
   1728   if (codes.lz77.enabled || context_map.size() > 1) {
   1729     for (int i = end - 1; i >= 0; --i) {
   1730       const Token token = tokens[i];
   1731       const uint8_t histo = context_map[context_offset + token.context];
   1732       uint32_t tok, nbits, bits;
   1733       (token.is_lz77_length ? codes.lz77.length_uint_config
   1734                             : codes.uint_config[histo])
   1735           .Encode(tokens[i].value, &tok, &nbits, &bits);
   1736       tok += token.is_lz77_length ? codes.lz77.min_symbol : 0;
   1737       const ANSEncSymbolInfo& info = codes.encoding_info[histo][tok];
   1738       JXL_DASSERT(info.freq_ > 0);
   1739       // Extra bits first as this is reversed.
   1740       addbits(bits, nbits);
   1741       num_extra_bits += nbits;
   1742       uint8_t ans_nbits = 0;
   1743       uint32_t ans_bits = ans.PutSymbol(info, &ans_nbits);
   1744       addbits(ans_bits, ans_nbits);
   1745     }
   1746   } else {
   1747     for (int i = end - 1; i >= 0; --i) {
   1748       uint32_t tok, nbits, bits;
   1749       codes.uint_config[0].Encode(tokens[i].value, &tok, &nbits, &bits);
   1750       const ANSEncSymbolInfo& info = codes.encoding_info[0][tok];
   1751       // Extra bits first as this is reversed.
   1752       addbits(bits, nbits);
   1753       num_extra_bits += nbits;
   1754       uint8_t ans_nbits = 0;
   1755       uint32_t ans_bits = ans.PutSymbol(info, &ans_nbits);
   1756       addbits(ans_bits, ans_nbits);
   1757     }
   1758   }
   1759   const uint32_t state = ans.GetState();
   1760   writer->Write(32, state);
   1761   writer->Write(numallbits, allbits);
   1762   for (int i = out.size(); i > 0; --i) {
   1763     writer->Write(out_nbits[i - 1], out[i - 1]);
   1764   }
   1765   return num_extra_bits;
   1766 }
   1767 
   1768 void WriteTokens(const std::vector<Token>& tokens,
   1769                  const EntropyEncodingData& codes,
   1770                  const std::vector<uint8_t>& context_map, size_t context_offset,
   1771                  BitWriter* writer, size_t layer, AuxOut* aux_out) {
   1772   // Theoretically, we could have 15 prefix code bits + 31 extra bits.
   1773   BitWriter::Allotment allotment(writer, 46 * tokens.size() + 32 * 1024 * 4);
   1774   size_t num_extra_bits =
   1775       WriteTokens(tokens, codes, context_map, context_offset, writer);
   1776   allotment.ReclaimAndCharge(writer, layer, aux_out);
   1777   if (aux_out != nullptr) {
   1778     aux_out->layers[layer].extra_bits += num_extra_bits;
   1779   }
   1780 }
   1781 
   1782 void SetANSFuzzerFriendly(bool ans_fuzzer_friendly) {
   1783 #if JXL_IS_DEBUG_BUILD  // Guard against accidental / malicious changes.
   1784   ans_fuzzer_friendly_ = ans_fuzzer_friendly;
   1785 #endif
   1786 }
   1787 
   1788 HistogramParams HistogramParams::ForModular(
   1789     const CompressParams& cparams,
   1790     const std::vector<uint8_t>& extra_dc_precision, bool streaming_mode) {
   1791   HistogramParams params;
   1792   params.streaming_mode = streaming_mode;
   1793   if (cparams.speed_tier > SpeedTier::kKitten) {
   1794     params.clustering = HistogramParams::ClusteringType::kFast;
   1795     params.ans_histogram_strategy =
   1796         cparams.speed_tier > SpeedTier::kThunder
   1797             ? HistogramParams::ANSHistogramStrategy::kFast
   1798             : HistogramParams::ANSHistogramStrategy::kApproximate;
   1799     params.lz77_method =
   1800         cparams.decoding_speed_tier >= 3 && cparams.modular_mode
   1801             ? (cparams.speed_tier >= SpeedTier::kFalcon
   1802                    ? HistogramParams::LZ77Method::kRLE
   1803                    : HistogramParams::LZ77Method::kLZ77)
   1804             : HistogramParams::LZ77Method::kNone;
   1805     // Near-lossless DC, as well as modular mode, require choosing hybrid uint
   1806     // more carefully.
   1807     if ((!extra_dc_precision.empty() && extra_dc_precision[0] != 0) ||
   1808         (cparams.modular_mode && cparams.speed_tier < SpeedTier::kCheetah)) {
   1809       params.uint_method = HistogramParams::HybridUintMethod::kFast;
   1810     } else {
   1811       params.uint_method = HistogramParams::HybridUintMethod::kNone;
   1812     }
   1813   } else if (cparams.speed_tier <= SpeedTier::kTortoise) {
   1814     params.lz77_method = HistogramParams::LZ77Method::kOptimal;
   1815   } else {
   1816     params.lz77_method = HistogramParams::LZ77Method::kLZ77;
   1817   }
   1818   if (cparams.decoding_speed_tier >= 1) {
   1819     params.max_histograms = 12;
   1820   }
   1821   if (cparams.decoding_speed_tier >= 1 && cparams.responsive) {
   1822     params.lz77_method = cparams.speed_tier >= SpeedTier::kCheetah
   1823                              ? HistogramParams::LZ77Method::kRLE
   1824                          : cparams.speed_tier >= SpeedTier::kKitten
   1825                              ? HistogramParams::LZ77Method::kLZ77
   1826                              : HistogramParams::LZ77Method::kOptimal;
   1827   }
   1828   if (cparams.decoding_speed_tier >= 2 && cparams.responsive) {
   1829     params.uint_method = HistogramParams::HybridUintMethod::k000;
   1830     params.force_huffman = true;
   1831   }
   1832   return params;
   1833 }
   1834 }  // namespace jxl