enc_optimize.cc (4923B)
1 // Copyright (c) the JPEG XL Project Authors. All rights reserved. 2 // 3 // Use of this source code is governed by a BSD-style 4 // license that can be found in the LICENSE file. 5 6 #include "lib/jxl/enc_optimize.h" 7 8 #include <algorithm> 9 10 #include "lib/jxl/base/status.h" 11 12 namespace jxl { 13 14 namespace optimize { 15 16 namespace { 17 18 // simplex vector must be sorted by first element of its elements 19 std::vector<double> Midpoint(const std::vector<std::vector<double>>& simplex) { 20 JXL_CHECK(!simplex.empty()); 21 JXL_CHECK(simplex.size() == simplex[0].size()); 22 int dim = simplex.size() - 1; 23 std::vector<double> result(dim + 1, 0); 24 for (int i = 0; i < dim; i++) { 25 for (int k = 0; k < dim; k++) { 26 result[i + 1] += simplex[k][i + 1]; 27 } 28 result[i + 1] /= dim; 29 } 30 return result; 31 } 32 33 // first element ignored 34 std::vector<double> Subtract(const std::vector<double>& a, 35 const std::vector<double>& b) { 36 JXL_CHECK(a.size() == b.size()); 37 std::vector<double> result(a.size()); 38 result[0] = 0; 39 for (size_t i = 1; i < result.size(); i++) { 40 result[i] = a[i] - b[i]; 41 } 42 return result; 43 } 44 45 // first element ignored 46 std::vector<double> Add(const std::vector<double>& a, 47 const std::vector<double>& b) { 48 JXL_CHECK(a.size() == b.size()); 49 std::vector<double> result(a.size()); 50 result[0] = 0; 51 for (size_t i = 1; i < result.size(); i++) { 52 result[i] = a[i] + b[i]; 53 } 54 return result; 55 } 56 57 // first element ignored 58 std::vector<double> Average(const std::vector<double>& a, 59 const std::vector<double>& b) { 60 JXL_CHECK(a.size() == b.size()); 61 std::vector<double> result(a.size()); 62 result[0] = 0; 63 for (size_t i = 1; i < result.size(); i++) { 64 result[i] = 0.5 * (a[i] + b[i]); 65 } 66 return result; 67 } 68 69 // vec: [0] will contain the objective function, [1:] will 70 // contain the vector position for the objective function. 71 // fun: the function evaluates the value. 72 void Eval(std::vector<double>* vec, 73 const std::function<double(const std::vector<double>&)>& fun) { 74 std::vector<double> args(vec->begin() + 1, vec->end()); 75 (*vec)[0] = fun(args); 76 } 77 78 void Sort(std::vector<std::vector<double>>* simplex) { 79 std::sort(simplex->begin(), simplex->end()); 80 } 81 82 // Main iteration step of Nelder-Mead like optimization. 83 void Reflect(std::vector<std::vector<double>>* simplex, 84 const std::function<double(const std::vector<double>&)>& fun) { 85 Sort(simplex); 86 const std::vector<double>& last = simplex->back(); 87 std::vector<double> mid = Midpoint(*simplex); 88 std::vector<double> diff = Subtract(mid, last); 89 std::vector<double> mirrored = Add(mid, diff); 90 Eval(&mirrored, fun); 91 if (mirrored[0] > (*simplex)[simplex->size() - 2][0]) { 92 // Still the worst, shrink towards the best. 93 std::vector<double> shrinking = Average(simplex->back(), (*simplex)[0]); 94 Eval(&shrinking, fun); 95 simplex->back() = shrinking; 96 } else if (mirrored[0] < (*simplex)[0][0]) { 97 // new best 98 std::vector<double> even_further = Add(mirrored, diff); 99 Eval(&even_further, fun); 100 if (even_further[0] < mirrored[0]) { 101 mirrored = even_further; 102 } 103 simplex->back() = mirrored; 104 } else { 105 // not a best, not a worst point 106 simplex->back() = mirrored; 107 } 108 } 109 110 // Initialize the simplex at origin. 111 std::vector<std::vector<double>> InitialSimplex( 112 int dim, double amount, const std::vector<double>& init, 113 const std::function<double(const std::vector<double>&)>& fun) { 114 std::vector<double> best(1 + dim, 0); 115 std::copy(init.begin(), init.end(), best.begin() + 1); 116 Eval(&best, fun); 117 std::vector<std::vector<double>> result{best}; 118 for (int i = 0; i < dim; i++) { 119 best = result[0]; 120 best[i + 1] += amount; 121 Eval(&best, fun); 122 result.push_back(best); 123 Sort(&result); 124 } 125 return result; 126 } 127 128 // For comparing the same with the python tool 129 /*void RunSimplexExternal( 130 int dim, double amount, int max_iterations, 131 const std::function<double((const vector<double>&))>& fun) { 132 vector<double> vars; 133 for (int i = 0; i < dim; i++) { 134 vars.push_back(atof(getenv(StrCat("VAR", i).c_str()))); 135 } 136 double result = fun(vars); 137 std::cout << "Result=" << result; 138 }*/ 139 140 } // namespace 141 142 std::vector<double> RunSimplex( 143 int dim, double amount, int max_iterations, const std::vector<double>& init, 144 const std::function<double(const std::vector<double>&)>& fun) { 145 std::vector<std::vector<double>> simplex = 146 InitialSimplex(dim, amount, init, fun); 147 for (int i = 0; i < max_iterations; i++) { 148 Sort(&simplex); 149 Reflect(&simplex, fun); 150 } 151 return simplex[0]; 152 } 153 154 std::vector<double> RunSimplex( 155 int dim, double amount, int max_iterations, 156 const std::function<double(const std::vector<double>&)>& fun) { 157 std::vector<double> init(dim, 0.0); 158 return RunSimplex(dim, amount, max_iterations, init, fun); 159 } 160 161 } // namespace optimize 162 163 } // namespace jxl