#include "all.h" enum { Bot = -2, /* lattice bottom */ Top = -1, /* lattice top */ }; typedef struct Edge Edge; struct Edge { int dest; int dead; Edge *work; }; static int *val; static Edge *flowrk, (*edge)[2]; static Use **usewrk; static uint nuse; static int czero(Con *c) { return c->type == CBits && c->bits.i == 0; } static int latval(Ref r) { switch (rtype(r)) { case RTmp: return val[r.val]; case RCon: return r.val; case RType: return Bot; case -1: return CON_Z.val; default: die("unreachable"); } } static int latmerge(int l, int r) { if (l == Bot || r == Bot) return Bot; if (l == Top) return r; if (r == Top) return l; if (l == r) return l; return Bot; } static void update(int t, int v, Fn *fn) { Tmp *tmp; uint u; if (val[t] != v) { tmp = &fn->tmp[t]; for (u=0; unuse; u++) { vgrow(&usewrk, ++nuse); usewrk[nuse-1] = &tmp->use[u]; } } } static void visitphi(Phi *p, int n, Fn *fn) { int v, dead; uint a; v = Top; for (a=0; anarg; a++) { if (edge[n][0].dest == p->blk[a]->id) dead = edge[n][0].dead; else dead = edge[n][1].dead; if (!dead) v = latmerge(v, latval(p->arg[a])); } assert(v != Top); update(p->to.val, v, fn); } static int opfold(int, int, Con *, Con *, Fn *); static void visitins(Ins *i, Fn *fn) { int v, l, r; if (rtype(i->to) != RTmp) return; if (opdesc[i->op].cfold) { l = latval(i->arg[0]); r = latval(i->arg[1]); if (l == Bot || r == Bot) v = Bot; else if (l == Top || r == Top) v = Top; else v = opfold(i->op, i->cls, &fn->con[l], &fn->con[r], fn); } else v = Bot; assert(v != Top); update(i->to.val, v, fn); } static void visitjmp(Blk *b, int n, Fn *fn) { int l; switch (b->jmp.type) { case JJnz: l = latval(b->jmp.arg); if (l == Bot) { edge[n][1].work = flowrk; edge[n][0].work = &edge[n][1]; flowrk = &edge[n][0]; } else if (czero(&fn->con[l])) { assert(edge[n][0].dead); edge[n][1].work = flowrk; flowrk = &edge[n][1]; } else { assert(edge[n][1].dead); edge[n][0].work = flowrk; flowrk = &edge[n][0]; } break; case JJmp: edge[n][0].work = flowrk; flowrk = &edge[n][0]; break; default: if (isret(b->jmp.type)) break; die("unreachable"); } } static void initedge(Edge *e, Blk *s) { if (s) e->dest = s->id; else e->dest = -1; e->dead = 1; e->work = 0; } /* require rpo, use, pred */ void fold(Fn *fn) { Edge *e; Use *u; Blk *b; Phi *p; Ins *i; int n; val = emalloc(fn->ntmp * sizeof val[0]); edge = emalloc(fn->nblk * sizeof edge[0]); usewrk = vnew(0, sizeof usewrk[0]); for (b=fn->start; b; b=b->link) { for (p=b->phi; p; p=p->link) val[p->to.val] = Top; for (i=b->ins; i-b->ins < b->nins; i++) if (rtype(i->to) == RTmp) { if (opdesc[i->op].cfold) val[i->to.val] = Top; else val[i->to.val] = Bot; } } for (n=0; nnblk; n++) { b = fn->rpo[n]; b->visit = 0; initedge(&edge[n][0], b->s1); initedge(&edge[n][1], b->s2); } assert(fn->start->id == 0); edge[0][0].work = &edge[0][1]; flowrk = &edge[0][0]; nuse = 0; /* 1. find out constants and dead cfg edges */ for (;;) { e = flowrk; if (e) { flowrk = e->work; e->work = 0; if (e->dest == -1 || !e->dead) continue; e->dead = 0; n = e->dest; b = fn->rpo[n]; for (p=b->phi; p; p=p->link) visitphi(p, n, fn); if (b->visit == 0) { for (i=b->ins; i-b->ins < b->nins; i++) visitins(i, fn); visitjmp(b, n, fn); } b->visit++; assert(b->jmp.type != JJmp || edge[n][0].dead != 0 || flowrk == &edge[n][0]); } else if (nuse) { u = usewrk[--nuse]; if (u->type == UPhi) { visitphi(u->u.phi, u->bid, fn); continue; } if (b->visit == 0) continue; n = u->bid; switch (u->type) { case UIns: visitins(u->u.ins, fn); break; case UJmp: visitjmp(fn->rpo[n], n, fn); break; default: die("unreachable"); } } else break; } /* 2. trim dead code, replace constants */ free(val); free(edge); } /* boring folding code */ static void foldint(Con *res, int op, int w, Con *cl, Con *cr) { union { int64_t s; uint64_t u; float fs; double fd; } l, r; uint64_t x; char *lab; lab = 0; l.s = cl->bits.i; r.s = cl->bits.i; switch (op) { case OAdd: x = l.u + r.u; if (cl->type == CAddr) { if (cr->type == CAddr) err("undefined addition (addr + addr)"); lab = cl->label; } else if (cr->type == CAddr) lab = cr->label; break; case OSub: x = l.u - r.u; if (cl->type == CAddr) { if (cr->type != CAddr) lab = cl->label; else if (strcmp(cl->label, cr->label) != 0) err("undefined substraction (addr1 - addr2)"); } else if (cr->type == CAddr) err("undefined substraction (num - addr)"); break; case ODiv: x = l.s / r.s; break; case ORem: x = l.s % r.s; break; case OUDiv: x = l.u / r.u; break; case OURem: x = l.u % r.u; break; case OMul: x = l.u * r.u; break; case OAnd: x = l.u & r.u; break; case OOr: x = l.u | r.u; break; case OXor: x = l.u ^ r.u; break; case OSar: x = l.s >> (r.u & 63); break; case OShr: x = l.u >> (r.u & 63); break; case OShl: x = l.u << (r.u & 63); break; case OExtsb: x = (int8_t)l.u; break; case OExtub: x = (uint8_t)l.u; break; case OExtsh: x = (int16_t)l.u; break; case OExtuh: x = (uint16_t)l.u; break; case OExtsw: x = (int32_t)l.u; break; case OExtuw: x = (uint32_t)l.u; break; case OFtosi: if (w) x = (int64_t)cl->bits.d; else x = (int32_t)cl->bits.s; break; case OCast: x = l.u; if (cl->type == CAddr) lab = cl->label; break; default: if (OCmpw <= op && op <= OCmpl1) { if (op <= OCmpw1) { l.u = (uint32_t)l.u; r.u = (uint32_t)r.u; } else op -= OCmpl - OCmpw; switch (op - OCmpw) { case ICule: x = l.u <= r.u; break; case ICult: x = l.u < r.u; break; case ICsle: x = l.s <= r.s; break; case ICslt: x = l.s < r.s; break; case ICsgt: x = l.s > r.s; break; case ICsge: x = l.s >= r.s; break; case ICugt: x = l.u > r.u; break; case ICuge: x = l.u >= r.u; break; case ICeq: x = l.u == r.u; break; case ICne: x = l.u != r.u; break; default: die("unreachable"); } } else if (OCmps <= op && op <= OCmps1) { switch (op - OCmps) { case FCle: x = l.fs <= r.fs; break; case FClt: x = l.fs < r.fs; break; case FCgt: x = l.fs > r.fs; break; case FCge: x = l.fs >= r.fs; break; case FCne: x = l.fs != r.fs; break; case FCeq: x = l.fs == r.fs; break; case FCo: x = l.fs < r.fs || l.fs >= r.fs; break; case FCuo: x = !(l.fs < r.fs || l.fs >= r.fs); break; default: die("unreachable"); } } else if (OCmpd <= op && op <= OCmpd1) { switch (op - OCmpd) { case FCle: x = l.fd <= r.fd; break; case FClt: x = l.fd < r.fd; break; case FCgt: x = l.fd > r.fd; break; case FCge: x = l.fd >= r.fd; break; case FCne: x = l.fd != r.fd; break; case FCeq: x = l.fd == r.fd; break; case FCo: x = l.fd < r.fd || l.fd >= r.fd; break; case FCuo: x = !(l.fd < r.fd || l.fd >= r.fd); break; default: die("unreachable"); } } else die("unreachable"); } *res = (Con){lab ? CAddr : CBits, .bits={.i=x}}; if (lab) strcpy(res->label, lab); } static void foldflt(Con *res, int op, int w, Con *cl, Con *cr) { float xs, ls, rs; double xd, ld, rd; if (w) { ld = cl->bits.d; rd = cr->bits.d; switch (op) { case OAdd: xd = ld + rd; break; case OSub: xd = ld - rd; break; case ODiv: xd = ld / rd; break; case OMul: xd = ld * rd; break; case OSitof: xd = cl->bits.i; break; case OExts: xd = cl->bits.s; break; case OCast: xd = cl->bits.d; break; default: die("unreachable"); } *res = (Con){CBits, .bits={.d=xd}, .flt=2}; } else { ls = cl->bits.s; rs = cr->bits.s; switch (op) { case OAdd: xs = ls + rs; break; case OSub: xs = ls - rs; break; case ODiv: xs = ls / rs; break; case OMul: xs = ls * rs; break; case OSitof: xs = cl->bits.i; break; case OTruncd: xs = cl->bits.d; break; case OCast: xs = cl->bits.s; break; default: die("unreachable"); } *res = (Con){CBits, .bits={.s=xs}, .flt=1}; } } static int opfold(int op, int cls, Con *cl, Con *cr, Fn *fn) { int nc; Con c; if ((op == ODiv || op == OUDiv || op == ORem || op == OURem) && czero(cr)) err("null divisor in '%s'", opdesc[op].name); if (cls == Kw || cls == Kl) foldint(&c, op, cls == Kl, cl, cr); else { if (cl->type != CBits || cr->type != CBits) err("invalid address operand for '%s'", opdesc[op].name); foldflt(&c, op, cls == Kd, cl, cr); } if (c.type == CBits) nc = getcon(c.bits.i, fn).val; else { nc = fn->ncon; vgrow(&fn->con, ++fn->ncon); } fn->con[nc] = c; return nc; }