about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--instrumentation/SanitizerCoverageLTO.so.cc55
-rw-r--r--src/afl-cc.c10
2 files changed, 61 insertions, 4 deletions
diff --git a/instrumentation/SanitizerCoverageLTO.so.cc b/instrumentation/SanitizerCoverageLTO.so.cc
index 68423029..c74069e1 100644
--- a/instrumentation/SanitizerCoverageLTO.so.cc
+++ b/instrumentation/SanitizerCoverageLTO.so.cc
@@ -247,6 +247,7 @@ class ModuleSanitizerCoverageLTO
   uint32_t                         afl_global_id = 0;
   uint32_t                         unhandled = 0;
   uint32_t                         select_cnt = 0;
+  uint32_t                         instrument_ctx = 0;
   uint64_t                         map_addr = 0;
   const char                      *skip_nozero = NULL;
   const char                      *use_threadsafe_counters = nullptr;
@@ -261,6 +262,7 @@ class ModuleSanitizerCoverageLTO
   LLVMContext                     *Ct = NULL;
   Module                          *Mo = NULL;
   GlobalVariable                  *AFLMapPtr = NULL;
+  GlobalVariable                  *AFLContext = NULL;
   Value                           *MapPtrFixed = NULL;
   std::ofstream                    dFile;
   size_t                           found = 0;
@@ -420,11 +422,13 @@ bool ModuleSanitizerCoverageLTO::instrumentModule(
   setvbuf(stdout, NULL, _IONBF, 0);
   if (getenv("AFL_DEBUG")) { debug = 1; }
   if (getenv("AFL_LLVM_DICT2FILE_NO_MAIN")) { autodictionary_no_main = 1; }
+  if (getenv("AFL_LLVM_CALLER")) { instrument_ctx = 1; }
 
   if ((isatty(2) && !getenv("AFL_QUIET")) || debug) {
 
     SAYF(cCYA "afl-llvm-lto" VERSION cRST
-              " by Marc \"vanHauser\" Heuse <mh@mh-sec.de>\n");
+              "%s by Marc \"vanHauser\" Heuse <mh@mh-sec.de>\n",
+         instrument_ctx ? " (CTX mode)" : "");
 
   } else
 
@@ -500,6 +504,10 @@ bool ModuleSanitizerCoverageLTO::instrumentModule(
 
   }
 
+  AFLContext = new GlobalVariable(
+      M, Int32Ty, false, GlobalValue::ExternalLinkage, 0, "__afl_prev_ctx", 0,
+      GlobalVariable::GeneralDynamicTLSModel, 0, false);
+
   Zero = ConstantInt::get(Int8Tyi, 0);
   One = ConstantInt::get(Int8Tyi, 1);
 
@@ -1284,7 +1292,50 @@ void ModuleSanitizerCoverageLTO::instrumentFunction(
   const DominatorTree     *DT = DTCallback(F);
   const PostDominatorTree *PDT = PDTCallback(F);
   bool                     IsLeafFunc = true;
-  uint32_t                 skip_next = 0;
+  uint32_t                 skip_next = 0, call_counter = 0;
+  Value                   *PrevCtx = NULL;
+
+  MDNode *N =
+      MDNode::get(F.getContext(), MDString::get(F.getContext(), "nosanitize"));
+
+  for (auto &BB : F) {
+
+    if (/*F.size() > 1 &&*/ instrument_ctx && &BB == &F.getEntryBlock()) {
+
+      // we insert a CTX value in all our callers:
+      LLVMContext &Context = F.getContext();
+      IRBuilder<>  Builder(Context);
+      for (auto *U : F.users()) {
+
+        if (auto *CI = dyn_cast<CallInst>(U)) {
+
+          fprintf(stderr, "Insert %s [%u] -> %s\n",
+                  CI->getParent()->getParent()->getName().str().c_str(),
+                  call_counter, F.getName().str().c_str());
+          Builder.SetInsertPoint(CI);
+          StoreInst *StoreCtx = Builder.CreateStore(
+              ConstantInt::get(Type::getInt32Ty(Context), call_counter++),
+              AFLContext);
+          StoreCtx->setMetadata("nosanitize", N);
+
+        }
+
+      }
+
+      // We read the CTX for this call
+      BasicBlock::iterator IP = BB.getFirstInsertionPt();
+      IRBuilder<>          IRB(&(*IP));
+      LoadInst            *PrevCtxLoad = IRB.CreateLoad(
+#if LLVM_VERSION_MAJOR >= 14
+          Builder.getInt32Ty(),
+#endif
+          AFLContext);
+      PrevCtxLoad->setMetadata("nosanitize", N);
+      PrevCtx = PrevCtxLoad;
+
+    }
+
+  }
 
   for (auto &BB : F) {
 
diff --git a/src/afl-cc.c b/src/afl-cc.c
index 174b3783..4f6745ed 100644
--- a/src/afl-cc.c
+++ b/src/afl-cc.c
@@ -1103,12 +1103,18 @@ static void instrument_opt_mode_exclude(aflcc_state_t *aflcc) {
 
   }
 
-  if (aflcc->instrument_opt_mode && aflcc->compiler_mode != LLVM)
+  fprintf(stderr, "X %u %u\n", aflcc->compiler_mode, LTO);
+
+  if (aflcc->instrument_opt_mode && aflcc->compiler_mode != LLVM &&
+      !((aflcc->instrument_opt_mode & INSTRUMENT_OPT_CALLER) &&
+        aflcc->compiler_mode == LTO))
     FATAL("CTX, CALLER and NGRAM can only be used in LLVM mode");
 
   if (aflcc->instrument_opt_mode &&
       aflcc->instrument_opt_mode != INSTRUMENT_OPT_CODECOV &&
-      aflcc->instrument_mode != INSTRUMENT_CLASSIC)
+      aflcc->instrument_mode != INSTRUMENT_CLASSIC &&
+      !(aflcc->instrument_opt_mode & INSTRUMENT_OPT_CALLER &&
+        aflcc->compiler_mode == LTO))
     FATAL(
         "CALLER, CTX and NGRAM instrumentation options can only be used with "
         "the LLVM CLASSIC instrumentation mode.");