about summary refs log tree commit diff
path: root/llvm_mode/split-switches-pass.so.cc
diff options
context:
space:
mode:
Diffstat (limited to 'llvm_mode/split-switches-pass.so.cc')
-rw-r--r--llvm_mode/split-switches-pass.so.cc315
1 files changed, 315 insertions, 0 deletions
diff --git a/llvm_mode/split-switches-pass.so.cc b/llvm_mode/split-switches-pass.so.cc
new file mode 100644
index 00000000..1341c7f9
--- /dev/null
+++ b/llvm_mode/split-switches-pass.so.cc
@@ -0,0 +1,315 @@
+/*
+ * Copyright 2016 laf-intel
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <stdio.h>
+#include <stdlib.h>
+#include <unistd.h>
+
+#include "llvm/ADT/Statistic.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/LegacyPassManager.h"
+#include "llvm/IR/Module.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+#include "llvm/Transforms/IPO/PassManagerBuilder.h"
+#include "llvm/Transforms/Utils/BasicBlockUtils.h"
+#include "llvm/IR/Verifier.h"
+#include "llvm/Pass.h"
+#include "llvm/Analysis/ValueTracking.h"
+
+#include <set>
+
+using namespace llvm;
+
+namespace {
+
+  class SplitSwitchesTransform : public ModulePass {
+
+    public:
+      static char ID;
+      SplitSwitchesTransform() : ModulePass(ID) {
+      } 
+
+      bool runOnModule(Module &M) override;
+
+#if __clang_major__ >= 4
+      StringRef getPassName() const override {
+#else
+      const char * getPassName() const override {
+#endif
+        return "splits switch constructs";
+      }
+      struct CaseExpr {
+        ConstantInt* Val;
+        BasicBlock* BB;
+
+        CaseExpr(ConstantInt *val = nullptr, BasicBlock *bb = nullptr) :
+          Val(val), BB(bb) { }
+      };
+
+    typedef std::vector<CaseExpr> CaseVector;
+
+    private:
+      bool splitSwitches(Module &M);
+      bool transformCmps(Module &M, const bool processStrcmp, const bool processMemcmp);
+      BasicBlock* switchConvert(CaseVector Cases, std::vector<bool> bytesChecked,
+                                BasicBlock* OrigBlock, BasicBlock* NewDefault,
+                                Value* Val, unsigned level);
+  };
+
+}
+
+char SplitSwitchesTransform::ID = 0;
+
+
+/* switchConvert - Transform simple list of Cases into list of CaseRange's */
+BasicBlock* SplitSwitchesTransform::switchConvert(CaseVector Cases, std::vector<bool> bytesChecked, 
+                                            BasicBlock* OrigBlock, BasicBlock* NewDefault,
+                                            Value* Val, unsigned level) {
+
+  unsigned ValTypeBitWidth = Cases[0].Val->getBitWidth();
+  IntegerType *ValType  = IntegerType::get(OrigBlock->getContext(), ValTypeBitWidth);
+  IntegerType *ByteType = IntegerType::get(OrigBlock->getContext(), 8);
+  unsigned BytesInValue = bytesChecked.size();
+  std::vector<uint8_t> setSizes;
+  std::vector<std::set<uint8_t>> byteSets(BytesInValue, std::set<uint8_t>());
+
+
+  /* for each of the possible cases we iterate over all bytes of the values
+   * build a set of possible values at each byte position in byteSets */
+  for (CaseExpr& Case: Cases) {
+    for (unsigned i = 0; i < BytesInValue; i++) {
+
+      uint8_t byte = (Case.Val->getZExtValue() >> (i*8)) & 0xFF;
+      byteSets[i].insert(byte);
+    }
+  }
+
+  unsigned smallestIndex = 0;
+  unsigned smallestSize = 257;
+  for(unsigned i = 0; i < byteSets.size(); i++) {
+    if (bytesChecked[i])
+      continue;
+    if (byteSets[i].size() < smallestSize) {
+      smallestIndex = i;
+      smallestSize = byteSets[i].size();
+    }
+  }
+  assert(bytesChecked[smallestIndex] == false);
+
+  /* there are only smallestSize different bytes at index smallestIndex */
+ 
+  Instruction *Shift, *Trunc;
+  Function* F = OrigBlock->getParent();
+  BasicBlock* NewNode = BasicBlock::Create(Val->getContext(), "NodeBlock", F);
+  Shift = BinaryOperator::Create(Instruction::LShr, Val, ConstantInt::get(ValType, smallestIndex * 8));
+  NewNode->getInstList().push_back(Shift);
+
+  if (ValTypeBitWidth > 8) {
+    Trunc = new TruncInst(Shift, ByteType);
+    NewNode->getInstList().push_back(Trunc);
+  }
+  else {
+    /* not necessary to trunc */
+    Trunc = Shift;
+  }
+
+  /* this is a trivial case, we can directly check for the byte,
+   * if the byte is not found go to default. if the byte was found
+   * mark the byte as checked. if this was the last byte to check
+   * we can finally execute the block belonging to this case */
+
+
+  if (smallestSize == 1) {
+    uint8_t byte = *(byteSets[smallestIndex].begin());
+
+    /* insert instructions to check whether the value we are switching on is equal to byte */
+    ICmpInst* Comp = new ICmpInst(ICmpInst::ICMP_EQ, Trunc, ConstantInt::get(ByteType, byte), "byteMatch");
+    NewNode->getInstList().push_back(Comp);
+
+    bytesChecked[smallestIndex] = true;
+    if (std::all_of(bytesChecked.begin(), bytesChecked.end(), [](bool b){return b;} )) {
+      assert(Cases.size() == 1);
+      BranchInst::Create(Cases[0].BB, NewDefault, Comp, NewNode);
+
+      /* we have to update the phi nodes! */
+      for (BasicBlock::iterator I = Cases[0].BB->begin(); I != Cases[0].BB->end(); ++I) {
+        if (!isa<PHINode>(&*I)) {
+          continue;
+        }
+        PHINode *PN = cast<PHINode>(I);
+
+        /* Only update the first occurence. */
+        unsigned Idx = 0, E = PN->getNumIncomingValues();
+        for (; Idx != E; ++Idx) {
+          if (PN->getIncomingBlock(Idx) == OrigBlock) {
+            PN->setIncomingBlock(Idx, NewNode);
+            break;
+          }
+        }
+      }
+    }
+    else {
+      BasicBlock* BB = switchConvert(Cases, bytesChecked, OrigBlock, NewDefault, Val, level + 1);
+      BranchInst::Create(BB, NewDefault, Comp, NewNode);
+    }
+  }
+  /* there is no byte which we can directly check on, split the tree */
+  else {
+
+    std::vector<uint8_t> byteVector;
+    std::copy(byteSets[smallestIndex].begin(), byteSets[smallestIndex].end(), std::back_inserter(byteVector));
+    std::sort(byteVector.begin(), byteVector.end());
+    uint8_t pivot = byteVector[byteVector.size() / 2];
+
+    /* we already chose to divide the cases based on the value of byte at index smallestIndex
+     * the pivot value determines the threshold for the decicion; if a case value
+     * is smaller at this byte index move it to the LHS vector, otherwise to the RHS vector */
+
+    CaseVector LHSCases, RHSCases;
+
+    for (CaseExpr& Case: Cases) {
+      uint8_t byte = (Case.Val->getZExtValue() >> (smallestIndex*8)) & 0xFF;
+
+      if (byte < pivot) {
+        LHSCases.push_back(Case);
+      }
+      else {
+        RHSCases.push_back(Case);
+      }
+    }
+    BasicBlock *LBB, *RBB;
+    LBB = switchConvert(LHSCases, bytesChecked, OrigBlock, NewDefault, Val, level + 1);
+    RBB = switchConvert(RHSCases, bytesChecked, OrigBlock, NewDefault, Val, level + 1);
+
+    /* insert instructions to check whether the value we are switching on is equal to byte */
+    ICmpInst* Comp = new ICmpInst(ICmpInst::ICMP_ULT, Trunc, ConstantInt::get(ByteType, pivot), "byteMatch");
+    NewNode->getInstList().push_back(Comp);
+    BranchInst::Create(LBB, RBB, Comp, NewNode);
+
+  }
+
+  return NewNode;
+}
+
+bool SplitSwitchesTransform::splitSwitches(Module &M) {
+
+  std::vector<SwitchInst*> switches;
+
+  /* iterate over all functions, bbs and instruction and add
+   * all switches to switches vector for later processing */
+  for (auto &F : M) {
+    for (auto &BB : F) {
+      SwitchInst* switchInst = nullptr;
+
+      if ((switchInst = dyn_cast<SwitchInst>(BB.getTerminator()))) {
+        if (switchInst->getNumCases() < 1)
+            continue;
+          switches.push_back(switchInst);
+      }
+    }
+  }
+
+  if (!switches.size())
+    return false;
+  errs() << "Rewriting " << switches.size() << " switch statements " << "\n";
+
+  for (auto &SI: switches) {
+
+    BasicBlock *CurBlock = SI->getParent();
+    BasicBlock *OrigBlock = CurBlock;
+    Function *F = CurBlock->getParent();
+    /* this is the value we are switching on */
+    Value *Val = SI->getCondition();
+    BasicBlock* Default = SI->getDefaultDest();
+
+    /* If there is only the default destination, don't bother with the code below. */
+    if (!SI->getNumCases()) {
+      continue;
+    }
+
+    /* Create a new, empty default block so that the new hierarchy of
+     * if-then statements go to this and the PHI nodes are happy.
+     * if the default block is set as an unreachable we avoid creating one
+     * because will never be a valid target.*/
+    BasicBlock *NewDefault = nullptr;
+    NewDefault = BasicBlock::Create(SI->getContext(), "NewDefault");
+    NewDefault->insertInto(F, Default);
+    BranchInst::Create(Default, NewDefault);
+
+
+    /* Prepare cases vector. */
+    CaseVector Cases;
+    for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end(); i != e; ++i)
+#if __clang_major__ < 7
+      Cases.push_back(CaseExpr(i.getCaseValue(), i.getCaseSuccessor()));
+#else
+      Cases.push_back(CaseExpr(i->getCaseValue(), i->getCaseSuccessor()));
+#endif
+    std::vector<bool> bytesChecked(Cases[0].Val->getBitWidth() / 8, false);
+    BasicBlock* SwitchBlock = switchConvert(Cases, bytesChecked, OrigBlock, NewDefault, Val, 0);
+
+    /* Branch to our shiny new if-then stuff... */
+    BranchInst::Create(SwitchBlock, OrigBlock);
+
+    /* We are now done with the switch instruction, delete it. */
+    CurBlock->getInstList().erase(SI);
+
+
+   /* we have to update the phi nodes! */
+   for (BasicBlock::iterator I = Default->begin(); I != Default->end(); ++I) {
+     if (!isa<PHINode>(&*I)) {
+      continue;
+     }
+     PHINode *PN = cast<PHINode>(I);
+
+     /* Only update the first occurence. */
+     unsigned Idx = 0, E = PN->getNumIncomingValues();
+     for (; Idx != E; ++Idx) {
+       if (PN->getIncomingBlock(Idx) == OrigBlock) {
+         PN->setIncomingBlock(Idx, NewDefault);
+         break;
+       }
+     }
+   }
+ }
+
+ verifyModule(M);
+ return true;
+}
+
+bool SplitSwitchesTransform::runOnModule(Module &M) {
+
+  llvm::errs() << "Running split-switches-pass by laf.intel@gmail.com\n"; 
+  splitSwitches(M);
+  verifyModule(M);
+
+  return true;
+}
+
+static void registerSplitSwitchesTransPass(const PassManagerBuilder &,
+                            legacy::PassManagerBase &PM) {
+
+  auto p = new SplitSwitchesTransform();
+  PM.add(p);
+
+}
+
+static RegisterStandardPasses RegisterSplitSwitchesTransPass(
+    PassManagerBuilder::EP_OptimizerLast, registerSplitSwitchesTransPass);
+
+static RegisterStandardPasses RegisterSplitSwitchesTransPass0(
+    PassManagerBuilder::EP_EnabledOnOptLevel0, registerSplitSwitchesTransPass);