diff options
Diffstat (limited to 'llvm_mode')
-rw-r--r-- | llvm_mode/LLVMInsTrim.so.cc | 546 | ||||
-rw-r--r-- | llvm_mode/MarkNodes.cc | 416 | ||||
-rw-r--r-- | llvm_mode/MarkNodes.h | 13 | ||||
-rw-r--r-- | llvm_mode/afl-clang-fast.c | 180 | ||||
-rw-r--r-- | llvm_mode/afl-llvm-pass.so.cc | 294 | ||||
-rw-r--r-- | llvm_mode/afl-llvm-rt.o.c | 55 | ||||
-rw-r--r-- | llvm_mode/compare-transform-pass.so.cc | 286 | ||||
-rw-r--r-- | llvm_mode/split-compares-pass.so.cc | 338 | ||||
-rw-r--r-- | llvm_mode/split-switches-pass.so.cc | 294 |
9 files changed, 1452 insertions, 970 deletions
diff --git a/llvm_mode/LLVMInsTrim.so.cc b/llvm_mode/LLVMInsTrim.so.cc index 95b52d48..4b5597e2 100644 --- a/llvm_mode/LLVMInsTrim.so.cc +++ b/llvm_mode/LLVMInsTrim.so.cc @@ -37,268 +37,349 @@ static cl::opt<bool> LoopHeadOpt("loophead", cl::desc("LoopHead"), cl::init(false)); namespace { - struct InsTrim : public ModulePass { - protected: - std::list<std::string> myWhitelist; +struct InsTrim : public ModulePass { - private: - std::mt19937 generator; - int total_instr = 0; + protected: + std::list<std::string> myWhitelist; - unsigned int genLabel() { - return generator() & (MAP_SIZE - 1); - } + private: + std::mt19937 generator; + int total_instr = 0; + + unsigned int genLabel() { + + return generator() & (MAP_SIZE - 1); + + } + + public: + static char ID; + InsTrim() : ModulePass(ID), generator(0) { - public: - static char ID; - InsTrim() : ModulePass(ID), generator(0) { - char* instWhiteListFilename = getenv("AFL_LLVM_WHITELIST"); - if (instWhiteListFilename) { - std::string line; - std::ifstream fileStream; - fileStream.open(instWhiteListFilename); - if (!fileStream) - report_fatal_error("Unable to open AFL_LLVM_WHITELIST"); + char *instWhiteListFilename = getenv("AFL_LLVM_WHITELIST"); + if (instWhiteListFilename) { + + std::string line; + std::ifstream fileStream; + fileStream.open(instWhiteListFilename); + if (!fileStream) report_fatal_error("Unable to open AFL_LLVM_WHITELIST"); + getline(fileStream, line); + while (fileStream) { + + myWhitelist.push_back(line); getline(fileStream, line); - while (fileStream) { - myWhitelist.push_back(line); - getline(fileStream, line); - } + } - } - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired<DominatorTreeWrapperPass>(); } + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + + AU.addRequired<DominatorTreeWrapperPass>(); + + } + #if LLVM_VERSION_MAJOR < 4 - const char * + const char * #else - StringRef + StringRef #endif - getPassName() const override { - return "InstTrim Instrumentation"; - } + getPassName() const override { + + return "InstTrim Instrumentation"; + + } + + bool runOnModule(Module &M) override { + + char be_quiet = 0; + + if (isatty(2) && !getenv("AFL_QUIET")) { + + SAYF(cCYA "LLVMInsTrim" VERSION cRST " by csienslab\n"); + + } else + + be_quiet = 1; - bool runOnModule(Module &M) override { - char be_quiet = 0; - - if (isatty(2) && !getenv("AFL_QUIET")) { - SAYF(cCYA "LLVMInsTrim" VERSION cRST " by csienslab\n"); - } else be_quiet = 1; - #if LLVM_VERSION_MAJOR < 9 - char* neverZero_counters_str; - if ((neverZero_counters_str = getenv("AFL_LLVM_NOT_ZERO")) != NULL) - OKF("LLVM neverZero activated (by hexcoder)\n"); + char *neverZero_counters_str; + if ((neverZero_counters_str = getenv("AFL_LLVM_NOT_ZERO")) != NULL) + OKF("LLVM neverZero activated (by hexcoder)\n"); #endif - - if (getenv("AFL_LLVM_INSTRIM_LOOPHEAD") != NULL || getenv("LOOPHEAD") != NULL) { - LoopHeadOpt = true; - } - // this is our default - MarkSetOpt = true; - -/* // I dont think this makes sense to port into LLVMInsTrim - char* inst_ratio_str = getenv("AFL_INST_RATIO"); - unsigned int inst_ratio = 100; - if (inst_ratio_str) { - if (sscanf(inst_ratio_str, "%u", &inst_ratio) != 1 || !inst_ratio || inst_ratio > 100) - FATAL("Bad value of AFL_INST_RATIO (must be between 1 and 100)"); - } -*/ + if (getenv("AFL_LLVM_INSTRIM_LOOPHEAD") != NULL || + getenv("LOOPHEAD") != NULL) { + + LoopHeadOpt = true; + + } + + // this is our default + MarkSetOpt = true; + + /* // I dont think this makes sense to port into LLVMInsTrim + char* inst_ratio_str = getenv("AFL_INST_RATIO"); + unsigned int inst_ratio = 100; + if (inst_ratio_str) { + + if (sscanf(inst_ratio_str, "%u", &inst_ratio) != 1 || !inst_ratio || + inst_ratio > 100) FATAL("Bad value of AFL_INST_RATIO (must be between 1 + and 100)"); + + } + + */ - LLVMContext &C = M.getContext(); - IntegerType *Int8Ty = IntegerType::getInt8Ty(C); - IntegerType *Int32Ty = IntegerType::getInt32Ty(C); + LLVMContext &C = M.getContext(); + IntegerType *Int8Ty = IntegerType::getInt8Ty(C); + IntegerType *Int32Ty = IntegerType::getInt32Ty(C); - GlobalVariable *CovMapPtr = new GlobalVariable( + GlobalVariable *CovMapPtr = new GlobalVariable( M, PointerType::getUnqual(Int8Ty), false, GlobalValue::ExternalLinkage, nullptr, "__afl_area_ptr"); - GlobalVariable *OldPrev = new GlobalVariable( - M, Int32Ty, false, GlobalValue::ExternalLinkage, 0, "__afl_prev_loc", - 0, GlobalVariable::GeneralDynamicTLSModel, 0, false); + GlobalVariable *OldPrev = new GlobalVariable( + M, Int32Ty, false, GlobalValue::ExternalLinkage, 0, "__afl_prev_loc", 0, + GlobalVariable::GeneralDynamicTLSModel, 0, false); - u64 total_rs = 0; - u64 total_hs = 0; + u64 total_rs = 0; + u64 total_hs = 0; + + for (Function &F : M) { + + if (!F.size()) { continue; } + + if (!myWhitelist.empty()) { + + bool instrumentBlock = false; + DebugLoc Loc; + StringRef instFilename; + + for (auto &BB : F) { + + BasicBlock::iterator IP = BB.getFirstInsertionPt(); + IRBuilder<> IRB(&(*IP)); + if (!Loc) Loc = IP->getDebugLoc(); - for (Function &F : M) { - if (!F.size()) { - continue; } - if (!myWhitelist.empty()) { - bool instrumentBlock = false; - DebugLoc Loc; - StringRef instFilename; + if (Loc) { + + DILocation *cDILoc = dyn_cast<DILocation>(Loc.getAsMDNode()); + + unsigned int instLine = cDILoc->getLine(); + instFilename = cDILoc->getFilename(); + + if (instFilename.str().empty()) { + + /* If the original location is empty, try using the inlined location + */ + DILocation *oDILoc = cDILoc->getInlinedAt(); + if (oDILoc) { + + instFilename = oDILoc->getFilename(); + instLine = oDILoc->getLine(); + + } - for (auto &BB : F) { - BasicBlock::iterator IP = BB.getFirstInsertionPt(); - IRBuilder<> IRB(&(*IP)); - if (!Loc) - Loc = IP->getDebugLoc(); } - if ( Loc ) { - DILocation *cDILoc = dyn_cast<DILocation>(Loc.getAsMDNode()); + /* Continue only if we know where we actually are */ + if (!instFilename.str().empty()) { - unsigned int instLine = cDILoc->getLine(); - instFilename = cDILoc->getFilename(); + for (std::list<std::string>::iterator it = myWhitelist.begin(); + it != myWhitelist.end(); ++it) { - if (instFilename.str().empty()) { - /* If the original location is empty, try using the inlined location */ - DILocation *oDILoc = cDILoc->getInlinedAt(); - if (oDILoc) { - instFilename = oDILoc->getFilename(); - instLine = oDILoc->getLine(); - } - } + if (instFilename.str().length() >= it->length()) { + + if (instFilename.str().compare( + instFilename.str().length() - it->length(), + it->length(), *it) == 0) { + + instrumentBlock = true; + break; + + } - /* Continue only if we know where we actually are */ - if (!instFilename.str().empty()) { - for (std::list<std::string>::iterator it = myWhitelist.begin(); it != myWhitelist.end(); ++it) { - if (instFilename.str().length() >= it->length()) { - if (instFilename.str().compare(instFilename.str().length() - it->length(), it->length(), *it) == 0) { - instrumentBlock = true; - break; - } - } - } } - } - /* Either we couldn't figure out our location or the location is - * not whitelisted, so we skip instrumentation. */ - if (!instrumentBlock) { - if (!instFilename.str().empty()) - SAYF(cYEL "[!] " cBRI "Not in whitelist, skipping %s ...\n", instFilename.str().c_str()); - else - SAYF(cYEL "[!] " cBRI "No filename information found, skipping it"); - continue; + } + } + } - std::unordered_set<BasicBlock *> MS; - if (!MarkSetOpt) { - for (auto &BB : F) { - MS.insert(&BB); - } - total_rs += F.size(); + /* Either we couldn't figure out our location or the location is + * not whitelisted, so we skip instrumentation. */ + if (!instrumentBlock) { + + if (!instFilename.str().empty()) + SAYF(cYEL "[!] " cBRI "Not in whitelist, skipping %s ...\n", + instFilename.str().c_str()); + else + SAYF(cYEL "[!] " cBRI "No filename information found, skipping it"); + continue; + + } + + } + + std::unordered_set<BasicBlock *> MS; + if (!MarkSetOpt) { + + for (auto &BB : F) { + + MS.insert(&BB); + + } + + total_rs += F.size(); + + } else { + + auto Result = markNodes(&F); + auto RS = Result.first; + auto HS = Result.second; + + MS.insert(RS.begin(), RS.end()); + if (!LoopHeadOpt) { + + MS.insert(HS.begin(), HS.end()); + total_rs += MS.size(); + } else { - auto Result = markNodes(&F); - auto RS = Result.first; - auto HS = Result.second; - - MS.insert(RS.begin(), RS.end()); - if (!LoopHeadOpt) { - MS.insert(HS.begin(), HS.end()); - total_rs += MS.size(); - } else { - DenseSet<std::pair<BasicBlock *, BasicBlock *>> EdgeSet; - DominatorTreeWrapperPass *DTWP = &getAnalysis<DominatorTreeWrapperPass>(F); - auto DT = &DTWP->getDomTree(); - - total_rs += RS.size(); - total_hs += HS.size(); - - for (BasicBlock *BB : HS) { - bool Inserted = false; - for (auto BI = pred_begin(BB), BE = pred_end(BB); - BI != BE; ++BI - ) { - auto Edge = BasicBlockEdge(*BI, BB); - if (Edge.isSingleEdge() && DT->dominates(Edge, BB)) { - EdgeSet.insert({*BI, BB}); - Inserted = true; - break; - } - } - if (!Inserted) { - MS.insert(BB); - total_rs += 1; - total_hs -= 1; + + DenseSet<std::pair<BasicBlock *, BasicBlock *>> EdgeSet; + DominatorTreeWrapperPass * DTWP = + &getAnalysis<DominatorTreeWrapperPass>(F); + auto DT = &DTWP->getDomTree(); + + total_rs += RS.size(); + total_hs += HS.size(); + + for (BasicBlock *BB : HS) { + + bool Inserted = false; + for (auto BI = pred_begin(BB), BE = pred_end(BB); BI != BE; ++BI) { + + auto Edge = BasicBlockEdge(*BI, BB); + if (Edge.isSingleEdge() && DT->dominates(Edge, BB)) { + + EdgeSet.insert({*BI, BB}); + Inserted = true; + break; + } + } - for (auto I = EdgeSet.begin(), E = EdgeSet.end(); I != E; ++I) { - auto PredBB = I->first; - auto SuccBB = I->second; - auto NewBB = SplitBlockPredecessors(SuccBB, {PredBB}, ".split", - DT, nullptr, -#if LLVM_VERSION_MAJOR >= 8 - nullptr, -#endif - false); - MS.insert(NewBB); + + if (!Inserted) { + + MS.insert(BB); + total_rs += 1; + total_hs -= 1; + } - } - auto *EBB = &F.getEntryBlock(); - if (succ_begin(EBB) == succ_end(EBB)) { - MS.insert(EBB); - total_rs += 1; } - for (BasicBlock &BB : F) { - if (MS.find(&BB) == MS.end()) { - continue; - } - IRBuilder<> IRB(&*BB.getFirstInsertionPt()); - IRB.CreateStore(ConstantInt::get(Int32Ty, genLabel()), OldPrev); + for (auto I = EdgeSet.begin(), E = EdgeSet.end(); I != E; ++I) { + + auto PredBB = I->first; + auto SuccBB = I->second; + auto NewBB = + SplitBlockPredecessors(SuccBB, {PredBB}, ".split", DT, nullptr, +#if LLVM_VERSION_MAJOR >= 8 + nullptr, +#endif + false); + MS.insert(NewBB); + } + + } + + auto *EBB = &F.getEntryBlock(); + if (succ_begin(EBB) == succ_end(EBB)) { + + MS.insert(EBB); + total_rs += 1; + } for (BasicBlock &BB : F) { - auto PI = pred_begin(&BB); - auto PE = pred_end(&BB); - if (MarkSetOpt && MS.find(&BB) == MS.end()) { - continue; - } + if (MS.find(&BB) == MS.end()) { continue; } IRBuilder<> IRB(&*BB.getFirstInsertionPt()); - Value *L = NULL; - if (PI == PE) { - L = ConstantInt::get(Int32Ty, genLabel()); - } else { - auto *PN = PHINode::Create(Int32Ty, 0, "", &*BB.begin()); - DenseMap<BasicBlock *, unsigned> PredMap; - for (auto PI = pred_begin(&BB), PE = pred_end(&BB); - PI != PE; ++PI - ) { - BasicBlock *PBB = *PI; - auto It = PredMap.insert({PBB, genLabel()}); - unsigned Label = It.first->second; - PN->addIncoming(ConstantInt::get(Int32Ty, Label), PBB); - } - L = PN; + IRB.CreateStore(ConstantInt::get(Int32Ty, genLabel()), OldPrev); + + } + + } + + for (BasicBlock &BB : F) { + + auto PI = pred_begin(&BB); + auto PE = pred_end(&BB); + if (MarkSetOpt && MS.find(&BB) == MS.end()) { continue; } + + IRBuilder<> IRB(&*BB.getFirstInsertionPt()); + Value * L = NULL; + if (PI == PE) { + + L = ConstantInt::get(Int32Ty, genLabel()); + + } else { + + auto *PN = PHINode::Create(Int32Ty, 0, "", &*BB.begin()); + DenseMap<BasicBlock *, unsigned> PredMap; + for (auto PI = pred_begin(&BB), PE = pred_end(&BB); PI != PE; ++PI) { + + BasicBlock *PBB = *PI; + auto It = PredMap.insert({PBB, genLabel()}); + unsigned Label = It.first->second; + PN->addIncoming(ConstantInt::get(Int32Ty, Label), PBB); + } - /* Load prev_loc */ - LoadInst *PrevLoc = IRB.CreateLoad(OldPrev); - PrevLoc->setMetadata(M.getMDKindID("nosanitize"), MDNode::get(C, None)); - Value *PrevLocCasted = IRB.CreateZExt(PrevLoc, IRB.getInt32Ty()); + L = PN; + + } + + /* Load prev_loc */ + LoadInst *PrevLoc = IRB.CreateLoad(OldPrev); + PrevLoc->setMetadata(M.getMDKindID("nosanitize"), MDNode::get(C, None)); + Value *PrevLocCasted = IRB.CreateZExt(PrevLoc, IRB.getInt32Ty()); + + /* Load SHM pointer */ + LoadInst *MapPtr = IRB.CreateLoad(CovMapPtr); + MapPtr->setMetadata(M.getMDKindID("nosanitize"), MDNode::get(C, None)); + Value *MapPtrIdx = + IRB.CreateGEP(MapPtr, IRB.CreateXor(PrevLocCasted, L)); - /* Load SHM pointer */ - LoadInst *MapPtr = IRB.CreateLoad(CovMapPtr); - MapPtr->setMetadata(M.getMDKindID("nosanitize"), MDNode::get(C, None)); - Value *MapPtrIdx = IRB.CreateGEP(MapPtr, IRB.CreateXor(PrevLocCasted, L)); + /* Update bitmap */ + LoadInst *Counter = IRB.CreateLoad(MapPtrIdx); + Counter->setMetadata(M.getMDKindID("nosanitize"), MDNode::get(C, None)); - /* Update bitmap */ - LoadInst *Counter = IRB.CreateLoad(MapPtrIdx); - Counter->setMetadata(M.getMDKindID("nosanitize"), MDNode::get(C, None)); - - Value *Incr = IRB.CreateAdd(Counter, ConstantInt::get(Int8Ty, 1)); + Value *Incr = IRB.CreateAdd(Counter, ConstantInt::get(Int8Ty, 1)); #if LLVM_VERSION_MAJOR < 9 - if (neverZero_counters_str != NULL) // with llvm 9 we make this the default as the bug in llvm is then fixed + if (neverZero_counters_str != + NULL) // with llvm 9 we make this the default as the bug in llvm is + // then fixed #else - if (1) // with llvm 9 we make this the default as the bug in llvm is then fixed + if (1) // with llvm 9 we make this the default as the bug in llvm is + // then fixed #endif - { + { + /* hexcoder: Realize a counter that skips zero during overflow. - * Once this counter reaches its maximum value, it next increments to 1 + * Once this counter reaches its maximum value, it next increments to + * 1 * * Instead of * Counter + 1 -> Counter @@ -306,38 +387,52 @@ namespace { * Counter + 1 -> {Counter, OverflowFlag} * Counter + OverflowFlag -> Counter */ - auto cf = IRB.CreateICmpEQ(Incr, ConstantInt::get(Int8Ty, 0)); - auto carry = IRB.CreateZExt(cf, Int8Ty); - Incr = IRB.CreateAdd(Incr, carry); - } - - IRB.CreateStore(Incr, MapPtrIdx)->setMetadata(M.getMDKindID("nosanitize"), MDNode::get(C, None)); - - /* Set prev_loc to cur_loc >> 1 */ - /* - StoreInst *Store = IRB.CreateStore(ConstantInt::get(Int32Ty, L >> 1), OldPrev); - Store->setMetadata(M.getMDKindID("nosanitize"), MDNode::get(C, None)); - */ - - total_instr++; + auto cf = IRB.CreateICmpEQ(Incr, ConstantInt::get(Int8Ty, 0)); + auto carry = IRB.CreateZExt(cf, Int8Ty); + Incr = IRB.CreateAdd(Incr, carry); + } + + IRB.CreateStore(Incr, MapPtrIdx) + ->setMetadata(M.getMDKindID("nosanitize"), MDNode::get(C, None)); + + /* Set prev_loc to cur_loc >> 1 */ + /* + StoreInst *Store = IRB.CreateStore(ConstantInt::get(Int32Ty, L >> 1), + OldPrev); Store->setMetadata(M.getMDKindID("nosanitize"), MDNode::get(C, + None)); + */ + + total_instr++; + } - OKF("Instrumented %u locations (%llu, %llu) (%s mode)\n"/*", ratio %u%%)."*/, - total_instr, total_rs, total_hs, - getenv("AFL_HARDEN") ? "hardened" : - ((getenv("AFL_USE_ASAN") || getenv("AFL_USE_MSAN")) ? - "ASAN/MSAN" : "non-hardened")/*, inst_ratio*/); - return false; } - }; // end of struct InsTrim + + OKF("Instrumented %u locations (%llu, %llu) (%s mode)\n" /*", ratio + %u%%)."*/ + , + total_instr, total_rs, total_hs, + getenv("AFL_HARDEN") + ? "hardened" + : ((getenv("AFL_USE_ASAN") || getenv("AFL_USE_MSAN")) + ? "ASAN/MSAN" + : "non-hardened") /*, inst_ratio*/); + return false; + + } + +}; // end of struct InsTrim + } // end of anonymous namespace char InsTrim::ID = 0; static void registerAFLPass(const PassManagerBuilder &, legacy::PassManagerBase &PM) { + PM.add(new InsTrim()); + } static RegisterStandardPasses RegisterAFLPass( @@ -345,3 +440,4 @@ static RegisterStandardPasses RegisterAFLPass( static RegisterStandardPasses RegisterAFLPass0( PassManagerBuilder::EP_EnabledOnOptLevel0, registerAFLPass); + diff --git a/llvm_mode/MarkNodes.cc b/llvm_mode/MarkNodes.cc index 348dc264..2aeeda8d 100644 --- a/llvm_mode/MarkNodes.cc +++ b/llvm_mode/MarkNodes.cc @@ -19,207 +19,267 @@ using namespace llvm; -DenseMap<BasicBlock *, uint32_t> LMap; -std::vector<BasicBlock *> Blocks; -std::set<uint32_t> Marked , Markabove; -std::vector< std::vector<uint32_t> > Succs , Preds; +DenseMap<BasicBlock *, uint32_t> LMap; +std::vector<BasicBlock *> Blocks; +std::set<uint32_t> Marked, Markabove; +std::vector<std::vector<uint32_t> > Succs, Preds; + +void reset() { -void reset(){ LMap.clear(); Blocks.clear(); Marked.clear(); Markabove.clear(); + } uint32_t start_point; void labelEachBlock(Function *F) { + // Fake single endpoint; LMap[NULL] = Blocks.size(); Blocks.push_back(NULL); - + // Assign the unique LabelID to each block; for (auto I = F->begin(), E = F->end(); I != E; ++I) { + BasicBlock *BB = &*I; LMap[BB] = Blocks.size(); Blocks.push_back(BB); + } - + start_point = LMap[&F->getEntryBlock()]; + } void buildCFG(Function *F) { - Succs.resize( Blocks.size() ); - Preds.resize( Blocks.size() ); - for( size_t i = 0 ; i < Succs.size() ; i ++ ){ - Succs[ i ].clear(); - Preds[ i ].clear(); + + Succs.resize(Blocks.size()); + Preds.resize(Blocks.size()); + for (size_t i = 0; i < Succs.size(); i++) { + + Succs[i].clear(); + Preds[i].clear(); + } - //uint32_t FakeID = 0; + // uint32_t FakeID = 0; for (auto S = F->begin(), E = F->end(); S != E; ++S) { + BasicBlock *BB = &*S; - uint32_t MyID = LMap[BB]; - //if (succ_begin(BB) == succ_end(BB)) { - //Succs[MyID].push_back(FakeID); - //Marked.insert(MyID); + uint32_t MyID = LMap[BB]; + // if (succ_begin(BB) == succ_end(BB)) { + + // Succs[MyID].push_back(FakeID); + // Marked.insert(MyID); //} for (auto I = succ_begin(BB), E = succ_end(BB); I != E; ++I) { + Succs[MyID].push_back(LMap[*I]); + } + } + } -std::vector< std::vector<uint32_t> > tSuccs; -std::vector<bool> tag , indfs; +std::vector<std::vector<uint32_t> > tSuccs; +std::vector<bool> tag, indfs; void DFStree(size_t now_id) { - if(tag[now_id]) return; - tag[now_id]=true; - indfs[now_id]=true; - for (auto succ: tSuccs[now_id]) { - if(tag[succ] and indfs[succ]) { + + if (tag[now_id]) return; + tag[now_id] = true; + indfs[now_id] = true; + for (auto succ : tSuccs[now_id]) { + + if (tag[succ] and indfs[succ]) { + Marked.insert(succ); Markabove.insert(succ); continue; + } + Succs[now_id].push_back(succ); Preds[succ].push_back(now_id); DFStree(succ); + } - indfs[now_id]=false; + + indfs[now_id] = false; + } + void turnCFGintoDAG(Function *F) { + tSuccs = Succs; tag.resize(Blocks.size()); indfs.resize(Blocks.size()); - for (size_t i = 0; i < Blocks.size(); ++ i) { + for (size_t i = 0; i < Blocks.size(); ++i) { + Succs[i].clear(); - tag[i]=false; - indfs[i]=false; + tag[i] = false; + indfs[i] = false; + } + DFStree(start_point); - for (size_t i = 0; i < Blocks.size(); ++ i) - if( Succs[i].empty() ){ + for (size_t i = 0; i < Blocks.size(); ++i) + if (Succs[i].empty()) { + Succs[i].push_back(0); Preds[0].push_back(i); + } + } uint32_t timeStamp; -namespace DominatorTree{ - std::vector< std::vector<uint32_t> > cov; - std::vector<uint32_t> dfn, nfd, par, sdom, idom, mom, mn; +namespace DominatorTree { + +std::vector<std::vector<uint32_t> > cov; +std::vector<uint32_t> dfn, nfd, par, sdom, idom, mom, mn; + +bool Compare(uint32_t u, uint32_t v) { + + return dfn[u] < dfn[v]; + +} + +uint32_t eval(uint32_t u) { + + if (mom[u] == u) return u; + uint32_t res = eval(mom[u]); + if (Compare(sdom[mn[mom[u]]], sdom[mn[u]])) { mn[u] = mn[mom[u]]; } + return mom[u] = res; + +} + +void DFS(uint32_t now) { + + timeStamp += 1; + dfn[now] = timeStamp; + nfd[timeStamp - 1] = now; + for (auto succ : Succs[now]) { + + if (dfn[succ] == 0) { + + par[succ] = now; + DFS(succ); - bool Compare(uint32_t u, uint32_t v) { - return dfn[u] < dfn[v]; - } - uint32_t eval(uint32_t u) { - if( mom[u] == u ) return u; - uint32_t res = eval( mom[u] ); - if(Compare(sdom[mn[mom[u]]] , sdom[mn[u]])) { - mn[u] = mn[mom[u]]; } - return mom[u] = res; + } - void DFS(uint32_t now) { - timeStamp += 1; - dfn[now] = timeStamp; - nfd[timeStamp - 1] = now; - for( auto succ : Succs[now] ) { - if( dfn[succ] == 0 ) { - par[succ] = now; - DFS(succ); - } - } +} + +void DominatorTree(Function *F) { + + if (Blocks.empty()) return; + uint32_t s = start_point; + + // Initialization + mn.resize(Blocks.size()); + cov.resize(Blocks.size()); + dfn.resize(Blocks.size()); + nfd.resize(Blocks.size()); + par.resize(Blocks.size()); + mom.resize(Blocks.size()); + sdom.resize(Blocks.size()); + idom.resize(Blocks.size()); + + for (uint32_t i = 0; i < Blocks.size(); i++) { + + dfn[i] = 0; + nfd[i] = Blocks.size(); + cov[i].clear(); + idom[i] = mom[i] = mn[i] = sdom[i] = i; + } - void DominatorTree(Function *F) { - if( Blocks.empty() ) return; - uint32_t s = start_point; - - // Initialization - mn.resize(Blocks.size()); - cov.resize(Blocks.size()); - dfn.resize(Blocks.size()); - nfd.resize(Blocks.size()); - par.resize(Blocks.size()); - mom.resize(Blocks.size()); - sdom.resize(Blocks.size()); - idom.resize(Blocks.size()); - - for( uint32_t i = 0 ; i < Blocks.size() ; i ++ ) { - dfn[i] = 0; - nfd[i] = Blocks.size(); - cov[i].clear(); - idom[i] = mom[i] = mn[i] = sdom[i] = i; - } + timeStamp = 0; + DFS(s); - timeStamp = 0; - DFS(s); + for (uint32_t i = Blocks.size() - 1; i >= 1u; i--) { + + uint32_t now = nfd[i]; + if (now == Blocks.size()) { continue; } + for (uint32_t pre : Preds[now]) { + + if (dfn[pre]) { + + eval(pre); + if (Compare(sdom[mn[pre]], sdom[now])) { sdom[now] = sdom[mn[pre]]; } - for( uint32_t i = Blocks.size() - 1 ; i >= 1u ; i -- ) { - uint32_t now = nfd[i]; - if( now == Blocks.size() ) { - continue; - } - for( uint32_t pre : Preds[ now ] ) { - if( dfn[ pre ] ) { - eval(pre); - if( Compare(sdom[mn[pre]], sdom[now]) ) { - sdom[now] = sdom[mn[pre]]; - } - } - } - cov[sdom[now]].push_back(now); - mom[now] = par[now]; - for( uint32_t x : cov[par[now]] ) { - eval(x); - if( Compare(sdom[mn[x]], par[now]) ) { - idom[x] = mn[x]; - } else { - idom[x] = par[now]; - } } + } - for( uint32_t i = 1 ; i < Blocks.size() ; i += 1 ) { - uint32_t now = nfd[i]; - if( now == Blocks.size() ) { - continue; + cov[sdom[now]].push_back(now); + mom[now] = par[now]; + for (uint32_t x : cov[par[now]]) { + + eval(x); + if (Compare(sdom[mn[x]], par[now])) { + + idom[x] = mn[x]; + + } else { + + idom[x] = par[now]; + } - if(idom[now] != sdom[now]) - idom[now] = idom[idom[now]]; + } + } -} // End of DominatorTree -std::vector<uint32_t> Visited, InStack; -std::vector<uint32_t> TopoOrder, InDeg; -std::vector< std::vector<uint32_t> > t_Succ , t_Pred; + for (uint32_t i = 1; i < Blocks.size(); i += 1) { + + uint32_t now = nfd[i]; + if (now == Blocks.size()) { continue; } + if (idom[now] != sdom[now]) idom[now] = idom[idom[now]]; + + } + +} + +} // namespace DominatorTree + +std::vector<uint32_t> Visited, InStack; +std::vector<uint32_t> TopoOrder, InDeg; +std::vector<std::vector<uint32_t> > t_Succ, t_Pred; void Go(uint32_t now, uint32_t tt) { - if( now == tt ) return; + + if (now == tt) return; Visited[now] = InStack[now] = timeStamp; - for(uint32_t nxt : Succs[now]) { - if(Visited[nxt] == timeStamp and InStack[nxt] == timeStamp) { + for (uint32_t nxt : Succs[now]) { + + if (Visited[nxt] == timeStamp and InStack[nxt] == timeStamp) { + Marked.insert(nxt); + } + t_Succ[now].push_back(nxt); t_Pred[nxt].push_back(now); InDeg[nxt] += 1; - if(Visited[nxt] == timeStamp) { - continue; - } + if (Visited[nxt] == timeStamp) { continue; } Go(nxt, tt); + } InStack[now] = 0; + } void TopologicalSort(uint32_t ss, uint32_t tt) { + timeStamp += 1; Go(ss, tt); @@ -227,76 +287,111 @@ void TopologicalSort(uint32_t ss, uint32_t tt) { TopoOrder.clear(); std::queue<uint32_t> wait; wait.push(ss); - while( not wait.empty() ) { - uint32_t now = wait.front(); wait.pop(); + while (not wait.empty()) { + + uint32_t now = wait.front(); + wait.pop(); TopoOrder.push_back(now); - for(uint32_t nxt : t_Succ[now]) { + for (uint32_t nxt : t_Succ[now]) { + InDeg[nxt] -= 1; - if(InDeg[nxt] == 0u) { - wait.push(nxt); - } + if (InDeg[nxt] == 0u) { wait.push(nxt); } + } + } + } -std::vector< std::set<uint32_t> > NextMarked; -bool Indistinguish(uint32_t node1, uint32_t node2) { - if(NextMarked[node1].size() > NextMarked[node2].size()){ +std::vector<std::set<uint32_t> > NextMarked; +bool Indistinguish(uint32_t node1, uint32_t node2) { + + if (NextMarked[node1].size() > NextMarked[node2].size()) { + uint32_t _swap = node1; node1 = node2; node2 = _swap; + } - for(uint32_t x : NextMarked[node1]) { - if( NextMarked[node2].find(x) != NextMarked[node2].end() ) { - return true; - } + + for (uint32_t x : NextMarked[node1]) { + + if (NextMarked[node2].find(x) != NextMarked[node2].end()) { return true; } + } + return false; + } void MakeUniq(uint32_t now) { + bool StopFlag = false; if (Marked.find(now) == Marked.end()) { - for(uint32_t pred1 : t_Pred[now]) { - for(uint32_t pred2 : t_Pred[now]) { - if(pred1 == pred2) continue; - if(Indistinguish(pred1, pred2)) { + + for (uint32_t pred1 : t_Pred[now]) { + + for (uint32_t pred2 : t_Pred[now]) { + + if (pred1 == pred2) continue; + if (Indistinguish(pred1, pred2)) { + Marked.insert(now); StopFlag = true; break; + } + } - if (StopFlag) { - break; - } + + if (StopFlag) { break; } + } + } - if(Marked.find(now) != Marked.end()) { + + if (Marked.find(now) != Marked.end()) { + NextMarked[now].insert(now); + } else { - for(uint32_t pred : t_Pred[now]) { - for(uint32_t x : NextMarked[pred]) { + + for (uint32_t pred : t_Pred[now]) { + + for (uint32_t x : NextMarked[pred]) { + NextMarked[now].insert(x); + } + } + } + } void MarkSubGraph(uint32_t ss, uint32_t tt) { + TopologicalSort(ss, tt); - if(TopoOrder.empty()) return; + if (TopoOrder.empty()) return; + + for (uint32_t i : TopoOrder) { - for(uint32_t i : TopoOrder) { NextMarked[i].clear(); + } NextMarked[TopoOrder[0]].insert(TopoOrder[0]); - for(uint32_t i = 1 ; i < TopoOrder.size() ; i += 1) { + for (uint32_t i = 1; i < TopoOrder.size(); i += 1) { + MakeUniq(TopoOrder[i]); + } + } void MarkVertice(Function *F) { + uint32_t s = start_point; InDeg.resize(Blocks.size()); @@ -306,26 +401,32 @@ void MarkVertice(Function *F) { t_Pred.resize(Blocks.size()); NextMarked.resize(Blocks.size()); - for( uint32_t i = 0 ; i < Blocks.size() ; i += 1 ) { + for (uint32_t i = 0; i < Blocks.size(); i += 1) { + Visited[i] = InStack[i] = InDeg[i] = 0; t_Succ[i].clear(); t_Pred[i].clear(); + } + timeStamp = 0; uint32_t t = 0; - //MarkSubGraph(s, t); - //return; + // MarkSubGraph(s, t); + // return; + + while (s != t) { - while( s != t ) { MarkSubGraph(DominatorTree::idom[t], t); t = DominatorTree::idom[t]; + } } // return {marked nodes} -std::pair<std::vector<BasicBlock *>, - std::vector<BasicBlock *> >markNodes(Function *F) { +std::pair<std::vector<BasicBlock *>, std::vector<BasicBlock *> > markNodes( + Function *F) { + assert(F->size() > 0 && "Function can not be empty"); reset(); @@ -335,21 +436,30 @@ std::pair<std::vector<BasicBlock *>, DominatorTree::DominatorTree(F); MarkVertice(F); - std::vector<BasicBlock *> Result , ResultAbove; - for( uint32_t x : Markabove ) { - auto it = Marked.find( x ); - if( it != Marked.end() ) - Marked.erase( it ); - if( x ) - ResultAbove.push_back(Blocks[x]); + std::vector<BasicBlock *> Result, ResultAbove; + for (uint32_t x : Markabove) { + + auto it = Marked.find(x); + if (it != Marked.end()) Marked.erase(it); + if (x) ResultAbove.push_back(Blocks[x]); + } - for( uint32_t x : Marked ) { + + for (uint32_t x : Marked) { + if (x == 0) { + continue; + } else { + Result.push_back(Blocks[x]); + } + } - return { Result , ResultAbove }; + return {Result, ResultAbove}; + } + diff --git a/llvm_mode/MarkNodes.h b/llvm_mode/MarkNodes.h index e3bf3ce5..23316652 100644 --- a/llvm_mode/MarkNodes.h +++ b/llvm_mode/MarkNodes.h @@ -1,11 +1,12 @@ #ifndef __MARK_NODES__ -#define __MARK_NODES__ +# define __MARK_NODES__ -#include "llvm/IR/BasicBlock.h" -#include "llvm/IR/Function.h" -#include<vector> +# include "llvm/IR/BasicBlock.h" +# include "llvm/IR/Function.h" +# include <vector> -std::pair<std::vector<llvm::BasicBlock *>, - std::vector<llvm::BasicBlock *>> markNodes(llvm::Function *F); +std::pair<std::vector<llvm::BasicBlock *>, std::vector<llvm::BasicBlock *>> +markNodes(llvm::Function *F); #endif + diff --git a/llvm_mode/afl-clang-fast.c b/llvm_mode/afl-clang-fast.c index 1b810edf..666fd043 100644 --- a/llvm_mode/afl-clang-fast.c +++ b/llvm_mode/afl-clang-fast.c @@ -34,16 +34,15 @@ #include <string.h> #include <assert.h> -static u8* obj_path; /* Path to runtime libraries */ -static u8** cc_params; /* Parameters passed to the real CC */ -static u32 cc_par_cnt = 1; /* Param count, including argv0 */ - +static u8* obj_path; /* Path to runtime libraries */ +static u8** cc_params; /* Parameters passed to the real CC */ +static u32 cc_par_cnt = 1; /* Param count, including argv0 */ /* Try to find the runtime libraries. If that fails, abort. */ static void find_obj(u8* argv0) { - u8 *afl_path = getenv("AFL_PATH"); + u8* afl_path = getenv("AFL_PATH"); u8 *slash, *tmp; if (afl_path) { @@ -51,9 +50,11 @@ static void find_obj(u8* argv0) { tmp = alloc_printf("%s/afl-llvm-rt.o", afl_path); if (!access(tmp, R_OK)) { + obj_path = afl_path; ck_free(tmp); return; + } ck_free(tmp); @@ -64,7 +65,7 @@ static void find_obj(u8* argv0) { if (slash) { - u8 *dir; + u8* dir; *slash = 0; dir = ck_strdup(argv0); @@ -73,9 +74,11 @@ static void find_obj(u8* argv0) { tmp = alloc_printf("%s/afl-llvm-rt.o", dir); if (!access(tmp, R_OK)) { + obj_path = dir; ck_free(tmp); return; + } ck_free(tmp); @@ -84,33 +87,43 @@ static void find_obj(u8* argv0) { } if (!access(AFL_PATH "/afl-llvm-rt.o", R_OK)) { + obj_path = AFL_PATH; return; + } - FATAL("Unable to find 'afl-llvm-rt.o' or 'afl-llvm-pass.so.cc'. Please set AFL_PATH"); - -} + FATAL( + "Unable to find 'afl-llvm-rt.o' or 'afl-llvm-pass.so.cc'. Please set " + "AFL_PATH"); +} /* Copy argv to cc_params, making the necessary edits. */ static void edit_params(u32 argc, char** argv) { - u8 fortify_set = 0, asan_set = 0, x_set = 0, maybe_linking = 1, bit_mode = 0; - u8 *name; + u8 fortify_set = 0, asan_set = 0, x_set = 0, maybe_linking = 1, bit_mode = 0; + u8* name; cc_params = ck_alloc((argc + 128) * sizeof(u8*)); name = strrchr(argv[0], '/'); - if (!name) name = argv[0]; else name++; + if (!name) + name = argv[0]; + else + name++; if (!strcmp(name, "afl-clang-fast++")) { + u8* alt_cxx = getenv("AFL_CXX"); cc_params[0] = alt_cxx ? alt_cxx : (u8*)"clang++"; + } else { + u8* alt_cc = getenv("AFL_CC"); cc_params[0] = alt_cc ? alt_cc : (u8*)"clang"; + } /* There are three ways to compile with afl-clang-fast. In the traditional @@ -118,36 +131,50 @@ static void edit_params(u32 argc, char** argv) { much faster but has less coverage. Finally tere is the experimental 'trace-pc-guard' mode, we use native LLVM instrumentation callbacks instead. For trace-pc-guard see: - http://clang.llvm.org/docs/SanitizerCoverage.html#tracing-pcs-with-guards */ + http://clang.llvm.org/docs/SanitizerCoverage.html#tracing-pcs-with-guards + */ // laf - if (getenv("LAF_SPLIT_SWITCHES")||getenv("AFL_LLVM_LAF_SPLIT_SWITCHES")) { + if (getenv("LAF_SPLIT_SWITCHES") || getenv("AFL_LLVM_LAF_SPLIT_SWITCHES")) { + cc_params[cc_par_cnt++] = "-Xclang"; cc_params[cc_par_cnt++] = "-load"; cc_params[cc_par_cnt++] = "-Xclang"; - cc_params[cc_par_cnt++] = alloc_printf("%s/split-switches-pass.so", obj_path); + cc_params[cc_par_cnt++] = + alloc_printf("%s/split-switches-pass.so", obj_path); + } - if (getenv("LAF_TRANSFORM_COMPARES")||getenv("AFL_LLVM_LAF_TRANSFORM_COMPARES")) { + if (getenv("LAF_TRANSFORM_COMPARES") || + getenv("AFL_LLVM_LAF_TRANSFORM_COMPARES")) { + cc_params[cc_par_cnt++] = "-Xclang"; cc_params[cc_par_cnt++] = "-load"; cc_params[cc_par_cnt++] = "-Xclang"; - cc_params[cc_par_cnt++] = alloc_printf("%s/compare-transform-pass.so", obj_path); + cc_params[cc_par_cnt++] = + alloc_printf("%s/compare-transform-pass.so", obj_path); + } - if (getenv("LAF_SPLIT_COMPARES")||getenv("AFL_LLVM_LAF_SPLIT_COMPARES")) { + if (getenv("LAF_SPLIT_COMPARES") || getenv("AFL_LLVM_LAF_SPLIT_COMPARES")) { + cc_params[cc_par_cnt++] = "-Xclang"; cc_params[cc_par_cnt++] = "-load"; cc_params[cc_par_cnt++] = "-Xclang"; - cc_params[cc_par_cnt++] = alloc_printf("%s/split-compares-pass.so", obj_path); + cc_params[cc_par_cnt++] = + alloc_printf("%s/split-compares-pass.so", obj_path); + } + // /laf #ifdef USE_TRACE_PC - cc_params[cc_par_cnt++] = "-fsanitize-coverage=trace-pc-guard"; // edge coverage by default - //cc_params[cc_par_cnt++] = "-mllvm"; - //cc_params[cc_par_cnt++] = "-fsanitize-coverage=trace-cmp,trace-div,trace-gep"; - //cc_params[cc_par_cnt++] = "-sanitizer-coverage-block-threshold=0"; + cc_params[cc_par_cnt++] = + "-fsanitize-coverage=trace-pc-guard"; // edge coverage by default + // cc_params[cc_par_cnt++] = "-mllvm"; + // cc_params[cc_par_cnt++] = + // "-fsanitize-coverage=trace-cmp,trace-div,trace-gep"; cc_params[cc_par_cnt++] + // = "-sanitizer-coverage-block-threshold=0"; #else cc_params[cc_par_cnt++] = "-Xclang"; cc_params[cc_par_cnt++] = "-load"; @@ -165,6 +192,7 @@ static void edit_params(u32 argc, char** argv) { if (argc == 1 && !strcmp(argv[1], "-v")) maybe_linking = 0; while (--argc) { + u8* cur = *(++argv); if (!strcmp(cur, "-m32")) bit_mode = 32; @@ -175,15 +203,15 @@ static void edit_params(u32 argc, char** argv) { if (!strcmp(cur, "-c") || !strcmp(cur, "-S") || !strcmp(cur, "-E")) maybe_linking = 0; - if (!strcmp(cur, "-fsanitize=address") || - !strcmp(cur, "-fsanitize=memory")) asan_set = 1; + if (!strcmp(cur, "-fsanitize=address") || !strcmp(cur, "-fsanitize=memory")) + asan_set = 1; if (strstr(cur, "FORTIFY_SOURCE")) fortify_set = 1; if (!strcmp(cur, "-shared")) maybe_linking = 0; - if (!strcmp(cur, "-Wl,-z,defs") || - !strcmp(cur, "-Wl,--no-undefined")) continue; + if (!strcmp(cur, "-Wl,-z,defs") || !strcmp(cur, "-Wl,--no-undefined")) + continue; cc_params[cc_par_cnt++] = cur; @@ -193,8 +221,7 @@ static void edit_params(u32 argc, char** argv) { cc_params[cc_par_cnt++] = "-fstack-protector-all"; - if (!fortify_set) - cc_params[cc_par_cnt++] = "-D_FORTIFY_SOURCE=2"; + if (!fortify_set) cc_params[cc_par_cnt++] = "-D_FORTIFY_SOURCE=2"; } @@ -202,8 +229,7 @@ static void edit_params(u32 argc, char** argv) { if (getenv("AFL_USE_ASAN")) { - if (getenv("AFL_USE_MSAN")) - FATAL("ASAN and MSAN are mutually exclusive"); + if (getenv("AFL_USE_MSAN")) FATAL("ASAN and MSAN are mutually exclusive"); if (getenv("AFL_HARDEN")) FATAL("ASAN and AFL_HARDEN are mutually exclusive"); @@ -213,8 +239,7 @@ static void edit_params(u32 argc, char** argv) { } else if (getenv("AFL_USE_MSAN")) { - if (getenv("AFL_USE_ASAN")) - FATAL("ASAN and MSAN are mutually exclusive"); + if (getenv("AFL_USE_ASAN")) FATAL("ASAN and MSAN are mutually exclusive"); if (getenv("AFL_HARDEN")) FATAL("MSAN and AFL_HARDEN are mutually exclusive"); @@ -279,35 +304,41 @@ static void edit_params(u32 argc, char** argv) { */ - cc_params[cc_par_cnt++] = "-D__AFL_LOOP(_A)=" - "({ static volatile char *_B __attribute__((used)); " - " _B = (char*)\"" PERSIST_SIG "\"; " + cc_params[cc_par_cnt++] = + "-D__AFL_LOOP(_A)=" + "({ static volatile char *_B __attribute__((used)); " + " _B = (char*)\"" PERSIST_SIG + "\"; " #ifdef __APPLE__ - "__attribute__((visibility(\"default\"))) " - "int _L(unsigned int) __asm__(\"___afl_persistent_loop\"); " + "__attribute__((visibility(\"default\"))) " + "int _L(unsigned int) __asm__(\"___afl_persistent_loop\"); " #else - "__attribute__((visibility(\"default\"))) " - "int _L(unsigned int) __asm__(\"__afl_persistent_loop\"); " + "__attribute__((visibility(\"default\"))) " + "int _L(unsigned int) __asm__(\"__afl_persistent_loop\"); " #endif /* ^__APPLE__ */ - "_L(_A); })"; + "_L(_A); })"; - cc_params[cc_par_cnt++] = "-D__AFL_INIT()=" - "do { static volatile char *_A __attribute__((used)); " - " _A = (char*)\"" DEFER_SIG "\"; " + cc_params[cc_par_cnt++] = + "-D__AFL_INIT()=" + "do { static volatile char *_A __attribute__((used)); " + " _A = (char*)\"" DEFER_SIG + "\"; " #ifdef __APPLE__ - "__attribute__((visibility(\"default\"))) " - "void _I(void) __asm__(\"___afl_manual_init\"); " + "__attribute__((visibility(\"default\"))) " + "void _I(void) __asm__(\"___afl_manual_init\"); " #else - "__attribute__((visibility(\"default\"))) " - "void _I(void) __asm__(\"__afl_manual_init\"); " + "__attribute__((visibility(\"default\"))) " + "void _I(void) __asm__(\"__afl_manual_init\"); " #endif /* ^__APPLE__ */ - "_I(); } while (0)"; + "_I(); } while (0)"; if (maybe_linking) { if (x_set) { + cc_params[cc_par_cnt++] = "-x"; cc_params[cc_par_cnt++] = "none"; + } switch (bit_mode) { @@ -340,7 +371,6 @@ static void edit_params(u32 argc, char** argv) { } - /* Main entry point */ int main(int argc, char** argv) { @@ -348,46 +378,53 @@ int main(int argc, char** argv) { if (isatty(2) && !getenv("AFL_QUIET")) { #ifdef USE_TRACE_PC - SAYF(cCYA "afl-clang-fast" VERSION cRST " [tpcg] by <lszekeres@google.com>\n"); + SAYF(cCYA "afl-clang-fast" VERSION cRST + " [tpcg] by <lszekeres@google.com>\n"); #else - SAYF(cCYA "afl-clang-fast" VERSION cRST " by <lszekeres@google.com>\n"); + SAYF(cCYA "afl-clang-fast" VERSION cRST " by <lszekeres@google.com>\n"); #endif /* ^USE_TRACE_PC */ } if (argc < 2) { - SAYF("\n" - "This is a helper application for afl-fuzz. It serves as a drop-in replacement\n" - "for clang, letting you recompile third-party code with the required runtime\n" - "instrumentation. A common use pattern would be one of the following:\n\n" + SAYF( + "\n" + "This is a helper application for afl-fuzz. It serves as a drop-in " + "replacement\n" + "for clang, letting you recompile third-party code with the required " + "runtime\n" + "instrumentation. A common use pattern would be one of the " + "following:\n\n" - " CC=%s/afl-clang-fast ./configure\n" - " CXX=%s/afl-clang-fast++ ./configure\n\n" + " CC=%s/afl-clang-fast ./configure\n" + " CXX=%s/afl-clang-fast++ ./configure\n\n" - "In contrast to the traditional afl-clang tool, this version is implemented as\n" - "an LLVM pass and tends to offer improved performance with slow programs.\n\n" + "In contrast to the traditional afl-clang tool, this version is " + "implemented as\n" + "an LLVM pass and tends to offer improved performance with slow " + "programs.\n\n" - "You can specify custom next-stage toolchain via AFL_CC and AFL_CXX. Setting\n" - "AFL_HARDEN enables hardening optimizations in the compiled code.\n\n", - BIN_PATH, BIN_PATH); + "You can specify custom next-stage toolchain via AFL_CC and AFL_CXX. " + "Setting\n" + "AFL_HARDEN enables hardening optimizations in the compiled code.\n\n", + BIN_PATH, BIN_PATH); exit(1); } - find_obj(argv[0]); edit_params(argc, argv); -/* - int i = 0; - printf("EXEC:"); - while (cc_params[i] != NULL) - printf(" %s", cc_params[i++]); - printf("\n"); -*/ + /* + int i = 0; + printf("EXEC:"); + while (cc_params[i] != NULL) + printf(" %s", cc_params[i++]); + printf("\n"); + */ execvp(cc_params[0], (char**)cc_params); @@ -396,3 +433,4 @@ int main(int argc, char** argv) { return 0; } + diff --git a/llvm_mode/afl-llvm-pass.so.cc b/llvm_mode/afl-llvm-pass.so.cc index b242163e..5d531a87 100644 --- a/llvm_mode/afl-llvm-pass.so.cc +++ b/llvm_mode/afl-llvm-pass.so.cc @@ -48,50 +48,52 @@ using namespace llvm; namespace { - class AFLCoverage : public ModulePass { - - public: - - static char ID; - AFLCoverage() : ModulePass(ID) { - char* instWhiteListFilename = getenv("AFL_LLVM_WHITELIST"); - if (instWhiteListFilename) { - std::string line; - std::ifstream fileStream; - fileStream.open(instWhiteListFilename); - if (!fileStream) - report_fatal_error("Unable to open AFL_LLVM_WHITELIST"); - getline(fileStream, line); - while (fileStream) { - myWhitelist.push_back(line); - getline(fileStream, line); - } - } +class AFLCoverage : public ModulePass { + + public: + static char ID; + AFLCoverage() : ModulePass(ID) { + + char *instWhiteListFilename = getenv("AFL_LLVM_WHITELIST"); + if (instWhiteListFilename) { + + std::string line; + std::ifstream fileStream; + fileStream.open(instWhiteListFilename); + if (!fileStream) report_fatal_error("Unable to open AFL_LLVM_WHITELIST"); + getline(fileStream, line); + while (fileStream) { + + myWhitelist.push_back(line); + getline(fileStream, line); + } - bool runOnModule(Module &M) override; + } - // StringRef getPassName() const override { - // return "American Fuzzy Lop Instrumentation"; - // } + } - protected: + bool runOnModule(Module &M) override; - std::list<std::string> myWhitelist; + // StringRef getPassName() const override { - }; + // return "American Fuzzy Lop Instrumentation"; + // } -} + protected: + std::list<std::string> myWhitelist; +}; -char AFLCoverage::ID = 0; +} // namespace +char AFLCoverage::ID = 0; bool AFLCoverage::runOnModule(Module &M) { LLVMContext &C = M.getContext(); - IntegerType *Int8Ty = IntegerType::getInt8Ty(C); + IntegerType *Int8Ty = IntegerType::getInt8Ty(C); IntegerType *Int32Ty = IntegerType::getInt32Ty(C); unsigned int cur_loc = 0; @@ -103,11 +105,13 @@ bool AFLCoverage::runOnModule(Module &M) { SAYF(cCYA "afl-llvm-pass" VERSION cRST " by <lszekeres@google.com>\n"); - } else be_quiet = 1; + } else + + be_quiet = 1; /* Decide instrumentation ratio */ - char* inst_ratio_str = getenv("AFL_INST_RATIO"); + char * inst_ratio_str = getenv("AFL_INST_RATIO"); unsigned int inst_ratio = 100; if (inst_ratio_str) { @@ -119,7 +123,7 @@ bool AFLCoverage::runOnModule(Module &M) { } #if LLVM_VERSION_MAJOR < 9 - char* neverZero_counters_str = getenv("AFL_LLVM_NOT_ZERO"); + char *neverZero_counters_str = getenv("AFL_LLVM_NOT_ZERO"); #endif /* Get globals for the SHM region and the previous location. Note that @@ -134,8 +138,8 @@ bool AFLCoverage::runOnModule(Module &M) { M, Int32Ty, false, GlobalValue::ExternalLinkage, 0, "__afl_prev_loc"); #else GlobalVariable *AFLPrevLoc = new GlobalVariable( - M, Int32Ty, false, GlobalValue::ExternalLinkage, 0, "__afl_prev_loc", - 0, GlobalVariable::GeneralDynamicTLSModel, 0, false); + M, Int32Ty, false, GlobalValue::ExternalLinkage, 0, "__afl_prev_loc", 0, + GlobalVariable::GeneralDynamicTLSModel, 0, false); #endif /* Instrument all the things! */ @@ -146,58 +150,77 @@ bool AFLCoverage::runOnModule(Module &M) { for (auto &BB : F) { BasicBlock::iterator IP = BB.getFirstInsertionPt(); - IRBuilder<> IRB(&(*IP)); - + IRBuilder<> IRB(&(*IP)); + if (!myWhitelist.empty()) { - bool instrumentBlock = false; - - /* Get the current location using debug information. - * For now, just instrument the block if we are not able - * to determine our location. */ - DebugLoc Loc = IP->getDebugLoc(); - if ( Loc ) { - DILocation *cDILoc = dyn_cast<DILocation>(Loc.getAsMDNode()); - - unsigned int instLine = cDILoc->getLine(); - StringRef instFilename = cDILoc->getFilename(); - - if (instFilename.str().empty()) { - /* If the original location is empty, try using the inlined location */ - DILocation *oDILoc = cDILoc->getInlinedAt(); - if (oDILoc) { - instFilename = oDILoc->getFilename(); - instLine = oDILoc->getLine(); - } - } - /* Continue only if we know where we actually are */ - if (!instFilename.str().empty()) { - for (std::list<std::string>::iterator it = myWhitelist.begin(); it != myWhitelist.end(); ++it) { - /* We don't check for filename equality here because - * filenames might actually be full paths. Instead we - * check that the actual filename ends in the filename - * specified in the list. */ - if (instFilename.str().length() >= it->length()) { - if (instFilename.str().compare(instFilename.str().length() - it->length(), it->length(), *it) == 0) { - instrumentBlock = true; - break; - } - } - } + bool instrumentBlock = false; + + /* Get the current location using debug information. + * For now, just instrument the block if we are not able + * to determine our location. */ + DebugLoc Loc = IP->getDebugLoc(); + if (Loc) { + + DILocation *cDILoc = dyn_cast<DILocation>(Loc.getAsMDNode()); + + unsigned int instLine = cDILoc->getLine(); + StringRef instFilename = cDILoc->getFilename(); + + if (instFilename.str().empty()) { + + /* If the original location is empty, try using the inlined location + */ + DILocation *oDILoc = cDILoc->getInlinedAt(); + if (oDILoc) { + + instFilename = oDILoc->getFilename(); + instLine = oDILoc->getLine(); + + } + + } + + /* Continue only if we know where we actually are */ + if (!instFilename.str().empty()) { + + for (std::list<std::string>::iterator it = myWhitelist.begin(); + it != myWhitelist.end(); ++it) { + + /* We don't check for filename equality here because + * filenames might actually be full paths. Instead we + * check that the actual filename ends in the filename + * specified in the list. */ + if (instFilename.str().length() >= it->length()) { + + if (instFilename.str().compare( + instFilename.str().length() - it->length(), + it->length(), *it) == 0) { + + instrumentBlock = true; + break; + + } + } + + } + } - /* Either we couldn't figure out our location or the location is - * not whitelisted, so we skip instrumentation. */ - if (!instrumentBlock) continue; - } + } + + /* Either we couldn't figure out our location or the location is + * not whitelisted, so we skip instrumentation. */ + if (!instrumentBlock) continue; + } if (AFL_R(100) >= inst_ratio) continue; /* Make up cur_loc */ - //cur_loc++; + // cur_loc++; cur_loc = AFL_R(MAP_SIZE); // only instrument if this basic block is the destination of a previous @@ -205,24 +228,27 @@ bool AFLCoverage::runOnModule(Module &M) { // this gets rid of ~5-10% of instrumentations that are unnecessary // result: a little more speed and less map pollution int more_than_one = -1; - //fprintf(stderr, "BB %u: ", cur_loc); + // fprintf(stderr, "BB %u: ", cur_loc); for (BasicBlock *Pred : predecessors(&BB)) { + int count = 0; - if (more_than_one == -1) - more_than_one = 0; - //fprintf(stderr, " %p=>", Pred); + if (more_than_one == -1) more_than_one = 0; + // fprintf(stderr, " %p=>", Pred); for (BasicBlock *Succ : successors(Pred)) { - //if (count > 0) + + // if (count > 0) // fprintf(stderr, "|"); if (Succ != NULL) count++; - //fprintf(stderr, "%p", Succ); + // fprintf(stderr, "%p", Succ); + } - if (count > 1) - more_than_one = 1; + + if (count > 1) more_than_one = 1; + } - //fprintf(stderr, " == %d\n", more_than_one); - if (more_than_one != 1) - continue; + + // fprintf(stderr, " == %d\n", more_than_one); + if (more_than_one != 1) continue; ConstantInt *CurLoc = ConstantInt::get(Int32Ty, cur_loc); @@ -236,7 +262,8 @@ bool AFLCoverage::runOnModule(Module &M) { LoadInst *MapPtr = IRB.CreateLoad(AFLMapPtr); MapPtr->setMetadata(M.getMDKindID("nosanitize"), MDNode::get(C, None)); - Value *MapPtrIdx = IRB.CreateGEP(MapPtr, IRB.CreateXor(PrevLocCasted, CurLoc)); + Value *MapPtrIdx = + IRB.CreateGEP(MapPtr, IRB.CreateXor(PrevLocCasted, CurLoc)); /* Update bitmap */ @@ -246,7 +273,9 @@ bool AFLCoverage::runOnModule(Module &M) { Value *Incr = IRB.CreateAdd(Counter, ConstantInt::get(Int8Ty, 1)); #if LLVM_VERSION_MAJOR < 9 - if (neverZero_counters_str != NULL) { // with llvm 9 we make this the default as the bug in llvm is then fixed + if (neverZero_counters_str != + NULL) { // with llvm 9 we make this the default as the bug in llvm is + // then fixed #endif /* hexcoder: Realize a counter that skips zero during overflow. * Once this counter reaches its maximum value, it next increments to 1 @@ -257,48 +286,67 @@ bool AFLCoverage::runOnModule(Module &M) { * Counter + 1 -> {Counter, OverflowFlag} * Counter + OverflowFlag -> Counter */ -/* // we keep the old solutions just in case - // Solution #1 - if (neverZero_counters_str[0] == '1') { - CallInst *AddOv = IRB.CreateBinaryIntrinsic(Intrinsic::uadd_with_overflow, Counter, ConstantInt::get(Int8Ty, 1)); - AddOv->setMetadata(M.getMDKindID("nosanitize"), MDNode::get(C, None)); - Value *SumWithOverflowBit = AddOv; - Incr = IRB.CreateAdd(IRB.CreateExtractValue(SumWithOverflowBit, 0), // sum - IRB.CreateZExt( // convert from one bit type to 8 bits type - IRB.CreateExtractValue(SumWithOverflowBit, 1), // overflow - Int8Ty)); - // Solution #2 - } else if (neverZero_counters_str[0] == '2') { - auto cf = IRB.CreateICmpEQ(Counter, ConstantInt::get(Int8Ty, 255)); - Value *HowMuch = IRB.CreateAdd(ConstantInt::get(Int8Ty, 1), cf); - Incr = IRB.CreateAdd(Counter, HowMuch); - // Solution #3 - } else if (neverZero_counters_str[0] == '3') { -*/ - // this is the solution we choose because llvm9 should do the right thing here - auto cf = IRB.CreateICmpEQ(Incr, ConstantInt::get(Int8Ty, 0)); - auto carry = IRB.CreateZExt(cf, Int8Ty); - Incr = IRB.CreateAdd(Incr, carry); + /* // we keep the old solutions just in case + // Solution #1 + if (neverZero_counters_str[0] == '1') { + + CallInst *AddOv = + IRB.CreateBinaryIntrinsic(Intrinsic::uadd_with_overflow, Counter, + ConstantInt::get(Int8Ty, 1)); + AddOv->setMetadata(M.getMDKindID("nosanitize"), + MDNode::get(C, None)); Value *SumWithOverflowBit = AddOv; Incr = + IRB.CreateAdd(IRB.CreateExtractValue(SumWithOverflowBit, 0), // sum + IRB.CreateZExt( // convert from one bit + type to 8 bits type IRB.CreateExtractValue(SumWithOverflowBit, 1), // + overflow Int8Ty)); + // Solution #2 + + } else if (neverZero_counters_str[0] == '2') { + + auto cf = IRB.CreateICmpEQ(Counter, + ConstantInt::get(Int8Ty, 255)); Value *HowMuch = + IRB.CreateAdd(ConstantInt::get(Int8Ty, 1), cf); Incr = + IRB.CreateAdd(Counter, HowMuch); + // Solution #3 + + } else if (neverZero_counters_str[0] == '3') { + + */ + // this is the solution we choose because llvm9 should do the right + // thing here + auto cf = IRB.CreateICmpEQ(Incr, ConstantInt::get(Int8Ty, 0)); + auto carry = IRB.CreateZExt(cf, Int8Ty); + Incr = IRB.CreateAdd(Incr, carry); /* // Solution #4 + } else if (neverZero_counters_str[0] == '4') { + auto cf = IRB.CreateICmpULT(Incr, ConstantInt::get(Int8Ty, 1)); auto carry = IRB.CreateZExt(cf, Int8Ty); Incr = IRB.CreateAdd(Incr, carry); + } else { - fprintf(stderr, "Error: unknown value for AFL_NZERO_COUNTS: %s (valid is 1-4)\n", neverZero_counters_str); - exit(-1); + + fprintf(stderr, "Error: unknown value for AFL_NZERO_COUNTS: %s + (valid is 1-4)\n", neverZero_counters_str); exit(-1); + } + */ #if LLVM_VERSION_MAJOR < 9 + } + #endif - IRB.CreateStore(Incr, MapPtrIdx)->setMetadata(M.getMDKindID("nosanitize"), MDNode::get(C, None)); + IRB.CreateStore(Incr, MapPtrIdx) + ->setMetadata(M.getMDKindID("nosanitize"), MDNode::get(C, None)); /* Set prev_loc to cur_loc >> 1 */ - StoreInst *Store = IRB.CreateStore(ConstantInt::get(Int32Ty, cur_loc >> 1), AFLPrevLoc); + StoreInst *Store = + IRB.CreateStore(ConstantInt::get(Int32Ty, cur_loc >> 1), AFLPrevLoc); Store->setMetadata(M.getMDKindID("nosanitize"), MDNode::get(C, None)); inst_blocks++; @@ -309,11 +357,16 @@ bool AFLCoverage::runOnModule(Module &M) { if (!be_quiet) { - if (!inst_blocks) WARNF("No instrumentation targets found."); - else OKF("Instrumented %u locations (%s mode, ratio %u%%).", - inst_blocks, getenv("AFL_HARDEN") ? "hardened" : - ((getenv("AFL_USE_ASAN") || getenv("AFL_USE_MSAN")) ? - "ASAN/MSAN" : "non-hardened"), inst_ratio); + if (!inst_blocks) + WARNF("No instrumentation targets found."); + else + OKF("Instrumented %u locations (%s mode, ratio %u%%).", inst_blocks, + getenv("AFL_HARDEN") + ? "hardened" + : ((getenv("AFL_USE_ASAN") || getenv("AFL_USE_MSAN")) + ? "ASAN/MSAN" + : "non-hardened"), + inst_ratio); } @@ -321,7 +374,6 @@ bool AFLCoverage::runOnModule(Module &M) { } - static void registerAFLPass(const PassManagerBuilder &, legacy::PassManagerBase &PM) { @@ -329,9 +381,9 @@ static void registerAFLPass(const PassManagerBuilder &, } - static RegisterStandardPasses RegisterAFLPass( PassManagerBuilder::EP_OptimizerLast, registerAFLPass); static RegisterStandardPasses RegisterAFLPass0( PassManagerBuilder::EP_EnabledOnOptLevel0, registerAFLPass); + diff --git a/llvm_mode/afl-llvm-rt.o.c b/llvm_mode/afl-llvm-rt.o.c index e6d9b993..bc38f1ec 100644 --- a/llvm_mode/afl-llvm-rt.o.c +++ b/llvm_mode/afl-llvm-rt.o.c @@ -20,7 +20,7 @@ */ #ifdef __ANDROID__ - #include "android-ashmem.h" +# include "android-ashmem.h" #endif #include "config.h" #include "types.h" @@ -50,10 +50,9 @@ #include <sys/mman.h> #include <fcntl.h> - /* Globals needed by the injected instrumentation. The __afl_area_initial region - is used for instrumentation output before __afl_map_shm() has a chance to run. - It will end up as .comm, so it shouldn't be too wasteful. */ + is used for instrumentation output before __afl_map_shm() has a chance to + run. It will end up as .comm, so it shouldn't be too wasteful. */ u8 __afl_area_initial[MAP_SIZE]; u8* __afl_area_ptr = __afl_area_initial; @@ -64,43 +63,46 @@ u32 __afl_prev_loc; __thread u32 __afl_prev_loc; #endif - /* Running in persistent mode? */ static u8 is_persistent; - /* SHM setup. */ static void __afl_map_shm(void) { - u8 *id_str = getenv(SHM_ENV_VAR); + u8* id_str = getenv(SHM_ENV_VAR); /* If we're running under AFL, attach to the appropriate region, replacing the early-stage __afl_area_initial region that is needed to allow some really hacky .init code to work correctly in projects such as OpenSSL. */ if (id_str) { + #ifdef USEMMAP - const char *shm_file_path = id_str; - int shm_fd = -1; - unsigned char *shm_base = NULL; + const char* shm_file_path = id_str; + int shm_fd = -1; + unsigned char* shm_base = NULL; /* create the shared memory segment as if it was a file */ shm_fd = shm_open(shm_file_path, O_RDWR, 0600); if (shm_fd == -1) { + printf("shm_open() failed\n"); exit(1); + } /* map the shared memory segment to the address space of the process */ shm_base = mmap(0, MAP_SIZE, PROT_READ | PROT_WRITE, MAP_SHARED, shm_fd, 0); if (shm_base == MAP_FAILED) { + close(shm_fd); shm_fd = -1; printf("mmap() failed\n"); exit(2); + } __afl_area_ptr = shm_base; @@ -112,7 +114,7 @@ static void __afl_map_shm(void) { /* Whooooops. */ - if (__afl_area_ptr == (void *)-1) _exit(1); + if (__afl_area_ptr == (void*)-1) _exit(1); /* Write something into the bitmap so that even with low AFL_INST_RATIO, our parent doesn't give up on us. */ @@ -123,16 +125,15 @@ static void __afl_map_shm(void) { } - /* Fork server logic. */ static void __afl_start_forkserver(void) { static u8 tmp[4]; - s32 child_pid; + s32 child_pid; + + u8 child_stopped = 0; - u8 child_stopped = 0; - void (*old_sigchld_handler)(int) = signal(SIGCHLD, SIG_DFL); /* Phone home and tell the parent that we're OK. If parent isn't there, @@ -154,8 +155,10 @@ static void __afl_start_forkserver(void) { process. */ if (child_stopped && was_killed) { + child_stopped = 0; if (waitpid(child_pid, &status, 0) < 0) _exit(1); + } if (!child_stopped) { @@ -168,12 +171,13 @@ static void __afl_start_forkserver(void) { /* In child process: close fds, resume execution. */ if (!child_pid) { + signal(SIGCHLD, old_sigchld_handler); close(FORKSRV_FD); close(FORKSRV_FD + 1); return; - + } } else { @@ -207,7 +211,6 @@ static void __afl_start_forkserver(void) { } - /* A simplified persistent mode handler, used as explained in README.llvm. */ int __afl_persistent_loop(unsigned int max_cnt) { @@ -227,9 +230,10 @@ int __afl_persistent_loop(unsigned int max_cnt) { memset(__afl_area_ptr, 0, MAP_SIZE); __afl_area_ptr[0] = 1; __afl_prev_loc = 0; + } - cycle_cnt = max_cnt; + cycle_cnt = max_cnt; first_pass = 0; return 1; @@ -262,7 +266,6 @@ int __afl_persistent_loop(unsigned int max_cnt) { } - /* This one can be called from user code when deferred forkserver mode is enabled. */ @@ -280,7 +283,6 @@ void __afl_manual_init(void) { } - /* Proper initialization routine. */ __attribute__((constructor(CONST_PRIO))) void __afl_auto_init(void) { @@ -293,7 +295,6 @@ __attribute__((constructor(CONST_PRIO))) void __afl_auto_init(void) { } - /* The following stuff deals with supporting -fsanitize-coverage=trace-pc-guard. It remains non-operational in the traditional, plugin-backed LLVM mode. For more info about 'trace-pc-guard', see README.llvm. @@ -302,9 +303,10 @@ __attribute__((constructor(CONST_PRIO))) void __afl_auto_init(void) { edge (as opposed to every basic block). */ void __sanitizer_cov_trace_pc_guard(uint32_t* guard) { + __afl_area_ptr[*guard]++; -} +} /* Init callback. Populates instrumentation IDs. Note that we're using ID of 0 as a special value to indicate non-instrumented bits. That may @@ -321,8 +323,10 @@ void __sanitizer_cov_trace_pc_guard_init(uint32_t* start, uint32_t* stop) { if (x) inst_ratio = atoi(x); if (!inst_ratio || inst_ratio > 100) { + fprintf(stderr, "[-] ERROR: Invalid AFL_INST_RATIO (must be 1-100).\n"); abort(); + } /* Make sure that the first element in the range is always set - we use that @@ -333,11 +337,14 @@ void __sanitizer_cov_trace_pc_guard_init(uint32_t* start, uint32_t* stop) { while (start < stop) { - if (R(100) < inst_ratio) *start = R(MAP_SIZE - 1) + 1; - else *start = 0; + if (R(100) < inst_ratio) + *start = R(MAP_SIZE - 1) + 1; + else + *start = 0; start++; } } + diff --git a/llvm_mode/compare-transform-pass.so.cc b/llvm_mode/compare-transform-pass.so.cc index e7886db1..e1b6e671 100644 --- a/llvm_mode/compare-transform-pass.so.cc +++ b/llvm_mode/compare-transform-pass.so.cc @@ -36,202 +36,236 @@ using namespace llvm; namespace { - class CompareTransform : public ModulePass { +class CompareTransform : public ModulePass { - public: - static char ID; - CompareTransform() : ModulePass(ID) { - } + public: + static char ID; + CompareTransform() : ModulePass(ID) { - bool runOnModule(Module &M) override; + } + + bool runOnModule(Module &M) override; #if LLVM_VERSION_MAJOR < 4 - const char * getPassName() const override { + const char *getPassName() const override { + #else - StringRef getPassName() const override { + StringRef getPassName() const override { + #endif - return "transforms compare functions"; - } - private: - bool transformCmps(Module &M, const bool processStrcmp, const bool processMemcmp - ,const bool processStrncmp, const bool processStrcasecmp, const bool processStrncasecmp); - }; -} + return "transforms compare functions"; + } + + private: + bool transformCmps(Module &M, const bool processStrcmp, + const bool processMemcmp, const bool processStrncmp, + const bool processStrcasecmp, + const bool processStrncasecmp); + +}; + +} // namespace char CompareTransform::ID = 0; -bool CompareTransform::transformCmps(Module &M, const bool processStrcmp, const bool processMemcmp - , const bool processStrncmp, const bool processStrcasecmp, const bool processStrncasecmp) { +bool CompareTransform::transformCmps(Module &M, const bool processStrcmp, + const bool processMemcmp, + const bool processStrncmp, + const bool processStrcasecmp, + const bool processStrncasecmp) { - std::vector<CallInst*> calls; - LLVMContext &C = M.getContext(); - IntegerType *Int8Ty = IntegerType::getInt8Ty(C); - IntegerType *Int32Ty = IntegerType::getInt32Ty(C); - IntegerType *Int64Ty = IntegerType::getInt64Ty(C); + std::vector<CallInst *> calls; + LLVMContext & C = M.getContext(); + IntegerType * Int8Ty = IntegerType::getInt8Ty(C); + IntegerType * Int32Ty = IntegerType::getInt32Ty(C); + IntegerType * Int64Ty = IntegerType::getInt64Ty(C); #if LLVM_VERSION_MAJOR < 9 - Constant* + Constant * #else FunctionCallee #endif - c = M.getOrInsertFunction("tolower", - Int32Ty, - Int32Ty + c = M.getOrInsertFunction("tolower", Int32Ty, Int32Ty #if LLVM_VERSION_MAJOR < 5 - , nullptr + , + nullptr #endif - ); + ); #if LLVM_VERSION_MAJOR < 9 - Function* tolowerFn = cast<Function>(c); + Function *tolowerFn = cast<Function>(c); #else FunctionCallee tolowerFn = c; #endif - /* iterate over all functions, bbs and instruction and add suitable calls to strcmp/memcmp/strncmp/strcasecmp/strncasecmp */ + /* iterate over all functions, bbs and instruction and add suitable calls to + * strcmp/memcmp/strncmp/strcasecmp/strncasecmp */ for (auto &F : M) { + for (auto &BB : F) { - for(auto &IN: BB) { - CallInst* callInst = nullptr; + + for (auto &IN : BB) { + + CallInst *callInst = nullptr; if ((callInst = dyn_cast<CallInst>(&IN))) { - bool isStrcmp = processStrcmp; - bool isMemcmp = processMemcmp; - bool isStrncmp = processStrncmp; - bool isStrcasecmp = processStrcasecmp; + bool isStrcmp = processStrcmp; + bool isMemcmp = processMemcmp; + bool isStrncmp = processStrncmp; + bool isStrcasecmp = processStrcasecmp; bool isStrncasecmp = processStrncasecmp; Function *Callee = callInst->getCalledFunction(); - if (!Callee) - continue; - if (callInst->getCallingConv() != llvm::CallingConv::C) - continue; + if (!Callee) continue; + if (callInst->getCallingConv() != llvm::CallingConv::C) continue; StringRef FuncName = Callee->getName(); - isStrcmp &= !FuncName.compare(StringRef("strcmp")); - isMemcmp &= !FuncName.compare(StringRef("memcmp")); - isStrncmp &= !FuncName.compare(StringRef("strncmp")); - isStrcasecmp &= !FuncName.compare(StringRef("strcasecmp")); + isStrcmp &= !FuncName.compare(StringRef("strcmp")); + isMemcmp &= !FuncName.compare(StringRef("memcmp")); + isStrncmp &= !FuncName.compare(StringRef("strncmp")); + isStrcasecmp &= !FuncName.compare(StringRef("strcasecmp")); isStrncasecmp &= !FuncName.compare(StringRef("strncasecmp")); - if (!isStrcmp && !isMemcmp && !isStrncmp && !isStrcasecmp && !isStrncasecmp) + if (!isStrcmp && !isMemcmp && !isStrncmp && !isStrcasecmp && + !isStrncasecmp) continue; - /* Verify the strcmp/memcmp/strncmp/strcasecmp/strncasecmp function prototype */ + /* Verify the strcmp/memcmp/strncmp/strcasecmp/strncasecmp function + * prototype */ FunctionType *FT = Callee->getFunctionType(); - - isStrcmp &= FT->getNumParams() == 2 && - FT->getReturnType()->isIntegerTy(32) && - FT->getParamType(0) == FT->getParamType(1) && - FT->getParamType(0) == IntegerType::getInt8PtrTy(M.getContext()); - isStrcasecmp &= FT->getNumParams() == 2 && - FT->getReturnType()->isIntegerTy(32) && - FT->getParamType(0) == FT->getParamType(1) && - FT->getParamType(0) == IntegerType::getInt8PtrTy(M.getContext()); - isMemcmp &= FT->getNumParams() == 3 && + isStrcmp &= + FT->getNumParams() == 2 && FT->getReturnType()->isIntegerTy(32) && + FT->getParamType(0) == FT->getParamType(1) && + FT->getParamType(0) == IntegerType::getInt8PtrTy(M.getContext()); + isStrcasecmp &= + FT->getNumParams() == 2 && FT->getReturnType()->isIntegerTy(32) && + FT->getParamType(0) == FT->getParamType(1) && + FT->getParamType(0) == IntegerType::getInt8PtrTy(M.getContext()); + isMemcmp &= FT->getNumParams() == 3 && FT->getReturnType()->isIntegerTy(32) && FT->getParamType(0)->isPointerTy() && FT->getParamType(1)->isPointerTy() && FT->getParamType(2)->isIntegerTy(); - isStrncmp &= FT->getNumParams() == 3 && - FT->getReturnType()->isIntegerTy(32) && - FT->getParamType(0) == FT->getParamType(1) && - FT->getParamType(0) == IntegerType::getInt8PtrTy(M.getContext()) && - FT->getParamType(2)->isIntegerTy(); + isStrncmp &= FT->getNumParams() == 3 && + FT->getReturnType()->isIntegerTy(32) && + FT->getParamType(0) == FT->getParamType(1) && + FT->getParamType(0) == + IntegerType::getInt8PtrTy(M.getContext()) && + FT->getParamType(2)->isIntegerTy(); isStrncasecmp &= FT->getNumParams() == 3 && - FT->getReturnType()->isIntegerTy(32) && - FT->getParamType(0) == FT->getParamType(1) && - FT->getParamType(0) == IntegerType::getInt8PtrTy(M.getContext()) && - FT->getParamType(2)->isIntegerTy(); - - if (!isStrcmp && !isMemcmp && !isStrncmp && !isStrcasecmp && !isStrncasecmp) + FT->getReturnType()->isIntegerTy(32) && + FT->getParamType(0) == FT->getParamType(1) && + FT->getParamType(0) == + IntegerType::getInt8PtrTy(M.getContext()) && + FT->getParamType(2)->isIntegerTy(); + + if (!isStrcmp && !isMemcmp && !isStrncmp && !isStrcasecmp && + !isStrncasecmp) continue; /* is a str{n,}{case,}cmp/memcmp, check if we have * str{case,}cmp(x, "const") or str{case,}cmp("const", x) * strn{case,}cmp(x, "const", ..) or strn{case,}cmp("const", x, ..) * memcmp(x, "const", ..) or memcmp("const", x, ..) */ - Value *Str1P = callInst->getArgOperand(0), *Str2P = callInst->getArgOperand(1); + Value *Str1P = callInst->getArgOperand(0), + *Str2P = callInst->getArgOperand(1); StringRef Str1, Str2; - bool HasStr1 = getConstantStringInfo(Str1P, Str1); - bool HasStr2 = getConstantStringInfo(Str2P, Str2); + bool HasStr1 = getConstantStringInfo(Str1P, Str1); + bool HasStr2 = getConstantStringInfo(Str2P, Str2); /* handle cases of one string is const, one string is variable */ - if (!(HasStr1 ^ HasStr2)) - continue; + if (!(HasStr1 ^ HasStr2)) continue; if (isMemcmp || isStrncmp || isStrncasecmp) { + /* check if third operand is a constant integer * strlen("constStr") and sizeof() are treated as constant */ - Value *op2 = callInst->getArgOperand(2); - ConstantInt* ilen = dyn_cast<ConstantInt>(op2); - if (!ilen) - continue; - /* final precaution: if size of compare is larger than constant string skip it*/ - uint64_t literalLength = HasStr1 ? GetStringLength(Str1P) : GetStringLength(Str2P); - if (literalLength < ilen->getZExtValue()) - continue; + Value * op2 = callInst->getArgOperand(2); + ConstantInt *ilen = dyn_cast<ConstantInt>(op2); + if (!ilen) continue; + /* final precaution: if size of compare is larger than constant + * string skip it*/ + uint64_t literalLength = + HasStr1 ? GetStringLength(Str1P) : GetStringLength(Str2P); + if (literalLength < ilen->getZExtValue()) continue; + } calls.push_back(callInst); + } + } + } + } - if (!calls.size()) - return false; - errs() << "Replacing " << calls.size() << " calls to strcmp/memcmp/strncmp/strcasecmp/strncasecmp\n"; + if (!calls.size()) return false; + errs() << "Replacing " << calls.size() + << " calls to strcmp/memcmp/strncmp/strcasecmp/strncasecmp\n"; - for (auto &callInst: calls) { + for (auto &callInst : calls) { - Value *Str1P = callInst->getArgOperand(0), *Str2P = callInst->getArgOperand(1); - StringRef Str1, Str2, ConstStr; + Value *Str1P = callInst->getArgOperand(0), + *Str2P = callInst->getArgOperand(1); + StringRef Str1, Str2, ConstStr; std::string TmpConstStr; - Value *VarStr; - bool HasStr1 = getConstantStringInfo(Str1P, Str1); + Value * VarStr; + bool HasStr1 = getConstantStringInfo(Str1P, Str1); getConstantStringInfo(Str2P, Str2); uint64_t constLen, sizedLen; - bool isMemcmp = !callInst->getCalledFunction()->getName().compare(StringRef("memcmp")); - bool isSizedcmp = isMemcmp - || !callInst->getCalledFunction()->getName().compare(StringRef("strncmp")) - || !callInst->getCalledFunction()->getName().compare(StringRef("strncasecmp")); - bool isCaseInsensitive = !callInst->getCalledFunction()->getName().compare(StringRef("strcasecmp")) - || !callInst->getCalledFunction()->getName().compare(StringRef("strncasecmp")); + bool isMemcmp = + !callInst->getCalledFunction()->getName().compare(StringRef("memcmp")); + bool isSizedcmp = isMemcmp || + !callInst->getCalledFunction()->getName().compare( + StringRef("strncmp")) || + !callInst->getCalledFunction()->getName().compare( + StringRef("strncasecmp")); + bool isCaseInsensitive = !callInst->getCalledFunction()->getName().compare( + StringRef("strcasecmp")) || + !callInst->getCalledFunction()->getName().compare( + StringRef("strncasecmp")); if (isSizedcmp) { - Value *op2 = callInst->getArgOperand(2); - ConstantInt* ilen = dyn_cast<ConstantInt>(op2); + + Value * op2 = callInst->getArgOperand(2); + ConstantInt *ilen = dyn_cast<ConstantInt>(op2); sizedLen = ilen->getZExtValue(); + } if (HasStr1) { + TmpConstStr = Str1.str(); VarStr = Str2P; constLen = isMemcmp ? sizedLen : GetStringLength(Str1P); - } - else { + + } else { + TmpConstStr = Str2.str(); VarStr = Str1P; constLen = isMemcmp ? sizedLen : GetStringLength(Str2P); + } /* properly handle zero terminated C strings by adding the terminating 0 to * the StringRef (in comparison to std::string a StringRef has built-in * runtime bounds checking, which makes debugging easier) */ - TmpConstStr.append("\0", 1); ConstStr = StringRef(TmpConstStr); + TmpConstStr.append("\0", 1); + ConstStr = StringRef(TmpConstStr); - if (isSizedcmp && constLen > sizedLen) { - constLen = sizedLen; - } + if (isSizedcmp && constLen > sizedLen) { constLen = sizedLen; } - errs() << callInst->getCalledFunction()->getName() << ": len " << constLen << ": " << ConstStr << "\n"; + errs() << callInst->getCalledFunction()->getName() << ": len " << constLen + << ": " << ConstStr << "\n"; /* split before the call instruction */ BasicBlock *bb = callInst->getParent(); BasicBlock *end_bb = bb->splitBasicBlock(BasicBlock::iterator(callInst)); - BasicBlock *next_bb = BasicBlock::Create(C, "cmp_added", end_bb->getParent(), end_bb); + BasicBlock *next_bb = + BasicBlock::Create(C, "cmp_added", end_bb->getParent(), end_bb); BranchInst::Create(end_bb, next_bb); PHINode *PN = PHINode::Create(Int32Ty, constLen + 1, "cmp_phi"); @@ -249,71 +283,81 @@ bool CompareTransform::transformCmps(Module &M, const bool processStrcmp, const char c = isCaseInsensitive ? tolower(ConstStr[i]) : ConstStr[i]; - BasicBlock::iterator IP = next_bb->getFirstInsertionPt(); - IRBuilder<> IRB(&*IP); + IRBuilder<> IRB(&*IP); - Value* v = ConstantInt::get(Int64Ty, i); - Value *ele = IRB.CreateInBoundsGEP(VarStr, v, "empty"); + Value *v = ConstantInt::get(Int64Ty, i); + Value *ele = IRB.CreateInBoundsGEP(VarStr, v, "empty"); Value *load = IRB.CreateLoad(ele); if (isCaseInsensitive) { + // load >= 'A' && load <= 'Z' ? load | 0x020 : load std::vector<Value *> args; args.push_back(load); load = IRB.CreateCall(tolowerFn, args, "tmp"); load = IRB.CreateTrunc(load, Int8Ty); + } + Value *isub; if (HasStr1) isub = IRB.CreateSub(ConstantInt::get(Int8Ty, c), load); else isub = IRB.CreateSub(load, ConstantInt::get(Int8Ty, c)); - Value *sext = IRB.CreateSExt(isub, Int32Ty); + Value *sext = IRB.CreateSExt(isub, Int32Ty); PN->addIncoming(sext, cur_bb); - if (i < constLen - 1) { - next_bb = BasicBlock::Create(C, "cmp_added", end_bb->getParent(), end_bb); + + next_bb = + BasicBlock::Create(C, "cmp_added", end_bb->getParent(), end_bb); BranchInst::Create(end_bb, next_bb); Value *icmp = IRB.CreateICmpEQ(isub, ConstantInt::get(Int8Ty, 0)); IRB.CreateCondBr(icmp, next_bb, end_bb); cur_bb->getTerminator()->eraseFromParent(); + } else { - //IRB.CreateBr(end_bb); + + // IRB.CreateBr(end_bb); + } - //add offset to varstr - //create load - //create signed isub - //create icmp - //create jcc - //create next_bb + // add offset to varstr + // create load + // create signed isub + // create icmp + // create jcc + // create next_bb + } /* since the call is the first instruction of the bb it is safe to * replace it with a phi instruction */ BasicBlock::iterator ii(callInst); ReplaceInstWithInst(callInst->getParent()->getInstList(), ii, PN); - } + } return true; + } bool CompareTransform::runOnModule(Module &M) { if (getenv("AFL_QUIET") == NULL) - llvm::errs() << "Running compare-transform-pass by laf.intel@gmail.com, extended by heiko@hexco.de\n"; + llvm::errs() << "Running compare-transform-pass by laf.intel@gmail.com, " + "extended by heiko@hexco.de\n"; transformCmps(M, true, true, true, true, true); verifyModule(M); return true; + } static void registerCompTransPass(const PassManagerBuilder &, - legacy::PassManagerBase &PM) { + legacy::PassManagerBase &PM) { auto p = new CompareTransform(); PM.add(p); diff --git a/llvm_mode/split-compares-pass.so.cc b/llvm_mode/split-compares-pass.so.cc index a74b60fa..1e9d6542 100644 --- a/llvm_mode/split-compares-pass.so.cc +++ b/llvm_mode/split-compares-pass.so.cc @@ -27,117 +27,126 @@ using namespace llvm; namespace { - class SplitComparesTransform : public ModulePass { - public: - static char ID; - SplitComparesTransform() : ModulePass(ID) {} - bool runOnModule(Module &M) override; +class SplitComparesTransform : public ModulePass { + + public: + static char ID; + SplitComparesTransform() : ModulePass(ID) { + + } + + bool runOnModule(Module &M) override; #if LLVM_VERSION_MAJOR >= 4 - StringRef getPassName() const override { + StringRef getPassName() const override { + #else - const char * getPassName() const override { + const char *getPassName() const override { + #endif - return "simplifies and splits ICMP instructions"; - } - private: - bool splitCompares(Module &M, unsigned bitw); - bool simplifyCompares(Module &M); - bool simplifySignedness(Module &M); + return "simplifies and splits ICMP instructions"; - }; -} + } + + private: + bool splitCompares(Module &M, unsigned bitw); + bool simplifyCompares(Module &M); + bool simplifySignedness(Module &M); + +}; + +} // namespace char SplitComparesTransform::ID = 0; -/* This function splits ICMP instructions with xGE or xLE predicates into two +/* This function splits ICMP instructions with xGE or xLE predicates into two * ICMP instructions with predicate xGT or xLT and EQ */ bool SplitComparesTransform::simplifyCompares(Module &M) { - LLVMContext &C = M.getContext(); - std::vector<Instruction*> icomps; - IntegerType *Int1Ty = IntegerType::getInt1Ty(C); + + LLVMContext & C = M.getContext(); + std::vector<Instruction *> icomps; + IntegerType * Int1Ty = IntegerType::getInt1Ty(C); /* iterate over all functions, bbs and instruction and add * all integer comparisons with >= and <= predicates to the icomps vector */ for (auto &F : M) { + for (auto &BB : F) { - for (auto &IN: BB) { - CmpInst* selectcmpInst = nullptr; + + for (auto &IN : BB) { + + CmpInst *selectcmpInst = nullptr; if ((selectcmpInst = dyn_cast<CmpInst>(&IN))) { if (selectcmpInst->getPredicate() != CmpInst::ICMP_UGE && selectcmpInst->getPredicate() != CmpInst::ICMP_SGE && selectcmpInst->getPredicate() != CmpInst::ICMP_ULE && - selectcmpInst->getPredicate() != CmpInst::ICMP_SLE ) { + selectcmpInst->getPredicate() != CmpInst::ICMP_SLE) { + continue; + } auto op0 = selectcmpInst->getOperand(0); auto op1 = selectcmpInst->getOperand(1); - IntegerType* intTyOp0 = dyn_cast<IntegerType>(op0->getType()); - IntegerType* intTyOp1 = dyn_cast<IntegerType>(op1->getType()); + IntegerType *intTyOp0 = dyn_cast<IntegerType>(op0->getType()); + IntegerType *intTyOp1 = dyn_cast<IntegerType>(op1->getType()); /* this is probably not needed but we do it anyway */ - if (!intTyOp0 || !intTyOp1) { - continue; - } + if (!intTyOp0 || !intTyOp1) { continue; } icomps.push_back(selectcmpInst); + } + } + } - } - if (!icomps.size()) { - return false; } + if (!icomps.size()) { return false; } + + for (auto &IcmpInst : icomps) { - for (auto &IcmpInst: icomps) { - BasicBlock* bb = IcmpInst->getParent(); + BasicBlock *bb = IcmpInst->getParent(); auto op0 = IcmpInst->getOperand(0); auto op1 = IcmpInst->getOperand(1); /* find out what the new predicate is going to be */ - auto pred = dyn_cast<CmpInst>(IcmpInst)->getPredicate(); + auto pred = dyn_cast<CmpInst>(IcmpInst)->getPredicate(); CmpInst::Predicate new_pred; - switch(pred) { - case CmpInst::ICMP_UGE: - new_pred = CmpInst::ICMP_UGT; - break; - case CmpInst::ICMP_SGE: - new_pred = CmpInst::ICMP_SGT; - break; - case CmpInst::ICMP_ULE: - new_pred = CmpInst::ICMP_ULT; - break; - case CmpInst::ICMP_SLE: - new_pred = CmpInst::ICMP_SLT; - break; - default: // keep the compiler happy + switch (pred) { + + case CmpInst::ICMP_UGE: new_pred = CmpInst::ICMP_UGT; break; + case CmpInst::ICMP_SGE: new_pred = CmpInst::ICMP_SGT; break; + case CmpInst::ICMP_ULE: new_pred = CmpInst::ICMP_ULT; break; + case CmpInst::ICMP_SLE: new_pred = CmpInst::ICMP_SLT; break; + default: // keep the compiler happy continue; + } /* split before the icmp instruction */ - BasicBlock* end_bb = bb->splitBasicBlock(BasicBlock::iterator(IcmpInst)); + BasicBlock *end_bb = bb->splitBasicBlock(BasicBlock::iterator(IcmpInst)); /* the old bb now contains a unconditional jump to the new one (end_bb) * we need to delete it later */ /* create the ICMP instruction with new_pred and add it to the old basic * block bb it is now at the position where the old IcmpInst was */ - Instruction* icmp_np; + Instruction *icmp_np; icmp_np = CmpInst::Create(Instruction::ICmp, new_pred, op0, op1); bb->getInstList().insert(bb->getTerminator()->getIterator(), icmp_np); /* create a new basic block which holds the new EQ icmp */ Instruction *icmp_eq; /* insert middle_bb before end_bb */ - BasicBlock* middle_bb = BasicBlock::Create(C, "injected", - end_bb->getParent(), end_bb); + BasicBlock *middle_bb = + BasicBlock::Create(C, "injected", end_bb->getParent(), end_bb); icmp_eq = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, op0, op1); middle_bb->getInstList().push_back(icmp_eq); /* add an unconditional branch to the end of middle_bb with destination @@ -150,7 +159,6 @@ bool SplitComparesTransform::simplifyCompares(Module &M) { BranchInst::Create(end_bb, middle_bb, icmp_np, bb); term->eraseFromParent(); - /* replace the old IcmpInst (which is the first inst in end_bb) with a PHI * inst to wire up the loose ends */ PHINode *PN = PHINode::Create(Int1Ty, 2, ""); @@ -162,118 +170,139 @@ bool SplitComparesTransform::simplifyCompares(Module &M) { /* replace the old IcmpInst with our new and shiny PHI inst */ BasicBlock::iterator ii(IcmpInst); ReplaceInstWithInst(IcmpInst->getParent()->getInstList(), ii, PN); + } return true; + } /* this function transforms signed compares to equivalent unsigned compares */ bool SplitComparesTransform::simplifySignedness(Module &M) { - LLVMContext &C = M.getContext(); - std::vector<Instruction*> icomps; - IntegerType *Int1Ty = IntegerType::getInt1Ty(C); + + LLVMContext & C = M.getContext(); + std::vector<Instruction *> icomps; + IntegerType * Int1Ty = IntegerType::getInt1Ty(C); /* iterate over all functions, bbs and instruction and add * all signed compares to icomps vector */ for (auto &F : M) { + for (auto &BB : F) { - for(auto &IN: BB) { - CmpInst* selectcmpInst = nullptr; + + for (auto &IN : BB) { + + CmpInst *selectcmpInst = nullptr; if ((selectcmpInst = dyn_cast<CmpInst>(&IN))) { if (selectcmpInst->getPredicate() != CmpInst::ICMP_SGT && - selectcmpInst->getPredicate() != CmpInst::ICMP_SLT - ) { + selectcmpInst->getPredicate() != CmpInst::ICMP_SLT) { + continue; + } auto op0 = selectcmpInst->getOperand(0); auto op1 = selectcmpInst->getOperand(1); - IntegerType* intTyOp0 = dyn_cast<IntegerType>(op0->getType()); - IntegerType* intTyOp1 = dyn_cast<IntegerType>(op1->getType()); + IntegerType *intTyOp0 = dyn_cast<IntegerType>(op0->getType()); + IntegerType *intTyOp1 = dyn_cast<IntegerType>(op1->getType()); /* see above */ - if (!intTyOp0 || !intTyOp1) { - continue; - } + if (!intTyOp0 || !intTyOp1) { continue; } /* i think this is not possible but to lazy to look it up */ - if (intTyOp0->getBitWidth() != intTyOp1->getBitWidth()) { - continue; - } + if (intTyOp0->getBitWidth() != intTyOp1->getBitWidth()) { continue; } icomps.push_back(selectcmpInst); + } + } + } - } - if (!icomps.size()) { - return false; } - for (auto &IcmpInst: icomps) { - BasicBlock* bb = IcmpInst->getParent(); + if (!icomps.size()) { return false; } + + for (auto &IcmpInst : icomps) { + + BasicBlock *bb = IcmpInst->getParent(); auto op0 = IcmpInst->getOperand(0); auto op1 = IcmpInst->getOperand(1); - IntegerType* intTyOp0 = dyn_cast<IntegerType>(op0->getType()); - unsigned bitw = intTyOp0->getBitWidth(); + IntegerType *intTyOp0 = dyn_cast<IntegerType>(op0->getType()); + unsigned bitw = intTyOp0->getBitWidth(); IntegerType *IntType = IntegerType::get(C, bitw); - /* get the new predicate */ - auto pred = dyn_cast<CmpInst>(IcmpInst)->getPredicate(); + auto pred = dyn_cast<CmpInst>(IcmpInst)->getPredicate(); CmpInst::Predicate new_pred; if (pred == CmpInst::ICMP_SGT) { + new_pred = CmpInst::ICMP_UGT; + } else { + new_pred = CmpInst::ICMP_ULT; + } - BasicBlock* end_bb = bb->splitBasicBlock(BasicBlock::iterator(IcmpInst)); + BasicBlock *end_bb = bb->splitBasicBlock(BasicBlock::iterator(IcmpInst)); /* create a 1 bit compare for the sign bit. to do this shift and trunc * the original operands so only the first bit remains.*/ Instruction *s_op0, *t_op0, *s_op1, *t_op1, *icmp_sign_bit; - s_op0 = BinaryOperator::Create(Instruction::LShr, op0, ConstantInt::get(IntType, bitw - 1)); + s_op0 = BinaryOperator::Create(Instruction::LShr, op0, + ConstantInt::get(IntType, bitw - 1)); bb->getInstList().insert(bb->getTerminator()->getIterator(), s_op0); t_op0 = new TruncInst(s_op0, Int1Ty); bb->getInstList().insert(bb->getTerminator()->getIterator(), t_op0); - s_op1 = BinaryOperator::Create(Instruction::LShr, op1, ConstantInt::get(IntType, bitw - 1)); + s_op1 = BinaryOperator::Create(Instruction::LShr, op1, + ConstantInt::get(IntType, bitw - 1)); bb->getInstList().insert(bb->getTerminator()->getIterator(), s_op1); t_op1 = new TruncInst(s_op1, Int1Ty); bb->getInstList().insert(bb->getTerminator()->getIterator(), t_op1); /* compare of the sign bits */ - icmp_sign_bit = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, t_op0, t_op1); + icmp_sign_bit = + CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, t_op0, t_op1); bb->getInstList().insert(bb->getTerminator()->getIterator(), icmp_sign_bit); /* create a new basic block which is executed if the signedness bit is - * different */ + * different */ Instruction *icmp_inv_sig_cmp; - BasicBlock* sign_bb = BasicBlock::Create(C, "sign", end_bb->getParent(), end_bb); + BasicBlock * sign_bb = + BasicBlock::Create(C, "sign", end_bb->getParent(), end_bb); if (pred == CmpInst::ICMP_SGT) { + /* if we check for > and the op0 positive and op1 negative then the final * result is true. if op0 negative and op1 pos, the cmp must result * in false */ - icmp_inv_sig_cmp = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULT, t_op0, t_op1); + icmp_inv_sig_cmp = + CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULT, t_op0, t_op1); + } else { + /* just the inverse of the above statement */ - icmp_inv_sig_cmp = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_UGT, t_op0, t_op1); + icmp_inv_sig_cmp = + CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_UGT, t_op0, t_op1); + } + sign_bb->getInstList().push_back(icmp_inv_sig_cmp); BranchInst::Create(end_bb, sign_bb); /* create a new bb which is executed if signedness is equal */ Instruction *icmp_usign_cmp; - BasicBlock* middle_bb = BasicBlock::Create(C, "injected", end_bb->getParent(), end_bb); + BasicBlock * middle_bb = + BasicBlock::Create(C, "injected", end_bb->getParent(), end_bb); /* we can do a normal unsigned compare now */ icmp_usign_cmp = CmpInst::Create(Instruction::ICmp, new_pred, op0, op1); middle_bb->getInstList().push_back(icmp_usign_cmp); @@ -285,7 +314,6 @@ bool SplitComparesTransform::simplifySignedness(Module &M) { BranchInst::Create(middle_bb, sign_bb, icmp_sign_bit, bb); term->eraseFromParent(); - PHINode *PN = PHINode::Create(Int1Ty, 2, ""); PN->addIncoming(icmp_usign_cmp, middle_bb); @@ -293,91 +321,100 @@ bool SplitComparesTransform::simplifySignedness(Module &M) { BasicBlock::iterator ii(IcmpInst); ReplaceInstWithInst(IcmpInst->getParent()->getInstList(), ii, PN); + } return true; + } /* splits icmps of size bitw into two nested icmps with bitw/2 size each */ bool SplitComparesTransform::splitCompares(Module &M, unsigned bitw) { + LLVMContext &C = M.getContext(); IntegerType *Int1Ty = IntegerType::getInt1Ty(C); IntegerType *OldIntType = IntegerType::get(C, bitw); IntegerType *NewIntType = IntegerType::get(C, bitw / 2); - std::vector<Instruction*> icomps; + std::vector<Instruction *> icomps; - if (bitw % 2) { - return false; - } + if (bitw % 2) { return false; } /* not supported yet */ - if (bitw > 64) { - return false; - } + if (bitw > 64) { return false; } - /* get all EQ, NE, UGT, and ULT icmps of width bitw. if the other two + /* get all EQ, NE, UGT, and ULT icmps of width bitw. if the other two * unctions were executed only these four predicates should exist */ for (auto &F : M) { + for (auto &BB : F) { - for(auto &IN: BB) { - CmpInst* selectcmpInst = nullptr; + + for (auto &IN : BB) { + + CmpInst *selectcmpInst = nullptr; if ((selectcmpInst = dyn_cast<CmpInst>(&IN))) { - if(selectcmpInst->getPredicate() != CmpInst::ICMP_EQ && - selectcmpInst->getPredicate() != CmpInst::ICMP_NE && - selectcmpInst->getPredicate() != CmpInst::ICMP_UGT && - selectcmpInst->getPredicate() != CmpInst::ICMP_ULT - ) { + if (selectcmpInst->getPredicate() != CmpInst::ICMP_EQ && + selectcmpInst->getPredicate() != CmpInst::ICMP_NE && + selectcmpInst->getPredicate() != CmpInst::ICMP_UGT && + selectcmpInst->getPredicate() != CmpInst::ICMP_ULT) { + continue; + } auto op0 = selectcmpInst->getOperand(0); auto op1 = selectcmpInst->getOperand(1); - IntegerType* intTyOp0 = dyn_cast<IntegerType>(op0->getType()); - IntegerType* intTyOp1 = dyn_cast<IntegerType>(op1->getType()); + IntegerType *intTyOp0 = dyn_cast<IntegerType>(op0->getType()); + IntegerType *intTyOp1 = dyn_cast<IntegerType>(op1->getType()); - if (!intTyOp0 || !intTyOp1) { - continue; - } + if (!intTyOp0 || !intTyOp1) { continue; } /* check if the bitwidths are the one we are looking for */ - if (intTyOp0->getBitWidth() != bitw || intTyOp1->getBitWidth() != bitw) { + if (intTyOp0->getBitWidth() != bitw || + intTyOp1->getBitWidth() != bitw) { + continue; + } icomps.push_back(selectcmpInst); + } + } + } - } - if (!icomps.size()) { - return false; } - for (auto &IcmpInst: icomps) { - BasicBlock* bb = IcmpInst->getParent(); + if (!icomps.size()) { return false; } + + for (auto &IcmpInst : icomps) { + + BasicBlock *bb = IcmpInst->getParent(); auto op0 = IcmpInst->getOperand(0); auto op1 = IcmpInst->getOperand(1); auto pred = dyn_cast<CmpInst>(IcmpInst)->getPredicate(); - BasicBlock* end_bb = bb->splitBasicBlock(BasicBlock::iterator(IcmpInst)); + BasicBlock *end_bb = bb->splitBasicBlock(BasicBlock::iterator(IcmpInst)); /* create the comparison of the top halves of the original operands */ Instruction *s_op0, *op0_high, *s_op1, *op1_high, *icmp_high; - s_op0 = BinaryOperator::Create(Instruction::LShr, op0, ConstantInt::get(OldIntType, bitw / 2)); + s_op0 = BinaryOperator::Create(Instruction::LShr, op0, + ConstantInt::get(OldIntType, bitw / 2)); bb->getInstList().insert(bb->getTerminator()->getIterator(), s_op0); op0_high = new TruncInst(s_op0, NewIntType); bb->getInstList().insert(bb->getTerminator()->getIterator(), op0_high); - s_op1 = BinaryOperator::Create(Instruction::LShr, op1, ConstantInt::get(OldIntType, bitw / 2)); + s_op1 = BinaryOperator::Create(Instruction::LShr, op1, + ConstantInt::get(OldIntType, bitw / 2)); bb->getInstList().insert(bb->getTerminator()->getIterator(), s_op1); op1_high = new TruncInst(s_op1, NewIntType); bb->getInstList().insert(bb->getTerminator()->getIterator(), op1_high); @@ -387,11 +424,13 @@ bool SplitComparesTransform::splitCompares(Module &M, unsigned bitw) { /* now we have to destinguish between == != and > < */ if (pred == CmpInst::ICMP_EQ || pred == CmpInst::ICMP_NE) { + /* transformation for == and != icmps */ /* create a compare for the lower half of the original operands */ Instruction *op0_low, *op1_low, *icmp_low; - BasicBlock* cmp_low_bb = BasicBlock::Create(C, "injected", end_bb->getParent(), end_bb); + BasicBlock * cmp_low_bb = + BasicBlock::Create(C, "injected", end_bb->getParent(), end_bb); op0_low = new TruncInst(op0, NewIntType); cmp_low_bb->getInstList().push_back(op0_low); @@ -407,21 +446,30 @@ bool SplitComparesTransform::splitCompares(Module &M, unsigned bitw) { * the comparison */ auto term = bb->getTerminator(); if (pred == CmpInst::ICMP_EQ) { + BranchInst::Create(cmp_low_bb, end_bb, icmp_high, bb); + } else { + /* CmpInst::ICMP_NE */ BranchInst::Create(end_bb, cmp_low_bb, icmp_high, bb); + } + term->eraseFromParent(); /* create the PHI and connect the edges accordingly */ PHINode *PN = PHINode::Create(Int1Ty, 2, ""); PN->addIncoming(icmp_low, cmp_low_bb); if (pred == CmpInst::ICMP_EQ) { + PN->addIncoming(ConstantInt::get(Int1Ty, 0), bb); + } else { + /* CmpInst::ICMP_NE */ PN->addIncoming(ConstantInt::get(Int1Ty, 1), bb); + } /* replace the old icmp with the new PHI */ @@ -429,19 +477,28 @@ bool SplitComparesTransform::splitCompares(Module &M, unsigned bitw) { ReplaceInstWithInst(IcmpInst->getParent()->getInstList(), ii, PN); } else { + /* CmpInst::ICMP_UGT and CmpInst::ICMP_ULT */ /* transformations for < and > */ - /* create a basic block which checks for the inverse predicate. + /* create a basic block which checks for the inverse predicate. * if this is true we can go to the end if not we have to got to the * bb which checks the lower half of the operands */ Instruction *icmp_inv_cmp, *op0_low, *op1_low, *icmp_low; - BasicBlock* inv_cmp_bb = BasicBlock::Create(C, "inv_cmp", end_bb->getParent(), end_bb); + BasicBlock * inv_cmp_bb = + BasicBlock::Create(C, "inv_cmp", end_bb->getParent(), end_bb); if (pred == CmpInst::ICMP_UGT) { - icmp_inv_cmp = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULT, op0_high, op1_high); + + icmp_inv_cmp = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULT, + op0_high, op1_high); + } else { - icmp_inv_cmp = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_UGT, op0_high, op1_high); + + icmp_inv_cmp = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_UGT, + op0_high, op1_high); + } + inv_cmp_bb->getInstList().push_back(icmp_inv_cmp); auto term = bb->getTerminator(); @@ -449,7 +506,8 @@ bool SplitComparesTransform::splitCompares(Module &M, unsigned bitw) { BranchInst::Create(end_bb, inv_cmp_bb, icmp_high, bb); /* create a bb which handles the cmp of the lower halves */ - BasicBlock* cmp_low_bb = BasicBlock::Create(C, "injected", end_bb->getParent(), end_bb); + BasicBlock *cmp_low_bb = + BasicBlock::Create(C, "injected", end_bb->getParent(), end_bb); op0_low = new TruncInst(op0, NewIntType); cmp_low_bb->getInstList().push_back(op0_low); op1_low = new TruncInst(op1, NewIntType); @@ -468,57 +526,64 @@ bool SplitComparesTransform::splitCompares(Module &M, unsigned bitw) { BasicBlock::iterator ii(IcmpInst); ReplaceInstWithInst(IcmpInst->getParent()->getInstList(), ii, PN); + } + } - return true; + + return true; + } bool SplitComparesTransform::runOnModule(Module &M) { + int bitw = 64; - char* bitw_env = getenv("LAF_SPLIT_COMPARES_BITW"); - if (!bitw_env) - bitw_env = getenv("AFL_LLVM_LAF_SPLIT_COMPARES_BITW"); - if (bitw_env) { - bitw = atoi(bitw_env); - } + char *bitw_env = getenv("LAF_SPLIT_COMPARES_BITW"); + if (!bitw_env) bitw_env = getenv("AFL_LLVM_LAF_SPLIT_COMPARES_BITW"); + if (bitw_env) { bitw = atoi(bitw_env); } simplifyCompares(M); simplifySignedness(M); if (getenv("AFL_QUIET") == NULL) - errs() << "Split-compare-pass by laf.intel@gmail.com\n"; + errs() << "Split-compare-pass by laf.intel@gmail.com\n"; switch (bitw) { + case 64: - errs() << "Running split-compare-pass " << 64 << "\n"; + errs() << "Running split-compare-pass " << 64 << "\n"; splitCompares(M, 64); - [[clang::fallthrough]]; /*FALLTHRU*/ /* FALLTHROUGH */ + [[clang::fallthrough]]; /*FALLTHRU*/ /* FALLTHROUGH */ case 32: - errs() << "Running split-compare-pass " << 32 << "\n"; + errs() << "Running split-compare-pass " << 32 << "\n"; splitCompares(M, 32); - [[clang::fallthrough]]; /*FALLTHRU*/ /* FALLTHROUGH */ + [[clang::fallthrough]]; /*FALLTHRU*/ /* FALLTHROUGH */ case 16: - errs() << "Running split-compare-pass " << 16 << "\n"; + errs() << "Running split-compare-pass " << 16 << "\n"; splitCompares(M, 16); break; default: - errs() << "NOT Running split-compare-pass \n"; + errs() << "NOT Running split-compare-pass \n"; return false; break; + } verifyModule(M); return true; + } static void registerSplitComparesPass(const PassManagerBuilder &, - legacy::PassManagerBase &PM) { + legacy::PassManagerBase &PM) { + PM.add(new SplitComparesTransform()); + } static RegisterStandardPasses RegisterSplitComparesPass( @@ -526,3 +591,4 @@ static RegisterStandardPasses RegisterSplitComparesPass( static RegisterStandardPasses RegisterSplitComparesTransPass0( PassManagerBuilder::EP_EnabledOnOptLevel0, registerSplitComparesPass); + diff --git a/llvm_mode/split-switches-pass.so.cc b/llvm_mode/split-switches-pass.so.cc index 1ace3185..2743a71a 100644 --- a/llvm_mode/split-switches-pass.so.cc +++ b/llvm_mode/split-switches-pass.so.cc @@ -36,54 +36,65 @@ using namespace llvm; namespace { - class SplitSwitchesTransform : public ModulePass { +class SplitSwitchesTransform : public ModulePass { - public: - static char ID; - SplitSwitchesTransform() : ModulePass(ID) { - } + public: + static char ID; + SplitSwitchesTransform() : ModulePass(ID) { - bool runOnModule(Module &M) override; + } + + bool runOnModule(Module &M) override; #if LLVM_VERSION_MAJOR >= 4 - StringRef getPassName() const override { + StringRef getPassName() const override { + #else - const char * getPassName() const override { + const char *getPassName() const override { + #endif - return "splits switch constructs"; - } - struct CaseExpr { - ConstantInt* Val; - BasicBlock* BB; - - CaseExpr(ConstantInt *val = nullptr, BasicBlock *bb = nullptr) : - Val(val), BB(bb) { } - }; - - typedef std::vector<CaseExpr> CaseVector; - - private: - bool splitSwitches(Module &M); - bool transformCmps(Module &M, const bool processStrcmp, const bool processMemcmp); - BasicBlock* switchConvert(CaseVector Cases, std::vector<bool> bytesChecked, - BasicBlock* OrigBlock, BasicBlock* NewDefault, - Value* Val, unsigned level); + return "splits switch constructs"; + + } + + struct CaseExpr { + + ConstantInt *Val; + BasicBlock * BB; + + CaseExpr(ConstantInt *val = nullptr, BasicBlock *bb = nullptr) + : Val(val), BB(bb) { + + } + }; -} + typedef std::vector<CaseExpr> CaseVector; -char SplitSwitchesTransform::ID = 0; + private: + bool splitSwitches(Module &M); + bool transformCmps(Module &M, const bool processStrcmp, + const bool processMemcmp); + BasicBlock *switchConvert(CaseVector Cases, std::vector<bool> bytesChecked, + BasicBlock *OrigBlock, BasicBlock *NewDefault, + Value *Val, unsigned level); + +}; +} // namespace + +char SplitSwitchesTransform::ID = 0; /* switchConvert - Transform simple list of Cases into list of CaseRange's */ -BasicBlock* SplitSwitchesTransform::switchConvert(CaseVector Cases, std::vector<bool> bytesChecked, - BasicBlock* OrigBlock, BasicBlock* NewDefault, - Value* Val, unsigned level) { - - unsigned ValTypeBitWidth = Cases[0].Val->getBitWidth(); - IntegerType *ValType = IntegerType::get(OrigBlock->getContext(), ValTypeBitWidth); - IntegerType *ByteType = IntegerType::get(OrigBlock->getContext(), 8); - unsigned BytesInValue = bytesChecked.size(); +BasicBlock *SplitSwitchesTransform::switchConvert( + CaseVector Cases, std::vector<bool> bytesChecked, BasicBlock *OrigBlock, + BasicBlock *NewDefault, Value *Val, unsigned level) { + + unsigned ValTypeBitWidth = Cases[0].Val->getBitWidth(); + IntegerType *ValType = + IntegerType::get(OrigBlock->getContext(), ValTypeBitWidth); + IntegerType * ByteType = IntegerType::get(OrigBlock->getContext(), 8); + unsigned BytesInValue = bytesChecked.size(); std::vector<uint8_t> setSizes; std::vector<std::set<uint8_t>> byteSets(BytesInValue, std::set<uint8_t>()); @@ -91,43 +102,54 @@ BasicBlock* SplitSwitchesTransform::switchConvert(CaseVector Cases, std::vector< /* for each of the possible cases we iterate over all bytes of the values * build a set of possible values at each byte position in byteSets */ - for (CaseExpr& Case: Cases) { + for (CaseExpr &Case : Cases) { + for (unsigned i = 0; i < BytesInValue; i++) { - uint8_t byte = (Case.Val->getZExtValue() >> (i*8)) & 0xFF; + uint8_t byte = (Case.Val->getZExtValue() >> (i * 8)) & 0xFF; byteSets[i].insert(byte); + } + } /* find the index of the first byte position that was not yet checked. then * save the number of possible values at that byte position */ unsigned smallestIndex = 0; unsigned smallestSize = 257; - for(unsigned i = 0; i < byteSets.size(); i++) { - if (bytesChecked[i]) - continue; + for (unsigned i = 0; i < byteSets.size(); i++) { + + if (bytesChecked[i]) continue; if (byteSets[i].size() < smallestSize) { + smallestIndex = i; smallestSize = byteSets[i].size(); + } + } + assert(bytesChecked[smallestIndex] == false); /* there are only smallestSize different bytes at index smallestIndex */ - + Instruction *Shift, *Trunc; - Function* F = OrigBlock->getParent(); - BasicBlock* NewNode = BasicBlock::Create(Val->getContext(), "NodeBlock", F); - Shift = BinaryOperator::Create(Instruction::LShr, Val, ConstantInt::get(ValType, smallestIndex * 8)); + Function * F = OrigBlock->getParent(); + BasicBlock * NewNode = BasicBlock::Create(Val->getContext(), "NodeBlock", F); + Shift = BinaryOperator::Create(Instruction::LShr, Val, + ConstantInt::get(ValType, smallestIndex * 8)); NewNode->getInstList().push_back(Shift); if (ValTypeBitWidth > 8) { + Trunc = new TruncInst(Shift, ByteType); NewNode->getInstList().push_back(Trunc); - } - else { + + } else { + /* not necessary to trunc */ Trunc = Shift; + } /* this is a trivial case, we can directly check for the byte, @@ -135,118 +157,155 @@ BasicBlock* SplitSwitchesTransform::switchConvert(CaseVector Cases, std::vector< * mark the byte as checked. if this was the last byte to check * we can finally execute the block belonging to this case */ - if (smallestSize == 1) { + uint8_t byte = *(byteSets[smallestIndex].begin()); - /* insert instructions to check whether the value we are switching on is equal to byte */ - ICmpInst* Comp = new ICmpInst(ICmpInst::ICMP_EQ, Trunc, ConstantInt::get(ByteType, byte), "byteMatch"); + /* insert instructions to check whether the value we are switching on is + * equal to byte */ + ICmpInst *Comp = + new ICmpInst(ICmpInst::ICMP_EQ, Trunc, ConstantInt::get(ByteType, byte), + "byteMatch"); NewNode->getInstList().push_back(Comp); bytesChecked[smallestIndex] = true; - if (std::all_of(bytesChecked.begin(), bytesChecked.end(), [](bool b){return b;} )) { + if (std::all_of(bytesChecked.begin(), bytesChecked.end(), + [](bool b) { return b; })) { + assert(Cases.size() == 1); BranchInst::Create(Cases[0].BB, NewDefault, Comp, NewNode); /* we have to update the phi nodes! */ - for (BasicBlock::iterator I = Cases[0].BB->begin(); I != Cases[0].BB->end(); ++I) { - if (!isa<PHINode>(&*I)) { - continue; - } + for (BasicBlock::iterator I = Cases[0].BB->begin(); + I != Cases[0].BB->end(); ++I) { + + if (!isa<PHINode>(&*I)) { continue; } PHINode *PN = cast<PHINode>(I); /* Only update the first occurrence. */ unsigned Idx = 0, E = PN->getNumIncomingValues(); for (; Idx != E; ++Idx) { + if (PN->getIncomingBlock(Idx) == OrigBlock) { + PN->setIncomingBlock(Idx, NewNode); break; + } + } + } - } - else { - BasicBlock* BB = switchConvert(Cases, bytesChecked, OrigBlock, NewDefault, Val, level + 1); + + } else { + + BasicBlock *BB = switchConvert(Cases, bytesChecked, OrigBlock, NewDefault, + Val, level + 1); BranchInst::Create(BB, NewDefault, Comp, NewNode); + } + } + /* there is no byte which we can directly check on, split the tree */ else { std::vector<uint8_t> byteVector; - std::copy(byteSets[smallestIndex].begin(), byteSets[smallestIndex].end(), std::back_inserter(byteVector)); + std::copy(byteSets[smallestIndex].begin(), byteSets[smallestIndex].end(), + std::back_inserter(byteVector)); std::sort(byteVector.begin(), byteVector.end()); uint8_t pivot = byteVector[byteVector.size() / 2]; - /* we already chose to divide the cases based on the value of byte at index smallestIndex - * the pivot value determines the threshold for the decicion; if a case value - * is smaller at this byte index move it to the LHS vector, otherwise to the RHS vector */ + /* we already chose to divide the cases based on the value of byte at index + * smallestIndex the pivot value determines the threshold for the decicion; + * if a case value + * is smaller at this byte index move it to the LHS vector, otherwise to the + * RHS vector */ CaseVector LHSCases, RHSCases; - for (CaseExpr& Case: Cases) { - uint8_t byte = (Case.Val->getZExtValue() >> (smallestIndex*8)) & 0xFF; + for (CaseExpr &Case : Cases) { + + uint8_t byte = (Case.Val->getZExtValue() >> (smallestIndex * 8)) & 0xFF; if (byte < pivot) { + LHSCases.push_back(Case); - } - else { + + } else { + RHSCases.push_back(Case); + } + } - BasicBlock *LBB, *RBB; - LBB = switchConvert(LHSCases, bytesChecked, OrigBlock, NewDefault, Val, level + 1); - RBB = switchConvert(RHSCases, bytesChecked, OrigBlock, NewDefault, Val, level + 1); - /* insert instructions to check whether the value we are switching on is equal to byte */ - ICmpInst* Comp = new ICmpInst(ICmpInst::ICMP_ULT, Trunc, ConstantInt::get(ByteType, pivot), "byteMatch"); + BasicBlock *LBB, *RBB; + LBB = switchConvert(LHSCases, bytesChecked, OrigBlock, NewDefault, Val, + level + 1); + RBB = switchConvert(RHSCases, bytesChecked, OrigBlock, NewDefault, Val, + level + 1); + + /* insert instructions to check whether the value we are switching on is + * equal to byte */ + ICmpInst *Comp = + new ICmpInst(ICmpInst::ICMP_ULT, Trunc, + ConstantInt::get(ByteType, pivot), "byteMatch"); NewNode->getInstList().push_back(Comp); BranchInst::Create(LBB, RBB, Comp, NewNode); } return NewNode; + } bool SplitSwitchesTransform::splitSwitches(Module &M) { - std::vector<SwitchInst*> switches; + std::vector<SwitchInst *> switches; /* iterate over all functions, bbs and instruction and add * all switches to switches vector for later processing */ for (auto &F : M) { + for (auto &BB : F) { - SwitchInst* switchInst = nullptr; + + SwitchInst *switchInst = nullptr; if ((switchInst = dyn_cast<SwitchInst>(BB.getTerminator()))) { - if (switchInst->getNumCases() < 1) - continue; - switches.push_back(switchInst); + + if (switchInst->getNumCases() < 1) continue; + switches.push_back(switchInst); + } + } + } - if (!switches.size()) - return false; - errs() << "Rewriting " << switches.size() << " switch statements " << "\n"; + if (!switches.size()) return false; + errs() << "Rewriting " << switches.size() << " switch statements " + << "\n"; - for (auto &SI: switches) { + for (auto &SI : switches) { BasicBlock *CurBlock = SI->getParent(); BasicBlock *OrigBlock = CurBlock; - Function *F = CurBlock->getParent(); + Function * F = CurBlock->getParent(); /* this is the value we are switching on */ - Value *Val = SI->getCondition(); - BasicBlock* Default = SI->getDefaultDest(); - unsigned bitw = Val->getType()->getIntegerBitWidth(); + Value * Val = SI->getCondition(); + BasicBlock *Default = SI->getDefaultDest(); + unsigned bitw = Val->getType()->getIntegerBitWidth(); errs() << "switch: " << SI->getNumCases() << " cases " << bitw << " bit\n"; - /* If there is only the default destination or the condition checks 8 bit or less, don't bother with the code below. */ + /* If there is only the default destination or the condition checks 8 bit or + * less, don't bother with the code below. */ if (!SI->getNumCases() || bitw <= 8) { - if (getenv("AFL_QUIET") == NULL) - errs() << "skip trivial switch..\n"; + + if (getenv("AFL_QUIET") == NULL) errs() << "skip trivial switch..\n"; continue; + } /* Create a new, empty default block so that the new hierarchy of @@ -258,10 +317,10 @@ bool SplitSwitchesTransform::splitSwitches(Module &M) { NewDefault->insertInto(F, Default); BranchInst::Create(Default, NewDefault); - /* Prepare cases vector. */ CaseVector Cases; - for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end(); i != e; ++i) + for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end(); i != e; + ++i) #if LLVM_VERSION_MAJOR < 5 Cases.push_back(CaseExpr(i.getCaseValue(), i.getCaseSuccessor())); #else @@ -269,8 +328,10 @@ bool SplitSwitchesTransform::splitSwitches(Module &M) { #endif /* bugfix thanks to pbst * round up bytesChecked (in case getBitWidth() % 8 != 0) */ - std::vector<bool> bytesChecked((7 + Cases[0].Val->getBitWidth()) / 8, false); - BasicBlock* SwitchBlock = switchConvert(Cases, bytesChecked, OrigBlock, NewDefault, Val, 0); + std::vector<bool> bytesChecked((7 + Cases[0].Val->getBitWidth()) / 8, + false); + BasicBlock * SwitchBlock = + switchConvert(Cases, bytesChecked, OrigBlock, NewDefault, Val, 0); /* Branch to our shiny new if-then stuff... */ BranchInst::Create(SwitchBlock, OrigBlock); @@ -278,41 +339,47 @@ bool SplitSwitchesTransform::splitSwitches(Module &M) { /* We are now done with the switch instruction, delete it. */ CurBlock->getInstList().erase(SI); + /* we have to update the phi nodes! */ + for (BasicBlock::iterator I = Default->begin(); I != Default->end(); ++I) { + + if (!isa<PHINode>(&*I)) { continue; } + PHINode *PN = cast<PHINode>(I); + + /* Only update the first occurrence. */ + unsigned Idx = 0, E = PN->getNumIncomingValues(); + for (; Idx != E; ++Idx) { + + if (PN->getIncomingBlock(Idx) == OrigBlock) { + + PN->setIncomingBlock(Idx, NewDefault); + break; + + } + + } + + } + + } + + verifyModule(M); + return true; - /* we have to update the phi nodes! */ - for (BasicBlock::iterator I = Default->begin(); I != Default->end(); ++I) { - if (!isa<PHINode>(&*I)) { - continue; - } - PHINode *PN = cast<PHINode>(I); - - /* Only update the first occurrence. */ - unsigned Idx = 0, E = PN->getNumIncomingValues(); - for (; Idx != E; ++Idx) { - if (PN->getIncomingBlock(Idx) == OrigBlock) { - PN->setIncomingBlock(Idx, NewDefault); - break; - } - } - } - } - - verifyModule(M); - return true; } bool SplitSwitchesTransform::runOnModule(Module &M) { if (getenv("AFL_QUIET") == NULL) - llvm::errs() << "Running split-switches-pass by laf.intel@gmail.com\n"; + llvm::errs() << "Running split-switches-pass by laf.intel@gmail.com\n"; splitSwitches(M); verifyModule(M); return true; + } static void registerSplitSwitchesTransPass(const PassManagerBuilder &, - legacy::PassManagerBase &PM) { + legacy::PassManagerBase &PM) { auto p = new SplitSwitchesTransform(); PM.add(p); @@ -324,3 +391,4 @@ static RegisterStandardPasses RegisterSplitSwitchesTransPass( static RegisterStandardPasses RegisterSplitSwitchesTransPass0( PassManagerBuilder::EP_EnabledOnOptLevel0, registerSplitSwitchesTransPass); + |