libjxl

FORK: libjxl patches used on blog
git clone https://git.neptards.moe/blog/libjxl.git
Log | Files | Refs | Submodules | README | LICENSE

ssimulacra.cc (11852B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 // Re-implementation of //tools/ssimulacra.tct using jxl's
      7 // ImageF library instead of opencv.
      8 
      9 #include "tools/ssimulacra.h"
     10 
     11 #include <cmath>
     12 
     13 #include "lib/jxl/base/status.h"
     14 #include "lib/jxl/image.h"
     15 #include "lib/jxl/image_ops.h"
     16 #include "tools/gauss_blur.h"
     17 
     18 namespace ssimulacra {
     19 namespace {
     20 
     21 using jxl::Image3F;
     22 using jxl::ImageF;
     23 using jxl::StatusOr;
     24 
     25 const float kC1 = 0.0001f;
     26 const float kC2 = 0.0004f;
     27 const int kNumScales = 6;
     28 // Premultiplied by chroma weight 0.2
     29 const double kScaleWeights[kNumScales][3] = {
     30     {0.04480, 0.00300, 0.00300}, {0.28560, 0.00896, 0.00896},
     31     {0.30010, 0.05712, 0.05712}, {0.23630, 0.06002, 0.06002},
     32     {0.13330, 0.06726, 0.06726}, {0.10000, 0.05000, 0.05000},
     33 };
     34 // Premultiplied by min weights 0.1, 0.005, 0.005
     35 const double kMinScaleWeights[kNumScales][3] = {
     36     {0.02000, 0.00005, 0.00005}, {0.03000, 0.00025, 0.00025},
     37     {0.02500, 0.00100, 0.00100}, {0.02000, 0.00150, 0.00150},
     38     {0.01200, 0.00175, 0.00175}, {0.00500, 0.00175, 0.00175},
     39 };
     40 const double kEdgeWeight[3] = {1.5, 0.1, 0.1};
     41 const double kGridWeight[3] = {1.0, 0.1, 0.1};
     42 
     43 inline void Rgb2Lab(float r, float g, float b, float* L, float* A, float* B) {
     44   const float epsilon = 0.00885645167903563081f;
     45   const float s = 0.13793103448275862068f;
     46   const float k = 7.78703703703703703703f;
     47   float fx = (r * 0.43393624408206207259f + g * 0.37619779063650710152f +
     48               b * 0.18983429773803261441f);
     49   float fy = (r * 0.2126729f + g * 0.7151522f + b * 0.0721750f);
     50   float fz = (r * 0.01775381083562901744f + g * 0.10945087235996326905f +
     51               b * 0.87263921028466483011f);
     52   const float gamma = 1.0f / 3.0f;
     53   float X = (fx > epsilon) ? powf(fx, gamma) - s : k * fx;
     54   float Y = (fy > epsilon) ? powf(fy, gamma) - s : k * fy;
     55   float Z = (fz > epsilon) ? powf(fz, gamma) - s : k * fz;
     56   *L = Y * 1.16f;
     57   *A = (0.39181818181818181818f + 2.27272727272727272727f * (X - Y));
     58   *B = (0.49045454545454545454f + 0.90909090909090909090f * (Y - Z));
     59 }
     60 
     61 StatusOr<Image3F> Rgb2Lab(const Image3F& in) {
     62   JXL_ASSIGN_OR_RETURN(Image3F out, Image3F::Create(in.xsize(), in.ysize()));
     63   for (size_t y = 0; y < in.ysize(); ++y) {
     64     const float* JXL_RESTRICT row_in0 = in.PlaneRow(0, y);
     65     const float* JXL_RESTRICT row_in1 = in.PlaneRow(1, y);
     66     const float* JXL_RESTRICT row_in2 = in.PlaneRow(2, y);
     67     float* JXL_RESTRICT row_out0 = out.PlaneRow(0, y);
     68     float* JXL_RESTRICT row_out1 = out.PlaneRow(1, y);
     69     float* JXL_RESTRICT row_out2 = out.PlaneRow(2, y);
     70 
     71     for (size_t x = 0; x < in.xsize(); ++x) {
     72       Rgb2Lab(row_in0[x], row_in1[x], row_in2[x], &row_out0[x], &row_out1[x],
     73               &row_out2[x]);
     74     }
     75   }
     76   return out;
     77 }
     78 
     79 StatusOr<Image3F> Downsample(const Image3F& in, size_t fx, size_t fy) {
     80   const size_t out_xsize = (in.xsize() + fx - 1) / fx;
     81   const size_t out_ysize = (in.ysize() + fy - 1) / fy;
     82   JXL_ASSIGN_OR_RETURN(Image3F out, Image3F::Create(out_xsize, out_ysize));
     83   const float normalize = 1.0f / (fx * fy);
     84   for (size_t c = 0; c < 3; ++c) {
     85     for (size_t oy = 0; oy < out_ysize; ++oy) {
     86       float* JXL_RESTRICT row_out = out.PlaneRow(c, oy);
     87       for (size_t ox = 0; ox < out_xsize; ++ox) {
     88         float sum = 0.0f;
     89         for (size_t iy = 0; iy < fy; ++iy) {
     90           for (size_t ix = 0; ix < fx; ++ix) {
     91             const size_t x = std::min(ox * fx + ix, in.xsize() - 1);
     92             const size_t y = std::min(oy * fy + iy, in.ysize() - 1);
     93             sum += in.PlaneRow(c, y)[x];
     94           }
     95         }
     96         row_out[ox] = sum * normalize;
     97       }
     98     }
     99   }
    100   return out;
    101 }
    102 
    103 void Multiply(const Image3F& a, const Image3F& b, Image3F* mul) {
    104   for (size_t c = 0; c < 3; ++c) {
    105     for (size_t y = 0; y < a.ysize(); ++y) {
    106       const float* JXL_RESTRICT in1 = a.PlaneRow(c, y);
    107       const float* JXL_RESTRICT in2 = b.PlaneRow(c, y);
    108       float* JXL_RESTRICT out = mul->PlaneRow(c, y);
    109       for (size_t x = 0; x < a.xsize(); ++x) {
    110         out[x] = in1[x] * in2[x];
    111       }
    112     }
    113   }
    114 }
    115 
    116 void RowColAvgP2(const ImageF& in, double* rp2, double* cp2) {
    117   std::vector<double> ravg(in.ysize());
    118   std::vector<double> cavg(in.xsize());
    119   for (size_t y = 0; y < in.ysize(); ++y) {
    120     const auto* row = in.Row(y);
    121     for (size_t x = 0; x < in.xsize(); ++x) {
    122       const float val = row[x];
    123       ravg[y] += val;
    124       cavg[x] += val;
    125     }
    126   }
    127   std::sort(ravg.begin(), ravg.end());
    128   std::sort(cavg.begin(), cavg.end());
    129   *rp2 = ravg[ravg.size() / 50] / in.xsize();
    130   *cp2 = cavg[cavg.size() / 50] / in.ysize();
    131 }
    132 
    133 class StreamingAverage {
    134  public:
    135   void Add(const float v) {
    136     // Numerically stable method.
    137     double delta = v - result_;
    138     n_ += 1;
    139     result_ += delta / n_;
    140   }
    141 
    142   double Get() const { return result_; }
    143 
    144  private:
    145   double result_ = 0.0;
    146   size_t n_ = 0;
    147 };
    148 
    149 void EdgeDiffMap(const Image3F& img1, const Image3F& mu1, const Image3F& img2,
    150                  const Image3F& mu2, Image3F* out, double* plane_avg) {
    151   for (size_t c = 0; c < 3; ++c) {
    152     StreamingAverage avg;
    153     for (size_t y = 0; y < img1.ysize(); ++y) {
    154       const float* JXL_RESTRICT row1 = img1.PlaneRow(c, y);
    155       const float* JXL_RESTRICT row2 = img2.PlaneRow(c, y);
    156       const float* JXL_RESTRICT rowm1 = mu1.PlaneRow(c, y);
    157       const float* JXL_RESTRICT rowm2 = mu2.PlaneRow(c, y);
    158       float* JXL_RESTRICT row_out = out->PlaneRow(c, y);
    159       for (size_t x = 0; x < img1.xsize(); ++x) {
    160         float edgediff = std::max(
    161             std::abs(row2[x] - rowm2[x]) - std::abs(row1[x] - rowm1[x]), 0.0f);
    162         row_out[x] = 1.0f - edgediff;
    163         avg.Add(row_out[x]);
    164       }
    165     }
    166     plane_avg[c] = avg.Get();
    167   }
    168 }
    169 
    170 // Temporary storage for Gaussian blur, reused for multiple images.
    171 class Blur {
    172  public:
    173   static StatusOr<Blur> Create(const size_t xsize, const size_t ysize) {
    174     Blur result;
    175     JXL_ASSIGN_OR_RETURN(result.temp_, ImageF::Create(xsize, ysize));
    176     return result;
    177   }
    178 
    179   void operator()(const ImageF& in, ImageF* JXL_RESTRICT out) {
    180     FastGaussian(
    181         rg_, in.xsize(), in.ysize(), [&](size_t y) { return in.ConstRow(y); },
    182         [&](size_t y) { return temp_.Row(y); },
    183         [&](size_t y) { return out->Row(y); });
    184   }
    185 
    186   StatusOr<Image3F> operator()(const Image3F& in) {
    187     JXL_ASSIGN_OR_RETURN(Image3F out, Image3F::Create(in.xsize(), in.ysize()));
    188     operator()(in.Plane(0), &out.Plane(0));
    189     operator()(in.Plane(1), &out.Plane(1));
    190     operator()(in.Plane(2), &out.Plane(2));
    191     return out;
    192   }
    193 
    194   // Allows reusing across scales.
    195   void ShrinkTo(const size_t xsize, const size_t ysize) {
    196     temp_.ShrinkTo(xsize, ysize);
    197   }
    198 
    199  private:
    200   Blur() : rg_(jxl::CreateRecursiveGaussian(1.5)) {}
    201   hwy::AlignedUniquePtr<jxl::RecursiveGaussian> rg_;
    202   ImageF temp_;
    203 };
    204 
    205 void SSIMMap(const Image3F& m1, const Image3F& m2, const Image3F& s11,
    206              const Image3F& s22, const Image3F& s12, Image3F* out,
    207              double* plane_averages) {
    208   for (size_t c = 0; c < 3; ++c) {
    209     StreamingAverage avg;
    210     for (size_t y = 0; y < out->ysize(); ++y) {
    211       const float* JXL_RESTRICT row_m1 = m1.PlaneRow(c, y);
    212       const float* JXL_RESTRICT row_m2 = m2.PlaneRow(c, y);
    213       const float* JXL_RESTRICT row_s11 = s11.PlaneRow(c, y);
    214       const float* JXL_RESTRICT row_s22 = s22.PlaneRow(c, y);
    215       const float* JXL_RESTRICT row_s12 = s12.PlaneRow(c, y);
    216       float* JXL_RESTRICT row_out = out->PlaneRow(c, y);
    217       for (size_t x = 0; x < out->xsize(); ++x) {
    218         float mu1 = row_m1[x];
    219         float mu2 = row_m2[x];
    220         float mu11 = mu1 * mu1;
    221         float mu22 = mu2 * mu2;
    222         float mu12 = mu1 * mu2;
    223         float nom_m = 2 * mu12 + kC1;
    224         float nom_s = 2 * (row_s12[x] - mu12) + kC2;
    225         float denom_m = mu11 + mu22 + kC1;
    226         float denom_s = (row_s11[x] - mu11) + (row_s22[x] - mu22) + kC2;
    227         row_out[x] = (nom_m * nom_s) / (denom_m * denom_s);
    228         avg.Add(row_out[x]);
    229       }
    230     }
    231     plane_averages[c] = avg.Get();
    232   }
    233 }
    234 
    235 }  // namespace
    236 
    237 double Ssimulacra::Score() const {
    238   double ssim = 0.0;
    239   double ssim_max = 0.0;
    240   for (size_t c = 0; c < 3; ++c) {
    241     for (size_t scale = 0; scale < scales.size(); ++scale) {
    242       ssim += kScaleWeights[scale][c] * scales[scale].avg_ssim[c];
    243       ssim_max += kScaleWeights[scale][c];
    244       ssim += kMinScaleWeights[scale][c] * scales[scale].min_ssim[c];
    245       ssim_max += kMinScaleWeights[scale][c];
    246     }
    247     if (!simple) {
    248       ssim += kEdgeWeight[c] * avg_edgediff[c];
    249       ssim_max += kEdgeWeight[c];
    250       ssim += kGridWeight[c] *
    251               (row_p2[0][c] + row_p2[1][c] + col_p2[0][c] + col_p2[1][c]);
    252       ssim_max += 4.0 * kGridWeight[c];
    253     }
    254   }
    255   double dssim = ssim_max / ssim - 1.0;
    256   return std::min(1.0, std::max(0.0, dssim));
    257 }
    258 
    259 inline void PrintItem(const char* name, int scale, const double* vals,
    260                       const double* w) {
    261   printf("scale %d %s = [%.10f %.10f %.10f]  w = [%.5f %.5f %.5f]\n", scale,
    262          name, vals[0], vals[1], vals[2], w[0], w[1], w[2]);
    263 }
    264 
    265 void Ssimulacra::PrintDetails() const {
    266   for (size_t s = 0; s < scales.size(); ++s) {
    267     if (s < kNumScales) {
    268       PrintItem("avg ssim", s, scales[s].avg_ssim, kScaleWeights[s]);
    269       PrintItem("min ssim", s, scales[s].min_ssim, kMinScaleWeights[s]);
    270     }
    271     if (s == 0 && !simple) {
    272       PrintItem("avg edif", s, avg_edgediff, kEdgeWeight);
    273       PrintItem("rp2 ssim", s, &row_p2[0][0], kGridWeight);
    274       PrintItem("cp2 ssim", s, &col_p2[0][0], kGridWeight);
    275       PrintItem("rp2 edif", s, &row_p2[1][0], kGridWeight);
    276       PrintItem("cp2 edif", s, &col_p2[1][0], kGridWeight);
    277     }
    278   }
    279 }
    280 
    281 StatusOr<Ssimulacra> ComputeDiff(const Image3F& orig, const Image3F& distorted,
    282                                  bool simple) {
    283   Ssimulacra ssimulacra;
    284 
    285   ssimulacra.simple = simple;
    286   JXL_ASSIGN_OR_RETURN(Image3F img1, Rgb2Lab(orig));
    287   JXL_ASSIGN_OR_RETURN(Image3F img2, Rgb2Lab(distorted));
    288 
    289   JXL_ASSIGN_OR_RETURN(Image3F mul,
    290                        Image3F::Create(orig.xsize(), orig.ysize()));
    291   JXL_ASSIGN_OR_RETURN(Blur blur, Blur::Create(img1.xsize(), img1.ysize()));
    292 
    293   for (int scale = 0; scale < kNumScales; scale++) {
    294     if (img1.xsize() < 8 || img1.ysize() < 8) {
    295       break;
    296     }
    297     if (scale) {
    298       JXL_ASSIGN_OR_RETURN(img1, Downsample(img1, 2, 2));
    299       JXL_ASSIGN_OR_RETURN(img2, Downsample(img2, 2, 2));
    300     }
    301     mul.ShrinkTo(img1.xsize(), img2.ysize());
    302     blur.ShrinkTo(img1.xsize(), img2.ysize());
    303 
    304     Multiply(img1, img1, &mul);
    305     JXL_ASSIGN_OR_RETURN(Image3F sigma1_sq, blur(mul));
    306 
    307     Multiply(img2, img2, &mul);
    308     JXL_ASSIGN_OR_RETURN(Image3F sigma2_sq, blur(mul));
    309 
    310     Multiply(img1, img2, &mul);
    311     JXL_ASSIGN_OR_RETURN(Image3F sigma12, blur(mul));
    312 
    313     JXL_ASSIGN_OR_RETURN(Image3F mu1, blur(img1));
    314     JXL_ASSIGN_OR_RETURN(Image3F mu2, blur(img2));
    315     // Reuse mul as "ssim_map".
    316     SsimulacraScale sscale;
    317     SSIMMap(mu1, mu2, sigma1_sq, sigma2_sq, sigma12, &mul, sscale.avg_ssim);
    318 
    319     JXL_ASSIGN_OR_RETURN(const Image3F ssim_map, Downsample(mul, 4, 4));
    320     for (size_t c = 0; c < 3; c++) {
    321       float minval;
    322       float maxval;
    323       ImageMinMax(ssim_map.Plane(c), &minval, &maxval);
    324       sscale.min_ssim[c] = static_cast<double>(minval);
    325     }
    326     ssimulacra.scales.push_back(sscale);
    327 
    328     if (scale == 0 && !simple) {
    329       Image3F* edgediff = &sigma1_sq;  // reuse
    330       EdgeDiffMap(img1, mu1, img2, mu2, edgediff, ssimulacra.avg_edgediff);
    331       for (size_t c = 0; c < 3; c++) {
    332         RowColAvgP2(ssim_map.Plane(c), &ssimulacra.row_p2[0][c],
    333                     &ssimulacra.col_p2[0][c]);
    334         RowColAvgP2(edgediff->Plane(c), &ssimulacra.row_p2[1][c],
    335                     &ssimulacra.col_p2[1][c]);
    336       }
    337     }
    338   }
    339   return ssimulacra;
    340 }
    341 
    342 }  // namespace ssimulacra