libjxl

FORK: libjxl patches used on blog
git clone https://git.neptards.moe/blog/libjxl.git
Log | Files | Refs | Submodules | README | LICENSE

render.cc (29746B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 #include "lib/jpegli/render.h"
      7 
      8 #include <string.h>
      9 
     10 #include <array>
     11 #include <cmath>
     12 #include <cstddef>
     13 #include <cstdint>
     14 #include <hwy/aligned_allocator.h>
     15 #include <vector>
     16 
     17 #include "lib/jpegli/color_quantize.h"
     18 #include "lib/jpegli/color_transform.h"
     19 #include "lib/jpegli/decode_internal.h"
     20 #include "lib/jpegli/error.h"
     21 #include "lib/jpegli/idct.h"
     22 #include "lib/jpegli/upsample.h"
     23 #include "lib/jxl/base/byte_order.h"
     24 #include "lib/jxl/base/compiler_specific.h"
     25 #include "lib/jxl/base/status.h"
     26 
     27 #ifdef MEMORY_SANITIZER
     28 #define JXL_MEMORY_SANITIZER 1
     29 #elif defined(__has_feature)
     30 #if __has_feature(memory_sanitizer)
     31 #define JXL_MEMORY_SANITIZER 1
     32 #else
     33 #define JXL_MEMORY_SANITIZER 0
     34 #endif
     35 #else
     36 #define JXL_MEMORY_SANITIZER 0
     37 #endif
     38 
     39 #if JXL_MEMORY_SANITIZER
     40 #include "sanitizer/msan_interface.h"
     41 #endif
     42 
     43 #undef HWY_TARGET_INCLUDE
     44 #define HWY_TARGET_INCLUDE "lib/jpegli/render.cc"
     45 #include <hwy/foreach_target.h>
     46 #include <hwy/highway.h>
     47 
     48 HWY_BEFORE_NAMESPACE();
     49 namespace jpegli {
     50 namespace HWY_NAMESPACE {
     51 
     52 // These templates are not found via ADL.
     53 using hwy::HWY_NAMESPACE::Abs;
     54 using hwy::HWY_NAMESPACE::Add;
     55 using hwy::HWY_NAMESPACE::Clamp;
     56 using hwy::HWY_NAMESPACE::Gt;
     57 using hwy::HWY_NAMESPACE::IfThenElseZero;
     58 using hwy::HWY_NAMESPACE::Mul;
     59 using hwy::HWY_NAMESPACE::NearestInt;
     60 using hwy::HWY_NAMESPACE::Or;
     61 using hwy::HWY_NAMESPACE::Rebind;
     62 using hwy::HWY_NAMESPACE::ShiftLeftSame;
     63 using hwy::HWY_NAMESPACE::ShiftRightSame;
     64 using hwy::HWY_NAMESPACE::Vec;
     65 using D = HWY_FULL(float);
     66 using DI = HWY_FULL(int32_t);
     67 constexpr D d;
     68 constexpr DI di;
     69 
     70 void GatherBlockStats(const int16_t* JXL_RESTRICT coeffs,
     71                       const size_t coeffs_size, int32_t* JXL_RESTRICT nonzeros,
     72                       int32_t* JXL_RESTRICT sumabs) {
     73   for (size_t i = 0; i < coeffs_size; i += Lanes(d)) {
     74     size_t k = i % DCTSIZE2;
     75     const Rebind<int16_t, DI> di16;
     76     const Vec<DI> coeff = PromoteTo(di, Load(di16, coeffs + i));
     77     const auto abs_coeff = Abs(coeff);
     78     const auto not_0 = Gt(abs_coeff, Zero(di));
     79     const auto nzero = IfThenElseZero(not_0, Set(di, 1));
     80     Store(Add(nzero, Load(di, nonzeros + k)), di, nonzeros + k);
     81     Store(Add(abs_coeff, Load(di, sumabs + k)), di, sumabs + k);
     82   }
     83 }
     84 
     85 void DecenterRow(float* row, size_t xsize) {
     86   const HWY_CAPPED(float, 8) df;
     87   const auto c128 = Set(df, 128.0f / 255);
     88   for (size_t x = 0; x < xsize; x += Lanes(df)) {
     89     Store(Add(Load(df, row + x), c128), df, row + x);
     90   }
     91 }
     92 
     93 void DitherRow(j_decompress_ptr cinfo, float* row, int c, size_t y,
     94                size_t xsize) {
     95   jpeg_decomp_master* m = cinfo->master;
     96   if (!m->dither_[c]) return;
     97   const float* dither_row =
     98       &m->dither_[c][(y & m->dither_mask_) * m->dither_size_];
     99   for (size_t x = 0; x < xsize; ++x) {
    100     row[x] += dither_row[x & m->dither_mask_];
    101   }
    102 }
    103 
    104 template <typename T>
    105 void StoreUnsignedRow(float* JXL_RESTRICT input[], size_t x0, size_t len,
    106                       size_t num_channels, float multiplier, T* output) {
    107   const HWY_CAPPED(float, 8) d;
    108   auto zero = Zero(d);
    109   auto mul = Set(d, multiplier);
    110   const Rebind<T, decltype(d)> du;
    111 #if JXL_MEMORY_SANITIZER
    112   const size_t padding = hwy::RoundUpTo(len, Lanes(d)) - len;
    113   for (size_t c = 0; c < num_channels; ++c) {
    114     __msan_unpoison(input[c] + x0 + len, sizeof(input[c][0]) * padding);
    115   }
    116 #endif
    117   if (num_channels == 1) {
    118     for (size_t i = 0; i < len; i += Lanes(d)) {
    119       auto v0 = Clamp(zero, Mul(LoadU(d, &input[0][x0 + i]), mul), mul);
    120       StoreU(DemoteTo(du, NearestInt(v0)), du, &output[i]);
    121     }
    122   } else if (num_channels == 2) {
    123     for (size_t i = 0; i < len; i += Lanes(d)) {
    124       auto v0 = Clamp(zero, Mul(LoadU(d, &input[0][x0 + i]), mul), mul);
    125       auto v1 = Clamp(zero, Mul(LoadU(d, &input[1][x0 + i]), mul), mul);
    126       StoreInterleaved2(DemoteTo(du, NearestInt(v0)),
    127                         DemoteTo(du, NearestInt(v1)), du, &output[2 * i]);
    128     }
    129   } else if (num_channels == 3) {
    130     for (size_t i = 0; i < len; i += Lanes(d)) {
    131       auto v0 = Clamp(zero, Mul(LoadU(d, &input[0][x0 + i]), mul), mul);
    132       auto v1 = Clamp(zero, Mul(LoadU(d, &input[1][x0 + i]), mul), mul);
    133       auto v2 = Clamp(zero, Mul(LoadU(d, &input[2][x0 + i]), mul), mul);
    134       StoreInterleaved3(DemoteTo(du, NearestInt(v0)),
    135                         DemoteTo(du, NearestInt(v1)),
    136                         DemoteTo(du, NearestInt(v2)), du, &output[3 * i]);
    137     }
    138   } else if (num_channels == 4) {
    139     for (size_t i = 0; i < len; i += Lanes(d)) {
    140       auto v0 = Clamp(zero, Mul(LoadU(d, &input[0][x0 + i]), mul), mul);
    141       auto v1 = Clamp(zero, Mul(LoadU(d, &input[1][x0 + i]), mul), mul);
    142       auto v2 = Clamp(zero, Mul(LoadU(d, &input[2][x0 + i]), mul), mul);
    143       auto v3 = Clamp(zero, Mul(LoadU(d, &input[3][x0 + i]), mul), mul);
    144       StoreInterleaved4(DemoteTo(du, NearestInt(v0)),
    145                         DemoteTo(du, NearestInt(v1)),
    146                         DemoteTo(du, NearestInt(v2)),
    147                         DemoteTo(du, NearestInt(v3)), du, &output[4 * i]);
    148     }
    149   }
    150 #if JXL_MEMORY_SANITIZER
    151   __msan_poison(output + num_channels * len,
    152                 sizeof(output[0]) * num_channels * padding);
    153 #endif
    154 }
    155 
    156 void StoreFloatRow(float* JXL_RESTRICT input[3], size_t x0, size_t len,
    157                    size_t num_channels, float* output) {
    158   const HWY_CAPPED(float, 8) d;
    159   if (num_channels == 1) {
    160     memcpy(output, input[0] + x0, len * sizeof(output[0]));
    161   } else if (num_channels == 2) {
    162     for (size_t i = 0; i < len; i += Lanes(d)) {
    163       StoreInterleaved2(LoadU(d, &input[0][x0 + i]),
    164                         LoadU(d, &input[1][x0 + i]), d, &output[2 * i]);
    165     }
    166   } else if (num_channels == 3) {
    167     for (size_t i = 0; i < len; i += Lanes(d)) {
    168       StoreInterleaved3(LoadU(d, &input[0][x0 + i]),
    169                         LoadU(d, &input[1][x0 + i]),
    170                         LoadU(d, &input[2][x0 + i]), d, &output[3 * i]);
    171     }
    172   } else if (num_channels == 4) {
    173     for (size_t i = 0; i < len; i += Lanes(d)) {
    174       StoreInterleaved4(LoadU(d, &input[0][x0 + i]),
    175                         LoadU(d, &input[1][x0 + i]),
    176                         LoadU(d, &input[2][x0 + i]),
    177                         LoadU(d, &input[3][x0 + i]), d, &output[4 * i]);
    178     }
    179   }
    180 }
    181 
    182 static constexpr float kFSWeightMR = 7.0f / 16.0f;
    183 static constexpr float kFSWeightBL = 3.0f / 16.0f;
    184 static constexpr float kFSWeightBM = 5.0f / 16.0f;
    185 static constexpr float kFSWeightBR = 1.0f / 16.0f;
    186 
    187 float LimitError(float error) {
    188   float abserror = std::abs(error);
    189   if (abserror > 48.0f) {
    190     abserror = 32.0f;
    191   } else if (abserror > 16.0f) {
    192     abserror = 0.5f * abserror + 8.0f;
    193   }
    194   return error > 0.0f ? abserror : -abserror;
    195 }
    196 
    197 void WriteToOutput(j_decompress_ptr cinfo, float* JXL_RESTRICT rows[],
    198                    size_t xoffset, size_t len, size_t num_channels,
    199                    uint8_t* JXL_RESTRICT output) {
    200   jpeg_decomp_master* m = cinfo->master;
    201   uint8_t* JXL_RESTRICT scratch_space = m->output_scratch_;
    202   if (cinfo->quantize_colors && m->quant_pass_ == 1) {
    203     float* error_row[kMaxComponents];
    204     float* next_error_row[kMaxComponents];
    205     J_DITHER_MODE dither_mode = cinfo->dither_mode;
    206     if (dither_mode == JDITHER_ORDERED) {
    207       for (size_t c = 0; c < num_channels; ++c) {
    208         DitherRow(cinfo, &rows[c][xoffset], c, cinfo->output_scanline,
    209                   cinfo->output_width);
    210       }
    211     } else if (dither_mode == JDITHER_FS) {
    212       for (size_t c = 0; c < num_channels; ++c) {
    213         if (cinfo->output_scanline % 2 == 0) {
    214           error_row[c] = m->error_row_[c];
    215           next_error_row[c] = m->error_row_[c + kMaxComponents];
    216         } else {
    217           error_row[c] = m->error_row_[c + kMaxComponents];
    218           next_error_row[c] = m->error_row_[c];
    219         }
    220         memset(next_error_row[c], 0.0, cinfo->output_width * sizeof(float));
    221       }
    222     }
    223     const float mul = 255.0f;
    224     if (dither_mode != JDITHER_FS) {
    225       StoreUnsignedRow(rows, xoffset, len, num_channels, mul, scratch_space);
    226     }
    227     for (size_t i = 0; i < len; ++i) {
    228       uint8_t* pixel = &scratch_space[num_channels * i];
    229       if (dither_mode == JDITHER_FS) {
    230         for (size_t c = 0; c < num_channels; ++c) {
    231           float val = rows[c][i] * mul + LimitError(error_row[c][i]);
    232           pixel[c] = std::round(std::min(255.0f, std::max(0.0f, val)));
    233         }
    234       }
    235       int index = LookupColorIndex(cinfo, pixel);
    236       output[i] = index;
    237       if (dither_mode == JDITHER_FS) {
    238         size_t prev_i = i > 0 ? i - 1 : 0;
    239         size_t next_i = i + 1 < len ? i + 1 : len - 1;
    240         for (size_t c = 0; c < num_channels; ++c) {
    241           float error = pixel[c] - cinfo->colormap[c][index];
    242           error_row[c][next_i] += kFSWeightMR * error;
    243           next_error_row[c][prev_i] += kFSWeightBL * error;
    244           next_error_row[c][i] += kFSWeightBM * error;
    245           next_error_row[c][next_i] += kFSWeightBR * error;
    246         }
    247       }
    248     }
    249   } else if (m->output_data_type_ == JPEGLI_TYPE_UINT8) {
    250     const float mul = 255.0;
    251     StoreUnsignedRow(rows, xoffset, len, num_channels, mul, scratch_space);
    252     memcpy(output, scratch_space, len * num_channels);
    253   } else if (m->output_data_type_ == JPEGLI_TYPE_UINT16) {
    254     const float mul = 65535.0;
    255     uint16_t* tmp = reinterpret_cast<uint16_t*>(scratch_space);
    256     StoreUnsignedRow(rows, xoffset, len, num_channels, mul, tmp);
    257     if (m->swap_endianness_) {
    258       const HWY_CAPPED(uint16_t, 8) du;
    259       size_t output_len = len * num_channels;
    260       for (size_t j = 0; j < output_len; j += Lanes(du)) {
    261         auto v = LoadU(du, tmp + j);
    262         auto vswap = Or(ShiftRightSame(v, 8), ShiftLeftSame(v, 8));
    263         StoreU(vswap, du, tmp + j);
    264       }
    265     }
    266     memcpy(output, tmp, len * num_channels * 2);
    267   } else if (m->output_data_type_ == JPEGLI_TYPE_FLOAT) {
    268     float* tmp = reinterpret_cast<float*>(scratch_space);
    269     StoreFloatRow(rows, xoffset, len, num_channels, tmp);
    270     if (m->swap_endianness_) {
    271       size_t output_len = len * num_channels;
    272       for (size_t j = 0; j < output_len; ++j) {
    273         tmp[j] = BSwapFloat(tmp[j]);
    274       }
    275     }
    276     memcpy(output, tmp, len * num_channels * 4);
    277   }
    278 }
    279 
    280 // NOLINTNEXTLINE(google-readability-namespace-comments)
    281 }  // namespace HWY_NAMESPACE
    282 }  // namespace jpegli
    283 HWY_AFTER_NAMESPACE();
    284 
    285 #if HWY_ONCE
    286 
    287 namespace jpegli {
    288 
    289 HWY_EXPORT(GatherBlockStats);
    290 HWY_EXPORT(WriteToOutput);
    291 HWY_EXPORT(DecenterRow);
    292 
    293 void GatherBlockStats(const int16_t* JXL_RESTRICT coeffs,
    294                       const size_t coeffs_size, int32_t* JXL_RESTRICT nonzeros,
    295                       int32_t* JXL_RESTRICT sumabs) {
    296   HWY_DYNAMIC_DISPATCH(GatherBlockStats)(coeffs, coeffs_size, nonzeros, sumabs);
    297 }
    298 
    299 void WriteToOutput(j_decompress_ptr cinfo, float* JXL_RESTRICT rows[],
    300                    size_t xoffset, size_t len, size_t num_channels,
    301                    uint8_t* JXL_RESTRICT output) {
    302   HWY_DYNAMIC_DISPATCH(WriteToOutput)
    303   (cinfo, rows, xoffset, len, num_channels, output);
    304 }
    305 
    306 void DecenterRow(float* row, size_t xsize) {
    307   HWY_DYNAMIC_DISPATCH(DecenterRow)(row, xsize);
    308 }
    309 
    310 bool ShouldApplyDequantBiases(j_decompress_ptr cinfo, int ci) {
    311   const auto& compinfo = cinfo->comp_info[ci];
    312   return (compinfo.h_samp_factor == cinfo->max_h_samp_factor &&
    313           compinfo.v_samp_factor == cinfo->max_v_samp_factor);
    314 }
    315 
    316 // See the following article for the details:
    317 // J. R. Price and M. Rabbani, "Dequantization bias for JPEG decompression"
    318 // Proceedings International Conference on Information Technology: Coding and
    319 // Computing (Cat. No.PR00540), 2000, pp. 30-35, doi: 10.1109/ITCC.2000.844179.
    320 void ComputeOptimalLaplacianBiases(const int num_blocks, const int* nonzeros,
    321                                    const int* sumabs, float* biases) {
    322   for (size_t k = 1; k < DCTSIZE2; ++k) {
    323     if (nonzeros[k] == 0) {
    324       biases[k] = 0.5f;
    325       continue;
    326     }
    327     // Notation adapted from the article
    328     float N = num_blocks;
    329     float N1 = nonzeros[k];
    330     float N0 = num_blocks - N1;
    331     float S = sumabs[k];
    332     // Compute gamma from N0, N1, N, S (eq. 11), with A and B being just
    333     // temporary grouping of terms.
    334     float A = 4.0 * S + 2.0 * N;
    335     float B = 4.0 * S - 2.0 * N1;
    336     float gamma = (-1.0 * N0 + std::sqrt(N0 * N0 * 1.0 + A * B)) / A;
    337     float gamma2 = gamma * gamma;
    338     // The bias is computed from gamma with (eq. 5), where the quantization
    339     // multiplier Q can be factored out and thus the bias can be applied
    340     // directly on the quantized coefficient.
    341     biases[k] =
    342         0.5 * (((1.0 + gamma2) / (1.0 - gamma2)) + 1.0 / std::log(gamma));
    343   }
    344 }
    345 
    346 constexpr std::array<int, SAVED_COEFS> Q_POS = {0, 1, 8,  16, 9,
    347                                                 2, 3, 10, 17, 24};
    348 
    349 bool is_nonzero_quantizers(const JQUANT_TBL* qtable) {
    350   return std::all_of(Q_POS.begin(), Q_POS.end(),
    351                      [&](int pos) { return qtable->quantval[pos] != 0; });
    352 }
    353 
    354 // Determine whether smoothing should be applied during decompression
    355 bool do_smoothing(j_decompress_ptr cinfo) {
    356   jpeg_decomp_master* m = cinfo->master;
    357   bool smoothing_useful = false;
    358 
    359   if (!cinfo->progressive_mode || cinfo->coef_bits == nullptr) {
    360     return false;
    361   }
    362   auto* coef_bits_latch = m->coef_bits_latch;
    363   auto* prev_coef_bits_latch = m->prev_coef_bits_latch;
    364 
    365   for (int ci = 0; ci < cinfo->num_components; ci++) {
    366     jpeg_component_info* compptr = &cinfo->comp_info[ci];
    367     JQUANT_TBL* qtable = compptr->quant_table;
    368     int* coef_bits = cinfo->coef_bits[ci];
    369     int* prev_coef_bits = cinfo->coef_bits[ci + cinfo->num_components];
    370 
    371     // Return early if conditions for smoothing are not met
    372     if (qtable == nullptr || !is_nonzero_quantizers(qtable) ||
    373         coef_bits[0] < 0) {
    374       return false;
    375     }
    376 
    377     coef_bits_latch[ci][0] = coef_bits[0];
    378 
    379     for (int coefi = 1; coefi < SAVED_COEFS; coefi++) {
    380       prev_coef_bits_latch[ci][coefi] =
    381           cinfo->input_scan_number > 1 ? prev_coef_bits[coefi] : -1;
    382       if (coef_bits[coefi] != 0) {
    383         smoothing_useful = true;
    384       }
    385       coef_bits_latch[ci][coefi] = coef_bits[coefi];
    386     }
    387   }
    388 
    389   return smoothing_useful;
    390 }
    391 
    392 void PredictSmooth(j_decompress_ptr cinfo, JBLOCKARRAY blocks, int component,
    393                    size_t bx, int iy) {
    394   const size_t imcu_row = cinfo->output_iMCU_row;
    395   int16_t* scratch = cinfo->master->smoothing_scratch_;
    396   std::vector<int> Q_VAL(SAVED_COEFS);
    397   int* coef_bits;
    398 
    399   std::array<std::array<int, 5>, 5> dc_values;
    400   auto& compinfo = cinfo->comp_info[component];
    401   const size_t by0 = imcu_row * compinfo.v_samp_factor;
    402   const size_t by = by0 + iy;
    403 
    404   int prev_iy = by > 0 ? iy - 1 : 0;
    405   int prev_prev_iy = by > 1 ? iy - 2 : prev_iy;
    406   int next_iy = by + 1 < compinfo.height_in_blocks ? iy + 1 : iy;
    407   int next_next_iy = by + 2 < compinfo.height_in_blocks ? iy + 2 : next_iy;
    408 
    409   const int16_t* cur_row = blocks[iy][bx];
    410   const int16_t* prev_row = blocks[prev_iy][bx];
    411   const int16_t* prev_prev_row = blocks[prev_prev_iy][bx];
    412   const int16_t* next_row = blocks[next_iy][bx];
    413   const int16_t* next_next_row = blocks[next_next_iy][bx];
    414 
    415   int prev_block_ind = bx ? -DCTSIZE2 : 0;
    416   int prev_prev_block_ind = bx > 1 ? -2 * DCTSIZE2 : prev_block_ind;
    417   int next_block_ind = bx + 1 < compinfo.width_in_blocks ? DCTSIZE2 : 0;
    418   int next_next_block_ind =
    419       bx + 2 < compinfo.width_in_blocks ? DCTSIZE2 * 2 : next_block_ind;
    420 
    421   std::array<const int16_t*, 5> row_ptrs = {prev_prev_row, prev_row, cur_row,
    422                                             next_row, next_next_row};
    423   std::array<int, 5> block_inds = {prev_prev_block_ind, prev_block_ind, 0,
    424                                    next_block_ind, next_next_block_ind};
    425 
    426   memcpy(scratch, cur_row, DCTSIZE2 * sizeof(cur_row[0]));
    427 
    428   for (int r = 0; r < 5; ++r) {
    429     for (int c = 0; c < 5; ++c) {
    430       dc_values[r][c] = row_ptrs[r][block_inds[c]];
    431     }
    432   }
    433   // Get the correct coef_bits: In case of an incomplete scan, we use the
    434   // prev coeficients.
    435   if (cinfo->output_iMCU_row + 1 > cinfo->input_iMCU_row) {
    436     coef_bits = cinfo->master->prev_coef_bits_latch[component];
    437   } else {
    438     coef_bits = cinfo->master->coef_bits_latch[component];
    439   }
    440 
    441   bool change_dc = true;
    442   for (int i = 1; i < SAVED_COEFS; i++) {
    443     if (coef_bits[i] != -1) {
    444       change_dc = false;
    445       break;
    446     }
    447   }
    448 
    449   JQUANT_TBL* quanttbl = cinfo->quant_tbl_ptrs[compinfo.quant_tbl_no];
    450   for (size_t i = 0; i < 6; ++i) {
    451     Q_VAL[i] = quanttbl->quantval[Q_POS[i]];
    452   }
    453   if (change_dc) {
    454     for (size_t i = 6; i < SAVED_COEFS; ++i) {
    455       Q_VAL[i] = quanttbl->quantval[Q_POS[i]];
    456     }
    457   }
    458   auto calculate_dct_value = [&](int coef_index) {
    459     int64_t num = 0;
    460     int pred;
    461     int Al;
    462     // we use the symmetry of the smoothing matrices by transposing the 5x5 dc
    463     // matrix in that case.
    464     bool swap_indices = coef_index == 2 || coef_index == 5 || coef_index == 8 ||
    465                         coef_index == 9;
    466     auto dc = [&](int i, int j) {
    467       return swap_indices ? dc_values[j][i] : dc_values[i][j];
    468     };
    469     Al = coef_bits[coef_index];
    470     JXL_ASSERT(coef_index >= 0 && coef_index < 10);
    471     switch (coef_index) {
    472       case 0:
    473         // set the DC
    474         num = (-2 * dc(0, 0) - 6 * dc(0, 1) - 8 * dc(0, 2) - 6 * dc(0, 3) -
    475                2 * dc(0, 4) - 6 * dc(1, 0) + 6 * dc(1, 1) + 42 * dc(1, 2) +
    476                6 * dc(1, 3) - 6 * dc(1, 4) - 8 * dc(2, 0) + 42 * dc(2, 1) +
    477                152 * dc(2, 2) + 42 * dc(2, 3) - 8 * dc(2, 4) - 6 * dc(3, 0) +
    478                6 * dc(3, 1) + 42 * dc(3, 2) + 6 * dc(3, 3) - 6 * dc(3, 4) -
    479                2 * dc(4, 0) - 6 * dc(4, 1) - 8 * dc(4, 2) - 6 * dc(4, 3) -
    480                2 * dc(4, 4));
    481         // special case: for the DC the dequantization is different
    482         Al = 0;
    483         break;
    484       case 1:
    485       case 2:
    486         // set Q01 or Q10
    487         num = (change_dc ? (-dc(0, 0) - dc(0, 1) + dc(0, 3) + dc(0, 4) -
    488                             3 * dc(1, 0) + 13 * dc(1, 1) - 13 * dc(1, 3) +
    489                             3 * dc(1, 4) - 3 * dc(2, 0) + 38 * dc(2, 1) -
    490                             38 * dc(2, 3) + 3 * dc(2, 4) - 3 * dc(3, 0) +
    491                             13 * dc(3, 1) - 13 * dc(3, 3) + 3 * dc(3, 4) -
    492                             dc(4, 0) - dc(4, 1) + dc(4, 3) + dc(4, 4))
    493                          : (-7 * dc(2, 0) + 50 * dc(2, 1) - 50 * dc(2, 3) +
    494                             7 * dc(2, 4)));
    495         break;
    496       case 3:
    497       case 5:
    498         // set Q02 or Q20
    499         num = (change_dc
    500                    ? dc(0, 2) + 2 * dc(1, 1) + 7 * dc(1, 2) + 2 * dc(1, 3) -
    501                          5 * dc(2, 1) - 14 * dc(2, 2) - 5 * dc(2, 3) +
    502                          2 * dc(3, 1) + 7 * dc(3, 2) + 2 * dc(3, 3) + dc(4, 2)
    503                    : (-dc(0, 2) + 13 * dc(1, 2) - 24 * dc(2, 2) +
    504                       13 * dc(3, 2) - dc(4, 2)));
    505         break;
    506       case 4:
    507         // set Q11
    508         num =
    509             (change_dc ? -dc(0, 0) + dc(0, 4) + 9 * dc(1, 1) - 9 * dc(1, 3) -
    510                              9 * dc(3, 1) + 9 * dc(3, 3) + dc(4, 0) - dc(4, 4)
    511                        : (dc(1, 4) + dc(3, 0) - 10 * dc(3, 1) + 10 * dc(3, 3) -
    512                           dc(0, 1) - dc(3, 4) + dc(4, 1) - dc(4, 3) + dc(0, 3) -
    513                           dc(1, 0) + 10 * dc(1, 1) - 10 * dc(1, 3)));
    514         break;
    515       case 6:
    516       case 9:
    517         // set Q03 or Q30
    518         num = (dc(1, 1) - dc(1, 3) + 2 * dc(2, 1) - 2 * dc(2, 3) + dc(3, 1) -
    519                dc(3, 3));
    520         break;
    521       case 7:
    522       case 8:
    523       default:
    524         // set Q12 and Q21
    525         num = (dc(1, 1) - 3 * dc(1, 2) + dc(1, 3) - dc(3, 1) + 3 * dc(3, 2) -
    526                dc(3, 3));
    527         break;
    528     }
    529     num = Q_VAL[0] * num;
    530     if (num >= 0) {
    531       pred = ((Q_VAL[coef_index] << 7) + num) / (Q_VAL[coef_index] << 8);
    532       if (Al > 0 && pred >= (1 << Al)) pred = (1 << Al) - 1;
    533     } else {
    534       pred = ((Q_VAL[coef_index] << 7) - num) / (Q_VAL[coef_index] << 8);
    535       if (Al > 0 && pred >= (1 << Al)) pred = (1 << Al) - 1;
    536       pred = -pred;
    537     }
    538     return static_cast<int16_t>(pred);
    539   };
    540 
    541   int loop_end = change_dc ? SAVED_COEFS : 6;
    542   for (int i = 1; i < loop_end; ++i) {
    543     if (coef_bits[i] != 0 && scratch[Q_POS[i]] == 0) {
    544       scratch[Q_POS[i]] = calculate_dct_value(i);
    545     }
    546   }
    547   if (change_dc) {
    548     scratch[0] = calculate_dct_value(0);
    549   }
    550 }
    551 
    552 void PrepareForOutput(j_decompress_ptr cinfo) {
    553   jpeg_decomp_master* m = cinfo->master;
    554   bool smoothing = do_smoothing(cinfo);
    555   m->apply_smoothing = smoothing && FROM_JXL_BOOL(cinfo->do_block_smoothing);
    556   size_t coeffs_per_block = cinfo->num_components * DCTSIZE2;
    557   memset(m->nonzeros_, 0, coeffs_per_block * sizeof(m->nonzeros_[0]));
    558   memset(m->sumabs_, 0, coeffs_per_block * sizeof(m->sumabs_[0]));
    559   memset(m->num_processed_blocks_, 0, sizeof(m->num_processed_blocks_));
    560   memset(m->biases_, 0, coeffs_per_block * sizeof(m->biases_[0]));
    561   cinfo->output_iMCU_row = 0;
    562   cinfo->output_scanline = 0;
    563   const float kDequantScale = 1.0f / (8 * 255);
    564   for (int c = 0; c < cinfo->num_components; c++) {
    565     const auto& comp = cinfo->comp_info[c];
    566     JQUANT_TBL* table = comp.quant_table;
    567     if (table == nullptr) continue;
    568     for (size_t k = 0; k < DCTSIZE2; ++k) {
    569       m->dequant_[c * DCTSIZE2 + k] = table->quantval[k] * kDequantScale;
    570     }
    571   }
    572   ChooseInverseTransform(cinfo);
    573   ChooseColorTransform(cinfo);
    574 }
    575 
    576 void DecodeCurrentiMCURow(j_decompress_ptr cinfo) {
    577   jpeg_decomp_master* m = cinfo->master;
    578   const size_t imcu_row = cinfo->output_iMCU_row;
    579   JBLOCKARRAY ba[kMaxComponents];
    580   for (int c = 0; c < cinfo->num_components; ++c) {
    581     const jpeg_component_info* comp = &cinfo->comp_info[c];
    582     int by0 = imcu_row * comp->v_samp_factor;
    583     int block_rows_left = comp->height_in_blocks - by0;
    584     int max_block_rows = std::min(comp->v_samp_factor, block_rows_left);
    585     int offset = m->streaming_mode_ ? 0 : by0;
    586     ba[c] = (*cinfo->mem->access_virt_barray)(
    587         reinterpret_cast<j_common_ptr>(cinfo), m->coef_arrays[c], offset,
    588         max_block_rows, FALSE);
    589   }
    590   for (int c = 0; c < cinfo->num_components; ++c) {
    591     size_t k0 = c * DCTSIZE2;
    592     auto& compinfo = cinfo->comp_info[c];
    593     size_t block_row = imcu_row * compinfo.v_samp_factor;
    594     if (ShouldApplyDequantBiases(cinfo, c)) {
    595       // Update statistics for this iMCU row.
    596       for (int iy = 0; iy < compinfo.v_samp_factor; ++iy) {
    597         size_t by = block_row + iy;
    598         if (by >= compinfo.height_in_blocks) {
    599           continue;
    600         }
    601         int16_t* JXL_RESTRICT coeffs = &ba[c][iy][0][0];
    602         size_t num = compinfo.width_in_blocks * DCTSIZE2;
    603         GatherBlockStats(coeffs, num, &m->nonzeros_[k0], &m->sumabs_[k0]);
    604         m->num_processed_blocks_[c] += compinfo.width_in_blocks;
    605       }
    606       if (imcu_row % 4 == 3) {
    607         // Re-compute optimal biases every few iMCU-rows.
    608         ComputeOptimalLaplacianBiases(m->num_processed_blocks_[c],
    609                                       &m->nonzeros_[k0], &m->sumabs_[k0],
    610                                       &m->biases_[k0]);
    611       }
    612     }
    613     RowBuffer<float>* raw_out = &m->raw_output_[c];
    614     for (int iy = 0; iy < compinfo.v_samp_factor; ++iy) {
    615       size_t by = block_row + iy;
    616       if (by >= compinfo.height_in_blocks) {
    617         continue;
    618       }
    619       size_t dctsize = m->scaled_dct_size[c];
    620       int16_t* JXL_RESTRICT row_in = &ba[c][iy][0][0];
    621       float* JXL_RESTRICT row_out = raw_out->Row(by * dctsize);
    622       for (size_t bx = 0; bx < compinfo.width_in_blocks; ++bx) {
    623         if (m->apply_smoothing) {
    624           PredictSmooth(cinfo, ba[c], c, bx, iy);
    625           (*m->inverse_transform[c])(m->smoothing_scratch_, &m->dequant_[k0],
    626                                      &m->biases_[k0], m->idct_scratch_,
    627                                      &row_out[bx * dctsize], raw_out->stride(),
    628                                      dctsize);
    629         } else {
    630           (*m->inverse_transform[c])(&row_in[bx * DCTSIZE2], &m->dequant_[k0],
    631                                      &m->biases_[k0], m->idct_scratch_,
    632                                      &row_out[bx * dctsize], raw_out->stride(),
    633                                      dctsize);
    634         }
    635       }
    636       if (m->streaming_mode_) {
    637         memset(row_in, 0, compinfo.width_in_blocks * sizeof(JBLOCK));
    638       }
    639     }
    640   }
    641 }
    642 
    643 void ProcessRawOutput(j_decompress_ptr cinfo, JSAMPIMAGE data) {
    644   jpegli::DecodeCurrentiMCURow(cinfo);
    645   jpeg_decomp_master* m = cinfo->master;
    646   for (int c = 0; c < cinfo->num_components; ++c) {
    647     const auto& compinfo = cinfo->comp_info[c];
    648     size_t comp_width = compinfo.width_in_blocks * DCTSIZE;
    649     size_t comp_height = compinfo.height_in_blocks * DCTSIZE;
    650     size_t comp_nrows = compinfo.v_samp_factor * DCTSIZE;
    651     size_t y0 = cinfo->output_iMCU_row * compinfo.v_samp_factor * DCTSIZE;
    652     size_t y1 = std::min(y0 + comp_nrows, comp_height);
    653     for (size_t y = y0; y < y1; ++y) {
    654       float* rows[1] = {m->raw_output_[c].Row(y)};
    655       uint8_t* output = data[c][y - y0];
    656       DecenterRow(rows[0], comp_width);
    657       WriteToOutput(cinfo, rows, 0, comp_width, 1, output);
    658     }
    659   }
    660   ++cinfo->output_iMCU_row;
    661   cinfo->output_scanline += cinfo->max_v_samp_factor * DCTSIZE;
    662   if (cinfo->output_scanline >= cinfo->output_height) {
    663     ++m->output_passes_done_;
    664   }
    665 }
    666 
    667 void ProcessOutput(j_decompress_ptr cinfo, size_t* num_output_rows,
    668                    JSAMPARRAY scanlines, size_t max_output_rows) {
    669   jpeg_decomp_master* m = cinfo->master;
    670   const int vfactor = cinfo->max_v_samp_factor;
    671   const int hfactor = cinfo->max_h_samp_factor;
    672   const size_t context = m->need_context_rows_ ? 1 : 0;
    673   const size_t imcu_row = cinfo->output_iMCU_row;
    674   const size_t imcu_height = vfactor * m->min_scaled_dct_size;
    675   const size_t imcu_width = hfactor * m->min_scaled_dct_size;
    676   const size_t output_width = m->iMCU_cols_ * imcu_width;
    677   if (imcu_row == cinfo->total_iMCU_rows ||
    678       (imcu_row > context &&
    679        cinfo->output_scanline < (imcu_row - context) * imcu_height)) {
    680     // We are ready to output some scanlines.
    681     size_t ybegin = cinfo->output_scanline;
    682     size_t yend = (imcu_row == cinfo->total_iMCU_rows
    683                        ? cinfo->output_height
    684                        : (imcu_row - context) * imcu_height);
    685     yend = std::min<size_t>(yend, ybegin + max_output_rows - *num_output_rows);
    686     size_t yb = (ybegin / vfactor) * vfactor;
    687     size_t ye = DivCeil(yend, vfactor) * vfactor;
    688     for (size_t y = yb; y < ye; y += vfactor) {
    689       for (int c = 0; c < cinfo->num_components; ++c) {
    690         RowBuffer<float>* raw_out = &m->raw_output_[c];
    691         RowBuffer<float>* render_out = &m->render_output_[c];
    692         int line_groups = vfactor / m->v_factor[c];
    693         int downsampled_width = output_width / m->h_factor[c];
    694         size_t yc = y / m->v_factor[c];
    695         for (int dy = 0; dy < line_groups; ++dy) {
    696           size_t ymid = yc + dy;
    697           const float* JXL_RESTRICT row_mid = raw_out->Row(ymid);
    698           if (cinfo->do_fancy_upsampling && m->v_factor[c] == 2) {
    699             const float* JXL_RESTRICT row_top =
    700                 ymid == 0 ? row_mid : raw_out->Row(ymid - 1);
    701             const float* JXL_RESTRICT row_bot = ymid + 1 == m->raw_height_[c]
    702                                                     ? row_mid
    703                                                     : raw_out->Row(ymid + 1);
    704             Upsample2Vertical(row_top, row_mid, row_bot,
    705                               render_out->Row(2 * dy),
    706                               render_out->Row(2 * dy + 1), downsampled_width);
    707           } else {
    708             for (int yix = 0; yix < m->v_factor[c]; ++yix) {
    709               memcpy(render_out->Row(m->v_factor[c] * dy + yix), row_mid,
    710                      downsampled_width * sizeof(float));
    711             }
    712           }
    713           if (m->h_factor[c] > 1) {
    714             for (int yix = 0; yix < m->v_factor[c]; ++yix) {
    715               int row_ix = m->v_factor[c] * dy + yix;
    716               float* JXL_RESTRICT row = render_out->Row(row_ix);
    717               float* JXL_RESTRICT tmp = m->upsample_scratch_;
    718               if (cinfo->do_fancy_upsampling && m->h_factor[c] == 2) {
    719                 Upsample2Horizontal(row, tmp, output_width);
    720               } else {
    721                 // TODO(szabadka) SIMDify this.
    722                 for (size_t x = 0; x < output_width; ++x) {
    723                   tmp[x] = row[x / m->h_factor[c]];
    724                 }
    725                 memcpy(row, tmp, output_width * sizeof(tmp[0]));
    726               }
    727             }
    728           }
    729         }
    730       }
    731       for (int yix = 0; yix < vfactor; ++yix) {
    732         if (y + yix < ybegin || y + yix >= yend) continue;
    733         float* rows[kMaxComponents];
    734         int num_all_components =
    735             std::max(cinfo->out_color_components, cinfo->num_components);
    736         for (int c = 0; c < num_all_components; ++c) {
    737           rows[c] = m->render_output_[c].Row(yix);
    738         }
    739         (*m->color_transform)(rows, output_width);
    740         for (int c = 0; c < cinfo->out_color_components; ++c) {
    741           // Undo the centering of the sample values around zero.
    742           DecenterRow(rows[c], output_width);
    743         }
    744         if (scanlines) {
    745           uint8_t* output = scanlines[*num_output_rows];
    746           WriteToOutput(cinfo, rows, m->xoffset_, cinfo->output_width,
    747                         cinfo->out_color_components, output);
    748         }
    749         JXL_ASSERT(cinfo->output_scanline == y + yix);
    750         ++cinfo->output_scanline;
    751         ++(*num_output_rows);
    752         if (cinfo->output_scanline == cinfo->output_height) {
    753           ++m->output_passes_done_;
    754         }
    755       }
    756     }
    757   } else {
    758     DecodeCurrentiMCURow(cinfo);
    759     ++cinfo->output_iMCU_row;
    760   }
    761 }
    762 
    763 }  // namespace jpegli
    764 #endif  // HWY_ONCE