libjxl

FORK: libjxl patches used on blog
git clone https://git.neptards.moe/blog/libjxl.git
Log | Files | Refs | Submodules | README | LICENSE

benchmark_xl.cc (38846B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 #include <jxl/cms.h>
      7 #include <jxl/decode.h>
      8 #include <jxl/types.h>
      9 
     10 #include <algorithm>
     11 #include <cmath>
     12 #include <cstdint>
     13 #include <cstdio>
     14 #include <cstdlib>
     15 #include <cstring>
     16 #include <memory>
     17 #include <mutex>
     18 #include <numeric>
     19 #include <string>
     20 #include <thread>
     21 #include <utility>
     22 #include <vector>
     23 
     24 #include "lib/extras/codec.h"
     25 #include "lib/extras/dec/color_hints.h"
     26 #include "lib/extras/dec/decode.h"
     27 #include "lib/extras/enc/apng.h"
     28 #include "lib/extras/metrics.h"
     29 #include "lib/extras/packed_image.h"
     30 #include "lib/extras/packed_image_convert.h"
     31 #include "lib/jxl/base/compiler_specific.h"
     32 #include "lib/jxl/base/data_parallel.h"
     33 #include "lib/jxl/base/printf_macros.h"
     34 #include "lib/jxl/base/random.h"
     35 #include "lib/jxl/base/span.h"
     36 #include "lib/jxl/base/status.h"
     37 #include "lib/jxl/butteraugli/butteraugli.h"
     38 #include "lib/jxl/cache_aligned.h"
     39 #include "lib/jxl/codec_in_out.h"
     40 #include "lib/jxl/color_encoding_internal.h"
     41 #include "lib/jxl/enc_butteraugli_comparator.h"
     42 #include "lib/jxl/image.h"
     43 #include "lib/jxl/image_bundle.h"
     44 #include "lib/jxl/image_ops.h"
     45 #include "lib/jxl/jpeg/enc_jpeg_data.h"
     46 #include "tools/benchmark/benchmark_args.h"
     47 #include "tools/benchmark/benchmark_codec.h"
     48 #include "tools/benchmark/benchmark_file_io.h"
     49 #include "tools/benchmark/benchmark_stats.h"
     50 #include "tools/benchmark/benchmark_utils.h"
     51 #include "tools/codec_config.h"
     52 #include "tools/file_io.h"
     53 #include "tools/speed_stats.h"
     54 #include "tools/ssimulacra2.h"
     55 #include "tools/thread_pool_internal.h"
     56 
     57 namespace jpegxl {
     58 namespace tools {
     59 namespace {
     60 
     61 using ::jxl::ButteraugliParams;
     62 using ::jxl::Bytes;
     63 using ::jxl::CodecInOut;
     64 using ::jxl::ColorEncoding;
     65 using ::jxl::Image3F;
     66 using ::jxl::ImageBundle;
     67 using ::jxl::ImageF;
     68 using ::jxl::JxlButteraugliComparator;
     69 using ::jxl::Rng;
     70 using ::jxl::Status;
     71 using ::jxl::ThreadPool;
     72 using ::jxl::extras::PackedPixelFile;
     73 
     74 Status WriteImage(const Image3F& image, ThreadPool* pool,
     75                   const std::string& filename) {
     76   JxlPixelFormat format = {3, JXL_TYPE_UINT8, JXL_BIG_ENDIAN, 0};
     77   PackedPixelFile ppf = jxl::extras::ConvertImage3FToPackedPixelFile(
     78       image, ColorEncoding::SRGB(), format, pool);
     79   std::vector<uint8_t> encoded;
     80   return Encode(ppf, filename, &encoded, pool) && WriteFile(filename, encoded);
     81 }
     82 
     83 Status ReadPNG(const std::string& filename, Image3F* image) {
     84   CodecInOut io;
     85   std::vector<uint8_t> encoded;
     86   JXL_CHECK(ReadFile(filename, &encoded));
     87   JXL_CHECK(
     88       jxl::SetFromBytes(jxl::Bytes(encoded), jxl::extras::ColorHints(), &io));
     89   JXL_ASSIGN_OR_DIE(*image, Image3F::Create(io.xsize(), io.ysize()));
     90   CopyImageTo(*io.Main().color(), image);
     91   return true;
     92 }
     93 
     94 Status CreateNonSRGBICCProfile(PackedPixelFile* ppf) {
     95   ColorEncoding color_encoding;
     96   JXL_RETURN_IF_ERROR(color_encoding.FromExternal(ppf->color_encoding));
     97   if (color_encoding.ICC().empty()) {
     98     return JXL_FAILURE("Invalid color encoding.");
     99   }
    100   if (!color_encoding.IsSRGB()) {
    101     ppf->icc.assign(color_encoding.ICC().begin(), color_encoding.ICC().end());
    102   }
    103   return true;
    104 }
    105 
    106 std::string CodecToExtension(const std::string& codec_name, char sep) {
    107   std::string result;
    108   // Add in the parameters of the codec_name in reverse order, so that the
    109   // name of the file format (e.g. jxl) is last.
    110   int pos = static_cast<int>(codec_name.size()) - 1;
    111   while (pos > 0) {
    112     int prev = codec_name.find_last_of(sep, pos);
    113     if (prev > pos) prev = -1;
    114     result += '.' + codec_name.substr(prev + 1, pos - prev);
    115     pos = prev - 1;
    116   }
    117   return result;
    118 }
    119 
    120 void DoCompress(const std::string& filename, const PackedPixelFile& ppf,
    121                 const std::vector<std::string>& extra_metrics_commands,
    122                 ImageCodec* codec, ThreadPool* inner_pool,
    123                 std::vector<uint8_t>* compressed, BenchmarkStats* s) {
    124   ++s->total_input_files;
    125 
    126   if (ppf.frames.size() != 1) {
    127     // Multiple frames not supported.
    128     s->total_errors++;
    129     if (!Args()->silent_errors) {
    130       JXL_WARNING("multiframe input image not supported %s", filename.c_str());
    131     }
    132     return;
    133   }
    134   const size_t xsize = ppf.info.xsize;
    135   const size_t ysize = ppf.info.ysize;
    136   const size_t input_pixels = xsize * ysize;
    137 
    138   jpegxl::tools::SpeedStats speed_stats;
    139   jpegxl::tools::SpeedStats::Summary summary;
    140 
    141   bool valid = true;  // false if roundtrip, encoding or decoding errors occur.
    142 
    143   if (!Args()->decode_only && (xsize == 0 || ysize == 0)) {
    144     // This means the benchmark couldn't load the image, e.g. due to invalid
    145     // ICC profile. Warning message about that was already printed. Continue
    146     // this function to indicate it as error in the stats.
    147     valid = false;
    148   }
    149 
    150   std::string ext = FileExtension(filename);
    151   if (valid && !Args()->decode_only) {
    152     for (size_t i = 0; i < Args()->encode_reps; ++i) {
    153       if (codec->CanRecompressJpeg() && (ext == ".jpg" || ext == ".jpeg")) {
    154         std::vector<uint8_t> data_in;
    155         JXL_CHECK(ReadFile(filename, &data_in));
    156         JXL_CHECK(
    157             codec->RecompressJpeg(filename, data_in, compressed, &speed_stats));
    158       } else {
    159         Status status = codec->Compress(filename, ppf, inner_pool, compressed,
    160                                         &speed_stats);
    161         if (!status) {
    162           valid = false;
    163           if (!Args()->silent_errors) {
    164             std::string message = codec->GetErrorMessage();
    165             if (!message.empty()) {
    166               fprintf(stderr, "Error in %s codec: %s\n",
    167                       codec->description().c_str(), message.c_str());
    168             } else {
    169               fprintf(stderr, "Error in %s codec\n",
    170                       codec->description().c_str());
    171             }
    172           }
    173         }
    174       }
    175     }
    176     JXL_CHECK(speed_stats.GetSummary(&summary));
    177     s->total_time_encode += summary.central_tendency;
    178   }
    179 
    180   if (valid && Args()->decode_only) {
    181     std::vector<uint8_t> data_in;
    182     JXL_CHECK(ReadFile(filename, &data_in));
    183     compressed->insert(compressed->end(), data_in.begin(), data_in.end());
    184   }
    185 
    186   // Decompress
    187   PackedPixelFile ppf2;
    188   if (valid) {
    189     speed_stats = jpegxl::tools::SpeedStats();
    190     for (size_t i = 0; i < Args()->decode_reps; ++i) {
    191       if (!codec->Decompress(filename, Bytes(*compressed), inner_pool, &ppf2,
    192                              &speed_stats)) {
    193         if (!Args()->silent_errors) {
    194           fprintf(stderr,
    195                   "%s failed to decompress encoded image. Original source:"
    196                   " %s\n",
    197                   codec->description().c_str(), filename.c_str());
    198         }
    199         valid = false;
    200       }
    201     }
    202     for (const auto& frame : ppf2.frames) {
    203       s->total_input_pixels += frame.color.xsize * frame.color.ysize;
    204     }
    205     JXL_CHECK(speed_stats.GetSummary(&summary));
    206     s->total_time_decode += summary.central_tendency;
    207   }
    208 
    209   std::string name = FileBaseName(filename);
    210   std::string codec_name = codec->description();
    211 
    212   if (!valid) {
    213     s->total_errors++;
    214   }
    215 
    216   if (ppf.frames.size() != ppf2.frames.size()) {
    217     if (!Args()->silent_errors) {
    218       // Animated gifs not supported yet?
    219       fprintf(stderr,
    220               "Frame sizes not equal, is this an animated gif? %s %s %" PRIuS
    221               " %" PRIuS "\n",
    222               codec_name.c_str(), name.c_str(), ppf.frames.size(),
    223               ppf2.frames.size());
    224     }
    225     valid = false;
    226   }
    227 
    228   bool skip_butteraugli = Args()->skip_butteraugli || Args()->decode_only;
    229   ImageF distmap;
    230   float distance = 1.0f;
    231 
    232   if (valid && !skip_butteraugli) {
    233     CodecInOut ppf_io;
    234     JXL_CHECK(ConvertPackedPixelFileToCodecInOut(ppf, inner_pool, &ppf_io));
    235     CodecInOut ppf2_io;
    236     JXL_CHECK(ConvertPackedPixelFileToCodecInOut(ppf2, inner_pool, &ppf2_io));
    237     const ImageBundle& ib1 = ppf_io.Main();
    238     const ImageBundle& ib2 = ppf2_io.Main();
    239     if (jxl::SameSize(ppf, ppf2)) {
    240       ButteraugliParams params;
    241       // Hack the default intensity target value to be 80.0, the intensity
    242       // target of sRGB images and a more reasonable viewing default than
    243       // JPEG XL file format's default.
    244       // TODO(szabadka) Support different intensity targets as well.
    245       params.intensity_target = 80.0;
    246 
    247       const JxlCmsInterface& cms = *JxlGetDefaultCms();
    248       JxlButteraugliComparator comparator(params, cms);
    249       JXL_CHECK(ComputeScore(ib1, ib2, &comparator, cms, &distance, &distmap,
    250                              inner_pool, codec->IgnoreAlpha()));
    251     } else {
    252       // TODO(veluca): re-upsample and compute proper distance.
    253       distance = 1e+4f;
    254       JXL_ASSIGN_OR_DIE(distmap, ImageF::Create(1, 1));
    255       distmap.Row(0)[0] = distance;
    256     }
    257     // Update stats
    258     s->psnr +=
    259         compressed->empty()
    260             ? 0
    261             : jxl::ComputePSNR(ib1, ib2, *JxlGetDefaultCms()) * input_pixels;
    262     s->distance_p_norm +=
    263         ComputeDistanceP(distmap, ButteraugliParams(), Args()->error_pnorm) *
    264         input_pixels;
    265     JXL_ASSIGN_OR_DIE(Msssim msssim, ComputeSSIMULACRA2(ib1, ib2));
    266     s->ssimulacra2 += msssim.Score() * input_pixels;
    267     s->max_distance = std::max(s->max_distance, distance);
    268     s->distances.push_back(distance);
    269   }
    270 
    271   s->total_compressed_size += compressed->size();
    272   s->total_adj_compressed_size += compressed->size() * std::max(1.0f, distance);
    273   codec->GetMoreStats(s);
    274 
    275   if (Args()->save_compressed || Args()->save_decompressed) {
    276     std::string dir = FileDirName(filename);
    277     std::string outdir =
    278         Args()->output_dir.empty() ? dir + "/out" : Args()->output_dir;
    279     std::string compressed_fn =
    280         outdir + "/" + name + CodecToExtension(codec_name, ':');
    281     std::string decompressed_fn = compressed_fn + Args()->output_extension;
    282     std::string heatmap_fn;
    283     if (jxl::extras::GetAPNGEncoder()) {
    284       heatmap_fn = compressed_fn + ".heatmap.png";
    285     } else {
    286       heatmap_fn = compressed_fn + ".heatmap.ppm";
    287     }
    288     JXL_CHECK(MakeDir(outdir));
    289     if (Args()->save_compressed) {
    290       JXL_CHECK(WriteFile(compressed_fn, *compressed));
    291     }
    292     if (Args()->save_decompressed && valid) {
    293       // TODO(szabadka): Handle Args()->mul_output
    294       std::vector<uint8_t> encoded;
    295       JXL_CHECK(jxl::Encode(ppf2, decompressed_fn, &encoded));
    296       JXL_CHECK(WriteFile(decompressed_fn, encoded));
    297       if (!skip_butteraugli) {
    298         float good = Args()->heatmap_good > 0.0f
    299                          ? Args()->heatmap_good
    300                          : jxl::ButteraugliFuzzyInverse(1.5);
    301         float bad = Args()->heatmap_bad > 0.0f
    302                         ? Args()->heatmap_bad
    303                         : jxl::ButteraugliFuzzyInverse(0.5);
    304         if (Args()->save_heatmap) {
    305           JXL_ASSIGN_OR_DIE(Image3F heatmap,
    306                             CreateHeatMapImage(distmap, good, bad));
    307           JXL_CHECK(WriteImage(heatmap, inner_pool, heatmap_fn));
    308         }
    309       }
    310     }
    311   }
    312   if (!extra_metrics_commands.empty()) {
    313     TemporaryFile tmp_in("original", "pfm");
    314     TemporaryFile tmp_out("decoded", "pfm");
    315     TemporaryFile tmp_res("result", "txt");
    316     std::string tmp_in_fn;
    317     std::string tmp_out_fn;
    318     std::string tmp_res_fn;
    319     JXL_CHECK(tmp_in.GetFileName(&tmp_in_fn));
    320     JXL_CHECK(tmp_out.GetFileName(&tmp_out_fn));
    321     JXL_CHECK(tmp_res.GetFileName(&tmp_res_fn));
    322 
    323     std::vector<uint8_t> encoded;
    324     JXL_CHECK(jxl::Encode(ppf, tmp_in_fn, &encoded));
    325     JXL_CHECK(WriteFile(tmp_in_fn, encoded));
    326     JXL_CHECK(jxl::Encode(ppf2, tmp_out_fn, &encoded));
    327     JXL_CHECK(WriteFile(tmp_out_fn, encoded));
    328     // TODO(szabadka) Handle custom intensity target.
    329     std::string intensity_target = "255";
    330     for (size_t i = 0; i < extra_metrics_commands.size(); i++) {
    331       float res = nanf("");
    332       bool error = false;
    333       if (RunCommand(extra_metrics_commands[i],
    334                      {tmp_in_fn, tmp_out_fn, tmp_res_fn, intensity_target})) {
    335         FILE* f = fopen(tmp_res_fn.c_str(), "r");
    336         if (fscanf(f, "%f", &res) != 1) {
    337           error = true;
    338         }
    339         fclose(f);
    340       } else {
    341         error = true;
    342       }
    343       if (error) {
    344         fprintf(stderr,
    345                 "WARNING: Computation of metric with command %s failed\n",
    346                 extra_metrics_commands[i].c_str());
    347       }
    348       s->extra_metrics.push_back(res);
    349     }
    350   }
    351 
    352   if (Args()->show_progress) {
    353     fprintf(stderr, ".");
    354     fflush(stderr);
    355   }
    356 }
    357 
    358 // Makes a base64 data URI for embedded image in HTML
    359 std::string Base64Image(const std::string& filename) {
    360   std::vector<uint8_t> bytes;
    361   if (!ReadFile(filename, &bytes)) {
    362     return "";
    363   }
    364   static const char* symbols =
    365       "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
    366   std::string result;
    367   for (size_t i = 0; i < bytes.size(); i += 3) {
    368     uint8_t o0 = bytes[i + 0];
    369     uint8_t o1 = (i + 1 < bytes.size()) ? bytes[i + 1] : 0;
    370     uint8_t o2 = (i + 2 < bytes.size()) ? bytes[i + 2] : 0;
    371     uint32_t value = (o0 << 16) | (o1 << 8) | o2;
    372     for (size_t j = 0; j < 4; j++) {
    373       result += (i + j <= bytes.size()) ? symbols[(value >> (6 * (3 - j))) & 63]
    374                                         : '=';
    375     }
    376   }
    377   // NOTE: Chrome supports max 2MB of data this way for URLs, but appears to
    378   // support larger images anyway as long as it's embedded in the HTML file
    379   // itself. If more data is needed, use createObjectURL.
    380   return "data:image;base64," + result;
    381 }
    382 
    383 struct Task {
    384   ImageCodecPtr codec;
    385   size_t idx_image;
    386   size_t idx_method;
    387   const PackedPixelFile* image;
    388   BenchmarkStats stats;
    389 };
    390 
    391 void WriteHtmlReport(const std::string& codec_desc,
    392                      const std::vector<std::string>& fnames,
    393                      const std::vector<const Task*>& tasks,
    394                      const std::vector<const PackedPixelFile*>& images,
    395                      bool add_heatmap, bool self_contained) {
    396   std::string toggle_js =
    397       "<script type=\"text/javascript\">\n"
    398       "  var codecname = '" +
    399       codec_desc + "';\n";
    400   if (add_heatmap) {
    401     toggle_js += R"(
    402   var maintitle = codecname + ' - click images to toggle, press space to' +
    403       ' toggle all, h to toggle all heatmaps. Zoom in with CTRL+wheel or' +
    404       ' CTRL+plus.';
    405   document.title = maintitle;
    406   var counter = [];
    407   function setState(i, s) {
    408     var preview = document.getElementById("preview" + i);
    409     var orig = document.getElementById("orig" + i);
    410     var hm = document.getElementById("hm" + i);
    411     if (s == 0) {
    412       preview.style.display = 'none';
    413       orig.style.display = 'block';
    414       hm.style.display = 'none';
    415     } else if (s == 1) {
    416       preview.style.display = 'block';
    417       orig.style.display = 'none';
    418       hm.style.display = 'none';
    419     } else if (s == 2) {
    420       preview.style.display = 'none';
    421       orig.style.display = 'none';
    422       hm.style.display = 'block';
    423     }
    424   }
    425   function toggle(i) {
    426     for (index = counter.length; index <= i; index++) {
    427       counter.push(1);
    428     }
    429     setState(i, counter[i]);
    430     counter[i] = (counter[i] + 1) % 3;
    431     document.title = maintitle;
    432   }
    433   var toggleall_state = 1;
    434   document.body.onkeydown = function(e) {
    435     // space (32) to toggle orig/compr, 'h' (72) to toggle heatmap/compr
    436     if (e.keyCode == 32 || e.keyCode == 72) {
    437       var divs = document.getElementsByTagName('div');
    438       var key_state = (e.keyCode == 32) ? 0 : 2;
    439       toggleall_state = (toggleall_state == key_state) ? 1 : key_state;
    440       document.title = codecname + ' - ' + (toggleall_state == 0 ?
    441           'originals' : (toggleall_state == 1 ? 'compressed' : 'heatmaps'));
    442       for (var i = 0; i < divs.length; i++) {
    443         setState(i, toggleall_state);
    444       }
    445       return false;
    446     }
    447   };
    448 </script>
    449 )";
    450   } else {
    451     toggle_js += R"(
    452   var maintitle = codecname + ' - click images to toggle, press space to' +
    453       ' toggle all. Zoom in with CTRL+wheel or CTRL+plus.';
    454   document.title = maintitle;
    455   var counter = [];
    456   function setState(i, s) {
    457     var preview = document.getElementById("preview" + i);
    458     var orig = document.getElementById("orig" + i);
    459     if (s == 0) {
    460       preview.style.display = 'none';
    461       orig.style.display = 'block';
    462     } else if (s == 1) {
    463       preview.style.display = 'block';
    464       orig.style.display = 'none';
    465     }
    466   }
    467   function toggle(i) {
    468     for (index = counter.length; index <= i; index++) {
    469       counter.push(1);
    470     }
    471     setState(i, counter[i]);
    472     counter[i] = 1 - counter[i];
    473     document.title = maintitle;
    474   }
    475   var toggleall_state = 1;
    476   document.body.onkeydown = function(e) {
    477     // space (32) to toggle orig/compr
    478     if (e.keyCode == 32) {
    479       var divs = document.getElementsByTagName('div');
    480       toggleall_state = 1 - toggleall_state;
    481       document.title = codecname + ' - ' + (toggleall_state == 0 ?
    482           'originals' : 'compressed');
    483       for (var i = 0; i < divs.length; i++) {
    484         setState(i, toggleall_state);
    485       }
    486       return false;
    487     }
    488   };
    489 </script>
    490 )";
    491   }
    492   std::string out_html;
    493   std::string outdir;
    494   out_html += "<body bgcolor=\"#000\">\n";
    495   out_html += "<style>img { image-rendering: pixelated; }</style>\n";
    496   std::string codec_name = codec_desc;
    497   // Make compatible for filename
    498   std::replace(codec_name.begin(), codec_name.end(), ':', '_');
    499   for (size_t i = 0; i < fnames.size(); ++i) {
    500     std::string name = FileBaseName(fnames[i]);
    501     std::string dir = FileDirName(fnames[i]);
    502     outdir = Args()->output_dir.empty() ? dir + "/out" : Args()->output_dir;
    503     std::string name_out = name + CodecToExtension(codec_name, '_');
    504     if (Args()->html_report_use_decompressed) {
    505       name_out += Args()->output_extension;
    506     }
    507     std::string heatmap_out =
    508         name + CodecToExtension(codec_name, '_') + ".heatmap.png";
    509 
    510     const std::string& fname_orig = fnames[i];
    511     std::string fname_out = outdir + "/" + name_out;
    512     std::string fname_heatmap = outdir + "/" + heatmap_out;
    513     std::string url_orig = Args()->originals_url.empty()
    514                                ? ("file://" + fnames[i])
    515                                : (Args()->originals_url + "/" + name);
    516     std::string url_out = name_out;
    517     std::string url_heatmap = heatmap_out;
    518     if (self_contained) {
    519       url_orig = Base64Image(fname_orig);
    520       url_out = Base64Image(fname_out);
    521       url_heatmap = Base64Image(fname_heatmap);
    522     }
    523     std::string number = StringPrintf("%" PRIuS, i);
    524     const PackedPixelFile& image = *images[i];
    525     size_t xsize = image.frames.size() == 1 ? image.info.xsize : 0;
    526     size_t ysize = image.frames.size() == 1 ? image.info.ysize : 0;
    527     std::string html_width = StringPrintf("%" PRIuS "px", xsize);
    528     std::string html_height = StringPrintf("%" PRIuS "px", ysize);
    529     double bpp = tasks[i]->stats.total_compressed_size * 8.0 /
    530                  tasks[i]->stats.total_input_pixels;
    531     double pnorm =
    532         tasks[i]->stats.distance_p_norm / tasks[i]->stats.total_input_pixels;
    533     double max_dist = tasks[i]->stats.max_distance;
    534     std::string compressed_title = StringPrintf(
    535         "compressed. bpp: %f, pnorm: %f, max dist: %f", bpp, pnorm, max_dist);
    536     out_html += "<div onclick=\"toggle(" + number +
    537                 ");\" style=\"display:inline-block;width:" + html_width +
    538                 ";height:" + html_height +
    539                 ";\">\n"
    540                 "  <img title=\"" +
    541                 compressed_title + "\" id=\"preview" + number + "\" src=";
    542     out_html += "\"" + url_out + "\"style=\"display:block;\"/>\n";
    543     out_html += R"(  <img title="original" id="orig)" + number + "\" src=";
    544     out_html += "\"" + url_orig + "\"style=\"display:none;\"/>\n";
    545     if (add_heatmap) {
    546       out_html = R"(  <img title="heatmap" id="hm)" + number + "\" src=";
    547       out_html += "\"" + url_heatmap + "\"style=\"display:none;\"/>\n";
    548     }
    549     out_html += "</div>\n";
    550   }
    551   out_html += "</body>\n";
    552   out_html += toggle_js;
    553   JXL_CHECK(WriteFile(outdir + "/index." + codec_name + ".html", out_html));
    554 }
    555 
    556 // Prints the detailed and aggregate statistics, in the correct order but as
    557 // soon as possible when multithreaded tasks are done.
    558 struct StatPrinter {
    559   StatPrinter(const std::vector<std::string>& methods,
    560               const std::vector<std::string>& extra_metrics_names,
    561               const std::vector<std::string>& fnames,
    562               const std::vector<Task>& tasks)
    563       : methods_(&methods),
    564         extra_metrics_names_(&extra_metrics_names),
    565         fnames_(&fnames),
    566         tasks_(&tasks),
    567         tasks_done_(0),
    568         stats_printed_(0),
    569         details_printed_(0) {
    570     stats_done_.resize(methods.size(), 0);
    571     details_done_.resize(tasks.size(), 0);
    572     max_fname_width_ = 0;
    573     for (const auto& fname : fnames) {
    574       max_fname_width_ = std::max(max_fname_width_, FileBaseName(fname).size());
    575     }
    576     max_method_width_ = 0;
    577     for (const auto& method : methods) {
    578       max_method_width_ =
    579           std::max(max_method_width_, FileBaseName(method).size());
    580     }
    581   }
    582 
    583   void TaskDone(size_t task_index, const Task& t) {
    584     std::lock_guard<std::mutex> guard(mutex);
    585     tasks_done_++;
    586     if (Args()->print_details || Args()->show_progress) {
    587       if (Args()->print_details) {
    588         // Render individual results as soon as they are ready and all previous
    589         // ones in task order are ready.
    590         details_done_[task_index] = 1;
    591         if (task_index == details_printed_) {
    592           while (details_printed_ < tasks_->size() &&
    593                  details_done_[details_printed_]) {
    594             PrintDetails((*tasks_)[details_printed_]);
    595             details_printed_++;
    596           }
    597         }
    598       }
    599       // When using "show_progress" or "print_details", the table must be
    600       // rendered at the very end, else the details or progress would be
    601       // rendered in-between the table rows.
    602       if (tasks_done_ == tasks_->size()) {
    603         PrintStatsHeader();
    604         for (size_t i = 0; i < methods_->size(); i++) {
    605           PrintStats((*methods_)[i], i);
    606         }
    607         PrintStatsFooter();
    608       }
    609     } else {
    610       if (tasks_done_ == 1) {
    611         PrintStatsHeader();
    612       }
    613       // Render lines of the table as soon as it is ready and all previous
    614       // lines have been printed.
    615       stats_done_[t.idx_method]++;
    616       if (stats_done_[t.idx_method] == fnames_->size() &&
    617           t.idx_method == stats_printed_) {
    618         while (stats_printed_ < stats_done_.size() &&
    619                stats_done_[stats_printed_] == fnames_->size()) {
    620           PrintStats((*methods_)[stats_printed_], stats_printed_);
    621           stats_printed_++;
    622         }
    623       }
    624       if (tasks_done_ == tasks_->size()) {
    625         PrintStatsFooter();
    626       }
    627     }
    628   }
    629 
    630   void PrintDetails(const Task& t) const {
    631     double comp_bpp =
    632         t.stats.total_compressed_size * 8.0 / t.stats.total_input_pixels;
    633     double p_norm = t.stats.distance_p_norm / t.stats.total_input_pixels;
    634     double psnr = t.stats.psnr / t.stats.total_input_pixels;
    635     double ssimulacra2 = t.stats.ssimulacra2 / t.stats.total_input_pixels;
    636     double bpp_p_norm = p_norm * comp_bpp;
    637 
    638     const double adj_comp_bpp =
    639         t.stats.total_adj_compressed_size * 8.0 / t.stats.total_input_pixels;
    640 
    641     size_t pixels = t.stats.total_input_pixels;
    642 
    643     const double enc_mps =
    644         t.stats.total_input_pixels / (1000000.0 * t.stats.total_time_encode);
    645     const double dec_mps =
    646         t.stats.total_input_pixels / (1000000.0 * t.stats.total_time_decode);
    647     if (Args()->print_details_csv) {
    648       printf("%s,%s,%" PRIdS ",%" PRIdS ",%" PRIdS
    649              ",%.8f,%.8f,%.8f,%.8f,%.8f,%.8f,%.8f,%.8f",
    650              (*methods_)[t.idx_method].c_str(),
    651              FileBaseName((*fnames_)[t.idx_image]).c_str(),
    652              t.stats.total_errors, t.stats.total_compressed_size, pixels,
    653              enc_mps, dec_mps, comp_bpp, t.stats.max_distance, psnr, p_norm,
    654              bpp_p_norm, adj_comp_bpp);
    655       for (float m : t.stats.extra_metrics) {
    656         printf(",%.8f", m);
    657       }
    658       printf("\n");
    659     } else {
    660       printf("%s", (*methods_)[t.idx_method].c_str());
    661       for (size_t i = (*methods_)[t.idx_method].size(); i <= max_method_width_;
    662            i++) {
    663         printf(" ");
    664       }
    665       printf("%s", FileBaseName((*fnames_)[t.idx_image]).c_str());
    666       for (size_t i = FileBaseName((*fnames_)[t.idx_image]).size();
    667            i <= max_fname_width_; i++) {
    668         printf(" ");
    669       }
    670       printf(
    671           "error:%" PRIdS "    size:%8" PRIdS "    pixels:%9" PRIdS
    672           "    enc_speed:%8.8f    dec_speed:%8.8f    bpp:%10.8f    dist:%10.8f"
    673           "    psnr:%10.8f    ssimulacra2:%.2f   p:%10.8f    bppp:%10.8f    "
    674           "qabpp:%10.8f ",
    675           t.stats.total_errors, t.stats.total_compressed_size, pixels, enc_mps,
    676           dec_mps, comp_bpp, t.stats.max_distance, psnr, ssimulacra2, p_norm,
    677           bpp_p_norm, adj_comp_bpp);
    678       for (size_t i = 0; i < t.stats.extra_metrics.size(); i++) {
    679         printf(" %s:%.8f", (*extra_metrics_names_)[i].c_str(),
    680                t.stats.extra_metrics[i]);
    681       }
    682       printf("\n");
    683     }
    684     fflush(stdout);
    685   }
    686 
    687   void PrintStats(const std::string& method, size_t idx_method) {
    688     // Assimilate all tasks with the same idx_method.
    689     BenchmarkStats method_stats;
    690     std::vector<const PackedPixelFile*> images;
    691     std::vector<const Task*> tasks;
    692     for (const Task& t : *tasks_) {
    693       if (t.idx_method == idx_method) {
    694         method_stats.Assimilate(t.stats);
    695         images.push_back(t.image);
    696         tasks.push_back(&t);
    697       }
    698     }
    699 
    700     std::string out;
    701 
    702     method_stats.PrintMoreStats();  // not concurrent
    703     out += method_stats.PrintLine(method, fnames_->size());
    704 
    705     if (Args()->write_html_report) {
    706       WriteHtmlReport(method, *fnames_, tasks, images,
    707                       Args()->save_heatmap && Args()->html_report_add_heatmap,
    708                       Args()->html_report_self_contained);
    709     }
    710 
    711     stats_aggregate_.push_back(
    712         method_stats.ComputeColumns(method, fnames_->size()));
    713 
    714     printf("%s", out.c_str());
    715     fflush(stdout);
    716   }
    717 
    718   void PrintStatsHeader() const {
    719     if (Args()->markdown) {
    720       if (Args()->show_progress) {
    721         fprintf(stderr, "\n");
    722         fflush(stderr);
    723       }
    724       printf("```\n");
    725     }
    726     if (fnames_->size() == 1) printf("%s\n", (*fnames_)[0].c_str());
    727     printf("%s", PrintHeader(*extra_metrics_names_).c_str());
    728     fflush(stdout);
    729   }
    730 
    731   void PrintStatsFooter() const {
    732     printf(
    733         "%s",
    734         PrintAggregate(extra_metrics_names_->size(), stats_aggregate_).c_str());
    735     if (Args()->markdown) printf("```\n");
    736     printf("\n");
    737     fflush(stdout);
    738   }
    739 
    740   const std::vector<std::string>* methods_;
    741   const std::vector<std::string>* extra_metrics_names_;
    742   const std::vector<std::string>* fnames_;
    743   const std::vector<Task>* tasks_;
    744 
    745   size_t tasks_done_;
    746 
    747   size_t stats_printed_;
    748   std::vector<size_t> stats_done_;
    749 
    750   size_t details_printed_;
    751   std::vector<size_t> details_done_;
    752 
    753   size_t max_fname_width_;
    754   size_t max_method_width_;
    755 
    756   std::vector<std::vector<ColumnValue>> stats_aggregate_;
    757 
    758   std::mutex mutex;
    759 };
    760 
    761 class Benchmark {
    762   using StringVec = std::vector<std::string>;
    763 
    764  public:
    765   // Return the exit code of the program.
    766   static int Run() {
    767     int ret = EXIT_SUCCESS;
    768     {
    769       const StringVec methods = GetMethods();
    770       const StringVec extra_metrics_names = GetExtraMetricsNames();
    771       const StringVec extra_metrics_commands = GetExtraMetricsCommands();
    772       const StringVec fnames = GetFilenames();
    773       // (non-const because Task.stats are updated)
    774       std::vector<Task> tasks = CreateTasks(methods, fnames);
    775 
    776       std::unique_ptr<ThreadPoolInternal> pool;
    777       std::vector<std::unique_ptr<ThreadPoolInternal>> inner_pools;
    778       InitThreads(tasks.size(), &pool, &inner_pools);
    779 
    780       std::vector<PackedPixelFile> loaded_images = LoadImages(fnames, &*pool);
    781 
    782       if (RunTasks(methods, extra_metrics_names, extra_metrics_commands, fnames,
    783                    loaded_images, &*pool, inner_pools, &tasks) != 0) {
    784         ret = EXIT_FAILURE;
    785         if (!Args()->silent_errors) {
    786           fprintf(stderr, "There were error(s) in the benchmark.\n");
    787         }
    788       }
    789     }
    790 
    791     jxl::CacheAligned::PrintStats();
    792     return ret;
    793   }
    794 
    795  private:
    796   static size_t NumOuterThreads(const size_t num_hw_threads,
    797                                 const size_t num_tasks) {
    798     // Default to #cores
    799     size_t num_threads = num_hw_threads;
    800     if (Args()->num_threads >= 0) {
    801       num_threads = static_cast<size_t>(Args()->num_threads);
    802     }
    803 
    804     // As a safety precaution, limit the number of threads to 4x the number of
    805     // available CPUs.
    806     num_threads =
    807         std::min<size_t>(num_threads, 4 * std::thread::hardware_concurrency());
    808 
    809     // Don't create more threads than there are tasks (pointless/wasteful).
    810     num_threads = std::min(num_threads, num_tasks);
    811 
    812     // Just one thread is counterproductive.
    813     if (num_threads == 1) num_threads = 0;
    814 
    815     return num_threads;
    816   }
    817 
    818   static int NumInnerThreads(const size_t num_hw_threads,
    819                              const size_t num_threads) {
    820     size_t num_inner;
    821 
    822     // Default: distribute remaining cores among tasks.
    823     if (Args()->inner_threads < 0) {
    824       if (num_threads == 0) {
    825         num_inner = num_hw_threads;
    826       } else if (num_hw_threads <= num_threads) {
    827         num_inner = 1;
    828       } else {
    829         num_inner = (num_hw_threads - num_threads) / num_threads;
    830       }
    831     } else {
    832       num_inner = static_cast<size_t>(Args()->inner_threads);
    833     }
    834 
    835     // Just one thread is counterproductive.
    836     if (num_inner == 1) num_inner = 0;
    837 
    838     return num_inner;
    839   }
    840 
    841   static void InitThreads(
    842       size_t num_tasks, std::unique_ptr<ThreadPoolInternal>* pool,
    843       std::vector<std::unique_ptr<ThreadPoolInternal>>* inner_pools) {
    844     const size_t num_hw_threads = std::thread::hardware_concurrency();
    845     const size_t num_threads = NumOuterThreads(num_hw_threads, num_tasks);
    846     const size_t num_inner = NumInnerThreads(num_hw_threads, num_threads);
    847 
    848     fprintf(stderr,
    849             "%" PRIuS " total threads, %" PRIuS " tasks, %" PRIuS
    850             " threads, %" PRIuS " inner threads\n",
    851             num_hw_threads, num_tasks, num_threads, num_inner);
    852 
    853     pool->reset(new ThreadPoolInternal(num_threads));
    854     // Main thread OR worker threads in pool each get a possibly empty nested
    855     // pool (helps use all available cores when #tasks < #threads)
    856     for (size_t i = 0; i < std::max<size_t>(num_threads, 1); ++i) {
    857       inner_pools->emplace_back(new ThreadPoolInternal(num_inner));
    858     }
    859   }
    860 
    861   static StringVec GetMethods() {
    862     StringVec methods = SplitString(Args()->codec, ',');
    863     for (auto it = methods.begin(); it != methods.end();) {
    864       if (it->empty()) {
    865         it = methods.erase(it);
    866       } else {
    867         ++it;
    868       }
    869     }
    870     return methods;
    871   }
    872 
    873   static StringVec GetExtraMetricsNames() {
    874     StringVec metrics = SplitString(Args()->extra_metrics, ',');
    875     for (auto it = metrics.begin(); it != metrics.end();) {
    876       if (it->empty()) {
    877         it = metrics.erase(it);
    878       } else {
    879         *it = SplitString(*it, ':')[0];
    880         ++it;
    881       }
    882     }
    883     return metrics;
    884   }
    885 
    886   static StringVec GetExtraMetricsCommands() {
    887     StringVec metrics = SplitString(Args()->extra_metrics, ',');
    888     for (auto it = metrics.begin(); it != metrics.end();) {
    889       if (it->empty()) {
    890         it = metrics.erase(it);
    891       } else {
    892         auto s = SplitString(*it, ':');
    893         JXL_CHECK(s.size() == 2);
    894         *it = s[1];
    895         ++it;
    896       }
    897     }
    898     return metrics;
    899   }
    900 
    901   static StringVec SampleFromInput(const StringVec& fnames,
    902                                    const std::string& sample_tmp_dir,
    903                                    int num_samples, size_t size) {
    904     JXL_CHECK(!sample_tmp_dir.empty());
    905     fprintf(stderr, "Creating samples of %" PRIuS "x%" PRIuS " tiles...\n",
    906             size, size);
    907     StringVec fnames_out;
    908     std::vector<Image3F> images;
    909     std::vector<size_t> offsets;
    910     size_t total_num_tiles = 0;
    911     for (const auto& fname : fnames) {
    912       Image3F img;
    913       JXL_CHECK(ReadPNG(fname, &img));
    914       JXL_CHECK(img.xsize() >= size);
    915       JXL_CHECK(img.ysize() >= size);
    916       total_num_tiles += (img.xsize() - size + 1) * (img.ysize() - size + 1);
    917       offsets.push_back(total_num_tiles);
    918       images.emplace_back(std::move(img));
    919     }
    920     JXL_CHECK(MakeDir(sample_tmp_dir));
    921     Rng rng(0);
    922     for (int i = 0; i < num_samples; ++i) {
    923       int val = rng.UniformI(0, offsets.back());
    924       size_t idx = (std::lower_bound(offsets.begin(), offsets.end(), val) -
    925                     offsets.begin());
    926       JXL_CHECK(idx < images.size());
    927       const Image3F& img = images[idx];
    928       int x0 = rng.UniformI(0, img.xsize() - size);
    929       int y0 = rng.UniformI(0, img.ysize() - size);
    930       JXL_ASSIGN_OR_DIE(Image3F sample, Image3F::Create(size, size));
    931       for (size_t c = 0; c < 3; ++c) {
    932         for (size_t y = 0; y < size; ++y) {
    933           const float* JXL_RESTRICT row_in = img.PlaneRow(c, y0 + y);
    934           float* JXL_RESTRICT row_out = sample.PlaneRow(c, y);
    935           memcpy(row_out, &row_in[x0], size * sizeof(row_out[0]));
    936         }
    937       }
    938       std::string fn_output =
    939           StringPrintf("%s/%s.crop_%dx%d+%d+%d.png", sample_tmp_dir.c_str(),
    940                        FileBaseName(fnames[idx]).c_str(), size, size, x0, y0);
    941       ThreadPool* null_pool = nullptr;
    942       JXL_CHECK(WriteImage(sample, null_pool, fn_output));
    943       fnames_out.push_back(fn_output);
    944     }
    945     fprintf(stderr, "Created %d sample tiles\n", num_samples);
    946     return fnames_out;
    947   }
    948 
    949   static StringVec GetFilenames() {
    950     StringVec fnames;
    951     JXL_CHECK(MatchFiles(Args()->input, &fnames));
    952     if (fnames.empty()) {
    953       JXL_ABORT("No input file matches pattern: '%s'", Args()->input.c_str());
    954     }
    955     if (Args()->print_details) {
    956       std::sort(fnames.begin(), fnames.end());
    957     }
    958 
    959     if (Args()->num_samples > 0) {
    960       fnames = SampleFromInput(fnames, Args()->sample_tmp_dir,
    961                                Args()->num_samples, Args()->sample_dimensions);
    962     }
    963     return fnames;
    964   }
    965 
    966   // (Load only once, not for every codec)
    967   static std::vector<PackedPixelFile> LoadImages(const StringVec& fnames,
    968                                                  ThreadPool* pool) {
    969     std::vector<PackedPixelFile> loaded_images;
    970     loaded_images.resize(fnames.size());
    971     const auto process_image = [&](const uint32_t task, size_t /*thread*/) {
    972       const size_t i = static_cast<size_t>(task);
    973       Status ok = true;
    974 
    975       if (!Args()->decode_only) {
    976         std::vector<uint8_t> encoded;
    977         ok = ReadFile(fnames[i], &encoded);
    978         if (ok) {
    979           ok = jxl::extras::DecodeBytes(Bytes(encoded), Args()->color_hints,
    980                                         &loaded_images[i]);
    981         }
    982         if (ok && loaded_images[i].icc.empty()) {
    983           // Add ICC profile if the image is not in sRGB, because not all codecs
    984           // can handle the color_encoding enum.
    985           ok = CreateNonSRGBICCProfile(&loaded_images[i]);
    986         }
    987         if (ok && Args()->intensity_target != 0) {
    988           // TODO(szabadka) Respect Args()->intensity_target
    989         }
    990       }
    991       if (!ok) {
    992         if (!Args()->silent_errors) {
    993           fprintf(stderr, "Failed to load image %s\n", fnames[i].c_str());
    994         }
    995         return;
    996       }
    997 
    998       if (!Args()->decode_only && Args()->override_bitdepth != 0) {
    999         // TODO(szabadla) Respect Args()->override_bitdepth
   1000       }
   1001     };
   1002     JXL_CHECK(jxl::RunOnPool(pool, 0, static_cast<uint32_t>(fnames.size()),
   1003                              ThreadPool::NoInit, process_image, "Load images"));
   1004     return loaded_images;
   1005   }
   1006 
   1007   static std::vector<Task> CreateTasks(const StringVec& methods,
   1008                                        const StringVec& fnames) {
   1009     std::vector<Task> tasks;
   1010     tasks.reserve(methods.size() * fnames.size());
   1011     for (size_t idx_image = 0; idx_image < fnames.size(); ++idx_image) {
   1012       for (size_t idx_method = 0; idx_method < methods.size(); ++idx_method) {
   1013         tasks.emplace_back();
   1014         Task& t = tasks.back();
   1015         t.codec = CreateImageCodec(methods[idx_method]);
   1016         t.idx_image = idx_image;
   1017         t.idx_method = idx_method;
   1018         // t.stats is default-initialized.
   1019       }
   1020     }
   1021     JXL_ASSERT(tasks.size() == tasks.capacity());
   1022     return tasks;
   1023   }
   1024 
   1025   // Return the total number of errors.
   1026   static size_t RunTasks(
   1027       const StringVec& methods, const StringVec& extra_metrics_names,
   1028       const StringVec& extra_metrics_commands, const StringVec& fnames,
   1029       const std::vector<PackedPixelFile>& loaded_images, ThreadPool* pool,
   1030       const std::vector<std::unique_ptr<ThreadPoolInternal>>& inner_pools,
   1031       std::vector<Task>* tasks) {
   1032     StatPrinter printer(methods, extra_metrics_names, fnames, *tasks);
   1033     if (Args()->print_details_csv) {
   1034       // Print CSV header
   1035       printf(
   1036           "method,image,error,size,pixels,enc_speed,dec_speed,"
   1037           "bpp,dist,psnr,p,bppp,qabpp");
   1038       for (const std::string& s : extra_metrics_names) {
   1039         printf(",%s", s.c_str());
   1040       }
   1041       printf("\n");
   1042     }
   1043 
   1044     std::vector<uint64_t> errors_thread;
   1045     JXL_CHECK(jxl::RunOnPool(
   1046         pool, 0, tasks->size(),
   1047         [&](const size_t num_threads) {
   1048           // Reduce false sharing by only writing every 8th slot (64 bytes).
   1049           errors_thread.resize(8 * num_threads);
   1050           return true;
   1051         },
   1052         [&](const uint32_t i, const size_t thread) {
   1053           Task& t = (*tasks)[i];
   1054           const PackedPixelFile& image = loaded_images[t.idx_image];
   1055           t.image = &image;
   1056           std::vector<uint8_t> compressed;
   1057           DoCompress(fnames[t.idx_image], image, extra_metrics_commands,
   1058                      t.codec.get(), &*inner_pools[thread], &compressed,
   1059                      &t.stats);
   1060           printer.TaskDone(i, t);
   1061           errors_thread[8 * thread] += t.stats.total_errors;
   1062         },
   1063         "Benchmark tasks"));
   1064     if (Args()->show_progress) fprintf(stderr, "\n");
   1065     return std::accumulate(errors_thread.begin(), errors_thread.end(),
   1066                            static_cast<size_t>(0));
   1067   }
   1068 };
   1069 
   1070 int BenchmarkMain(int argc, const char** argv) {
   1071   fprintf(stderr, "benchmark_xl %s\n",
   1072           jpegxl::tools::CodecConfigString(JxlDecoderVersion()).c_str());
   1073 
   1074   JXL_CHECK(Args()->AddCommandLineOptions());
   1075 
   1076   if (!Args()->Parse(argc, argv)) {
   1077     fprintf(stderr, "Use '%s -h' for more information\n", argv[0]);
   1078     return 1;
   1079   }
   1080 
   1081   if (Args()->cmdline.HelpFlagPassed()) {
   1082     Args()->PrintHelp();
   1083     return 0;
   1084   }
   1085   if (!Args()->ValidateArgs()) {
   1086     fprintf(stderr, "Use '%s -h' for more information\n", argv[0]);
   1087     return 1;
   1088   }
   1089   return Benchmark::Run();
   1090 }
   1091 
   1092 }  // namespace
   1093 }  // namespace tools
   1094 }  // namespace jpegxl
   1095 
   1096 int main(int argc, const char** argv) {
   1097   return jpegxl::tools::BenchmarkMain(argc, argv);
   1098 }