ans_test.cc (9187B)
1 // Copyright (c) the JPEG XL Project Authors. All rights reserved. 2 // 3 // Use of this source code is governed by a BSD-style 4 // license that can be found in the LICENSE file. 5 6 #include <stddef.h> 7 #include <stdint.h> 8 9 #include <vector> 10 11 #include "lib/jxl/ans_params.h" 12 #include "lib/jxl/base/random.h" 13 #include "lib/jxl/base/span.h" 14 #include "lib/jxl/dec_ans.h" 15 #include "lib/jxl/dec_bit_reader.h" 16 #include "lib/jxl/enc_ans.h" 17 #include "lib/jxl/enc_aux_out.h" 18 #include "lib/jxl/enc_bit_writer.h" 19 #include "lib/jxl/testing.h" 20 21 namespace jxl { 22 namespace { 23 24 void RoundtripTestcase(int n_histograms, int alphabet_size, 25 const std::vector<Token>& input_values) { 26 constexpr uint16_t kMagic1 = 0x9e33; 27 constexpr uint16_t kMagic2 = 0x8b04; 28 29 BitWriter writer; 30 // Space for magic bytes. 31 BitWriter::Allotment allotment_magic1(&writer, 16); 32 writer.Write(16, kMagic1); 33 allotment_magic1.ReclaimAndCharge(&writer, 0, nullptr); 34 35 std::vector<uint8_t> context_map; 36 EntropyEncodingData codes; 37 std::vector<std::vector<Token>> input_values_vec; 38 input_values_vec.push_back(input_values); 39 40 BuildAndEncodeHistograms(HistogramParams(), n_histograms, input_values_vec, 41 &codes, &context_map, &writer, 0, nullptr); 42 WriteTokens(input_values_vec[0], codes, context_map, 0, &writer, 0, nullptr); 43 44 // Magic bytes + padding 45 BitWriter::Allotment allotment_magic2(&writer, 24); 46 writer.Write(16, kMagic2); 47 writer.ZeroPadToByte(); 48 allotment_magic2.ReclaimAndCharge(&writer, 0, nullptr); 49 50 // We do not truncate the output. Reading past the end reads out zeroes 51 // anyway. 52 BitReader br(writer.GetSpan()); 53 54 ASSERT_EQ(br.ReadBits(16), kMagic1); 55 56 std::vector<uint8_t> dec_context_map; 57 ANSCode decoded_codes; 58 ASSERT_TRUE( 59 DecodeHistograms(&br, n_histograms, &decoded_codes, &dec_context_map)); 60 ASSERT_EQ(dec_context_map, context_map); 61 ANSSymbolReader reader(&decoded_codes, &br); 62 63 for (const Token& symbol : input_values) { 64 uint32_t read_symbol = 65 reader.ReadHybridUint(symbol.context, &br, dec_context_map); 66 ASSERT_EQ(read_symbol, symbol.value); 67 } 68 ASSERT_TRUE(reader.CheckANSFinalState()); 69 70 ASSERT_EQ(br.ReadBits(16), kMagic2); 71 EXPECT_TRUE(br.Close()); 72 } 73 74 TEST(ANSTest, EmptyRoundtrip) { 75 RoundtripTestcase(2, ANS_MAX_ALPHABET_SIZE, std::vector<Token>()); 76 } 77 78 TEST(ANSTest, SingleSymbolRoundtrip) { 79 for (uint32_t i = 0; i < ANS_MAX_ALPHABET_SIZE; i++) { 80 RoundtripTestcase(2, ANS_MAX_ALPHABET_SIZE, {{0, i}}); 81 } 82 for (uint32_t i = 0; i < ANS_MAX_ALPHABET_SIZE; i++) { 83 RoundtripTestcase(2, ANS_MAX_ALPHABET_SIZE, 84 std::vector<Token>(1024, {0, i})); 85 } 86 } 87 88 #if defined(ADDRESS_SANITIZER) || defined(MEMORY_SANITIZER) || \ 89 defined(THREAD_SANITIZER) 90 constexpr size_t kReps = 3; 91 #else 92 constexpr size_t kReps = 10; 93 #endif 94 95 void RoundtripRandomStream(int alphabet_size, size_t reps = kReps, 96 size_t num = 1 << 18) { 97 constexpr int kNumHistograms = 3; 98 Rng rng(0); 99 for (size_t i = 0; i < reps; i++) { 100 std::vector<Token> symbols; 101 for (size_t j = 0; j < num; j++) { 102 int context = rng.UniformI(0, kNumHistograms); 103 int value = rng.UniformU(0, alphabet_size); 104 symbols.emplace_back(context, value); 105 } 106 RoundtripTestcase(kNumHistograms, alphabet_size, symbols); 107 } 108 } 109 110 void RoundtripRandomUnbalancedStream(int alphabet_size) { 111 constexpr int kNumHistograms = 3; 112 constexpr int kPrecision = 1 << 10; 113 Rng rng(0); 114 for (size_t i = 0; i < kReps; i++) { 115 std::vector<int> distributions[kNumHistograms] = {}; 116 for (auto& distr : distributions) { 117 distr.resize(kPrecision); 118 int symbol = 0; 119 int remaining = 1; 120 for (int k = 0; k < kPrecision; k++) { 121 if (remaining == 0) { 122 if (symbol < alphabet_size - 1) symbol++; 123 // There is no meaning behind this distribution: it's anything that 124 // will create a nonuniform distribution and won't have too few 125 // symbols usually. Also we want different distributions we get to be 126 // sufficiently dissimilar. 127 remaining = rng.UniformU(0, kPrecision - k + 1); 128 } 129 distr[k] = symbol; 130 remaining--; 131 } 132 } 133 std::vector<Token> symbols; 134 for (int j = 0; j < 1 << 18; j++) { 135 int context = rng.UniformI(0, kNumHistograms); 136 int value = rng.UniformU(0, kPrecision); 137 symbols.emplace_back(context, value); 138 } 139 RoundtripTestcase(kNumHistograms + 1, alphabet_size, symbols); 140 } 141 } 142 143 TEST(ANSTest, RandomStreamRoundtrip3Small) { RoundtripRandomStream(3, 1, 16); } 144 145 TEST(ANSTest, RandomStreamRoundtrip3) { RoundtripRandomStream(3); } 146 147 TEST(ANSTest, RandomStreamRoundtripBig) { 148 RoundtripRandomStream(ANS_MAX_ALPHABET_SIZE); 149 } 150 151 TEST(ANSTest, RandomUnbalancedStreamRoundtrip3) { 152 RoundtripRandomUnbalancedStream(3); 153 } 154 155 TEST(ANSTest, RandomUnbalancedStreamRoundtripBig) { 156 RoundtripRandomUnbalancedStream(ANS_MAX_ALPHABET_SIZE); 157 } 158 159 TEST(ANSTest, UintConfigRoundtrip) { 160 for (size_t log_alpha_size = 5; log_alpha_size <= 8; log_alpha_size++) { 161 std::vector<HybridUintConfig> uint_config; 162 std::vector<HybridUintConfig> uint_config_dec; 163 for (size_t i = 0; i < log_alpha_size; i++) { 164 for (size_t j = 0; j <= i; j++) { 165 for (size_t k = 0; k <= i - j; k++) { 166 uint_config.emplace_back(i, j, k); 167 } 168 } 169 } 170 uint_config.emplace_back(log_alpha_size, 0, 0); 171 uint_config_dec.resize(uint_config.size()); 172 BitWriter writer; 173 BitWriter::Allotment allotment(&writer, 10 * uint_config.size()); 174 EncodeUintConfigs(uint_config, &writer, log_alpha_size); 175 allotment.ReclaimAndCharge(&writer, 0, nullptr); 176 writer.ZeroPadToByte(); 177 BitReader br(writer.GetSpan()); 178 EXPECT_TRUE(DecodeUintConfigs(log_alpha_size, &uint_config_dec, &br)); 179 EXPECT_TRUE(br.Close()); 180 for (size_t i = 0; i < uint_config.size(); i++) { 181 EXPECT_EQ(uint_config[i].split_token, uint_config_dec[i].split_token); 182 EXPECT_EQ(uint_config[i].msb_in_token, uint_config_dec[i].msb_in_token); 183 EXPECT_EQ(uint_config[i].lsb_in_token, uint_config_dec[i].lsb_in_token); 184 } 185 } 186 } 187 188 void TestCheckpointing(bool ans, bool lz77) { 189 std::vector<std::vector<Token>> input_values(1); 190 for (size_t i = 0; i < 1024; i++) { 191 input_values[0].emplace_back(0, i % 4); 192 } 193 // up to lz77 window size. 194 for (size_t i = 0; i < (1 << 20) - 1022; i++) { 195 input_values[0].emplace_back(0, (i % 5) + 4); 196 } 197 // Ensure that when the window wraps around, new values are different. 198 input_values[0].emplace_back(0, 0); 199 for (size_t i = 0; i < 1024; i++) { 200 input_values[0].emplace_back(0, i % 4); 201 } 202 203 std::vector<uint8_t> context_map; 204 EntropyEncodingData codes; 205 HistogramParams params; 206 params.lz77_method = lz77 ? HistogramParams::LZ77Method::kLZ77 207 : HistogramParams::LZ77Method::kNone; 208 params.force_huffman = !ans; 209 210 BitWriter writer; 211 { 212 auto input_values_copy = input_values; 213 BuildAndEncodeHistograms(params, 1, input_values_copy, &codes, &context_map, 214 &writer, 0, nullptr); 215 WriteTokens(input_values_copy[0], codes, context_map, 0, &writer, 0, 216 nullptr); 217 writer.ZeroPadToByte(); 218 } 219 220 // We do not truncate the output. Reading past the end reads out zeroes 221 // anyway. 222 BitReader br(writer.GetSpan()); 223 Status status = true; 224 { 225 BitReaderScopedCloser bc(&br, &status); 226 227 std::vector<uint8_t> dec_context_map; 228 ANSCode decoded_codes; 229 ASSERT_TRUE(DecodeHistograms(&br, 1, &decoded_codes, &dec_context_map)); 230 ASSERT_EQ(dec_context_map, context_map); 231 ANSSymbolReader reader(&decoded_codes, &br); 232 233 ANSSymbolReader::Checkpoint checkpoint; 234 size_t br_pos = 0; 235 constexpr size_t kInterval = ANSSymbolReader::kMaxCheckpointInterval - 2; 236 for (size_t i = 0; i < input_values[0].size(); i++) { 237 if (i % kInterval == 0 && i > 0) { 238 reader.Restore(checkpoint); 239 ASSERT_TRUE(br.Close()); 240 br = BitReader(writer.GetSpan()); 241 br.SkipBits(br_pos); 242 for (size_t j = i - kInterval; j < i; j++) { 243 Token symbol = input_values[0][j]; 244 uint32_t read_symbol = 245 reader.ReadHybridUint(symbol.context, &br, dec_context_map); 246 ASSERT_EQ(read_symbol, symbol.value) << "j = " << j; 247 } 248 } 249 if (i % kInterval == 0) { 250 reader.Save(&checkpoint); 251 br_pos = br.TotalBitsConsumed(); 252 } 253 Token symbol = input_values[0][i]; 254 uint32_t read_symbol = 255 reader.ReadHybridUint(symbol.context, &br, dec_context_map); 256 ASSERT_EQ(read_symbol, symbol.value) << "i = " << i; 257 } 258 ASSERT_TRUE(reader.CheckANSFinalState()); 259 } 260 EXPECT_TRUE(status); 261 } 262 263 TEST(ANSTest, TestCheckpointingANS) { 264 TestCheckpointing(/*ans=*/true, /*lz77=*/false); 265 } 266 267 TEST(ANSTest, TestCheckpointingPrefix) { 268 TestCheckpointing(/*ans=*/false, /*lz77=*/false); 269 } 270 271 TEST(ANSTest, TestCheckpointingANSLZ77) { 272 TestCheckpointing(/*ans=*/true, /*lz77=*/true); 273 } 274 275 TEST(ANSTest, TestCheckpointingPrefixLZ77) { 276 TestCheckpointing(/*ans=*/false, /*lz77=*/true); 277 } 278 279 } // namespace 280 } // namespace jxl