diff options
Diffstat (limited to 'lib/Analysis/ValueTracking.cpp')
-rw-r--r-- | lib/Analysis/ValueTracking.cpp | 692 |
1 files changed, 369 insertions, 323 deletions
diff --git a/lib/Analysis/ValueTracking.cpp b/lib/Analysis/ValueTracking.cpp index b79370baad10..d871e83f222a 100644 --- a/lib/Analysis/ValueTracking.cpp +++ b/lib/Analysis/ValueTracking.cpp @@ -20,6 +20,7 @@ #include "llvm/Analysis/MemoryBuiltins.h" #include "llvm/Analysis/Loads.h" #include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/OptimizationDiagnosticInfo.h" #include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/CallSite.h" #include "llvm/IR/ConstantRange.h" @@ -76,6 +77,9 @@ struct Query { AssumptionCache *AC; const Instruction *CxtI; const DominatorTree *DT; + // Unlike the other analyses, this may be a nullptr because not all clients + // provide it currently. + OptimizationRemarkEmitter *ORE; /// Set of assumptions that should be excluded from further queries. /// This is because of the potential for mutual recursion to cause @@ -90,11 +94,12 @@ struct Query { unsigned NumExcluded; Query(const DataLayout &DL, AssumptionCache *AC, const Instruction *CxtI, - const DominatorTree *DT) - : DL(DL), AC(AC), CxtI(CxtI), DT(DT), NumExcluded(0) {} + const DominatorTree *DT, OptimizationRemarkEmitter *ORE = nullptr) + : DL(DL), AC(AC), CxtI(CxtI), DT(DT), ORE(ORE), NumExcluded(0) {} Query(const Query &Q, const Value *NewExcl) - : DL(Q.DL), AC(Q.AC), CxtI(Q.CxtI), DT(Q.DT), NumExcluded(Q.NumExcluded) { + : DL(Q.DL), AC(Q.AC), CxtI(Q.CxtI), DT(Q.DT), ORE(Q.ORE), + NumExcluded(Q.NumExcluded) { Excluded = Q.Excluded; Excluded[NumExcluded++] = NewExcl; assert(NumExcluded <= Excluded.size()); @@ -131,9 +136,10 @@ static void computeKnownBits(const Value *V, APInt &KnownZero, APInt &KnownOne, void llvm::computeKnownBits(const Value *V, APInt &KnownZero, APInt &KnownOne, const DataLayout &DL, unsigned Depth, AssumptionCache *AC, const Instruction *CxtI, - const DominatorTree *DT) { + const DominatorTree *DT, + OptimizationRemarkEmitter *ORE) { ::computeKnownBits(V, KnownZero, KnownOne, Depth, - Query(DL, AC, safeCxtI(V, CxtI), DT)); + Query(DL, AC, safeCxtI(V, CxtI), DT, ORE)); } bool llvm::haveNoCommonBitsSet(const Value *LHS, const Value *RHS, @@ -249,30 +255,6 @@ static void computeKnownBitsAddSub(bool Add, const Value *Op0, const Value *Op1, APInt &KnownZero, APInt &KnownOne, APInt &KnownZero2, APInt &KnownOne2, unsigned Depth, const Query &Q) { - if (!Add) { - if (const ConstantInt *CLHS = dyn_cast<ConstantInt>(Op0)) { - // We know that the top bits of C-X are clear if X contains less bits - // than C (i.e. no wrap-around can happen). For example, 20-X is - // positive if we can prove that X is >= 0 and < 16. - if (!CLHS->getValue().isNegative()) { - unsigned BitWidth = KnownZero.getBitWidth(); - unsigned NLZ = (CLHS->getValue()+1).countLeadingZeros(); - // NLZ can't be BitWidth with no sign bit - APInt MaskV = APInt::getHighBitsSet(BitWidth, NLZ+1); - computeKnownBits(Op1, KnownZero2, KnownOne2, Depth + 1, Q); - - // If all of the MaskV bits are known to be zero, then we know the - // output top bits are zero, because we now know that the output is - // from [0-C]. - if ((KnownZero2 & MaskV) == MaskV) { - unsigned NLZ2 = CLHS->getValue().countLeadingZeros(); - // Top bits known zero. - KnownZero = APInt::getHighBitsSet(BitWidth, NLZ2); - } - } - } - } - unsigned BitWidth = KnownZero.getBitWidth(); // If an initial sequence of bits in the result is not needed, the @@ -282,11 +264,11 @@ static void computeKnownBitsAddSub(bool Add, const Value *Op0, const Value *Op1, computeKnownBits(Op1, KnownZero2, KnownOne2, Depth + 1, Q); // Carry in a 1 for a subtract, rather than a 0. - APInt CarryIn(BitWidth, 0); + uint64_t CarryIn = 0; if (!Add) { // Sum = LHS + ~RHS + 1 std::swap(KnownZero2, KnownOne2); - CarryIn.setBit(0); + CarryIn = 1; } APInt PossibleSumZero = ~LHSKnownZero + ~KnownZero2 + CarryIn; @@ -315,11 +297,11 @@ static void computeKnownBitsAddSub(bool Add, const Value *Op0, const Value *Op1, // Adding two non-negative numbers, or subtracting a negative number from // a non-negative one, can't wrap into negative. if (LHSKnownZero.isNegative() && KnownZero2.isNegative()) - KnownZero |= APInt::getSignBit(BitWidth); + KnownZero.setSignBit(); // Adding two negative numbers, or subtracting a non-negative number from // a negative one, can't wrap into non-negative. else if (LHSKnownOne.isNegative() && KnownOne2.isNegative()) - KnownOne |= APInt::getSignBit(BitWidth); + KnownOne.setSignBit(); } } } @@ -370,8 +352,9 @@ static void computeKnownBitsMul(const Value *Op0, const Value *Op1, bool NSW, TrailZ = std::min(TrailZ, BitWidth); LeadZ = std::min(LeadZ, BitWidth); - KnownZero = APInt::getLowBitsSet(BitWidth, TrailZ) | - APInt::getHighBitsSet(BitWidth, LeadZ); + KnownZero.clearAllBits(); + KnownZero.setLowBits(TrailZ); + KnownZero.setHighBits(LeadZ); // Only make use of no-wrap flags if we failed to compute the sign bit // directly. This matters if the multiplication always overflows, in @@ -379,9 +362,9 @@ static void computeKnownBitsMul(const Value *Op0, const Value *Op1, bool NSW, // though as the program is invoking undefined behaviour we can choose // whatever we like here. if (isKnownNonNegative && !KnownOne.isNegative()) - KnownZero.setBit(BitWidth - 1); + KnownZero.setSignBit(); else if (isKnownNegative && !KnownZero.isNegative()) - KnownOne.setBit(BitWidth - 1); + KnownOne.setSignBit(); } void llvm::computeKnownBitsFromRangeMetadata(const MDNode &Ranges, @@ -553,6 +536,13 @@ static void computeKnownBitsFromAssume(const Value *V, APInt &KnownZero, KnownOne.setAllBits(); return; } + if (match(Arg, m_Not(m_Specific(V))) && + isValidAssumeForContext(I, Q.CxtI, Q.DT)) { + assert(BitWidth == 1 && "assume operand is not i1?"); + KnownZero.setAllBits(); + KnownOne.clearAllBits(); + return; + } // The remaining tests are all recursive, so bail out if we hit the limit. if (Depth == MaxDepth) @@ -719,7 +709,7 @@ static void computeKnownBitsFromAssume(const Value *V, APInt &KnownZero, if (RHSKnownZero.isNegative()) { // We know that the sign bit is zero. - KnownZero |= APInt::getSignBit(BitWidth); + KnownZero.setSignBit(); } // assume(v >_s c) where c is at least -1. } else if (match(Arg, m_ICmp(Pred, m_V, m_Value(A))) && @@ -730,7 +720,7 @@ static void computeKnownBitsFromAssume(const Value *V, APInt &KnownZero, if (RHSKnownOne.isAllOnesValue() || RHSKnownZero.isNegative()) { // We know that the sign bit is zero. - KnownZero |= APInt::getSignBit(BitWidth); + KnownZero.setSignBit(); } // assume(v <=_s c) where c is negative } else if (match(Arg, m_ICmp(Pred, m_V, m_Value(A))) && @@ -741,7 +731,7 @@ static void computeKnownBitsFromAssume(const Value *V, APInt &KnownZero, if (RHSKnownOne.isNegative()) { // We know that the sign bit is one. - KnownOne |= APInt::getSignBit(BitWidth); + KnownOne.setSignBit(); } // assume(v <_s c) where c is non-positive } else if (match(Arg, m_ICmp(Pred, m_V, m_Value(A))) && @@ -752,7 +742,7 @@ static void computeKnownBitsFromAssume(const Value *V, APInt &KnownZero, if (RHSKnownZero.isAllOnesValue() || RHSKnownOne.isNegative()) { // We know that the sign bit is one. - KnownOne |= APInt::getSignBit(BitWidth); + KnownOne.setSignBit(); } // assume(v <=_u c) } else if (match(Arg, m_ICmp(Pred, m_V, m_Value(A))) && @@ -762,8 +752,7 @@ static void computeKnownBitsFromAssume(const Value *V, APInt &KnownZero, computeKnownBits(A, RHSKnownZero, RHSKnownOne, Depth+1, Query(Q, I)); // Whatever high bits in c are zero are known to be zero. - KnownZero |= - APInt::getHighBitsSet(BitWidth, RHSKnownZero.countLeadingOnes()); + KnownZero.setHighBits(RHSKnownZero.countLeadingOnes()); // assume(v <_u c) } else if (match(Arg, m_ICmp(Pred, m_V, m_Value(A))) && Pred == ICmpInst::ICMP_ULT && @@ -774,11 +763,27 @@ static void computeKnownBitsFromAssume(const Value *V, APInt &KnownZero, // Whatever high bits in c are zero are known to be zero (if c is a power // of 2, then one more). if (isKnownToBeAPowerOfTwo(A, false, Depth + 1, Query(Q, I))) - KnownZero |= - APInt::getHighBitsSet(BitWidth, RHSKnownZero.countLeadingOnes()+1); + KnownZero.setHighBits(RHSKnownZero.countLeadingOnes()+1); else - KnownZero |= - APInt::getHighBitsSet(BitWidth, RHSKnownZero.countLeadingOnes()); + KnownZero.setHighBits(RHSKnownZero.countLeadingOnes()); + } + } + + // If assumptions conflict with each other or previous known bits, then we + // have a logical fallacy. It's possible that the assumption is not reachable, + // so this isn't a real bug. On the other hand, the program may have undefined + // behavior, or we might have a bug in the compiler. We can't assert/crash, so + // clear out the known bits, try to warn the user, and hope for the best. + if ((KnownZero & KnownOne) != 0) { + KnownZero.clearAllBits(); + KnownOne.clearAllBits(); + + if (Q.ORE) { + auto *CxtI = const_cast<Instruction *>(Q.CxtI); + OptimizationRemarkAnalysis ORA("value-tracking", "BadAssumption", CxtI); + Q.ORE->emit(ORA << "Detected conflicting code assumptions. Program may " + "have undefined behavior, or compiler may have " + "internal error."); } } } @@ -817,6 +822,14 @@ static void computeKnownBitsFromShiftOperator( computeKnownBits(I->getOperand(1), KnownZero, KnownOne, Depth + 1, Q); + // If the shift amount could be greater than or equal to the bit-width of the LHS, the + // value could be undef, so we don't know anything about it. + if ((~KnownZero).uge(BitWidth)) { + KnownZero.clearAllBits(); + KnownOne.clearAllBits(); + return; + } + // Note: We cannot use KnownZero.getLimitedValue() here, because if // BitWidth > 64 and any upper bits are known, we'll end up returning the // limit value (which implies all bits are known). @@ -905,14 +918,15 @@ static void computeKnownBitsFromOperator(const Operator *I, APInt &KnownZero, // TODO: This could be generalized to clearing any bit set in y where the // following bit is known to be unset in y. Value *Y = nullptr; - if (match(I->getOperand(0), m_Add(m_Specific(I->getOperand(1)), - m_Value(Y))) || - match(I->getOperand(1), m_Add(m_Specific(I->getOperand(0)), - m_Value(Y)))) { - APInt KnownZero3(BitWidth, 0), KnownOne3(BitWidth, 0); - computeKnownBits(Y, KnownZero3, KnownOne3, Depth + 1, Q); - if (KnownOne3.countTrailingOnes() > 0) - KnownZero |= APInt::getLowBitsSet(BitWidth, 1); + if (!KnownZero[0] && !KnownOne[0] && + (match(I->getOperand(0), m_Add(m_Specific(I->getOperand(1)), + m_Value(Y))) || + match(I->getOperand(1), m_Add(m_Specific(I->getOperand(0)), + m_Value(Y))))) { + KnownZero2.clearAllBits(); KnownOne2.clearAllBits(); + computeKnownBits(Y, KnownZero2, KnownOne2, Depth + 1, Q); + if (KnownOne2.countTrailingOnes() > 0) + KnownZero.setBit(0); } break; } @@ -934,7 +948,7 @@ static void computeKnownBitsFromOperator(const Operator *I, APInt &KnownZero, APInt KnownZeroOut = (KnownZero & KnownZero2) | (KnownOne & KnownOne2); // Output known-1 are known to be set if set in only one of the LHS, RHS. KnownOne = (KnownZero & KnownOne2) | (KnownOne & KnownZero2); - KnownZero = KnownZeroOut; + KnownZero = std::move(KnownZeroOut); break; } case Instruction::Mul: { @@ -958,15 +972,11 @@ static void computeKnownBitsFromOperator(const Operator *I, APInt &KnownZero, LeadZ = std::min(BitWidth, LeadZ + BitWidth - RHSUnknownLeadingOnes - 1); - KnownZero = APInt::getHighBitsSet(BitWidth, LeadZ); + KnownZero.setHighBits(LeadZ); break; } case Instruction::Select: { - computeKnownBits(I->getOperand(2), KnownZero, KnownOne, Depth + 1, Q); - computeKnownBits(I->getOperand(1), KnownZero2, KnownOne2, Depth + 1, Q); - - const Value *LHS; - const Value *RHS; + const Value *LHS, *RHS; SelectPatternFlavor SPF = matchSelectPattern(I, LHS, RHS).Flavor; if (SelectPatternResult::isMinOrMax(SPF)) { computeKnownBits(RHS, KnownZero, KnownOne, Depth + 1, Q); @@ -980,23 +990,23 @@ static void computeKnownBitsFromOperator(const Operator *I, APInt &KnownZero, unsigned MaxHighZeros = 0; if (SPF == SPF_SMAX) { // If both sides are negative, the result is negative. - if (KnownOne[BitWidth - 1] && KnownOne2[BitWidth - 1]) + if (KnownOne.isNegative() && KnownOne2.isNegative()) // We can derive a lower bound on the result by taking the max of the // leading one bits. MaxHighOnes = std::max(KnownOne.countLeadingOnes(), KnownOne2.countLeadingOnes()); // If either side is non-negative, the result is non-negative. - else if (KnownZero[BitWidth - 1] || KnownZero2[BitWidth - 1]) + else if (KnownZero.isNegative() || KnownZero2.isNegative()) MaxHighZeros = 1; } else if (SPF == SPF_SMIN) { // If both sides are non-negative, the result is non-negative. - if (KnownZero[BitWidth - 1] && KnownZero2[BitWidth - 1]) + if (KnownZero.isNegative() && KnownZero2.isNegative()) // We can derive an upper bound on the result by taking the max of the // leading zero bits. MaxHighZeros = std::max(KnownZero.countLeadingOnes(), KnownZero2.countLeadingOnes()); // If either side is negative, the result is negative. - else if (KnownOne[BitWidth - 1] || KnownOne2[BitWidth - 1]) + else if (KnownOne.isNegative() || KnownOne2.isNegative()) MaxHighOnes = 1; } else if (SPF == SPF_UMAX) { // We can derive a lower bound on the result by taking the max of the @@ -1014,9 +1024,9 @@ static void computeKnownBitsFromOperator(const Operator *I, APInt &KnownZero, KnownOne &= KnownOne2; KnownZero &= KnownZero2; if (MaxHighOnes > 0) - KnownOne |= APInt::getHighBitsSet(BitWidth, MaxHighOnes); + KnownOne.setHighBits(MaxHighOnes); if (MaxHighZeros > 0) - KnownZero |= APInt::getHighBitsSet(BitWidth, MaxHighZeros); + KnownZero.setHighBits(MaxHighZeros); break; } case Instruction::FPTrunc: @@ -1047,7 +1057,7 @@ static void computeKnownBitsFromOperator(const Operator *I, APInt &KnownZero, KnownOne = KnownOne.zextOrTrunc(BitWidth); // Any top bits are known to be zero. if (BitWidth > SrcBitWidth) - KnownZero |= APInt::getHighBitsSet(BitWidth, BitWidth - SrcBitWidth); + KnownZero.setBitsFrom(SrcBitWidth); break; } case Instruction::BitCast: { @@ -1068,35 +1078,29 @@ static void computeKnownBitsFromOperator(const Operator *I, APInt &KnownZero, KnownZero = KnownZero.trunc(SrcBitWidth); KnownOne = KnownOne.trunc(SrcBitWidth); computeKnownBits(I->getOperand(0), KnownZero, KnownOne, Depth + 1, Q); - KnownZero = KnownZero.zext(BitWidth); - KnownOne = KnownOne.zext(BitWidth); - // If the sign bit of the input is known set or clear, then we know the // top bits of the result. - if (KnownZero[SrcBitWidth-1]) // Input sign bit known zero - KnownZero |= APInt::getHighBitsSet(BitWidth, BitWidth - SrcBitWidth); - else if (KnownOne[SrcBitWidth-1]) // Input sign bit known set - KnownOne |= APInt::getHighBitsSet(BitWidth, BitWidth - SrcBitWidth); + KnownZero = KnownZero.sext(BitWidth); + KnownOne = KnownOne.sext(BitWidth); break; } case Instruction::Shl: { // (shl X, C1) & C2 == 0 iff (X & C2 >>u C1) == 0 bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap(); - auto KZF = [BitWidth, NSW](const APInt &KnownZero, unsigned ShiftAmt) { - APInt KZResult = - (KnownZero << ShiftAmt) | - APInt::getLowBitsSet(BitWidth, ShiftAmt); // Low bits known 0. + auto KZF = [NSW](const APInt &KnownZero, unsigned ShiftAmt) { + APInt KZResult = KnownZero << ShiftAmt; + KZResult.setLowBits(ShiftAmt); // Low bits known 0. // If this shift has "nsw" keyword, then the result is either a poison // value or has the same sign bit as the first operand. if (NSW && KnownZero.isNegative()) - KZResult.setBit(BitWidth - 1); + KZResult.setSignBit(); return KZResult; }; - auto KOF = [BitWidth, NSW](const APInt &KnownOne, unsigned ShiftAmt) { + auto KOF = [NSW](const APInt &KnownOne, unsigned ShiftAmt) { APInt KOResult = KnownOne << ShiftAmt; if (NSW && KnownOne.isNegative()) - KOResult.setBit(BitWidth - 1); + KOResult.setSignBit(); return KOResult; }; @@ -1108,13 +1112,13 @@ static void computeKnownBitsFromOperator(const Operator *I, APInt &KnownZero, case Instruction::LShr: { // (ushr X, C1) & C2 == 0 iff (-1 >> C1) & C2 == 0 auto KZF = [BitWidth](const APInt &KnownZero, unsigned ShiftAmt) { - return APIntOps::lshr(KnownZero, ShiftAmt) | + return KnownZero.lshr(ShiftAmt) | // High bits known zero. APInt::getHighBitsSet(BitWidth, ShiftAmt); }; - auto KOF = [BitWidth](const APInt &KnownOne, unsigned ShiftAmt) { - return APIntOps::lshr(KnownOne, ShiftAmt); + auto KOF = [](const APInt &KnownOne, unsigned ShiftAmt) { + return KnownOne.lshr(ShiftAmt); }; computeKnownBitsFromShiftOperator(I, KnownZero, KnownOne, @@ -1124,12 +1128,12 @@ static void computeKnownBitsFromOperator(const Operator *I, APInt &KnownZero, } case Instruction::AShr: { // (ashr X, C1) & C2 == 0 iff (-1 >> C1) & C2 == 0 - auto KZF = [BitWidth](const APInt &KnownZero, unsigned ShiftAmt) { - return APIntOps::ashr(KnownZero, ShiftAmt); + auto KZF = [](const APInt &KnownZero, unsigned ShiftAmt) { + return KnownZero.ashr(ShiftAmt); }; - auto KOF = [BitWidth](const APInt &KnownOne, unsigned ShiftAmt) { - return APIntOps::ashr(KnownOne, ShiftAmt); + auto KOF = [](const APInt &KnownOne, unsigned ShiftAmt) { + return KnownOne.ashr(ShiftAmt); }; computeKnownBitsFromShiftOperator(I, KnownZero, KnownOne, @@ -1165,12 +1169,12 @@ static void computeKnownBitsFromOperator(const Operator *I, APInt &KnownZero, // If the first operand is non-negative or has all low bits zero, then // the upper bits are all zero. - if (KnownZero2[BitWidth-1] || ((KnownZero2 & LowBits) == LowBits)) + if (KnownZero2.isNegative() || ((KnownZero2 & LowBits) == LowBits)) KnownZero |= ~LowBits; // If the first operand is negative and not all low bits are zero, then // the upper bits are all one. - if (KnownOne2[BitWidth-1] && ((KnownOne2 & LowBits) != 0)) + if (KnownOne2.isNegative() && ((KnownOne2 & LowBits) != 0)) KnownOne |= ~LowBits; assert((KnownZero & KnownOne) == 0 && "Bits known to be one AND zero?"); @@ -1185,7 +1189,7 @@ static void computeKnownBitsFromOperator(const Operator *I, APInt &KnownZero, Q); // If it's known zero, our sign bit is also zero. if (LHSKnownZero.isNegative()) - KnownZero.setBit(BitWidth - 1); + KnownZero.setSignBit(); } break; @@ -1209,7 +1213,8 @@ static void computeKnownBitsFromOperator(const Operator *I, APInt &KnownZero, unsigned Leaders = std::max(KnownZero.countLeadingOnes(), KnownZero2.countLeadingOnes()); KnownOne.clearAllBits(); - KnownZero = APInt::getHighBitsSet(BitWidth, Leaders); + KnownZero.clearAllBits(); + KnownZero.setHighBits(Leaders); break; } @@ -1220,7 +1225,7 @@ static void computeKnownBitsFromOperator(const Operator *I, APInt &KnownZero, Align = Q.DL.getABITypeAlignment(AI->getAllocatedType()); if (Align > 0) - KnownZero = APInt::getLowBitsSet(BitWidth, countTrailingZeros(Align)); + KnownZero.setLowBits(countTrailingZeros(Align)); break; } case Instruction::GetElementPtr: { @@ -1267,7 +1272,7 @@ static void computeKnownBitsFromOperator(const Operator *I, APInt &KnownZero, } } - KnownZero = APInt::getLowBitsSet(BitWidth, TrailZ); + KnownZero.setLowBits(TrailZ); break; } case Instruction::PHI: { @@ -1308,9 +1313,8 @@ static void computeKnownBitsFromOperator(const Operator *I, APInt &KnownZero, APInt KnownZero3(KnownZero), KnownOne3(KnownOne); computeKnownBits(L, KnownZero3, KnownOne3, Depth + 1, Q); - KnownZero = APInt::getLowBitsSet( - BitWidth, std::min(KnownZero2.countTrailingOnes(), - KnownZero3.countTrailingOnes())); + KnownZero.setLowBits(std::min(KnownZero2.countTrailingOnes(), + KnownZero3.countTrailingOnes())); if (DontImproveNonNegativePhiBits) break; @@ -1328,24 +1332,24 @@ static void computeKnownBitsFromOperator(const Operator *I, APInt &KnownZero, // (add negative, negative) --> negative if (Opcode == Instruction::Add) { if (KnownZero2.isNegative() && KnownZero3.isNegative()) - KnownZero.setBit(BitWidth - 1); + KnownZero.setSignBit(); else if (KnownOne2.isNegative() && KnownOne3.isNegative()) - KnownOne.setBit(BitWidth - 1); + KnownOne.setSignBit(); } // (sub nsw non-negative, negative) --> non-negative // (sub nsw negative, non-negative) --> negative else if (Opcode == Instruction::Sub && LL == I) { if (KnownZero2.isNegative() && KnownOne3.isNegative()) - KnownZero.setBit(BitWidth - 1); + KnownZero.setSignBit(); else if (KnownOne2.isNegative() && KnownZero3.isNegative()) - KnownOne.setBit(BitWidth - 1); + KnownOne.setSignBit(); } // (mul nsw non-negative, non-negative) --> non-negative else if (Opcode == Instruction::Mul && KnownZero2.isNegative() && KnownZero3.isNegative()) - KnownZero.setBit(BitWidth - 1); + KnownZero.setSignBit(); } break; @@ -1364,8 +1368,8 @@ static void computeKnownBitsFromOperator(const Operator *I, APInt &KnownZero, if (dyn_cast_or_null<UndefValue>(P->hasConstantValue())) break; - KnownZero = APInt::getAllOnesValue(BitWidth); - KnownOne = APInt::getAllOnesValue(BitWidth); + KnownZero.setAllBits(); + KnownOne.setAllBits(); for (Value *IncValue : P->incoming_values()) { // Skip direct self references. if (IncValue == P) continue; @@ -1400,6 +1404,11 @@ static void computeKnownBitsFromOperator(const Operator *I, APInt &KnownZero, if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) { switch (II->getIntrinsicID()) { default: break; + case Intrinsic::bitreverse: + computeKnownBits(I->getOperand(0), KnownZero2, KnownOne2, Depth + 1, Q); + KnownZero |= KnownZero2.reverseBits(); + KnownOne |= KnownOne2.reverseBits(); + break; case Intrinsic::bswap: computeKnownBits(I->getOperand(0), KnownZero2, KnownOne2, Depth + 1, Q); KnownZero |= KnownZero2.byteSwap(); @@ -1411,7 +1420,7 @@ static void computeKnownBitsFromOperator(const Operator *I, APInt &KnownZero, // If this call is undefined for 0, the result will be less than 2^n. if (II->getArgOperand(1) == ConstantInt::getTrue(II->getContext())) LowBits -= 1; - KnownZero |= APInt::getHighBitsSet(BitWidth, BitWidth - LowBits); + KnownZero.setBitsFrom(LowBits); break; } case Intrinsic::ctpop: { @@ -1419,17 +1428,14 @@ static void computeKnownBitsFromOperator(const Operator *I, APInt &KnownZero, // We can bound the space the count needs. Also, bits known to be zero // can't contribute to the population. unsigned BitsPossiblySet = BitWidth - KnownZero2.countPopulation(); - unsigned LeadingZeros = - APInt(BitWidth, BitsPossiblySet).countLeadingZeros(); - assert(LeadingZeros <= BitWidth); - KnownZero |= APInt::getHighBitsSet(BitWidth, LeadingZeros); - KnownOne &= ~KnownZero; + unsigned LowBits = Log2_32(BitsPossiblySet)+1; + KnownZero.setBitsFrom(LowBits); // TODO: we could bound KnownOne using the lower bound on the number // of bits which might be set provided by popcnt KnownOne2. break; } case Intrinsic::x86_sse42_crc32_64_64: - KnownZero |= APInt::getHighBitsSet(64, 32); + KnownZero.setBitsFrom(32); break; } } @@ -1502,6 +1508,7 @@ void computeKnownBits(const Value *V, APInt &KnownZero, APInt &KnownOne, KnownZero.getBitWidth() == BitWidth && KnownOne.getBitWidth() == BitWidth && "V, KnownOne and KnownZero should have same BitWidth"); + (void)BitWidth; const APInt *C; if (match(V, m_APInt(C))) { @@ -1513,7 +1520,7 @@ void computeKnownBits(const Value *V, APInt &KnownZero, APInt &KnownOne, // Null and aggregate-zero are all-zeros. if (isa<ConstantPointerNull>(V) || isa<ConstantAggregateZero>(V)) { KnownOne.clearAllBits(); - KnownZero = APInt::getAllOnesValue(BitWidth); + KnownZero.setAllBits(); return; } // Handle a constant vector by taking the intersection of the known bits of @@ -1582,7 +1589,7 @@ void computeKnownBits(const Value *V, APInt &KnownZero, APInt &KnownOne, if (V->getType()->isPointerTy()) { unsigned Align = V->getPointerAlignment(Q.DL); if (Align) - KnownZero |= APInt::getLowBitsSet(BitWidth, countTrailingZeros(Align)); + KnownZero.setLowBits(countTrailingZeros(Align)); } // computeKnownBitsFromAssume strictly refines KnownZero and @@ -1607,8 +1614,8 @@ void ComputeSignBit(const Value *V, bool &KnownZero, bool &KnownOne, APInt ZeroBits(BitWidth, 0); APInt OneBits(BitWidth, 0); computeKnownBits(V, ZeroBits, OneBits, Depth, Q); - KnownOne = OneBits[BitWidth - 1]; - KnownZero = ZeroBits[BitWidth - 1]; + KnownOne = OneBits.isNegative(); + KnownZero = ZeroBits.isNegative(); } /// Return true if the given value is known to have exactly one @@ -1788,10 +1795,12 @@ static bool rangeMetadataExcludesValue(const MDNode* Ranges, const APInt& Value) return true; } -/// Return true if the given value is known to be non-zero when defined. -/// For vectors return true if every element is known to be non-zero when -/// defined. Supports values with integer or pointer type and vectors of -/// integers. +/// Return true if the given value is known to be non-zero when defined. For +/// vectors, return true if every element is known to be non-zero when +/// defined. For pointers, if the context instruction and dominator tree are +/// specified, perform context-sensitive analysis and return true if the +/// pointer couldn't possibly be null at the specified instruction. +/// Supports values with integer or pointer type and vectors of integers. bool isKnownNonZero(const Value *V, unsigned Depth, const Query &Q) { if (auto *C = dyn_cast<Constant>(V)) { if (C->isNullValue()) @@ -1834,7 +1843,7 @@ bool isKnownNonZero(const Value *V, unsigned Depth, const Query &Q) { // Check for pointer simplifications. if (V->getType()->isPointerTy()) { - if (isKnownNonNull(V)) + if (isKnownNonNullAt(V, Q.CxtI, Q.DT)) return true; if (const GEPOperator *GEP = dyn_cast<GEPOperator>(V)) if (isGEPKnownNonNull(GEP, Depth, Q)) @@ -2075,13 +2084,29 @@ static unsigned computeNumSignBitsVectorConstant(const Value *V, return MinSignBits; } +static unsigned ComputeNumSignBitsImpl(const Value *V, unsigned Depth, + const Query &Q); + +static unsigned ComputeNumSignBits(const Value *V, unsigned Depth, + const Query &Q) { + unsigned Result = ComputeNumSignBitsImpl(V, Depth, Q); + assert(Result > 0 && "At least one sign bit needs to be present!"); + return Result; +} + /// Return the number of times the sign bit of the register is replicated into /// the other bits. We know that at least 1 bit is always equal to the sign bit /// (itself), but other cases can give us information. For example, immediately /// after an "ashr X, 2", we know that the top 3 bits are all equal to each /// other, so we return 3. For vectors, return the number of sign bits for the /// vector element with the mininum number of known sign bits. -unsigned ComputeNumSignBits(const Value *V, unsigned Depth, const Query &Q) { +static unsigned ComputeNumSignBitsImpl(const Value *V, unsigned Depth, + const Query &Q) { + + // We return the minimum number of sign bits that are guaranteed to be present + // in V, so for undef we have to conservatively return 1. We don't have the + // same behavior for poison though -- that's a FIXME today. + unsigned TyBits = Q.DL.getTypeSizeInBits(V->getType()->getScalarType()); unsigned Tmp, Tmp2; unsigned FirstAnswer = 1; @@ -2157,7 +2182,10 @@ unsigned ComputeNumSignBits(const Value *V, unsigned Depth, const Query &Q) { // ashr X, C -> adds C sign bits. Vectors too. const APInt *ShAmt; if (match(U->getOperand(1), m_APInt(ShAmt))) { - Tmp += ShAmt->getZExtValue(); + unsigned ShAmtLimited = ShAmt->getZExtValue(); + if (ShAmtLimited >= TyBits) + break; // Bad shift. + Tmp += ShAmtLimited; if (Tmp > TyBits) Tmp = TyBits; } return Tmp; @@ -2436,7 +2464,7 @@ Intrinsic::ID llvm::getIntrinsicForCallSite(ImmutableCallSite ICS, if (!TLI) return Intrinsic::not_intrinsic; - LibFunc::Func Func; + LibFunc Func; // We're going to make assumptions on the semantics of the functions, check // that the target knows that it's available in this environment and it does // not have local linkage. @@ -2451,81 +2479,81 @@ Intrinsic::ID llvm::getIntrinsicForCallSite(ImmutableCallSite ICS, switch (Func) { default: break; - case LibFunc::sin: - case LibFunc::sinf: - case LibFunc::sinl: + case LibFunc_sin: + case LibFunc_sinf: + case LibFunc_sinl: return Intrinsic::sin; - case LibFunc::cos: - case LibFunc::cosf: - case LibFunc::cosl: + case LibFunc_cos: + case LibFunc_cosf: + case LibFunc_cosl: return Intrinsic::cos; - case LibFunc::exp: - case LibFunc::expf: - case LibFunc::expl: + case LibFunc_exp: + case LibFunc_expf: + case LibFunc_expl: return Intrinsic::exp; - case LibFunc::exp2: - case LibFunc::exp2f: - case LibFunc::exp2l: + case LibFunc_exp2: + case LibFunc_exp2f: + case LibFunc_exp2l: return Intrinsic::exp2; - case LibFunc::log: - case LibFunc::logf: - case LibFunc::logl: + case LibFunc_log: + case LibFunc_logf: + case LibFunc_logl: return Intrinsic::log; - case LibFunc::log10: - case LibFunc::log10f: - case LibFunc::log10l: + case LibFunc_log10: + case LibFunc_log10f: + case LibFunc_log10l: return Intrinsic::log10; - case LibFunc::log2: - case LibFunc::log2f: - case LibFunc::log2l: + case LibFunc_log2: + case LibFunc_log2f: + case LibFunc_log2l: return Intrinsic::log2; - case LibFunc::fabs: - case LibFunc::fabsf: - case LibFunc::fabsl: + case LibFunc_fabs: + case LibFunc_fabsf: + case LibFunc_fabsl: return Intrinsic::fabs; - case LibFunc::fmin: - case LibFunc::fminf: - case LibFunc::fminl: + case LibFunc_fmin: + case LibFunc_fminf: + case LibFunc_fminl: return Intrinsic::minnum; - case LibFunc::fmax: - case LibFunc::fmaxf: - case LibFunc::fmaxl: + case LibFunc_fmax: + case LibFunc_fmaxf: + case LibFunc_fmaxl: return Intrinsic::maxnum; - case LibFunc::copysign: - case LibFunc::copysignf: - case LibFunc::copysignl: + case LibFunc_copysign: + case LibFunc_copysignf: + case LibFunc_copysignl: return Intrinsic::copysign; - case LibFunc::floor: - case LibFunc::floorf: - case LibFunc::floorl: + case LibFunc_floor: + case LibFunc_floorf: + case LibFunc_floorl: return Intrinsic::floor; - case LibFunc::ceil: - case LibFunc::ceilf: - case LibFunc::ceill: + case LibFunc_ceil: + case LibFunc_ceilf: + case LibFunc_ceill: return Intrinsic::ceil; - case LibFunc::trunc: - case LibFunc::truncf: - case LibFunc::truncl: + case LibFunc_trunc: + case LibFunc_truncf: + case LibFunc_truncl: return Intrinsic::trunc; - case LibFunc::rint: - case LibFunc::rintf: - case LibFunc::rintl: + case LibFunc_rint: + case LibFunc_rintf: + case LibFunc_rintl: return Intrinsic::rint; - case LibFunc::nearbyint: - case LibFunc::nearbyintf: - case LibFunc::nearbyintl: + case LibFunc_nearbyint: + case LibFunc_nearbyintf: + case LibFunc_nearbyintl: return Intrinsic::nearbyint; - case LibFunc::round: - case LibFunc::roundf: - case LibFunc::roundl: + case LibFunc_round: + case LibFunc_roundf: + case LibFunc_roundl: return Intrinsic::round; - case LibFunc::pow: - case LibFunc::powf: - case LibFunc::powl: + case LibFunc_pow: + case LibFunc_powf: + case LibFunc_powl: return Intrinsic::pow; - case LibFunc::sqrt: - case LibFunc::sqrtf: - case LibFunc::sqrtl: + case LibFunc_sqrt: + case LibFunc_sqrtf: + case LibFunc_sqrtl: if (ICS->hasNoNaNs()) return Intrinsic::sqrt; return Intrinsic::not_intrinsic; @@ -2590,6 +2618,11 @@ static bool cannotBeOrderedLessThanZeroImpl(const Value *V, const TargetLibraryInfo *TLI, bool SignBitOnly, unsigned Depth) { + // TODO: This function does not do the right thing when SignBitOnly is true + // and we're lowering to a hypothetical IEEE 754-compliant-but-evil platform + // which flips the sign bits of NaNs. See + // https://llvm.org/bugs/show_bug.cgi?id=31702. + if (const ConstantFP *CFP = dyn_cast<ConstantFP>(V)) { return !CFP->getValueAPF().isNegative() || (!SignBitOnly && CFP->getValueAPF().isZero()); @@ -2633,7 +2666,8 @@ static bool cannotBeOrderedLessThanZeroImpl(const Value *V, return cannotBeOrderedLessThanZeroImpl(I->getOperand(0), TLI, SignBitOnly, Depth + 1); case Instruction::Call: - Intrinsic::ID IID = getIntrinsicForCallSite(cast<CallInst>(I), TLI); + const auto *CI = cast<CallInst>(I); + Intrinsic::ID IID = getIntrinsicForCallSite(CI, TLI); switch (IID) { default: break; @@ -2650,16 +2684,37 @@ static bool cannotBeOrderedLessThanZeroImpl(const Value *V, case Intrinsic::exp: case Intrinsic::exp2: case Intrinsic::fabs: - case Intrinsic::sqrt: return true; + + case Intrinsic::sqrt: + // sqrt(x) is always >= -0 or NaN. Moreover, sqrt(x) == -0 iff x == -0. + if (!SignBitOnly) + return true; + return CI->hasNoNaNs() && (CI->hasNoSignedZeros() || + CannotBeNegativeZero(CI->getOperand(0), TLI)); + case Intrinsic::powi: - if (ConstantInt *CI = dyn_cast<ConstantInt>(I->getOperand(1))) { + if (ConstantInt *Exponent = dyn_cast<ConstantInt>(I->getOperand(1))) { // powi(x,n) is non-negative if n is even. - if (CI->getBitWidth() <= 64 && CI->getSExtValue() % 2u == 0) + if (Exponent->getBitWidth() <= 64 && Exponent->getSExtValue() % 2u == 0) return true; } + // TODO: This is not correct. Given that exp is an integer, here are the + // ways that pow can return a negative value: + // + // pow(x, exp) --> negative if exp is odd and x is negative. + // pow(-0, exp) --> -inf if exp is negative odd. + // pow(-0, exp) --> -0 if exp is positive odd. + // pow(-inf, exp) --> -0 if exp is negative odd. + // pow(-inf, exp) --> -inf if exp is positive odd. + // + // Therefore, if !SignBitOnly, we can return true if x >= +0 or x is NaN, + // but we must return false if x == -0. Unfortunately we do not currently + // have a way of expressing this constraint. See details in + // https://llvm.org/bugs/show_bug.cgi?id=31702. return cannotBeOrderedLessThanZeroImpl(I->getOperand(0), TLI, SignBitOnly, Depth + 1); + case Intrinsic::fma: case Intrinsic::fmuladd: // x*x+y is non-negative if y is non-negative. @@ -3150,6 +3205,9 @@ Value *llvm::GetUnderlyingObject(Value *V, const DataLayout &DL, if (GA->isInterposable()) return V; V = GA->getAliasee(); + } else if (isa<AllocaInst>(V)) { + // An alloca can't be further simplified. + return V; } else { if (auto CS = CallSite(V)) if (Value *RV = CS.getReturnedArgOperand()) { @@ -3327,6 +3385,10 @@ bool llvm::isSafeToSpeculativelyExecute(const Value *V, case Intrinsic::rint: case Intrinsic::round: return true; + // These intrinsics do not correspond to any libm function, and + // do not set errno. + case Intrinsic::powi: + return true; // TODO: are convert_{from,to}_fp16 safe? // TODO: can we list target-specific intrinsics here? default: break; @@ -3406,6 +3468,16 @@ static bool isKnownNonNullFromDominatingCondition(const Value *V, if (NumUsesExplored >= DomConditionsMaxUses) break; NumUsesExplored++; + + // If the value is used as an argument to a call or invoke, then argument + // attributes may provide an answer about null-ness. + if (auto CS = ImmutableCallSite(U)) + if (auto *CalledFunc = CS.getCalledFunction()) + for (const Argument &Arg : CalledFunc->args()) + if (CS.getArgOperand(Arg.getArgNo()) == V && + Arg.hasNonNullAttr() && DT->dominates(CS.getInstruction(), CtxI)) + return true; + // Consider only compare instructions uniquely controlling a branch CmpInst::Predicate Pred; if (!match(const_cast<User *>(U), @@ -3683,6 +3755,8 @@ bool llvm::isGuaranteedToTransferExecutionToSuccessor(const Instruction *I) { return false; if (isa<ReturnInst>(I)) return false; + if (isa<UnreachableInst>(I)) + return false; // Calls can throw, or contain an infinite loop, or kill the process. if (auto CS = ImmutableCallSite(I)) { @@ -3731,79 +3805,33 @@ bool llvm::isGuaranteedToExecuteForEveryIteration(const Instruction *I, bool llvm::propagatesFullPoison(const Instruction *I) { switch (I->getOpcode()) { - case Instruction::Add: - case Instruction::Sub: - case Instruction::Xor: - case Instruction::Trunc: - case Instruction::BitCast: - case Instruction::AddrSpaceCast: - // These operations all propagate poison unconditionally. Note that poison - // is not any particular value, so xor or subtraction of poison with - // itself still yields poison, not zero. - return true; - - case Instruction::AShr: - case Instruction::SExt: - // For these operations, one bit of the input is replicated across - // multiple output bits. A replicated poison bit is still poison. - return true; - - case Instruction::Shl: { - // Left shift *by* a poison value is poison. The number of - // positions to shift is unsigned, so no negative values are - // possible there. Left shift by zero places preserves poison. So - // it only remains to consider left shift of poison by a positive - // number of places. - // - // A left shift by a positive number of places leaves the lowest order bit - // non-poisoned. However, if such a shift has a no-wrap flag, then we can - // make the poison operand violate that flag, yielding a fresh full-poison - // value. - auto *OBO = cast<OverflowingBinaryOperator>(I); - return OBO->hasNoUnsignedWrap() || OBO->hasNoSignedWrap(); - } - - case Instruction::Mul: { - // A multiplication by zero yields a non-poison zero result, so we need to - // rule out zero as an operand. Conservatively, multiplication by a - // non-zero constant is not multiplication by zero. - // - // Multiplication by a non-zero constant can leave some bits - // non-poisoned. For example, a multiplication by 2 leaves the lowest - // order bit unpoisoned. So we need to consider that. - // - // Multiplication by 1 preserves poison. If the multiplication has a - // no-wrap flag, then we can make the poison operand violate that flag - // when multiplied by any integer other than 0 and 1. - auto *OBO = cast<OverflowingBinaryOperator>(I); - if (OBO->hasNoUnsignedWrap() || OBO->hasNoSignedWrap()) { - for (Value *V : OBO->operands()) { - if (auto *CI = dyn_cast<ConstantInt>(V)) { - // A ConstantInt cannot yield poison, so we can assume that it is - // the other operand that is poison. - return !CI->isZero(); - } - } - } - return false; - } + case Instruction::Add: + case Instruction::Sub: + case Instruction::Xor: + case Instruction::Trunc: + case Instruction::BitCast: + case Instruction::AddrSpaceCast: + case Instruction::Mul: + case Instruction::Shl: + case Instruction::GetElementPtr: + // These operations all propagate poison unconditionally. Note that poison + // is not any particular value, so xor or subtraction of poison with + // itself still yields poison, not zero. + return true; - case Instruction::ICmp: - // Comparing poison with any value yields poison. This is why, for - // instance, x s< (x +nsw 1) can be folded to true. - return true; + case Instruction::AShr: + case Instruction::SExt: + // For these operations, one bit of the input is replicated across + // multiple output bits. A replicated poison bit is still poison. + return true; - case Instruction::GetElementPtr: - // A GEP implicitly represents a sequence of additions, subtractions, - // truncations, sign extensions and multiplications. The multiplications - // are by the non-zero sizes of some set of types, so we do not have to be - // concerned with multiplication by zero. If the GEP is in-bounds, then - // these operations are implicitly no-signed-wrap so poison is propagated - // by the arguments above for Add, Sub, Trunc, SExt and Mul. - return cast<GEPOperator>(I)->isInBounds(); + case Instruction::ICmp: + // Comparing poison with any value yields poison. This is why, for + // instance, x s< (x +nsw 1) can be folded to true. + return true; - default: - return false; + default: + return false; } } @@ -3906,6 +3934,37 @@ static SelectPatternResult matchMinMax(CmpInst::Predicate Pred, Value *CmpLHS, Value *CmpRHS, Value *TrueVal, Value *FalseVal, Value *&LHS, Value *&RHS) { + // Assume success. If there's no match, callers should not use these anyway. + LHS = TrueVal; + RHS = FalseVal; + + // Recognize variations of: + // CLAMP(v,l,h) ==> ((v) < (l) ? (l) : ((v) > (h) ? (h) : (v))) + const APInt *C1; + if (CmpRHS == TrueVal && match(CmpRHS, m_APInt(C1))) { + const APInt *C2; + + // (X <s C1) ? C1 : SMIN(X, C2) ==> SMAX(SMIN(X, C2), C1) + if (match(FalseVal, m_SMin(m_Specific(CmpLHS), m_APInt(C2))) && + C1->slt(*C2) && Pred == CmpInst::ICMP_SLT) + return {SPF_SMAX, SPNB_NA, false}; + + // (X >s C1) ? C1 : SMAX(X, C2) ==> SMIN(SMAX(X, C2), C1) + if (match(FalseVal, m_SMax(m_Specific(CmpLHS), m_APInt(C2))) && + C1->sgt(*C2) && Pred == CmpInst::ICMP_SGT) + return {SPF_SMIN, SPNB_NA, false}; + + // (X <u C1) ? C1 : UMIN(X, C2) ==> UMAX(UMIN(X, C2), C1) + if (match(FalseVal, m_UMin(m_Specific(CmpLHS), m_APInt(C2))) && + C1->ult(*C2) && Pred == CmpInst::ICMP_ULT) + return {SPF_UMAX, SPNB_NA, false}; + + // (X >u C1) ? C1 : UMAX(X, C2) ==> UMIN(UMAX(X, C2), C1) + if (match(FalseVal, m_UMax(m_Specific(CmpLHS), m_APInt(C2))) && + C1->ugt(*C2) && Pred == CmpInst::ICMP_UGT) + return {SPF_UMIN, SPNB_NA, false}; + } + if (Pred != CmpInst::ICMP_SGT && Pred != CmpInst::ICMP_SLT) return {SPF_UNKNOWN, SPNB_NA, false}; @@ -3913,23 +3972,16 @@ static SelectPatternResult matchMinMax(CmpInst::Predicate Pred, // (X >s Y) ? 0 : Z ==> (Z >s 0) ? 0 : Z ==> SMIN(Z, 0) // (X <s Y) ? 0 : Z ==> (Z <s 0) ? 0 : Z ==> SMAX(Z, 0) if (match(TrueVal, m_Zero()) && - match(FalseVal, m_NSWSub(m_Specific(CmpLHS), m_Specific(CmpRHS)))) { - LHS = TrueVal; - RHS = FalseVal; + match(FalseVal, m_NSWSub(m_Specific(CmpLHS), m_Specific(CmpRHS)))) return {Pred == CmpInst::ICMP_SGT ? SPF_SMIN : SPF_SMAX, SPNB_NA, false}; - } // Z = X -nsw Y // (X >s Y) ? Z : 0 ==> (Z >s 0) ? Z : 0 ==> SMAX(Z, 0) // (X <s Y) ? Z : 0 ==> (Z <s 0) ? Z : 0 ==> SMIN(Z, 0) if (match(FalseVal, m_Zero()) && - match(TrueVal, m_NSWSub(m_Specific(CmpLHS), m_Specific(CmpRHS)))) { - LHS = TrueVal; - RHS = FalseVal; + match(TrueVal, m_NSWSub(m_Specific(CmpLHS), m_Specific(CmpRHS)))) return {Pred == CmpInst::ICMP_SGT ? SPF_SMAX : SPF_SMIN, SPNB_NA, false}; - } - const APInt *C1; if (!match(CmpRHS, m_APInt(C1))) return {SPF_UNKNOWN, SPNB_NA, false}; @@ -3940,41 +3992,29 @@ static SelectPatternResult matchMinMax(CmpInst::Predicate Pred, // Is the sign bit set? // (X <s 0) ? X : MAXVAL ==> (X >u MAXVAL) ? X : MAXVAL ==> UMAX // (X <s 0) ? MAXVAL : X ==> (X >u MAXVAL) ? MAXVAL : X ==> UMIN - if (Pred == CmpInst::ICMP_SLT && *C1 == 0 && C2->isMaxSignedValue()) { - LHS = TrueVal; - RHS = FalseVal; + if (Pred == CmpInst::ICMP_SLT && *C1 == 0 && C2->isMaxSignedValue()) return {CmpLHS == TrueVal ? SPF_UMAX : SPF_UMIN, SPNB_NA, false}; - } // Is the sign bit clear? // (X >s -1) ? MINVAL : X ==> (X <u MINVAL) ? MINVAL : X ==> UMAX // (X >s -1) ? X : MINVAL ==> (X <u MINVAL) ? X : MINVAL ==> UMIN if (Pred == CmpInst::ICMP_SGT && C1->isAllOnesValue() && - C2->isMinSignedValue()) { - LHS = TrueVal; - RHS = FalseVal; + C2->isMinSignedValue()) return {CmpLHS == FalseVal ? SPF_UMAX : SPF_UMIN, SPNB_NA, false}; - } } // Look through 'not' ops to find disguised signed min/max. // (X >s C) ? ~X : ~C ==> (~X <s ~C) ? ~X : ~C ==> SMIN(~X, ~C) // (X <s C) ? ~X : ~C ==> (~X >s ~C) ? ~X : ~C ==> SMAX(~X, ~C) if (match(TrueVal, m_Not(m_Specific(CmpLHS))) && - match(FalseVal, m_APInt(C2)) && ~(*C1) == *C2) { - LHS = TrueVal; - RHS = FalseVal; + match(FalseVal, m_APInt(C2)) && ~(*C1) == *C2) return {Pred == CmpInst::ICMP_SGT ? SPF_SMIN : SPF_SMAX, SPNB_NA, false}; - } // (X >s C) ? ~C : ~X ==> (~X <s ~C) ? ~C : ~X ==> SMAX(~C, ~X) // (X <s C) ? ~C : ~X ==> (~X >s ~C) ? ~C : ~X ==> SMIN(~C, ~X) if (match(FalseVal, m_Not(m_Specific(CmpLHS))) && - match(TrueVal, m_APInt(C2)) && ~(*C1) == *C2) { - LHS = TrueVal; - RHS = FalseVal; + match(TrueVal, m_APInt(C2)) && ~(*C1) == *C2) return {Pred == CmpInst::ICMP_SGT ? SPF_SMAX : SPF_SMIN, SPNB_NA, false}; - } return {SPF_UNKNOWN, SPNB_NA, false}; } @@ -4101,58 +4141,64 @@ static SelectPatternResult matchSelectPattern(CmpInst::Predicate Pred, static Value *lookThroughCast(CmpInst *CmpI, Value *V1, Value *V2, Instruction::CastOps *CastOp) { - CastInst *CI = dyn_cast<CastInst>(V1); - Constant *C = dyn_cast<Constant>(V2); - if (!CI) + auto *Cast1 = dyn_cast<CastInst>(V1); + if (!Cast1) return nullptr; - *CastOp = CI->getOpcode(); - - if (auto *CI2 = dyn_cast<CastInst>(V2)) { - // If V1 and V2 are both the same cast from the same type, we can look - // through V1. - if (CI2->getOpcode() == CI->getOpcode() && - CI2->getSrcTy() == CI->getSrcTy()) - return CI2->getOperand(0); - return nullptr; - } else if (!C) { + + *CastOp = Cast1->getOpcode(); + Type *SrcTy = Cast1->getSrcTy(); + if (auto *Cast2 = dyn_cast<CastInst>(V2)) { + // If V1 and V2 are both the same cast from the same type, look through V1. + if (*CastOp == Cast2->getOpcode() && SrcTy == Cast2->getSrcTy()) + return Cast2->getOperand(0); return nullptr; } - Constant *CastedTo = nullptr; - - if (isa<ZExtInst>(CI) && CmpI->isUnsigned()) - CastedTo = ConstantExpr::getTrunc(C, CI->getSrcTy()); - - if (isa<SExtInst>(CI) && CmpI->isSigned()) - CastedTo = ConstantExpr::getTrunc(C, CI->getSrcTy(), true); - - if (isa<TruncInst>(CI)) - CastedTo = ConstantExpr::getIntegerCast(C, CI->getSrcTy(), CmpI->isSigned()); - - if (isa<FPTruncInst>(CI)) - CastedTo = ConstantExpr::getFPExtend(C, CI->getSrcTy(), true); - - if (isa<FPExtInst>(CI)) - CastedTo = ConstantExpr::getFPTrunc(C, CI->getSrcTy(), true); - - if (isa<FPToUIInst>(CI)) - CastedTo = ConstantExpr::getUIToFP(C, CI->getSrcTy(), true); - - if (isa<FPToSIInst>(CI)) - CastedTo = ConstantExpr::getSIToFP(C, CI->getSrcTy(), true); - - if (isa<UIToFPInst>(CI)) - CastedTo = ConstantExpr::getFPToUI(C, CI->getSrcTy(), true); + auto *C = dyn_cast<Constant>(V2); + if (!C) + return nullptr; - if (isa<SIToFPInst>(CI)) - CastedTo = ConstantExpr::getFPToSI(C, CI->getSrcTy(), true); + Constant *CastedTo = nullptr; + switch (*CastOp) { + case Instruction::ZExt: + if (CmpI->isUnsigned()) + CastedTo = ConstantExpr::getTrunc(C, SrcTy); + break; + case Instruction::SExt: + if (CmpI->isSigned()) + CastedTo = ConstantExpr::getTrunc(C, SrcTy, true); + break; + case Instruction::Trunc: + CastedTo = ConstantExpr::getIntegerCast(C, SrcTy, CmpI->isSigned()); + break; + case Instruction::FPTrunc: + CastedTo = ConstantExpr::getFPExtend(C, SrcTy, true); + break; + case Instruction::FPExt: + CastedTo = ConstantExpr::getFPTrunc(C, SrcTy, true); + break; + case Instruction::FPToUI: + CastedTo = ConstantExpr::getUIToFP(C, SrcTy, true); + break; + case Instruction::FPToSI: + CastedTo = ConstantExpr::getSIToFP(C, SrcTy, true); + break; + case Instruction::UIToFP: + CastedTo = ConstantExpr::getFPToUI(C, SrcTy, true); + break; + case Instruction::SIToFP: + CastedTo = ConstantExpr::getFPToSI(C, SrcTy, true); + break; + default: + break; + } if (!CastedTo) return nullptr; - Constant *CastedBack = - ConstantExpr::getCast(CI->getOpcode(), CastedTo, C->getType(), true); // Make sure the cast doesn't lose any information. + Constant *CastedBack = + ConstantExpr::getCast(*CastOp, CastedTo, C->getType(), true); if (CastedBack != C) return nullptr; |