aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp')
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp828
1 files changed, 537 insertions, 291 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index abadf54a9676..5a4791870ac7 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -12,7 +12,6 @@
#include "InstCombineInternal.h"
#include "llvm/ADT/APSInt.h"
-#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/Analysis/CaptureTracking.h"
@@ -312,7 +311,7 @@ Instruction *InstCombinerImpl::foldCmpLoadFromIndexedGlobal(
DL.getTypeAllocSize(Init->getType()->getArrayElementType());
auto MaskIdx = [&](Value *Idx) {
if (!GEP->isInBounds() && llvm::countr_zero(ElementSize) != 0) {
- Value *Mask = ConstantInt::get(Idx->getType(), -1);
+ Value *Mask = Constant::getAllOnesValue(Idx->getType());
Mask = Builder.CreateLShr(Mask, llvm::countr_zero(ElementSize));
Idx = Builder.CreateAnd(Idx, Mask);
}
@@ -423,7 +422,7 @@ Instruction *InstCombinerImpl::foldCmpLoadFromIndexedGlobal(
/// Returns true if we can rewrite Start as a GEP with pointer Base
/// and some integer offset. The nodes that need to be re-written
/// for this transformation will be added to Explored.
-static bool canRewriteGEPAsOffset(Value *Start, Value *Base,
+static bool canRewriteGEPAsOffset(Value *Start, Value *Base, GEPNoWrapFlags &NW,
const DataLayout &DL,
SetVector<Value *> &Explored) {
SmallVector<Value *, 16> WorkList(1, Start);
@@ -462,6 +461,7 @@ static bool canRewriteGEPAsOffset(Value *Start, Value *Base,
if (!GEP->isInBounds() || count_if(GEP->indices(), IsNonConst) > 1)
return false;
+ NW = NW.intersectForOffsetAdd(GEP->getNoWrapFlags());
if (!Explored.contains(GEP->getOperand(0)))
WorkList.push_back(GEP->getOperand(0));
}
@@ -536,7 +536,7 @@ static void setInsertionPoint(IRBuilder<> &Builder, Value *V,
/// Returns a re-written value of Start as an indexed GEP using Base as a
/// pointer.
-static Value *rewriteGEPAsOffset(Value *Start, Value *Base,
+static Value *rewriteGEPAsOffset(Value *Start, Value *Base, GEPNoWrapFlags NW,
const DataLayout &DL,
SetVector<Value *> &Explored,
InstCombiner &IC) {
@@ -578,8 +578,10 @@ static Value *rewriteGEPAsOffset(Value *Start, Value *Base,
if (isa<ConstantInt>(Op) && cast<ConstantInt>(Op)->isZero())
NewInsts[GEP] = OffsetV;
else
- NewInsts[GEP] = Builder.CreateNSWAdd(
- Op, OffsetV, GEP->getOperand(0)->getName() + ".add");
+ NewInsts[GEP] = Builder.CreateAdd(
+ Op, OffsetV, GEP->getOperand(0)->getName() + ".add",
+ /*NUW=*/NW.hasNoUnsignedWrap(),
+ /*NSW=*/NW.hasNoUnsignedSignedWrap());
continue;
}
if (isa<PHINode>(Val))
@@ -599,8 +601,9 @@ static Value *rewriteGEPAsOffset(Value *Start, Value *Base,
for (unsigned I = 0, E = PHI->getNumIncomingValues(); I < E; ++I) {
Value *NewIncoming = PHI->getIncomingValue(I);
- if (NewInsts.contains(NewIncoming))
- NewIncoming = NewInsts[NewIncoming];
+ auto It = NewInsts.find(NewIncoming);
+ if (It != NewInsts.end())
+ NewIncoming = It->second;
NewPhi->addIncoming(NewIncoming, PHI->getIncomingBlock(I));
}
@@ -613,8 +616,8 @@ static Value *rewriteGEPAsOffset(Value *Start, Value *Base,
setInsertionPoint(Builder, Val, false);
// Create GEP for external users.
- Value *NewVal = Builder.CreateInBoundsGEP(
- Builder.getInt8Ty(), Base, NewInsts[Val], Val->getName() + ".ptr");
+ Value *NewVal = Builder.CreateGEP(Builder.getInt8Ty(), Base, NewInsts[Val],
+ Val->getName() + ".ptr", NW);
IC.replaceInstUsesWith(*cast<Instruction>(Val), NewVal);
// Add old instruction to worklist for DCE. We don't directly remove it
// here because the original compare is one of the users.
@@ -628,7 +631,7 @@ static Value *rewriteGEPAsOffset(Value *Start, Value *Base,
/// We can look through PHIs, GEPs and casts in order to determine a common base
/// between GEPLHS and RHS.
static Instruction *transformToIndexedCompare(GEPOperator *GEPLHS, Value *RHS,
- ICmpInst::Predicate Cond,
+ CmpPredicate Cond,
const DataLayout &DL,
InstCombiner &IC) {
// FIXME: Support vector of pointers.
@@ -649,8 +652,8 @@ static Instruction *transformToIndexedCompare(GEPOperator *GEPLHS, Value *RHS,
// The set of nodes that will take part in this transformation.
SetVector<Value *> Nodes;
-
- if (!canRewriteGEPAsOffset(RHS, PtrBase, DL, Nodes))
+ GEPNoWrapFlags NW = GEPLHS->getNoWrapFlags();
+ if (!canRewriteGEPAsOffset(RHS, PtrBase, NW, DL, Nodes))
return nullptr;
// We know we can re-write this as
@@ -659,7 +662,7 @@ static Instruction *transformToIndexedCompare(GEPOperator *GEPLHS, Value *RHS,
// can't have overflow on either side. We can therefore re-write
// this as:
// OFFSET1 cmp OFFSET2
- Value *NewRHS = rewriteGEPAsOffset(RHS, PtrBase, DL, Nodes, IC);
+ Value *NewRHS = rewriteGEPAsOffset(RHS, PtrBase, NW, DL, Nodes, IC);
// RewriteGEPAsOffset has replaced RHS and all of its uses with a re-written
// GEP having PtrBase as the pointer base, and has returned in NewRHS the
@@ -672,8 +675,7 @@ static Instruction *transformToIndexedCompare(GEPOperator *GEPLHS, Value *RHS,
/// Fold comparisons between a GEP instruction and something else. At this point
/// we know that the GEP is on the LHS of the comparison.
Instruction *InstCombinerImpl::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS,
- ICmpInst::Predicate Cond,
- Instruction &I) {
+ CmpPredicate Cond, Instruction &I) {
// Don't transform signed compares of GEPs into index compares. Even if the
// GEP is inbounds, the final add of the base pointer can have signed overflow
// and would change the result of the icmp.
@@ -687,12 +689,32 @@ Instruction *InstCombinerImpl::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS,
if (!isa<GetElementPtrInst>(RHS))
RHS = RHS->stripPointerCasts();
+ auto CanFold = [Cond](GEPNoWrapFlags NW) {
+ if (ICmpInst::isEquality(Cond))
+ return true;
+
+ // Unsigned predicates can be folded if the GEPs have *any* nowrap flags.
+ assert(ICmpInst::isUnsigned(Cond));
+ return NW != GEPNoWrapFlags::none();
+ };
+
+ auto NewICmp = [Cond](GEPNoWrapFlags NW, Value *Op1, Value *Op2) {
+ if (!NW.hasNoUnsignedWrap()) {
+ // Convert signed to unsigned comparison.
+ return new ICmpInst(ICmpInst::getSignedPredicate(Cond), Op1, Op2);
+ }
+
+ auto *I = new ICmpInst(Cond, Op1, Op2);
+ I->setSameSign(NW.hasNoUnsignedSignedWrap());
+ return I;
+ };
+
Value *PtrBase = GEPLHS->getOperand(0);
- if (PtrBase == RHS && (GEPLHS->isInBounds() || ICmpInst::isEquality(Cond))) {
+ if (PtrBase == RHS && CanFold(GEPLHS->getNoWrapFlags())) {
// ((gep Ptr, OFFSET) cmp Ptr) ---> (OFFSET cmp 0).
Value *Offset = EmitGEPOffset(GEPLHS);
- return new ICmpInst(ICmpInst::getSignedPredicate(Cond), Offset,
- Constant::getNullValue(Offset->getType()));
+ return NewICmp(GEPLHS->getNoWrapFlags(), Offset,
+ Constant::getNullValue(Offset->getType()));
}
if (GEPLHS->isInBounds() && ICmpInst::isEquality(Cond) &&
@@ -725,6 +747,8 @@ Instruction *InstCombinerImpl::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS,
ConstantExpr::getPointerBitCastOrAddrSpaceCast(
cast<Constant>(RHS), Base->getType()));
} else if (GEPOperator *GEPRHS = dyn_cast<GEPOperator>(RHS)) {
+ GEPNoWrapFlags NW = GEPLHS->getNoWrapFlags() & GEPRHS->getNoWrapFlags();
+
// If the base pointers are different, but the indices are the same, just
// compare the base pointer.
if (PtrBase != GEPRHS->getOperand(0)) {
@@ -742,7 +766,8 @@ Instruction *InstCombinerImpl::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS,
// If all indices are the same, just compare the base pointers.
Type *BaseType = GEPLHS->getOperand(0)->getType();
- if (IndicesTheSame && CmpInst::makeCmpResultType(BaseType) == I.getType())
+ if (IndicesTheSame &&
+ CmpInst::makeCmpResultType(BaseType) == I.getType() && CanFold(NW))
return new ICmpInst(Cond, GEPLHS->getOperand(0), GEPRHS->getOperand(0));
// If we're comparing GEPs with two base pointers that only differ in type
@@ -782,7 +807,6 @@ Instruction *InstCombinerImpl::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS,
return transformToIndexedCompare(GEPLHS, RHS, Cond, DL, *this);
}
- bool GEPsInBounds = GEPLHS->isInBounds() && GEPRHS->isInBounds();
if (GEPLHS->getNumOperands() == GEPRHS->getNumOperands() &&
GEPLHS->getSourceElementType() == GEPRHS->getSourceElementType()) {
// If the GEPs only differ by one index, compare it.
@@ -810,19 +834,18 @@ Instruction *InstCombinerImpl::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS,
return replaceInstUsesWith(I, // No comparison is needed here.
ConstantInt::get(I.getType(), ICmpInst::isTrueWhenEqual(Cond)));
- else if (NumDifferences == 1 && GEPsInBounds) {
+ else if (NumDifferences == 1 && CanFold(NW)) {
Value *LHSV = GEPLHS->getOperand(DiffOperand);
Value *RHSV = GEPRHS->getOperand(DiffOperand);
- // Make sure we do a signed comparison here.
- return new ICmpInst(ICmpInst::getSignedPredicate(Cond), LHSV, RHSV);
+ return NewICmp(NW, LHSV, RHSV);
}
}
- if (GEPsInBounds || CmpInst::isEquality(Cond)) {
+ if (CanFold(NW)) {
// ((gep Ptr, OFFSET1) cmp (gep Ptr, OFFSET2) ---> (OFFSET1 cmp OFFSET2)
Value *L = EmitGEPOffset(GEPLHS, /*RewriteGEP=*/true);
Value *R = EmitGEPOffset(GEPRHS, /*RewriteGEP=*/true);
- return new ICmpInst(ICmpInst::getSignedPredicate(Cond), L, R);
+ return NewICmp(NW, L, R);
}
}
@@ -866,8 +889,7 @@ bool InstCombinerImpl::foldAllocaCmp(AllocaInst *Alloca) {
if (ICmp && ICmp->isEquality() && getUnderlyingObject(*U) == Alloca) {
// Collect equality icmps of the alloca, and don't treat them as
// captures.
- auto Res = ICmps.insert({ICmp, 0});
- Res.first->second |= 1u << U->getOperandNo();
+ ICmps[ICmp] |= 1u << U->getOperandNo();
return false;
}
@@ -910,7 +932,7 @@ bool InstCombinerImpl::foldAllocaCmp(AllocaInst *Alloca) {
/// Fold "icmp pred (X+C), X".
Instruction *InstCombinerImpl::foldICmpAddOpConst(Value *X, const APInt &C,
- ICmpInst::Predicate Pred) {
+ CmpPredicate Pred) {
// From this point on, we know that (X+C <= X) --> (X+C < X) because C != 0,
// so the values can never be equal. Similarly for all other "or equals"
// operators.
@@ -1121,7 +1143,7 @@ static Instruction *processUGT_ADDCST_ADD(ICmpInst &I, Value *A, Value *B,
// use the sadd_with_overflow intrinsic to efficiently compute both the
// result and the overflow bit.
Type *NewType = IntegerType::get(OrigAdd->getContext(), NewWidth);
- Function *F = Intrinsic::getDeclaration(
+ Function *F = Intrinsic::getOrInsertDeclaration(
I.getModule(), Intrinsic::sadd_with_overflow, NewType);
InstCombiner::BuilderTy &Builder = IC.Builder;
@@ -1153,7 +1175,7 @@ Instruction *InstCombinerImpl::foldIRemByPowerOfTwoToBitTest(ICmpInst &I) {
// This fold is only valid for equality predicates.
if (!I.isEquality())
return nullptr;
- ICmpInst::Predicate Pred;
+ CmpPredicate Pred;
Value *X, *Y, *Zero;
if (!match(&I, m_ICmp(Pred, m_OneUse(m_IRem(m_Value(X), m_Value(Y))),
m_CombineAnd(m_Zero(), m_Value(Zero)))))
@@ -1170,7 +1192,7 @@ Instruction *InstCombinerImpl::foldIRemByPowerOfTwoToBitTest(ICmpInst &I) {
/// by one-less-than-bitwidth into a sign test on the original value.
Instruction *InstCombinerImpl::foldSignBitTest(ICmpInst &I) {
Instruction *Val;
- ICmpInst::Predicate Pred;
+ CmpPredicate Pred;
if (!I.isEquality() || !match(&I, m_ICmp(Pred, m_Instruction(Val), m_Zero())))
return nullptr;
@@ -1384,7 +1406,7 @@ Instruction *InstCombinerImpl::foldICmpWithDominatingICmp(ICmpInst &Cmp) {
};
for (BranchInst *BI : DC.conditionsFor(X)) {
- ICmpInst::Predicate DomPred;
+ CmpPredicate DomPred;
const APInt *DomC;
if (!match(BI->getCondition(),
m_ICmp(DomPred, m_Specific(X), m_APInt(DomC))))
@@ -1497,7 +1519,7 @@ Instruction *
InstCombinerImpl::foldICmpTruncWithTruncOrExt(ICmpInst &Cmp,
const SimplifyQuery &Q) {
Value *X, *Y;
- ICmpInst::Predicate Pred;
+ CmpPredicate Pred;
bool YIsSExt = false;
// Try to match icmp (trunc X), (trunc Y)
if (match(&Cmp, m_ICmp(Pred, m_Trunc(m_Value(X)), m_Trunc(m_Value(Y))))) {
@@ -1725,7 +1747,8 @@ Instruction *InstCombinerImpl::foldICmpAndShift(ICmpInst &Cmp,
// preferable because it allows the C2 << Y expression to be hoisted out of a
// loop if Y is invariant and X is not.
if (Shift->hasOneUse() && C1.isZero() && Cmp.isEquality() &&
- !Shift->isArithmeticShift() && !isa<Constant>(Shift->getOperand(0))) {
+ !Shift->isArithmeticShift() &&
+ ((!IsShl && C2.isOne()) || !isa<Constant>(Shift->getOperand(0)))) {
// Compute C2 << Y.
Value *NewShift =
IsShl ? Builder.CreateLShr(And->getOperand(1), Shift->getOperand(1))
@@ -1733,7 +1756,7 @@ Instruction *InstCombinerImpl::foldICmpAndShift(ICmpInst &Cmp,
// Compute X & (C2 << Y).
Value *NewAnd = Builder.CreateAnd(Shift->getOperand(0), NewShift);
- return replaceOperand(Cmp, 0, NewAnd);
+ return new ICmpInst(Cmp.getPredicate(), NewAnd, Cmp.getOperand(1));
}
return nullptr;
@@ -1757,6 +1780,17 @@ Instruction *InstCombinerImpl::foldICmpAndConstConst(ICmpInst &Cmp,
if (!match(And, m_And(m_Value(X), m_APInt(C2))))
return nullptr;
+ // (and X, highmask) s> [0, ~highmask] --> X s> ~highmask
+ if (Cmp.getPredicate() == ICmpInst::ICMP_SGT && C1.ule(~*C2) &&
+ C2->isNegatedPowerOf2())
+ return new ICmpInst(ICmpInst::ICMP_SGT, X,
+ ConstantInt::get(X->getType(), ~*C2));
+ // (and X, highmask) s< [1, -highmask] --> X s< -highmask
+ if (Cmp.getPredicate() == ICmpInst::ICMP_SLT && !C1.isSignMask() &&
+ (C1 - 1).ule(~*C2) && C2->isNegatedPowerOf2() && !C2->isSignMask())
+ return new ICmpInst(ICmpInst::ICMP_SLT, X,
+ ConstantInt::get(X->getType(), -*C2));
+
// Don't perform the following transforms if the AND has multiple uses
if (!And->hasOneUse())
return nullptr;
@@ -1839,7 +1873,7 @@ Instruction *InstCombinerImpl::foldICmpAndConstConst(ICmpInst &Cmp,
/*HasNUW=*/true),
One, Or->getName());
Value *NewAnd = Builder.CreateAnd(A, NewOr, And->getName());
- return replaceOperand(Cmp, 0, NewAnd);
+ return new ICmpInst(Cmp.getPredicate(), NewAnd, Cmp.getOperand(1));
}
}
}
@@ -1972,6 +2006,22 @@ Instruction *InstCombinerImpl::foldICmpAndConstant(ICmpInst &Cmp,
return new ICmpInst(Pred, LShr, Constant::getNullValue(LShr->getType()));
}
+ // (icmp eq/ne (and (add A, Addend), Msk), C)
+ // -> (icmp eq/ne (and A, Msk), (and (sub C, Addend), Msk))
+ {
+ Value *A;
+ const APInt *Addend, *Msk;
+ if (match(And, m_And(m_OneUse(m_Add(m_Value(A), m_APInt(Addend))),
+ m_APInt(Msk))) &&
+ Msk->isMask() && C.ule(*Msk)) {
+ APInt NewComperand = (C - *Addend) & *Msk;
+ Value* MaskA = Builder.CreateAnd(A, ConstantInt::get(A->getType(), *Msk));
+ return new ICmpInst(
+ Pred, MaskA,
+ Constant::getIntegerValue(MaskA->getType(), NewComperand));
+ }
+ }
+
return nullptr;
}
@@ -2226,18 +2276,24 @@ Instruction *InstCombinerImpl::foldICmpMulConstant(ICmpInst &Cmp,
return NewC ? new ICmpInst(Pred, X, NewC) : nullptr;
}
-/// Fold icmp (shl 1, Y), C.
-static Instruction *foldICmpShlOne(ICmpInst &Cmp, Instruction *Shl,
- const APInt &C) {
+/// Fold icmp (shl nuw C2, Y), C.
+static Instruction *foldICmpShlLHSC(ICmpInst &Cmp, Instruction *Shl,
+ const APInt &C) {
Value *Y;
- if (!match(Shl, m_Shl(m_One(), m_Value(Y))))
+ const APInt *C2;
+ if (!match(Shl, m_NUWShl(m_APInt(C2), m_Value(Y))))
return nullptr;
Type *ShiftType = Shl->getType();
unsigned TypeBits = C.getBitWidth();
- bool CIsPowerOf2 = C.isPowerOf2();
ICmpInst::Predicate Pred = Cmp.getPredicate();
if (Cmp.isUnsigned()) {
+ if (C2->isZero() || C2->ugt(C))
+ return nullptr;
+ APInt Div, Rem;
+ APInt::udivrem(C, *C2, Div, Rem);
+ bool CIsPowerOf2 = Rem.isZero() && Div.isPowerOf2();
+
// (1 << Y) pred C -> Y pred Log2(C)
if (!CIsPowerOf2) {
// (1 << Y) < 30 -> Y <= 4
@@ -2250,9 +2306,9 @@ static Instruction *foldICmpShlOne(ICmpInst &Cmp, Instruction *Shl,
Pred = ICmpInst::ICMP_UGT;
}
- unsigned CLog2 = C.logBase2();
+ unsigned CLog2 = Div.logBase2();
return new ICmpInst(Pred, Y, ConstantInt::get(ShiftType, CLog2));
- } else if (Cmp.isSigned()) {
+ } else if (Cmp.isSigned() && C2->isOne()) {
Constant *BitWidthMinusOne = ConstantInt::get(ShiftType, TypeBits - 1);
// (1 << Y) > 0 -> Y != 31
// (1 << Y) > C -> Y != 31 if C is negative.
@@ -2306,7 +2362,7 @@ Instruction *InstCombinerImpl::foldICmpShlConstant(ICmpInst &Cmp,
const APInt *ShiftAmt;
if (!match(Shl->getOperand(1), m_APInt(ShiftAmt)))
- return foldICmpShlOne(Cmp, Shl, C);
+ return foldICmpShlLHSC(Cmp, Shl, C);
// Check that the shift amount is in range. If not, don't perform undefined
// shifts. When the shift is visited, it will be simplified.
@@ -2429,9 +2485,8 @@ Instruction *InstCombinerImpl::foldICmpShlConstant(ICmpInst &Cmp,
// icmp ule i64 (shl X, 32), 8589934592 ->
// icmp ule i32 (trunc X, i32), 2 ->
// icmp ult i32 (trunc X, i32), 3
- if (auto FlippedStrictness =
- InstCombiner::getFlippedStrictnessPredicateAndConstant(
- Pred, ConstantInt::get(ShType->getContext(), C))) {
+ if (auto FlippedStrictness = getFlippedStrictnessPredicateAndConstant(
+ Pred, ConstantInt::get(ShType->getContext(), C))) {
CmpPred = FlippedStrictness->first;
RHSC = cast<ConstantInt>(FlippedStrictness->second)->getValue();
}
@@ -2619,10 +2674,41 @@ Instruction *InstCombinerImpl::foldICmpShrConstant(ICmpInst &Cmp,
Instruction *InstCombinerImpl::foldICmpSRemConstant(ICmpInst &Cmp,
BinaryOperator *SRem,
const APInt &C) {
+ const ICmpInst::Predicate Pred = Cmp.getPredicate();
+ if (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_ULT) {
+ // Canonicalize unsigned predicates to signed:
+ // (X s% DivisorC) u> C -> (X s% DivisorC) s< 0
+ // iff (C s< 0 ? ~C : C) u>= abs(DivisorC)-1
+ // (X s% DivisorC) u< C+1 -> (X s% DivisorC) s> -1
+ // iff (C+1 s< 0 ? ~C : C) u>= abs(DivisorC)-1
+
+ const APInt *DivisorC;
+ if (!match(SRem->getOperand(1), m_APInt(DivisorC)))
+ return nullptr;
+
+ APInt NormalizedC = C;
+ if (Pred == ICmpInst::ICMP_ULT) {
+ assert(!NormalizedC.isZero() &&
+ "ult X, 0 should have been simplified already.");
+ --NormalizedC;
+ }
+ if (C.isNegative())
+ NormalizedC.flipAllBits();
+ assert(!DivisorC->isZero() &&
+ "srem X, 0 should have been simplified already.");
+ if (!NormalizedC.uge(DivisorC->abs() - 1))
+ return nullptr;
+
+ Type *Ty = SRem->getType();
+ if (Pred == ICmpInst::ICMP_UGT)
+ return new ICmpInst(ICmpInst::ICMP_SLT, SRem,
+ ConstantInt::getNullValue(Ty));
+ return new ICmpInst(ICmpInst::ICMP_SGT, SRem,
+ ConstantInt::getAllOnesValue(Ty));
+ }
// Match an 'is positive' or 'is negative' comparison of remainder by a
// constant power-of-2 value:
// (X % pow2C) sgt/slt 0
- const ICmpInst::Predicate Pred = Cmp.getPredicate();
if (Pred != ICmpInst::ICMP_SGT && Pred != ICmpInst::ICMP_SLT &&
Pred != ICmpInst::ICMP_EQ && Pred != ICmpInst::ICMP_NE)
return nullptr;
@@ -3035,12 +3121,12 @@ Instruction *InstCombinerImpl::foldICmpAddConstant(ICmpInst &Cmp,
unsigned BW = C.getBitWidth();
std::bitset<4> Table;
auto ComputeTable = [&](bool Op0Val, bool Op1Val) {
- int Res = 0;
+ APInt Res(BW, 0);
if (Op0Val)
- Res += isa<ZExtInst>(Ext0) ? 1 : -1;
+ Res += APInt(BW, isa<ZExtInst>(Ext0) ? 1 : -1, /*isSigned=*/true);
if (Op1Val)
- Res += isa<ZExtInst>(Ext1) ? 1 : -1;
- return ICmpInst::compare(APInt(BW, Res, true), C, Pred);
+ Res += APInt(BW, isa<ZExtInst>(Ext1) ? 1 : -1, /*isSigned=*/true);
+ return ICmpInst::compare(Res, C, Pred);
};
Table[0] = ComputeTable(false, false);
@@ -3076,6 +3162,12 @@ Instruction *InstCombinerImpl::foldICmpAddConstant(ICmpInst &Cmp,
return new ICmpInst(Pred, X, ConstantInt::get(Ty, NewC));
}
+ if (ICmpInst::isUnsigned(Pred) && Add->hasNoSignedWrap() &&
+ C.isNonNegative() && (C - *C2).isNonNegative() &&
+ computeConstantRange(X, /*ForSigned=*/true).add(*C2).isAllNonNegative())
+ return new ICmpInst(ICmpInst::getSignedPredicate(Pred), X,
+ ConstantInt::get(Ty, C - *C2));
+
auto CR = ConstantRange::makeExactICmpRegion(Pred, C).subtract(*C2);
const APInt &Upper = CR.getUpper();
const APInt &Lower = CR.getLower();
@@ -3152,6 +3244,29 @@ Instruction *InstCombinerImpl::foldICmpAddConstant(ICmpInst &Cmp,
Builder.CreateAdd(X, ConstantInt::get(Ty, *C2 - C - 1)),
ConstantInt::get(Ty, ~C));
+ // zext(V) + C2 pred C -> V + C3 pred' C4
+ Value *V;
+ if (match(X, m_ZExt(m_Value(V)))) {
+ Type *NewCmpTy = V->getType();
+ unsigned NewCmpBW = NewCmpTy->getScalarSizeInBits();
+ if (shouldChangeType(Ty, NewCmpTy)) {
+ if (CR.getActiveBits() <= NewCmpBW) {
+ ConstantRange SrcCR = CR.truncate(NewCmpBW);
+ CmpInst::Predicate EquivPred;
+ APInt EquivInt;
+ APInt EquivOffset;
+
+ SrcCR.getEquivalentICmp(EquivPred, EquivInt, EquivOffset);
+ return new ICmpInst(
+ EquivPred,
+ EquivOffset.isZero()
+ ? V
+ : Builder.CreateAdd(V, ConstantInt::get(NewCmpTy, EquivOffset)),
+ ConstantInt::get(NewCmpTy, EquivInt));
+ }
+ }
+ }
+
return nullptr;
}
@@ -3166,7 +3281,7 @@ bool InstCombinerImpl::matchThreeWayIntCompare(SelectInst *SI, Value *&LHS,
// i32 Equal,
// i32 (select i1 (a < b), i32 Less, i32 Greater)
// where Equal, Less and Greater are placeholders for any three constants.
- ICmpInst::Predicate PredA;
+ CmpPredicate PredA;
if (!match(SI->getCondition(), m_ICmp(PredA, m_Value(LHS), m_Value(RHS))) ||
!ICmpInst::isEquality(PredA))
return false;
@@ -3177,7 +3292,7 @@ bool InstCombinerImpl::matchThreeWayIntCompare(SelectInst *SI, Value *&LHS,
std::swap(EqualVal, UnequalVal);
if (!match(EqualVal, m_ConstantInt(Equal)))
return false;
- ICmpInst::Predicate PredB;
+ CmpPredicate PredB;
Value *LHS2, *RHS2;
if (!match(UnequalVal, m_Select(m_ICmp(PredB, m_Value(LHS2), m_Value(RHS2)),
m_ConstantInt(Less), m_ConstantInt(Greater))))
@@ -3195,8 +3310,7 @@ bool InstCombinerImpl::matchThreeWayIntCompare(SelectInst *SI, Value *&LHS,
if (PredB == ICmpInst::ICMP_SGT && isa<Constant>(RHS2)) {
// x sgt C-1 <--> x sge C <--> not(x slt C)
auto FlippedStrictness =
- InstCombiner::getFlippedStrictnessPredicateAndConstant(
- PredB, cast<Constant>(RHS2));
+ getFlippedStrictnessPredicateAndConstant(PredB, cast<Constant>(RHS2));
if (!FlippedStrictness)
return false;
assert(FlippedStrictness->first == ICmpInst::ICMP_SGE &&
@@ -3529,6 +3643,53 @@ Instruction *InstCombinerImpl::foldICmpBinOpEqualityWithConstant(
Value *And = Builder.CreateAnd(BOp0, NotBOC);
return new ICmpInst(Pred, And, NotBOC);
}
+ // (icmp eq (or (select cond, 0, NonZero), Other), 0)
+ // -> (and cond, (icmp eq Other, 0))
+ // (icmp ne (or (select cond, NonZero, 0), Other), 0)
+ // -> (or cond, (icmp ne Other, 0))
+ Value *Cond, *TV, *FV, *Other, *Sel;
+ if (C.isZero() &&
+ match(BO,
+ m_OneUse(m_c_Or(m_CombineAnd(m_Value(Sel),
+ m_Select(m_Value(Cond), m_Value(TV),
+ m_Value(FV))),
+ m_Value(Other)))) &&
+ Cond->getType() == Cmp.getType()) {
+ const SimplifyQuery Q = SQ.getWithInstruction(&Cmp);
+ // Easy case is if eq/ne matches whether 0 is trueval/falseval.
+ if (Pred == ICmpInst::ICMP_EQ
+ ? (match(TV, m_Zero()) && isKnownNonZero(FV, Q))
+ : (match(FV, m_Zero()) && isKnownNonZero(TV, Q))) {
+ Value *Cmp = Builder.CreateICmp(
+ Pred, Other, Constant::getNullValue(Other->getType()));
+ return BinaryOperator::Create(
+ Pred == ICmpInst::ICMP_EQ ? Instruction::And : Instruction::Or, Cmp,
+ Cond);
+ }
+ // Harder case is if eq/ne matches whether 0 is falseval/trueval. In this
+ // case we need to invert the select condition so we need to be careful to
+ // avoid creating extra instructions.
+ // (icmp ne (or (select cond, 0, NonZero), Other), 0)
+ // -> (or (not cond), (icmp ne Other, 0))
+ // (icmp eq (or (select cond, NonZero, 0), Other), 0)
+ // -> (and (not cond), (icmp eq Other, 0))
+ //
+ // Only do this if the inner select has one use, in which case we are
+ // replacing `select` with `(not cond)`. Otherwise, we will create more
+ // uses. NB: Trying to freely invert cond doesn't make sense here, as if
+ // cond was freely invertable, the select arms would have been inverted.
+ if (Sel->hasOneUse() &&
+ (Pred == ICmpInst::ICMP_EQ
+ ? (match(FV, m_Zero()) && isKnownNonZero(TV, Q))
+ : (match(TV, m_Zero()) && isKnownNonZero(FV, Q)))) {
+ Value *NotCond = Builder.CreateNot(Cond);
+ Value *Cmp = Builder.CreateICmp(
+ Pred, Other, Constant::getNullValue(Other->getType()));
+ return BinaryOperator::Create(
+ Pred == ICmpInst::ICMP_EQ ? Instruction::And : Instruction::Or, Cmp,
+ NotCond);
+ }
+ }
break;
}
case Instruction::UDiv:
@@ -3849,8 +4010,8 @@ Instruction *InstCombinerImpl::foldICmpBinOpWithConstant(ICmpInst &Cmp,
}
static Instruction *
-foldICmpUSubSatOrUAddSatWithConstant(ICmpInst::Predicate Pred,
- SaturatingInst *II, const APInt &C,
+foldICmpUSubSatOrUAddSatWithConstant(CmpPredicate Pred, SaturatingInst *II,
+ const APInt &C,
InstCombiner::BuilderTy &Builder) {
// This transform may end up producing more than one instruction for the
// intrinsic, so limit it to one user of the intrinsic.
@@ -3934,7 +4095,7 @@ foldICmpUSubSatOrUAddSatWithConstant(ICmpInst::Predicate Pred,
}
static Instruction *
-foldICmpOfCmpIntrinsicWithConstant(ICmpInst::Predicate Pred, IntrinsicInst *I,
+foldICmpOfCmpIntrinsicWithConstant(CmpPredicate Pred, IntrinsicInst *I,
const APInt &C,
InstCombiner::BuilderTy &Builder) {
std::optional<ICmpInst::Predicate> NewPredicate = std::nullopt;
@@ -3965,6 +4126,16 @@ foldICmpOfCmpIntrinsicWithConstant(ICmpInst::Predicate Pred, IntrinsicInst *I,
NewPredicate = ICmpInst::ICMP_ULE;
break;
+ case ICmpInst::ICMP_ULT:
+ if (C.ugt(1))
+ NewPredicate = ICmpInst::ICMP_UGE;
+ break;
+
+ case ICmpInst::ICMP_UGT:
+ if (!C.isZero() && !C.isAllOnes())
+ NewPredicate = ICmpInst::ICMP_ULT;
+ break;
+
default:
break;
}
@@ -4123,9 +4294,8 @@ Instruction *InstCombinerImpl::foldICmpInstWithConstantNotInt(ICmpInst &I) {
return nullptr;
}
-Instruction *InstCombinerImpl::foldSelectICmp(ICmpInst::Predicate Pred,
- SelectInst *SI, Value *RHS,
- const ICmpInst &I) {
+Instruction *InstCombinerImpl::foldSelectICmp(CmpPredicate Pred, SelectInst *SI,
+ Value *RHS, const ICmpInst &I) {
// Try to fold the comparison into the select arms, which will cause the
// select to be converted into a logical and/or.
auto SimplifyOp = [&](Value *Op, bool SelectCondIsTrue) -> Value * {
@@ -4146,6 +4316,14 @@ Instruction *InstCombinerImpl::foldSelectICmp(ICmpInst::Predicate Pred,
if (Op2)
CI = dyn_cast<ConstantInt>(Op2);
+ auto Simplifies = [&](Value *Op, unsigned Idx) {
+ // A comparison of ucmp/scmp with a constant will fold into an icmp.
+ const APInt *Dummy;
+ return Op ||
+ (isa<CmpIntrinsic>(SI->getOperand(Idx)) &&
+ SI->getOperand(Idx)->hasOneUse() && match(RHS, m_APInt(Dummy)));
+ };
+
// We only want to perform this transformation if it will not lead to
// additional code. This is true if either both sides of the select
// fold to a constant (in which case the icmp is replaced with a select
@@ -4156,7 +4334,7 @@ Instruction *InstCombinerImpl::foldSelectICmp(ICmpInst::Predicate Pred,
bool Transform = false;
if (Op1 && Op2)
Transform = true;
- else if (Op1 || Op2) {
+ else if (Simplifies(Op1, 1) || Simplifies(Op2, 2)) {
// Local case
if (SI->hasOneUse())
Transform = true;
@@ -4286,7 +4464,7 @@ static bool isMaskOrZero(const Value *V, bool Not, const SimplifyQuery &Q,
/// The Mask can be a constant, too.
/// For some predicates, the operands are commutative.
/// For others, x can only be on a specific side.
-static Value *foldICmpWithLowBitMaskedVal(ICmpInst::Predicate Pred, Value *Op0,
+static Value *foldICmpWithLowBitMaskedVal(CmpPredicate Pred, Value *Op0,
Value *Op1, const SimplifyQuery &Q,
InstCombiner &IC) {
@@ -4418,7 +4596,7 @@ static Value *foldICmpWithLowBitMaskedVal(ICmpInst::Predicate Pred, Value *Op0,
static Value *
foldICmpWithTruncSignExtendedVal(ICmpInst &I,
InstCombiner::BuilderTy &Builder) {
- ICmpInst::Predicate SrcPred;
+ CmpPredicate SrcPred;
Value *X;
const APInt *C0, *C1; // FIXME: non-splats, potentially with undef.
// We are ok with 'shl' having multiple uses, but 'ashr' must be one-use.
@@ -4664,7 +4842,7 @@ foldShiftIntoShiftInAnotherHandOfAndInICmp(ICmpInst &I, const SimplifyQuery SQ,
/// Note that the comparison is commutative, while inverted (u>=, ==) predicate
/// will mean that we are looking for the opposite answer.
Value *InstCombinerImpl::foldMultiplicationOverflowCheck(ICmpInst &I) {
- ICmpInst::Predicate Pred;
+ CmpPredicate Pred;
Value *X, *Y;
Instruction *Mul;
Instruction *Div;
@@ -4709,12 +4887,10 @@ Value *InstCombinerImpl::foldMultiplicationOverflowCheck(ICmpInst &I) {
if (MulHadOtherUses)
Builder.SetInsertPoint(Mul);
- Function *F = Intrinsic::getDeclaration(I.getModule(),
- Div->getOpcode() == Instruction::UDiv
- ? Intrinsic::umul_with_overflow
- : Intrinsic::smul_with_overflow,
- X->getType());
- CallInst *Call = Builder.CreateCall(F, {X, Y}, "mul");
+ CallInst *Call = Builder.CreateIntrinsic(
+ Div->getOpcode() == Instruction::UDiv ? Intrinsic::umul_with_overflow
+ : Intrinsic::smul_with_overflow,
+ X->getType(), {X, Y}, /*FMFSource=*/nullptr, "mul");
// If the multiplication was used elsewhere, to ensure that we don't leave
// "duplicate" instructions, replace uses of that original multiplication
@@ -4736,7 +4912,7 @@ Value *InstCombinerImpl::foldMultiplicationOverflowCheck(ICmpInst &I) {
static Instruction *foldICmpXNegX(ICmpInst &I,
InstCombiner::BuilderTy &Builder) {
- CmpInst::Predicate Pred;
+ CmpPredicate Pred;
Value *X;
if (match(&I, m_c_ICmp(Pred, m_NSWNeg(m_Value(X)), m_Deferred(X)))) {
@@ -4935,6 +5111,18 @@ Instruction *InstCombinerImpl::foldICmpBinOp(ICmpInst &I,
}
}
+ // (icmp eq/ne (X, -P2), INT_MIN)
+ // -> (icmp slt/sge X, INT_MIN + P2)
+ if (ICmpInst::isEquality(Pred) && BO0 &&
+ match(I.getOperand(1), m_SignMask()) &&
+ match(BO0, m_And(m_Value(), m_NegatedPower2OrZero()))) {
+ // Will Constant fold.
+ Value *NewC = Builder.CreateSub(I.getOperand(1), BO0->getOperand(1));
+ return new ICmpInst(Pred == ICmpInst::ICMP_EQ ? ICmpInst::ICMP_SLT
+ : ICmpInst::ICMP_SGE,
+ BO0->getOperand(0), NewC);
+ }
+
{
// Similar to above: an unsigned overflow comparison may use offset + mask:
// ((Op1 + C) & C) u< Op1 --> Op1 != 0
@@ -5196,32 +5384,55 @@ Instruction *InstCombinerImpl::foldICmpBinOp(ICmpInst &I,
{
// Try to remove shared multiplier from comparison:
- // X * Z u{lt/le/gt/ge}/eq/ne Y * Z
+ // X * Z pred Y * Z
Value *X, *Y, *Z;
- if (Pred == ICmpInst::getUnsignedPredicate(Pred) &&
- ((match(Op0, m_Mul(m_Value(X), m_Value(Z))) &&
- match(Op1, m_c_Mul(m_Specific(Z), m_Value(Y)))) ||
- (match(Op0, m_Mul(m_Value(Z), m_Value(X))) &&
- match(Op1, m_c_Mul(m_Specific(Z), m_Value(Y)))))) {
- bool NonZero;
- if (ICmpInst::isEquality(Pred)) {
- KnownBits ZKnown = computeKnownBits(Z, 0, &I);
- // if Z % 2 != 0
- // X * Z eq/ne Y * Z -> X eq/ne Y
- if (ZKnown.countMaxTrailingZeros() == 0)
- return new ICmpInst(Pred, X, Y);
- NonZero = !ZKnown.One.isZero() || isKnownNonZero(Z, Q);
- // if Z != 0 and nsw(X * Z) and nsw(Y * Z)
- // X * Z eq/ne Y * Z -> X eq/ne Y
- if (NonZero && BO0 && BO1 && Op0HasNSW && Op1HasNSW)
+ if ((match(Op0, m_Mul(m_Value(X), m_Value(Z))) &&
+ match(Op1, m_c_Mul(m_Specific(Z), m_Value(Y)))) ||
+ (match(Op0, m_Mul(m_Value(Z), m_Value(X))) &&
+ match(Op1, m_c_Mul(m_Specific(Z), m_Value(Y))))) {
+ if (ICmpInst::isSigned(Pred)) {
+ if (Op0HasNSW && Op1HasNSW) {
+ KnownBits ZKnown = computeKnownBits(Z, 0, &I);
+ if (ZKnown.isStrictlyPositive())
+ return new ICmpInst(Pred, X, Y);
+ if (ZKnown.isNegative())
+ return new ICmpInst(ICmpInst::getSwappedPredicate(Pred), X, Y);
+ Value *LessThan = simplifyICmpInst(ICmpInst::ICMP_SLT, X, Y,
+ SQ.getWithInstruction(&I));
+ if (LessThan && match(LessThan, m_One()))
+ return new ICmpInst(ICmpInst::getSwappedPredicate(Pred), Z,
+ Constant::getNullValue(Z->getType()));
+ Value *GreaterThan = simplifyICmpInst(ICmpInst::ICMP_SGT, X, Y,
+ SQ.getWithInstruction(&I));
+ if (GreaterThan && match(GreaterThan, m_One()))
+ return new ICmpInst(Pred, Z, Constant::getNullValue(Z->getType()));
+ }
+ } else {
+ bool NonZero;
+ if (ICmpInst::isEquality(Pred)) {
+ // If X != Y, fold (X *nw Z) eq/ne (Y *nw Z) -> Z eq/ne 0
+ if (((Op0HasNSW && Op1HasNSW) || (Op0HasNUW && Op1HasNUW)) &&
+ isKnownNonEqual(X, Y, DL, &AC, &I, &DT))
+ return new ICmpInst(Pred, Z, Constant::getNullValue(Z->getType()));
+
+ KnownBits ZKnown = computeKnownBits(Z, 0, &I);
+ // if Z % 2 != 0
+ // X * Z eq/ne Y * Z -> X eq/ne Y
+ if (ZKnown.countMaxTrailingZeros() == 0)
+ return new ICmpInst(Pred, X, Y);
+ NonZero = !ZKnown.One.isZero() || isKnownNonZero(Z, Q);
+ // if Z != 0 and nsw(X * Z) and nsw(Y * Z)
+ // X * Z eq/ne Y * Z -> X eq/ne Y
+ if (NonZero && BO0 && BO1 && Op0HasNSW && Op1HasNSW)
+ return new ICmpInst(Pred, X, Y);
+ } else
+ NonZero = isKnownNonZero(Z, Q);
+
+ // If Z != 0 and nuw(X * Z) and nuw(Y * Z)
+ // X * Z u{lt/le/gt/ge}/eq/ne Y * Z -> X u{lt/le/gt/ge}/eq/ne Y
+ if (NonZero && BO0 && BO1 && Op0HasNUW && Op1HasNUW)
return new ICmpInst(Pred, X, Y);
- } else
- NonZero = isKnownNonZero(Z, Q);
-
- // If Z != 0 and nuw(X * Z) and nuw(Y * Z)
- // X * Z u{lt/le/gt/ge}/eq/ne Y * Z -> X u{lt/le/gt/ge}/eq/ne Y
- if (NonZero && BO0 && BO1 && Op0HasNUW && Op1HasNUW)
- return new ICmpInst(Pred, X, Y);
+ }
}
}
@@ -5373,8 +5584,7 @@ Instruction *InstCombinerImpl::foldICmpBinOp(ICmpInst &I,
/// Fold icmp Pred min|max(X, Y), Z.
Instruction *InstCombinerImpl::foldICmpWithMinMax(Instruction &I,
MinMaxIntrinsic *MinMax,
- Value *Z,
- ICmpInst::Predicate Pred) {
+ Value *Z, CmpPredicate Pred) {
Value *X = MinMax->getLHS();
Value *Y = MinMax->getRHS();
if (ICmpInst::isSigned(Pred) && !MinMax->isSigned())
@@ -5772,8 +5982,7 @@ Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) {
// -> icmp eq/ne X, rotate-left(X)
// We generally try to convert rotate-right -> rotate-left, this just
// canonicalizes another case.
- CmpInst::Predicate PredUnused = Pred;
- if (match(&I, m_c_ICmp(PredUnused, m_Value(A),
+ if (match(&I, m_c_ICmp(m_Value(A),
m_OneUse(m_Intrinsic<Intrinsic::fshr>(
m_Deferred(A), m_Deferred(A), m_Value(B))))))
return new ICmpInst(
@@ -5783,8 +5992,7 @@ Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) {
// Canonicalize:
// icmp eq/ne OneUse(A ^ Cst), B --> icmp eq/ne (A ^ B), Cst
Constant *Cst;
- if (match(&I, m_c_ICmp(PredUnused,
- m_OneUse(m_Xor(m_Value(A), m_ImmConstant(Cst))),
+ if (match(&I, m_c_ICmp(m_OneUse(m_Xor(m_Value(A), m_ImmConstant(Cst))),
m_CombineAnd(m_Value(B), m_Unless(m_ImmConstant())))))
return new ICmpInst(Pred, Builder.CreateXor(A, B), Cst);
@@ -5795,13 +6003,12 @@ Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) {
m_c_Xor(m_Value(B), m_Deferred(A))),
m_Sub(m_Value(B), m_Deferred(A)));
std::optional<bool> IsZero = std::nullopt;
- if (match(&I, m_c_ICmp(PredUnused, m_OneUse(m_c_And(m_Value(A), m_Matcher)),
+ if (match(&I, m_c_ICmp(m_OneUse(m_c_And(m_Value(A), m_Matcher)),
m_Deferred(A))))
IsZero = false;
// (icmp eq/ne (and (add/sub/xor X, P2), P2), 0)
else if (match(&I,
- m_ICmp(PredUnused, m_OneUse(m_c_And(m_Value(A), m_Matcher)),
- m_Zero())))
+ m_ICmp(m_OneUse(m_c_And(m_Value(A), m_Matcher)), m_Zero())))
IsZero = true;
if (IsZero && isKnownToBeAPowerOfTwo(A, /* OrZero */ true, /*Depth*/ 0, &I))
@@ -5829,32 +6036,15 @@ Instruction *InstCombinerImpl::foldICmpWithTrunc(ICmpInst &ICmp) {
return nullptr;
// This matches patterns corresponding to tests of the signbit as well as:
- // (trunc X) u< C --> (X & -C) == 0 (are all masked-high-bits clear?)
- // (trunc X) u> C --> (X & ~C) != 0 (are any masked-high-bits set?)
- APInt Mask;
- if (decomposeBitTestICmp(Op0, Op1, Pred, X, Mask, true /* WithTrunc */)) {
- Value *And = Builder.CreateAnd(X, Mask);
- Constant *Zero = ConstantInt::getNullValue(X->getType());
- return new ICmpInst(Pred, And, Zero);
+ // (trunc X) pred C2 --> (X & Mask) == C
+ if (auto Res = decomposeBitTestICmp(Op0, Op1, Pred, /*WithTrunc=*/true,
+ /*AllowNonZeroC=*/true)) {
+ Value *And = Builder.CreateAnd(Res->X, Res->Mask);
+ Constant *C = ConstantInt::get(Res->X->getType(), Res->C);
+ return new ICmpInst(Res->Pred, And, C);
}
unsigned SrcBits = X->getType()->getScalarSizeInBits();
- if (Pred == ICmpInst::ICMP_ULT && C->isNegatedPowerOf2()) {
- // If C is a negative power-of-2 (high-bit mask):
- // (trunc X) u< C --> (X & C) != C (are any masked-high-bits clear?)
- Constant *MaskC = ConstantInt::get(X->getType(), C->zext(SrcBits));
- Value *And = Builder.CreateAnd(X, MaskC);
- return new ICmpInst(ICmpInst::ICMP_NE, And, MaskC);
- }
-
- if (Pred == ICmpInst::ICMP_UGT && (~*C).isPowerOf2()) {
- // If C is not-of-power-of-2 (one clear bit):
- // (trunc X) u> C --> (X & (C+1)) == C+1 (are all masked-high-bits set?)
- Constant *MaskC = ConstantInt::get(X->getType(), (*C + 1).zext(SrcBits));
- Value *And = Builder.CreateAnd(X, MaskC);
- return new ICmpInst(ICmpInst::ICMP_EQ, And, MaskC);
- }
-
if (auto *II = dyn_cast<IntrinsicInst>(X)) {
if (II->getIntrinsicID() == Intrinsic::cttz ||
II->getIntrinsicID() == Intrinsic::ctlz) {
@@ -6013,12 +6203,12 @@ Instruction *InstCombinerImpl::foldICmpWithCastOp(ICmpInst &ICmp) {
// Turn icmp (ptrtoint x), (ptrtoint/c) into a compare of the input if the
// integer type is the same size as the pointer type.
- auto CompatibleSizes = [&](Type *SrcTy, Type *DestTy) {
- if (isa<VectorType>(SrcTy)) {
- SrcTy = cast<VectorType>(SrcTy)->getElementType();
- DestTy = cast<VectorType>(DestTy)->getElementType();
+ auto CompatibleSizes = [&](Type *PtrTy, Type *IntTy) {
+ if (isa<VectorType>(PtrTy)) {
+ PtrTy = cast<VectorType>(PtrTy)->getElementType();
+ IntTy = cast<VectorType>(IntTy)->getElementType();
}
- return DL.getPointerTypeSizeInBits(SrcTy) == DestTy->getIntegerBitWidth();
+ return DL.getPointerTypeSizeInBits(PtrTy) == IntTy->getIntegerBitWidth();
};
if (CastOp0->getOpcode() == Instruction::PtrToInt &&
CompatibleSizes(SrcTy, DestTy)) {
@@ -6035,6 +6225,22 @@ Instruction *InstCombinerImpl::foldICmpWithCastOp(ICmpInst &ICmp) {
return new ICmpInst(ICmp.getPredicate(), Op0Src, NewOp1);
}
+ // Do the same in the other direction for icmp (inttoptr x), (inttoptr/c).
+ if (CastOp0->getOpcode() == Instruction::IntToPtr &&
+ CompatibleSizes(DestTy, SrcTy)) {
+ Value *NewOp1 = nullptr;
+ if (auto *IntToPtrOp1 = dyn_cast<IntToPtrInst>(ICmp.getOperand(1))) {
+ Value *IntSrc = IntToPtrOp1->getOperand(0);
+ if (IntSrc->getType() == Op0Src->getType())
+ NewOp1 = IntToPtrOp1->getOperand(0);
+ } else if (auto *RHSC = dyn_cast<Constant>(ICmp.getOperand(1))) {
+ NewOp1 = ConstantFoldConstant(ConstantExpr::getPtrToInt(RHSC, SrcTy), DL);
+ }
+
+ if (NewOp1)
+ return new ICmpInst(ICmp.getPredicate(), Op0Src, NewOp1);
+ }
+
if (Instruction *R = foldICmpWithTrunc(ICmp))
return R;
@@ -6243,9 +6449,9 @@ static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal,
MulA = Builder.CreateZExt(A, MulType);
if (WidthB < MulWidth)
MulB = Builder.CreateZExt(B, MulType);
- Function *F = Intrinsic::getDeclaration(
- I.getModule(), Intrinsic::umul_with_overflow, MulType);
- CallInst *Call = Builder.CreateCall(F, {MulA, MulB}, "umul");
+ CallInst *Call =
+ Builder.CreateIntrinsic(Intrinsic::umul_with_overflow, MulType,
+ {MulA, MulB}, /*FMFSource=*/nullptr, "umul");
IC.addToWorklist(MulInstr);
// If there are uses of mul result other than the comparison, we know that
@@ -6461,6 +6667,16 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) {
return &I;
}
+ if (!isa<Constant>(Op0) && Op0Known.isConstant())
+ return new ICmpInst(
+ Pred, ConstantExpr::getIntegerValue(Ty, Op0Known.getConstant()), Op1);
+ if (!isa<Constant>(Op1) && Op1Known.isConstant())
+ return new ICmpInst(
+ Pred, Op0, ConstantExpr::getIntegerValue(Ty, Op1Known.getConstant()));
+
+ if (std::optional<bool> Res = ICmpInst::compare(Op0Known, Op1Known, Pred))
+ return replaceInstUsesWith(I, ConstantInt::getBool(I.getType(), *Res));
+
// Given the known and unknown bits, compute a range that the LHS could be
// in. Compute the Min, Max and RHS values based on the known bits. For the
// EQ and NE we use unsigned values.
@@ -6478,14 +6694,6 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) {
Op1Max = Op1Known.getMaxValue();
}
- // If Min and Max are known to be the same, then SimplifyDemandedBits figured
- // out that the LHS or RHS is a constant. Constant fold this now, so that
- // code below can assume that Min != Max.
- if (!isa<Constant>(Op0) && Op0Min == Op0Max)
- return new ICmpInst(Pred, ConstantExpr::getIntegerValue(Ty, Op0Min), Op1);
- if (!isa<Constant>(Op1) && Op1Min == Op1Max)
- return new ICmpInst(Pred, Op0, ConstantExpr::getIntegerValue(Ty, Op1Min));
-
// Don't break up a clamp pattern -- (min(max X, Y), Z) -- by replacing a
// min/max canonical compare with some other compare. That could lead to
// conflict with select canonicalization and infinite looping.
@@ -6567,13 +6775,9 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) {
// simplify this comparison. For example, (x&4) < 8 is always true.
switch (Pred) {
default:
- llvm_unreachable("Unknown icmp opcode!");
+ break;
case ICmpInst::ICMP_EQ:
case ICmpInst::ICMP_NE: {
- if (Op0Max.ult(Op1Min) || Op0Min.ugt(Op1Max))
- return replaceInstUsesWith(
- I, ConstantInt::getBool(I.getType(), Pred == CmpInst::ICMP_NE));
-
// If all bits are known zero except for one, then we know at most one bit
// is set. If the comparison is against zero, then this is a check to see if
// *that* bit is set.
@@ -6613,78 +6817,34 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) {
ConstantInt::getNullValue(Op1->getType()));
break;
}
- case ICmpInst::ICMP_ULT: {
- if (Op0Max.ult(Op1Min)) // A <u B -> true if max(A) < min(B)
- return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
- if (Op0Min.uge(Op1Max)) // A <u B -> false if min(A) >= max(B)
- return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
- break;
- }
- case ICmpInst::ICMP_UGT: {
- if (Op0Min.ugt(Op1Max)) // A >u B -> true if min(A) > max(B)
- return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
- if (Op0Max.ule(Op1Min)) // A >u B -> false if max(A) <= max(B)
- return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
- break;
- }
- case ICmpInst::ICMP_SLT: {
- if (Op0Max.slt(Op1Min)) // A <s B -> true if max(A) < min(C)
- return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
- if (Op0Min.sge(Op1Max)) // A <s B -> false if min(A) >= max(C)
- return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
- break;
- }
- case ICmpInst::ICMP_SGT: {
- if (Op0Min.sgt(Op1Max)) // A >s B -> true if min(A) > max(B)
- return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
- if (Op0Max.sle(Op1Min)) // A >s B -> false if max(A) <= min(B)
- return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
- break;
- }
case ICmpInst::ICMP_SGE:
- assert(!isa<ConstantInt>(Op1) && "ICMP_SGE with ConstantInt not folded!");
- if (Op0Min.sge(Op1Max)) // A >=s B -> true if min(A) >= max(B)
- return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
- if (Op0Max.slt(Op1Min)) // A >=s B -> false if max(A) < min(B)
- return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
if (Op1Min == Op0Max) // A >=s B -> A == B if max(A) == min(B)
return new ICmpInst(ICmpInst::ICMP_EQ, Op0, Op1);
break;
case ICmpInst::ICMP_SLE:
- assert(!isa<ConstantInt>(Op1) && "ICMP_SLE with ConstantInt not folded!");
- if (Op0Max.sle(Op1Min)) // A <=s B -> true if max(A) <= min(B)
- return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
- if (Op0Min.sgt(Op1Max)) // A <=s B -> false if min(A) > max(B)
- return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
if (Op1Max == Op0Min) // A <=s B -> A == B if min(A) == max(B)
return new ICmpInst(ICmpInst::ICMP_EQ, Op0, Op1);
break;
case ICmpInst::ICMP_UGE:
- assert(!isa<ConstantInt>(Op1) && "ICMP_UGE with ConstantInt not folded!");
- if (Op0Min.uge(Op1Max)) // A >=u B -> true if min(A) >= max(B)
- return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
- if (Op0Max.ult(Op1Min)) // A >=u B -> false if max(A) < min(B)
- return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
if (Op1Min == Op0Max) // A >=u B -> A == B if max(A) == min(B)
return new ICmpInst(ICmpInst::ICMP_EQ, Op0, Op1);
break;
case ICmpInst::ICMP_ULE:
- assert(!isa<ConstantInt>(Op1) && "ICMP_ULE with ConstantInt not folded!");
- if (Op0Max.ule(Op1Min)) // A <=u B -> true if max(A) <= min(B)
- return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
- if (Op0Min.ugt(Op1Max)) // A <=u B -> false if min(A) > max(B)
- return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
if (Op1Max == Op0Min) // A <=u B -> A == B if min(A) == max(B)
return new ICmpInst(ICmpInst::ICMP_EQ, Op0, Op1);
break;
}
// Turn a signed comparison into an unsigned one if both operands are known to
- // have the same sign.
- if (I.isSigned() &&
+ // have the same sign. Set samesign if possible (except for equality
+ // predicates).
+ if ((I.isSigned() || (I.isUnsigned() && !I.hasSameSign())) &&
((Op0Known.Zero.isNegative() && Op1Known.Zero.isNegative()) ||
- (Op0Known.One.isNegative() && Op1Known.One.isNegative())))
- return new ICmpInst(I.getUnsignedPredicate(), Op0, Op1);
+ (Op0Known.One.isNegative() && Op1Known.One.isNegative()))) {
+ I.setPredicate(I.getUnsignedPredicate());
+ I.setSameSign();
+ return &I;
+ }
return nullptr;
}
@@ -6693,7 +6853,7 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) {
/// then try to reduce patterns based on that limit.
Instruction *InstCombinerImpl::foldICmpUsingBoolRange(ICmpInst &I) {
Value *X, *Y;
- ICmpInst::Predicate Pred;
+ CmpPredicate Pred;
// X must be 0 and bool must be true for "ULT":
// X <u (zext i1 Y) --> (X == 0) & Y
@@ -6708,7 +6868,7 @@ Instruction *InstCombinerImpl::foldICmpUsingBoolRange(ICmpInst &I) {
return BinaryOperator::CreateOr(Builder.CreateIsNull(X), Y);
// icmp eq/ne X, (zext/sext (icmp eq/ne X, C))
- ICmpInst::Predicate Pred1, Pred2;
+ CmpPredicate Pred1, Pred2;
const APInt *C;
Instruction *ExtI;
if (match(&I, m_c_ICmp(Pred1, m_Value(X),
@@ -6777,79 +6937,6 @@ Instruction *InstCombinerImpl::foldICmpUsingBoolRange(ICmpInst &I) {
return nullptr;
}
-std::optional<std::pair<CmpInst::Predicate, Constant *>>
-InstCombiner::getFlippedStrictnessPredicateAndConstant(CmpInst::Predicate Pred,
- Constant *C) {
- assert(ICmpInst::isRelational(Pred) && ICmpInst::isIntPredicate(Pred) &&
- "Only for relational integer predicates.");
-
- Type *Type = C->getType();
- bool IsSigned = ICmpInst::isSigned(Pred);
-
- CmpInst::Predicate UnsignedPred = ICmpInst::getUnsignedPredicate(Pred);
- bool WillIncrement =
- UnsignedPred == ICmpInst::ICMP_ULE || UnsignedPred == ICmpInst::ICMP_UGT;
-
- // Check if the constant operand can be safely incremented/decremented
- // without overflowing/underflowing.
- auto ConstantIsOk = [WillIncrement, IsSigned](ConstantInt *C) {
- return WillIncrement ? !C->isMaxValue(IsSigned) : !C->isMinValue(IsSigned);
- };
-
- Constant *SafeReplacementConstant = nullptr;
- if (auto *CI = dyn_cast<ConstantInt>(C)) {
- // Bail out if the constant can't be safely incremented/decremented.
- if (!ConstantIsOk(CI))
- return std::nullopt;
- } else if (auto *FVTy = dyn_cast<FixedVectorType>(Type)) {
- unsigned NumElts = FVTy->getNumElements();
- for (unsigned i = 0; i != NumElts; ++i) {
- Constant *Elt = C->getAggregateElement(i);
- if (!Elt)
- return std::nullopt;
-
- if (isa<UndefValue>(Elt))
- continue;
-
- // Bail out if we can't determine if this constant is min/max or if we
- // know that this constant is min/max.
- auto *CI = dyn_cast<ConstantInt>(Elt);
- if (!CI || !ConstantIsOk(CI))
- return std::nullopt;
-
- if (!SafeReplacementConstant)
- SafeReplacementConstant = CI;
- }
- } else if (isa<VectorType>(C->getType())) {
- // Handle scalable splat
- Value *SplatC = C->getSplatValue();
- auto *CI = dyn_cast_or_null<ConstantInt>(SplatC);
- // Bail out if the constant can't be safely incremented/decremented.
- if (!CI || !ConstantIsOk(CI))
- return std::nullopt;
- } else {
- // ConstantExpr?
- return std::nullopt;
- }
-
- // It may not be safe to change a compare predicate in the presence of
- // undefined elements, so replace those elements with the first safe constant
- // that we found.
- // TODO: in case of poison, it is safe; let's replace undefs only.
- if (C->containsUndefOrPoisonElement()) {
- assert(SafeReplacementConstant && "Replacement constant not set");
- C = Constant::replaceUndefsWith(C, SafeReplacementConstant);
- }
-
- CmpInst::Predicate NewPred = CmpInst::getFlippedStrictnessPredicate(Pred);
-
- // Increment or decrement the constant.
- Constant *OneOrNegOne = ConstantInt::get(Type, WillIncrement ? 1 : -1, true);
- Constant *NewC = ConstantExpr::getAdd(C, OneOrNegOne);
-
- return std::make_pair(NewPred, NewC);
-}
-
/// If we have an icmp le or icmp ge instruction with a constant operand, turn
/// it into the appropriate icmp lt or icmp gt instruction. This transform
/// allows them to be folded in visitICmpInst.
@@ -6865,8 +6952,7 @@ static ICmpInst *canonicalizeCmpWithConstant(ICmpInst &I) {
if (!Op1C)
return nullptr;
- auto FlippedStrictness =
- InstCombiner::getFlippedStrictnessPredicateAndConstant(Pred, Op1C);
+ auto FlippedStrictness = getFlippedStrictnessPredicateAndConstant(Pred, Op1C);
if (!FlippedStrictness)
return nullptr;
@@ -6978,7 +7064,7 @@ static Instruction *canonicalizeICmpBool(ICmpInst &I,
// (X l>> Y) == 0
static Instruction *foldICmpWithHighBitMask(ICmpInst &Cmp,
InstCombiner::BuilderTy &Builder) {
- ICmpInst::Predicate Pred, NewPred;
+ CmpPredicate Pred, NewPred;
Value *X, *Y;
if (match(&Cmp,
m_c_ICmp(Pred, m_OneUse(m_Shl(m_One(), m_Value(Y))), m_Value(X)))) {
@@ -7030,8 +7116,8 @@ static Instruction *foldVectorCmp(CmpInst &Cmp,
if (auto *I = dyn_cast<Instruction>(V))
I->copyIRFlags(&Cmp);
Module *M = Cmp.getModule();
- Function *F =
- Intrinsic::getDeclaration(M, Intrinsic::vector_reverse, V->getType());
+ Function *F = Intrinsic::getOrInsertDeclaration(
+ M, Intrinsic::vector_reverse, V->getType());
return CallInst::Create(F, V);
};
@@ -7143,7 +7229,7 @@ static Instruction *foldReductionIdiom(ICmpInst &I,
const DataLayout &DL) {
if (I.getType()->isVectorTy())
return nullptr;
- ICmpInst::Predicate OuterPred, InnerPred;
+ CmpPredicate OuterPred, InnerPred;
Value *LHS, *RHS;
// Match lowering of @llvm.vector.reduce.and. Turn
@@ -7184,7 +7270,7 @@ static Instruction *foldReductionIdiom(ICmpInst &I,
}
// This helper will be called with icmp operands in both orders.
-Instruction *InstCombinerImpl::foldICmpCommutative(ICmpInst::Predicate Pred,
+Instruction *InstCombinerImpl::foldICmpCommutative(CmpPredicate Pred,
Value *Op0, Value *Op1,
ICmpInst &CxtI) {
// Try to optimize 'icmp GEP, P' or 'icmp P, GEP'.
@@ -7312,7 +7398,7 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) {
Changed = true;
}
- if (Value *V = simplifyICmpInst(I.getPredicate(), Op0, Op1, Q))
+ if (Value *V = simplifyICmpInst(I.getCmpPredicate(), Op0, Op1, Q))
return replaceInstUsesWith(I, V);
// Comparing -val or val with non-zero is the same as just comparing val
@@ -7419,10 +7505,10 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) {
if (Instruction *Res = foldICmpInstWithConstantNotInt(I))
return Res;
- if (Instruction *Res = foldICmpCommutative(I.getPredicate(), Op0, Op1, I))
+ if (Instruction *Res = foldICmpCommutative(I.getCmpPredicate(), Op0, Op1, I))
return Res;
if (Instruction *Res =
- foldICmpCommutative(I.getSwappedPredicate(), Op1, Op0, I))
+ foldICmpCommutative(I.getSwappedCmpPredicate(), Op1, Op0, I))
return Res;
if (I.isCommutative()) {
@@ -7596,6 +7682,32 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) {
if (Instruction *Res = foldReductionIdiom(I, Builder, DL))
return Res;
+ {
+ Value *A;
+ const APInt *C1, *C2;
+ ICmpInst::Predicate Pred = I.getPredicate();
+ if (ICmpInst::isEquality(Pred)) {
+ // sext(a) & c1 == c2 --> a & c3 == trunc(c2)
+ // sext(a) & c1 != c2 --> a & c3 != trunc(c2)
+ if (match(Op0, m_And(m_SExt(m_Value(A)), m_APInt(C1))) &&
+ match(Op1, m_APInt(C2))) {
+ Type *InputTy = A->getType();
+ unsigned InputBitWidth = InputTy->getScalarSizeInBits();
+ // c2 must be non-negative at the bitwidth of a.
+ if (C2->getActiveBits() < InputBitWidth) {
+ APInt TruncC1 = C1->trunc(InputBitWidth);
+ // Check if there are 1s in C1 high bits of size InputBitWidth.
+ if (C1->uge(APInt::getOneBitSet(C1->getBitWidth(), InputBitWidth)))
+ TruncC1.setBit(InputBitWidth - 1);
+ Value *AndInst = Builder.CreateAnd(A, TruncC1);
+ return new ICmpInst(
+ Pred, AndInst,
+ ConstantInt::get(InputTy, C2->trunc(InputBitWidth)));
+ }
+ }
+ }
+ }
+
return Changed ? &I : nullptr;
}
@@ -7983,6 +8095,67 @@ static Instruction *foldFabsWithFcmpZero(FCmpInst &I, InstCombinerImpl &IC) {
}
}
+/// Optimize sqrt(X) compared with zero.
+static Instruction *foldSqrtWithFcmpZero(FCmpInst &I, InstCombinerImpl &IC) {
+ Value *X;
+ if (!match(I.getOperand(0), m_Sqrt(m_Value(X))))
+ return nullptr;
+
+ if (!match(I.getOperand(1), m_PosZeroFP()))
+ return nullptr;
+
+ auto ReplacePredAndOp0 = [&](FCmpInst::Predicate P) {
+ I.setPredicate(P);
+ return IC.replaceOperand(I, 0, X);
+ };
+
+ // Clear ninf flag if sqrt doesn't have it.
+ if (!cast<Instruction>(I.getOperand(0))->hasNoInfs())
+ I.setHasNoInfs(false);
+
+ switch (I.getPredicate()) {
+ case FCmpInst::FCMP_OLT:
+ case FCmpInst::FCMP_UGE:
+ // sqrt(X) < 0.0 --> false
+ // sqrt(X) u>= 0.0 --> true
+ llvm_unreachable("fcmp should have simplified");
+ case FCmpInst::FCMP_ULT:
+ case FCmpInst::FCMP_ULE:
+ case FCmpInst::FCMP_OGT:
+ case FCmpInst::FCMP_OGE:
+ case FCmpInst::FCMP_OEQ:
+ case FCmpInst::FCMP_UNE:
+ // sqrt(X) u< 0.0 --> X u< 0.0
+ // sqrt(X) u<= 0.0 --> X u<= 0.0
+ // sqrt(X) > 0.0 --> X > 0.0
+ // sqrt(X) >= 0.0 --> X >= 0.0
+ // sqrt(X) == 0.0 --> X == 0.0
+ // sqrt(X) u!= 0.0 --> X u!= 0.0
+ return IC.replaceOperand(I, 0, X);
+
+ case FCmpInst::FCMP_OLE:
+ // sqrt(X) <= 0.0 --> X == 0.0
+ return ReplacePredAndOp0(FCmpInst::FCMP_OEQ);
+ case FCmpInst::FCMP_UGT:
+ // sqrt(X) u> 0.0 --> X u!= 0.0
+ return ReplacePredAndOp0(FCmpInst::FCMP_UNE);
+ case FCmpInst::FCMP_UEQ:
+ // sqrt(X) u== 0.0 --> X u<= 0.0
+ return ReplacePredAndOp0(FCmpInst::FCMP_ULE);
+ case FCmpInst::FCMP_ONE:
+ // sqrt(X) != 0.0 --> X > 0.0
+ return ReplacePredAndOp0(FCmpInst::FCMP_OGT);
+ case FCmpInst::FCMP_ORD:
+ // !isnan(sqrt(X)) --> X >= 0.0
+ return ReplacePredAndOp0(FCmpInst::FCMP_OGE);
+ case FCmpInst::FCMP_UNO:
+ // isnan(sqrt(X)) --> X u< 0.0
+ return ReplacePredAndOp0(FCmpInst::FCMP_ULT);
+ default:
+ llvm_unreachable("Unexpected predicate!");
+ }
+}
+
static Instruction *foldFCmpFNegCommonOp(FCmpInst &I) {
CmpInst::Predicate Pred = I.getPredicate();
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
@@ -8049,6 +8222,75 @@ static Instruction *foldFCmpFSubIntoFCmp(FCmpInst &I, Instruction *LHSI,
return nullptr;
}
+static Instruction *foldFCmpWithFloorAndCeil(FCmpInst &I,
+ InstCombinerImpl &IC) {
+ Value *LHS = I.getOperand(0), *RHS = I.getOperand(1);
+ Type *OpType = LHS->getType();
+ CmpInst::Predicate Pred = I.getPredicate();
+
+ bool FloorX = match(LHS, m_Intrinsic<Intrinsic::floor>(m_Specific(RHS)));
+ bool CeilX = match(LHS, m_Intrinsic<Intrinsic::ceil>(m_Specific(RHS)));
+
+ if (!FloorX && !CeilX) {
+ if ((FloorX = match(RHS, m_Intrinsic<Intrinsic::floor>(m_Specific(LHS)))) ||
+ (CeilX = match(RHS, m_Intrinsic<Intrinsic::ceil>(m_Specific(LHS))))) {
+ std::swap(LHS, RHS);
+ Pred = I.getSwappedPredicate();
+ }
+ }
+
+ switch (Pred) {
+ case FCmpInst::FCMP_OLE:
+ // fcmp ole floor(x), x => fcmp ord x, 0
+ if (FloorX)
+ return new FCmpInst(FCmpInst::FCMP_ORD, RHS, ConstantFP::getZero(OpType),
+ "", &I);
+ break;
+ case FCmpInst::FCMP_OGT:
+ // fcmp ogt floor(x), x => false
+ if (FloorX)
+ return IC.replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
+ break;
+ case FCmpInst::FCMP_OGE:
+ // fcmp oge ceil(x), x => fcmp ord x, 0
+ if (CeilX)
+ return new FCmpInst(FCmpInst::FCMP_ORD, RHS, ConstantFP::getZero(OpType),
+ "", &I);
+ break;
+ case FCmpInst::FCMP_OLT:
+ // fcmp olt ceil(x), x => false
+ if (CeilX)
+ return IC.replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
+ break;
+ case FCmpInst::FCMP_ULE:
+ // fcmp ule floor(x), x => true
+ if (FloorX)
+ return IC.replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
+ break;
+ case FCmpInst::FCMP_UGT:
+ // fcmp ugt floor(x), x => fcmp uno x, 0
+ if (FloorX)
+ return new FCmpInst(FCmpInst::FCMP_UNO, RHS, ConstantFP::getZero(OpType),
+ "", &I);
+ break;
+ case FCmpInst::FCMP_UGE:
+ // fcmp uge ceil(x), x => true
+ if (CeilX)
+ return IC.replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
+ break;
+ case FCmpInst::FCMP_ULT:
+ // fcmp ult ceil(x), x => fcmp uno x, 0
+ if (CeilX)
+ return new FCmpInst(FCmpInst::FCMP_UNO, RHS, ConstantFP::getZero(OpType),
+ "", &I);
+ break;
+ default:
+ break;
+ }
+
+ return nullptr;
+}
+
Instruction *InstCombinerImpl::visitFCmpInst(FCmpInst &I) {
bool Changed = false;
@@ -8212,9 +8454,7 @@ Instruction *InstCombinerImpl::visitFCmpInst(FCmpInst &I) {
case Instruction::Select:
// fcmp eq (cond ? x : -x), 0 --> fcmp eq x, 0
if (FCmpInst::isEquality(Pred) && match(RHSC, m_AnyZeroFP()) &&
- (match(LHSI,
- m_Select(m_Value(), m_Value(X), m_FNeg(m_Deferred(X)))) ||
- match(LHSI, m_Select(m_Value(), m_FNeg(m_Value(X)), m_Deferred(X)))))
+ match(LHSI, m_c_Select(m_FNeg(m_Value(X)), m_Deferred(X))))
return replaceOperand(I, 0, X);
if (Instruction *NV = FoldOpIntoSelect(I, cast<SelectInst>(LHSI)))
return NV;
@@ -8250,6 +8490,12 @@ Instruction *InstCombinerImpl::visitFCmpInst(FCmpInst &I) {
if (Instruction *R = foldFabsWithFcmpZero(I, *this))
return R;
+ if (Instruction *R = foldSqrtWithFcmpZero(I, *this))
+ return R;
+
+ if (Instruction *R = foldFCmpWithFloorAndCeil(I, *this))
+ return R;
+
if (match(Op0, m_FNeg(m_Value(X)))) {
// fcmp pred (fneg X), C --> fcmp swap(pred) X, -C
Constant *C;