libjxl

FORK: libjxl patches used on blog
git clone https://git.neptards.moe/blog/libjxl.git
Log | Files | Refs | Submodules | README | LICENSE

enc_chroma_from_luma.cc (15890B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 #include "lib/jxl/enc_chroma_from_luma.h"
      7 
      8 #include <float.h>
      9 #include <stdlib.h>
     10 
     11 #include <algorithm>
     12 #include <cmath>
     13 
     14 #undef HWY_TARGET_INCLUDE
     15 #define HWY_TARGET_INCLUDE "lib/jxl/enc_chroma_from_luma.cc"
     16 #include <hwy/aligned_allocator.h>
     17 #include <hwy/foreach_target.h>
     18 #include <hwy/highway.h>
     19 
     20 #include "lib/jxl/base/common.h"
     21 #include "lib/jxl/base/status.h"
     22 #include "lib/jxl/cms/opsin_params.h"
     23 #include "lib/jxl/dec_transforms-inl.h"
     24 #include "lib/jxl/enc_aux_out.h"
     25 #include "lib/jxl/enc_params.h"
     26 #include "lib/jxl/enc_transforms-inl.h"
     27 #include "lib/jxl/quantizer.h"
     28 #include "lib/jxl/simd_util.h"
     29 HWY_BEFORE_NAMESPACE();
     30 namespace jxl {
     31 namespace HWY_NAMESPACE {
     32 
     33 // These templates are not found via ADL.
     34 using hwy::HWY_NAMESPACE::Abs;
     35 using hwy::HWY_NAMESPACE::Ge;
     36 using hwy::HWY_NAMESPACE::GetLane;
     37 using hwy::HWY_NAMESPACE::IfThenElse;
     38 using hwy::HWY_NAMESPACE::Lt;
     39 
     40 static HWY_FULL(float) df;
     41 
     42 struct CFLFunction {
     43   static constexpr float kCoeff = 1.f / 3;
     44   static constexpr float kThres = 100.0f;
     45   static constexpr float kInvColorFactor = 1.0f / kDefaultColorFactor;
     46   CFLFunction(const float* values_m, const float* values_s, size_t num,
     47               float base, float distance_mul)
     48       : values_m(values_m),
     49         values_s(values_s),
     50         num(num),
     51         base(base),
     52         distance_mul(distance_mul) {}
     53 
     54   // Returns f'(x), where f is 1/3 * sum ((|color residual| + 1)^2-1) +
     55   // distance_mul * x^2 * num.
     56   float Compute(float x, float eps, float* fpeps, float* fmeps) const {
     57     float first_derivative = 2 * distance_mul * num * x;
     58     float first_derivative_peps = 2 * distance_mul * num * (x + eps);
     59     float first_derivative_meps = 2 * distance_mul * num * (x - eps);
     60 
     61     const auto inv_color_factor = Set(df, kInvColorFactor);
     62     const auto thres = Set(df, kThres);
     63     const auto coeffx2 = Set(df, kCoeff * 2.0f);
     64     const auto one = Set(df, 1.0f);
     65     const auto zero = Set(df, 0.0f);
     66     const auto base_v = Set(df, base);
     67     const auto x_v = Set(df, x);
     68     const auto xpe_v = Set(df, x + eps);
     69     const auto xme_v = Set(df, x - eps);
     70     auto fd_v = Zero(df);
     71     auto fdpe_v = Zero(df);
     72     auto fdme_v = Zero(df);
     73     JXL_ASSERT(num % Lanes(df) == 0);
     74 
     75     for (size_t i = 0; i < num; i += Lanes(df)) {
     76       // color residual = ax + b
     77       const auto a = Mul(inv_color_factor, Load(df, values_m + i));
     78       const auto b =
     79           Sub(Mul(base_v, Load(df, values_m + i)), Load(df, values_s + i));
     80       const auto v = MulAdd(a, x_v, b);
     81       const auto vpe = MulAdd(a, xpe_v, b);
     82       const auto vme = MulAdd(a, xme_v, b);
     83       const auto av = Abs(v);
     84       const auto avpe = Abs(vpe);
     85       const auto avme = Abs(vme);
     86       const auto acoeffx2 = Mul(coeffx2, a);
     87       auto d = Mul(acoeffx2, Add(av, one));
     88       auto dpe = Mul(acoeffx2, Add(avpe, one));
     89       auto dme = Mul(acoeffx2, Add(avme, one));
     90       d = IfThenElse(Lt(v, zero), Sub(zero, d), d);
     91       dpe = IfThenElse(Lt(vpe, zero), Sub(zero, dpe), dpe);
     92       dme = IfThenElse(Lt(vme, zero), Sub(zero, dme), dme);
     93       const auto above = Ge(av, thres);
     94       // TODO(eustas): use IfThenElseZero
     95       fd_v = Add(fd_v, IfThenElse(above, zero, d));
     96       fdpe_v = Add(fdpe_v, IfThenElse(above, zero, dpe));
     97       fdme_v = Add(fdme_v, IfThenElse(above, zero, dme));
     98     }
     99 
    100     *fpeps = first_derivative_peps + GetLane(SumOfLanes(df, fdpe_v));
    101     *fmeps = first_derivative_meps + GetLane(SumOfLanes(df, fdme_v));
    102     return first_derivative + GetLane(SumOfLanes(df, fd_v));
    103   }
    104 
    105   const float* JXL_RESTRICT values_m;
    106   const float* JXL_RESTRICT values_s;
    107   size_t num;
    108   float base;
    109   float distance_mul;
    110 };
    111 
    112 // Chroma-from-luma search, values_m will have luma -- and values_s chroma.
    113 int32_t FindBestMultiplier(const float* values_m, const float* values_s,
    114                            size_t num, float base, float distance_mul,
    115                            bool fast) {
    116   if (num == 0) {
    117     return 0;
    118   }
    119   float x;
    120   if (fast) {
    121     static constexpr float kInvColorFactor = 1.0f / kDefaultColorFactor;
    122     auto ca = Zero(df);
    123     auto cb = Zero(df);
    124     const auto inv_color_factor = Set(df, kInvColorFactor);
    125     const auto base_v = Set(df, base);
    126     for (size_t i = 0; i < num; i += Lanes(df)) {
    127       // color residual = ax + b
    128       const auto a = Mul(inv_color_factor, Load(df, values_m + i));
    129       const auto b =
    130           Sub(Mul(base_v, Load(df, values_m + i)), Load(df, values_s + i));
    131       ca = MulAdd(a, a, ca);
    132       cb = MulAdd(a, b, cb);
    133     }
    134     // + distance_mul * x^2 * num
    135     x = -GetLane(SumOfLanes(df, cb)) /
    136         (GetLane(SumOfLanes(df, ca)) + num * distance_mul * 0.5f);
    137   } else {
    138     constexpr float eps = 100;
    139     constexpr float kClamp = 20.0f;
    140     CFLFunction fn(values_m, values_s, num, base, distance_mul);
    141     x = 0;
    142     // Up to 20 Newton iterations, with approximate derivatives.
    143     // Derivatives are approximate due to the high amount of noise in the exact
    144     // derivatives.
    145     for (size_t i = 0; i < 20; i++) {
    146       float dfpeps;
    147       float dfmeps;
    148       float df = fn.Compute(x, eps, &dfpeps, &dfmeps);
    149       float ddf = (dfpeps - dfmeps) / (2 * eps);
    150       float kExperimentalInsignificantStabilizer = 0.85;
    151       float step = df / (ddf + kExperimentalInsignificantStabilizer);
    152       x -= std::min(kClamp, std::max(-kClamp, step));
    153       if (std::abs(step) < 3e-3) break;
    154     }
    155   }
    156   // CFL seems to be tricky for larger transforms for HF components
    157   // close to zero. This heuristic brings the solutions closer to zero
    158   // and reduces red-green oscillations. A better approach would
    159   // look into variance of the multiplier within separate (e.g. 8x8)
    160   // areas and only apply this heuristic where there is a high variance.
    161   // This would give about 1 % more compression density.
    162   float towards_zero = 2.6;
    163   if (x >= towards_zero) {
    164     x -= towards_zero;
    165   } else if (x <= -towards_zero) {
    166     x += towards_zero;
    167   } else {
    168     x = 0;
    169   }
    170   return std::max(-128.0f, std::min(127.0f, roundf(x)));
    171 }
    172 
    173 Status InitDCStorage(size_t num_blocks, ImageF* dc_values) {
    174   // First row: Y channel
    175   // Second row: X channel
    176   // Third row: Y channel
    177   // Fourth row: B channel
    178   JXL_ASSIGN_OR_RETURN(*dc_values,
    179                        ImageF::Create(RoundUpTo(num_blocks, Lanes(df)), 4));
    180 
    181   JXL_ASSERT(dc_values->xsize() != 0);
    182   // Zero-fill the last lanes
    183   for (size_t y = 0; y < 4; y++) {
    184     for (size_t x = dc_values->xsize() - Lanes(df); x < dc_values->xsize();
    185          x++) {
    186       dc_values->Row(y)[x] = 0;
    187     }
    188   }
    189   return true;
    190 }
    191 
    192 void ComputeTile(const Image3F& opsin, const Rect& opsin_rect,
    193                  const DequantMatrices& dequant,
    194                  const AcStrategyImage* ac_strategy,
    195                  const ImageI* raw_quant_field, const Quantizer* quantizer,
    196                  const Rect& rect, bool fast, bool use_dct8, ImageSB* map_x,
    197                  ImageSB* map_b, ImageF* dc_values, float* mem) {
    198   static_assert(kEncTileDimInBlocks == kColorTileDimInBlocks,
    199                 "Invalid color tile dim");
    200   size_t xsize_blocks = opsin_rect.xsize() / kBlockDim;
    201   constexpr float kDistanceMultiplierAC = 1e-9f;
    202   const size_t dct_scratch_size =
    203       3 * (MaxVectorSize() / sizeof(float)) * AcStrategy::kMaxBlockDim;
    204 
    205   const size_t y0 = rect.y0();
    206   const size_t x0 = rect.x0();
    207   const size_t x1 = rect.x0() + rect.xsize();
    208   const size_t y1 = rect.y0() + rect.ysize();
    209 
    210   int ty = y0 / kColorTileDimInBlocks;
    211   int tx = x0 / kColorTileDimInBlocks;
    212 
    213   int8_t* JXL_RESTRICT row_out_x = map_x->Row(ty);
    214   int8_t* JXL_RESTRICT row_out_b = map_b->Row(ty);
    215 
    216   float* JXL_RESTRICT dc_values_yx = dc_values->Row(0);
    217   float* JXL_RESTRICT dc_values_x = dc_values->Row(1);
    218   float* JXL_RESTRICT dc_values_yb = dc_values->Row(2);
    219   float* JXL_RESTRICT dc_values_b = dc_values->Row(3);
    220 
    221   // All are aligned.
    222   float* HWY_RESTRICT block_y = mem;
    223   float* HWY_RESTRICT block_x = block_y + AcStrategy::kMaxCoeffArea;
    224   float* HWY_RESTRICT block_b = block_x + AcStrategy::kMaxCoeffArea;
    225   float* HWY_RESTRICT coeffs_yx = block_b + AcStrategy::kMaxCoeffArea;
    226   float* HWY_RESTRICT coeffs_x = coeffs_yx + kColorTileDim * kColorTileDim;
    227   float* HWY_RESTRICT coeffs_yb = coeffs_x + kColorTileDim * kColorTileDim;
    228   float* HWY_RESTRICT coeffs_b = coeffs_yb + kColorTileDim * kColorTileDim;
    229   float* HWY_RESTRICT scratch_space = coeffs_b + kColorTileDim * kColorTileDim;
    230   float* scratch_space_end =
    231       scratch_space + 2 * AcStrategy::kMaxCoeffArea + dct_scratch_size;
    232   JXL_DASSERT(scratch_space_end == block_y + CfLHeuristics::ItemsPerThread());
    233   (void)scratch_space_end;
    234 
    235   // Small (~256 bytes each)
    236   HWY_ALIGN_MAX float
    237       dc_y[AcStrategy::kMaxCoeffBlocks * AcStrategy::kMaxCoeffBlocks] = {};
    238   HWY_ALIGN_MAX float
    239       dc_x[AcStrategy::kMaxCoeffBlocks * AcStrategy::kMaxCoeffBlocks] = {};
    240   HWY_ALIGN_MAX float
    241       dc_b[AcStrategy::kMaxCoeffBlocks * AcStrategy::kMaxCoeffBlocks] = {};
    242   size_t num_ac = 0;
    243 
    244   for (size_t y = y0; y < y1; ++y) {
    245     const float* JXL_RESTRICT row_y =
    246         opsin_rect.ConstPlaneRow(opsin, 1, y * kBlockDim);
    247     const float* JXL_RESTRICT row_x =
    248         opsin_rect.ConstPlaneRow(opsin, 0, y * kBlockDim);
    249     const float* JXL_RESTRICT row_b =
    250         opsin_rect.ConstPlaneRow(opsin, 2, y * kBlockDim);
    251     size_t stride = opsin.PixelsPerRow();
    252 
    253     for (size_t x = x0; x < x1; x++) {
    254       AcStrategy acs = use_dct8
    255                            ? AcStrategy::FromRawStrategy(AcStrategy::Type::DCT)
    256                            : ac_strategy->ConstRow(y)[x];
    257       if (!acs.IsFirstBlock()) continue;
    258       size_t xs = acs.covered_blocks_x();
    259       TransformFromPixels(acs.Strategy(), row_y + x * kBlockDim, stride,
    260                           block_y, scratch_space);
    261       DCFromLowestFrequencies(acs.Strategy(), block_y, dc_y, xs);
    262       TransformFromPixels(acs.Strategy(), row_x + x * kBlockDim, stride,
    263                           block_x, scratch_space);
    264       DCFromLowestFrequencies(acs.Strategy(), block_x, dc_x, xs);
    265       TransformFromPixels(acs.Strategy(), row_b + x * kBlockDim, stride,
    266                           block_b, scratch_space);
    267       DCFromLowestFrequencies(acs.Strategy(), block_b, dc_b, xs);
    268       const float* const JXL_RESTRICT qm_x =
    269           dequant.InvMatrix(acs.Strategy(), 0);
    270       const float* const JXL_RESTRICT qm_b =
    271           dequant.InvMatrix(acs.Strategy(), 2);
    272       float q_dc_x = use_dct8 ? 1 : 1.0f / quantizer->GetInvDcStep(0);
    273       float q_dc_b = use_dct8 ? 1 : 1.0f / quantizer->GetInvDcStep(2);
    274 
    275       // Copy DCs in dc_values.
    276       for (size_t iy = 0; iy < acs.covered_blocks_y(); iy++) {
    277         for (size_t ix = 0; ix < xs; ix++) {
    278           dc_values_yx[(iy + y) * xsize_blocks + ix + x] =
    279               dc_y[iy * xs + ix] * q_dc_x;
    280           dc_values_x[(iy + y) * xsize_blocks + ix + x] =
    281               dc_x[iy * xs + ix] * q_dc_x;
    282           dc_values_yb[(iy + y) * xsize_blocks + ix + x] =
    283               dc_y[iy * xs + ix] * q_dc_b;
    284           dc_values_b[(iy + y) * xsize_blocks + ix + x] =
    285               dc_b[iy * xs + ix] * q_dc_b;
    286         }
    287       }
    288 
    289       // Do not use this block for computing AC CfL.
    290       if (acs.covered_blocks_x() + x0 > x1 ||
    291           acs.covered_blocks_y() + y0 > y1) {
    292         continue;
    293       }
    294 
    295       // Copy AC coefficients in the local block. The order in which
    296       // coefficients get stored does not matter.
    297       size_t cx = acs.covered_blocks_x();
    298       size_t cy = acs.covered_blocks_y();
    299       CoefficientLayout(&cy, &cx);
    300       // Zero out LFs. This introduces terms in the optimization loop that
    301       // don't affect the result, as they are all 0, but allow for simpler
    302       // SIMDfication.
    303       for (size_t iy = 0; iy < cy; iy++) {
    304         for (size_t ix = 0; ix < cx; ix++) {
    305           block_y[cx * kBlockDim * iy + ix] = 0;
    306           block_x[cx * kBlockDim * iy + ix] = 0;
    307           block_b[cx * kBlockDim * iy + ix] = 0;
    308         }
    309       }
    310       // Unclear why this is like it is. (This works slightly better
    311       // than the previous approach which was also a hack.)
    312       const float qq =
    313           (raw_quant_field == nullptr) ? 1.0f : raw_quant_field->Row(y)[x];
    314       // Experimentally values 128-130 seem best -- I don't know why we
    315       // need this multiplier.
    316       const float kStrangeMultiplier = 128;
    317       float q = use_dct8 ? 1 : quantizer->Scale() * kStrangeMultiplier * qq;
    318       const auto qv = Set(df, q);
    319       for (size_t i = 0; i < cx * cy * 64; i += Lanes(df)) {
    320         const auto b_y = Load(df, block_y + i);
    321         const auto b_x = Load(df, block_x + i);
    322         const auto b_b = Load(df, block_b + i);
    323         const auto qqm_x = Mul(qv, Load(df, qm_x + i));
    324         const auto qqm_b = Mul(qv, Load(df, qm_b + i));
    325         Store(Mul(b_y, qqm_x), df, coeffs_yx + num_ac);
    326         Store(Mul(b_x, qqm_x), df, coeffs_x + num_ac);
    327         Store(Mul(b_y, qqm_b), df, coeffs_yb + num_ac);
    328         Store(Mul(b_b, qqm_b), df, coeffs_b + num_ac);
    329         num_ac += Lanes(df);
    330       }
    331     }
    332   }
    333   JXL_CHECK(num_ac % Lanes(df) == 0);
    334   row_out_x[tx] = FindBestMultiplier(coeffs_yx, coeffs_x, num_ac, 0.0f,
    335                                      kDistanceMultiplierAC, fast);
    336   row_out_b[tx] =
    337       FindBestMultiplier(coeffs_yb, coeffs_b, num_ac, jxl::cms::kYToBRatio,
    338                          kDistanceMultiplierAC, fast);
    339 }
    340 
    341 // NOLINTNEXTLINE(google-readability-namespace-comments)
    342 }  // namespace HWY_NAMESPACE
    343 }  // namespace jxl
    344 HWY_AFTER_NAMESPACE();
    345 
    346 #if HWY_ONCE
    347 namespace jxl {
    348 
    349 HWY_EXPORT(InitDCStorage);
    350 HWY_EXPORT(ComputeTile);
    351 
    352 Status CfLHeuristics::Init(const Rect& rect) {
    353   size_t xsize_blocks = rect.xsize() / kBlockDim;
    354   size_t ysize_blocks = rect.ysize() / kBlockDim;
    355   return HWY_DYNAMIC_DISPATCH(InitDCStorage)(xsize_blocks * ysize_blocks,
    356                                              &dc_values);
    357 }
    358 
    359 void CfLHeuristics::ComputeTile(const Rect& r, const Image3F& opsin,
    360                                 const Rect& opsin_rect,
    361                                 const DequantMatrices& dequant,
    362                                 const AcStrategyImage* ac_strategy,
    363                                 const ImageI* raw_quant_field,
    364                                 const Quantizer* quantizer, bool fast,
    365                                 size_t thread, ColorCorrelationMap* cmap) {
    366   bool use_dct8 = ac_strategy == nullptr;
    367   HWY_DYNAMIC_DISPATCH(ComputeTile)
    368   (opsin, opsin_rect, dequant, ac_strategy, raw_quant_field, quantizer, r, fast,
    369    use_dct8, &cmap->ytox_map, &cmap->ytob_map, &dc_values,
    370    mem.get() + thread * ItemsPerThread());
    371 }
    372 
    373 void ColorCorrelationMapEncodeDC(const ColorCorrelationMap& map,
    374                                  BitWriter* writer, size_t layer,
    375                                  AuxOut* aux_out) {
    376   float color_factor = map.GetColorFactor();
    377   float base_correlation_x = map.GetBaseCorrelationX();
    378   float base_correlation_b = map.GetBaseCorrelationB();
    379   int32_t ytox_dc = map.GetYToXDC();
    380   int32_t ytob_dc = map.GetYToBDC();
    381 
    382   BitWriter::Allotment allotment(writer, 1 + 2 * kBitsPerByte + 12 + 32);
    383   if (ytox_dc == 0 && ytob_dc == 0 && color_factor == kDefaultColorFactor &&
    384       base_correlation_x == 0.0f &&
    385       base_correlation_b == jxl::cms::kYToBRatio) {
    386     writer->Write(1, 1);
    387     allotment.ReclaimAndCharge(writer, layer, aux_out);
    388     return;
    389   }
    390   writer->Write(1, 0);
    391   JXL_CHECK(U32Coder::Write(kColorFactorDist, color_factor, writer));
    392   JXL_CHECK(F16Coder::Write(base_correlation_x, writer));
    393   JXL_CHECK(F16Coder::Write(base_correlation_b, writer));
    394   writer->Write(kBitsPerByte, ytox_dc - std::numeric_limits<int8_t>::min());
    395   writer->Write(kBitsPerByte, ytob_dc - std::numeric_limits<int8_t>::min());
    396   allotment.ReclaimAndCharge(writer, layer, aux_out);
    397 }
    398 
    399 }  // namespace jxl
    400 #endif  // HWY_ONCE