about summary refs log tree commit diff homepage
diff options
context:
space:
mode:
authorPeter Collingbourne <peter@pcc.me.uk>2011-08-01 02:26:27 +0000
committerPeter Collingbourne <peter@pcc.me.uk>2011-08-01 02:26:27 +0000
commit3c2bacc7a9fb95caef07d87d9b6220712b7d613f (patch)
tree7aee0c72179d88a3a7adf6f6984d962d9f9fba48
parent179a8930253e7e81dda77fda1db11a6d11b22f14 (diff)
downloadklee-3c2bacc7a9fb95caef07d87d9b6220712b7d613f.tar.gz
Make the Executor capable of handling bitcasts of aliases, by rewriting the
direct function call logic.

git-svn-id: https://llvm.org/svn/llvm-project/klee/trunk@136605 91177308-0d34-0410-b5e6-96231b3b80d8
-rw-r--r--lib/Core/Executor.cpp79
-rw-r--r--lib/Core/Executor.h3
-rw-r--r--test/Feature/BitcastAlias.ll33
3 files changed, 83 insertions, 32 deletions
diff --git a/lib/Core/Executor.cpp b/lib/Core/Executor.cpp
index 1a37498f..87e70a4c 100644
--- a/lib/Core/Executor.cpp
+++ b/lib/Core/Executor.cpp
@@ -54,6 +54,7 @@
 #include "llvm/LLVMContext.h"
 #endif
 #include "llvm/Module.h"
+#include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/Support/CallSite.h"
 #include "llvm/Support/CommandLine.h"
@@ -1287,25 +1288,46 @@ void Executor::printFileLine(ExecutionState &state, KInstruction *ki) {
     std::cerr << "     [no debug info]:";
 }
 
+/// Compute the true target of a function call, resolving LLVM and KLEE aliases
+/// and bitcasts.
+Function* Executor::getTargetFunction(Value *calledVal, ExecutionState &state) {
+  SmallPtrSet<const GlobalValue*, 3> Visited;
 
-Function* Executor::getCalledFunction(CallSite &cs, ExecutionState &state) {
-  Function *f = cs.getCalledFunction();
-  
-  if (f) {
-    std::string alias = state.getFnAlias(f->getName());
-    if (alias != "") {
-      llvm::Module* currModule = kmodule->module;
-      Function* old_f = f;
-      f = currModule->getFunction(alias);
-      if (!f) {
-	llvm::errs() << "Function " << alias << "(), alias for " 
-                     << old_f->getName() << " not found!\n";
-	assert(f && "function alias not found");
+  Constant *c = dyn_cast<Constant>(calledVal);
+  if (!c)
+    return 0;
+
+  while (true) {
+    if (GlobalValue *gv = dyn_cast<GlobalValue>(c)) {
+      if (!Visited.insert(gv))
+        return 0;
+
+      std::string alias = state.getFnAlias(gv->getName());
+      if (alias != "") {
+        llvm::Module* currModule = kmodule->module;
+        GlobalValue *old_gv = gv;
+        gv = currModule->getNamedValue(alias);
+        if (!gv) {
+          llvm::errs() << "Function " << alias << "(), alias for " 
+                       << old_gv->getName() << " not found!\n";
+          assert(0 && "function alias not found");
+        }
       }
-    }
+     
+      if (Function *f = dyn_cast<Function>(gv))
+        return f;
+      else if (GlobalAlias *ga = dyn_cast<GlobalAlias>(gv))
+        c = ga->getAliasee();
+      else
+        return 0;
+    } else if (llvm::ConstantExpr *ce = dyn_cast<llvm::ConstantExpr>(c)) {
+      if (ce->getOpcode()==Instruction::BitCast)
+        c = ce->getOperand(0);
+      else
+        return 0;
+    } else
+      return 0;
   }
-  
-  return f;
 }
 
 static bool isDebugIntrinsic(const Function *f, KModule *KM) {
@@ -1526,7 +1548,8 @@ void Executor::executeInstruction(ExecutionState &state, KInstruction *ki) {
     CallSite cs(i);
 
     unsigned numArgs = cs.arg_size();
-    Function *f = getCalledFunction(cs, state);
+    Value *fp = cs.getCalledValue();
+    Function *f = getTargetFunction(fp, state);
 
     // Skip debug intrinsics, we can't evaluate their metadata arguments.
     if (f && isDebugIntrinsic(f, kmodule))
@@ -1539,19 +1562,15 @@ void Executor::executeInstruction(ExecutionState &state, KInstruction *ki) {
     for (unsigned j=0; j<numArgs; ++j)
       arguments.push_back(eval(ki, j+1, state).value);
 
-    if (!f) {
+    if (f) {
+      const FunctionType *fType = 
+        dyn_cast<FunctionType>(cast<PointerType>(f->getType())->getElementType());
+      const FunctionType *fpType =
+        dyn_cast<FunctionType>(cast<PointerType>(fp->getType())->getElementType());
+
       // special case the call with a bitcast case
-      Value *fp = cs.getCalledValue();
-      llvm::ConstantExpr *ce = dyn_cast<llvm::ConstantExpr>(fp);
-        
-      if (ce && ce->getOpcode()==Instruction::BitCast) {
-        f = dyn_cast<Function>(ce->getOperand(0));
-        assert(f && "XXX unrecognized constant expression in call");
-        const FunctionType *fType = 
-          dyn_cast<FunctionType>(cast<PointerType>(f->getType())->getElementType());
-        const FunctionType *ceType =
-          dyn_cast<FunctionType>(cast<PointerType>(ce->getType())->getElementType());
-        assert(fType && ceType && "unable to get function type");
+      if (fType != fpType) {
+        assert(fType && fpType && "unable to get function type");
 
         // XXX check result coercion
 
@@ -1581,9 +1600,7 @@ void Executor::executeInstruction(ExecutionState &state, KInstruction *ki) {
         terminateStateOnExecError(state, "inline assembly is unsupported");
         break;
       }
-    }
 
-    if (f) {
       executeCall(state, ki, f, arguments);
     } else {
       ref<Expr> v = eval(ki, 0, state).value;
diff --git a/lib/Core/Executor.h b/lib/Core/Executor.h
index 3df4cc88..a296051b 100644
--- a/lib/Core/Executor.h
+++ b/lib/Core/Executor.h
@@ -170,7 +170,8 @@ private:
   /// The maximum time to allow for a single stp query.
   double stpTimeout;  
 
-  llvm::Function* getCalledFunction(llvm::CallSite &cs, ExecutionState &state);
+  llvm::Function* getTargetFunction(llvm::Value *calledVal,
+                                    ExecutionState &state);
   
   void executeInstruction(ExecutionState &state, KInstruction *ki);
 
diff --git a/test/Feature/BitcastAlias.ll b/test/Feature/BitcastAlias.ll
new file mode 100644
index 00000000..711ed7ad
--- /dev/null
+++ b/test/Feature/BitcastAlias.ll
@@ -0,0 +1,33 @@
+; RUN: llvm-as %s -f -o %t1.bc
+; RUN: %klee -disable-opt %t1.bc > %t2
+; RUN: grep PASS %t2
+
+target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v64:64:64-v128:128:128-a0:0:64-s0:64:64-f80:128:128-n8:16:32:64"
+target triple = "x86_64-unknown-linux-gnu"
+
+@foo = alias i32 (i32)* @__foo
+
+define i32 @__foo(i32 %i) nounwind {
+entry:
+  ret i32 %i
+}
+
+declare i32 @puts(i8*)
+
+@.passstr = private constant [5 x i8] c"PASS\00", align 1
+@.failstr = private constant [5 x i8] c"FAIL\00", align 1
+
+define i32 @main(i32 %argc, i8** nocapture %argv) nounwind readnone {
+entry:
+  %call = call i32 (i64)* bitcast (i32 (i32)* @foo to i32 (i64)*)(i64 52)
+  %r = icmp eq i32 %call, 52
+  br i1 %r, label %bbtrue, label %bbfalse
+
+bbtrue:
+  %0 = call i32 @puts(i8* getelementptr inbounds ([5 x i8]* @.passstr, i64 0, i64 0)) nounwind
+  ret i32 0
+
+bbfalse:
+  %1 = call i32 @puts(i8* getelementptr inbounds ([5 x i8]* @.failstr, i64 0, i64 0)) nounwind
+  ret i32 0
+}