libjxl

FORK: libjxl patches used on blog
git clone https://git.neptards.moe/blog/libjxl.git
Log | Files | Refs | Submodules | README | LICENSE

enc_ar_control_field.cc (13388B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 #include "lib/jxl/enc_ar_control_field.h"
      7 
      8 #include <stdint.h>
      9 #include <stdlib.h>
     10 
     11 #include <algorithm>
     12 
     13 #undef HWY_TARGET_INCLUDE
     14 #define HWY_TARGET_INCLUDE "lib/jxl/enc_ar_control_field.cc"
     15 #include <hwy/foreach_target.h>
     16 #include <hwy/highway.h>
     17 
     18 #include "lib/jxl/ac_strategy.h"
     19 #include "lib/jxl/base/compiler_specific.h"
     20 #include "lib/jxl/base/status.h"
     21 #include "lib/jxl/enc_params.h"
     22 #include "lib/jxl/image.h"
     23 #include "lib/jxl/image_ops.h"
     24 
     25 HWY_BEFORE_NAMESPACE();
     26 namespace jxl {
     27 namespace HWY_NAMESPACE {
     28 namespace {
     29 
     30 // These templates are not found via ADL.
     31 using hwy::HWY_NAMESPACE::Add;
     32 using hwy::HWY_NAMESPACE::GetLane;
     33 using hwy::HWY_NAMESPACE::Mul;
     34 using hwy::HWY_NAMESPACE::MulAdd;
     35 using hwy::HWY_NAMESPACE::Sqrt;
     36 
     37 Status ProcessTile(const CompressParams& cparams,
     38                    const FrameHeader& frame_header, const Image3F& opsin,
     39                    const Rect& opsin_rect, const ImageF& quant_field,
     40                    const AcStrategyImage& ac_strategy, ImageB* epf_sharpness,
     41                    const Rect& rect,
     42                    ArControlFieldHeuristics::TempImages* temp_image) {
     43   JXL_ASSERT(opsin_rect.x0() % 8 == 0);
     44   JXL_ASSERT(opsin_rect.y0() % 8 == 0);
     45   JXL_ASSERT(opsin_rect.xsize() % 8 == 0);
     46   JXL_ASSERT(opsin_rect.ysize() % 8 == 0);
     47   constexpr size_t N = kBlockDim;
     48   if (cparams.butteraugli_distance < kMinButteraugliForDynamicAR ||
     49       cparams.speed_tier > SpeedTier::kWombat ||
     50       frame_header.loop_filter.epf_iters == 0) {
     51     FillPlane(static_cast<uint8_t>(4), epf_sharpness, rect);
     52     return true;
     53   }
     54 
     55   // Likely better to have a higher X weight, like:
     56   // const float kChannelWeights[3] = {47.0f, 4.35f, 0.287f};
     57   const float kChannelWeights[3] = {4.35f, 4.35f, 0.287f};
     58   const float kChannelWeightsLapNeg[3] = {-0.125f * kChannelWeights[0],
     59                                           -0.125f * kChannelWeights[1],
     60                                           -0.125f * kChannelWeights[2]};
     61   const size_t sharpness_stride =
     62       static_cast<size_t>(epf_sharpness->PixelsPerRow());
     63 
     64   size_t by0 = opsin_rect.y0() / 8 + rect.y0();
     65   size_t by1 = by0 + rect.ysize();
     66   size_t bx0 = opsin_rect.x0() / 8 + rect.x0();
     67   size_t bx1 = bx0 + rect.xsize();
     68   JXL_RETURN_IF_ERROR(temp_image->InitOnce());
     69   ImageF& laplacian_sqrsum = temp_image->laplacian_sqrsum;
     70   // Calculate the L2 of the 3x3 Laplacian in an integral transform
     71   // (for example 32x32 dct). This relates to transforms ability
     72   // to propagate artefacts.
     73   size_t y0 = by0 == 0 ? 0 : by0 * N - 2;
     74   size_t y1 = by1 * N == opsin.ysize() ? by1 * N : by1 * N + 2;
     75   size_t x0 = bx0 == 0 ? 0 : bx0 * N - 2;
     76   size_t x1 = bx1 * N == opsin.xsize() ? bx1 * N : bx1 * N + 2;
     77   HWY_FULL(float) df;
     78   for (size_t y = y0; y < y1; y++) {
     79     float* JXL_RESTRICT laplacian_sqrsum_row =
     80         laplacian_sqrsum.Row(y + 2 - by0 * N);
     81     const float* JXL_RESTRICT in_row_t[3];
     82     const float* JXL_RESTRICT in_row[3];
     83     const float* JXL_RESTRICT in_row_b[3];
     84     for (size_t c = 0; c < 3; c++) {
     85       in_row_t[c] = opsin.ConstPlaneRow(c, y > 0 ? y - 1 : y);
     86       in_row[c] = opsin.ConstPlaneRow(c, y);
     87       in_row_b[c] = opsin.ConstPlaneRow(c, y + 1 < opsin.ysize() ? y + 1 : y);
     88     }
     89     auto compute_laplacian_scalar = [&](size_t x) {
     90       const size_t prevX = x >= 1 ? x - 1 : x;
     91       const size_t nextX = x + 1 < opsin.xsize() ? x + 1 : x;
     92       float sumsqr = 0;
     93       for (size_t c = 0; c < 3; c++) {
     94         float laplacian =
     95             kChannelWeights[c] * in_row[c][x] +
     96             kChannelWeightsLapNeg[c] *
     97                 (in_row[c][prevX] + in_row[c][nextX] + in_row_b[c][prevX] +
     98                  in_row_b[c][x] + in_row_b[c][nextX] + in_row_t[c][prevX] +
     99                  in_row_t[c][x] + in_row_t[c][nextX]);
    100         sumsqr += laplacian * laplacian;
    101       }
    102       laplacian_sqrsum_row[x + 2 - bx0 * N] = sumsqr;
    103     };
    104     size_t x = x0;
    105     for (; x < 1; x++) {
    106       compute_laplacian_scalar(x);
    107     }
    108     // Interior. One extra pixel of border as the last pixel is special.
    109     for (; x + Lanes(df) <= x1 && x + Lanes(df) + 1 <= opsin.xsize();
    110          x += Lanes(df)) {
    111       auto sumsqr = Zero(df);
    112       for (size_t c = 0; c < 3; c++) {
    113         auto laplacian =
    114             Mul(LoadU(df, in_row[c] + x), Set(df, kChannelWeights[c]));
    115         auto sum_oth0 = LoadU(df, in_row[c] + x - 1);
    116         auto sum_oth1 = LoadU(df, in_row[c] + x + 1);
    117         auto sum_oth2 = LoadU(df, in_row_t[c] + x - 1);
    118         auto sum_oth3 = LoadU(df, in_row_t[c] + x);
    119         sum_oth0 = Add(sum_oth0, LoadU(df, in_row_t[c] + x + 1));
    120         sum_oth1 = Add(sum_oth1, LoadU(df, in_row_b[c] + x - 1));
    121         sum_oth2 = Add(sum_oth2, LoadU(df, in_row_b[c] + x));
    122         sum_oth3 = Add(sum_oth3, LoadU(df, in_row_b[c] + x + 1));
    123         sum_oth0 = Add(sum_oth0, sum_oth1);
    124         sum_oth2 = Add(sum_oth2, sum_oth3);
    125         sum_oth0 = Add(sum_oth0, sum_oth2);
    126         laplacian =
    127             MulAdd(Set(df, kChannelWeightsLapNeg[c]), sum_oth0, laplacian);
    128         sumsqr = MulAdd(laplacian, laplacian, sumsqr);
    129       }
    130       StoreU(sumsqr, df, laplacian_sqrsum_row + x + 2 - bx0 * N);
    131     }
    132     for (; x < x1; x++) {
    133       compute_laplacian_scalar(x);
    134     }
    135   }
    136   HWY_CAPPED(float, 4) df4;
    137   // Calculate the L2 of the 3x3 Laplacian in 4x4 blocks within the area
    138   // of the integral transform. Sample them within the integral transform
    139   // with two offsets (0,0) and (-2, -2) pixels (sqrsum_00 and sqrsum_22,
    140   //  respectively).
    141   ImageF& sqrsum_00 = temp_image->sqrsum_00;
    142   size_t sqrsum_00_stride = sqrsum_00.PixelsPerRow();
    143   float* JXL_RESTRICT sqrsum_00_row = sqrsum_00.Row(0);
    144   for (size_t y = 0; y < rect.ysize() * 2; y++) {
    145     const float* JXL_RESTRICT rows_in[4];
    146     for (size_t iy = 0; iy < 4; iy++) {
    147       rows_in[iy] = laplacian_sqrsum.ConstRow(y * 4 + iy + 2);
    148     }
    149     float* JXL_RESTRICT row_out = sqrsum_00_row + y * sqrsum_00_stride;
    150     for (size_t x = 0; x < rect.xsize() * 2; x++) {
    151       auto sum = Zero(df4);
    152       for (size_t iy = 0; iy < 4; iy++) {
    153         for (size_t ix = 0; ix < 4; ix += Lanes(df4)) {
    154           sum = Add(sum, LoadU(df4, rows_in[iy] + x * 4 + ix + 2));
    155         }
    156       }
    157       row_out[x] = GetLane(Sqrt(SumOfLanes(df4, sum))) * (1.0f / 4.0f);
    158     }
    159   }
    160   // Indexing iy and ix is a bit tricky as we include a 2 pixel border
    161   // around the block for evenness calculations. This is similar to what
    162   // we did in guetzli for the observability of artefacts, except there
    163   // the element is a sliding 5x5, not sparsely sampled 4x4 box like here.
    164   ImageF& sqrsum_22 = temp_image->sqrsum_22;
    165   size_t sqrsum_22_stride = sqrsum_22.PixelsPerRow();
    166   float* JXL_RESTRICT sqrsum_22_row = sqrsum_22.Row(0);
    167   for (size_t y = 0; y < rect.ysize() * 2 + 1; y++) {
    168     const float* JXL_RESTRICT rows_in[4];
    169     for (size_t iy = 0; iy < 4; iy++) {
    170       rows_in[iy] = laplacian_sqrsum.ConstRow(y * 4 + iy);
    171     }
    172     float* JXL_RESTRICT row_out = sqrsum_22_row + y * sqrsum_22_stride;
    173     // ignore pixels outside the image.
    174     // Y coordinates are relative to by0*8+y*4.
    175     size_t sy = y * 4 + by0 * 8 > 0 ? 0 : 2;
    176     size_t ey = y * 4 + by0 * 8 + 2 <= opsin.ysize()
    177                     ? 4
    178                     : opsin.ysize() - y * 4 - by0 * 8 + 2;
    179     for (size_t x = 0; x < rect.xsize() * 2 + 1; x++) {
    180       // ignore pixels outside the image.
    181       // X coordinates are relative to bx0*8.
    182       size_t sx = x * 4 + bx0 * 8 > 0 ? x * 4 : x * 4 + 2;
    183       size_t ex = x * 4 + bx0 * 8 + 2 <= opsin.xsize()
    184                       ? x * 4 + 4
    185                       : opsin.xsize() - bx0 * 8 + 2;
    186       if (ex - sx == 4 && ey - sy == 4) {
    187         auto sum = Zero(df4);
    188         for (size_t iy = sy; iy < ey; iy++) {
    189           for (size_t ix = sx; ix < ex; ix += Lanes(df4)) {
    190             sum = Add(sum, Load(df4, rows_in[iy] + ix));
    191           }
    192         }
    193         row_out[x] = GetLane(Sqrt(SumOfLanes(df4, sum))) * (1.0f / 4.0f);
    194       } else {
    195         float sum = 0;
    196         for (size_t iy = sy; iy < ey; iy++) {
    197           for (size_t ix = sx; ix < ex; ix++) {
    198             sum += rows_in[iy][ix];
    199           }
    200         }
    201         row_out[x] = std::sqrt(sum / ((ex - sx) * (ey - sy)));
    202       }
    203     }
    204   }
    205   for (size_t by = rect.y0(); by < rect.y1(); by++) {
    206     AcStrategyRow acs_row = ac_strategy.ConstRow(by);
    207     uint8_t* JXL_RESTRICT out_row = epf_sharpness->Row(by);
    208     const float* JXL_RESTRICT quant_row = quant_field.Row(by);
    209     for (size_t bx = rect.x0(); bx < rect.x1(); bx++) {
    210       AcStrategy acs = acs_row[bx];
    211       if (!acs.IsFirstBlock()) continue;
    212       // The errors are going to be linear to the quantization value in this
    213       // locality. We only have access to the initial quant field here.
    214       float quant_val = 1.0f / quant_row[bx];
    215 
    216       const auto sq00 = [&](size_t y, size_t x) {
    217         return sqrsum_00_row[((by - rect.y0()) * 2 + y) * sqrsum_00_stride +
    218                              (bx - rect.x0()) * 2 + x];
    219       };
    220       const auto sq22 = [&](size_t y, size_t x) {
    221         return sqrsum_22_row[((by - rect.y0()) * 2 + y) * sqrsum_22_stride +
    222                              (bx - rect.x0()) * 2 + x];
    223       };
    224       float sqrsum_integral_transform = 0;
    225       for (size_t iy = 0; iy < acs.covered_blocks_y() * 2; iy++) {
    226         for (size_t ix = 0; ix < acs.covered_blocks_x() * 2; ix++) {
    227           sqrsum_integral_transform += sq00(iy, ix) * sq00(iy, ix);
    228         }
    229       }
    230       sqrsum_integral_transform /=
    231           4 * acs.covered_blocks_x() * acs.covered_blocks_y();
    232       sqrsum_integral_transform = std::sqrt(sqrsum_integral_transform);
    233       // If masking is high or amplitude of the artefacts is low, then no
    234       // smoothing is needed.
    235       for (size_t iy = 0; iy < acs.covered_blocks_y(); iy++) {
    236         for (size_t ix = 0; ix < acs.covered_blocks_x(); ix++) {
    237           // Five 4x4 blocks for masking estimation, all within the
    238           // 8x8 area.
    239           float minval_1 = std::min(sq00(2 * iy + 0, 2 * ix + 0),
    240                                     sq00(2 * iy + 0, 2 * ix + 1));
    241           float minval_2 = std::min(sq00(2 * iy + 1, 2 * ix + 0),
    242                                     sq00(2 * iy + 1, 2 * ix + 1));
    243           float minval = std::min(minval_1, minval_2);
    244           minval = std::min(minval, sq22(2 * iy + 1, 2 * ix + 1));
    245           // Nine more 4x4 blocks for masking estimation, includes
    246           // the 2 pixel area around the 8x8 block being controlled.
    247           float minval2_1 = std::min(sq22(2 * iy + 0, 2 * ix + 0),
    248                                      sq22(2 * iy + 0, 2 * ix + 1));
    249           float minval2_2 = std::min(sq22(2 * iy + 0, 2 * ix + 2),
    250                                      sq22(2 * iy + 1, 2 * ix + 0));
    251           float minval2_3 = std::min(sq22(2 * iy + 1, 2 * ix + 1),
    252                                      sq22(2 * iy + 1, 2 * ix + 2));
    253           float minval2_4 = std::min(sq22(2 * iy + 2, 2 * ix + 0),
    254                                      sq22(2 * iy + 2, 2 * ix + 1));
    255           float minval2_5 = std::min(minval2_1, minval2_2);
    256           float minval2_6 = std::min(minval2_3, minval2_4);
    257           float minval2 = std::min(minval2_5, minval2_6);
    258           minval2 = std::min(minval2, sq22(2 * iy + 2, 2 * ix + 2));
    259           float minval3 = std::min(minval, minval2);
    260           minval *= 0.125f;
    261           minval += 0.625f * minval3;
    262           minval +=
    263               0.125f * std::min(1.5f * minval3, sq22(2 * iy + 1, 2 * ix + 1));
    264           minval += 0.125f * minval2;
    265           // Larger kBias, less smoothing for low intensity changes.
    266           float kDeltaLimit = 3.2;
    267           float bias = 0.0625f * quant_val;
    268           float delta =
    269               (sqrsum_integral_transform + (kDeltaLimit + 0.05) * bias) /
    270               (minval + bias);
    271           int out = 4;
    272           if (delta > kDeltaLimit) {
    273             out = 4;  // smooth
    274           } else {
    275             out = 0;
    276           }
    277           // 'threshold' is separate from 'bias' for easier tuning of these
    278           // heuristics.
    279           float threshold = 0.0625f * quant_val;
    280           const float kSmoothLimit = 0.085f;
    281           float smooth = 0.20f * (sq00(2 * iy + 0, 2 * ix + 0) +
    282                                   sq00(2 * iy + 0, 2 * ix + 1) +
    283                                   sq00(2 * iy + 1, 2 * ix + 0) +
    284                                   sq00(2 * iy + 1, 2 * ix + 1) + minval);
    285           if (smooth < kSmoothLimit * threshold) {
    286             out = 4;
    287           }
    288           out_row[bx + sharpness_stride * iy + ix] = out;
    289         }
    290       }
    291     }
    292   }
    293   return true;
    294 }
    295 
    296 }  // namespace
    297 // NOLINTNEXTLINE(google-readability-namespace-comments)
    298 }  // namespace HWY_NAMESPACE
    299 }  // namespace jxl
    300 HWY_AFTER_NAMESPACE();
    301 
    302 #if HWY_ONCE
    303 namespace jxl {
    304 HWY_EXPORT(ProcessTile);
    305 
    306 Status ArControlFieldHeuristics::RunRect(
    307     const CompressParams& cparams, const FrameHeader& frame_header,
    308     const Rect& block_rect, const Image3F& opsin, const Rect& opsin_rect,
    309     const ImageF& quant_field, const AcStrategyImage& ac_strategy,
    310     ImageB* epf_sharpness, size_t thread) {
    311   return HWY_DYNAMIC_DISPATCH(ProcessTile)(
    312       cparams, frame_header, opsin, opsin_rect, quant_field, ac_strategy,
    313       epf_sharpness, block_rect, &temp_images[thread]);
    314 }
    315 
    316 }  // namespace jxl
    317 
    318 #endif