libjxl

FORK: libjxl patches used on blog
git clone https://git.neptards.moe/blog/libjxl.git
Log | Files | Refs | Submodules | README | LICENSE

enc_heuristics.cc (37055B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 #include "lib/jxl/enc_heuristics.h"
      7 
      8 #include <jxl/cms_interface.h>
      9 #include <stddef.h>
     10 #include <stdint.h>
     11 
     12 #include <algorithm>
     13 #include <cstdlib>
     14 #include <limits>
     15 #include <memory>
     16 #include <numeric>
     17 #include <string>
     18 #include <utility>
     19 #include <vector>
     20 
     21 #include "lib/jxl/ac_context.h"
     22 #include "lib/jxl/ac_strategy.h"
     23 #include "lib/jxl/base/common.h"
     24 #include "lib/jxl/base/compiler_specific.h"
     25 #include "lib/jxl/base/data_parallel.h"
     26 #include "lib/jxl/base/override.h"
     27 #include "lib/jxl/base/status.h"
     28 #include "lib/jxl/butteraugli/butteraugli.h"
     29 #include "lib/jxl/chroma_from_luma.h"
     30 #include "lib/jxl/coeff_order.h"
     31 #include "lib/jxl/coeff_order_fwd.h"
     32 #include "lib/jxl/dec_xyb.h"
     33 #include "lib/jxl/enc_ac_strategy.h"
     34 #include "lib/jxl/enc_adaptive_quantization.h"
     35 #include "lib/jxl/enc_ar_control_field.h"
     36 #include "lib/jxl/enc_cache.h"
     37 #include "lib/jxl/enc_chroma_from_luma.h"
     38 #include "lib/jxl/enc_gaborish.h"
     39 #include "lib/jxl/enc_modular.h"
     40 #include "lib/jxl/enc_noise.h"
     41 #include "lib/jxl/enc_params.h"
     42 #include "lib/jxl/enc_patch_dictionary.h"
     43 #include "lib/jxl/enc_quant_weights.h"
     44 #include "lib/jxl/enc_splines.h"
     45 #include "lib/jxl/frame_dimensions.h"
     46 #include "lib/jxl/frame_header.h"
     47 #include "lib/jxl/image.h"
     48 #include "lib/jxl/image_ops.h"
     49 #include "lib/jxl/passes_state.h"
     50 #include "lib/jxl/quant_weights.h"
     51 
     52 namespace jxl {
     53 
     54 struct AuxOut;
     55 
     56 void FindBestBlockEntropyModel(const CompressParams& cparams, const ImageI& rqf,
     57                                const AcStrategyImage& ac_strategy,
     58                                BlockCtxMap* block_ctx_map) {
     59   if (cparams.decoding_speed_tier >= 1) {
     60     static constexpr uint8_t kSimpleCtxMap[] = {
     61         // Cluster all blocks together
     62         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,  //
     63         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,  //
     64         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,  //
     65     };
     66     static_assert(
     67         3 * kNumOrders == sizeof(kSimpleCtxMap) / sizeof *kSimpleCtxMap,
     68         "Update simple context map");
     69 
     70     auto bcm = *block_ctx_map;
     71     bcm.ctx_map.assign(std::begin(kSimpleCtxMap), std::end(kSimpleCtxMap));
     72     bcm.num_ctxs = 2;
     73     bcm.num_dc_ctxs = 1;
     74     return;
     75   }
     76   if (cparams.speed_tier >= SpeedTier::kFalcon) {
     77     return;
     78   }
     79   // No need to change context modeling for small images.
     80   size_t tot = rqf.xsize() * rqf.ysize();
     81   size_t size_for_ctx_model = (1 << 10) * cparams.butteraugli_distance;
     82   if (tot < size_for_ctx_model) return;
     83 
     84   struct OccCounters {
     85     // count the occurrences of each qf value and each strategy type.
     86     OccCounters(const ImageI& rqf, const AcStrategyImage& ac_strategy) {
     87       for (size_t y = 0; y < rqf.ysize(); y++) {
     88         const int32_t* qf_row = rqf.Row(y);
     89         AcStrategyRow acs_row = ac_strategy.ConstRow(y);
     90         for (size_t x = 0; x < rqf.xsize(); x++) {
     91           int ord = kStrategyOrder[acs_row[x].RawStrategy()];
     92           int qf = qf_row[x] - 1;
     93           qf_counts[qf]++;
     94           qf_ord_counts[ord][qf]++;
     95           ord_counts[ord]++;
     96         }
     97       }
     98     }
     99 
    100     size_t qf_counts[256] = {};
    101     size_t qf_ord_counts[kNumOrders][256] = {};
    102     size_t ord_counts[kNumOrders] = {};
    103   };
    104   // The OccCounters struct is too big to allocate on the stack.
    105   std::unique_ptr<OccCounters> counters(new OccCounters(rqf, ac_strategy));
    106 
    107   // Splitting the context model according to the quantization field seems to
    108   // mostly benefit only large images.
    109   size_t size_for_qf_split = (1 << 13) * cparams.butteraugli_distance;
    110   size_t num_qf_segments = tot < size_for_qf_split ? 1 : 2;
    111   std::vector<uint32_t>& qft = block_ctx_map->qf_thresholds;
    112   qft.clear();
    113   // Divide the quant field in up to num_qf_segments segments.
    114   size_t cumsum = 0;
    115   size_t next = 1;
    116   size_t last_cut = 256;
    117   size_t cut = tot * next / num_qf_segments;
    118   for (uint32_t j = 0; j < 256; j++) {
    119     cumsum += counters->qf_counts[j];
    120     if (cumsum > cut) {
    121       if (j != 0) {
    122         qft.push_back(j);
    123       }
    124       last_cut = j;
    125       while (cumsum > cut) {
    126         next++;
    127         cut = tot * next / num_qf_segments;
    128       }
    129     } else if (next > qft.size() + 1) {
    130       if (j - 1 == last_cut && j != 0) {
    131         qft.push_back(j);
    132       }
    133     }
    134   }
    135 
    136   // Count the occurrences of each segment.
    137   std::vector<size_t> counts(kNumOrders * (qft.size() + 1));
    138   size_t qft_pos = 0;
    139   for (size_t j = 0; j < 256; j++) {
    140     if (qft_pos < qft.size() && j == qft[qft_pos]) {
    141       qft_pos++;
    142     }
    143     for (size_t i = 0; i < kNumOrders; i++) {
    144       counts[qft_pos + i * (qft.size() + 1)] += counters->qf_ord_counts[i][j];
    145     }
    146   }
    147 
    148   // Repeatedly merge the lowest-count pair.
    149   std::vector<uint8_t> remap((qft.size() + 1) * kNumOrders);
    150   std::iota(remap.begin(), remap.end(), 0);
    151   std::vector<uint8_t> clusters(remap);
    152   size_t nb_clusters =
    153       Clamp1(static_cast<int>(tot / size_for_ctx_model / 2), 2, 9);
    154   size_t nb_clusters_chroma =
    155       Clamp1(static_cast<int>(tot / size_for_ctx_model / 3), 1, 5);
    156   // This is O(n^2 log n), but n is small.
    157   while (clusters.size() > nb_clusters) {
    158     std::sort(clusters.begin(), clusters.end(),
    159               [&](int a, int b) { return counts[a] > counts[b]; });
    160     counts[clusters[clusters.size() - 2]] += counts[clusters.back()];
    161     counts[clusters.back()] = 0;
    162     remap[clusters.back()] = clusters[clusters.size() - 2];
    163     clusters.pop_back();
    164   }
    165   for (size_t i = 0; i < remap.size(); i++) {
    166     while (remap[remap[i]] != remap[i]) {
    167       remap[i] = remap[remap[i]];
    168     }
    169   }
    170   // Relabel starting from 0.
    171   std::vector<uint8_t> remap_remap(remap.size(), remap.size());
    172   size_t num = 0;
    173   for (size_t i = 0; i < remap.size(); i++) {
    174     if (remap_remap[remap[i]] == remap.size()) {
    175       remap_remap[remap[i]] = num++;
    176     }
    177     remap[i] = remap_remap[remap[i]];
    178   }
    179   // Write the block context map.
    180   auto& ctx_map = block_ctx_map->ctx_map;
    181   ctx_map = remap;
    182   ctx_map.resize(remap.size() * 3);
    183   // for chroma, only use up to nb_clusters_chroma separate block contexts
    184   // (those for the biggest clusters)
    185   for (size_t i = remap.size(); i < remap.size() * 3; i++) {
    186     ctx_map[i] = num + Clamp1(static_cast<int>(remap[i % remap.size()]), 0,
    187                               static_cast<int>(nb_clusters_chroma) - 1);
    188   }
    189   block_ctx_map->num_ctxs =
    190       *std::max_element(ctx_map.begin(), ctx_map.end()) + 1;
    191 }
    192 
    193 namespace {
    194 
    195 Status FindBestDequantMatrices(const CompressParams& cparams,
    196                                ModularFrameEncoder* modular_frame_encoder,
    197                                DequantMatrices* dequant_matrices) {
    198   // TODO(veluca): quant matrices for no-gaborish.
    199   // TODO(veluca): heuristics for in-bitstream quant tables.
    200   *dequant_matrices = DequantMatrices();
    201   if (cparams.max_error_mode) {
    202     // Set numerators of all quantization matrices to constant values.
    203     float weights[3][1] = {{1.0f / cparams.max_error[0]},
    204                            {1.0f / cparams.max_error[1]},
    205                            {1.0f / cparams.max_error[2]}};
    206     DctQuantWeightParams dct_params(weights);
    207     std::vector<QuantEncoding> encodings(DequantMatrices::kNum,
    208                                          QuantEncoding::DCT(dct_params));
    209     JXL_RETURN_IF_ERROR(DequantMatricesSetCustom(dequant_matrices, encodings,
    210                                                  modular_frame_encoder));
    211     float dc_weights[3] = {1.0f / cparams.max_error[0],
    212                            1.0f / cparams.max_error[1],
    213                            1.0f / cparams.max_error[2]};
    214     DequantMatricesSetCustomDC(dequant_matrices, dc_weights);
    215   }
    216   return true;
    217 }
    218 
    219 void StoreMin2(const float v, float& min1, float& min2) {
    220   if (v < min2) {
    221     if (v < min1) {
    222       min2 = min1;
    223       min1 = v;
    224     } else {
    225       min2 = v;
    226     }
    227   }
    228 }
    229 
    230 void CreateMask(const ImageF& image, ImageF& mask) {
    231   for (size_t y = 0; y < image.ysize(); y++) {
    232     const auto* row_n = y > 0 ? image.Row(y - 1) : image.Row(y);
    233     const auto* row_in = image.Row(y);
    234     const auto* row_s = y + 1 < image.ysize() ? image.Row(y + 1) : image.Row(y);
    235     auto* row_out = mask.Row(y);
    236     for (size_t x = 0; x < image.xsize(); x++) {
    237       // Center, west, east, north, south values and their absolute difference
    238       float c = row_in[x];
    239       float w = x > 0 ? row_in[x - 1] : row_in[x];
    240       float e = x + 1 < image.xsize() ? row_in[x + 1] : row_in[x];
    241       float n = row_n[x];
    242       float s = row_s[x];
    243       float dw = std::abs(c - w);
    244       float de = std::abs(c - e);
    245       float dn = std::abs(c - n);
    246       float ds = std::abs(c - s);
    247       float min = std::numeric_limits<float>::max();
    248       float min2 = std::numeric_limits<float>::max();
    249       StoreMin2(dw, min, min2);
    250       StoreMin2(de, min, min2);
    251       StoreMin2(dn, min, min2);
    252       StoreMin2(ds, min, min2);
    253       row_out[x] = min2;
    254     }
    255   }
    256 }
    257 
    258 // Downsamples the image by a factor of 2 with a kernel that's sharper than
    259 // the standard 2x2 box kernel used by DownsampleImage.
    260 // The kernel is optimized against the result of the 2x2 upsampling kernel used
    261 // by the decoder. Ringing is slightly reduced by clamping the values of the
    262 // resulting pixels within certain bounds of a small region in the original
    263 // image.
    264 Status DownsampleImage2_Sharper(const ImageF& input, ImageF* output) {
    265   const int64_t kernelx = 12;
    266   const int64_t kernely = 12;
    267 
    268   static const float kernel[144] = {
    269       -0.000314256996835, -0.000314256996835, -0.000897597057705,
    270       -0.000562751488849, -0.000176807273646, 0.001864627368902,
    271       0.001864627368902,  -0.000176807273646, -0.000562751488849,
    272       -0.000897597057705, -0.000314256996835, -0.000314256996835,
    273       -0.000314256996835, -0.001527942804748, -0.000121760530512,
    274       0.000191123989093,  0.010193185932466,  0.058637519197110,
    275       0.058637519197110,  0.010193185932466,  0.000191123989093,
    276       -0.000121760530512, -0.001527942804748, -0.000314256996835,
    277       -0.000897597057705, -0.000121760530512, 0.000946363683751,
    278       0.007113577630288,  0.000437956841058,  -0.000372823835211,
    279       -0.000372823835211, 0.000437956841058,  0.007113577630288,
    280       0.000946363683751,  -0.000121760530512, -0.000897597057705,
    281       -0.000562751488849, 0.000191123989093,  0.007113577630288,
    282       0.044592622228814,  0.000222278879007,  -0.162864473015945,
    283       -0.162864473015945, 0.000222278879007,  0.044592622228814,
    284       0.007113577630288,  0.000191123989093,  -0.000562751488849,
    285       -0.000176807273646, 0.010193185932466,  0.000437956841058,
    286       0.000222278879007,  -0.000913092543974, -0.017071696107902,
    287       -0.017071696107902, -0.000913092543974, 0.000222278879007,
    288       0.000437956841058,  0.010193185932466,  -0.000176807273646,
    289       0.001864627368902,  0.058637519197110,  -0.000372823835211,
    290       -0.162864473015945, -0.017071696107902, 0.414660099370354,
    291       0.414660099370354,  -0.017071696107902, -0.162864473015945,
    292       -0.000372823835211, 0.058637519197110,  0.001864627368902,
    293       0.001864627368902,  0.058637519197110,  -0.000372823835211,
    294       -0.162864473015945, -0.017071696107902, 0.414660099370354,
    295       0.414660099370354,  -0.017071696107902, -0.162864473015945,
    296       -0.000372823835211, 0.058637519197110,  0.001864627368902,
    297       -0.000176807273646, 0.010193185932466,  0.000437956841058,
    298       0.000222278879007,  -0.000913092543974, -0.017071696107902,
    299       -0.017071696107902, -0.000913092543974, 0.000222278879007,
    300       0.000437956841058,  0.010193185932466,  -0.000176807273646,
    301       -0.000562751488849, 0.000191123989093,  0.007113577630288,
    302       0.044592622228814,  0.000222278879007,  -0.162864473015945,
    303       -0.162864473015945, 0.000222278879007,  0.044592622228814,
    304       0.007113577630288,  0.000191123989093,  -0.000562751488849,
    305       -0.000897597057705, -0.000121760530512, 0.000946363683751,
    306       0.007113577630288,  0.000437956841058,  -0.000372823835211,
    307       -0.000372823835211, 0.000437956841058,  0.007113577630288,
    308       0.000946363683751,  -0.000121760530512, -0.000897597057705,
    309       -0.000314256996835, -0.001527942804748, -0.000121760530512,
    310       0.000191123989093,  0.010193185932466,  0.058637519197110,
    311       0.058637519197110,  0.010193185932466,  0.000191123989093,
    312       -0.000121760530512, -0.001527942804748, -0.000314256996835,
    313       -0.000314256996835, -0.000314256996835, -0.000897597057705,
    314       -0.000562751488849, -0.000176807273646, 0.001864627368902,
    315       0.001864627368902,  -0.000176807273646, -0.000562751488849,
    316       -0.000897597057705, -0.000314256996835, -0.000314256996835};
    317 
    318   int64_t xsize = input.xsize();
    319   int64_t ysize = input.ysize();
    320 
    321   JXL_ASSIGN_OR_RETURN(ImageF box_downsample, ImageF::Create(xsize, ysize));
    322   CopyImageTo(input, &box_downsample);
    323   JXL_ASSIGN_OR_RETURN(box_downsample, DownsampleImage(box_downsample, 2));
    324 
    325   JXL_ASSIGN_OR_RETURN(ImageF mask, ImageF::Create(box_downsample.xsize(),
    326                                                    box_downsample.ysize()));
    327   CreateMask(box_downsample, mask);
    328 
    329   for (size_t y = 0; y < output->ysize(); y++) {
    330     float* row_out = output->Row(y);
    331     const float* row_in[kernely];
    332     const float* row_mask = mask.Row(y);
    333     // get the rows in the support
    334     for (size_t ky = 0; ky < kernely; ky++) {
    335       int64_t iy = y * 2 + ky - (kernely - 1) / 2;
    336       if (iy < 0) iy = 0;
    337       if (iy >= ysize) iy = ysize - 1;
    338       row_in[ky] = input.Row(iy);
    339     }
    340 
    341     for (size_t x = 0; x < output->xsize(); x++) {
    342       // get min and max values of the original image in the support
    343       float min = std::numeric_limits<float>::max();
    344       float max = std::numeric_limits<float>::min();
    345       // kernelx - R and kernely - R are the radius of a rectangular region in
    346       // which the values of a pixel are bounded to reduce ringing.
    347       static constexpr int64_t R = 5;
    348       for (int64_t ky = R; ky + R < kernely; ky++) {
    349         for (int64_t kx = R; kx + R < kernelx; kx++) {
    350           int64_t ix = x * 2 + kx - (kernelx - 1) / 2;
    351           if (ix < 0) ix = 0;
    352           if (ix >= xsize) ix = xsize - 1;
    353           min = std::min<float>(min, row_in[ky][ix]);
    354           max = std::max<float>(max, row_in[ky][ix]);
    355         }
    356       }
    357 
    358       float sum = 0;
    359       for (int64_t ky = 0; ky < kernely; ky++) {
    360         for (int64_t kx = 0; kx < kernelx; kx++) {
    361           int64_t ix = x * 2 + kx - (kernelx - 1) / 2;
    362           if (ix < 0) ix = 0;
    363           if (ix >= xsize) ix = xsize - 1;
    364           sum += row_in[ky][ix] * kernel[ky * kernelx + kx];
    365         }
    366       }
    367 
    368       row_out[x] = sum;
    369 
    370       // Clamp the pixel within the value  of a small area to prevent ringning.
    371       // The mask determines how much to clamp, clamp more to reduce more
    372       // ringing in smooth areas, clamp less in noisy areas to get more
    373       // sharpness. Higher mask_multiplier gives less clamping, so less
    374       // ringing reduction.
    375       const constexpr float mask_multiplier = 1;
    376       float a = row_mask[x] * mask_multiplier;
    377       float clip_min = min - a;
    378       float clip_max = max + a;
    379       if (row_out[x] < clip_min) {
    380         row_out[x] = clip_min;
    381       } else if (row_out[x] > clip_max) {
    382         row_out[x] = clip_max;
    383       }
    384     }
    385   }
    386   return true;
    387 }
    388 
    389 }  // namespace
    390 
    391 Status DownsampleImage2_Sharper(Image3F* opsin) {
    392   // Allocate extra space to avoid a reallocation when padding.
    393   JXL_ASSIGN_OR_RETURN(Image3F downsampled,
    394                        Image3F::Create(DivCeil(opsin->xsize(), 2) + kBlockDim,
    395                                        DivCeil(opsin->ysize(), 2) + kBlockDim));
    396   downsampled.ShrinkTo(downsampled.xsize() - kBlockDim,
    397                        downsampled.ysize() - kBlockDim);
    398 
    399   for (size_t c = 0; c < 3; c++) {
    400     JXL_RETURN_IF_ERROR(
    401         DownsampleImage2_Sharper(opsin->Plane(c), &downsampled.Plane(c)));
    402   }
    403   *opsin = std::move(downsampled);
    404   return true;
    405 }
    406 
    407 namespace {
    408 
    409 // The default upsampling kernels used by Upsampler in the decoder.
    410 const constexpr int64_t kSize = 5;
    411 
    412 const float kernel00[25] = {
    413     -0.01716200f, -0.03452303f, -0.04022174f, -0.02921014f, -0.00624645f,
    414     -0.03452303f, 0.14111091f,  0.28896755f,  0.00278718f,  -0.01610267f,
    415     -0.04022174f, 0.28896755f,  0.56661550f,  0.03777607f,  -0.01986694f,
    416     -0.02921014f, 0.00278718f,  0.03777607f,  -0.03144731f, -0.01185068f,
    417     -0.00624645f, -0.01610267f, -0.01986694f, -0.01185068f, -0.00213539f,
    418 };
    419 const float kernel01[25] = {
    420     -0.00624645f, -0.01610267f, -0.01986694f, -0.01185068f, -0.00213539f,
    421     -0.02921014f, 0.00278718f,  0.03777607f,  -0.03144731f, -0.01185068f,
    422     -0.04022174f, 0.28896755f,  0.56661550f,  0.03777607f,  -0.01986694f,
    423     -0.03452303f, 0.14111091f,  0.28896755f,  0.00278718f,  -0.01610267f,
    424     -0.01716200f, -0.03452303f, -0.04022174f, -0.02921014f, -0.00624645f,
    425 };
    426 const float kernel10[25] = {
    427     -0.00624645f, -0.02921014f, -0.04022174f, -0.03452303f, -0.01716200f,
    428     -0.01610267f, 0.00278718f,  0.28896755f,  0.14111091f,  -0.03452303f,
    429     -0.01986694f, 0.03777607f,  0.56661550f,  0.28896755f,  -0.04022174f,
    430     -0.01185068f, -0.03144731f, 0.03777607f,  0.00278718f,  -0.02921014f,
    431     -0.00213539f, -0.01185068f, -0.01986694f, -0.01610267f, -0.00624645f,
    432 };
    433 const float kernel11[25] = {
    434     -0.00213539f, -0.01185068f, -0.01986694f, -0.01610267f, -0.00624645f,
    435     -0.01185068f, -0.03144731f, 0.03777607f,  0.00278718f,  -0.02921014f,
    436     -0.01986694f, 0.03777607f,  0.56661550f,  0.28896755f,  -0.04022174f,
    437     -0.01610267f, 0.00278718f,  0.28896755f,  0.14111091f,  -0.03452303f,
    438     -0.00624645f, -0.02921014f, -0.04022174f, -0.03452303f, -0.01716200f,
    439 };
    440 
    441 // Does exactly the same as the Upsampler in dec_upsampler for 2x2 pixels, with
    442 // default CustomTransformData.
    443 // TODO(lode): use Upsampler instead. However, it requires pre-initialization
    444 // and padding on the left side of the image which requires refactoring the
    445 // other code using this.
    446 void UpsampleImage(const ImageF& input, ImageF* output) {
    447   int64_t xsize = input.xsize();
    448   int64_t ysize = input.ysize();
    449   int64_t xsize2 = output->xsize();
    450   int64_t ysize2 = output->ysize();
    451   for (int64_t y = 0; y < ysize2; y++) {
    452     for (int64_t x = 0; x < xsize2; x++) {
    453       const auto* kernel = kernel00;
    454       if ((x & 1) && (y & 1)) {
    455         kernel = kernel11;
    456       } else if (x & 1) {
    457         kernel = kernel10;
    458       } else if (y & 1) {
    459         kernel = kernel01;
    460       }
    461       float sum = 0;
    462       int64_t x2 = x / 2;
    463       int64_t y2 = y / 2;
    464 
    465       // get min and max values of the original image in the support
    466       float min = std::numeric_limits<float>::max();
    467       float max = std::numeric_limits<float>::min();
    468 
    469       for (int64_t ky = 0; ky < kSize; ky++) {
    470         for (int64_t kx = 0; kx < kSize; kx++) {
    471           int64_t xi = x2 - kSize / 2 + kx;
    472           int64_t yi = y2 - kSize / 2 + ky;
    473           if (xi < 0) xi = 0;
    474           if (xi >= xsize) xi = input.xsize() - 1;
    475           if (yi < 0) yi = 0;
    476           if (yi >= ysize) yi = input.ysize() - 1;
    477           min = std::min<float>(min, input.Row(yi)[xi]);
    478           max = std::max<float>(max, input.Row(yi)[xi]);
    479         }
    480       }
    481 
    482       for (int64_t ky = 0; ky < kSize; ky++) {
    483         for (int64_t kx = 0; kx < kSize; kx++) {
    484           int64_t xi = x2 - kSize / 2 + kx;
    485           int64_t yi = y2 - kSize / 2 + ky;
    486           if (xi < 0) xi = 0;
    487           if (xi >= xsize) xi = input.xsize() - 1;
    488           if (yi < 0) yi = 0;
    489           if (yi >= ysize) yi = input.ysize() - 1;
    490           sum += input.Row(yi)[xi] * kernel[ky * kSize + kx];
    491         }
    492       }
    493       output->Row(y)[x] = sum;
    494       if (output->Row(y)[x] < min) output->Row(y)[x] = min;
    495       if (output->Row(y)[x] > max) output->Row(y)[x] = max;
    496     }
    497   }
    498 }
    499 
    500 // Returns the derivative of Upsampler, with respect to input pixel x2, y2, to
    501 // output pixel x, y (ignoring the clamping).
    502 float UpsamplerDeriv(int64_t x2, int64_t y2, int64_t x, int64_t y) {
    503   const auto* kernel = kernel00;
    504   if ((x & 1) && (y & 1)) {
    505     kernel = kernel11;
    506   } else if (x & 1) {
    507     kernel = kernel10;
    508   } else if (y & 1) {
    509     kernel = kernel01;
    510   }
    511 
    512   int64_t ix = x / 2;
    513   int64_t iy = y / 2;
    514   int64_t kx = x2 - ix + kSize / 2;
    515   int64_t ky = y2 - iy + kSize / 2;
    516 
    517   // This should not happen.
    518   if (kx < 0 || kx >= kSize || ky < 0 || ky >= kSize) return 0;
    519 
    520   return kernel[ky * kSize + kx];
    521 }
    522 
    523 // Apply the derivative of the Upsampler to the input, reversing the effect of
    524 // its coefficients. The output image is 2x2 times smaller than the input.
    525 void AntiUpsample(const ImageF& input, ImageF* d) {
    526   int64_t xsize = input.xsize();
    527   int64_t ysize = input.ysize();
    528   int64_t xsize2 = d->xsize();
    529   int64_t ysize2 = d->ysize();
    530   int64_t k0 = kSize - 1;
    531   int64_t k1 = kSize;
    532   for (int64_t y2 = 0; y2 < ysize2; ++y2) {
    533     auto* row = d->Row(y2);
    534     for (int64_t x2 = 0; x2 < xsize2; ++x2) {
    535       int64_t x0 = x2 * 2 - k0;
    536       if (x0 < 0) x0 = 0;
    537       int64_t x1 = x2 * 2 + k1 + 1;
    538       if (x1 > xsize) x1 = xsize;
    539       int64_t y0 = y2 * 2 - k0;
    540       if (y0 < 0) y0 = 0;
    541       int64_t y1 = y2 * 2 + k1 + 1;
    542       if (y1 > ysize) y1 = ysize;
    543 
    544       float sum = 0;
    545       for (int64_t y = y0; y < y1; ++y) {
    546         const auto* row_in = input.Row(y);
    547         for (int64_t x = x0; x < x1; ++x) {
    548           double deriv = UpsamplerDeriv(x2, y2, x, y);
    549           sum += deriv * row_in[x];
    550         }
    551       }
    552       row[x2] = sum;
    553     }
    554   }
    555 }
    556 
    557 // Element-wise multiplies two images.
    558 template <typename T>
    559 void ElwiseMul(const Plane<T>& image1, const Plane<T>& image2, Plane<T>* out) {
    560   const size_t xsize = image1.xsize();
    561   const size_t ysize = image1.ysize();
    562   JXL_CHECK(xsize == image2.xsize());
    563   JXL_CHECK(ysize == image2.ysize());
    564   JXL_CHECK(xsize == out->xsize());
    565   JXL_CHECK(ysize == out->ysize());
    566   for (size_t y = 0; y < ysize; ++y) {
    567     const T* const JXL_RESTRICT row1 = image1.Row(y);
    568     const T* const JXL_RESTRICT row2 = image2.Row(y);
    569     T* const JXL_RESTRICT row_out = out->Row(y);
    570     for (size_t x = 0; x < xsize; ++x) {
    571       row_out[x] = row1[x] * row2[x];
    572     }
    573   }
    574 }
    575 
    576 // Element-wise divides two images.
    577 template <typename T>
    578 void ElwiseDiv(const Plane<T>& image1, const Plane<T>& image2, Plane<T>* out) {
    579   const size_t xsize = image1.xsize();
    580   const size_t ysize = image1.ysize();
    581   JXL_CHECK(xsize == image2.xsize());
    582   JXL_CHECK(ysize == image2.ysize());
    583   JXL_CHECK(xsize == out->xsize());
    584   JXL_CHECK(ysize == out->ysize());
    585   for (size_t y = 0; y < ysize; ++y) {
    586     const T* const JXL_RESTRICT row1 = image1.Row(y);
    587     const T* const JXL_RESTRICT row2 = image2.Row(y);
    588     T* const JXL_RESTRICT row_out = out->Row(y);
    589     for (size_t x = 0; x < xsize; ++x) {
    590       row_out[x] = row1[x] / row2[x];
    591     }
    592   }
    593 }
    594 
    595 void ReduceRinging(const ImageF& initial, const ImageF& mask, ImageF& down) {
    596   int64_t xsize2 = down.xsize();
    597   int64_t ysize2 = down.ysize();
    598 
    599   for (size_t y = 0; y < down.ysize(); y++) {
    600     const float* row_mask = mask.Row(y);
    601     float* row_out = down.Row(y);
    602     for (size_t x = 0; x < down.xsize(); x++) {
    603       float v = down.Row(y)[x];
    604       float min = initial.Row(y)[x];
    605       float max = initial.Row(y)[x];
    606       for (int64_t yi = -1; yi < 2; yi++) {
    607         for (int64_t xi = -1; xi < 2; xi++) {
    608           int64_t x2 = static_cast<int64_t>(x) + xi;
    609           int64_t y2 = static_cast<int64_t>(y) + yi;
    610           if (x2 < 0 || y2 < 0 || x2 >= xsize2 || y2 >= ysize2) continue;
    611           min = std::min<float>(min, initial.Row(y2)[x2]);
    612           max = std::max<float>(max, initial.Row(y2)[x2]);
    613         }
    614       }
    615 
    616       row_out[x] = v;
    617 
    618       // Clamp the pixel within the value  of a small area to prevent ringning.
    619       // The mask determines how much to clamp, clamp more to reduce more
    620       // ringing in smooth areas, clamp less in noisy areas to get more
    621       // sharpness. Higher mask_multiplier gives less clamping, so less
    622       // ringing reduction.
    623       const constexpr float mask_multiplier = 2;
    624       float a = row_mask[x] * mask_multiplier;
    625       float clip_min = min - a;
    626       float clip_max = max + a;
    627       if (row_out[x] < clip_min) row_out[x] = clip_min;
    628       if (row_out[x] > clip_max) row_out[x] = clip_max;
    629     }
    630   }
    631 }
    632 
    633 // TODO(lode): move this to a separate file enc_downsample.cc
    634 Status DownsampleImage2_Iterative(const ImageF& orig, ImageF* output) {
    635   int64_t xsize = orig.xsize();
    636   int64_t ysize = orig.ysize();
    637   int64_t xsize2 = DivCeil(orig.xsize(), 2);
    638   int64_t ysize2 = DivCeil(orig.ysize(), 2);
    639 
    640   JXL_ASSIGN_OR_RETURN(ImageF box_downsample, ImageF::Create(xsize, ysize));
    641   CopyImageTo(orig, &box_downsample);
    642   JXL_ASSIGN_OR_RETURN(box_downsample, DownsampleImage(box_downsample, 2));
    643   JXL_ASSIGN_OR_RETURN(ImageF mask, ImageF::Create(box_downsample.xsize(),
    644                                                    box_downsample.ysize()));
    645   CreateMask(box_downsample, mask);
    646 
    647   output->ShrinkTo(xsize2, ysize2);
    648 
    649   // Initial result image using the sharper downsampling.
    650   // Allocate extra space to avoid a reallocation when padding.
    651   JXL_ASSIGN_OR_RETURN(ImageF initial,
    652                        ImageF::Create(DivCeil(orig.xsize(), 2) + kBlockDim,
    653                                       DivCeil(orig.ysize(), 2) + kBlockDim));
    654   initial.ShrinkTo(initial.xsize() - kBlockDim, initial.ysize() - kBlockDim);
    655   JXL_RETURN_IF_ERROR(DownsampleImage2_Sharper(orig, &initial));
    656 
    657   JXL_ASSIGN_OR_RETURN(ImageF down,
    658                        ImageF::Create(initial.xsize(), initial.ysize()));
    659   CopyImageTo(initial, &down);
    660   JXL_ASSIGN_OR_RETURN(ImageF up, ImageF::Create(xsize, ysize));
    661   JXL_ASSIGN_OR_RETURN(ImageF corr, ImageF::Create(xsize, ysize));
    662   JXL_ASSIGN_OR_RETURN(ImageF corr2, ImageF::Create(xsize2, ysize2));
    663 
    664   // In the weights map, relatively higher values will allow less ringing but
    665   // also less sharpness. With all constant values, it optimizes equally
    666   // everywhere. Even in this case, the weights2 computed from
    667   // this is still used and differs at the borders of the image.
    668   // TODO(lode): Make use of the weights field for anti-ringing and clamping,
    669   // the values are all set to 1 for now, but it is intended to be used for
    670   // reducing ringing based on the mask, and taking clamping into account.
    671   JXL_ASSIGN_OR_RETURN(ImageF weights, ImageF::Create(xsize, ysize));
    672   for (size_t y = 0; y < weights.ysize(); y++) {
    673     auto* row = weights.Row(y);
    674     for (size_t x = 0; x < weights.xsize(); x++) {
    675       row[x] = 1;
    676     }
    677   }
    678   JXL_ASSIGN_OR_RETURN(ImageF weights2, ImageF::Create(xsize2, ysize2));
    679   AntiUpsample(weights, &weights2);
    680 
    681   const size_t num_it = 3;
    682   for (size_t it = 0; it < num_it; ++it) {
    683     UpsampleImage(down, &up);
    684     JXL_ASSIGN_OR_RETURN(corr, LinComb<float>(1, orig, -1, up));
    685     ElwiseMul(corr, weights, &corr);
    686     AntiUpsample(corr, &corr2);
    687     ElwiseDiv(corr2, weights2, &corr2);
    688 
    689     JXL_ASSIGN_OR_RETURN(down, LinComb<float>(1, down, 1, corr2));
    690   }
    691 
    692   ReduceRinging(initial, mask, down);
    693 
    694   // can't just use CopyImage, because the output image was prepared with
    695   // padding.
    696   for (size_t y = 0; y < down.ysize(); y++) {
    697     for (size_t x = 0; x < down.xsize(); x++) {
    698       float v = down.Row(y)[x];
    699       output->Row(y)[x] = v;
    700     }
    701   }
    702   return true;
    703 }
    704 
    705 }  // namespace
    706 
    707 Status DownsampleImage2_Iterative(Image3F* opsin) {
    708   // Allocate extra space to avoid a reallocation when padding.
    709   JXL_ASSIGN_OR_RETURN(Image3F downsampled,
    710                        Image3F::Create(DivCeil(opsin->xsize(), 2) + kBlockDim,
    711                                        DivCeil(opsin->ysize(), 2) + kBlockDim));
    712   downsampled.ShrinkTo(downsampled.xsize() - kBlockDim,
    713                        downsampled.ysize() - kBlockDim);
    714 
    715   JXL_ASSIGN_OR_RETURN(Image3F rgb,
    716                        Image3F::Create(opsin->xsize(), opsin->ysize()));
    717   OpsinParams opsin_params;  // TODO(user): use the ones that are actually used
    718   opsin_params.Init(kDefaultIntensityTarget);
    719   OpsinToLinear(*opsin, Rect(rgb), nullptr, &rgb, opsin_params);
    720 
    721   JXL_ASSIGN_OR_RETURN(ImageF mask,
    722                        ImageF::Create(opsin->xsize(), opsin->ysize()));
    723   ButteraugliParams butter_params;
    724   JXL_ASSIGN_OR_RETURN(std::unique_ptr<ButteraugliComparator> butter,
    725                        ButteraugliComparator::Make(rgb, butter_params));
    726   JXL_RETURN_IF_ERROR(butter->Mask(&mask));
    727   JXL_ASSIGN_OR_RETURN(ImageF mask_fuzzy,
    728                        ImageF::Create(opsin->xsize(), opsin->ysize()));
    729 
    730   for (size_t c = 0; c < 3; c++) {
    731     JXL_RETURN_IF_ERROR(
    732         DownsampleImage2_Iterative(opsin->Plane(c), &downsampled.Plane(c)));
    733   }
    734   *opsin = std::move(downsampled);
    735   return true;
    736 }
    737 
    738 Status LossyFrameHeuristics(const FrameHeader& frame_header,
    739                             PassesEncoderState* enc_state,
    740                             ModularFrameEncoder* modular_frame_encoder,
    741                             const Image3F* original_pixels, Image3F* opsin,
    742                             const Rect& rect, const JxlCmsInterface& cms,
    743                             ThreadPool* pool, AuxOut* aux_out) {
    744   const CompressParams& cparams = enc_state->cparams;
    745   const bool streaming_mode = enc_state->streaming_mode;
    746   const bool initialize_global_state = enc_state->initialize_global_state;
    747   PassesSharedState& shared = enc_state->shared;
    748   const FrameDimensions& frame_dim = shared.frame_dim;
    749   ImageFeatures& image_features = shared.image_features;
    750   DequantMatrices& matrices = shared.matrices;
    751   Quantizer& quantizer = shared.quantizer;
    752   ImageI& raw_quant_field = shared.raw_quant_field;
    753   ColorCorrelationMap& cmap = shared.cmap;
    754   AcStrategyImage& ac_strategy = shared.ac_strategy;
    755   ImageB& epf_sharpness = shared.epf_sharpness;
    756   BlockCtxMap& block_ctx_map = shared.block_ctx_map;
    757 
    758   // Find and subtract splines.
    759   if (cparams.custom_splines.HasAny()) {
    760     image_features.splines = cparams.custom_splines;
    761   }
    762   if (!streaming_mode && cparams.speed_tier <= SpeedTier::kSquirrel) {
    763     if (!cparams.custom_splines.HasAny()) {
    764       image_features.splines = FindSplines(*opsin);
    765     }
    766     JXL_RETURN_IF_ERROR(image_features.splines.InitializeDrawCache(
    767         opsin->xsize(), opsin->ysize(), cmap));
    768     image_features.splines.SubtractFrom(opsin);
    769   }
    770 
    771   // Find and subtract patches/dots.
    772   if (!streaming_mode &&
    773       ApplyOverride(cparams.patches,
    774                     cparams.speed_tier <= SpeedTier::kSquirrel)) {
    775     JXL_RETURN_IF_ERROR(
    776         FindBestPatchDictionary(*opsin, enc_state, cms, pool, aux_out));
    777     PatchDictionaryEncoder::SubtractFrom(image_features.patches, opsin);
    778   }
    779 
    780   const float quant_dc = InitialQuantDC(cparams.butteraugli_distance);
    781 
    782   // TODO(veluca): we can now run all the code from here to FindBestQuantizer
    783   // (excluded) one rect at a time. Do that.
    784 
    785   // Dependency graph:
    786   //
    787   // input: either XYB or input image
    788   //
    789   // input image -> XYB [optional]
    790   // XYB -> initial quant field
    791   // XYB -> Gaborished XYB
    792   // Gaborished XYB -> CfL1
    793   // initial quant field, Gaborished XYB, CfL1 -> ACS
    794   // initial quant field, ACS, Gaborished XYB -> EPF control field
    795   // initial quant field -> adjusted initial quant field
    796   // adjusted initial quant field, ACS -> raw quant field
    797   // raw quant field, ACS, Gaborished XYB -> CfL2
    798   //
    799   // output: Gaborished XYB, CfL, ACS, raw quant field, EPF control field.
    800 
    801   ArControlFieldHeuristics ar_heuristics;
    802   AcStrategyHeuristics acs_heuristics(cparams);
    803   CfLHeuristics cfl_heuristics;
    804   ImageF initial_quant_field;
    805   ImageF initial_quant_masking;
    806   ImageF initial_quant_masking1x1;
    807 
    808   // Compute an initial estimate of the quantization field.
    809   // Call InitialQuantField only in Hare mode or slower. Otherwise, rely
    810   // on simple heuristics in FindBestAcStrategy, or set a constant for Falcon
    811   // mode.
    812   if (cparams.speed_tier > SpeedTier::kHare) {
    813     JXL_ASSIGN_OR_RETURN(
    814         initial_quant_field,
    815         ImageF::Create(frame_dim.xsize_blocks, frame_dim.ysize_blocks));
    816     JXL_ASSIGN_OR_RETURN(
    817         initial_quant_masking,
    818         ImageF::Create(frame_dim.xsize_blocks, frame_dim.ysize_blocks));
    819     float q = 0.79 / cparams.butteraugli_distance;
    820     FillImage(q, &initial_quant_field);
    821     FillImage(1.0f / (q + 0.001f), &initial_quant_masking);
    822     quantizer.ComputeGlobalScaleAndQuant(quant_dc, q, 0);
    823   } else {
    824     // Call this here, as it relies on pre-gaborish values.
    825     float butteraugli_distance_for_iqf = cparams.butteraugli_distance;
    826     if (!frame_header.loop_filter.gab) {
    827       butteraugli_distance_for_iqf *= 0.73f;
    828     }
    829     JXL_ASSIGN_OR_RETURN(
    830         initial_quant_field,
    831         InitialQuantField(butteraugli_distance_for_iqf, *opsin, rect, pool,
    832                           1.0f, &initial_quant_masking,
    833                           &initial_quant_masking1x1));
    834     float q = 0.39 / cparams.butteraugli_distance;
    835     quantizer.ComputeGlobalScaleAndQuant(quant_dc, q, 0);
    836   }
    837 
    838   // TODO(veluca): do something about animations.
    839 
    840   // Apply inverse-gaborish.
    841   if (frame_header.loop_filter.gab) {
    842     // Unsure why better to do some more gaborish on X and B than Y.
    843     float weight[3] = {
    844         1.0036278514398933f,
    845         0.99406123118127299f,
    846         0.99719338015886894f,
    847     };
    848     JXL_RETURN_IF_ERROR(GaborishInverse(opsin, rect, weight, pool));
    849   }
    850 
    851   if (initialize_global_state) {
    852     JXL_RETURN_IF_ERROR(
    853         FindBestDequantMatrices(cparams, modular_frame_encoder, &matrices));
    854   }
    855 
    856   JXL_RETURN_IF_ERROR(cfl_heuristics.Init(rect));
    857   acs_heuristics.Init(*opsin, rect, initial_quant_field, initial_quant_masking,
    858                       initial_quant_masking1x1, &matrices);
    859 
    860   std::atomic<bool> has_error{false};
    861   auto process_tile = [&](const uint32_t tid, const size_t thread) {
    862     if (has_error) return;
    863     size_t n_enc_tiles = DivCeil(frame_dim.xsize_blocks, kEncTileDimInBlocks);
    864     size_t tx = tid % n_enc_tiles;
    865     size_t ty = tid / n_enc_tiles;
    866     size_t by0 = ty * kEncTileDimInBlocks;
    867     size_t by1 =
    868         std::min((ty + 1) * kEncTileDimInBlocks, frame_dim.ysize_blocks);
    869     size_t bx0 = tx * kEncTileDimInBlocks;
    870     size_t bx1 =
    871         std::min((tx + 1) * kEncTileDimInBlocks, frame_dim.xsize_blocks);
    872     Rect r(bx0, by0, bx1 - bx0, by1 - by0);
    873 
    874     // For speeds up to Wombat, we only compute the color correlation map
    875     // once we know the transform type and the quantization map.
    876     if (cparams.speed_tier <= SpeedTier::kSquirrel) {
    877       cfl_heuristics.ComputeTile(r, *opsin, rect, matrices,
    878                                  /*ac_strategy=*/nullptr,
    879                                  /*raw_quant_field=*/nullptr,
    880                                  /*quantizer=*/nullptr, /*fast=*/false, thread,
    881                                  &cmap);
    882     }
    883 
    884     // Choose block sizes.
    885     acs_heuristics.ProcessRect(r, cmap, &ac_strategy, thread);
    886 
    887     // Choose amount of post-processing smoothing.
    888     // TODO(veluca): should this go *after* AdjustQuantField?
    889     if (!ar_heuristics.RunRect(cparams, frame_header, r, *opsin, rect,
    890                                initial_quant_field, ac_strategy, &epf_sharpness,
    891                                thread)) {
    892       has_error = true;
    893       return;
    894     }
    895 
    896     // Always set the initial quant field, so we can compute the CfL map with
    897     // more accuracy. The initial quant field might change in slower modes, but
    898     // adjusting the quant field with butteraugli when all the other encoding
    899     // parameters are fixed is likely a more reliable choice anyway.
    900     AdjustQuantField(ac_strategy, r, cparams.butteraugli_distance,
    901                      &initial_quant_field);
    902     quantizer.SetQuantFieldRect(initial_quant_field, r, &raw_quant_field);
    903 
    904     // Compute a non-default CfL map if we are at Hare speed, or slower.
    905     if (cparams.speed_tier <= SpeedTier::kHare) {
    906       cfl_heuristics.ComputeTile(
    907           r, *opsin, rect, matrices, &ac_strategy, &raw_quant_field, &quantizer,
    908           /*fast=*/cparams.speed_tier >= SpeedTier::kWombat, thread, &cmap);
    909     }
    910   };
    911   JXL_RETURN_IF_ERROR(RunOnPool(
    912       pool, 0,
    913       DivCeil(frame_dim.xsize_blocks, kEncTileDimInBlocks) *
    914           DivCeil(frame_dim.ysize_blocks, kEncTileDimInBlocks),
    915       [&](const size_t num_threads) {
    916         acs_heuristics.PrepareForThreads(num_threads);
    917         ar_heuristics.PrepareForThreads(num_threads);
    918         cfl_heuristics.PrepareForThreads(num_threads);
    919         return true;
    920       },
    921       process_tile, "Enc Heuristics"));
    922   if (has_error) return JXL_FAILURE("Enc Heuristics failed");
    923 
    924   JXL_RETURN_IF_ERROR(acs_heuristics.Finalize(frame_dim, ac_strategy, aux_out));
    925 
    926   // Refine quantization levels.
    927   if (!streaming_mode) {
    928     JXL_RETURN_IF_ERROR(FindBestQuantizer(frame_header, original_pixels, *opsin,
    929                                           initial_quant_field, enc_state, cms,
    930                                           pool, aux_out));
    931   }
    932 
    933   // Choose a context model that depends on the amount of quantization for AC.
    934   if (cparams.speed_tier < SpeedTier::kFalcon && initialize_global_state) {
    935     FindBestBlockEntropyModel(cparams, raw_quant_field, ac_strategy,
    936                               &block_ctx_map);
    937   }
    938   return true;
    939 }
    940 
    941 }  // namespace jxl