libjxl

FORK: libjxl patches used on blog
git clone https://git.neptards.moe/blog/libjxl.git
Log | Files | Refs | Submodules | README | LICENSE

enc_quant_weights.cc (6961B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 #include "lib/jxl/enc_quant_weights.h"
      7 
      8 #include <jxl/types.h>
      9 #include <stdlib.h>
     10 
     11 #include <cmath>
     12 
     13 #include "lib/jxl/base/common.h"
     14 #include "lib/jxl/base/status.h"
     15 #include "lib/jxl/enc_aux_out.h"
     16 #include "lib/jxl/enc_bit_writer.h"
     17 #include "lib/jxl/enc_modular.h"
     18 #include "lib/jxl/fields.h"
     19 #include "lib/jxl/modular/encoding/encoding.h"
     20 
     21 namespace jxl {
     22 
     23 struct AuxOut;
     24 
     25 namespace {
     26 
     27 Status EncodeDctParams(const DctQuantWeightParams& params, BitWriter* writer) {
     28   JXL_ASSERT(params.num_distance_bands >= 1);
     29   writer->Write(DctQuantWeightParams::kLog2MaxDistanceBands,
     30                 params.num_distance_bands - 1);
     31   for (size_t c = 0; c < 3; c++) {
     32     for (size_t i = 0; i < params.num_distance_bands; i++) {
     33       JXL_RETURN_IF_ERROR(F16Coder::Write(
     34           params.distance_bands[c][i] * (i == 0 ? (1 / 64.0f) : 1.0f), writer));
     35     }
     36   }
     37   return true;
     38 }
     39 
     40 Status EncodeQuant(const QuantEncoding& encoding, size_t idx, size_t size_x,
     41                    size_t size_y, BitWriter* writer,
     42                    ModularFrameEncoder* modular_frame_encoder) {
     43   writer->Write(kLog2NumQuantModes, encoding.mode);
     44   size_x *= kBlockDim;
     45   size_y *= kBlockDim;
     46   switch (encoding.mode) {
     47     case QuantEncoding::kQuantModeLibrary: {
     48       writer->Write(kCeilLog2NumPredefinedTables, encoding.predefined);
     49       break;
     50     }
     51     case QuantEncoding::kQuantModeID: {
     52       for (size_t c = 0; c < 3; c++) {
     53         for (size_t i = 0; i < 3; i++) {
     54           JXL_RETURN_IF_ERROR(
     55               F16Coder::Write(encoding.idweights[c][i] * (1.0f / 64), writer));
     56         }
     57       }
     58       break;
     59     }
     60     case QuantEncoding::kQuantModeDCT2: {
     61       for (size_t c = 0; c < 3; c++) {
     62         for (size_t i = 0; i < 6; i++) {
     63           JXL_RETURN_IF_ERROR(F16Coder::Write(
     64               encoding.dct2weights[c][i] * (1.0f / 64), writer));
     65         }
     66       }
     67       break;
     68     }
     69     case QuantEncoding::kQuantModeDCT4X8: {
     70       for (size_t c = 0; c < 3; c++) {
     71         JXL_RETURN_IF_ERROR(
     72             F16Coder::Write(encoding.dct4x8multipliers[c], writer));
     73       }
     74       JXL_RETURN_IF_ERROR(EncodeDctParams(encoding.dct_params, writer));
     75       break;
     76     }
     77     case QuantEncoding::kQuantModeDCT4: {
     78       for (size_t c = 0; c < 3; c++) {
     79         for (size_t i = 0; i < 2; i++) {
     80           JXL_RETURN_IF_ERROR(
     81               F16Coder::Write(encoding.dct4multipliers[c][i], writer));
     82         }
     83       }
     84       JXL_RETURN_IF_ERROR(EncodeDctParams(encoding.dct_params, writer));
     85       break;
     86     }
     87     case QuantEncoding::kQuantModeDCT: {
     88       JXL_RETURN_IF_ERROR(EncodeDctParams(encoding.dct_params, writer));
     89       break;
     90     }
     91     case QuantEncoding::kQuantModeRAW: {
     92       JXL_RETURN_IF_ERROR(ModularFrameEncoder::EncodeQuantTable(
     93           size_x, size_y, writer, encoding, idx, modular_frame_encoder));
     94       break;
     95     }
     96     case QuantEncoding::kQuantModeAFV: {
     97       for (size_t c = 0; c < 3; c++) {
     98         for (size_t i = 0; i < 9; i++) {
     99           JXL_RETURN_IF_ERROR(F16Coder::Write(
    100               encoding.afv_weights[c][i] * (i < 6 ? 1.0f / 64 : 1.0f), writer));
    101         }
    102       }
    103       JXL_RETURN_IF_ERROR(EncodeDctParams(encoding.dct_params, writer));
    104       JXL_RETURN_IF_ERROR(EncodeDctParams(encoding.dct_params_afv_4x4, writer));
    105       break;
    106     }
    107   }
    108   return true;
    109 }
    110 
    111 }  // namespace
    112 
    113 Status DequantMatricesEncode(const DequantMatrices& matrices, BitWriter* writer,
    114                              size_t layer, AuxOut* aux_out,
    115                              ModularFrameEncoder* modular_frame_encoder) {
    116   bool all_default = true;
    117   const std::vector<QuantEncoding>& encodings = matrices.encodings();
    118 
    119   for (size_t i = 0; i < encodings.size(); i++) {
    120     if (encodings[i].mode != QuantEncoding::kQuantModeLibrary ||
    121         encodings[i].predefined != 0) {
    122       all_default = false;
    123     }
    124   }
    125   // TODO(janwas): better bound
    126   BitWriter::Allotment allotment(writer, 512 * 1024);
    127   writer->Write(1, TO_JXL_BOOL(all_default));
    128   if (!all_default) {
    129     for (size_t i = 0; i < encodings.size(); i++) {
    130       JXL_RETURN_IF_ERROR(EncodeQuant(
    131           encodings[i], i, DequantMatrices::required_size_x[i],
    132           DequantMatrices::required_size_y[i], writer, modular_frame_encoder));
    133     }
    134   }
    135   allotment.ReclaimAndCharge(writer, layer, aux_out);
    136   return true;
    137 }
    138 
    139 Status DequantMatricesEncodeDC(const DequantMatrices& matrices,
    140                                BitWriter* writer, size_t layer,
    141                                AuxOut* aux_out) {
    142   bool all_default = true;
    143   const float* dc_quant = matrices.DCQuants();
    144   for (size_t c = 0; c < 3; c++) {
    145     if (dc_quant[c] != kDCQuant[c]) {
    146       all_default = false;
    147     }
    148   }
    149   BitWriter::Allotment allotment(writer, 1 + sizeof(float) * kBitsPerByte * 3);
    150   writer->Write(1, TO_JXL_BOOL(all_default));
    151   if (!all_default) {
    152     for (size_t c = 0; c < 3; c++) {
    153       JXL_RETURN_IF_ERROR(F16Coder::Write(dc_quant[c] * 128.0f, writer));
    154     }
    155   }
    156   allotment.ReclaimAndCharge(writer, layer, aux_out);
    157   return true;
    158 }
    159 
    160 void DequantMatricesSetCustomDC(DequantMatrices* matrices, const float* dc) {
    161   matrices->SetDCQuant(dc);
    162   // Roundtrip encode/decode DC to ensure same values as decoder.
    163   BitWriter writer;
    164   JXL_CHECK(DequantMatricesEncodeDC(*matrices, &writer, 0, nullptr));
    165   writer.ZeroPadToByte();
    166   BitReader br(writer.GetSpan());
    167   // Called only in the encoder: should fail only for programmer errors.
    168   JXL_CHECK(matrices->DecodeDC(&br));
    169   JXL_CHECK(br.Close());
    170 }
    171 
    172 void DequantMatricesScaleDC(DequantMatrices* matrices, const float scale) {
    173   float dc[3];
    174   for (size_t c = 0; c < 3; ++c) {
    175     dc[c] = matrices->InvDCQuant(c) * (1.0f / scale);
    176   }
    177   DequantMatricesSetCustomDC(matrices, dc);
    178 }
    179 
    180 void DequantMatricesRoundtrip(DequantMatrices* matrices) {
    181   // Do not pass modular en/decoder, as they only change entropy and not
    182   // values.
    183   BitWriter writer;
    184   JXL_CHECK(DequantMatricesEncode(*matrices, &writer, 0, nullptr));
    185   writer.ZeroPadToByte();
    186   BitReader br(writer.GetSpan());
    187   // Called only in the encoder: should fail only for programmer errors.
    188   JXL_CHECK(matrices->Decode(&br));
    189   JXL_CHECK(br.Close());
    190 }
    191 
    192 Status DequantMatricesSetCustom(DequantMatrices* matrices,
    193                                 const std::vector<QuantEncoding>& encodings,
    194                                 ModularFrameEncoder* encoder) {
    195   JXL_ASSERT(encodings.size() == DequantMatrices::kNum);
    196   matrices->SetEncodings(encodings);
    197   for (size_t i = 0; i < encodings.size(); i++) {
    198     if (encodings[i].mode == QuantEncodingInternal::kQuantModeRAW) {
    199       JXL_RETURN_IF_ERROR(encoder->AddQuantTable(
    200           DequantMatrices::required_size_x[i] * kBlockDim,
    201           DequantMatrices::required_size_y[i] * kBlockDim, encodings[i], i));
    202     }
    203   }
    204   DequantMatricesRoundtrip(matrices);
    205   return true;
    206 }
    207 
    208 }  // namespace jxl