libjxl

FORK: libjxl patches used on blog
git clone https://git.neptards.moe/blog/libjxl.git
Log | Files | Refs | Submodules | README | LICENSE

ssimulacra2.cc (20028B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 /*
      7 SSIMULACRA 2
      8 Structural SIMilarity Unveiling Local And Compression Related Artifacts
      9 
     10 Perceptual metric developed by Jon Sneyers (Cloudinary) in July 2022,
     11 updated in April 2023.
     12 Design:
     13 - XYB color space (rescaled to a 0..1 range and with B-Y)
     14 - SSIM map (with correction: no double gamma correction)
     15 - 'blockiness/ringing' map (distorted has edges where original is smooth)
     16 - 'smoothing' map (distorted is smooth where original has edges)
     17 - error maps are computed at 6 scales (1:1 to 1:32) for each component (X,Y,B)
     18 - downscaling is done in linear RGB
     19 - for all 6*3*3=54 maps, two norms are computed: 1-norm (mean) and 4-norm
     20 - a weighted sum of these 54*2=108 norms leads to the final score
     21 - weights were tuned based on a large set of subjective scores
     22   (CID22, TID2013, Kadid10k, KonFiG-IQA).
     23 */
     24 
     25 #include "tools/ssimulacra2.h"
     26 
     27 #include <jxl/cms.h>
     28 #include <stdio.h>
     29 
     30 #include <algorithm>
     31 #include <cmath>
     32 #include <hwy/aligned_allocator.h>
     33 #include <utility>
     34 
     35 #include "lib/jxl/base/compiler_specific.h"
     36 #include "lib/jxl/base/printf_macros.h"
     37 #include "lib/jxl/base/status.h"
     38 #include "lib/jxl/color_encoding_internal.h"
     39 #include "lib/jxl/enc_xyb.h"
     40 #include "lib/jxl/image.h"
     41 #include "lib/jxl/image_bundle.h"
     42 #include "tools/gauss_blur.h"
     43 
     44 namespace {
     45 
     46 using jxl::Image3F;
     47 using jxl::ImageBundle;
     48 using jxl::ImageF;
     49 using jxl::StatusOr;
     50 
     51 const float kC2 = 0.0009f;
     52 const int kNumScales = 6;
     53 
     54 StatusOr<Image3F> Downsample(const Image3F& in, size_t fx, size_t fy) {
     55   const size_t out_xsize = (in.xsize() + fx - 1) / fx;
     56   const size_t out_ysize = (in.ysize() + fy - 1) / fy;
     57   JXL_ASSIGN_OR_RETURN(Image3F out, Image3F::Create(out_xsize, out_ysize));
     58   const float normalize = 1.0f / (fx * fy);
     59   for (size_t c = 0; c < 3; ++c) {
     60     for (size_t oy = 0; oy < out_ysize; ++oy) {
     61       float* JXL_RESTRICT row_out = out.PlaneRow(c, oy);
     62       for (size_t ox = 0; ox < out_xsize; ++ox) {
     63         float sum = 0.0f;
     64         for (size_t iy = 0; iy < fy; ++iy) {
     65           for (size_t ix = 0; ix < fx; ++ix) {
     66             const size_t x = std::min(ox * fx + ix, in.xsize() - 1);
     67             const size_t y = std::min(oy * fy + iy, in.ysize() - 1);
     68             sum += in.PlaneRow(c, y)[x];
     69           }
     70         }
     71         row_out[ox] = sum * normalize;
     72       }
     73     }
     74   }
     75   return out;
     76 }
     77 
     78 void Multiply(const Image3F& a, const Image3F& b, Image3F* mul) {
     79   for (size_t c = 0; c < 3; ++c) {
     80     for (size_t y = 0; y < a.ysize(); ++y) {
     81       const float* JXL_RESTRICT in1 = a.PlaneRow(c, y);
     82       const float* JXL_RESTRICT in2 = b.PlaneRow(c, y);
     83       float* JXL_RESTRICT out = mul->PlaneRow(c, y);
     84       for (size_t x = 0; x < a.xsize(); ++x) {
     85         out[x] = in1[x] * in2[x];
     86       }
     87     }
     88   }
     89 }
     90 
     91 // Temporary storage for Gaussian blur, reused for multiple images.
     92 class Blur {
     93  public:
     94   static StatusOr<Blur> Create(const size_t xsize, const size_t ysize) {
     95     Blur result;
     96     JXL_ASSIGN_OR_RETURN(result.temp_, ImageF::Create(xsize, ysize));
     97     return result;
     98   }
     99 
    100   void operator()(const ImageF& in, ImageF* JXL_RESTRICT out) {
    101     FastGaussian(
    102         rg_, in.xsize(), in.ysize(), [&](size_t y) { return in.ConstRow(y); },
    103         [&](size_t y) { return temp_.Row(y); },
    104         [&](size_t y) { return out->Row(y); });
    105   }
    106 
    107   StatusOr<Image3F> operator()(const Image3F& in) {
    108     JXL_ASSIGN_OR_RETURN(Image3F out, Image3F::Create(in.xsize(), in.ysize()));
    109     operator()(in.Plane(0), &out.Plane(0));
    110     operator()(in.Plane(1), &out.Plane(1));
    111     operator()(in.Plane(2), &out.Plane(2));
    112     return out;
    113   }
    114 
    115   // Allows reusing across scales.
    116   void ShrinkTo(const size_t xsize, const size_t ysize) {
    117     temp_.ShrinkTo(xsize, ysize);
    118   }
    119 
    120  private:
    121   Blur() : rg_(jxl::CreateRecursiveGaussian(1.5)) {}
    122   hwy::AlignedUniquePtr<jxl::RecursiveGaussian> rg_;
    123   ImageF temp_;
    124 };
    125 
    126 double quartic(double x) {
    127   x *= x;
    128   x *= x;
    129   return x;
    130 }
    131 void SSIMMap(const Image3F& m1, const Image3F& m2, const Image3F& s11,
    132              const Image3F& s22, const Image3F& s12, double* plane_averages) {
    133   const double onePerPixels = 1.0 / (m1.ysize() * m1.xsize());
    134   for (size_t c = 0; c < 3; ++c) {
    135     double sum1[2] = {0.0};
    136     for (size_t y = 0; y < m1.ysize(); ++y) {
    137       const float* JXL_RESTRICT row_m1 = m1.PlaneRow(c, y);
    138       const float* JXL_RESTRICT row_m2 = m2.PlaneRow(c, y);
    139       const float* JXL_RESTRICT row_s11 = s11.PlaneRow(c, y);
    140       const float* JXL_RESTRICT row_s22 = s22.PlaneRow(c, y);
    141       const float* JXL_RESTRICT row_s12 = s12.PlaneRow(c, y);
    142       for (size_t x = 0; x < m1.xsize(); ++x) {
    143         float mu1 = row_m1[x];
    144         float mu2 = row_m2[x];
    145         float mu11 = mu1 * mu1;
    146         float mu22 = mu2 * mu2;
    147         float mu12 = mu1 * mu2;
    148         /* Correction applied compared to the original SSIM formula, which has:
    149 
    150              luma_err = 2 * mu1 * mu2 / (mu1^2 + mu2^2)
    151                       = 1 - (mu1 - mu2)^2 / (mu1^2 + mu2^2)
    152 
    153            The denominator causes error in the darks (low mu1 and mu2) to weigh
    154            more than error in the brights (high mu1 and mu2). This would make
    155            sense if values correspond to linear luma. However, the actual values
    156            are either gamma-compressed luma (which supposedly is already
    157            perceptually uniform) or chroma (where weighing green more than red
    158            or blue more than yellow does not make any sense at all). So it is
    159            better to simply drop this denominator.
    160         */
    161         float num_m = 1.0 - (mu1 - mu2) * (mu1 - mu2);
    162         float num_s = 2 * (row_s12[x] - mu12) + kC2;
    163         float denom_s = (row_s11[x] - mu11) + (row_s22[x] - mu22) + kC2;
    164 
    165         // Use 1 - SSIM' so it becomes an error score instead of a quality
    166         // index. This makes it make sense to compute an L_4 norm.
    167         double d = 1.0 - (num_m * num_s / denom_s);
    168         d = std::max(d, 0.0);
    169         sum1[0] += d;
    170         sum1[1] += quartic(d);
    171       }
    172     }
    173     plane_averages[c * 2] = onePerPixels * sum1[0];
    174     plane_averages[c * 2 + 1] = sqrt(sqrt(onePerPixels * sum1[1]));
    175   }
    176 }
    177 
    178 void EdgeDiffMap(const Image3F& img1, const Image3F& mu1, const Image3F& img2,
    179                  const Image3F& mu2, double* plane_averages) {
    180   const double onePerPixels = 1.0 / (img1.ysize() * img1.xsize());
    181   for (size_t c = 0; c < 3; ++c) {
    182     double sum1[4] = {0.0};
    183     for (size_t y = 0; y < img1.ysize(); ++y) {
    184       const float* JXL_RESTRICT row1 = img1.PlaneRow(c, y);
    185       const float* JXL_RESTRICT row2 = img2.PlaneRow(c, y);
    186       const float* JXL_RESTRICT rowm1 = mu1.PlaneRow(c, y);
    187       const float* JXL_RESTRICT rowm2 = mu2.PlaneRow(c, y);
    188       for (size_t x = 0; x < img1.xsize(); ++x) {
    189         double d1 = (1.0 + std::abs(row2[x] - rowm2[x])) /
    190                         (1.0 + std::abs(row1[x] - rowm1[x])) -
    191                     1.0;
    192 
    193         // d1 > 0: distorted has an edge where original is smooth
    194         //         (indicating ringing, color banding, blockiness, etc)
    195         double artifact = std::max(d1, 0.0);
    196         sum1[0] += artifact;
    197         sum1[1] += quartic(artifact);
    198 
    199         // d1 < 0: original has an edge where distorted is smooth
    200         //         (indicating smoothing, blurring, smearing, etc)
    201         double detail_lost = std::max(-d1, 0.0);
    202         sum1[2] += detail_lost;
    203         sum1[3] += quartic(detail_lost);
    204       }
    205     }
    206     plane_averages[c * 4] = onePerPixels * sum1[0];
    207     plane_averages[c * 4 + 1] = sqrt(sqrt(onePerPixels * sum1[1]));
    208     plane_averages[c * 4 + 2] = onePerPixels * sum1[2];
    209     plane_averages[c * 4 + 3] = sqrt(sqrt(onePerPixels * sum1[3]));
    210   }
    211 }
    212 
    213 /* Get all components in more or less 0..1 range
    214    Range of Rec2020 with these adjustments:
    215     X: 0.017223..0.998838
    216     Y: 0.010000..0.855303
    217     B: 0.048759..0.989551
    218    Range of sRGB:
    219     X: 0.204594..0.813402
    220     Y: 0.010000..0.855308
    221     B: 0.272295..0.938012
    222    The maximum pixel-wise difference has to be <= 1 for the ssim formula to make
    223    sense.
    224 */
    225 void MakePositiveXYB(Image3F& img) {
    226   for (size_t y = 0; y < img.ysize(); ++y) {
    227     float* JXL_RESTRICT rowY = img.PlaneRow(1, y);
    228     float* JXL_RESTRICT rowB = img.PlaneRow(2, y);
    229     float* JXL_RESTRICT rowX = img.PlaneRow(0, y);
    230     for (size_t x = 0; x < img.xsize(); ++x) {
    231       rowB[x] = (rowB[x] - rowY[x]) + 0.55f;
    232       rowX[x] = rowX[x] * 14.f + 0.42f;
    233       rowY[x] += 0.01f;
    234     }
    235   }
    236 }
    237 
    238 void AlphaBlend(ImageBundle& img, float bg) {
    239   for (size_t y = 0; y < img.ysize(); ++y) {
    240     float* JXL_RESTRICT r = img.color()->PlaneRow(0, y);
    241     float* JXL_RESTRICT g = img.color()->PlaneRow(1, y);
    242     float* JXL_RESTRICT b = img.color()->PlaneRow(2, y);
    243     const float* JXL_RESTRICT a = img.alpha()->Row(y);
    244     for (size_t x = 0; x < img.xsize(); ++x) {
    245       r[x] = a[x] * r[x] + (1.f - a[x]) * bg;
    246       g[x] = a[x] * g[x] + (1.f - a[x]) * bg;
    247       b[x] = a[x] * b[x] + (1.f - a[x]) * bg;
    248     }
    249   }
    250 }
    251 
    252 }  // namespace
    253 
    254 /*
    255 The final score is based on a weighted sum of 108 sub-scores:
    256 - for 6 scales (1:1 to 1:32, downsampled in linear RGB)
    257 - for 3 components (X, Y, B-Y, rescaled to 0..1 range)
    258 - using 2 norms (the 1-norm and the 4-norm)
    259 - over 3 error maps:
    260     - SSIM' (SSIM without the spurious gamma correction term)
    261     - "ringing" (distorted edges where there are no orig edges)
    262     - "blurring" (orig edges where there are no distorted edges)
    263 
    264 The weights were obtained by running Nelder-Mead simplex search,
    265 optimizing to minimize MSE for the CID22 training set and to
    266 maximize Kendall rank correlation (and with a lower weight,
    267 also Pearson correlation) with the CID22 training set and the
    268 TID2013, Kadid10k and KonFiG-IQA datasets.
    269 Validation was done on the CID22 validation set.
    270 
    271 Final results after tuning (Kendall | Spearman | Pearson):
    272    CID22:     0.6903 | 0.8805 | 0.8583
    273    TID2013:   0.6590 | 0.8445 | 0.8471
    274    KADID-10k: 0.6175 | 0.8133 | 0.8030
    275    KonFiG(F): 0.7668 | 0.9194 | 0.9136
    276 */
    277 double Msssim::Score() const {
    278   double ssim = 0.0;
    279   constexpr double weight[108] = {0.0,
    280                                   0.0007376606707406586,
    281                                   0.0,
    282                                   0.0,
    283                                   0.0007793481682867309,
    284                                   0.0,
    285                                   0.0,
    286                                   0.0004371155730107379,
    287                                   0.0,
    288                                   1.1041726426657346,
    289                                   0.00066284834129271,
    290                                   0.00015231632783718752,
    291                                   0.0,
    292                                   0.0016406437456599754,
    293                                   0.0,
    294                                   1.8422455520539298,
    295                                   11.441172603757666,
    296                                   0.0,
    297                                   0.0007989109436015163,
    298                                   0.000176816438078653,
    299                                   0.0,
    300                                   1.8787594979546387,
    301                                   10.94906990605142,
    302                                   0.0,
    303                                   0.0007289346991508072,
    304                                   0.9677937080626833,
    305                                   0.0,
    306                                   0.00014003424285435884,
    307                                   0.9981766977854967,
    308                                   0.00031949755934435053,
    309                                   0.0004550992113792063,
    310                                   0.0,
    311                                   0.0,
    312                                   0.0013648766163243398,
    313                                   0.0,
    314                                   0.0,
    315                                   0.0,
    316                                   0.0,
    317                                   0.0,
    318                                   7.466890328078848,
    319                                   0.0,
    320                                   17.445833984131262,
    321                                   0.0006235601634041466,
    322                                   0.0,
    323                                   0.0,
    324                                   6.683678146179332,
    325                                   0.00037724407979611296,
    326                                   1.027889937768264,
    327                                   225.20515300849274,
    328                                   0.0,
    329                                   0.0,
    330                                   19.213238186143016,
    331                                   0.0011401524586618361,
    332                                   0.001237755635509985,
    333                                   176.39317598450694,
    334                                   0.0,
    335                                   0.0,
    336                                   24.43300999870476,
    337                                   0.28520802612117757,
    338                                   0.0004485436923833408,
    339                                   0.0,
    340                                   0.0,
    341                                   0.0,
    342                                   34.77906344483772,
    343                                   44.835625328877896,
    344                                   0.0,
    345                                   0.0,
    346                                   0.0,
    347                                   0.0,
    348                                   0.0,
    349                                   0.0,
    350                                   0.0,
    351                                   0.0,
    352                                   0.0008680556573291698,
    353                                   0.0,
    354                                   0.0,
    355                                   0.0,
    356                                   0.0,
    357                                   0.0,
    358                                   0.0005313191874358747,
    359                                   0.0,
    360                                   0.00016533814161379112,
    361                                   0.0,
    362                                   0.0,
    363                                   0.0,
    364                                   0.0,
    365                                   0.0,
    366                                   0.0004179171803251336,
    367                                   0.0017290828234722833,
    368                                   0.0,
    369                                   0.0020827005846636437,
    370                                   0.0,
    371                                   0.0,
    372                                   8.826982764996862,
    373                                   23.19243343998926,
    374                                   0.0,
    375                                   95.1080498811086,
    376                                   0.9863978034400682,
    377                                   0.9834382792465353,
    378                                   0.0012286405048278493,
    379                                   171.2667255897307,
    380                                   0.9807858872435379,
    381                                   0.0,
    382                                   0.0,
    383                                   0.0,
    384                                   0.0005130064588990679,
    385                                   0.0,
    386                                   0.00010854057858411537};
    387 
    388   size_t i = 0;
    389   char ch[] = "XYB";
    390   const bool verbose = false;
    391   for (size_t c = 0; c < 3; ++c) {
    392     for (size_t scale = 0; scale < scales.size(); ++scale) {
    393       for (size_t n = 0; n < 2; n++) {
    394 #ifdef SSIMULACRA2_OUTPUT_RAW_SCORES_FOR_WEIGHT_TUNING
    395         printf("%.12f,%.12f,%.12f,", scales[scale].avg_ssim[c * 2 + n],
    396                scales[scale].avg_edgediff[c * 4 + n],
    397                scales[scale].avg_edgediff[c * 4 + 2 + n]);
    398 #endif
    399         if (verbose) {
    400           printf("%f from channel %c ssim, scale 1:%i, %" PRIuS
    401                  "-norm (weight %f)\n",
    402                  weight[i] * std::abs(scales[scale].avg_ssim[c * 2 + n]), ch[c],
    403                  1 << scale, n * 3 + 1, weight[i]);
    404         }
    405         ssim += weight[i++] * std::abs(scales[scale].avg_ssim[c * 2 + n]);
    406         if (verbose) {
    407           printf("%f from channel %c ringing, scale 1:%i, %" PRIuS
    408                  "-norm (weight %f)\n",
    409                  weight[i] * std::abs(scales[scale].avg_edgediff[c * 4 + n]),
    410                  ch[c], 1 << scale, n * 3 + 1, weight[i]);
    411         }
    412         ssim += weight[i++] * std::abs(scales[scale].avg_edgediff[c * 4 + n]);
    413         if (verbose) {
    414           printf(
    415               "%f from channel %c blur, scale 1:%i, %" PRIuS
    416               "-norm (weight %f)\n",
    417               weight[i] * std::abs(scales[scale].avg_edgediff[c * 4 + n + 2]),
    418               ch[c], 1 << scale, n * 3 + 1, weight[i]);
    419         }
    420         ssim +=
    421             weight[i++] * std::abs(scales[scale].avg_edgediff[c * 4 + n + 2]);
    422       }
    423     }
    424   }
    425 
    426   ssim = ssim * 0.9562382616834844;
    427   ssim = 2.326765642916932 * ssim - 0.020884521182843837 * ssim * ssim +
    428          6.248496625763138e-05 * ssim * ssim * ssim;
    429   if (ssim > 0) {
    430     ssim = 100.0 - 10.0 * pow(ssim, 0.6276336467831387);
    431   } else {
    432     ssim = 100.0;
    433   }
    434   return ssim;
    435 }
    436 
    437 StatusOr<Msssim> ComputeSSIMULACRA2(const ImageBundle& orig,
    438                                     const ImageBundle& dist, float bg) {
    439   Msssim msssim;
    440 
    441   JXL_ASSIGN_OR_RETURN(Image3F img1,
    442                        Image3F::Create(orig.xsize(), orig.ysize()));
    443   JXL_ASSIGN_OR_RETURN(Image3F img2,
    444                        Image3F::Create(img1.xsize(), img1.ysize()));
    445 
    446   JXL_ASSIGN_OR_RETURN(ImageBundle orig2, orig.Copy());
    447   JXL_ASSIGN_OR_RETURN(ImageBundle dist2, dist.Copy());
    448 
    449   if (orig.HasAlpha()) AlphaBlend(orig2, bg);
    450   if (dist.HasAlpha()) AlphaBlend(dist2, bg);
    451   orig2.ClearExtraChannels();
    452   dist2.ClearExtraChannels();
    453 
    454   JXL_CHECK(orig2.TransformTo(jxl::ColorEncoding::LinearSRGB(orig2.IsGray()),
    455                               *JxlGetDefaultCms()));
    456   JXL_CHECK(dist2.TransformTo(jxl::ColorEncoding::LinearSRGB(dist2.IsGray()),
    457                               *JxlGetDefaultCms()));
    458 
    459   JXL_RETURN_IF_ERROR(
    460       jxl::ToXYB(orig2, nullptr, &img1, *JxlGetDefaultCms(), nullptr));
    461   JXL_RETURN_IF_ERROR(
    462       jxl::ToXYB(dist2, nullptr, &img2, *JxlGetDefaultCms(), nullptr));
    463   MakePositiveXYB(img1);
    464   MakePositiveXYB(img2);
    465 
    466   JXL_ASSIGN_OR_RETURN(Image3F mul,
    467                        Image3F::Create(img1.xsize(), img1.ysize()));
    468   JXL_ASSIGN_OR_RETURN(Blur blur, Blur::Create(img1.xsize(), img1.ysize()));
    469 
    470   for (int scale = 0; scale < kNumScales; scale++) {
    471     if (img1.xsize() < 8 || img1.ysize() < 8) {
    472       break;
    473     }
    474     if (scale) {
    475       JXL_ASSIGN_OR_RETURN(Image3F tmp, Downsample(*orig2.color(), 2, 2));
    476       orig2.SetFromImage(std::move(tmp),
    477                          jxl::ColorEncoding::LinearSRGB(orig2.IsGray()));
    478       JXL_ASSIGN_OR_RETURN(tmp, Downsample(*dist2.color(), 2, 2));
    479       dist2.SetFromImage(std::move(tmp),
    480                          jxl::ColorEncoding::LinearSRGB(dist2.IsGray()));
    481       img1.ShrinkTo(orig2.xsize(), orig2.ysize());
    482       img2.ShrinkTo(orig2.xsize(), orig2.ysize());
    483       JXL_RETURN_IF_ERROR(
    484           jxl::ToXYB(orig2, nullptr, &img1, *JxlGetDefaultCms(), nullptr));
    485       JXL_RETURN_IF_ERROR(
    486           jxl::ToXYB(dist2, nullptr, &img2, *JxlGetDefaultCms(), nullptr));
    487       MakePositiveXYB(img1);
    488       MakePositiveXYB(img2);
    489     }
    490     mul.ShrinkTo(img1.xsize(), img1.ysize());
    491     blur.ShrinkTo(img1.xsize(), img1.ysize());
    492 
    493     Multiply(img1, img1, &mul);
    494     JXL_ASSIGN_OR_RETURN(Image3F sigma1_sq, blur(mul));
    495 
    496     Multiply(img2, img2, &mul);
    497     JXL_ASSIGN_OR_RETURN(Image3F sigma2_sq, blur(mul));
    498 
    499     Multiply(img1, img2, &mul);
    500     JXL_ASSIGN_OR_RETURN(Image3F sigma12, blur(mul));
    501 
    502     JXL_ASSIGN_OR_RETURN(Image3F mu1, blur(img1));
    503     JXL_ASSIGN_OR_RETURN(Image3F mu2, blur(img2));
    504 
    505     MsssimScale sscale;
    506     SSIMMap(mu1, mu2, sigma1_sq, sigma2_sq, sigma12, sscale.avg_ssim);
    507     EdgeDiffMap(img1, mu1, img2, mu2, sscale.avg_edgediff);
    508     msssim.scales.push_back(sscale);
    509   }
    510   return msssim;
    511 }
    512 
    513 StatusOr<Msssim> ComputeSSIMULACRA2(const ImageBundle& orig,
    514                                     const ImageBundle& distorted) {
    515   return ComputeSSIMULACRA2(orig, distorted, 0.5f);
    516 }