libjxl

FORK: libjxl patches used on blog
git clone https://git.neptards.moe/blog/libjxl.git
Log | Files | Refs | Submodules | README | LICENSE

quant_weights.h (15658B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 #ifndef LIB_JXL_QUANT_WEIGHTS_H_
      7 #define LIB_JXL_QUANT_WEIGHTS_H_
      8 
      9 #include <stdint.h>
     10 #include <string.h>
     11 
     12 #include <array>
     13 #include <hwy/aligned_allocator.h>
     14 #include <utility>
     15 #include <vector>
     16 
     17 #include "lib/jxl/ac_strategy.h"
     18 #include "lib/jxl/base/common.h"
     19 #include "lib/jxl/base/compiler_specific.h"
     20 #include "lib/jxl/base/span.h"
     21 #include "lib/jxl/base/status.h"
     22 #include "lib/jxl/dec_bit_reader.h"
     23 #include "lib/jxl/image.h"
     24 
     25 namespace jxl {
     26 
     27 static constexpr size_t kMaxQuantTableSize = AcStrategy::kMaxCoeffArea;
     28 static constexpr size_t kNumPredefinedTables = 1;
     29 static constexpr size_t kCeilLog2NumPredefinedTables = 0;
     30 static constexpr size_t kLog2NumQuantModes = 3;
     31 
     32 struct DctQuantWeightParams {
     33   static constexpr size_t kLog2MaxDistanceBands = 4;
     34   static constexpr size_t kMaxDistanceBands = 1 + (1 << kLog2MaxDistanceBands);
     35   typedef std::array<std::array<float, kMaxDistanceBands>, 3>
     36       DistanceBandsArray;
     37 
     38   size_t num_distance_bands = 0;
     39   DistanceBandsArray distance_bands = {};
     40 
     41   constexpr DctQuantWeightParams() : num_distance_bands(0) {}
     42 
     43   constexpr DctQuantWeightParams(const DistanceBandsArray& dist_bands,
     44                                  size_t num_dist_bands)
     45       : num_distance_bands(num_dist_bands), distance_bands(dist_bands) {}
     46 
     47   template <size_t num_dist_bands>
     48   explicit DctQuantWeightParams(const float dist_bands[3][num_dist_bands]) {
     49     num_distance_bands = num_dist_bands;
     50     for (size_t c = 0; c < 3; c++) {
     51       memcpy(distance_bands[c].data(), dist_bands[c],
     52              sizeof(float) * num_dist_bands);
     53     }
     54   }
     55 };
     56 
     57 // NOLINTNEXTLINE(clang-analyzer-optin.performance.Padding)
     58 struct QuantEncodingInternal {
     59   enum Mode {
     60     kQuantModeLibrary,
     61     kQuantModeID,
     62     kQuantModeDCT2,
     63     kQuantModeDCT4,
     64     kQuantModeDCT4X8,
     65     kQuantModeAFV,
     66     kQuantModeDCT,
     67     kQuantModeRAW,
     68   };
     69 
     70   template <Mode mode>
     71   struct Tag {};
     72 
     73   typedef std::array<std::array<float, 3>, 3> IdWeights;
     74   typedef std::array<std::array<float, 6>, 3> DCT2Weights;
     75   typedef std::array<std::array<float, 2>, 3> DCT4Multipliers;
     76   typedef std::array<std::array<float, 9>, 3> AFVWeights;
     77   typedef std::array<float, 3> DCT4x8Multipliers;
     78 
     79   static constexpr QuantEncodingInternal Library(uint8_t predefined) {
     80     return ((predefined < kNumPredefinedTables) ||
     81             JXL_ABORT("Assert predefined < kNumPredefinedTables")),
     82            QuantEncodingInternal(Tag<kQuantModeLibrary>(), predefined);
     83   }
     84   constexpr QuantEncodingInternal(Tag<kQuantModeLibrary> /* tag */,
     85                                   uint8_t predefined)
     86       : mode(kQuantModeLibrary), predefined(predefined) {}
     87 
     88   // Identity
     89   // xybweights is an array of {xweights, yweights, bweights}.
     90   static constexpr QuantEncodingInternal Identity(const IdWeights& xybweights) {
     91     return QuantEncodingInternal(Tag<kQuantModeID>(), xybweights);
     92   }
     93   constexpr QuantEncodingInternal(Tag<kQuantModeID> /* tag */,
     94                                   const IdWeights& xybweights)
     95       : mode(kQuantModeID), idweights(xybweights) {}
     96 
     97   // DCT2
     98   static constexpr QuantEncodingInternal DCT2(const DCT2Weights& xybweights) {
     99     return QuantEncodingInternal(Tag<kQuantModeDCT2>(), xybweights);
    100   }
    101   constexpr QuantEncodingInternal(Tag<kQuantModeDCT2> /* tag */,
    102                                   const DCT2Weights& xybweights)
    103       : mode(kQuantModeDCT2), dct2weights(xybweights) {}
    104 
    105   // DCT4
    106   static constexpr QuantEncodingInternal DCT4(
    107       const DctQuantWeightParams& params, const DCT4Multipliers& xybmul) {
    108     return QuantEncodingInternal(Tag<kQuantModeDCT4>(), params, xybmul);
    109   }
    110   constexpr QuantEncodingInternal(Tag<kQuantModeDCT4> /* tag */,
    111                                   const DctQuantWeightParams& params,
    112                                   const DCT4Multipliers& xybmul)
    113       : mode(kQuantModeDCT4), dct_params(params), dct4multipliers(xybmul) {}
    114 
    115   // DCT4x8
    116   static constexpr QuantEncodingInternal DCT4X8(
    117       const DctQuantWeightParams& params, const DCT4x8Multipliers& xybmul) {
    118     return QuantEncodingInternal(Tag<kQuantModeDCT4X8>(), params, xybmul);
    119   }
    120   constexpr QuantEncodingInternal(Tag<kQuantModeDCT4X8> /* tag */,
    121                                   const DctQuantWeightParams& params,
    122                                   const DCT4x8Multipliers& xybmul)
    123       : mode(kQuantModeDCT4X8), dct_params(params), dct4x8multipliers(xybmul) {}
    124 
    125   // DCT
    126   static constexpr QuantEncodingInternal DCT(
    127       const DctQuantWeightParams& params) {
    128     return QuantEncodingInternal(Tag<kQuantModeDCT>(), params);
    129   }
    130   constexpr QuantEncodingInternal(Tag<kQuantModeDCT> /* tag */,
    131                                   const DctQuantWeightParams& params)
    132       : mode(kQuantModeDCT), dct_params(params) {}
    133 
    134   // AFV
    135   static constexpr QuantEncodingInternal AFV(
    136       const DctQuantWeightParams& params4x8,
    137       const DctQuantWeightParams& params4x4, const AFVWeights& weights) {
    138     return QuantEncodingInternal(Tag<kQuantModeAFV>(), params4x8, params4x4,
    139                                  weights);
    140   }
    141   constexpr QuantEncodingInternal(Tag<kQuantModeAFV> /* tag */,
    142                                   const DctQuantWeightParams& params4x8,
    143                                   const DctQuantWeightParams& params4x4,
    144                                   const AFVWeights& weights)
    145       : mode(kQuantModeAFV),
    146         dct_params(params4x8),
    147         afv_weights(weights),
    148         dct_params_afv_4x4(params4x4) {}
    149 
    150   // This constructor is not constexpr so it can't be used in any of the
    151   // constexpr cases above.
    152   explicit QuantEncodingInternal(Mode mode) : mode(mode) {}
    153 
    154   Mode mode;
    155 
    156   // Weights for DCT4+ tables.
    157   DctQuantWeightParams dct_params;
    158 
    159   union {
    160     // Weights for identity.
    161     IdWeights idweights;
    162 
    163     // Weights for DCT2.
    164     DCT2Weights dct2weights;
    165 
    166     // Extra multipliers for coefficients 01/10 and 11 for DCT4 and AFV.
    167     DCT4Multipliers dct4multipliers;
    168 
    169     // Weights for AFV. {0, 1} are used directly for coefficients (0, 1) and (1,
    170     // 0);  {2, 3, 4} are used directly corner DC, (1,0) - (0,1) and (0, 1) +
    171     // (1, 0) - (0, 0) inside the AFV block. Values from 5 to 8 are interpolated
    172     // as in GetQuantWeights for DC and are used for other coefficients.
    173     AFVWeights afv_weights = {};
    174 
    175     // Extra multipliers for coefficients 01 or 10 for DCT4X8 and DCT8X4.
    176     DCT4x8Multipliers dct4x8multipliers;
    177 
    178     // Only used in kQuantModeRAW mode.
    179     struct {
    180       // explicit quantization table (like in JPEG)
    181       std::vector<int>* qtable = nullptr;
    182       float qtable_den = 1.f / (8 * 255);
    183     } qraw;
    184   };
    185 
    186   // Weights for 4x4 sub-block in AFV.
    187   DctQuantWeightParams dct_params_afv_4x4;
    188 
    189   union {
    190     // Which predefined table to use. Only used if mode is kQuantModeLibrary.
    191     uint8_t predefined = 0;
    192 
    193     // Which other quant table to copy; must copy from a table that comes before
    194     // the current one. Only used if mode is kQuantModeCopy.
    195     uint8_t source;
    196   };
    197 };
    198 
    199 class QuantEncoding final : public QuantEncodingInternal {
    200  public:
    201   QuantEncoding(const QuantEncoding& other)
    202       : QuantEncodingInternal(
    203             static_cast<const QuantEncodingInternal&>(other)) {
    204     if (mode == kQuantModeRAW && qraw.qtable) {
    205       // Need to make a copy of the passed *qtable.
    206       qraw.qtable = new std::vector<int>(*other.qraw.qtable);
    207     }
    208   }
    209   QuantEncoding(QuantEncoding&& other) noexcept
    210       : QuantEncodingInternal(
    211             static_cast<const QuantEncodingInternal&>(other)) {
    212     // Steal the qtable from the other object if any.
    213     if (mode == kQuantModeRAW) {
    214       other.qraw.qtable = nullptr;
    215     }
    216   }
    217   QuantEncoding& operator=(const QuantEncoding& other) {
    218     if (mode == kQuantModeRAW && qraw.qtable) {
    219       delete qraw.qtable;
    220     }
    221     *static_cast<QuantEncodingInternal*>(this) =
    222         QuantEncodingInternal(static_cast<const QuantEncodingInternal&>(other));
    223     if (mode == kQuantModeRAW && qraw.qtable) {
    224       // Need to make a copy of the passed *qtable.
    225       qraw.qtable = new std::vector<int>(*other.qraw.qtable);
    226     }
    227     return *this;
    228   }
    229 
    230   ~QuantEncoding() {
    231     if (mode == kQuantModeRAW && qraw.qtable) {
    232       delete qraw.qtable;
    233     }
    234   }
    235 
    236   // Wrappers of the QuantEncodingInternal:: static functions that return a
    237   // QuantEncoding instead. This is using the explicit and private cast from
    238   // QuantEncodingInternal to QuantEncoding, which would be inlined anyway.
    239   // In general, you should use this wrappers. The only reason to directly
    240   // create a QuantEncodingInternal instance is if you need a constexpr version
    241   // of this class. Note that RAW() is not supported in that case since it uses
    242   // a std::vector.
    243   static QuantEncoding Library(uint8_t predefined_arg) {
    244     return QuantEncoding(QuantEncodingInternal::Library(predefined_arg));
    245   }
    246   static QuantEncoding Identity(const IdWeights& xybweights) {
    247     return QuantEncoding(QuantEncodingInternal::Identity(xybweights));
    248   }
    249   static QuantEncoding DCT2(const DCT2Weights& xybweights) {
    250     return QuantEncoding(QuantEncodingInternal::DCT2(xybweights));
    251   }
    252   static QuantEncoding DCT4(const DctQuantWeightParams& params,
    253                             const DCT4Multipliers& xybmul) {
    254     return QuantEncoding(QuantEncodingInternal::DCT4(params, xybmul));
    255   }
    256   static QuantEncoding DCT4X8(const DctQuantWeightParams& params,
    257                               const DCT4x8Multipliers& xybmul) {
    258     return QuantEncoding(QuantEncodingInternal::DCT4X8(params, xybmul));
    259   }
    260   static QuantEncoding DCT(const DctQuantWeightParams& params) {
    261     return QuantEncoding(QuantEncodingInternal::DCT(params));
    262   }
    263   static QuantEncoding AFV(const DctQuantWeightParams& params4x8,
    264                            const DctQuantWeightParams& params4x4,
    265                            const AFVWeights& weights) {
    266     return QuantEncoding(
    267         QuantEncodingInternal::AFV(params4x8, params4x4, weights));
    268   }
    269 
    270   // RAW, note that this one is not a constexpr one.
    271   static QuantEncoding RAW(const std::vector<int>& qtable, int shift = 0) {
    272     QuantEncoding encoding(kQuantModeRAW);
    273     encoding.qraw.qtable = new std::vector<int>();
    274     *encoding.qraw.qtable = qtable;
    275     encoding.qraw.qtable_den = (1 << shift) * (1.f / (8 * 255));
    276     return encoding;
    277   }
    278 
    279  private:
    280   explicit QuantEncoding(const QuantEncodingInternal& other)
    281       : QuantEncodingInternal(other) {}
    282 
    283   explicit QuantEncoding(QuantEncodingInternal::Mode mode_arg)
    284       : QuantEncodingInternal(mode_arg) {}
    285 };
    286 
    287 // A constexpr QuantEncodingInternal instance is often downcasted to the
    288 // QuantEncoding subclass even if the instance wasn't an instance of the
    289 // subclass. This is safe because user will upcast to QuantEncodingInternal to
    290 // access any of its members.
    291 static_assert(sizeof(QuantEncoding) == sizeof(QuantEncodingInternal),
    292               "Don't add any members to QuantEncoding");
    293 
    294 // Let's try to keep these 2**N for possible future simplicity.
    295 const float kInvDCQuant[3] = {
    296     4096.0f,
    297     512.0f,
    298     256.0f,
    299 };
    300 
    301 const float kDCQuant[3] = {
    302     1.0f / kInvDCQuant[0],
    303     1.0f / kInvDCQuant[1],
    304     1.0f / kInvDCQuant[2],
    305 };
    306 
    307 class ModularFrameEncoder;
    308 class ModularFrameDecoder;
    309 
    310 class DequantMatrices {
    311  public:
    312   enum QuantTable : size_t {
    313     DCT = 0,
    314     IDENTITY,
    315     DCT2X2,
    316     DCT4X4,
    317     DCT16X16,
    318     DCT32X32,
    319     // DCT16X8
    320     DCT8X16,
    321     // DCT32X8
    322     DCT8X32,
    323     // DCT32X16
    324     DCT16X32,
    325     DCT4X8,
    326     // DCT8X4
    327     AFV0,
    328     // AFV1
    329     // AFV2
    330     // AFV3
    331     DCT64X64,
    332     // DCT64X32,
    333     DCT32X64,
    334     DCT128X128,
    335     // DCT128X64,
    336     DCT64X128,
    337     DCT256X256,
    338     // DCT256X128,
    339     DCT128X256,
    340     kNum
    341   };
    342 
    343   static constexpr QuantTable kQuantTable[] = {
    344       QuantTable::DCT,        QuantTable::IDENTITY,   QuantTable::DCT2X2,
    345       QuantTable::DCT4X4,     QuantTable::DCT16X16,   QuantTable::DCT32X32,
    346       QuantTable::DCT8X16,    QuantTable::DCT8X16,    QuantTable::DCT8X32,
    347       QuantTable::DCT8X32,    QuantTable::DCT16X32,   QuantTable::DCT16X32,
    348       QuantTable::DCT4X8,     QuantTable::DCT4X8,     QuantTable::AFV0,
    349       QuantTable::AFV0,       QuantTable::AFV0,       QuantTable::AFV0,
    350       QuantTable::DCT64X64,   QuantTable::DCT32X64,   QuantTable::DCT32X64,
    351       QuantTable::DCT128X128, QuantTable::DCT64X128,  QuantTable::DCT64X128,
    352       QuantTable::DCT256X256, QuantTable::DCT128X256, QuantTable::DCT128X256,
    353   };
    354   static_assert(AcStrategy::kNumValidStrategies ==
    355                     sizeof(kQuantTable) / sizeof *kQuantTable,
    356                 "Update this array when adding or removing AC strategies.");
    357 
    358   DequantMatrices();
    359 
    360   static const QuantEncoding* Library();
    361 
    362   typedef std::array<QuantEncodingInternal, kNumPredefinedTables * kNum>
    363       DequantLibraryInternal;
    364   // Return the array of library kNumPredefinedTables QuantEncoding entries as
    365   // a constexpr array. Use Library() to obtain a pointer to the copy in the
    366   // .cc file.
    367   static DequantLibraryInternal LibraryInit();
    368 
    369   // Returns aligned memory.
    370   JXL_INLINE const float* Matrix(size_t quant_kind, size_t c) const {
    371     JXL_DASSERT(quant_kind < AcStrategy::kNumValidStrategies);
    372     JXL_DASSERT((1 << quant_kind) & computed_mask_);
    373     return &table_[table_offsets_[quant_kind * 3 + c]];
    374   }
    375 
    376   JXL_INLINE const float* InvMatrix(size_t quant_kind, size_t c) const {
    377     JXL_DASSERT(quant_kind < AcStrategy::kNumValidStrategies);
    378     JXL_DASSERT((1 << quant_kind) & computed_mask_);
    379     return &inv_table_[table_offsets_[quant_kind * 3 + c]];
    380   }
    381 
    382   // DC quants are used in modular mode for XYB multipliers.
    383   JXL_INLINE float DCQuant(size_t c) const { return dc_quant_[c]; }
    384   JXL_INLINE const float* DCQuants() const { return dc_quant_; }
    385 
    386   JXL_INLINE float InvDCQuant(size_t c) const { return inv_dc_quant_[c]; }
    387 
    388   // For encoder.
    389   void SetEncodings(const std::vector<QuantEncoding>& encodings) {
    390     encodings_ = encodings;
    391     computed_mask_ = 0;
    392   }
    393 
    394   // For encoder.
    395   void SetDCQuant(const float dc[3]) {
    396     for (size_t c = 0; c < 3; c++) {
    397       dc_quant_[c] = 1.0f / dc[c];
    398       inv_dc_quant_[c] = dc[c];
    399     }
    400   }
    401 
    402   Status Decode(BitReader* br,
    403                 ModularFrameDecoder* modular_frame_decoder = nullptr);
    404   Status DecodeDC(BitReader* br);
    405 
    406   const std::vector<QuantEncoding>& encodings() const { return encodings_; }
    407 
    408   static constexpr auto required_size_x =
    409       to_array<int>({1, 1, 1, 1, 2, 4, 1, 1, 2, 1, 1, 8, 4, 16, 8, 32, 16});
    410   static_assert(kNum == required_size_x.size(),
    411                 "Update this array when adding or removing quant tables.");
    412 
    413   static constexpr auto required_size_y =
    414       to_array<int>({1, 1, 1, 1, 2, 4, 2, 4, 4, 1, 1, 8, 8, 16, 16, 32, 32});
    415   static_assert(kNum == required_size_y.size(),
    416                 "Update this array when adding or removing quant tables.");
    417 
    418   // MUST be equal `sum(dot(required_size_x, required_size_y))`.
    419   static constexpr size_t kSumRequiredXy = 2056;
    420 
    421   Status EnsureComputed(uint32_t acs_mask);
    422 
    423  private:
    424   static constexpr size_t kTotalTableSize = kSumRequiredXy * kDCTBlockSize * 3;
    425 
    426   uint32_t computed_mask_ = 0;
    427   // kTotalTableSize entries followed by kTotalTableSize for inv_table
    428   hwy::AlignedFreeUniquePtr<float[]> table_storage_;
    429   const float* table_;
    430   const float* inv_table_;
    431   float dc_quant_[3] = {kDCQuant[0], kDCQuant[1], kDCQuant[2]};
    432   float inv_dc_quant_[3] = {kInvDCQuant[0], kInvDCQuant[1], kInvDCQuant[2]};
    433   size_t table_offsets_[AcStrategy::kNumValidStrategies * 3];
    434   std::vector<QuantEncoding> encodings_;
    435 };
    436 
    437 }  // namespace jxl
    438 
    439 #endif  // LIB_JXL_QUANT_WEIGHTS_H_