enc_optimize.h (5163B)
1 // Copyright (c) the JPEG XL Project Authors. All rights reserved. 2 // 3 // Use of this source code is governed by a BSD-style 4 // license that can be found in the LICENSE file. 5 6 // Utility functions for optimizing multi-dimensional nonlinear functions. 7 8 #ifndef LIB_JXL_OPTIMIZE_H_ 9 #define LIB_JXL_OPTIMIZE_H_ 10 11 #include <cmath> 12 #include <cstdio> 13 #include <functional> 14 #include <vector> 15 16 #include "lib/jxl/base/status.h" 17 18 namespace jxl { 19 namespace optimize { 20 21 // An array type of numeric values that supports math operations with operator-, 22 // operator+, etc. 23 template <typename T, size_t N> 24 class Array { 25 public: 26 Array() = default; 27 explicit Array(T v) { 28 for (size_t i = 0; i < N; i++) v_[i] = v; 29 } 30 31 size_t size() const { return N; } 32 33 T& operator[](size_t index) { 34 JXL_DASSERT(index < N); 35 return v_[index]; 36 } 37 T operator[](size_t index) const { 38 JXL_DASSERT(index < N); 39 return v_[index]; 40 } 41 42 private: 43 // The values used by this Array. 44 T v_[N]; 45 }; 46 47 template <typename T, size_t N> 48 Array<T, N> operator+(const Array<T, N>& x, const Array<T, N>& y) { 49 Array<T, N> z; 50 for (size_t i = 0; i < N; ++i) { 51 z[i] = x[i] + y[i]; 52 } 53 return z; 54 } 55 56 template <typename T, size_t N> 57 Array<T, N> operator-(const Array<T, N>& x, const Array<T, N>& y) { 58 Array<T, N> z; 59 for (size_t i = 0; i < N; ++i) { 60 z[i] = x[i] - y[i]; 61 } 62 return z; 63 } 64 65 template <typename T, size_t N> 66 Array<T, N> operator*(T v, const Array<T, N>& x) { 67 Array<T, N> y; 68 for (size_t i = 0; i < N; ++i) { 69 y[i] = v * x[i]; 70 } 71 return y; 72 } 73 74 template <typename T, size_t N> 75 T operator*(const Array<T, N>& x, const Array<T, N>& y) { 76 T r = 0.0; 77 for (size_t i = 0; i < N; ++i) { 78 r += x[i] * y[i]; 79 } 80 return r; 81 } 82 83 // Runs Nelder-Mead like optimization. Runs for max_iterations times, 84 // fun gets called with a vector of size dim as argument, and returns the score 85 // based on those parameters (lower is better). Returns a vector of dim+1 86 // dimensions, where the first value is the optimal value of the function and 87 // the rest is the argmin value. Use init to pass an initial guess or where 88 // the optimal value is. 89 // 90 // Usage example: 91 // 92 // RunSimplex(2, 0.1, 100, [](const vector<float>& v) { 93 // return (v[0] - 5) * (v[0] - 5) + (v[1] - 7) * (v[1] - 7); 94 // }); 95 // 96 // Returns (0.0, 5, 7) 97 std::vector<double> RunSimplex( 98 int dim, double amount, int max_iterations, 99 const std::function<double(const std::vector<double>&)>& fun); 100 std::vector<double> RunSimplex( 101 int dim, double amount, int max_iterations, const std::vector<double>& init, 102 const std::function<double(const std::vector<double>&)>& fun); 103 104 // Implementation of the Scaled Conjugate Gradient method described in the 105 // following paper: 106 // Moller, M. "A Scaled Conjugate Gradient Algorithm for Fast Supervised 107 // Learning", Neural Networks, Vol. 6. pp. 525-533, 1993 108 // http://sci2s.ugr.es/keel/pdf/algorithm/articulo/moller1990.pdf 109 // 110 // The Function template parameter is a class that has the following method: 111 // 112 // // Returns the value of the function at point w and sets *df to be the 113 // // negative gradient vector of the function at point w. 114 // double Compute(const optimize::Array<T, N>& w, 115 // optimize::Array<T, N>* df) const; 116 // 117 // Returns a vector w, such that |df(w)| < grad_norm_threshold. 118 template <typename T, size_t N, typename Function> 119 Array<T, N> OptimizeWithScaledConjugateGradientMethod( 120 const Function& f, const Array<T, N>& w0, const T grad_norm_threshold, 121 size_t max_iters) { 122 const size_t n = w0.size(); 123 const T rsq_threshold = grad_norm_threshold * grad_norm_threshold; 124 const T sigma0 = static_cast<T>(0.0001); 125 const T l_min = static_cast<T>(1.0e-15); 126 const T l_max = static_cast<T>(1.0e15); 127 128 Array<T, N> w = w0; 129 Array<T, N> wp; 130 Array<T, N> r; 131 Array<T, N> rt; 132 Array<T, N> e; 133 Array<T, N> p; 134 T psq; 135 T fp; 136 T D; 137 T d; 138 T m; 139 T a; 140 T b; 141 T s; 142 T t; 143 144 T fw = f.Compute(w, &r); 145 T rsq = r * r; 146 e = r; 147 p = r; 148 T l = static_cast<T>(1.0); 149 bool success = true; 150 size_t n_success = 0; 151 size_t k = 0; 152 153 while (k++ < max_iters) { 154 if (success) { 155 m = -(p * r); 156 if (m >= 0) { 157 p = r; 158 m = -(p * r); 159 } 160 psq = p * p; 161 s = sigma0 / std::sqrt(psq); 162 f.Compute(w + (s * p), &rt); 163 t = (p * (r - rt)) / s; 164 } 165 166 d = t + l * psq; 167 if (d <= 0) { 168 d = l * psq; 169 l = l - t / psq; 170 } 171 172 a = -m / d; 173 wp = w + a * p; 174 fp = f.Compute(wp, &rt); 175 176 D = 2.0 * (fp - fw) / (a * m); 177 if (D >= 0.0) { 178 success = true; 179 n_success++; 180 w = wp; 181 } else { 182 success = false; 183 } 184 185 if (success) { 186 e = r; 187 r = rt; 188 rsq = r * r; 189 fw = fp; 190 if (rsq <= rsq_threshold) { 191 break; 192 } 193 } 194 195 if (D < 0.25) { 196 l = std::min(4.0 * l, l_max); 197 } else if (D > 0.75) { 198 l = std::max(0.25 * l, l_min); 199 } 200 201 if ((n_success % n) == 0) { 202 p = r; 203 l = 1.0; 204 } else if (success) { 205 b = ((e - r) * r) / m; 206 p = b * p + r; 207 } 208 } 209 210 return w; 211 } 212 213 } // namespace optimize 214 } // namespace jxl 215 216 #endif // LIB_JXL_OPTIMIZE_H_