about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--instrumentation/SanitizerCoverageLTO.so.cc299
-rw-r--r--src/afl-cc.c2
2 files changed, 185 insertions, 116 deletions
diff --git a/instrumentation/SanitizerCoverageLTO.so.cc b/instrumentation/SanitizerCoverageLTO.so.cc
index f6d60099..31d26ca3 100644
--- a/instrumentation/SanitizerCoverageLTO.so.cc
+++ b/instrumentation/SanitizerCoverageLTO.so.cc
@@ -1394,10 +1394,14 @@ void ModuleSanitizerCoverageLTO::instrumentFunction(
   uint32_t                 inst_save = inst, save_global = afl_global_id;
   uint32_t                 inst_in_this_func = 0;
   Function                *caller = NULL;
+  LoadInst                *PrevCtxLoad = NULL;
 
   CTX_add = NULL;
 
-  if (debug) fprintf(stderr, "Function: %s\n", F.getName().str().c_str());
+  if (debug)
+    fprintf(stderr,
+            "Function: %s (%u %u) XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX\n",
+            F.getName().str().c_str(), inst, afl_global_id);
 
   if (instrument_ctx) {
 
@@ -1585,7 +1589,8 @@ void ModuleSanitizerCoverageLTO::instrumentFunction(
             Value               *CTX_offset;
             BasicBlock::iterator IP = BB.getFirstInsertionPt();
             IRBuilder<>          IRB(&(*IP));
-            LoadInst            *PrevCtxLoad = IRB.CreateLoad(
+
+            PrevCtxLoad = IRB.CreateLoad(
 #if LLVM_VERSION_MAJOR >= 14
                 IRB.getInt32Ty(),
 #endif
@@ -1608,20 +1613,78 @@ void ModuleSanitizerCoverageLTO::instrumentFunction(
 
           }
 
-        }
+          // bool loaded = false, multicall = false;
+          for (auto &IN : BB) {
+
+            // check all calls and where callee count == 1 instrument
+            // our current caller_id to __afl_ctx
+            if (auto callInst = dyn_cast<CallInst>(&IN)) {
 
-        inst = inst_save;
+              Function *Callee = callInst->getCalledFunction();
+              if (countCallers(Callee) == 1) {
+
+                if (debug)
+                  fprintf(stderr, "DEBUG: %s call to %s with only one caller\n",
+                          F.getName().str().c_str(),
+                          Callee->getName().str().c_str());
+                /* if (loaded == false || multicall == true) { // } */
+                IRBuilder<> Builder(IN.getContext());
+                Builder.SetInsertPoint(callInst);
+                StoreInst *StoreCtx =
+                    Builder.CreateStore(PrevCtxLoad, AFLContext);
+                StoreCtx->setMetadata("nosanitize", N);
+                // multicall = false; loaded = true;
+
+              }  // else { multicall = true; }
+
+            }
+
+          }
+
+        }
 
       }
 
     }
 
+    inst = inst_save;
+
   }
 
+  /*  if (debug)
+      fprintf(stderr, "Next instrumentation (%u-%u=%u %u-%u=%u)\n", inst,
+              inst_save, inst - inst_save, afl_global_id, save_global,
+              afl_global_id - save_global);*/
+
   for (auto &BB : F) {
 
+    skip_next = 0;
+
+    /*
+        uint32_t j = 0;
+        fprintf(stderr, "BB %p ============================================\n",
+                CTX_add);*/
+
     for (auto &IN : BB) {
 
+      /*      j++;
+            uint32_t           i = 1;
+            std::string        errMsg;
+            raw_string_ostream os(errMsg);
+            IN.print(os);
+            fprintf(stderr, "Next instruction, BB size now %zu: %02u %s\n",
+         BB.size(), j, os.str().c_str()); for (auto &IN2 : BB) {
+
+              std::string        errMsg2;
+              raw_string_ostream os2(errMsg2);
+              IN2.print(os2);
+              fprintf(
+                  stderr, "%s %02u: %s\n",
+                  strcmp(os.str().c_str(), os2.str().c_str()) == 0 ? ">>>" : "
+         ", i++, os2.str().c_str());
+
+            }*/
+
       CallInst *callInst = nullptr;
 
       if ((callInst = dyn_cast<CallInst>(&IN))) {
@@ -1665,83 +1728,62 @@ void ModuleSanitizerCoverageLTO::instrumentFunction(
 
       SelectInst *selectInst = nullptr;
 
-      /*
-            std::string errMsg;
-            raw_string_ostream os(errMsg);
-            IN.print(os);
-            fprintf(stderr, "X(%u): %s\n", skip_next, os.str().c_str());
-      */
-      if (!skip_next && (selectInst = dyn_cast<SelectInst>(&IN))) {
-
-        uint32_t    vector_cnt = 0;
-        Value      *condition = selectInst->getCondition();
-        Value      *result;
-        auto        t = condition->getType();
-        IRBuilder<> IRB(selectInst->getNextNode());
+      if ((selectInst = dyn_cast<SelectInst>(&IN))) {
 
-        ++select_cnt;
+        if (!skip_next) {
 
-        if (t->getTypeID() == llvm::Type::IntegerTyID) {
+          // fprintf(stderr, "Select in\n");
 
-          Value *val1 = ConstantInt::get(Int32Ty, ++afl_global_id);
-          Value *val2 = ConstantInt::get(Int32Ty, ++afl_global_id);
-          if (CTX_add) {
-
-            LoadInst *CTX_load = IRB.CreateLoad(
-#if LLVM_VERSION_MAJOR >= 14
-                IRB.getInt32Ty(),
-#endif
-                CTX_add);
-            val1 = IRB.CreateAdd(val1, CTX_load);
-            val2 = IRB.CreateAdd(val2, CTX_load);
+          uint32_t    vector_cnt = 0;
+          Value      *condition = selectInst->getCondition();
+          Value      *result;
+          auto        t = condition->getType();
+          IRBuilder<> IRB(selectInst->getNextNode());
 
-          }
+          ++select_cnt;
 
-          result = IRB.CreateSelect(condition, val1, val2);
-          skip_next = 1;
-          inst += 2;
+          if (t->getTypeID() == llvm::Type::IntegerTyID) {
 
-        } else
+            Value *val1 = ConstantInt::get(Int32Ty, ++afl_global_id);
+            Value *val2 = ConstantInt::get(Int32Ty, ++afl_global_id);
+            if (CTX_add) {
 
+              LoadInst *CTX_load = IRB.CreateLoad(
 #if LLVM_VERSION_MAJOR >= 14
-            if (t->getTypeID() == llvm::Type::FixedVectorTyID) {
-
-          FixedVectorType *tt = dyn_cast<FixedVectorType>(t);
-          if (tt) {
+                  IRB.getInt32Ty(),
+#endif
+                  CTX_add);
+              val1 = IRB.CreateAdd(val1, CTX_load);
+              val2 = IRB.CreateAdd(val2, CTX_load);
 
-            uint32_t elements = tt->getElementCount().getFixedValue();
-            vector_cnt = elements;
-            inst += vector_cnt * 2;
-            if (elements) {
+            }
 
-              FixedVectorType *GuardPtr1 =
-                  FixedVectorType::get(Int32Ty, elements);
-              FixedVectorType *GuardPtr2 =
-                  FixedVectorType::get(Int32Ty, elements);
-              Value *x, *y;
+            result = IRB.CreateSelect(condition, val1, val2);
+            skip_next = 1;
+            inst += 2;
 
-              Value *val1 = ConstantInt::get(Int32Ty, ++afl_global_id);
-              Value *val2 = ConstantInt::get(Int32Ty, ++afl_global_id);
-              if (CTX_add) {
+          } else
 
-                LoadInst *CTX_load = IRB.CreateLoad(
-  #if LLVM_VERSION_MAJOR >= 14
-                    IRB.getInt32Ty(),
-  #endif
-                    CTX_add);
-                val1 = IRB.CreateAdd(val1, CTX_load);
-                val2 = IRB.CreateAdd(val2, CTX_load);
+#if LLVM_VERSION_MAJOR >= 14
+              if (t->getTypeID() == llvm::Type::FixedVectorTyID) {
 
-              }
+            FixedVectorType *tt = dyn_cast<FixedVectorType>(t);
+            if (tt) {
 
-              x = IRB.CreateInsertElement(GuardPtr1, val1, (uint64_t)0);
-              y = IRB.CreateInsertElement(GuardPtr2, val2, (uint64_t)0);
+              uint32_t elements = tt->getElementCount().getFixedValue();
+              vector_cnt = elements;
+              inst += vector_cnt * 2;
+              if (elements) {
 
-              for (uint64_t i = 1; i < elements; i++) {
+                FixedVectorType *GuardPtr1 =
+                    FixedVectorType::get(Int32Ty, elements);
+                FixedVectorType *GuardPtr2 =
+                    FixedVectorType::get(Int32Ty, elements);
+                Value *x, *y;
 
-                val1 = ConstantInt::get(Int32Ty, ++afl_global_id);
-                val2 = ConstantInt::get(Int32Ty, ++afl_global_id);
-                /*if (CTX_add) { // already loaded I guess
+                Value *val1 = ConstantInt::get(Int32Ty, ++afl_global_id);
+                Value *val2 = ConstantInt::get(Int32Ty, ++afl_global_id);
+                if (CTX_add) {
 
                   LoadInst *CTX_load = IRB.CreateLoad(
   #if LLVM_VERSION_MAJOR >= 14
@@ -1751,92 +1793,116 @@ void ModuleSanitizerCoverageLTO::instrumentFunction(
                   val1 = IRB.CreateAdd(val1, CTX_load);
                   val2 = IRB.CreateAdd(val2, CTX_load);
 
-                }*/
+                }
 
-                x = IRB.CreateInsertElement(GuardPtr1, val1, i);
-                y = IRB.CreateInsertElement(GuardPtr2, val2, i);
+                x = IRB.CreateInsertElement(GuardPtr1, val1, (uint64_t)0);
+                y = IRB.CreateInsertElement(GuardPtr2, val2, (uint64_t)0);
 
-              }
+                for (uint64_t i = 1; i < elements; i++) {
 
-              result = IRB.CreateSelect(condition, x, y);
-              skip_next = 1;
+                  val1 = ConstantInt::get(Int32Ty, ++afl_global_id);
+                  val2 = ConstantInt::get(Int32Ty, ++afl_global_id);
+                  /*if (CTX_add) { // already loaded I guess
 
-            }
+                    LoadInst *CTX_load = IRB.CreateLoad(
+    #if LLVM_VERSION_MAJOR >= 14
+                        IRB.getInt32Ty(),
+    #endif
+                        CTX_add);
+                    val1 = IRB.CreateAdd(val1, CTX_load);
+                    val2 = IRB.CreateAdd(val2, CTX_load);
 
-          }
+                  }*/
+
+                  x = IRB.CreateInsertElement(GuardPtr1, val1, i);
+                  y = IRB.CreateInsertElement(GuardPtr2, val2, i);
+
+                }
+
+                result = IRB.CreateSelect(condition, x, y);
+                skip_next = 1;
 
-        } else
+              }
+
+            }
+
+          } else
 
 #endif
-        {
+          {
 
-          ++unhandled;
-          continue;
+            ++unhandled;
+            continue;
 
-        }
+          }
 
-        uint32_t vector_cur = 0;
-        /* Load SHM pointer */
-        LoadInst *MapPtr =
-            IRB.CreateLoad(PointerType::get(Int8Ty, 0), AFLMapPtr);
-        ModuleSanitizerCoverageLTO::SetNoSanitizeMetadata(MapPtr);
+          uint32_t vector_cur = 0;
+          /* Load SHM pointer */
+          LoadInst *MapPtr =
+              IRB.CreateLoad(PointerType::get(Int8Ty, 0), AFLMapPtr);
+          ModuleSanitizerCoverageLTO::SetNoSanitizeMetadata(MapPtr);
 
-        while (1) {
+          while (1) {
 
-          /* Get CurLoc */
-          Value *MapPtrIdx = nullptr;
+            /* Get CurLoc */
+            Value *MapPtrIdx = nullptr;
 
-          /* Load counter for CurLoc */
-          if (!vector_cnt) {
+            /* Load counter for CurLoc */
+            if (!vector_cnt) {
 
-            MapPtrIdx = IRB.CreateGEP(Int8Ty, MapPtr, result);
+              MapPtrIdx = IRB.CreateGEP(Int8Ty, MapPtr, result);
 
-          } else {
+            } else {
 
-            auto element = IRB.CreateExtractElement(result, vector_cur++);
-            MapPtrIdx = IRB.CreateGEP(Int8Ty, MapPtr, element);
+              auto element = IRB.CreateExtractElement(result, vector_cur++);
+              MapPtrIdx = IRB.CreateGEP(Int8Ty, MapPtr, element);
 
-          }
+            }
 
-          if (use_threadsafe_counters) {
+            if (use_threadsafe_counters) {
 
-            IRB.CreateAtomicRMW(llvm::AtomicRMWInst::BinOp::Add, MapPtrIdx, One,
+              IRB.CreateAtomicRMW(llvm::AtomicRMWInst::BinOp::Add, MapPtrIdx,
+                                  One,
 #if LLVM_VERSION_MAJOR >= 13
-                                llvm::MaybeAlign(1),
+                                  llvm::MaybeAlign(1),
 #endif
-                                llvm::AtomicOrdering::Monotonic);
+                                  llvm::AtomicOrdering::Monotonic);
+
+            } else {
 
-          } else {
+              LoadInst *Counter = IRB.CreateLoad(IRB.getInt8Ty(), MapPtrIdx);
+              ModuleSanitizerCoverageLTO::SetNoSanitizeMetadata(Counter);
 
-            LoadInst *Counter = IRB.CreateLoad(IRB.getInt8Ty(), MapPtrIdx);
-            ModuleSanitizerCoverageLTO::SetNoSanitizeMetadata(Counter);
+              /* Update bitmap */
 
-            /* Update bitmap */
+              Value *Incr = IRB.CreateAdd(Counter, One);
 
-            Value *Incr = IRB.CreateAdd(Counter, One);
+              if (skip_nozero == NULL) {
 
-            if (skip_nozero == NULL) {
+                auto cf = IRB.CreateICmpEQ(Incr, Zero);
+                auto carry = IRB.CreateZExt(cf, Int8Ty);
+                Incr = IRB.CreateAdd(Incr, carry);
 
-              auto cf = IRB.CreateICmpEQ(Incr, Zero);
-              auto carry = IRB.CreateZExt(cf, Int8Ty);
-              Incr = IRB.CreateAdd(Incr, carry);
+              }
+
+              auto nosan = IRB.CreateStore(Incr, MapPtrIdx);
+              ModuleSanitizerCoverageLTO::SetNoSanitizeMetadata(nosan);
 
             }
 
-            auto nosan = IRB.CreateStore(Incr, MapPtrIdx);
-            ModuleSanitizerCoverageLTO::SetNoSanitizeMetadata(nosan);
+            if (!vector_cnt || vector_cnt == vector_cur) { break; }
 
           }
 
-          if (!vector_cnt || vector_cnt == vector_cur) { break; }
-
-        }
+          skip_next = 1;
+          // fprintf(stderr, "Select out\n");
 
-        skip_next = 1;
+        } else {
 
-      } else {
+          // fprintf(stderr, "Select skip\n");
+          skip_next = 0;
 
-        skip_next = 0;
+        }
 
       }
 
@@ -1862,6 +1928,11 @@ void ModuleSanitizerCoverageLTO::instrumentFunction(
   InjectCoverage(F, BlocksToInstrument, IsLeafFunc);
   // InjectCoverageForIndirectCalls(F, IndirCalls);
 
+  /*if (debug)
+    fprintf(stderr, "Done instrumentation (%u-%u=%u %u-%u=%u)\n", inst,
+            inst_save, inst - inst_save, afl_global_id, save_global,
+            afl_global_id - save_global);*/
+
   if (inst_in_this_func && call_counter > 1) {
 
     if (inst_in_this_func != afl_global_id - save_global) {
diff --git a/src/afl-cc.c b/src/afl-cc.c
index 4f6745ed..fd466541 100644
--- a/src/afl-cc.c
+++ b/src/afl-cc.c
@@ -1103,8 +1103,6 @@ static void instrument_opt_mode_exclude(aflcc_state_t *aflcc) {
 
   }
 
-  fprintf(stderr, "X %u %u\n", aflcc->compiler_mode, LTO);
-
   if (aflcc->instrument_opt_mode && aflcc->compiler_mode != LLVM &&
       !((aflcc->instrument_opt_mode & INSTRUMENT_OPT_CALLER) &&
         aflcc->compiler_mode == LTO))