about summary refs log tree commit diff homepage
path: root/lib/Solver
diff options
context:
space:
mode:
Diffstat (limited to 'lib/Solver')
-rw-r--r--lib/Solver/STPBuilder.cpp92
-rw-r--r--lib/Solver/STPBuilder.h2
2 files changed, 57 insertions, 37 deletions
diff --git a/lib/Solver/STPBuilder.cpp b/lib/Solver/STPBuilder.cpp
index 307ae0fb..8ec35762 100644
--- a/lib/Solver/STPBuilder.cpp
+++ b/lib/Solver/STPBuilder.cpp
@@ -152,6 +152,7 @@ ExprHandle STPBuilder::bvExtract(ExprHandle expr, unsigned top, unsigned bottom)
   return vc_bvExtract(vc, expr, top, bottom);
 }
 ExprHandle STPBuilder::eqExpr(ExprHandle a, ExprHandle b) {
+  assert((vc_getBVLength(vc, a) == vc_getBVLength(vc, b)) && "a and b should be same type");
   return vc_eqExpr(vc, a, b);
 }
 
@@ -187,76 +188,93 @@ ExprHandle STPBuilder::bvLeftShift(ExprHandle expr, unsigned shift) {
   }
 }
 
+ExprHandle STPBuilder::extractPartialShiftValue(ExprHandle shift,
+                                                unsigned width,
+                                                unsigned &shiftBits) {
+  // Assuming width is power of 2
+  llvm::APInt sw(32, width);
+  shiftBits = sw.getActiveBits();
+
+  // get the shift amount (looking only at the bits appropriate for the given
+  // width)
+  return vc_bvExtract(vc, shift, shiftBits - 1, 0);
+}
+
 // left shift by a variable amount on an expression of the specified width
 ExprHandle STPBuilder::bvVarLeftShift(ExprHandle expr, ExprHandle shift) {
   unsigned width = vc_getBVLength(vc, expr);
   ExprHandle res = bvZero(width);
 
-  //construct a big if-then-elif-elif-... with one case per possible shift amount
-  for( int i=width-1; i>=0; i-- ) {
-    res = vc_iteExpr(vc,
-                     eqExpr(shift, bvConst32(width, i)),
-                     bvLeftShift(expr, i),
-                     res);
+  unsigned shiftBits = 0;
+  ExprHandle shift_ext = extractPartialShiftValue(shift, width, shiftBits);
+
+  // construct a big if-then-elif-elif-... with one case per possible shift
+  // amount
+  for (int i = width - 1; i >= 0; i--) {
+    res = vc_iteExpr(vc, eqExpr(shift_ext, bvConst32(shiftBits, i)),
+                     bvLeftShift(expr, i), res);
   }
 
   // If overshifting, shift to zero
-  ExprHandle ex = vc_bvLtExpr(vc, shift, bvConst32(vc_getBVLength(vc,shift), width));
+  ExprHandle ex =
+      vc_bvLtExpr(vc, shift, bvConst32(vc_getBVLength(vc, shift), width));
 
   res = vc_iteExpr(vc, ex, res, bvZero(width));
   return res;
 }
 
-// logical right shift by a variable amount on an expression of the specified width
+// logical right shift by a variable amount on an expression of the specified
+// width
 ExprHandle STPBuilder::bvVarRightShift(ExprHandle expr, ExprHandle shift) {
   unsigned width = vc_getBVLength(vc, expr);
   ExprHandle res = bvZero(width);
 
-  //construct a big if-then-elif-elif-... with one case per possible shift amount
-  for( int i=width-1; i>=0; i-- ) {
-    res = vc_iteExpr(vc,
-                     eqExpr(shift, bvConst32(width, i)),
-                     bvRightShift(expr, i),
-                     res);
+  unsigned shiftBits = 0;
+  ExprHandle shift_ext = extractPartialShiftValue(shift, width, shiftBits);
+
+  // construct a big if-then-elif-elif-... with one case per possible shift
+  // amount
+  for (int i = width - 1; i >= 0; i--) {
+    res = vc_iteExpr(vc, eqExpr(shift_ext, bvConst32(shiftBits, i)),
+                     bvRightShift(expr, i), res);
   }
 
   // If overshifting, shift to zero
-  ExprHandle ex = vc_bvLtExpr(vc, shift, bvConst32(vc_getBVLength(vc,shift), width));
-  res = vc_iteExpr(vc,
-                   ex,
-                   res,
-                   bvZero(width));
+  // If overshifting, shift to zero
+  ExprHandle ex =
+      vc_bvLtExpr(vc, shift, bvConst32(vc_getBVLength(vc, shift), width));
+  res = vc_iteExpr(vc, ex, res, bvZero(width));
+
   return res;
 }
 
-// arithmetic right shift by a variable amount on an expression of the specified width
+// arithmetic right shift by a variable amount on an expression of the specified
+// width
 ExprHandle STPBuilder::bvVarArithRightShift(ExprHandle expr, ExprHandle shift) {
   unsigned width = vc_getBVLength(vc, expr);
 
-  //get the sign bit to fill with
-  ExprHandle signedBool = bvBoolExtract(expr, width-1);
+  unsigned shiftBits = 0;
+  ExprHandle shift_ext = extractPartialShiftValue(shift, width, shiftBits);
 
-  //start with the result if shifting by width-1
-  ExprHandle res = constructAShrByConstant(expr, width-1, signedBool);
+  // get the sign bit to fill with
+  ExprHandle signedBool = bvBoolExtract(expr, width - 1);
 
-  //construct a big if-then-elif-elif-... with one case per possible shift amount
+  // start with the result if shifting by width-1
+  ExprHandle res = constructAShrByConstant(expr, width - 1, signedBool);
+
+  // construct a big if-then-elif-elif-... with one case per possible shift
+  // amount
   // XXX more efficient to move the ite on the sign outside all exprs?
   // XXX more efficient to sign extend, right shift, then extract lower bits?
-  for( int i=width-2; i>=0; i-- ) {
-    res = vc_iteExpr(vc,
-                     eqExpr(shift, bvConst32(width,i)),
-                     constructAShrByConstant(expr, 
-                                             i, 
-                                             signedBool),
-                     res);
+  for (int i = width - 2; i >= 0; i--) {
+    res = vc_iteExpr(vc, eqExpr(shift_ext, bvConst32(shiftBits, i)),
+                     constructAShrByConstant(expr, i, signedBool), res);
   }
 
   // If overshifting, shift to zero
-  ExprHandle ex = vc_bvLtExpr(vc, shift, bvConst32(vc_getBVLength(vc,shift), width));
-  res = vc_iteExpr(vc,
-                   ex,
-                   res,
-                   bvZero(width));
+  ExprHandle ex =
+      vc_bvLtExpr(vc, shift, bvConst32(vc_getBVLength(vc, shift), width));
+  res = vc_iteExpr(vc, ex, res, bvZero(width));
   return res;
 }
 
diff --git a/lib/Solver/STPBuilder.h b/lib/Solver/STPBuilder.h
index 3b17ccf1..5be34029 100644
--- a/lib/Solver/STPBuilder.h
+++ b/lib/Solver/STPBuilder.h
@@ -93,6 +93,8 @@ private:
   ExprHandle bvVarLeftShift(ExprHandle expr, ExprHandle shift);
   ExprHandle bvVarRightShift(ExprHandle expr, ExprHandle shift);
   ExprHandle bvVarArithRightShift(ExprHandle expr, ExprHandle shift);
+  ExprHandle extractPartialShiftValue(ExprHandle shift, unsigned width,
+                                      unsigned &shiftBits);
 
   ExprHandle constructAShrByConstant(ExprHandle expr, unsigned shift, 
                                      ExprHandle isSigned);