libjxl

FORK: libjxl patches used on blog
git clone https://git.neptards.moe/blog/libjxl.git
Log | Files | Refs | Submodules | README | LICENSE

context_predict.h (25614B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 #ifndef LIB_JXL_MODULAR_ENCODING_CONTEXT_PREDICT_H_
      7 #define LIB_JXL_MODULAR_ENCODING_CONTEXT_PREDICT_H_
      8 
      9 #include <utility>
     10 #include <vector>
     11 
     12 #include "lib/jxl/fields.h"
     13 #include "lib/jxl/image_ops.h"
     14 #include "lib/jxl/modular/modular_image.h"
     15 #include "lib/jxl/modular/options.h"
     16 
     17 namespace jxl {
     18 
     19 namespace weighted {
     20 constexpr static size_t kNumPredictors = 4;
     21 constexpr static int64_t kPredExtraBits = 3;
     22 constexpr static int64_t kPredictionRound = ((1 << kPredExtraBits) >> 1) - 1;
     23 constexpr static size_t kNumProperties = 1;
     24 
     25 struct Header : public Fields {
     26   JXL_FIELDS_NAME(WeightedPredictorHeader)
     27   // TODO(janwas): move to cc file, avoid including fields.h.
     28   Header() { Bundle::Init(this); }
     29 
     30   Status VisitFields(Visitor *JXL_RESTRICT visitor) override {
     31     if (visitor->AllDefault(*this, &all_default)) {
     32       // Overwrite all serialized fields, but not any nonserialized_*.
     33       visitor->SetDefault(this);
     34       return true;
     35     }
     36     auto visit_p = [visitor](pixel_type val, pixel_type *p) {
     37       uint32_t up = *p;
     38       JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(5, val, &up));
     39       *p = up;
     40       return Status(true);
     41     };
     42     JXL_QUIET_RETURN_IF_ERROR(visit_p(16, &p1C));
     43     JXL_QUIET_RETURN_IF_ERROR(visit_p(10, &p2C));
     44     JXL_QUIET_RETURN_IF_ERROR(visit_p(7, &p3Ca));
     45     JXL_QUIET_RETURN_IF_ERROR(visit_p(7, &p3Cb));
     46     JXL_QUIET_RETURN_IF_ERROR(visit_p(7, &p3Cc));
     47     JXL_QUIET_RETURN_IF_ERROR(visit_p(0, &p3Cd));
     48     JXL_QUIET_RETURN_IF_ERROR(visit_p(0, &p3Ce));
     49     JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(4, 0xd, &w[0]));
     50     JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(4, 0xc, &w[1]));
     51     JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(4, 0xc, &w[2]));
     52     JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(4, 0xc, &w[3]));
     53     return true;
     54   }
     55 
     56   bool all_default;
     57   pixel_type p1C = 0, p2C = 0, p3Ca = 0, p3Cb = 0, p3Cc = 0, p3Cd = 0, p3Ce = 0;
     58   uint32_t w[kNumPredictors] = {};
     59 };
     60 
     61 struct State {
     62   pixel_type_w prediction[kNumPredictors] = {};
     63   pixel_type_w pred = 0;  // *before* removing the added bits.
     64   std::vector<uint32_t> pred_errors[kNumPredictors];
     65   std::vector<int32_t> error;
     66   const Header header;
     67 
     68   // Allows to approximate division by a number from 1 to 64.
     69   //  for (int i = 0; i < 64; i++) divlookup[i] = (1 << 24) / (i + 1);
     70 
     71   const uint32_t divlookup[64] = {
     72       16777216, 8388608, 5592405, 4194304, 3355443, 2796202, 2396745, 2097152,
     73       1864135,  1677721, 1525201, 1398101, 1290555, 1198372, 1118481, 1048576,
     74       986895,   932067,  883011,  838860,  798915,  762600,  729444,  699050,
     75       671088,   645277,  621378,  599186,  578524,  559240,  541200,  524288,
     76       508400,   493447,  479349,  466033,  453438,  441505,  430185,  419430,
     77       409200,   399457,  390167,  381300,  372827,  364722,  356962,  349525,
     78       342392,   335544,  328965,  322638,  316551,  310689,  305040,  299593,
     79       294337,   289262,  284359,  279620,  275036,  270600,  266305,  262144};
     80 
     81   constexpr static pixel_type_w AddBits(pixel_type_w x) {
     82     return static_cast<uint64_t>(x) << kPredExtraBits;
     83   }
     84 
     85   State(Header header, size_t xsize, size_t ysize) : header(header) {
     86     // Extra margin to avoid out-of-bounds writes.
     87     // All have space for two rows of data.
     88     for (auto &pred_error : pred_errors) {
     89       pred_error.resize((xsize + 2) * 2);
     90     }
     91     error.resize((xsize + 2) * 2);
     92   }
     93 
     94   // Approximates 4+(maxweight<<24)/(x+1), avoiding division
     95   JXL_INLINE uint32_t ErrorWeight(uint64_t x, uint32_t maxweight) const {
     96     int shift = static_cast<int>(FloorLog2Nonzero(x + 1)) - 5;
     97     if (shift < 0) shift = 0;
     98     return 4 + ((maxweight * divlookup[x >> shift]) >> shift);
     99   }
    100 
    101   // Approximates the weighted average of the input values with the given
    102   // weights, avoiding division. Weights must sum to at least 16.
    103   JXL_INLINE pixel_type_w
    104   WeightedAverage(const pixel_type_w *JXL_RESTRICT p,
    105                   std::array<uint32_t, kNumPredictors> w) const {
    106     uint32_t weight_sum = 0;
    107     for (size_t i = 0; i < kNumPredictors; i++) {
    108       weight_sum += w[i];
    109     }
    110     JXL_DASSERT(weight_sum > 15);
    111     uint32_t log_weight = FloorLog2Nonzero(weight_sum);  // at least 4.
    112     weight_sum = 0;
    113     for (size_t i = 0; i < kNumPredictors; i++) {
    114       w[i] >>= log_weight - 4;
    115       weight_sum += w[i];
    116     }
    117     // for rounding.
    118     pixel_type_w sum = (weight_sum >> 1) - 1;
    119     for (size_t i = 0; i < kNumPredictors; i++) {
    120       sum += p[i] * w[i];
    121     }
    122     return (sum * divlookup[weight_sum - 1]) >> 24;
    123   }
    124 
    125   template <bool compute_properties>
    126   JXL_INLINE pixel_type_w Predict(size_t x, size_t y, size_t xsize,
    127                                   pixel_type_w N, pixel_type_w W,
    128                                   pixel_type_w NE, pixel_type_w NW,
    129                                   pixel_type_w NN, Properties *properties,
    130                                   size_t offset) {
    131     size_t cur_row = y & 1 ? 0 : (xsize + 2);
    132     size_t prev_row = y & 1 ? (xsize + 2) : 0;
    133     size_t pos_N = prev_row + x;
    134     size_t pos_NE = x < xsize - 1 ? pos_N + 1 : pos_N;
    135     size_t pos_NW = x > 0 ? pos_N - 1 : pos_N;
    136     std::array<uint32_t, kNumPredictors> weights;
    137     for (size_t i = 0; i < kNumPredictors; i++) {
    138       // pred_errors[pos_N] also contains the error of pixel W.
    139       // pred_errors[pos_NW] also contains the error of pixel WW.
    140       weights[i] = pred_errors[i][pos_N] + pred_errors[i][pos_NE] +
    141                    pred_errors[i][pos_NW];
    142       weights[i] = ErrorWeight(weights[i], header.w[i]);
    143     }
    144 
    145     N = AddBits(N);
    146     W = AddBits(W);
    147     NE = AddBits(NE);
    148     NW = AddBits(NW);
    149     NN = AddBits(NN);
    150 
    151     pixel_type_w teW = x == 0 ? 0 : error[cur_row + x - 1];
    152     pixel_type_w teN = error[pos_N];
    153     pixel_type_w teNW = error[pos_NW];
    154     pixel_type_w sumWN = teN + teW;
    155     pixel_type_w teNE = error[pos_NE];
    156 
    157     if (compute_properties) {
    158       pixel_type_w p = teW;
    159       if (std::abs(teN) > std::abs(p)) p = teN;
    160       if (std::abs(teNW) > std::abs(p)) p = teNW;
    161       if (std::abs(teNE) > std::abs(p)) p = teNE;
    162       (*properties)[offset++] = p;
    163     }
    164 
    165     prediction[0] = W + NE - N;
    166     prediction[1] = N - (((sumWN + teNE) * header.p1C) >> 5);
    167     prediction[2] = W - (((sumWN + teNW) * header.p2C) >> 5);
    168     prediction[3] =
    169         N - ((teNW * header.p3Ca + teN * header.p3Cb + teNE * header.p3Cc +
    170               (NN - N) * header.p3Cd + (NW - W) * header.p3Ce) >>
    171              5);
    172 
    173     pred = WeightedAverage(prediction, weights);
    174 
    175     // If all three have the same sign, skip clamping.
    176     if (((teN ^ teW) | (teN ^ teNW)) > 0) {
    177       return (pred + kPredictionRound) >> kPredExtraBits;
    178     }
    179 
    180     // Otherwise, clamp to min/max of neighbouring pixels (just W, NE, N).
    181     pixel_type_w mx = std::max(W, std::max(NE, N));
    182     pixel_type_w mn = std::min(W, std::min(NE, N));
    183     pred = std::max(mn, std::min(mx, pred));
    184     return (pred + kPredictionRound) >> kPredExtraBits;
    185   }
    186 
    187   JXL_INLINE void UpdateErrors(pixel_type_w val, size_t x, size_t y,
    188                                size_t xsize) {
    189     size_t cur_row = y & 1 ? 0 : (xsize + 2);
    190     size_t prev_row = y & 1 ? (xsize + 2) : 0;
    191     val = AddBits(val);
    192     error[cur_row + x] = pred - val;
    193     for (size_t i = 0; i < kNumPredictors; i++) {
    194       pixel_type_w err =
    195           (std::abs(prediction[i] - val) + kPredictionRound) >> kPredExtraBits;
    196       // For predicting in the next row.
    197       pred_errors[i][cur_row + x] = err;
    198       // Add the error on this pixel to the error on the NE pixel. This has the
    199       // effect of adding the error on this pixel to the E and EE pixels.
    200       pred_errors[i][prev_row + x + 1] += err;
    201     }
    202   }
    203 };
    204 
    205 // Encoder helper function to set the parameters to some presets.
    206 inline void PredictorMode(int i, Header *header) {
    207   switch (i) {
    208     case 0:
    209       // ~ lossless16 predictor
    210       header->w[0] = 0xd;
    211       header->w[1] = 0xc;
    212       header->w[2] = 0xc;
    213       header->w[3] = 0xc;
    214       header->p1C = 16;
    215       header->p2C = 10;
    216       header->p3Ca = 7;
    217       header->p3Cb = 7;
    218       header->p3Cc = 7;
    219       header->p3Cd = 0;
    220       header->p3Ce = 0;
    221       break;
    222     case 1:
    223       // ~ default lossless8 predictor
    224       header->w[0] = 0xd;
    225       header->w[1] = 0xc;
    226       header->w[2] = 0xc;
    227       header->w[3] = 0xb;
    228       header->p1C = 8;
    229       header->p2C = 8;
    230       header->p3Ca = 4;
    231       header->p3Cb = 0;
    232       header->p3Cc = 3;
    233       header->p3Cd = 23;
    234       header->p3Ce = 2;
    235       break;
    236     case 2:
    237       // ~ west lossless8 predictor
    238       header->w[0] = 0xd;
    239       header->w[1] = 0xc;
    240       header->w[2] = 0xd;
    241       header->w[3] = 0xc;
    242       header->p1C = 10;
    243       header->p2C = 9;
    244       header->p3Ca = 7;
    245       header->p3Cb = 0;
    246       header->p3Cc = 0;
    247       header->p3Cd = 16;
    248       header->p3Ce = 9;
    249       break;
    250     case 3:
    251       // ~ north lossless8 predictor
    252       header->w[0] = 0xd;
    253       header->w[1] = 0xd;
    254       header->w[2] = 0xc;
    255       header->w[3] = 0xc;
    256       header->p1C = 16;
    257       header->p2C = 8;
    258       header->p3Ca = 0;
    259       header->p3Cb = 16;
    260       header->p3Cc = 0;
    261       header->p3Cd = 23;
    262       header->p3Ce = 0;
    263       break;
    264     case 4:
    265     default:
    266       // something else, because why not
    267       header->w[0] = 0xd;
    268       header->w[1] = 0xc;
    269       header->w[2] = 0xc;
    270       header->w[3] = 0xc;
    271       header->p1C = 10;
    272       header->p2C = 10;
    273       header->p3Ca = 5;
    274       header->p3Cb = 5;
    275       header->p3Cc = 5;
    276       header->p3Cd = 12;
    277       header->p3Ce = 4;
    278       break;
    279   }
    280 }
    281 }  // namespace weighted
    282 
    283 // Stores a node and its two children at the same time. This significantly
    284 // reduces the number of branches needed during decoding.
    285 struct FlatDecisionNode {
    286   // Property + splitval of the top node.
    287   int32_t property0;  // -1 if leaf.
    288   union {
    289     PropertyVal splitval0;
    290     Predictor predictor;
    291   };
    292   // Property+splitval of the two child nodes.
    293   union {
    294     PropertyVal splitvals[2];
    295     int32_t multiplier;
    296   };
    297   uint32_t childID;  // childID is ctx id if leaf.
    298   union {
    299     int16_t properties[2];
    300     int32_t predictor_offset;
    301   };
    302 };
    303 using FlatTree = std::vector<FlatDecisionNode>;
    304 
    305 class MATreeLookup {
    306  public:
    307   explicit MATreeLookup(const FlatTree &tree) : nodes_(tree) {}
    308   struct LookupResult {
    309     uint32_t context;
    310     Predictor predictor;
    311     int32_t offset;
    312     int32_t multiplier;
    313   };
    314   JXL_INLINE LookupResult Lookup(const Properties &properties) const {
    315     uint32_t pos = 0;
    316     while (true) {
    317 #define TRAVERSE_THE_TREE                                                      \
    318   {                                                                            \
    319     const FlatDecisionNode &node = nodes_[pos];                                \
    320     if (node.property0 < 0) {                                                  \
    321       return {node.childID, node.predictor, node.predictor_offset,             \
    322               node.multiplier};                                                \
    323     }                                                                          \
    324     bool p0 = properties[node.property0] <= node.splitval0;                    \
    325     uint32_t off0 = properties[node.properties[0]] <= node.splitvals[0];       \
    326     uint32_t off1 = 2 | (properties[node.properties[1]] <= node.splitvals[1]); \
    327     pos = node.childID + (p0 ? off1 : off0);                                   \
    328   }
    329 
    330       TRAVERSE_THE_TREE;
    331       TRAVERSE_THE_TREE;
    332     }
    333   }
    334 
    335  private:
    336   const FlatTree &nodes_;
    337 };
    338 
    339 static constexpr size_t kExtraPropsPerChannel = 4;
    340 static constexpr size_t kNumNonrefProperties =
    341     kNumStaticProperties + 13 + weighted::kNumProperties;
    342 
    343 constexpr size_t kWPProp = kNumNonrefProperties - weighted::kNumProperties;
    344 constexpr size_t kGradientProp = 9;
    345 
    346 // Clamps gradient to the min/max of n, w (and l, implicitly).
    347 static JXL_INLINE int32_t ClampedGradient(const int32_t n, const int32_t w,
    348                                           const int32_t l) {
    349   const int32_t m = std::min(n, w);
    350   const int32_t M = std::max(n, w);
    351   // The end result of this operation doesn't overflow or underflow if the
    352   // result is between m and M, but the intermediate value may overflow, so we
    353   // do the intermediate operations in uint32_t and check later if we had an
    354   // overflow or underflow condition comparing m, M and l directly.
    355   // grad = M + m - l = n + w - l
    356   const int32_t grad =
    357       static_cast<int32_t>(static_cast<uint32_t>(n) + static_cast<uint32_t>(w) -
    358                            static_cast<uint32_t>(l));
    359   // We use two sets of ternary operators to force the evaluation of them in
    360   // any case, allowing the compiler to avoid branches and use cmovl/cmovg in
    361   // x86.
    362   const int32_t grad_clamp_M = (l < m) ? M : grad;
    363   return (l > M) ? m : grad_clamp_M;
    364 }
    365 
    366 inline pixel_type_w Select(pixel_type_w a, pixel_type_w b, pixel_type_w c) {
    367   pixel_type_w p = a + b - c;
    368   pixel_type_w pa = std::abs(p - a);
    369   pixel_type_w pb = std::abs(p - b);
    370   return pa < pb ? a : b;
    371 }
    372 
    373 inline void PrecomputeReferences(const Channel &ch, size_t y,
    374                                  const Image &image, uint32_t i,
    375                                  Channel *references) {
    376   ZeroFillImage(&references->plane);
    377   uint32_t offset = 0;
    378   size_t num_extra_props = references->w;
    379   intptr_t onerow = references->plane.PixelsPerRow();
    380   for (int32_t j = static_cast<int32_t>(i) - 1;
    381        j >= 0 && offset < num_extra_props; j--) {
    382     if (image.channel[j].w != image.channel[i].w ||
    383         image.channel[j].h != image.channel[i].h) {
    384       continue;
    385     }
    386     if (image.channel[j].hshift != image.channel[i].hshift) continue;
    387     if (image.channel[j].vshift != image.channel[i].vshift) continue;
    388     pixel_type *JXL_RESTRICT rp = references->Row(0) + offset;
    389     const pixel_type *JXL_RESTRICT rpp = image.channel[j].Row(y);
    390     const pixel_type *JXL_RESTRICT rpprev = image.channel[j].Row(y ? y - 1 : 0);
    391     for (size_t x = 0; x < ch.w; x++, rp += onerow) {
    392       pixel_type_w v = rpp[x];
    393       rp[0] = std::abs(v);
    394       rp[1] = v;
    395       pixel_type_w vleft = (x ? rpp[x - 1] : 0);
    396       pixel_type_w vtop = (y ? rpprev[x] : vleft);
    397       pixel_type_w vtopleft = (x && y ? rpprev[x - 1] : vleft);
    398       pixel_type_w vpredicted = ClampedGradient(vleft, vtop, vtopleft);
    399       rp[2] = std::abs(v - vpredicted);
    400       rp[3] = v - vpredicted;
    401     }
    402 
    403     offset += kExtraPropsPerChannel;
    404   }
    405 }
    406 
    407 struct PredictionResult {
    408   int context = 0;
    409   pixel_type_w guess = 0;
    410   Predictor predictor;
    411   int32_t multiplier;
    412 };
    413 
    414 inline void InitPropsRow(
    415     Properties *p,
    416     const std::array<pixel_type, kNumStaticProperties> &static_props,
    417     const int y) {
    418   for (size_t i = 0; i < kNumStaticProperties; i++) {
    419     (*p)[i] = static_props[i];
    420   }
    421   (*p)[2] = y;
    422   (*p)[9] = 0;  // local gradient.
    423 }
    424 
    425 namespace detail {
    426 enum PredictorMode {
    427   kUseTree = 1,
    428   kUseWP = 2,
    429   kForceComputeProperties = 4,
    430   kAllPredictions = 8,
    431   kNoEdgeCases = 16
    432 };
    433 
    434 JXL_INLINE pixel_type_w PredictOne(Predictor p, pixel_type_w left,
    435                                    pixel_type_w top, pixel_type_w toptop,
    436                                    pixel_type_w topleft, pixel_type_w topright,
    437                                    pixel_type_w leftleft,
    438                                    pixel_type_w toprightright,
    439                                    pixel_type_w wp_pred) {
    440   switch (p) {
    441     case Predictor::Zero:
    442       return pixel_type_w{0};
    443     case Predictor::Left:
    444       return left;
    445     case Predictor::Top:
    446       return top;
    447     case Predictor::Select:
    448       return Select(left, top, topleft);
    449     case Predictor::Weighted:
    450       return wp_pred;
    451     case Predictor::Gradient:
    452       return pixel_type_w{ClampedGradient(left, top, topleft)};
    453     case Predictor::TopLeft:
    454       return topleft;
    455     case Predictor::TopRight:
    456       return topright;
    457     case Predictor::LeftLeft:
    458       return leftleft;
    459     case Predictor::Average0:
    460       return (left + top) / 2;
    461     case Predictor::Average1:
    462       return (left + topleft) / 2;
    463     case Predictor::Average2:
    464       return (topleft + top) / 2;
    465     case Predictor::Average3:
    466       return (top + topright) / 2;
    467     case Predictor::Average4:
    468       return (6 * top - 2 * toptop + 7 * left + 1 * leftleft +
    469               1 * toprightright + 3 * topright + 8) /
    470              16;
    471     default:
    472       return pixel_type_w{0};
    473   }
    474 }
    475 
    476 template <int mode>
    477 JXL_INLINE PredictionResult Predict(
    478     Properties *p, size_t w, const pixel_type *JXL_RESTRICT pp,
    479     const intptr_t onerow, const size_t x, const size_t y, Predictor predictor,
    480     const MATreeLookup *lookup, const Channel *references,
    481     weighted::State *wp_state, pixel_type_w *predictions) {
    482   // We start in position 3 because of 2 static properties + y.
    483   size_t offset = 3;
    484   constexpr bool compute_properties =
    485       mode & kUseTree || mode & kForceComputeProperties;
    486   constexpr bool nec = mode & kNoEdgeCases;
    487   pixel_type_w left = (nec || x ? pp[-1] : (y ? pp[-onerow] : 0));
    488   pixel_type_w top = (nec || y ? pp[-onerow] : left);
    489   pixel_type_w topleft = (nec || (x && y) ? pp[-1 - onerow] : left);
    490   pixel_type_w topright = (nec || (x + 1 < w && y) ? pp[1 - onerow] : top);
    491   pixel_type_w leftleft = (nec || x > 1 ? pp[-2] : left);
    492   pixel_type_w toptop = (nec || y > 1 ? pp[-onerow - onerow] : top);
    493   pixel_type_w toprightright =
    494       (nec || (x + 2 < w && y) ? pp[2 - onerow] : topright);
    495 
    496   if (compute_properties) {
    497     // location
    498     (*p)[offset++] = x;
    499     // neighbors
    500     (*p)[offset++] = top > 0 ? top : -top;
    501     (*p)[offset++] = left > 0 ? left : -left;
    502     (*p)[offset++] = top;
    503     (*p)[offset++] = left;
    504 
    505     // local gradient
    506     (*p)[offset] = left - (*p)[offset + 1];
    507     offset++;
    508     // local gradient
    509     (*p)[offset++] = left + top - topleft;
    510 
    511     // FFV1 context properties
    512     (*p)[offset++] = left - topleft;
    513     (*p)[offset++] = topleft - top;
    514     (*p)[offset++] = top - topright;
    515     (*p)[offset++] = top - toptop;
    516     (*p)[offset++] = left - leftleft;
    517   }
    518 
    519   pixel_type_w wp_pred = 0;
    520   if (mode & kUseWP) {
    521     wp_pred = wp_state->Predict<compute_properties>(
    522         x, y, w, top, left, topright, topleft, toptop, p, offset);
    523   }
    524   if (!nec && compute_properties) {
    525     offset += weighted::kNumProperties;
    526     // Extra properties.
    527     const pixel_type *JXL_RESTRICT rp = references->Row(x);
    528     for (size_t i = 0; i < references->w; i++) {
    529       (*p)[offset++] = rp[i];
    530     }
    531   }
    532   PredictionResult result;
    533   if (mode & kUseTree) {
    534     MATreeLookup::LookupResult lr = lookup->Lookup(*p);
    535     result.context = lr.context;
    536     result.guess = lr.offset;
    537     result.multiplier = lr.multiplier;
    538     predictor = lr.predictor;
    539   }
    540   if (mode & kAllPredictions) {
    541     for (size_t i = 0; i < kNumModularPredictors; i++) {
    542       predictions[i] =
    543           PredictOne(static_cast<Predictor>(i), left, top, toptop, topleft,
    544                      topright, leftleft, toprightright, wp_pred);
    545     }
    546   }
    547   result.guess += PredictOne(predictor, left, top, toptop, topleft, topright,
    548                              leftleft, toprightright, wp_pred);
    549   result.predictor = predictor;
    550 
    551   return result;
    552 }
    553 }  // namespace detail
    554 
    555 inline PredictionResult PredictNoTreeNoWP(size_t w,
    556                                           const pixel_type *JXL_RESTRICT pp,
    557                                           const intptr_t onerow, const int x,
    558                                           const int y, Predictor predictor) {
    559   return detail::Predict</*mode=*/0>(
    560       /*p=*/nullptr, w, pp, onerow, x, y, predictor, /*lookup=*/nullptr,
    561       /*references=*/nullptr, /*wp_state=*/nullptr, /*predictions=*/nullptr);
    562 }
    563 
    564 inline PredictionResult PredictNoTreeWP(size_t w,
    565                                         const pixel_type *JXL_RESTRICT pp,
    566                                         const intptr_t onerow, const int x,
    567                                         const int y, Predictor predictor,
    568                                         weighted::State *wp_state) {
    569   return detail::Predict<detail::kUseWP>(
    570       /*p=*/nullptr, w, pp, onerow, x, y, predictor, /*lookup=*/nullptr,
    571       /*references=*/nullptr, wp_state, /*predictions=*/nullptr);
    572 }
    573 
    574 inline PredictionResult PredictTreeNoWP(Properties *p, size_t w,
    575                                         const pixel_type *JXL_RESTRICT pp,
    576                                         const intptr_t onerow, const int x,
    577                                         const int y,
    578                                         const MATreeLookup &tree_lookup,
    579                                         const Channel &references) {
    580   return detail::Predict<detail::kUseTree>(
    581       p, w, pp, onerow, x, y, Predictor::Zero, &tree_lookup, &references,
    582       /*wp_state=*/nullptr, /*predictions=*/nullptr);
    583 }
    584 // Only use for y > 1, x > 1, x < w-2, and empty references
    585 JXL_INLINE PredictionResult
    586 PredictTreeNoWPNEC(Properties *p, size_t w, const pixel_type *JXL_RESTRICT pp,
    587                    const intptr_t onerow, const int x, const int y,
    588                    const MATreeLookup &tree_lookup, const Channel &references) {
    589   return detail::Predict<detail::kUseTree | detail::kNoEdgeCases>(
    590       p, w, pp, onerow, x, y, Predictor::Zero, &tree_lookup, &references,
    591       /*wp_state=*/nullptr, /*predictions=*/nullptr);
    592 }
    593 
    594 inline PredictionResult PredictTreeWP(Properties *p, size_t w,
    595                                       const pixel_type *JXL_RESTRICT pp,
    596                                       const intptr_t onerow, const int x,
    597                                       const int y,
    598                                       const MATreeLookup &tree_lookup,
    599                                       const Channel &references,
    600                                       weighted::State *wp_state) {
    601   return detail::Predict<detail::kUseTree | detail::kUseWP>(
    602       p, w, pp, onerow, x, y, Predictor::Zero, &tree_lookup, &references,
    603       wp_state, /*predictions=*/nullptr);
    604 }
    605 JXL_INLINE PredictionResult PredictTreeWPNEC(Properties *p, size_t w,
    606                                              const pixel_type *JXL_RESTRICT pp,
    607                                              const intptr_t onerow, const int x,
    608                                              const int y,
    609                                              const MATreeLookup &tree_lookup,
    610                                              const Channel &references,
    611                                              weighted::State *wp_state) {
    612   return detail::Predict<detail::kUseTree | detail::kUseWP |
    613                          detail::kNoEdgeCases>(
    614       p, w, pp, onerow, x, y, Predictor::Zero, &tree_lookup, &references,
    615       wp_state, /*predictions=*/nullptr);
    616 }
    617 
    618 inline PredictionResult PredictLearn(Properties *p, size_t w,
    619                                      const pixel_type *JXL_RESTRICT pp,
    620                                      const intptr_t onerow, const int x,
    621                                      const int y, Predictor predictor,
    622                                      const Channel &references,
    623                                      weighted::State *wp_state) {
    624   return detail::Predict<detail::kForceComputeProperties | detail::kUseWP>(
    625       p, w, pp, onerow, x, y, predictor, /*lookup=*/nullptr, &references,
    626       wp_state, /*predictions=*/nullptr);
    627 }
    628 
    629 inline void PredictLearnAll(Properties *p, size_t w,
    630                             const pixel_type *JXL_RESTRICT pp,
    631                             const intptr_t onerow, const int x, const int y,
    632                             const Channel &references,
    633                             weighted::State *wp_state,
    634                             pixel_type_w *predictions) {
    635   detail::Predict<detail::kForceComputeProperties | detail::kUseWP |
    636                   detail::kAllPredictions>(
    637       p, w, pp, onerow, x, y, Predictor::Zero,
    638       /*lookup=*/nullptr, &references, wp_state, predictions);
    639 }
    640 inline PredictionResult PredictLearnNEC(Properties *p, size_t w,
    641                                         const pixel_type *JXL_RESTRICT pp,
    642                                         const intptr_t onerow, const int x,
    643                                         const int y, Predictor predictor,
    644                                         const Channel &references,
    645                                         weighted::State *wp_state) {
    646   return detail::Predict<detail::kForceComputeProperties | detail::kUseWP |
    647                          detail::kNoEdgeCases>(
    648       p, w, pp, onerow, x, y, predictor, /*lookup=*/nullptr, &references,
    649       wp_state, /*predictions=*/nullptr);
    650 }
    651 
    652 inline void PredictLearnAllNEC(Properties *p, size_t w,
    653                                const pixel_type *JXL_RESTRICT pp,
    654                                const intptr_t onerow, const int x, const int y,
    655                                const Channel &references,
    656                                weighted::State *wp_state,
    657                                pixel_type_w *predictions) {
    658   detail::Predict<detail::kForceComputeProperties | detail::kUseWP |
    659                   detail::kAllPredictions | detail::kNoEdgeCases>(
    660       p, w, pp, onerow, x, y, Predictor::Zero,
    661       /*lookup=*/nullptr, &references, wp_state, predictions);
    662 }
    663 
    664 inline void PredictAllNoWP(size_t w, const pixel_type *JXL_RESTRICT pp,
    665                            const intptr_t onerow, const int x, const int y,
    666                            pixel_type_w *predictions) {
    667   detail::Predict<detail::kAllPredictions>(
    668       /*p=*/nullptr, w, pp, onerow, x, y, Predictor::Zero,
    669       /*lookup=*/nullptr,
    670       /*references=*/nullptr, /*wp_state=*/nullptr, predictions);
    671 }
    672 }  // namespace jxl
    673 
    674 #endif  // LIB_JXL_MODULAR_ENCODING_CONTEXT_PREDICT_H_