libjxl

FORK: libjxl patches used on blog
git clone https://git.neptards.moe/blog/libjxl.git
Log | Files | Refs | Submodules | README | LICENSE

gauss_blur.cc (19990B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 #include "tools/gauss_blur.h"
      7 
      8 #include <algorithm>
      9 #include <cmath>
     10 #include <cstdint>
     11 
     12 #undef HWY_TARGET_INCLUDE
     13 #define HWY_TARGET_INCLUDE "tools/gauss_blur.cc"
     14 #include <hwy/aligned_allocator.h>
     15 #include <hwy/cache_control.h>  // Prefetch
     16 #include <hwy/foreach_target.h>
     17 #include <hwy/highway.h>
     18 
     19 #include "lib/jxl/base/common.h"             // RoundUpTo
     20 #include "lib/jxl/base/compiler_specific.h"  // JXL_RESTRICT
     21 #include "lib/jxl/base/matrix_ops.h"         // Inv3x3Matrix
     22 HWY_BEFORE_NAMESPACE();
     23 namespace jxl {
     24 namespace HWY_NAMESPACE {
     25 
     26 // These templates are not found via ADL.
     27 using hwy::HWY_NAMESPACE::Add;
     28 using hwy::HWY_NAMESPACE::Broadcast;
     29 using hwy::HWY_NAMESPACE::GetLane;
     30 using hwy::HWY_NAMESPACE::Mul;
     31 using hwy::HWY_NAMESPACE::MulAdd;
     32 using hwy::HWY_NAMESPACE::NegMulSub;
     33 #if HWY_TARGET != HWY_SCALAR
     34 using hwy::HWY_NAMESPACE::ShiftLeftLanes;
     35 #endif
     36 using hwy::HWY_NAMESPACE::Vec;
     37 
     38 void FastGaussian1D(const hwy::AlignedUniquePtr<RecursiveGaussian>& rg,
     39                     const intptr_t xsize, const float* JXL_RESTRICT in,
     40                     float* JXL_RESTRICT out) {
     41   // Although the current output depends on the previous output, we can unroll
     42   // up to 4x by precomputing up to fourth powers of the constants. Beyond that,
     43   // numerical precision might become a problem. Macro because this is tested
     44   // in #if alongside HWY_TARGET.
     45 #define JXL_GAUSS_MAX_LANES 4
     46   using D = HWY_CAPPED(float, JXL_GAUSS_MAX_LANES);
     47   using V = Vec<D>;
     48   const D d;
     49   const V mul_in_1 = Load(d, rg->mul_in + 0 * 4);
     50   const V mul_in_3 = Load(d, rg->mul_in + 1 * 4);
     51   const V mul_in_5 = Load(d, rg->mul_in + 2 * 4);
     52   const V mul_prev_1 = Load(d, rg->mul_prev + 0 * 4);
     53   const V mul_prev_3 = Load(d, rg->mul_prev + 1 * 4);
     54   const V mul_prev_5 = Load(d, rg->mul_prev + 2 * 4);
     55   const V mul_prev2_1 = Load(d, rg->mul_prev2 + 0 * 4);
     56   const V mul_prev2_3 = Load(d, rg->mul_prev2 + 1 * 4);
     57   const V mul_prev2_5 = Load(d, rg->mul_prev2 + 2 * 4);
     58   V prev_1 = Zero(d);
     59   V prev_3 = Zero(d);
     60   V prev_5 = Zero(d);
     61   V prev2_1 = Zero(d);
     62   V prev2_3 = Zero(d);
     63   V prev2_5 = Zero(d);
     64 
     65   const intptr_t N = static_cast<intptr_t>(rg->radius);
     66 
     67   intptr_t n = -N + 1;
     68   // Left side with bounds checks and only write output after n >= 0.
     69   const intptr_t first_aligned = RoundUpTo(N + 1, Lanes(d));
     70   for (; n < std::min(first_aligned, xsize); ++n) {
     71     const intptr_t left = n - N - 1;
     72     const intptr_t right = n + N - 1;
     73     const float left_val = left >= 0 ? in[left] : 0.0f;
     74     const float right_val = (right < xsize) ? in[right] : 0.0f;
     75     const V sum = Set(d, left_val + right_val);
     76 
     77     // (Only processing a single lane here, no need to broadcast)
     78     V out_1 = Mul(sum, mul_in_1);
     79     V out_3 = Mul(sum, mul_in_3);
     80     V out_5 = Mul(sum, mul_in_5);
     81 
     82     out_1 = MulAdd(mul_prev2_1, prev2_1, out_1);
     83     out_3 = MulAdd(mul_prev2_3, prev2_3, out_3);
     84     out_5 = MulAdd(mul_prev2_5, prev2_5, out_5);
     85     prev2_1 = prev_1;
     86     prev2_3 = prev_3;
     87     prev2_5 = prev_5;
     88 
     89     out_1 = MulAdd(mul_prev_1, prev_1, out_1);
     90     out_3 = MulAdd(mul_prev_3, prev_3, out_3);
     91     out_5 = MulAdd(mul_prev_5, prev_5, out_5);
     92     prev_1 = out_1;
     93     prev_3 = out_3;
     94     prev_5 = out_5;
     95 
     96     if (n >= 0) {
     97       out[n] = GetLane(Add(out_1, Add(out_3, out_5)));
     98     }
     99   }
    100 
    101   // The above loop is effectively scalar but it is convenient to use the same
    102   // prev/prev2 variables, so broadcast to each lane before the unrolled loop.
    103 #if HWY_TARGET != HWY_SCALAR && JXL_GAUSS_MAX_LANES > 1
    104   prev2_1 = Broadcast<0>(prev2_1);
    105   prev2_3 = Broadcast<0>(prev2_3);
    106   prev2_5 = Broadcast<0>(prev2_5);
    107   prev_1 = Broadcast<0>(prev_1);
    108   prev_3 = Broadcast<0>(prev_3);
    109   prev_5 = Broadcast<0>(prev_5);
    110 #endif
    111 
    112   // Unrolled, no bounds checking needed.
    113   for (; n < xsize - N + 1 - (JXL_GAUSS_MAX_LANES - 1); n += Lanes(d)) {
    114     const V sum = Add(LoadU(d, in + n - N - 1), LoadU(d, in + n + N - 1));
    115 
    116     // To get a vector of output(s), we multiply broadcasted vectors (of each
    117     // input plus the two previous outputs) and add them all together.
    118     // Incremental broadcasting and shifting is expected to be cheaper than
    119     // horizontal adds or transposing 4x4 values because they run on a different
    120     // port, concurrently with the FMA.
    121     const V in0 = Broadcast<0>(sum);
    122     V out_1 = Mul(in0, mul_in_1);
    123     V out_3 = Mul(in0, mul_in_3);
    124     V out_5 = Mul(in0, mul_in_5);
    125 
    126 #if HWY_TARGET != HWY_SCALAR && JXL_GAUSS_MAX_LANES >= 2
    127     const V in1 = Broadcast<1>(sum);
    128     out_1 = MulAdd(ShiftLeftLanes<1>(mul_in_1), in1, out_1);
    129     out_3 = MulAdd(ShiftLeftLanes<1>(mul_in_3), in1, out_3);
    130     out_5 = MulAdd(ShiftLeftLanes<1>(mul_in_5), in1, out_5);
    131 
    132 #if JXL_GAUSS_MAX_LANES >= 4
    133     const V in2 = Broadcast<2>(sum);
    134     out_1 = MulAdd(ShiftLeftLanes<2>(mul_in_1), in2, out_1);
    135     out_3 = MulAdd(ShiftLeftLanes<2>(mul_in_3), in2, out_3);
    136     out_5 = MulAdd(ShiftLeftLanes<2>(mul_in_5), in2, out_5);
    137 
    138     const V in3 = Broadcast<3>(sum);
    139     out_1 = MulAdd(ShiftLeftLanes<3>(mul_in_1), in3, out_1);
    140     out_3 = MulAdd(ShiftLeftLanes<3>(mul_in_3), in3, out_3);
    141     out_5 = MulAdd(ShiftLeftLanes<3>(mul_in_5), in3, out_5);
    142 #endif
    143 #endif
    144 
    145     out_1 = MulAdd(mul_prev2_1, prev2_1, out_1);
    146     out_3 = MulAdd(mul_prev2_3, prev2_3, out_3);
    147     out_5 = MulAdd(mul_prev2_5, prev2_5, out_5);
    148 
    149     out_1 = MulAdd(mul_prev_1, prev_1, out_1);
    150     out_3 = MulAdd(mul_prev_3, prev_3, out_3);
    151     out_5 = MulAdd(mul_prev_5, prev_5, out_5);
    152 #if HWY_TARGET == HWY_SCALAR || JXL_GAUSS_MAX_LANES == 1
    153     prev2_1 = prev_1;
    154     prev2_3 = prev_3;
    155     prev2_5 = prev_5;
    156     prev_1 = out_1;
    157     prev_3 = out_3;
    158     prev_5 = out_5;
    159 #else
    160     prev2_1 = Broadcast<JXL_GAUSS_MAX_LANES - 2>(out_1);
    161     prev2_3 = Broadcast<JXL_GAUSS_MAX_LANES - 2>(out_3);
    162     prev2_5 = Broadcast<JXL_GAUSS_MAX_LANES - 2>(out_5);
    163     prev_1 = Broadcast<JXL_GAUSS_MAX_LANES - 1>(out_1);
    164     prev_3 = Broadcast<JXL_GAUSS_MAX_LANES - 1>(out_3);
    165     prev_5 = Broadcast<JXL_GAUSS_MAX_LANES - 1>(out_5);
    166 #endif
    167 
    168     Store(Add(out_1, Add(out_3, out_5)), d, out + n);
    169   }
    170 
    171   // Remainder handling with bounds checks
    172   for (; n < xsize; ++n) {
    173     const intptr_t left = n - N - 1;
    174     const intptr_t right = n + N - 1;
    175     const float left_val = left >= 0 ? in[left] : 0.0f;
    176     const float right_val = (right < xsize) ? in[right] : 0.0f;
    177     const V sum = Set(d, left_val + right_val);
    178 
    179     // (Only processing a single lane here, no need to broadcast)
    180     V out_1 = Mul(sum, mul_in_1);
    181     V out_3 = Mul(sum, mul_in_3);
    182     V out_5 = Mul(sum, mul_in_5);
    183 
    184     out_1 = MulAdd(mul_prev2_1, prev2_1, out_1);
    185     out_3 = MulAdd(mul_prev2_3, prev2_3, out_3);
    186     out_5 = MulAdd(mul_prev2_5, prev2_5, out_5);
    187     prev2_1 = prev_1;
    188     prev2_3 = prev_3;
    189     prev2_5 = prev_5;
    190 
    191     out_1 = MulAdd(mul_prev_1, prev_1, out_1);
    192     out_3 = MulAdd(mul_prev_3, prev_3, out_3);
    193     out_5 = MulAdd(mul_prev_5, prev_5, out_5);
    194     prev_1 = out_1;
    195     prev_3 = out_3;
    196     prev_5 = out_5;
    197 
    198     out[n] = GetLane(Add(out_1, Add(out_3, out_5)));
    199   }
    200 }
    201 
    202 // Ring buffer is for n, n-1, n-2; round up to 4 for faster modulo.
    203 constexpr size_t kMod = 4;
    204 
    205 // Avoids an unnecessary store during warmup.
    206 struct OutputNone {
    207   template <class V>
    208   void operator()(const V& /*unused*/, float* JXL_RESTRICT /*pos*/,
    209                   ptrdiff_t /*offset*/) const {}
    210 };
    211 
    212 // Common case: write output vectors in all VerticalBlock except warmup.
    213 struct OutputStore {
    214   template <class V>
    215   void operator()(const V& out, float* JXL_RESTRICT pos,
    216                   ptrdiff_t offset) const {
    217     // Stream helps for large images but is slower for images that fit in cache.
    218     const HWY_FULL(float) df;
    219     Store(out, df, pos + offset);
    220   }
    221 };
    222 
    223 // At top/bottom borders, we don't have two inputs to load, so avoid addition.
    224 // pos may even point to all zeros if the row is outside the input image.
    225 class SingleInput {
    226  public:
    227   explicit SingleInput(const float* pos) : pos_(pos) {}
    228   Vec<HWY_FULL(float)> operator()(const size_t offset) const {
    229     const HWY_FULL(float) df;
    230     return Load(df, pos_ + offset);
    231   }
    232   const float* pos_;
    233 };
    234 
    235 // In the middle of the image, we need to load from a row above and below, and
    236 // return the sum.
    237 class TwoInputs {
    238  public:
    239   TwoInputs(const float* pos1, const float* pos2) : pos1_(pos1), pos2_(pos2) {}
    240   Vec<HWY_FULL(float)> operator()(const size_t offset) const {
    241     const HWY_FULL(float) df;
    242     const auto in1 = Load(df, pos1_ + offset);
    243     const auto in2 = Load(df, pos2_ + offset);
    244     return Add(in1, in2);
    245   }
    246 
    247  private:
    248   const float* pos1_;
    249   const float* pos2_;
    250 };
    251 
    252 // Block := kVectors consecutive full vectors (one cache line except on the
    253 // right boundary, where we can only rely on having one vector). Unrolling to
    254 // the cache line size improves cache utilization.
    255 template <size_t kVectors, class V, class Input, class Output>
    256 void VerticalBlock(const V& d1_1, const V& d1_3, const V& d1_5, const V& n2_1,
    257                    const V& n2_3, const V& n2_5, const Input& input,
    258                    size_t& ctr, float* ring_buffer, const Output output,
    259                    float* JXL_RESTRICT out_pos) {
    260   const HWY_FULL(float) d;
    261   constexpr size_t kVN = MaxLanes(d);
    262   // More cache-friendly to process an entirely cache line at a time
    263   constexpr size_t kLanes = kVectors * kVN;
    264 
    265   float* JXL_RESTRICT y_1 = ring_buffer + 0 * kLanes * kMod;
    266   float* JXL_RESTRICT y_3 = ring_buffer + 1 * kLanes * kMod;
    267   float* JXL_RESTRICT y_5 = ring_buffer + 2 * kLanes * kMod;
    268 
    269   const size_t n_0 = (++ctr) % kMod;
    270   const size_t n_1 = (ctr - 1) % kMod;
    271   const size_t n_2 = (ctr - 2) % kMod;
    272 
    273   for (size_t idx_vec = 0; idx_vec < kVectors; ++idx_vec) {
    274     const V sum = input(idx_vec * kVN);
    275 
    276     const V y_n1_1 = Load(d, y_1 + kLanes * n_1 + idx_vec * kVN);
    277     const V y_n1_3 = Load(d, y_3 + kLanes * n_1 + idx_vec * kVN);
    278     const V y_n1_5 = Load(d, y_5 + kLanes * n_1 + idx_vec * kVN);
    279     const V y_n2_1 = Load(d, y_1 + kLanes * n_2 + idx_vec * kVN);
    280     const V y_n2_3 = Load(d, y_3 + kLanes * n_2 + idx_vec * kVN);
    281     const V y_n2_5 = Load(d, y_5 + kLanes * n_2 + idx_vec * kVN);
    282     // (35)
    283     const V y1 = MulAdd(n2_1, sum, NegMulSub(d1_1, y_n1_1, y_n2_1));
    284     const V y3 = MulAdd(n2_3, sum, NegMulSub(d1_3, y_n1_3, y_n2_3));
    285     const V y5 = MulAdd(n2_5, sum, NegMulSub(d1_5, y_n1_5, y_n2_5));
    286     Store(y1, d, y_1 + kLanes * n_0 + idx_vec * kVN);
    287     Store(y3, d, y_3 + kLanes * n_0 + idx_vec * kVN);
    288     Store(y5, d, y_5 + kLanes * n_0 + idx_vec * kVN);
    289     output(Add(y1, Add(y3, y5)), out_pos, idx_vec * kVN);
    290   }
    291   // NOTE: flushing cache line out_pos hurts performance - less so with
    292   // clflushopt than clflush but still a significant slowdown.
    293 }
    294 
    295 // Reads/writes one block (kVectors full vectors) in each row.
    296 template <size_t kVectors>
    297 void VerticalStrip(const hwy::AlignedUniquePtr<RecursiveGaussian>& rg,
    298                    const size_t x, const size_t ysize, const GetConstRow& in,
    299                    const GetRow& out) {
    300   // We're iterating vertically, so use multiple full-length vectors (each lane
    301   // is one column of row n).
    302   using D = HWY_FULL(float);
    303   using V = Vec<D>;
    304   const D d;
    305   constexpr size_t kVN = MaxLanes(d);
    306   // More cache-friendly to process an entirely cache line at a time
    307   constexpr size_t kLanes = kVectors * kVN;
    308 #if HWY_TARGET == HWY_SCALAR
    309   const V d1_1 = Set(d, rg->d1[0 * 4]);
    310   const V d1_3 = Set(d, rg->d1[1 * 4]);
    311   const V d1_5 = Set(d, rg->d1[2 * 4]);
    312   const V n2_1 = Set(d, rg->n2[0 * 4]);
    313   const V n2_3 = Set(d, rg->n2[1 * 4]);
    314   const V n2_5 = Set(d, rg->n2[2 * 4]);
    315 #else
    316   const V d1_1 = LoadDup128(d, rg->d1 + 0 * 4);
    317   const V d1_3 = LoadDup128(d, rg->d1 + 1 * 4);
    318   const V d1_5 = LoadDup128(d, rg->d1 + 2 * 4);
    319   const V n2_1 = LoadDup128(d, rg->n2 + 0 * 4);
    320   const V n2_3 = LoadDup128(d, rg->n2 + 1 * 4);
    321   const V n2_5 = LoadDup128(d, rg->n2 + 2 * 4);
    322 #endif
    323 
    324   const size_t N = rg->radius;
    325 
    326   size_t ctr = 0;
    327   HWY_ALIGN float ring_buffer[3 * kLanes * kMod] = {0};
    328   HWY_ALIGN static constexpr float zero[kLanes] = {0};
    329 
    330   // Warmup: top is out of bounds (zero padded), bottom is usually in-bounds.
    331   ssize_t n = -static_cast<ssize_t>(N) + 1;
    332   for (; n < 0; ++n) {
    333     // bottom is always non-negative since n is initialized in -N + 1.
    334     const size_t bottom = n + N - 1;
    335     VerticalBlock<kVectors>(d1_1, d1_3, d1_5, n2_1, n2_3, n2_5,
    336                             SingleInput(bottom < ysize ? in(bottom) + x : zero),
    337                             ctr, ring_buffer, OutputNone(), nullptr);
    338   }
    339   JXL_DASSERT(n >= 0);
    340 
    341   // Start producing output; top is still out of bounds.
    342   for (; static_cast<size_t>(n) < std::min(N + 1, ysize); ++n) {
    343     const size_t bottom = n + N - 1;
    344     VerticalBlock<kVectors>(d1_1, d1_3, d1_5, n2_1, n2_3, n2_5,
    345                             SingleInput(bottom < ysize ? in(bottom) + x : zero),
    346                             ctr, ring_buffer, OutputStore(), out(n) + x);
    347   }
    348 
    349   // Interior outputs with prefetching and without bounds checks.
    350   constexpr size_t kPrefetchRows = 8;
    351   for (; n < static_cast<ssize_t>(ysize - N + 1 - kPrefetchRows); ++n) {
    352     const size_t top = n - N - 1;
    353     const size_t bottom = n + N - 1;
    354     VerticalBlock<kVectors>(d1_1, d1_3, d1_5, n2_1, n2_3, n2_5,
    355                             TwoInputs(in(top) + x, in(bottom) + x), ctr,
    356                             ring_buffer, OutputStore(), out(n) + x);
    357     hwy::Prefetch(in(top + kPrefetchRows) + x);
    358     hwy::Prefetch(in(bottom + kPrefetchRows) + x);
    359   }
    360 
    361   // Bottom border without prefetching and with bounds checks.
    362   for (; static_cast<size_t>(n) < ysize; ++n) {
    363     const size_t top = n - N - 1;
    364     const size_t bottom = n + N - 1;
    365     VerticalBlock<kVectors>(
    366         d1_1, d1_3, d1_5, n2_1, n2_3, n2_5,
    367         TwoInputs(in(top) + x, bottom < ysize ? in(bottom) + x : zero), ctr,
    368         ring_buffer, OutputStore(), out(n) + x);
    369   }
    370 }
    371 
    372 // Apply 1D vertical scan to multiple columns (one per vector lane).
    373 // Not yet parallelized.
    374 void FastGaussianVertical(const hwy::AlignedUniquePtr<RecursiveGaussian>& rg,
    375                           const size_t xsize, const size_t ysize,
    376                           const GetConstRow& in, const GetRow& out,
    377                           ThreadPool* /* pool */) {
    378   const HWY_FULL(float) df;
    379   constexpr size_t kCacheLineLanes = 64 / sizeof(float);
    380   constexpr size_t kVN = MaxLanes(df);
    381   constexpr size_t kCacheLineVectors =
    382       (kVN < kCacheLineLanes) ? (kCacheLineLanes / kVN) : 4;
    383   constexpr size_t kFastPace = kCacheLineVectors * kVN;
    384 
    385   // TODO(eustas): why pool is unused?
    386   size_t x = 0;
    387   for (; x + kFastPace <= xsize; x += kFastPace) {
    388     VerticalStrip<kCacheLineVectors>(rg, x, ysize, in, out);
    389   }
    390   for (; x < xsize; x += kVN) {
    391     VerticalStrip<1>(rg, x, ysize, in, out);
    392   }
    393 }
    394 
    395 // NOLINTNEXTLINE(google-readability-namespace-comments)
    396 }  // namespace HWY_NAMESPACE
    397 }  // namespace jxl
    398 HWY_AFTER_NAMESPACE();
    399 
    400 #if HWY_ONCE
    401 namespace jxl {
    402 
    403 HWY_EXPORT(FastGaussian1D);
    404 void FastGaussian1D(const hwy::AlignedUniquePtr<RecursiveGaussian>& rg,
    405                     const size_t xsize, const float* JXL_RESTRICT in,
    406                     float* JXL_RESTRICT out) {
    407   HWY_DYNAMIC_DISPATCH(FastGaussian1D)
    408   (rg, static_cast<intptr_t>(xsize), in, out);
    409 }
    410 
    411 HWY_EXPORT(FastGaussianVertical);  // Local function.
    412 
    413 // Implements "Recursive Implementation of the Gaussian Filter Using Truncated
    414 // Cosine Functions" by Charalampidis [2016].
    415 hwy::AlignedUniquePtr<RecursiveGaussian> CreateRecursiveGaussian(double sigma) {
    416   auto rg = hwy::MakeUniqueAligned<RecursiveGaussian>();
    417   constexpr double kPi = 3.141592653589793238;
    418 
    419   const double radius = roundf(3.2795 * sigma + 0.2546);  // (57), "N"
    420 
    421   // Table I, first row
    422   const double pi_div_2r = kPi / (2.0 * radius);
    423   const double omega[3] = {pi_div_2r, 3.0 * pi_div_2r, 5.0 * pi_div_2r};
    424 
    425   // (37), k={1,3,5}
    426   const double p_1 = +1.0 / std::tan(0.5 * omega[0]);
    427   const double p_3 = -1.0 / std::tan(0.5 * omega[1]);
    428   const double p_5 = +1.0 / std::tan(0.5 * omega[2]);
    429 
    430   // (44), k={1,3,5}
    431   const double r_1 = +p_1 * p_1 / std::sin(omega[0]);
    432   const double r_3 = -p_3 * p_3 / std::sin(omega[1]);
    433   const double r_5 = +p_5 * p_5 / std::sin(omega[2]);
    434 
    435   // (50), k={1,3,5}
    436   const double neg_half_sigma2 = -0.5 * sigma * sigma;
    437   const double recip_radius = 1.0 / radius;
    438   double rho[3];
    439   for (size_t i = 0; i < 3; ++i) {
    440     rho[i] = std::exp(neg_half_sigma2 * omega[i] * omega[i]) * recip_radius;
    441   }
    442 
    443   // second part of (52), k1,k2 = 1,3; 3,5; 5,1
    444   const double D_13 = p_1 * r_3 - r_1 * p_3;
    445   const double D_35 = p_3 * r_5 - r_3 * p_5;
    446   const double D_51 = p_5 * r_1 - r_5 * p_1;
    447 
    448   // (52), k=5
    449   const double recip_d13 = 1.0 / D_13;
    450   const double zeta_15 = D_35 * recip_d13;
    451   const double zeta_35 = D_51 * recip_d13;
    452 
    453   double A[9] = {p_1,     p_3,     p_5,  //
    454                  r_1,     r_3,     r_5,  //  (56)
    455                  zeta_15, zeta_35, 1};
    456   JXL_CHECK(Inv3x3Matrix(A));
    457   const double gamma[3] = {1, radius * radius - sigma * sigma,  // (55)
    458                            zeta_15 * rho[0] + zeta_35 * rho[1] + rho[2]};
    459   double beta[3];
    460   Mul3x3Vector(A, gamma, beta);  // (53)
    461 
    462   // Sanity check: correctly solved for beta (IIR filter weights are normalized)
    463   const double sum = beta[0] * p_1 + beta[1] * p_3 + beta[2] * p_5;  // (39)
    464   JXL_ASSERT(std::abs(sum - 1) < 1E-12);
    465   (void)sum;
    466 
    467   rg->radius = static_cast<int>(radius);
    468 
    469   double n2[3];
    470   double d1[3];
    471   for (size_t i = 0; i < 3; ++i) {
    472     n2[i] = -beta[i] * std::cos(omega[i] * (radius + 1.0));  // (33)
    473     d1[i] = -2.0 * std::cos(omega[i]);                       // (33)
    474 
    475     for (size_t lane = 0; lane < 4; ++lane) {
    476       rg->n2[4 * i + lane] = static_cast<float>(n2[i]);
    477       rg->d1[4 * i + lane] = static_cast<float>(d1[i]);
    478     }
    479 
    480     const double d_2 = d1[i] * d1[i];
    481 
    482     // Obtained by expanding (35) for four consecutive outputs via sympy:
    483     // n, d, p, pp = symbols('n d p pp')
    484     // i0, i1, i2, i3 = symbols('i0 i1 i2 i3')
    485     // o0, o1, o2, o3 = symbols('o0 o1 o2 o3')
    486     // o0 = n*i0 - d*p - pp
    487     // o1 = n*i1 - d*o0 - p
    488     // o2 = n*i2 - d*o1 - o0
    489     // o3 = n*i3 - d*o2 - o1
    490     // Then expand(o3) and gather terms for p(prev), pp(prev2) etc.
    491     rg->mul_prev[4 * i + 0] = -d1[i];
    492     rg->mul_prev[4 * i + 1] = d_2 - 1.0;
    493     rg->mul_prev[4 * i + 2] = -d_2 * d1[i] + 2.0 * d1[i];
    494     rg->mul_prev[4 * i + 3] = d_2 * d_2 - 3.0 * d_2 + 1.0;
    495     rg->mul_prev2[4 * i + 0] = -1.0;
    496     rg->mul_prev2[4 * i + 1] = d1[i];
    497     rg->mul_prev2[4 * i + 2] = -d_2 + 1.0;
    498     rg->mul_prev2[4 * i + 3] = d_2 * d1[i] - 2.0 * d1[i];
    499     rg->mul_in[4 * i + 0] = n2[i];
    500     rg->mul_in[4 * i + 1] = -d1[i] * n2[i];
    501     rg->mul_in[4 * i + 2] = d_2 * n2[i] - n2[i];
    502     rg->mul_in[4 * i + 3] = -d_2 * d1[i] * n2[i] + 2.0 * d1[i] * n2[i];
    503   }
    504   return rg;
    505 }
    506 
    507 namespace {
    508 
    509 // Apply 1D horizontal scan to each row.
    510 void FastGaussianHorizontal(const hwy::AlignedUniquePtr<RecursiveGaussian>& rg,
    511                             const size_t xsize, const size_t ysize,
    512                             const GetConstRow& in, const GetRow& out,
    513                             ThreadPool* pool) {
    514   const auto process_line = [&](const uint32_t task, size_t /*thread*/) {
    515     const size_t y = task;
    516     FastGaussian1D(rg, static_cast<intptr_t>(xsize), in(y), out(y));
    517   };
    518 
    519   JXL_CHECK(RunOnPool(pool, 0, ysize, ThreadPool::NoInit, process_line,
    520                       "FastGaussianHorizontal"));
    521 }
    522 
    523 }  // namespace
    524 
    525 void FastGaussian(const hwy::AlignedUniquePtr<RecursiveGaussian>& rg,
    526                   const size_t xsize, const size_t ysize, const GetConstRow& in,
    527                   const GetRow& temp, const GetRow& out, ThreadPool* pool) {
    528   FastGaussianHorizontal(rg, xsize, ysize, in, temp, pool);
    529   GetConstRow temp_in = [&](size_t y) { return temp(y); };
    530   HWY_DYNAMIC_DISPATCH(FastGaussianVertical)
    531   (rg, xsize, ysize, temp_in, out, pool);
    532 }
    533 
    534 }  // namespace jxl
    535 #endif  // HWY_ONCE