libjxl

FORK: libjxl patches used on blog
git clone https://git.neptards.moe/blog/libjxl.git
Log | Files | Refs | Submodules | README | LICENSE

dec_context_map.cc (2796B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 #include "lib/jxl/dec_context_map.h"
      7 
      8 #include <algorithm>
      9 #include <cstdint>
     10 #include <vector>
     11 
     12 #include "lib/jxl/ans_params.h"
     13 #include "lib/jxl/base/status.h"
     14 #include "lib/jxl/dec_ans.h"
     15 #include "lib/jxl/entropy_coder.h"
     16 #include "lib/jxl/inverse_mtf-inl.h"
     17 
     18 namespace jxl {
     19 
     20 namespace {
     21 
     22 Status VerifyContextMap(const std::vector<uint8_t>& context_map,
     23                         const size_t num_htrees) {
     24   std::vector<bool> have_htree(num_htrees);
     25   size_t num_found = 0;
     26   for (const uint8_t htree : context_map) {
     27     if (htree >= num_htrees) {
     28       return JXL_FAILURE("Invalid histogram index in context map.");
     29     }
     30     if (!have_htree[htree]) {
     31       have_htree[htree] = true;
     32       ++num_found;
     33     }
     34   }
     35   if (num_found != num_htrees) {
     36     return JXL_FAILURE("Incomplete context map.");
     37   }
     38   return true;
     39 }
     40 
     41 }  // namespace
     42 
     43 Status DecodeContextMap(std::vector<uint8_t>* context_map, size_t* num_htrees,
     44                         BitReader* input) {
     45   bool is_simple = static_cast<bool>(input->ReadFixedBits<1>());
     46   if (is_simple) {
     47     int bits_per_entry = input->ReadFixedBits<2>();
     48     if (bits_per_entry != 0) {
     49       for (uint8_t& entry : *context_map) {
     50         entry = input->ReadBits(bits_per_entry);
     51       }
     52     } else {
     53       std::fill(context_map->begin(), context_map->end(), 0);
     54     }
     55   } else {
     56     bool use_mtf = static_cast<bool>(input->ReadFixedBits<1>());
     57     ANSCode code;
     58     std::vector<uint8_t> sink_ctx_map;
     59     // Usage of LZ77 is disallowed if decoding only two symbols. This doesn't
     60     // make sense in non-malicious bitstreams, and could cause a stack overflow
     61     // in malicious bitstreams by making every context map require its own
     62     // context map.
     63     JXL_RETURN_IF_ERROR(
     64         DecodeHistograms(input, 1, &code, &sink_ctx_map,
     65                          /*disallow_lz77=*/context_map->size() <= 2));
     66     ANSSymbolReader reader(&code, input);
     67     size_t i = 0;
     68     uint32_t maxsym = 0;
     69     while (i < context_map->size()) {
     70       uint32_t sym = reader.ReadHybridUintInlined</*uses_lz77=*/true>(
     71           0, input, sink_ctx_map);
     72       maxsym = sym > maxsym ? sym : maxsym;
     73       (*context_map)[i] = sym;
     74       i++;
     75     }
     76     if (maxsym >= kMaxClusters) {
     77       return JXL_FAILURE("Invalid cluster ID");
     78     }
     79     if (!reader.CheckANSFinalState()) {
     80       return JXL_FAILURE("Invalid context map");
     81     }
     82     if (use_mtf) {
     83       InverseMoveToFrontTransform(context_map->data(), context_map->size());
     84     }
     85   }
     86   *num_htrees = *std::max_element(context_map->begin(), context_map->end()) + 1;
     87   return VerifyContextMap(*context_map, *num_htrees);
     88 }
     89 
     90 }  // namespace jxl