libjxl

FORK: libjxl patches used on blog
git clone https://git.neptards.moe/blog/libjxl.git
Log | Files | Refs | Submodules | README | LICENSE

dec_frame.cc (34633B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 #include "lib/jxl/dec_frame.h"
      7 
      8 #include <jxl/decode.h>
      9 #include <stddef.h>
     10 #include <stdint.h>
     11 
     12 #include <algorithm>
     13 #include <atomic>
     14 #include <cstdlib>
     15 #include <memory>
     16 #include <utility>
     17 #include <vector>
     18 
     19 #include "lib/jxl/ac_context.h"
     20 #include "lib/jxl/ac_strategy.h"
     21 #include "lib/jxl/base/bits.h"
     22 #include "lib/jxl/base/common.h"
     23 #include "lib/jxl/base/compiler_specific.h"
     24 #include "lib/jxl/base/data_parallel.h"
     25 #include "lib/jxl/base/printf_macros.h"
     26 #include "lib/jxl/base/status.h"
     27 #include "lib/jxl/chroma_from_luma.h"
     28 #include "lib/jxl/coeff_order.h"
     29 #include "lib/jxl/coeff_order_fwd.h"
     30 #include "lib/jxl/common.h"  // kMaxNumPasses
     31 #include "lib/jxl/compressed_dc.h"
     32 #include "lib/jxl/dct_util.h"
     33 #include "lib/jxl/dec_ans.h"
     34 #include "lib/jxl/dec_bit_reader.h"
     35 #include "lib/jxl/dec_cache.h"
     36 #include "lib/jxl/dec_group.h"
     37 #include "lib/jxl/dec_modular.h"
     38 #include "lib/jxl/dec_noise.h"
     39 #include "lib/jxl/dec_patch_dictionary.h"
     40 #include "lib/jxl/entropy_coder.h"
     41 #include "lib/jxl/epf.h"
     42 #include "lib/jxl/fields.h"
     43 #include "lib/jxl/frame_dimensions.h"
     44 #include "lib/jxl/frame_header.h"
     45 #include "lib/jxl/image.h"
     46 #include "lib/jxl/image_bundle.h"
     47 #include "lib/jxl/image_metadata.h"
     48 #include "lib/jxl/image_ops.h"
     49 #include "lib/jxl/jpeg/jpeg_data.h"
     50 #include "lib/jxl/loop_filter.h"
     51 #include "lib/jxl/passes_state.h"
     52 #include "lib/jxl/quant_weights.h"
     53 #include "lib/jxl/quantizer.h"
     54 #include "lib/jxl/render_pipeline/render_pipeline.h"
     55 #include "lib/jxl/splines.h"
     56 #include "lib/jxl/toc.h"
     57 
     58 namespace jxl {
     59 
     60 namespace {
     61 Status DecodeGlobalDCInfo(BitReader* reader, bool is_jpeg,
     62                           PassesDecoderState* state, ThreadPool* pool) {
     63   JXL_RETURN_IF_ERROR(state->shared_storage.quantizer.Decode(reader));
     64 
     65   JXL_RETURN_IF_ERROR(
     66       DecodeBlockCtxMap(reader, &state->shared_storage.block_ctx_map));
     67 
     68   JXL_RETURN_IF_ERROR(state->shared_storage.cmap.DecodeDC(reader));
     69 
     70   // Pre-compute info for decoding a group.
     71   if (is_jpeg) {
     72     state->shared_storage.quantizer.ClearDCMul();  // Don't dequant DC
     73   }
     74 
     75   state->shared_storage.ac_strategy.FillInvalid();
     76   return true;
     77 }
     78 }  // namespace
     79 
     80 Status DecodeFrame(PassesDecoderState* dec_state, ThreadPool* JXL_RESTRICT pool,
     81                    const uint8_t* next_in, size_t avail_in,
     82                    FrameHeader* frame_header, ImageBundle* decoded,
     83                    const CodecMetadata& metadata,
     84                    bool use_slow_rendering_pipeline) {
     85   FrameDecoder frame_decoder(dec_state, metadata, pool,
     86                              use_slow_rendering_pipeline);
     87 
     88   BitReader reader(Bytes(next_in, avail_in));
     89   JXL_RETURN_IF_ERROR(frame_decoder.InitFrame(&reader, decoded,
     90                                               /*is_preview=*/false));
     91   JXL_RETURN_IF_ERROR(frame_decoder.InitFrameOutput());
     92   if (frame_header) {
     93     *frame_header = frame_decoder.GetFrameHeader();
     94   }
     95 
     96   JXL_RETURN_IF_ERROR(reader.AllReadsWithinBounds());
     97   size_t header_bytes = reader.TotalBitsConsumed() / kBitsPerByte;
     98   JXL_RETURN_IF_ERROR(reader.Close());
     99 
    100   size_t processed_bytes = header_bytes;
    101   Status close_ok = true;
    102   std::vector<std::unique_ptr<BitReader>> section_readers;
    103   {
    104     std::vector<std::unique_ptr<BitReaderScopedCloser>> section_closers;
    105     std::vector<FrameDecoder::SectionInfo> section_info;
    106     std::vector<FrameDecoder::SectionStatus> section_status;
    107     size_t pos = header_bytes;
    108     size_t index = 0;
    109     for (auto toc_entry : frame_decoder.Toc()) {
    110       JXL_RETURN_IF_ERROR(pos + toc_entry.size <= avail_in);
    111       auto br = make_unique<BitReader>(Bytes(next_in + pos, toc_entry.size));
    112       section_info.emplace_back(
    113           FrameDecoder::SectionInfo{br.get(), toc_entry.id, index++});
    114       section_closers.emplace_back(
    115           make_unique<BitReaderScopedCloser>(br.get(), &close_ok));
    116       section_readers.emplace_back(std::move(br));
    117       pos += toc_entry.size;
    118     }
    119     section_status.resize(section_info.size());
    120     JXL_RETURN_IF_ERROR(frame_decoder.ProcessSections(
    121         section_info.data(), section_info.size(), section_status.data()));
    122     for (size_t i = 0; i < section_status.size(); i++) {
    123       JXL_RETURN_IF_ERROR(section_status[i] == FrameDecoder::kDone);
    124       processed_bytes += frame_decoder.Toc()[i].size;
    125     }
    126   }
    127   JXL_RETURN_IF_ERROR(close_ok);
    128   JXL_RETURN_IF_ERROR(frame_decoder.FinalizeFrame());
    129   decoded->SetDecodedBytes(processed_bytes);
    130   return true;
    131 }
    132 
    133 Status FrameDecoder::InitFrame(BitReader* JXL_RESTRICT br, ImageBundle* decoded,
    134                                bool is_preview) {
    135   decoded_ = decoded;
    136   JXL_ASSERT(is_finalized_);
    137 
    138   // Reset the dequantization matrices to their default values.
    139   dec_state_->shared_storage.matrices = DequantMatrices();
    140 
    141   frame_header_.nonserialized_is_preview = is_preview;
    142   JXL_ASSERT(frame_header_.nonserialized_metadata != nullptr);
    143   JXL_RETURN_IF_ERROR(ReadFrameHeader(br, &frame_header_));
    144   frame_dim_ = frame_header_.ToFrameDimensions();
    145   JXL_DEBUG_V(2, "FrameHeader: %s", frame_header_.DebugString().c_str());
    146 
    147   const size_t num_passes = frame_header_.passes.num_passes;
    148   const size_t num_groups = frame_dim_.num_groups;
    149 
    150   // If the previous frame was not a kRegularFrame, `decoded` may have different
    151   // dimensions; must reset to avoid errors.
    152   decoded->RemoveColor();
    153   decoded->ClearExtraChannels();
    154 
    155   decoded->duration = frame_header_.animation_frame.duration;
    156 
    157   if (!frame_header_.nonserialized_is_preview &&
    158       (frame_header_.is_last || frame_header_.animation_frame.duration > 0) &&
    159       (frame_header_.frame_type == kRegularFrame ||
    160        frame_header_.frame_type == kSkipProgressive)) {
    161     ++dec_state_->visible_frame_index;
    162     dec_state_->nonvisible_frame_index = 0;
    163   } else {
    164     ++dec_state_->nonvisible_frame_index;
    165   }
    166 
    167   // Read TOC.
    168   const size_t toc_entries =
    169       NumTocEntries(num_groups, frame_dim_.num_dc_groups, num_passes);
    170   std::vector<uint32_t> sizes;
    171   std::vector<coeff_order_t> permutation;
    172   JXL_RETURN_IF_ERROR(ReadToc(toc_entries, br, &sizes, &permutation));
    173   bool have_permutation = !permutation.empty();
    174   toc_.resize(toc_entries);
    175   section_sizes_sum_ = 0;
    176   for (size_t i = 0; i < toc_entries; ++i) {
    177     toc_[i].size = sizes[i];
    178     size_t index = have_permutation ? permutation[i] : i;
    179     toc_[index].id = i;
    180     if (section_sizes_sum_ + toc_[i].size < section_sizes_sum_) {
    181       return JXL_FAILURE("group offset overflow");
    182     }
    183     section_sizes_sum_ += toc_[i].size;
    184   }
    185 
    186   if (JXL_DEBUG_V_LEVEL >= 3) {
    187     for (size_t i = 0; i < toc_entries; ++i) {
    188       JXL_DEBUG_V(3, "TOC entry %" PRIuS " size %" PRIuS " id %" PRIuS "", i,
    189                   toc_[i].size, toc_[i].id);
    190     }
    191   }
    192 
    193   JXL_DASSERT((br->TotalBitsConsumed() % kBitsPerByte) == 0);
    194   const size_t group_codes_begin = br->TotalBitsConsumed() / kBitsPerByte;
    195   JXL_DASSERT(!toc_.empty());
    196 
    197   // Overflow check.
    198   if (group_codes_begin + section_sizes_sum_ < group_codes_begin) {
    199     return JXL_FAILURE("Invalid group codes");
    200   }
    201 
    202   if (!frame_header_.chroma_subsampling.Is444() &&
    203       !(frame_header_.flags & FrameHeader::kSkipAdaptiveDCSmoothing) &&
    204       frame_header_.encoding == FrameEncoding::kVarDCT) {
    205     return JXL_FAILURE(
    206         "Non-444 chroma subsampling is not allowed when adaptive DC "
    207         "smoothing is enabled");
    208   }
    209   return true;
    210 }
    211 
    212 Status FrameDecoder::InitFrameOutput() {
    213   JXL_RETURN_IF_ERROR(
    214       InitializePassesSharedState(frame_header_, &dec_state_->shared_storage));
    215   JXL_RETURN_IF_ERROR(dec_state_->Init(frame_header_));
    216   modular_frame_decoder_.Init(frame_dim_);
    217 
    218   if (decoded_->IsJPEG()) {
    219     if (frame_header_.encoding == FrameEncoding::kModular) {
    220       return JXL_FAILURE("Cannot output JPEG from Modular");
    221     }
    222     jpeg::JPEGData* jpeg_data = decoded_->jpeg_data.get();
    223     size_t num_components = jpeg_data->components.size();
    224     if (num_components != 1 && num_components != 3) {
    225       return JXL_FAILURE("Invalid number of components");
    226     }
    227     if (frame_header_.nonserialized_metadata->m.xyb_encoded) {
    228       return JXL_FAILURE("Cannot decode to JPEG an XYB image");
    229     }
    230     auto jpeg_c_map = JpegOrder(ColorTransform::kYCbCr, num_components == 1);
    231     decoded_->jpeg_data->width = frame_dim_.xsize;
    232     decoded_->jpeg_data->height = frame_dim_.ysize;
    233     for (size_t c = 0; c < num_components; c++) {
    234       auto& component = jpeg_data->components[jpeg_c_map[c]];
    235       component.width_in_blocks =
    236           frame_dim_.xsize_blocks >> frame_header_.chroma_subsampling.HShift(c);
    237       component.height_in_blocks =
    238           frame_dim_.ysize_blocks >> frame_header_.chroma_subsampling.VShift(c);
    239       component.h_samp_factor =
    240           1 << frame_header_.chroma_subsampling.RawHShift(c);
    241       component.v_samp_factor =
    242           1 << frame_header_.chroma_subsampling.RawVShift(c);
    243       component.coeffs.resize(component.width_in_blocks *
    244                               component.height_in_blocks * jxl::kDCTBlockSize);
    245     }
    246   }
    247 
    248   // Clear the state.
    249   decoded_dc_global_ = false;
    250   decoded_ac_global_ = false;
    251   is_finalized_ = false;
    252   finalized_dc_ = false;
    253   num_sections_done_ = 0;
    254   decoded_dc_groups_.clear();
    255   decoded_dc_groups_.resize(frame_dim_.num_dc_groups);
    256   decoded_passes_per_ac_group_.clear();
    257   decoded_passes_per_ac_group_.resize(frame_dim_.num_groups, 0);
    258   processed_section_.clear();
    259   processed_section_.resize(toc_.size());
    260   allocated_ = false;
    261   return true;
    262 }
    263 
    264 Status FrameDecoder::ProcessDCGlobal(BitReader* br) {
    265   PassesSharedState& shared = dec_state_->shared_storage;
    266   if (frame_header_.flags & FrameHeader::kPatches) {
    267     bool uses_extra_channels = false;
    268     JXL_RETURN_IF_ERROR(shared.image_features.patches.Decode(
    269         br, frame_dim_.xsize_padded, frame_dim_.ysize_padded,
    270         &uses_extra_channels));
    271     if (uses_extra_channels && frame_header_.upsampling != 1) {
    272       for (size_t ecups : frame_header_.extra_channel_upsampling) {
    273         if (ecups != frame_header_.upsampling) {
    274           return JXL_FAILURE(
    275               "Cannot use extra channels in patches if color channels are "
    276               "subsampled differently from extra channels");
    277         }
    278       }
    279     }
    280   } else {
    281     shared.image_features.patches.Clear();
    282   }
    283   shared.image_features.splines.Clear();
    284   if (frame_header_.flags & FrameHeader::kSplines) {
    285     JXL_RETURN_IF_ERROR(shared.image_features.splines.Decode(
    286         br, frame_dim_.xsize * frame_dim_.ysize));
    287   }
    288   if (frame_header_.flags & FrameHeader::kNoise) {
    289     JXL_RETURN_IF_ERROR(DecodeNoise(br, &shared.image_features.noise_params));
    290   }
    291   JXL_RETURN_IF_ERROR(dec_state_->shared_storage.matrices.DecodeDC(br));
    292 
    293   if (frame_header_.encoding == FrameEncoding::kVarDCT) {
    294     JXL_RETURN_IF_ERROR(
    295         jxl::DecodeGlobalDCInfo(br, decoded_->IsJPEG(), dec_state_, pool_));
    296   }
    297   // Splines' draw cache uses the color correlation map.
    298   if (frame_header_.flags & FrameHeader::kSplines) {
    299     JXL_RETURN_IF_ERROR(shared.image_features.splines.InitializeDrawCache(
    300         frame_dim_.xsize_upsampled, frame_dim_.ysize_upsampled,
    301         dec_state_->shared->cmap));
    302   }
    303   Status dec_status = modular_frame_decoder_.DecodeGlobalInfo(
    304       br, frame_header_, /*allow_truncated_group=*/false);
    305   if (dec_status.IsFatalError()) return dec_status;
    306   if (dec_status) {
    307     decoded_dc_global_ = true;
    308   }
    309   return dec_status;
    310 }
    311 
    312 Status FrameDecoder::ProcessDCGroup(size_t dc_group_id, BitReader* br) {
    313   const size_t gx = dc_group_id % frame_dim_.xsize_dc_groups;
    314   const size_t gy = dc_group_id / frame_dim_.xsize_dc_groups;
    315   const LoopFilter& lf = frame_header_.loop_filter;
    316   if (frame_header_.encoding == FrameEncoding::kVarDCT &&
    317       !(frame_header_.flags & FrameHeader::kUseDcFrame)) {
    318     JXL_RETURN_IF_ERROR(modular_frame_decoder_.DecodeVarDCTDC(
    319         frame_header_, dc_group_id, br, dec_state_));
    320   }
    321   const Rect mrect(gx * frame_dim_.dc_group_dim, gy * frame_dim_.dc_group_dim,
    322                    frame_dim_.dc_group_dim, frame_dim_.dc_group_dim);
    323   JXL_RETURN_IF_ERROR(modular_frame_decoder_.DecodeGroup(
    324       frame_header_, mrect, br, 3, 1000,
    325       ModularStreamId::ModularDC(dc_group_id),
    326       /*zerofill=*/false, nullptr, nullptr,
    327       /*allow_truncated=*/false));
    328   if (frame_header_.encoding == FrameEncoding::kVarDCT) {
    329     JXL_RETURN_IF_ERROR(modular_frame_decoder_.DecodeAcMetadata(
    330         frame_header_, dc_group_id, br, dec_state_));
    331   } else if (lf.epf_iters > 0) {
    332     FillImage(kInvSigmaNum / lf.epf_sigma_for_modular, &dec_state_->sigma);
    333   }
    334   decoded_dc_groups_[dc_group_id] = JXL_TRUE;
    335   return true;
    336 }
    337 
    338 Status FrameDecoder::FinalizeDC() {
    339   // Do Adaptive DC smoothing if enabled. This *must* happen between all the
    340   // ProcessDCGroup and ProcessACGroup.
    341   if (frame_header_.encoding == FrameEncoding::kVarDCT &&
    342       !(frame_header_.flags & FrameHeader::kSkipAdaptiveDCSmoothing) &&
    343       !(frame_header_.flags & FrameHeader::kUseDcFrame)) {
    344     JXL_RETURN_IF_ERROR(
    345         AdaptiveDCSmoothing(dec_state_->shared->quantizer.MulDC(),
    346                             &dec_state_->shared_storage.dc_storage, pool_));
    347   }
    348 
    349   finalized_dc_ = true;
    350   return true;
    351 }
    352 
    353 Status FrameDecoder::AllocateOutput() {
    354   if (allocated_) return true;
    355   modular_frame_decoder_.MaybeDropFullImage();
    356   decoded_->origin = frame_header_.frame_origin;
    357   JXL_RETURN_IF_ERROR(
    358       dec_state_->InitForAC(frame_header_.passes.num_passes, nullptr));
    359   allocated_ = true;
    360   return true;
    361 }
    362 
    363 Status FrameDecoder::ProcessACGlobal(BitReader* br) {
    364   JXL_CHECK(finalized_dc_);
    365 
    366   // Decode AC group.
    367   if (frame_header_.encoding == FrameEncoding::kVarDCT) {
    368     JXL_RETURN_IF_ERROR(dec_state_->shared_storage.matrices.Decode(
    369         br, &modular_frame_decoder_));
    370     JXL_RETURN_IF_ERROR(dec_state_->shared_storage.matrices.EnsureComputed(
    371         dec_state_->used_acs));
    372 
    373     size_t num_histo_bits =
    374         CeilLog2Nonzero(dec_state_->shared->frame_dim.num_groups);
    375     dec_state_->shared_storage.num_histograms =
    376         1 + br->ReadBits(num_histo_bits);
    377 
    378     JXL_DEBUG_V(3,
    379                 "Processing AC global with %d passes and %" PRIuS
    380                 " sets of histograms",
    381                 frame_header_.passes.num_passes,
    382                 dec_state_->shared_storage.num_histograms);
    383 
    384     dec_state_->code.resize(kMaxNumPasses);
    385     dec_state_->context_map.resize(kMaxNumPasses);
    386     // Read coefficient orders and histograms.
    387     size_t max_num_bits_ac = 0;
    388     for (size_t i = 0; i < frame_header_.passes.num_passes; i++) {
    389       uint16_t used_orders = U32Coder::Read(kOrderEnc, br);
    390       JXL_RETURN_IF_ERROR(DecodeCoeffOrders(
    391           used_orders, dec_state_->used_acs,
    392           &dec_state_->shared_storage
    393                .coeff_orders[i * dec_state_->shared_storage.coeff_order_size],
    394           br));
    395       size_t num_contexts =
    396           dec_state_->shared->num_histograms *
    397           dec_state_->shared_storage.block_ctx_map.NumACContexts();
    398       JXL_RETURN_IF_ERROR(DecodeHistograms(
    399           br, num_contexts, &dec_state_->code[i], &dec_state_->context_map[i]));
    400       // Add extra values to enable the cheat in hot loop of DecodeACVarBlock.
    401       dec_state_->context_map[i].resize(
    402           num_contexts + kZeroDensityContextLimit - kZeroDensityContextCount);
    403       max_num_bits_ac =
    404           std::max(max_num_bits_ac, dec_state_->code[i].max_num_bits);
    405     }
    406     max_num_bits_ac += CeilLog2Nonzero(frame_header_.passes.num_passes);
    407     // 16-bit buffer for decoding to JPEG are not implemented.
    408     // TODO(veluca): figure out the exact limit - 16 should still work with
    409     // 16-bit buffers, but we are excluding it for safety.
    410     bool use_16_bit = max_num_bits_ac < 16 && !decoded_->IsJPEG();
    411     bool store = frame_header_.passes.num_passes > 1;
    412     size_t xs = store ? kGroupDim * kGroupDim : 0;
    413     size_t ys = store ? frame_dim_.num_groups : 0;
    414     if (use_16_bit) {
    415       JXL_ASSIGN_OR_RETURN(dec_state_->coefficients,
    416                            ACImageT<int16_t>::Make(xs, ys));
    417     } else {
    418       JXL_ASSIGN_OR_RETURN(dec_state_->coefficients,
    419                            ACImageT<int32_t>::Make(xs, ys));
    420     }
    421     if (store) {
    422       dec_state_->coefficients->ZeroFill();
    423     }
    424   }
    425 
    426   // Set JPEG decoding data.
    427   if (decoded_->IsJPEG()) {
    428     decoded_->color_transform = frame_header_.color_transform;
    429     decoded_->chroma_subsampling = frame_header_.chroma_subsampling;
    430     const std::vector<QuantEncoding>& qe =
    431         dec_state_->shared_storage.matrices.encodings();
    432     if (qe.empty() || qe[0].mode != QuantEncoding::Mode::kQuantModeRAW ||
    433         std::abs(qe[0].qraw.qtable_den - 1.f / (8 * 255)) > 1e-8f) {
    434       return JXL_FAILURE(
    435           "Quantization table is not a JPEG quantization table.");
    436     }
    437     jpeg::JPEGData* jpeg_data = decoded_->jpeg_data.get();
    438     size_t num_components = jpeg_data->components.size();
    439     bool is_gray = (num_components == 1);
    440     auto jpeg_c_map = JpegOrder(frame_header_.color_transform, is_gray);
    441     size_t qt_set = 0;
    442     for (size_t c = 0; c < num_components; c++) {
    443       // TODO(eustas): why 1-st quant table for gray?
    444       size_t quant_c = is_gray ? 1 : c;
    445       size_t qpos = jpeg_data->components[jpeg_c_map[c]].quant_idx;
    446       JXL_CHECK(qpos != jpeg_data->quant.size());
    447       qt_set |= 1 << qpos;
    448       for (size_t x = 0; x < 8; x++) {
    449         for (size_t y = 0; y < 8; y++) {
    450           jpeg_data->quant[qpos].values[x * 8 + y] =
    451               (*qe[0].qraw.qtable)[quant_c * 64 + y * 8 + x];
    452         }
    453       }
    454     }
    455     for (size_t i = 0; i < jpeg_data->quant.size(); i++) {
    456       if (qt_set & (1 << i)) continue;
    457       if (i == 0) return JXL_FAILURE("First quant table unused.");
    458       // Unused quant table is set to copy of previous quant table
    459       for (size_t j = 0; j < 64; j++) {
    460         jpeg_data->quant[i].values[j] = jpeg_data->quant[i - 1].values[j];
    461       }
    462     }
    463   }
    464   decoded_ac_global_ = true;
    465   return true;
    466 }
    467 
    468 Status FrameDecoder::ProcessACGroup(size_t ac_group_id,
    469                                     BitReader* JXL_RESTRICT* br,
    470                                     size_t num_passes, size_t thread,
    471                                     bool force_draw, bool dc_only) {
    472   size_t group_dim = frame_dim_.group_dim;
    473   const size_t gx = ac_group_id % frame_dim_.xsize_groups;
    474   const size_t gy = ac_group_id / frame_dim_.xsize_groups;
    475   const size_t x = gx * group_dim;
    476   const size_t y = gy * group_dim;
    477   JXL_DEBUG_V(3,
    478               "Processing AC group %" PRIuS "(%" PRIuS ",%" PRIuS
    479               ") group_dim: %" PRIuS " decoded passes: %u new passes: %" PRIuS,
    480               ac_group_id, gx, gy, group_dim,
    481               decoded_passes_per_ac_group_[ac_group_id], num_passes);
    482 
    483   RenderPipelineInput render_pipeline_input =
    484       dec_state_->render_pipeline->GetInputBuffers(ac_group_id, thread);
    485 
    486   bool should_run_pipeline = true;
    487 
    488   if (frame_header_.encoding == FrameEncoding::kVarDCT) {
    489     JXL_RETURN_IF_ERROR(group_dec_caches_[thread].InitOnce(
    490         frame_header_.passes.num_passes, dec_state_->used_acs));
    491     JXL_RETURN_IF_ERROR(DecodeGroup(frame_header_, br, num_passes, ac_group_id,
    492                                     dec_state_, &group_dec_caches_[thread],
    493                                     thread, render_pipeline_input, decoded_,
    494                                     decoded_passes_per_ac_group_[ac_group_id],
    495                                     force_draw, dc_only, &should_run_pipeline));
    496   }
    497 
    498   // don't limit to image dimensions here (is done in DecodeGroup)
    499   const Rect mrect(x, y, group_dim, group_dim);
    500   bool modular_ready = false;
    501   size_t pass0 = decoded_passes_per_ac_group_[ac_group_id];
    502   size_t pass1 =
    503       force_draw ? frame_header_.passes.num_passes : pass0 + num_passes;
    504   for (size_t i = pass0; i < pass1; ++i) {
    505     int minShift;
    506     int maxShift;
    507     frame_header_.passes.GetDownsamplingBracket(i, minShift, maxShift);
    508     bool modular_pass_ready = true;
    509     JXL_DEBUG_V(2, "Decoding modular in group %d pass %d",
    510                 static_cast<int>(ac_group_id), static_cast<int>(i));
    511     if (i < pass0 + num_passes) {
    512       JXL_DEBUG_V(2, "Bit reader position: %" PRIuS " / %" PRIuS,
    513                   br[i - pass0]->TotalBitsConsumed(),
    514                   br[i - pass0]->TotalBytes() * kBitsPerByte);
    515       JXL_RETURN_IF_ERROR(modular_frame_decoder_.DecodeGroup(
    516           frame_header_, mrect, br[i - pass0], minShift, maxShift,
    517           ModularStreamId::ModularAC(ac_group_id, i),
    518           /*zerofill=*/false, dec_state_, &render_pipeline_input,
    519           /*allow_truncated=*/false, &modular_pass_ready));
    520     } else {
    521       JXL_RETURN_IF_ERROR(modular_frame_decoder_.DecodeGroup(
    522           frame_header_, mrect, nullptr, minShift, maxShift,
    523           ModularStreamId::ModularAC(ac_group_id, i), /*zerofill=*/true,
    524           dec_state_, &render_pipeline_input,
    525           /*allow_truncated=*/false, &modular_pass_ready));
    526     }
    527     if (modular_pass_ready) modular_ready = true;
    528   }
    529   decoded_passes_per_ac_group_[ac_group_id] += num_passes;
    530 
    531   if ((frame_header_.flags & FrameHeader::kNoise) != 0) {
    532     size_t noise_c_start =
    533         3 + frame_header_.nonserialized_metadata->m.num_extra_channels;
    534     // When the color channels are downsampled, we need to generate more noise
    535     // input for the current group than just the group dimensions.
    536     std::pair<ImageF*, Rect> rects[3];
    537     for (size_t iy = 0; iy < frame_header_.upsampling; iy++) {
    538       for (size_t ix = 0; ix < frame_header_.upsampling; ix++) {
    539         for (size_t c = 0; c < 3; c++) {
    540           auto r = render_pipeline_input.GetBuffer(noise_c_start + c);
    541           rects[c].first = r.first;
    542           size_t x1 = r.second.x0() + r.second.xsize();
    543           size_t y1 = r.second.y0() + r.second.ysize();
    544           rects[c].second = Rect(r.second.x0() + ix * group_dim,
    545                                  r.second.y0() + iy * group_dim, group_dim,
    546                                  group_dim, x1, y1);
    547         }
    548         Random3Planes(dec_state_->visible_frame_index,
    549                       dec_state_->nonvisible_frame_index,
    550                       (gx * frame_header_.upsampling + ix) * group_dim,
    551                       (gy * frame_header_.upsampling + iy) * group_dim,
    552                       rects[0], rects[1], rects[2]);
    553       }
    554     }
    555   }
    556 
    557   if (!modular_frame_decoder_.UsesFullImage() && !decoded_->IsJPEG()) {
    558     if (should_run_pipeline && modular_ready) {
    559       JXL_RETURN_IF_ERROR(render_pipeline_input.Done());
    560     } else if (force_draw) {
    561       return JXL_FAILURE("Modular group decoding failed.");
    562     }
    563   }
    564   return true;
    565 }
    566 
    567 void FrameDecoder::MarkSections(const SectionInfo* sections, size_t num,
    568                                 const SectionStatus* section_status) {
    569   num_sections_done_ += num;
    570   for (size_t i = 0; i < num; i++) {
    571     if (section_status[i] != SectionStatus::kDone) {
    572       processed_section_[sections[i].id] = JXL_FALSE;
    573       num_sections_done_--;
    574     }
    575   }
    576 }
    577 
    578 Status FrameDecoder::ProcessSections(const SectionInfo* sections, size_t num,
    579                                      SectionStatus* section_status) {
    580   if (num == 0) return true;  // Nothing to process
    581   std::fill(section_status, section_status + num, SectionStatus::kSkipped);
    582   size_t dc_global_sec = num;
    583   size_t ac_global_sec = num;
    584   std::vector<size_t> dc_group_sec(frame_dim_.num_dc_groups, num);
    585   std::vector<std::vector<size_t>> ac_group_sec(
    586       frame_dim_.num_groups,
    587       std::vector<size_t>(frame_header_.passes.num_passes, num));
    588   // This keeps track of the number of ac passes we want to process during this
    589   // call of ProcessSections.
    590   std::vector<size_t> desired_num_ac_passes(frame_dim_.num_groups);
    591   bool single_section =
    592       frame_dim_.num_groups == 1 && frame_header_.passes.num_passes == 1;
    593   if (single_section) {
    594     JXL_ASSERT(num == 1);
    595     JXL_ASSERT(sections[0].id == 0);
    596     if (processed_section_[0] == JXL_FALSE) {
    597       processed_section_[0] = JXL_TRUE;
    598       ac_group_sec[0].resize(1);
    599       dc_global_sec = ac_global_sec = dc_group_sec[0] = ac_group_sec[0][0] = 0;
    600       desired_num_ac_passes[0] = 1;
    601     } else {
    602       section_status[0] = SectionStatus::kDuplicate;
    603     }
    604   } else {
    605     size_t ac_global_index = frame_dim_.num_dc_groups + 1;
    606     for (size_t i = 0; i < num; i++) {
    607       JXL_ASSERT(sections[i].id < processed_section_.size());
    608       if (processed_section_[sections[i].id]) {
    609         section_status[i] = SectionStatus::kDuplicate;
    610         continue;
    611       }
    612       if (sections[i].id == 0) {
    613         dc_global_sec = i;
    614       } else if (sections[i].id < ac_global_index) {
    615         dc_group_sec[sections[i].id - 1] = i;
    616       } else if (sections[i].id == ac_global_index) {
    617         ac_global_sec = i;
    618       } else {
    619         size_t ac_idx = sections[i].id - ac_global_index - 1;
    620         size_t acg = ac_idx % frame_dim_.num_groups;
    621         size_t acp = ac_idx / frame_dim_.num_groups;
    622         if (acp >= frame_header_.passes.num_passes) {
    623           return JXL_FAILURE("Invalid section ID");
    624         }
    625         ac_group_sec[acg][acp] = i;
    626       }
    627       processed_section_[sections[i].id] = JXL_TRUE;
    628     }
    629     // Count number of new passes per group.
    630     for (size_t g = 0; g < ac_group_sec.size(); g++) {
    631       size_t j = 0;
    632       for (; j + decoded_passes_per_ac_group_[g] <
    633              frame_header_.passes.num_passes;
    634            j++) {
    635         if (ac_group_sec[g][j + decoded_passes_per_ac_group_[g]] == num) {
    636           break;
    637         }
    638       }
    639       desired_num_ac_passes[g] = j;
    640     }
    641   }
    642   if (dc_global_sec != num) {
    643     Status dc_global_status = ProcessDCGlobal(sections[dc_global_sec].br);
    644     if (dc_global_status.IsFatalError()) return dc_global_status;
    645     if (dc_global_status) {
    646       section_status[dc_global_sec] = SectionStatus::kDone;
    647     } else {
    648       section_status[dc_global_sec] = SectionStatus::kPartial;
    649     }
    650   }
    651 
    652   std::atomic<bool> has_error{false};
    653   if (decoded_dc_global_) {
    654     JXL_RETURN_IF_ERROR(RunOnPool(
    655         pool_, 0, dc_group_sec.size(), ThreadPool::NoInit,
    656         [this, &dc_group_sec, &num, &sections, &section_status, &has_error](
    657             size_t i, size_t thread) {
    658           if (has_error) return;
    659           if (dc_group_sec[i] != num) {
    660             if (!ProcessDCGroup(i, sections[dc_group_sec[i]].br)) {
    661               has_error = true;
    662               return;
    663             } else {
    664               section_status[dc_group_sec[i]] = SectionStatus::kDone;
    665             }
    666           }
    667         },
    668         "DecodeDCGroup"));
    669   }
    670   if (has_error) return JXL_FAILURE("Error in DC group");
    671 
    672   if (!HasDcGroupToDecode() && !finalized_dc_) {
    673     PassesDecoderState::PipelineOptions pipeline_options;
    674     pipeline_options.use_slow_render_pipeline = use_slow_rendering_pipeline_;
    675     pipeline_options.coalescing = coalescing_;
    676     pipeline_options.render_spotcolors = render_spotcolors_;
    677     pipeline_options.render_noise = true;
    678     JXL_RETURN_IF_ERROR(
    679         dec_state_->PreparePipeline(frame_header_, decoded_, pipeline_options));
    680     JXL_RETURN_IF_ERROR(FinalizeDC());
    681     JXL_RETURN_IF_ERROR(AllocateOutput());
    682     if (progressive_detail_ >= JxlProgressiveDetail::kDC) {
    683       MarkSections(sections, num, section_status);
    684       return true;
    685     }
    686   }
    687 
    688   if (finalized_dc_ && ac_global_sec != num && !decoded_ac_global_) {
    689     JXL_RETURN_IF_ERROR(ProcessACGlobal(sections[ac_global_sec].br));
    690     section_status[ac_global_sec] = SectionStatus::kDone;
    691   }
    692 
    693   if (progressive_detail_ >= JxlProgressiveDetail::kLastPasses) {
    694     // Mark that we only want the next progression pass.
    695     size_t target_complete_passes = NextNumPassesToPause();
    696     for (size_t i = 0; i < ac_group_sec.size(); i++) {
    697       desired_num_ac_passes[i] =
    698           std::min(desired_num_ac_passes[i],
    699                    target_complete_passes - decoded_passes_per_ac_group_[i]);
    700     }
    701   }
    702 
    703   if (decoded_ac_global_) {
    704     // Mark all the AC groups that we received as not complete yet.
    705     for (size_t i = 0; i < ac_group_sec.size(); i++) {
    706       if (desired_num_ac_passes[i] != 0) {
    707         dec_state_->render_pipeline->ClearDone(i);
    708       }
    709     }
    710 
    711     JXL_RETURN_IF_ERROR(RunOnPool(
    712         pool_, 0, ac_group_sec.size(),
    713         [this](size_t num_threads) {
    714           return PrepareStorage(num_threads,
    715                                 decoded_passes_per_ac_group_.size());
    716         },
    717         [this, &ac_group_sec, &desired_num_ac_passes, &num, &sections,
    718          &section_status, &has_error](size_t g, size_t thread) {
    719           if (desired_num_ac_passes[g] == 0) {
    720             // no new AC pass, nothing to do
    721             return;
    722           }
    723           (void)num;
    724           size_t first_pass = decoded_passes_per_ac_group_[g];
    725           BitReader* JXL_RESTRICT readers[kMaxNumPasses];
    726           for (size_t i = 0; i < desired_num_ac_passes[g]; i++) {
    727             JXL_ASSERT(ac_group_sec[g][first_pass + i] != num);
    728             readers[i] = sections[ac_group_sec[g][first_pass + i]].br;
    729           }
    730           if (!ProcessACGroup(g, readers, desired_num_ac_passes[g],
    731                               GetStorageLocation(thread, g),
    732                               /*force_draw=*/false, /*dc_only=*/false)) {
    733             has_error = true;
    734           } else {
    735             for (size_t i = 0; i < desired_num_ac_passes[g]; i++) {
    736               section_status[ac_group_sec[g][first_pass + i]] =
    737                   SectionStatus::kDone;
    738             }
    739           }
    740         },
    741         "DecodeGroup"));
    742   }
    743   if (has_error) return JXL_FAILURE("Error in AC group");
    744 
    745   MarkSections(sections, num, section_status);
    746   return true;
    747 }
    748 
    749 Status FrameDecoder::Flush() {
    750   bool has_blending = frame_header_.blending_info.mode != BlendMode::kReplace ||
    751                       frame_header_.custom_size_or_origin;
    752   for (const auto& blending_info_ec :
    753        frame_header_.extra_channel_blending_info) {
    754     if (blending_info_ec.mode != BlendMode::kReplace) has_blending = true;
    755   }
    756   // No early Flush() if blending is enabled.
    757   if (has_blending && !is_finalized_) {
    758     return false;
    759   }
    760   // No early Flush() - nothing to do - if the frame is a kSkipProgressive
    761   // frame.
    762   if (frame_header_.frame_type == FrameType::kSkipProgressive &&
    763       !is_finalized_) {
    764     return true;
    765   }
    766   if (decoded_->IsJPEG()) {
    767     // Nothing to do.
    768     return true;
    769   }
    770   JXL_RETURN_IF_ERROR(AllocateOutput());
    771 
    772   uint32_t completely_decoded_ac_pass = *std::min_element(
    773       decoded_passes_per_ac_group_.begin(), decoded_passes_per_ac_group_.end());
    774   if (completely_decoded_ac_pass < frame_header_.passes.num_passes) {
    775     // We don't have all AC yet: force a draw of all the missing areas.
    776     // Mark all sections as not complete.
    777     for (size_t i = 0; i < decoded_passes_per_ac_group_.size(); i++) {
    778       if (decoded_passes_per_ac_group_[i] < frame_header_.passes.num_passes) {
    779         dec_state_->render_pipeline->ClearDone(i);
    780       }
    781     }
    782     std::atomic<bool> has_error{false};
    783     JXL_RETURN_IF_ERROR(RunOnPool(
    784         pool_, 0, decoded_passes_per_ac_group_.size(),
    785         [this](const size_t num_threads) {
    786           return PrepareStorage(num_threads,
    787                                 decoded_passes_per_ac_group_.size());
    788         },
    789         [this, &has_error](const uint32_t g, size_t thread) {
    790           if (has_error) return;
    791           if (decoded_passes_per_ac_group_[g] ==
    792               frame_header_.passes.num_passes) {
    793             // This group was drawn already, nothing to do.
    794             return;
    795           }
    796           BitReader* JXL_RESTRICT readers[kMaxNumPasses] = {};
    797           if (!ProcessACGroup(
    798                   g, readers, /*num_passes=*/0, GetStorageLocation(thread, g),
    799                   /*force_draw=*/true, /*dc_only=*/!decoded_ac_global_)) {
    800             has_error = true;
    801             return;
    802           }
    803         },
    804         "ForceDrawGroup"));
    805     if (has_error) return JXL_FAILURE("Drawing groups failed");
    806   }
    807 
    808   // undo global modular transforms and copy int pixel buffers to float ones
    809   JXL_RETURN_IF_ERROR(modular_frame_decoder_.FinalizeDecoding(
    810       frame_header_, dec_state_, pool_, is_finalized_));
    811 
    812   return true;
    813 }
    814 
    815 int FrameDecoder::SavedAs(const FrameHeader& header) {
    816   if (header.frame_type == FrameType::kDCFrame) {
    817     // bits 16, 32, 64, 128 for DC level
    818     return 16 << (header.dc_level - 1);
    819   } else if (header.CanBeReferenced()) {
    820     // bits 1, 2, 4 and 8 for the references
    821     return 1 << header.save_as_reference;
    822   }
    823 
    824   return 0;
    825 }
    826 
    827 bool FrameDecoder::HasEverything() const {
    828   if (!decoded_dc_global_) return false;
    829   if (!decoded_ac_global_) return false;
    830   if (HasDcGroupToDecode()) return false;
    831   for (const auto& nb_passes : decoded_passes_per_ac_group_) {
    832     if (nb_passes < frame_header_.passes.num_passes) return false;
    833   }
    834   return true;
    835 }
    836 
    837 int FrameDecoder::References() const {
    838   if (is_finalized_) {
    839     return 0;
    840   }
    841   if (!HasEverything()) return 0;
    842 
    843   int result = 0;
    844 
    845   // Blending
    846   if (frame_header_.frame_type == FrameType::kRegularFrame ||
    847       frame_header_.frame_type == FrameType::kSkipProgressive) {
    848     bool cropped = frame_header_.custom_size_or_origin;
    849     if (cropped || frame_header_.blending_info.mode != BlendMode::kReplace) {
    850       result |= (1 << frame_header_.blending_info.source);
    851     }
    852     const auto& extra = frame_header_.extra_channel_blending_info;
    853     for (const auto& ecbi : extra) {
    854       if (cropped || ecbi.mode != BlendMode::kReplace) {
    855         result |= (1 << ecbi.source);
    856       }
    857     }
    858   }
    859 
    860   // Patches
    861   if (frame_header_.flags & FrameHeader::kPatches) {
    862     result |= dec_state_->shared->image_features.patches.GetReferences();
    863   }
    864 
    865   // DC Level
    866   if (frame_header_.flags & FrameHeader::kUseDcFrame) {
    867     // Reads from the next dc level
    868     int dc_level = frame_header_.dc_level + 1;
    869     // bits 16, 32, 64, 128 for DC level
    870     result |= (16 << (dc_level - 1));
    871   }
    872 
    873   return result;
    874 }
    875 
    876 Status FrameDecoder::FinalizeFrame() {
    877   if (is_finalized_) {
    878     return JXL_FAILURE("FinalizeFrame called multiple times");
    879   }
    880   is_finalized_ = true;
    881   if (decoded_->IsJPEG()) {
    882     // Nothing to do.
    883     return true;
    884   }
    885 
    886   // undo global modular transforms and copy int pixel buffers to float ones
    887   JXL_RETURN_IF_ERROR(
    888       modular_frame_decoder_.FinalizeDecoding(frame_header_, dec_state_, pool_,
    889                                               /*inplace=*/true));
    890 
    891   if (frame_header_.CanBeReferenced()) {
    892     auto& info = dec_state_->shared_storage
    893                      .reference_frames[frame_header_.save_as_reference];
    894     info.frame = std::move(dec_state_->frame_storage_for_referencing);
    895     info.ib_is_in_xyb = frame_header_.save_before_color_transform;
    896   }
    897   return true;
    898 }
    899 
    900 }  // namespace jxl