libjxl

FORK: libjxl patches used on blog
git clone https://git.neptards.moe/blog/libjxl.git
Log | Files | Refs | Submodules | README | LICENSE

quant_weights_test.cc (9211B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 #include "lib/jxl/quant_weights.h"
      6 
      7 #include <stdlib.h>
      8 
      9 #include <algorithm>
     10 #include <cmath>
     11 #include <hwy/base.h>  // HWY_ALIGN_MAX
     12 #include <hwy/tests/hwy_gtest.h>
     13 #include <numeric>
     14 
     15 #include "lib/jxl/base/random.h"
     16 #include "lib/jxl/dct_for_test.h"
     17 #include "lib/jxl/dec_transforms_testonly.h"
     18 #include "lib/jxl/enc_modular.h"
     19 #include "lib/jxl/enc_quant_weights.h"
     20 #include "lib/jxl/enc_transforms.h"
     21 #include "lib/jxl/testing.h"
     22 
     23 namespace jxl {
     24 namespace {
     25 
     26 // This should have been static assert; not compiling though with C++<17.
     27 TEST(QuantWeightsTest, Invariant) {
     28   size_t sum = 0;
     29   ASSERT_EQ(DequantMatrices::required_size_x.size(),
     30             DequantMatrices::required_size_y.size());
     31   for (size_t i = 0; i < DequantMatrices::required_size_x.size(); ++i) {
     32     sum += DequantMatrices::required_size_x[i] *
     33            DequantMatrices::required_size_y[i];
     34   }
     35   ASSERT_EQ(DequantMatrices::kSumRequiredXy, sum);
     36 }
     37 
     38 template <typename T>
     39 void CheckSimilar(T a, T b) {
     40   EXPECT_EQ(a, b);
     41 }
     42 // minimum exponent = -15.
     43 template <>
     44 void CheckSimilar(float a, float b) {
     45   float m = std::max(std::abs(a), std::abs(b));
     46   // 10 bits of precision are used in the format. Relative error should be
     47   // below 2^-10.
     48   EXPECT_LE(std::abs(a - b), m / 1024.0f) << "a: " << a << " b: " << b;
     49 }
     50 
     51 TEST(QuantWeightsTest, DC) {
     52   DequantMatrices mat;
     53   float dc_quant[3] = {1e+5, 1e+3, 1e+1};
     54   DequantMatricesSetCustomDC(&mat, dc_quant);
     55   for (size_t c = 0; c < 3; c++) {
     56     CheckSimilar(mat.InvDCQuant(c), dc_quant[c]);
     57   }
     58 }
     59 
     60 void RoundtripMatrices(const std::vector<QuantEncoding>& encodings) {
     61   ASSERT_TRUE(encodings.size() == DequantMatrices::kNum);
     62   DequantMatrices mat;
     63   CodecMetadata metadata;
     64   FrameHeader frame_header(&metadata);
     65   ModularFrameEncoder encoder(frame_header, CompressParams{}, false);
     66   JXL_CHECK(DequantMatricesSetCustom(&mat, encodings, &encoder));
     67   const std::vector<QuantEncoding>& encodings_dec = mat.encodings();
     68   for (size_t i = 0; i < encodings.size(); i++) {
     69     const QuantEncoding& e = encodings[i];
     70     const QuantEncoding& d = encodings_dec[i];
     71     // Check values roundtripped correctly.
     72     EXPECT_EQ(e.mode, d.mode);
     73     EXPECT_EQ(e.predefined, d.predefined);
     74     EXPECT_EQ(e.source, d.source);
     75 
     76     EXPECT_EQ(static_cast<uint64_t>(e.dct_params.num_distance_bands),
     77               static_cast<uint64_t>(d.dct_params.num_distance_bands));
     78     for (size_t c = 0; c < 3; c++) {
     79       for (size_t j = 0; j < DctQuantWeightParams::kMaxDistanceBands; j++) {
     80         CheckSimilar(e.dct_params.distance_bands[c][j],
     81                      d.dct_params.distance_bands[c][j]);
     82       }
     83     }
     84 
     85     if (e.mode == QuantEncoding::kQuantModeRAW) {
     86       EXPECT_FALSE(!e.qraw.qtable);
     87       EXPECT_FALSE(!d.qraw.qtable);
     88       EXPECT_EQ(e.qraw.qtable->size(), d.qraw.qtable->size());
     89       for (size_t j = 0; j < e.qraw.qtable->size(); j++) {
     90         EXPECT_EQ((*e.qraw.qtable)[j], (*d.qraw.qtable)[j]);
     91       }
     92       EXPECT_NEAR(e.qraw.qtable_den, d.qraw.qtable_den, 1e-7f);
     93     } else {
     94       // modes different than kQuantModeRAW use one of the other fields used
     95       // here, which all happen to be arrays of floats.
     96       for (size_t c = 0; c < 3; c++) {
     97         for (size_t j = 0; j < 3; j++) {
     98           CheckSimilar(e.idweights[c][j], d.idweights[c][j]);
     99         }
    100         for (size_t j = 0; j < 6; j++) {
    101           CheckSimilar(e.dct2weights[c][j], d.dct2weights[c][j]);
    102         }
    103         for (size_t j = 0; j < 2; j++) {
    104           CheckSimilar(e.dct4multipliers[c][j], d.dct4multipliers[c][j]);
    105         }
    106         CheckSimilar(e.dct4x8multipliers[c], d.dct4x8multipliers[c]);
    107         for (size_t j = 0; j < 9; j++) {
    108           CheckSimilar(e.afv_weights[c][j], d.afv_weights[c][j]);
    109         }
    110         for (size_t j = 0; j < DctQuantWeightParams::kMaxDistanceBands; j++) {
    111           CheckSimilar(e.dct_params_afv_4x4.distance_bands[c][j],
    112                        d.dct_params_afv_4x4.distance_bands[c][j]);
    113         }
    114       }
    115     }
    116   }
    117 }
    118 
    119 TEST(QuantWeightsTest, AllDefault) {
    120   std::vector<QuantEncoding> encodings(DequantMatrices::kNum,
    121                                        QuantEncoding::Library(0));
    122   RoundtripMatrices(encodings);
    123 }
    124 
    125 void TestSingleQuantMatrix(DequantMatrices::QuantTable kind) {
    126   std::vector<QuantEncoding> encodings(DequantMatrices::kNum,
    127                                        QuantEncoding::Library(0));
    128   encodings[kind] = DequantMatrices::Library()[kind];
    129   RoundtripMatrices(encodings);
    130 }
    131 
    132 // Ensure we can reasonably represent default quant tables.
    133 TEST(QuantWeightsTest, DCT) { TestSingleQuantMatrix(DequantMatrices::DCT); }
    134 TEST(QuantWeightsTest, IDENTITY) {
    135   TestSingleQuantMatrix(DequantMatrices::IDENTITY);
    136 }
    137 TEST(QuantWeightsTest, DCT2X2) {
    138   TestSingleQuantMatrix(DequantMatrices::DCT2X2);
    139 }
    140 TEST(QuantWeightsTest, DCT4X4) {
    141   TestSingleQuantMatrix(DequantMatrices::DCT4X4);
    142 }
    143 TEST(QuantWeightsTest, DCT16X16) {
    144   TestSingleQuantMatrix(DequantMatrices::DCT16X16);
    145 }
    146 TEST(QuantWeightsTest, DCT32X32) {
    147   TestSingleQuantMatrix(DequantMatrices::DCT32X32);
    148 }
    149 TEST(QuantWeightsTest, DCT8X16) {
    150   TestSingleQuantMatrix(DequantMatrices::DCT8X16);
    151 }
    152 TEST(QuantWeightsTest, DCT8X32) {
    153   TestSingleQuantMatrix(DequantMatrices::DCT8X32);
    154 }
    155 TEST(QuantWeightsTest, DCT16X32) {
    156   TestSingleQuantMatrix(DequantMatrices::DCT16X32);
    157 }
    158 TEST(QuantWeightsTest, DCT4X8) {
    159   TestSingleQuantMatrix(DequantMatrices::DCT4X8);
    160 }
    161 TEST(QuantWeightsTest, AFV0) { TestSingleQuantMatrix(DequantMatrices::AFV0); }
    162 TEST(QuantWeightsTest, RAW) {
    163   std::vector<QuantEncoding> encodings(DequantMatrices::kNum,
    164                                        QuantEncoding::Library(0));
    165   std::vector<int> matrix(3 * 32 * 32);
    166   Rng rng(0);
    167   for (size_t i = 0; i < matrix.size(); i++) matrix[i] = rng.UniformI(1, 256);
    168   encodings[DequantMatrices::kQuantTable[AcStrategy::DCT32X32]] =
    169       QuantEncoding::RAW(matrix, 2);
    170   RoundtripMatrices(encodings);
    171 }
    172 
    173 class QuantWeightsTargetTest : public hwy::TestWithParamTarget {};
    174 HWY_TARGET_INSTANTIATE_TEST_SUITE_P(QuantWeightsTargetTest);
    175 
    176 TEST_P(QuantWeightsTargetTest, DCTUniform) {
    177   constexpr float kUniformQuant = 4;
    178   float weights[3][2] = {{1.0f / kUniformQuant, 0},
    179                          {1.0f / kUniformQuant, 0},
    180                          {1.0f / kUniformQuant, 0}};
    181   DctQuantWeightParams dct_params(weights);
    182   std::vector<QuantEncoding> encodings(DequantMatrices::kNum,
    183                                        QuantEncoding::DCT(dct_params));
    184   DequantMatrices dequant_matrices;
    185   CodecMetadata metadata;
    186   FrameHeader frame_header(&metadata);
    187   ModularFrameEncoder encoder(frame_header, CompressParams{}, false);
    188   JXL_CHECK(DequantMatricesSetCustom(&dequant_matrices, encodings, &encoder));
    189   JXL_CHECK(dequant_matrices.EnsureComputed(~0u));
    190 
    191   const float dc_quant[3] = {1.0f / kUniformQuant, 1.0f / kUniformQuant,
    192                              1.0f / kUniformQuant};
    193   DequantMatricesSetCustomDC(&dequant_matrices, dc_quant);
    194 
    195   HWY_ALIGN_MAX float scratch_space[16 * 16 * 5];
    196 
    197   // DCT8
    198   {
    199     HWY_ALIGN_MAX float pixels[64];
    200     std::iota(std::begin(pixels), std::end(pixels), 0);
    201     HWY_ALIGN_MAX float coeffs[64];
    202     const AcStrategy::Type dct = AcStrategy::DCT;
    203     TransformFromPixels(dct, pixels, 8, coeffs, scratch_space);
    204     HWY_ALIGN_MAX double slow_coeffs[64];
    205     for (size_t i = 0; i < 64; i++) slow_coeffs[i] = pixels[i];
    206     DCTSlow<8>(slow_coeffs);
    207 
    208     for (size_t i = 0; i < 64; i++) {
    209       // DCTSlow doesn't multiply/divide by 1/N, so we do it manually.
    210       slow_coeffs[i] = roundf(slow_coeffs[i] / kUniformQuant) * kUniformQuant;
    211       coeffs[i] = roundf(coeffs[i] / dequant_matrices.Matrix(dct, 0)[i]) *
    212                   dequant_matrices.Matrix(dct, 0)[i];
    213     }
    214     IDCTSlow<8>(slow_coeffs);
    215     TransformToPixels(dct, coeffs, pixels, 8, scratch_space);
    216     for (size_t i = 0; i < 64; i++) {
    217       EXPECT_NEAR(pixels[i], slow_coeffs[i], 1e-4);
    218     }
    219   }
    220 
    221   // DCT16
    222   {
    223     HWY_ALIGN_MAX float pixels[64 * 4];
    224     std::iota(std::begin(pixels), std::end(pixels), 0);
    225     HWY_ALIGN_MAX float coeffs[64 * 4];
    226     const AcStrategy::Type dct = AcStrategy::DCT16X16;
    227     TransformFromPixels(dct, pixels, 16, coeffs, scratch_space);
    228     HWY_ALIGN_MAX double slow_coeffs[64 * 4];
    229     for (size_t i = 0; i < 64 * 4; i++) slow_coeffs[i] = pixels[i];
    230     DCTSlow<16>(slow_coeffs);
    231 
    232     for (size_t i = 0; i < 64 * 4; i++) {
    233       slow_coeffs[i] = roundf(slow_coeffs[i] / kUniformQuant) * kUniformQuant;
    234       coeffs[i] = roundf(coeffs[i] / dequant_matrices.Matrix(dct, 0)[i]) *
    235                   dequant_matrices.Matrix(dct, 0)[i];
    236     }
    237 
    238     IDCTSlow<16>(slow_coeffs);
    239     TransformToPixels(dct, coeffs, pixels, 16, scratch_space);
    240     for (size_t i = 0; i < 64 * 4; i++) {
    241       EXPECT_NEAR(pixels[i], slow_coeffs[i], 1e-4);
    242     }
    243   }
    244 
    245   // Check that all matrices have the same DC quantization, i.e. that they all
    246   // have the same scaling.
    247   for (size_t i = 0; i < AcStrategy::kNumValidStrategies; i++) {
    248     EXPECT_NEAR(dequant_matrices.Matrix(i, 0)[0], kUniformQuant, 1e-6);
    249   }
    250 }
    251 
    252 }  // namespace
    253 }  // namespace jxl