diff options
Diffstat (limited to 'llvm_mode/split-compares-pass.so.cc')
-rw-r--r-- | llvm_mode/split-compares-pass.so.cc | 278 |
1 files changed, 181 insertions, 97 deletions
diff --git a/llvm_mode/split-compares-pass.so.cc b/llvm_mode/split-compares-pass.so.cc index e97c2e7b..0595c682 100644 --- a/llvm_mode/split-compares-pass.so.cc +++ b/llvm_mode/split-compares-pass.so.cc @@ -52,8 +52,8 @@ class SplitComparesTransform : public ModulePass { private: size_t splitIntCompares(Module &M, unsigned bitw); size_t splitFPCompares(Module &M); - bool simplifyCompares(Module &M); - bool simplifyIntSignedness(Module &M); + bool simplifyCompares(Module &M); + bool simplifyIntSignedness(Module &M); size_t nextPowerOfTwo(size_t in); }; @@ -294,7 +294,11 @@ bool SplitComparesTransform::simplifyIntSignedness(Module &M) { if (!intTyOp0 || !intTyOp1) { continue; } /* i think this is not possible but to lazy to look it up */ - if (intTyOp0->getBitWidth() != intTyOp1->getBitWidth()) { continue; } + if (intTyOp0->getBitWidth() != intTyOp1->getBitWidth()) { + + continue; + + } icomps.push_back(selectcmpInst); @@ -412,30 +416,36 @@ bool SplitComparesTransform::simplifyIntSignedness(Module &M) { } size_t SplitComparesTransform::nextPowerOfTwo(size_t in) { + --in; in |= in >> 1; in |= in >> 2; in |= in >> 4; -// in |= in >> 8; -// in |= in >> 16; + // in |= in >> 8; + // in |= in >> 16; return in + 1; + } /* splits fcmps into two nested fcmps with sign compare and the rest */ size_t SplitComparesTransform::splitFPCompares(Module &M) { + size_t count = 0; LLVMContext &C = M.getContext(); const DataLayout &dl = M.getDataLayout(); - /* define unions with floating point and (sign, exponent, mantissa) triples */ + /* define unions with floating point and (sign, exponent, mantissa) triples + */ if (dl.isLittleEndian()) { - } - else if (dl.isBigEndian()) { - } - else { + + } else if (dl.isBigEndian()) { + + } else { + return count; + } std::vector<CmpInst *> fcomps; @@ -477,6 +487,7 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) { } } + if (!fcomps.size()) { return count; } IntegerType *Int1Ty = IntegerType::getInt1Ty(C); @@ -492,37 +503,42 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) { op0_size = op0->getType()->getPrimitiveSizeInBits(); op1_size = op1->getType()->getPrimitiveSizeInBits(); - if (op0_size != op1_size) { - continue; - } + if (op0_size != op1_size) { continue; } const unsigned int sizeInBits = op0->getType()->getPrimitiveSizeInBits(); - const unsigned int precision = sizeInBits == 32 ? 24 : - sizeInBits == 64 ? 53 : - sizeInBits == 128 ? 113 : - sizeInBits == 16 ? 11 : - 65; - - const unsigned shiftR_exponent = precision - 1; - const unsigned long long mask_fraction = ((1 << (precision - 2))) | ((1 << (precision - 2)) - 1); - const unsigned long long mask_exponent = (1 << (sizeInBits - precision)) - 1; + const unsigned int precision = + sizeInBits == 32 + ? 24 + : sizeInBits == 64 + ? 53 + : sizeInBits == 128 ? 113 : sizeInBits == 16 ? 11 : 65; + + const unsigned shiftR_exponent = precision - 1; + const unsigned long long mask_fraction = + ((1 << (precision - 2))) | ((1 << (precision - 2)) - 1); + const unsigned long long mask_exponent = + (1 << (sizeInBits - precision)) - 1; // round up sizes to the next power of two // this should help with integer compare splitting size_t exTySizeBytes = ((sizeInBits - precision + 7) >> 3); - size_t frTySizeBytes = ((precision - 1 + 7) >> 3); + size_t frTySizeBytes = ((precision - 1 + 7) >> 3); - IntegerType *IntExponentTy = IntegerType::get(C, nextPowerOfTwo(exTySizeBytes) << 3); - IntegerType *IntFractionTy = IntegerType::get(C, nextPowerOfTwo(frTySizeBytes) << 3); + IntegerType *IntExponentTy = + IntegerType::get(C, nextPowerOfTwo(exTySizeBytes) << 3); + IntegerType *IntFractionTy = + IntegerType::get(C, nextPowerOfTwo(frTySizeBytes) << 3); BasicBlock *end_bb = bb->splitBasicBlock(BasicBlock::iterator(FcmpInst)); /* create the integers from floats directly */ Instruction *b_op0, *b_op1; - b_op0 = CastInst::Create(Instruction::BitCast, op0, IntegerType::get(C, op0_size)); + b_op0 = CastInst::Create(Instruction::BitCast, op0, + IntegerType::get(C, op0_size)); bb->getInstList().insert(bb->getTerminator()->getIterator(), b_op0); - b_op1 = CastInst::Create(Instruction::BitCast, op1, IntegerType::get(C, op1_size)); + b_op1 = CastInst::Create(Instruction::BitCast, op1, + IntegerType::get(C, op1_size)); bb->getInstList().insert(bb->getTerminator()->getIterator(), b_op1); /* isolate signs of value of floating point type */ @@ -531,31 +547,34 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) { * the original operands so only the first bit remains.*/ Instruction *s_s0, *t_s0, *s_s1, *t_s1, *icmp_sign_bit; - s_s0 = BinaryOperator::Create(Instruction::LShr, b_op0, - ConstantInt::get(b_op0->getType(), op0_size - 1)); + s_s0 = BinaryOperator::Create( + Instruction::LShr, b_op0, + ConstantInt::get(b_op0->getType(), op0_size - 1)); bb->getInstList().insert(bb->getTerminator()->getIterator(), s_s0); t_s0 = new TruncInst(s_s0, Int1Ty); bb->getInstList().insert(bb->getTerminator()->getIterator(), t_s0); - s_s1 = BinaryOperator::Create(Instruction::LShr, b_op1, - ConstantInt::get(b_op1->getType(), op1_size - 1)); + s_s1 = BinaryOperator::Create( + Instruction::LShr, b_op1, + ConstantInt::get(b_op1->getType(), op1_size - 1)); bb->getInstList().insert(bb->getTerminator()->getIterator(), s_s1); t_s1 = new TruncInst(s_s1, Int1Ty); bb->getInstList().insert(bb->getTerminator()->getIterator(), t_s1); /* compare of the sign bits */ - icmp_sign_bit = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, t_s0, t_s1); + icmp_sign_bit = + CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, t_s0, t_s1); bb->getInstList().insert(bb->getTerminator()->getIterator(), icmp_sign_bit); /* create a new basic block which is executed if the signedness bits are * equal */ - BasicBlock * signequal_bb = + BasicBlock *signequal_bb = BasicBlock::Create(C, "signequal", end_bb->getParent(), end_bb); BranchInst::Create(end_bb, signequal_bb); /* create a new bb which is executed if exponents are equal */ - BasicBlock * middle_bb = + BasicBlock *middle_bb = BasicBlock::Create(C, "injected", end_bb->getParent(), end_bb); BranchInst::Create(end_bb, middle_bb); @@ -570,128 +589,187 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) { /* isolate the exponents */ Instruction *s_e0, *m_e0, *t_e0, *s_e1, *m_e1, *t_e1; - s_e0 = BinaryOperator::Create(Instruction::LShr, b_op0, ConstantInt::get(b_op0->getType(), shiftR_exponent)); - s_e1 = BinaryOperator::Create(Instruction::LShr, b_op1, ConstantInt::get(b_op1->getType(), shiftR_exponent)); - signequal_bb->getInstList().insert(signequal_bb->getTerminator()->getIterator(), s_e0); - signequal_bb->getInstList().insert(signequal_bb->getTerminator()->getIterator(), s_e1); + s_e0 = BinaryOperator::Create( + Instruction::LShr, b_op0, + ConstantInt::get(b_op0->getType(), shiftR_exponent)); + s_e1 = BinaryOperator::Create( + Instruction::LShr, b_op1, + ConstantInt::get(b_op1->getType(), shiftR_exponent)); + signequal_bb->getInstList().insert( + signequal_bb->getTerminator()->getIterator(), s_e0); + signequal_bb->getInstList().insert( + signequal_bb->getTerminator()->getIterator(), s_e1); t_e0 = new TruncInst(s_e0, IntExponentTy); t_e1 = new TruncInst(s_e1, IntExponentTy); - signequal_bb->getInstList().insert(signequal_bb->getTerminator()->getIterator(), t_e0); - signequal_bb->getInstList().insert(signequal_bb->getTerminator()->getIterator(), t_e1); + signequal_bb->getInstList().insert( + signequal_bb->getTerminator()->getIterator(), t_e0); + signequal_bb->getInstList().insert( + signequal_bb->getTerminator()->getIterator(), t_e1); if (sizeInBits - precision < exTySizeBytes * 8) { - m_e0 = BinaryOperator::Create(Instruction::And, t_e0, ConstantInt::get(t_e0->getType(), mask_exponent)); - m_e1 = BinaryOperator::Create(Instruction::And, t_e1, ConstantInt::get(t_e1->getType(), mask_exponent)); - signequal_bb->getInstList().insert(signequal_bb->getTerminator()->getIterator(), m_e0); - signequal_bb->getInstList().insert(signequal_bb->getTerminator()->getIterator(), m_e1); + + m_e0 = BinaryOperator::Create( + Instruction::And, t_e0, + ConstantInt::get(t_e0->getType(), mask_exponent)); + m_e1 = BinaryOperator::Create( + Instruction::And, t_e1, + ConstantInt::get(t_e1->getType(), mask_exponent)); + signequal_bb->getInstList().insert( + signequal_bb->getTerminator()->getIterator(), m_e0); + signequal_bb->getInstList().insert( + signequal_bb->getTerminator()->getIterator(), m_e1); + } else { + m_e0 = t_e0; m_e1 = t_e1; + } + /* compare the exponents of the operands */ Instruction *icmp_exponent_result; switch (FcmpInst->getPredicate()) { + case CmpInst::FCMP_OEQ: icmp_exponent_result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, m_e0, m_e1); - break; + break; case CmpInst::FCMP_ONE: case CmpInst::FCMP_UNE: icmp_exponent_result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_NE, m_e0, m_e1); - break; + break; case CmpInst::FCMP_OGT: Instruction *icmp_exponent; icmp_exponent = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_UGT, m_e0, m_e1); - signequal_bb->getInstList().insert(signequal_bb->getTerminator()->getIterator(), icmp_exponent); - icmp_exponent_result = BinaryOperator::Create(Instruction::Xor, icmp_exponent, t_s0); - break; + signequal_bb->getInstList().insert( + signequal_bb->getTerminator()->getIterator(), icmp_exponent); + icmp_exponent_result = + BinaryOperator::Create(Instruction::Xor, icmp_exponent, t_s0); + break; case CmpInst::FCMP_OLT: icmp_exponent = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULT, m_e0, m_e1); - signequal_bb->getInstList().insert(signequal_bb->getTerminator()->getIterator(), icmp_exponent); - icmp_exponent_result = BinaryOperator::Create(Instruction::Xor, icmp_exponent, t_s0); - break; - default: - continue; + signequal_bb->getInstList().insert( + signequal_bb->getTerminator()->getIterator(), icmp_exponent); + icmp_exponent_result = + BinaryOperator::Create(Instruction::Xor, icmp_exponent, t_s0); + break; + default: continue; + } - signequal_bb->getInstList().insert(signequal_bb->getTerminator()->getIterator(), icmp_exponent_result); + + signequal_bb->getInstList().insert( + signequal_bb->getTerminator()->getIterator(), icmp_exponent_result); { - auto term = signequal_bb->getTerminator(); - /* if the exponents are different do a fraction cmp */ - BranchInst::Create(middle_bb, end_bb, icmp_exponent_result, signequal_bb); - term->eraseFromParent(); - } + auto term = signequal_bb->getTerminator(); + /* if the exponents are different do a fraction cmp */ + BranchInst::Create(middle_bb, end_bb, icmp_exponent_result, signequal_bb); + term->eraseFromParent(); + + } /* isolate the mantissa aka fraction */ Instruction *t_f0, *t_f1; - bool needTrunc = IntFractionTy->getPrimitiveSizeInBits() < op0_size; -//errs() << "Fractions: IntFractionTy size " << IntFractionTy->getPrimitiveSizeInBits() << ", op0_size " << op0_size << ", needTrunc " << needTrunc << "\n"; + bool needTrunc = IntFractionTy->getPrimitiveSizeInBits() < op0_size; + // errs() << "Fractions: IntFractionTy size " << + // IntFractionTy->getPrimitiveSizeInBits() << ", op0_size " << op0_size << ", + // needTrunc " << needTrunc << "\n"; if (precision - 1 < frTySizeBytes * 8) { + Instruction *m_f0, *m_f1; - m_f0 = BinaryOperator::Create(Instruction::And, b_op0, ConstantInt::get(b_op0->getType(), mask_fraction)); - m_f1 = BinaryOperator::Create(Instruction::And, b_op1, ConstantInt::get(b_op1->getType(), mask_fraction)); - middle_bb->getInstList().insert(middle_bb->getTerminator()->getIterator(), m_f0); - middle_bb->getInstList().insert(middle_bb->getTerminator()->getIterator(), m_f1); + m_f0 = BinaryOperator::Create( + Instruction::And, b_op0, + ConstantInt::get(b_op0->getType(), mask_fraction)); + m_f1 = BinaryOperator::Create( + Instruction::And, b_op1, + ConstantInt::get(b_op1->getType(), mask_fraction)); + middle_bb->getInstList().insert(middle_bb->getTerminator()->getIterator(), + m_f0); + middle_bb->getInstList().insert(middle_bb->getTerminator()->getIterator(), + m_f1); if (needTrunc) { + t_f0 = new TruncInst(m_f0, IntFractionTy); t_f1 = new TruncInst(m_f1, IntFractionTy); - middle_bb->getInstList().insert(middle_bb->getTerminator()->getIterator(), t_f0); - middle_bb->getInstList().insert(middle_bb->getTerminator()->getIterator(), t_f1); + middle_bb->getInstList().insert( + middle_bb->getTerminator()->getIterator(), t_f0); + middle_bb->getInstList().insert( + middle_bb->getTerminator()->getIterator(), t_f1); + } else { + t_f0 = m_f0; t_f1 = m_f1; + } + } else { + if (needTrunc) { + t_f0 = new TruncInst(b_op0, IntFractionTy); t_f1 = new TruncInst(b_op1, IntFractionTy); - middle_bb->getInstList().insert(middle_bb->getTerminator()->getIterator(), t_f0); - middle_bb->getInstList().insert(middle_bb->getTerminator()->getIterator(), t_f1); + middle_bb->getInstList().insert( + middle_bb->getTerminator()->getIterator(), t_f0); + middle_bb->getInstList().insert( + middle_bb->getTerminator()->getIterator(), t_f1); + } else { + t_f0 = b_op0; t_f1 = b_op1; + } + } /* compare the fractions of the operands */ Instruction *icmp_fraction_result; switch (FcmpInst->getPredicate()) { + case CmpInst::FCMP_OEQ: icmp_fraction_result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, t_f0, t_f1); - break; + break; case CmpInst::FCMP_UNE: case CmpInst::FCMP_ONE: icmp_fraction_result = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_NE, t_f0, t_f1); - break; + break; case CmpInst::FCMP_OGT: Instruction *icmp_fraction; icmp_fraction = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_UGT, t_f0, t_f1); - middle_bb->getInstList().insert(middle_bb->getTerminator()->getIterator(), icmp_fraction); - icmp_fraction_result = BinaryOperator::Create(Instruction::Xor, icmp_fraction, t_s0); - break; + middle_bb->getInstList().insert( + middle_bb->getTerminator()->getIterator(), icmp_fraction); + icmp_fraction_result = + BinaryOperator::Create(Instruction::Xor, icmp_fraction, t_s0); + break; case CmpInst::FCMP_OLT: icmp_fraction = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULT, t_f0, t_f1); - middle_bb->getInstList().insert(middle_bb->getTerminator()->getIterator(), icmp_fraction); - icmp_fraction_result = BinaryOperator::Create(Instruction::Xor, icmp_fraction, t_s0); - break; - default: - continue; + middle_bb->getInstList().insert( + middle_bb->getTerminator()->getIterator(), icmp_fraction); + icmp_fraction_result = + BinaryOperator::Create(Instruction::Xor, icmp_fraction, t_s0); + break; + default: continue; + } - middle_bb->getInstList().insert(middle_bb->getTerminator()->getIterator(), icmp_fraction_result); + + middle_bb->getInstList().insert(middle_bb->getTerminator()->getIterator(), + icmp_fraction_result); PHINode *PN = PHINode::Create(Int1Ty, 3, ""); switch (FcmpInst->getPredicate()) { + case CmpInst::FCMP_OEQ: /* unequal signs cannot be equal values */ /* goto false branch */ @@ -700,7 +778,7 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) { PN->addIncoming(ConstantInt::get(Int1Ty, 0), signequal_bb); /* fractions comparison */ PN->addIncoming(icmp_fraction_result, middle_bb); - break; + break; case CmpInst::FCMP_ONE: case CmpInst::FCMP_UNE: /* unequal signs are unequal values */ @@ -710,28 +788,29 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) { PN->addIncoming(ConstantInt::get(Int1Ty, 1), signequal_bb); /* fractions comparison */ PN->addIncoming(icmp_fraction_result, middle_bb); - break; + break; case CmpInst::FCMP_OGT: - /* if op1 is negative goto true branch, + /* if op1 is negative goto true branch, else go on comparing */ PN->addIncoming(t_s1, bb); PN->addIncoming(icmp_exponent_result, signequal_bb); PN->addIncoming(icmp_fraction_result, middle_bb); - break; + break; case CmpInst::FCMP_OLT: /* if op0 is negative goto true branch, else go on comparing */ PN->addIncoming(t_s0, bb); PN->addIncoming(icmp_exponent_result, signequal_bb); PN->addIncoming(icmp_fraction_result, middle_bb); - break; - default: - continue; + break; + default: continue; + } BasicBlock::iterator ii(FcmpInst); ReplaceInstWithInst(FcmpInst->getParent()->getInstList(), ii, PN); ++count; + } return count; @@ -740,6 +819,7 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) { /* splits icmps of size bitw into two nested icmps with bitw/2 size each */ size_t SplitComparesTransform::splitIntCompares(Module &M, unsigned bitw) { + size_t count = 0; LLVMContext &C = M.getContext(); @@ -755,7 +835,7 @@ size_t SplitComparesTransform::splitIntCompares(Module &M, unsigned bitw) { /* not supported yet */ if (bitw > 64) { return 0; } - /* get all EQ, NE, UGT, and ULT icmps of width bitw. if the + /* get all EQ, NE, UGT, and ULT icmps of width bitw. if the * functions simplifyCompares() and simplifyIntSignedness() * were executed only these four predicates should exist */ for (auto &F : M) { @@ -938,7 +1018,9 @@ size_t SplitComparesTransform::splitIntCompares(Module &M, unsigned bitw) { ReplaceInstWithInst(IcmpInst->getParent()->getInstList(), ii, PN); } + ++count; + } return count; @@ -958,27 +1040,29 @@ bool SplitComparesTransform::runOnModule(Module &M) { simplifyIntSignedness(M); if (getenv("AFL_QUIET") == NULL) - errs() << "Split-compare-pass by laf.intel@gmail.com, extended by heiko@hexco.de\n"; + errs() << "Split-compare-pass by laf.intel@gmail.com, extended by " + "heiko@hexco.de\n"; - errs() << "Split-floatingpoint-compare-pass: " << splitFPCompares(M) << " FP comparisons splitted\n"; + errs() << "Split-floatingpoint-compare-pass: " << splitFPCompares(M) + << " FP comparisons splitted\n"; switch (bitw) { case 64: - errs() << "Split-integer-compare-pass " << bitw << "bit: " - << splitIntCompares(M, bitw) << " splitted\n"; + errs() << "Split-integer-compare-pass " << bitw + << "bit: " << splitIntCompares(M, bitw) << " splitted\n"; bitw >>= 1; [[clang::fallthrough]]; /*FALLTHRU*/ /* FALLTHROUGH */ case 32: - errs() << "Split-integer-compare-pass " << bitw << "bit: " - << splitIntCompares(M, bitw) << " splitted\n"; + errs() << "Split-integer-compare-pass " << bitw + << "bit: " << splitIntCompares(M, bitw) << " splitted\n"; bitw >>= 1; [[clang::fallthrough]]; /*FALLTHRU*/ /* FALLTHROUGH */ case 16: - errs() << "Split-integer-compare-pass " << bitw << "bit: " - << splitIntCompares(M, bitw) << " splitted\n"; + errs() << "Split-integer-compare-pass " << bitw + << "bit: " << splitIntCompares(M, bitw) << " splitted\n"; bitw >>= 1; break; |