libjxl

FORK: libjxl patches used on blog
git clone https://git.neptards.moe/blog/libjxl.git
Log | Files | Refs | Submodules | README | LICENSE

splines.cc (28517B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 #include "lib/jxl/splines.h"
      7 
      8 #include <algorithm>
      9 #include <cinttypes>
     10 #include <cmath>
     11 #include <limits>
     12 
     13 #include "lib/jxl/base/common.h"
     14 #include "lib/jxl/base/printf_macros.h"
     15 #include "lib/jxl/base/status.h"
     16 #include "lib/jxl/chroma_from_luma.h"
     17 #include "lib/jxl/common.h"  // JXL_HIGH_PRECISION
     18 #include "lib/jxl/dct_scales.h"
     19 #include "lib/jxl/dec_ans.h"
     20 #include "lib/jxl/dec_bit_reader.h"
     21 #include "lib/jxl/pack_signed.h"
     22 
     23 #undef HWY_TARGET_INCLUDE
     24 #define HWY_TARGET_INCLUDE "lib/jxl/splines.cc"
     25 #include <hwy/foreach_target.h>
     26 #include <hwy/highway.h>
     27 
     28 #include "lib/jxl/base/fast_math-inl.h"
     29 HWY_BEFORE_NAMESPACE();
     30 namespace jxl {
     31 namespace HWY_NAMESPACE {
     32 namespace {
     33 
     34 // These templates are not found via ADL.
     35 using hwy::HWY_NAMESPACE::Mul;
     36 using hwy::HWY_NAMESPACE::MulAdd;
     37 using hwy::HWY_NAMESPACE::MulSub;
     38 using hwy::HWY_NAMESPACE::Sqrt;
     39 using hwy::HWY_NAMESPACE::Sub;
     40 
     41 // Given a set of DCT coefficients, this returns the result of performing cosine
     42 // interpolation on the original samples.
     43 float ContinuousIDCT(const float dct[32], const float t) {
     44   // We compute here the DCT-3 of the `dct` vector, rescaled by a factor of
     45   // sqrt(32). This is such that an input vector vector {x, 0, ..., 0} produces
     46   // a constant result of x. dct[0] was scaled in Dequantize() to allow uniform
     47   // treatment of all the coefficients.
     48   constexpr float kMultipliers[32] = {
     49       kPi / 32 * 0,  kPi / 32 * 1,  kPi / 32 * 2,  kPi / 32 * 3,  kPi / 32 * 4,
     50       kPi / 32 * 5,  kPi / 32 * 6,  kPi / 32 * 7,  kPi / 32 * 8,  kPi / 32 * 9,
     51       kPi / 32 * 10, kPi / 32 * 11, kPi / 32 * 12, kPi / 32 * 13, kPi / 32 * 14,
     52       kPi / 32 * 15, kPi / 32 * 16, kPi / 32 * 17, kPi / 32 * 18, kPi / 32 * 19,
     53       kPi / 32 * 20, kPi / 32 * 21, kPi / 32 * 22, kPi / 32 * 23, kPi / 32 * 24,
     54       kPi / 32 * 25, kPi / 32 * 26, kPi / 32 * 27, kPi / 32 * 28, kPi / 32 * 29,
     55       kPi / 32 * 30, kPi / 32 * 31,
     56   };
     57   HWY_CAPPED(float, 32) df;
     58   auto result = Zero(df);
     59   const auto tandhalf = Set(df, t + 0.5f);
     60   for (int i = 0; i < 32; i += Lanes(df)) {
     61     auto cos_arg = Mul(LoadU(df, kMultipliers + i), tandhalf);
     62     auto cos = FastCosf(df, cos_arg);
     63     auto local_res = Mul(LoadU(df, dct + i), cos);
     64     result = MulAdd(Set(df, kSqrt2), local_res, result);
     65   }
     66   return GetLane(SumOfLanes(df, result));
     67 }
     68 
     69 template <typename DF>
     70 void DrawSegment(DF df, const SplineSegment& segment, const bool add,
     71                  const size_t y, const size_t x, float* JXL_RESTRICT rows[3]) {
     72   Rebind<int32_t, DF> di;
     73   const auto inv_sigma = Set(df, segment.inv_sigma);
     74   const auto half = Set(df, 0.5f);
     75   const auto one_over_2s2 = Set(df, 0.353553391f);
     76   const auto sigma_over_4_times_intensity =
     77       Set(df, segment.sigma_over_4_times_intensity);
     78   const auto dx = Sub(ConvertTo(df, Iota(di, x)), Set(df, segment.center_x));
     79   const auto dy = Set(df, y - segment.center_y);
     80   const auto sqd = MulAdd(dx, dx, Mul(dy, dy));
     81   const auto distance = Sqrt(sqd);
     82   const auto one_dimensional_factor =
     83       Sub(FastErff(df, Mul(MulAdd(distance, half, one_over_2s2), inv_sigma)),
     84           FastErff(df, Mul(MulSub(distance, half, one_over_2s2), inv_sigma)));
     85   auto local_intensity =
     86       Mul(sigma_over_4_times_intensity,
     87           Mul(one_dimensional_factor, one_dimensional_factor));
     88   for (size_t c = 0; c < 3; ++c) {
     89     const auto cm = Set(df, add ? segment.color[c] : -segment.color[c]);
     90     const auto in = LoadU(df, rows[c] + x);
     91     StoreU(MulAdd(cm, local_intensity, in), df, rows[c] + x);
     92   }
     93 }
     94 
     95 void DrawSegment(const SplineSegment& segment, const bool add, const size_t y,
     96                  const ssize_t x0, ssize_t x1, float* JXL_RESTRICT rows[3]) {
     97   ssize_t x =
     98       std::max<ssize_t>(x0, segment.center_x - segment.maximum_distance + 0.5f);
     99   // one-past-the-end
    100   x1 =
    101       std::min<ssize_t>(x1, segment.center_x + segment.maximum_distance + 1.5f);
    102   HWY_FULL(float) df;
    103   for (; x + static_cast<ssize_t>(Lanes(df)) <= x1; x += Lanes(df)) {
    104     DrawSegment(df, segment, add, y, x, rows);
    105   }
    106   for (; x < x1; ++x) {
    107     DrawSegment(HWY_CAPPED(float, 1)(), segment, add, y, x, rows);
    108   }
    109 }
    110 
    111 void ComputeSegments(const Spline::Point& center, const float intensity,
    112                      const float color[3], const float sigma,
    113                      std::vector<SplineSegment>& segments,
    114                      std::vector<std::pair<size_t, size_t>>& segments_by_y) {
    115   // Sanity check sigma, inverse sigma and intensity
    116   if (!(std::isfinite(sigma) && sigma != 0.0f && std::isfinite(1.0f / sigma) &&
    117         std::isfinite(intensity))) {
    118     return;
    119   }
    120 #if JXL_HIGH_PRECISION
    121   constexpr float kDistanceExp = 5;
    122 #else
    123   // About 30% faster.
    124   constexpr float kDistanceExp = 3;
    125 #endif
    126   // We cap from below colors to at least 0.01.
    127   float max_color = 0.01f;
    128   for (size_t c = 0; c < 3; c++) {
    129     max_color = std::max(max_color, std::abs(color[c] * intensity));
    130   }
    131   // Distance beyond which max_color*intensity*exp(-d^2 / (2 * sigma^2)) drops
    132   // below 10^-kDistanceExp.
    133   const float maximum_distance =
    134       std::sqrt(-2 * sigma * sigma *
    135                 (std::log(0.1) * kDistanceExp - std::log(max_color)));
    136   SplineSegment segment;
    137   segment.center_y = center.y;
    138   segment.center_x = center.x;
    139   memcpy(segment.color, color, sizeof(segment.color));
    140   segment.inv_sigma = 1.0f / sigma;
    141   segment.sigma_over_4_times_intensity = .25f * sigma * intensity;
    142   segment.maximum_distance = maximum_distance;
    143   ssize_t y0 = std::llround(center.y - maximum_distance);
    144   ssize_t y1 =
    145       std::llround(center.y + maximum_distance) + 1;  // one-past-the-end
    146   for (ssize_t y = std::max<ssize_t>(y0, 0); y < y1; y++) {
    147     segments_by_y.emplace_back(y, segments.size());
    148   }
    149   segments.push_back(segment);
    150 }
    151 
    152 void DrawSegments(float* JXL_RESTRICT row_x, float* JXL_RESTRICT row_y,
    153                   float* JXL_RESTRICT row_b, const Rect& image_rect,
    154                   const bool add, const SplineSegment* segments,
    155                   const size_t* segment_indices,
    156                   const size_t* segment_y_start) {
    157   JXL_ASSERT(image_rect.ysize() == 1);
    158   float* JXL_RESTRICT rows[3] = {row_x - image_rect.x0(),
    159                                  row_y - image_rect.x0(),
    160                                  row_b - image_rect.x0()};
    161   size_t y = image_rect.y0();
    162   for (size_t i = segment_y_start[y]; i < segment_y_start[y + 1]; i++) {
    163     DrawSegment(segments[segment_indices[i]], add, y, image_rect.x0(),
    164                 image_rect.x0() + image_rect.xsize(), rows);
    165   }
    166 }
    167 
    168 void SegmentsFromPoints(
    169     const Spline& spline,
    170     const std::vector<std::pair<Spline::Point, float>>& points_to_draw,
    171     const float arc_length, std::vector<SplineSegment>& segments,
    172     std::vector<std::pair<size_t, size_t>>& segments_by_y) {
    173   const float inv_arc_length = 1.0f / arc_length;
    174   int k = 0;
    175   for (const auto& point_to_draw : points_to_draw) {
    176     const Spline::Point& point = point_to_draw.first;
    177     const float multiplier = point_to_draw.second;
    178     const float progress_along_arc =
    179         std::min(1.f, (k * kDesiredRenderingDistance) * inv_arc_length);
    180     ++k;
    181     float color[3];
    182     for (size_t c = 0; c < 3; ++c) {
    183       color[c] =
    184           ContinuousIDCT(spline.color_dct[c], (32 - 1) * progress_along_arc);
    185     }
    186     const float sigma =
    187         ContinuousIDCT(spline.sigma_dct, (32 - 1) * progress_along_arc);
    188     ComputeSegments(point, multiplier, color, sigma, segments, segments_by_y);
    189   }
    190 }
    191 }  // namespace
    192 // NOLINTNEXTLINE(google-readability-namespace-comments)
    193 }  // namespace HWY_NAMESPACE
    194 }  // namespace jxl
    195 HWY_AFTER_NAMESPACE();
    196 
    197 #if HWY_ONCE
    198 namespace jxl {
    199 HWY_EXPORT(SegmentsFromPoints);
    200 HWY_EXPORT(DrawSegments);
    201 
    202 namespace {
    203 
    204 // It is not in spec, but reasonable limit to avoid overflows.
    205 template <typename T>
    206 Status ValidateSplinePointPos(const T& x, const T& y) {
    207   constexpr T kSplinePosLimit = 1u << 23;
    208   if ((x >= kSplinePosLimit) || (x <= -kSplinePosLimit) ||
    209       (y >= kSplinePosLimit) || (y <= -kSplinePosLimit)) {
    210     return JXL_FAILURE("Spline coordinates out of bounds");
    211   }
    212   return true;
    213 }
    214 
    215 // Maximum number of spline control points per frame is
    216 //   std::min(kMaxNumControlPoints, xsize * ysize / 2)
    217 constexpr size_t kMaxNumControlPoints = 1u << 20u;
    218 constexpr size_t kMaxNumControlPointsPerPixelRatio = 2;
    219 
    220 float AdjustedQuant(const int32_t adjustment) {
    221   return (adjustment >= 0) ? (1.f + .125f * adjustment)
    222                            : 1.f / (1.f - .125f * adjustment);
    223 }
    224 
    225 float InvAdjustedQuant(const int32_t adjustment) {
    226   return (adjustment >= 0) ? 1.f / (1.f + .125f * adjustment)
    227                            : (1.f - .125f * adjustment);
    228 }
    229 
    230 // X, Y, B, sigma.
    231 constexpr float kChannelWeight[] = {0.0042f, 0.075f, 0.07f, .3333f};
    232 
    233 Status DecodeAllStartingPoints(std::vector<Spline::Point>* const points,
    234                                BitReader* const br, ANSSymbolReader* reader,
    235                                const std::vector<uint8_t>& context_map,
    236                                const size_t num_splines) {
    237   points->clear();
    238   points->reserve(num_splines);
    239   int64_t last_x = 0;
    240   int64_t last_y = 0;
    241   for (size_t i = 0; i < num_splines; i++) {
    242     int64_t x =
    243         reader->ReadHybridUint(kStartingPositionContext, br, context_map);
    244     int64_t y =
    245         reader->ReadHybridUint(kStartingPositionContext, br, context_map);
    246     if (i != 0) {
    247       x = UnpackSigned(x) + last_x;
    248       y = UnpackSigned(y) + last_y;
    249     }
    250     JXL_RETURN_IF_ERROR(ValidateSplinePointPos(x, y));
    251     points->emplace_back(static_cast<float>(x), static_cast<float>(y));
    252     last_x = x;
    253     last_y = y;
    254   }
    255   return true;
    256 }
    257 
    258 struct Vector {
    259   float x, y;
    260   Vector operator-() const { return {-x, -y}; }
    261   Vector operator+(const Vector& other) const {
    262     return {x + other.x, y + other.y};
    263   }
    264   float SquaredNorm() const { return x * x + y * y; }
    265 };
    266 Vector operator*(const float k, const Vector& vec) {
    267   return {k * vec.x, k * vec.y};
    268 }
    269 
    270 Spline::Point operator+(const Spline::Point& p, const Vector& vec) {
    271   return {p.x + vec.x, p.y + vec.y};
    272 }
    273 Vector operator-(const Spline::Point& a, const Spline::Point& b) {
    274   return {a.x - b.x, a.y - b.y};
    275 }
    276 
    277 // TODO(eustas): avoid making a copy of "points".
    278 void DrawCentripetalCatmullRomSpline(std::vector<Spline::Point> points,
    279                                      std::vector<Spline::Point>& result) {
    280   if (points.empty()) return;
    281   if (points.size() == 1) {
    282     result.push_back(points[0]);
    283     return;
    284   }
    285   // Number of points to compute between each control point.
    286   static constexpr int kNumPoints = 16;
    287   result.reserve((points.size() - 1) * kNumPoints + 1);
    288   points.insert(points.begin(), points[0] + (points[0] - points[1]));
    289   points.push_back(points[points.size() - 1] +
    290                    (points[points.size() - 1] - points[points.size() - 2]));
    291   // points has at least 4 elements at this point.
    292   for (size_t start = 0; start < points.size() - 3; ++start) {
    293     // 4 of them are used, and we draw from p[1] to p[2].
    294     const Spline::Point* const p = &points[start];
    295     result.push_back(p[1]);
    296     float d[3];
    297     float t[4];
    298     t[0] = 0;
    299     for (int k = 0; k < 3; ++k) {
    300       // TODO(eustas): for each segment delta is calculated 3 times...
    301       // TODO(eustas): restrict d[k] with reasonable limit and spec it.
    302       d[k] = std::sqrt(hypotf(p[k + 1].x - p[k].x, p[k + 1].y - p[k].y));
    303       t[k + 1] = t[k] + d[k];
    304     }
    305     for (int i = 1; i < kNumPoints; ++i) {
    306       const float tt = d[0] + (static_cast<float>(i) / kNumPoints) * d[1];
    307       Spline::Point a[3];
    308       for (int k = 0; k < 3; ++k) {
    309         // TODO(eustas): reciprocal multiplication would be faster.
    310         a[k] = p[k] + ((tt - t[k]) / d[k]) * (p[k + 1] - p[k]);
    311       }
    312       Spline::Point b[2];
    313       for (int k = 0; k < 2; ++k) {
    314         b[k] = a[k] + ((tt - t[k]) / (d[k] + d[k + 1])) * (a[k + 1] - a[k]);
    315       }
    316       result.push_back(b[0] + ((tt - t[1]) / d[1]) * (b[1] - b[0]));
    317     }
    318   }
    319   result.push_back(points[points.size() - 2]);
    320 }
    321 
    322 // Move along the line segments defined by `points`, `kDesiredRenderingDistance`
    323 // pixels at a time, and call `functor` with each point and the actual distance
    324 // to the previous point (which will always be kDesiredRenderingDistance except
    325 // possibly for the very last point).
    326 // TODO(eustas): this method always adds the last point, but never the first
    327 //               (unless those are one); I believe both ends matter.
    328 template <typename Points, typename Functor>
    329 void ForEachEquallySpacedPoint(const Points& points, const Functor& functor) {
    330   JXL_ASSERT(!points.empty());
    331   Spline::Point current = points.front();
    332   functor(current, kDesiredRenderingDistance);
    333   auto next = points.begin();
    334   while (next != points.end()) {
    335     const Spline::Point* previous = &current;
    336     float arclength_from_previous = 0.f;
    337     for (;;) {
    338       if (next == points.end()) {
    339         functor(*previous, arclength_from_previous);
    340         return;
    341       }
    342       const float arclength_to_next =
    343           std::sqrt((*next - *previous).SquaredNorm());
    344       if (arclength_from_previous + arclength_to_next >=
    345           kDesiredRenderingDistance) {
    346         current =
    347             *previous + ((kDesiredRenderingDistance - arclength_from_previous) /
    348                          arclength_to_next) *
    349                             (*next - *previous);
    350         functor(current, kDesiredRenderingDistance);
    351         break;
    352       }
    353       arclength_from_previous += arclength_to_next;
    354       previous = &*next;
    355       ++next;
    356     }
    357   }
    358 }
    359 
    360 }  // namespace
    361 
    362 QuantizedSpline::QuantizedSpline(const Spline& original,
    363                                  const int32_t quantization_adjustment,
    364                                  const float y_to_x, const float y_to_b) {
    365   JXL_ASSERT(!original.control_points.empty());
    366   control_points_.reserve(original.control_points.size() - 1);
    367   const Spline::Point& starting_point = original.control_points.front();
    368   int previous_x = static_cast<int>(std::roundf(starting_point.x));
    369   int previous_y = static_cast<int>(std::roundf(starting_point.y));
    370   int previous_delta_x = 0;
    371   int previous_delta_y = 0;
    372   for (auto it = original.control_points.begin() + 1;
    373        it != original.control_points.end(); ++it) {
    374     const int new_x = static_cast<int>(std::roundf(it->x));
    375     const int new_y = static_cast<int>(std::roundf(it->y));
    376     const int new_delta_x = new_x - previous_x;
    377     const int new_delta_y = new_y - previous_y;
    378     control_points_.emplace_back(new_delta_x - previous_delta_x,
    379                                  new_delta_y - previous_delta_y);
    380     previous_delta_x = new_delta_x;
    381     previous_delta_y = new_delta_y;
    382     previous_x = new_x;
    383     previous_y = new_y;
    384   }
    385 
    386   const auto to_int = [](float v) -> int {
    387     // Maximal int representable with float.
    388     constexpr float kMax = std::numeric_limits<int>::max() - 127;
    389     constexpr float kMin = -kMax;
    390     return static_cast<int>(std::roundf(Clamp1(v, kMin, kMax)));
    391   };
    392 
    393   const auto quant = AdjustedQuant(quantization_adjustment);
    394   const auto inv_quant = InvAdjustedQuant(quantization_adjustment);
    395   for (int c : {1, 0, 2}) {
    396     float factor = (c == 0) ? y_to_x : (c == 1) ? 0 : y_to_b;
    397     for (int i = 0; i < 32; ++i) {
    398       const float dct_factor = (i == 0) ? kSqrt2 : 1.0f;
    399       const float inv_dct_factor = (i == 0) ? kSqrt0_5 : 1.0f;
    400       auto restored_y =
    401           color_dct_[1][i] * inv_dct_factor * kChannelWeight[1] * inv_quant;
    402       auto decorellated = original.color_dct[c][i] - factor * restored_y;
    403       color_dct_[c][i] =
    404           to_int(decorellated * dct_factor * quant / kChannelWeight[c]);
    405     }
    406   }
    407   for (int i = 0; i < 32; ++i) {
    408     const float dct_factor = (i == 0) ? kSqrt2 : 1.0f;
    409     sigma_dct_[i] =
    410         to_int(original.sigma_dct[i] * dct_factor * quant / kChannelWeight[3]);
    411   }
    412 }
    413 
    414 Status QuantizedSpline::Dequantize(const Spline::Point& starting_point,
    415                                    const int32_t quantization_adjustment,
    416                                    const float y_to_x, const float y_to_b,
    417                                    const uint64_t image_size,
    418                                    uint64_t* total_estimated_area_reached,
    419                                    Spline& result) const {
    420   constexpr uint64_t kOne = static_cast<uint64_t>(1);
    421   const uint64_t area_limit =
    422       std::min(1024 * image_size + (kOne << 32), kOne << 42);
    423 
    424   result.control_points.clear();
    425   result.control_points.reserve(control_points_.size() + 1);
    426   float px = std::roundf(starting_point.x);
    427   float py = std::roundf(starting_point.y);
    428   JXL_RETURN_IF_ERROR(ValidateSplinePointPos(px, py));
    429   int current_x = static_cast<int>(px);
    430   int current_y = static_cast<int>(py);
    431   result.control_points.emplace_back(static_cast<float>(current_x),
    432                                      static_cast<float>(current_y));
    433   int current_delta_x = 0;
    434   int current_delta_y = 0;
    435   uint64_t manhattan_distance = 0;
    436   for (const auto& point : control_points_) {
    437     current_delta_x += point.first;
    438     current_delta_y += point.second;
    439     manhattan_distance += std::abs(current_delta_x) + std::abs(current_delta_y);
    440     if (manhattan_distance > area_limit) {
    441       return JXL_FAILURE("Too large manhattan_distance reached: %" PRIu64,
    442                          manhattan_distance);
    443     }
    444     JXL_RETURN_IF_ERROR(
    445         ValidateSplinePointPos(current_delta_x, current_delta_y));
    446     current_x += current_delta_x;
    447     current_y += current_delta_y;
    448     JXL_RETURN_IF_ERROR(ValidateSplinePointPos(current_x, current_y));
    449     result.control_points.emplace_back(static_cast<float>(current_x),
    450                                        static_cast<float>(current_y));
    451   }
    452 
    453   const auto inv_quant = InvAdjustedQuant(quantization_adjustment);
    454   for (int c = 0; c < 3; ++c) {
    455     for (int i = 0; i < 32; ++i) {
    456       const float inv_dct_factor = (i == 0) ? kSqrt0_5 : 1.0f;
    457       result.color_dct[c][i] =
    458           color_dct_[c][i] * inv_dct_factor * kChannelWeight[c] * inv_quant;
    459     }
    460   }
    461   for (int i = 0; i < 32; ++i) {
    462     result.color_dct[0][i] += y_to_x * result.color_dct[1][i];
    463     result.color_dct[2][i] += y_to_b * result.color_dct[1][i];
    464   }
    465   uint64_t width_estimate = 0;
    466 
    467   uint64_t color[3] = {};
    468   for (int c = 0; c < 3; ++c) {
    469     for (int i = 0; i < 32; ++i) {
    470       color[c] += static_cast<uint64_t>(
    471           std::ceil(inv_quant * std::abs(color_dct_[c][i])));
    472     }
    473   }
    474   color[0] += static_cast<uint64_t>(std::ceil(std::abs(y_to_x))) * color[1];
    475   color[2] += static_cast<uint64_t>(std::ceil(std::abs(y_to_b))) * color[1];
    476   // This is not taking kChannelWeight into account, but up to constant factors
    477   // it gives an indication of the influence of the color values on the area
    478   // that will need to be rendered.
    479   const uint64_t max_color = std::max({color[1], color[0], color[2]});
    480   uint64_t logcolor =
    481       std::max(kOne, static_cast<uint64_t>(CeilLog2Nonzero(kOne + max_color)));
    482 
    483   const float weight_limit =
    484       std::ceil(std::sqrt((static_cast<float>(area_limit) / logcolor) /
    485                           std::max<size_t>(1, manhattan_distance)));
    486 
    487   for (int i = 0; i < 32; ++i) {
    488     const float inv_dct_factor = (i == 0) ? kSqrt0_5 : 1.0f;
    489     result.sigma_dct[i] =
    490         sigma_dct_[i] * inv_dct_factor * kChannelWeight[3] * inv_quant;
    491     // If we include the factor kChannelWeight[3]=.3333f here, we get a
    492     // realistic area estimate. We leave it out to simplify the calculations,
    493     // and understand that this way we underestimate the area by a factor of
    494     // 1/(0.3333*0.3333). This is taken into account in the limits below.
    495     float weight_f = std::ceil(inv_quant * std::abs(sigma_dct_[i]));
    496     uint64_t weight =
    497         static_cast<uint64_t>(std::min(weight_limit, std::max(1.0f, weight_f)));
    498     width_estimate += weight * weight * logcolor;
    499   }
    500   *total_estimated_area_reached += (width_estimate * manhattan_distance);
    501   if (*total_estimated_area_reached > area_limit) {
    502     return JXL_FAILURE("Too large total_estimated_area eached: %" PRIu64,
    503                        *total_estimated_area_reached);
    504   }
    505 
    506   return true;
    507 }
    508 
    509 Status QuantizedSpline::Decode(const std::vector<uint8_t>& context_map,
    510                                ANSSymbolReader* const decoder,
    511                                BitReader* const br,
    512                                const size_t max_control_points,
    513                                size_t* total_num_control_points) {
    514   const size_t num_control_points =
    515       decoder->ReadHybridUint(kNumControlPointsContext, br, context_map);
    516   if (num_control_points > max_control_points) {
    517     return JXL_FAILURE("Too many control points: %" PRIuS, num_control_points);
    518   }
    519   *total_num_control_points += num_control_points;
    520   if (*total_num_control_points > max_control_points) {
    521     return JXL_FAILURE("Too many control points: %" PRIuS,
    522                        *total_num_control_points);
    523   }
    524   control_points_.resize(num_control_points);
    525   // Maximal image dimension.
    526   constexpr int64_t kDeltaLimit = 1u << 30;
    527   for (std::pair<int64_t, int64_t>& control_point : control_points_) {
    528     control_point.first = UnpackSigned(
    529         decoder->ReadHybridUint(kControlPointsContext, br, context_map));
    530     control_point.second = UnpackSigned(
    531         decoder->ReadHybridUint(kControlPointsContext, br, context_map));
    532     // Check delta-deltas are not outrageous; it is not in spec, but there is
    533     // no reason to allow larger values.
    534     if ((control_point.first >= kDeltaLimit) ||
    535         (control_point.first <= -kDeltaLimit) ||
    536         (control_point.second >= kDeltaLimit) ||
    537         (control_point.second <= -kDeltaLimit)) {
    538       return JXL_FAILURE("Spline delta-delta is out of bounds");
    539     }
    540   }
    541 
    542   const auto decode_dct = [decoder, br, &context_map](int dct[32]) -> Status {
    543     constexpr int kWeirdNumber = std::numeric_limits<int>::min();
    544     for (int i = 0; i < 32; ++i) {
    545       dct[i] =
    546           UnpackSigned(decoder->ReadHybridUint(kDCTContext, br, context_map));
    547       if (dct[i] == kWeirdNumber) {
    548         return JXL_FAILURE("The weird number in spline DCT");
    549       }
    550     }
    551     return true;
    552   };
    553   for (int c = 0; c < 3; ++c) {
    554     JXL_RETURN_IF_ERROR(decode_dct(color_dct_[c]));
    555   }
    556   JXL_RETURN_IF_ERROR(decode_dct(sigma_dct_));
    557   return true;
    558 }
    559 
    560 void Splines::Clear() {
    561   quantization_adjustment_ = 0;
    562   splines_.clear();
    563   starting_points_.clear();
    564   segments_.clear();
    565   segment_indices_.clear();
    566   segment_y_start_.clear();
    567 }
    568 
    569 Status Splines::Decode(jxl::BitReader* br, const size_t num_pixels) {
    570   std::vector<uint8_t> context_map;
    571   ANSCode code;
    572   JXL_RETURN_IF_ERROR(
    573       DecodeHistograms(br, kNumSplineContexts, &code, &context_map));
    574   ANSSymbolReader decoder(&code, br);
    575   size_t num_splines =
    576       decoder.ReadHybridUint(kNumSplinesContext, br, context_map);
    577   size_t max_control_points = std::min(
    578       kMaxNumControlPoints, num_pixels / kMaxNumControlPointsPerPixelRatio);
    579   if (num_splines > max_control_points ||
    580       num_splines + 1 > max_control_points) {
    581     return JXL_FAILURE("Too many splines: %" PRIuS, num_splines);
    582   }
    583   num_splines++;
    584   JXL_RETURN_IF_ERROR(DecodeAllStartingPoints(&starting_points_, br, &decoder,
    585                                               context_map, num_splines));
    586 
    587   quantization_adjustment_ = UnpackSigned(
    588       decoder.ReadHybridUint(kQuantizationAdjustmentContext, br, context_map));
    589 
    590   splines_.clear();
    591   splines_.reserve(num_splines);
    592   size_t num_control_points = num_splines;
    593   for (size_t i = 0; i < num_splines; ++i) {
    594     QuantizedSpline spline;
    595     JXL_RETURN_IF_ERROR(spline.Decode(context_map, &decoder, br,
    596                                       max_control_points, &num_control_points));
    597     splines_.push_back(std::move(spline));
    598   }
    599 
    600   JXL_RETURN_IF_ERROR(decoder.CheckANSFinalState());
    601 
    602   if (!HasAny()) {
    603     return JXL_FAILURE("Decoded splines but got none");
    604   }
    605 
    606   return true;
    607 }
    608 
    609 void Splines::AddTo(Image3F* const opsin, const Rect& opsin_rect,
    610                     const Rect& image_rect) const {
    611   Apply</*add=*/true>(opsin, opsin_rect, image_rect);
    612 }
    613 void Splines::AddToRow(float* JXL_RESTRICT row_x, float* JXL_RESTRICT row_y,
    614                        float* JXL_RESTRICT row_b, const Rect& image_row) const {
    615   ApplyToRow</*add=*/true>(row_x, row_y, row_b, image_row);
    616 }
    617 
    618 void Splines::SubtractFrom(Image3F* const opsin) const {
    619   Apply</*add=*/false>(opsin, Rect(*opsin), Rect(*opsin));
    620 }
    621 
    622 Status Splines::InitializeDrawCache(const size_t image_xsize,
    623                                     const size_t image_ysize,
    624                                     const ColorCorrelationMap& cmap) {
    625   // TODO(veluca): avoid storing segments that are entirely outside image
    626   // boundaries.
    627   segments_.clear();
    628   segment_indices_.clear();
    629   segment_y_start_.clear();
    630   std::vector<std::pair<size_t, size_t>> segments_by_y;
    631   std::vector<Spline::Point> intermediate_points;
    632   uint64_t total_estimated_area_reached = 0;
    633   std::vector<Spline> splines;
    634   for (size_t i = 0; i < splines_.size(); ++i) {
    635     Spline spline;
    636     JXL_RETURN_IF_ERROR(splines_[i].Dequantize(
    637         starting_points_[i], quantization_adjustment_, cmap.YtoXRatio(0),
    638         cmap.YtoBRatio(0), image_xsize * image_ysize,
    639         &total_estimated_area_reached, spline));
    640     if (std::adjacent_find(spline.control_points.begin(),
    641                            spline.control_points.end()) !=
    642         spline.control_points.end()) {
    643       // Otherwise division by zero might occur. Once control points coincide,
    644       // the direction of curve is undefined...
    645       return JXL_FAILURE(
    646           "identical successive control points in spline %" PRIuS, i);
    647     }
    648     splines.push_back(spline);
    649   }
    650   // TODO(firsching) Change this into a JXL_FAILURE for level 5 codestreams.
    651   if (total_estimated_area_reached >
    652       std::min(
    653           (8 * image_xsize * image_ysize + (static_cast<uint64_t>(1) << 25)),
    654           (static_cast<uint64_t>(1) << 30))) {
    655     JXL_WARNING(
    656         "Large total_estimated_area_reached, expect slower decoding: %" PRIu64,
    657         total_estimated_area_reached);
    658 #ifdef FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION
    659     return JXL_FAILURE("Total spline area is too large");
    660 #endif
    661   }
    662 
    663   for (Spline& spline : splines) {
    664     std::vector<std::pair<Spline::Point, float>> points_to_draw;
    665     auto add_point = [&](const Spline::Point& point, const float multiplier) {
    666       points_to_draw.emplace_back(point, multiplier);
    667     };
    668     intermediate_points.clear();
    669     DrawCentripetalCatmullRomSpline(spline.control_points, intermediate_points);
    670     ForEachEquallySpacedPoint(intermediate_points, add_point);
    671     const float arc_length =
    672         (points_to_draw.size() - 2) * kDesiredRenderingDistance +
    673         points_to_draw.back().second;
    674     if (arc_length <= 0.f) {
    675       // This spline wouldn't have any effect.
    676       continue;
    677     }
    678     HWY_DYNAMIC_DISPATCH(SegmentsFromPoints)
    679     (spline, points_to_draw, arc_length, segments_, segments_by_y);
    680   }
    681 
    682   // TODO(eustas): consider linear sorting here.
    683   std::sort(segments_by_y.begin(), segments_by_y.end());
    684   segment_indices_.resize(segments_by_y.size());
    685   segment_y_start_.resize(image_ysize + 1);
    686   for (size_t i = 0; i < segments_by_y.size(); i++) {
    687     segment_indices_[i] = segments_by_y[i].second;
    688     size_t y = segments_by_y[i].first;
    689     if (y < image_ysize) {
    690       segment_y_start_[y + 1]++;
    691     }
    692   }
    693   for (size_t y = 0; y < image_ysize; y++) {
    694     segment_y_start_[y + 1] += segment_y_start_[y];
    695   }
    696   return true;
    697 }
    698 
    699 template <bool add>
    700 void Splines::ApplyToRow(float* JXL_RESTRICT row_x, float* JXL_RESTRICT row_y,
    701                          float* JXL_RESTRICT row_b,
    702                          const Rect& image_row) const {
    703   if (segments_.empty()) return;
    704   JXL_ASSERT(image_row.ysize() == 1);
    705   for (size_t iy = 0; iy < image_row.ysize(); iy++) {
    706     HWY_DYNAMIC_DISPATCH(DrawSegments)
    707     (row_x, row_y, row_b, image_row.Line(iy), add, segments_.data(),
    708      segment_indices_.data(), segment_y_start_.data());
    709   }
    710 }
    711 
    712 template <bool add>
    713 void Splines::Apply(Image3F* const opsin, const Rect& opsin_rect,
    714                     const Rect& image_rect) const {
    715   if (segments_.empty()) return;
    716   for (size_t iy = 0; iy < image_rect.ysize(); iy++) {
    717     const size_t y0 = opsin_rect.Line(iy).y0();
    718     const size_t x0 = opsin_rect.x0();
    719     ApplyToRow<add>(opsin->PlaneRow(0, y0) + x0, opsin->PlaneRow(1, y0) + x0,
    720                     opsin->PlaneRow(2, y0) + x0, image_rect.Line(iy));
    721   }
    722 }
    723 
    724 }  // namespace jxl
    725 #endif  // HWY_ONCE