libjxl

FORK: libjxl patches used on blog
git clone https://git.neptards.moe/blog/libjxl.git
Log | Files | Refs | Submodules | README | LICENSE

compressed_dc.cc (11019B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 #include "lib/jxl/compressed_dc.h"
      7 
      8 #include <stdint.h>
      9 #include <stdlib.h>
     10 #include <string.h>
     11 
     12 #include <algorithm>
     13 #include <vector>
     14 
     15 #undef HWY_TARGET_INCLUDE
     16 #define HWY_TARGET_INCLUDE "lib/jxl/compressed_dc.cc"
     17 #include <hwy/aligned_allocator.h>
     18 #include <hwy/foreach_target.h>
     19 #include <hwy/highway.h>
     20 
     21 #include "lib/jxl/base/compiler_specific.h"
     22 #include "lib/jxl/base/data_parallel.h"
     23 #include "lib/jxl/base/status.h"
     24 #include "lib/jxl/image.h"
     25 HWY_BEFORE_NAMESPACE();
     26 namespace jxl {
     27 namespace HWY_NAMESPACE {
     28 
     29 using D = HWY_FULL(float);
     30 using DScalar = HWY_CAPPED(float, 1);
     31 
     32 // These templates are not found via ADL.
     33 using hwy::HWY_NAMESPACE::Abs;
     34 using hwy::HWY_NAMESPACE::Add;
     35 using hwy::HWY_NAMESPACE::Div;
     36 using hwy::HWY_NAMESPACE::Max;
     37 using hwy::HWY_NAMESPACE::Mul;
     38 using hwy::HWY_NAMESPACE::MulAdd;
     39 using hwy::HWY_NAMESPACE::Rebind;
     40 using hwy::HWY_NAMESPACE::Sub;
     41 using hwy::HWY_NAMESPACE::Vec;
     42 using hwy::HWY_NAMESPACE::ZeroIfNegative;
     43 
     44 // TODO(veluca): optimize constants.
     45 const float w1 = 0.20345139757231578f;
     46 const float w2 = 0.0334829185968739f;
     47 const float w0 = 1.0f - 4.0f * (w1 + w2);
     48 
     49 template <class V>
     50 V MaxWorkaround(V a, V b) {
     51 #if (HWY_TARGET == HWY_AVX3) && HWY_COMPILER_CLANG <= 800
     52   // Prevents "Do not know how to split the result of this operator" error
     53   return IfThenElse(a > b, a, b);
     54 #else
     55   return Max(a, b);
     56 #endif
     57 }
     58 
     59 template <typename D>
     60 JXL_INLINE void ComputePixelChannel(const D d, const float dc_factor,
     61                                     const float* JXL_RESTRICT row_top,
     62                                     const float* JXL_RESTRICT row,
     63                                     const float* JXL_RESTRICT row_bottom,
     64                                     Vec<D>* JXL_RESTRICT mc,
     65                                     Vec<D>* JXL_RESTRICT sm,
     66                                     Vec<D>* JXL_RESTRICT gap, size_t x) {
     67   const auto tl = LoadU(d, row_top + x - 1);
     68   const auto tc = Load(d, row_top + x);
     69   const auto tr = LoadU(d, row_top + x + 1);
     70 
     71   const auto ml = LoadU(d, row + x - 1);
     72   *mc = Load(d, row + x);
     73   const auto mr = LoadU(d, row + x + 1);
     74 
     75   const auto bl = LoadU(d, row_bottom + x - 1);
     76   const auto bc = Load(d, row_bottom + x);
     77   const auto br = LoadU(d, row_bottom + x + 1);
     78 
     79   const auto w_center = Set(d, w0);
     80   const auto w_side = Set(d, w1);
     81   const auto w_corner = Set(d, w2);
     82 
     83   const auto corner = Add(Add(tl, tr), Add(bl, br));
     84   const auto side = Add(Add(ml, mr), Add(tc, bc));
     85   *sm = MulAdd(corner, w_corner, MulAdd(side, w_side, Mul(*mc, w_center)));
     86 
     87   const auto dc_quant = Set(d, dc_factor);
     88   *gap = MaxWorkaround(*gap, Abs(Div(Sub(*mc, *sm), dc_quant)));
     89 }
     90 
     91 template <typename D>
     92 JXL_INLINE void ComputePixel(
     93     const float* JXL_RESTRICT dc_factors,
     94     const float* JXL_RESTRICT* JXL_RESTRICT rows_top,
     95     const float* JXL_RESTRICT* JXL_RESTRICT rows,
     96     const float* JXL_RESTRICT* JXL_RESTRICT rows_bottom,
     97     float* JXL_RESTRICT* JXL_RESTRICT out_rows, size_t x) {
     98   const D d;
     99   auto mc_x = Undefined(d);
    100   auto mc_y = Undefined(d);
    101   auto mc_b = Undefined(d);
    102   auto sm_x = Undefined(d);
    103   auto sm_y = Undefined(d);
    104   auto sm_b = Undefined(d);
    105   auto gap = Set(d, 0.5f);
    106   ComputePixelChannel(d, dc_factors[0], rows_top[0], rows[0], rows_bottom[0],
    107                       &mc_x, &sm_x, &gap, x);
    108   ComputePixelChannel(d, dc_factors[1], rows_top[1], rows[1], rows_bottom[1],
    109                       &mc_y, &sm_y, &gap, x);
    110   ComputePixelChannel(d, dc_factors[2], rows_top[2], rows[2], rows_bottom[2],
    111                       &mc_b, &sm_b, &gap, x);
    112   auto factor = MulAdd(Set(d, -4.0f), gap, Set(d, 3.0f));
    113   factor = ZeroIfNegative(factor);
    114 
    115   auto out = MulAdd(Sub(sm_x, mc_x), factor, mc_x);
    116   Store(out, d, out_rows[0] + x);
    117   out = MulAdd(Sub(sm_y, mc_y), factor, mc_y);
    118   Store(out, d, out_rows[1] + x);
    119   out = MulAdd(Sub(sm_b, mc_b), factor, mc_b);
    120   Store(out, d, out_rows[2] + x);
    121 }
    122 
    123 Status AdaptiveDCSmoothing(const float* dc_factors, Image3F* dc,
    124                            ThreadPool* pool) {
    125   const size_t xsize = dc->xsize();
    126   const size_t ysize = dc->ysize();
    127   if (ysize <= 2 || xsize <= 2) return true;
    128 
    129   // TODO(veluca): use tile-based processing?
    130   // TODO(veluca): decide if changes to the y channel should be propagated to
    131   // the x and b channels through color correlation.
    132   JXL_ASSERT(w1 + w2 < 0.25f);
    133 
    134   JXL_ASSIGN_OR_RETURN(Image3F smoothed, Image3F::Create(xsize, ysize));
    135   // Fill in borders that the loop below will not. First and last are unused.
    136   for (size_t c = 0; c < 3; c++) {
    137     for (size_t y : {static_cast<size_t>(0), ysize - 1}) {
    138       memcpy(smoothed.PlaneRow(c, y), dc->PlaneRow(c, y),
    139              xsize * sizeof(float));
    140     }
    141   }
    142   auto process_row = [&](const uint32_t y, size_t /*thread*/) {
    143     const float* JXL_RESTRICT rows_top[3]{
    144         dc->ConstPlaneRow(0, y - 1),
    145         dc->ConstPlaneRow(1, y - 1),
    146         dc->ConstPlaneRow(2, y - 1),
    147     };
    148     const float* JXL_RESTRICT rows[3] = {
    149         dc->ConstPlaneRow(0, y),
    150         dc->ConstPlaneRow(1, y),
    151         dc->ConstPlaneRow(2, y),
    152     };
    153     const float* JXL_RESTRICT rows_bottom[3] = {
    154         dc->ConstPlaneRow(0, y + 1),
    155         dc->ConstPlaneRow(1, y + 1),
    156         dc->ConstPlaneRow(2, y + 1),
    157     };
    158     float* JXL_RESTRICT rows_out[3] = {
    159         smoothed.PlaneRow(0, y),
    160         smoothed.PlaneRow(1, y),
    161         smoothed.PlaneRow(2, y),
    162     };
    163     for (size_t x : {static_cast<size_t>(0), xsize - 1}) {
    164       for (size_t c = 0; c < 3; c++) {
    165         rows_out[c][x] = rows[c][x];
    166       }
    167     }
    168 
    169     size_t x = 1;
    170     // First pixels
    171     const size_t N = Lanes(D());
    172     for (; x < std::min(N, xsize - 1); x++) {
    173       ComputePixel<DScalar>(dc_factors, rows_top, rows, rows_bottom, rows_out,
    174                             x);
    175     }
    176     // Full vectors.
    177     for (; x + N <= xsize - 1; x += N) {
    178       ComputePixel<D>(dc_factors, rows_top, rows, rows_bottom, rows_out, x);
    179     }
    180     // Last pixels.
    181     for (; x < xsize - 1; x++) {
    182       ComputePixel<DScalar>(dc_factors, rows_top, rows, rows_bottom, rows_out,
    183                             x);
    184     }
    185   };
    186   JXL_CHECK(RunOnPool(pool, 1, ysize - 1, ThreadPool::NoInit, process_row,
    187                       "DCSmoothingRow"));
    188   dc->Swap(smoothed);
    189   return true;
    190 }
    191 
    192 // DC dequantization.
    193 void DequantDC(const Rect& r, Image3F* dc, ImageB* quant_dc, const Image& in,
    194                const float* dc_factors, float mul, const float* cfl_factors,
    195                const YCbCrChromaSubsampling& chroma_subsampling,
    196                const BlockCtxMap& bctx) {
    197   const HWY_FULL(float) df;
    198   const Rebind<pixel_type, HWY_FULL(float)> di;  // assumes pixel_type <= float
    199   if (chroma_subsampling.Is444()) {
    200     const auto fac_x = Set(df, dc_factors[0] * mul);
    201     const auto fac_y = Set(df, dc_factors[1] * mul);
    202     const auto fac_b = Set(df, dc_factors[2] * mul);
    203     const auto cfl_fac_x = Set(df, cfl_factors[0]);
    204     const auto cfl_fac_b = Set(df, cfl_factors[2]);
    205     for (size_t y = 0; y < r.ysize(); y++) {
    206       float* dec_row_x = r.PlaneRow(dc, 0, y);
    207       float* dec_row_y = r.PlaneRow(dc, 1, y);
    208       float* dec_row_b = r.PlaneRow(dc, 2, y);
    209       const int32_t* quant_row_x = in.channel[1].plane.Row(y);
    210       const int32_t* quant_row_y = in.channel[0].plane.Row(y);
    211       const int32_t* quant_row_b = in.channel[2].plane.Row(y);
    212       for (size_t x = 0; x < r.xsize(); x += Lanes(di)) {
    213         const auto in_q_x = Load(di, quant_row_x + x);
    214         const auto in_q_y = Load(di, quant_row_y + x);
    215         const auto in_q_b = Load(di, quant_row_b + x);
    216         const auto in_x = Mul(ConvertTo(df, in_q_x), fac_x);
    217         const auto in_y = Mul(ConvertTo(df, in_q_y), fac_y);
    218         const auto in_b = Mul(ConvertTo(df, in_q_b), fac_b);
    219         Store(in_y, df, dec_row_y + x);
    220         Store(MulAdd(in_y, cfl_fac_x, in_x), df, dec_row_x + x);
    221         Store(MulAdd(in_y, cfl_fac_b, in_b), df, dec_row_b + x);
    222       }
    223     }
    224   } else {
    225     for (size_t c : {1, 0, 2}) {
    226       Rect rect(r.x0() >> chroma_subsampling.HShift(c),
    227                 r.y0() >> chroma_subsampling.VShift(c),
    228                 r.xsize() >> chroma_subsampling.HShift(c),
    229                 r.ysize() >> chroma_subsampling.VShift(c));
    230       const auto fac = Set(df, dc_factors[c] * mul);
    231       const Channel& ch = in.channel[c < 2 ? c ^ 1 : c];
    232       for (size_t y = 0; y < rect.ysize(); y++) {
    233         const int32_t* quant_row = ch.plane.Row(y);
    234         float* row = rect.PlaneRow(dc, c, y);
    235         for (size_t x = 0; x < rect.xsize(); x += Lanes(di)) {
    236           const auto in_q = Load(di, quant_row + x);
    237           const auto in = Mul(ConvertTo(df, in_q), fac);
    238           Store(in, df, row + x);
    239         }
    240       }
    241     }
    242   }
    243   if (bctx.num_dc_ctxs <= 1) {
    244     for (size_t y = 0; y < r.ysize(); y++) {
    245       uint8_t* qdc_row = r.Row(quant_dc, y);
    246       memset(qdc_row, 0, sizeof(*qdc_row) * r.xsize());
    247     }
    248   } else {
    249     for (size_t y = 0; y < r.ysize(); y++) {
    250       uint8_t* qdc_row_val = r.Row(quant_dc, y);
    251       const int32_t* quant_row_x =
    252           in.channel[1].plane.Row(y >> chroma_subsampling.VShift(0));
    253       const int32_t* quant_row_y =
    254           in.channel[0].plane.Row(y >> chroma_subsampling.VShift(1));
    255       const int32_t* quant_row_b =
    256           in.channel[2].plane.Row(y >> chroma_subsampling.VShift(2));
    257       for (size_t x = 0; x < r.xsize(); x++) {
    258         int bucket_x = 0;
    259         int bucket_y = 0;
    260         int bucket_b = 0;
    261         for (int t : bctx.dc_thresholds[0]) {
    262           if (quant_row_x[x >> chroma_subsampling.HShift(0)] > t) bucket_x++;
    263         }
    264         for (int t : bctx.dc_thresholds[1]) {
    265           if (quant_row_y[x >> chroma_subsampling.HShift(1)] > t) bucket_y++;
    266         }
    267         for (int t : bctx.dc_thresholds[2]) {
    268           if (quant_row_b[x >> chroma_subsampling.HShift(2)] > t) bucket_b++;
    269         }
    270         int bucket = bucket_x;
    271         bucket *= bctx.dc_thresholds[2].size() + 1;
    272         bucket += bucket_b;
    273         bucket *= bctx.dc_thresholds[1].size() + 1;
    274         bucket += bucket_y;
    275         qdc_row_val[x] = bucket;
    276       }
    277     }
    278   }
    279 }
    280 
    281 // NOLINTNEXTLINE(google-readability-namespace-comments)
    282 }  // namespace HWY_NAMESPACE
    283 }  // namespace jxl
    284 HWY_AFTER_NAMESPACE();
    285 
    286 #if HWY_ONCE
    287 namespace jxl {
    288 
    289 HWY_EXPORT(DequantDC);
    290 HWY_EXPORT(AdaptiveDCSmoothing);
    291 Status AdaptiveDCSmoothing(const float* dc_factors, Image3F* dc,
    292                            ThreadPool* pool) {
    293   return HWY_DYNAMIC_DISPATCH(AdaptiveDCSmoothing)(dc_factors, dc, pool);
    294 }
    295 
    296 void DequantDC(const Rect& r, Image3F* dc, ImageB* quant_dc, const Image& in,
    297                const float* dc_factors, float mul, const float* cfl_factors,
    298                const YCbCrChromaSubsampling& chroma_subsampling,
    299                const BlockCtxMap& bctx) {
    300   HWY_DYNAMIC_DISPATCH(DequantDC)
    301   (r, dc, quant_dc, in, dc_factors, mul, cfl_factors, chroma_subsampling, bctx);
    302 }
    303 
    304 }  // namespace jxl
    305 #endif  // HWY_ONCE