about summary refs log tree commit diff
path: root/instrumentation
diff options
context:
space:
mode:
Diffstat (limited to 'instrumentation')
-rw-r--r--instrumentation/SanitizerCoverageLTO.so.cc3
-rw-r--r--instrumentation/SanitizerCoveragePCGUARD.so.cc1
-rw-r--r--instrumentation/afl-compiler-rt.o.c7
-rw-r--r--instrumentation/afl-llvm-common.h16
-rw-r--r--instrumentation/afl-llvm-pass.so.cc13
-rw-r--r--instrumentation/cmplog-instructions-pass.cc33
-rw-r--r--instrumentation/compare-transform-pass.so.cc54
-rw-r--r--instrumentation/split-switches-pass.so.cc6
8 files changed, 87 insertions, 46 deletions
diff --git a/instrumentation/SanitizerCoverageLTO.so.cc b/instrumentation/SanitizerCoverageLTO.so.cc
index 597a24b1..6a4a071f 100644
--- a/instrumentation/SanitizerCoverageLTO.so.cc
+++ b/instrumentation/SanitizerCoverageLTO.so.cc
@@ -1275,7 +1275,7 @@ void ModuleSanitizerCoverage::instrumentFunction(
   const DominatorTree *    DT = DTCallback(F);
   const PostDominatorTree *PDT = PDTCallback(F);
   bool                     IsLeafFunc = true;
-  uint32_t                 skip_next = 0, local_selects = 0;
+  uint32_t                 skip_next = 0;
 
   for (auto &BB : F) {
 
@@ -1385,7 +1385,6 @@ void ModuleSanitizerCoverage::instrumentFunction(
 
         }
 
-        local_selects++;
         uint32_t vector_cur = 0;
         /* Load SHM pointer */
         LoadInst *MapPtr =
diff --git a/instrumentation/SanitizerCoveragePCGUARD.so.cc b/instrumentation/SanitizerCoveragePCGUARD.so.cc
index c422d858..e4ffeb50 100644
--- a/instrumentation/SanitizerCoveragePCGUARD.so.cc
+++ b/instrumentation/SanitizerCoveragePCGUARD.so.cc
@@ -1054,7 +1054,6 @@ bool ModuleSanitizerCoverage::InjectCoverage(Function &             F,
 
         }
 
-        local_selects++;
         uint32_t vector_cur = 0;
 
         /* Load SHM pointer */
diff --git a/instrumentation/afl-compiler-rt.o.c b/instrumentation/afl-compiler-rt.o.c
index 1b9fdee3..a84f31e3 100644
--- a/instrumentation/afl-compiler-rt.o.c
+++ b/instrumentation/afl-compiler-rt.o.c
@@ -1433,9 +1433,12 @@ void __sanitizer_cov_trace_pc_guard_init(uint32_t *start, uint32_t *stop) {
 
     } else {
 
+      static u32 offset = 4;
+
       while (start < stop) {
 
-        *(start++) = 4;
+        *(start++) = offset;
+        if (unlikely(++offset >= __afl_final_loc)) { offset = 4; }
 
       }
 
@@ -1444,7 +1447,7 @@ void __sanitizer_cov_trace_pc_guard_init(uint32_t *start, uint32_t *stop) {
   }
 
   x = getenv("AFL_INST_RATIO");
-  if (x) inst_ratio = (u32)atoi(x);
+  if (x) { inst_ratio = (u32)atoi(x); }
 
   if (!inst_ratio || inst_ratio > 100) {
 
diff --git a/instrumentation/afl-llvm-common.h b/instrumentation/afl-llvm-common.h
index bd424e21..dee5f9fc 100644
--- a/instrumentation/afl-llvm-common.h
+++ b/instrumentation/afl-llvm-common.h
@@ -33,17 +33,17 @@ typedef long double max_align_t;
 #endif
 
 #if LLVM_VERSION_MAJOR >= 11
- #define MNAME M.getSourceFileName()
- #define FMNAME F.getParent()->getSourceFileName()
+  #define MNAME M.getSourceFileName()
+  #define FMNAME F.getParent()->getSourceFileName()
 #else
- #define MNAME std::string("")
- #define FMNAME std::string("")
+  #define MNAME std::string("")
+  #define FMNAME std::string("")
 #endif
 
-char *                 getBBName(const llvm::BasicBlock *BB);
-bool                   isIgnoreFunction(const llvm::Function *F);
-void                   initInstrumentList();
-bool                   isInInstrumentList(llvm::Function *F, std::string Filename);
+char *getBBName(const llvm::BasicBlock *BB);
+bool  isIgnoreFunction(const llvm::Function *F);
+void  initInstrumentList();
+bool  isInInstrumentList(llvm::Function *F, std::string Filename);
 unsigned long long int calculateCollisions(uint32_t edges);
 void                   scanForDangerousFunctions(llvm::Module *M);
 
diff --git a/instrumentation/afl-llvm-pass.so.cc b/instrumentation/afl-llvm-pass.so.cc
index 899734f8..5246ba08 100644
--- a/instrumentation/afl-llvm-pass.so.cc
+++ b/instrumentation/afl-llvm-pass.so.cc
@@ -631,18 +631,23 @@ bool AFLCoverage::runOnModule(Module &M) {
       LoadInst *PrevLoc;
 
       if (ngram_size) {
+
         PrevLoc = IRB.CreateLoad(
 #if LLVM_VERSION_MAJOR >= 14
-          PrevLocTy,
+            PrevLocTy,
 #endif
-          AFLPrevLoc);
+            AFLPrevLoc);
+
       } else {
+
         PrevLoc = IRB.CreateLoad(
 #if LLVM_VERSION_MAJOR >= 14
-          IRB.getInt32Ty(),
+            IRB.getInt32Ty(),
 #endif
-          AFLPrevLoc);
+            AFLPrevLoc);
+
       }
+
       PrevLoc->setMetadata(M.getMDKindID("nosanitize"), MDNode::get(C, None));
       Value *PrevLocTrans;
 
diff --git a/instrumentation/cmplog-instructions-pass.cc b/instrumentation/cmplog-instructions-pass.cc
index a0b386d5..310f5585 100644
--- a/instrumentation/cmplog-instructions-pass.cc
+++ b/instrumentation/cmplog-instructions-pass.cc
@@ -478,27 +478,28 @@ bool CmpLogInstructions::hookInstrs(Module &M) {
           */
           if (is_fp) {
 
-/*
-            ConstantFP *i0 = dyn_cast<ConstantFP>(op0);
-            ConstantFP *i1 = dyn_cast<ConstantFP>(op1);
-            // BUG FIXME TODO: this is null ... but why?
-            // fprintf(stderr, "%p %p\n", i0, i1);
-            if (i0) {
+            /*
+                        ConstantFP *i0 = dyn_cast<ConstantFP>(op0);
+                        ConstantFP *i1 = dyn_cast<ConstantFP>(op1);
+                        // BUG FIXME TODO: this is null ... but why?
+                        // fprintf(stderr, "%p %p\n", i0, i1);
+                        if (i0) {
 
-              cur_val = (uint64_t)i0->getValue().convertToDouble();
-              if (last_val0 && last_val0 == cur_val) { skip = 1; }
-              last_val0 = cur_val;
+                          cur_val = (uint64_t)i0->getValue().convertToDouble();
+                          if (last_val0 && last_val0 == cur_val) { skip = 1; }
+                          last_val0 = cur_val;
 
-            }
+                        }
 
-            if (i1) {
+                        if (i1) {
 
-              cur_val = (uint64_t)i1->getValue().convertToDouble();
-              if (last_val1 && last_val1 == cur_val) { skip = 1; }
-              last_val1 = cur_val;
+                          cur_val = (uint64_t)i1->getValue().convertToDouble();
+                          if (last_val1 && last_val1 == cur_val) { skip = 1; }
+                          last_val1 = cur_val;
 
-            }
-*/
+                        }
+
+            */
 
           } else {
 
diff --git a/instrumentation/compare-transform-pass.so.cc b/instrumentation/compare-transform-pass.so.cc
index 3f6a6763..c3a4ee34 100644
--- a/instrumentation/compare-transform-pass.so.cc
+++ b/instrumentation/compare-transform-pass.so.cc
@@ -383,17 +383,56 @@ bool CompareTransform::transformCmps(Module &M, const bool processStrcmp,
     bool        isMemcmp = false;
     bool        isSizedcmp = false;
     bool        isCaseInsensitive = false;
+    bool        needs_null = false;
     Function *  Callee = callInst->getCalledFunction();
     if (Callee) {
 
-      isMemcmp = Callee->getName().compare("memcmp") == 0;
-      isSizedcmp = isMemcmp || Callee->getName().compare("strncmp") == 0 ||
-                   Callee->getName().compare("strncasecmp") == 0;
-      isCaseInsensitive = Callee->getName().compare("strcasecmp") == 0 ||
-                          Callee->getName().compare("strncasecmp") == 0;
+      if (!Callee->getName().compare("memcmp") ||
+          !Callee->getName().compare("bcmp") ||
+          !Callee->getName().compare("CRYPTO_memcmp") ||
+          !Callee->getName().compare("OPENSSL_memcmp") ||
+          !Callee->getName().compare("memcmp_const_time") ||
+          !Callee->getName().compare("memcmpct") ||
+          !Callee->getName().compare("llvm.memcpy.p0i8.p0i8.i64"))
+        isMemcmp = true;
+
+      if (isMemcmp || !Callee->getName().compare("strncmp") ||
+          !Callee->getName().compare("xmlStrncmp") ||
+          !Callee->getName().compare("curl_strnequal") ||
+          !Callee->getName().compare("strncasecmp") ||
+          !Callee->getName().compare("strnicmp") ||
+          !Callee->getName().compare("ap_cstr_casecmpn") ||
+          !Callee->getName().compare("OPENSSL_strncasecmp") ||
+          !Callee->getName().compare("xmlStrncasecmp") ||
+          !Callee->getName().compare("g_ascii_strncasecmp") ||
+          !Callee->getName().compare("Curl_strncasecompare") ||
+          !Callee->getName().compare("g_strncasecmp"))
+        isSizedcmp = true;
+
+      if (!Callee->getName().compare("strcasecmp") ||
+          !Callee->getName().compare("stricmp") ||
+          !Callee->getName().compare("ap_cstr_casecmp") ||
+          !Callee->getName().compare("OPENSSL_strcasecmp") ||
+          !Callee->getName().compare("xmlStrcasecmp") ||
+          !Callee->getName().compare("g_strcasecmp") ||
+          !Callee->getName().compare("g_ascii_strcasecmp") ||
+          !Callee->getName().compare("Curl_strcasecompare") ||
+          !Callee->getName().compare("Curl_safe_strcasecompare") ||
+          !Callee->getName().compare("cmsstrcasecmp") ||
+          !Callee->getName().compare("strncasecmp") ||
+          !Callee->getName().compare("strnicmp") ||
+          !Callee->getName().compare("ap_cstr_casecmpn") ||
+          !Callee->getName().compare("OPENSSL_strncasecmp") ||
+          !Callee->getName().compare("xmlStrncasecmp") ||
+          !Callee->getName().compare("g_ascii_strncasecmp") ||
+          !Callee->getName().compare("Curl_strncasecompare") ||
+          !Callee->getName().compare("g_strncasecmp"))
+        isCaseInsensitive = true;
 
     }
 
+    if (!isSizedcmp) needs_null = true;
+
     Value *sizedValue = isSizedcmp ? callInst->getArgOperand(2) : NULL;
     bool   isConstSized = sizedValue && isa<ConstantInt>(sizedValue);
 
@@ -447,17 +486,14 @@ bool CompareTransform::transformCmps(Module &M, const bool processStrcmp,
 
     // the following is in general OK, but strncmp is sometimes used in binary
     // data structures and this can result in crashes :( so it is commented out
-    /*
 
     // add null termination character implicit in c strings
-    if (!isMemcmp && TmpConstStr[TmpConstStr.length() - 1]) {
+    if (needs_null && TmpConstStr[TmpConstStr.length() - 1] != 0) {
 
       TmpConstStr.append("\0", 1);
 
     }
 
-    */
-
     // in the unusual case the const str has embedded null
     // characters, the string comparison functions should terminate
     // at the first null
diff --git a/instrumentation/split-switches-pass.so.cc b/instrumentation/split-switches-pass.so.cc
index 85a35c2a..9f9e7eca 100644
--- a/instrumentation/split-switches-pass.so.cc
+++ b/instrumentation/split-switches-pass.so.cc
@@ -118,8 +118,6 @@ BasicBlock *SplitSwitchesTransform::switchConvert(
   std::vector<uint8_t> setSizes;
   std::vector<std::set<uint8_t> > byteSets(BytesInValue, std::set<uint8_t>());
 
-  assert(ValTypeBitWidth >= 8 && ValTypeBitWidth <= 64);
-
   /* for each of the possible cases we iterate over all bytes of the values
    * build a set of possible values at each byte position in byteSets */
   for (CaseExpr &Case : Cases) {
@@ -350,9 +348,9 @@ bool SplitSwitchesTransform::splitSwitches(Module &M) {
 
     /* If there is only the default destination or the condition checks 8 bit or
      * less, don't bother with the code below. */
-    if (!SI->getNumCases() || bitw <= 8) {
+    if (SI->getNumCases() < 2 || bitw % 8 || bitw > 64) {
 
-      // if (!be_quiet) errs() << "skip trivial switch..\n";
+      // if (!be_quiet) errs() << "skip switch..\n";
       continue;
 
     }