libjxl

FORK: libjxl patches used on blog
git clone https://git.neptards.moe/blog/libjxl.git
Log | Files | Refs | Submodules | README | LICENSE

enc_optimize.cc (4923B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 #include "lib/jxl/enc_optimize.h"
      7 
      8 #include <algorithm>
      9 
     10 #include "lib/jxl/base/status.h"
     11 
     12 namespace jxl {
     13 
     14 namespace optimize {
     15 
     16 namespace {
     17 
     18 // simplex vector must be sorted by first element of its elements
     19 std::vector<double> Midpoint(const std::vector<std::vector<double>>& simplex) {
     20   JXL_CHECK(!simplex.empty());
     21   JXL_CHECK(simplex.size() == simplex[0].size());
     22   int dim = simplex.size() - 1;
     23   std::vector<double> result(dim + 1, 0);
     24   for (int i = 0; i < dim; i++) {
     25     for (int k = 0; k < dim; k++) {
     26       result[i + 1] += simplex[k][i + 1];
     27     }
     28     result[i + 1] /= dim;
     29   }
     30   return result;
     31 }
     32 
     33 // first element ignored
     34 std::vector<double> Subtract(const std::vector<double>& a,
     35                              const std::vector<double>& b) {
     36   JXL_CHECK(a.size() == b.size());
     37   std::vector<double> result(a.size());
     38   result[0] = 0;
     39   for (size_t i = 1; i < result.size(); i++) {
     40     result[i] = a[i] - b[i];
     41   }
     42   return result;
     43 }
     44 
     45 // first element ignored
     46 std::vector<double> Add(const std::vector<double>& a,
     47                         const std::vector<double>& b) {
     48   JXL_CHECK(a.size() == b.size());
     49   std::vector<double> result(a.size());
     50   result[0] = 0;
     51   for (size_t i = 1; i < result.size(); i++) {
     52     result[i] = a[i] + b[i];
     53   }
     54   return result;
     55 }
     56 
     57 // first element ignored
     58 std::vector<double> Average(const std::vector<double>& a,
     59                             const std::vector<double>& b) {
     60   JXL_CHECK(a.size() == b.size());
     61   std::vector<double> result(a.size());
     62   result[0] = 0;
     63   for (size_t i = 1; i < result.size(); i++) {
     64     result[i] = 0.5 * (a[i] + b[i]);
     65   }
     66   return result;
     67 }
     68 
     69 // vec: [0] will contain the objective function, [1:] will
     70 //   contain the vector position for the objective function.
     71 // fun: the function evaluates the value.
     72 void Eval(std::vector<double>* vec,
     73           const std::function<double(const std::vector<double>&)>& fun) {
     74   std::vector<double> args(vec->begin() + 1, vec->end());
     75   (*vec)[0] = fun(args);
     76 }
     77 
     78 void Sort(std::vector<std::vector<double>>* simplex) {
     79   std::sort(simplex->begin(), simplex->end());
     80 }
     81 
     82 // Main iteration step of Nelder-Mead like optimization.
     83 void Reflect(std::vector<std::vector<double>>* simplex,
     84              const std::function<double(const std::vector<double>&)>& fun) {
     85   Sort(simplex);
     86   const std::vector<double>& last = simplex->back();
     87   std::vector<double> mid = Midpoint(*simplex);
     88   std::vector<double> diff = Subtract(mid, last);
     89   std::vector<double> mirrored = Add(mid, diff);
     90   Eval(&mirrored, fun);
     91   if (mirrored[0] > (*simplex)[simplex->size() - 2][0]) {
     92     // Still the worst, shrink towards the best.
     93     std::vector<double> shrinking = Average(simplex->back(), (*simplex)[0]);
     94     Eval(&shrinking, fun);
     95     simplex->back() = shrinking;
     96   } else if (mirrored[0] < (*simplex)[0][0]) {
     97     // new best
     98     std::vector<double> even_further = Add(mirrored, diff);
     99     Eval(&even_further, fun);
    100     if (even_further[0] < mirrored[0]) {
    101       mirrored = even_further;
    102     }
    103     simplex->back() = mirrored;
    104   } else {
    105     // not a best, not a worst point
    106     simplex->back() = mirrored;
    107   }
    108 }
    109 
    110 // Initialize the simplex at origin.
    111 std::vector<std::vector<double>> InitialSimplex(
    112     int dim, double amount, const std::vector<double>& init,
    113     const std::function<double(const std::vector<double>&)>& fun) {
    114   std::vector<double> best(1 + dim, 0);
    115   std::copy(init.begin(), init.end(), best.begin() + 1);
    116   Eval(&best, fun);
    117   std::vector<std::vector<double>> result{best};
    118   for (int i = 0; i < dim; i++) {
    119     best = result[0];
    120     best[i + 1] += amount;
    121     Eval(&best, fun);
    122     result.push_back(best);
    123     Sort(&result);
    124   }
    125   return result;
    126 }
    127 
    128 // For comparing the same with the python tool
    129 /*void RunSimplexExternal(
    130     int dim, double amount, int max_iterations,
    131     const std::function<double((const vector<double>&))>& fun) {
    132   vector<double> vars;
    133   for (int i = 0; i < dim; i++) {
    134     vars.push_back(atof(getenv(StrCat("VAR", i).c_str())));
    135   }
    136   double result = fun(vars);
    137   std::cout << "Result=" << result;
    138 }*/
    139 
    140 }  // namespace
    141 
    142 std::vector<double> RunSimplex(
    143     int dim, double amount, int max_iterations, const std::vector<double>& init,
    144     const std::function<double(const std::vector<double>&)>& fun) {
    145   std::vector<std::vector<double>> simplex =
    146       InitialSimplex(dim, amount, init, fun);
    147   for (int i = 0; i < max_iterations; i++) {
    148     Sort(&simplex);
    149     Reflect(&simplex, fun);
    150   }
    151   return simplex[0];
    152 }
    153 
    154 std::vector<double> RunSimplex(
    155     int dim, double amount, int max_iterations,
    156     const std::function<double(const std::vector<double>&)>& fun) {
    157   std::vector<double> init(dim, 0.0);
    158   return RunSimplex(dim, amount, max_iterations, init, fun);
    159 }
    160 
    161 }  // namespace optimize
    162 
    163 }  // namespace jxl