From ebcc123e4dc0d5497e816fe824c1849685e295af Mon Sep 17 00:00:00 2001 From: Quentin Carbonneaux Date: Thu, 7 Apr 2016 13:08:31 -0400 Subject: add boring folding code --- fold.c | 220 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 212 insertions(+), 8 deletions(-) diff --git a/fold.c b/fold.c index 0ca68c3..34b6165 100644 --- a/fold.c +++ b/fold.c @@ -13,8 +13,6 @@ struct Edge { Edge *work; }; -int evalop(int, int, int, int); - static int *val; static Edge *flowrk, (*edge)[2]; static Use **usewrk; @@ -36,6 +34,8 @@ latval(Ref r) return r.val; case RType: return Bot; + case -1: + return CON_Z.val; default: die("unreachable"); } @@ -64,7 +64,7 @@ update(int t, int v, Fn *fn) if (val[t] != v) { tmp = &fn->tmp[t]; for (u=0; unuse; u++) { - vgrow(usewrk, ++nuse); + vgrow(&usewrk, ++nuse); usewrk[nuse-1] = &tmp->use[u]; } } @@ -89,6 +89,8 @@ visitphi(Phi *p, int n, Fn *fn) update(p->to.val, v, fn); } +static int opfold(int, int, Con *, Con *, Fn *); + static void visitins(Ins *i, Fn *fn) { @@ -96,9 +98,17 @@ visitins(Ins *i, Fn *fn) if (rtype(i->to) != RTmp) return; - l = latval(i->arg[0]); - r = latval(i->arg[1]); - v = evalop(i->op, i->cls, l, r); + if (opdesc[i->op].cfold) { + l = latval(i->arg[0]); + r = latval(i->arg[1]); + if (l == Bot || r == Bot) + v = Bot; + else if (l == Top || r == Top) + v = Top; + else + v = opfold(i->op, i->cls, &fn->con[l], &fn->con[r], fn); + } else + v = Bot; assert(v != Top); update(i->to.val, v, fn); } @@ -164,8 +174,17 @@ fold(Fn *fn) edge = emalloc(fn->nblk * sizeof edge[0]); usewrk = vnew(0, sizeof usewrk[0]); - for (n=0; nntmp; n++) - val[n] = Bot; + for (b=fn->start; b; b=b->link) { + for (p=b->phi; p; p=p->link) + val[p->to.val] = Top; + for (i=b->ins; i-b->ins < b->nins; i++) + if (rtype(i->to) == RTmp) { + if (opdesc[i->op].cfold) + val[i->to.val] = Top; + else + val[i->to.val] = Bot; + } + } for (n=0; nnblk; n++) { b = fn->rpo[n]; b->visit = 0; @@ -229,3 +248,188 @@ fold(Fn *fn) free(val); free(edge); } + +/* boring folding code */ + +static void +foldint(Con *res, int op, int w, Con *cl, Con *cr) +{ + union { + int64_t s; + uint64_t u; + float fs; + double fd; + } l, r; + uint64_t x; + char *lab; + + lab = 0; + l.s = cl->bits.i; + r.s = cl->bits.i; + switch (op) { + case OAdd: + x = l.u + r.u; + if (cl->type == CAddr) { + if (cr->type == CAddr) + err("undefined addition (addr + addr)"); + lab = cl->label; + } + else if (cr->type == CAddr) + lab = cr->label; + break; + case OSub: + x = l.u - r.u; + if (cl->type == CAddr) { + if (cr->type != CAddr) + lab = cl->label; + else if (strcmp(cl->label, cr->label) != 0) + err("undefined substraction (addr1 - addr2)"); + } + else if (cr->type == CAddr) + err("undefined substraction (num - addr)"); + break; + case ODiv: x = l.s / r.s; break; + case ORem: x = l.s % r.s; break; + case OUDiv: x = l.u / r.u; break; + case OURem: x = l.u % r.u; break; + case OMul: x = l.u * r.u; break; + case OAnd: x = l.u & r.u; break; + case OOr: x = l.u | r.u; break; + case OXor: x = l.u ^ r.u; break; + case OSar: x = l.s >> (r.u & 63); break; + case OShr: x = l.u >> (r.u & 63); break; + case OShl: x = l.u << (r.u & 63); break; + case OExtsb: x = (int8_t)l.u; break; + case OExtub: x = (uint8_t)l.u; break; + case OExtsh: x = (int16_t)l.u; break; + case OExtuh: x = (uint16_t)l.u; break; + case OExtsw: x = (int32_t)l.u; break; + case OExtuw: x = (uint32_t)l.u; break; + case OFtosi: + if (w) + x = (int64_t)cl->bits.d; + else + x = (int32_t)cl->bits.s; + break; + case OCast: + x = l.u; + if (cl->type == CAddr) + lab = cl->label; + break; + default: + if (OCmpw <= op && op <= OCmpl1) { + if (op <= OCmpw1) { + l.u = (uint32_t)l.u; + r.u = (uint32_t)r.u; + } else + op -= OCmpl - OCmpw; + switch (op - OCmpw) { + case ICule: x = l.u <= r.u; break; + case ICult: x = l.u < r.u; break; + case ICsle: x = l.s <= r.s; break; + case ICslt: x = l.s < r.s; break; + case ICsgt: x = l.s > r.s; break; + case ICsge: x = l.s >= r.s; break; + case ICugt: x = l.u > r.u; break; + case ICuge: x = l.u >= r.u; break; + case ICeq: x = l.u == r.u; break; + case ICne: x = l.u != r.u; break; + default: die("unreachable"); + } + } + else if (OCmps <= op && op <= OCmps1) { + switch (op - OCmps) { + case FCle: x = l.fs <= r.fs; break; + case FClt: x = l.fs < r.fs; break; + case FCgt: x = l.fs > r.fs; break; + case FCge: x = l.fs >= r.fs; break; + case FCne: x = l.fs != r.fs; break; + case FCeq: x = l.fs == r.fs; break; + case FCo: x = l.fs < r.fs || l.fs >= r.fs; break; + case FCuo: x = !(l.fs < r.fs || l.fs >= r.fs); break; + default: die("unreachable"); + } + } + else if (OCmpd <= op && op <= OCmpd1) { + switch (op - OCmpd) { + case FCle: x = l.fd <= r.fd; break; + case FClt: x = l.fd < r.fd; break; + case FCgt: x = l.fd > r.fd; break; + case FCge: x = l.fd >= r.fd; break; + case FCne: x = l.fd != r.fd; break; + case FCeq: x = l.fd == r.fd; break; + case FCo: x = l.fd < r.fd || l.fd >= r.fd; break; + case FCuo: x = !(l.fd < r.fd || l.fd >= r.fd); break; + default: die("unreachable"); + } + } + else + die("unreachable"); + } + *res = (Con){lab ? CAddr : CBits, .bits={.i=x}}; + if (lab) + strcpy(res->label, lab); +} + +static void +foldflt(Con *res, int op, int w, Con *cl, Con *cr) +{ + float xs, ls, rs; + double xd, ld, rd; + + if (w) { + ld = cl->bits.d; + rd = cr->bits.d; + switch (op) { + case OAdd: xd = ld + rd; break; + case OSub: xd = ld - rd; break; + case ODiv: xd = ld / rd; break; + case OMul: xd = ld * rd; break; + case OSitof: xd = cl->bits.i; break; + case OExts: xd = cl->bits.s; break; + case OCast: xd = cl->bits.d; break; + default: die("unreachable"); + } + *res = (Con){CBits, .bits={.d=xd}, .flt=2}; + } else { + ls = cl->bits.s; + rs = cr->bits.s; + switch (op) { + case OAdd: xs = ls + rs; break; + case OSub: xs = ls - rs; break; + case ODiv: xs = ls / rs; break; + case OMul: xs = ls * rs; break; + case OSitof: xs = cl->bits.i; break; + case OTruncd: xs = cl->bits.d; break; + case OCast: xs = cl->bits.s; break; + default: die("unreachable"); + } + *res = (Con){CBits, .bits={.s=xs}, .flt=1}; + } +} + +static int +opfold(int op, int cls, Con *cl, Con *cr, Fn *fn) +{ + int nc; + Con c; + + if ((op == ODiv || op == OUDiv + || op == ORem || op == OURem) && czero(cr)) + err("null divisor in '%s'", opdesc[op].name); + if (cls == Kw || cls == Kl) + foldint(&c, op, cls == Kl, cl, cr); + else { + if (cl->type != CBits || cr->type != CBits) + err("invalid address operand for '%s'", opdesc[op].name); + foldflt(&c, op, cls == Kd, cl, cr); + } + if (c.type == CBits) + nc = getcon(c.bits.i, fn).val; + else { + nc = fn->ncon; + vgrow(&fn->con, ++fn->ncon); + } + fn->con[nc] = c; + return nc; +} -- cgit 1.4.1