libjxl

FORK: libjxl patches used on blog
git clone https://git.neptards.moe/blog/libjxl.git
Log | Files | Refs | Submodules | README | LICENSE

bisector (10294B)


      1 #!/usr/bin/env python
      2 #
      3 # Copyright (c) the JPEG XL Project Authors. All rights reserved.
      4 #
      5 # Use of this source code is governed by a BSD-style
      6 # license that can be found in the LICENSE file.
      7 
      8 r"""General-purpose bisector
      9 
     10 Prints a space-separated list of values to stdout:
     11 1_if_success_0_otherwise left_x left_f(x) right_x right_f(x)
     12 
     13 Usage examples:
     14 
     15 # Finding the square root of 200 via bisection:
     16 bisector --var=BB --range=0.0,100.0 --target=200 --maxiter=100 \
     17          --atol_val=1e-12 --rtol_val=0 --cmd='echo "$BB * $BB" | bc'
     18 # => 1 14.142135623730923 199.99999999999923 14.142135623731633 200.0000000000193
     19 
     20 # Finding an integer approximation to sqrt(200) via bisection:
     21 bisector --var=BB --range=0,100 --target=200 --maxiter=100 \
     22          --atol_arg=1 --cmd='echo "$BB * $BB" | bc'
     23 # => 1 14 196.0 15 225.0
     24 
     25 # Finding a change-id that broke something via bisection:
     26 bisector --var=BB --range=0,1000000 --target=0.5 --maxiter=100 \
     27          --atol_arg=1 \
     28          --cmd='test $BB -gt 123456 && echo 1 || echo 0' --verbosity=3
     29 # => 1 123456 0.0 123457 1.0
     30 
     31 # Finding settings that compress /usr/share/dict/words to a given target size:
     32 bisector --var=BB --range=1,9 --target=250000 --atol_arg=1 \
     33   --cmd='gzip -$BB </usr/share/dict/words >/tmp/w_$BB.gz; wc -c /tmp/w_$BB.gz' \
     34   --final='mv /tmp/w_$BB.gz /tmp/words.gz; rm /tmp/w_*.gz' \
     35   --verbosity=1
     36 # => 1 3 263170.0 4 240043.0
     37 
     38 # JXL-encoding with bisection-for-size (tolerance 0.5%):
     39 bisector --var=BB --range=0.1,3.0 --target=3500 --rtol_val=0.005 \
     40   --cmd='(build/tools/cjxl --distance=$BB /tmp/baseball.png /tmp/baseball_$BB.jxl && wc -c /tmp/baseball_$BB.jxl)' \
     41   --final='mv /tmp/baseball_$BB.jxl /tmp/baseball.jxl; rm -f /tmp/baseball_*.jxl' \
     42   --verbosity=1
     43 # => 1 1.1875 3573.0 1.278125 3481.0
     44 
     45 # JXL-encoding with bisection-for-bits-per-pixel (tolerance 0.5%), using helper:
     46 bisector --var=BB --range=0.1,3.0 --target=1.2 --rtol_val=0.005 \
     47   --cmd='(build/tools/cjxl --distance=$BB /tmp/baseball.png /tmp/baseball_$BB.jxl && get_bpp /tmp/baseball_$BB.jxl)' \
     48   --final='mv /tmp/baseball_$BB.jxl /tmp/baseball.jxl; rm -f /tmp/baseball_*.jxl' \
     49   --verbosity=1
     50 # => ...
     51 """
     52 
     53 import argparse
     54 import os
     55 import re
     56 import subprocess
     57 import sys
     58 
     59 
     60 def _expandvars(vardef, env,
     61                 max_recursion=100,
     62                 max_length=10**6,
     63                 verbosity=0):
     64   """os.path.expandvars() variant using parameter env rather than os.environ."""
     65   current_expanded = vardef
     66   for num_recursions in range(max_recursion):
     67     if verbosity >= 3:
     68       print(f'_expandvars(): num_recursions={num_recursions}, '
     69             f'len={len(current_expanded)}' +
     70             (', current: ' + current_expanded if verbosity >= 4 else ''))
     71     if len > max_length:
     72         break
     73     current_expanded, num_replacements = re.subn(
     74         r'$\{(\w+)\}|$(\w+)',
     75         lambda m: env.get(m[1] if m[1] is not None else m[2], ''),
     76         current_expanded)
     77     if num_replacements == 0:
     78         break
     79   return current_expanded
     80 
     81 
     82 def _strtod(string):
     83   """Extracts leftmost float from string (like strtod(3))."""
     84   match = re.match(r'[+-]?\d*[.]?\d*(?:[eE][+-]?\d+)?', string)
     85   return float(match[0]) if match[0] else None
     86 
     87   
     88 def run_shell_command(shell_command,
     89                       bisect_var, bisect_val,
     90                       extra_env_defs,
     91                       verbosity=0):
     92   """Runs a shell command with env modifications, fetching return value."""
     93   shell_env = dict(os.environ)
     94   shell_env[bisect_var] = str(bisect_val)
     95   for env_def in extra_env_defs:
     96     varname, vardef = env_def.split('=', 1)
     97     shell_env[varname] = _expandvars(vardev, shell_env,
     98                                      verbosity=verbosity)
     99   shell_ret = subprocess.run(shell_command,
    100                              # We explicitly want subshell semantics!
    101                              shell=True,
    102                              capture_output=True,
    103                              env=shell_env)
    104   stdout = shell_ret.stdout.decode('utf-8')
    105   score = _strtod(stdout)
    106   if verbosity >= 2:
    107     print(f'{bisect_var}={bisect_val} {shell_command} => '
    108           f'{shell_ret.returncode} # {stdout.strip()}')
    109   return (shell_ret.returncode == 0,  # Command was successful?
    110           score)
    111 
    112 
    113 def _bisect(*,
    114             shell_command,
    115             final_shell_command,
    116             target,
    117             int_args,            
    118             bisect_var, bisect_left, bisect_right,
    119             rtol_val, atol_val, rtol_arg, atol_arg,
    120             maxiter,
    121             extra_env_defs,
    122             verbosity=0
    123             ):
    124   """Performs bisection."""
    125   def _get_val(x):
    126     success, val = run_shell_command(shell_command,
    127                                      bisect_var, x,
    128                                      extra_env_defs,
    129                                      verbosity=verbosity)
    130     if not success:
    131       raise RuntimeError(f'Bisection failed for: {bisect_var}={x}: '
    132                          f'success={success}, val={val}, '
    133                          f'cmd={shell_command}, var={bisect_var}')
    134     return val
    135   #
    136   bisect_mid, value_mid = None, None
    137   try:
    138     value_left = _get_val(bisect_left)
    139     value_right = _get_val(bisect_right)
    140     if (value_left < target) != (target <= value_right):
    141       raise RuntimeError(
    142           f'Cannot bisect: target={target}, value_left={value_left}, '
    143           f'value_right={value_right}')
    144     for num_iter in range(maxiter):
    145       bisect_mid_f = 0.5 * (bisect_left + bisect_right)
    146       bisect_mid = round(bisect_mid_f) if int_args else bisect_mid_f
    147       value_mid = _get_val(bisect_mid)
    148       if (value_left < target) == (value_mid < target):
    149         # Relative to target, `value_mid` is on the same side
    150         # as `value_left`.
    151         bisect_left = bisect_mid
    152         value_left = value_mid
    153       else:
    154         # Otherwise, this situation must hold for value_right
    155         # ("tertium non datur").
    156         bisect_right = bisect_mid
    157         value_right = value_mid
    158       if verbosity >= 1:
    159         print(f'bisect target={target}, '
    160               f'left: {value_left} at {bisect_left}, '
    161               f'right: {value_right} at {bisect_right}, '
    162               f'mid: {value_mid} at {bisect_mid}')
    163       delta_val = target - value_mid
    164       if abs(delta_val) <= atol_val + rtol_val * abs(target):
    165         return 1, bisect_left, value_left, bisect_right, value_right
    166       delta_arg = bisect_right - bisect_left
    167       # Also check whether the argument is "within tolerance".
    168       # Here, we have to be careful if bisect_left and bisect_right
    169       # have different signs: Then, their absolute magnitude
    170       # "sets the relevant scale".
    171       if abs(delta_arg) <= atol_arg + (
    172               rtol_arg * 0.5 * (abs(bisect_left) + abs(bisect_right))):
    173         return 1, bisect_left, value_left, bisect_right, value_right
    174     return 0, bisect_left, value_left, bisect_right, value_right
    175   finally:
    176     # If cleanup is specified, always run it
    177     if final_shell_command:
    178         run_shell_command(
    179             final_shell_command,
    180             bisect_var,
    181             bisect_mid if bisect_mid is not None else bisect_left,
    182             extra_env_defs, verbosity=verbosity)
    183 
    184 
    185 def main(args):
    186   """Main entry point."""
    187   parser = argparse.ArgumentParser(description='mhtml_walk args')
    188   parser.add_argument(
    189       '--var',
    190       help='The variable to use for bisection.',
    191       default='BISECT')
    192   parser.add_argument(
    193       '--range',
    194       help=('The argument range for bisecting, as {low},{high}. '
    195             'If no argument has a decimal dot, assume integer parameters.'),
    196       default='0.0,1.0')
    197   parser.add_argument(
    198       '--max',
    199       help='The maximal value for bisecting.',
    200       type=float,
    201       default=0.0)
    202   parser.add_argument(
    203       '--target',
    204       help='The target value to aim for.',
    205       type=float,
    206       default=1.0)
    207   parser.add_argument(
    208       '--maxiter',
    209       help='The maximal number of iterations to perform.',
    210       type=int,
    211       default=40)
    212   parser.add_argument(
    213       '--rtol_val',
    214       help='Relative tolerance to accept for deviations from target value.',
    215       type=float,
    216       default=0.0)
    217   parser.add_argument(
    218       '--atol_val',
    219       help='Absolute tolerance to accept for deviations from target value.',
    220       type=float,
    221       default=0.0)
    222   parser.add_argument(
    223       '--rtol_arg',
    224       help='Relative tolerance to accept for the argument.',
    225       type=float,
    226       default=0.0)
    227   parser.add_argument(
    228       '--atol_arg',
    229       help=('Absolute tolerance to accept for the argument '
    230             '(e.g. for bisecting change-IDs).'),
    231       type=float,
    232       default=0.0)
    233   parser.add_argument(
    234       '--verbosity',
    235       help='The verbosity level.',
    236       type=int,
    237       default=1)
    238   parser.add_argument(
    239       '--env',
    240       help=('Comma-separated list of extra environment variables '
    241             'to incrementally add before executing the shell-command.'),
    242       default='')
    243   parser.add_argument(
    244       '--cmd',
    245       help=('The shell command to execute. Must print a numerical result '
    246             'to stdout.'))
    247   parser.add_argument(
    248       '--final',
    249       help='The cleanup shell command to execute.')
    250   #
    251   parsed = parser.parse_args(args)
    252   extra_env_defs = tuple(filter(None, parsed.env.split(',')))    
    253   try:
    254     low_high = parsed.range.split(',')
    255     if len(low_high) != 2:
    256       raise ValueError('--range must be {low},{high}')
    257     int_args = False
    258     low_val, high_val = map(float, low_high)
    259     low_val_int = round(low_val)
    260     high_val_int = round(high_val)
    261     if low_high == [str(low_val_int), str(high_val_int)]:
    262         int_args = True
    263         low_val = low_val_int
    264         high_val = high_val_int
    265     ret = _bisect(
    266         shell_command=parsed.cmd,
    267         final_shell_command=parsed.final,
    268         target=parsed.target,
    269         int_args=int_args,        
    270         bisect_var=parsed.var,
    271         bisect_left=low_val,
    272         bisect_right=high_val,
    273         rtol_val=parsed.rtol_val,
    274         atol_val=parsed.atol_val,
    275         rtol_arg=parsed.rtol_arg,
    276         atol_arg=parsed.atol_arg,
    277         maxiter=parsed.maxiter,
    278         extra_env_defs=extra_env_defs,
    279         verbosity=parsed.verbosity,
    280     )
    281     print(' '.join(map(str, ret)))
    282   except Exception as exn:
    283     sys.exit(f'Problem: {exn}')
    284 
    285 
    286 if __name__ == '__main__':
    287   main(sys.argv[1:])