libjxl

FORK: libjxl patches used on blog
git clone https://git.neptards.moe/blog/libjxl.git
Log | Files | Refs | Submodules | README | LICENSE

encode_finish.cc (7575B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 #include "lib/jpegli/encode_finish.h"
      7 
      8 #include <cmath>
      9 #include <limits>
     10 
     11 #include "lib/jpegli/error.h"
     12 #include "lib/jpegli/memory_manager.h"
     13 #include "lib/jpegli/quant.h"
     14 
     15 #undef HWY_TARGET_INCLUDE
     16 #define HWY_TARGET_INCLUDE "lib/jpegli/encode_finish.cc"
     17 #include <hwy/foreach_target.h>
     18 #include <hwy/highway.h>
     19 
     20 #include "lib/jpegli/dct-inl.h"
     21 
     22 HWY_BEFORE_NAMESPACE();
     23 namespace jpegli {
     24 namespace HWY_NAMESPACE {
     25 
     26 // These templates are not found via ADL.
     27 using hwy::HWY_NAMESPACE::GetLane;
     28 
     29 using D = HWY_FULL(float);
     30 using DI = HWY_FULL(int32_t);
     31 using DI16 = Rebind<int16_t, HWY_FULL(int32_t)>;
     32 
     33 void ReQuantizeBlock(int16_t* block, const float* qmc, float aq_strength,
     34                      const float* zero_bias_offset,
     35                      const float* zero_bias_mul) {
     36   D d;
     37   DI di;
     38   DI16 di16;
     39   const auto aq_mul = Set(d, aq_strength);
     40   for (size_t k = 0; k < DCTSIZE2; k += Lanes(d)) {
     41     const auto in = Load(di16, block + k);
     42     const auto val = ConvertTo(d, PromoteTo(di, in));
     43     const auto q = Load(d, qmc + k);
     44     const auto qval = Mul(val, q);
     45     const auto zb_offset = Load(d, zero_bias_offset + k);
     46     const auto zb_mul = Load(d, zero_bias_mul + k);
     47     const auto threshold = Add(zb_offset, Mul(zb_mul, aq_mul));
     48     const auto nzero_mask = Ge(Abs(qval), threshold);
     49     const auto iqval = IfThenElseZero(nzero_mask, Round(qval));
     50     Store(DemoteTo(di16, ConvertTo(di, iqval)), di16, block + k);
     51   }
     52 }
     53 
     54 float BlockError(const int16_t* block, const float* qmc, const float* iqmc,
     55                  const float aq_strength, const float* zero_bias_offset,
     56                  const float* zero_bias_mul) {
     57   D d;
     58   DI di;
     59   DI16 di16;
     60   auto err = Zero(d);
     61   const auto scale = Set(d, 1.0 / 16);
     62   const auto aq_mul = Set(d, aq_strength);
     63   for (size_t k = 0; k < DCTSIZE2; k += Lanes(d)) {
     64     const auto in = Load(di16, block + k);
     65     const auto val = ConvertTo(d, PromoteTo(di, in));
     66     const auto q = Load(d, qmc + k);
     67     const auto qval = Mul(val, q);
     68     const auto zb_offset = Load(d, zero_bias_offset + k);
     69     const auto zb_mul = Load(d, zero_bias_mul + k);
     70     const auto threshold = Add(zb_offset, Mul(zb_mul, aq_mul));
     71     const auto nzero_mask = Ge(Abs(qval), threshold);
     72     const auto iqval = IfThenElseZero(nzero_mask, Round(qval));
     73     const auto invq = Load(d, iqmc + k);
     74     const auto rval = Mul(iqval, invq);
     75     const auto diff = Mul(Sub(val, rval), scale);
     76     err = Add(err, Mul(diff, diff));
     77   }
     78   return GetLane(SumOfLanes(d, err));
     79 }
     80 
     81 void ComputeInverseWeights(const float* qmc, float* iqmc) {
     82   for (int k = 0; k < 64; ++k) {
     83     iqmc[k] = 1.0f / qmc[k];
     84   }
     85 }
     86 
     87 float ComputePSNR(j_compress_ptr cinfo, int sampling) {
     88   jpeg_comp_master* m = cinfo->master;
     89   InitQuantizer(cinfo, QuantPass::SEARCH_SECOND_PASS);
     90   double error = 0.0;
     91   size_t num = 0;
     92   for (int c = 0; c < cinfo->num_components; ++c) {
     93     jpeg_component_info* comp = &cinfo->comp_info[c];
     94     const float* qmc = m->quant_mul[c];
     95     const int h_factor = m->h_factor[c];
     96     const int v_factor = m->v_factor[c];
     97     const float* zero_bias_offset = m->zero_bias_offset[c];
     98     const float* zero_bias_mul = m->zero_bias_mul[c];
     99     HWY_ALIGN float iqmc[64];
    100     ComputeInverseWeights(qmc, iqmc);
    101     for (JDIMENSION by = 0; by < comp->height_in_blocks; by += sampling) {
    102       JBLOCKARRAY ba = GetBlockRow(cinfo, c, by);
    103       const float* qf = m->quant_field.Row(by * v_factor);
    104       for (JDIMENSION bx = 0; bx < comp->width_in_blocks; bx += sampling) {
    105         error += BlockError(&ba[0][bx][0], qmc, iqmc, qf[bx * h_factor],
    106                             zero_bias_offset, zero_bias_mul);
    107         num += DCTSIZE2;
    108       }
    109     }
    110   }
    111   return 4.3429448f * log(num / (error / 255. / 255.));
    112 }
    113 
    114 void ReQuantizeCoeffs(j_compress_ptr cinfo) {
    115   jpeg_comp_master* m = cinfo->master;
    116   InitQuantizer(cinfo, QuantPass::SEARCH_SECOND_PASS);
    117   for (int c = 0; c < cinfo->num_components; ++c) {
    118     jpeg_component_info* comp = &cinfo->comp_info[c];
    119     const float* qmc = m->quant_mul[c];
    120     const int h_factor = m->h_factor[c];
    121     const int v_factor = m->v_factor[c];
    122     const float* zero_bias_offset = m->zero_bias_offset[c];
    123     const float* zero_bias_mul = m->zero_bias_mul[c];
    124     for (JDIMENSION by = 0; by < comp->height_in_blocks; ++by) {
    125       JBLOCKARRAY ba = GetBlockRow(cinfo, c, by);
    126       const float* qf = m->quant_field.Row(by * v_factor);
    127       for (JDIMENSION bx = 0; bx < comp->width_in_blocks; ++bx) {
    128         ReQuantizeBlock(&ba[0][bx][0], qmc, qf[bx * h_factor], zero_bias_offset,
    129                         zero_bias_mul);
    130       }
    131     }
    132   }
    133 }
    134 
    135 // NOLINTNEXTLINE(google-readability-namespace-comments)
    136 }  // namespace HWY_NAMESPACE
    137 }  // namespace jpegli
    138 HWY_AFTER_NAMESPACE();
    139 
    140 #if HWY_ONCE
    141 namespace jpegli {
    142 namespace {
    143 HWY_EXPORT(ComputePSNR);
    144 HWY_EXPORT(ReQuantizeCoeffs);
    145 
    146 void ReQuantizeCoeffs(j_compress_ptr cinfo) {
    147   HWY_DYNAMIC_DISPATCH(ReQuantizeCoeffs)(cinfo);
    148 }
    149 
    150 float ComputePSNR(j_compress_ptr cinfo, int sampling) {
    151   return HWY_DYNAMIC_DISPATCH(ComputePSNR)(cinfo, sampling);
    152 }
    153 
    154 void UpdateDistance(j_compress_ptr cinfo, float distance) {
    155   float distances[NUM_QUANT_TBLS] = {distance, distance, distance};
    156   SetQuantMatrices(cinfo, distances, /*add_two_chroma_tables=*/true);
    157 }
    158 
    159 float Clamp(float val, float minval, float maxval) {
    160   return std::max(minval, std::min(maxval, val));
    161 }
    162 
    163 #define PSNR_SEARCH_DBG 0
    164 
    165 float FindDistanceForPSNR(j_compress_ptr cinfo) {
    166   constexpr int kMaxIters = 20;
    167   const float psnr_target = cinfo->master->psnr_target;
    168   const float tolerance = cinfo->master->psnr_tolerance;
    169   const float min_dist = cinfo->master->min_distance;
    170   const float max_dist = cinfo->master->max_distance;
    171   float d = Clamp(1.0f, min_dist, max_dist);
    172   for (int sampling : {4, 1}) {
    173     float best_diff = std::numeric_limits<float>::max();
    174     float best_distance = 0.0f;
    175     float best_psnr = 0.0;
    176     float dmin = min_dist;
    177     float dmax = max_dist;
    178     bool found_lower_bound = false;
    179     bool found_upper_bound = false;
    180     for (int i = 0; i < kMaxIters; ++i) {
    181       UpdateDistance(cinfo, d);
    182       float psnr = ComputePSNR(cinfo, sampling);
    183       if (psnr > psnr_target) {
    184         dmin = d;
    185         found_lower_bound = true;
    186       } else {
    187         dmax = d;
    188         found_upper_bound = true;
    189       }
    190 #if (PSNR_SEARCH_DBG > 1)
    191       printf("sampling %d iter %2d d %7.4f psnr %.2f", sampling, i, d, psnr);
    192       if (found_upper_bound && found_lower_bound) {
    193         printf("    d-interval: [ %7.4f .. %7.4f ]", dmin, dmax);
    194       }
    195       printf("\n");
    196 #endif
    197       float diff = std::abs(psnr - psnr_target);
    198       if (diff < best_diff) {
    199         best_diff = diff;
    200         best_distance = d;
    201         best_psnr = psnr;
    202       }
    203       if (diff < tolerance * psnr_target || dmin == dmax) {
    204         break;
    205       }
    206       if (!found_lower_bound || !found_upper_bound) {
    207         d *= std::exp(0.15f * (psnr - psnr_target));
    208       } else {
    209         d = 0.5f * (dmin + dmax);
    210       }
    211       d = Clamp(d, min_dist, max_dist);
    212     }
    213     d = best_distance;
    214     if (sampling == 1 && PSNR_SEARCH_DBG) {
    215       printf("Final PSNR %.2f at distance %.4f\n", best_psnr, d);
    216     } else {
    217       (void)best_psnr;
    218     }
    219   }
    220   return d;
    221 }
    222 
    223 }  // namespace
    224 
    225 void QuantizetoPSNR(j_compress_ptr cinfo) {
    226   float distance = FindDistanceForPSNR(cinfo);
    227   UpdateDistance(cinfo, distance);
    228   ReQuantizeCoeffs(cinfo);
    229 }
    230 
    231 }  // namespace jpegli
    232 #endif  // HWY_ONCE