about summary refs log tree commit diff homepage
path: root/lib/Module/Checks.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'lib/Module/Checks.cpp')
-rw-r--r--lib/Module/Checks.cpp87
1 files changed, 81 insertions, 6 deletions
diff --git a/lib/Module/Checks.cpp b/lib/Module/Checks.cpp
index 18ef398a..79fd4afc 100644
--- a/lib/Module/Checks.cpp
+++ b/lib/Module/Checks.cpp
@@ -11,6 +11,19 @@
 
 #include "klee/Config/Version.h"
 
+#if LLVM_VERSION_CODE >= LLVM_VERSION(3, 3)
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/InstrTypes.h"
+#include "llvm/IR/Instruction.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/Type.h"
+#include "llvm/IR/DataLayout.h"
+#else
 #include "llvm/Constants.h"
 #include "llvm/DerivedTypes.h"
 #include "llvm/Function.h"
@@ -18,19 +31,22 @@
 #include "llvm/Instruction.h"
 #include "llvm/Instructions.h"
 #include "llvm/IntrinsicInst.h"
-#if LLVM_VERSION_CODE >= LLVM_VERSION(2, 7)
-#include "llvm/LLVMContext.h"
-#endif
 #include "llvm/Module.h"
-#include "llvm/Pass.h"
 #include "llvm/Type.h"
-#include "llvm/Transforms/Scalar.h"
-#include "llvm/Transforms/Utils/BasicBlockUtils.h"
+
+#include "llvm/LLVMContext.h"
+
 #if LLVM_VERSION_CODE <= LLVM_VERSION(3, 1)
 #include "llvm/Target/TargetData.h"
 #else
 #include "llvm/DataLayout.h"
 #endif
+#endif
+#include "llvm/Pass.h"
+#include "llvm/Transforms/Scalar.h"
+#include "llvm/Transforms/Utils/BasicBlockUtils.h"
+#include "llvm/Support/CallSite.h"
+#include <iostream>
 
 using namespace llvm;
 using namespace klee;
@@ -76,3 +92,62 @@ bool DivCheckPass::runOnModule(Module &M) {
   }
   return moduleChanged;
 }
+
+char OvershiftCheckPass::ID;
+
+bool OvershiftCheckPass::runOnModule(Module &M) {
+  Function *overshiftCheckFunction = 0;
+
+  bool moduleChanged = false;
+
+  for (Module::iterator f = M.begin(), fe = M.end(); f != fe; ++f) {
+    for (Function::iterator b = f->begin(), be = f->end(); b != be; ++b) {
+      for (BasicBlock::iterator i = b->begin(), ie = b->end(); i != ie; ++i) {
+          if (BinaryOperator* binOp = dyn_cast<BinaryOperator>(i)) {
+          // find all shift instructions
+          Instruction::BinaryOps opcode = binOp->getOpcode();
+
+          if (opcode == Instruction::Shl ||
+              opcode == Instruction::LShr ||
+              opcode == Instruction::AShr ) {
+            std::vector<llvm::Value*> args;
+
+            // Determine bit width of first operand
+            uint64_t bitWidth=i->getOperand(0)->getType()->getScalarSizeInBits();
+
+            ConstantInt *bitWidthC = ConstantInt::get(Type::getInt64Ty(getGlobalContext()),bitWidth,false);
+            args.push_back(bitWidthC);
+
+            CastInst *shift =
+              CastInst::CreateIntegerCast(i->getOperand(1),
+                                          Type::getInt64Ty(getGlobalContext()),
+                                          false,  /* sign doesn't matter */
+                                          "int_cast_to_i64",
+                                          i);
+            args.push_back(shift);
+
+
+            // Lazily bind the function to avoid always importing it.
+            if (!overshiftCheckFunction) {
+              Constant *fc = M.getOrInsertFunction("klee_overshift_check",
+                                                   Type::getVoidTy(getGlobalContext()),
+                                                   Type::getInt64Ty(getGlobalContext()),
+                                                   Type::getInt64Ty(getGlobalContext()),
+                                                   NULL);
+              overshiftCheckFunction = cast<Function>(fc);
+            }
+
+            // Inject CallInstr to check if overshifting possible
+#if LLVM_VERSION_CODE >= LLVM_VERSION(3, 0)
+            CallInst::Create(overshiftCheckFunction, args, "", &*i);
+#else
+            CallInst::Create(overshiftCheckFunction, args.begin(), args.end(), "", &*i);
+#endif
+            moduleChanged = true;
+          }
+        }
+      }
+    }
+  }
+  return moduleChanged;
+}