libjxl

FORK: libjxl patches used on blog
git clone https://git.neptards.moe/blog/libjxl.git
Log | Files | Refs | Submodules | README | LICENSE

enc_ac_strategy.cc (46632B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 #include "lib/jxl/enc_ac_strategy.h"
      7 
      8 #include <stdint.h>
      9 #include <string.h>
     10 
     11 #include <algorithm>
     12 #include <cmath>
     13 #include <cstdio>
     14 
     15 #undef HWY_TARGET_INCLUDE
     16 #define HWY_TARGET_INCLUDE "lib/jxl/enc_ac_strategy.cc"
     17 #include <hwy/foreach_target.h>
     18 #include <hwy/highway.h>
     19 
     20 #include "lib/jxl/ac_strategy.h"
     21 #include "lib/jxl/base/bits.h"
     22 #include "lib/jxl/base/compiler_specific.h"
     23 #include "lib/jxl/base/fast_math-inl.h"
     24 #include "lib/jxl/base/status.h"
     25 #include "lib/jxl/dec_transforms-inl.h"
     26 #include "lib/jxl/enc_aux_out.h"
     27 #include "lib/jxl/enc_debug_image.h"
     28 #include "lib/jxl/enc_params.h"
     29 #include "lib/jxl/enc_transforms-inl.h"
     30 #include "lib/jxl/simd_util.h"
     31 
     32 // Some of the floating point constants in this file and in other
     33 // files in the libjxl project have been obtained using the
     34 // tools/optimizer/simplex_fork.py tool. It is a variation of
     35 // Nelder-Mead optimization, and we generally try to minimize
     36 // BPP * pnorm aggregate as reported by the benchmark_xl tool,
     37 // but occasionally the values are optimized by using additional
     38 // constraints such as maintaining a certain density, or ratio of
     39 // popularity of integral transforms. Jyrki visually reviews all
     40 // such changes and often makes manual changes to maintain good
     41 // visual quality to changes where butteraugli was not sufficiently
     42 // sensitive to some kind of degradation. Unfortunately image quality
     43 // is still more of an art than science.
     44 
     45 // Set JXL_DEBUG_AC_STRATEGY to 1 to enable debugging.
     46 #ifndef JXL_DEBUG_AC_STRATEGY
     47 #define JXL_DEBUG_AC_STRATEGY 0
     48 #endif
     49 
     50 // This must come before the begin/end_target, but HWY_ONCE is only true
     51 // after that, so use an "include guard".
     52 #ifndef LIB_JXL_ENC_AC_STRATEGY_
     53 #define LIB_JXL_ENC_AC_STRATEGY_
     54 // Parameters of the heuristic are marked with a OPTIMIZE comment.
     55 namespace jxl {
     56 namespace {
     57 
     58 // Debugging utilities.
     59 
     60 // Returns a linear sRGB color (as bytes) for each AC strategy.
     61 const uint8_t* TypeColor(const uint8_t& raw_strategy) {
     62   JXL_ASSERT(AcStrategy::IsRawStrategyValid(raw_strategy));
     63   static_assert(AcStrategy::kNumValidStrategies == 27, "Change colors");
     64   static constexpr uint8_t kColors[][3] = {
     65       {0xFF, 0xFF, 0x00},  // DCT8
     66       {0xFF, 0x80, 0x80},  // HORNUSS
     67       {0xFF, 0x80, 0x80},  // DCT2x2
     68       {0xFF, 0x80, 0x80},  // DCT4x4
     69       {0x80, 0xFF, 0x00},  // DCT16x16
     70       {0x00, 0xC0, 0x00},  // DCT32x32
     71       {0xC0, 0xFF, 0x00},  // DCT16x8
     72       {0xC0, 0xFF, 0x00},  // DCT8x16
     73       {0x00, 0xFF, 0x00},  // DCT32x8
     74       {0x00, 0xFF, 0x00},  // DCT8x32
     75       {0x00, 0xFF, 0x00},  // DCT32x16
     76       {0x00, 0xFF, 0x00},  // DCT16x32
     77       {0xFF, 0x80, 0x00},  // DCT4x8
     78       {0xFF, 0x80, 0x00},  // DCT8x4
     79       {0xFF, 0xFF, 0x80},  // AFV0
     80       {0xFF, 0xFF, 0x80},  // AFV1
     81       {0xFF, 0xFF, 0x80},  // AFV2
     82       {0xFF, 0xFF, 0x80},  // AFV3
     83       {0x00, 0xC0, 0xFF},  // DCT64x64
     84       {0x00, 0xFF, 0xFF},  // DCT64x32
     85       {0x00, 0xFF, 0xFF},  // DCT32x64
     86       {0x00, 0x40, 0xFF},  // DCT128x128
     87       {0x00, 0x80, 0xFF},  // DCT128x64
     88       {0x00, 0x80, 0xFF},  // DCT64x128
     89       {0x00, 0x00, 0xC0},  // DCT256x256
     90       {0x00, 0x00, 0xFF},  // DCT256x128
     91       {0x00, 0x00, 0xFF},  // DCT128x256
     92   };
     93   return kColors[raw_strategy];
     94 }
     95 
     96 const uint8_t* TypeMask(const uint8_t& raw_strategy) {
     97   JXL_ASSERT(AcStrategy::IsRawStrategyValid(raw_strategy));
     98   static_assert(AcStrategy::kNumValidStrategies == 27, "Add masks");
     99   // implicitly, first row and column is made dark
    100   static constexpr uint8_t kMask[][64] = {
    101       {
    102           0, 0, 0, 0, 0, 0, 0, 0,  //
    103           0, 0, 0, 0, 0, 0, 0, 0,  //
    104           0, 0, 0, 0, 0, 0, 0, 0,  //
    105           0, 0, 0, 0, 0, 0, 0, 0,  //
    106           0, 0, 0, 0, 0, 0, 0, 0,  //
    107           0, 0, 0, 0, 0, 0, 0, 0,  //
    108           0, 0, 0, 0, 0, 0, 0, 0,  //
    109           0, 0, 0, 0, 0, 0, 0, 0,  //
    110       },                           // DCT8
    111       {
    112           0, 0, 0, 0, 0, 0, 0, 0,  //
    113           0, 0, 0, 0, 0, 0, 0, 0,  //
    114           0, 0, 1, 0, 0, 1, 0, 0,  //
    115           0, 0, 1, 0, 0, 1, 0, 0,  //
    116           0, 0, 1, 1, 1, 1, 0, 0,  //
    117           0, 0, 1, 0, 0, 1, 0, 0,  //
    118           0, 0, 1, 0, 0, 1, 0, 0,  //
    119           0, 0, 0, 0, 0, 0, 0, 0,  //
    120       },                           // HORNUSS
    121       {
    122           1, 1, 1, 1, 1, 1, 1, 1,  //
    123           1, 0, 1, 0, 1, 0, 1, 0,  //
    124           1, 1, 1, 1, 1, 1, 1, 1,  //
    125           1, 0, 1, 0, 1, 0, 1, 0,  //
    126           1, 1, 1, 1, 1, 1, 1, 1,  //
    127           1, 0, 1, 0, 1, 0, 1, 0,  //
    128           1, 1, 1, 1, 1, 1, 1, 1,  //
    129           1, 0, 1, 0, 1, 0, 1, 0,  //
    130       },                           // 2x2
    131       {
    132           0, 0, 0, 0, 1, 0, 0, 0,  //
    133           0, 0, 0, 0, 1, 0, 0, 0,  //
    134           0, 0, 0, 0, 1, 0, 0, 0,  //
    135           0, 0, 0, 0, 1, 0, 0, 0,  //
    136           1, 1, 1, 1, 1, 1, 1, 1,  //
    137           0, 0, 0, 0, 1, 0, 0, 0,  //
    138           0, 0, 0, 0, 1, 0, 0, 0,  //
    139           0, 0, 0, 0, 1, 0, 0, 0,  //
    140       },                           // 4x4
    141       {},                          // DCT16x16 (unused)
    142       {},                          // DCT32x32 (unused)
    143       {},                          // DCT16x8 (unused)
    144       {},                          // DCT8x16 (unused)
    145       {},                          // DCT32x8 (unused)
    146       {},                          // DCT8x32 (unused)
    147       {},                          // DCT32x16 (unused)
    148       {},                          // DCT16x32 (unused)
    149       {
    150           0, 0, 0, 0, 0, 0, 0, 0,  //
    151           0, 0, 0, 0, 0, 0, 0, 0,  //
    152           0, 0, 0, 0, 0, 0, 0, 0,  //
    153           0, 0, 0, 0, 0, 0, 0, 0,  //
    154           1, 1, 1, 1, 1, 1, 1, 1,  //
    155           0, 0, 0, 0, 0, 0, 0, 0,  //
    156           0, 0, 0, 0, 0, 0, 0, 0,  //
    157           0, 0, 0, 0, 0, 0, 0, 0,  //
    158       },                           // DCT4x8
    159       {
    160           0, 0, 0, 0, 1, 0, 0, 0,  //
    161           0, 0, 0, 0, 1, 0, 0, 0,  //
    162           0, 0, 0, 0, 1, 0, 0, 0,  //
    163           0, 0, 0, 0, 1, 0, 0, 0,  //
    164           0, 0, 0, 0, 1, 0, 0, 0,  //
    165           0, 0, 0, 0, 1, 0, 0, 0,  //
    166           0, 0, 0, 0, 1, 0, 0, 0,  //
    167           0, 0, 0, 0, 1, 0, 0, 0,  //
    168       },                           // DCT8x4
    169       {
    170           1, 1, 1, 1, 1, 0, 0, 0,  //
    171           1, 1, 1, 1, 0, 0, 0, 0,  //
    172           1, 1, 1, 0, 0, 0, 0, 0,  //
    173           1, 1, 0, 0, 0, 0, 0, 0,  //
    174           1, 0, 0, 0, 0, 0, 0, 0,  //
    175           0, 0, 0, 0, 0, 0, 0, 0,  //
    176           0, 0, 0, 0, 0, 0, 0, 0,  //
    177           0, 0, 0, 0, 0, 0, 0, 0,  //
    178       },                           // AFV0
    179       {
    180           0, 0, 0, 0, 1, 1, 1, 1,  //
    181           0, 0, 0, 0, 0, 1, 1, 1,  //
    182           0, 0, 0, 0, 0, 0, 1, 1,  //
    183           0, 0, 0, 0, 0, 0, 0, 1,  //
    184           0, 0, 0, 0, 0, 0, 0, 0,  //
    185           0, 0, 0, 0, 0, 0, 0, 0,  //
    186           0, 0, 0, 0, 0, 0, 0, 0,  //
    187           0, 0, 0, 0, 0, 0, 0, 0,  //
    188       },                           // AFV1
    189       {
    190           0, 0, 0, 0, 0, 0, 0, 0,  //
    191           0, 0, 0, 0, 0, 0, 0, 0,  //
    192           0, 0, 0, 0, 0, 0, 0, 0,  //
    193           0, 0, 0, 0, 0, 0, 0, 0,  //
    194           1, 0, 0, 0, 0, 0, 0, 0,  //
    195           1, 1, 0, 0, 0, 0, 0, 0,  //
    196           1, 1, 1, 0, 0, 0, 0, 0,  //
    197           1, 1, 1, 1, 0, 0, 0, 0,  //
    198       },                           // AFV2
    199       {
    200           0, 0, 0, 0, 0, 0, 0, 0,  //
    201           0, 0, 0, 0, 0, 0, 0, 0,  //
    202           0, 0, 0, 0, 0, 0, 0, 0,  //
    203           0, 0, 0, 0, 0, 0, 0, 0,  //
    204           0, 0, 0, 0, 0, 0, 0, 0,  //
    205           0, 0, 0, 0, 0, 0, 0, 1,  //
    206           0, 0, 0, 0, 0, 0, 1, 1,  //
    207           0, 0, 0, 0, 0, 1, 1, 1,  //
    208       },                           // AFV3
    209   };
    210   return kMask[raw_strategy];
    211 }
    212 
    213 Status DumpAcStrategy(const AcStrategyImage& ac_strategy, size_t xsize,
    214                       size_t ysize, const char* tag, AuxOut* aux_out,
    215                       const CompressParams& cparams) {
    216   JXL_ASSIGN_OR_RETURN(Image3F color_acs, Image3F::Create(xsize, ysize));
    217   for (size_t y = 0; y < ysize; y++) {
    218     float* JXL_RESTRICT rows[3] = {
    219         color_acs.PlaneRow(0, y),
    220         color_acs.PlaneRow(1, y),
    221         color_acs.PlaneRow(2, y),
    222     };
    223     const AcStrategyRow acs_row = ac_strategy.ConstRow(y / kBlockDim);
    224     for (size_t x = 0; x < xsize; x++) {
    225       AcStrategy acs = acs_row[x / kBlockDim];
    226       const uint8_t* JXL_RESTRICT color = TypeColor(acs.RawStrategy());
    227       for (size_t c = 0; c < 3; c++) {
    228         rows[c][x] = color[c] / 255.f;
    229       }
    230     }
    231   }
    232   size_t stride = color_acs.PixelsPerRow();
    233   for (size_t c = 0; c < 3; c++) {
    234     for (size_t by = 0; by < DivCeil(ysize, kBlockDim); by++) {
    235       float* JXL_RESTRICT row = color_acs.PlaneRow(c, by * kBlockDim);
    236       const AcStrategyRow acs_row = ac_strategy.ConstRow(by);
    237       for (size_t bx = 0; bx < DivCeil(xsize, kBlockDim); bx++) {
    238         AcStrategy acs = acs_row[bx];
    239         if (!acs.IsFirstBlock()) continue;
    240         const uint8_t* JXL_RESTRICT color = TypeColor(acs.RawStrategy());
    241         const uint8_t* JXL_RESTRICT mask = TypeMask(acs.RawStrategy());
    242         if (acs.covered_blocks_x() == 1 && acs.covered_blocks_y() == 1) {
    243           for (size_t iy = 0; iy < kBlockDim && by * kBlockDim + iy < ysize;
    244                iy++) {
    245             for (size_t ix = 0; ix < kBlockDim && bx * kBlockDim + ix < xsize;
    246                  ix++) {
    247               if (mask[iy * kBlockDim + ix]) {
    248                 row[iy * stride + bx * kBlockDim + ix] = color[c] / 800.f;
    249               }
    250             }
    251           }
    252         }
    253         // draw block edges
    254         for (size_t ix = 0; ix < kBlockDim * acs.covered_blocks_x() &&
    255                             bx * kBlockDim + ix < xsize;
    256              ix++) {
    257           row[0 * stride + bx * kBlockDim + ix] = color[c] / 350.f;
    258         }
    259         for (size_t iy = 0; iy < kBlockDim * acs.covered_blocks_y() &&
    260                             by * kBlockDim + iy < ysize;
    261              iy++) {
    262           row[iy * stride + bx * kBlockDim + 0] = color[c] / 350.f;
    263         }
    264       }
    265     }
    266   }
    267   return DumpImage(cparams, tag, color_acs);
    268 }
    269 
    270 }  // namespace
    271 }  // namespace jxl
    272 #endif  // LIB_JXL_ENC_AC_STRATEGY_
    273 
    274 HWY_BEFORE_NAMESPACE();
    275 namespace jxl {
    276 namespace HWY_NAMESPACE {
    277 
    278 // These templates are not found via ADL.
    279 using hwy::HWY_NAMESPACE::AbsDiff;
    280 using hwy::HWY_NAMESPACE::Eq;
    281 using hwy::HWY_NAMESPACE::IfThenElseZero;
    282 using hwy::HWY_NAMESPACE::IfThenZeroElse;
    283 using hwy::HWY_NAMESPACE::Round;
    284 using hwy::HWY_NAMESPACE::Sqrt;
    285 
    286 bool MultiBlockTransformCrossesHorizontalBoundary(
    287     const AcStrategyImage& ac_strategy, size_t start_x, size_t y,
    288     size_t end_x) {
    289   if (start_x >= ac_strategy.xsize() || y >= ac_strategy.ysize()) {
    290     return false;
    291   }
    292   if (y % 8 == 0) {
    293     // Nothing crosses 64x64 boundaries, and the memory on the other side
    294     // of the 64x64 block may still uninitialized.
    295     return false;
    296   }
    297   end_x = std::min(end_x, ac_strategy.xsize());
    298   // The first multiblock might be before the start_x, let's adjust it
    299   // to point to the first IsFirstBlock() == true block we find by backward
    300   // tracing.
    301   AcStrategyRow row = ac_strategy.ConstRow(y);
    302   const size_t start_x_limit = start_x & ~7;
    303   while (start_x != start_x_limit && !row[start_x].IsFirstBlock()) {
    304     --start_x;
    305   }
    306   for (size_t x = start_x; x < end_x;) {
    307     if (row[x].IsFirstBlock()) {
    308       x += row[x].covered_blocks_x();
    309     } else {
    310       return true;
    311     }
    312   }
    313   return false;
    314 }
    315 
    316 bool MultiBlockTransformCrossesVerticalBoundary(
    317     const AcStrategyImage& ac_strategy, size_t x, size_t start_y,
    318     size_t end_y) {
    319   if (x >= ac_strategy.xsize() || start_y >= ac_strategy.ysize()) {
    320     return false;
    321   }
    322   if (x % 8 == 0) {
    323     // Nothing crosses 64x64 boundaries, and the memory on the other side
    324     // of the 64x64 block may still uninitialized.
    325     return false;
    326   }
    327   end_y = std::min(end_y, ac_strategy.ysize());
    328   // The first multiblock might be before the start_y, let's adjust it
    329   // to point to the first IsFirstBlock() == true block we find by backward
    330   // tracing.
    331   const size_t start_y_limit = start_y & ~7;
    332   while (start_y != start_y_limit &&
    333          !ac_strategy.ConstRow(start_y)[x].IsFirstBlock()) {
    334     --start_y;
    335   }
    336 
    337   for (size_t y = start_y; y < end_y;) {
    338     AcStrategyRow row = ac_strategy.ConstRow(y);
    339     if (row[x].IsFirstBlock()) {
    340       y += row[x].covered_blocks_y();
    341     } else {
    342       return true;
    343     }
    344   }
    345   return false;
    346 }
    347 
    348 float EstimateEntropy(const AcStrategy& acs, float entropy_mul, size_t x,
    349                       size_t y, const ACSConfig& config,
    350                       const float* JXL_RESTRICT cmap_factors, float* block,
    351                       float* full_scratch_space, uint32_t* quantized) {
    352   float* mem = full_scratch_space;
    353   float* scratch_space = full_scratch_space + AcStrategy::kMaxCoeffArea;
    354   const size_t size = (1 << acs.log2_covered_blocks()) * kDCTBlockSize;
    355 
    356   // Apply transform.
    357   for (size_t c = 0; c < 3; c++) {
    358     float* JXL_RESTRICT block_c = block + size * c;
    359     TransformFromPixels(acs.Strategy(), &config.Pixel(c, x, y),
    360                         config.src_stride, block_c, scratch_space);
    361   }
    362   HWY_FULL(float) df;
    363 
    364   const size_t num_blocks = acs.covered_blocks_x() * acs.covered_blocks_y();
    365   // avoid large blocks when there is a lot going on in red-green.
    366   float quant_norm16 = 0;
    367   if (num_blocks == 1) {
    368     // When it is only one 8x8, we don't need aggregation of values.
    369     quant_norm16 = config.Quant(x / 8, y / 8);
    370   } else if (num_blocks == 2) {
    371     // Taking max instead of 8th norm seems to work
    372     // better for smallest blocks up to 16x8. Jyrki couldn't get
    373     // improvements in trying the same for 16x16 blocks.
    374     if (acs.covered_blocks_y() == 2) {
    375       quant_norm16 =
    376           std::max(config.Quant(x / 8, y / 8), config.Quant(x / 8, y / 8 + 1));
    377     } else {
    378       quant_norm16 =
    379           std::max(config.Quant(x / 8, y / 8), config.Quant(x / 8 + 1, y / 8));
    380     }
    381   } else {
    382     // Load QF value, calculate empirical heuristic on masking field
    383     // for weighting the information loss. Information loss manifests
    384     // itself as ringing, and masking could hide it.
    385     for (size_t iy = 0; iy < acs.covered_blocks_y(); iy++) {
    386       for (size_t ix = 0; ix < acs.covered_blocks_x(); ix++) {
    387         float qval = config.Quant(x / 8 + ix, y / 8 + iy);
    388         qval *= qval;
    389         qval *= qval;
    390         qval *= qval;
    391         quant_norm16 += qval * qval;
    392       }
    393     }
    394     quant_norm16 /= num_blocks;
    395     quant_norm16 = FastPowf(quant_norm16, 1.0f / 16.0f);
    396   }
    397   const auto quant = Set(df, quant_norm16);
    398 
    399   // Compute entropy.
    400   float entropy = 0.0f;
    401   const HWY_CAPPED(float, 8) df8;
    402 
    403   auto loss = Zero(df8);
    404   for (size_t c = 0; c < 3; c++) {
    405     const float* inv_matrix = config.dequant->InvMatrix(acs.RawStrategy(), c);
    406     const float* matrix = config.dequant->Matrix(acs.RawStrategy(), c);
    407     const auto cmap_factor = Set(df, cmap_factors[c]);
    408 
    409     auto entropy_v = Zero(df);
    410     auto nzeros_v = Zero(df);
    411     for (size_t i = 0; i < num_blocks * kDCTBlockSize; i += Lanes(df)) {
    412       const auto in = Load(df, block + c * size + i);
    413       const auto in_y = Mul(Load(df, block + size + i), cmap_factor);
    414       const auto im = Load(df, inv_matrix + i);
    415       const auto val = Mul(Sub(in, in_y), Mul(im, quant));
    416       const auto rval = Round(val);
    417       const auto diff = Sub(val, rval);
    418       const auto m = Load(df, matrix + i);
    419       Store(Mul(m, diff), df, &mem[i]);
    420       const auto q = Abs(rval);
    421       const auto q_is_zero = Eq(q, Zero(df));
    422       // We used to have q * C here, but that cost model seems to
    423       // be punishing large values more than necessary. Sqrt tries
    424       // to avoid large values less aggressively.
    425       entropy_v = Add(Sqrt(q), entropy_v);
    426       nzeros_v = Add(nzeros_v, IfThenZeroElse(q_is_zero, Set(df, 1.0f)));
    427     }
    428 
    429     {
    430       auto lossc = Zero(df8);
    431       TransformToPixels(acs.Strategy(), &mem[0], block,
    432                         acs.covered_blocks_x() * 8, scratch_space);
    433 
    434       for (size_t iy = 0; iy < acs.covered_blocks_y(); iy++) {
    435         for (size_t ix = 0; ix < acs.covered_blocks_x(); ix++) {
    436           for (size_t dy = 0; dy < kBlockDim; ++dy) {
    437             for (size_t dx = 0; dx < kBlockDim; dx += Lanes(df8)) {
    438               auto in = Load(df8, block +
    439                                       (iy * kBlockDim + dy) *
    440                                           (acs.covered_blocks_x() * kBlockDim) +
    441                                       ix * kBlockDim + dx);
    442               auto masku = Abs(Load(
    443                   df8, config.MaskingPtr1x1(x + ix * 8 + dx, y + iy * 8 + dy)));
    444               in = Mul(masku, in);
    445               in = Mul(in, in);
    446               in = Mul(in, in);
    447               in = Mul(in, in);
    448               lossc = Add(lossc, in);
    449             }
    450           }
    451         }
    452       }
    453       static const double kChannelMul[3] = {
    454           10.2,
    455           1.0,
    456           1.03,
    457       };
    458       lossc = Mul(Set(df8, pow(kChannelMul[c], 8.0)), lossc);
    459       loss = Add(loss, lossc);
    460     }
    461     entropy += config.cost_delta * GetLane(SumOfLanes(df, entropy_v));
    462     size_t num_nzeros = GetLane(SumOfLanes(df, nzeros_v));
    463     // Add #bit of num_nonzeros, as an estimate of the cost for encoding the
    464     // number of non-zeros of the block.
    465     size_t nbits = CeilLog2Nonzero(num_nzeros + 1) + 1;
    466     // Also add #bit of #bit of num_nonzeros, to estimate the ANS cost, with a
    467     // bias.
    468     entropy += config.zeros_mul * (CeilLog2Nonzero(nbits + 17) + nbits);
    469   }
    470   float loss_scalar =
    471       pow(GetLane(SumOfLanes(df8, loss)) / (num_blocks * kDCTBlockSize),
    472           1.0 / 8.0) *
    473       (num_blocks * kDCTBlockSize) / quant_norm16;
    474   float ret = entropy * entropy_mul;
    475   ret += config.info_loss_multiplier * loss_scalar;
    476   return ret;
    477 }
    478 
    479 uint8_t FindBest8x8Transform(size_t x, size_t y, int encoding_speed_tier,
    480                              float butteraugli_target, const ACSConfig& config,
    481                              const float* JXL_RESTRICT cmap_factors,
    482                              AcStrategyImage* JXL_RESTRICT ac_strategy,
    483                              float* block, float* scratch_space,
    484                              uint32_t* quantized, float* entropy_out) {
    485   struct TransformTry8x8 {
    486     AcStrategy::Type type;
    487     int encoding_speed_tier_max_limit;
    488     double entropy_mul;
    489   };
    490   static const TransformTry8x8 kTransforms8x8[] = {
    491       {
    492           AcStrategy::Type::DCT,
    493           9,
    494           0.8,
    495       },
    496       {
    497           AcStrategy::Type::DCT4X4,
    498           5,
    499           1.08,
    500       },
    501       {
    502           AcStrategy::Type::DCT2X2,
    503           5,
    504           0.95,
    505       },
    506       {
    507           AcStrategy::Type::DCT4X8,
    508           4,
    509           0.85931637428340035,
    510       },
    511       {
    512           AcStrategy::Type::DCT8X4,
    513           4,
    514           0.85931637428340035,
    515       },
    516       {
    517           AcStrategy::Type::IDENTITY,
    518           5,
    519           1.0427542510634957,
    520       },
    521       {
    522           AcStrategy::Type::AFV0,
    523           4,
    524           0.81779489591359944,
    525       },
    526       {
    527           AcStrategy::Type::AFV1,
    528           4,
    529           0.81779489591359944,
    530       },
    531       {
    532           AcStrategy::Type::AFV2,
    533           4,
    534           0.81779489591359944,
    535       },
    536       {
    537           AcStrategy::Type::AFV3,
    538           4,
    539           0.81779489591359944,
    540       },
    541   };
    542   double best = 1e30;
    543   uint8_t best_tx = kTransforms8x8[0].type;
    544   for (auto tx : kTransforms8x8) {
    545     if (tx.encoding_speed_tier_max_limit < encoding_speed_tier) {
    546       continue;
    547     }
    548     AcStrategy acs = AcStrategy::FromRawStrategy(tx.type);
    549     float entropy_mul = tx.entropy_mul / kTransforms8x8[0].entropy_mul;
    550     if ((tx.type == AcStrategy::Type::DCT2X2 ||
    551          tx.type == AcStrategy::Type::IDENTITY) &&
    552         butteraugli_target < 5.0) {
    553       static const float kFavor2X2AtHighQuality = 0.4;
    554       float weight = pow((5.0f - butteraugli_target) / 5.0f, 2.0);
    555       entropy_mul -= kFavor2X2AtHighQuality * weight;
    556     }
    557     if ((tx.type != AcStrategy::Type::DCT &&
    558          tx.type != AcStrategy::Type::DCT2X2 &&
    559          tx.type != AcStrategy::Type::IDENTITY) &&
    560         butteraugli_target > 4.0) {
    561       static const float kAvoidEntropyOfTransforms = 0.5;
    562       float mul = 1.0;
    563       if (butteraugli_target < 12.0) {
    564         mul *= (12.0 - 4.0) / (butteraugli_target - 4.0);
    565       }
    566       entropy_mul += kAvoidEntropyOfTransforms * mul;
    567     }
    568     float entropy =
    569         EstimateEntropy(acs, entropy_mul, x, y, config, cmap_factors, block,
    570                         scratch_space, quantized);
    571     if (entropy < best) {
    572       best_tx = tx.type;
    573       best = entropy;
    574     }
    575   }
    576   *entropy_out = best;
    577   return best_tx;
    578 }
    579 
    580 // bx, by addresses the 64x64 block at 8x8 subresolution
    581 // cx, cy addresses the left, upper 8x8 block position of the candidate
    582 // transform.
    583 void TryMergeAcs(AcStrategy::Type acs_raw, size_t bx, size_t by, size_t cx,
    584                  size_t cy, const ACSConfig& config,
    585                  const float* JXL_RESTRICT cmap_factors,
    586                  AcStrategyImage* JXL_RESTRICT ac_strategy,
    587                  const float entropy_mul, const uint8_t candidate_priority,
    588                  uint8_t* priority, float* JXL_RESTRICT entropy_estimate,
    589                  float* block, float* scratch_space, uint32_t* quantized) {
    590   AcStrategy acs = AcStrategy::FromRawStrategy(acs_raw);
    591   float entropy_current = 0;
    592   for (size_t iy = 0; iy < acs.covered_blocks_y(); ++iy) {
    593     for (size_t ix = 0; ix < acs.covered_blocks_x(); ++ix) {
    594       if (priority[(cy + iy) * 8 + (cx + ix)] >= candidate_priority) {
    595         // Transform would reuse already allocated blocks and
    596         // lead to invalid overlaps, for example DCT64X32 vs.
    597         // DCT32X64.
    598         return;
    599       }
    600       entropy_current += entropy_estimate[(cy + iy) * 8 + (cx + ix)];
    601     }
    602   }
    603   float entropy_candidate =
    604       EstimateEntropy(acs, entropy_mul, (bx + cx) * 8, (by + cy) * 8, config,
    605                       cmap_factors, block, scratch_space, quantized);
    606   if (entropy_candidate >= entropy_current) return;
    607   // Accept the candidate.
    608   for (size_t iy = 0; iy < acs.covered_blocks_y(); iy++) {
    609     for (size_t ix = 0; ix < acs.covered_blocks_x(); ix++) {
    610       entropy_estimate[(cy + iy) * 8 + cx + ix] = 0;
    611       priority[(cy + iy) * 8 + cx + ix] = candidate_priority;
    612     }
    613   }
    614   ac_strategy->Set(bx + cx, by + cy, acs_raw);
    615   entropy_estimate[cy * 8 + cx] = entropy_candidate;
    616 }
    617 
    618 static void SetEntropyForTransform(size_t cx, size_t cy,
    619                                    const AcStrategy::Type acs_raw,
    620                                    float entropy,
    621                                    float* JXL_RESTRICT entropy_estimate) {
    622   const AcStrategy acs = AcStrategy::FromRawStrategy(acs_raw);
    623   for (size_t dy = 0; dy < acs.covered_blocks_y(); ++dy) {
    624     for (size_t dx = 0; dx < acs.covered_blocks_x(); ++dx) {
    625       entropy_estimate[(cy + dy) * 8 + cx + dx] = 0.0;
    626     }
    627   }
    628   entropy_estimate[cy * 8 + cx] = entropy;
    629 }
    630 
    631 AcStrategy::Type AcsSquare(size_t blocks) {
    632   if (blocks == 2) {
    633     return AcStrategy::Type::DCT16X16;
    634   } else if (blocks == 4) {
    635     return AcStrategy::Type::DCT32X32;
    636   } else {
    637     return AcStrategy::Type::DCT64X64;
    638   }
    639 }
    640 
    641 AcStrategy::Type AcsVerticalSplit(size_t blocks) {
    642   if (blocks == 2) {
    643     return AcStrategy::Type::DCT16X8;
    644   } else if (blocks == 4) {
    645     return AcStrategy::Type::DCT32X16;
    646   } else {
    647     return AcStrategy::Type::DCT64X32;
    648   }
    649 }
    650 
    651 AcStrategy::Type AcsHorizontalSplit(size_t blocks) {
    652   if (blocks == 2) {
    653     return AcStrategy::Type::DCT8X16;
    654   } else if (blocks == 4) {
    655     return AcStrategy::Type::DCT16X32;
    656   } else {
    657     return AcStrategy::Type::DCT32X64;
    658   }
    659 }
    660 
    661 // The following function tries to merge smaller transforms into
    662 // squares and the rectangles originating from a single middle division
    663 // (horizontal or vertical) fairly.
    664 //
    665 // This is now generalized to concern about squares
    666 // of blocks X blocks size, where a block is 8x8 pixels.
    667 void FindBestFirstLevelDivisionForSquare(
    668     size_t blocks, bool allow_square_transform, size_t bx, size_t by, size_t cx,
    669     size_t cy, const ACSConfig& config, const float* JXL_RESTRICT cmap_factors,
    670     AcStrategyImage* JXL_RESTRICT ac_strategy, const float entropy_mul_JXK,
    671     const float entropy_mul_JXJ, float* JXL_RESTRICT entropy_estimate,
    672     float* block, float* scratch_space, uint32_t* quantized) {
    673   // We denote J for the larger dimension here, and K for the smaller.
    674   // For example, for 32x32 block splitting, J would be 32, K 16.
    675   const size_t blocks_half = blocks / 2;
    676   const AcStrategy::Type acs_rawJXK = AcsVerticalSplit(blocks);
    677   const AcStrategy::Type acs_rawKXJ = AcsHorizontalSplit(blocks);
    678   const AcStrategy::Type acs_rawJXJ = AcsSquare(blocks);
    679   const AcStrategy acsJXK = AcStrategy::FromRawStrategy(acs_rawJXK);
    680   const AcStrategy acsKXJ = AcStrategy::FromRawStrategy(acs_rawKXJ);
    681   const AcStrategy acsJXJ = AcStrategy::FromRawStrategy(acs_rawJXJ);
    682   AcStrategyRow row0 = ac_strategy->ConstRow(by + cy + 0);
    683   AcStrategyRow row1 = ac_strategy->ConstRow(by + cy + blocks_half);
    684   // Let's check if we can consider a JXJ block here at all.
    685   // This is not necessary in the basic use of hierarchically merging
    686   // blocks in the simplest possible way, but is needed when we try other
    687   // 'floating' options of merging, possibly after a simple hierarchical
    688   // merge has been explored.
    689   if (MultiBlockTransformCrossesHorizontalBoundary(*ac_strategy, bx + cx,
    690                                                    by + cy, bx + cx + blocks) ||
    691       MultiBlockTransformCrossesHorizontalBoundary(
    692           *ac_strategy, bx + cx, by + cy + blocks, bx + cx + blocks) ||
    693       MultiBlockTransformCrossesVerticalBoundary(*ac_strategy, bx + cx, by + cy,
    694                                                  by + cy + blocks) ||
    695       MultiBlockTransformCrossesVerticalBoundary(*ac_strategy, bx + cx + blocks,
    696                                                  by + cy, by + cy + blocks)) {
    697     return;  // not suitable for JxJ analysis, some transforms leak out.
    698   }
    699   // For floating transforms there may be
    700   // already blocks selected that make either or both JXK and
    701   // KXJ not feasible for this location.
    702   const bool allow_JXK = !MultiBlockTransformCrossesVerticalBoundary(
    703       *ac_strategy, bx + cx + blocks_half, by + cy, by + cy + blocks);
    704   const bool allow_KXJ = !MultiBlockTransformCrossesHorizontalBoundary(
    705       *ac_strategy, bx + cx, by + cy + blocks_half, bx + cx + blocks);
    706   // Current entropies aggregated on NxN resolution.
    707   float entropy[2][2] = {};
    708   for (size_t dy = 0; dy < blocks; ++dy) {
    709     for (size_t dx = 0; dx < blocks; ++dx) {
    710       entropy[dy / blocks_half][dx / blocks_half] +=
    711           entropy_estimate[(cy + dy) * 8 + (cx + dx)];
    712     }
    713   }
    714   float entropy_JXK_left = std::numeric_limits<float>::max();
    715   float entropy_JXK_right = std::numeric_limits<float>::max();
    716   float entropy_KXJ_top = std::numeric_limits<float>::max();
    717   float entropy_KXJ_bottom = std::numeric_limits<float>::max();
    718   float entropy_JXJ = std::numeric_limits<float>::max();
    719   if (allow_JXK) {
    720     if (row0[bx + cx + 0].RawStrategy() != acs_rawJXK) {
    721       entropy_JXK_left = EstimateEntropy(
    722           acsJXK, entropy_mul_JXK, (bx + cx + 0) * 8, (by + cy + 0) * 8, config,
    723           cmap_factors, block, scratch_space, quantized);
    724     }
    725     if (row0[bx + cx + blocks_half].RawStrategy() != acs_rawJXK) {
    726       entropy_JXK_right =
    727           EstimateEntropy(acsJXK, entropy_mul_JXK, (bx + cx + blocks_half) * 8,
    728                           (by + cy + 0) * 8, config, cmap_factors, block,
    729                           scratch_space, quantized);
    730     }
    731   }
    732   if (allow_KXJ) {
    733     if (row0[bx + cx].RawStrategy() != acs_rawKXJ) {
    734       entropy_KXJ_top = EstimateEntropy(
    735           acsKXJ, entropy_mul_JXK, (bx + cx + 0) * 8, (by + cy + 0) * 8, config,
    736           cmap_factors, block, scratch_space, quantized);
    737     }
    738     if (row1[bx + cx].RawStrategy() != acs_rawKXJ) {
    739       entropy_KXJ_bottom =
    740           EstimateEntropy(acsKXJ, entropy_mul_JXK, (bx + cx + 0) * 8,
    741                           (by + cy + blocks_half) * 8, config, cmap_factors,
    742                           block, scratch_space, quantized);
    743     }
    744   }
    745   if (allow_square_transform) {
    746     // We control the exploration of the square transform separately so that
    747     // we can turn it off at high decoding speeds for 32x32, but still allow
    748     // exploring 16x32 and 32x16.
    749     entropy_JXJ = EstimateEntropy(acsJXJ, entropy_mul_JXJ, (bx + cx + 0) * 8,
    750                                   (by + cy + 0) * 8, config, cmap_factors,
    751                                   block, scratch_space, quantized);
    752   }
    753 
    754   // Test if this block should have JXK or KXJ transforms,
    755   // because it can have only one or the other.
    756   float costJxN = std::min(entropy_JXK_left, entropy[0][0] + entropy[1][0]) +
    757                   std::min(entropy_JXK_right, entropy[0][1] + entropy[1][1]);
    758   float costNxJ = std::min(entropy_KXJ_top, entropy[0][0] + entropy[0][1]) +
    759                   std::min(entropy_KXJ_bottom, entropy[1][0] + entropy[1][1]);
    760   if (entropy_JXJ < costJxN && entropy_JXJ < costNxJ) {
    761     ac_strategy->Set(bx + cx, by + cy, acs_rawJXJ);
    762     SetEntropyForTransform(cx, cy, acs_rawJXJ, entropy_JXJ, entropy_estimate);
    763   } else if (costJxN < costNxJ) {
    764     if (entropy_JXK_left < entropy[0][0] + entropy[1][0]) {
    765       ac_strategy->Set(bx + cx, by + cy, acs_rawJXK);
    766       SetEntropyForTransform(cx, cy, acs_rawJXK, entropy_JXK_left,
    767                              entropy_estimate);
    768     }
    769     if (entropy_JXK_right < entropy[0][1] + entropy[1][1]) {
    770       ac_strategy->Set(bx + cx + blocks_half, by + cy, acs_rawJXK);
    771       SetEntropyForTransform(cx + blocks_half, cy, acs_rawJXK,
    772                              entropy_JXK_right, entropy_estimate);
    773     }
    774   } else {
    775     if (entropy_KXJ_top < entropy[0][0] + entropy[0][1]) {
    776       ac_strategy->Set(bx + cx, by + cy, acs_rawKXJ);
    777       SetEntropyForTransform(cx, cy, acs_rawKXJ, entropy_KXJ_top,
    778                              entropy_estimate);
    779     }
    780     if (entropy_KXJ_bottom < entropy[1][0] + entropy[1][1]) {
    781       ac_strategy->Set(bx + cx, by + cy + blocks_half, acs_rawKXJ);
    782       SetEntropyForTransform(cx, cy + blocks_half, acs_rawKXJ,
    783                              entropy_KXJ_bottom, entropy_estimate);
    784     }
    785   }
    786 }
    787 
    788 void ProcessRectACS(const CompressParams& cparams, const ACSConfig& config,
    789                     const Rect& rect, const ColorCorrelationMap& cmap,
    790                     float* JXL_RESTRICT block, uint32_t* JXL_RESTRICT quantized,
    791                     AcStrategyImage* ac_strategy) {
    792   // Main philosophy here:
    793   // 1. First find best 8x8 transform for each area.
    794   // 2. Merging them into larger transforms where possibly, but
    795   // starting from the smallest transforms (16x8 and 8x16).
    796   // Additional complication: 16x8 and 8x16 are considered
    797   // simultaneously and fairly against each other.
    798   // We are looking at 64x64 squares since the YtoX and YtoB
    799   // maps happen to be at that resolution, and having
    800   // integral transforms cross these boundaries leads to
    801   // additional complications.
    802   const float butteraugli_target = cparams.butteraugli_distance;
    803   float* JXL_RESTRICT scratch_space = block + 3 * AcStrategy::kMaxCoeffArea;
    804   size_t bx = rect.x0();
    805   size_t by = rect.y0();
    806   JXL_ASSERT(rect.xsize() <= 8);
    807   JXL_ASSERT(rect.ysize() <= 8);
    808   size_t tx = bx / kColorTileDimInBlocks;
    809   size_t ty = by / kColorTileDimInBlocks;
    810   const float cmap_factors[3] = {
    811       cmap.YtoXRatio(cmap.ytox_map.ConstRow(ty)[tx]),
    812       0.0f,
    813       cmap.YtoBRatio(cmap.ytob_map.ConstRow(ty)[tx]),
    814   };
    815   if (cparams.speed_tier > SpeedTier::kHare) return;
    816   // First compute the best 8x8 transform for each square. Later, we do not
    817   // experiment with different combinations, but only use the best of the 8x8s
    818   // when DCT8X8 is specified in the tree search.
    819   // 8x8 transforms have 10 variants, but every larger transform is just a DCT.
    820   float entropy_estimate[64] = {};
    821   // Favor all 8x8 transforms (against 16x8 and larger transforms)) at
    822   // low butteraugli_target distances.
    823   static const float k8x8mul1 = -0.4;
    824   static const float k8x8mul2 = 1.0;
    825   static const float k8x8base = 1.4;
    826   const float mul8x8 = k8x8mul2 + k8x8mul1 / (butteraugli_target + k8x8base);
    827   for (size_t iy = 0; iy < rect.ysize(); iy++) {
    828     for (size_t ix = 0; ix < rect.xsize(); ix++) {
    829       float entropy = 0.0;
    830       const uint8_t best_of_8x8s = FindBest8x8Transform(
    831           8 * (bx + ix), 8 * (by + iy), static_cast<int>(cparams.speed_tier),
    832           butteraugli_target, config, cmap_factors, ac_strategy, block,
    833           scratch_space, quantized, &entropy);
    834       ac_strategy->Set(bx + ix, by + iy,
    835                        static_cast<AcStrategy::Type>(best_of_8x8s));
    836       entropy_estimate[iy * 8 + ix] = entropy * mul8x8;
    837     }
    838   }
    839   // Merge when a larger transform is better than the previously
    840   // searched best combination of 8x8 transforms.
    841   struct MergeTry {
    842     AcStrategy::Type type;
    843     uint8_t priority;
    844     uint8_t decoding_speed_tier_max_limit;
    845     uint8_t encoding_speed_tier_max_limit;
    846     float entropy_mul;
    847   };
    848   // These numbers need to be figured out manually and looking at
    849   // ringing next to sky etc. Optimization will find larger numbers
    850   // and produce more ringing than is ideal. Larger numbers will
    851   // help stop ringing.
    852   const float entropy_mul16X8 = 1.25;
    853   const float entropy_mul16X16 = 1.35;
    854   const float entropy_mul16X32 = 1.5;
    855   const float entropy_mul32X32 = 1.5;
    856   const float entropy_mul64X32 = 2.26;
    857   const float entropy_mul64X64 = 2.26;
    858   // TODO(jyrki): Consider this feedback in further changes:
    859   // Also effectively when the multipliers for smaller blocks are
    860   // below 1, this raises the bar for the bigger blocks even higher
    861   // in that sense these constants are not independent (e.g. changing
    862   // the constant for DCT16x32 by -5% (making it more likely) also
    863   // means that DCT32x32 becomes harder to do when starting from
    864   // two DCT16x32s). It might be better to make them more independent,
    865   // e.g. by not applying the multiplier when storing the new entropy
    866   // estimates in TryMergeToACSCandidate().
    867   const MergeTry kTransformsForMerge[9] = {
    868       {AcStrategy::Type::DCT16X8, 2, 4, 5, entropy_mul16X8},
    869       {AcStrategy::Type::DCT8X16, 2, 4, 5, entropy_mul16X8},
    870       // FindBestFirstLevelDivisionForSquare looks for DCT16X16 and its
    871       // subdivisions. {AcStrategy::Type::DCT16X16, 3, entropy_mul16X16},
    872       {AcStrategy::Type::DCT16X32, 4, 4, 4, entropy_mul16X32},
    873       {AcStrategy::Type::DCT32X16, 4, 4, 4, entropy_mul16X32},
    874       // FindBestFirstLevelDivisionForSquare looks for DCT32X32 and its
    875       // subdivisions. {AcStrategy::Type::DCT32X32, 5, 1, 5,
    876       // 0.9822994906548809f},
    877       {AcStrategy::Type::DCT64X32, 6, 1, 3, entropy_mul64X32},
    878       {AcStrategy::Type::DCT32X64, 6, 1, 3, entropy_mul64X32},
    879       // {AcStrategy::Type::DCT64X64, 8, 1, 3, 2.0846542128012948f},
    880   };
    881   /*
    882   These sizes not yet included in merge heuristic:
    883   set(AcStrategy::Type::DCT32X8, 0.0f, 2.261390410971102f);
    884   set(AcStrategy::Type::DCT8X32, 0.0f, 2.261390410971102f);
    885   set(AcStrategy::Type::DCT128X128, 0.0f, 1.0f);
    886   set(AcStrategy::Type::DCT128X64, 0.0f, 0.73f);
    887   set(AcStrategy::Type::DCT64X128, 0.0f, 0.73f);
    888   set(AcStrategy::Type::DCT256X256, 0.0f, 1.0f);
    889   set(AcStrategy::Type::DCT256X128, 0.0f, 0.73f);
    890   set(AcStrategy::Type::DCT128X256, 0.0f, 0.73f);
    891   */
    892 
    893   // Priority is a tricky kludge to avoid collisions so that transforms
    894   // don't overlap.
    895   uint8_t priority[64] = {};
    896   bool enable_32x32 = cparams.decoding_speed_tier < 4;
    897   for (auto tx : kTransformsForMerge) {
    898     if (tx.decoding_speed_tier_max_limit < cparams.decoding_speed_tier) {
    899       continue;
    900     }
    901     AcStrategy acs = AcStrategy::FromRawStrategy(tx.type);
    902 
    903     for (size_t cy = 0; cy + acs.covered_blocks_y() - 1 < rect.ysize();
    904          cy += acs.covered_blocks_y()) {
    905       for (size_t cx = 0; cx + acs.covered_blocks_x() - 1 < rect.xsize();
    906            cx += acs.covered_blocks_x()) {
    907         if (cy + 7 < rect.ysize() && cx + 7 < rect.xsize()) {
    908           if (cparams.decoding_speed_tier < 4 &&
    909               tx.type == AcStrategy::Type::DCT32X64) {
    910             // We handle both DCT8X16 and DCT16X8 at the same time.
    911             if ((cy | cx) % 8 == 0) {
    912               FindBestFirstLevelDivisionForSquare(
    913                   8, true, bx, by, cx, cy, config, cmap_factors, ac_strategy,
    914                   tx.entropy_mul, entropy_mul64X64, entropy_estimate, block,
    915                   scratch_space, quantized);
    916             }
    917             continue;
    918           } else if (tx.type == AcStrategy::Type::DCT32X16) {
    919             // We handled both DCT8X16 and DCT16X8 at the same time,
    920             // and that is above. The last column and last row,
    921             // when the last column or last row is odd numbered,
    922             // are still handled by TryMergeAcs.
    923             continue;
    924           }
    925         }
    926         if ((tx.type == AcStrategy::Type::DCT16X32 && cy % 4 != 0) ||
    927             (tx.type == AcStrategy::Type::DCT32X16 && cx % 4 != 0)) {
    928           // already covered by FindBest32X32
    929           continue;
    930         }
    931 
    932         if (cy + 3 < rect.ysize() && cx + 3 < rect.xsize()) {
    933           if (tx.type == AcStrategy::Type::DCT16X32) {
    934             // We handle both DCT8X16 and DCT16X8 at the same time.
    935             if ((cy | cx) % 4 == 0) {
    936               FindBestFirstLevelDivisionForSquare(
    937                   4, enable_32x32, bx, by, cx, cy, config, cmap_factors,
    938                   ac_strategy, tx.entropy_mul, entropy_mul32X32,
    939                   entropy_estimate, block, scratch_space, quantized);
    940             }
    941             continue;
    942           } else if (tx.type == AcStrategy::Type::DCT32X16) {
    943             // We handled both DCT8X16 and DCT16X8 at the same time,
    944             // and that is above. The last column and last row,
    945             // when the last column or last row is odd numbered,
    946             // are still handled by TryMergeAcs.
    947             continue;
    948           }
    949         }
    950         if ((tx.type == AcStrategy::Type::DCT16X32 && cy % 4 != 0) ||
    951             (tx.type == AcStrategy::Type::DCT32X16 && cx % 4 != 0)) {
    952           // already covered by FindBest32X32
    953           continue;
    954         }
    955         if (cy + 1 < rect.ysize() && cx + 1 < rect.xsize()) {
    956           if (tx.type == AcStrategy::Type::DCT8X16) {
    957             // We handle both DCT8X16 and DCT16X8 at the same time.
    958             if ((cy | cx) % 2 == 0) {
    959               FindBestFirstLevelDivisionForSquare(
    960                   2, true, bx, by, cx, cy, config, cmap_factors, ac_strategy,
    961                   tx.entropy_mul, entropy_mul16X16, entropy_estimate, block,
    962                   scratch_space, quantized);
    963             }
    964             continue;
    965           } else if (tx.type == AcStrategy::Type::DCT16X8) {
    966             // We handled both DCT8X16 and DCT16X8 at the same time,
    967             // and that is above. The last column and last row,
    968             // when the last column or last row is odd numbered,
    969             // are still handled by TryMergeAcs.
    970             continue;
    971           }
    972         }
    973         if ((tx.type == AcStrategy::Type::DCT8X16 && cy % 2 == 1) ||
    974             (tx.type == AcStrategy::Type::DCT16X8 && cx % 2 == 1)) {
    975           // already covered by FindBestFirstLevelDivisionForSquare
    976           continue;
    977         }
    978         // All other merge sizes are handled here.
    979         // Some of the DCT16X8s and DCT8X16s will still leak through here
    980         // when there is an odd number of 8x8 blocks, then the last row
    981         // and column will get their DCT16X8s and DCT8X16s through the
    982         // normal integral transform merging process.
    983         TryMergeAcs(tx.type, bx, by, cx, cy, config, cmap_factors, ac_strategy,
    984                     tx.entropy_mul, tx.priority, &priority[0], entropy_estimate,
    985                     block, scratch_space, quantized);
    986       }
    987     }
    988   }
    989   if (cparams.speed_tier >= SpeedTier::kHare) {
    990     return;
    991   }
    992   // Here we still try to do some non-aligned matching, find a few more
    993   // 16X8, 8X16 and 16X16s between the non-2-aligned blocks.
    994   for (size_t cy = 0; cy + 1 < rect.ysize(); ++cy) {
    995     for (size_t cx = 0; cx + 1 < rect.xsize(); ++cx) {
    996       if ((cy | cx) % 2 != 0) {
    997         FindBestFirstLevelDivisionForSquare(
    998             2, true, bx, by, cx, cy, config, cmap_factors, ac_strategy,
    999             entropy_mul16X8, entropy_mul16X16, entropy_estimate, block,
   1000             scratch_space, quantized);
   1001       }
   1002     }
   1003   }
   1004   // Non-aligned matching for 32X32, 16X32 and 32X16.
   1005   size_t step = cparams.speed_tier >= SpeedTier::kTortoise ? 2 : 1;
   1006   for (size_t cy = 0; cy + 3 < rect.ysize(); cy += step) {
   1007     for (size_t cx = 0; cx + 3 < rect.xsize(); cx += step) {
   1008       if ((cy | cx) % 4 == 0) {
   1009         continue;  // Already tried with loop above (DCT16X32 case).
   1010       }
   1011       FindBestFirstLevelDivisionForSquare(
   1012           4, enable_32x32, bx, by, cx, cy, config, cmap_factors, ac_strategy,
   1013           entropy_mul16X32, entropy_mul32X32, entropy_estimate, block,
   1014           scratch_space, quantized);
   1015     }
   1016   }
   1017 }
   1018 
   1019 // NOLINTNEXTLINE(google-readability-namespace-comments)
   1020 }  // namespace HWY_NAMESPACE
   1021 }  // namespace jxl
   1022 HWY_AFTER_NAMESPACE();
   1023 
   1024 #if HWY_ONCE
   1025 namespace jxl {
   1026 HWY_EXPORT(ProcessRectACS);
   1027 
   1028 void AcStrategyHeuristics::Init(const Image3F& src, const Rect& rect_in,
   1029                                 const ImageF& quant_field, const ImageF& mask,
   1030                                 const ImageF& mask1x1,
   1031                                 DequantMatrices* matrices) {
   1032   config.dequant = matrices;
   1033 
   1034   if (cparams.speed_tier >= SpeedTier::kCheetah) {
   1035     JXL_CHECK(matrices->EnsureComputed(1));  // DCT8 only
   1036   } else {
   1037     uint32_t acs_mask = 0;
   1038     // All transforms up to 64x64.
   1039     for (size_t i = 0; i < AcStrategy::DCT128X128; i++) {
   1040       acs_mask |= (1 << i);
   1041     }
   1042     JXL_CHECK(matrices->EnsureComputed(acs_mask));
   1043   }
   1044 
   1045   // Image row pointers and strides.
   1046   config.quant_field_row = quant_field.Row(0);
   1047   config.quant_field_stride = quant_field.PixelsPerRow();
   1048   if (mask.xsize() > 0 && mask.ysize() > 0) {
   1049     config.masking_field_row = mask.Row(0);
   1050     config.masking_field_stride = mask.PixelsPerRow();
   1051   }
   1052   if (mask1x1.xsize() > 0 && mask1x1.ysize() > 0) {
   1053     config.masking1x1_field_row = mask1x1.Row(0);
   1054     config.masking1x1_field_stride = mask1x1.PixelsPerRow();
   1055   }
   1056 
   1057   config.src_rows[0] = rect_in.ConstPlaneRow(src, 0, 0);
   1058   config.src_rows[1] = rect_in.ConstPlaneRow(src, 1, 0);
   1059   config.src_rows[2] = rect_in.ConstPlaneRow(src, 2, 0);
   1060   config.src_stride = src.PixelsPerRow();
   1061 
   1062   // Entropy estimate is composed of two factors:
   1063   //  - estimate of the number of bits that will be used by the block
   1064   //  - information loss due to quantization
   1065   // The following constant controls the relative weights of these components.
   1066   config.info_loss_multiplier = 1.2;
   1067   config.zeros_mul = 9.3089059022677905;
   1068   config.cost_delta = 10.833273317067883;
   1069 
   1070   static const float kBias = 0.13731742964354549;
   1071   const float ratio = (cparams.butteraugli_distance + kBias) / (1.0f + kBias);
   1072 
   1073   static const float kPow1 = 0.33677806662454718;
   1074   static const float kPow2 = 0.50990926717963703;
   1075   static const float kPow3 = 0.36702940662370243;
   1076   config.info_loss_multiplier *= std::pow(ratio, kPow1);
   1077   config.zeros_mul *= std::pow(ratio, kPow2);
   1078   config.cost_delta *= std::pow(ratio, kPow3);
   1079 }
   1080 
   1081 void AcStrategyHeuristics::PrepareForThreads(std::size_t num_threads) {
   1082   const size_t dct_scratch_size =
   1083       3 * (MaxVectorSize() / sizeof(float)) * AcStrategy::kMaxBlockDim;
   1084   mem_per_thread = 6 * AcStrategy::kMaxCoeffArea + dct_scratch_size;
   1085   mem = hwy::AllocateAligned<float>(num_threads * mem_per_thread);
   1086   qmem_per_thread = AcStrategy::kMaxCoeffArea;
   1087   qmem = hwy::AllocateAligned<uint32_t>(num_threads * qmem_per_thread);
   1088 }
   1089 
   1090 void AcStrategyHeuristics::ProcessRect(const Rect& rect,
   1091                                        const ColorCorrelationMap& cmap,
   1092                                        AcStrategyImage* ac_strategy,
   1093                                        size_t thread) {
   1094   // In Falcon mode, use DCT8 everywhere and uniform quantization.
   1095   if (cparams.speed_tier >= SpeedTier::kCheetah) {
   1096     ac_strategy->FillDCT8(rect);
   1097     return;
   1098   }
   1099   HWY_DYNAMIC_DISPATCH(ProcessRectACS)
   1100   (cparams, config, rect, cmap, mem.get() + thread * mem_per_thread,
   1101    qmem.get() + thread * qmem_per_thread, ac_strategy);
   1102 }
   1103 
   1104 Status AcStrategyHeuristics::Finalize(const FrameDimensions& frame_dim,
   1105                                       const AcStrategyImage& ac_strategy,
   1106                                       AuxOut* aux_out) {
   1107   // Accounting and debug output.
   1108   if (aux_out != nullptr) {
   1109     aux_out->num_small_blocks =
   1110         ac_strategy.CountBlocks(AcStrategy::Type::IDENTITY) +
   1111         ac_strategy.CountBlocks(AcStrategy::Type::DCT2X2) +
   1112         ac_strategy.CountBlocks(AcStrategy::Type::DCT4X4);
   1113     aux_out->num_dct4x8_blocks =
   1114         ac_strategy.CountBlocks(AcStrategy::Type::DCT4X8) +
   1115         ac_strategy.CountBlocks(AcStrategy::Type::DCT8X4);
   1116     aux_out->num_afv_blocks = ac_strategy.CountBlocks(AcStrategy::Type::AFV0) +
   1117                               ac_strategy.CountBlocks(AcStrategy::Type::AFV1) +
   1118                               ac_strategy.CountBlocks(AcStrategy::Type::AFV2) +
   1119                               ac_strategy.CountBlocks(AcStrategy::Type::AFV3);
   1120     aux_out->num_dct8_blocks = ac_strategy.CountBlocks(AcStrategy::Type::DCT);
   1121     aux_out->num_dct8x16_blocks =
   1122         ac_strategy.CountBlocks(AcStrategy::Type::DCT8X16) +
   1123         ac_strategy.CountBlocks(AcStrategy::Type::DCT16X8);
   1124     aux_out->num_dct8x32_blocks =
   1125         ac_strategy.CountBlocks(AcStrategy::Type::DCT8X32) +
   1126         ac_strategy.CountBlocks(AcStrategy::Type::DCT32X8);
   1127     aux_out->num_dct16_blocks =
   1128         ac_strategy.CountBlocks(AcStrategy::Type::DCT16X16);
   1129     aux_out->num_dct16x32_blocks =
   1130         ac_strategy.CountBlocks(AcStrategy::Type::DCT16X32) +
   1131         ac_strategy.CountBlocks(AcStrategy::Type::DCT32X16);
   1132     aux_out->num_dct32_blocks =
   1133         ac_strategy.CountBlocks(AcStrategy::Type::DCT32X32);
   1134     aux_out->num_dct32x64_blocks =
   1135         ac_strategy.CountBlocks(AcStrategy::Type::DCT32X64) +
   1136         ac_strategy.CountBlocks(AcStrategy::Type::DCT64X32);
   1137     aux_out->num_dct64_blocks =
   1138         ac_strategy.CountBlocks(AcStrategy::Type::DCT64X64);
   1139   }
   1140 
   1141   // if (JXL_DEBUG_AC_STRATEGY && WantDebugOutput(aux_out)) {
   1142   if (JXL_DEBUG_AC_STRATEGY && WantDebugOutput(cparams)) {
   1143     JXL_RETURN_IF_ERROR(DumpAcStrategy(ac_strategy, frame_dim.xsize,
   1144                                        frame_dim.ysize, "ac_strategy", aux_out,
   1145                                        cparams));
   1146   }
   1147   return true;
   1148 }
   1149 
   1150 }  // namespace jxl
   1151 #endif  // HWY_ONCE