diff options
Diffstat (limited to 'llvm_mode/split-compares-pass.so.cc')
-rw-r--r-- | llvm_mode/split-compares-pass.so.cc | 544 |
1 files changed, 479 insertions, 65 deletions
diff --git a/llvm_mode/split-compares-pass.so.cc b/llvm_mode/split-compares-pass.so.cc index 1e9d6542..c5da42c0 100644 --- a/llvm_mode/split-compares-pass.so.cc +++ b/llvm_mode/split-compares-pass.so.cc @@ -1,5 +1,6 @@ /* * Copyright 2016 laf-intel + * extended for floating point by Heiko Eißfeldt * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,6 +22,7 @@ #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/IR/Verifier.h" #include "llvm/IR/Module.h" +#include "llvm/ADT/APFloat.h" #include "llvm/IR/IRBuilder.h" @@ -49,9 +51,11 @@ class SplitComparesTransform : public ModulePass { } private: - bool splitCompares(Module &M, unsigned bitw); + size_t splitIntCompares(Module &M, unsigned bitw); + size_t splitFPCompares(Module &M); bool simplifyCompares(Module &M); - bool simplifySignedness(Module &M); + bool simplifyIntSignedness(Module &M); + size_t nextPowerOfTwo(size_t in); }; @@ -65,6 +69,7 @@ bool SplitComparesTransform::simplifyCompares(Module &M) { LLVMContext & C = M.getContext(); std::vector<Instruction *> icomps; + std::vector<Instruction *> fcomps; IntegerType * Int1Ty = IntegerType::getInt1Ty(C); /* iterate over all functions, bbs and instruction and add @@ -79,25 +84,41 @@ bool SplitComparesTransform::simplifyCompares(Module &M) { if ((selectcmpInst = dyn_cast<CmpInst>(&IN))) { - if (selectcmpInst->getPredicate() != CmpInst::ICMP_UGE && - selectcmpInst->getPredicate() != CmpInst::ICMP_SGE && - selectcmpInst->getPredicate() != CmpInst::ICMP_ULE && - selectcmpInst->getPredicate() != CmpInst::ICMP_SLE) { + if (selectcmpInst->getPredicate() == CmpInst::ICMP_UGE || + selectcmpInst->getPredicate() == CmpInst::ICMP_SGE || + selectcmpInst->getPredicate() == CmpInst::ICMP_ULE || + selectcmpInst->getPredicate() == CmpInst::ICMP_SLE) { - continue; + auto op0 = selectcmpInst->getOperand(0); + auto op1 = selectcmpInst->getOperand(1); + + IntegerType *intTyOp0 = dyn_cast<IntegerType>(op0->getType()); + IntegerType *intTyOp1 = dyn_cast<IntegerType>(op1->getType()); + + /* this is probably not needed but we do it anyway */ + if (!intTyOp0 || !intTyOp1) { continue; } + + icomps.push_back(selectcmpInst); } - auto op0 = selectcmpInst->getOperand(0); - auto op1 = selectcmpInst->getOperand(1); + if (selectcmpInst->getPredicate() == CmpInst::FCMP_OGE || + selectcmpInst->getPredicate() == CmpInst::FCMP_UGE || + selectcmpInst->getPredicate() == CmpInst::FCMP_OLE || + selectcmpInst->getPredicate() == CmpInst::FCMP_ULE) { - IntegerType *intTyOp0 = dyn_cast<IntegerType>(op0->getType()); - IntegerType *intTyOp1 = dyn_cast<IntegerType>(op1->getType()); + auto op0 = selectcmpInst->getOperand(0); + auto op1 = selectcmpInst->getOperand(1); - /* this is probably not needed but we do it anyway */ - if (!intTyOp0 || !intTyOp1) { continue; } + Type *TyOp0 = op0->getType(); + Type *TyOp1 = op1->getType(); - icomps.push_back(selectcmpInst); + /* this is probably not needed but we do it anyway */ + if (TyOp0 != TyOp1) { continue; } + + fcomps.push_back(selectcmpInst); + + } } @@ -107,7 +128,7 @@ bool SplitComparesTransform::simplifyCompares(Module &M) { } - if (!icomps.size()) { return false; } + if (!icomps.size() && !fcomps.size()) { return false; } for (auto &IcmpInst : icomps) { @@ -173,18 +194,83 @@ bool SplitComparesTransform::simplifyCompares(Module &M) { } + /* now for floating point */ + for (auto &FcmpInst : fcomps) { + + BasicBlock *bb = FcmpInst->getParent(); + + auto op0 = FcmpInst->getOperand(0); + auto op1 = FcmpInst->getOperand(1); + + /* find out what the new predicate is going to be */ + auto pred = dyn_cast<CmpInst>(FcmpInst)->getPredicate(); + CmpInst::Predicate new_pred; + switch (pred) { + + case CmpInst::FCMP_UGE: new_pred = CmpInst::FCMP_UGT; break; + case CmpInst::FCMP_OGE: new_pred = CmpInst::FCMP_OGT; break; + case CmpInst::FCMP_ULE: new_pred = CmpInst::FCMP_ULT; break; + case CmpInst::FCMP_OLE: new_pred = CmpInst::FCMP_OLT; break; + default: // keep the compiler happy + continue; + + } + + /* split before the icmp instruction */ + BasicBlock *end_bb = bb->splitBasicBlock(BasicBlock::iterator(FcmpInst)); + + /* the old bb now contains a unconditional jump to the new one (end_bb) + * we need to delete it later */ + + /* create the ICMP instruction with new_pred and add it to the old basic + * block bb it is now at the position where the old IcmpInst was */ + Instruction *fcmp_np; + fcmp_np = CmpInst::Create(Instruction::FCmp, new_pred, op0, op1); + bb->getInstList().insert(bb->getTerminator()->getIterator(), fcmp_np); + + /* create a new basic block which holds the new EQ fcmp */ + Instruction *fcmp_eq; + /* insert middle_bb before end_bb */ + BasicBlock *middle_bb = + BasicBlock::Create(C, "injected", end_bb->getParent(), end_bb); + fcmp_eq = CmpInst::Create(Instruction::FCmp, CmpInst::FCMP_OEQ, op0, op1); + middle_bb->getInstList().push_back(fcmp_eq); + /* add an unconditional branch to the end of middle_bb with destination + * end_bb */ + BranchInst::Create(end_bb, middle_bb); + + /* replace the uncond branch with a conditional one, which depends on the + * new_pred icmp. True goes to end, false to the middle (injected) bb */ + auto term = bb->getTerminator(); + BranchInst::Create(end_bb, middle_bb, fcmp_np, bb); + term->eraseFromParent(); + + /* replace the old IcmpInst (which is the first inst in end_bb) with a PHI + * inst to wire up the loose ends */ + PHINode *PN = PHINode::Create(Int1Ty, 2, ""); + /* the first result depends on the outcome of icmp_eq */ + PN->addIncoming(fcmp_eq, middle_bb); + /* if the source was the original bb we know that the icmp_np yielded true + * hence we can hardcode this value */ + PN->addIncoming(ConstantInt::get(Int1Ty, 1), bb); + /* replace the old IcmpInst with our new and shiny PHI inst */ + BasicBlock::iterator ii(FcmpInst); + ReplaceInstWithInst(FcmpInst->getParent()->getInstList(), ii, PN); + + } + return true; } /* this function transforms signed compares to equivalent unsigned compares */ -bool SplitComparesTransform::simplifySignedness(Module &M) { +bool SplitComparesTransform::simplifyIntSignedness(Module &M) { LLVMContext & C = M.getContext(); std::vector<Instruction *> icomps; IntegerType * Int1Ty = IntegerType::getInt1Ty(C); - /* iterate over all functions, bbs and instruction and add + /* iterate over all functions, bbs and instructions and add * all signed compares to icomps vector */ for (auto &F : M) { @@ -196,26 +282,24 @@ bool SplitComparesTransform::simplifySignedness(Module &M) { if ((selectcmpInst = dyn_cast<CmpInst>(&IN))) { - if (selectcmpInst->getPredicate() != CmpInst::ICMP_SGT && - selectcmpInst->getPredicate() != CmpInst::ICMP_SLT) { + if (selectcmpInst->getPredicate() == CmpInst::ICMP_SGT || + selectcmpInst->getPredicate() == CmpInst::ICMP_SLT) { - continue; + auto op0 = selectcmpInst->getOperand(0); + auto op1 = selectcmpInst->getOperand(1); - } - - auto op0 = selectcmpInst->getOperand(0); - auto op1 = selectcmpInst->getOperand(1); + IntegerType *intTyOp0 = dyn_cast<IntegerType>(op0->getType()); + IntegerType *intTyOp1 = dyn_cast<IntegerType>(op1->getType()); - IntegerType *intTyOp0 = dyn_cast<IntegerType>(op0->getType()); - IntegerType *intTyOp1 = dyn_cast<IntegerType>(op1->getType()); + /* see above */ + if (!intTyOp0 || !intTyOp1) { continue; } - /* see above */ - if (!intTyOp0 || !intTyOp1) { continue; } + /* i think this is not possible but to lazy to look it up */ + if (intTyOp0->getBitWidth() != intTyOp1->getBitWidth()) { continue; } - /* i think this is not possible but to lazy to look it up */ - if (intTyOp0->getBitWidth() != intTyOp1->getBitWidth()) { continue; } + icomps.push_back(selectcmpInst); - icomps.push_back(selectcmpInst); + } } @@ -328,8 +412,333 @@ bool SplitComparesTransform::simplifySignedness(Module &M) { } +size_t SplitComparesTransform::nextPowerOfTwo(size_t in) { + --in; + in |= in >> 1; + in |= in >> 2; + in |= in >> 4; +// in |= in >> 8; +// in |= in >> 16; + return in + 1; +} + +/* splits fcmps into two nested fcmps with sign compare and the rest */ +size_t SplitComparesTransform::splitFPCompares(Module &M) { + size_t count = 0; + + LLVMContext &C = M.getContext(); + + const DataLayout &dl = M.getDataLayout(); + + /* define unions with floating point and (sign, exponent, mantissa) triples */ + if (dl.isLittleEndian()) { + } + else if (dl.isBigEndian()) { + } + else { + return count; + } + + std::vector<CmpInst *> fcomps; + + /* get all EQ, NE, GT, and LT fcmps. if the other two + * functions were executed only these four predicates should exist */ + for (auto &F : M) { + + for (auto &BB : F) { + + for (auto &IN : BB) { + + CmpInst *selectcmpInst = nullptr; + + if ((selectcmpInst = dyn_cast<CmpInst>(&IN))) { + + if (selectcmpInst->getPredicate() == CmpInst::FCMP_OEQ || + selectcmpInst->getPredicate() == CmpInst::FCMP_ONE || + selectcmpInst->getPredicate() == CmpInst::FCMP_UNE || + selectcmpInst->getPredicate() == CmpInst::FCMP_OGT || + selectcmpInst->getPredicate() == CmpInst::FCMP_OLT) { + + auto op0 = selectcmpInst->getOperand(0); + auto op1 = selectcmpInst->getOperand(1); + + Type *TyOp0 = op0->getType(); + Type *TyOp1 = op1->getType(); + + if (TyOp0 != TyOp1) { continue; } + + fcomps.push_back(selectcmpInst); + + } + + } + + } + + } + + } + if (!fcomps.size()) { return count; } + + IntegerType *Int1Ty = IntegerType::getInt1Ty(C); + + for (auto &FcmpInst : fcomps) { + + BasicBlock *bb = FcmpInst->getParent(); + + auto op0 = FcmpInst->getOperand(0); + auto op1 = FcmpInst->getOperand(1); + + unsigned op0_size, op1_size; + op0_size = op0->getType()->getPrimitiveSizeInBits(); + op1_size = op1->getType()->getPrimitiveSizeInBits(); + + if (op0_size != op1_size) { + continue; + } + + const unsigned int precision = llvm::APFloatBase::semanticsPrecision(op0->getType()->getFltSemantics()); + const unsigned int sizeInBits = llvm::APFloatBase::semanticsSizeInBits(op0->getType()->getFltSemantics()); + + + const unsigned shiftR_exponent = precision - 1; + const unsigned long long mask_fraction = ((1 << (precision - 2))) | ((1 << (precision - 2)) - 1); + const unsigned long long mask_exponent = (1 << (sizeInBits - precision)) - 1; + + // round up sizes to the next power of two + // this should help with integer compare splitting + size_t exTySizeBytes = ((sizeInBits - precision + 7) >> 3); + size_t frTySizeBytes = ((precision - 1 + 7) >> 3); + + IntegerType *IntExponentTy = IntegerType::get(C, nextPowerOfTwo(exTySizeBytes) << 3); + IntegerType *IntFractionTy = IntegerType::get(C, nextPowerOfTwo(frTySizeBytes) << 3); + + BasicBlock *end_bb = bb->splitBasicBlock(BasicBlock::iterator(FcmpInst)); + + /* create the integers from floats directly */ + Instruction *b_op0, *b_op1; + b_op0 = CastInst::Create(Instruction::BitCast, op0, IntegerType::get(C, op0_size)); + bb->getInstList().insert(bb->getTerminator()->getIterator(), b_op0); + + b_op1 = CastInst::Create(Instruction::BitCast, op1, IntegerType::get(C, op1_size)); + bb->getInstList().insert(bb->getTerminator()->getIterator(), b_op1); + + /* isolate signs of value of floating point type */ + + /* create a 1 bit compare for the sign bit. to do this shift and trunc + * the original operands so only the first bit remains.*/ + Instruction *s_s0, *t_s0, *s_s1, *t_s1, *icmp_sign_bit; + + s_s0 = BinaryOperator::Create(Instruction::LShr, b_op0, + ConstantInt::get(b_op0->getType(), op0_size - 1)); + bb->getInstList().insert(bb->getTerminator()->getIterator(), s_s0); + t_s0 = new TruncInst(s_s0, Int1Ty); + bb->getInstList().insert(bb->getTerminator()->getIterator(), t_s0); + + s_s1 = BinaryOperator::Create(Instruction::LShr, b_op1, + ConstantInt::get(b_op1->getType(), op1_size - 1)); + bb->getInstList().insert(bb->getTerminator()->getIterator(), s_s1); + t_s1 = new TruncInst(s_s1, Int1Ty); + bb->getInstList().insert(bb->getTerminator()->getIterator(), t_s1); + + /* compare of the sign bits */ + icmp_sign_bit = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, t_s0, t_s1); + bb->getInstList().insert(bb->getTerminator()->getIterator(), icmp_sign_bit); + + /* create a new basic block which is executed if the signedness bits are + * equal */ + BasicBlock * signequal_bb = + BasicBlock::Create(C, "signequal", end_bb->getParent(), end_bb); + + BranchInst::Create(end_bb, signequal_bb); + + /* create a new bb which is executed if exponents are equal */ + BasicBlock * middle_bb = + BasicBlock::Create(C, "injected", end_bb->getParent(), end_bb); + + BranchInst::Create(end_bb, middle_bb); + + auto term = bb->getTerminator(); + /* if the signs are different goto end_bb else to signequal_bb */ + BranchInst::Create(signequal_bb, end_bb, icmp_sign_bit, bb); + term->eraseFromParent(); + + /* insert code for equal signs */ + + /* isolate the exponents */ + Instruction *s_e0, *m_e0, *t_e0, *s_e1, *m_e1, *t_e1; + + s_e0 = BinaryOperator::Create(Instruction::LShr, b_op0, ConstantInt::get(b_op0->getType(), shiftR_exponent)); + s_e1 = BinaryOperator::Create(Instruction::LShr, b_op1, ConstantInt::get(b_op1->getType(), shiftR_exponent)); + signequal_bb->getInstList().insert(signequal_bb->getTerminator()->getIterator(), s_e0); + signequal_bb->getInstList().insert(signequal_bb->getTerminator()->getIterator(), s_e1); + + if (sizeInBits - precision < exTySizeBytes * 8) { + m_e0 = BinaryOperator::Create(Instruction::And, s_e0, ConstantInt::get(s_e0->getType(), mask_exponent)); + m_e1 = BinaryOperator::Create(Instruction::And, s_e1, ConstantInt::get(s_e1->getType(), mask_exponent)); + signequal_bb->getInstList().insert(signequal_bb->getTerminator()->getIterator(), m_e0); + signequal_bb->getInstList().insert(signequal_bb->getTerminator()->getIterator(), m_e1); + + t_e0 = new TruncInst(m_e0, IntExponentTy); + t_e1 = new TruncInst(m_e1, IntExponentTy); + } else { + t_e0 = new TruncInst(s_e0, IntExponentTy); + t_e1 = new TruncInst(s_e1, IntExponentTy); + } + signequal_bb->getInstList().insert(signequal_bb->getTerminator()->getIterator(), t_e0); + signequal_bb->getInstList().insert(signequal_bb->getTerminator()->getIterator(), t_e1); + /* compare the exponents of the operands */ + Instruction *icmp_exponent_result; + switch (FcmpInst->getPredicate()) { + case CmpInst::FCMP_OEQ: + icmp_exponent_result = + CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, t_e0, t_e1); + break; + case CmpInst::FCMP_ONE: + case CmpInst::FCMP_UNE: + icmp_exponent_result = + CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_NE, t_e0, t_e1); + break; + case CmpInst::FCMP_OGT: + Instruction *icmp_exponent; + icmp_exponent = + CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_UGT, t_e0, t_e1); + signequal_bb->getInstList().insert(signequal_bb->getTerminator()->getIterator(), icmp_exponent); + icmp_exponent_result = BinaryOperator::Create(Instruction::Xor, icmp_exponent, t_s0); + break; + case CmpInst::FCMP_OLT: + icmp_exponent = + CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULT, t_e0, t_e1); + signequal_bb->getInstList().insert(signequal_bb->getTerminator()->getIterator(), icmp_exponent); + icmp_exponent_result = BinaryOperator::Create(Instruction::Xor, icmp_exponent, t_s0); + break; + default: + continue; + } + signequal_bb->getInstList().insert(signequal_bb->getTerminator()->getIterator(), icmp_exponent_result); + + { + auto term = signequal_bb->getTerminator(); + /* if the exponents are different do a fraction cmp */ + BranchInst::Create(middle_bb, end_bb, icmp_exponent_result, signequal_bb); + term->eraseFromParent(); + } + + + /* isolate the mantissa aka fraction */ + Instruction *t_f0, *t_f1; + bool needTrunc = IntFractionTy->getPrimitiveSizeInBits() < op0_size; +//errs() << "Fractions: IntFractionTy size " << IntFractionTy->getPrimitiveSizeInBits() << ", op0_size " << op0_size << ", needTrunc " << needTrunc << "\n"; + if (precision - 1 < frTySizeBytes * 8) { + Instruction *m_f0, *m_f1; + m_f0 = BinaryOperator::Create(Instruction::And, b_op0, ConstantInt::get(b_op0->getType(), mask_fraction)); + m_f1 = BinaryOperator::Create(Instruction::And, b_op1, ConstantInt::get(b_op1->getType(), mask_fraction)); + middle_bb->getInstList().insert(middle_bb->getTerminator()->getIterator(), m_f0); + middle_bb->getInstList().insert(middle_bb->getTerminator()->getIterator(), m_f1); + + if (needTrunc) { + t_f0 = new TruncInst(m_f0, IntFractionTy); + t_f1 = new TruncInst(m_f1, IntFractionTy); + middle_bb->getInstList().insert(middle_bb->getTerminator()->getIterator(), t_f0); + middle_bb->getInstList().insert(middle_bb->getTerminator()->getIterator(), t_f1); + } else { + t_f0 = m_f0; + t_f1 = m_f1; + } + } else { + if (needTrunc) { + t_f0 = new TruncInst(b_op0, IntFractionTy); + t_f1 = new TruncInst(b_op1, IntFractionTy); + middle_bb->getInstList().insert(middle_bb->getTerminator()->getIterator(), t_f0); + middle_bb->getInstList().insert(middle_bb->getTerminator()->getIterator(), t_f1); + } else { + t_f0 = b_op0; + t_f1 = b_op1; + } + } + + /* compare the fractions of the operands */ + Instruction *icmp_fraction_result; + switch (FcmpInst->getPredicate()) { + case CmpInst::FCMP_OEQ: + icmp_fraction_result = + CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, t_f0, t_f1); + break; + case CmpInst::FCMP_UNE: + case CmpInst::FCMP_ONE: + icmp_fraction_result = + CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_NE, t_f0, t_f1); + break; + case CmpInst::FCMP_OGT: + Instruction *icmp_fraction; + icmp_fraction = + CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_UGT, t_f0, t_f1); + middle_bb->getInstList().insert(middle_bb->getTerminator()->getIterator(), icmp_fraction); + icmp_fraction_result = BinaryOperator::Create(Instruction::Xor, icmp_fraction, t_s0); + break; + case CmpInst::FCMP_OLT: + icmp_fraction = + CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULT, t_f0, t_f1); + middle_bb->getInstList().insert(middle_bb->getTerminator()->getIterator(), icmp_fraction); + icmp_fraction_result = BinaryOperator::Create(Instruction::Xor, icmp_fraction, t_s0); + break; + default: + continue; + } + middle_bb->getInstList().insert(middle_bb->getTerminator()->getIterator(), icmp_fraction_result); + + PHINode *PN = PHINode::Create(Int1Ty, 3, ""); + + switch (FcmpInst->getPredicate()) { + case CmpInst::FCMP_OEQ: + /* unequal signs cannot be equal values */ + /* goto false branch */ + PN->addIncoming(ConstantInt::get(Int1Ty, 0), bb); + /* unequal exponents cannot be equal values, too */ + PN->addIncoming(ConstantInt::get(Int1Ty, 0), signequal_bb); + /* fractions comparison */ + PN->addIncoming(icmp_fraction_result, middle_bb); + break; + case CmpInst::FCMP_ONE: + case CmpInst::FCMP_UNE: + /* unequal signs are unequal values */ + /* goto true branch */ + PN->addIncoming(ConstantInt::get(Int1Ty, 1), bb); + /* unequal exponents are unequal values, too */ + PN->addIncoming(ConstantInt::get(Int1Ty, 1), signequal_bb); + /* fractions comparison */ + PN->addIncoming(icmp_fraction_result, middle_bb); + break; + case CmpInst::FCMP_OGT: + /* if op1 is negative goto true branch, + else go on comparing */ + PN->addIncoming(t_s1, bb); + PN->addIncoming(icmp_exponent_result, signequal_bb); + PN->addIncoming(icmp_fraction_result, middle_bb); + break; + case CmpInst::FCMP_OLT: + /* if op0 is negative goto true branch, + else go on comparing */ + PN->addIncoming(t_s0, bb); + PN->addIncoming(icmp_exponent_result, signequal_bb); + PN->addIncoming(icmp_fraction_result, middle_bb); + break; + default: + continue; + } + + BasicBlock::iterator ii(FcmpInst); + ReplaceInstWithInst(FcmpInst->getParent()->getInstList(), ii, PN); + ++count; + } + + return count; + +} + /* splits icmps of size bitw into two nested icmps with bitw/2 size each */ -bool SplitComparesTransform::splitCompares(Module &M, unsigned bitw) { +size_t SplitComparesTransform::splitIntCompares(Module &M, unsigned bitw) { + size_t count = 0; LLVMContext &C = M.getContext(); @@ -339,13 +748,14 @@ bool SplitComparesTransform::splitCompares(Module &M, unsigned bitw) { std::vector<Instruction *> icomps; - if (bitw % 2) { return false; } + if (bitw % 2) { return 0; } /* not supported yet */ - if (bitw > 64) { return false; } + if (bitw > 64) { return 0; } - /* get all EQ, NE, UGT, and ULT icmps of width bitw. if the other two - * unctions were executed only these four predicates should exist */ + /* get all EQ, NE, UGT, and ULT icmps of width bitw. if the + * functions simplifyCompares() and simplifyIntSignedness() + * were executed only these four predicates should exist */ for (auto &F : M) { for (auto &BB : F) { @@ -356,33 +766,31 @@ bool SplitComparesTransform::splitCompares(Module &M, unsigned bitw) { if ((selectcmpInst = dyn_cast<CmpInst>(&IN))) { - if (selectcmpInst->getPredicate() != CmpInst::ICMP_EQ && - selectcmpInst->getPredicate() != CmpInst::ICMP_NE && - selectcmpInst->getPredicate() != CmpInst::ICMP_UGT && - selectcmpInst->getPredicate() != CmpInst::ICMP_ULT) { + if (selectcmpInst->getPredicate() == CmpInst::ICMP_EQ || + selectcmpInst->getPredicate() == CmpInst::ICMP_NE || + selectcmpInst->getPredicate() == CmpInst::ICMP_UGT || + selectcmpInst->getPredicate() == CmpInst::ICMP_ULT) { - continue; + auto op0 = selectcmpInst->getOperand(0); + auto op1 = selectcmpInst->getOperand(1); - } + IntegerType *intTyOp0 = dyn_cast<IntegerType>(op0->getType()); + IntegerType *intTyOp1 = dyn_cast<IntegerType>(op1->getType()); - auto op0 = selectcmpInst->getOperand(0); - auto op1 = selectcmpInst->getOperand(1); + if (!intTyOp0 || !intTyOp1) { continue; } - IntegerType *intTyOp0 = dyn_cast<IntegerType>(op0->getType()); - IntegerType *intTyOp1 = dyn_cast<IntegerType>(op1->getType()); + /* check if the bitwidths are the one we are looking for */ + if (intTyOp0->getBitWidth() != bitw || + intTyOp1->getBitWidth() != bitw) { - if (!intTyOp0 || !intTyOp1) { continue; } + continue; - /* check if the bitwidths are the one we are looking for */ - if (intTyOp0->getBitWidth() != bitw || - intTyOp1->getBitWidth() != bitw) { + } - continue; + icomps.push_back(selectcmpInst); } - icomps.push_back(selectcmpInst); - } } @@ -391,7 +799,7 @@ bool SplitComparesTransform::splitCompares(Module &M, unsigned bitw) { } - if (!icomps.size()) { return false; } + if (!icomps.size()) { return 0; } for (auto &IcmpInst : icomps) { @@ -482,7 +890,7 @@ bool SplitComparesTransform::splitCompares(Module &M, unsigned bitw) { /* transformations for < and > */ /* create a basic block which checks for the inverse predicate. - * if this is true we can go to the end if not we have to got to the + * if this is true we can go to the end if not we have to go to the * bb which checks the lower half of the operands */ Instruction *icmp_inv_cmp, *op0_low, *op1_low, *icmp_low; BasicBlock * inv_cmp_bb = @@ -528,10 +936,10 @@ bool SplitComparesTransform::splitCompares(Module &M, unsigned bitw) { ReplaceInstWithInst(IcmpInst->getParent()->getInstList(), ii, PN); } - + ++count; } - return true; + return count; } @@ -545,26 +953,32 @@ bool SplitComparesTransform::runOnModule(Module &M) { simplifyCompares(M); - simplifySignedness(M); + simplifyIntSignedness(M); if (getenv("AFL_QUIET") == NULL) - errs() << "Split-compare-pass by laf.intel@gmail.com\n"; + errs() << "Split-compare-pass by laf.intel@gmail.com, extended by heiko@hexco.de\n"; + + errs() << "Split-floatingpoint-compare-pass: " << splitFPCompares(M) << " FP comparisons splitted\n"; switch (bitw) { case 64: - errs() << "Running split-compare-pass " << 64 << "\n"; - splitCompares(M, 64); + errs() << "Split-integer-compare-pass " << bitw << "bit: " + << splitIntCompares(M, bitw) << " splitted\n"; + bitw >>= 1; [[clang::fallthrough]]; /*FALLTHRU*/ /* FALLTHROUGH */ case 32: - errs() << "Running split-compare-pass " << 32 << "\n"; - splitCompares(M, 32); + errs() << "Split-integer-compare-pass " << bitw << "bit: " + << splitIntCompares(M, bitw) << " splitted\n"; + bitw >>= 1; [[clang::fallthrough]]; /*FALLTHRU*/ /* FALLTHROUGH */ case 16: - errs() << "Running split-compare-pass " << 16 << "\n"; - splitCompares(M, 16); + errs() << "Split-integer-compare-pass " << bitw << "bit: " + << splitIntCompares(M, bitw) << " splitted\n"; + + bitw >>= 1; break; default: |