ljx

FORK: LuaJIT with native 5.2 and 5.3 support
git clone https://git.neptards.moe/neptards/ljx.git
Log | Files | Refs | README

lj_carith.c (12498B)


      1 /*
      2 ** C data arithmetic.
      3 ** Copyright (C) 2005-2016 Mike Pall. See Copyright Notice in luajit.h
      4 */
      5 
      6 #include "lj_obj.h"
      7 
      8 #if LJ_HASFFI
      9 
     10 #include "lj_gc.h"
     11 #include "lj_err.h"
     12 #include "lj_tab.h"
     13 #include "lj_meta.h"
     14 #include "lj_ir.h"
     15 #include "lj_ctype.h"
     16 #include "lj_cconv.h"
     17 #include "lj_cdata.h"
     18 #include "lj_carith.h"
     19 #include "lj_strscan.h"
     20 
     21 /* -- C data arithmetic --------------------------------------------------- */
     22 
     23 /* Binary operands of an operator converted to ctypes. */
     24 typedef struct CDArith {
     25   uint8_t *p[2];
     26   CType *ct[2];
     27 } CDArith;
     28 
     29 /* Check arguments for arithmetic metamethods. */
     30 static int carith_checkarg(lua_State *L, CTState *cts, CDArith *ca)
     31 {
     32   TValue *o = L->base;
     33   int ok = 1;
     34   MSize i;
     35   if (o+1 >= L->top)
     36     lj_err_argt(L, 1, LUA_TCDATA);
     37   for (i = 0; i < 2; i++, o++) {
     38     if (tviscdata(o)) {
     39       GCcdata *cd = cdataV(o);
     40       CTypeID id = (CTypeID)cd->ctypeid;
     41       CType *ct = ctype_raw(cts, id);
     42       uint8_t *p = (uint8_t *)cdataptr(cd);
     43       if (ctype_isptr(ct->info)) {
     44 	p = (uint8_t *)cdata_getptr(p, ct->size);
     45 	if (ctype_isref(ct->info)) ct = ctype_rawchild(cts, ct);
     46       } else if (ctype_isfunc(ct->info)) {
     47 	p = (uint8_t *)*(void **)p;
     48 	ct = ctype_get(cts,
     49 	  lj_ctype_intern(cts, CTINFO(CT_PTR, CTALIGN_PTR|id), CTSIZE_PTR));
     50       }
     51       if (ctype_isenum(ct->info)) ct = ctype_child(cts, ct);
     52       ca->ct[i] = ct;
     53       ca->p[i] = p;
     54     } else if (tvisint(o)) {
     55       ca->ct[i] = ctype_get(cts, CTID_INT32);
     56       ca->p[i] = (uint8_t *)&o->i;
     57     } else if (tvisnum(o)) {
     58       ca->ct[i] = ctype_get(cts, CTID_DOUBLE);
     59       ca->p[i] = (uint8_t *)&o->n;
     60     } else if (tvisnil(o)) {
     61       ca->ct[i] = ctype_get(cts, CTID_P_VOID);
     62       ca->p[i] = (uint8_t *)0;
     63     } else if (tvisstr(o)) {
     64       TValue *o2 = i == 0 ? o+1 : o-1;
     65       CType *ct = ctype_raw(cts, cdataV(o2)->ctypeid);
     66       ca->ct[i] = NULL;
     67       ca->p[i] = (uint8_t *)strVdata(o);
     68       ok = 0;
     69       if (ctype_isenum(ct->info)) {
     70 	CTSize ofs;
     71 	CType *cct = lj_ctype_getfield(cts, ct, strV(o), &ofs);
     72 	if (cct && ctype_isconstval(cct->info)) {
     73 	  ca->ct[i] = ctype_child(cts, cct);
     74 	  ca->p[i] = (uint8_t *)&cct->size;  /* Assumes ct does not grow. */
     75 	  ok = 1;
     76 	} else {
     77 	  ca->ct[1-i] = ct;  /* Use enum to improve error message. */
     78 	  ca->p[1-i] = NULL;
     79 	  break;
     80 	}
     81       }
     82     } else {
     83       ca->ct[i] = NULL;
     84       ca->p[i] = (void *)(intptr_t)1;  /* To make it unequal. */
     85       ok = 0;
     86     }
     87   }
     88   return ok;
     89 }
     90 
     91 /* Pointer arithmetic. */
     92 static int carith_ptr(lua_State *L, CTState *cts, CDArith *ca, MMS mm)
     93 {
     94   CType *ctp = ca->ct[0];
     95   uint8_t *pp = ca->p[0];
     96   ptrdiff_t idx;
     97   CTSize sz;
     98   CTypeID id;
     99   GCcdata *cd;
    100   if (ctype_isptr(ctp->info) || ctype_isrefarray(ctp->info)) {
    101     if ((mm == MM_sub || mm == MM_eq || mm == MM_lt || mm == MM_le) &&
    102 	(ctype_isptr(ca->ct[1]->info) || ctype_isrefarray(ca->ct[1]->info))) {
    103       uint8_t *pp2 = ca->p[1];
    104       if (mm == MM_eq) {  /* Pointer equality. Incompatible pointers are ok. */
    105 	setboolV(L->top-1, (pp == pp2));
    106 	return 1;
    107       }
    108       if (!lj_cconv_compatptr(cts, ctp, ca->ct[1], CCF_IGNQUAL))
    109 	return 0;
    110       if (mm == MM_sub) {  /* Pointer difference. */
    111 	intptr_t diff;
    112 	sz = lj_ctype_size(cts, ctype_cid(ctp->info));  /* Element size. */
    113 	if (sz == 0 || sz == CTSIZE_INVALID)
    114 	  return 0;
    115 	diff = ((intptr_t)pp - (intptr_t)pp2) / (int32_t)sz;
    116 	/* All valid pointer differences on x64 are in (-2^47, +2^47),
    117 	** which fits into a double without loss of precision.
    118 	*/
    119 	setintptrV(L->top-1, (int32_t)diff);
    120 	return 1;
    121       } else if (mm == MM_lt) {  /* Pointer comparison (unsigned). */
    122 	setboolV(L->top-1, ((uintptr_t)pp < (uintptr_t)pp2));
    123 	return 1;
    124       } else {
    125 	lua_assert(mm == MM_le);
    126 	setboolV(L->top-1, ((uintptr_t)pp <= (uintptr_t)pp2));
    127 	return 1;
    128       }
    129     }
    130     if (!((mm == MM_add || mm == MM_sub) && ctype_isnum(ca->ct[1]->info)))
    131       return 0;
    132     lj_cconv_ct_ct(cts, ctype_get(cts, CTID_INT_PSZ), ca->ct[1],
    133 		   (uint8_t *)&idx, ca->p[1], 0);
    134     if (mm == MM_sub) idx = -idx;
    135   } else if (mm == MM_add && ctype_isnum(ctp->info) &&
    136       (ctype_isptr(ca->ct[1]->info) || ctype_isrefarray(ca->ct[1]->info))) {
    137     /* Swap pointer and index. */
    138     ctp = ca->ct[1]; pp = ca->p[1];
    139     lj_cconv_ct_ct(cts, ctype_get(cts, CTID_INT_PSZ), ca->ct[0],
    140 		   (uint8_t *)&idx, ca->p[0], 0);
    141   } else {
    142     return 0;
    143   }
    144   sz = lj_ctype_size(cts, ctype_cid(ctp->info));  /* Element size. */
    145   if (sz == CTSIZE_INVALID)
    146     return 0;
    147   pp += idx*(int32_t)sz;  /* Compute pointer + index. */
    148   id = lj_ctype_intern(cts, CTINFO(CT_PTR, CTALIGN_PTR|ctype_cid(ctp->info)),
    149 		       CTSIZE_PTR);
    150   cd = lj_cdata_new(cts, id, CTSIZE_PTR);
    151   *(uint8_t **)cdataptr(cd) = pp;
    152   setcdataV(L, L->top-1, cd);
    153   lj_gc_check(L);
    154   return 1;
    155 }
    156 
    157 /* 64 bit integer arithmetic. */
    158 static int carith_int64(lua_State *L, CTState *cts, CDArith *ca, MMS mm)
    159 {
    160   if (ctype_isnum(ca->ct[0]->info) && ca->ct[0]->size <= 8 &&
    161       ctype_isnum(ca->ct[1]->info) && ca->ct[1]->size <= 8) {
    162     CTypeID id = (((ca->ct[0]->info & CTF_UNSIGNED) && ca->ct[0]->size == 8) ||
    163 		  ((ca->ct[1]->info & CTF_UNSIGNED) && ca->ct[1]->size == 8)) ?
    164 		 CTID_UINT64 : CTID_INT64;
    165     CType *ct = ctype_get(cts, id);
    166     GCcdata *cd;
    167     uint64_t u0, u1, *up;
    168     lj_cconv_ct_ct(cts, ct, ca->ct[0], (uint8_t *)&u0, ca->p[0], 0);
    169     if (mm != MM_unm)
    170       lj_cconv_ct_ct(cts, ct, ca->ct[1], (uint8_t *)&u1, ca->p[1], 0);
    171     switch (mm) {
    172     case MM_eq:
    173       setboolV(L->top-1, (u0 == u1));
    174       return 1;
    175     case MM_lt:
    176       setboolV(L->top-1,
    177 	       id == CTID_INT64 ? ((int64_t)u0 < (int64_t)u1) : (u0 < u1));
    178       return 1;
    179     case MM_le:
    180       setboolV(L->top-1,
    181 	       id == CTID_INT64 ? ((int64_t)u0 <= (int64_t)u1) : (u0 <= u1));
    182       return 1;
    183     default: break;
    184     }
    185     cd = lj_cdata_new(cts, id, 8);
    186     up = (uint64_t *)cdataptr(cd);
    187     setcdataV(L, L->top-1, cd);
    188     switch (mm) {
    189     case MM_add: *up = u0 + u1; break;
    190     case MM_sub: *up = u0 - u1; break;
    191     case MM_mul: *up = u0 * u1; break;
    192     case MM_div:
    193       if (id == CTID_INT64)
    194 	*up = (uint64_t)lj_carith_divi64((int64_t)u0, (int64_t)u1);
    195       else
    196 	*up = lj_carith_divu64(u0, u1);
    197       break;
    198     case MM_mod:
    199       if (id == CTID_INT64)
    200 	*up = (uint64_t)lj_carith_modi64((int64_t)u0, (int64_t)u1);
    201       else
    202 	*up = lj_carith_modu64(u0, u1);
    203       break;
    204     case MM_pow:
    205       if (id == CTID_INT64)
    206 	*up = (uint64_t)lj_carith_powi64((int64_t)u0, (int64_t)u1);
    207       else
    208 	*up = lj_carith_powu64(u0, u1);
    209       break;
    210     case MM_unm: *up = (uint64_t)-(int64_t)u0; break;
    211     default: lua_assert(0); break;
    212     }
    213     lj_gc_check(L);
    214     return 1;
    215   }
    216   return 0;
    217 }
    218 
    219 /* Handle ctype arithmetic metamethods. */
    220 static int lj_carith_meta(lua_State *L, CTState *cts, CDArith *ca, MMS mm)
    221 {
    222   cTValue *tv = NULL;
    223   if (tviscdata(L->base)) {
    224     CTypeID id = cdataV(L->base)->ctypeid;
    225     CType *ct = ctype_raw(cts, id);
    226     if (ctype_isptr(ct->info)) id = ctype_cid(ct->info);
    227     tv = lj_ctype_meta(cts, id, mm);
    228   }
    229   if (!tv && L->base+1 < L->top && tviscdata(L->base+1)) {
    230     CTypeID id = cdataV(L->base+1)->ctypeid;
    231     CType *ct = ctype_raw(cts, id);
    232     if (ctype_isptr(ct->info)) id = ctype_cid(ct->info);
    233     tv = lj_ctype_meta(cts, id, mm);
    234   }
    235   if (!tv) {
    236     const char *repr[2];
    237     int i, isenum = -1, isstr = -1;
    238     if (mm == MM_eq) {  /* Equality checks never raise an error. */
    239       int eq = ca->p[0] == ca->p[1];
    240       setboolV(L->top-1, eq);
    241       setboolV(&G(L)->tmptv2, eq);  /* Remember for trace recorder. */
    242       return 1;
    243     }
    244     for (i = 0; i < 2; i++) {
    245       if (ca->ct[i] && tviscdata(L->base+i)) {
    246 	if (ctype_isenum(ca->ct[i]->info)) isenum = i;
    247 	repr[i] = strdata(lj_ctype_repr(L, ctype_typeid(cts, ca->ct[i]), NULL));
    248       } else {
    249 	if (tvisstr(&L->base[i])) isstr = i;
    250 	repr[i] = lj_typename(&L->base[i]);
    251       }
    252     }
    253     if ((isenum ^ isstr) == 1)
    254       lj_err_callerv(L, LJ_ERR_FFI_BADCONV, repr[isstr], repr[isenum]);
    255     lj_err_callerv(L, mm == MM_len ? LJ_ERR_FFI_BADLEN :
    256 		      mm == MM_concat ? LJ_ERR_FFI_BADCONCAT :
    257 		      mm < MM_add ? LJ_ERR_FFI_BADCOMP : LJ_ERR_FFI_BADARITH,
    258 		   repr[0], repr[1]);
    259   }
    260   return lj_meta_tailcall(L, tv);
    261 }
    262 
    263 /* Arithmetic operators for cdata. */
    264 int lj_carith_op(lua_State *L, MMS mm)
    265 {
    266   CTState *cts = ctype_cts(L);
    267   CDArith ca;
    268   if (carith_checkarg(L, cts, &ca)) {
    269     if (carith_int64(L, cts, &ca, mm) || carith_ptr(L, cts, &ca, mm)) {
    270       copyTV(L, &G(L)->tmptv2, L->top-1);  /* Remember for trace recorder. */
    271       return 1;
    272     }
    273   }
    274   return lj_carith_meta(L, cts, &ca, mm);
    275 }
    276 
    277 /* -- 64 bit bit operations helpers --------------------------------------- */
    278 
    279 #if LJ_64
    280 #define B64DEF(name) \
    281   static LJ_AINLINE uint64_t lj_carith_##name(uint64_t x, int32_t sh)
    282 #else
    283 /* Not inlined on 32 bit archs, since some of these are quite lengthy. */
    284 #define B64DEF(name) \
    285   uint64_t LJ_NOINLINE lj_carith_##name(uint64_t x, int32_t sh)
    286 #endif
    287 
    288 B64DEF(shl64) { return x << (sh&63); }
    289 B64DEF(shr64) { return x >> (sh&63); }
    290 B64DEF(sar64) { return (uint64_t)((int64_t)x >> (sh&63)); }
    291 B64DEF(rol64) { return lj_rol(x, (sh&63)); }
    292 B64DEF(ror64) { return lj_ror(x, (sh&63)); }
    293 
    294 #undef B64DEF
    295 
    296 uint64_t lj_carith_shift64(uint64_t x, int32_t sh, int op)
    297 {
    298   switch (op) {
    299   case IR_BSHL-IR_BSHL: x = lj_carith_shl64(x, sh); break;
    300   case IR_BSHR-IR_BSHL: x = lj_carith_shr64(x, sh); break;
    301   case IR_BSAR-IR_BSHL: x = lj_carith_sar64(x, sh); break;
    302   case IR_BROL-IR_BSHL: x = lj_carith_rol64(x, sh); break;
    303   case IR_BROR-IR_BSHL: x = lj_carith_ror64(x, sh); break;
    304   default: lua_assert(0); break;
    305   }
    306   return x;
    307 }
    308 
    309 /* Equivalent to lj_lib_checkbit(), but handles cdata. */
    310 uint64_t lj_carith_check64(lua_State *L, int narg, CTypeID *id)
    311 {
    312   TValue *o = L->base + narg-1;
    313   if (o >= L->top) {
    314   err:
    315     lj_err_argt(L, narg, LUA_TNUMBER);
    316   } else if (LJ_LIKELY(tvisnumber(o))) {
    317     /* Handled below. */
    318   } else if (tviscdata(o)) {
    319     CTState *cts = ctype_cts(L);
    320     uint8_t *sp = (uint8_t *)cdataptr(cdataV(o));
    321     CTypeID sid = cdataV(o)->ctypeid;
    322     CType *s = ctype_get(cts, sid);
    323     uint64_t x;
    324     if (ctype_isref(s->info)) {
    325       sp = *(void **)sp;
    326       sid = ctype_cid(s->info);
    327     }
    328     s = ctype_raw(cts, sid);
    329     if (ctype_isenum(s->info)) s = ctype_child(cts, s);
    330     if ((s->info & (CTMASK_NUM|CTF_BOOL|CTF_FP|CTF_UNSIGNED)) ==
    331 	CTINFO(CT_NUM, CTF_UNSIGNED) && s->size == 8)
    332       *id = CTID_UINT64;  /* Use uint64_t, since it has the highest rank. */
    333     else if (!*id)
    334       *id = CTID_INT64;  /* Use int64_t, unless already set. */
    335     lj_cconv_ct_ct(cts, ctype_get(cts, *id), s,
    336 		   (uint8_t *)&x, sp, CCF_ARG(narg));
    337     return x;
    338   } else if (!(tvisstr(o) && lj_strscan_number(strV(o), o))) {
    339     goto err;
    340   }
    341   if (LJ_LIKELY(tvisint(o))) {
    342     return (uint32_t)intV(o);
    343   } else {
    344     int32_t i = lj_num2bit(numV(o));
    345     if (LJ_DUALNUM) setintV(o, i);
    346     return (uint32_t)i;
    347   }
    348 }
    349 
    350 
    351 /* -- 64 bit integer arithmetic helpers ----------------------------------- */
    352 
    353 #if LJ_32 && LJ_HASJIT
    354 /* Signed/unsigned 64 bit multiplication. */
    355 int64_t lj_carith_mul64(int64_t a, int64_t b)
    356 {
    357   return a * b;
    358 }
    359 #endif
    360 
    361 /* Unsigned 64 bit division. */
    362 uint64_t lj_carith_divu64(uint64_t a, uint64_t b)
    363 {
    364   if (b == 0) return U64x(80000000,00000000);
    365   return a / b;
    366 }
    367 
    368 /* Signed 64 bit division. */
    369 int64_t lj_carith_divi64(int64_t a, int64_t b)
    370 {
    371   if (b == 0 || (a == (int64_t)U64x(80000000,00000000) && b == -1))
    372     return U64x(80000000,00000000);
    373   return a / b;
    374 }
    375 
    376 /* Unsigned 64 bit modulo. */
    377 uint64_t lj_carith_modu64(uint64_t a, uint64_t b)
    378 {
    379   if (b == 0) return U64x(80000000,00000000);
    380   return a % b;
    381 }
    382 
    383 /* Signed 64 bit modulo. */
    384 int64_t lj_carith_modi64(int64_t a, int64_t b)
    385 {
    386   if (b == 0) return U64x(80000000,00000000);
    387   if (a == (int64_t)U64x(80000000,00000000) && b == -1) return 0;
    388   return a % b;
    389 }
    390 
    391 /* Unsigned 64 bit x^k. */
    392 uint64_t lj_carith_powu64(uint64_t x, uint64_t k)
    393 {
    394   uint64_t y;
    395   if (k == 0)
    396     return 1;
    397   for (; (k & 1) == 0; k >>= 1) x *= x;
    398   y = x;
    399   if ((k >>= 1) != 0) {
    400     for (;;) {
    401       x *= x;
    402       if (k == 1) break;
    403       if (k & 1) y *= x;
    404       k >>= 1;
    405     }
    406     y *= x;
    407   }
    408   return y;
    409 }
    410 
    411 /* Signed 64 bit x^k. */
    412 int64_t lj_carith_powi64(int64_t x, int64_t k)
    413 {
    414   if (k == 0)
    415     return 1;
    416   if (k < 0) {
    417     if (x == 0)
    418       return U64x(7fffffff,ffffffff);
    419     else if (x == 1)
    420       return 1;
    421     else if (x == -1)
    422       return (k & 1) ? -1 : 1;
    423     else
    424       return 0;
    425   }
    426   return (int64_t)lj_carith_powu64((uint64_t)x, (uint64_t)k);
    427 }
    428 
    429 #endif