about summary refs log tree commit diff
path: root/llvm_mode/compare-transform-pass.so.cc
diff options
context:
space:
mode:
Diffstat (limited to 'llvm_mode/compare-transform-pass.so.cc')
-rw-r--r--llvm_mode/compare-transform-pass.so.cc286
1 files changed, 165 insertions, 121 deletions
diff --git a/llvm_mode/compare-transform-pass.so.cc b/llvm_mode/compare-transform-pass.so.cc
index e7886db1..e1b6e671 100644
--- a/llvm_mode/compare-transform-pass.so.cc
+++ b/llvm_mode/compare-transform-pass.so.cc
@@ -36,202 +36,236 @@ using namespace llvm;
 
 namespace {
 
-  class CompareTransform : public ModulePass {
+class CompareTransform : public ModulePass {
 
-    public:
-      static char ID;
-      CompareTransform() : ModulePass(ID) {
-      } 
+ public:
+  static char ID;
+  CompareTransform() : ModulePass(ID) {
 
-      bool runOnModule(Module &M) override;
+  }
+
+  bool runOnModule(Module &M) override;
 
 #if LLVM_VERSION_MAJOR < 4
-      const char * getPassName() const override {
+  const char *getPassName() const override {
+
 #else
-      StringRef getPassName() const override {
+  StringRef getPassName() const override {
+
 #endif
-        return "transforms compare functions";
-      }
-    private:
-      bool transformCmps(Module &M, const bool processStrcmp, const bool processMemcmp
-        ,const bool processStrncmp, const bool processStrcasecmp, const bool processStrncasecmp);
-  };
-}
+    return "transforms compare functions";
 
+  }
+
+ private:
+  bool transformCmps(Module &M, const bool processStrcmp,
+                     const bool processMemcmp, const bool processStrncmp,
+                     const bool processStrcasecmp,
+                     const bool processStrncasecmp);
+
+};
+
+}  // namespace
 
 char CompareTransform::ID = 0;
 
-bool CompareTransform::transformCmps(Module &M, const bool processStrcmp, const bool processMemcmp
-  , const bool processStrncmp, const bool processStrcasecmp, const bool processStrncasecmp) {
+bool CompareTransform::transformCmps(Module &M, const bool processStrcmp,
+                                     const bool processMemcmp,
+                                     const bool processStrncmp,
+                                     const bool processStrcasecmp,
+                                     const bool processStrncasecmp) {
 
-  std::vector<CallInst*> calls;
-  LLVMContext &C = M.getContext();
-  IntegerType *Int8Ty = IntegerType::getInt8Ty(C);
-  IntegerType *Int32Ty = IntegerType::getInt32Ty(C);
-  IntegerType *Int64Ty = IntegerType::getInt64Ty(C);
+  std::vector<CallInst *> calls;
+  LLVMContext &           C = M.getContext();
+  IntegerType *           Int8Ty = IntegerType::getInt8Ty(C);
+  IntegerType *           Int32Ty = IntegerType::getInt32Ty(C);
+  IntegerType *           Int64Ty = IntegerType::getInt64Ty(C);
 
 #if LLVM_VERSION_MAJOR < 9
-  Constant* 
+  Constant *
 #else
   FunctionCallee
 #endif
-               c = M.getOrInsertFunction("tolower",
-                                         Int32Ty,
-                                         Int32Ty
+      c = M.getOrInsertFunction("tolower", Int32Ty, Int32Ty
 #if LLVM_VERSION_MAJOR < 5
-					 , nullptr
+                                ,
+                                nullptr
 #endif
-					 );
+      );
 #if LLVM_VERSION_MAJOR < 9
-  Function* tolowerFn = cast<Function>(c);
+  Function *tolowerFn = cast<Function>(c);
 #else
   FunctionCallee tolowerFn = c;
 #endif
 
-  /* iterate over all functions, bbs and instruction and add suitable calls to strcmp/memcmp/strncmp/strcasecmp/strncasecmp */
+  /* iterate over all functions, bbs and instruction and add suitable calls to
+   * strcmp/memcmp/strncmp/strcasecmp/strncasecmp */
   for (auto &F : M) {
+
     for (auto &BB : F) {
-      for(auto &IN: BB) {
-        CallInst* callInst = nullptr;
+
+      for (auto &IN : BB) {
+
+        CallInst *callInst = nullptr;
 
         if ((callInst = dyn_cast<CallInst>(&IN))) {
 
-          bool isStrcmp      = processStrcmp;
-          bool isMemcmp      = processMemcmp;
-          bool isStrncmp     = processStrncmp;
-          bool isStrcasecmp  = processStrcasecmp;
+          bool isStrcmp = processStrcmp;
+          bool isMemcmp = processMemcmp;
+          bool isStrncmp = processStrncmp;
+          bool isStrcasecmp = processStrcasecmp;
           bool isStrncasecmp = processStrncasecmp;
 
           Function *Callee = callInst->getCalledFunction();
-          if (!Callee)
-            continue;
-          if (callInst->getCallingConv() != llvm::CallingConv::C)
-            continue;
+          if (!Callee) continue;
+          if (callInst->getCallingConv() != llvm::CallingConv::C) continue;
           StringRef FuncName = Callee->getName();
-          isStrcmp      &= !FuncName.compare(StringRef("strcmp"));
-          isMemcmp      &= !FuncName.compare(StringRef("memcmp"));
-          isStrncmp     &= !FuncName.compare(StringRef("strncmp"));
-          isStrcasecmp  &= !FuncName.compare(StringRef("strcasecmp"));
+          isStrcmp &= !FuncName.compare(StringRef("strcmp"));
+          isMemcmp &= !FuncName.compare(StringRef("memcmp"));
+          isStrncmp &= !FuncName.compare(StringRef("strncmp"));
+          isStrcasecmp &= !FuncName.compare(StringRef("strcasecmp"));
           isStrncasecmp &= !FuncName.compare(StringRef("strncasecmp"));
 
-          if (!isStrcmp && !isMemcmp && !isStrncmp && !isStrcasecmp && !isStrncasecmp)
+          if (!isStrcmp && !isMemcmp && !isStrncmp && !isStrcasecmp &&
+              !isStrncasecmp)
             continue;
 
-          /* Verify the strcmp/memcmp/strncmp/strcasecmp/strncasecmp function prototype */
+          /* Verify the strcmp/memcmp/strncmp/strcasecmp/strncasecmp function
+           * prototype */
           FunctionType *FT = Callee->getFunctionType();
 
-
-          isStrcmp      &= FT->getNumParams() == 2 &&
-                      FT->getReturnType()->isIntegerTy(32) &&
-                      FT->getParamType(0) == FT->getParamType(1) &&
-                      FT->getParamType(0) == IntegerType::getInt8PtrTy(M.getContext());
-          isStrcasecmp  &= FT->getNumParams() == 2 &&
-                      FT->getReturnType()->isIntegerTy(32) &&
-                      FT->getParamType(0) == FT->getParamType(1) &&
-                      FT->getParamType(0) == IntegerType::getInt8PtrTy(M.getContext());
-          isMemcmp      &= FT->getNumParams() == 3 &&
+          isStrcmp &=
+              FT->getNumParams() == 2 && FT->getReturnType()->isIntegerTy(32) &&
+              FT->getParamType(0) == FT->getParamType(1) &&
+              FT->getParamType(0) == IntegerType::getInt8PtrTy(M.getContext());
+          isStrcasecmp &=
+              FT->getNumParams() == 2 && FT->getReturnType()->isIntegerTy(32) &&
+              FT->getParamType(0) == FT->getParamType(1) &&
+              FT->getParamType(0) == IntegerType::getInt8PtrTy(M.getContext());
+          isMemcmp &= FT->getNumParams() == 3 &&
                       FT->getReturnType()->isIntegerTy(32) &&
                       FT->getParamType(0)->isPointerTy() &&
                       FT->getParamType(1)->isPointerTy() &&
                       FT->getParamType(2)->isIntegerTy();
-          isStrncmp     &= FT->getNumParams() == 3 &&
-                      FT->getReturnType()->isIntegerTy(32) &&
-                      FT->getParamType(0) == FT->getParamType(1) &&
-                      FT->getParamType(0) == IntegerType::getInt8PtrTy(M.getContext()) &&
-                      FT->getParamType(2)->isIntegerTy();
+          isStrncmp &= FT->getNumParams() == 3 &&
+                       FT->getReturnType()->isIntegerTy(32) &&
+                       FT->getParamType(0) == FT->getParamType(1) &&
+                       FT->getParamType(0) ==
+                           IntegerType::getInt8PtrTy(M.getContext()) &&
+                       FT->getParamType(2)->isIntegerTy();
           isStrncasecmp &= FT->getNumParams() == 3 &&
-                      FT->getReturnType()->isIntegerTy(32) &&
-                      FT->getParamType(0) == FT->getParamType(1) &&
-                      FT->getParamType(0) == IntegerType::getInt8PtrTy(M.getContext()) &&
-                      FT->getParamType(2)->isIntegerTy();
-
-          if (!isStrcmp && !isMemcmp && !isStrncmp && !isStrcasecmp && !isStrncasecmp)
+                           FT->getReturnType()->isIntegerTy(32) &&
+                           FT->getParamType(0) == FT->getParamType(1) &&
+                           FT->getParamType(0) ==
+                               IntegerType::getInt8PtrTy(M.getContext()) &&
+                           FT->getParamType(2)->isIntegerTy();
+
+          if (!isStrcmp && !isMemcmp && !isStrncmp && !isStrcasecmp &&
+              !isStrncasecmp)
             continue;
 
           /* is a str{n,}{case,}cmp/memcmp, check if we have
            * str{case,}cmp(x, "const") or str{case,}cmp("const", x)
            * strn{case,}cmp(x, "const", ..) or strn{case,}cmp("const", x, ..)
            * memcmp(x, "const", ..) or memcmp("const", x, ..) */
-          Value *Str1P = callInst->getArgOperand(0), *Str2P = callInst->getArgOperand(1);
+          Value *Str1P = callInst->getArgOperand(0),
+                *Str2P = callInst->getArgOperand(1);
           StringRef Str1, Str2;
-          bool HasStr1 = getConstantStringInfo(Str1P, Str1);
-          bool HasStr2 = getConstantStringInfo(Str2P, Str2);
+          bool      HasStr1 = getConstantStringInfo(Str1P, Str1);
+          bool      HasStr2 = getConstantStringInfo(Str2P, Str2);
 
           /* handle cases of one string is const, one string is variable */
-          if (!(HasStr1 ^ HasStr2))
-            continue;
+          if (!(HasStr1 ^ HasStr2)) continue;
 
           if (isMemcmp || isStrncmp || isStrncasecmp) {
+
             /* check if third operand is a constant integer
              * strlen("constStr") and sizeof() are treated as constant */
-            Value *op2 = callInst->getArgOperand(2);
-            ConstantInt* ilen = dyn_cast<ConstantInt>(op2);
-            if (!ilen)
-              continue;
-            /* final precaution: if size of compare is larger than constant string skip it*/
-            uint64_t literalLength = HasStr1 ? GetStringLength(Str1P) : GetStringLength(Str2P);
-            if (literalLength < ilen->getZExtValue())
-              continue;
+            Value *      op2 = callInst->getArgOperand(2);
+            ConstantInt *ilen = dyn_cast<ConstantInt>(op2);
+            if (!ilen) continue;
+            /* final precaution: if size of compare is larger than constant
+             * string skip it*/
+            uint64_t literalLength =
+                HasStr1 ? GetStringLength(Str1P) : GetStringLength(Str2P);
+            if (literalLength < ilen->getZExtValue()) continue;
+
           }
 
           calls.push_back(callInst);
+
         }
+
       }
+
     }
+
   }
 
-  if (!calls.size())
-    return false;
-  errs() << "Replacing " << calls.size() << " calls to strcmp/memcmp/strncmp/strcasecmp/strncasecmp\n";
+  if (!calls.size()) return false;
+  errs() << "Replacing " << calls.size()
+         << " calls to strcmp/memcmp/strncmp/strcasecmp/strncasecmp\n";
 
-  for (auto &callInst: calls) {
+  for (auto &callInst : calls) {
 
-    Value *Str1P = callInst->getArgOperand(0), *Str2P = callInst->getArgOperand(1);
-    StringRef Str1, Str2, ConstStr;
+    Value *Str1P = callInst->getArgOperand(0),
+          *Str2P = callInst->getArgOperand(1);
+    StringRef   Str1, Str2, ConstStr;
     std::string TmpConstStr;
-    Value *VarStr;
-    bool HasStr1 = getConstantStringInfo(Str1P, Str1);
+    Value *     VarStr;
+    bool        HasStr1 = getConstantStringInfo(Str1P, Str1);
     getConstantStringInfo(Str2P, Str2);
     uint64_t constLen, sizedLen;
-    bool isMemcmp          = !callInst->getCalledFunction()->getName().compare(StringRef("memcmp"));
-    bool isSizedcmp        = isMemcmp
-		          || !callInst->getCalledFunction()->getName().compare(StringRef("strncmp"))
-		          || !callInst->getCalledFunction()->getName().compare(StringRef("strncasecmp"));
-    bool isCaseInsensitive = !callInst->getCalledFunction()->getName().compare(StringRef("strcasecmp"))
-		          || !callInst->getCalledFunction()->getName().compare(StringRef("strncasecmp"));
+    bool     isMemcmp =
+        !callInst->getCalledFunction()->getName().compare(StringRef("memcmp"));
+    bool isSizedcmp = isMemcmp ||
+                      !callInst->getCalledFunction()->getName().compare(
+                          StringRef("strncmp")) ||
+                      !callInst->getCalledFunction()->getName().compare(
+                          StringRef("strncasecmp"));
+    bool isCaseInsensitive = !callInst->getCalledFunction()->getName().compare(
+                                 StringRef("strcasecmp")) ||
+                             !callInst->getCalledFunction()->getName().compare(
+                                 StringRef("strncasecmp"));
 
     if (isSizedcmp) {
-      Value *op2 = callInst->getArgOperand(2);
-      ConstantInt* ilen = dyn_cast<ConstantInt>(op2);
+
+      Value *      op2 = callInst->getArgOperand(2);
+      ConstantInt *ilen = dyn_cast<ConstantInt>(op2);
       sizedLen = ilen->getZExtValue();
+
     }
 
     if (HasStr1) {
+
       TmpConstStr = Str1.str();
       VarStr = Str2P;
       constLen = isMemcmp ? sizedLen : GetStringLength(Str1P);
-    }
-    else {
+
+    } else {
+
       TmpConstStr = Str2.str();
       VarStr = Str1P;
       constLen = isMemcmp ? sizedLen : GetStringLength(Str2P);
+
     }
 
     /* properly handle zero terminated C strings by adding the terminating 0 to
      * the StringRef (in comparison to std::string a StringRef has built-in
      * runtime bounds checking, which makes debugging easier) */
-    TmpConstStr.append("\0", 1); ConstStr = StringRef(TmpConstStr);
+    TmpConstStr.append("\0", 1);
+    ConstStr = StringRef(TmpConstStr);
 
-    if (isSizedcmp && constLen > sizedLen) {
-      constLen = sizedLen;
-    }
+    if (isSizedcmp && constLen > sizedLen) { constLen = sizedLen; }
 
-    errs() << callInst->getCalledFunction()->getName() << ": len " << constLen << ": " << ConstStr << "\n";
+    errs() << callInst->getCalledFunction()->getName() << ": len " << constLen
+           << ": " << ConstStr << "\n";
 
     /* split before the call instruction */
     BasicBlock *bb = callInst->getParent();
     BasicBlock *end_bb = bb->splitBasicBlock(BasicBlock::iterator(callInst));
-    BasicBlock *next_bb =  BasicBlock::Create(C, "cmp_added", end_bb->getParent(), end_bb);
+    BasicBlock *next_bb =
+        BasicBlock::Create(C, "cmp_added", end_bb->getParent(), end_bb);
     BranchInst::Create(end_bb, next_bb);
     PHINode *PN = PHINode::Create(Int32Ty, constLen + 1, "cmp_phi");
 
@@ -249,71 +283,81 @@ bool CompareTransform::transformCmps(Module &M, const bool processStrcmp, const
 
       char c = isCaseInsensitive ? tolower(ConstStr[i]) : ConstStr[i];
 
-
       BasicBlock::iterator IP = next_bb->getFirstInsertionPt();
-      IRBuilder<> IRB(&*IP);
+      IRBuilder<>          IRB(&*IP);
 
-      Value* v = ConstantInt::get(Int64Ty, i);
-      Value *ele  = IRB.CreateInBoundsGEP(VarStr, v, "empty");
+      Value *v = ConstantInt::get(Int64Ty, i);
+      Value *ele = IRB.CreateInBoundsGEP(VarStr, v, "empty");
       Value *load = IRB.CreateLoad(ele);
       if (isCaseInsensitive) {
+
         // load >= 'A' && load <= 'Z' ? load | 0x020 : load
         std::vector<Value *> args;
         args.push_back(load);
         load = IRB.CreateCall(tolowerFn, args, "tmp");
         load = IRB.CreateTrunc(load, Int8Ty);
+
       }
+
       Value *isub;
       if (HasStr1)
         isub = IRB.CreateSub(ConstantInt::get(Int8Ty, c), load);
       else
         isub = IRB.CreateSub(load, ConstantInt::get(Int8Ty, c));
 
-      Value *sext = IRB.CreateSExt(isub, Int32Ty); 
+      Value *sext = IRB.CreateSExt(isub, Int32Ty);
       PN->addIncoming(sext, cur_bb);
 
-
       if (i < constLen - 1) {
-        next_bb =  BasicBlock::Create(C, "cmp_added", end_bb->getParent(), end_bb);
+
+        next_bb =
+            BasicBlock::Create(C, "cmp_added", end_bb->getParent(), end_bb);
         BranchInst::Create(end_bb, next_bb);
 
         Value *icmp = IRB.CreateICmpEQ(isub, ConstantInt::get(Int8Ty, 0));
         IRB.CreateCondBr(icmp, next_bb, end_bb);
         cur_bb->getTerminator()->eraseFromParent();
+
       } else {
-        //IRB.CreateBr(end_bb);
+
+        // IRB.CreateBr(end_bb);
+
       }
 
-      //add offset to varstr
-      //create load
-      //create signed isub
-      //create icmp
-      //create jcc
-      //create next_bb
+      // add offset to varstr
+      // create load
+      // create signed isub
+      // create icmp
+      // create jcc
+      // create next_bb
+
     }
 
     /* since the call is the first instruction of the bb it is safe to
      * replace it with a phi instruction */
     BasicBlock::iterator ii(callInst);
     ReplaceInstWithInst(callInst->getParent()->getInstList(), ii, PN);
-  }
 
+  }
 
   return true;
+
 }
 
 bool CompareTransform::runOnModule(Module &M) {
 
   if (getenv("AFL_QUIET") == NULL)
-    llvm::errs() << "Running compare-transform-pass by laf.intel@gmail.com, extended by heiko@hexco.de\n";
+    llvm::errs() << "Running compare-transform-pass by laf.intel@gmail.com, "
+                    "extended by heiko@hexco.de\n";
   transformCmps(M, true, true, true, true, true);
   verifyModule(M);
 
   return true;
+
 }
 
 static void registerCompTransPass(const PassManagerBuilder &,
-                            legacy::PassManagerBase &PM) {
+                                  legacy::PassManagerBase &PM) {
 
   auto p = new CompareTransform();
   PM.add(p);