libjxl

FORK: libjxl patches used on blog
git clone https://git.neptards.moe/blog/libjxl.git
Log | Files | Refs | Submodules | README | LICENSE

enc_optimize_test.cc (3084B)


      1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
      2 //
      3 // Use of this source code is governed by a BSD-style
      4 // license that can be found in the LICENSE file.
      5 
      6 #include "lib/jxl/enc_optimize.h"
      7 
      8 #include "lib/jxl/testing.h"
      9 
     10 namespace jxl {
     11 namespace optimize {
     12 namespace {
     13 
     14 // The maximum number of iterations for the test.
     15 const size_t kMaxTestIter = 100000;
     16 
     17 // F(w) = (w - w_min)^2.
     18 struct SimpleQuadraticFunction {
     19   typedef Array<double, 2> ArrayType;
     20   explicit SimpleQuadraticFunction(const ArrayType& w0) : w_min(w0) {}
     21 
     22   double Compute(const ArrayType& w, ArrayType* df) const {
     23     ArrayType dw = w - w_min;
     24     *df = -2.0 * dw;
     25     return dw * dw;
     26   }
     27 
     28   ArrayType w_min;
     29 };
     30 
     31 // F(alpha, beta, gamma| x,y) = \sum_i(y_i - (alpha x_i ^ gamma + beta))^2.
     32 struct PowerFunction {
     33   explicit PowerFunction(const std::vector<double>& x0,
     34                          const std::vector<double>& y0)
     35       : x(x0), y(y0) {}
     36 
     37   typedef Array<double, 3> ArrayType;
     38   double Compute(const ArrayType& w, ArrayType* df) const {
     39     double loss_function = 0;
     40     (*df)[0] = 0;
     41     (*df)[1] = 0;
     42     (*df)[2] = 0;
     43     for (size_t ind = 0; ind < y.size(); ++ind) {
     44       if (x[ind] != 0) {
     45         double l_f = y[ind] - (w[0] * pow(x[ind], w[1]) + w[2]);
     46         (*df)[0] += 2.0 * l_f * pow(x[ind], w[1]);
     47         (*df)[1] += 2.0 * l_f * w[0] * pow(x[ind], w[1]) * log(x[ind]);
     48         (*df)[2] += 2.0 * l_f * 1;
     49         loss_function += l_f * l_f;
     50       }
     51     }
     52     return loss_function;
     53   }
     54 
     55   std::vector<double> x;
     56   std::vector<double> y;
     57 };
     58 
     59 TEST(OptimizeTest, SimpleQuadraticFunction) {
     60   SimpleQuadraticFunction::ArrayType w_min;
     61   w_min[0] = 1.0;
     62   w_min[1] = 2.0;
     63   SimpleQuadraticFunction f(w_min);
     64   SimpleQuadraticFunction::ArrayType w(0.);
     65   static const double kPrecision = 1e-8;
     66   w = optimize::OptimizeWithScaledConjugateGradientMethod(f, w, kPrecision,
     67                                                           kMaxTestIter);
     68   EXPECT_NEAR(w[0], 1.0, kPrecision);
     69   EXPECT_NEAR(w[1], 2.0, kPrecision);
     70 }
     71 
     72 TEST(OptimizeTest, PowerFunction) {
     73   std::vector<double> x(10);
     74   std::vector<double> y(10);
     75   for (int ind = 0; ind < 10; ++ind) {
     76     x[ind] = 1. * ind;
     77     y[ind] = 2. * pow(x[ind], 3) + 5.;
     78   }
     79   PowerFunction f(x, y);
     80   PowerFunction::ArrayType w(0.);
     81 
     82   static const double kPrecision = 0.01;
     83   w = optimize::OptimizeWithScaledConjugateGradientMethod(f, w, kPrecision,
     84                                                           kMaxTestIter);
     85   EXPECT_NEAR(w[0], 2.0, kPrecision);
     86   EXPECT_NEAR(w[1], 3.0, kPrecision);
     87   EXPECT_NEAR(w[2], 5.0, kPrecision);
     88 }
     89 
     90 TEST(OptimizeTest, SimplexOptTest) {
     91   auto f = [](const std::vector<double>& x) -> double {
     92     double t1 = x[0] - 1.0;
     93     double t2 = x[1] + 1.5;
     94     return 2.0 + t1 * t1 + t2 * t2;
     95   };
     96   auto opt = RunSimplex(2, 0.01, 100, f);
     97   EXPECT_EQ(opt.size(), 3u);
     98 
     99   static const double kPrecision = 0.01;
    100   EXPECT_NEAR(opt[0], 2.0, kPrecision);
    101   EXPECT_NEAR(opt[1], 1.0, kPrecision);
    102   EXPECT_NEAR(opt[2], -1.5, kPrecision);
    103 }
    104 
    105 }  // namespace
    106 }  // namespace optimize
    107 }  // namespace jxl