libjxl

FORK: libjxl patches used on blog
git clone https://git.neptards.moe/blog/libjxl.git
Log | Files | Refs | Submodules | README | LICENSE

fast_math-inl.h (8389B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 // Fast SIMD math ops (log2, encoder only, cos, erf for splines)
      7 
      8 #if defined(LIB_JXL_BASE_FAST_MATH_INL_H_) == defined(HWY_TARGET_TOGGLE)
      9 #ifdef LIB_JXL_BASE_FAST_MATH_INL_H_
     10 #undef LIB_JXL_BASE_FAST_MATH_INL_H_
     11 #else
     12 #define LIB_JXL_BASE_FAST_MATH_INL_H_
     13 #endif
     14 
     15 #include <hwy/highway.h>
     16 
     17 #include "lib/jxl/base/common.h"
     18 #include "lib/jxl/base/rational_polynomial-inl.h"
     19 HWY_BEFORE_NAMESPACE();
     20 namespace jxl {
     21 namespace HWY_NAMESPACE {
     22 
     23 // These templates are not found via ADL.
     24 using hwy::HWY_NAMESPACE::Abs;
     25 using hwy::HWY_NAMESPACE::Add;
     26 using hwy::HWY_NAMESPACE::Eq;
     27 using hwy::HWY_NAMESPACE::Floor;
     28 using hwy::HWY_NAMESPACE::Ge;
     29 using hwy::HWY_NAMESPACE::GetLane;
     30 using hwy::HWY_NAMESPACE::IfThenElse;
     31 using hwy::HWY_NAMESPACE::IfThenZeroElse;
     32 using hwy::HWY_NAMESPACE::Le;
     33 using hwy::HWY_NAMESPACE::Min;
     34 using hwy::HWY_NAMESPACE::Mul;
     35 using hwy::HWY_NAMESPACE::MulAdd;
     36 using hwy::HWY_NAMESPACE::NegMulAdd;
     37 using hwy::HWY_NAMESPACE::Rebind;
     38 using hwy::HWY_NAMESPACE::ShiftLeft;
     39 using hwy::HWY_NAMESPACE::ShiftRight;
     40 using hwy::HWY_NAMESPACE::Sub;
     41 using hwy::HWY_NAMESPACE::Xor;
     42 
     43 // Computes base-2 logarithm like std::log2. Undefined if negative / NaN.
     44 // L1 error ~3.9E-6
     45 template <class DF, class V>
     46 V FastLog2f(const DF df, V x) {
     47   // 2,2 rational polynomial approximation of std::log1p(x) / std::log(2).
     48   HWY_ALIGN const float p[4 * (2 + 1)] = {HWY_REP4(-1.8503833400518310E-06f),
     49                                           HWY_REP4(1.4287160470083755E+00f),
     50                                           HWY_REP4(7.4245873327820566E-01f)};
     51   HWY_ALIGN const float q[4 * (2 + 1)] = {HWY_REP4(9.9032814277590719E-01f),
     52                                           HWY_REP4(1.0096718572241148E+00f),
     53                                           HWY_REP4(1.7409343003366853E-01f)};
     54 
     55   const Rebind<int32_t, DF> di;
     56   const auto x_bits = BitCast(di, x);
     57 
     58   // Range reduction to [-1/3, 1/3] - 3 integer, 2 float ops
     59   const auto exp_bits = Sub(x_bits, Set(di, 0x3f2aaaab));  // = 2/3
     60   // Shifted exponent = log2; also used to clear mantissa.
     61   const auto exp_shifted = ShiftRight<23>(exp_bits);
     62   const auto mantissa = BitCast(df, Sub(x_bits, ShiftLeft<23>(exp_shifted)));
     63   const auto exp_val = ConvertTo(df, exp_shifted);
     64   return Add(EvalRationalPolynomial(df, Sub(mantissa, Set(df, 1.0f)), p, q),
     65              exp_val);
     66 }
     67 
     68 // max relative error ~3e-7
     69 template <class DF, class V>
     70 V FastPow2f(const DF df, V x) {
     71   const Rebind<int32_t, DF> di;
     72   auto floorx = Floor(x);
     73   auto exp =
     74       BitCast(df, ShiftLeft<23>(Add(ConvertTo(di, floorx), Set(di, 127))));
     75   auto frac = Sub(x, floorx);
     76   auto num = Add(frac, Set(df, 1.01749063e+01));
     77   num = MulAdd(num, frac, Set(df, 4.88687798e+01));
     78   num = MulAdd(num, frac, Set(df, 9.85506591e+01));
     79   num = Mul(num, exp);
     80   auto den = MulAdd(frac, Set(df, 2.10242958e-01), Set(df, -2.22328856e-02));
     81   den = MulAdd(den, frac, Set(df, -1.94414990e+01));
     82   den = MulAdd(den, frac, Set(df, 9.85506633e+01));
     83   return Div(num, den);
     84 }
     85 
     86 // max relative error ~3e-5
     87 template <class DF, class V>
     88 V FastPowf(const DF df, V base, V exponent) {
     89   return FastPow2f(df, Mul(FastLog2f(df, base), exponent));
     90 }
     91 
     92 // Computes cosine like std::cos.
     93 // L1 error 7e-5.
     94 template <class DF, class V>
     95 V FastCosf(const DF df, V x) {
     96   // Step 1: range reduction to [0, 2pi)
     97   const auto pi2 = Set(df, kPi * 2.0f);
     98   const auto pi2_inv = Set(df, 0.5f / kPi);
     99   const auto npi2 = Mul(Floor(Mul(x, pi2_inv)), pi2);
    100   const auto xmodpi2 = Sub(x, npi2);
    101   // Step 2: range reduction to [0, pi]
    102   const auto x_pi = Min(xmodpi2, Sub(pi2, xmodpi2));
    103   // Step 3: range reduction to [0, pi/2]
    104   const auto above_pihalf = Ge(x_pi, Set(df, kPi / 2.0f));
    105   const auto x_pihalf = IfThenElse(above_pihalf, Sub(Set(df, kPi), x_pi), x_pi);
    106   // Step 4: Taylor-like approximation, scaled by 2**0.75 to make angle
    107   // duplication steps faster, on x/4.
    108   const auto xs = Mul(x_pihalf, Set(df, 0.25f));
    109   const auto x2 = Mul(xs, xs);
    110   const auto x4 = Mul(x2, x2);
    111   const auto cosx_prescaling =
    112       MulAdd(x4, Set(df, 0.06960438),
    113              MulAdd(x2, Set(df, -0.84087373), Set(df, 1.68179268)));
    114   // Step 5: angle duplication.
    115   const auto cosx_scale1 =
    116       MulAdd(cosx_prescaling, cosx_prescaling, Set(df, -1.414213562));
    117   const auto cosx_scale2 = MulAdd(cosx_scale1, cosx_scale1, Set(df, -1));
    118   // Step 6: change sign if needed.
    119   const Rebind<uint32_t, DF> du;
    120   auto signbit = ShiftLeft<31>(BitCast(du, VecFromMask(df, above_pihalf)));
    121   return BitCast(df, Xor(signbit, BitCast(du, cosx_scale2)));
    122 }
    123 
    124 // Computes the error function like std::erf.
    125 // L1 error 7e-4.
    126 template <class DF, class V>
    127 V FastErff(const DF df, V x) {
    128   // Formula from
    129   // https://en.wikipedia.org/wiki/Error_function#Numerical_approximations
    130   // but constants have been recomputed.
    131   const auto xle0 = Le(x, Zero(df));
    132   const auto absx = Abs(x);
    133   // Compute 1 - 1 / ((((x * a + b) * x + c) * x + d) * x + 1)**4
    134   const auto denom1 =
    135       MulAdd(absx, Set(df, 7.77394369e-02), Set(df, 2.05260015e-04));
    136   const auto denom2 = MulAdd(denom1, absx, Set(df, 2.32120216e-01));
    137   const auto denom3 = MulAdd(denom2, absx, Set(df, 2.77820801e-01));
    138   const auto denom4 = MulAdd(denom3, absx, Set(df, 1.0f));
    139   const auto denom5 = Mul(denom4, denom4);
    140   const auto inv_denom5 = Div(Set(df, 1.0f), denom5);
    141   const auto result = NegMulAdd(inv_denom5, inv_denom5, Set(df, 1.0f));
    142   // Change sign if needed.
    143   const Rebind<uint32_t, DF> du;
    144   auto signbit = ShiftLeft<31>(BitCast(du, VecFromMask(df, xle0)));
    145   return BitCast(df, Xor(signbit, BitCast(du, result)));
    146 }
    147 
    148 inline float FastLog2f(float f) {
    149   HWY_CAPPED(float, 1) D;
    150   return GetLane(FastLog2f(D, Set(D, f)));
    151 }
    152 
    153 inline float FastPow2f(float f) {
    154   HWY_CAPPED(float, 1) D;
    155   return GetLane(FastPow2f(D, Set(D, f)));
    156 }
    157 
    158 inline float FastPowf(float b, float e) {
    159   HWY_CAPPED(float, 1) D;
    160   return GetLane(FastPowf(D, Set(D, b), Set(D, e)));
    161 }
    162 
    163 inline float FastCosf(float f) {
    164   HWY_CAPPED(float, 1) D;
    165   return GetLane(FastCosf(D, Set(D, f)));
    166 }
    167 
    168 inline float FastErff(float f) {
    169   HWY_CAPPED(float, 1) D;
    170   return GetLane(FastErff(D, Set(D, f)));
    171 }
    172 
    173 // Returns cbrt(x) + add with 6 ulp max error.
    174 // Modified from vectormath_exp.h, Apache 2 license.
    175 // https://www.agner.org/optimize/vectorclass.zip
    176 template <class V>
    177 V CubeRootAndAdd(const V x, const V add) {
    178   const HWY_FULL(float) df;
    179   const HWY_FULL(int32_t) di;
    180 
    181   const auto kExpBias = Set(di, 0x54800000);  // cast(1.) + cast(1.) / 3
    182   const auto kExpMul = Set(di, 0x002AAAAA);   // shifted 1/3
    183   const auto k1_3 = Set(df, 1.0f / 3);
    184   const auto k4_3 = Set(df, 4.0f / 3);
    185 
    186   const auto xa = x;  // assume inputs never negative
    187   const auto xa_3 = Mul(k1_3, xa);
    188 
    189   // Multiply exponent by -1/3
    190   const auto m1 = BitCast(di, xa);
    191   // Special case for 0. 0 is represented with an exponent of 0, so the
    192   // "kExpBias - 1/3 * exp" below gives the wrong result. The IfThenZeroElse()
    193   // sets those values as 0, which prevents having NaNs in the computations
    194   // below.
    195   // TODO(eustas): use fused op
    196   const auto m2 = IfThenZeroElse(
    197       Eq(m1, Zero(di)), Sub(kExpBias, Mul((ShiftRight<23>(m1)), kExpMul)));
    198   auto r = BitCast(df, m2);
    199 
    200   // Newton-Raphson iterations
    201   for (int i = 0; i < 3; i++) {
    202     const auto r2 = Mul(r, r);
    203     r = NegMulAdd(xa_3, Mul(r2, r2), Mul(k4_3, r));
    204   }
    205   // Final iteration
    206   auto r2 = Mul(r, r);
    207   r = MulAdd(k1_3, NegMulAdd(xa, Mul(r2, r2), r), r);
    208   r2 = Mul(r, r);
    209   r = MulAdd(r2, x, add);
    210 
    211   return r;
    212 }
    213 
    214 // NOLINTNEXTLINE(google-readability-namespace-comments)
    215 }  // namespace HWY_NAMESPACE
    216 }  // namespace jxl
    217 HWY_AFTER_NAMESPACE();
    218 
    219 #endif  // LIB_JXL_BASE_FAST_MATH_INL_H_
    220 
    221 #if HWY_ONCE
    222 #ifndef LIB_JXL_BASE_FAST_MATH_ONCE
    223 #define LIB_JXL_BASE_FAST_MATH_ONCE
    224 
    225 namespace jxl {
    226 inline float FastLog2f(float f) { return HWY_STATIC_DISPATCH(FastLog2f)(f); }
    227 inline float FastPow2f(float f) { return HWY_STATIC_DISPATCH(FastPow2f)(f); }
    228 inline float FastPowf(float b, float e) {
    229   return HWY_STATIC_DISPATCH(FastPowf)(b, e);
    230 }
    231 inline float FastCosf(float f) { return HWY_STATIC_DISPATCH(FastCosf)(f); }
    232 inline float FastErff(float f) { return HWY_STATIC_DISPATCH(FastErff)(f); }
    233 }  // namespace jxl
    234 
    235 #endif  // LIB_JXL_BASE_FAST_MATH_ONCE
    236 #endif  // HWY_ONCE