summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--fold.c97
1 files changed, 57 insertions, 40 deletions
diff --git a/fold.c b/fold.c
index 34b6165..b7bc201 100644
--- a/fold.c
+++ b/fold.c
@@ -19,9 +19,14 @@ static Use **usewrk;
 static uint nuse;
 
 static int
-czero(Con *c)
+czero(Con *c, int w)
 {
-	return c->type == CBits && c->bits.i == 0;
+	if (c->type != CBits)
+		return 0;
+	if (w)
+		return !c->bits.i;
+	else
+		return !(c->bits.i & 0xffffffff);
 }
 
 static int
@@ -67,21 +72,25 @@ update(int t, int v, Fn *fn)
 			vgrow(&usewrk, ++nuse);
 			usewrk[nuse-1] = &tmp->use[u];
 		}
+		val[t] = v;
 	}
 }
 
 static void
 visitphi(Phi *p, int n, Fn *fn)
 {
-	int v, dead;
+	int v, m, dead;
 	uint a;
 
 	v = Top;
 	for (a=0; a<p->narg; a++) {
-		if (edge[n][0].dest == p->blk[a]->id)
-			dead = edge[n][0].dead;
+		m = p->blk[a]->id;
+		if (edge[m][0].dest == n)
+			dead = edge[m][0].dead;
+		else if (edge[m][1].dest == n)
+			dead = edge[m][1].dead;
 		else
-			dead = edge[n][1].dead;
+			die("invalid phi argument");
 		if (!dead)
 			v = latmerge(v, latval(p->arg[a]));
 	}
@@ -109,7 +118,7 @@ visitins(Ins *i, Fn *fn)
 			v = opfold(i->op, i->cls, &fn->con[l], &fn->con[r], fn);
 	} else
 		v = Bot;
-	assert(v != Top);
+	/* fprintf(stderr, "\nvisiting %s (%p)", opdesc[i->op].name, (void *)i); */
 	update(i->to.val, v, fn);
 }
 
@@ -126,7 +135,7 @@ visitjmp(Blk *b, int n, Fn *fn)
 			edge[n][0].work = &edge[n][1];
 			flowrk = &edge[n][0];
 		}
-		else if (czero(&fn->con[l])) {
+		else if (czero(&fn->con[l], 0)) {
 			assert(edge[n][0].dead);
 			edge[n][1].work = flowrk;
 			flowrk = &edge[n][1];
@@ -163,7 +172,7 @@ initedge(Edge *e, Blk *s)
 void
 fold(Fn *fn)
 {
-	Edge *e;
+	Edge *e, start;
 	Use *u;
 	Blk *b;
 	Phi *p;
@@ -174,26 +183,16 @@ fold(Fn *fn)
 	edge = emalloc(fn->nblk * sizeof edge[0]);
 	usewrk = vnew(0, sizeof usewrk[0]);
 
-	for (b=fn->start; b; b=b->link) {
-		for (p=b->phi; p; p=p->link)
-			val[p->to.val] = Top;
-		for (i=b->ins; i-b->ins < b->nins; i++)
-			if (rtype(i->to) == RTmp) {
-				if (opdesc[i->op].cfold)
-					val[i->to.val] = Top;
-				else
-					val[i->to.val] = Bot;
-			}
-	}
+	for (n=0; n<fn->ntmp; n++)
+		val[n] = Top;
 	for (n=0; n<fn->nblk; n++) {
 		b = fn->rpo[n];
 		b->visit = 0;
 		initedge(&edge[n][0], b->s1);
 		initedge(&edge[n][1], b->s2);
 	}
-	assert(fn->start->id == 0);
-	edge[0][0].work = &edge[0][1];
-	flowrk = &edge[0][0];
+	initedge(&start, fn->start);
+	flowrk = &start;
 	nuse = 0;
 
 	/* 1. find out constants and dead cfg edges */
@@ -225,15 +224,16 @@ fold(Fn *fn)
 				visitphi(u->u.phi, u->bid, fn);
 				continue;
 			}
+			n = u->bid;
+			b = fn->rpo[n];
 			if (b->visit == 0)
 				continue;
-			n = u->bid;
 			switch (u->type) {
 			case UIns:
 				visitins(u->u.ins, fn);
 				break;
 			case UJmp:
-				visitjmp(fn->rpo[n], n, fn);
+				visitjmp(b, n, fn);
 				break;
 			default:
 				die("unreachable");
@@ -243,6 +243,20 @@ fold(Fn *fn)
 			break;
 	}
 
+	if (debug['F']) {
+		fprintf(stderr, "\n> SCCP findings:\n");
+		for (n=Tmp0; n<fn->ntmp; n++) {
+			fprintf(stderr, "%10s: ", fn->tmp[n].name);
+			if (val[n] == Top)
+				fprintf(stderr, "Top");
+			else if (val[n] == Bot)
+				fprintf(stderr, "Bot");
+			else
+				printref(CON(val[n]), fn, stderr);
+			fprintf(stderr, "\n");
+		}
+	}
+
 	/* 2. trim dead code, replace constants */
 
 	free(val);
@@ -265,10 +279,8 @@ foldint(Con *res, int op, int w, Con *cl, Con *cr)
 
 	lab = 0;
 	l.s = cl->bits.i;
-	r.s = cl->bits.i;
-	switch (op) {
-	case OAdd:
-		x = l.u + r.u;
+	r.s = cr->bits.i;
+	if (op == OAdd) {
 		if (cl->type == CAddr) {
 			if (cr->type == CAddr)
 				err("undefined addition (addr + addr)");
@@ -276,9 +288,8 @@ foldint(Con *res, int op, int w, Con *cl, Con *cr)
 		}
 		else if (cr->type == CAddr)
 			lab = cr->label;
-		break;
-	case OSub:
-		x = l.u - r.u;
+	}
+	else if (op == OSub) {
 		if (cl->type == CAddr) {
 			if (cr->type != CAddr)
 				lab = cl->label;
@@ -287,7 +298,12 @@ foldint(Con *res, int op, int w, Con *cl, Con *cr)
 		}
 		else if (cr->type == CAddr)
 			err("undefined substraction (num - addr)");
-		break;
+	}
+	else if (cl->type == CAddr || cr->type == CAddr)
+		err("invalid address operand for '%s'", opdesc[op].name);
+	switch (op) {
+	case OAdd:  x = l.u + r.u; break;
+	case OSub:  x = l.u - r.u; break;
 	case ODiv:  x = l.s / r.s; break;
 	case ORem:  x = l.s % r.s; break;
 	case OUDiv: x = l.u / r.u; break;
@@ -367,6 +383,7 @@ foldint(Con *res, int op, int w, Con *cl, Con *cr)
 			die("unreachable");
 	}
 	*res = (Con){lab ? CAddr : CBits, .bits={.i=x}};
+	res->bits.i = x;
 	if (lab)
 		strcpy(res->label, lab);
 }
@@ -377,6 +394,8 @@ foldflt(Con *res, int op, int w, Con *cl, Con *cr)
 	float xs, ls, rs;
 	double xd, ld, rd;
 
+	if (cl->type != CBits || cr->type != CBits)
+		err("invalid address operand for '%s'", opdesc[op].name);
 	if (w)  {
 		ld = cl->bits.d;
 		rd = cr->bits.d;
@@ -387,7 +406,7 @@ foldflt(Con *res, int op, int w, Con *cl, Con *cr)
 		case OMul: xd = ld * rd; break;
 		case OSitof: xd = cl->bits.i; break;
 		case OExts: xd = cl->bits.s; break;
-		case OCast: xd = cl->bits.d; break;
+		case OCast: xd = ld; break;
 		default: die("unreachable");
 		}
 		*res = (Con){CBits, .bits={.d=xd}, .flt=2};
@@ -401,7 +420,7 @@ foldflt(Con *res, int op, int w, Con *cl, Con *cr)
 		case OMul: xs = ls * rs; break;
 		case OSitof: xs = cl->bits.i; break;
 		case OTruncd: xs = cl->bits.d; break;
-		case OCast: xs = cl->bits.s; break;
+		case OCast: xs = ls; break;
 		default: die("unreachable");
 		}
 		*res = (Con){CBits, .bits={.s=xs}, .flt=1};
@@ -415,21 +434,19 @@ opfold(int op, int cls, Con *cl, Con *cr, Fn *fn)
 	Con c;
 
 	if ((op == ODiv || op == OUDiv
-	|| op == ORem || op == OURem) && czero(cr))
+	|| op == ORem || op == OURem) && czero(cr, KWIDE(cls)))
 		err("null divisor in '%s'", opdesc[op].name);
 	if (cls == Kw || cls == Kl)
 		foldint(&c, op, cls == Kl, cl, cr);
-	else {
-		if (cl->type != CBits || cr->type != CBits)
-			err("invalid address operand for '%s'", opdesc[op].name);
+	else
 		foldflt(&c, op, cls == Kd, cl, cr);
-	}
 	if (c.type == CBits)
 		nc = getcon(c.bits.i, fn).val;
 	else {
 		nc = fn->ncon;
 		vgrow(&fn->con, ++fn->ncon);
 	}
+	assert(!(cls == Ks || cls == Kd) || c.flt);
 	fn->con[nc] = c;
 	return nc;
 }