enc_adaptive_quantization.cc (48877B)
1 // Copyright (c) the JPEG XL Project Authors. All rights reserved. 2 // 3 // Use of this source code is governed by a BSD-style 4 // license that can be found in the LICENSE file. 5 6 #include "lib/jxl/enc_adaptive_quantization.h" 7 8 #include <stddef.h> 9 #include <stdlib.h> 10 11 #include <algorithm> 12 #include <atomic> 13 #include <cmath> 14 #include <string> 15 #include <vector> 16 17 #undef HWY_TARGET_INCLUDE 18 #define HWY_TARGET_INCLUDE "lib/jxl/enc_adaptive_quantization.cc" 19 #include <hwy/foreach_target.h> 20 #include <hwy/highway.h> 21 22 #include "lib/jxl/ac_strategy.h" 23 #include "lib/jxl/base/common.h" 24 #include "lib/jxl/base/compiler_specific.h" 25 #include "lib/jxl/base/data_parallel.h" 26 #include "lib/jxl/base/fast_math-inl.h" 27 #include "lib/jxl/base/status.h" 28 #include "lib/jxl/butteraugli/butteraugli.h" 29 #include "lib/jxl/cms/opsin_params.h" 30 #include "lib/jxl/convolve.h" 31 #include "lib/jxl/dec_cache.h" 32 #include "lib/jxl/dec_group.h" 33 #include "lib/jxl/enc_aux_out.h" 34 #include "lib/jxl/enc_butteraugli_comparator.h" 35 #include "lib/jxl/enc_cache.h" 36 #include "lib/jxl/enc_debug_image.h" 37 #include "lib/jxl/enc_group.h" 38 #include "lib/jxl/enc_modular.h" 39 #include "lib/jxl/enc_params.h" 40 #include "lib/jxl/enc_transforms-inl.h" 41 #include "lib/jxl/epf.h" 42 #include "lib/jxl/frame_dimensions.h" 43 #include "lib/jxl/image.h" 44 #include "lib/jxl/image_bundle.h" 45 #include "lib/jxl/image_ops.h" 46 #include "lib/jxl/quant_weights.h" 47 48 // Set JXL_DEBUG_ADAPTIVE_QUANTIZATION to 1 to enable debugging. 49 #ifndef JXL_DEBUG_ADAPTIVE_QUANTIZATION 50 #define JXL_DEBUG_ADAPTIVE_QUANTIZATION 0 51 #endif 52 53 HWY_BEFORE_NAMESPACE(); 54 namespace jxl { 55 namespace HWY_NAMESPACE { 56 namespace { 57 58 // These templates are not found via ADL. 59 using hwy::HWY_NAMESPACE::AbsDiff; 60 using hwy::HWY_NAMESPACE::Add; 61 using hwy::HWY_NAMESPACE::And; 62 using hwy::HWY_NAMESPACE::Max; 63 using hwy::HWY_NAMESPACE::Rebind; 64 using hwy::HWY_NAMESPACE::Sqrt; 65 using hwy::HWY_NAMESPACE::ZeroIfNegative; 66 67 // The following functions modulate an exponent (out_val) and return the updated 68 // value. Their descriptor is limited to 8 lanes for 8x8 blocks. 69 70 // Hack for mask estimation. Eventually replace this code with butteraugli's 71 // masking. 72 float ComputeMaskForAcStrategyUse(const float out_val) { 73 const float kMul = 1.0f; 74 const float kOffset = 0.001f; 75 return kMul / (out_val + kOffset); 76 } 77 78 template <class D, class V> 79 V ComputeMask(const D d, const V out_val) { 80 const auto kBase = Set(d, -0.7647f); 81 const auto kMul4 = Set(d, 9.4708735624378946f); 82 const auto kMul2 = Set(d, 17.35036561631863f); 83 const auto kOffset2 = Set(d, 302.59587815579727f); 84 const auto kMul3 = Set(d, 6.7943250517376494f); 85 const auto kOffset3 = Set(d, 3.7179635626140772f); 86 const auto kOffset4 = Mul(Set(d, 0.25f), kOffset3); 87 const auto kMul0 = Set(d, 0.80061762862741759f); 88 const auto k1 = Set(d, 1.0f); 89 90 // Avoid division by zero. 91 const auto v1 = Max(Mul(out_val, kMul0), Set(d, 1e-3f)); 92 const auto v2 = Div(k1, Add(v1, kOffset2)); 93 const auto v3 = Div(k1, MulAdd(v1, v1, kOffset3)); 94 const auto v4 = Div(k1, MulAdd(v1, v1, kOffset4)); 95 // TODO(jyrki): 96 // A log or two here could make sense. In butteraugli we have effectively 97 // log(log(x + C)) for this kind of use, as a single log is used in 98 // saturating visual masking and here the modulation values are exponential, 99 // another log would counter that. 100 return Add(kBase, MulAdd(kMul4, v4, MulAdd(kMul2, v2, Mul(kMul3, v3)))); 101 } 102 103 // mul and mul2 represent a scaling difference between jxl and butteraugli. 104 const float kSGmul = 226.77216153508914f; 105 const float kSGmul2 = 1.0f / 73.377132366608819f; 106 const float kLog2 = 0.693147181f; 107 // Includes correction factor for std::log -> log2. 108 const float kSGRetMul = kSGmul2 * 18.6580932135f * kLog2; 109 const float kSGVOffset = 7.7825991679894591f; 110 111 template <bool invert, typename D, typename V> 112 V RatioOfDerivativesOfCubicRootToSimpleGamma(const D d, V v) { 113 // The opsin space in jxl is the cubic root of photons, i.e., v * v * v 114 // is related to the number of photons. 115 // 116 // SimpleGamma(v * v * v) is the psychovisual space in butteraugli. 117 // This ratio allows quantization to move from jxl's opsin space to 118 // butteraugli's log-gamma space. 119 float kEpsilon = 1e-2; 120 v = ZeroIfNegative(v); 121 const auto kNumMul = Set(d, kSGRetMul * 3 * kSGmul); 122 const auto kVOffset = Set(d, kSGVOffset * kLog2 + kEpsilon); 123 const auto kDenMul = Set(d, kLog2 * kSGmul); 124 125 const auto v2 = Mul(v, v); 126 127 const auto num = MulAdd(kNumMul, v2, Set(d, kEpsilon)); 128 const auto den = MulAdd(Mul(kDenMul, v), v2, kVOffset); 129 return invert ? Div(num, den) : Div(den, num); 130 } 131 132 template <bool invert = false> 133 float RatioOfDerivativesOfCubicRootToSimpleGamma(float v) { 134 using DScalar = HWY_CAPPED(float, 1); 135 auto vscalar = Load(DScalar(), &v); 136 return GetLane( 137 RatioOfDerivativesOfCubicRootToSimpleGamma<invert>(DScalar(), vscalar)); 138 } 139 140 // TODO(veluca): this function computes an approximation of the derivative of 141 // SimpleGamma with (f(x+eps)-f(x))/eps. Consider two-sided approximation or 142 // exact derivatives. For reference, SimpleGamma was: 143 /* 144 template <typename D, typename V> 145 V SimpleGamma(const D d, V v) { 146 // A simple HDR compatible gamma function. 147 const auto mul = Set(d, kSGmul); 148 const auto kRetMul = Set(d, kSGRetMul); 149 const auto kRetAdd = Set(d, kSGmul2 * -20.2789020414f); 150 const auto kVOffset = Set(d, kSGVOffset); 151 152 v *= mul; 153 154 // This should happen rarely, but may lead to a NaN, which is rather 155 // undesirable. Since negative photons don't exist we solve the NaNs by 156 // clamping here. 157 // TODO(veluca): with FastLog2f, this no longer leads to NaNs. 158 v = ZeroIfNegative(v); 159 return kRetMul * FastLog2f(d, v + kVOffset) + kRetAdd; 160 } 161 */ 162 163 template <class D, class V> 164 V GammaModulation(const D d, const size_t x, const size_t y, 165 const ImageF& xyb_x, const ImageF& xyb_y, const Rect& rect, 166 const V out_val) { 167 const float kBias = 0.16f; 168 JXL_DASSERT(kBias > jxl::cms::kOpsinAbsorbanceBias[0]); 169 JXL_DASSERT(kBias > jxl::cms::kOpsinAbsorbanceBias[1]); 170 JXL_DASSERT(kBias > jxl::cms::kOpsinAbsorbanceBias[2]); 171 auto overall_ratio = Zero(d); 172 auto bias = Set(d, kBias); 173 auto half = Set(d, 0.5f); 174 for (size_t dy = 0; dy < 8; ++dy) { 175 const float* const JXL_RESTRICT row_in_x = rect.ConstRow(xyb_x, y + dy); 176 const float* const JXL_RESTRICT row_in_y = rect.ConstRow(xyb_y, y + dy); 177 for (size_t dx = 0; dx < 8; dx += Lanes(d)) { 178 const auto iny = Add(Load(d, row_in_y + x + dx), bias); 179 const auto inx = Load(d, row_in_x + x + dx); 180 const auto r = Sub(iny, inx); 181 const auto g = Add(iny, inx); 182 const auto ratio_r = 183 RatioOfDerivativesOfCubicRootToSimpleGamma</*invert=*/true>(d, r); 184 const auto ratio_g = 185 RatioOfDerivativesOfCubicRootToSimpleGamma</*invert=*/true>(d, g); 186 const auto avg_ratio = Mul(half, Add(ratio_r, ratio_g)); 187 188 overall_ratio = Add(overall_ratio, avg_ratio); 189 } 190 } 191 overall_ratio = Mul(SumOfLanes(d, overall_ratio), Set(d, 1.0f / 64)); 192 // ideally -1.0, but likely optimal correction adds some entropy, so slightly 193 // less than that. 194 // ln(2) constant folded in because we want std::log but have FastLog2f. 195 static const float v = 0.14507933746197058f; 196 const auto kGam = Set(d, v * 0.693147180559945f); 197 return MulAdd(kGam, FastLog2f(d, overall_ratio), out_val); 198 } 199 200 // Change precision in 8x8 blocks that have high frequency content. 201 template <class D, class V> 202 V HfModulation(const D d, const size_t x, const size_t y, const ImageF& xyb, 203 const Rect& rect, const V out_val) { 204 // Zero out the invalid differences for the rightmost value per row. 205 const Rebind<uint32_t, D> du; 206 HWY_ALIGN constexpr uint32_t kMaskRight[kBlockDim] = {~0u, ~0u, ~0u, ~0u, 207 ~0u, ~0u, ~0u, 0}; 208 209 auto sum = Zero(d); // sum of absolute differences with right and below 210 211 static const float valmin = 0.020602694503245016f; 212 auto valminv = Set(d, valmin); 213 for (size_t dy = 0; dy < 8; ++dy) { 214 const float* JXL_RESTRICT row_in = rect.ConstRow(xyb, y + dy) + x; 215 const float* JXL_RESTRICT row_in_next = 216 dy == 7 ? row_in : rect.ConstRow(xyb, y + dy + 1) + x; 217 218 // In SCALAR, there is no guarantee of having extra row padding. 219 // Hence, we need to ensure we don't access pixels outside the row itself. 220 // In SIMD modes, however, rows are padded, so it's safe to access one 221 // garbage value after the row. The vector then gets masked with kMaskRight 222 // to remove the influence of that value. 223 #if HWY_TARGET != HWY_SCALAR 224 for (size_t dx = 0; dx < 8; dx += Lanes(d)) { 225 #else 226 for (size_t dx = 0; dx < 7; dx += Lanes(d)) { 227 #endif 228 const auto p = Load(d, row_in + dx); 229 const auto pr = LoadU(d, row_in + dx + 1); 230 const auto mask = BitCast(d, Load(du, kMaskRight + dx)); 231 sum = Add(sum, And(mask, Min(valminv, AbsDiff(p, pr)))); 232 233 const auto pd = Load(d, row_in_next + dx); 234 sum = Add(sum, Min(valminv, AbsDiff(p, pd))); 235 } 236 #if HWY_TARGET == HWY_SCALAR 237 const auto p = Load(d, row_in + 7); 238 const auto pd = Load(d, row_in_next + 7); 239 sum = Add(sum, Min(valminv, AbsDiff(p, pd))); 240 #endif 241 } 242 // more negative value gives more bpp 243 static const float kOffset = -1.110929106987477; 244 static const float kMul = -0.38078920620238305; 245 sum = SumOfLanes(d, sum); 246 float scalar_sum = GetLane(sum); 247 scalar_sum += kOffset; 248 scalar_sum *= kMul; 249 return Add(Set(d, scalar_sum), out_val); 250 } 251 252 void PerBlockModulations(const float butteraugli_target, const ImageF& xyb_x, 253 const ImageF& xyb_y, const ImageF& xyb_b, 254 const Rect& rect_in, const float scale, 255 const Rect& rect_out, ImageF* out) { 256 float base_level = 0.48f * scale; 257 float kDampenRampStart = 2.0f; 258 float kDampenRampEnd = 14.0f; 259 float dampen = 1.0f; 260 if (butteraugli_target >= kDampenRampStart) { 261 dampen = 1.0f - ((butteraugli_target - kDampenRampStart) / 262 (kDampenRampEnd - kDampenRampStart)); 263 if (dampen < 0) { 264 dampen = 0; 265 } 266 } 267 const float mul = scale * dampen; 268 const float add = (1.0f - dampen) * base_level; 269 for (size_t iy = rect_out.y0(); iy < rect_out.y1(); iy++) { 270 const size_t y = iy * 8; 271 float* const JXL_RESTRICT row_out = out->Row(iy); 272 const HWY_CAPPED(float, kBlockDim) df; 273 for (size_t ix = rect_out.x0(); ix < rect_out.x1(); ix++) { 274 size_t x = ix * 8; 275 auto out_val = Set(df, row_out[ix]); 276 out_val = ComputeMask(df, out_val); 277 out_val = HfModulation(df, x, y, xyb_y, rect_in, out_val); 278 out_val = GammaModulation(df, x, y, xyb_x, xyb_y, rect_in, out_val); 279 // We want multiplicative quantization field, so everything 280 // until this point has been modulating the exponent. 281 row_out[ix] = FastPow2f(GetLane(out_val) * 1.442695041f) * mul + add; 282 } 283 } 284 } 285 286 template <typename D, typename V> 287 V MaskingSqrt(const D d, V v) { 288 static const float kLogOffset = 27.97044946785558f; 289 static const float kMul = 211.53333281566171f; 290 const auto mul_v = Set(d, kMul * 1e8); 291 const auto offset_v = Set(d, kLogOffset); 292 return Mul(Set(d, 0.25f), Sqrt(MulAdd(v, Sqrt(mul_v), offset_v))); 293 } 294 295 float MaskingSqrt(const float v) { 296 using DScalar = HWY_CAPPED(float, 1); 297 auto vscalar = Load(DScalar(), &v); 298 return GetLane(MaskingSqrt(DScalar(), vscalar)); 299 } 300 301 void StoreMin4(const float v, float& min0, float& min1, float& min2, 302 float& min3) { 303 if (v < min3) { 304 if (v < min0) { 305 min3 = min2; 306 min2 = min1; 307 min1 = min0; 308 min0 = v; 309 } else if (v < min1) { 310 min3 = min2; 311 min2 = min1; 312 min1 = v; 313 } else if (v < min2) { 314 min3 = min2; 315 min2 = v; 316 } else { 317 min3 = v; 318 } 319 } 320 } 321 322 // Look for smooth areas near the area of degradation. 323 // If the areas are generally smooth, don't do masking. 324 // Output is downsampled 2x. 325 void FuzzyErosion(const float butteraugli_target, const Rect& from_rect, 326 const ImageF& from, const Rect& to_rect, ImageF* to) { 327 const size_t xsize = from.xsize(); 328 const size_t ysize = from.ysize(); 329 constexpr int kStep = 1; 330 static_assert(kStep == 1, "Step must be 1"); 331 JXL_ASSERT(to_rect.xsize() * 2 == from_rect.xsize()); 332 JXL_ASSERT(to_rect.ysize() * 2 == from_rect.ysize()); 333 static const float kMulBase0 = 0.125; 334 static const float kMulBase1 = 0.10; 335 static const float kMulBase2 = 0.09; 336 static const float kMulBase3 = 0.06; 337 static const float kMulAdd0 = 0.0; 338 static const float kMulAdd1 = -0.10; 339 static const float kMulAdd2 = -0.09; 340 static const float kMulAdd3 = -0.06; 341 342 float mul = 0.0; 343 if (butteraugli_target < 2.0f) { 344 mul = (2.0f - butteraugli_target) * (1.0f / 2.0f); 345 } 346 float kMul0 = kMulBase0 + mul * kMulAdd0; 347 float kMul1 = kMulBase1 + mul * kMulAdd1; 348 float kMul2 = kMulBase2 + mul * kMulAdd2; 349 float kMul3 = kMulBase3 + mul * kMulAdd3; 350 static const float kTotal = 0.29959705784054957; 351 float norm = kTotal / (kMul0 + kMul1 + kMul2 + kMul3); 352 kMul0 *= norm; 353 kMul1 *= norm; 354 kMul2 *= norm; 355 kMul3 *= norm; 356 357 for (size_t fy = 0; fy < from_rect.ysize(); ++fy) { 358 size_t y = fy + from_rect.y0(); 359 size_t ym1 = y >= kStep ? y - kStep : y; 360 size_t yp1 = y + kStep < ysize ? y + kStep : y; 361 const float* rowt = from.Row(ym1); 362 const float* row = from.Row(y); 363 const float* rowb = from.Row(yp1); 364 float* row_out = to_rect.Row(to, fy / 2); 365 for (size_t fx = 0; fx < from_rect.xsize(); ++fx) { 366 size_t x = fx + from_rect.x0(); 367 size_t xm1 = x >= kStep ? x - kStep : x; 368 size_t xp1 = x + kStep < xsize ? x + kStep : x; 369 float min0 = row[x]; 370 float min1 = row[xm1]; 371 float min2 = row[xp1]; 372 float min3 = rowt[xm1]; 373 // Sort the first four values. 374 if (min0 > min1) std::swap(min0, min1); 375 if (min0 > min2) std::swap(min0, min2); 376 if (min0 > min3) std::swap(min0, min3); 377 if (min1 > min2) std::swap(min1, min2); 378 if (min1 > min3) std::swap(min1, min3); 379 if (min2 > min3) std::swap(min2, min3); 380 // The remaining five values of a 3x3 neighbourhood. 381 StoreMin4(rowt[x], min0, min1, min2, min3); 382 StoreMin4(rowt[xp1], min0, min1, min2, min3); 383 StoreMin4(rowb[xm1], min0, min1, min2, min3); 384 StoreMin4(rowb[x], min0, min1, min2, min3); 385 StoreMin4(rowb[xp1], min0, min1, min2, min3); 386 387 float v = kMul0 * min0 + kMul1 * min1 + kMul2 * min2 + kMul3 * min3; 388 if (fx % 2 == 0 && fy % 2 == 0) { 389 row_out[fx / 2] = v; 390 } else { 391 row_out[fx / 2] += v; 392 } 393 } 394 } 395 } 396 397 struct AdaptiveQuantizationImpl { 398 Status PrepareBuffers(size_t num_threads) { 399 JXL_ASSIGN_OR_RETURN(diff_buffer, 400 ImageF::Create(kEncTileDim + 8, num_threads)); 401 for (size_t i = pre_erosion.size(); i < num_threads; i++) { 402 JXL_ASSIGN_OR_RETURN(ImageF tmp, 403 ImageF::Create(kEncTileDimInBlocks * 2 + 2, 404 kEncTileDimInBlocks * 2 + 2)); 405 pre_erosion.emplace_back(std::move(tmp)); 406 } 407 return true; 408 } 409 410 void ComputeTile(float butteraugli_target, float scale, const Image3F& xyb, 411 const Rect& rect_in, const Rect& rect_out, const int thread, 412 ImageF* mask, ImageF* mask1x1) { 413 JXL_ASSERT(rect_in.x0() % 8 == 0); 414 JXL_ASSERT(rect_in.y0() % 8 == 0); 415 const size_t xsize = xyb.xsize(); 416 const size_t ysize = xyb.ysize(); 417 418 // The XYB gamma is 3.0 to be able to decode faster with two muls. 419 // Butteraugli's gamma is matching the gamma of human eye, around 2.6. 420 // We approximate the gamma difference by adding one cubic root into 421 // the adaptive quantization. This gives us a total gamma of 2.6666 422 // for quantization uses. 423 const float match_gamma_offset = 0.019; 424 425 const HWY_FULL(float) df; 426 427 size_t y_start_1x1 = rect_in.y0() + rect_out.y0() * 8; 428 size_t y_end_1x1 = y_start_1x1 + rect_out.ysize() * 8; 429 430 size_t x_start_1x1 = rect_in.x0() + rect_out.x0() * 8; 431 size_t x_end_1x1 = x_start_1x1 + rect_out.xsize() * 8; 432 433 if (rect_in.x0() != 0 && rect_out.x0() == 0) x_start_1x1 -= 2; 434 if (rect_in.x1() < xsize && rect_out.x1() * 8 == rect_in.xsize()) { 435 x_end_1x1 += 2; 436 } 437 if (rect_in.y0() != 0 && rect_out.y0() == 0) y_start_1x1 -= 2; 438 if (rect_in.y1() < ysize && rect_out.y1() * 8 == rect_in.ysize()) { 439 y_end_1x1 += 2; 440 } 441 442 // Computes image (padded to multiple of 8x8) of local pixel differences. 443 // Subsample both directions by 4. 444 // 1x1 Laplacian of intensity. 445 for (size_t y = y_start_1x1; y < y_end_1x1; ++y) { 446 const size_t y2 = y + 1 < ysize ? y + 1 : y; 447 const size_t y1 = y > 0 ? y - 1 : y; 448 const float* row_in = xyb.ConstPlaneRow(1, y); 449 const float* row_in1 = xyb.ConstPlaneRow(1, y1); 450 const float* row_in2 = xyb.ConstPlaneRow(1, y2); 451 float* mask1x1_out = mask1x1->Row(y); 452 auto scalar_pixel1x1 = [&](size_t x) { 453 const size_t x2 = x + 1 < xsize ? x + 1 : x; 454 const size_t x1 = x > 0 ? x - 1 : x; 455 const float base = 456 0.25f * (row_in2[x] + row_in1[x] + row_in[x1] + row_in[x2]); 457 const float gammac = RatioOfDerivativesOfCubicRootToSimpleGamma( 458 row_in[x] + match_gamma_offset); 459 float diff = fabs(gammac * (row_in[x] - base)); 460 static const double kScaler = 1.0; 461 diff *= kScaler; 462 diff = log1p(diff); 463 static const float kMul = 1.0; 464 static const float kOffset = 0.01; 465 mask1x1_out[x] = kMul / (diff + kOffset); 466 }; 467 for (size_t x = x_start_1x1; x < x_end_1x1; ++x) { 468 scalar_pixel1x1(x); 469 } 470 } 471 472 size_t y_start = rect_in.y0() + rect_out.y0() * 8; 473 size_t y_end = y_start + rect_out.ysize() * 8; 474 475 size_t x_start = rect_in.x0() + rect_out.x0() * 8; 476 size_t x_end = x_start + rect_out.xsize() * 8; 477 478 if (x_start != 0) x_start -= 4; 479 if (x_end != xsize) x_end += 4; 480 if (y_start != 0) y_start -= 4; 481 if (y_end != ysize) y_end += 4; 482 pre_erosion[thread].ShrinkTo((x_end - x_start) / 4, (y_end - y_start) / 4); 483 484 static const float limit = 0.2f; 485 for (size_t y = y_start; y < y_end; ++y) { 486 size_t y2 = y + 1 < ysize ? y + 1 : y; 487 size_t y1 = y > 0 ? y - 1 : y; 488 489 const float* row_in = xyb.ConstPlaneRow(1, y); 490 const float* row_in1 = xyb.ConstPlaneRow(1, y1); 491 const float* row_in2 = xyb.ConstPlaneRow(1, y2); 492 float* JXL_RESTRICT row_out = diff_buffer.Row(thread); 493 494 auto scalar_pixel = [&](size_t x) { 495 const size_t x2 = x + 1 < xsize ? x + 1 : x; 496 const size_t x1 = x > 0 ? x - 1 : x; 497 const float base = 498 0.25f * (row_in2[x] + row_in1[x] + row_in[x1] + row_in[x2]); 499 const float gammac = RatioOfDerivativesOfCubicRootToSimpleGamma( 500 row_in[x] + match_gamma_offset); 501 float diff = gammac * (row_in[x] - base); 502 diff *= diff; 503 if (diff >= limit) { 504 diff = limit; 505 } 506 diff = MaskingSqrt(diff); 507 if ((y % 4) != 0) { 508 row_out[x - x_start] += diff; 509 } else { 510 row_out[x - x_start] = diff; 511 } 512 }; 513 514 size_t x = x_start; 515 // First pixel of the row. 516 if (x_start == 0) { 517 scalar_pixel(x_start); 518 ++x; 519 } 520 // SIMD 521 const auto match_gamma_offset_v = Set(df, match_gamma_offset); 522 const auto quarter = Set(df, 0.25f); 523 for (; x + 1 + Lanes(df) < x_end; x += Lanes(df)) { 524 const auto in = LoadU(df, row_in + x); 525 const auto in_r = LoadU(df, row_in + x + 1); 526 const auto in_l = LoadU(df, row_in + x - 1); 527 const auto in_t = LoadU(df, row_in2 + x); 528 const auto in_b = LoadU(df, row_in1 + x); 529 auto base = Mul(quarter, Add(Add(in_r, in_l), Add(in_t, in_b))); 530 auto gammacv = 531 RatioOfDerivativesOfCubicRootToSimpleGamma</*invert=*/false>( 532 df, Add(in, match_gamma_offset_v)); 533 auto diff = Mul(gammacv, Sub(in, base)); 534 diff = Mul(diff, diff); 535 diff = Min(diff, Set(df, limit)); 536 diff = MaskingSqrt(df, diff); 537 if ((y & 3) != 0) { 538 diff = Add(diff, LoadU(df, row_out + x - x_start)); 539 } 540 StoreU(diff, df, row_out + x - x_start); 541 } 542 // Scalar 543 for (; x < x_end; ++x) { 544 scalar_pixel(x); 545 } 546 if (y % 4 == 3) { 547 float* row_dout = pre_erosion[thread].Row((y - y_start) / 4); 548 for (size_t x = 0; x < (x_end - x_start) / 4; x++) { 549 row_dout[x] = (row_out[x * 4] + row_out[x * 4 + 1] + 550 row_out[x * 4 + 2] + row_out[x * 4 + 3]) * 551 0.25f; 552 } 553 } 554 } 555 Rect from_rect(x_start % 8 == 0 ? 0 : 1, y_start % 8 == 0 ? 0 : 1, 556 rect_out.xsize() * 2, rect_out.ysize() * 2); 557 FuzzyErosion(butteraugli_target, from_rect, pre_erosion[thread], rect_out, 558 &aq_map); 559 for (size_t y = 0; y < rect_out.ysize(); ++y) { 560 const float* aq_map_row = rect_out.ConstRow(aq_map, y); 561 float* mask_row = rect_out.Row(mask, y); 562 for (size_t x = 0; x < rect_out.xsize(); ++x) { 563 mask_row[x] = ComputeMaskForAcStrategyUse(aq_map_row[x]); 564 } 565 } 566 PerBlockModulations(butteraugli_target, xyb.Plane(0), xyb.Plane(1), 567 xyb.Plane(2), rect_in, scale, rect_out, &aq_map); 568 } 569 std::vector<ImageF> pre_erosion; 570 ImageF aq_map; 571 ImageF diff_buffer; 572 }; 573 574 Status Blur1x1Masking(ThreadPool* pool, ImageF* mask1x1, const Rect& rect) { 575 // Blur the mask1x1 to obtain the masking image. 576 // Before blurring it contains an image of absolute value of the 577 // Laplacian of the intensity channel. 578 static const float kFilterMask1x1[5] = { 579 static_cast<float>(0.25647067633737227), 580 static_cast<float>(0.2050056912354399075), 581 static_cast<float>(0.154082048668497307), 582 static_cast<float>(0.08149576591362004441), 583 static_cast<float>(0.0512750104812308467), 584 }; 585 double sum = 586 1.0 + 4 * (kFilterMask1x1[0] + kFilterMask1x1[1] + kFilterMask1x1[2] + 587 kFilterMask1x1[4] + 2 * kFilterMask1x1[3]); 588 if (sum < 1e-5) { 589 sum = 1e-5; 590 } 591 const float normalize = static_cast<float>(1.0 / sum); 592 const float normalize_mul = normalize; 593 WeightsSymmetric5 weights = 594 WeightsSymmetric5{{HWY_REP4(normalize)}, 595 {HWY_REP4(normalize_mul * kFilterMask1x1[0])}, 596 {HWY_REP4(normalize_mul * kFilterMask1x1[2])}, 597 {HWY_REP4(normalize_mul * kFilterMask1x1[1])}, 598 {HWY_REP4(normalize_mul * kFilterMask1x1[4])}, 599 {HWY_REP4(normalize_mul * kFilterMask1x1[3])}}; 600 JXL_ASSIGN_OR_RETURN(ImageF temp, ImageF::Create(rect.xsize(), rect.ysize())); 601 Symmetric5(*mask1x1, rect, weights, pool, &temp); 602 *mask1x1 = std::move(temp); 603 return true; 604 } 605 606 StatusOr<ImageF> AdaptiveQuantizationMap(const float butteraugli_target, 607 const Image3F& xyb, const Rect& rect, 608 float scale, ThreadPool* pool, 609 ImageF* mask, ImageF* mask1x1) { 610 JXL_DASSERT(rect.xsize() % kBlockDim == 0); 611 JXL_DASSERT(rect.ysize() % kBlockDim == 0); 612 AdaptiveQuantizationImpl impl; 613 const size_t xsize_blocks = rect.xsize() / kBlockDim; 614 const size_t ysize_blocks = rect.ysize() / kBlockDim; 615 JXL_ASSIGN_OR_RETURN(impl.aq_map, ImageF::Create(xsize_blocks, ysize_blocks)); 616 JXL_ASSIGN_OR_RETURN(*mask, ImageF::Create(xsize_blocks, ysize_blocks)); 617 JXL_ASSIGN_OR_RETURN(*mask1x1, ImageF::Create(xyb.xsize(), xyb.ysize())); 618 JXL_CHECK(RunOnPool( 619 pool, 0, 620 DivCeil(xsize_blocks, kEncTileDimInBlocks) * 621 DivCeil(ysize_blocks, kEncTileDimInBlocks), 622 [&](const size_t num_threads) { 623 return !!impl.PrepareBuffers(num_threads); 624 }, 625 [&](const uint32_t tid, const size_t thread) { 626 size_t n_enc_tiles = DivCeil(xsize_blocks, kEncTileDimInBlocks); 627 size_t tx = tid % n_enc_tiles; 628 size_t ty = tid / n_enc_tiles; 629 size_t by0 = ty * kEncTileDimInBlocks; 630 size_t by1 = std::min((ty + 1) * kEncTileDimInBlocks, ysize_blocks); 631 size_t bx0 = tx * kEncTileDimInBlocks; 632 size_t bx1 = std::min((tx + 1) * kEncTileDimInBlocks, xsize_blocks); 633 Rect rect_out(bx0, by0, bx1 - bx0, by1 - by0); 634 impl.ComputeTile(butteraugli_target, scale, xyb, rect, rect_out, thread, 635 mask, mask1x1); 636 }, 637 "AQ DiffPrecompute")); 638 639 JXL_RETURN_IF_ERROR(Blur1x1Masking(pool, mask1x1, rect)); 640 return std::move(impl).aq_map; 641 } 642 643 } // namespace 644 645 // NOLINTNEXTLINE(google-readability-namespace-comments) 646 } // namespace HWY_NAMESPACE 647 } // namespace jxl 648 HWY_AFTER_NAMESPACE(); 649 650 #if HWY_ONCE 651 namespace jxl { 652 HWY_EXPORT(AdaptiveQuantizationMap); 653 654 namespace { 655 656 // If true, prints the quantization maps at each iteration. 657 constexpr bool FLAGS_dump_quant_state = false; 658 659 Status DumpHeatmap(const CompressParams& cparams, const AuxOut* aux_out, 660 const std::string& label, const ImageF& image, 661 float good_threshold, float bad_threshold) { 662 if (JXL_DEBUG_ADAPTIVE_QUANTIZATION) { 663 JXL_ASSIGN_OR_RETURN( 664 Image3F heatmap, 665 CreateHeatMapImage(image, good_threshold, bad_threshold)); 666 char filename[200]; 667 snprintf(filename, sizeof(filename), "%s%05d", label.c_str(), 668 aux_out->num_butteraugli_iters); 669 JXL_RETURN_IF_ERROR(DumpImage(cparams, filename, heatmap)); 670 } 671 return true; 672 } 673 674 Status DumpHeatmaps(const CompressParams& cparams, const AuxOut* aux_out, 675 float ba_target, const ImageF& quant_field, 676 const ImageF& tile_heatmap, const ImageF& bt_diffmap) { 677 if (JXL_DEBUG_ADAPTIVE_QUANTIZATION) { 678 if (!WantDebugOutput(cparams)) return true; 679 JXL_ASSIGN_OR_RETURN(ImageF inv_qmap, ImageF::Create(quant_field.xsize(), 680 quant_field.ysize())); 681 for (size_t y = 0; y < quant_field.ysize(); ++y) { 682 const float* JXL_RESTRICT row_q = quant_field.ConstRow(y); 683 float* JXL_RESTRICT row_inv_q = inv_qmap.Row(y); 684 for (size_t x = 0; x < quant_field.xsize(); ++x) { 685 row_inv_q[x] = 1.0f / row_q[x]; // never zero 686 } 687 } 688 JXL_RETURN_IF_ERROR(DumpHeatmap(cparams, aux_out, "quant_heatmap", inv_qmap, 689 4.0f * ba_target, 6.0f * ba_target)); 690 JXL_RETURN_IF_ERROR(DumpHeatmap(cparams, aux_out, "tile_heatmap", 691 tile_heatmap, ba_target, 1.5f * ba_target)); 692 // matches heat maps produced by the command line tool. 693 JXL_RETURN_IF_ERROR(DumpHeatmap(cparams, aux_out, "bt_diffmap", bt_diffmap, 694 ButteraugliFuzzyInverse(1.5), 695 ButteraugliFuzzyInverse(0.5))); 696 } 697 return true; 698 } 699 700 StatusOr<ImageF> TileDistMap(const ImageF& distmap, int tile_size, int margin, 701 const AcStrategyImage& ac_strategy) { 702 const int tile_xsize = (distmap.xsize() + tile_size - 1) / tile_size; 703 const int tile_ysize = (distmap.ysize() + tile_size - 1) / tile_size; 704 JXL_ASSIGN_OR_RETURN(ImageF tile_distmap, 705 ImageF::Create(tile_xsize, tile_ysize)); 706 size_t distmap_stride = tile_distmap.PixelsPerRow(); 707 for (int tile_y = 0; tile_y < tile_ysize; ++tile_y) { 708 AcStrategyRow ac_strategy_row = ac_strategy.ConstRow(tile_y); 709 float* JXL_RESTRICT dist_row = tile_distmap.Row(tile_y); 710 for (int tile_x = 0; tile_x < tile_xsize; ++tile_x) { 711 AcStrategy acs = ac_strategy_row[tile_x]; 712 if (!acs.IsFirstBlock()) continue; 713 int this_tile_xsize = acs.covered_blocks_x() * tile_size; 714 int this_tile_ysize = acs.covered_blocks_y() * tile_size; 715 int y_begin = std::max<int>(0, tile_size * tile_y - margin); 716 int y_end = std::min<int>(distmap.ysize(), 717 tile_size * tile_y + this_tile_ysize + margin); 718 int x_begin = std::max<int>(0, tile_size * tile_x - margin); 719 int x_end = std::min<int>(distmap.xsize(), 720 tile_size * tile_x + this_tile_xsize + margin); 721 float dist_norm = 0.0; 722 double pixels = 0; 723 for (int y = y_begin; y < y_end; ++y) { 724 float ymul = 1.0; 725 constexpr float kBorderMul = 0.98f; 726 constexpr float kCornerMul = 0.7f; 727 if (margin != 0 && (y == y_begin || y == y_end - 1)) { 728 ymul = kBorderMul; 729 } 730 const float* const JXL_RESTRICT row = distmap.Row(y); 731 for (int x = x_begin; x < x_end; ++x) { 732 float xmul = ymul; 733 if (margin != 0 && (x == x_begin || x == x_end - 1)) { 734 if (xmul == 1.0) { 735 xmul = kBorderMul; 736 } else { 737 xmul = kCornerMul; 738 } 739 } 740 float v = row[x]; 741 v *= v; 742 v *= v; 743 v *= v; 744 v *= v; 745 dist_norm += xmul * v; 746 pixels += xmul; 747 } 748 } 749 if (pixels == 0) pixels = 1; 750 // 16th norm is less than the max norm, we reduce the difference 751 // with this normalization factor. 752 constexpr float kTileNorm = 1.2f; 753 const float tile_dist = 754 kTileNorm * std::pow(dist_norm / pixels, 1.0f / 16.0f); 755 dist_row[tile_x] = tile_dist; 756 for (size_t iy = 0; iy < acs.covered_blocks_y(); iy++) { 757 for (size_t ix = 0; ix < acs.covered_blocks_x(); ix++) { 758 dist_row[tile_x + distmap_stride * iy + ix] = tile_dist; 759 } 760 } 761 } 762 } 763 return tile_distmap; 764 } 765 766 const float kDcQuantPow = 0.83f; 767 const float kDcQuant = 1.095924047623553f; 768 const float kAcQuant = 0.7381485255235064f; 769 770 // Computes the decoded image for a given set of compression parameters. 771 StatusOr<ImageBundle> RoundtripImage(const FrameHeader& frame_header, 772 const Image3F& opsin, 773 PassesEncoderState* enc_state, 774 const JxlCmsInterface& cms, 775 ThreadPool* pool) { 776 std::unique_ptr<PassesDecoderState> dec_state = 777 jxl::make_unique<PassesDecoderState>(); 778 JXL_CHECK(dec_state->output_encoding_info.SetFromMetadata( 779 *enc_state->shared.metadata)); 780 dec_state->shared = &enc_state->shared; 781 JXL_ASSERT(opsin.ysize() % kBlockDim == 0); 782 783 const size_t xsize_groups = DivCeil(opsin.xsize(), kGroupDim); 784 const size_t ysize_groups = DivCeil(opsin.ysize(), kGroupDim); 785 const size_t num_groups = xsize_groups * ysize_groups; 786 787 size_t num_special_frames = enc_state->special_frames.size(); 788 size_t num_passes = enc_state->progressive_splitter.GetNumPasses(); 789 ModularFrameEncoder modular_frame_encoder(frame_header, enc_state->cparams, 790 false); 791 JXL_CHECK(InitializePassesEncoder(frame_header, opsin, Rect(opsin), cms, pool, 792 enc_state, &modular_frame_encoder, 793 nullptr)); 794 JXL_CHECK(dec_state->Init(frame_header)); 795 JXL_CHECK(dec_state->InitForAC(num_passes, pool)); 796 797 ImageBundle decoded(&enc_state->shared.metadata->m); 798 decoded.origin = frame_header.frame_origin; 799 JXL_ASSIGN_OR_RETURN(Image3F tmp, 800 Image3F::Create(opsin.xsize(), opsin.ysize())); 801 decoded.SetFromImage(std::move(tmp), 802 dec_state->output_encoding_info.color_encoding); 803 804 PassesDecoderState::PipelineOptions options; 805 options.use_slow_render_pipeline = false; 806 options.coalescing = false; 807 options.render_spotcolors = false; 808 options.render_noise = false; 809 810 // Same as frame_header.nonserialized_metadata->m 811 const ImageMetadata& metadata = *decoded.metadata(); 812 813 JXL_CHECK(dec_state->PreparePipeline(frame_header, &decoded, options)); 814 815 hwy::AlignedUniquePtr<GroupDecCache[]> group_dec_caches; 816 const auto allocate_storage = [&](const size_t num_threads) -> Status { 817 JXL_RETURN_IF_ERROR( 818 dec_state->render_pipeline->PrepareForThreads(num_threads, 819 /*use_group_ids=*/false)); 820 group_dec_caches = hwy::MakeUniqueAlignedArray<GroupDecCache>(num_threads); 821 return true; 822 }; 823 std::atomic<bool> has_error{false}; 824 const auto process_group = [&](const uint32_t group_index, 825 const size_t thread) { 826 if (has_error) return; 827 if (frame_header.loop_filter.epf_iters > 0) { 828 ComputeSigma(frame_header.loop_filter, 829 dec_state->shared->frame_dim.BlockGroupRect(group_index), 830 dec_state.get()); 831 } 832 RenderPipelineInput input = 833 dec_state->render_pipeline->GetInputBuffers(group_index, thread); 834 JXL_CHECK(DecodeGroupForRoundtrip( 835 frame_header, enc_state->coeffs, group_index, dec_state.get(), 836 &group_dec_caches[thread], thread, input, &decoded, nullptr)); 837 for (size_t c = 0; c < metadata.num_extra_channels; c++) { 838 std::pair<ImageF*, Rect> ri = input.GetBuffer(3 + c); 839 FillPlane(0.0f, ri.first, ri.second); 840 } 841 if (!input.Done()) { 842 has_error = true; 843 return; 844 } 845 }; 846 JXL_CHECK(RunOnPool(pool, 0, num_groups, allocate_storage, process_group, 847 "AQ loop")); 848 if (has_error) return JXL_FAILURE("AQ loop failure"); 849 850 // Ensure we don't create any new special frames. 851 enc_state->special_frames.resize(num_special_frames); 852 853 return decoded; 854 } 855 856 constexpr int kMaxButteraugliIters = 4; 857 858 Status FindBestQuantization(const FrameHeader& frame_header, 859 const Image3F& linear, const Image3F& opsin, 860 ImageF& quant_field, PassesEncoderState* enc_state, 861 const JxlCmsInterface& cms, ThreadPool* pool, 862 AuxOut* aux_out) { 863 const CompressParams& cparams = enc_state->cparams; 864 if (cparams.resampling > 1 && 865 cparams.original_butteraugli_distance <= 4.0 * cparams.resampling) { 866 // For downsampled opsin image, the butteraugli based adaptive quantization 867 // loop would only make the size bigger without improving the distance much, 868 // so in this case we enable it only for very high butteraugli targets. 869 return true; 870 } 871 Quantizer& quantizer = enc_state->shared.quantizer; 872 ImageI& raw_quant_field = enc_state->shared.raw_quant_field; 873 874 const float butteraugli_target = cparams.butteraugli_distance; 875 const float original_butteraugli = cparams.original_butteraugli_distance; 876 ButteraugliParams params; 877 params.intensity_target = 80.f; 878 JxlButteraugliComparator comparator(params, cms); 879 JXL_CHECK(comparator.SetLinearReferenceImage(linear)); 880 bool lower_is_better = 881 (comparator.GoodQualityScore() < comparator.BadQualityScore()); 882 const float initial_quant_dc = InitialQuantDC(butteraugli_target); 883 AdjustQuantField(enc_state->shared.ac_strategy, Rect(quant_field), 884 original_butteraugli, &quant_field); 885 ImageF tile_distmap; 886 JXL_ASSIGN_OR_RETURN( 887 ImageF initial_quant_field, 888 ImageF::Create(quant_field.xsize(), quant_field.ysize())); 889 CopyImageTo(quant_field, &initial_quant_field); 890 891 float initial_qf_min; 892 float initial_qf_max; 893 ImageMinMax(initial_quant_field, &initial_qf_min, &initial_qf_max); 894 float initial_qf_ratio = initial_qf_max / initial_qf_min; 895 float qf_max_deviation_low = std::sqrt(250 / initial_qf_ratio); 896 float asymmetry = 2; 897 if (qf_max_deviation_low < asymmetry) asymmetry = qf_max_deviation_low; 898 float qf_lower = initial_qf_min / (asymmetry * qf_max_deviation_low); 899 float qf_higher = initial_qf_max * (qf_max_deviation_low / asymmetry); 900 901 JXL_ASSERT(qf_higher / qf_lower < 253); 902 903 constexpr int kOriginalComparisonRound = 1; 904 int iters = kMaxButteraugliIters; 905 if (cparams.speed_tier != SpeedTier::kTortoise) { 906 iters = 2; 907 } 908 for (int i = 0; i < iters + 1; ++i) { 909 if (JXL_DEBUG_ADAPTIVE_QUANTIZATION) { 910 printf("\nQuantization field:\n"); 911 for (size_t y = 0; y < quant_field.ysize(); ++y) { 912 for (size_t x = 0; x < quant_field.xsize(); ++x) { 913 printf(" %.5f", quant_field.Row(y)[x]); 914 } 915 printf("\n"); 916 } 917 } 918 quantizer.SetQuantField(initial_quant_dc, quant_field, &raw_quant_field); 919 JXL_ASSIGN_OR_RETURN( 920 ImageBundle dec_linear, 921 RoundtripImage(frame_header, opsin, enc_state, cms, pool)); 922 float score; 923 ImageF diffmap; 924 JXL_CHECK(comparator.CompareWith(dec_linear, &diffmap, &score)); 925 if (!lower_is_better) { 926 score = -score; 927 ScaleImage(-1.0f, &diffmap); 928 } 929 JXL_ASSIGN_OR_RETURN(tile_distmap, 930 TileDistMap(diffmap, 8 * cparams.resampling, 0, 931 enc_state->shared.ac_strategy)); 932 if (JXL_DEBUG_ADAPTIVE_QUANTIZATION && WantDebugOutput(cparams)) { 933 JXL_RETURN_IF_ERROR(DumpImage(cparams, ("dec" + ToString(i)).c_str(), 934 *dec_linear.color())); 935 JXL_RETURN_IF_ERROR(DumpHeatmaps(cparams, aux_out, butteraugli_target, 936 quant_field, tile_distmap, diffmap)); 937 } 938 if (aux_out != nullptr) ++aux_out->num_butteraugli_iters; 939 if (JXL_DEBUG_ADAPTIVE_QUANTIZATION) { 940 float minval; 941 float maxval; 942 ImageMinMax(quant_field, &minval, &maxval); 943 printf("\nButteraugli iter: %d/%d\n", i, kMaxButteraugliIters); 944 printf("Butteraugli distance: %f (target = %f)\n", score, 945 original_butteraugli); 946 printf("quant range: %f ... %f DC quant: %f\n", minval, maxval, 947 initial_quant_dc); 948 if (FLAGS_dump_quant_state) { 949 quantizer.DumpQuantizationMap(raw_quant_field); 950 } 951 } 952 953 if (i == iters) break; 954 955 double kPow[8] = { 956 0.2, 0.2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 957 }; 958 double kPowMod[8] = { 959 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 960 }; 961 if (i == kOriginalComparisonRound) { 962 // Don't allow optimization to make the quant field a lot worse than 963 // what the initial guess was. This allows the AC field to have enough 964 // precision to reduce the oscillations due to the dc reconstruction. 965 double kInitMul = 0.6; 966 const double kOneMinusInitMul = 1.0 - kInitMul; 967 for (size_t y = 0; y < quant_field.ysize(); ++y) { 968 float* const JXL_RESTRICT row_q = quant_field.Row(y); 969 const float* const JXL_RESTRICT row_init = initial_quant_field.Row(y); 970 for (size_t x = 0; x < quant_field.xsize(); ++x) { 971 double clamp = kOneMinusInitMul * row_q[x] + kInitMul * row_init[x]; 972 if (row_q[x] < clamp) { 973 row_q[x] = clamp; 974 if (row_q[x] > qf_higher) row_q[x] = qf_higher; 975 if (row_q[x] < qf_lower) row_q[x] = qf_lower; 976 } 977 } 978 } 979 } 980 981 double cur_pow = 0.0; 982 if (i < 7) { 983 cur_pow = kPow[i] + (original_butteraugli - 1.0) * kPowMod[i]; 984 if (cur_pow < 0) { 985 cur_pow = 0; 986 } 987 } 988 if (cur_pow == 0.0) { 989 for (size_t y = 0; y < quant_field.ysize(); ++y) { 990 const float* const JXL_RESTRICT row_dist = tile_distmap.Row(y); 991 float* const JXL_RESTRICT row_q = quant_field.Row(y); 992 for (size_t x = 0; x < quant_field.xsize(); ++x) { 993 const float diff = row_dist[x] / original_butteraugli; 994 if (diff > 1.0f) { 995 float old = row_q[x]; 996 row_q[x] *= diff; 997 int qf_old = 998 static_cast<int>(std::lround(old * quantizer.InvGlobalScale())); 999 int qf_new = static_cast<int>( 1000 std::lround(row_q[x] * quantizer.InvGlobalScale())); 1001 if (qf_old == qf_new) { 1002 row_q[x] = old + quantizer.Scale(); 1003 } 1004 } 1005 if (row_q[x] > qf_higher) row_q[x] = qf_higher; 1006 if (row_q[x] < qf_lower) row_q[x] = qf_lower; 1007 } 1008 } 1009 } else { 1010 for (size_t y = 0; y < quant_field.ysize(); ++y) { 1011 const float* const JXL_RESTRICT row_dist = tile_distmap.Row(y); 1012 float* const JXL_RESTRICT row_q = quant_field.Row(y); 1013 for (size_t x = 0; x < quant_field.xsize(); ++x) { 1014 const float diff = row_dist[x] / original_butteraugli; 1015 if (diff <= 1.0f) { 1016 row_q[x] *= std::pow(diff, cur_pow); 1017 } else { 1018 float old = row_q[x]; 1019 row_q[x] *= diff; 1020 int qf_old = 1021 static_cast<int>(std::lround(old * quantizer.InvGlobalScale())); 1022 int qf_new = static_cast<int>( 1023 std::lround(row_q[x] * quantizer.InvGlobalScale())); 1024 if (qf_old == qf_new) { 1025 row_q[x] = old + quantizer.Scale(); 1026 } 1027 } 1028 if (row_q[x] > qf_higher) row_q[x] = qf_higher; 1029 if (row_q[x] < qf_lower) row_q[x] = qf_lower; 1030 } 1031 } 1032 } 1033 } 1034 quantizer.SetQuantField(initial_quant_dc, quant_field, &raw_quant_field); 1035 return true; 1036 } 1037 1038 Status FindBestQuantizationMaxError(const FrameHeader& frame_header, 1039 const Image3F& opsin, ImageF& quant_field, 1040 PassesEncoderState* enc_state, 1041 const JxlCmsInterface& cms, 1042 ThreadPool* pool, AuxOut* aux_out) { 1043 // TODO(szabadka): Make this work for non-opsin color spaces. 1044 const CompressParams& cparams = enc_state->cparams; 1045 Quantizer& quantizer = enc_state->shared.quantizer; 1046 ImageI& raw_quant_field = enc_state->shared.raw_quant_field; 1047 1048 // TODO(veluca): better choice of this value. 1049 const float initial_quant_dc = 1050 16 * std::sqrt(0.1f / cparams.butteraugli_distance); 1051 AdjustQuantField(enc_state->shared.ac_strategy, Rect(quant_field), 1052 cparams.original_butteraugli_distance, &quant_field); 1053 1054 const float inv_max_err[3] = {1.0f / enc_state->cparams.max_error[0], 1055 1.0f / enc_state->cparams.max_error[1], 1056 1.0f / enc_state->cparams.max_error[2]}; 1057 1058 for (int i = 0; i < kMaxButteraugliIters + 1; ++i) { 1059 quantizer.SetQuantField(initial_quant_dc, quant_field, &raw_quant_field); 1060 if (JXL_DEBUG_ADAPTIVE_QUANTIZATION && aux_out) { 1061 JXL_RETURN_IF_ERROR( 1062 DumpXybImage(cparams, ("ops" + ToString(i)).c_str(), opsin)); 1063 } 1064 JXL_ASSIGN_OR_RETURN( 1065 ImageBundle decoded, 1066 RoundtripImage(frame_header, opsin, enc_state, cms, pool)); 1067 if (JXL_DEBUG_ADAPTIVE_QUANTIZATION && aux_out) { 1068 JXL_RETURN_IF_ERROR(DumpXybImage(cparams, ("dec" + ToString(i)).c_str(), 1069 *decoded.color())); 1070 } 1071 for (size_t by = 0; by < enc_state->shared.frame_dim.ysize_blocks; by++) { 1072 AcStrategyRow ac_strategy_row = 1073 enc_state->shared.ac_strategy.ConstRow(by); 1074 for (size_t bx = 0; bx < enc_state->shared.frame_dim.xsize_blocks; bx++) { 1075 AcStrategy acs = ac_strategy_row[bx]; 1076 if (!acs.IsFirstBlock()) continue; 1077 float max_error = 0; 1078 for (size_t c = 0; c < 3; c++) { 1079 for (size_t y = by * kBlockDim; 1080 y < (by + acs.covered_blocks_y()) * kBlockDim; y++) { 1081 if (y >= decoded.ysize()) continue; 1082 const float* JXL_RESTRICT in_row = opsin.ConstPlaneRow(c, y); 1083 const float* JXL_RESTRICT dec_row = 1084 decoded.color()->ConstPlaneRow(c, y); 1085 for (size_t x = bx * kBlockDim; 1086 x < (bx + acs.covered_blocks_x()) * kBlockDim; x++) { 1087 if (x >= decoded.xsize()) continue; 1088 max_error = std::max( 1089 std::abs(in_row[x] - dec_row[x]) * inv_max_err[c], max_error); 1090 } 1091 } 1092 } 1093 // Target an error between max_error/2 and max_error. 1094 // If the error in the varblock is above the target, increase the qf to 1095 // compensate. If the error is below the target, decrease the qf. 1096 // However, to avoid an excessive increase of the qf, only do so if the 1097 // error is less than half the maximum allowed error. 1098 const float qf_mul = (max_error < 0.5f) ? max_error * 2.0f 1099 : (max_error > 1.0f) ? max_error 1100 : 1.0f; 1101 for (size_t qy = by; qy < by + acs.covered_blocks_y(); qy++) { 1102 float* JXL_RESTRICT quant_field_row = quant_field.Row(qy); 1103 for (size_t qx = bx; qx < bx + acs.covered_blocks_x(); qx++) { 1104 quant_field_row[qx] *= qf_mul; 1105 } 1106 } 1107 } 1108 } 1109 } 1110 quantizer.SetQuantField(initial_quant_dc, quant_field, &raw_quant_field); 1111 return true; 1112 } 1113 1114 } // namespace 1115 1116 void AdjustQuantField(const AcStrategyImage& ac_strategy, const Rect& rect, 1117 float butteraugli_target, ImageF* quant_field) { 1118 // Replace the whole quant_field in non-8x8 blocks with the maximum of each 1119 // 8x8 block. 1120 size_t stride = quant_field->PixelsPerRow(); 1121 1122 // At low distances it is great to use max, but mean works better 1123 // at high distances. We interpolate between them for a distance 1124 // range. 1125 float mean_max_mixer = 1.0f; 1126 { 1127 static const float kLimit = 1.54138f; 1128 static const float kMul = 0.56391f; 1129 static const float kMin = 0.0f; 1130 if (butteraugli_target > kLimit) { 1131 mean_max_mixer -= (butteraugli_target - kLimit) * kMul; 1132 if (mean_max_mixer < kMin) { 1133 mean_max_mixer = kMin; 1134 } 1135 } 1136 } 1137 for (size_t y = 0; y < rect.ysize(); ++y) { 1138 AcStrategyRow ac_strategy_row = ac_strategy.ConstRow(rect, y); 1139 float* JXL_RESTRICT quant_row = rect.Row(quant_field, y); 1140 for (size_t x = 0; x < rect.xsize(); ++x) { 1141 AcStrategy acs = ac_strategy_row[x]; 1142 if (!acs.IsFirstBlock()) continue; 1143 JXL_ASSERT(x + acs.covered_blocks_x() <= quant_field->xsize()); 1144 JXL_ASSERT(y + acs.covered_blocks_y() <= quant_field->ysize()); 1145 float max = quant_row[x]; 1146 float mean = 0.0; 1147 for (size_t iy = 0; iy < acs.covered_blocks_y(); iy++) { 1148 for (size_t ix = 0; ix < acs.covered_blocks_x(); ix++) { 1149 mean += quant_row[x + ix + iy * stride]; 1150 max = std::max(quant_row[x + ix + iy * stride], max); 1151 } 1152 } 1153 mean /= acs.covered_blocks_y() * acs.covered_blocks_x(); 1154 if (acs.covered_blocks_y() * acs.covered_blocks_x() >= 4) { 1155 max *= mean_max_mixer; 1156 max += (1.0f - mean_max_mixer) * mean; 1157 } 1158 for (size_t iy = 0; iy < acs.covered_blocks_y(); iy++) { 1159 for (size_t ix = 0; ix < acs.covered_blocks_x(); ix++) { 1160 quant_row[x + ix + iy * stride] = max; 1161 } 1162 } 1163 } 1164 } 1165 } 1166 1167 float InitialQuantDC(float butteraugli_target) { 1168 const float kDcMul = 0.3; // Butteraugli target where non-linearity kicks in. 1169 const float butteraugli_target_dc = std::max<float>( 1170 0.5f * butteraugli_target, 1171 std::min<float>(butteraugli_target, 1172 kDcMul * std::pow((1.0f / kDcMul) * butteraugli_target, 1173 kDcQuantPow))); 1174 // We want the maximum DC value to be at most 2**15 * kInvDCQuant / quant_dc. 1175 // The maximum DC value might not be in the kXybRange because of inverse 1176 // gaborish, so we add some slack to the maximum theoretical quant obtained 1177 // this way (64). 1178 return std::min(kDcQuant / butteraugli_target_dc, 50.f); 1179 } 1180 1181 StatusOr<ImageF> InitialQuantField(const float butteraugli_target, 1182 const Image3F& opsin, const Rect& rect, 1183 ThreadPool* pool, float rescale, 1184 ImageF* mask, ImageF* mask1x1) { 1185 const float quant_ac = kAcQuant / butteraugli_target; 1186 return HWY_DYNAMIC_DISPATCH(AdaptiveQuantizationMap)( 1187 butteraugli_target, opsin, rect, quant_ac * rescale, pool, mask, mask1x1); 1188 } 1189 1190 Status FindBestQuantizer(const FrameHeader& frame_header, const Image3F* linear, 1191 const Image3F& opsin, ImageF& quant_field, 1192 PassesEncoderState* enc_state, 1193 const JxlCmsInterface& cms, ThreadPool* pool, 1194 AuxOut* aux_out, double rescale) { 1195 const CompressParams& cparams = enc_state->cparams; 1196 if (cparams.max_error_mode) { 1197 JXL_RETURN_IF_ERROR(FindBestQuantizationMaxError( 1198 frame_header, opsin, quant_field, enc_state, cms, pool, aux_out)); 1199 } else if (linear && cparams.speed_tier <= SpeedTier::kKitten) { 1200 // Normal encoding to a butteraugli score. 1201 JXL_RETURN_IF_ERROR(FindBestQuantization(frame_header, *linear, opsin, 1202 quant_field, enc_state, cms, pool, 1203 aux_out)); 1204 } 1205 return true; 1206 } 1207 1208 } // namespace jxl 1209 #endif // HWY_ONCE