libjxl

FORK: libjxl patches used on blog
git clone https://git.neptards.moe/blog/libjxl.git
Log | Files | Refs | Submodules | README | LICENSE

dec_ans.cc (13487B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 #include "lib/jxl/dec_ans.h"
      7 
      8 #include <stdint.h>
      9 
     10 #include <vector>
     11 
     12 #include "lib/jxl/ans_common.h"
     13 #include "lib/jxl/ans_params.h"
     14 #include "lib/jxl/base/bits.h"
     15 #include "lib/jxl/base/printf_macros.h"
     16 #include "lib/jxl/base/status.h"
     17 #include "lib/jxl/dec_context_map.h"
     18 #include "lib/jxl/fields.h"
     19 
     20 namespace jxl {
     21 namespace {
     22 
     23 // Decodes a number in the range [0..255], by reading 1 - 11 bits.
     24 inline int DecodeVarLenUint8(BitReader* input) {
     25   if (input->ReadFixedBits<1>()) {
     26     int nbits = static_cast<int>(input->ReadFixedBits<3>());
     27     if (nbits == 0) {
     28       return 1;
     29     } else {
     30       return static_cast<int>(input->ReadBits(nbits)) + (1 << nbits);
     31     }
     32   }
     33   return 0;
     34 }
     35 
     36 // Decodes a number in the range [0..65535], by reading 1 - 21 bits.
     37 inline int DecodeVarLenUint16(BitReader* input) {
     38   if (input->ReadFixedBits<1>()) {
     39     int nbits = static_cast<int>(input->ReadFixedBits<4>());
     40     if (nbits == 0) {
     41       return 1;
     42     } else {
     43       return static_cast<int>(input->ReadBits(nbits)) + (1 << nbits);
     44     }
     45   }
     46   return 0;
     47 }
     48 
     49 Status ReadHistogram(int precision_bits, std::vector<int32_t>* counts,
     50                      BitReader* input) {
     51   int simple_code = input->ReadBits(1);
     52   if (simple_code == 1) {
     53     int i;
     54     int symbols[2] = {0};
     55     int max_symbol = 0;
     56     const int num_symbols = input->ReadBits(1) + 1;
     57     for (i = 0; i < num_symbols; ++i) {
     58       symbols[i] = DecodeVarLenUint8(input);
     59       if (symbols[i] > max_symbol) max_symbol = symbols[i];
     60     }
     61     counts->resize(max_symbol + 1);
     62     if (num_symbols == 1) {
     63       (*counts)[symbols[0]] = 1 << precision_bits;
     64     } else {
     65       if (symbols[0] == symbols[1]) {  // corrupt data
     66         return false;
     67       }
     68       (*counts)[symbols[0]] = input->ReadBits(precision_bits);
     69       (*counts)[symbols[1]] = (1 << precision_bits) - (*counts)[symbols[0]];
     70     }
     71   } else {
     72     int is_flat = input->ReadBits(1);
     73     if (is_flat == 1) {
     74       int alphabet_size = DecodeVarLenUint8(input) + 1;
     75       *counts = CreateFlatHistogram(alphabet_size, 1 << precision_bits);
     76       return true;
     77     }
     78 
     79     uint32_t shift;
     80     {
     81       // TODO(veluca): speed up reading with table lookups.
     82       int upper_bound_log = FloorLog2Nonzero(ANS_LOG_TAB_SIZE + 1);
     83       int log = 0;
     84       for (; log < upper_bound_log; log++) {
     85         if (input->ReadFixedBits<1>() == 0) break;
     86       }
     87       shift = (input->ReadBits(log) | (1 << log)) - 1;
     88       if (shift > ANS_LOG_TAB_SIZE + 1) {
     89         return JXL_FAILURE("Invalid shift value");
     90       }
     91     }
     92 
     93     int length = DecodeVarLenUint8(input) + 3;
     94     counts->resize(length);
     95     int total_count = 0;
     96 
     97     static const uint8_t huff[128][2] = {
     98         {3, 10}, {7, 12}, {3, 7}, {4, 3}, {3, 6}, {3, 8}, {3, 9}, {4, 5},
     99         {3, 10}, {4, 4},  {3, 7}, {4, 1}, {3, 6}, {3, 8}, {3, 9}, {4, 2},
    100         {3, 10}, {5, 0},  {3, 7}, {4, 3}, {3, 6}, {3, 8}, {3, 9}, {4, 5},
    101         {3, 10}, {4, 4},  {3, 7}, {4, 1}, {3, 6}, {3, 8}, {3, 9}, {4, 2},
    102         {3, 10}, {6, 11}, {3, 7}, {4, 3}, {3, 6}, {3, 8}, {3, 9}, {4, 5},
    103         {3, 10}, {4, 4},  {3, 7}, {4, 1}, {3, 6}, {3, 8}, {3, 9}, {4, 2},
    104         {3, 10}, {5, 0},  {3, 7}, {4, 3}, {3, 6}, {3, 8}, {3, 9}, {4, 5},
    105         {3, 10}, {4, 4},  {3, 7}, {4, 1}, {3, 6}, {3, 8}, {3, 9}, {4, 2},
    106         {3, 10}, {7, 13}, {3, 7}, {4, 3}, {3, 6}, {3, 8}, {3, 9}, {4, 5},
    107         {3, 10}, {4, 4},  {3, 7}, {4, 1}, {3, 6}, {3, 8}, {3, 9}, {4, 2},
    108         {3, 10}, {5, 0},  {3, 7}, {4, 3}, {3, 6}, {3, 8}, {3, 9}, {4, 5},
    109         {3, 10}, {4, 4},  {3, 7}, {4, 1}, {3, 6}, {3, 8}, {3, 9}, {4, 2},
    110         {3, 10}, {6, 11}, {3, 7}, {4, 3}, {3, 6}, {3, 8}, {3, 9}, {4, 5},
    111         {3, 10}, {4, 4},  {3, 7}, {4, 1}, {3, 6}, {3, 8}, {3, 9}, {4, 2},
    112         {3, 10}, {5, 0},  {3, 7}, {4, 3}, {3, 6}, {3, 8}, {3, 9}, {4, 5},
    113         {3, 10}, {4, 4},  {3, 7}, {4, 1}, {3, 6}, {3, 8}, {3, 9}, {4, 2},
    114     };
    115 
    116     std::vector<int> logcounts(counts->size());
    117     int omit_log = -1;
    118     int omit_pos = -1;
    119     // This array remembers which symbols have an RLE length.
    120     std::vector<int> same(counts->size(), 0);
    121     for (size_t i = 0; i < logcounts.size(); ++i) {
    122       input->Refill();  // for PeekFixedBits + Advance
    123       int idx = input->PeekFixedBits<7>();
    124       input->Consume(huff[idx][0]);
    125       logcounts[i] = huff[idx][1];
    126       // The RLE symbol.
    127       if (logcounts[i] == ANS_LOG_TAB_SIZE + 1) {
    128         int rle_length = DecodeVarLenUint8(input);
    129         same[i] = rle_length + 5;
    130         i += rle_length + 3;
    131         continue;
    132       }
    133       if (logcounts[i] > omit_log) {
    134         omit_log = logcounts[i];
    135         omit_pos = i;
    136       }
    137     }
    138     // Invalid input, e.g. due to invalid usage of RLE.
    139     if (omit_pos < 0) return JXL_FAILURE("Invalid histogram.");
    140     if (static_cast<size_t>(omit_pos) + 1 < logcounts.size() &&
    141         logcounts[omit_pos + 1] == ANS_TAB_SIZE + 1) {
    142       return JXL_FAILURE("Invalid histogram.");
    143     }
    144     int prev = 0;
    145     int numsame = 0;
    146     for (size_t i = 0; i < logcounts.size(); ++i) {
    147       if (same[i]) {
    148         // RLE sequence, let this loop output the same count for the next
    149         // iterations.
    150         numsame = same[i] - 1;
    151         prev = i > 0 ? (*counts)[i - 1] : 0;
    152       }
    153       if (numsame > 0) {
    154         (*counts)[i] = prev;
    155         numsame--;
    156       } else {
    157         unsigned int code = logcounts[i];
    158         // omit_pos may not be negative at this point (checked before).
    159         if (i == static_cast<size_t>(omit_pos)) {
    160           continue;
    161         } else if (code == 0) {
    162           continue;
    163         } else if (code == 1) {
    164           (*counts)[i] = 1;
    165         } else {
    166           int bitcount = GetPopulationCountPrecision(code - 1, shift);
    167           (*counts)[i] = (1u << (code - 1)) +
    168                          (input->ReadBits(bitcount) << (code - 1 - bitcount));
    169         }
    170       }
    171       total_count += (*counts)[i];
    172     }
    173     (*counts)[omit_pos] = (1 << precision_bits) - total_count;
    174     if ((*counts)[omit_pos] <= 0) {
    175       // The histogram we've read sums to more than total_count (including at
    176       // least 1 for the omitted value).
    177       return JXL_FAILURE("Invalid histogram count.");
    178     }
    179   }
    180   return true;
    181 }
    182 
    183 }  // namespace
    184 
    185 Status DecodeANSCodes(const size_t num_histograms,
    186                       const size_t max_alphabet_size, BitReader* in,
    187                       ANSCode* result) {
    188   result->degenerate_symbols.resize(num_histograms, -1);
    189   if (result->use_prefix_code) {
    190     JXL_ASSERT(max_alphabet_size <= 1 << PREFIX_MAX_BITS);
    191     result->huffman_data.resize(num_histograms);
    192     std::vector<uint16_t> alphabet_sizes(num_histograms);
    193     for (size_t c = 0; c < num_histograms; c++) {
    194       alphabet_sizes[c] = DecodeVarLenUint16(in) + 1;
    195       if (alphabet_sizes[c] > max_alphabet_size) {
    196         return JXL_FAILURE("Alphabet size is too long: %u", alphabet_sizes[c]);
    197       }
    198     }
    199     for (size_t c = 0; c < num_histograms; c++) {
    200       if (alphabet_sizes[c] > 1) {
    201         if (!result->huffman_data[c].ReadFromBitStream(alphabet_sizes[c], in)) {
    202           if (!in->AllReadsWithinBounds()) {
    203             return JXL_STATUS(StatusCode::kNotEnoughBytes,
    204                               "Not enough bytes for huffman code");
    205           }
    206           return JXL_FAILURE("Invalid huffman tree number %" PRIuS
    207                              ", alphabet size %u",
    208                              c, alphabet_sizes[c]);
    209         }
    210       } else {
    211         // 0-bit codes does not require extension tables.
    212         result->huffman_data[c].table_.clear();
    213         result->huffman_data[c].table_.resize(1u << kHuffmanTableBits);
    214       }
    215       for (const auto& h : result->huffman_data[c].table_) {
    216         if (h.bits <= kHuffmanTableBits) {
    217           result->UpdateMaxNumBits(c, h.value);
    218         }
    219       }
    220     }
    221   } else {
    222     JXL_ASSERT(max_alphabet_size <= ANS_MAX_ALPHABET_SIZE);
    223     result->alias_tables =
    224         AllocateArray(num_histograms * (1 << result->log_alpha_size) *
    225                       sizeof(AliasTable::Entry));
    226     AliasTable::Entry* alias_tables =
    227         reinterpret_cast<AliasTable::Entry*>(result->alias_tables.get());
    228     for (size_t c = 0; c < num_histograms; ++c) {
    229       std::vector<int32_t> counts;
    230       if (!ReadHistogram(ANS_LOG_TAB_SIZE, &counts, in)) {
    231         return JXL_FAILURE("Invalid histogram bitstream.");
    232       }
    233       if (counts.size() > max_alphabet_size) {
    234         return JXL_FAILURE("Alphabet size is too long: %" PRIuS, counts.size());
    235       }
    236       while (!counts.empty() && counts.back() == 0) {
    237         counts.pop_back();
    238       }
    239       for (size_t s = 0; s < counts.size(); s++) {
    240         if (counts[s] != 0) {
    241           result->UpdateMaxNumBits(c, s);
    242         }
    243       }
    244       // InitAliasTable "fixes" empty counts to contain degenerate "0" symbol.
    245       int degenerate_symbol = counts.empty() ? 0 : (counts.size() - 1);
    246       for (int s = 0; s < degenerate_symbol; ++s) {
    247         if (counts[s] != 0) {
    248           degenerate_symbol = -1;
    249           break;
    250         }
    251       }
    252       result->degenerate_symbols[c] = degenerate_symbol;
    253       InitAliasTable(counts, ANS_TAB_SIZE, result->log_alpha_size,
    254                      alias_tables + c * (1 << result->log_alpha_size));
    255     }
    256   }
    257   return true;
    258 }
    259 Status DecodeUintConfig(size_t log_alpha_size, HybridUintConfig* uint_config,
    260                         BitReader* br) {
    261   br->Refill();
    262   size_t split_exponent = br->ReadBits(CeilLog2Nonzero(log_alpha_size + 1));
    263   size_t msb_in_token = 0;
    264   size_t lsb_in_token = 0;
    265   if (split_exponent != log_alpha_size) {
    266     // otherwise, msb/lsb don't matter.
    267     size_t nbits = CeilLog2Nonzero(split_exponent + 1);
    268     msb_in_token = br->ReadBits(nbits);
    269     if (msb_in_token > split_exponent) {
    270       // This could be invalid here already and we need to check this before
    271       // we use its value to read more bits.
    272       return JXL_FAILURE("Invalid HybridUintConfig");
    273     }
    274     nbits = CeilLog2Nonzero(split_exponent - msb_in_token + 1);
    275     lsb_in_token = br->ReadBits(nbits);
    276   }
    277   if (lsb_in_token + msb_in_token > split_exponent) {
    278     return JXL_FAILURE("Invalid HybridUintConfig");
    279   }
    280   *uint_config = HybridUintConfig(split_exponent, msb_in_token, lsb_in_token);
    281   return true;
    282 }
    283 
    284 Status DecodeUintConfigs(size_t log_alpha_size,
    285                          std::vector<HybridUintConfig>* uint_config,
    286                          BitReader* br) {
    287   // TODO(veluca): RLE?
    288   for (auto& cfg : *uint_config) {
    289     JXL_RETURN_IF_ERROR(DecodeUintConfig(log_alpha_size, &cfg, br));
    290   }
    291   return true;
    292 }
    293 
    294 LZ77Params::LZ77Params() { Bundle::Init(this); }
    295 Status LZ77Params::VisitFields(Visitor* JXL_RESTRICT visitor) {
    296   JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &enabled));
    297   if (!visitor->Conditional(enabled)) return true;
    298   JXL_QUIET_RETURN_IF_ERROR(visitor->U32(Val(224), Val(512), Val(4096),
    299                                          BitsOffset(15, 8), 224, &min_symbol));
    300   JXL_QUIET_RETURN_IF_ERROR(visitor->U32(Val(3), Val(4), BitsOffset(2, 5),
    301                                          BitsOffset(8, 9), 3, &min_length));
    302   return true;
    303 }
    304 
    305 void ANSCode::UpdateMaxNumBits(size_t ctx, size_t symbol) {
    306   HybridUintConfig* cfg = &uint_config[ctx];
    307   // LZ77 symbols use a different uint config.
    308   if (lz77.enabled && lz77.nonserialized_distance_context != ctx &&
    309       symbol >= lz77.min_symbol) {
    310     symbol -= lz77.min_symbol;
    311     cfg = &lz77.length_uint_config;
    312   }
    313   size_t split_token = cfg->split_token;
    314   size_t msb_in_token = cfg->msb_in_token;
    315   size_t lsb_in_token = cfg->lsb_in_token;
    316   size_t split_exponent = cfg->split_exponent;
    317   if (symbol < split_token) {
    318     max_num_bits = std::max(max_num_bits, split_exponent);
    319     return;
    320   }
    321   uint32_t n_extra_bits =
    322       split_exponent - (msb_in_token + lsb_in_token) +
    323       ((symbol - split_token) >> (msb_in_token + lsb_in_token));
    324   size_t total_bits = msb_in_token + lsb_in_token + n_extra_bits + 1;
    325   max_num_bits = std::max(max_num_bits, total_bits);
    326 }
    327 
    328 Status DecodeHistograms(BitReader* br, size_t num_contexts, ANSCode* code,
    329                         std::vector<uint8_t>* context_map, bool disallow_lz77) {
    330   JXL_RETURN_IF_ERROR(Bundle::Read(br, &code->lz77));
    331   if (code->lz77.enabled) {
    332     num_contexts++;
    333     JXL_RETURN_IF_ERROR(DecodeUintConfig(/*log_alpha_size=*/8,
    334                                          &code->lz77.length_uint_config, br));
    335   }
    336   if (code->lz77.enabled && disallow_lz77) {
    337     return JXL_FAILURE("Using LZ77 when explicitly disallowed");
    338   }
    339   size_t num_histograms = 1;
    340   context_map->resize(num_contexts);
    341   if (num_contexts > 1) {
    342     JXL_RETURN_IF_ERROR(DecodeContextMap(context_map, &num_histograms, br));
    343   }
    344   JXL_DEBUG_V(
    345       4, "Decoded context map of size %" PRIuS " and %" PRIuS " histograms",
    346       num_contexts, num_histograms);
    347   code->lz77.nonserialized_distance_context = context_map->back();
    348   code->use_prefix_code = static_cast<bool>(br->ReadFixedBits<1>());
    349   if (code->use_prefix_code) {
    350     code->log_alpha_size = PREFIX_MAX_BITS;
    351   } else {
    352     code->log_alpha_size = br->ReadFixedBits<2>() + 5;
    353   }
    354   code->uint_config.resize(num_histograms);
    355   JXL_RETURN_IF_ERROR(
    356       DecodeUintConfigs(code->log_alpha_size, &code->uint_config, br));
    357   const size_t max_alphabet_size = 1 << code->log_alpha_size;
    358   JXL_RETURN_IF_ERROR(
    359       DecodeANSCodes(num_histograms, max_alphabet_size, br, code));
    360   return true;
    361 }
    362 
    363 }  // namespace jxl