about summary refs log tree commit diff
path: root/instrumentation/compare-transform-pass.so.cc
diff options
context:
space:
mode:
authorAlexander Shvedov <60114847+a-shvedov@users.noreply.github.com>2024-05-30 10:43:01 +0300
committerGitHub <noreply@github.com>2024-05-30 10:43:01 +0300
commitf8a5f1cd9ea907654f42fa06ce6b6bfd4b8c1b13 (patch)
tree7aec2a095a30ed609ce96f85ec3c4e0a8b8eb74c /instrumentation/compare-transform-pass.so.cc
parent629edb1e78d791894ce9ee6d53259f95fe1a29af (diff)
parente7d871c8bf64962a658e447b90a1a3b43aaddc28 (diff)
downloadafl++-f8a5f1cd9ea907654f42fa06ce6b6bfd4b8c1b13.tar.gz
Merge branch 'AFLplusplus:stable' into stable
Diffstat (limited to 'instrumentation/compare-transform-pass.so.cc')
-rw-r--r--instrumentation/compare-transform-pass.so.cc116
1 files changed, 82 insertions, 34 deletions
diff --git a/instrumentation/compare-transform-pass.so.cc b/instrumentation/compare-transform-pass.so.cc
index efc99d20..496d69fc 100644
--- a/instrumentation/compare-transform-pass.so.cc
+++ b/instrumentation/compare-transform-pass.so.cc
@@ -89,7 +89,7 @@ class CompareTransform : public ModulePass {
 
   #endif
 
-    return "cmplog transform";
+    return "compcov transform";
 
   }
 
@@ -123,7 +123,11 @@ llvmGetPassPluginInfo() {
     #if LLVM_VERSION_MAJOR <= 13
             using OptimizationLevel = typename PassBuilder::OptimizationLevel;
     #endif
+    #if LLVM_VERSION_MAJOR >= 16
+            PB.registerOptimizerEarlyEPCallback(
+    #else
             PB.registerOptimizerLastEPCallback(
+    #endif
                 [](ModulePassManager &MPM, OptimizationLevel OL) {
 
                   MPM.addPass(CompareTransform());
@@ -169,6 +173,7 @@ bool CompareTransform::transformCmps(Module &M, const bool processStrcmp,
   DenseMap<Value *, std::string *> valueMap;
   std::vector<CallInst *>          calls;
   LLVMContext                     &C = M.getContext();
+  IntegerType                     *Int1Ty = IntegerType::getInt1Ty(C);
   IntegerType                     *Int8Ty = IntegerType::getInt8Ty(C);
   IntegerType                     *Int32Ty = IntegerType::getInt32Ty(C);
   IntegerType                     *Int64Ty = IntegerType::getInt64Ty(C);
@@ -225,38 +230,38 @@ bool CompareTransform::transformCmps(Module &M, const bool processStrcmp,
           if (callInst->getCallingConv() != llvm::CallingConv::C) continue;
           StringRef FuncName = Callee->getName();
           isStrcmp &=
-              (!FuncName.compare("strcmp") || !FuncName.compare("xmlStrcmp") ||
+              (!FuncName.compare("strcmp") /*|| !FuncName.compare("xmlStrcmp") ||
                !FuncName.compare("xmlStrEqual") ||
-               !FuncName.compare("g_strcmp0") ||
                !FuncName.compare("curl_strequal") ||
-               !FuncName.compare("strcsequal"));
+               !FuncName.compare("strcsequal") ||
+               !FuncName.compare("g_strcmp0")*/);
           isMemcmp &=
               (!FuncName.compare("memcmp") || !FuncName.compare("bcmp") ||
                !FuncName.compare("CRYPTO_memcmp") ||
                !FuncName.compare("OPENSSL_memcmp") ||
                !FuncName.compare("memcmp_const_time") ||
                !FuncName.compare("memcmpct"));
-          isStrncmp &= (!FuncName.compare("strncmp") ||
-                        !FuncName.compare("xmlStrncmp") ||
-                        !FuncName.compare("curl_strnequal"));
+          isStrncmp &= (!FuncName.compare("strncmp")/* ||
+                        !FuncName.compare("curl_strnequal") ||
+                        !FuncName.compare("xmlStrncmp")*/);
           isStrcasecmp &= (!FuncName.compare("strcasecmp") ||
                            !FuncName.compare("stricmp") ||
                            !FuncName.compare("ap_cstr_casecmp") ||
                            !FuncName.compare("OPENSSL_strcasecmp") ||
-                           !FuncName.compare("xmlStrcasecmp") ||
+                           /*!FuncName.compare("xmlStrcasecmp") ||
                            !FuncName.compare("g_strcasecmp") ||
                            !FuncName.compare("g_ascii_strcasecmp") ||
                            !FuncName.compare("Curl_strcasecompare") ||
-                           !FuncName.compare("Curl_safe_strcasecompare") ||
+                           !FuncName.compare("Curl_safe_strcasecompare") ||*/
                            !FuncName.compare("cmsstrcasecmp"));
           isStrncasecmp &= (!FuncName.compare("strncasecmp") ||
                             !FuncName.compare("strnicmp") ||
                             !FuncName.compare("ap_cstr_casecmpn") ||
-                            !FuncName.compare("OPENSSL_strncasecmp") ||
+                            !FuncName.compare("OPENSSL_strncasecmp") /*||
                             !FuncName.compare("xmlStrncasecmp") ||
                             !FuncName.compare("g_ascii_strncasecmp") ||
                             !FuncName.compare("Curl_strncasecompare") ||
-                            !FuncName.compare("g_strncasecmp"));
+                            !FuncName.compare("g_strncasecmp")*/);
           isIntMemcpy &= !FuncName.compare("llvm.memcpy.p0i8.p0i8.i64");
 
           if (!isStrcmp && !isMemcmp && !isStrncmp && !isStrcasecmp &&
@@ -270,28 +275,30 @@ bool CompareTransform::transformCmps(Module &M, const bool processStrcmp,
           isStrcmp &=
               FT->getNumParams() == 2 && FT->getReturnType()->isIntegerTy(32) &&
               FT->getParamType(0) == FT->getParamType(1) &&
-              FT->getParamType(0) == IntegerType::getInt8PtrTy(M.getContext());
+              FT->getParamType(0) ==
+                  IntegerType::getInt8Ty(M.getContext())->getPointerTo(0);
           isStrcasecmp &=
               FT->getNumParams() == 2 && FT->getReturnType()->isIntegerTy(32) &&
               FT->getParamType(0) == FT->getParamType(1) &&
-              FT->getParamType(0) == IntegerType::getInt8PtrTy(M.getContext());
+              FT->getParamType(0) ==
+                  IntegerType::getInt8Ty(M.getContext())->getPointerTo(0);
           isMemcmp &= FT->getNumParams() == 3 &&
                       FT->getReturnType()->isIntegerTy(32) &&
                       FT->getParamType(0)->isPointerTy() &&
                       FT->getParamType(1)->isPointerTy() &&
                       FT->getParamType(2)->isIntegerTy();
-          isStrncmp &= FT->getNumParams() == 3 &&
-                       FT->getReturnType()->isIntegerTy(32) &&
-                       FT->getParamType(0) == FT->getParamType(1) &&
-                       FT->getParamType(0) ==
-                           IntegerType::getInt8PtrTy(M.getContext()) &&
-                       FT->getParamType(2)->isIntegerTy();
-          isStrncasecmp &= FT->getNumParams() == 3 &&
-                           FT->getReturnType()->isIntegerTy(32) &&
-                           FT->getParamType(0) == FT->getParamType(1) &&
-                           FT->getParamType(0) ==
-                               IntegerType::getInt8PtrTy(M.getContext()) &&
-                           FT->getParamType(2)->isIntegerTy();
+          isStrncmp &=
+              FT->getNumParams() == 3 && FT->getReturnType()->isIntegerTy(32) &&
+              FT->getParamType(0) == FT->getParamType(1) &&
+              FT->getParamType(0) ==
+                  IntegerType::getInt8Ty(M.getContext())->getPointerTo(0) &&
+              FT->getParamType(2)->isIntegerTy();
+          isStrncasecmp &=
+              FT->getNumParams() == 3 && FT->getReturnType()->isIntegerTy(32) &&
+              FT->getParamType(0) == FT->getParamType(1) &&
+              FT->getParamType(0) ==
+                  IntegerType::getInt8Ty(M.getContext())->getPointerTo(0) &&
+              FT->getParamType(2)->isIntegerTy();
 
           if (!isStrcmp && !isMemcmp && !isStrncmp && !isStrcasecmp &&
               !isStrncasecmp && !isIntMemcpy)
@@ -457,8 +464,21 @@ bool CompareTransform::transformCmps(Module &M, const bool processStrcmp,
     bool        isSizedcmp = false;
     bool        isCaseInsensitive = false;
     bool        needs_null = false;
+    bool        success_is_one = false;
+    bool        nullCheck = false;
     Function   *Callee = callInst->getCalledFunction();
 
+    /*
+    fprintf(stderr, "%s - %s - %s\n",
+            callInst->getParent()
+                ->getParent()
+                ->getParent()
+                ->getName()
+                .str()
+                .c_str(),
+            callInst->getParent()->getParent()->getName().str().c_str(),
+            Callee ? Callee->getName().str().c_str() : "NULL");*/
+
     if (Callee) {
 
       if (!Callee->getName().compare("memcmp") ||
@@ -503,9 +523,20 @@ bool CompareTransform::transformCmps(Module &M, const bool processStrcmp,
           !Callee->getName().compare("g_strncasecmp"))
         isCaseInsensitive = true;
 
+      if (!Callee->getName().compare("xmlStrEqual") ||
+          !Callee->getName().compare("curl_strequal") ||
+          !Callee->getName().compare("strcsequal") ||
+          !Callee->getName().compare("curl_strnequal"))
+        success_is_one = true;
+
     }
 
     if (!isSizedcmp) needs_null = true;
+    if (Callee->getName().startswith("g_") ||
+        Callee->getName().startswith("curl_") ||
+        Callee->getName().startswith("Curl_") ||
+        Callee->getName().startswith("xml"))
+      nullCheck = true;
 
     Value *sizedValue = isSizedcmp ? callInst->getArgOperand(2) : NULL;
     bool   isConstSized = sizedValue && isa<ConstantInt>(sizedValue);
@@ -590,8 +621,10 @@ bool CompareTransform::transformCmps(Module &M, const bool processStrcmp,
     /* split before the call instruction */
     BasicBlock *bb = callInst->getParent();
     BasicBlock *end_bb = bb->splitBasicBlock(BasicBlock::iterator(callInst));
-
     BasicBlock *next_lenchk_bb = NULL;
+
+    if (nullCheck) { fprintf(stderr, "TODO: null check\n"); }
+
     if (isSizedcmp && !isConstSized) {
 
       next_lenchk_bb =
@@ -623,7 +656,7 @@ bool CompareTransform::transformCmps(Module &M, const bool processStrcmp,
 
         IRBuilder<> cur_lenchk_IRB(&*(cur_lenchk_bb->getFirstInsertionPt()));
         Value      *icmp = cur_lenchk_IRB.CreateICmpEQ(
-                 sizedValue, ConstantInt::get(sizedValue->getType(), i));
+            sizedValue, ConstantInt::get(sizedValue->getType(), i));
         cur_lenchk_IRB.CreateCondBr(icmp, end_bb, cur_cmp_bb);
         cur_lenchk_bb->getTerminator()->eraseFromParent();
 
@@ -667,6 +700,14 @@ bool CompareTransform::transformCmps(Module &M, const bool processStrcmp,
       else
         isub = cur_cmp_IRB.CreateSub(load, ConstantInt::get(Int8Ty, c));
 
+      if (success_is_one && i == unrollLen - 1) {
+
+        Value *isubsub = cur_cmp_IRB.CreateTrunc(isub, Int1Ty);
+        isub = cur_cmp_IRB.CreateSelect(isubsub, ConstantInt::get(Int8Ty, 0),
+                                        ConstantInt::get(Int8Ty, 1));
+
+      }
+
       Value *sext = cur_cmp_IRB.CreateSExt(isub, Int32Ty);
       PN->addIncoming(sext, cur_cmp_bb);
 
@@ -728,6 +769,8 @@ bool CompareTransform::runOnModule(Module &M) {
 
 #endif
 
+  bool ret = false;
+
   if ((isatty(2) && getenv("AFL_QUIET") == NULL) || getenv("AFL_DEBUG") != NULL)
     printf(
         "Running compare-transform-pass by laf.intel@gmail.com, extended by "
@@ -735,11 +778,7 @@ bool CompareTransform::runOnModule(Module &M) {
   else
     be_quiet = 1;
 
-#if LLVM_MAJOR >= 11                                /* use new pass manager */
-  auto PA = PreservedAnalyses::all();
-#endif
-
-  transformCmps(M, true, true, true, true, true);
+  if (transformCmps(M, true, true, true, true, true) == true) ret = true;
   verifyModule(M);
 
 #if LLVM_MAJOR >= 11                                /* use new pass manager */
@@ -749,9 +788,18 @@ bool CompareTransform::runOnModule(Module &M) {
                    
                        }*/
 
-  return PA;
+  if (ret == true) {
+
+    return PreservedAnalyses();
+
+  } else {
+
+    return PreservedAnalyses::all();
+
+  }
+
 #else
-  return true;
+  return ret;
 #endif
 
 }