libjxl

FORK: libjxl patches used on blog
git clone https://git.neptards.moe/blog/libjxl.git
Log | Files | Refs | Submodules | README | LICENSE

simplex_fork.py (7443B)


      1 #!/usr/bin/python
      2 # Copyright (c) the JPEG XL Project Authors. All rights reserved.
      3 #
      4 # Use of this source code is governed by a BSD-style
      5 # license that can be found in the LICENSE file.
      6 
      7 """Implementation of simplex search for an external process.
      8 
      9 The external process gets the input vector through environment variables.
     10 Input of vector as setenv("VAR%dimension", val)
     11 Getting the optimized function with regexp match from stdout
     12 of the forked process.
     13 
     14 https://en.wikipedia.org/wiki/Nelder%E2%80%93Mead_method
     15 
     16 start as ./simplex_fork.py binary dimensions amount
     17 """
     18 
     19 from __future__ import absolute_import
     20 from __future__ import division
     21 from __future__ import print_function
     22 from six.moves import range
     23 import copy
     24 import os
     25 import random
     26 import re
     27 import subprocess
     28 import sys
     29 
     30 def Midpoint(simplex):
     31   """Nelder-Mead-like simplex midpoint calculation."""
     32   simplex.sort()
     33   dim = len(simplex) - 1
     34   retval = [None] + [0.0] * dim
     35   for i in range(1, dim + 1):
     36     for k in range(dim):
     37       retval[i] += simplex[k][i]
     38     retval[i] /= dim
     39   return retval
     40 
     41 
     42 def Subtract(a, b):
     43   """Vector arithmetic, with [0] being ignored."""
     44   return [None if k == 0 else a[k] - b[k] for k in range(len(a))]
     45 
     46 def Add(a, b):
     47   """Vector arithmetic, with [0] being ignored."""
     48   return [None if k == 0 else a[k] + b[k] for k in range(len(a))]
     49 
     50 def Average(a, b):
     51   """Vector arithmetic, with [0] being ignored."""
     52   return [None if k == 0 else 0.5 * (a[k] + b[k]) for k in range(len(a))]
     53 
     54 
     55 eval_hash = {}
     56 g_best_val = None
     57 
     58 def EvalCacheForget():
     59   global eval_hash
     60   eval_hash = {}
     61 
     62 def RandomizedJxlCodecs():
     63   retval = []
     64   minval = 0.2
     65   maxval = 9.3
     66   rangeval = maxval/minval
     67   steps = 13
     68   for i in range(steps):
     69     mul = minval * rangeval**(float(i)/(steps - 1))
     70     mul *= 0.99 + 0.05 * random.random()
     71     retval.append("jxl:d%.4f" % mul)
     72   for i in range(steps - 1):
     73     mul = minval * rangeval**(float(i+0.5)/(steps - 1))
     74     mul *= 0.99 + 0.05 * random.random()
     75     retval.append("jxl:d%.4f" % mul)
     76   return ",".join(retval)
     77 
     78 g_codecs = RandomizedJxlCodecs()
     79 
     80 def Eval(vec, binary_name, cached=True):
     81   """Evaluates the objective function by forking a process.
     82 
     83   Args:
     84     vec: [0] will be set to the objective function, [1:] will
     85       contain the vector position for the objective function.
     86     binary_name: the name of the binary that evaluates the value.
     87   """
     88   global eval_hash
     89   global g_codecs
     90   global g_best_val
     91   key = ""
     92   # os.environ["BUTTERAUGLI_OPTIMIZE"] = "1"
     93   for i in range(300):
     94     os.environ["VAR%d" % i] = "0"
     95   for i in range(len(vec) - 1):
     96     os.environ["VAR%d" % i] = str(vec[i + 1])
     97     key += str(vec[i + 1]) + ":"
     98   if cached and (key in eval_hash):
     99     vec[0] = eval_hash[key]
    100     return
    101 
    102   process = subprocess.Popen(
    103       (binary_name,
    104        '--input',
    105        '/usr/local/google/home/jyrki/newcorpus/split/*.png',
    106        '--error_pnorm=3.0',
    107        '--more_columns',
    108        '--codec', g_codecs),
    109       stdout=subprocess.PIPE,
    110       stderr=subprocess.PIPE,
    111       env=dict(os.environ))
    112 
    113   # process.wait()
    114   found_score = False
    115   vec[0] = 1.0
    116   dct2 = 0.0
    117   dct4 = 0.0
    118   dct16 = 0.0
    119   dct32 = 0.0
    120   n = 0
    121   for line in process.communicate(input=None)[0].splitlines():
    122     print("BE", line)
    123     sys.stdout.flush()
    124     if line[0:3] == b'jxl':
    125       bpp = line.split()[3]
    126       dist_pnorm = line.split()[9]
    127       dist_max = line.split()[6]
    128       vec[0] *= float(dist_pnorm) * float(bpp) / 16.0
    129       #vec[0] *= (float(dist_max) * float(bpp) / 16.0) ** 0.01
    130       n += 1
    131       found_score = True
    132       distance = float(line.split()[0].split(b'd')[-1])
    133       faultybpp = 1.0 + 0.43 * ((float(bpp) * distance ** 0.69) - 1.64) ** 2
    134       vec[0] *= faultybpp
    135 
    136   print("eval: ", vec)
    137   if (vec[0] <= 0.0):
    138     vec[0] = 1e30
    139   if found_score:
    140     eval_hash[key] = vec[0]
    141     if not g_best_val or vec[0] < g_best_val:
    142       g_best_val = vec[0]
    143       print("\nSaving best simplex\n")
    144       with open("best_simplex.txt", "w") as f:
    145         print(vec, file=f)
    146     return
    147   vec[0] = 1e33
    148   return
    149   # sys.exit("awful things happened")
    150 
    151 def Reflect(simplex, binary):
    152   """Main iteration step of Nelder-Mead optimization. Modifies `simplex`."""
    153   simplex.sort()
    154   last = simplex[-1]
    155   mid = Midpoint(simplex)
    156   diff = Subtract(mid, last)
    157   mirrored = Add(mid, diff)
    158   Eval(mirrored, binary)
    159   if mirrored[0] > simplex[-2][0]:
    160     print("\nStill worst\n\n")
    161     # Still the worst, shrink towards the best.
    162     shrinking = Average(simplex[-1], simplex[0])
    163     Eval(shrinking, binary)
    164     print("\nshrinking...\n\n")
    165     simplex[-1] = shrinking
    166     return
    167   if mirrored[0] < simplex[0][0]:
    168     # new best
    169     print("\nNew Best\n\n")
    170     even_further = Add(mirrored, diff)
    171     Eval(even_further, binary)
    172     if even_further[0] < mirrored[0]:
    173       print("\nEven Further\n\n")
    174       mirrored = even_further
    175     simplex[-1] = mirrored
    176     # try to extend
    177     return
    178   else:
    179     # not a best, not a worst point
    180     simplex[-1] = mirrored
    181 
    182 
    183 def OneDimensionalSearch(simplex, shrink, index):
    184   # last appended was better than the best so far, try to replace it
    185   last_attempt = simplex[-1][:]
    186   best = simplex[0]
    187   if last_attempt[0] < best[0]:
    188     # try expansion of the amount
    189     diff = simplex[-1][index] - simplex[0][index]
    190     simplex[-1][index] = simplex[0][index] + shrink * diff
    191     Eval(simplex[-1], g_binary)
    192     if simplex[-1][0] < last_attempt[0]:
    193       # it got better
    194       return True
    195   elif last_attempt[0] >= 0:
    196     diff = simplex[-1][index] - simplex[0][index]
    197     simplex[-1][index] = simplex[0][index] - diff
    198     Eval(simplex[-1], g_binary)
    199     if simplex[-1][0] < last_attempt[0]:
    200       # it got better
    201       return True
    202   simplex[-1] = last_attempt
    203   return False
    204 
    205 def InitialSimplex(vec, dim, amount):
    206   """Initialize the simplex at origin."""
    207   EvalCacheForget()
    208   best = vec[:]
    209   Eval(best, g_binary)
    210   retval = [best]
    211   comp_order = list(range(1, dim + 1))
    212   random.shuffle(comp_order)
    213 
    214   for i in range(dim):
    215     index = comp_order[i]
    216     best = retval[0][:]
    217     best[index] += amount
    218     Eval(best, g_binary)
    219     retval.append(best)
    220     do_shrink = True
    221     while OneDimensionalSearch(retval, 2.0, index):
    222       print("OneDimensionalSearch-Grow")
    223     while OneDimensionalSearch(retval, 1.1, index):
    224       print("OneDimensionalSearch-SlowGrow")
    225       do_shrink = False
    226     if do_shrink:
    227       while OneDimensionalSearch(retval, 0.9, index):
    228         print("OneDimensionalSearch-SlowShrinking")
    229     retval.sort()
    230   return retval
    231 
    232 
    233 if len(sys.argv) != 4:
    234   print("usage: ", sys.argv[0], "binary-name number-of-dimensions simplex-size")
    235   exit(1)
    236 
    237 g_dim = int(sys.argv[2])
    238 g_amount = float(sys.argv[3])
    239 g_binary = sys.argv[1]
    240 g_simplex = InitialSimplex([None] + [0.0] * g_dim,
    241                            g_dim, 7.0 * g_amount)
    242 best = g_simplex[0][:]
    243 g_codecs = RandomizedJxlCodecs()
    244 g_simplex = InitialSimplex(best, g_dim, g_amount * 2.47)
    245 best = g_simplex[0][:]
    246 g_simplex = InitialSimplex(best, g_dim, g_amount)
    247 best = g_simplex[0][:]
    248 g_simplex = InitialSimplex(best, g_dim, g_amount * 0.33)
    249 best = g_simplex[0][:]
    250 
    251 for restarts in range(99999):
    252   for ii in range(g_dim * 5):
    253     g_simplex.sort()
    254     print("reflect", ii, g_simplex[0])
    255     Reflect(g_simplex, g_binary)
    256 
    257   mulli = 0.1 + 15 * random.random()**2.0
    258   g_codecs = RandomizedJxlCodecs()
    259   print("\n\n\nRestart", restarts, "mulli", mulli)
    260   g_simplex.sort()
    261   best = g_simplex[0][:]
    262   g_simplex = InitialSimplex(best, g_dim, g_amount * mulli)