quant_weights_test.cc (9211B)
1 // Copyright (c) the JPEG XL Project Authors. All rights reserved. 2 // 3 // Use of this source code is governed by a BSD-style 4 // license that can be found in the LICENSE file. 5 #include "lib/jxl/quant_weights.h" 6 7 #include <stdlib.h> 8 9 #include <algorithm> 10 #include <cmath> 11 #include <hwy/base.h> // HWY_ALIGN_MAX 12 #include <hwy/tests/hwy_gtest.h> 13 #include <numeric> 14 15 #include "lib/jxl/base/random.h" 16 #include "lib/jxl/dct_for_test.h" 17 #include "lib/jxl/dec_transforms_testonly.h" 18 #include "lib/jxl/enc_modular.h" 19 #include "lib/jxl/enc_quant_weights.h" 20 #include "lib/jxl/enc_transforms.h" 21 #include "lib/jxl/testing.h" 22 23 namespace jxl { 24 namespace { 25 26 // This should have been static assert; not compiling though with C++<17. 27 TEST(QuantWeightsTest, Invariant) { 28 size_t sum = 0; 29 ASSERT_EQ(DequantMatrices::required_size_x.size(), 30 DequantMatrices::required_size_y.size()); 31 for (size_t i = 0; i < DequantMatrices::required_size_x.size(); ++i) { 32 sum += DequantMatrices::required_size_x[i] * 33 DequantMatrices::required_size_y[i]; 34 } 35 ASSERT_EQ(DequantMatrices::kSumRequiredXy, sum); 36 } 37 38 template <typename T> 39 void CheckSimilar(T a, T b) { 40 EXPECT_EQ(a, b); 41 } 42 // minimum exponent = -15. 43 template <> 44 void CheckSimilar(float a, float b) { 45 float m = std::max(std::abs(a), std::abs(b)); 46 // 10 bits of precision are used in the format. Relative error should be 47 // below 2^-10. 48 EXPECT_LE(std::abs(a - b), m / 1024.0f) << "a: " << a << " b: " << b; 49 } 50 51 TEST(QuantWeightsTest, DC) { 52 DequantMatrices mat; 53 float dc_quant[3] = {1e+5, 1e+3, 1e+1}; 54 DequantMatricesSetCustomDC(&mat, dc_quant); 55 for (size_t c = 0; c < 3; c++) { 56 CheckSimilar(mat.InvDCQuant(c), dc_quant[c]); 57 } 58 } 59 60 void RoundtripMatrices(const std::vector<QuantEncoding>& encodings) { 61 ASSERT_TRUE(encodings.size() == DequantMatrices::kNum); 62 DequantMatrices mat; 63 CodecMetadata metadata; 64 FrameHeader frame_header(&metadata); 65 ModularFrameEncoder encoder(frame_header, CompressParams{}, false); 66 JXL_CHECK(DequantMatricesSetCustom(&mat, encodings, &encoder)); 67 const std::vector<QuantEncoding>& encodings_dec = mat.encodings(); 68 for (size_t i = 0; i < encodings.size(); i++) { 69 const QuantEncoding& e = encodings[i]; 70 const QuantEncoding& d = encodings_dec[i]; 71 // Check values roundtripped correctly. 72 EXPECT_EQ(e.mode, d.mode); 73 EXPECT_EQ(e.predefined, d.predefined); 74 EXPECT_EQ(e.source, d.source); 75 76 EXPECT_EQ(static_cast<uint64_t>(e.dct_params.num_distance_bands), 77 static_cast<uint64_t>(d.dct_params.num_distance_bands)); 78 for (size_t c = 0; c < 3; c++) { 79 for (size_t j = 0; j < DctQuantWeightParams::kMaxDistanceBands; j++) { 80 CheckSimilar(e.dct_params.distance_bands[c][j], 81 d.dct_params.distance_bands[c][j]); 82 } 83 } 84 85 if (e.mode == QuantEncoding::kQuantModeRAW) { 86 EXPECT_FALSE(!e.qraw.qtable); 87 EXPECT_FALSE(!d.qraw.qtable); 88 EXPECT_EQ(e.qraw.qtable->size(), d.qraw.qtable->size()); 89 for (size_t j = 0; j < e.qraw.qtable->size(); j++) { 90 EXPECT_EQ((*e.qraw.qtable)[j], (*d.qraw.qtable)[j]); 91 } 92 EXPECT_NEAR(e.qraw.qtable_den, d.qraw.qtable_den, 1e-7f); 93 } else { 94 // modes different than kQuantModeRAW use one of the other fields used 95 // here, which all happen to be arrays of floats. 96 for (size_t c = 0; c < 3; c++) { 97 for (size_t j = 0; j < 3; j++) { 98 CheckSimilar(e.idweights[c][j], d.idweights[c][j]); 99 } 100 for (size_t j = 0; j < 6; j++) { 101 CheckSimilar(e.dct2weights[c][j], d.dct2weights[c][j]); 102 } 103 for (size_t j = 0; j < 2; j++) { 104 CheckSimilar(e.dct4multipliers[c][j], d.dct4multipliers[c][j]); 105 } 106 CheckSimilar(e.dct4x8multipliers[c], d.dct4x8multipliers[c]); 107 for (size_t j = 0; j < 9; j++) { 108 CheckSimilar(e.afv_weights[c][j], d.afv_weights[c][j]); 109 } 110 for (size_t j = 0; j < DctQuantWeightParams::kMaxDistanceBands; j++) { 111 CheckSimilar(e.dct_params_afv_4x4.distance_bands[c][j], 112 d.dct_params_afv_4x4.distance_bands[c][j]); 113 } 114 } 115 } 116 } 117 } 118 119 TEST(QuantWeightsTest, AllDefault) { 120 std::vector<QuantEncoding> encodings(DequantMatrices::kNum, 121 QuantEncoding::Library(0)); 122 RoundtripMatrices(encodings); 123 } 124 125 void TestSingleQuantMatrix(DequantMatrices::QuantTable kind) { 126 std::vector<QuantEncoding> encodings(DequantMatrices::kNum, 127 QuantEncoding::Library(0)); 128 encodings[kind] = DequantMatrices::Library()[kind]; 129 RoundtripMatrices(encodings); 130 } 131 132 // Ensure we can reasonably represent default quant tables. 133 TEST(QuantWeightsTest, DCT) { TestSingleQuantMatrix(DequantMatrices::DCT); } 134 TEST(QuantWeightsTest, IDENTITY) { 135 TestSingleQuantMatrix(DequantMatrices::IDENTITY); 136 } 137 TEST(QuantWeightsTest, DCT2X2) { 138 TestSingleQuantMatrix(DequantMatrices::DCT2X2); 139 } 140 TEST(QuantWeightsTest, DCT4X4) { 141 TestSingleQuantMatrix(DequantMatrices::DCT4X4); 142 } 143 TEST(QuantWeightsTest, DCT16X16) { 144 TestSingleQuantMatrix(DequantMatrices::DCT16X16); 145 } 146 TEST(QuantWeightsTest, DCT32X32) { 147 TestSingleQuantMatrix(DequantMatrices::DCT32X32); 148 } 149 TEST(QuantWeightsTest, DCT8X16) { 150 TestSingleQuantMatrix(DequantMatrices::DCT8X16); 151 } 152 TEST(QuantWeightsTest, DCT8X32) { 153 TestSingleQuantMatrix(DequantMatrices::DCT8X32); 154 } 155 TEST(QuantWeightsTest, DCT16X32) { 156 TestSingleQuantMatrix(DequantMatrices::DCT16X32); 157 } 158 TEST(QuantWeightsTest, DCT4X8) { 159 TestSingleQuantMatrix(DequantMatrices::DCT4X8); 160 } 161 TEST(QuantWeightsTest, AFV0) { TestSingleQuantMatrix(DequantMatrices::AFV0); } 162 TEST(QuantWeightsTest, RAW) { 163 std::vector<QuantEncoding> encodings(DequantMatrices::kNum, 164 QuantEncoding::Library(0)); 165 std::vector<int> matrix(3 * 32 * 32); 166 Rng rng(0); 167 for (size_t i = 0; i < matrix.size(); i++) matrix[i] = rng.UniformI(1, 256); 168 encodings[DequantMatrices::kQuantTable[AcStrategy::DCT32X32]] = 169 QuantEncoding::RAW(matrix, 2); 170 RoundtripMatrices(encodings); 171 } 172 173 class QuantWeightsTargetTest : public hwy::TestWithParamTarget {}; 174 HWY_TARGET_INSTANTIATE_TEST_SUITE_P(QuantWeightsTargetTest); 175 176 TEST_P(QuantWeightsTargetTest, DCTUniform) { 177 constexpr float kUniformQuant = 4; 178 float weights[3][2] = {{1.0f / kUniformQuant, 0}, 179 {1.0f / kUniformQuant, 0}, 180 {1.0f / kUniformQuant, 0}}; 181 DctQuantWeightParams dct_params(weights); 182 std::vector<QuantEncoding> encodings(DequantMatrices::kNum, 183 QuantEncoding::DCT(dct_params)); 184 DequantMatrices dequant_matrices; 185 CodecMetadata metadata; 186 FrameHeader frame_header(&metadata); 187 ModularFrameEncoder encoder(frame_header, CompressParams{}, false); 188 JXL_CHECK(DequantMatricesSetCustom(&dequant_matrices, encodings, &encoder)); 189 JXL_CHECK(dequant_matrices.EnsureComputed(~0u)); 190 191 const float dc_quant[3] = {1.0f / kUniformQuant, 1.0f / kUniformQuant, 192 1.0f / kUniformQuant}; 193 DequantMatricesSetCustomDC(&dequant_matrices, dc_quant); 194 195 HWY_ALIGN_MAX float scratch_space[16 * 16 * 5]; 196 197 // DCT8 198 { 199 HWY_ALIGN_MAX float pixels[64]; 200 std::iota(std::begin(pixels), std::end(pixels), 0); 201 HWY_ALIGN_MAX float coeffs[64]; 202 const AcStrategy::Type dct = AcStrategy::DCT; 203 TransformFromPixels(dct, pixels, 8, coeffs, scratch_space); 204 HWY_ALIGN_MAX double slow_coeffs[64]; 205 for (size_t i = 0; i < 64; i++) slow_coeffs[i] = pixels[i]; 206 DCTSlow<8>(slow_coeffs); 207 208 for (size_t i = 0; i < 64; i++) { 209 // DCTSlow doesn't multiply/divide by 1/N, so we do it manually. 210 slow_coeffs[i] = roundf(slow_coeffs[i] / kUniformQuant) * kUniformQuant; 211 coeffs[i] = roundf(coeffs[i] / dequant_matrices.Matrix(dct, 0)[i]) * 212 dequant_matrices.Matrix(dct, 0)[i]; 213 } 214 IDCTSlow<8>(slow_coeffs); 215 TransformToPixels(dct, coeffs, pixels, 8, scratch_space); 216 for (size_t i = 0; i < 64; i++) { 217 EXPECT_NEAR(pixels[i], slow_coeffs[i], 1e-4); 218 } 219 } 220 221 // DCT16 222 { 223 HWY_ALIGN_MAX float pixels[64 * 4]; 224 std::iota(std::begin(pixels), std::end(pixels), 0); 225 HWY_ALIGN_MAX float coeffs[64 * 4]; 226 const AcStrategy::Type dct = AcStrategy::DCT16X16; 227 TransformFromPixels(dct, pixels, 16, coeffs, scratch_space); 228 HWY_ALIGN_MAX double slow_coeffs[64 * 4]; 229 for (size_t i = 0; i < 64 * 4; i++) slow_coeffs[i] = pixels[i]; 230 DCTSlow<16>(slow_coeffs); 231 232 for (size_t i = 0; i < 64 * 4; i++) { 233 slow_coeffs[i] = roundf(slow_coeffs[i] / kUniformQuant) * kUniformQuant; 234 coeffs[i] = roundf(coeffs[i] / dequant_matrices.Matrix(dct, 0)[i]) * 235 dequant_matrices.Matrix(dct, 0)[i]; 236 } 237 238 IDCTSlow<16>(slow_coeffs); 239 TransformToPixels(dct, coeffs, pixels, 16, scratch_space); 240 for (size_t i = 0; i < 64 * 4; i++) { 241 EXPECT_NEAR(pixels[i], slow_coeffs[i], 1e-4); 242 } 243 } 244 245 // Check that all matrices have the same DC quantization, i.e. that they all 246 // have the same scaling. 247 for (size_t i = 0; i < AcStrategy::kNumValidStrategies; i++) { 248 EXPECT_NEAR(dequant_matrices.Matrix(i, 0)[0], kUniformQuant, 1e-6); 249 } 250 } 251 252 } // namespace 253 } // namespace jxl