dec_modular.cc (32464B)
1 // Copyright (c) the JPEG XL Project Authors. All rights reserved. 2 // 3 // Use of this source code is governed by a BSD-style 4 // license that can be found in the LICENSE file. 5 6 #include "lib/jxl/dec_modular.h" 7 8 #include <stdint.h> 9 10 #include <atomic> 11 #include <vector> 12 13 #include "lib/jxl/frame_header.h" 14 15 #undef HWY_TARGET_INCLUDE 16 #define HWY_TARGET_INCLUDE "lib/jxl/dec_modular.cc" 17 #include <hwy/foreach_target.h> 18 #include <hwy/highway.h> 19 20 #include "lib/jxl/base/compiler_specific.h" 21 #include "lib/jxl/base/printf_macros.h" 22 #include "lib/jxl/base/status.h" 23 #include "lib/jxl/compressed_dc.h" 24 #include "lib/jxl/epf.h" 25 #include "lib/jxl/modular/encoding/encoding.h" 26 #include "lib/jxl/modular/modular_image.h" 27 #include "lib/jxl/modular/transform/transform.h" 28 29 HWY_BEFORE_NAMESPACE(); 30 namespace jxl { 31 namespace HWY_NAMESPACE { 32 33 // These templates are not found via ADL. 34 using hwy::HWY_NAMESPACE::Add; 35 using hwy::HWY_NAMESPACE::Mul; 36 using hwy::HWY_NAMESPACE::Rebind; 37 38 void MultiplySum(const size_t xsize, 39 const pixel_type* const JXL_RESTRICT row_in, 40 const pixel_type* const JXL_RESTRICT row_in_Y, 41 const float factor, float* const JXL_RESTRICT row_out) { 42 const HWY_FULL(float) df; 43 const Rebind<pixel_type, HWY_FULL(float)> di; // assumes pixel_type <= float 44 const auto factor_v = Set(df, factor); 45 for (size_t x = 0; x < xsize; x += Lanes(di)) { 46 const auto in = Add(Load(di, row_in + x), Load(di, row_in_Y + x)); 47 const auto out = Mul(ConvertTo(df, in), factor_v); 48 Store(out, df, row_out + x); 49 } 50 } 51 52 void RgbFromSingle(const size_t xsize, 53 const pixel_type* const JXL_RESTRICT row_in, 54 const float factor, float* out_r, float* out_g, 55 float* out_b) { 56 const HWY_FULL(float) df; 57 const Rebind<pixel_type, HWY_FULL(float)> di; // assumes pixel_type <= float 58 59 const auto factor_v = Set(df, factor); 60 for (size_t x = 0; x < xsize; x += Lanes(di)) { 61 const auto in = Load(di, row_in + x); 62 const auto out = Mul(ConvertTo(df, in), factor_v); 63 Store(out, df, out_r + x); 64 Store(out, df, out_g + x); 65 Store(out, df, out_b + x); 66 } 67 } 68 69 void SingleFromSingle(const size_t xsize, 70 const pixel_type* const JXL_RESTRICT row_in, 71 const float factor, float* row_out) { 72 const HWY_FULL(float) df; 73 const Rebind<pixel_type, HWY_FULL(float)> di; // assumes pixel_type <= float 74 75 const auto factor_v = Set(df, factor); 76 for (size_t x = 0; x < xsize; x += Lanes(di)) { 77 const auto in = Load(di, row_in + x); 78 const auto out = Mul(ConvertTo(df, in), factor_v); 79 Store(out, df, row_out + x); 80 } 81 } 82 // NOLINTNEXTLINE(google-readability-namespace-comments) 83 } // namespace HWY_NAMESPACE 84 } // namespace jxl 85 HWY_AFTER_NAMESPACE(); 86 87 #if HWY_ONCE 88 namespace jxl { 89 HWY_EXPORT(MultiplySum); // Local function 90 HWY_EXPORT(RgbFromSingle); // Local function 91 HWY_EXPORT(SingleFromSingle); // Local function 92 93 // Slow conversion using double precision multiplication, only 94 // needed when the bit depth is too high for single precision 95 void SingleFromSingleAccurate(const size_t xsize, 96 const pixel_type* const JXL_RESTRICT row_in, 97 const double factor, float* row_out) { 98 for (size_t x = 0; x < xsize; x++) { 99 row_out[x] = row_in[x] * factor; 100 } 101 } 102 103 // convert custom [bits]-bit float (with [exp_bits] exponent bits) stored as int 104 // back to binary32 float 105 void int_to_float(const pixel_type* const JXL_RESTRICT row_in, 106 float* const JXL_RESTRICT row_out, const size_t xsize, 107 const int bits, const int exp_bits) { 108 if (bits == 32) { 109 JXL_ASSERT(sizeof(pixel_type) == sizeof(float)); 110 JXL_ASSERT(exp_bits == 8); 111 memcpy(row_out, row_in, xsize * sizeof(float)); 112 return; 113 } 114 int exp_bias = (1 << (exp_bits - 1)) - 1; 115 int sign_shift = bits - 1; 116 int mant_bits = bits - exp_bits - 1; 117 int mant_shift = 23 - mant_bits; 118 for (size_t x = 0; x < xsize; ++x) { 119 uint32_t f; 120 memcpy(&f, &row_in[x], 4); 121 int signbit = (f >> sign_shift); 122 f &= (1 << sign_shift) - 1; 123 if (f == 0) { 124 row_out[x] = (signbit ? -0.f : 0.f); 125 continue; 126 } 127 int exp = (f >> mant_bits); 128 int mantissa = (f & ((1 << mant_bits) - 1)); 129 mantissa <<= mant_shift; 130 // Try to normalize only if there is space for maneuver. 131 if (exp == 0 && exp_bits < 8) { 132 // subnormal number 133 while ((mantissa & 0x800000) == 0) { 134 mantissa <<= 1; 135 exp--; 136 } 137 exp++; 138 // remove leading 1 because it is implicit now 139 mantissa &= 0x7fffff; 140 } 141 exp -= exp_bias; 142 // broke up the arbitrary float into its parts, now reassemble into 143 // binary32 144 exp += 127; 145 JXL_ASSERT(exp >= 0); 146 f = (signbit ? 0x80000000 : 0); 147 f |= (exp << 23); 148 f |= mantissa; 149 memcpy(&row_out[x], &f, 4); 150 } 151 } 152 153 #if JXL_DEBUG_V_LEVEL >= 1 154 std::string ModularStreamId::DebugString() const { 155 std::ostringstream os; 156 os << (kind == kGlobalData ? "ModularGlobal" 157 : kind == kVarDCTDC ? "VarDCTDC" 158 : kind == kModularDC ? "ModularDC" 159 : kind == kACMetadata ? "ACMeta" 160 : kind == kQuantTable ? "QuantTable" 161 : kind == kModularAC ? "ModularAC" 162 : ""); 163 if (kind == kVarDCTDC || kind == kModularDC || kind == kACMetadata || 164 kind == kModularAC) { 165 os << " group " << group_id; 166 } 167 if (kind == kModularAC) { 168 os << " pass " << pass_id; 169 } 170 if (kind == kQuantTable) { 171 os << " " << quant_table_id; 172 } 173 return os.str(); 174 } 175 #endif 176 177 Status ModularFrameDecoder::DecodeGlobalInfo(BitReader* reader, 178 const FrameHeader& frame_header, 179 bool allow_truncated_group) { 180 bool decode_color = frame_header.encoding == FrameEncoding::kModular; 181 const auto& metadata = frame_header.nonserialized_metadata->m; 182 bool is_gray = metadata.color_encoding.IsGray(); 183 size_t nb_chans = 3; 184 if (is_gray && frame_header.color_transform == ColorTransform::kNone) { 185 nb_chans = 1; 186 } 187 do_color = decode_color; 188 size_t nb_extra = metadata.extra_channel_info.size(); 189 bool has_tree = static_cast<bool>(reader->ReadBits(1)); 190 if (!allow_truncated_group || 191 reader->TotalBitsConsumed() < reader->TotalBytes() * kBitsPerByte) { 192 if (has_tree) { 193 size_t tree_size_limit = 194 std::min(static_cast<size_t>(1 << 22), 195 1024 + frame_dim.xsize * frame_dim.ysize * 196 (nb_chans + nb_extra) / 16); 197 JXL_RETURN_IF_ERROR(DecodeTree(reader, &tree, tree_size_limit)); 198 JXL_RETURN_IF_ERROR( 199 DecodeHistograms(reader, (tree.size() + 1) / 2, &code, &context_map)); 200 } 201 } 202 if (!do_color) nb_chans = 0; 203 204 bool fp = metadata.bit_depth.floating_point_sample; 205 206 // bits_per_sample is just metadata for XYB images. 207 if (metadata.bit_depth.bits_per_sample >= 32 && do_color && 208 frame_header.color_transform != ColorTransform::kXYB) { 209 if (metadata.bit_depth.bits_per_sample == 32 && fp == false) { 210 return JXL_FAILURE("uint32_t not supported in dec_modular"); 211 } else if (metadata.bit_depth.bits_per_sample > 32) { 212 return JXL_FAILURE("bits_per_sample > 32 not supported"); 213 } 214 } 215 216 JXL_ASSIGN_OR_RETURN( 217 Image gi, 218 Image::Create(frame_dim.xsize, frame_dim.ysize, 219 metadata.bit_depth.bits_per_sample, nb_chans + nb_extra)); 220 221 all_same_shift = true; 222 if (frame_header.color_transform == ColorTransform::kYCbCr) { 223 for (size_t c = 0; c < nb_chans; c++) { 224 gi.channel[c].hshift = frame_header.chroma_subsampling.HShift(c); 225 gi.channel[c].vshift = frame_header.chroma_subsampling.VShift(c); 226 size_t xsize_shifted = 227 DivCeil(frame_dim.xsize, 1 << gi.channel[c].hshift); 228 size_t ysize_shifted = 229 DivCeil(frame_dim.ysize, 1 << gi.channel[c].vshift); 230 JXL_RETURN_IF_ERROR(gi.channel[c].shrink(xsize_shifted, ysize_shifted)); 231 if (gi.channel[c].hshift != gi.channel[0].hshift || 232 gi.channel[c].vshift != gi.channel[0].vshift) 233 all_same_shift = false; 234 } 235 } 236 237 for (size_t ec = 0, c = nb_chans; ec < nb_extra; ec++, c++) { 238 size_t ecups = frame_header.extra_channel_upsampling[ec]; 239 JXL_RETURN_IF_ERROR( 240 gi.channel[c].shrink(DivCeil(frame_dim.xsize_upsampled, ecups), 241 DivCeil(frame_dim.ysize_upsampled, ecups))); 242 gi.channel[c].hshift = gi.channel[c].vshift = 243 CeilLog2Nonzero(ecups) - CeilLog2Nonzero(frame_header.upsampling); 244 if (gi.channel[c].hshift != gi.channel[0].hshift || 245 gi.channel[c].vshift != gi.channel[0].vshift) 246 all_same_shift = false; 247 } 248 249 JXL_DEBUG_V(6, "DecodeGlobalInfo: full_image (w/o transforms) %s", 250 gi.DebugString().c_str()); 251 ModularOptions options; 252 options.max_chan_size = frame_dim.group_dim; 253 options.group_dim = frame_dim.group_dim; 254 Status dec_status = ModularGenericDecompress( 255 reader, gi, &global_header, ModularStreamId::Global().ID(frame_dim), 256 &options, 257 /*undo_transforms=*/false, &tree, &code, &context_map, 258 allow_truncated_group); 259 if (!allow_truncated_group) JXL_RETURN_IF_ERROR(dec_status); 260 if (dec_status.IsFatalError()) { 261 return JXL_FAILURE("Failed to decode global modular info"); 262 } 263 264 // TODO(eustas): are we sure this can be done after partial decode? 265 have_something = false; 266 for (size_t c = 0; c < gi.channel.size(); c++) { 267 Channel& gic = gi.channel[c]; 268 if (c >= gi.nb_meta_channels && gic.w <= frame_dim.group_dim && 269 gic.h <= frame_dim.group_dim) 270 have_something = true; 271 } 272 // move global transforms to groups if possible 273 if (!have_something && all_same_shift) { 274 if (gi.transform.size() == 1 && gi.transform[0].id == TransformId::kRCT) { 275 global_transform = gi.transform; 276 gi.transform.clear(); 277 // TODO(jon): also move no-delta-palette out (trickier though) 278 } 279 } 280 full_image = std::move(gi); 281 JXL_DEBUG_V(6, "DecodeGlobalInfo: full_image (with transforms) %s", 282 full_image.DebugString().c_str()); 283 return dec_status; 284 } 285 286 void ModularFrameDecoder::MaybeDropFullImage() { 287 if (full_image.transform.empty() && !have_something && all_same_shift) { 288 use_full_image = false; 289 JXL_DEBUG_V(6, "Dropping full image"); 290 for (auto& ch : full_image.channel) { 291 // keep metadata on channels around, but dealloc their planes 292 ch.plane = Plane<pixel_type>(); 293 } 294 } 295 } 296 297 Status ModularFrameDecoder::DecodeGroup( 298 const FrameHeader& frame_header, const Rect& rect, BitReader* reader, 299 int minShift, int maxShift, const ModularStreamId& stream, bool zerofill, 300 PassesDecoderState* dec_state, RenderPipelineInput* render_pipeline_input, 301 bool allow_truncated, bool* should_run_pipeline) { 302 JXL_DEBUG_V(6, "Decoding %s with rect %s and shift bracket %d..%d %s", 303 stream.DebugString().c_str(), Description(rect).c_str(), minShift, 304 maxShift, zerofill ? "using zerofill" : ""); 305 JXL_DASSERT(stream.kind == ModularStreamId::kModularDC || 306 stream.kind == ModularStreamId::kModularAC); 307 const size_t xsize = rect.xsize(); 308 const size_t ysize = rect.ysize(); 309 JXL_ASSIGN_OR_RETURN(Image gi, 310 Image::Create(xsize, ysize, full_image.bitdepth, 0)); 311 // start at the first bigger-than-groupsize non-metachannel 312 size_t c = full_image.nb_meta_channels; 313 for (; c < full_image.channel.size(); c++) { 314 Channel& fc = full_image.channel[c]; 315 if (fc.w > frame_dim.group_dim || fc.h > frame_dim.group_dim) break; 316 } 317 size_t beginc = c; 318 for (; c < full_image.channel.size(); c++) { 319 Channel& fc = full_image.channel[c]; 320 int shift = std::min(fc.hshift, fc.vshift); 321 if (shift > maxShift) continue; 322 if (shift < minShift) continue; 323 Rect r(rect.x0() >> fc.hshift, rect.y0() >> fc.vshift, 324 rect.xsize() >> fc.hshift, rect.ysize() >> fc.vshift, fc.w, fc.h); 325 if (r.xsize() == 0 || r.ysize() == 0) continue; 326 if (zerofill && use_full_image) { 327 for (size_t y = 0; y < r.ysize(); ++y) { 328 pixel_type* const JXL_RESTRICT row_out = r.Row(&fc.plane, y); 329 memset(row_out, 0, r.xsize() * sizeof(*row_out)); 330 } 331 } else { 332 JXL_ASSIGN_OR_RETURN(Channel gc, Channel::Create(r.xsize(), r.ysize())); 333 if (zerofill) ZeroFillImage(&gc.plane); 334 gc.hshift = fc.hshift; 335 gc.vshift = fc.vshift; 336 gi.channel.emplace_back(std::move(gc)); 337 } 338 } 339 if (zerofill && use_full_image) return true; 340 // Return early if there's nothing to decode. Otherwise there might be 341 // problems later (in ModularImageToDecodedRect). 342 if (gi.channel.empty()) { 343 if (dec_state && should_run_pipeline) { 344 const auto* metadata = frame_header.nonserialized_metadata; 345 if (do_color || metadata->m.num_extra_channels > 0) { 346 // Signal to FrameDecoder that we do not have some of the required input 347 // for the render pipeline. 348 *should_run_pipeline = false; 349 } 350 } 351 JXL_DEBUG_V(6, "Nothing to decode, returning early."); 352 return true; 353 } 354 ModularOptions options; 355 if (!zerofill) { 356 auto status = ModularGenericDecompress( 357 reader, gi, /*header=*/nullptr, stream.ID(frame_dim), &options, 358 /*undo_transforms=*/true, &tree, &code, &context_map, allow_truncated); 359 if (!allow_truncated) JXL_RETURN_IF_ERROR(status); 360 if (status.IsFatalError()) return status; 361 } 362 // Undo global transforms that have been pushed to the group level 363 if (!use_full_image) { 364 JXL_ASSERT(render_pipeline_input); 365 for (auto t : global_transform) { 366 JXL_RETURN_IF_ERROR(t.Inverse(gi, global_header.wp_header)); 367 } 368 JXL_RETURN_IF_ERROR(ModularImageToDecodedRect( 369 frame_header, gi, dec_state, nullptr, *render_pipeline_input, 370 Rect(0, 0, gi.w, gi.h))); 371 return true; 372 } 373 int gic = 0; 374 for (c = beginc; c < full_image.channel.size(); c++) { 375 Channel& fc = full_image.channel[c]; 376 int shift = std::min(fc.hshift, fc.vshift); 377 if (shift > maxShift) continue; 378 if (shift < minShift) continue; 379 Rect r(rect.x0() >> fc.hshift, rect.y0() >> fc.vshift, 380 rect.xsize() >> fc.hshift, rect.ysize() >> fc.vshift, fc.w, fc.h); 381 if (r.xsize() == 0 || r.ysize() == 0) continue; 382 JXL_ASSERT(use_full_image); 383 CopyImageTo(/*rect_from=*/Rect(0, 0, r.xsize(), r.ysize()), 384 /*from=*/gi.channel[gic].plane, 385 /*rect_to=*/r, /*to=*/&fc.plane); 386 gic++; 387 } 388 return true; 389 } 390 391 Status ModularFrameDecoder::DecodeVarDCTDC(const FrameHeader& frame_header, 392 size_t group_id, BitReader* reader, 393 PassesDecoderState* dec_state) { 394 const Rect r = dec_state->shared->frame_dim.DCGroupRect(group_id); 395 JXL_DEBUG_V(6, "Decoding VarDCT DC with rect %s", Description(r).c_str()); 396 // TODO(eustas): investigate if we could reduce the impact of 397 // EvalRationalPolynomial; generally speaking, the limit is 398 // 2**(128/(3*magic)), where 128 comes from IEEE 754 exponent, 399 // 3 comes from XybToRgb that cubes the values, and "magic" is 400 // the sum of all other contributions. 2**18 is known to lead 401 // to NaN on input found by fuzzing (see commit message). 402 JXL_ASSIGN_OR_RETURN( 403 Image image, Image::Create(r.xsize(), r.ysize(), full_image.bitdepth, 3)); 404 size_t stream_id = ModularStreamId::VarDCTDC(group_id).ID(frame_dim); 405 reader->Refill(); 406 size_t extra_precision = reader->ReadFixedBits<2>(); 407 float mul = 1.0f / (1 << extra_precision); 408 ModularOptions options; 409 for (size_t c = 0; c < 3; c++) { 410 Channel& ch = image.channel[c < 2 ? c ^ 1 : c]; 411 ch.w >>= frame_header.chroma_subsampling.HShift(c); 412 ch.h >>= frame_header.chroma_subsampling.VShift(c); 413 JXL_RETURN_IF_ERROR(ch.shrink()); 414 } 415 if (!ModularGenericDecompress( 416 reader, image, /*header=*/nullptr, stream_id, &options, 417 /*undo_transforms=*/true, &tree, &code, &context_map)) { 418 return JXL_FAILURE("Failed to decode VarDCT DC group (DC group id %d)", 419 static_cast<int>(group_id)); 420 } 421 DequantDC(r, &dec_state->shared_storage.dc_storage, 422 &dec_state->shared_storage.quant_dc, image, 423 dec_state->shared->quantizer.MulDC(), mul, 424 dec_state->shared->cmap.DCFactors(), 425 frame_header.chroma_subsampling, dec_state->shared->block_ctx_map); 426 return true; 427 } 428 429 Status ModularFrameDecoder::DecodeAcMetadata(const FrameHeader& frame_header, 430 size_t group_id, BitReader* reader, 431 PassesDecoderState* dec_state) { 432 const Rect r = dec_state->shared->frame_dim.DCGroupRect(group_id); 433 JXL_DEBUG_V(6, "Decoding AcMetadata with rect %s", Description(r).c_str()); 434 size_t upper_bound = r.xsize() * r.ysize(); 435 reader->Refill(); 436 size_t count = reader->ReadBits(CeilLog2Nonzero(upper_bound)) + 1; 437 size_t stream_id = ModularStreamId::ACMetadata(group_id).ID(frame_dim); 438 // YToX, YToB, ACS + QF, EPF 439 JXL_ASSIGN_OR_RETURN( 440 Image image, Image::Create(r.xsize(), r.ysize(), full_image.bitdepth, 4)); 441 static_assert(kColorTileDimInBlocks == 8, "Color tile size changed"); 442 Rect cr(r.x0() >> 3, r.y0() >> 3, (r.xsize() + 7) >> 3, (r.ysize() + 7) >> 3); 443 JXL_ASSIGN_OR_RETURN(image.channel[0], 444 Channel::Create(cr.xsize(), cr.ysize(), 3, 3)); 445 JXL_ASSIGN_OR_RETURN(image.channel[1], 446 Channel::Create(cr.xsize(), cr.ysize(), 3, 3)); 447 JXL_ASSIGN_OR_RETURN(image.channel[2], Channel::Create(count, 2, 0, 0)); 448 ModularOptions options; 449 if (!ModularGenericDecompress( 450 reader, image, /*header=*/nullptr, stream_id, &options, 451 /*undo_transforms=*/true, &tree, &code, &context_map)) { 452 return JXL_FAILURE("Failed to decode AC metadata"); 453 } 454 ConvertPlaneAndClamp(Rect(image.channel[0].plane), image.channel[0].plane, cr, 455 &dec_state->shared_storage.cmap.ytox_map); 456 ConvertPlaneAndClamp(Rect(image.channel[1].plane), image.channel[1].plane, cr, 457 &dec_state->shared_storage.cmap.ytob_map); 458 size_t num = 0; 459 bool is444 = frame_header.chroma_subsampling.Is444(); 460 auto& ac_strategy = dec_state->shared_storage.ac_strategy; 461 size_t xlim = std::min(ac_strategy.xsize(), r.x0() + r.xsize()); 462 size_t ylim = std::min(ac_strategy.ysize(), r.y0() + r.ysize()); 463 uint32_t local_used_acs = 0; 464 for (size_t iy = 0; iy < r.ysize(); iy++) { 465 size_t y = r.y0() + iy; 466 int32_t* row_qf = r.Row(&dec_state->shared_storage.raw_quant_field, iy); 467 uint8_t* row_epf = r.Row(&dec_state->shared_storage.epf_sharpness, iy); 468 int32_t* row_in_1 = image.channel[2].plane.Row(0); 469 int32_t* row_in_2 = image.channel[2].plane.Row(1); 470 int32_t* row_in_3 = image.channel[3].plane.Row(iy); 471 for (size_t ix = 0; ix < r.xsize(); ix++) { 472 size_t x = r.x0() + ix; 473 int sharpness = row_in_3[ix]; 474 if (sharpness < 0 || sharpness >= LoopFilter::kEpfSharpEntries) { 475 return JXL_FAILURE("Corrupted sharpness field"); 476 } 477 row_epf[ix] = sharpness; 478 if (ac_strategy.IsValid(x, y)) { 479 continue; 480 } 481 482 if (num >= count) return JXL_FAILURE("Corrupted stream"); 483 484 if (!AcStrategy::IsRawStrategyValid(row_in_1[num])) { 485 return JXL_FAILURE("Invalid AC strategy"); 486 } 487 local_used_acs |= 1u << row_in_1[num]; 488 AcStrategy acs = AcStrategy::FromRawStrategy(row_in_1[num]); 489 if ((acs.covered_blocks_x() > 1 || acs.covered_blocks_y() > 1) && 490 !is444) { 491 return JXL_FAILURE( 492 "AC strategy not compatible with chroma subsampling"); 493 } 494 // Ensure that blocks do not overflow *AC* groups. 495 size_t next_x_ac_block = (x / kGroupDimInBlocks + 1) * kGroupDimInBlocks; 496 size_t next_y_ac_block = (y / kGroupDimInBlocks + 1) * kGroupDimInBlocks; 497 size_t next_x_dct_block = x + acs.covered_blocks_x(); 498 size_t next_y_dct_block = y + acs.covered_blocks_y(); 499 if (next_x_dct_block > next_x_ac_block || next_x_dct_block > xlim) { 500 return JXL_FAILURE("Invalid AC strategy, x overflow"); 501 } 502 if (next_y_dct_block > next_y_ac_block || next_y_dct_block > ylim) { 503 return JXL_FAILURE("Invalid AC strategy, y overflow"); 504 } 505 JXL_RETURN_IF_ERROR( 506 ac_strategy.SetNoBoundsCheck(x, y, AcStrategy::Type(row_in_1[num]))); 507 row_qf[ix] = 1 + std::max<int32_t>(0, std::min(Quantizer::kQuantMax - 1, 508 row_in_2[num])); 509 num++; 510 } 511 } 512 dec_state->used_acs |= local_used_acs; 513 if (frame_header.loop_filter.epf_iters > 0) { 514 ComputeSigma(frame_header.loop_filter, r, dec_state); 515 } 516 return true; 517 } 518 519 Status ModularFrameDecoder::ModularImageToDecodedRect( 520 const FrameHeader& frame_header, Image& gi, PassesDecoderState* dec_state, 521 jxl::ThreadPool* pool, RenderPipelineInput& render_pipeline_input, 522 Rect modular_rect) const { 523 const auto* metadata = frame_header.nonserialized_metadata; 524 JXL_CHECK(gi.transform.empty()); 525 526 auto get_row = [&](size_t c, size_t y) { 527 const auto& buffer = render_pipeline_input.GetBuffer(c); 528 return buffer.second.Row(buffer.first, y); 529 }; 530 531 size_t c = 0; 532 if (do_color) { 533 const bool rgb_from_gray = 534 metadata->m.color_encoding.IsGray() && 535 frame_header.color_transform == ColorTransform::kNone; 536 const bool fp = metadata->m.bit_depth.floating_point_sample && 537 frame_header.color_transform != ColorTransform::kXYB; 538 for (; c < 3; c++) { 539 double factor = full_image.bitdepth < 32 540 ? 1.0 / ((1u << full_image.bitdepth) - 1) 541 : 0; 542 size_t c_in = c; 543 if (frame_header.color_transform == ColorTransform::kXYB) { 544 factor = dec_state->shared->matrices.DCQuants()[c]; 545 // XYB is encoded as YX(B-Y) 546 if (c < 2) c_in = 1 - c; 547 } else if (rgb_from_gray) { 548 c_in = 0; 549 } 550 JXL_ASSERT(c_in < gi.channel.size()); 551 Channel& ch_in = gi.channel[c_in]; 552 // TODO(eustas): could we detect it on earlier stage? 553 if (ch_in.w == 0 || ch_in.h == 0) { 554 return JXL_FAILURE("Empty image"); 555 } 556 JXL_CHECK(ch_in.hshift <= 3 && ch_in.vshift <= 3); 557 Rect r = render_pipeline_input.GetBuffer(c).second; 558 Rect mr(modular_rect.x0() >> ch_in.hshift, 559 modular_rect.y0() >> ch_in.vshift, 560 DivCeil(modular_rect.xsize(), 1 << ch_in.hshift), 561 DivCeil(modular_rect.ysize(), 1 << ch_in.vshift)); 562 mr = mr.Crop(ch_in.plane); 563 size_t xsize_shifted = r.xsize(); 564 size_t ysize_shifted = r.ysize(); 565 if (r.ysize() != mr.ysize() || r.xsize() != mr.xsize()) { 566 return JXL_FAILURE("Dimension mismatch: trying to fit a %" PRIuS 567 "x%" PRIuS 568 " modular channel into " 569 "a %" PRIuS "x%" PRIuS " rect", 570 mr.xsize(), mr.ysize(), r.xsize(), r.ysize()); 571 } 572 if (frame_header.color_transform == ColorTransform::kXYB && c == 2) { 573 JXL_ASSERT(!fp); 574 JXL_RETURN_IF_ERROR(RunOnPool( 575 pool, 0, ysize_shifted, ThreadPool::NoInit, 576 [&](const uint32_t task, size_t /* thread */) { 577 const size_t y = task; 578 const pixel_type* const JXL_RESTRICT row_in = 579 mr.Row(&ch_in.plane, y); 580 const pixel_type* const JXL_RESTRICT row_in_Y = 581 mr.Row(&gi.channel[0].plane, y); 582 float* const JXL_RESTRICT row_out = get_row(c, y); 583 HWY_DYNAMIC_DISPATCH(MultiplySum) 584 (xsize_shifted, row_in, row_in_Y, factor, row_out); 585 }, 586 "ModularIntToFloat")); 587 } else if (fp) { 588 int bits = metadata->m.bit_depth.bits_per_sample; 589 int exp_bits = metadata->m.bit_depth.exponent_bits_per_sample; 590 JXL_RETURN_IF_ERROR(RunOnPool( 591 pool, 0, ysize_shifted, ThreadPool::NoInit, 592 [&](const uint32_t task, size_t /* thread */) { 593 const size_t y = task; 594 const pixel_type* const JXL_RESTRICT row_in = 595 mr.Row(&ch_in.plane, y); 596 if (rgb_from_gray) { 597 for (size_t cc = 0; cc < 3; cc++) { 598 float* const JXL_RESTRICT row_out = get_row(cc, y); 599 int_to_float(row_in, row_out, xsize_shifted, bits, exp_bits); 600 } 601 } else { 602 float* const JXL_RESTRICT row_out = get_row(c, y); 603 int_to_float(row_in, row_out, xsize_shifted, bits, exp_bits); 604 } 605 }, 606 "ModularIntToFloat_losslessfloat")); 607 } else { 608 JXL_RETURN_IF_ERROR(RunOnPool( 609 pool, 0, ysize_shifted, ThreadPool::NoInit, 610 [&](const uint32_t task, size_t /* thread */) { 611 const size_t y = task; 612 const pixel_type* const JXL_RESTRICT row_in = 613 mr.Row(&ch_in.plane, y); 614 if (rgb_from_gray) { 615 if (full_image.bitdepth < 23) { 616 HWY_DYNAMIC_DISPATCH(RgbFromSingle) 617 (xsize_shifted, row_in, factor, get_row(0, y), get_row(1, y), 618 get_row(2, y)); 619 } else { 620 SingleFromSingleAccurate(xsize_shifted, row_in, factor, 621 get_row(0, y)); 622 SingleFromSingleAccurate(xsize_shifted, row_in, factor, 623 get_row(1, y)); 624 SingleFromSingleAccurate(xsize_shifted, row_in, factor, 625 get_row(2, y)); 626 } 627 } else { 628 float* const JXL_RESTRICT row_out = get_row(c, y); 629 if (full_image.bitdepth < 23) { 630 HWY_DYNAMIC_DISPATCH(SingleFromSingle) 631 (xsize_shifted, row_in, factor, row_out); 632 } else { 633 SingleFromSingleAccurate(xsize_shifted, row_in, factor, 634 row_out); 635 } 636 } 637 }, 638 "ModularIntToFloat")); 639 } 640 if (rgb_from_gray) { 641 break; 642 } 643 } 644 if (rgb_from_gray) { 645 c = 1; 646 } 647 } 648 size_t num_extra_channels = metadata->m.num_extra_channels; 649 for (size_t ec = 0; ec < num_extra_channels; ec++, c++) { 650 const ExtraChannelInfo& eci = metadata->m.extra_channel_info[ec]; 651 int bits = eci.bit_depth.bits_per_sample; 652 int exp_bits = eci.bit_depth.exponent_bits_per_sample; 653 bool fp = eci.bit_depth.floating_point_sample; 654 JXL_ASSERT(fp || bits < 32); 655 const double factor = fp ? 0 : (1.0 / ((1u << bits) - 1)); 656 JXL_ASSERT(c < gi.channel.size()); 657 Channel& ch_in = gi.channel[c]; 658 Rect r = render_pipeline_input.GetBuffer(3 + ec).second; 659 Rect mr(modular_rect.x0() >> ch_in.hshift, 660 modular_rect.y0() >> ch_in.vshift, 661 DivCeil(modular_rect.xsize(), 1 << ch_in.hshift), 662 DivCeil(modular_rect.ysize(), 1 << ch_in.vshift)); 663 mr = mr.Crop(ch_in.plane); 664 if (r.ysize() != mr.ysize() || r.xsize() != mr.xsize()) { 665 return JXL_FAILURE("Dimension mismatch: trying to fit a %" PRIuS 666 "x%" PRIuS 667 " modular channel into " 668 "a %" PRIuS "x%" PRIuS " rect", 669 mr.xsize(), mr.ysize(), r.xsize(), r.ysize()); 670 } 671 for (size_t y = 0; y < r.ysize(); ++y) { 672 float* const JXL_RESTRICT row_out = 673 r.Row(render_pipeline_input.GetBuffer(3 + ec).first, y); 674 const pixel_type* const JXL_RESTRICT row_in = mr.Row(&ch_in.plane, y); 675 if (fp) { 676 int_to_float(row_in, row_out, r.xsize(), bits, exp_bits); 677 } else { 678 if (full_image.bitdepth < 23) { 679 HWY_DYNAMIC_DISPATCH(SingleFromSingle) 680 (r.xsize(), row_in, factor, row_out); 681 } else { 682 SingleFromSingleAccurate(r.xsize(), row_in, factor, row_out); 683 } 684 } 685 } 686 } 687 return true; 688 } 689 690 Status ModularFrameDecoder::FinalizeDecoding(const FrameHeader& frame_header, 691 PassesDecoderState* dec_state, 692 jxl::ThreadPool* pool, 693 bool inplace) { 694 if (!use_full_image) return true; 695 Image gi; 696 if (inplace) { 697 gi = std::move(full_image); 698 } else { 699 JXL_ASSIGN_OR_RETURN(gi, Image::Clone(full_image)); 700 } 701 size_t xsize = gi.w; 702 size_t ysize = gi.h; 703 704 JXL_DEBUG_V(3, "Finalizing decoding for modular image: %s", 705 gi.DebugString().c_str()); 706 707 // Don't use threads if total image size is smaller than a group 708 if (xsize * ysize < frame_dim.group_dim * frame_dim.group_dim) pool = nullptr; 709 710 // Undo the global transforms 711 gi.undo_transforms(global_header.wp_header, pool); 712 JXL_DASSERT(global_transform.empty()); 713 if (gi.error) return JXL_FAILURE("Undoing transforms failed"); 714 715 for (size_t i = 0; i < dec_state->shared->frame_dim.num_groups; i++) { 716 dec_state->render_pipeline->ClearDone(i); 717 } 718 std::atomic<bool> has_error{false}; 719 JXL_RETURN_IF_ERROR(RunOnPool( 720 pool, 0, dec_state->shared->frame_dim.num_groups, 721 [&](size_t num_threads) { 722 bool use_group_ids = (frame_header.encoding == FrameEncoding::kVarDCT || 723 (frame_header.flags & FrameHeader::kNoise)); 724 return dec_state->render_pipeline->PrepareForThreads(num_threads, 725 use_group_ids); 726 }, 727 [&](const uint32_t group, size_t thread_id) { 728 if (has_error) return; 729 RenderPipelineInput input = 730 dec_state->render_pipeline->GetInputBuffers(group, thread_id); 731 if (!ModularImageToDecodedRect( 732 frame_header, gi, dec_state, nullptr, input, 733 dec_state->shared->frame_dim.GroupRect(group))) { 734 has_error = true; 735 return; 736 } 737 if (!input.Done()) { 738 has_error = true; 739 return; 740 } 741 }, 742 "ModularToRect")); 743 if (has_error) return JXL_FAILURE("Error producing input to render pipeline"); 744 return true; 745 } 746 747 static constexpr const float kAlmostZero = 1e-8f; 748 749 Status ModularFrameDecoder::DecodeQuantTable( 750 size_t required_size_x, size_t required_size_y, BitReader* br, 751 QuantEncoding* encoding, size_t idx, 752 ModularFrameDecoder* modular_frame_decoder) { 753 JXL_RETURN_IF_ERROR(F16Coder::Read(br, &encoding->qraw.qtable_den)); 754 if (encoding->qraw.qtable_den < kAlmostZero) { 755 // qtable[] values are already checked for <= 0 so the denominator may not 756 // be negative. 757 return JXL_FAILURE("Invalid qtable_den: value too small"); 758 } 759 JXL_ASSIGN_OR_RETURN(Image image, 760 Image::Create(required_size_x, required_size_y, 8, 3)); 761 ModularOptions options; 762 if (modular_frame_decoder) { 763 JXL_RETURN_IF_ERROR(ModularGenericDecompress( 764 br, image, /*header=*/nullptr, 765 ModularStreamId::QuantTable(idx).ID(modular_frame_decoder->frame_dim), 766 &options, /*undo_transforms=*/true, &modular_frame_decoder->tree, 767 &modular_frame_decoder->code, &modular_frame_decoder->context_map)); 768 } else { 769 JXL_RETURN_IF_ERROR(ModularGenericDecompress(br, image, /*header=*/nullptr, 770 0, &options, 771 /*undo_transforms=*/true)); 772 } 773 if (!encoding->qraw.qtable) { 774 encoding->qraw.qtable = new std::vector<int>(); 775 } 776 encoding->qraw.qtable->resize(required_size_x * required_size_y * 3); 777 for (size_t c = 0; c < 3; c++) { 778 for (size_t y = 0; y < required_size_y; y++) { 779 int32_t* JXL_RESTRICT row = image.channel[c].Row(y); 780 for (size_t x = 0; x < required_size_x; x++) { 781 (*encoding->qraw.qtable)[c * required_size_x * required_size_y + 782 y * required_size_x + x] = row[x]; 783 if (row[x] <= 0) { 784 return JXL_FAILURE("Invalid raw quantization table"); 785 } 786 } 787 } 788 } 789 return true; 790 } 791 792 } // namespace jxl 793 #endif // HWY_ONCE