libjxl

FORK: libjxl patches used on blog
git clone https://git.neptards.moe/blog/libjxl.git
Log | Files | Refs | Submodules | README | LICENSE

enc_ma.cc (38376B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 #include "lib/jxl/modular/encoding/enc_ma.h"
      7 
      8 #include <algorithm>
      9 #include <limits>
     10 #include <numeric>
     11 #include <queue>
     12 #include <unordered_map>
     13 #include <unordered_set>
     14 
     15 #include "lib/jxl/modular/encoding/ma_common.h"
     16 
     17 #undef HWY_TARGET_INCLUDE
     18 #define HWY_TARGET_INCLUDE "lib/jxl/modular/encoding/enc_ma.cc"
     19 #include <hwy/foreach_target.h>
     20 #include <hwy/highway.h>
     21 
     22 #include "lib/jxl/base/fast_math-inl.h"
     23 #include "lib/jxl/base/random.h"
     24 #include "lib/jxl/enc_ans.h"
     25 #include "lib/jxl/modular/encoding/context_predict.h"
     26 #include "lib/jxl/modular/options.h"
     27 #include "lib/jxl/pack_signed.h"
     28 HWY_BEFORE_NAMESPACE();
     29 namespace jxl {
     30 namespace HWY_NAMESPACE {
     31 
     32 // These templates are not found via ADL.
     33 using hwy::HWY_NAMESPACE::Eq;
     34 using hwy::HWY_NAMESPACE::IfThenElse;
     35 using hwy::HWY_NAMESPACE::Lt;
     36 using hwy::HWY_NAMESPACE::Max;
     37 
     38 const HWY_FULL(float) df;
     39 const HWY_FULL(int32_t) di;
     40 size_t Padded(size_t x) { return RoundUpTo(x, Lanes(df)); }
     41 
     42 // Compute entropy of the histogram, taking into account the minimum probability
     43 // for symbols with non-zero counts.
     44 float EstimateBits(const int32_t *counts, size_t num_symbols) {
     45   int32_t total = std::accumulate(counts, counts + num_symbols, 0);
     46   const auto zero = Zero(df);
     47   const auto minprob = Set(df, 1.0f / ANS_TAB_SIZE);
     48   const auto inv_total = Set(df, 1.0f / total);
     49   auto bits_lanes = Zero(df);
     50   auto total_v = Set(di, total);
     51   for (size_t i = 0; i < num_symbols; i += Lanes(df)) {
     52     const auto counts_iv = LoadU(di, &counts[i]);
     53     const auto counts_fv = ConvertTo(df, counts_iv);
     54     const auto probs = Mul(counts_fv, inv_total);
     55     const auto mprobs = Max(probs, minprob);
     56     const auto nbps = IfThenElse(Eq(counts_iv, total_v), BitCast(di, zero),
     57                                  BitCast(di, FastLog2f(df, mprobs)));
     58     bits_lanes = Sub(bits_lanes, Mul(counts_fv, BitCast(df, nbps)));
     59   }
     60   return GetLane(SumOfLanes(df, bits_lanes));
     61 }
     62 
     63 void MakeSplitNode(size_t pos, int property, int splitval, Predictor lpred,
     64                    int64_t loff, Predictor rpred, int64_t roff, Tree *tree) {
     65   // Note that the tree splits on *strictly greater*.
     66   (*tree)[pos].lchild = tree->size();
     67   (*tree)[pos].rchild = tree->size() + 1;
     68   (*tree)[pos].splitval = splitval;
     69   (*tree)[pos].property = property;
     70   tree->emplace_back();
     71   tree->back().property = -1;
     72   tree->back().predictor = rpred;
     73   tree->back().predictor_offset = roff;
     74   tree->back().multiplier = 1;
     75   tree->emplace_back();
     76   tree->back().property = -1;
     77   tree->back().predictor = lpred;
     78   tree->back().predictor_offset = loff;
     79   tree->back().multiplier = 1;
     80 }
     81 
     82 enum class IntersectionType { kNone, kPartial, kInside };
     83 IntersectionType BoxIntersects(StaticPropRange needle, StaticPropRange haystack,
     84                                uint32_t &partial_axis, uint32_t &partial_val) {
     85   bool partial = false;
     86   for (size_t i = 0; i < kNumStaticProperties; i++) {
     87     if (haystack[i][0] >= needle[i][1]) {
     88       return IntersectionType::kNone;
     89     }
     90     if (haystack[i][1] <= needle[i][0]) {
     91       return IntersectionType::kNone;
     92     }
     93     if (haystack[i][0] <= needle[i][0] && haystack[i][1] >= needle[i][1]) {
     94       continue;
     95     }
     96     partial = true;
     97     partial_axis = i;
     98     if (haystack[i][0] > needle[i][0] && haystack[i][0] < needle[i][1]) {
     99       partial_val = haystack[i][0] - 1;
    100     } else {
    101       JXL_DASSERT(haystack[i][1] > needle[i][0] &&
    102                   haystack[i][1] < needle[i][1]);
    103       partial_val = haystack[i][1] - 1;
    104     }
    105   }
    106   return partial ? IntersectionType::kPartial : IntersectionType::kInside;
    107 }
    108 
    109 void SplitTreeSamples(TreeSamples &tree_samples, size_t begin, size_t pos,
    110                       size_t end, size_t prop) {
    111   auto cmp = [&](size_t a, size_t b) {
    112     return static_cast<int32_t>(tree_samples.Property(prop, a)) -
    113            static_cast<int32_t>(tree_samples.Property(prop, b));
    114   };
    115   Rng rng(0);
    116   while (end > begin + 1) {
    117     {
    118       size_t pivot = rng.UniformU(begin, end);
    119       tree_samples.Swap(begin, pivot);
    120     }
    121     size_t pivot_begin = begin;
    122     size_t pivot_end = pivot_begin + 1;
    123     for (size_t i = begin + 1; i < end; i++) {
    124       JXL_DASSERT(i >= pivot_end);
    125       JXL_DASSERT(pivot_end > pivot_begin);
    126       int32_t cmp_result = cmp(i, pivot_begin);
    127       if (cmp_result < 0) {  // i < pivot, move pivot forward and put i before
    128                              // the pivot.
    129         tree_samples.ThreeShuffle(pivot_begin, pivot_end, i);
    130         pivot_begin++;
    131         pivot_end++;
    132       } else if (cmp_result == 0) {
    133         tree_samples.Swap(pivot_end, i);
    134         pivot_end++;
    135       }
    136     }
    137     JXL_DASSERT(pivot_begin >= begin);
    138     JXL_DASSERT(pivot_end > pivot_begin);
    139     JXL_DASSERT(pivot_end <= end);
    140     for (size_t i = begin; i < pivot_begin; i++) {
    141       JXL_DASSERT(cmp(i, pivot_begin) < 0);
    142     }
    143     for (size_t i = pivot_end; i < end; i++) {
    144       JXL_DASSERT(cmp(i, pivot_begin) > 0);
    145     }
    146     for (size_t i = pivot_begin; i < pivot_end; i++) {
    147       JXL_DASSERT(cmp(i, pivot_begin) == 0);
    148     }
    149     // We now have that [begin, pivot_begin) is < pivot, [pivot_begin,
    150     // pivot_end) is = pivot, and [pivot_end, end) is > pivot.
    151     // If pos falls in the first or the last interval, we continue in that
    152     // interval; otherwise, we are done.
    153     if (pivot_begin > pos) {
    154       end = pivot_begin;
    155     } else if (pivot_end < pos) {
    156       begin = pivot_end;
    157     } else {
    158       break;
    159     }
    160   }
    161 }
    162 
    163 void FindBestSplit(TreeSamples &tree_samples, float threshold,
    164                    const std::vector<ModularMultiplierInfo> &mul_info,
    165                    StaticPropRange initial_static_prop_range,
    166                    float fast_decode_multiplier, Tree *tree) {
    167   struct NodeInfo {
    168     size_t pos;
    169     size_t begin;
    170     size_t end;
    171     uint64_t used_properties;
    172     StaticPropRange static_prop_range;
    173   };
    174   std::vector<NodeInfo> nodes;
    175   nodes.push_back(NodeInfo{0, 0, tree_samples.NumDistinctSamples(), 0,
    176                            initial_static_prop_range});
    177 
    178   size_t num_predictors = tree_samples.NumPredictors();
    179   size_t num_properties = tree_samples.NumProperties();
    180 
    181   // TODO(veluca): consider parallelizing the search (processing multiple nodes
    182   // at a time).
    183   while (!nodes.empty()) {
    184     size_t pos = nodes.back().pos;
    185     size_t begin = nodes.back().begin;
    186     size_t end = nodes.back().end;
    187     uint64_t used_properties = nodes.back().used_properties;
    188     StaticPropRange static_prop_range = nodes.back().static_prop_range;
    189     nodes.pop_back();
    190     if (begin == end) continue;
    191 
    192     struct SplitInfo {
    193       size_t prop = 0;
    194       uint32_t val = 0;
    195       size_t pos = 0;
    196       float lcost = std::numeric_limits<float>::max();
    197       float rcost = std::numeric_limits<float>::max();
    198       Predictor lpred = Predictor::Zero;
    199       Predictor rpred = Predictor::Zero;
    200       float Cost() { return lcost + rcost; }
    201     };
    202 
    203     SplitInfo best_split_static_constant;
    204     SplitInfo best_split_static;
    205     SplitInfo best_split_nonstatic;
    206     SplitInfo best_split_nowp;
    207 
    208     JXL_DASSERT(begin <= end);
    209     JXL_DASSERT(end <= tree_samples.NumDistinctSamples());
    210 
    211     // Compute the maximum token in the range.
    212     size_t max_symbols = 0;
    213     for (size_t pred = 0; pred < num_predictors; pred++) {
    214       for (size_t i = begin; i < end; i++) {
    215         uint32_t tok = tree_samples.Token(pred, i);
    216         max_symbols = max_symbols > tok + 1 ? max_symbols : tok + 1;
    217       }
    218     }
    219     max_symbols = Padded(max_symbols);
    220     std::vector<int32_t> counts(max_symbols * num_predictors);
    221     std::vector<uint32_t> tot_extra_bits(num_predictors);
    222     for (size_t pred = 0; pred < num_predictors; pred++) {
    223       for (size_t i = begin; i < end; i++) {
    224         counts[pred * max_symbols + tree_samples.Token(pred, i)] +=
    225             tree_samples.Count(i);
    226         tot_extra_bits[pred] +=
    227             tree_samples.NBits(pred, i) * tree_samples.Count(i);
    228       }
    229     }
    230 
    231     float base_bits;
    232     {
    233       size_t pred = tree_samples.PredictorIndex((*tree)[pos].predictor);
    234       base_bits =
    235           EstimateBits(counts.data() + pred * max_symbols, max_symbols) +
    236           tot_extra_bits[pred];
    237     }
    238 
    239     SplitInfo *best = &best_split_nonstatic;
    240 
    241     SplitInfo forced_split;
    242     // The multiplier ranges cut halfway through the current ranges of static
    243     // properties. We do this even if the current node is not a leaf, to
    244     // minimize the number of nodes in the resulting tree.
    245     for (size_t i = 0; i < mul_info.size(); i++) {
    246       uint32_t axis;
    247       uint32_t val;
    248       IntersectionType t =
    249           BoxIntersects(static_prop_range, mul_info[i].range, axis, val);
    250       if (t == IntersectionType::kNone) continue;
    251       if (t == IntersectionType::kInside) {
    252         (*tree)[pos].multiplier = mul_info[i].multiplier;
    253         break;
    254       }
    255       if (t == IntersectionType::kPartial) {
    256         forced_split.val = tree_samples.QuantizeProperty(axis, val);
    257         forced_split.prop = axis;
    258         forced_split.lcost = forced_split.rcost = base_bits / 2 - threshold;
    259         forced_split.lpred = forced_split.rpred = (*tree)[pos].predictor;
    260         best = &forced_split;
    261         best->pos = begin;
    262         JXL_ASSERT(best->prop == tree_samples.PropertyFromIndex(best->prop));
    263         for (size_t x = begin; x < end; x++) {
    264           if (tree_samples.Property(best->prop, x) <= best->val) {
    265             best->pos++;
    266           }
    267         }
    268         break;
    269       }
    270     }
    271 
    272     if (best != &forced_split) {
    273       std::vector<int> prop_value_used_count;
    274       std::vector<int> count_increase;
    275       std::vector<size_t> extra_bits_increase;
    276       // For each property, compute which of its values are used, and what
    277       // tokens correspond to those usages. Then, iterate through the values,
    278       // and compute the entropy of each side of the split (of the form `prop >
    279       // threshold`). Finally, find the split that minimizes the cost.
    280       struct CostInfo {
    281         float cost = std::numeric_limits<float>::max();
    282         float extra_cost = 0;
    283         float Cost() const { return cost + extra_cost; }
    284         Predictor pred;  // will be uninitialized in some cases, but never used.
    285       };
    286       std::vector<CostInfo> costs_l;
    287       std::vector<CostInfo> costs_r;
    288 
    289       std::vector<int32_t> counts_above(max_symbols);
    290       std::vector<int32_t> counts_below(max_symbols);
    291 
    292       // The lower the threshold, the higher the expected noisiness of the
    293       // estimate. Thus, discourage changing predictors.
    294       float change_pred_penalty = 800.0f / (100.0f + threshold);
    295       for (size_t prop = 0; prop < num_properties && base_bits > threshold;
    296            prop++) {
    297         costs_l.clear();
    298         costs_r.clear();
    299         size_t prop_size = tree_samples.NumPropertyValues(prop);
    300         if (extra_bits_increase.size() < prop_size) {
    301           count_increase.resize(prop_size * max_symbols);
    302           extra_bits_increase.resize(prop_size);
    303         }
    304         // Clear prop_value_used_count (which cannot be cleared "on the go")
    305         prop_value_used_count.clear();
    306         prop_value_used_count.resize(prop_size);
    307 
    308         size_t first_used = prop_size;
    309         size_t last_used = 0;
    310 
    311         // TODO(veluca): consider finding multiple splits along a single
    312         // property at the same time, possibly with a bottom-up approach.
    313         for (size_t i = begin; i < end; i++) {
    314           size_t p = tree_samples.Property(prop, i);
    315           prop_value_used_count[p]++;
    316           last_used = std::max(last_used, p);
    317           first_used = std::min(first_used, p);
    318         }
    319         costs_l.resize(last_used - first_used);
    320         costs_r.resize(last_used - first_used);
    321         // For all predictors, compute the right and left costs of each split.
    322         for (size_t pred = 0; pred < num_predictors; pred++) {
    323           // Compute cost and histogram increments for each property value.
    324           for (size_t i = begin; i < end; i++) {
    325             size_t p = tree_samples.Property(prop, i);
    326             size_t cnt = tree_samples.Count(i);
    327             size_t sym = tree_samples.Token(pred, i);
    328             count_increase[p * max_symbols + sym] += cnt;
    329             extra_bits_increase[p] += tree_samples.NBits(pred, i) * cnt;
    330           }
    331           memcpy(counts_above.data(), counts.data() + pred * max_symbols,
    332                  max_symbols * sizeof counts_above[0]);
    333           memset(counts_below.data(), 0, max_symbols * sizeof counts_below[0]);
    334           size_t extra_bits_below = 0;
    335           // Exclude last used: this ensures neither counts_above nor
    336           // counts_below is empty.
    337           for (size_t i = first_used; i < last_used; i++) {
    338             if (!prop_value_used_count[i]) continue;
    339             extra_bits_below += extra_bits_increase[i];
    340             // The increase for this property value has been used, and will not
    341             // be used again: clear it. Also below.
    342             extra_bits_increase[i] = 0;
    343             for (size_t sym = 0; sym < max_symbols; sym++) {
    344               counts_above[sym] -= count_increase[i * max_symbols + sym];
    345               counts_below[sym] += count_increase[i * max_symbols + sym];
    346               count_increase[i * max_symbols + sym] = 0;
    347             }
    348             float rcost = EstimateBits(counts_above.data(), max_symbols) +
    349                           tot_extra_bits[pred] - extra_bits_below;
    350             float lcost = EstimateBits(counts_below.data(), max_symbols) +
    351                           extra_bits_below;
    352             JXL_DASSERT(extra_bits_below <= tot_extra_bits[pred]);
    353             float penalty = 0;
    354             // Never discourage moving away from the Weighted predictor.
    355             if (tree_samples.PredictorFromIndex(pred) !=
    356                     (*tree)[pos].predictor &&
    357                 (*tree)[pos].predictor != Predictor::Weighted) {
    358               penalty = change_pred_penalty;
    359             }
    360             // If everything else is equal, disfavour Weighted (slower) and
    361             // favour Zero (faster if it's the only predictor used in a
    362             // group+channel combination)
    363             if (tree_samples.PredictorFromIndex(pred) == Predictor::Weighted) {
    364               penalty += 1e-8;
    365             }
    366             if (tree_samples.PredictorFromIndex(pred) == Predictor::Zero) {
    367               penalty -= 1e-8;
    368             }
    369             if (rcost + penalty < costs_r[i - first_used].Cost()) {
    370               costs_r[i - first_used].cost = rcost;
    371               costs_r[i - first_used].extra_cost = penalty;
    372               costs_r[i - first_used].pred =
    373                   tree_samples.PredictorFromIndex(pred);
    374             }
    375             if (lcost + penalty < costs_l[i - first_used].Cost()) {
    376               costs_l[i - first_used].cost = lcost;
    377               costs_l[i - first_used].extra_cost = penalty;
    378               costs_l[i - first_used].pred =
    379                   tree_samples.PredictorFromIndex(pred);
    380             }
    381           }
    382         }
    383         // Iterate through the possible splits and find the one with minimum sum
    384         // of costs of the two sides.
    385         size_t split = begin;
    386         for (size_t i = first_used; i < last_used; i++) {
    387           if (!prop_value_used_count[i]) continue;
    388           split += prop_value_used_count[i];
    389           float rcost = costs_r[i - first_used].cost;
    390           float lcost = costs_l[i - first_used].cost;
    391           // WP was not used + we would use the WP property or predictor
    392           bool adds_wp =
    393               (tree_samples.PropertyFromIndex(prop) == kWPProp &&
    394                (used_properties & (1LU << prop)) == 0) ||
    395               ((costs_l[i - first_used].pred == Predictor::Weighted ||
    396                 costs_r[i - first_used].pred == Predictor::Weighted) &&
    397                (*tree)[pos].predictor != Predictor::Weighted);
    398           bool zero_entropy_side = rcost == 0 || lcost == 0;
    399 
    400           SplitInfo &best =
    401               prop < kNumStaticProperties
    402                   ? (zero_entropy_side ? best_split_static_constant
    403                                        : best_split_static)
    404                   : (adds_wp ? best_split_nonstatic : best_split_nowp);
    405           if (lcost + rcost < best.Cost()) {
    406             best.prop = prop;
    407             best.val = i;
    408             best.pos = split;
    409             best.lcost = lcost;
    410             best.lpred = costs_l[i - first_used].pred;
    411             best.rcost = rcost;
    412             best.rpred = costs_r[i - first_used].pred;
    413           }
    414         }
    415         // Clear extra_bits_increase and cost_increase for last_used.
    416         extra_bits_increase[last_used] = 0;
    417         for (size_t sym = 0; sym < max_symbols; sym++) {
    418           count_increase[last_used * max_symbols + sym] = 0;
    419         }
    420       }
    421 
    422       // Try to avoid introducing WP.
    423       if (best_split_nowp.Cost() + threshold < base_bits &&
    424           best_split_nowp.Cost() <= fast_decode_multiplier * best->Cost()) {
    425         best = &best_split_nowp;
    426       }
    427       // Split along static props if possible and not significantly more
    428       // expensive.
    429       if (best_split_static.Cost() + threshold < base_bits &&
    430           best_split_static.Cost() <= fast_decode_multiplier * best->Cost()) {
    431         best = &best_split_static;
    432       }
    433       // Split along static props to create constant nodes if possible.
    434       if (best_split_static_constant.Cost() + threshold < base_bits) {
    435         best = &best_split_static_constant;
    436       }
    437     }
    438 
    439     if (best->Cost() + threshold < base_bits) {
    440       uint32_t p = tree_samples.PropertyFromIndex(best->prop);
    441       pixel_type dequant =
    442           tree_samples.UnquantizeProperty(best->prop, best->val);
    443       // Split node and try to split children.
    444       MakeSplitNode(pos, p, dequant, best->lpred, 0, best->rpred, 0, tree);
    445       // "Sort" according to winning property
    446       SplitTreeSamples(tree_samples, begin, best->pos, end, best->prop);
    447       if (p >= kNumStaticProperties) {
    448         used_properties |= 1 << best->prop;
    449       }
    450       auto new_sp_range = static_prop_range;
    451       if (p < kNumStaticProperties) {
    452         JXL_ASSERT(static_cast<uint32_t>(dequant + 1) <= new_sp_range[p][1]);
    453         new_sp_range[p][1] = dequant + 1;
    454         JXL_ASSERT(new_sp_range[p][0] < new_sp_range[p][1]);
    455       }
    456       nodes.push_back(NodeInfo{(*tree)[pos].rchild, begin, best->pos,
    457                                used_properties, new_sp_range});
    458       new_sp_range = static_prop_range;
    459       if (p < kNumStaticProperties) {
    460         JXL_ASSERT(new_sp_range[p][0] <= static_cast<uint32_t>(dequant + 1));
    461         new_sp_range[p][0] = dequant + 1;
    462         JXL_ASSERT(new_sp_range[p][0] < new_sp_range[p][1]);
    463       }
    464       nodes.push_back(NodeInfo{(*tree)[pos].lchild, best->pos, end,
    465                                used_properties, new_sp_range});
    466     }
    467   }
    468 }
    469 
    470 // NOLINTNEXTLINE(google-readability-namespace-comments)
    471 }  // namespace HWY_NAMESPACE
    472 }  // namespace jxl
    473 HWY_AFTER_NAMESPACE();
    474 
    475 #if HWY_ONCE
    476 namespace jxl {
    477 
    478 HWY_EXPORT(FindBestSplit);  // Local function.
    479 
    480 void ComputeBestTree(TreeSamples &tree_samples, float threshold,
    481                      const std::vector<ModularMultiplierInfo> &mul_info,
    482                      StaticPropRange static_prop_range,
    483                      float fast_decode_multiplier, Tree *tree) {
    484   // TODO(veluca): take into account that different contexts can have different
    485   // uint configs.
    486   //
    487   // Initialize tree.
    488   tree->emplace_back();
    489   tree->back().property = -1;
    490   tree->back().predictor = tree_samples.PredictorFromIndex(0);
    491   tree->back().predictor_offset = 0;
    492   tree->back().multiplier = 1;
    493   JXL_ASSERT(tree_samples.NumProperties() < 64);
    494 
    495   JXL_ASSERT(tree_samples.NumDistinctSamples() <=
    496              std::numeric_limits<uint32_t>::max());
    497   HWY_DYNAMIC_DISPATCH(FindBestSplit)
    498   (tree_samples, threshold, mul_info, static_prop_range, fast_decode_multiplier,
    499    tree);
    500 }
    501 
    502 constexpr int32_t TreeSamples::kPropertyRange;
    503 constexpr uint32_t TreeSamples::kDedupEntryUnused;
    504 
    505 Status TreeSamples::SetPredictor(Predictor predictor,
    506                                  ModularOptions::TreeMode wp_tree_mode) {
    507   if (wp_tree_mode == ModularOptions::TreeMode::kWPOnly) {
    508     predictors = {Predictor::Weighted};
    509     residuals.resize(1);
    510     return true;
    511   }
    512   if (wp_tree_mode == ModularOptions::TreeMode::kNoWP &&
    513       predictor == Predictor::Weighted) {
    514     return JXL_FAILURE("Invalid predictor settings");
    515   }
    516   if (predictor == Predictor::Variable) {
    517     for (size_t i = 0; i < kNumModularPredictors; i++) {
    518       predictors.push_back(static_cast<Predictor>(i));
    519     }
    520     std::swap(predictors[0], predictors[static_cast<int>(Predictor::Weighted)]);
    521     std::swap(predictors[1], predictors[static_cast<int>(Predictor::Gradient)]);
    522   } else if (predictor == Predictor::Best) {
    523     predictors = {Predictor::Weighted, Predictor::Gradient};
    524   } else {
    525     predictors = {predictor};
    526   }
    527   if (wp_tree_mode == ModularOptions::TreeMode::kNoWP) {
    528     auto wp_it =
    529         std::find(predictors.begin(), predictors.end(), Predictor::Weighted);
    530     if (wp_it != predictors.end()) {
    531       predictors.erase(wp_it);
    532     }
    533   }
    534   residuals.resize(predictors.size());
    535   return true;
    536 }
    537 
    538 Status TreeSamples::SetProperties(const std::vector<uint32_t> &properties,
    539                                   ModularOptions::TreeMode wp_tree_mode) {
    540   props_to_use = properties;
    541   if (wp_tree_mode == ModularOptions::TreeMode::kWPOnly) {
    542     props_to_use = {static_cast<uint32_t>(kWPProp)};
    543   }
    544   if (wp_tree_mode == ModularOptions::TreeMode::kGradientOnly) {
    545     props_to_use = {static_cast<uint32_t>(kGradientProp)};
    546   }
    547   if (wp_tree_mode == ModularOptions::TreeMode::kNoWP) {
    548     auto it = std::find(props_to_use.begin(), props_to_use.end(), kWPProp);
    549     if (it != props_to_use.end()) {
    550       props_to_use.erase(it);
    551     }
    552   }
    553   if (props_to_use.empty()) {
    554     return JXL_FAILURE("Invalid property set configuration");
    555   }
    556   props.resize(props_to_use.size());
    557   return true;
    558 }
    559 
    560 void TreeSamples::InitTable(size_t size) {
    561   JXL_DASSERT((size & (size - 1)) == 0);
    562   if (dedup_table_.size() == size) return;
    563   dedup_table_.resize(size, kDedupEntryUnused);
    564   for (size_t i = 0; i < NumDistinctSamples(); i++) {
    565     if (sample_counts[i] != std::numeric_limits<uint16_t>::max()) {
    566       AddToTable(i);
    567     }
    568   }
    569 }
    570 
    571 bool TreeSamples::AddToTableAndMerge(size_t a) {
    572   size_t pos1 = Hash1(a);
    573   size_t pos2 = Hash2(a);
    574   if (dedup_table_[pos1] != kDedupEntryUnused &&
    575       IsSameSample(a, dedup_table_[pos1])) {
    576     JXL_DASSERT(sample_counts[a] == 1);
    577     sample_counts[dedup_table_[pos1]]++;
    578     // Remove from hash table samples that are saturated.
    579     if (sample_counts[dedup_table_[pos1]] ==
    580         std::numeric_limits<uint16_t>::max()) {
    581       dedup_table_[pos1] = kDedupEntryUnused;
    582     }
    583     return true;
    584   }
    585   if (dedup_table_[pos2] != kDedupEntryUnused &&
    586       IsSameSample(a, dedup_table_[pos2])) {
    587     JXL_DASSERT(sample_counts[a] == 1);
    588     sample_counts[dedup_table_[pos2]]++;
    589     // Remove from hash table samples that are saturated.
    590     if (sample_counts[dedup_table_[pos2]] ==
    591         std::numeric_limits<uint16_t>::max()) {
    592       dedup_table_[pos2] = kDedupEntryUnused;
    593     }
    594     return true;
    595   }
    596   AddToTable(a);
    597   return false;
    598 }
    599 
    600 void TreeSamples::AddToTable(size_t a) {
    601   size_t pos1 = Hash1(a);
    602   size_t pos2 = Hash2(a);
    603   if (dedup_table_[pos1] == kDedupEntryUnused) {
    604     dedup_table_[pos1] = a;
    605   } else if (dedup_table_[pos2] == kDedupEntryUnused) {
    606     dedup_table_[pos2] = a;
    607   }
    608 }
    609 
    610 void TreeSamples::PrepareForSamples(size_t num_samples) {
    611   for (auto &res : residuals) {
    612     res.reserve(res.size() + num_samples);
    613   }
    614   for (auto &p : props) {
    615     p.reserve(p.size() + num_samples);
    616   }
    617   size_t total_num_samples = num_samples + sample_counts.size();
    618   size_t next_pow2 = 1LLU << CeilLog2Nonzero(total_num_samples * 3 / 2);
    619   InitTable(next_pow2);
    620 }
    621 
    622 size_t TreeSamples::Hash1(size_t a) const {
    623   constexpr uint64_t constant = 0x1e35a7bd;
    624   uint64_t h = constant;
    625   for (const auto &r : residuals) {
    626     h = h * constant + r[a].tok;
    627     h = h * constant + r[a].nbits;
    628   }
    629   for (const auto &p : props) {
    630     h = h * constant + p[a];
    631   }
    632   return (h >> 16) & (dedup_table_.size() - 1);
    633 }
    634 size_t TreeSamples::Hash2(size_t a) const {
    635   constexpr uint64_t constant = 0x1e35a7bd1e35a7bd;
    636   uint64_t h = constant;
    637   for (const auto &p : props) {
    638     h = h * constant ^ p[a];
    639   }
    640   for (const auto &r : residuals) {
    641     h = h * constant ^ r[a].tok;
    642     h = h * constant ^ r[a].nbits;
    643   }
    644   return (h >> 16) & (dedup_table_.size() - 1);
    645 }
    646 
    647 bool TreeSamples::IsSameSample(size_t a, size_t b) const {
    648   bool ret = true;
    649   for (const auto &r : residuals) {
    650     if (r[a].tok != r[b].tok) {
    651       ret = false;
    652     }
    653     if (r[a].nbits != r[b].nbits) {
    654       ret = false;
    655     }
    656   }
    657   for (const auto &p : props) {
    658     if (p[a] != p[b]) {
    659       ret = false;
    660     }
    661   }
    662   return ret;
    663 }
    664 
    665 void TreeSamples::AddSample(pixel_type_w pixel, const Properties &properties,
    666                             const pixel_type_w *predictions) {
    667   for (size_t i = 0; i < predictors.size(); i++) {
    668     pixel_type v = pixel - predictions[static_cast<int>(predictors[i])];
    669     uint32_t tok, nbits, bits;
    670     HybridUintConfig(4, 1, 2).Encode(PackSigned(v), &tok, &nbits, &bits);
    671     JXL_DASSERT(tok < 256);
    672     JXL_DASSERT(nbits < 256);
    673     residuals[i].emplace_back(
    674         ResidualToken{static_cast<uint8_t>(tok), static_cast<uint8_t>(nbits)});
    675   }
    676   for (size_t i = 0; i < props_to_use.size(); i++) {
    677     props[i].push_back(QuantizeProperty(i, properties[props_to_use[i]]));
    678   }
    679   sample_counts.push_back(1);
    680   num_samples++;
    681   if (AddToTableAndMerge(sample_counts.size() - 1)) {
    682     for (auto &r : residuals) r.pop_back();
    683     for (auto &p : props) p.pop_back();
    684     sample_counts.pop_back();
    685   }
    686 }
    687 
    688 void TreeSamples::Swap(size_t a, size_t b) {
    689   if (a == b) return;
    690   for (auto &r : residuals) {
    691     std::swap(r[a], r[b]);
    692   }
    693   for (auto &p : props) {
    694     std::swap(p[a], p[b]);
    695   }
    696   std::swap(sample_counts[a], sample_counts[b]);
    697 }
    698 
    699 void TreeSamples::ThreeShuffle(size_t a, size_t b, size_t c) {
    700   if (b == c) {
    701     Swap(a, b);
    702     return;
    703   }
    704 
    705   for (auto &r : residuals) {
    706     auto tmp = r[a];
    707     r[a] = r[c];
    708     r[c] = r[b];
    709     r[b] = tmp;
    710   }
    711   for (auto &p : props) {
    712     auto tmp = p[a];
    713     p[a] = p[c];
    714     p[c] = p[b];
    715     p[b] = tmp;
    716   }
    717   auto tmp = sample_counts[a];
    718   sample_counts[a] = sample_counts[c];
    719   sample_counts[c] = sample_counts[b];
    720   sample_counts[b] = tmp;
    721 }
    722 
    723 namespace {
    724 std::vector<int32_t> QuantizeHistogram(const std::vector<uint32_t> &histogram,
    725                                        size_t num_chunks) {
    726   if (histogram.empty()) return {};
    727   // TODO(veluca): selecting distinct quantiles is likely not the best
    728   // way to go about this.
    729   std::vector<int32_t> thresholds;
    730   uint64_t sum = std::accumulate(histogram.begin(), histogram.end(), 0LU);
    731   uint64_t cumsum = 0;
    732   uint64_t threshold = 1;
    733   for (size_t i = 0; i + 1 < histogram.size(); i++) {
    734     cumsum += histogram[i];
    735     if (cumsum >= threshold * sum / num_chunks) {
    736       thresholds.push_back(i);
    737       while (cumsum > threshold * sum / num_chunks) threshold++;
    738     }
    739   }
    740   return thresholds;
    741 }
    742 
    743 std::vector<int32_t> QuantizeSamples(const std::vector<int32_t> &samples,
    744                                      size_t num_chunks) {
    745   if (samples.empty()) return {};
    746   int min = *std::min_element(samples.begin(), samples.end());
    747   constexpr int kRange = 512;
    748   min = std::min(std::max(min, -kRange), kRange);
    749   std::vector<uint32_t> counts(2 * kRange + 1);
    750   for (int s : samples) {
    751     uint32_t sample_offset = std::min(std::max(s, -kRange), kRange) - min;
    752     counts[sample_offset]++;
    753   }
    754   std::vector<int32_t> thresholds = QuantizeHistogram(counts, num_chunks);
    755   for (auto &v : thresholds) v += min;
    756   return thresholds;
    757 }
    758 }  // namespace
    759 
    760 void TreeSamples::PreQuantizeProperties(
    761     const StaticPropRange &range,
    762     const std::vector<ModularMultiplierInfo> &multiplier_info,
    763     const std::vector<uint32_t> &group_pixel_count,
    764     const std::vector<uint32_t> &channel_pixel_count,
    765     std::vector<pixel_type> &pixel_samples,
    766     std::vector<pixel_type> &diff_samples, size_t max_property_values) {
    767   // If we have forced splits because of multipliers, choose channel and group
    768   // thresholds accordingly.
    769   std::vector<int32_t> group_multiplier_thresholds;
    770   std::vector<int32_t> channel_multiplier_thresholds;
    771   for (const auto &v : multiplier_info) {
    772     if (v.range[0][0] != range[0][0]) {
    773       channel_multiplier_thresholds.push_back(v.range[0][0] - 1);
    774     }
    775     if (v.range[0][1] != range[0][1]) {
    776       channel_multiplier_thresholds.push_back(v.range[0][1] - 1);
    777     }
    778     if (v.range[1][0] != range[1][0]) {
    779       group_multiplier_thresholds.push_back(v.range[1][0] - 1);
    780     }
    781     if (v.range[1][1] != range[1][1]) {
    782       group_multiplier_thresholds.push_back(v.range[1][1] - 1);
    783     }
    784   }
    785   std::sort(channel_multiplier_thresholds.begin(),
    786             channel_multiplier_thresholds.end());
    787   channel_multiplier_thresholds.resize(
    788       std::unique(channel_multiplier_thresholds.begin(),
    789                   channel_multiplier_thresholds.end()) -
    790       channel_multiplier_thresholds.begin());
    791   std::sort(group_multiplier_thresholds.begin(),
    792             group_multiplier_thresholds.end());
    793   group_multiplier_thresholds.resize(
    794       std::unique(group_multiplier_thresholds.begin(),
    795                   group_multiplier_thresholds.end()) -
    796       group_multiplier_thresholds.begin());
    797 
    798   compact_properties.resize(props_to_use.size());
    799   auto quantize_channel = [&]() {
    800     if (!channel_multiplier_thresholds.empty()) {
    801       return channel_multiplier_thresholds;
    802     }
    803     return QuantizeHistogram(channel_pixel_count, max_property_values);
    804   };
    805   auto quantize_group_id = [&]() {
    806     if (!group_multiplier_thresholds.empty()) {
    807       return group_multiplier_thresholds;
    808     }
    809     return QuantizeHistogram(group_pixel_count, max_property_values);
    810   };
    811   auto quantize_coordinate = [&]() {
    812     std::vector<int32_t> quantized;
    813     quantized.reserve(max_property_values - 1);
    814     for (size_t i = 0; i + 1 < max_property_values; i++) {
    815       quantized.push_back((i + 1) * 256 / max_property_values - 1);
    816     }
    817     return quantized;
    818   };
    819   std::vector<int32_t> abs_pixel_thr;
    820   std::vector<int32_t> pixel_thr;
    821   auto quantize_pixel_property = [&]() {
    822     if (pixel_thr.empty()) {
    823       pixel_thr = QuantizeSamples(pixel_samples, max_property_values);
    824     }
    825     return pixel_thr;
    826   };
    827   auto quantize_abs_pixel_property = [&]() {
    828     if (abs_pixel_thr.empty()) {
    829       quantize_pixel_property();  // Compute the non-abs thresholds.
    830       for (auto &v : pixel_samples) v = std::abs(v);
    831       abs_pixel_thr = QuantizeSamples(pixel_samples, max_property_values);
    832     }
    833     return abs_pixel_thr;
    834   };
    835   std::vector<int32_t> abs_diff_thr;
    836   std::vector<int32_t> diff_thr;
    837   auto quantize_diff_property = [&]() {
    838     if (diff_thr.empty()) {
    839       diff_thr = QuantizeSamples(diff_samples, max_property_values);
    840     }
    841     return diff_thr;
    842   };
    843   auto quantize_abs_diff_property = [&]() {
    844     if (abs_diff_thr.empty()) {
    845       quantize_diff_property();  // Compute the non-abs thresholds.
    846       for (auto &v : diff_samples) v = std::abs(v);
    847       abs_diff_thr = QuantizeSamples(diff_samples, max_property_values);
    848     }
    849     return abs_diff_thr;
    850   };
    851   auto quantize_wp = [&]() {
    852     if (max_property_values < 32) {
    853       return std::vector<int32_t>{-127, -63, -31, -15, -7, -3, -1, 0,
    854                                   1,    3,   7,   15,  31, 63, 127};
    855     }
    856     if (max_property_values < 64) {
    857       return std::vector<int32_t>{-255, -191, -127, -95, -63, -47, -31, -23,
    858                                   -15,  -11,  -7,   -5,  -3,  -1,  0,   1,
    859                                   3,    5,    7,    11,  15,  23,  31,  47,
    860                                   63,   95,   127,  191, 255};
    861     }
    862     return std::vector<int32_t>{
    863         -255, -223, -191, -159, -127, -111, -95, -79, -63, -55, -47,
    864         -39,  -31,  -27,  -23,  -19,  -15,  -13, -11, -9,  -7,  -6,
    865         -5,   -4,   -3,   -2,   -1,   0,    1,   2,   3,   4,   5,
    866         6,    7,    9,    11,   13,   15,   19,  23,  27,  31,  39,
    867         47,   55,   63,   79,   95,   111,  127, 159, 191, 223, 255};
    868   };
    869 
    870   property_mapping.resize(props_to_use.size());
    871   for (size_t i = 0; i < props_to_use.size(); i++) {
    872     if (props_to_use[i] == 0) {
    873       compact_properties[i] = quantize_channel();
    874     } else if (props_to_use[i] == 1) {
    875       compact_properties[i] = quantize_group_id();
    876     } else if (props_to_use[i] == 2 || props_to_use[i] == 3) {
    877       compact_properties[i] = quantize_coordinate();
    878     } else if (props_to_use[i] == 6 || props_to_use[i] == 7 ||
    879                props_to_use[i] == 8 ||
    880                (props_to_use[i] >= kNumNonrefProperties &&
    881                 (props_to_use[i] - kNumNonrefProperties) % 4 == 1)) {
    882       compact_properties[i] = quantize_pixel_property();
    883     } else if (props_to_use[i] == 4 || props_to_use[i] == 5 ||
    884                (props_to_use[i] >= kNumNonrefProperties &&
    885                 (props_to_use[i] - kNumNonrefProperties) % 4 == 0)) {
    886       compact_properties[i] = quantize_abs_pixel_property();
    887     } else if (props_to_use[i] >= kNumNonrefProperties &&
    888                (props_to_use[i] - kNumNonrefProperties) % 4 == 2) {
    889       compact_properties[i] = quantize_abs_diff_property();
    890     } else if (props_to_use[i] == kWPProp) {
    891       compact_properties[i] = quantize_wp();
    892     } else {
    893       compact_properties[i] = quantize_diff_property();
    894     }
    895     property_mapping[i].resize(kPropertyRange * 2 + 1);
    896     size_t mapped = 0;
    897     for (size_t j = 0; j < property_mapping[i].size(); j++) {
    898       while (mapped < compact_properties[i].size() &&
    899              static_cast<int>(j) - kPropertyRange >
    900                  compact_properties[i][mapped]) {
    901         mapped++;
    902       }
    903       // property_mapping[i] of a value V is `mapped` if
    904       // compact_properties[i][mapped] <= j and
    905       // compact_properties[i][mapped-1] > j
    906       // This is because the decision node in the tree splits on (property) > j,
    907       // hence everything that is not > of a threshold should be clustered
    908       // together.
    909       property_mapping[i][j] = mapped;
    910     }
    911   }
    912 }
    913 
    914 void CollectPixelSamples(const Image &image, const ModularOptions &options,
    915                          size_t group_id,
    916                          std::vector<uint32_t> &group_pixel_count,
    917                          std::vector<uint32_t> &channel_pixel_count,
    918                          std::vector<pixel_type> &pixel_samples,
    919                          std::vector<pixel_type> &diff_samples) {
    920   if (options.nb_repeats == 0) return;
    921   if (group_pixel_count.size() <= group_id) {
    922     group_pixel_count.resize(group_id + 1);
    923   }
    924   if (channel_pixel_count.size() < image.channel.size()) {
    925     channel_pixel_count.resize(image.channel.size());
    926   }
    927   Rng rng(group_id);
    928   // Sample 10% of the final number of samples for property quantization.
    929   float fraction = std::min(options.nb_repeats * 0.1, 0.99);
    930   Rng::GeometricDistribution dist = Rng::MakeGeometric(fraction);
    931   size_t total_pixels = 0;
    932   std::vector<size_t> channel_ids;
    933   for (size_t i = 0; i < image.channel.size(); i++) {
    934     if (image.channel[i].w <= 1 || image.channel[i].h == 0) {
    935       continue;  // skip empty or width-1 channels.
    936     }
    937     if (i >= image.nb_meta_channels &&
    938         (image.channel[i].w > options.max_chan_size ||
    939          image.channel[i].h > options.max_chan_size)) {
    940       break;
    941     }
    942     channel_ids.push_back(i);
    943     group_pixel_count[group_id] += image.channel[i].w * image.channel[i].h;
    944     channel_pixel_count[i] += image.channel[i].w * image.channel[i].h;
    945     total_pixels += image.channel[i].w * image.channel[i].h;
    946   }
    947   if (channel_ids.empty()) return;
    948   pixel_samples.reserve(pixel_samples.size() + fraction * total_pixels);
    949   diff_samples.reserve(diff_samples.size() + fraction * total_pixels);
    950   size_t i = 0;
    951   size_t y = 0;
    952   size_t x = 0;
    953   auto advance = [&](size_t amount) {
    954     x += amount;
    955     // Detect row overflow (rare).
    956     while (x >= image.channel[channel_ids[i]].w) {
    957       x -= image.channel[channel_ids[i]].w;
    958       y++;
    959       // Detect end-of-channel (even rarer).
    960       if (y == image.channel[channel_ids[i]].h) {
    961         i++;
    962         y = 0;
    963         if (i >= channel_ids.size()) {
    964           return;
    965         }
    966       }
    967     }
    968   };
    969   advance(rng.Geometric(dist));
    970   for (; i < channel_ids.size(); advance(rng.Geometric(dist) + 1)) {
    971     const pixel_type *row = image.channel[channel_ids[i]].Row(y);
    972     pixel_samples.push_back(row[x]);
    973     size_t xp = x == 0 ? 1 : x - 1;
    974     diff_samples.push_back(static_cast<int64_t>(row[x]) - row[xp]);
    975   }
    976 }
    977 
    978 // TODO(veluca): very simple encoding scheme. This should be improved.
    979 void TokenizeTree(const Tree &tree, std::vector<Token> *tokens,
    980                   Tree *decoder_tree) {
    981   JXL_ASSERT(tree.size() <= kMaxTreeSize);
    982   std::queue<int> q;
    983   q.push(0);
    984   size_t leaf_id = 0;
    985   decoder_tree->clear();
    986   while (!q.empty()) {
    987     int cur = q.front();
    988     q.pop();
    989     JXL_ASSERT(tree[cur].property >= -1);
    990     tokens->emplace_back(kPropertyContext, tree[cur].property + 1);
    991     if (tree[cur].property == -1) {
    992       tokens->emplace_back(kPredictorContext,
    993                            static_cast<int>(tree[cur].predictor));
    994       tokens->emplace_back(kOffsetContext,
    995                            PackSigned(tree[cur].predictor_offset));
    996       uint32_t mul_log = Num0BitsBelowLS1Bit_Nonzero(tree[cur].multiplier);
    997       uint32_t mul_bits = (tree[cur].multiplier >> mul_log) - 1;
    998       tokens->emplace_back(kMultiplierLogContext, mul_log);
    999       tokens->emplace_back(kMultiplierBitsContext, mul_bits);
   1000       JXL_ASSERT(tree[cur].predictor < Predictor::Best);
   1001       decoder_tree->emplace_back(-1, 0, leaf_id++, 0, tree[cur].predictor,
   1002                                  tree[cur].predictor_offset,
   1003                                  tree[cur].multiplier);
   1004       continue;
   1005     }
   1006     decoder_tree->emplace_back(tree[cur].property, tree[cur].splitval,
   1007                                decoder_tree->size() + q.size() + 1,
   1008                                decoder_tree->size() + q.size() + 2,
   1009                                Predictor::Zero, 0, 1);
   1010     q.push(tree[cur].lchild);
   1011     q.push(tree[cur].rchild);
   1012     tokens->emplace_back(kSplitValContext, PackSigned(tree[cur].splitval));
   1013   }
   1014 }
   1015 
   1016 }  // namespace jxl
   1017 #endif  // HWY_ONCE