fast_dct-inl.h (9118B)
1 // Copyright (c) the JPEG XL Project Authors. All rights reserved. 2 // 3 // Use of this source code is governed by a BSD-style 4 // license that can be found in the LICENSE file. 5 6 #if defined(LIB_JXL_FAST_DCT_INL_H_) == defined(HWY_TARGET_TOGGLE) 7 #ifdef LIB_JXL_FAST_DCT_INL_H_ 8 #undef LIB_JXL_FAST_DCT_INL_H_ 9 #else 10 #define LIB_JXL_FAST_DCT_INL_H_ 11 #endif 12 13 #include <cmath> 14 #include <hwy/aligned_allocator.h> 15 #include <hwy/highway.h> 16 17 #include "lib/jxl/base/status.h" 18 19 HWY_BEFORE_NAMESPACE(); 20 namespace jxl { 21 namespace HWY_NAMESPACE { 22 namespace { 23 24 #if HWY_TARGET == HWY_NEON 25 HWY_NOINLINE void FastTransposeBlock(const int16_t* JXL_RESTRICT data_in, 26 size_t stride_in, size_t N, size_t M, 27 int16_t* JXL_RESTRICT data_out, 28 size_t stride_out) { 29 JXL_DASSERT(N % 8 == 0); 30 JXL_DASSERT(M % 8 == 0); 31 for (size_t i = 0; i < N; i += 8) { 32 for (size_t j = 0; j < M; j += 8) { 33 // TODO(veluca): one could optimize the M==8, stride_in==8 case further 34 // with vld4. 35 // This code is about 40% faster for N == M == stride_in == 36 // stride_out == 8 37 // Using loads + stores to reshuffle things to be able to 38 // use vld4 doesn't help. 39 /* 40 auto a0 = vld4q_s16(data_in); auto a1 = vld4q_s16(data_in + 32); 41 int16x8x4_t out0; 42 int16x8x4_t out1; 43 out0.val[0] = vuzp1q_s16(a0.val[0], a1.val[0]); 44 out0.val[1] = vuzp1q_s16(a0.val[1], a1.val[1]); 45 out0.val[2] = vuzp1q_s16(a0.val[2], a1.val[2]); 46 out0.val[3] = vuzp1q_s16(a0.val[3], a1.val[3]); 47 out1.val[0] = vuzp2q_s16(a0.val[0], a1.val[0]); 48 out1.val[1] = vuzp2q_s16(a0.val[1], a1.val[1]); 49 out1.val[2] = vuzp2q_s16(a0.val[2], a1.val[2]); 50 out1.val[3] = vuzp2q_s16(a0.val[3], a1.val[3]); 51 vst1q_s16_x4(data_out, out0); 52 vst1q_s16_x4(data_out + 32, out1); 53 */ 54 auto a0 = vld1q_s16(data_in + i * stride_in + j); 55 auto a1 = vld1q_s16(data_in + (i + 1) * stride_in + j); 56 auto a2 = vld1q_s16(data_in + (i + 2) * stride_in + j); 57 auto a3 = vld1q_s16(data_in + (i + 3) * stride_in + j); 58 59 auto a01 = vtrnq_s16(a0, a1); 60 auto a23 = vtrnq_s16(a2, a3); 61 62 auto four0 = vtrnq_s32(vreinterpretq_s32_s16(a01.val[0]), 63 vreinterpretq_s32_s16(a23.val[0])); 64 auto four1 = vtrnq_s32(vreinterpretq_s32_s16(a01.val[1]), 65 vreinterpretq_s32_s16(a23.val[1])); 66 67 auto a4 = vld1q_s16(data_in + (i + 4) * stride_in + j); 68 auto a5 = vld1q_s16(data_in + (i + 5) * stride_in + j); 69 auto a6 = vld1q_s16(data_in + (i + 6) * stride_in + j); 70 auto a7 = vld1q_s16(data_in + (i + 7) * stride_in + j); 71 72 auto a45 = vtrnq_s16(a4, a5); 73 auto a67 = vtrnq_s16(a6, a7); 74 75 auto four2 = vtrnq_s32(vreinterpretq_s32_s16(a45.val[0]), 76 vreinterpretq_s32_s16(a67.val[0])); 77 auto four3 = vtrnq_s32(vreinterpretq_s32_s16(a45.val[1]), 78 vreinterpretq_s32_s16(a67.val[1])); 79 80 auto out0 = 81 vcombine_s32(vget_low_s32(four0.val[0]), vget_low_s32(four2.val[0])); 82 auto out1 = 83 vcombine_s32(vget_low_s32(four1.val[0]), vget_low_s32(four3.val[0])); 84 auto out2 = 85 vcombine_s32(vget_low_s32(four0.val[1]), vget_low_s32(four2.val[1])); 86 auto out3 = 87 vcombine_s32(vget_low_s32(four1.val[1]), vget_low_s32(four3.val[1])); 88 auto out4 = vcombine_s32(vget_high_s32(four0.val[0]), 89 vget_high_s32(four2.val[0])); 90 auto out5 = vcombine_s32(vget_high_s32(four1.val[0]), 91 vget_high_s32(four3.val[0])); 92 auto out6 = vcombine_s32(vget_high_s32(four0.val[1]), 93 vget_high_s32(four2.val[1])); 94 auto out7 = vcombine_s32(vget_high_s32(four1.val[1]), 95 vget_high_s32(four3.val[1])); 96 vst1q_s16(data_out + j * stride_out + i, vreinterpretq_s16_s32(out0)); 97 vst1q_s16(data_out + (j + 1) * stride_out + i, 98 vreinterpretq_s16_s32(out1)); 99 vst1q_s16(data_out + (j + 2) * stride_out + i, 100 vreinterpretq_s16_s32(out2)); 101 vst1q_s16(data_out + (j + 3) * stride_out + i, 102 vreinterpretq_s16_s32(out3)); 103 vst1q_s16(data_out + (j + 4) * stride_out + i, 104 vreinterpretq_s16_s32(out4)); 105 vst1q_s16(data_out + (j + 5) * stride_out + i, 106 vreinterpretq_s16_s32(out5)); 107 vst1q_s16(data_out + (j + 6) * stride_out + i, 108 vreinterpretq_s16_s32(out6)); 109 vst1q_s16(data_out + (j + 7) * stride_out + i, 110 vreinterpretq_s16_s32(out7)); 111 } 112 } 113 } 114 115 template <size_t N> 116 struct FastDCTTag {}; 117 118 #include "lib/jxl/fast_dct128-inl.h" 119 #include "lib/jxl/fast_dct16-inl.h" 120 #include "lib/jxl/fast_dct256-inl.h" 121 #include "lib/jxl/fast_dct32-inl.h" 122 #include "lib/jxl/fast_dct64-inl.h" 123 #include "lib/jxl/fast_dct8-inl.h" 124 125 template <size_t ROWS, size_t COLS> 126 struct ComputeFastScaledIDCT { 127 // scratch_space must be aligned, and should have space for ROWS*COLS 128 // int16_ts. 129 HWY_MAYBE_UNUSED void operator()(int16_t* JXL_RESTRICT from, int16_t* to, 130 size_t to_stride, 131 int16_t* JXL_RESTRICT scratch_space) { 132 // Reverse the steps done in ComputeScaledDCT. 133 if (ROWS < COLS) { 134 FastTransposeBlock(from, COLS, ROWS, COLS, scratch_space, ROWS); 135 FastIDCT(FastDCTTag<COLS>(), scratch_space, ROWS, from, ROWS, ROWS); 136 FastTransposeBlock(from, ROWS, COLS, ROWS, scratch_space, COLS); 137 FastIDCT(FastDCTTag<ROWS>(), scratch_space, COLS, to, to_stride, COLS); 138 } else { 139 FastIDCT(FastDCTTag<COLS>(), from, ROWS, scratch_space, ROWS, ROWS); 140 FastTransposeBlock(scratch_space, ROWS, COLS, ROWS, from, COLS); 141 FastIDCT(FastDCTTag<ROWS>(), from, COLS, to, to_stride, COLS); 142 } 143 } 144 }; 145 #endif 146 147 template <size_t N, size_t M> 148 HWY_NOINLINE void TestFastIDCT() { 149 #if HWY_TARGET == HWY_NEON 150 auto pixels_mem = hwy::AllocateAligned<float>(N * M); 151 float* pixels = pixels_mem.get(); 152 auto dct_mem = hwy::AllocateAligned<float>(N * M); 153 float* dct = dct_mem.get(); 154 auto dct_i_mem = hwy::AllocateAligned<int16_t>(N * M); 155 int16_t* dct_i = dct_i_mem.get(); 156 auto dct_in_mem = hwy::AllocateAligned<int16_t>(N * M); 157 int16_t* dct_in = dct_in_mem.get(); 158 auto idct_mem = hwy::AllocateAligned<int16_t>(N * M); 159 int16_t* idct = idct_mem.get(); 160 161 const HWY_FULL(float) df; 162 auto scratch_space_mem = hwy::AllocateAligned<float>( 163 N * M * 2 + 3 * std::max(N, M) * MaxLanes(df)); 164 float* scratch_space = scratch_space_mem.get(); 165 auto scratch_space_i_mem = hwy::AllocateAligned<int16_t>(N * M * 2); 166 int16_t* scratch_space_i = scratch_space_i_mem.get(); 167 168 Rng rng(0); 169 for (size_t i = 0; i < N * M; i++) { 170 pixels[i] = rng.UniformF(-1, 1); 171 } 172 ComputeScaledDCT<M, N>()(DCTFrom(pixels, N), dct, scratch_space); 173 size_t integer_bits = std::max(FastIDCTIntegerBits(FastDCTTag<N>()), 174 FastIDCTIntegerBits(FastDCTTag<M>())); 175 // Enough range for [-2, 2] output values. 176 JXL_ASSERT(integer_bits <= 14); 177 float scale = (1 << (14 - integer_bits)); 178 for (size_t i = 0; i < N * M; i++) { 179 dct_i[i] = std::round(dct[i] * scale); 180 } 181 182 for (size_t j = 0; j < 40000000 / (M * N); j++) { 183 memcpy(dct_in, dct_i, sizeof(*dct_i) * N * M); 184 ComputeFastScaledIDCT<M, N>()(dct_in, idct, N, scratch_space_i); 185 } 186 float max_error = 0; 187 for (size_t i = 0; i < M * N; i++) { 188 float err = std::abs(idct[i] * (1.0f / scale) - pixels[i]); 189 if (std::abs(err) > max_error) { 190 max_error = std::abs(err); 191 } 192 } 193 printf("max error: %f mantissa bits: %d\n", max_error, 194 14 - static_cast<int>(integer_bits)); 195 #endif 196 } 197 198 template <size_t N, size_t M> 199 HWY_NOINLINE void TestFloatIDCT() { 200 auto pixels_mem = hwy::AllocateAligned<float>(N * M); 201 float* pixels = pixels_mem.get(); 202 auto dct_mem = hwy::AllocateAligned<float>(N * M); 203 float* dct = dct_mem.get(); 204 auto idct_mem = hwy::AllocateAligned<float>(N * M); 205 float* idct = idct_mem.get(); 206 207 auto dct_in_mem = hwy::AllocateAligned<float>(N * M); 208 float* dct_in = dct_mem.get(); 209 210 auto scratch_space_mem = hwy::AllocateAligned<float>(N * M * 5); 211 float* scratch_space = scratch_space_mem.get(); 212 213 Rng rng(0); 214 for (size_t i = 0; i < N * M; i++) { 215 pixels[i] = rng.UniformF(-1, 1); 216 } 217 ComputeScaledDCT<M, N>()(DCTFrom(pixels, N), dct, scratch_space); 218 219 for (size_t j = 0; j < 40000000 / (M * N); j++) { 220 memcpy(dct_in, dct, sizeof(*dct) * N * M); 221 ComputeScaledIDCT<M, N>()(dct_in, DCTTo(idct, N), scratch_space); 222 } 223 float max_error = 0; 224 for (size_t i = 0; i < M * N; i++) { 225 float err = std::abs(idct[i] - pixels[i]); 226 if (std::abs(err) > max_error) { 227 max_error = std::abs(err); 228 } 229 } 230 printf("max error: %e\n", max_error); 231 } 232 233 } // namespace 234 // NOLINTNEXTLINE(google-readability-namespace-comments) 235 } // namespace HWY_NAMESPACE 236 } // namespace jxl 237 HWY_AFTER_NAMESPACE(); 238 239 #endif // LIB_JXL_FAST_DCT_INL_H_