libjxl

FORK: libjxl patches used on blog
git clone https://git.neptards.moe/blog/libjxl.git
Log | Files | Refs | Submodules | README | LICENSE

enc_adaptive_quantization.cc (48877B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 #include "lib/jxl/enc_adaptive_quantization.h"
      7 
      8 #include <stddef.h>
      9 #include <stdlib.h>
     10 
     11 #include <algorithm>
     12 #include <atomic>
     13 #include <cmath>
     14 #include <string>
     15 #include <vector>
     16 
     17 #undef HWY_TARGET_INCLUDE
     18 #define HWY_TARGET_INCLUDE "lib/jxl/enc_adaptive_quantization.cc"
     19 #include <hwy/foreach_target.h>
     20 #include <hwy/highway.h>
     21 
     22 #include "lib/jxl/ac_strategy.h"
     23 #include "lib/jxl/base/common.h"
     24 #include "lib/jxl/base/compiler_specific.h"
     25 #include "lib/jxl/base/data_parallel.h"
     26 #include "lib/jxl/base/fast_math-inl.h"
     27 #include "lib/jxl/base/status.h"
     28 #include "lib/jxl/butteraugli/butteraugli.h"
     29 #include "lib/jxl/cms/opsin_params.h"
     30 #include "lib/jxl/convolve.h"
     31 #include "lib/jxl/dec_cache.h"
     32 #include "lib/jxl/dec_group.h"
     33 #include "lib/jxl/enc_aux_out.h"
     34 #include "lib/jxl/enc_butteraugli_comparator.h"
     35 #include "lib/jxl/enc_cache.h"
     36 #include "lib/jxl/enc_debug_image.h"
     37 #include "lib/jxl/enc_group.h"
     38 #include "lib/jxl/enc_modular.h"
     39 #include "lib/jxl/enc_params.h"
     40 #include "lib/jxl/enc_transforms-inl.h"
     41 #include "lib/jxl/epf.h"
     42 #include "lib/jxl/frame_dimensions.h"
     43 #include "lib/jxl/image.h"
     44 #include "lib/jxl/image_bundle.h"
     45 #include "lib/jxl/image_ops.h"
     46 #include "lib/jxl/quant_weights.h"
     47 
     48 // Set JXL_DEBUG_ADAPTIVE_QUANTIZATION to 1 to enable debugging.
     49 #ifndef JXL_DEBUG_ADAPTIVE_QUANTIZATION
     50 #define JXL_DEBUG_ADAPTIVE_QUANTIZATION 0
     51 #endif
     52 
     53 HWY_BEFORE_NAMESPACE();
     54 namespace jxl {
     55 namespace HWY_NAMESPACE {
     56 namespace {
     57 
     58 // These templates are not found via ADL.
     59 using hwy::HWY_NAMESPACE::AbsDiff;
     60 using hwy::HWY_NAMESPACE::Add;
     61 using hwy::HWY_NAMESPACE::And;
     62 using hwy::HWY_NAMESPACE::Max;
     63 using hwy::HWY_NAMESPACE::Rebind;
     64 using hwy::HWY_NAMESPACE::Sqrt;
     65 using hwy::HWY_NAMESPACE::ZeroIfNegative;
     66 
     67 // The following functions modulate an exponent (out_val) and return the updated
     68 // value. Their descriptor is limited to 8 lanes for 8x8 blocks.
     69 
     70 // Hack for mask estimation. Eventually replace this code with butteraugli's
     71 // masking.
     72 float ComputeMaskForAcStrategyUse(const float out_val) {
     73   const float kMul = 1.0f;
     74   const float kOffset = 0.001f;
     75   return kMul / (out_val + kOffset);
     76 }
     77 
     78 template <class D, class V>
     79 V ComputeMask(const D d, const V out_val) {
     80   const auto kBase = Set(d, -0.7647f);
     81   const auto kMul4 = Set(d, 9.4708735624378946f);
     82   const auto kMul2 = Set(d, 17.35036561631863f);
     83   const auto kOffset2 = Set(d, 302.59587815579727f);
     84   const auto kMul3 = Set(d, 6.7943250517376494f);
     85   const auto kOffset3 = Set(d, 3.7179635626140772f);
     86   const auto kOffset4 = Mul(Set(d, 0.25f), kOffset3);
     87   const auto kMul0 = Set(d, 0.80061762862741759f);
     88   const auto k1 = Set(d, 1.0f);
     89 
     90   // Avoid division by zero.
     91   const auto v1 = Max(Mul(out_val, kMul0), Set(d, 1e-3f));
     92   const auto v2 = Div(k1, Add(v1, kOffset2));
     93   const auto v3 = Div(k1, MulAdd(v1, v1, kOffset3));
     94   const auto v4 = Div(k1, MulAdd(v1, v1, kOffset4));
     95   // TODO(jyrki):
     96   // A log or two here could make sense. In butteraugli we have effectively
     97   // log(log(x + C)) for this kind of use, as a single log is used in
     98   // saturating visual masking and here the modulation values are exponential,
     99   // another log would counter that.
    100   return Add(kBase, MulAdd(kMul4, v4, MulAdd(kMul2, v2, Mul(kMul3, v3))));
    101 }
    102 
    103 // mul and mul2 represent a scaling difference between jxl and butteraugli.
    104 const float kSGmul = 226.77216153508914f;
    105 const float kSGmul2 = 1.0f / 73.377132366608819f;
    106 const float kLog2 = 0.693147181f;
    107 // Includes correction factor for std::log -> log2.
    108 const float kSGRetMul = kSGmul2 * 18.6580932135f * kLog2;
    109 const float kSGVOffset = 7.7825991679894591f;
    110 
    111 template <bool invert, typename D, typename V>
    112 V RatioOfDerivativesOfCubicRootToSimpleGamma(const D d, V v) {
    113   // The opsin space in jxl is the cubic root of photons, i.e., v * v * v
    114   // is related to the number of photons.
    115   //
    116   // SimpleGamma(v * v * v) is the psychovisual space in butteraugli.
    117   // This ratio allows quantization to move from jxl's opsin space to
    118   // butteraugli's log-gamma space.
    119   float kEpsilon = 1e-2;
    120   v = ZeroIfNegative(v);
    121   const auto kNumMul = Set(d, kSGRetMul * 3 * kSGmul);
    122   const auto kVOffset = Set(d, kSGVOffset * kLog2 + kEpsilon);
    123   const auto kDenMul = Set(d, kLog2 * kSGmul);
    124 
    125   const auto v2 = Mul(v, v);
    126 
    127   const auto num = MulAdd(kNumMul, v2, Set(d, kEpsilon));
    128   const auto den = MulAdd(Mul(kDenMul, v), v2, kVOffset);
    129   return invert ? Div(num, den) : Div(den, num);
    130 }
    131 
    132 template <bool invert = false>
    133 float RatioOfDerivativesOfCubicRootToSimpleGamma(float v) {
    134   using DScalar = HWY_CAPPED(float, 1);
    135   auto vscalar = Load(DScalar(), &v);
    136   return GetLane(
    137       RatioOfDerivativesOfCubicRootToSimpleGamma<invert>(DScalar(), vscalar));
    138 }
    139 
    140 // TODO(veluca): this function computes an approximation of the derivative of
    141 // SimpleGamma with (f(x+eps)-f(x))/eps. Consider two-sided approximation or
    142 // exact derivatives. For reference, SimpleGamma was:
    143 /*
    144 template <typename D, typename V>
    145 V SimpleGamma(const D d, V v) {
    146   // A simple HDR compatible gamma function.
    147   const auto mul = Set(d, kSGmul);
    148   const auto kRetMul = Set(d, kSGRetMul);
    149   const auto kRetAdd = Set(d, kSGmul2 * -20.2789020414f);
    150   const auto kVOffset = Set(d, kSGVOffset);
    151 
    152   v *= mul;
    153 
    154   // This should happen rarely, but may lead to a NaN, which is rather
    155   // undesirable. Since negative photons don't exist we solve the NaNs by
    156   // clamping here.
    157   // TODO(veluca): with FastLog2f, this no longer leads to NaNs.
    158   v = ZeroIfNegative(v);
    159   return kRetMul * FastLog2f(d, v + kVOffset) + kRetAdd;
    160 }
    161 */
    162 
    163 template <class D, class V>
    164 V GammaModulation(const D d, const size_t x, const size_t y,
    165                   const ImageF& xyb_x, const ImageF& xyb_y, const Rect& rect,
    166                   const V out_val) {
    167   const float kBias = 0.16f;
    168   JXL_DASSERT(kBias > jxl::cms::kOpsinAbsorbanceBias[0]);
    169   JXL_DASSERT(kBias > jxl::cms::kOpsinAbsorbanceBias[1]);
    170   JXL_DASSERT(kBias > jxl::cms::kOpsinAbsorbanceBias[2]);
    171   auto overall_ratio = Zero(d);
    172   auto bias = Set(d, kBias);
    173   auto half = Set(d, 0.5f);
    174   for (size_t dy = 0; dy < 8; ++dy) {
    175     const float* const JXL_RESTRICT row_in_x = rect.ConstRow(xyb_x, y + dy);
    176     const float* const JXL_RESTRICT row_in_y = rect.ConstRow(xyb_y, y + dy);
    177     for (size_t dx = 0; dx < 8; dx += Lanes(d)) {
    178       const auto iny = Add(Load(d, row_in_y + x + dx), bias);
    179       const auto inx = Load(d, row_in_x + x + dx);
    180       const auto r = Sub(iny, inx);
    181       const auto g = Add(iny, inx);
    182       const auto ratio_r =
    183           RatioOfDerivativesOfCubicRootToSimpleGamma</*invert=*/true>(d, r);
    184       const auto ratio_g =
    185           RatioOfDerivativesOfCubicRootToSimpleGamma</*invert=*/true>(d, g);
    186       const auto avg_ratio = Mul(half, Add(ratio_r, ratio_g));
    187 
    188       overall_ratio = Add(overall_ratio, avg_ratio);
    189     }
    190   }
    191   overall_ratio = Mul(SumOfLanes(d, overall_ratio), Set(d, 1.0f / 64));
    192   // ideally -1.0, but likely optimal correction adds some entropy, so slightly
    193   // less than that.
    194   // ln(2) constant folded in because we want std::log but have FastLog2f.
    195   static const float v = 0.14507933746197058f;
    196   const auto kGam = Set(d, v * 0.693147180559945f);
    197   return MulAdd(kGam, FastLog2f(d, overall_ratio), out_val);
    198 }
    199 
    200 // Change precision in 8x8 blocks that have high frequency content.
    201 template <class D, class V>
    202 V HfModulation(const D d, const size_t x, const size_t y, const ImageF& xyb,
    203                const Rect& rect, const V out_val) {
    204   // Zero out the invalid differences for the rightmost value per row.
    205   const Rebind<uint32_t, D> du;
    206   HWY_ALIGN constexpr uint32_t kMaskRight[kBlockDim] = {~0u, ~0u, ~0u, ~0u,
    207                                                         ~0u, ~0u, ~0u, 0};
    208 
    209   auto sum = Zero(d);  // sum of absolute differences with right and below
    210 
    211   static const float valmin = 0.020602694503245016f;
    212   auto valminv = Set(d, valmin);
    213   for (size_t dy = 0; dy < 8; ++dy) {
    214     const float* JXL_RESTRICT row_in = rect.ConstRow(xyb, y + dy) + x;
    215     const float* JXL_RESTRICT row_in_next =
    216         dy == 7 ? row_in : rect.ConstRow(xyb, y + dy + 1) + x;
    217 
    218     // In SCALAR, there is no guarantee of having extra row padding.
    219     // Hence, we need to ensure we don't access pixels outside the row itself.
    220     // In SIMD modes, however, rows are padded, so it's safe to access one
    221     // garbage value after the row. The vector then gets masked with kMaskRight
    222     // to remove the influence of that value.
    223 #if HWY_TARGET != HWY_SCALAR
    224     for (size_t dx = 0; dx < 8; dx += Lanes(d)) {
    225 #else
    226     for (size_t dx = 0; dx < 7; dx += Lanes(d)) {
    227 #endif
    228       const auto p = Load(d, row_in + dx);
    229       const auto pr = LoadU(d, row_in + dx + 1);
    230       const auto mask = BitCast(d, Load(du, kMaskRight + dx));
    231       sum = Add(sum, And(mask, Min(valminv, AbsDiff(p, pr))));
    232 
    233       const auto pd = Load(d, row_in_next + dx);
    234       sum = Add(sum, Min(valminv, AbsDiff(p, pd)));
    235     }
    236 #if HWY_TARGET == HWY_SCALAR
    237     const auto p = Load(d, row_in + 7);
    238     const auto pd = Load(d, row_in_next + 7);
    239     sum = Add(sum, Min(valminv, AbsDiff(p, pd)));
    240 #endif
    241   }
    242   // more negative value gives more bpp
    243   static const float kOffset = -1.110929106987477;
    244   static const float kMul = -0.38078920620238305;
    245   sum = SumOfLanes(d, sum);
    246   float scalar_sum = GetLane(sum);
    247   scalar_sum += kOffset;
    248   scalar_sum *= kMul;
    249   return Add(Set(d, scalar_sum), out_val);
    250 }
    251 
    252 void PerBlockModulations(const float butteraugli_target, const ImageF& xyb_x,
    253                          const ImageF& xyb_y, const ImageF& xyb_b,
    254                          const Rect& rect_in, const float scale,
    255                          const Rect& rect_out, ImageF* out) {
    256   float base_level = 0.48f * scale;
    257   float kDampenRampStart = 2.0f;
    258   float kDampenRampEnd = 14.0f;
    259   float dampen = 1.0f;
    260   if (butteraugli_target >= kDampenRampStart) {
    261     dampen = 1.0f - ((butteraugli_target - kDampenRampStart) /
    262                      (kDampenRampEnd - kDampenRampStart));
    263     if (dampen < 0) {
    264       dampen = 0;
    265     }
    266   }
    267   const float mul = scale * dampen;
    268   const float add = (1.0f - dampen) * base_level;
    269   for (size_t iy = rect_out.y0(); iy < rect_out.y1(); iy++) {
    270     const size_t y = iy * 8;
    271     float* const JXL_RESTRICT row_out = out->Row(iy);
    272     const HWY_CAPPED(float, kBlockDim) df;
    273     for (size_t ix = rect_out.x0(); ix < rect_out.x1(); ix++) {
    274       size_t x = ix * 8;
    275       auto out_val = Set(df, row_out[ix]);
    276       out_val = ComputeMask(df, out_val);
    277       out_val = HfModulation(df, x, y, xyb_y, rect_in, out_val);
    278       out_val = GammaModulation(df, x, y, xyb_x, xyb_y, rect_in, out_val);
    279       // We want multiplicative quantization field, so everything
    280       // until this point has been modulating the exponent.
    281       row_out[ix] = FastPow2f(GetLane(out_val) * 1.442695041f) * mul + add;
    282     }
    283   }
    284 }
    285 
    286 template <typename D, typename V>
    287 V MaskingSqrt(const D d, V v) {
    288   static const float kLogOffset = 27.97044946785558f;
    289   static const float kMul = 211.53333281566171f;
    290   const auto mul_v = Set(d, kMul * 1e8);
    291   const auto offset_v = Set(d, kLogOffset);
    292   return Mul(Set(d, 0.25f), Sqrt(MulAdd(v, Sqrt(mul_v), offset_v)));
    293 }
    294 
    295 float MaskingSqrt(const float v) {
    296   using DScalar = HWY_CAPPED(float, 1);
    297   auto vscalar = Load(DScalar(), &v);
    298   return GetLane(MaskingSqrt(DScalar(), vscalar));
    299 }
    300 
    301 void StoreMin4(const float v, float& min0, float& min1, float& min2,
    302                float& min3) {
    303   if (v < min3) {
    304     if (v < min0) {
    305       min3 = min2;
    306       min2 = min1;
    307       min1 = min0;
    308       min0 = v;
    309     } else if (v < min1) {
    310       min3 = min2;
    311       min2 = min1;
    312       min1 = v;
    313     } else if (v < min2) {
    314       min3 = min2;
    315       min2 = v;
    316     } else {
    317       min3 = v;
    318     }
    319   }
    320 }
    321 
    322 // Look for smooth areas near the area of degradation.
    323 // If the areas are generally smooth, don't do masking.
    324 // Output is downsampled 2x.
    325 void FuzzyErosion(const float butteraugli_target, const Rect& from_rect,
    326                   const ImageF& from, const Rect& to_rect, ImageF* to) {
    327   const size_t xsize = from.xsize();
    328   const size_t ysize = from.ysize();
    329   constexpr int kStep = 1;
    330   static_assert(kStep == 1, "Step must be 1");
    331   JXL_ASSERT(to_rect.xsize() * 2 == from_rect.xsize());
    332   JXL_ASSERT(to_rect.ysize() * 2 == from_rect.ysize());
    333   static const float kMulBase0 = 0.125;
    334   static const float kMulBase1 = 0.10;
    335   static const float kMulBase2 = 0.09;
    336   static const float kMulBase3 = 0.06;
    337   static const float kMulAdd0 = 0.0;
    338   static const float kMulAdd1 = -0.10;
    339   static const float kMulAdd2 = -0.09;
    340   static const float kMulAdd3 = -0.06;
    341 
    342   float mul = 0.0;
    343   if (butteraugli_target < 2.0f) {
    344     mul = (2.0f - butteraugli_target) * (1.0f / 2.0f);
    345   }
    346   float kMul0 = kMulBase0 + mul * kMulAdd0;
    347   float kMul1 = kMulBase1 + mul * kMulAdd1;
    348   float kMul2 = kMulBase2 + mul * kMulAdd2;
    349   float kMul3 = kMulBase3 + mul * kMulAdd3;
    350   static const float kTotal = 0.29959705784054957;
    351   float norm = kTotal / (kMul0 + kMul1 + kMul2 + kMul3);
    352   kMul0 *= norm;
    353   kMul1 *= norm;
    354   kMul2 *= norm;
    355   kMul3 *= norm;
    356 
    357   for (size_t fy = 0; fy < from_rect.ysize(); ++fy) {
    358     size_t y = fy + from_rect.y0();
    359     size_t ym1 = y >= kStep ? y - kStep : y;
    360     size_t yp1 = y + kStep < ysize ? y + kStep : y;
    361     const float* rowt = from.Row(ym1);
    362     const float* row = from.Row(y);
    363     const float* rowb = from.Row(yp1);
    364     float* row_out = to_rect.Row(to, fy / 2);
    365     for (size_t fx = 0; fx < from_rect.xsize(); ++fx) {
    366       size_t x = fx + from_rect.x0();
    367       size_t xm1 = x >= kStep ? x - kStep : x;
    368       size_t xp1 = x + kStep < xsize ? x + kStep : x;
    369       float min0 = row[x];
    370       float min1 = row[xm1];
    371       float min2 = row[xp1];
    372       float min3 = rowt[xm1];
    373       // Sort the first four values.
    374       if (min0 > min1) std::swap(min0, min1);
    375       if (min0 > min2) std::swap(min0, min2);
    376       if (min0 > min3) std::swap(min0, min3);
    377       if (min1 > min2) std::swap(min1, min2);
    378       if (min1 > min3) std::swap(min1, min3);
    379       if (min2 > min3) std::swap(min2, min3);
    380       // The remaining five values of a 3x3 neighbourhood.
    381       StoreMin4(rowt[x], min0, min1, min2, min3);
    382       StoreMin4(rowt[xp1], min0, min1, min2, min3);
    383       StoreMin4(rowb[xm1], min0, min1, min2, min3);
    384       StoreMin4(rowb[x], min0, min1, min2, min3);
    385       StoreMin4(rowb[xp1], min0, min1, min2, min3);
    386 
    387       float v = kMul0 * min0 + kMul1 * min1 + kMul2 * min2 + kMul3 * min3;
    388       if (fx % 2 == 0 && fy % 2 == 0) {
    389         row_out[fx / 2] = v;
    390       } else {
    391         row_out[fx / 2] += v;
    392       }
    393     }
    394   }
    395 }
    396 
    397 struct AdaptiveQuantizationImpl {
    398   Status PrepareBuffers(size_t num_threads) {
    399     JXL_ASSIGN_OR_RETURN(diff_buffer,
    400                          ImageF::Create(kEncTileDim + 8, num_threads));
    401     for (size_t i = pre_erosion.size(); i < num_threads; i++) {
    402       JXL_ASSIGN_OR_RETURN(ImageF tmp,
    403                            ImageF::Create(kEncTileDimInBlocks * 2 + 2,
    404                                           kEncTileDimInBlocks * 2 + 2));
    405       pre_erosion.emplace_back(std::move(tmp));
    406     }
    407     return true;
    408   }
    409 
    410   void ComputeTile(float butteraugli_target, float scale, const Image3F& xyb,
    411                    const Rect& rect_in, const Rect& rect_out, const int thread,
    412                    ImageF* mask, ImageF* mask1x1) {
    413     JXL_ASSERT(rect_in.x0() % 8 == 0);
    414     JXL_ASSERT(rect_in.y0() % 8 == 0);
    415     const size_t xsize = xyb.xsize();
    416     const size_t ysize = xyb.ysize();
    417 
    418     // The XYB gamma is 3.0 to be able to decode faster with two muls.
    419     // Butteraugli's gamma is matching the gamma of human eye, around 2.6.
    420     // We approximate the gamma difference by adding one cubic root into
    421     // the adaptive quantization. This gives us a total gamma of 2.6666
    422     // for quantization uses.
    423     const float match_gamma_offset = 0.019;
    424 
    425     const HWY_FULL(float) df;
    426 
    427     size_t y_start_1x1 = rect_in.y0() + rect_out.y0() * 8;
    428     size_t y_end_1x1 = y_start_1x1 + rect_out.ysize() * 8;
    429 
    430     size_t x_start_1x1 = rect_in.x0() + rect_out.x0() * 8;
    431     size_t x_end_1x1 = x_start_1x1 + rect_out.xsize() * 8;
    432 
    433     if (rect_in.x0() != 0 && rect_out.x0() == 0) x_start_1x1 -= 2;
    434     if (rect_in.x1() < xsize && rect_out.x1() * 8 == rect_in.xsize()) {
    435       x_end_1x1 += 2;
    436     }
    437     if (rect_in.y0() != 0 && rect_out.y0() == 0) y_start_1x1 -= 2;
    438     if (rect_in.y1() < ysize && rect_out.y1() * 8 == rect_in.ysize()) {
    439       y_end_1x1 += 2;
    440     }
    441 
    442     // Computes image (padded to multiple of 8x8) of local pixel differences.
    443     // Subsample both directions by 4.
    444     // 1x1 Laplacian of intensity.
    445     for (size_t y = y_start_1x1; y < y_end_1x1; ++y) {
    446       const size_t y2 = y + 1 < ysize ? y + 1 : y;
    447       const size_t y1 = y > 0 ? y - 1 : y;
    448       const float* row_in = xyb.ConstPlaneRow(1, y);
    449       const float* row_in1 = xyb.ConstPlaneRow(1, y1);
    450       const float* row_in2 = xyb.ConstPlaneRow(1, y2);
    451       float* mask1x1_out = mask1x1->Row(y);
    452       auto scalar_pixel1x1 = [&](size_t x) {
    453         const size_t x2 = x + 1 < xsize ? x + 1 : x;
    454         const size_t x1 = x > 0 ? x - 1 : x;
    455         const float base =
    456             0.25f * (row_in2[x] + row_in1[x] + row_in[x1] + row_in[x2]);
    457         const float gammac = RatioOfDerivativesOfCubicRootToSimpleGamma(
    458             row_in[x] + match_gamma_offset);
    459         float diff = fabs(gammac * (row_in[x] - base));
    460         static const double kScaler = 1.0;
    461         diff *= kScaler;
    462         diff = log1p(diff);
    463         static const float kMul = 1.0;
    464         static const float kOffset = 0.01;
    465         mask1x1_out[x] = kMul / (diff + kOffset);
    466       };
    467       for (size_t x = x_start_1x1; x < x_end_1x1; ++x) {
    468         scalar_pixel1x1(x);
    469       }
    470     }
    471 
    472     size_t y_start = rect_in.y0() + rect_out.y0() * 8;
    473     size_t y_end = y_start + rect_out.ysize() * 8;
    474 
    475     size_t x_start = rect_in.x0() + rect_out.x0() * 8;
    476     size_t x_end = x_start + rect_out.xsize() * 8;
    477 
    478     if (x_start != 0) x_start -= 4;
    479     if (x_end != xsize) x_end += 4;
    480     if (y_start != 0) y_start -= 4;
    481     if (y_end != ysize) y_end += 4;
    482     pre_erosion[thread].ShrinkTo((x_end - x_start) / 4, (y_end - y_start) / 4);
    483 
    484     static const float limit = 0.2f;
    485     for (size_t y = y_start; y < y_end; ++y) {
    486       size_t y2 = y + 1 < ysize ? y + 1 : y;
    487       size_t y1 = y > 0 ? y - 1 : y;
    488 
    489       const float* row_in = xyb.ConstPlaneRow(1, y);
    490       const float* row_in1 = xyb.ConstPlaneRow(1, y1);
    491       const float* row_in2 = xyb.ConstPlaneRow(1, y2);
    492       float* JXL_RESTRICT row_out = diff_buffer.Row(thread);
    493 
    494       auto scalar_pixel = [&](size_t x) {
    495         const size_t x2 = x + 1 < xsize ? x + 1 : x;
    496         const size_t x1 = x > 0 ? x - 1 : x;
    497         const float base =
    498             0.25f * (row_in2[x] + row_in1[x] + row_in[x1] + row_in[x2]);
    499         const float gammac = RatioOfDerivativesOfCubicRootToSimpleGamma(
    500             row_in[x] + match_gamma_offset);
    501         float diff = gammac * (row_in[x] - base);
    502         diff *= diff;
    503         if (diff >= limit) {
    504           diff = limit;
    505         }
    506         diff = MaskingSqrt(diff);
    507         if ((y % 4) != 0) {
    508           row_out[x - x_start] += diff;
    509         } else {
    510           row_out[x - x_start] = diff;
    511         }
    512       };
    513 
    514       size_t x = x_start;
    515       // First pixel of the row.
    516       if (x_start == 0) {
    517         scalar_pixel(x_start);
    518         ++x;
    519       }
    520       // SIMD
    521       const auto match_gamma_offset_v = Set(df, match_gamma_offset);
    522       const auto quarter = Set(df, 0.25f);
    523       for (; x + 1 + Lanes(df) < x_end; x += Lanes(df)) {
    524         const auto in = LoadU(df, row_in + x);
    525         const auto in_r = LoadU(df, row_in + x + 1);
    526         const auto in_l = LoadU(df, row_in + x - 1);
    527         const auto in_t = LoadU(df, row_in2 + x);
    528         const auto in_b = LoadU(df, row_in1 + x);
    529         auto base = Mul(quarter, Add(Add(in_r, in_l), Add(in_t, in_b)));
    530         auto gammacv =
    531             RatioOfDerivativesOfCubicRootToSimpleGamma</*invert=*/false>(
    532                 df, Add(in, match_gamma_offset_v));
    533         auto diff = Mul(gammacv, Sub(in, base));
    534         diff = Mul(diff, diff);
    535         diff = Min(diff, Set(df, limit));
    536         diff = MaskingSqrt(df, diff);
    537         if ((y & 3) != 0) {
    538           diff = Add(diff, LoadU(df, row_out + x - x_start));
    539         }
    540         StoreU(diff, df, row_out + x - x_start);
    541       }
    542       // Scalar
    543       for (; x < x_end; ++x) {
    544         scalar_pixel(x);
    545       }
    546       if (y % 4 == 3) {
    547         float* row_dout = pre_erosion[thread].Row((y - y_start) / 4);
    548         for (size_t x = 0; x < (x_end - x_start) / 4; x++) {
    549           row_dout[x] = (row_out[x * 4] + row_out[x * 4 + 1] +
    550                          row_out[x * 4 + 2] + row_out[x * 4 + 3]) *
    551                         0.25f;
    552         }
    553       }
    554     }
    555     Rect from_rect(x_start % 8 == 0 ? 0 : 1, y_start % 8 == 0 ? 0 : 1,
    556                    rect_out.xsize() * 2, rect_out.ysize() * 2);
    557     FuzzyErosion(butteraugli_target, from_rect, pre_erosion[thread], rect_out,
    558                  &aq_map);
    559     for (size_t y = 0; y < rect_out.ysize(); ++y) {
    560       const float* aq_map_row = rect_out.ConstRow(aq_map, y);
    561       float* mask_row = rect_out.Row(mask, y);
    562       for (size_t x = 0; x < rect_out.xsize(); ++x) {
    563         mask_row[x] = ComputeMaskForAcStrategyUse(aq_map_row[x]);
    564       }
    565     }
    566     PerBlockModulations(butteraugli_target, xyb.Plane(0), xyb.Plane(1),
    567                         xyb.Plane(2), rect_in, scale, rect_out, &aq_map);
    568   }
    569   std::vector<ImageF> pre_erosion;
    570   ImageF aq_map;
    571   ImageF diff_buffer;
    572 };
    573 
    574 Status Blur1x1Masking(ThreadPool* pool, ImageF* mask1x1, const Rect& rect) {
    575   // Blur the mask1x1 to obtain the masking image.
    576   // Before blurring it contains an image of absolute value of the
    577   // Laplacian of the intensity channel.
    578   static const float kFilterMask1x1[5] = {
    579       static_cast<float>(0.25647067633737227),
    580       static_cast<float>(0.2050056912354399075),
    581       static_cast<float>(0.154082048668497307),
    582       static_cast<float>(0.08149576591362004441),
    583       static_cast<float>(0.0512750104812308467),
    584   };
    585   double sum =
    586       1.0 + 4 * (kFilterMask1x1[0] + kFilterMask1x1[1] + kFilterMask1x1[2] +
    587                  kFilterMask1x1[4] + 2 * kFilterMask1x1[3]);
    588   if (sum < 1e-5) {
    589     sum = 1e-5;
    590   }
    591   const float normalize = static_cast<float>(1.0 / sum);
    592   const float normalize_mul = normalize;
    593   WeightsSymmetric5 weights =
    594       WeightsSymmetric5{{HWY_REP4(normalize)},
    595                         {HWY_REP4(normalize_mul * kFilterMask1x1[0])},
    596                         {HWY_REP4(normalize_mul * kFilterMask1x1[2])},
    597                         {HWY_REP4(normalize_mul * kFilterMask1x1[1])},
    598                         {HWY_REP4(normalize_mul * kFilterMask1x1[4])},
    599                         {HWY_REP4(normalize_mul * kFilterMask1x1[3])}};
    600   JXL_ASSIGN_OR_RETURN(ImageF temp, ImageF::Create(rect.xsize(), rect.ysize()));
    601   Symmetric5(*mask1x1, rect, weights, pool, &temp);
    602   *mask1x1 = std::move(temp);
    603   return true;
    604 }
    605 
    606 StatusOr<ImageF> AdaptiveQuantizationMap(const float butteraugli_target,
    607                                          const Image3F& xyb, const Rect& rect,
    608                                          float scale, ThreadPool* pool,
    609                                          ImageF* mask, ImageF* mask1x1) {
    610   JXL_DASSERT(rect.xsize() % kBlockDim == 0);
    611   JXL_DASSERT(rect.ysize() % kBlockDim == 0);
    612   AdaptiveQuantizationImpl impl;
    613   const size_t xsize_blocks = rect.xsize() / kBlockDim;
    614   const size_t ysize_blocks = rect.ysize() / kBlockDim;
    615   JXL_ASSIGN_OR_RETURN(impl.aq_map, ImageF::Create(xsize_blocks, ysize_blocks));
    616   JXL_ASSIGN_OR_RETURN(*mask, ImageF::Create(xsize_blocks, ysize_blocks));
    617   JXL_ASSIGN_OR_RETURN(*mask1x1, ImageF::Create(xyb.xsize(), xyb.ysize()));
    618   JXL_CHECK(RunOnPool(
    619       pool, 0,
    620       DivCeil(xsize_blocks, kEncTileDimInBlocks) *
    621           DivCeil(ysize_blocks, kEncTileDimInBlocks),
    622       [&](const size_t num_threads) {
    623         return !!impl.PrepareBuffers(num_threads);
    624       },
    625       [&](const uint32_t tid, const size_t thread) {
    626         size_t n_enc_tiles = DivCeil(xsize_blocks, kEncTileDimInBlocks);
    627         size_t tx = tid % n_enc_tiles;
    628         size_t ty = tid / n_enc_tiles;
    629         size_t by0 = ty * kEncTileDimInBlocks;
    630         size_t by1 = std::min((ty + 1) * kEncTileDimInBlocks, ysize_blocks);
    631         size_t bx0 = tx * kEncTileDimInBlocks;
    632         size_t bx1 = std::min((tx + 1) * kEncTileDimInBlocks, xsize_blocks);
    633         Rect rect_out(bx0, by0, bx1 - bx0, by1 - by0);
    634         impl.ComputeTile(butteraugli_target, scale, xyb, rect, rect_out, thread,
    635                          mask, mask1x1);
    636       },
    637       "AQ DiffPrecompute"));
    638 
    639   JXL_RETURN_IF_ERROR(Blur1x1Masking(pool, mask1x1, rect));
    640   return std::move(impl).aq_map;
    641 }
    642 
    643 }  // namespace
    644 
    645 // NOLINTNEXTLINE(google-readability-namespace-comments)
    646 }  // namespace HWY_NAMESPACE
    647 }  // namespace jxl
    648 HWY_AFTER_NAMESPACE();
    649 
    650 #if HWY_ONCE
    651 namespace jxl {
    652 HWY_EXPORT(AdaptiveQuantizationMap);
    653 
    654 namespace {
    655 
    656 // If true, prints the quantization maps at each iteration.
    657 constexpr bool FLAGS_dump_quant_state = false;
    658 
    659 Status DumpHeatmap(const CompressParams& cparams, const AuxOut* aux_out,
    660                    const std::string& label, const ImageF& image,
    661                    float good_threshold, float bad_threshold) {
    662   if (JXL_DEBUG_ADAPTIVE_QUANTIZATION) {
    663     JXL_ASSIGN_OR_RETURN(
    664         Image3F heatmap,
    665         CreateHeatMapImage(image, good_threshold, bad_threshold));
    666     char filename[200];
    667     snprintf(filename, sizeof(filename), "%s%05d", label.c_str(),
    668              aux_out->num_butteraugli_iters);
    669     JXL_RETURN_IF_ERROR(DumpImage(cparams, filename, heatmap));
    670   }
    671   return true;
    672 }
    673 
    674 Status DumpHeatmaps(const CompressParams& cparams, const AuxOut* aux_out,
    675                     float ba_target, const ImageF& quant_field,
    676                     const ImageF& tile_heatmap, const ImageF& bt_diffmap) {
    677   if (JXL_DEBUG_ADAPTIVE_QUANTIZATION) {
    678     if (!WantDebugOutput(cparams)) return true;
    679     JXL_ASSIGN_OR_RETURN(ImageF inv_qmap, ImageF::Create(quant_field.xsize(),
    680                                                          quant_field.ysize()));
    681     for (size_t y = 0; y < quant_field.ysize(); ++y) {
    682       const float* JXL_RESTRICT row_q = quant_field.ConstRow(y);
    683       float* JXL_RESTRICT row_inv_q = inv_qmap.Row(y);
    684       for (size_t x = 0; x < quant_field.xsize(); ++x) {
    685         row_inv_q[x] = 1.0f / row_q[x];  // never zero
    686       }
    687     }
    688     JXL_RETURN_IF_ERROR(DumpHeatmap(cparams, aux_out, "quant_heatmap", inv_qmap,
    689                                     4.0f * ba_target, 6.0f * ba_target));
    690     JXL_RETURN_IF_ERROR(DumpHeatmap(cparams, aux_out, "tile_heatmap",
    691                                     tile_heatmap, ba_target, 1.5f * ba_target));
    692     // matches heat maps produced by the command line tool.
    693     JXL_RETURN_IF_ERROR(DumpHeatmap(cparams, aux_out, "bt_diffmap", bt_diffmap,
    694                                     ButteraugliFuzzyInverse(1.5),
    695                                     ButteraugliFuzzyInverse(0.5)));
    696   }
    697   return true;
    698 }
    699 
    700 StatusOr<ImageF> TileDistMap(const ImageF& distmap, int tile_size, int margin,
    701                              const AcStrategyImage& ac_strategy) {
    702   const int tile_xsize = (distmap.xsize() + tile_size - 1) / tile_size;
    703   const int tile_ysize = (distmap.ysize() + tile_size - 1) / tile_size;
    704   JXL_ASSIGN_OR_RETURN(ImageF tile_distmap,
    705                        ImageF::Create(tile_xsize, tile_ysize));
    706   size_t distmap_stride = tile_distmap.PixelsPerRow();
    707   for (int tile_y = 0; tile_y < tile_ysize; ++tile_y) {
    708     AcStrategyRow ac_strategy_row = ac_strategy.ConstRow(tile_y);
    709     float* JXL_RESTRICT dist_row = tile_distmap.Row(tile_y);
    710     for (int tile_x = 0; tile_x < tile_xsize; ++tile_x) {
    711       AcStrategy acs = ac_strategy_row[tile_x];
    712       if (!acs.IsFirstBlock()) continue;
    713       int this_tile_xsize = acs.covered_blocks_x() * tile_size;
    714       int this_tile_ysize = acs.covered_blocks_y() * tile_size;
    715       int y_begin = std::max<int>(0, tile_size * tile_y - margin);
    716       int y_end = std::min<int>(distmap.ysize(),
    717                                 tile_size * tile_y + this_tile_ysize + margin);
    718       int x_begin = std::max<int>(0, tile_size * tile_x - margin);
    719       int x_end = std::min<int>(distmap.xsize(),
    720                                 tile_size * tile_x + this_tile_xsize + margin);
    721       float dist_norm = 0.0;
    722       double pixels = 0;
    723       for (int y = y_begin; y < y_end; ++y) {
    724         float ymul = 1.0;
    725         constexpr float kBorderMul = 0.98f;
    726         constexpr float kCornerMul = 0.7f;
    727         if (margin != 0 && (y == y_begin || y == y_end - 1)) {
    728           ymul = kBorderMul;
    729         }
    730         const float* const JXL_RESTRICT row = distmap.Row(y);
    731         for (int x = x_begin; x < x_end; ++x) {
    732           float xmul = ymul;
    733           if (margin != 0 && (x == x_begin || x == x_end - 1)) {
    734             if (xmul == 1.0) {
    735               xmul = kBorderMul;
    736             } else {
    737               xmul = kCornerMul;
    738             }
    739           }
    740           float v = row[x];
    741           v *= v;
    742           v *= v;
    743           v *= v;
    744           v *= v;
    745           dist_norm += xmul * v;
    746           pixels += xmul;
    747         }
    748       }
    749       if (pixels == 0) pixels = 1;
    750       // 16th norm is less than the max norm, we reduce the difference
    751       // with this normalization factor.
    752       constexpr float kTileNorm = 1.2f;
    753       const float tile_dist =
    754           kTileNorm * std::pow(dist_norm / pixels, 1.0f / 16.0f);
    755       dist_row[tile_x] = tile_dist;
    756       for (size_t iy = 0; iy < acs.covered_blocks_y(); iy++) {
    757         for (size_t ix = 0; ix < acs.covered_blocks_x(); ix++) {
    758           dist_row[tile_x + distmap_stride * iy + ix] = tile_dist;
    759         }
    760       }
    761     }
    762   }
    763   return tile_distmap;
    764 }
    765 
    766 const float kDcQuantPow = 0.83f;
    767 const float kDcQuant = 1.095924047623553f;
    768 const float kAcQuant = 0.7381485255235064f;
    769 
    770 // Computes the decoded image for a given set of compression parameters.
    771 StatusOr<ImageBundle> RoundtripImage(const FrameHeader& frame_header,
    772                                      const Image3F& opsin,
    773                                      PassesEncoderState* enc_state,
    774                                      const JxlCmsInterface& cms,
    775                                      ThreadPool* pool) {
    776   std::unique_ptr<PassesDecoderState> dec_state =
    777       jxl::make_unique<PassesDecoderState>();
    778   JXL_CHECK(dec_state->output_encoding_info.SetFromMetadata(
    779       *enc_state->shared.metadata));
    780   dec_state->shared = &enc_state->shared;
    781   JXL_ASSERT(opsin.ysize() % kBlockDim == 0);
    782 
    783   const size_t xsize_groups = DivCeil(opsin.xsize(), kGroupDim);
    784   const size_t ysize_groups = DivCeil(opsin.ysize(), kGroupDim);
    785   const size_t num_groups = xsize_groups * ysize_groups;
    786 
    787   size_t num_special_frames = enc_state->special_frames.size();
    788   size_t num_passes = enc_state->progressive_splitter.GetNumPasses();
    789   ModularFrameEncoder modular_frame_encoder(frame_header, enc_state->cparams,
    790                                             false);
    791   JXL_CHECK(InitializePassesEncoder(frame_header, opsin, Rect(opsin), cms, pool,
    792                                     enc_state, &modular_frame_encoder,
    793                                     nullptr));
    794   JXL_CHECK(dec_state->Init(frame_header));
    795   JXL_CHECK(dec_state->InitForAC(num_passes, pool));
    796 
    797   ImageBundle decoded(&enc_state->shared.metadata->m);
    798   decoded.origin = frame_header.frame_origin;
    799   JXL_ASSIGN_OR_RETURN(Image3F tmp,
    800                        Image3F::Create(opsin.xsize(), opsin.ysize()));
    801   decoded.SetFromImage(std::move(tmp),
    802                        dec_state->output_encoding_info.color_encoding);
    803 
    804   PassesDecoderState::PipelineOptions options;
    805   options.use_slow_render_pipeline = false;
    806   options.coalescing = false;
    807   options.render_spotcolors = false;
    808   options.render_noise = false;
    809 
    810   // Same as frame_header.nonserialized_metadata->m
    811   const ImageMetadata& metadata = *decoded.metadata();
    812 
    813   JXL_CHECK(dec_state->PreparePipeline(frame_header, &decoded, options));
    814 
    815   hwy::AlignedUniquePtr<GroupDecCache[]> group_dec_caches;
    816   const auto allocate_storage = [&](const size_t num_threads) -> Status {
    817     JXL_RETURN_IF_ERROR(
    818         dec_state->render_pipeline->PrepareForThreads(num_threads,
    819                                                       /*use_group_ids=*/false));
    820     group_dec_caches = hwy::MakeUniqueAlignedArray<GroupDecCache>(num_threads);
    821     return true;
    822   };
    823   std::atomic<bool> has_error{false};
    824   const auto process_group = [&](const uint32_t group_index,
    825                                  const size_t thread) {
    826     if (has_error) return;
    827     if (frame_header.loop_filter.epf_iters > 0) {
    828       ComputeSigma(frame_header.loop_filter,
    829                    dec_state->shared->frame_dim.BlockGroupRect(group_index),
    830                    dec_state.get());
    831     }
    832     RenderPipelineInput input =
    833         dec_state->render_pipeline->GetInputBuffers(group_index, thread);
    834     JXL_CHECK(DecodeGroupForRoundtrip(
    835         frame_header, enc_state->coeffs, group_index, dec_state.get(),
    836         &group_dec_caches[thread], thread, input, &decoded, nullptr));
    837     for (size_t c = 0; c < metadata.num_extra_channels; c++) {
    838       std::pair<ImageF*, Rect> ri = input.GetBuffer(3 + c);
    839       FillPlane(0.0f, ri.first, ri.second);
    840     }
    841     if (!input.Done()) {
    842       has_error = true;
    843       return;
    844     }
    845   };
    846   JXL_CHECK(RunOnPool(pool, 0, num_groups, allocate_storage, process_group,
    847                       "AQ loop"));
    848   if (has_error) return JXL_FAILURE("AQ loop failure");
    849 
    850   // Ensure we don't create any new special frames.
    851   enc_state->special_frames.resize(num_special_frames);
    852 
    853   return decoded;
    854 }
    855 
    856 constexpr int kMaxButteraugliIters = 4;
    857 
    858 Status FindBestQuantization(const FrameHeader& frame_header,
    859                             const Image3F& linear, const Image3F& opsin,
    860                             ImageF& quant_field, PassesEncoderState* enc_state,
    861                             const JxlCmsInterface& cms, ThreadPool* pool,
    862                             AuxOut* aux_out) {
    863   const CompressParams& cparams = enc_state->cparams;
    864   if (cparams.resampling > 1 &&
    865       cparams.original_butteraugli_distance <= 4.0 * cparams.resampling) {
    866     // For downsampled opsin image, the butteraugli based adaptive quantization
    867     // loop would only make the size bigger without improving the distance much,
    868     // so in this case we enable it only for very high butteraugli targets.
    869     return true;
    870   }
    871   Quantizer& quantizer = enc_state->shared.quantizer;
    872   ImageI& raw_quant_field = enc_state->shared.raw_quant_field;
    873 
    874   const float butteraugli_target = cparams.butteraugli_distance;
    875   const float original_butteraugli = cparams.original_butteraugli_distance;
    876   ButteraugliParams params;
    877   params.intensity_target = 80.f;
    878   JxlButteraugliComparator comparator(params, cms);
    879   JXL_CHECK(comparator.SetLinearReferenceImage(linear));
    880   bool lower_is_better =
    881       (comparator.GoodQualityScore() < comparator.BadQualityScore());
    882   const float initial_quant_dc = InitialQuantDC(butteraugli_target);
    883   AdjustQuantField(enc_state->shared.ac_strategy, Rect(quant_field),
    884                    original_butteraugli, &quant_field);
    885   ImageF tile_distmap;
    886   JXL_ASSIGN_OR_RETURN(
    887       ImageF initial_quant_field,
    888       ImageF::Create(quant_field.xsize(), quant_field.ysize()));
    889   CopyImageTo(quant_field, &initial_quant_field);
    890 
    891   float initial_qf_min;
    892   float initial_qf_max;
    893   ImageMinMax(initial_quant_field, &initial_qf_min, &initial_qf_max);
    894   float initial_qf_ratio = initial_qf_max / initial_qf_min;
    895   float qf_max_deviation_low = std::sqrt(250 / initial_qf_ratio);
    896   float asymmetry = 2;
    897   if (qf_max_deviation_low < asymmetry) asymmetry = qf_max_deviation_low;
    898   float qf_lower = initial_qf_min / (asymmetry * qf_max_deviation_low);
    899   float qf_higher = initial_qf_max * (qf_max_deviation_low / asymmetry);
    900 
    901   JXL_ASSERT(qf_higher / qf_lower < 253);
    902 
    903   constexpr int kOriginalComparisonRound = 1;
    904   int iters = kMaxButteraugliIters;
    905   if (cparams.speed_tier != SpeedTier::kTortoise) {
    906     iters = 2;
    907   }
    908   for (int i = 0; i < iters + 1; ++i) {
    909     if (JXL_DEBUG_ADAPTIVE_QUANTIZATION) {
    910       printf("\nQuantization field:\n");
    911       for (size_t y = 0; y < quant_field.ysize(); ++y) {
    912         for (size_t x = 0; x < quant_field.xsize(); ++x) {
    913           printf(" %.5f", quant_field.Row(y)[x]);
    914         }
    915         printf("\n");
    916       }
    917     }
    918     quantizer.SetQuantField(initial_quant_dc, quant_field, &raw_quant_field);
    919     JXL_ASSIGN_OR_RETURN(
    920         ImageBundle dec_linear,
    921         RoundtripImage(frame_header, opsin, enc_state, cms, pool));
    922     float score;
    923     ImageF diffmap;
    924     JXL_CHECK(comparator.CompareWith(dec_linear, &diffmap, &score));
    925     if (!lower_is_better) {
    926       score = -score;
    927       ScaleImage(-1.0f, &diffmap);
    928     }
    929     JXL_ASSIGN_OR_RETURN(tile_distmap,
    930                          TileDistMap(diffmap, 8 * cparams.resampling, 0,
    931                                      enc_state->shared.ac_strategy));
    932     if (JXL_DEBUG_ADAPTIVE_QUANTIZATION && WantDebugOutput(cparams)) {
    933       JXL_RETURN_IF_ERROR(DumpImage(cparams, ("dec" + ToString(i)).c_str(),
    934                                     *dec_linear.color()));
    935       JXL_RETURN_IF_ERROR(DumpHeatmaps(cparams, aux_out, butteraugli_target,
    936                                        quant_field, tile_distmap, diffmap));
    937     }
    938     if (aux_out != nullptr) ++aux_out->num_butteraugli_iters;
    939     if (JXL_DEBUG_ADAPTIVE_QUANTIZATION) {
    940       float minval;
    941       float maxval;
    942       ImageMinMax(quant_field, &minval, &maxval);
    943       printf("\nButteraugli iter: %d/%d\n", i, kMaxButteraugliIters);
    944       printf("Butteraugli distance: %f  (target = %f)\n", score,
    945              original_butteraugli);
    946       printf("quant range: %f ... %f  DC quant: %f\n", minval, maxval,
    947              initial_quant_dc);
    948       if (FLAGS_dump_quant_state) {
    949         quantizer.DumpQuantizationMap(raw_quant_field);
    950       }
    951     }
    952 
    953     if (i == iters) break;
    954 
    955     double kPow[8] = {
    956         0.2, 0.2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
    957     };
    958     double kPowMod[8] = {
    959         0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
    960     };
    961     if (i == kOriginalComparisonRound) {
    962       // Don't allow optimization to make the quant field a lot worse than
    963       // what the initial guess was. This allows the AC field to have enough
    964       // precision to reduce the oscillations due to the dc reconstruction.
    965       double kInitMul = 0.6;
    966       const double kOneMinusInitMul = 1.0 - kInitMul;
    967       for (size_t y = 0; y < quant_field.ysize(); ++y) {
    968         float* const JXL_RESTRICT row_q = quant_field.Row(y);
    969         const float* const JXL_RESTRICT row_init = initial_quant_field.Row(y);
    970         for (size_t x = 0; x < quant_field.xsize(); ++x) {
    971           double clamp = kOneMinusInitMul * row_q[x] + kInitMul * row_init[x];
    972           if (row_q[x] < clamp) {
    973             row_q[x] = clamp;
    974             if (row_q[x] > qf_higher) row_q[x] = qf_higher;
    975             if (row_q[x] < qf_lower) row_q[x] = qf_lower;
    976           }
    977         }
    978       }
    979     }
    980 
    981     double cur_pow = 0.0;
    982     if (i < 7) {
    983       cur_pow = kPow[i] + (original_butteraugli - 1.0) * kPowMod[i];
    984       if (cur_pow < 0) {
    985         cur_pow = 0;
    986       }
    987     }
    988     if (cur_pow == 0.0) {
    989       for (size_t y = 0; y < quant_field.ysize(); ++y) {
    990         const float* const JXL_RESTRICT row_dist = tile_distmap.Row(y);
    991         float* const JXL_RESTRICT row_q = quant_field.Row(y);
    992         for (size_t x = 0; x < quant_field.xsize(); ++x) {
    993           const float diff = row_dist[x] / original_butteraugli;
    994           if (diff > 1.0f) {
    995             float old = row_q[x];
    996             row_q[x] *= diff;
    997             int qf_old =
    998                 static_cast<int>(std::lround(old * quantizer.InvGlobalScale()));
    999             int qf_new = static_cast<int>(
   1000                 std::lround(row_q[x] * quantizer.InvGlobalScale()));
   1001             if (qf_old == qf_new) {
   1002               row_q[x] = old + quantizer.Scale();
   1003             }
   1004           }
   1005           if (row_q[x] > qf_higher) row_q[x] = qf_higher;
   1006           if (row_q[x] < qf_lower) row_q[x] = qf_lower;
   1007         }
   1008       }
   1009     } else {
   1010       for (size_t y = 0; y < quant_field.ysize(); ++y) {
   1011         const float* const JXL_RESTRICT row_dist = tile_distmap.Row(y);
   1012         float* const JXL_RESTRICT row_q = quant_field.Row(y);
   1013         for (size_t x = 0; x < quant_field.xsize(); ++x) {
   1014           const float diff = row_dist[x] / original_butteraugli;
   1015           if (diff <= 1.0f) {
   1016             row_q[x] *= std::pow(diff, cur_pow);
   1017           } else {
   1018             float old = row_q[x];
   1019             row_q[x] *= diff;
   1020             int qf_old =
   1021                 static_cast<int>(std::lround(old * quantizer.InvGlobalScale()));
   1022             int qf_new = static_cast<int>(
   1023                 std::lround(row_q[x] * quantizer.InvGlobalScale()));
   1024             if (qf_old == qf_new) {
   1025               row_q[x] = old + quantizer.Scale();
   1026             }
   1027           }
   1028           if (row_q[x] > qf_higher) row_q[x] = qf_higher;
   1029           if (row_q[x] < qf_lower) row_q[x] = qf_lower;
   1030         }
   1031       }
   1032     }
   1033   }
   1034   quantizer.SetQuantField(initial_quant_dc, quant_field, &raw_quant_field);
   1035   return true;
   1036 }
   1037 
   1038 Status FindBestQuantizationMaxError(const FrameHeader& frame_header,
   1039                                     const Image3F& opsin, ImageF& quant_field,
   1040                                     PassesEncoderState* enc_state,
   1041                                     const JxlCmsInterface& cms,
   1042                                     ThreadPool* pool, AuxOut* aux_out) {
   1043   // TODO(szabadka): Make this work for non-opsin color spaces.
   1044   const CompressParams& cparams = enc_state->cparams;
   1045   Quantizer& quantizer = enc_state->shared.quantizer;
   1046   ImageI& raw_quant_field = enc_state->shared.raw_quant_field;
   1047 
   1048   // TODO(veluca): better choice of this value.
   1049   const float initial_quant_dc =
   1050       16 * std::sqrt(0.1f / cparams.butteraugli_distance);
   1051   AdjustQuantField(enc_state->shared.ac_strategy, Rect(quant_field),
   1052                    cparams.original_butteraugli_distance, &quant_field);
   1053 
   1054   const float inv_max_err[3] = {1.0f / enc_state->cparams.max_error[0],
   1055                                 1.0f / enc_state->cparams.max_error[1],
   1056                                 1.0f / enc_state->cparams.max_error[2]};
   1057 
   1058   for (int i = 0; i < kMaxButteraugliIters + 1; ++i) {
   1059     quantizer.SetQuantField(initial_quant_dc, quant_field, &raw_quant_field);
   1060     if (JXL_DEBUG_ADAPTIVE_QUANTIZATION && aux_out) {
   1061       JXL_RETURN_IF_ERROR(
   1062           DumpXybImage(cparams, ("ops" + ToString(i)).c_str(), opsin));
   1063     }
   1064     JXL_ASSIGN_OR_RETURN(
   1065         ImageBundle decoded,
   1066         RoundtripImage(frame_header, opsin, enc_state, cms, pool));
   1067     if (JXL_DEBUG_ADAPTIVE_QUANTIZATION && aux_out) {
   1068       JXL_RETURN_IF_ERROR(DumpXybImage(cparams, ("dec" + ToString(i)).c_str(),
   1069                                        *decoded.color()));
   1070     }
   1071     for (size_t by = 0; by < enc_state->shared.frame_dim.ysize_blocks; by++) {
   1072       AcStrategyRow ac_strategy_row =
   1073           enc_state->shared.ac_strategy.ConstRow(by);
   1074       for (size_t bx = 0; bx < enc_state->shared.frame_dim.xsize_blocks; bx++) {
   1075         AcStrategy acs = ac_strategy_row[bx];
   1076         if (!acs.IsFirstBlock()) continue;
   1077         float max_error = 0;
   1078         for (size_t c = 0; c < 3; c++) {
   1079           for (size_t y = by * kBlockDim;
   1080                y < (by + acs.covered_blocks_y()) * kBlockDim; y++) {
   1081             if (y >= decoded.ysize()) continue;
   1082             const float* JXL_RESTRICT in_row = opsin.ConstPlaneRow(c, y);
   1083             const float* JXL_RESTRICT dec_row =
   1084                 decoded.color()->ConstPlaneRow(c, y);
   1085             for (size_t x = bx * kBlockDim;
   1086                  x < (bx + acs.covered_blocks_x()) * kBlockDim; x++) {
   1087               if (x >= decoded.xsize()) continue;
   1088               max_error = std::max(
   1089                   std::abs(in_row[x] - dec_row[x]) * inv_max_err[c], max_error);
   1090             }
   1091           }
   1092         }
   1093         // Target an error between max_error/2 and max_error.
   1094         // If the error in the varblock is above the target, increase the qf to
   1095         // compensate. If the error is below the target, decrease the qf.
   1096         // However, to avoid an excessive increase of the qf, only do so if the
   1097         // error is less than half the maximum allowed error.
   1098         const float qf_mul = (max_error < 0.5f)   ? max_error * 2.0f
   1099                              : (max_error > 1.0f) ? max_error
   1100                                                   : 1.0f;
   1101         for (size_t qy = by; qy < by + acs.covered_blocks_y(); qy++) {
   1102           float* JXL_RESTRICT quant_field_row = quant_field.Row(qy);
   1103           for (size_t qx = bx; qx < bx + acs.covered_blocks_x(); qx++) {
   1104             quant_field_row[qx] *= qf_mul;
   1105           }
   1106         }
   1107       }
   1108     }
   1109   }
   1110   quantizer.SetQuantField(initial_quant_dc, quant_field, &raw_quant_field);
   1111   return true;
   1112 }
   1113 
   1114 }  // namespace
   1115 
   1116 void AdjustQuantField(const AcStrategyImage& ac_strategy, const Rect& rect,
   1117                       float butteraugli_target, ImageF* quant_field) {
   1118   // Replace the whole quant_field in non-8x8 blocks with the maximum of each
   1119   // 8x8 block.
   1120   size_t stride = quant_field->PixelsPerRow();
   1121 
   1122   // At low distances it is great to use max, but mean works better
   1123   // at high distances. We interpolate between them for a distance
   1124   // range.
   1125   float mean_max_mixer = 1.0f;
   1126   {
   1127     static const float kLimit = 1.54138f;
   1128     static const float kMul = 0.56391f;
   1129     static const float kMin = 0.0f;
   1130     if (butteraugli_target > kLimit) {
   1131       mean_max_mixer -= (butteraugli_target - kLimit) * kMul;
   1132       if (mean_max_mixer < kMin) {
   1133         mean_max_mixer = kMin;
   1134       }
   1135     }
   1136   }
   1137   for (size_t y = 0; y < rect.ysize(); ++y) {
   1138     AcStrategyRow ac_strategy_row = ac_strategy.ConstRow(rect, y);
   1139     float* JXL_RESTRICT quant_row = rect.Row(quant_field, y);
   1140     for (size_t x = 0; x < rect.xsize(); ++x) {
   1141       AcStrategy acs = ac_strategy_row[x];
   1142       if (!acs.IsFirstBlock()) continue;
   1143       JXL_ASSERT(x + acs.covered_blocks_x() <= quant_field->xsize());
   1144       JXL_ASSERT(y + acs.covered_blocks_y() <= quant_field->ysize());
   1145       float max = quant_row[x];
   1146       float mean = 0.0;
   1147       for (size_t iy = 0; iy < acs.covered_blocks_y(); iy++) {
   1148         for (size_t ix = 0; ix < acs.covered_blocks_x(); ix++) {
   1149           mean += quant_row[x + ix + iy * stride];
   1150           max = std::max(quant_row[x + ix + iy * stride], max);
   1151         }
   1152       }
   1153       mean /= acs.covered_blocks_y() * acs.covered_blocks_x();
   1154       if (acs.covered_blocks_y() * acs.covered_blocks_x() >= 4) {
   1155         max *= mean_max_mixer;
   1156         max += (1.0f - mean_max_mixer) * mean;
   1157       }
   1158       for (size_t iy = 0; iy < acs.covered_blocks_y(); iy++) {
   1159         for (size_t ix = 0; ix < acs.covered_blocks_x(); ix++) {
   1160           quant_row[x + ix + iy * stride] = max;
   1161         }
   1162       }
   1163     }
   1164   }
   1165 }
   1166 
   1167 float InitialQuantDC(float butteraugli_target) {
   1168   const float kDcMul = 0.3;  // Butteraugli target where non-linearity kicks in.
   1169   const float butteraugli_target_dc = std::max<float>(
   1170       0.5f * butteraugli_target,
   1171       std::min<float>(butteraugli_target,
   1172                       kDcMul * std::pow((1.0f / kDcMul) * butteraugli_target,
   1173                                         kDcQuantPow)));
   1174   // We want the maximum DC value to be at most 2**15 * kInvDCQuant / quant_dc.
   1175   // The maximum DC value might not be in the kXybRange because of inverse
   1176   // gaborish, so we add some slack to the maximum theoretical quant obtained
   1177   // this way (64).
   1178   return std::min(kDcQuant / butteraugli_target_dc, 50.f);
   1179 }
   1180 
   1181 StatusOr<ImageF> InitialQuantField(const float butteraugli_target,
   1182                                    const Image3F& opsin, const Rect& rect,
   1183                                    ThreadPool* pool, float rescale,
   1184                                    ImageF* mask, ImageF* mask1x1) {
   1185   const float quant_ac = kAcQuant / butteraugli_target;
   1186   return HWY_DYNAMIC_DISPATCH(AdaptiveQuantizationMap)(
   1187       butteraugli_target, opsin, rect, quant_ac * rescale, pool, mask, mask1x1);
   1188 }
   1189 
   1190 Status FindBestQuantizer(const FrameHeader& frame_header, const Image3F* linear,
   1191                          const Image3F& opsin, ImageF& quant_field,
   1192                          PassesEncoderState* enc_state,
   1193                          const JxlCmsInterface& cms, ThreadPool* pool,
   1194                          AuxOut* aux_out, double rescale) {
   1195   const CompressParams& cparams = enc_state->cparams;
   1196   if (cparams.max_error_mode) {
   1197     JXL_RETURN_IF_ERROR(FindBestQuantizationMaxError(
   1198         frame_header, opsin, quant_field, enc_state, cms, pool, aux_out));
   1199   } else if (linear && cparams.speed_tier <= SpeedTier::kKitten) {
   1200     // Normal encoding to a butteraugli score.
   1201     JXL_RETURN_IF_ERROR(FindBestQuantization(frame_header, *linear, opsin,
   1202                                              quant_field, enc_state, cms, pool,
   1203                                              aux_out));
   1204   }
   1205   return true;
   1206 }
   1207 
   1208 }  // namespace jxl
   1209 #endif  // HWY_ONCE