dec_ans.cc (13487B)
1 // Copyright (c) the JPEG XL Project Authors. All rights reserved. 2 // 3 // Use of this source code is governed by a BSD-style 4 // license that can be found in the LICENSE file. 5 6 #include "lib/jxl/dec_ans.h" 7 8 #include <stdint.h> 9 10 #include <vector> 11 12 #include "lib/jxl/ans_common.h" 13 #include "lib/jxl/ans_params.h" 14 #include "lib/jxl/base/bits.h" 15 #include "lib/jxl/base/printf_macros.h" 16 #include "lib/jxl/base/status.h" 17 #include "lib/jxl/dec_context_map.h" 18 #include "lib/jxl/fields.h" 19 20 namespace jxl { 21 namespace { 22 23 // Decodes a number in the range [0..255], by reading 1 - 11 bits. 24 inline int DecodeVarLenUint8(BitReader* input) { 25 if (input->ReadFixedBits<1>()) { 26 int nbits = static_cast<int>(input->ReadFixedBits<3>()); 27 if (nbits == 0) { 28 return 1; 29 } else { 30 return static_cast<int>(input->ReadBits(nbits)) + (1 << nbits); 31 } 32 } 33 return 0; 34 } 35 36 // Decodes a number in the range [0..65535], by reading 1 - 21 bits. 37 inline int DecodeVarLenUint16(BitReader* input) { 38 if (input->ReadFixedBits<1>()) { 39 int nbits = static_cast<int>(input->ReadFixedBits<4>()); 40 if (nbits == 0) { 41 return 1; 42 } else { 43 return static_cast<int>(input->ReadBits(nbits)) + (1 << nbits); 44 } 45 } 46 return 0; 47 } 48 49 Status ReadHistogram(int precision_bits, std::vector<int32_t>* counts, 50 BitReader* input) { 51 int simple_code = input->ReadBits(1); 52 if (simple_code == 1) { 53 int i; 54 int symbols[2] = {0}; 55 int max_symbol = 0; 56 const int num_symbols = input->ReadBits(1) + 1; 57 for (i = 0; i < num_symbols; ++i) { 58 symbols[i] = DecodeVarLenUint8(input); 59 if (symbols[i] > max_symbol) max_symbol = symbols[i]; 60 } 61 counts->resize(max_symbol + 1); 62 if (num_symbols == 1) { 63 (*counts)[symbols[0]] = 1 << precision_bits; 64 } else { 65 if (symbols[0] == symbols[1]) { // corrupt data 66 return false; 67 } 68 (*counts)[symbols[0]] = input->ReadBits(precision_bits); 69 (*counts)[symbols[1]] = (1 << precision_bits) - (*counts)[symbols[0]]; 70 } 71 } else { 72 int is_flat = input->ReadBits(1); 73 if (is_flat == 1) { 74 int alphabet_size = DecodeVarLenUint8(input) + 1; 75 *counts = CreateFlatHistogram(alphabet_size, 1 << precision_bits); 76 return true; 77 } 78 79 uint32_t shift; 80 { 81 // TODO(veluca): speed up reading with table lookups. 82 int upper_bound_log = FloorLog2Nonzero(ANS_LOG_TAB_SIZE + 1); 83 int log = 0; 84 for (; log < upper_bound_log; log++) { 85 if (input->ReadFixedBits<1>() == 0) break; 86 } 87 shift = (input->ReadBits(log) | (1 << log)) - 1; 88 if (shift > ANS_LOG_TAB_SIZE + 1) { 89 return JXL_FAILURE("Invalid shift value"); 90 } 91 } 92 93 int length = DecodeVarLenUint8(input) + 3; 94 counts->resize(length); 95 int total_count = 0; 96 97 static const uint8_t huff[128][2] = { 98 {3, 10}, {7, 12}, {3, 7}, {4, 3}, {3, 6}, {3, 8}, {3, 9}, {4, 5}, 99 {3, 10}, {4, 4}, {3, 7}, {4, 1}, {3, 6}, {3, 8}, {3, 9}, {4, 2}, 100 {3, 10}, {5, 0}, {3, 7}, {4, 3}, {3, 6}, {3, 8}, {3, 9}, {4, 5}, 101 {3, 10}, {4, 4}, {3, 7}, {4, 1}, {3, 6}, {3, 8}, {3, 9}, {4, 2}, 102 {3, 10}, {6, 11}, {3, 7}, {4, 3}, {3, 6}, {3, 8}, {3, 9}, {4, 5}, 103 {3, 10}, {4, 4}, {3, 7}, {4, 1}, {3, 6}, {3, 8}, {3, 9}, {4, 2}, 104 {3, 10}, {5, 0}, {3, 7}, {4, 3}, {3, 6}, {3, 8}, {3, 9}, {4, 5}, 105 {3, 10}, {4, 4}, {3, 7}, {4, 1}, {3, 6}, {3, 8}, {3, 9}, {4, 2}, 106 {3, 10}, {7, 13}, {3, 7}, {4, 3}, {3, 6}, {3, 8}, {3, 9}, {4, 5}, 107 {3, 10}, {4, 4}, {3, 7}, {4, 1}, {3, 6}, {3, 8}, {3, 9}, {4, 2}, 108 {3, 10}, {5, 0}, {3, 7}, {4, 3}, {3, 6}, {3, 8}, {3, 9}, {4, 5}, 109 {3, 10}, {4, 4}, {3, 7}, {4, 1}, {3, 6}, {3, 8}, {3, 9}, {4, 2}, 110 {3, 10}, {6, 11}, {3, 7}, {4, 3}, {3, 6}, {3, 8}, {3, 9}, {4, 5}, 111 {3, 10}, {4, 4}, {3, 7}, {4, 1}, {3, 6}, {3, 8}, {3, 9}, {4, 2}, 112 {3, 10}, {5, 0}, {3, 7}, {4, 3}, {3, 6}, {3, 8}, {3, 9}, {4, 5}, 113 {3, 10}, {4, 4}, {3, 7}, {4, 1}, {3, 6}, {3, 8}, {3, 9}, {4, 2}, 114 }; 115 116 std::vector<int> logcounts(counts->size()); 117 int omit_log = -1; 118 int omit_pos = -1; 119 // This array remembers which symbols have an RLE length. 120 std::vector<int> same(counts->size(), 0); 121 for (size_t i = 0; i < logcounts.size(); ++i) { 122 input->Refill(); // for PeekFixedBits + Advance 123 int idx = input->PeekFixedBits<7>(); 124 input->Consume(huff[idx][0]); 125 logcounts[i] = huff[idx][1]; 126 // The RLE symbol. 127 if (logcounts[i] == ANS_LOG_TAB_SIZE + 1) { 128 int rle_length = DecodeVarLenUint8(input); 129 same[i] = rle_length + 5; 130 i += rle_length + 3; 131 continue; 132 } 133 if (logcounts[i] > omit_log) { 134 omit_log = logcounts[i]; 135 omit_pos = i; 136 } 137 } 138 // Invalid input, e.g. due to invalid usage of RLE. 139 if (omit_pos < 0) return JXL_FAILURE("Invalid histogram."); 140 if (static_cast<size_t>(omit_pos) + 1 < logcounts.size() && 141 logcounts[omit_pos + 1] == ANS_TAB_SIZE + 1) { 142 return JXL_FAILURE("Invalid histogram."); 143 } 144 int prev = 0; 145 int numsame = 0; 146 for (size_t i = 0; i < logcounts.size(); ++i) { 147 if (same[i]) { 148 // RLE sequence, let this loop output the same count for the next 149 // iterations. 150 numsame = same[i] - 1; 151 prev = i > 0 ? (*counts)[i - 1] : 0; 152 } 153 if (numsame > 0) { 154 (*counts)[i] = prev; 155 numsame--; 156 } else { 157 unsigned int code = logcounts[i]; 158 // omit_pos may not be negative at this point (checked before). 159 if (i == static_cast<size_t>(omit_pos)) { 160 continue; 161 } else if (code == 0) { 162 continue; 163 } else if (code == 1) { 164 (*counts)[i] = 1; 165 } else { 166 int bitcount = GetPopulationCountPrecision(code - 1, shift); 167 (*counts)[i] = (1u << (code - 1)) + 168 (input->ReadBits(bitcount) << (code - 1 - bitcount)); 169 } 170 } 171 total_count += (*counts)[i]; 172 } 173 (*counts)[omit_pos] = (1 << precision_bits) - total_count; 174 if ((*counts)[omit_pos] <= 0) { 175 // The histogram we've read sums to more than total_count (including at 176 // least 1 for the omitted value). 177 return JXL_FAILURE("Invalid histogram count."); 178 } 179 } 180 return true; 181 } 182 183 } // namespace 184 185 Status DecodeANSCodes(const size_t num_histograms, 186 const size_t max_alphabet_size, BitReader* in, 187 ANSCode* result) { 188 result->degenerate_symbols.resize(num_histograms, -1); 189 if (result->use_prefix_code) { 190 JXL_ASSERT(max_alphabet_size <= 1 << PREFIX_MAX_BITS); 191 result->huffman_data.resize(num_histograms); 192 std::vector<uint16_t> alphabet_sizes(num_histograms); 193 for (size_t c = 0; c < num_histograms; c++) { 194 alphabet_sizes[c] = DecodeVarLenUint16(in) + 1; 195 if (alphabet_sizes[c] > max_alphabet_size) { 196 return JXL_FAILURE("Alphabet size is too long: %u", alphabet_sizes[c]); 197 } 198 } 199 for (size_t c = 0; c < num_histograms; c++) { 200 if (alphabet_sizes[c] > 1) { 201 if (!result->huffman_data[c].ReadFromBitStream(alphabet_sizes[c], in)) { 202 if (!in->AllReadsWithinBounds()) { 203 return JXL_STATUS(StatusCode::kNotEnoughBytes, 204 "Not enough bytes for huffman code"); 205 } 206 return JXL_FAILURE("Invalid huffman tree number %" PRIuS 207 ", alphabet size %u", 208 c, alphabet_sizes[c]); 209 } 210 } else { 211 // 0-bit codes does not require extension tables. 212 result->huffman_data[c].table_.clear(); 213 result->huffman_data[c].table_.resize(1u << kHuffmanTableBits); 214 } 215 for (const auto& h : result->huffman_data[c].table_) { 216 if (h.bits <= kHuffmanTableBits) { 217 result->UpdateMaxNumBits(c, h.value); 218 } 219 } 220 } 221 } else { 222 JXL_ASSERT(max_alphabet_size <= ANS_MAX_ALPHABET_SIZE); 223 result->alias_tables = 224 AllocateArray(num_histograms * (1 << result->log_alpha_size) * 225 sizeof(AliasTable::Entry)); 226 AliasTable::Entry* alias_tables = 227 reinterpret_cast<AliasTable::Entry*>(result->alias_tables.get()); 228 for (size_t c = 0; c < num_histograms; ++c) { 229 std::vector<int32_t> counts; 230 if (!ReadHistogram(ANS_LOG_TAB_SIZE, &counts, in)) { 231 return JXL_FAILURE("Invalid histogram bitstream."); 232 } 233 if (counts.size() > max_alphabet_size) { 234 return JXL_FAILURE("Alphabet size is too long: %" PRIuS, counts.size()); 235 } 236 while (!counts.empty() && counts.back() == 0) { 237 counts.pop_back(); 238 } 239 for (size_t s = 0; s < counts.size(); s++) { 240 if (counts[s] != 0) { 241 result->UpdateMaxNumBits(c, s); 242 } 243 } 244 // InitAliasTable "fixes" empty counts to contain degenerate "0" symbol. 245 int degenerate_symbol = counts.empty() ? 0 : (counts.size() - 1); 246 for (int s = 0; s < degenerate_symbol; ++s) { 247 if (counts[s] != 0) { 248 degenerate_symbol = -1; 249 break; 250 } 251 } 252 result->degenerate_symbols[c] = degenerate_symbol; 253 InitAliasTable(counts, ANS_TAB_SIZE, result->log_alpha_size, 254 alias_tables + c * (1 << result->log_alpha_size)); 255 } 256 } 257 return true; 258 } 259 Status DecodeUintConfig(size_t log_alpha_size, HybridUintConfig* uint_config, 260 BitReader* br) { 261 br->Refill(); 262 size_t split_exponent = br->ReadBits(CeilLog2Nonzero(log_alpha_size + 1)); 263 size_t msb_in_token = 0; 264 size_t lsb_in_token = 0; 265 if (split_exponent != log_alpha_size) { 266 // otherwise, msb/lsb don't matter. 267 size_t nbits = CeilLog2Nonzero(split_exponent + 1); 268 msb_in_token = br->ReadBits(nbits); 269 if (msb_in_token > split_exponent) { 270 // This could be invalid here already and we need to check this before 271 // we use its value to read more bits. 272 return JXL_FAILURE("Invalid HybridUintConfig"); 273 } 274 nbits = CeilLog2Nonzero(split_exponent - msb_in_token + 1); 275 lsb_in_token = br->ReadBits(nbits); 276 } 277 if (lsb_in_token + msb_in_token > split_exponent) { 278 return JXL_FAILURE("Invalid HybridUintConfig"); 279 } 280 *uint_config = HybridUintConfig(split_exponent, msb_in_token, lsb_in_token); 281 return true; 282 } 283 284 Status DecodeUintConfigs(size_t log_alpha_size, 285 std::vector<HybridUintConfig>* uint_config, 286 BitReader* br) { 287 // TODO(veluca): RLE? 288 for (auto& cfg : *uint_config) { 289 JXL_RETURN_IF_ERROR(DecodeUintConfig(log_alpha_size, &cfg, br)); 290 } 291 return true; 292 } 293 294 LZ77Params::LZ77Params() { Bundle::Init(this); } 295 Status LZ77Params::VisitFields(Visitor* JXL_RESTRICT visitor) { 296 JXL_QUIET_RETURN_IF_ERROR(visitor->Bool(false, &enabled)); 297 if (!visitor->Conditional(enabled)) return true; 298 JXL_QUIET_RETURN_IF_ERROR(visitor->U32(Val(224), Val(512), Val(4096), 299 BitsOffset(15, 8), 224, &min_symbol)); 300 JXL_QUIET_RETURN_IF_ERROR(visitor->U32(Val(3), Val(4), BitsOffset(2, 5), 301 BitsOffset(8, 9), 3, &min_length)); 302 return true; 303 } 304 305 void ANSCode::UpdateMaxNumBits(size_t ctx, size_t symbol) { 306 HybridUintConfig* cfg = &uint_config[ctx]; 307 // LZ77 symbols use a different uint config. 308 if (lz77.enabled && lz77.nonserialized_distance_context != ctx && 309 symbol >= lz77.min_symbol) { 310 symbol -= lz77.min_symbol; 311 cfg = &lz77.length_uint_config; 312 } 313 size_t split_token = cfg->split_token; 314 size_t msb_in_token = cfg->msb_in_token; 315 size_t lsb_in_token = cfg->lsb_in_token; 316 size_t split_exponent = cfg->split_exponent; 317 if (symbol < split_token) { 318 max_num_bits = std::max(max_num_bits, split_exponent); 319 return; 320 } 321 uint32_t n_extra_bits = 322 split_exponent - (msb_in_token + lsb_in_token) + 323 ((symbol - split_token) >> (msb_in_token + lsb_in_token)); 324 size_t total_bits = msb_in_token + lsb_in_token + n_extra_bits + 1; 325 max_num_bits = std::max(max_num_bits, total_bits); 326 } 327 328 Status DecodeHistograms(BitReader* br, size_t num_contexts, ANSCode* code, 329 std::vector<uint8_t>* context_map, bool disallow_lz77) { 330 JXL_RETURN_IF_ERROR(Bundle::Read(br, &code->lz77)); 331 if (code->lz77.enabled) { 332 num_contexts++; 333 JXL_RETURN_IF_ERROR(DecodeUintConfig(/*log_alpha_size=*/8, 334 &code->lz77.length_uint_config, br)); 335 } 336 if (code->lz77.enabled && disallow_lz77) { 337 return JXL_FAILURE("Using LZ77 when explicitly disallowed"); 338 } 339 size_t num_histograms = 1; 340 context_map->resize(num_contexts); 341 if (num_contexts > 1) { 342 JXL_RETURN_IF_ERROR(DecodeContextMap(context_map, &num_histograms, br)); 343 } 344 JXL_DEBUG_V( 345 4, "Decoded context map of size %" PRIuS " and %" PRIuS " histograms", 346 num_contexts, num_histograms); 347 code->lz77.nonserialized_distance_context = context_map->back(); 348 code->use_prefix_code = static_cast<bool>(br->ReadFixedBits<1>()); 349 if (code->use_prefix_code) { 350 code->log_alpha_size = PREFIX_MAX_BITS; 351 } else { 352 code->log_alpha_size = br->ReadFixedBits<2>() + 5; 353 } 354 code->uint_config.resize(num_histograms); 355 JXL_RETURN_IF_ERROR( 356 DecodeUintConfigs(code->log_alpha_size, &code->uint_config, br)); 357 const size_t max_alphabet_size = 1 << code->log_alpha_size; 358 JXL_RETURN_IF_ERROR( 359 DecodeANSCodes(num_histograms, max_alphabet_size, br, code)); 360 return true; 361 } 362 363 } // namespace jxl