libjxl

FORK: libjxl patches used on blog
git clone https://git.neptards.moe/blog/libjxl.git
Log | Files | Refs | Submodules | README | LICENSE

dec_modular.cc (32464B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 #include "lib/jxl/dec_modular.h"
      7 
      8 #include <stdint.h>
      9 
     10 #include <atomic>
     11 #include <vector>
     12 
     13 #include "lib/jxl/frame_header.h"
     14 
     15 #undef HWY_TARGET_INCLUDE
     16 #define HWY_TARGET_INCLUDE "lib/jxl/dec_modular.cc"
     17 #include <hwy/foreach_target.h>
     18 #include <hwy/highway.h>
     19 
     20 #include "lib/jxl/base/compiler_specific.h"
     21 #include "lib/jxl/base/printf_macros.h"
     22 #include "lib/jxl/base/status.h"
     23 #include "lib/jxl/compressed_dc.h"
     24 #include "lib/jxl/epf.h"
     25 #include "lib/jxl/modular/encoding/encoding.h"
     26 #include "lib/jxl/modular/modular_image.h"
     27 #include "lib/jxl/modular/transform/transform.h"
     28 
     29 HWY_BEFORE_NAMESPACE();
     30 namespace jxl {
     31 namespace HWY_NAMESPACE {
     32 
     33 // These templates are not found via ADL.
     34 using hwy::HWY_NAMESPACE::Add;
     35 using hwy::HWY_NAMESPACE::Mul;
     36 using hwy::HWY_NAMESPACE::Rebind;
     37 
     38 void MultiplySum(const size_t xsize,
     39                  const pixel_type* const JXL_RESTRICT row_in,
     40                  const pixel_type* const JXL_RESTRICT row_in_Y,
     41                  const float factor, float* const JXL_RESTRICT row_out) {
     42   const HWY_FULL(float) df;
     43   const Rebind<pixel_type, HWY_FULL(float)> di;  // assumes pixel_type <= float
     44   const auto factor_v = Set(df, factor);
     45   for (size_t x = 0; x < xsize; x += Lanes(di)) {
     46     const auto in = Add(Load(di, row_in + x), Load(di, row_in_Y + x));
     47     const auto out = Mul(ConvertTo(df, in), factor_v);
     48     Store(out, df, row_out + x);
     49   }
     50 }
     51 
     52 void RgbFromSingle(const size_t xsize,
     53                    const pixel_type* const JXL_RESTRICT row_in,
     54                    const float factor, float* out_r, float* out_g,
     55                    float* out_b) {
     56   const HWY_FULL(float) df;
     57   const Rebind<pixel_type, HWY_FULL(float)> di;  // assumes pixel_type <= float
     58 
     59   const auto factor_v = Set(df, factor);
     60   for (size_t x = 0; x < xsize; x += Lanes(di)) {
     61     const auto in = Load(di, row_in + x);
     62     const auto out = Mul(ConvertTo(df, in), factor_v);
     63     Store(out, df, out_r + x);
     64     Store(out, df, out_g + x);
     65     Store(out, df, out_b + x);
     66   }
     67 }
     68 
     69 void SingleFromSingle(const size_t xsize,
     70                       const pixel_type* const JXL_RESTRICT row_in,
     71                       const float factor, float* row_out) {
     72   const HWY_FULL(float) df;
     73   const Rebind<pixel_type, HWY_FULL(float)> di;  // assumes pixel_type <= float
     74 
     75   const auto factor_v = Set(df, factor);
     76   for (size_t x = 0; x < xsize; x += Lanes(di)) {
     77     const auto in = Load(di, row_in + x);
     78     const auto out = Mul(ConvertTo(df, in), factor_v);
     79     Store(out, df, row_out + x);
     80   }
     81 }
     82 // NOLINTNEXTLINE(google-readability-namespace-comments)
     83 }  // namespace HWY_NAMESPACE
     84 }  // namespace jxl
     85 HWY_AFTER_NAMESPACE();
     86 
     87 #if HWY_ONCE
     88 namespace jxl {
     89 HWY_EXPORT(MultiplySum);       // Local function
     90 HWY_EXPORT(RgbFromSingle);     // Local function
     91 HWY_EXPORT(SingleFromSingle);  // Local function
     92 
     93 // Slow conversion using double precision multiplication, only
     94 // needed when the bit depth is too high for single precision
     95 void SingleFromSingleAccurate(const size_t xsize,
     96                               const pixel_type* const JXL_RESTRICT row_in,
     97                               const double factor, float* row_out) {
     98   for (size_t x = 0; x < xsize; x++) {
     99     row_out[x] = row_in[x] * factor;
    100   }
    101 }
    102 
    103 // convert custom [bits]-bit float (with [exp_bits] exponent bits) stored as int
    104 // back to binary32 float
    105 void int_to_float(const pixel_type* const JXL_RESTRICT row_in,
    106                   float* const JXL_RESTRICT row_out, const size_t xsize,
    107                   const int bits, const int exp_bits) {
    108   if (bits == 32) {
    109     JXL_ASSERT(sizeof(pixel_type) == sizeof(float));
    110     JXL_ASSERT(exp_bits == 8);
    111     memcpy(row_out, row_in, xsize * sizeof(float));
    112     return;
    113   }
    114   int exp_bias = (1 << (exp_bits - 1)) - 1;
    115   int sign_shift = bits - 1;
    116   int mant_bits = bits - exp_bits - 1;
    117   int mant_shift = 23 - mant_bits;
    118   for (size_t x = 0; x < xsize; ++x) {
    119     uint32_t f;
    120     memcpy(&f, &row_in[x], 4);
    121     int signbit = (f >> sign_shift);
    122     f &= (1 << sign_shift) - 1;
    123     if (f == 0) {
    124       row_out[x] = (signbit ? -0.f : 0.f);
    125       continue;
    126     }
    127     int exp = (f >> mant_bits);
    128     int mantissa = (f & ((1 << mant_bits) - 1));
    129     mantissa <<= mant_shift;
    130     // Try to normalize only if there is space for maneuver.
    131     if (exp == 0 && exp_bits < 8) {
    132       // subnormal number
    133       while ((mantissa & 0x800000) == 0) {
    134         mantissa <<= 1;
    135         exp--;
    136       }
    137       exp++;
    138       // remove leading 1 because it is implicit now
    139       mantissa &= 0x7fffff;
    140     }
    141     exp -= exp_bias;
    142     // broke up the arbitrary float into its parts, now reassemble into
    143     // binary32
    144     exp += 127;
    145     JXL_ASSERT(exp >= 0);
    146     f = (signbit ? 0x80000000 : 0);
    147     f |= (exp << 23);
    148     f |= mantissa;
    149     memcpy(&row_out[x], &f, 4);
    150   }
    151 }
    152 
    153 #if JXL_DEBUG_V_LEVEL >= 1
    154 std::string ModularStreamId::DebugString() const {
    155   std::ostringstream os;
    156   os << (kind == kGlobalData   ? "ModularGlobal"
    157          : kind == kVarDCTDC   ? "VarDCTDC"
    158          : kind == kModularDC  ? "ModularDC"
    159          : kind == kACMetadata ? "ACMeta"
    160          : kind == kQuantTable ? "QuantTable"
    161          : kind == kModularAC  ? "ModularAC"
    162                                : "");
    163   if (kind == kVarDCTDC || kind == kModularDC || kind == kACMetadata ||
    164       kind == kModularAC) {
    165     os << " group " << group_id;
    166   }
    167   if (kind == kModularAC) {
    168     os << " pass " << pass_id;
    169   }
    170   if (kind == kQuantTable) {
    171     os << " " << quant_table_id;
    172   }
    173   return os.str();
    174 }
    175 #endif
    176 
    177 Status ModularFrameDecoder::DecodeGlobalInfo(BitReader* reader,
    178                                              const FrameHeader& frame_header,
    179                                              bool allow_truncated_group) {
    180   bool decode_color = frame_header.encoding == FrameEncoding::kModular;
    181   const auto& metadata = frame_header.nonserialized_metadata->m;
    182   bool is_gray = metadata.color_encoding.IsGray();
    183   size_t nb_chans = 3;
    184   if (is_gray && frame_header.color_transform == ColorTransform::kNone) {
    185     nb_chans = 1;
    186   }
    187   do_color = decode_color;
    188   size_t nb_extra = metadata.extra_channel_info.size();
    189   bool has_tree = static_cast<bool>(reader->ReadBits(1));
    190   if (!allow_truncated_group ||
    191       reader->TotalBitsConsumed() < reader->TotalBytes() * kBitsPerByte) {
    192     if (has_tree) {
    193       size_t tree_size_limit =
    194           std::min(static_cast<size_t>(1 << 22),
    195                    1024 + frame_dim.xsize * frame_dim.ysize *
    196                               (nb_chans + nb_extra) / 16);
    197       JXL_RETURN_IF_ERROR(DecodeTree(reader, &tree, tree_size_limit));
    198       JXL_RETURN_IF_ERROR(
    199           DecodeHistograms(reader, (tree.size() + 1) / 2, &code, &context_map));
    200     }
    201   }
    202   if (!do_color) nb_chans = 0;
    203 
    204   bool fp = metadata.bit_depth.floating_point_sample;
    205 
    206   // bits_per_sample is just metadata for XYB images.
    207   if (metadata.bit_depth.bits_per_sample >= 32 && do_color &&
    208       frame_header.color_transform != ColorTransform::kXYB) {
    209     if (metadata.bit_depth.bits_per_sample == 32 && fp == false) {
    210       return JXL_FAILURE("uint32_t not supported in dec_modular");
    211     } else if (metadata.bit_depth.bits_per_sample > 32) {
    212       return JXL_FAILURE("bits_per_sample > 32 not supported");
    213     }
    214   }
    215 
    216   JXL_ASSIGN_OR_RETURN(
    217       Image gi,
    218       Image::Create(frame_dim.xsize, frame_dim.ysize,
    219                     metadata.bit_depth.bits_per_sample, nb_chans + nb_extra));
    220 
    221   all_same_shift = true;
    222   if (frame_header.color_transform == ColorTransform::kYCbCr) {
    223     for (size_t c = 0; c < nb_chans; c++) {
    224       gi.channel[c].hshift = frame_header.chroma_subsampling.HShift(c);
    225       gi.channel[c].vshift = frame_header.chroma_subsampling.VShift(c);
    226       size_t xsize_shifted =
    227           DivCeil(frame_dim.xsize, 1 << gi.channel[c].hshift);
    228       size_t ysize_shifted =
    229           DivCeil(frame_dim.ysize, 1 << gi.channel[c].vshift);
    230       JXL_RETURN_IF_ERROR(gi.channel[c].shrink(xsize_shifted, ysize_shifted));
    231       if (gi.channel[c].hshift != gi.channel[0].hshift ||
    232           gi.channel[c].vshift != gi.channel[0].vshift)
    233         all_same_shift = false;
    234     }
    235   }
    236 
    237   for (size_t ec = 0, c = nb_chans; ec < nb_extra; ec++, c++) {
    238     size_t ecups = frame_header.extra_channel_upsampling[ec];
    239     JXL_RETURN_IF_ERROR(
    240         gi.channel[c].shrink(DivCeil(frame_dim.xsize_upsampled, ecups),
    241                              DivCeil(frame_dim.ysize_upsampled, ecups)));
    242     gi.channel[c].hshift = gi.channel[c].vshift =
    243         CeilLog2Nonzero(ecups) - CeilLog2Nonzero(frame_header.upsampling);
    244     if (gi.channel[c].hshift != gi.channel[0].hshift ||
    245         gi.channel[c].vshift != gi.channel[0].vshift)
    246       all_same_shift = false;
    247   }
    248 
    249   JXL_DEBUG_V(6, "DecodeGlobalInfo: full_image (w/o transforms) %s",
    250               gi.DebugString().c_str());
    251   ModularOptions options;
    252   options.max_chan_size = frame_dim.group_dim;
    253   options.group_dim = frame_dim.group_dim;
    254   Status dec_status = ModularGenericDecompress(
    255       reader, gi, &global_header, ModularStreamId::Global().ID(frame_dim),
    256       &options,
    257       /*undo_transforms=*/false, &tree, &code, &context_map,
    258       allow_truncated_group);
    259   if (!allow_truncated_group) JXL_RETURN_IF_ERROR(dec_status);
    260   if (dec_status.IsFatalError()) {
    261     return JXL_FAILURE("Failed to decode global modular info");
    262   }
    263 
    264   // TODO(eustas): are we sure this can be done after partial decode?
    265   have_something = false;
    266   for (size_t c = 0; c < gi.channel.size(); c++) {
    267     Channel& gic = gi.channel[c];
    268     if (c >= gi.nb_meta_channels && gic.w <= frame_dim.group_dim &&
    269         gic.h <= frame_dim.group_dim)
    270       have_something = true;
    271   }
    272   // move global transforms to groups if possible
    273   if (!have_something && all_same_shift) {
    274     if (gi.transform.size() == 1 && gi.transform[0].id == TransformId::kRCT) {
    275       global_transform = gi.transform;
    276       gi.transform.clear();
    277       // TODO(jon): also move no-delta-palette out (trickier though)
    278     }
    279   }
    280   full_image = std::move(gi);
    281   JXL_DEBUG_V(6, "DecodeGlobalInfo: full_image (with transforms) %s",
    282               full_image.DebugString().c_str());
    283   return dec_status;
    284 }
    285 
    286 void ModularFrameDecoder::MaybeDropFullImage() {
    287   if (full_image.transform.empty() && !have_something && all_same_shift) {
    288     use_full_image = false;
    289     JXL_DEBUG_V(6, "Dropping full image");
    290     for (auto& ch : full_image.channel) {
    291       // keep metadata on channels around, but dealloc their planes
    292       ch.plane = Plane<pixel_type>();
    293     }
    294   }
    295 }
    296 
    297 Status ModularFrameDecoder::DecodeGroup(
    298     const FrameHeader& frame_header, const Rect& rect, BitReader* reader,
    299     int minShift, int maxShift, const ModularStreamId& stream, bool zerofill,
    300     PassesDecoderState* dec_state, RenderPipelineInput* render_pipeline_input,
    301     bool allow_truncated, bool* should_run_pipeline) {
    302   JXL_DEBUG_V(6, "Decoding %s with rect %s and shift bracket %d..%d %s",
    303               stream.DebugString().c_str(), Description(rect).c_str(), minShift,
    304               maxShift, zerofill ? "using zerofill" : "");
    305   JXL_DASSERT(stream.kind == ModularStreamId::kModularDC ||
    306               stream.kind == ModularStreamId::kModularAC);
    307   const size_t xsize = rect.xsize();
    308   const size_t ysize = rect.ysize();
    309   JXL_ASSIGN_OR_RETURN(Image gi,
    310                        Image::Create(xsize, ysize, full_image.bitdepth, 0));
    311   // start at the first bigger-than-groupsize non-metachannel
    312   size_t c = full_image.nb_meta_channels;
    313   for (; c < full_image.channel.size(); c++) {
    314     Channel& fc = full_image.channel[c];
    315     if (fc.w > frame_dim.group_dim || fc.h > frame_dim.group_dim) break;
    316   }
    317   size_t beginc = c;
    318   for (; c < full_image.channel.size(); c++) {
    319     Channel& fc = full_image.channel[c];
    320     int shift = std::min(fc.hshift, fc.vshift);
    321     if (shift > maxShift) continue;
    322     if (shift < minShift) continue;
    323     Rect r(rect.x0() >> fc.hshift, rect.y0() >> fc.vshift,
    324            rect.xsize() >> fc.hshift, rect.ysize() >> fc.vshift, fc.w, fc.h);
    325     if (r.xsize() == 0 || r.ysize() == 0) continue;
    326     if (zerofill && use_full_image) {
    327       for (size_t y = 0; y < r.ysize(); ++y) {
    328         pixel_type* const JXL_RESTRICT row_out = r.Row(&fc.plane, y);
    329         memset(row_out, 0, r.xsize() * sizeof(*row_out));
    330       }
    331     } else {
    332       JXL_ASSIGN_OR_RETURN(Channel gc, Channel::Create(r.xsize(), r.ysize()));
    333       if (zerofill) ZeroFillImage(&gc.plane);
    334       gc.hshift = fc.hshift;
    335       gc.vshift = fc.vshift;
    336       gi.channel.emplace_back(std::move(gc));
    337     }
    338   }
    339   if (zerofill && use_full_image) return true;
    340   // Return early if there's nothing to decode. Otherwise there might be
    341   // problems later (in ModularImageToDecodedRect).
    342   if (gi.channel.empty()) {
    343     if (dec_state && should_run_pipeline) {
    344       const auto* metadata = frame_header.nonserialized_metadata;
    345       if (do_color || metadata->m.num_extra_channels > 0) {
    346         // Signal to FrameDecoder that we do not have some of the required input
    347         // for the render pipeline.
    348         *should_run_pipeline = false;
    349       }
    350     }
    351     JXL_DEBUG_V(6, "Nothing to decode, returning early.");
    352     return true;
    353   }
    354   ModularOptions options;
    355   if (!zerofill) {
    356     auto status = ModularGenericDecompress(
    357         reader, gi, /*header=*/nullptr, stream.ID(frame_dim), &options,
    358         /*undo_transforms=*/true, &tree, &code, &context_map, allow_truncated);
    359     if (!allow_truncated) JXL_RETURN_IF_ERROR(status);
    360     if (status.IsFatalError()) return status;
    361   }
    362   // Undo global transforms that have been pushed to the group level
    363   if (!use_full_image) {
    364     JXL_ASSERT(render_pipeline_input);
    365     for (auto t : global_transform) {
    366       JXL_RETURN_IF_ERROR(t.Inverse(gi, global_header.wp_header));
    367     }
    368     JXL_RETURN_IF_ERROR(ModularImageToDecodedRect(
    369         frame_header, gi, dec_state, nullptr, *render_pipeline_input,
    370         Rect(0, 0, gi.w, gi.h)));
    371     return true;
    372   }
    373   int gic = 0;
    374   for (c = beginc; c < full_image.channel.size(); c++) {
    375     Channel& fc = full_image.channel[c];
    376     int shift = std::min(fc.hshift, fc.vshift);
    377     if (shift > maxShift) continue;
    378     if (shift < minShift) continue;
    379     Rect r(rect.x0() >> fc.hshift, rect.y0() >> fc.vshift,
    380            rect.xsize() >> fc.hshift, rect.ysize() >> fc.vshift, fc.w, fc.h);
    381     if (r.xsize() == 0 || r.ysize() == 0) continue;
    382     JXL_ASSERT(use_full_image);
    383     CopyImageTo(/*rect_from=*/Rect(0, 0, r.xsize(), r.ysize()),
    384                 /*from=*/gi.channel[gic].plane,
    385                 /*rect_to=*/r, /*to=*/&fc.plane);
    386     gic++;
    387   }
    388   return true;
    389 }
    390 
    391 Status ModularFrameDecoder::DecodeVarDCTDC(const FrameHeader& frame_header,
    392                                            size_t group_id, BitReader* reader,
    393                                            PassesDecoderState* dec_state) {
    394   const Rect r = dec_state->shared->frame_dim.DCGroupRect(group_id);
    395   JXL_DEBUG_V(6, "Decoding VarDCT DC with rect %s", Description(r).c_str());
    396   // TODO(eustas): investigate if we could reduce the impact of
    397   //               EvalRationalPolynomial; generally speaking, the limit is
    398   //               2**(128/(3*magic)), where 128 comes from IEEE 754 exponent,
    399   //               3 comes from XybToRgb that cubes the values, and "magic" is
    400   //               the sum of all other contributions. 2**18 is known to lead
    401   //               to NaN on input found by fuzzing (see commit message).
    402   JXL_ASSIGN_OR_RETURN(
    403       Image image, Image::Create(r.xsize(), r.ysize(), full_image.bitdepth, 3));
    404   size_t stream_id = ModularStreamId::VarDCTDC(group_id).ID(frame_dim);
    405   reader->Refill();
    406   size_t extra_precision = reader->ReadFixedBits<2>();
    407   float mul = 1.0f / (1 << extra_precision);
    408   ModularOptions options;
    409   for (size_t c = 0; c < 3; c++) {
    410     Channel& ch = image.channel[c < 2 ? c ^ 1 : c];
    411     ch.w >>= frame_header.chroma_subsampling.HShift(c);
    412     ch.h >>= frame_header.chroma_subsampling.VShift(c);
    413     JXL_RETURN_IF_ERROR(ch.shrink());
    414   }
    415   if (!ModularGenericDecompress(
    416           reader, image, /*header=*/nullptr, stream_id, &options,
    417           /*undo_transforms=*/true, &tree, &code, &context_map)) {
    418     return JXL_FAILURE("Failed to decode VarDCT DC group (DC group id %d)",
    419                        static_cast<int>(group_id));
    420   }
    421   DequantDC(r, &dec_state->shared_storage.dc_storage,
    422             &dec_state->shared_storage.quant_dc, image,
    423             dec_state->shared->quantizer.MulDC(), mul,
    424             dec_state->shared->cmap.DCFactors(),
    425             frame_header.chroma_subsampling, dec_state->shared->block_ctx_map);
    426   return true;
    427 }
    428 
    429 Status ModularFrameDecoder::DecodeAcMetadata(const FrameHeader& frame_header,
    430                                              size_t group_id, BitReader* reader,
    431                                              PassesDecoderState* dec_state) {
    432   const Rect r = dec_state->shared->frame_dim.DCGroupRect(group_id);
    433   JXL_DEBUG_V(6, "Decoding AcMetadata with rect %s", Description(r).c_str());
    434   size_t upper_bound = r.xsize() * r.ysize();
    435   reader->Refill();
    436   size_t count = reader->ReadBits(CeilLog2Nonzero(upper_bound)) + 1;
    437   size_t stream_id = ModularStreamId::ACMetadata(group_id).ID(frame_dim);
    438   // YToX, YToB, ACS + QF, EPF
    439   JXL_ASSIGN_OR_RETURN(
    440       Image image, Image::Create(r.xsize(), r.ysize(), full_image.bitdepth, 4));
    441   static_assert(kColorTileDimInBlocks == 8, "Color tile size changed");
    442   Rect cr(r.x0() >> 3, r.y0() >> 3, (r.xsize() + 7) >> 3, (r.ysize() + 7) >> 3);
    443   JXL_ASSIGN_OR_RETURN(image.channel[0],
    444                        Channel::Create(cr.xsize(), cr.ysize(), 3, 3));
    445   JXL_ASSIGN_OR_RETURN(image.channel[1],
    446                        Channel::Create(cr.xsize(), cr.ysize(), 3, 3));
    447   JXL_ASSIGN_OR_RETURN(image.channel[2], Channel::Create(count, 2, 0, 0));
    448   ModularOptions options;
    449   if (!ModularGenericDecompress(
    450           reader, image, /*header=*/nullptr, stream_id, &options,
    451           /*undo_transforms=*/true, &tree, &code, &context_map)) {
    452     return JXL_FAILURE("Failed to decode AC metadata");
    453   }
    454   ConvertPlaneAndClamp(Rect(image.channel[0].plane), image.channel[0].plane, cr,
    455                        &dec_state->shared_storage.cmap.ytox_map);
    456   ConvertPlaneAndClamp(Rect(image.channel[1].plane), image.channel[1].plane, cr,
    457                        &dec_state->shared_storage.cmap.ytob_map);
    458   size_t num = 0;
    459   bool is444 = frame_header.chroma_subsampling.Is444();
    460   auto& ac_strategy = dec_state->shared_storage.ac_strategy;
    461   size_t xlim = std::min(ac_strategy.xsize(), r.x0() + r.xsize());
    462   size_t ylim = std::min(ac_strategy.ysize(), r.y0() + r.ysize());
    463   uint32_t local_used_acs = 0;
    464   for (size_t iy = 0; iy < r.ysize(); iy++) {
    465     size_t y = r.y0() + iy;
    466     int32_t* row_qf = r.Row(&dec_state->shared_storage.raw_quant_field, iy);
    467     uint8_t* row_epf = r.Row(&dec_state->shared_storage.epf_sharpness, iy);
    468     int32_t* row_in_1 = image.channel[2].plane.Row(0);
    469     int32_t* row_in_2 = image.channel[2].plane.Row(1);
    470     int32_t* row_in_3 = image.channel[3].plane.Row(iy);
    471     for (size_t ix = 0; ix < r.xsize(); ix++) {
    472       size_t x = r.x0() + ix;
    473       int sharpness = row_in_3[ix];
    474       if (sharpness < 0 || sharpness >= LoopFilter::kEpfSharpEntries) {
    475         return JXL_FAILURE("Corrupted sharpness field");
    476       }
    477       row_epf[ix] = sharpness;
    478       if (ac_strategy.IsValid(x, y)) {
    479         continue;
    480       }
    481 
    482       if (num >= count) return JXL_FAILURE("Corrupted stream");
    483 
    484       if (!AcStrategy::IsRawStrategyValid(row_in_1[num])) {
    485         return JXL_FAILURE("Invalid AC strategy");
    486       }
    487       local_used_acs |= 1u << row_in_1[num];
    488       AcStrategy acs = AcStrategy::FromRawStrategy(row_in_1[num]);
    489       if ((acs.covered_blocks_x() > 1 || acs.covered_blocks_y() > 1) &&
    490           !is444) {
    491         return JXL_FAILURE(
    492             "AC strategy not compatible with chroma subsampling");
    493       }
    494       // Ensure that blocks do not overflow *AC* groups.
    495       size_t next_x_ac_block = (x / kGroupDimInBlocks + 1) * kGroupDimInBlocks;
    496       size_t next_y_ac_block = (y / kGroupDimInBlocks + 1) * kGroupDimInBlocks;
    497       size_t next_x_dct_block = x + acs.covered_blocks_x();
    498       size_t next_y_dct_block = y + acs.covered_blocks_y();
    499       if (next_x_dct_block > next_x_ac_block || next_x_dct_block > xlim) {
    500         return JXL_FAILURE("Invalid AC strategy, x overflow");
    501       }
    502       if (next_y_dct_block > next_y_ac_block || next_y_dct_block > ylim) {
    503         return JXL_FAILURE("Invalid AC strategy, y overflow");
    504       }
    505       JXL_RETURN_IF_ERROR(
    506           ac_strategy.SetNoBoundsCheck(x, y, AcStrategy::Type(row_in_1[num])));
    507       row_qf[ix] = 1 + std::max<int32_t>(0, std::min(Quantizer::kQuantMax - 1,
    508                                                      row_in_2[num]));
    509       num++;
    510     }
    511   }
    512   dec_state->used_acs |= local_used_acs;
    513   if (frame_header.loop_filter.epf_iters > 0) {
    514     ComputeSigma(frame_header.loop_filter, r, dec_state);
    515   }
    516   return true;
    517 }
    518 
    519 Status ModularFrameDecoder::ModularImageToDecodedRect(
    520     const FrameHeader& frame_header, Image& gi, PassesDecoderState* dec_state,
    521     jxl::ThreadPool* pool, RenderPipelineInput& render_pipeline_input,
    522     Rect modular_rect) const {
    523   const auto* metadata = frame_header.nonserialized_metadata;
    524   JXL_CHECK(gi.transform.empty());
    525 
    526   auto get_row = [&](size_t c, size_t y) {
    527     const auto& buffer = render_pipeline_input.GetBuffer(c);
    528     return buffer.second.Row(buffer.first, y);
    529   };
    530 
    531   size_t c = 0;
    532   if (do_color) {
    533     const bool rgb_from_gray =
    534         metadata->m.color_encoding.IsGray() &&
    535         frame_header.color_transform == ColorTransform::kNone;
    536     const bool fp = metadata->m.bit_depth.floating_point_sample &&
    537                     frame_header.color_transform != ColorTransform::kXYB;
    538     for (; c < 3; c++) {
    539       double factor = full_image.bitdepth < 32
    540                           ? 1.0 / ((1u << full_image.bitdepth) - 1)
    541                           : 0;
    542       size_t c_in = c;
    543       if (frame_header.color_transform == ColorTransform::kXYB) {
    544         factor = dec_state->shared->matrices.DCQuants()[c];
    545         // XYB is encoded as YX(B-Y)
    546         if (c < 2) c_in = 1 - c;
    547       } else if (rgb_from_gray) {
    548         c_in = 0;
    549       }
    550       JXL_ASSERT(c_in < gi.channel.size());
    551       Channel& ch_in = gi.channel[c_in];
    552       // TODO(eustas): could we detect it on earlier stage?
    553       if (ch_in.w == 0 || ch_in.h == 0) {
    554         return JXL_FAILURE("Empty image");
    555       }
    556       JXL_CHECK(ch_in.hshift <= 3 && ch_in.vshift <= 3);
    557       Rect r = render_pipeline_input.GetBuffer(c).second;
    558       Rect mr(modular_rect.x0() >> ch_in.hshift,
    559               modular_rect.y0() >> ch_in.vshift,
    560               DivCeil(modular_rect.xsize(), 1 << ch_in.hshift),
    561               DivCeil(modular_rect.ysize(), 1 << ch_in.vshift));
    562       mr = mr.Crop(ch_in.plane);
    563       size_t xsize_shifted = r.xsize();
    564       size_t ysize_shifted = r.ysize();
    565       if (r.ysize() != mr.ysize() || r.xsize() != mr.xsize()) {
    566         return JXL_FAILURE("Dimension mismatch: trying to fit a %" PRIuS
    567                            "x%" PRIuS
    568                            " modular channel into "
    569                            "a %" PRIuS "x%" PRIuS " rect",
    570                            mr.xsize(), mr.ysize(), r.xsize(), r.ysize());
    571       }
    572       if (frame_header.color_transform == ColorTransform::kXYB && c == 2) {
    573         JXL_ASSERT(!fp);
    574         JXL_RETURN_IF_ERROR(RunOnPool(
    575             pool, 0, ysize_shifted, ThreadPool::NoInit,
    576             [&](const uint32_t task, size_t /* thread */) {
    577               const size_t y = task;
    578               const pixel_type* const JXL_RESTRICT row_in =
    579                   mr.Row(&ch_in.plane, y);
    580               const pixel_type* const JXL_RESTRICT row_in_Y =
    581                   mr.Row(&gi.channel[0].plane, y);
    582               float* const JXL_RESTRICT row_out = get_row(c, y);
    583               HWY_DYNAMIC_DISPATCH(MultiplySum)
    584               (xsize_shifted, row_in, row_in_Y, factor, row_out);
    585             },
    586             "ModularIntToFloat"));
    587       } else if (fp) {
    588         int bits = metadata->m.bit_depth.bits_per_sample;
    589         int exp_bits = metadata->m.bit_depth.exponent_bits_per_sample;
    590         JXL_RETURN_IF_ERROR(RunOnPool(
    591             pool, 0, ysize_shifted, ThreadPool::NoInit,
    592             [&](const uint32_t task, size_t /* thread */) {
    593               const size_t y = task;
    594               const pixel_type* const JXL_RESTRICT row_in =
    595                   mr.Row(&ch_in.plane, y);
    596               if (rgb_from_gray) {
    597                 for (size_t cc = 0; cc < 3; cc++) {
    598                   float* const JXL_RESTRICT row_out = get_row(cc, y);
    599                   int_to_float(row_in, row_out, xsize_shifted, bits, exp_bits);
    600                 }
    601               } else {
    602                 float* const JXL_RESTRICT row_out = get_row(c, y);
    603                 int_to_float(row_in, row_out, xsize_shifted, bits, exp_bits);
    604               }
    605             },
    606             "ModularIntToFloat_losslessfloat"));
    607       } else {
    608         JXL_RETURN_IF_ERROR(RunOnPool(
    609             pool, 0, ysize_shifted, ThreadPool::NoInit,
    610             [&](const uint32_t task, size_t /* thread */) {
    611               const size_t y = task;
    612               const pixel_type* const JXL_RESTRICT row_in =
    613                   mr.Row(&ch_in.plane, y);
    614               if (rgb_from_gray) {
    615                 if (full_image.bitdepth < 23) {
    616                   HWY_DYNAMIC_DISPATCH(RgbFromSingle)
    617                   (xsize_shifted, row_in, factor, get_row(0, y), get_row(1, y),
    618                    get_row(2, y));
    619                 } else {
    620                   SingleFromSingleAccurate(xsize_shifted, row_in, factor,
    621                                            get_row(0, y));
    622                   SingleFromSingleAccurate(xsize_shifted, row_in, factor,
    623                                            get_row(1, y));
    624                   SingleFromSingleAccurate(xsize_shifted, row_in, factor,
    625                                            get_row(2, y));
    626                 }
    627               } else {
    628                 float* const JXL_RESTRICT row_out = get_row(c, y);
    629                 if (full_image.bitdepth < 23) {
    630                   HWY_DYNAMIC_DISPATCH(SingleFromSingle)
    631                   (xsize_shifted, row_in, factor, row_out);
    632                 } else {
    633                   SingleFromSingleAccurate(xsize_shifted, row_in, factor,
    634                                            row_out);
    635                 }
    636               }
    637             },
    638             "ModularIntToFloat"));
    639       }
    640       if (rgb_from_gray) {
    641         break;
    642       }
    643     }
    644     if (rgb_from_gray) {
    645       c = 1;
    646     }
    647   }
    648   size_t num_extra_channels = metadata->m.num_extra_channels;
    649   for (size_t ec = 0; ec < num_extra_channels; ec++, c++) {
    650     const ExtraChannelInfo& eci = metadata->m.extra_channel_info[ec];
    651     int bits = eci.bit_depth.bits_per_sample;
    652     int exp_bits = eci.bit_depth.exponent_bits_per_sample;
    653     bool fp = eci.bit_depth.floating_point_sample;
    654     JXL_ASSERT(fp || bits < 32);
    655     const double factor = fp ? 0 : (1.0 / ((1u << bits) - 1));
    656     JXL_ASSERT(c < gi.channel.size());
    657     Channel& ch_in = gi.channel[c];
    658     Rect r = render_pipeline_input.GetBuffer(3 + ec).second;
    659     Rect mr(modular_rect.x0() >> ch_in.hshift,
    660             modular_rect.y0() >> ch_in.vshift,
    661             DivCeil(modular_rect.xsize(), 1 << ch_in.hshift),
    662             DivCeil(modular_rect.ysize(), 1 << ch_in.vshift));
    663     mr = mr.Crop(ch_in.plane);
    664     if (r.ysize() != mr.ysize() || r.xsize() != mr.xsize()) {
    665       return JXL_FAILURE("Dimension mismatch: trying to fit a %" PRIuS
    666                          "x%" PRIuS
    667                          " modular channel into "
    668                          "a %" PRIuS "x%" PRIuS " rect",
    669                          mr.xsize(), mr.ysize(), r.xsize(), r.ysize());
    670     }
    671     for (size_t y = 0; y < r.ysize(); ++y) {
    672       float* const JXL_RESTRICT row_out =
    673           r.Row(render_pipeline_input.GetBuffer(3 + ec).first, y);
    674       const pixel_type* const JXL_RESTRICT row_in = mr.Row(&ch_in.plane, y);
    675       if (fp) {
    676         int_to_float(row_in, row_out, r.xsize(), bits, exp_bits);
    677       } else {
    678         if (full_image.bitdepth < 23) {
    679           HWY_DYNAMIC_DISPATCH(SingleFromSingle)
    680           (r.xsize(), row_in, factor, row_out);
    681         } else {
    682           SingleFromSingleAccurate(r.xsize(), row_in, factor, row_out);
    683         }
    684       }
    685     }
    686   }
    687   return true;
    688 }
    689 
    690 Status ModularFrameDecoder::FinalizeDecoding(const FrameHeader& frame_header,
    691                                              PassesDecoderState* dec_state,
    692                                              jxl::ThreadPool* pool,
    693                                              bool inplace) {
    694   if (!use_full_image) return true;
    695   Image gi;
    696   if (inplace) {
    697     gi = std::move(full_image);
    698   } else {
    699     JXL_ASSIGN_OR_RETURN(gi, Image::Clone(full_image));
    700   }
    701   size_t xsize = gi.w;
    702   size_t ysize = gi.h;
    703 
    704   JXL_DEBUG_V(3, "Finalizing decoding for modular image: %s",
    705               gi.DebugString().c_str());
    706 
    707   // Don't use threads if total image size is smaller than a group
    708   if (xsize * ysize < frame_dim.group_dim * frame_dim.group_dim) pool = nullptr;
    709 
    710   // Undo the global transforms
    711   gi.undo_transforms(global_header.wp_header, pool);
    712   JXL_DASSERT(global_transform.empty());
    713   if (gi.error) return JXL_FAILURE("Undoing transforms failed");
    714 
    715   for (size_t i = 0; i < dec_state->shared->frame_dim.num_groups; i++) {
    716     dec_state->render_pipeline->ClearDone(i);
    717   }
    718   std::atomic<bool> has_error{false};
    719   JXL_RETURN_IF_ERROR(RunOnPool(
    720       pool, 0, dec_state->shared->frame_dim.num_groups,
    721       [&](size_t num_threads) {
    722         bool use_group_ids = (frame_header.encoding == FrameEncoding::kVarDCT ||
    723                               (frame_header.flags & FrameHeader::kNoise));
    724         return dec_state->render_pipeline->PrepareForThreads(num_threads,
    725                                                              use_group_ids);
    726       },
    727       [&](const uint32_t group, size_t thread_id) {
    728         if (has_error) return;
    729         RenderPipelineInput input =
    730             dec_state->render_pipeline->GetInputBuffers(group, thread_id);
    731         if (!ModularImageToDecodedRect(
    732                 frame_header, gi, dec_state, nullptr, input,
    733                 dec_state->shared->frame_dim.GroupRect(group))) {
    734           has_error = true;
    735           return;
    736         }
    737         if (!input.Done()) {
    738           has_error = true;
    739           return;
    740         }
    741       },
    742       "ModularToRect"));
    743   if (has_error) return JXL_FAILURE("Error producing input to render pipeline");
    744   return true;
    745 }
    746 
    747 static constexpr const float kAlmostZero = 1e-8f;
    748 
    749 Status ModularFrameDecoder::DecodeQuantTable(
    750     size_t required_size_x, size_t required_size_y, BitReader* br,
    751     QuantEncoding* encoding, size_t idx,
    752     ModularFrameDecoder* modular_frame_decoder) {
    753   JXL_RETURN_IF_ERROR(F16Coder::Read(br, &encoding->qraw.qtable_den));
    754   if (encoding->qraw.qtable_den < kAlmostZero) {
    755     // qtable[] values are already checked for <= 0 so the denominator may not
    756     // be negative.
    757     return JXL_FAILURE("Invalid qtable_den: value too small");
    758   }
    759   JXL_ASSIGN_OR_RETURN(Image image,
    760                        Image::Create(required_size_x, required_size_y, 8, 3));
    761   ModularOptions options;
    762   if (modular_frame_decoder) {
    763     JXL_RETURN_IF_ERROR(ModularGenericDecompress(
    764         br, image, /*header=*/nullptr,
    765         ModularStreamId::QuantTable(idx).ID(modular_frame_decoder->frame_dim),
    766         &options, /*undo_transforms=*/true, &modular_frame_decoder->tree,
    767         &modular_frame_decoder->code, &modular_frame_decoder->context_map));
    768   } else {
    769     JXL_RETURN_IF_ERROR(ModularGenericDecompress(br, image, /*header=*/nullptr,
    770                                                  0, &options,
    771                                                  /*undo_transforms=*/true));
    772   }
    773   if (!encoding->qraw.qtable) {
    774     encoding->qraw.qtable = new std::vector<int>();
    775   }
    776   encoding->qraw.qtable->resize(required_size_x * required_size_y * 3);
    777   for (size_t c = 0; c < 3; c++) {
    778     for (size_t y = 0; y < required_size_y; y++) {
    779       int32_t* JXL_RESTRICT row = image.channel[c].Row(y);
    780       for (size_t x = 0; x < required_size_x; x++) {
    781         (*encoding->qraw.qtable)[c * required_size_x * required_size_y +
    782                                  y * required_size_x + x] = row[x];
    783         if (row[x] <= 0) {
    784           return JXL_FAILURE("Invalid raw quantization table");
    785         }
    786       }
    787     }
    788   }
    789   return true;
    790 }
    791 
    792 }  // namespace jxl
    793 #endif  // HWY_ONCE