diff options
author | Quentin Carbonneaux <quentin.carbonneaux@yale.edu> | 2016-09-04 20:22:38 -0400 |
---|---|---|
committer | Quentin Carbonneaux <quentin.carbonneaux@yale.edu> | 2016-12-12 22:17:03 -0500 |
commit | 3f147ed2e078769a71b2935fc36cb08b2b0ddb67 (patch) | |
tree | 3bf4bed9df2904c41d786426ab4cfc8cc7767a87 /load.c | |
parent | 8fdea1dd5236f2693b677fc6bd6e2bb417c0fccd (diff) | |
download | roux-3f147ed2e078769a71b2935fc36cb08b2b0ddb67.tar.gz |
implement a simple load elimination pass
Diffstat (limited to 'load.c')
-rw-r--r-- | load.c | 408 |
1 files changed, 408 insertions, 0 deletions
diff --git a/load.c b/load.c new file mode 100644 index 0000000..a6bd68e --- /dev/null +++ b/load.c @@ -0,0 +1,408 @@ +#include "all.h" + +#define MASK(w) (BIT(8*(w)-1)*2-1) /* must work when w==8 */ + +typedef struct Loc Loc; +typedef struct Slice Slice; +typedef struct Insert Insert; + + +struct Loc { + enum { + LRoot, /* right above the original load */ + LLoad, /* inserting a load is allowed */ + LNoLoad, /* only scalar operations allowed */ + } type; + uint off; + Blk *blk; +}; + +struct Slice { + Ref ref; + short sz; + short cls; /* load class */ +}; + +struct Insert { + uint isphi:1; + uint num:31; + int bid; + uint off; + union { + Ins ins; + struct { + Slice m; + Phi *p; + } phi; + } new; +}; + +static Fn *curf; +static uint inum; /* current insertion number */ +static Insert *ilog; /* global insertion log */ +static uint nlog; /* number of entries in the log */ + +static int +loadsz(Ins *l) +{ + switch (l->op) { + case Oloadsb: case Oloadub: return 1; + case Oloadsh: case Oloaduh: return 2; + case Oloadsw: case Oloaduw: return 4; + case Oload: return KWIDE(l->cls) ? 8 : 4; + } + die("unreachable"); +} + +static int +storesz(Ins *s) +{ + switch (s->op) { + case Ostoreb: return 1; + case Ostoreh: return 2; + case Ostorew: case Ostores: return 4; + case Ostorel: case Ostored: return 8; + } + die("unreachable"); +} + +static Ref +iins(int cls, int op, Ref a0, Ref a1, Loc *l) +{ + Insert *ist; + + vgrow(&ilog, ++nlog); + ist = &ilog[nlog-1]; + ist->isphi = 0; + ist->num = inum++; + ist->bid = l->blk->id; + ist->off = l->off; + ist->new.ins = (Ins){op, R, {a0, a1}, cls}; + return ist->new.ins.to = newtmp("ld", cls, curf); +} + +static void +cast(Ref *r, int cls, Loc *l) +{ + int cls0; + + if (rtype(*r) == RCon) + return; + assert(rtype(*r) == RTmp); + cls0 = curf->tmp[r->val].cls; + if (cls0 == cls || (cls == Kw && cls0 == Kl)) + return; + assert(!KWIDE(cls0) || KWIDE(cls)); + if (KWIDE(cls) == KWIDE(cls0)) + *r = iins(cls, Ocast, *r, R, l); + else { + assert(cls == Kl); + if (cls0 == Ks) + *r = iins(Kw, Ocast, *r, R, l); + *r = iins(Kl, Oextuw, *r, R, l); + } +} + +static inline void +mask(int cls, Ref *r, bits msk, Loc *l) +{ + cast(r, cls, l); + *r = iins(cls, Oand, *r, getcon(msk, curf), l); +} + +static Ref +load(Slice sl, bits msk, Loc *l) +{ + Ref r; + int ld, cls, all; + + ld = (int[]){ + [1] = Oloadub, + [2] = Oloaduh, + [4] = Oloaduw, + [8] = Oload + }[sl.sz]; + all = msk == MASK(sl.sz); + if (all) + cls = sl.cls; + else + cls = sl.sz > 4 ? Kl : Kw; + r = iins(cls, ld, sl.ref, R, l); + if (!all) + mask(cls, &r, msk, l); + return r; +} + +/* returns a ref containing the contents of the slice + * passed as argument, all the bits set to 0 in the + * mask argument are zeroed in the result; + * the returned ref has an integer class when the + * mask does not cover all the bits of the slice, + * otherwise, it has class sl.cls + * the procedure returns R when it fails */ +static Ref +def(Slice sl, bits msk, Blk *b, Ins *i, Loc *il) +{ + Blk *bp; + bits msk1, msks; + int off, cls, cls1, op, sz, ld; + uint np, oldl, oldt; + Ref r, r1; + Phi *p; + Insert *ist; + Loc l; + + /* invariants: + * -1- b dominates il->blk; so we can use + * temporaries of b in il->blk + * -2- if il->type != LNoLoad, then il->blk + * postdominates the original load; so it + * is safe to load in il->blk + * -3- if il->type != LNoLoad, then b + * postdominates il->blk (and by 2, the + * original load) + */ + assert(dom(b, il->blk)); + oldl = nlog; + oldt = curf->ntmp; + if (0) { + Load: + curf->ntmp = oldt; + nlog = oldl; + if (il->type != LLoad) + return R; + return load(sl, msk, il); + } + + if (!i) + i = &b->ins[b->nins]; + cls = sl.sz > 4 ? Kl : Kw; + msks = MASK(sl.sz); + + while (i > b->ins) { + --i; + if (req(i->to, sl.ref) + || (i->op == Ocall && escapes(sl.ref, curf))) + goto Load; + ld = isload(i->op); + if (ld) { + sz = loadsz(i); + r1 = i->arg[0]; + r = i->to; + } else if (isstore(i->op)) { + sz = storesz(i); + r1 = i->arg[1]; + r = i->arg[0]; + } else + continue; + switch (alias(sl.ref, sl.sz, r1, sz, &off, curf)) { + case MustAlias: + if (off < 0) { + off = -off; + msk1 = (MASK(sz) << 8*off) & msks; + op = Oshl; + } else { + msk1 = (MASK(sz) >> 8*off) & msks; + op = Oshr; + } + if ((msk1 & msk) == 0) + break; + if (off) { + cls1 = cls; + if (op == Oshr && off + sl.sz > 4) + cls1 = Kl; + cast(&r, cls1, il); + r1 = getcon(8*off, curf); + r = iins(cls1, op, r, r1, il); + } + if ((msk1 & msk) != msk1 || off + sz < sl.sz) + mask(cls, &r, msk1 & msk, il); + if ((msk & ~msk1) != 0) { + r1 = def(sl, msk & ~msk1, b, i, il); + if (req(r1, R)) + goto Load; + r = iins(cls, Oor, r, r1, il); + } + if (msk == msks) + cast(&r, sl.cls, il); + return r; + case MayAlias: + if (ld) + break; + else + goto Load; + case NoAlias: + break; + default: + die("unreachable"); + } + } + + for (ist=ilog; ist<&ilog[nlog]; ++ist) + if (ist->isphi && ist->bid == b->id) + if (req(ist->new.phi.m.ref, sl.ref)) + if (ist->new.phi.m.sz == sl.sz) { + r = ist->new.phi.p->to; + if (msk != msks) + mask(cls, &r, msk, il); + else + cast(&r, sl.cls, il); + return r; + } + + for (p=b->phi; p; p=p->link) + if (req(p->to, sl.ref)) + /* scanning predecessors in that + * case would be unsafe */ + goto Load; + + if (b->npred == 0) + goto Load; + if (b->npred == 1) { + bp = b->pred[0]; + assert(bp->loop == il->blk->loop); + l = *il; + if (bp->s2) + l.type = LNoLoad; + r1 = def(sl, msk, bp, 0, &l); + if (req(r1, R)) + goto Load; + return r1; + } + + r = newtmp("ld", sl.cls, curf); + p = alloc(sizeof *p); + vgrow(&ilog, ++nlog); + ist = &ilog[nlog-1]; + ist->isphi = 1; + ist->bid = b->id; + ist->new.phi.m = sl; + ist->new.phi.p = p; + p->to = r; + p->cls = sl.cls; + p->narg = b->npred; + for (np=0; np<b->npred; ++np) { + bp = b->pred[np]; + if (!bp->s2 + && il->type != LNoLoad + && bp->loop < il->blk->loop) + l.type = LLoad; + else + l.type = LNoLoad; + l.blk = bp; + l.off = bp->nins; + r1 = def(sl, msks, bp, 0, &l); + if (req(r1, R)) + goto Load; + p->arg[np] = r1; + p->blk[np] = bp; + } + if (msk != msks) + mask(cls, &r, msk, il); + return r; +} + +static int +icmp(const void *pa, const void *pb) +{ + Insert *a, *b; + int c; + + a = (Insert *)pa; + b = (Insert *)pb; + if ((c = a->bid - b->bid)) + return c; + if (a->isphi && b->isphi) + return 0; + if (a->isphi) + return -1; + if (b->isphi) + return +1; + if ((c = a->off - b->off)) + return c; + return a->num - b->num; +} + +/* require rpo ssa alias */ +void +loadopt(Fn *fn) +{ + Ins *i, *ib; + Blk *b; + int n, sz; + uint ni, ext, nt; + Insert *ist; + Slice sl; + Loc l; + + curf = fn; + ilog = vnew(0, sizeof ilog[0], emalloc); + nlog = 0; + inum = 0; + for (b=fn->start; b; b=b->link) + for (i=b->ins; i<&b->ins[b->nins]; ++i) { + if (!isload(i->op)) + continue; + sz = loadsz(i); + sl = (Slice){i->arg[0], sz, i->cls}; + l = (Loc){LRoot, i-b->ins, b}; + i->arg[1] = def(sl, MASK(sz), b, i, &l); + } + qsort(ilog, nlog, sizeof ilog[0], icmp); + vgrow(&ilog, nlog+1); + ilog[nlog].bid = fn->nblk; /* add a sentinel */ + ib = vnew(0, sizeof(Ins), emalloc); + for (ist=ilog, n=0; n<fn->nblk; ++n) { + b = fn->rpo[n]; + for (; ist->bid == n && ist->isphi; ++ist) { + ist->new.phi.p->link = b->phi; + b->phi = ist->new.phi.p; + } + ni = 0; + nt = 0; + for (;;) { + if (ist->bid == n && ist->off == ni) + i = &ist++->new.ins; + else { + if (ni == b->nins) + break; + i = &b->ins[ni++]; + if (isload(i->op) + && !req(i->arg[1], R)) { + ext = Oextsb + i->op - Oloadsb; + switch (i->op) { + default: + die("unreachable"); + case Oloadsb: + case Oloadub: + case Oloadsh: + case Oloaduh: + i->op = ext; + break; + case Oloadsw: + case Oloaduw: + if (i->cls == Kl) { + i->op = ext; + break; + } + case Oload: + i->op = Ocopy; + break; + } + i->arg[0] = i->arg[1]; + i->arg[1] = R; + } + } + vgrow(&ib, ++nt); + ib[nt-1] = *i; + } + b->nins = nt; + idup(&b->ins, ib, nt); + } + vfree(ib); + vfree(ilog); + if (debug['M']) { + fprintf(stderr, "\n> After load elimination:\n"); + printfn(fn, stderr); + } +} |