libjxl

FORK: libjxl patches used on blog
git clone https://git.neptards.moe/blog/libjxl.git
Log | Files | Refs | Submodules | README | LICENSE

dct-inl.h (12399B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 // Fast SIMD floating-point (I)DCT, any power of two.
      7 
      8 #if defined(LIB_JXL_DCT_INL_H_) == defined(HWY_TARGET_TOGGLE)
      9 #ifdef LIB_JXL_DCT_INL_H_
     10 #undef LIB_JXL_DCT_INL_H_
     11 #else
     12 #define LIB_JXL_DCT_INL_H_
     13 #endif
     14 
     15 #include <stddef.h>
     16 
     17 #include <hwy/highway.h>
     18 
     19 #include "lib/jxl/dct_block-inl.h"
     20 #include "lib/jxl/dct_scales.h"
     21 #include "lib/jxl/transpose-inl.h"
     22 HWY_BEFORE_NAMESPACE();
     23 namespace jxl {
     24 namespace HWY_NAMESPACE {
     25 namespace {
     26 
     27 // These templates are not found via ADL.
     28 using hwy::HWY_NAMESPACE::Add;
     29 using hwy::HWY_NAMESPACE::Mul;
     30 using hwy::HWY_NAMESPACE::MulAdd;
     31 using hwy::HWY_NAMESPACE::NegMulAdd;
     32 using hwy::HWY_NAMESPACE::Sub;
     33 
     34 template <size_t SZ>
     35 struct FVImpl {
     36   using type = HWY_CAPPED(float, SZ);
     37 };
     38 
     39 template <>
     40 struct FVImpl<0> {
     41   using type = HWY_FULL(float);
     42 };
     43 
     44 template <size_t SZ>
     45 using FV = typename FVImpl<SZ>::type;
     46 
     47 // Implementation of Lowest Complexity Self Recursive Radix-2 DCT II/III
     48 // Algorithms, by Siriani M. Perera and Jianhua Liu.
     49 
     50 template <size_t N, size_t SZ>
     51 struct CoeffBundle {
     52   static void AddReverse(const float* JXL_RESTRICT ain1,
     53                          const float* JXL_RESTRICT ain2,
     54                          float* JXL_RESTRICT aout) {
     55     for (size_t i = 0; i < N; i++) {
     56       auto in1 = Load(FV<SZ>(), ain1 + i * SZ);
     57       auto in2 = Load(FV<SZ>(), ain2 + (N - i - 1) * SZ);
     58       Store(Add(in1, in2), FV<SZ>(), aout + i * SZ);
     59     }
     60   }
     61   static void SubReverse(const float* JXL_RESTRICT ain1,
     62                          const float* JXL_RESTRICT ain2,
     63                          float* JXL_RESTRICT aout) {
     64     for (size_t i = 0; i < N; i++) {
     65       auto in1 = Load(FV<SZ>(), ain1 + i * SZ);
     66       auto in2 = Load(FV<SZ>(), ain2 + (N - i - 1) * SZ);
     67       Store(Sub(in1, in2), FV<SZ>(), aout + i * SZ);
     68     }
     69   }
     70   static void B(float* JXL_RESTRICT coeff) {
     71     auto sqrt2 = Set(FV<SZ>(), kSqrt2);
     72     auto in1 = Load(FV<SZ>(), coeff);
     73     auto in2 = Load(FV<SZ>(), coeff + SZ);
     74     Store(MulAdd(in1, sqrt2, in2), FV<SZ>(), coeff);
     75     for (size_t i = 1; i + 1 < N; i++) {
     76       auto in1 = Load(FV<SZ>(), coeff + i * SZ);
     77       auto in2 = Load(FV<SZ>(), coeff + (i + 1) * SZ);
     78       Store(Add(in1, in2), FV<SZ>(), coeff + i * SZ);
     79     }
     80   }
     81   static void BTranspose(float* JXL_RESTRICT coeff) {
     82     for (size_t i = N - 1; i > 0; i--) {
     83       auto in1 = Load(FV<SZ>(), coeff + i * SZ);
     84       auto in2 = Load(FV<SZ>(), coeff + (i - 1) * SZ);
     85       Store(Add(in1, in2), FV<SZ>(), coeff + i * SZ);
     86     }
     87     auto sqrt2 = Set(FV<SZ>(), kSqrt2);
     88     auto in1 = Load(FV<SZ>(), coeff);
     89     Store(Mul(in1, sqrt2), FV<SZ>(), coeff);
     90   }
     91   // Ideally optimized away by compiler (except the multiply).
     92   static void InverseEvenOdd(const float* JXL_RESTRICT ain,
     93                              float* JXL_RESTRICT aout) {
     94     for (size_t i = 0; i < N / 2; i++) {
     95       auto in1 = Load(FV<SZ>(), ain + i * SZ);
     96       Store(in1, FV<SZ>(), aout + 2 * i * SZ);
     97     }
     98     for (size_t i = N / 2; i < N; i++) {
     99       auto in1 = Load(FV<SZ>(), ain + i * SZ);
    100       Store(in1, FV<SZ>(), aout + (2 * (i - N / 2) + 1) * SZ);
    101     }
    102   }
    103   // Ideally optimized away by compiler.
    104   static void ForwardEvenOdd(const float* JXL_RESTRICT ain, size_t ain_stride,
    105                              float* JXL_RESTRICT aout) {
    106     for (size_t i = 0; i < N / 2; i++) {
    107       auto in1 = LoadU(FV<SZ>(), ain + 2 * i * ain_stride);
    108       Store(in1, FV<SZ>(), aout + i * SZ);
    109     }
    110     for (size_t i = N / 2; i < N; i++) {
    111       auto in1 = LoadU(FV<SZ>(), ain + (2 * (i - N / 2) + 1) * ain_stride);
    112       Store(in1, FV<SZ>(), aout + i * SZ);
    113     }
    114   }
    115   // Invoked on full vector.
    116   static void Multiply(float* JXL_RESTRICT coeff) {
    117     for (size_t i = 0; i < N / 2; i++) {
    118       auto in1 = Load(FV<SZ>(), coeff + (N / 2 + i) * SZ);
    119       auto mul = Set(FV<SZ>(), WcMultipliers<N>::kMultipliers[i]);
    120       Store(Mul(in1, mul), FV<SZ>(), coeff + (N / 2 + i) * SZ);
    121     }
    122   }
    123   static void MultiplyAndAdd(const float* JXL_RESTRICT coeff,
    124                              float* JXL_RESTRICT out, size_t out_stride) {
    125     for (size_t i = 0; i < N / 2; i++) {
    126       auto mul = Set(FV<SZ>(), WcMultipliers<N>::kMultipliers[i]);
    127       auto in1 = Load(FV<SZ>(), coeff + i * SZ);
    128       auto in2 = Load(FV<SZ>(), coeff + (N / 2 + i) * SZ);
    129       auto out1 = MulAdd(mul, in2, in1);
    130       auto out2 = NegMulAdd(mul, in2, in1);
    131       StoreU(out1, FV<SZ>(), out + i * out_stride);
    132       StoreU(out2, FV<SZ>(), out + (N - i - 1) * out_stride);
    133     }
    134   }
    135   template <typename Block>
    136   static void LoadFromBlock(const Block& in, size_t off,
    137                             float* JXL_RESTRICT coeff) {
    138     for (size_t i = 0; i < N; i++) {
    139       Store(in.LoadPart(FV<SZ>(), i, off), FV<SZ>(), coeff + i * SZ);
    140     }
    141   }
    142   template <typename Block>
    143   static void StoreToBlockAndScale(const float* JXL_RESTRICT coeff,
    144                                    const Block& out, size_t off) {
    145     auto mul = Set(FV<SZ>(), 1.0f / N);
    146     for (size_t i = 0; i < N; i++) {
    147       out.StorePart(FV<SZ>(), Mul(mul, Load(FV<SZ>(), coeff + i * SZ)), i, off);
    148     }
    149   }
    150 };
    151 
    152 template <size_t N, size_t SZ>
    153 struct DCT1DImpl;
    154 
    155 template <size_t SZ>
    156 struct DCT1DImpl<1, SZ> {
    157   JXL_INLINE void operator()(float* JXL_RESTRICT mem, float* /* tmp */) {}
    158 };
    159 
    160 template <size_t SZ>
    161 struct DCT1DImpl<2, SZ> {
    162   JXL_INLINE void operator()(float* JXL_RESTRICT mem, float* /* tmp */) {
    163     auto in1 = Load(FV<SZ>(), mem);
    164     auto in2 = Load(FV<SZ>(), mem + SZ);
    165     Store(Add(in1, in2), FV<SZ>(), mem);
    166     Store(Sub(in1, in2), FV<SZ>(), mem + SZ);
    167   }
    168 };
    169 
    170 template <size_t N, size_t SZ>
    171 struct DCT1DImpl {
    172   void operator()(float* JXL_RESTRICT mem, float* JXL_RESTRICT tmp) {
    173     CoeffBundle<N / 2, SZ>::AddReverse(mem, mem + N / 2 * SZ, tmp);
    174     DCT1DImpl<N / 2, SZ>()(tmp, tmp + N * SZ);
    175     CoeffBundle<N / 2, SZ>::SubReverse(mem, mem + N / 2 * SZ, tmp + N / 2 * SZ);
    176     CoeffBundle<N, SZ>::Multiply(tmp);
    177     DCT1DImpl<N / 2, SZ>()(tmp + N / 2 * SZ, tmp + N * SZ);
    178     CoeffBundle<N / 2, SZ>::B(tmp + N / 2 * SZ);
    179     CoeffBundle<N, SZ>::InverseEvenOdd(tmp, mem);
    180   }
    181 };
    182 
    183 template <size_t N, size_t SZ>
    184 struct IDCT1DImpl;
    185 
    186 template <size_t SZ>
    187 struct IDCT1DImpl<1, SZ> {
    188   JXL_INLINE void operator()(const float* from, size_t from_stride, float* to,
    189                              size_t to_stride, float* JXL_RESTRICT /* tmp */) {
    190     StoreU(LoadU(FV<SZ>(), from), FV<SZ>(), to);
    191   }
    192 };
    193 
    194 template <size_t SZ>
    195 struct IDCT1DImpl<2, SZ> {
    196   JXL_INLINE void operator()(const float* from, size_t from_stride, float* to,
    197                              size_t to_stride, float* JXL_RESTRICT /* tmp */) {
    198     JXL_DASSERT(from_stride >= SZ);
    199     JXL_DASSERT(to_stride >= SZ);
    200     auto in1 = LoadU(FV<SZ>(), from);
    201     auto in2 = LoadU(FV<SZ>(), from + from_stride);
    202     StoreU(Add(in1, in2), FV<SZ>(), to);
    203     StoreU(Sub(in1, in2), FV<SZ>(), to + to_stride);
    204   }
    205 };
    206 
    207 template <size_t N, size_t SZ>
    208 struct IDCT1DImpl {
    209   void operator()(const float* from, size_t from_stride, float* to,
    210                   size_t to_stride, float* JXL_RESTRICT tmp) {
    211     JXL_DASSERT(from_stride >= SZ);
    212     JXL_DASSERT(to_stride >= SZ);
    213     CoeffBundle<N, SZ>::ForwardEvenOdd(from, from_stride, tmp);
    214     IDCT1DImpl<N / 2, SZ>()(tmp, SZ, tmp, SZ, tmp + N * SZ);
    215     CoeffBundle<N / 2, SZ>::BTranspose(tmp + N / 2 * SZ);
    216     IDCT1DImpl<N / 2, SZ>()(tmp + N / 2 * SZ, SZ, tmp + N / 2 * SZ, SZ,
    217                             tmp + N * SZ);
    218     CoeffBundle<N, SZ>::MultiplyAndAdd(tmp, to, to_stride);
    219   }
    220 };
    221 
    222 template <size_t N, size_t M_or_0, typename FromBlock, typename ToBlock>
    223 void DCT1DWrapper(const FromBlock& from, const ToBlock& to, size_t Mp,
    224                   float* JXL_RESTRICT tmp) {
    225   size_t M = M_or_0 != 0 ? M_or_0 : Mp;
    226   constexpr size_t SZ = MaxLanes(FV<M_or_0>());
    227   for (size_t i = 0; i < M; i += Lanes(FV<M_or_0>())) {
    228     // TODO(veluca): consider removing the temporary memory here (as is done in
    229     // IDCT), if it turns out that some compilers don't optimize away the loads
    230     // and this is performance-critical.
    231     CoeffBundle<N, SZ>::LoadFromBlock(from, i, tmp);
    232     DCT1DImpl<N, SZ>()(tmp, tmp + N * SZ);
    233     CoeffBundle<N, SZ>::StoreToBlockAndScale(tmp, to, i);
    234   }
    235 }
    236 
    237 template <size_t N, size_t M_or_0, typename FromBlock, typename ToBlock>
    238 void IDCT1DWrapper(const FromBlock& from, const ToBlock& to, size_t Mp,
    239                    float* JXL_RESTRICT tmp) {
    240   size_t M = M_or_0 != 0 ? M_or_0 : Mp;
    241   constexpr size_t SZ = MaxLanes(FV<M_or_0>());
    242   for (size_t i = 0; i < M; i += Lanes(FV<M_or_0>())) {
    243     IDCT1DImpl<N, SZ>()(from.Address(0, i), from.Stride(), to.Address(0, i),
    244                         to.Stride(), tmp);
    245   }
    246 }
    247 
    248 template <size_t N, size_t M, typename = void>
    249 struct DCT1D {
    250   template <typename FromBlock, typename ToBlock>
    251   void operator()(const FromBlock& from, const ToBlock& to,
    252                   float* JXL_RESTRICT tmp) {
    253     return DCT1DWrapper<N, M>(from, to, M, tmp);
    254   }
    255 };
    256 
    257 template <size_t N, size_t M>
    258 struct DCT1D<N, M, typename std::enable_if<(M > MaxLanes(FV<0>()))>::type> {
    259   template <typename FromBlock, typename ToBlock>
    260   void operator()(const FromBlock& from, const ToBlock& to,
    261                   float* JXL_RESTRICT tmp) {
    262     return NoInlineWrapper(DCT1DWrapper<N, 0, FromBlock, ToBlock>, from, to, M,
    263                            tmp);
    264   }
    265 };
    266 
    267 template <size_t N, size_t M, typename = void>
    268 struct IDCT1D {
    269   template <typename FromBlock, typename ToBlock>
    270   void operator()(const FromBlock& from, const ToBlock& to,
    271                   float* JXL_RESTRICT tmp) {
    272     return IDCT1DWrapper<N, M>(from, to, M, tmp);
    273   }
    274 };
    275 
    276 template <size_t N, size_t M>
    277 struct IDCT1D<N, M, typename std::enable_if<(M > MaxLanes(FV<0>()))>::type> {
    278   template <typename FromBlock, typename ToBlock>
    279   void operator()(const FromBlock& from, const ToBlock& to,
    280                   float* JXL_RESTRICT tmp) {
    281     return NoInlineWrapper(IDCT1DWrapper<N, 0, FromBlock, ToBlock>, from, to, M,
    282                            tmp);
    283   }
    284 };
    285 
    286 // Computes the maybe-transposed, scaled DCT of a block, that needs to be
    287 // HWY_ALIGN'ed.
    288 template <size_t ROWS, size_t COLS>
    289 struct ComputeScaledDCT {
    290   // scratch_space must be aligned, and should have space for ROWS*COLS
    291   // floats.
    292   template <class From>
    293   HWY_MAYBE_UNUSED void operator()(const From& from, float* to,
    294                                    float* JXL_RESTRICT scratch_space) {
    295     float* JXL_RESTRICT block = scratch_space;
    296     float* JXL_RESTRICT tmp = scratch_space + ROWS * COLS;
    297     if (ROWS < COLS) {
    298       DCT1D<ROWS, COLS>()(from, DCTTo(block, COLS), tmp);
    299       Transpose<ROWS, COLS>::Run(DCTFrom(block, COLS), DCTTo(to, ROWS));
    300       DCT1D<COLS, ROWS>()(DCTFrom(to, ROWS), DCTTo(block, ROWS), tmp);
    301       Transpose<COLS, ROWS>::Run(DCTFrom(block, ROWS), DCTTo(to, COLS));
    302     } else {
    303       DCT1D<ROWS, COLS>()(from, DCTTo(to, COLS), tmp);
    304       Transpose<ROWS, COLS>::Run(DCTFrom(to, COLS), DCTTo(block, ROWS));
    305       DCT1D<COLS, ROWS>()(DCTFrom(block, ROWS), DCTTo(to, ROWS), tmp);
    306     }
    307   }
    308 };
    309 // Computes the maybe-transposed, scaled IDCT of a block, that needs to be
    310 // HWY_ALIGN'ed.
    311 template <size_t ROWS, size_t COLS>
    312 struct ComputeScaledIDCT {
    313   // scratch_space must be aligned, and should have space for ROWS*COLS
    314   // floats.
    315   template <class To>
    316   HWY_MAYBE_UNUSED void operator()(float* JXL_RESTRICT from, const To& to,
    317                                    float* JXL_RESTRICT scratch_space) {
    318     float* JXL_RESTRICT block = scratch_space;
    319     float* JXL_RESTRICT tmp = scratch_space + ROWS * COLS;
    320     // Reverse the steps done in ComputeScaledDCT.
    321     if (ROWS < COLS) {
    322       Transpose<ROWS, COLS>::Run(DCTFrom(from, COLS), DCTTo(block, ROWS));
    323       IDCT1D<COLS, ROWS>()(DCTFrom(block, ROWS), DCTTo(from, ROWS), tmp);
    324       Transpose<COLS, ROWS>::Run(DCTFrom(from, ROWS), DCTTo(block, COLS));
    325       IDCT1D<ROWS, COLS>()(DCTFrom(block, COLS), to, tmp);
    326     } else {
    327       IDCT1D<COLS, ROWS>()(DCTFrom(from, ROWS), DCTTo(block, ROWS), tmp);
    328       Transpose<COLS, ROWS>::Run(DCTFrom(block, ROWS), DCTTo(from, COLS));
    329       IDCT1D<ROWS, COLS>()(DCTFrom(from, COLS), to, tmp);
    330     }
    331   }
    332 };
    333 
    334 }  // namespace
    335 // NOLINTNEXTLINE(google-readability-namespace-comments)
    336 }  // namespace HWY_NAMESPACE
    337 }  // namespace jxl
    338 HWY_AFTER_NAMESPACE();
    339 #endif  // LIB_JXL_DCT_INL_H_