aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--instrumentation/split-compares-pass.so.cc1005
-rwxr-xr-xtest/test-llvm.sh23
-rw-r--r--test/test-uint_cases.c30
3 files changed, 485 insertions, 573 deletions
diff --git a/instrumentation/split-compares-pass.so.cc b/instrumentation/split-compares-pass.so.cc
index b02a89fb..6eb9050c 100644
--- a/instrumentation/split-compares-pass.so.cc
+++ b/instrumentation/split-compares-pass.so.cc
@@ -47,50 +47,101 @@
using namespace llvm;
#include "afl-llvm-common.h"
+// uncomment this toggle function verification at each step. horribly slow, but
+// helps to pinpoint a potential problem in the splitting code.
+//#define VERIFY_TOO_MUCH 1
+
namespace {
class SplitComparesTransform : public ModulePass {
-
public:
static char ID;
SplitComparesTransform() : ModulePass(ID), enableFPSplit(0) {
-
initInstrumentList();
-
}
bool runOnModule(Module &M) override;
#if LLVM_VERSION_MAJOR >= 4
StringRef getPassName() const override {
-
#else
const char *getPassName() const override {
#endif
- return "simplifies and splits ICMP instructions";
-
+ return "AFL_SplitComparesTransform";
}
private:
int enableFPSplit;
- size_t splitIntCompares(Module &M, unsigned bitw);
+ unsigned target_bitwidth = 8;
+
+ size_t count = 0;
+
size_t splitFPCompares(Module &M);
- bool simplifyCompares(Module &M);
bool simplifyFPCompares(Module &M);
- bool simplifyIntSignedness(Module &M);
size_t nextPowerOfTwo(size_t in);
+ using CmpWorklist = SmallVector<CmpInst *, 8>;
+
+ /// simplify the comparison and then split the comparison until the
+ /// target_bitwidth is reached.
+ bool simplifyAndSplit(CmpInst *I, Module &M);
+ /// simplify a non-strict comparison (e.g., less than or equals)
+ bool simplifyOrEqualsCompare(CmpInst *IcmpInst, Module &M,
+ CmpWorklist &worklist);
+ /// simplify a signed comparison (signed less or greater than)
+ bool simplifySignedCompare(CmpInst *IcmpInst, Module &M,
+ CmpWorklist &worklist);
+ /// splits an icmp into nested icmps recursivly until target_bitwidth is
+ /// reached
+ bool splitCompare(CmpInst *I, Module &M, CmpWorklist &worklist);
+
+ /// print an error to llvm's errs stream, but only if not ordered to be quiet
+ void reportError(const StringRef msg, Instruction *I, Module &M) {
+ if (!be_quiet) {
+ errs() << "[AFL++ SplitComparesTransform] ERROR: " << msg << "\n";
+ if (debug) {
+ if (I) {
+ errs() << "Instruction = " << *I << "\n";
+ if (auto BB = I->getParent()) {
+ if (auto F = BB->getParent()) {
+ if (F->hasName()) {
+ errs() << "|-> in function " << F->getName() << " ";
+ }
+ }
+ }
+ }
+ auto n = M.getName();
+ if (n.size() > 0) { errs() << "in module " << n << "\n"; }
+ }
+ }
+ }
+
+ bool isSupportedBitWidth(unsigned bitw) {
+ // IDK whether the icmp code works on other bitwidths. I guess not? So we
+ // try to avoid dealing with other weird icmp's that llvm might use (looking
+ // at you `icmp i0`).
+ switch (bitw) {
+ case 8:
+ case 16:
+ case 32:
+ case 64:
+ case 128:
+ case 256:
+ return true;
+ default:
+ return false;
+ }
+ }
};
} // namespace
char SplitComparesTransform::ID = 0;
-/* This function splits FCMP instructions with xGE or xLE predicates into two
- * FCMP instructions with predicate xGT or xLT and EQ */
+/// This function splits FCMP instructions with xGE or xLE predicates into two
+/// FCMP instructions with predicate xGT or xLT and EQ
bool SplitComparesTransform::simplifyFPCompares(Module &M) {
-
LLVMContext & C = M.getContext();
std::vector<Instruction *> fcomps;
IntegerType * Int1Ty = IntegerType::getInt1Ty(C);
@@ -98,23 +149,18 @@ bool SplitComparesTransform::simplifyFPCompares(Module &M) {
/* iterate over all functions, bbs and instruction and add
* all integer comparisons with >= and <= predicates to the icomps vector */
for (auto &F : M) {
-
if (!isInInstrumentList(&F)) continue;
for (auto &BB : F) {
-
for (auto &IN : BB) {
-
CmpInst *selectcmpInst = nullptr;
if ((selectcmpInst = dyn_cast<CmpInst>(&IN))) {
-
if (enableFPSplit &&
(selectcmpInst->getPredicate() == CmpInst::FCMP_OGE ||
selectcmpInst->getPredicate() == CmpInst::FCMP_UGE ||
selectcmpInst->getPredicate() == CmpInst::FCMP_OLE ||
selectcmpInst->getPredicate() == CmpInst::FCMP_ULE)) {
-
auto op0 = selectcmpInst->getOperand(0);
auto op1 = selectcmpInst->getOperand(1);
@@ -127,22 +173,16 @@ bool SplitComparesTransform::simplifyFPCompares(Module &M) {
if (TyOp0->isArrayTy() || TyOp0->isVectorTy()) { continue; }
fcomps.push_back(selectcmpInst);
-
}
-
}
-
}
-
}
-
}
if (!fcomps.size()) { return false; }
/* transform for floating point */
for (auto &FcmpInst : fcomps) {
-
BasicBlock *bb = FcmpInst->getParent();
auto op0 = FcmpInst->getOperand(0);
@@ -155,7 +195,6 @@ bool SplitComparesTransform::simplifyFPCompares(Module &M) {
CmpInst::Predicate new_pred;
switch (pred) {
-
case CmpInst::FCMP_UGE:
new_pred = CmpInst::FCMP_UGT;
break;
@@ -170,7 +209,6 @@ bool SplitComparesTransform::simplifyFPCompares(Module &M) {
break;
default: // keep the compiler happy
continue;
-
}
/* split before the fcmp instruction */
@@ -214,305 +252,425 @@ bool SplitComparesTransform::simplifyFPCompares(Module &M) {
/* replace the old FcmpInst with our new and shiny PHI inst */
BasicBlock::iterator ii(FcmpInst);
ReplaceInstWithInst(FcmpInst->getParent()->getInstList(), ii, PN);
-
}
return true;
-
}
-/* This function splits ICMP instructions with xGE or xLE predicates into two
- * ICMP instructions with predicate xGT or xLT and EQ */
-bool SplitComparesTransform::simplifyCompares(Module &M) {
-
- LLVMContext & C = M.getContext();
- std::vector<Instruction *> icomps;
- IntegerType * Int1Ty = IntegerType::getInt1Ty(C);
-
- /* iterate over all functions, bbs and instruction and add
- * all integer comparisons with >= and <= predicates to the icomps vector */
- for (auto &F : M) {
-
- if (!isInInstrumentList(&F)) continue;
+/// This function splits ICMP instructions with xGE or xLE predicates into two
+/// ICMP instructions with predicate xGT or xLT and EQ
+bool SplitComparesTransform::simplifyOrEqualsCompare(CmpInst * IcmpInst,
+ Module & M,
+ CmpWorklist &worklist) {
+ LLVMContext &C = M.getContext();
+ IntegerType *Int1Ty = IntegerType::getInt1Ty(C);
- for (auto &BB : F) {
+ /* find out what the new predicate is going to be */
+ auto cmp_inst = dyn_cast<CmpInst>(IcmpInst);
+ if (!cmp_inst) { return false; }
- for (auto &IN : BB) {
+ BasicBlock *bb = IcmpInst->getParent();
- CmpInst *selectcmpInst = nullptr;
+ auto op0 = IcmpInst->getOperand(0);
+ auto op1 = IcmpInst->getOperand(1);
- if ((selectcmpInst = dyn_cast<CmpInst>(&IN))) {
+ CmpInst::Predicate pred = cmp_inst->getPredicate();
+ CmpInst::Predicate new_pred;
- if (selectcmpInst->getPredicate() == CmpInst::ICMP_UGE ||
- selectcmpInst->getPredicate() == CmpInst::ICMP_SGE ||
- selectcmpInst->getPredicate() == CmpInst::ICMP_ULE ||
- selectcmpInst->getPredicate() == CmpInst::ICMP_SLE) {
+ switch (pred) {
+ case CmpInst::ICMP_UGE:
+ new_pred = CmpInst::ICMP_UGT;
+ break;
+ case CmpInst::ICMP_SGE:
+ new_pred = CmpInst::ICMP_SGT;
+ break;
+ case CmpInst::ICMP_ULE:
+ new_pred = CmpInst::ICMP_ULT;
+ break;
+ case CmpInst::ICMP_SLE:
+ new_pred = CmpInst::ICMP_SLT;
+ break;
+ default: // keep the compiler happy
+ return false;
+ }
- auto op0 = selectcmpInst->getOperand(0);
- auto op1 = selectcmpInst->getOperand(1);
+ /* split before the icmp instruction */
+ BasicBlock *end_bb = bb->splitBasicBlock(BasicBlock::iterator(IcmpInst));
+
+ /* the old bb now contains a unconditional jump to the new one (end_bb)
+ * we need to delete it later */
+
+ /* create the ICMP instruction with new_pred and add it to the old basic
+ * block bb it is now at the position where the old IcmpInst was */
+ CmpInst *icmp_np = CmpInst::Create(Instruction::ICmp, new_pred, op0, op1);
+ bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), icmp_np);
+
+ /* create a new basic block which holds the new EQ icmp */
+ CmpInst *icmp_eq;
+ /* insert middle_bb before end_bb */
+ BasicBlock *middle_bb =
+ BasicBlock::Create(C, "injected", end_bb->getParent(), end_bb);
+ icmp_eq = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, op0, op1);
+ middle_bb->getInstList().push_back(icmp_eq);
+ /* add an unconditional branch to the end of middle_bb with destination
+ * end_bb */
+ BranchInst::Create(end_bb, middle_bb);
+
+ /* replace the uncond branch with a conditional one, which depends on the
+ * new_pred icmp. True goes to end, false to the middle (injected) bb */
+ auto term = bb->getTerminator();
+ BranchInst::Create(end_bb, middle_bb, icmp_np, bb);
+ term->eraseFromParent();
+
+ /* replace the old IcmpInst (which is the first inst in end_bb) with a PHI
+ * inst to wire up the loose ends */
+ PHINode *PN = PHINode::Create(Int1Ty, 2, "");
+ /* the first result depends on the outcome of icmp_eq */
+ PN->addIncoming(icmp_eq, middle_bb);
+ /* if the source was the original bb we know that the icmp_np yielded true
+ * hence we can hardcode this value */
+ PN->addIncoming(ConstantInt::get(Int1Ty, 1), bb);
+ /* replace the old IcmpInst with our new and shiny PHI inst */
+ BasicBlock::iterator ii(IcmpInst);
+ ReplaceInstWithInst(IcmpInst->getParent()->getInstList(), ii, PN);
+
+ worklist.push_back(icmp_np);
+ worklist.push_back(icmp_eq);
- IntegerType *intTyOp0 = dyn_cast<IntegerType>(op0->getType());
- IntegerType *intTyOp1 = dyn_cast<IntegerType>(op1->getType());
+ return true;
+}
- /* this is probably not needed but we do it anyway */
- if (!intTyOp0 || !intTyOp1) { continue; }
+/// Simplify a signed comparison operator by splitting it into a unsigned and
+/// bit comparison. add all resulting comparisons to
+/// the worklist passed as a reference.
+bool SplitComparesTransform::simplifySignedCompare(CmpInst *IcmpInst, Module &M,
+ CmpWorklist &worklist) {
+ LLVMContext &C = M.getContext();
+ IntegerType *Int1Ty = IntegerType::getInt1Ty(C);
- icomps.push_back(selectcmpInst);
+ BasicBlock *bb = IcmpInst->getParent();
- }
+ auto op0 = IcmpInst->getOperand(0);
+ auto op1 = IcmpInst->getOperand(1);
- }
+ IntegerType *intTyOp0 = dyn_cast<IntegerType>(op0->getType());
+ if (!intTyOp0) { return false; }
+ unsigned bitw = intTyOp0->getBitWidth();
+ IntegerType *IntType = IntegerType::get(C, bitw);
- }
+ /* get the new predicate */
+ auto cmp_inst = dyn_cast<CmpInst>(IcmpInst);
+ if (!cmp_inst) { return false; }
+ auto pred = cmp_inst->getPredicate();
+ CmpInst::Predicate new_pred;
- }
+ if (pred == CmpInst::ICMP_SGT) {
+ new_pred = CmpInst::ICMP_UGT;
+ } else {
+ new_pred = CmpInst::ICMP_ULT;
}
- if (!icomps.size()) { return false; }
-
- for (auto &IcmpInst : icomps) {
-
- BasicBlock *bb = IcmpInst->getParent();
-
- auto op0 = IcmpInst->getOperand(0);
- auto op1 = IcmpInst->getOperand(1);
-
- /* find out what the new predicate is going to be */
- auto cmp_inst = dyn_cast<CmpInst>(IcmpInst);
- if (!cmp_inst) { continue; }
- auto pred = cmp_inst->getPredicate();
- CmpInst::Predicate new_pred;
+ BasicBlock *end_bb = bb->splitBasicBlock(BasicBlock::iterator(IcmpInst));
+
+ /* create a 1 bit compare for the sign bit. to do this shift and trunc
+ * the original operands so only the first bit remains.*/
+ Value *s_op0, *t_op0, *s_op1, *t_op1, *icmp_sign_bit;
+
+ IRBuilder<> IRB(bb->getTerminator());
+ s_op0 = IRB.CreateLShr(op0, ConstantInt::get(IntType, bitw - 1));
+ t_op0 = IRB.CreateTruncOrBitCast(s_op0, Int1Ty);
+ s_op1 = IRB.CreateLShr(op1, ConstantInt::get(IntType, bitw - 1));
+ t_op1 = IRB.CreateTruncOrBitCast(s_op1, Int1Ty);
+ /* compare of the sign bits */
+ icmp_sign_bit = IRB.CreateCmp(CmpInst::ICMP_EQ, t_op0, t_op1);
+
+ /* create a new basic block which is executed if the signedness bit is
+ * different */
+ CmpInst * icmp_inv_sig_cmp;
+ BasicBlock *sign_bb =
+ BasicBlock::Create(C, "sign", end_bb->getParent(), end_bb);
+ if (pred == CmpInst::ICMP_SGT) {
+ /* if we check for > and the op0 positive and op1 negative then the final
+ * result is true. if op0 negative and op1 pos, the cmp must result
+ * in false
+ */
+ icmp_inv_sig_cmp =
+ CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULT, t_op0, t_op1);
- switch (pred) {
+ } else {
+ /* just the inverse of the above statement */
+ icmp_inv_sig_cmp =
+ CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_UGT, t_op0, t_op1);
+ }
- case CmpInst::ICMP_UGE:
- new_pred = CmpInst::ICMP_UGT;
- break;
- case CmpInst::ICMP_SGE:
- new_pred = CmpInst::ICMP_SGT;
- break;
- case CmpInst::ICMP_ULE:
- new_pred = CmpInst::ICMP_ULT;
- break;
- case CmpInst::ICMP_SLE:
- new_pred = CmpInst::ICMP_SLT;
- break;
- default: // keep the compiler happy
- continue;
+ sign_bb->getInstList().push_back(icmp_inv_sig_cmp);
+ BranchInst::Create(end_bb, sign_bb);
- }
+ /* create a new bb which is executed if signedness is equal */
+ CmpInst * icmp_usign_cmp;
+ BasicBlock *middle_bb =
+ BasicBlock::Create(C, "injected", end_bb->getParent(), end_bb);
+ /* we can do a normal unsigned compare now */
+ icmp_usign_cmp = CmpInst::Create(Instruction::ICmp, new_pred, op0, op1);
- /* split before the icmp instruction */
- BasicBlock *end_bb = bb->splitBasicBlock(BasicBlock::iterator(IcmpInst));
+ middle_bb->getInstList().push_back(icmp_usign_cmp);
+ BranchInst::Create(end_bb, middle_bb);
- /* the old bb now contains a unconditional jump to the new one (end_bb)
- * we need to delete it later */
+ auto term = bb->getTerminator();
+ /* if the sign is eq do a normal unsigned cmp, else we have to check the
+ * signedness bit */
+ BranchInst::Create(middle_bb, sign_bb, icmp_sign_bit, bb);
+ term->eraseFromParent();
- /* create the ICMP instruction with new_pred and add it to the old basic
- * block bb it is now at the position where the old IcmpInst was */
- Instruction *icmp_np;
- icmp_np = CmpInst::Create(Instruction::ICmp, new_pred, op0, op1);
- bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()),
- icmp_np);
+ PHINode *PN = PHINode::Create(Int1Ty, 2, "");
- /* create a new basic block which holds the new EQ icmp */
- Instruction *icmp_eq;
- /* insert middle_bb before end_bb */
- BasicBlock *middle_bb =
- BasicBlock::Create(C, "injected", end_bb->getParent(), end_bb);
- icmp_eq = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, op0, op1);
- middle_bb->getInstList().push_back(icmp_eq);
- /* add an unconditional branch to the end of middle_bb with destination
- * end_bb */
- BranchInst::Create(end_bb, middle_bb);
+ PN->addIncoming(icmp_usign_cmp, middle_bb);
+ PN->addIncoming(icmp_inv_sig_cmp, sign_bb);
- /* replace the uncond branch with a conditional one, which depends on the
- * new_pred icmp. True goes to end, false to the middle (injected) bb */
- auto term = bb->getTerminator();
- BranchInst::Create(end_bb, middle_bb, icmp_np, bb);
- term->eraseFromParent();
+ BasicBlock::iterator ii(IcmpInst);
+ ReplaceInstWithInst(IcmpInst->getParent()->getInstList(), ii, PN);
- /* replace the old IcmpInst (which is the first inst in end_bb) with a PHI
- * inst to wire up the loose ends */
- PHINode *PN = PHINode::Create(Int1Ty, 2, "");
- /* the first result depends on the outcome of icmp_eq */
- PN->addIncoming(icmp_eq, middle_bb);
- /* if the source was the original bb we know that the icmp_np yielded true
- * hence we can hardcode this value */
- PN->addIncoming(ConstantInt::get(Int1Ty, 1), bb);
- /* replace the old IcmpInst with our new and shiny PHI inst */
- BasicBlock::iterator ii(IcmpInst);
- ReplaceInstWithInst(IcmpInst->getParent()->getInstList(), ii, PN);
+ // save for later
+ worklist.push_back(icmp_usign_cmp);
- }
+ // signed comparisons are not supported by the splitting code, so we must not
+ // add it to the worklist.
+ // worklist.push_back(icmp_inv_sig_cmp);
return true;
-
}
-/* this function transforms signed compares to equivalent unsigned compares */
-bool SplitComparesTransform::simplifyIntSignedness(Module &M) {
-
- LLVMContext & C = M.getContext();
- std::vector<Instruction *> icomps;
- IntegerType * Int1Ty = IntegerType::getInt1Ty(C);
-
- /* iterate over all functions, bbs and instructions and add
- * all signed compares to icomps vector */
- for (auto &F : M) {
-
- if (!isInInstrumentList(&F)) continue;
+bool SplitComparesTransform::splitCompare(CmpInst *cmp_inst, Module &M,
+ CmpWorklist &worklist) {
+ auto pred = cmp_inst->getPredicate();
+ switch (pred) {
+ case CmpInst::ICMP_EQ:
+ case CmpInst::ICMP_NE:
+ case CmpInst::ICMP_UGT:
+ case CmpInst::ICMP_ULT:
+ break;
+ default:
+ // unsupported predicate!
+ return false;
+ }
- for (auto &BB : F) {
+ auto op0 = cmp_inst->getOperand(0);
+ auto op1 = cmp_inst->getOperand(1);
- for (auto &IN : BB) {
+ // get bitwidth by checking the bitwidth of the first operator
+ IntegerType *intTyOp0 = dyn_cast<IntegerType>(op0->getType());
+ if (!intTyOp0) {
+ // not an integer type
+ return false;
+ }
- CmpInst *selectcmpInst = nullptr;
+ unsigned bitw = intTyOp0->getBitWidth();
+ if (bitw == target_bitwidth) {
+ // already the target bitwidth so we have to do nothing here.
+ return true;
+ }
- if ((selectcmpInst = dyn_cast<CmpInst>(&IN))) {
+ LLVMContext &C = M.getContext();
+ IntegerType *Int1Ty = IntegerType::getInt1Ty(C);
+ BasicBlock * bb = cmp_inst->getParent();
+ IntegerType *OldIntType = IntegerType::get(C, bitw);
+ IntegerType *NewIntType = IntegerType::get(C, bitw / 2);
+ BasicBlock * end_bb = bb->splitBasicBlock(BasicBlock::iterator(cmp_inst));
+ CmpInst * icmp_high, *icmp_low;
- if (selectcmpInst->getPredicate() == CmpInst::ICMP_SGT ||
- selectcmpInst->getPredicate() == CmpInst::ICMP_SLT) {
+ /* create the comparison of the top halves of the original operands */
+ Value *s_op0, *op0_high, *s_op1, *op1_high;
- auto op0 = selectcmpInst->getOperand(0);
- auto op1 = selectcmpInst->getOperand(1);
+ IRBuilder<> IRB(bb->getTerminator());
- IntegerType *intTyOp0 = dyn_cast<IntegerType>(op0->getType());
- IntegerType *intTyOp1 = dyn_cast<IntegerType>(op1->getType());
+ s_op0 = IRB.CreateBinOp(Instruction::LShr, op0,
+ ConstantInt::get(OldIntType, bitw / 2));
+ op0_high = IRB.CreateTruncOrBitCast(s_op0, NewIntType);
- /* see above */
- if (!intTyOp0 || !intTyOp1) { continue; }
+ s_op1 = IRB.CreateBinOp(Instruction::LShr, op1,
+ ConstantInt::get(OldIntType, bitw / 2));
+ op1_high = IRB.CreateTruncOrBitCast(s_op1, NewIntType);
+ icmp_high = cast<CmpInst>(IRB.CreateICmp(pred, op0_high, op1_high));
- /* i think this is not possible but to lazy to look it up */
- if (intTyOp0->getBitWidth() != intTyOp1->getBitWidth()) {
+ PHINode *PN = nullptr;
- continue;
+ /* now we have to destinguish between == != and > < */
+ switch (pred) {
+ case CmpInst::ICMP_EQ:
+ case CmpInst::ICMP_NE: {
+ /* transformation for == and != icmps */
- }
+ /* create a compare for the lower half of the original operands */
+ BasicBlock *cmp_low_bb =
+ BasicBlock::Create(C, "" /*"injected"*/, end_bb->getParent(), end_bb);
- icomps.push_back(selectcmpInst);
+ Value *op0_low, *op1_low;
+ IRBuilder<> Builder(cmp_low_bb);
- }
+ op0_low = Builder.CreateTrunc(op0, NewIntType);
+ op1_low = Builder.CreateTrunc(op1, NewIntType);
+ icmp_low = cast<CmpInst>(Builder.CreateICmp(pred, op0_low, op1_low));
- }
+ BranchInst::Create(end_bb, cmp_low_bb);
+ /* dependent on the cmp of the high parts go to the end or go on with
+ * the comparison */
+ auto term = bb->getTerminator();
+ BranchInst *br = nullptr;
+ if (pred == CmpInst::ICMP_EQ) {
+ br = BranchInst::Create(cmp_low_bb, end_bb, icmp_high, bb);
+ } else {
+ /* CmpInst::ICMP_NE */
+ br = BranchInst::Create(end_bb, cmp_low_bb, icmp_high, bb);
}
+ term->eraseFromParent();
+ /* create the PHI and connect the edges accordingly */
+ PN = PHINode::Create(Int1Ty, 2, "");
+ PN->addIncoming(icmp_low, cmp_low_bb);
+ Value *val = nullptr;
+ if (pred == CmpInst::ICMP_EQ) {
+ val = ConstantInt::get(Int1Ty, 0);
+ } else {
+ /* CmpInst::ICMP_NE */
+ val = ConstantInt::get(Int1Ty, 1);
+ }
+ PN->addIncoming(val, icmp_high->getParent());
+ break;
}
+ case CmpInst::ICMP_UGT:
+ case CmpInst::ICMP_ULT: {
+ /* transformations for < and > */
- }
-
- if (!icomps.size()) { return false; }
-
- for (auto &IcmpInst : icomps) {
-
- BasicBlock *bb = IcmpInst->getParent();
-
- auto op0 = IcmpInst->getOperand(0);
- auto op1 = IcmpInst->getOperand(1);
+ /* create a basic block which checks for the inverse predicate.
+ * if this is true we can go to the end if not we have to go to the
+ * bb which checks the lower half of the operands */
+ Instruction *op0_low, *op1_low;
+ CmpInst *icmp_inv_cmp = nullptr;
+ BasicBlock * inv_cmp_bb =
+ BasicBlock::Create(C, "inv_cmp", end_bb->getParent(), end_bb);
+ if (pred == CmpInst::ICMP_UGT) {
+ icmp_inv_cmp = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULT,
+ op0_high, op1_high);
- IntegerType *intTyOp0 = dyn_cast<IntegerType>(op0->getType());
- if (!intTyOp0) { continue; }
- unsigned bitw = intTyOp0->getBitWidth();
- IntegerType *IntType = IntegerType::get(C, bitw);
+ } else {
+ icmp_inv_cmp = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_UGT,
+ op0_high, op1_high);
+ }
- /* get the new predicate */
- auto cmp_inst = dyn_cast<CmpInst>(IcmpInst);
- if (!cmp_inst) { continue; }
- auto pred = cmp_inst->getPredicate();
- CmpInst::Predicate new_pred;
+ inv_cmp_bb->getInstList().push_back(icmp_inv_cmp);
+ worklist.push_back(icmp_inv_cmp);
- if (pred == CmpInst::ICMP_SGT) {
+ auto term = bb->getTerminator();
+ term->eraseFromParent();
+ BranchInst::Create(end_bb, inv_cmp_bb, icmp_high, bb);
- new_pred = CmpInst::ICMP_UGT;
+ /* create a bb which handles the cmp of the lower halves */
+ BasicBlock *cmp_low_bb =
+ BasicBlock::Create(C, "" /*"injected"*/, end_bb->getParent(), end_bb);
+ op0_low = new TruncInst(op0, NewIntType);
+ cmp_low_bb->getInstList().push_back(op0_low);
+ op1_low = new TruncInst(op1, NewIntType);
+ cmp_low_bb->getInstList().push_back(op1_low);
- } else {
+ icmp_low = CmpInst::Create(Instruction::ICmp, pred, op0_low, op1_low);
+ cmp_low_bb->getInstList().push_back(icmp_low);
+ BranchInst::Create(end_bb, cmp_low_bb);
- new_pred = CmpInst::ICMP_ULT;
+ BranchInst::Create(end_bb, cmp_low_bb, icmp_inv_cmp, inv_cmp_bb);
+ PN = PHINode::Create(Int1Ty, 3);
+ PN->addIncoming(icmp_low, cmp_low_bb);
+ PN->addIncoming(ConstantInt::get(Int1Ty, 1), bb);
+ PN->addIncoming(ConstantInt::get(Int1Ty, 0), inv_cmp_bb);
+ break;
}
+ default:
+ return false;
+ }
- BasicBlock *end_bb = bb->splitBasicBlock(BasicBlock::iterator(IcmpInst));
-
- /* create a 1 bit compare for the sign bit. to do this shift and trunc
- * the original operands so only the first bit remains.*/
- Instruction *s_op0, *t_op0, *s_op1, *t_op1, *icmp_sign_bit;
-
- s_op0 = BinaryOperator::Create(Instruction::LShr, op0,
- ConstantInt::get(IntType, bitw - 1));
- bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), s_op0);
- t_op0 = new TruncInst(s_op0, Int1Ty);
- bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), t_op0);
+ BasicBlock::iterator ii(cmp_inst);
+ ReplaceInstWithInst(cmp_inst->getParent()->getInstList(), ii, PN);
- s_op1 = BinaryOperator::Create(Instruction::LShr, op1,
- ConstantInt::get(IntType, bitw - 1));
- bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), s_op1);
- t_op1 = new TruncInst(s_op1, Int1Ty);
- bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), t_op1);
+ // We split the comparison into low and high. If this isn't our target
+ // bitwidth we recursivly split the low and high parts again until we have
+ // target bitwidth.
+ if ((bitw / 2) > target_bitwidth) {
+ worklist.push_back(icmp_high);
+ worklist.push_back(icmp_low);
+ }
- /* compare of the sign bits */
- icmp_sign_bit =
- CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, t_op0, t_op1);
- bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()),
- icmp_sign_bit);
+ return true;
+}
- /* create a new basic block which is executed if the signedness bit is
- * different */
- Instruction *icmp_inv_sig_cmp;
- BasicBlock * sign_bb =
- BasicBlock::Create(C, "sign", end_bb->getParent(), end_bb);
- if (pred == CmpInst::ICMP_SGT) {
+bool SplitComparesTransform::simplifyAndSplit(CmpInst *I, Module &M) {
+ CmpWorklist worklist;
- /* if we check for > and the op0 positive and op1 negative then the final
- * result is true. if op0 negative and op1 pos, the cmp must result
- * in false
- */
- icmp_inv_sig_cmp =
- CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULT, t_op0, t_op1);
+ auto op0 = I->getOperand(0);
+ auto op1 = I->getOperand(1);
+ if (!op0 || !op1) { return false; }
+ auto op0Ty = dyn_cast<IntegerType>(op0->getType());
+ if (!op0Ty || !isa<IntegerType>(op1->getType())) { return true; }
- } else {
+ unsigned bitw = op0Ty->getBitWidth();
- /* just the inverse of the above statement */
- icmp_inv_sig_cmp =
- CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_UGT, t_op0, t_op1);
+#ifdef VERIFY_TOO_MUCH
+ auto F = I->getParent()->getParent();
+#endif
+ // we run the comparison simplification on all compares regardless of their
+ // bitwidth.
+ if (I->getPredicate() == CmpInst::ICMP_UGE ||
+ I->getPredicate() == CmpInst::ICMP_SGE ||
+ I->getPredicate() == CmpInst::ICMP_ULE ||
+ I->getPredicate() == CmpInst::ICMP_SLE) {
+ if (!simplifyOrEqualsCompare(I, M, worklist)) {
+ reportError(
+ "Failed to simplify inequality or equals comparison "
+ "(UGE,SGE,ULE,SLE)",
+ I, M);
}
+ } else if (I->getPredicate() == CmpInst::ICMP_SGT ||
+ I->getPredicate() == CmpInst::ICMP_SLT) {
+ if (!simplifySignedCompare(I, M, worklist)) {
+ reportError("Failed to simplify signed comparison (SGT,SLT)", I, M);
+ }
+ }
- sign_bb->getInstList().push_back(icmp_inv_sig_cmp);
- BranchInst::Create(end_bb, sign_bb);
-
- /* create a new bb which is executed if signedness is equal */
- Instruction *icmp_usign_cmp;
- BasicBlock * middle_bb =
- BasicBlock::Create(C, "injected", end_bb->getParent(), end_bb);
- /* we can do a normal unsigned compare now */
- icmp_usign_cmp = CmpInst::Create(Instruction::ICmp, new_pred, op0, op1);
- middle_bb->getInstList().push_back(icmp_usign_cmp);
- BranchInst::Create(end_bb, middle_bb);
-
- auto term = bb->getTerminator();
- /* if the sign is eq do a normal unsigned cmp, else we have to check the
- * signedness bit */
- BranchInst::Create(middle_bb, sign_bb, icmp_sign_bit, bb);
- term->eraseFromParent();
-
- PHINode *PN = PHINode::Create(Int1Ty, 2, "");
-
- PN->addIncoming(icmp_usign_cmp, middle_bb);
- PN->addIncoming(icmp_inv_sig_cmp, sign_bb);
+#ifdef VERIFY_TOO_MUCH
+ if (verifyFunction(*F, &errs())) {
+ reportError("simpliyfing compare lead to broken function", nullptr, M);
+ }
+#endif
- BasicBlock::iterator ii(IcmpInst);
- ReplaceInstWithInst(IcmpInst->getParent()->getInstList(), ii, PN);
+ // the simplification methods replace the original CmpInst and push the
+ // resulting new CmpInst into the worklist. If the worklist is empty then
+ // we only have to split the original CmpInst.
+ if (worklist.size() == 0) { worklist.push_back(I); }
+
+ while (!worklist.empty()) {
+ CmpInst *cmp = worklist.pop_back_val();
+ // we split the simplified compares into comparisons with smaller bitwidths
+ // if they are larger than our target_bitwidth.
+ if (bitw > target_bitwidth) {
+ if (!splitCompare(cmp, M, worklist)) {
+ reportError("Failed to split comparison", cmp, M);
+ }
+#ifdef VERIFY_TOO_MUCH
+ if (verifyFunction(*F, &errs())) {
+ reportError("splitting compare lead to broken function", nullptr, M);
+ }
+#endif
+ }
}
+ count++;
return true;
-
}
size_t SplitComparesTransform::nextPowerOfTwo(size_t in) {
-
--in;
in |= in >> 1;
in |= in >> 2;
@@ -520,12 +678,10 @@ size_t SplitComparesTransform::nextPowerOfTwo(size_t in) {
// in |= in >> 8;
// in |= in >> 16;
return in + 1;
-
}
/* splits fcmps into two nested fcmps with sign compare and the rest */
size_t SplitComparesTransform::splitFPCompares(Module &M) {
-
size_t count = 0;
LLVMContext &C = M.getContext();
@@ -537,13 +693,9 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) {
/* define unions with floating point and (sign, exponent, mantissa) triples
*/
if (dl.isLittleEndian()) {
-
} else if (dl.isBigEndian()) {
-
} else {
-
return count;
-
}
#endif
@@ -553,17 +705,13 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) {
/* get all EQ, NE, GT, and LT fcmps. if the other two
* functions were executed only these four predicates should exist */
for (auto &F : M) {
-
if (!isInInstrumentList(&F)) continue;
for (auto &BB : F) {
-
for (auto &IN : BB) {
-
CmpInst *selectcmpInst = nullptr;
if ((selectcmpInst = dyn_cast<CmpInst>(&IN))) {
-
if (selectcmpInst->getPredicate() == CmpInst::FCMP_OEQ ||
selectcmpInst->getPredicate() == CmpInst::FCMP_UEQ ||
selectcmpInst->getPredicate() == CmpInst::FCMP_ONE ||
@@ -572,7 +720,6 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) {
selectcmpInst->getPredicate() == CmpInst::FCMP_OGT ||
selectcmpInst->getPredicate() == CmpInst::FCMP_ULT ||
selectcmpInst->getPredicate() == CmpInst::FCMP_OLT) {
-
auto op0 = selectcmpInst->getOperand(0);
auto op1 = selectcmpInst->getOperand(1);
@@ -584,15 +731,10 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) {
if (TyOp0->isArrayTy() || TyOp0->isVectorTy()) { continue; }
fcomps.push_back(selectcmpInst);
-
}
-
}
-
}
-
}
-
}
if (!fcomps.size()) { return count; }
@@ -600,7 +742,6 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) {
IntegerType *Int1Ty = IntegerType::getInt1Ty(C);
for (auto &FcmpInst : fcomps) {
-
BasicBlock *bb = FcmpInst->getParent();
auto op0 = FcmpInst->getOperand(0);
@@ -725,7 +866,6 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) {
BasicBlock::iterator(signequal_bb->getTerminator()), t_e1);
if (sizeInBits - precision < exTySizeBytes * 8) {
-
m_e0 = BinaryOperator::Create(
Instruction::And, t_e0,
ConstantInt::get(t_e0->getType(), mask_exponent));
@@ -738,10 +878,8 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) {
BasicBlock::iterator(signequal_bb->getTerminator()), m_e1);
} else {
-
m_e0 = t_e0;
m_e1 = t_e1;
-
}
/* compare the exponents of the operands */
@@ -749,7 +887,6 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) {
Instruction *icmp_exponent_result;
BasicBlock * signequal2_bb = signequal_bb;
switch (FcmpInst->getPredicate()) {
-
case CmpInst::FCMP_UEQ:
case CmpInst::FCMP_OEQ:
icmp_exponent_result =
@@ -819,7 +956,6 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) {
break;
default:
continue;
-
}
signequal2_bb->getInstList().insert(
@@ -827,11 +963,9 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) {
icmp_exponent_result);
{
-
term = signequal2_bb->getTerminator();
switch (FcmpInst->getPredicate()) {
-
case CmpInst::FCMP_UEQ:
case CmpInst::FCMP_OEQ:
/* if the exponents are satifying the compare do a fraction cmp in
@@ -854,11 +988,9 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) {
break;
default:
continue;
-
}
term->eraseFromParent();
-
}
/* isolate the mantissa aka fraction */
@@ -866,7 +998,6 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) {
bool needTrunc = IntFractionTy->getPrimitiveSizeInBits() < op_size;
if (precision - 1 < frTySizeBytes * 8) {
-
Instruction *m_f0, *m_f1;
m_f0 = BinaryOperator::Create(
Instruction::And, b_op0,
@@ -880,7 +1011,6 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) {
BasicBlock::iterator(middle_bb->getTerminator()), m_f1);
if (needTrunc) {
-
t_f0 = new TruncInst(m_f0, IntFractionTy);
t_f1 = new TruncInst(m_f1, IntFractionTy);
middle_bb->getInstList().insert(
@@ -889,16 +1019,12 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) {
BasicBlock::iterator(middle_bb->getTerminator()), t_f1);
} else {
-
t_f0 = m_f0;
t_f1 = m_f1;
-
}
} else {
-
if (needTrunc) {
-
t_f0 = new TruncInst(b_op0, IntFractionTy);
t_f1 = new TruncInst(b_op1, IntFractionTy);
middle_bb->getInstList().insert(
@@ -907,12 +1033,9 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) {
BasicBlock::iterator(middle_bb->getTerminator()), t_f1);
} else {
-
t_f0 = b_op0;
t_f1 = b_op1;
-
}
-
}
/* compare the fractions of the operands */
@@ -920,7 +1043,6 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) {
BasicBlock * middle2_bb = middle_bb;
PHINode * PN2 = nullptr;
switch (FcmpInst->getPredicate()) {
-
case CmpInst::FCMP_UEQ:
case CmpInst::FCMP_OEQ:
icmp_fraction_result =
@@ -943,7 +1065,6 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) {
case CmpInst::FCMP_UGT:
case CmpInst::FCMP_OLT:
case CmpInst::FCMP_ULT: {
-
Instruction *icmp_fraction_result2;
middle2_bb = middle_bb->splitBasicBlock(
@@ -956,7 +1077,6 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) {
if (FcmpInst->getPredicate() == CmpInst::FCMP_OGT ||
FcmpInst->getPredicate() == CmpInst::FCMP_UGT) {
-
negative_bb->getInstList().push_back(
icmp_fraction_result = CmpInst::Create(
Instruction::ICmp, CmpInst::ICMP_ULT, t_f0, t_f1));
@@ -965,14 +1085,12 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) {
Instruction::ICmp, CmpInst::ICMP_UGT, t_f0, t_f1));
} else {
-
negative_bb->getInstList().push_back(
icmp_fraction_result = CmpInst::Create(
Instruction::ICmp, CmpInst::ICMP_UGT, t_f0, t_f1));
positive_bb->getInstList().push_back(
icmp_fraction_result2 = CmpInst::Create(
Instruction::ICmp, CmpInst::ICMP_ULT, t_f0, t_f1));
-
}
BranchInst::Create(middle2_bb, negative_bb);
@@ -992,13 +1110,11 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) {
default:
continue;
-
}
PHINode *PN = PHINode::Create(Int1Ty, 3, "");
switch (FcmpInst->getPredicate()) {
-
case CmpInst::FCMP_UEQ:
case CmpInst::FCMP_OEQ:
/* unequal signs cannot be equal values */
@@ -1037,262 +1153,36 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) {
break;
default:
continue;
-
}
BasicBlock::iterator ii(FcmpInst);
ReplaceInstWithInst(FcmpInst->getParent()->getInstList(), ii, PN);
++count;
-
}
return count;
-
-}
-
-/* splits icmps of size bitw into two nested icmps with bitw/2 size each */
-size_t SplitComparesTransform::splitIntCompares(Module &M, unsigned bitw) {
-
- size_t count = 0;
-
- LLVMContext &C = M.getContext();
-
- IntegerType *Int1Ty = IntegerType::getInt1Ty(C);
- IntegerType *OldIntType = IntegerType::get(C, bitw);
- IntegerType *NewIntType = IntegerType::get(C, bitw / 2);
-
- std::vector<Instruction *> icomps;
-
- if (bitw % 2) { return 0; }
-
- /* not supported yet */
- if (bitw > 64) { return 0; }
-
- /* get all EQ, NE, UGT, and ULT icmps of width bitw. if the
- * functions simplifyCompares() and simplifyIntSignedness()
- * were executed only these four predicates should exist */
- for (auto &F : M) {
-
- if (!isInInstrumentList(&F)) continue;
-
- for (auto &BB : F) {
-
- for (auto &IN : BB) {
-
- CmpInst *selectcmpInst = nullptr;
-
- if ((selectcmpInst = dyn_cast<CmpInst>(&IN))) {
-
- if (selectcmpInst->getPredicate() == CmpInst::ICMP_EQ ||
- selectcmpInst->getPredicate() == CmpInst::ICMP_NE ||
- selectcmpInst->getPredicate() == CmpInst::ICMP_UGT ||
- selectcmpInst->getPredicate() == CmpInst::ICMP_ULT) {
-
- auto op0 = selectcmpInst->getOperand(0);
- auto op1 = selectcmpInst->getOperand(1);
-
- IntegerType *intTyOp0 = dyn_cast<IntegerType>(op0->getType());
- IntegerType *intTyOp1 = dyn_cast<IntegerType>(op1->getType());
-
- if (!intTyOp0 || !intTyOp1) { continue; }
-
- /* check if the bitwidths are the one we are looking for */
- if (intTyOp0->getBitWidth() != bitw ||
- intTyOp1->getBitWidth() != bitw) {
-
- continue;
-
- }
-
- icomps.push_back(selectcmpInst);
-
- }
-
- }
-
- }
-
- }
-
- }
-
- if (!icomps.size()) { return 0; }
-
- for (auto &IcmpInst : icomps) {
-
- BasicBlock *bb = IcmpInst->getParent();
-
- auto op0 = IcmpInst->getOperand(0);
- auto op1 = IcmpInst->getOperand(1);
-
- auto cmp_inst = dyn_cast<CmpInst>(IcmpInst);
- if (!cmp_inst) { continue; }
- auto pred = cmp_inst->getPredicate();
-
- BasicBlock *end_bb = bb->splitBasicBlock(BasicBlock::iterator(IcmpInst));
-
- /* create the comparison of the top halves of the original operands */
- Instruction *s_op0, *op0_high, *s_op1, *op1_high, *icmp_high;
-
- s_op0 = BinaryOperator::Create(Instruction::LShr, op0,
- ConstantInt::get(OldIntType, bitw / 2));
- bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), s_op0);
- op0_high = new TruncInst(s_op0, NewIntType);
- bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()),
- op0_high);
-
- s_op1 = BinaryOperator::Create(Instruction::LShr, op1,
- ConstantInt::get(OldIntType, bitw / 2));
- bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), s_op1);
- op1_high = new TruncInst(s_op1, NewIntType);
- bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()),
- op1_high);
-
- icmp_high = CmpInst::Create(Instruction::ICmp, pred, op0_high, op1_high);
- bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()),
- icmp_high);
-
- /* now we have to destinguish between == != and > < */
- if (pred == CmpInst::ICMP_EQ || pred == CmpInst::ICMP_NE) {
-
- /* transformation for == and != icmps */
-
- /* create a compare for the lower half of the original operands */
- Instruction *op0_low, *op1_low, *icmp_low;
- BasicBlock * cmp_low_bb =
- BasicBlock::Create(C, "injected", end_bb->getParent(), end_bb);
-
- op0_low = new TruncInst(op0, NewIntType);
- cmp_low_bb->getInstList().push_back(op0_low);
-
- op1_low = new TruncInst(op1, NewIntType);
- cmp_low_bb->getInstList().push_back(op1_low);
-
- icmp_low = CmpInst::Create(Instruction::ICmp, pred, op0_low, op1_low);
- cmp_low_bb->getInstList().push_back(icmp_low);
- BranchInst::Create(end_bb, cmp_low_bb);
-
- /* dependent on the cmp of the high parts go to the end or go on with
- * the comparison */
- auto term = bb->getTerminator();
- if (pred == CmpInst::ICMP_EQ) {
-
- BranchInst::Create(cmp_low_bb, end_bb, icmp_high, bb);
-
- } else {
-
- /* CmpInst::ICMP_NE */
- BranchInst::Create(end_bb, cmp_low_bb, icmp_high, bb);
-
- }
-
- term->eraseFromParent();
-
- /* create the PHI and connect the edges accordingly */
- PHINode *PN = PHINode::Create(Int1Ty, 2, "");
- PN->addIncoming(icmp_low, cmp_low_bb);
- if (pred == CmpInst::ICMP_EQ) {
-
- PN->addIncoming(ConstantInt::get(Int1Ty, 0), bb);
-
- } else {
-
- /* CmpInst::ICMP_NE */
- PN->addIncoming(ConstantInt::get(Int1Ty, 1), bb);
-
- }
-
- /* replace the old icmp with the new PHI */
- BasicBlock::iterator ii(IcmpInst);
- ReplaceInstWithInst(IcmpInst->getParent()->getInstList(), ii, PN);
-
- } else {
-
- /* CmpInst::ICMP_UGT and CmpInst::ICMP_ULT */
- /* transformations for < and > */
-
- /* create a basic block which checks for the inverse predicate.
- * if this is true we can go to the end if not we have to go to the
- * bb which checks the lower half of the operands */
- Instruction *icmp_inv_cmp, *op0_low, *op1_low, *icmp_low;
- BasicBlock * inv_cmp_bb =
- BasicBlock::Create(C, "inv_cmp", end_bb->getParent(), end_bb);
- if (pred == CmpInst::ICMP_UGT) {
-
- icmp_inv_cmp = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULT,
- op0_high, op1_high);
-
- } else {
-
- icmp_inv_cmp = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_UGT,
- op0_high, op1_high);
-
- }
-
- inv_cmp_bb->getInstList().push_back(icmp_inv_cmp);
-
- auto term = bb->getTerminator();
- term->eraseFromParent();
- BranchInst::Create(end_bb, inv_cmp_bb, icmp_high, bb);
-
- /* create a bb which handles the cmp of the lower halves */
- BasicBlock *cmp_low_bb =
- BasicBlock::Create(C, "injected", end_bb->getParent(), end_bb);
- op0_low = new TruncInst(op0, NewIntType);
- cmp_low_bb->getInstList().push_back(op0_low);
- op1_low = new TruncInst(op1, NewIntType);
- cmp_low_bb->getInstList().push_back(op1_low);
-
- icmp_low = CmpInst::Create(Instruction::ICmp, pred, op0_low, op1_low);
- cmp_low_bb->getInstList().push_back(icmp_low);
- BranchInst::Create(end_bb, cmp_low_bb);
-
- BranchInst::Create(end_bb, cmp_low_bb, icmp_inv_cmp, inv_cmp_bb);
-
- PHINode *PN = PHINode::Create(Int1Ty, 3);
- PN->addIncoming(icmp_low, cmp_low_bb);
- PN->addIncoming(ConstantInt::get(Int1Ty, 1), bb);
- PN->addIncoming(ConstantInt::get(Int1Ty, 0), inv_cmp_bb);
-
- BasicBlock::iterator ii(IcmpInst);
- ReplaceInstWithInst(IcmpInst->getParent()->getInstList(), ii, PN);
-
- }
-
- ++count;
-
- }
-
- return count;
-
}
bool SplitComparesTransform::runOnModule(Module &M) {
-
- int bitw = 64;
- size_t count = 0;
-
char *bitw_env = getenv("AFL_LLVM_LAF_SPLIT_COMPARES_BITW");
if (!bitw_env) bitw_env = getenv("LAF_SPLIT_COMPARES_BITW");
- if (bitw_env) { bitw = atoi(bitw_env); }
+ if (bitw_env) { target_bitwidth = atoi(bitw_env); }
enableFPSplit = getenv("AFL_LLVM_LAF_SPLIT_FLOATS") != NULL;
if ((isatty(2) && getenv("AFL_QUIET") == NULL) ||
getenv("AFL_DEBUG") != NULL) {
+ errs() << "Split-compare-pass by laf.intel@gmail.com, extended by "
+ "heiko@hexco.de (splitting icmp to "
+ << target_bitwidth << " bit)\n";
- printf(
- "Split-compare-pass by laf.intel@gmail.com, extended by "
- "heiko@hexco.de\n");
+ if (getenv("AFL_DEBUG") != NULL && !debug) { debug = 1; }
} else {
-
be_quiet = 1;
-
}
if (enableFPSplit) {
-
count = splitFPCompares(M);
/*
@@ -1305,60 +1195,55 @@ bool SplitComparesTransform::runOnModule(Module &M) {
*/
simplifyFPCompares(M);
-
}
- simplifyCompares(M);
-
- simplifyIntSignedness(M);
-
- switch (bitw) {
+ std::vector<CmpInst *> worklist;
+ /* iterate over all functions, bbs and instruction search for all integer
+ * compare instructions. Save them into the worklist for later. */
+ for (auto &F : M) {
+ if (!isInInstrumentList(&F)) continue;
- case 64:
- count += splitIntCompares(M, bitw);
- if (debug)
- errs() << "Split-integer-compare-pass " << bitw << "bit: " << count
- << " split\n";
- bitw >>= 1;
-#if LLVM_VERSION_MAJOR > 3 || \
- (LLVM_VERSION_MAJOR == 3 && LLVM_VERSION_MINOR > 7)
- [[clang::fallthrough]]; /*FALLTHRU*/ /* FALLTHROUGH */
-#endif
- case 32:
- count += splitIntCompares(M, bitw);
- if (debug)
- errs() << "Split-integer-compare-pass " << bitw << "bit: " << count
- << " split\n";
- bitw >>= 1;
-#if LLVM_VERSION_MAJOR > 3 || \
- (LLVM_VERSION_MAJOR == 3 && LLVM_VERSION_MINOR > 7)
- [[clang::fallthrough]]; /*FALLTHRU*/ /* FALLTHROUGH */
-#endif
- case 16:
- count += splitIntCompares(M, bitw);
- if (debug)
- errs() << "Split-integer-compare-pass " << bitw << "bit: " << count
- << " split\n";
- // bitw >>= 1;
- break;
+ for (auto &BB : F) {
+ for (auto &IN : BB) {
+ if (auto CI = dyn_cast<CmpInst>(&IN)) {
+ auto op0 = CI->getOperand(0);
+ auto op1 = CI->getOperand(1);
+ if (!op0 || !op1) { return false; }
+ auto iTy1 = dyn_cast<IntegerType>(op0->getType());
+ if (iTy1 && isa<IntegerType>(op1->getType())) {
+ unsigned bitw = iTy1->getBitWidth();
+ if (isSupportedBitWidth(bitw)) { worklist.push_back(CI); }
+ }
+ }
+ }
+ }
+ }
- default:
- // if (!be_quiet) errs() << "NOT Running split-compare-pass \n";
- return false;
- break;
+ // now that we have a list of all integer comparisons we can start replacing
+ // them with the splitted alternatives.
+ for (auto CI : worklist) {
+ simplifyAndSplit(CI, M);
+ }
+ bool brokenDebug = false;
+ if (verifyModule(M, &errs(), &brokenDebug)) {
+ reportError(
+ "Module Verifier failed! Consider reporting a bug with the AFL++ "
+ "project.",
+ nullptr, M);
}
- verifyModule(M);
+ if (brokenDebug) {
+ reportError("Module Verifier reported broken Debug Infos - Stripping!",
+ nullptr, M);
+ StripDebugInfo(M);
+ }
return true;
-
}
static void registerSplitComparesPass(const PassManagerBuilder &,
legacy::PassManagerBase &PM) {
-
PM.add(new SplitComparesTransform());
-
}
static RegisterStandardPasses RegisterSplitComparesPass(
@@ -1373,3 +1258,7 @@ static RegisterStandardPasses RegisterSplitComparesTransPassLTO(
registerSplitComparesPass);
#endif
+static RegisterPass<SplitComparesTransform> X("splitcompares",
+ "AFL++ split compares",
+ true /* Only looks at CFG */,
+ true /* Analysis Pass */);
diff --git a/test/test-llvm.sh b/test/test-llvm.sh
index f902ffc5..8090e176 100755
--- a/test/test-llvm.sh
+++ b/test/test-llvm.sh
@@ -186,6 +186,29 @@ test -e ../afl-clang-fast -a -e ../split-switches-pass.so && {
}
rm -f test-instr.plain
+ $ECHO "$GREY[*] llvm_mode laf-intel/compcov testing splitting integer types (this might take some time)"
+ for testcase in ./test-int_cases.c ./test-uint_cases.c; do
+ for I in char short int long "long long"; do
+ for BITS in 8 16 32 64; do
+ bin="$testcase-split-$I-$BITS.compcov"
+ AFL_LLVM_INSTRUMENT=AFL AFL_DEBUG=1 AFL_LLVM_LAF_SPLIT_COMPARES_BITW=$BITS AFL_LLVM_LAF_SPLIT_COMPARES=1 ../afl-clang-fast -DINT_TYPE="$I" -o "$bin" "$testcase" > test.out 2>&1;
+ if ! test -e "$bin"; then
+ cat test.out
+ $ECHO "$RED[!] llvm_mode laf-intel/compcov integer splitting failed! ($testcase with type $I split to $BITS)!";
+ CODE=1
+ break
+ fi
+ if ! "$bin"; then
+ $ECHO "$RED[!] llvm_mode laf-intel/compcov integer splitting resulted in miscompilation (type $I split to $BITS)!";
+ CODE=1
+ break
+ fi
+ rm -f "$bin" test.out || true
+ done
+ done
+ done
+ rm -f test-int-split*.compcov test.out
+
AFL_LLVM_INSTRUMENT=AFL AFL_DEBUG=1 AFL_LLVM_LAF_SPLIT_SWITCHES=1 AFL_LLVM_LAF_TRANSFORM_COMPARES=1 AFL_LLVM_LAF_SPLIT_COMPARES=1 ../afl-clang-fast -o test-compcov.compcov test-compcov.c > test.out 2>&1
test -e test-compcov.compcov && test_compcov_binary_functionality ./test-compcov.compcov && {
grep --binary-files=text -Eq " [ 123][0-9][0-9] location| [3-9][0-9] location" test.out && {
diff --git a/test/test-uint_cases.c b/test/test-uint_cases.c
index 8496cffe..a277e28a 100644
--- a/test/test-uint_cases.c
+++ b/test/test-uint_cases.c
@@ -1,16 +1,16 @@
/*
- * compile with -DUINT_TYPE="unsigned char"
- * or -DUINT_TYPE="unsigned short"
- * or -DUINT_TYPE="unsigned int"
- * or -DUINT_TYPE="unsigned long"
- * or -DUINT_TYPE="unsigned long long"
+ * compile with -DINT_TYPE="char"
+ * or -DINT_TYPE="short"
+ * or -DINT_TYPE="int"
+ * or -DINT_TYPE="long"
+ * or -DINT_TYPE="long long"
*/
#include <assert.h>
int main() {
- volatile UINT_TYPE a, b;
+ volatile unsigned INT_TYPE a, b;
a = 1;
b = 8;
@@ -21,7 +21,7 @@ int main() {
assert((a != b));
assert(!(a == b));
- if ((UINT_TYPE)(~0) > 255) {
+ if ((INT_TYPE)(~0) > 255) {
volatile unsigned short a, b;
a = 256+2;
b = 256+21;
@@ -41,7 +41,7 @@ int main() {
assert((a != b));
assert(!(a == b));
- if ((UINT_TYPE)(~0) > 65535) {
+ if ((INT_TYPE)(~0) > 65535) {
volatile unsigned int a, b;
a = 65536+2;
b = 65536+21;
@@ -62,7 +62,7 @@ int main() {
assert(!(a == b));
}
- if ((UINT_TYPE)(~0) > 4294967295) {
+ if ((INT_TYPE)(~0) > 4294967295) {
volatile unsigned long a, b;
a = 4294967296+2;
b = 4294967296+21;
@@ -93,7 +93,7 @@ int main() {
assert((a != b));
assert(!(a == b));
- if ((UINT_TYPE)(~0) > 255) {
+ if ((INT_TYPE)(~0) > 255) {
volatile unsigned short a, b;
a = 256+2;
b = 256+1;
@@ -113,7 +113,7 @@ int main() {
assert((a != b));
assert(!(a == b));
- if ((UINT_TYPE)(~0) > 65535) {
+ if ((INT_TYPE)(~0) > 65535) {
volatile unsigned int a, b;
a = 65536+2;
b = 65536+1;
@@ -133,7 +133,7 @@ int main() {
assert((a != b));
assert(!(a == b));
- if ((UINT_TYPE)(~0) > 4294967295) {
+ if ((INT_TYPE)(~0) > 4294967295) {
volatile unsigned long a, b;
a = 4294967296+2;
b = 4294967296+1;
@@ -176,7 +176,7 @@ int main() {
assert(!(a != b));
assert((a == b));
- if ((UINT_TYPE)(~0) > 255) {
+ if ((INT_TYPE)(~0) > 255) {
volatile unsigned short a, b;
a = 256+5;
b = 256+5;
@@ -187,7 +187,7 @@ int main() {
assert(!(a != b));
assert((a == b));
- if ((UINT_TYPE)(~0) > 65535) {
+ if ((INT_TYPE)(~0) > 65535) {
volatile unsigned int a, b;
a = 65536+5;
b = 65536+5;
@@ -198,7 +198,7 @@ int main() {
assert(!(a != b));
assert((a == b));
- if ((UINT_TYPE)(~0) > 4294967295) {
+ if ((INT_TYPE)(~0) > 4294967295) {
volatile unsigned long a, b;
a = 4294967296+5;
b = 4294967296+5;