diff options
Diffstat (limited to 'instrumentation/split-compares-pass.so.cc')
-rw-r--r-- | instrumentation/split-compares-pass.so.cc | 139 |
1 files changed, 115 insertions, 24 deletions
diff --git a/instrumentation/split-compares-pass.so.cc b/instrumentation/split-compares-pass.so.cc index 95485be9..c06118c0 100644 --- a/instrumentation/split-compares-pass.so.cc +++ b/instrumentation/split-compares-pass.so.cc @@ -882,6 +882,7 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) { // BUG FIXME TODO: u64 does not work for > 64 bit ... e.g. 80 and 128 bit if (sizeInBits > 64) { continue; } + IntegerType *intType = IntegerType::get(C, op_size); const unsigned int precision = sizeInBits == 32 ? 24 : sizeInBits == 64 ? 53 : sizeInBits == 128 ? 113 @@ -913,14 +914,99 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) { BasicBlock *end_bb = bb->splitBasicBlock(BasicBlock::iterator(FcmpInst)); /* create the integers from floats directly */ - Instruction *b_op0, *b_op1; - b_op0 = CastInst::Create(Instruction::BitCast, op0, + Instruction *bpre_op0, *bpre_op1; + bpre_op0 = CastInst::Create(Instruction::BitCast, op0, IntegerType::get(C, op_size)); - bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), b_op0); + bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), bpre_op0); - b_op1 = CastInst::Create(Instruction::BitCast, op1, + bpre_op1 = CastInst::Create(Instruction::BitCast, op1, IntegerType::get(C, op_size)); - bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), b_op1); + bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), bpre_op1); + + + /* Check if any operand is NaN. + * If so, all comparisons except unequal (which yields true) yield false */ + + /* build mask for NaN */ + const unsigned long long NaN_lowend = mask_exponent << precision; + // errs() << "Fractions: IntFractionTy size " << + // IntFractionTy->getPrimitiveSizeInBits() << ", op_size " << op_size << + // ", mask_fraction 0x"; + // errs().write_hex(mask_fraction); + // errs() << ", precision " << precision << + // ", NaN_lowend 0x"; + // errs().write_hex(NaN_lowend); errs() << "\n"; + + /* Check op0 for NaN */ + /* Shift left 1 Bit, ignore sign bit */ + Instruction *nan_op0, *nan_op1; + nan_op0 = BinaryOperator::Create( + Instruction::Shl, bpre_op0, + ConstantInt::get(bpre_op0->getType(), 1)); + bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), nan_op0); + + /* compare to NaN interval */ + Instruction *is_op0_nan = + CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_UGT, nan_op0, ConstantInt::get(intType, NaN_lowend) ); + bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), + is_op0_nan); + + /* Check op1 for NaN */ + /* Shift right 1 Bit, ignore sign bit */ + nan_op1 = BinaryOperator::Create( + Instruction::Shl, bpre_op1, + ConstantInt::get(bpre_op1->getType(), 1)); + bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), nan_op1); + + /* compare to NaN interval */ + Instruction *is_op1_nan = + CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_UGT, nan_op1, ConstantInt::get(intType, NaN_lowend) ); + bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), + is_op1_nan); + + /* combine checks */ + Instruction *is_nan = BinaryOperator::Create( + Instruction::Or, is_op0_nan, is_op1_nan); + bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), + is_nan); + + /* the result of the comparison, when at least one op is NaN + is true only for the "NOT EQUAL" predicates. */ + bool NaNcmp_result = + FcmpInst->getPredicate() == CmpInst::FCMP_ONE || + FcmpInst->getPredicate() == CmpInst::FCMP_UNE; + + BasicBlock *nonan_bb = + BasicBlock::Create(C, "noNaN", end_bb->getParent(), end_bb); + + BranchInst::Create(end_bb, nonan_bb); + + auto term = bb->getTerminator(); + /* if no operand is NaN goto nonan_bb else to handleNaN_bb */ + BranchInst::Create(end_bb, nonan_bb, is_nan, bb); + term->eraseFromParent(); + + /*** now working in nonan_bb ***/ + + /* Treat -0.0 as equal to +0.0, that is for -0.0 make it +0.0 */ + Instruction *b_op0, *b_op1; + Instruction *isMzero_op0, *isMzero_op1; + const unsigned long long MinusZero = 1UL << (sizeInBits - 1U); + const unsigned long long PlusZero = 0; + + isMzero_op0 = + CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, bpre_op0, ConstantInt::get(intType, MinusZero)); + nonan_bb->getInstList().insert(BasicBlock::iterator(nonan_bb->getTerminator()), isMzero_op0); + + isMzero_op1 = + CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, bpre_op1, ConstantInt::get(intType, MinusZero)); + nonan_bb->getInstList().insert(BasicBlock::iterator(nonan_bb->getTerminator()), isMzero_op1); + + b_op0 = SelectInst::Create(isMzero_op0, ConstantInt::get(intType, PlusZero), bpre_op0); + nonan_bb->getInstList().insert(BasicBlock::iterator(nonan_bb->getTerminator()), b_op0); + + b_op1 = SelectInst::Create(isMzero_op1, ConstantInt::get(intType, PlusZero), bpre_op1); + nonan_bb->getInstList().insert(BasicBlock::iterator(nonan_bb->getTerminator()), b_op1); /* isolate signs of value of floating point type */ @@ -931,21 +1017,21 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) { s_s0 = BinaryOperator::Create(Instruction::LShr, b_op0, ConstantInt::get(b_op0->getType(), op_size - 1)); - bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), s_s0); + nonan_bb->getInstList().insert(BasicBlock::iterator(nonan_bb->getTerminator()), s_s0); t_s0 = new TruncInst(s_s0, Int1Ty); - bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), t_s0); + nonan_bb->getInstList().insert(BasicBlock::iterator(nonan_bb->getTerminator()), t_s0); s_s1 = BinaryOperator::Create(Instruction::LShr, b_op1, ConstantInt::get(b_op1->getType(), op_size - 1)); - bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), s_s1); + nonan_bb->getInstList().insert(BasicBlock::iterator(nonan_bb->getTerminator()), s_s1); t_s1 = new TruncInst(s_s1, Int1Ty); - bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), t_s1); + nonan_bb->getInstList().insert(BasicBlock::iterator(nonan_bb->getTerminator()), t_s1); /* compare of the sign bits */ icmp_sign_bit = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, t_s0, t_s1); - bb->getInstList().insert(BasicBlock::iterator(bb->getTerminator()), + nonan_bb->getInstList().insert(BasicBlock::iterator(nonan_bb->getTerminator()), icmp_sign_bit); /* create a new basic block which is executed if the signedness bits are @@ -962,9 +1048,9 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) { BranchInst::Create(end_bb, middle_bb); - auto term = bb->getTerminator(); + term = nonan_bb->getTerminator(); /* if the signs are different goto end_bb else to signequal_bb */ - BranchInst::Create(signequal_bb, end_bb, icmp_sign_bit, bb); + BranchInst::Create(signequal_bb, end_bb, icmp_sign_bit, nonan_bb); term->eraseFromParent(); /* insert code for equal signs */ @@ -1261,7 +1347,7 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) { } - PHINode *PN = PHINode::Create(Int1Ty, 3, ""); + PHINode *PN = PHINode::Create(Int1Ty, 4, ""); switch (FcmpInst->getPredicate()) { @@ -1269,37 +1355,45 @@ size_t SplitComparesTransform::splitFPCompares(Module &M) { case CmpInst::FCMP_OEQ: /* unequal signs cannot be equal values */ /* goto false branch */ - PN->addIncoming(ConstantInt::get(Int1Ty, 0), bb); + PN->addIncoming(ConstantInt::get(Int1Ty, 0), nonan_bb); /* unequal exponents cannot be equal values, too */ PN->addIncoming(ConstantInt::get(Int1Ty, 0), signequal_bb); /* fractions comparison */ PN->addIncoming(icmp_fraction_result, middle2_bb); + /* NaNs */ + PN->addIncoming(ConstantInt::get(Int1Ty, NaNcmp_result), bb); break; case CmpInst::FCMP_ONE: case CmpInst::FCMP_UNE: /* unequal signs are unequal values */ /* goto true branch */ - PN->addIncoming(ConstantInt::get(Int1Ty, 1), bb); + PN->addIncoming(ConstantInt::get(Int1Ty, 1), nonan_bb); /* unequal exponents are unequal values, too */ PN->addIncoming(icmp_exponent_result, signequal_bb); /* fractions comparison */ PN->addIncoming(icmp_fraction_result, middle2_bb); + /* NaNs */ + PN->addIncoming(ConstantInt::get(Int1Ty, NaNcmp_result), bb); break; case CmpInst::FCMP_OGT: case CmpInst::FCMP_UGT: /* if op1 is negative goto true branch, else go on comparing */ - PN->addIncoming(t_s1, bb); + PN->addIncoming(t_s1, nonan_bb); PN->addIncoming(icmp_exponent_result, signequal2_bb); PN->addIncoming(PN2, middle2_bb); + /* NaNs */ + PN->addIncoming(ConstantInt::get(Int1Ty, NaNcmp_result), bb); break; case CmpInst::FCMP_OLT: case CmpInst::FCMP_ULT: /* if op0 is negative goto true branch, else go on comparing */ - PN->addIncoming(t_s0, bb); + PN->addIncoming(t_s0, nonan_bb); PN->addIncoming(icmp_exponent_result, signequal2_bb); PN->addIncoming(PN2, middle2_bb); + /* NaNs */ + PN->addIncoming(ConstantInt::get(Int1Ty, NaNcmp_result), bb); break; default: continue; @@ -1341,18 +1435,15 @@ bool SplitComparesTransform::runOnModule(Module &M) { if (enableFPSplit) { + simplifyFPCompares(M); count = splitFPCompares(M); - /* - if (!be_quiet) { + if (!be_quiet && !debug) { errs() << "Split-floatingpoint-compare-pass: " << count - << " FP comparisons split\n"; + << " FP comparisons splitted\n"; - } - - */ - simplifyFPCompares(M); + } } |