simplex_fork.py (7443B)
1 #!/usr/bin/python 2 # Copyright (c) the JPEG XL Project Authors. All rights reserved. 3 # 4 # Use of this source code is governed by a BSD-style 5 # license that can be found in the LICENSE file. 6 7 """Implementation of simplex search for an external process. 8 9 The external process gets the input vector through environment variables. 10 Input of vector as setenv("VAR%dimension", val) 11 Getting the optimized function with regexp match from stdout 12 of the forked process. 13 14 https://en.wikipedia.org/wiki/Nelder%E2%80%93Mead_method 15 16 start as ./simplex_fork.py binary dimensions amount 17 """ 18 19 from __future__ import absolute_import 20 from __future__ import division 21 from __future__ import print_function 22 from six.moves import range 23 import copy 24 import os 25 import random 26 import re 27 import subprocess 28 import sys 29 30 def Midpoint(simplex): 31 """Nelder-Mead-like simplex midpoint calculation.""" 32 simplex.sort() 33 dim = len(simplex) - 1 34 retval = [None] + [0.0] * dim 35 for i in range(1, dim + 1): 36 for k in range(dim): 37 retval[i] += simplex[k][i] 38 retval[i] /= dim 39 return retval 40 41 42 def Subtract(a, b): 43 """Vector arithmetic, with [0] being ignored.""" 44 return [None if k == 0 else a[k] - b[k] for k in range(len(a))] 45 46 def Add(a, b): 47 """Vector arithmetic, with [0] being ignored.""" 48 return [None if k == 0 else a[k] + b[k] for k in range(len(a))] 49 50 def Average(a, b): 51 """Vector arithmetic, with [0] being ignored.""" 52 return [None if k == 0 else 0.5 * (a[k] + b[k]) for k in range(len(a))] 53 54 55 eval_hash = {} 56 g_best_val = None 57 58 def EvalCacheForget(): 59 global eval_hash 60 eval_hash = {} 61 62 def RandomizedJxlCodecs(): 63 retval = [] 64 minval = 0.2 65 maxval = 9.3 66 rangeval = maxval/minval 67 steps = 13 68 for i in range(steps): 69 mul = minval * rangeval**(float(i)/(steps - 1)) 70 mul *= 0.99 + 0.05 * random.random() 71 retval.append("jxl:d%.4f" % mul) 72 for i in range(steps - 1): 73 mul = minval * rangeval**(float(i+0.5)/(steps - 1)) 74 mul *= 0.99 + 0.05 * random.random() 75 retval.append("jxl:d%.4f" % mul) 76 return ",".join(retval) 77 78 g_codecs = RandomizedJxlCodecs() 79 80 def Eval(vec, binary_name, cached=True): 81 """Evaluates the objective function by forking a process. 82 83 Args: 84 vec: [0] will be set to the objective function, [1:] will 85 contain the vector position for the objective function. 86 binary_name: the name of the binary that evaluates the value. 87 """ 88 global eval_hash 89 global g_codecs 90 global g_best_val 91 key = "" 92 # os.environ["BUTTERAUGLI_OPTIMIZE"] = "1" 93 for i in range(300): 94 os.environ["VAR%d" % i] = "0" 95 for i in range(len(vec) - 1): 96 os.environ["VAR%d" % i] = str(vec[i + 1]) 97 key += str(vec[i + 1]) + ":" 98 if cached and (key in eval_hash): 99 vec[0] = eval_hash[key] 100 return 101 102 process = subprocess.Popen( 103 (binary_name, 104 '--input', 105 '/usr/local/google/home/jyrki/newcorpus/split/*.png', 106 '--error_pnorm=3.0', 107 '--more_columns', 108 '--codec', g_codecs), 109 stdout=subprocess.PIPE, 110 stderr=subprocess.PIPE, 111 env=dict(os.environ)) 112 113 # process.wait() 114 found_score = False 115 vec[0] = 1.0 116 dct2 = 0.0 117 dct4 = 0.0 118 dct16 = 0.0 119 dct32 = 0.0 120 n = 0 121 for line in process.communicate(input=None)[0].splitlines(): 122 print("BE", line) 123 sys.stdout.flush() 124 if line[0:3] == b'jxl': 125 bpp = line.split()[3] 126 dist_pnorm = line.split()[9] 127 dist_max = line.split()[6] 128 vec[0] *= float(dist_pnorm) * float(bpp) / 16.0 129 #vec[0] *= (float(dist_max) * float(bpp) / 16.0) ** 0.01 130 n += 1 131 found_score = True 132 distance = float(line.split()[0].split(b'd')[-1]) 133 faultybpp = 1.0 + 0.43 * ((float(bpp) * distance ** 0.69) - 1.64) ** 2 134 vec[0] *= faultybpp 135 136 print("eval: ", vec) 137 if (vec[0] <= 0.0): 138 vec[0] = 1e30 139 if found_score: 140 eval_hash[key] = vec[0] 141 if not g_best_val or vec[0] < g_best_val: 142 g_best_val = vec[0] 143 print("\nSaving best simplex\n") 144 with open("best_simplex.txt", "w") as f: 145 print(vec, file=f) 146 return 147 vec[0] = 1e33 148 return 149 # sys.exit("awful things happened") 150 151 def Reflect(simplex, binary): 152 """Main iteration step of Nelder-Mead optimization. Modifies `simplex`.""" 153 simplex.sort() 154 last = simplex[-1] 155 mid = Midpoint(simplex) 156 diff = Subtract(mid, last) 157 mirrored = Add(mid, diff) 158 Eval(mirrored, binary) 159 if mirrored[0] > simplex[-2][0]: 160 print("\nStill worst\n\n") 161 # Still the worst, shrink towards the best. 162 shrinking = Average(simplex[-1], simplex[0]) 163 Eval(shrinking, binary) 164 print("\nshrinking...\n\n") 165 simplex[-1] = shrinking 166 return 167 if mirrored[0] < simplex[0][0]: 168 # new best 169 print("\nNew Best\n\n") 170 even_further = Add(mirrored, diff) 171 Eval(even_further, binary) 172 if even_further[0] < mirrored[0]: 173 print("\nEven Further\n\n") 174 mirrored = even_further 175 simplex[-1] = mirrored 176 # try to extend 177 return 178 else: 179 # not a best, not a worst point 180 simplex[-1] = mirrored 181 182 183 def OneDimensionalSearch(simplex, shrink, index): 184 # last appended was better than the best so far, try to replace it 185 last_attempt = simplex[-1][:] 186 best = simplex[0] 187 if last_attempt[0] < best[0]: 188 # try expansion of the amount 189 diff = simplex[-1][index] - simplex[0][index] 190 simplex[-1][index] = simplex[0][index] + shrink * diff 191 Eval(simplex[-1], g_binary) 192 if simplex[-1][0] < last_attempt[0]: 193 # it got better 194 return True 195 elif last_attempt[0] >= 0: 196 diff = simplex[-1][index] - simplex[0][index] 197 simplex[-1][index] = simplex[0][index] - diff 198 Eval(simplex[-1], g_binary) 199 if simplex[-1][0] < last_attempt[0]: 200 # it got better 201 return True 202 simplex[-1] = last_attempt 203 return False 204 205 def InitialSimplex(vec, dim, amount): 206 """Initialize the simplex at origin.""" 207 EvalCacheForget() 208 best = vec[:] 209 Eval(best, g_binary) 210 retval = [best] 211 comp_order = list(range(1, dim + 1)) 212 random.shuffle(comp_order) 213 214 for i in range(dim): 215 index = comp_order[i] 216 best = retval[0][:] 217 best[index] += amount 218 Eval(best, g_binary) 219 retval.append(best) 220 do_shrink = True 221 while OneDimensionalSearch(retval, 2.0, index): 222 print("OneDimensionalSearch-Grow") 223 while OneDimensionalSearch(retval, 1.1, index): 224 print("OneDimensionalSearch-SlowGrow") 225 do_shrink = False 226 if do_shrink: 227 while OneDimensionalSearch(retval, 0.9, index): 228 print("OneDimensionalSearch-SlowShrinking") 229 retval.sort() 230 return retval 231 232 233 if len(sys.argv) != 4: 234 print("usage: ", sys.argv[0], "binary-name number-of-dimensions simplex-size") 235 exit(1) 236 237 g_dim = int(sys.argv[2]) 238 g_amount = float(sys.argv[3]) 239 g_binary = sys.argv[1] 240 g_simplex = InitialSimplex([None] + [0.0] * g_dim, 241 g_dim, 7.0 * g_amount) 242 best = g_simplex[0][:] 243 g_codecs = RandomizedJxlCodecs() 244 g_simplex = InitialSimplex(best, g_dim, g_amount * 2.47) 245 best = g_simplex[0][:] 246 g_simplex = InitialSimplex(best, g_dim, g_amount) 247 best = g_simplex[0][:] 248 g_simplex = InitialSimplex(best, g_dim, g_amount * 0.33) 249 best = g_simplex[0][:] 250 251 for restarts in range(99999): 252 for ii in range(g_dim * 5): 253 g_simplex.sort() 254 print("reflect", ii, g_simplex[0]) 255 Reflect(g_simplex, g_binary) 256 257 mulli = 0.1 + 15 * random.random()**2.0 258 g_codecs = RandomizedJxlCodecs() 259 print("\n\n\nRestart", restarts, "mulli", mulli) 260 g_simplex.sort() 261 best = g_simplex[0][:] 262 g_simplex = InitialSimplex(best, g_dim, g_amount * mulli)