qemu

FORK: QEMU emulator
git clone https://git.neptards.moe/neptards/qemu.git
Log | Files | Refs | Submodules | LICENSE

test-avx.py (10596B)


      1 #! /usr/bin/env python3
      2 
      3 # Generate test-avx.h from x86.csv
      4 
      5 import csv
      6 import sys
      7 from fnmatch import fnmatch
      8 
      9 archs = [
     10     "SSE", "SSE2", "SSE3", "SSSE3", "SSE4_1", "SSE4_2",
     11     "AES", "AVX", "AVX2", "AES+AVX", "VAES+AVX",
     12     "F16C", "FMA",
     13 ]
     14 
     15 ignore = set(["FISTTP",
     16     "LDMXCSR", "VLDMXCSR", "STMXCSR", "VSTMXCSR"])
     17 
     18 imask = {
     19     'vBLENDPD': 0xff,
     20     'vBLENDPS': 0x0f,
     21     'CMP[PS][SD]': 0x07,
     22     'VCMP[PS][SD]': 0x1f,
     23     'vCVTPS2PH': 0x7,
     24     'vDPPD': 0x33,
     25     'vDPPS': 0xff,
     26     'vEXTRACTPS': 0x03,
     27     'vINSERTPS': 0xff,
     28     'MPSADBW': 0x7,
     29     'VMPSADBW': 0x3f,
     30     'vPALIGNR': 0x3f,
     31     'vPBLENDW': 0xff,
     32     'vPCMP[EI]STR*': 0x0f,
     33     'vPEXTRB': 0x0f,
     34     'vPEXTRW': 0x07,
     35     'vPEXTRD': 0x03,
     36     'vPEXTRQ': 0x01,
     37     'vPINSRB': 0x0f,
     38     'vPINSRW': 0x07,
     39     'vPINSRD': 0x03,
     40     'vPINSRQ': 0x01,
     41     'vPSHUF[DW]': 0xff,
     42     'vPSHUF[LH]W': 0xff,
     43     'vPS[LR][AL][WDQ]': 0x3f,
     44     'vPS[RL]LDQ': 0x1f,
     45     'vROUND[PS][SD]': 0x7,
     46     'vSHUFPD': 0x0f,
     47     'vSHUFPS': 0xff,
     48     'vAESKEYGENASSIST': 0xff,
     49     'VEXTRACT[FI]128': 0x01,
     50     'VINSERT[FI]128': 0x01,
     51     'VPBLENDD': 0xff,
     52     'VPERM2[FI]128': 0x33,
     53     'VPERMPD': 0xff,
     54     'VPERMQ': 0xff,
     55     'VPERMILPS': 0xff,
     56     'VPERMILPD': 0x0f,
     57     }
     58 
     59 def strip_comments(x):
     60     for l in x:
     61         if l != '' and l[0] != '#':
     62             yield l
     63 
     64 def reg_w(w):
     65     if w == 8:
     66         return 'al'
     67     elif w == 16:
     68         return 'ax'
     69     elif w == 32:
     70         return 'eax'
     71     elif w == 64:
     72         return 'rax'
     73     raise Exception("bad reg_w %d" % w)
     74 
     75 def mem_w(w):
     76     if w == 8:
     77         t = "BYTE"
     78     elif w == 16:
     79         t = "WORD"
     80     elif w == 32:
     81         t = "DWORD"
     82     elif w == 64:
     83         t = "QWORD"
     84     elif w == 128:
     85         t = "XMMWORD"
     86     elif w == 256:
     87         t = "YMMWORD"
     88     else:
     89         raise Exception()
     90 
     91     return t + " PTR 32[rdx]"
     92 
     93 class XMMArg():
     94     isxmm = True
     95     def __init__(self, reg, mw):
     96         if mw not in [0, 8, 16, 32, 64, 128, 256]:
     97             raise Exception("Bad /m width: %s" % w)
     98         self.reg = reg
     99         self.mw = mw
    100         self.ismem = mw != 0
    101     def regstr(self, n):
    102         if n < 0:
    103             return mem_w(self.mw)
    104         else:
    105             return "%smm%d" % (self.reg, n)
    106 
    107 class MMArg():
    108     isxmm = True
    109     def __init__(self, mw):
    110         if mw not in [0, 32, 64]:
    111             raise Exception("Bad mem width: %s" % mw)
    112         self.mw = mw
    113         self.ismem = mw != 0
    114     def regstr(self, n):
    115         return "mm%d" % (n & 7)
    116 
    117 def match(op, pattern):
    118     if pattern[0] == 'v':
    119         return fnmatch(op, pattern[1:]) or fnmatch(op, 'V'+pattern[1:])
    120     return fnmatch(op, pattern)
    121 
    122 class ArgVSIB():
    123     isxmm = True
    124     ismem = False
    125     def __init__(self, reg, w):
    126         if w not in [32, 64]:
    127             raise Exception("Bad vsib width: %s" % w)
    128         self.w = w
    129         self.reg = reg
    130     def regstr(self, n):
    131         reg = "%smm%d" % (self.reg, n >> 2)
    132         return "[rsi + %s * %d]" % (reg, 1 << (n & 3))
    133 
    134 class ArgImm8u():
    135     isxmm = False
    136     ismem = False
    137     def __init__(self, op):
    138         for k, v in imask.items():
    139             if match(op, k):
    140                 self.mask = imask[k];
    141                 return
    142         raise Exception("Unknown immediate")
    143     def vals(self):
    144         mask = self.mask
    145         yield 0
    146         n = 0
    147         while n != mask:
    148             n += 1
    149             while (n & ~mask) != 0:
    150                 n += (n & ~mask)
    151             yield n
    152 
    153 class ArgRM():
    154     isxmm = False
    155     def __init__(self, rw, mw):
    156         if rw not in [8, 16, 32, 64]:
    157             raise Exception("Bad r/w width: %s" % w)
    158         if mw not in [0, 8, 16, 32, 64]:
    159             raise Exception("Bad r/w width: %s" % w)
    160         self.rw = rw
    161         self.mw = mw
    162         self.ismem = mw != 0
    163     def regstr(self, n):
    164         if n < 0:
    165             return mem_w(self.mw)
    166         else:
    167             return reg_w(self.rw)
    168 
    169 class ArgMem():
    170     isxmm = False
    171     ismem = True
    172     def __init__(self, w):
    173         if w not in [8, 16, 32, 64, 128, 256]:
    174             raise Exception("Bad mem width: %s" % w)
    175         self.w = w
    176     def regstr(self, n):
    177         return mem_w(self.w)
    178 
    179 class SkipInstruction(Exception):
    180     pass
    181 
    182 def ArgGenerator(arg, op):
    183     if arg[:3] == 'xmm' or arg[:3] == "ymm":
    184         if "/" in arg:
    185             r, m = arg.split('/')
    186             if (m[0] != 'm'):
    187                 raise Exception("Expected /m: %s", arg)
    188             return XMMArg(arg[0], int(m[1:]));
    189         else:
    190             return XMMArg(arg[0], 0);
    191     elif arg[:2] == 'mm':
    192         if "/" in arg:
    193             r, m = arg.split('/')
    194             if (m[0] != 'm'):
    195                 raise Exception("Expected /m: %s", arg)
    196             return MMArg(int(m[1:]));
    197         else:
    198             return MMArg(0);
    199     elif arg[:4] == 'imm8':
    200         return ArgImm8u(op);
    201     elif arg == '<XMM0>':
    202         return None
    203     elif arg[0] == 'r':
    204         if '/m' in arg:
    205             r, m = arg.split('/')
    206             if (m[0] != 'm'):
    207                 raise Exception("Expected /m: %s", arg)
    208             mw = int(m[1:])
    209             if r == 'r':
    210                 rw = mw
    211             else:
    212                 rw = int(r[1:])
    213             return ArgRM(rw, mw)
    214 
    215         return ArgRM(int(arg[1:]), 0);
    216     elif arg[0] == 'm':
    217         return ArgMem(int(arg[1:]))
    218     elif arg[:2] == 'vm':
    219         return ArgVSIB(arg[-1], int(arg[2:-1]))
    220     else:
    221         raise Exception("Unrecognised arg: %s", arg)
    222 
    223 class InsnGenerator:
    224     def __init__(self, op, args):
    225         self.op = op
    226         if op[-2:] in ["PH", "PS", "PD", "SS", "SD"]:
    227             if op[-1] == 'H':
    228                 self.optype = 'F16'
    229             elif op[-1] == 'S':
    230                 self.optype = 'F32'
    231             else:
    232                 self.optype = 'F64'
    233         else:
    234             self.optype = 'I'
    235 
    236         try:
    237             self.args = list(ArgGenerator(a, op) for a in args)
    238             if not any((x.isxmm for x in self.args)):
    239                 raise SkipInstruction
    240             if len(self.args) > 0 and self.args[-1] is None:
    241                 self.args = self.args[:-1]
    242         except SkipInstruction:
    243             raise
    244         except Exception as e:
    245             raise Exception("Bad arg %s: %s" % (op, e))
    246 
    247     def gen(self):
    248         regs = (10, 11, 12)
    249         dest = 9
    250 
    251         nreg = len(self.args)
    252         if nreg == 0:
    253             yield self.op
    254             return
    255         if isinstance(self.args[-1], ArgImm8u):
    256             nreg -= 1
    257             immarg = self.args[-1]
    258         else:
    259             immarg = None
    260         memarg = -1
    261         for n, arg in enumerate(self.args):
    262             if arg.ismem:
    263                 memarg = n
    264 
    265         if (self.op.startswith("VGATHER") or self.op.startswith("VPGATHER")):
    266             if "GATHERD" in self.op:
    267                 ireg = 13 << 2
    268             else:
    269                 ireg = 14 << 2
    270             regset = [
    271                 (dest, ireg | 0, regs[0]),
    272                 (dest, ireg | 1, regs[0]),
    273                 (dest, ireg | 2, regs[0]),
    274                 (dest, ireg | 3, regs[0]),
    275                 ]
    276             if memarg >= 0:
    277                 raise Exception("vsib with memory: %s" % self.op)
    278         elif nreg == 1:
    279             regset = [(regs[0],)]
    280             if memarg == 0:
    281                 regset += [(-1,)]
    282         elif nreg == 2:
    283             regset = [
    284                 (regs[0], regs[1]),
    285                 (regs[0], regs[0]),
    286                 ]
    287             if memarg == 0:
    288                 regset += [(-1, regs[0])]
    289             elif memarg == 1:
    290                 regset += [(dest, -1)]
    291         elif nreg == 3:
    292             regset = [
    293                 (dest, regs[0], regs[1]),
    294                 (dest, regs[0], regs[0]),
    295                 (regs[0], regs[0], regs[1]),
    296                 (regs[0], regs[1], regs[0]),
    297                 (regs[0], regs[0], regs[0]),
    298                 ]
    299             if memarg == 2:
    300                 regset += [
    301                     (dest, regs[0], -1),
    302                     (regs[0], regs[0], -1),
    303                     ]
    304             elif memarg > 0:
    305                 raise Exception("Memarg %d" % memarg)
    306         elif nreg == 4:
    307             regset = [
    308                 (dest, regs[0], regs[1], regs[2]),
    309                 (dest, regs[0], regs[0], regs[1]),
    310                 (dest, regs[0], regs[1], regs[0]),
    311                 (dest, regs[1], regs[0], regs[0]),
    312                 (dest, regs[0], regs[0], regs[0]),
    313                 (regs[0], regs[0], regs[1], regs[2]),
    314                 (regs[0], regs[1], regs[0], regs[2]),
    315                 (regs[0], regs[1], regs[2], regs[0]),
    316                 (regs[0], regs[0], regs[0], regs[1]),
    317                 (regs[0], regs[0], regs[1], regs[0]),
    318                 (regs[0], regs[1], regs[0], regs[0]),
    319                 (regs[0], regs[0], regs[0], regs[0]),
    320                 ]
    321             if memarg == 2:
    322                 regset += [
    323                     (dest, regs[0], -1, regs[1]),
    324                     (dest, regs[0], -1, regs[0]),
    325                     (regs[0], regs[0], -1, regs[1]),
    326                     (regs[0], regs[1], -1, regs[0]),
    327                     (regs[0], regs[0], -1, regs[0]),
    328                     ]
    329             elif memarg > 0:
    330                 raise Exception("Memarg4 %d" % memarg)
    331         else:
    332             raise Exception("Too many regs: %s(%d)" % (self.op, nreg))
    333 
    334         for regv in regset:
    335             argstr = []
    336             for i in range(nreg):
    337                 arg = self.args[i]
    338                 argstr.append(arg.regstr(regv[i]))
    339             if immarg is None:
    340                 yield self.op + ' ' + ','.join(argstr)
    341             else:
    342                 for immval in immarg.vals():
    343                     yield self.op + ' ' + ','.join(argstr) + ',' + str(immval)
    344 
    345 def split0(s):
    346     if s == '':
    347         return []
    348     return s.split(',')
    349 
    350 def main():
    351     n = 0
    352     if len(sys.argv) != 3:
    353         print("Usage: test-avx.py x86.csv test-avx.h")
    354         exit(1)
    355     csvfile = open(sys.argv[1], 'r', newline='')
    356     with open(sys.argv[2], "w") as outf:
    357         outf.write("// Generated by test-avx.py. Do not edit.\n")
    358         for row in csv.reader(strip_comments(csvfile)):
    359             insn = row[0].replace(',', '').split()
    360             if insn[0] in ignore:
    361                 continue
    362             cpuid = row[6]
    363             if cpuid in archs:
    364                 try:
    365                     g = InsnGenerator(insn[0], insn[1:])
    366                     for insn in g.gen():
    367                         outf.write('TEST(%d, "%s", %s)\n' % (n, insn, g.optype))
    368                         n += 1
    369                 except SkipInstruction:
    370                     pass
    371         outf.write("#undef TEST\n")
    372         csvfile.close()
    373 
    374 if __name__ == "__main__":
    375     main()