about summary refs log tree commit diff homepage
diff options
context:
space:
mode:
-rw-r--r--lib/Module/Checks.cpp94
1 files changed, 44 insertions, 50 deletions
diff --git a/lib/Module/Checks.cpp b/lib/Module/Checks.cpp
index 2489b55c..bd9a8bd3 100644
--- a/lib/Module/Checks.cpp
+++ b/lib/Module/Checks.cpp
@@ -91,59 +91,53 @@ bool DivCheckPass::runOnModule(Module &M) {
 char OvershiftCheckPass::ID;
 
 bool OvershiftCheckPass::runOnModule(Module &M) {
-  Function *overshiftCheckFunction = 0;
-  LLVMContext &ctx = M.getContext();
+  std::vector<llvm::BinaryOperator *> shiftInstructions;
+  for (auto &F : M) {
+    for (auto &BB : F) {
+      for (auto &I : BB) {
+        auto binOp = dyn_cast<BinaryOperator>(&I);
+        if (!binOp)
+          continue;
 
-  bool moduleChanged = false;
-
-  for (Module::iterator f = M.begin(), fe = M.end(); f != fe; ++f) {
-    for (Function::iterator b = f->begin(), be = f->end(); b != be; ++b) {
-      for (BasicBlock::iterator i = b->begin(), ie = b->end(); i != ie; ++i) {
-          if (BinaryOperator* binOp = dyn_cast<BinaryOperator>(i)) {
-          // find all shift instructions
-          Instruction::BinaryOps opcode = binOp->getOpcode();
-
-          if (opcode == Instruction::Shl ||
-              opcode == Instruction::LShr ||
-              opcode == Instruction::AShr ) {
-            std::vector<llvm::Value*> args;
-
-            // Determine bit width of first operand
-            uint64_t bitWidth=i->getOperand(0)->getType()->getScalarSizeInBits();
-
-            ConstantInt *bitWidthC = ConstantInt::get(Type::getInt64Ty(ctx),
-		bitWidth, false);
-            args.push_back(bitWidthC);
-
-            CastInst *shift =
-              CastInst::CreateIntegerCast(i->getOperand(1),
-                                          Type::getInt64Ty(ctx),
-                                          false,  /* sign doesn't matter */
-                                          "int_cast_to_i64",
-                                          &*i);
-            args.push_back(shift);
-
-
-            // Lazily bind the function to avoid always importing it.
-            if (!overshiftCheckFunction) {
-              Constant *fc = M.getOrInsertFunction("klee_overshift_check",
-                                                   Type::getVoidTy(ctx),
-                                                   Type::getInt64Ty(ctx),
-                                                   Type::getInt64Ty(ctx),
-                                                   NULL);
-              overshiftCheckFunction = cast<Function>(fc);
-            }
-
-            // Inject CallInstr to check if overshifting possible
-            CallInst *ci =
-                CallInst::Create(overshiftCheckFunction, args, "", &*i);
-            // set debug information from binary operand to preserve it
-            ci->setDebugLoc(binOp->getDebugLoc());
-            moduleChanged = true;
-          }
+        // find all shift instructions
+        auto opcode = binOp->getOpcode();
+        if (opcode != Instruction::Shl && opcode != Instruction::LShr &&
+            opcode != Instruction::AShr)
+          continue;
         }
+
+        shiftInstructions.push_back(binOp);
       }
     }
   }
-  return moduleChanged;
+
+  if (shiftInstructions.empty())
+    return false;
+
+  // Retrieve the checker function
+  auto &ctx = M.getContext();
+  auto overshiftCheckFunction = cast<Function>(M.getOrInsertFunction(
+      "klee_overshift_check", Type::getVoidTy(ctx), Type::getInt64Ty(ctx),
+      Type::getInt64Ty(ctx), NULL));
+
+  for (auto &shiftInst : shiftInstructions) {
+    llvm::IRBuilder<> Builder(shiftInst);
+
+    std::vector<llvm::Value *> args;
+
+    // Determine bit width of first operand
+    uint64_t bitWidth = shiftInst->getOperand(0)->getType()->getScalarSizeInBits();
+    auto bitWidthC = ConstantInt::get(Type::getInt64Ty(ctx), bitWidth, false);
+    args.push_back(bitWidthC);
+
+    auto shiftValue =
+        Builder.CreateIntCast(shiftInst->getOperand(1), Type::getInt64Ty(ctx),
+                              false, /* sign doesn't matter */
+                              "int_cast_to_i64");
+    args.push_back(shiftValue);
+
+    Builder.CreateCall(overshiftCheckFunction, args);
+  }
+
+  return true;
 }