libjxl

FORK: libjxl patches used on blog
git clone https://git.neptards.moe/blog/libjxl.git
Log | Files | Refs | Submodules | README | LICENSE

convolve-inl.h (11252B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 #if defined(LIB_JXL_CONVOLVE_INL_H_) == defined(HWY_TARGET_TOGGLE)
      7 #ifdef LIB_JXL_CONVOLVE_INL_H_
      8 #undef LIB_JXL_CONVOLVE_INL_H_
      9 #else
     10 #define LIB_JXL_CONVOLVE_INL_H_
     11 #endif
     12 
     13 #include <hwy/highway.h>
     14 
     15 #include "lib/jxl/base/status.h"
     16 #include "lib/jxl/image_ops.h"
     17 
     18 HWY_BEFORE_NAMESPACE();
     19 namespace jxl {
     20 namespace HWY_NAMESPACE {
     21 namespace {
     22 
     23 // These templates are not found via ADL.
     24 using hwy::HWY_NAMESPACE::Broadcast;
     25 #if HWY_TARGET != HWY_SCALAR
     26 using hwy::HWY_NAMESPACE::CombineShiftRightBytes;
     27 #endif
     28 using hwy::HWY_NAMESPACE::TableLookupLanes;
     29 using hwy::HWY_NAMESPACE::Vec;
     30 
     31 // Synthesizes left/right neighbors from a vector of center pixels.
     32 class Neighbors {
     33  public:
     34   using D = HWY_CAPPED(float, 16);
     35   using V = Vec<D>;
     36 
     37   // Returns l[i] == c[Mirror(i - 1)].
     38   HWY_INLINE HWY_MAYBE_UNUSED static V FirstL1(const V c) {
     39 #if HWY_CAP_GE256
     40     const D d;
     41     HWY_ALIGN constexpr int32_t lanes[16] = {0, 0, 1, 2,  3,  4,  5,  6,
     42                                              7, 8, 9, 10, 11, 12, 13, 14};
     43     const auto indices = SetTableIndices(d, lanes);
     44     // c = PONM'LKJI
     45     return TableLookupLanes(c, indices);  // ONML'KJII
     46 #elif HWY_TARGET == HWY_SCALAR
     47     return c;  // Same (the first mirrored value is the last valid one)
     48 #else  // 128 bit
     49     // c = LKJI
     50 #if HWY_TARGET <= (1 << HWY_HIGHEST_TARGET_BIT_X86)
     51     return V{_mm_shuffle_ps(c.raw, c.raw, _MM_SHUFFLE(2, 1, 0, 0))};  // KJII
     52 #else
     53     const D d;
     54     // TODO(deymo): Figure out if this can be optimized using a single vsri
     55     // instruction to convert LKJI to KJII.
     56     HWY_ALIGN constexpr int lanes[4] = {0, 0, 1, 2};  // KJII
     57     const auto indices = SetTableIndices(d, lanes);
     58     return TableLookupLanes(c, indices);
     59 #endif
     60 #endif
     61   }
     62 
     63   // Returns l[i] == c[Mirror(i - 2)].
     64   HWY_INLINE HWY_MAYBE_UNUSED static V FirstL2(const V c) {
     65 #if HWY_CAP_GE256
     66     const D d;
     67     HWY_ALIGN constexpr int32_t lanes[16] = {1, 0, 0, 1, 2,  3,  4,  5,
     68                                              6, 7, 8, 9, 10, 11, 12, 13};
     69     const auto indices = SetTableIndices(d, lanes);
     70     // c = PONM'LKJI
     71     return TableLookupLanes(c, indices);  // NMLK'JIIJ
     72 #elif HWY_TARGET == HWY_SCALAR
     73     const D d;
     74     JXL_ASSERT(false);  // unsupported, avoid calling this.
     75     return Zero(d);
     76 #else  // 128 bit
     77     // c = LKJI
     78 #if HWY_TARGET <= (1 << HWY_HIGHEST_TARGET_BIT_X86)
     79     return V{_mm_shuffle_ps(c.raw, c.raw, _MM_SHUFFLE(1, 0, 0, 1))};  // JIIJ
     80 #else
     81     const D d;
     82     HWY_ALIGN constexpr int lanes[4] = {1, 0, 0, 1};  // JIIJ
     83     const auto indices = SetTableIndices(d, lanes);
     84     return TableLookupLanes(c, indices);
     85 #endif
     86 #endif
     87   }
     88 
     89   // Returns l[i] == c[Mirror(i - 3)].
     90   HWY_INLINE HWY_MAYBE_UNUSED static V FirstL3(const V c) {
     91 #if HWY_CAP_GE256
     92     const D d;
     93     HWY_ALIGN constexpr int32_t lanes[16] = {2, 1, 0, 0, 1, 2,  3,  4,
     94                                              5, 6, 7, 8, 9, 10, 11, 12};
     95     const auto indices = SetTableIndices(d, lanes);
     96     // c = PONM'LKJI
     97     return TableLookupLanes(c, indices);  // MLKJ'IIJK
     98 #elif HWY_TARGET == HWY_SCALAR
     99     const D d;
    100     JXL_ASSERT(false);  // unsupported, avoid calling this.
    101     return Zero(d);
    102 #else  // 128 bit
    103     // c = LKJI
    104 #if HWY_TARGET <= (1 << HWY_HIGHEST_TARGET_BIT_X86)
    105     return V{_mm_shuffle_ps(c.raw, c.raw, _MM_SHUFFLE(0, 0, 1, 2))};  // IIJK
    106 #else
    107     const D d;
    108     HWY_ALIGN constexpr int lanes[4] = {2, 1, 0, 0};  // IIJK
    109     const auto indices = SetTableIndices(d, lanes);
    110     return TableLookupLanes(c, indices);
    111 #endif
    112 #endif
    113   }
    114 };
    115 
    116 #if HWY_TARGET != HWY_SCALAR
    117 
    118 // Returns indices for SetTableIndices such that TableLookupLanes on the
    119 // rightmost unaligned vector (rightmost sample in its most-significant lane)
    120 // returns the mirrored values, with the mirror outside the last valid sample.
    121 inline const int32_t* MirrorLanes(const size_t mod) {
    122   const HWY_CAPPED(float, 16) d;
    123   constexpr size_t kN = MaxLanes(d);
    124 
    125   // For mod = `image width mod 16` 0..15:
    126   // last full vec     mirrored (mem order)  loadedVec  mirrorVec  idxVec
    127   // 0123456789abcdef| fedcba9876543210      fed..210   012..def   012..def
    128   // 0123456789abcdef|0 0fedcba98765432      0fe..321   234..f00   123..eff
    129   // 0123456789abcdef|01 10fedcba987654      10f..432   456..110   234..ffe
    130   // 0123456789abcdef|012 210fedcba9876      210..543   67..2210   34..ffed
    131   // 0123456789abcdef|0123 3210fedcba98      321..654   8..33210   4..ffedc
    132   // 0123456789abcdef|01234 43210fedcba
    133   // 0123456789abcdef|012345 543210fedc
    134   // 0123456789abcdef|0123456 6543210fe
    135   // 0123456789abcdef|01234567 76543210
    136   // 0123456789abcdef|012345678 8765432
    137   // 0123456789abcdef|0123456789 987654
    138   // 0123456789abcdef|0123456789A A9876
    139   // 0123456789abcdef|0123456789AB BA98
    140   // 0123456789abcdef|0123456789ABC CBA
    141   // 0123456789abcdef|0123456789ABCD DC
    142   // 0123456789abcdef|0123456789ABCDE E      EDC..10f   EED..210   ffe..321
    143 #if HWY_CAP_GE512
    144   HWY_ALIGN static constexpr int32_t idx_lanes[2 * kN - 1] = {
    145       1,  2,  3,  4,  5,  6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 15,  //
    146       14, 13, 12, 11, 10, 9, 8, 7, 6, 5,  4,  3,  2,  1,  0};
    147 #elif HWY_CAP_GE256
    148   HWY_ALIGN static constexpr int32_t idx_lanes[2 * kN - 1] = {
    149       1, 2, 3, 4, 5, 6, 7, 7,  //
    150       6, 5, 4, 3, 2, 1, 0};
    151 #else  // 128-bit
    152   HWY_ALIGN static constexpr int32_t idx_lanes[2 * kN - 1] = {1, 2, 3, 3,  //
    153                                                               2, 1, 0};
    154 #endif
    155   return idx_lanes + kN - 1 - mod;
    156 }
    157 
    158 #endif  // HWY_TARGET != HWY_SCALAR
    159 
    160 // Single entry point for convolution.
    161 // "Strategy" (Direct*/Separable*) decides kernel size and how to evaluate it.
    162 template <class Strategy>
    163 class ConvolveT {
    164   static constexpr int64_t kRadius = Strategy::kRadius;
    165   using Simd = HWY_CAPPED(float, 16);
    166 
    167  public:
    168   static size_t MinWidth() {
    169 #if HWY_TARGET == HWY_SCALAR
    170     // First/Last use mirrored loads of up to +/- kRadius.
    171     return 2 * kRadius;
    172 #else
    173     return Lanes(Simd()) + kRadius;
    174 #endif
    175   }
    176 
    177   // "Image" is ImageF or Image3F.
    178   template <class Image, class Weights>
    179   static void Run(const Image& in, const Rect& rect, const Weights& weights,
    180                   ThreadPool* pool, Image* out) {
    181     JXL_CHECK(SameSize(rect, *out));
    182     JXL_CHECK(rect.xsize() >= MinWidth());
    183 
    184     static_assert(static_cast<int64_t>(kRadius) <= 3,
    185                   "Must handle [0, kRadius) and >= kRadius");
    186     switch (rect.xsize() % Lanes(Simd())) {
    187       case 0:
    188         return RunRows<0>(in, rect, weights, pool, out);
    189       case 1:
    190         return RunRows<1>(in, rect, weights, pool, out);
    191       case 2:
    192         return RunRows<2>(in, rect, weights, pool, out);
    193       default:
    194         return RunRows<3>(in, rect, weights, pool, out);
    195     }
    196   }
    197 
    198  private:
    199   template <size_t kSizeModN, class WrapRow, class Weights>
    200   static JXL_INLINE void RunRow(const float* JXL_RESTRICT in,
    201                                 const size_t xsize, const int64_t stride,
    202                                 const WrapRow& wrap_row, const Weights& weights,
    203                                 float* JXL_RESTRICT out) {
    204     Strategy::template ConvolveRow<kSizeModN>(in, xsize, stride, wrap_row,
    205                                               weights, out);
    206   }
    207 
    208   template <size_t kSizeModN, class Weights>
    209   static JXL_INLINE void RunBorderRows(const ImageF& in, const Rect& rect,
    210                                        const int64_t ybegin, const int64_t yend,
    211                                        const Weights& weights, ImageF* out) {
    212     const int64_t stride = in.PixelsPerRow();
    213     const WrapRowMirror wrap_row(in, rect.ysize());
    214     for (int64_t y = ybegin; y < yend; ++y) {
    215       RunRow<kSizeModN>(rect.ConstRow(in, y), rect.xsize(), stride, wrap_row,
    216                         weights, out->Row(y));
    217     }
    218   }
    219 
    220   // Image3F.
    221   template <size_t kSizeModN, class Weights>
    222   static JXL_INLINE void RunBorderRows(const Image3F& in, const Rect& rect,
    223                                        const int64_t ybegin, const int64_t yend,
    224                                        const Weights& weights, Image3F* out) {
    225     const int64_t stride = in.PixelsPerRow();
    226     for (int64_t y = ybegin; y < yend; ++y) {
    227       for (size_t c = 0; c < 3; ++c) {
    228         const WrapRowMirror wrap_row(in.Plane(c), rect.ysize());
    229         RunRow<kSizeModN>(rect.ConstPlaneRow(in, c, y), rect.xsize(), stride,
    230                           wrap_row, weights, out->PlaneRow(c, y));
    231       }
    232     }
    233   }
    234 
    235   template <size_t kSizeModN, class Weights>
    236   static JXL_INLINE void RunInteriorRows(const ImageF& in, const Rect& rect,
    237                                          const int64_t ybegin,
    238                                          const int64_t yend,
    239                                          const Weights& weights,
    240                                          ThreadPool* pool, ImageF* out) {
    241     const int64_t stride = in.PixelsPerRow();
    242     JXL_CHECK(RunOnPool(
    243         pool, ybegin, yend, ThreadPool::NoInit,
    244         [&](const uint32_t y, size_t /*thread*/) HWY_ATTR {
    245           RunRow<kSizeModN>(rect.ConstRow(in, y), rect.xsize(), stride,
    246                             WrapRowUnchanged(), weights, out->Row(y));
    247         },
    248         "Convolve"));
    249   }
    250 
    251   // Image3F.
    252   template <size_t kSizeModN, class Weights>
    253   static JXL_INLINE void RunInteriorRows(const Image3F& in, const Rect& rect,
    254                                          const int64_t ybegin,
    255                                          const int64_t yend,
    256                                          const Weights& weights,
    257                                          ThreadPool* pool, Image3F* out) {
    258     const int64_t stride = in.PixelsPerRow();
    259     JXL_CHECK(RunOnPool(
    260         pool, ybegin, yend, ThreadPool::NoInit,
    261         [&](const uint32_t y, size_t /*thread*/) HWY_ATTR {
    262           for (size_t c = 0; c < 3; ++c) {
    263             RunRow<kSizeModN>(rect.ConstPlaneRow(in, c, y), rect.xsize(),
    264                               stride, WrapRowUnchanged(), weights,
    265                               out->PlaneRow(c, y));
    266           }
    267         },
    268         "Convolve3"));
    269   }
    270 
    271   template <size_t kSizeModN, class Image, class Weights>
    272   static JXL_INLINE void RunRows(const Image& in, const Rect& rect,
    273                                  const Weights& weights, ThreadPool* pool,
    274                                  Image* out) {
    275     const int64_t ysize = rect.ysize();
    276     RunBorderRows<kSizeModN>(in, rect, 0,
    277                              std::min(static_cast<int64_t>(kRadius), ysize),
    278                              weights, out);
    279     if (ysize > 2 * static_cast<int64_t>(kRadius)) {
    280       RunInteriorRows<kSizeModN>(in, rect, static_cast<int64_t>(kRadius),
    281                                  ysize - static_cast<int64_t>(kRadius), weights,
    282                                  pool, out);
    283     }
    284     if (ysize > static_cast<int64_t>(kRadius)) {
    285       RunBorderRows<kSizeModN>(in, rect, ysize - static_cast<int64_t>(kRadius),
    286                                ysize, weights, out);
    287     }
    288   }
    289 };
    290 
    291 }  // namespace
    292 // NOLINTNEXTLINE(google-readability-namespace-comments)
    293 }  // namespace HWY_NAMESPACE
    294 }  // namespace jxl
    295 HWY_AFTER_NAMESPACE();
    296 
    297 #endif  // LIB_JXL_CONVOLVE_INL_H_