context_predict.h (25614B)
1 // Copyright (c) the JPEG XL Project Authors. All rights reserved. 2 // 3 // Use of this source code is governed by a BSD-style 4 // license that can be found in the LICENSE file. 5 6 #ifndef LIB_JXL_MODULAR_ENCODING_CONTEXT_PREDICT_H_ 7 #define LIB_JXL_MODULAR_ENCODING_CONTEXT_PREDICT_H_ 8 9 #include <utility> 10 #include <vector> 11 12 #include "lib/jxl/fields.h" 13 #include "lib/jxl/image_ops.h" 14 #include "lib/jxl/modular/modular_image.h" 15 #include "lib/jxl/modular/options.h" 16 17 namespace jxl { 18 19 namespace weighted { 20 constexpr static size_t kNumPredictors = 4; 21 constexpr static int64_t kPredExtraBits = 3; 22 constexpr static int64_t kPredictionRound = ((1 << kPredExtraBits) >> 1) - 1; 23 constexpr static size_t kNumProperties = 1; 24 25 struct Header : public Fields { 26 JXL_FIELDS_NAME(WeightedPredictorHeader) 27 // TODO(janwas): move to cc file, avoid including fields.h. 28 Header() { Bundle::Init(this); } 29 30 Status VisitFields(Visitor *JXL_RESTRICT visitor) override { 31 if (visitor->AllDefault(*this, &all_default)) { 32 // Overwrite all serialized fields, but not any nonserialized_*. 33 visitor->SetDefault(this); 34 return true; 35 } 36 auto visit_p = [visitor](pixel_type val, pixel_type *p) { 37 uint32_t up = *p; 38 JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(5, val, &up)); 39 *p = up; 40 return Status(true); 41 }; 42 JXL_QUIET_RETURN_IF_ERROR(visit_p(16, &p1C)); 43 JXL_QUIET_RETURN_IF_ERROR(visit_p(10, &p2C)); 44 JXL_QUIET_RETURN_IF_ERROR(visit_p(7, &p3Ca)); 45 JXL_QUIET_RETURN_IF_ERROR(visit_p(7, &p3Cb)); 46 JXL_QUIET_RETURN_IF_ERROR(visit_p(7, &p3Cc)); 47 JXL_QUIET_RETURN_IF_ERROR(visit_p(0, &p3Cd)); 48 JXL_QUIET_RETURN_IF_ERROR(visit_p(0, &p3Ce)); 49 JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(4, 0xd, &w[0])); 50 JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(4, 0xc, &w[1])); 51 JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(4, 0xc, &w[2])); 52 JXL_QUIET_RETURN_IF_ERROR(visitor->Bits(4, 0xc, &w[3])); 53 return true; 54 } 55 56 bool all_default; 57 pixel_type p1C = 0, p2C = 0, p3Ca = 0, p3Cb = 0, p3Cc = 0, p3Cd = 0, p3Ce = 0; 58 uint32_t w[kNumPredictors] = {}; 59 }; 60 61 struct State { 62 pixel_type_w prediction[kNumPredictors] = {}; 63 pixel_type_w pred = 0; // *before* removing the added bits. 64 std::vector<uint32_t> pred_errors[kNumPredictors]; 65 std::vector<int32_t> error; 66 const Header header; 67 68 // Allows to approximate division by a number from 1 to 64. 69 // for (int i = 0; i < 64; i++) divlookup[i] = (1 << 24) / (i + 1); 70 71 const uint32_t divlookup[64] = { 72 16777216, 8388608, 5592405, 4194304, 3355443, 2796202, 2396745, 2097152, 73 1864135, 1677721, 1525201, 1398101, 1290555, 1198372, 1118481, 1048576, 74 986895, 932067, 883011, 838860, 798915, 762600, 729444, 699050, 75 671088, 645277, 621378, 599186, 578524, 559240, 541200, 524288, 76 508400, 493447, 479349, 466033, 453438, 441505, 430185, 419430, 77 409200, 399457, 390167, 381300, 372827, 364722, 356962, 349525, 78 342392, 335544, 328965, 322638, 316551, 310689, 305040, 299593, 79 294337, 289262, 284359, 279620, 275036, 270600, 266305, 262144}; 80 81 constexpr static pixel_type_w AddBits(pixel_type_w x) { 82 return static_cast<uint64_t>(x) << kPredExtraBits; 83 } 84 85 State(Header header, size_t xsize, size_t ysize) : header(header) { 86 // Extra margin to avoid out-of-bounds writes. 87 // All have space for two rows of data. 88 for (auto &pred_error : pred_errors) { 89 pred_error.resize((xsize + 2) * 2); 90 } 91 error.resize((xsize + 2) * 2); 92 } 93 94 // Approximates 4+(maxweight<<24)/(x+1), avoiding division 95 JXL_INLINE uint32_t ErrorWeight(uint64_t x, uint32_t maxweight) const { 96 int shift = static_cast<int>(FloorLog2Nonzero(x + 1)) - 5; 97 if (shift < 0) shift = 0; 98 return 4 + ((maxweight * divlookup[x >> shift]) >> shift); 99 } 100 101 // Approximates the weighted average of the input values with the given 102 // weights, avoiding division. Weights must sum to at least 16. 103 JXL_INLINE pixel_type_w 104 WeightedAverage(const pixel_type_w *JXL_RESTRICT p, 105 std::array<uint32_t, kNumPredictors> w) const { 106 uint32_t weight_sum = 0; 107 for (size_t i = 0; i < kNumPredictors; i++) { 108 weight_sum += w[i]; 109 } 110 JXL_DASSERT(weight_sum > 15); 111 uint32_t log_weight = FloorLog2Nonzero(weight_sum); // at least 4. 112 weight_sum = 0; 113 for (size_t i = 0; i < kNumPredictors; i++) { 114 w[i] >>= log_weight - 4; 115 weight_sum += w[i]; 116 } 117 // for rounding. 118 pixel_type_w sum = (weight_sum >> 1) - 1; 119 for (size_t i = 0; i < kNumPredictors; i++) { 120 sum += p[i] * w[i]; 121 } 122 return (sum * divlookup[weight_sum - 1]) >> 24; 123 } 124 125 template <bool compute_properties> 126 JXL_INLINE pixel_type_w Predict(size_t x, size_t y, size_t xsize, 127 pixel_type_w N, pixel_type_w W, 128 pixel_type_w NE, pixel_type_w NW, 129 pixel_type_w NN, Properties *properties, 130 size_t offset) { 131 size_t cur_row = y & 1 ? 0 : (xsize + 2); 132 size_t prev_row = y & 1 ? (xsize + 2) : 0; 133 size_t pos_N = prev_row + x; 134 size_t pos_NE = x < xsize - 1 ? pos_N + 1 : pos_N; 135 size_t pos_NW = x > 0 ? pos_N - 1 : pos_N; 136 std::array<uint32_t, kNumPredictors> weights; 137 for (size_t i = 0; i < kNumPredictors; i++) { 138 // pred_errors[pos_N] also contains the error of pixel W. 139 // pred_errors[pos_NW] also contains the error of pixel WW. 140 weights[i] = pred_errors[i][pos_N] + pred_errors[i][pos_NE] + 141 pred_errors[i][pos_NW]; 142 weights[i] = ErrorWeight(weights[i], header.w[i]); 143 } 144 145 N = AddBits(N); 146 W = AddBits(W); 147 NE = AddBits(NE); 148 NW = AddBits(NW); 149 NN = AddBits(NN); 150 151 pixel_type_w teW = x == 0 ? 0 : error[cur_row + x - 1]; 152 pixel_type_w teN = error[pos_N]; 153 pixel_type_w teNW = error[pos_NW]; 154 pixel_type_w sumWN = teN + teW; 155 pixel_type_w teNE = error[pos_NE]; 156 157 if (compute_properties) { 158 pixel_type_w p = teW; 159 if (std::abs(teN) > std::abs(p)) p = teN; 160 if (std::abs(teNW) > std::abs(p)) p = teNW; 161 if (std::abs(teNE) > std::abs(p)) p = teNE; 162 (*properties)[offset++] = p; 163 } 164 165 prediction[0] = W + NE - N; 166 prediction[1] = N - (((sumWN + teNE) * header.p1C) >> 5); 167 prediction[2] = W - (((sumWN + teNW) * header.p2C) >> 5); 168 prediction[3] = 169 N - ((teNW * header.p3Ca + teN * header.p3Cb + teNE * header.p3Cc + 170 (NN - N) * header.p3Cd + (NW - W) * header.p3Ce) >> 171 5); 172 173 pred = WeightedAverage(prediction, weights); 174 175 // If all three have the same sign, skip clamping. 176 if (((teN ^ teW) | (teN ^ teNW)) > 0) { 177 return (pred + kPredictionRound) >> kPredExtraBits; 178 } 179 180 // Otherwise, clamp to min/max of neighbouring pixels (just W, NE, N). 181 pixel_type_w mx = std::max(W, std::max(NE, N)); 182 pixel_type_w mn = std::min(W, std::min(NE, N)); 183 pred = std::max(mn, std::min(mx, pred)); 184 return (pred + kPredictionRound) >> kPredExtraBits; 185 } 186 187 JXL_INLINE void UpdateErrors(pixel_type_w val, size_t x, size_t y, 188 size_t xsize) { 189 size_t cur_row = y & 1 ? 0 : (xsize + 2); 190 size_t prev_row = y & 1 ? (xsize + 2) : 0; 191 val = AddBits(val); 192 error[cur_row + x] = pred - val; 193 for (size_t i = 0; i < kNumPredictors; i++) { 194 pixel_type_w err = 195 (std::abs(prediction[i] - val) + kPredictionRound) >> kPredExtraBits; 196 // For predicting in the next row. 197 pred_errors[i][cur_row + x] = err; 198 // Add the error on this pixel to the error on the NE pixel. This has the 199 // effect of adding the error on this pixel to the E and EE pixels. 200 pred_errors[i][prev_row + x + 1] += err; 201 } 202 } 203 }; 204 205 // Encoder helper function to set the parameters to some presets. 206 inline void PredictorMode(int i, Header *header) { 207 switch (i) { 208 case 0: 209 // ~ lossless16 predictor 210 header->w[0] = 0xd; 211 header->w[1] = 0xc; 212 header->w[2] = 0xc; 213 header->w[3] = 0xc; 214 header->p1C = 16; 215 header->p2C = 10; 216 header->p3Ca = 7; 217 header->p3Cb = 7; 218 header->p3Cc = 7; 219 header->p3Cd = 0; 220 header->p3Ce = 0; 221 break; 222 case 1: 223 // ~ default lossless8 predictor 224 header->w[0] = 0xd; 225 header->w[1] = 0xc; 226 header->w[2] = 0xc; 227 header->w[3] = 0xb; 228 header->p1C = 8; 229 header->p2C = 8; 230 header->p3Ca = 4; 231 header->p3Cb = 0; 232 header->p3Cc = 3; 233 header->p3Cd = 23; 234 header->p3Ce = 2; 235 break; 236 case 2: 237 // ~ west lossless8 predictor 238 header->w[0] = 0xd; 239 header->w[1] = 0xc; 240 header->w[2] = 0xd; 241 header->w[3] = 0xc; 242 header->p1C = 10; 243 header->p2C = 9; 244 header->p3Ca = 7; 245 header->p3Cb = 0; 246 header->p3Cc = 0; 247 header->p3Cd = 16; 248 header->p3Ce = 9; 249 break; 250 case 3: 251 // ~ north lossless8 predictor 252 header->w[0] = 0xd; 253 header->w[1] = 0xd; 254 header->w[2] = 0xc; 255 header->w[3] = 0xc; 256 header->p1C = 16; 257 header->p2C = 8; 258 header->p3Ca = 0; 259 header->p3Cb = 16; 260 header->p3Cc = 0; 261 header->p3Cd = 23; 262 header->p3Ce = 0; 263 break; 264 case 4: 265 default: 266 // something else, because why not 267 header->w[0] = 0xd; 268 header->w[1] = 0xc; 269 header->w[2] = 0xc; 270 header->w[3] = 0xc; 271 header->p1C = 10; 272 header->p2C = 10; 273 header->p3Ca = 5; 274 header->p3Cb = 5; 275 header->p3Cc = 5; 276 header->p3Cd = 12; 277 header->p3Ce = 4; 278 break; 279 } 280 } 281 } // namespace weighted 282 283 // Stores a node and its two children at the same time. This significantly 284 // reduces the number of branches needed during decoding. 285 struct FlatDecisionNode { 286 // Property + splitval of the top node. 287 int32_t property0; // -1 if leaf. 288 union { 289 PropertyVal splitval0; 290 Predictor predictor; 291 }; 292 // Property+splitval of the two child nodes. 293 union { 294 PropertyVal splitvals[2]; 295 int32_t multiplier; 296 }; 297 uint32_t childID; // childID is ctx id if leaf. 298 union { 299 int16_t properties[2]; 300 int32_t predictor_offset; 301 }; 302 }; 303 using FlatTree = std::vector<FlatDecisionNode>; 304 305 class MATreeLookup { 306 public: 307 explicit MATreeLookup(const FlatTree &tree) : nodes_(tree) {} 308 struct LookupResult { 309 uint32_t context; 310 Predictor predictor; 311 int32_t offset; 312 int32_t multiplier; 313 }; 314 JXL_INLINE LookupResult Lookup(const Properties &properties) const { 315 uint32_t pos = 0; 316 while (true) { 317 #define TRAVERSE_THE_TREE \ 318 { \ 319 const FlatDecisionNode &node = nodes_[pos]; \ 320 if (node.property0 < 0) { \ 321 return {node.childID, node.predictor, node.predictor_offset, \ 322 node.multiplier}; \ 323 } \ 324 bool p0 = properties[node.property0] <= node.splitval0; \ 325 uint32_t off0 = properties[node.properties[0]] <= node.splitvals[0]; \ 326 uint32_t off1 = 2 | (properties[node.properties[1]] <= node.splitvals[1]); \ 327 pos = node.childID + (p0 ? off1 : off0); \ 328 } 329 330 TRAVERSE_THE_TREE; 331 TRAVERSE_THE_TREE; 332 } 333 } 334 335 private: 336 const FlatTree &nodes_; 337 }; 338 339 static constexpr size_t kExtraPropsPerChannel = 4; 340 static constexpr size_t kNumNonrefProperties = 341 kNumStaticProperties + 13 + weighted::kNumProperties; 342 343 constexpr size_t kWPProp = kNumNonrefProperties - weighted::kNumProperties; 344 constexpr size_t kGradientProp = 9; 345 346 // Clamps gradient to the min/max of n, w (and l, implicitly). 347 static JXL_INLINE int32_t ClampedGradient(const int32_t n, const int32_t w, 348 const int32_t l) { 349 const int32_t m = std::min(n, w); 350 const int32_t M = std::max(n, w); 351 // The end result of this operation doesn't overflow or underflow if the 352 // result is between m and M, but the intermediate value may overflow, so we 353 // do the intermediate operations in uint32_t and check later if we had an 354 // overflow or underflow condition comparing m, M and l directly. 355 // grad = M + m - l = n + w - l 356 const int32_t grad = 357 static_cast<int32_t>(static_cast<uint32_t>(n) + static_cast<uint32_t>(w) - 358 static_cast<uint32_t>(l)); 359 // We use two sets of ternary operators to force the evaluation of them in 360 // any case, allowing the compiler to avoid branches and use cmovl/cmovg in 361 // x86. 362 const int32_t grad_clamp_M = (l < m) ? M : grad; 363 return (l > M) ? m : grad_clamp_M; 364 } 365 366 inline pixel_type_w Select(pixel_type_w a, pixel_type_w b, pixel_type_w c) { 367 pixel_type_w p = a + b - c; 368 pixel_type_w pa = std::abs(p - a); 369 pixel_type_w pb = std::abs(p - b); 370 return pa < pb ? a : b; 371 } 372 373 inline void PrecomputeReferences(const Channel &ch, size_t y, 374 const Image &image, uint32_t i, 375 Channel *references) { 376 ZeroFillImage(&references->plane); 377 uint32_t offset = 0; 378 size_t num_extra_props = references->w; 379 intptr_t onerow = references->plane.PixelsPerRow(); 380 for (int32_t j = static_cast<int32_t>(i) - 1; 381 j >= 0 && offset < num_extra_props; j--) { 382 if (image.channel[j].w != image.channel[i].w || 383 image.channel[j].h != image.channel[i].h) { 384 continue; 385 } 386 if (image.channel[j].hshift != image.channel[i].hshift) continue; 387 if (image.channel[j].vshift != image.channel[i].vshift) continue; 388 pixel_type *JXL_RESTRICT rp = references->Row(0) + offset; 389 const pixel_type *JXL_RESTRICT rpp = image.channel[j].Row(y); 390 const pixel_type *JXL_RESTRICT rpprev = image.channel[j].Row(y ? y - 1 : 0); 391 for (size_t x = 0; x < ch.w; x++, rp += onerow) { 392 pixel_type_w v = rpp[x]; 393 rp[0] = std::abs(v); 394 rp[1] = v; 395 pixel_type_w vleft = (x ? rpp[x - 1] : 0); 396 pixel_type_w vtop = (y ? rpprev[x] : vleft); 397 pixel_type_w vtopleft = (x && y ? rpprev[x - 1] : vleft); 398 pixel_type_w vpredicted = ClampedGradient(vleft, vtop, vtopleft); 399 rp[2] = std::abs(v - vpredicted); 400 rp[3] = v - vpredicted; 401 } 402 403 offset += kExtraPropsPerChannel; 404 } 405 } 406 407 struct PredictionResult { 408 int context = 0; 409 pixel_type_w guess = 0; 410 Predictor predictor; 411 int32_t multiplier; 412 }; 413 414 inline void InitPropsRow( 415 Properties *p, 416 const std::array<pixel_type, kNumStaticProperties> &static_props, 417 const int y) { 418 for (size_t i = 0; i < kNumStaticProperties; i++) { 419 (*p)[i] = static_props[i]; 420 } 421 (*p)[2] = y; 422 (*p)[9] = 0; // local gradient. 423 } 424 425 namespace detail { 426 enum PredictorMode { 427 kUseTree = 1, 428 kUseWP = 2, 429 kForceComputeProperties = 4, 430 kAllPredictions = 8, 431 kNoEdgeCases = 16 432 }; 433 434 JXL_INLINE pixel_type_w PredictOne(Predictor p, pixel_type_w left, 435 pixel_type_w top, pixel_type_w toptop, 436 pixel_type_w topleft, pixel_type_w topright, 437 pixel_type_w leftleft, 438 pixel_type_w toprightright, 439 pixel_type_w wp_pred) { 440 switch (p) { 441 case Predictor::Zero: 442 return pixel_type_w{0}; 443 case Predictor::Left: 444 return left; 445 case Predictor::Top: 446 return top; 447 case Predictor::Select: 448 return Select(left, top, topleft); 449 case Predictor::Weighted: 450 return wp_pred; 451 case Predictor::Gradient: 452 return pixel_type_w{ClampedGradient(left, top, topleft)}; 453 case Predictor::TopLeft: 454 return topleft; 455 case Predictor::TopRight: 456 return topright; 457 case Predictor::LeftLeft: 458 return leftleft; 459 case Predictor::Average0: 460 return (left + top) / 2; 461 case Predictor::Average1: 462 return (left + topleft) / 2; 463 case Predictor::Average2: 464 return (topleft + top) / 2; 465 case Predictor::Average3: 466 return (top + topright) / 2; 467 case Predictor::Average4: 468 return (6 * top - 2 * toptop + 7 * left + 1 * leftleft + 469 1 * toprightright + 3 * topright + 8) / 470 16; 471 default: 472 return pixel_type_w{0}; 473 } 474 } 475 476 template <int mode> 477 JXL_INLINE PredictionResult Predict( 478 Properties *p, size_t w, const pixel_type *JXL_RESTRICT pp, 479 const intptr_t onerow, const size_t x, const size_t y, Predictor predictor, 480 const MATreeLookup *lookup, const Channel *references, 481 weighted::State *wp_state, pixel_type_w *predictions) { 482 // We start in position 3 because of 2 static properties + y. 483 size_t offset = 3; 484 constexpr bool compute_properties = 485 mode & kUseTree || mode & kForceComputeProperties; 486 constexpr bool nec = mode & kNoEdgeCases; 487 pixel_type_w left = (nec || x ? pp[-1] : (y ? pp[-onerow] : 0)); 488 pixel_type_w top = (nec || y ? pp[-onerow] : left); 489 pixel_type_w topleft = (nec || (x && y) ? pp[-1 - onerow] : left); 490 pixel_type_w topright = (nec || (x + 1 < w && y) ? pp[1 - onerow] : top); 491 pixel_type_w leftleft = (nec || x > 1 ? pp[-2] : left); 492 pixel_type_w toptop = (nec || y > 1 ? pp[-onerow - onerow] : top); 493 pixel_type_w toprightright = 494 (nec || (x + 2 < w && y) ? pp[2 - onerow] : topright); 495 496 if (compute_properties) { 497 // location 498 (*p)[offset++] = x; 499 // neighbors 500 (*p)[offset++] = top > 0 ? top : -top; 501 (*p)[offset++] = left > 0 ? left : -left; 502 (*p)[offset++] = top; 503 (*p)[offset++] = left; 504 505 // local gradient 506 (*p)[offset] = left - (*p)[offset + 1]; 507 offset++; 508 // local gradient 509 (*p)[offset++] = left + top - topleft; 510 511 // FFV1 context properties 512 (*p)[offset++] = left - topleft; 513 (*p)[offset++] = topleft - top; 514 (*p)[offset++] = top - topright; 515 (*p)[offset++] = top - toptop; 516 (*p)[offset++] = left - leftleft; 517 } 518 519 pixel_type_w wp_pred = 0; 520 if (mode & kUseWP) { 521 wp_pred = wp_state->Predict<compute_properties>( 522 x, y, w, top, left, topright, topleft, toptop, p, offset); 523 } 524 if (!nec && compute_properties) { 525 offset += weighted::kNumProperties; 526 // Extra properties. 527 const pixel_type *JXL_RESTRICT rp = references->Row(x); 528 for (size_t i = 0; i < references->w; i++) { 529 (*p)[offset++] = rp[i]; 530 } 531 } 532 PredictionResult result; 533 if (mode & kUseTree) { 534 MATreeLookup::LookupResult lr = lookup->Lookup(*p); 535 result.context = lr.context; 536 result.guess = lr.offset; 537 result.multiplier = lr.multiplier; 538 predictor = lr.predictor; 539 } 540 if (mode & kAllPredictions) { 541 for (size_t i = 0; i < kNumModularPredictors; i++) { 542 predictions[i] = 543 PredictOne(static_cast<Predictor>(i), left, top, toptop, topleft, 544 topright, leftleft, toprightright, wp_pred); 545 } 546 } 547 result.guess += PredictOne(predictor, left, top, toptop, topleft, topright, 548 leftleft, toprightright, wp_pred); 549 result.predictor = predictor; 550 551 return result; 552 } 553 } // namespace detail 554 555 inline PredictionResult PredictNoTreeNoWP(size_t w, 556 const pixel_type *JXL_RESTRICT pp, 557 const intptr_t onerow, const int x, 558 const int y, Predictor predictor) { 559 return detail::Predict</*mode=*/0>( 560 /*p=*/nullptr, w, pp, onerow, x, y, predictor, /*lookup=*/nullptr, 561 /*references=*/nullptr, /*wp_state=*/nullptr, /*predictions=*/nullptr); 562 } 563 564 inline PredictionResult PredictNoTreeWP(size_t w, 565 const pixel_type *JXL_RESTRICT pp, 566 const intptr_t onerow, const int x, 567 const int y, Predictor predictor, 568 weighted::State *wp_state) { 569 return detail::Predict<detail::kUseWP>( 570 /*p=*/nullptr, w, pp, onerow, x, y, predictor, /*lookup=*/nullptr, 571 /*references=*/nullptr, wp_state, /*predictions=*/nullptr); 572 } 573 574 inline PredictionResult PredictTreeNoWP(Properties *p, size_t w, 575 const pixel_type *JXL_RESTRICT pp, 576 const intptr_t onerow, const int x, 577 const int y, 578 const MATreeLookup &tree_lookup, 579 const Channel &references) { 580 return detail::Predict<detail::kUseTree>( 581 p, w, pp, onerow, x, y, Predictor::Zero, &tree_lookup, &references, 582 /*wp_state=*/nullptr, /*predictions=*/nullptr); 583 } 584 // Only use for y > 1, x > 1, x < w-2, and empty references 585 JXL_INLINE PredictionResult 586 PredictTreeNoWPNEC(Properties *p, size_t w, const pixel_type *JXL_RESTRICT pp, 587 const intptr_t onerow, const int x, const int y, 588 const MATreeLookup &tree_lookup, const Channel &references) { 589 return detail::Predict<detail::kUseTree | detail::kNoEdgeCases>( 590 p, w, pp, onerow, x, y, Predictor::Zero, &tree_lookup, &references, 591 /*wp_state=*/nullptr, /*predictions=*/nullptr); 592 } 593 594 inline PredictionResult PredictTreeWP(Properties *p, size_t w, 595 const pixel_type *JXL_RESTRICT pp, 596 const intptr_t onerow, const int x, 597 const int y, 598 const MATreeLookup &tree_lookup, 599 const Channel &references, 600 weighted::State *wp_state) { 601 return detail::Predict<detail::kUseTree | detail::kUseWP>( 602 p, w, pp, onerow, x, y, Predictor::Zero, &tree_lookup, &references, 603 wp_state, /*predictions=*/nullptr); 604 } 605 JXL_INLINE PredictionResult PredictTreeWPNEC(Properties *p, size_t w, 606 const pixel_type *JXL_RESTRICT pp, 607 const intptr_t onerow, const int x, 608 const int y, 609 const MATreeLookup &tree_lookup, 610 const Channel &references, 611 weighted::State *wp_state) { 612 return detail::Predict<detail::kUseTree | detail::kUseWP | 613 detail::kNoEdgeCases>( 614 p, w, pp, onerow, x, y, Predictor::Zero, &tree_lookup, &references, 615 wp_state, /*predictions=*/nullptr); 616 } 617 618 inline PredictionResult PredictLearn(Properties *p, size_t w, 619 const pixel_type *JXL_RESTRICT pp, 620 const intptr_t onerow, const int x, 621 const int y, Predictor predictor, 622 const Channel &references, 623 weighted::State *wp_state) { 624 return detail::Predict<detail::kForceComputeProperties | detail::kUseWP>( 625 p, w, pp, onerow, x, y, predictor, /*lookup=*/nullptr, &references, 626 wp_state, /*predictions=*/nullptr); 627 } 628 629 inline void PredictLearnAll(Properties *p, size_t w, 630 const pixel_type *JXL_RESTRICT pp, 631 const intptr_t onerow, const int x, const int y, 632 const Channel &references, 633 weighted::State *wp_state, 634 pixel_type_w *predictions) { 635 detail::Predict<detail::kForceComputeProperties | detail::kUseWP | 636 detail::kAllPredictions>( 637 p, w, pp, onerow, x, y, Predictor::Zero, 638 /*lookup=*/nullptr, &references, wp_state, predictions); 639 } 640 inline PredictionResult PredictLearnNEC(Properties *p, size_t w, 641 const pixel_type *JXL_RESTRICT pp, 642 const intptr_t onerow, const int x, 643 const int y, Predictor predictor, 644 const Channel &references, 645 weighted::State *wp_state) { 646 return detail::Predict<detail::kForceComputeProperties | detail::kUseWP | 647 detail::kNoEdgeCases>( 648 p, w, pp, onerow, x, y, predictor, /*lookup=*/nullptr, &references, 649 wp_state, /*predictions=*/nullptr); 650 } 651 652 inline void PredictLearnAllNEC(Properties *p, size_t w, 653 const pixel_type *JXL_RESTRICT pp, 654 const intptr_t onerow, const int x, const int y, 655 const Channel &references, 656 weighted::State *wp_state, 657 pixel_type_w *predictions) { 658 detail::Predict<detail::kForceComputeProperties | detail::kUseWP | 659 detail::kAllPredictions | detail::kNoEdgeCases>( 660 p, w, pp, onerow, x, y, Predictor::Zero, 661 /*lookup=*/nullptr, &references, wp_state, predictions); 662 } 663 664 inline void PredictAllNoWP(size_t w, const pixel_type *JXL_RESTRICT pp, 665 const intptr_t onerow, const int x, const int y, 666 pixel_type_w *predictions) { 667 detail::Predict<detail::kAllPredictions>( 668 /*p=*/nullptr, w, pp, onerow, x, y, Predictor::Zero, 669 /*lookup=*/nullptr, 670 /*references=*/nullptr, /*wp_state=*/nullptr, predictions); 671 } 672 } // namespace jxl 673 674 #endif // LIB_JXL_MODULAR_ENCODING_CONTEXT_PREDICT_H_