ssimulacra2.cc (20028B)
1 // Copyright (c) the JPEG XL Project Authors. All rights reserved. 2 // 3 // Use of this source code is governed by a BSD-style 4 // license that can be found in the LICENSE file. 5 6 /* 7 SSIMULACRA 2 8 Structural SIMilarity Unveiling Local And Compression Related Artifacts 9 10 Perceptual metric developed by Jon Sneyers (Cloudinary) in July 2022, 11 updated in April 2023. 12 Design: 13 - XYB color space (rescaled to a 0..1 range and with B-Y) 14 - SSIM map (with correction: no double gamma correction) 15 - 'blockiness/ringing' map (distorted has edges where original is smooth) 16 - 'smoothing' map (distorted is smooth where original has edges) 17 - error maps are computed at 6 scales (1:1 to 1:32) for each component (X,Y,B) 18 - downscaling is done in linear RGB 19 - for all 6*3*3=54 maps, two norms are computed: 1-norm (mean) and 4-norm 20 - a weighted sum of these 54*2=108 norms leads to the final score 21 - weights were tuned based on a large set of subjective scores 22 (CID22, TID2013, Kadid10k, KonFiG-IQA). 23 */ 24 25 #include "tools/ssimulacra2.h" 26 27 #include <jxl/cms.h> 28 #include <stdio.h> 29 30 #include <algorithm> 31 #include <cmath> 32 #include <hwy/aligned_allocator.h> 33 #include <utility> 34 35 #include "lib/jxl/base/compiler_specific.h" 36 #include "lib/jxl/base/printf_macros.h" 37 #include "lib/jxl/base/status.h" 38 #include "lib/jxl/color_encoding_internal.h" 39 #include "lib/jxl/enc_xyb.h" 40 #include "lib/jxl/image.h" 41 #include "lib/jxl/image_bundle.h" 42 #include "tools/gauss_blur.h" 43 44 namespace { 45 46 using jxl::Image3F; 47 using jxl::ImageBundle; 48 using jxl::ImageF; 49 using jxl::StatusOr; 50 51 const float kC2 = 0.0009f; 52 const int kNumScales = 6; 53 54 StatusOr<Image3F> Downsample(const Image3F& in, size_t fx, size_t fy) { 55 const size_t out_xsize = (in.xsize() + fx - 1) / fx; 56 const size_t out_ysize = (in.ysize() + fy - 1) / fy; 57 JXL_ASSIGN_OR_RETURN(Image3F out, Image3F::Create(out_xsize, out_ysize)); 58 const float normalize = 1.0f / (fx * fy); 59 for (size_t c = 0; c < 3; ++c) { 60 for (size_t oy = 0; oy < out_ysize; ++oy) { 61 float* JXL_RESTRICT row_out = out.PlaneRow(c, oy); 62 for (size_t ox = 0; ox < out_xsize; ++ox) { 63 float sum = 0.0f; 64 for (size_t iy = 0; iy < fy; ++iy) { 65 for (size_t ix = 0; ix < fx; ++ix) { 66 const size_t x = std::min(ox * fx + ix, in.xsize() - 1); 67 const size_t y = std::min(oy * fy + iy, in.ysize() - 1); 68 sum += in.PlaneRow(c, y)[x]; 69 } 70 } 71 row_out[ox] = sum * normalize; 72 } 73 } 74 } 75 return out; 76 } 77 78 void Multiply(const Image3F& a, const Image3F& b, Image3F* mul) { 79 for (size_t c = 0; c < 3; ++c) { 80 for (size_t y = 0; y < a.ysize(); ++y) { 81 const float* JXL_RESTRICT in1 = a.PlaneRow(c, y); 82 const float* JXL_RESTRICT in2 = b.PlaneRow(c, y); 83 float* JXL_RESTRICT out = mul->PlaneRow(c, y); 84 for (size_t x = 0; x < a.xsize(); ++x) { 85 out[x] = in1[x] * in2[x]; 86 } 87 } 88 } 89 } 90 91 // Temporary storage for Gaussian blur, reused for multiple images. 92 class Blur { 93 public: 94 static StatusOr<Blur> Create(const size_t xsize, const size_t ysize) { 95 Blur result; 96 JXL_ASSIGN_OR_RETURN(result.temp_, ImageF::Create(xsize, ysize)); 97 return result; 98 } 99 100 void operator()(const ImageF& in, ImageF* JXL_RESTRICT out) { 101 FastGaussian( 102 rg_, in.xsize(), in.ysize(), [&](size_t y) { return in.ConstRow(y); }, 103 [&](size_t y) { return temp_.Row(y); }, 104 [&](size_t y) { return out->Row(y); }); 105 } 106 107 StatusOr<Image3F> operator()(const Image3F& in) { 108 JXL_ASSIGN_OR_RETURN(Image3F out, Image3F::Create(in.xsize(), in.ysize())); 109 operator()(in.Plane(0), &out.Plane(0)); 110 operator()(in.Plane(1), &out.Plane(1)); 111 operator()(in.Plane(2), &out.Plane(2)); 112 return out; 113 } 114 115 // Allows reusing across scales. 116 void ShrinkTo(const size_t xsize, const size_t ysize) { 117 temp_.ShrinkTo(xsize, ysize); 118 } 119 120 private: 121 Blur() : rg_(jxl::CreateRecursiveGaussian(1.5)) {} 122 hwy::AlignedUniquePtr<jxl::RecursiveGaussian> rg_; 123 ImageF temp_; 124 }; 125 126 double quartic(double x) { 127 x *= x; 128 x *= x; 129 return x; 130 } 131 void SSIMMap(const Image3F& m1, const Image3F& m2, const Image3F& s11, 132 const Image3F& s22, const Image3F& s12, double* plane_averages) { 133 const double onePerPixels = 1.0 / (m1.ysize() * m1.xsize()); 134 for (size_t c = 0; c < 3; ++c) { 135 double sum1[2] = {0.0}; 136 for (size_t y = 0; y < m1.ysize(); ++y) { 137 const float* JXL_RESTRICT row_m1 = m1.PlaneRow(c, y); 138 const float* JXL_RESTRICT row_m2 = m2.PlaneRow(c, y); 139 const float* JXL_RESTRICT row_s11 = s11.PlaneRow(c, y); 140 const float* JXL_RESTRICT row_s22 = s22.PlaneRow(c, y); 141 const float* JXL_RESTRICT row_s12 = s12.PlaneRow(c, y); 142 for (size_t x = 0; x < m1.xsize(); ++x) { 143 float mu1 = row_m1[x]; 144 float mu2 = row_m2[x]; 145 float mu11 = mu1 * mu1; 146 float mu22 = mu2 * mu2; 147 float mu12 = mu1 * mu2; 148 /* Correction applied compared to the original SSIM formula, which has: 149 150 luma_err = 2 * mu1 * mu2 / (mu1^2 + mu2^2) 151 = 1 - (mu1 - mu2)^2 / (mu1^2 + mu2^2) 152 153 The denominator causes error in the darks (low mu1 and mu2) to weigh 154 more than error in the brights (high mu1 and mu2). This would make 155 sense if values correspond to linear luma. However, the actual values 156 are either gamma-compressed luma (which supposedly is already 157 perceptually uniform) or chroma (where weighing green more than red 158 or blue more than yellow does not make any sense at all). So it is 159 better to simply drop this denominator. 160 */ 161 float num_m = 1.0 - (mu1 - mu2) * (mu1 - mu2); 162 float num_s = 2 * (row_s12[x] - mu12) + kC2; 163 float denom_s = (row_s11[x] - mu11) + (row_s22[x] - mu22) + kC2; 164 165 // Use 1 - SSIM' so it becomes an error score instead of a quality 166 // index. This makes it make sense to compute an L_4 norm. 167 double d = 1.0 - (num_m * num_s / denom_s); 168 d = std::max(d, 0.0); 169 sum1[0] += d; 170 sum1[1] += quartic(d); 171 } 172 } 173 plane_averages[c * 2] = onePerPixels * sum1[0]; 174 plane_averages[c * 2 + 1] = sqrt(sqrt(onePerPixels * sum1[1])); 175 } 176 } 177 178 void EdgeDiffMap(const Image3F& img1, const Image3F& mu1, const Image3F& img2, 179 const Image3F& mu2, double* plane_averages) { 180 const double onePerPixels = 1.0 / (img1.ysize() * img1.xsize()); 181 for (size_t c = 0; c < 3; ++c) { 182 double sum1[4] = {0.0}; 183 for (size_t y = 0; y < img1.ysize(); ++y) { 184 const float* JXL_RESTRICT row1 = img1.PlaneRow(c, y); 185 const float* JXL_RESTRICT row2 = img2.PlaneRow(c, y); 186 const float* JXL_RESTRICT rowm1 = mu1.PlaneRow(c, y); 187 const float* JXL_RESTRICT rowm2 = mu2.PlaneRow(c, y); 188 for (size_t x = 0; x < img1.xsize(); ++x) { 189 double d1 = (1.0 + std::abs(row2[x] - rowm2[x])) / 190 (1.0 + std::abs(row1[x] - rowm1[x])) - 191 1.0; 192 193 // d1 > 0: distorted has an edge where original is smooth 194 // (indicating ringing, color banding, blockiness, etc) 195 double artifact = std::max(d1, 0.0); 196 sum1[0] += artifact; 197 sum1[1] += quartic(artifact); 198 199 // d1 < 0: original has an edge where distorted is smooth 200 // (indicating smoothing, blurring, smearing, etc) 201 double detail_lost = std::max(-d1, 0.0); 202 sum1[2] += detail_lost; 203 sum1[3] += quartic(detail_lost); 204 } 205 } 206 plane_averages[c * 4] = onePerPixels * sum1[0]; 207 plane_averages[c * 4 + 1] = sqrt(sqrt(onePerPixels * sum1[1])); 208 plane_averages[c * 4 + 2] = onePerPixels * sum1[2]; 209 plane_averages[c * 4 + 3] = sqrt(sqrt(onePerPixels * sum1[3])); 210 } 211 } 212 213 /* Get all components in more or less 0..1 range 214 Range of Rec2020 with these adjustments: 215 X: 0.017223..0.998838 216 Y: 0.010000..0.855303 217 B: 0.048759..0.989551 218 Range of sRGB: 219 X: 0.204594..0.813402 220 Y: 0.010000..0.855308 221 B: 0.272295..0.938012 222 The maximum pixel-wise difference has to be <= 1 for the ssim formula to make 223 sense. 224 */ 225 void MakePositiveXYB(Image3F& img) { 226 for (size_t y = 0; y < img.ysize(); ++y) { 227 float* JXL_RESTRICT rowY = img.PlaneRow(1, y); 228 float* JXL_RESTRICT rowB = img.PlaneRow(2, y); 229 float* JXL_RESTRICT rowX = img.PlaneRow(0, y); 230 for (size_t x = 0; x < img.xsize(); ++x) { 231 rowB[x] = (rowB[x] - rowY[x]) + 0.55f; 232 rowX[x] = rowX[x] * 14.f + 0.42f; 233 rowY[x] += 0.01f; 234 } 235 } 236 } 237 238 void AlphaBlend(ImageBundle& img, float bg) { 239 for (size_t y = 0; y < img.ysize(); ++y) { 240 float* JXL_RESTRICT r = img.color()->PlaneRow(0, y); 241 float* JXL_RESTRICT g = img.color()->PlaneRow(1, y); 242 float* JXL_RESTRICT b = img.color()->PlaneRow(2, y); 243 const float* JXL_RESTRICT a = img.alpha()->Row(y); 244 for (size_t x = 0; x < img.xsize(); ++x) { 245 r[x] = a[x] * r[x] + (1.f - a[x]) * bg; 246 g[x] = a[x] * g[x] + (1.f - a[x]) * bg; 247 b[x] = a[x] * b[x] + (1.f - a[x]) * bg; 248 } 249 } 250 } 251 252 } // namespace 253 254 /* 255 The final score is based on a weighted sum of 108 sub-scores: 256 - for 6 scales (1:1 to 1:32, downsampled in linear RGB) 257 - for 3 components (X, Y, B-Y, rescaled to 0..1 range) 258 - using 2 norms (the 1-norm and the 4-norm) 259 - over 3 error maps: 260 - SSIM' (SSIM without the spurious gamma correction term) 261 - "ringing" (distorted edges where there are no orig edges) 262 - "blurring" (orig edges where there are no distorted edges) 263 264 The weights were obtained by running Nelder-Mead simplex search, 265 optimizing to minimize MSE for the CID22 training set and to 266 maximize Kendall rank correlation (and with a lower weight, 267 also Pearson correlation) with the CID22 training set and the 268 TID2013, Kadid10k and KonFiG-IQA datasets. 269 Validation was done on the CID22 validation set. 270 271 Final results after tuning (Kendall | Spearman | Pearson): 272 CID22: 0.6903 | 0.8805 | 0.8583 273 TID2013: 0.6590 | 0.8445 | 0.8471 274 KADID-10k: 0.6175 | 0.8133 | 0.8030 275 KonFiG(F): 0.7668 | 0.9194 | 0.9136 276 */ 277 double Msssim::Score() const { 278 double ssim = 0.0; 279 constexpr double weight[108] = {0.0, 280 0.0007376606707406586, 281 0.0, 282 0.0, 283 0.0007793481682867309, 284 0.0, 285 0.0, 286 0.0004371155730107379, 287 0.0, 288 1.1041726426657346, 289 0.00066284834129271, 290 0.00015231632783718752, 291 0.0, 292 0.0016406437456599754, 293 0.0, 294 1.8422455520539298, 295 11.441172603757666, 296 0.0, 297 0.0007989109436015163, 298 0.000176816438078653, 299 0.0, 300 1.8787594979546387, 301 10.94906990605142, 302 0.0, 303 0.0007289346991508072, 304 0.9677937080626833, 305 0.0, 306 0.00014003424285435884, 307 0.9981766977854967, 308 0.00031949755934435053, 309 0.0004550992113792063, 310 0.0, 311 0.0, 312 0.0013648766163243398, 313 0.0, 314 0.0, 315 0.0, 316 0.0, 317 0.0, 318 7.466890328078848, 319 0.0, 320 17.445833984131262, 321 0.0006235601634041466, 322 0.0, 323 0.0, 324 6.683678146179332, 325 0.00037724407979611296, 326 1.027889937768264, 327 225.20515300849274, 328 0.0, 329 0.0, 330 19.213238186143016, 331 0.0011401524586618361, 332 0.001237755635509985, 333 176.39317598450694, 334 0.0, 335 0.0, 336 24.43300999870476, 337 0.28520802612117757, 338 0.0004485436923833408, 339 0.0, 340 0.0, 341 0.0, 342 34.77906344483772, 343 44.835625328877896, 344 0.0, 345 0.0, 346 0.0, 347 0.0, 348 0.0, 349 0.0, 350 0.0, 351 0.0, 352 0.0008680556573291698, 353 0.0, 354 0.0, 355 0.0, 356 0.0, 357 0.0, 358 0.0005313191874358747, 359 0.0, 360 0.00016533814161379112, 361 0.0, 362 0.0, 363 0.0, 364 0.0, 365 0.0, 366 0.0004179171803251336, 367 0.0017290828234722833, 368 0.0, 369 0.0020827005846636437, 370 0.0, 371 0.0, 372 8.826982764996862, 373 23.19243343998926, 374 0.0, 375 95.1080498811086, 376 0.9863978034400682, 377 0.9834382792465353, 378 0.0012286405048278493, 379 171.2667255897307, 380 0.9807858872435379, 381 0.0, 382 0.0, 383 0.0, 384 0.0005130064588990679, 385 0.0, 386 0.00010854057858411537}; 387 388 size_t i = 0; 389 char ch[] = "XYB"; 390 const bool verbose = false; 391 for (size_t c = 0; c < 3; ++c) { 392 for (size_t scale = 0; scale < scales.size(); ++scale) { 393 for (size_t n = 0; n < 2; n++) { 394 #ifdef SSIMULACRA2_OUTPUT_RAW_SCORES_FOR_WEIGHT_TUNING 395 printf("%.12f,%.12f,%.12f,", scales[scale].avg_ssim[c * 2 + n], 396 scales[scale].avg_edgediff[c * 4 + n], 397 scales[scale].avg_edgediff[c * 4 + 2 + n]); 398 #endif 399 if (verbose) { 400 printf("%f from channel %c ssim, scale 1:%i, %" PRIuS 401 "-norm (weight %f)\n", 402 weight[i] * std::abs(scales[scale].avg_ssim[c * 2 + n]), ch[c], 403 1 << scale, n * 3 + 1, weight[i]); 404 } 405 ssim += weight[i++] * std::abs(scales[scale].avg_ssim[c * 2 + n]); 406 if (verbose) { 407 printf("%f from channel %c ringing, scale 1:%i, %" PRIuS 408 "-norm (weight %f)\n", 409 weight[i] * std::abs(scales[scale].avg_edgediff[c * 4 + n]), 410 ch[c], 1 << scale, n * 3 + 1, weight[i]); 411 } 412 ssim += weight[i++] * std::abs(scales[scale].avg_edgediff[c * 4 + n]); 413 if (verbose) { 414 printf( 415 "%f from channel %c blur, scale 1:%i, %" PRIuS 416 "-norm (weight %f)\n", 417 weight[i] * std::abs(scales[scale].avg_edgediff[c * 4 + n + 2]), 418 ch[c], 1 << scale, n * 3 + 1, weight[i]); 419 } 420 ssim += 421 weight[i++] * std::abs(scales[scale].avg_edgediff[c * 4 + n + 2]); 422 } 423 } 424 } 425 426 ssim = ssim * 0.9562382616834844; 427 ssim = 2.326765642916932 * ssim - 0.020884521182843837 * ssim * ssim + 428 6.248496625763138e-05 * ssim * ssim * ssim; 429 if (ssim > 0) { 430 ssim = 100.0 - 10.0 * pow(ssim, 0.6276336467831387); 431 } else { 432 ssim = 100.0; 433 } 434 return ssim; 435 } 436 437 StatusOr<Msssim> ComputeSSIMULACRA2(const ImageBundle& orig, 438 const ImageBundle& dist, float bg) { 439 Msssim msssim; 440 441 JXL_ASSIGN_OR_RETURN(Image3F img1, 442 Image3F::Create(orig.xsize(), orig.ysize())); 443 JXL_ASSIGN_OR_RETURN(Image3F img2, 444 Image3F::Create(img1.xsize(), img1.ysize())); 445 446 JXL_ASSIGN_OR_RETURN(ImageBundle orig2, orig.Copy()); 447 JXL_ASSIGN_OR_RETURN(ImageBundle dist2, dist.Copy()); 448 449 if (orig.HasAlpha()) AlphaBlend(orig2, bg); 450 if (dist.HasAlpha()) AlphaBlend(dist2, bg); 451 orig2.ClearExtraChannels(); 452 dist2.ClearExtraChannels(); 453 454 JXL_CHECK(orig2.TransformTo(jxl::ColorEncoding::LinearSRGB(orig2.IsGray()), 455 *JxlGetDefaultCms())); 456 JXL_CHECK(dist2.TransformTo(jxl::ColorEncoding::LinearSRGB(dist2.IsGray()), 457 *JxlGetDefaultCms())); 458 459 JXL_RETURN_IF_ERROR( 460 jxl::ToXYB(orig2, nullptr, &img1, *JxlGetDefaultCms(), nullptr)); 461 JXL_RETURN_IF_ERROR( 462 jxl::ToXYB(dist2, nullptr, &img2, *JxlGetDefaultCms(), nullptr)); 463 MakePositiveXYB(img1); 464 MakePositiveXYB(img2); 465 466 JXL_ASSIGN_OR_RETURN(Image3F mul, 467 Image3F::Create(img1.xsize(), img1.ysize())); 468 JXL_ASSIGN_OR_RETURN(Blur blur, Blur::Create(img1.xsize(), img1.ysize())); 469 470 for (int scale = 0; scale < kNumScales; scale++) { 471 if (img1.xsize() < 8 || img1.ysize() < 8) { 472 break; 473 } 474 if (scale) { 475 JXL_ASSIGN_OR_RETURN(Image3F tmp, Downsample(*orig2.color(), 2, 2)); 476 orig2.SetFromImage(std::move(tmp), 477 jxl::ColorEncoding::LinearSRGB(orig2.IsGray())); 478 JXL_ASSIGN_OR_RETURN(tmp, Downsample(*dist2.color(), 2, 2)); 479 dist2.SetFromImage(std::move(tmp), 480 jxl::ColorEncoding::LinearSRGB(dist2.IsGray())); 481 img1.ShrinkTo(orig2.xsize(), orig2.ysize()); 482 img2.ShrinkTo(orig2.xsize(), orig2.ysize()); 483 JXL_RETURN_IF_ERROR( 484 jxl::ToXYB(orig2, nullptr, &img1, *JxlGetDefaultCms(), nullptr)); 485 JXL_RETURN_IF_ERROR( 486 jxl::ToXYB(dist2, nullptr, &img2, *JxlGetDefaultCms(), nullptr)); 487 MakePositiveXYB(img1); 488 MakePositiveXYB(img2); 489 } 490 mul.ShrinkTo(img1.xsize(), img1.ysize()); 491 blur.ShrinkTo(img1.xsize(), img1.ysize()); 492 493 Multiply(img1, img1, &mul); 494 JXL_ASSIGN_OR_RETURN(Image3F sigma1_sq, blur(mul)); 495 496 Multiply(img2, img2, &mul); 497 JXL_ASSIGN_OR_RETURN(Image3F sigma2_sq, blur(mul)); 498 499 Multiply(img1, img2, &mul); 500 JXL_ASSIGN_OR_RETURN(Image3F sigma12, blur(mul)); 501 502 JXL_ASSIGN_OR_RETURN(Image3F mu1, blur(img1)); 503 JXL_ASSIGN_OR_RETURN(Image3F mu2, blur(img2)); 504 505 MsssimScale sscale; 506 SSIMMap(mu1, mu2, sigma1_sq, sigma2_sq, sigma12, sscale.avg_ssim); 507 EdgeDiffMap(img1, mu1, img2, mu2, sscale.avg_edgediff); 508 msssim.scales.push_back(sscale); 509 } 510 return msssim; 511 } 512 513 StatusOr<Msssim> ComputeSSIMULACRA2(const ImageBundle& orig, 514 const ImageBundle& distorted) { 515 return ComputeSSIMULACRA2(orig, distorted, 0.5f); 516 }