libjxl

FORK: libjxl patches used on blog
git clone https://git.neptards.moe/blog/libjxl.git
Log | Files | Refs | Submodules | README | LICENSE

fast_dct-inl.h (9118B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 #if defined(LIB_JXL_FAST_DCT_INL_H_) == defined(HWY_TARGET_TOGGLE)
      7 #ifdef LIB_JXL_FAST_DCT_INL_H_
      8 #undef LIB_JXL_FAST_DCT_INL_H_
      9 #else
     10 #define LIB_JXL_FAST_DCT_INL_H_
     11 #endif
     12 
     13 #include <cmath>
     14 #include <hwy/aligned_allocator.h>
     15 #include <hwy/highway.h>
     16 
     17 #include "lib/jxl/base/status.h"
     18 
     19 HWY_BEFORE_NAMESPACE();
     20 namespace jxl {
     21 namespace HWY_NAMESPACE {
     22 namespace {
     23 
     24 #if HWY_TARGET == HWY_NEON
     25 HWY_NOINLINE void FastTransposeBlock(const int16_t* JXL_RESTRICT data_in,
     26                                      size_t stride_in, size_t N, size_t M,
     27                                      int16_t* JXL_RESTRICT data_out,
     28                                      size_t stride_out) {
     29   JXL_DASSERT(N % 8 == 0);
     30   JXL_DASSERT(M % 8 == 0);
     31   for (size_t i = 0; i < N; i += 8) {
     32     for (size_t j = 0; j < M; j += 8) {
     33       // TODO(veluca): one could optimize the M==8, stride_in==8 case further
     34       // with vld4.
     35       // This code is about 40% faster for N == M == stride_in ==
     36       // stride_out == 8
     37       // Using loads + stores to reshuffle things to be able to
     38       // use vld4 doesn't help.
     39       /*
     40       auto a0 = vld4q_s16(data_in); auto a1 = vld4q_s16(data_in + 32);
     41       int16x8x4_t out0;
     42       int16x8x4_t out1;
     43       out0.val[0] = vuzp1q_s16(a0.val[0], a1.val[0]);
     44       out0.val[1] = vuzp1q_s16(a0.val[1], a1.val[1]);
     45       out0.val[2] = vuzp1q_s16(a0.val[2], a1.val[2]);
     46       out0.val[3] = vuzp1q_s16(a0.val[3], a1.val[3]);
     47       out1.val[0] = vuzp2q_s16(a0.val[0], a1.val[0]);
     48       out1.val[1] = vuzp2q_s16(a0.val[1], a1.val[1]);
     49       out1.val[2] = vuzp2q_s16(a0.val[2], a1.val[2]);
     50       out1.val[3] = vuzp2q_s16(a0.val[3], a1.val[3]);
     51       vst1q_s16_x4(data_out, out0);
     52       vst1q_s16_x4(data_out + 32, out1);
     53       */
     54       auto a0 = vld1q_s16(data_in + i * stride_in + j);
     55       auto a1 = vld1q_s16(data_in + (i + 1) * stride_in + j);
     56       auto a2 = vld1q_s16(data_in + (i + 2) * stride_in + j);
     57       auto a3 = vld1q_s16(data_in + (i + 3) * stride_in + j);
     58 
     59       auto a01 = vtrnq_s16(a0, a1);
     60       auto a23 = vtrnq_s16(a2, a3);
     61 
     62       auto four0 = vtrnq_s32(vreinterpretq_s32_s16(a01.val[0]),
     63                              vreinterpretq_s32_s16(a23.val[0]));
     64       auto four1 = vtrnq_s32(vreinterpretq_s32_s16(a01.val[1]),
     65                              vreinterpretq_s32_s16(a23.val[1]));
     66 
     67       auto a4 = vld1q_s16(data_in + (i + 4) * stride_in + j);
     68       auto a5 = vld1q_s16(data_in + (i + 5) * stride_in + j);
     69       auto a6 = vld1q_s16(data_in + (i + 6) * stride_in + j);
     70       auto a7 = vld1q_s16(data_in + (i + 7) * stride_in + j);
     71 
     72       auto a45 = vtrnq_s16(a4, a5);
     73       auto a67 = vtrnq_s16(a6, a7);
     74 
     75       auto four2 = vtrnq_s32(vreinterpretq_s32_s16(a45.val[0]),
     76                              vreinterpretq_s32_s16(a67.val[0]));
     77       auto four3 = vtrnq_s32(vreinterpretq_s32_s16(a45.val[1]),
     78                              vreinterpretq_s32_s16(a67.val[1]));
     79 
     80       auto out0 =
     81           vcombine_s32(vget_low_s32(four0.val[0]), vget_low_s32(four2.val[0]));
     82       auto out1 =
     83           vcombine_s32(vget_low_s32(four1.val[0]), vget_low_s32(four3.val[0]));
     84       auto out2 =
     85           vcombine_s32(vget_low_s32(four0.val[1]), vget_low_s32(four2.val[1]));
     86       auto out3 =
     87           vcombine_s32(vget_low_s32(four1.val[1]), vget_low_s32(four3.val[1]));
     88       auto out4 = vcombine_s32(vget_high_s32(four0.val[0]),
     89                                vget_high_s32(four2.val[0]));
     90       auto out5 = vcombine_s32(vget_high_s32(four1.val[0]),
     91                                vget_high_s32(four3.val[0]));
     92       auto out6 = vcombine_s32(vget_high_s32(four0.val[1]),
     93                                vget_high_s32(four2.val[1]));
     94       auto out7 = vcombine_s32(vget_high_s32(four1.val[1]),
     95                                vget_high_s32(four3.val[1]));
     96       vst1q_s16(data_out + j * stride_out + i, vreinterpretq_s16_s32(out0));
     97       vst1q_s16(data_out + (j + 1) * stride_out + i,
     98                 vreinterpretq_s16_s32(out1));
     99       vst1q_s16(data_out + (j + 2) * stride_out + i,
    100                 vreinterpretq_s16_s32(out2));
    101       vst1q_s16(data_out + (j + 3) * stride_out + i,
    102                 vreinterpretq_s16_s32(out3));
    103       vst1q_s16(data_out + (j + 4) * stride_out + i,
    104                 vreinterpretq_s16_s32(out4));
    105       vst1q_s16(data_out + (j + 5) * stride_out + i,
    106                 vreinterpretq_s16_s32(out5));
    107       vst1q_s16(data_out + (j + 6) * stride_out + i,
    108                 vreinterpretq_s16_s32(out6));
    109       vst1q_s16(data_out + (j + 7) * stride_out + i,
    110                 vreinterpretq_s16_s32(out7));
    111     }
    112   }
    113 }
    114 
    115 template <size_t N>
    116 struct FastDCTTag {};
    117 
    118 #include "lib/jxl/fast_dct128-inl.h"
    119 #include "lib/jxl/fast_dct16-inl.h"
    120 #include "lib/jxl/fast_dct256-inl.h"
    121 #include "lib/jxl/fast_dct32-inl.h"
    122 #include "lib/jxl/fast_dct64-inl.h"
    123 #include "lib/jxl/fast_dct8-inl.h"
    124 
    125 template <size_t ROWS, size_t COLS>
    126 struct ComputeFastScaledIDCT {
    127   // scratch_space must be aligned, and should have space for ROWS*COLS
    128   // int16_ts.
    129   HWY_MAYBE_UNUSED void operator()(int16_t* JXL_RESTRICT from, int16_t* to,
    130                                    size_t to_stride,
    131                                    int16_t* JXL_RESTRICT scratch_space) {
    132     // Reverse the steps done in ComputeScaledDCT.
    133     if (ROWS < COLS) {
    134       FastTransposeBlock(from, COLS, ROWS, COLS, scratch_space, ROWS);
    135       FastIDCT(FastDCTTag<COLS>(), scratch_space, ROWS, from, ROWS, ROWS);
    136       FastTransposeBlock(from, ROWS, COLS, ROWS, scratch_space, COLS);
    137       FastIDCT(FastDCTTag<ROWS>(), scratch_space, COLS, to, to_stride, COLS);
    138     } else {
    139       FastIDCT(FastDCTTag<COLS>(), from, ROWS, scratch_space, ROWS, ROWS);
    140       FastTransposeBlock(scratch_space, ROWS, COLS, ROWS, from, COLS);
    141       FastIDCT(FastDCTTag<ROWS>(), from, COLS, to, to_stride, COLS);
    142     }
    143   }
    144 };
    145 #endif
    146 
    147 template <size_t N, size_t M>
    148 HWY_NOINLINE void TestFastIDCT() {
    149 #if HWY_TARGET == HWY_NEON
    150   auto pixels_mem = hwy::AllocateAligned<float>(N * M);
    151   float* pixels = pixels_mem.get();
    152   auto dct_mem = hwy::AllocateAligned<float>(N * M);
    153   float* dct = dct_mem.get();
    154   auto dct_i_mem = hwy::AllocateAligned<int16_t>(N * M);
    155   int16_t* dct_i = dct_i_mem.get();
    156   auto dct_in_mem = hwy::AllocateAligned<int16_t>(N * M);
    157   int16_t* dct_in = dct_in_mem.get();
    158   auto idct_mem = hwy::AllocateAligned<int16_t>(N * M);
    159   int16_t* idct = idct_mem.get();
    160 
    161   const HWY_FULL(float) df;
    162   auto scratch_space_mem = hwy::AllocateAligned<float>(
    163       N * M * 2 + 3 * std::max(N, M) * MaxLanes(df));
    164   float* scratch_space = scratch_space_mem.get();
    165   auto scratch_space_i_mem = hwy::AllocateAligned<int16_t>(N * M * 2);
    166   int16_t* scratch_space_i = scratch_space_i_mem.get();
    167 
    168   Rng rng(0);
    169   for (size_t i = 0; i < N * M; i++) {
    170     pixels[i] = rng.UniformF(-1, 1);
    171   }
    172   ComputeScaledDCT<M, N>()(DCTFrom(pixels, N), dct, scratch_space);
    173   size_t integer_bits = std::max(FastIDCTIntegerBits(FastDCTTag<N>()),
    174                                  FastIDCTIntegerBits(FastDCTTag<M>()));
    175   // Enough range for [-2, 2] output values.
    176   JXL_ASSERT(integer_bits <= 14);
    177   float scale = (1 << (14 - integer_bits));
    178   for (size_t i = 0; i < N * M; i++) {
    179     dct_i[i] = std::round(dct[i] * scale);
    180   }
    181 
    182   for (size_t j = 0; j < 40000000 / (M * N); j++) {
    183     memcpy(dct_in, dct_i, sizeof(*dct_i) * N * M);
    184     ComputeFastScaledIDCT<M, N>()(dct_in, idct, N, scratch_space_i);
    185   }
    186   float max_error = 0;
    187   for (size_t i = 0; i < M * N; i++) {
    188     float err = std::abs(idct[i] * (1.0f / scale) - pixels[i]);
    189     if (std::abs(err) > max_error) {
    190       max_error = std::abs(err);
    191     }
    192   }
    193   printf("max error: %f mantissa bits: %d\n", max_error,
    194          14 - static_cast<int>(integer_bits));
    195 #endif
    196 }
    197 
    198 template <size_t N, size_t M>
    199 HWY_NOINLINE void TestFloatIDCT() {
    200   auto pixels_mem = hwy::AllocateAligned<float>(N * M);
    201   float* pixels = pixels_mem.get();
    202   auto dct_mem = hwy::AllocateAligned<float>(N * M);
    203   float* dct = dct_mem.get();
    204   auto idct_mem = hwy::AllocateAligned<float>(N * M);
    205   float* idct = idct_mem.get();
    206 
    207   auto dct_in_mem = hwy::AllocateAligned<float>(N * M);
    208   float* dct_in = dct_mem.get();
    209 
    210   auto scratch_space_mem = hwy::AllocateAligned<float>(N * M * 5);
    211   float* scratch_space = scratch_space_mem.get();
    212 
    213   Rng rng(0);
    214   for (size_t i = 0; i < N * M; i++) {
    215     pixels[i] = rng.UniformF(-1, 1);
    216   }
    217   ComputeScaledDCT<M, N>()(DCTFrom(pixels, N), dct, scratch_space);
    218 
    219   for (size_t j = 0; j < 40000000 / (M * N); j++) {
    220     memcpy(dct_in, dct, sizeof(*dct) * N * M);
    221     ComputeScaledIDCT<M, N>()(dct_in, DCTTo(idct, N), scratch_space);
    222   }
    223   float max_error = 0;
    224   for (size_t i = 0; i < M * N; i++) {
    225     float err = std::abs(idct[i] - pixels[i]);
    226     if (std::abs(err) > max_error) {
    227       max_error = std::abs(err);
    228     }
    229   }
    230   printf("max error: %e\n", max_error);
    231 }
    232 
    233 }  // namespace
    234 // NOLINTNEXTLINE(google-readability-namespace-comments)
    235 }  // namespace HWY_NAMESPACE
    236 }  // namespace jxl
    237 HWY_AFTER_NAMESPACE();
    238 
    239 #endif  // LIB_JXL_FAST_DCT_INL_H_