diff options
-rw-r--r-- | instrumentation/split-compares-pass.so.cc | 1005 | ||||
-rwxr-xr-x | test/test-llvm.sh | 23 | ||||
-rw-r--r-- | test/test-uint_cases.c | 30 |
3 files changed, 485 insertions, 573 deletions
diff --git a/instrumentation/split-compares-pass.so.cc b/instrumentation/split-compares-pass.so.cc index b02a89fb..6eb9050c 100644 --- a/instrumentation/split-compares-pass.so.cc +++ b/instrumentation/split-compares-pass.so.cc @@ -47,50 +47,101 @@ using namespace llvm; #include "afl-llvm-common.h" +// uncomment this toggle function verification at each step. horribly slow, but +// helps to pinpoint a potential problem in the splitting code. +//#define VERIFY_TOO_MUCH 1 + namespace { class SplitComparesTransform : public ModulePass { - public: static char ID; SplitComparesTransform() : ModulePass(ID), enableFPSplit(0) { - initInstrumentList(); - } bool runOnModule(Module &M) override; #if LLVM_VERSION_MAJOR >= 4 StringRef getPassName() const override { - #else const char *getPassName() const override { #endif - return "simplifies and splits ICMP instructions"; - + return "AFL_SplitComparesTransform"; } private: int enableFPSplit; - size_t splitIntCompares(Module &M, unsigned bitw); + unsigned target_bitwidth = 8; + + size_t count = 0; + size_t splitFPCompares(Module &M); - bool simplifyCompares(Module &M); bool simplifyFPCompares(Module &M); - bool simplifyIntSignedness(Module &M); size_t nextPowerOfTwo(size_t in); + using CmpWorklist = SmallVector<CmpInst *, 8>; + + /// simplify the comparison and then split the comparison until the + /// target_bitwidth is reached. + bool simplifyAndSplit(CmpInst *I, Module &M); + /// simplify a non-strict comparison (e.g., less than or equals) + bool simplifyOrEqualsCompare(CmpInst *IcmpInst, Module &M, + CmpWorklist &worklist); + /// simplify a signed comparison (signed less or greater than) + bool simplifySignedCompare(CmpInst *IcmpInst, Module &M, + CmpWorklist &worklist); + /// splits an icmp into nested icmps recursivly until target_bitwidth is + /// reached + bool splitCompare(CmpInst *I, Module &M, CmpWorklist &worklist); + + /// print an error to llvm's errs stream, but only if not ordered to be quiet + void reportError(const StringRef msg, Instruction *I, Module &M) { + if (!be_quiet) { + errs() << "[AFL++ SplitComparesTransform] ERROR: " << msg << "\n"; + if (debug) { + if (I) { + errs() << "Instruction = " << *I << "\n"; + if (auto BB = I->getParent()) { + if (auto F = BB->getParent()) { + if (F->hasName()) { + errs() << "|-> in function " << F->getName() << " "; + } + } + } + } + auto n = M.getName(); + if (n.size() > 0) { errs() << "in module " << n << "\n"; } + } + } + } + + bool isSupportedBitWidth(unsigned bitw) { + // IDK whether the icmp code works on other bitwidths. I guess not? So we + // try to avoid dealing with other weird icmp's that llvm might use (looking + // at you `icmp i0`). + switch (bitw) { + case 8: + case 16: + case 32: + case 64: + case 128: + case 256: + return true; + default: + return false; + } + } }; } // namespace char SplitComparesTransform::ID = 0; -/* This function splits FCMP instructions with xGE or xLE predicates into two - * FCMP instructions with predicate xGT or xLT and EQ */ +/// This function splits FCMP instructions with xGE or xLE predicates into two +/// FCMP instructions with predicate xGT or xLT and EQ bool SplitComparesTransform::simplifyFPCompares(Module &M) { - LLVMContext & C = M.getContext(); std::vector<Instruction *> fcomps; IntegerType * Int1Ty = IntegerType::getInt1Ty(C); @@ -98,23 +149,18 @@ bool SplitComparesTransform::simplifyFPCompares(Module &M) { /* iterate over all functions, bbs and instruction and add * all integer comparisons with >= and <= predicates to the icomps vector */ for (auto &F : M) { - if (!isInInstrumentList(&F)) continue; for (auto &BB : F) { - for (auto &IN : BB) { - CmpInst *selectcmpInst = nullptr; if ((selectcmpInst = dyn_cast<CmpInst>(&IN))) { - if (enableFPSplit && (selectcmpInst->getPredicate() == CmpInst::FCMP_OGE || selectcmpInst->getPredicate() == CmpInst::FCMP_UGE || selectcmpInst->getPredicate() == CmpInst::FCMP_OLE || selectcmpInst->getPredicate() == CmpInst::FCMP_ULE)) { - auto op0 = selectcmpInst->getOperand(0); auto op1 = selectcmpInst->getOperand(1); @@ -127,22 +173,16 @@ bool SplitComparesTransform::simplifyFPCompares(Module &M) { if (TyOp0->isArrayTy() || TyOp0->isVectorTy()) { continue; } fcomps.push_back(selectcmpInst); - } - } - } - } - } if (!fcomps.size()) { return false; } /* transform for floating point */ for (auto &FcmpInst : fcomps) { - BasicBlock *bb = FcmpInst->getParent(); auto op0 = FcmpInst->getOperand(0); @@ -155,7 +195,6 @@ bool SplitComparesTransform::simplifyFPCompares(Module &M) { CmpInst::Predicate new_pred; switch (pred) { - case CmpInst::FCMP_UGE: new_pred = CmpInst::FCMP_UGT; break; @@ -170,7 +209,6 @@ bool SplitComparesTransform::simplifyFPCompares(Module &M) { break; default: // keep the compiler happy continue; - } /* split before the fcmp instruction */ @@ -214,305 +252,425 @@ bool SplitComparesTransform::simplifyFPCompares(Module &M) { /* replace the old FcmpInst with our new and shiny PHI inst */ BasicBlock::iterator ii(FcmpInst); ReplaceInstWithInst(FcmpInst->getParent()->getInstList(), ii, PN); - } return true; - } -/* This function splits ICMP instructions with xGE or xLE predicates into two - * ICMP instructions with predicate xGT or xLT and EQ */ -bool SplitComparesTransform::simplifyCompares(Module &M) { - - LLVMContext & C = M.getContext(); - std::vector<Instruction *> icomps; - IntegerType * Int1Ty = IntegerType::getInt1Ty(C); - - /* iterate over all functions, bbs and instruction and add - * all integer comparisons with >= and <= predicates to the icomps vector */ - for (auto &F : M) { - - if (!isInInstrumentList(&F)) continue; +/// This function splits ICMP instructions with xGE or xLE predicates into two +/// ICMP instructions with predicate xGT or xLT and EQ +bool SplitComparesTransform::simplifyOrEqualsCompare(CmpInst * IcmpInst, + Module & M, + CmpWorklist &worklist) { + LLVMContext &C = M.getContext(); + IntegerType *Int1Ty = IntegerType::getInt1Ty(C); - for (auto &BB : F) { + /* find out what the new predicate is going to be */ + auto cmp_inst = dyn_cast<CmpInst>(IcmpInst); + if (!cmp_inst) { return false; } - for (auto &IN : BB) { + BasicBlock *bb = IcmpInst->getParent(); - CmpInst *selectcmpInst = nullptr; + auto op0 = IcmpInst->getOperand(0); + auto op1 = IcmpInst->getOperand(1); - if ((selectcmpInst = dyn_cast<CmpInst>(&IN))) { + CmpInst::Predicate pred = cmp_inst->getPredicate(); + CmpInst::Predicate new_pred; - if (selectcmpInst->getPredicate() == CmpInst::ICMP_UGE || - selectcmpInst->getPredicate() == CmpInst::ICMP_SGE || - selectcmpInst->getPredicate() == CmpInst::ICMP_ULE || - selectcmpInst->getPredicate() == CmpInst::ICMP_SLE) { + switch (pred) { + case CmpInst::ICMP_UGE: + new_pred = CmpInst::ICMP_UGT; + break; + case CmpInst::ICMP_SGE: + new_pred = CmpInst::ICMP_SGT; + break; + case CmpInst::ICMP_ULE: + new_pred = CmpInst::ICMP_ULT; + break; + case CmpInst::ICMP_SLE: + new_pred = CmpInst::ICMP_SLT; + break; + default: // keep the compiler happy + return false; + } - auto op0 = selectcmpInst->getOperand(0); - auto op1 = selectcmpInst->getOperand(1); + /* split before the icmp instruction */ + BasicBlock *end_bb = bb->splitBasicBlock(BasicBlock::iterator(IcmpInst)); + + /* the old bb now contains a unconditional jump to the new one (end_bb) + * we need to delete it later */ + + /* create the ICMP instruction with new_pred and add it to the old basic + * block bb it is now at the position where the old IcmpInst was */ + CmpInst *icmp_np = CmpInst::Create(Instruction::ICmp, new_pred, op0, op1); + bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), icmp_np); + + /* create a new basic block which holds the new EQ icmp */ + CmpInst *icmp_eq; + /* insert middle_bb before end_bb */ + BasicBlock *middle_bb = + BasicBlock::Create(C, "injected", end_bb->getParent(), end_bb); + icmp_eq = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, op0, op1); + middle_bb->getInstList().push_back(icmp_eq); + /* add an unconditional branch to the end of middle_bb with destination + * end_bb */ + BranchInst::Create(end_bb, middle_bb); + + /* replace the uncond branch with a conditional one, which depends on the + * new_pred icmp. True goes to end, false to the middle (injected) bb */ + auto term = bb->getTerminator(); + BranchInst::Create(end_bb, middle_bb, icmp_np, bb); + term->eraseFromParent(); + + /* replace the old IcmpInst (which is the first inst in end_bb) with a PHI + * inst to wire up the loose ends */ + PHINode *PN = PHINode::Create(Int1Ty, 2, ""); + /* the first result depends on the outcome of icmp_eq */ + PN->addIncoming(icmp_eq, middle_bb); + /* if the source was the original bb we know that the icmp_np yielded true + * hence we can hardcode this value */ + PN->addIncoming(ConstantInt::get(Int1Ty, 1), bb); + /* replace the old IcmpInst with our new and shiny PHI inst */ + BasicBlock::iterator ii(IcmpInst); + ReplaceInstWithInst(IcmpInst->getParent()->getInstList(), ii, PN); + + worklist.push_back(icmp_np); + worklist.push_back(icmp_eq); - IntegerType *intTyOp0 = dyn_cast<IntegerType>(op0->getType()); - IntegerType *intTyOp1 = dyn_cast<IntegerType>(op1->getType()); + return true; +} - /* this is probably not needed but we do it anyway */ - if (!intTyOp0 || !intTyOp1) { continue; } +/// Simplify a signed comparison operator by splitting it into a unsigned and +/// bit comparison. add all resulting comparisons to +/// the worklist passed as a reference. +bool SplitComparesTransform::simplifySignedCompare(CmpInst *IcmpInst, Module &M, + CmpWorklist &worklist) { + LLVMContext &C = M.getContext(); + IntegerType *Int1Ty = IntegerType::getInt1Ty(C); - icomps.push_back(selectcmpInst); + BasicBlock *bb = IcmpInst->getParent(); - } + auto op0 = IcmpInst->getOperand(0); + auto op1 = IcmpInst->getOperand(1); - } + IntegerType *intTyOp0 = dyn_cast<IntegerType>(op0->getType()); + if (!intTyOp0) { return false; } + unsigned bitw = intTyOp0->getBitWidth(); + IntegerType *IntType = IntegerType::get(C, bitw); - } + /* get the new predicate */ + auto cmp_inst = dyn_cast<CmpInst>(IcmpInst); + if (!cmp_inst) { return false; } + auto pred = cmp_inst->getPredicate(); + CmpInst::Predicate new_pred; - } + if (pred == CmpInst::ICMP_SGT) { + new_pred = CmpInst::ICMP_UGT; + } else { + new_pred = CmpInst::ICMP_ULT; } - if (!icomps.size()) { return false; } - - for (auto &IcmpInst : icomps) { - - BasicBlock *bb = IcmpInst->getParent(); - - auto op0 = IcmpInst->getOperand(0); - auto op1 = IcmpInst->getOperand(1); - - /* find out what the new predicate is going to be */ - auto cmp_inst = dyn_cast<CmpInst>(IcmpInst); - if (!cmp_inst) { continue; } - auto pred = cmp_inst->getPredicate(); - CmpInst::Predicate new_pred; + BasicBlock *end_bb = bb->splitBasicBlock(BasicBlock::iterator(IcmpInst)); + + /* create a 1 bit compare for the sign bit. to do this shift and trunc + * the original operands so only the first bit remains.*/ + Value *s_op0, *t_op0, *s_op1, *t_op1, *icmp_sign_bit; + + IRBuilder<> IRB(bb->getTerminator()); + s_op0 = IRB.CreateLShr(op0, ConstantInt::get(IntType, bitw - 1)); + t_op0 = IRB.CreateTruncOrBitCast(s_op0, Int1Ty); + s_op1 = IRB.CreateLShr(op1, ConstantInt::get(IntType, bitw - 1)); + t_op1 = IRB.CreateTruncOrBitCast(s_op1, Int1Ty); + /* compare of the sign bits */ + icmp_sign_bit = IRB.CreateCmp(CmpInst::ICMP_EQ, t_op0, t_op1); + + /* create a new basic block which is executed if the signedness bit is + * different */ + CmpInst * icmp_inv_sig_cmp; + BasicBlock *sign_bb = + BasicBlock::Create(C, "sign", end_bb->getParent(), end_bb); + if (pred == CmpInst::ICMP_SGT) { + /* if we check for > and the op0 positive and op1 negative then the final + * result is true. if op0 negative and op1 pos, the cmp must result + * in false + */ + icmp_inv_sig_cmp = + CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULT, t_op0, t_op1); - switch (pred) { + } else { + /* just the inverse of the above statement */ + icmp_inv_sig_cmp = + CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_UGT, t_op0, t_op1); + } - case CmpInst::ICMP_UGE: - new_pred = CmpInst::ICMP_UGT; - break; - case CmpInst::ICMP_SGE: - new_pred = CmpInst::ICMP_SGT; - break; - case CmpInst::ICMP_ULE: - new_pred = CmpInst::ICMP_ULT; - break; - case CmpInst::ICMP_SLE: - new_pred = CmpInst::ICMP_SLT; - break; - default: // keep the compiler happy - continue; + sign_bb->getInstList().push_back(icmp_inv_sig_cmp); + BranchInst::Create(end_bb, sign_bb); - } + /* create a new bb which is executed if signedness is equal */ + CmpInst * icmp_usign_cmp; + BasicBlock *middle_bb = + BasicBlock::Create(C, "injected", end_bb->getParent(), end_bb); + /* we can do a normal unsigned compare now */ + icmp_usign_cmp = CmpInst::Create(Instruction::ICmp, new_pred, op0, op1); - /* split before the icmp instruction */ - BasicBlock *end_bb = bb->splitBasicBlock(BasicBlock::iterator(IcmpInst)); + middle_bb->getInstList().push_back(icmp_usign_cmp); + BranchInst::Create(end_bb, middle_bb); - /* the old bb now contains a unconditional jump to the new one (end_bb) - * we need to delete it later */ + auto term = bb->getTerminator(); + /* if the sign is eq do a normal unsigned cmp, else we have to check the + * signedness bit */ + BranchInst::Create(middle_bb, sign_bb, icmp_sign_bit, bb); + term->eraseFromParent(); - /* create the ICMP instruction with new_pred and add it to the old basic - * block bb it is now at the position where the old IcmpInst was */ - Instruction *icmp_np; - icmp_np = CmpInst::Create(Instruction::ICmp, new_pred, op0, op1); - bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), - icmp_np); + PHINode *PN = PHINode::Create(Int1Ty, 2, ""); - /* create a new basic block which holds the new EQ icmp */ - Instruction *icmp_eq; - /* insert middle_bb before end_bb */ - BasicBlock *middle_bb = - BasicBlock::Create(C, "injected", end_bb->getParent(), end_bb); - icmp_eq = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, op0, op1); - middle_bb->getInstList().push_back(icmp_eq); - /* add an unconditional branch to the end of middle_bb with destination - * end_bb */ - BranchInst::Create(end_bb, middle_bb); + PN->addIncoming(icmp_usign_cmp, middle_bb); + PN->addIncoming(icmp_inv_sig_cmp, sign_bb); - /* replace the uncond branch with a conditional one, which depends on the - * new_pred icmp. True goes to end, false to the middle (injected) bb */ - auto term = bb->getTerminator(); - BranchInst::Create(end_bb, middle_bb, icmp_np, bb); - term->eraseFromParent(); + BasicBlock::iterator ii(IcmpInst); + ReplaceInstWithInst(IcmpInst->getParent()->getInstList(), ii, PN); - /* replace the old IcmpInst (which is the first inst in end_bb) with a PHI - * inst to wire up the loose ends */ - PHINode *PN = PHINode::Create(Int1Ty, 2, ""); - /* the first result depends on the outcome of icmp_eq */ - PN->addIncoming(icmp_eq, middle_bb); - /* if the source was the original bb we know that the icmp_np yielded true - * hence we can hardcode this value */ - PN->addIncoming(ConstantInt::get(Int1Ty, 1), bb); - /* replace the old IcmpInst with our new and shiny PHI inst */ - BasicBlock::iterator ii(IcmpInst); - ReplaceInstWithInst(IcmpInst->getParent()->getInstList(), ii, PN); + // save for later + worklist.push_back(icmp_usign_cmp); - } + // signed comparisons are not supported by the splitting code, so we must not + // add it to the worklist. + // worklist.push_back(icmp_inv_sig_cmp); return true; - } -/* this function transforms signed compares to equivalent unsigned compares */ -bool SplitComparesTransform::simplifyIntSignedness(Module &M) { - - LLVMContext & C = M.getContext(); - std::vector<Instruction *> icomps; - IntegerType * Int1Ty = IntegerType::getInt1Ty(C); - - /* iterate over all functions, bbs and instructions and add - * all signed compares to icomps vector */ - for (auto &F : M) { - - if (!isInInstrumentList(&F)) continue; +bool SplitComparesTransform::splitCompare(CmpInst *cmp_inst, Module &M, + CmpWorklist &worklist) { + auto pred = cmp_inst->getPredicate(); + switch (pred) { + case CmpInst::ICMP_EQ: + case CmpInst::ICMP_NE: + case CmpInst::ICMP_UGT: + case CmpInst::ICMP_ULT: + break; + default: + // unsupported predicate! + return false; + } - for (auto &BB : F) { + auto op0 = cmp_inst->getOperand(0); + auto op1 = cmp_inst->getOperand(1); - for (auto &IN : BB) { + // get bitwidth by checking the bitwidth of the first operator + IntegerType *intTyOp0 = dyn_cast<IntegerType>(op0->getType()); + if (!intTyOp0) { + // not an integer type + return false; + } - CmpInst *selectcmpInst = nullptr; + unsigned bitw = intTyOp0->getBitWidth(); + if (bitw == target_bitwidth) { + // already the target bitwidth so we have to do nothing here. + return true; + } - if ((selectcmpInst = dyn_cast<CmpInst>(&IN))) { + LLVMContext &C = M.getContext(); + IntegerType *Int1Ty = IntegerType::getInt1Ty(C); + BasicBlock * bb = cmp_inst->getParent(); + IntegerType *OldIntType = IntegerType::get(C, bitw); + IntegerType *NewIntType = IntegerType::get(C, bitw / 2); + BasicBlock * end_bb = bb->splitBasicBlock(BasicBlock::iterator(cmp_inst)); + CmpInst * icmp_high, *icmp_low; - if (selectcmpInst->getPredicate() == CmpInst::ICMP_SGT || - selectcmpInst->getPredicate() == CmpInst::ICMP_SLT) { + /* create the comparison of the top halves of the original operands */ + Value *s_op0, *op0_high, *s_op1, *op1_high; - auto op0 = selectcmpInst->getOperand(0); - auto op1 = selectcmpInst->getOperand(1); + IRBuilder<> IRB(bb->getTerminator()); - IntegerType *intTyOp0 = dyn_cast<IntegerType>(op0->getType()); - IntegerType *intTyOp1 = dyn_cast<IntegerType>(op1->getType()); + s_op0 = IRB.CreateBinOp(Instruction::LShr, op0, + ConstantInt::get(OldIntType, bitw / 2)); + op0_high = IRB.CreateTruncOrBitCast(s_op0, NewIntType); - /* see above */ - if (!intTyOp0 || !intTyOp1) { continue; } + s_op1 = IRB.CreateBinOp(Instruction::LShr, op1, + ConstantInt::get(OldIntType, bitw / 2)); + op1_high = IRB.CreateTruncOrBitCast(s_op1, NewIntType); + icmp_high = cast<CmpInst>(IRB.CreateICmp(pred, op0_high, op1_high)); - /* i think this is not possible but to lazy to look it up */ - if (intTyOp0->getBitWidth() != intTyOp1->getBitWidth()) { + PHINode *PN = nullptr; - continue; + /* now we have to destinguish between == != and > < */ + switch (pred) { + case CmpInst::ICMP_EQ: + case CmpInst::ICMP_NE: { + /* transformation for == and != icmps */ - } + /* create a compare for the lower half of the original operands */ + BasicBlock *cmp_low_bb = + BasicBlock::Create(C, "" /*"injected"*/, end_bb->getParent(), end_bb); - icomps.push_back(selectcmpInst); + Value *op0_low, *op1_low; + IRBuilder<> Builder(cmp_low_bb); - } + op0_low = Builder.CreateTrunc(op0, NewIntType); + op1_low = Builder.CreateTrunc(op1, NewIntType); + icmp_low = cast<CmpInst>(Builder.CreateICmp(pred, op0_low, op1_low)); - } + BranchInst::Create(end_bb, cmp_low_bb); + /* dependent on the cmp of the high parts go to the end or go on with + * the comparison */ + auto term = bb->getTerminator(); + BranchInst *br = nullptr; + if (pred == CmpInst::ICMP_EQ) { + br = BranchInst::Create(cmp_low_bb, end_bb, icmp_high, bb); + } else { + /* CmpInst::ICMP_NE */ + br = BranchInst::Create(end_bb, cmp_low_bb, icmp_high, bb); } + term->eraseFromParent(); + /* create the PHI and connect the edges accordingly */ + PN = PHINode::Create(Int1Ty, 2, ""); + PN->addIncoming(icmp_low, cmp_low_bb); + Value *val = nullptr; + if (pred == CmpInst::ICMP_EQ) { + val = ConstantInt::get(Int1Ty, 0); + } else { + /* CmpInst::ICMP_NE */ + val = ConstantInt::get(Int1Ty, 1); + } + PN->addIncoming(val, icmp_high->getParent()); + break; } + case CmpInst::ICMP_UGT: + case CmpInst::ICMP_ULT: { + /* transformations for < and > */ - } - - if (!icomps.size()) { return false; } - - for (auto &IcmpInst : icomps) { - - BasicBlock *bb = IcmpInst->getParent(); - - auto op0 = IcmpInst->getOperand(0); - auto op1 = IcmpInst->getOperand(1); + /* create a basic block which checks for the inverse predicate. + * if this is true we can go to the end if not we have to go to the + * bb which checks the lower half of the operands */ + Instruction *op0_low, *op1_low; + CmpInst *icmp_inv_cmp = nullptr; + BasicBlock * inv_cmp_bb = + BasicBlock::Create(C, "inv_cmp", end_bb->getParent(), end_bb); + if (pred == CmpInst::ICMP_UGT) { + icmp_inv_cmp = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULT, + op0_high, op1_high); - IntegerType *intTyOp0 = dyn_cast<IntegerType>(op0->getType()); - if (!intTyOp0) { continue; } - unsigned bitw = intTyOp0->getBitWidth(); - IntegerType *IntType = IntegerType::get(C, bitw); + } else { + icmp_inv_cmp = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_UGT, + op0_high, op1_high); + } - /* get the new predicate */ - auto cmp_inst = dyn_cast<CmpInst>(IcmpInst); - if (!cmp_inst) { continue; } - auto pred = cmp_inst->getPredicate(); - CmpInst::Predicate new_pred; + inv_cmp_bb->getInstList().push_back(icmp_inv_cmp); + worklist.push_back(icmp_inv_cmp); - if (pred == CmpInst::ICMP_SGT) { + auto term = bb->getTerminator(); + term->eraseFromParent(); + BranchInst::Create(end_bb, inv_cmp_bb, icmp_high, bb); - new_pred = CmpInst::ICMP_UGT; + /* create a bb which handles the cmp of the lower halves */ + BasicBlock *cmp_low_bb = + BasicBlock::Create(C, "" /*"injected"*/, end_bb->getParent(), end_bb); + op0_low = new TruncInst(op0, NewIntType); + cmp_low_bb->getInstList().push_back(op0_low); + op1_low = new TruncInst(op1, NewIntType); + cmp_low_bb->getInstList().push_back(op1_low); - } else { + icmp_low = CmpInst::Create(Instruction::ICmp, pred, op0_low, op1_low); + cmp_low_bb->getInstList().push_back(icmp_low); + BranchInst::Create(end_bb, cmp_low_bb); - new_pred = CmpInst::ICMP_ULT; + BranchInst::Create(end_bb, cmp_low_bb, icmp_inv_cmp, inv_cmp_bb); + PN = PHINode::Create(Int1Ty, 3); + PN->addIncoming(icmp_low, cmp_low_bb); + PN->addIncoming(ConstantInt::get(Int1Ty, 1), bb); + PN->addIncoming(ConstantInt::get(Int1Ty, 0), inv_cmp_bb); + break; } + default: + return false; + } - BasicBlock *end_bb = bb->splitBasicBlock(BasicBlock::iterator(IcmpInst)); - - /* create a 1 bit compare for the sign bit. to do this shift and trunc - * the original operands so only the first bit remains.*/ - Instruction *s_op0, *t_op0, *s_op1, *t_op1, *icmp_sign_bit; - - s_op0 = BinaryOperator::Create(Instruction::LShr, op0, - ConstantInt::get(IntType, bitw - 1)); - bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), s_op0); - t_op0 = new TruncInst(s_op0, Int1Ty); - bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), t_op0); + BasicBlock::iterator ii(cmp_inst); + ReplaceInstWithInst(cmp_inst->getParent()->getInstList(), ii, PN); - s_op1 = BinaryOperator::Create(Instruction::LShr, op1, - ConstantInt::get(IntType, bitw - 1)); - bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), s_op1); - t_op1 = new TruncInst(s_op1, Int1Ty); - bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), t_op1); + // We split the comparison into low and high. If this isn't our target + // bitwidth we recursivly split the low and high parts again until we have + // target bitwidth. + if ((bitw / 2) > target_bitwidth) { + worklist.push_back(icmp_high); + worklist.push_back(icmp_low); + } - /* compare of the sign bits */ - icmp_sign_bit = - CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, t_op0, t_op1); - bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), - icmp_sign_bit); + return true; +} - /* create a new basic block which is executed if the signedness bit is - * different */ - Instruction *icmp_inv_sig_cmp; - BasicBlock * sign_bb = - BasicBlock::Create(C, "sign", end_bb->getParent(), end_bb); - if (pred == CmpInst::ICMP_SGT) { +bool SplitComparesTransform::simplifyAndSplit(CmpInst *I, Module &M) { + CmpWorklist worklist; - /* if we check for > and the op0 positive and op1 negative then the final - * result is true. if op0 negative and op1 pos, the cmp must result - * in false - */ - icmp_inv_sig_cmp = - CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULT, t_op0, t_op1); + auto op0 = I->getOperand(0); + auto op1 = I->getOperand(1); + if (!op0 || !op1) { return false; } + auto op0Ty = dyn_cast<IntegerType>(op0->getType()); + if (!op0Ty || !isa<IntegerType>(op1->getType())) { return true; } - } else { + unsigned bitw = op0Ty->getBitWidth(); - /* just the inverse of the above statement */ - icmp_inv_sig_cmp = - CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_UGT, t_op0, t_op1); +#ifdef VERIFY_TOO_MUCH + auto F = I->getParent()->getParent(); +#endif + // we run the comparison simplification on all compares regardless of their + // bitwidth. + if (I->getPredicate() == CmpInst::ICMP_UGE || + I->getPredicate() == CmpInst::ICMP_SGE || + I->getPredicate() == CmpInst::ICMP_ULE || + I->getPredicate() == CmpInst::ICMP_SLE) { + if (!simplifyOrEqualsCompare(I, M, worklist)) { + reportError( + "Failed to simplify inequality or equals comparison " + "(UGE,SGE,ULE,SLE)", + I, M); } + } else if (I->getPredicate() == CmpInst::ICMP_SGT || + I->getPredicate() == CmpInst::ICMP_SLT) { + if (!simplifySignedCompare(I, M, worklist)) { + reportError("Failed to simplify signed comparison (SGT,SLT)", I, M); + } + } - sign_bb->getInstList().push_back(icmp_inv_sig_cmp); - BranchInst::Create(end_bb, sign_bb); - - /* create a new bb which is executed if signedness is equal */ - Instruction *icmp_usign_cmp; - BasicBlock * middle_bb = - BasicBlock::Create(C, "injected", end_bb->getParent(), end_bb); - /* we can do a normal unsigned compare now */ - icmp_usign_cmp = CmpInst::Create(Instruction::ICmp, new_pred, op0, op1); - middle_bb->getInstList().push_back(icmp_usign_cmp); - BranchInst::Create(end_bb, middle_bb); - - auto term = bb->getTerminator(); - /* if the sign is eq do a normal unsigned cmp, else we have to check the - * signedness bit */ - BranchInst::Create(middle_bb, sign_bb, icmp_sign_bit, bb); - term->eraseFromParent(); - - PHINode *PN = PHINode::Create(Int1Ty, 2, ""); - - PN->addIncoming(icmp_usign_cmp, middle_bb); - PN->addIncoming(icmp_inv_sig_cmp, sign_bb); +#ifdef VERIFY_TOO_MUCH + if (verifyFunction(*F, &errs())) { + reportError("simpliyfing compare lead to broken function", nullptr, M); + } +#endif - BasicBlock::iterator ii(IcmpInst); - ReplaceInstWithInst(IcmpInst->getParent()->getInstList(), ii, PN); + // the simplification methods replace the original CmpInst and push the + // resulting new CmpInst into the worklist. If the worklist is empty then + // we only have to split the original CmpInst. + if (worklist.size() == 0) { worklist.push_back(I); } + + while (!worklist.empty()) { + CmpInst *cmp = worklist.pop_back_val(); + // we split the simplified compares into comparisons with smaller bitwidths + // if they are larger than our target_bitwidth. + if (bitw > target_bitwidth) { + if (!splitCompare(cmp, M, worklist)) { + reportError("Failed to split comparison", cmp, M); + } +#ifdef VERIFY_TOO_MUCH + if (verifyFunction(*F, &errs())) { + reportError("splitting compare lead to broken function", nullptr, M); + } +#endif + } } + count++; return true; - } size_t SplitComparesTransform::nextPowerOfTwo(size_t in) { - --in; in |= in >> 1; in |= in >> 2; @@ -520,12 +678,10 @@ size_t SplitComparesTransform::nextPowerOfTwo(size_t in) { // in |= in >> 8; // in |= in >> 16; return in + 1; - } /* splits fcmps into two nested fcmps with sign compare and the rest */ size_t SplitComparesTransform::splitFPCompares(Module &M) { - size_t count = 0; LLVMContext &C = M.getContext(); @@ -537,13 +693,9 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) { /* define unions with floating point and (sign, exponent, mantissa) triples */ if (dl.isLittleEndian()) { - } else if (dl.isBigEndian()) { - } else { - return count; - } #endif @@ -553,17 +705,13 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) { /* get all EQ, NE, GT, and LT fcmps. if the other two * functions were executed only these four predicates should exist */ for (auto &F : M) { - if (!isInInstrumentList(&F)) continue; for (auto &BB : F) { - for (auto &IN : BB) { - CmpInst *selectcmpInst = nullptr; if ((selectcmpInst = dyn_cast<CmpInst>(&IN))) { - if (selectcmpInst->getPredicate() == CmpInst::FCMP_OEQ || selectcmpInst->getPredicate() == CmpInst::FCMP_UEQ || selectcmpInst->getPredicate() == CmpInst::FCMP_ONE || @@ -572,7 +720,6 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) { selectcmpInst->getPredicate() == CmpInst::FCMP_OGT || selectcmpInst->getPredicate() == CmpInst::FCMP_ULT || selectcmpInst->getPredicate() == CmpInst::FCMP_OLT) { - auto op0 = selectcmpInst->getOperand(0); auto op1 = selectcmpInst->getOperand(1); @@ -584,15 +731,10 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) { if (TyOp0->isArrayTy() || TyOp0->isVectorTy()) { continue; } fcomps.push_back(selectcmpInst); - } - } - } - } - } if (!fcomps.size()) { return count; } @@ -600,7 +742,6 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) { IntegerType *Int1Ty = IntegerType::getInt1Ty(C); for (auto &FcmpInst : fcomps) { - BasicBlock *bb = FcmpInst->getParent(); auto op0 = FcmpInst->getOperand(0); @@ -725,7 +866,6 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) { BasicBlock::iterator(signequal_bb->getTerminator()), t_e1); if (sizeInBits - precision < exTySizeBytes * 8) { - m_e0 = BinaryOperator::Create( Instruction::And, t_e0, ConstantInt::get(t_e0->getType(), mask_exponent)); @@ -738,10 +878,8 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) { BasicBlock::iterator(signequal_bb->getTerminator()), m_e1); } else { - m_e0 = t_e0; m_e1 = t_e1; - } /* compare the exponents of the operands */ @@ -749,7 +887,6 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) { Instruction *icmp_exponent_result; BasicBlock * signequal2_bb = signequal_bb; switch (FcmpInst->getPredicate()) { - case CmpInst::FCMP_UEQ: case CmpInst::FCMP_OEQ: icmp_exponent_result = @@ -819,7 +956,6 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) { break; default: continue; - } signequal2_bb->getInstList().insert( @@ -827,11 +963,9 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) { icmp_exponent_result); { - term = signequal2_bb->getTerminator(); switch (FcmpInst->getPredicate()) { - case CmpInst::FCMP_UEQ: case CmpInst::FCMP_OEQ: /* if the exponents are satifying the compare do a fraction cmp in @@ -854,11 +988,9 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) { break; default: continue; - } term->eraseFromParent(); - } /* isolate the mantissa aka fraction */ @@ -866,7 +998,6 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) { bool needTrunc = IntFractionTy->getPrimitiveSizeInBits() < op_size; if (precision - 1 < frTySizeBytes * 8) { - Instruction *m_f0, *m_f1; m_f0 = BinaryOperator::Create( Instruction::And, b_op0, @@ -880,7 +1011,6 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) { BasicBlock::iterator(middle_bb->getTerminator()), m_f1); if (needTrunc) { - t_f0 = new TruncInst(m_f0, IntFractionTy); t_f1 = new TruncInst(m_f1, IntFractionTy); middle_bb->getInstList().insert( @@ -889,16 +1019,12 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) { BasicBlock::iterator(middle_bb->getTerminator()), t_f1); } else { - t_f0 = m_f0; t_f1 = m_f1; - } } else { - if (needTrunc) { - t_f0 = new TruncInst(b_op0, IntFractionTy); t_f1 = new TruncInst(b_op1, IntFractionTy); middle_bb->getInstList().insert( @@ -907,12 +1033,9 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) { BasicBlock::iterator(middle_bb->getTerminator()), t_f1); } else { - t_f0 = b_op0; t_f1 = b_op1; - } - } /* compare the fractions of the operands */ @@ -920,7 +1043,6 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) { BasicBlock * middle2_bb = middle_bb; PHINode * PN2 = nullptr; switch (FcmpInst->getPredicate()) { - case CmpInst::FCMP_UEQ: case CmpInst::FCMP_OEQ: icmp_fraction_result = @@ -943,7 +1065,6 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) { case CmpInst::FCMP_UGT: case CmpInst::FCMP_OLT: case CmpInst::FCMP_ULT: { - Instruction *icmp_fraction_result2; middle2_bb = middle_bb->splitBasicBlock( @@ -956,7 +1077,6 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) { if (FcmpInst->getPredicate() == CmpInst::FCMP_OGT || FcmpInst->getPredicate() == CmpInst::FCMP_UGT) { - negative_bb->getInstList().push_back( icmp_fraction_result = CmpInst::Create( Instruction::ICmp, CmpInst::ICMP_ULT, t_f0, t_f1)); @@ -965,14 +1085,12 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) { Instruction::ICmp, CmpInst::ICMP_UGT, t_f0, t_f1)); } else { - negative_bb->getInstList().push_back( icmp_fraction_result = CmpInst::Create( Instruction::ICmp, CmpInst::ICMP_UGT, t_f0, t_f1)); positive_bb->getInstList().push_back( icmp_fraction_result2 = CmpInst::Create( Instruction::ICmp, CmpInst::ICMP_ULT, t_f0, t_f1)); - } BranchInst::Create(middle2_bb, negative_bb); @@ -992,13 +1110,11 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) { default: continue; - } PHINode *PN = PHINode::Create(Int1Ty, 3, ""); switch (FcmpInst->getPredicate()) { - case CmpInst::FCMP_UEQ: case CmpInst::FCMP_OEQ: /* unequal signs cannot be equal values */ @@ -1037,262 +1153,36 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) { break; default: continue; - } BasicBlock::iterator ii(FcmpInst); ReplaceInstWithInst(FcmpInst->getParent()->getInstList(), ii, PN); ++count; - } return count; - -} - -/* splits icmps of size bitw into two nested icmps with bitw/2 size each */ -size_t SplitComparesTransform::splitIntCompares(Module &M, unsigned bitw) { - - size_t count = 0; - - LLVMContext &C = M.getContext(); - - IntegerType *Int1Ty = IntegerType::getInt1Ty(C); - IntegerType *OldIntType = IntegerType::get(C, bitw); - IntegerType *NewIntType = IntegerType::get(C, bitw / 2); - - std::vector<Instruction *> icomps; - - if (bitw % 2) { return 0; } - - /* not supported yet */ - if (bitw > 64) { return 0; } - - /* get all EQ, NE, UGT, and ULT icmps of width bitw. if the - * functions simplifyCompares() and simplifyIntSignedness() - * were executed only these four predicates should exist */ - for (auto &F : M) { - - if (!isInInstrumentList(&F)) continue; - - for (auto &BB : F) { - - for (auto &IN : BB) { - - CmpInst *selectcmpInst = nullptr; - - if ((selectcmpInst = dyn_cast<CmpInst>(&IN))) { - - if (selectcmpInst->getPredicate() == CmpInst::ICMP_EQ || - selectcmpInst->getPredicate() == CmpInst::ICMP_NE || - selectcmpInst->getPredicate() == CmpInst::ICMP_UGT || - selectcmpInst->getPredicate() == CmpInst::ICMP_ULT) { - - auto op0 = selectcmpInst->getOperand(0); - auto op1 = selectcmpInst->getOperand(1); - - IntegerType *intTyOp0 = dyn_cast<IntegerType>(op0->getType()); - IntegerType *intTyOp1 = dyn_cast<IntegerType>(op1->getType()); - - if (!intTyOp0 || !intTyOp1) { continue; } - - /* check if the bitwidths are the one we are looking for */ - if (intTyOp0->getBitWidth() != bitw || - intTyOp1->getBitWidth() != bitw) { - - continue; - - } - - icomps.push_back(selectcmpInst); - - } - - } - - } - - } - - } - - if (!icomps.size()) { return 0; } - - for (auto &IcmpInst : icomps) { - - BasicBlock *bb = IcmpInst->getParent(); - - auto op0 = IcmpInst->getOperand(0); - auto op1 = IcmpInst->getOperand(1); - - auto cmp_inst = dyn_cast<CmpInst>(IcmpInst); - if (!cmp_inst) { continue; } - auto pred = cmp_inst->getPredicate(); - - BasicBlock *end_bb = bb->splitBasicBlock(BasicBlock::iterator(IcmpInst)); - - /* create the comparison of the top halves of the original operands */ - Instruction *s_op0, *op0_high, *s_op1, *op1_high, *icmp_high; - - s_op0 = BinaryOperator::Create(Instruction::LShr, op0, - ConstantInt::get(OldIntType, bitw / 2)); - bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), s_op0); - op0_high = new TruncInst(s_op0, NewIntType); - bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), - op0_high); - - s_op1 = BinaryOperator::Create(Instruction::LShr, op1, - ConstantInt::get(OldIntType, bitw / 2)); - bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), s_op1); - op1_high = new TruncInst(s_op1, NewIntType); - bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), - op1_high); - - icmp_high = CmpInst::Create(Instruction::ICmp, pred, op0_high, op1_high); - bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), - icmp_high); - - /* now we have to destinguish between == != and > < */ - if (pred == CmpInst::ICMP_EQ || pred == CmpInst::ICMP_NE) { - - /* transformation for == and != icmps */ - - /* create a compare for the lower half of the original operands */ - Instruction *op0_low, *op1_low, *icmp_low; - BasicBlock * cmp_low_bb = - BasicBlock::Create(C, "injected", end_bb->getParent(), end_bb); - - op0_low = new TruncInst(op0, NewIntType); - cmp_low_bb->getInstList().push_back(op0_low); - - op1_low = new TruncInst(op1, NewIntType); - cmp_low_bb->getInstList().push_back(op1_low); - - icmp_low = CmpInst::Create(Instruction::ICmp, pred, op0_low, op1_low); - cmp_low_bb->getInstList().push_back(icmp_low); - BranchInst::Create(end_bb, cmp_low_bb); - - /* dependent on the cmp of the high parts go to the end or go on with - * the comparison */ - auto term = bb->getTerminator(); - if (pred == CmpInst::ICMP_EQ) { - - BranchInst::Create(cmp_low_bb, end_bb, icmp_high, bb); - - } else { - - /* CmpInst::ICMP_NE */ - BranchInst::Create(end_bb, cmp_low_bb, icmp_high, bb); - - } - - term->eraseFromParent(); - - /* create the PHI and connect the edges accordingly */ - PHINode *PN = PHINode::Create(Int1Ty, 2, ""); - PN->addIncoming(icmp_low, cmp_low_bb); - if (pred == CmpInst::ICMP_EQ) { - - PN->addIncoming(ConstantInt::get(Int1Ty, 0), bb); - - } else { - - /* CmpInst::ICMP_NE */ - PN->addIncoming(ConstantInt::get(Int1Ty, 1), bb); - - } - - /* replace the old icmp with the new PHI */ - BasicBlock::iterator ii(IcmpInst); - ReplaceInstWithInst(IcmpInst->getParent()->getInstList(), ii, PN); - - } else { - - /* CmpInst::ICMP_UGT and CmpInst::ICMP_ULT */ - /* transformations for < and > */ - - /* create a basic block which checks for the inverse predicate. - * if this is true we can go to the end if not we have to go to the - * bb which checks the lower half of the operands */ - Instruction *icmp_inv_cmp, *op0_low, *op1_low, *icmp_low; - BasicBlock * inv_cmp_bb = - BasicBlock::Create(C, "inv_cmp", end_bb->getParent(), end_bb); - if (pred == CmpInst::ICMP_UGT) { - - icmp_inv_cmp = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULT, - op0_high, op1_high); - - } else { - - icmp_inv_cmp = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_UGT, - op0_high, op1_high); - - } - - inv_cmp_bb->getInstList().push_back(icmp_inv_cmp); - - auto term = bb->getTerminator(); - term->eraseFromParent(); - BranchInst::Create(end_bb, inv_cmp_bb, icmp_high, bb); - - /* create a bb which handles the cmp of the lower halves */ - BasicBlock *cmp_low_bb = - BasicBlock::Create(C, "injected", end_bb->getParent(), end_bb); - op0_low = new TruncInst(op0, NewIntType); - cmp_low_bb->getInstList().push_back(op0_low); - op1_low = new TruncInst(op1, NewIntType); - cmp_low_bb->getInstList().push_back(op1_low); - - icmp_low = CmpInst::Create(Instruction::ICmp, pred, op0_low, op1_low); - cmp_low_bb->getInstList().push_back(icmp_low); - BranchInst::Create(end_bb, cmp_low_bb); - - BranchInst::Create(end_bb, cmp_low_bb, icmp_inv_cmp, inv_cmp_bb); - - PHINode *PN = PHINode::Create(Int1Ty, 3); - PN->addIncoming(icmp_low, cmp_low_bb); - PN->addIncoming(ConstantInt::get(Int1Ty, 1), bb); - PN->addIncoming(ConstantInt::get(Int1Ty, 0), inv_cmp_bb); - - BasicBlock::iterator ii(IcmpInst); - ReplaceInstWithInst(IcmpInst->getParent()->getInstList(), ii, PN); - - } - - ++count; - - } - - return count; - } bool SplitComparesTransform::runOnModule(Module &M) { - - int bitw = 64; - size_t count = 0; - char *bitw_env = getenv("AFL_LLVM_LAF_SPLIT_COMPARES_BITW"); if (!bitw_env) bitw_env = getenv("LAF_SPLIT_COMPARES_BITW"); - if (bitw_env) { bitw = atoi(bitw_env); } + if (bitw_env) { target_bitwidth = atoi(bitw_env); } enableFPSplit = getenv("AFL_LLVM_LAF_SPLIT_FLOATS") != NULL; if ((isatty(2) && getenv("AFL_QUIET") == NULL) || getenv("AFL_DEBUG") != NULL) { + errs() << "Split-compare-pass by laf.intel@gmail.com, extended by " + "heiko@hexco.de (splitting icmp to " + << target_bitwidth << " bit)\n"; - printf( - "Split-compare-pass by laf.intel@gmail.com, extended by " - "heiko@hexco.de\n"); + if (getenv("AFL_DEBUG") != NULL && !debug) { debug = 1; } } else { - be_quiet = 1; - } if (enableFPSplit) { - count = splitFPCompares(M); /* @@ -1305,60 +1195,55 @@ bool SplitComparesTransform::runOnModule(Module &M) { */ simplifyFPCompares(M); - } - simplifyCompares(M); - - simplifyIntSignedness(M); - - switch (bitw) { + std::vector<CmpInst *> worklist; + /* iterate over all functions, bbs and instruction search for all integer + * compare instructions. Save them into the worklist for later. */ + for (auto &F : M) { + if (!isInInstrumentList(&F)) continue; - case 64: - count += splitIntCompares(M, bitw); - if (debug) - errs() << "Split-integer-compare-pass " << bitw << "bit: " << count - << " split\n"; - bitw >>= 1; -#if LLVM_VERSION_MAJOR > 3 || \ - (LLVM_VERSION_MAJOR == 3 && LLVM_VERSION_MINOR > 7) - [[clang::fallthrough]]; /*FALLTHRU*/ /* FALLTHROUGH */ -#endif - case 32: - count += splitIntCompares(M, bitw); - if (debug) - errs() << "Split-integer-compare-pass " << bitw << "bit: " << count - << " split\n"; - bitw >>= 1; -#if LLVM_VERSION_MAJOR > 3 || \ - (LLVM_VERSION_MAJOR == 3 && LLVM_VERSION_MINOR > 7) - [[clang::fallthrough]]; /*FALLTHRU*/ /* FALLTHROUGH */ -#endif - case 16: - count += splitIntCompares(M, bitw); - if (debug) - errs() << "Split-integer-compare-pass " << bitw << "bit: " << count - << " split\n"; - // bitw >>= 1; - break; + for (auto &BB : F) { + for (auto &IN : BB) { + if (auto CI = dyn_cast<CmpInst>(&IN)) { + auto op0 = CI->getOperand(0); + auto op1 = CI->getOperand(1); + if (!op0 || !op1) { return false; } + auto iTy1 = dyn_cast<IntegerType>(op0->getType()); + if (iTy1 && isa<IntegerType>(op1->getType())) { + unsigned bitw = iTy1->getBitWidth(); + if (isSupportedBitWidth(bitw)) { worklist.push_back(CI); } + } + } + } + } + } - default: - // if (!be_quiet) errs() << "NOT Running split-compare-pass \n"; - return false; - break; + // now that we have a list of all integer comparisons we can start replacing + // them with the splitted alternatives. + for (auto CI : worklist) { + simplifyAndSplit(CI, M); + } + bool brokenDebug = false; + if (verifyModule(M, &errs(), &brokenDebug)) { + reportError( + "Module Verifier failed! Consider reporting a bug with the AFL++ " + "project.", + nullptr, M); } - verifyModule(M); + if (brokenDebug) { + reportError("Module Verifier reported broken Debug Infos - Stripping!", + nullptr, M); + StripDebugInfo(M); + } return true; - } static void registerSplitComparesPass(const PassManagerBuilder &, legacy::PassManagerBase &PM) { - PM.add(new SplitComparesTransform()); - } static RegisterStandardPasses RegisterSplitComparesPass( @@ -1373,3 +1258,7 @@ static RegisterStandardPasses RegisterSplitComparesTransPassLTO( registerSplitComparesPass); #endif +static RegisterPass<SplitComparesTransform> X("splitcompares", + "AFL++ split compares", + true /* Only looks at CFG */, + true /* Analysis Pass */); diff --git a/test/test-llvm.sh b/test/test-llvm.sh index f902ffc5..8090e176 100755 --- a/test/test-llvm.sh +++ b/test/test-llvm.sh @@ -186,6 +186,29 @@ test -e ../afl-clang-fast -a -e ../split-switches-pass.so && { } rm -f test-instr.plain + $ECHO "$GREY[*] llvm_mode laf-intel/compcov testing splitting integer types (this might take some time)" + for testcase in ./test-int_cases.c ./test-uint_cases.c; do + for I in char short int long "long long"; do + for BITS in 8 16 32 64; do + bin="$testcase-split-$I-$BITS.compcov" + AFL_LLVM_INSTRUMENT=AFL AFL_DEBUG=1 AFL_LLVM_LAF_SPLIT_COMPARES_BITW=$BITS AFL_LLVM_LAF_SPLIT_COMPARES=1 ../afl-clang-fast -DINT_TYPE="$I" -o "$bin" "$testcase" > test.out 2>&1; + if ! test -e "$bin"; then + cat test.out + $ECHO "$RED[!] llvm_mode laf-intel/compcov integer splitting failed! ($testcase with type $I split to $BITS)!"; + CODE=1 + break + fi + if ! "$bin"; then + $ECHO "$RED[!] llvm_mode laf-intel/compcov integer splitting resulted in miscompilation (type $I split to $BITS)!"; + CODE=1 + break + fi + rm -f "$bin" test.out || true + done + done + done + rm -f test-int-split*.compcov test.out + AFL_LLVM_INSTRUMENT=AFL AFL_DEBUG=1 AFL_LLVM_LAF_SPLIT_SWITCHES=1 AFL_LLVM_LAF_TRANSFORM_COMPARES=1 AFL_LLVM_LAF_SPLIT_COMPARES=1 ../afl-clang-fast -o test-compcov.compcov test-compcov.c > test.out 2>&1 test -e test-compcov.compcov && test_compcov_binary_functionality ./test-compcov.compcov && { grep --binary-files=text -Eq " [ 123][0-9][0-9] location| [3-9][0-9] location" test.out && { diff --git a/test/test-uint_cases.c b/test/test-uint_cases.c index 8496cffe..a277e28a 100644 --- a/test/test-uint_cases.c +++ b/test/test-uint_cases.c @@ -1,16 +1,16 @@ /* - * compile with -DUINT_TYPE="unsigned char" - * or -DUINT_TYPE="unsigned short" - * or -DUINT_TYPE="unsigned int" - * or -DUINT_TYPE="unsigned long" - * or -DUINT_TYPE="unsigned long long" + * compile with -DINT_TYPE="char" + * or -DINT_TYPE="short" + * or -DINT_TYPE="int" + * or -DINT_TYPE="long" + * or -DINT_TYPE="long long" */ #include <assert.h> int main() { - volatile UINT_TYPE a, b; + volatile unsigned INT_TYPE a, b; a = 1; b = 8; @@ -21,7 +21,7 @@ int main() { assert((a != b)); assert(!(a == b)); - if ((UINT_TYPE)(~0) > 255) { + if ((INT_TYPE)(~0) > 255) { volatile unsigned short a, b; a = 256+2; b = 256+21; @@ -41,7 +41,7 @@ int main() { assert((a != b)); assert(!(a == b)); - if ((UINT_TYPE)(~0) > 65535) { + if ((INT_TYPE)(~0) > 65535) { volatile unsigned int a, b; a = 65536+2; b = 65536+21; @@ -62,7 +62,7 @@ int main() { assert(!(a == b)); } - if ((UINT_TYPE)(~0) > 4294967295) { + if ((INT_TYPE)(~0) > 4294967295) { volatile unsigned long a, b; a = 4294967296+2; b = 4294967296+21; @@ -93,7 +93,7 @@ int main() { assert((a != b)); assert(!(a == b)); - if ((UINT_TYPE)(~0) > 255) { + if ((INT_TYPE)(~0) > 255) { volatile unsigned short a, b; a = 256+2; b = 256+1; @@ -113,7 +113,7 @@ int main() { assert((a != b)); assert(!(a == b)); - if ((UINT_TYPE)(~0) > 65535) { + if ((INT_TYPE)(~0) > 65535) { volatile unsigned int a, b; a = 65536+2; b = 65536+1; @@ -133,7 +133,7 @@ int main() { assert((a != b)); assert(!(a == b)); - if ((UINT_TYPE)(~0) > 4294967295) { + if ((INT_TYPE)(~0) > 4294967295) { volatile unsigned long a, b; a = 4294967296+2; b = 4294967296+1; @@ -176,7 +176,7 @@ int main() { assert(!(a != b)); assert((a == b)); - if ((UINT_TYPE)(~0) > 255) { + if ((INT_TYPE)(~0) > 255) { volatile unsigned short a, b; a = 256+5; b = 256+5; @@ -187,7 +187,7 @@ int main() { assert(!(a != b)); assert((a == b)); - if ((UINT_TYPE)(~0) > 65535) { + if ((INT_TYPE)(~0) > 65535) { volatile unsigned int a, b; a = 65536+5; b = 65536+5; @@ -198,7 +198,7 @@ int main() { assert(!(a != b)); assert((a == b)); - if ((UINT_TYPE)(~0) > 4294967295) { + if ((INT_TYPE)(~0) > 4294967295) { volatile unsigned long a, b; a = 4294967296+5; b = 4294967296+5; |