diff options
| author | Dan Liew <daniel.liew@imperial.ac.uk> | 2017-05-24 16:32:47 +0100 | 
|---|---|---|
| committer | Dan Liew <daniel.liew@imperial.ac.uk> | 2017-07-19 18:15:20 +0100 | 
| commit | 8e8732d482e42e363f0f4c51794ed966701371e7 (patch) | |
| tree | 35ae38f61a0dbe865c931841a559b0d4553a838a /lib/Core/Executor.cpp | |
| parent | 7c4fdd012317eb92352fc7ded53a553ed762719f (diff) | |
| download | klee-8e8732d482e42e363f0f4c51794ed966701371e7.tar.gz | |
Implement basic support for vectorized instructions.
We use LLVM's Scalarizer pass to remove most vectorized code so that the Executor only needs to support the InsertElement and ExtractElement instructions. This pass was not available in LLVM 3.4 so to support that LLVM version the pass has been back ported. To check that the Executor is not receiving vector operand types that it can't handle assertions have been added. There are a few limitations to this implementation. * The InsertElement and ExtractElement index cannot be symbolic. * There is no support for LLVM < 3.4.
Diffstat (limited to 'lib/Core/Executor.cpp')
| -rw-r--r-- | lib/Core/Executor.cpp | 218 | 
1 files changed, 206 insertions, 12 deletions
| diff --git a/lib/Core/Executor.cpp b/lib/Core/Executor.cpp index c087b79b..96181a3d 100644 --- a/lib/Core/Executor.cpp +++ b/lib/Core/Executor.cpp @@ -279,6 +279,7 @@ namespace { cl::values( clEnumValN(Executor::Abort, "Abort", "The program crashed"), clEnumValN(Executor::Assert, "Assert", "An assertion was hit"), + clEnumValN(Executor::BadVectorAccess, "BadVectorAccess", "Vector accessed out of bounds"), clEnumValN(Executor::Exec, "Exec", "Trying to execute an unexpected instruction"), clEnumValN(Executor::External, "External", "External objects referenced"), clEnumValN(Executor::Free, "Free", "Freeing invalid memory"), @@ -333,6 +334,7 @@ namespace klee { const char *Executor::TerminateReasonNames[] = { [ Abort ] = "abort", [ Assert ] = "assert", + [ BadVectorAccess ] = "bad_vector_access", [ Exec ] = "exec", [ External ] = "external", [ Free ] = "free", @@ -1102,8 +1104,18 @@ ref<klee::ConstantExpr> Executor::evalConstant(const Constant *c) { } ref<Expr> res = ConcatExpr::createN(kids.size(), kids.data()); return cast<ConstantExpr>(res); + } else if (const ConstantVector *cv = dyn_cast<ConstantVector>(c)) { + llvm::SmallVector<ref<Expr>, 8> kids; + const size_t numOperands = cv->getNumOperands(); + kids.reserve(numOperands); + for (unsigned i = 0; i < numOperands; ++i) { + kids.push_back(evalConstant(cv->getOperand(i))); + } + ref<Expr> res = ConcatExpr::createN(numOperands, kids.data()); + assert(isa<ConstantExpr>(res) && + "result of constant vector built is not a constant"); + return cast<ConstantExpr>(res); } else { - // Constant{Vector} llvm::report_fatal_error("invalid argument to evalConstant()"); } } @@ -1910,6 +1922,12 @@ void Executor::executeInstruction(ExecutionState &state, KInstruction *ki) { // Special instructions case Instruction::Select: { + assert(i->getOperand(0)->getType()->isIntegerTy() && + "Invalid operand type"); + assert((i->getOperand(1)->getType() == i->getOperand(2)->getType()) && + "true and false operands to Select have mismatching types"); + // NOTE: We are do not check that operands 1 and 2 are not of vector type + // because the scalarizer pass might not remove these. ref<Expr> cond = eval(ki, 0, state).value; ref<Expr> tExpr = eval(ki, 1, state).value; ref<Expr> fExpr = eval(ki, 2, state).value; @@ -1925,6 +1943,10 @@ void Executor::executeInstruction(ExecutionState &state, KInstruction *ki) { // Arithmetic / logical case Instruction::Add: { + assert(i->getOperand(0)->getType()->isIntegerTy() && + "Invalid operand type"); + assert(i->getOperand(1)->getType()->isIntegerTy() && + "Invalid operand type"); ref<Expr> left = eval(ki, 0, state).value; ref<Expr> right = eval(ki, 1, state).value; bindLocal(ki, state, AddExpr::create(left, right)); @@ -1932,6 +1954,10 @@ void Executor::executeInstruction(ExecutionState &state, KInstruction *ki) { } case Instruction::Sub: { + assert(i->getOperand(0)->getType()->isIntegerTy() && + "Invalid operand type"); + assert(i->getOperand(1)->getType()->isIntegerTy() && + "Invalid operand type"); ref<Expr> left = eval(ki, 0, state).value; ref<Expr> right = eval(ki, 1, state).value; bindLocal(ki, state, SubExpr::create(left, right)); @@ -1939,6 +1965,10 @@ void Executor::executeInstruction(ExecutionState &state, KInstruction *ki) { } case Instruction::Mul: { + assert(i->getOperand(0)->getType()->isIntegerTy() && + "Invalid operand type"); + assert(i->getOperand(1)->getType()->isIntegerTy() && + "Invalid operand type"); ref<Expr> left = eval(ki, 0, state).value; ref<Expr> right = eval(ki, 1, state).value; bindLocal(ki, state, MulExpr::create(left, right)); @@ -1946,6 +1976,10 @@ void Executor::executeInstruction(ExecutionState &state, KInstruction *ki) { } case Instruction::UDiv: { + assert(i->getOperand(0)->getType()->isIntegerTy() && + "Invalid operand type"); + assert(i->getOperand(1)->getType()->isIntegerTy() && + "Invalid operand type"); ref<Expr> left = eval(ki, 0, state).value; ref<Expr> right = eval(ki, 1, state).value; ref<Expr> result = UDivExpr::create(left, right); @@ -1954,6 +1988,10 @@ void Executor::executeInstruction(ExecutionState &state, KInstruction *ki) { } case Instruction::SDiv: { + assert(i->getOperand(0)->getType()->isIntegerTy() && + "Invalid operand type"); + assert(i->getOperand(1)->getType()->isIntegerTy() && + "Invalid operand type"); ref<Expr> left = eval(ki, 0, state).value; ref<Expr> right = eval(ki, 1, state).value; ref<Expr> result = SDivExpr::create(left, right); @@ -1962,14 +2000,22 @@ void Executor::executeInstruction(ExecutionState &state, KInstruction *ki) { } case Instruction::URem: { + assert(i->getOperand(0)->getType()->isIntegerTy() && + "Invalid operand type"); + assert(i->getOperand(1)->getType()->isIntegerTy() && + "Invalid operand type"); ref<Expr> left = eval(ki, 0, state).value; ref<Expr> right = eval(ki, 1, state).value; ref<Expr> result = URemExpr::create(left, right); bindLocal(ki, state, result); break; } - + case Instruction::SRem: { + assert(i->getOperand(0)->getType()->isIntegerTy() && + "Invalid operand type"); + assert(i->getOperand(1)->getType()->isIntegerTy() && + "Invalid operand type"); ref<Expr> left = eval(ki, 0, state).value; ref<Expr> right = eval(ki, 1, state).value; ref<Expr> result = SRemExpr::create(left, right); @@ -1978,6 +2024,10 @@ void Executor::executeInstruction(ExecutionState &state, KInstruction *ki) { } case Instruction::And: { + assert(i->getOperand(0)->getType()->isIntegerTy() && + "Invalid operand type"); + assert(i->getOperand(1)->getType()->isIntegerTy() && + "Invalid operand type"); ref<Expr> left = eval(ki, 0, state).value; ref<Expr> right = eval(ki, 1, state).value; ref<Expr> result = AndExpr::create(left, right); @@ -1986,6 +2036,10 @@ void Executor::executeInstruction(ExecutionState &state, KInstruction *ki) { } case Instruction::Or: { + assert(i->getOperand(0)->getType()->isIntegerTy() && + "Invalid operand type"); + assert(i->getOperand(1)->getType()->isIntegerTy() && + "Invalid operand type"); ref<Expr> left = eval(ki, 0, state).value; ref<Expr> right = eval(ki, 1, state).value; ref<Expr> result = OrExpr::create(left, right); @@ -1994,6 +2048,10 @@ void Executor::executeInstruction(ExecutionState &state, KInstruction *ki) { } case Instruction::Xor: { + assert(i->getOperand(0)->getType()->isIntegerTy() && + "Invalid operand type"); + assert(i->getOperand(1)->getType()->isIntegerTy() && + "Invalid operand type"); ref<Expr> left = eval(ki, 0, state).value; ref<Expr> right = eval(ki, 1, state).value; ref<Expr> result = XorExpr::create(left, right); @@ -2002,6 +2060,10 @@ void Executor::executeInstruction(ExecutionState &state, KInstruction *ki) { } case Instruction::Shl: { + assert(i->getOperand(0)->getType()->isIntegerTy() && + "Invalid operand type"); + assert(i->getOperand(1)->getType()->isIntegerTy() && + "Invalid operand type"); ref<Expr> left = eval(ki, 0, state).value; ref<Expr> right = eval(ki, 1, state).value; ref<Expr> result = ShlExpr::create(left, right); @@ -2010,6 +2072,10 @@ void Executor::executeInstruction(ExecutionState &state, KInstruction *ki) { } case Instruction::LShr: { + assert(i->getOperand(0)->getType()->isIntegerTy() && + "Invalid operand type"); + assert(i->getOperand(1)->getType()->isIntegerTy() && + "Invalid operand type"); ref<Expr> left = eval(ki, 0, state).value; ref<Expr> right = eval(ki, 1, state).value; ref<Expr> result = LShrExpr::create(left, right); @@ -2018,6 +2084,10 @@ void Executor::executeInstruction(ExecutionState &state, KInstruction *ki) { } case Instruction::AShr: { + assert(i->getOperand(0)->getType()->isIntegerTy() && + "Invalid operand type"); + assert(i->getOperand(1)->getType()->isIntegerTy() && + "Invalid operand type"); ref<Expr> left = eval(ki, 0, state).value; ref<Expr> right = eval(ki, 1, state).value; ref<Expr> result = AShrExpr::create(left, right); @@ -2028,9 +2098,15 @@ void Executor::executeInstruction(ExecutionState &state, KInstruction *ki) { // Compare case Instruction::ICmp: { + assert((i->getOperand(0)->getType()->isIntegerTy() || + i->getOperand(0)->getType()->isPointerTy()) && + "Invalid operand type"); + assert((i->getOperand(1)->getType()->isIntegerTy() || + i->getOperand(1)->getType()->isPointerTy()) && + "Invalid operand type"); CmpInst *ci = cast<CmpInst>(i); ICmpInst *ii = cast<ICmpInst>(ci); - + switch(ii->getPredicate()) { case ICmpInst::ICMP_EQ: { ref<Expr> left = eval(ki, 0, state).value; @@ -2167,6 +2243,8 @@ void Executor::executeInstruction(ExecutionState &state, KInstruction *ki) { // Conversion case Instruction::Trunc: { + assert(i->getOperand(0)->getType()->isIntegerTy() && + "Invalid operand type"); CastInst *ci = cast<CastInst>(i); ref<Expr> result = ExtractExpr::create(eval(ki, 0, state).value, 0, @@ -2175,6 +2253,8 @@ void Executor::executeInstruction(ExecutionState &state, KInstruction *ki) { break; } case Instruction::ZExt: { + assert(i->getOperand(0)->getType()->isIntegerTy() && + "Invalid operand type"); CastInst *ci = cast<CastInst>(i); ref<Expr> result = ZExtExpr::create(eval(ki, 0, state).value, getWidthForLLVMType(ci->getType())); @@ -2182,6 +2262,8 @@ void Executor::executeInstruction(ExecutionState &state, KInstruction *ki) { break; } case Instruction::SExt: { + assert(i->getOperand(0)->getType()->isIntegerTy() && + "Invalid operand type"); CastInst *ci = cast<CastInst>(i); ref<Expr> result = SExtExpr::create(eval(ki, 0, state).value, getWidthForLLVMType(ci->getType())); @@ -2190,13 +2272,17 @@ void Executor::executeInstruction(ExecutionState &state, KInstruction *ki) { } case Instruction::IntToPtr: { + assert(i->getOperand(0)->getType()->isIntegerTy() && + "Invalid operand type"); CastInst *ci = cast<CastInst>(i); Expr::Width pType = getWidthForLLVMType(ci->getType()); ref<Expr> arg = eval(ki, 0, state).value; bindLocal(ki, state, ZExtExpr::create(arg, pType)); break; - } + } case Instruction::PtrToInt: { + assert(i->getOperand(0)->getType()->isPointerTy() && + "Invalid operand type"); CastInst *ci = cast<CastInst>(i); Expr::Width iType = getWidthForLLVMType(ci->getType()); ref<Expr> arg = eval(ki, 0, state).value; @@ -2213,6 +2299,10 @@ void Executor::executeInstruction(ExecutionState &state, KInstruction *ki) { // Floating point instructions case Instruction::FAdd: { + assert(i->getOperand(0)->getType()->isFloatingPointTy() && + "Invalid operand type"); + assert(i->getOperand(1)->getType()->isFloatingPointTy() && + "Invalid operand type"); ref<ConstantExpr> left = toConstant(state, eval(ki, 0, state).value, "floating point"); ref<ConstantExpr> right = toConstant(state, eval(ki, 1, state).value, @@ -2233,6 +2323,10 @@ void Executor::executeInstruction(ExecutionState &state, KInstruction *ki) { } case Instruction::FSub: { + assert(i->getOperand(0)->getType()->isFloatingPointTy() && + "Invalid operand type"); + assert(i->getOperand(1)->getType()->isFloatingPointTy() && + "Invalid operand type"); ref<ConstantExpr> left = toConstant(state, eval(ki, 0, state).value, "floating point"); ref<ConstantExpr> right = toConstant(state, eval(ki, 1, state).value, @@ -2250,8 +2344,12 @@ void Executor::executeInstruction(ExecutionState &state, KInstruction *ki) { bindLocal(ki, state, ConstantExpr::alloc(Res.bitcastToAPInt())); break; } - + case Instruction::FMul: { + assert(i->getOperand(0)->getType()->isFloatingPointTy() && + "Invalid operand type"); + assert(i->getOperand(1)->getType()->isFloatingPointTy() && + "Invalid operand type"); ref<ConstantExpr> left = toConstant(state, eval(ki, 0, state).value, "floating point"); ref<ConstantExpr> right = toConstant(state, eval(ki, 1, state).value, @@ -2272,6 +2370,10 @@ void Executor::executeInstruction(ExecutionState &state, KInstruction *ki) { } case Instruction::FDiv: { + assert(i->getOperand(0)->getType()->isFloatingPointTy() && + "Invalid operand type"); + assert(i->getOperand(1)->getType()->isFloatingPointTy() && + "Invalid operand type"); ref<ConstantExpr> left = toConstant(state, eval(ki, 0, state).value, "floating point"); ref<ConstantExpr> right = toConstant(state, eval(ki, 1, state).value, @@ -2292,6 +2394,10 @@ void Executor::executeInstruction(ExecutionState &state, KInstruction *ki) { } case Instruction::FRem: { + assert(i->getOperand(0)->getType()->isFloatingPointTy() && + "Invalid operand type"); + assert(i->getOperand(1)->getType()->isFloatingPointTy() && + "Invalid operand type"); ref<ConstantExpr> left = toConstant(state, eval(ki, 0, state).value, "floating point"); ref<ConstantExpr> right = toConstant(state, eval(ki, 1, state).value, @@ -2312,6 +2418,8 @@ void Executor::executeInstruction(ExecutionState &state, KInstruction *ki) { } case Instruction::FPTrunc: { + assert(i->getOperand(0)->getType()->isFloatingPointTy() && + "Invalid operand type"); FPTruncInst *fi = cast<FPTruncInst>(i); Expr::Width resultType = getWidthForLLVMType(fi->getType()); ref<ConstantExpr> arg = toConstant(state, eval(ki, 0, state).value, @@ -2333,6 +2441,8 @@ void Executor::executeInstruction(ExecutionState &state, KInstruction *ki) { } case Instruction::FPExt: { + assert(i->getOperand(0)->getType()->isFloatingPointTy() && + "Invalid operand type"); FPExtInst *fi = cast<FPExtInst>(i); Expr::Width resultType = getWidthForLLVMType(fi->getType()); ref<ConstantExpr> arg = toConstant(state, eval(ki, 0, state).value, @@ -2353,6 +2463,8 @@ void Executor::executeInstruction(ExecutionState &state, KInstruction *ki) { } case Instruction::FPToUI: { + assert(i->getOperand(0)->getType()->isFloatingPointTy() && + "Invalid operand type"); FPToUIInst *fi = cast<FPToUIInst>(i); Expr::Width resultType = getWidthForLLVMType(fi->getType()); ref<ConstantExpr> arg = toConstant(state, eval(ki, 0, state).value, @@ -2374,6 +2486,8 @@ void Executor::executeInstruction(ExecutionState &state, KInstruction *ki) { } case Instruction::FPToSI: { + assert(i->getOperand(0)->getType()->isFloatingPointTy() && + "Invalid operand type"); FPToSIInst *fi = cast<FPToSIInst>(i); Expr::Width resultType = getWidthForLLVMType(fi->getType()); ref<ConstantExpr> arg = toConstant(state, eval(ki, 0, state).value, @@ -2395,6 +2509,8 @@ void Executor::executeInstruction(ExecutionState &state, KInstruction *ki) { } case Instruction::UIToFP: { + assert(i->getOperand(0)->getType()->isIntegerTy() && + "Invalid operand type"); UIToFPInst *fi = cast<UIToFPInst>(i); Expr::Width resultType = getWidthForLLVMType(fi->getType()); ref<ConstantExpr> arg = toConstant(state, eval(ki, 0, state).value, @@ -2411,6 +2527,8 @@ void Executor::executeInstruction(ExecutionState &state, KInstruction *ki) { } case Instruction::SIToFP: { + assert(i->getOperand(0)->getType()->isIntegerTy() && + "Invalid operand type"); SIToFPInst *fi = cast<SIToFPInst>(i); Expr::Width resultType = getWidthForLLVMType(fi->getType()); ref<ConstantExpr> arg = toConstant(state, eval(ki, 0, state).value, @@ -2427,6 +2545,10 @@ void Executor::executeInstruction(ExecutionState &state, KInstruction *ki) { } case Instruction::FCmp: { + assert(i->getOperand(0)->getType()->isFloatingPointTy() && + "Invalid operand type"); + assert(i->getOperand(1)->getType()->isFloatingPointTy() && + "Invalid operand type"); FCmpInst *fi = cast<FCmpInst>(i); ref<ConstantExpr> left = toConstant(state, eval(ki, 0, state).value, "floating point"); @@ -2566,16 +2688,88 @@ void Executor::executeInstruction(ExecutionState &state, KInstruction *ki) { break; } #endif + case Instruction::InsertElement: { + InsertElementInst *iei = cast<InsertElementInst>(i); + ref<Expr> vec = eval(ki, 0, state).value; + ref<Expr> newElt = eval(ki, 1, state).value; + ref<Expr> idx = eval(ki, 2, state).value; + + ConstantExpr *cIdx = dyn_cast<ConstantExpr>(idx); + if (cIdx == NULL) { + terminateStateOnError( + state, "InsertElement, support for symbolic index not implemented", + Unhandled); + return; + } + uint64_t iIdx = cIdx->getZExtValue(); + const llvm::VectorType *vt = iei->getType(); + unsigned EltBits = getWidthForLLVMType(vt->getElementType()); - // Other instructions... - // Unhandled - case Instruction::ExtractElement: - case Instruction::InsertElement: + if (iIdx >= vt->getNumElements()) { + // Out of bounds write + terminateStateOnError(state, "Out of bounds write when inserting element", + BadVectorAccess); + return; + } + + const unsigned elementCount = vt->getNumElements(); + llvm::SmallVector<ref<Expr>, 8> elems; + elems.reserve(elementCount); + for (unsigned i = 0; i < elementCount; ++i) { + // evalConstant() will use ConcatExpr to build vectors with the + // zero-th element leftmost (most significant bits), followed + // by the next element (second leftmost) and so on. This means + // that we have to adjust the index so we read left to right + // rather than right to left. + unsigned bitOffset = EltBits * (elementCount - i - 1); + elems.push_back(i == iIdx ? newElt + : ExtractExpr::create(vec, bitOffset, EltBits)); + } + + ref<Expr> Result = ConcatExpr::createN(elementCount, elems.data()); + bindLocal(ki, state, Result); + break; + } + case Instruction::ExtractElement: { + ExtractElementInst *eei = cast<ExtractElementInst>(i); + ref<Expr> vec = eval(ki, 0, state).value; + ref<Expr> idx = eval(ki, 1, state).value; + + ConstantExpr *cIdx = dyn_cast<ConstantExpr>(idx); + if (cIdx == NULL) { + terminateStateOnError( + state, "ExtractElement, support for symbolic index not implemented", + Unhandled); + return; + } + uint64_t iIdx = cIdx->getZExtValue(); + const llvm::VectorType *vt = eei->getVectorOperandType(); + unsigned EltBits = getWidthForLLVMType(vt->getElementType()); + + if (iIdx >= vt->getNumElements()) { + // Out of bounds read + terminateStateOnError(state, "Out of bounds read when extracting element", + BadVectorAccess); + return; + } + + // evalConstant() will use ConcatExpr to build vectors with the + // zero-th element left most (most significant bits), followed + // by the next element (second left most) and so on. This means + // that we have to adjust the index so we read left to right + // rather than right to left. + unsigned bitOffset = EltBits*(vt->getNumElements() - iIdx -1); + ref<Expr> Result = ExtractExpr::create(vec, bitOffset, EltBits); + bindLocal(ki, state, Result); + break; + } case Instruction::ShuffleVector: - terminateStateOnError(state, "XXX vector instructions unhandled", - Unhandled); + // Should never happen due to Scalarizer pass removing ShuffleVector + // instructions. + terminateStateOnExecError(state, "Unexpected ShuffleVector instruction"); break; - + // Other instructions... + // Unhandled default: terminateStateOnExecError(state, "illegal instruction"); break; | 
