about summary refs log tree commit diff
path: root/llvm_mode/compare-transform-pass.so.cc
diff options
context:
space:
mode:
Diffstat (limited to 'llvm_mode/compare-transform-pass.so.cc')
-rw-r--r--llvm_mode/compare-transform-pass.so.cc306
1 files changed, 306 insertions, 0 deletions
diff --git a/llvm_mode/compare-transform-pass.so.cc b/llvm_mode/compare-transform-pass.so.cc
new file mode 100644
index 00000000..acca3ff0
--- /dev/null
+++ b/llvm_mode/compare-transform-pass.so.cc
@@ -0,0 +1,306 @@
+/*
+ * Copyright 2016 laf-intel
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <stdio.h>
+#include <stdlib.h>
+#include <unistd.h>
+
+#include "llvm/ADT/Statistic.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/LegacyPassManager.h"
+#include "llvm/IR/Module.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+#include "llvm/Transforms/IPO/PassManagerBuilder.h"
+#include "llvm/Transforms/Utils/BasicBlockUtils.h"
+#include "llvm/IR/Verifier.h"
+#include "llvm/Pass.h"
+#include "llvm/Analysis/ValueTracking.h"
+
+#include <set>
+
+using namespace llvm;
+
+namespace {
+
+  class CompareTransform : public ModulePass {
+
+    public:
+      static char ID;
+      CompareTransform() : ModulePass(ID) {
+      } 
+
+      bool runOnModule(Module &M) override;
+
+#if __clang_major__ < 4
+      const char * getPassName() const override {
+#else
+      StringRef getPassName() const override {
+#endif
+        return "transforms compare functions";
+      }
+    private:
+      bool transformCmps(Module &M, const bool processStrcmp, const bool processMemcmp
+        ,const bool processStrncmp, const bool processStrcasecmp, const bool processStrncasecmp);
+  };
+}
+
+
+char CompareTransform::ID = 0;
+
+bool CompareTransform::transformCmps(Module &M, const bool processStrcmp, const bool processMemcmp
+  , const bool processStrncmp, const bool processStrcasecmp, const bool processStrncasecmp) {
+
+  std::vector<CallInst*> calls;
+  LLVMContext &C = M.getContext();
+  IntegerType *Int8Ty = IntegerType::getInt8Ty(C);
+  IntegerType *Int32Ty = IntegerType::getInt32Ty(C);
+  IntegerType *Int64Ty = IntegerType::getInt64Ty(C);
+  Constant* c = M.getOrInsertFunction("tolower",
+                                         Int32Ty,
+                                         Int32Ty
+#if __clang_major__ < 7
+					 , nullptr
+#endif
+					 );
+  Function* tolowerFn = cast<Function>(c);
+
+  /* iterate over all functions, bbs and instruction and add suitable calls to strcmp/memcmp/strncmp/strcasecmp/strncasecmp */
+  for (auto &F : M) {
+    for (auto &BB : F) {
+      for(auto &IN: BB) {
+        CallInst* callInst = nullptr;
+
+        if ((callInst = dyn_cast<CallInst>(&IN))) {
+
+          bool isStrcmp      = processStrcmp;
+          bool isMemcmp      = processMemcmp;
+          bool isStrncmp     = processStrncmp;
+          bool isStrcasecmp  = processStrcasecmp;
+          bool isStrncasecmp = processStrncasecmp;
+
+          Function *Callee = callInst->getCalledFunction();
+          if (!Callee)
+            continue;
+          if (callInst->getCallingConv() != llvm::CallingConv::C)
+            continue;
+          StringRef FuncName = Callee->getName();
+          isStrcmp      &= !FuncName.compare(StringRef("strcmp"));
+          isMemcmp      &= !FuncName.compare(StringRef("memcmp"));
+          isStrncmp     &= !FuncName.compare(StringRef("strncmp"));
+          isStrcasecmp  &= !FuncName.compare(StringRef("strcasecmp"));
+          isStrncasecmp &= !FuncName.compare(StringRef("strncasecmp"));
+
+          if (!isStrcmp && !isMemcmp && !isStrncmp && !isStrcasecmp && !isStrncasecmp)
+            continue;
+
+          /* Verify the strcmp/memcmp/strncmp/strcasecmp/strncasecmp function prototype */
+          FunctionType *FT = Callee->getFunctionType();
+
+
+          isStrcmp      &= FT->getNumParams() == 2 &&
+                      FT->getReturnType()->isIntegerTy(32) &&
+                      FT->getParamType(0) == FT->getParamType(1) &&
+                      FT->getParamType(0) == IntegerType::getInt8PtrTy(M.getContext());
+          isStrcasecmp  &= FT->getNumParams() == 2 &&
+                      FT->getReturnType()->isIntegerTy(32) &&
+                      FT->getParamType(0) == FT->getParamType(1) &&
+                      FT->getParamType(0) == IntegerType::getInt8PtrTy(M.getContext());
+          isMemcmp      &= FT->getNumParams() == 3 &&
+                      FT->getReturnType()->isIntegerTy(32) &&
+                      FT->getParamType(0)->isPointerTy() &&
+                      FT->getParamType(1)->isPointerTy() &&
+                      FT->getParamType(2)->isIntegerTy();
+          isStrncmp     &= FT->getNumParams() == 3 &&
+                      FT->getReturnType()->isIntegerTy(32) &&
+                      FT->getParamType(0) == FT->getParamType(1) &&
+                      FT->getParamType(0) == IntegerType::getInt8PtrTy(M.getContext()) &&
+                      FT->getParamType(2)->isIntegerTy();
+          isStrncasecmp &= FT->getNumParams() == 3 &&
+                      FT->getReturnType()->isIntegerTy(32) &&
+                      FT->getParamType(0) == FT->getParamType(1) &&
+                      FT->getParamType(0) == IntegerType::getInt8PtrTy(M.getContext()) &&
+                      FT->getParamType(2)->isIntegerTy();
+
+          if (!isStrcmp && !isMemcmp && !isStrncmp && !isStrcasecmp && !isStrncasecmp)
+            continue;
+
+          /* is a str{n,}{case,}cmp/memcmp, check is we have
+           * str{case,}cmp(x, "const") or str{case,}cmp("const", x)
+           * strn{case,}cmp(x, "const", ..) or strn{case,}cmp("const", x, ..)
+           * memcmp(x, "const", ..) or memcmp("const", x, ..) */
+          Value *Str1P = callInst->getArgOperand(0), *Str2P = callInst->getArgOperand(1);
+          StringRef Str1, Str2;
+          bool HasStr1 = getConstantStringInfo(Str1P, Str1);
+          bool HasStr2 = getConstantStringInfo(Str2P, Str2);
+
+          /* handle cases of one string is const, one string is variable */
+          if (!(HasStr1 ^ HasStr2))
+            continue;
+
+          if (isMemcmp || isStrncmp || isStrncasecmp) {
+            /* check if third operand is a constant integer
+             * strlen("constStr") and sizeof() are treated as constant */
+            Value *op2 = callInst->getArgOperand(2);
+            ConstantInt* ilen = dyn_cast<ConstantInt>(op2);
+            if (!ilen)
+              continue;
+            /* final precaution: if size of compare is larger than constant string skip it*/
+            uint64_t literalLength = HasStr1 ? GetStringLength(Str1P) : GetStringLength(Str2P);
+            if (literalLength < ilen->getZExtValue())
+              continue;
+          }
+
+          calls.push_back(callInst);
+        }
+      }
+    }
+  }
+
+  if (!calls.size())
+    return false;
+  errs() << "Replacing " << calls.size() << " calls to strcmp/memcmp/strncmp/strcasecmp/strncasecmp\n";
+
+  for (auto &callInst: calls) {
+
+    Value *Str1P = callInst->getArgOperand(0), *Str2P = callInst->getArgOperand(1);
+    StringRef Str1, Str2, ConstStr;
+    Value *VarStr;
+    bool HasStr1 = getConstantStringInfo(Str1P, Str1);
+    getConstantStringInfo(Str2P, Str2);
+    uint64_t constLen, sizedLen;
+    bool isMemcmp          = !callInst->getCalledFunction()->getName().compare(StringRef("memcmp"));
+    bool isSizedcmp        = isMemcmp
+		          || !callInst->getCalledFunction()->getName().compare(StringRef("strncmp"))
+		          || !callInst->getCalledFunction()->getName().compare(StringRef("strncasecmp"));
+    bool isCaseInsensitive = !callInst->getCalledFunction()->getName().compare(StringRef("strcasecmp"))
+		          || !callInst->getCalledFunction()->getName().compare(StringRef("strncasecmp"));
+
+    if (isSizedcmp) {
+      Value *op2 = callInst->getArgOperand(2);
+      ConstantInt* ilen = dyn_cast<ConstantInt>(op2);
+      sizedLen = ilen->getZExtValue();
+    }
+
+    if (HasStr1) {
+      ConstStr = Str1;
+      VarStr = Str2P;
+      constLen = isMemcmp ? sizedLen : GetStringLength(Str1P);
+    }
+    else {
+      ConstStr = Str2;
+      VarStr = Str1P;
+      constLen = isMemcmp ? sizedLen : GetStringLength(Str2P);
+    }
+    if (isSizedcmp && constLen > sizedLen) {
+      constLen = sizedLen;
+    }
+
+    errs() << callInst->getCalledFunction()->getName() << ": len " << constLen << ": " << ConstStr << "\n";
+
+    /* split before the call instruction */
+    BasicBlock *bb = callInst->getParent();
+    BasicBlock *end_bb = bb->splitBasicBlock(BasicBlock::iterator(callInst));
+    BasicBlock *next_bb =  BasicBlock::Create(C, "cmp_added", end_bb->getParent(), end_bb);
+    BranchInst::Create(end_bb, next_bb);
+    PHINode *PN = PHINode::Create(Int32Ty, constLen + 1, "cmp_phi");
+
+    TerminatorInst *term = bb->getTerminator();
+    BranchInst::Create(next_bb, bb);
+    term->eraseFromParent();
+
+    for (uint64_t i = 0; i < constLen; i++) {
+
+      BasicBlock *cur_bb = next_bb;
+
+      char c = isCaseInsensitive ? tolower(ConstStr[i]) : ConstStr[i];
+
+
+      BasicBlock::iterator IP = next_bb->getFirstInsertionPt();
+      IRBuilder<> IRB(&*IP);
+
+      Value* v = ConstantInt::get(Int64Ty, i);
+      Value *ele  = IRB.CreateInBoundsGEP(VarStr, v, "empty");
+      Value *load = IRB.CreateLoad(ele);
+      if (isCaseInsensitive) {
+        // load >= 'A' && load <= 'Z' ? load | 0x020 : load
+        std::vector<Value *> args;
+        args.push_back(load);
+        load = IRB.CreateCall(tolowerFn, args, "tmp");
+      }
+      Value *isub;
+      if (HasStr1)
+        isub = IRB.CreateSub(ConstantInt::get(Int8Ty, c), load);
+      else
+        isub = IRB.CreateSub(load, ConstantInt::get(Int8Ty, c));
+
+      Value *sext = IRB.CreateSExt(isub, Int32Ty); 
+      PN->addIncoming(sext, cur_bb);
+
+
+      if (i < constLen - 1) {
+        next_bb =  BasicBlock::Create(C, "cmp_added", end_bb->getParent(), end_bb);
+        BranchInst::Create(end_bb, next_bb);
+
+        TerminatorInst *term = cur_bb->getTerminator();
+        Value *icmp = IRB.CreateICmpEQ(isub, ConstantInt::get(Int8Ty, 0));
+        IRB.CreateCondBr(icmp, next_bb, end_bb);
+        term->eraseFromParent();
+      } else {
+        //IRB.CreateBr(end_bb);
+      }
+
+      //add offset to varstr
+      //create load
+      //create signed isub
+      //create icmp
+      //create jcc
+      //create next_bb
+    }
+
+    /* since the call is the first instruction of the bb it is safe to
+     * replace it with a phi instruction */
+    BasicBlock::iterator ii(callInst);
+    ReplaceInstWithInst(callInst->getParent()->getInstList(), ii, PN);
+  }
+
+
+  return true;
+}
+
+bool CompareTransform::runOnModule(Module &M) {
+
+  llvm::errs() << "Running compare-transform-pass by laf.intel@gmail.com, extended by heiko@hexco.de\n";
+  transformCmps(M, true, true, true, true, true);
+  verifyModule(M);
+
+  return true;
+}
+
+static void registerCompTransPass(const PassManagerBuilder &,
+                            legacy::PassManagerBase &PM) {
+
+  auto p = new CompareTransform();
+  PM.add(p);
+
+}
+
+static RegisterStandardPasses RegisterCompTransPass(
+    PassManagerBuilder::EP_OptimizerLast, registerCompTransPass);
+
+static RegisterStandardPasses RegisterCompTransPass0(
+    PassManagerBuilder::EP_EnabledOnOptLevel0, registerCompTransPass);
+