dec_context_map.cc (2796B)
1 // Copyright (c) the JPEG XL Project Authors. All rights reserved. 2 // 3 // Use of this source code is governed by a BSD-style 4 // license that can be found in the LICENSE file. 5 6 #include "lib/jxl/dec_context_map.h" 7 8 #include <algorithm> 9 #include <cstdint> 10 #include <vector> 11 12 #include "lib/jxl/ans_params.h" 13 #include "lib/jxl/base/status.h" 14 #include "lib/jxl/dec_ans.h" 15 #include "lib/jxl/entropy_coder.h" 16 #include "lib/jxl/inverse_mtf-inl.h" 17 18 namespace jxl { 19 20 namespace { 21 22 Status VerifyContextMap(const std::vector<uint8_t>& context_map, 23 const size_t num_htrees) { 24 std::vector<bool> have_htree(num_htrees); 25 size_t num_found = 0; 26 for (const uint8_t htree : context_map) { 27 if (htree >= num_htrees) { 28 return JXL_FAILURE("Invalid histogram index in context map."); 29 } 30 if (!have_htree[htree]) { 31 have_htree[htree] = true; 32 ++num_found; 33 } 34 } 35 if (num_found != num_htrees) { 36 return JXL_FAILURE("Incomplete context map."); 37 } 38 return true; 39 } 40 41 } // namespace 42 43 Status DecodeContextMap(std::vector<uint8_t>* context_map, size_t* num_htrees, 44 BitReader* input) { 45 bool is_simple = static_cast<bool>(input->ReadFixedBits<1>()); 46 if (is_simple) { 47 int bits_per_entry = input->ReadFixedBits<2>(); 48 if (bits_per_entry != 0) { 49 for (uint8_t& entry : *context_map) { 50 entry = input->ReadBits(bits_per_entry); 51 } 52 } else { 53 std::fill(context_map->begin(), context_map->end(), 0); 54 } 55 } else { 56 bool use_mtf = static_cast<bool>(input->ReadFixedBits<1>()); 57 ANSCode code; 58 std::vector<uint8_t> sink_ctx_map; 59 // Usage of LZ77 is disallowed if decoding only two symbols. This doesn't 60 // make sense in non-malicious bitstreams, and could cause a stack overflow 61 // in malicious bitstreams by making every context map require its own 62 // context map. 63 JXL_RETURN_IF_ERROR( 64 DecodeHistograms(input, 1, &code, &sink_ctx_map, 65 /*disallow_lz77=*/context_map->size() <= 2)); 66 ANSSymbolReader reader(&code, input); 67 size_t i = 0; 68 uint32_t maxsym = 0; 69 while (i < context_map->size()) { 70 uint32_t sym = reader.ReadHybridUintInlined</*uses_lz77=*/true>( 71 0, input, sink_ctx_map); 72 maxsym = sym > maxsym ? sym : maxsym; 73 (*context_map)[i] = sym; 74 i++; 75 } 76 if (maxsym >= kMaxClusters) { 77 return JXL_FAILURE("Invalid cluster ID"); 78 } 79 if (!reader.CheckANSFinalState()) { 80 return JXL_FAILURE("Invalid context map"); 81 } 82 if (use_mtf) { 83 InverseMoveToFrontTransform(context_map->data(), context_map->size()); 84 } 85 } 86 *num_htrees = *std::max_element(context_map->begin(), context_map->end()) + 1; 87 return VerifyContextMap(*context_map, *num_htrees); 88 } 89 90 } // namespace jxl