libjxl

FORK: libjxl patches used on blog
git clone https://git.neptards.moe/blog/libjxl.git
Log | Files | Refs | Submodules | README | LICENSE

fast_math_test.cc (8552B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 #undef HWY_TARGET_INCLUDE
      7 #define HWY_TARGET_INCLUDE "lib/jxl/fast_math_test.cc"
      8 #include <jxl/cms.h>
      9 
     10 #include <hwy/foreach_target.h>
     11 
     12 #include "lib/jxl/base/random.h"
     13 #include "lib/jxl/cms/transfer_functions-inl.h"
     14 #include "lib/jxl/dec_xyb-inl.h"
     15 #include "lib/jxl/enc_xyb.h"
     16 #include "lib/jxl/testing.h"
     17 
     18 // Test utils
     19 #include <hwy/highway.h>
     20 #include <hwy/tests/hwy_gtest.h>
     21 HWY_BEFORE_NAMESPACE();
     22 namespace jxl {
     23 namespace HWY_NAMESPACE {
     24 namespace {
     25 
     26 HWY_NOINLINE void TestFastLog2() {
     27   constexpr size_t kNumTrials = 1 << 23;
     28   Rng rng(1);
     29   float max_abs_err = 0;
     30   HWY_FULL(float) d;
     31   for (size_t i = 0; i < kNumTrials; i++) {
     32     const float f = rng.UniformF(1e-7f, 1e3f);
     33     const auto actual_v = FastLog2f(d, Set(d, f));
     34     const float actual = GetLane(actual_v);
     35     const float abs_err = std::abs(std::log2(f) - actual);
     36     EXPECT_LT(abs_err, 3.1E-6) << "f = " << f;
     37     max_abs_err = std::max(max_abs_err, abs_err);
     38   }
     39   printf("max abs err %e\n", static_cast<double>(max_abs_err));
     40 }
     41 
     42 HWY_NOINLINE void TestFastPow2() {
     43   constexpr size_t kNumTrials = 1 << 23;
     44   Rng rng(1);
     45   float max_rel_err = 0;
     46   HWY_FULL(float) d;
     47   for (size_t i = 0; i < kNumTrials; i++) {
     48     const float f = rng.UniformF(-100, 100);
     49     const auto actual_v = FastPow2f(d, Set(d, f));
     50     const float actual = GetLane(actual_v);
     51     const float expected = std::pow(2, f);
     52     const float rel_err = std::abs(expected - actual) / expected;
     53     EXPECT_LT(rel_err, 3.1E-6) << "f = " << f;
     54     max_rel_err = std::max(max_rel_err, rel_err);
     55   }
     56   printf("max rel err %e\n", static_cast<double>(max_rel_err));
     57 }
     58 
     59 HWY_NOINLINE void TestFastPow() {
     60   constexpr size_t kNumTrials = 1 << 23;
     61   Rng rng(1);
     62   float max_rel_err = 0;
     63   HWY_FULL(float) d;
     64   for (size_t i = 0; i < kNumTrials; i++) {
     65     const float b = rng.UniformF(1e-3f, 1e3f);
     66     const float e = rng.UniformF(-10, 10);
     67     const auto actual_v = FastPowf(d, Set(d, b), Set(d, e));
     68     const float actual = GetLane(actual_v);
     69     const float expected = std::pow(b, e);
     70     const float rel_err = std::abs(expected - actual) / expected;
     71     EXPECT_LT(rel_err, 3E-5) << "b = " << b << " e = " << e;
     72     max_rel_err = std::max(max_rel_err, rel_err);
     73   }
     74   printf("max rel err %e\n", static_cast<double>(max_rel_err));
     75 }
     76 
     77 HWY_NOINLINE void TestFastCos() {
     78   constexpr size_t kNumTrials = 1 << 23;
     79   Rng rng(1);
     80   float max_abs_err = 0;
     81   HWY_FULL(float) d;
     82   for (size_t i = 0; i < kNumTrials; i++) {
     83     const float f = rng.UniformF(-1e3f, 1e3f);
     84     const auto actual_v = FastCosf(d, Set(d, f));
     85     const float actual = GetLane(actual_v);
     86     const float abs_err = std::abs(std::cos(f) - actual);
     87     EXPECT_LT(abs_err, 7E-5) << "f = " << f;
     88     max_abs_err = std::max(max_abs_err, abs_err);
     89   }
     90   printf("max abs err %e\n", static_cast<double>(max_abs_err));
     91 }
     92 
     93 HWY_NOINLINE void TestFastErf() {
     94   constexpr size_t kNumTrials = 1 << 23;
     95   Rng rng(1);
     96   float max_abs_err = 0;
     97   HWY_FULL(float) d;
     98   for (size_t i = 0; i < kNumTrials; i++) {
     99     const float f = rng.UniformF(-5.f, 5.f);
    100     const auto actual_v = FastErff(d, Set(d, f));
    101     const float actual = GetLane(actual_v);
    102     const float abs_err = std::abs(std::erf(f) - actual);
    103     EXPECT_LT(abs_err, 7E-4) << "f = " << f;
    104     max_abs_err = std::max(max_abs_err, abs_err);
    105   }
    106   printf("max abs err %e\n", static_cast<double>(max_abs_err));
    107 }
    108 
    109 HWY_NOINLINE void TestCubeRoot() {
    110   const HWY_FULL(float) d;
    111   for (uint64_t x5 = 0; x5 < 2000000; x5++) {
    112     const float x = x5 * 1E-5f;
    113     const float expected = cbrtf(x);
    114     HWY_ALIGN float approx[MaxLanes(d)];
    115     Store(CubeRootAndAdd(Set(d, x), Zero(d)), d, approx);
    116 
    117     // All lanes are same
    118     for (size_t i = 1; i < Lanes(d); ++i) {
    119       EXPECT_NEAR(approx[0], approx[i], 5E-7f);
    120     }
    121     EXPECT_NEAR(approx[0], expected, 8E-7f);
    122   }
    123 }
    124 
    125 HWY_NOINLINE void TestFastSRGB() {
    126   constexpr size_t kNumTrials = 1 << 23;
    127   Rng rng(1);
    128   float max_abs_err = 0;
    129   HWY_FULL(float) d;
    130   for (size_t i = 0; i < kNumTrials; i++) {
    131     const float f = rng.UniformF(0.0f, 1.0f);
    132     const auto actual_v = FastLinearToSRGB(d, Set(d, f));
    133     const float actual = GetLane(actual_v);
    134     const float expected = GetLane(TF_SRGB().EncodedFromDisplay(d, Set(d, f)));
    135     const float abs_err = std::abs(expected - actual);
    136     EXPECT_LT(abs_err, 1.2E-4) << "f = " << f;
    137     max_abs_err = std::max(max_abs_err, abs_err);
    138   }
    139   printf("max abs err %e\n", static_cast<double>(max_abs_err));
    140 }
    141 
    142 HWY_NOINLINE void TestFast709EFD() {
    143   constexpr size_t kNumTrials = 1 << 23;
    144   Rng rng(1);
    145   float max_abs_err = 0;
    146   HWY_FULL(float) d;
    147   for (size_t i = 0; i < kNumTrials; i++) {
    148     const float f = rng.UniformF(0.0f, 1.0f);
    149     const float actual = GetLane(TF_709().EncodedFromDisplay(d, Set(d, f)));
    150     const float expected = TF_709().EncodedFromDisplay(f);
    151     const float abs_err = std::abs(expected - actual);
    152     EXPECT_LT(abs_err, 2e-6) << "f = " << f;
    153     max_abs_err = std::max(max_abs_err, abs_err);
    154   }
    155   printf("max abs err %e\n", static_cast<double>(max_abs_err));
    156 }
    157 
    158 HWY_NOINLINE void TestFastXYB() {
    159   if (!HasFastXYBTosRGB8()) return;
    160   ImageMetadata metadata;
    161   ImageBundle ib(&metadata);
    162   int scaling = 1;
    163   int n = 256 * scaling;
    164   float inv_scaling = 1.0f / scaling;
    165   int kChunk = 32;
    166   // The image is divided in chunks to reduce total memory usage.
    167   for (int cr = 0; cr < n; cr += kChunk) {
    168     for (int cg = 0; cg < n; cg += kChunk) {
    169       for (int cb = 0; cb < n; cb += kChunk) {
    170         JXL_ASSIGN_OR_DIE(Image3F chunk,
    171                           Image3F::Create(kChunk * kChunk, kChunk));
    172         for (int ir = 0; ir < kChunk; ir++) {
    173           for (int ig = 0; ig < kChunk; ig++) {
    174             for (int ib = 0; ib < kChunk; ib++) {
    175               float r = (cr + ir) * inv_scaling;
    176               float g = (cg + ig) * inv_scaling;
    177               float b = (cb + ib) * inv_scaling;
    178               chunk.PlaneRow(0, ir)[ig * kChunk + ib] = r * (1.0f / 255);
    179               chunk.PlaneRow(1, ir)[ig * kChunk + ib] = g * (1.0f / 255);
    180               chunk.PlaneRow(2, ir)[ig * kChunk + ib] = b * (1.0f / 255);
    181             }
    182           }
    183         }
    184         ib.SetFromImage(std::move(chunk), ColorEncoding::SRGB());
    185         JXL_ASSIGN_OR_DIE(Image3F xyb,
    186                           Image3F::Create(kChunk * kChunk, kChunk));
    187         std::vector<uint8_t> roundtrip(kChunk * kChunk * kChunk * 3);
    188         JXL_CHECK(ToXYB(ib, nullptr, &xyb, *JxlGetDefaultCms()));
    189         for (int y = 0; y < kChunk; y++) {
    190           const float* xyba[4] = {xyb.PlaneRow(0, y), xyb.PlaneRow(1, y),
    191                                   xyb.PlaneRow(2, y), nullptr};
    192           jxl::HWY_NAMESPACE::FastXYBTosRGB8(
    193               xyba, roundtrip.data() + 3 * xyb.xsize() * y, false, xyb.xsize());
    194         }
    195         for (int ir = 0; ir < kChunk; ir++) {
    196           for (int ig = 0; ig < kChunk; ig++) {
    197             for (int ib = 0; ib < kChunk; ib++) {
    198               float r = (cr + ir) * inv_scaling;
    199               float g = (cg + ig) * inv_scaling;
    200               float b = (cb + ib) * inv_scaling;
    201               size_t idx = ir * kChunk * kChunk + ig * kChunk + ib;
    202               int rr = roundtrip[3 * idx];
    203               int rg = roundtrip[3 * idx + 1];
    204               int rb = roundtrip[3 * idx + 2];
    205               EXPECT_LT(abs(r - rr), 2) << "expected " << r << " got " << rr;
    206               EXPECT_LT(abs(g - rg), 2) << "expected " << g << " got " << rg;
    207               EXPECT_LT(abs(b - rb), 2) << "expected " << b << " got " << rb;
    208             }
    209           }
    210         }
    211       }
    212     }
    213   }
    214 }
    215 
    216 }  // namespace
    217 // NOLINTNEXTLINE(google-readability-namespace-comments)
    218 }  // namespace HWY_NAMESPACE
    219 }  // namespace jxl
    220 HWY_AFTER_NAMESPACE();
    221 
    222 #if HWY_ONCE
    223 namespace jxl {
    224 
    225 class FastMathTargetTest : public hwy::TestWithParamTarget {};
    226 HWY_TARGET_INSTANTIATE_TEST_SUITE_P(FastMathTargetTest);
    227 
    228 HWY_EXPORT_AND_TEST_P(FastMathTargetTest, TestFastLog2);
    229 HWY_EXPORT_AND_TEST_P(FastMathTargetTest, TestFastPow2);
    230 HWY_EXPORT_AND_TEST_P(FastMathTargetTest, TestFastPow);
    231 HWY_EXPORT_AND_TEST_P(FastMathTargetTest, TestFastCos);
    232 HWY_EXPORT_AND_TEST_P(FastMathTargetTest, TestFastErf);
    233 HWY_EXPORT_AND_TEST_P(FastMathTargetTest, TestCubeRoot);
    234 HWY_EXPORT_AND_TEST_P(FastMathTargetTest, TestFastSRGB);
    235 HWY_EXPORT_AND_TEST_P(FastMathTargetTest, TestFast709EFD);
    236 HWY_EXPORT_AND_TEST_P(FastMathTargetTest, TestFastXYB);
    237 
    238 }  // namespace jxl
    239 #endif  // HWY_ONCE