summary refs log tree commit diff
diff options
context:
space:
mode:
authorQuentin Carbonneaux <quentin.carbonneaux@yale.edu>2016-04-07 13:08:31 -0400
committerQuentin Carbonneaux <quentin.carbonneaux@yale.edu>2016-04-07 13:21:15 -0400
commitebcc123e4dc0d5497e816fe824c1849685e295af (patch)
tree3c2c2e02f5de29a39b6a0cae4af57b6e375a517f
parent45f3493777488b05c28746670f585c8e41a76681 (diff)
downloadroux-ebcc123e4dc0d5497e816fe824c1849685e295af.tar.gz
add boring folding code
-rw-r--r--fold.c220
1 files changed, 212 insertions, 8 deletions
diff --git a/fold.c b/fold.c
index 0ca68c3..34b6165 100644
--- a/fold.c
+++ b/fold.c
@@ -13,8 +13,6 @@ struct Edge {
 	Edge *work;
 };
 
-int evalop(int, int, int, int);
-
 static int *val;
 static Edge *flowrk, (*edge)[2];
 static Use **usewrk;
@@ -36,6 +34,8 @@ latval(Ref r)
 		return r.val;
 	case RType:
 		return Bot;
+	case -1:
+		return CON_Z.val;
 	default:
 		die("unreachable");
 	}
@@ -64,7 +64,7 @@ update(int t, int v, Fn *fn)
 	if (val[t] != v) {
 		tmp = &fn->tmp[t];
 		for (u=0; u<tmp->nuse; u++) {
-			vgrow(usewrk, ++nuse);
+			vgrow(&usewrk, ++nuse);
 			usewrk[nuse-1] = &tmp->use[u];
 		}
 	}
@@ -89,6 +89,8 @@ visitphi(Phi *p, int n, Fn *fn)
 	update(p->to.val, v, fn);
 }
 
+static int opfold(int, int, Con *, Con *, Fn *);
+
 static void
 visitins(Ins *i, Fn *fn)
 {
@@ -96,9 +98,17 @@ visitins(Ins *i, Fn *fn)
 
 	if (rtype(i->to) != RTmp)
 		return;
-	l = latval(i->arg[0]);
-	r = latval(i->arg[1]);
-	v = evalop(i->op, i->cls, l, r);
+	if (opdesc[i->op].cfold) {
+		l = latval(i->arg[0]);
+		r = latval(i->arg[1]);
+		if (l == Bot || r == Bot)
+			v = Bot;
+		else if (l == Top || r == Top)
+			v = Top;
+		else
+			v = opfold(i->op, i->cls, &fn->con[l], &fn->con[r], fn);
+	} else
+		v = Bot;
 	assert(v != Top);
 	update(i->to.val, v, fn);
 }
@@ -164,8 +174,17 @@ fold(Fn *fn)
 	edge = emalloc(fn->nblk * sizeof edge[0]);
 	usewrk = vnew(0, sizeof usewrk[0]);
 
-	for (n=0; n<fn->ntmp; n++)
-		val[n] = Bot;
+	for (b=fn->start; b; b=b->link) {
+		for (p=b->phi; p; p=p->link)
+			val[p->to.val] = Top;
+		for (i=b->ins; i-b->ins < b->nins; i++)
+			if (rtype(i->to) == RTmp) {
+				if (opdesc[i->op].cfold)
+					val[i->to.val] = Top;
+				else
+					val[i->to.val] = Bot;
+			}
+	}
 	for (n=0; n<fn->nblk; n++) {
 		b = fn->rpo[n];
 		b->visit = 0;
@@ -229,3 +248,188 @@ fold(Fn *fn)
 	free(val);
 	free(edge);
 }
+
+/* boring folding code */
+
+static void
+foldint(Con *res, int op, int w, Con *cl, Con *cr)
+{
+	union {
+		int64_t s;
+		uint64_t u;
+		float fs;
+		double fd;
+	} l, r;
+	uint64_t x;
+	char *lab;
+
+	lab = 0;
+	l.s = cl->bits.i;
+	r.s = cl->bits.i;
+	switch (op) {
+	case OAdd:
+		x = l.u + r.u;
+		if (cl->type == CAddr) {
+			if (cr->type == CAddr)
+				err("undefined addition (addr + addr)");
+			lab = cl->label;
+		}
+		else if (cr->type == CAddr)
+			lab = cr->label;
+		break;
+	case OSub:
+		x = l.u - r.u;
+		if (cl->type == CAddr) {
+			if (cr->type != CAddr)
+				lab = cl->label;
+			else if (strcmp(cl->label, cr->label) != 0)
+				err("undefined substraction (addr1 - addr2)");
+		}
+		else if (cr->type == CAddr)
+			err("undefined substraction (num - addr)");
+		break;
+	case ODiv:  x = l.s / r.s; break;
+	case ORem:  x = l.s % r.s; break;
+	case OUDiv: x = l.u / r.u; break;
+	case OURem: x = l.u % r.u; break;
+	case OMul:  x = l.u * r.u; break;
+	case OAnd:  x = l.u & r.u; break;
+	case OOr:   x = l.u | r.u; break;
+	case OXor:  x = l.u ^ r.u; break;
+	case OSar:  x = l.s >> (r.u & 63); break;
+	case OShr:  x = l.u >> (r.u & 63); break;
+	case OShl:  x = l.u << (r.u & 63); break;
+	case OExtsb: x = (int8_t)l.u;   break;
+	case OExtub: x = (uint8_t)l.u;  break;
+	case OExtsh: x = (int16_t)l.u;  break;
+	case OExtuh: x = (uint16_t)l.u; break;
+	case OExtsw: x = (int32_t)l.u;  break;
+	case OExtuw: x = (uint32_t)l.u; break;
+	case OFtosi:
+		if (w)
+			x = (int64_t)cl->bits.d;
+		else
+			x = (int32_t)cl->bits.s;
+		break;
+	case OCast:
+		x = l.u;
+		if (cl->type == CAddr)
+			lab = cl->label;
+		break;
+	default:
+		if (OCmpw <= op && op <= OCmpl1) {
+			if (op <= OCmpw1) {
+				l.u = (uint32_t)l.u;
+				r.u = (uint32_t)r.u;
+			} else
+				op -= OCmpl - OCmpw;
+			switch (op - OCmpw) {
+			case ICule: x = l.u <= r.u; break;
+			case ICult: x = l.u < r.u;  break;
+			case ICsle: x = l.s <= r.s; break;
+			case ICslt: x = l.s < r.s;  break;
+			case ICsgt: x = l.s > r.s;  break;
+			case ICsge: x = l.s >= r.s; break;
+			case ICugt: x = l.u > r.u;  break;
+			case ICuge: x = l.u >= r.u; break;
+			case ICeq:  x = l.u == r.u; break;
+			case ICne:  x = l.u != r.u; break;
+			default: die("unreachable");
+			}
+		}
+		else if (OCmps <= op && op <= OCmps1) {
+			switch (op - OCmps) {
+			case FCle: x = l.fs <= r.fs; break;
+			case FClt: x = l.fs < r.fs;  break;
+			case FCgt: x = l.fs > r.fs;  break;
+			case FCge: x = l.fs >= r.fs; break;
+			case FCne: x = l.fs != r.fs; break;
+			case FCeq: x = l.fs == r.fs; break;
+			case FCo: x = l.fs < r.fs || l.fs >= r.fs; break;
+			case FCuo: x = !(l.fs < r.fs || l.fs >= r.fs); break;
+			default: die("unreachable");
+			}
+		}
+		else if (OCmpd <= op && op <= OCmpd1) {
+			switch (op - OCmpd) {
+			case FCle: x = l.fd <= r.fd; break;
+			case FClt: x = l.fd < r.fd;  break;
+			case FCgt: x = l.fd > r.fd;  break;
+			case FCge: x = l.fd >= r.fd; break;
+			case FCne: x = l.fd != r.fd; break;
+			case FCeq: x = l.fd == r.fd; break;
+			case FCo: x = l.fd < r.fd || l.fd >= r.fd; break;
+			case FCuo: x = !(l.fd < r.fd || l.fd >= r.fd); break;
+			default: die("unreachable");
+			}
+		}
+		else
+			die("unreachable");
+	}
+	*res = (Con){lab ? CAddr : CBits, .bits={.i=x}};
+	if (lab)
+		strcpy(res->label, lab);
+}
+
+static void
+foldflt(Con *res, int op, int w, Con *cl, Con *cr)
+{
+	float xs, ls, rs;
+	double xd, ld, rd;
+
+	if (w)  {
+		ld = cl->bits.d;
+		rd = cr->bits.d;
+		switch (op) {
+		case OAdd: xd = ld + rd; break;
+		case OSub: xd = ld - rd; break;
+		case ODiv: xd = ld / rd; break;
+		case OMul: xd = ld * rd; break;
+		case OSitof: xd = cl->bits.i; break;
+		case OExts: xd = cl->bits.s; break;
+		case OCast: xd = cl->bits.d; break;
+		default: die("unreachable");
+		}
+		*res = (Con){CBits, .bits={.d=xd}, .flt=2};
+	} else {
+		ls = cl->bits.s;
+		rs = cr->bits.s;
+		switch (op) {
+		case OAdd: xs = ls + rs; break;
+		case OSub: xs = ls - rs; break;
+		case ODiv: xs = ls / rs; break;
+		case OMul: xs = ls * rs; break;
+		case OSitof: xs = cl->bits.i; break;
+		case OTruncd: xs = cl->bits.d; break;
+		case OCast: xs = cl->bits.s; break;
+		default: die("unreachable");
+		}
+		*res = (Con){CBits, .bits={.s=xs}, .flt=1};
+	}
+}
+
+static int
+opfold(int op, int cls, Con *cl, Con *cr, Fn *fn)
+{
+	int nc;
+	Con c;
+
+	if ((op == ODiv || op == OUDiv
+	|| op == ORem || op == OURem) && czero(cr))
+		err("null divisor in '%s'", opdesc[op].name);
+	if (cls == Kw || cls == Kl)
+		foldint(&c, op, cls == Kl, cl, cr);
+	else {
+		if (cl->type != CBits || cr->type != CBits)
+			err("invalid address operand for '%s'", opdesc[op].name);
+		foldflt(&c, op, cls == Kd, cl, cr);
+	}
+	if (c.type == CBits)
+		nc = getcon(c.bits.i, fn).val;
+	else {
+		nc = fn->ncon;
+		vgrow(&fn->con, ++fn->ncon);
+	}
+	fn->con[nc] = c;
+	return nc;
+}