libjxl

FORK: libjxl patches used on blog
git clone https://git.neptards.moe/blog/libjxl.git
Log | Files | Refs | Submodules | README | LICENSE

enc_optimize.h (5163B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 // Utility functions for optimizing multi-dimensional nonlinear functions.
      7 
      8 #ifndef LIB_JXL_OPTIMIZE_H_
      9 #define LIB_JXL_OPTIMIZE_H_
     10 
     11 #include <cmath>
     12 #include <cstdio>
     13 #include <functional>
     14 #include <vector>
     15 
     16 #include "lib/jxl/base/status.h"
     17 
     18 namespace jxl {
     19 namespace optimize {
     20 
     21 // An array type of numeric values that supports math operations with operator-,
     22 // operator+, etc.
     23 template <typename T, size_t N>
     24 class Array {
     25  public:
     26   Array() = default;
     27   explicit Array(T v) {
     28     for (size_t i = 0; i < N; i++) v_[i] = v;
     29   }
     30 
     31   size_t size() const { return N; }
     32 
     33   T& operator[](size_t index) {
     34     JXL_DASSERT(index < N);
     35     return v_[index];
     36   }
     37   T operator[](size_t index) const {
     38     JXL_DASSERT(index < N);
     39     return v_[index];
     40   }
     41 
     42  private:
     43   // The values used by this Array.
     44   T v_[N];
     45 };
     46 
     47 template <typename T, size_t N>
     48 Array<T, N> operator+(const Array<T, N>& x, const Array<T, N>& y) {
     49   Array<T, N> z;
     50   for (size_t i = 0; i < N; ++i) {
     51     z[i] = x[i] + y[i];
     52   }
     53   return z;
     54 }
     55 
     56 template <typename T, size_t N>
     57 Array<T, N> operator-(const Array<T, N>& x, const Array<T, N>& y) {
     58   Array<T, N> z;
     59   for (size_t i = 0; i < N; ++i) {
     60     z[i] = x[i] - y[i];
     61   }
     62   return z;
     63 }
     64 
     65 template <typename T, size_t N>
     66 Array<T, N> operator*(T v, const Array<T, N>& x) {
     67   Array<T, N> y;
     68   for (size_t i = 0; i < N; ++i) {
     69     y[i] = v * x[i];
     70   }
     71   return y;
     72 }
     73 
     74 template <typename T, size_t N>
     75 T operator*(const Array<T, N>& x, const Array<T, N>& y) {
     76   T r = 0.0;
     77   for (size_t i = 0; i < N; ++i) {
     78     r += x[i] * y[i];
     79   }
     80   return r;
     81 }
     82 
     83 // Runs Nelder-Mead like optimization. Runs for max_iterations times,
     84 // fun gets called with a vector of size dim as argument, and returns the score
     85 // based on those parameters (lower is better). Returns a vector of dim+1
     86 // dimensions, where the first value is the optimal value of the function and
     87 // the rest is the argmin value. Use init to pass an initial guess or where
     88 // the optimal value is.
     89 //
     90 // Usage example:
     91 //
     92 // RunSimplex(2, 0.1, 100, [](const vector<float>& v) {
     93 //   return (v[0] - 5) * (v[0] - 5) + (v[1] - 7) * (v[1] - 7);
     94 // });
     95 //
     96 // Returns (0.0, 5, 7)
     97 std::vector<double> RunSimplex(
     98     int dim, double amount, int max_iterations,
     99     const std::function<double(const std::vector<double>&)>& fun);
    100 std::vector<double> RunSimplex(
    101     int dim, double amount, int max_iterations, const std::vector<double>& init,
    102     const std::function<double(const std::vector<double>&)>& fun);
    103 
    104 // Implementation of the Scaled Conjugate Gradient method described in the
    105 // following paper:
    106 //   Moller, M. "A Scaled Conjugate Gradient Algorithm for Fast Supervised
    107 //   Learning", Neural Networks, Vol. 6. pp. 525-533, 1993
    108 //   http://sci2s.ugr.es/keel/pdf/algorithm/articulo/moller1990.pdf
    109 //
    110 // The Function template parameter is a class that has the following method:
    111 //
    112 //   // Returns the value of the function at point w and sets *df to be the
    113 //   // negative gradient vector of the function at point w.
    114 //   double Compute(const optimize::Array<T, N>& w,
    115 //                  optimize::Array<T, N>* df) const;
    116 //
    117 // Returns a vector w, such that |df(w)| < grad_norm_threshold.
    118 template <typename T, size_t N, typename Function>
    119 Array<T, N> OptimizeWithScaledConjugateGradientMethod(
    120     const Function& f, const Array<T, N>& w0, const T grad_norm_threshold,
    121     size_t max_iters) {
    122   const size_t n = w0.size();
    123   const T rsq_threshold = grad_norm_threshold * grad_norm_threshold;
    124   const T sigma0 = static_cast<T>(0.0001);
    125   const T l_min = static_cast<T>(1.0e-15);
    126   const T l_max = static_cast<T>(1.0e15);
    127 
    128   Array<T, N> w = w0;
    129   Array<T, N> wp;
    130   Array<T, N> r;
    131   Array<T, N> rt;
    132   Array<T, N> e;
    133   Array<T, N> p;
    134   T psq;
    135   T fp;
    136   T D;
    137   T d;
    138   T m;
    139   T a;
    140   T b;
    141   T s;
    142   T t;
    143 
    144   T fw = f.Compute(w, &r);
    145   T rsq = r * r;
    146   e = r;
    147   p = r;
    148   T l = static_cast<T>(1.0);
    149   bool success = true;
    150   size_t n_success = 0;
    151   size_t k = 0;
    152 
    153   while (k++ < max_iters) {
    154     if (success) {
    155       m = -(p * r);
    156       if (m >= 0) {
    157         p = r;
    158         m = -(p * r);
    159       }
    160       psq = p * p;
    161       s = sigma0 / std::sqrt(psq);
    162       f.Compute(w + (s * p), &rt);
    163       t = (p * (r - rt)) / s;
    164     }
    165 
    166     d = t + l * psq;
    167     if (d <= 0) {
    168       d = l * psq;
    169       l = l - t / psq;
    170     }
    171 
    172     a = -m / d;
    173     wp = w + a * p;
    174     fp = f.Compute(wp, &rt);
    175 
    176     D = 2.0 * (fp - fw) / (a * m);
    177     if (D >= 0.0) {
    178       success = true;
    179       n_success++;
    180       w = wp;
    181     } else {
    182       success = false;
    183     }
    184 
    185     if (success) {
    186       e = r;
    187       r = rt;
    188       rsq = r * r;
    189       fw = fp;
    190       if (rsq <= rsq_threshold) {
    191         break;
    192       }
    193     }
    194 
    195     if (D < 0.25) {
    196       l = std::min(4.0 * l, l_max);
    197     } else if (D > 0.75) {
    198       l = std::max(0.25 * l, l_min);
    199     }
    200 
    201     if ((n_success % n) == 0) {
    202       p = r;
    203       l = 1.0;
    204     } else if (success) {
    205       b = ((e - r) * r) / m;
    206       p = b * p + r;
    207     }
    208   }
    209 
    210   return w;
    211 }
    212 
    213 }  // namespace optimize
    214 }  // namespace jxl
    215 
    216 #endif  // LIB_JXL_OPTIMIZE_H_