libjxl

FORK: libjxl patches used on blog
git clone https://git.neptards.moe/blog/libjxl.git
Log | Files | Refs | Submodules | README | LICENSE

encoding.cc (28261B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 #include "lib/jxl/modular/encoding/encoding.h"
      7 
      8 #include <stdint.h>
      9 #include <stdlib.h>
     10 
     11 #include <queue>
     12 
     13 #include "lib/jxl/base/printf_macros.h"
     14 #include "lib/jxl/base/scope_guard.h"
     15 #include "lib/jxl/dec_ans.h"
     16 #include "lib/jxl/dec_bit_reader.h"
     17 #include "lib/jxl/frame_dimensions.h"
     18 #include "lib/jxl/image_ops.h"
     19 #include "lib/jxl/modular/encoding/context_predict.h"
     20 #include "lib/jxl/modular/options.h"
     21 #include "lib/jxl/pack_signed.h"
     22 
     23 namespace jxl {
     24 
     25 // Removes all nodes that use a static property (i.e. channel or group ID) from
     26 // the tree and collapses each node on even levels with its two children to
     27 // produce a flatter tree. Also computes whether the resulting tree requires
     28 // using the weighted predictor.
     29 FlatTree FilterTree(const Tree &global_tree,
     30                     std::array<pixel_type, kNumStaticProperties> &static_props,
     31                     size_t *num_props, bool *use_wp, bool *wp_only,
     32                     bool *gradient_only) {
     33   *num_props = 0;
     34   bool has_wp = false;
     35   bool has_non_wp = false;
     36   *gradient_only = true;
     37   const auto mark_property = [&](int32_t p) {
     38     if (p == kWPProp) {
     39       has_wp = true;
     40     } else if (p >= kNumStaticProperties) {
     41       has_non_wp = true;
     42     }
     43     if (p >= kNumStaticProperties && p != kGradientProp) {
     44       *gradient_only = false;
     45     }
     46   };
     47   FlatTree output;
     48   std::queue<size_t> nodes;
     49   nodes.push(0);
     50   // Produces a trimmed and flattened tree by doing a BFS visit of the original
     51   // tree, ignoring branches that are known to be false and proceeding two
     52   // levels at a time to collapse nodes in a flatter tree; if an inner parent
     53   // node has a leaf as a child, the leaf is duplicated and an implicit fake
     54   // node is added. This allows to reduce the number of branches when traversing
     55   // the resulting flat tree.
     56   while (!nodes.empty()) {
     57     size_t cur = nodes.front();
     58     nodes.pop();
     59     // Skip nodes that we can decide now, by jumping directly to their children.
     60     while (global_tree[cur].property < kNumStaticProperties &&
     61            global_tree[cur].property != -1) {
     62       if (static_props[global_tree[cur].property] > global_tree[cur].splitval) {
     63         cur = global_tree[cur].lchild;
     64       } else {
     65         cur = global_tree[cur].rchild;
     66       }
     67     }
     68     FlatDecisionNode flat;
     69     if (global_tree[cur].property == -1) {
     70       flat.property0 = -1;
     71       flat.childID = global_tree[cur].lchild;
     72       flat.predictor = global_tree[cur].predictor;
     73       flat.predictor_offset = global_tree[cur].predictor_offset;
     74       flat.multiplier = global_tree[cur].multiplier;
     75       *gradient_only &= flat.predictor == Predictor::Gradient;
     76       has_wp |= flat.predictor == Predictor::Weighted;
     77       has_non_wp |= flat.predictor != Predictor::Weighted;
     78       output.push_back(flat);
     79       continue;
     80     }
     81     flat.childID = output.size() + nodes.size() + 1;
     82 
     83     flat.property0 = global_tree[cur].property;
     84     *num_props = std::max<size_t>(flat.property0 + 1, *num_props);
     85     flat.splitval0 = global_tree[cur].splitval;
     86 
     87     for (size_t i = 0; i < 2; i++) {
     88       size_t cur_child =
     89           i == 0 ? global_tree[cur].lchild : global_tree[cur].rchild;
     90       // Skip nodes that we can decide now.
     91       while (global_tree[cur_child].property < kNumStaticProperties &&
     92              global_tree[cur_child].property != -1) {
     93         if (static_props[global_tree[cur_child].property] >
     94             global_tree[cur_child].splitval) {
     95           cur_child = global_tree[cur_child].lchild;
     96         } else {
     97           cur_child = global_tree[cur_child].rchild;
     98         }
     99       }
    100       // We ended up in a leaf, add a placeholder decision and two copies of the
    101       // leaf.
    102       if (global_tree[cur_child].property == -1) {
    103         flat.properties[i] = 0;
    104         flat.splitvals[i] = 0;
    105         nodes.push(cur_child);
    106         nodes.push(cur_child);
    107       } else {
    108         flat.properties[i] = global_tree[cur_child].property;
    109         flat.splitvals[i] = global_tree[cur_child].splitval;
    110         nodes.push(global_tree[cur_child].lchild);
    111         nodes.push(global_tree[cur_child].rchild);
    112         *num_props = std::max<size_t>(flat.properties[i] + 1, *num_props);
    113       }
    114     }
    115 
    116     for (size_t j = 0; j < 2; j++) mark_property(flat.properties[j]);
    117     mark_property(flat.property0);
    118     output.push_back(flat);
    119   }
    120   if (*num_props > kNumNonrefProperties) {
    121     *num_props =
    122         DivCeil(*num_props - kNumNonrefProperties, kExtraPropsPerChannel) *
    123             kExtraPropsPerChannel +
    124         kNumNonrefProperties;
    125   } else {
    126     *num_props = kNumNonrefProperties;
    127   }
    128   *use_wp = has_wp;
    129   *wp_only = has_wp && !has_non_wp;
    130 
    131   return output;
    132 }
    133 
    134 namespace detail {
    135 template <bool uses_lz77>
    136 Status DecodeModularChannelMAANS(BitReader *br, ANSSymbolReader *reader,
    137                                  const std::vector<uint8_t> &context_map,
    138                                  const Tree &global_tree,
    139                                  const weighted::Header &wp_header,
    140                                  pixel_type chan, size_t group_id,
    141                                  TreeLut<uint8_t, true> &tree_lut, Image *image,
    142                                  uint32_t &fl_run, uint32_t &fl_v) {
    143   Channel &channel = image->channel[chan];
    144 
    145   std::array<pixel_type, kNumStaticProperties> static_props = {
    146       {chan, static_cast<int>(group_id)}};
    147   // TODO(veluca): filter the tree according to static_props.
    148 
    149   // zero pixel channel? could happen
    150   if (channel.w == 0 || channel.h == 0) return true;
    151 
    152   bool tree_has_wp_prop_or_pred = false;
    153   bool is_wp_only = false;
    154   bool is_gradient_only = false;
    155   size_t num_props;
    156   FlatTree tree =
    157       FilterTree(global_tree, static_props, &num_props,
    158                  &tree_has_wp_prop_or_pred, &is_wp_only, &is_gradient_only);
    159 
    160   // From here on, tree lookup returns a *clustered* context ID.
    161   // This avoids an extra memory lookup after tree traversal.
    162   for (size_t i = 0; i < tree.size(); i++) {
    163     if (tree[i].property0 == -1) {
    164       tree[i].childID = context_map[tree[i].childID];
    165     }
    166   }
    167 
    168   JXL_DEBUG_V(3, "Decoded MA tree with %" PRIuS " nodes", tree.size());
    169 
    170   // MAANS decode
    171   const auto make_pixel = [](uint64_t v, pixel_type multiplier,
    172                              pixel_type_w offset) -> pixel_type {
    173     JXL_DASSERT((v & 0xFFFFFFFF) == v);
    174     pixel_type_w val = UnpackSigned(v);
    175     // if it overflows, it overflows, and we have a problem anyway
    176     return val * multiplier + offset;
    177   };
    178 
    179   if (tree.size() == 1) {
    180     // special optimized case: no meta-adaptation, so no need
    181     // to compute properties.
    182     Predictor predictor = tree[0].predictor;
    183     int64_t offset = tree[0].predictor_offset;
    184     int32_t multiplier = tree[0].multiplier;
    185     size_t ctx_id = tree[0].childID;
    186     if (predictor == Predictor::Zero) {
    187       uint32_t value;
    188       if (reader->IsSingleValueAndAdvance(ctx_id, &value,
    189                                           channel.w * channel.h)) {
    190         // Special-case: histogram has a single symbol, with no extra bits, and
    191         // we use ANS mode.
    192         JXL_DEBUG_V(8, "Fastest track.");
    193         pixel_type v = make_pixel(value, multiplier, offset);
    194         for (size_t y = 0; y < channel.h; y++) {
    195           pixel_type *JXL_RESTRICT r = channel.Row(y);
    196           std::fill(r, r + channel.w, v);
    197         }
    198       } else {
    199         JXL_DEBUG_V(8, "Fast track.");
    200         if (multiplier == 1 && offset == 0) {
    201           for (size_t y = 0; y < channel.h; y++) {
    202             pixel_type *JXL_RESTRICT r = channel.Row(y);
    203             for (size_t x = 0; x < channel.w; x++) {
    204               uint32_t v =
    205                   reader->ReadHybridUintClusteredInlined<uses_lz77>(ctx_id, br);
    206               r[x] = UnpackSigned(v);
    207             }
    208           }
    209         } else {
    210           for (size_t y = 0; y < channel.h; y++) {
    211             pixel_type *JXL_RESTRICT r = channel.Row(y);
    212             for (size_t x = 0; x < channel.w; x++) {
    213               uint32_t v =
    214                   reader->ReadHybridUintClusteredMaybeInlined<uses_lz77>(ctx_id,
    215                                                                          br);
    216               r[x] = make_pixel(v, multiplier, offset);
    217             }
    218           }
    219         }
    220       }
    221       return true;
    222     } else if (uses_lz77 && predictor == Predictor::Gradient && offset == 0 &&
    223                multiplier == 1 && reader->HuffRleOnly()) {
    224       JXL_DEBUG_V(8, "Gradient RLE (fjxl) very fast track.");
    225       pixel_type_w sv = UnpackSigned(fl_v);
    226       for (size_t y = 0; y < channel.h; y++) {
    227         pixel_type *JXL_RESTRICT r = channel.Row(y);
    228         const pixel_type *JXL_RESTRICT rtop = (y ? channel.Row(y - 1) : r - 1);
    229         const pixel_type *JXL_RESTRICT rtopleft =
    230             (y ? channel.Row(y - 1) - 1 : r - 1);
    231         pixel_type_w guess = (y ? rtop[0] : 0);
    232         if (fl_run == 0) {
    233           reader->ReadHybridUintClusteredHuffRleOnly(ctx_id, br, &fl_v,
    234                                                      &fl_run);
    235           sv = UnpackSigned(fl_v);
    236         } else {
    237           fl_run--;
    238         }
    239         r[0] = sv + guess;
    240         for (size_t x = 1; x < channel.w; x++) {
    241           pixel_type left = r[x - 1];
    242           pixel_type top = rtop[x];
    243           pixel_type topleft = rtopleft[x];
    244           pixel_type_w guess = ClampedGradient(top, left, topleft);
    245           if (!fl_run) {
    246             reader->ReadHybridUintClusteredHuffRleOnly(ctx_id, br, &fl_v,
    247                                                        &fl_run);
    248             sv = UnpackSigned(fl_v);
    249           } else {
    250             fl_run--;
    251           }
    252           r[x] = sv + guess;
    253         }
    254       }
    255       return true;
    256     } else if (predictor == Predictor::Gradient && offset == 0 &&
    257                multiplier == 1) {
    258       JXL_DEBUG_V(8, "Gradient very fast track.");
    259       const intptr_t onerow = channel.plane.PixelsPerRow();
    260       for (size_t y = 0; y < channel.h; y++) {
    261         pixel_type *JXL_RESTRICT r = channel.Row(y);
    262         for (size_t x = 0; x < channel.w; x++) {
    263           pixel_type left = (x ? r[x - 1] : y ? *(r + x - onerow) : 0);
    264           pixel_type top = (y ? *(r + x - onerow) : left);
    265           pixel_type topleft = (x && y ? *(r + x - 1 - onerow) : left);
    266           pixel_type guess = ClampedGradient(top, left, topleft);
    267           uint64_t v = reader->ReadHybridUintClusteredMaybeInlined<uses_lz77>(
    268               ctx_id, br);
    269           r[x] = make_pixel(v, 1, guess);
    270         }
    271       }
    272       return true;
    273     }
    274   }
    275 
    276   // Check if this tree is a WP-only tree with a small enough property value
    277   // range.
    278   if (is_wp_only) {
    279     is_wp_only = TreeToLookupTable(tree, tree_lut);
    280   }
    281   if (is_gradient_only) {
    282     is_gradient_only = TreeToLookupTable(tree, tree_lut);
    283   }
    284 
    285   if (is_gradient_only) {
    286     JXL_DEBUG_V(8, "Gradient fast track.");
    287     const intptr_t onerow = channel.plane.PixelsPerRow();
    288     for (size_t y = 0; y < channel.h; y++) {
    289       pixel_type *JXL_RESTRICT r = channel.Row(y);
    290       for (size_t x = 0; x < channel.w; x++) {
    291         pixel_type_w left = (x ? r[x - 1] : y ? *(r + x - onerow) : 0);
    292         pixel_type_w top = (y ? *(r + x - onerow) : left);
    293         pixel_type_w topleft = (x && y ? *(r + x - 1 - onerow) : left);
    294         int32_t guess = ClampedGradient(top, left, topleft);
    295         uint32_t pos =
    296             kPropRangeFast +
    297             std::min<pixel_type_w>(
    298                 std::max<pixel_type_w>(-kPropRangeFast, top + left - topleft),
    299                 kPropRangeFast - 1);
    300         uint32_t ctx_id = tree_lut.context_lookup[pos];
    301         uint64_t v =
    302             reader->ReadHybridUintClusteredMaybeInlined<uses_lz77>(ctx_id, br);
    303         r[x] = make_pixel(
    304             v, tree_lut.multipliers[pos],
    305             static_cast<pixel_type_w>(tree_lut.offsets[pos]) + guess);
    306       }
    307     }
    308   } else if (!uses_lz77 && is_wp_only && channel.w > 8) {
    309     JXL_DEBUG_V(8, "WP fast track.");
    310     weighted::State wp_state(wp_header, channel.w, channel.h);
    311     Properties properties(1);
    312     for (size_t y = 0; y < channel.h; y++) {
    313       pixel_type *JXL_RESTRICT r = channel.Row(y);
    314       const pixel_type *JXL_RESTRICT rtop = (y ? channel.Row(y - 1) : r - 1);
    315       const pixel_type *JXL_RESTRICT rtoptop =
    316           (y > 1 ? channel.Row(y - 2) : rtop);
    317       const pixel_type *JXL_RESTRICT rtopleft =
    318           (y ? channel.Row(y - 1) - 1 : r - 1);
    319       const pixel_type *JXL_RESTRICT rtopright =
    320           (y ? channel.Row(y - 1) + 1 : r - 1);
    321       size_t x = 0;
    322       {
    323         size_t offset = 0;
    324         pixel_type_w left = y ? rtop[x] : 0;
    325         pixel_type_w toptop = y ? rtoptop[x] : 0;
    326         pixel_type_w topright = (x + 1 < channel.w && y ? rtop[x + 1] : left);
    327         int32_t guess = wp_state.Predict</*compute_properties=*/true>(
    328             x, y, channel.w, left, left, topright, left, toptop, &properties,
    329             offset);
    330         uint32_t pos =
    331             kPropRangeFast + std::min(std::max(-kPropRangeFast, properties[0]),
    332                                       kPropRangeFast - 1);
    333         uint32_t ctx_id = tree_lut.context_lookup[pos];
    334         uint64_t v =
    335             reader->ReadHybridUintClusteredInlined<uses_lz77>(ctx_id, br);
    336         r[x] = make_pixel(
    337             v, tree_lut.multipliers[pos],
    338             static_cast<pixel_type_w>(tree_lut.offsets[pos]) + guess);
    339         wp_state.UpdateErrors(r[x], x, y, channel.w);
    340       }
    341       for (x = 1; x + 1 < channel.w; x++) {
    342         size_t offset = 0;
    343         int32_t guess = wp_state.Predict</*compute_properties=*/true>(
    344             x, y, channel.w, rtop[x], r[x - 1], rtopright[x], rtopleft[x],
    345             rtoptop[x], &properties, offset);
    346         uint32_t pos =
    347             kPropRangeFast + std::min(std::max(-kPropRangeFast, properties[0]),
    348                                       kPropRangeFast - 1);
    349         uint32_t ctx_id = tree_lut.context_lookup[pos];
    350         uint64_t v =
    351             reader->ReadHybridUintClusteredInlined<uses_lz77>(ctx_id, br);
    352         r[x] = make_pixel(
    353             v, tree_lut.multipliers[pos],
    354             static_cast<pixel_type_w>(tree_lut.offsets[pos]) + guess);
    355         wp_state.UpdateErrors(r[x], x, y, channel.w);
    356       }
    357       {
    358         size_t offset = 0;
    359         int32_t guess = wp_state.Predict</*compute_properties=*/true>(
    360             x, y, channel.w, rtop[x], r[x - 1], rtop[x], rtopleft[x],
    361             rtoptop[x], &properties, offset);
    362         uint32_t pos =
    363             kPropRangeFast + std::min(std::max(-kPropRangeFast, properties[0]),
    364                                       kPropRangeFast - 1);
    365         uint32_t ctx_id = tree_lut.context_lookup[pos];
    366         uint64_t v =
    367             reader->ReadHybridUintClusteredInlined<uses_lz77>(ctx_id, br);
    368         r[x] = make_pixel(
    369             v, tree_lut.multipliers[pos],
    370             static_cast<pixel_type_w>(tree_lut.offsets[pos]) + guess);
    371         wp_state.UpdateErrors(r[x], x, y, channel.w);
    372       }
    373     }
    374   } else if (!tree_has_wp_prop_or_pred) {
    375     // special optimized case: the weighted predictor and its properties are not
    376     // used, so no need to compute weights and properties.
    377     JXL_DEBUG_V(8, "Slow track.");
    378     MATreeLookup tree_lookup(tree);
    379     Properties properties = Properties(num_props);
    380     const intptr_t onerow = channel.plane.PixelsPerRow();
    381     JXL_ASSIGN_OR_RETURN(
    382         Channel references,
    383         Channel::Create(properties.size() - kNumNonrefProperties, channel.w));
    384     for (size_t y = 0; y < channel.h; y++) {
    385       pixel_type *JXL_RESTRICT p = channel.Row(y);
    386       PrecomputeReferences(channel, y, *image, chan, &references);
    387       InitPropsRow(&properties, static_props, y);
    388       if (y > 1 && channel.w > 8 && references.w == 0) {
    389         for (size_t x = 0; x < 2; x++) {
    390           PredictionResult res =
    391               PredictTreeNoWP(&properties, channel.w, p + x, onerow, x, y,
    392                               tree_lookup, references);
    393           uint64_t v =
    394               reader->ReadHybridUintClustered<uses_lz77>(res.context, br);
    395           p[x] = make_pixel(v, res.multiplier, res.guess);
    396         }
    397         for (size_t x = 2; x < channel.w - 2; x++) {
    398           PredictionResult res =
    399               PredictTreeNoWPNEC(&properties, channel.w, p + x, onerow, x, y,
    400                                  tree_lookup, references);
    401           uint64_t v = reader->ReadHybridUintClusteredInlined<uses_lz77>(
    402               res.context, br);
    403           p[x] = make_pixel(v, res.multiplier, res.guess);
    404         }
    405         for (size_t x = channel.w - 2; x < channel.w; x++) {
    406           PredictionResult res =
    407               PredictTreeNoWP(&properties, channel.w, p + x, onerow, x, y,
    408                               tree_lookup, references);
    409           uint64_t v =
    410               reader->ReadHybridUintClustered<uses_lz77>(res.context, br);
    411           p[x] = make_pixel(v, res.multiplier, res.guess);
    412         }
    413       } else {
    414         for (size_t x = 0; x < channel.w; x++) {
    415           PredictionResult res =
    416               PredictTreeNoWP(&properties, channel.w, p + x, onerow, x, y,
    417                               tree_lookup, references);
    418           uint64_t v = reader->ReadHybridUintClusteredMaybeInlined<uses_lz77>(
    419               res.context, br);
    420           p[x] = make_pixel(v, res.multiplier, res.guess);
    421         }
    422       }
    423     }
    424   } else {
    425     JXL_DEBUG_V(8, "Slowest track.");
    426     MATreeLookup tree_lookup(tree);
    427     Properties properties = Properties(num_props);
    428     const intptr_t onerow = channel.plane.PixelsPerRow();
    429     JXL_ASSIGN_OR_RETURN(
    430         Channel references,
    431         Channel::Create(properties.size() - kNumNonrefProperties, channel.w));
    432     weighted::State wp_state(wp_header, channel.w, channel.h);
    433     for (size_t y = 0; y < channel.h; y++) {
    434       pixel_type *JXL_RESTRICT p = channel.Row(y);
    435       InitPropsRow(&properties, static_props, y);
    436       PrecomputeReferences(channel, y, *image, chan, &references);
    437       if (!uses_lz77 && y > 1 && channel.w > 8 && references.w == 0) {
    438         for (size_t x = 0; x < 2; x++) {
    439           PredictionResult res =
    440               PredictTreeWP(&properties, channel.w, p + x, onerow, x, y,
    441                             tree_lookup, references, &wp_state);
    442           uint64_t v =
    443               reader->ReadHybridUintClustered<uses_lz77>(res.context, br);
    444           p[x] = make_pixel(v, res.multiplier, res.guess);
    445           wp_state.UpdateErrors(p[x], x, y, channel.w);
    446         }
    447         for (size_t x = 2; x < channel.w - 2; x++) {
    448           PredictionResult res =
    449               PredictTreeWPNEC(&properties, channel.w, p + x, onerow, x, y,
    450                                tree_lookup, references, &wp_state);
    451           uint64_t v = reader->ReadHybridUintClusteredInlined<uses_lz77>(
    452               res.context, br);
    453           p[x] = make_pixel(v, res.multiplier, res.guess);
    454           wp_state.UpdateErrors(p[x], x, y, channel.w);
    455         }
    456         for (size_t x = channel.w - 2; x < channel.w; x++) {
    457           PredictionResult res =
    458               PredictTreeWP(&properties, channel.w, p + x, onerow, x, y,
    459                             tree_lookup, references, &wp_state);
    460           uint64_t v =
    461               reader->ReadHybridUintClustered<uses_lz77>(res.context, br);
    462           p[x] = make_pixel(v, res.multiplier, res.guess);
    463           wp_state.UpdateErrors(p[x], x, y, channel.w);
    464         }
    465       } else {
    466         for (size_t x = 0; x < channel.w; x++) {
    467           PredictionResult res =
    468               PredictTreeWP(&properties, channel.w, p + x, onerow, x, y,
    469                             tree_lookup, references, &wp_state);
    470           uint64_t v =
    471               reader->ReadHybridUintClustered<uses_lz77>(res.context, br);
    472           p[x] = make_pixel(v, res.multiplier, res.guess);
    473           wp_state.UpdateErrors(p[x], x, y, channel.w);
    474         }
    475       }
    476     }
    477   }
    478   return true;
    479 }
    480 }  // namespace detail
    481 
    482 Status DecodeModularChannelMAANS(BitReader *br, ANSSymbolReader *reader,
    483                                  const std::vector<uint8_t> &context_map,
    484                                  const Tree &global_tree,
    485                                  const weighted::Header &wp_header,
    486                                  pixel_type chan, size_t group_id,
    487                                  TreeLut<uint8_t, true> &tree_lut, Image *image,
    488                                  uint32_t &fl_run, uint32_t &fl_v) {
    489   if (reader->UsesLZ77()) {
    490     return detail::DecodeModularChannelMAANS</*uses_lz77=*/true>(
    491         br, reader, context_map, global_tree, wp_header, chan, group_id,
    492         tree_lut, image, fl_run, fl_v);
    493   } else {
    494     return detail::DecodeModularChannelMAANS</*uses_lz77=*/false>(
    495         br, reader, context_map, global_tree, wp_header, chan, group_id,
    496         tree_lut, image, fl_run, fl_v);
    497   }
    498 }
    499 
    500 GroupHeader::GroupHeader() { Bundle::Init(this); }
    501 
    502 Status ValidateChannelDimensions(const Image &image,
    503                                  const ModularOptions &options) {
    504   size_t nb_channels = image.channel.size();
    505   for (bool is_dc : {true, false}) {
    506     size_t group_dim = options.group_dim * (is_dc ? kBlockDim : 1);
    507     size_t c = image.nb_meta_channels;
    508     for (; c < nb_channels; c++) {
    509       const Channel &ch = image.channel[c];
    510       if (ch.w > options.group_dim || ch.h > options.group_dim) break;
    511     }
    512     for (; c < nb_channels; c++) {
    513       const Channel &ch = image.channel[c];
    514       if (ch.w == 0 || ch.h == 0) continue;  // skip empty
    515       bool is_dc_channel = std::min(ch.hshift, ch.vshift) >= 3;
    516       if (is_dc_channel != is_dc) continue;
    517       size_t tile_dim = group_dim >> std::max(ch.hshift, ch.vshift);
    518       if (tile_dim == 0) {
    519         return JXL_FAILURE("Inconsistent transforms");
    520       }
    521     }
    522   }
    523   return true;
    524 }
    525 
    526 Status ModularDecode(BitReader *br, Image &image, GroupHeader &header,
    527                      size_t group_id, ModularOptions *options,
    528                      const Tree *global_tree, const ANSCode *global_code,
    529                      const std::vector<uint8_t> *global_ctx_map,
    530                      const bool allow_truncated_group) {
    531   if (image.channel.empty()) return true;
    532 
    533   // decode transforms
    534   Status status = Bundle::Read(br, &header);
    535   if (!allow_truncated_group) JXL_RETURN_IF_ERROR(status);
    536   if (status.IsFatalError()) return status;
    537   if (!br->AllReadsWithinBounds()) {
    538     // Don't do/undo transforms if header is incomplete.
    539     header.transforms.clear();
    540     image.transform = header.transforms;
    541     for (size_t c = 0; c < image.channel.size(); c++) {
    542       ZeroFillImage(&image.channel[c].plane);
    543     }
    544     return Status(StatusCode::kNotEnoughBytes);
    545   }
    546 
    547   JXL_DEBUG_V(3, "Image data underwent %" PRIuS " transformations: ",
    548               header.transforms.size());
    549   image.transform = header.transforms;
    550   for (Transform &transform : image.transform) {
    551     JXL_RETURN_IF_ERROR(transform.MetaApply(image));
    552   }
    553   if (image.error) {
    554     return JXL_FAILURE("Corrupt file. Aborting.");
    555   }
    556   JXL_RETURN_IF_ERROR(ValidateChannelDimensions(image, *options));
    557 
    558   size_t nb_channels = image.channel.size();
    559 
    560   size_t num_chans = 0;
    561   size_t distance_multiplier = 0;
    562   for (size_t i = 0; i < nb_channels; i++) {
    563     Channel &channel = image.channel[i];
    564     if (!channel.w || !channel.h) {
    565       continue;  // skip empty channels
    566     }
    567     if (i >= image.nb_meta_channels && (channel.w > options->max_chan_size ||
    568                                         channel.h > options->max_chan_size)) {
    569       break;
    570     }
    571     if (channel.w > distance_multiplier) {
    572       distance_multiplier = channel.w;
    573     }
    574     num_chans++;
    575   }
    576   if (num_chans == 0) return true;
    577 
    578   size_t next_channel = 0;
    579   auto scope_guard = MakeScopeGuard([&]() {
    580     for (size_t c = next_channel; c < image.channel.size(); c++) {
    581       ZeroFillImage(&image.channel[c].plane);
    582     }
    583   });
    584   // Do not do anything if truncated groups are not allowed.
    585   if (allow_truncated_group) scope_guard.Disarm();
    586 
    587   // Read tree.
    588   Tree tree_storage;
    589   std::vector<uint8_t> context_map_storage;
    590   ANSCode code_storage;
    591   const Tree *tree = &tree_storage;
    592   const ANSCode *code = &code_storage;
    593   const std::vector<uint8_t> *context_map = &context_map_storage;
    594   if (!header.use_global_tree) {
    595     uint64_t max_tree_size = 1024;
    596     for (size_t i = 0; i < nb_channels; i++) {
    597       Channel &channel = image.channel[i];
    598       if (i >= image.nb_meta_channels && (channel.w > options->max_chan_size ||
    599                                           channel.h > options->max_chan_size)) {
    600         break;
    601       }
    602       uint64_t pixels = channel.w * channel.h;
    603       max_tree_size += pixels;
    604     }
    605     max_tree_size = std::min(static_cast<uint64_t>(1 << 20), max_tree_size);
    606     JXL_RETURN_IF_ERROR(DecodeTree(br, &tree_storage, max_tree_size));
    607     JXL_RETURN_IF_ERROR(DecodeHistograms(br, (tree_storage.size() + 1) / 2,
    608                                          &code_storage, &context_map_storage));
    609   } else {
    610     if (!global_tree || !global_code || !global_ctx_map ||
    611         global_tree->empty()) {
    612       return JXL_FAILURE("No global tree available but one was requested");
    613     }
    614     tree = global_tree;
    615     code = global_code;
    616     context_map = global_ctx_map;
    617   }
    618 
    619   // Read channels
    620   ANSSymbolReader reader(code, br, distance_multiplier);
    621   auto tree_lut = jxl::make_unique<TreeLut<uint8_t, true>>();
    622   uint32_t fl_run = 0;
    623   uint32_t fl_v = 0;
    624   for (; next_channel < nb_channels; next_channel++) {
    625     Channel &channel = image.channel[next_channel];
    626     if (!channel.w || !channel.h) {
    627       continue;  // skip empty channels
    628     }
    629     if (next_channel >= image.nb_meta_channels &&
    630         (channel.w > options->max_chan_size ||
    631          channel.h > options->max_chan_size)) {
    632       break;
    633     }
    634     JXL_RETURN_IF_ERROR(DecodeModularChannelMAANS(
    635         br, &reader, *context_map, *tree, header.wp_header, next_channel,
    636         group_id, *tree_lut, &image, fl_run, fl_v));
    637 
    638     // Truncated group.
    639     if (!br->AllReadsWithinBounds()) {
    640       if (!allow_truncated_group) return JXL_FAILURE("Truncated input");
    641       return Status(StatusCode::kNotEnoughBytes);
    642     }
    643   }
    644 
    645   // Make sure no zero-filling happens even if next_channel < nb_channels.
    646   scope_guard.Disarm();
    647 
    648   if (!reader.CheckANSFinalState()) {
    649     return JXL_FAILURE("ANS decode final state failed");
    650   }
    651   return true;
    652 }
    653 
    654 Status ModularGenericDecompress(BitReader *br, Image &image,
    655                                 GroupHeader *header, size_t group_id,
    656                                 ModularOptions *options, bool undo_transforms,
    657                                 const Tree *tree, const ANSCode *code,
    658                                 const std::vector<uint8_t> *ctx_map,
    659                                 bool allow_truncated_group) {
    660 #ifdef JXL_ENABLE_ASSERT
    661   std::vector<std::pair<uint32_t, uint32_t>> req_sizes(image.channel.size());
    662   for (size_t c = 0; c < req_sizes.size(); c++) {
    663     req_sizes[c] = {image.channel[c].w, image.channel[c].h};
    664   }
    665 #endif
    666   GroupHeader local_header;
    667   if (header == nullptr) header = &local_header;
    668   size_t bit_pos = br->TotalBitsConsumed();
    669   auto dec_status = ModularDecode(br, image, *header, group_id, options, tree,
    670                                   code, ctx_map, allow_truncated_group);
    671   if (!allow_truncated_group) JXL_RETURN_IF_ERROR(dec_status);
    672   if (dec_status.IsFatalError()) return dec_status;
    673   if (undo_transforms) image.undo_transforms(header->wp_header);
    674   if (image.error) return JXL_FAILURE("Corrupt file. Aborting.");
    675   JXL_DEBUG_V(4,
    676               "Modular-decoded a %" PRIuS "x%" PRIuS " nbchans=%" PRIuS
    677               " image from %" PRIuS " bytes",
    678               image.w, image.h, image.channel.size(),
    679               (br->TotalBitsConsumed() - bit_pos) / 8);
    680   JXL_DEBUG_V(5, "Modular image: %s", image.DebugString().c_str());
    681   (void)bit_pos;
    682 #ifdef JXL_ENABLE_ASSERT
    683   // Check that after applying all transforms we are back to the requested image
    684   // sizes, otherwise there's a programming error with the transformations.
    685   if (undo_transforms) {
    686     JXL_ASSERT(image.channel.size() == req_sizes.size());
    687     for (size_t c = 0; c < req_sizes.size(); c++) {
    688       JXL_ASSERT(req_sizes[c].first == image.channel[c].w);
    689       JXL_ASSERT(req_sizes[c].second == image.channel[c].h);
    690     }
    691   }
    692 #endif
    693   return dec_status;
    694 }
    695 
    696 }  // namespace jxl