libjxl

FORK: libjxl patches used on blog
git clone https://git.neptards.moe/blog/libjxl.git
Log | Files | Refs | Submodules | README | LICENSE

adaptive_quantization.cc (21629B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 #include "lib/jpegli/adaptive_quantization.h"
      7 
      8 #include <jxl/types.h>
      9 #include <stddef.h>
     10 #include <stdlib.h>
     11 
     12 #include <algorithm>
     13 #include <cmath>
     14 #include <limits>
     15 #include <string>
     16 #include <vector>
     17 
     18 #undef HWY_TARGET_INCLUDE
     19 #define HWY_TARGET_INCLUDE "lib/jpegli/adaptive_quantization.cc"
     20 #include <hwy/foreach_target.h>
     21 #include <hwy/highway.h>
     22 
     23 #include "lib/jpegli/encode_internal.h"
     24 #include "lib/jxl/base/compiler_specific.h"
     25 #include "lib/jxl/base/status.h"
     26 HWY_BEFORE_NAMESPACE();
     27 namespace jpegli {
     28 namespace HWY_NAMESPACE {
     29 namespace {
     30 
     31 // These templates are not found via ADL.
     32 using hwy::HWY_NAMESPACE::AbsDiff;
     33 using hwy::HWY_NAMESPACE::Add;
     34 using hwy::HWY_NAMESPACE::And;
     35 using hwy::HWY_NAMESPACE::Div;
     36 using hwy::HWY_NAMESPACE::Floor;
     37 using hwy::HWY_NAMESPACE::GetLane;
     38 using hwy::HWY_NAMESPACE::Max;
     39 using hwy::HWY_NAMESPACE::Min;
     40 using hwy::HWY_NAMESPACE::Mul;
     41 using hwy::HWY_NAMESPACE::MulAdd;
     42 using hwy::HWY_NAMESPACE::NegMulAdd;
     43 using hwy::HWY_NAMESPACE::Rebind;
     44 using hwy::HWY_NAMESPACE::ShiftLeft;
     45 using hwy::HWY_NAMESPACE::ShiftRight;
     46 using hwy::HWY_NAMESPACE::Sqrt;
     47 using hwy::HWY_NAMESPACE::Sub;
     48 using hwy::HWY_NAMESPACE::ZeroIfNegative;
     49 
     50 constexpr float kInputScaling = 1.0f / 255.0f;
     51 
     52 // Primary template: default to actual division.
     53 template <typename T, class V>
     54 struct FastDivision {
     55   HWY_INLINE V operator()(const V n, const V d) const { return n / d; }
     56 };
     57 // Partial specialization for float vectors.
     58 template <class V>
     59 struct FastDivision<float, V> {
     60   // One Newton-Raphson iteration.
     61   static HWY_INLINE V ReciprocalNR(const V x) {
     62     const auto rcp = ApproximateReciprocal(x);
     63     const auto sum = Add(rcp, rcp);
     64     const auto x_rcp = Mul(x, rcp);
     65     return NegMulAdd(x_rcp, rcp, sum);
     66   }
     67 
     68   V operator()(const V n, const V d) const {
     69 #if JXL_TRUE  // Faster on SKX
     70     return Div(n, d);
     71 #else
     72     return n * ReciprocalNR(d);
     73 #endif
     74   }
     75 };
     76 
     77 // Approximates smooth functions via rational polynomials (i.e. dividing two
     78 // polynomials). Evaluates polynomials via Horner's scheme, which is faster than
     79 // Clenshaw recurrence for Chebyshev polynomials. LoadDup128 allows us to
     80 // specify constants (replicated 4x) independently of the lane count.
     81 template <size_t NP, size_t NQ, class D, class V, typename T>
     82 HWY_INLINE HWY_MAYBE_UNUSED V EvalRationalPolynomial(const D d, const V x,
     83                                                      const T (&p)[NP],
     84                                                      const T (&q)[NQ]) {
     85   constexpr size_t kDegP = NP / 4 - 1;
     86   constexpr size_t kDegQ = NQ / 4 - 1;
     87   auto yp = LoadDup128(d, &p[kDegP * 4]);
     88   auto yq = LoadDup128(d, &q[kDegQ * 4]);
     89   // We use pointer arithmetic to refer to &p[(kDegP - n) * 4] to avoid a
     90   // compiler warning that the index is out of bounds since we are already
     91   // checking that it is not out of bounds with (kDegP >= n) and the access
     92   // will be optimized away. Similarly with q and kDegQ.
     93   HWY_FENCE;
     94   if (kDegP >= 1) yp = MulAdd(yp, x, LoadDup128(d, p + ((kDegP - 1) * 4)));
     95   if (kDegQ >= 1) yq = MulAdd(yq, x, LoadDup128(d, q + ((kDegQ - 1) * 4)));
     96   HWY_FENCE;
     97   if (kDegP >= 2) yp = MulAdd(yp, x, LoadDup128(d, p + ((kDegP - 2) * 4)));
     98   if (kDegQ >= 2) yq = MulAdd(yq, x, LoadDup128(d, q + ((kDegQ - 2) * 4)));
     99   HWY_FENCE;
    100   if (kDegP >= 3) yp = MulAdd(yp, x, LoadDup128(d, p + ((kDegP - 3) * 4)));
    101   if (kDegQ >= 3) yq = MulAdd(yq, x, LoadDup128(d, q + ((kDegQ - 3) * 4)));
    102   HWY_FENCE;
    103   if (kDegP >= 4) yp = MulAdd(yp, x, LoadDup128(d, p + ((kDegP - 4) * 4)));
    104   if (kDegQ >= 4) yq = MulAdd(yq, x, LoadDup128(d, q + ((kDegQ - 4) * 4)));
    105   HWY_FENCE;
    106   if (kDegP >= 5) yp = MulAdd(yp, x, LoadDup128(d, p + ((kDegP - 5) * 4)));
    107   if (kDegQ >= 5) yq = MulAdd(yq, x, LoadDup128(d, q + ((kDegQ - 5) * 4)));
    108   HWY_FENCE;
    109   if (kDegP >= 6) yp = MulAdd(yp, x, LoadDup128(d, p + ((kDegP - 6) * 4)));
    110   if (kDegQ >= 6) yq = MulAdd(yq, x, LoadDup128(d, q + ((kDegQ - 6) * 4)));
    111   HWY_FENCE;
    112   if (kDegP >= 7) yp = MulAdd(yp, x, LoadDup128(d, p + ((kDegP - 7) * 4)));
    113   if (kDegQ >= 7) yq = MulAdd(yq, x, LoadDup128(d, q + ((kDegQ - 7) * 4)));
    114 
    115   return FastDivision<T, V>()(yp, yq);
    116 }
    117 
    118 // Computes base-2 logarithm like std::log2. Undefined if negative / NaN.
    119 // L1 error ~3.9E-6
    120 template <class DF, class V>
    121 V FastLog2f(const DF df, V x) {
    122   // 2,2 rational polynomial approximation of std::log1p(x) / std::log(2).
    123   HWY_ALIGN const float p[4 * (2 + 1)] = {HWY_REP4(-1.8503833400518310E-06f),
    124                                           HWY_REP4(1.4287160470083755E+00f),
    125                                           HWY_REP4(7.4245873327820566E-01f)};
    126   HWY_ALIGN const float q[4 * (2 + 1)] = {HWY_REP4(9.9032814277590719E-01f),
    127                                           HWY_REP4(1.0096718572241148E+00f),
    128                                           HWY_REP4(1.7409343003366853E-01f)};
    129 
    130   const Rebind<int32_t, DF> di;
    131   const auto x_bits = BitCast(di, x);
    132 
    133   // Range reduction to [-1/3, 1/3] - 3 integer, 2 float ops
    134   const auto exp_bits = Sub(x_bits, Set(di, 0x3f2aaaab));  // = 2/3
    135   // Shifted exponent = log2; also used to clear mantissa.
    136   const auto exp_shifted = ShiftRight<23>(exp_bits);
    137   const auto mantissa = BitCast(df, Sub(x_bits, ShiftLeft<23>(exp_shifted)));
    138   const auto exp_val = ConvertTo(df, exp_shifted);
    139   return Add(EvalRationalPolynomial(df, Sub(mantissa, Set(df, 1.0f)), p, q),
    140              exp_val);
    141 }
    142 
    143 // max relative error ~3e-7
    144 template <class DF, class V>
    145 V FastPow2f(const DF df, V x) {
    146   const Rebind<int32_t, DF> di;
    147   auto floorx = Floor(x);
    148   auto exp =
    149       BitCast(df, ShiftLeft<23>(Add(ConvertTo(di, floorx), Set(di, 127))));
    150   auto frac = Sub(x, floorx);
    151   auto num = Add(frac, Set(df, 1.01749063e+01));
    152   num = MulAdd(num, frac, Set(df, 4.88687798e+01));
    153   num = MulAdd(num, frac, Set(df, 9.85506591e+01));
    154   num = Mul(num, exp);
    155   auto den = MulAdd(frac, Set(df, 2.10242958e-01), Set(df, -2.22328856e-02));
    156   den = MulAdd(den, frac, Set(df, -1.94414990e+01));
    157   den = MulAdd(den, frac, Set(df, 9.85506633e+01));
    158   return Div(num, den);
    159 }
    160 
    161 inline float FastPow2f(float f) {
    162   HWY_CAPPED(float, 1) D;
    163   return GetLane(FastPow2f(D, Set(D, f)));
    164 }
    165 
    166 // The following functions modulate an exponent (out_val) and return the updated
    167 // value. Their descriptor is limited to 8 lanes for 8x8 blocks.
    168 
    169 template <class D, class V>
    170 V ComputeMask(const D d, const V out_val) {
    171   const auto kBase = Set(d, -0.74174993f);
    172   const auto kMul4 = Set(d, 3.2353257320940401f);
    173   const auto kMul2 = Set(d, 12.906028311180409f);
    174   const auto kOffset2 = Set(d, 305.04035728311436f);
    175   const auto kMul3 = Set(d, 5.0220313103171232f);
    176   const auto kOffset3 = Set(d, 2.1925739705298404f);
    177   const auto kOffset4 = Mul(Set(d, 0.25f), kOffset3);
    178   const auto kMul0 = Set(d, 0.74760422233706747f);
    179   const auto k1 = Set(d, 1.0f);
    180 
    181   // Avoid division by zero.
    182   const auto v1 = Max(Mul(out_val, kMul0), Set(d, 1e-3f));
    183   const auto v2 = Div(k1, Add(v1, kOffset2));
    184   const auto v3 = Div(k1, MulAdd(v1, v1, kOffset3));
    185   const auto v4 = Div(k1, MulAdd(v1, v1, kOffset4));
    186   // TODO(jyrki):
    187   // A log or two here could make sense. In butteraugli we have effectively
    188   // log(log(x + C)) for this kind of use, as a single log is used in
    189   // saturating visual masking and here the modulation values are exponential,
    190   // another log would counter that.
    191   return Add(kBase, MulAdd(kMul4, v4, MulAdd(kMul2, v2, Mul(kMul3, v3))));
    192 }
    193 
    194 // mul and mul2 represent a scaling difference between jxl and butteraugli.
    195 const float kSGmul = 226.0480446705883f;
    196 const float kSGmul2 = 1.0f / 73.377132366608819f;
    197 const float kLog2 = 0.693147181f;
    198 // Includes correction factor for std::log -> log2.
    199 const float kSGRetMul = kSGmul2 * 18.6580932135f * kLog2;
    200 const float kSGVOffset = 7.14672470003f;
    201 
    202 template <bool invert, typename D, typename V>
    203 V RatioOfDerivativesOfCubicRootToSimpleGamma(const D d, V v) {
    204   // The opsin space in jxl is the cubic root of photons, i.e., v * v * v
    205   // is related to the number of photons.
    206   //
    207   // SimpleGamma(v * v * v) is the psychovisual space in butteraugli.
    208   // This ratio allows quantization to move from jxl's opsin space to
    209   // butteraugli's log-gamma space.
    210   static const float kEpsilon = 1e-2;
    211   static const float kNumOffset = kEpsilon / kInputScaling / kInputScaling;
    212   static const float kNumMul = kSGRetMul * 3 * kSGmul;
    213   static const float kVOffset = (kSGVOffset * kLog2 + kEpsilon) / kInputScaling;
    214   static const float kDenMul = kLog2 * kSGmul * kInputScaling * kInputScaling;
    215 
    216   v = ZeroIfNegative(v);
    217   const auto num_mul = Set(d, kNumMul);
    218   const auto num_offset = Set(d, kNumOffset);
    219   const auto den_offset = Set(d, kVOffset);
    220   const auto den_mul = Set(d, kDenMul);
    221 
    222   const auto v2 = Mul(v, v);
    223 
    224   const auto num = MulAdd(num_mul, v2, num_offset);
    225   const auto den = MulAdd(Mul(den_mul, v), v2, den_offset);
    226   return invert ? Div(num, den) : Div(den, num);
    227 }
    228 
    229 template <bool invert = false>
    230 float RatioOfDerivativesOfCubicRootToSimpleGamma(float v) {
    231   using DScalar = HWY_CAPPED(float, 1);
    232   auto vscalar = Load(DScalar(), &v);
    233   return GetLane(
    234       RatioOfDerivativesOfCubicRootToSimpleGamma<invert>(DScalar(), vscalar));
    235 }
    236 
    237 // TODO(veluca): this function computes an approximation of the derivative of
    238 // SimpleGamma with (f(x+eps)-f(x))/eps. Consider two-sided approximation or
    239 // exact derivatives. For reference, SimpleGamma was:
    240 /*
    241 template <typename D, typename V>
    242 V SimpleGamma(const D d, V v) {
    243   // A simple HDR compatible gamma function.
    244   const auto mul = Set(d, kSGmul);
    245   const auto kRetMul = Set(d, kSGRetMul);
    246   const auto kRetAdd = Set(d, kSGmul2 * -20.2789020414f);
    247   const auto kVOffset = Set(d, kSGVOffset);
    248 
    249   v *= mul;
    250 
    251   // This should happen rarely, but may lead to a NaN, which is rather
    252   // undesirable. Since negative photons don't exist we solve the NaNs by
    253   // clamping here.
    254   // TODO(veluca): with FastLog2f, this no longer leads to NaNs.
    255   v = ZeroIfNegative(v);
    256   return kRetMul * FastLog2f(d, v + kVOffset) + kRetAdd;
    257 }
    258 */
    259 
    260 template <class D, class V>
    261 V GammaModulation(const D d, const size_t x, const size_t y,
    262                   const RowBuffer<float>& input, const V out_val) {
    263   static const float kBias = 0.16f / kInputScaling;
    264   static const float kScale = kInputScaling / 64.0f;
    265   auto overall_ratio = Zero(d);
    266   const auto bias = Set(d, kBias);
    267   const auto scale = Set(d, kScale);
    268   const float* const JXL_RESTRICT block_start = input.Row(y) + x;
    269   for (size_t dy = 0; dy < 8; ++dy) {
    270     const float* const JXL_RESTRICT row_in = block_start + dy * input.stride();
    271     for (size_t dx = 0; dx < 8; dx += Lanes(d)) {
    272       const auto iny = Add(Load(d, row_in + dx), bias);
    273       const auto ratio_g =
    274           RatioOfDerivativesOfCubicRootToSimpleGamma</*invert=*/true>(d, iny);
    275       overall_ratio = Add(overall_ratio, ratio_g);
    276     }
    277   }
    278   overall_ratio = Mul(SumOfLanes(d, overall_ratio), scale);
    279   // ideally -1.0, but likely optimal correction adds some entropy, so slightly
    280   // less than that.
    281   // ln(2) constant folded in because we want std::log but have FastLog2f.
    282   const auto kGam = Set(d, -0.15526878023684174f * 0.693147180559945f);
    283   return MulAdd(kGam, FastLog2f(d, overall_ratio), out_val);
    284 }
    285 
    286 // Change precision in 8x8 blocks that have high frequency content.
    287 template <class D, class V>
    288 V HfModulation(const D d, const size_t x, const size_t y,
    289                const RowBuffer<float>& input, const V out_val) {
    290   // Zero out the invalid differences for the rightmost value per row.
    291   const Rebind<uint32_t, D> du;
    292   HWY_ALIGN constexpr uint32_t kMaskRight[8] = {~0u, ~0u, ~0u, ~0u,
    293                                                 ~0u, ~0u, ~0u, 0};
    294 
    295   auto sum = Zero(d);  // sum of absolute differences with right and below
    296   static const float kSumCoeff = -2.0052193233688884f * kInputScaling / 112.0;
    297   auto sumcoeff = Set(d, kSumCoeff);
    298 
    299   const float* const JXL_RESTRICT block_start = input.Row(y) + x;
    300   for (size_t dy = 0; dy < 8; ++dy) {
    301     const float* JXL_RESTRICT row_in = block_start + dy * input.stride();
    302     const float* JXL_RESTRICT row_in_next =
    303         dy == 7 ? row_in : row_in + input.stride();
    304 
    305     for (size_t dx = 0; dx < 8; dx += Lanes(d)) {
    306       const auto p = Load(d, row_in + dx);
    307       const auto pr = LoadU(d, row_in + dx + 1);
    308       const auto mask = BitCast(d, Load(du, kMaskRight + dx));
    309       sum = Add(sum, And(mask, AbsDiff(p, pr)));
    310       const auto pd = Load(d, row_in_next + dx);
    311       sum = Add(sum, AbsDiff(p, pd));
    312     }
    313   }
    314 
    315   sum = SumOfLanes(d, sum);
    316   return MulAdd(sum, sumcoeff, out_val);
    317 }
    318 
    319 void PerBlockModulations(const float y_quant_01, const RowBuffer<float>& input,
    320                          const size_t yb0, const size_t yblen,
    321                          RowBuffer<float>* aq_map) {
    322   static const float kAcQuant = 0.841f;
    323   float base_level = 0.48f * kAcQuant;
    324   float kDampenRampStart = 9.0f;
    325   float kDampenRampEnd = 65.0f;
    326   float dampen = 1.0f;
    327   if (y_quant_01 >= kDampenRampStart) {
    328     dampen = 1.0f - ((y_quant_01 - kDampenRampStart) /
    329                      (kDampenRampEnd - kDampenRampStart));
    330     if (dampen < 0) {
    331       dampen = 0;
    332     }
    333   }
    334   const float mul = kAcQuant * dampen;
    335   const float add = (1.0f - dampen) * base_level;
    336   for (size_t iy = 0; iy < yblen; iy++) {
    337     const size_t yb = yb0 + iy;
    338     const size_t y = yb * 8;
    339     float* const JXL_RESTRICT row_out = aq_map->Row(yb);
    340     const HWY_CAPPED(float, 8) df;
    341     for (size_t ix = 0; ix < aq_map->xsize(); ix++) {
    342       size_t x = ix * 8;
    343       auto out_val = Set(df, row_out[ix]);
    344       out_val = ComputeMask(df, out_val);
    345       out_val = HfModulation(df, x, y, input, out_val);
    346       out_val = GammaModulation(df, x, y, input, out_val);
    347       // We want multiplicative quantization field, so everything
    348       // until this point has been modulating the exponent.
    349       row_out[ix] = FastPow2f(GetLane(out_val) * 1.442695041f) * mul + add;
    350     }
    351   }
    352 }
    353 
    354 template <typename D, typename V>
    355 V MaskingSqrt(const D d, V v) {
    356   static const float kLogOffset = 28;
    357   static const float kMul = 211.50759899638012f;
    358   const auto mul_v = Set(d, kMul * 1e8);
    359   const auto offset_v = Set(d, kLogOffset);
    360   return Mul(Set(d, 0.25f), Sqrt(MulAdd(v, Sqrt(mul_v), offset_v)));
    361 }
    362 
    363 template <typename V>
    364 void Sort4(V& min0, V& min1, V& min2, V& min3) {
    365   const auto tmp0 = Min(min0, min1);
    366   const auto tmp1 = Max(min0, min1);
    367   const auto tmp2 = Min(min2, min3);
    368   const auto tmp3 = Max(min2, min3);
    369   const auto tmp4 = Max(tmp0, tmp2);
    370   const auto tmp5 = Min(tmp1, tmp3);
    371   min0 = Min(tmp0, tmp2);
    372   min1 = Min(tmp4, tmp5);
    373   min2 = Max(tmp4, tmp5);
    374   min3 = Max(tmp1, tmp3);
    375 }
    376 
    377 template <typename V>
    378 void UpdateMin4(const V v, V& min0, V& min1, V& min2, V& min3) {
    379   const auto tmp0 = Max(min0, v);
    380   const auto tmp1 = Max(min1, tmp0);
    381   const auto tmp2 = Max(min2, tmp1);
    382   min0 = Min(min0, v);
    383   min1 = Min(min1, tmp0);
    384   min2 = Min(min2, tmp1);
    385   min3 = Min(min3, tmp2);
    386 }
    387 
    388 // Computes a linear combination of the 4 lowest values of the 3x3 neighborhood
    389 // of each pixel. Output is downsampled 2x.
    390 void FuzzyErosion(const RowBuffer<float>& pre_erosion, const size_t yb0,
    391                   const size_t yblen, RowBuffer<float>* tmp,
    392                   RowBuffer<float>* aq_map) {
    393   int xsize_blocks = aq_map->xsize();
    394   int xsize = pre_erosion.xsize();
    395   HWY_FULL(float) d;
    396   const auto mul0 = Set(d, 0.125f);
    397   const auto mul1 = Set(d, 0.075f);
    398   const auto mul2 = Set(d, 0.06f);
    399   const auto mul3 = Set(d, 0.05f);
    400   for (size_t iy = 0; iy < 2 * yblen; ++iy) {
    401     size_t y = 2 * yb0 + iy;
    402     const float* JXL_RESTRICT rowt = pre_erosion.Row(y - 1);
    403     const float* JXL_RESTRICT rowm = pre_erosion.Row(y);
    404     const float* JXL_RESTRICT rowb = pre_erosion.Row(y + 1);
    405     float* row_out = tmp->Row(y);
    406     for (int x = 0; x < xsize; x += Lanes(d)) {
    407       int xm1 = x - 1;
    408       int xp1 = x + 1;
    409       auto min0 = LoadU(d, rowm + x);
    410       auto min1 = LoadU(d, rowm + xm1);
    411       auto min2 = LoadU(d, rowm + xp1);
    412       auto min3 = LoadU(d, rowt + xm1);
    413       Sort4(min0, min1, min2, min3);
    414       UpdateMin4(LoadU(d, rowt + x), min0, min1, min2, min3);
    415       UpdateMin4(LoadU(d, rowt + xp1), min0, min1, min2, min3);
    416       UpdateMin4(LoadU(d, rowb + xm1), min0, min1, min2, min3);
    417       UpdateMin4(LoadU(d, rowb + x), min0, min1, min2, min3);
    418       UpdateMin4(LoadU(d, rowb + xp1), min0, min1, min2, min3);
    419       const auto v = Add(Add(Mul(mul0, min0), Mul(mul1, min1)),
    420                          Add(Mul(mul2, min2), Mul(mul3, min3)));
    421       Store(v, d, row_out + x);
    422     }
    423     if (iy % 2 == 1) {
    424       const float* JXL_RESTRICT row_out0 = tmp->Row(y - 1);
    425       float* JXL_RESTRICT aq_out = aq_map->Row(yb0 + iy / 2);
    426       for (int bx = 0, x = 0; bx < xsize_blocks; ++bx, x += 2) {
    427         aq_out[bx] =
    428             (row_out[x] + row_out[x + 1] + row_out0[x] + row_out0[x + 1]);
    429       }
    430     }
    431   }
    432 }
    433 
    434 void ComputePreErosion(const RowBuffer<float>& input, const size_t xsize,
    435                        const size_t y0, const size_t ylen, int border,
    436                        float* diff_buffer, RowBuffer<float>* pre_erosion) {
    437   const size_t xsize_out = xsize / 4;
    438   const size_t y0_out = y0 / 4;
    439 
    440   // The XYB gamma is 3.0 to be able to decode faster with two muls.
    441   // Butteraugli's gamma is matching the gamma of human eye, around 2.6.
    442   // We approximate the gamma difference by adding one cubic root into
    443   // the adaptive quantization. This gives us a total gamma of 2.6666
    444   // for quantization uses.
    445   static const float match_gamma_offset = 0.019 / kInputScaling;
    446 
    447   const HWY_CAPPED(float, 8) df;
    448 
    449   static const float limit = 0.2f;
    450   // Computes image (padded to multiple of 8x8) of local pixel differences.
    451   // Subsample both directions by 4.
    452   for (size_t iy = 0; iy < ylen; ++iy) {
    453     size_t y = y0 + iy;
    454     const float* row_in = input.Row(y);
    455     const float* row_in1 = input.Row(y + 1);
    456     const float* row_in2 = input.Row(y - 1);
    457     float* JXL_RESTRICT row_out = diff_buffer;
    458     const auto match_gamma_offset_v = Set(df, match_gamma_offset);
    459     const auto quarter = Set(df, 0.25f);
    460     for (size_t x = 0; x < xsize; x += Lanes(df)) {
    461       const auto in = LoadU(df, row_in + x);
    462       const auto in_r = LoadU(df, row_in + x + 1);
    463       const auto in_l = LoadU(df, row_in + x - 1);
    464       const auto in_t = LoadU(df, row_in2 + x);
    465       const auto in_b = LoadU(df, row_in1 + x);
    466       const auto base = Mul(quarter, Add(Add(in_r, in_l), Add(in_t, in_b)));
    467       const auto gammacv =
    468           RatioOfDerivativesOfCubicRootToSimpleGamma</*invert=*/false>(
    469               df, Add(in, match_gamma_offset_v));
    470       auto diff = Mul(gammacv, Sub(in, base));
    471       diff = Mul(diff, diff);
    472       diff = Min(diff, Set(df, limit));
    473       diff = MaskingSqrt(df, diff);
    474       if ((iy & 3) != 0) {
    475         diff = Add(diff, LoadU(df, row_out + x));
    476       }
    477       StoreU(diff, df, row_out + x);
    478     }
    479     if (iy % 4 == 3) {
    480       size_t y_out = y0_out + iy / 4;
    481       float* row_dout = pre_erosion->Row(y_out);
    482       for (size_t x = 0; x < xsize_out; x++) {
    483         row_dout[x] = (row_out[x * 4] + row_out[x * 4 + 1] +
    484                        row_out[x * 4 + 2] + row_out[x * 4 + 3]) *
    485                       0.25f;
    486       }
    487       pre_erosion->PadRow(y_out, xsize_out, border);
    488     }
    489   }
    490 }
    491 
    492 }  // namespace
    493 
    494 // NOLINTNEXTLINE(google-readability-namespace-comments)
    495 }  // namespace HWY_NAMESPACE
    496 }  // namespace jpegli
    497 HWY_AFTER_NAMESPACE();
    498 
    499 #if HWY_ONCE
    500 namespace jpegli {
    501 HWY_EXPORT(ComputePreErosion);
    502 HWY_EXPORT(FuzzyErosion);
    503 HWY_EXPORT(PerBlockModulations);
    504 
    505 namespace {
    506 
    507 constexpr int kPreErosionBorder = 1;
    508 
    509 }  // namespace
    510 
    511 void ComputeAdaptiveQuantField(j_compress_ptr cinfo) {
    512   jpeg_comp_master* m = cinfo->master;
    513   if (!m->use_adaptive_quantization) {
    514     return;
    515   }
    516   int y_channel = cinfo->jpeg_color_space == JCS_RGB ? 1 : 0;
    517   jpeg_component_info* y_comp = &cinfo->comp_info[y_channel];
    518   int y_quant_01 = cinfo->quant_tbl_ptrs[y_comp->quant_tbl_no]->quantval[1];
    519   if (m->next_iMCU_row == 0) {
    520     m->input_buffer[y_channel].CopyRow(-1, 0, 1);
    521   }
    522   if (m->next_iMCU_row + 1 == cinfo->total_iMCU_rows) {
    523     size_t last_row = m->ysize_blocks * DCTSIZE - 1;
    524     m->input_buffer[y_channel].CopyRow(last_row + 1, last_row, 1);
    525   }
    526   const RowBuffer<float>& input = m->input_buffer[y_channel];
    527   const size_t xsize_blocks = y_comp->width_in_blocks;
    528   const size_t xsize = xsize_blocks * DCTSIZE;
    529   const size_t yb0 = m->next_iMCU_row * cinfo->max_v_samp_factor;
    530   const size_t yblen = cinfo->max_v_samp_factor;
    531   size_t y0 = yb0 * DCTSIZE;
    532   size_t ylen = cinfo->max_v_samp_factor * DCTSIZE;
    533   if (y0 == 0) {
    534     ylen += 4;
    535   } else {
    536     y0 += 4;
    537   }
    538   if (m->next_iMCU_row + 1 == cinfo->total_iMCU_rows) {
    539     ylen -= 4;
    540   }
    541   HWY_DYNAMIC_DISPATCH(ComputePreErosion)
    542   (input, xsize, y0, ylen, kPreErosionBorder, m->diff_buffer, &m->pre_erosion);
    543   if (y0 == 0) {
    544     m->pre_erosion.CopyRow(-1, 0, kPreErosionBorder);
    545   }
    546   if (m->next_iMCU_row + 1 == cinfo->total_iMCU_rows) {
    547     size_t last_row = m->ysize_blocks * 2 - 1;
    548     m->pre_erosion.CopyRow(last_row + 1, last_row, kPreErosionBorder);
    549   }
    550   HWY_DYNAMIC_DISPATCH(FuzzyErosion)
    551   (m->pre_erosion, yb0, yblen, &m->fuzzy_erosion_tmp, &m->quant_field);
    552   HWY_DYNAMIC_DISPATCH(PerBlockModulations)
    553   (y_quant_01, input, yb0, yblen, &m->quant_field);
    554   for (int y = 0; y < cinfo->max_v_samp_factor; ++y) {
    555     float* row = m->quant_field.Row(yb0 + y);
    556     for (size_t x = 0; x < xsize_blocks; ++x) {
    557       row[x] = std::max(0.0f, (0.6f / row[x]) - 1.0f);
    558     }
    559   }
    560 }
    561 
    562 }  // namespace jpegli
    563 #endif  // HWY_ONCE