libjxl

FORK: libjxl patches used on blog
git clone https://git.neptards.moe/blog/libjxl.git
Log | Files | Refs | Submodules | README | LICENSE

ans_test.cc (9187B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 #include <stddef.h>
      7 #include <stdint.h>
      8 
      9 #include <vector>
     10 
     11 #include "lib/jxl/ans_params.h"
     12 #include "lib/jxl/base/random.h"
     13 #include "lib/jxl/base/span.h"
     14 #include "lib/jxl/dec_ans.h"
     15 #include "lib/jxl/dec_bit_reader.h"
     16 #include "lib/jxl/enc_ans.h"
     17 #include "lib/jxl/enc_aux_out.h"
     18 #include "lib/jxl/enc_bit_writer.h"
     19 #include "lib/jxl/testing.h"
     20 
     21 namespace jxl {
     22 namespace {
     23 
     24 void RoundtripTestcase(int n_histograms, int alphabet_size,
     25                        const std::vector<Token>& input_values) {
     26   constexpr uint16_t kMagic1 = 0x9e33;
     27   constexpr uint16_t kMagic2 = 0x8b04;
     28 
     29   BitWriter writer;
     30   // Space for magic bytes.
     31   BitWriter::Allotment allotment_magic1(&writer, 16);
     32   writer.Write(16, kMagic1);
     33   allotment_magic1.ReclaimAndCharge(&writer, 0, nullptr);
     34 
     35   std::vector<uint8_t> context_map;
     36   EntropyEncodingData codes;
     37   std::vector<std::vector<Token>> input_values_vec;
     38   input_values_vec.push_back(input_values);
     39 
     40   BuildAndEncodeHistograms(HistogramParams(), n_histograms, input_values_vec,
     41                            &codes, &context_map, &writer, 0, nullptr);
     42   WriteTokens(input_values_vec[0], codes, context_map, 0, &writer, 0, nullptr);
     43 
     44   // Magic bytes + padding
     45   BitWriter::Allotment allotment_magic2(&writer, 24);
     46   writer.Write(16, kMagic2);
     47   writer.ZeroPadToByte();
     48   allotment_magic2.ReclaimAndCharge(&writer, 0, nullptr);
     49 
     50   // We do not truncate the output. Reading past the end reads out zeroes
     51   // anyway.
     52   BitReader br(writer.GetSpan());
     53 
     54   ASSERT_EQ(br.ReadBits(16), kMagic1);
     55 
     56   std::vector<uint8_t> dec_context_map;
     57   ANSCode decoded_codes;
     58   ASSERT_TRUE(
     59       DecodeHistograms(&br, n_histograms, &decoded_codes, &dec_context_map));
     60   ASSERT_EQ(dec_context_map, context_map);
     61   ANSSymbolReader reader(&decoded_codes, &br);
     62 
     63   for (const Token& symbol : input_values) {
     64     uint32_t read_symbol =
     65         reader.ReadHybridUint(symbol.context, &br, dec_context_map);
     66     ASSERT_EQ(read_symbol, symbol.value);
     67   }
     68   ASSERT_TRUE(reader.CheckANSFinalState());
     69 
     70   ASSERT_EQ(br.ReadBits(16), kMagic2);
     71   EXPECT_TRUE(br.Close());
     72 }
     73 
     74 TEST(ANSTest, EmptyRoundtrip) {
     75   RoundtripTestcase(2, ANS_MAX_ALPHABET_SIZE, std::vector<Token>());
     76 }
     77 
     78 TEST(ANSTest, SingleSymbolRoundtrip) {
     79   for (uint32_t i = 0; i < ANS_MAX_ALPHABET_SIZE; i++) {
     80     RoundtripTestcase(2, ANS_MAX_ALPHABET_SIZE, {{0, i}});
     81   }
     82   for (uint32_t i = 0; i < ANS_MAX_ALPHABET_SIZE; i++) {
     83     RoundtripTestcase(2, ANS_MAX_ALPHABET_SIZE,
     84                       std::vector<Token>(1024, {0, i}));
     85   }
     86 }
     87 
     88 #if defined(ADDRESS_SANITIZER) || defined(MEMORY_SANITIZER) || \
     89     defined(THREAD_SANITIZER)
     90 constexpr size_t kReps = 3;
     91 #else
     92 constexpr size_t kReps = 10;
     93 #endif
     94 
     95 void RoundtripRandomStream(int alphabet_size, size_t reps = kReps,
     96                            size_t num = 1 << 18) {
     97   constexpr int kNumHistograms = 3;
     98   Rng rng(0);
     99   for (size_t i = 0; i < reps; i++) {
    100     std::vector<Token> symbols;
    101     for (size_t j = 0; j < num; j++) {
    102       int context = rng.UniformI(0, kNumHistograms);
    103       int value = rng.UniformU(0, alphabet_size);
    104       symbols.emplace_back(context, value);
    105     }
    106     RoundtripTestcase(kNumHistograms, alphabet_size, symbols);
    107   }
    108 }
    109 
    110 void RoundtripRandomUnbalancedStream(int alphabet_size) {
    111   constexpr int kNumHistograms = 3;
    112   constexpr int kPrecision = 1 << 10;
    113   Rng rng(0);
    114   for (size_t i = 0; i < kReps; i++) {
    115     std::vector<int> distributions[kNumHistograms] = {};
    116     for (auto& distr : distributions) {
    117       distr.resize(kPrecision);
    118       int symbol = 0;
    119       int remaining = 1;
    120       for (int k = 0; k < kPrecision; k++) {
    121         if (remaining == 0) {
    122           if (symbol < alphabet_size - 1) symbol++;
    123           // There is no meaning behind this distribution: it's anything that
    124           // will create a nonuniform distribution and won't have too few
    125           // symbols usually. Also we want different distributions we get to be
    126           // sufficiently dissimilar.
    127           remaining = rng.UniformU(0, kPrecision - k + 1);
    128         }
    129         distr[k] = symbol;
    130         remaining--;
    131       }
    132     }
    133     std::vector<Token> symbols;
    134     for (int j = 0; j < 1 << 18; j++) {
    135       int context = rng.UniformI(0, kNumHistograms);
    136       int value = rng.UniformU(0, kPrecision);
    137       symbols.emplace_back(context, value);
    138     }
    139     RoundtripTestcase(kNumHistograms + 1, alphabet_size, symbols);
    140   }
    141 }
    142 
    143 TEST(ANSTest, RandomStreamRoundtrip3Small) { RoundtripRandomStream(3, 1, 16); }
    144 
    145 TEST(ANSTest, RandomStreamRoundtrip3) { RoundtripRandomStream(3); }
    146 
    147 TEST(ANSTest, RandomStreamRoundtripBig) {
    148   RoundtripRandomStream(ANS_MAX_ALPHABET_SIZE);
    149 }
    150 
    151 TEST(ANSTest, RandomUnbalancedStreamRoundtrip3) {
    152   RoundtripRandomUnbalancedStream(3);
    153 }
    154 
    155 TEST(ANSTest, RandomUnbalancedStreamRoundtripBig) {
    156   RoundtripRandomUnbalancedStream(ANS_MAX_ALPHABET_SIZE);
    157 }
    158 
    159 TEST(ANSTest, UintConfigRoundtrip) {
    160   for (size_t log_alpha_size = 5; log_alpha_size <= 8; log_alpha_size++) {
    161     std::vector<HybridUintConfig> uint_config;
    162     std::vector<HybridUintConfig> uint_config_dec;
    163     for (size_t i = 0; i < log_alpha_size; i++) {
    164       for (size_t j = 0; j <= i; j++) {
    165         for (size_t k = 0; k <= i - j; k++) {
    166           uint_config.emplace_back(i, j, k);
    167         }
    168       }
    169     }
    170     uint_config.emplace_back(log_alpha_size, 0, 0);
    171     uint_config_dec.resize(uint_config.size());
    172     BitWriter writer;
    173     BitWriter::Allotment allotment(&writer, 10 * uint_config.size());
    174     EncodeUintConfigs(uint_config, &writer, log_alpha_size);
    175     allotment.ReclaimAndCharge(&writer, 0, nullptr);
    176     writer.ZeroPadToByte();
    177     BitReader br(writer.GetSpan());
    178     EXPECT_TRUE(DecodeUintConfigs(log_alpha_size, &uint_config_dec, &br));
    179     EXPECT_TRUE(br.Close());
    180     for (size_t i = 0; i < uint_config.size(); i++) {
    181       EXPECT_EQ(uint_config[i].split_token, uint_config_dec[i].split_token);
    182       EXPECT_EQ(uint_config[i].msb_in_token, uint_config_dec[i].msb_in_token);
    183       EXPECT_EQ(uint_config[i].lsb_in_token, uint_config_dec[i].lsb_in_token);
    184     }
    185   }
    186 }
    187 
    188 void TestCheckpointing(bool ans, bool lz77) {
    189   std::vector<std::vector<Token>> input_values(1);
    190   for (size_t i = 0; i < 1024; i++) {
    191     input_values[0].emplace_back(0, i % 4);
    192   }
    193   // up to lz77 window size.
    194   for (size_t i = 0; i < (1 << 20) - 1022; i++) {
    195     input_values[0].emplace_back(0, (i % 5) + 4);
    196   }
    197   // Ensure that when the window wraps around, new values are different.
    198   input_values[0].emplace_back(0, 0);
    199   for (size_t i = 0; i < 1024; i++) {
    200     input_values[0].emplace_back(0, i % 4);
    201   }
    202 
    203   std::vector<uint8_t> context_map;
    204   EntropyEncodingData codes;
    205   HistogramParams params;
    206   params.lz77_method = lz77 ? HistogramParams::LZ77Method::kLZ77
    207                             : HistogramParams::LZ77Method::kNone;
    208   params.force_huffman = !ans;
    209 
    210   BitWriter writer;
    211   {
    212     auto input_values_copy = input_values;
    213     BuildAndEncodeHistograms(params, 1, input_values_copy, &codes, &context_map,
    214                              &writer, 0, nullptr);
    215     WriteTokens(input_values_copy[0], codes, context_map, 0, &writer, 0,
    216                 nullptr);
    217     writer.ZeroPadToByte();
    218   }
    219 
    220   // We do not truncate the output. Reading past the end reads out zeroes
    221   // anyway.
    222   BitReader br(writer.GetSpan());
    223   Status status = true;
    224   {
    225     BitReaderScopedCloser bc(&br, &status);
    226 
    227     std::vector<uint8_t> dec_context_map;
    228     ANSCode decoded_codes;
    229     ASSERT_TRUE(DecodeHistograms(&br, 1, &decoded_codes, &dec_context_map));
    230     ASSERT_EQ(dec_context_map, context_map);
    231     ANSSymbolReader reader(&decoded_codes, &br);
    232 
    233     ANSSymbolReader::Checkpoint checkpoint;
    234     size_t br_pos = 0;
    235     constexpr size_t kInterval = ANSSymbolReader::kMaxCheckpointInterval - 2;
    236     for (size_t i = 0; i < input_values[0].size(); i++) {
    237       if (i % kInterval == 0 && i > 0) {
    238         reader.Restore(checkpoint);
    239         ASSERT_TRUE(br.Close());
    240         br = BitReader(writer.GetSpan());
    241         br.SkipBits(br_pos);
    242         for (size_t j = i - kInterval; j < i; j++) {
    243           Token symbol = input_values[0][j];
    244           uint32_t read_symbol =
    245               reader.ReadHybridUint(symbol.context, &br, dec_context_map);
    246           ASSERT_EQ(read_symbol, symbol.value) << "j = " << j;
    247         }
    248       }
    249       if (i % kInterval == 0) {
    250         reader.Save(&checkpoint);
    251         br_pos = br.TotalBitsConsumed();
    252       }
    253       Token symbol = input_values[0][i];
    254       uint32_t read_symbol =
    255           reader.ReadHybridUint(symbol.context, &br, dec_context_map);
    256       ASSERT_EQ(read_symbol, symbol.value) << "i = " << i;
    257     }
    258     ASSERT_TRUE(reader.CheckANSFinalState());
    259   }
    260   EXPECT_TRUE(status);
    261 }
    262 
    263 TEST(ANSTest, TestCheckpointingANS) {
    264   TestCheckpointing(/*ans=*/true, /*lz77=*/false);
    265 }
    266 
    267 TEST(ANSTest, TestCheckpointingPrefix) {
    268   TestCheckpointing(/*ans=*/false, /*lz77=*/false);
    269 }
    270 
    271 TEST(ANSTest, TestCheckpointingANSLZ77) {
    272   TestCheckpointing(/*ans=*/true, /*lz77=*/true);
    273 }
    274 
    275 TEST(ANSTest, TestCheckpointingPrefixLZ77) {
    276   TestCheckpointing(/*ans=*/false, /*lz77=*/true);
    277 }
    278 
    279 }  // namespace
    280 }  // namespace jxl