dec_group.cc (32464B)
1 // Copyright (c) the JPEG XL Project Authors. All rights reserved. 2 // 3 // Use of this source code is governed by a BSD-style 4 // license that can be found in the LICENSE file. 5 6 #include "lib/jxl/dec_group.h" 7 8 #include <stdint.h> 9 #include <string.h> 10 11 #include <algorithm> 12 #include <memory> 13 #include <utility> 14 15 #include "lib/jxl/frame_header.h" 16 17 #undef HWY_TARGET_INCLUDE 18 #define HWY_TARGET_INCLUDE "lib/jxl/dec_group.cc" 19 #include <hwy/foreach_target.h> 20 #include <hwy/highway.h> 21 22 #include "lib/jxl/ac_context.h" 23 #include "lib/jxl/ac_strategy.h" 24 #include "lib/jxl/base/bits.h" 25 #include "lib/jxl/base/common.h" 26 #include "lib/jxl/base/printf_macros.h" 27 #include "lib/jxl/base/status.h" 28 #include "lib/jxl/coeff_order.h" 29 #include "lib/jxl/common.h" // kMaxNumPasses 30 #include "lib/jxl/dec_cache.h" 31 #include "lib/jxl/dec_transforms-inl.h" 32 #include "lib/jxl/dec_xyb.h" 33 #include "lib/jxl/entropy_coder.h" 34 #include "lib/jxl/quant_weights.h" 35 #include "lib/jxl/quantizer-inl.h" 36 #include "lib/jxl/quantizer.h" 37 38 #ifndef LIB_JXL_DEC_GROUP_CC 39 #define LIB_JXL_DEC_GROUP_CC 40 namespace jxl { 41 42 struct AuxOut; 43 44 // Interface for reading groups for DecodeGroupImpl. 45 class GetBlock { 46 public: 47 virtual void StartRow(size_t by) = 0; 48 virtual Status LoadBlock(size_t bx, size_t by, const AcStrategy& acs, 49 size_t size, size_t log2_covered_blocks, 50 ACPtr block[3], ACType ac_type) = 0; 51 virtual ~GetBlock() {} 52 }; 53 54 // Controls whether DecodeGroupImpl renders to pixels or not. 55 enum DrawMode { 56 // Render to pixels. 57 kDraw = 0, 58 // Don't render to pixels. 59 kDontDraw = 1, 60 }; 61 62 } // namespace jxl 63 #endif // LIB_JXL_DEC_GROUP_CC 64 65 HWY_BEFORE_NAMESPACE(); 66 namespace jxl { 67 namespace HWY_NAMESPACE { 68 69 // These templates are not found via ADL. 70 using hwy::HWY_NAMESPACE::AllFalse; 71 using hwy::HWY_NAMESPACE::Gt; 72 using hwy::HWY_NAMESPACE::Le; 73 using hwy::HWY_NAMESPACE::MaskFromVec; 74 using hwy::HWY_NAMESPACE::Or; 75 using hwy::HWY_NAMESPACE::Rebind; 76 using hwy::HWY_NAMESPACE::ShiftRight; 77 78 using D = HWY_FULL(float); 79 using DU = HWY_FULL(uint32_t); 80 using DI = HWY_FULL(int32_t); 81 using DI16 = Rebind<int16_t, DI>; 82 using DI16_FULL = HWY_CAPPED(int16_t, kDCTBlockSize); 83 constexpr D d; 84 constexpr DI di; 85 constexpr DI16 di16; 86 constexpr DI16_FULL di16_full; 87 88 // TODO(veluca): consider SIMDfying. 89 void Transpose8x8InPlace(int32_t* JXL_RESTRICT block) { 90 for (size_t x = 0; x < 8; x++) { 91 for (size_t y = x + 1; y < 8; y++) { 92 std::swap(block[y * 8 + x], block[x * 8 + y]); 93 } 94 } 95 } 96 97 template <ACType ac_type> 98 void DequantLane(Vec<D> scaled_dequant_x, Vec<D> scaled_dequant_y, 99 Vec<D> scaled_dequant_b, 100 const float* JXL_RESTRICT dequant_matrices, size_t size, 101 size_t k, Vec<D> x_cc_mul, Vec<D> b_cc_mul, 102 const float* JXL_RESTRICT biases, ACPtr qblock[3], 103 float* JXL_RESTRICT block) { 104 const auto x_mul = Mul(Load(d, dequant_matrices + k), scaled_dequant_x); 105 const auto y_mul = 106 Mul(Load(d, dequant_matrices + size + k), scaled_dequant_y); 107 const auto b_mul = 108 Mul(Load(d, dequant_matrices + 2 * size + k), scaled_dequant_b); 109 110 Vec<DI> quantized_x_int; 111 Vec<DI> quantized_y_int; 112 Vec<DI> quantized_b_int; 113 if (ac_type == ACType::k16) { 114 Rebind<int16_t, DI> di16; 115 quantized_x_int = PromoteTo(di, Load(di16, qblock[0].ptr16 + k)); 116 quantized_y_int = PromoteTo(di, Load(di16, qblock[1].ptr16 + k)); 117 quantized_b_int = PromoteTo(di, Load(di16, qblock[2].ptr16 + k)); 118 } else { 119 quantized_x_int = Load(di, qblock[0].ptr32 + k); 120 quantized_y_int = Load(di, qblock[1].ptr32 + k); 121 quantized_b_int = Load(di, qblock[2].ptr32 + k); 122 } 123 124 const auto dequant_x_cc = 125 Mul(AdjustQuantBias(di, 0, quantized_x_int, biases), x_mul); 126 const auto dequant_y = 127 Mul(AdjustQuantBias(di, 1, quantized_y_int, biases), y_mul); 128 const auto dequant_b_cc = 129 Mul(AdjustQuantBias(di, 2, quantized_b_int, biases), b_mul); 130 131 const auto dequant_x = MulAdd(x_cc_mul, dequant_y, dequant_x_cc); 132 const auto dequant_b = MulAdd(b_cc_mul, dequant_y, dequant_b_cc); 133 Store(dequant_x, d, block + k); 134 Store(dequant_y, d, block + size + k); 135 Store(dequant_b, d, block + 2 * size + k); 136 } 137 138 template <ACType ac_type> 139 void DequantBlock(const AcStrategy& acs, float inv_global_scale, int quant, 140 float x_dm_multiplier, float b_dm_multiplier, Vec<D> x_cc_mul, 141 Vec<D> b_cc_mul, size_t kind, size_t size, 142 const Quantizer& quantizer, size_t covered_blocks, 143 const size_t* sbx, 144 const float* JXL_RESTRICT* JXL_RESTRICT dc_row, 145 size_t dc_stride, const float* JXL_RESTRICT biases, 146 ACPtr qblock[3], float* JXL_RESTRICT block, 147 float* JXL_RESTRICT scratch) { 148 const auto scaled_dequant_s = inv_global_scale / quant; 149 150 const auto scaled_dequant_x = Set(d, scaled_dequant_s * x_dm_multiplier); 151 const auto scaled_dequant_y = Set(d, scaled_dequant_s); 152 const auto scaled_dequant_b = Set(d, scaled_dequant_s * b_dm_multiplier); 153 154 const float* dequant_matrices = quantizer.DequantMatrix(kind, 0); 155 156 for (size_t k = 0; k < covered_blocks * kDCTBlockSize; k += Lanes(d)) { 157 DequantLane<ac_type>(scaled_dequant_x, scaled_dequant_y, scaled_dequant_b, 158 dequant_matrices, size, k, x_cc_mul, b_cc_mul, biases, 159 qblock, block); 160 } 161 for (size_t c = 0; c < 3; c++) { 162 LowestFrequenciesFromDC(acs.Strategy(), dc_row[c] + sbx[c], dc_stride, 163 block + c * size, scratch); 164 } 165 } 166 167 Status DecodeGroupImpl(const FrameHeader& frame_header, 168 GetBlock* JXL_RESTRICT get_block, 169 GroupDecCache* JXL_RESTRICT group_dec_cache, 170 PassesDecoderState* JXL_RESTRICT dec_state, 171 size_t thread, size_t group_idx, 172 RenderPipelineInput& render_pipeline_input, 173 ImageBundle* decoded, DrawMode draw) { 174 // TODO(veluca): investigate cache usage in this function. 175 const Rect block_rect = 176 dec_state->shared->frame_dim.BlockGroupRect(group_idx); 177 const AcStrategyImage& ac_strategy = dec_state->shared->ac_strategy; 178 179 const size_t xsize_blocks = block_rect.xsize(); 180 const size_t ysize_blocks = block_rect.ysize(); 181 182 const size_t dc_stride = dec_state->shared->dc->PixelsPerRow(); 183 184 const float inv_global_scale = dec_state->shared->quantizer.InvGlobalScale(); 185 186 const YCbCrChromaSubsampling& cs = frame_header.chroma_subsampling; 187 188 const auto kJpegDctMin = Set(di16_full, -4095); 189 const auto kJpegDctMax = Set(di16_full, 4095); 190 191 size_t idct_stride[3]; 192 for (size_t c = 0; c < 3; c++) { 193 idct_stride[c] = render_pipeline_input.GetBuffer(c).first->PixelsPerRow(); 194 } 195 196 HWY_ALIGN int32_t scaled_qtable[64 * 3]; 197 198 ACType ac_type = dec_state->coefficients->Type(); 199 auto dequant_block = ac_type == ACType::k16 ? DequantBlock<ACType::k16> 200 : DequantBlock<ACType::k32>; 201 // Whether or not coefficients should be stored for future usage, and/or read 202 // from past usage. 203 bool accumulate = !dec_state->coefficients->IsEmpty(); 204 // Offset of the current block in the group. 205 size_t offset = 0; 206 207 std::array<int, 3> jpeg_c_map; 208 bool jpeg_is_gray = false; 209 std::array<int, 3> dcoff = {}; 210 211 // TODO(veluca): all of this should be done only once per image. 212 if (decoded->IsJPEG()) { 213 if (!dec_state->shared->cmap.IsJPEGCompatible()) { 214 return JXL_FAILURE("The CfL map is not JPEG-compatible"); 215 } 216 jpeg_is_gray = (decoded->jpeg_data->components.size() == 1); 217 jpeg_c_map = JpegOrder(frame_header.color_transform, jpeg_is_gray); 218 const std::vector<QuantEncoding>& qe = 219 dec_state->shared->matrices.encodings(); 220 if (qe.empty() || qe[0].mode != QuantEncoding::Mode::kQuantModeRAW || 221 std::abs(qe[0].qraw.qtable_den - 1.f / (8 * 255)) > 1e-8f) { 222 return JXL_FAILURE( 223 "Quantization table is not a JPEG quantization table."); 224 } 225 for (size_t c = 0; c < 3; c++) { 226 if (frame_header.color_transform == ColorTransform::kNone) { 227 dcoff[c] = 1024 / (*qe[0].qraw.qtable)[64 * c]; 228 } 229 for (size_t i = 0; i < 64; i++) { 230 // Transpose the matrix, as it will be used on the transposed block. 231 int n = qe[0].qraw.qtable->at(64 + i); 232 int d = qe[0].qraw.qtable->at(64 * c + i); 233 if (n <= 0 || d <= 0 || n >= 65536 || d >= 65536) { 234 return JXL_FAILURE("Invalid JPEG quantization table"); 235 } 236 scaled_qtable[64 * c + (i % 8) * 8 + (i / 8)] = 237 (1 << kCFLFixedPointPrecision) * n / d; 238 } 239 } 240 } 241 242 size_t hshift[3] = {cs.HShift(0), cs.HShift(1), cs.HShift(2)}; 243 size_t vshift[3] = {cs.VShift(0), cs.VShift(1), cs.VShift(2)}; 244 Rect r[3]; 245 for (size_t i = 0; i < 3; i++) { 246 r[i] = 247 Rect(block_rect.x0() >> hshift[i], block_rect.y0() >> vshift[i], 248 block_rect.xsize() >> hshift[i], block_rect.ysize() >> vshift[i]); 249 if (!r[i].IsInside({0, 0, dec_state->shared->dc->Plane(i).xsize(), 250 dec_state->shared->dc->Plane(i).ysize()})) { 251 return JXL_FAILURE("Frame dimensions are too big for the image."); 252 } 253 } 254 255 for (size_t by = 0; by < ysize_blocks; ++by) { 256 get_block->StartRow(by); 257 size_t sby[3] = {by >> vshift[0], by >> vshift[1], by >> vshift[2]}; 258 259 const int32_t* JXL_RESTRICT row_quant = 260 block_rect.ConstRow(dec_state->shared->raw_quant_field, by); 261 262 const float* JXL_RESTRICT dc_rows[3] = { 263 r[0].ConstPlaneRow(*dec_state->shared->dc, 0, sby[0]), 264 r[1].ConstPlaneRow(*dec_state->shared->dc, 1, sby[1]), 265 r[2].ConstPlaneRow(*dec_state->shared->dc, 2, sby[2]), 266 }; 267 268 const size_t ty = (block_rect.y0() + by) / kColorTileDimInBlocks; 269 AcStrategyRow acs_row = ac_strategy.ConstRow(block_rect, by); 270 271 const int8_t* JXL_RESTRICT row_cmap[3] = { 272 dec_state->shared->cmap.ytox_map.ConstRow(ty), 273 nullptr, 274 dec_state->shared->cmap.ytob_map.ConstRow(ty), 275 }; 276 277 float* JXL_RESTRICT idct_row[3]; 278 int16_t* JXL_RESTRICT jpeg_row[3]; 279 for (size_t c = 0; c < 3; c++) { 280 idct_row[c] = render_pipeline_input.GetBuffer(c).second.Row( 281 render_pipeline_input.GetBuffer(c).first, sby[c] * kBlockDim); 282 if (decoded->IsJPEG()) { 283 auto& component = decoded->jpeg_data->components[jpeg_c_map[c]]; 284 jpeg_row[c] = 285 component.coeffs.data() + 286 (component.width_in_blocks * (r[c].y0() + sby[c]) + r[c].x0()) * 287 kDCTBlockSize; 288 } 289 } 290 291 size_t bx = 0; 292 for (size_t tx = 0; tx < DivCeil(xsize_blocks, kColorTileDimInBlocks); 293 tx++) { 294 size_t abs_tx = tx + block_rect.x0() / kColorTileDimInBlocks; 295 auto x_cc_mul = 296 Set(d, dec_state->shared->cmap.YtoXRatio(row_cmap[0][abs_tx])); 297 auto b_cc_mul = 298 Set(d, dec_state->shared->cmap.YtoBRatio(row_cmap[2][abs_tx])); 299 // Increment bx by llf_x because those iterations would otherwise 300 // immediately continue (!IsFirstBlock). Reduces mispredictions. 301 for (; bx < xsize_blocks && bx < (tx + 1) * kColorTileDimInBlocks;) { 302 size_t sbx[3] = {bx >> hshift[0], bx >> hshift[1], bx >> hshift[2]}; 303 AcStrategy acs = acs_row[bx]; 304 const size_t llf_x = acs.covered_blocks_x(); 305 306 // Can only happen in the second or lower rows of a varblock. 307 if (JXL_UNLIKELY(!acs.IsFirstBlock())) { 308 bx += llf_x; 309 continue; 310 } 311 const size_t log2_covered_blocks = acs.log2_covered_blocks(); 312 313 const size_t covered_blocks = 1 << log2_covered_blocks; 314 const size_t size = covered_blocks * kDCTBlockSize; 315 316 ACPtr qblock[3]; 317 if (accumulate) { 318 for (size_t c = 0; c < 3; c++) { 319 qblock[c] = dec_state->coefficients->PlaneRow(c, group_idx, offset); 320 } 321 } else { 322 // No point in reading from bitstream without accumulating and not 323 // drawing. 324 JXL_ASSERT(draw == kDraw); 325 if (ac_type == ACType::k16) { 326 memset(group_dec_cache->dec_group_qblock16, 0, 327 size * 3 * sizeof(int16_t)); 328 for (size_t c = 0; c < 3; c++) { 329 qblock[c].ptr16 = group_dec_cache->dec_group_qblock16 + c * size; 330 } 331 } else { 332 memset(group_dec_cache->dec_group_qblock, 0, 333 size * 3 * sizeof(int32_t)); 334 for (size_t c = 0; c < 3; c++) { 335 qblock[c].ptr32 = group_dec_cache->dec_group_qblock + c * size; 336 } 337 } 338 } 339 JXL_RETURN_IF_ERROR(get_block->LoadBlock( 340 bx, by, acs, size, log2_covered_blocks, qblock, ac_type)); 341 offset += size; 342 if (draw == kDontDraw) { 343 bx += llf_x; 344 continue; 345 } 346 347 if (JXL_UNLIKELY(decoded->IsJPEG())) { 348 if (acs.Strategy() != AcStrategy::Type::DCT) { 349 return JXL_FAILURE( 350 "Can only decode to JPEG if only DCT-8 is used."); 351 } 352 353 HWY_ALIGN int32_t transposed_dct_y[64]; 354 for (size_t c : {1, 0, 2}) { 355 // Propagate only Y for grayscale. 356 if (jpeg_is_gray && c != 1) { 357 continue; 358 } 359 if ((sbx[c] << hshift[c] != bx) || (sby[c] << vshift[c] != by)) { 360 continue; 361 } 362 int16_t* JXL_RESTRICT jpeg_pos = 363 jpeg_row[c] + sbx[c] * kDCTBlockSize; 364 // JPEG XL is transposed, JPEG is not. 365 auto* transposed_dct = qblock[c].ptr32; 366 Transpose8x8InPlace(transposed_dct); 367 // No CfL - no need to store the y block converted to integers. 368 if (!cs.Is444() || 369 (row_cmap[0][abs_tx] == 0 && row_cmap[2][abs_tx] == 0)) { 370 for (size_t i = 0; i < 64; i += Lanes(d)) { 371 const auto ini = Load(di, transposed_dct + i); 372 const auto ini16 = DemoteTo(di16, ini); 373 StoreU(ini16, di16, jpeg_pos + i); 374 } 375 } else if (c == 1) { 376 // Y channel: save for restoring X/B, but nothing else to do. 377 for (size_t i = 0; i < 64; i += Lanes(d)) { 378 const auto ini = Load(di, transposed_dct + i); 379 Store(ini, di, transposed_dct_y + i); 380 const auto ini16 = DemoteTo(di16, ini); 381 StoreU(ini16, di16, jpeg_pos + i); 382 } 383 } else { 384 // transposed_dct_y contains the y channel block, transposed. 385 const auto scale = Set( 386 di, dec_state->shared->cmap.RatioJPEG(row_cmap[c][abs_tx])); 387 const auto round = Set(di, 1 << (kCFLFixedPointPrecision - 1)); 388 for (int i = 0; i < 64; i += Lanes(d)) { 389 auto in = Load(di, transposed_dct + i); 390 auto in_y = Load(di, transposed_dct_y + i); 391 auto qt = Load(di, scaled_qtable + c * size + i); 392 auto coeff_scale = ShiftRight<kCFLFixedPointPrecision>( 393 Add(Mul(qt, scale), round)); 394 auto cfl_factor = ShiftRight<kCFLFixedPointPrecision>( 395 Add(Mul(in_y, coeff_scale), round)); 396 StoreU(DemoteTo(di16, Add(in, cfl_factor)), di16, jpeg_pos + i); 397 } 398 } 399 jpeg_pos[0] = 400 Clamp1<float>(dc_rows[c][sbx[c]] - dcoff[c], -2047, 2047); 401 auto overflow = MaskFromVec(Set(di16_full, 0)); 402 auto underflow = MaskFromVec(Set(di16_full, 0)); 403 for (int i = 0; i < 64; i += Lanes(di16_full)) { 404 auto in = LoadU(di16_full, jpeg_pos + i); 405 overflow = Or(overflow, Gt(in, kJpegDctMax)); 406 underflow = Or(underflow, Lt(in, kJpegDctMin)); 407 } 408 if (!AllFalse(di16_full, Or(overflow, underflow))) { 409 return JXL_FAILURE("JPEG DCT coefficients out of range"); 410 } 411 } 412 } else { 413 HWY_ALIGN float* const block = group_dec_cache->dec_group_block; 414 // Dequantize and add predictions. 415 dequant_block( 416 acs, inv_global_scale, row_quant[bx], dec_state->x_dm_multiplier, 417 dec_state->b_dm_multiplier, x_cc_mul, b_cc_mul, acs.RawStrategy(), 418 size, dec_state->shared->quantizer, 419 acs.covered_blocks_y() * acs.covered_blocks_x(), sbx, dc_rows, 420 dc_stride, 421 dec_state->output_encoding_info.opsin_params.quant_biases, qblock, 422 block, group_dec_cache->scratch_space); 423 424 for (size_t c : {1, 0, 2}) { 425 if ((sbx[c] << hshift[c] != bx) || (sby[c] << vshift[c] != by)) { 426 continue; 427 } 428 // IDCT 429 float* JXL_RESTRICT idct_pos = idct_row[c] + sbx[c] * kBlockDim; 430 TransformToPixels(acs.Strategy(), block + c * size, idct_pos, 431 idct_stride[c], group_dec_cache->scratch_space); 432 } 433 } 434 bx += llf_x; 435 } 436 } 437 } 438 return true; 439 } 440 441 // NOLINTNEXTLINE(google-readability-namespace-comments) 442 } // namespace HWY_NAMESPACE 443 } // namespace jxl 444 HWY_AFTER_NAMESPACE(); 445 446 #if HWY_ONCE 447 namespace jxl { 448 namespace { 449 // Decode quantized AC coefficients of DCT blocks. 450 // LLF components in the output block will not be modified. 451 template <ACType ac_type, bool uses_lz77> 452 Status DecodeACVarBlock(size_t ctx_offset, size_t log2_covered_blocks, 453 int32_t* JXL_RESTRICT row_nzeros, 454 const int32_t* JXL_RESTRICT row_nzeros_top, 455 size_t nzeros_stride, size_t c, size_t bx, size_t by, 456 size_t lbx, AcStrategy acs, 457 const coeff_order_t* JXL_RESTRICT coeff_order, 458 BitReader* JXL_RESTRICT br, 459 ANSSymbolReader* JXL_RESTRICT decoder, 460 const std::vector<uint8_t>& context_map, 461 const uint8_t* qdc_row, const int32_t* qf_row, 462 const BlockCtxMap& block_ctx_map, ACPtr block, 463 size_t shift = 0) { 464 // Equal to number of LLF coefficients. 465 const size_t covered_blocks = 1 << log2_covered_blocks; 466 const size_t size = covered_blocks * kDCTBlockSize; 467 int32_t predicted_nzeros = 468 PredictFromTopAndLeft(row_nzeros_top, row_nzeros, bx, 32); 469 470 size_t ord = kStrategyOrder[acs.RawStrategy()]; 471 const coeff_order_t* JXL_RESTRICT order = 472 &coeff_order[CoeffOrderOffset(ord, c)]; 473 474 size_t block_ctx = block_ctx_map.Context(qdc_row[lbx], qf_row[bx], ord, c); 475 const int32_t nzero_ctx = 476 block_ctx_map.NonZeroContext(predicted_nzeros, block_ctx) + ctx_offset; 477 478 size_t nzeros = 479 decoder->ReadHybridUintInlined<uses_lz77>(nzero_ctx, br, context_map); 480 if (nzeros > size - covered_blocks) { 481 return JXL_FAILURE("Invalid AC: nzeros %" PRIuS " too large for %" PRIuS 482 " 8x8 blocks", 483 nzeros, covered_blocks); 484 } 485 for (size_t y = 0; y < acs.covered_blocks_y(); y++) { 486 for (size_t x = 0; x < acs.covered_blocks_x(); x++) { 487 row_nzeros[bx + x + y * nzeros_stride] = 488 (nzeros + covered_blocks - 1) >> log2_covered_blocks; 489 } 490 } 491 492 const size_t histo_offset = 493 ctx_offset + block_ctx_map.ZeroDensityContextsOffset(block_ctx); 494 495 size_t prev = (nzeros > size / 16 ? 0 : 1); 496 for (size_t k = covered_blocks; k < size && nzeros != 0; ++k) { 497 const size_t ctx = 498 histo_offset + ZeroDensityContext(nzeros, k, covered_blocks, 499 log2_covered_blocks, prev); 500 const size_t u_coeff = 501 decoder->ReadHybridUintInlined<uses_lz77>(ctx, br, context_map); 502 // Hand-rolled version of UnpackSigned, shifting before the conversion to 503 // signed integer to avoid undefined behavior of shifting negative numbers. 504 const size_t magnitude = u_coeff >> 1; 505 const size_t neg_sign = (~u_coeff) & 1; 506 const intptr_t coeff = 507 static_cast<intptr_t>((magnitude ^ (neg_sign - 1)) << shift); 508 if (ac_type == ACType::k16) { 509 block.ptr16[order[k]] += coeff; 510 } else { 511 block.ptr32[order[k]] += coeff; 512 } 513 prev = static_cast<size_t>(u_coeff != 0); 514 nzeros -= prev; 515 } 516 if (JXL_UNLIKELY(nzeros != 0)) { 517 return JXL_FAILURE("Invalid AC: nzeros at end of block is %" PRIuS 518 ", should be 0. Block (%" PRIuS ", %" PRIuS 519 "), channel %" PRIuS, 520 nzeros, bx, by, c); 521 } 522 523 return true; 524 } 525 526 // Structs used by DecodeGroupImpl to get a quantized block. 527 // GetBlockFromBitstream uses ANS decoding (and thus keeps track of row 528 // pointers in row_nzeros), GetBlockFromEncoder simply reads the coefficient 529 // image provided by the encoder. 530 531 struct GetBlockFromBitstream : public GetBlock { 532 void StartRow(size_t by) override { 533 qf_row = rect.ConstRow(*qf, by); 534 for (size_t c = 0; c < 3; c++) { 535 size_t sby = by >> vshift[c]; 536 quant_dc_row = quant_dc->ConstRow(rect.y0() + by) + rect.x0(); 537 for (size_t i = 0; i < num_passes; i++) { 538 row_nzeros[i][c] = group_dec_cache->num_nzeroes[i].PlaneRow(c, sby); 539 row_nzeros_top[i][c] = 540 sby == 0 541 ? nullptr 542 : group_dec_cache->num_nzeroes[i].ConstPlaneRow(c, sby - 1); 543 } 544 } 545 } 546 547 Status LoadBlock(size_t bx, size_t by, const AcStrategy& acs, size_t size, 548 size_t log2_covered_blocks, ACPtr block[3], 549 ACType ac_type) override { 550 ; 551 for (size_t c : {1, 0, 2}) { 552 size_t sbx = bx >> hshift[c]; 553 size_t sby = by >> vshift[c]; 554 if (JXL_UNLIKELY((sbx << hshift[c] != bx) || (sby << vshift[c] != by))) { 555 continue; 556 } 557 558 for (size_t pass = 0; JXL_UNLIKELY(pass < num_passes); pass++) { 559 auto decode_ac_varblock = 560 decoders[pass].UsesLZ77() 561 ? (ac_type == ACType::k16 ? DecodeACVarBlock<ACType::k16, 1> 562 : DecodeACVarBlock<ACType::k32, 1>) 563 : (ac_type == ACType::k16 ? DecodeACVarBlock<ACType::k16, 0> 564 : DecodeACVarBlock<ACType::k32, 0>); 565 JXL_RETURN_IF_ERROR(decode_ac_varblock( 566 ctx_offset[pass], log2_covered_blocks, row_nzeros[pass][c], 567 row_nzeros_top[pass][c], nzeros_stride, c, sbx, sby, bx, acs, 568 &coeff_orders[pass * coeff_order_size], readers[pass], 569 &decoders[pass], context_map[pass], quant_dc_row, qf_row, 570 *block_ctx_map, block[c], shift_for_pass[pass])); 571 } 572 } 573 return true; 574 } 575 576 Status Init(const FrameHeader& frame_header, 577 BitReader* JXL_RESTRICT* JXL_RESTRICT readers, size_t num_passes, 578 size_t group_idx, size_t histo_selector_bits, const Rect& rect, 579 GroupDecCache* JXL_RESTRICT group_dec_cache, 580 PassesDecoderState* dec_state, size_t first_pass) { 581 for (size_t i = 0; i < 3; i++) { 582 hshift[i] = frame_header.chroma_subsampling.HShift(i); 583 vshift[i] = frame_header.chroma_subsampling.VShift(i); 584 } 585 this->coeff_order_size = dec_state->shared->coeff_order_size; 586 this->coeff_orders = 587 dec_state->shared->coeff_orders.data() + first_pass * coeff_order_size; 588 this->context_map = dec_state->context_map.data() + first_pass; 589 this->readers = readers; 590 this->num_passes = num_passes; 591 this->shift_for_pass = frame_header.passes.shift + first_pass; 592 this->group_dec_cache = group_dec_cache; 593 this->rect = rect; 594 block_ctx_map = &dec_state->shared->block_ctx_map; 595 qf = &dec_state->shared->raw_quant_field; 596 quant_dc = &dec_state->shared->quant_dc; 597 598 for (size_t pass = 0; pass < num_passes; pass++) { 599 // Select which histogram set to use among those of the current pass. 600 size_t cur_histogram = 0; 601 if (histo_selector_bits != 0) { 602 cur_histogram = readers[pass]->ReadBits(histo_selector_bits); 603 } 604 if (cur_histogram >= dec_state->shared->num_histograms) { 605 return JXL_FAILURE("Invalid histogram selector"); 606 } 607 ctx_offset[pass] = cur_histogram * block_ctx_map->NumACContexts(); 608 609 decoders[pass] = 610 ANSSymbolReader(&dec_state->code[pass + first_pass], readers[pass]); 611 } 612 nzeros_stride = group_dec_cache->num_nzeroes[0].PixelsPerRow(); 613 for (size_t i = 0; i < num_passes; i++) { 614 JXL_ASSERT( 615 nzeros_stride == 616 static_cast<size_t>(group_dec_cache->num_nzeroes[i].PixelsPerRow())); 617 } 618 return true; 619 } 620 621 const uint32_t* shift_for_pass = nullptr; // not owned 622 const coeff_order_t* JXL_RESTRICT coeff_orders; 623 size_t coeff_order_size; 624 const std::vector<uint8_t>* JXL_RESTRICT context_map; 625 ANSSymbolReader decoders[kMaxNumPasses]; 626 BitReader* JXL_RESTRICT* JXL_RESTRICT readers; 627 size_t num_passes; 628 size_t ctx_offset[kMaxNumPasses]; 629 size_t nzeros_stride; 630 int32_t* JXL_RESTRICT row_nzeros[kMaxNumPasses][3]; 631 const int32_t* JXL_RESTRICT row_nzeros_top[kMaxNumPasses][3]; 632 GroupDecCache* JXL_RESTRICT group_dec_cache; 633 const BlockCtxMap* block_ctx_map; 634 const ImageI* qf; 635 const ImageB* quant_dc; 636 const int32_t* qf_row; 637 const uint8_t* quant_dc_row; 638 Rect rect; 639 size_t hshift[3], vshift[3]; 640 }; 641 642 struct GetBlockFromEncoder : public GetBlock { 643 void StartRow(size_t by) override {} 644 645 Status LoadBlock(size_t bx, size_t by, const AcStrategy& acs, size_t size, 646 size_t log2_covered_blocks, ACPtr block[3], 647 ACType ac_type) override { 648 JXL_DASSERT(ac_type == ACType::k32); 649 for (size_t c = 0; c < 3; c++) { 650 // for each pass 651 for (size_t i = 0; i < quantized_ac->size(); i++) { 652 for (size_t k = 0; k < size; k++) { 653 // TODO(veluca): SIMD. 654 block[c].ptr32[k] += 655 rows[i][c][offset + k] * (1 << shift_for_pass[i]); 656 } 657 } 658 } 659 offset += size; 660 return true; 661 } 662 663 GetBlockFromEncoder(const std::vector<std::unique_ptr<ACImage>>& ac, 664 size_t group_idx, const uint32_t* shift_for_pass) 665 : quantized_ac(&ac), shift_for_pass(shift_for_pass) { 666 // TODO(veluca): not supported with chroma subsampling. 667 for (size_t i = 0; i < quantized_ac->size(); i++) { 668 JXL_CHECK((*quantized_ac)[i]->Type() == ACType::k32); 669 for (size_t c = 0; c < 3; c++) { 670 rows[i][c] = (*quantized_ac)[i]->PlaneRow(c, group_idx, 0).ptr32; 671 } 672 } 673 } 674 675 const std::vector<std::unique_ptr<ACImage>>* JXL_RESTRICT quantized_ac; 676 size_t offset = 0; 677 const int32_t* JXL_RESTRICT rows[kMaxNumPasses][3]; 678 const uint32_t* shift_for_pass = nullptr; // not owned 679 }; 680 681 HWY_EXPORT(DecodeGroupImpl); 682 683 } // namespace 684 685 Status DecodeGroup(const FrameHeader& frame_header, 686 BitReader* JXL_RESTRICT* JXL_RESTRICT readers, 687 size_t num_passes, size_t group_idx, 688 PassesDecoderState* JXL_RESTRICT dec_state, 689 GroupDecCache* JXL_RESTRICT group_dec_cache, size_t thread, 690 RenderPipelineInput& render_pipeline_input, 691 ImageBundle* JXL_RESTRICT decoded, size_t first_pass, 692 bool force_draw, bool dc_only, bool* should_run_pipeline) { 693 DrawMode draw = 694 (num_passes + first_pass == frame_header.passes.num_passes) || force_draw 695 ? kDraw 696 : kDontDraw; 697 698 if (should_run_pipeline) { 699 *should_run_pipeline = draw != kDontDraw; 700 } 701 702 if (draw == kDraw && num_passes == 0 && first_pass == 0) { 703 JXL_RETURN_IF_ERROR(group_dec_cache->InitDCBufferOnce()); 704 const YCbCrChromaSubsampling& cs = frame_header.chroma_subsampling; 705 for (size_t c : {0, 1, 2}) { 706 size_t hs = cs.HShift(c); 707 size_t vs = cs.VShift(c); 708 // We reuse filter_input_storage here as it is not currently in use. 709 const Rect src_rect_precs = 710 dec_state->shared->frame_dim.BlockGroupRect(group_idx); 711 const Rect src_rect = 712 Rect(src_rect_precs.x0() >> hs, src_rect_precs.y0() >> vs, 713 src_rect_precs.xsize() >> hs, src_rect_precs.ysize() >> vs); 714 const Rect copy_rect(kRenderPipelineXOffset, 2, src_rect.xsize(), 715 src_rect.ysize()); 716 CopyImageToWithPadding(src_rect, dec_state->shared->dc->Plane(c), 2, 717 copy_rect, &group_dec_cache->dc_buffer); 718 // Mirrorpad. Interleaving left and right padding ensures that padding 719 // works out correctly even for images with DC size of 1. 720 for (size_t y = 0; y < src_rect.ysize() + 4; y++) { 721 size_t xend = kRenderPipelineXOffset + 722 (dec_state->shared->dc->Plane(c).xsize() >> hs) - 723 src_rect.x0(); 724 for (size_t ix = 0; ix < 2; ix++) { 725 if (src_rect.x0() == 0) { 726 group_dec_cache->dc_buffer.Row(y)[kRenderPipelineXOffset - ix - 1] = 727 group_dec_cache->dc_buffer.Row(y)[kRenderPipelineXOffset + ix]; 728 } 729 if (src_rect.x0() + src_rect.xsize() + 2 >= 730 (dec_state->shared->dc->xsize() >> hs)) { 731 group_dec_cache->dc_buffer.Row(y)[xend + ix] = 732 group_dec_cache->dc_buffer.Row(y)[xend - ix - 1]; 733 } 734 } 735 } 736 Rect dst_rect = render_pipeline_input.GetBuffer(c).second; 737 ImageF* upsampling_dst = render_pipeline_input.GetBuffer(c).first; 738 JXL_ASSERT(dst_rect.IsInside(*upsampling_dst)); 739 740 RenderPipelineStage::RowInfo input_rows(1, std::vector<float*>(5)); 741 RenderPipelineStage::RowInfo output_rows(1, std::vector<float*>(8)); 742 for (size_t y = src_rect.y0(); y < src_rect.y0() + src_rect.ysize(); 743 y++) { 744 for (ssize_t iy = 0; iy < 5; iy++) { 745 input_rows[0][iy] = group_dec_cache->dc_buffer.Row( 746 Mirror(static_cast<ssize_t>(y) + iy - 2, 747 dec_state->shared->dc->Plane(c).ysize() >> vs) + 748 2 - src_rect.y0()); 749 } 750 for (size_t iy = 0; iy < 8; iy++) { 751 output_rows[0][iy] = 752 dst_rect.Row(upsampling_dst, ((y - src_rect.y0()) << 3) + iy) - 753 kRenderPipelineXOffset; 754 } 755 // Arguments set to 0/nullptr are not used. 756 JXL_RETURN_IF_ERROR(dec_state->upsampler8x->ProcessRow( 757 input_rows, output_rows, 758 /*xextra=*/0, src_rect.xsize(), 0, 0, thread)); 759 } 760 } 761 return true; 762 } 763 764 size_t histo_selector_bits = 0; 765 if (dc_only) { 766 JXL_ASSERT(num_passes == 0); 767 } else { 768 JXL_ASSERT(dec_state->shared->num_histograms > 0); 769 histo_selector_bits = CeilLog2Nonzero(dec_state->shared->num_histograms); 770 } 771 772 auto get_block = jxl::make_unique<GetBlockFromBitstream>(); 773 JXL_RETURN_IF_ERROR(get_block->Init( 774 frame_header, readers, num_passes, group_idx, histo_selector_bits, 775 dec_state->shared->frame_dim.BlockGroupRect(group_idx), group_dec_cache, 776 dec_state, first_pass)); 777 778 JXL_RETURN_IF_ERROR(HWY_DYNAMIC_DISPATCH(DecodeGroupImpl)( 779 frame_header, get_block.get(), group_dec_cache, dec_state, thread, 780 group_idx, render_pipeline_input, decoded, draw)); 781 782 for (size_t pass = 0; pass < num_passes; pass++) { 783 if (!get_block->decoders[pass].CheckANSFinalState()) { 784 return JXL_FAILURE("ANS checksum failure."); 785 } 786 } 787 return true; 788 } 789 790 Status DecodeGroupForRoundtrip(const FrameHeader& frame_header, 791 const std::vector<std::unique_ptr<ACImage>>& ac, 792 size_t group_idx, 793 PassesDecoderState* JXL_RESTRICT dec_state, 794 GroupDecCache* JXL_RESTRICT group_dec_cache, 795 size_t thread, 796 RenderPipelineInput& render_pipeline_input, 797 ImageBundle* JXL_RESTRICT decoded, 798 AuxOut* aux_out) { 799 GetBlockFromEncoder get_block(ac, group_idx, frame_header.passes.shift); 800 JXL_RETURN_IF_ERROR(group_dec_cache->InitOnce( 801 /*num_passes=*/0, 802 /*used_acs=*/(1u << AcStrategy::kNumValidStrategies) - 1)); 803 804 return HWY_DYNAMIC_DISPATCH(DecodeGroupImpl)( 805 frame_header, &get_block, group_dec_cache, dec_state, thread, group_idx, 806 render_pipeline_input, decoded, kDraw); 807 } 808 809 } // namespace jxl 810 #endif // HWY_ONCE