about summary refs log tree commit diff
path: root/instrumentation/cmplog-routines-pass.cc
diff options
context:
space:
mode:
authorvanhauser-thc <vh@thc.org>2020-09-05 12:11:48 +0200
committervanhauser-thc <vh@thc.org>2020-09-05 12:11:48 +0200
commit996986bed5f2dd97a3d76f584d8eddc1203f8396 (patch)
tree245d4b208ecb1dcf38c34987aabbd8e44c2703c9 /instrumentation/cmplog-routines-pass.cc
parentfac108476c1cb5326cf4339b2a4c846828698816 (diff)
downloadafl++-996986bed5f2dd97a3d76f584d8eddc1203f8396.tar.gz
first batch of changes
Diffstat (limited to 'instrumentation/cmplog-routines-pass.cc')
-rw-r--r--instrumentation/cmplog-routines-pass.cc212
1 files changed, 212 insertions, 0 deletions
diff --git a/instrumentation/cmplog-routines-pass.cc b/instrumentation/cmplog-routines-pass.cc
new file mode 100644
index 00000000..c44f38c4
--- /dev/null
+++ b/instrumentation/cmplog-routines-pass.cc
@@ -0,0 +1,212 @@
+/*
+   american fuzzy lop++ - LLVM CmpLog instrumentation
+   --------------------------------------------------
+
+   Written by Andrea Fioraldi <andreafioraldi@gmail.com>
+
+   Copyright 2015, 2016 Google Inc. All rights reserved.
+   Copyright 2019-2020 AFLplusplus Project. All rights reserved.
+
+   Licensed under the Apache License, Version 2.0 (the "License");
+   you may not use this file except in compliance with the License.
+   You may obtain a copy of the License at:
+
+     http://www.apache.org/licenses/LICENSE-2.0
+
+*/
+
+#include <stdio.h>
+#include <stdlib.h>
+#include <unistd.h>
+
+#include <list>
+#include <string>
+#include <fstream>
+#include <sys/time.h>
+#include "llvm/Config/llvm-config.h"
+
+#include "llvm/ADT/Statistic.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/LegacyPassManager.h"
+#include "llvm/IR/Module.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+#include "llvm/Transforms/IPO/PassManagerBuilder.h"
+#include "llvm/Transforms/Utils/BasicBlockUtils.h"
+#include "llvm/Pass.h"
+#include "llvm/Analysis/ValueTracking.h"
+
+#if LLVM_VERSION_MAJOR > 3 || \
+    (LLVM_VERSION_MAJOR == 3 && LLVM_VERSION_MINOR > 4)
+  #include "llvm/IR/Verifier.h"
+  #include "llvm/IR/DebugInfo.h"
+#else
+  #include "llvm/Analysis/Verifier.h"
+  #include "llvm/DebugInfo.h"
+  #define nullptr 0
+#endif
+
+#include <set>
+#include "afl-llvm-common.h"
+
+using namespace llvm;
+
+namespace {
+
+class CmpLogRoutines : public ModulePass {
+
+ public:
+  static char ID;
+  CmpLogRoutines() : ModulePass(ID) {
+
+    initInstrumentList();
+
+  }
+
+  bool runOnModule(Module &M) override;
+
+#if LLVM_VERSION_MAJOR < 4
+  const char *getPassName() const override {
+
+#else
+  StringRef getPassName() const override {
+
+#endif
+    return "cmplog routines";
+
+  }
+
+ private:
+  bool hookRtns(Module &M);
+
+};
+
+}  // namespace
+
+char CmpLogRoutines::ID = 0;
+
+bool CmpLogRoutines::hookRtns(Module &M) {
+
+  std::vector<CallInst *> calls;
+  LLVMContext &           C = M.getContext();
+
+  Type *VoidTy = Type::getVoidTy(C);
+  // PointerType *VoidPtrTy = PointerType::get(VoidTy, 0);
+  IntegerType *Int8Ty = IntegerType::getInt8Ty(C);
+  PointerType *i8PtrTy = PointerType::get(Int8Ty, 0);
+
+#if LLVM_VERSION_MAJOR < 9
+  Constant *
+#else
+  FunctionCallee
+#endif
+      c = M.getOrInsertFunction("__cmplog_rtn_hook", VoidTy, i8PtrTy, i8PtrTy
+#if LLVM_VERSION_MAJOR < 5
+                                ,
+                                NULL
+#endif
+      );
+#if LLVM_VERSION_MAJOR < 9
+  Function *cmplogHookFn = cast<Function>(c);
+#else
+  FunctionCallee cmplogHookFn = c;
+#endif
+
+  /* iterate over all functions, bbs and instruction and add suitable calls */
+  for (auto &F : M) {
+
+    if (!isInInstrumentList(&F)) continue;
+
+    for (auto &BB : F) {
+
+      for (auto &IN : BB) {
+
+        CallInst *callInst = nullptr;
+
+        if ((callInst = dyn_cast<CallInst>(&IN))) {
+
+          Function *Callee = callInst->getCalledFunction();
+          if (!Callee) continue;
+          if (callInst->getCallingConv() != llvm::CallingConv::C) continue;
+
+          FunctionType *FT = Callee->getFunctionType();
+
+          bool isPtrRtn = FT->getNumParams() >= 2 &&
+                          !FT->getReturnType()->isVoidTy() &&
+                          FT->getParamType(0) == FT->getParamType(1) &&
+                          FT->getParamType(0)->isPointerTy();
+
+          if (!isPtrRtn) continue;
+
+          calls.push_back(callInst);
+
+        }
+
+      }
+
+    }
+
+  }
+
+  if (!calls.size()) return false;
+  if (!be_quiet)
+    errs() << "Hooking " << calls.size()
+           << " calls with pointers as arguments\n";
+
+  for (auto &callInst : calls) {
+
+    Value *v1P = callInst->getArgOperand(0), *v2P = callInst->getArgOperand(1);
+
+    IRBuilder<> IRB(callInst->getParent());
+    IRB.SetInsertPoint(callInst);
+
+    std::vector<Value *> args;
+    Value *              v1Pcasted = IRB.CreatePointerCast(v1P, i8PtrTy);
+    Value *              v2Pcasted = IRB.CreatePointerCast(v2P, i8PtrTy);
+    args.push_back(v1Pcasted);
+    args.push_back(v2Pcasted);
+
+    IRB.CreateCall(cmplogHookFn, args);
+
+    // errs() << callInst->getCalledFunction()->getName() << "\n";
+
+  }
+
+  return true;
+
+}
+
+bool CmpLogRoutines::runOnModule(Module &M) {
+
+  if (getenv("AFL_QUIET") == NULL)
+    llvm::errs()
+        << "Running cmplog-routines-pass by andreafioraldi@gmail.com\n";
+  else
+    be_quiet = 1;
+  hookRtns(M);
+  verifyModule(M);
+
+  return true;
+
+}
+
+static void registerCmpLogRoutinesPass(const PassManagerBuilder &,
+                                       legacy::PassManagerBase &PM) {
+
+  auto p = new CmpLogRoutines();
+  PM.add(p);
+
+}
+
+static RegisterStandardPasses RegisterCmpLogRoutinesPass(
+    PassManagerBuilder::EP_OptimizerLast, registerCmpLogRoutinesPass);
+
+static RegisterStandardPasses RegisterCmpLogRoutinesPass0(
+    PassManagerBuilder::EP_EnabledOnOptLevel0, registerCmpLogRoutinesPass);
+
+#if LLVM_VERSION_MAJOR >= 11
+static RegisterStandardPasses RegisterCmpLogRoutinesPassLTO(
+    PassManagerBuilder::EP_FullLinkTimeOptimizationLast,
+    registerCmpLogRoutinesPass);
+#endif
+