libjxl

FORK: libjxl patches used on blog
git clone https://git.neptards.moe/blog/libjxl.git
Log | Files | Refs | Submodules | README | LICENSE

metrics.cc (6956B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 #include "lib/extras/metrics.h"
      7 
      8 #include <math.h>
      9 #include <stdlib.h>
     10 
     11 #include <atomic>
     12 
     13 #undef HWY_TARGET_INCLUDE
     14 #define HWY_TARGET_INCLUDE "lib/extras/metrics.cc"
     15 #include <hwy/foreach_target.h>
     16 #include <hwy/highway.h>
     17 
     18 #include "lib/jxl/base/compiler_specific.h"
     19 #include "lib/jxl/base/status.h"
     20 #include "lib/jxl/color_encoding_internal.h"
     21 HWY_BEFORE_NAMESPACE();
     22 namespace jxl {
     23 namespace HWY_NAMESPACE {
     24 
     25 // These templates are not found via ADL.
     26 using hwy::HWY_NAMESPACE::Add;
     27 using hwy::HWY_NAMESPACE::GetLane;
     28 using hwy::HWY_NAMESPACE::Mul;
     29 using hwy::HWY_NAMESPACE::Rebind;
     30 
     31 double ComputeDistanceP(const ImageF& distmap, const ButteraugliParams& params,
     32                         double p) {
     33   const double onePerPixels = 1.0 / (distmap.ysize() * distmap.xsize());
     34   if (std::abs(p - 3.0) < 1E-6) {
     35     double sum1[3] = {0.0};
     36 
     37 // Prefer double if possible, but otherwise use float rather than scalar.
     38 #if HWY_CAP_FLOAT64
     39     using T = double;
     40     const Rebind<float, HWY_FULL(double)> df;
     41 #else
     42     using T = float;
     43 #endif
     44     const HWY_FULL(T) d;
     45     constexpr size_t N = MaxLanes(d);
     46     // Manually aligned storage to avoid asan crash on clang-7 due to
     47     // unaligned spill.
     48     HWY_ALIGN T sum_totals0[N] = {0};
     49     HWY_ALIGN T sum_totals1[N] = {0};
     50     HWY_ALIGN T sum_totals2[N] = {0};
     51 
     52     for (size_t y = 0; y < distmap.ysize(); ++y) {
     53       const float* JXL_RESTRICT row = distmap.ConstRow(y);
     54 
     55       auto sums0 = Zero(d);
     56       auto sums1 = Zero(d);
     57       auto sums2 = Zero(d);
     58 
     59       size_t x = 0;
     60       for (; x + Lanes(d) <= distmap.xsize(); x += Lanes(d)) {
     61 #if HWY_CAP_FLOAT64
     62         const auto d1 = PromoteTo(d, Load(df, row + x));
     63 #else
     64         const auto d1 = Load(d, row + x);
     65 #endif
     66         const auto d2 = Mul(d1, Mul(d1, d1));
     67         sums0 = Add(sums0, d2);
     68         const auto d3 = Mul(d2, d2);
     69         sums1 = Add(sums1, d3);
     70         const auto d4 = Mul(d3, d3);
     71         sums2 = Add(sums2, d4);
     72       }
     73 
     74       Store(Add(sums0, Load(d, sum_totals0)), d, sum_totals0);
     75       Store(Add(sums1, Load(d, sum_totals1)), d, sum_totals1);
     76       Store(Add(sums2, Load(d, sum_totals2)), d, sum_totals2);
     77 
     78       for (; x < distmap.xsize(); ++x) {
     79         const double d1 = row[x];
     80         double d2 = d1 * d1 * d1;
     81         sum1[0] += d2;
     82         d2 *= d2;
     83         sum1[1] += d2;
     84         d2 *= d2;
     85         sum1[2] += d2;
     86       }
     87     }
     88     double v = 0;
     89     v += pow(
     90         onePerPixels * (sum1[0] + GetLane(SumOfLanes(d, Load(d, sum_totals0)))),
     91         1.0 / (p * 1.0));
     92     v += pow(
     93         onePerPixels * (sum1[1] + GetLane(SumOfLanes(d, Load(d, sum_totals1)))),
     94         1.0 / (p * 2.0));
     95     v += pow(
     96         onePerPixels * (sum1[2] + GetLane(SumOfLanes(d, Load(d, sum_totals2)))),
     97         1.0 / (p * 4.0));
     98     v /= 3.0;
     99     return v;
    100   } else {
    101     static std::atomic<int> once{0};
    102     if (once.fetch_add(1, std::memory_order_relaxed) == 0) {
    103       JXL_WARNING("WARNING: using slow ComputeDistanceP");
    104     }
    105     double sum1[3] = {0.0};
    106     for (size_t y = 0; y < distmap.ysize(); ++y) {
    107       const float* JXL_RESTRICT row = distmap.ConstRow(y);
    108       for (size_t x = 0; x < distmap.xsize(); ++x) {
    109         double d2 = std::pow(row[x], p);
    110         sum1[0] += d2;
    111         d2 *= d2;
    112         sum1[1] += d2;
    113         d2 *= d2;
    114         sum1[2] += d2;
    115       }
    116     }
    117     double v = 0;
    118     for (int i = 0; i < 3; ++i) {
    119       v += pow(onePerPixels * (sum1[i]), 1.0 / (p * (1 << i)));
    120     }
    121     v /= 3.0;
    122     return v;
    123   }
    124 }
    125 
    126 void ComputeSumOfSquares(const ImageBundle& ib1, const ImageBundle& ib2,
    127                          const JxlCmsInterface& cms, double sum_of_squares[3]) {
    128   // Convert to sRGB - closer to perception than linear.
    129   const Image3F* srgb1 = &ib1.color();
    130   Image3F copy1;
    131   if (!ib1.IsSRGB()) {
    132     JXL_CHECK(
    133         ib1.CopyTo(Rect(ib1), ColorEncoding::SRGB(ib1.IsGray()), cms, &copy1));
    134     srgb1 = &copy1;
    135   }
    136   const Image3F* srgb2 = &ib2.color();
    137   Image3F copy2;
    138   if (!ib2.IsSRGB()) {
    139     JXL_CHECK(
    140         ib2.CopyTo(Rect(ib2), ColorEncoding::SRGB(ib2.IsGray()), cms, &copy2));
    141     srgb2 = &copy2;
    142   }
    143 
    144   JXL_CHECK(SameSize(*srgb1, *srgb2));
    145 
    146   // TODO(veluca): SIMD.
    147   float yuvmatrix[3][3] = {{0.299, 0.587, 0.114},
    148                            {-0.14713, -0.28886, 0.436},
    149                            {0.615, -0.51499, -0.10001}};
    150   for (size_t y = 0; y < srgb1->ysize(); ++y) {
    151     const float* JXL_RESTRICT row1[3];
    152     const float* JXL_RESTRICT row2[3];
    153     for (size_t j = 0; j < 3; j++) {
    154       row1[j] = srgb1->ConstPlaneRow(j, y);
    155       row2[j] = srgb2->ConstPlaneRow(j, y);
    156     }
    157     for (size_t x = 0; x < srgb1->xsize(); ++x) {
    158       float cdiff[3] = {};
    159       // YUV conversion is linear, so we can run it on the difference.
    160       for (size_t j = 0; j < 3; j++) {
    161         cdiff[j] = row1[j][x] - row2[j][x];
    162       }
    163       float yuvdiff[3] = {};
    164       for (size_t j = 0; j < 3; j++) {
    165         for (size_t k = 0; k < 3; k++) {
    166           yuvdiff[j] += yuvmatrix[j][k] * cdiff[k];
    167         }
    168       }
    169       for (size_t j = 0; j < 3; j++) {
    170         sum_of_squares[j] += yuvdiff[j] * yuvdiff[j];
    171       }
    172     }
    173   }
    174 }
    175 
    176 // NOLINTNEXTLINE(google-readability-namespace-comments)
    177 }  // namespace HWY_NAMESPACE
    178 }  // namespace jxl
    179 HWY_AFTER_NAMESPACE();
    180 
    181 #if HWY_ONCE
    182 namespace jxl {
    183 HWY_EXPORT(ComputeDistanceP);
    184 double ComputeDistanceP(const ImageF& distmap, const ButteraugliParams& params,
    185                         double p) {
    186   return HWY_DYNAMIC_DISPATCH(ComputeDistanceP)(distmap, params, p);
    187 }
    188 
    189 HWY_EXPORT(ComputeSumOfSquares);
    190 
    191 double ComputeDistance2(const ImageBundle& ib1, const ImageBundle& ib2,
    192                         const JxlCmsInterface& cms) {
    193   double sum_of_squares[3] = {};
    194   HWY_DYNAMIC_DISPATCH(ComputeSumOfSquares)(ib1, ib2, cms, sum_of_squares);
    195   // Weighted PSNR as in JPEG-XL: chroma counts 1/8.
    196   const float weights[3] = {6.0f / 8, 1.0f / 8, 1.0f / 8};
    197   // Avoid squaring the weight - 1/64 is too extreme.
    198   double norm = 0;
    199   for (size_t i = 0; i < 3; i++) {
    200     norm += std::sqrt(sum_of_squares[i]) * weights[i];
    201   }
    202   // This function returns distance *squared*.
    203   return norm * norm;
    204 }
    205 
    206 double ComputePSNR(const ImageBundle& ib1, const ImageBundle& ib2,
    207                    const JxlCmsInterface& cms) {
    208   if (!SameSize(ib1, ib2)) return 0.0;
    209   double sum_of_squares[3] = {};
    210   HWY_DYNAMIC_DISPATCH(ComputeSumOfSquares)(ib1, ib2, cms, sum_of_squares);
    211   constexpr double kChannelWeights[3] = {6.0 / 8, 1.0 / 8, 1.0 / 8};
    212   double avg_psnr = 0;
    213   const size_t input_pixels = ib1.xsize() * ib1.ysize();
    214   for (int i = 0; i < 3; ++i) {
    215     const double rmse = std::sqrt(sum_of_squares[i] / input_pixels);
    216     const double psnr =
    217         sum_of_squares[i] == 0 ? 99.99 : (20 * std::log10(1 / rmse));
    218     avg_psnr += kChannelWeights[i] * psnr;
    219   }
    220   return avg_psnr;
    221 }
    222 
    223 }  // namespace jxl
    224 #endif