about summary refs log tree commit diff
path: root/llvm_mode/split-compares-pass.so.cc
diff options
context:
space:
mode:
Diffstat (limited to 'llvm_mode/split-compares-pass.so.cc')
-rw-r--r--llvm_mode/split-compares-pass.so.cc278
1 files changed, 181 insertions, 97 deletions
diff --git a/llvm_mode/split-compares-pass.so.cc b/llvm_mode/split-compares-pass.so.cc
index e97c2e7b..0595c682 100644
--- a/llvm_mode/split-compares-pass.so.cc
+++ b/llvm_mode/split-compares-pass.so.cc
@@ -52,8 +52,8 @@ class SplitComparesTransform : public ModulePass {
  private:
   size_t splitIntCompares(Module &M, unsigned bitw);
   size_t splitFPCompares(Module &M);
-  bool simplifyCompares(Module &M);
-  bool simplifyIntSignedness(Module &M);
+  bool   simplifyCompares(Module &M);
+  bool   simplifyIntSignedness(Module &M);
   size_t nextPowerOfTwo(size_t in);
 
 };
@@ -294,7 +294,11 @@ bool SplitComparesTransform::simplifyIntSignedness(Module &M) {
             if (!intTyOp0 || !intTyOp1) { continue; }
 
             /* i think this is not possible but to lazy to look it up */
-            if (intTyOp0->getBitWidth() != intTyOp1->getBitWidth()) { continue; }
+            if (intTyOp0->getBitWidth() != intTyOp1->getBitWidth()) {
+
+              continue;
+
+            }
 
             icomps.push_back(selectcmpInst);
 
@@ -412,30 +416,36 @@ bool SplitComparesTransform::simplifyIntSignedness(Module &M) {
 }
 
 size_t SplitComparesTransform::nextPowerOfTwo(size_t in) {
+
   --in;
   in |= in >> 1;
   in |= in >> 2;
   in |= in >> 4;
-//  in |= in >> 8;
-//  in |= in >> 16;
+  //  in |= in >> 8;
+  //  in |= in >> 16;
   return in + 1;
+
 }
 
 /* splits fcmps into two nested fcmps with sign compare and the rest */
 size_t SplitComparesTransform::splitFPCompares(Module &M) {
+
   size_t count = 0;
 
   LLVMContext &C = M.getContext();
 
   const DataLayout &dl = M.getDataLayout();
 
-  /* define unions with floating point and (sign, exponent, mantissa)  triples */
+  /* define unions with floating point and (sign, exponent, mantissa)  triples
+   */
   if (dl.isLittleEndian()) {
-  }
-  else if (dl.isBigEndian()) {
-  }
-  else {
+
+  } else if (dl.isBigEndian()) {
+
+  } else {
+
     return count;
+
   }
 
   std::vector<CmpInst *> fcomps;
@@ -477,6 +487,7 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) {
     }
 
   }
+
   if (!fcomps.size()) { return count; }
 
   IntegerType *Int1Ty = IntegerType::getInt1Ty(C);
@@ -492,37 +503,42 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) {
     op0_size = op0->getType()->getPrimitiveSizeInBits();
     op1_size = op1->getType()->getPrimitiveSizeInBits();
 
-    if (op0_size != op1_size) {
-      continue;
-    }
+    if (op0_size != op1_size) { continue; }
 
     const unsigned int sizeInBits = op0->getType()->getPrimitiveSizeInBits();
-    const unsigned int precision = sizeInBits == 32 ? 24 :
-                                   sizeInBits == 64 ? 53 :
-                                   sizeInBits == 128 ? 113 :
-                                   sizeInBits == 16 ? 11 :
-                                   65;
-
-    const unsigned shiftR_exponent = precision - 1;
-    const unsigned long long mask_fraction = ((1 << (precision - 2))) | ((1 << (precision - 2)) - 1);
-    const unsigned long long mask_exponent = (1 << (sizeInBits - precision)) - 1;
+    const unsigned int precision =
+        sizeInBits == 32
+            ? 24
+            : sizeInBits == 64
+                  ? 53
+                  : sizeInBits == 128 ? 113 : sizeInBits == 16 ? 11 : 65;
+
+    const unsigned           shiftR_exponent = precision - 1;
+    const unsigned long long mask_fraction =
+        ((1 << (precision - 2))) | ((1 << (precision - 2)) - 1);
+    const unsigned long long mask_exponent =
+        (1 << (sizeInBits - precision)) - 1;
 
     // round up sizes to the next power of two
     // this should help with integer compare splitting
     size_t exTySizeBytes = ((sizeInBits - precision + 7) >> 3);
-    size_t frTySizeBytes = ((precision - 1          + 7) >> 3);
+    size_t frTySizeBytes = ((precision - 1 + 7) >> 3);
 
-    IntegerType *IntExponentTy = IntegerType::get(C, nextPowerOfTwo(exTySizeBytes) << 3);
-    IntegerType *IntFractionTy = IntegerType::get(C, nextPowerOfTwo(frTySizeBytes) << 3);
+    IntegerType *IntExponentTy =
+        IntegerType::get(C, nextPowerOfTwo(exTySizeBytes) << 3);
+    IntegerType *IntFractionTy =
+        IntegerType::get(C, nextPowerOfTwo(frTySizeBytes) << 3);
 
     BasicBlock *end_bb = bb->splitBasicBlock(BasicBlock::iterator(FcmpInst));
 
     /* create the integers from floats directly */
     Instruction *b_op0, *b_op1;
-    b_op0 = CastInst::Create(Instruction::BitCast, op0, IntegerType::get(C, op0_size));
+    b_op0 = CastInst::Create(Instruction::BitCast, op0,
+                             IntegerType::get(C, op0_size));
     bb->getInstList().insert(bb->getTerminator()->getIterator(), b_op0);
 
-    b_op1 = CastInst::Create(Instruction::BitCast, op1, IntegerType::get(C, op1_size));
+    b_op1 = CastInst::Create(Instruction::BitCast, op1,
+                             IntegerType::get(C, op1_size));
     bb->getInstList().insert(bb->getTerminator()->getIterator(), b_op1);
 
     /* isolate signs of value of floating point type */
@@ -531,31 +547,34 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) {
      * the original operands so only the first bit remains.*/
     Instruction *s_s0, *t_s0, *s_s1, *t_s1, *icmp_sign_bit;
 
-    s_s0 = BinaryOperator::Create(Instruction::LShr, b_op0,
-                                   ConstantInt::get(b_op0->getType(), op0_size - 1));
+    s_s0 = BinaryOperator::Create(
+        Instruction::LShr, b_op0,
+        ConstantInt::get(b_op0->getType(), op0_size - 1));
     bb->getInstList().insert(bb->getTerminator()->getIterator(), s_s0);
     t_s0 = new TruncInst(s_s0, Int1Ty);
     bb->getInstList().insert(bb->getTerminator()->getIterator(), t_s0);
 
-    s_s1 = BinaryOperator::Create(Instruction::LShr, b_op1,
-                                   ConstantInt::get(b_op1->getType(), op1_size - 1));
+    s_s1 = BinaryOperator::Create(
+        Instruction::LShr, b_op1,
+        ConstantInt::get(b_op1->getType(), op1_size - 1));
     bb->getInstList().insert(bb->getTerminator()->getIterator(), s_s1);
     t_s1 = new TruncInst(s_s1, Int1Ty);
     bb->getInstList().insert(bb->getTerminator()->getIterator(), t_s1);
 
     /* compare of the sign bits */
-    icmp_sign_bit = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, t_s0, t_s1);
+    icmp_sign_bit =
+        CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, t_s0, t_s1);
     bb->getInstList().insert(bb->getTerminator()->getIterator(), icmp_sign_bit);
 
     /* create a new basic block which is executed if the signedness bits are
      * equal */
-    BasicBlock * signequal_bb =
+    BasicBlock *signequal_bb =
         BasicBlock::Create(C, "signequal", end_bb->getParent(), end_bb);
 
     BranchInst::Create(end_bb, signequal_bb);
 
     /* create a new bb which is executed if exponents are equal */
-    BasicBlock * middle_bb =
+    BasicBlock *middle_bb =
         BasicBlock::Create(C, "injected", end_bb->getParent(), end_bb);
 
     BranchInst::Create(end_bb, middle_bb);
@@ -570,128 +589,187 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) {
     /* isolate the exponents */
     Instruction *s_e0, *m_e0, *t_e0, *s_e1, *m_e1, *t_e1;
 
-    s_e0 = BinaryOperator::Create(Instruction::LShr, b_op0, ConstantInt::get(b_op0->getType(), shiftR_exponent));
-    s_e1 = BinaryOperator::Create(Instruction::LShr, b_op1, ConstantInt::get(b_op1->getType(), shiftR_exponent));
-    signequal_bb->getInstList().insert(signequal_bb->getTerminator()->getIterator(), s_e0);
-    signequal_bb->getInstList().insert(signequal_bb->getTerminator()->getIterator(), s_e1);
+    s_e0 = BinaryOperator::Create(
+        Instruction::LShr, b_op0,
+        ConstantInt::get(b_op0->getType(), shiftR_exponent));
+    s_e1 = BinaryOperator::Create(
+        Instruction::LShr, b_op1,
+        ConstantInt::get(b_op1->getType(), shiftR_exponent));
+    signequal_bb->getInstList().insert(
+        signequal_bb->getTerminator()->getIterator(), s_e0);
+    signequal_bb->getInstList().insert(
+        signequal_bb->getTerminator()->getIterator(), s_e1);
 
     t_e0 = new TruncInst(s_e0, IntExponentTy);
     t_e1 = new TruncInst(s_e1, IntExponentTy);
-    signequal_bb->getInstList().insert(signequal_bb->getTerminator()->getIterator(), t_e0);
-    signequal_bb->getInstList().insert(signequal_bb->getTerminator()->getIterator(), t_e1);
+    signequal_bb->getInstList().insert(
+        signequal_bb->getTerminator()->getIterator(), t_e0);
+    signequal_bb->getInstList().insert(
+        signequal_bb->getTerminator()->getIterator(), t_e1);
 
     if (sizeInBits - precision < exTySizeBytes * 8) {
-      m_e0 = BinaryOperator::Create(Instruction::And, t_e0, ConstantInt::get(t_e0->getType(), mask_exponent));
-      m_e1 = BinaryOperator::Create(Instruction::And, t_e1, ConstantInt::get(t_e1->getType(), mask_exponent));
-      signequal_bb->getInstList().insert(signequal_bb->getTerminator()->getIterator(), m_e0);
-      signequal_bb->getInstList().insert(signequal_bb->getTerminator()->getIterator(), m_e1);
+
+      m_e0 = BinaryOperator::Create(
+          Instruction::And, t_e0,
+          ConstantInt::get(t_e0->getType(), mask_exponent));
+      m_e1 = BinaryOperator::Create(
+          Instruction::And, t_e1,
+          ConstantInt::get(t_e1->getType(), mask_exponent));
+      signequal_bb->getInstList().insert(
+          signequal_bb->getTerminator()->getIterator(), m_e0);
+      signequal_bb->getInstList().insert(
+          signequal_bb->getTerminator()->getIterator(), m_e1);
+
     } else {
+
       m_e0 = t_e0;
       m_e1 = t_e1;
+
     }
+
     /* compare the exponents of the operands */
     Instruction *icmp_exponent_result;
     switch (FcmpInst->getPredicate()) {
+
       case CmpInst::FCMP_OEQ:
         icmp_exponent_result =
             CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, m_e0, m_e1);
-      break;
+        break;
       case CmpInst::FCMP_ONE:
       case CmpInst::FCMP_UNE:
         icmp_exponent_result =
             CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_NE, m_e0, m_e1);
-      break;
+        break;
       case CmpInst::FCMP_OGT:
         Instruction *icmp_exponent;
         icmp_exponent =
             CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_UGT, m_e0, m_e1);
-        signequal_bb->getInstList().insert(signequal_bb->getTerminator()->getIterator(), icmp_exponent);
-        icmp_exponent_result = BinaryOperator::Create(Instruction::Xor, icmp_exponent, t_s0);
-      break;
+        signequal_bb->getInstList().insert(
+            signequal_bb->getTerminator()->getIterator(), icmp_exponent);
+        icmp_exponent_result =
+            BinaryOperator::Create(Instruction::Xor, icmp_exponent, t_s0);
+        break;
       case CmpInst::FCMP_OLT:
         icmp_exponent =
             CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULT, m_e0, m_e1);
-        signequal_bb->getInstList().insert(signequal_bb->getTerminator()->getIterator(), icmp_exponent);
-        icmp_exponent_result = BinaryOperator::Create(Instruction::Xor, icmp_exponent, t_s0);
-      break;
-      default:
-        continue;
+        signequal_bb->getInstList().insert(
+            signequal_bb->getTerminator()->getIterator(), icmp_exponent);
+        icmp_exponent_result =
+            BinaryOperator::Create(Instruction::Xor, icmp_exponent, t_s0);
+        break;
+      default: continue;
+
     }
-    signequal_bb->getInstList().insert(signequal_bb->getTerminator()->getIterator(), icmp_exponent_result);
+
+    signequal_bb->getInstList().insert(
+        signequal_bb->getTerminator()->getIterator(), icmp_exponent_result);
 
     {
-    auto term = signequal_bb->getTerminator();
-    /* if the exponents are different do a fraction cmp */
-    BranchInst::Create(middle_bb, end_bb, icmp_exponent_result, signequal_bb);
-    term->eraseFromParent();
-    }
 
+      auto term = signequal_bb->getTerminator();
+      /* if the exponents are different do a fraction cmp */
+      BranchInst::Create(middle_bb, end_bb, icmp_exponent_result, signequal_bb);
+      term->eraseFromParent();
+
+    }
 
     /* isolate the mantissa aka fraction */
     Instruction *t_f0, *t_f1;
-    bool needTrunc = IntFractionTy->getPrimitiveSizeInBits() < op0_size;
-//errs() << "Fractions: IntFractionTy size " << IntFractionTy->getPrimitiveSizeInBits() << ", op0_size " << op0_size << ", needTrunc " << needTrunc << "\n";
+    bool         needTrunc = IntFractionTy->getPrimitiveSizeInBits() < op0_size;
+    // errs() << "Fractions: IntFractionTy size " <<
+    // IntFractionTy->getPrimitiveSizeInBits() << ", op0_size " << op0_size << ",
+    // needTrunc " << needTrunc << "\n";
     if (precision - 1 < frTySizeBytes * 8) {
+
       Instruction *m_f0, *m_f1;
-      m_f0 = BinaryOperator::Create(Instruction::And, b_op0, ConstantInt::get(b_op0->getType(), mask_fraction));
-      m_f1 = BinaryOperator::Create(Instruction::And, b_op1, ConstantInt::get(b_op1->getType(), mask_fraction));
-      middle_bb->getInstList().insert(middle_bb->getTerminator()->getIterator(), m_f0);
-      middle_bb->getInstList().insert(middle_bb->getTerminator()->getIterator(), m_f1);
+      m_f0 = BinaryOperator::Create(
+          Instruction::And, b_op0,
+          ConstantInt::get(b_op0->getType(), mask_fraction));
+      m_f1 = BinaryOperator::Create(
+          Instruction::And, b_op1,
+          ConstantInt::get(b_op1->getType(), mask_fraction));
+      middle_bb->getInstList().insert(middle_bb->getTerminator()->getIterator(),
+                                      m_f0);
+      middle_bb->getInstList().insert(middle_bb->getTerminator()->getIterator(),
+                                      m_f1);
 
       if (needTrunc) {
+
         t_f0 = new TruncInst(m_f0, IntFractionTy);
         t_f1 = new TruncInst(m_f1, IntFractionTy);
-        middle_bb->getInstList().insert(middle_bb->getTerminator()->getIterator(), t_f0);
-        middle_bb->getInstList().insert(middle_bb->getTerminator()->getIterator(), t_f1);
+        middle_bb->getInstList().insert(
+            middle_bb->getTerminator()->getIterator(), t_f0);
+        middle_bb->getInstList().insert(
+            middle_bb->getTerminator()->getIterator(), t_f1);
+
       } else {
+
         t_f0 = m_f0;
         t_f1 = m_f1;
+
       }
+
     } else {
+
       if (needTrunc) {
+
         t_f0 = new TruncInst(b_op0, IntFractionTy);
         t_f1 = new TruncInst(b_op1, IntFractionTy);
-        middle_bb->getInstList().insert(middle_bb->getTerminator()->getIterator(), t_f0);
-        middle_bb->getInstList().insert(middle_bb->getTerminator()->getIterator(), t_f1);
+        middle_bb->getInstList().insert(
+            middle_bb->getTerminator()->getIterator(), t_f0);
+        middle_bb->getInstList().insert(
+            middle_bb->getTerminator()->getIterator(), t_f1);
+
       } else {
+
         t_f0 = b_op0;
         t_f1 = b_op1;
+
       }
+
     }
 
     /* compare the fractions of the operands */
     Instruction *icmp_fraction_result;
     switch (FcmpInst->getPredicate()) {
+
       case CmpInst::FCMP_OEQ:
         icmp_fraction_result =
             CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, t_f0, t_f1);
-      break;
+        break;
       case CmpInst::FCMP_UNE:
       case CmpInst::FCMP_ONE:
         icmp_fraction_result =
             CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_NE, t_f0, t_f1);
-      break;
+        break;
       case CmpInst::FCMP_OGT:
         Instruction *icmp_fraction;
         icmp_fraction =
             CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_UGT, t_f0, t_f1);
-        middle_bb->getInstList().insert(middle_bb->getTerminator()->getIterator(), icmp_fraction);
-        icmp_fraction_result = BinaryOperator::Create(Instruction::Xor, icmp_fraction, t_s0);
-      break;
+        middle_bb->getInstList().insert(
+            middle_bb->getTerminator()->getIterator(), icmp_fraction);
+        icmp_fraction_result =
+            BinaryOperator::Create(Instruction::Xor, icmp_fraction, t_s0);
+        break;
       case CmpInst::FCMP_OLT:
         icmp_fraction =
             CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULT, t_f0, t_f1);
-        middle_bb->getInstList().insert(middle_bb->getTerminator()->getIterator(), icmp_fraction);
-        icmp_fraction_result = BinaryOperator::Create(Instruction::Xor, icmp_fraction, t_s0);
-      break;
-      default:
-        continue;
+        middle_bb->getInstList().insert(
+            middle_bb->getTerminator()->getIterator(), icmp_fraction);
+        icmp_fraction_result =
+            BinaryOperator::Create(Instruction::Xor, icmp_fraction, t_s0);
+        break;
+      default: continue;
+
     }
-    middle_bb->getInstList().insert(middle_bb->getTerminator()->getIterator(), icmp_fraction_result);
+
+    middle_bb->getInstList().insert(middle_bb->getTerminator()->getIterator(),
+                                    icmp_fraction_result);
 
     PHINode *PN = PHINode::Create(Int1Ty, 3, "");
 
     switch (FcmpInst->getPredicate()) {
+
       case CmpInst::FCMP_OEQ:
         /* unequal signs cannot be equal values */
         /* goto false branch */
@@ -700,7 +778,7 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) {
         PN->addIncoming(ConstantInt::get(Int1Ty, 0), signequal_bb);
         /* fractions comparison */
         PN->addIncoming(icmp_fraction_result, middle_bb);
-      break;
+        break;
       case CmpInst::FCMP_ONE:
       case CmpInst::FCMP_UNE:
         /* unequal signs are unequal values */
@@ -710,28 +788,29 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) {
         PN->addIncoming(ConstantInt::get(Int1Ty, 1), signequal_bb);
         /* fractions comparison */
         PN->addIncoming(icmp_fraction_result, middle_bb);
-      break;
+        break;
       case CmpInst::FCMP_OGT:
-        /* if op1 is negative goto true branch, 
+        /* if op1 is negative goto true branch,
            else go on comparing */
         PN->addIncoming(t_s1, bb);
         PN->addIncoming(icmp_exponent_result, signequal_bb);
         PN->addIncoming(icmp_fraction_result, middle_bb);
-      break;
+        break;
       case CmpInst::FCMP_OLT:
         /* if op0 is negative goto true branch,
            else go on comparing */
         PN->addIncoming(t_s0, bb);
         PN->addIncoming(icmp_exponent_result, signequal_bb);
         PN->addIncoming(icmp_fraction_result, middle_bb);
-      break;
-      default:
-        continue;
+        break;
+      default: continue;
+
     }
 
     BasicBlock::iterator ii(FcmpInst);
     ReplaceInstWithInst(FcmpInst->getParent()->getInstList(), ii, PN);
     ++count;
+
   }
 
   return count;
@@ -740,6 +819,7 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) {
 
 /* splits icmps of size bitw into two nested icmps with bitw/2 size each */
 size_t SplitComparesTransform::splitIntCompares(Module &M, unsigned bitw) {
+
   size_t count = 0;
 
   LLVMContext &C = M.getContext();
@@ -755,7 +835,7 @@ size_t SplitComparesTransform::splitIntCompares(Module &M, unsigned bitw) {
   /* not supported yet */
   if (bitw > 64) { return 0; }
 
-  /* get all EQ, NE, UGT, and ULT icmps of width bitw. if the 
+  /* get all EQ, NE, UGT, and ULT icmps of width bitw. if the
    * functions simplifyCompares() and simplifyIntSignedness()
    * were executed only these four predicates should exist */
   for (auto &F : M) {
@@ -938,7 +1018,9 @@ size_t SplitComparesTransform::splitIntCompares(Module &M, unsigned bitw) {
       ReplaceInstWithInst(IcmpInst->getParent()->getInstList(), ii, PN);
 
     }
+
     ++count;
+
   }
 
   return count;
@@ -958,27 +1040,29 @@ bool SplitComparesTransform::runOnModule(Module &M) {
   simplifyIntSignedness(M);
 
   if (getenv("AFL_QUIET") == NULL)
-    errs() << "Split-compare-pass by laf.intel@gmail.com, extended by heiko@hexco.de\n";
+    errs() << "Split-compare-pass by laf.intel@gmail.com, extended by "
+              "heiko@hexco.de\n";
 
-  errs() << "Split-floatingpoint-compare-pass: " << splitFPCompares(M) << " FP comparisons splitted\n";
+  errs() << "Split-floatingpoint-compare-pass: " << splitFPCompares(M)
+         << " FP comparisons splitted\n";
 
   switch (bitw) {
 
     case 64:
-      errs() << "Split-integer-compare-pass " << bitw << "bit: " 
-        << splitIntCompares(M, bitw) << " splitted\n";
+      errs() << "Split-integer-compare-pass " << bitw
+             << "bit: " << splitIntCompares(M, bitw) << " splitted\n";
 
       bitw >>= 1;
       [[clang::fallthrough]]; /*FALLTHRU*/                   /* FALLTHROUGH */
     case 32:
-      errs() << "Split-integer-compare-pass " << bitw << "bit: " 
-        << splitIntCompares(M, bitw) << " splitted\n";
+      errs() << "Split-integer-compare-pass " << bitw
+             << "bit: " << splitIntCompares(M, bitw) << " splitted\n";
 
       bitw >>= 1;
       [[clang::fallthrough]]; /*FALLTHRU*/                   /* FALLTHROUGH */
     case 16:
-      errs() << "Split-integer-compare-pass " << bitw << "bit: " 
-        << splitIntCompares(M, bitw) << " splitted\n";
+      errs() << "Split-integer-compare-pass " << bitw
+             << "bit: " << splitIntCompares(M, bitw) << " splitted\n";
 
       bitw >>= 1;
       break;