diff options
Diffstat (limited to 'instrumentation/split-compares-pass.so.cc')
-rw-r--r-- | instrumentation/split-compares-pass.so.cc | 987 |
1 files changed, 531 insertions, 456 deletions
diff --git a/instrumentation/split-compares-pass.so.cc b/instrumentation/split-compares-pass.so.cc index b02a89fb..13f45b69 100644 --- a/instrumentation/split-compares-pass.so.cc +++ b/instrumentation/split-compares-pass.so.cc @@ -47,6 +47,10 @@ using namespace llvm; #include "afl-llvm-common.h" +// uncomment this toggle function verification at each step. horribly slow, but +// helps to pinpoint a potential problem in the splitting code. +//#define VERIFY_TOO_MUCH 1 + namespace { class SplitComparesTransform : public ModulePass { @@ -67,28 +71,101 @@ class SplitComparesTransform : public ModulePass { const char *getPassName() const override { #endif - return "simplifies and splits ICMP instructions"; + return "AFL_SplitComparesTransform"; } private: int enableFPSplit; - size_t splitIntCompares(Module &M, unsigned bitw); + unsigned target_bitwidth = 8; + + size_t count = 0; + size_t splitFPCompares(Module &M); - bool simplifyCompares(Module &M); bool simplifyFPCompares(Module &M); - bool simplifyIntSignedness(Module &M); size_t nextPowerOfTwo(size_t in); + using CmpWorklist = SmallVector<CmpInst *, 8>; + + /// simplify the comparison and then split the comparison until the + /// target_bitwidth is reached. + bool simplifyAndSplit(CmpInst *I, Module &M); + /// simplify a non-strict comparison (e.g., less than or equals) + bool simplifyOrEqualsCompare(CmpInst *IcmpInst, Module &M, + CmpWorklist &worklist); + /// simplify a signed comparison (signed less or greater than) + bool simplifySignedCompare(CmpInst *IcmpInst, Module &M, + CmpWorklist &worklist); + /// splits an icmp into nested icmps recursivly until target_bitwidth is + /// reached + bool splitCompare(CmpInst *I, Module &M, CmpWorklist &worklist); + + /// print an error to llvm's errs stream, but only if not ordered to be quiet + void reportError(const StringRef msg, Instruction *I, Module &M) { + + if (!be_quiet) { + + errs() << "[AFL++ SplitComparesTransform] ERROR: " << msg << "\n"; + if (debug) { + + if (I) { + + errs() << "Instruction = " << *I << "\n"; + if (auto BB = I->getParent()) { + + if (auto F = BB->getParent()) { + + if (F->hasName()) { + + errs() << "|-> in function " << F->getName() << " "; + + } + + } + + } + + } + + auto n = M.getName(); + if (n.size() > 0) { errs() << "in module " << n << "\n"; } + + } + + } + + } + + bool isSupportedBitWidth(unsigned bitw) { + + // IDK whether the icmp code works on other bitwidths. I guess not? So we + // try to avoid dealing with other weird icmp's that llvm might use (looking + // at you `icmp i0`). + switch (bitw) { + + case 8: + case 16: + case 32: + case 64: + case 128: + case 256: + return true; + default: + return false; + + } + + } + }; } // namespace char SplitComparesTransform::ID = 0; -/* This function splits FCMP instructions with xGE or xLE predicates into two - * FCMP instructions with predicate xGT or xLT and EQ */ +/// This function splits FCMP instructions with xGE or xLE predicates into two +/// FCMP instructions with predicate xGT or xLT and EQ bool SplitComparesTransform::simplifyFPCompares(Module &M) { LLVMContext & C = M.getContext(); @@ -221,292 +298,481 @@ bool SplitComparesTransform::simplifyFPCompares(Module &M) { } -/* This function splits ICMP instructions with xGE or xLE predicates into two - * ICMP instructions with predicate xGT or xLT and EQ */ -bool SplitComparesTransform::simplifyCompares(Module &M) { +/// This function splits ICMP instructions with xGE or xLE predicates into two +/// ICMP instructions with predicate xGT or xLT and EQ +bool SplitComparesTransform::simplifyOrEqualsCompare(CmpInst * IcmpInst, + Module & M, + CmpWorklist &worklist) { - LLVMContext & C = M.getContext(); - std::vector<Instruction *> icomps; - IntegerType * Int1Ty = IntegerType::getInt1Ty(C); + LLVMContext &C = M.getContext(); + IntegerType *Int1Ty = IntegerType::getInt1Ty(C); - /* iterate over all functions, bbs and instruction and add - * all integer comparisons with >= and <= predicates to the icomps vector */ - for (auto &F : M) { + /* find out what the new predicate is going to be */ + auto cmp_inst = dyn_cast<CmpInst>(IcmpInst); + if (!cmp_inst) { return false; } - if (!isInInstrumentList(&F)) continue; + BasicBlock *bb = IcmpInst->getParent(); - for (auto &BB : F) { + auto op0 = IcmpInst->getOperand(0); + auto op1 = IcmpInst->getOperand(1); - for (auto &IN : BB) { + CmpInst::Predicate pred = cmp_inst->getPredicate(); + CmpInst::Predicate new_pred; - CmpInst *selectcmpInst = nullptr; + switch (pred) { - if ((selectcmpInst = dyn_cast<CmpInst>(&IN))) { + case CmpInst::ICMP_UGE: + new_pred = CmpInst::ICMP_UGT; + break; + case CmpInst::ICMP_SGE: + new_pred = CmpInst::ICMP_SGT; + break; + case CmpInst::ICMP_ULE: + new_pred = CmpInst::ICMP_ULT; + break; + case CmpInst::ICMP_SLE: + new_pred = CmpInst::ICMP_SLT; + break; + default: // keep the compiler happy + return false; - if (selectcmpInst->getPredicate() == CmpInst::ICMP_UGE || - selectcmpInst->getPredicate() == CmpInst::ICMP_SGE || - selectcmpInst->getPredicate() == CmpInst::ICMP_ULE || - selectcmpInst->getPredicate() == CmpInst::ICMP_SLE) { + } - auto op0 = selectcmpInst->getOperand(0); - auto op1 = selectcmpInst->getOperand(1); + /* split before the icmp instruction */ + BasicBlock *end_bb = bb->splitBasicBlock(BasicBlock::iterator(IcmpInst)); + + /* the old bb now contains a unconditional jump to the new one (end_bb) + * we need to delete it later */ + + /* create the ICMP instruction with new_pred and add it to the old basic + * block bb it is now at the position where the old IcmpInst was */ + CmpInst *icmp_np = CmpInst::Create(Instruction::ICmp, new_pred, op0, op1); + bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), icmp_np); + + /* create a new basic block which holds the new EQ icmp */ + CmpInst *icmp_eq; + /* insert middle_bb before end_bb */ + BasicBlock *middle_bb = + BasicBlock::Create(C, "injected", end_bb->getParent(), end_bb); + icmp_eq = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, op0, op1); + middle_bb->getInstList().push_back(icmp_eq); + /* add an unconditional branch to the end of middle_bb with destination + * end_bb */ + BranchInst::Create(end_bb, middle_bb); + + /* replace the uncond branch with a conditional one, which depends on the + * new_pred icmp. True goes to end, false to the middle (injected) bb */ + auto term = bb->getTerminator(); + BranchInst::Create(end_bb, middle_bb, icmp_np, bb); + term->eraseFromParent(); + + /* replace the old IcmpInst (which is the first inst in end_bb) with a PHI + * inst to wire up the loose ends */ + PHINode *PN = PHINode::Create(Int1Ty, 2, ""); + /* the first result depends on the outcome of icmp_eq */ + PN->addIncoming(icmp_eq, middle_bb); + /* if the source was the original bb we know that the icmp_np yielded true + * hence we can hardcode this value */ + PN->addIncoming(ConstantInt::get(Int1Ty, 1), bb); + /* replace the old IcmpInst with our new and shiny PHI inst */ + BasicBlock::iterator ii(IcmpInst); + ReplaceInstWithInst(IcmpInst->getParent()->getInstList(), ii, PN); + + worklist.push_back(icmp_np); + worklist.push_back(icmp_eq); - IntegerType *intTyOp0 = dyn_cast<IntegerType>(op0->getType()); - IntegerType *intTyOp1 = dyn_cast<IntegerType>(op1->getType()); + return true; - /* this is probably not needed but we do it anyway */ - if (!intTyOp0 || !intTyOp1) { continue; } +} - icomps.push_back(selectcmpInst); +/// Simplify a signed comparison operator by splitting it into a unsigned and +/// bit comparison. add all resulting comparisons to +/// the worklist passed as a reference. +bool SplitComparesTransform::simplifySignedCompare(CmpInst *IcmpInst, Module &M, + CmpWorklist &worklist) { - } + LLVMContext &C = M.getContext(); + IntegerType *Int1Ty = IntegerType::getInt1Ty(C); - } + BasicBlock *bb = IcmpInst->getParent(); - } + auto op0 = IcmpInst->getOperand(0); + auto op1 = IcmpInst->getOperand(1); - } + IntegerType *intTyOp0 = dyn_cast<IntegerType>(op0->getType()); + if (!intTyOp0) { return false; } + unsigned bitw = intTyOp0->getBitWidth(); + IntegerType *IntType = IntegerType::get(C, bitw); - } + /* get the new predicate */ + auto cmp_inst = dyn_cast<CmpInst>(IcmpInst); + if (!cmp_inst) { return false; } + auto pred = cmp_inst->getPredicate(); + CmpInst::Predicate new_pred; - if (!icomps.size()) { return false; } + if (pred == CmpInst::ICMP_SGT) { - for (auto &IcmpInst : icomps) { + new_pred = CmpInst::ICMP_UGT; - BasicBlock *bb = IcmpInst->getParent(); + } else { - auto op0 = IcmpInst->getOperand(0); - auto op1 = IcmpInst->getOperand(1); + new_pred = CmpInst::ICMP_ULT; - /* find out what the new predicate is going to be */ - auto cmp_inst = dyn_cast<CmpInst>(IcmpInst); - if (!cmp_inst) { continue; } - auto pred = cmp_inst->getPredicate(); - CmpInst::Predicate new_pred; + } - switch (pred) { + BasicBlock *end_bb = bb->splitBasicBlock(BasicBlock::iterator(IcmpInst)); + + /* create a 1 bit compare for the sign bit. to do this shift and trunc + * the original operands so only the first bit remains.*/ + Value *s_op0, *t_op0, *s_op1, *t_op1, *icmp_sign_bit; + + IRBuilder<> IRB(bb->getTerminator()); + s_op0 = IRB.CreateLShr(op0, ConstantInt::get(IntType, bitw - 1)); + t_op0 = IRB.CreateTruncOrBitCast(s_op0, Int1Ty); + s_op1 = IRB.CreateLShr(op1, ConstantInt::get(IntType, bitw - 1)); + t_op1 = IRB.CreateTruncOrBitCast(s_op1, Int1Ty); + /* compare of the sign bits */ + icmp_sign_bit = IRB.CreateICmp(CmpInst::ICMP_EQ, t_op0, t_op1); + + /* create a new basic block which is executed if the signedness bit is + * different */ + CmpInst * icmp_inv_sig_cmp; + BasicBlock *sign_bb = + BasicBlock::Create(C, "sign", end_bb->getParent(), end_bb); + if (pred == CmpInst::ICMP_SGT) { + + /* if we check for > and the op0 positive and op1 negative then the final + * result is true. if op0 negative and op1 pos, the cmp must result + * in false + */ + icmp_inv_sig_cmp = + CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULT, t_op0, t_op1); - case CmpInst::ICMP_UGE: - new_pred = CmpInst::ICMP_UGT; - break; - case CmpInst::ICMP_SGE: - new_pred = CmpInst::ICMP_SGT; - break; - case CmpInst::ICMP_ULE: - new_pred = CmpInst::ICMP_ULT; - break; - case CmpInst::ICMP_SLE: - new_pred = CmpInst::ICMP_SLT; - break; - default: // keep the compiler happy - continue; + } else { - } + /* just the inverse of the above statement */ + icmp_inv_sig_cmp = + CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_UGT, t_op0, t_op1); - /* split before the icmp instruction */ - BasicBlock *end_bb = bb->splitBasicBlock(BasicBlock::iterator(IcmpInst)); + } - /* the old bb now contains a unconditional jump to the new one (end_bb) - * we need to delete it later */ + sign_bb->getInstList().push_back(icmp_inv_sig_cmp); + BranchInst::Create(end_bb, sign_bb); - /* create the ICMP instruction with new_pred and add it to the old basic - * block bb it is now at the position where the old IcmpInst was */ - Instruction *icmp_np; - icmp_np = CmpInst::Create(Instruction::ICmp, new_pred, op0, op1); - bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), - icmp_np); + /* create a new bb which is executed if signedness is equal */ + CmpInst * icmp_usign_cmp; + BasicBlock *middle_bb = + BasicBlock::Create(C, "injected", end_bb->getParent(), end_bb); + /* we can do a normal unsigned compare now */ + icmp_usign_cmp = CmpInst::Create(Instruction::ICmp, new_pred, op0, op1); - /* create a new basic block which holds the new EQ icmp */ - Instruction *icmp_eq; - /* insert middle_bb before end_bb */ - BasicBlock *middle_bb = - BasicBlock::Create(C, "injected", end_bb->getParent(), end_bb); - icmp_eq = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, op0, op1); - middle_bb->getInstList().push_back(icmp_eq); - /* add an unconditional branch to the end of middle_bb with destination - * end_bb */ - BranchInst::Create(end_bb, middle_bb); + middle_bb->getInstList().push_back(icmp_usign_cmp); + BranchInst::Create(end_bb, middle_bb); - /* replace the uncond branch with a conditional one, which depends on the - * new_pred icmp. True goes to end, false to the middle (injected) bb */ - auto term = bb->getTerminator(); - BranchInst::Create(end_bb, middle_bb, icmp_np, bb); - term->eraseFromParent(); + auto term = bb->getTerminator(); + /* if the sign is eq do a normal unsigned cmp, else we have to check the + * signedness bit */ + BranchInst::Create(middle_bb, sign_bb, icmp_sign_bit, bb); + term->eraseFromParent(); - /* replace the old IcmpInst (which is the first inst in end_bb) with a PHI - * inst to wire up the loose ends */ - PHINode *PN = PHINode::Create(Int1Ty, 2, ""); - /* the first result depends on the outcome of icmp_eq */ - PN->addIncoming(icmp_eq, middle_bb); - /* if the source was the original bb we know that the icmp_np yielded true - * hence we can hardcode this value */ - PN->addIncoming(ConstantInt::get(Int1Ty, 1), bb); - /* replace the old IcmpInst with our new and shiny PHI inst */ - BasicBlock::iterator ii(IcmpInst); - ReplaceInstWithInst(IcmpInst->getParent()->getInstList(), ii, PN); + PHINode *PN = PHINode::Create(Int1Ty, 2, ""); - } + PN->addIncoming(icmp_usign_cmp, middle_bb); + PN->addIncoming(icmp_inv_sig_cmp, sign_bb); + + BasicBlock::iterator ii(IcmpInst); + ReplaceInstWithInst(IcmpInst->getParent()->getInstList(), ii, PN); + + // save for later + worklist.push_back(icmp_usign_cmp); + + // signed comparisons are not supported by the splitting code, so we must not + // add it to the worklist. + // worklist.push_back(icmp_inv_sig_cmp); return true; } -/* this function transforms signed compares to equivalent unsigned compares */ -bool SplitComparesTransform::simplifyIntSignedness(Module &M) { +bool SplitComparesTransform::splitCompare(CmpInst *cmp_inst, Module &M, + CmpWorklist &worklist) { - LLVMContext & C = M.getContext(); - std::vector<Instruction *> icomps; - IntegerType * Int1Ty = IntegerType::getInt1Ty(C); + auto pred = cmp_inst->getPredicate(); + switch (pred) { - /* iterate over all functions, bbs and instructions and add - * all signed compares to icomps vector */ - for (auto &F : M) { + case CmpInst::ICMP_EQ: + case CmpInst::ICMP_NE: + case CmpInst::ICMP_UGT: + case CmpInst::ICMP_ULT: + break; + default: + // unsupported predicate! + return false; - if (!isInInstrumentList(&F)) continue; + } - for (auto &BB : F) { + auto op0 = cmp_inst->getOperand(0); + auto op1 = cmp_inst->getOperand(1); - for (auto &IN : BB) { + // get bitwidth by checking the bitwidth of the first operator + IntegerType *intTyOp0 = dyn_cast<IntegerType>(op0->getType()); + if (!intTyOp0) { - CmpInst *selectcmpInst = nullptr; + // not an integer type + return false; - if ((selectcmpInst = dyn_cast<CmpInst>(&IN))) { + } - if (selectcmpInst->getPredicate() == CmpInst::ICMP_SGT || - selectcmpInst->getPredicate() == CmpInst::ICMP_SLT) { + unsigned bitw = intTyOp0->getBitWidth(); + if (bitw == target_bitwidth) { - auto op0 = selectcmpInst->getOperand(0); - auto op1 = selectcmpInst->getOperand(1); + // already the target bitwidth so we have to do nothing here. + return true; + + } + + LLVMContext &C = M.getContext(); + IntegerType *Int1Ty = IntegerType::getInt1Ty(C); + BasicBlock * bb = cmp_inst->getParent(); + IntegerType *OldIntType = IntegerType::get(C, bitw); + IntegerType *NewIntType = IntegerType::get(C, bitw / 2); + BasicBlock * end_bb = bb->splitBasicBlock(BasicBlock::iterator(cmp_inst)); + CmpInst * icmp_high, *icmp_low; - IntegerType *intTyOp0 = dyn_cast<IntegerType>(op0->getType()); - IntegerType *intTyOp1 = dyn_cast<IntegerType>(op1->getType()); + /* create the comparison of the top halves of the original operands */ + Value *s_op0, *op0_high, *s_op1, *op1_high; - /* see above */ - if (!intTyOp0 || !intTyOp1) { continue; } + IRBuilder<> IRB(bb->getTerminator()); - /* i think this is not possible but to lazy to look it up */ - if (intTyOp0->getBitWidth() != intTyOp1->getBitWidth()) { + s_op0 = IRB.CreateBinOp(Instruction::LShr, op0, + ConstantInt::get(OldIntType, bitw / 2)); + op0_high = IRB.CreateTruncOrBitCast(s_op0, NewIntType); - continue; + s_op1 = IRB.CreateBinOp(Instruction::LShr, op1, + ConstantInt::get(OldIntType, bitw / 2)); + op1_high = IRB.CreateTruncOrBitCast(s_op1, NewIntType); + icmp_high = cast<CmpInst>(IRB.CreateICmp(pred, op0_high, op1_high)); - } + PHINode *PN = nullptr; - icomps.push_back(selectcmpInst); + /* now we have to destinguish between == != and > < */ + switch (pred) { - } + case CmpInst::ICMP_EQ: + case CmpInst::ICMP_NE: { - } + /* transformation for == and != icmps */ + + /* create a compare for the lower half of the original operands */ + BasicBlock *cmp_low_bb = + BasicBlock::Create(C, "" /*"injected"*/, end_bb->getParent(), end_bb); + + Value * op0_low, *op1_low; + IRBuilder<> Builder(cmp_low_bb); + + op0_low = Builder.CreateTrunc(op0, NewIntType); + op1_low = Builder.CreateTrunc(op1, NewIntType); + icmp_low = cast<CmpInst>(Builder.CreateICmp(pred, op0_low, op1_low)); + + BranchInst::Create(end_bb, cmp_low_bb); + + /* dependent on the cmp of the high parts go to the end or go on with + * the comparison */ + auto term = bb->getTerminator(); + BranchInst *br = nullptr; + if (pred == CmpInst::ICMP_EQ) { + + br = BranchInst::Create(cmp_low_bb, end_bb, icmp_high, bb); + + } else { + + /* CmpInst::ICMP_NE */ + br = BranchInst::Create(end_bb, cmp_low_bb, icmp_high, bb); + + } + + term->eraseFromParent(); + + /* create the PHI and connect the edges accordingly */ + PN = PHINode::Create(Int1Ty, 2, ""); + PN->addIncoming(icmp_low, cmp_low_bb); + Value *val = nullptr; + if (pred == CmpInst::ICMP_EQ) { + + val = ConstantInt::get(Int1Ty, 0); + + } else { + + /* CmpInst::ICMP_NE */ + val = ConstantInt::get(Int1Ty, 1); } + PN->addIncoming(val, icmp_high->getParent()); + break; + } - } + case CmpInst::ICMP_UGT: + case CmpInst::ICMP_ULT: { - if (!icomps.size()) { return false; } + /* transformations for < and > */ - for (auto &IcmpInst : icomps) { + /* create a basic block which checks for the inverse predicate. + * if this is true we can go to the end if not we have to go to the + * bb which checks the lower half of the operands */ + Instruction *op0_low, *op1_low; + CmpInst * icmp_inv_cmp = nullptr; + BasicBlock * inv_cmp_bb = + BasicBlock::Create(C, "inv_cmp", end_bb->getParent(), end_bb); + if (pred == CmpInst::ICMP_UGT) { - BasicBlock *bb = IcmpInst->getParent(); + icmp_inv_cmp = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULT, + op0_high, op1_high); - auto op0 = IcmpInst->getOperand(0); - auto op1 = IcmpInst->getOperand(1); + } else { - IntegerType *intTyOp0 = dyn_cast<IntegerType>(op0->getType()); - if (!intTyOp0) { continue; } - unsigned bitw = intTyOp0->getBitWidth(); - IntegerType *IntType = IntegerType::get(C, bitw); + icmp_inv_cmp = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_UGT, + op0_high, op1_high); - /* get the new predicate */ - auto cmp_inst = dyn_cast<CmpInst>(IcmpInst); - if (!cmp_inst) { continue; } - auto pred = cmp_inst->getPredicate(); - CmpInst::Predicate new_pred; + } - if (pred == CmpInst::ICMP_SGT) { + inv_cmp_bb->getInstList().push_back(icmp_inv_cmp); + worklist.push_back(icmp_inv_cmp); - new_pred = CmpInst::ICMP_UGT; + auto term = bb->getTerminator(); + term->eraseFromParent(); + BranchInst::Create(end_bb, inv_cmp_bb, icmp_high, bb); - } else { + /* create a bb which handles the cmp of the lower halves */ + BasicBlock *cmp_low_bb = + BasicBlock::Create(C, "" /*"injected"*/, end_bb->getParent(), end_bb); + op0_low = new TruncInst(op0, NewIntType); + cmp_low_bb->getInstList().push_back(op0_low); + op1_low = new TruncInst(op1, NewIntType); + cmp_low_bb->getInstList().push_back(op1_low); - new_pred = CmpInst::ICMP_ULT; + icmp_low = CmpInst::Create(Instruction::ICmp, pred, op0_low, op1_low); + cmp_low_bb->getInstList().push_back(icmp_low); + BranchInst::Create(end_bb, cmp_low_bb); + + BranchInst::Create(end_bb, cmp_low_bb, icmp_inv_cmp, inv_cmp_bb); + + PN = PHINode::Create(Int1Ty, 3); + PN->addIncoming(icmp_low, cmp_low_bb); + PN->addIncoming(ConstantInt::get(Int1Ty, 1), bb); + PN->addIncoming(ConstantInt::get(Int1Ty, 0), inv_cmp_bb); + break; } - BasicBlock *end_bb = bb->splitBasicBlock(BasicBlock::iterator(IcmpInst)); + default: + return false; - /* create a 1 bit compare for the sign bit. to do this shift and trunc - * the original operands so only the first bit remains.*/ - Instruction *s_op0, *t_op0, *s_op1, *t_op1, *icmp_sign_bit; + } - s_op0 = BinaryOperator::Create(Instruction::LShr, op0, - ConstantInt::get(IntType, bitw - 1)); - bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), s_op0); - t_op0 = new TruncInst(s_op0, Int1Ty); - bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), t_op0); + BasicBlock::iterator ii(cmp_inst); + ReplaceInstWithInst(cmp_inst->getParent()->getInstList(), ii, PN); - s_op1 = BinaryOperator::Create(Instruction::LShr, op1, - ConstantInt::get(IntType, bitw - 1)); - bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), s_op1); - t_op1 = new TruncInst(s_op1, Int1Ty); - bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), t_op1); + // We split the comparison into low and high. If this isn't our target + // bitwidth we recursivly split the low and high parts again until we have + // target bitwidth. + if ((bitw / 2) > target_bitwidth) { - /* compare of the sign bits */ - icmp_sign_bit = - CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, t_op0, t_op1); - bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), - icmp_sign_bit); + worklist.push_back(icmp_high); + worklist.push_back(icmp_low); - /* create a new basic block which is executed if the signedness bit is - * different */ - Instruction *icmp_inv_sig_cmp; - BasicBlock * sign_bb = - BasicBlock::Create(C, "sign", end_bb->getParent(), end_bb); - if (pred == CmpInst::ICMP_SGT) { + } - /* if we check for > and the op0 positive and op1 negative then the final - * result is true. if op0 negative and op1 pos, the cmp must result - * in false - */ - icmp_inv_sig_cmp = - CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULT, t_op0, t_op1); + return true; - } else { +} + +bool SplitComparesTransform::simplifyAndSplit(CmpInst *I, Module &M) { + + CmpWorklist worklist; + + auto op0 = I->getOperand(0); + auto op1 = I->getOperand(1); + if (!op0 || !op1) { return false; } + auto op0Ty = dyn_cast<IntegerType>(op0->getType()); + if (!op0Ty || !isa<IntegerType>(op1->getType())) { return true; } + + unsigned bitw = op0Ty->getBitWidth(); + +#ifdef VERIFY_TOO_MUCH + auto F = I->getParent()->getParent(); +#endif - /* just the inverse of the above statement */ - icmp_inv_sig_cmp = - CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_UGT, t_op0, t_op1); + // we run the comparison simplification on all compares regardless of their + // bitwidth. + if (I->getPredicate() == CmpInst::ICMP_UGE || + I->getPredicate() == CmpInst::ICMP_SGE || + I->getPredicate() == CmpInst::ICMP_ULE || + I->getPredicate() == CmpInst::ICMP_SLE) { + + if (!simplifyOrEqualsCompare(I, M, worklist)) { + + reportError( + "Failed to simplify inequality or equals comparison " + "(UGE,SGE,ULE,SLE)", + I, M); } - sign_bb->getInstList().push_back(icmp_inv_sig_cmp); - BranchInst::Create(end_bb, sign_bb); + } else if (I->getPredicate() == CmpInst::ICMP_SGT || - /* create a new bb which is executed if signedness is equal */ - Instruction *icmp_usign_cmp; - BasicBlock * middle_bb = - BasicBlock::Create(C, "injected", end_bb->getParent(), end_bb); - /* we can do a normal unsigned compare now */ - icmp_usign_cmp = CmpInst::Create(Instruction::ICmp, new_pred, op0, op1); - middle_bb->getInstList().push_back(icmp_usign_cmp); - BranchInst::Create(end_bb, middle_bb); + I->getPredicate() == CmpInst::ICMP_SLT) { - auto term = bb->getTerminator(); - /* if the sign is eq do a normal unsigned cmp, else we have to check the - * signedness bit */ - BranchInst::Create(middle_bb, sign_bb, icmp_sign_bit, bb); - term->eraseFromParent(); + if (!simplifySignedCompare(I, M, worklist)) { - PHINode *PN = PHINode::Create(Int1Ty, 2, ""); + reportError("Failed to simplify signed comparison (SGT,SLT)", I, M); + + } + + } + +#ifdef VERIFY_TOO_MUCH + if (verifyFunction(*F, &errs())) { + + reportError("simpliyfing compare lead to broken function", nullptr, M); + + } + +#endif + + // the simplification methods replace the original CmpInst and push the + // resulting new CmpInst into the worklist. If the worklist is empty then + // we only have to split the original CmpInst. + if (worklist.size() == 0) { worklist.push_back(I); } + + while (!worklist.empty()) { + + CmpInst *cmp = worklist.pop_back_val(); + // we split the simplified compares into comparisons with smaller bitwidths + // if they are larger than our target_bitwidth. + if (bitw > target_bitwidth) { + + if (!splitCompare(cmp, M, worklist)) { + + reportError("Failed to split comparison", cmp, M); + + } + +#ifdef VERIFY_TOO_MUCH + if (verifyFunction(*F, &errs())) { + + reportError("splitting compare lead to broken function", nullptr, M); + + } - PN->addIncoming(icmp_usign_cmp, middle_bb); - PN->addIncoming(icmp_inv_sig_cmp, sign_bb); +#endif - BasicBlock::iterator ii(IcmpInst); - ReplaceInstWithInst(IcmpInst->getParent()->getInstList(), ii, PN); + } } + count++; return true; } @@ -1050,306 +1316,110 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) { } -/* splits icmps of size bitw into two nested icmps with bitw/2 size each */ -size_t SplitComparesTransform::splitIntCompares(Module &M, unsigned bitw) { - - size_t count = 0; - - LLVMContext &C = M.getContext(); - - IntegerType *Int1Ty = IntegerType::getInt1Ty(C); - IntegerType *OldIntType = IntegerType::get(C, bitw); - IntegerType *NewIntType = IntegerType::get(C, bitw / 2); - - std::vector<Instruction *> icomps; - - if (bitw % 2) { return 0; } - - /* not supported yet */ - if (bitw > 64) { return 0; } - - /* get all EQ, NE, UGT, and ULT icmps of width bitw. if the - * functions simplifyCompares() and simplifyIntSignedness() - * were executed only these four predicates should exist */ - for (auto &F : M) { - - if (!isInInstrumentList(&F)) continue; +bool SplitComparesTransform::runOnModule(Module &M) { - for (auto &BB : F) { + char *bitw_env = getenv("AFL_LLVM_LAF_SPLIT_COMPARES_BITW"); + if (!bitw_env) bitw_env = getenv("LAF_SPLIT_COMPARES_BITW"); + if (bitw_env) { target_bitwidth = atoi(bitw_env); } - for (auto &IN : BB) { + enableFPSplit = getenv("AFL_LLVM_LAF_SPLIT_FLOATS") != NULL; - CmpInst *selectcmpInst = nullptr; + if ((isatty(2) && getenv("AFL_QUIET") == NULL) || + getenv("AFL_DEBUG") != NULL) { - if ((selectcmpInst = dyn_cast<CmpInst>(&IN))) { + errs() << "Split-compare-pass by laf.intel@gmail.com, extended by " + "heiko@hexco.de (splitting icmp to " + << target_bitwidth << " bit)\n"; - if (selectcmpInst->getPredicate() == CmpInst::ICMP_EQ || - selectcmpInst->getPredicate() == CmpInst::ICMP_NE || - selectcmpInst->getPredicate() == CmpInst::ICMP_UGT || - selectcmpInst->getPredicate() == CmpInst::ICMP_ULT) { + if (getenv("AFL_DEBUG") != NULL && !debug) { debug = 1; } - auto op0 = selectcmpInst->getOperand(0); - auto op1 = selectcmpInst->getOperand(1); - - IntegerType *intTyOp0 = dyn_cast<IntegerType>(op0->getType()); - IntegerType *intTyOp1 = dyn_cast<IntegerType>(op1->getType()); + } else { - if (!intTyOp0 || !intTyOp1) { continue; } + be_quiet = 1; - /* check if the bitwidths are the one we are looking for */ - if (intTyOp0->getBitWidth() != bitw || - intTyOp1->getBitWidth() != bitw) { + } - continue; + if (enableFPSplit) { - } + count = splitFPCompares(M); - icomps.push_back(selectcmpInst); + /* + if (!be_quiet) { - } + errs() << "Split-floatingpoint-compare-pass: " << count + << " FP comparisons split\n"; } - } - - } + */ + simplifyFPCompares(M); } - if (!icomps.size()) { return 0; } - - for (auto &IcmpInst : icomps) { - - BasicBlock *bb = IcmpInst->getParent(); - - auto op0 = IcmpInst->getOperand(0); - auto op1 = IcmpInst->getOperand(1); - - auto cmp_inst = dyn_cast<CmpInst>(IcmpInst); - if (!cmp_inst) { continue; } - auto pred = cmp_inst->getPredicate(); - - BasicBlock *end_bb = bb->splitBasicBlock(BasicBlock::iterator(IcmpInst)); - - /* create the comparison of the top halves of the original operands */ - Instruction *s_op0, *op0_high, *s_op1, *op1_high, *icmp_high; - - s_op0 = BinaryOperator::Create(Instruction::LShr, op0, - ConstantInt::get(OldIntType, bitw / 2)); - bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), s_op0); - op0_high = new TruncInst(s_op0, NewIntType); - bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), - op0_high); - - s_op1 = BinaryOperator::Create(Instruction::LShr, op1, - ConstantInt::get(OldIntType, bitw / 2)); - bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), s_op1); - op1_high = new TruncInst(s_op1, NewIntType); - bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), - op1_high); - - icmp_high = CmpInst::Create(Instruction::ICmp, pred, op0_high, op1_high); - bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), - icmp_high); - - /* now we have to destinguish between == != and > < */ - if (pred == CmpInst::ICMP_EQ || pred == CmpInst::ICMP_NE) { - - /* transformation for == and != icmps */ - - /* create a compare for the lower half of the original operands */ - Instruction *op0_low, *op1_low, *icmp_low; - BasicBlock * cmp_low_bb = - BasicBlock::Create(C, "injected", end_bb->getParent(), end_bb); - - op0_low = new TruncInst(op0, NewIntType); - cmp_low_bb->getInstList().push_back(op0_low); - - op1_low = new TruncInst(op1, NewIntType); - cmp_low_bb->getInstList().push_back(op1_low); - - icmp_low = CmpInst::Create(Instruction::ICmp, pred, op0_low, op1_low); - cmp_low_bb->getInstList().push_back(icmp_low); - BranchInst::Create(end_bb, cmp_low_bb); - - /* dependent on the cmp of the high parts go to the end or go on with - * the comparison */ - auto term = bb->getTerminator(); - if (pred == CmpInst::ICMP_EQ) { - - BranchInst::Create(cmp_low_bb, end_bb, icmp_high, bb); - - } else { - - /* CmpInst::ICMP_NE */ - BranchInst::Create(end_bb, cmp_low_bb, icmp_high, bb); - - } - - term->eraseFromParent(); - - /* create the PHI and connect the edges accordingly */ - PHINode *PN = PHINode::Create(Int1Ty, 2, ""); - PN->addIncoming(icmp_low, cmp_low_bb); - if (pred == CmpInst::ICMP_EQ) { - - PN->addIncoming(ConstantInt::get(Int1Ty, 0), bb); - - } else { + std::vector<CmpInst *> worklist; + /* iterate over all functions, bbs and instruction search for all integer + * compare instructions. Save them into the worklist for later. */ + for (auto &F : M) { - /* CmpInst::ICMP_NE */ - PN->addIncoming(ConstantInt::get(Int1Ty, 1), bb); + if (!isInInstrumentList(&F)) continue; - } + for (auto &BB : F) { - /* replace the old icmp with the new PHI */ - BasicBlock::iterator ii(IcmpInst); - ReplaceInstWithInst(IcmpInst->getParent()->getInstList(), ii, PN); + for (auto &IN : BB) { - } else { + if (auto CI = dyn_cast<CmpInst>(&IN)) { - /* CmpInst::ICMP_UGT and CmpInst::ICMP_ULT */ - /* transformations for < and > */ + auto op0 = CI->getOperand(0); + auto op1 = CI->getOperand(1); + if (!op0 || !op1) { return false; } + auto iTy1 = dyn_cast<IntegerType>(op0->getType()); + if (iTy1 && isa<IntegerType>(op1->getType())) { - /* create a basic block which checks for the inverse predicate. - * if this is true we can go to the end if not we have to go to the - * bb which checks the lower half of the operands */ - Instruction *icmp_inv_cmp, *op0_low, *op1_low, *icmp_low; - BasicBlock * inv_cmp_bb = - BasicBlock::Create(C, "inv_cmp", end_bb->getParent(), end_bb); - if (pred == CmpInst::ICMP_UGT) { + unsigned bitw = iTy1->getBitWidth(); + if (isSupportedBitWidth(bitw)) { worklist.push_back(CI); } - icmp_inv_cmp = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULT, - op0_high, op1_high); - - } else { + } - icmp_inv_cmp = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_UGT, - op0_high, op1_high); + } } - inv_cmp_bb->getInstList().push_back(icmp_inv_cmp); - - auto term = bb->getTerminator(); - term->eraseFromParent(); - BranchInst::Create(end_bb, inv_cmp_bb, icmp_high, bb); - - /* create a bb which handles the cmp of the lower halves */ - BasicBlock *cmp_low_bb = - BasicBlock::Create(C, "injected", end_bb->getParent(), end_bb); - op0_low = new TruncInst(op0, NewIntType); - cmp_low_bb->getInstList().push_back(op0_low); - op1_low = new TruncInst(op1, NewIntType); - cmp_low_bb->getInstList().push_back(op1_low); - - icmp_low = CmpInst::Create(Instruction::ICmp, pred, op0_low, op1_low); - cmp_low_bb->getInstList().push_back(icmp_low); - BranchInst::Create(end_bb, cmp_low_bb); - - BranchInst::Create(end_bb, cmp_low_bb, icmp_inv_cmp, inv_cmp_bb); - - PHINode *PN = PHINode::Create(Int1Ty, 3); - PN->addIncoming(icmp_low, cmp_low_bb); - PN->addIncoming(ConstantInt::get(Int1Ty, 1), bb); - PN->addIncoming(ConstantInt::get(Int1Ty, 0), inv_cmp_bb); - - BasicBlock::iterator ii(IcmpInst); - ReplaceInstWithInst(IcmpInst->getParent()->getInstList(), ii, PN); - } - ++count; - } - return count; - -} - -bool SplitComparesTransform::runOnModule(Module &M) { - - int bitw = 64; - size_t count = 0; - - char *bitw_env = getenv("AFL_LLVM_LAF_SPLIT_COMPARES_BITW"); - if (!bitw_env) bitw_env = getenv("LAF_SPLIT_COMPARES_BITW"); - if (bitw_env) { bitw = atoi(bitw_env); } - - enableFPSplit = getenv("AFL_LLVM_LAF_SPLIT_FLOATS") != NULL; - - if ((isatty(2) && getenv("AFL_QUIET") == NULL) || - getenv("AFL_DEBUG") != NULL) { - - printf( - "Split-compare-pass by laf.intel@gmail.com, extended by " - "heiko@hexco.de\n"); + // now that we have a list of all integer comparisons we can start replacing + // them with the splitted alternatives. + for (auto CI : worklist) { - } else { - - be_quiet = 1; + simplifyAndSplit(CI, M); } - if (enableFPSplit) { - - count = splitFPCompares(M); - - /* - if (!be_quiet) { - - errs() << "Split-floatingpoint-compare-pass: " << count - << " FP comparisons split\n"; + bool brokenDebug = false; + if (verifyModule(M, &errs() +#if LLVM_VERSION_MAJOR > 3 || \ + (LLVM_VERSION_MAJOR == 3 && LLVM_VERSION_MINOR >= 9) + , + &brokenDebug // 9th May 2016 +#endif + )) { - } - - */ - simplifyFPCompares(M); + reportError( + "Module Verifier failed! Consider reporting a bug with the AFL++ " + "project.", + nullptr, M); } - simplifyCompares(M); - - simplifyIntSignedness(M); + if (brokenDebug) { - switch (bitw) { - - case 64: - count += splitIntCompares(M, bitw); - if (debug) - errs() << "Split-integer-compare-pass " << bitw << "bit: " << count - << " split\n"; - bitw >>= 1; -#if LLVM_VERSION_MAJOR > 3 || \ - (LLVM_VERSION_MAJOR == 3 && LLVM_VERSION_MINOR > 7) - [[clang::fallthrough]]; /*FALLTHRU*/ /* FALLTHROUGH */ -#endif - case 32: - count += splitIntCompares(M, bitw); - if (debug) - errs() << "Split-integer-compare-pass " << bitw << "bit: " << count - << " split\n"; - bitw >>= 1; -#if LLVM_VERSION_MAJOR > 3 || \ - (LLVM_VERSION_MAJOR == 3 && LLVM_VERSION_MINOR > 7) - [[clang::fallthrough]]; /*FALLTHRU*/ /* FALLTHROUGH */ -#endif - case 16: - count += splitIntCompares(M, bitw); - if (debug) - errs() << "Split-integer-compare-pass " << bitw << "bit: " << count - << " split\n"; - // bitw >>= 1; - break; - - default: - // if (!be_quiet) errs() << "NOT Running split-compare-pass \n"; - return false; - break; + reportError("Module Verifier reported broken Debug Infos - Stripping!", + nullptr, M); + StripDebugInfo(M); } - verifyModule(M); return true; } @@ -1373,3 +1443,8 @@ static RegisterStandardPasses RegisterSplitComparesTransPassLTO( registerSplitComparesPass); #endif +static RegisterPass<SplitComparesTransform> X("splitcompares", + "AFL++ split compares", + true /* Only looks at CFG */, + true /* Analysis Pass */); + |