quant_weights.h (15658B)
1 // Copyright (c) the JPEG XL Project Authors. All rights reserved. 2 // 3 // Use of this source code is governed by a BSD-style 4 // license that can be found in the LICENSE file. 5 6 #ifndef LIB_JXL_QUANT_WEIGHTS_H_ 7 #define LIB_JXL_QUANT_WEIGHTS_H_ 8 9 #include <stdint.h> 10 #include <string.h> 11 12 #include <array> 13 #include <hwy/aligned_allocator.h> 14 #include <utility> 15 #include <vector> 16 17 #include "lib/jxl/ac_strategy.h" 18 #include "lib/jxl/base/common.h" 19 #include "lib/jxl/base/compiler_specific.h" 20 #include "lib/jxl/base/span.h" 21 #include "lib/jxl/base/status.h" 22 #include "lib/jxl/dec_bit_reader.h" 23 #include "lib/jxl/image.h" 24 25 namespace jxl { 26 27 static constexpr size_t kMaxQuantTableSize = AcStrategy::kMaxCoeffArea; 28 static constexpr size_t kNumPredefinedTables = 1; 29 static constexpr size_t kCeilLog2NumPredefinedTables = 0; 30 static constexpr size_t kLog2NumQuantModes = 3; 31 32 struct DctQuantWeightParams { 33 static constexpr size_t kLog2MaxDistanceBands = 4; 34 static constexpr size_t kMaxDistanceBands = 1 + (1 << kLog2MaxDistanceBands); 35 typedef std::array<std::array<float, kMaxDistanceBands>, 3> 36 DistanceBandsArray; 37 38 size_t num_distance_bands = 0; 39 DistanceBandsArray distance_bands = {}; 40 41 constexpr DctQuantWeightParams() : num_distance_bands(0) {} 42 43 constexpr DctQuantWeightParams(const DistanceBandsArray& dist_bands, 44 size_t num_dist_bands) 45 : num_distance_bands(num_dist_bands), distance_bands(dist_bands) {} 46 47 template <size_t num_dist_bands> 48 explicit DctQuantWeightParams(const float dist_bands[3][num_dist_bands]) { 49 num_distance_bands = num_dist_bands; 50 for (size_t c = 0; c < 3; c++) { 51 memcpy(distance_bands[c].data(), dist_bands[c], 52 sizeof(float) * num_dist_bands); 53 } 54 } 55 }; 56 57 // NOLINTNEXTLINE(clang-analyzer-optin.performance.Padding) 58 struct QuantEncodingInternal { 59 enum Mode { 60 kQuantModeLibrary, 61 kQuantModeID, 62 kQuantModeDCT2, 63 kQuantModeDCT4, 64 kQuantModeDCT4X8, 65 kQuantModeAFV, 66 kQuantModeDCT, 67 kQuantModeRAW, 68 }; 69 70 template <Mode mode> 71 struct Tag {}; 72 73 typedef std::array<std::array<float, 3>, 3> IdWeights; 74 typedef std::array<std::array<float, 6>, 3> DCT2Weights; 75 typedef std::array<std::array<float, 2>, 3> DCT4Multipliers; 76 typedef std::array<std::array<float, 9>, 3> AFVWeights; 77 typedef std::array<float, 3> DCT4x8Multipliers; 78 79 static constexpr QuantEncodingInternal Library(uint8_t predefined) { 80 return ((predefined < kNumPredefinedTables) || 81 JXL_ABORT("Assert predefined < kNumPredefinedTables")), 82 QuantEncodingInternal(Tag<kQuantModeLibrary>(), predefined); 83 } 84 constexpr QuantEncodingInternal(Tag<kQuantModeLibrary> /* tag */, 85 uint8_t predefined) 86 : mode(kQuantModeLibrary), predefined(predefined) {} 87 88 // Identity 89 // xybweights is an array of {xweights, yweights, bweights}. 90 static constexpr QuantEncodingInternal Identity(const IdWeights& xybweights) { 91 return QuantEncodingInternal(Tag<kQuantModeID>(), xybweights); 92 } 93 constexpr QuantEncodingInternal(Tag<kQuantModeID> /* tag */, 94 const IdWeights& xybweights) 95 : mode(kQuantModeID), idweights(xybweights) {} 96 97 // DCT2 98 static constexpr QuantEncodingInternal DCT2(const DCT2Weights& xybweights) { 99 return QuantEncodingInternal(Tag<kQuantModeDCT2>(), xybweights); 100 } 101 constexpr QuantEncodingInternal(Tag<kQuantModeDCT2> /* tag */, 102 const DCT2Weights& xybweights) 103 : mode(kQuantModeDCT2), dct2weights(xybweights) {} 104 105 // DCT4 106 static constexpr QuantEncodingInternal DCT4( 107 const DctQuantWeightParams& params, const DCT4Multipliers& xybmul) { 108 return QuantEncodingInternal(Tag<kQuantModeDCT4>(), params, xybmul); 109 } 110 constexpr QuantEncodingInternal(Tag<kQuantModeDCT4> /* tag */, 111 const DctQuantWeightParams& params, 112 const DCT4Multipliers& xybmul) 113 : mode(kQuantModeDCT4), dct_params(params), dct4multipliers(xybmul) {} 114 115 // DCT4x8 116 static constexpr QuantEncodingInternal DCT4X8( 117 const DctQuantWeightParams& params, const DCT4x8Multipliers& xybmul) { 118 return QuantEncodingInternal(Tag<kQuantModeDCT4X8>(), params, xybmul); 119 } 120 constexpr QuantEncodingInternal(Tag<kQuantModeDCT4X8> /* tag */, 121 const DctQuantWeightParams& params, 122 const DCT4x8Multipliers& xybmul) 123 : mode(kQuantModeDCT4X8), dct_params(params), dct4x8multipliers(xybmul) {} 124 125 // DCT 126 static constexpr QuantEncodingInternal DCT( 127 const DctQuantWeightParams& params) { 128 return QuantEncodingInternal(Tag<kQuantModeDCT>(), params); 129 } 130 constexpr QuantEncodingInternal(Tag<kQuantModeDCT> /* tag */, 131 const DctQuantWeightParams& params) 132 : mode(kQuantModeDCT), dct_params(params) {} 133 134 // AFV 135 static constexpr QuantEncodingInternal AFV( 136 const DctQuantWeightParams& params4x8, 137 const DctQuantWeightParams& params4x4, const AFVWeights& weights) { 138 return QuantEncodingInternal(Tag<kQuantModeAFV>(), params4x8, params4x4, 139 weights); 140 } 141 constexpr QuantEncodingInternal(Tag<kQuantModeAFV> /* tag */, 142 const DctQuantWeightParams& params4x8, 143 const DctQuantWeightParams& params4x4, 144 const AFVWeights& weights) 145 : mode(kQuantModeAFV), 146 dct_params(params4x8), 147 afv_weights(weights), 148 dct_params_afv_4x4(params4x4) {} 149 150 // This constructor is not constexpr so it can't be used in any of the 151 // constexpr cases above. 152 explicit QuantEncodingInternal(Mode mode) : mode(mode) {} 153 154 Mode mode; 155 156 // Weights for DCT4+ tables. 157 DctQuantWeightParams dct_params; 158 159 union { 160 // Weights for identity. 161 IdWeights idweights; 162 163 // Weights for DCT2. 164 DCT2Weights dct2weights; 165 166 // Extra multipliers for coefficients 01/10 and 11 for DCT4 and AFV. 167 DCT4Multipliers dct4multipliers; 168 169 // Weights for AFV. {0, 1} are used directly for coefficients (0, 1) and (1, 170 // 0); {2, 3, 4} are used directly corner DC, (1,0) - (0,1) and (0, 1) + 171 // (1, 0) - (0, 0) inside the AFV block. Values from 5 to 8 are interpolated 172 // as in GetQuantWeights for DC and are used for other coefficients. 173 AFVWeights afv_weights = {}; 174 175 // Extra multipliers for coefficients 01 or 10 for DCT4X8 and DCT8X4. 176 DCT4x8Multipliers dct4x8multipliers; 177 178 // Only used in kQuantModeRAW mode. 179 struct { 180 // explicit quantization table (like in JPEG) 181 std::vector<int>* qtable = nullptr; 182 float qtable_den = 1.f / (8 * 255); 183 } qraw; 184 }; 185 186 // Weights for 4x4 sub-block in AFV. 187 DctQuantWeightParams dct_params_afv_4x4; 188 189 union { 190 // Which predefined table to use. Only used if mode is kQuantModeLibrary. 191 uint8_t predefined = 0; 192 193 // Which other quant table to copy; must copy from a table that comes before 194 // the current one. Only used if mode is kQuantModeCopy. 195 uint8_t source; 196 }; 197 }; 198 199 class QuantEncoding final : public QuantEncodingInternal { 200 public: 201 QuantEncoding(const QuantEncoding& other) 202 : QuantEncodingInternal( 203 static_cast<const QuantEncodingInternal&>(other)) { 204 if (mode == kQuantModeRAW && qraw.qtable) { 205 // Need to make a copy of the passed *qtable. 206 qraw.qtable = new std::vector<int>(*other.qraw.qtable); 207 } 208 } 209 QuantEncoding(QuantEncoding&& other) noexcept 210 : QuantEncodingInternal( 211 static_cast<const QuantEncodingInternal&>(other)) { 212 // Steal the qtable from the other object if any. 213 if (mode == kQuantModeRAW) { 214 other.qraw.qtable = nullptr; 215 } 216 } 217 QuantEncoding& operator=(const QuantEncoding& other) { 218 if (mode == kQuantModeRAW && qraw.qtable) { 219 delete qraw.qtable; 220 } 221 *static_cast<QuantEncodingInternal*>(this) = 222 QuantEncodingInternal(static_cast<const QuantEncodingInternal&>(other)); 223 if (mode == kQuantModeRAW && qraw.qtable) { 224 // Need to make a copy of the passed *qtable. 225 qraw.qtable = new std::vector<int>(*other.qraw.qtable); 226 } 227 return *this; 228 } 229 230 ~QuantEncoding() { 231 if (mode == kQuantModeRAW && qraw.qtable) { 232 delete qraw.qtable; 233 } 234 } 235 236 // Wrappers of the QuantEncodingInternal:: static functions that return a 237 // QuantEncoding instead. This is using the explicit and private cast from 238 // QuantEncodingInternal to QuantEncoding, which would be inlined anyway. 239 // In general, you should use this wrappers. The only reason to directly 240 // create a QuantEncodingInternal instance is if you need a constexpr version 241 // of this class. Note that RAW() is not supported in that case since it uses 242 // a std::vector. 243 static QuantEncoding Library(uint8_t predefined_arg) { 244 return QuantEncoding(QuantEncodingInternal::Library(predefined_arg)); 245 } 246 static QuantEncoding Identity(const IdWeights& xybweights) { 247 return QuantEncoding(QuantEncodingInternal::Identity(xybweights)); 248 } 249 static QuantEncoding DCT2(const DCT2Weights& xybweights) { 250 return QuantEncoding(QuantEncodingInternal::DCT2(xybweights)); 251 } 252 static QuantEncoding DCT4(const DctQuantWeightParams& params, 253 const DCT4Multipliers& xybmul) { 254 return QuantEncoding(QuantEncodingInternal::DCT4(params, xybmul)); 255 } 256 static QuantEncoding DCT4X8(const DctQuantWeightParams& params, 257 const DCT4x8Multipliers& xybmul) { 258 return QuantEncoding(QuantEncodingInternal::DCT4X8(params, xybmul)); 259 } 260 static QuantEncoding DCT(const DctQuantWeightParams& params) { 261 return QuantEncoding(QuantEncodingInternal::DCT(params)); 262 } 263 static QuantEncoding AFV(const DctQuantWeightParams& params4x8, 264 const DctQuantWeightParams& params4x4, 265 const AFVWeights& weights) { 266 return QuantEncoding( 267 QuantEncodingInternal::AFV(params4x8, params4x4, weights)); 268 } 269 270 // RAW, note that this one is not a constexpr one. 271 static QuantEncoding RAW(const std::vector<int>& qtable, int shift = 0) { 272 QuantEncoding encoding(kQuantModeRAW); 273 encoding.qraw.qtable = new std::vector<int>(); 274 *encoding.qraw.qtable = qtable; 275 encoding.qraw.qtable_den = (1 << shift) * (1.f / (8 * 255)); 276 return encoding; 277 } 278 279 private: 280 explicit QuantEncoding(const QuantEncodingInternal& other) 281 : QuantEncodingInternal(other) {} 282 283 explicit QuantEncoding(QuantEncodingInternal::Mode mode_arg) 284 : QuantEncodingInternal(mode_arg) {} 285 }; 286 287 // A constexpr QuantEncodingInternal instance is often downcasted to the 288 // QuantEncoding subclass even if the instance wasn't an instance of the 289 // subclass. This is safe because user will upcast to QuantEncodingInternal to 290 // access any of its members. 291 static_assert(sizeof(QuantEncoding) == sizeof(QuantEncodingInternal), 292 "Don't add any members to QuantEncoding"); 293 294 // Let's try to keep these 2**N for possible future simplicity. 295 const float kInvDCQuant[3] = { 296 4096.0f, 297 512.0f, 298 256.0f, 299 }; 300 301 const float kDCQuant[3] = { 302 1.0f / kInvDCQuant[0], 303 1.0f / kInvDCQuant[1], 304 1.0f / kInvDCQuant[2], 305 }; 306 307 class ModularFrameEncoder; 308 class ModularFrameDecoder; 309 310 class DequantMatrices { 311 public: 312 enum QuantTable : size_t { 313 DCT = 0, 314 IDENTITY, 315 DCT2X2, 316 DCT4X4, 317 DCT16X16, 318 DCT32X32, 319 // DCT16X8 320 DCT8X16, 321 // DCT32X8 322 DCT8X32, 323 // DCT32X16 324 DCT16X32, 325 DCT4X8, 326 // DCT8X4 327 AFV0, 328 // AFV1 329 // AFV2 330 // AFV3 331 DCT64X64, 332 // DCT64X32, 333 DCT32X64, 334 DCT128X128, 335 // DCT128X64, 336 DCT64X128, 337 DCT256X256, 338 // DCT256X128, 339 DCT128X256, 340 kNum 341 }; 342 343 static constexpr QuantTable kQuantTable[] = { 344 QuantTable::DCT, QuantTable::IDENTITY, QuantTable::DCT2X2, 345 QuantTable::DCT4X4, QuantTable::DCT16X16, QuantTable::DCT32X32, 346 QuantTable::DCT8X16, QuantTable::DCT8X16, QuantTable::DCT8X32, 347 QuantTable::DCT8X32, QuantTable::DCT16X32, QuantTable::DCT16X32, 348 QuantTable::DCT4X8, QuantTable::DCT4X8, QuantTable::AFV0, 349 QuantTable::AFV0, QuantTable::AFV0, QuantTable::AFV0, 350 QuantTable::DCT64X64, QuantTable::DCT32X64, QuantTable::DCT32X64, 351 QuantTable::DCT128X128, QuantTable::DCT64X128, QuantTable::DCT64X128, 352 QuantTable::DCT256X256, QuantTable::DCT128X256, QuantTable::DCT128X256, 353 }; 354 static_assert(AcStrategy::kNumValidStrategies == 355 sizeof(kQuantTable) / sizeof *kQuantTable, 356 "Update this array when adding or removing AC strategies."); 357 358 DequantMatrices(); 359 360 static const QuantEncoding* Library(); 361 362 typedef std::array<QuantEncodingInternal, kNumPredefinedTables * kNum> 363 DequantLibraryInternal; 364 // Return the array of library kNumPredefinedTables QuantEncoding entries as 365 // a constexpr array. Use Library() to obtain a pointer to the copy in the 366 // .cc file. 367 static DequantLibraryInternal LibraryInit(); 368 369 // Returns aligned memory. 370 JXL_INLINE const float* Matrix(size_t quant_kind, size_t c) const { 371 JXL_DASSERT(quant_kind < AcStrategy::kNumValidStrategies); 372 JXL_DASSERT((1 << quant_kind) & computed_mask_); 373 return &table_[table_offsets_[quant_kind * 3 + c]]; 374 } 375 376 JXL_INLINE const float* InvMatrix(size_t quant_kind, size_t c) const { 377 JXL_DASSERT(quant_kind < AcStrategy::kNumValidStrategies); 378 JXL_DASSERT((1 << quant_kind) & computed_mask_); 379 return &inv_table_[table_offsets_[quant_kind * 3 + c]]; 380 } 381 382 // DC quants are used in modular mode for XYB multipliers. 383 JXL_INLINE float DCQuant(size_t c) const { return dc_quant_[c]; } 384 JXL_INLINE const float* DCQuants() const { return dc_quant_; } 385 386 JXL_INLINE float InvDCQuant(size_t c) const { return inv_dc_quant_[c]; } 387 388 // For encoder. 389 void SetEncodings(const std::vector<QuantEncoding>& encodings) { 390 encodings_ = encodings; 391 computed_mask_ = 0; 392 } 393 394 // For encoder. 395 void SetDCQuant(const float dc[3]) { 396 for (size_t c = 0; c < 3; c++) { 397 dc_quant_[c] = 1.0f / dc[c]; 398 inv_dc_quant_[c] = dc[c]; 399 } 400 } 401 402 Status Decode(BitReader* br, 403 ModularFrameDecoder* modular_frame_decoder = nullptr); 404 Status DecodeDC(BitReader* br); 405 406 const std::vector<QuantEncoding>& encodings() const { return encodings_; } 407 408 static constexpr auto required_size_x = 409 to_array<int>({1, 1, 1, 1, 2, 4, 1, 1, 2, 1, 1, 8, 4, 16, 8, 32, 16}); 410 static_assert(kNum == required_size_x.size(), 411 "Update this array when adding or removing quant tables."); 412 413 static constexpr auto required_size_y = 414 to_array<int>({1, 1, 1, 1, 2, 4, 2, 4, 4, 1, 1, 8, 8, 16, 16, 32, 32}); 415 static_assert(kNum == required_size_y.size(), 416 "Update this array when adding or removing quant tables."); 417 418 // MUST be equal `sum(dot(required_size_x, required_size_y))`. 419 static constexpr size_t kSumRequiredXy = 2056; 420 421 Status EnsureComputed(uint32_t acs_mask); 422 423 private: 424 static constexpr size_t kTotalTableSize = kSumRequiredXy * kDCTBlockSize * 3; 425 426 uint32_t computed_mask_ = 0; 427 // kTotalTableSize entries followed by kTotalTableSize for inv_table 428 hwy::AlignedFreeUniquePtr<float[]> table_storage_; 429 const float* table_; 430 const float* inv_table_; 431 float dc_quant_[3] = {kDCQuant[0], kDCQuant[1], kDCQuant[2]}; 432 float inv_dc_quant_[3] = {kInvDCQuant[0], kInvDCQuant[1], kInvDCQuant[2]}; 433 size_t table_offsets_[AcStrategy::kNumValidStrategies * 3]; 434 std::vector<QuantEncoding> encodings_; 435 }; 436 437 } // namespace jxl 438 439 #endif // LIB_JXL_QUANT_WEIGHTS_H_