libjxl

FORK: libjxl patches used on blog
git clone https://git.neptards.moe/blog/libjxl.git
Log | Files | Refs | Submodules | README | LICENSE

convolve_test.cc (9511B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 #include "lib/jxl/convolve.h"
      7 
      8 #include <jxl/types.h>
      9 #include <time.h>
     10 
     11 #undef HWY_TARGET_INCLUDE
     12 #define HWY_TARGET_INCLUDE "lib/jxl/convolve_test.cc"
     13 #include <hwy/foreach_target.h>
     14 #include <hwy/highway.h>
     15 #include <hwy/nanobenchmark.h>
     16 #include <hwy/tests/hwy_gtest.h>
     17 #include <vector>
     18 
     19 #include "lib/jxl/base/compiler_specific.h"
     20 #include "lib/jxl/base/data_parallel.h"
     21 #include "lib/jxl/base/printf_macros.h"
     22 #include "lib/jxl/base/random.h"
     23 #include "lib/jxl/image_ops.h"
     24 #include "lib/jxl/image_test_utils.h"
     25 #include "lib/jxl/test_utils.h"
     26 #include "lib/jxl/testing.h"
     27 
     28 #ifndef JXL_DEBUG_CONVOLVE
     29 #define JXL_DEBUG_CONVOLVE 0
     30 #endif
     31 
     32 #include "lib/jxl/convolve-inl.h"
     33 
     34 HWY_BEFORE_NAMESPACE();
     35 namespace jxl {
     36 namespace HWY_NAMESPACE {
     37 
     38 void TestNeighbors() {
     39   const Neighbors::D d;
     40   const Neighbors::V v = Iota(d, 0);
     41   constexpr size_t kMaxVectorSize = 64;
     42   constexpr size_t M = kMaxVectorSize / sizeof(float);
     43   HWY_ALIGN float actual[M] = {0};
     44 
     45   HWY_ALIGN float first_l1[M] = {0, 0, 1, 2,  3,  4,  5,  6,
     46                                  7, 8, 9, 10, 11, 12, 13, 14};
     47   Store(Neighbors::FirstL1(v), d, actual);
     48   const size_t N = Lanes(d);
     49   ASSERT_LE(N, M);
     50   EXPECT_EQ(std::vector<float>(first_l1, first_l1 + N),
     51             std::vector<float>(actual, actual + N));
     52 
     53 #if HWY_TARGET != HWY_SCALAR
     54   HWY_ALIGN float first_l2[M] = {1, 0, 0, 1, 2,  3,  4,  5,
     55                                  6, 7, 8, 9, 10, 11, 12, 13};
     56   Store(Neighbors::FirstL2(v), d, actual);
     57   EXPECT_EQ(std::vector<float>(first_l2, first_l2 + N),
     58             std::vector<float>(actual, actual + N));
     59 
     60   HWY_ALIGN float first_l3[] = {2, 1, 0, 0, 1, 2,  3,  4,
     61                                 5, 6, 7, 8, 9, 10, 11, 12};
     62   Store(Neighbors::FirstL3(v), d, actual);
     63   EXPECT_EQ(std::vector<float>(first_l3, first_l3 + N),
     64             std::vector<float>(actual, actual + N));
     65 #endif  // HWY_TARGET != HWY_SCALAR
     66 }
     67 
     68 void VerifySymmetric3(const size_t xsize, const size_t ysize, ThreadPool* pool,
     69                       Rng* rng) {
     70   const Rect rect(0, 0, xsize, ysize);
     71 
     72   JXL_ASSIGN_OR_DIE(ImageF in, ImageF::Create(xsize, ysize));
     73   GenerateImage(*rng, &in, 0.0f, 1.0f);
     74 
     75   JXL_ASSIGN_OR_DIE(ImageF out_expected, ImageF::Create(xsize, ysize));
     76   JXL_ASSIGN_OR_DIE(ImageF out_actual, ImageF::Create(xsize, ysize));
     77 
     78   const WeightsSymmetric3& weights = WeightsSymmetric3Lowpass();
     79   Symmetric3(in, rect, weights, pool, &out_expected);
     80   SlowSymmetric3(in, rect, weights, pool, &out_actual);
     81 
     82   JXL_ASSERT_OK(VerifyRelativeError(out_expected, out_actual, 1E-5f, 1E-5f, _));
     83 }
     84 
     85 std::vector<Rect> GenerateTestRectangles(size_t xsize, size_t ysize) {
     86   std::vector<Rect> out;
     87   for (size_t tl : {0, 1, 13}) {
     88     for (size_t br : {0, 1, 13}) {
     89       if (xsize > tl + br && ysize > tl + br) {
     90         out.emplace_back(tl, tl, xsize - tl - br, ysize - tl - br);
     91       }
     92     }
     93   }
     94   return out;
     95 }
     96 
     97 // Ensures Symmetric and Separable give the same result.
     98 void VerifySymmetric5(const size_t xsize, const size_t ysize, ThreadPool* pool,
     99                       Rng* rng) {
    100   JXL_ASSIGN_OR_DIE(ImageF in, ImageF::Create(xsize, ysize));
    101   GenerateImage(*rng, &in, 0.0f, 1.0f);
    102 
    103   for (const Rect& in_rect : GenerateTestRectangles(xsize, ysize)) {
    104     JXL_DEBUG(JXL_DEBUG_CONVOLVE,
    105               "in_rect: %" PRIuS "x%" PRIuS "+%" PRIuS ",%" PRIuS "",
    106               in_rect.xsize(), in_rect.ysize(), in_rect.x0(), in_rect.y0());
    107     {
    108       Rect out_rect = in_rect;
    109       JXL_ASSIGN_OR_DIE(ImageF out_expected, ImageF::Create(xsize, ysize));
    110       JXL_ASSIGN_OR_DIE(ImageF out_actual, ImageF::Create(xsize, ysize));
    111       FillImage(-1.0f, &out_expected);
    112       FillImage(-1.0f, &out_actual);
    113 
    114       SlowSeparable5(in, in_rect, WeightsSeparable5Lowpass(), pool,
    115                      &out_expected, out_rect);
    116       Symmetric5(in, in_rect, WeightsSymmetric5Lowpass(), pool, &out_actual,
    117                  out_rect);
    118 
    119       JXL_ASSERT_OK(
    120           VerifyRelativeError(out_expected, out_actual, 1E-5f, 1E-5f, _));
    121     }
    122     {
    123       Rect out_rect(0, 0, in_rect.xsize(), in_rect.ysize());
    124       JXL_ASSIGN_OR_DIE(ImageF out_expected,
    125                         ImageF::Create(out_rect.xsize(), out_rect.ysize()));
    126       JXL_ASSIGN_OR_DIE(ImageF out_actual,
    127                         ImageF::Create(out_rect.xsize(), out_rect.ysize()));
    128 
    129       SlowSeparable5(in, in_rect, WeightsSeparable5Lowpass(), pool,
    130                      &out_expected, out_rect);
    131       Symmetric5(in, in_rect, WeightsSymmetric5Lowpass(), pool, &out_actual,
    132                  out_rect);
    133 
    134       JXL_ASSERT_OK(
    135           VerifyRelativeError(out_expected, out_actual, 1E-5f, 1E-5f, _));
    136     }
    137   }
    138 }
    139 
    140 void VerifySeparable5(const size_t xsize, const size_t ysize, ThreadPool* pool,
    141                       Rng* rng) {
    142   const Rect rect(0, 0, xsize, ysize);
    143 
    144   JXL_ASSIGN_OR_DIE(ImageF in, ImageF::Create(xsize, ysize));
    145   GenerateImage(*rng, &in, 0.0f, 1.0f);
    146 
    147   JXL_ASSIGN_OR_DIE(ImageF out_expected, ImageF::Create(xsize, ysize));
    148   JXL_ASSIGN_OR_DIE(ImageF out_actual, ImageF::Create(xsize, ysize));
    149 
    150   const WeightsSeparable5& weights = WeightsSeparable5Lowpass();
    151   SlowSeparable5(in, rect, weights, pool, &out_expected, rect);
    152   Separable5(in, rect, weights, pool, &out_actual);
    153 
    154   JXL_ASSERT_OK(VerifyRelativeError(out_expected, out_actual, 1E-5f, 1E-5f, _));
    155 }
    156 
    157 // For all xsize/ysize and kernels:
    158 void TestConvolve() {
    159   TestNeighbors();
    160 
    161   test::ThreadPoolForTests pool(4);
    162   EXPECT_EQ(true,
    163             RunOnPool(
    164                 &pool, kConvolveMaxRadius, 40, ThreadPool::NoInit,
    165                 [](const uint32_t task, size_t /*thread*/) {
    166                   const size_t xsize = task;
    167                   Rng rng(129 + 13 * xsize);
    168 
    169                   ThreadPool* null_pool = nullptr;
    170                   test::ThreadPoolForTests pool3(3);
    171                   for (size_t ysize = kConvolveMaxRadius; ysize < 16; ++ysize) {
    172                     JXL_DEBUG(JXL_DEBUG_CONVOLVE,
    173                               "%" PRIuS " x %" PRIuS " (target %" PRIx64
    174                               ")===============================",
    175                               xsize, ysize, static_cast<int64_t>(HWY_TARGET));
    176 
    177                     JXL_DEBUG(JXL_DEBUG_CONVOLVE, "Sym3------------------");
    178                     VerifySymmetric3(xsize, ysize, null_pool, &rng);
    179                     VerifySymmetric3(xsize, ysize, &pool3, &rng);
    180 
    181                     JXL_DEBUG(JXL_DEBUG_CONVOLVE, "Sym5------------------");
    182                     VerifySymmetric5(xsize, ysize, null_pool, &rng);
    183                     VerifySymmetric5(xsize, ysize, &pool3, &rng);
    184 
    185                     JXL_DEBUG(JXL_DEBUG_CONVOLVE, "Sep5------------------");
    186                     VerifySeparable5(xsize, ysize, null_pool, &rng);
    187                     VerifySeparable5(xsize, ysize, &pool3, &rng);
    188                   }
    189                 },
    190                 "TestConvolve"));
    191 }
    192 
    193 // Measures durations, verifies results, prints timings. `unpredictable1`
    194 // must have value 1 (unknown to the compiler to prevent elision).
    195 template <class Conv>
    196 void BenchmarkConv(const char* caption, const Conv& conv,
    197                    const hwy::FuncInput unpredictable1) {
    198   const size_t kNumInputs = 1;
    199   const hwy::FuncInput inputs[kNumInputs] = {unpredictable1};
    200   hwy::Result results[kNumInputs];
    201 
    202   const size_t kDim = 160;  // in+out fit in L2
    203   JXL_ASSIGN_OR_DIE(ImageF in, ImageF::Create(kDim, kDim));
    204   ZeroFillImage(&in);
    205   in.Row(kDim / 2)[kDim / 2] = unpredictable1;
    206   JXL_ASSIGN_OR_DIE(ImageF out, ImageF::Create(kDim, kDim));
    207 
    208   hwy::Params p;
    209   p.verbose = false;
    210   p.max_evals = 7;
    211   p.target_rel_mad = 0.002;
    212   const size_t num_results = MeasureClosure(
    213       [&in, &conv, &out](const hwy::FuncInput input) {
    214         conv(in, &out);
    215         return out.Row(input)[0];
    216       },
    217       inputs, kNumInputs, results, p);
    218   if (num_results != kNumInputs) {
    219     fprintf(stderr, "MeasureClosure failed.\n");
    220   }
    221   for (size_t i = 0; i < num_results; ++i) {
    222     const double seconds = static_cast<double>(results[i].ticks) /
    223                            hwy::platform::InvariantTicksPerSecond();
    224     printf("%12s: %7.2f MP/s (MAD=%4.2f%%)\n", caption,
    225            kDim * kDim * 1E-6 / seconds,
    226            static_cast<double>(results[i].variability) * 100.0);
    227   }
    228 }
    229 
    230 struct ConvSymmetric3 {
    231   void operator()(const ImageF& in, ImageF* JXL_RESTRICT out) const {
    232     ThreadPool* null_pool = nullptr;
    233     Symmetric3(in, Rect(in), WeightsSymmetric3Lowpass(), null_pool, out);
    234   }
    235 };
    236 
    237 struct ConvSeparable5 {
    238   void operator()(const ImageF& in, ImageF* JXL_RESTRICT out) const {
    239     ThreadPool* null_pool = nullptr;
    240     Separable5(in, Rect(in), WeightsSeparable5Lowpass(), null_pool, out);
    241   }
    242 };
    243 
    244 void BenchmarkAll() {
    245 #if JXL_FALSE  // disabled to avoid test timeouts, run manually on demand
    246   const hwy::FuncInput unpredictable1 = time(nullptr) != 1234;
    247   BenchmarkConv("Symmetric3", ConvSymmetric3(), unpredictable1);
    248   BenchmarkConv("Separable5", ConvSeparable5(), unpredictable1);
    249 #endif
    250 }
    251 
    252 // NOLINTNEXTLINE(google-readability-namespace-comments)
    253 }  // namespace HWY_NAMESPACE
    254 }  // namespace jxl
    255 HWY_AFTER_NAMESPACE();
    256 
    257 #if HWY_ONCE
    258 namespace jxl {
    259 
    260 class ConvolveTest : public hwy::TestWithParamTarget {};
    261 HWY_TARGET_INSTANTIATE_TEST_SUITE_P(ConvolveTest);
    262 
    263 HWY_EXPORT_AND_TEST_P(ConvolveTest, TestConvolve);
    264 
    265 HWY_EXPORT_AND_TEST_P(ConvolveTest, BenchmarkAll);
    266 
    267 }  // namespace jxl
    268 #endif