about summary refs log tree commit diff homepage
path: root/lib/Expr/Expr.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'lib/Expr/Expr.cpp')
-rw-r--r--lib/Expr/Expr.cpp167
1 files changed, 82 insertions, 85 deletions
diff --git a/lib/Expr/Expr.cpp b/lib/Expr/Expr.cpp
index 6cd15be7..0a174179 100644
--- a/lib/Expr/Expr.cpp
+++ b/lib/Expr/Expr.cpp
@@ -551,13 +551,12 @@ ref<Expr> SExtExpr::create(const ref<Expr> &e, Width w) {
 static ref<Expr> AndExpr_create(Expr *l, Expr *r);
 static ref<Expr> XorExpr_create(Expr *l, Expr *r);
 
-static ref<Expr> EqExpr_createPartial(Expr *l, const ref<Expr> &cr);
-static ref<Expr> AndExpr_createPartialR(const ref<Expr> &cl, Expr *r);
-static ref<Expr> SubExpr_createPartialR(const ref<Expr> &cl, Expr *r);
-static ref<Expr> XorExpr_createPartialR(const ref<Expr> &cl, Expr *r);
+static ref<Expr> EqExpr_createPartial(Expr *l, const ref<ConstantExpr> &cr);
+static ref<Expr> AndExpr_createPartialR(const ref<ConstantExpr> &cl, Expr *r);
+static ref<Expr> SubExpr_createPartialR(const ref<ConstantExpr> &cl, Expr *r);
+static ref<Expr> XorExpr_createPartialR(const ref<ConstantExpr> &cl, Expr *r);
 
-static ref<Expr> AddExpr_createPartialR(const ref<Expr> &cl, Expr *r) {
-  assert(cl->isConstant() && "non-constant passed in place of constant");
+static ref<Expr> AddExpr_createPartialR(const ref<ConstantExpr> &cl, Expr *r) {
   uint64_t value = cl->getConstantValue();
   Expr::Width type = cl->getWidth();
 
@@ -567,10 +566,10 @@ static ref<Expr> AddExpr_createPartialR(const ref<Expr> &cl, Expr *r) {
     return r;
   } else {
     Expr::Kind rk = r->getKind();
-    if (rk==Expr::Add && r->getKid(0)->isConstant()) { // A + (B+c) == (A+B) + c
+    if (rk==Expr::Add && isa<ConstantExpr>(r->getKid(0))) { // A + (B+c) == (A+B) + c
       return AddExpr::create(AddExpr::create(cl, r->getKid(0)),
                              r->getKid(1));
-    } else if (rk==Expr::Sub && r->getKid(0)->isConstant()) { // A + (B-c) == (A+B) - c
+    } else if (rk==Expr::Sub && isa<ConstantExpr>(r->getKid(0))) { // A + (B-c) == (A+B) - c
       return SubExpr::create(AddExpr::create(cl, r->getKid(0)),
                              r->getKid(1));
     } else {
@@ -578,7 +577,7 @@ static ref<Expr> AddExpr_createPartialR(const ref<Expr> &cl, Expr *r) {
     }
   }
 }
-static ref<Expr> AddExpr_createPartial(Expr *l, const ref<Expr> &cr) {
+static ref<Expr> AddExpr_createPartial(Expr *l, const ref<ConstantExpr> &cr) {
   return AddExpr_createPartialR(cr, l);
 }
 static ref<Expr> AddExpr_create(Expr *l, Expr *r) {
@@ -588,16 +587,16 @@ static ref<Expr> AddExpr_create(Expr *l, Expr *r) {
     return XorExpr_create(l, r);
   } else {
     Expr::Kind lk = l->getKind(), rk = r->getKind();
-    if (lk==Expr::Add && l->getKid(0)->isConstant()) { // (k+a)+b = k+(a+b)
+    if (lk==Expr::Add && isa<ConstantExpr>(l->getKid(0))) { // (k+a)+b = k+(a+b)
       return AddExpr::create(l->getKid(0),
                              AddExpr::create(l->getKid(1), r));
-    } else if (lk==Expr::Sub && l->getKid(0)->isConstant()) { // (k-a)+b = k+(b-a)
+    } else if (lk==Expr::Sub && isa<ConstantExpr>(l->getKid(0))) { // (k-a)+b = k+(b-a)
       return AddExpr::create(l->getKid(0),
                              SubExpr::create(r, l->getKid(1)));
-    } else if (rk==Expr::Add && r->getKid(0)->isConstant()) { // a + (k+b) = k+(a+b)
+    } else if (rk==Expr::Add && isa<ConstantExpr>(r->getKid(0))) { // a + (k+b) = k+(a+b)
       return AddExpr::create(r->getKid(0),
                              AddExpr::create(l, r->getKid(1)));
-    } else if (rk==Expr::Sub && r->getKid(0)->isConstant()) { // a + (k-b) = k+(a-b)
+    } else if (rk==Expr::Sub && isa<ConstantExpr>(r->getKid(0))) { // a + (k-b) = k+(a-b)
       return AddExpr::create(r->getKid(0),
                              SubExpr::create(l, r->getKid(1)));
     } else {
@@ -606,8 +605,7 @@ static ref<Expr> AddExpr_create(Expr *l, Expr *r) {
   }  
 }
 
-static ref<Expr> SubExpr_createPartialR(const ref<Expr> &cl, Expr *r) {
-  assert(cl->isConstant() && "non-constant passed in place of constant");
+static ref<Expr> SubExpr_createPartialR(const ref<ConstantExpr> &cl, Expr *r) {
   Expr::Width type = cl->getWidth();
 
   if (type==Expr::Bool) {
@@ -625,7 +623,7 @@ static ref<Expr> SubExpr_createPartialR(const ref<Expr> &cl, Expr *r) {
     }
   }
 }
-static ref<Expr> SubExpr_createPartial(Expr *l, const ref<Expr> &cr) {
+static ref<Expr> SubExpr_createPartial(Expr *l, const ref<ConstantExpr> &cr) {
   assert(cr->isConstant() && "non-constant passed in place of constant");
   uint64_t value = cr->getConstantValue();
   Expr::Width width = cr->getWidth();
@@ -642,16 +640,16 @@ static ref<Expr> SubExpr_create(Expr *l, Expr *r) {
     return ConstantExpr::alloc(0, type);
   } else {
     Expr::Kind lk = l->getKind(), rk = r->getKind();
-    if (lk==Expr::Add && l->getKid(0)->isConstant()) { // (k+a)-b = k+(a-b)
+    if (lk==Expr::Add && isa<ConstantExpr>(l->getKid(0))) { // (k+a)-b = k+(a-b)
       return AddExpr::create(l->getKid(0),
                              SubExpr::create(l->getKid(1), r));
-    } else if (lk==Expr::Sub && l->getKid(0)->isConstant()) { // (k-a)-b = k-(a+b)
+    } else if (lk==Expr::Sub && isa<ConstantExpr>(l->getKid(0))) { // (k-a)-b = k-(a+b)
       return SubExpr::create(l->getKid(0),
                              AddExpr::create(l->getKid(1), r));
-    } else if (rk==Expr::Add && r->getKid(0)->isConstant()) { // a - (k+b) = (a-c) - k
+    } else if (rk==Expr::Add && isa<ConstantExpr>(r->getKid(0))) { // a - (k+b) = (a-c) - k
       return SubExpr::create(SubExpr::create(l, r->getKid(1)),
                              r->getKid(0));
-    } else if (rk==Expr::Sub && r->getKid(0)->isConstant()) { // a - (k-b) = (a+b) - k
+    } else if (rk==Expr::Sub && isa<ConstantExpr>(r->getKid(0))) { // a - (k-b) = (a+b) - k
       return SubExpr::create(AddExpr::create(l, r->getKid(1)),
                              r->getKid(0));
     } else {
@@ -660,7 +658,7 @@ static ref<Expr> SubExpr_create(Expr *l, Expr *r) {
   }  
 }
 
-static ref<Expr> MulExpr_createPartialR(const ref<Expr> &cl, Expr *r) {
+static ref<Expr> MulExpr_createPartialR(const ref<ConstantExpr> &cl, Expr *r) {
   assert(cl->isConstant() && "non-constant passed in place of constant");
   uint64_t value = cl->getConstantValue();
   Expr::Width type = cl->getWidth();
@@ -675,7 +673,7 @@ static ref<Expr> MulExpr_createPartialR(const ref<Expr> &cl, Expr *r) {
     return MulExpr::alloc(cl, r);
   }
 }
-static ref<Expr> MulExpr_createPartial(Expr *l, const ref<Expr> &cr) {
+static ref<Expr> MulExpr_createPartial(Expr *l, const ref<ConstantExpr> &cr) {
   return MulExpr_createPartialR(cr, l);
 }
 static ref<Expr> MulExpr_create(Expr *l, Expr *r) {
@@ -688,8 +686,7 @@ static ref<Expr> MulExpr_create(Expr *l, Expr *r) {
   }
 }
 
-static ref<Expr> AndExpr_createPartial(Expr *l, const ref<Expr> &cr) {
-  assert(cr->isConstant() && "non-constant passed in place of constant");
+static ref<Expr> AndExpr_createPartial(Expr *l, const ref<ConstantExpr> &cr) {
   uint64_t value = cr->getConstantValue();
   Expr::Width width = cr->getWidth();
 
@@ -701,15 +698,14 @@ static ref<Expr> AndExpr_createPartial(Expr *l, const ref<Expr> &cr) {
     return AndExpr::alloc(l, cr);
   }
 }
-static ref<Expr> AndExpr_createPartialR(const ref<Expr> &cl, Expr *r) {
+static ref<Expr> AndExpr_createPartialR(const ref<ConstantExpr> &cl, Expr *r) {
   return AndExpr_createPartial(r, cl);
 }
 static ref<Expr> AndExpr_create(Expr *l, Expr *r) {
   return AndExpr::alloc(l, r);
 }
 
-static ref<Expr> OrExpr_createPartial(Expr *l, const ref<Expr> &cr) {
-  assert(cr->isConstant() && "non-constant passed in place of constant");
+static ref<Expr> OrExpr_createPartial(Expr *l, const ref<ConstantExpr> &cr) {
   uint64_t value = cr->getConstantValue();
   Expr::Width width = cr->getWidth();
 
@@ -721,15 +717,14 @@ static ref<Expr> OrExpr_createPartial(Expr *l, const ref<Expr> &cr) {
     return OrExpr::alloc(l, cr);
   }
 }
-static ref<Expr> OrExpr_createPartialR(const ref<Expr> &cl, Expr *r) {
+static ref<Expr> OrExpr_createPartialR(const ref<ConstantExpr> &cl, Expr *r) {
   return OrExpr_createPartial(r, cl);
 }
 static ref<Expr> OrExpr_create(Expr *l, Expr *r) {
   return OrExpr::alloc(l, r);
 }
 
-static ref<Expr> XorExpr_createPartialR(const ref<Expr> &cl, Expr *r) {
-  assert(cl->isConstant() && "non-constant passed in place of constant");
+static ref<Expr> XorExpr_createPartialR(const ref<ConstantExpr> &cl, Expr *r) {
   uint64_t value = cl->getConstantValue();
   Expr::Width type = cl->getWidth();
 
@@ -746,7 +741,7 @@ static ref<Expr> XorExpr_createPartialR(const ref<Expr> &cl, Expr *r) {
   }
 }
 
-static ref<Expr> XorExpr_createPartial(Expr *l, const ref<Expr> &cr) {
+static ref<Expr> XorExpr_createPartial(Expr *l, const ref<ConstantExpr> &cr) {
   return XorExpr_createPartialR(cr, l);
 }
 static ref<Expr> XorExpr_create(Expr *l, Expr *r) {
@@ -811,34 +806,34 @@ static ref<Expr> AShrExpr_create(const ref<Expr> &l, const ref<Expr> &r) {
 
 #define BCREATE_R(_e_op, _op, partialL, partialR) \
 ref<Expr>  _e_op ::create(const ref<Expr> &l, const ref<Expr> &r) { \
-  assert(l->getWidth()==r->getWidth() && "type mismatch"); \
-  if (l->isConstant()) {                                \
-    if (r->isConstant()) {                              \
-      Expr::Width width = l->getWidth(); \
-      uint64_t val = ints::_op(l->getConstantValue(),  \
-                               r->getConstantValue(), width);  \
-      return ConstantExpr::create(val, width); \
-    } else { \
-      return _e_op ## _createPartialR(l, r.get()); \
-    } \
-  } else if (r->isConstant()) {             \
-    return _e_op ## _createPartial(l.get(), r); \
-  } \
-  return _e_op ## _create(l.get(), r.get()); \
+  assert(l->getWidth()==r->getWidth() && "type mismatch");              \
+  if (ConstantExpr *cl = dyn_cast<ConstantExpr>(l)) {                   \
+    if (ConstantExpr *cr = dyn_cast<ConstantExpr>(r)) {                 \
+      Expr::Width width = l->getWidth();                                \
+      uint64_t val = ints::_op(cl->getConstantValue(),                  \
+                               cr->getConstantValue(), width);          \
+      return ConstantExpr::create(val, width);                          \
+    } else {                                                            \
+      return _e_op ## _createPartialR(cl, r.get());                     \
+    }                                                                   \
+  } else if (ConstantExpr *cr = dyn_cast<ConstantExpr>(r)) {            \
+    return _e_op ## _createPartial(l.get(), cr);                        \
+  }                                                                     \
+  return _e_op ## _create(l.get(), r.get());                            \
 }
 
 #define BCREATE(_e_op, _op) \
 ref<Expr>  _e_op ::create(const ref<Expr> &l, const ref<Expr> &r) { \
-  assert(l->getWidth()==r->getWidth() && "type mismatch"); \
-  if (l->isConstant()) {                                \
-    if (r->isConstant()) {                              \
-      Expr::Width width = l->getWidth(); \
-      uint64_t val = ints::_op(l->getConstantValue(), \
-                               r->getConstantValue(), width);  \
-      return ConstantExpr::create(val, width); \
-    } \
-  } \
-  return _e_op ## _create(l, r);                    \
+  assert(l->getWidth()==r->getWidth() && "type mismatch");          \
+  if (ConstantExpr *cl = dyn_cast<ConstantExpr>(l)) {               \
+    if (ConstantExpr *cr = dyn_cast<ConstantExpr>(r)) {             \
+      Expr::Width width = l->getWidth();                            \
+      uint64_t val = ints::_op(cl->getConstantValue(),              \
+                               cr->getConstantValue(), width);      \
+      return ConstantExpr::create(val, width);                      \
+    }                                                               \
+  }                                                                 \
+  return _e_op ## _create(l, r);                                    \
 }
 
 BCREATE_R(AddExpr, add, AddExpr_createPartial, AddExpr_createPartialR)
@@ -857,35 +852,35 @@ BCREATE(AShrExpr, ashr)
 
 #define CMPCREATE(_e_op, _op) \
 ref<Expr>  _e_op ::create(const ref<Expr> &l, const ref<Expr> &r) { \
-  assert(l->getWidth()==r->getWidth() && "type mismatch"); \
-  if (l->isConstant()) {                                \
-    if (r->isConstant()) {                              \
-      Expr::Width width = l->getWidth(); \
-      uint64_t val = ints::_op(l->getConstantValue(), \
-                               r->getConstantValue(), width);  \
-      return ConstantExpr::create(val, Expr::Bool); \
-    } \
-  } \
-  return _e_op ## _create(l, r);                    \
+  assert(l->getWidth()==r->getWidth() && "type mismatch");              \
+  if (ConstantExpr *cl = dyn_cast<ConstantExpr>(l)) {                   \
+    if (ConstantExpr *cr = dyn_cast<ConstantExpr>(r)) {                 \
+      Expr::Width width = cl->getWidth();                               \
+      uint64_t val = ints::_op(cl->getConstantValue(),                  \
+                               cr->getConstantValue(), width);          \
+      return ConstantExpr::create(val, Expr::Bool);                     \
+    }                                                                   \
+  }                                                                     \
+  return _e_op ## _create(l, r);                                        \
 }
 
 #define CMPCREATE_T(_e_op, _op, _reflexive_e_op, partialL, partialR) \
-ref<Expr>  _e_op ::create(const ref<Expr> &l, const ref<Expr> &r) { \
-  assert(l->getWidth()==r->getWidth() && "type mismatch"); \
-  if (l->isConstant()) {                                \
-    if (r->isConstant()) {                              \
-      Expr::Width width = l->getWidth(); \
-      uint64_t val = ints::_op(l->getConstantValue(), \
-                               r->getConstantValue(), width);  \
-      return ConstantExpr::create(val, Expr::Bool); \
-    } else { \
-      return partialR(l, r.get()); \
-    } \
-  } else if (r->isConstant()) {                  \
-    return partialL(l.get(), r); \
-  } else { \
-    return _e_op ## _create(l.get(), r.get()); \
-  } \
+ref<Expr>  _e_op ::create(const ref<Expr> &l, const ref<Expr> &r) {    \
+  assert(l->getWidth()==r->getWidth() && "type mismatch");             \
+  if (ConstantExpr *cl = dyn_cast<ConstantExpr>(l)) {                  \
+    if (ConstantExpr *cr = dyn_cast<ConstantExpr>(r)) {                \
+      Expr::Width width = cl->getWidth();                              \
+      uint64_t val = ints::_op(cl->getConstantValue(),                 \
+                               cr->getConstantValue(), width);         \
+      return ConstantExpr::create(val, Expr::Bool);                    \
+    } else {                                                           \
+      return partialR(cl, r.get());                                    \
+    }                                                                  \
+  } else if (ConstantExpr *cr = dyn_cast<ConstantExpr>(r)) {           \
+    return partialL(l.get(), cr);                                      \
+  } else {                                                             \
+    return _e_op ## _create(l.get(), r.get());                         \
+  }                                                                    \
 }
   
 
@@ -972,7 +967,7 @@ static ref<Expr> TryConstArrayOpt(const ref<Expr> &cl,
 }
 
 
-static ref<Expr> EqExpr_createPartialR(const ref<Expr> &cl, Expr *r) {  
+static ref<Expr> EqExpr_createPartialR(const ref<ConstantExpr> &cl, Expr *r) {  
   assert(cl->isConstant() && "non-constant passed in place of constant");
   uint64_t value = cl->getConstantValue();
   Expr::Width width = cl->getWidth();
@@ -1031,14 +1026,16 @@ static ref<Expr> EqExpr_createPartialR(const ref<Expr> &cl, Expr *r) {
     const AddExpr *ae = cast<AddExpr>(r);
     if (ae->left->isConstant()) {
       // c0 = c1 + b => c0 - c1 = b
-      return EqExpr_createPartialR(SubExpr::create(cl, ae->left),
+      return EqExpr_createPartialR(cast<ConstantExpr>(SubExpr::create(cl, 
+                                                                      ae->left)),
                                    ae->right.get());
     }
   } else if (rk==Expr::Sub) {
     const SubExpr *se = cast<SubExpr>(r);
     if (se->left->isConstant()) {
       // c0 = c1 - b => c1 - c0 = b
-      return EqExpr_createPartialR(SubExpr::create(se->left, cl),
+      return EqExpr_createPartialR(cast<ConstantExpr>(SubExpr::create(se->left, 
+                                                                      cl)),
                                    se->right.get());
     }
   } else if (rk == Expr::Read && ConstArrayOpt) {
@@ -1048,7 +1045,7 @@ static ref<Expr> EqExpr_createPartialR(const ref<Expr> &cl, Expr *r) {
   return EqExpr_create(cl, r);
 }
 
-static ref<Expr> EqExpr_createPartial(Expr *l, const ref<Expr> &cr) {  
+static ref<Expr> EqExpr_createPartial(Expr *l, const ref<ConstantExpr> &cr) {  
   return EqExpr_createPartialR(cr, l);
 }