about summary refs log tree commit diff
path: root/llvm_mode/split-compares-pass.so.cc
diff options
context:
space:
mode:
Diffstat (limited to 'llvm_mode/split-compares-pass.so.cc')
-rw-r--r--llvm_mode/split-compares-pass.so.cc544
1 files changed, 479 insertions, 65 deletions
diff --git a/llvm_mode/split-compares-pass.so.cc b/llvm_mode/split-compares-pass.so.cc
index 1e9d6542..c5da42c0 100644
--- a/llvm_mode/split-compares-pass.so.cc
+++ b/llvm_mode/split-compares-pass.so.cc
@@ -1,5 +1,6 @@
 /*
  * Copyright 2016 laf-intel
+ * extended for floating point by Heiko Eißfeldt
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -21,6 +22,7 @@
 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
 #include "llvm/IR/Verifier.h"
 #include "llvm/IR/Module.h"
+#include "llvm/ADT/APFloat.h"
 
 #include "llvm/IR/IRBuilder.h"
 
@@ -49,9 +51,11 @@ class SplitComparesTransform : public ModulePass {
   }
 
  private:
-  bool splitCompares(Module &M, unsigned bitw);
+  size_t splitIntCompares(Module &M, unsigned bitw);
+  size_t splitFPCompares(Module &M);
   bool simplifyCompares(Module &M);
-  bool simplifySignedness(Module &M);
+  bool simplifyIntSignedness(Module &M);
+  size_t nextPowerOfTwo(size_t in);
 
 };
 
@@ -65,6 +69,7 @@ bool SplitComparesTransform::simplifyCompares(Module &M) {
 
   LLVMContext &              C = M.getContext();
   std::vector<Instruction *> icomps;
+  std::vector<Instruction *> fcomps;
   IntegerType *              Int1Ty = IntegerType::getInt1Ty(C);
 
   /* iterate over all functions, bbs and instruction and add
@@ -79,25 +84,41 @@ bool SplitComparesTransform::simplifyCompares(Module &M) {
 
         if ((selectcmpInst = dyn_cast<CmpInst>(&IN))) {
 
-          if (selectcmpInst->getPredicate() != CmpInst::ICMP_UGE &&
-              selectcmpInst->getPredicate() != CmpInst::ICMP_SGE &&
-              selectcmpInst->getPredicate() != CmpInst::ICMP_ULE &&
-              selectcmpInst->getPredicate() != CmpInst::ICMP_SLE) {
+          if (selectcmpInst->getPredicate() == CmpInst::ICMP_UGE ||
+              selectcmpInst->getPredicate() == CmpInst::ICMP_SGE ||
+              selectcmpInst->getPredicate() == CmpInst::ICMP_ULE ||
+              selectcmpInst->getPredicate() == CmpInst::ICMP_SLE) {
 
-            continue;
+            auto op0 = selectcmpInst->getOperand(0);
+            auto op1 = selectcmpInst->getOperand(1);
+
+            IntegerType *intTyOp0 = dyn_cast<IntegerType>(op0->getType());
+            IntegerType *intTyOp1 = dyn_cast<IntegerType>(op1->getType());
+
+            /* this is probably not needed but we do it anyway */
+            if (!intTyOp0 || !intTyOp1) { continue; }
+
+            icomps.push_back(selectcmpInst);
 
           }
 
-          auto op0 = selectcmpInst->getOperand(0);
-          auto op1 = selectcmpInst->getOperand(1);
+          if (selectcmpInst->getPredicate() == CmpInst::FCMP_OGE ||
+              selectcmpInst->getPredicate() == CmpInst::FCMP_UGE ||
+              selectcmpInst->getPredicate() == CmpInst::FCMP_OLE ||
+              selectcmpInst->getPredicate() == CmpInst::FCMP_ULE) {
 
-          IntegerType *intTyOp0 = dyn_cast<IntegerType>(op0->getType());
-          IntegerType *intTyOp1 = dyn_cast<IntegerType>(op1->getType());
+            auto op0 = selectcmpInst->getOperand(0);
+            auto op1 = selectcmpInst->getOperand(1);
 
-          /* this is probably not needed but we do it anyway */
-          if (!intTyOp0 || !intTyOp1) { continue; }
+            Type *TyOp0 = op0->getType();
+            Type *TyOp1 = op1->getType();
 
-          icomps.push_back(selectcmpInst);
+            /* this is probably not needed but we do it anyway */
+            if (TyOp0 != TyOp1) { continue; }
+
+            fcomps.push_back(selectcmpInst);
+
+          }
 
         }
 
@@ -107,7 +128,7 @@ bool SplitComparesTransform::simplifyCompares(Module &M) {
 
   }
 
-  if (!icomps.size()) { return false; }
+  if (!icomps.size() && !fcomps.size()) { return false; }
 
   for (auto &IcmpInst : icomps) {
 
@@ -173,18 +194,83 @@ bool SplitComparesTransform::simplifyCompares(Module &M) {
 
   }
 
+  /* now for floating point */
+  for (auto &FcmpInst : fcomps) {
+
+    BasicBlock *bb = FcmpInst->getParent();
+
+    auto op0 = FcmpInst->getOperand(0);
+    auto op1 = FcmpInst->getOperand(1);
+
+    /* find out what the new predicate is going to be */
+    auto               pred = dyn_cast<CmpInst>(FcmpInst)->getPredicate();
+    CmpInst::Predicate new_pred;
+    switch (pred) {
+
+      case CmpInst::FCMP_UGE: new_pred = CmpInst::FCMP_UGT; break;
+      case CmpInst::FCMP_OGE: new_pred = CmpInst::FCMP_OGT; break;
+      case CmpInst::FCMP_ULE: new_pred = CmpInst::FCMP_ULT; break;
+      case CmpInst::FCMP_OLE: new_pred = CmpInst::FCMP_OLT; break;
+      default:  // keep the compiler happy
+        continue;
+
+    }
+
+    /* split before the icmp instruction */
+    BasicBlock *end_bb = bb->splitBasicBlock(BasicBlock::iterator(FcmpInst));
+
+    /* the old bb now contains a unconditional jump to the new one (end_bb)
+     * we need to delete it later */
+
+    /* create the ICMP instruction with new_pred and add it to the old basic
+     * block bb it is now at the position where the old IcmpInst was */
+    Instruction *fcmp_np;
+    fcmp_np = CmpInst::Create(Instruction::FCmp, new_pred, op0, op1);
+    bb->getInstList().insert(bb->getTerminator()->getIterator(), fcmp_np);
+
+    /* create a new basic block which holds the new EQ fcmp */
+    Instruction *fcmp_eq;
+    /* insert middle_bb before end_bb */
+    BasicBlock *middle_bb =
+        BasicBlock::Create(C, "injected", end_bb->getParent(), end_bb);
+    fcmp_eq = CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_OEQ, op0, op1);
+    middle_bb->getInstList().push_back(fcmp_eq);
+    /* add an unconditional branch to the end of middle_bb with destination
+     * end_bb */
+    BranchInst::Create(end_bb, middle_bb);
+
+    /* replace the uncond branch with a conditional one, which depends on the
+     * new_pred icmp. True goes to end, false to the middle (injected) bb */
+    auto term = bb->getTerminator();
+    BranchInst::Create(end_bb, middle_bb, fcmp_np, bb);
+    term->eraseFromParent();
+
+    /* replace the old IcmpInst (which is the first inst in end_bb) with a PHI
+     * inst to wire up the loose ends */
+    PHINode *PN = PHINode::Create(Int1Ty, 2, "");
+    /* the first result depends on the outcome of icmp_eq */
+    PN->addIncoming(fcmp_eq, middle_bb);
+    /* if the source was the original bb we know that the icmp_np yielded true
+     * hence we can hardcode this value */
+    PN->addIncoming(ConstantInt::get(Int1Ty, 1), bb);
+    /* replace the old IcmpInst with our new and shiny PHI inst */
+    BasicBlock::iterator ii(FcmpInst);
+    ReplaceInstWithInst(FcmpInst->getParent()->getInstList(), ii, PN);
+
+  }
+
   return true;
 
 }
 
 /* this function transforms signed compares to equivalent unsigned compares */
-bool SplitComparesTransform::simplifySignedness(Module &M) {
+bool SplitComparesTransform::simplifyIntSignedness(Module &M) {
 
   LLVMContext &              C = M.getContext();
   std::vector<Instruction *> icomps;
   IntegerType *              Int1Ty = IntegerType::getInt1Ty(C);
 
-  /* iterate over all functions, bbs and instruction and add
+  /* iterate over all functions, bbs and instructions and add
    * all signed compares to icomps vector */
   for (auto &F : M) {
 
@@ -196,26 +282,24 @@ bool SplitComparesTransform::simplifySignedness(Module &M) {
 
         if ((selectcmpInst = dyn_cast<CmpInst>(&IN))) {
 
-          if (selectcmpInst->getPredicate() != CmpInst::ICMP_SGT &&
-              selectcmpInst->getPredicate() != CmpInst::ICMP_SLT) {
+          if (selectcmpInst->getPredicate() == CmpInst::ICMP_SGT ||
+              selectcmpInst->getPredicate() == CmpInst::ICMP_SLT) {
 
-            continue;
+            auto op0 = selectcmpInst->getOperand(0);
+            auto op1 = selectcmpInst->getOperand(1);
 
-          }
-
-          auto op0 = selectcmpInst->getOperand(0);
-          auto op1 = selectcmpInst->getOperand(1);
+            IntegerType *intTyOp0 = dyn_cast<IntegerType>(op0->getType());
+            IntegerType *intTyOp1 = dyn_cast<IntegerType>(op1->getType());
 
-          IntegerType *intTyOp0 = dyn_cast<IntegerType>(op0->getType());
-          IntegerType *intTyOp1 = dyn_cast<IntegerType>(op1->getType());
+            /* see above */
+            if (!intTyOp0 || !intTyOp1) { continue; }
 
-          /* see above */
-          if (!intTyOp0 || !intTyOp1) { continue; }
+            /* i think this is not possible but to lazy to look it up */
+            if (intTyOp0->getBitWidth() != intTyOp1->getBitWidth()) { continue; }
 
-          /* i think this is not possible but to lazy to look it up */
-          if (intTyOp0->getBitWidth() != intTyOp1->getBitWidth()) { continue; }
+            icomps.push_back(selectcmpInst);
 
-          icomps.push_back(selectcmpInst);
+          }
 
         }
 
@@ -328,8 +412,333 @@ bool SplitComparesTransform::simplifySignedness(Module &M) {
 
 }
 
+size_t SplitComparesTransform::nextPowerOfTwo(size_t in) {
+  --in;
+  in |= in >> 1;
+  in |= in >> 2;
+  in |= in >> 4;
+//  in |= in >> 8;
+//  in |= in >> 16;
+  return in + 1;
+}
+
+/* splits fcmps into two nested fcmps with sign compare and the rest */
+size_t SplitComparesTransform::splitFPCompares(Module &M) {
+  size_t count = 0;
+
+  LLVMContext &C = M.getContext();
+
+  const DataLayout &dl = M.getDataLayout();
+
+  /* define unions with floating point and (sign, exponent, mantissa)  triples */
+  if (dl.isLittleEndian()) {
+  }
+  else if (dl.isBigEndian()) {
+  }
+  else {
+    return count;
+  }
+
+  std::vector<CmpInst *> fcomps;
+
+  /* get all EQ, NE, GT, and LT fcmps. if the other two
+   * functions were executed only these four predicates should exist */
+  for (auto &F : M) {
+
+    for (auto &BB : F) {
+
+      for (auto &IN : BB) {
+
+        CmpInst *selectcmpInst = nullptr;
+
+        if ((selectcmpInst = dyn_cast<CmpInst>(&IN))) {
+
+          if (selectcmpInst->getPredicate() == CmpInst::FCMP_OEQ ||
+              selectcmpInst->getPredicate() == CmpInst::FCMP_ONE ||
+              selectcmpInst->getPredicate() == CmpInst::FCMP_UNE ||
+              selectcmpInst->getPredicate() == CmpInst::FCMP_OGT ||
+              selectcmpInst->getPredicate() == CmpInst::FCMP_OLT) {
+
+            auto op0 = selectcmpInst->getOperand(0);
+            auto op1 = selectcmpInst->getOperand(1);
+
+            Type *TyOp0 = op0->getType();
+            Type *TyOp1 = op1->getType();
+
+            if (TyOp0 != TyOp1) { continue; }
+
+            fcomps.push_back(selectcmpInst);
+
+          }
+
+        }
+
+      }
+
+    }
+
+  }
+  if (!fcomps.size()) { return count; }
+
+  IntegerType *Int1Ty = IntegerType::getInt1Ty(C);
+
+  for (auto &FcmpInst : fcomps) {
+
+    BasicBlock *bb = FcmpInst->getParent();
+
+    auto op0 = FcmpInst->getOperand(0);
+    auto op1 = FcmpInst->getOperand(1);
+
+    unsigned op0_size, op1_size;
+    op0_size = op0->getType()->getPrimitiveSizeInBits();
+    op1_size = op1->getType()->getPrimitiveSizeInBits();
+
+    if (op0_size != op1_size) {
+      continue;
+    }
+
+    const unsigned int precision  = llvm::APFloatBase::semanticsPrecision(op0->getType()->getFltSemantics());
+    const unsigned int sizeInBits = llvm::APFloatBase::semanticsSizeInBits(op0->getType()->getFltSemantics());
+
+
+    const unsigned shiftR_exponent = precision - 1;
+    const unsigned long long mask_fraction = ((1 << (precision - 2))) | ((1 << (precision - 2)) - 1);
+    const unsigned long long mask_exponent = (1 << (sizeInBits - precision)) - 1;
+
+    // round up sizes to the next power of two
+    // this should help with integer compare splitting
+    size_t exTySizeBytes = ((sizeInBits - precision + 7) >> 3);
+    size_t frTySizeBytes = ((precision - 1          + 7) >> 3);
+
+    IntegerType *IntExponentTy = IntegerType::get(C, nextPowerOfTwo(exTySizeBytes) << 3);
+    IntegerType *IntFractionTy = IntegerType::get(C, nextPowerOfTwo(frTySizeBytes) << 3);
+
+    BasicBlock *end_bb = bb->splitBasicBlock(BasicBlock::iterator(FcmpInst));
+
+    /* create the integers from floats directly */
+    Instruction *b_op0, *b_op1;
+    b_op0 = CastInst::Create(Instruction::BitCast, op0, IntegerType::get(C, op0_size));
+    bb->getInstList().insert(bb->getTerminator()->getIterator(), b_op0);
+
+    b_op1 = CastInst::Create(Instruction::BitCast, op1, IntegerType::get(C, op1_size));
+    bb->getInstList().insert(bb->getTerminator()->getIterator(), b_op1);
+
+    /* isolate signs of value of floating point type */
+
+    /* create a 1 bit compare for the sign bit. to do this shift and trunc
+     * the original operands so only the first bit remains.*/
+    Instruction *s_s0, *t_s0, *s_s1, *t_s1, *icmp_sign_bit;
+
+    s_s0 = BinaryOperator::Create(Instruction::LShr, b_op0,
+                                   ConstantInt::get(b_op0->getType(), op0_size - 1));
+    bb->getInstList().insert(bb->getTerminator()->getIterator(), s_s0);
+    t_s0 = new TruncInst(s_s0, Int1Ty);
+    bb->getInstList().insert(bb->getTerminator()->getIterator(), t_s0);
+
+    s_s1 = BinaryOperator::Create(Instruction::LShr, b_op1,
+                                   ConstantInt::get(b_op1->getType(), op1_size - 1));
+    bb->getInstList().insert(bb->getTerminator()->getIterator(), s_s1);
+    t_s1 = new TruncInst(s_s1, Int1Ty);
+    bb->getInstList().insert(bb->getTerminator()->getIterator(), t_s1);
+
+    /* compare of the sign bits */
+    icmp_sign_bit = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, t_s0, t_s1);
+    bb->getInstList().insert(bb->getTerminator()->getIterator(), icmp_sign_bit);
+
+    /* create a new basic block which is executed if the signedness bits are
+     * equal */
+    BasicBlock * signequal_bb =
+        BasicBlock::Create(C, "signequal", end_bb->getParent(), end_bb);
+
+    BranchInst::Create(end_bb, signequal_bb);
+
+    /* create a new bb which is executed if exponents are equal */
+    BasicBlock * middle_bb =
+        BasicBlock::Create(C, "injected", end_bb->getParent(), end_bb);
+
+    BranchInst::Create(end_bb, middle_bb);
+
+    auto term = bb->getTerminator();
+    /* if the signs are different goto end_bb else to signequal_bb */
+    BranchInst::Create(signequal_bb, end_bb, icmp_sign_bit, bb);
+    term->eraseFromParent();
+
+    /* insert code for equal signs */
+
+    /* isolate the exponents */
+    Instruction *s_e0, *m_e0, *t_e0, *s_e1, *m_e1, *t_e1;
+
+    s_e0 = BinaryOperator::Create(Instruction::LShr, b_op0, ConstantInt::get(b_op0->getType(), shiftR_exponent));
+    s_e1 = BinaryOperator::Create(Instruction::LShr, b_op1, ConstantInt::get(b_op1->getType(), shiftR_exponent));
+    signequal_bb->getInstList().insert(signequal_bb->getTerminator()->getIterator(), s_e0);
+    signequal_bb->getInstList().insert(signequal_bb->getTerminator()->getIterator(), s_e1);
+
+    if (sizeInBits - precision < exTySizeBytes * 8) {
+      m_e0 = BinaryOperator::Create(Instruction::And, s_e0, ConstantInt::get(s_e0->getType(), mask_exponent));
+      m_e1 = BinaryOperator::Create(Instruction::And, s_e1, ConstantInt::get(s_e1->getType(), mask_exponent));
+      signequal_bb->getInstList().insert(signequal_bb->getTerminator()->getIterator(), m_e0);
+      signequal_bb->getInstList().insert(signequal_bb->getTerminator()->getIterator(), m_e1);
+
+      t_e0 = new TruncInst(m_e0, IntExponentTy);
+      t_e1 = new TruncInst(m_e1, IntExponentTy);
+    } else {
+      t_e0 = new TruncInst(s_e0, IntExponentTy);
+      t_e1 = new TruncInst(s_e1, IntExponentTy);
+    }
+    signequal_bb->getInstList().insert(signequal_bb->getTerminator()->getIterator(), t_e0);
+    signequal_bb->getInstList().insert(signequal_bb->getTerminator()->getIterator(), t_e1);
+    /* compare the exponents of the operands */
+    Instruction *icmp_exponent_result;
+    switch (FcmpInst->getPredicate()) {
+      case CmpInst::FCMP_OEQ:
+        icmp_exponent_result =
+            CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, t_e0, t_e1);
+      break;
+      case CmpInst::FCMP_ONE:
+      case CmpInst::FCMP_UNE:
+        icmp_exponent_result =
+            CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_NE, t_e0, t_e1);
+      break;
+      case CmpInst::FCMP_OGT:
+        Instruction *icmp_exponent;
+        icmp_exponent =
+            CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_UGT, t_e0, t_e1);
+        signequal_bb->getInstList().insert(signequal_bb->getTerminator()->getIterator(), icmp_exponent);
+        icmp_exponent_result = BinaryOperator::Create(Instruction::Xor, icmp_exponent, t_s0);
+      break;
+      case CmpInst::FCMP_OLT:
+        icmp_exponent =
+            CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULT, t_e0, t_e1);
+        signequal_bb->getInstList().insert(signequal_bb->getTerminator()->getIterator(), icmp_exponent);
+        icmp_exponent_result = BinaryOperator::Create(Instruction::Xor, icmp_exponent, t_s0);
+      break;
+      default:
+        continue;
+    }
+    signequal_bb->getInstList().insert(signequal_bb->getTerminator()->getIterator(), icmp_exponent_result);
+
+    {
+    auto term = signequal_bb->getTerminator();
+    /* if the exponents are different do a fraction cmp */
+    BranchInst::Create(middle_bb, end_bb, icmp_exponent_result, signequal_bb);
+    term->eraseFromParent();
+    }
+
+
+    /* isolate the mantissa aka fraction */
+    Instruction *t_f0, *t_f1;
+    bool needTrunc = IntFractionTy->getPrimitiveSizeInBits() < op0_size;
+//errs() << "Fractions: IntFractionTy size " << IntFractionTy->getPrimitiveSizeInBits() << ", op0_size " << op0_size << ", needTrunc " << needTrunc << "\n";
+    if (precision - 1 < frTySizeBytes * 8) {
+      Instruction *m_f0, *m_f1;
+      m_f0 = BinaryOperator::Create(Instruction::And, b_op0, ConstantInt::get(b_op0->getType(), mask_fraction));
+      m_f1 = BinaryOperator::Create(Instruction::And, b_op1, ConstantInt::get(b_op1->getType(), mask_fraction));
+      middle_bb->getInstList().insert(middle_bb->getTerminator()->getIterator(), m_f0);
+      middle_bb->getInstList().insert(middle_bb->getTerminator()->getIterator(), m_f1);
+
+      if (needTrunc) {
+        t_f0 = new TruncInst(m_f0, IntFractionTy);
+        t_f1 = new TruncInst(m_f1, IntFractionTy);
+        middle_bb->getInstList().insert(middle_bb->getTerminator()->getIterator(), t_f0);
+        middle_bb->getInstList().insert(middle_bb->getTerminator()->getIterator(), t_f1);
+      } else {
+        t_f0 = m_f0;
+        t_f1 = m_f1;
+      }
+    } else {
+      if (needTrunc) {
+        t_f0 = new TruncInst(b_op0, IntFractionTy);
+        t_f1 = new TruncInst(b_op1, IntFractionTy);
+        middle_bb->getInstList().insert(middle_bb->getTerminator()->getIterator(), t_f0);
+        middle_bb->getInstList().insert(middle_bb->getTerminator()->getIterator(), t_f1);
+      } else {
+        t_f0 = b_op0;
+        t_f1 = b_op1;
+      }
+    }
+
+    /* compare the fractions of the operands */
+    Instruction *icmp_fraction_result;
+    switch (FcmpInst->getPredicate()) {
+      case CmpInst::FCMP_OEQ:
+        icmp_fraction_result =
+            CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, t_f0, t_f1);
+      break;
+      case CmpInst::FCMP_UNE:
+      case CmpInst::FCMP_ONE:
+        icmp_fraction_result =
+            CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_NE, t_f0, t_f1);
+      break;
+      case CmpInst::FCMP_OGT:
+        Instruction *icmp_fraction;
+        icmp_fraction =
+            CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_UGT, t_f0, t_f1);
+        middle_bb->getInstList().insert(middle_bb->getTerminator()->getIterator(), icmp_fraction);
+        icmp_fraction_result = BinaryOperator::Create(Instruction::Xor, icmp_fraction, t_s0);
+      break;
+      case CmpInst::FCMP_OLT:
+        icmp_fraction =
+            CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULT, t_f0, t_f1);
+        middle_bb->getInstList().insert(middle_bb->getTerminator()->getIterator(), icmp_fraction);
+        icmp_fraction_result = BinaryOperator::Create(Instruction::Xor, icmp_fraction, t_s0);
+      break;
+      default:
+        continue;
+    }
+    middle_bb->getInstList().insert(middle_bb->getTerminator()->getIterator(), icmp_fraction_result);
+
+    PHINode *PN = PHINode::Create(Int1Ty, 3, "");
+
+    switch (FcmpInst->getPredicate()) {
+      case CmpInst::FCMP_OEQ:
+        /* unequal signs cannot be equal values */
+        /* goto false branch */
+        PN->addIncoming(ConstantInt::get(Int1Ty, 0), bb);
+        /* unequal exponents cannot be equal values, too */
+        PN->addIncoming(ConstantInt::get(Int1Ty, 0), signequal_bb);
+        /* fractions comparison */
+        PN->addIncoming(icmp_fraction_result, middle_bb);
+      break;
+      case CmpInst::FCMP_ONE:
+      case CmpInst::FCMP_UNE:
+        /* unequal signs are unequal values */
+        /* goto true branch */
+        PN->addIncoming(ConstantInt::get(Int1Ty, 1), bb);
+        /* unequal exponents are unequal values, too */
+        PN->addIncoming(ConstantInt::get(Int1Ty, 1), signequal_bb);
+        /* fractions comparison */
+        PN->addIncoming(icmp_fraction_result, middle_bb);
+      break;
+      case CmpInst::FCMP_OGT:
+        /* if op1 is negative goto true branch, 
+           else go on comparing */
+        PN->addIncoming(t_s1, bb);
+        PN->addIncoming(icmp_exponent_result, signequal_bb);
+        PN->addIncoming(icmp_fraction_result, middle_bb);
+      break;
+      case CmpInst::FCMP_OLT:
+        /* if op0 is negative goto true branch,
+           else go on comparing */
+        PN->addIncoming(t_s0, bb);
+        PN->addIncoming(icmp_exponent_result, signequal_bb);
+        PN->addIncoming(icmp_fraction_result, middle_bb);
+      break;
+      default:
+        continue;
+    }
+
+    BasicBlock::iterator ii(FcmpInst);
+    ReplaceInstWithInst(FcmpInst->getParent()->getInstList(), ii, PN);
+    ++count;
+  }
+
+  return count;
+
+}
+
 /* splits icmps of size bitw into two nested icmps with bitw/2 size each */
-bool SplitComparesTransform::splitCompares(Module &M, unsigned bitw) {
+size_t SplitComparesTransform::splitIntCompares(Module &M, unsigned bitw) {
+  size_t count = 0;
 
   LLVMContext &C = M.getContext();
 
@@ -339,13 +748,14 @@ bool SplitComparesTransform::splitCompares(Module &M, unsigned bitw) {
 
   std::vector<Instruction *> icomps;
 
-  if (bitw % 2) { return false; }
+  if (bitw % 2) { return 0; }
 
   /* not supported yet */
-  if (bitw > 64) { return false; }
+  if (bitw > 64) { return 0; }
 
-  /* get all EQ, NE, UGT, and ULT icmps of width bitw. if the other two
-   * unctions were executed only these four predicates should exist */
+  /* get all EQ, NE, UGT, and ULT icmps of width bitw. if the 
+   * functions simplifyCompares() and simplifyIntSignedness()
+   * were executed only these four predicates should exist */
   for (auto &F : M) {
 
     for (auto &BB : F) {
@@ -356,33 +766,31 @@ bool SplitComparesTransform::splitCompares(Module &M, unsigned bitw) {
 
         if ((selectcmpInst = dyn_cast<CmpInst>(&IN))) {
 
-          if (selectcmpInst->getPredicate() != CmpInst::ICMP_EQ &&
-              selectcmpInst->getPredicate() != CmpInst::ICMP_NE &&
-              selectcmpInst->getPredicate() != CmpInst::ICMP_UGT &&
-              selectcmpInst->getPredicate() != CmpInst::ICMP_ULT) {
+          if (selectcmpInst->getPredicate() == CmpInst::ICMP_EQ ||
+              selectcmpInst->getPredicate() == CmpInst::ICMP_NE ||
+              selectcmpInst->getPredicate() == CmpInst::ICMP_UGT ||
+              selectcmpInst->getPredicate() == CmpInst::ICMP_ULT) {
 
-            continue;
+            auto op0 = selectcmpInst->getOperand(0);
+            auto op1 = selectcmpInst->getOperand(1);
 
-          }
+            IntegerType *intTyOp0 = dyn_cast<IntegerType>(op0->getType());
+            IntegerType *intTyOp1 = dyn_cast<IntegerType>(op1->getType());
 
-          auto op0 = selectcmpInst->getOperand(0);
-          auto op1 = selectcmpInst->getOperand(1);
+            if (!intTyOp0 || !intTyOp1) { continue; }
 
-          IntegerType *intTyOp0 = dyn_cast<IntegerType>(op0->getType());
-          IntegerType *intTyOp1 = dyn_cast<IntegerType>(op1->getType());
+            /* check if the bitwidths are the one we are looking for */
+            if (intTyOp0->getBitWidth() != bitw ||
+                intTyOp1->getBitWidth() != bitw) {
 
-          if (!intTyOp0 || !intTyOp1) { continue; }
+              continue;
 
-          /* check if the bitwidths are the one we are looking for */
-          if (intTyOp0->getBitWidth() != bitw ||
-              intTyOp1->getBitWidth() != bitw) {
+            }
 
-            continue;
+            icomps.push_back(selectcmpInst);
 
           }
 
-          icomps.push_back(selectcmpInst);
-
         }
 
       }
@@ -391,7 +799,7 @@ bool SplitComparesTransform::splitCompares(Module &M, unsigned bitw) {
 
   }
 
-  if (!icomps.size()) { return false; }
+  if (!icomps.size()) { return 0; }
 
   for (auto &IcmpInst : icomps) {
 
@@ -482,7 +890,7 @@ bool SplitComparesTransform::splitCompares(Module &M, unsigned bitw) {
       /* transformations for < and > */
 
       /* create a basic block which checks for the inverse predicate.
-       * if this is true we can go to the end if not we have to got to the
+       * if this is true we can go to the end if not we have to go to the
        * bb which checks the lower half of the operands */
       Instruction *icmp_inv_cmp, *op0_low, *op1_low, *icmp_low;
       BasicBlock * inv_cmp_bb =
@@ -528,10 +936,10 @@ bool SplitComparesTransform::splitCompares(Module &M, unsigned bitw) {
       ReplaceInstWithInst(IcmpInst->getParent()->getInstList(), ii, PN);
 
     }
-
+    ++count;
   }
 
-  return true;
+  return count;
 
 }
 
@@ -545,26 +953,32 @@ bool SplitComparesTransform::runOnModule(Module &M) {
 
   simplifyCompares(M);
 
-  simplifySignedness(M);
+  simplifyIntSignedness(M);
 
   if (getenv("AFL_QUIET") == NULL)
-    errs() << "Split-compare-pass by laf.intel@gmail.com\n";
+    errs() << "Split-compare-pass by laf.intel@gmail.com, extended by heiko@hexco.de\n";
+
+  errs() << "Split-floatingpoint-compare-pass: " << splitFPCompares(M) << " FP comparisons splitted\n";
 
   switch (bitw) {
 
     case 64:
-      errs() << "Running split-compare-pass " << 64 << "\n";
-      splitCompares(M, 64);
+      errs() << "Split-integer-compare-pass " << bitw << "bit: " 
+        << splitIntCompares(M, bitw) << " splitted\n";
 
+      bitw >>= 1;
       [[clang::fallthrough]]; /*FALLTHRU*/                   /* FALLTHROUGH */
     case 32:
-      errs() << "Running split-compare-pass " << 32 << "\n";
-      splitCompares(M, 32);
+      errs() << "Split-integer-compare-pass " << bitw << "bit: " 
+        << splitIntCompares(M, bitw) << " splitted\n";
 
+      bitw >>= 1;
       [[clang::fallthrough]]; /*FALLTHRU*/                   /* FALLTHROUGH */
     case 16:
-      errs() << "Running split-compare-pass " << 16 << "\n";
-      splitCompares(M, 16);
+      errs() << "Split-integer-compare-pass " << bitw << "bit: " 
+        << splitIntCompares(M, bitw) << " splitted\n";
+
+      bitw >>= 1;
       break;
 
     default: