convolve-inl.h (11252B)
1 // Copyright (c) the JPEG XL Project Authors. All rights reserved. 2 // 3 // Use of this source code is governed by a BSD-style 4 // license that can be found in the LICENSE file. 5 6 #if defined(LIB_JXL_CONVOLVE_INL_H_) == defined(HWY_TARGET_TOGGLE) 7 #ifdef LIB_JXL_CONVOLVE_INL_H_ 8 #undef LIB_JXL_CONVOLVE_INL_H_ 9 #else 10 #define LIB_JXL_CONVOLVE_INL_H_ 11 #endif 12 13 #include <hwy/highway.h> 14 15 #include "lib/jxl/base/status.h" 16 #include "lib/jxl/image_ops.h" 17 18 HWY_BEFORE_NAMESPACE(); 19 namespace jxl { 20 namespace HWY_NAMESPACE { 21 namespace { 22 23 // These templates are not found via ADL. 24 using hwy::HWY_NAMESPACE::Broadcast; 25 #if HWY_TARGET != HWY_SCALAR 26 using hwy::HWY_NAMESPACE::CombineShiftRightBytes; 27 #endif 28 using hwy::HWY_NAMESPACE::TableLookupLanes; 29 using hwy::HWY_NAMESPACE::Vec; 30 31 // Synthesizes left/right neighbors from a vector of center pixels. 32 class Neighbors { 33 public: 34 using D = HWY_CAPPED(float, 16); 35 using V = Vec<D>; 36 37 // Returns l[i] == c[Mirror(i - 1)]. 38 HWY_INLINE HWY_MAYBE_UNUSED static V FirstL1(const V c) { 39 #if HWY_CAP_GE256 40 const D d; 41 HWY_ALIGN constexpr int32_t lanes[16] = {0, 0, 1, 2, 3, 4, 5, 6, 42 7, 8, 9, 10, 11, 12, 13, 14}; 43 const auto indices = SetTableIndices(d, lanes); 44 // c = PONM'LKJI 45 return TableLookupLanes(c, indices); // ONML'KJII 46 #elif HWY_TARGET == HWY_SCALAR 47 return c; // Same (the first mirrored value is the last valid one) 48 #else // 128 bit 49 // c = LKJI 50 #if HWY_TARGET <= (1 << HWY_HIGHEST_TARGET_BIT_X86) 51 return V{_mm_shuffle_ps(c.raw, c.raw, _MM_SHUFFLE(2, 1, 0, 0))}; // KJII 52 #else 53 const D d; 54 // TODO(deymo): Figure out if this can be optimized using a single vsri 55 // instruction to convert LKJI to KJII. 56 HWY_ALIGN constexpr int lanes[4] = {0, 0, 1, 2}; // KJII 57 const auto indices = SetTableIndices(d, lanes); 58 return TableLookupLanes(c, indices); 59 #endif 60 #endif 61 } 62 63 // Returns l[i] == c[Mirror(i - 2)]. 64 HWY_INLINE HWY_MAYBE_UNUSED static V FirstL2(const V c) { 65 #if HWY_CAP_GE256 66 const D d; 67 HWY_ALIGN constexpr int32_t lanes[16] = {1, 0, 0, 1, 2, 3, 4, 5, 68 6, 7, 8, 9, 10, 11, 12, 13}; 69 const auto indices = SetTableIndices(d, lanes); 70 // c = PONM'LKJI 71 return TableLookupLanes(c, indices); // NMLK'JIIJ 72 #elif HWY_TARGET == HWY_SCALAR 73 const D d; 74 JXL_ASSERT(false); // unsupported, avoid calling this. 75 return Zero(d); 76 #else // 128 bit 77 // c = LKJI 78 #if HWY_TARGET <= (1 << HWY_HIGHEST_TARGET_BIT_X86) 79 return V{_mm_shuffle_ps(c.raw, c.raw, _MM_SHUFFLE(1, 0, 0, 1))}; // JIIJ 80 #else 81 const D d; 82 HWY_ALIGN constexpr int lanes[4] = {1, 0, 0, 1}; // JIIJ 83 const auto indices = SetTableIndices(d, lanes); 84 return TableLookupLanes(c, indices); 85 #endif 86 #endif 87 } 88 89 // Returns l[i] == c[Mirror(i - 3)]. 90 HWY_INLINE HWY_MAYBE_UNUSED static V FirstL3(const V c) { 91 #if HWY_CAP_GE256 92 const D d; 93 HWY_ALIGN constexpr int32_t lanes[16] = {2, 1, 0, 0, 1, 2, 3, 4, 94 5, 6, 7, 8, 9, 10, 11, 12}; 95 const auto indices = SetTableIndices(d, lanes); 96 // c = PONM'LKJI 97 return TableLookupLanes(c, indices); // MLKJ'IIJK 98 #elif HWY_TARGET == HWY_SCALAR 99 const D d; 100 JXL_ASSERT(false); // unsupported, avoid calling this. 101 return Zero(d); 102 #else // 128 bit 103 // c = LKJI 104 #if HWY_TARGET <= (1 << HWY_HIGHEST_TARGET_BIT_X86) 105 return V{_mm_shuffle_ps(c.raw, c.raw, _MM_SHUFFLE(0, 0, 1, 2))}; // IIJK 106 #else 107 const D d; 108 HWY_ALIGN constexpr int lanes[4] = {2, 1, 0, 0}; // IIJK 109 const auto indices = SetTableIndices(d, lanes); 110 return TableLookupLanes(c, indices); 111 #endif 112 #endif 113 } 114 }; 115 116 #if HWY_TARGET != HWY_SCALAR 117 118 // Returns indices for SetTableIndices such that TableLookupLanes on the 119 // rightmost unaligned vector (rightmost sample in its most-significant lane) 120 // returns the mirrored values, with the mirror outside the last valid sample. 121 inline const int32_t* MirrorLanes(const size_t mod) { 122 const HWY_CAPPED(float, 16) d; 123 constexpr size_t kN = MaxLanes(d); 124 125 // For mod = `image width mod 16` 0..15: 126 // last full vec mirrored (mem order) loadedVec mirrorVec idxVec 127 // 0123456789abcdef| fedcba9876543210 fed..210 012..def 012..def 128 // 0123456789abcdef|0 0fedcba98765432 0fe..321 234..f00 123..eff 129 // 0123456789abcdef|01 10fedcba987654 10f..432 456..110 234..ffe 130 // 0123456789abcdef|012 210fedcba9876 210..543 67..2210 34..ffed 131 // 0123456789abcdef|0123 3210fedcba98 321..654 8..33210 4..ffedc 132 // 0123456789abcdef|01234 43210fedcba 133 // 0123456789abcdef|012345 543210fedc 134 // 0123456789abcdef|0123456 6543210fe 135 // 0123456789abcdef|01234567 76543210 136 // 0123456789abcdef|012345678 8765432 137 // 0123456789abcdef|0123456789 987654 138 // 0123456789abcdef|0123456789A A9876 139 // 0123456789abcdef|0123456789AB BA98 140 // 0123456789abcdef|0123456789ABC CBA 141 // 0123456789abcdef|0123456789ABCD DC 142 // 0123456789abcdef|0123456789ABCDE E EDC..10f EED..210 ffe..321 143 #if HWY_CAP_GE512 144 HWY_ALIGN static constexpr int32_t idx_lanes[2 * kN - 1] = { 145 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 15, // 146 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}; 147 #elif HWY_CAP_GE256 148 HWY_ALIGN static constexpr int32_t idx_lanes[2 * kN - 1] = { 149 1, 2, 3, 4, 5, 6, 7, 7, // 150 6, 5, 4, 3, 2, 1, 0}; 151 #else // 128-bit 152 HWY_ALIGN static constexpr int32_t idx_lanes[2 * kN - 1] = {1, 2, 3, 3, // 153 2, 1, 0}; 154 #endif 155 return idx_lanes + kN - 1 - mod; 156 } 157 158 #endif // HWY_TARGET != HWY_SCALAR 159 160 // Single entry point for convolution. 161 // "Strategy" (Direct*/Separable*) decides kernel size and how to evaluate it. 162 template <class Strategy> 163 class ConvolveT { 164 static constexpr int64_t kRadius = Strategy::kRadius; 165 using Simd = HWY_CAPPED(float, 16); 166 167 public: 168 static size_t MinWidth() { 169 #if HWY_TARGET == HWY_SCALAR 170 // First/Last use mirrored loads of up to +/- kRadius. 171 return 2 * kRadius; 172 #else 173 return Lanes(Simd()) + kRadius; 174 #endif 175 } 176 177 // "Image" is ImageF or Image3F. 178 template <class Image, class Weights> 179 static void Run(const Image& in, const Rect& rect, const Weights& weights, 180 ThreadPool* pool, Image* out) { 181 JXL_CHECK(SameSize(rect, *out)); 182 JXL_CHECK(rect.xsize() >= MinWidth()); 183 184 static_assert(static_cast<int64_t>(kRadius) <= 3, 185 "Must handle [0, kRadius) and >= kRadius"); 186 switch (rect.xsize() % Lanes(Simd())) { 187 case 0: 188 return RunRows<0>(in, rect, weights, pool, out); 189 case 1: 190 return RunRows<1>(in, rect, weights, pool, out); 191 case 2: 192 return RunRows<2>(in, rect, weights, pool, out); 193 default: 194 return RunRows<3>(in, rect, weights, pool, out); 195 } 196 } 197 198 private: 199 template <size_t kSizeModN, class WrapRow, class Weights> 200 static JXL_INLINE void RunRow(const float* JXL_RESTRICT in, 201 const size_t xsize, const int64_t stride, 202 const WrapRow& wrap_row, const Weights& weights, 203 float* JXL_RESTRICT out) { 204 Strategy::template ConvolveRow<kSizeModN>(in, xsize, stride, wrap_row, 205 weights, out); 206 } 207 208 template <size_t kSizeModN, class Weights> 209 static JXL_INLINE void RunBorderRows(const ImageF& in, const Rect& rect, 210 const int64_t ybegin, const int64_t yend, 211 const Weights& weights, ImageF* out) { 212 const int64_t stride = in.PixelsPerRow(); 213 const WrapRowMirror wrap_row(in, rect.ysize()); 214 for (int64_t y = ybegin; y < yend; ++y) { 215 RunRow<kSizeModN>(rect.ConstRow(in, y), rect.xsize(), stride, wrap_row, 216 weights, out->Row(y)); 217 } 218 } 219 220 // Image3F. 221 template <size_t kSizeModN, class Weights> 222 static JXL_INLINE void RunBorderRows(const Image3F& in, const Rect& rect, 223 const int64_t ybegin, const int64_t yend, 224 const Weights& weights, Image3F* out) { 225 const int64_t stride = in.PixelsPerRow(); 226 for (int64_t y = ybegin; y < yend; ++y) { 227 for (size_t c = 0; c < 3; ++c) { 228 const WrapRowMirror wrap_row(in.Plane(c), rect.ysize()); 229 RunRow<kSizeModN>(rect.ConstPlaneRow(in, c, y), rect.xsize(), stride, 230 wrap_row, weights, out->PlaneRow(c, y)); 231 } 232 } 233 } 234 235 template <size_t kSizeModN, class Weights> 236 static JXL_INLINE void RunInteriorRows(const ImageF& in, const Rect& rect, 237 const int64_t ybegin, 238 const int64_t yend, 239 const Weights& weights, 240 ThreadPool* pool, ImageF* out) { 241 const int64_t stride = in.PixelsPerRow(); 242 JXL_CHECK(RunOnPool( 243 pool, ybegin, yend, ThreadPool::NoInit, 244 [&](const uint32_t y, size_t /*thread*/) HWY_ATTR { 245 RunRow<kSizeModN>(rect.ConstRow(in, y), rect.xsize(), stride, 246 WrapRowUnchanged(), weights, out->Row(y)); 247 }, 248 "Convolve")); 249 } 250 251 // Image3F. 252 template <size_t kSizeModN, class Weights> 253 static JXL_INLINE void RunInteriorRows(const Image3F& in, const Rect& rect, 254 const int64_t ybegin, 255 const int64_t yend, 256 const Weights& weights, 257 ThreadPool* pool, Image3F* out) { 258 const int64_t stride = in.PixelsPerRow(); 259 JXL_CHECK(RunOnPool( 260 pool, ybegin, yend, ThreadPool::NoInit, 261 [&](const uint32_t y, size_t /*thread*/) HWY_ATTR { 262 for (size_t c = 0; c < 3; ++c) { 263 RunRow<kSizeModN>(rect.ConstPlaneRow(in, c, y), rect.xsize(), 264 stride, WrapRowUnchanged(), weights, 265 out->PlaneRow(c, y)); 266 } 267 }, 268 "Convolve3")); 269 } 270 271 template <size_t kSizeModN, class Image, class Weights> 272 static JXL_INLINE void RunRows(const Image& in, const Rect& rect, 273 const Weights& weights, ThreadPool* pool, 274 Image* out) { 275 const int64_t ysize = rect.ysize(); 276 RunBorderRows<kSizeModN>(in, rect, 0, 277 std::min(static_cast<int64_t>(kRadius), ysize), 278 weights, out); 279 if (ysize > 2 * static_cast<int64_t>(kRadius)) { 280 RunInteriorRows<kSizeModN>(in, rect, static_cast<int64_t>(kRadius), 281 ysize - static_cast<int64_t>(kRadius), weights, 282 pool, out); 283 } 284 if (ysize > static_cast<int64_t>(kRadius)) { 285 RunBorderRows<kSizeModN>(in, rect, ysize - static_cast<int64_t>(kRadius), 286 ysize, weights, out); 287 } 288 } 289 }; 290 291 } // namespace 292 // NOLINTNEXTLINE(google-readability-namespace-comments) 293 } // namespace HWY_NAMESPACE 294 } // namespace jxl 295 HWY_AFTER_NAMESPACE(); 296 297 #endif // LIB_JXL_CONVOLVE_INL_H_