diff options
Diffstat (limited to 'llvm_mode/split-switches-pass.so.cc')
-rw-r--r-- | llvm_mode/split-switches-pass.so.cc | 294 |
1 files changed, 181 insertions, 113 deletions
diff --git a/llvm_mode/split-switches-pass.so.cc b/llvm_mode/split-switches-pass.so.cc index 1ace3185..2743a71a 100644 --- a/llvm_mode/split-switches-pass.so.cc +++ b/llvm_mode/split-switches-pass.so.cc @@ -36,54 +36,65 @@ using namespace llvm; namespace { - class SplitSwitchesTransform : public ModulePass { +class SplitSwitchesTransform : public ModulePass { - public: - static char ID; - SplitSwitchesTransform() : ModulePass(ID) { - } + public: + static char ID; + SplitSwitchesTransform() : ModulePass(ID) { - bool runOnModule(Module &M) override; + } + + bool runOnModule(Module &M) override; #if LLVM_VERSION_MAJOR >= 4 - StringRef getPassName() const override { + StringRef getPassName() const override { + #else - const char * getPassName() const override { + const char *getPassName() const override { + #endif - return "splits switch constructs"; - } - struct CaseExpr { - ConstantInt* Val; - BasicBlock* BB; - - CaseExpr(ConstantInt *val = nullptr, BasicBlock *bb = nullptr) : - Val(val), BB(bb) { } - }; - - typedef std::vector<CaseExpr> CaseVector; - - private: - bool splitSwitches(Module &M); - bool transformCmps(Module &M, const bool processStrcmp, const bool processMemcmp); - BasicBlock* switchConvert(CaseVector Cases, std::vector<bool> bytesChecked, - BasicBlock* OrigBlock, BasicBlock* NewDefault, - Value* Val, unsigned level); + return "splits switch constructs"; + + } + + struct CaseExpr { + + ConstantInt *Val; + BasicBlock * BB; + + CaseExpr(ConstantInt *val = nullptr, BasicBlock *bb = nullptr) + : Val(val), BB(bb) { + + } + }; -} + typedef std::vector<CaseExpr> CaseVector; -char SplitSwitchesTransform::ID = 0; + private: + bool splitSwitches(Module &M); + bool transformCmps(Module &M, const bool processStrcmp, + const bool processMemcmp); + BasicBlock *switchConvert(CaseVector Cases, std::vector<bool> bytesChecked, + BasicBlock *OrigBlock, BasicBlock *NewDefault, + Value *Val, unsigned level); + +}; +} // namespace + +char SplitSwitchesTransform::ID = 0; /* switchConvert - Transform simple list of Cases into list of CaseRange's */ -BasicBlock* SplitSwitchesTransform::switchConvert(CaseVector Cases, std::vector<bool> bytesChecked, - BasicBlock* OrigBlock, BasicBlock* NewDefault, - Value* Val, unsigned level) { - - unsigned ValTypeBitWidth = Cases[0].Val->getBitWidth(); - IntegerType *ValType = IntegerType::get(OrigBlock->getContext(), ValTypeBitWidth); - IntegerType *ByteType = IntegerType::get(OrigBlock->getContext(), 8); - unsigned BytesInValue = bytesChecked.size(); +BasicBlock *SplitSwitchesTransform::switchConvert( + CaseVector Cases, std::vector<bool> bytesChecked, BasicBlock *OrigBlock, + BasicBlock *NewDefault, Value *Val, unsigned level) { + + unsigned ValTypeBitWidth = Cases[0].Val->getBitWidth(); + IntegerType *ValType = + IntegerType::get(OrigBlock->getContext(), ValTypeBitWidth); + IntegerType * ByteType = IntegerType::get(OrigBlock->getContext(), 8); + unsigned BytesInValue = bytesChecked.size(); std::vector<uint8_t> setSizes; std::vector<std::set<uint8_t>> byteSets(BytesInValue, std::set<uint8_t>()); @@ -91,43 +102,54 @@ BasicBlock* SplitSwitchesTransform::switchConvert(CaseVector Cases, std::vector< /* for each of the possible cases we iterate over all bytes of the values * build a set of possible values at each byte position in byteSets */ - for (CaseExpr& Case: Cases) { + for (CaseExpr &Case : Cases) { + for (unsigned i = 0; i < BytesInValue; i++) { - uint8_t byte = (Case.Val->getZExtValue() >> (i*8)) & 0xFF; + uint8_t byte = (Case.Val->getZExtValue() >> (i * 8)) & 0xFF; byteSets[i].insert(byte); + } + } /* find the index of the first byte position that was not yet checked. then * save the number of possible values at that byte position */ unsigned smallestIndex = 0; unsigned smallestSize = 257; - for(unsigned i = 0; i < byteSets.size(); i++) { - if (bytesChecked[i]) - continue; + for (unsigned i = 0; i < byteSets.size(); i++) { + + if (bytesChecked[i]) continue; if (byteSets[i].size() < smallestSize) { + smallestIndex = i; smallestSize = byteSets[i].size(); + } + } + assert(bytesChecked[smallestIndex] == false); /* there are only smallestSize different bytes at index smallestIndex */ - + Instruction *Shift, *Trunc; - Function* F = OrigBlock->getParent(); - BasicBlock* NewNode = BasicBlock::Create(Val->getContext(), "NodeBlock", F); - Shift = BinaryOperator::Create(Instruction::LShr, Val, ConstantInt::get(ValType, smallestIndex * 8)); + Function * F = OrigBlock->getParent(); + BasicBlock * NewNode = BasicBlock::Create(Val->getContext(), "NodeBlock", F); + Shift = BinaryOperator::Create(Instruction::LShr, Val, + ConstantInt::get(ValType, smallestIndex * 8)); NewNode->getInstList().push_back(Shift); if (ValTypeBitWidth > 8) { + Trunc = new TruncInst(Shift, ByteType); NewNode->getInstList().push_back(Trunc); - } - else { + + } else { + /* not necessary to trunc */ Trunc = Shift; + } /* this is a trivial case, we can directly check for the byte, @@ -135,118 +157,155 @@ BasicBlock* SplitSwitchesTransform::switchConvert(CaseVector Cases, std::vector< * mark the byte as checked. if this was the last byte to check * we can finally execute the block belonging to this case */ - if (smallestSize == 1) { + uint8_t byte = *(byteSets[smallestIndex].begin()); - /* insert instructions to check whether the value we are switching on is equal to byte */ - ICmpInst* Comp = new ICmpInst(ICmpInst::ICMP_EQ, Trunc, ConstantInt::get(ByteType, byte), "byteMatch"); + /* insert instructions to check whether the value we are switching on is + * equal to byte */ + ICmpInst *Comp = + new ICmpInst(ICmpInst::ICMP_EQ, Trunc, ConstantInt::get(ByteType, byte), + "byteMatch"); NewNode->getInstList().push_back(Comp); bytesChecked[smallestIndex] = true; - if (std::all_of(bytesChecked.begin(), bytesChecked.end(), [](bool b){return b;} )) { + if (std::all_of(bytesChecked.begin(), bytesChecked.end(), + [](bool b) { return b; })) { + assert(Cases.size() == 1); BranchInst::Create(Cases[0].BB, NewDefault, Comp, NewNode); /* we have to update the phi nodes! */ - for (BasicBlock::iterator I = Cases[0].BB->begin(); I != Cases[0].BB->end(); ++I) { - if (!isa<PHINode>(&*I)) { - continue; - } + for (BasicBlock::iterator I = Cases[0].BB->begin(); + I != Cases[0].BB->end(); ++I) { + + if (!isa<PHINode>(&*I)) { continue; } PHINode *PN = cast<PHINode>(I); /* Only update the first occurrence. */ unsigned Idx = 0, E = PN->getNumIncomingValues(); for (; Idx != E; ++Idx) { + if (PN->getIncomingBlock(Idx) == OrigBlock) { + PN->setIncomingBlock(Idx, NewNode); break; + } + } + } - } - else { - BasicBlock* BB = switchConvert(Cases, bytesChecked, OrigBlock, NewDefault, Val, level + 1); + + } else { + + BasicBlock *BB = switchConvert(Cases, bytesChecked, OrigBlock, NewDefault, + Val, level + 1); BranchInst::Create(BB, NewDefault, Comp, NewNode); + } + } + /* there is no byte which we can directly check on, split the tree */ else { std::vector<uint8_t> byteVector; - std::copy(byteSets[smallestIndex].begin(), byteSets[smallestIndex].end(), std::back_inserter(byteVector)); + std::copy(byteSets[smallestIndex].begin(), byteSets[smallestIndex].end(), + std::back_inserter(byteVector)); std::sort(byteVector.begin(), byteVector.end()); uint8_t pivot = byteVector[byteVector.size() / 2]; - /* we already chose to divide the cases based on the value of byte at index smallestIndex - * the pivot value determines the threshold for the decicion; if a case value - * is smaller at this byte index move it to the LHS vector, otherwise to the RHS vector */ + /* we already chose to divide the cases based on the value of byte at index + * smallestIndex the pivot value determines the threshold for the decicion; + * if a case value + * is smaller at this byte index move it to the LHS vector, otherwise to the + * RHS vector */ CaseVector LHSCases, RHSCases; - for (CaseExpr& Case: Cases) { - uint8_t byte = (Case.Val->getZExtValue() >> (smallestIndex*8)) & 0xFF; + for (CaseExpr &Case : Cases) { + + uint8_t byte = (Case.Val->getZExtValue() >> (smallestIndex * 8)) & 0xFF; if (byte < pivot) { + LHSCases.push_back(Case); - } - else { + + } else { + RHSCases.push_back(Case); + } + } - BasicBlock *LBB, *RBB; - LBB = switchConvert(LHSCases, bytesChecked, OrigBlock, NewDefault, Val, level + 1); - RBB = switchConvert(RHSCases, bytesChecked, OrigBlock, NewDefault, Val, level + 1); - /* insert instructions to check whether the value we are switching on is equal to byte */ - ICmpInst* Comp = new ICmpInst(ICmpInst::ICMP_ULT, Trunc, ConstantInt::get(ByteType, pivot), "byteMatch"); + BasicBlock *LBB, *RBB; + LBB = switchConvert(LHSCases, bytesChecked, OrigBlock, NewDefault, Val, + level + 1); + RBB = switchConvert(RHSCases, bytesChecked, OrigBlock, NewDefault, Val, + level + 1); + + /* insert instructions to check whether the value we are switching on is + * equal to byte */ + ICmpInst *Comp = + new ICmpInst(ICmpInst::ICMP_ULT, Trunc, + ConstantInt::get(ByteType, pivot), "byteMatch"); NewNode->getInstList().push_back(Comp); BranchInst::Create(LBB, RBB, Comp, NewNode); } return NewNode; + } bool SplitSwitchesTransform::splitSwitches(Module &M) { - std::vector<SwitchInst*> switches; + std::vector<SwitchInst *> switches; /* iterate over all functions, bbs and instruction and add * all switches to switches vector for later processing */ for (auto &F : M) { + for (auto &BB : F) { - SwitchInst* switchInst = nullptr; + + SwitchInst *switchInst = nullptr; if ((switchInst = dyn_cast<SwitchInst>(BB.getTerminator()))) { - if (switchInst->getNumCases() < 1) - continue; - switches.push_back(switchInst); + + if (switchInst->getNumCases() < 1) continue; + switches.push_back(switchInst); + } + } + } - if (!switches.size()) - return false; - errs() << "Rewriting " << switches.size() << " switch statements " << "\n"; + if (!switches.size()) return false; + errs() << "Rewriting " << switches.size() << " switch statements " + << "\n"; - for (auto &SI: switches) { + for (auto &SI : switches) { BasicBlock *CurBlock = SI->getParent(); BasicBlock *OrigBlock = CurBlock; - Function *F = CurBlock->getParent(); + Function * F = CurBlock->getParent(); /* this is the value we are switching on */ - Value *Val = SI->getCondition(); - BasicBlock* Default = SI->getDefaultDest(); - unsigned bitw = Val->getType()->getIntegerBitWidth(); + Value * Val = SI->getCondition(); + BasicBlock *Default = SI->getDefaultDest(); + unsigned bitw = Val->getType()->getIntegerBitWidth(); errs() << "switch: " << SI->getNumCases() << " cases " << bitw << " bit\n"; - /* If there is only the default destination or the condition checks 8 bit or less, don't bother with the code below. */ + /* If there is only the default destination or the condition checks 8 bit or + * less, don't bother with the code below. */ if (!SI->getNumCases() || bitw <= 8) { - if (getenv("AFL_QUIET") == NULL) - errs() << "skip trivial switch..\n"; + + if (getenv("AFL_QUIET") == NULL) errs() << "skip trivial switch..\n"; continue; + } /* Create a new, empty default block so that the new hierarchy of @@ -258,10 +317,10 @@ bool SplitSwitchesTransform::splitSwitches(Module &M) { NewDefault->insertInto(F, Default); BranchInst::Create(Default, NewDefault); - /* Prepare cases vector. */ CaseVector Cases; - for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end(); i != e; ++i) + for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end(); i != e; + ++i) #if LLVM_VERSION_MAJOR < 5 Cases.push_back(CaseExpr(i.getCaseValue(), i.getCaseSuccessor())); #else @@ -269,8 +328,10 @@ bool SplitSwitchesTransform::splitSwitches(Module &M) { #endif /* bugfix thanks to pbst * round up bytesChecked (in case getBitWidth() % 8 != 0) */ - std::vector<bool> bytesChecked((7 + Cases[0].Val->getBitWidth()) / 8, false); - BasicBlock* SwitchBlock = switchConvert(Cases, bytesChecked, OrigBlock, NewDefault, Val, 0); + std::vector<bool> bytesChecked((7 + Cases[0].Val->getBitWidth()) / 8, + false); + BasicBlock * SwitchBlock = + switchConvert(Cases, bytesChecked, OrigBlock, NewDefault, Val, 0); /* Branch to our shiny new if-then stuff... */ BranchInst::Create(SwitchBlock, OrigBlock); @@ -278,41 +339,47 @@ bool SplitSwitchesTransform::splitSwitches(Module &M) { /* We are now done with the switch instruction, delete it. */ CurBlock->getInstList().erase(SI); + /* we have to update the phi nodes! */ + for (BasicBlock::iterator I = Default->begin(); I != Default->end(); ++I) { + + if (!isa<PHINode>(&*I)) { continue; } + PHINode *PN = cast<PHINode>(I); + + /* Only update the first occurrence. */ + unsigned Idx = 0, E = PN->getNumIncomingValues(); + for (; Idx != E; ++Idx) { + + if (PN->getIncomingBlock(Idx) == OrigBlock) { + + PN->setIncomingBlock(Idx, NewDefault); + break; + + } + + } + + } + + } + + verifyModule(M); + return true; - /* we have to update the phi nodes! */ - for (BasicBlock::iterator I = Default->begin(); I != Default->end(); ++I) { - if (!isa<PHINode>(&*I)) { - continue; - } - PHINode *PN = cast<PHINode>(I); - - /* Only update the first occurrence. */ - unsigned Idx = 0, E = PN->getNumIncomingValues(); - for (; Idx != E; ++Idx) { - if (PN->getIncomingBlock(Idx) == OrigBlock) { - PN->setIncomingBlock(Idx, NewDefault); - break; - } - } - } - } - - verifyModule(M); - return true; } bool SplitSwitchesTransform::runOnModule(Module &M) { if (getenv("AFL_QUIET") == NULL) - llvm::errs() << "Running split-switches-pass by laf.intel@gmail.com\n"; + llvm::errs() << "Running split-switches-pass by laf.intel@gmail.com\n"; splitSwitches(M); verifyModule(M); return true; + } static void registerSplitSwitchesTransPass(const PassManagerBuilder &, - legacy::PassManagerBase &PM) { + legacy::PassManagerBase &PM) { auto p = new SplitSwitchesTransform(); PM.add(p); @@ -324,3 +391,4 @@ static RegisterStandardPasses RegisterSplitSwitchesTransPass( static RegisterStandardPasses RegisterSplitSwitchesTransPass0( PassManagerBuilder::EP_EnabledOnOptLevel0, registerSplitSwitchesTransPass); + |