enc_chroma_from_luma.cc (15890B)
1 // Copyright (c) the JPEG XL Project Authors. All rights reserved. 2 // 3 // Use of this source code is governed by a BSD-style 4 // license that can be found in the LICENSE file. 5 6 #include "lib/jxl/enc_chroma_from_luma.h" 7 8 #include <float.h> 9 #include <stdlib.h> 10 11 #include <algorithm> 12 #include <cmath> 13 14 #undef HWY_TARGET_INCLUDE 15 #define HWY_TARGET_INCLUDE "lib/jxl/enc_chroma_from_luma.cc" 16 #include <hwy/aligned_allocator.h> 17 #include <hwy/foreach_target.h> 18 #include <hwy/highway.h> 19 20 #include "lib/jxl/base/common.h" 21 #include "lib/jxl/base/status.h" 22 #include "lib/jxl/cms/opsin_params.h" 23 #include "lib/jxl/dec_transforms-inl.h" 24 #include "lib/jxl/enc_aux_out.h" 25 #include "lib/jxl/enc_params.h" 26 #include "lib/jxl/enc_transforms-inl.h" 27 #include "lib/jxl/quantizer.h" 28 #include "lib/jxl/simd_util.h" 29 HWY_BEFORE_NAMESPACE(); 30 namespace jxl { 31 namespace HWY_NAMESPACE { 32 33 // These templates are not found via ADL. 34 using hwy::HWY_NAMESPACE::Abs; 35 using hwy::HWY_NAMESPACE::Ge; 36 using hwy::HWY_NAMESPACE::GetLane; 37 using hwy::HWY_NAMESPACE::IfThenElse; 38 using hwy::HWY_NAMESPACE::Lt; 39 40 static HWY_FULL(float) df; 41 42 struct CFLFunction { 43 static constexpr float kCoeff = 1.f / 3; 44 static constexpr float kThres = 100.0f; 45 static constexpr float kInvColorFactor = 1.0f / kDefaultColorFactor; 46 CFLFunction(const float* values_m, const float* values_s, size_t num, 47 float base, float distance_mul) 48 : values_m(values_m), 49 values_s(values_s), 50 num(num), 51 base(base), 52 distance_mul(distance_mul) {} 53 54 // Returns f'(x), where f is 1/3 * sum ((|color residual| + 1)^2-1) + 55 // distance_mul * x^2 * num. 56 float Compute(float x, float eps, float* fpeps, float* fmeps) const { 57 float first_derivative = 2 * distance_mul * num * x; 58 float first_derivative_peps = 2 * distance_mul * num * (x + eps); 59 float first_derivative_meps = 2 * distance_mul * num * (x - eps); 60 61 const auto inv_color_factor = Set(df, kInvColorFactor); 62 const auto thres = Set(df, kThres); 63 const auto coeffx2 = Set(df, kCoeff * 2.0f); 64 const auto one = Set(df, 1.0f); 65 const auto zero = Set(df, 0.0f); 66 const auto base_v = Set(df, base); 67 const auto x_v = Set(df, x); 68 const auto xpe_v = Set(df, x + eps); 69 const auto xme_v = Set(df, x - eps); 70 auto fd_v = Zero(df); 71 auto fdpe_v = Zero(df); 72 auto fdme_v = Zero(df); 73 JXL_ASSERT(num % Lanes(df) == 0); 74 75 for (size_t i = 0; i < num; i += Lanes(df)) { 76 // color residual = ax + b 77 const auto a = Mul(inv_color_factor, Load(df, values_m + i)); 78 const auto b = 79 Sub(Mul(base_v, Load(df, values_m + i)), Load(df, values_s + i)); 80 const auto v = MulAdd(a, x_v, b); 81 const auto vpe = MulAdd(a, xpe_v, b); 82 const auto vme = MulAdd(a, xme_v, b); 83 const auto av = Abs(v); 84 const auto avpe = Abs(vpe); 85 const auto avme = Abs(vme); 86 const auto acoeffx2 = Mul(coeffx2, a); 87 auto d = Mul(acoeffx2, Add(av, one)); 88 auto dpe = Mul(acoeffx2, Add(avpe, one)); 89 auto dme = Mul(acoeffx2, Add(avme, one)); 90 d = IfThenElse(Lt(v, zero), Sub(zero, d), d); 91 dpe = IfThenElse(Lt(vpe, zero), Sub(zero, dpe), dpe); 92 dme = IfThenElse(Lt(vme, zero), Sub(zero, dme), dme); 93 const auto above = Ge(av, thres); 94 // TODO(eustas): use IfThenElseZero 95 fd_v = Add(fd_v, IfThenElse(above, zero, d)); 96 fdpe_v = Add(fdpe_v, IfThenElse(above, zero, dpe)); 97 fdme_v = Add(fdme_v, IfThenElse(above, zero, dme)); 98 } 99 100 *fpeps = first_derivative_peps + GetLane(SumOfLanes(df, fdpe_v)); 101 *fmeps = first_derivative_meps + GetLane(SumOfLanes(df, fdme_v)); 102 return first_derivative + GetLane(SumOfLanes(df, fd_v)); 103 } 104 105 const float* JXL_RESTRICT values_m; 106 const float* JXL_RESTRICT values_s; 107 size_t num; 108 float base; 109 float distance_mul; 110 }; 111 112 // Chroma-from-luma search, values_m will have luma -- and values_s chroma. 113 int32_t FindBestMultiplier(const float* values_m, const float* values_s, 114 size_t num, float base, float distance_mul, 115 bool fast) { 116 if (num == 0) { 117 return 0; 118 } 119 float x; 120 if (fast) { 121 static constexpr float kInvColorFactor = 1.0f / kDefaultColorFactor; 122 auto ca = Zero(df); 123 auto cb = Zero(df); 124 const auto inv_color_factor = Set(df, kInvColorFactor); 125 const auto base_v = Set(df, base); 126 for (size_t i = 0; i < num; i += Lanes(df)) { 127 // color residual = ax + b 128 const auto a = Mul(inv_color_factor, Load(df, values_m + i)); 129 const auto b = 130 Sub(Mul(base_v, Load(df, values_m + i)), Load(df, values_s + i)); 131 ca = MulAdd(a, a, ca); 132 cb = MulAdd(a, b, cb); 133 } 134 // + distance_mul * x^2 * num 135 x = -GetLane(SumOfLanes(df, cb)) / 136 (GetLane(SumOfLanes(df, ca)) + num * distance_mul * 0.5f); 137 } else { 138 constexpr float eps = 100; 139 constexpr float kClamp = 20.0f; 140 CFLFunction fn(values_m, values_s, num, base, distance_mul); 141 x = 0; 142 // Up to 20 Newton iterations, with approximate derivatives. 143 // Derivatives are approximate due to the high amount of noise in the exact 144 // derivatives. 145 for (size_t i = 0; i < 20; i++) { 146 float dfpeps; 147 float dfmeps; 148 float df = fn.Compute(x, eps, &dfpeps, &dfmeps); 149 float ddf = (dfpeps - dfmeps) / (2 * eps); 150 float kExperimentalInsignificantStabilizer = 0.85; 151 float step = df / (ddf + kExperimentalInsignificantStabilizer); 152 x -= std::min(kClamp, std::max(-kClamp, step)); 153 if (std::abs(step) < 3e-3) break; 154 } 155 } 156 // CFL seems to be tricky for larger transforms for HF components 157 // close to zero. This heuristic brings the solutions closer to zero 158 // and reduces red-green oscillations. A better approach would 159 // look into variance of the multiplier within separate (e.g. 8x8) 160 // areas and only apply this heuristic where there is a high variance. 161 // This would give about 1 % more compression density. 162 float towards_zero = 2.6; 163 if (x >= towards_zero) { 164 x -= towards_zero; 165 } else if (x <= -towards_zero) { 166 x += towards_zero; 167 } else { 168 x = 0; 169 } 170 return std::max(-128.0f, std::min(127.0f, roundf(x))); 171 } 172 173 Status InitDCStorage(size_t num_blocks, ImageF* dc_values) { 174 // First row: Y channel 175 // Second row: X channel 176 // Third row: Y channel 177 // Fourth row: B channel 178 JXL_ASSIGN_OR_RETURN(*dc_values, 179 ImageF::Create(RoundUpTo(num_blocks, Lanes(df)), 4)); 180 181 JXL_ASSERT(dc_values->xsize() != 0); 182 // Zero-fill the last lanes 183 for (size_t y = 0; y < 4; y++) { 184 for (size_t x = dc_values->xsize() - Lanes(df); x < dc_values->xsize(); 185 x++) { 186 dc_values->Row(y)[x] = 0; 187 } 188 } 189 return true; 190 } 191 192 void ComputeTile(const Image3F& opsin, const Rect& opsin_rect, 193 const DequantMatrices& dequant, 194 const AcStrategyImage* ac_strategy, 195 const ImageI* raw_quant_field, const Quantizer* quantizer, 196 const Rect& rect, bool fast, bool use_dct8, ImageSB* map_x, 197 ImageSB* map_b, ImageF* dc_values, float* mem) { 198 static_assert(kEncTileDimInBlocks == kColorTileDimInBlocks, 199 "Invalid color tile dim"); 200 size_t xsize_blocks = opsin_rect.xsize() / kBlockDim; 201 constexpr float kDistanceMultiplierAC = 1e-9f; 202 const size_t dct_scratch_size = 203 3 * (MaxVectorSize() / sizeof(float)) * AcStrategy::kMaxBlockDim; 204 205 const size_t y0 = rect.y0(); 206 const size_t x0 = rect.x0(); 207 const size_t x1 = rect.x0() + rect.xsize(); 208 const size_t y1 = rect.y0() + rect.ysize(); 209 210 int ty = y0 / kColorTileDimInBlocks; 211 int tx = x0 / kColorTileDimInBlocks; 212 213 int8_t* JXL_RESTRICT row_out_x = map_x->Row(ty); 214 int8_t* JXL_RESTRICT row_out_b = map_b->Row(ty); 215 216 float* JXL_RESTRICT dc_values_yx = dc_values->Row(0); 217 float* JXL_RESTRICT dc_values_x = dc_values->Row(1); 218 float* JXL_RESTRICT dc_values_yb = dc_values->Row(2); 219 float* JXL_RESTRICT dc_values_b = dc_values->Row(3); 220 221 // All are aligned. 222 float* HWY_RESTRICT block_y = mem; 223 float* HWY_RESTRICT block_x = block_y + AcStrategy::kMaxCoeffArea; 224 float* HWY_RESTRICT block_b = block_x + AcStrategy::kMaxCoeffArea; 225 float* HWY_RESTRICT coeffs_yx = block_b + AcStrategy::kMaxCoeffArea; 226 float* HWY_RESTRICT coeffs_x = coeffs_yx + kColorTileDim * kColorTileDim; 227 float* HWY_RESTRICT coeffs_yb = coeffs_x + kColorTileDim * kColorTileDim; 228 float* HWY_RESTRICT coeffs_b = coeffs_yb + kColorTileDim * kColorTileDim; 229 float* HWY_RESTRICT scratch_space = coeffs_b + kColorTileDim * kColorTileDim; 230 float* scratch_space_end = 231 scratch_space + 2 * AcStrategy::kMaxCoeffArea + dct_scratch_size; 232 JXL_DASSERT(scratch_space_end == block_y + CfLHeuristics::ItemsPerThread()); 233 (void)scratch_space_end; 234 235 // Small (~256 bytes each) 236 HWY_ALIGN_MAX float 237 dc_y[AcStrategy::kMaxCoeffBlocks * AcStrategy::kMaxCoeffBlocks] = {}; 238 HWY_ALIGN_MAX float 239 dc_x[AcStrategy::kMaxCoeffBlocks * AcStrategy::kMaxCoeffBlocks] = {}; 240 HWY_ALIGN_MAX float 241 dc_b[AcStrategy::kMaxCoeffBlocks * AcStrategy::kMaxCoeffBlocks] = {}; 242 size_t num_ac = 0; 243 244 for (size_t y = y0; y < y1; ++y) { 245 const float* JXL_RESTRICT row_y = 246 opsin_rect.ConstPlaneRow(opsin, 1, y * kBlockDim); 247 const float* JXL_RESTRICT row_x = 248 opsin_rect.ConstPlaneRow(opsin, 0, y * kBlockDim); 249 const float* JXL_RESTRICT row_b = 250 opsin_rect.ConstPlaneRow(opsin, 2, y * kBlockDim); 251 size_t stride = opsin.PixelsPerRow(); 252 253 for (size_t x = x0; x < x1; x++) { 254 AcStrategy acs = use_dct8 255 ? AcStrategy::FromRawStrategy(AcStrategy::Type::DCT) 256 : ac_strategy->ConstRow(y)[x]; 257 if (!acs.IsFirstBlock()) continue; 258 size_t xs = acs.covered_blocks_x(); 259 TransformFromPixels(acs.Strategy(), row_y + x * kBlockDim, stride, 260 block_y, scratch_space); 261 DCFromLowestFrequencies(acs.Strategy(), block_y, dc_y, xs); 262 TransformFromPixels(acs.Strategy(), row_x + x * kBlockDim, stride, 263 block_x, scratch_space); 264 DCFromLowestFrequencies(acs.Strategy(), block_x, dc_x, xs); 265 TransformFromPixels(acs.Strategy(), row_b + x * kBlockDim, stride, 266 block_b, scratch_space); 267 DCFromLowestFrequencies(acs.Strategy(), block_b, dc_b, xs); 268 const float* const JXL_RESTRICT qm_x = 269 dequant.InvMatrix(acs.Strategy(), 0); 270 const float* const JXL_RESTRICT qm_b = 271 dequant.InvMatrix(acs.Strategy(), 2); 272 float q_dc_x = use_dct8 ? 1 : 1.0f / quantizer->GetInvDcStep(0); 273 float q_dc_b = use_dct8 ? 1 : 1.0f / quantizer->GetInvDcStep(2); 274 275 // Copy DCs in dc_values. 276 for (size_t iy = 0; iy < acs.covered_blocks_y(); iy++) { 277 for (size_t ix = 0; ix < xs; ix++) { 278 dc_values_yx[(iy + y) * xsize_blocks + ix + x] = 279 dc_y[iy * xs + ix] * q_dc_x; 280 dc_values_x[(iy + y) * xsize_blocks + ix + x] = 281 dc_x[iy * xs + ix] * q_dc_x; 282 dc_values_yb[(iy + y) * xsize_blocks + ix + x] = 283 dc_y[iy * xs + ix] * q_dc_b; 284 dc_values_b[(iy + y) * xsize_blocks + ix + x] = 285 dc_b[iy * xs + ix] * q_dc_b; 286 } 287 } 288 289 // Do not use this block for computing AC CfL. 290 if (acs.covered_blocks_x() + x0 > x1 || 291 acs.covered_blocks_y() + y0 > y1) { 292 continue; 293 } 294 295 // Copy AC coefficients in the local block. The order in which 296 // coefficients get stored does not matter. 297 size_t cx = acs.covered_blocks_x(); 298 size_t cy = acs.covered_blocks_y(); 299 CoefficientLayout(&cy, &cx); 300 // Zero out LFs. This introduces terms in the optimization loop that 301 // don't affect the result, as they are all 0, but allow for simpler 302 // SIMDfication. 303 for (size_t iy = 0; iy < cy; iy++) { 304 for (size_t ix = 0; ix < cx; ix++) { 305 block_y[cx * kBlockDim * iy + ix] = 0; 306 block_x[cx * kBlockDim * iy + ix] = 0; 307 block_b[cx * kBlockDim * iy + ix] = 0; 308 } 309 } 310 // Unclear why this is like it is. (This works slightly better 311 // than the previous approach which was also a hack.) 312 const float qq = 313 (raw_quant_field == nullptr) ? 1.0f : raw_quant_field->Row(y)[x]; 314 // Experimentally values 128-130 seem best -- I don't know why we 315 // need this multiplier. 316 const float kStrangeMultiplier = 128; 317 float q = use_dct8 ? 1 : quantizer->Scale() * kStrangeMultiplier * qq; 318 const auto qv = Set(df, q); 319 for (size_t i = 0; i < cx * cy * 64; i += Lanes(df)) { 320 const auto b_y = Load(df, block_y + i); 321 const auto b_x = Load(df, block_x + i); 322 const auto b_b = Load(df, block_b + i); 323 const auto qqm_x = Mul(qv, Load(df, qm_x + i)); 324 const auto qqm_b = Mul(qv, Load(df, qm_b + i)); 325 Store(Mul(b_y, qqm_x), df, coeffs_yx + num_ac); 326 Store(Mul(b_x, qqm_x), df, coeffs_x + num_ac); 327 Store(Mul(b_y, qqm_b), df, coeffs_yb + num_ac); 328 Store(Mul(b_b, qqm_b), df, coeffs_b + num_ac); 329 num_ac += Lanes(df); 330 } 331 } 332 } 333 JXL_CHECK(num_ac % Lanes(df) == 0); 334 row_out_x[tx] = FindBestMultiplier(coeffs_yx, coeffs_x, num_ac, 0.0f, 335 kDistanceMultiplierAC, fast); 336 row_out_b[tx] = 337 FindBestMultiplier(coeffs_yb, coeffs_b, num_ac, jxl::cms::kYToBRatio, 338 kDistanceMultiplierAC, fast); 339 } 340 341 // NOLINTNEXTLINE(google-readability-namespace-comments) 342 } // namespace HWY_NAMESPACE 343 } // namespace jxl 344 HWY_AFTER_NAMESPACE(); 345 346 #if HWY_ONCE 347 namespace jxl { 348 349 HWY_EXPORT(InitDCStorage); 350 HWY_EXPORT(ComputeTile); 351 352 Status CfLHeuristics::Init(const Rect& rect) { 353 size_t xsize_blocks = rect.xsize() / kBlockDim; 354 size_t ysize_blocks = rect.ysize() / kBlockDim; 355 return HWY_DYNAMIC_DISPATCH(InitDCStorage)(xsize_blocks * ysize_blocks, 356 &dc_values); 357 } 358 359 void CfLHeuristics::ComputeTile(const Rect& r, const Image3F& opsin, 360 const Rect& opsin_rect, 361 const DequantMatrices& dequant, 362 const AcStrategyImage* ac_strategy, 363 const ImageI* raw_quant_field, 364 const Quantizer* quantizer, bool fast, 365 size_t thread, ColorCorrelationMap* cmap) { 366 bool use_dct8 = ac_strategy == nullptr; 367 HWY_DYNAMIC_DISPATCH(ComputeTile) 368 (opsin, opsin_rect, dequant, ac_strategy, raw_quant_field, quantizer, r, fast, 369 use_dct8, &cmap->ytox_map, &cmap->ytob_map, &dc_values, 370 mem.get() + thread * ItemsPerThread()); 371 } 372 373 void ColorCorrelationMapEncodeDC(const ColorCorrelationMap& map, 374 BitWriter* writer, size_t layer, 375 AuxOut* aux_out) { 376 float color_factor = map.GetColorFactor(); 377 float base_correlation_x = map.GetBaseCorrelationX(); 378 float base_correlation_b = map.GetBaseCorrelationB(); 379 int32_t ytox_dc = map.GetYToXDC(); 380 int32_t ytob_dc = map.GetYToBDC(); 381 382 BitWriter::Allotment allotment(writer, 1 + 2 * kBitsPerByte + 12 + 32); 383 if (ytox_dc == 0 && ytob_dc == 0 && color_factor == kDefaultColorFactor && 384 base_correlation_x == 0.0f && 385 base_correlation_b == jxl::cms::kYToBRatio) { 386 writer->Write(1, 1); 387 allotment.ReclaimAndCharge(writer, layer, aux_out); 388 return; 389 } 390 writer->Write(1, 0); 391 JXL_CHECK(U32Coder::Write(kColorFactorDist, color_factor, writer)); 392 JXL_CHECK(F16Coder::Write(base_correlation_x, writer)); 393 JXL_CHECK(F16Coder::Write(base_correlation_b, writer)); 394 writer->Write(kBitsPerByte, ytox_dc - std::numeric_limits<int8_t>::min()); 395 writer->Write(kBitsPerByte, ytob_dc - std::numeric_limits<int8_t>::min()); 396 allotment.ReclaimAndCharge(writer, layer, aux_out); 397 } 398 399 } // namespace jxl 400 #endif // HWY_ONCE