enc_quant_weights.cc (6961B)
1 // Copyright (c) the JPEG XL Project Authors. All rights reserved. 2 // 3 // Use of this source code is governed by a BSD-style 4 // license that can be found in the LICENSE file. 5 6 #include "lib/jxl/enc_quant_weights.h" 7 8 #include <jxl/types.h> 9 #include <stdlib.h> 10 11 #include <cmath> 12 13 #include "lib/jxl/base/common.h" 14 #include "lib/jxl/base/status.h" 15 #include "lib/jxl/enc_aux_out.h" 16 #include "lib/jxl/enc_bit_writer.h" 17 #include "lib/jxl/enc_modular.h" 18 #include "lib/jxl/fields.h" 19 #include "lib/jxl/modular/encoding/encoding.h" 20 21 namespace jxl { 22 23 struct AuxOut; 24 25 namespace { 26 27 Status EncodeDctParams(const DctQuantWeightParams& params, BitWriter* writer) { 28 JXL_ASSERT(params.num_distance_bands >= 1); 29 writer->Write(DctQuantWeightParams::kLog2MaxDistanceBands, 30 params.num_distance_bands - 1); 31 for (size_t c = 0; c < 3; c++) { 32 for (size_t i = 0; i < params.num_distance_bands; i++) { 33 JXL_RETURN_IF_ERROR(F16Coder::Write( 34 params.distance_bands[c][i] * (i == 0 ? (1 / 64.0f) : 1.0f), writer)); 35 } 36 } 37 return true; 38 } 39 40 Status EncodeQuant(const QuantEncoding& encoding, size_t idx, size_t size_x, 41 size_t size_y, BitWriter* writer, 42 ModularFrameEncoder* modular_frame_encoder) { 43 writer->Write(kLog2NumQuantModes, encoding.mode); 44 size_x *= kBlockDim; 45 size_y *= kBlockDim; 46 switch (encoding.mode) { 47 case QuantEncoding::kQuantModeLibrary: { 48 writer->Write(kCeilLog2NumPredefinedTables, encoding.predefined); 49 break; 50 } 51 case QuantEncoding::kQuantModeID: { 52 for (size_t c = 0; c < 3; c++) { 53 for (size_t i = 0; i < 3; i++) { 54 JXL_RETURN_IF_ERROR( 55 F16Coder::Write(encoding.idweights[c][i] * (1.0f / 64), writer)); 56 } 57 } 58 break; 59 } 60 case QuantEncoding::kQuantModeDCT2: { 61 for (size_t c = 0; c < 3; c++) { 62 for (size_t i = 0; i < 6; i++) { 63 JXL_RETURN_IF_ERROR(F16Coder::Write( 64 encoding.dct2weights[c][i] * (1.0f / 64), writer)); 65 } 66 } 67 break; 68 } 69 case QuantEncoding::kQuantModeDCT4X8: { 70 for (size_t c = 0; c < 3; c++) { 71 JXL_RETURN_IF_ERROR( 72 F16Coder::Write(encoding.dct4x8multipliers[c], writer)); 73 } 74 JXL_RETURN_IF_ERROR(EncodeDctParams(encoding.dct_params, writer)); 75 break; 76 } 77 case QuantEncoding::kQuantModeDCT4: { 78 for (size_t c = 0; c < 3; c++) { 79 for (size_t i = 0; i < 2; i++) { 80 JXL_RETURN_IF_ERROR( 81 F16Coder::Write(encoding.dct4multipliers[c][i], writer)); 82 } 83 } 84 JXL_RETURN_IF_ERROR(EncodeDctParams(encoding.dct_params, writer)); 85 break; 86 } 87 case QuantEncoding::kQuantModeDCT: { 88 JXL_RETURN_IF_ERROR(EncodeDctParams(encoding.dct_params, writer)); 89 break; 90 } 91 case QuantEncoding::kQuantModeRAW: { 92 JXL_RETURN_IF_ERROR(ModularFrameEncoder::EncodeQuantTable( 93 size_x, size_y, writer, encoding, idx, modular_frame_encoder)); 94 break; 95 } 96 case QuantEncoding::kQuantModeAFV: { 97 for (size_t c = 0; c < 3; c++) { 98 for (size_t i = 0; i < 9; i++) { 99 JXL_RETURN_IF_ERROR(F16Coder::Write( 100 encoding.afv_weights[c][i] * (i < 6 ? 1.0f / 64 : 1.0f), writer)); 101 } 102 } 103 JXL_RETURN_IF_ERROR(EncodeDctParams(encoding.dct_params, writer)); 104 JXL_RETURN_IF_ERROR(EncodeDctParams(encoding.dct_params_afv_4x4, writer)); 105 break; 106 } 107 } 108 return true; 109 } 110 111 } // namespace 112 113 Status DequantMatricesEncode(const DequantMatrices& matrices, BitWriter* writer, 114 size_t layer, AuxOut* aux_out, 115 ModularFrameEncoder* modular_frame_encoder) { 116 bool all_default = true; 117 const std::vector<QuantEncoding>& encodings = matrices.encodings(); 118 119 for (size_t i = 0; i < encodings.size(); i++) { 120 if (encodings[i].mode != QuantEncoding::kQuantModeLibrary || 121 encodings[i].predefined != 0) { 122 all_default = false; 123 } 124 } 125 // TODO(janwas): better bound 126 BitWriter::Allotment allotment(writer, 512 * 1024); 127 writer->Write(1, TO_JXL_BOOL(all_default)); 128 if (!all_default) { 129 for (size_t i = 0; i < encodings.size(); i++) { 130 JXL_RETURN_IF_ERROR(EncodeQuant( 131 encodings[i], i, DequantMatrices::required_size_x[i], 132 DequantMatrices::required_size_y[i], writer, modular_frame_encoder)); 133 } 134 } 135 allotment.ReclaimAndCharge(writer, layer, aux_out); 136 return true; 137 } 138 139 Status DequantMatricesEncodeDC(const DequantMatrices& matrices, 140 BitWriter* writer, size_t layer, 141 AuxOut* aux_out) { 142 bool all_default = true; 143 const float* dc_quant = matrices.DCQuants(); 144 for (size_t c = 0; c < 3; c++) { 145 if (dc_quant[c] != kDCQuant[c]) { 146 all_default = false; 147 } 148 } 149 BitWriter::Allotment allotment(writer, 1 + sizeof(float) * kBitsPerByte * 3); 150 writer->Write(1, TO_JXL_BOOL(all_default)); 151 if (!all_default) { 152 for (size_t c = 0; c < 3; c++) { 153 JXL_RETURN_IF_ERROR(F16Coder::Write(dc_quant[c] * 128.0f, writer)); 154 } 155 } 156 allotment.ReclaimAndCharge(writer, layer, aux_out); 157 return true; 158 } 159 160 void DequantMatricesSetCustomDC(DequantMatrices* matrices, const float* dc) { 161 matrices->SetDCQuant(dc); 162 // Roundtrip encode/decode DC to ensure same values as decoder. 163 BitWriter writer; 164 JXL_CHECK(DequantMatricesEncodeDC(*matrices, &writer, 0, nullptr)); 165 writer.ZeroPadToByte(); 166 BitReader br(writer.GetSpan()); 167 // Called only in the encoder: should fail only for programmer errors. 168 JXL_CHECK(matrices->DecodeDC(&br)); 169 JXL_CHECK(br.Close()); 170 } 171 172 void DequantMatricesScaleDC(DequantMatrices* matrices, const float scale) { 173 float dc[3]; 174 for (size_t c = 0; c < 3; ++c) { 175 dc[c] = matrices->InvDCQuant(c) * (1.0f / scale); 176 } 177 DequantMatricesSetCustomDC(matrices, dc); 178 } 179 180 void DequantMatricesRoundtrip(DequantMatrices* matrices) { 181 // Do not pass modular en/decoder, as they only change entropy and not 182 // values. 183 BitWriter writer; 184 JXL_CHECK(DequantMatricesEncode(*matrices, &writer, 0, nullptr)); 185 writer.ZeroPadToByte(); 186 BitReader br(writer.GetSpan()); 187 // Called only in the encoder: should fail only for programmer errors. 188 JXL_CHECK(matrices->Decode(&br)); 189 JXL_CHECK(br.Close()); 190 } 191 192 Status DequantMatricesSetCustom(DequantMatrices* matrices, 193 const std::vector<QuantEncoding>& encodings, 194 ModularFrameEncoder* encoder) { 195 JXL_ASSERT(encodings.size() == DequantMatrices::kNum); 196 matrices->SetEncodings(encodings); 197 for (size_t i = 0; i < encodings.size(); i++) { 198 if (encodings[i].mode == QuantEncodingInternal::kQuantModeRAW) { 199 JXL_RETURN_IF_ERROR(encoder->AddQuantTable( 200 DequantMatrices::required_size_x[i] * kBlockDim, 201 DequantMatrices::required_size_y[i] * kBlockDim, encodings[i], i)); 202 } 203 } 204 DequantMatricesRoundtrip(matrices); 205 return true; 206 } 207 208 } // namespace jxl