libjxl

FORK: libjxl patches used on blog
git clone https://git.neptards.moe/blog/libjxl.git
Log | Files | Refs | Submodules | README | LICENSE

enc_encoding.cc (28988B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 #include <stdint.h>
      7 #include <stdlib.h>
      8 
      9 #include <limits>
     10 #include <queue>
     11 
     12 #include "lib/jxl/base/common.h"
     13 #include "lib/jxl/base/printf_macros.h"
     14 #include "lib/jxl/base/status.h"
     15 #include "lib/jxl/enc_ans.h"
     16 #include "lib/jxl/enc_aux_out.h"
     17 #include "lib/jxl/enc_bit_writer.h"
     18 #include "lib/jxl/enc_fields.h"
     19 #include "lib/jxl/fields.h"
     20 #include "lib/jxl/image_ops.h"
     21 #include "lib/jxl/modular/encoding/context_predict.h"
     22 #include "lib/jxl/modular/encoding/enc_ma.h"
     23 #include "lib/jxl/modular/encoding/encoding.h"
     24 #include "lib/jxl/modular/encoding/ma_common.h"
     25 #include "lib/jxl/modular/options.h"
     26 #include "lib/jxl/pack_signed.h"
     27 
     28 namespace jxl {
     29 
     30 namespace {
     31 // Plot tree (if enabled) and predictor usage map.
     32 constexpr bool kWantDebug = true;
     33 // constexpr bool kPrintTree = false;
     34 
     35 inline std::array<uint8_t, 3> PredictorColor(Predictor p) {
     36   switch (p) {
     37     case Predictor::Zero:
     38       return {{0, 0, 0}};
     39     case Predictor::Left:
     40       return {{255, 0, 0}};
     41     case Predictor::Top:
     42       return {{0, 255, 0}};
     43     case Predictor::Average0:
     44       return {{0, 0, 255}};
     45     case Predictor::Average4:
     46       return {{192, 128, 128}};
     47     case Predictor::Select:
     48       return {{255, 255, 0}};
     49     case Predictor::Gradient:
     50       return {{255, 0, 255}};
     51     case Predictor::Weighted:
     52       return {{0, 255, 255}};
     53       // TODO(jon)
     54     default:
     55       return {{255, 255, 255}};
     56   };
     57 }
     58 
     59 // `cutoffs` must be sorted.
     60 Tree MakeFixedTree(int property, const std::vector<int32_t> &cutoffs,
     61                    Predictor pred, size_t num_pixels) {
     62   size_t log_px = CeilLog2Nonzero(num_pixels);
     63   size_t min_gap = 0;
     64   // Reduce fixed tree height when encoding small images.
     65   if (log_px < 14) {
     66     min_gap = 8 * (14 - log_px);
     67   }
     68   Tree tree;
     69   struct NodeInfo {
     70     size_t begin, end, pos;
     71   };
     72   std::queue<NodeInfo> q;
     73   // Leaf IDs will be set by roundtrip decoding the tree.
     74   tree.push_back(PropertyDecisionNode::Leaf(pred));
     75   q.push(NodeInfo{0, cutoffs.size(), 0});
     76   while (!q.empty()) {
     77     NodeInfo info = q.front();
     78     q.pop();
     79     if (info.begin + min_gap >= info.end) continue;
     80     uint32_t split = (info.begin + info.end) / 2;
     81     tree[info.pos] =
     82         PropertyDecisionNode::Split(property, cutoffs[split], tree.size());
     83     q.push(NodeInfo{split + 1, info.end, tree.size()});
     84     tree.push_back(PropertyDecisionNode::Leaf(pred));
     85     q.push(NodeInfo{info.begin, split, tree.size()});
     86     tree.push_back(PropertyDecisionNode::Leaf(pred));
     87   }
     88   return tree;
     89 }
     90 
     91 }  // namespace
     92 
     93 Status GatherTreeData(const Image &image, pixel_type chan, size_t group_id,
     94                       const weighted::Header &wp_header,
     95                       const ModularOptions &options, TreeSamples &tree_samples,
     96                       size_t *total_pixels) {
     97   const Channel &channel = image.channel[chan];
     98 
     99   JXL_DEBUG_V(7, "Learning %" PRIuS "x%" PRIuS " channel %d", channel.w,
    100               channel.h, chan);
    101 
    102   std::array<pixel_type, kNumStaticProperties> static_props = {
    103       {chan, static_cast<int>(group_id)}};
    104   Properties properties(kNumNonrefProperties +
    105                         kExtraPropsPerChannel * options.max_properties);
    106   double pixel_fraction = std::min(1.0f, options.nb_repeats);
    107   // a fraction of 0 is used to disable learning entirely.
    108   if (pixel_fraction > 0) {
    109     pixel_fraction = std::max(pixel_fraction,
    110                               std::min(1.0, 1024.0 / (channel.w * channel.h)));
    111   }
    112   uint64_t threshold =
    113       (std::numeric_limits<uint64_t>::max() >> 32) * pixel_fraction;
    114   uint64_t s[2] = {static_cast<uint64_t>(0x94D049BB133111EBull),
    115                    static_cast<uint64_t>(0xBF58476D1CE4E5B9ull)};
    116   // Xorshift128+ adapted from xorshift128+-inl.h
    117   auto use_sample = [&]() {
    118     auto s1 = s[0];
    119     const auto s0 = s[1];
    120     const auto bits = s1 + s0;  // b, c
    121     s[0] = s0;
    122     s1 ^= s1 << 23;
    123     s1 ^= s0 ^ (s1 >> 18) ^ (s0 >> 5);
    124     s[1] = s1;
    125     return (bits >> 32) <= threshold;
    126   };
    127 
    128   const intptr_t onerow = channel.plane.PixelsPerRow();
    129   JXL_ASSIGN_OR_RETURN(
    130       Channel references,
    131       Channel::Create(properties.size() - kNumNonrefProperties, channel.w));
    132   weighted::State wp_state(wp_header, channel.w, channel.h);
    133   tree_samples.PrepareForSamples(pixel_fraction * channel.h * channel.w + 64);
    134   const bool multiple_predictors = tree_samples.NumPredictors() != 1;
    135   auto compute_sample = [&](const pixel_type *p, size_t x, size_t y) {
    136     pixel_type_w pred[kNumModularPredictors];
    137     if (multiple_predictors) {
    138       PredictLearnAll(&properties, channel.w, p + x, onerow, x, y, references,
    139                       &wp_state, pred);
    140     } else {
    141       pred[static_cast<int>(tree_samples.PredictorFromIndex(0))] =
    142           PredictLearn(&properties, channel.w, p + x, onerow, x, y,
    143                        tree_samples.PredictorFromIndex(0), references,
    144                        &wp_state)
    145               .guess;
    146     }
    147     (*total_pixels)++;
    148     if (use_sample()) {
    149       tree_samples.AddSample(p[x], properties, pred);
    150     }
    151     wp_state.UpdateErrors(p[x], x, y, channel.w);
    152   };
    153 
    154   for (size_t y = 0; y < channel.h; y++) {
    155     const pixel_type *JXL_RESTRICT p = channel.Row(y);
    156     PrecomputeReferences(channel, y, image, chan, &references);
    157     InitPropsRow(&properties, static_props, y);
    158 
    159     // TODO(veluca): avoid computing WP if we don't use its property or
    160     // predictions.
    161     if (y > 1 && channel.w > 8 && references.w == 0) {
    162       for (size_t x = 0; x < 2; x++) {
    163         compute_sample(p, x, y);
    164       }
    165       for (size_t x = 2; x < channel.w - 2; x++) {
    166         pixel_type_w pred[kNumModularPredictors];
    167         if (multiple_predictors) {
    168           PredictLearnAllNEC(&properties, channel.w, p + x, onerow, x, y,
    169                              references, &wp_state, pred);
    170         } else {
    171           pred[static_cast<int>(tree_samples.PredictorFromIndex(0))] =
    172               PredictLearnNEC(&properties, channel.w, p + x, onerow, x, y,
    173                               tree_samples.PredictorFromIndex(0), references,
    174                               &wp_state)
    175                   .guess;
    176         }
    177         (*total_pixels)++;
    178         if (use_sample()) {
    179           tree_samples.AddSample(p[x], properties, pred);
    180         }
    181         wp_state.UpdateErrors(p[x], x, y, channel.w);
    182       }
    183       for (size_t x = channel.w - 2; x < channel.w; x++) {
    184         compute_sample(p, x, y);
    185       }
    186     } else {
    187       for (size_t x = 0; x < channel.w; x++) {
    188         compute_sample(p, x, y);
    189       }
    190     }
    191   }
    192   return true;
    193 }
    194 
    195 Tree PredefinedTree(ModularOptions::TreeKind tree_kind, size_t total_pixels) {
    196   if (tree_kind == ModularOptions::TreeKind::kJpegTranscodeACMeta ||
    197       tree_kind == ModularOptions::TreeKind::kTrivialTreeNoPredictor) {
    198     // All the data is 0, so no need for a fancy tree.
    199     return {PropertyDecisionNode::Leaf(Predictor::Zero)};
    200   }
    201   if (tree_kind == ModularOptions::TreeKind::kFalconACMeta) {
    202     // All the data is 0 except the quant field. TODO(veluca): make that 0 too.
    203     return {PropertyDecisionNode::Leaf(Predictor::Left)};
    204   }
    205   if (tree_kind == ModularOptions::TreeKind::kACMeta) {
    206     // Small image.
    207     if (total_pixels < 1024) {
    208       return {PropertyDecisionNode::Leaf(Predictor::Left)};
    209     }
    210     Tree tree;
    211     // 0: c > 1
    212     tree.push_back(PropertyDecisionNode::Split(0, 1, 1));
    213     // 1: c > 2
    214     tree.push_back(PropertyDecisionNode::Split(0, 2, 3));
    215     // 2: c > 0
    216     tree.push_back(PropertyDecisionNode::Split(0, 0, 5));
    217     // 3: EPF control field (all 0 or 4), top > 0
    218     tree.push_back(PropertyDecisionNode::Split(6, 0, 21));
    219     // 4: ACS+QF, y > 0
    220     tree.push_back(PropertyDecisionNode::Split(2, 0, 7));
    221     // 5: CfL x
    222     tree.push_back(PropertyDecisionNode::Leaf(Predictor::Gradient));
    223     // 6: CfL b
    224     tree.push_back(PropertyDecisionNode::Leaf(Predictor::Gradient));
    225     // 7: QF: split according to the left quant value.
    226     tree.push_back(PropertyDecisionNode::Split(7, 5, 9));
    227     // 8: ACS: split in 4 segments (8x8 from 0 to 3, large square 4-5, large
    228     // rectangular 6-11, 8x8 12+), according to previous ACS value.
    229     tree.push_back(PropertyDecisionNode::Split(7, 5, 15));
    230     // QF
    231     tree.push_back(PropertyDecisionNode::Split(7, 11, 11));
    232     tree.push_back(PropertyDecisionNode::Split(7, 3, 13));
    233     tree.push_back(PropertyDecisionNode::Leaf(Predictor::Left));
    234     tree.push_back(PropertyDecisionNode::Leaf(Predictor::Left));
    235     tree.push_back(PropertyDecisionNode::Leaf(Predictor::Left));
    236     tree.push_back(PropertyDecisionNode::Leaf(Predictor::Left));
    237     // ACS
    238     tree.push_back(PropertyDecisionNode::Split(7, 11, 17));
    239     tree.push_back(PropertyDecisionNode::Split(7, 3, 19));
    240     tree.push_back(PropertyDecisionNode::Leaf(Predictor::Zero));
    241     tree.push_back(PropertyDecisionNode::Leaf(Predictor::Zero));
    242     tree.push_back(PropertyDecisionNode::Leaf(Predictor::Zero));
    243     tree.push_back(PropertyDecisionNode::Leaf(Predictor::Zero));
    244     // EPF, left > 0
    245     tree.push_back(PropertyDecisionNode::Split(7, 0, 23));
    246     tree.push_back(PropertyDecisionNode::Split(7, 0, 25));
    247     tree.push_back(PropertyDecisionNode::Leaf(Predictor::Zero));
    248     tree.push_back(PropertyDecisionNode::Leaf(Predictor::Zero));
    249     tree.push_back(PropertyDecisionNode::Leaf(Predictor::Zero));
    250     tree.push_back(PropertyDecisionNode::Leaf(Predictor::Zero));
    251     return tree;
    252   }
    253   if (tree_kind == ModularOptions::TreeKind::kWPFixedDC) {
    254     std::vector<int32_t> cutoffs = {
    255         -500, -392, -255, -191, -127, -95, -63, -47, -31, -23, -15,
    256         -11,  -7,   -4,   -3,   -1,   0,   1,   3,   5,   7,   11,
    257         15,   23,   31,   47,   63,   95,  127, 191, 255, 392, 500};
    258     return MakeFixedTree(kWPProp, cutoffs, Predictor::Weighted, total_pixels);
    259   }
    260   if (tree_kind == ModularOptions::TreeKind::kGradientFixedDC) {
    261     std::vector<int32_t> cutoffs = {
    262         -500, -392, -255, -191, -127, -95, -63, -47, -31, -23, -15,
    263         -11,  -7,   -4,   -3,   -1,   0,   1,   3,   5,   7,   11,
    264         15,   23,   31,   47,   63,   95,  127, 191, 255, 392, 500};
    265     return MakeFixedTree(kGradientProp, cutoffs, Predictor::Gradient,
    266                          total_pixels);
    267   }
    268   JXL_UNREACHABLE("Unreachable");
    269   return {};
    270 }
    271 
    272 Tree LearnTree(TreeSamples &&tree_samples, size_t total_pixels,
    273                const ModularOptions &options,
    274                const std::vector<ModularMultiplierInfo> &multiplier_info = {},
    275                StaticPropRange static_prop_range = {}) {
    276   for (size_t i = 0; i < kNumStaticProperties; i++) {
    277     if (static_prop_range[i][1] == 0) {
    278       static_prop_range[i][1] = std::numeric_limits<uint32_t>::max();
    279     }
    280   }
    281   if (!tree_samples.HasSamples()) {
    282     Tree tree;
    283     tree.emplace_back();
    284     tree.back().predictor = tree_samples.PredictorFromIndex(0);
    285     tree.back().property = -1;
    286     tree.back().predictor_offset = 0;
    287     tree.back().multiplier = 1;
    288     return tree;
    289   }
    290   float pixel_fraction = tree_samples.NumSamples() * 1.0f / total_pixels;
    291   float required_cost = pixel_fraction * 0.9 + 0.1;
    292   tree_samples.AllSamplesDone();
    293   Tree tree;
    294   ComputeBestTree(tree_samples,
    295                   options.splitting_heuristics_node_threshold * required_cost,
    296                   multiplier_info, static_prop_range,
    297                   options.fast_decode_multiplier, &tree);
    298   return tree;
    299 }
    300 
    301 Status EncodeModularChannelMAANS(const Image &image, pixel_type chan,
    302                                  const weighted::Header &wp_header,
    303                                  const Tree &global_tree, Token **tokenpp,
    304                                  AuxOut *aux_out, size_t group_id,
    305                                  bool skip_encoder_fast_path) {
    306   const Channel &channel = image.channel[chan];
    307   Token *tokenp = *tokenpp;
    308   JXL_ASSERT(channel.w != 0 && channel.h != 0);
    309 
    310   Image3F predictor_img;
    311   if (kWantDebug) {
    312     JXL_ASSIGN_OR_RETURN(predictor_img, Image3F::Create(channel.w, channel.h));
    313   }
    314 
    315   JXL_DEBUG_V(6,
    316               "Encoding %" PRIuS "x%" PRIuS
    317               " channel %d, "
    318               "(shift=%i,%i)",
    319               channel.w, channel.h, chan, channel.hshift, channel.vshift);
    320 
    321   std::array<pixel_type, kNumStaticProperties> static_props = {
    322       {chan, static_cast<int>(group_id)}};
    323   bool use_wp;
    324   bool is_wp_only;
    325   bool is_gradient_only;
    326   size_t num_props;
    327   FlatTree tree = FilterTree(global_tree, static_props, &num_props, &use_wp,
    328                              &is_wp_only, &is_gradient_only);
    329   Properties properties(num_props);
    330   MATreeLookup tree_lookup(tree);
    331   JXL_DEBUG_V(3, "Encoding using a MA tree with %" PRIuS " nodes", tree.size());
    332 
    333   // Check if this tree is a WP-only tree with a small enough property value
    334   // range.
    335   // Initialized to avoid clang-tidy complaining.
    336   auto tree_lut = jxl::make_unique<TreeLut<uint16_t, false>>();
    337   if (is_wp_only) {
    338     is_wp_only = TreeToLookupTable(tree, *tree_lut);
    339   }
    340   if (is_gradient_only) {
    341     is_gradient_only = TreeToLookupTable(tree, *tree_lut);
    342   }
    343 
    344   if (is_wp_only && !skip_encoder_fast_path) {
    345     for (size_t c = 0; c < 3; c++) {
    346       FillImage(static_cast<float>(PredictorColor(Predictor::Weighted)[c]),
    347                 &predictor_img.Plane(c));
    348     }
    349     const intptr_t onerow = channel.plane.PixelsPerRow();
    350     weighted::State wp_state(wp_header, channel.w, channel.h);
    351     Properties properties(1);
    352     for (size_t y = 0; y < channel.h; y++) {
    353       const pixel_type *JXL_RESTRICT r = channel.Row(y);
    354       for (size_t x = 0; x < channel.w; x++) {
    355         size_t offset = 0;
    356         pixel_type_w left = (x ? r[x - 1] : y ? *(r + x - onerow) : 0);
    357         pixel_type_w top = (y ? *(r + x - onerow) : left);
    358         pixel_type_w topleft = (x && y ? *(r + x - 1 - onerow) : left);
    359         pixel_type_w topright =
    360             (x + 1 < channel.w && y ? *(r + x + 1 - onerow) : top);
    361         pixel_type_w toptop = (y > 1 ? *(r + x - onerow - onerow) : top);
    362         int32_t guess = wp_state.Predict</*compute_properties=*/true>(
    363             x, y, channel.w, top, left, topright, topleft, toptop, &properties,
    364             offset);
    365         uint32_t pos =
    366             kPropRangeFast + std::min(std::max(-kPropRangeFast, properties[0]),
    367                                       kPropRangeFast - 1);
    368         uint32_t ctx_id = tree_lut->context_lookup[pos];
    369         int32_t residual = r[x] - guess - tree_lut->offsets[pos];
    370         *tokenp++ = Token(ctx_id, PackSigned(residual));
    371         wp_state.UpdateErrors(r[x], x, y, channel.w);
    372       }
    373     }
    374   } else if (tree.size() == 1 && tree[0].predictor == Predictor::Gradient &&
    375              tree[0].multiplier == 1 && tree[0].predictor_offset == 0 &&
    376              !skip_encoder_fast_path) {
    377     for (size_t c = 0; c < 3; c++) {
    378       FillImage(static_cast<float>(PredictorColor(Predictor::Gradient)[c]),
    379                 &predictor_img.Plane(c));
    380     }
    381     const intptr_t onerow = channel.plane.PixelsPerRow();
    382     for (size_t y = 0; y < channel.h; y++) {
    383       const pixel_type *JXL_RESTRICT r = channel.Row(y);
    384       for (size_t x = 0; x < channel.w; x++) {
    385         pixel_type_w left = (x ? r[x - 1] : y ? *(r + x - onerow) : 0);
    386         pixel_type_w top = (y ? *(r + x - onerow) : left);
    387         pixel_type_w topleft = (x && y ? *(r + x - 1 - onerow) : left);
    388         int32_t guess = ClampedGradient(top, left, topleft);
    389         int32_t residual = r[x] - guess;
    390         *tokenp++ = Token(tree[0].childID, PackSigned(residual));
    391       }
    392     }
    393   } else if (is_gradient_only && !skip_encoder_fast_path) {
    394     for (size_t c = 0; c < 3; c++) {
    395       FillImage(static_cast<float>(PredictorColor(Predictor::Gradient)[c]),
    396                 &predictor_img.Plane(c));
    397     }
    398     const intptr_t onerow = channel.plane.PixelsPerRow();
    399     for (size_t y = 0; y < channel.h; y++) {
    400       const pixel_type *JXL_RESTRICT r = channel.Row(y);
    401       for (size_t x = 0; x < channel.w; x++) {
    402         pixel_type_w left = (x ? r[x - 1] : y ? *(r + x - onerow) : 0);
    403         pixel_type_w top = (y ? *(r + x - onerow) : left);
    404         pixel_type_w topleft = (x && y ? *(r + x - 1 - onerow) : left);
    405         int32_t guess = ClampedGradient(top, left, topleft);
    406         uint32_t pos =
    407             kPropRangeFast +
    408             std::min<pixel_type_w>(
    409                 std::max<pixel_type_w>(-kPropRangeFast, top + left - topleft),
    410                 kPropRangeFast - 1);
    411         uint32_t ctx_id = tree_lut->context_lookup[pos];
    412         int32_t residual = r[x] - guess - tree_lut->offsets[pos];
    413         *tokenp++ = Token(ctx_id, PackSigned(residual));
    414       }
    415     }
    416   } else if (tree.size() == 1 && tree[0].predictor == Predictor::Zero &&
    417              tree[0].multiplier == 1 && tree[0].predictor_offset == 0 &&
    418              !skip_encoder_fast_path) {
    419     for (size_t c = 0; c < 3; c++) {
    420       FillImage(static_cast<float>(PredictorColor(Predictor::Zero)[c]),
    421                 &predictor_img.Plane(c));
    422     }
    423     for (size_t y = 0; y < channel.h; y++) {
    424       const pixel_type *JXL_RESTRICT p = channel.Row(y);
    425       for (size_t x = 0; x < channel.w; x++) {
    426         *tokenp++ = Token(tree[0].childID, PackSigned(p[x]));
    427       }
    428     }
    429   } else if (tree.size() == 1 && tree[0].predictor != Predictor::Weighted &&
    430              (tree[0].multiplier & (tree[0].multiplier - 1)) == 0 &&
    431              tree[0].predictor_offset == 0 && !skip_encoder_fast_path) {
    432     // multiplier is a power of 2.
    433     for (size_t c = 0; c < 3; c++) {
    434       FillImage(static_cast<float>(PredictorColor(tree[0].predictor)[c]),
    435                 &predictor_img.Plane(c));
    436     }
    437     uint32_t mul_shift =
    438         FloorLog2Nonzero(static_cast<uint32_t>(tree[0].multiplier));
    439     const intptr_t onerow = channel.plane.PixelsPerRow();
    440     for (size_t y = 0; y < channel.h; y++) {
    441       const pixel_type *JXL_RESTRICT r = channel.Row(y);
    442       for (size_t x = 0; x < channel.w; x++) {
    443         PredictionResult pred = PredictNoTreeNoWP(channel.w, r + x, onerow, x,
    444                                                   y, tree[0].predictor);
    445         pixel_type_w residual = r[x] - pred.guess;
    446         JXL_DASSERT((residual >> mul_shift) * tree[0].multiplier == residual);
    447         *tokenp++ = Token(tree[0].childID, PackSigned(residual >> mul_shift));
    448       }
    449     }
    450 
    451   } else if (!use_wp && !skip_encoder_fast_path) {
    452     const intptr_t onerow = channel.plane.PixelsPerRow();
    453     JXL_ASSIGN_OR_RETURN(
    454         Channel references,
    455         Channel::Create(properties.size() - kNumNonrefProperties, channel.w));
    456     for (size_t y = 0; y < channel.h; y++) {
    457       const pixel_type *JXL_RESTRICT p = channel.Row(y);
    458       PrecomputeReferences(channel, y, image, chan, &references);
    459       float *pred_img_row[3];
    460       if (kWantDebug) {
    461         for (size_t c = 0; c < 3; c++) {
    462           pred_img_row[c] = predictor_img.PlaneRow(c, y);
    463         }
    464       }
    465       InitPropsRow(&properties, static_props, y);
    466       for (size_t x = 0; x < channel.w; x++) {
    467         PredictionResult res =
    468             PredictTreeNoWP(&properties, channel.w, p + x, onerow, x, y,
    469                             tree_lookup, references);
    470         if (kWantDebug) {
    471           for (size_t i = 0; i < 3; i++) {
    472             pred_img_row[i][x] = PredictorColor(res.predictor)[i];
    473           }
    474         }
    475         pixel_type_w residual = p[x] - res.guess;
    476         JXL_DASSERT(residual % res.multiplier == 0);
    477         *tokenp++ = Token(res.context, PackSigned(residual / res.multiplier));
    478       }
    479     }
    480   } else {
    481     const intptr_t onerow = channel.plane.PixelsPerRow();
    482     JXL_ASSIGN_OR_RETURN(
    483         Channel references,
    484         Channel::Create(properties.size() - kNumNonrefProperties, channel.w));
    485     weighted::State wp_state(wp_header, channel.w, channel.h);
    486     for (size_t y = 0; y < channel.h; y++) {
    487       const pixel_type *JXL_RESTRICT p = channel.Row(y);
    488       PrecomputeReferences(channel, y, image, chan, &references);
    489       float *pred_img_row[3];
    490       if (kWantDebug) {
    491         for (size_t c = 0; c < 3; c++) {
    492           pred_img_row[c] = predictor_img.PlaneRow(c, y);
    493         }
    494       }
    495       InitPropsRow(&properties, static_props, y);
    496       for (size_t x = 0; x < channel.w; x++) {
    497         PredictionResult res =
    498             PredictTreeWP(&properties, channel.w, p + x, onerow, x, y,
    499                           tree_lookup, references, &wp_state);
    500         if (kWantDebug) {
    501           for (size_t i = 0; i < 3; i++) {
    502             pred_img_row[i][x] = PredictorColor(res.predictor)[i];
    503           }
    504         }
    505         pixel_type_w residual = p[x] - res.guess;
    506         JXL_DASSERT(residual % res.multiplier == 0);
    507         *tokenp++ = Token(res.context, PackSigned(residual / res.multiplier));
    508         wp_state.UpdateErrors(p[x], x, y, channel.w);
    509       }
    510     }
    511   }
    512   /* TODO(szabadka): Add cparams to the call stack here.
    513   if (kWantDebug && WantDebugOutput(cparams)) {
    514     DumpImage(
    515         cparams,
    516         ("pred_" + ToString(group_id) + "_" + ToString(chan)).c_str(),
    517         predictor_img);
    518   }
    519   */
    520   *tokenpp = tokenp;
    521   return true;
    522 }
    523 
    524 Status ModularEncode(const Image &image, const ModularOptions &options,
    525                      BitWriter *writer, AuxOut *aux_out, size_t layer,
    526                      size_t group_id, TreeSamples *tree_samples,
    527                      size_t *total_pixels, const Tree *tree,
    528                      GroupHeader *header, std::vector<Token> *tokens,
    529                      size_t *width) {
    530   if (image.error) return JXL_FAILURE("Invalid image");
    531   size_t nb_channels = image.channel.size();
    532   JXL_DEBUG_V(
    533       2, "Encoding %" PRIuS "-channel, %i-bit, %" PRIuS "x%" PRIuS " image.",
    534       nb_channels, image.bitdepth, image.w, image.h);
    535 
    536   if (nb_channels < 1) {
    537     return true;  // is there any use for a zero-channel image?
    538   }
    539 
    540   // encode transforms
    541   GroupHeader header_storage;
    542   if (header == nullptr) header = &header_storage;
    543   Bundle::Init(header);
    544   if (options.predictor == Predictor::Weighted) {
    545     weighted::PredictorMode(options.wp_mode, &header->wp_header);
    546   }
    547   header->transforms = image.transform;
    548   // This doesn't actually work
    549   if (tree != nullptr) {
    550     header->use_global_tree = true;
    551   }
    552   if (tree_samples == nullptr && tree == nullptr) {
    553     JXL_RETURN_IF_ERROR(Bundle::Write(*header, writer, layer, aux_out));
    554   }
    555 
    556   TreeSamples tree_samples_storage;
    557   size_t total_pixels_storage = 0;
    558   if (!total_pixels) total_pixels = &total_pixels_storage;
    559   if (*total_pixels == 0) {
    560     for (size_t i = 0; i < nb_channels; i++) {
    561       if (i >= image.nb_meta_channels &&
    562           (image.channel[i].w > options.max_chan_size ||
    563            image.channel[i].h > options.max_chan_size)) {
    564         break;
    565       }
    566       *total_pixels += image.channel[i].w * image.channel[i].h;
    567     }
    568     *total_pixels = std::max<size_t>(*total_pixels, 1);
    569   }
    570   // If there's no tree, compute one (or gather data to).
    571   if (tree == nullptr &&
    572       options.tree_kind == ModularOptions::TreeKind::kLearn) {
    573     bool gather_data = tree_samples != nullptr;
    574     if (tree_samples == nullptr) {
    575       JXL_RETURN_IF_ERROR(tree_samples_storage.SetPredictor(
    576           options.predictor, options.wp_tree_mode));
    577       JXL_RETURN_IF_ERROR(tree_samples_storage.SetProperties(
    578           options.splitting_heuristics_properties, options.wp_tree_mode));
    579       std::vector<pixel_type> pixel_samples;
    580       std::vector<pixel_type> diff_samples;
    581       std::vector<uint32_t> group_pixel_count;
    582       std::vector<uint32_t> channel_pixel_count;
    583       CollectPixelSamples(image, options, 0, group_pixel_count,
    584                           channel_pixel_count, pixel_samples, diff_samples);
    585       std::vector<ModularMultiplierInfo> placeholder_multiplier_info;
    586       StaticPropRange range;
    587       tree_samples_storage.PreQuantizeProperties(
    588           range, placeholder_multiplier_info, group_pixel_count,
    589           channel_pixel_count, pixel_samples, diff_samples,
    590           options.max_property_values);
    591     }
    592     for (size_t i = 0; i < nb_channels; i++) {
    593       if (!image.channel[i].w || !image.channel[i].h) {
    594         continue;  // skip empty channels
    595       }
    596       if (i >= image.nb_meta_channels &&
    597           (image.channel[i].w > options.max_chan_size ||
    598            image.channel[i].h > options.max_chan_size)) {
    599         break;
    600       }
    601       JXL_RETURN_IF_ERROR(GatherTreeData(
    602           image, i, group_id, header->wp_header, options,
    603           gather_data ? *tree_samples : tree_samples_storage, total_pixels));
    604     }
    605     if (gather_data) return true;
    606   }
    607 
    608   JXL_ASSERT((tree == nullptr) == (tokens == nullptr));
    609 
    610   Tree tree_storage;
    611   std::vector<std::vector<Token>> tokens_storage(1);
    612   // Compute tree.
    613   if (tree == nullptr) {
    614     EntropyEncodingData code;
    615     std::vector<uint8_t> context_map;
    616 
    617     std::vector<std::vector<Token>> tree_tokens(1);
    618 
    619     tree_storage =
    620         options.tree_kind == ModularOptions::TreeKind::kLearn
    621             ? LearnTree(std::move(tree_samples_storage), *total_pixels, options)
    622             : PredefinedTree(options.tree_kind, *total_pixels);
    623     tree = &tree_storage;
    624     tokens = tokens_storage.data();
    625 
    626     Tree decoded_tree;
    627     TokenizeTree(*tree, tree_tokens.data(), &decoded_tree);
    628     JXL_ASSERT(tree->size() == decoded_tree.size());
    629     tree_storage = std::move(decoded_tree);
    630 
    631     /* TODO(szabadka) Add text output callback
    632     if (kWantDebug && kPrintTree && WantDebugOutput(aux_out)) {
    633       PrintTree(*tree, aux_out->debug_prefix + "/tree_" + ToString(group_id));
    634     } */
    635 
    636     // Write tree
    637     BuildAndEncodeHistograms(options.histogram_params, kNumTreeContexts,
    638                              tree_tokens, &code, &context_map, writer,
    639                              kLayerModularTree, aux_out);
    640     WriteTokens(tree_tokens[0], code, context_map, 0, writer, kLayerModularTree,
    641                 aux_out);
    642   }
    643 
    644   size_t image_width = 0;
    645   size_t total_tokens = 0;
    646   for (size_t i = 0; i < nb_channels; i++) {
    647     if (i >= image.nb_meta_channels &&
    648         (image.channel[i].w > options.max_chan_size ||
    649          image.channel[i].h > options.max_chan_size)) {
    650       break;
    651     }
    652     if (image.channel[i].w > image_width) image_width = image.channel[i].w;
    653     total_tokens += image.channel[i].w * image.channel[i].h;
    654   }
    655   if (options.zero_tokens) {
    656     tokens->resize(tokens->size() + total_tokens, {0, 0});
    657   } else {
    658     // Do one big allocation for all the tokens we'll need,
    659     // to avoid reallocs that might require copying.
    660     size_t pos = tokens->size();
    661     tokens->resize(pos + total_tokens);
    662     Token *tokenp = tokens->data() + pos;
    663     for (size_t i = 0; i < nb_channels; i++) {
    664       if (!image.channel[i].w || !image.channel[i].h) {
    665         continue;  // skip empty channels
    666       }
    667       if (i >= image.nb_meta_channels &&
    668           (image.channel[i].w > options.max_chan_size ||
    669            image.channel[i].h > options.max_chan_size)) {
    670         break;
    671       }
    672       JXL_RETURN_IF_ERROR(EncodeModularChannelMAANS(
    673           image, i, header->wp_header, *tree, &tokenp, aux_out, group_id,
    674           options.skip_encoder_fast_path));
    675     }
    676     // Make sure we actually wrote all tokens
    677     JXL_CHECK(tokenp == tokens->data() + tokens->size());
    678   }
    679 
    680   // Write data if not using a global tree/ANS stream.
    681   if (!header->use_global_tree) {
    682     EntropyEncodingData code;
    683     std::vector<uint8_t> context_map;
    684     HistogramParams histo_params = options.histogram_params;
    685     histo_params.image_widths.push_back(image_width);
    686     BuildAndEncodeHistograms(histo_params, (tree->size() + 1) / 2,
    687                              tokens_storage, &code, &context_map, writer, layer,
    688                              aux_out);
    689     WriteTokens(tokens_storage[0], code, context_map, 0, writer, layer,
    690                 aux_out);
    691   } else {
    692     *width = image_width;
    693   }
    694   return true;
    695 }
    696 
    697 Status ModularGenericCompress(Image &image, const ModularOptions &opts,
    698                               BitWriter *writer, AuxOut *aux_out, size_t layer,
    699                               size_t group_id, TreeSamples *tree_samples,
    700                               size_t *total_pixels, const Tree *tree,
    701                               GroupHeader *header, std::vector<Token> *tokens,
    702                               size_t *width) {
    703   if (image.w == 0 || image.h == 0) return true;
    704   ModularOptions options = opts;  // Make a copy to modify it.
    705 
    706   if (options.predictor == kUndefinedPredictor) {
    707     options.predictor = Predictor::Gradient;
    708   }
    709 
    710   size_t bits = writer ? writer->BitsWritten() : 0;
    711   JXL_RETURN_IF_ERROR(ModularEncode(image, options, writer, aux_out, layer,
    712                                     group_id, tree_samples, total_pixels, tree,
    713                                     header, tokens, width));
    714   bits = writer ? writer->BitsWritten() - bits : 0;
    715   if (writer) {
    716     JXL_DEBUG_V(4,
    717                 "Modular-encoded a %" PRIuS "x%" PRIuS
    718                 " bitdepth=%i nbchans=%" PRIuS " image in %" PRIuS " bytes",
    719                 image.w, image.h, image.bitdepth, image.channel.size(),
    720                 bits / 8);
    721   }
    722   (void)bits;
    723   return true;
    724 }
    725 
    726 }  // namespace jxl