about summary refs log tree commit diff homepage
path: root/lib/Module
diff options
context:
space:
mode:
Diffstat (limited to 'lib/Module')
-rw-r--r--lib/Module/CMakeLists.txt1
-rw-r--r--lib/Module/InstructionOperandTypeCheckPass.cpp184
-rw-r--r--lib/Module/KModule.cpp10
-rw-r--r--lib/Module/Passes.h17
4 files changed, 212 insertions, 0 deletions
diff --git a/lib/Module/CMakeLists.txt b/lib/Module/CMakeLists.txt
index fc481780..c1a5d809 100644
--- a/lib/Module/CMakeLists.txt
+++ b/lib/Module/CMakeLists.txt
@@ -9,6 +9,7 @@
 klee_add_component(kleeModule
   Checks.cpp
   InstructionInfoTable.cpp
+  InstructionOperandTypeCheckPass.cpp
   IntrinsicCleaner.cpp
   KInstruction.cpp
   KModule.cpp
diff --git a/lib/Module/InstructionOperandTypeCheckPass.cpp b/lib/Module/InstructionOperandTypeCheckPass.cpp
new file mode 100644
index 00000000..449eea48
--- /dev/null
+++ b/lib/Module/InstructionOperandTypeCheckPass.cpp
@@ -0,0 +1,184 @@
+//===-- InstructionOperandTypeCheckPass.cpp ---------------------*- C++ -*-===//
+//
+//                     The KLEE Symbolic Virtual Machine
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+#include "Passes.h"
+#include "klee/Config/Version.h"
+#include "klee/Internal/Support/ErrorHandling.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace llvm;
+
+namespace {
+
+void printOperandWarning(const char *expected, const Instruction *i,
+                         LLVM_TYPE_Q Type *ty, unsigned opNum) {
+  std::string msg;
+  llvm::raw_string_ostream ss(msg);
+  ss << "Found unexpected type (" << *ty << ") at operand " << opNum
+     << ". Expected " << expected << " in " << *i;
+  i->print(ss);
+  ss.flush();
+  klee::klee_warning("%s", msg.c_str());
+}
+
+bool checkOperandTypeIsScalarInt(const Instruction *i, unsigned opNum) {
+  assert(opNum < i->getNumOperands());
+  LLVM_TYPE_Q llvm::Type *ty = i->getOperand(opNum)->getType();
+  if (!(ty->isIntegerTy())) {
+    printOperandWarning("scalar integer", i, ty, opNum);
+    return false;
+  }
+  return true;
+}
+
+bool checkOperandTypeIsScalarIntOrPointer(const Instruction *i,
+                                          unsigned opNum) {
+  assert(opNum < i->getNumOperands());
+  LLVM_TYPE_Q llvm::Type *ty = i->getOperand(opNum)->getType();
+  if (!(ty->isIntegerTy() || ty->isPointerTy())) {
+    printOperandWarning("scalar integer or pointer", i, ty, opNum);
+    return false;
+  }
+  return true;
+}
+
+bool checkOperandTypeIsScalarPointer(const Instruction *i, unsigned opNum) {
+  assert(opNum < i->getNumOperands());
+  LLVM_TYPE_Q llvm::Type *ty = i->getOperand(opNum)->getType();
+  if (!(ty->isPointerTy())) {
+    printOperandWarning("scalar pointer", i, ty, opNum);
+    return false;
+  }
+  return true;
+}
+
+bool checkOperandTypeIsScalarFloat(const Instruction *i, unsigned opNum) {
+  assert(opNum < i->getNumOperands());
+  LLVM_TYPE_Q llvm::Type *ty = i->getOperand(opNum)->getType();
+  if (!(ty->isFloatingPointTy())) {
+    printOperandWarning("scalar float", i, ty, opNum);
+    return false;
+  }
+  return true;
+}
+
+bool checkOperandsHaveSameType(const Instruction *i, unsigned opNum0,
+                               unsigned opNum1) {
+  assert(opNum0 < i->getNumOperands());
+  assert(opNum1 < i->getNumOperands());
+  LLVM_TYPE_Q llvm::Type *ty0 = i->getOperand(opNum0)->getType();
+  LLVM_TYPE_Q llvm::Type *ty1 = i->getOperand(opNum1)->getType();
+  if (!(ty0 == ty1)) {
+    std::string msg;
+    llvm::raw_string_ostream ss(msg);
+    ss << "Found mismatched type (" << *ty0 << " != " << *ty1
+       << ") for operands" << opNum0 << " and " << opNum1
+       << ". Expected operand types to match in " << *i;
+    i->print(ss);
+    ss.flush();
+    klee::klee_warning("%s", msg.c_str());
+    return false;
+  }
+  return true;
+}
+
+bool checkInstruction(const Instruction *i) {
+  switch (i->getOpcode()) {
+  case Instruction::Select: {
+    // Note we do not enforce that operand 1 and 2 are scalar because the
+    // scalarizer pass might not remove these. This could be selecting which
+    // vector operand to feed to another instruction. The Executor can handle
+    // this so case so this is not a problem
+    return checkOperandTypeIsScalarInt(i, 0) &
+           checkOperandsHaveSameType(i, 1, 2);
+  }
+  // Integer arithmetic, logical and shifting
+  // TODO: When we upgrade to newer LLVM use LLVM_FALLTHROUGH
+  case Instruction::Add:
+  case Instruction::Sub:
+  case Instruction::Mul:
+  case Instruction::UDiv:
+  case Instruction::SDiv:
+  case Instruction::URem:
+  case Instruction::SRem:
+  case Instruction::And:
+  case Instruction::Or:
+  case Instruction::Xor:
+  case Instruction::Shl:
+  case Instruction::LShr:
+  case Instruction::AShr: {
+    return checkOperandTypeIsScalarInt(i, 0) &
+           checkOperandTypeIsScalarInt(i, 1);
+  }
+  // Integer comparison
+  case Instruction::ICmp: {
+    return checkOperandTypeIsScalarIntOrPointer(i, 0) &
+           checkOperandTypeIsScalarIntOrPointer(i, 1);
+  }
+  // Integer Conversion
+  case Instruction::Trunc:
+  case Instruction::ZExt:
+  case Instruction::SExt:
+  case Instruction::IntToPtr: {
+    return checkOperandTypeIsScalarInt(i, 0);
+  }
+  case Instruction::PtrToInt: {
+    return checkOperandTypeIsScalarPointer(i, 0);
+  }
+  // TODO: Figure out if Instruction::BitCast needs checking
+  // Floating point arithmetic
+  case Instruction::FAdd:
+  case Instruction::FSub:
+  case Instruction::FMul:
+  case Instruction::FDiv:
+  case Instruction::FRem: {
+    return checkOperandTypeIsScalarFloat(i, 0) &
+           checkOperandTypeIsScalarFloat(i, 1);
+  }
+  // Floating point conversion
+  case Instruction::FPTrunc:
+  case Instruction::FPExt:
+  case Instruction::FPToUI:
+  case Instruction::FPToSI: {
+    return checkOperandTypeIsScalarFloat(i, 0);
+  }
+  case Instruction::UIToFP:
+  case Instruction::SIToFP: {
+    return checkOperandTypeIsScalarInt(i, 0);
+  }
+  // Floating point comparison
+  case Instruction::FCmp: {
+    return checkOperandTypeIsScalarFloat(i, 0) &
+           checkOperandTypeIsScalarFloat(i, 1);
+  }
+  default:
+    // Treat all other instructions as conforming
+    return true;
+  }
+}
+}
+
+namespace klee {
+
+char InstructionOperandTypeCheckPass::ID = 0;
+
+bool InstructionOperandTypeCheckPass::runOnModule(Module &M) {
+  instructionOperandsConform = true;
+  for (Module::iterator fi = M.begin(), fe = M.end(); fi != fe; ++fi) {
+    for (Function::iterator bi = fi->begin(), be = fi->end(); bi != be; ++bi) {
+      for (BasicBlock::iterator ii = bi->begin(), ie = bi->end(); ii != ie;
+           ++ii) {
+        Instruction *i = static_cast<Instruction *>(ii);
+        instructionOperandsConform &= checkInstruction(i);
+      }
+    }
+  }
+
+  return false;
+}
+}
diff --git a/lib/Module/KModule.cpp b/lib/Module/KModule.cpp
index bee9a0b4..215b2275 100644
--- a/lib/Module/KModule.cpp
+++ b/lib/Module/KModule.cpp
@@ -390,9 +390,19 @@ void KModule::prepare(const Interpreter::ModuleOptions &opts,
   case eSwitchTypeLLVM:  pm3.add(createLowerSwitchPass()); break;
   default: klee_error("invalid --switch-type");
   }
+  InstructionOperandTypeCheckPass *operandTypeCheckPass =
+      new InstructionOperandTypeCheckPass();
   pm3.add(new IntrinsicCleanerPass(*targetData));
   pm3.add(new PhiCleanerPass());
+  pm3.add(operandTypeCheckPass);
   pm3.run(*module);
+
+  // Enforce the operand type invariants that the Executor expects.  This
+  // implicitly depends on the "Scalarizer" pass to be run in order to succeed
+  // in the presence of vector instructions.
+  if (!operandTypeCheckPass->checkPassed()) {
+    klee_error("Unexpected instruction operand types detected");
+  }
 #if LLVM_VERSION_CODE < LLVM_VERSION(3, 3)
   // For cleanliness see if we can discard any of the functions we
   // forced to import.
diff --git a/lib/Module/Passes.h b/lib/Module/Passes.h
index 293766c1..2ac57b9b 100644
--- a/lib/Module/Passes.h
+++ b/lib/Module/Passes.h
@@ -189,6 +189,23 @@ private:
 llvm::FunctionPass *createScalarizerPass();
 #endif
 
+/// InstructionOperandTypeCheckPass - Type checks the types of instruction
+/// operands to check that they conform to invariants expected by the Executor.
+///
+/// This is a ModulePass because other pass types are not meant to maintain
+/// state between calls.
+class InstructionOperandTypeCheckPass : public llvm::ModulePass {
+private:
+  bool instructionOperandsConform;
+
+public:
+  static char ID;
+  InstructionOperandTypeCheckPass()
+      : llvm::ModulePass(ID), instructionOperandsConform(true) {}
+  // TODO: Add `override` when we switch to C++11
+  bool runOnModule(llvm::Module &M);
+  bool checkPassed() const { return instructionOperandsConform; }
+};
 }
 
 #endif