enc_ma.cc (38376B)
1 // Copyright (c) the JPEG XL Project Authors. All rights reserved. 2 // 3 // Use of this source code is governed by a BSD-style 4 // license that can be found in the LICENSE file. 5 6 #include "lib/jxl/modular/encoding/enc_ma.h" 7 8 #include <algorithm> 9 #include <limits> 10 #include <numeric> 11 #include <queue> 12 #include <unordered_map> 13 #include <unordered_set> 14 15 #include "lib/jxl/modular/encoding/ma_common.h" 16 17 #undef HWY_TARGET_INCLUDE 18 #define HWY_TARGET_INCLUDE "lib/jxl/modular/encoding/enc_ma.cc" 19 #include <hwy/foreach_target.h> 20 #include <hwy/highway.h> 21 22 #include "lib/jxl/base/fast_math-inl.h" 23 #include "lib/jxl/base/random.h" 24 #include "lib/jxl/enc_ans.h" 25 #include "lib/jxl/modular/encoding/context_predict.h" 26 #include "lib/jxl/modular/options.h" 27 #include "lib/jxl/pack_signed.h" 28 HWY_BEFORE_NAMESPACE(); 29 namespace jxl { 30 namespace HWY_NAMESPACE { 31 32 // These templates are not found via ADL. 33 using hwy::HWY_NAMESPACE::Eq; 34 using hwy::HWY_NAMESPACE::IfThenElse; 35 using hwy::HWY_NAMESPACE::Lt; 36 using hwy::HWY_NAMESPACE::Max; 37 38 const HWY_FULL(float) df; 39 const HWY_FULL(int32_t) di; 40 size_t Padded(size_t x) { return RoundUpTo(x, Lanes(df)); } 41 42 // Compute entropy of the histogram, taking into account the minimum probability 43 // for symbols with non-zero counts. 44 float EstimateBits(const int32_t *counts, size_t num_symbols) { 45 int32_t total = std::accumulate(counts, counts + num_symbols, 0); 46 const auto zero = Zero(df); 47 const auto minprob = Set(df, 1.0f / ANS_TAB_SIZE); 48 const auto inv_total = Set(df, 1.0f / total); 49 auto bits_lanes = Zero(df); 50 auto total_v = Set(di, total); 51 for (size_t i = 0; i < num_symbols; i += Lanes(df)) { 52 const auto counts_iv = LoadU(di, &counts[i]); 53 const auto counts_fv = ConvertTo(df, counts_iv); 54 const auto probs = Mul(counts_fv, inv_total); 55 const auto mprobs = Max(probs, minprob); 56 const auto nbps = IfThenElse(Eq(counts_iv, total_v), BitCast(di, zero), 57 BitCast(di, FastLog2f(df, mprobs))); 58 bits_lanes = Sub(bits_lanes, Mul(counts_fv, BitCast(df, nbps))); 59 } 60 return GetLane(SumOfLanes(df, bits_lanes)); 61 } 62 63 void MakeSplitNode(size_t pos, int property, int splitval, Predictor lpred, 64 int64_t loff, Predictor rpred, int64_t roff, Tree *tree) { 65 // Note that the tree splits on *strictly greater*. 66 (*tree)[pos].lchild = tree->size(); 67 (*tree)[pos].rchild = tree->size() + 1; 68 (*tree)[pos].splitval = splitval; 69 (*tree)[pos].property = property; 70 tree->emplace_back(); 71 tree->back().property = -1; 72 tree->back().predictor = rpred; 73 tree->back().predictor_offset = roff; 74 tree->back().multiplier = 1; 75 tree->emplace_back(); 76 tree->back().property = -1; 77 tree->back().predictor = lpred; 78 tree->back().predictor_offset = loff; 79 tree->back().multiplier = 1; 80 } 81 82 enum class IntersectionType { kNone, kPartial, kInside }; 83 IntersectionType BoxIntersects(StaticPropRange needle, StaticPropRange haystack, 84 uint32_t &partial_axis, uint32_t &partial_val) { 85 bool partial = false; 86 for (size_t i = 0; i < kNumStaticProperties; i++) { 87 if (haystack[i][0] >= needle[i][1]) { 88 return IntersectionType::kNone; 89 } 90 if (haystack[i][1] <= needle[i][0]) { 91 return IntersectionType::kNone; 92 } 93 if (haystack[i][0] <= needle[i][0] && haystack[i][1] >= needle[i][1]) { 94 continue; 95 } 96 partial = true; 97 partial_axis = i; 98 if (haystack[i][0] > needle[i][0] && haystack[i][0] < needle[i][1]) { 99 partial_val = haystack[i][0] - 1; 100 } else { 101 JXL_DASSERT(haystack[i][1] > needle[i][0] && 102 haystack[i][1] < needle[i][1]); 103 partial_val = haystack[i][1] - 1; 104 } 105 } 106 return partial ? IntersectionType::kPartial : IntersectionType::kInside; 107 } 108 109 void SplitTreeSamples(TreeSamples &tree_samples, size_t begin, size_t pos, 110 size_t end, size_t prop) { 111 auto cmp = [&](size_t a, size_t b) { 112 return static_cast<int32_t>(tree_samples.Property(prop, a)) - 113 static_cast<int32_t>(tree_samples.Property(prop, b)); 114 }; 115 Rng rng(0); 116 while (end > begin + 1) { 117 { 118 size_t pivot = rng.UniformU(begin, end); 119 tree_samples.Swap(begin, pivot); 120 } 121 size_t pivot_begin = begin; 122 size_t pivot_end = pivot_begin + 1; 123 for (size_t i = begin + 1; i < end; i++) { 124 JXL_DASSERT(i >= pivot_end); 125 JXL_DASSERT(pivot_end > pivot_begin); 126 int32_t cmp_result = cmp(i, pivot_begin); 127 if (cmp_result < 0) { // i < pivot, move pivot forward and put i before 128 // the pivot. 129 tree_samples.ThreeShuffle(pivot_begin, pivot_end, i); 130 pivot_begin++; 131 pivot_end++; 132 } else if (cmp_result == 0) { 133 tree_samples.Swap(pivot_end, i); 134 pivot_end++; 135 } 136 } 137 JXL_DASSERT(pivot_begin >= begin); 138 JXL_DASSERT(pivot_end > pivot_begin); 139 JXL_DASSERT(pivot_end <= end); 140 for (size_t i = begin; i < pivot_begin; i++) { 141 JXL_DASSERT(cmp(i, pivot_begin) < 0); 142 } 143 for (size_t i = pivot_end; i < end; i++) { 144 JXL_DASSERT(cmp(i, pivot_begin) > 0); 145 } 146 for (size_t i = pivot_begin; i < pivot_end; i++) { 147 JXL_DASSERT(cmp(i, pivot_begin) == 0); 148 } 149 // We now have that [begin, pivot_begin) is < pivot, [pivot_begin, 150 // pivot_end) is = pivot, and [pivot_end, end) is > pivot. 151 // If pos falls in the first or the last interval, we continue in that 152 // interval; otherwise, we are done. 153 if (pivot_begin > pos) { 154 end = pivot_begin; 155 } else if (pivot_end < pos) { 156 begin = pivot_end; 157 } else { 158 break; 159 } 160 } 161 } 162 163 void FindBestSplit(TreeSamples &tree_samples, float threshold, 164 const std::vector<ModularMultiplierInfo> &mul_info, 165 StaticPropRange initial_static_prop_range, 166 float fast_decode_multiplier, Tree *tree) { 167 struct NodeInfo { 168 size_t pos; 169 size_t begin; 170 size_t end; 171 uint64_t used_properties; 172 StaticPropRange static_prop_range; 173 }; 174 std::vector<NodeInfo> nodes; 175 nodes.push_back(NodeInfo{0, 0, tree_samples.NumDistinctSamples(), 0, 176 initial_static_prop_range}); 177 178 size_t num_predictors = tree_samples.NumPredictors(); 179 size_t num_properties = tree_samples.NumProperties(); 180 181 // TODO(veluca): consider parallelizing the search (processing multiple nodes 182 // at a time). 183 while (!nodes.empty()) { 184 size_t pos = nodes.back().pos; 185 size_t begin = nodes.back().begin; 186 size_t end = nodes.back().end; 187 uint64_t used_properties = nodes.back().used_properties; 188 StaticPropRange static_prop_range = nodes.back().static_prop_range; 189 nodes.pop_back(); 190 if (begin == end) continue; 191 192 struct SplitInfo { 193 size_t prop = 0; 194 uint32_t val = 0; 195 size_t pos = 0; 196 float lcost = std::numeric_limits<float>::max(); 197 float rcost = std::numeric_limits<float>::max(); 198 Predictor lpred = Predictor::Zero; 199 Predictor rpred = Predictor::Zero; 200 float Cost() { return lcost + rcost; } 201 }; 202 203 SplitInfo best_split_static_constant; 204 SplitInfo best_split_static; 205 SplitInfo best_split_nonstatic; 206 SplitInfo best_split_nowp; 207 208 JXL_DASSERT(begin <= end); 209 JXL_DASSERT(end <= tree_samples.NumDistinctSamples()); 210 211 // Compute the maximum token in the range. 212 size_t max_symbols = 0; 213 for (size_t pred = 0; pred < num_predictors; pred++) { 214 for (size_t i = begin; i < end; i++) { 215 uint32_t tok = tree_samples.Token(pred, i); 216 max_symbols = max_symbols > tok + 1 ? max_symbols : tok + 1; 217 } 218 } 219 max_symbols = Padded(max_symbols); 220 std::vector<int32_t> counts(max_symbols * num_predictors); 221 std::vector<uint32_t> tot_extra_bits(num_predictors); 222 for (size_t pred = 0; pred < num_predictors; pred++) { 223 for (size_t i = begin; i < end; i++) { 224 counts[pred * max_symbols + tree_samples.Token(pred, i)] += 225 tree_samples.Count(i); 226 tot_extra_bits[pred] += 227 tree_samples.NBits(pred, i) * tree_samples.Count(i); 228 } 229 } 230 231 float base_bits; 232 { 233 size_t pred = tree_samples.PredictorIndex((*tree)[pos].predictor); 234 base_bits = 235 EstimateBits(counts.data() + pred * max_symbols, max_symbols) + 236 tot_extra_bits[pred]; 237 } 238 239 SplitInfo *best = &best_split_nonstatic; 240 241 SplitInfo forced_split; 242 // The multiplier ranges cut halfway through the current ranges of static 243 // properties. We do this even if the current node is not a leaf, to 244 // minimize the number of nodes in the resulting tree. 245 for (size_t i = 0; i < mul_info.size(); i++) { 246 uint32_t axis; 247 uint32_t val; 248 IntersectionType t = 249 BoxIntersects(static_prop_range, mul_info[i].range, axis, val); 250 if (t == IntersectionType::kNone) continue; 251 if (t == IntersectionType::kInside) { 252 (*tree)[pos].multiplier = mul_info[i].multiplier; 253 break; 254 } 255 if (t == IntersectionType::kPartial) { 256 forced_split.val = tree_samples.QuantizeProperty(axis, val); 257 forced_split.prop = axis; 258 forced_split.lcost = forced_split.rcost = base_bits / 2 - threshold; 259 forced_split.lpred = forced_split.rpred = (*tree)[pos].predictor; 260 best = &forced_split; 261 best->pos = begin; 262 JXL_ASSERT(best->prop == tree_samples.PropertyFromIndex(best->prop)); 263 for (size_t x = begin; x < end; x++) { 264 if (tree_samples.Property(best->prop, x) <= best->val) { 265 best->pos++; 266 } 267 } 268 break; 269 } 270 } 271 272 if (best != &forced_split) { 273 std::vector<int> prop_value_used_count; 274 std::vector<int> count_increase; 275 std::vector<size_t> extra_bits_increase; 276 // For each property, compute which of its values are used, and what 277 // tokens correspond to those usages. Then, iterate through the values, 278 // and compute the entropy of each side of the split (of the form `prop > 279 // threshold`). Finally, find the split that minimizes the cost. 280 struct CostInfo { 281 float cost = std::numeric_limits<float>::max(); 282 float extra_cost = 0; 283 float Cost() const { return cost + extra_cost; } 284 Predictor pred; // will be uninitialized in some cases, but never used. 285 }; 286 std::vector<CostInfo> costs_l; 287 std::vector<CostInfo> costs_r; 288 289 std::vector<int32_t> counts_above(max_symbols); 290 std::vector<int32_t> counts_below(max_symbols); 291 292 // The lower the threshold, the higher the expected noisiness of the 293 // estimate. Thus, discourage changing predictors. 294 float change_pred_penalty = 800.0f / (100.0f + threshold); 295 for (size_t prop = 0; prop < num_properties && base_bits > threshold; 296 prop++) { 297 costs_l.clear(); 298 costs_r.clear(); 299 size_t prop_size = tree_samples.NumPropertyValues(prop); 300 if (extra_bits_increase.size() < prop_size) { 301 count_increase.resize(prop_size * max_symbols); 302 extra_bits_increase.resize(prop_size); 303 } 304 // Clear prop_value_used_count (which cannot be cleared "on the go") 305 prop_value_used_count.clear(); 306 prop_value_used_count.resize(prop_size); 307 308 size_t first_used = prop_size; 309 size_t last_used = 0; 310 311 // TODO(veluca): consider finding multiple splits along a single 312 // property at the same time, possibly with a bottom-up approach. 313 for (size_t i = begin; i < end; i++) { 314 size_t p = tree_samples.Property(prop, i); 315 prop_value_used_count[p]++; 316 last_used = std::max(last_used, p); 317 first_used = std::min(first_used, p); 318 } 319 costs_l.resize(last_used - first_used); 320 costs_r.resize(last_used - first_used); 321 // For all predictors, compute the right and left costs of each split. 322 for (size_t pred = 0; pred < num_predictors; pred++) { 323 // Compute cost and histogram increments for each property value. 324 for (size_t i = begin; i < end; i++) { 325 size_t p = tree_samples.Property(prop, i); 326 size_t cnt = tree_samples.Count(i); 327 size_t sym = tree_samples.Token(pred, i); 328 count_increase[p * max_symbols + sym] += cnt; 329 extra_bits_increase[p] += tree_samples.NBits(pred, i) * cnt; 330 } 331 memcpy(counts_above.data(), counts.data() + pred * max_symbols, 332 max_symbols * sizeof counts_above[0]); 333 memset(counts_below.data(), 0, max_symbols * sizeof counts_below[0]); 334 size_t extra_bits_below = 0; 335 // Exclude last used: this ensures neither counts_above nor 336 // counts_below is empty. 337 for (size_t i = first_used; i < last_used; i++) { 338 if (!prop_value_used_count[i]) continue; 339 extra_bits_below += extra_bits_increase[i]; 340 // The increase for this property value has been used, and will not 341 // be used again: clear it. Also below. 342 extra_bits_increase[i] = 0; 343 for (size_t sym = 0; sym < max_symbols; sym++) { 344 counts_above[sym] -= count_increase[i * max_symbols + sym]; 345 counts_below[sym] += count_increase[i * max_symbols + sym]; 346 count_increase[i * max_symbols + sym] = 0; 347 } 348 float rcost = EstimateBits(counts_above.data(), max_symbols) + 349 tot_extra_bits[pred] - extra_bits_below; 350 float lcost = EstimateBits(counts_below.data(), max_symbols) + 351 extra_bits_below; 352 JXL_DASSERT(extra_bits_below <= tot_extra_bits[pred]); 353 float penalty = 0; 354 // Never discourage moving away from the Weighted predictor. 355 if (tree_samples.PredictorFromIndex(pred) != 356 (*tree)[pos].predictor && 357 (*tree)[pos].predictor != Predictor::Weighted) { 358 penalty = change_pred_penalty; 359 } 360 // If everything else is equal, disfavour Weighted (slower) and 361 // favour Zero (faster if it's the only predictor used in a 362 // group+channel combination) 363 if (tree_samples.PredictorFromIndex(pred) == Predictor::Weighted) { 364 penalty += 1e-8; 365 } 366 if (tree_samples.PredictorFromIndex(pred) == Predictor::Zero) { 367 penalty -= 1e-8; 368 } 369 if (rcost + penalty < costs_r[i - first_used].Cost()) { 370 costs_r[i - first_used].cost = rcost; 371 costs_r[i - first_used].extra_cost = penalty; 372 costs_r[i - first_used].pred = 373 tree_samples.PredictorFromIndex(pred); 374 } 375 if (lcost + penalty < costs_l[i - first_used].Cost()) { 376 costs_l[i - first_used].cost = lcost; 377 costs_l[i - first_used].extra_cost = penalty; 378 costs_l[i - first_used].pred = 379 tree_samples.PredictorFromIndex(pred); 380 } 381 } 382 } 383 // Iterate through the possible splits and find the one with minimum sum 384 // of costs of the two sides. 385 size_t split = begin; 386 for (size_t i = first_used; i < last_used; i++) { 387 if (!prop_value_used_count[i]) continue; 388 split += prop_value_used_count[i]; 389 float rcost = costs_r[i - first_used].cost; 390 float lcost = costs_l[i - first_used].cost; 391 // WP was not used + we would use the WP property or predictor 392 bool adds_wp = 393 (tree_samples.PropertyFromIndex(prop) == kWPProp && 394 (used_properties & (1LU << prop)) == 0) || 395 ((costs_l[i - first_used].pred == Predictor::Weighted || 396 costs_r[i - first_used].pred == Predictor::Weighted) && 397 (*tree)[pos].predictor != Predictor::Weighted); 398 bool zero_entropy_side = rcost == 0 || lcost == 0; 399 400 SplitInfo &best = 401 prop < kNumStaticProperties 402 ? (zero_entropy_side ? best_split_static_constant 403 : best_split_static) 404 : (adds_wp ? best_split_nonstatic : best_split_nowp); 405 if (lcost + rcost < best.Cost()) { 406 best.prop = prop; 407 best.val = i; 408 best.pos = split; 409 best.lcost = lcost; 410 best.lpred = costs_l[i - first_used].pred; 411 best.rcost = rcost; 412 best.rpred = costs_r[i - first_used].pred; 413 } 414 } 415 // Clear extra_bits_increase and cost_increase for last_used. 416 extra_bits_increase[last_used] = 0; 417 for (size_t sym = 0; sym < max_symbols; sym++) { 418 count_increase[last_used * max_symbols + sym] = 0; 419 } 420 } 421 422 // Try to avoid introducing WP. 423 if (best_split_nowp.Cost() + threshold < base_bits && 424 best_split_nowp.Cost() <= fast_decode_multiplier * best->Cost()) { 425 best = &best_split_nowp; 426 } 427 // Split along static props if possible and not significantly more 428 // expensive. 429 if (best_split_static.Cost() + threshold < base_bits && 430 best_split_static.Cost() <= fast_decode_multiplier * best->Cost()) { 431 best = &best_split_static; 432 } 433 // Split along static props to create constant nodes if possible. 434 if (best_split_static_constant.Cost() + threshold < base_bits) { 435 best = &best_split_static_constant; 436 } 437 } 438 439 if (best->Cost() + threshold < base_bits) { 440 uint32_t p = tree_samples.PropertyFromIndex(best->prop); 441 pixel_type dequant = 442 tree_samples.UnquantizeProperty(best->prop, best->val); 443 // Split node and try to split children. 444 MakeSplitNode(pos, p, dequant, best->lpred, 0, best->rpred, 0, tree); 445 // "Sort" according to winning property 446 SplitTreeSamples(tree_samples, begin, best->pos, end, best->prop); 447 if (p >= kNumStaticProperties) { 448 used_properties |= 1 << best->prop; 449 } 450 auto new_sp_range = static_prop_range; 451 if (p < kNumStaticProperties) { 452 JXL_ASSERT(static_cast<uint32_t>(dequant + 1) <= new_sp_range[p][1]); 453 new_sp_range[p][1] = dequant + 1; 454 JXL_ASSERT(new_sp_range[p][0] < new_sp_range[p][1]); 455 } 456 nodes.push_back(NodeInfo{(*tree)[pos].rchild, begin, best->pos, 457 used_properties, new_sp_range}); 458 new_sp_range = static_prop_range; 459 if (p < kNumStaticProperties) { 460 JXL_ASSERT(new_sp_range[p][0] <= static_cast<uint32_t>(dequant + 1)); 461 new_sp_range[p][0] = dequant + 1; 462 JXL_ASSERT(new_sp_range[p][0] < new_sp_range[p][1]); 463 } 464 nodes.push_back(NodeInfo{(*tree)[pos].lchild, best->pos, end, 465 used_properties, new_sp_range}); 466 } 467 } 468 } 469 470 // NOLINTNEXTLINE(google-readability-namespace-comments) 471 } // namespace HWY_NAMESPACE 472 } // namespace jxl 473 HWY_AFTER_NAMESPACE(); 474 475 #if HWY_ONCE 476 namespace jxl { 477 478 HWY_EXPORT(FindBestSplit); // Local function. 479 480 void ComputeBestTree(TreeSamples &tree_samples, float threshold, 481 const std::vector<ModularMultiplierInfo> &mul_info, 482 StaticPropRange static_prop_range, 483 float fast_decode_multiplier, Tree *tree) { 484 // TODO(veluca): take into account that different contexts can have different 485 // uint configs. 486 // 487 // Initialize tree. 488 tree->emplace_back(); 489 tree->back().property = -1; 490 tree->back().predictor = tree_samples.PredictorFromIndex(0); 491 tree->back().predictor_offset = 0; 492 tree->back().multiplier = 1; 493 JXL_ASSERT(tree_samples.NumProperties() < 64); 494 495 JXL_ASSERT(tree_samples.NumDistinctSamples() <= 496 std::numeric_limits<uint32_t>::max()); 497 HWY_DYNAMIC_DISPATCH(FindBestSplit) 498 (tree_samples, threshold, mul_info, static_prop_range, fast_decode_multiplier, 499 tree); 500 } 501 502 constexpr int32_t TreeSamples::kPropertyRange; 503 constexpr uint32_t TreeSamples::kDedupEntryUnused; 504 505 Status TreeSamples::SetPredictor(Predictor predictor, 506 ModularOptions::TreeMode wp_tree_mode) { 507 if (wp_tree_mode == ModularOptions::TreeMode::kWPOnly) { 508 predictors = {Predictor::Weighted}; 509 residuals.resize(1); 510 return true; 511 } 512 if (wp_tree_mode == ModularOptions::TreeMode::kNoWP && 513 predictor == Predictor::Weighted) { 514 return JXL_FAILURE("Invalid predictor settings"); 515 } 516 if (predictor == Predictor::Variable) { 517 for (size_t i = 0; i < kNumModularPredictors; i++) { 518 predictors.push_back(static_cast<Predictor>(i)); 519 } 520 std::swap(predictors[0], predictors[static_cast<int>(Predictor::Weighted)]); 521 std::swap(predictors[1], predictors[static_cast<int>(Predictor::Gradient)]); 522 } else if (predictor == Predictor::Best) { 523 predictors = {Predictor::Weighted, Predictor::Gradient}; 524 } else { 525 predictors = {predictor}; 526 } 527 if (wp_tree_mode == ModularOptions::TreeMode::kNoWP) { 528 auto wp_it = 529 std::find(predictors.begin(), predictors.end(), Predictor::Weighted); 530 if (wp_it != predictors.end()) { 531 predictors.erase(wp_it); 532 } 533 } 534 residuals.resize(predictors.size()); 535 return true; 536 } 537 538 Status TreeSamples::SetProperties(const std::vector<uint32_t> &properties, 539 ModularOptions::TreeMode wp_tree_mode) { 540 props_to_use = properties; 541 if (wp_tree_mode == ModularOptions::TreeMode::kWPOnly) { 542 props_to_use = {static_cast<uint32_t>(kWPProp)}; 543 } 544 if (wp_tree_mode == ModularOptions::TreeMode::kGradientOnly) { 545 props_to_use = {static_cast<uint32_t>(kGradientProp)}; 546 } 547 if (wp_tree_mode == ModularOptions::TreeMode::kNoWP) { 548 auto it = std::find(props_to_use.begin(), props_to_use.end(), kWPProp); 549 if (it != props_to_use.end()) { 550 props_to_use.erase(it); 551 } 552 } 553 if (props_to_use.empty()) { 554 return JXL_FAILURE("Invalid property set configuration"); 555 } 556 props.resize(props_to_use.size()); 557 return true; 558 } 559 560 void TreeSamples::InitTable(size_t size) { 561 JXL_DASSERT((size & (size - 1)) == 0); 562 if (dedup_table_.size() == size) return; 563 dedup_table_.resize(size, kDedupEntryUnused); 564 for (size_t i = 0; i < NumDistinctSamples(); i++) { 565 if (sample_counts[i] != std::numeric_limits<uint16_t>::max()) { 566 AddToTable(i); 567 } 568 } 569 } 570 571 bool TreeSamples::AddToTableAndMerge(size_t a) { 572 size_t pos1 = Hash1(a); 573 size_t pos2 = Hash2(a); 574 if (dedup_table_[pos1] != kDedupEntryUnused && 575 IsSameSample(a, dedup_table_[pos1])) { 576 JXL_DASSERT(sample_counts[a] == 1); 577 sample_counts[dedup_table_[pos1]]++; 578 // Remove from hash table samples that are saturated. 579 if (sample_counts[dedup_table_[pos1]] == 580 std::numeric_limits<uint16_t>::max()) { 581 dedup_table_[pos1] = kDedupEntryUnused; 582 } 583 return true; 584 } 585 if (dedup_table_[pos2] != kDedupEntryUnused && 586 IsSameSample(a, dedup_table_[pos2])) { 587 JXL_DASSERT(sample_counts[a] == 1); 588 sample_counts[dedup_table_[pos2]]++; 589 // Remove from hash table samples that are saturated. 590 if (sample_counts[dedup_table_[pos2]] == 591 std::numeric_limits<uint16_t>::max()) { 592 dedup_table_[pos2] = kDedupEntryUnused; 593 } 594 return true; 595 } 596 AddToTable(a); 597 return false; 598 } 599 600 void TreeSamples::AddToTable(size_t a) { 601 size_t pos1 = Hash1(a); 602 size_t pos2 = Hash2(a); 603 if (dedup_table_[pos1] == kDedupEntryUnused) { 604 dedup_table_[pos1] = a; 605 } else if (dedup_table_[pos2] == kDedupEntryUnused) { 606 dedup_table_[pos2] = a; 607 } 608 } 609 610 void TreeSamples::PrepareForSamples(size_t num_samples) { 611 for (auto &res : residuals) { 612 res.reserve(res.size() + num_samples); 613 } 614 for (auto &p : props) { 615 p.reserve(p.size() + num_samples); 616 } 617 size_t total_num_samples = num_samples + sample_counts.size(); 618 size_t next_pow2 = 1LLU << CeilLog2Nonzero(total_num_samples * 3 / 2); 619 InitTable(next_pow2); 620 } 621 622 size_t TreeSamples::Hash1(size_t a) const { 623 constexpr uint64_t constant = 0x1e35a7bd; 624 uint64_t h = constant; 625 for (const auto &r : residuals) { 626 h = h * constant + r[a].tok; 627 h = h * constant + r[a].nbits; 628 } 629 for (const auto &p : props) { 630 h = h * constant + p[a]; 631 } 632 return (h >> 16) & (dedup_table_.size() - 1); 633 } 634 size_t TreeSamples::Hash2(size_t a) const { 635 constexpr uint64_t constant = 0x1e35a7bd1e35a7bd; 636 uint64_t h = constant; 637 for (const auto &p : props) { 638 h = h * constant ^ p[a]; 639 } 640 for (const auto &r : residuals) { 641 h = h * constant ^ r[a].tok; 642 h = h * constant ^ r[a].nbits; 643 } 644 return (h >> 16) & (dedup_table_.size() - 1); 645 } 646 647 bool TreeSamples::IsSameSample(size_t a, size_t b) const { 648 bool ret = true; 649 for (const auto &r : residuals) { 650 if (r[a].tok != r[b].tok) { 651 ret = false; 652 } 653 if (r[a].nbits != r[b].nbits) { 654 ret = false; 655 } 656 } 657 for (const auto &p : props) { 658 if (p[a] != p[b]) { 659 ret = false; 660 } 661 } 662 return ret; 663 } 664 665 void TreeSamples::AddSample(pixel_type_w pixel, const Properties &properties, 666 const pixel_type_w *predictions) { 667 for (size_t i = 0; i < predictors.size(); i++) { 668 pixel_type v = pixel - predictions[static_cast<int>(predictors[i])]; 669 uint32_t tok, nbits, bits; 670 HybridUintConfig(4, 1, 2).Encode(PackSigned(v), &tok, &nbits, &bits); 671 JXL_DASSERT(tok < 256); 672 JXL_DASSERT(nbits < 256); 673 residuals[i].emplace_back( 674 ResidualToken{static_cast<uint8_t>(tok), static_cast<uint8_t>(nbits)}); 675 } 676 for (size_t i = 0; i < props_to_use.size(); i++) { 677 props[i].push_back(QuantizeProperty(i, properties[props_to_use[i]])); 678 } 679 sample_counts.push_back(1); 680 num_samples++; 681 if (AddToTableAndMerge(sample_counts.size() - 1)) { 682 for (auto &r : residuals) r.pop_back(); 683 for (auto &p : props) p.pop_back(); 684 sample_counts.pop_back(); 685 } 686 } 687 688 void TreeSamples::Swap(size_t a, size_t b) { 689 if (a == b) return; 690 for (auto &r : residuals) { 691 std::swap(r[a], r[b]); 692 } 693 for (auto &p : props) { 694 std::swap(p[a], p[b]); 695 } 696 std::swap(sample_counts[a], sample_counts[b]); 697 } 698 699 void TreeSamples::ThreeShuffle(size_t a, size_t b, size_t c) { 700 if (b == c) { 701 Swap(a, b); 702 return; 703 } 704 705 for (auto &r : residuals) { 706 auto tmp = r[a]; 707 r[a] = r[c]; 708 r[c] = r[b]; 709 r[b] = tmp; 710 } 711 for (auto &p : props) { 712 auto tmp = p[a]; 713 p[a] = p[c]; 714 p[c] = p[b]; 715 p[b] = tmp; 716 } 717 auto tmp = sample_counts[a]; 718 sample_counts[a] = sample_counts[c]; 719 sample_counts[c] = sample_counts[b]; 720 sample_counts[b] = tmp; 721 } 722 723 namespace { 724 std::vector<int32_t> QuantizeHistogram(const std::vector<uint32_t> &histogram, 725 size_t num_chunks) { 726 if (histogram.empty()) return {}; 727 // TODO(veluca): selecting distinct quantiles is likely not the best 728 // way to go about this. 729 std::vector<int32_t> thresholds; 730 uint64_t sum = std::accumulate(histogram.begin(), histogram.end(), 0LU); 731 uint64_t cumsum = 0; 732 uint64_t threshold = 1; 733 for (size_t i = 0; i + 1 < histogram.size(); i++) { 734 cumsum += histogram[i]; 735 if (cumsum >= threshold * sum / num_chunks) { 736 thresholds.push_back(i); 737 while (cumsum > threshold * sum / num_chunks) threshold++; 738 } 739 } 740 return thresholds; 741 } 742 743 std::vector<int32_t> QuantizeSamples(const std::vector<int32_t> &samples, 744 size_t num_chunks) { 745 if (samples.empty()) return {}; 746 int min = *std::min_element(samples.begin(), samples.end()); 747 constexpr int kRange = 512; 748 min = std::min(std::max(min, -kRange), kRange); 749 std::vector<uint32_t> counts(2 * kRange + 1); 750 for (int s : samples) { 751 uint32_t sample_offset = std::min(std::max(s, -kRange), kRange) - min; 752 counts[sample_offset]++; 753 } 754 std::vector<int32_t> thresholds = QuantizeHistogram(counts, num_chunks); 755 for (auto &v : thresholds) v += min; 756 return thresholds; 757 } 758 } // namespace 759 760 void TreeSamples::PreQuantizeProperties( 761 const StaticPropRange &range, 762 const std::vector<ModularMultiplierInfo> &multiplier_info, 763 const std::vector<uint32_t> &group_pixel_count, 764 const std::vector<uint32_t> &channel_pixel_count, 765 std::vector<pixel_type> &pixel_samples, 766 std::vector<pixel_type> &diff_samples, size_t max_property_values) { 767 // If we have forced splits because of multipliers, choose channel and group 768 // thresholds accordingly. 769 std::vector<int32_t> group_multiplier_thresholds; 770 std::vector<int32_t> channel_multiplier_thresholds; 771 for (const auto &v : multiplier_info) { 772 if (v.range[0][0] != range[0][0]) { 773 channel_multiplier_thresholds.push_back(v.range[0][0] - 1); 774 } 775 if (v.range[0][1] != range[0][1]) { 776 channel_multiplier_thresholds.push_back(v.range[0][1] - 1); 777 } 778 if (v.range[1][0] != range[1][0]) { 779 group_multiplier_thresholds.push_back(v.range[1][0] - 1); 780 } 781 if (v.range[1][1] != range[1][1]) { 782 group_multiplier_thresholds.push_back(v.range[1][1] - 1); 783 } 784 } 785 std::sort(channel_multiplier_thresholds.begin(), 786 channel_multiplier_thresholds.end()); 787 channel_multiplier_thresholds.resize( 788 std::unique(channel_multiplier_thresholds.begin(), 789 channel_multiplier_thresholds.end()) - 790 channel_multiplier_thresholds.begin()); 791 std::sort(group_multiplier_thresholds.begin(), 792 group_multiplier_thresholds.end()); 793 group_multiplier_thresholds.resize( 794 std::unique(group_multiplier_thresholds.begin(), 795 group_multiplier_thresholds.end()) - 796 group_multiplier_thresholds.begin()); 797 798 compact_properties.resize(props_to_use.size()); 799 auto quantize_channel = [&]() { 800 if (!channel_multiplier_thresholds.empty()) { 801 return channel_multiplier_thresholds; 802 } 803 return QuantizeHistogram(channel_pixel_count, max_property_values); 804 }; 805 auto quantize_group_id = [&]() { 806 if (!group_multiplier_thresholds.empty()) { 807 return group_multiplier_thresholds; 808 } 809 return QuantizeHistogram(group_pixel_count, max_property_values); 810 }; 811 auto quantize_coordinate = [&]() { 812 std::vector<int32_t> quantized; 813 quantized.reserve(max_property_values - 1); 814 for (size_t i = 0; i + 1 < max_property_values; i++) { 815 quantized.push_back((i + 1) * 256 / max_property_values - 1); 816 } 817 return quantized; 818 }; 819 std::vector<int32_t> abs_pixel_thr; 820 std::vector<int32_t> pixel_thr; 821 auto quantize_pixel_property = [&]() { 822 if (pixel_thr.empty()) { 823 pixel_thr = QuantizeSamples(pixel_samples, max_property_values); 824 } 825 return pixel_thr; 826 }; 827 auto quantize_abs_pixel_property = [&]() { 828 if (abs_pixel_thr.empty()) { 829 quantize_pixel_property(); // Compute the non-abs thresholds. 830 for (auto &v : pixel_samples) v = std::abs(v); 831 abs_pixel_thr = QuantizeSamples(pixel_samples, max_property_values); 832 } 833 return abs_pixel_thr; 834 }; 835 std::vector<int32_t> abs_diff_thr; 836 std::vector<int32_t> diff_thr; 837 auto quantize_diff_property = [&]() { 838 if (diff_thr.empty()) { 839 diff_thr = QuantizeSamples(diff_samples, max_property_values); 840 } 841 return diff_thr; 842 }; 843 auto quantize_abs_diff_property = [&]() { 844 if (abs_diff_thr.empty()) { 845 quantize_diff_property(); // Compute the non-abs thresholds. 846 for (auto &v : diff_samples) v = std::abs(v); 847 abs_diff_thr = QuantizeSamples(diff_samples, max_property_values); 848 } 849 return abs_diff_thr; 850 }; 851 auto quantize_wp = [&]() { 852 if (max_property_values < 32) { 853 return std::vector<int32_t>{-127, -63, -31, -15, -7, -3, -1, 0, 854 1, 3, 7, 15, 31, 63, 127}; 855 } 856 if (max_property_values < 64) { 857 return std::vector<int32_t>{-255, -191, -127, -95, -63, -47, -31, -23, 858 -15, -11, -7, -5, -3, -1, 0, 1, 859 3, 5, 7, 11, 15, 23, 31, 47, 860 63, 95, 127, 191, 255}; 861 } 862 return std::vector<int32_t>{ 863 -255, -223, -191, -159, -127, -111, -95, -79, -63, -55, -47, 864 -39, -31, -27, -23, -19, -15, -13, -11, -9, -7, -6, 865 -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 866 6, 7, 9, 11, 13, 15, 19, 23, 27, 31, 39, 867 47, 55, 63, 79, 95, 111, 127, 159, 191, 223, 255}; 868 }; 869 870 property_mapping.resize(props_to_use.size()); 871 for (size_t i = 0; i < props_to_use.size(); i++) { 872 if (props_to_use[i] == 0) { 873 compact_properties[i] = quantize_channel(); 874 } else if (props_to_use[i] == 1) { 875 compact_properties[i] = quantize_group_id(); 876 } else if (props_to_use[i] == 2 || props_to_use[i] == 3) { 877 compact_properties[i] = quantize_coordinate(); 878 } else if (props_to_use[i] == 6 || props_to_use[i] == 7 || 879 props_to_use[i] == 8 || 880 (props_to_use[i] >= kNumNonrefProperties && 881 (props_to_use[i] - kNumNonrefProperties) % 4 == 1)) { 882 compact_properties[i] = quantize_pixel_property(); 883 } else if (props_to_use[i] == 4 || props_to_use[i] == 5 || 884 (props_to_use[i] >= kNumNonrefProperties && 885 (props_to_use[i] - kNumNonrefProperties) % 4 == 0)) { 886 compact_properties[i] = quantize_abs_pixel_property(); 887 } else if (props_to_use[i] >= kNumNonrefProperties && 888 (props_to_use[i] - kNumNonrefProperties) % 4 == 2) { 889 compact_properties[i] = quantize_abs_diff_property(); 890 } else if (props_to_use[i] == kWPProp) { 891 compact_properties[i] = quantize_wp(); 892 } else { 893 compact_properties[i] = quantize_diff_property(); 894 } 895 property_mapping[i].resize(kPropertyRange * 2 + 1); 896 size_t mapped = 0; 897 for (size_t j = 0; j < property_mapping[i].size(); j++) { 898 while (mapped < compact_properties[i].size() && 899 static_cast<int>(j) - kPropertyRange > 900 compact_properties[i][mapped]) { 901 mapped++; 902 } 903 // property_mapping[i] of a value V is `mapped` if 904 // compact_properties[i][mapped] <= j and 905 // compact_properties[i][mapped-1] > j 906 // This is because the decision node in the tree splits on (property) > j, 907 // hence everything that is not > of a threshold should be clustered 908 // together. 909 property_mapping[i][j] = mapped; 910 } 911 } 912 } 913 914 void CollectPixelSamples(const Image &image, const ModularOptions &options, 915 size_t group_id, 916 std::vector<uint32_t> &group_pixel_count, 917 std::vector<uint32_t> &channel_pixel_count, 918 std::vector<pixel_type> &pixel_samples, 919 std::vector<pixel_type> &diff_samples) { 920 if (options.nb_repeats == 0) return; 921 if (group_pixel_count.size() <= group_id) { 922 group_pixel_count.resize(group_id + 1); 923 } 924 if (channel_pixel_count.size() < image.channel.size()) { 925 channel_pixel_count.resize(image.channel.size()); 926 } 927 Rng rng(group_id); 928 // Sample 10% of the final number of samples for property quantization. 929 float fraction = std::min(options.nb_repeats * 0.1, 0.99); 930 Rng::GeometricDistribution dist = Rng::MakeGeometric(fraction); 931 size_t total_pixels = 0; 932 std::vector<size_t> channel_ids; 933 for (size_t i = 0; i < image.channel.size(); i++) { 934 if (image.channel[i].w <= 1 || image.channel[i].h == 0) { 935 continue; // skip empty or width-1 channels. 936 } 937 if (i >= image.nb_meta_channels && 938 (image.channel[i].w > options.max_chan_size || 939 image.channel[i].h > options.max_chan_size)) { 940 break; 941 } 942 channel_ids.push_back(i); 943 group_pixel_count[group_id] += image.channel[i].w * image.channel[i].h; 944 channel_pixel_count[i] += image.channel[i].w * image.channel[i].h; 945 total_pixels += image.channel[i].w * image.channel[i].h; 946 } 947 if (channel_ids.empty()) return; 948 pixel_samples.reserve(pixel_samples.size() + fraction * total_pixels); 949 diff_samples.reserve(diff_samples.size() + fraction * total_pixels); 950 size_t i = 0; 951 size_t y = 0; 952 size_t x = 0; 953 auto advance = [&](size_t amount) { 954 x += amount; 955 // Detect row overflow (rare). 956 while (x >= image.channel[channel_ids[i]].w) { 957 x -= image.channel[channel_ids[i]].w; 958 y++; 959 // Detect end-of-channel (even rarer). 960 if (y == image.channel[channel_ids[i]].h) { 961 i++; 962 y = 0; 963 if (i >= channel_ids.size()) { 964 return; 965 } 966 } 967 } 968 }; 969 advance(rng.Geometric(dist)); 970 for (; i < channel_ids.size(); advance(rng.Geometric(dist) + 1)) { 971 const pixel_type *row = image.channel[channel_ids[i]].Row(y); 972 pixel_samples.push_back(row[x]); 973 size_t xp = x == 0 ? 1 : x - 1; 974 diff_samples.push_back(static_cast<int64_t>(row[x]) - row[xp]); 975 } 976 } 977 978 // TODO(veluca): very simple encoding scheme. This should be improved. 979 void TokenizeTree(const Tree &tree, std::vector<Token> *tokens, 980 Tree *decoder_tree) { 981 JXL_ASSERT(tree.size() <= kMaxTreeSize); 982 std::queue<int> q; 983 q.push(0); 984 size_t leaf_id = 0; 985 decoder_tree->clear(); 986 while (!q.empty()) { 987 int cur = q.front(); 988 q.pop(); 989 JXL_ASSERT(tree[cur].property >= -1); 990 tokens->emplace_back(kPropertyContext, tree[cur].property + 1); 991 if (tree[cur].property == -1) { 992 tokens->emplace_back(kPredictorContext, 993 static_cast<int>(tree[cur].predictor)); 994 tokens->emplace_back(kOffsetContext, 995 PackSigned(tree[cur].predictor_offset)); 996 uint32_t mul_log = Num0BitsBelowLS1Bit_Nonzero(tree[cur].multiplier); 997 uint32_t mul_bits = (tree[cur].multiplier >> mul_log) - 1; 998 tokens->emplace_back(kMultiplierLogContext, mul_log); 999 tokens->emplace_back(kMultiplierBitsContext, mul_bits); 1000 JXL_ASSERT(tree[cur].predictor < Predictor::Best); 1001 decoder_tree->emplace_back(-1, 0, leaf_id++, 0, tree[cur].predictor, 1002 tree[cur].predictor_offset, 1003 tree[cur].multiplier); 1004 continue; 1005 } 1006 decoder_tree->emplace_back(tree[cur].property, tree[cur].splitval, 1007 decoder_tree->size() + q.size() + 1, 1008 decoder_tree->size() + q.size() + 2, 1009 Predictor::Zero, 0, 1); 1010 q.push(tree[cur].lchild); 1011 q.push(tree[cur].rchild); 1012 tokens->emplace_back(kSplitValContext, PackSigned(tree[cur].splitval)); 1013 } 1014 } 1015 1016 } // namespace jxl 1017 #endif // HWY_ONCE