about summary refs log tree commit diff
path: root/llvm_mode/split-switches-pass.so.cc
diff options
context:
space:
mode:
Diffstat (limited to 'llvm_mode/split-switches-pass.so.cc')
-rw-r--r--llvm_mode/split-switches-pass.so.cc294
1 files changed, 181 insertions, 113 deletions
diff --git a/llvm_mode/split-switches-pass.so.cc b/llvm_mode/split-switches-pass.so.cc
index 1ace3185..2743a71a 100644
--- a/llvm_mode/split-switches-pass.so.cc
+++ b/llvm_mode/split-switches-pass.so.cc
@@ -36,54 +36,65 @@ using namespace llvm;
 
 namespace {
 
-  class SplitSwitchesTransform : public ModulePass {
+class SplitSwitchesTransform : public ModulePass {
 
-    public:
-      static char ID;
-      SplitSwitchesTransform() : ModulePass(ID) {
-      } 
+ public:
+  static char ID;
+  SplitSwitchesTransform() : ModulePass(ID) {
 
-      bool runOnModule(Module &M) override;
+  }
+
+  bool runOnModule(Module &M) override;
 
 #if LLVM_VERSION_MAJOR >= 4
-      StringRef getPassName() const override {
+  StringRef getPassName() const override {
+
 #else
-      const char * getPassName() const override {
+  const char *getPassName() const override {
+
 #endif
-        return "splits switch constructs";
-      }
-      struct CaseExpr {
-        ConstantInt* Val;
-        BasicBlock* BB;
-
-        CaseExpr(ConstantInt *val = nullptr, BasicBlock *bb = nullptr) :
-          Val(val), BB(bb) { }
-      };
-
-    typedef std::vector<CaseExpr> CaseVector;
-
-    private:
-      bool splitSwitches(Module &M);
-      bool transformCmps(Module &M, const bool processStrcmp, const bool processMemcmp);
-      BasicBlock* switchConvert(CaseVector Cases, std::vector<bool> bytesChecked,
-                                BasicBlock* OrigBlock, BasicBlock* NewDefault,
-                                Value* Val, unsigned level);
+    return "splits switch constructs";
+
+  }
+
+  struct CaseExpr {
+
+    ConstantInt *Val;
+    BasicBlock * BB;
+
+    CaseExpr(ConstantInt *val = nullptr, BasicBlock *bb = nullptr)
+        : Val(val), BB(bb) {
+
+    }
+
   };
 
-}
+  typedef std::vector<CaseExpr> CaseVector;
 
-char SplitSwitchesTransform::ID = 0;
+ private:
+  bool        splitSwitches(Module &M);
+  bool        transformCmps(Module &M, const bool processStrcmp,
+                            const bool processMemcmp);
+  BasicBlock *switchConvert(CaseVector Cases, std::vector<bool> bytesChecked,
+                            BasicBlock *OrigBlock, BasicBlock *NewDefault,
+                            Value *Val, unsigned level);
+
+};
 
+}  // namespace
+
+char SplitSwitchesTransform::ID = 0;
 
 /* switchConvert - Transform simple list of Cases into list of CaseRange's */
-BasicBlock* SplitSwitchesTransform::switchConvert(CaseVector Cases, std::vector<bool> bytesChecked, 
-                                            BasicBlock* OrigBlock, BasicBlock* NewDefault,
-                                            Value* Val, unsigned level) {
-
-  unsigned ValTypeBitWidth = Cases[0].Val->getBitWidth();
-  IntegerType *ValType  = IntegerType::get(OrigBlock->getContext(), ValTypeBitWidth);
-  IntegerType *ByteType = IntegerType::get(OrigBlock->getContext(), 8);
-  unsigned BytesInValue = bytesChecked.size();
+BasicBlock *SplitSwitchesTransform::switchConvert(
+    CaseVector Cases, std::vector<bool> bytesChecked, BasicBlock *OrigBlock,
+    BasicBlock *NewDefault, Value *Val, unsigned level) {
+
+  unsigned     ValTypeBitWidth = Cases[0].Val->getBitWidth();
+  IntegerType *ValType =
+      IntegerType::get(OrigBlock->getContext(), ValTypeBitWidth);
+  IntegerType *        ByteType = IntegerType::get(OrigBlock->getContext(), 8);
+  unsigned             BytesInValue = bytesChecked.size();
   std::vector<uint8_t> setSizes;
   std::vector<std::set<uint8_t>> byteSets(BytesInValue, std::set<uint8_t>());
 
@@ -91,43 +102,54 @@ BasicBlock* SplitSwitchesTransform::switchConvert(CaseVector Cases, std::vector<
 
   /* for each of the possible cases we iterate over all bytes of the values
    * build a set of possible values at each byte position in byteSets */
-  for (CaseExpr& Case: Cases) {
+  for (CaseExpr &Case : Cases) {
+
     for (unsigned i = 0; i < BytesInValue; i++) {
 
-      uint8_t byte = (Case.Val->getZExtValue() >> (i*8)) & 0xFF;
+      uint8_t byte = (Case.Val->getZExtValue() >> (i * 8)) & 0xFF;
       byteSets[i].insert(byte);
+
     }
+
   }
 
   /* find the index of the first byte position that was not yet checked. then
    * save the number of possible values at that byte position */
   unsigned smallestIndex = 0;
   unsigned smallestSize = 257;
-  for(unsigned i = 0; i < byteSets.size(); i++) {
-    if (bytesChecked[i])
-      continue;
+  for (unsigned i = 0; i < byteSets.size(); i++) {
+
+    if (bytesChecked[i]) continue;
     if (byteSets[i].size() < smallestSize) {
+
       smallestIndex = i;
       smallestSize = byteSets[i].size();
+
     }
+
   }
+
   assert(bytesChecked[smallestIndex] == false);
 
   /* there are only smallestSize different bytes at index smallestIndex */
- 
+
   Instruction *Shift, *Trunc;
-  Function* F = OrigBlock->getParent();
-  BasicBlock* NewNode = BasicBlock::Create(Val->getContext(), "NodeBlock", F);
-  Shift = BinaryOperator::Create(Instruction::LShr, Val, ConstantInt::get(ValType, smallestIndex * 8));
+  Function *   F = OrigBlock->getParent();
+  BasicBlock * NewNode = BasicBlock::Create(Val->getContext(), "NodeBlock", F);
+  Shift = BinaryOperator::Create(Instruction::LShr, Val,
+                                 ConstantInt::get(ValType, smallestIndex * 8));
   NewNode->getInstList().push_back(Shift);
 
   if (ValTypeBitWidth > 8) {
+
     Trunc = new TruncInst(Shift, ByteType);
     NewNode->getInstList().push_back(Trunc);
-  }
-  else {
+
+  } else {
+
     /* not necessary to trunc */
     Trunc = Shift;
+
   }
 
   /* this is a trivial case, we can directly check for the byte,
@@ -135,118 +157,155 @@ BasicBlock* SplitSwitchesTransform::switchConvert(CaseVector Cases, std::vector<
    * mark the byte as checked. if this was the last byte to check
    * we can finally execute the block belonging to this case */
 
-
   if (smallestSize == 1) {
+
     uint8_t byte = *(byteSets[smallestIndex].begin());
 
-    /* insert instructions to check whether the value we are switching on is equal to byte */
-    ICmpInst* Comp = new ICmpInst(ICmpInst::ICMP_EQ, Trunc, ConstantInt::get(ByteType, byte), "byteMatch");
+    /* insert instructions to check whether the value we are switching on is
+     * equal to byte */
+    ICmpInst *Comp =
+        new ICmpInst(ICmpInst::ICMP_EQ, Trunc, ConstantInt::get(ByteType, byte),
+                     "byteMatch");
     NewNode->getInstList().push_back(Comp);
 
     bytesChecked[smallestIndex] = true;
-    if (std::all_of(bytesChecked.begin(), bytesChecked.end(), [](bool b){return b;} )) {
+    if (std::all_of(bytesChecked.begin(), bytesChecked.end(),
+                    [](bool b) { return b; })) {
+
       assert(Cases.size() == 1);
       BranchInst::Create(Cases[0].BB, NewDefault, Comp, NewNode);
 
       /* we have to update the phi nodes! */
-      for (BasicBlock::iterator I = Cases[0].BB->begin(); I != Cases[0].BB->end(); ++I) {
-        if (!isa<PHINode>(&*I)) {
-          continue;
-        }
+      for (BasicBlock::iterator I = Cases[0].BB->begin();
+           I != Cases[0].BB->end(); ++I) {
+
+        if (!isa<PHINode>(&*I)) { continue; }
         PHINode *PN = cast<PHINode>(I);
 
         /* Only update the first occurrence. */
         unsigned Idx = 0, E = PN->getNumIncomingValues();
         for (; Idx != E; ++Idx) {
+
           if (PN->getIncomingBlock(Idx) == OrigBlock) {
+
             PN->setIncomingBlock(Idx, NewNode);
             break;
+
           }
+
         }
+
       }
-    }
-    else {
-      BasicBlock* BB = switchConvert(Cases, bytesChecked, OrigBlock, NewDefault, Val, level + 1);
+
+    } else {
+
+      BasicBlock *BB = switchConvert(Cases, bytesChecked, OrigBlock, NewDefault,
+                                     Val, level + 1);
       BranchInst::Create(BB, NewDefault, Comp, NewNode);
+
     }
+
   }
+
   /* there is no byte which we can directly check on, split the tree */
   else {
 
     std::vector<uint8_t> byteVector;
-    std::copy(byteSets[smallestIndex].begin(), byteSets[smallestIndex].end(), std::back_inserter(byteVector));
+    std::copy(byteSets[smallestIndex].begin(), byteSets[smallestIndex].end(),
+              std::back_inserter(byteVector));
     std::sort(byteVector.begin(), byteVector.end());
     uint8_t pivot = byteVector[byteVector.size() / 2];
 
-    /* we already chose to divide the cases based on the value of byte at index smallestIndex
-     * the pivot value determines the threshold for the decicion; if a case value
-     * is smaller at this byte index move it to the LHS vector, otherwise to the RHS vector */
+    /* we already chose to divide the cases based on the value of byte at index
+     * smallestIndex the pivot value determines the threshold for the decicion;
+     * if a case value
+     * is smaller at this byte index move it to the LHS vector, otherwise to the
+     * RHS vector */
 
     CaseVector LHSCases, RHSCases;
 
-    for (CaseExpr& Case: Cases) {
-      uint8_t byte = (Case.Val->getZExtValue() >> (smallestIndex*8)) & 0xFF;
+    for (CaseExpr &Case : Cases) {
+
+      uint8_t byte = (Case.Val->getZExtValue() >> (smallestIndex * 8)) & 0xFF;
 
       if (byte < pivot) {
+
         LHSCases.push_back(Case);
-      }
-      else {
+
+      } else {
+
         RHSCases.push_back(Case);
+
       }
+
     }
-    BasicBlock *LBB, *RBB;
-    LBB = switchConvert(LHSCases, bytesChecked, OrigBlock, NewDefault, Val, level + 1);
-    RBB = switchConvert(RHSCases, bytesChecked, OrigBlock, NewDefault, Val, level + 1);
 
-    /* insert instructions to check whether the value we are switching on is equal to byte */
-    ICmpInst* Comp = new ICmpInst(ICmpInst::ICMP_ULT, Trunc, ConstantInt::get(ByteType, pivot), "byteMatch");
+    BasicBlock *LBB, *RBB;
+    LBB = switchConvert(LHSCases, bytesChecked, OrigBlock, NewDefault, Val,
+                        level + 1);
+    RBB = switchConvert(RHSCases, bytesChecked, OrigBlock, NewDefault, Val,
+                        level + 1);
+
+    /* insert instructions to check whether the value we are switching on is
+     * equal to byte */
+    ICmpInst *Comp =
+        new ICmpInst(ICmpInst::ICMP_ULT, Trunc,
+                     ConstantInt::get(ByteType, pivot), "byteMatch");
     NewNode->getInstList().push_back(Comp);
     BranchInst::Create(LBB, RBB, Comp, NewNode);
 
   }
 
   return NewNode;
+
 }
 
 bool SplitSwitchesTransform::splitSwitches(Module &M) {
 
-  std::vector<SwitchInst*> switches;
+  std::vector<SwitchInst *> switches;
 
   /* iterate over all functions, bbs and instruction and add
    * all switches to switches vector for later processing */
   for (auto &F : M) {
+
     for (auto &BB : F) {
-      SwitchInst* switchInst = nullptr;
+
+      SwitchInst *switchInst = nullptr;
 
       if ((switchInst = dyn_cast<SwitchInst>(BB.getTerminator()))) {
-        if (switchInst->getNumCases() < 1)
-            continue;
-          switches.push_back(switchInst);
+
+        if (switchInst->getNumCases() < 1) continue;
+        switches.push_back(switchInst);
+
       }
+
     }
+
   }
 
-  if (!switches.size())
-    return false;
-  errs() << "Rewriting " << switches.size() << " switch statements " << "\n";
+  if (!switches.size()) return false;
+  errs() << "Rewriting " << switches.size() << " switch statements "
+         << "\n";
 
-  for (auto &SI: switches) {
+  for (auto &SI : switches) {
 
     BasicBlock *CurBlock = SI->getParent();
     BasicBlock *OrigBlock = CurBlock;
-    Function *F = CurBlock->getParent();
+    Function *  F = CurBlock->getParent();
     /* this is the value we are switching on */
-    Value *Val = SI->getCondition();
-    BasicBlock* Default = SI->getDefaultDest();
-    unsigned bitw = Val->getType()->getIntegerBitWidth();
+    Value *     Val = SI->getCondition();
+    BasicBlock *Default = SI->getDefaultDest();
+    unsigned    bitw = Val->getType()->getIntegerBitWidth();
 
     errs() << "switch: " << SI->getNumCases() << " cases " << bitw << " bit\n";
 
-    /* If there is only the default destination or the condition checks 8 bit or less, don't bother with the code below. */
+    /* If there is only the default destination or the condition checks 8 bit or
+     * less, don't bother with the code below. */
     if (!SI->getNumCases() || bitw <= 8) {
-      if (getenv("AFL_QUIET") == NULL)
-        errs() << "skip trivial switch..\n";
+
+      if (getenv("AFL_QUIET") == NULL) errs() << "skip trivial switch..\n";
       continue;
+
     }
 
     /* Create a new, empty default block so that the new hierarchy of
@@ -258,10 +317,10 @@ bool SplitSwitchesTransform::splitSwitches(Module &M) {
     NewDefault->insertInto(F, Default);
     BranchInst::Create(Default, NewDefault);
 
-
     /* Prepare cases vector. */
     CaseVector Cases;
-    for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end(); i != e; ++i)
+    for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end(); i != e;
+         ++i)
 #if LLVM_VERSION_MAJOR < 5
       Cases.push_back(CaseExpr(i.getCaseValue(), i.getCaseSuccessor()));
 #else
@@ -269,8 +328,10 @@ bool SplitSwitchesTransform::splitSwitches(Module &M) {
 #endif
     /* bugfix thanks to pbst
      * round up bytesChecked (in case getBitWidth() % 8 != 0) */
-    std::vector<bool> bytesChecked((7 + Cases[0].Val->getBitWidth()) / 8, false);
-    BasicBlock* SwitchBlock = switchConvert(Cases, bytesChecked, OrigBlock, NewDefault, Val, 0);
+    std::vector<bool> bytesChecked((7 + Cases[0].Val->getBitWidth()) / 8,
+                                   false);
+    BasicBlock *      SwitchBlock =
+        switchConvert(Cases, bytesChecked, OrigBlock, NewDefault, Val, 0);
 
     /* Branch to our shiny new if-then stuff... */
     BranchInst::Create(SwitchBlock, OrigBlock);
@@ -278,41 +339,47 @@ bool SplitSwitchesTransform::splitSwitches(Module &M) {
     /* We are now done with the switch instruction, delete it. */
     CurBlock->getInstList().erase(SI);
 
+    /* we have to update the phi nodes! */
+    for (BasicBlock::iterator I = Default->begin(); I != Default->end(); ++I) {
+
+      if (!isa<PHINode>(&*I)) { continue; }
+      PHINode *PN = cast<PHINode>(I);
+
+      /* Only update the first occurrence. */
+      unsigned Idx = 0, E = PN->getNumIncomingValues();
+      for (; Idx != E; ++Idx) {
+
+        if (PN->getIncomingBlock(Idx) == OrigBlock) {
+
+          PN->setIncomingBlock(Idx, NewDefault);
+          break;
+
+        }
+
+      }
+
+    }
+
+  }
+
+  verifyModule(M);
+  return true;
 
-   /* we have to update the phi nodes! */
-   for (BasicBlock::iterator I = Default->begin(); I != Default->end(); ++I) {
-     if (!isa<PHINode>(&*I)) {
-      continue;
-     }
-     PHINode *PN = cast<PHINode>(I);
-
-     /* Only update the first occurrence. */
-     unsigned Idx = 0, E = PN->getNumIncomingValues();
-     for (; Idx != E; ++Idx) {
-       if (PN->getIncomingBlock(Idx) == OrigBlock) {
-         PN->setIncomingBlock(Idx, NewDefault);
-         break;
-       }
-     }
-   }
- }
-
- verifyModule(M);
- return true;
 }
 
 bool SplitSwitchesTransform::runOnModule(Module &M) {
 
   if (getenv("AFL_QUIET") == NULL)
-    llvm::errs() << "Running split-switches-pass by laf.intel@gmail.com\n"; 
+    llvm::errs() << "Running split-switches-pass by laf.intel@gmail.com\n";
   splitSwitches(M);
   verifyModule(M);
 
   return true;
+
 }
 
 static void registerSplitSwitchesTransPass(const PassManagerBuilder &,
-                            legacy::PassManagerBase &PM) {
+                                           legacy::PassManagerBase &PM) {
 
   auto p = new SplitSwitchesTransform();
   PM.add(p);
@@ -324,3 +391,4 @@ static RegisterStandardPasses RegisterSplitSwitchesTransPass(
 
 static RegisterStandardPasses RegisterSplitSwitchesTransPass0(
     PassManagerBuilder::EP_EnabledOnOptLevel0, registerSplitSwitchesTransPass);
+