ssimulacra.cc (11852B)
1 // Copyright (c) the JPEG XL Project Authors. All rights reserved. 2 // 3 // Use of this source code is governed by a BSD-style 4 // license that can be found in the LICENSE file. 5 6 // Re-implementation of //tools/ssimulacra.tct using jxl's 7 // ImageF library instead of opencv. 8 9 #include "tools/ssimulacra.h" 10 11 #include <cmath> 12 13 #include "lib/jxl/base/status.h" 14 #include "lib/jxl/image.h" 15 #include "lib/jxl/image_ops.h" 16 #include "tools/gauss_blur.h" 17 18 namespace ssimulacra { 19 namespace { 20 21 using jxl::Image3F; 22 using jxl::ImageF; 23 using jxl::StatusOr; 24 25 const float kC1 = 0.0001f; 26 const float kC2 = 0.0004f; 27 const int kNumScales = 6; 28 // Premultiplied by chroma weight 0.2 29 const double kScaleWeights[kNumScales][3] = { 30 {0.04480, 0.00300, 0.00300}, {0.28560, 0.00896, 0.00896}, 31 {0.30010, 0.05712, 0.05712}, {0.23630, 0.06002, 0.06002}, 32 {0.13330, 0.06726, 0.06726}, {0.10000, 0.05000, 0.05000}, 33 }; 34 // Premultiplied by min weights 0.1, 0.005, 0.005 35 const double kMinScaleWeights[kNumScales][3] = { 36 {0.02000, 0.00005, 0.00005}, {0.03000, 0.00025, 0.00025}, 37 {0.02500, 0.00100, 0.00100}, {0.02000, 0.00150, 0.00150}, 38 {0.01200, 0.00175, 0.00175}, {0.00500, 0.00175, 0.00175}, 39 }; 40 const double kEdgeWeight[3] = {1.5, 0.1, 0.1}; 41 const double kGridWeight[3] = {1.0, 0.1, 0.1}; 42 43 inline void Rgb2Lab(float r, float g, float b, float* L, float* A, float* B) { 44 const float epsilon = 0.00885645167903563081f; 45 const float s = 0.13793103448275862068f; 46 const float k = 7.78703703703703703703f; 47 float fx = (r * 0.43393624408206207259f + g * 0.37619779063650710152f + 48 b * 0.18983429773803261441f); 49 float fy = (r * 0.2126729f + g * 0.7151522f + b * 0.0721750f); 50 float fz = (r * 0.01775381083562901744f + g * 0.10945087235996326905f + 51 b * 0.87263921028466483011f); 52 const float gamma = 1.0f / 3.0f; 53 float X = (fx > epsilon) ? powf(fx, gamma) - s : k * fx; 54 float Y = (fy > epsilon) ? powf(fy, gamma) - s : k * fy; 55 float Z = (fz > epsilon) ? powf(fz, gamma) - s : k * fz; 56 *L = Y * 1.16f; 57 *A = (0.39181818181818181818f + 2.27272727272727272727f * (X - Y)); 58 *B = (0.49045454545454545454f + 0.90909090909090909090f * (Y - Z)); 59 } 60 61 StatusOr<Image3F> Rgb2Lab(const Image3F& in) { 62 JXL_ASSIGN_OR_RETURN(Image3F out, Image3F::Create(in.xsize(), in.ysize())); 63 for (size_t y = 0; y < in.ysize(); ++y) { 64 const float* JXL_RESTRICT row_in0 = in.PlaneRow(0, y); 65 const float* JXL_RESTRICT row_in1 = in.PlaneRow(1, y); 66 const float* JXL_RESTRICT row_in2 = in.PlaneRow(2, y); 67 float* JXL_RESTRICT row_out0 = out.PlaneRow(0, y); 68 float* JXL_RESTRICT row_out1 = out.PlaneRow(1, y); 69 float* JXL_RESTRICT row_out2 = out.PlaneRow(2, y); 70 71 for (size_t x = 0; x < in.xsize(); ++x) { 72 Rgb2Lab(row_in0[x], row_in1[x], row_in2[x], &row_out0[x], &row_out1[x], 73 &row_out2[x]); 74 } 75 } 76 return out; 77 } 78 79 StatusOr<Image3F> Downsample(const Image3F& in, size_t fx, size_t fy) { 80 const size_t out_xsize = (in.xsize() + fx - 1) / fx; 81 const size_t out_ysize = (in.ysize() + fy - 1) / fy; 82 JXL_ASSIGN_OR_RETURN(Image3F out, Image3F::Create(out_xsize, out_ysize)); 83 const float normalize = 1.0f / (fx * fy); 84 for (size_t c = 0; c < 3; ++c) { 85 for (size_t oy = 0; oy < out_ysize; ++oy) { 86 float* JXL_RESTRICT row_out = out.PlaneRow(c, oy); 87 for (size_t ox = 0; ox < out_xsize; ++ox) { 88 float sum = 0.0f; 89 for (size_t iy = 0; iy < fy; ++iy) { 90 for (size_t ix = 0; ix < fx; ++ix) { 91 const size_t x = std::min(ox * fx + ix, in.xsize() - 1); 92 const size_t y = std::min(oy * fy + iy, in.ysize() - 1); 93 sum += in.PlaneRow(c, y)[x]; 94 } 95 } 96 row_out[ox] = sum * normalize; 97 } 98 } 99 } 100 return out; 101 } 102 103 void Multiply(const Image3F& a, const Image3F& b, Image3F* mul) { 104 for (size_t c = 0; c < 3; ++c) { 105 for (size_t y = 0; y < a.ysize(); ++y) { 106 const float* JXL_RESTRICT in1 = a.PlaneRow(c, y); 107 const float* JXL_RESTRICT in2 = b.PlaneRow(c, y); 108 float* JXL_RESTRICT out = mul->PlaneRow(c, y); 109 for (size_t x = 0; x < a.xsize(); ++x) { 110 out[x] = in1[x] * in2[x]; 111 } 112 } 113 } 114 } 115 116 void RowColAvgP2(const ImageF& in, double* rp2, double* cp2) { 117 std::vector<double> ravg(in.ysize()); 118 std::vector<double> cavg(in.xsize()); 119 for (size_t y = 0; y < in.ysize(); ++y) { 120 const auto* row = in.Row(y); 121 for (size_t x = 0; x < in.xsize(); ++x) { 122 const float val = row[x]; 123 ravg[y] += val; 124 cavg[x] += val; 125 } 126 } 127 std::sort(ravg.begin(), ravg.end()); 128 std::sort(cavg.begin(), cavg.end()); 129 *rp2 = ravg[ravg.size() / 50] / in.xsize(); 130 *cp2 = cavg[cavg.size() / 50] / in.ysize(); 131 } 132 133 class StreamingAverage { 134 public: 135 void Add(const float v) { 136 // Numerically stable method. 137 double delta = v - result_; 138 n_ += 1; 139 result_ += delta / n_; 140 } 141 142 double Get() const { return result_; } 143 144 private: 145 double result_ = 0.0; 146 size_t n_ = 0; 147 }; 148 149 void EdgeDiffMap(const Image3F& img1, const Image3F& mu1, const Image3F& img2, 150 const Image3F& mu2, Image3F* out, double* plane_avg) { 151 for (size_t c = 0; c < 3; ++c) { 152 StreamingAverage avg; 153 for (size_t y = 0; y < img1.ysize(); ++y) { 154 const float* JXL_RESTRICT row1 = img1.PlaneRow(c, y); 155 const float* JXL_RESTRICT row2 = img2.PlaneRow(c, y); 156 const float* JXL_RESTRICT rowm1 = mu1.PlaneRow(c, y); 157 const float* JXL_RESTRICT rowm2 = mu2.PlaneRow(c, y); 158 float* JXL_RESTRICT row_out = out->PlaneRow(c, y); 159 for (size_t x = 0; x < img1.xsize(); ++x) { 160 float edgediff = std::max( 161 std::abs(row2[x] - rowm2[x]) - std::abs(row1[x] - rowm1[x]), 0.0f); 162 row_out[x] = 1.0f - edgediff; 163 avg.Add(row_out[x]); 164 } 165 } 166 plane_avg[c] = avg.Get(); 167 } 168 } 169 170 // Temporary storage for Gaussian blur, reused for multiple images. 171 class Blur { 172 public: 173 static StatusOr<Blur> Create(const size_t xsize, const size_t ysize) { 174 Blur result; 175 JXL_ASSIGN_OR_RETURN(result.temp_, ImageF::Create(xsize, ysize)); 176 return result; 177 } 178 179 void operator()(const ImageF& in, ImageF* JXL_RESTRICT out) { 180 FastGaussian( 181 rg_, in.xsize(), in.ysize(), [&](size_t y) { return in.ConstRow(y); }, 182 [&](size_t y) { return temp_.Row(y); }, 183 [&](size_t y) { return out->Row(y); }); 184 } 185 186 StatusOr<Image3F> operator()(const Image3F& in) { 187 JXL_ASSIGN_OR_RETURN(Image3F out, Image3F::Create(in.xsize(), in.ysize())); 188 operator()(in.Plane(0), &out.Plane(0)); 189 operator()(in.Plane(1), &out.Plane(1)); 190 operator()(in.Plane(2), &out.Plane(2)); 191 return out; 192 } 193 194 // Allows reusing across scales. 195 void ShrinkTo(const size_t xsize, const size_t ysize) { 196 temp_.ShrinkTo(xsize, ysize); 197 } 198 199 private: 200 Blur() : rg_(jxl::CreateRecursiveGaussian(1.5)) {} 201 hwy::AlignedUniquePtr<jxl::RecursiveGaussian> rg_; 202 ImageF temp_; 203 }; 204 205 void SSIMMap(const Image3F& m1, const Image3F& m2, const Image3F& s11, 206 const Image3F& s22, const Image3F& s12, Image3F* out, 207 double* plane_averages) { 208 for (size_t c = 0; c < 3; ++c) { 209 StreamingAverage avg; 210 for (size_t y = 0; y < out->ysize(); ++y) { 211 const float* JXL_RESTRICT row_m1 = m1.PlaneRow(c, y); 212 const float* JXL_RESTRICT row_m2 = m2.PlaneRow(c, y); 213 const float* JXL_RESTRICT row_s11 = s11.PlaneRow(c, y); 214 const float* JXL_RESTRICT row_s22 = s22.PlaneRow(c, y); 215 const float* JXL_RESTRICT row_s12 = s12.PlaneRow(c, y); 216 float* JXL_RESTRICT row_out = out->PlaneRow(c, y); 217 for (size_t x = 0; x < out->xsize(); ++x) { 218 float mu1 = row_m1[x]; 219 float mu2 = row_m2[x]; 220 float mu11 = mu1 * mu1; 221 float mu22 = mu2 * mu2; 222 float mu12 = mu1 * mu2; 223 float nom_m = 2 * mu12 + kC1; 224 float nom_s = 2 * (row_s12[x] - mu12) + kC2; 225 float denom_m = mu11 + mu22 + kC1; 226 float denom_s = (row_s11[x] - mu11) + (row_s22[x] - mu22) + kC2; 227 row_out[x] = (nom_m * nom_s) / (denom_m * denom_s); 228 avg.Add(row_out[x]); 229 } 230 } 231 plane_averages[c] = avg.Get(); 232 } 233 } 234 235 } // namespace 236 237 double Ssimulacra::Score() const { 238 double ssim = 0.0; 239 double ssim_max = 0.0; 240 for (size_t c = 0; c < 3; ++c) { 241 for (size_t scale = 0; scale < scales.size(); ++scale) { 242 ssim += kScaleWeights[scale][c] * scales[scale].avg_ssim[c]; 243 ssim_max += kScaleWeights[scale][c]; 244 ssim += kMinScaleWeights[scale][c] * scales[scale].min_ssim[c]; 245 ssim_max += kMinScaleWeights[scale][c]; 246 } 247 if (!simple) { 248 ssim += kEdgeWeight[c] * avg_edgediff[c]; 249 ssim_max += kEdgeWeight[c]; 250 ssim += kGridWeight[c] * 251 (row_p2[0][c] + row_p2[1][c] + col_p2[0][c] + col_p2[1][c]); 252 ssim_max += 4.0 * kGridWeight[c]; 253 } 254 } 255 double dssim = ssim_max / ssim - 1.0; 256 return std::min(1.0, std::max(0.0, dssim)); 257 } 258 259 inline void PrintItem(const char* name, int scale, const double* vals, 260 const double* w) { 261 printf("scale %d %s = [%.10f %.10f %.10f] w = [%.5f %.5f %.5f]\n", scale, 262 name, vals[0], vals[1], vals[2], w[0], w[1], w[2]); 263 } 264 265 void Ssimulacra::PrintDetails() const { 266 for (size_t s = 0; s < scales.size(); ++s) { 267 if (s < kNumScales) { 268 PrintItem("avg ssim", s, scales[s].avg_ssim, kScaleWeights[s]); 269 PrintItem("min ssim", s, scales[s].min_ssim, kMinScaleWeights[s]); 270 } 271 if (s == 0 && !simple) { 272 PrintItem("avg edif", s, avg_edgediff, kEdgeWeight); 273 PrintItem("rp2 ssim", s, &row_p2[0][0], kGridWeight); 274 PrintItem("cp2 ssim", s, &col_p2[0][0], kGridWeight); 275 PrintItem("rp2 edif", s, &row_p2[1][0], kGridWeight); 276 PrintItem("cp2 edif", s, &col_p2[1][0], kGridWeight); 277 } 278 } 279 } 280 281 StatusOr<Ssimulacra> ComputeDiff(const Image3F& orig, const Image3F& distorted, 282 bool simple) { 283 Ssimulacra ssimulacra; 284 285 ssimulacra.simple = simple; 286 JXL_ASSIGN_OR_RETURN(Image3F img1, Rgb2Lab(orig)); 287 JXL_ASSIGN_OR_RETURN(Image3F img2, Rgb2Lab(distorted)); 288 289 JXL_ASSIGN_OR_RETURN(Image3F mul, 290 Image3F::Create(orig.xsize(), orig.ysize())); 291 JXL_ASSIGN_OR_RETURN(Blur blur, Blur::Create(img1.xsize(), img1.ysize())); 292 293 for (int scale = 0; scale < kNumScales; scale++) { 294 if (img1.xsize() < 8 || img1.ysize() < 8) { 295 break; 296 } 297 if (scale) { 298 JXL_ASSIGN_OR_RETURN(img1, Downsample(img1, 2, 2)); 299 JXL_ASSIGN_OR_RETURN(img2, Downsample(img2, 2, 2)); 300 } 301 mul.ShrinkTo(img1.xsize(), img2.ysize()); 302 blur.ShrinkTo(img1.xsize(), img2.ysize()); 303 304 Multiply(img1, img1, &mul); 305 JXL_ASSIGN_OR_RETURN(Image3F sigma1_sq, blur(mul)); 306 307 Multiply(img2, img2, &mul); 308 JXL_ASSIGN_OR_RETURN(Image3F sigma2_sq, blur(mul)); 309 310 Multiply(img1, img2, &mul); 311 JXL_ASSIGN_OR_RETURN(Image3F sigma12, blur(mul)); 312 313 JXL_ASSIGN_OR_RETURN(Image3F mu1, blur(img1)); 314 JXL_ASSIGN_OR_RETURN(Image3F mu2, blur(img2)); 315 // Reuse mul as "ssim_map". 316 SsimulacraScale sscale; 317 SSIMMap(mu1, mu2, sigma1_sq, sigma2_sq, sigma12, &mul, sscale.avg_ssim); 318 319 JXL_ASSIGN_OR_RETURN(const Image3F ssim_map, Downsample(mul, 4, 4)); 320 for (size_t c = 0; c < 3; c++) { 321 float minval; 322 float maxval; 323 ImageMinMax(ssim_map.Plane(c), &minval, &maxval); 324 sscale.min_ssim[c] = static_cast<double>(minval); 325 } 326 ssimulacra.scales.push_back(sscale); 327 328 if (scale == 0 && !simple) { 329 Image3F* edgediff = &sigma1_sq; // reuse 330 EdgeDiffMap(img1, mu1, img2, mu2, edgediff, ssimulacra.avg_edgediff); 331 for (size_t c = 0; c < 3; c++) { 332 RowColAvgP2(ssim_map.Plane(c), &ssimulacra.row_p2[0][c], 333 &ssimulacra.col_p2[0][c]); 334 RowColAvgP2(edgediff->Plane(c), &ssimulacra.row_p2[1][c], 335 &ssimulacra.col_p2[1][c]); 336 } 337 } 338 } 339 return ssimulacra; 340 } 341 342 } // namespace ssimulacra