about summary refs log tree commit diff
path: root/instrumentation/compare-transform-pass.so.cc
diff options
context:
space:
mode:
Diffstat (limited to 'instrumentation/compare-transform-pass.so.cc')
-rw-r--r--instrumentation/compare-transform-pass.so.cc588
1 files changed, 588 insertions, 0 deletions
diff --git a/instrumentation/compare-transform-pass.so.cc b/instrumentation/compare-transform-pass.so.cc
new file mode 100644
index 00000000..9d2f4a92
--- /dev/null
+++ b/instrumentation/compare-transform-pass.so.cc
@@ -0,0 +1,588 @@
+/*
+ * Copyright 2016 laf-intel
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <stdio.h>
+#include <stdlib.h>
+#include <unistd.h>
+
+#include <list>
+#include <string>
+#include <fstream>
+#include <sys/time.h>
+#include "llvm/Config/llvm-config.h"
+
+#include "llvm/ADT/Statistic.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/LegacyPassManager.h"
+#include "llvm/IR/Module.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+#include "llvm/Transforms/IPO/PassManagerBuilder.h"
+#include "llvm/Transforms/Utils/BasicBlockUtils.h"
+#include "llvm/Pass.h"
+#include "llvm/Analysis/ValueTracking.h"
+
+#if LLVM_VERSION_MAJOR > 3 || \
+    (LLVM_VERSION_MAJOR == 3 && LLVM_VERSION_MINOR > 4)
+  #include "llvm/IR/Verifier.h"
+  #include "llvm/IR/DebugInfo.h"
+#else
+  #include "llvm/Analysis/Verifier.h"
+  #include "llvm/DebugInfo.h"
+  #define nullptr 0
+#endif
+
+#include <set>
+#include "afl-llvm-common.h"
+
+using namespace llvm;
+
+namespace {
+
+class CompareTransform : public ModulePass {
+
+ public:
+  static char ID;
+  CompareTransform() : ModulePass(ID) {
+
+    initInstrumentList();
+
+  }
+
+  bool runOnModule(Module &M) override;
+
+#if LLVM_VERSION_MAJOR < 4
+  const char *getPassName() const override {
+
+#else
+  StringRef getPassName() const override {
+
+#endif
+    return "transforms compare functions";
+
+  }
+
+ private:
+  bool transformCmps(Module &M, const bool processStrcmp,
+                     const bool processMemcmp, const bool processStrncmp,
+                     const bool processStrcasecmp,
+                     const bool processStrncasecmp);
+
+};
+
+}  // namespace
+
+char CompareTransform::ID = 0;
+
+bool CompareTransform::transformCmps(Module &M, const bool processStrcmp,
+                                     const bool processMemcmp,
+                                     const bool processStrncmp,
+                                     const bool processStrcasecmp,
+                                     const bool processStrncasecmp) {
+
+  DenseMap<Value *, std::string *> valueMap;
+  std::vector<CallInst *>          calls;
+  LLVMContext &                    C = M.getContext();
+  IntegerType *                    Int8Ty = IntegerType::getInt8Ty(C);
+  IntegerType *                    Int32Ty = IntegerType::getInt32Ty(C);
+  IntegerType *                    Int64Ty = IntegerType::getInt64Ty(C);
+
+#if LLVM_VERSION_MAJOR < 9
+  Constant *
+#else
+  FunctionCallee
+#endif
+      c = M.getOrInsertFunction("tolower", Int32Ty, Int32Ty
+#if LLVM_VERSION_MAJOR < 5
+                                ,
+                                NULL
+#endif
+      );
+#if LLVM_VERSION_MAJOR < 9
+  Function *tolowerFn = cast<Function>(c);
+#else
+  FunctionCallee tolowerFn = c;
+#endif
+
+  /* iterate over all functions, bbs and instruction and add suitable calls to
+   * strcmp/memcmp/strncmp/strcasecmp/strncasecmp */
+  for (auto &F : M) {
+
+    if (!isInInstrumentList(&F)) continue;
+
+    for (auto &BB : F) {
+
+      for (auto &IN : BB) {
+
+        CallInst *callInst = nullptr;
+
+        if ((callInst = dyn_cast<CallInst>(&IN))) {
+
+          bool isStrcmp = processStrcmp;
+          bool isMemcmp = processMemcmp;
+          bool isStrncmp = processStrncmp;
+          bool isStrcasecmp = processStrcasecmp;
+          bool isStrncasecmp = processStrncasecmp;
+          bool isIntMemcpy = true;
+
+          Function *Callee = callInst->getCalledFunction();
+          if (!Callee) continue;
+          if (callInst->getCallingConv() != llvm::CallingConv::C) continue;
+          StringRef FuncName = Callee->getName();
+          isStrcmp &= !FuncName.compare(StringRef("strcmp"));
+          isMemcmp &= (!FuncName.compare(StringRef("memcmp")) ||
+                       !FuncName.compare(StringRef("bcmp")));
+          isStrncmp &= !FuncName.compare(StringRef("strncmp"));
+          isStrcasecmp &= !FuncName.compare(StringRef("strcasecmp"));
+          isStrncasecmp &= !FuncName.compare(StringRef("strncasecmp"));
+          isIntMemcpy &= !FuncName.compare("llvm.memcpy.p0i8.p0i8.i64");
+
+          if (!isStrcmp && !isMemcmp && !isStrncmp && !isStrcasecmp &&
+              !isStrncasecmp && !isIntMemcpy)
+            continue;
+
+          /* Verify the strcmp/memcmp/strncmp/strcasecmp/strncasecmp function
+           * prototype */
+          FunctionType *FT = Callee->getFunctionType();
+
+          isStrcmp &=
+              FT->getNumParams() == 2 && FT->getReturnType()->isIntegerTy(32) &&
+              FT->getParamType(0) == FT->getParamType(1) &&
+              FT->getParamType(0) == IntegerType::getInt8PtrTy(M.getContext());
+          isStrcasecmp &=
+              FT->getNumParams() == 2 && FT->getReturnType()->isIntegerTy(32) &&
+              FT->getParamType(0) == FT->getParamType(1) &&
+              FT->getParamType(0) == IntegerType::getInt8PtrTy(M.getContext());
+          isMemcmp &= FT->getNumParams() == 3 &&
+                      FT->getReturnType()->isIntegerTy(32) &&
+                      FT->getParamType(0)->isPointerTy() &&
+                      FT->getParamType(1)->isPointerTy() &&
+                      FT->getParamType(2)->isIntegerTy();
+          isStrncmp &= FT->getNumParams() == 3 &&
+                       FT->getReturnType()->isIntegerTy(32) &&
+                       FT->getParamType(0) == FT->getParamType(1) &&
+                       FT->getParamType(0) ==
+                           IntegerType::getInt8PtrTy(M.getContext()) &&
+                       FT->getParamType(2)->isIntegerTy();
+          isStrncasecmp &= FT->getNumParams() == 3 &&
+                           FT->getReturnType()->isIntegerTy(32) &&
+                           FT->getParamType(0) == FT->getParamType(1) &&
+                           FT->getParamType(0) ==
+                               IntegerType::getInt8PtrTy(M.getContext()) &&
+                           FT->getParamType(2)->isIntegerTy();
+
+          if (!isStrcmp && !isMemcmp && !isStrncmp && !isStrcasecmp &&
+              !isStrncasecmp && !isIntMemcpy)
+            continue;
+
+          /* is a str{n,}{case,}cmp/memcmp, check if we have
+           * str{case,}cmp(x, "const") or str{case,}cmp("const", x)
+           * strn{case,}cmp(x, "const", ..) or strn{case,}cmp("const", x, ..)
+           * memcmp(x, "const", ..) or memcmp("const", x, ..) */
+          Value *Str1P = callInst->getArgOperand(0),
+                *Str2P = callInst->getArgOperand(1);
+          StringRef Str1, Str2;
+          bool      HasStr1 = getConstantStringInfo(Str1P, Str1);
+          bool      HasStr2 = getConstantStringInfo(Str2P, Str2);
+
+          if (isIntMemcpy && HasStr2) {
+
+            valueMap[Str1P] = new std::string(Str2.str());
+            // fprintf(stderr, "saved %s for %p\n", Str2.str().c_str(), Str1P);
+            continue;
+
+          }
+
+          // not literal? maybe global or local variable
+          if (!(HasStr1 || HasStr2)) {
+
+            auto *Ptr = dyn_cast<ConstantExpr>(Str2P);
+            if (Ptr && Ptr->isGEPWithNoNotionalOverIndexing()) {
+
+              if (auto *Var = dyn_cast<GlobalVariable>(Ptr->getOperand(0))) {
+
+                if (Var->hasInitializer()) {
+
+                  if (auto *Array =
+                          dyn_cast<ConstantDataArray>(Var->getInitializer())) {
+
+                    HasStr2 = true;
+                    Str2 = Array->getAsString();
+                    valueMap[Str2P] = new std::string(Str2.str());
+                    fprintf(stderr, "glo2 %s\n", Str2.str().c_str());
+
+                  }
+
+                }
+
+              }
+
+            }
+
+            if (!HasStr2) {
+
+              auto *Ptr = dyn_cast<ConstantExpr>(Str1P);
+              if (Ptr && Ptr->isGEPWithNoNotionalOverIndexing()) {
+
+                if (auto *Var = dyn_cast<GlobalVariable>(Ptr->getOperand(0))) {
+
+                  if (Var->hasInitializer()) {
+
+                    if (auto *Array = dyn_cast<ConstantDataArray>(
+                            Var->getInitializer())) {
+
+                      HasStr1 = true;
+                      Str1 = Array->getAsString();
+                      valueMap[Str1P] = new std::string(Str1.str());
+                      // fprintf(stderr, "glo1 %s\n", Str1.str().c_str());
+
+                    }
+
+                  }
+
+                }
+
+              }
+
+            } else if (isIntMemcpy) {
+
+              valueMap[Str1P] = new std::string(Str2.str());
+              // fprintf(stderr, "saved\n");
+
+            }
+
+          }
+
+          if (isIntMemcpy) continue;
+
+          if (!(HasStr1 || HasStr2)) {
+
+            // do we have a saved local variable initialization?
+            std::string *val = valueMap[Str1P];
+            if (val && !val->empty()) {
+
+              Str1 = StringRef(*val);
+              HasStr1 = true;
+              // fprintf(stderr, "loaded1 %s\n", Str1.str().c_str());
+
+            } else {
+
+              val = valueMap[Str2P];
+              if (val && !val->empty()) {
+
+                Str2 = StringRef(*val);
+                HasStr2 = true;
+                // fprintf(stderr, "loaded2 %s\n", Str2.str().c_str());
+
+              }
+
+            }
+
+          }
+
+          /* handle cases of one string is const, one string is variable */
+          if (!(HasStr1 || HasStr2)) continue;
+
+          if (isMemcmp || isStrncmp || isStrncasecmp) {
+
+            /* check if third operand is a constant integer
+             * strlen("constStr") and sizeof() are treated as constant */
+            Value *      op2 = callInst->getArgOperand(2);
+            ConstantInt *ilen = dyn_cast<ConstantInt>(op2);
+            if (ilen) {
+
+              uint64_t len = ilen->getZExtValue();
+              // if len is zero this is a pointless call but allow real
+              // implementation to worry about that
+              if (!len) continue;
+
+              if (isMemcmp) {
+
+                // if size of compare is larger than constant string this is
+                // likely a bug but allow real implementation to worry about
+                // that
+                uint64_t literalLength = HasStr1 ? Str1.size() : Str2.size();
+                if (literalLength + 1 < ilen->getZExtValue()) continue;
+
+              }
+
+            } else if (isMemcmp)
+
+              // this *may* supply a len greater than the constant string at
+              // runtime so similarly we don't want to have to handle that
+              continue;
+
+          }
+
+          calls.push_back(callInst);
+
+        }
+
+      }
+
+    }
+
+  }
+
+  if (!calls.size()) return false;
+  if (!be_quiet)
+    errs() << "Replacing " << calls.size()
+           << " calls to strcmp/memcmp/strncmp/strcasecmp/strncasecmp\n";
+
+  for (auto &callInst : calls) {
+
+    Value *Str1P = callInst->getArgOperand(0),
+          *Str2P = callInst->getArgOperand(1);
+    StringRef   Str1, Str2, ConstStr;
+    std::string TmpConstStr;
+    Value *     VarStr;
+    bool        HasStr1 = getConstantStringInfo(Str1P, Str1);
+    bool        HasStr2 = getConstantStringInfo(Str2P, Str2);
+    uint64_t    constStrLen, unrollLen, constSizedLen = 0;
+    bool        isMemcmp =
+        !callInst->getCalledFunction()->getName().compare(StringRef("memcmp"));
+    bool isSizedcmp = isMemcmp ||
+                      !callInst->getCalledFunction()->getName().compare(
+                          StringRef("strncmp")) ||
+                      !callInst->getCalledFunction()->getName().compare(
+                          StringRef("strncasecmp"));
+    Value *sizedValue = isSizedcmp ? callInst->getArgOperand(2) : NULL;
+    bool   isConstSized = sizedValue && isa<ConstantInt>(sizedValue);
+    bool isCaseInsensitive = !callInst->getCalledFunction()->getName().compare(
+                                 StringRef("strcasecmp")) ||
+                             !callInst->getCalledFunction()->getName().compare(
+                                 StringRef("strncasecmp"));
+
+    if (!(HasStr1 || HasStr2)) {
+
+      // do we have a saved local or global variable initialization?
+      std::string *val = valueMap[Str1P];
+      if (val && !val->empty()) {
+
+        Str1 = StringRef(*val);
+        HasStr1 = true;
+
+      } else {
+
+        val = valueMap[Str2P];
+        if (val && !val->empty()) {
+
+          Str2 = StringRef(*val);
+          HasStr2 = true;
+
+        }
+
+      }
+
+    }
+
+    if (isConstSized) {
+
+      constSizedLen = dyn_cast<ConstantInt>(sizedValue)->getZExtValue();
+
+    }
+
+    if (HasStr1) {
+
+      TmpConstStr = Str1.str();
+      VarStr = Str2P;
+
+    } else {
+
+      TmpConstStr = Str2.str();
+      VarStr = Str1P;
+
+    }
+
+    // add null termination character implicit in c strings
+    TmpConstStr.append("\0", 1);
+
+    // in the unusual case the const str has embedded null
+    // characters, the string comparison functions should terminate
+    // at the first null
+    if (!isMemcmp)
+      TmpConstStr.assign(TmpConstStr, 0, TmpConstStr.find('\0') + 1);
+
+    constStrLen = TmpConstStr.length();
+    // prefer use of StringRef (in comparison to std::string a StringRef has
+    // built-in runtime bounds checking, which makes debugging easier)
+    ConstStr = StringRef(TmpConstStr);
+
+    if (isConstSized)
+      unrollLen = constSizedLen < constStrLen ? constSizedLen : constStrLen;
+    else
+      unrollLen = constStrLen;
+
+    if (!be_quiet)
+      errs() << callInst->getCalledFunction()->getName() << ": unroll len "
+             << unrollLen
+             << ((isSizedcmp && !isConstSized) ? ", variable n" : "") << ": "
+             << ConstStr << "\n";
+
+    /* split before the call instruction */
+    BasicBlock *bb = callInst->getParent();
+    BasicBlock *end_bb = bb->splitBasicBlock(BasicBlock::iterator(callInst));
+
+    BasicBlock *next_lenchk_bb = NULL;
+    if (isSizedcmp && !isConstSized) {
+
+      next_lenchk_bb =
+          BasicBlock::Create(C, "len_check", end_bb->getParent(), end_bb);
+      BranchInst::Create(end_bb, next_lenchk_bb);
+
+    }
+
+    BasicBlock *next_cmp_bb =
+        BasicBlock::Create(C, "cmp_added", end_bb->getParent(), end_bb);
+    BranchInst::Create(end_bb, next_cmp_bb);
+    PHINode *PN = PHINode::Create(
+        Int32Ty, (next_lenchk_bb ? 2 : 1) * unrollLen + 1, "cmp_phi");
+
+#if LLVM_VERSION_MAJOR < 8
+    TerminatorInst *term = bb->getTerminator();
+#else
+    Instruction *term = bb->getTerminator();
+#endif
+    BranchInst::Create(next_lenchk_bb ? next_lenchk_bb : next_cmp_bb, bb);
+    term->eraseFromParent();
+
+    for (uint64_t i = 0; i < unrollLen; i++) {
+
+      BasicBlock *  cur_cmp_bb = next_cmp_bb, *cur_lenchk_bb = next_lenchk_bb;
+      unsigned char c;
+
+      if (cur_lenchk_bb) {
+
+        IRBuilder<> cur_lenchk_IRB(&*(cur_lenchk_bb->getFirstInsertionPt()));
+        Value *     icmp = cur_lenchk_IRB.CreateICmpEQ(
+            sizedValue, ConstantInt::get(sizedValue->getType(), i));
+        cur_lenchk_IRB.CreateCondBr(icmp, end_bb, cur_cmp_bb);
+        cur_lenchk_bb->getTerminator()->eraseFromParent();
+
+        PN->addIncoming(ConstantInt::get(Int32Ty, 0), cur_lenchk_bb);
+
+      }
+
+      if (isCaseInsensitive)
+        c = (unsigned char)(tolower((int)ConstStr[i]) & 0xff);
+      else
+        c = (unsigned char)ConstStr[i];
+
+      IRBuilder<> cur_cmp_IRB(&*(cur_cmp_bb->getFirstInsertionPt()));
+
+      Value *v = ConstantInt::get(Int64Ty, i);
+      Value *ele = cur_cmp_IRB.CreateInBoundsGEP(VarStr, v, "empty");
+      Value *load = cur_cmp_IRB.CreateLoad(ele);
+
+      if (isCaseInsensitive) {
+
+        // load >= 'A' && load <= 'Z' ? load | 0x020 : load
+        load = cur_cmp_IRB.CreateZExt(load, Int32Ty);
+        std::vector<Value *> args;
+        args.push_back(load);
+        load = cur_cmp_IRB.CreateCall(tolowerFn, args);
+        load = cur_cmp_IRB.CreateTrunc(load, Int8Ty);
+
+      }
+
+      Value *isub;
+      if (HasStr1)
+        isub = cur_cmp_IRB.CreateSub(ConstantInt::get(Int8Ty, c), load);
+      else
+        isub = cur_cmp_IRB.CreateSub(load, ConstantInt::get(Int8Ty, c));
+
+      Value *sext = cur_cmp_IRB.CreateSExt(isub, Int32Ty);
+      PN->addIncoming(sext, cur_cmp_bb);
+
+      if (i < unrollLen - 1) {
+
+        if (cur_lenchk_bb) {
+
+          next_lenchk_bb =
+              BasicBlock::Create(C, "len_check", end_bb->getParent(), end_bb);
+          BranchInst::Create(end_bb, next_lenchk_bb);
+
+        }
+
+        next_cmp_bb =
+            BasicBlock::Create(C, "cmp_added", end_bb->getParent(), end_bb);
+        BranchInst::Create(end_bb, next_cmp_bb);
+
+        Value *icmp =
+            cur_cmp_IRB.CreateICmpEQ(isub, ConstantInt::get(Int8Ty, 0));
+        cur_cmp_IRB.CreateCondBr(
+            icmp, next_lenchk_bb ? next_lenchk_bb : next_cmp_bb, end_bb);
+        cur_cmp_bb->getTerminator()->eraseFromParent();
+
+      } else {
+
+        // IRB.CreateBr(end_bb);
+
+      }
+
+      // add offset to varstr
+      // create load
+      // create signed isub
+      // create icmp
+      // create jcc
+      // create next_bb
+
+    }
+
+    /* since the call is the first instruction of the bb it is safe to
+     * replace it with a phi instruction */
+    BasicBlock::iterator ii(callInst);
+    ReplaceInstWithInst(callInst->getParent()->getInstList(), ii, PN);
+
+  }
+
+  return true;
+
+}
+
+bool CompareTransform::runOnModule(Module &M) {
+
+  if ((isatty(2) && getenv("AFL_QUIET") == NULL) || getenv("AFL_DEBUG") != NULL)
+    llvm::errs() << "Running compare-transform-pass by laf.intel@gmail.com, "
+                    "extended by heiko@hexco.de\n";
+  else
+    be_quiet = 1;
+  transformCmps(M, true, true, true, true, true);
+  verifyModule(M);
+
+  return true;
+
+}
+
+static void registerCompTransPass(const PassManagerBuilder &,
+                                  legacy::PassManagerBase &PM) {
+
+  auto p = new CompareTransform();
+  PM.add(p);
+
+}
+
+static RegisterStandardPasses RegisterCompTransPass(
+    PassManagerBuilder::EP_OptimizerLast, registerCompTransPass);
+
+static RegisterStandardPasses RegisterCompTransPass0(
+    PassManagerBuilder::EP_EnabledOnOptLevel0, registerCompTransPass);
+
+#if LLVM_VERSION_MAJOR >= 11
+static RegisterStandardPasses RegisterCompTransPassLTO(
+    PassManagerBuilder::EP_FullLinkTimeOptimizationLast, registerCompTransPass);
+#endif
+