libjxl

FORK: libjxl patches used on blog
git clone https://git.neptards.moe/blog/libjxl.git
Log | Files | Refs | Submodules | README | LICENSE

dec_group.cc (32464B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 #include "lib/jxl/dec_group.h"
      7 
      8 #include <stdint.h>
      9 #include <string.h>
     10 
     11 #include <algorithm>
     12 #include <memory>
     13 #include <utility>
     14 
     15 #include "lib/jxl/frame_header.h"
     16 
     17 #undef HWY_TARGET_INCLUDE
     18 #define HWY_TARGET_INCLUDE "lib/jxl/dec_group.cc"
     19 #include <hwy/foreach_target.h>
     20 #include <hwy/highway.h>
     21 
     22 #include "lib/jxl/ac_context.h"
     23 #include "lib/jxl/ac_strategy.h"
     24 #include "lib/jxl/base/bits.h"
     25 #include "lib/jxl/base/common.h"
     26 #include "lib/jxl/base/printf_macros.h"
     27 #include "lib/jxl/base/status.h"
     28 #include "lib/jxl/coeff_order.h"
     29 #include "lib/jxl/common.h"  // kMaxNumPasses
     30 #include "lib/jxl/dec_cache.h"
     31 #include "lib/jxl/dec_transforms-inl.h"
     32 #include "lib/jxl/dec_xyb.h"
     33 #include "lib/jxl/entropy_coder.h"
     34 #include "lib/jxl/quant_weights.h"
     35 #include "lib/jxl/quantizer-inl.h"
     36 #include "lib/jxl/quantizer.h"
     37 
     38 #ifndef LIB_JXL_DEC_GROUP_CC
     39 #define LIB_JXL_DEC_GROUP_CC
     40 namespace jxl {
     41 
     42 struct AuxOut;
     43 
     44 // Interface for reading groups for DecodeGroupImpl.
     45 class GetBlock {
     46  public:
     47   virtual void StartRow(size_t by) = 0;
     48   virtual Status LoadBlock(size_t bx, size_t by, const AcStrategy& acs,
     49                            size_t size, size_t log2_covered_blocks,
     50                            ACPtr block[3], ACType ac_type) = 0;
     51   virtual ~GetBlock() {}
     52 };
     53 
     54 // Controls whether DecodeGroupImpl renders to pixels or not.
     55 enum DrawMode {
     56   // Render to pixels.
     57   kDraw = 0,
     58   // Don't render to pixels.
     59   kDontDraw = 1,
     60 };
     61 
     62 }  // namespace jxl
     63 #endif  // LIB_JXL_DEC_GROUP_CC
     64 
     65 HWY_BEFORE_NAMESPACE();
     66 namespace jxl {
     67 namespace HWY_NAMESPACE {
     68 
     69 // These templates are not found via ADL.
     70 using hwy::HWY_NAMESPACE::AllFalse;
     71 using hwy::HWY_NAMESPACE::Gt;
     72 using hwy::HWY_NAMESPACE::Le;
     73 using hwy::HWY_NAMESPACE::MaskFromVec;
     74 using hwy::HWY_NAMESPACE::Or;
     75 using hwy::HWY_NAMESPACE::Rebind;
     76 using hwy::HWY_NAMESPACE::ShiftRight;
     77 
     78 using D = HWY_FULL(float);
     79 using DU = HWY_FULL(uint32_t);
     80 using DI = HWY_FULL(int32_t);
     81 using DI16 = Rebind<int16_t, DI>;
     82 using DI16_FULL = HWY_CAPPED(int16_t, kDCTBlockSize);
     83 constexpr D d;
     84 constexpr DI di;
     85 constexpr DI16 di16;
     86 constexpr DI16_FULL di16_full;
     87 
     88 // TODO(veluca): consider SIMDfying.
     89 void Transpose8x8InPlace(int32_t* JXL_RESTRICT block) {
     90   for (size_t x = 0; x < 8; x++) {
     91     for (size_t y = x + 1; y < 8; y++) {
     92       std::swap(block[y * 8 + x], block[x * 8 + y]);
     93     }
     94   }
     95 }
     96 
     97 template <ACType ac_type>
     98 void DequantLane(Vec<D> scaled_dequant_x, Vec<D> scaled_dequant_y,
     99                  Vec<D> scaled_dequant_b,
    100                  const float* JXL_RESTRICT dequant_matrices, size_t size,
    101                  size_t k, Vec<D> x_cc_mul, Vec<D> b_cc_mul,
    102                  const float* JXL_RESTRICT biases, ACPtr qblock[3],
    103                  float* JXL_RESTRICT block) {
    104   const auto x_mul = Mul(Load(d, dequant_matrices + k), scaled_dequant_x);
    105   const auto y_mul =
    106       Mul(Load(d, dequant_matrices + size + k), scaled_dequant_y);
    107   const auto b_mul =
    108       Mul(Load(d, dequant_matrices + 2 * size + k), scaled_dequant_b);
    109 
    110   Vec<DI> quantized_x_int;
    111   Vec<DI> quantized_y_int;
    112   Vec<DI> quantized_b_int;
    113   if (ac_type == ACType::k16) {
    114     Rebind<int16_t, DI> di16;
    115     quantized_x_int = PromoteTo(di, Load(di16, qblock[0].ptr16 + k));
    116     quantized_y_int = PromoteTo(di, Load(di16, qblock[1].ptr16 + k));
    117     quantized_b_int = PromoteTo(di, Load(di16, qblock[2].ptr16 + k));
    118   } else {
    119     quantized_x_int = Load(di, qblock[0].ptr32 + k);
    120     quantized_y_int = Load(di, qblock[1].ptr32 + k);
    121     quantized_b_int = Load(di, qblock[2].ptr32 + k);
    122   }
    123 
    124   const auto dequant_x_cc =
    125       Mul(AdjustQuantBias(di, 0, quantized_x_int, biases), x_mul);
    126   const auto dequant_y =
    127       Mul(AdjustQuantBias(di, 1, quantized_y_int, biases), y_mul);
    128   const auto dequant_b_cc =
    129       Mul(AdjustQuantBias(di, 2, quantized_b_int, biases), b_mul);
    130 
    131   const auto dequant_x = MulAdd(x_cc_mul, dequant_y, dequant_x_cc);
    132   const auto dequant_b = MulAdd(b_cc_mul, dequant_y, dequant_b_cc);
    133   Store(dequant_x, d, block + k);
    134   Store(dequant_y, d, block + size + k);
    135   Store(dequant_b, d, block + 2 * size + k);
    136 }
    137 
    138 template <ACType ac_type>
    139 void DequantBlock(const AcStrategy& acs, float inv_global_scale, int quant,
    140                   float x_dm_multiplier, float b_dm_multiplier, Vec<D> x_cc_mul,
    141                   Vec<D> b_cc_mul, size_t kind, size_t size,
    142                   const Quantizer& quantizer, size_t covered_blocks,
    143                   const size_t* sbx,
    144                   const float* JXL_RESTRICT* JXL_RESTRICT dc_row,
    145                   size_t dc_stride, const float* JXL_RESTRICT biases,
    146                   ACPtr qblock[3], float* JXL_RESTRICT block,
    147                   float* JXL_RESTRICT scratch) {
    148   const auto scaled_dequant_s = inv_global_scale / quant;
    149 
    150   const auto scaled_dequant_x = Set(d, scaled_dequant_s * x_dm_multiplier);
    151   const auto scaled_dequant_y = Set(d, scaled_dequant_s);
    152   const auto scaled_dequant_b = Set(d, scaled_dequant_s * b_dm_multiplier);
    153 
    154   const float* dequant_matrices = quantizer.DequantMatrix(kind, 0);
    155 
    156   for (size_t k = 0; k < covered_blocks * kDCTBlockSize; k += Lanes(d)) {
    157     DequantLane<ac_type>(scaled_dequant_x, scaled_dequant_y, scaled_dequant_b,
    158                          dequant_matrices, size, k, x_cc_mul, b_cc_mul, biases,
    159                          qblock, block);
    160   }
    161   for (size_t c = 0; c < 3; c++) {
    162     LowestFrequenciesFromDC(acs.Strategy(), dc_row[c] + sbx[c], dc_stride,
    163                             block + c * size, scratch);
    164   }
    165 }
    166 
    167 Status DecodeGroupImpl(const FrameHeader& frame_header,
    168                        GetBlock* JXL_RESTRICT get_block,
    169                        GroupDecCache* JXL_RESTRICT group_dec_cache,
    170                        PassesDecoderState* JXL_RESTRICT dec_state,
    171                        size_t thread, size_t group_idx,
    172                        RenderPipelineInput& render_pipeline_input,
    173                        ImageBundle* decoded, DrawMode draw) {
    174   // TODO(veluca): investigate cache usage in this function.
    175   const Rect block_rect =
    176       dec_state->shared->frame_dim.BlockGroupRect(group_idx);
    177   const AcStrategyImage& ac_strategy = dec_state->shared->ac_strategy;
    178 
    179   const size_t xsize_blocks = block_rect.xsize();
    180   const size_t ysize_blocks = block_rect.ysize();
    181 
    182   const size_t dc_stride = dec_state->shared->dc->PixelsPerRow();
    183 
    184   const float inv_global_scale = dec_state->shared->quantizer.InvGlobalScale();
    185 
    186   const YCbCrChromaSubsampling& cs = frame_header.chroma_subsampling;
    187 
    188   const auto kJpegDctMin = Set(di16_full, -4095);
    189   const auto kJpegDctMax = Set(di16_full, 4095);
    190 
    191   size_t idct_stride[3];
    192   for (size_t c = 0; c < 3; c++) {
    193     idct_stride[c] = render_pipeline_input.GetBuffer(c).first->PixelsPerRow();
    194   }
    195 
    196   HWY_ALIGN int32_t scaled_qtable[64 * 3];
    197 
    198   ACType ac_type = dec_state->coefficients->Type();
    199   auto dequant_block = ac_type == ACType::k16 ? DequantBlock<ACType::k16>
    200                                               : DequantBlock<ACType::k32>;
    201   // Whether or not coefficients should be stored for future usage, and/or read
    202   // from past usage.
    203   bool accumulate = !dec_state->coefficients->IsEmpty();
    204   // Offset of the current block in the group.
    205   size_t offset = 0;
    206 
    207   std::array<int, 3> jpeg_c_map;
    208   bool jpeg_is_gray = false;
    209   std::array<int, 3> dcoff = {};
    210 
    211   // TODO(veluca): all of this should be done only once per image.
    212   if (decoded->IsJPEG()) {
    213     if (!dec_state->shared->cmap.IsJPEGCompatible()) {
    214       return JXL_FAILURE("The CfL map is not JPEG-compatible");
    215     }
    216     jpeg_is_gray = (decoded->jpeg_data->components.size() == 1);
    217     jpeg_c_map = JpegOrder(frame_header.color_transform, jpeg_is_gray);
    218     const std::vector<QuantEncoding>& qe =
    219         dec_state->shared->matrices.encodings();
    220     if (qe.empty() || qe[0].mode != QuantEncoding::Mode::kQuantModeRAW ||
    221         std::abs(qe[0].qraw.qtable_den - 1.f / (8 * 255)) > 1e-8f) {
    222       return JXL_FAILURE(
    223           "Quantization table is not a JPEG quantization table.");
    224     }
    225     for (size_t c = 0; c < 3; c++) {
    226       if (frame_header.color_transform == ColorTransform::kNone) {
    227         dcoff[c] = 1024 / (*qe[0].qraw.qtable)[64 * c];
    228       }
    229       for (size_t i = 0; i < 64; i++) {
    230         // Transpose the matrix, as it will be used on the transposed block.
    231         int n = qe[0].qraw.qtable->at(64 + i);
    232         int d = qe[0].qraw.qtable->at(64 * c + i);
    233         if (n <= 0 || d <= 0 || n >= 65536 || d >= 65536) {
    234           return JXL_FAILURE("Invalid JPEG quantization table");
    235         }
    236         scaled_qtable[64 * c + (i % 8) * 8 + (i / 8)] =
    237             (1 << kCFLFixedPointPrecision) * n / d;
    238       }
    239     }
    240   }
    241 
    242   size_t hshift[3] = {cs.HShift(0), cs.HShift(1), cs.HShift(2)};
    243   size_t vshift[3] = {cs.VShift(0), cs.VShift(1), cs.VShift(2)};
    244   Rect r[3];
    245   for (size_t i = 0; i < 3; i++) {
    246     r[i] =
    247         Rect(block_rect.x0() >> hshift[i], block_rect.y0() >> vshift[i],
    248              block_rect.xsize() >> hshift[i], block_rect.ysize() >> vshift[i]);
    249     if (!r[i].IsInside({0, 0, dec_state->shared->dc->Plane(i).xsize(),
    250                         dec_state->shared->dc->Plane(i).ysize()})) {
    251       return JXL_FAILURE("Frame dimensions are too big for the image.");
    252     }
    253   }
    254 
    255   for (size_t by = 0; by < ysize_blocks; ++by) {
    256     get_block->StartRow(by);
    257     size_t sby[3] = {by >> vshift[0], by >> vshift[1], by >> vshift[2]};
    258 
    259     const int32_t* JXL_RESTRICT row_quant =
    260         block_rect.ConstRow(dec_state->shared->raw_quant_field, by);
    261 
    262     const float* JXL_RESTRICT dc_rows[3] = {
    263         r[0].ConstPlaneRow(*dec_state->shared->dc, 0, sby[0]),
    264         r[1].ConstPlaneRow(*dec_state->shared->dc, 1, sby[1]),
    265         r[2].ConstPlaneRow(*dec_state->shared->dc, 2, sby[2]),
    266     };
    267 
    268     const size_t ty = (block_rect.y0() + by) / kColorTileDimInBlocks;
    269     AcStrategyRow acs_row = ac_strategy.ConstRow(block_rect, by);
    270 
    271     const int8_t* JXL_RESTRICT row_cmap[3] = {
    272         dec_state->shared->cmap.ytox_map.ConstRow(ty),
    273         nullptr,
    274         dec_state->shared->cmap.ytob_map.ConstRow(ty),
    275     };
    276 
    277     float* JXL_RESTRICT idct_row[3];
    278     int16_t* JXL_RESTRICT jpeg_row[3];
    279     for (size_t c = 0; c < 3; c++) {
    280       idct_row[c] = render_pipeline_input.GetBuffer(c).second.Row(
    281           render_pipeline_input.GetBuffer(c).first, sby[c] * kBlockDim);
    282       if (decoded->IsJPEG()) {
    283         auto& component = decoded->jpeg_data->components[jpeg_c_map[c]];
    284         jpeg_row[c] =
    285             component.coeffs.data() +
    286             (component.width_in_blocks * (r[c].y0() + sby[c]) + r[c].x0()) *
    287                 kDCTBlockSize;
    288       }
    289     }
    290 
    291     size_t bx = 0;
    292     for (size_t tx = 0; tx < DivCeil(xsize_blocks, kColorTileDimInBlocks);
    293          tx++) {
    294       size_t abs_tx = tx + block_rect.x0() / kColorTileDimInBlocks;
    295       auto x_cc_mul =
    296           Set(d, dec_state->shared->cmap.YtoXRatio(row_cmap[0][abs_tx]));
    297       auto b_cc_mul =
    298           Set(d, dec_state->shared->cmap.YtoBRatio(row_cmap[2][abs_tx]));
    299       // Increment bx by llf_x because those iterations would otherwise
    300       // immediately continue (!IsFirstBlock). Reduces mispredictions.
    301       for (; bx < xsize_blocks && bx < (tx + 1) * kColorTileDimInBlocks;) {
    302         size_t sbx[3] = {bx >> hshift[0], bx >> hshift[1], bx >> hshift[2]};
    303         AcStrategy acs = acs_row[bx];
    304         const size_t llf_x = acs.covered_blocks_x();
    305 
    306         // Can only happen in the second or lower rows of a varblock.
    307         if (JXL_UNLIKELY(!acs.IsFirstBlock())) {
    308           bx += llf_x;
    309           continue;
    310         }
    311         const size_t log2_covered_blocks = acs.log2_covered_blocks();
    312 
    313         const size_t covered_blocks = 1 << log2_covered_blocks;
    314         const size_t size = covered_blocks * kDCTBlockSize;
    315 
    316         ACPtr qblock[3];
    317         if (accumulate) {
    318           for (size_t c = 0; c < 3; c++) {
    319             qblock[c] = dec_state->coefficients->PlaneRow(c, group_idx, offset);
    320           }
    321         } else {
    322           // No point in reading from bitstream without accumulating and not
    323           // drawing.
    324           JXL_ASSERT(draw == kDraw);
    325           if (ac_type == ACType::k16) {
    326             memset(group_dec_cache->dec_group_qblock16, 0,
    327                    size * 3 * sizeof(int16_t));
    328             for (size_t c = 0; c < 3; c++) {
    329               qblock[c].ptr16 = group_dec_cache->dec_group_qblock16 + c * size;
    330             }
    331           } else {
    332             memset(group_dec_cache->dec_group_qblock, 0,
    333                    size * 3 * sizeof(int32_t));
    334             for (size_t c = 0; c < 3; c++) {
    335               qblock[c].ptr32 = group_dec_cache->dec_group_qblock + c * size;
    336             }
    337           }
    338         }
    339         JXL_RETURN_IF_ERROR(get_block->LoadBlock(
    340             bx, by, acs, size, log2_covered_blocks, qblock, ac_type));
    341         offset += size;
    342         if (draw == kDontDraw) {
    343           bx += llf_x;
    344           continue;
    345         }
    346 
    347         if (JXL_UNLIKELY(decoded->IsJPEG())) {
    348           if (acs.Strategy() != AcStrategy::Type::DCT) {
    349             return JXL_FAILURE(
    350                 "Can only decode to JPEG if only DCT-8 is used.");
    351           }
    352 
    353           HWY_ALIGN int32_t transposed_dct_y[64];
    354           for (size_t c : {1, 0, 2}) {
    355             // Propagate only Y for grayscale.
    356             if (jpeg_is_gray && c != 1) {
    357               continue;
    358             }
    359             if ((sbx[c] << hshift[c] != bx) || (sby[c] << vshift[c] != by)) {
    360               continue;
    361             }
    362             int16_t* JXL_RESTRICT jpeg_pos =
    363                 jpeg_row[c] + sbx[c] * kDCTBlockSize;
    364             // JPEG XL is transposed, JPEG is not.
    365             auto* transposed_dct = qblock[c].ptr32;
    366             Transpose8x8InPlace(transposed_dct);
    367             // No CfL - no need to store the y block converted to integers.
    368             if (!cs.Is444() ||
    369                 (row_cmap[0][abs_tx] == 0 && row_cmap[2][abs_tx] == 0)) {
    370               for (size_t i = 0; i < 64; i += Lanes(d)) {
    371                 const auto ini = Load(di, transposed_dct + i);
    372                 const auto ini16 = DemoteTo(di16, ini);
    373                 StoreU(ini16, di16, jpeg_pos + i);
    374               }
    375             } else if (c == 1) {
    376               // Y channel: save for restoring X/B, but nothing else to do.
    377               for (size_t i = 0; i < 64; i += Lanes(d)) {
    378                 const auto ini = Load(di, transposed_dct + i);
    379                 Store(ini, di, transposed_dct_y + i);
    380                 const auto ini16 = DemoteTo(di16, ini);
    381                 StoreU(ini16, di16, jpeg_pos + i);
    382               }
    383             } else {
    384               // transposed_dct_y contains the y channel block, transposed.
    385               const auto scale = Set(
    386                   di, dec_state->shared->cmap.RatioJPEG(row_cmap[c][abs_tx]));
    387               const auto round = Set(di, 1 << (kCFLFixedPointPrecision - 1));
    388               for (int i = 0; i < 64; i += Lanes(d)) {
    389                 auto in = Load(di, transposed_dct + i);
    390                 auto in_y = Load(di, transposed_dct_y + i);
    391                 auto qt = Load(di, scaled_qtable + c * size + i);
    392                 auto coeff_scale = ShiftRight<kCFLFixedPointPrecision>(
    393                     Add(Mul(qt, scale), round));
    394                 auto cfl_factor = ShiftRight<kCFLFixedPointPrecision>(
    395                     Add(Mul(in_y, coeff_scale), round));
    396                 StoreU(DemoteTo(di16, Add(in, cfl_factor)), di16, jpeg_pos + i);
    397               }
    398             }
    399             jpeg_pos[0] =
    400                 Clamp1<float>(dc_rows[c][sbx[c]] - dcoff[c], -2047, 2047);
    401             auto overflow = MaskFromVec(Set(di16_full, 0));
    402             auto underflow = MaskFromVec(Set(di16_full, 0));
    403             for (int i = 0; i < 64; i += Lanes(di16_full)) {
    404               auto in = LoadU(di16_full, jpeg_pos + i);
    405               overflow = Or(overflow, Gt(in, kJpegDctMax));
    406               underflow = Or(underflow, Lt(in, kJpegDctMin));
    407             }
    408             if (!AllFalse(di16_full, Or(overflow, underflow))) {
    409               return JXL_FAILURE("JPEG DCT coefficients out of range");
    410             }
    411           }
    412         } else {
    413           HWY_ALIGN float* const block = group_dec_cache->dec_group_block;
    414           // Dequantize and add predictions.
    415           dequant_block(
    416               acs, inv_global_scale, row_quant[bx], dec_state->x_dm_multiplier,
    417               dec_state->b_dm_multiplier, x_cc_mul, b_cc_mul, acs.RawStrategy(),
    418               size, dec_state->shared->quantizer,
    419               acs.covered_blocks_y() * acs.covered_blocks_x(), sbx, dc_rows,
    420               dc_stride,
    421               dec_state->output_encoding_info.opsin_params.quant_biases, qblock,
    422               block, group_dec_cache->scratch_space);
    423 
    424           for (size_t c : {1, 0, 2}) {
    425             if ((sbx[c] << hshift[c] != bx) || (sby[c] << vshift[c] != by)) {
    426               continue;
    427             }
    428             // IDCT
    429             float* JXL_RESTRICT idct_pos = idct_row[c] + sbx[c] * kBlockDim;
    430             TransformToPixels(acs.Strategy(), block + c * size, idct_pos,
    431                               idct_stride[c], group_dec_cache->scratch_space);
    432           }
    433         }
    434         bx += llf_x;
    435       }
    436     }
    437   }
    438   return true;
    439 }
    440 
    441 // NOLINTNEXTLINE(google-readability-namespace-comments)
    442 }  // namespace HWY_NAMESPACE
    443 }  // namespace jxl
    444 HWY_AFTER_NAMESPACE();
    445 
    446 #if HWY_ONCE
    447 namespace jxl {
    448 namespace {
    449 // Decode quantized AC coefficients of DCT blocks.
    450 // LLF components in the output block will not be modified.
    451 template <ACType ac_type, bool uses_lz77>
    452 Status DecodeACVarBlock(size_t ctx_offset, size_t log2_covered_blocks,
    453                         int32_t* JXL_RESTRICT row_nzeros,
    454                         const int32_t* JXL_RESTRICT row_nzeros_top,
    455                         size_t nzeros_stride, size_t c, size_t bx, size_t by,
    456                         size_t lbx, AcStrategy acs,
    457                         const coeff_order_t* JXL_RESTRICT coeff_order,
    458                         BitReader* JXL_RESTRICT br,
    459                         ANSSymbolReader* JXL_RESTRICT decoder,
    460                         const std::vector<uint8_t>& context_map,
    461                         const uint8_t* qdc_row, const int32_t* qf_row,
    462                         const BlockCtxMap& block_ctx_map, ACPtr block,
    463                         size_t shift = 0) {
    464   // Equal to number of LLF coefficients.
    465   const size_t covered_blocks = 1 << log2_covered_blocks;
    466   const size_t size = covered_blocks * kDCTBlockSize;
    467   int32_t predicted_nzeros =
    468       PredictFromTopAndLeft(row_nzeros_top, row_nzeros, bx, 32);
    469 
    470   size_t ord = kStrategyOrder[acs.RawStrategy()];
    471   const coeff_order_t* JXL_RESTRICT order =
    472       &coeff_order[CoeffOrderOffset(ord, c)];
    473 
    474   size_t block_ctx = block_ctx_map.Context(qdc_row[lbx], qf_row[bx], ord, c);
    475   const int32_t nzero_ctx =
    476       block_ctx_map.NonZeroContext(predicted_nzeros, block_ctx) + ctx_offset;
    477 
    478   size_t nzeros =
    479       decoder->ReadHybridUintInlined<uses_lz77>(nzero_ctx, br, context_map);
    480   if (nzeros > size - covered_blocks) {
    481     return JXL_FAILURE("Invalid AC: nzeros %" PRIuS " too large for %" PRIuS
    482                        " 8x8 blocks",
    483                        nzeros, covered_blocks);
    484   }
    485   for (size_t y = 0; y < acs.covered_blocks_y(); y++) {
    486     for (size_t x = 0; x < acs.covered_blocks_x(); x++) {
    487       row_nzeros[bx + x + y * nzeros_stride] =
    488           (nzeros + covered_blocks - 1) >> log2_covered_blocks;
    489     }
    490   }
    491 
    492   const size_t histo_offset =
    493       ctx_offset + block_ctx_map.ZeroDensityContextsOffset(block_ctx);
    494 
    495   size_t prev = (nzeros > size / 16 ? 0 : 1);
    496   for (size_t k = covered_blocks; k < size && nzeros != 0; ++k) {
    497     const size_t ctx =
    498         histo_offset + ZeroDensityContext(nzeros, k, covered_blocks,
    499                                           log2_covered_blocks, prev);
    500     const size_t u_coeff =
    501         decoder->ReadHybridUintInlined<uses_lz77>(ctx, br, context_map);
    502     // Hand-rolled version of UnpackSigned, shifting before the conversion to
    503     // signed integer to avoid undefined behavior of shifting negative numbers.
    504     const size_t magnitude = u_coeff >> 1;
    505     const size_t neg_sign = (~u_coeff) & 1;
    506     const intptr_t coeff =
    507         static_cast<intptr_t>((magnitude ^ (neg_sign - 1)) << shift);
    508     if (ac_type == ACType::k16) {
    509       block.ptr16[order[k]] += coeff;
    510     } else {
    511       block.ptr32[order[k]] += coeff;
    512     }
    513     prev = static_cast<size_t>(u_coeff != 0);
    514     nzeros -= prev;
    515   }
    516   if (JXL_UNLIKELY(nzeros != 0)) {
    517     return JXL_FAILURE("Invalid AC: nzeros at end of block is %" PRIuS
    518                        ", should be 0. Block (%" PRIuS ", %" PRIuS
    519                        "), channel %" PRIuS,
    520                        nzeros, bx, by, c);
    521   }
    522 
    523   return true;
    524 }
    525 
    526 // Structs used by DecodeGroupImpl to get a quantized block.
    527 // GetBlockFromBitstream uses ANS decoding (and thus keeps track of row
    528 // pointers in row_nzeros), GetBlockFromEncoder simply reads the coefficient
    529 // image provided by the encoder.
    530 
    531 struct GetBlockFromBitstream : public GetBlock {
    532   void StartRow(size_t by) override {
    533     qf_row = rect.ConstRow(*qf, by);
    534     for (size_t c = 0; c < 3; c++) {
    535       size_t sby = by >> vshift[c];
    536       quant_dc_row = quant_dc->ConstRow(rect.y0() + by) + rect.x0();
    537       for (size_t i = 0; i < num_passes; i++) {
    538         row_nzeros[i][c] = group_dec_cache->num_nzeroes[i].PlaneRow(c, sby);
    539         row_nzeros_top[i][c] =
    540             sby == 0
    541                 ? nullptr
    542                 : group_dec_cache->num_nzeroes[i].ConstPlaneRow(c, sby - 1);
    543       }
    544     }
    545   }
    546 
    547   Status LoadBlock(size_t bx, size_t by, const AcStrategy& acs, size_t size,
    548                    size_t log2_covered_blocks, ACPtr block[3],
    549                    ACType ac_type) override {
    550     ;
    551     for (size_t c : {1, 0, 2}) {
    552       size_t sbx = bx >> hshift[c];
    553       size_t sby = by >> vshift[c];
    554       if (JXL_UNLIKELY((sbx << hshift[c] != bx) || (sby << vshift[c] != by))) {
    555         continue;
    556       }
    557 
    558       for (size_t pass = 0; JXL_UNLIKELY(pass < num_passes); pass++) {
    559         auto decode_ac_varblock =
    560             decoders[pass].UsesLZ77()
    561                 ? (ac_type == ACType::k16 ? DecodeACVarBlock<ACType::k16, 1>
    562                                           : DecodeACVarBlock<ACType::k32, 1>)
    563                 : (ac_type == ACType::k16 ? DecodeACVarBlock<ACType::k16, 0>
    564                                           : DecodeACVarBlock<ACType::k32, 0>);
    565         JXL_RETURN_IF_ERROR(decode_ac_varblock(
    566             ctx_offset[pass], log2_covered_blocks, row_nzeros[pass][c],
    567             row_nzeros_top[pass][c], nzeros_stride, c, sbx, sby, bx, acs,
    568             &coeff_orders[pass * coeff_order_size], readers[pass],
    569             &decoders[pass], context_map[pass], quant_dc_row, qf_row,
    570             *block_ctx_map, block[c], shift_for_pass[pass]));
    571       }
    572     }
    573     return true;
    574   }
    575 
    576   Status Init(const FrameHeader& frame_header,
    577               BitReader* JXL_RESTRICT* JXL_RESTRICT readers, size_t num_passes,
    578               size_t group_idx, size_t histo_selector_bits, const Rect& rect,
    579               GroupDecCache* JXL_RESTRICT group_dec_cache,
    580               PassesDecoderState* dec_state, size_t first_pass) {
    581     for (size_t i = 0; i < 3; i++) {
    582       hshift[i] = frame_header.chroma_subsampling.HShift(i);
    583       vshift[i] = frame_header.chroma_subsampling.VShift(i);
    584     }
    585     this->coeff_order_size = dec_state->shared->coeff_order_size;
    586     this->coeff_orders =
    587         dec_state->shared->coeff_orders.data() + first_pass * coeff_order_size;
    588     this->context_map = dec_state->context_map.data() + first_pass;
    589     this->readers = readers;
    590     this->num_passes = num_passes;
    591     this->shift_for_pass = frame_header.passes.shift + first_pass;
    592     this->group_dec_cache = group_dec_cache;
    593     this->rect = rect;
    594     block_ctx_map = &dec_state->shared->block_ctx_map;
    595     qf = &dec_state->shared->raw_quant_field;
    596     quant_dc = &dec_state->shared->quant_dc;
    597 
    598     for (size_t pass = 0; pass < num_passes; pass++) {
    599       // Select which histogram set to use among those of the current pass.
    600       size_t cur_histogram = 0;
    601       if (histo_selector_bits != 0) {
    602         cur_histogram = readers[pass]->ReadBits(histo_selector_bits);
    603       }
    604       if (cur_histogram >= dec_state->shared->num_histograms) {
    605         return JXL_FAILURE("Invalid histogram selector");
    606       }
    607       ctx_offset[pass] = cur_histogram * block_ctx_map->NumACContexts();
    608 
    609       decoders[pass] =
    610           ANSSymbolReader(&dec_state->code[pass + first_pass], readers[pass]);
    611     }
    612     nzeros_stride = group_dec_cache->num_nzeroes[0].PixelsPerRow();
    613     for (size_t i = 0; i < num_passes; i++) {
    614       JXL_ASSERT(
    615           nzeros_stride ==
    616           static_cast<size_t>(group_dec_cache->num_nzeroes[i].PixelsPerRow()));
    617     }
    618     return true;
    619   }
    620 
    621   const uint32_t* shift_for_pass = nullptr;  // not owned
    622   const coeff_order_t* JXL_RESTRICT coeff_orders;
    623   size_t coeff_order_size;
    624   const std::vector<uint8_t>* JXL_RESTRICT context_map;
    625   ANSSymbolReader decoders[kMaxNumPasses];
    626   BitReader* JXL_RESTRICT* JXL_RESTRICT readers;
    627   size_t num_passes;
    628   size_t ctx_offset[kMaxNumPasses];
    629   size_t nzeros_stride;
    630   int32_t* JXL_RESTRICT row_nzeros[kMaxNumPasses][3];
    631   const int32_t* JXL_RESTRICT row_nzeros_top[kMaxNumPasses][3];
    632   GroupDecCache* JXL_RESTRICT group_dec_cache;
    633   const BlockCtxMap* block_ctx_map;
    634   const ImageI* qf;
    635   const ImageB* quant_dc;
    636   const int32_t* qf_row;
    637   const uint8_t* quant_dc_row;
    638   Rect rect;
    639   size_t hshift[3], vshift[3];
    640 };
    641 
    642 struct GetBlockFromEncoder : public GetBlock {
    643   void StartRow(size_t by) override {}
    644 
    645   Status LoadBlock(size_t bx, size_t by, const AcStrategy& acs, size_t size,
    646                    size_t log2_covered_blocks, ACPtr block[3],
    647                    ACType ac_type) override {
    648     JXL_DASSERT(ac_type == ACType::k32);
    649     for (size_t c = 0; c < 3; c++) {
    650       // for each pass
    651       for (size_t i = 0; i < quantized_ac->size(); i++) {
    652         for (size_t k = 0; k < size; k++) {
    653           // TODO(veluca): SIMD.
    654           block[c].ptr32[k] +=
    655               rows[i][c][offset + k] * (1 << shift_for_pass[i]);
    656         }
    657       }
    658     }
    659     offset += size;
    660     return true;
    661   }
    662 
    663   GetBlockFromEncoder(const std::vector<std::unique_ptr<ACImage>>& ac,
    664                       size_t group_idx, const uint32_t* shift_for_pass)
    665       : quantized_ac(&ac), shift_for_pass(shift_for_pass) {
    666     // TODO(veluca): not supported with chroma subsampling.
    667     for (size_t i = 0; i < quantized_ac->size(); i++) {
    668       JXL_CHECK((*quantized_ac)[i]->Type() == ACType::k32);
    669       for (size_t c = 0; c < 3; c++) {
    670         rows[i][c] = (*quantized_ac)[i]->PlaneRow(c, group_idx, 0).ptr32;
    671       }
    672     }
    673   }
    674 
    675   const std::vector<std::unique_ptr<ACImage>>* JXL_RESTRICT quantized_ac;
    676   size_t offset = 0;
    677   const int32_t* JXL_RESTRICT rows[kMaxNumPasses][3];
    678   const uint32_t* shift_for_pass = nullptr;  // not owned
    679 };
    680 
    681 HWY_EXPORT(DecodeGroupImpl);
    682 
    683 }  // namespace
    684 
    685 Status DecodeGroup(const FrameHeader& frame_header,
    686                    BitReader* JXL_RESTRICT* JXL_RESTRICT readers,
    687                    size_t num_passes, size_t group_idx,
    688                    PassesDecoderState* JXL_RESTRICT dec_state,
    689                    GroupDecCache* JXL_RESTRICT group_dec_cache, size_t thread,
    690                    RenderPipelineInput& render_pipeline_input,
    691                    ImageBundle* JXL_RESTRICT decoded, size_t first_pass,
    692                    bool force_draw, bool dc_only, bool* should_run_pipeline) {
    693   DrawMode draw =
    694       (num_passes + first_pass == frame_header.passes.num_passes) || force_draw
    695           ? kDraw
    696           : kDontDraw;
    697 
    698   if (should_run_pipeline) {
    699     *should_run_pipeline = draw != kDontDraw;
    700   }
    701 
    702   if (draw == kDraw && num_passes == 0 && first_pass == 0) {
    703     JXL_RETURN_IF_ERROR(group_dec_cache->InitDCBufferOnce());
    704     const YCbCrChromaSubsampling& cs = frame_header.chroma_subsampling;
    705     for (size_t c : {0, 1, 2}) {
    706       size_t hs = cs.HShift(c);
    707       size_t vs = cs.VShift(c);
    708       // We reuse filter_input_storage here as it is not currently in use.
    709       const Rect src_rect_precs =
    710           dec_state->shared->frame_dim.BlockGroupRect(group_idx);
    711       const Rect src_rect =
    712           Rect(src_rect_precs.x0() >> hs, src_rect_precs.y0() >> vs,
    713                src_rect_precs.xsize() >> hs, src_rect_precs.ysize() >> vs);
    714       const Rect copy_rect(kRenderPipelineXOffset, 2, src_rect.xsize(),
    715                            src_rect.ysize());
    716       CopyImageToWithPadding(src_rect, dec_state->shared->dc->Plane(c), 2,
    717                              copy_rect, &group_dec_cache->dc_buffer);
    718       // Mirrorpad. Interleaving left and right padding ensures that padding
    719       // works out correctly even for images with DC size of 1.
    720       for (size_t y = 0; y < src_rect.ysize() + 4; y++) {
    721         size_t xend = kRenderPipelineXOffset +
    722                       (dec_state->shared->dc->Plane(c).xsize() >> hs) -
    723                       src_rect.x0();
    724         for (size_t ix = 0; ix < 2; ix++) {
    725           if (src_rect.x0() == 0) {
    726             group_dec_cache->dc_buffer.Row(y)[kRenderPipelineXOffset - ix - 1] =
    727                 group_dec_cache->dc_buffer.Row(y)[kRenderPipelineXOffset + ix];
    728           }
    729           if (src_rect.x0() + src_rect.xsize() + 2 >=
    730               (dec_state->shared->dc->xsize() >> hs)) {
    731             group_dec_cache->dc_buffer.Row(y)[xend + ix] =
    732                 group_dec_cache->dc_buffer.Row(y)[xend - ix - 1];
    733           }
    734         }
    735       }
    736       Rect dst_rect = render_pipeline_input.GetBuffer(c).second;
    737       ImageF* upsampling_dst = render_pipeline_input.GetBuffer(c).first;
    738       JXL_ASSERT(dst_rect.IsInside(*upsampling_dst));
    739 
    740       RenderPipelineStage::RowInfo input_rows(1, std::vector<float*>(5));
    741       RenderPipelineStage::RowInfo output_rows(1, std::vector<float*>(8));
    742       for (size_t y = src_rect.y0(); y < src_rect.y0() + src_rect.ysize();
    743            y++) {
    744         for (ssize_t iy = 0; iy < 5; iy++) {
    745           input_rows[0][iy] = group_dec_cache->dc_buffer.Row(
    746               Mirror(static_cast<ssize_t>(y) + iy - 2,
    747                      dec_state->shared->dc->Plane(c).ysize() >> vs) +
    748               2 - src_rect.y0());
    749         }
    750         for (size_t iy = 0; iy < 8; iy++) {
    751           output_rows[0][iy] =
    752               dst_rect.Row(upsampling_dst, ((y - src_rect.y0()) << 3) + iy) -
    753               kRenderPipelineXOffset;
    754         }
    755         // Arguments set to 0/nullptr are not used.
    756         JXL_RETURN_IF_ERROR(dec_state->upsampler8x->ProcessRow(
    757             input_rows, output_rows,
    758             /*xextra=*/0, src_rect.xsize(), 0, 0, thread));
    759       }
    760     }
    761     return true;
    762   }
    763 
    764   size_t histo_selector_bits = 0;
    765   if (dc_only) {
    766     JXL_ASSERT(num_passes == 0);
    767   } else {
    768     JXL_ASSERT(dec_state->shared->num_histograms > 0);
    769     histo_selector_bits = CeilLog2Nonzero(dec_state->shared->num_histograms);
    770   }
    771 
    772   auto get_block = jxl::make_unique<GetBlockFromBitstream>();
    773   JXL_RETURN_IF_ERROR(get_block->Init(
    774       frame_header, readers, num_passes, group_idx, histo_selector_bits,
    775       dec_state->shared->frame_dim.BlockGroupRect(group_idx), group_dec_cache,
    776       dec_state, first_pass));
    777 
    778   JXL_RETURN_IF_ERROR(HWY_DYNAMIC_DISPATCH(DecodeGroupImpl)(
    779       frame_header, get_block.get(), group_dec_cache, dec_state, thread,
    780       group_idx, render_pipeline_input, decoded, draw));
    781 
    782   for (size_t pass = 0; pass < num_passes; pass++) {
    783     if (!get_block->decoders[pass].CheckANSFinalState()) {
    784       return JXL_FAILURE("ANS checksum failure.");
    785     }
    786   }
    787   return true;
    788 }
    789 
    790 Status DecodeGroupForRoundtrip(const FrameHeader& frame_header,
    791                                const std::vector<std::unique_ptr<ACImage>>& ac,
    792                                size_t group_idx,
    793                                PassesDecoderState* JXL_RESTRICT dec_state,
    794                                GroupDecCache* JXL_RESTRICT group_dec_cache,
    795                                size_t thread,
    796                                RenderPipelineInput& render_pipeline_input,
    797                                ImageBundle* JXL_RESTRICT decoded,
    798                                AuxOut* aux_out) {
    799   GetBlockFromEncoder get_block(ac, group_idx, frame_header.passes.shift);
    800   JXL_RETURN_IF_ERROR(group_dec_cache->InitOnce(
    801       /*num_passes=*/0,
    802       /*used_acs=*/(1u << AcStrategy::kNumValidStrategies) - 1));
    803 
    804   return HWY_DYNAMIC_DISPATCH(DecodeGroupImpl)(
    805       frame_header, &get_block, group_dec_cache, dec_state, thread, group_idx,
    806       render_pipeline_input, decoded, kDraw);
    807 }
    808 
    809 }  // namespace jxl
    810 #endif  // HWY_ONCE