diff options
Diffstat (limited to 'llvm_mode/split-compares-pass.so.cc')
-rw-r--r-- | llvm_mode/split-compares-pass.so.cc | 527 |
1 files changed, 527 insertions, 0 deletions
diff --git a/llvm_mode/split-compares-pass.so.cc b/llvm_mode/split-compares-pass.so.cc new file mode 100644 index 00000000..5bd01d62 --- /dev/null +++ b/llvm_mode/split-compares-pass.so.cc @@ -0,0 +1,527 @@ +/* + * Copyright 2016 laf-intel + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "llvm/Pass.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/IR/LegacyPassManager.h" +#include "llvm/Transforms/IPO/PassManagerBuilder.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/IR/Verifier.h" +#include "llvm/IR/Module.h" + +#include "llvm/IR/IRBuilder.h" + +using namespace llvm; + +namespace { + class SplitComparesTransform : public ModulePass { + public: + static char ID; + SplitComparesTransform() : ModulePass(ID) {} + + bool runOnModule(Module &M) override; +#if __clang_major__ >= 4 + StringRef getPassName() const override { +#else + const char * getPassName() const override { +#endif + return "simplifies and splits ICMP instructions"; + } + private: + bool splitCompares(Module &M, unsigned bitw); + bool simplifyCompares(Module &M); + bool simplifySignedness(Module &M); + + }; +} + +char SplitComparesTransform::ID = 0; + +/* This function splits ICMP instructions with xGE or xLE predicates into two + * ICMP instructions with predicate xGT or xLT and EQ */ +bool SplitComparesTransform::simplifyCompares(Module &M) { + LLVMContext &C = M.getContext(); + std::vector<Instruction*> icomps; + IntegerType *Int1Ty = IntegerType::getInt1Ty(C); + + /* iterate over all functions, bbs and instruction and add + * all integer comparisons with >= and <= predicates to the icomps vector */ + for (auto &F : M) { + for (auto &BB : F) { + for (auto &IN: BB) { + CmpInst* selectcmpInst = nullptr; + + if ((selectcmpInst = dyn_cast<CmpInst>(&IN))) { + + if (selectcmpInst->getPredicate() != CmpInst::ICMP_UGE && + selectcmpInst->getPredicate() != CmpInst::ICMP_SGE && + selectcmpInst->getPredicate() != CmpInst::ICMP_ULE && + selectcmpInst->getPredicate() != CmpInst::ICMP_SLE ) { + continue; + } + + auto op0 = selectcmpInst->getOperand(0); + auto op1 = selectcmpInst->getOperand(1); + + IntegerType* intTyOp0 = dyn_cast<IntegerType>(op0->getType()); + IntegerType* intTyOp1 = dyn_cast<IntegerType>(op1->getType()); + + /* this is probably not needed but we do it anyway */ + if (!intTyOp0 || !intTyOp1) { + continue; + } + + icomps.push_back(selectcmpInst); + } + } + } + } + + if (!icomps.size()) { + return false; + } + + + for (auto &IcmpInst: icomps) { + BasicBlock* bb = IcmpInst->getParent(); + + auto op0 = IcmpInst->getOperand(0); + auto op1 = IcmpInst->getOperand(1); + + /* find out what the new predicate is going to be */ + auto pred = dyn_cast<CmpInst>(IcmpInst)->getPredicate(); + CmpInst::Predicate new_pred; + switch(pred) { + case CmpInst::ICMP_UGE: + new_pred = CmpInst::ICMP_UGT; + break; + case CmpInst::ICMP_SGE: + new_pred = CmpInst::ICMP_SGT; + break; + case CmpInst::ICMP_ULE: + new_pred = CmpInst::ICMP_ULT; + break; + case CmpInst::ICMP_SLE: + new_pred = CmpInst::ICMP_SLT; + break; + default: // keep the compiler happy + continue; + } + + /* split before the icmp instruction */ + BasicBlock* end_bb = bb->splitBasicBlock(BasicBlock::iterator(IcmpInst)); + + /* the old bb now contains a unconditional jump to the new one (end_bb) + * we need to delete it later */ + + /* create the ICMP instruction with new_pred and add it to the old basic + * block bb it is now at the position where the old IcmpInst was */ + Instruction* icmp_np; + icmp_np = CmpInst::Create(Instruction::ICmp, new_pred, op0, op1); + bb->getInstList().insert(bb->getTerminator()->getIterator(), icmp_np); + + /* create a new basic block which holds the new EQ icmp */ + Instruction *icmp_eq; + /* insert middle_bb before end_bb */ + BasicBlock* middle_bb = BasicBlock::Create(C, "injected", + end_bb->getParent(), end_bb); + icmp_eq = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, op0, op1); + middle_bb->getInstList().push_back(icmp_eq); + /* add an unconditional branch to the end of middle_bb with destination + * end_bb */ + BranchInst::Create(end_bb, middle_bb); + + /* replace the uncond branch with a conditional one, which depends on the + * new_pred icmp. True goes to end, false to the middle (injected) bb */ + auto term = bb->getTerminator(); + BranchInst::Create(end_bb, middle_bb, icmp_np, bb); + term->eraseFromParent(); + + + /* replace the old IcmpInst (which is the first inst in end_bb) with a PHI + * inst to wire up the loose ends */ + PHINode *PN = PHINode::Create(Int1Ty, 2, ""); + /* the first result depends on the outcome of icmp_eq */ + PN->addIncoming(icmp_eq, middle_bb); + /* if the source was the original bb we know that the icmp_np yielded true + * hence we can hardcode this value */ + PN->addIncoming(ConstantInt::get(Int1Ty, 1), bb); + /* replace the old IcmpInst with our new and shiny PHI inst */ + BasicBlock::iterator ii(IcmpInst); + ReplaceInstWithInst(IcmpInst->getParent()->getInstList(), ii, PN); + } + + return true; +} + +/* this function transforms signed compares to equivalent unsigned compares */ +bool SplitComparesTransform::simplifySignedness(Module &M) { + LLVMContext &C = M.getContext(); + std::vector<Instruction*> icomps; + IntegerType *Int1Ty = IntegerType::getInt1Ty(C); + + /* iterate over all functions, bbs and instruction and add + * all signed compares to icomps vector */ + for (auto &F : M) { + for (auto &BB : F) { + for(auto &IN: BB) { + CmpInst* selectcmpInst = nullptr; + + if ((selectcmpInst = dyn_cast<CmpInst>(&IN))) { + + if (selectcmpInst->getPredicate() != CmpInst::ICMP_SGT && + selectcmpInst->getPredicate() != CmpInst::ICMP_SLT + ) { + continue; + } + + auto op0 = selectcmpInst->getOperand(0); + auto op1 = selectcmpInst->getOperand(1); + + IntegerType* intTyOp0 = dyn_cast<IntegerType>(op0->getType()); + IntegerType* intTyOp1 = dyn_cast<IntegerType>(op1->getType()); + + /* see above */ + if (!intTyOp0 || !intTyOp1) { + continue; + } + + /* i think this is not possible but to lazy to look it up */ + if (intTyOp0->getBitWidth() != intTyOp1->getBitWidth()) { + continue; + } + + icomps.push_back(selectcmpInst); + } + } + } + } + + if (!icomps.size()) { + return false; + } + + for (auto &IcmpInst: icomps) { + BasicBlock* bb = IcmpInst->getParent(); + + auto op0 = IcmpInst->getOperand(0); + auto op1 = IcmpInst->getOperand(1); + + IntegerType* intTyOp0 = dyn_cast<IntegerType>(op0->getType()); + unsigned bitw = intTyOp0->getBitWidth(); + IntegerType *IntType = IntegerType::get(C, bitw); + + + /* get the new predicate */ + auto pred = dyn_cast<CmpInst>(IcmpInst)->getPredicate(); + CmpInst::Predicate new_pred; + if (pred == CmpInst::ICMP_SGT) { + new_pred = CmpInst::ICMP_UGT; + } else { + new_pred = CmpInst::ICMP_ULT; + } + + BasicBlock* end_bb = bb->splitBasicBlock(BasicBlock::iterator(IcmpInst)); + + /* create a 1 bit compare for the sign bit. to do this shift and trunc + * the original operands so only the first bit remains.*/ + Instruction *s_op0, *t_op0, *s_op1, *t_op1, *icmp_sign_bit; + + s_op0 = BinaryOperator::Create(Instruction::LShr, op0, ConstantInt::get(IntType, bitw - 1)); + bb->getInstList().insert(bb->getTerminator()->getIterator(), s_op0); + t_op0 = new TruncInst(s_op0, Int1Ty); + bb->getInstList().insert(bb->getTerminator()->getIterator(), t_op0); + + s_op1 = BinaryOperator::Create(Instruction::LShr, op1, ConstantInt::get(IntType, bitw - 1)); + bb->getInstList().insert(bb->getTerminator()->getIterator(), s_op1); + t_op1 = new TruncInst(s_op1, Int1Ty); + bb->getInstList().insert(bb->getTerminator()->getIterator(), t_op1); + + /* compare of the sign bits */ + icmp_sign_bit = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, t_op0, t_op1); + bb->getInstList().insert(bb->getTerminator()->getIterator(), icmp_sign_bit); + + /* create a new basic block which is executed if the signedness bit is + * different */ + Instruction *icmp_inv_sig_cmp; + BasicBlock* sign_bb = BasicBlock::Create(C, "sign", end_bb->getParent(), end_bb); + if (pred == CmpInst::ICMP_SGT) { + /* if we check for > and the op0 positiv and op1 negative then the final + * result is true. if op0 negative and op1 pos, the cmp must result + * in false + */ + icmp_inv_sig_cmp = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULT, t_op0, t_op1); + } else { + /* just the inverse of the above statement */ + icmp_inv_sig_cmp = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_UGT, t_op0, t_op1); + } + sign_bb->getInstList().push_back(icmp_inv_sig_cmp); + BranchInst::Create(end_bb, sign_bb); + + /* create a new bb which is executed if signedness is equal */ + Instruction *icmp_usign_cmp; + BasicBlock* middle_bb = BasicBlock::Create(C, "injected", end_bb->getParent(), end_bb); + /* we can do a normal unsigned compare now */ + icmp_usign_cmp = CmpInst::Create(Instruction::ICmp, new_pred, op0, op1); + middle_bb->getInstList().push_back(icmp_usign_cmp); + BranchInst::Create(end_bb, middle_bb); + + auto term = bb->getTerminator(); + /* if the sign is eq do a normal unsigned cmp, else we have to check the + * signedness bit */ + BranchInst::Create(middle_bb, sign_bb, icmp_sign_bit, bb); + term->eraseFromParent(); + + + PHINode *PN = PHINode::Create(Int1Ty, 2, ""); + + PN->addIncoming(icmp_usign_cmp, middle_bb); + PN->addIncoming(icmp_inv_sig_cmp, sign_bb); + + BasicBlock::iterator ii(IcmpInst); + ReplaceInstWithInst(IcmpInst->getParent()->getInstList(), ii, PN); + } + + return true; +} + +/* splits icmps of size bitw into two nested icmps with bitw/2 size each */ +bool SplitComparesTransform::splitCompares(Module &M, unsigned bitw) { + LLVMContext &C = M.getContext(); + + IntegerType *Int1Ty = IntegerType::getInt1Ty(C); + IntegerType *OldIntType = IntegerType::get(C, bitw); + IntegerType *NewIntType = IntegerType::get(C, bitw / 2); + + std::vector<Instruction*> icomps; + + if (bitw % 2) { + return false; + } + + /* not supported yet */ + if (bitw > 64) { + return false; + } + + /* get all EQ, NE, UGT, and ULT icmps of width bitw. if the other two + * unctions were executed only these four predicates should exist */ + for (auto &F : M) { + for (auto &BB : F) { + for(auto &IN: BB) { + CmpInst* selectcmpInst = nullptr; + + if ((selectcmpInst = dyn_cast<CmpInst>(&IN))) { + + if(selectcmpInst->getPredicate() != CmpInst::ICMP_EQ && + selectcmpInst->getPredicate() != CmpInst::ICMP_NE && + selectcmpInst->getPredicate() != CmpInst::ICMP_UGT && + selectcmpInst->getPredicate() != CmpInst::ICMP_ULT + ) { + continue; + } + + auto op0 = selectcmpInst->getOperand(0); + auto op1 = selectcmpInst->getOperand(1); + + IntegerType* intTyOp0 = dyn_cast<IntegerType>(op0->getType()); + IntegerType* intTyOp1 = dyn_cast<IntegerType>(op1->getType()); + + if (!intTyOp0 || !intTyOp1) { + continue; + } + + /* check if the bitwidths are the one we are looking for */ + if (intTyOp0->getBitWidth() != bitw || intTyOp1->getBitWidth() != bitw) { + continue; + } + + icomps.push_back(selectcmpInst); + } + } + } + } + + if (!icomps.size()) { + return false; + } + + for (auto &IcmpInst: icomps) { + BasicBlock* bb = IcmpInst->getParent(); + + auto op0 = IcmpInst->getOperand(0); + auto op1 = IcmpInst->getOperand(1); + + auto pred = dyn_cast<CmpInst>(IcmpInst)->getPredicate(); + + BasicBlock* end_bb = bb->splitBasicBlock(BasicBlock::iterator(IcmpInst)); + + /* create the comparison of the top halfs of the original operands */ + Instruction *s_op0, *op0_high, *s_op1, *op1_high, *icmp_high; + + s_op0 = BinaryOperator::Create(Instruction::LShr, op0, ConstantInt::get(OldIntType, bitw / 2)); + bb->getInstList().insert(bb->getTerminator()->getIterator(), s_op0); + op0_high = new TruncInst(s_op0, NewIntType); + bb->getInstList().insert(bb->getTerminator()->getIterator(), op0_high); + + s_op1 = BinaryOperator::Create(Instruction::LShr, op1, ConstantInt::get(OldIntType, bitw / 2)); + bb->getInstList().insert(bb->getTerminator()->getIterator(), s_op1); + op1_high = new TruncInst(s_op1, NewIntType); + bb->getInstList().insert(bb->getTerminator()->getIterator(), op1_high); + + icmp_high = CmpInst::Create(Instruction::ICmp, pred, op0_high, op1_high); + bb->getInstList().insert(bb->getTerminator()->getIterator(), icmp_high); + + /* now we have to destinguish between == != and > < */ + if (pred == CmpInst::ICMP_EQ || pred == CmpInst::ICMP_NE) { + /* transformation for == and != icmps */ + + /* create a compare for the lower half of the original operands */ + Instruction *op0_low, *op1_low, *icmp_low; + BasicBlock* cmp_low_bb = BasicBlock::Create(C, "injected", end_bb->getParent(), end_bb); + + op0_low = new TruncInst(op0, NewIntType); + cmp_low_bb->getInstList().push_back(op0_low); + + op1_low = new TruncInst(op1, NewIntType); + cmp_low_bb->getInstList().push_back(op1_low); + + icmp_low = CmpInst::Create(Instruction::ICmp, pred, op0_low, op1_low); + cmp_low_bb->getInstList().push_back(icmp_low); + BranchInst::Create(end_bb, cmp_low_bb); + + /* dependant on the cmp of the high parts go to the end or go on with + * the comparison */ + auto term = bb->getTerminator(); + if (pred == CmpInst::ICMP_EQ) { + BranchInst::Create(cmp_low_bb, end_bb, icmp_high, bb); + } else { + /* CmpInst::ICMP_NE */ + BranchInst::Create(end_bb, cmp_low_bb, icmp_high, bb); + } + term->eraseFromParent(); + + /* create the PHI and connect the edges accordingly */ + PHINode *PN = PHINode::Create(Int1Ty, 2, ""); + PN->addIncoming(icmp_low, cmp_low_bb); + if (pred == CmpInst::ICMP_EQ) { + PN->addIncoming(ConstantInt::get(Int1Ty, 0), bb); + } else { + /* CmpInst::ICMP_NE */ + PN->addIncoming(ConstantInt::get(Int1Ty, 1), bb); + } + + /* replace the old icmp with the new PHI */ + BasicBlock::iterator ii(IcmpInst); + ReplaceInstWithInst(IcmpInst->getParent()->getInstList(), ii, PN); + + } else { + /* CmpInst::ICMP_UGT and CmpInst::ICMP_ULT */ + /* transformations for < and > */ + + /* create a basic block which checks for the inverse predicate. + * if this is true we can go to the end if not we have to got to the + * bb which checks the lower half of the operands */ + Instruction *icmp_inv_cmp, *op0_low, *op1_low, *icmp_low; + BasicBlock* inv_cmp_bb = BasicBlock::Create(C, "inv_cmp", end_bb->getParent(), end_bb); + if (pred == CmpInst::ICMP_UGT) { + icmp_inv_cmp = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULT, op0_high, op1_high); + } else { + icmp_inv_cmp = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_UGT, op0_high, op1_high); + } + inv_cmp_bb->getInstList().push_back(icmp_inv_cmp); + + auto term = bb->getTerminator(); + term->eraseFromParent(); + BranchInst::Create(end_bb, inv_cmp_bb, icmp_high, bb); + + /* create a bb which handles the cmp of the lower halfs */ + BasicBlock* cmp_low_bb = BasicBlock::Create(C, "injected", end_bb->getParent(), end_bb); + op0_low = new TruncInst(op0, NewIntType); + cmp_low_bb->getInstList().push_back(op0_low); + op1_low = new TruncInst(op1, NewIntType); + cmp_low_bb->getInstList().push_back(op1_low); + + icmp_low = CmpInst::Create(Instruction::ICmp, pred, op0_low, op1_low); + cmp_low_bb->getInstList().push_back(icmp_low); + BranchInst::Create(end_bb, cmp_low_bb); + + BranchInst::Create(end_bb, cmp_low_bb, icmp_inv_cmp, inv_cmp_bb); + + PHINode *PN = PHINode::Create(Int1Ty, 3); + PN->addIncoming(icmp_low, cmp_low_bb); + PN->addIncoming(ConstantInt::get(Int1Ty, 1), bb); + PN->addIncoming(ConstantInt::get(Int1Ty, 0), inv_cmp_bb); + + BasicBlock::iterator ii(IcmpInst); + ReplaceInstWithInst(IcmpInst->getParent()->getInstList(), ii, PN); + } + } + return true; +} + +bool SplitComparesTransform::runOnModule(Module &M) { + int bitw = 64; + + char* bitw_env = getenv("LAF_SPLIT_COMPARES_BITW"); + if (bitw_env) { + bitw = atoi(bitw_env); + } + + simplifyCompares(M); + + simplifySignedness(M); + + errs() << "Split-compare-pass by laf.intel@gmail.com\n"; + + switch (bitw) { + case 64: + errs() << "Running split-compare-pass " << 64 << "\n"; + splitCompares(M, 64); + + [[clang::fallthrough]]; + /* fallthrough */ + case 32: + errs() << "Running split-compare-pass " << 32 << "\n"; + splitCompares(M, 32); + + [[clang::fallthrough]]; + /* fallthrough */ + case 16: + errs() << "Running split-compare-pass " << 16 << "\n"; + splitCompares(M, 16); + break; + + default: + errs() << "NOT Running split-compare-pass \n"; + return false; + break; + } + + verifyModule(M); + return true; +} + +static void registerSplitComparesPass(const PassManagerBuilder &, + legacy::PassManagerBase &PM) { + PM.add(new SplitComparesTransform()); +} + +static RegisterStandardPasses RegisterSplitComparesPass( + PassManagerBuilder::EP_OptimizerLast, registerSplitComparesPass); + +static RegisterStandardPasses RegisterSplitComparesTransPass0( + PassManagerBuilder::EP_EnabledOnOptLevel0, registerSplitComparesPass); |