libjxl

FORK: libjxl patches used on blog
git clone https://git.neptards.moe/blog/libjxl.git
Log | Files | Refs | Submodules | README | LICENSE

convolve_separable5.cc (9572B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 #include "lib/jxl/convolve.h"
      7 
      8 #undef HWY_TARGET_INCLUDE
      9 #define HWY_TARGET_INCLUDE "lib/jxl/convolve_separable5.cc"
     10 #include <hwy/foreach_target.h>
     11 #include <hwy/highway.h>
     12 
     13 #include "lib/jxl/convolve-inl.h"
     14 
     15 HWY_BEFORE_NAMESPACE();
     16 namespace jxl {
     17 namespace HWY_NAMESPACE {
     18 
     19 // These templates are not found via ADL.
     20 using hwy::HWY_NAMESPACE::Add;
     21 using hwy::HWY_NAMESPACE::Mul;
     22 using hwy::HWY_NAMESPACE::MulAdd;
     23 using hwy::HWY_NAMESPACE::Vec;
     24 
     25 // 5x5 convolution by separable kernel with a single scan through the input.
     26 // This is more cache-efficient than separate horizontal/vertical passes, and
     27 // possibly faster (given enough registers) than tiling and/or transposing.
     28 //
     29 // Overview: imagine a 5x5 window around a central pixel. First convolve the
     30 // rows by multiplying the pixels with the corresponding weights from
     31 // WeightsSeparable5.horz[abs(x_offset) * 4]. Then multiply each of these
     32 // intermediate results by the corresponding vertical weight, i.e.
     33 // vert[abs(y_offset) * 4]. Finally, store the sum of these values as the
     34 // convolution result at the position of the central pixel in the output.
     35 //
     36 // Each of these operations uses SIMD vectors. The central pixel and most
     37 // importantly the output are aligned, so neighnoring pixels (e.g. x_offset=1)
     38 // require unaligned loads. Because weights are supplied in identical groups of
     39 // 4, we can use LoadDup128 to load them (slightly faster).
     40 //
     41 // Uses mirrored boundary handling. Until x >= kRadius, the horizontal
     42 // convolution uses Neighbors class to shuffle vectors as if each of its lanes
     43 // had been loaded from the mirrored offset. Similarly, the last full vector to
     44 // write uses mirroring. In the case of scalar vectors, Neighbors is not usable
     45 // and the value is loaded directly. Otherwise, the number of valid pixels
     46 // modulo the vector size enables a small optimization: for smaller offsets,
     47 // a non-mirrored load is sufficient.
     48 class Separable5Strategy {
     49   using D = HWY_CAPPED(float, 16);
     50   using V = Vec<D>;
     51 
     52  public:
     53   static constexpr int64_t kRadius = 2;
     54 
     55   template <size_t kSizeModN, class WrapRow>
     56   static JXL_MAYBE_INLINE void ConvolveRow(
     57       const float* const JXL_RESTRICT row_m, const size_t xsize,
     58       const int64_t stride, const WrapRow& wrap_row,
     59       const WeightsSeparable5& weights, float* const JXL_RESTRICT row_out) {
     60     const D d;
     61     const int64_t neg_stride = -stride;  // allows LEA addressing.
     62     const float* const JXL_RESTRICT row_t2 =
     63         wrap_row(row_m + 2 * neg_stride, stride);
     64     const float* const JXL_RESTRICT row_t1 =
     65         wrap_row(row_m + 1 * neg_stride, stride);
     66     const float* const JXL_RESTRICT row_b1 =
     67         wrap_row(row_m + 1 * stride, stride);
     68     const float* const JXL_RESTRICT row_b2 =
     69         wrap_row(row_m + 2 * stride, stride);
     70 
     71     const V wh0 = LoadDup128(d, weights.horz + 0 * 4);
     72     const V wh1 = LoadDup128(d, weights.horz + 1 * 4);
     73     const V wh2 = LoadDup128(d, weights.horz + 2 * 4);
     74     const V wv0 = LoadDup128(d, weights.vert + 0 * 4);
     75     const V wv1 = LoadDup128(d, weights.vert + 1 * 4);
     76     const V wv2 = LoadDup128(d, weights.vert + 2 * 4);
     77 
     78     size_t x = 0;
     79 
     80     // More than one iteration for scalars.
     81     for (; x < kRadius; x += Lanes(d)) {
     82       const V conv0 =
     83           Mul(HorzConvolveFirst(row_m, x, xsize, wh0, wh1, wh2), wv0);
     84 
     85       const V conv1t = HorzConvolveFirst(row_t1, x, xsize, wh0, wh1, wh2);
     86       const V conv1b = HorzConvolveFirst(row_b1, x, xsize, wh0, wh1, wh2);
     87       const V conv1 = MulAdd(Add(conv1t, conv1b), wv1, conv0);
     88 
     89       const V conv2t = HorzConvolveFirst(row_t2, x, xsize, wh0, wh1, wh2);
     90       const V conv2b = HorzConvolveFirst(row_b2, x, xsize, wh0, wh1, wh2);
     91       const V conv2 = MulAdd(Add(conv2t, conv2b), wv2, conv1);
     92       Store(conv2, d, row_out + x);
     93     }
     94 
     95     // Main loop: load inputs without padding
     96     for (; x + Lanes(d) + kRadius <= xsize; x += Lanes(d)) {
     97       const V conv0 = Mul(HorzConvolve(row_m + x, wh0, wh1, wh2), wv0);
     98 
     99       const V conv1t = HorzConvolve(row_t1 + x, wh0, wh1, wh2);
    100       const V conv1b = HorzConvolve(row_b1 + x, wh0, wh1, wh2);
    101       const V conv1 = MulAdd(Add(conv1t, conv1b), wv1, conv0);
    102 
    103       const V conv2t = HorzConvolve(row_t2 + x, wh0, wh1, wh2);
    104       const V conv2b = HorzConvolve(row_b2 + x, wh0, wh1, wh2);
    105       const V conv2 = MulAdd(Add(conv2t, conv2b), wv2, conv1);
    106       Store(conv2, d, row_out + x);
    107     }
    108 
    109     // Last full vector to write (the above loop handled mod >= kRadius)
    110 #if HWY_TARGET == HWY_SCALAR
    111     while (x < xsize) {
    112 #else
    113     if (kSizeModN < kRadius) {
    114 #endif
    115       const V conv0 =
    116           Mul(HorzConvolveLast<kSizeModN>(row_m, x, xsize, wh0, wh1, wh2), wv0);
    117 
    118       const V conv1t =
    119           HorzConvolveLast<kSizeModN>(row_t1, x, xsize, wh0, wh1, wh2);
    120       const V conv1b =
    121           HorzConvolveLast<kSizeModN>(row_b1, x, xsize, wh0, wh1, wh2);
    122       const V conv1 = MulAdd(Add(conv1t, conv1b), wv1, conv0);
    123 
    124       const V conv2t =
    125           HorzConvolveLast<kSizeModN>(row_t2, x, xsize, wh0, wh1, wh2);
    126       const V conv2b =
    127           HorzConvolveLast<kSizeModN>(row_b2, x, xsize, wh0, wh1, wh2);
    128       const V conv2 = MulAdd(Add(conv2t, conv2b), wv2, conv1);
    129       Store(conv2, d, row_out + x);
    130       x += Lanes(d);
    131     }
    132 
    133     // If mod = 0, the above vector was the last.
    134     if (kSizeModN != 0) {
    135       for (; x < xsize; ++x) {
    136         float mul = 0.0f;
    137         for (int64_t dy = -kRadius; dy <= kRadius; ++dy) {
    138           const float wy = weights.vert[std::abs(dy) * 4];
    139           const float* clamped_row = wrap_row(row_m + dy * stride, stride);
    140           for (int64_t dx = -kRadius; dx <= kRadius; ++dx) {
    141             const float wx = weights.horz[std::abs(dx) * 4];
    142             const int64_t clamped_x = Mirror(x + dx, xsize);
    143             mul += clamped_row[clamped_x] * wx * wy;
    144           }
    145         }
    146         row_out[x] = mul;
    147       }
    148     }
    149   }
    150 
    151  private:
    152   // Same as HorzConvolve for the first/last vector in a row.
    153   static JXL_MAYBE_INLINE V HorzConvolveFirst(
    154       const float* const JXL_RESTRICT row, const int64_t x, const int64_t xsize,
    155       const V wh0, const V wh1, const V wh2) {
    156     const D d;
    157     const V c = LoadU(d, row + x);
    158     const V mul0 = Mul(c, wh0);
    159 
    160 #if HWY_TARGET == HWY_SCALAR
    161     const V l1 = LoadU(d, row + Mirror(x - 1, xsize));
    162     const V l2 = LoadU(d, row + Mirror(x - 2, xsize));
    163 #else
    164     (void)xsize;
    165     const V l1 = Neighbors::FirstL1(c);
    166     const V l2 = Neighbors::FirstL2(c);
    167 #endif
    168 
    169     const V r1 = LoadU(d, row + x + 1);
    170     const V r2 = LoadU(d, row + x + 2);
    171 
    172     const V mul1 = MulAdd(Add(l1, r1), wh1, mul0);
    173     const V mul2 = MulAdd(Add(l2, r2), wh2, mul1);
    174     return mul2;
    175   }
    176 
    177   template <size_t kSizeModN>
    178   static JXL_MAYBE_INLINE V
    179   HorzConvolveLast(const float* const JXL_RESTRICT row, const int64_t x,
    180                    const int64_t xsize, const V wh0, const V wh1, const V wh2) {
    181     const D d;
    182     const V c = LoadU(d, row + x);
    183     const V mul0 = Mul(c, wh0);
    184 
    185     const V l1 = LoadU(d, row + x - 1);
    186     const V l2 = LoadU(d, row + x - 2);
    187 
    188     V r1;
    189     V r2;
    190 #if HWY_TARGET == HWY_SCALAR
    191     r1 = LoadU(d, row + Mirror(x + 1, xsize));
    192     r2 = LoadU(d, row + Mirror(x + 2, xsize));
    193 #else
    194     const size_t N = Lanes(d);
    195     if (kSizeModN == 0) {
    196       r2 = TableLookupLanes(c, SetTableIndices(d, MirrorLanes(N - 2)));
    197       r1 = TableLookupLanes(c, SetTableIndices(d, MirrorLanes(N - 1)));
    198     } else {  // == 1
    199       const auto last = LoadU(d, row + xsize - N);
    200       r2 = TableLookupLanes(last, SetTableIndices(d, MirrorLanes(N - 1)));
    201       r1 = last;
    202     }
    203 #endif
    204 
    205     // Sum of pixels with Manhattan distance i, multiplied by weights[i].
    206     const V sum1 = Add(l1, r1);
    207     const V mul1 = MulAdd(sum1, wh1, mul0);
    208     const V sum2 = Add(l2, r2);
    209     const V mul2 = MulAdd(sum2, wh2, mul1);
    210     return mul2;
    211   }
    212 
    213   // Requires kRadius valid pixels before/after pos.
    214   static JXL_MAYBE_INLINE V HorzConvolve(const float* const JXL_RESTRICT pos,
    215                                          const V wh0, const V wh1,
    216                                          const V wh2) {
    217     const D d;
    218     const V c = LoadU(d, pos);
    219     const V mul0 = Mul(c, wh0);
    220 
    221     // Loading anew is faster than combining vectors.
    222     const V l1 = LoadU(d, pos - 1);
    223     const V r1 = LoadU(d, pos + 1);
    224     const V l2 = LoadU(d, pos - 2);
    225     const V r2 = LoadU(d, pos + 2);
    226     // Sum of pixels with Manhattan distance i, multiplied by weights[i].
    227     const V sum1 = Add(l1, r1);
    228     const V mul1 = MulAdd(sum1, wh1, mul0);
    229     const V sum2 = Add(l2, r2);
    230     const V mul2 = MulAdd(sum2, wh2, mul1);
    231     return mul2;
    232   }
    233 };
    234 
    235 void Separable5(const ImageF& in, const Rect& rect,
    236                 const WeightsSeparable5& weights, ThreadPool* pool,
    237                 ImageF* out) {
    238   using Conv = ConvolveT<Separable5Strategy>;
    239   if (rect.xsize() >= Conv::MinWidth()) {
    240     Conv::Run(in, rect, weights, pool, out);
    241     return;
    242   }
    243 
    244   SlowSeparable5(in, rect, weights, pool, out, Rect(*out));
    245 }
    246 
    247 // NOLINTNEXTLINE(google-readability-namespace-comments)
    248 }  // namespace HWY_NAMESPACE
    249 }  // namespace jxl
    250 HWY_AFTER_NAMESPACE();
    251 
    252 #if HWY_ONCE
    253 namespace jxl {
    254 
    255 HWY_EXPORT(Separable5);
    256 void Separable5(const ImageF& in, const Rect& rect,
    257                 const WeightsSeparable5& weights, ThreadPool* pool,
    258                 ImageF* out) {
    259   HWY_DYNAMIC_DISPATCH(Separable5)(in, rect, weights, pool, out);
    260 }
    261 
    262 }  // namespace jxl
    263 #endif  // HWY_ONCE