enc_ans.cc (68365B)
1 // Copyright (c) the JPEG XL Project Authors. All rights reserved. 2 // 3 // Use of this source code is governed by a BSD-style 4 // license that can be found in the LICENSE file. 5 6 #include "lib/jxl/enc_ans.h" 7 8 #include <jxl/types.h> 9 #include <stdint.h> 10 11 #include <algorithm> 12 #include <array> 13 #include <cmath> 14 #include <cstdint> 15 #include <limits> 16 #include <numeric> 17 #include <type_traits> 18 #include <unordered_map> 19 #include <utility> 20 #include <vector> 21 22 #include "lib/jxl/ans_common.h" 23 #include "lib/jxl/base/bits.h" 24 #include "lib/jxl/base/fast_math-inl.h" 25 #include "lib/jxl/base/status.h" 26 #include "lib/jxl/dec_ans.h" 27 #include "lib/jxl/enc_ans_params.h" 28 #include "lib/jxl/enc_aux_out.h" 29 #include "lib/jxl/enc_cluster.h" 30 #include "lib/jxl/enc_context_map.h" 31 #include "lib/jxl/enc_fields.h" 32 #include "lib/jxl/enc_huffman.h" 33 #include "lib/jxl/enc_params.h" 34 #include "lib/jxl/fields.h" 35 36 namespace jxl { 37 38 namespace { 39 40 #if !JXL_IS_DEBUG_BUILD 41 constexpr 42 #endif 43 bool ans_fuzzer_friendly_ = false; 44 45 const int kMaxNumSymbolsForSmallCode = 4; 46 47 void ANSBuildInfoTable(const ANSHistBin* counts, const AliasTable::Entry* table, 48 size_t alphabet_size, size_t log_alpha_size, 49 ANSEncSymbolInfo* info) { 50 size_t log_entry_size = ANS_LOG_TAB_SIZE - log_alpha_size; 51 size_t entry_size_minus_1 = (1 << log_entry_size) - 1; 52 // create valid alias table for empty streams. 53 for (size_t s = 0; s < std::max<size_t>(1, alphabet_size); ++s) { 54 const ANSHistBin freq = s == alphabet_size ? ANS_TAB_SIZE : counts[s]; 55 info[s].freq_ = static_cast<uint16_t>(freq); 56 #ifdef USE_MULT_BY_RECIPROCAL 57 if (freq != 0) { 58 info[s].ifreq_ = 59 ((1ull << RECIPROCAL_PRECISION) + info[s].freq_ - 1) / info[s].freq_; 60 } else { 61 info[s].ifreq_ = 1; // shouldn't matter (symbol shouldn't occur), but... 62 } 63 #endif 64 info[s].reverse_map_.resize(freq); 65 } 66 for (int i = 0; i < ANS_TAB_SIZE; i++) { 67 AliasTable::Symbol s = 68 AliasTable::Lookup(table, i, log_entry_size, entry_size_minus_1); 69 info[s.value].reverse_map_[s.offset] = i; 70 } 71 } 72 73 float EstimateDataBits(const ANSHistBin* histogram, const ANSHistBin* counts, 74 size_t len) { 75 float sum = 0.0f; 76 int total_histogram = 0; 77 int total_counts = 0; 78 for (size_t i = 0; i < len; ++i) { 79 total_histogram += histogram[i]; 80 total_counts += counts[i]; 81 if (histogram[i] > 0) { 82 JXL_ASSERT(counts[i] > 0); 83 // += histogram[i] * -log(counts[i]/total_counts) 84 sum += histogram[i] * 85 std::max(0.0f, ANS_LOG_TAB_SIZE - FastLog2f(counts[i])); 86 } 87 } 88 if (total_histogram > 0) { 89 // Used only in assert. 90 (void)total_counts; 91 JXL_ASSERT(total_counts == ANS_TAB_SIZE); 92 } 93 return sum; 94 } 95 96 float EstimateDataBitsFlat(const ANSHistBin* histogram, size_t len) { 97 const float flat_bits = std::max(FastLog2f(len), 0.0f); 98 float total_histogram = 0; 99 for (size_t i = 0; i < len; ++i) { 100 total_histogram += histogram[i]; 101 } 102 return total_histogram * flat_bits; 103 } 104 105 // Static Huffman code for encoding logcounts. The last symbol is used as RLE 106 // sequence. 107 const uint8_t kLogCountBitLengths[ANS_LOG_TAB_SIZE + 2] = { 108 5, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 6, 7, 7, 109 }; 110 const uint8_t kLogCountSymbols[ANS_LOG_TAB_SIZE + 2] = { 111 17, 11, 15, 3, 9, 7, 4, 2, 5, 6, 0, 33, 1, 65, 112 }; 113 114 // Returns the difference between largest count that can be represented and is 115 // smaller than "count" and smallest representable count larger than "count". 116 int SmallestIncrement(uint32_t count, uint32_t shift) { 117 int bits = count == 0 ? -1 : FloorLog2Nonzero(count); 118 int drop_bits = bits - GetPopulationCountPrecision(bits, shift); 119 return drop_bits < 0 ? 1 : (1 << drop_bits); 120 } 121 122 template <bool minimize_error_of_sum> 123 bool RebalanceHistogram(const float* targets, int max_symbol, int table_size, 124 uint32_t shift, int* omit_pos, ANSHistBin* counts) { 125 int sum = 0; 126 float sum_nonrounded = 0.0; 127 int remainder_pos = 0; // if all of them are handled in first loop 128 int remainder_log = -1; 129 for (int n = 0; n < max_symbol; ++n) { 130 if (targets[n] > 0 && targets[n] < 1.0f) { 131 counts[n] = 1; 132 sum_nonrounded += targets[n]; 133 sum += counts[n]; 134 } 135 } 136 const float discount_ratio = 137 (table_size - sum) / (table_size - sum_nonrounded); 138 JXL_ASSERT(discount_ratio > 0); 139 JXL_ASSERT(discount_ratio <= 1.0f); 140 // Invariant for minimize_error_of_sum == true: 141 // abs(sum - sum_nonrounded) 142 // <= SmallestIncrement(max(targets[])) + max_symbol 143 for (int n = 0; n < max_symbol; ++n) { 144 if (targets[n] >= 1.0f) { 145 sum_nonrounded += targets[n]; 146 counts[n] = 147 static_cast<ANSHistBin>(targets[n] * discount_ratio); // truncate 148 if (counts[n] == 0) counts[n] = 1; 149 if (counts[n] == table_size) counts[n] = table_size - 1; 150 // Round the count to the closest nonzero multiple of SmallestIncrement 151 // (when minimize_error_of_sum is false) or one of two closest so as to 152 // keep the sum as close as possible to sum_nonrounded. 153 int inc = SmallestIncrement(counts[n], shift); 154 counts[n] -= counts[n] & (inc - 1); 155 // TODO(robryk): Should we rescale targets[n]? 156 const int target = minimize_error_of_sum 157 ? (static_cast<int>(sum_nonrounded) - sum) 158 : static_cast<int>(targets[n]); 159 if (counts[n] == 0 || 160 (target >= counts[n] + inc / 2 && counts[n] + inc < table_size)) { 161 counts[n] += inc; 162 } 163 sum += counts[n]; 164 const int count_log = FloorLog2Nonzero(static_cast<uint32_t>(counts[n])); 165 if (count_log > remainder_log) { 166 remainder_pos = n; 167 remainder_log = count_log; 168 } 169 } 170 } 171 JXL_ASSERT(remainder_pos != -1); 172 // NOTE: This is the only place where counts could go negative. We could 173 // detect that, return false and make ANSHistBin uint32_t. 174 counts[remainder_pos] -= sum - table_size; 175 *omit_pos = remainder_pos; 176 return counts[remainder_pos] > 0; 177 } 178 179 Status NormalizeCounts(ANSHistBin* counts, int* omit_pos, const int length, 180 const int precision_bits, uint32_t shift, 181 int* num_symbols, int* symbols) { 182 const int32_t table_size = 1 << precision_bits; // target sum / table size 183 uint64_t total = 0; 184 int max_symbol = 0; 185 int symbol_count = 0; 186 for (int n = 0; n < length; ++n) { 187 total += counts[n]; 188 if (counts[n] > 0) { 189 if (symbol_count < kMaxNumSymbolsForSmallCode) { 190 symbols[symbol_count] = n; 191 } 192 ++symbol_count; 193 max_symbol = n + 1; 194 } 195 } 196 *num_symbols = symbol_count; 197 if (symbol_count == 0) { 198 return true; 199 } 200 if (symbol_count == 1) { 201 counts[symbols[0]] = table_size; 202 return true; 203 } 204 if (symbol_count > table_size) 205 return JXL_FAILURE("Too many entries in an ANS histogram"); 206 207 const float norm = 1.f * table_size / total; 208 std::vector<float> targets(max_symbol); 209 for (size_t n = 0; n < targets.size(); ++n) { 210 targets[n] = norm * counts[n]; 211 } 212 if (!RebalanceHistogram<false>(targets.data(), max_symbol, table_size, shift, 213 omit_pos, counts)) { 214 // Use an alternative rebalancing mechanism if the one above failed 215 // to create a histogram that is positive wherever the original one was. 216 if (!RebalanceHistogram<true>(targets.data(), max_symbol, table_size, shift, 217 omit_pos, counts)) { 218 return JXL_FAILURE("Logic error: couldn't rebalance a histogram"); 219 } 220 } 221 return true; 222 } 223 224 struct SizeWriter { 225 size_t size = 0; 226 void Write(size_t num, size_t bits) { size += num; } 227 }; 228 229 template <typename Writer> 230 void StoreVarLenUint8(size_t n, Writer* writer) { 231 JXL_DASSERT(n <= 255); 232 if (n == 0) { 233 writer->Write(1, 0); 234 } else { 235 writer->Write(1, 1); 236 size_t nbits = FloorLog2Nonzero(n); 237 writer->Write(3, nbits); 238 writer->Write(nbits, n - (1ULL << nbits)); 239 } 240 } 241 242 template <typename Writer> 243 void StoreVarLenUint16(size_t n, Writer* writer) { 244 JXL_DASSERT(n <= 65535); 245 if (n == 0) { 246 writer->Write(1, 0); 247 } else { 248 writer->Write(1, 1); 249 size_t nbits = FloorLog2Nonzero(n); 250 writer->Write(4, nbits); 251 writer->Write(nbits, n - (1ULL << nbits)); 252 } 253 } 254 255 template <typename Writer> 256 bool EncodeCounts(const ANSHistBin* counts, const int alphabet_size, 257 const int omit_pos, const int num_symbols, uint32_t shift, 258 const int* symbols, Writer* writer) { 259 bool ok = true; 260 if (num_symbols <= 2) { 261 // Small tree marker to encode 1-2 symbols. 262 writer->Write(1, 1); 263 if (num_symbols == 0) { 264 writer->Write(1, 0); 265 StoreVarLenUint8(0, writer); 266 } else { 267 writer->Write(1, num_symbols - 1); 268 for (int i = 0; i < num_symbols; ++i) { 269 StoreVarLenUint8(symbols[i], writer); 270 } 271 } 272 if (num_symbols == 2) { 273 writer->Write(ANS_LOG_TAB_SIZE, counts[symbols[0]]); 274 } 275 } else { 276 // Mark non-small tree. 277 writer->Write(1, 0); 278 // Mark non-flat histogram. 279 writer->Write(1, 0); 280 281 // Precompute sequences for RLE encoding. Contains the number of identical 282 // values starting at a given index. Only contains the value at the first 283 // element of the series. 284 std::vector<uint32_t> same(alphabet_size, 0); 285 int last = 0; 286 for (int i = 1; i < alphabet_size; i++) { 287 // Store the sequence length once different symbol reached, or we're at 288 // the end, or the length is longer than we can encode, or we are at 289 // the omit_pos. We don't support including the omit_pos in an RLE 290 // sequence because this value may use a different amount of log2 bits 291 // than standard, it is too complex to handle in the decoder. 292 if (counts[i] != counts[last] || i + 1 == alphabet_size || 293 (i - last) >= 255 || i == omit_pos || i == omit_pos + 1) { 294 same[last] = (i - last); 295 last = i + 1; 296 } 297 } 298 299 int length = 0; 300 std::vector<int> logcounts(alphabet_size); 301 int omit_log = 0; 302 for (int i = 0; i < alphabet_size; ++i) { 303 JXL_ASSERT(counts[i] <= ANS_TAB_SIZE); 304 JXL_ASSERT(counts[i] >= 0); 305 if (i == omit_pos) { 306 length = i + 1; 307 } else if (counts[i] > 0) { 308 logcounts[i] = FloorLog2Nonzero(static_cast<uint32_t>(counts[i])) + 1; 309 length = i + 1; 310 if (i < omit_pos) { 311 omit_log = std::max(omit_log, logcounts[i] + 1); 312 } else { 313 omit_log = std::max(omit_log, logcounts[i]); 314 } 315 } 316 } 317 logcounts[omit_pos] = omit_log; 318 319 // Elias gamma-like code for shift. Only difference is that if the number 320 // of bits to be encoded is equal to FloorLog2(ANS_LOG_TAB_SIZE+1), we skip 321 // the terminating 0 in unary coding. 322 int upper_bound_log = FloorLog2Nonzero(ANS_LOG_TAB_SIZE + 1); 323 int log = FloorLog2Nonzero(shift + 1); 324 writer->Write(log, (1 << log) - 1); 325 if (log != upper_bound_log) writer->Write(1, 0); 326 writer->Write(log, ((1 << log) - 1) & (shift + 1)); 327 328 // Since num_symbols >= 3, we know that length >= 3, therefore we encode 329 // length - 3. 330 if (length - 3 > 255) { 331 // Pretend that everything is OK, but complain about correctness later. 332 StoreVarLenUint8(255, writer); 333 ok = false; 334 } else { 335 StoreVarLenUint8(length - 3, writer); 336 } 337 338 // The logcount values are encoded with a static Huffman code. 339 static const size_t kMinReps = 4; 340 size_t rep = ANS_LOG_TAB_SIZE + 1; 341 for (int i = 0; i < length; ++i) { 342 if (i > 0 && same[i - 1] > kMinReps) { 343 // Encode the RLE symbol and skip the repeated ones. 344 writer->Write(kLogCountBitLengths[rep], kLogCountSymbols[rep]); 345 StoreVarLenUint8(same[i - 1] - kMinReps - 1, writer); 346 i += same[i - 1] - 2; 347 continue; 348 } 349 writer->Write(kLogCountBitLengths[logcounts[i]], 350 kLogCountSymbols[logcounts[i]]); 351 } 352 for (int i = 0; i < length; ++i) { 353 if (i > 0 && same[i - 1] > kMinReps) { 354 // Skip symbols encoded by RLE. 355 i += same[i - 1] - 2; 356 continue; 357 } 358 if (logcounts[i] > 1 && i != omit_pos) { 359 int bitcount = GetPopulationCountPrecision(logcounts[i] - 1, shift); 360 int drop_bits = logcounts[i] - 1 - bitcount; 361 JXL_CHECK((counts[i] & ((1 << drop_bits) - 1)) == 0); 362 writer->Write(bitcount, (counts[i] >> drop_bits) - (1 << bitcount)); 363 } 364 } 365 } 366 return ok; 367 } 368 369 void EncodeFlatHistogram(const int alphabet_size, BitWriter* writer) { 370 // Mark non-small tree. 371 writer->Write(1, 0); 372 // Mark uniform histogram. 373 writer->Write(1, 1); 374 JXL_ASSERT(alphabet_size > 0); 375 // Encode alphabet size. 376 StoreVarLenUint8(alphabet_size - 1, writer); 377 } 378 379 float ComputeHistoAndDataCost(const ANSHistBin* histogram, size_t alphabet_size, 380 uint32_t method) { 381 if (method == 0) { // Flat code 382 return ANS_LOG_TAB_SIZE + 2 + 383 EstimateDataBitsFlat(histogram, alphabet_size); 384 } 385 // Non-flat: shift = method-1. 386 uint32_t shift = method - 1; 387 std::vector<ANSHistBin> counts(histogram, histogram + alphabet_size); 388 int omit_pos = 0; 389 int num_symbols; 390 int symbols[kMaxNumSymbolsForSmallCode] = {}; 391 JXL_CHECK(NormalizeCounts(counts.data(), &omit_pos, alphabet_size, 392 ANS_LOG_TAB_SIZE, shift, &num_symbols, symbols)); 393 SizeWriter writer; 394 // Ignore the correctness, no real encoding happens at this stage. 395 (void)EncodeCounts(counts.data(), alphabet_size, omit_pos, num_symbols, shift, 396 symbols, &writer); 397 return writer.size + 398 EstimateDataBits(histogram, counts.data(), alphabet_size); 399 } 400 401 uint32_t ComputeBestMethod( 402 const ANSHistBin* histogram, size_t alphabet_size, float* cost, 403 HistogramParams::ANSHistogramStrategy ans_histogram_strategy) { 404 size_t method = 0; 405 float fcost = ComputeHistoAndDataCost(histogram, alphabet_size, 0); 406 auto try_shift = [&](size_t shift) { 407 float c = ComputeHistoAndDataCost(histogram, alphabet_size, shift + 1); 408 if (c < fcost) { 409 method = shift + 1; 410 fcost = c; 411 } 412 }; 413 switch (ans_histogram_strategy) { 414 case HistogramParams::ANSHistogramStrategy::kPrecise: { 415 for (uint32_t shift = 0; shift <= ANS_LOG_TAB_SIZE; shift++) { 416 try_shift(shift); 417 } 418 break; 419 } 420 case HistogramParams::ANSHistogramStrategy::kApproximate: { 421 for (uint32_t shift = 0; shift <= ANS_LOG_TAB_SIZE; shift += 2) { 422 try_shift(shift); 423 } 424 break; 425 } 426 case HistogramParams::ANSHistogramStrategy::kFast: { 427 try_shift(0); 428 try_shift(ANS_LOG_TAB_SIZE / 2); 429 try_shift(ANS_LOG_TAB_SIZE); 430 break; 431 } 432 }; 433 *cost = fcost; 434 return method; 435 } 436 437 } // namespace 438 439 // Returns an estimate of the cost of encoding this histogram and the 440 // corresponding data. 441 size_t BuildAndStoreANSEncodingData( 442 HistogramParams::ANSHistogramStrategy ans_histogram_strategy, 443 const ANSHistBin* histogram, size_t alphabet_size, size_t log_alpha_size, 444 bool use_prefix_code, ANSEncSymbolInfo* info, BitWriter* writer) { 445 if (use_prefix_code) { 446 if (alphabet_size <= 1) return 0; 447 std::vector<uint32_t> histo(alphabet_size); 448 for (size_t i = 0; i < alphabet_size; i++) { 449 histo[i] = histogram[i]; 450 JXL_CHECK(histogram[i] >= 0); 451 } 452 size_t cost = 0; 453 { 454 std::vector<uint8_t> depths(alphabet_size); 455 std::vector<uint16_t> bits(alphabet_size); 456 if (writer == nullptr) { 457 BitWriter tmp_writer; 458 BitWriter::Allotment allotment( 459 &tmp_writer, 8 * alphabet_size + 8); // safe upper bound 460 BuildAndStoreHuffmanTree(histo.data(), alphabet_size, depths.data(), 461 bits.data(), &tmp_writer); 462 allotment.ReclaimAndCharge(&tmp_writer, 0, /*aux_out=*/nullptr); 463 cost = tmp_writer.BitsWritten(); 464 } else { 465 size_t start = writer->BitsWritten(); 466 BuildAndStoreHuffmanTree(histo.data(), alphabet_size, depths.data(), 467 bits.data(), writer); 468 cost = writer->BitsWritten() - start; 469 } 470 for (size_t i = 0; i < alphabet_size; i++) { 471 info[i].bits = depths[i] == 0 ? 0 : bits[i]; 472 info[i].depth = depths[i]; 473 } 474 } 475 // Estimate data cost. 476 for (size_t i = 0; i < alphabet_size; i++) { 477 cost += histogram[i] * info[i].depth; 478 } 479 return cost; 480 } 481 JXL_ASSERT(alphabet_size <= ANS_TAB_SIZE); 482 float cost; 483 uint32_t method = ComputeBestMethod(histogram, alphabet_size, &cost, 484 ans_histogram_strategy); 485 JXL_ASSERT(cost >= 0); 486 int num_symbols; 487 int symbols[kMaxNumSymbolsForSmallCode] = {}; 488 std::vector<ANSHistBin> counts(histogram, histogram + alphabet_size); 489 if (!counts.empty()) { 490 size_t sum = 0; 491 for (int count : counts) { 492 sum += count; 493 } 494 if (sum == 0) { 495 counts[0] = ANS_TAB_SIZE; 496 } 497 } 498 int omit_pos = 0; 499 uint32_t shift = method - 1; 500 if (method == 0) { 501 counts = CreateFlatHistogram(alphabet_size, ANS_TAB_SIZE); 502 } else { 503 JXL_CHECK(NormalizeCounts(counts.data(), &omit_pos, alphabet_size, 504 ANS_LOG_TAB_SIZE, shift, &num_symbols, symbols)); 505 } 506 AliasTable::Entry a[ANS_MAX_ALPHABET_SIZE]; 507 InitAliasTable(counts, ANS_TAB_SIZE, log_alpha_size, a); 508 ANSBuildInfoTable(counts.data(), a, alphabet_size, log_alpha_size, info); 509 if (writer != nullptr) { 510 if (method == 0) { 511 EncodeFlatHistogram(alphabet_size, writer); 512 } else { 513 bool ok = EncodeCounts(counts.data(), alphabet_size, omit_pos, 514 num_symbols, method - 1, symbols, writer); 515 (void)ok; 516 JXL_DASSERT(ok); 517 } 518 } 519 return cost; 520 } 521 522 float ANSPopulationCost(const ANSHistBin* data, size_t alphabet_size) { 523 float c; 524 ComputeBestMethod(data, alphabet_size, &c, 525 HistogramParams::ANSHistogramStrategy::kFast); 526 return c; 527 } 528 529 template <typename Writer> 530 void EncodeUintConfig(const HybridUintConfig uint_config, Writer* writer, 531 size_t log_alpha_size) { 532 writer->Write(CeilLog2Nonzero(log_alpha_size + 1), 533 uint_config.split_exponent); 534 if (uint_config.split_exponent == log_alpha_size) { 535 return; // msb/lsb don't matter. 536 } 537 size_t nbits = CeilLog2Nonzero(uint_config.split_exponent + 1); 538 writer->Write(nbits, uint_config.msb_in_token); 539 nbits = CeilLog2Nonzero(uint_config.split_exponent - 540 uint_config.msb_in_token + 1); 541 writer->Write(nbits, uint_config.lsb_in_token); 542 } 543 template <typename Writer> 544 void EncodeUintConfigs(const std::vector<HybridUintConfig>& uint_config, 545 Writer* writer, size_t log_alpha_size) { 546 // TODO(veluca): RLE? 547 for (const auto& cfg : uint_config) { 548 EncodeUintConfig(cfg, writer, log_alpha_size); 549 } 550 } 551 template void EncodeUintConfigs(const std::vector<HybridUintConfig>&, 552 BitWriter*, size_t); 553 554 namespace { 555 556 void ChooseUintConfigs(const HistogramParams& params, 557 const std::vector<std::vector<Token>>& tokens, 558 const std::vector<uint8_t>& context_map, 559 std::vector<Histogram>* clustered_histograms, 560 EntropyEncodingData* codes, size_t* log_alpha_size) { 561 codes->uint_config.resize(clustered_histograms->size()); 562 if (params.uint_method == HistogramParams::HybridUintMethod::kNone) { 563 return; 564 } 565 if (params.uint_method == HistogramParams::HybridUintMethod::k000) { 566 codes->uint_config.clear(); 567 codes->uint_config.resize(clustered_histograms->size(), 568 HybridUintConfig(0, 0, 0)); 569 return; 570 } 571 if (params.uint_method == HistogramParams::HybridUintMethod::kContextMap) { 572 codes->uint_config.clear(); 573 codes->uint_config.resize(clustered_histograms->size(), 574 HybridUintConfig(2, 0, 1)); 575 return; 576 } 577 578 // If the uint config is adaptive, just stick with the default in streaming 579 // mode. 580 if (params.streaming_mode) { 581 return; 582 } 583 584 // Brute-force method that tries a few options. 585 std::vector<HybridUintConfig> configs; 586 if (params.uint_method == HistogramParams::HybridUintMethod::kBest) { 587 configs = { 588 HybridUintConfig(4, 2, 0), // default 589 HybridUintConfig(4, 1, 0), // less precise 590 HybridUintConfig(4, 2, 1), // add sign 591 HybridUintConfig(4, 2, 2), // add sign+parity 592 HybridUintConfig(4, 1, 2), // add parity but less msb 593 // Same as above, but more direct coding. 594 HybridUintConfig(5, 2, 0), HybridUintConfig(5, 1, 0), 595 HybridUintConfig(5, 2, 1), HybridUintConfig(5, 2, 2), 596 HybridUintConfig(5, 1, 2), 597 // Same as above, but less direct coding. 598 HybridUintConfig(3, 2, 0), HybridUintConfig(3, 1, 0), 599 HybridUintConfig(3, 2, 1), HybridUintConfig(3, 1, 2), 600 // For near-lossless. 601 HybridUintConfig(4, 1, 3), HybridUintConfig(5, 1, 4), 602 HybridUintConfig(5, 2, 3), HybridUintConfig(6, 1, 5), 603 HybridUintConfig(6, 2, 4), HybridUintConfig(6, 0, 0), 604 // Other 605 HybridUintConfig(0, 0, 0), // varlenuint 606 HybridUintConfig(2, 0, 1), // works well for ctx map 607 HybridUintConfig(7, 0, 0), // direct coding 608 HybridUintConfig(8, 0, 0), // direct coding 609 HybridUintConfig(9, 0, 0), // direct coding 610 HybridUintConfig(10, 0, 0), // direct coding 611 HybridUintConfig(11, 0, 0), // direct coding 612 HybridUintConfig(12, 0, 0), // direct coding 613 }; 614 } else if (params.uint_method == HistogramParams::HybridUintMethod::kFast) { 615 configs = { 616 HybridUintConfig(4, 2, 0), // default 617 HybridUintConfig(4, 1, 2), // add parity but less msb 618 HybridUintConfig(0, 0, 0), // smallest histograms 619 HybridUintConfig(2, 0, 1), // works well for ctx map 620 }; 621 } 622 623 std::vector<float> costs(clustered_histograms->size(), 624 std::numeric_limits<float>::max()); 625 std::vector<uint32_t> extra_bits(clustered_histograms->size()); 626 std::vector<uint8_t> is_valid(clustered_histograms->size()); 627 size_t max_alpha = 628 codes->use_prefix_code ? PREFIX_MAX_ALPHABET_SIZE : ANS_MAX_ALPHABET_SIZE; 629 for (HybridUintConfig cfg : configs) { 630 std::fill(is_valid.begin(), is_valid.end(), true); 631 std::fill(extra_bits.begin(), extra_bits.end(), 0); 632 633 for (auto& histo : *clustered_histograms) { 634 histo.Clear(); 635 } 636 for (const auto& stream : tokens) { 637 for (const auto& token : stream) { 638 // TODO(veluca): do not ignore lz77 commands. 639 if (token.is_lz77_length) continue; 640 size_t histo = context_map[token.context]; 641 uint32_t tok, nbits, bits; 642 cfg.Encode(token.value, &tok, &nbits, &bits); 643 if (tok >= max_alpha || 644 (codes->lz77.enabled && tok >= codes->lz77.min_symbol)) { 645 is_valid[histo] = JXL_FALSE; 646 continue; 647 } 648 extra_bits[histo] += nbits; 649 (*clustered_histograms)[histo].Add(tok); 650 } 651 } 652 653 for (size_t i = 0; i < clustered_histograms->size(); i++) { 654 if (!is_valid[i]) continue; 655 float cost = (*clustered_histograms)[i].PopulationCost() + extra_bits[i]; 656 // add signaling cost of the hybriduintconfig itself 657 cost += CeilLog2Nonzero(cfg.split_exponent + 1); 658 cost += CeilLog2Nonzero(cfg.split_exponent - cfg.msb_in_token + 1); 659 if (cost < costs[i]) { 660 codes->uint_config[i] = cfg; 661 costs[i] = cost; 662 } 663 } 664 } 665 666 // Rebuild histograms. 667 for (auto& histo : *clustered_histograms) { 668 histo.Clear(); 669 } 670 *log_alpha_size = 4; 671 for (const auto& stream : tokens) { 672 for (const auto& token : stream) { 673 uint32_t tok, nbits, bits; 674 size_t histo = context_map[token.context]; 675 (token.is_lz77_length ? codes->lz77.length_uint_config 676 : codes->uint_config[histo]) 677 .Encode(token.value, &tok, &nbits, &bits); 678 tok += token.is_lz77_length ? codes->lz77.min_symbol : 0; 679 (*clustered_histograms)[histo].Add(tok); 680 while (tok >= (1u << *log_alpha_size)) (*log_alpha_size)++; 681 } 682 } 683 #if JXL_ENABLE_ASSERT 684 size_t max_log_alpha_size = codes->use_prefix_code ? PREFIX_MAX_BITS : 8; 685 JXL_ASSERT(*log_alpha_size <= max_log_alpha_size); 686 #endif 687 } 688 689 Histogram HistogramFromSymbolInfo( 690 const std::vector<ANSEncSymbolInfo>& encoding_info, bool use_prefix_code) { 691 Histogram histo; 692 histo.data_.resize(DivCeil(encoding_info.size(), Histogram::kRounding) * 693 Histogram::kRounding); 694 histo.total_count_ = 0; 695 for (size_t i = 0; i < encoding_info.size(); ++i) { 696 const ANSEncSymbolInfo& info = encoding_info[i]; 697 int count = use_prefix_code 698 ? (info.depth ? (1u << (PREFIX_MAX_BITS - info.depth)) : 0) 699 : info.freq_; 700 histo.data_[i] = count; 701 histo.total_count_ += count; 702 } 703 return histo; 704 } 705 706 class HistogramBuilder { 707 public: 708 explicit HistogramBuilder(const size_t num_contexts) 709 : histograms_(num_contexts) {} 710 711 void VisitSymbol(int symbol, size_t histo_idx) { 712 JXL_DASSERT(histo_idx < histograms_.size()); 713 histograms_[histo_idx].Add(symbol); 714 } 715 716 // NOTE: `layer` is only for clustered_entropy; caller does ReclaimAndCharge. 717 size_t BuildAndStoreEntropyCodes( 718 const HistogramParams& params, 719 const std::vector<std::vector<Token>>& tokens, EntropyEncodingData* codes, 720 std::vector<uint8_t>* context_map, BitWriter* writer, size_t layer, 721 AuxOut* aux_out) const { 722 const size_t prev_histograms = codes->encoding_info.size(); 723 size_t cost = 0; 724 std::vector<Histogram> clustered_histograms; 725 for (size_t i = 0; i < prev_histograms; ++i) { 726 clustered_histograms.push_back(HistogramFromSymbolInfo( 727 codes->encoding_info[i], codes->use_prefix_code)); 728 } 729 size_t context_offset = context_map->size(); 730 context_map->resize(context_offset + histograms_.size()); 731 if (histograms_.size() > 1) { 732 if (!ans_fuzzer_friendly_) { 733 std::vector<uint32_t> histogram_symbols; 734 ClusterHistograms(params, histograms_, kClustersLimit, 735 &clustered_histograms, &histogram_symbols); 736 for (size_t c = 0; c < histograms_.size(); ++c) { 737 (*context_map)[context_offset + c] = 738 static_cast<uint8_t>(histogram_symbols[c]); 739 } 740 } else { 741 JXL_ASSERT(codes->encoding_info.empty()); 742 fill(context_map->begin(), context_map->end(), 0); 743 size_t max_symbol = 0; 744 for (const Histogram& h : histograms_) { 745 max_symbol = std::max(h.data_.size(), max_symbol); 746 } 747 size_t num_symbols = 1 << CeilLog2Nonzero(max_symbol + 1); 748 clustered_histograms.resize(1); 749 clustered_histograms[0].Clear(); 750 for (size_t i = 0; i < num_symbols; i++) { 751 clustered_histograms[0].Add(i); 752 } 753 } 754 if (writer != nullptr) { 755 EncodeContextMap(*context_map, clustered_histograms.size(), writer, 756 layer, aux_out); 757 } 758 } else { 759 JXL_ASSERT(codes->encoding_info.empty()); 760 clustered_histograms.push_back(histograms_[0]); 761 } 762 if (aux_out != nullptr) { 763 for (size_t i = prev_histograms; i < clustered_histograms.size(); ++i) { 764 aux_out->layers[layer].clustered_entropy += 765 clustered_histograms[i].ShannonEntropy(); 766 } 767 } 768 size_t log_alpha_size = codes->lz77.enabled ? 8 : 7; // Sane default. 769 if (ans_fuzzer_friendly_) { 770 codes->uint_config.clear(); 771 codes->uint_config.resize(1, HybridUintConfig(7, 0, 0)); 772 } else { 773 ChooseUintConfigs(params, tokens, *context_map, &clustered_histograms, 774 codes, &log_alpha_size); 775 } 776 if (log_alpha_size < 5) log_alpha_size = 5; 777 if (params.streaming_mode) { 778 // TODO(szabadka) Figure out if we can use lower values here. 779 log_alpha_size = 8; 780 } 781 SizeWriter size_writer; // Used if writer == nullptr to estimate costs. 782 cost += 1; 783 if (writer) writer->Write(1, TO_JXL_BOOL(codes->use_prefix_code)); 784 785 if (codes->use_prefix_code) { 786 log_alpha_size = PREFIX_MAX_BITS; 787 } else { 788 cost += 2; 789 } 790 if (writer == nullptr) { 791 EncodeUintConfigs(codes->uint_config, &size_writer, log_alpha_size); 792 } else { 793 if (!codes->use_prefix_code) writer->Write(2, log_alpha_size - 5); 794 EncodeUintConfigs(codes->uint_config, writer, log_alpha_size); 795 } 796 if (codes->use_prefix_code) { 797 for (const auto& histo : clustered_histograms) { 798 size_t alphabet_size = histo.alphabet_size(); 799 if (writer) { 800 StoreVarLenUint16(alphabet_size - 1, writer); 801 } else { 802 StoreVarLenUint16(alphabet_size - 1, &size_writer); 803 } 804 } 805 } 806 cost += size_writer.size; 807 for (size_t c = prev_histograms; c < clustered_histograms.size(); ++c) { 808 size_t alphabet_size = clustered_histograms[c].alphabet_size(); 809 codes->encoding_info.emplace_back(); 810 codes->encoding_info.back().resize(alphabet_size); 811 BitWriter* histo_writer = writer; 812 if (params.streaming_mode) { 813 codes->encoded_histograms.emplace_back(); 814 histo_writer = &codes->encoded_histograms.back(); 815 } 816 BitWriter::Allotment allotment(histo_writer, 256 + alphabet_size * 24); 817 cost += BuildAndStoreANSEncodingData( 818 params.ans_histogram_strategy, clustered_histograms[c].data_.data(), 819 alphabet_size, log_alpha_size, codes->use_prefix_code, 820 codes->encoding_info.back().data(), histo_writer); 821 allotment.FinishedHistogram(histo_writer); 822 allotment.ReclaimAndCharge(histo_writer, layer, aux_out); 823 if (params.streaming_mode) { 824 writer->AppendUnaligned(*histo_writer); 825 } 826 } 827 return cost; 828 } 829 830 const Histogram& Histo(size_t i) const { return histograms_[i]; } 831 832 private: 833 std::vector<Histogram> histograms_; 834 }; 835 836 class SymbolCostEstimator { 837 public: 838 SymbolCostEstimator(size_t num_contexts, bool force_huffman, 839 const std::vector<std::vector<Token>>& tokens, 840 const LZ77Params& lz77) { 841 HistogramBuilder builder(num_contexts); 842 // Build histograms for estimating lz77 savings. 843 HybridUintConfig uint_config; 844 for (const auto& stream : tokens) { 845 for (const auto& token : stream) { 846 uint32_t tok, nbits, bits; 847 (token.is_lz77_length ? lz77.length_uint_config : uint_config) 848 .Encode(token.value, &tok, &nbits, &bits); 849 tok += token.is_lz77_length ? lz77.min_symbol : 0; 850 builder.VisitSymbol(tok, token.context); 851 } 852 } 853 max_alphabet_size_ = 0; 854 for (size_t i = 0; i < num_contexts; i++) { 855 max_alphabet_size_ = 856 std::max(max_alphabet_size_, builder.Histo(i).data_.size()); 857 } 858 bits_.resize(num_contexts * max_alphabet_size_); 859 // TODO(veluca): SIMD? 860 add_symbol_cost_.resize(num_contexts); 861 for (size_t i = 0; i < num_contexts; i++) { 862 float inv_total = 1.0f / (builder.Histo(i).total_count_ + 1e-8f); 863 float total_cost = 0; 864 for (size_t j = 0; j < builder.Histo(i).data_.size(); j++) { 865 size_t cnt = builder.Histo(i).data_[j]; 866 float cost = 0; 867 if (cnt != 0 && cnt != builder.Histo(i).total_count_) { 868 cost = -FastLog2f(cnt * inv_total); 869 if (force_huffman) cost = std::ceil(cost); 870 } else if (cnt == 0) { 871 cost = ANS_LOG_TAB_SIZE; // Highest possible cost. 872 } 873 bits_[i * max_alphabet_size_ + j] = cost; 874 total_cost += cost * builder.Histo(i).data_[j]; 875 } 876 // Penalty for adding a lz77 symbol to this contest (only used for static 877 // cost model). Higher penalty for contexts that have a very low 878 // per-symbol entropy. 879 add_symbol_cost_[i] = std::max(0.0f, 6.0f - total_cost * inv_total); 880 } 881 } 882 float Bits(size_t ctx, size_t sym) const { 883 return bits_[ctx * max_alphabet_size_ + sym]; 884 } 885 float LenCost(size_t ctx, size_t len, const LZ77Params& lz77) const { 886 uint32_t nbits, bits, tok; 887 lz77.length_uint_config.Encode(len, &tok, &nbits, &bits); 888 tok += lz77.min_symbol; 889 return nbits + Bits(ctx, tok); 890 } 891 float DistCost(size_t len, const LZ77Params& lz77) const { 892 uint32_t nbits, bits, tok; 893 HybridUintConfig().Encode(len, &tok, &nbits, &bits); 894 return nbits + Bits(lz77.nonserialized_distance_context, tok); 895 } 896 float AddSymbolCost(size_t idx) const { return add_symbol_cost_[idx]; } 897 898 private: 899 size_t max_alphabet_size_; 900 std::vector<float> bits_; 901 std::vector<float> add_symbol_cost_; 902 }; 903 904 void ApplyLZ77_RLE(const HistogramParams& params, size_t num_contexts, 905 const std::vector<std::vector<Token>>& tokens, 906 LZ77Params& lz77, 907 std::vector<std::vector<Token>>& tokens_lz77) { 908 // TODO(veluca): tune heuristics here. 909 SymbolCostEstimator sce(num_contexts, params.force_huffman, tokens, lz77); 910 float bit_decrease = 0; 911 size_t total_symbols = 0; 912 tokens_lz77.resize(tokens.size()); 913 std::vector<float> sym_cost; 914 HybridUintConfig uint_config; 915 for (size_t stream = 0; stream < tokens.size(); stream++) { 916 size_t distance_multiplier = 917 params.image_widths.size() > stream ? params.image_widths[stream] : 0; 918 const auto& in = tokens[stream]; 919 auto& out = tokens_lz77[stream]; 920 total_symbols += in.size(); 921 // Cumulative sum of bit costs. 922 sym_cost.resize(in.size() + 1); 923 for (size_t i = 0; i < in.size(); i++) { 924 uint32_t tok, nbits, unused_bits; 925 uint_config.Encode(in[i].value, &tok, &nbits, &unused_bits); 926 sym_cost[i + 1] = sce.Bits(in[i].context, tok) + nbits + sym_cost[i]; 927 } 928 out.reserve(in.size()); 929 for (size_t i = 0; i < in.size(); i++) { 930 size_t num_to_copy = 0; 931 size_t distance_symbol = 0; // 1 for RLE. 932 if (distance_multiplier != 0) { 933 distance_symbol = 1; // Special distance 1 if enabled. 934 JXL_DASSERT(kSpecialDistances[1][0] == 1); 935 JXL_DASSERT(kSpecialDistances[1][1] == 0); 936 } 937 if (i > 0) { 938 for (; i + num_to_copy < in.size(); num_to_copy++) { 939 if (in[i + num_to_copy].value != in[i - 1].value) { 940 break; 941 } 942 } 943 } 944 if (num_to_copy == 0) { 945 out.push_back(in[i]); 946 continue; 947 } 948 float cost = sym_cost[i + num_to_copy] - sym_cost[i]; 949 // This subtraction might overflow, but that's OK. 950 size_t lz77_len = num_to_copy - lz77.min_length; 951 float lz77_cost = num_to_copy >= lz77.min_length 952 ? CeilLog2Nonzero(lz77_len + 1) + 1 953 : 0; 954 if (num_to_copy < lz77.min_length || cost <= lz77_cost) { 955 for (size_t j = 0; j < num_to_copy; j++) { 956 out.push_back(in[i + j]); 957 } 958 i += num_to_copy - 1; 959 continue; 960 } 961 // Output the LZ77 length 962 out.emplace_back(in[i].context, lz77_len); 963 out.back().is_lz77_length = true; 964 i += num_to_copy - 1; 965 bit_decrease += cost - lz77_cost; 966 // Output the LZ77 copy distance. 967 out.emplace_back(lz77.nonserialized_distance_context, distance_symbol); 968 } 969 } 970 971 if (bit_decrease > total_symbols * 0.2 + 16) { 972 lz77.enabled = true; 973 } 974 } 975 976 // Hash chain for LZ77 matching 977 struct HashChain { 978 size_t size_; 979 std::vector<uint32_t> data_; 980 981 unsigned hash_num_values_ = 32768; 982 unsigned hash_mask_ = hash_num_values_ - 1; 983 unsigned hash_shift_ = 5; 984 985 std::vector<int> head; 986 std::vector<uint32_t> chain; 987 std::vector<int> val; 988 989 // Speed up repetitions of zero 990 std::vector<int> headz; 991 std::vector<uint32_t> chainz; 992 std::vector<uint32_t> zeros; 993 uint32_t numzeros = 0; 994 995 size_t window_size_; 996 size_t window_mask_; 997 size_t min_length_; 998 size_t max_length_; 999 1000 // Map of special distance codes. 1001 std::unordered_map<int, int> special_dist_table_; 1002 size_t num_special_distances_ = 0; 1003 1004 uint32_t maxchainlength = 256; // window_size_ to allow all 1005 1006 HashChain(const Token* data, size_t size, size_t window_size, 1007 size_t min_length, size_t max_length, size_t distance_multiplier) 1008 : size_(size), 1009 window_size_(window_size), 1010 window_mask_(window_size - 1), 1011 min_length_(min_length), 1012 max_length_(max_length) { 1013 data_.resize(size); 1014 for (size_t i = 0; i < size; i++) { 1015 data_[i] = data[i].value; 1016 } 1017 1018 head.resize(hash_num_values_, -1); 1019 val.resize(window_size_, -1); 1020 chain.resize(window_size_); 1021 for (uint32_t i = 0; i < window_size_; ++i) { 1022 chain[i] = i; // same value as index indicates uninitialized 1023 } 1024 1025 zeros.resize(window_size_); 1026 headz.resize(window_size_ + 1, -1); 1027 chainz.resize(window_size_); 1028 for (uint32_t i = 0; i < window_size_; ++i) { 1029 chainz[i] = i; 1030 } 1031 // Translate distance to special distance code. 1032 if (distance_multiplier) { 1033 // Count down, so if due to small distance multiplier multiple distances 1034 // map to the same code, the smallest code will be used in the end. 1035 for (int i = kNumSpecialDistances - 1; i >= 0; --i) { 1036 special_dist_table_[SpecialDistance(i, distance_multiplier)] = i; 1037 } 1038 num_special_distances_ = kNumSpecialDistances; 1039 } 1040 } 1041 1042 uint32_t GetHash(size_t pos) const { 1043 uint32_t result = 0; 1044 if (pos + 2 < size_) { 1045 // TODO(lode): take the MSB's of the uint32_t values into account as well, 1046 // given that the hash code itself is less than 32 bits. 1047 result ^= static_cast<uint32_t>(data_[pos + 0] << 0u); 1048 result ^= static_cast<uint32_t>(data_[pos + 1] << hash_shift_); 1049 result ^= static_cast<uint32_t>(data_[pos + 2] << (hash_shift_ * 2)); 1050 } else { 1051 // No need to compute hash of last 2 bytes, the length 2 is too short. 1052 return 0; 1053 } 1054 return result & hash_mask_; 1055 } 1056 1057 uint32_t CountZeros(size_t pos, uint32_t prevzeros) const { 1058 size_t end = pos + window_size_; 1059 if (end > size_) end = size_; 1060 if (prevzeros > 0) { 1061 if (prevzeros >= window_mask_ && data_[end - 1] == 0 && 1062 end == pos + window_size_) { 1063 return prevzeros; 1064 } else { 1065 return prevzeros - 1; 1066 } 1067 } 1068 uint32_t num = 0; 1069 while (pos + num < end && data_[pos + num] == 0) num++; 1070 return num; 1071 } 1072 1073 void Update(size_t pos) { 1074 uint32_t hashval = GetHash(pos); 1075 uint32_t wpos = pos & window_mask_; 1076 1077 val[wpos] = static_cast<int>(hashval); 1078 if (head[hashval] != -1) chain[wpos] = head[hashval]; 1079 head[hashval] = wpos; 1080 1081 if (pos > 0 && data_[pos] != data_[pos - 1]) numzeros = 0; 1082 numzeros = CountZeros(pos, numzeros); 1083 1084 zeros[wpos] = numzeros; 1085 if (headz[numzeros] != -1) chainz[wpos] = headz[numzeros]; 1086 headz[numzeros] = wpos; 1087 } 1088 1089 void Update(size_t pos, size_t len) { 1090 for (size_t i = 0; i < len; i++) { 1091 Update(pos + i); 1092 } 1093 } 1094 1095 template <typename CB> 1096 void FindMatches(size_t pos, int max_dist, const CB& found_match) const { 1097 uint32_t wpos = pos & window_mask_; 1098 uint32_t hashval = GetHash(pos); 1099 uint32_t hashpos = chain[wpos]; 1100 1101 int prev_dist = 0; 1102 int end = std::min<int>(pos + max_length_, size_); 1103 uint32_t chainlength = 0; 1104 uint32_t best_len = 0; 1105 for (;;) { 1106 int dist = (hashpos <= wpos) ? (wpos - hashpos) 1107 : (wpos - hashpos + window_mask_ + 1); 1108 if (dist < prev_dist) break; 1109 prev_dist = dist; 1110 uint32_t len = 0; 1111 if (dist > 0) { 1112 int i = pos; 1113 int j = pos - dist; 1114 if (numzeros > 3) { 1115 int r = std::min<int>(numzeros - 1, zeros[hashpos]); 1116 if (i + r >= end) r = end - i - 1; 1117 i += r; 1118 j += r; 1119 } 1120 while (i < end && data_[i] == data_[j]) { 1121 i++; 1122 j++; 1123 } 1124 len = i - pos; 1125 // This can trigger even if the new length is slightly smaller than the 1126 // best length, because it is possible for a slightly cheaper distance 1127 // symbol to occur. 1128 if (len >= min_length_ && len + 2 >= best_len) { 1129 auto it = special_dist_table_.find(dist); 1130 int dist_symbol = (it == special_dist_table_.end()) 1131 ? (num_special_distances_ + dist - 1) 1132 : it->second; 1133 found_match(len, dist_symbol); 1134 if (len > best_len) best_len = len; 1135 } 1136 } 1137 1138 chainlength++; 1139 if (chainlength >= maxchainlength) break; 1140 1141 if (numzeros >= 3 && len > numzeros) { 1142 if (hashpos == chainz[hashpos]) break; 1143 hashpos = chainz[hashpos]; 1144 if (zeros[hashpos] != numzeros) break; 1145 } else { 1146 if (hashpos == chain[hashpos]) break; 1147 hashpos = chain[hashpos]; 1148 if (val[hashpos] != static_cast<int>(hashval)) { 1149 // outdated hash value 1150 break; 1151 } 1152 } 1153 } 1154 } 1155 void FindMatch(size_t pos, int max_dist, size_t* result_dist_symbol, 1156 size_t* result_len) const { 1157 *result_dist_symbol = 0; 1158 *result_len = 1; 1159 FindMatches(pos, max_dist, [&](size_t len, size_t dist_symbol) { 1160 if (len > *result_len || 1161 (len == *result_len && *result_dist_symbol > dist_symbol)) { 1162 *result_len = len; 1163 *result_dist_symbol = dist_symbol; 1164 } 1165 }); 1166 } 1167 }; 1168 1169 float LenCost(size_t len) { 1170 uint32_t nbits, bits, tok; 1171 HybridUintConfig(1, 0, 0).Encode(len, &tok, &nbits, &bits); 1172 constexpr float kCostTable[] = { 1173 2.797667318563126, 3.213177690381199, 2.5706009246743737, 1174 2.408392498667534, 2.829649191872326, 3.3923087753324577, 1175 4.029267451554331, 4.415576699706408, 4.509357574741465, 1176 9.21481543803004, 10.020590190114898, 11.858671627804766, 1177 12.45853300490526, 11.713105831990857, 12.561996324849314, 1178 13.775477692278367, 13.174027068768641, 1179 }; 1180 size_t table_size = sizeof kCostTable / sizeof *kCostTable; 1181 if (tok >= table_size) tok = table_size - 1; 1182 return kCostTable[tok] + nbits; 1183 } 1184 1185 // TODO(veluca): this does not take into account usage or non-usage of distance 1186 // multipliers. 1187 float DistCost(size_t dist) { 1188 uint32_t nbits, bits, tok; 1189 HybridUintConfig(7, 0, 0).Encode(dist, &tok, &nbits, &bits); 1190 constexpr float kCostTable[] = { 1191 6.368282626312716, 5.680793277090298, 8.347404197105247, 1192 7.641619201599141, 6.914328374119438, 7.959808291537444, 1193 8.70023120759855, 8.71378518934703, 9.379132523982769, 1194 9.110472749092708, 9.159029569270908, 9.430936766731973, 1195 7.278284055315169, 7.8278514904267755, 10.026641158289236, 1196 9.976049229827066, 9.64351607048908, 9.563403863480442, 1197 10.171474111762747, 10.45950155077234, 9.994813912104219, 1198 10.322524683741156, 8.465808729388186, 8.756254166066853, 1199 10.160930174662234, 10.247329273413435, 10.04090403724809, 1200 10.129398517544082, 9.342311691539546, 9.07608009102374, 1201 10.104799540677513, 10.378079384990906, 10.165828974075072, 1202 10.337595322341553, 7.940557464567944, 10.575665823319431, 1203 11.023344321751955, 10.736144698831827, 11.118277044595054, 1204 7.468468230648442, 10.738305230932939, 10.906980780216568, 1205 10.163468216353817, 10.17805759656433, 11.167283670483565, 1206 11.147050200274544, 10.517921919244333, 10.651764778156886, 1207 10.17074446448919, 11.217636876224745, 11.261630721139484, 1208 11.403140815247259, 10.892472096873417, 11.1859607804481, 1209 8.017346947551262, 7.895143720278828, 11.036577113822025, 1210 11.170562110315794, 10.326988722591086, 10.40872184751056, 1211 11.213498225466386, 11.30580635516863, 10.672272515665442, 1212 10.768069466228063, 11.145257364153565, 11.64668307145549, 1213 10.593156194627339, 11.207499484844943, 10.767517766396908, 1214 10.826629811407042, 10.737764794499988, 10.6200448518045, 1215 10.191315385198092, 8.468384171390085, 11.731295299170432, 1216 11.824619886654398, 10.41518844301179, 10.16310536548649, 1217 10.539423685097576, 10.495136599328031, 10.469112847728267, 1218 11.72057686174922, 10.910326337834674, 11.378921834673758, 1219 11.847759036098536, 11.92071647623854, 10.810628276345282, 1220 11.008601085273893, 11.910326337834674, 11.949212023423133, 1221 11.298614839104337, 11.611603659010392, 10.472930394619985, 1222 11.835564720850282, 11.523267392285337, 12.01055816679611, 1223 8.413029688994023, 11.895784139536406, 11.984679534970505, 1224 11.220654278717394, 11.716311684833672, 10.61036646226114, 1225 10.89849965960364, 10.203762898863669, 10.997560826267238, 1226 11.484217379438984, 11.792836176993665, 12.24310468755171, 1227 11.464858097919262, 12.212747017409377, 11.425595666074955, 1228 11.572048533398757, 12.742093965163013, 11.381874288645637, 1229 12.191870445817015, 11.683156920035426, 11.152442115262197, 1230 11.90303691580457, 11.653292787169159, 11.938615382266098, 1231 16.970641701570223, 16.853602280380002, 17.26240782594733, 1232 16.644655390108507, 17.14310889757499, 16.910935455445955, 1233 17.505678976959697, 17.213498225466388, 2.4162310293553024, 1234 3.494587244462329, 3.5258600986408344, 3.4959806589517095, 1235 3.098390886949687, 3.343454654302911, 3.588847442290287, 1236 4.14614790111827, 5.152948641990529, 7.433696808092598, 1237 9.716311684833672, 1238 }; 1239 size_t table_size = sizeof kCostTable / sizeof *kCostTable; 1240 if (tok >= table_size) tok = table_size - 1; 1241 return kCostTable[tok] + nbits; 1242 } 1243 1244 void ApplyLZ77_LZ77(const HistogramParams& params, size_t num_contexts, 1245 const std::vector<std::vector<Token>>& tokens, 1246 LZ77Params& lz77, 1247 std::vector<std::vector<Token>>& tokens_lz77) { 1248 // TODO(veluca): tune heuristics here. 1249 SymbolCostEstimator sce(num_contexts, params.force_huffman, tokens, lz77); 1250 float bit_decrease = 0; 1251 size_t total_symbols = 0; 1252 tokens_lz77.resize(tokens.size()); 1253 HybridUintConfig uint_config; 1254 std::vector<float> sym_cost; 1255 for (size_t stream = 0; stream < tokens.size(); stream++) { 1256 size_t distance_multiplier = 1257 params.image_widths.size() > stream ? params.image_widths[stream] : 0; 1258 const auto& in = tokens[stream]; 1259 auto& out = tokens_lz77[stream]; 1260 total_symbols += in.size(); 1261 // Cumulative sum of bit costs. 1262 sym_cost.resize(in.size() + 1); 1263 for (size_t i = 0; i < in.size(); i++) { 1264 uint32_t tok, nbits, unused_bits; 1265 uint_config.Encode(in[i].value, &tok, &nbits, &unused_bits); 1266 sym_cost[i + 1] = sce.Bits(in[i].context, tok) + nbits + sym_cost[i]; 1267 } 1268 1269 out.reserve(in.size()); 1270 size_t max_distance = in.size(); 1271 size_t min_length = lz77.min_length; 1272 JXL_ASSERT(min_length >= 3); 1273 size_t max_length = in.size(); 1274 1275 // Use next power of two as window size. 1276 size_t window_size = 1; 1277 while (window_size < max_distance && window_size < kWindowSize) { 1278 window_size <<= 1; 1279 } 1280 1281 HashChain chain(in.data(), in.size(), window_size, min_length, max_length, 1282 distance_multiplier); 1283 size_t len; 1284 size_t dist_symbol; 1285 1286 const size_t max_lazy_match_len = 256; // 0 to disable lazy matching 1287 1288 // Whether the next symbol was already updated (to test lazy matching) 1289 bool already_updated = false; 1290 for (size_t i = 0; i < in.size(); i++) { 1291 out.push_back(in[i]); 1292 if (!already_updated) chain.Update(i); 1293 already_updated = false; 1294 chain.FindMatch(i, max_distance, &dist_symbol, &len); 1295 if (len >= min_length) { 1296 if (len < max_lazy_match_len && i + 1 < in.size()) { 1297 // Try length at next symbol lazy matching 1298 chain.Update(i + 1); 1299 already_updated = true; 1300 size_t len2, dist_symbol2; 1301 chain.FindMatch(i + 1, max_distance, &dist_symbol2, &len2); 1302 if (len2 > len) { 1303 // Use the lazy match. Add literal, and use the next length starting 1304 // from the next byte. 1305 ++i; 1306 already_updated = false; 1307 len = len2; 1308 dist_symbol = dist_symbol2; 1309 out.push_back(in[i]); 1310 } 1311 } 1312 1313 float cost = sym_cost[i + len] - sym_cost[i]; 1314 size_t lz77_len = len - lz77.min_length; 1315 float lz77_cost = LenCost(lz77_len) + DistCost(dist_symbol) + 1316 sce.AddSymbolCost(out.back().context); 1317 1318 if (lz77_cost <= cost) { 1319 out.back().value = len - min_length; 1320 out.back().is_lz77_length = true; 1321 out.emplace_back(lz77.nonserialized_distance_context, dist_symbol); 1322 bit_decrease += cost - lz77_cost; 1323 } else { 1324 // LZ77 match ignored, and symbol already pushed. Push all other 1325 // symbols and skip. 1326 for (size_t j = 1; j < len; j++) { 1327 out.push_back(in[i + j]); 1328 } 1329 } 1330 1331 if (already_updated) { 1332 chain.Update(i + 2, len - 2); 1333 already_updated = false; 1334 } else { 1335 chain.Update(i + 1, len - 1); 1336 } 1337 i += len - 1; 1338 } else { 1339 // Literal, already pushed 1340 } 1341 } 1342 } 1343 1344 if (bit_decrease > total_symbols * 0.2 + 16) { 1345 lz77.enabled = true; 1346 } 1347 } 1348 1349 void ApplyLZ77_Optimal(const HistogramParams& params, size_t num_contexts, 1350 const std::vector<std::vector<Token>>& tokens, 1351 LZ77Params& lz77, 1352 std::vector<std::vector<Token>>& tokens_lz77) { 1353 std::vector<std::vector<Token>> tokens_for_cost_estimate; 1354 ApplyLZ77_LZ77(params, num_contexts, tokens, lz77, tokens_for_cost_estimate); 1355 // If greedy-LZ77 does not give better compression than no-lz77, no reason to 1356 // run the optimal matching. 1357 if (!lz77.enabled) return; 1358 SymbolCostEstimator sce(num_contexts + 1, params.force_huffman, 1359 tokens_for_cost_estimate, lz77); 1360 tokens_lz77.resize(tokens.size()); 1361 HybridUintConfig uint_config; 1362 std::vector<float> sym_cost; 1363 std::vector<uint32_t> dist_symbols; 1364 for (size_t stream = 0; stream < tokens.size(); stream++) { 1365 size_t distance_multiplier = 1366 params.image_widths.size() > stream ? params.image_widths[stream] : 0; 1367 const auto& in = tokens[stream]; 1368 auto& out = tokens_lz77[stream]; 1369 // Cumulative sum of bit costs. 1370 sym_cost.resize(in.size() + 1); 1371 for (size_t i = 0; i < in.size(); i++) { 1372 uint32_t tok, nbits, unused_bits; 1373 uint_config.Encode(in[i].value, &tok, &nbits, &unused_bits); 1374 sym_cost[i + 1] = sce.Bits(in[i].context, tok) + nbits + sym_cost[i]; 1375 } 1376 1377 out.reserve(in.size()); 1378 size_t max_distance = in.size(); 1379 size_t min_length = lz77.min_length; 1380 JXL_ASSERT(min_length >= 3); 1381 size_t max_length = in.size(); 1382 1383 // Use next power of two as window size. 1384 size_t window_size = 1; 1385 while (window_size < max_distance && window_size < kWindowSize) { 1386 window_size <<= 1; 1387 } 1388 1389 HashChain chain(in.data(), in.size(), window_size, min_length, max_length, 1390 distance_multiplier); 1391 1392 struct MatchInfo { 1393 uint32_t len; 1394 uint32_t dist_symbol; 1395 uint32_t ctx; 1396 float total_cost = std::numeric_limits<float>::max(); 1397 }; 1398 // Total cost to encode the first N symbols. 1399 std::vector<MatchInfo> prefix_costs(in.size() + 1); 1400 prefix_costs[0].total_cost = 0; 1401 1402 size_t rle_length = 0; 1403 size_t skip_lz77 = 0; 1404 for (size_t i = 0; i < in.size(); i++) { 1405 chain.Update(i); 1406 float lit_cost = 1407 prefix_costs[i].total_cost + sym_cost[i + 1] - sym_cost[i]; 1408 if (prefix_costs[i + 1].total_cost > lit_cost) { 1409 prefix_costs[i + 1].dist_symbol = 0; 1410 prefix_costs[i + 1].len = 1; 1411 prefix_costs[i + 1].ctx = in[i].context; 1412 prefix_costs[i + 1].total_cost = lit_cost; 1413 } 1414 if (skip_lz77 > 0) { 1415 skip_lz77--; 1416 continue; 1417 } 1418 dist_symbols.clear(); 1419 chain.FindMatches(i, max_distance, 1420 [&dist_symbols](size_t len, size_t dist_symbol) { 1421 if (dist_symbols.size() <= len) { 1422 dist_symbols.resize(len + 1, dist_symbol); 1423 } 1424 if (dist_symbol < dist_symbols[len]) { 1425 dist_symbols[len] = dist_symbol; 1426 } 1427 }); 1428 if (dist_symbols.size() <= min_length) continue; 1429 { 1430 size_t best_cost = dist_symbols.back(); 1431 for (size_t j = dist_symbols.size() - 1; j >= min_length; j--) { 1432 if (dist_symbols[j] < best_cost) { 1433 best_cost = dist_symbols[j]; 1434 } 1435 dist_symbols[j] = best_cost; 1436 } 1437 } 1438 for (size_t j = min_length; j < dist_symbols.size(); j++) { 1439 // Cost model that uses results from lazy LZ77. 1440 float lz77_cost = sce.LenCost(in[i].context, j - min_length, lz77) + 1441 sce.DistCost(dist_symbols[j], lz77); 1442 float cost = prefix_costs[i].total_cost + lz77_cost; 1443 if (prefix_costs[i + j].total_cost > cost) { 1444 prefix_costs[i + j].len = j; 1445 prefix_costs[i + j].dist_symbol = dist_symbols[j] + 1; 1446 prefix_costs[i + j].ctx = in[i].context; 1447 prefix_costs[i + j].total_cost = cost; 1448 } 1449 } 1450 // We are in a RLE sequence: skip all the symbols except the first 8 and 1451 // the last 8. This avoid quadratic costs for sequences with long runs of 1452 // the same symbol. 1453 if ((dist_symbols.back() == 0 && distance_multiplier == 0) || 1454 (dist_symbols.back() == 1 && distance_multiplier != 0)) { 1455 rle_length++; 1456 } else { 1457 rle_length = 0; 1458 } 1459 if (rle_length >= 8 && dist_symbols.size() > 9) { 1460 skip_lz77 = dist_symbols.size() - 10; 1461 rle_length = 0; 1462 } 1463 } 1464 size_t pos = in.size(); 1465 while (pos > 0) { 1466 bool is_lz77_length = prefix_costs[pos].dist_symbol != 0; 1467 if (is_lz77_length) { 1468 size_t dist_symbol = prefix_costs[pos].dist_symbol - 1; 1469 out.emplace_back(lz77.nonserialized_distance_context, dist_symbol); 1470 } 1471 size_t val = is_lz77_length ? prefix_costs[pos].len - min_length 1472 : in[pos - 1].value; 1473 out.emplace_back(prefix_costs[pos].ctx, val); 1474 out.back().is_lz77_length = is_lz77_length; 1475 pos -= prefix_costs[pos].len; 1476 } 1477 std::reverse(out.begin(), out.end()); 1478 } 1479 } 1480 1481 void ApplyLZ77(const HistogramParams& params, size_t num_contexts, 1482 const std::vector<std::vector<Token>>& tokens, LZ77Params& lz77, 1483 std::vector<std::vector<Token>>& tokens_lz77) { 1484 if (params.initialize_global_state) { 1485 lz77.enabled = false; 1486 } 1487 if (params.force_huffman) { 1488 lz77.min_symbol = std::min(PREFIX_MAX_ALPHABET_SIZE - 32, 512); 1489 } else { 1490 lz77.min_symbol = 224; 1491 } 1492 if (params.lz77_method == HistogramParams::LZ77Method::kNone) { 1493 return; 1494 } else if (params.lz77_method == HistogramParams::LZ77Method::kRLE) { 1495 ApplyLZ77_RLE(params, num_contexts, tokens, lz77, tokens_lz77); 1496 } else if (params.lz77_method == HistogramParams::LZ77Method::kLZ77) { 1497 ApplyLZ77_LZ77(params, num_contexts, tokens, lz77, tokens_lz77); 1498 } else if (params.lz77_method == HistogramParams::LZ77Method::kOptimal) { 1499 ApplyLZ77_Optimal(params, num_contexts, tokens, lz77, tokens_lz77); 1500 } else { 1501 JXL_UNREACHABLE("Not implemented"); 1502 } 1503 } 1504 } // namespace 1505 1506 void EncodeHistograms(const std::vector<uint8_t>& context_map, 1507 const EntropyEncodingData& codes, BitWriter* writer, 1508 size_t layer, AuxOut* aux_out) { 1509 BitWriter::Allotment allotment(writer, 128 + kClustersLimit * 136); 1510 JXL_CHECK(Bundle::Write(codes.lz77, writer, layer, aux_out)); 1511 if (codes.lz77.enabled) { 1512 EncodeUintConfig(codes.lz77.length_uint_config, writer, 1513 /*log_alpha_size=*/8); 1514 } 1515 EncodeContextMap(context_map, codes.encoding_info.size(), writer, layer, 1516 aux_out); 1517 writer->Write(1, TO_JXL_BOOL(codes.use_prefix_code)); 1518 size_t log_alpha_size = 8; 1519 if (codes.use_prefix_code) { 1520 log_alpha_size = PREFIX_MAX_BITS; 1521 } else { 1522 log_alpha_size = 8; // streaming_mode 1523 writer->Write(2, log_alpha_size - 5); 1524 } 1525 EncodeUintConfigs(codes.uint_config, writer, log_alpha_size); 1526 if (codes.use_prefix_code) { 1527 for (const auto& info : codes.encoding_info) { 1528 StoreVarLenUint16(info.size() - 1, writer); 1529 } 1530 } 1531 for (const auto& histo_writer : codes.encoded_histograms) { 1532 writer->AppendUnaligned(histo_writer); 1533 } 1534 allotment.FinishedHistogram(writer); 1535 allotment.ReclaimAndCharge(writer, layer, aux_out); 1536 } 1537 1538 size_t BuildAndEncodeHistograms(const HistogramParams& params, 1539 size_t num_contexts, 1540 std::vector<std::vector<Token>>& tokens, 1541 EntropyEncodingData* codes, 1542 std::vector<uint8_t>* context_map, 1543 BitWriter* writer, size_t layer, 1544 AuxOut* aux_out) { 1545 size_t total_bits = 0; 1546 codes->lz77.nonserialized_distance_context = num_contexts; 1547 std::vector<std::vector<Token>> tokens_lz77; 1548 ApplyLZ77(params, num_contexts, tokens, codes->lz77, tokens_lz77); 1549 if (ans_fuzzer_friendly_) { 1550 codes->lz77.length_uint_config = HybridUintConfig(10, 0, 0); 1551 codes->lz77.min_symbol = 2048; 1552 } 1553 1554 const size_t max_contexts = std::min(num_contexts, kClustersLimit); 1555 BitWriter::Allotment allotment(writer, 1556 128 + num_contexts * 40 + max_contexts * 96); 1557 if (writer) { 1558 JXL_CHECK(Bundle::Write(codes->lz77, writer, layer, aux_out)); 1559 } else { 1560 size_t ebits, bits; 1561 JXL_CHECK(Bundle::CanEncode(codes->lz77, &ebits, &bits)); 1562 total_bits += bits; 1563 } 1564 if (codes->lz77.enabled) { 1565 if (writer) { 1566 size_t b = writer->BitsWritten(); 1567 EncodeUintConfig(codes->lz77.length_uint_config, writer, 1568 /*log_alpha_size=*/8); 1569 total_bits += writer->BitsWritten() - b; 1570 } else { 1571 SizeWriter size_writer; 1572 EncodeUintConfig(codes->lz77.length_uint_config, &size_writer, 1573 /*log_alpha_size=*/8); 1574 total_bits += size_writer.size; 1575 } 1576 num_contexts += 1; 1577 tokens = std::move(tokens_lz77); 1578 } 1579 size_t total_tokens = 0; 1580 // Build histograms. 1581 HistogramBuilder builder(num_contexts); 1582 HybridUintConfig uint_config; // Default config for clustering. 1583 // Unless we are using the kContextMap histogram option. 1584 if (params.uint_method == HistogramParams::HybridUintMethod::kContextMap) { 1585 uint_config = HybridUintConfig(2, 0, 1); 1586 } 1587 if (params.uint_method == HistogramParams::HybridUintMethod::k000) { 1588 uint_config = HybridUintConfig(0, 0, 0); 1589 } 1590 if (ans_fuzzer_friendly_) { 1591 uint_config = HybridUintConfig(10, 0, 0); 1592 } 1593 for (const auto& stream : tokens) { 1594 if (codes->lz77.enabled) { 1595 for (const auto& token : stream) { 1596 total_tokens++; 1597 uint32_t tok, nbits, bits; 1598 (token.is_lz77_length ? codes->lz77.length_uint_config : uint_config) 1599 .Encode(token.value, &tok, &nbits, &bits); 1600 tok += token.is_lz77_length ? codes->lz77.min_symbol : 0; 1601 builder.VisitSymbol(tok, token.context); 1602 } 1603 } else if (num_contexts == 1) { 1604 for (const auto& token : stream) { 1605 total_tokens++; 1606 uint32_t tok, nbits, bits; 1607 uint_config.Encode(token.value, &tok, &nbits, &bits); 1608 builder.VisitSymbol(tok, /*token.context=*/0); 1609 } 1610 } else { 1611 for (const auto& token : stream) { 1612 total_tokens++; 1613 uint32_t tok, nbits, bits; 1614 uint_config.Encode(token.value, &tok, &nbits, &bits); 1615 builder.VisitSymbol(tok, token.context); 1616 } 1617 } 1618 } 1619 1620 if (params.add_missing_symbols) { 1621 for (size_t c = 0; c < num_contexts; ++c) { 1622 for (int symbol = 0; symbol < ANS_MAX_ALPHABET_SIZE; ++symbol) { 1623 builder.VisitSymbol(symbol, c); 1624 } 1625 } 1626 } 1627 1628 if (params.initialize_global_state) { 1629 bool use_prefix_code = 1630 params.force_huffman || total_tokens < 100 || 1631 params.clustering == HistogramParams::ClusteringType::kFastest || 1632 ans_fuzzer_friendly_; 1633 if (!use_prefix_code) { 1634 bool all_singleton = true; 1635 for (size_t i = 0; i < num_contexts; i++) { 1636 if (builder.Histo(i).ShannonEntropy() >= 1e-5) { 1637 all_singleton = false; 1638 } 1639 } 1640 if (all_singleton) { 1641 use_prefix_code = true; 1642 } 1643 } 1644 codes->use_prefix_code = use_prefix_code; 1645 } 1646 1647 if (params.add_fixed_histograms) { 1648 // TODO(szabadka) Add more fixed histograms. 1649 // TODO(szabadka) Reduce alphabet size by choosing a non-default 1650 // uint_config. 1651 const size_t alphabet_size = ANS_MAX_ALPHABET_SIZE; 1652 const size_t log_alpha_size = 8; 1653 JXL_ASSERT(alphabet_size == 1u << log_alpha_size); 1654 std::vector<int32_t> counts = 1655 CreateFlatHistogram(alphabet_size, ANS_TAB_SIZE); 1656 codes->encoding_info.emplace_back(); 1657 codes->encoding_info.back().resize(alphabet_size); 1658 codes->encoded_histograms.emplace_back(); 1659 BitWriter* histo_writer = &codes->encoded_histograms.back(); 1660 BitWriter::Allotment allotment(histo_writer, 256 + alphabet_size * 24); 1661 BuildAndStoreANSEncodingData( 1662 params.ans_histogram_strategy, counts.data(), alphabet_size, 1663 log_alpha_size, codes->use_prefix_code, 1664 codes->encoding_info.back().data(), histo_writer); 1665 allotment.ReclaimAndCharge(histo_writer, 0, nullptr); 1666 } 1667 1668 // Encode histograms. 1669 total_bits += builder.BuildAndStoreEntropyCodes( 1670 params, tokens, codes, context_map, writer, layer, aux_out); 1671 allotment.FinishedHistogram(writer); 1672 allotment.ReclaimAndCharge(writer, layer, aux_out); 1673 1674 if (aux_out != nullptr) { 1675 aux_out->layers[layer].num_clustered_histograms += 1676 codes->encoding_info.size(); 1677 } 1678 return total_bits; 1679 } 1680 1681 size_t WriteTokens(const std::vector<Token>& tokens, 1682 const EntropyEncodingData& codes, 1683 const std::vector<uint8_t>& context_map, 1684 size_t context_offset, BitWriter* writer) { 1685 size_t num_extra_bits = 0; 1686 if (codes.use_prefix_code) { 1687 for (const auto& token : tokens) { 1688 uint32_t tok, nbits, bits; 1689 size_t histo = context_map[context_offset + token.context]; 1690 (token.is_lz77_length ? codes.lz77.length_uint_config 1691 : codes.uint_config[histo]) 1692 .Encode(token.value, &tok, &nbits, &bits); 1693 tok += token.is_lz77_length ? codes.lz77.min_symbol : 0; 1694 // Combine two calls to the BitWriter. Equivalent to: 1695 // writer->Write(codes.encoding_info[histo][tok].depth, 1696 // codes.encoding_info[histo][tok].bits); 1697 // writer->Write(nbits, bits); 1698 uint64_t data = codes.encoding_info[histo][tok].bits; 1699 data |= static_cast<uint64_t>(bits) 1700 << codes.encoding_info[histo][tok].depth; 1701 writer->Write(codes.encoding_info[histo][tok].depth + nbits, data); 1702 num_extra_bits += nbits; 1703 } 1704 return num_extra_bits; 1705 } 1706 std::vector<uint64_t> out; 1707 std::vector<uint8_t> out_nbits; 1708 out.reserve(tokens.size()); 1709 out_nbits.reserve(tokens.size()); 1710 uint64_t allbits = 0; 1711 size_t numallbits = 0; 1712 // Writes in *reversed* order. 1713 auto addbits = [&](size_t bits, size_t nbits) { 1714 if (JXL_UNLIKELY(nbits)) { 1715 JXL_DASSERT(bits >> nbits == 0); 1716 if (JXL_UNLIKELY(numallbits + nbits > BitWriter::kMaxBitsPerCall)) { 1717 out.push_back(allbits); 1718 out_nbits.push_back(numallbits); 1719 numallbits = allbits = 0; 1720 } 1721 allbits <<= nbits; 1722 allbits |= bits; 1723 numallbits += nbits; 1724 } 1725 }; 1726 const int end = tokens.size(); 1727 ANSCoder ans; 1728 if (codes.lz77.enabled || context_map.size() > 1) { 1729 for (int i = end - 1; i >= 0; --i) { 1730 const Token token = tokens[i]; 1731 const uint8_t histo = context_map[context_offset + token.context]; 1732 uint32_t tok, nbits, bits; 1733 (token.is_lz77_length ? codes.lz77.length_uint_config 1734 : codes.uint_config[histo]) 1735 .Encode(tokens[i].value, &tok, &nbits, &bits); 1736 tok += token.is_lz77_length ? codes.lz77.min_symbol : 0; 1737 const ANSEncSymbolInfo& info = codes.encoding_info[histo][tok]; 1738 JXL_DASSERT(info.freq_ > 0); 1739 // Extra bits first as this is reversed. 1740 addbits(bits, nbits); 1741 num_extra_bits += nbits; 1742 uint8_t ans_nbits = 0; 1743 uint32_t ans_bits = ans.PutSymbol(info, &ans_nbits); 1744 addbits(ans_bits, ans_nbits); 1745 } 1746 } else { 1747 for (int i = end - 1; i >= 0; --i) { 1748 uint32_t tok, nbits, bits; 1749 codes.uint_config[0].Encode(tokens[i].value, &tok, &nbits, &bits); 1750 const ANSEncSymbolInfo& info = codes.encoding_info[0][tok]; 1751 // Extra bits first as this is reversed. 1752 addbits(bits, nbits); 1753 num_extra_bits += nbits; 1754 uint8_t ans_nbits = 0; 1755 uint32_t ans_bits = ans.PutSymbol(info, &ans_nbits); 1756 addbits(ans_bits, ans_nbits); 1757 } 1758 } 1759 const uint32_t state = ans.GetState(); 1760 writer->Write(32, state); 1761 writer->Write(numallbits, allbits); 1762 for (int i = out.size(); i > 0; --i) { 1763 writer->Write(out_nbits[i - 1], out[i - 1]); 1764 } 1765 return num_extra_bits; 1766 } 1767 1768 void WriteTokens(const std::vector<Token>& tokens, 1769 const EntropyEncodingData& codes, 1770 const std::vector<uint8_t>& context_map, size_t context_offset, 1771 BitWriter* writer, size_t layer, AuxOut* aux_out) { 1772 // Theoretically, we could have 15 prefix code bits + 31 extra bits. 1773 BitWriter::Allotment allotment(writer, 46 * tokens.size() + 32 * 1024 * 4); 1774 size_t num_extra_bits = 1775 WriteTokens(tokens, codes, context_map, context_offset, writer); 1776 allotment.ReclaimAndCharge(writer, layer, aux_out); 1777 if (aux_out != nullptr) { 1778 aux_out->layers[layer].extra_bits += num_extra_bits; 1779 } 1780 } 1781 1782 void SetANSFuzzerFriendly(bool ans_fuzzer_friendly) { 1783 #if JXL_IS_DEBUG_BUILD // Guard against accidental / malicious changes. 1784 ans_fuzzer_friendly_ = ans_fuzzer_friendly; 1785 #endif 1786 } 1787 1788 HistogramParams HistogramParams::ForModular( 1789 const CompressParams& cparams, 1790 const std::vector<uint8_t>& extra_dc_precision, bool streaming_mode) { 1791 HistogramParams params; 1792 params.streaming_mode = streaming_mode; 1793 if (cparams.speed_tier > SpeedTier::kKitten) { 1794 params.clustering = HistogramParams::ClusteringType::kFast; 1795 params.ans_histogram_strategy = 1796 cparams.speed_tier > SpeedTier::kThunder 1797 ? HistogramParams::ANSHistogramStrategy::kFast 1798 : HistogramParams::ANSHistogramStrategy::kApproximate; 1799 params.lz77_method = 1800 cparams.decoding_speed_tier >= 3 && cparams.modular_mode 1801 ? (cparams.speed_tier >= SpeedTier::kFalcon 1802 ? HistogramParams::LZ77Method::kRLE 1803 : HistogramParams::LZ77Method::kLZ77) 1804 : HistogramParams::LZ77Method::kNone; 1805 // Near-lossless DC, as well as modular mode, require choosing hybrid uint 1806 // more carefully. 1807 if ((!extra_dc_precision.empty() && extra_dc_precision[0] != 0) || 1808 (cparams.modular_mode && cparams.speed_tier < SpeedTier::kCheetah)) { 1809 params.uint_method = HistogramParams::HybridUintMethod::kFast; 1810 } else { 1811 params.uint_method = HistogramParams::HybridUintMethod::kNone; 1812 } 1813 } else if (cparams.speed_tier <= SpeedTier::kTortoise) { 1814 params.lz77_method = HistogramParams::LZ77Method::kOptimal; 1815 } else { 1816 params.lz77_method = HistogramParams::LZ77Method::kLZ77; 1817 } 1818 if (cparams.decoding_speed_tier >= 1) { 1819 params.max_histograms = 12; 1820 } 1821 if (cparams.decoding_speed_tier >= 1 && cparams.responsive) { 1822 params.lz77_method = cparams.speed_tier >= SpeedTier::kCheetah 1823 ? HistogramParams::LZ77Method::kRLE 1824 : cparams.speed_tier >= SpeedTier::kKitten 1825 ? HistogramParams::LZ77Method::kLZ77 1826 : HistogramParams::LZ77Method::kOptimal; 1827 } 1828 if (cparams.decoding_speed_tier >= 2 && cparams.responsive) { 1829 params.uint_method = HistogramParams::HybridUintMethod::k000; 1830 params.force_huffman = true; 1831 } 1832 return params; 1833 } 1834 } // namespace jxl