libjxl

FORK: libjxl patches used on blog
git clone https://git.neptards.moe/blog/libjxl.git
Log | Files | Refs | Submodules | README | LICENSE

idct.cc (26685B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 #include "lib/jpegli/idct.h"
      7 
      8 #include <cmath>
      9 
     10 #include "lib/jpegli/decode_internal.h"
     11 #include "lib/jxl/base/status.h"
     12 
     13 #undef HWY_TARGET_INCLUDE
     14 #define HWY_TARGET_INCLUDE "lib/jpegli/idct.cc"
     15 #include <hwy/foreach_target.h>
     16 #include <hwy/highway.h>
     17 
     18 #include "lib/jpegli/transpose-inl.h"
     19 
     20 HWY_BEFORE_NAMESPACE();
     21 namespace jpegli {
     22 namespace HWY_NAMESPACE {
     23 
     24 // These templates are not found via ADL.
     25 using hwy::HWY_NAMESPACE::Abs;
     26 using hwy::HWY_NAMESPACE::Add;
     27 using hwy::HWY_NAMESPACE::Gt;
     28 using hwy::HWY_NAMESPACE::IfThenElseZero;
     29 using hwy::HWY_NAMESPACE::Mul;
     30 using hwy::HWY_NAMESPACE::MulAdd;
     31 using hwy::HWY_NAMESPACE::NegMulAdd;
     32 using hwy::HWY_NAMESPACE::Rebind;
     33 using hwy::HWY_NAMESPACE::Sub;
     34 using hwy::HWY_NAMESPACE::Vec;
     35 using hwy::HWY_NAMESPACE::Xor;
     36 
     37 using D = HWY_FULL(float);
     38 using DI = HWY_FULL(int32_t);
     39 constexpr D d;
     40 constexpr DI di;
     41 
     42 using D8 = HWY_CAPPED(float, 8);
     43 constexpr D8 d8;
     44 
     45 void DequantBlock(const int16_t* JXL_RESTRICT qblock,
     46                   const float* JXL_RESTRICT dequant,
     47                   const float* JXL_RESTRICT biases, float* JXL_RESTRICT block) {
     48   for (size_t k = 0; k < 64; k += Lanes(d)) {
     49     const auto mul = Load(d, dequant + k);
     50     const auto bias = Load(d, biases + k);
     51     const Rebind<int16_t, DI> di16;
     52     const Vec<DI> quant_i = PromoteTo(di, Load(di16, qblock + k));
     53     const Rebind<float, DI> df;
     54     const auto quant = ConvertTo(df, quant_i);
     55     const auto abs_quant = Abs(quant);
     56     const auto not_0 = Gt(abs_quant, Zero(df));
     57     const auto sign_quant = Xor(quant, abs_quant);
     58     const auto biased_quant = Sub(quant, Xor(bias, sign_quant));
     59     const auto dequant = IfThenElseZero(not_0, Mul(biased_quant, mul));
     60     Store(dequant, d, block + k);
     61   }
     62 }
     63 
     64 template <size_t N>
     65 void ForwardEvenOdd(const float* JXL_RESTRICT ain, size_t ain_stride,
     66                     float* JXL_RESTRICT aout) {
     67   for (size_t i = 0; i < N / 2; i++) {
     68     auto in1 = LoadU(d8, ain + 2 * i * ain_stride);
     69     Store(in1, d8, aout + i * 8);
     70   }
     71   for (size_t i = N / 2; i < N; i++) {
     72     auto in1 = LoadU(d8, ain + (2 * (i - N / 2) + 1) * ain_stride);
     73     Store(in1, d8, aout + i * 8);
     74   }
     75 }
     76 
     77 template <size_t N>
     78 void BTranspose(float* JXL_RESTRICT coeff) {
     79   for (size_t i = N - 1; i > 0; i--) {
     80     auto in1 = Load(d8, coeff + i * 8);
     81     auto in2 = Load(d8, coeff + (i - 1) * 8);
     82     Store(Add(in1, in2), d8, coeff + i * 8);
     83   }
     84   constexpr float kSqrt2 = 1.41421356237f;
     85   auto sqrt2 = Set(d8, kSqrt2);
     86   auto in1 = Load(d8, coeff);
     87   Store(Mul(in1, sqrt2), d8, coeff);
     88 }
     89 
     90 // Constants for DCT implementation. Generated by the following snippet:
     91 // for i in range(N // 2):
     92 //    print(1.0 / (2 * math.cos((i + 0.5) * math.pi / N)), end=", ")
     93 template <size_t N>
     94 struct WcMultipliers;
     95 
     96 template <>
     97 struct WcMultipliers<4> {
     98   static constexpr float kMultipliers[] = {
     99       0.541196100146197,
    100       1.3065629648763764,
    101   };
    102 };
    103 
    104 template <>
    105 struct WcMultipliers<8> {
    106   static constexpr float kMultipliers[] = {
    107       0.5097955791041592,
    108       0.6013448869350453,
    109       0.8999762231364156,
    110       2.5629154477415055,
    111   };
    112 };
    113 
    114 constexpr float WcMultipliers<4>::kMultipliers[];
    115 constexpr float WcMultipliers<8>::kMultipliers[];
    116 
    117 template <size_t N>
    118 void MultiplyAndAdd(const float* JXL_RESTRICT coeff, float* JXL_RESTRICT out,
    119                     size_t out_stride) {
    120   for (size_t i = 0; i < N / 2; i++) {
    121     auto mul = Set(d8, WcMultipliers<N>::kMultipliers[i]);
    122     auto in1 = Load(d8, coeff + i * 8);
    123     auto in2 = Load(d8, coeff + (N / 2 + i) * 8);
    124     auto out1 = MulAdd(mul, in2, in1);
    125     auto out2 = NegMulAdd(mul, in2, in1);
    126     StoreU(out1, d8, out + i * out_stride);
    127     StoreU(out2, d8, out + (N - i - 1) * out_stride);
    128   }
    129 }
    130 
    131 template <size_t N>
    132 struct IDCT1DImpl;
    133 
    134 template <>
    135 struct IDCT1DImpl<1> {
    136   JXL_INLINE void operator()(const float* from, size_t from_stride, float* to,
    137                              size_t to_stride) {
    138     StoreU(LoadU(d8, from), d8, to);
    139   }
    140 };
    141 
    142 template <>
    143 struct IDCT1DImpl<2> {
    144   JXL_INLINE void operator()(const float* from, size_t from_stride, float* to,
    145                              size_t to_stride) {
    146     JXL_DASSERT(from_stride >= 8);
    147     JXL_DASSERT(to_stride >= 8);
    148     auto in1 = LoadU(d8, from);
    149     auto in2 = LoadU(d8, from + from_stride);
    150     StoreU(Add(in1, in2), d8, to);
    151     StoreU(Sub(in1, in2), d8, to + to_stride);
    152   }
    153 };
    154 
    155 template <size_t N>
    156 struct IDCT1DImpl {
    157   void operator()(const float* from, size_t from_stride, float* to,
    158                   size_t to_stride) {
    159     JXL_DASSERT(from_stride >= 8);
    160     JXL_DASSERT(to_stride >= 8);
    161     HWY_ALIGN float tmp[64];
    162     ForwardEvenOdd<N>(from, from_stride, tmp);
    163     IDCT1DImpl<N / 2>()(tmp, 8, tmp, 8);
    164     BTranspose<N / 2>(tmp + N * 4);
    165     IDCT1DImpl<N / 2>()(tmp + N * 4, 8, tmp + N * 4, 8);
    166     MultiplyAndAdd<N>(tmp, to, to_stride);
    167   }
    168 };
    169 
    170 template <size_t N>
    171 void IDCT1D(float* JXL_RESTRICT from, float* JXL_RESTRICT output,
    172             size_t output_stride) {
    173   for (size_t i = 0; i < 8; i += Lanes(d8)) {
    174     IDCT1DImpl<N>()(from + i, 8, output + i, output_stride);
    175   }
    176 }
    177 
    178 void ComputeScaledIDCT(float* JXL_RESTRICT block0, float* JXL_RESTRICT block1,
    179                        float* JXL_RESTRICT output, size_t output_stride) {
    180   Transpose8x8Block(block0, block1);
    181   IDCT1D<8>(block1, block0, 8);
    182   Transpose8x8Block(block0, block1);
    183   IDCT1D<8>(block1, output, output_stride);
    184 }
    185 
    186 void InverseTransformBlock8x8(const int16_t* JXL_RESTRICT qblock,
    187                               const float* JXL_RESTRICT dequant,
    188                               const float* JXL_RESTRICT biases,
    189                               float* JXL_RESTRICT scratch_space,
    190                               float* JXL_RESTRICT output, size_t output_stride,
    191                               size_t dctsize) {
    192   float* JXL_RESTRICT block0 = scratch_space;
    193   float* JXL_RESTRICT block1 = scratch_space + DCTSIZE2;
    194   DequantBlock(qblock, dequant, biases, block0);
    195   ComputeScaledIDCT(block0, block1, output, output_stride);
    196 }
    197 
    198 // Computes the N-point IDCT of in[], and stores the result in out[]. The in[]
    199 // array is at most 8 values long, values in[8:N-1] are assumed to be 0.
    200 void Compute1dIDCT(const float* in, float* out, size_t N) {
    201   switch (N) {
    202     case 3: {
    203       static constexpr float kC3[3] = {
    204           1.414213562373,
    205           1.224744871392,
    206           0.707106781187,
    207       };
    208       float even0 = in[0] + kC3[2] * in[2];
    209       float even1 = in[0] - kC3[0] * in[2];
    210       float odd0 = kC3[1] * in[1];
    211       out[0] = even0 + odd0;
    212       out[2] = even0 - odd0;
    213       out[1] = even1;
    214       break;
    215     }
    216     case 5: {
    217       static constexpr float kC5[5] = {
    218           1.414213562373, 1.344997023928, 1.144122805635,
    219           0.831253875555, 0.437016024449,
    220       };
    221       float even0 = in[0] + kC5[2] * in[2] + kC5[4] * in[4];
    222       float even1 = in[0] - kC5[4] * in[2] - kC5[2] * in[4];
    223       float even2 = in[0] - kC5[0] * in[2] + kC5[0] * in[4];
    224       float odd0 = kC5[1] * in[1] + kC5[3] * in[3];
    225       float odd1 = kC5[3] * in[1] - kC5[1] * in[3];
    226       out[0] = even0 + odd0;
    227       out[4] = even0 - odd0;
    228       out[1] = even1 + odd1;
    229       out[3] = even1 - odd1;
    230       out[2] = even2;
    231       break;
    232     }
    233     case 6: {
    234       static constexpr float kC6[6] = {
    235           1.414213562373, 1.366025403784, 1.224744871392,
    236           1.000000000000, 0.707106781187, 0.366025403784,
    237       };
    238       float even0 = in[0] + kC6[2] * in[2] + kC6[4] * in[4];
    239       float even1 = in[0] - kC6[0] * in[4];
    240       float even2 = in[0] - kC6[2] * in[2] + kC6[4] * in[4];
    241       float odd0 = kC6[1] * in[1] + kC6[3] * in[3] + kC6[5] * in[5];
    242       float odd1 = kC6[3] * in[1] - kC6[3] * in[3] - kC6[3] * in[5];
    243       float odd2 = kC6[5] * in[1] - kC6[3] * in[3] + kC6[1] * in[5];
    244       out[0] = even0 + odd0;
    245       out[5] = even0 - odd0;
    246       out[1] = even1 + odd1;
    247       out[4] = even1 - odd1;
    248       out[2] = even2 + odd2;
    249       out[3] = even2 - odd2;
    250       break;
    251     }
    252     case 7: {
    253       static constexpr float kC7[7] = {
    254           1.414213562373, 1.378756275744, 1.274162392264, 1.105676685997,
    255           0.881747733790, 0.613604268353, 0.314692122713,
    256       };
    257       float even0 = in[0] + kC7[2] * in[2] + kC7[4] * in[4] + kC7[6] * in[6];
    258       float even1 = in[0] + kC7[6] * in[2] - kC7[2] * in[4] - kC7[4] * in[6];
    259       float even2 = in[0] - kC7[4] * in[2] - kC7[6] * in[4] + kC7[2] * in[6];
    260       float even3 = in[0] - kC7[0] * in[2] + kC7[0] * in[4] - kC7[0] * in[6];
    261       float odd0 = kC7[1] * in[1] + kC7[3] * in[3] + kC7[5] * in[5];
    262       float odd1 = kC7[3] * in[1] - kC7[5] * in[3] - kC7[1] * in[5];
    263       float odd2 = kC7[5] * in[1] - kC7[1] * in[3] + kC7[3] * in[5];
    264       out[0] = even0 + odd0;
    265       out[6] = even0 - odd0;
    266       out[1] = even1 + odd1;
    267       out[5] = even1 - odd1;
    268       out[2] = even2 + odd2;
    269       out[4] = even2 - odd2;
    270       out[3] = even3;
    271       break;
    272     }
    273     case 9: {
    274       static constexpr float kC9[9] = {
    275           1.414213562373, 1.392728480640, 1.328926048777,
    276           1.224744871392, 1.083350440839, 0.909038955344,
    277           0.707106781187, 0.483689525296, 0.245575607938,
    278       };
    279       float even0 = in[0] + kC9[2] * in[2] + kC9[4] * in[4] + kC9[6] * in[6];
    280       float even1 = in[0] + kC9[6] * in[2] - kC9[6] * in[4] - kC9[0] * in[6];
    281       float even2 = in[0] - kC9[8] * in[2] - kC9[2] * in[4] + kC9[6] * in[6];
    282       float even3 = in[0] - kC9[4] * in[2] + kC9[8] * in[4] + kC9[6] * in[6];
    283       float even4 = in[0] - kC9[0] * in[2] + kC9[0] * in[4] - kC9[0] * in[6];
    284       float odd0 =
    285           kC9[1] * in[1] + kC9[3] * in[3] + kC9[5] * in[5] + kC9[7] * in[7];
    286       float odd1 = kC9[3] * in[1] - kC9[3] * in[5] - kC9[3] * in[7];
    287       float odd2 =
    288           kC9[5] * in[1] - kC9[3] * in[3] - kC9[7] * in[5] + kC9[1] * in[7];
    289       float odd3 =
    290           kC9[7] * in[1] - kC9[3] * in[3] + kC9[1] * in[5] - kC9[5] * in[7];
    291       out[0] = even0 + odd0;
    292       out[8] = even0 - odd0;
    293       out[1] = even1 + odd1;
    294       out[7] = even1 - odd1;
    295       out[2] = even2 + odd2;
    296       out[6] = even2 - odd2;
    297       out[3] = even3 + odd3;
    298       out[5] = even3 - odd3;
    299       out[4] = even4;
    300       break;
    301     }
    302     case 10: {
    303       static constexpr float kC10[10] = {
    304           1.414213562373, 1.396802246667, 1.344997023928, 1.260073510670,
    305           1.144122805635, 1.000000000000, 0.831253875555, 0.642039521920,
    306           0.437016024449, 0.221231742082,
    307       };
    308       float even0 = in[0] + kC10[2] * in[2] + kC10[4] * in[4] + kC10[6] * in[6];
    309       float even1 = in[0] + kC10[6] * in[2] - kC10[8] * in[4] - kC10[2] * in[6];
    310       float even2 = in[0] - kC10[0] * in[4];
    311       float even3 = in[0] - kC10[6] * in[2] - kC10[8] * in[4] + kC10[2] * in[6];
    312       float even4 = in[0] - kC10[2] * in[2] + kC10[4] * in[4] - kC10[6] * in[6];
    313       float odd0 =
    314           kC10[1] * in[1] + kC10[3] * in[3] + kC10[5] * in[5] + kC10[7] * in[7];
    315       float odd1 =
    316           kC10[3] * in[1] + kC10[9] * in[3] - kC10[5] * in[5] - kC10[1] * in[7];
    317       float odd2 =
    318           kC10[5] * in[1] - kC10[5] * in[3] - kC10[5] * in[5] + kC10[5] * in[7];
    319       float odd3 =
    320           kC10[7] * in[1] - kC10[1] * in[3] + kC10[5] * in[5] + kC10[9] * in[7];
    321       float odd4 =
    322           kC10[9] * in[1] - kC10[7] * in[3] + kC10[5] * in[5] - kC10[3] * in[7];
    323       out[0] = even0 + odd0;
    324       out[9] = even0 - odd0;
    325       out[1] = even1 + odd1;
    326       out[8] = even1 - odd1;
    327       out[2] = even2 + odd2;
    328       out[7] = even2 - odd2;
    329       out[3] = even3 + odd3;
    330       out[6] = even3 - odd3;
    331       out[4] = even4 + odd4;
    332       out[5] = even4 - odd4;
    333       break;
    334     }
    335     case 11: {
    336       static constexpr float kC11[11] = {
    337           1.414213562373, 1.399818907436, 1.356927976287, 1.286413904599,
    338           1.189712155524, 1.068791297809, 0.926112931411, 0.764581576418,
    339           0.587485545401, 0.398430002847, 0.201263574413,
    340       };
    341       float even0 = in[0] + kC11[2] * in[2] + kC11[4] * in[4] + kC11[6] * in[6];
    342       float even1 =
    343           in[0] + kC11[6] * in[2] - kC11[10] * in[4] - kC11[4] * in[6];
    344       float even2 =
    345           in[0] + kC11[10] * in[2] - kC11[2] * in[4] - kC11[8] * in[6];
    346       float even3 = in[0] - kC11[8] * in[2] - kC11[6] * in[4] + kC11[2] * in[6];
    347       float even4 =
    348           in[0] - kC11[4] * in[2] + kC11[8] * in[4] + kC11[10] * in[6];
    349       float even5 = in[0] - kC11[0] * in[2] + kC11[0] * in[4] - kC11[0] * in[6];
    350       float odd0 =
    351           kC11[1] * in[1] + kC11[3] * in[3] + kC11[5] * in[5] + kC11[7] * in[7];
    352       float odd1 =
    353           kC11[3] * in[1] + kC11[9] * in[3] - kC11[7] * in[5] - kC11[1] * in[7];
    354       float odd2 =
    355           kC11[5] * in[1] - kC11[7] * in[3] - kC11[3] * in[5] + kC11[9] * in[7];
    356       float odd3 =
    357           kC11[7] * in[1] - kC11[1] * in[3] + kC11[9] * in[5] + kC11[5] * in[7];
    358       float odd4 =
    359           kC11[9] * in[1] - kC11[5] * in[3] + kC11[1] * in[5] - kC11[3] * in[7];
    360       out[0] = even0 + odd0;
    361       out[10] = even0 - odd0;
    362       out[1] = even1 + odd1;
    363       out[9] = even1 - odd1;
    364       out[2] = even2 + odd2;
    365       out[8] = even2 - odd2;
    366       out[3] = even3 + odd3;
    367       out[7] = even3 - odd3;
    368       out[4] = even4 + odd4;
    369       out[6] = even4 - odd4;
    370       out[5] = even5;
    371       break;
    372     }
    373     case 12: {
    374       static constexpr float kC12[12] = {
    375           1.414213562373, 1.402114769300, 1.366025403784, 1.306562964876,
    376           1.224744871392, 1.121971053594, 1.000000000000, 0.860918669154,
    377           0.707106781187, 0.541196100146, 0.366025403784, 0.184591911283,
    378       };
    379       float even0 = in[0] + kC12[2] * in[2] + kC12[4] * in[4] + kC12[6] * in[6];
    380       float even1 = in[0] + kC12[6] * in[2] - kC12[6] * in[6];
    381       float even2 =
    382           in[0] + kC12[10] * in[2] - kC12[4] * in[4] - kC12[6] * in[6];
    383       float even3 =
    384           in[0] - kC12[10] * in[2] - kC12[4] * in[4] + kC12[6] * in[6];
    385       float even4 = in[0] - kC12[6] * in[2] + kC12[6] * in[6];
    386       float even5 = in[0] - kC12[2] * in[2] + kC12[4] * in[4] - kC12[6] * in[6];
    387       float odd0 =
    388           kC12[1] * in[1] + kC12[3] * in[3] + kC12[5] * in[5] + kC12[7] * in[7];
    389       float odd1 =
    390           kC12[3] * in[1] + kC12[9] * in[3] - kC12[9] * in[5] - kC12[3] * in[7];
    391       float odd2 = kC12[5] * in[1] - kC12[9] * in[3] - kC12[1] * in[5] -
    392                    kC12[11] * in[7];
    393       float odd3 = kC12[7] * in[1] - kC12[3] * in[3] - kC12[11] * in[5] +
    394                    kC12[1] * in[7];
    395       float odd4 =
    396           kC12[9] * in[1] - kC12[3] * in[3] + kC12[3] * in[5] - kC12[9] * in[7];
    397       float odd5 = kC12[11] * in[1] - kC12[9] * in[3] + kC12[7] * in[5] -
    398                    kC12[5] * in[7];
    399       out[0] = even0 + odd0;
    400       out[11] = even0 - odd0;
    401       out[1] = even1 + odd1;
    402       out[10] = even1 - odd1;
    403       out[2] = even2 + odd2;
    404       out[9] = even2 - odd2;
    405       out[3] = even3 + odd3;
    406       out[8] = even3 - odd3;
    407       out[4] = even4 + odd4;
    408       out[7] = even4 - odd4;
    409       out[5] = even5 + odd5;
    410       out[6] = even5 - odd5;
    411       break;
    412     }
    413     case 13: {
    414       static constexpr float kC13[13] = {
    415           1.414213562373, 1.403902353238, 1.373119086479, 1.322312651445,
    416           1.252223920364, 1.163874944761, 1.058554051646, 0.937797056801,
    417           0.803364869133, 0.657217812653, 0.501487040539, 0.338443458124,
    418           0.170464607981,
    419       };
    420       float even0 = in[0] + kC13[2] * in[2] + kC13[4] * in[4] + kC13[6] * in[6];
    421       float even1 =
    422           in[0] + kC13[6] * in[2] + kC13[12] * in[4] - kC13[8] * in[6];
    423       float even2 =
    424           in[0] + kC13[10] * in[2] - kC13[6] * in[4] - kC13[4] * in[6];
    425       float even3 =
    426           in[0] - kC13[12] * in[2] - kC13[2] * in[4] + kC13[10] * in[6];
    427       float even4 =
    428           in[0] - kC13[8] * in[2] - kC13[10] * in[4] + kC13[2] * in[6];
    429       float even5 =
    430           in[0] - kC13[4] * in[2] + kC13[8] * in[4] - kC13[12] * in[6];
    431       float even6 = in[0] - kC13[0] * in[2] + kC13[0] * in[4] - kC13[0] * in[6];
    432       float odd0 =
    433           kC13[1] * in[1] + kC13[3] * in[3] + kC13[5] * in[5] + kC13[7] * in[7];
    434       float odd1 = kC13[3] * in[1] + kC13[9] * in[3] - kC13[11] * in[5] -
    435                    kC13[5] * in[7];
    436       float odd2 = kC13[5] * in[1] - kC13[11] * in[3] - kC13[1] * in[5] -
    437                    kC13[9] * in[7];
    438       float odd3 =
    439           kC13[7] * in[1] - kC13[5] * in[3] - kC13[9] * in[5] + kC13[3] * in[7];
    440       float odd4 = kC13[9] * in[1] - kC13[1] * in[3] + kC13[7] * in[5] +
    441                    kC13[11] * in[7];
    442       float odd5 = kC13[11] * in[1] - kC13[7] * in[3] + kC13[3] * in[5] -
    443                    kC13[1] * in[7];
    444       out[0] = even0 + odd0;
    445       out[12] = even0 - odd0;
    446       out[1] = even1 + odd1;
    447       out[11] = even1 - odd1;
    448       out[2] = even2 + odd2;
    449       out[10] = even2 - odd2;
    450       out[3] = even3 + odd3;
    451       out[9] = even3 - odd3;
    452       out[4] = even4 + odd4;
    453       out[8] = even4 - odd4;
    454       out[5] = even5 + odd5;
    455       out[7] = even5 - odd5;
    456       out[6] = even6;
    457       break;
    458     }
    459     case 14: {
    460       static constexpr float kC14[14] = {
    461           1.414213562373, 1.405321284327, 1.378756275744, 1.334852607020,
    462           1.274162392264, 1.197448846138, 1.105676685997, 1.000000000000,
    463           0.881747733790, 0.752406978226, 0.613604268353, 0.467085128785,
    464           0.314692122713, 0.158341680609,
    465       };
    466       float even0 = in[0] + kC14[2] * in[2] + kC14[4] * in[4] + kC14[6] * in[6];
    467       float even1 =
    468           in[0] + kC14[6] * in[2] + kC14[12] * in[4] - kC14[10] * in[6];
    469       float even2 =
    470           in[0] + kC14[10] * in[2] - kC14[8] * in[4] - kC14[2] * in[6];
    471       float even3 = in[0] - kC14[0] * in[4];
    472       float even4 =
    473           in[0] - kC14[10] * in[2] - kC14[8] * in[4] + kC14[2] * in[6];
    474       float even5 =
    475           in[0] - kC14[6] * in[2] + kC14[12] * in[4] + kC14[10] * in[6];
    476       float even6 = in[0] - kC14[2] * in[2] + kC14[4] * in[4] - kC14[6] * in[6];
    477       float odd0 =
    478           kC14[1] * in[1] + kC14[3] * in[3] + kC14[5] * in[5] + kC14[7] * in[7];
    479       float odd1 = kC14[3] * in[1] + kC14[9] * in[3] - kC14[13] * in[5] -
    480                    kC14[7] * in[7];
    481       float odd2 = kC14[5] * in[1] - kC14[13] * in[3] - kC14[3] * in[5] -
    482                    kC14[7] * in[7];
    483       float odd3 =
    484           kC14[7] * in[1] - kC14[7] * in[3] - kC14[7] * in[5] + kC14[7] * in[7];
    485       float odd4 = kC14[9] * in[1] - kC14[1] * in[3] + kC14[11] * in[5] +
    486                    kC14[7] * in[7];
    487       float odd5 = kC14[11] * in[1] - kC14[5] * in[3] + kC14[1] * in[5] -
    488                    kC14[7] * in[7];
    489       float odd6 = kC14[13] * in[1] - kC14[11] * in[3] + kC14[9] * in[5] -
    490                    kC14[7] * in[7];
    491       out[0] = even0 + odd0;
    492       out[13] = even0 - odd0;
    493       out[1] = even1 + odd1;
    494       out[12] = even1 - odd1;
    495       out[2] = even2 + odd2;
    496       out[11] = even2 - odd2;
    497       out[3] = even3 + odd3;
    498       out[10] = even3 - odd3;
    499       out[4] = even4 + odd4;
    500       out[9] = even4 - odd4;
    501       out[5] = even5 + odd5;
    502       out[8] = even5 - odd5;
    503       out[6] = even6 + odd6;
    504       out[7] = even6 - odd6;
    505       break;
    506     }
    507     case 15: {
    508       static constexpr float kC15[15] = {
    509           1.414213562373, 1.406466352507, 1.383309602960, 1.344997023928,
    510           1.291948376043, 1.224744871392, 1.144122805635, 1.050965490998,
    511           0.946293578512, 0.831253875555, 0.707106781187, 0.575212476952,
    512           0.437016024449, 0.294031532930, 0.147825570407,
    513       };
    514       float even0 = in[0] + kC15[2] * in[2] + kC15[4] * in[4] + kC15[6] * in[6];
    515       float even1 =
    516           in[0] + kC15[6] * in[2] + kC15[12] * in[4] - kC15[12] * in[6];
    517       float even2 =
    518           in[0] + kC15[10] * in[2] - kC15[10] * in[4] - kC15[0] * in[6];
    519       float even3 =
    520           in[0] + kC15[14] * in[2] - kC15[2] * in[4] - kC15[12] * in[6];
    521       float even4 =
    522           in[0] - kC15[12] * in[2] - kC15[6] * in[4] + kC15[6] * in[6];
    523       float even5 =
    524           in[0] - kC15[8] * in[2] - kC15[14] * in[4] + kC15[6] * in[6];
    525       float even6 =
    526           in[0] - kC15[4] * in[2] + kC15[8] * in[4] - kC15[12] * in[6];
    527       float even7 = in[0] - kC15[0] * in[2] + kC15[0] * in[4] - kC15[0] * in[6];
    528       float odd0 =
    529           kC15[1] * in[1] + kC15[3] * in[3] + kC15[5] * in[5] + kC15[7] * in[7];
    530       float odd1 = kC15[3] * in[1] + kC15[9] * in[3] - kC15[9] * in[7];
    531       float odd2 = kC15[5] * in[1] - kC15[5] * in[5] - kC15[5] * in[7];
    532       float odd3 = kC15[7] * in[1] - kC15[9] * in[3] - kC15[5] * in[5] +
    533                    kC15[11] * in[7];
    534       float odd4 = kC15[9] * in[1] - kC15[3] * in[3] + kC15[3] * in[7];
    535       float odd5 = kC15[11] * in[1] - kC15[3] * in[3] + kC15[5] * in[5] -
    536                    kC15[13] * in[7];
    537       float odd6 = kC15[13] * in[1] - kC15[9] * in[3] + kC15[5] * in[5] -
    538                    kC15[1] * in[7];
    539       out[0] = even0 + odd0;
    540       out[14] = even0 - odd0;
    541       out[1] = even1 + odd1;
    542       out[13] = even1 - odd1;
    543       out[2] = even2 + odd2;
    544       out[12] = even2 - odd2;
    545       out[3] = even3 + odd3;
    546       out[11] = even3 - odd3;
    547       out[4] = even4 + odd4;
    548       out[10] = even4 - odd4;
    549       out[5] = even5 + odd5;
    550       out[9] = even5 - odd5;
    551       out[6] = even6 + odd6;
    552       out[8] = even6 - odd6;
    553       out[7] = even7;
    554       break;
    555     }
    556     case 16: {
    557       static constexpr float kC16[16] = {
    558           1.414213562373, 1.407403737526, 1.387039845322, 1.353318001174,
    559           1.306562964876, 1.247225012987, 1.175875602419, 1.093201867002,
    560           1.000000000000, 0.897167586343, 0.785694958387, 0.666655658478,
    561           0.541196100146, 0.410524527522, 0.275899379283, 0.138617169199,
    562       };
    563       float even0 = in[0] + kC16[2] * in[2] + kC16[4] * in[4] + kC16[6] * in[6];
    564       float even1 =
    565           in[0] + kC16[6] * in[2] + kC16[12] * in[4] - kC16[14] * in[6];
    566       float even2 =
    567           in[0] + kC16[10] * in[2] - kC16[12] * in[4] - kC16[2] * in[6];
    568       float even3 =
    569           in[0] + kC16[14] * in[2] - kC16[4] * in[4] - kC16[10] * in[6];
    570       float even4 =
    571           in[0] - kC16[14] * in[2] - kC16[4] * in[4] + kC16[10] * in[6];
    572       float even5 =
    573           in[0] - kC16[10] * in[2] - kC16[12] * in[4] + kC16[2] * in[6];
    574       float even6 =
    575           in[0] - kC16[6] * in[2] + kC16[12] * in[4] + kC16[14] * in[6];
    576       float even7 = in[0] - kC16[2] * in[2] + kC16[4] * in[4] - kC16[6] * in[6];
    577       float odd0 = (kC16[1] * in[1] + kC16[3] * in[3] + kC16[5] * in[5] +
    578                     kC16[7] * in[7]);
    579       float odd1 = (kC16[3] * in[1] + kC16[9] * in[3] + kC16[15] * in[5] -
    580                     kC16[11] * in[7]);
    581       float odd2 = (kC16[5] * in[1] + kC16[15] * in[3] - kC16[7] * in[5] -
    582                     kC16[3] * in[7]);
    583       float odd3 = (kC16[7] * in[1] - kC16[11] * in[3] - kC16[3] * in[5] +
    584                     kC16[15] * in[7]);
    585       float odd4 = (kC16[9] * in[1] - kC16[5] * in[3] - kC16[13] * in[5] +
    586                     kC16[1] * in[7]);
    587       float odd5 = (kC16[11] * in[1] - kC16[1] * in[3] + kC16[9] * in[5] +
    588                     kC16[13] * in[7]);
    589       float odd6 = (kC16[13] * in[1] - kC16[7] * in[3] + kC16[1] * in[5] -
    590                     kC16[5] * in[7]);
    591       float odd7 = (kC16[15] * in[1] - kC16[13] * in[3] + kC16[11] * in[5] -
    592                     kC16[9] * in[7]);
    593       out[0] = even0 + odd0;
    594       out[15] = even0 - odd0;
    595       out[1] = even1 + odd1;
    596       out[14] = even1 - odd1;
    597       out[2] = even2 + odd2;
    598       out[13] = even2 - odd2;
    599       out[3] = even3 + odd3;
    600       out[12] = even3 - odd3;
    601       out[4] = even4 + odd4;
    602       out[11] = even4 - odd4;
    603       out[5] = even5 + odd5;
    604       out[10] = even5 - odd5;
    605       out[6] = even6 + odd6;
    606       out[9] = even6 - odd6;
    607       out[7] = even7 + odd7;
    608       out[8] = even7 - odd7;
    609       break;
    610     }
    611     default:
    612       JXL_ABORT("Compute1dIDCT does not support N=%d", static_cast<int>(N));
    613       break;
    614   }
    615 }
    616 
    617 void InverseTransformBlockGeneric(const int16_t* JXL_RESTRICT qblock,
    618                                   const float* JXL_RESTRICT dequant,
    619                                   const float* JXL_RESTRICT biases,
    620                                   float* JXL_RESTRICT scratch_space,
    621                                   float* JXL_RESTRICT output,
    622                                   size_t output_stride, size_t dctsize) {
    623   float* JXL_RESTRICT block0 = scratch_space;
    624   float* JXL_RESTRICT block1 = scratch_space + DCTSIZE2;
    625   DequantBlock(qblock, dequant, biases, block0);
    626   if (dctsize == 1) {
    627     *output = *block0;
    628   } else if (dctsize == 2 || dctsize == 4) {
    629     float* JXL_RESTRICT block2 = scratch_space + 2 * DCTSIZE2;
    630     ComputeScaledIDCT(block0, block1, block2, 8);
    631     if (dctsize == 4) {
    632       for (size_t iy = 0; iy < 4; ++iy) {
    633         for (size_t ix = 0; ix < 4; ++ix) {
    634           float* block = &block2[16 * iy + 2 * ix];
    635           output[iy * output_stride + ix] =
    636               0.25f * (block[0] + block[1] + block[8] + block[9]);
    637         }
    638       }
    639     } else {
    640       for (size_t iy = 0; iy < 2; ++iy) {
    641         for (size_t ix = 0; ix < 2; ++ix) {
    642           float* block = &block2[32 * iy + 4 * ix];
    643           output[iy * output_stride + ix] =
    644               0.0625f *
    645               (block[0] + block[1] + block[2] + block[3] + block[8] + block[9] +
    646                block[10] + block[11] + block[16] + block[17] + block[18] +
    647                block[19] + block[24] + block[25] + block[26] + block[27]);
    648         }
    649       }
    650     }
    651   } else {
    652     float dctin[DCTSIZE];
    653     float dctout[DCTSIZE * 2];
    654     size_t insize = std::min<size_t>(dctsize, DCTSIZE);
    655     for (size_t ix = 0; ix < insize; ++ix) {
    656       for (size_t iy = 0; iy < insize; ++iy) {
    657         dctin[iy] = block0[iy * DCTSIZE + ix];
    658       }
    659       Compute1dIDCT(dctin, dctout, dctsize);
    660       for (size_t iy = 0; iy < dctsize; ++iy) {
    661         block1[iy * dctsize + ix] = dctout[iy];
    662       }
    663     }
    664     for (size_t iy = 0; iy < dctsize; ++iy) {
    665       Compute1dIDCT(block1 + iy * dctsize, output + iy * output_stride,
    666                     dctsize);
    667     }
    668   }
    669 }
    670 
    671 // NOLINTNEXTLINE(google-readability-namespace-comments)
    672 }  // namespace HWY_NAMESPACE
    673 }  // namespace jpegli
    674 HWY_AFTER_NAMESPACE();
    675 
    676 #if HWY_ONCE
    677 namespace jpegli {
    678 
    679 HWY_EXPORT(InverseTransformBlock8x8);
    680 HWY_EXPORT(InverseTransformBlockGeneric);
    681 
    682 void ChooseInverseTransform(j_decompress_ptr cinfo) {
    683   jpeg_decomp_master* m = cinfo->master;
    684   for (int c = 0; c < cinfo->num_components; ++c) {
    685     if (m->scaled_dct_size[c] == DCTSIZE) {
    686       m->inverse_transform[c] = HWY_DYNAMIC_DISPATCH(InverseTransformBlock8x8);
    687     } else {
    688       m->inverse_transform[c] =
    689           HWY_DYNAMIC_DISPATCH(InverseTransformBlockGeneric);
    690     }
    691   }
    692 }
    693 
    694 }  // namespace jpegli
    695 #endif  // HWY_ONCE