dct-inl.h (12399B)
1 // Copyright (c) the JPEG XL Project Authors. All rights reserved. 2 // 3 // Use of this source code is governed by a BSD-style 4 // license that can be found in the LICENSE file. 5 6 // Fast SIMD floating-point (I)DCT, any power of two. 7 8 #if defined(LIB_JXL_DCT_INL_H_) == defined(HWY_TARGET_TOGGLE) 9 #ifdef LIB_JXL_DCT_INL_H_ 10 #undef LIB_JXL_DCT_INL_H_ 11 #else 12 #define LIB_JXL_DCT_INL_H_ 13 #endif 14 15 #include <stddef.h> 16 17 #include <hwy/highway.h> 18 19 #include "lib/jxl/dct_block-inl.h" 20 #include "lib/jxl/dct_scales.h" 21 #include "lib/jxl/transpose-inl.h" 22 HWY_BEFORE_NAMESPACE(); 23 namespace jxl { 24 namespace HWY_NAMESPACE { 25 namespace { 26 27 // These templates are not found via ADL. 28 using hwy::HWY_NAMESPACE::Add; 29 using hwy::HWY_NAMESPACE::Mul; 30 using hwy::HWY_NAMESPACE::MulAdd; 31 using hwy::HWY_NAMESPACE::NegMulAdd; 32 using hwy::HWY_NAMESPACE::Sub; 33 34 template <size_t SZ> 35 struct FVImpl { 36 using type = HWY_CAPPED(float, SZ); 37 }; 38 39 template <> 40 struct FVImpl<0> { 41 using type = HWY_FULL(float); 42 }; 43 44 template <size_t SZ> 45 using FV = typename FVImpl<SZ>::type; 46 47 // Implementation of Lowest Complexity Self Recursive Radix-2 DCT II/III 48 // Algorithms, by Siriani M. Perera and Jianhua Liu. 49 50 template <size_t N, size_t SZ> 51 struct CoeffBundle { 52 static void AddReverse(const float* JXL_RESTRICT ain1, 53 const float* JXL_RESTRICT ain2, 54 float* JXL_RESTRICT aout) { 55 for (size_t i = 0; i < N; i++) { 56 auto in1 = Load(FV<SZ>(), ain1 + i * SZ); 57 auto in2 = Load(FV<SZ>(), ain2 + (N - i - 1) * SZ); 58 Store(Add(in1, in2), FV<SZ>(), aout + i * SZ); 59 } 60 } 61 static void SubReverse(const float* JXL_RESTRICT ain1, 62 const float* JXL_RESTRICT ain2, 63 float* JXL_RESTRICT aout) { 64 for (size_t i = 0; i < N; i++) { 65 auto in1 = Load(FV<SZ>(), ain1 + i * SZ); 66 auto in2 = Load(FV<SZ>(), ain2 + (N - i - 1) * SZ); 67 Store(Sub(in1, in2), FV<SZ>(), aout + i * SZ); 68 } 69 } 70 static void B(float* JXL_RESTRICT coeff) { 71 auto sqrt2 = Set(FV<SZ>(), kSqrt2); 72 auto in1 = Load(FV<SZ>(), coeff); 73 auto in2 = Load(FV<SZ>(), coeff + SZ); 74 Store(MulAdd(in1, sqrt2, in2), FV<SZ>(), coeff); 75 for (size_t i = 1; i + 1 < N; i++) { 76 auto in1 = Load(FV<SZ>(), coeff + i * SZ); 77 auto in2 = Load(FV<SZ>(), coeff + (i + 1) * SZ); 78 Store(Add(in1, in2), FV<SZ>(), coeff + i * SZ); 79 } 80 } 81 static void BTranspose(float* JXL_RESTRICT coeff) { 82 for (size_t i = N - 1; i > 0; i--) { 83 auto in1 = Load(FV<SZ>(), coeff + i * SZ); 84 auto in2 = Load(FV<SZ>(), coeff + (i - 1) * SZ); 85 Store(Add(in1, in2), FV<SZ>(), coeff + i * SZ); 86 } 87 auto sqrt2 = Set(FV<SZ>(), kSqrt2); 88 auto in1 = Load(FV<SZ>(), coeff); 89 Store(Mul(in1, sqrt2), FV<SZ>(), coeff); 90 } 91 // Ideally optimized away by compiler (except the multiply). 92 static void InverseEvenOdd(const float* JXL_RESTRICT ain, 93 float* JXL_RESTRICT aout) { 94 for (size_t i = 0; i < N / 2; i++) { 95 auto in1 = Load(FV<SZ>(), ain + i * SZ); 96 Store(in1, FV<SZ>(), aout + 2 * i * SZ); 97 } 98 for (size_t i = N / 2; i < N; i++) { 99 auto in1 = Load(FV<SZ>(), ain + i * SZ); 100 Store(in1, FV<SZ>(), aout + (2 * (i - N / 2) + 1) * SZ); 101 } 102 } 103 // Ideally optimized away by compiler. 104 static void ForwardEvenOdd(const float* JXL_RESTRICT ain, size_t ain_stride, 105 float* JXL_RESTRICT aout) { 106 for (size_t i = 0; i < N / 2; i++) { 107 auto in1 = LoadU(FV<SZ>(), ain + 2 * i * ain_stride); 108 Store(in1, FV<SZ>(), aout + i * SZ); 109 } 110 for (size_t i = N / 2; i < N; i++) { 111 auto in1 = LoadU(FV<SZ>(), ain + (2 * (i - N / 2) + 1) * ain_stride); 112 Store(in1, FV<SZ>(), aout + i * SZ); 113 } 114 } 115 // Invoked on full vector. 116 static void Multiply(float* JXL_RESTRICT coeff) { 117 for (size_t i = 0; i < N / 2; i++) { 118 auto in1 = Load(FV<SZ>(), coeff + (N / 2 + i) * SZ); 119 auto mul = Set(FV<SZ>(), WcMultipliers<N>::kMultipliers[i]); 120 Store(Mul(in1, mul), FV<SZ>(), coeff + (N / 2 + i) * SZ); 121 } 122 } 123 static void MultiplyAndAdd(const float* JXL_RESTRICT coeff, 124 float* JXL_RESTRICT out, size_t out_stride) { 125 for (size_t i = 0; i < N / 2; i++) { 126 auto mul = Set(FV<SZ>(), WcMultipliers<N>::kMultipliers[i]); 127 auto in1 = Load(FV<SZ>(), coeff + i * SZ); 128 auto in2 = Load(FV<SZ>(), coeff + (N / 2 + i) * SZ); 129 auto out1 = MulAdd(mul, in2, in1); 130 auto out2 = NegMulAdd(mul, in2, in1); 131 StoreU(out1, FV<SZ>(), out + i * out_stride); 132 StoreU(out2, FV<SZ>(), out + (N - i - 1) * out_stride); 133 } 134 } 135 template <typename Block> 136 static void LoadFromBlock(const Block& in, size_t off, 137 float* JXL_RESTRICT coeff) { 138 for (size_t i = 0; i < N; i++) { 139 Store(in.LoadPart(FV<SZ>(), i, off), FV<SZ>(), coeff + i * SZ); 140 } 141 } 142 template <typename Block> 143 static void StoreToBlockAndScale(const float* JXL_RESTRICT coeff, 144 const Block& out, size_t off) { 145 auto mul = Set(FV<SZ>(), 1.0f / N); 146 for (size_t i = 0; i < N; i++) { 147 out.StorePart(FV<SZ>(), Mul(mul, Load(FV<SZ>(), coeff + i * SZ)), i, off); 148 } 149 } 150 }; 151 152 template <size_t N, size_t SZ> 153 struct DCT1DImpl; 154 155 template <size_t SZ> 156 struct DCT1DImpl<1, SZ> { 157 JXL_INLINE void operator()(float* JXL_RESTRICT mem, float* /* tmp */) {} 158 }; 159 160 template <size_t SZ> 161 struct DCT1DImpl<2, SZ> { 162 JXL_INLINE void operator()(float* JXL_RESTRICT mem, float* /* tmp */) { 163 auto in1 = Load(FV<SZ>(), mem); 164 auto in2 = Load(FV<SZ>(), mem + SZ); 165 Store(Add(in1, in2), FV<SZ>(), mem); 166 Store(Sub(in1, in2), FV<SZ>(), mem + SZ); 167 } 168 }; 169 170 template <size_t N, size_t SZ> 171 struct DCT1DImpl { 172 void operator()(float* JXL_RESTRICT mem, float* JXL_RESTRICT tmp) { 173 CoeffBundle<N / 2, SZ>::AddReverse(mem, mem + N / 2 * SZ, tmp); 174 DCT1DImpl<N / 2, SZ>()(tmp, tmp + N * SZ); 175 CoeffBundle<N / 2, SZ>::SubReverse(mem, mem + N / 2 * SZ, tmp + N / 2 * SZ); 176 CoeffBundle<N, SZ>::Multiply(tmp); 177 DCT1DImpl<N / 2, SZ>()(tmp + N / 2 * SZ, tmp + N * SZ); 178 CoeffBundle<N / 2, SZ>::B(tmp + N / 2 * SZ); 179 CoeffBundle<N, SZ>::InverseEvenOdd(tmp, mem); 180 } 181 }; 182 183 template <size_t N, size_t SZ> 184 struct IDCT1DImpl; 185 186 template <size_t SZ> 187 struct IDCT1DImpl<1, SZ> { 188 JXL_INLINE void operator()(const float* from, size_t from_stride, float* to, 189 size_t to_stride, float* JXL_RESTRICT /* tmp */) { 190 StoreU(LoadU(FV<SZ>(), from), FV<SZ>(), to); 191 } 192 }; 193 194 template <size_t SZ> 195 struct IDCT1DImpl<2, SZ> { 196 JXL_INLINE void operator()(const float* from, size_t from_stride, float* to, 197 size_t to_stride, float* JXL_RESTRICT /* tmp */) { 198 JXL_DASSERT(from_stride >= SZ); 199 JXL_DASSERT(to_stride >= SZ); 200 auto in1 = LoadU(FV<SZ>(), from); 201 auto in2 = LoadU(FV<SZ>(), from + from_stride); 202 StoreU(Add(in1, in2), FV<SZ>(), to); 203 StoreU(Sub(in1, in2), FV<SZ>(), to + to_stride); 204 } 205 }; 206 207 template <size_t N, size_t SZ> 208 struct IDCT1DImpl { 209 void operator()(const float* from, size_t from_stride, float* to, 210 size_t to_stride, float* JXL_RESTRICT tmp) { 211 JXL_DASSERT(from_stride >= SZ); 212 JXL_DASSERT(to_stride >= SZ); 213 CoeffBundle<N, SZ>::ForwardEvenOdd(from, from_stride, tmp); 214 IDCT1DImpl<N / 2, SZ>()(tmp, SZ, tmp, SZ, tmp + N * SZ); 215 CoeffBundle<N / 2, SZ>::BTranspose(tmp + N / 2 * SZ); 216 IDCT1DImpl<N / 2, SZ>()(tmp + N / 2 * SZ, SZ, tmp + N / 2 * SZ, SZ, 217 tmp + N * SZ); 218 CoeffBundle<N, SZ>::MultiplyAndAdd(tmp, to, to_stride); 219 } 220 }; 221 222 template <size_t N, size_t M_or_0, typename FromBlock, typename ToBlock> 223 void DCT1DWrapper(const FromBlock& from, const ToBlock& to, size_t Mp, 224 float* JXL_RESTRICT tmp) { 225 size_t M = M_or_0 != 0 ? M_or_0 : Mp; 226 constexpr size_t SZ = MaxLanes(FV<M_or_0>()); 227 for (size_t i = 0; i < M; i += Lanes(FV<M_or_0>())) { 228 // TODO(veluca): consider removing the temporary memory here (as is done in 229 // IDCT), if it turns out that some compilers don't optimize away the loads 230 // and this is performance-critical. 231 CoeffBundle<N, SZ>::LoadFromBlock(from, i, tmp); 232 DCT1DImpl<N, SZ>()(tmp, tmp + N * SZ); 233 CoeffBundle<N, SZ>::StoreToBlockAndScale(tmp, to, i); 234 } 235 } 236 237 template <size_t N, size_t M_or_0, typename FromBlock, typename ToBlock> 238 void IDCT1DWrapper(const FromBlock& from, const ToBlock& to, size_t Mp, 239 float* JXL_RESTRICT tmp) { 240 size_t M = M_or_0 != 0 ? M_or_0 : Mp; 241 constexpr size_t SZ = MaxLanes(FV<M_or_0>()); 242 for (size_t i = 0; i < M; i += Lanes(FV<M_or_0>())) { 243 IDCT1DImpl<N, SZ>()(from.Address(0, i), from.Stride(), to.Address(0, i), 244 to.Stride(), tmp); 245 } 246 } 247 248 template <size_t N, size_t M, typename = void> 249 struct DCT1D { 250 template <typename FromBlock, typename ToBlock> 251 void operator()(const FromBlock& from, const ToBlock& to, 252 float* JXL_RESTRICT tmp) { 253 return DCT1DWrapper<N, M>(from, to, M, tmp); 254 } 255 }; 256 257 template <size_t N, size_t M> 258 struct DCT1D<N, M, typename std::enable_if<(M > MaxLanes(FV<0>()))>::type> { 259 template <typename FromBlock, typename ToBlock> 260 void operator()(const FromBlock& from, const ToBlock& to, 261 float* JXL_RESTRICT tmp) { 262 return NoInlineWrapper(DCT1DWrapper<N, 0, FromBlock, ToBlock>, from, to, M, 263 tmp); 264 } 265 }; 266 267 template <size_t N, size_t M, typename = void> 268 struct IDCT1D { 269 template <typename FromBlock, typename ToBlock> 270 void operator()(const FromBlock& from, const ToBlock& to, 271 float* JXL_RESTRICT tmp) { 272 return IDCT1DWrapper<N, M>(from, to, M, tmp); 273 } 274 }; 275 276 template <size_t N, size_t M> 277 struct IDCT1D<N, M, typename std::enable_if<(M > MaxLanes(FV<0>()))>::type> { 278 template <typename FromBlock, typename ToBlock> 279 void operator()(const FromBlock& from, const ToBlock& to, 280 float* JXL_RESTRICT tmp) { 281 return NoInlineWrapper(IDCT1DWrapper<N, 0, FromBlock, ToBlock>, from, to, M, 282 tmp); 283 } 284 }; 285 286 // Computes the maybe-transposed, scaled DCT of a block, that needs to be 287 // HWY_ALIGN'ed. 288 template <size_t ROWS, size_t COLS> 289 struct ComputeScaledDCT { 290 // scratch_space must be aligned, and should have space for ROWS*COLS 291 // floats. 292 template <class From> 293 HWY_MAYBE_UNUSED void operator()(const From& from, float* to, 294 float* JXL_RESTRICT scratch_space) { 295 float* JXL_RESTRICT block = scratch_space; 296 float* JXL_RESTRICT tmp = scratch_space + ROWS * COLS; 297 if (ROWS < COLS) { 298 DCT1D<ROWS, COLS>()(from, DCTTo(block, COLS), tmp); 299 Transpose<ROWS, COLS>::Run(DCTFrom(block, COLS), DCTTo(to, ROWS)); 300 DCT1D<COLS, ROWS>()(DCTFrom(to, ROWS), DCTTo(block, ROWS), tmp); 301 Transpose<COLS, ROWS>::Run(DCTFrom(block, ROWS), DCTTo(to, COLS)); 302 } else { 303 DCT1D<ROWS, COLS>()(from, DCTTo(to, COLS), tmp); 304 Transpose<ROWS, COLS>::Run(DCTFrom(to, COLS), DCTTo(block, ROWS)); 305 DCT1D<COLS, ROWS>()(DCTFrom(block, ROWS), DCTTo(to, ROWS), tmp); 306 } 307 } 308 }; 309 // Computes the maybe-transposed, scaled IDCT of a block, that needs to be 310 // HWY_ALIGN'ed. 311 template <size_t ROWS, size_t COLS> 312 struct ComputeScaledIDCT { 313 // scratch_space must be aligned, and should have space for ROWS*COLS 314 // floats. 315 template <class To> 316 HWY_MAYBE_UNUSED void operator()(float* JXL_RESTRICT from, const To& to, 317 float* JXL_RESTRICT scratch_space) { 318 float* JXL_RESTRICT block = scratch_space; 319 float* JXL_RESTRICT tmp = scratch_space + ROWS * COLS; 320 // Reverse the steps done in ComputeScaledDCT. 321 if (ROWS < COLS) { 322 Transpose<ROWS, COLS>::Run(DCTFrom(from, COLS), DCTTo(block, ROWS)); 323 IDCT1D<COLS, ROWS>()(DCTFrom(block, ROWS), DCTTo(from, ROWS), tmp); 324 Transpose<COLS, ROWS>::Run(DCTFrom(from, ROWS), DCTTo(block, COLS)); 325 IDCT1D<ROWS, COLS>()(DCTFrom(block, COLS), to, tmp); 326 } else { 327 IDCT1D<COLS, ROWS>()(DCTFrom(from, ROWS), DCTTo(block, ROWS), tmp); 328 Transpose<COLS, ROWS>::Run(DCTFrom(block, ROWS), DCTTo(from, COLS)); 329 IDCT1D<ROWS, COLS>()(DCTFrom(from, COLS), to, tmp); 330 } 331 } 332 }; 333 334 } // namespace 335 // NOLINTNEXTLINE(google-readability-namespace-comments) 336 } // namespace HWY_NAMESPACE 337 } // namespace jxl 338 HWY_AFTER_NAMESPACE(); 339 #endif // LIB_JXL_DCT_INL_H_