qemu

FORK: QEMU emulator
git clone https://git.neptards.moe/neptards/qemu.git
Log | Files | Refs | Submodules | LICENSE

minimize_qtest_trace.py (11613B)


      1 #!/usr/bin/env python3
      2 # -*- coding: utf-8 -*-
      3 
      4 """
      5 This takes a crashing qtest trace and tries to remove superflous operations
      6 """
      7 
      8 import sys
      9 import os
     10 import subprocess
     11 import time
     12 import struct
     13 
     14 QEMU_ARGS = None
     15 QEMU_PATH = None
     16 TIMEOUT = 5
     17 CRASH_TOKEN = None
     18 
     19 # Minimization levels
     20 M1 = False # try removing IO commands iteratively
     21 M2 = False # try setting bits in operand of write/out to zero
     22 
     23 write_suffix_lookup = {"b": (1, "B"),
     24                        "w": (2, "H"),
     25                        "l": (4, "L"),
     26                        "q": (8, "Q")}
     27 
     28 def usage():
     29     sys.exit("""\
     30 Usage:
     31 
     32 QEMU_PATH="/path/to/qemu" QEMU_ARGS="args" {} [Options] input_trace output_trace
     33 
     34 By default, will try to use the second-to-last line in the output to identify
     35 whether the crash occred. Optionally, manually set a string that idenitifes the
     36 crash by setting CRASH_TOKEN=
     37 
     38 Options:
     39 
     40 -M1: enable a loop around the remove minimizer, which may help decrease some
     41      timing dependant instructions. Off by default.
     42 -M2: try setting bits in operand of write/out to zero. Off by default.
     43 
     44 """.format((sys.argv[0])))
     45 
     46 deduplication_note = """\n\
     47 Note: While trimming the input, sometimes the mutated trace triggers a different
     48 type crash but indicates the same bug. Under this situation, our minimizer is
     49 incapable of recognizing and stopped from removing it. In the future, we may
     50 use a more sophisticated crash case deduplication method.
     51 \n"""
     52 
     53 def check_if_trace_crashes(trace, path):
     54     with open(path, "w") as tracefile:
     55         tracefile.write("".join(trace))
     56 
     57     rc = subprocess.Popen("timeout -s 9 {timeout}s {qemu_path} {qemu_args} 2>&1\
     58     < {trace_path}".format(timeout=TIMEOUT,
     59                            qemu_path=QEMU_PATH,
     60                            qemu_args=QEMU_ARGS,
     61                            trace_path=path),
     62                           shell=True,
     63                           stdin=subprocess.PIPE,
     64                           stdout=subprocess.PIPE,
     65                           encoding="utf-8")
     66     global CRASH_TOKEN
     67     if CRASH_TOKEN is None:
     68         try:
     69             outs, _ = rc.communicate(timeout=5)
     70             CRASH_TOKEN = " ".join(outs.splitlines()[-2].split()[0:3])
     71         except subprocess.TimeoutExpired:
     72             print("subprocess.TimeoutExpired")
     73             return False
     74         print("Identifying Crashes by this string: {}".format(CRASH_TOKEN))
     75         global deduplication_note
     76         print(deduplication_note)
     77         return True
     78 
     79     for line in iter(rc.stdout.readline, ""):
     80         if "CLOSED" in line:
     81             return False
     82         if CRASH_TOKEN in line:
     83             return True
     84 
     85     print("\nWarning:")
     86     print("  There is no 'CLOSED'or CRASH_TOKEN in the stdout of subprocess.")
     87     print("  Usually this indicates a different type of crash.\n")
     88     return False
     89 
     90 
     91 # If previous write commands write the same length of data at the same
     92 # interval, we view it as a hint.
     93 def split_write_hint(newtrace, i):
     94     HINT_LEN = 3 # > 2
     95     if i <=(HINT_LEN-1):
     96         return None
     97 
     98     #find previous continuous write traces
     99     k = 0
    100     l = i-1
    101     writes = []
    102     while (k != HINT_LEN and l >= 0):
    103         if newtrace[l].startswith("write "):
    104             writes.append(newtrace[l])
    105             k += 1
    106             l -= 1
    107         elif newtrace[l] == "":
    108             l -= 1
    109         else:
    110             return None
    111     if k != HINT_LEN:
    112         return None
    113 
    114     length = int(writes[0].split()[2], 16)
    115     for j in range(1, HINT_LEN):
    116         if length != int(writes[j].split()[2], 16):
    117             return None
    118 
    119     step = int(writes[0].split()[1], 16) - int(writes[1].split()[1], 16)
    120     for j in range(1, HINT_LEN-1):
    121         if step != int(writes[j].split()[1], 16) - \
    122             int(writes[j+1].split()[1], 16):
    123             return None
    124 
    125     return (int(writes[0].split()[1], 16)+step, length)
    126 
    127 
    128 def remove_lines(newtrace, outpath):
    129     remove_step = 1
    130     i = 0
    131     while i < len(newtrace):
    132         # 1.) Try to remove lines completely and reproduce the crash.
    133         # If it works, we're done.
    134         if (i+remove_step) >= len(newtrace):
    135             remove_step = 1
    136         prior = newtrace[i:i+remove_step]
    137         for j in range(i, i+remove_step):
    138             newtrace[j] = ""
    139         print("Removing {lines} ...\n".format(lines=prior))
    140         if check_if_trace_crashes(newtrace, outpath):
    141             i += remove_step
    142             # Double the number of lines to remove for next round
    143             remove_step *= 2
    144             continue
    145         # Failed to remove multiple IOs, fast recovery
    146         if remove_step > 1:
    147             for j in range(i, i+remove_step):
    148                 newtrace[j] = prior[j-i]
    149             remove_step = 1
    150             continue
    151         newtrace[i] = prior[0] # remove_step = 1
    152 
    153         # 2.) Try to replace write{bwlq} commands with a write addr, len
    154         # command. Since this can require swapping endianness, try both LE and
    155         # BE options. We do this, so we can "trim" the writes in (3)
    156 
    157         if (newtrace[i].startswith("write") and not
    158             newtrace[i].startswith("write ")):
    159             suffix = newtrace[i].split()[0][-1]
    160             assert(suffix in write_suffix_lookup)
    161             addr = int(newtrace[i].split()[1], 16)
    162             value = int(newtrace[i].split()[2], 16)
    163             for endianness in ['<', '>']:
    164                 data = struct.pack("{end}{size}".format(end=endianness,
    165                                    size=write_suffix_lookup[suffix][1]),
    166                                    value)
    167                 newtrace[i] = "write {addr} {size} 0x{data}\n".format(
    168                     addr=hex(addr),
    169                     size=hex(write_suffix_lookup[suffix][0]),
    170                     data=data.hex())
    171                 if(check_if_trace_crashes(newtrace, outpath)):
    172                     break
    173             else:
    174                 newtrace[i] = prior[0]
    175 
    176         # 3.) If it is a qtest write command: write addr len data, try to split
    177         # it into two separate write commands. If splitting the data operand
    178         # from length/2^n bytes to the left does not work, try to move the pivot
    179         # to the right side, then add one to n, until length/2^n == 0. The idea
    180         # is to prune unneccessary bytes from long writes, while accommodating
    181         # arbitrary MemoryRegion access sizes and alignments.
    182 
    183         # This algorithm will fail under some rare situations.
    184         # e.g., xxxxxxxxxuxxxxxx (u is the unnecessary byte)
    185 
    186         if newtrace[i].startswith("write "):
    187             addr = int(newtrace[i].split()[1], 16)
    188             length = int(newtrace[i].split()[2], 16)
    189             data = newtrace[i].split()[3][2:]
    190             if length > 1:
    191 
    192                 # Can we get a hint from previous writes?
    193                 hint = split_write_hint(newtrace, i)
    194                 if hint is not None:
    195                     hint_addr = hint[0]
    196                     hint_len = hint[1]
    197                     if hint_addr >= addr and hint_addr+hint_len <= addr+length:
    198                         newtrace[i] = "write {addr} {size} 0x{data}\n".format(
    199                             addr=hex(hint_addr),
    200                             size=hex(hint_len),
    201                             data=data[(hint_addr-addr)*2:\
    202                                 (hint_addr-addr)*2+hint_len*2])
    203                         if check_if_trace_crashes(newtrace, outpath):
    204                             # next round
    205                             i += 1
    206                             continue
    207                         newtrace[i] = prior[0]
    208 
    209                 # Try splitting it using a binary approach
    210                 leftlength = int(length/2)
    211                 rightlength = length - leftlength
    212                 newtrace.insert(i+1, "")
    213                 power = 1
    214                 while leftlength > 0:
    215                     newtrace[i] = "write {addr} {size} 0x{data}\n".format(
    216                             addr=hex(addr),
    217                             size=hex(leftlength),
    218                             data=data[:leftlength*2])
    219                     newtrace[i+1] = "write {addr} {size} 0x{data}\n".format(
    220                             addr=hex(addr+leftlength),
    221                             size=hex(rightlength),
    222                             data=data[leftlength*2:])
    223                     if check_if_trace_crashes(newtrace, outpath):
    224                         break
    225                     # move the pivot to right side
    226                     if leftlength < rightlength:
    227                         rightlength, leftlength = leftlength, rightlength
    228                         continue
    229                     power += 1
    230                     leftlength = int(length/pow(2, power))
    231                     rightlength = length - leftlength
    232                 if check_if_trace_crashes(newtrace, outpath):
    233                     i -= 1
    234                 else:
    235                     newtrace[i] = prior[0]
    236                     del newtrace[i+1]
    237         i += 1
    238 
    239 
    240 def clear_bits(newtrace, outpath):
    241     # try setting bits in operands of out/write to zero
    242     i = 0
    243     while i < len(newtrace):
    244         if (not newtrace[i].startswith("write ") and not
    245            newtrace[i].startswith("out")):
    246            i += 1
    247            continue
    248         # write ADDR SIZE DATA
    249         # outx ADDR VALUE
    250         print("\nzero setting bits: {}".format(newtrace[i]))
    251 
    252         prefix = " ".join(newtrace[i].split()[:-1])
    253         data = newtrace[i].split()[-1]
    254         data_bin = bin(int(data, 16))
    255         data_bin_list = list(data_bin)
    256 
    257         for j in range(2, len(data_bin_list)):
    258             prior = newtrace[i]
    259             if (data_bin_list[j] == '1'):
    260                 data_bin_list[j] = '0'
    261                 data_try = hex(int("".join(data_bin_list), 2))
    262                 # It seems qtest only accepts padded hex-values.
    263                 if len(data_try) % 2 == 1:
    264                     data_try = data_try[:2] + "0" + data_try[2:]
    265 
    266                 newtrace[i] = "{prefix} {data_try}\n".format(
    267                         prefix=prefix,
    268                         data_try=data_try)
    269 
    270                 if not check_if_trace_crashes(newtrace, outpath):
    271                     data_bin_list[j] = '1'
    272                     newtrace[i] = prior
    273         i += 1
    274 
    275 
    276 def minimize_trace(inpath, outpath):
    277     global TIMEOUT
    278     with open(inpath) as f:
    279         trace = f.readlines()
    280     start = time.time()
    281     if not check_if_trace_crashes(trace, outpath):
    282         sys.exit("The input qtest trace didn't cause a crash...")
    283     end = time.time()
    284     print("Crashed in {} seconds".format(end-start))
    285     TIMEOUT = (end-start)*5
    286     print("Setting the timeout for {} seconds".format(TIMEOUT))
    287 
    288     newtrace = trace[:]
    289     global M1, M2
    290 
    291     # remove lines
    292     old_len = len(newtrace) + 1
    293     while(old_len > len(newtrace)):
    294         old_len = len(newtrace)
    295         print("trace lenth = ", old_len)
    296         remove_lines(newtrace, outpath)
    297         if not M1 and not M2:
    298             break
    299         newtrace = list(filter(lambda s: s != "", newtrace))
    300     assert(check_if_trace_crashes(newtrace, outpath))
    301 
    302     # set bits to zero
    303     if M2:
    304         clear_bits(newtrace, outpath)
    305     assert(check_if_trace_crashes(newtrace, outpath))
    306 
    307 
    308 if __name__ == '__main__':
    309     if len(sys.argv) < 3:
    310         usage()
    311     if "-M1" in sys.argv:
    312         M1 = True
    313     if "-M2" in sys.argv:
    314         M2 = True
    315     QEMU_PATH = os.getenv("QEMU_PATH")
    316     QEMU_ARGS = os.getenv("QEMU_ARGS")
    317     if QEMU_PATH is None or QEMU_ARGS is None:
    318         usage()
    319     # if "accel" not in QEMU_ARGS:
    320     #     QEMU_ARGS += " -accel qtest"
    321     CRASH_TOKEN = os.getenv("CRASH_TOKEN")
    322     QEMU_ARGS += " -qtest stdio -monitor none -serial none "
    323     minimize_trace(sys.argv[-2], sys.argv[-1])