aboutsummaryrefslogtreecommitdiffhomepage
path: root/lib
diff options
context:
space:
mode:
authorCristian Cadar <c.cadar@imperial.ac.uk>2017-07-20 14:16:39 +0100
committerGitHub <noreply@github.com>2017-07-20 14:16:39 +0100
commit051e44e59cc5260324667d3fd8a8d46d54c3e72e (patch)
tree34fdea2badf338c2f6d78185cc5e03cea19cab9b /lib
parent99a1e3a25cd1405a15112a85de1ff5fc714e7be1 (diff)
parentffe695c29915cf8605b2fb807cd083cdcc769d47 (diff)
downloadklee-051e44e59cc5260324667d3fd8a8d46d54c3e72e.tar.gz
Merge pull request #657 from delcypher/vectorized_instructions
Implement basic support for vectorized instructions.
Diffstat (limited to 'lib')
-rw-r--r--lib/Core/Executor.cpp109
-rw-r--r--lib/Core/Executor.h1
-rw-r--r--lib/Module/CMakeLists.txt2
-rw-r--r--lib/Module/InstructionOperandTypeCheckPass.cpp184
-rw-r--r--lib/Module/KModule.cpp19
-rw-r--r--lib/Module/Passes.h26
-rw-r--r--lib/Module/Scalarizer.cpp651
7 files changed, 980 insertions, 12 deletions
diff --git a/lib/Core/Executor.cpp b/lib/Core/Executor.cpp
index 60ed1bd9..ff842fd1 100644
--- a/lib/Core/Executor.cpp
+++ b/lib/Core/Executor.cpp
@@ -279,6 +279,7 @@ namespace {
cl::values(
clEnumValN(Executor::Abort, "Abort", "The program crashed"),
clEnumValN(Executor::Assert, "Assert", "An assertion was hit"),
+ clEnumValN(Executor::BadVectorAccess, "BadVectorAccess", "Vector accessed out of bounds"),
clEnumValN(Executor::Exec, "Exec", "Trying to execute an unexpected instruction"),
clEnumValN(Executor::External, "External", "External objects referenced"),
clEnumValN(Executor::Free, "Free", "Freeing invalid memory"),
@@ -333,6 +334,7 @@ namespace klee {
const char *Executor::TerminateReasonNames[] = {
[ Abort ] = "abort",
[ Assert ] = "assert",
+ [ BadVectorAccess ] = "bad_vector_access",
[ Exec ] = "exec",
[ External ] = "external",
[ Free ] = "free",
@@ -1101,8 +1103,18 @@ ref<klee::ConstantExpr> Executor::evalConstant(const Constant *c) {
}
ref<Expr> res = ConcatExpr::createN(kids.size(), kids.data());
return cast<ConstantExpr>(res);
+ } else if (const ConstantVector *cv = dyn_cast<ConstantVector>(c)) {
+ llvm::SmallVector<ref<Expr>, 8> kids;
+ const size_t numOperands = cv->getNumOperands();
+ kids.reserve(numOperands);
+ for (unsigned i = 0; i < numOperands; ++i) {
+ kids.push_back(evalConstant(cv->getOperand(i)));
+ }
+ ref<Expr> res = ConcatExpr::createN(numOperands, kids.data());
+ assert(isa<ConstantExpr>(res) &&
+ "result of constant vector built is not a constant");
+ return cast<ConstantExpr>(res);
} else {
- // Constant{Vector}
llvm::report_fatal_error("invalid argument to evalConstant()");
}
}
@@ -1909,6 +1921,7 @@ void Executor::executeInstruction(ExecutionState &state, KInstruction *ki) {
// Special instructions
case Instruction::Select: {
+ // NOTE: It is not required that operands 1 and 2 be of scalar type.
ref<Expr> cond = eval(ki, 0, state).value;
ref<Expr> tExpr = eval(ki, 1, state).value;
ref<Expr> fExpr = eval(ki, 2, state).value;
@@ -1967,7 +1980,7 @@ void Executor::executeInstruction(ExecutionState &state, KInstruction *ki) {
bindLocal(ki, state, result);
break;
}
-
+
case Instruction::SRem: {
ref<Expr> left = eval(ki, 0, state).value;
ref<Expr> right = eval(ki, 1, state).value;
@@ -2029,7 +2042,7 @@ void Executor::executeInstruction(ExecutionState &state, KInstruction *ki) {
case Instruction::ICmp: {
CmpInst *ci = cast<CmpInst>(i);
ICmpInst *ii = cast<ICmpInst>(ci);
-
+
switch(ii->getPredicate()) {
case ICmpInst::ICMP_EQ: {
ref<Expr> left = eval(ki, 0, state).value;
@@ -2194,7 +2207,7 @@ void Executor::executeInstruction(ExecutionState &state, KInstruction *ki) {
ref<Expr> arg = eval(ki, 0, state).value;
bindLocal(ki, state, ZExtExpr::create(arg, pType));
break;
- }
+ }
case Instruction::PtrToInt: {
CastInst *ci = cast<CastInst>(i);
Expr::Width iType = getWidthForLLVMType(ci->getType());
@@ -2249,7 +2262,7 @@ void Executor::executeInstruction(ExecutionState &state, KInstruction *ki) {
bindLocal(ki, state, ConstantExpr::alloc(Res.bitcastToAPInt()));
break;
}
-
+
case Instruction::FMul: {
ref<ConstantExpr> left = toConstant(state, eval(ki, 0, state).value,
"floating point");
@@ -2565,16 +2578,88 @@ void Executor::executeInstruction(ExecutionState &state, KInstruction *ki) {
break;
}
#endif
+ case Instruction::InsertElement: {
+ InsertElementInst *iei = cast<InsertElementInst>(i);
+ ref<Expr> vec = eval(ki, 0, state).value;
+ ref<Expr> newElt = eval(ki, 1, state).value;
+ ref<Expr> idx = eval(ki, 2, state).value;
+
+ ConstantExpr *cIdx = dyn_cast<ConstantExpr>(idx);
+ if (cIdx == NULL) {
+ terminateStateOnError(
+ state, "InsertElement, support for symbolic index not implemented",
+ Unhandled);
+ return;
+ }
+ uint64_t iIdx = cIdx->getZExtValue();
+ const llvm::VectorType *vt = iei->getType();
+ unsigned EltBits = getWidthForLLVMType(vt->getElementType());
- // Other instructions...
- // Unhandled
- case Instruction::ExtractElement:
- case Instruction::InsertElement:
+ if (iIdx >= vt->getNumElements()) {
+ // Out of bounds write
+ terminateStateOnError(state, "Out of bounds write when inserting element",
+ BadVectorAccess);
+ return;
+ }
+
+ const unsigned elementCount = vt->getNumElements();
+ llvm::SmallVector<ref<Expr>, 8> elems;
+ elems.reserve(elementCount);
+ for (unsigned i = 0; i < elementCount; ++i) {
+ // evalConstant() will use ConcatExpr to build vectors with the
+ // zero-th element leftmost (most significant bits), followed
+ // by the next element (second leftmost) and so on. This means
+ // that we have to adjust the index so we read left to right
+ // rather than right to left.
+ unsigned bitOffset = EltBits * (elementCount - i - 1);
+ elems.push_back(i == iIdx ? newElt
+ : ExtractExpr::create(vec, bitOffset, EltBits));
+ }
+
+ ref<Expr> Result = ConcatExpr::createN(elementCount, elems.data());
+ bindLocal(ki, state, Result);
+ break;
+ }
+ case Instruction::ExtractElement: {
+ ExtractElementInst *eei = cast<ExtractElementInst>(i);
+ ref<Expr> vec = eval(ki, 0, state).value;
+ ref<Expr> idx = eval(ki, 1, state).value;
+
+ ConstantExpr *cIdx = dyn_cast<ConstantExpr>(idx);
+ if (cIdx == NULL) {
+ terminateStateOnError(
+ state, "ExtractElement, support for symbolic index not implemented",
+ Unhandled);
+ return;
+ }
+ uint64_t iIdx = cIdx->getZExtValue();
+ const llvm::VectorType *vt = eei->getVectorOperandType();
+ unsigned EltBits = getWidthForLLVMType(vt->getElementType());
+
+ if (iIdx >= vt->getNumElements()) {
+ // Out of bounds read
+ terminateStateOnError(state, "Out of bounds read when extracting element",
+ BadVectorAccess);
+ return;
+ }
+
+ // evalConstant() will use ConcatExpr to build vectors with the
+ // zero-th element left most (most significant bits), followed
+ // by the next element (second left most) and so on. This means
+ // that we have to adjust the index so we read left to right
+ // rather than right to left.
+ unsigned bitOffset = EltBits*(vt->getNumElements() - iIdx -1);
+ ref<Expr> Result = ExtractExpr::create(vec, bitOffset, EltBits);
+ bindLocal(ki, state, Result);
+ break;
+ }
case Instruction::ShuffleVector:
- terminateStateOnError(state, "XXX vector instructions unhandled",
- Unhandled);
+ // Should never happen due to Scalarizer pass removing ShuffleVector
+ // instructions.
+ terminateStateOnExecError(state, "Unexpected ShuffleVector instruction");
break;
-
+ // Other instructions...
+ // Unhandled
default:
terminateStateOnExecError(state, "illegal instruction");
break;
diff --git a/lib/Core/Executor.h b/lib/Core/Executor.h
index c3f6e705..b3bb6864 100644
--- a/lib/Core/Executor.h
+++ b/lib/Core/Executor.h
@@ -105,6 +105,7 @@ public:
enum TerminateReason {
Abort,
Assert,
+ BadVectorAccess,
Exec,
External,
Free,
diff --git a/lib/Module/CMakeLists.txt b/lib/Module/CMakeLists.txt
index 006443a9..c1a5d809 100644
--- a/lib/Module/CMakeLists.txt
+++ b/lib/Module/CMakeLists.txt
@@ -9,6 +9,7 @@
klee_add_component(kleeModule
Checks.cpp
InstructionInfoTable.cpp
+ InstructionOperandTypeCheckPass.cpp
IntrinsicCleaner.cpp
KInstruction.cpp
KModule.cpp
@@ -17,6 +18,7 @@ klee_add_component(kleeModule
Optimize.cpp
PhiCleaner.cpp
RaiseAsm.cpp
+ Scalarizer.cpp
)
set(LLVM_COMPONENTS
diff --git a/lib/Module/InstructionOperandTypeCheckPass.cpp b/lib/Module/InstructionOperandTypeCheckPass.cpp
new file mode 100644
index 00000000..449eea48
--- /dev/null
+++ b/lib/Module/InstructionOperandTypeCheckPass.cpp
@@ -0,0 +1,184 @@
+//===-- InstructionOperandTypeCheckPass.cpp ---------------------*- C++ -*-===//
+//
+// The KLEE Symbolic Virtual Machine
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+#include "Passes.h"
+#include "klee/Config/Version.h"
+#include "klee/Internal/Support/ErrorHandling.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace llvm;
+
+namespace {
+
+void printOperandWarning(const char *expected, const Instruction *i,
+ LLVM_TYPE_Q Type *ty, unsigned opNum) {
+ std::string msg;
+ llvm::raw_string_ostream ss(msg);
+ ss << "Found unexpected type (" << *ty << ") at operand " << opNum
+ << ". Expected " << expected << " in " << *i;
+ i->print(ss);
+ ss.flush();
+ klee::klee_warning("%s", msg.c_str());
+}
+
+bool checkOperandTypeIsScalarInt(const Instruction *i, unsigned opNum) {
+ assert(opNum < i->getNumOperands());
+ LLVM_TYPE_Q llvm::Type *ty = i->getOperand(opNum)->getType();
+ if (!(ty->isIntegerTy())) {
+ printOperandWarning("scalar integer", i, ty, opNum);
+ return false;
+ }
+ return true;
+}
+
+bool checkOperandTypeIsScalarIntOrPointer(const Instruction *i,
+ unsigned opNum) {
+ assert(opNum < i->getNumOperands());
+ LLVM_TYPE_Q llvm::Type *ty = i->getOperand(opNum)->getType();
+ if (!(ty->isIntegerTy() || ty->isPointerTy())) {
+ printOperandWarning("scalar integer or pointer", i, ty, opNum);
+ return false;
+ }
+ return true;
+}
+
+bool checkOperandTypeIsScalarPointer(const Instruction *i, unsigned opNum) {
+ assert(opNum < i->getNumOperands());
+ LLVM_TYPE_Q llvm::Type *ty = i->getOperand(opNum)->getType();
+ if (!(ty->isPointerTy())) {
+ printOperandWarning("scalar pointer", i, ty, opNum);
+ return false;
+ }
+ return true;
+}
+
+bool checkOperandTypeIsScalarFloat(const Instruction *i, unsigned opNum) {
+ assert(opNum < i->getNumOperands());
+ LLVM_TYPE_Q llvm::Type *ty = i->getOperand(opNum)->getType();
+ if (!(ty->isFloatingPointTy())) {
+ printOperandWarning("scalar float", i, ty, opNum);
+ return false;
+ }
+ return true;
+}
+
+bool checkOperandsHaveSameType(const Instruction *i, unsigned opNum0,
+ unsigned opNum1) {
+ assert(opNum0 < i->getNumOperands());
+ assert(opNum1 < i->getNumOperands());
+ LLVM_TYPE_Q llvm::Type *ty0 = i->getOperand(opNum0)->getType();
+ LLVM_TYPE_Q llvm::Type *ty1 = i->getOperand(opNum1)->getType();
+ if (!(ty0 == ty1)) {
+ std::string msg;
+ llvm::raw_string_ostream ss(msg);
+ ss << "Found mismatched type (" << *ty0 << " != " << *ty1
+ << ") for operands" << opNum0 << " and " << opNum1
+ << ". Expected operand types to match in " << *i;
+ i->print(ss);
+ ss.flush();
+ klee::klee_warning("%s", msg.c_str());
+ return false;
+ }
+ return true;
+}
+
+bool checkInstruction(const Instruction *i) {
+ switch (i->getOpcode()) {
+ case Instruction::Select: {
+ // Note we do not enforce that operand 1 and 2 are scalar because the
+ // scalarizer pass might not remove these. This could be selecting which
+ // vector operand to feed to another instruction. The Executor can handle
+ // this so case so this is not a problem
+ return checkOperandTypeIsScalarInt(i, 0) &
+ checkOperandsHaveSameType(i, 1, 2);
+ }
+ // Integer arithmetic, logical and shifting
+ // TODO: When we upgrade to newer LLVM use LLVM_FALLTHROUGH
+ case Instruction::Add:
+ case Instruction::Sub:
+ case Instruction::Mul:
+ case Instruction::UDiv:
+ case Instruction::SDiv:
+ case Instruction::URem:
+ case Instruction::SRem:
+ case Instruction::And:
+ case Instruction::Or:
+ case Instruction::Xor:
+ case Instruction::Shl:
+ case Instruction::LShr:
+ case Instruction::AShr: {
+ return checkOperandTypeIsScalarInt(i, 0) &
+ checkOperandTypeIsScalarInt(i, 1);
+ }
+ // Integer comparison
+ case Instruction::ICmp: {
+ return checkOperandTypeIsScalarIntOrPointer(i, 0) &
+ checkOperandTypeIsScalarIntOrPointer(i, 1);
+ }
+ // Integer Conversion
+ case Instruction::Trunc:
+ case Instruction::ZExt:
+ case Instruction::SExt:
+ case Instruction::IntToPtr: {
+ return checkOperandTypeIsScalarInt(i, 0);
+ }
+ case Instruction::PtrToInt: {
+ return checkOperandTypeIsScalarPointer(i, 0);
+ }
+ // TODO: Figure out if Instruction::BitCast needs checking
+ // Floating point arithmetic
+ case Instruction::FAdd:
+ case Instruction::FSub:
+ case Instruction::FMul:
+ case Instruction::FDiv:
+ case Instruction::FRem: {
+ return checkOperandTypeIsScalarFloat(i, 0) &
+ checkOperandTypeIsScalarFloat(i, 1);
+ }
+ // Floating point conversion
+ case Instruction::FPTrunc:
+ case Instruction::FPExt:
+ case Instruction::FPToUI:
+ case Instruction::FPToSI: {
+ return checkOperandTypeIsScalarFloat(i, 0);
+ }
+ case Instruction::UIToFP:
+ case Instruction::SIToFP: {
+ return checkOperandTypeIsScalarInt(i, 0);
+ }
+ // Floating point comparison
+ case Instruction::FCmp: {
+ return checkOperandTypeIsScalarFloat(i, 0) &
+ checkOperandTypeIsScalarFloat(i, 1);
+ }
+ default:
+ // Treat all other instructions as conforming
+ return true;
+ }
+}
+}
+
+namespace klee {
+
+char InstructionOperandTypeCheckPass::ID = 0;
+
+bool InstructionOperandTypeCheckPass::runOnModule(Module &M) {
+ instructionOperandsConform = true;
+ for (Module::iterator fi = M.begin(), fe = M.end(); fi != fe; ++fi) {
+ for (Function::iterator bi = fi->begin(), be = fi->end(); bi != be; ++bi) {
+ for (BasicBlock::iterator ii = bi->begin(), ie = bi->end(); ii != ie;
+ ++ii) {
+ Instruction *i = static_cast<Instruction *>(ii);
+ instructionOperandsConform &= checkInstruction(i);
+ }
+ }
+ }
+
+ return false;
+}
+}
diff --git a/lib/Module/KModule.cpp b/lib/Module/KModule.cpp
index aafabacc..ec9972eb 100644
--- a/lib/Module/KModule.cpp
+++ b/lib/Module/KModule.cpp
@@ -307,6 +307,15 @@ void KModule::prepare(const Interpreter::ModuleOptions &opts,
// module.
LegacyLLVMPassManagerTy pm;
pm.add(new RaiseAsmPass());
+#if LLVM_VERSION_CODE >= LLVM_VERSION(3,4)
+ // This pass will scalarize as much code as possible so that the Executor
+ // does not need to handle operands of vector type for most instructions
+ // other than InsertElementInst and ExtractElementInst.
+ //
+ // NOTE: Must come before division/overshift checks because those passes
+ // don't know how to handle vector instructions.
+ pm.add(createScalarizerPass());
+#endif
if (opts.CheckDivZero) pm.add(new DivCheckPass());
if (opts.CheckOvershift) pm.add(new OvershiftCheckPass());
// FIXME: This false here is to work around a bug in
@@ -381,9 +390,19 @@ void KModule::prepare(const Interpreter::ModuleOptions &opts,
case eSwitchTypeLLVM: pm3.add(createLowerSwitchPass()); break;
default: klee_error("invalid --switch-type");
}
+ InstructionOperandTypeCheckPass *operandTypeCheckPass =
+ new InstructionOperandTypeCheckPass();
pm3.add(new IntrinsicCleanerPass(*targetData));
pm3.add(new PhiCleanerPass());
+ pm3.add(operandTypeCheckPass);
pm3.run(*module);
+
+ // Enforce the operand type invariants that the Executor expects. This
+ // implicitly depends on the "Scalarizer" pass to be run in order to succeed
+ // in the presence of vector instructions.
+ if (!operandTypeCheckPass->checkPassed()) {
+ klee_error("Unexpected instruction operand types detected");
+ }
#if LLVM_VERSION_CODE < LLVM_VERSION(3, 3)
// For cleanliness see if we can discard any of the functions we
// forced to import.
diff --git a/lib/Module/Passes.h b/lib/Module/Passes.h
index 4f1a1453..2ac57b9b 100644
--- a/lib/Module/Passes.h
+++ b/lib/Module/Passes.h
@@ -180,6 +180,32 @@ private:
llvm::BasicBlock *defaultBlock);
};
+// This is the interface to a back-ported LLVM pass.
+// Newer versions of LLVM already have this in-tree
+// and we are not supporting vector instructions for
+// LLVM 2.9. Therefore this interface is only needed for
+// LLVM 3.4.
+#if LLVM_VERSION_CODE == LLVM_VERSION(3,4)
+llvm::FunctionPass *createScalarizerPass();
+#endif
+
+/// InstructionOperandTypeCheckPass - Type checks the types of instruction
+/// operands to check that they conform to invariants expected by the Executor.
+///
+/// This is a ModulePass because other pass types are not meant to maintain
+/// state between calls.
+class InstructionOperandTypeCheckPass : public llvm::ModulePass {
+private:
+ bool instructionOperandsConform;
+
+public:
+ static char ID;
+ InstructionOperandTypeCheckPass()
+ : llvm::ModulePass(ID), instructionOperandsConform(true) {}
+ // TODO: Add `override` when we switch to C++11
+ bool runOnModule(llvm::Module &M);
+ bool checkPassed() const { return instructionOperandsConform; }
+};
}
#endif
diff --git a/lib/Module/Scalarizer.cpp b/lib/Module/Scalarizer.cpp
new file mode 100644
index 00000000..0d8e1f48
--- /dev/null
+++ b/lib/Module/Scalarizer.cpp
@@ -0,0 +1,651 @@
+//===--- Scalarizer.cpp - Scalarize vector operations ---------------------===//
+//
+// The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass converts vector operations into scalar operations, in order
+// to expose optimization opportunities on the individual scalar operations.
+// It is mainly intended for targets that do not have vector units, but it
+// may also be useful for revectorizing code to different vector widths.
+//
+//===----------------------------------------------------------------------===//
+#include "klee/Config/Version.h"
+
+// This is taken from r195471 in LLVM. This unfortunately was introduced just
+// after LLVM branched for 3.4 so it has been copied into KLEE's source tree.
+// We only use this for LLVM 3.4 because newer LLVM's have this pass in-tree.
+#if LLVM_VERSION_CODE == LLVM_VERSION(3,4)
+
+#define DEBUG_TYPE "scalarizer"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/InstVisitor.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Transforms/Scalar.h"
+#include "llvm/Transforms/Utils/BasicBlockUtils.h"
+
+using namespace llvm;
+
+namespace {
+// Used to store the scattered form of a vector.
+typedef SmallVector<Value *, 8> ValueVector;
+
+// Used to map a vector Value to its scattered form. We use std::map
+// because we want iterators to persist across insertion and because the
+// values are relatively large.
+typedef std::map<Value *, ValueVector> ScatterMap;
+
+// Lists Instructions that have been replaced with scalar implementations,
+// along with a pointer to their scattered forms.
+typedef SmallVector<std::pair<Instruction *, ValueVector *>, 16> GatherList;
+
+// Provides a very limited vector-like interface for lazily accessing one
+// component of a scattered vector or vector pointer.
+class Scatterer {
+public:
+ // Scatter V into Size components. If new instructions are needed,
+ // insert them before BBI in BB. If Cache is nonnull, use it to cache
+ // the results.
+ Scatterer(BasicBlock *bb, BasicBlock::iterator bbi, Value *v,
+ ValueVector *cachePtr = 0);
+
+ // Return component I, creating a new Value for it if necessary.
+ Value *operator[](unsigned I);
+
+ // Return the number of components.
+ unsigned size() const { return Size; }
+
+private:
+ BasicBlock *BB;
+ BasicBlock::iterator BBI;
+ Value *V;
+ ValueVector *CachePtr;
+ PointerType *PtrTy;
+ ValueVector Tmp;
+ unsigned Size;
+};
+
+// FCmpSpliiter(FCI)(Builder, X, Y, Name) uses Builder to create an FCmp
+// called Name that compares X and Y in the same way as FCI.
+struct FCmpSplitter {
+ FCmpSplitter(FCmpInst &fci) : FCI(fci) {}
+ Value *operator()(IRBuilder<> &Builder, Value *Op0, Value *Op1,
+ const Twine &Name) const {
+ return Builder.CreateFCmp(FCI.getPredicate(), Op0, Op1, Name);
+ }
+ FCmpInst &FCI;
+};
+
+// ICmpSpliiter(ICI)(Builder, X, Y, Name) uses Builder to create an ICmp
+// called Name that compares X and Y in the same way as ICI.
+struct ICmpSplitter {
+ ICmpSplitter(ICmpInst &ici) : ICI(ici) {}
+ Value *operator()(IRBuilder<> &Builder, Value *Op0, Value *Op1,
+ const Twine &Name) const {
+ return Builder.CreateICmp(ICI.getPredicate(), Op0, Op1, Name);
+ }
+ ICmpInst &ICI;
+};
+
+// BinarySpliiter(BO)(Builder, X, Y, Name) uses Builder to create
+// a binary operator like BO called Name with operands X and Y.
+struct BinarySplitter {
+ BinarySplitter(BinaryOperator &bo) : BO(bo) {}
+ Value *operator()(IRBuilder<> &Builder, Value *Op0, Value *Op1,
+ const Twine &Name) const {
+ return Builder.CreateBinOp(BO.getOpcode(), Op0, Op1, Name);
+ }
+ BinaryOperator &BO;
+};
+
+// GEPSpliiter()(Builder, X, Y, Name) uses Builder to create
+// a single GEP called Name with operands X and Y.
+struct GEPSplitter {
+ GEPSplitter() {}
+ Value *operator()(IRBuilder<> &Builder, Value *Op0, Value *Op1,
+ const Twine &Name) const {
+ return Builder.CreateGEP(Op0, Op1, Name);
+ }
+};
+
+// Information about a load or store that we're scalarizing.
+struct VectorLayout {
+ VectorLayout() : VecTy(0), ElemTy(0), VecAlign(0), ElemSize(0) {}
+
+ // Return the alignment of element I.
+ uint64_t getElemAlign(unsigned I) {
+ return MinAlign(VecAlign, I * ElemSize);
+ }
+
+ // The type of the vector.
+ VectorType *VecTy;
+
+ // The type of each element.
+ Type *ElemTy;
+
+ // The alignment of the vector.
+ uint64_t VecAlign;
+
+ // The size of each element.
+ uint64_t ElemSize;
+};
+
+class Scalarizer : public FunctionPass,
+ public InstVisitor<Scalarizer, bool> {
+public:
+ static char ID;
+
+ Scalarizer() :
+ FunctionPass(ID) {
+ // HACK:
+ //initializeScalarizerPass(*PassRegistry::getPassRegistry());
+ }
+
+ virtual bool doInitialization(Module &M);
+ virtual bool runOnFunction(Function &F);
+
+ // InstVisitor methods. They return true if the instruction was scalarized,
+ // false if nothing changed.
+ bool visitInstruction(Instruction &) { return false; }
+ bool visitSelectInst(SelectInst &SI);
+ bool visitICmpInst(ICmpInst &);
+ bool visitFCmpInst(FCmpInst &);
+ bool visitBinaryOperator(BinaryOperator &);
+ bool visitGetElementPtrInst(GetElementPtrInst &);
+ bool visitCastInst(CastInst &);
+ bool visitBitCastInst(BitCastInst &);
+ bool visitShuffleVectorInst(ShuffleVectorInst &);
+ bool visitPHINode(PHINode &);
+ bool visitLoadInst(LoadInst &);
+ bool visitStoreInst(StoreInst &);
+
+private:
+ Scatterer scatter(Instruction *, Value *);
+ void gather(Instruction *, const ValueVector &);
+ bool canTransferMetadata(unsigned Kind);
+ void transferMetadata(Instruction *, const ValueVector &);
+ bool getVectorLayout(Type *, unsigned, VectorLayout &);
+ bool finish();
+
+ template<typename T> bool splitBinary(Instruction &, const T &);
+
+ ScatterMap Scattered;
+ GatherList Gathered;
+ unsigned ParallelLoopAccessMDKind;
+ const DataLayout *TDL;
+};
+
+char Scalarizer::ID = 0;
+} // end anonymous namespace
+
+bool ScalarizeLoadStore = true; // HACK
+/*
+// This is disabled by default because having separate loads and stores makes
+// it more likely that the -combiner-alias-analysis limits will be reached.
+static cl::opt<bool> ScalarizeLoadStore
+ ("scalarize-load-store", cl::Hidden, cl::init(false),
+ cl::desc("Allow the scalarizer pass to scalarize loads and store"));
+
+INITIALIZE_PASS(Scalarizer, "scalarizer", "Scalarize vector operations",
+ false, false)
+*/
+
+Scatterer::Scatterer(BasicBlock *bb, BasicBlock::iterator bbi, Value *v,
+ ValueVector *cachePtr)
+ : BB(bb), BBI(bbi), V(v), CachePtr(cachePtr) {
+ Type *Ty = V->getType();
+ PtrTy = dyn_cast<PointerType>(Ty);
+ if (PtrTy)
+ Ty = PtrTy->getElementType();
+ Size = Ty->getVectorNumElements();
+ if (!CachePtr)
+ Tmp.resize(Size, 0);
+ else if (CachePtr->empty())
+ CachePtr->resize(Size, 0);
+ else
+ assert(Size == CachePtr->size() && "Inconsistent vector sizes");
+}
+
+// Return component I, creating a new Value for it if necessary.
+Value *Scatterer::operator[](unsigned I) {
+ ValueVector &CV = (CachePtr ? *CachePtr : Tmp);
+ // Try to reuse a previous value.
+ if (CV[I])
+ return CV[I];
+ IRBuilder<> Builder(BB, BBI);
+ if (PtrTy) {
+ if (!CV[0]) {
+ Type *Ty =
+ PointerType::get(PtrTy->getElementType()->getVectorElementType(),
+ PtrTy->getAddressSpace());
+ CV[0] = Builder.CreateBitCast(V, Ty, V->getName() + ".i0");
+ }
+ if (I != 0)
+ CV[I] = Builder.CreateConstGEP1_32(CV[0], I,
+ V->getName() + ".i" + Twine(I));
+ } else {
+ // Search through a chain of InsertElementInsts looking for element I.
+ // Record other elements in the cache. The new V is still suitable
+ // for all uncached indices.
+ for (;;) {
+ InsertElementInst *Insert = dyn_cast<InsertElementInst>(V);
+ if (!Insert)
+ break;
+ ConstantInt *Idx = dyn_cast<ConstantInt>(Insert->getOperand(2));
+ if (!Idx)
+ break;
+ unsigned J = Idx->getZExtValue();
+ CV[J] = Insert->getOperand(1);
+ V = Insert->getOperand(0);
+ if (I == J)
+ return CV[J];
+ }
+ CV[I] = Builder.CreateExtractElement(V, Builder.getInt32(I),
+ V->getName() + ".i" + Twine(I));
+ }
+ return CV[I];
+}
+
+bool Scalarizer::doInitialization(Module &M) {
+ ParallelLoopAccessMDKind =
+ M.getContext().getMDKindID("llvm.mem.parallel_loop_access");
+ return false;
+}
+
+bool Scalarizer::runOnFunction(Function &F) {
+ TDL = getAnalysisIfAvailable<DataLayout>();
+ for (Function::iterator BBI = F.begin(), BBE = F.end(); BBI != BBE; ++BBI) {
+ BasicBlock *BB = BBI;
+ for (BasicBlock::iterator II = BB->begin(), IE = BB->end(); II != IE;) {
+ Instruction *I = II;
+ bool Done = visit(I);
+ ++II;
+ if (Done && I->getType()->isVoidTy())
+ I->eraseFromParent();
+ }
+ }
+ return finish();
+}
+
+// Return a scattered form of V that can be accessed by Point. V must be a
+// vector or a pointer to a vector.
+Scatterer Scalarizer::scatter(Instruction *Point, Value *V) {
+ if (Argument *VArg = dyn_cast<Argument>(V)) {
+ // Put the scattered form of arguments in the entry block,
+ // so that it can be used everywhere.
+ Function *F = VArg->getParent();
+ BasicBlock *BB = &F->getEntryBlock();
+ return Scatterer(BB, BB->begin(), V, &Scattered[V]);
+ }
+ if (Instruction *VOp = dyn_cast<Instruction>(V)) {
+ // Put the scattered form of an instruction directly after the
+ // instruction.
+ BasicBlock *BB = VOp->getParent();
+ return Scatterer(BB, llvm::next(BasicBlock::iterator(VOp)),
+ V, &Scattered[V]);
+ }
+ // In the fallback case, just put the scattered before Point and
+ // keep the result local to Point.
+ return Scatterer(Point->getParent(), Point, V);
+}
+
+// Replace Op with the gathered form of the components in CV. Defer the
+// deletion of Op and creation of the gathered form to the end of the pass,
+// so that we can avoid creating the gathered form if all uses of Op are
+// replaced with uses of CV.
+void Scalarizer::gather(Instruction *Op, const ValueVector &CV) {
+ // Since we're not deleting Op yet, stub out its operands, so that it
+ // doesn't make anything live unnecessarily.
+ for (unsigned I = 0, E = Op->getNumOperands(); I != E; ++I)
+ Op->setOperand(I, UndefValue::get(Op->getOperand(I)->getType()));
+
+ transferMetadata(Op, CV);
+
+ // If we already have a scattered form of Op (created from ExtractElements
+ // of Op itself), replace them with the new form.
+ ValueVector &SV = Scattered[Op];
+ if (!SV.empty()) {
+ for (unsigned I = 0, E = SV.size(); I != E; ++I) {
+ Instruction *Old = cast<Instruction>(SV[I]);
+ CV[I]->takeName(Old);
+ Old->replaceAllUsesWith(CV[I]);
+ Old->eraseFromParent();
+ }
+ }
+ SV = CV;
+ Gathered.push_back(GatherList::value_type(Op, &SV));
+}
+
+// Return true if it is safe to transfer the given metadata tag from
+// vector to scalar instructions.
+bool Scalarizer::canTransferMetadata(unsigned Tag) {
+ return (Tag == LLVMContext::MD_tbaa
+ || Tag == LLVMContext::MD_fpmath
+ || Tag == LLVMContext::MD_tbaa_struct
+ || Tag == LLVMContext::MD_invariant_load
+ || Tag == ParallelLoopAccessMDKind);
+}
+
+// Transfer metadata from Op to the instructions in CV if it is known
+// to be safe to do so.
+void Scalarizer::transferMetadata(Instruction *Op, const ValueVector &CV) {
+ SmallVector<std::pair<unsigned, MDNode *>, 4> MDs;
+ Op->getAllMetadataOtherThanDebugLoc(MDs);
+ for (unsigned I = 0, E = CV.size(); I != E; ++I) {
+ if (Instruction *New = dyn_cast<Instruction>(CV[I])) {
+ for (SmallVectorImpl<std::pair<unsigned, MDNode *> >::iterator
+ MI = MDs.begin(), ME = MDs.end(); MI != ME; ++MI)
+ if (canTransferMetadata(MI->first))
+ New->setMetadata(MI->first, MI->second);
+ New->setDebugLoc(Op->getDebugLoc());
+ }
+ }
+}
+
+// Try to fill in Layout from Ty, returning true on success. Alignment is
+// the alignment of the vector, or 0 if the ABI default should be used.
+bool Scalarizer::getVectorLayout(Type *Ty, unsigned Alignment,
+ VectorLayout &Layout) {
+ if (!TDL)
+ return false;
+
+ // Make sure we're dealing with a vector.
+ Layout.VecTy = dyn_cast<VectorType>(Ty);
+ if (!Layout.VecTy)
+ return false;
+
+ // Check that we're dealing with full-byte elements.
+ Layout.ElemTy = Layout.VecTy->getElementType();
+ if (TDL->getTypeSizeInBits(Layout.ElemTy) !=
+ TDL->getTypeStoreSizeInBits(Layout.ElemTy))
+ return false;
+
+ if (Alignment)
+ Layout.VecAlign = Alignment;
+ else
+ Layout.VecAlign = TDL->getABITypeAlignment(Layout.VecTy);
+ Layout.ElemSize = TDL->getTypeStoreSize(Layout.ElemTy);
+ return true;
+}
+
+// Scalarize two-operand instruction I, using Split(Builder, X, Y, Name)
+// to create an instruction like I with operands X and Y and name Name.
+template<typename Splitter>
+bool Scalarizer::splitBinary(Instruction &I, const Splitter &Split) {
+ VectorType *VT = dyn_cast<VectorType>(I.getType());
+ if (!VT)
+ return false;
+
+ unsigned NumElems = VT->getNumElements();
+ IRBuilder<> Builder(I.getParent(), &I);
+ Scatterer Op0 = scatter(&I, I.getOperand(0));
+ Scatterer Op1 = scatter(&I, I.getOperand(1));
+ assert(Op0.size() == NumElems && "Mismatched binary operation");
+ assert(Op1.size() == NumElems && "Mismatched binary operation");
+ ValueVector Res;
+ Res.resize(NumElems);
+ for (unsigned Elem = 0; Elem < NumElems; ++Elem)
+ Res[Elem] = Split(Builder, Op0[Elem], Op1[Elem],
+ I.getName() + ".i" + Twine(Elem));
+ gather(&I, Res);
+ return true;
+}
+
+bool Scalarizer::visitSelectInst(SelectInst &SI) {
+ VectorType *VT = dyn_cast<VectorType>(SI.getType());
+ if (!VT)
+ return false;
+
+ unsigned NumElems = VT->getNumElements();
+ IRBuilder<> Builder(SI.getParent(), &SI);
+ Scatterer Op1 = scatter(&SI, SI.getOperand(1));
+ Scatterer Op2 = scatter(&SI, SI.getOperand(2));
+ assert(Op1.size() == NumElems && "Mismatched select");
+ assert(Op2.size() == NumElems && "Mismatched select");
+ ValueVector Res;
+ Res.resize(NumElems);
+
+ if (SI.getOperand(0)->getType()->isVectorTy()) {
+ Scatterer Op0 = scatter(&SI, SI.getOperand(0));
+ assert(Op0.size() == NumElems && "Mismatched select");
+ for (unsigned I = 0; I < NumElems; ++I)
+ Res[I] = Builder.CreateSelect(Op0[I], Op1[I], Op2[I],
+ SI.getName() + ".i" + Twine(I));
+ } else {
+ Value *Op0 = SI.getOperand(0);
+ for (unsigned I = 0; I < NumElems; ++I)
+ Res[I] = Builder.CreateSelect(Op0, Op1[I], Op2[I],
+ SI.getName() + ".i" + Twine(I));
+ }
+ gather(&SI, Res);
+ return true;
+}
+
+bool Scalarizer::visitICmpInst(ICmpInst &ICI) {
+ return splitBinary(ICI, ICmpSplitter(ICI));
+}
+
+bool Scalarizer::visitFCmpInst(FCmpInst &FCI) {
+ return splitBinary(FCI, FCmpSplitter(FCI));
+}
+
+bool Scalarizer::visitBinaryOperator(BinaryOperator &BO) {
+ return splitBinary(BO, BinarySplitter(BO));
+}
+
+bool Scalarizer::visitGetElementPtrInst(GetElementPtrInst &GEPI) {
+ return splitBinary(GEPI, GEPSplitter());
+}
+
+bool Scalarizer::visitCastInst(CastInst &CI) {
+ VectorType *VT = dyn_cast<VectorType>(CI.getDestTy());
+ if (!VT)
+ return false;
+
+ unsigned NumElems = VT->getNumElements();
+ IRBuilder<> Builder(CI.getParent(), &CI);
+ Scatterer Op0 = scatter(&CI, CI.getOperand(0));
+ assert(Op0.size() == NumElems && "Mismatched cast");
+ ValueVector Res;
+ Res.resize(NumElems);
+ for (unsigned I = 0; I < NumElems; ++I)
+ Res[I] = Builder.CreateCast(CI.getOpcode(), Op0[I], VT->getElementType(),
+ CI.getName() + ".i" + Twine(I));
+ gather(&CI, Res);
+ return true;
+}
+
+bool Scalarizer::visitBitCastInst(BitCastInst &BCI) {
+ VectorType *DstVT = dyn_cast<VectorType>(BCI.getDestTy());
+ VectorType *SrcVT = dyn_cast<VectorType>(BCI.getSrcTy());
+ if (!DstVT || !SrcVT)
+ return false;
+
+ unsigned DstNumElems = DstVT->getNumElements();
+ unsigned SrcNumElems = SrcVT->getNumElements();
+ IRBuilder<> Builder(BCI.getParent(), &BCI);
+ Scatterer Op0 = scatter(&BCI, BCI.getOperand(0));
+ ValueVector Res;
+ Res.resize(DstNumElems);
+
+ if (DstNumElems == SrcNumElems) {
+ for (unsigned I = 0; I < DstNumElems; ++I)
+ Res[I] = Builder.CreateBitCast(Op0[I], DstVT->getElementType(),
+ BCI.getName() + ".i" + Twine(I));
+ } else if (DstNumElems > SrcNumElems) {
+ // <M x t1> -> <N*M x t2>. Convert each t1 to <N x t2> and copy the
+ // individual elements to the destination.
+ unsigned FanOut = DstNumElems / SrcNumElems;
+ Type *MidTy = VectorType::get(DstVT->getElementType(), FanOut);
+ unsigned ResI = 0;
+ for (unsigned Op0I = 0; Op0I < SrcNumElems; ++Op0I) {
+ Value *V = Op0[Op0I];
+ Instruction *VI;
+ // Look through any existing bitcasts before converting to <N x t2>.
+ // In the best case, the resulting conversion might be a no-op.
+ while ((VI = dyn_cast<Instruction>(V)) &&
+ VI->getOpcode() == Instruction::BitCast)
+ V = VI->getOperand(0);
+ V = Builder.CreateBitCast(V, MidTy, V->getName() + ".cast");
+ Scatterer Mid = scatter(&BCI, V);
+ for (unsigned MidI = 0; MidI < FanOut; ++MidI)
+ Res[ResI++] = Mid[MidI];
+ }
+ } else {
+ // <N*M x t1> -> <M x t2>. Convert each group of <N x t1> into a t2.
+ unsigned FanIn = SrcNumElems / DstNumElems;
+ Type *MidTy = VectorType::get(SrcVT->getElementType(), FanIn);
+ unsigned Op0I = 0;
+ for (unsigned ResI = 0; ResI < DstNumElems; ++ResI) {
+ Value *V = UndefValue::get(MidTy);
+ for (unsigned MidI = 0; MidI < FanIn; ++MidI)
+ V = Builder.CreateInsertElement(V, Op0[Op0I++], Builder.getInt32(MidI),
+ BCI.getName() + ".i" + Twine(ResI)
+ + ".upto" + Twine(MidI));
+ Res[ResI] = Builder.CreateBitCast(V, DstVT->getElementType(),
+ BCI.getName() + ".i" + Twine(ResI));
+ }
+ }
+ gather(&BCI, Res);
+ return true;
+}
+
+bool Scalarizer::visitShuffleVectorInst(ShuffleVectorInst &SVI) {
+ VectorType *VT = dyn_cast<VectorType>(SVI.getType());
+ if (!VT)
+ return false;
+
+ unsigned NumElems = VT->getNumElements();
+ Scatterer Op0 = scatter(&SVI, SVI.getOperand(0));
+ Scatterer Op1 = scatter(&SVI, SVI.getOperand(1));
+ ValueVector Res;
+ Res.resize(NumElems);
+
+ for (unsigned I = 0; I < NumElems; ++I) {
+ int Selector = SVI.getMaskValue(I);
+ if (Selector < 0)
+ Res[I] = UndefValue::get(VT->getElementType());
+ else if (unsigned(Selector) < Op0.size())
+ Res[I] = Op0[Selector];
+ else
+ Res[I] = Op1[Selector - Op0.size()];
+ }
+ gather(&SVI, Res);
+ return true;
+}
+
+bool Scalarizer::visitPHINode(PHINode &PHI) {
+ VectorType *VT = dyn_cast<VectorType>(PHI.getType());
+ if (!VT)
+ return false;
+
+ unsigned NumElems = VT->getNumElements();
+ IRBuilder<> Builder(PHI.getParent(), &PHI);
+ ValueVector Res;
+ Res.resize(NumElems);
+
+ unsigned NumOps = PHI.getNumOperands();
+ for (unsigned I = 0; I < NumElems; ++I)
+ Res[I] = Builder.CreatePHI(VT->getElementType(), NumOps,
+ PHI.getName() + ".i" + Twine(I));
+
+ for (unsigned I = 0; I < NumOps; ++I) {
+ Scatterer Op = scatter(&PHI, PHI.getIncomingValue(I));
+ BasicBlock *IncomingBlock = PHI.getIncomingBlock(I);
+ for (unsigned J = 0; J < NumElems; ++J)
+ cast<PHINode>(Res[J])->addIncoming(Op[J], IncomingBlock);
+ }
+ gather(&PHI, Res);
+ return true;
+}
+
+bool Scalarizer::visitLoadInst(LoadInst &LI) {
+ if (!ScalarizeLoadStore)
+ return false;
+ if (!LI.isSimple())
+ return false;
+
+ VectorLayout Layout;
+ if (!getVectorLayout(LI.getType(), LI.getAlignment(), Layout))
+ return false;
+
+ unsigned NumElems = Layout.VecTy->getNumElements();
+ IRBuilder<> Builder(LI.getParent(), &LI);
+ Scatterer Ptr = scatter(&LI, LI.getPointerOperand());
+ ValueVector Res;
+ Res.resize(NumElems);
+
+ for (unsigned I = 0; I < NumElems; ++I)
+ Res[I] = Builder.CreateAlignedLoad(Ptr[I], Layout.getElemAlign(I),
+ LI.getName() + ".i" + Twine(I));
+ gather(&LI, Res);
+ return true;
+}
+
+bool Scalarizer::visitStoreInst(StoreInst &SI) {
+ if (!ScalarizeLoadStore)
+ return false;
+ if (!SI.isSimple())
+ return false;
+
+ VectorLayout Layout;
+ Value *FullValue = SI.getValueOperand();
+ if (!getVectorLayout(FullValue->getType(), SI.getAlignment(), Layout))
+ return false;
+
+ unsigned NumElems = Layout.VecTy->getNumElements();
+ IRBuilder<> Builder(SI.getParent(), &SI);
+ Scatterer Ptr = scatter(&SI, SI.getPointerOperand());
+ Scatterer Val = scatter(&SI, FullValue);
+
+ ValueVector Stores;
+ Stores.resize(NumElems);
+ for (unsigned I = 0; I < NumElems; ++I) {
+ unsigned Align = Layout.getElemAlign(I);
+ Stores[I] = Builder.CreateAlignedStore(Val[I], Ptr[I], Align);
+ }
+ transferMetadata(&SI, Stores);
+ return true;
+}
+
+// Delete the instructions that we scalarized. If a full vector result
+// is still needed, recreate it using InsertElements.
+bool Scalarizer::finish() {
+ if (Gathered.empty())
+ return false;
+ for (GatherList::iterator GMI = Gathered.begin(), GME = Gathered.end();
+ GMI != GME; ++GMI) {
+ Instruction *Op = GMI->first;
+ ValueVector &CV = *GMI->second;
+ if (!Op->use_empty()) {
+ // The value is still needed, so recreate it using a series of
+ // InsertElements.
+ Type *Ty = Op->getType();
+ Value *Res = UndefValue::get(Ty);
+ unsigned Count = Ty->getVectorNumElements();
+ IRBuilder<> Builder(Op->getParent(), Op);
+ for (unsigned I = 0; I < Count; ++I)
+ Res = Builder.CreateInsertElement(Res, CV[I], Builder.getInt32(I),
+ Op->getName() + ".upto" + Twine(I));
+ Res->takeName(Op);
+ Op->replaceAllUsesWith(Res);
+ }
+ Op->eraseFromParent();
+ }
+ Gathered.clear();
+ Scattered.clear();
+ return true;
+}
+
+namespace klee {
+ llvm::FunctionPass *createScalarizerPass() {
+ return new Scalarizer();
+ }
+}
+
+#endif