fast_math-inl.h (8389B)
1 // Copyright (c) the JPEG XL Project Authors. All rights reserved. 2 // 3 // Use of this source code is governed by a BSD-style 4 // license that can be found in the LICENSE file. 5 6 // Fast SIMD math ops (log2, encoder only, cos, erf for splines) 7 8 #if defined(LIB_JXL_BASE_FAST_MATH_INL_H_) == defined(HWY_TARGET_TOGGLE) 9 #ifdef LIB_JXL_BASE_FAST_MATH_INL_H_ 10 #undef LIB_JXL_BASE_FAST_MATH_INL_H_ 11 #else 12 #define LIB_JXL_BASE_FAST_MATH_INL_H_ 13 #endif 14 15 #include <hwy/highway.h> 16 17 #include "lib/jxl/base/common.h" 18 #include "lib/jxl/base/rational_polynomial-inl.h" 19 HWY_BEFORE_NAMESPACE(); 20 namespace jxl { 21 namespace HWY_NAMESPACE { 22 23 // These templates are not found via ADL. 24 using hwy::HWY_NAMESPACE::Abs; 25 using hwy::HWY_NAMESPACE::Add; 26 using hwy::HWY_NAMESPACE::Eq; 27 using hwy::HWY_NAMESPACE::Floor; 28 using hwy::HWY_NAMESPACE::Ge; 29 using hwy::HWY_NAMESPACE::GetLane; 30 using hwy::HWY_NAMESPACE::IfThenElse; 31 using hwy::HWY_NAMESPACE::IfThenZeroElse; 32 using hwy::HWY_NAMESPACE::Le; 33 using hwy::HWY_NAMESPACE::Min; 34 using hwy::HWY_NAMESPACE::Mul; 35 using hwy::HWY_NAMESPACE::MulAdd; 36 using hwy::HWY_NAMESPACE::NegMulAdd; 37 using hwy::HWY_NAMESPACE::Rebind; 38 using hwy::HWY_NAMESPACE::ShiftLeft; 39 using hwy::HWY_NAMESPACE::ShiftRight; 40 using hwy::HWY_NAMESPACE::Sub; 41 using hwy::HWY_NAMESPACE::Xor; 42 43 // Computes base-2 logarithm like std::log2. Undefined if negative / NaN. 44 // L1 error ~3.9E-6 45 template <class DF, class V> 46 V FastLog2f(const DF df, V x) { 47 // 2,2 rational polynomial approximation of std::log1p(x) / std::log(2). 48 HWY_ALIGN const float p[4 * (2 + 1)] = {HWY_REP4(-1.8503833400518310E-06f), 49 HWY_REP4(1.4287160470083755E+00f), 50 HWY_REP4(7.4245873327820566E-01f)}; 51 HWY_ALIGN const float q[4 * (2 + 1)] = {HWY_REP4(9.9032814277590719E-01f), 52 HWY_REP4(1.0096718572241148E+00f), 53 HWY_REP4(1.7409343003366853E-01f)}; 54 55 const Rebind<int32_t, DF> di; 56 const auto x_bits = BitCast(di, x); 57 58 // Range reduction to [-1/3, 1/3] - 3 integer, 2 float ops 59 const auto exp_bits = Sub(x_bits, Set(di, 0x3f2aaaab)); // = 2/3 60 // Shifted exponent = log2; also used to clear mantissa. 61 const auto exp_shifted = ShiftRight<23>(exp_bits); 62 const auto mantissa = BitCast(df, Sub(x_bits, ShiftLeft<23>(exp_shifted))); 63 const auto exp_val = ConvertTo(df, exp_shifted); 64 return Add(EvalRationalPolynomial(df, Sub(mantissa, Set(df, 1.0f)), p, q), 65 exp_val); 66 } 67 68 // max relative error ~3e-7 69 template <class DF, class V> 70 V FastPow2f(const DF df, V x) { 71 const Rebind<int32_t, DF> di; 72 auto floorx = Floor(x); 73 auto exp = 74 BitCast(df, ShiftLeft<23>(Add(ConvertTo(di, floorx), Set(di, 127)))); 75 auto frac = Sub(x, floorx); 76 auto num = Add(frac, Set(df, 1.01749063e+01)); 77 num = MulAdd(num, frac, Set(df, 4.88687798e+01)); 78 num = MulAdd(num, frac, Set(df, 9.85506591e+01)); 79 num = Mul(num, exp); 80 auto den = MulAdd(frac, Set(df, 2.10242958e-01), Set(df, -2.22328856e-02)); 81 den = MulAdd(den, frac, Set(df, -1.94414990e+01)); 82 den = MulAdd(den, frac, Set(df, 9.85506633e+01)); 83 return Div(num, den); 84 } 85 86 // max relative error ~3e-5 87 template <class DF, class V> 88 V FastPowf(const DF df, V base, V exponent) { 89 return FastPow2f(df, Mul(FastLog2f(df, base), exponent)); 90 } 91 92 // Computes cosine like std::cos. 93 // L1 error 7e-5. 94 template <class DF, class V> 95 V FastCosf(const DF df, V x) { 96 // Step 1: range reduction to [0, 2pi) 97 const auto pi2 = Set(df, kPi * 2.0f); 98 const auto pi2_inv = Set(df, 0.5f / kPi); 99 const auto npi2 = Mul(Floor(Mul(x, pi2_inv)), pi2); 100 const auto xmodpi2 = Sub(x, npi2); 101 // Step 2: range reduction to [0, pi] 102 const auto x_pi = Min(xmodpi2, Sub(pi2, xmodpi2)); 103 // Step 3: range reduction to [0, pi/2] 104 const auto above_pihalf = Ge(x_pi, Set(df, kPi / 2.0f)); 105 const auto x_pihalf = IfThenElse(above_pihalf, Sub(Set(df, kPi), x_pi), x_pi); 106 // Step 4: Taylor-like approximation, scaled by 2**0.75 to make angle 107 // duplication steps faster, on x/4. 108 const auto xs = Mul(x_pihalf, Set(df, 0.25f)); 109 const auto x2 = Mul(xs, xs); 110 const auto x4 = Mul(x2, x2); 111 const auto cosx_prescaling = 112 MulAdd(x4, Set(df, 0.06960438), 113 MulAdd(x2, Set(df, -0.84087373), Set(df, 1.68179268))); 114 // Step 5: angle duplication. 115 const auto cosx_scale1 = 116 MulAdd(cosx_prescaling, cosx_prescaling, Set(df, -1.414213562)); 117 const auto cosx_scale2 = MulAdd(cosx_scale1, cosx_scale1, Set(df, -1)); 118 // Step 6: change sign if needed. 119 const Rebind<uint32_t, DF> du; 120 auto signbit = ShiftLeft<31>(BitCast(du, VecFromMask(df, above_pihalf))); 121 return BitCast(df, Xor(signbit, BitCast(du, cosx_scale2))); 122 } 123 124 // Computes the error function like std::erf. 125 // L1 error 7e-4. 126 template <class DF, class V> 127 V FastErff(const DF df, V x) { 128 // Formula from 129 // https://en.wikipedia.org/wiki/Error_function#Numerical_approximations 130 // but constants have been recomputed. 131 const auto xle0 = Le(x, Zero(df)); 132 const auto absx = Abs(x); 133 // Compute 1 - 1 / ((((x * a + b) * x + c) * x + d) * x + 1)**4 134 const auto denom1 = 135 MulAdd(absx, Set(df, 7.77394369e-02), Set(df, 2.05260015e-04)); 136 const auto denom2 = MulAdd(denom1, absx, Set(df, 2.32120216e-01)); 137 const auto denom3 = MulAdd(denom2, absx, Set(df, 2.77820801e-01)); 138 const auto denom4 = MulAdd(denom3, absx, Set(df, 1.0f)); 139 const auto denom5 = Mul(denom4, denom4); 140 const auto inv_denom5 = Div(Set(df, 1.0f), denom5); 141 const auto result = NegMulAdd(inv_denom5, inv_denom5, Set(df, 1.0f)); 142 // Change sign if needed. 143 const Rebind<uint32_t, DF> du; 144 auto signbit = ShiftLeft<31>(BitCast(du, VecFromMask(df, xle0))); 145 return BitCast(df, Xor(signbit, BitCast(du, result))); 146 } 147 148 inline float FastLog2f(float f) { 149 HWY_CAPPED(float, 1) D; 150 return GetLane(FastLog2f(D, Set(D, f))); 151 } 152 153 inline float FastPow2f(float f) { 154 HWY_CAPPED(float, 1) D; 155 return GetLane(FastPow2f(D, Set(D, f))); 156 } 157 158 inline float FastPowf(float b, float e) { 159 HWY_CAPPED(float, 1) D; 160 return GetLane(FastPowf(D, Set(D, b), Set(D, e))); 161 } 162 163 inline float FastCosf(float f) { 164 HWY_CAPPED(float, 1) D; 165 return GetLane(FastCosf(D, Set(D, f))); 166 } 167 168 inline float FastErff(float f) { 169 HWY_CAPPED(float, 1) D; 170 return GetLane(FastErff(D, Set(D, f))); 171 } 172 173 // Returns cbrt(x) + add with 6 ulp max error. 174 // Modified from vectormath_exp.h, Apache 2 license. 175 // https://www.agner.org/optimize/vectorclass.zip 176 template <class V> 177 V CubeRootAndAdd(const V x, const V add) { 178 const HWY_FULL(float) df; 179 const HWY_FULL(int32_t) di; 180 181 const auto kExpBias = Set(di, 0x54800000); // cast(1.) + cast(1.) / 3 182 const auto kExpMul = Set(di, 0x002AAAAA); // shifted 1/3 183 const auto k1_3 = Set(df, 1.0f / 3); 184 const auto k4_3 = Set(df, 4.0f / 3); 185 186 const auto xa = x; // assume inputs never negative 187 const auto xa_3 = Mul(k1_3, xa); 188 189 // Multiply exponent by -1/3 190 const auto m1 = BitCast(di, xa); 191 // Special case for 0. 0 is represented with an exponent of 0, so the 192 // "kExpBias - 1/3 * exp" below gives the wrong result. The IfThenZeroElse() 193 // sets those values as 0, which prevents having NaNs in the computations 194 // below. 195 // TODO(eustas): use fused op 196 const auto m2 = IfThenZeroElse( 197 Eq(m1, Zero(di)), Sub(kExpBias, Mul((ShiftRight<23>(m1)), kExpMul))); 198 auto r = BitCast(df, m2); 199 200 // Newton-Raphson iterations 201 for (int i = 0; i < 3; i++) { 202 const auto r2 = Mul(r, r); 203 r = NegMulAdd(xa_3, Mul(r2, r2), Mul(k4_3, r)); 204 } 205 // Final iteration 206 auto r2 = Mul(r, r); 207 r = MulAdd(k1_3, NegMulAdd(xa, Mul(r2, r2), r), r); 208 r2 = Mul(r, r); 209 r = MulAdd(r2, x, add); 210 211 return r; 212 } 213 214 // NOLINTNEXTLINE(google-readability-namespace-comments) 215 } // namespace HWY_NAMESPACE 216 } // namespace jxl 217 HWY_AFTER_NAMESPACE(); 218 219 #endif // LIB_JXL_BASE_FAST_MATH_INL_H_ 220 221 #if HWY_ONCE 222 #ifndef LIB_JXL_BASE_FAST_MATH_ONCE 223 #define LIB_JXL_BASE_FAST_MATH_ONCE 224 225 namespace jxl { 226 inline float FastLog2f(float f) { return HWY_STATIC_DISPATCH(FastLog2f)(f); } 227 inline float FastPow2f(float f) { return HWY_STATIC_DISPATCH(FastPow2f)(f); } 228 inline float FastPowf(float b, float e) { 229 return HWY_STATIC_DISPATCH(FastPowf)(b, e); 230 } 231 inline float FastCosf(float f) { return HWY_STATIC_DISPATCH(FastCosf)(f); } 232 inline float FastErff(float f) { return HWY_STATIC_DISPATCH(FastErff)(f); } 233 } // namespace jxl 234 235 #endif // LIB_JXL_BASE_FAST_MATH_ONCE 236 #endif // HWY_ONCE