about summary refs log tree commit diff
path: root/instrumentation
diff options
context:
space:
mode:
Diffstat (limited to 'instrumentation')
-rw-r--r--instrumentation/SanitizerCoveragePCGUARD.so.cc6
-rw-r--r--instrumentation/afl-llvm-common.cc49
-rw-r--r--instrumentation/afl-llvm-common.h1
-rw-r--r--instrumentation/compare-transform-pass.so.cc43
-rw-r--r--instrumentation/split-compares-pass.so.cc11
5 files changed, 99 insertions, 11 deletions
diff --git a/instrumentation/SanitizerCoveragePCGUARD.so.cc b/instrumentation/SanitizerCoveragePCGUARD.so.cc
index f88ce126..01881f28 100644
--- a/instrumentation/SanitizerCoveragePCGUARD.so.cc
+++ b/instrumentation/SanitizerCoveragePCGUARD.so.cc
@@ -195,7 +195,7 @@ class ModuleSanitizerCoverageAFL
 
   SanitizerCoverageOptions Options;
 
-  uint32_t        instr = 0, selects = 0, unhandled = 0;
+  uint32_t        instr = 0, selects = 0, unhandled = 0, dump_cc = 0;
   GlobalVariable *AFLMapPtr = NULL;
   ConstantInt    *One = NULL;
   ConstantInt    *Zero = NULL;
@@ -330,6 +330,8 @@ bool ModuleSanitizerCoverageAFL::instrumentModule(
 
   if (getenv("AFL_DEBUG")) { debug = 1; }
 
+  if (getenv("AFL_DUMP_CYCLOMATIC_COMPLEXITY")) { dump_cc = 1; }
+
   if ((isatty(2) && !getenv("AFL_QUIET")) || debug) {
 
     SAYF(cCYA "SanitizerCoveragePCGUARD" VERSION cRST "\n");
@@ -638,6 +640,8 @@ void ModuleSanitizerCoverageAFL::instrumentFunction(
   // InjectTraceForCmp(F, CmpTraceTargets);
   // InjectTraceForSwitch(F, SwitchTraceTargets);
 
+  if (dump_cc) { calcCyclomaticComplexity(&F); }
+
 }
 
 GlobalVariable *ModuleSanitizerCoverageAFL::CreateFunctionLocalArrayInSection(
diff --git a/instrumentation/afl-llvm-common.cc b/instrumentation/afl-llvm-common.cc
index 8e9e7800..50954324 100644
--- a/instrumentation/afl-llvm-common.cc
+++ b/instrumentation/afl-llvm-common.cc
@@ -26,6 +26,51 @@ static std::list<std::string> allowListFunctions;
 static std::list<std::string> denyListFiles;
 static std::list<std::string> denyListFunctions;
 
+unsigned int calcCyclomaticComplexity(llvm::Function *F) {
+
+  unsigned int numBlocks = 0;
+  unsigned int numEdges = 0;
+  unsigned int numCalls = 0;
+
+  // Iterate through each basic block in the function
+  for (BasicBlock &BB : *F) {
+
+    // count all nodes == basic blocks
+    numBlocks++;
+    // Count the number of successors (outgoing edges)
+    for (BasicBlock *Succ : successors(&BB)) {
+
+      // count edges for CC
+      numEdges++;
+      (void)(Succ);
+
+    }
+
+    for (Instruction &I : BB) {
+
+      // every call is also an edge, so we need to count the calls too
+      if (isa<CallInst>(&I) || isa<InvokeInst>(&I)) { numCalls++; }
+
+    }
+
+  }
+
+  // Cyclomatic Complexity V(G) = E - N + 2P
+  // For a single function, P (number of connected components) is 1
+  // Calls are considered to be an edge
+  unsigned int CC = 2 + numCalls + numEdges - numBlocks;
+
+  // if (debug) {
+
+  fprintf(stderr, "CyclomaticComplexity for %s: %u\n",
+          F->getName().str().c_str(), CC);
+
+  //}
+
+  return CC;
+
+}
+
 char *getBBName(const llvm::BasicBlock *BB) {
 
   static char *name;
@@ -91,7 +136,11 @@ bool isIgnoreFunction(const llvm::Function *F) {
 
   for (auto const &ignoreListFunc : ignoreList) {
 
+#if LLVM_VERSION_MAJOR >= 19
+    if (F->getName().starts_with(ignoreListFunc)) { return true; }
+#else
     if (F->getName().startswith(ignoreListFunc)) { return true; }
+#endif
 
   }
 
diff --git a/instrumentation/afl-llvm-common.h b/instrumentation/afl-llvm-common.h
index 23f67179..6b628d64 100644
--- a/instrumentation/afl-llvm-common.h
+++ b/instrumentation/afl-llvm-common.h
@@ -55,6 +55,7 @@ void  initInstrumentList();
 bool  isInInstrumentList(llvm::Function *F, std::string Filename);
 unsigned long long int calculateCollisions(uint32_t edges);
 void                   scanForDangerousFunctions(llvm::Module *M);
+unsigned int           calcCyclomaticComplexity(llvm::Function *F);
 
 #ifndef IS_EXTERN
   #define IS_EXTERN
diff --git a/instrumentation/compare-transform-pass.so.cc b/instrumentation/compare-transform-pass.so.cc
index f8ba9de5..36149f35 100644
--- a/instrumentation/compare-transform-pass.so.cc
+++ b/instrumentation/compare-transform-pass.so.cc
@@ -54,6 +54,12 @@
   #define nullptr 0
 #endif
 
+#if LLVM_MAJOR >= 19
+  #define STARTSWITH starts_with
+#else
+  #define STARTSWITH startswith
+#endif
+
 #include <set>
 #include "afl-llvm-common.h"
 
@@ -230,38 +236,38 @@ bool CompareTransform::transformCmps(Module &M, const bool processStrcmp,
           if (callInst->getCallingConv() != llvm::CallingConv::C) continue;
           StringRef FuncName = Callee->getName();
           isStrcmp &=
-              (!FuncName.compare("strcmp") || !FuncName.compare("xmlStrcmp") ||
+              (!FuncName.compare("strcmp") /*|| !FuncName.compare("xmlStrcmp") ||
                !FuncName.compare("xmlStrEqual") ||
                !FuncName.compare("curl_strequal") ||
                !FuncName.compare("strcsequal") ||
-               !FuncName.compare("g_strcmp0"));
+               !FuncName.compare("g_strcmp0")*/);
           isMemcmp &=
               (!FuncName.compare("memcmp") || !FuncName.compare("bcmp") ||
                !FuncName.compare("CRYPTO_memcmp") ||
                !FuncName.compare("OPENSSL_memcmp") ||
                !FuncName.compare("memcmp_const_time") ||
                !FuncName.compare("memcmpct"));
-          isStrncmp &= (!FuncName.compare("strncmp") ||
+          isStrncmp &= (!FuncName.compare("strncmp")/* ||
                         !FuncName.compare("curl_strnequal") ||
-                        !FuncName.compare("xmlStrncmp"));
+                        !FuncName.compare("xmlStrncmp")*/);
           isStrcasecmp &= (!FuncName.compare("strcasecmp") ||
                            !FuncName.compare("stricmp") ||
                            !FuncName.compare("ap_cstr_casecmp") ||
                            !FuncName.compare("OPENSSL_strcasecmp") ||
-                           !FuncName.compare("xmlStrcasecmp") ||
+                           /*!FuncName.compare("xmlStrcasecmp") ||
                            !FuncName.compare("g_strcasecmp") ||
                            !FuncName.compare("g_ascii_strcasecmp") ||
                            !FuncName.compare("Curl_strcasecompare") ||
-                           !FuncName.compare("Curl_safe_strcasecompare") ||
+                           !FuncName.compare("Curl_safe_strcasecompare") ||*/
                            !FuncName.compare("cmsstrcasecmp"));
           isStrncasecmp &= (!FuncName.compare("strncasecmp") ||
                             !FuncName.compare("strnicmp") ||
                             !FuncName.compare("ap_cstr_casecmpn") ||
-                            !FuncName.compare("OPENSSL_strncasecmp") ||
+                            !FuncName.compare("OPENSSL_strncasecmp") /*||
                             !FuncName.compare("xmlStrncasecmp") ||
                             !FuncName.compare("g_ascii_strncasecmp") ||
                             !FuncName.compare("Curl_strncasecompare") ||
-                            !FuncName.compare("g_strncasecmp"));
+                            !FuncName.compare("g_strncasecmp")*/);
           isIntMemcpy &= !FuncName.compare("llvm.memcpy.p0i8.p0i8.i64");
 
           if (!isStrcmp && !isMemcmp && !isStrncmp && !isStrcasecmp &&
@@ -465,8 +471,20 @@ bool CompareTransform::transformCmps(Module &M, const bool processStrcmp,
     bool        isCaseInsensitive = false;
     bool        needs_null = false;
     bool        success_is_one = false;
+    bool        nullCheck = false;
     Function   *Callee = callInst->getCalledFunction();
 
+    /*
+    fprintf(stderr, "%s - %s - %s\n",
+            callInst->getParent()
+                ->getParent()
+                ->getParent()
+                ->getName()
+                .str()
+                .c_str(),
+            callInst->getParent()->getParent()->getName().str().c_str(),
+            Callee ? Callee->getName().str().c_str() : "NULL");*/
+
     if (Callee) {
 
       if (!Callee->getName().compare("memcmp") ||
@@ -520,6 +538,11 @@ bool CompareTransform::transformCmps(Module &M, const bool processStrcmp,
     }
 
     if (!isSizedcmp) needs_null = true;
+    if (Callee->getName().STARTSWITH("g_") ||
+        Callee->getName().STARTSWITH("curl_") ||
+        Callee->getName().STARTSWITH("Curl_") ||
+        Callee->getName().STARTSWITH("xml"))
+      nullCheck = true;
 
     Value *sizedValue = isSizedcmp ? callInst->getArgOperand(2) : NULL;
     bool   isConstSized = sizedValue && isa<ConstantInt>(sizedValue);
@@ -604,8 +627,10 @@ bool CompareTransform::transformCmps(Module &M, const bool processStrcmp,
     /* split before the call instruction */
     BasicBlock *bb = callInst->getParent();
     BasicBlock *end_bb = bb->splitBasicBlock(BasicBlock::iterator(callInst));
-
     BasicBlock *next_lenchk_bb = NULL;
+
+    if (nullCheck) { fprintf(stderr, "TODO: null check\n"); }
+
     if (isSizedcmp && !isConstSized) {
 
       next_lenchk_bb =
diff --git a/instrumentation/split-compares-pass.so.cc b/instrumentation/split-compares-pass.so.cc
index 728ebc22..effafe50 100644
--- a/instrumentation/split-compares-pass.so.cc
+++ b/instrumentation/split-compares-pass.so.cc
@@ -266,8 +266,11 @@ bool SplitComparesTransform::simplifyFPCompares(Module &M) {
 
             /* this is probably not needed but we do it anyway */
             if (TyOp0 != TyOp1) { continue; }
-
             if (TyOp0->isArrayTy() || TyOp0->isVectorTy()) { continue; }
+            int constants = 0;
+            if (llvm::isa<llvm::Constant>(op0)) { ++constants; }
+            if (llvm::isa<llvm::Constant>(op1)) { ++constants; }
+            if (constants != 1) { continue; }
 
             fcomps.push_back(selectcmpInst);
 
@@ -1778,7 +1781,13 @@ bool SplitComparesTransform::runOnModule(Module &M) {
 
             auto op0 = CI->getOperand(0);
             auto op1 = CI->getOperand(1);
+            // has to valid operands
             if (!op0 || !op1) { continue; }
+            // has exactly one constant and one variable
+            int constants = 0;
+            if (dyn_cast<ConstantInt>(op0)) { ++constants; }
+            if (dyn_cast<ConstantInt>(op1)) { ++constants; }
+            if (constants != 1) { continue; }
 
             auto iTy1 = dyn_cast<IntegerType>(op0->getType());
             if (iTy1 && isa<IntegerType>(op1->getType())) {