quant_weights.cc (51714B)
1 // Copyright (c) the JPEG XL Project Authors. All rights reserved. 2 // 3 // Use of this source code is governed by a BSD-style 4 // license that can be found in the LICENSE file. 5 #include "lib/jxl/quant_weights.h" 6 7 #include <stdio.h> 8 #include <stdlib.h> 9 10 #include <algorithm> 11 #include <cmath> 12 #include <limits> 13 #include <utility> 14 15 #include "lib/jxl/base/bits.h" 16 #include "lib/jxl/base/status.h" 17 #include "lib/jxl/dct_scales.h" 18 #include "lib/jxl/dec_modular.h" 19 #include "lib/jxl/fields.h" 20 #include "lib/jxl/image.h" 21 22 #undef HWY_TARGET_INCLUDE 23 #define HWY_TARGET_INCLUDE "lib/jxl/quant_weights.cc" 24 #include <hwy/foreach_target.h> 25 #include <hwy/highway.h> 26 27 #include "lib/jxl/base/fast_math-inl.h" 28 29 HWY_BEFORE_NAMESPACE(); 30 namespace jxl { 31 namespace HWY_NAMESPACE { 32 33 // These templates are not found via ADL. 34 using hwy::HWY_NAMESPACE::Lt; 35 using hwy::HWY_NAMESPACE::MulAdd; 36 using hwy::HWY_NAMESPACE::Sqrt; 37 38 // kQuantWeights[N * N * c + N * y + x] is the relative weight of the (x, y) 39 // coefficient in component c. Higher weights correspond to finer quantization 40 // intervals and more bits spent in encoding. 41 42 static constexpr const float kAlmostZero = 1e-8f; 43 44 void GetQuantWeightsDCT2(const QuantEncoding::DCT2Weights& dct2weights, 45 float* weights) { 46 for (size_t c = 0; c < 3; c++) { 47 size_t start = c * 64; 48 weights[start] = 0xBAD; 49 weights[start + 1] = weights[start + 8] = dct2weights[c][0]; 50 weights[start + 9] = dct2weights[c][1]; 51 for (size_t y = 0; y < 2; y++) { 52 for (size_t x = 0; x < 2; x++) { 53 weights[start + y * 8 + x + 2] = dct2weights[c][2]; 54 weights[start + (y + 2) * 8 + x] = dct2weights[c][2]; 55 } 56 } 57 for (size_t y = 0; y < 2; y++) { 58 for (size_t x = 0; x < 2; x++) { 59 weights[start + (y + 2) * 8 + x + 2] = dct2weights[c][3]; 60 } 61 } 62 for (size_t y = 0; y < 4; y++) { 63 for (size_t x = 0; x < 4; x++) { 64 weights[start + y * 8 + x + 4] = dct2weights[c][4]; 65 weights[start + (y + 4) * 8 + x] = dct2weights[c][4]; 66 } 67 } 68 for (size_t y = 0; y < 4; y++) { 69 for (size_t x = 0; x < 4; x++) { 70 weights[start + (y + 4) * 8 + x + 4] = dct2weights[c][5]; 71 } 72 } 73 } 74 } 75 76 void GetQuantWeightsIdentity(const QuantEncoding::IdWeights& idweights, 77 float* weights) { 78 for (size_t c = 0; c < 3; c++) { 79 for (int i = 0; i < 64; i++) { 80 weights[64 * c + i] = idweights[c][0]; 81 } 82 weights[64 * c + 1] = idweights[c][1]; 83 weights[64 * c + 8] = idweights[c][1]; 84 weights[64 * c + 9] = idweights[c][2]; 85 } 86 } 87 88 float Interpolate(float pos, float max, const float* array, size_t len) { 89 float scaled_pos = pos * (len - 1) / max; 90 size_t idx = scaled_pos; 91 JXL_DASSERT(idx + 1 < len); 92 float a = array[idx]; 93 float b = array[idx + 1]; 94 return a * FastPowf(b / a, scaled_pos - idx); 95 } 96 97 float Mult(float v) { 98 if (v > 0.0f) return 1.0f + v; 99 return 1.0f / (1.0f - v); 100 } 101 102 using DF4 = HWY_CAPPED(float, 4); 103 104 hwy::HWY_NAMESPACE::Vec<DF4> InterpolateVec( 105 hwy::HWY_NAMESPACE::Vec<DF4> scaled_pos, const float* array) { 106 HWY_CAPPED(int32_t, 4) di; 107 108 auto idx = ConvertTo(di, scaled_pos); 109 110 auto frac = Sub(scaled_pos, ConvertTo(DF4(), idx)); 111 112 // TODO(veluca): in theory, this could be done with 8 TableLookupBytes, but 113 // it's probably slower. 114 auto a = GatherIndex(DF4(), array, idx); 115 auto b = GatherIndex(DF4(), array + 1, idx); 116 117 return Mul(a, FastPowf(DF4(), Div(b, a), frac)); 118 } 119 120 // Computes quant weights for a COLS*ROWS-sized transform, using num_bands 121 // eccentricity bands and num_ebands eccentricity bands. If print_mode is 1, 122 // prints the resulting matrix; if print_mode is 2, prints the matrix in a 123 // format suitable for a 3d plot with gnuplot. 124 Status GetQuantWeights( 125 size_t ROWS, size_t COLS, 126 const DctQuantWeightParams::DistanceBandsArray& distance_bands, 127 size_t num_bands, float* out) { 128 for (size_t c = 0; c < 3; c++) { 129 float bands[DctQuantWeightParams::kMaxDistanceBands] = { 130 distance_bands[c][0]}; 131 if (bands[0] < kAlmostZero) return JXL_FAILURE("Invalid distance bands"); 132 for (size_t i = 1; i < num_bands; i++) { 133 bands[i] = bands[i - 1] * Mult(distance_bands[c][i]); 134 if (bands[i] < kAlmostZero) return JXL_FAILURE("Invalid distance bands"); 135 } 136 float scale = (num_bands - 1) / (kSqrt2 + 1e-6f); 137 float rcpcol = scale / (COLS - 1); 138 float rcprow = scale / (ROWS - 1); 139 JXL_ASSERT(COLS >= Lanes(DF4())); 140 HWY_ALIGN float l0123[4] = {0, 1, 2, 3}; 141 for (uint32_t y = 0; y < ROWS; y++) { 142 float dy = y * rcprow; 143 float dy2 = dy * dy; 144 for (uint32_t x = 0; x < COLS; x += Lanes(DF4())) { 145 auto dx = 146 Mul(Add(Set(DF4(), x), Load(DF4(), l0123)), Set(DF4(), rcpcol)); 147 auto scaled_distance = Sqrt(MulAdd(dx, dx, Set(DF4(), dy2))); 148 auto weight = num_bands == 1 ? Set(DF4(), bands[0]) 149 : InterpolateVec(scaled_distance, bands); 150 StoreU(weight, DF4(), out + c * COLS * ROWS + y * COLS + x); 151 } 152 } 153 } 154 return true; 155 } 156 157 // TODO(veluca): SIMD-fy. With 256x256, this is actually slow. 158 Status ComputeQuantTable(const QuantEncoding& encoding, 159 float* JXL_RESTRICT table, 160 float* JXL_RESTRICT inv_table, size_t table_num, 161 DequantMatrices::QuantTable kind, size_t* pos) { 162 constexpr size_t N = kBlockDim; 163 size_t wrows = 8 * DequantMatrices::required_size_x[kind]; 164 size_t wcols = 8 * DequantMatrices::required_size_y[kind]; 165 size_t num = wrows * wcols; 166 167 std::vector<float> weights(3 * num); 168 169 switch (encoding.mode) { 170 case QuantEncoding::kQuantModeLibrary: { 171 // Library and copy quant encoding should get replaced by the actual 172 // parameters by the caller. 173 JXL_ASSERT(false); 174 break; 175 } 176 case QuantEncoding::kQuantModeID: { 177 JXL_ASSERT(num == kDCTBlockSize); 178 GetQuantWeightsIdentity(encoding.idweights, weights.data()); 179 break; 180 } 181 case QuantEncoding::kQuantModeDCT2: { 182 JXL_ASSERT(num == kDCTBlockSize); 183 GetQuantWeightsDCT2(encoding.dct2weights, weights.data()); 184 break; 185 } 186 case QuantEncoding::kQuantModeDCT4: { 187 JXL_ASSERT(num == kDCTBlockSize); 188 float weights4x4[3 * 4 * 4]; 189 // Always use 4x4 GetQuantWeights for DCT4 quantization tables. 190 JXL_RETURN_IF_ERROR( 191 GetQuantWeights(4, 4, encoding.dct_params.distance_bands, 192 encoding.dct_params.num_distance_bands, weights4x4)); 193 for (size_t c = 0; c < 3; c++) { 194 for (size_t y = 0; y < kBlockDim; y++) { 195 for (size_t x = 0; x < kBlockDim; x++) { 196 weights[c * num + y * kBlockDim + x] = 197 weights4x4[c * 16 + (y / 2) * 4 + (x / 2)]; 198 } 199 } 200 weights[c * num + 1] /= encoding.dct4multipliers[c][0]; 201 weights[c * num + N] /= encoding.dct4multipliers[c][0]; 202 weights[c * num + N + 1] /= encoding.dct4multipliers[c][1]; 203 } 204 break; 205 } 206 case QuantEncoding::kQuantModeDCT4X8: { 207 JXL_ASSERT(num == kDCTBlockSize); 208 float weights4x8[3 * 4 * 8]; 209 // Always use 4x8 GetQuantWeights for DCT4X8 quantization tables. 210 JXL_RETURN_IF_ERROR( 211 GetQuantWeights(4, 8, encoding.dct_params.distance_bands, 212 encoding.dct_params.num_distance_bands, weights4x8)); 213 for (size_t c = 0; c < 3; c++) { 214 for (size_t y = 0; y < kBlockDim; y++) { 215 for (size_t x = 0; x < kBlockDim; x++) { 216 weights[c * num + y * kBlockDim + x] = 217 weights4x8[c * 32 + (y / 2) * 8 + x]; 218 } 219 } 220 weights[c * num + N] /= encoding.dct4x8multipliers[c]; 221 } 222 break; 223 } 224 case QuantEncoding::kQuantModeDCT: { 225 JXL_RETURN_IF_ERROR(GetQuantWeights( 226 wrows, wcols, encoding.dct_params.distance_bands, 227 encoding.dct_params.num_distance_bands, weights.data())); 228 break; 229 } 230 case QuantEncoding::kQuantModeRAW: { 231 if (!encoding.qraw.qtable || encoding.qraw.qtable->size() != 3 * num) { 232 return JXL_FAILURE("Invalid table encoding"); 233 } 234 for (size_t i = 0; i < 3 * num; i++) { 235 weights[i] = 236 1.f / (encoding.qraw.qtable_den * (*encoding.qraw.qtable)[i]); 237 } 238 break; 239 } 240 case QuantEncoding::kQuantModeAFV: { 241 constexpr float kFreqs[] = { 242 0xBAD, 243 0xBAD, 244 0.8517778890324296, 245 5.37778436506804, 246 0xBAD, 247 0xBAD, 248 4.734747904497923, 249 5.449245381693219, 250 1.6598270267479331, 251 4, 252 7.275749096817861, 253 10.423227632456525, 254 2.662932286148962, 255 7.630657783650829, 256 8.962388608184032, 257 12.97166202570235, 258 }; 259 260 float weights4x8[3 * 4 * 8]; 261 JXL_RETURN_IF_ERROR(( 262 GetQuantWeights(4, 8, encoding.dct_params.distance_bands, 263 encoding.dct_params.num_distance_bands, weights4x8))); 264 float weights4x4[3 * 4 * 4]; 265 JXL_RETURN_IF_ERROR((GetQuantWeights( 266 4, 4, encoding.dct_params_afv_4x4.distance_bands, 267 encoding.dct_params_afv_4x4.num_distance_bands, weights4x4))); 268 269 constexpr float lo = 0.8517778890324296; 270 constexpr float hi = 12.97166202570235f - lo + 1e-6f; 271 for (size_t c = 0; c < 3; c++) { 272 float bands[4]; 273 bands[0] = encoding.afv_weights[c][5]; 274 if (bands[0] < kAlmostZero) return JXL_FAILURE("Invalid AFV bands"); 275 for (size_t i = 1; i < 4; i++) { 276 bands[i] = bands[i - 1] * Mult(encoding.afv_weights[c][i + 5]); 277 if (bands[i] < kAlmostZero) return JXL_FAILURE("Invalid AFV bands"); 278 } 279 size_t start = c * 64; 280 auto set_weight = [&start, &weights](size_t x, size_t y, float val) { 281 weights[start + y * 8 + x] = val; 282 }; 283 weights[start] = 1; // Not used, but causes MSAN error otherwise. 284 // Weights for (0, 1) and (1, 0). 285 set_weight(0, 1, encoding.afv_weights[c][0]); 286 set_weight(1, 0, encoding.afv_weights[c][1]); 287 // AFV special weights for 3-pixel corner. 288 set_weight(0, 2, encoding.afv_weights[c][2]); 289 set_weight(2, 0, encoding.afv_weights[c][3]); 290 set_weight(2, 2, encoding.afv_weights[c][4]); 291 292 // All other AFV weights. 293 for (size_t y = 0; y < 4; y++) { 294 for (size_t x = 0; x < 4; x++) { 295 if (x < 2 && y < 2) continue; 296 float val = Interpolate(kFreqs[y * 4 + x] - lo, hi, bands, 4); 297 set_weight(2 * x, 2 * y, val); 298 } 299 } 300 301 // Put 4x8 weights in odd rows, except (1, 0). 302 for (size_t y = 0; y < kBlockDim / 2; y++) { 303 for (size_t x = 0; x < kBlockDim; x++) { 304 if (x == 0 && y == 0) continue; 305 weights[c * num + (2 * y + 1) * kBlockDim + x] = 306 weights4x8[c * 32 + y * 8 + x]; 307 } 308 } 309 // Put 4x4 weights in even rows / odd columns, except (0, 1). 310 for (size_t y = 0; y < kBlockDim / 2; y++) { 311 for (size_t x = 0; x < kBlockDim / 2; x++) { 312 if (x == 0 && y == 0) continue; 313 weights[c * num + (2 * y) * kBlockDim + 2 * x + 1] = 314 weights4x4[c * 16 + y * 4 + x]; 315 } 316 } 317 } 318 break; 319 } 320 } 321 size_t prev_pos = *pos; 322 HWY_CAPPED(float, 64) d; 323 for (size_t i = 0; i < num * 3; i += Lanes(d)) { 324 auto inv_val = LoadU(d, weights.data() + i); 325 if (JXL_UNLIKELY(!AllFalse(d, Ge(inv_val, Set(d, 1.0f / kAlmostZero))) || 326 !AllFalse(d, Lt(inv_val, Set(d, kAlmostZero))))) { 327 return JXL_FAILURE("Invalid quantization table"); 328 } 329 auto val = Div(Set(d, 1.0f), inv_val); 330 StoreU(val, d, table + *pos + i); 331 StoreU(inv_val, d, inv_table + *pos + i); 332 } 333 (*pos) += 3 * num; 334 335 // Ensure that the lowest frequencies have a 0 inverse table. 336 // This does not affect en/decoding, but allows AC strategy selection to be 337 // slightly simpler. 338 size_t xs = DequantMatrices::required_size_x[kind]; 339 size_t ys = DequantMatrices::required_size_y[kind]; 340 CoefficientLayout(&ys, &xs); 341 for (size_t c = 0; c < 3; c++) { 342 for (size_t y = 0; y < ys; y++) { 343 for (size_t x = 0; x < xs; x++) { 344 inv_table[prev_pos + c * ys * xs * kDCTBlockSize + y * kBlockDim * xs + 345 x] = 0; 346 } 347 } 348 } 349 return true; 350 } 351 352 // NOLINTNEXTLINE(google-readability-namespace-comments) 353 } // namespace HWY_NAMESPACE 354 } // namespace jxl 355 HWY_AFTER_NAMESPACE(); 356 357 #if HWY_ONCE 358 359 namespace jxl { 360 namespace { 361 362 HWY_EXPORT(ComputeQuantTable); 363 364 constexpr const float kAlmostZero = 1e-8f; 365 366 Status DecodeDctParams(BitReader* br, DctQuantWeightParams* params) { 367 params->num_distance_bands = 368 br->ReadFixedBits<DctQuantWeightParams::kLog2MaxDistanceBands>() + 1; 369 for (size_t c = 0; c < 3; c++) { 370 for (size_t i = 0; i < params->num_distance_bands; i++) { 371 JXL_RETURN_IF_ERROR(F16Coder::Read(br, ¶ms->distance_bands[c][i])); 372 } 373 if (params->distance_bands[c][0] < kAlmostZero) { 374 return JXL_FAILURE("Distance band seed is too small"); 375 } 376 params->distance_bands[c][0] *= 64.0f; 377 } 378 return true; 379 } 380 381 Status Decode(BitReader* br, QuantEncoding* encoding, size_t required_size_x, 382 size_t required_size_y, size_t idx, 383 ModularFrameDecoder* modular_frame_decoder) { 384 size_t required_size = required_size_x * required_size_y; 385 required_size_x *= kBlockDim; 386 required_size_y *= kBlockDim; 387 int mode = br->ReadFixedBits<kLog2NumQuantModes>(); 388 switch (mode) { 389 case QuantEncoding::kQuantModeLibrary: { 390 encoding->predefined = br->ReadFixedBits<kCeilLog2NumPredefinedTables>(); 391 if (encoding->predefined >= kNumPredefinedTables) { 392 return JXL_FAILURE("Invalid predefined table"); 393 } 394 break; 395 } 396 case QuantEncoding::kQuantModeID: { 397 if (required_size != 1) return JXL_FAILURE("Invalid mode"); 398 for (size_t c = 0; c < 3; c++) { 399 for (size_t i = 0; i < 3; i++) { 400 JXL_RETURN_IF_ERROR(F16Coder::Read(br, &encoding->idweights[c][i])); 401 if (std::abs(encoding->idweights[c][i]) < kAlmostZero) { 402 return JXL_FAILURE("ID Quantizer is too small"); 403 } 404 encoding->idweights[c][i] *= 64; 405 } 406 } 407 break; 408 } 409 case QuantEncoding::kQuantModeDCT2: { 410 if (required_size != 1) return JXL_FAILURE("Invalid mode"); 411 for (size_t c = 0; c < 3; c++) { 412 for (size_t i = 0; i < 6; i++) { 413 JXL_RETURN_IF_ERROR(F16Coder::Read(br, &encoding->dct2weights[c][i])); 414 if (std::abs(encoding->dct2weights[c][i]) < kAlmostZero) { 415 return JXL_FAILURE("Quantizer is too small"); 416 } 417 encoding->dct2weights[c][i] *= 64; 418 } 419 } 420 break; 421 } 422 case QuantEncoding::kQuantModeDCT4X8: { 423 if (required_size != 1) return JXL_FAILURE("Invalid mode"); 424 for (size_t c = 0; c < 3; c++) { 425 JXL_RETURN_IF_ERROR( 426 F16Coder::Read(br, &encoding->dct4x8multipliers[c])); 427 if (std::abs(encoding->dct4x8multipliers[c]) < kAlmostZero) { 428 return JXL_FAILURE("DCT4X8 multiplier is too small"); 429 } 430 } 431 JXL_RETURN_IF_ERROR(DecodeDctParams(br, &encoding->dct_params)); 432 break; 433 } 434 case QuantEncoding::kQuantModeDCT4: { 435 if (required_size != 1) return JXL_FAILURE("Invalid mode"); 436 for (size_t c = 0; c < 3; c++) { 437 for (size_t i = 0; i < 2; i++) { 438 JXL_RETURN_IF_ERROR( 439 F16Coder::Read(br, &encoding->dct4multipliers[c][i])); 440 if (std::abs(encoding->dct4multipliers[c][i]) < kAlmostZero) { 441 return JXL_FAILURE("DCT4 multiplier is too small"); 442 } 443 } 444 } 445 JXL_RETURN_IF_ERROR(DecodeDctParams(br, &encoding->dct_params)); 446 break; 447 } 448 case QuantEncoding::kQuantModeAFV: { 449 if (required_size != 1) return JXL_FAILURE("Invalid mode"); 450 for (size_t c = 0; c < 3; c++) { 451 for (size_t i = 0; i < 9; i++) { 452 JXL_RETURN_IF_ERROR(F16Coder::Read(br, &encoding->afv_weights[c][i])); 453 } 454 for (size_t i = 0; i < 6; i++) { 455 encoding->afv_weights[c][i] *= 64; 456 } 457 } 458 JXL_RETURN_IF_ERROR(DecodeDctParams(br, &encoding->dct_params)); 459 JXL_RETURN_IF_ERROR(DecodeDctParams(br, &encoding->dct_params_afv_4x4)); 460 break; 461 } 462 case QuantEncoding::kQuantModeDCT: { 463 JXL_RETURN_IF_ERROR(DecodeDctParams(br, &encoding->dct_params)); 464 break; 465 } 466 case QuantEncoding::kQuantModeRAW: { 467 // Set mode early, to avoid mem-leak. 468 encoding->mode = QuantEncoding::kQuantModeRAW; 469 JXL_RETURN_IF_ERROR(ModularFrameDecoder::DecodeQuantTable( 470 required_size_x, required_size_y, br, encoding, idx, 471 modular_frame_decoder)); 472 break; 473 } 474 default: 475 return JXL_FAILURE("Invalid quantization table encoding"); 476 } 477 encoding->mode = static_cast<QuantEncoding::Mode>(mode); 478 return true; 479 } 480 481 } // namespace 482 483 // These definitions are needed before C++17. 484 constexpr const std::array<int, 17> DequantMatrices::required_size_x; 485 constexpr const std::array<int, 17> DequantMatrices::required_size_y; 486 constexpr const size_t DequantMatrices::kSumRequiredXy; 487 constexpr DequantMatrices::QuantTable DequantMatrices::kQuantTable[]; 488 489 Status DequantMatrices::Decode(BitReader* br, 490 ModularFrameDecoder* modular_frame_decoder) { 491 size_t all_default = br->ReadBits(1); 492 size_t num_tables = all_default ? 0 : static_cast<size_t>(kNum); 493 encodings_.clear(); 494 encodings_.resize(kNum, QuantEncoding::Library(0)); 495 for (size_t i = 0; i < num_tables; i++) { 496 JXL_RETURN_IF_ERROR( 497 jxl::Decode(br, &encodings_[i], required_size_x[i % kNum], 498 required_size_y[i % kNum], i, modular_frame_decoder)); 499 } 500 computed_mask_ = 0; 501 return true; 502 } 503 504 Status DequantMatrices::DecodeDC(BitReader* br) { 505 bool all_default = static_cast<bool>(br->ReadBits(1)); 506 if (!br->AllReadsWithinBounds()) return JXL_FAILURE("EOS during DecodeDC"); 507 if (!all_default) { 508 for (size_t c = 0; c < 3; c++) { 509 JXL_RETURN_IF_ERROR(F16Coder::Read(br, &dc_quant_[c])); 510 dc_quant_[c] *= 1.0f / 128.0f; 511 // Negative values and nearly zero are invalid values. 512 if (dc_quant_[c] < kAlmostZero) { 513 return JXL_FAILURE("Invalid dc_quant: coefficient is too small."); 514 } 515 inv_dc_quant_[c] = 1.0f / dc_quant_[c]; 516 } 517 } 518 return true; 519 } 520 521 constexpr float V(float v) { return static_cast<float>(v); } 522 523 namespace { 524 struct DequantMatricesLibraryDef { 525 // DCT8 526 static constexpr QuantEncodingInternal DCT() { 527 return QuantEncodingInternal::DCT(DctQuantWeightParams({{{{ 528 V(3150.0), 529 V(0.0), 530 V(-0.4), 531 V(-0.4), 532 V(-0.4), 533 V(-2.0), 534 }}, 535 {{ 536 V(560.0), 537 V(0.0), 538 V(-0.3), 539 V(-0.3), 540 V(-0.3), 541 V(-0.3), 542 }}, 543 {{ 544 V(512.0), 545 V(-2.0), 546 V(-1.0), 547 V(0.0), 548 V(-1.0), 549 V(-2.0), 550 }}}}, 551 6)); 552 } 553 554 // Identity 555 static constexpr QuantEncodingInternal IDENTITY() { 556 return QuantEncodingInternal::Identity({{{{ 557 V(280.0), 558 V(3160.0), 559 V(3160.0), 560 }}, 561 {{ 562 V(60.0), 563 V(864.0), 564 V(864.0), 565 }}, 566 {{ 567 V(18.0), 568 V(200.0), 569 V(200.0), 570 }}}}); 571 } 572 573 // DCT2 574 static constexpr QuantEncodingInternal DCT2X2() { 575 return QuantEncodingInternal::DCT2({{{{ 576 V(3840.0), 577 V(2560.0), 578 V(1280.0), 579 V(640.0), 580 V(480.0), 581 V(300.0), 582 }}, 583 {{ 584 V(960.0), 585 V(640.0), 586 V(320.0), 587 V(180.0), 588 V(140.0), 589 V(120.0), 590 }}, 591 {{ 592 V(640.0), 593 V(320.0), 594 V(128.0), 595 V(64.0), 596 V(32.0), 597 V(16.0), 598 }}}}); 599 } 600 601 // DCT4 (quant_kind 3) 602 static constexpr QuantEncodingInternal DCT4X4() { 603 return QuantEncodingInternal::DCT4(DctQuantWeightParams({{{{ 604 V(2200.0), 605 V(0.0), 606 V(0.0), 607 V(0.0), 608 }}, 609 {{ 610 V(392.0), 611 V(0.0), 612 V(0.0), 613 V(0.0), 614 }}, 615 {{ 616 V(112.0), 617 V(-0.25), 618 V(-0.25), 619 V(-0.5), 620 }}}}, 621 4), 622 /* kMul */ 623 {{{{ 624 V(1.0), 625 V(1.0), 626 }}, 627 {{ 628 V(1.0), 629 V(1.0), 630 }}, 631 {{ 632 V(1.0), 633 V(1.0), 634 }}}}); 635 } 636 637 // DCT16 638 static constexpr QuantEncodingInternal DCT16X16() { 639 return QuantEncodingInternal::DCT( 640 DctQuantWeightParams({{{{ 641 V(8996.8725711814115328), 642 V(-1.3000777393353804), 643 V(-0.49424529824571225), 644 V(-0.439093774457103443), 645 V(-0.6350101832695744), 646 V(-0.90177264050827612), 647 V(-1.6162099239887414), 648 }}, 649 {{ 650 V(3191.48366296844234752), 651 V(-0.67424582104194355), 652 V(-0.80745813428471001), 653 V(-0.44925837484843441), 654 V(-0.35865440981033403), 655 V(-0.31322389111877305), 656 V(-0.37615025315725483), 657 }}, 658 {{ 659 V(1157.50408145487200256), 660 V(-2.0531423165804414), 661 V(-1.4), 662 V(-0.50687130033378396), 663 V(-0.42708730624733904), 664 V(-1.4856834539296244), 665 V(-4.9209142884401604), 666 }}}}, 667 7)); 668 } 669 670 // DCT32 671 static constexpr QuantEncodingInternal DCT32X32() { 672 return QuantEncodingInternal::DCT( 673 DctQuantWeightParams({{{{ 674 V(15718.40830982518931456), 675 V(-1.025), 676 V(-0.98), 677 V(-0.9012), 678 V(-0.4), 679 V(-0.48819395464), 680 V(-0.421064), 681 V(-0.27), 682 }}, 683 {{ 684 V(7305.7636810695983104), 685 V(-0.8041958212306401), 686 V(-0.7633036457487539), 687 V(-0.55660379990111464), 688 V(-0.49785304658857626), 689 V(-0.43699592683512467), 690 V(-0.40180866526242109), 691 V(-0.27321683125358037), 692 }}, 693 {{ 694 V(3803.53173721215041536), 695 V(-3.060733579805728), 696 V(-2.0413270132490346), 697 V(-2.0235650159727417), 698 V(-0.5495389509954993), 699 V(-0.4), 700 V(-0.4), 701 V(-0.3), 702 }}}}, 703 8)); 704 } 705 706 // DCT16X8 707 static constexpr QuantEncodingInternal DCT8X16() { 708 return QuantEncodingInternal::DCT( 709 DctQuantWeightParams({{{{ 710 V(7240.7734393502), 711 V(-0.7), 712 V(-0.7), 713 V(-0.2), 714 V(-0.2), 715 V(-0.2), 716 V(-0.5), 717 }}, 718 {{ 719 V(1448.15468787004), 720 V(-0.5), 721 V(-0.5), 722 V(-0.5), 723 V(-0.2), 724 V(-0.2), 725 V(-0.2), 726 }}, 727 {{ 728 V(506.854140754517), 729 V(-1.4), 730 V(-0.2), 731 V(-0.5), 732 V(-0.5), 733 V(-1.5), 734 V(-3.6), 735 }}}}, 736 7)); 737 } 738 739 // DCT32X8 740 static constexpr QuantEncodingInternal DCT8X32() { 741 return QuantEncodingInternal::DCT( 742 DctQuantWeightParams({{{{ 743 V(16283.2494710648897), 744 V(-1.7812845336559429), 745 V(-1.6309059012653515), 746 V(-1.0382179034313539), 747 V(-0.85), 748 V(-0.7), 749 V(-0.9), 750 V(-1.2360638576849587), 751 }}, 752 {{ 753 V(5089.15750884921511936), 754 V(-0.320049391452786891), 755 V(-0.35362849922161446), 756 V(-0.30340000000000003), 757 V(-0.61), 758 V(-0.5), 759 V(-0.5), 760 V(-0.6), 761 }}, 762 {{ 763 V(3397.77603275308720128), 764 V(-0.321327362693153371), 765 V(-0.34507619223117997), 766 V(-0.70340000000000003), 767 V(-0.9), 768 V(-1.0), 769 V(-1.0), 770 V(-1.1754605576265209), 771 }}}}, 772 8)); 773 } 774 775 // DCT32X16 776 static constexpr QuantEncodingInternal DCT16X32() { 777 return QuantEncodingInternal::DCT( 778 DctQuantWeightParams({{{{ 779 V(13844.97076442300573), 780 V(-0.97113799999999995), 781 V(-0.658), 782 V(-0.42026), 783 V(-0.22712), 784 V(-0.2206), 785 V(-0.226), 786 V(-0.6), 787 }}, 788 {{ 789 V(4798.964084220744293), 790 V(-0.61125308982767057), 791 V(-0.83770786552491361), 792 V(-0.79014862079498627), 793 V(-0.2692727459704829), 794 V(-0.38272769465388551), 795 V(-0.22924222653091453), 796 V(-0.20719098826199578), 797 }}, 798 {{ 799 V(1807.236946760964614), 800 V(-1.2), 801 V(-1.2), 802 V(-0.7), 803 V(-0.7), 804 V(-0.7), 805 V(-0.4), 806 V(-0.5), 807 }}}}, 808 8)); 809 } 810 811 // DCT4X8 and 8x4 812 static constexpr QuantEncodingInternal DCT4X8() { 813 return QuantEncodingInternal::DCT4X8( 814 DctQuantWeightParams({{ 815 {{ 816 V(2198.050556016380522), 817 V(-0.96269623020744692), 818 V(-0.76194253026666783), 819 V(-0.6551140670773547), 820 }}, 821 {{ 822 V(764.3655248643528689), 823 V(-0.92630200888366945), 824 V(-0.9675229603596517), 825 V(-0.27845290869168118), 826 }}, 827 {{ 828 V(527.107573587542228), 829 V(-1.4594385811273854), 830 V(-1.450082094097871593), 831 V(-1.5843722511996204), 832 }}, 833 }}, 834 4), 835 /* kMuls */ 836 {{ 837 V(1.0), 838 V(1.0), 839 V(1.0), 840 }}); 841 } 842 // AFV 843 static QuantEncodingInternal AFV0() { 844 return QuantEncodingInternal::AFV(DCT4X8().dct_params, DCT4X4().dct_params, 845 {{{{ 846 // 4x4/4x8 DC tendency. 847 V(3072.0), 848 V(3072.0), 849 // AFV corner. 850 V(256.0), 851 V(256.0), 852 V(256.0), 853 // AFV high freqs. 854 V(414.0), 855 V(0.0), 856 V(0.0), 857 V(0.0), 858 }}, 859 {{ 860 // 4x4/4x8 DC tendency. 861 V(1024.0), 862 V(1024.0), 863 // AFV corner. 864 V(50), 865 V(50), 866 V(50), 867 // AFV high freqs. 868 V(58.0), 869 V(0.0), 870 V(0.0), 871 V(0.0), 872 }}, 873 {{ 874 // 4x4/4x8 DC tendency. 875 V(384.0), 876 V(384.0), 877 // AFV corner. 878 V(12.0), 879 V(12.0), 880 V(12.0), 881 // AFV high freqs. 882 V(22.0), 883 V(-0.25), 884 V(-0.25), 885 V(-0.25), 886 }}}}); 887 } 888 889 // DCT64 890 static QuantEncodingInternal DCT64X64() { 891 return QuantEncodingInternal::DCT( 892 DctQuantWeightParams({{{{ 893 V(0.9 * 26629.073922049845), 894 V(-1.025), 895 V(-0.78), 896 V(-0.65012), 897 V(-0.19041574084286472), 898 V(-0.20819395464), 899 V(-0.421064), 900 V(-0.32733845535848671), 901 }}, 902 {{ 903 V(0.9 * 9311.3238710010046), 904 V(-0.3041958212306401), 905 V(-0.3633036457487539), 906 V(-0.35660379990111464), 907 V(-0.3443074455424403), 908 V(-0.33699592683512467), 909 V(-0.30180866526242109), 910 V(-0.27321683125358037), 911 }}, 912 {{ 913 V(0.9 * 4992.2486445538634), 914 V(-1.2), 915 V(-1.2), 916 V(-0.8), 917 V(-0.7), 918 V(-0.7), 919 V(-0.4), 920 V(-0.5), 921 }}}}, 922 8)); 923 } 924 925 // DCT64X32 926 static QuantEncodingInternal DCT32X64() { 927 return QuantEncodingInternal::DCT( 928 DctQuantWeightParams({{{{ 929 V(0.65 * 23629.073922049845), 930 V(-1.025), 931 V(-0.78), 932 V(-0.65012), 933 V(-0.19041574084286472), 934 V(-0.20819395464), 935 V(-0.421064), 936 V(-0.32733845535848671), 937 }}, 938 {{ 939 V(0.65 * 8611.3238710010046), 940 V(-0.3041958212306401), 941 V(-0.3633036457487539), 942 V(-0.35660379990111464), 943 V(-0.3443074455424403), 944 V(-0.33699592683512467), 945 V(-0.30180866526242109), 946 V(-0.27321683125358037), 947 }}, 948 {{ 949 V(0.65 * 4492.2486445538634), 950 V(-1.2), 951 V(-1.2), 952 V(-0.8), 953 V(-0.7), 954 V(-0.7), 955 V(-0.4), 956 V(-0.5), 957 }}}}, 958 8)); 959 } 960 // DCT128X128 961 static QuantEncodingInternal DCT128X128() { 962 return QuantEncodingInternal::DCT( 963 DctQuantWeightParams({{{{ 964 V(1.8 * 26629.073922049845), 965 V(-1.025), 966 V(-0.78), 967 V(-0.65012), 968 V(-0.19041574084286472), 969 V(-0.20819395464), 970 V(-0.421064), 971 V(-0.32733845535848671), 972 }}, 973 {{ 974 V(1.8 * 9311.3238710010046), 975 V(-0.3041958212306401), 976 V(-0.3633036457487539), 977 V(-0.35660379990111464), 978 V(-0.3443074455424403), 979 V(-0.33699592683512467), 980 V(-0.30180866526242109), 981 V(-0.27321683125358037), 982 }}, 983 {{ 984 V(1.8 * 4992.2486445538634), 985 V(-1.2), 986 V(-1.2), 987 V(-0.8), 988 V(-0.7), 989 V(-0.7), 990 V(-0.4), 991 V(-0.5), 992 }}}}, 993 8)); 994 } 995 996 // DCT128X64 997 static QuantEncodingInternal DCT64X128() { 998 return QuantEncodingInternal::DCT( 999 DctQuantWeightParams({{{{ 1000 V(1.3 * 23629.073922049845), 1001 V(-1.025), 1002 V(-0.78), 1003 V(-0.65012), 1004 V(-0.19041574084286472), 1005 V(-0.20819395464), 1006 V(-0.421064), 1007 V(-0.32733845535848671), 1008 }}, 1009 {{ 1010 V(1.3 * 8611.3238710010046), 1011 V(-0.3041958212306401), 1012 V(-0.3633036457487539), 1013 V(-0.35660379990111464), 1014 V(-0.3443074455424403), 1015 V(-0.33699592683512467), 1016 V(-0.30180866526242109), 1017 V(-0.27321683125358037), 1018 }}, 1019 {{ 1020 V(1.3 * 4492.2486445538634), 1021 V(-1.2), 1022 V(-1.2), 1023 V(-0.8), 1024 V(-0.7), 1025 V(-0.7), 1026 V(-0.4), 1027 V(-0.5), 1028 }}}}, 1029 8)); 1030 } 1031 // DCT256X256 1032 static QuantEncodingInternal DCT256X256() { 1033 return QuantEncodingInternal::DCT( 1034 DctQuantWeightParams({{{{ 1035 V(3.6 * 26629.073922049845), 1036 V(-1.025), 1037 V(-0.78), 1038 V(-0.65012), 1039 V(-0.19041574084286472), 1040 V(-0.20819395464), 1041 V(-0.421064), 1042 V(-0.32733845535848671), 1043 }}, 1044 {{ 1045 V(3.6 * 9311.3238710010046), 1046 V(-0.3041958212306401), 1047 V(-0.3633036457487539), 1048 V(-0.35660379990111464), 1049 V(-0.3443074455424403), 1050 V(-0.33699592683512467), 1051 V(-0.30180866526242109), 1052 V(-0.27321683125358037), 1053 }}, 1054 {{ 1055 V(3.6 * 4992.2486445538634), 1056 V(-1.2), 1057 V(-1.2), 1058 V(-0.8), 1059 V(-0.7), 1060 V(-0.7), 1061 V(-0.4), 1062 V(-0.5), 1063 }}}}, 1064 8)); 1065 } 1066 1067 // DCT256X128 1068 static QuantEncodingInternal DCT128X256() { 1069 return QuantEncodingInternal::DCT( 1070 DctQuantWeightParams({{{{ 1071 V(2.6 * 23629.073922049845), 1072 V(-1.025), 1073 V(-0.78), 1074 V(-0.65012), 1075 V(-0.19041574084286472), 1076 V(-0.20819395464), 1077 V(-0.421064), 1078 V(-0.32733845535848671), 1079 }}, 1080 {{ 1081 V(2.6 * 8611.3238710010046), 1082 V(-0.3041958212306401), 1083 V(-0.3633036457487539), 1084 V(-0.35660379990111464), 1085 V(-0.3443074455424403), 1086 V(-0.33699592683512467), 1087 V(-0.30180866526242109), 1088 V(-0.27321683125358037), 1089 }}, 1090 {{ 1091 V(2.6 * 4492.2486445538634), 1092 V(-1.2), 1093 V(-1.2), 1094 V(-0.8), 1095 V(-0.7), 1096 V(-0.7), 1097 V(-0.4), 1098 V(-0.5), 1099 }}}}, 1100 8)); 1101 } 1102 }; 1103 } // namespace 1104 1105 DequantMatrices::DequantLibraryInternal DequantMatrices::LibraryInit() { 1106 static_assert(kNum == 17, 1107 "Update this function when adding new quantization kinds."); 1108 static_assert(kNumPredefinedTables == 1, 1109 "Update this function when adding new quantization matrices to " 1110 "the library."); 1111 1112 // The library and the indices need to be kept in sync manually. 1113 static_assert(0 == DCT, "Update the DequantLibrary array below."); 1114 static_assert(1 == IDENTITY, "Update the DequantLibrary array below."); 1115 static_assert(2 == DCT2X2, "Update the DequantLibrary array below."); 1116 static_assert(3 == DCT4X4, "Update the DequantLibrary array below."); 1117 static_assert(4 == DCT16X16, "Update the DequantLibrary array below."); 1118 static_assert(5 == DCT32X32, "Update the DequantLibrary array below."); 1119 static_assert(6 == DCT8X16, "Update the DequantLibrary array below."); 1120 static_assert(7 == DCT8X32, "Update the DequantLibrary array below."); 1121 static_assert(8 == DCT16X32, "Update the DequantLibrary array below."); 1122 static_assert(9 == DCT4X8, "Update the DequantLibrary array below."); 1123 static_assert(10 == AFV0, "Update the DequantLibrary array below."); 1124 static_assert(11 == DCT64X64, "Update the DequantLibrary array below."); 1125 static_assert(12 == DCT32X64, "Update the DequantLibrary array below."); 1126 static_assert(13 == DCT128X128, "Update the DequantLibrary array below."); 1127 static_assert(14 == DCT64X128, "Update the DequantLibrary array below."); 1128 static_assert(15 == DCT256X256, "Update the DequantLibrary array below."); 1129 static_assert(16 == DCT128X256, "Update the DequantLibrary array below."); 1130 return DequantMatrices::DequantLibraryInternal{{ 1131 DequantMatricesLibraryDef::DCT(), 1132 DequantMatricesLibraryDef::IDENTITY(), 1133 DequantMatricesLibraryDef::DCT2X2(), 1134 DequantMatricesLibraryDef::DCT4X4(), 1135 DequantMatricesLibraryDef::DCT16X16(), 1136 DequantMatricesLibraryDef::DCT32X32(), 1137 DequantMatricesLibraryDef::DCT8X16(), 1138 DequantMatricesLibraryDef::DCT8X32(), 1139 DequantMatricesLibraryDef::DCT16X32(), 1140 DequantMatricesLibraryDef::DCT4X8(), 1141 DequantMatricesLibraryDef::AFV0(), 1142 DequantMatricesLibraryDef::DCT64X64(), 1143 DequantMatricesLibraryDef::DCT32X64(), 1144 // Same default for large transforms (128+) as for 64x* transforms. 1145 DequantMatricesLibraryDef::DCT128X128(), 1146 DequantMatricesLibraryDef::DCT64X128(), 1147 DequantMatricesLibraryDef::DCT256X256(), 1148 DequantMatricesLibraryDef::DCT128X256(), 1149 }}; 1150 } 1151 1152 const QuantEncoding* DequantMatrices::Library() { 1153 static const DequantMatrices::DequantLibraryInternal kDequantLibrary = 1154 DequantMatrices::LibraryInit(); 1155 // Downcast the result to a const QuantEncoding* from QuantEncodingInternal* 1156 // since the subclass (QuantEncoding) doesn't add any new members and users 1157 // will need to upcast to QuantEncodingInternal to access the members of that 1158 // class. This allows to have kDequantLibrary as a constexpr value while still 1159 // allowing to create QuantEncoding::RAW() instances that use std::vector in 1160 // C++11. 1161 return reinterpret_cast<const QuantEncoding*>(kDequantLibrary.data()); 1162 } 1163 1164 DequantMatrices::DequantMatrices() { 1165 encodings_.resize(static_cast<size_t>(QuantTable::kNum), 1166 QuantEncoding::Library(0)); 1167 size_t pos = 0; 1168 size_t offsets[kNum * 3]; 1169 for (size_t i = 0; i < static_cast<size_t>(QuantTable::kNum); i++) { 1170 size_t num = required_size_x[i] * required_size_y[i] * kDCTBlockSize; 1171 for (size_t c = 0; c < 3; c++) { 1172 offsets[3 * i + c] = pos + c * num; 1173 } 1174 pos += 3 * num; 1175 } 1176 for (size_t i = 0; i < AcStrategy::kNumValidStrategies; i++) { 1177 for (size_t c = 0; c < 3; c++) { 1178 table_offsets_[i * 3 + c] = offsets[kQuantTable[i] * 3 + c]; 1179 } 1180 } 1181 } 1182 1183 Status DequantMatrices::EnsureComputed(uint32_t acs_mask) { 1184 const QuantEncoding* library = Library(); 1185 1186 if (!table_storage_) { 1187 table_storage_ = hwy::AllocateAligned<float>(2 * kTotalTableSize); 1188 table_ = table_storage_.get(); 1189 inv_table_ = table_storage_.get() + kTotalTableSize; 1190 } 1191 1192 size_t offsets[kNum * 3 + 1]; 1193 size_t pos = 0; 1194 for (size_t i = 0; i < kNum; i++) { 1195 size_t num = required_size_x[i] * required_size_y[i] * kDCTBlockSize; 1196 for (size_t c = 0; c < 3; c++) { 1197 offsets[3 * i + c] = pos + c * num; 1198 } 1199 pos += 3 * num; 1200 } 1201 offsets[kNum * 3] = pos; 1202 JXL_ASSERT(pos == kTotalTableSize); 1203 1204 uint32_t kind_mask = 0; 1205 for (size_t i = 0; i < AcStrategy::kNumValidStrategies; i++) { 1206 if (acs_mask & (1u << i)) { 1207 kind_mask |= 1u << kQuantTable[i]; 1208 } 1209 } 1210 uint32_t computed_kind_mask = 0; 1211 for (size_t i = 0; i < AcStrategy::kNumValidStrategies; i++) { 1212 if (computed_mask_ & (1u << i)) { 1213 computed_kind_mask |= 1u << kQuantTable[i]; 1214 } 1215 } 1216 for (size_t table = 0; table < kNum; table++) { 1217 if ((1 << table) & computed_kind_mask) continue; 1218 if ((1 << table) & ~kind_mask) continue; 1219 size_t pos = offsets[table * 3]; 1220 if (encodings_[table].mode == QuantEncoding::kQuantModeLibrary) { 1221 JXL_CHECK(HWY_DYNAMIC_DISPATCH(ComputeQuantTable)( 1222 library[table], table_storage_.get(), 1223 table_storage_.get() + kTotalTableSize, table, QuantTable(table), 1224 &pos)); 1225 } else { 1226 JXL_RETURN_IF_ERROR(HWY_DYNAMIC_DISPATCH(ComputeQuantTable)( 1227 encodings_[table], table_storage_.get(), 1228 table_storage_.get() + kTotalTableSize, table, QuantTable(table), 1229 &pos)); 1230 } 1231 JXL_ASSERT(pos == offsets[table * 3 + 3]); 1232 } 1233 computed_mask_ |= acs_mask; 1234 1235 return true; 1236 } 1237 1238 } // namespace jxl 1239 #endif