about summary refs log tree commit diff
path: root/instrumentation/split-compares-pass.so.cc
diff options
context:
space:
mode:
Diffstat (limited to 'instrumentation/split-compares-pass.so.cc')
-rw-r--r--instrumentation/split-compares-pass.so.cc140
1 files changed, 88 insertions, 52 deletions
diff --git a/instrumentation/split-compares-pass.so.cc b/instrumentation/split-compares-pass.so.cc
index 8a07610c..effafe50 100644
--- a/instrumentation/split-compares-pass.so.cc
+++ b/instrumentation/split-compares-pass.so.cc
@@ -1,7 +1,7 @@
 /*
  * Copyright 2016 laf-intel
- * extended for floating point by Heiko Eißfeldt
- * adapted to new pass manager by Heiko Eißfeldt
+ * extended for floating point by Heiko Eissfeldt
+ * adapted to new pass manager by Heiko Eissfeldt
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -60,7 +60,7 @@ using namespace llvm;
 
 // uncomment this toggle function verification at each step. horribly slow, but
 // helps to pinpoint a potential problem in the splitting code.
-//#define VERIFY_TOO_MUCH 1
+// #define VERIFY_TOO_MUCH 1
 
 namespace {
 
@@ -189,7 +189,11 @@ llvmGetPassPluginInfo() {
     #if LLVM_VERSION_MAJOR <= 13
             using OptimizationLevel = typename PassBuilder::OptimizationLevel;
     #endif
+    #if LLVM_VERSION_MAJOR >= 16
+            PB.registerOptimizerEarlyEPCallback(
+    #else
             PB.registerOptimizerLastEPCallback(
+    #endif
                 [](ModulePassManager &MPM, OptimizationLevel OL) {
 
                   MPM.addPass(SplitComparesTransform());
@@ -262,8 +266,11 @@ bool SplitComparesTransform::simplifyFPCompares(Module &M) {
 
             /* this is probably not needed but we do it anyway */
             if (TyOp0 != TyOp1) { continue; }
-
             if (TyOp0->isArrayTy() || TyOp0->isVectorTy()) { continue; }
+            int constants = 0;
+            if (llvm::isa<llvm::Constant>(op0)) { ++constants; }
+            if (llvm::isa<llvm::Constant>(op1)) { ++constants; }
+            if (constants != 1) { continue; }
 
             fcomps.push_back(selectcmpInst);
 
@@ -463,8 +470,12 @@ bool SplitComparesTransform::simplifyOrEqualsCompare(CmpInst     *IcmpInst,
 #else
   ReplaceInstWithInst(IcmpInst->getParent()->getInstList(), ii, PN);
 #endif
+  if (new_pred == CmpInst::ICMP_SGT || new_pred == CmpInst::ICMP_SLT) {
+
+    simplifySignedCompare(icmp_np, M, worklist);
+
+  }
 
-  worklist.push_back(icmp_np);
   worklist.push_back(icmp_eq);
 
   return true;
@@ -740,17 +751,24 @@ bool SplitComparesTransform::splitCompare(CmpInst *cmp_inst, Module &M,
       CmpInst     *icmp_inv_cmp = nullptr;
       BasicBlock  *inv_cmp_bb =
           BasicBlock::Create(C, "inv_cmp", end_bb->getParent(), end_bb);
-      if (pred == CmpInst::ICMP_UGT || pred == CmpInst::ICMP_SGT ||
-          pred == CmpInst::ICMP_UGE || pred == CmpInst::ICMP_SGE) {
+      if (pred == CmpInst::ICMP_UGT) {
 
         icmp_inv_cmp = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULT,
                                        op0_high, op1_high);
 
-      } else {
+      } else if (pred == CmpInst::ICMP_ULT) {
 
         icmp_inv_cmp = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_UGT,
                                        op0_high, op1_high);
 
+      } else {
+
+        // Never gonna appen
+        if (!be_quiet)
+          fprintf(stderr,
+                  "Error: split-compare: Equals or signed not removed: %d\n",
+                  pred);
+
       }
 
 #if LLVM_MAJOR >= 16
@@ -924,7 +942,7 @@ size_t SplitComparesTransform::nextPowerOfTwo(size_t in) {
 /* splits fcmps into two nested fcmps with sign compare and the rest */
 size_t SplitComparesTransform::splitFPCompares(Module &M) {
 
-  size_t count = 0;
+  size_t counts = 0;
 
   LLVMContext &C = M.getContext();
 
@@ -940,7 +958,7 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) {
 
   } else {
 
-    return count;
+    return counts;
 
   }
 
@@ -993,7 +1011,7 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) {
 
   }
 
-  if (!fcomps.size()) { return count; }
+  if (!fcomps.size()) { return counts; }
 
   IntegerType *Int1Ty = IntegerType::getInt1Ty(C);
 
@@ -1573,7 +1591,7 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) {
               CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_UGT, t_f0, t_f1);
 #if LLVM_MAJOR >= 16
           icmp_fraction_result->insertInto(negative_bb, negative_bb->end());
-          icmp_fraction_result2->insertInto(positive_bb, negative_bb->end());
+          icmp_fraction_result2->insertInto(positive_bb, positive_bb->end());
 #else
           negative_bb->getInstList().push_back(icmp_fraction_result);
           positive_bb->getInstList().push_back(icmp_fraction_result2);
@@ -1587,7 +1605,7 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) {
               CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULT, t_f0, t_f1);
 #if LLVM_MAJOR >= 16
           icmp_fraction_result->insertInto(negative_bb, negative_bb->end());
-          icmp_fraction_result2->insertInto(positive_bb, negative_bb->end());
+          icmp_fraction_result2->insertInto(positive_bb, positive_bb->end());
 #else
           negative_bb->getInstList().push_back(icmp_fraction_result);
           positive_bb->getInstList().push_back(icmp_fraction_result2);
@@ -1679,11 +1697,11 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) {
 #else
     ReplaceInstWithInst(FcmpInst->getParent()->getInstList(), ii, PN);
 #endif
-    ++count;
+    ++counts;
 
   }
 
-  return count;
+  return counts;
 
 }
 
@@ -1696,12 +1714,6 @@ bool SplitComparesTransform::runOnModule(Module &M) {
 
 #endif
 
-  char *bitw_env = getenv("AFL_LLVM_LAF_SPLIT_COMPARES_BITW");
-  if (!bitw_env) bitw_env = getenv("LAF_SPLIT_COMPARES_BITW");
-  if (bitw_env) { target_bitwidth = atoi(bitw_env); }
-
-  enableFPSplit = getenv("AFL_LLVM_LAF_SPLIT_FLOATS") != NULL;
-
   if ((isatty(2) && getenv("AFL_QUIET") == NULL) ||
       getenv("AFL_DEBUG") != NULL) {
 
@@ -1717,10 +1729,27 @@ bool SplitComparesTransform::runOnModule(Module &M) {
 
   }
 
-#if LLVM_MAJOR >= 11
-  auto PA = PreservedAnalyses::all();
+  char *bitw_env = getenv("AFL_LLVM_LAF_SPLIT_COMPARES_BITW");
+  if (!bitw_env) bitw_env = getenv("LAF_SPLIT_COMPARES_BITW");
+  if (bitw_env) { target_bitwidth = atoi(bitw_env); }
+
+  if (getenv("AFL_LLVM_LAF_SPLIT_FLOATS")) { enableFPSplit = true; }
+
+  bool split_comp = false;
+
+  if (getenv("AFL_LLVM_LAF_SPLIT_COMPARES")) {
+
+#if LLVM_MAJOR == 17
+    if (!be_quiet)
+      fprintf(stderr,
+              "WARNING: AFL++ splitting integer comparisons is disabled in "
+              "LLVM 17 due bugs, switch to 16 or 18!\n");
+#else
+    split_comp = true;
 #endif
 
+  }
+
   if (enableFPSplit) {
 
     simplifyFPCompares(M);
@@ -1729,42 +1758,44 @@ bool SplitComparesTransform::runOnModule(Module &M) {
     if (!be_quiet && !debug) {
 
       errs() << "Split-floatingpoint-compare-pass: " << count
-             << " FP comparisons splitted\n";
+             << " FP comparisons split\n";
 
     }
 
   }
 
-  std::vector<CmpInst *> worklist;
-  /* iterate over all functions, bbs and instruction search for all integer
-   * compare instructions. Save them into the worklist for later. */
-  for (auto &F : M) {
+  if (split_comp) {
 
-    if (!isInInstrumentList(&F, MNAME)) continue;
+    std::vector<CmpInst *> worklist;
+    /* iterate over all functions, bbs and instruction search for all integer
+     * compare instructions. Save them into the worklist for later. */
+    for (auto &F : M) {
 
-    for (auto &BB : F) {
+      if (!isInInstrumentList(&F, MNAME)) continue;
 
-      for (auto &IN : BB) {
+      for (auto &BB : F) {
 
-        if (auto CI = dyn_cast<CmpInst>(&IN)) {
+        for (auto &IN : BB) {
 
-          auto op0 = CI->getOperand(0);
-          auto op1 = CI->getOperand(1);
-          if (!op0 || !op1) {
+          if (auto CI = dyn_cast<CmpInst>(&IN)) {
 
-#if LLVM_MAJOR >= 11
-            return PA;
-#else
-            return false;
-#endif
+            auto op0 = CI->getOperand(0);
+            auto op1 = CI->getOperand(1);
+            // has to valid operands
+            if (!op0 || !op1) { continue; }
+            // has exactly one constant and one variable
+            int constants = 0;
+            if (dyn_cast<ConstantInt>(op0)) { ++constants; }
+            if (dyn_cast<ConstantInt>(op1)) { ++constants; }
+            if (constants != 1) { continue; }
 
-          }
+            auto iTy1 = dyn_cast<IntegerType>(op0->getType());
+            if (iTy1 && isa<IntegerType>(op1->getType())) {
 
-          auto iTy1 = dyn_cast<IntegerType>(op0->getType());
-          if (iTy1 && isa<IntegerType>(op1->getType())) {
+              unsigned bitw = iTy1->getBitWidth();
+              if (isSupportedBitWidth(bitw)) { worklist.push_back(CI); }
 
-            unsigned bitw = iTy1->getBitWidth();
-            if (isSupportedBitWidth(bitw)) { worklist.push_back(CI); }
+            }
 
           }
 
@@ -1774,16 +1805,18 @@ bool SplitComparesTransform::runOnModule(Module &M) {
 
     }
 
-  }
+    // now that we have a list of all integer comparisons we can start replacing
+    // them with the splitted alternatives.
+    for (auto CI : worklist) {
 
-  // now that we have a list of all integer comparisons we can start replacing
-  // them with the splitted alternatives.
-  for (auto CI : worklist) {
+      simplifyAndSplit(CI, M);
 
-    simplifyAndSplit(CI, M);
+    }
 
   }
 
+  bool ret = count == 0 ? false : true;
+
   bool brokenDebug = false;
   if (verifyModule(M, &errs()
 #if LLVM_VERSION_MAJOR >= 4 || \
@@ -1822,9 +1855,12 @@ bool SplitComparesTransform::runOnModule(Module &M) {
 
     }*/
 
-  return PA;
+  if (ret == false)
+    return PreservedAnalyses::all();
+  else
+    return PreservedAnalyses();
 #else
-  return true;
+  return ret;
 #endif
 
 }