aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms/InstCombine
diff options
context:
space:
mode:
authorDimitry Andric <dim@FreeBSD.org>2021-02-16 20:13:02 +0000
committerDimitry Andric <dim@FreeBSD.org>2021-02-16 20:13:02 +0000
commitb60736ec1405bb0a8dd40989f67ef4c93da068ab (patch)
tree5c43fbb7c9fc45f0f87e0e6795a86267dbd12f9d /llvm/lib/Transforms/InstCombine
parentcfca06d7963fa0909f90483b42a6d7d194d01e08 (diff)
downloadsrc-b60736ec1405bb0a8dd40989f67ef4c93da068ab.tar.gz
src-b60736ec1405bb0a8dd40989f67ef4c93da068ab.zip
Vendor import of llvm-project main 8e464dd76bef, the last commit beforevendor/llvm-project/llvmorg-12-init-17869-g8e464dd76bef
the upstream release/12.x branch was created.
Diffstat (limited to 'llvm/lib/Transforms/InstCombine')
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp335
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp760
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineAtomicRMW.cpp18
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp3319
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp375
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp502
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineInternal.h409
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp232
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp205
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp149
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp245
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp491
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp208
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp597
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineTables.td11
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp434
-rw-r--r--llvm/lib/Transforms/InstCombine/InstructionCombining.cpp417
17 files changed, 3267 insertions, 5440 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index a7f5e0a7774d..bacb8689892a 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -29,6 +29,7 @@
#include "llvm/Support/AlignOf.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/KnownBits.h"
+#include "llvm/Transforms/InstCombine/InstCombiner.h"
#include <cassert>
#include <utility>
@@ -81,11 +82,11 @@ namespace {
private:
bool insaneIntVal(int V) { return V > 4 || V < -4; }
- APFloat *getFpValPtr()
- { return reinterpret_cast<APFloat *>(&FpValBuf.buffer[0]); }
+ APFloat *getFpValPtr() { return reinterpret_cast<APFloat *>(&FpValBuf); }
- const APFloat *getFpValPtr() const
- { return reinterpret_cast<const APFloat *>(&FpValBuf.buffer[0]); }
+ const APFloat *getFpValPtr() const {
+ return reinterpret_cast<const APFloat *>(&FpValBuf);
+ }
const APFloat &getFpVal() const {
assert(IsFp && BufHasFpVal && "Incorret state");
@@ -860,7 +861,7 @@ static Instruction *foldNoWrapAdd(BinaryOperator &Add,
return nullptr;
}
-Instruction *InstCombiner::foldAddWithConstant(BinaryOperator &Add) {
+Instruction *InstCombinerImpl::foldAddWithConstant(BinaryOperator &Add) {
Value *Op0 = Add.getOperand(0), *Op1 = Add.getOperand(1);
Constant *Op1C;
if (!match(Op1, m_Constant(Op1C)))
@@ -886,15 +887,15 @@ Instruction *InstCombiner::foldAddWithConstant(BinaryOperator &Add) {
// zext(bool) + C -> bool ? C + 1 : C
if (match(Op0, m_ZExt(m_Value(X))) &&
X->getType()->getScalarSizeInBits() == 1)
- return SelectInst::Create(X, AddOne(Op1C), Op1);
+ return SelectInst::Create(X, InstCombiner::AddOne(Op1C), Op1);
// sext(bool) + C -> bool ? C - 1 : C
if (match(Op0, m_SExt(m_Value(X))) &&
X->getType()->getScalarSizeInBits() == 1)
- return SelectInst::Create(X, SubOne(Op1C), Op1);
+ return SelectInst::Create(X, InstCombiner::SubOne(Op1C), Op1);
// ~X + C --> (C-1) - X
if (match(Op0, m_Not(m_Value(X))))
- return BinaryOperator::CreateSub(SubOne(Op1C), X);
+ return BinaryOperator::CreateSub(InstCombiner::SubOne(Op1C), X);
const APInt *C;
if (!match(Op1, m_APInt(C)))
@@ -923,6 +924,39 @@ Instruction *InstCombiner::foldAddWithConstant(BinaryOperator &Add) {
C2->isMinSignedValue() && C2->sext(Ty->getScalarSizeInBits()) == *C)
return CastInst::Create(Instruction::SExt, X, Ty);
+ if (match(Op0, m_Xor(m_Value(X), m_APInt(C2)))) {
+ // (X ^ signmask) + C --> (X + (signmask ^ C))
+ if (C2->isSignMask())
+ return BinaryOperator::CreateAdd(X, ConstantInt::get(Ty, *C2 ^ *C));
+
+ // If X has no high-bits set above an xor mask:
+ // add (xor X, LowMaskC), C --> sub (LowMaskC + C), X
+ if (C2->isMask()) {
+ KnownBits LHSKnown = computeKnownBits(X, 0, &Add);
+ if ((*C2 | LHSKnown.Zero).isAllOnesValue())
+ return BinaryOperator::CreateSub(ConstantInt::get(Ty, *C2 + *C), X);
+ }
+
+ // Look for a math+logic pattern that corresponds to sext-in-register of a
+ // value with cleared high bits. Convert that into a pair of shifts:
+ // add (xor X, 0x80), 0xF..F80 --> (X << ShAmtC) >>s ShAmtC
+ // add (xor X, 0xF..F80), 0x80 --> (X << ShAmtC) >>s ShAmtC
+ if (Op0->hasOneUse() && *C2 == -(*C)) {
+ unsigned BitWidth = Ty->getScalarSizeInBits();
+ unsigned ShAmt = 0;
+ if (C->isPowerOf2())
+ ShAmt = BitWidth - C->logBase2() - 1;
+ else if (C2->isPowerOf2())
+ ShAmt = BitWidth - C2->logBase2() - 1;
+ if (ShAmt && MaskedValueIsZero(X, APInt::getHighBitsSet(BitWidth, ShAmt),
+ 0, &Add)) {
+ Constant *ShAmtC = ConstantInt::get(Ty, ShAmt);
+ Value *NewShl = Builder.CreateShl(X, ShAmtC, "sext");
+ return BinaryOperator::CreateAShr(NewShl, ShAmtC);
+ }
+ }
+ }
+
if (C->isOneValue() && Op0->hasOneUse()) {
// add (sext i1 X), 1 --> zext (not X)
// TODO: The smallest IR representation is (select X, 0, 1), and that would
@@ -943,6 +977,15 @@ Instruction *InstCombiner::foldAddWithConstant(BinaryOperator &Add) {
}
}
+ // If all bits affected by the add are included in a high-bit-mask, do the
+ // add before the mask op:
+ // (X & 0xFF00) + xx00 --> (X + xx00) & 0xFF00
+ if (match(Op0, m_OneUse(m_And(m_Value(X), m_APInt(C2)))) &&
+ C2->isNegative() && C2->isShiftedMask() && *C == (*C & *C2)) {
+ Value *NewAdd = Builder.CreateAdd(X, ConstantInt::get(Ty, *C));
+ return BinaryOperator::CreateAnd(NewAdd, ConstantInt::get(Ty, *C2));
+ }
+
return nullptr;
}
@@ -1021,7 +1064,7 @@ static bool MulWillOverflow(APInt &C0, APInt &C1, bool IsSigned) {
// Simplifies X % C0 + (( X / C0 ) % C1) * C0 to X % (C0 * C1), where (C0 * C1)
// does not overflow.
-Value *InstCombiner::SimplifyAddWithRemainder(BinaryOperator &I) {
+Value *InstCombinerImpl::SimplifyAddWithRemainder(BinaryOperator &I) {
Value *LHS = I.getOperand(0), *RHS = I.getOperand(1);
Value *X, *MulOpV;
APInt C0, MulOpC;
@@ -1097,9 +1140,9 @@ static Instruction *foldToUnsignedSaturatedAdd(BinaryOperator &I) {
return nullptr;
}
-Instruction *
-InstCombiner::canonicalizeCondSignextOfHighBitExtractToSignextHighBitExtract(
- BinaryOperator &I) {
+Instruction *InstCombinerImpl::
+ canonicalizeCondSignextOfHighBitExtractToSignextHighBitExtract(
+ BinaryOperator &I) {
assert((I.getOpcode() == Instruction::Add ||
I.getOpcode() == Instruction::Or ||
I.getOpcode() == Instruction::Sub) &&
@@ -1198,7 +1241,44 @@ InstCombiner::canonicalizeCondSignextOfHighBitExtractToSignextHighBitExtract(
return TruncInst::CreateTruncOrBitCast(NewAShr, I.getType());
}
-Instruction *InstCombiner::visitAdd(BinaryOperator &I) {
+/// This is a specialization of a more general transform from
+/// SimplifyUsingDistributiveLaws. If that code can be made to work optimally
+/// for multi-use cases or propagating nsw/nuw, then we would not need this.
+static Instruction *factorizeMathWithShlOps(BinaryOperator &I,
+ InstCombiner::BuilderTy &Builder) {
+ // TODO: Also handle mul by doubling the shift amount?
+ assert((I.getOpcode() == Instruction::Add ||
+ I.getOpcode() == Instruction::Sub) &&
+ "Expected add/sub");
+ auto *Op0 = dyn_cast<BinaryOperator>(I.getOperand(0));
+ auto *Op1 = dyn_cast<BinaryOperator>(I.getOperand(1));
+ if (!Op0 || !Op1 || !(Op0->hasOneUse() || Op1->hasOneUse()))
+ return nullptr;
+
+ Value *X, *Y, *ShAmt;
+ if (!match(Op0, m_Shl(m_Value(X), m_Value(ShAmt))) ||
+ !match(Op1, m_Shl(m_Value(Y), m_Specific(ShAmt))))
+ return nullptr;
+
+ // No-wrap propagates only when all ops have no-wrap.
+ bool HasNSW = I.hasNoSignedWrap() && Op0->hasNoSignedWrap() &&
+ Op1->hasNoSignedWrap();
+ bool HasNUW = I.hasNoUnsignedWrap() && Op0->hasNoUnsignedWrap() &&
+ Op1->hasNoUnsignedWrap();
+
+ // add/sub (X << ShAmt), (Y << ShAmt) --> (add/sub X, Y) << ShAmt
+ Value *NewMath = Builder.CreateBinOp(I.getOpcode(), X, Y);
+ if (auto *NewI = dyn_cast<BinaryOperator>(NewMath)) {
+ NewI->setHasNoSignedWrap(HasNSW);
+ NewI->setHasNoUnsignedWrap(HasNUW);
+ }
+ auto *NewShl = BinaryOperator::CreateShl(NewMath, ShAmt);
+ NewShl->setHasNoSignedWrap(HasNSW);
+ NewShl->setHasNoUnsignedWrap(HasNUW);
+ return NewShl;
+}
+
+Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
if (Value *V = SimplifyAddInst(I.getOperand(0), I.getOperand(1),
I.hasNoSignedWrap(), I.hasNoUnsignedWrap(),
SQ.getWithInstruction(&I)))
@@ -1214,59 +1294,17 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) {
if (Value *V = SimplifyUsingDistributiveLaws(I))
return replaceInstUsesWith(I, V);
+ if (Instruction *R = factorizeMathWithShlOps(I, Builder))
+ return R;
+
if (Instruction *X = foldAddWithConstant(I))
return X;
if (Instruction *X = foldNoWrapAdd(I, Builder))
return X;
- // FIXME: This should be moved into the above helper function to allow these
- // transforms for general constant or constant splat vectors.
Value *LHS = I.getOperand(0), *RHS = I.getOperand(1);
Type *Ty = I.getType();
- if (ConstantInt *CI = dyn_cast<ConstantInt>(RHS)) {
- Value *XorLHS = nullptr; ConstantInt *XorRHS = nullptr;
- if (match(LHS, m_Xor(m_Value(XorLHS), m_ConstantInt(XorRHS)))) {
- unsigned TySizeBits = Ty->getScalarSizeInBits();
- const APInt &RHSVal = CI->getValue();
- unsigned ExtendAmt = 0;
- // If we have ADD(XOR(AND(X, 0xFF), 0x80), 0xF..F80), it's a sext.
- // If we have ADD(XOR(AND(X, 0xFF), 0xF..F80), 0x80), it's a sext.
- if (XorRHS->getValue() == -RHSVal) {
- if (RHSVal.isPowerOf2())
- ExtendAmt = TySizeBits - RHSVal.logBase2() - 1;
- else if (XorRHS->getValue().isPowerOf2())
- ExtendAmt = TySizeBits - XorRHS->getValue().logBase2() - 1;
- }
-
- if (ExtendAmt) {
- APInt Mask = APInt::getHighBitsSet(TySizeBits, ExtendAmt);
- if (!MaskedValueIsZero(XorLHS, Mask, 0, &I))
- ExtendAmt = 0;
- }
-
- if (ExtendAmt) {
- Constant *ShAmt = ConstantInt::get(Ty, ExtendAmt);
- Value *NewShl = Builder.CreateShl(XorLHS, ShAmt, "sext");
- return BinaryOperator::CreateAShr(NewShl, ShAmt);
- }
-
- // If this is a xor that was canonicalized from a sub, turn it back into
- // a sub and fuse this add with it.
- if (LHS->hasOneUse() && (XorRHS->getValue()+1).isPowerOf2()) {
- KnownBits LHSKnown = computeKnownBits(XorLHS, 0, &I);
- if ((XorRHS->getValue() | LHSKnown.Zero).isAllOnesValue())
- return BinaryOperator::CreateSub(ConstantExpr::getAdd(XorRHS, CI),
- XorLHS);
- }
- // (X + signmask) + C could have gotten canonicalized to (X^signmask) + C,
- // transform them into (X + (signmask ^ C))
- if (XorRHS->getValue().isSignMask())
- return BinaryOperator::CreateAdd(XorLHS,
- ConstantExpr::getXor(XorRHS, CI));
- }
- }
-
if (Ty->isIntOrIntVectorTy(1))
return BinaryOperator::CreateXor(LHS, RHS);
@@ -1329,34 +1367,6 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) {
if (haveNoCommonBitsSet(LHS, RHS, DL, &AC, &I, &DT))
return BinaryOperator::CreateOr(LHS, RHS);
- // FIXME: We already did a check for ConstantInt RHS above this.
- // FIXME: Is this pattern covered by another fold? No regression tests fail on
- // removal.
- if (ConstantInt *CRHS = dyn_cast<ConstantInt>(RHS)) {
- // (X & FF00) + xx00 -> (X+xx00) & FF00
- Value *X;
- ConstantInt *C2;
- if (LHS->hasOneUse() &&
- match(LHS, m_And(m_Value(X), m_ConstantInt(C2))) &&
- CRHS->getValue() == (CRHS->getValue() & C2->getValue())) {
- // See if all bits from the first bit set in the Add RHS up are included
- // in the mask. First, get the rightmost bit.
- const APInt &AddRHSV = CRHS->getValue();
-
- // Form a mask of all bits from the lowest bit added through the top.
- APInt AddRHSHighBits(~((AddRHSV & -AddRHSV)-1));
-
- // See if the and mask includes all of these bits.
- APInt AddRHSHighBitsAnd(AddRHSHighBits & C2->getValue());
-
- if (AddRHSHighBits == AddRHSHighBitsAnd) {
- // Okay, the xform is safe. Insert the new add pronto.
- Value *NewAdd = Builder.CreateAdd(X, CRHS, LHS->getName());
- return BinaryOperator::CreateAnd(NewAdd, C2);
- }
- }
- }
-
// add (select X 0 (sub n A)) A --> select X A n
{
SelectInst *SI = dyn_cast<SelectInst>(LHS);
@@ -1424,6 +1434,14 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) {
if (Instruction *SatAdd = foldToUnsignedSaturatedAdd(I))
return SatAdd;
+ // usub.sat(A, B) + B => umax(A, B)
+ if (match(&I, m_c_BinOp(
+ m_OneUse(m_Intrinsic<Intrinsic::usub_sat>(m_Value(A), m_Value(B))),
+ m_Deferred(B)))) {
+ return replaceInstUsesWith(I,
+ Builder.CreateIntrinsic(Intrinsic::umax, {I.getType()}, {A, B}));
+ }
+
return Changed ? &I : nullptr;
}
@@ -1486,7 +1504,7 @@ static Instruction *factorizeFAddFSub(BinaryOperator &I,
: BinaryOperator::CreateFDivFMF(XY, Z, &I);
}
-Instruction *InstCombiner::visitFAdd(BinaryOperator &I) {
+Instruction *InstCombinerImpl::visitFAdd(BinaryOperator &I) {
if (Value *V = SimplifyFAddInst(I.getOperand(0), I.getOperand(1),
I.getFastMathFlags(),
SQ.getWithInstruction(&I)))
@@ -1600,49 +1618,33 @@ Instruction *InstCombiner::visitFAdd(BinaryOperator &I) {
/// Optimize pointer differences into the same array into a size. Consider:
/// &A[10] - &A[0]: we should compile this to "10". LHS/RHS are the pointer
/// operands to the ptrtoint instructions for the LHS/RHS of the subtract.
-Value *InstCombiner::OptimizePointerDifference(Value *LHS, Value *RHS,
- Type *Ty, bool IsNUW) {
+Value *InstCombinerImpl::OptimizePointerDifference(Value *LHS, Value *RHS,
+ Type *Ty, bool IsNUW) {
// If LHS is a gep based on RHS or RHS is a gep based on LHS, we can optimize
// this.
bool Swapped = false;
GEPOperator *GEP1 = nullptr, *GEP2 = nullptr;
+ if (!isa<GEPOperator>(LHS) && isa<GEPOperator>(RHS)) {
+ std::swap(LHS, RHS);
+ Swapped = true;
+ }
- // For now we require one side to be the base pointer "A" or a constant
- // GEP derived from it.
- if (GEPOperator *LHSGEP = dyn_cast<GEPOperator>(LHS)) {
+ // Require at least one GEP with a common base pointer on both sides.
+ if (auto *LHSGEP = dyn_cast<GEPOperator>(LHS)) {
// (gep X, ...) - X
if (LHSGEP->getOperand(0) == RHS) {
GEP1 = LHSGEP;
- Swapped = false;
- } else if (GEPOperator *RHSGEP = dyn_cast<GEPOperator>(RHS)) {
+ } else if (auto *RHSGEP = dyn_cast<GEPOperator>(RHS)) {
// (gep X, ...) - (gep X, ...)
if (LHSGEP->getOperand(0)->stripPointerCasts() ==
- RHSGEP->getOperand(0)->stripPointerCasts()) {
- GEP2 = RHSGEP;
+ RHSGEP->getOperand(0)->stripPointerCasts()) {
GEP1 = LHSGEP;
- Swapped = false;
- }
- }
- }
-
- if (GEPOperator *RHSGEP = dyn_cast<GEPOperator>(RHS)) {
- // X - (gep X, ...)
- if (RHSGEP->getOperand(0) == LHS) {
- GEP1 = RHSGEP;
- Swapped = true;
- } else if (GEPOperator *LHSGEP = dyn_cast<GEPOperator>(LHS)) {
- // (gep X, ...) - (gep X, ...)
- if (RHSGEP->getOperand(0)->stripPointerCasts() ==
- LHSGEP->getOperand(0)->stripPointerCasts()) {
- GEP2 = LHSGEP;
- GEP1 = RHSGEP;
- Swapped = true;
+ GEP2 = RHSGEP;
}
}
}
if (!GEP1)
- // No GEP found.
return nullptr;
if (GEP2) {
@@ -1670,19 +1672,18 @@ Value *InstCombiner::OptimizePointerDifference(Value *LHS, Value *RHS,
Value *Result = EmitGEPOffset(GEP1);
// If this is a single inbounds GEP and the original sub was nuw,
- // then the final multiplication is also nuw. We match an extra add zero
- // here, because that's what EmitGEPOffset() generates.
- Instruction *I;
- if (IsNUW && !GEP2 && !Swapped && GEP1->isInBounds() &&
- match(Result, m_Add(m_Instruction(I), m_Zero())) &&
- I->getOpcode() == Instruction::Mul)
- I->setHasNoUnsignedWrap();
-
- // If we had a constant expression GEP on the other side offsetting the
- // pointer, subtract it from the offset we have.
+ // then the final multiplication is also nuw.
+ if (auto *I = dyn_cast<Instruction>(Result))
+ if (IsNUW && !GEP2 && !Swapped && GEP1->isInBounds() &&
+ I->getOpcode() == Instruction::Mul)
+ I->setHasNoUnsignedWrap();
+
+ // If we have a 2nd GEP of the same base pointer, subtract the offsets.
+ // If both GEPs are inbounds, then the subtract does not have signed overflow.
if (GEP2) {
Value *Offset = EmitGEPOffset(GEP2);
- Result = Builder.CreateSub(Result, Offset);
+ Result = Builder.CreateSub(Result, Offset, "gepdiff", /* NUW */ false,
+ GEP1->isInBounds() && GEP2->isInBounds());
}
// If we have p - gep(p, ...) then we have to negate the result.
@@ -1692,7 +1693,7 @@ Value *InstCombiner::OptimizePointerDifference(Value *LHS, Value *RHS,
return Builder.CreateIntCast(Result, Ty, true);
}
-Instruction *InstCombiner::visitSub(BinaryOperator &I) {
+Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) {
if (Value *V = SimplifySubInst(I.getOperand(0), I.getOperand(1),
I.hasNoSignedWrap(), I.hasNoUnsignedWrap(),
SQ.getWithInstruction(&I)))
@@ -1721,6 +1722,19 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) {
return Res;
}
+ // Try this before Negator to preserve NSW flag.
+ if (Instruction *R = factorizeMathWithShlOps(I, Builder))
+ return R;
+
+ if (Constant *C = dyn_cast<Constant>(Op0)) {
+ Value *X;
+ Constant *C2;
+
+ // C-(X+C2) --> (C-C2)-X
+ if (match(Op1, m_Add(m_Value(X), m_Constant(C2))))
+ return BinaryOperator::CreateSub(ConstantExpr::getSub(C, C2), X);
+ }
+
auto TryToNarrowDeduceFlags = [this, &I, &Op0, &Op1]() -> Instruction * {
if (Instruction *Ext = narrowMathIfNoOverflow(I))
return Ext;
@@ -1788,8 +1802,7 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) {
}
auto m_AddRdx = [](Value *&Vec) {
- return m_OneUse(
- m_Intrinsic<Intrinsic::experimental_vector_reduce_add>(m_Value(Vec)));
+ return m_OneUse(m_Intrinsic<Intrinsic::vector_reduce_add>(m_Value(Vec)));
};
Value *V0, *V1;
if (match(Op0, m_AddRdx(V0)) && match(Op1, m_AddRdx(V1)) &&
@@ -1797,8 +1810,8 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) {
// Difference of sums is sum of differences:
// add_rdx(V0) - add_rdx(V1) --> add_rdx(V0 - V1)
Value *Sub = Builder.CreateSub(V0, V1);
- Value *Rdx = Builder.CreateIntrinsic(
- Intrinsic::experimental_vector_reduce_add, {Sub->getType()}, {Sub});
+ Value *Rdx = Builder.CreateIntrinsic(Intrinsic::vector_reduce_add,
+ {Sub->getType()}, {Sub});
return replaceInstUsesWith(I, Rdx);
}
@@ -1806,14 +1819,14 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) {
Value *X;
if (match(Op1, m_ZExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1))
// C - (zext bool) --> bool ? C - 1 : C
- return SelectInst::Create(X, SubOne(C), C);
+ return SelectInst::Create(X, InstCombiner::SubOne(C), C);
if (match(Op1, m_SExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1))
// C - (sext bool) --> bool ? C + 1 : C
- return SelectInst::Create(X, AddOne(C), C);
+ return SelectInst::Create(X, InstCombiner::AddOne(C), C);
// C - ~X == X + (1+C)
if (match(Op1, m_Not(m_Value(X))))
- return BinaryOperator::CreateAdd(X, AddOne(C));
+ return BinaryOperator::CreateAdd(X, InstCombiner::AddOne(C));
// Try to fold constant sub into select arguments.
if (SelectInst *SI = dyn_cast<SelectInst>(Op1))
@@ -1828,12 +1841,8 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) {
Constant *C2;
// C-(C2-X) --> X+(C-C2)
- if (match(Op1, m_Sub(m_Constant(C2), m_Value(X))) && !isa<ConstantExpr>(C2))
+ if (match(Op1, m_Sub(m_ImmConstant(C2), m_Value(X))))
return BinaryOperator::CreateAdd(X, ConstantExpr::getSub(C, C2));
-
- // C-(X+C2) --> (C-C2)-X
- if (match(Op1, m_Add(m_Value(X), m_Constant(C2))))
- return BinaryOperator::CreateSub(ConstantExpr::getSub(C, C2), X);
}
const APInt *Op0C;
@@ -1864,6 +1873,22 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) {
return BinaryOperator::CreateXor(A, B);
}
+ // (sub (add A, B) (or A, B)) --> (and A, B)
+ {
+ Value *A, *B;
+ if (match(Op0, m_Add(m_Value(A), m_Value(B))) &&
+ match(Op1, m_c_Or(m_Specific(A), m_Specific(B))))
+ return BinaryOperator::CreateAnd(A, B);
+ }
+
+ // (sub (add A, B) (and A, B)) --> (or A, B)
+ {
+ Value *A, *B;
+ if (match(Op0, m_Add(m_Value(A), m_Value(B))) &&
+ match(Op1, m_c_And(m_Specific(A), m_Specific(B))))
+ return BinaryOperator::CreateOr(A, B);
+ }
+
// (sub (and A, B) (or A, B)) --> neg (xor A, B)
{
Value *A, *B;
@@ -2042,6 +2067,20 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) {
return SelectInst::Create(Cmp, Neg, A);
}
+ // If we are subtracting a low-bit masked subset of some value from an add
+ // of that same value with no low bits changed, that is clearing some low bits
+ // of the sum:
+ // sub (X + AddC), (X & AndC) --> and (X + AddC), ~AndC
+ const APInt *AddC, *AndC;
+ if (match(Op0, m_Add(m_Value(X), m_APInt(AddC))) &&
+ match(Op1, m_And(m_Specific(X), m_APInt(AndC)))) {
+ unsigned BitWidth = Ty->getScalarSizeInBits();
+ unsigned Cttz = AddC->countTrailingZeros();
+ APInt HighMask(APInt::getHighBitsSet(BitWidth, BitWidth - Cttz));
+ if ((HighMask & *AndC).isNullValue())
+ return BinaryOperator::CreateAnd(Op0, ConstantInt::get(Ty, ~(*AndC)));
+ }
+
if (Instruction *V =
canonicalizeCondSignextOfHighBitExtractToSignextHighBitExtract(I))
return V;
@@ -2094,11 +2133,11 @@ static Instruction *hoistFNegAboveFMulFDiv(Instruction &I,
return nullptr;
}
-Instruction *InstCombiner::visitFNeg(UnaryOperator &I) {
+Instruction *InstCombinerImpl::visitFNeg(UnaryOperator &I) {
Value *Op = I.getOperand(0);
if (Value *V = SimplifyFNegInst(Op, I.getFastMathFlags(),
- SQ.getWithInstruction(&I)))
+ getSimplifyQuery().getWithInstruction(&I)))
return replaceInstUsesWith(I, V);
if (Instruction *X = foldFNegIntoConstant(I))
@@ -2117,10 +2156,10 @@ Instruction *InstCombiner::visitFNeg(UnaryOperator &I) {
return nullptr;
}
-Instruction *InstCombiner::visitFSub(BinaryOperator &I) {
+Instruction *InstCombinerImpl::visitFSub(BinaryOperator &I) {
if (Value *V = SimplifyFSubInst(I.getOperand(0), I.getOperand(1),
I.getFastMathFlags(),
- SQ.getWithInstruction(&I)))
+ getSimplifyQuery().getWithInstruction(&I)))
return replaceInstUsesWith(I, V);
if (Instruction *X = foldVectorBinop(I))
@@ -2175,7 +2214,7 @@ Instruction *InstCombiner::visitFSub(BinaryOperator &I) {
// X - C --> X + (-C)
// But don't transform constant expressions because there's an inverse fold
// for X + (-Y) --> X - Y.
- if (match(Op1, m_Constant(C)) && !isa<ConstantExpr>(Op1))
+ if (match(Op1, m_ImmConstant(C)))
return BinaryOperator::CreateFAddFMF(Op0, ConstantExpr::getFNeg(C), &I);
// X - (-Y) --> X + Y
@@ -2244,9 +2283,8 @@ Instruction *InstCombiner::visitFSub(BinaryOperator &I) {
}
auto m_FaddRdx = [](Value *&Sum, Value *&Vec) {
- return m_OneUse(
- m_Intrinsic<Intrinsic::experimental_vector_reduce_v2_fadd>(
- m_Value(Sum), m_Value(Vec)));
+ return m_OneUse(m_Intrinsic<Intrinsic::vector_reduce_fadd>(m_Value(Sum),
+ m_Value(Vec)));
};
Value *A0, *A1, *V0, *V1;
if (match(Op0, m_FaddRdx(A0, V0)) && match(Op1, m_FaddRdx(A1, V1)) &&
@@ -2254,9 +2292,8 @@ Instruction *InstCombiner::visitFSub(BinaryOperator &I) {
// Difference of sums is sum of differences:
// add_rdx(A0, V0) - add_rdx(A1, V1) --> add_rdx(A0, V0 - V1) - A1
Value *Sub = Builder.CreateFSubFMF(V0, V1, &I);
- Value *Rdx = Builder.CreateIntrinsic(
- Intrinsic::experimental_vector_reduce_v2_fadd,
- {A0->getType(), Sub->getType()}, {A0, Sub}, &I);
+ Value *Rdx = Builder.CreateIntrinsic(Intrinsic::vector_reduce_fadd,
+ {Sub->getType()}, {A0, Sub}, &I);
return BinaryOperator::CreateFSubFMF(Rdx, A1, &I);
}
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index d3c718a919c0..68c4156af2c4 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -13,10 +13,12 @@
#include "InstCombineInternal.h"
#include "llvm/Analysis/CmpInstAnalysis.h"
#include "llvm/Analysis/InstructionSimplify.h"
-#include "llvm/Transforms/Utils/Local.h"
#include "llvm/IR/ConstantRange.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/PatternMatch.h"
+#include "llvm/Transforms/InstCombine/InstCombiner.h"
+#include "llvm/Transforms/Utils/Local.h"
+
using namespace llvm;
using namespace PatternMatch;
@@ -112,57 +114,12 @@ static Value *SimplifyBSwap(BinaryOperator &I,
return Builder.CreateCall(F, BinOp);
}
-/// This handles expressions of the form ((val OP C1) & C2). Where
-/// the Op parameter is 'OP', OpRHS is 'C1', and AndRHS is 'C2'.
-Instruction *InstCombiner::OptAndOp(BinaryOperator *Op,
- ConstantInt *OpRHS,
- ConstantInt *AndRHS,
- BinaryOperator &TheAnd) {
- Value *X = Op->getOperand(0);
-
- switch (Op->getOpcode()) {
- default: break;
- case Instruction::Add:
- if (Op->hasOneUse()) {
- // Adding a one to a single bit bit-field should be turned into an XOR
- // of the bit. First thing to check is to see if this AND is with a
- // single bit constant.
- const APInt &AndRHSV = AndRHS->getValue();
-
- // If there is only one bit set.
- if (AndRHSV.isPowerOf2()) {
- // Ok, at this point, we know that we are masking the result of the
- // ADD down to exactly one bit. If the constant we are adding has
- // no bits set below this bit, then we can eliminate the ADD.
- const APInt& AddRHS = OpRHS->getValue();
-
- // Check to see if any bits below the one bit set in AndRHSV are set.
- if ((AddRHS & (AndRHSV - 1)).isNullValue()) {
- // If not, the only thing that can effect the output of the AND is
- // the bit specified by AndRHSV. If that bit is set, the effect of
- // the XOR is to toggle the bit. If it is clear, then the ADD has
- // no effect.
- if ((AddRHS & AndRHSV).isNullValue()) { // Bit is not set, noop
- return replaceOperand(TheAnd, 0, X);
- } else {
- // Pull the XOR out of the AND.
- Value *NewAnd = Builder.CreateAnd(X, AndRHS);
- NewAnd->takeName(Op);
- return BinaryOperator::CreateXor(NewAnd, AndRHS);
- }
- }
- }
- }
- break;
- }
- return nullptr;
-}
-
/// Emit a computation of: (V >= Lo && V < Hi) if Inside is true, otherwise
/// (V < Lo || V >= Hi). This method expects that Lo < Hi. IsSigned indicates
/// whether to treat V, Lo, and Hi as signed or not.
-Value *InstCombiner::insertRangeTest(Value *V, const APInt &Lo, const APInt &Hi,
- bool isSigned, bool Inside) {
+Value *InstCombinerImpl::insertRangeTest(Value *V, const APInt &Lo,
+ const APInt &Hi, bool isSigned,
+ bool Inside) {
assert((isSigned ? Lo.slt(Hi) : Lo.ult(Hi)) &&
"Lo is not < Hi in range emission code!");
@@ -437,11 +394,10 @@ getMaskedTypeForICmpPair(Value *&A, Value *&B, Value *&C,
/// (icmp(A & X) ==/!= Y), where the left-hand side is of type Mask_NotAllZeros
/// and the right hand side is of type BMask_Mixed. For example,
/// (icmp (A & 12) != 0) & (icmp (A & 15) == 8) -> (icmp (A & 15) == 8).
-static Value * foldLogOpOfMaskedICmps_NotAllZeros_BMask_Mixed(
- ICmpInst *LHS, ICmpInst *RHS, bool IsAnd,
- Value *A, Value *B, Value *C, Value *D, Value *E,
- ICmpInst::Predicate PredL, ICmpInst::Predicate PredR,
- llvm::InstCombiner::BuilderTy &Builder) {
+static Value *foldLogOpOfMaskedICmps_NotAllZeros_BMask_Mixed(
+ ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, Value *A, Value *B, Value *C,
+ Value *D, Value *E, ICmpInst::Predicate PredL, ICmpInst::Predicate PredR,
+ InstCombiner::BuilderTy &Builder) {
// We are given the canonical form:
// (icmp ne (A & B), 0) & (icmp eq (A & D), E).
// where D & E == E.
@@ -452,17 +408,9 @@ static Value * foldLogOpOfMaskedICmps_NotAllZeros_BMask_Mixed(
//
// We currently handle the case of B, C, D, E are constant.
//
- ConstantInt *BCst = dyn_cast<ConstantInt>(B);
- if (!BCst)
- return nullptr;
- ConstantInt *CCst = dyn_cast<ConstantInt>(C);
- if (!CCst)
- return nullptr;
- ConstantInt *DCst = dyn_cast<ConstantInt>(D);
- if (!DCst)
- return nullptr;
- ConstantInt *ECst = dyn_cast<ConstantInt>(E);
- if (!ECst)
+ ConstantInt *BCst, *CCst, *DCst, *ECst;
+ if (!match(B, m_ConstantInt(BCst)) || !match(C, m_ConstantInt(CCst)) ||
+ !match(D, m_ConstantInt(DCst)) || !match(E, m_ConstantInt(ECst)))
return nullptr;
ICmpInst::Predicate NewCC = IsAnd ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE;
@@ -568,11 +516,9 @@ static Value * foldLogOpOfMaskedICmps_NotAllZeros_BMask_Mixed(
/// (icmp(A & X) ==/!= Y), where the left-hand side and the right hand side
/// aren't of the common mask pattern type.
static Value *foldLogOpOfMaskedICmpsAsymmetric(
- ICmpInst *LHS, ICmpInst *RHS, bool IsAnd,
- Value *A, Value *B, Value *C, Value *D, Value *E,
- ICmpInst::Predicate PredL, ICmpInst::Predicate PredR,
- unsigned LHSMask, unsigned RHSMask,
- llvm::InstCombiner::BuilderTy &Builder) {
+ ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, Value *A, Value *B, Value *C,
+ Value *D, Value *E, ICmpInst::Predicate PredL, ICmpInst::Predicate PredR,
+ unsigned LHSMask, unsigned RHSMask, InstCombiner::BuilderTy &Builder) {
assert(ICmpInst::isEquality(PredL) && ICmpInst::isEquality(PredR) &&
"Expected equality predicates for masked type of icmps.");
// Handle Mask_NotAllZeros-BMask_Mixed cases.
@@ -603,7 +549,7 @@ static Value *foldLogOpOfMaskedICmpsAsymmetric(
/// Try to fold (icmp(A & B) ==/!= C) &/| (icmp(A & D) ==/!= E)
/// into a single (icmp(A & X) ==/!= Y).
static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd,
- llvm::InstCombiner::BuilderTy &Builder) {
+ InstCombiner::BuilderTy &Builder) {
Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr, *E = nullptr;
ICmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate();
Optional<std::pair<unsigned, unsigned>> MaskPair =
@@ -673,11 +619,8 @@ static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd,
// Remaining cases assume at least that B and D are constant, and depend on
// their actual values. This isn't strictly necessary, just a "handle the
// easy cases for now" decision.
- ConstantInt *BCst = dyn_cast<ConstantInt>(B);
- if (!BCst)
- return nullptr;
- ConstantInt *DCst = dyn_cast<ConstantInt>(D);
- if (!DCst)
+ ConstantInt *BCst, *DCst;
+ if (!match(B, m_ConstantInt(BCst)) || !match(D, m_ConstantInt(DCst)))
return nullptr;
if (Mask & (Mask_NotAllZeros | BMask_NotAllOnes)) {
@@ -718,11 +661,8 @@ static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd,
// We can't simply use C and E because we might actually handle
// (icmp ne (A & B), B) & (icmp eq (A & D), D)
// with B and D, having a single bit set.
- ConstantInt *CCst = dyn_cast<ConstantInt>(C);
- if (!CCst)
- return nullptr;
- ConstantInt *ECst = dyn_cast<ConstantInt>(E);
- if (!ECst)
+ ConstantInt *CCst, *ECst;
+ if (!match(C, m_ConstantInt(CCst)) || !match(E, m_ConstantInt(ECst)))
return nullptr;
if (PredL != NewCC)
CCst = cast<ConstantInt>(ConstantExpr::getXor(BCst, CCst));
@@ -748,8 +688,8 @@ static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd,
/// Example: (icmp sge x, 0) & (icmp slt x, n) --> icmp ult x, n
/// If \p Inverted is true then the check is for the inverted range, e.g.
/// (icmp slt x, 0) | (icmp sgt x, n) --> icmp ugt x, n
-Value *InstCombiner::simplifyRangeCheck(ICmpInst *Cmp0, ICmpInst *Cmp1,
- bool Inverted) {
+Value *InstCombinerImpl::simplifyRangeCheck(ICmpInst *Cmp0, ICmpInst *Cmp1,
+ bool Inverted) {
// Check the lower range comparison, e.g. x >= 0
// InstCombine already ensured that if there is a constant it's on the RHS.
ConstantInt *RangeStart = dyn_cast<ConstantInt>(Cmp0->getOperand(1));
@@ -856,8 +796,9 @@ foldAndOrOfEqualityCmpsWithConstants(ICmpInst *LHS, ICmpInst *RHS,
// Fold (iszero(A & K1) | iszero(A & K2)) -> (A & (K1 | K2)) != (K1 | K2)
// Fold (!iszero(A & K1) & !iszero(A & K2)) -> (A & (K1 | K2)) == (K1 | K2)
-Value *InstCombiner::foldAndOrOfICmpsOfAndWithPow2(ICmpInst *LHS, ICmpInst *RHS,
- BinaryOperator &Logic) {
+Value *InstCombinerImpl::foldAndOrOfICmpsOfAndWithPow2(ICmpInst *LHS,
+ ICmpInst *RHS,
+ BinaryOperator &Logic) {
bool JoinedByAnd = Logic.getOpcode() == Instruction::And;
assert((JoinedByAnd || Logic.getOpcode() == Instruction::Or) &&
"Wrong opcode");
@@ -869,10 +810,8 @@ Value *InstCombiner::foldAndOrOfICmpsOfAndWithPow2(ICmpInst *LHS, ICmpInst *RHS,
if (!JoinedByAnd && Pred != ICmpInst::ICMP_EQ)
return nullptr;
- // TODO support vector splats
- ConstantInt *LHSC = dyn_cast<ConstantInt>(LHS->getOperand(1));
- ConstantInt *RHSC = dyn_cast<ConstantInt>(RHS->getOperand(1));
- if (!LHSC || !RHSC || !LHSC->isZero() || !RHSC->isZero())
+ if (!match(LHS->getOperand(1), m_Zero()) ||
+ !match(RHS->getOperand(1), m_Zero()))
return nullptr;
Value *A, *B, *C, *D;
@@ -1148,11 +1087,12 @@ static Value *foldAndOrOfICmpsWithConstEq(ICmpInst *Cmp0, ICmpInst *Cmp1,
assert((IsAnd || Logic.getOpcode() == Instruction::Or) && "Wrong logic op");
// Match an equality compare with a non-poison constant as Cmp0.
+ // Also, give up if the compare can be constant-folded to avoid looping.
ICmpInst::Predicate Pred0;
Value *X;
Constant *C;
if (!match(Cmp0, m_ICmp(Pred0, m_Value(X), m_Constant(C))) ||
- !isGuaranteedNotToBeUndefOrPoison(C))
+ !isGuaranteedNotToBeUndefOrPoison(C) || isa<Constant>(X))
return nullptr;
if ((IsAnd && Pred0 != ICmpInst::ICMP_EQ) ||
(!IsAnd && Pred0 != ICmpInst::ICMP_NE))
@@ -1183,8 +1123,8 @@ static Value *foldAndOrOfICmpsWithConstEq(ICmpInst *Cmp0, ICmpInst *Cmp1,
}
/// Fold (icmp)&(icmp) if possible.
-Value *InstCombiner::foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS,
- BinaryOperator &And) {
+Value *InstCombinerImpl::foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS,
+ BinaryOperator &And) {
const SimplifyQuery Q = SQ.getWithInstruction(&And);
// Fold (!iszero(A & K1) & !iszero(A & K2)) -> (A & (K1 | K2)) == (K1 | K2)
@@ -1243,9 +1183,10 @@ Value *InstCombiner::foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS,
// This only handles icmp of constants: (icmp1 A, C1) & (icmp2 B, C2).
Value *LHS0 = LHS->getOperand(0), *RHS0 = RHS->getOperand(0);
- ConstantInt *LHSC = dyn_cast<ConstantInt>(LHS->getOperand(1));
- ConstantInt *RHSC = dyn_cast<ConstantInt>(RHS->getOperand(1));
- if (!LHSC || !RHSC)
+
+ ConstantInt *LHSC, *RHSC;
+ if (!match(LHS->getOperand(1), m_ConstantInt(LHSC)) ||
+ !match(RHS->getOperand(1), m_ConstantInt(RHSC)))
return nullptr;
if (LHSC == RHSC && PredL == PredR) {
@@ -1403,7 +1344,8 @@ Value *InstCombiner::foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS,
return nullptr;
}
-Value *InstCombiner::foldLogicOfFCmps(FCmpInst *LHS, FCmpInst *RHS, bool IsAnd) {
+Value *InstCombinerImpl::foldLogicOfFCmps(FCmpInst *LHS, FCmpInst *RHS,
+ bool IsAnd) {
Value *LHS0 = LHS->getOperand(0), *LHS1 = LHS->getOperand(1);
Value *RHS0 = RHS->getOperand(0), *RHS1 = RHS->getOperand(1);
FCmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate();
@@ -1513,8 +1455,8 @@ static Instruction *matchDeMorgansLaws(BinaryOperator &I,
Value *A, *B;
if (match(I.getOperand(0), m_OneUse(m_Not(m_Value(A)))) &&
match(I.getOperand(1), m_OneUse(m_Not(m_Value(B)))) &&
- !isFreeToInvert(A, A->hasOneUse()) &&
- !isFreeToInvert(B, B->hasOneUse())) {
+ !InstCombiner::isFreeToInvert(A, A->hasOneUse()) &&
+ !InstCombiner::isFreeToInvert(B, B->hasOneUse())) {
Value *AndOr = Builder.CreateBinOp(Opcode, A, B, I.getName() + ".demorgan");
return BinaryOperator::CreateNot(AndOr);
}
@@ -1522,7 +1464,7 @@ static Instruction *matchDeMorgansLaws(BinaryOperator &I,
return nullptr;
}
-bool InstCombiner::shouldOptimizeCast(CastInst *CI) {
+bool InstCombinerImpl::shouldOptimizeCast(CastInst *CI) {
Value *CastSrc = CI->getOperand(0);
// Noop casts and casts of constants should be eliminated trivially.
@@ -1578,7 +1520,7 @@ static Instruction *foldLogicCastConstant(BinaryOperator &Logic, CastInst *Cast,
}
/// Fold {and,or,xor} (cast X), Y.
-Instruction *InstCombiner::foldCastedBitwiseLogic(BinaryOperator &I) {
+Instruction *InstCombinerImpl::foldCastedBitwiseLogic(BinaryOperator &I) {
auto LogicOpc = I.getOpcode();
assert(I.isBitwiseLogicOp() && "Unexpected opcode for bitwise logic folding");
@@ -1685,6 +1627,14 @@ static Instruction *foldOrToXor(BinaryOperator &I,
match(Op1, m_Not(m_c_Or(m_Specific(A), m_Specific(B)))))
return BinaryOperator::CreateNot(Builder.CreateXor(A, B));
+ // Operand complexity canonicalization guarantees that the 'xor' is Op0.
+ // (A ^ B) | ~(A | B) --> ~(A & B)
+ // (A ^ B) | ~(B | A) --> ~(A & B)
+ if (Op0->hasOneUse() || Op1->hasOneUse())
+ if (match(Op0, m_Xor(m_Value(A), m_Value(B))) &&
+ match(Op1, m_Not(m_c_Or(m_Specific(A), m_Specific(B)))))
+ return BinaryOperator::CreateNot(Builder.CreateAnd(A, B));
+
// (A & ~B) | (~A & B) --> A ^ B
// (A & ~B) | (B & ~A) --> A ^ B
// (~B & A) | (~A & B) --> A ^ B
@@ -1699,32 +1649,13 @@ static Instruction *foldOrToXor(BinaryOperator &I,
/// Return true if a constant shift amount is always less than the specified
/// bit-width. If not, the shift could create poison in the narrower type.
static bool canNarrowShiftAmt(Constant *C, unsigned BitWidth) {
- if (auto *ScalarC = dyn_cast<ConstantInt>(C))
- return ScalarC->getZExtValue() < BitWidth;
-
- if (C->getType()->isVectorTy()) {
- // Check each element of a constant vector.
- unsigned NumElts = cast<VectorType>(C->getType())->getNumElements();
- for (unsigned i = 0; i != NumElts; ++i) {
- Constant *Elt = C->getAggregateElement(i);
- if (!Elt)
- return false;
- if (isa<UndefValue>(Elt))
- continue;
- auto *CI = dyn_cast<ConstantInt>(Elt);
- if (!CI || CI->getZExtValue() >= BitWidth)
- return false;
- }
- return true;
- }
-
- // The constant is a constant expression or unknown.
- return false;
+ APInt Threshold(C->getType()->getScalarSizeInBits(), BitWidth);
+ return match(C, m_SpecificInt_ICMP(ICmpInst::ICMP_ULT, Threshold));
}
/// Try to use narrower ops (sink zext ops) for an 'and' with binop operand and
/// a common zext operand: and (binop (zext X), C), (zext X).
-Instruction *InstCombiner::narrowMaskedBinOp(BinaryOperator &And) {
+Instruction *InstCombinerImpl::narrowMaskedBinOp(BinaryOperator &And) {
// This transform could also apply to {or, and, xor}, but there are better
// folds for those cases, so we don't expect those patterns here. AShr is not
// handled because it should always be transformed to LShr in this sequence.
@@ -1766,7 +1697,9 @@ Instruction *InstCombiner::narrowMaskedBinOp(BinaryOperator &And) {
// FIXME: We use commutative matchers (m_c_*) for some, but not all, matches
// here. We should standardize that construct where it is needed or choose some
// other way to ensure that commutated variants of patterns are not missed.
-Instruction *InstCombiner::visitAnd(BinaryOperator &I) {
+Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) {
+ Type *Ty = I.getType();
+
if (Value *V = SimplifyAndInst(I.getOperand(0), I.getOperand(1),
SQ.getWithInstruction(&I)))
return replaceInstUsesWith(I, V);
@@ -1794,21 +1727,22 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) {
return replaceInstUsesWith(I, V);
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
+
+ Value *X, *Y;
+ if (match(Op0, m_OneUse(m_LogicalShift(m_One(), m_Value(X)))) &&
+ match(Op1, m_One())) {
+ // (1 << X) & 1 --> zext(X == 0)
+ // (1 >> X) & 1 --> zext(X == 0)
+ Value *IsZero = Builder.CreateICmpEQ(X, ConstantInt::get(Ty, 0));
+ return new ZExtInst(IsZero, Ty);
+ }
+
const APInt *C;
if (match(Op1, m_APInt(C))) {
- Value *X, *Y;
- if (match(Op0, m_OneUse(m_LogicalShift(m_One(), m_Value(X)))) &&
- C->isOneValue()) {
- // (1 << X) & 1 --> zext(X == 0)
- // (1 >> X) & 1 --> zext(X == 0)
- Value *IsZero = Builder.CreateICmpEQ(X, ConstantInt::get(I.getType(), 0));
- return new ZExtInst(IsZero, I.getType());
- }
-
const APInt *XorC;
if (match(Op0, m_OneUse(m_Xor(m_Value(X), m_APInt(XorC))))) {
// (X ^ C1) & C2 --> (X & C2) ^ (C1&C2)
- Constant *NewC = ConstantInt::get(I.getType(), *C & *XorC);
+ Constant *NewC = ConstantInt::get(Ty, *C & *XorC);
Value *And = Builder.CreateAnd(X, Op1);
And->takeName(Op0);
return BinaryOperator::CreateXor(And, NewC);
@@ -1823,11 +1757,9 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) {
// that aren't set in C2. Meaning we can replace (C1&C2) with C1 in
// above, but this feels safer.
APInt Together = *C & *OrC;
- Value *And = Builder.CreateAnd(X, ConstantInt::get(I.getType(),
- Together ^ *C));
+ Value *And = Builder.CreateAnd(X, ConstantInt::get(Ty, Together ^ *C));
And->takeName(Op0);
- return BinaryOperator::CreateOr(And, ConstantInt::get(I.getType(),
- Together));
+ return BinaryOperator::CreateOr(And, ConstantInt::get(Ty, Together));
}
// If the mask is only needed on one incoming arm, push the 'and' op up.
@@ -1848,27 +1780,49 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) {
return BinaryOperator::Create(BinOp, NewLHS, Y);
}
}
+
+ unsigned Width = Ty->getScalarSizeInBits();
const APInt *ShiftC;
if (match(Op0, m_OneUse(m_SExt(m_AShr(m_Value(X), m_APInt(ShiftC)))))) {
- unsigned Width = I.getType()->getScalarSizeInBits();
if (*C == APInt::getLowBitsSet(Width, Width - ShiftC->getZExtValue())) {
// We are clearing high bits that were potentially set by sext+ashr:
// and (sext (ashr X, ShiftC)), C --> lshr (sext X), ShiftC
- Value *Sext = Builder.CreateSExt(X, I.getType());
- Constant *ShAmtC = ConstantInt::get(I.getType(), ShiftC->zext(Width));
+ Value *Sext = Builder.CreateSExt(X, Ty);
+ Constant *ShAmtC = ConstantInt::get(Ty, ShiftC->zext(Width));
return BinaryOperator::CreateLShr(Sext, ShAmtC);
}
}
+
+ const APInt *AddC;
+ if (match(Op0, m_Add(m_Value(X), m_APInt(AddC)))) {
+ // If we add zeros to every bit below a mask, the add has no effect:
+ // (X + AddC) & LowMaskC --> X & LowMaskC
+ unsigned Ctlz = C->countLeadingZeros();
+ APInt LowMask(APInt::getLowBitsSet(Width, Width - Ctlz));
+ if ((*AddC & LowMask).isNullValue())
+ return BinaryOperator::CreateAnd(X, Op1);
+
+ // If we are masking the result of the add down to exactly one bit and
+ // the constant we are adding has no bits set below that bit, then the
+ // add is flipping a single bit. Example:
+ // (X + 4) & 4 --> (X & 4) ^ 4
+ if (Op0->hasOneUse() && C->isPowerOf2() && (*AddC & (*C - 1)) == 0) {
+ assert((*C & *AddC) != 0 && "Expected common bit");
+ Value *NewAnd = Builder.CreateAnd(X, Op1);
+ return BinaryOperator::CreateXor(NewAnd, Op1);
+ }
+ }
}
- if (ConstantInt *AndRHS = dyn_cast<ConstantInt>(Op1)) {
+ ConstantInt *AndRHS;
+ if (match(Op1, m_ConstantInt(AndRHS))) {
const APInt &AndRHSMask = AndRHS->getValue();
// Optimize a variety of ((val OP C1) & C2) combinations...
if (BinaryOperator *Op0I = dyn_cast<BinaryOperator>(Op0)) {
// ((C1 OP zext(X)) & C2) -> zext((C1-X) & C2) if C2 fits in the bitwidth
// of X and OP behaves well when given trunc(C1) and X.
- // TODO: Do this for vectors by using m_APInt isntead of m_ConstantInt.
+ // TODO: Do this for vectors by using m_APInt instead of m_ConstantInt.
switch (Op0I->getOpcode()) {
default:
break;
@@ -1893,31 +1847,30 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) {
BinOp = Builder.CreateBinOp(Op0I->getOpcode(), TruncC1, X);
auto *TruncC2 = ConstantExpr::getTrunc(AndRHS, X->getType());
auto *And = Builder.CreateAnd(BinOp, TruncC2);
- return new ZExtInst(And, I.getType());
+ return new ZExtInst(And, Ty);
}
}
}
-
- if (ConstantInt *Op0CI = dyn_cast<ConstantInt>(Op0I->getOperand(1)))
- if (Instruction *Res = OptAndOp(Op0I, Op0CI, AndRHS, I))
- return Res;
}
+ }
- // If this is an integer truncation, and if the source is an 'and' with
- // immediate, transform it. This frequently occurs for bitfield accesses.
- {
- Value *X = nullptr; ConstantInt *YC = nullptr;
- if (match(Op0, m_Trunc(m_And(m_Value(X), m_ConstantInt(YC))))) {
- // Change: and (trunc (and X, YC) to T), C2
- // into : and (trunc X to T), trunc(YC) & C2
- // This will fold the two constants together, which may allow
- // other simplifications.
- Value *NewCast = Builder.CreateTrunc(X, I.getType(), "and.shrunk");
- Constant *C3 = ConstantExpr::getTrunc(YC, I.getType());
- C3 = ConstantExpr::getAnd(C3, AndRHS);
- return BinaryOperator::CreateAnd(NewCast, C3);
- }
- }
+ if (match(&I, m_And(m_OneUse(m_Shl(m_ZExt(m_Value(X)), m_Value(Y))),
+ m_SignMask())) &&
+ match(Y, m_SpecificInt_ICMP(
+ ICmpInst::Predicate::ICMP_EQ,
+ APInt(Ty->getScalarSizeInBits(),
+ Ty->getScalarSizeInBits() -
+ X->getType()->getScalarSizeInBits())))) {
+ auto *SExt = Builder.CreateSExt(X, Ty, X->getName() + ".signext");
+ auto *SanitizedSignMask = cast<Constant>(Op1);
+ // We must be careful with the undef elements of the sign bit mask, however:
+ // the mask elt can be undef iff the shift amount for that lane was undef,
+ // otherwise we need to sanitize undef masks to zero.
+ SanitizedSignMask = Constant::replaceUndefsWith(
+ SanitizedSignMask, ConstantInt::getNullValue(Ty->getScalarType()));
+ SanitizedSignMask =
+ Constant::mergeUndefsWith(SanitizedSignMask, cast<Constant>(Y));
+ return BinaryOperator::CreateAnd(SExt, SanitizedSignMask);
}
if (Instruction *Z = narrowMaskedBinOp(I))
@@ -1938,6 +1891,13 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) {
if (match(Op0, m_OneUse(m_c_Xor(m_Specific(Op1), m_Value(B)))))
return BinaryOperator::CreateAnd(Op1, Builder.CreateNot(B));
+ // A & ~(A ^ B) --> A & B
+ if (match(Op1, m_Not(m_c_Xor(m_Specific(Op0), m_Value(B)))))
+ return BinaryOperator::CreateAnd(Op0, B);
+ // ~(A ^ B) & A --> A & B
+ if (match(Op0, m_Not(m_c_Xor(m_Specific(Op1), m_Value(B)))))
+ return BinaryOperator::CreateAnd(Op1, B);
+
// (A ^ B) & ((B ^ C) ^ A) -> (A ^ B) & ~C
if (match(Op0, m_Xor(m_Value(A), m_Value(B))))
if (match(Op1, m_Xor(m_Xor(m_Specific(B), m_Value(C)), m_Specific(A))))
@@ -1976,7 +1936,6 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) {
// TODO: Make this recursive; it's a little tricky because an arbitrary
// number of 'and' instructions might have to be created.
- Value *X, *Y;
if (LHS && match(Op1, m_OneUse(m_And(m_Value(X), m_Value(Y))))) {
if (auto *Cmp = dyn_cast<ICmpInst>(X))
if (Value *Res = foldAndOfICmps(LHS, Cmp, I))
@@ -2010,29 +1969,30 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) {
Value *A;
if (match(Op0, m_OneUse(m_SExt(m_Value(A)))) &&
A->getType()->isIntOrIntVectorTy(1))
- return SelectInst::Create(A, Op1, Constant::getNullValue(I.getType()));
+ return SelectInst::Create(A, Op1, Constant::getNullValue(Ty));
if (match(Op1, m_OneUse(m_SExt(m_Value(A)))) &&
A->getType()->isIntOrIntVectorTy(1))
- return SelectInst::Create(A, Op0, Constant::getNullValue(I.getType()));
+ return SelectInst::Create(A, Op0, Constant::getNullValue(Ty));
// and(ashr(subNSW(Y, X), ScalarSizeInBits(Y)-1), X) --> X s> Y ? X : 0.
- {
- Value *X, *Y;
- const APInt *ShAmt;
- Type *Ty = I.getType();
- if (match(&I, m_c_And(m_OneUse(m_AShr(m_NSWSub(m_Value(Y), m_Value(X)),
- m_APInt(ShAmt))),
- m_Deferred(X))) &&
- *ShAmt == Ty->getScalarSizeInBits() - 1) {
- Value *NewICmpInst = Builder.CreateICmpSGT(X, Y);
- return SelectInst::Create(NewICmpInst, X, ConstantInt::getNullValue(Ty));
- }
+ if (match(&I, m_c_And(m_OneUse(m_AShr(
+ m_NSWSub(m_Value(Y), m_Value(X)),
+ m_SpecificInt(Ty->getScalarSizeInBits() - 1))),
+ m_Deferred(X)))) {
+ Value *NewICmpInst = Builder.CreateICmpSGT(X, Y);
+ return SelectInst::Create(NewICmpInst, X, ConstantInt::getNullValue(Ty));
}
+ // (~x) & y --> ~(x | (~y)) iff that gets rid of inversions
+ if (sinkNotIntoOtherHandOfAndOrOr(I))
+ return &I;
+
return nullptr;
}
-Instruction *InstCombiner::matchBSwap(BinaryOperator &Or) {
+Instruction *InstCombinerImpl::matchBSwapOrBitReverse(BinaryOperator &Or,
+ bool MatchBSwaps,
+ bool MatchBitReversals) {
assert(Or.getOpcode() == Instruction::Or && "bswap requires an 'or'");
Value *Op0 = Or.getOperand(0), *Op1 = Or.getOperand(1);
@@ -2044,33 +2004,32 @@ Instruction *InstCombiner::matchBSwap(BinaryOperator &Or) {
Op1 = Ext->getOperand(0);
// (A | B) | C and A | (B | C) -> bswap if possible.
- bool OrOfOrs = match(Op0, m_Or(m_Value(), m_Value())) ||
- match(Op1, m_Or(m_Value(), m_Value()));
+ bool OrWithOrs = match(Op0, m_Or(m_Value(), m_Value())) ||
+ match(Op1, m_Or(m_Value(), m_Value()));
- // (A >> B) | (C << D) and (A << B) | (B >> C) -> bswap if possible.
- bool OrOfShifts = match(Op0, m_LogicalShift(m_Value(), m_Value())) &&
- match(Op1, m_LogicalShift(m_Value(), m_Value()));
+ // (A >> B) | C and (A << B) | C -> bswap if possible.
+ bool OrWithShifts = match(Op0, m_LogicalShift(m_Value(), m_Value())) ||
+ match(Op1, m_LogicalShift(m_Value(), m_Value()));
- // (A & B) | (C & D) -> bswap if possible.
- bool OrOfAnds = match(Op0, m_And(m_Value(), m_Value())) &&
- match(Op1, m_And(m_Value(), m_Value()));
+ // (A & B) | C and A | (B & C) -> bswap if possible.
+ bool OrWithAnds = match(Op0, m_And(m_Value(), m_Value())) ||
+ match(Op1, m_And(m_Value(), m_Value()));
- // (A << B) | (C & D) -> bswap if possible.
- // The bigger pattern here is ((A & C1) << C2) | ((B >> C2) & C1), which is a
- // part of the bswap idiom for specific values of C1, C2 (e.g. C1 = 16711935,
- // C2 = 8 for i32).
- // This pattern can occur when the operands of the 'or' are not canonicalized
- // for some reason (not having only one use, for example).
- bool OrOfAndAndSh = (match(Op0, m_LogicalShift(m_Value(), m_Value())) &&
- match(Op1, m_And(m_Value(), m_Value()))) ||
- (match(Op0, m_And(m_Value(), m_Value())) &&
- match(Op1, m_LogicalShift(m_Value(), m_Value())));
+ // fshl(A,B,C) | D and A | fshl(B,C,D) -> bswap if possible.
+ // fshr(A,B,C) | D and A | fshr(B,C,D) -> bswap if possible.
+ bool OrWithFunnels = match(Op0, m_FShl(m_Value(), m_Value(), m_Value())) ||
+ match(Op0, m_FShr(m_Value(), m_Value(), m_Value())) ||
+ match(Op0, m_FShl(m_Value(), m_Value(), m_Value())) ||
+ match(Op0, m_FShr(m_Value(), m_Value(), m_Value()));
- if (!OrOfOrs && !OrOfShifts && !OrOfAnds && !OrOfAndAndSh)
+ // TODO: Do we need all these filtering checks or should we just rely on
+ // recognizeBSwapOrBitReverseIdiom + collectBitParts to reject them quickly?
+ if (!OrWithOrs && !OrWithShifts && !OrWithAnds && !OrWithFunnels)
return nullptr;
- SmallVector<Instruction*, 4> Insts;
- if (!recognizeBSwapOrBitReverseIdiom(&Or, true, false, Insts))
+ SmallVector<Instruction *, 4> Insts;
+ if (!recognizeBSwapOrBitReverseIdiom(&Or, MatchBSwaps, MatchBitReversals,
+ Insts))
return nullptr;
Instruction *LastInst = Insts.pop_back_val();
LastInst->removeFromParent();
@@ -2080,34 +2039,72 @@ Instruction *InstCombiner::matchBSwap(BinaryOperator &Or) {
return LastInst;
}
-/// Transform UB-safe variants of bitwise rotate to the funnel shift intrinsic.
-static Instruction *matchRotate(Instruction &Or) {
+/// Match UB-safe variants of the funnel shift intrinsic.
+static Instruction *matchFunnelShift(Instruction &Or, InstCombinerImpl &IC) {
// TODO: Can we reduce the code duplication between this and the related
// rotate matching code under visitSelect and visitTrunc?
unsigned Width = Or.getType()->getScalarSizeInBits();
- if (!isPowerOf2_32(Width))
- return nullptr;
- // First, find an or'd pair of opposite shifts with the same shifted operand:
- // or (lshr ShVal, ShAmt0), (shl ShVal, ShAmt1)
+ // First, find an or'd pair of opposite shifts:
+ // or (lshr ShVal0, ShAmt0), (shl ShVal1, ShAmt1)
BinaryOperator *Or0, *Or1;
if (!match(Or.getOperand(0), m_BinOp(Or0)) ||
!match(Or.getOperand(1), m_BinOp(Or1)))
return nullptr;
- Value *ShVal, *ShAmt0, *ShAmt1;
- if (!match(Or0, m_OneUse(m_LogicalShift(m_Value(ShVal), m_Value(ShAmt0)))) ||
- !match(Or1, m_OneUse(m_LogicalShift(m_Specific(ShVal), m_Value(ShAmt1)))))
- return nullptr;
+ Value *ShVal0, *ShVal1, *ShAmt0, *ShAmt1;
+ if (!match(Or0, m_OneUse(m_LogicalShift(m_Value(ShVal0), m_Value(ShAmt0)))) ||
+ !match(Or1, m_OneUse(m_LogicalShift(m_Value(ShVal1), m_Value(ShAmt1)))) ||
+ Or0->getOpcode() == Or1->getOpcode())
+ return nullptr;
+
+ // Canonicalize to or(shl(ShVal0, ShAmt0), lshr(ShVal1, ShAmt1)).
+ if (Or0->getOpcode() == BinaryOperator::LShr) {
+ std::swap(Or0, Or1);
+ std::swap(ShVal0, ShVal1);
+ std::swap(ShAmt0, ShAmt1);
+ }
+ assert(Or0->getOpcode() == BinaryOperator::Shl &&
+ Or1->getOpcode() == BinaryOperator::LShr &&
+ "Illegal or(shift,shift) pair");
+
+ // Match the shift amount operands for a funnel shift pattern. This always
+ // matches a subtraction on the R operand.
+ auto matchShiftAmount = [&](Value *L, Value *R, unsigned Width) -> Value * {
+ // Check for constant shift amounts that sum to the bitwidth.
+ const APInt *LI, *RI;
+ if (match(L, m_APIntAllowUndef(LI)) && match(R, m_APIntAllowUndef(RI)))
+ if (LI->ult(Width) && RI->ult(Width) && (*LI + *RI) == Width)
+ return ConstantInt::get(L->getType(), *LI);
+
+ Constant *LC, *RC;
+ if (match(L, m_Constant(LC)) && match(R, m_Constant(RC)) &&
+ match(L, m_SpecificInt_ICMP(ICmpInst::ICMP_ULT, APInt(Width, Width))) &&
+ match(R, m_SpecificInt_ICMP(ICmpInst::ICMP_ULT, APInt(Width, Width))) &&
+ match(ConstantExpr::getAdd(LC, RC), m_SpecificIntAllowUndef(Width)))
+ return ConstantExpr::mergeUndefsWith(LC, RC);
+
+ // (shl ShVal, X) | (lshr ShVal, (Width - x)) iff X < Width.
+ // We limit this to X < Width in case the backend re-expands the intrinsic,
+ // and has to reintroduce a shift modulo operation (InstCombine might remove
+ // it after this fold). This still doesn't guarantee that the final codegen
+ // will match this original pattern.
+ if (match(R, m_OneUse(m_Sub(m_SpecificInt(Width), m_Specific(L))))) {
+ KnownBits KnownL = IC.computeKnownBits(L, /*Depth*/ 0, &Or);
+ return KnownL.getMaxValue().ult(Width) ? L : nullptr;
+ }
- BinaryOperator::BinaryOps ShiftOpcode0 = Or0->getOpcode();
- BinaryOperator::BinaryOps ShiftOpcode1 = Or1->getOpcode();
- if (ShiftOpcode0 == ShiftOpcode1)
- return nullptr;
+ // For non-constant cases, the following patterns currently only work for
+ // rotation patterns.
+ // TODO: Add general funnel-shift compatible patterns.
+ if (ShVal0 != ShVal1)
+ return nullptr;
+
+ // For non-constant cases we don't support non-pow2 shift masks.
+ // TODO: Is it worth matching urem as well?
+ if (!isPowerOf2_32(Width))
+ return nullptr;
- // Match the shift amount operands for a rotate pattern. This always matches
- // a subtraction on the R operand.
- auto matchShiftAmount = [](Value *L, Value *R, unsigned Width) -> Value * {
// The shift amount may be masked with negation:
// (shl ShVal, (X & (Width - 1))) | (lshr ShVal, ((-X) & (Width - 1)))
Value *X;
@@ -2123,23 +2120,25 @@ static Instruction *matchRotate(Instruction &Or) {
m_SpecificInt(Mask))))
return L;
+ if (match(L, m_ZExt(m_And(m_Value(X), m_SpecificInt(Mask)))) &&
+ match(R, m_ZExt(m_And(m_Neg(m_Specific(X)), m_SpecificInt(Mask)))))
+ return L;
+
return nullptr;
};
Value *ShAmt = matchShiftAmount(ShAmt0, ShAmt1, Width);
- bool SubIsOnLHS = false;
+ bool IsFshl = true; // Sub on LSHR.
if (!ShAmt) {
ShAmt = matchShiftAmount(ShAmt1, ShAmt0, Width);
- SubIsOnLHS = true;
+ IsFshl = false; // Sub on SHL.
}
if (!ShAmt)
return nullptr;
- bool IsFshl = (!SubIsOnLHS && ShiftOpcode0 == BinaryOperator::Shl) ||
- (SubIsOnLHS && ShiftOpcode1 == BinaryOperator::Shl);
Intrinsic::ID IID = IsFshl ? Intrinsic::fshl : Intrinsic::fshr;
Function *F = Intrinsic::getDeclaration(Or.getModule(), IID, Or.getType());
- return IntrinsicInst::Create(F, { ShVal, ShVal, ShAmt });
+ return IntrinsicInst::Create(F, {ShVal0, ShVal1, ShAmt});
}
/// Attempt to combine or(zext(x),shl(zext(y),bw/2) concat packing patterns.
@@ -2197,7 +2196,7 @@ static Instruction *matchOrConcat(Instruction &Or,
/// If all elements of two constant vectors are 0/-1 and inverses, return true.
static bool areInverseVectorBitmasks(Constant *C1, Constant *C2) {
- unsigned NumElts = cast<VectorType>(C1->getType())->getNumElements();
+ unsigned NumElts = cast<FixedVectorType>(C1->getType())->getNumElements();
for (unsigned i = 0; i != NumElts; ++i) {
Constant *EltC1 = C1->getAggregateElement(i);
Constant *EltC2 = C2->getAggregateElement(i);
@@ -2215,7 +2214,7 @@ static bool areInverseVectorBitmasks(Constant *C1, Constant *C2) {
/// We have an expression of the form (A & C) | (B & D). If A is a scalar or
/// vector composed of all-zeros or all-ones values and is the bitwise 'not' of
/// B, it can be used as the condition operand of a select instruction.
-Value *InstCombiner::getSelectCondition(Value *A, Value *B) {
+Value *InstCombinerImpl::getSelectCondition(Value *A, Value *B) {
// Step 1: We may have peeked through bitcasts in the caller.
// Exit immediately if we don't have (vector) integer types.
Type *Ty = A->getType();
@@ -2272,8 +2271,8 @@ Value *InstCombiner::getSelectCondition(Value *A, Value *B) {
/// We have an expression of the form (A & C) | (B & D). Try to simplify this
/// to "A' ? C : D", where A' is a boolean or vector of booleans.
-Value *InstCombiner::matchSelectFromAndOr(Value *A, Value *C, Value *B,
- Value *D) {
+Value *InstCombinerImpl::matchSelectFromAndOr(Value *A, Value *C, Value *B,
+ Value *D) {
// The potential condition of the select may be bitcasted. In that case, look
// through its bitcast and the corresponding bitcast of the 'not' condition.
Type *OrigType = A->getType();
@@ -2293,8 +2292,8 @@ Value *InstCombiner::matchSelectFromAndOr(Value *A, Value *C, Value *B,
}
/// Fold (icmp)|(icmp) if possible.
-Value *InstCombiner::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS,
- BinaryOperator &Or) {
+Value *InstCombinerImpl::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS,
+ BinaryOperator &Or) {
const SimplifyQuery Q = SQ.getWithInstruction(&Or);
// Fold (iszero(A & K1) | iszero(A & K2)) -> (A & (K1 | K2)) != (K1 | K2)
@@ -2303,9 +2302,10 @@ Value *InstCombiner::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS,
return V;
ICmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate();
-
- ConstantInt *LHSC = dyn_cast<ConstantInt>(LHS->getOperand(1));
- ConstantInt *RHSC = dyn_cast<ConstantInt>(RHS->getOperand(1));
+ Value *LHS0 = LHS->getOperand(0), *RHS0 = RHS->getOperand(0);
+ Value *LHS1 = LHS->getOperand(1), *RHS1 = RHS->getOperand(1);
+ auto *LHSC = dyn_cast<ConstantInt>(LHS1);
+ auto *RHSC = dyn_cast<ConstantInt>(RHS1);
// Fold (icmp ult/ule (A + C1), C3) | (icmp ult/ule (A + C2), C3)
// --> (icmp ult/ule ((A & ~(C1 ^ C2)) + max(C1, C2)), C3)
@@ -2317,24 +2317,20 @@ Value *InstCombiner::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS,
// 3) C1 ^ C2 is one-bit mask.
// 4) LowRange1 ^ LowRange2 and HighRange1 ^ HighRange2 are one-bit mask.
// This implies all values in the two ranges differ by exactly one bit.
-
if ((PredL == ICmpInst::ICMP_ULT || PredL == ICmpInst::ICMP_ULE) &&
PredL == PredR && LHSC && RHSC && LHS->hasOneUse() && RHS->hasOneUse() &&
LHSC->getType() == RHSC->getType() &&
LHSC->getValue() == (RHSC->getValue())) {
- Value *LAdd = LHS->getOperand(0);
- Value *RAdd = RHS->getOperand(0);
-
- Value *LAddOpnd, *RAddOpnd;
+ Value *AddOpnd;
ConstantInt *LAddC, *RAddC;
- if (match(LAdd, m_Add(m_Value(LAddOpnd), m_ConstantInt(LAddC))) &&
- match(RAdd, m_Add(m_Value(RAddOpnd), m_ConstantInt(RAddC))) &&
+ if (match(LHS0, m_Add(m_Value(AddOpnd), m_ConstantInt(LAddC))) &&
+ match(RHS0, m_Add(m_Specific(AddOpnd), m_ConstantInt(RAddC))) &&
LAddC->getValue().ugt(LHSC->getValue()) &&
RAddC->getValue().ugt(LHSC->getValue())) {
APInt DiffC = LAddC->getValue() ^ RAddC->getValue();
- if (LAddOpnd == RAddOpnd && DiffC.isPowerOf2()) {
+ if (DiffC.isPowerOf2()) {
ConstantInt *MaxAddC = nullptr;
if (LAddC->getValue().ult(RAddC->getValue()))
MaxAddC = RAddC;
@@ -2354,7 +2350,7 @@ Value *InstCombiner::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS,
RangeDiff.ugt(LHSC->getValue())) {
Value *MaskC = ConstantInt::get(LAddC->getType(), ~DiffC);
- Value *NewAnd = Builder.CreateAnd(LAddOpnd, MaskC);
+ Value *NewAnd = Builder.CreateAnd(AddOpnd, MaskC);
Value *NewAdd = Builder.CreateAdd(NewAnd, MaxAddC);
return Builder.CreateICmp(LHS->getPredicate(), NewAdd, LHSC);
}
@@ -2364,15 +2360,12 @@ Value *InstCombiner::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS,
// (icmp1 A, B) | (icmp2 A, B) --> (icmp3 A, B)
if (predicatesFoldable(PredL, PredR)) {
- if (LHS->getOperand(0) == RHS->getOperand(1) &&
- LHS->getOperand(1) == RHS->getOperand(0))
+ if (LHS0 == RHS1 && LHS1 == RHS0)
LHS->swapOperands();
- if (LHS->getOperand(0) == RHS->getOperand(0) &&
- LHS->getOperand(1) == RHS->getOperand(1)) {
- Value *Op0 = LHS->getOperand(0), *Op1 = LHS->getOperand(1);
+ if (LHS0 == RHS0 && LHS1 == RHS1) {
unsigned Code = getICmpCode(LHS) | getICmpCode(RHS);
bool IsSigned = LHS->isSigned() || RHS->isSigned();
- return getNewICmpValue(Code, IsSigned, Op0, Op1, Builder);
+ return getNewICmpValue(Code, IsSigned, LHS0, LHS1, Builder);
}
}
@@ -2381,31 +2374,30 @@ Value *InstCombiner::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS,
if (Value *V = foldLogOpOfMaskedICmps(LHS, RHS, false, Builder))
return V;
- Value *LHS0 = LHS->getOperand(0), *RHS0 = RHS->getOperand(0);
if (LHS->hasOneUse() || RHS->hasOneUse()) {
// (icmp eq B, 0) | (icmp ult A, B) -> (icmp ule A, B-1)
// (icmp eq B, 0) | (icmp ugt B, A) -> (icmp ule A, B-1)
Value *A = nullptr, *B = nullptr;
- if (PredL == ICmpInst::ICMP_EQ && LHSC && LHSC->isZero()) {
+ if (PredL == ICmpInst::ICMP_EQ && match(LHS1, m_Zero())) {
B = LHS0;
- if (PredR == ICmpInst::ICMP_ULT && LHS0 == RHS->getOperand(1))
+ if (PredR == ICmpInst::ICMP_ULT && LHS0 == RHS1)
A = RHS0;
else if (PredR == ICmpInst::ICMP_UGT && LHS0 == RHS0)
- A = RHS->getOperand(1);
+ A = RHS1;
}
// (icmp ult A, B) | (icmp eq B, 0) -> (icmp ule A, B-1)
// (icmp ugt B, A) | (icmp eq B, 0) -> (icmp ule A, B-1)
- else if (PredR == ICmpInst::ICMP_EQ && RHSC && RHSC->isZero()) {
+ else if (PredR == ICmpInst::ICMP_EQ && match(RHS1, m_Zero())) {
B = RHS0;
- if (PredL == ICmpInst::ICMP_ULT && RHS0 == LHS->getOperand(1))
+ if (PredL == ICmpInst::ICMP_ULT && RHS0 == LHS1)
A = LHS0;
- else if (PredL == ICmpInst::ICMP_UGT && LHS0 == RHS0)
- A = LHS->getOperand(1);
+ else if (PredL == ICmpInst::ICMP_UGT && RHS0 == LHS0)
+ A = LHS1;
}
- if (A && B)
+ if (A && B && B->getType()->isIntOrIntVectorTy())
return Builder.CreateICmp(
ICmpInst::ICMP_UGE,
- Builder.CreateAdd(B, ConstantInt::getSigned(B->getType(), -1)), A);
+ Builder.CreateAdd(B, Constant::getAllOnesValue(B->getType())), A);
}
if (Value *V = foldAndOrOfICmpsWithConstEq(LHS, RHS, Or, Builder, Q))
@@ -2434,18 +2426,21 @@ Value *InstCombiner::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS,
foldUnsignedUnderflowCheck(RHS, LHS, /*IsAnd=*/false, Q, Builder))
return X;
+ // (icmp ne A, 0) | (icmp ne B, 0) --> (icmp ne (A|B), 0)
+ // TODO: Remove this when foldLogOpOfMaskedICmps can handle vectors.
+ if (PredL == ICmpInst::ICMP_NE && match(LHS1, m_Zero()) &&
+ PredR == ICmpInst::ICMP_NE && match(RHS1, m_Zero()) &&
+ LHS0->getType()->isIntOrIntVectorTy() &&
+ LHS0->getType() == RHS0->getType()) {
+ Value *NewOr = Builder.CreateOr(LHS0, RHS0);
+ return Builder.CreateICmp(PredL, NewOr,
+ Constant::getNullValue(NewOr->getType()));
+ }
+
// This only handles icmp of constants: (icmp1 A, C1) | (icmp2 B, C2).
if (!LHSC || !RHSC)
return nullptr;
- if (LHSC == RHSC && PredL == PredR) {
- // (icmp ne A, 0) | (icmp ne B, 0) --> (icmp ne (A|B), 0)
- if (PredL == ICmpInst::ICMP_NE && LHSC->isZero()) {
- Value *NewOr = Builder.CreateOr(LHS0, RHS0);
- return Builder.CreateICmp(PredL, NewOr, LHSC);
- }
- }
-
// (icmp ult (X + CA), C1) | (icmp eq X, C2) -> (icmp ule (X + CA), C1)
// iff C2 + CA == C1.
if (PredL == ICmpInst::ICMP_ULT && PredR == ICmpInst::ICMP_EQ) {
@@ -2559,7 +2554,7 @@ Value *InstCombiner::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS,
// FIXME: We use commutative matchers (m_c_*) for some, but not all, matches
// here. We should standardize that construct where it is needed or choose some
// other way to ensure that commutated variants of patterns are not missed.
-Instruction *InstCombiner::visitOr(BinaryOperator &I) {
+Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
if (Value *V = SimplifyOrInst(I.getOperand(0), I.getOperand(1),
SQ.getWithInstruction(&I)))
return replaceInstUsesWith(I, V);
@@ -2589,11 +2584,12 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) {
if (Instruction *FoldedLogic = foldBinOpIntoSelectOrPhi(I))
return FoldedLogic;
- if (Instruction *BSwap = matchBSwap(I))
+ if (Instruction *BSwap = matchBSwapOrBitReverse(I, /*MatchBSwaps*/ true,
+ /*MatchBitReversals*/ false))
return BSwap;
- if (Instruction *Rotate = matchRotate(I))
- return Rotate;
+ if (Instruction *Funnel = matchFunnelShift(I, *this))
+ return Funnel;
if (Instruction *Concat = matchOrConcat(I, Builder))
return replaceInstUsesWith(I, Concat);
@@ -2613,9 +2609,9 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) {
Value *A, *B, *C, *D;
if (match(Op0, m_And(m_Value(A), m_Value(C))) &&
match(Op1, m_And(m_Value(B), m_Value(D)))) {
- ConstantInt *C1 = dyn_cast<ConstantInt>(C);
- ConstantInt *C2 = dyn_cast<ConstantInt>(D);
- if (C1 && C2) { // (A & C1)|(B & C2)
+ // (A & C1)|(B & C2)
+ ConstantInt *C1, *C2;
+ if (match(C, m_ConstantInt(C1)) && match(D, m_ConstantInt(C2))) {
Value *V1 = nullptr, *V2 = nullptr;
if ((C1->getValue() & C2->getValue()).isNullValue()) {
// ((V | N) & C1) | (V & C2) --> (V|N) & (C1|C2)
@@ -2806,7 +2802,7 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) {
// ORs in the hopes that we'll be able to simplify it this way.
// (X|C) | V --> (X|V) | C
ConstantInt *CI;
- if (Op0->hasOneUse() && !isa<ConstantInt>(Op1) &&
+ if (Op0->hasOneUse() && !match(Op1, m_ConstantInt()) &&
match(Op0, m_Or(m_Value(A), m_ConstantInt(CI)))) {
Value *Inner = Builder.CreateOr(A, Op1);
Inner->takeName(Op0);
@@ -2827,18 +2823,17 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) {
}
}
- // or(ashr(subNSW(Y, X), ScalarSizeInBits(Y)-1), X) --> X s> Y ? -1 : X.
+ // or(ashr(subNSW(Y, X), ScalarSizeInBits(Y) - 1), X) --> X s> Y ? -1 : X.
{
Value *X, *Y;
- const APInt *ShAmt;
Type *Ty = I.getType();
- if (match(&I, m_c_Or(m_OneUse(m_AShr(m_NSWSub(m_Value(Y), m_Value(X)),
- m_APInt(ShAmt))),
- m_Deferred(X))) &&
- *ShAmt == Ty->getScalarSizeInBits() - 1) {
+ if (match(&I, m_c_Or(m_OneUse(m_AShr(
+ m_NSWSub(m_Value(Y), m_Value(X)),
+ m_SpecificInt(Ty->getScalarSizeInBits() - 1))),
+ m_Deferred(X)))) {
Value *NewICmpInst = Builder.CreateICmpSGT(X, Y);
- return SelectInst::Create(NewICmpInst, ConstantInt::getAllOnesValue(Ty),
- X);
+ Value *AllOnes = ConstantInt::getAllOnesValue(Ty);
+ return SelectInst::Create(NewICmpInst, AllOnes, X);
}
}
@@ -2872,6 +2867,10 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) {
}
}
+ // (~x) | y --> ~(x & (~y)) iff that gets rid of inversions
+ if (sinkNotIntoOtherHandOfAndOrOr(I))
+ return &I;
+
return nullptr;
}
@@ -2928,8 +2927,8 @@ static Instruction *foldXorToXor(BinaryOperator &I,
return nullptr;
}
-Value *InstCombiner::foldXorOfICmps(ICmpInst *LHS, ICmpInst *RHS,
- BinaryOperator &I) {
+Value *InstCombinerImpl::foldXorOfICmps(ICmpInst *LHS, ICmpInst *RHS,
+ BinaryOperator &I) {
assert(I.getOpcode() == Instruction::Xor && I.getOperand(0) == LHS &&
I.getOperand(1) == RHS && "Should be 'xor' with these operands");
@@ -3087,9 +3086,9 @@ static Instruction *sinkNotIntoXor(BinaryOperator &I,
return nullptr;
// We only want to do the transform if it is free to do.
- if (isFreeToInvert(X, X->hasOneUse())) {
+ if (InstCombiner::isFreeToInvert(X, X->hasOneUse())) {
// Ok, good.
- } else if (isFreeToInvert(Y, Y->hasOneUse())) {
+ } else if (InstCombiner::isFreeToInvert(Y, Y->hasOneUse())) {
std::swap(X, Y);
} else
return nullptr;
@@ -3098,10 +3097,52 @@ static Instruction *sinkNotIntoXor(BinaryOperator &I,
return BinaryOperator::CreateXor(NotX, Y, I.getName() + ".demorgan");
}
+// Transform
+// z = (~x) &/| y
+// into:
+// z = ~(x |/& (~y))
+// iff y is free to invert and all uses of z can be freely updated.
+bool InstCombinerImpl::sinkNotIntoOtherHandOfAndOrOr(BinaryOperator &I) {
+ Instruction::BinaryOps NewOpc;
+ switch (I.getOpcode()) {
+ case Instruction::And:
+ NewOpc = Instruction::Or;
+ break;
+ case Instruction::Or:
+ NewOpc = Instruction::And;
+ break;
+ default:
+ return false;
+ };
+
+ Value *X, *Y;
+ if (!match(&I, m_c_BinOp(m_Not(m_Value(X)), m_Value(Y))))
+ return false;
+
+ // Will we be able to fold the `not` into Y eventually?
+ if (!InstCombiner::isFreeToInvert(Y, Y->hasOneUse()))
+ return false;
+
+ // And can our users be adapted?
+ if (!InstCombiner::canFreelyInvertAllUsersOf(&I, /*IgnoredUser=*/nullptr))
+ return false;
+
+ Value *NotY = Builder.CreateNot(Y, Y->getName() + ".not");
+ Value *NewBinOp =
+ BinaryOperator::Create(NewOpc, X, NotY, I.getName() + ".not");
+ Builder.Insert(NewBinOp);
+ replaceInstUsesWith(I, NewBinOp);
+ // We can not just create an outer `not`, it will most likely be immediately
+ // folded back, reconstructing our initial pattern, and causing an
+ // infinite combine loop, so immediately manually fold it away.
+ freelyInvertAllUsersOf(NewBinOp);
+ return true;
+}
+
// FIXME: We use commutative matchers (m_c_*) for some, but not all, matches
// here. We should standardize that construct where it is needed or choose some
// other way to ensure that commutated variants of patterns are not missed.
-Instruction *InstCombiner::visitXor(BinaryOperator &I) {
+Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) {
if (Value *V = SimplifyXorInst(I.getOperand(0), I.getOperand(1),
SQ.getWithInstruction(&I)))
return replaceInstUsesWith(I, V);
@@ -3128,6 +3169,7 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) {
return replaceInstUsesWith(I, V);
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
+ Type *Ty = I.getType();
// Fold (X & M) ^ (Y & ~M) -> (X & M) | (Y & ~M)
// This it a special case in haveNoCommonBitsSet, but the computeKnownBits
@@ -3199,7 +3241,7 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) {
match(C, m_Negative())) {
// We matched a negative constant, so propagating undef is unsafe.
// Clamp undef elements to -1.
- Type *EltTy = C->getType()->getScalarType();
+ Type *EltTy = Ty->getScalarType();
C = Constant::replaceUndefsWith(C, ConstantInt::getAllOnesValue(EltTy));
return BinaryOperator::CreateLShr(ConstantExpr::getNot(C), Y);
}
@@ -3209,7 +3251,7 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) {
match(C, m_NonNegative())) {
// We matched a non-negative constant, so propagating undef is unsafe.
// Clamp undef elements to 0.
- Type *EltTy = C->getType()->getScalarType();
+ Type *EltTy = Ty->getScalarType();
C = Constant::replaceUndefsWith(C, ConstantInt::getNullValue(EltTy));
return BinaryOperator::CreateAShr(ConstantExpr::getNot(C), Y);
}
@@ -3217,6 +3259,11 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) {
// ~(X + C) --> -(C + 1) - X
if (match(Op0, m_Add(m_Value(X), m_Constant(C))))
return BinaryOperator::CreateSub(ConstantExpr::getNeg(AddOne(C)), X);
+
+ // ~(~X + Y) --> X - Y
+ if (match(NotVal, m_c_Add(m_Not(m_Value(X)), m_Value(Y))))
+ return BinaryOperator::CreateWithCopiedFlags(Instruction::Sub, X, Y,
+ NotVal);
}
// Use DeMorgan and reassociation to eliminate a 'not' op.
@@ -3247,52 +3294,56 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) {
if (match(Op1, m_APInt(RHSC))) {
Value *X;
const APInt *C;
- if (RHSC->isSignMask() && match(Op0, m_Sub(m_APInt(C), m_Value(X)))) {
- // (C - X) ^ signmask -> (C + signmask - X)
- Constant *NewC = ConstantInt::get(I.getType(), *C + *RHSC);
- return BinaryOperator::CreateSub(NewC, X);
- }
- if (RHSC->isSignMask() && match(Op0, m_Add(m_Value(X), m_APInt(C)))) {
- // (X + C) ^ signmask -> (X + C + signmask)
- Constant *NewC = ConstantInt::get(I.getType(), *C + *RHSC);
- return BinaryOperator::CreateAdd(X, NewC);
- }
+ // (C - X) ^ signmaskC --> (C + signmaskC) - X
+ if (RHSC->isSignMask() && match(Op0, m_Sub(m_APInt(C), m_Value(X))))
+ return BinaryOperator::CreateSub(ConstantInt::get(Ty, *C + *RHSC), X);
+
+ // (X + C) ^ signmaskC --> X + (C + signmaskC)
+ if (RHSC->isSignMask() && match(Op0, m_Add(m_Value(X), m_APInt(C))))
+ return BinaryOperator::CreateAdd(X, ConstantInt::get(Ty, *C + *RHSC));
- // (X|C1)^C2 -> X^(C1^C2) iff X&~C1 == 0
+ // (X | C) ^ RHSC --> X ^ (C ^ RHSC) iff X & C == 0
if (match(Op0, m_Or(m_Value(X), m_APInt(C))) &&
- MaskedValueIsZero(X, *C, 0, &I)) {
- Constant *NewC = ConstantInt::get(I.getType(), *C ^ *RHSC);
- return BinaryOperator::CreateXor(X, NewC);
+ MaskedValueIsZero(X, *C, 0, &I))
+ return BinaryOperator::CreateXor(X, ConstantInt::get(Ty, *C ^ *RHSC));
+
+ // If RHSC is inverting the remaining bits of shifted X,
+ // canonicalize to a 'not' before the shift to help SCEV and codegen:
+ // (X << C) ^ RHSC --> ~X << C
+ if (match(Op0, m_OneUse(m_Shl(m_Value(X), m_APInt(C)))) &&
+ *RHSC == APInt::getAllOnesValue(Ty->getScalarSizeInBits()).shl(*C)) {
+ Value *NotX = Builder.CreateNot(X);
+ return BinaryOperator::CreateShl(NotX, ConstantInt::get(Ty, *C));
}
+ // (X >>u C) ^ RHSC --> ~X >>u C
+ if (match(Op0, m_OneUse(m_LShr(m_Value(X), m_APInt(C)))) &&
+ *RHSC == APInt::getAllOnesValue(Ty->getScalarSizeInBits()).lshr(*C)) {
+ Value *NotX = Builder.CreateNot(X);
+ return BinaryOperator::CreateLShr(NotX, ConstantInt::get(Ty, *C));
+ }
+ // TODO: We could handle 'ashr' here as well. That would be matching
+ // a 'not' op and moving it before the shift. Doing that requires
+ // preventing the inverse fold in canShiftBinOpWithConstantRHS().
}
}
- if (ConstantInt *RHSC = dyn_cast<ConstantInt>(Op1)) {
- if (BinaryOperator *Op0I = dyn_cast<BinaryOperator>(Op0)) {
- if (ConstantInt *Op0CI = dyn_cast<ConstantInt>(Op0I->getOperand(1))) {
- if (Op0I->getOpcode() == Instruction::LShr) {
- // ((X^C1) >> C2) ^ C3 -> (X>>C2) ^ ((C1>>C2)^C3)
- // E1 = "X ^ C1"
- BinaryOperator *E1;
- ConstantInt *C1;
- if (Op0I->hasOneUse() &&
- (E1 = dyn_cast<BinaryOperator>(Op0I->getOperand(0))) &&
- E1->getOpcode() == Instruction::Xor &&
- (C1 = dyn_cast<ConstantInt>(E1->getOperand(1)))) {
- // fold (C1 >> C2) ^ C3
- ConstantInt *C2 = Op0CI, *C3 = RHSC;
- APInt FoldConst = C1->getValue().lshr(C2->getValue());
- FoldConst ^= C3->getValue();
- // Prepare the two operands.
- Value *Opnd0 = Builder.CreateLShr(E1->getOperand(0), C2);
- Opnd0->takeName(Op0I);
- cast<Instruction>(Opnd0)->setDebugLoc(I.getDebugLoc());
- Value *FoldVal = ConstantInt::get(Opnd0->getType(), FoldConst);
-
- return BinaryOperator::CreateXor(Opnd0, FoldVal);
- }
- }
- }
+ // FIXME: This should not be limited to scalar (pull into APInt match above).
+ {
+ Value *X;
+ ConstantInt *C1, *C2, *C3;
+ // ((X^C1) >> C2) ^ C3 -> (X>>C2) ^ ((C1>>C2)^C3)
+ if (match(Op1, m_ConstantInt(C3)) &&
+ match(Op0, m_LShr(m_Xor(m_Value(X), m_ConstantInt(C1)),
+ m_ConstantInt(C2))) &&
+ Op0->hasOneUse()) {
+ // fold (C1 >> C2) ^ C3
+ APInt FoldConst = C1->getValue().lshr(C2->getValue());
+ FoldConst ^= C3->getValue();
+ // Prepare the two operands.
+ auto *Opnd0 = cast<Instruction>(Builder.CreateLShr(X, C2));
+ Opnd0->takeName(cast<Instruction>(Op0));
+ Opnd0->setDebugLoc(I.getDebugLoc());
+ return BinaryOperator::CreateXor(Opnd0, ConstantInt::get(Ty, FoldConst));
}
}
@@ -3349,6 +3400,25 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) {
match(Op1, m_Not(m_Specific(A))))
return BinaryOperator::CreateNot(Builder.CreateAnd(A, B));
+ // (~A & B) ^ A --> A | B -- There are 4 commuted variants.
+ if (match(&I, m_c_Xor(m_c_And(m_Not(m_Value(A)), m_Value(B)), m_Deferred(A))))
+ return BinaryOperator::CreateOr(A, B);
+
+ // (A | B) ^ (A | C) --> (B ^ C) & ~A -- There are 4 commuted variants.
+ // TODO: Loosen one-use restriction if common operand is a constant.
+ Value *D;
+ if (match(Op0, m_OneUse(m_Or(m_Value(A), m_Value(B)))) &&
+ match(Op1, m_OneUse(m_Or(m_Value(C), m_Value(D))))) {
+ if (B == C || B == D)
+ std::swap(A, B);
+ if (A == C)
+ std::swap(C, D);
+ if (A == D) {
+ Value *NotA = Builder.CreateNot(A);
+ return BinaryOperator::CreateAnd(Builder.CreateXor(B, C), NotA);
+ }
+ }
+
if (auto *LHS = dyn_cast<ICmpInst>(I.getOperand(0)))
if (auto *RHS = dyn_cast<ICmpInst>(I.getOperand(1)))
if (Value *V = foldXorOfICmps(LHS, RHS, I))
@@ -3366,7 +3436,6 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) {
std::swap(Op0, Op1);
const APInt *ShAmt;
- Type *Ty = I.getType();
if (match(Op1, m_AShr(m_Value(A), m_APInt(ShAmt))) &&
Op1->hasNUses(2) && *ShAmt == Ty->getScalarSizeInBits() - 1 &&
match(Op0, m_OneUse(m_c_Add(m_Specific(A), m_Specific(Op1))))) {
@@ -3425,19 +3494,30 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) {
}
}
- // Pull 'not' into operands of select if both operands are one-use compares.
+ // Pull 'not' into operands of select if both operands are one-use compares
+ // or one is one-use compare and the other one is a constant.
// Inverting the predicates eliminates the 'not' operation.
// Example:
- // not (select ?, (cmp TPred, ?, ?), (cmp FPred, ?, ?) -->
+ // not (select ?, (cmp TPred, ?, ?), (cmp FPred, ?, ?) -->
// select ?, (cmp InvTPred, ?, ?), (cmp InvFPred, ?, ?)
- // TODO: Canonicalize by hoisting 'not' into an arm of the select if only
- // 1 select operand is a cmp?
+ // not (select ?, (cmp TPred, ?, ?), true -->
+ // select ?, (cmp InvTPred, ?, ?), false
if (auto *Sel = dyn_cast<SelectInst>(Op0)) {
- auto *CmpT = dyn_cast<CmpInst>(Sel->getTrueValue());
- auto *CmpF = dyn_cast<CmpInst>(Sel->getFalseValue());
- if (CmpT && CmpF && CmpT->hasOneUse() && CmpF->hasOneUse()) {
- CmpT->setPredicate(CmpT->getInversePredicate());
- CmpF->setPredicate(CmpF->getInversePredicate());
+ Value *TV = Sel->getTrueValue();
+ Value *FV = Sel->getFalseValue();
+ auto *CmpT = dyn_cast<CmpInst>(TV);
+ auto *CmpF = dyn_cast<CmpInst>(FV);
+ bool InvertibleT = (CmpT && CmpT->hasOneUse()) || isa<Constant>(TV);
+ bool InvertibleF = (CmpF && CmpF->hasOneUse()) || isa<Constant>(FV);
+ if (InvertibleT && InvertibleF) {
+ if (CmpT)
+ CmpT->setPredicate(CmpT->getInversePredicate());
+ else
+ Sel->setTrueValue(ConstantExpr::getNot(cast<Constant>(TV)));
+ if (CmpF)
+ CmpF->setPredicate(CmpF->getInversePredicate());
+ else
+ Sel->setFalseValue(ConstantExpr::getNot(cast<Constant>(FV)));
return replaceInstUsesWith(I, Sel);
}
}
@@ -3446,5 +3526,15 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) {
if (Instruction *NewXor = sinkNotIntoXor(I, Builder))
return NewXor;
+ // Otherwise, if all else failed, try to hoist the xor-by-constant:
+ // (X ^ C) ^ Y --> (X ^ Y) ^ C
+ // Just like we do in other places, we completely avoid the fold
+ // for constantexprs, at least to avoid endless combine loop.
+ if (match(&I, m_c_Xor(m_OneUse(m_Xor(m_CombineAnd(m_Value(X),
+ m_Unless(m_ConstantExpr())),
+ m_ImmConstant(C1))),
+ m_Value(Y))))
+ return BinaryOperator::CreateXor(Builder.CreateXor(X, Y), C1);
+
return nullptr;
}
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAtomicRMW.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAtomicRMW.cpp
index ba1cf982229d..495493aab4b5 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAtomicRMW.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAtomicRMW.cpp
@@ -9,8 +9,10 @@
// This file implements the visit functions for atomic rmw instructions.
//
//===----------------------------------------------------------------------===//
+
#include "InstCombineInternal.h"
#include "llvm/IR/Instructions.h"
+#include "llvm/Transforms/InstCombine/InstCombiner.h"
using namespace llvm;
@@ -30,7 +32,7 @@ bool isIdempotentRMW(AtomicRMWInst& RMWI) {
default:
return false;
};
-
+
auto C = dyn_cast<ConstantInt>(RMWI.getValOperand());
if(!C)
return false;
@@ -91,13 +93,13 @@ bool isSaturating(AtomicRMWInst& RMWI) {
return C->isMaxValue(false);
};
}
-}
+} // namespace
-Instruction *InstCombiner::visitAtomicRMWInst(AtomicRMWInst &RMWI) {
+Instruction *InstCombinerImpl::visitAtomicRMWInst(AtomicRMWInst &RMWI) {
// Volatile RMWs perform a load and a store, we cannot replace this by just a
// load or just a store. We chose not to canonicalize out of general paranoia
- // about user expectations around volatile.
+ // about user expectations around volatile.
if (RMWI.isVolatile())
return nullptr;
@@ -115,7 +117,7 @@ Instruction *InstCombiner::visitAtomicRMWInst(AtomicRMWInst &RMWI) {
"AtomicRMWs don't make sense with Unordered or NotAtomic");
// Any atomicrmw xchg with no uses can be converted to a atomic store if the
- // ordering is compatible.
+ // ordering is compatible.
if (RMWI.getOperation() == AtomicRMWInst::Xchg &&
RMWI.use_empty()) {
if (Ordering != AtomicOrdering::Release &&
@@ -127,14 +129,14 @@ Instruction *InstCombiner::visitAtomicRMWInst(AtomicRMWInst &RMWI) {
SI->setAlignment(DL.getABITypeAlign(RMWI.getType()));
return eraseInstFromFunction(RMWI);
}
-
+
if (!isIdempotentRMW(RMWI))
return nullptr;
// We chose to canonicalize all idempotent operations to an single
// operation code and constant. This makes it easier for the rest of the
// optimizer to match easily. The choices of or w/0 and fadd w/-0.0 are
- // arbitrary.
+ // arbitrary.
if (RMWI.getType()->isIntegerTy() &&
RMWI.getOperation() != AtomicRMWInst::Or) {
RMWI.setOperation(AtomicRMWInst::Or);
@@ -149,7 +151,7 @@ Instruction *InstCombiner::visitAtomicRMWInst(AtomicRMWInst &RMWI) {
if (Ordering != AtomicOrdering::Acquire &&
Ordering != AtomicOrdering::Monotonic)
return nullptr;
-
+
LoadInst *Load = new LoadInst(RMWI.getType(), RMWI.getPointerOperand(), "",
false, DL.getABITypeAlign(RMWI.getType()),
Ordering, RMWI.getSyncScopeID());
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index c734c9a68fb2..5482b944e347 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -28,6 +28,7 @@
#include "llvm/Analysis/InstructionSimplify.h"
#include "llvm/Analysis/Loads.h"
#include "llvm/Analysis/MemoryBuiltins.h"
+#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/Analysis/VectorUtils.h"
#include "llvm/IR/Attributes.h"
@@ -47,9 +48,6 @@
#include "llvm/IR/IntrinsicsAMDGPU.h"
#include "llvm/IR/IntrinsicsARM.h"
#include "llvm/IR/IntrinsicsHexagon.h"
-#include "llvm/IR/IntrinsicsNVPTX.h"
-#include "llvm/IR/IntrinsicsPowerPC.h"
-#include "llvm/IR/IntrinsicsX86.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Metadata.h"
#include "llvm/IR/PatternMatch.h"
@@ -68,6 +66,7 @@
#include "llvm/Support/MathExtras.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/InstCombine/InstCombineWorklist.h"
+#include "llvm/Transforms/InstCombine/InstCombiner.h"
#include "llvm/Transforms/Utils/Local.h"
#include "llvm/Transforms/Utils/SimplifyLibCalls.h"
#include <algorithm>
@@ -100,24 +99,7 @@ static Type *getPromotedType(Type *Ty) {
return Ty;
}
-/// Return a constant boolean vector that has true elements in all positions
-/// where the input constant data vector has an element with the sign bit set.
-static Constant *getNegativeIsTrueBoolVec(ConstantDataVector *V) {
- SmallVector<Constant *, 32> BoolVec;
- IntegerType *BoolTy = Type::getInt1Ty(V->getContext());
- for (unsigned I = 0, E = V->getNumElements(); I != E; ++I) {
- Constant *Elt = V->getElementAsConstant(I);
- assert((isa<ConstantInt>(Elt) || isa<ConstantFP>(Elt)) &&
- "Unexpected constant data vector element type");
- bool Sign = V->getElementType()->isIntegerTy()
- ? cast<ConstantInt>(Elt)->isNegative()
- : cast<ConstantFP>(Elt)->isNegative();
- BoolVec.push_back(ConstantInt::get(BoolTy, Sign));
- }
- return ConstantVector::get(BoolVec);
-}
-
-Instruction *InstCombiner::SimplifyAnyMemTransfer(AnyMemTransferInst *MI) {
+Instruction *InstCombinerImpl::SimplifyAnyMemTransfer(AnyMemTransferInst *MI) {
Align DstAlign = getKnownAlignment(MI->getRawDest(), DL, MI, &AC, &DT);
MaybeAlign CopyDstAlign = MI->getDestAlign();
if (!CopyDstAlign || *CopyDstAlign < DstAlign) {
@@ -232,7 +214,7 @@ Instruction *InstCombiner::SimplifyAnyMemTransfer(AnyMemTransferInst *MI) {
return MI;
}
-Instruction *InstCombiner::SimplifyAnyMemSet(AnyMemSetInst *MI) {
+Instruction *InstCombinerImpl::SimplifyAnyMemSet(AnyMemSetInst *MI) {
const Align KnownAlignment =
getKnownAlignment(MI->getDest(), DL, MI, &AC, &DT);
MaybeAlign MemSetAlign = MI->getDestAlign();
@@ -292,820 +274,9 @@ Instruction *InstCombiner::SimplifyAnyMemSet(AnyMemSetInst *MI) {
return nullptr;
}
-static Value *simplifyX86immShift(const IntrinsicInst &II,
- InstCombiner::BuilderTy &Builder) {
- bool LogicalShift = false;
- bool ShiftLeft = false;
- bool IsImm = false;
-
- switch (II.getIntrinsicID()) {
- default: llvm_unreachable("Unexpected intrinsic!");
- case Intrinsic::x86_sse2_psrai_d:
- case Intrinsic::x86_sse2_psrai_w:
- case Intrinsic::x86_avx2_psrai_d:
- case Intrinsic::x86_avx2_psrai_w:
- case Intrinsic::x86_avx512_psrai_q_128:
- case Intrinsic::x86_avx512_psrai_q_256:
- case Intrinsic::x86_avx512_psrai_d_512:
- case Intrinsic::x86_avx512_psrai_q_512:
- case Intrinsic::x86_avx512_psrai_w_512:
- IsImm = true;
- LLVM_FALLTHROUGH;
- case Intrinsic::x86_sse2_psra_d:
- case Intrinsic::x86_sse2_psra_w:
- case Intrinsic::x86_avx2_psra_d:
- case Intrinsic::x86_avx2_psra_w:
- case Intrinsic::x86_avx512_psra_q_128:
- case Intrinsic::x86_avx512_psra_q_256:
- case Intrinsic::x86_avx512_psra_d_512:
- case Intrinsic::x86_avx512_psra_q_512:
- case Intrinsic::x86_avx512_psra_w_512:
- LogicalShift = false;
- ShiftLeft = false;
- break;
- case Intrinsic::x86_sse2_psrli_d:
- case Intrinsic::x86_sse2_psrli_q:
- case Intrinsic::x86_sse2_psrli_w:
- case Intrinsic::x86_avx2_psrli_d:
- case Intrinsic::x86_avx2_psrli_q:
- case Intrinsic::x86_avx2_psrli_w:
- case Intrinsic::x86_avx512_psrli_d_512:
- case Intrinsic::x86_avx512_psrli_q_512:
- case Intrinsic::x86_avx512_psrli_w_512:
- IsImm = true;
- LLVM_FALLTHROUGH;
- case Intrinsic::x86_sse2_psrl_d:
- case Intrinsic::x86_sse2_psrl_q:
- case Intrinsic::x86_sse2_psrl_w:
- case Intrinsic::x86_avx2_psrl_d:
- case Intrinsic::x86_avx2_psrl_q:
- case Intrinsic::x86_avx2_psrl_w:
- case Intrinsic::x86_avx512_psrl_d_512:
- case Intrinsic::x86_avx512_psrl_q_512:
- case Intrinsic::x86_avx512_psrl_w_512:
- LogicalShift = true;
- ShiftLeft = false;
- break;
- case Intrinsic::x86_sse2_pslli_d:
- case Intrinsic::x86_sse2_pslli_q:
- case Intrinsic::x86_sse2_pslli_w:
- case Intrinsic::x86_avx2_pslli_d:
- case Intrinsic::x86_avx2_pslli_q:
- case Intrinsic::x86_avx2_pslli_w:
- case Intrinsic::x86_avx512_pslli_d_512:
- case Intrinsic::x86_avx512_pslli_q_512:
- case Intrinsic::x86_avx512_pslli_w_512:
- IsImm = true;
- LLVM_FALLTHROUGH;
- case Intrinsic::x86_sse2_psll_d:
- case Intrinsic::x86_sse2_psll_q:
- case Intrinsic::x86_sse2_psll_w:
- case Intrinsic::x86_avx2_psll_d:
- case Intrinsic::x86_avx2_psll_q:
- case Intrinsic::x86_avx2_psll_w:
- case Intrinsic::x86_avx512_psll_d_512:
- case Intrinsic::x86_avx512_psll_q_512:
- case Intrinsic::x86_avx512_psll_w_512:
- LogicalShift = true;
- ShiftLeft = true;
- break;
- }
- assert((LogicalShift || !ShiftLeft) && "Only logical shifts can shift left");
-
- auto Vec = II.getArgOperand(0);
- auto Amt = II.getArgOperand(1);
- auto VT = cast<VectorType>(Vec->getType());
- auto SVT = VT->getElementType();
- auto AmtVT = Amt->getType();
- unsigned VWidth = VT->getNumElements();
- unsigned BitWidth = SVT->getPrimitiveSizeInBits();
-
- // If the shift amount is guaranteed to be in-range we can replace it with a
- // generic shift. If its guaranteed to be out of range, logical shifts combine to
- // zero and arithmetic shifts are clamped to (BitWidth - 1).
- if (IsImm) {
- assert(AmtVT ->isIntegerTy(32) &&
- "Unexpected shift-by-immediate type");
- KnownBits KnownAmtBits =
- llvm::computeKnownBits(Amt, II.getModule()->getDataLayout());
- if (KnownAmtBits.getMaxValue().ult(BitWidth)) {
- Amt = Builder.CreateZExtOrTrunc(Amt, SVT);
- Amt = Builder.CreateVectorSplat(VWidth, Amt);
- return (LogicalShift ? (ShiftLeft ? Builder.CreateShl(Vec, Amt)
- : Builder.CreateLShr(Vec, Amt))
- : Builder.CreateAShr(Vec, Amt));
- }
- if (KnownAmtBits.getMinValue().uge(BitWidth)) {
- if (LogicalShift)
- return ConstantAggregateZero::get(VT);
- Amt = ConstantInt::get(SVT, BitWidth - 1);
- return Builder.CreateAShr(Vec, Builder.CreateVectorSplat(VWidth, Amt));
- }
- } else {
- // Ensure the first element has an in-range value and the rest of the
- // elements in the bottom 64 bits are zero.
- assert(AmtVT->isVectorTy() && AmtVT->getPrimitiveSizeInBits() == 128 &&
- cast<VectorType>(AmtVT)->getElementType() == SVT &&
- "Unexpected shift-by-scalar type");
- unsigned NumAmtElts = cast<VectorType>(AmtVT)->getNumElements();
- APInt DemandedLower = APInt::getOneBitSet(NumAmtElts, 0);
- APInt DemandedUpper = APInt::getBitsSet(NumAmtElts, 1, NumAmtElts / 2);
- KnownBits KnownLowerBits = llvm::computeKnownBits(
- Amt, DemandedLower, II.getModule()->getDataLayout());
- KnownBits KnownUpperBits = llvm::computeKnownBits(
- Amt, DemandedUpper, II.getModule()->getDataLayout());
- if (KnownLowerBits.getMaxValue().ult(BitWidth) &&
- (DemandedUpper.isNullValue() || KnownUpperBits.isZero())) {
- SmallVector<int, 16> ZeroSplat(VWidth, 0);
- Amt = Builder.CreateShuffleVector(Amt, Amt, ZeroSplat);
- return (LogicalShift ? (ShiftLeft ? Builder.CreateShl(Vec, Amt)
- : Builder.CreateLShr(Vec, Amt))
- : Builder.CreateAShr(Vec, Amt));
- }
- }
-
- // Simplify if count is constant vector.
- auto CDV = dyn_cast<ConstantDataVector>(Amt);
- if (!CDV)
- return nullptr;
-
- // SSE2/AVX2 uses all the first 64-bits of the 128-bit vector
- // operand to compute the shift amount.
- assert(AmtVT->isVectorTy() && AmtVT->getPrimitiveSizeInBits() == 128 &&
- cast<VectorType>(AmtVT)->getElementType() == SVT &&
- "Unexpected shift-by-scalar type");
-
- // Concatenate the sub-elements to create the 64-bit value.
- APInt Count(64, 0);
- for (unsigned i = 0, NumSubElts = 64 / BitWidth; i != NumSubElts; ++i) {
- unsigned SubEltIdx = (NumSubElts - 1) - i;
- auto SubElt = cast<ConstantInt>(CDV->getElementAsConstant(SubEltIdx));
- Count <<= BitWidth;
- Count |= SubElt->getValue().zextOrTrunc(64);
- }
-
- // If shift-by-zero then just return the original value.
- if (Count.isNullValue())
- return Vec;
-
- // Handle cases when Shift >= BitWidth.
- if (Count.uge(BitWidth)) {
- // If LogicalShift - just return zero.
- if (LogicalShift)
- return ConstantAggregateZero::get(VT);
-
- // If ArithmeticShift - clamp Shift to (BitWidth - 1).
- Count = APInt(64, BitWidth - 1);
- }
-
- // Get a constant vector of the same type as the first operand.
- auto ShiftAmt = ConstantInt::get(SVT, Count.zextOrTrunc(BitWidth));
- auto ShiftVec = Builder.CreateVectorSplat(VWidth, ShiftAmt);
-
- if (ShiftLeft)
- return Builder.CreateShl(Vec, ShiftVec);
-
- if (LogicalShift)
- return Builder.CreateLShr(Vec, ShiftVec);
-
- return Builder.CreateAShr(Vec, ShiftVec);
-}
-
-// Attempt to simplify AVX2 per-element shift intrinsics to a generic IR shift.
-// Unlike the generic IR shifts, the intrinsics have defined behaviour for out
-// of range shift amounts (logical - set to zero, arithmetic - splat sign bit).
-static Value *simplifyX86varShift(const IntrinsicInst &II,
- InstCombiner::BuilderTy &Builder) {
- bool LogicalShift = false;
- bool ShiftLeft = false;
-
- switch (II.getIntrinsicID()) {
- default: llvm_unreachable("Unexpected intrinsic!");
- case Intrinsic::x86_avx2_psrav_d:
- case Intrinsic::x86_avx2_psrav_d_256:
- case Intrinsic::x86_avx512_psrav_q_128:
- case Intrinsic::x86_avx512_psrav_q_256:
- case Intrinsic::x86_avx512_psrav_d_512:
- case Intrinsic::x86_avx512_psrav_q_512:
- case Intrinsic::x86_avx512_psrav_w_128:
- case Intrinsic::x86_avx512_psrav_w_256:
- case Intrinsic::x86_avx512_psrav_w_512:
- LogicalShift = false;
- ShiftLeft = false;
- break;
- case Intrinsic::x86_avx2_psrlv_d:
- case Intrinsic::x86_avx2_psrlv_d_256:
- case Intrinsic::x86_avx2_psrlv_q:
- case Intrinsic::x86_avx2_psrlv_q_256:
- case Intrinsic::x86_avx512_psrlv_d_512:
- case Intrinsic::x86_avx512_psrlv_q_512:
- case Intrinsic::x86_avx512_psrlv_w_128:
- case Intrinsic::x86_avx512_psrlv_w_256:
- case Intrinsic::x86_avx512_psrlv_w_512:
- LogicalShift = true;
- ShiftLeft = false;
- break;
- case Intrinsic::x86_avx2_psllv_d:
- case Intrinsic::x86_avx2_psllv_d_256:
- case Intrinsic::x86_avx2_psllv_q:
- case Intrinsic::x86_avx2_psllv_q_256:
- case Intrinsic::x86_avx512_psllv_d_512:
- case Intrinsic::x86_avx512_psllv_q_512:
- case Intrinsic::x86_avx512_psllv_w_128:
- case Intrinsic::x86_avx512_psllv_w_256:
- case Intrinsic::x86_avx512_psllv_w_512:
- LogicalShift = true;
- ShiftLeft = true;
- break;
- }
- assert((LogicalShift || !ShiftLeft) && "Only logical shifts can shift left");
-
- auto Vec = II.getArgOperand(0);
- auto Amt = II.getArgOperand(1);
- auto VT = cast<VectorType>(II.getType());
- auto SVT = VT->getElementType();
- int NumElts = VT->getNumElements();
- int BitWidth = SVT->getIntegerBitWidth();
-
- // If the shift amount is guaranteed to be in-range we can replace it with a
- // generic shift.
- APInt UpperBits =
- APInt::getHighBitsSet(BitWidth, BitWidth - Log2_32(BitWidth));
- if (llvm::MaskedValueIsZero(Amt, UpperBits,
- II.getModule()->getDataLayout())) {
- return (LogicalShift ? (ShiftLeft ? Builder.CreateShl(Vec, Amt)
- : Builder.CreateLShr(Vec, Amt))
- : Builder.CreateAShr(Vec, Amt));
- }
-
- // Simplify if all shift amounts are constant/undef.
- auto *CShift = dyn_cast<Constant>(Amt);
- if (!CShift)
- return nullptr;
-
- // Collect each element's shift amount.
- // We also collect special cases: UNDEF = -1, OUT-OF-RANGE = BitWidth.
- bool AnyOutOfRange = false;
- SmallVector<int, 8> ShiftAmts;
- for (int I = 0; I < NumElts; ++I) {
- auto *CElt = CShift->getAggregateElement(I);
- if (CElt && isa<UndefValue>(CElt)) {
- ShiftAmts.push_back(-1);
- continue;
- }
-
- auto *COp = dyn_cast_or_null<ConstantInt>(CElt);
- if (!COp)
- return nullptr;
-
- // Handle out of range shifts.
- // If LogicalShift - set to BitWidth (special case).
- // If ArithmeticShift - set to (BitWidth - 1) (sign splat).
- APInt ShiftVal = COp->getValue();
- if (ShiftVal.uge(BitWidth)) {
- AnyOutOfRange = LogicalShift;
- ShiftAmts.push_back(LogicalShift ? BitWidth : BitWidth - 1);
- continue;
- }
-
- ShiftAmts.push_back((int)ShiftVal.getZExtValue());
- }
-
- // If all elements out of range or UNDEF, return vector of zeros/undefs.
- // ArithmeticShift should only hit this if they are all UNDEF.
- auto OutOfRange = [&](int Idx) { return (Idx < 0) || (BitWidth <= Idx); };
- if (llvm::all_of(ShiftAmts, OutOfRange)) {
- SmallVector<Constant *, 8> ConstantVec;
- for (int Idx : ShiftAmts) {
- if (Idx < 0) {
- ConstantVec.push_back(UndefValue::get(SVT));
- } else {
- assert(LogicalShift && "Logical shift expected");
- ConstantVec.push_back(ConstantInt::getNullValue(SVT));
- }
- }
- return ConstantVector::get(ConstantVec);
- }
-
- // We can't handle only some out of range values with generic logical shifts.
- if (AnyOutOfRange)
- return nullptr;
-
- // Build the shift amount constant vector.
- SmallVector<Constant *, 8> ShiftVecAmts;
- for (int Idx : ShiftAmts) {
- if (Idx < 0)
- ShiftVecAmts.push_back(UndefValue::get(SVT));
- else
- ShiftVecAmts.push_back(ConstantInt::get(SVT, Idx));
- }
- auto ShiftVec = ConstantVector::get(ShiftVecAmts);
-
- if (ShiftLeft)
- return Builder.CreateShl(Vec, ShiftVec);
-
- if (LogicalShift)
- return Builder.CreateLShr(Vec, ShiftVec);
-
- return Builder.CreateAShr(Vec, ShiftVec);
-}
-
-static Value *simplifyX86pack(IntrinsicInst &II,
- InstCombiner::BuilderTy &Builder, bool IsSigned) {
- Value *Arg0 = II.getArgOperand(0);
- Value *Arg1 = II.getArgOperand(1);
- Type *ResTy = II.getType();
-
- // Fast all undef handling.
- if (isa<UndefValue>(Arg0) && isa<UndefValue>(Arg1))
- return UndefValue::get(ResTy);
-
- auto *ArgTy = cast<VectorType>(Arg0->getType());
- unsigned NumLanes = ResTy->getPrimitiveSizeInBits() / 128;
- unsigned NumSrcElts = ArgTy->getNumElements();
- assert(cast<VectorType>(ResTy)->getNumElements() == (2 * NumSrcElts) &&
- "Unexpected packing types");
-
- unsigned NumSrcEltsPerLane = NumSrcElts / NumLanes;
- unsigned DstScalarSizeInBits = ResTy->getScalarSizeInBits();
- unsigned SrcScalarSizeInBits = ArgTy->getScalarSizeInBits();
- assert(SrcScalarSizeInBits == (2 * DstScalarSizeInBits) &&
- "Unexpected packing types");
-
- // Constant folding.
- if (!isa<Constant>(Arg0) || !isa<Constant>(Arg1))
- return nullptr;
-
- // Clamp Values - signed/unsigned both use signed clamp values, but they
- // differ on the min/max values.
- APInt MinValue, MaxValue;
- if (IsSigned) {
- // PACKSS: Truncate signed value with signed saturation.
- // Source values less than dst minint are saturated to minint.
- // Source values greater than dst maxint are saturated to maxint.
- MinValue =
- APInt::getSignedMinValue(DstScalarSizeInBits).sext(SrcScalarSizeInBits);
- MaxValue =
- APInt::getSignedMaxValue(DstScalarSizeInBits).sext(SrcScalarSizeInBits);
- } else {
- // PACKUS: Truncate signed value with unsigned saturation.
- // Source values less than zero are saturated to zero.
- // Source values greater than dst maxuint are saturated to maxuint.
- MinValue = APInt::getNullValue(SrcScalarSizeInBits);
- MaxValue = APInt::getLowBitsSet(SrcScalarSizeInBits, DstScalarSizeInBits);
- }
-
- auto *MinC = Constant::getIntegerValue(ArgTy, MinValue);
- auto *MaxC = Constant::getIntegerValue(ArgTy, MaxValue);
- Arg0 = Builder.CreateSelect(Builder.CreateICmpSLT(Arg0, MinC), MinC, Arg0);
- Arg1 = Builder.CreateSelect(Builder.CreateICmpSLT(Arg1, MinC), MinC, Arg1);
- Arg0 = Builder.CreateSelect(Builder.CreateICmpSGT(Arg0, MaxC), MaxC, Arg0);
- Arg1 = Builder.CreateSelect(Builder.CreateICmpSGT(Arg1, MaxC), MaxC, Arg1);
-
- // Shuffle clamped args together at the lane level.
- SmallVector<int, 32> PackMask;
- for (unsigned Lane = 0; Lane != NumLanes; ++Lane) {
- for (unsigned Elt = 0; Elt != NumSrcEltsPerLane; ++Elt)
- PackMask.push_back(Elt + (Lane * NumSrcEltsPerLane));
- for (unsigned Elt = 0; Elt != NumSrcEltsPerLane; ++Elt)
- PackMask.push_back(Elt + (Lane * NumSrcEltsPerLane) + NumSrcElts);
- }
- auto *Shuffle = Builder.CreateShuffleVector(Arg0, Arg1, PackMask);
-
- // Truncate to dst size.
- return Builder.CreateTrunc(Shuffle, ResTy);
-}
-
-static Value *simplifyX86movmsk(const IntrinsicInst &II,
- InstCombiner::BuilderTy &Builder) {
- Value *Arg = II.getArgOperand(0);
- Type *ResTy = II.getType();
-
- // movmsk(undef) -> zero as we must ensure the upper bits are zero.
- if (isa<UndefValue>(Arg))
- return Constant::getNullValue(ResTy);
-
- auto *ArgTy = dyn_cast<VectorType>(Arg->getType());
- // We can't easily peek through x86_mmx types.
- if (!ArgTy)
- return nullptr;
-
- // Expand MOVMSK to compare/bitcast/zext:
- // e.g. PMOVMSKB(v16i8 x):
- // %cmp = icmp slt <16 x i8> %x, zeroinitializer
- // %int = bitcast <16 x i1> %cmp to i16
- // %res = zext i16 %int to i32
- unsigned NumElts = ArgTy->getNumElements();
- Type *IntegerVecTy = VectorType::getInteger(ArgTy);
- Type *IntegerTy = Builder.getIntNTy(NumElts);
-
- Value *Res = Builder.CreateBitCast(Arg, IntegerVecTy);
- Res = Builder.CreateICmpSLT(Res, Constant::getNullValue(IntegerVecTy));
- Res = Builder.CreateBitCast(Res, IntegerTy);
- Res = Builder.CreateZExtOrTrunc(Res, ResTy);
- return Res;
-}
-
-static Value *simplifyX86addcarry(const IntrinsicInst &II,
- InstCombiner::BuilderTy &Builder) {
- Value *CarryIn = II.getArgOperand(0);
- Value *Op1 = II.getArgOperand(1);
- Value *Op2 = II.getArgOperand(2);
- Type *RetTy = II.getType();
- Type *OpTy = Op1->getType();
- assert(RetTy->getStructElementType(0)->isIntegerTy(8) &&
- RetTy->getStructElementType(1) == OpTy && OpTy == Op2->getType() &&
- "Unexpected types for x86 addcarry");
-
- // If carry-in is zero, this is just an unsigned add with overflow.
- if (match(CarryIn, m_ZeroInt())) {
- Value *UAdd = Builder.CreateIntrinsic(Intrinsic::uadd_with_overflow, OpTy,
- { Op1, Op2 });
- // The types have to be adjusted to match the x86 call types.
- Value *UAddResult = Builder.CreateExtractValue(UAdd, 0);
- Value *UAddOV = Builder.CreateZExt(Builder.CreateExtractValue(UAdd, 1),
- Builder.getInt8Ty());
- Value *Res = UndefValue::get(RetTy);
- Res = Builder.CreateInsertValue(Res, UAddOV, 0);
- return Builder.CreateInsertValue(Res, UAddResult, 1);
- }
-
- return nullptr;
-}
-
-static Value *simplifyX86insertps(const IntrinsicInst &II,
- InstCombiner::BuilderTy &Builder) {
- auto *CInt = dyn_cast<ConstantInt>(II.getArgOperand(2));
- if (!CInt)
- return nullptr;
-
- VectorType *VecTy = cast<VectorType>(II.getType());
- assert(VecTy->getNumElements() == 4 && "insertps with wrong vector type");
-
- // The immediate permute control byte looks like this:
- // [3:0] - zero mask for each 32-bit lane
- // [5:4] - select one 32-bit destination lane
- // [7:6] - select one 32-bit source lane
-
- uint8_t Imm = CInt->getZExtValue();
- uint8_t ZMask = Imm & 0xf;
- uint8_t DestLane = (Imm >> 4) & 0x3;
- uint8_t SourceLane = (Imm >> 6) & 0x3;
-
- ConstantAggregateZero *ZeroVector = ConstantAggregateZero::get(VecTy);
-
- // If all zero mask bits are set, this was just a weird way to
- // generate a zero vector.
- if (ZMask == 0xf)
- return ZeroVector;
-
- // Initialize by passing all of the first source bits through.
- int ShuffleMask[4] = {0, 1, 2, 3};
-
- // We may replace the second operand with the zero vector.
- Value *V1 = II.getArgOperand(1);
-
- if (ZMask) {
- // If the zero mask is being used with a single input or the zero mask
- // overrides the destination lane, this is a shuffle with the zero vector.
- if ((II.getArgOperand(0) == II.getArgOperand(1)) ||
- (ZMask & (1 << DestLane))) {
- V1 = ZeroVector;
- // We may still move 32-bits of the first source vector from one lane
- // to another.
- ShuffleMask[DestLane] = SourceLane;
- // The zero mask may override the previous insert operation.
- for (unsigned i = 0; i < 4; ++i)
- if ((ZMask >> i) & 0x1)
- ShuffleMask[i] = i + 4;
- } else {
- // TODO: Model this case as 2 shuffles or a 'logical and' plus shuffle?
- return nullptr;
- }
- } else {
- // Replace the selected destination lane with the selected source lane.
- ShuffleMask[DestLane] = SourceLane + 4;
- }
-
- return Builder.CreateShuffleVector(II.getArgOperand(0), V1, ShuffleMask);
-}
-
-/// Attempt to simplify SSE4A EXTRQ/EXTRQI instructions using constant folding
-/// or conversion to a shuffle vector.
-static Value *simplifyX86extrq(IntrinsicInst &II, Value *Op0,
- ConstantInt *CILength, ConstantInt *CIIndex,
- InstCombiner::BuilderTy &Builder) {
- auto LowConstantHighUndef = [&](uint64_t Val) {
- Type *IntTy64 = Type::getInt64Ty(II.getContext());
- Constant *Args[] = {ConstantInt::get(IntTy64, Val),
- UndefValue::get(IntTy64)};
- return ConstantVector::get(Args);
- };
-
- // See if we're dealing with constant values.
- Constant *C0 = dyn_cast<Constant>(Op0);
- ConstantInt *CI0 =
- C0 ? dyn_cast_or_null<ConstantInt>(C0->getAggregateElement((unsigned)0))
- : nullptr;
-
- // Attempt to constant fold.
- if (CILength && CIIndex) {
- // From AMD documentation: "The bit index and field length are each six
- // bits in length other bits of the field are ignored."
- APInt APIndex = CIIndex->getValue().zextOrTrunc(6);
- APInt APLength = CILength->getValue().zextOrTrunc(6);
-
- unsigned Index = APIndex.getZExtValue();
-
- // From AMD documentation: "a value of zero in the field length is
- // defined as length of 64".
- unsigned Length = APLength == 0 ? 64 : APLength.getZExtValue();
-
- // From AMD documentation: "If the sum of the bit index + length field
- // is greater than 64, the results are undefined".
- unsigned End = Index + Length;
-
- // Note that both field index and field length are 8-bit quantities.
- // Since variables 'Index' and 'Length' are unsigned values
- // obtained from zero-extending field index and field length
- // respectively, their sum should never wrap around.
- if (End > 64)
- return UndefValue::get(II.getType());
-
- // If we are inserting whole bytes, we can convert this to a shuffle.
- // Lowering can recognize EXTRQI shuffle masks.
- if ((Length % 8) == 0 && (Index % 8) == 0) {
- // Convert bit indices to byte indices.
- Length /= 8;
- Index /= 8;
-
- Type *IntTy8 = Type::getInt8Ty(II.getContext());
- auto *ShufTy = FixedVectorType::get(IntTy8, 16);
-
- SmallVector<int, 16> ShuffleMask;
- for (int i = 0; i != (int)Length; ++i)
- ShuffleMask.push_back(i + Index);
- for (int i = Length; i != 8; ++i)
- ShuffleMask.push_back(i + 16);
- for (int i = 8; i != 16; ++i)
- ShuffleMask.push_back(-1);
-
- Value *SV = Builder.CreateShuffleVector(
- Builder.CreateBitCast(Op0, ShufTy),
- ConstantAggregateZero::get(ShufTy), ShuffleMask);
- return Builder.CreateBitCast(SV, II.getType());
- }
-
- // Constant Fold - shift Index'th bit to lowest position and mask off
- // Length bits.
- if (CI0) {
- APInt Elt = CI0->getValue();
- Elt.lshrInPlace(Index);
- Elt = Elt.zextOrTrunc(Length);
- return LowConstantHighUndef(Elt.getZExtValue());
- }
-
- // If we were an EXTRQ call, we'll save registers if we convert to EXTRQI.
- if (II.getIntrinsicID() == Intrinsic::x86_sse4a_extrq) {
- Value *Args[] = {Op0, CILength, CIIndex};
- Module *M = II.getModule();
- Function *F = Intrinsic::getDeclaration(M, Intrinsic::x86_sse4a_extrqi);
- return Builder.CreateCall(F, Args);
- }
- }
-
- // Constant Fold - extraction from zero is always {zero, undef}.
- if (CI0 && CI0->isZero())
- return LowConstantHighUndef(0);
-
- return nullptr;
-}
-
-/// Attempt to simplify SSE4A INSERTQ/INSERTQI instructions using constant
-/// folding or conversion to a shuffle vector.
-static Value *simplifyX86insertq(IntrinsicInst &II, Value *Op0, Value *Op1,
- APInt APLength, APInt APIndex,
- InstCombiner::BuilderTy &Builder) {
- // From AMD documentation: "The bit index and field length are each six bits
- // in length other bits of the field are ignored."
- APIndex = APIndex.zextOrTrunc(6);
- APLength = APLength.zextOrTrunc(6);
-
- // Attempt to constant fold.
- unsigned Index = APIndex.getZExtValue();
-
- // From AMD documentation: "a value of zero in the field length is
- // defined as length of 64".
- unsigned Length = APLength == 0 ? 64 : APLength.getZExtValue();
-
- // From AMD documentation: "If the sum of the bit index + length field
- // is greater than 64, the results are undefined".
- unsigned End = Index + Length;
-
- // Note that both field index and field length are 8-bit quantities.
- // Since variables 'Index' and 'Length' are unsigned values
- // obtained from zero-extending field index and field length
- // respectively, their sum should never wrap around.
- if (End > 64)
- return UndefValue::get(II.getType());
-
- // If we are inserting whole bytes, we can convert this to a shuffle.
- // Lowering can recognize INSERTQI shuffle masks.
- if ((Length % 8) == 0 && (Index % 8) == 0) {
- // Convert bit indices to byte indices.
- Length /= 8;
- Index /= 8;
-
- Type *IntTy8 = Type::getInt8Ty(II.getContext());
- auto *ShufTy = FixedVectorType::get(IntTy8, 16);
-
- SmallVector<int, 16> ShuffleMask;
- for (int i = 0; i != (int)Index; ++i)
- ShuffleMask.push_back(i);
- for (int i = 0; i != (int)Length; ++i)
- ShuffleMask.push_back(i + 16);
- for (int i = Index + Length; i != 8; ++i)
- ShuffleMask.push_back(i);
- for (int i = 8; i != 16; ++i)
- ShuffleMask.push_back(-1);
-
- Value *SV = Builder.CreateShuffleVector(Builder.CreateBitCast(Op0, ShufTy),
- Builder.CreateBitCast(Op1, ShufTy),
- ShuffleMask);
- return Builder.CreateBitCast(SV, II.getType());
- }
-
- // See if we're dealing with constant values.
- Constant *C0 = dyn_cast<Constant>(Op0);
- Constant *C1 = dyn_cast<Constant>(Op1);
- ConstantInt *CI00 =
- C0 ? dyn_cast_or_null<ConstantInt>(C0->getAggregateElement((unsigned)0))
- : nullptr;
- ConstantInt *CI10 =
- C1 ? dyn_cast_or_null<ConstantInt>(C1->getAggregateElement((unsigned)0))
- : nullptr;
-
- // Constant Fold - insert bottom Length bits starting at the Index'th bit.
- if (CI00 && CI10) {
- APInt V00 = CI00->getValue();
- APInt V10 = CI10->getValue();
- APInt Mask = APInt::getLowBitsSet(64, Length).shl(Index);
- V00 = V00 & ~Mask;
- V10 = V10.zextOrTrunc(Length).zextOrTrunc(64).shl(Index);
- APInt Val = V00 | V10;
- Type *IntTy64 = Type::getInt64Ty(II.getContext());
- Constant *Args[] = {ConstantInt::get(IntTy64, Val.getZExtValue()),
- UndefValue::get(IntTy64)};
- return ConstantVector::get(Args);
- }
-
- // If we were an INSERTQ call, we'll save demanded elements if we convert to
- // INSERTQI.
- if (II.getIntrinsicID() == Intrinsic::x86_sse4a_insertq) {
- Type *IntTy8 = Type::getInt8Ty(II.getContext());
- Constant *CILength = ConstantInt::get(IntTy8, Length, false);
- Constant *CIIndex = ConstantInt::get(IntTy8, Index, false);
-
- Value *Args[] = {Op0, Op1, CILength, CIIndex};
- Module *M = II.getModule();
- Function *F = Intrinsic::getDeclaration(M, Intrinsic::x86_sse4a_insertqi);
- return Builder.CreateCall(F, Args);
- }
-
- return nullptr;
-}
-
-/// Attempt to convert pshufb* to shufflevector if the mask is constant.
-static Value *simplifyX86pshufb(const IntrinsicInst &II,
- InstCombiner::BuilderTy &Builder) {
- Constant *V = dyn_cast<Constant>(II.getArgOperand(1));
- if (!V)
- return nullptr;
-
- auto *VecTy = cast<VectorType>(II.getType());
- unsigned NumElts = VecTy->getNumElements();
- assert((NumElts == 16 || NumElts == 32 || NumElts == 64) &&
- "Unexpected number of elements in shuffle mask!");
-
- // Construct a shuffle mask from constant integers or UNDEFs.
- int Indexes[64];
-
- // Each byte in the shuffle control mask forms an index to permute the
- // corresponding byte in the destination operand.
- for (unsigned I = 0; I < NumElts; ++I) {
- Constant *COp = V->getAggregateElement(I);
- if (!COp || (!isa<UndefValue>(COp) && !isa<ConstantInt>(COp)))
- return nullptr;
-
- if (isa<UndefValue>(COp)) {
- Indexes[I] = -1;
- continue;
- }
-
- int8_t Index = cast<ConstantInt>(COp)->getValue().getZExtValue();
-
- // If the most significant bit (bit[7]) of each byte of the shuffle
- // control mask is set, then zero is written in the result byte.
- // The zero vector is in the right-hand side of the resulting
- // shufflevector.
-
- // The value of each index for the high 128-bit lane is the least
- // significant 4 bits of the respective shuffle control byte.
- Index = ((Index < 0) ? NumElts : Index & 0x0F) + (I & 0xF0);
- Indexes[I] = Index;
- }
-
- auto V1 = II.getArgOperand(0);
- auto V2 = Constant::getNullValue(VecTy);
- return Builder.CreateShuffleVector(V1, V2, makeArrayRef(Indexes, NumElts));
-}
-
-/// Attempt to convert vpermilvar* to shufflevector if the mask is constant.
-static Value *simplifyX86vpermilvar(const IntrinsicInst &II,
- InstCombiner::BuilderTy &Builder) {
- Constant *V = dyn_cast<Constant>(II.getArgOperand(1));
- if (!V)
- return nullptr;
-
- auto *VecTy = cast<VectorType>(II.getType());
- unsigned NumElts = VecTy->getNumElements();
- bool IsPD = VecTy->getScalarType()->isDoubleTy();
- unsigned NumLaneElts = IsPD ? 2 : 4;
- assert(NumElts == 16 || NumElts == 8 || NumElts == 4 || NumElts == 2);
-
- // Construct a shuffle mask from constant integers or UNDEFs.
- int Indexes[16];
-
- // The intrinsics only read one or two bits, clear the rest.
- for (unsigned I = 0; I < NumElts; ++I) {
- Constant *COp = V->getAggregateElement(I);
- if (!COp || (!isa<UndefValue>(COp) && !isa<ConstantInt>(COp)))
- return nullptr;
-
- if (isa<UndefValue>(COp)) {
- Indexes[I] = -1;
- continue;
- }
-
- APInt Index = cast<ConstantInt>(COp)->getValue();
- Index = Index.zextOrTrunc(32).getLoBits(2);
-
- // The PD variants uses bit 1 to select per-lane element index, so
- // shift down to convert to generic shuffle mask index.
- if (IsPD)
- Index.lshrInPlace(1);
-
- // The _256 variants are a bit trickier since the mask bits always index
- // into the corresponding 128 half. In order to convert to a generic
- // shuffle, we have to make that explicit.
- Index += APInt(32, (I / NumLaneElts) * NumLaneElts);
-
- Indexes[I] = Index.getZExtValue();
- }
-
- auto V1 = II.getArgOperand(0);
- auto V2 = UndefValue::get(V1->getType());
- return Builder.CreateShuffleVector(V1, V2, makeArrayRef(Indexes, NumElts));
-}
-
-/// Attempt to convert vpermd/vpermps to shufflevector if the mask is constant.
-static Value *simplifyX86vpermv(const IntrinsicInst &II,
- InstCombiner::BuilderTy &Builder) {
- auto *V = dyn_cast<Constant>(II.getArgOperand(1));
- if (!V)
- return nullptr;
-
- auto *VecTy = cast<VectorType>(II.getType());
- unsigned Size = VecTy->getNumElements();
- assert((Size == 4 || Size == 8 || Size == 16 || Size == 32 || Size == 64) &&
- "Unexpected shuffle mask size");
-
- // Construct a shuffle mask from constant integers or UNDEFs.
- int Indexes[64];
-
- for (unsigned I = 0; I < Size; ++I) {
- Constant *COp = V->getAggregateElement(I);
- if (!COp || (!isa<UndefValue>(COp) && !isa<ConstantInt>(COp)))
- return nullptr;
-
- if (isa<UndefValue>(COp)) {
- Indexes[I] = -1;
- continue;
- }
-
- uint32_t Index = cast<ConstantInt>(COp)->getZExtValue();
- Index &= Size - 1;
- Indexes[I] = Index;
- }
-
- auto V1 = II.getArgOperand(0);
- auto V2 = UndefValue::get(VecTy);
- return Builder.CreateShuffleVector(V1, V2, makeArrayRef(Indexes, Size));
-}
-
// TODO, Obvious Missing Transforms:
// * Narrow width by halfs excluding zero/undef lanes
-Value *InstCombiner::simplifyMaskedLoad(IntrinsicInst &II) {
+Value *InstCombinerImpl::simplifyMaskedLoad(IntrinsicInst &II) {
Value *LoadPtr = II.getArgOperand(0);
const Align Alignment =
cast<ConstantInt>(II.getArgOperand(1))->getAlignValue();
@@ -1118,9 +289,8 @@ Value *InstCombiner::simplifyMaskedLoad(IntrinsicInst &II) {
// If we can unconditionally load from this address, replace with a
// load/select idiom. TODO: use DT for context sensitive query
- if (isDereferenceableAndAlignedPointer(LoadPtr, II.getType(), Alignment,
- II.getModule()->getDataLayout(), &II,
- nullptr)) {
+ if (isDereferenceablePointer(LoadPtr, II.getType(),
+ II.getModule()->getDataLayout(), &II, nullptr)) {
Value *LI = Builder.CreateAlignedLoad(II.getType(), LoadPtr, Alignment,
"unmaskedload");
return Builder.CreateSelect(II.getArgOperand(2), LI, II.getArgOperand(3));
@@ -1132,7 +302,7 @@ Value *InstCombiner::simplifyMaskedLoad(IntrinsicInst &II) {
// TODO, Obvious Missing Transforms:
// * Single constant active lane -> store
// * Narrow width by halfs excluding zero/undef lanes
-Instruction *InstCombiner::simplifyMaskedStore(IntrinsicInst &II) {
+Instruction *InstCombinerImpl::simplifyMaskedStore(IntrinsicInst &II) {
auto *ConstMask = dyn_cast<Constant>(II.getArgOperand(3));
if (!ConstMask)
return nullptr;
@@ -1148,11 +318,14 @@ Instruction *InstCombiner::simplifyMaskedStore(IntrinsicInst &II) {
return new StoreInst(II.getArgOperand(0), StorePtr, false, Alignment);
}
+ if (isa<ScalableVectorType>(ConstMask->getType()))
+ return nullptr;
+
// Use masked off lanes to simplify operands via SimplifyDemandedVectorElts
APInt DemandedElts = possiblyDemandedEltsInMask(ConstMask);
APInt UndefElts(DemandedElts.getBitWidth(), 0);
- if (Value *V = SimplifyDemandedVectorElts(II.getOperand(0),
- DemandedElts, UndefElts))
+ if (Value *V =
+ SimplifyDemandedVectorElts(II.getOperand(0), DemandedElts, UndefElts))
return replaceOperand(II, 0, V);
return nullptr;
@@ -1165,7 +338,7 @@ Instruction *InstCombiner::simplifyMaskedStore(IntrinsicInst &II) {
// * Narrow width by halfs excluding zero/undef lanes
// * Vector splat address w/known mask -> scalar load
// * Vector incrementing address -> vector masked load
-Instruction *InstCombiner::simplifyMaskedGather(IntrinsicInst &II) {
+Instruction *InstCombinerImpl::simplifyMaskedGather(IntrinsicInst &II) {
return nullptr;
}
@@ -1175,7 +348,7 @@ Instruction *InstCombiner::simplifyMaskedGather(IntrinsicInst &II) {
// * Narrow store width by halfs excluding zero/undef lanes
// * Vector splat address w/known mask -> scalar store
// * Vector incrementing address -> vector masked store
-Instruction *InstCombiner::simplifyMaskedScatter(IntrinsicInst &II) {
+Instruction *InstCombinerImpl::simplifyMaskedScatter(IntrinsicInst &II) {
auto *ConstMask = dyn_cast<Constant>(II.getArgOperand(3));
if (!ConstMask)
return nullptr;
@@ -1184,14 +357,17 @@ Instruction *InstCombiner::simplifyMaskedScatter(IntrinsicInst &II) {
if (ConstMask->isNullValue())
return eraseInstFromFunction(II);
+ if (isa<ScalableVectorType>(ConstMask->getType()))
+ return nullptr;
+
// Use masked off lanes to simplify operands via SimplifyDemandedVectorElts
APInt DemandedElts = possiblyDemandedEltsInMask(ConstMask);
APInt UndefElts(DemandedElts.getBitWidth(), 0);
- if (Value *V = SimplifyDemandedVectorElts(II.getOperand(0),
- DemandedElts, UndefElts))
+ if (Value *V =
+ SimplifyDemandedVectorElts(II.getOperand(0), DemandedElts, UndefElts))
return replaceOperand(II, 0, V);
- if (Value *V = SimplifyDemandedVectorElts(II.getOperand(1),
- DemandedElts, UndefElts))
+ if (Value *V =
+ SimplifyDemandedVectorElts(II.getOperand(1), DemandedElts, UndefElts))
return replaceOperand(II, 1, V);
return nullptr;
@@ -1206,7 +382,7 @@ Instruction *InstCombiner::simplifyMaskedScatter(IntrinsicInst &II) {
/// This is legal because it preserves the most recent information about
/// the presence or absence of invariant.group.
static Instruction *simplifyInvariantGroupIntrinsic(IntrinsicInst &II,
- InstCombiner &IC) {
+ InstCombinerImpl &IC) {
auto *Arg = II.getArgOperand(0);
auto *StrippedArg = Arg->stripPointerCasts();
auto *StrippedInvariantGroupsArg = Arg->stripPointerCastsAndInvariantGroups();
@@ -1231,7 +407,7 @@ static Instruction *simplifyInvariantGroupIntrinsic(IntrinsicInst &II,
return cast<Instruction>(Result);
}
-static Instruction *foldCttzCtlz(IntrinsicInst &II, InstCombiner &IC) {
+static Instruction *foldCttzCtlz(IntrinsicInst &II, InstCombinerImpl &IC) {
assert((II.getIntrinsicID() == Intrinsic::cttz ||
II.getIntrinsicID() == Intrinsic::ctlz) &&
"Expected cttz or ctlz intrinsic");
@@ -1257,6 +433,9 @@ static Instruction *foldCttzCtlz(IntrinsicInst &II, InstCombiner &IC) {
SelectPatternFlavor SPF = matchSelectPattern(Op0, X, Y).Flavor;
if (SPF == SPF_ABS || SPF == SPF_NABS)
return IC.replaceOperand(II, 0, X);
+
+ if (match(Op0, m_Intrinsic<Intrinsic::abs>(m_Value(X))))
+ return IC.replaceOperand(II, 0, X);
}
KnownBits Known = IC.computeKnownBits(Op0, 0, &II);
@@ -1301,7 +480,7 @@ static Instruction *foldCttzCtlz(IntrinsicInst &II, InstCombiner &IC) {
return nullptr;
}
-static Instruction *foldCtpop(IntrinsicInst &II, InstCombiner &IC) {
+static Instruction *foldCtpop(IntrinsicInst &II, InstCombinerImpl &IC) {
assert(II.getIntrinsicID() == Intrinsic::ctpop &&
"Expected ctpop intrinsic");
Type *Ty = II.getType();
@@ -1356,107 +535,6 @@ static Instruction *foldCtpop(IntrinsicInst &II, InstCombiner &IC) {
return nullptr;
}
-// TODO: If the x86 backend knew how to convert a bool vector mask back to an
-// XMM register mask efficiently, we could transform all x86 masked intrinsics
-// to LLVM masked intrinsics and remove the x86 masked intrinsic defs.
-static Instruction *simplifyX86MaskedLoad(IntrinsicInst &II, InstCombiner &IC) {
- Value *Ptr = II.getOperand(0);
- Value *Mask = II.getOperand(1);
- Constant *ZeroVec = Constant::getNullValue(II.getType());
-
- // Special case a zero mask since that's not a ConstantDataVector.
- // This masked load instruction creates a zero vector.
- if (isa<ConstantAggregateZero>(Mask))
- return IC.replaceInstUsesWith(II, ZeroVec);
-
- auto *ConstMask = dyn_cast<ConstantDataVector>(Mask);
- if (!ConstMask)
- return nullptr;
-
- // The mask is constant. Convert this x86 intrinsic to the LLVM instrinsic
- // to allow target-independent optimizations.
-
- // First, cast the x86 intrinsic scalar pointer to a vector pointer to match
- // the LLVM intrinsic definition for the pointer argument.
- unsigned AddrSpace = cast<PointerType>(Ptr->getType())->getAddressSpace();
- PointerType *VecPtrTy = PointerType::get(II.getType(), AddrSpace);
- Value *PtrCast = IC.Builder.CreateBitCast(Ptr, VecPtrTy, "castvec");
-
- // Second, convert the x86 XMM integer vector mask to a vector of bools based
- // on each element's most significant bit (the sign bit).
- Constant *BoolMask = getNegativeIsTrueBoolVec(ConstMask);
-
- // The pass-through vector for an x86 masked load is a zero vector.
- CallInst *NewMaskedLoad =
- IC.Builder.CreateMaskedLoad(PtrCast, Align(1), BoolMask, ZeroVec);
- return IC.replaceInstUsesWith(II, NewMaskedLoad);
-}
-
-// TODO: If the x86 backend knew how to convert a bool vector mask back to an
-// XMM register mask efficiently, we could transform all x86 masked intrinsics
-// to LLVM masked intrinsics and remove the x86 masked intrinsic defs.
-static bool simplifyX86MaskedStore(IntrinsicInst &II, InstCombiner &IC) {
- Value *Ptr = II.getOperand(0);
- Value *Mask = II.getOperand(1);
- Value *Vec = II.getOperand(2);
-
- // Special case a zero mask since that's not a ConstantDataVector:
- // this masked store instruction does nothing.
- if (isa<ConstantAggregateZero>(Mask)) {
- IC.eraseInstFromFunction(II);
- return true;
- }
-
- // The SSE2 version is too weird (eg, unaligned but non-temporal) to do
- // anything else at this level.
- if (II.getIntrinsicID() == Intrinsic::x86_sse2_maskmov_dqu)
- return false;
-
- auto *ConstMask = dyn_cast<ConstantDataVector>(Mask);
- if (!ConstMask)
- return false;
-
- // The mask is constant. Convert this x86 intrinsic to the LLVM instrinsic
- // to allow target-independent optimizations.
-
- // First, cast the x86 intrinsic scalar pointer to a vector pointer to match
- // the LLVM intrinsic definition for the pointer argument.
- unsigned AddrSpace = cast<PointerType>(Ptr->getType())->getAddressSpace();
- PointerType *VecPtrTy = PointerType::get(Vec->getType(), AddrSpace);
- Value *PtrCast = IC.Builder.CreateBitCast(Ptr, VecPtrTy, "castvec");
-
- // Second, convert the x86 XMM integer vector mask to a vector of bools based
- // on each element's most significant bit (the sign bit).
- Constant *BoolMask = getNegativeIsTrueBoolVec(ConstMask);
-
- IC.Builder.CreateMaskedStore(Vec, PtrCast, Align(1), BoolMask);
-
- // 'Replace uses' doesn't work for stores. Erase the original masked store.
- IC.eraseInstFromFunction(II);
- return true;
-}
-
-// Constant fold llvm.amdgcn.fmed3 intrinsics for standard inputs.
-//
-// A single NaN input is folded to minnum, so we rely on that folding for
-// handling NaNs.
-static APFloat fmed3AMDGCN(const APFloat &Src0, const APFloat &Src1,
- const APFloat &Src2) {
- APFloat Max3 = maxnum(maxnum(Src0, Src1), Src2);
-
- APFloat::cmpResult Cmp0 = Max3.compare(Src0);
- assert(Cmp0 != APFloat::cmpUnordered && "nans handled separately");
- if (Cmp0 == APFloat::cmpEqual)
- return maxnum(Src1, Src2);
-
- APFloat::cmpResult Cmp1 = Max3.compare(Src1);
- assert(Cmp1 != APFloat::cmpUnordered && "nans handled separately");
- if (Cmp1 == APFloat::cmpEqual)
- return maxnum(Src0, Src2);
-
- return maxnum(Src0, Src1);
-}
-
/// Convert a table lookup to shufflevector if the mask is constant.
/// This could benefit tbl1 if the mask is { 7,6,5,4,3,2,1,0 }, in
/// which case we could lower the shufflevector with rev64 instructions
@@ -1468,7 +546,7 @@ static Value *simplifyNeonTbl1(const IntrinsicInst &II,
if (!C)
return nullptr;
- auto *VecTy = cast<VectorType>(II.getType());
+ auto *VecTy = cast<FixedVectorType>(II.getType());
unsigned NumElts = VecTy->getNumElements();
// Only perform this transformation for <8 x i8> vector types.
@@ -1495,28 +573,6 @@ static Value *simplifyNeonTbl1(const IntrinsicInst &II,
return Builder.CreateShuffleVector(V1, V2, makeArrayRef(Indexes));
}
-/// Convert a vector load intrinsic into a simple llvm load instruction.
-/// This is beneficial when the underlying object being addressed comes
-/// from a constant, since we get constant-folding for free.
-static Value *simplifyNeonVld1(const IntrinsicInst &II,
- unsigned MemAlign,
- InstCombiner::BuilderTy &Builder) {
- auto *IntrAlign = dyn_cast<ConstantInt>(II.getArgOperand(1));
-
- if (!IntrAlign)
- return nullptr;
-
- unsigned Alignment = IntrAlign->getLimitedValue() < MemAlign ?
- MemAlign : IntrAlign->getLimitedValue();
-
- if (!isPowerOf2_32(Alignment))
- return nullptr;
-
- auto *BCastInst = Builder.CreateBitCast(II.getArgOperand(0),
- PointerType::get(II.getType(), 0));
- return Builder.CreateAlignedLoad(II.getType(), BCastInst, Align(Alignment));
-}
-
// Returns true iff the 2 intrinsics have the same operands, limiting the
// comparison to the first NumOperands.
static bool haveSameOperands(const IntrinsicInst &I, const IntrinsicInst &E,
@@ -1538,9 +594,9 @@ static bool haveSameOperands(const IntrinsicInst &I, const IntrinsicInst &E,
// call @llvm.foo.start(i1 0) ; This one won't be skipped: it will be removed
// call @llvm.foo.end(i1 0)
// call @llvm.foo.end(i1 0) ; &I
-static bool removeTriviallyEmptyRange(
- IntrinsicInst &EndI, InstCombiner &IC,
- std::function<bool(const IntrinsicInst &)> IsStart) {
+static bool
+removeTriviallyEmptyRange(IntrinsicInst &EndI, InstCombinerImpl &IC,
+ std::function<bool(const IntrinsicInst &)> IsStart) {
// We start from the end intrinsic and scan backwards, so that InstCombine
// has already processed (and potentially removed) all the instructions
// before the end intrinsic.
@@ -1566,256 +622,7 @@ static bool removeTriviallyEmptyRange(
return false;
}
-// Convert NVVM intrinsics to target-generic LLVM code where possible.
-static Instruction *SimplifyNVVMIntrinsic(IntrinsicInst *II, InstCombiner &IC) {
- // Each NVVM intrinsic we can simplify can be replaced with one of:
- //
- // * an LLVM intrinsic,
- // * an LLVM cast operation,
- // * an LLVM binary operation, or
- // * ad-hoc LLVM IR for the particular operation.
-
- // Some transformations are only valid when the module's
- // flush-denormals-to-zero (ftz) setting is true/false, whereas other
- // transformations are valid regardless of the module's ftz setting.
- enum FtzRequirementTy {
- FTZ_Any, // Any ftz setting is ok.
- FTZ_MustBeOn, // Transformation is valid only if ftz is on.
- FTZ_MustBeOff, // Transformation is valid only if ftz is off.
- };
- // Classes of NVVM intrinsics that can't be replaced one-to-one with a
- // target-generic intrinsic, cast op, or binary op but that we can nonetheless
- // simplify.
- enum SpecialCase {
- SPC_Reciprocal,
- };
-
- // SimplifyAction is a poor-man's variant (plus an additional flag) that
- // represents how to replace an NVVM intrinsic with target-generic LLVM IR.
- struct SimplifyAction {
- // Invariant: At most one of these Optionals has a value.
- Optional<Intrinsic::ID> IID;
- Optional<Instruction::CastOps> CastOp;
- Optional<Instruction::BinaryOps> BinaryOp;
- Optional<SpecialCase> Special;
-
- FtzRequirementTy FtzRequirement = FTZ_Any;
-
- SimplifyAction() = default;
-
- SimplifyAction(Intrinsic::ID IID, FtzRequirementTy FtzReq)
- : IID(IID), FtzRequirement(FtzReq) {}
-
- // Cast operations don't have anything to do with FTZ, so we skip that
- // argument.
- SimplifyAction(Instruction::CastOps CastOp) : CastOp(CastOp) {}
-
- SimplifyAction(Instruction::BinaryOps BinaryOp, FtzRequirementTy FtzReq)
- : BinaryOp(BinaryOp), FtzRequirement(FtzReq) {}
-
- SimplifyAction(SpecialCase Special, FtzRequirementTy FtzReq)
- : Special(Special), FtzRequirement(FtzReq) {}
- };
-
- // Try to generate a SimplifyAction describing how to replace our
- // IntrinsicInstr with target-generic LLVM IR.
- const SimplifyAction Action = [II]() -> SimplifyAction {
- switch (II->getIntrinsicID()) {
- // NVVM intrinsics that map directly to LLVM intrinsics.
- case Intrinsic::nvvm_ceil_d:
- return {Intrinsic::ceil, FTZ_Any};
- case Intrinsic::nvvm_ceil_f:
- return {Intrinsic::ceil, FTZ_MustBeOff};
- case Intrinsic::nvvm_ceil_ftz_f:
- return {Intrinsic::ceil, FTZ_MustBeOn};
- case Intrinsic::nvvm_fabs_d:
- return {Intrinsic::fabs, FTZ_Any};
- case Intrinsic::nvvm_fabs_f:
- return {Intrinsic::fabs, FTZ_MustBeOff};
- case Intrinsic::nvvm_fabs_ftz_f:
- return {Intrinsic::fabs, FTZ_MustBeOn};
- case Intrinsic::nvvm_floor_d:
- return {Intrinsic::floor, FTZ_Any};
- case Intrinsic::nvvm_floor_f:
- return {Intrinsic::floor, FTZ_MustBeOff};
- case Intrinsic::nvvm_floor_ftz_f:
- return {Intrinsic::floor, FTZ_MustBeOn};
- case Intrinsic::nvvm_fma_rn_d:
- return {Intrinsic::fma, FTZ_Any};
- case Intrinsic::nvvm_fma_rn_f:
- return {Intrinsic::fma, FTZ_MustBeOff};
- case Intrinsic::nvvm_fma_rn_ftz_f:
- return {Intrinsic::fma, FTZ_MustBeOn};
- case Intrinsic::nvvm_fmax_d:
- return {Intrinsic::maxnum, FTZ_Any};
- case Intrinsic::nvvm_fmax_f:
- return {Intrinsic::maxnum, FTZ_MustBeOff};
- case Intrinsic::nvvm_fmax_ftz_f:
- return {Intrinsic::maxnum, FTZ_MustBeOn};
- case Intrinsic::nvvm_fmin_d:
- return {Intrinsic::minnum, FTZ_Any};
- case Intrinsic::nvvm_fmin_f:
- return {Intrinsic::minnum, FTZ_MustBeOff};
- case Intrinsic::nvvm_fmin_ftz_f:
- return {Intrinsic::minnum, FTZ_MustBeOn};
- case Intrinsic::nvvm_round_d:
- return {Intrinsic::round, FTZ_Any};
- case Intrinsic::nvvm_round_f:
- return {Intrinsic::round, FTZ_MustBeOff};
- case Intrinsic::nvvm_round_ftz_f:
- return {Intrinsic::round, FTZ_MustBeOn};
- case Intrinsic::nvvm_sqrt_rn_d:
- return {Intrinsic::sqrt, FTZ_Any};
- case Intrinsic::nvvm_sqrt_f:
- // nvvm_sqrt_f is a special case. For most intrinsics, foo_ftz_f is the
- // ftz version, and foo_f is the non-ftz version. But nvvm_sqrt_f adopts
- // the ftz-ness of the surrounding code. sqrt_rn_f and sqrt_rn_ftz_f are
- // the versions with explicit ftz-ness.
- return {Intrinsic::sqrt, FTZ_Any};
- case Intrinsic::nvvm_sqrt_rn_f:
- return {Intrinsic::sqrt, FTZ_MustBeOff};
- case Intrinsic::nvvm_sqrt_rn_ftz_f:
- return {Intrinsic::sqrt, FTZ_MustBeOn};
- case Intrinsic::nvvm_trunc_d:
- return {Intrinsic::trunc, FTZ_Any};
- case Intrinsic::nvvm_trunc_f:
- return {Intrinsic::trunc, FTZ_MustBeOff};
- case Intrinsic::nvvm_trunc_ftz_f:
- return {Intrinsic::trunc, FTZ_MustBeOn};
-
- // NVVM intrinsics that map to LLVM cast operations.
- //
- // Note that llvm's target-generic conversion operators correspond to the rz
- // (round to zero) versions of the nvvm conversion intrinsics, even though
- // most everything else here uses the rn (round to nearest even) nvvm ops.
- case Intrinsic::nvvm_d2i_rz:
- case Intrinsic::nvvm_f2i_rz:
- case Intrinsic::nvvm_d2ll_rz:
- case Intrinsic::nvvm_f2ll_rz:
- return {Instruction::FPToSI};
- case Intrinsic::nvvm_d2ui_rz:
- case Intrinsic::nvvm_f2ui_rz:
- case Intrinsic::nvvm_d2ull_rz:
- case Intrinsic::nvvm_f2ull_rz:
- return {Instruction::FPToUI};
- case Intrinsic::nvvm_i2d_rz:
- case Intrinsic::nvvm_i2f_rz:
- case Intrinsic::nvvm_ll2d_rz:
- case Intrinsic::nvvm_ll2f_rz:
- return {Instruction::SIToFP};
- case Intrinsic::nvvm_ui2d_rz:
- case Intrinsic::nvvm_ui2f_rz:
- case Intrinsic::nvvm_ull2d_rz:
- case Intrinsic::nvvm_ull2f_rz:
- return {Instruction::UIToFP};
-
- // NVVM intrinsics that map to LLVM binary ops.
- case Intrinsic::nvvm_add_rn_d:
- return {Instruction::FAdd, FTZ_Any};
- case Intrinsic::nvvm_add_rn_f:
- return {Instruction::FAdd, FTZ_MustBeOff};
- case Intrinsic::nvvm_add_rn_ftz_f:
- return {Instruction::FAdd, FTZ_MustBeOn};
- case Intrinsic::nvvm_mul_rn_d:
- return {Instruction::FMul, FTZ_Any};
- case Intrinsic::nvvm_mul_rn_f:
- return {Instruction::FMul, FTZ_MustBeOff};
- case Intrinsic::nvvm_mul_rn_ftz_f:
- return {Instruction::FMul, FTZ_MustBeOn};
- case Intrinsic::nvvm_div_rn_d:
- return {Instruction::FDiv, FTZ_Any};
- case Intrinsic::nvvm_div_rn_f:
- return {Instruction::FDiv, FTZ_MustBeOff};
- case Intrinsic::nvvm_div_rn_ftz_f:
- return {Instruction::FDiv, FTZ_MustBeOn};
-
- // The remainder of cases are NVVM intrinsics that map to LLVM idioms, but
- // need special handling.
- //
- // We seem to be missing intrinsics for rcp.approx.{ftz.}f32, which is just
- // as well.
- case Intrinsic::nvvm_rcp_rn_d:
- return {SPC_Reciprocal, FTZ_Any};
- case Intrinsic::nvvm_rcp_rn_f:
- return {SPC_Reciprocal, FTZ_MustBeOff};
- case Intrinsic::nvvm_rcp_rn_ftz_f:
- return {SPC_Reciprocal, FTZ_MustBeOn};
-
- // We do not currently simplify intrinsics that give an approximate answer.
- // These include:
- //
- // - nvvm_cos_approx_{f,ftz_f}
- // - nvvm_ex2_approx_{d,f,ftz_f}
- // - nvvm_lg2_approx_{d,f,ftz_f}
- // - nvvm_sin_approx_{f,ftz_f}
- // - nvvm_sqrt_approx_{f,ftz_f}
- // - nvvm_rsqrt_approx_{d,f,ftz_f}
- // - nvvm_div_approx_{ftz_d,ftz_f,f}
- // - nvvm_rcp_approx_ftz_d
- //
- // Ideally we'd encode them as e.g. "fast call @llvm.cos", where "fast"
- // means that fastmath is enabled in the intrinsic. Unfortunately only
- // binary operators (currently) have a fastmath bit in SelectionDAG, so this
- // information gets lost and we can't select on it.
- //
- // TODO: div and rcp are lowered to a binary op, so these we could in theory
- // lower them to "fast fdiv".
-
- default:
- return {};
- }
- }();
-
- // If Action.FtzRequirementTy is not satisfied by the module's ftz state, we
- // can bail out now. (Notice that in the case that IID is not an NVVM
- // intrinsic, we don't have to look up any module metadata, as
- // FtzRequirementTy will be FTZ_Any.)
- if (Action.FtzRequirement != FTZ_Any) {
- StringRef Attr = II->getFunction()
- ->getFnAttribute("denormal-fp-math-f32")
- .getValueAsString();
- DenormalMode Mode = parseDenormalFPAttribute(Attr);
- bool FtzEnabled = Mode.Output != DenormalMode::IEEE;
-
- if (FtzEnabled != (Action.FtzRequirement == FTZ_MustBeOn))
- return nullptr;
- }
-
- // Simplify to target-generic intrinsic.
- if (Action.IID) {
- SmallVector<Value *, 4> Args(II->arg_operands());
- // All the target-generic intrinsics currently of interest to us have one
- // type argument, equal to that of the nvvm intrinsic's argument.
- Type *Tys[] = {II->getArgOperand(0)->getType()};
- return CallInst::Create(
- Intrinsic::getDeclaration(II->getModule(), *Action.IID, Tys), Args);
- }
-
- // Simplify to target-generic binary op.
- if (Action.BinaryOp)
- return BinaryOperator::Create(*Action.BinaryOp, II->getArgOperand(0),
- II->getArgOperand(1), II->getName());
-
- // Simplify to target-generic cast op.
- if (Action.CastOp)
- return CastInst::Create(*Action.CastOp, II->getArgOperand(0), II->getType(),
- II->getName());
-
- // All that's left are the special cases.
- if (!Action.Special)
- return nullptr;
-
- switch (*Action.Special) {
- case SPC_Reciprocal:
- // Simplify reciprocal.
- return BinaryOperator::Create(
- Instruction::FDiv, ConstantFP::get(II->getArgOperand(0)->getType(), 1),
- II->getArgOperand(0), II->getName());
- }
- llvm_unreachable("All SpecialCase enumerators should be handled in switch.");
-}
-
-Instruction *InstCombiner::visitVAEndInst(VAEndInst &I) {
+Instruction *InstCombinerImpl::visitVAEndInst(VAEndInst &I) {
removeTriviallyEmptyRange(I, *this, [](const IntrinsicInst &I) {
return I.getIntrinsicID() == Intrinsic::vastart ||
I.getIntrinsicID() == Intrinsic::vacopy;
@@ -1823,7 +630,7 @@ Instruction *InstCombiner::visitVAEndInst(VAEndInst &I) {
return nullptr;
}
-static Instruction *canonicalizeConstantArg0ToArg1(CallInst &Call) {
+static CallInst *canonicalizeConstantArg0ToArg1(CallInst &Call) {
assert(Call.getNumArgOperands() > 1 && "Need at least 2 args to swap");
Value *Arg0 = Call.getArgOperand(0), *Arg1 = Call.getArgOperand(1);
if (isa<Constant>(Arg0) && !isa<Constant>(Arg1)) {
@@ -1834,20 +641,44 @@ static Instruction *canonicalizeConstantArg0ToArg1(CallInst &Call) {
return nullptr;
}
-Instruction *InstCombiner::foldIntrinsicWithOverflowCommon(IntrinsicInst *II) {
+/// Creates a result tuple for an overflow intrinsic \p II with a given
+/// \p Result and a constant \p Overflow value.
+static Instruction *createOverflowTuple(IntrinsicInst *II, Value *Result,
+ Constant *Overflow) {
+ Constant *V[] = {UndefValue::get(Result->getType()), Overflow};
+ StructType *ST = cast<StructType>(II->getType());
+ Constant *Struct = ConstantStruct::get(ST, V);
+ return InsertValueInst::Create(Struct, Result, 0);
+}
+
+Instruction *
+InstCombinerImpl::foldIntrinsicWithOverflowCommon(IntrinsicInst *II) {
WithOverflowInst *WO = cast<WithOverflowInst>(II);
Value *OperationResult = nullptr;
Constant *OverflowResult = nullptr;
if (OptimizeOverflowCheck(WO->getBinaryOp(), WO->isSigned(), WO->getLHS(),
WO->getRHS(), *WO, OperationResult, OverflowResult))
- return CreateOverflowTuple(WO, OperationResult, OverflowResult);
+ return createOverflowTuple(WO, OperationResult, OverflowResult);
return nullptr;
}
+static Optional<bool> getKnownSign(Value *Op, Instruction *CxtI,
+ const DataLayout &DL, AssumptionCache *AC,
+ DominatorTree *DT) {
+ KnownBits Known = computeKnownBits(Op, DL, 0, AC, CxtI, DT);
+ if (Known.isNonNegative())
+ return false;
+ if (Known.isNegative())
+ return true;
+
+ return isImpliedByDomCondition(
+ ICmpInst::ICMP_SLT, Op, Constant::getNullValue(Op->getType()), CxtI, DL);
+}
+
/// CallInst simplification. This mostly only handles folding of intrinsic
/// instructions. For normal calls, it allows visitCallBase to do the heavy
/// lifting.
-Instruction *InstCombiner::visitCallInst(CallInst &CI) {
+Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
// Don't try to simplify calls without uses. It will not do anything useful,
// but will result in the following folds being skipped.
if (!CI.use_empty())
@@ -1953,31 +784,84 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
}
}
- if (Instruction *I = SimplifyNVVMIntrinsic(II, *this))
- return I;
-
- auto SimplifyDemandedVectorEltsLow = [this](Value *Op, unsigned Width,
- unsigned DemandedWidth) {
- APInt UndefElts(Width, 0);
- APInt DemandedElts = APInt::getLowBitsSet(Width, DemandedWidth);
- return SimplifyDemandedVectorElts(Op, DemandedElts, UndefElts);
- };
+ if (II->isCommutative()) {
+ if (CallInst *NewCall = canonicalizeConstantArg0ToArg1(CI))
+ return NewCall;
+ }
Intrinsic::ID IID = II->getIntrinsicID();
switch (IID) {
- default: break;
case Intrinsic::objectsize:
if (Value *V = lowerObjectSizeCall(II, DL, &TLI, /*MustSucceed=*/false))
return replaceInstUsesWith(CI, V);
return nullptr;
+ case Intrinsic::abs: {
+ Value *IIOperand = II->getArgOperand(0);
+ bool IntMinIsPoison = cast<Constant>(II->getArgOperand(1))->isOneValue();
+
+ // abs(-x) -> abs(x)
+ // TODO: Copy nsw if it was present on the neg?
+ Value *X;
+ if (match(IIOperand, m_Neg(m_Value(X))))
+ return replaceOperand(*II, 0, X);
+ if (match(IIOperand, m_Select(m_Value(), m_Value(X), m_Neg(m_Deferred(X)))))
+ return replaceOperand(*II, 0, X);
+ if (match(IIOperand, m_Select(m_Value(), m_Neg(m_Value(X)), m_Deferred(X))))
+ return replaceOperand(*II, 0, X);
+
+ if (Optional<bool> Sign = getKnownSign(IIOperand, II, DL, &AC, &DT)) {
+ // abs(x) -> x if x >= 0
+ if (!*Sign)
+ return replaceInstUsesWith(*II, IIOperand);
+
+ // abs(x) -> -x if x < 0
+ if (IntMinIsPoison)
+ return BinaryOperator::CreateNSWNeg(IIOperand);
+ return BinaryOperator::CreateNeg(IIOperand);
+ }
+
+ // abs (sext X) --> zext (abs X*)
+ // Clear the IsIntMin (nsw) bit on the abs to allow narrowing.
+ if (match(IIOperand, m_OneUse(m_SExt(m_Value(X))))) {
+ Value *NarrowAbs =
+ Builder.CreateBinaryIntrinsic(Intrinsic::abs, X, Builder.getFalse());
+ return CastInst::Create(Instruction::ZExt, NarrowAbs, II->getType());
+ }
+
+ break;
+ }
+ case Intrinsic::umax:
+ case Intrinsic::umin: {
+ Value *I0 = II->getArgOperand(0), *I1 = II->getArgOperand(1);
+ Value *X, *Y;
+ if (match(I0, m_ZExt(m_Value(X))) && match(I1, m_ZExt(m_Value(Y))) &&
+ (I0->hasOneUse() || I1->hasOneUse()) && X->getType() == Y->getType()) {
+ Value *NarrowMaxMin = Builder.CreateBinaryIntrinsic(IID, X, Y);
+ return CastInst::Create(Instruction::ZExt, NarrowMaxMin, II->getType());
+ }
+ // If both operands of unsigned min/max are sign-extended, it is still ok
+ // to narrow the operation.
+ LLVM_FALLTHROUGH;
+ }
+ case Intrinsic::smax:
+ case Intrinsic::smin: {
+ Value *I0 = II->getArgOperand(0), *I1 = II->getArgOperand(1);
+ Value *X, *Y;
+ if (match(I0, m_SExt(m_Value(X))) && match(I1, m_SExt(m_Value(Y))) &&
+ (I0->hasOneUse() || I1->hasOneUse()) && X->getType() == Y->getType()) {
+ Value *NarrowMaxMin = Builder.CreateBinaryIntrinsic(IID, X, Y);
+ return CastInst::Create(Instruction::SExt, NarrowMaxMin, II->getType());
+ }
+ break;
+ }
case Intrinsic::bswap: {
Value *IIOperand = II->getArgOperand(0);
Value *X = nullptr;
// bswap(trunc(bswap(x))) -> trunc(lshr(x, c))
if (match(IIOperand, m_Trunc(m_BSwap(m_Value(X))))) {
- unsigned C = X->getType()->getPrimitiveSizeInBits() -
- IIOperand->getType()->getPrimitiveSizeInBits();
+ unsigned C = X->getType()->getScalarSizeInBits() -
+ IIOperand->getType()->getScalarSizeInBits();
Value *CV = ConstantInt::get(X->getType(), C);
Value *V = Builder.CreateLShr(X, CV);
return new TruncInst(V, IIOperand->getType());
@@ -2002,15 +886,14 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
case Intrinsic::powi:
if (ConstantInt *Power = dyn_cast<ConstantInt>(II->getArgOperand(1))) {
// 0 and 1 are handled in instsimplify
-
// powi(x, -1) -> 1/x
if (Power->isMinusOne())
- return BinaryOperator::CreateFDiv(ConstantFP::get(CI.getType(), 1.0),
- II->getArgOperand(0));
+ return BinaryOperator::CreateFDivFMF(ConstantFP::get(CI.getType(), 1.0),
+ II->getArgOperand(0), II);
// powi(x, 2) -> x*x
if (Power->equalsInt(2))
- return BinaryOperator::CreateFMul(II->getArgOperand(0),
- II->getArgOperand(0));
+ return BinaryOperator::CreateFMulFMF(II->getArgOperand(0),
+ II->getArgOperand(0), II);
}
break;
@@ -2031,8 +914,8 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
Type *Ty = II->getType();
unsigned BitWidth = Ty->getScalarSizeInBits();
Constant *ShAmtC;
- if (match(II->getArgOperand(2), m_Constant(ShAmtC)) &&
- !isa<ConstantExpr>(ShAmtC) && !ShAmtC->containsConstantExpression()) {
+ if (match(II->getArgOperand(2), m_ImmConstant(ShAmtC)) &&
+ !ShAmtC->containsConstantExpression()) {
// Canonicalize a shift amount constant operand to modulo the bit-width.
Constant *WidthC = ConstantInt::get(Ty, BitWidth);
Constant *ModuloC = ConstantExpr::getURem(ShAmtC, WidthC);
@@ -2092,8 +975,6 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
}
case Intrinsic::uadd_with_overflow:
case Intrinsic::sadd_with_overflow: {
- if (Instruction *I = canonicalizeConstantArg0ToArg1(CI))
- return I;
if (Instruction *I = foldIntrinsicWithOverflowCommon(II))
return I;
@@ -2121,10 +1002,6 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
case Intrinsic::umul_with_overflow:
case Intrinsic::smul_with_overflow:
- if (Instruction *I = canonicalizeConstantArg0ToArg1(CI))
- return I;
- LLVM_FALLTHROUGH;
-
case Intrinsic::usub_with_overflow:
if (Instruction *I = foldIntrinsicWithOverflowCommon(II))
return I;
@@ -2155,9 +1032,6 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
case Intrinsic::uadd_sat:
case Intrinsic::sadd_sat:
- if (Instruction *I = canonicalizeConstantArg0ToArg1(CI))
- return I;
- LLVM_FALLTHROUGH;
case Intrinsic::usub_sat:
case Intrinsic::ssub_sat: {
SaturatingInst *SI = cast<SaturatingInst>(II);
@@ -2238,8 +1112,6 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
case Intrinsic::maxnum:
case Intrinsic::minimum:
case Intrinsic::maximum: {
- if (Instruction *I = canonicalizeConstantArg0ToArg1(CI))
- return I;
Value *Arg0 = II->getArgOperand(0);
Value *Arg1 = II->getArgOperand(1);
Value *X, *Y;
@@ -2348,9 +1220,6 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
LLVM_FALLTHROUGH;
}
case Intrinsic::fma: {
- if (Instruction *I = canonicalizeConstantArg0ToArg1(CI))
- return I;
-
// fma fneg(x), fneg(y), z -> fma x, y, z
Value *Src0 = II->getArgOperand(0);
Value *Src1 = II->getArgOperand(1);
@@ -2390,40 +1259,52 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
break;
}
case Intrinsic::copysign: {
- if (SignBitMustBeZero(II->getArgOperand(1), &TLI)) {
+ Value *Mag = II->getArgOperand(0), *Sign = II->getArgOperand(1);
+ if (SignBitMustBeZero(Sign, &TLI)) {
// If we know that the sign argument is positive, reduce to FABS:
- // copysign X, Pos --> fabs X
- Value *Fabs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs,
- II->getArgOperand(0), II);
+ // copysign Mag, +Sign --> fabs Mag
+ Value *Fabs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, Mag, II);
return replaceInstUsesWith(*II, Fabs);
}
// TODO: There should be a ValueTracking sibling like SignBitMustBeOne.
const APFloat *C;
- if (match(II->getArgOperand(1), m_APFloat(C)) && C->isNegative()) {
+ if (match(Sign, m_APFloat(C)) && C->isNegative()) {
// If we know that the sign argument is negative, reduce to FNABS:
- // copysign X, Neg --> fneg (fabs X)
- Value *Fabs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs,
- II->getArgOperand(0), II);
+ // copysign Mag, -Sign --> fneg (fabs Mag)
+ Value *Fabs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, Mag, II);
return replaceInstUsesWith(*II, Builder.CreateFNegFMF(Fabs, II));
}
// Propagate sign argument through nested calls:
- // copysign X, (copysign ?, SignArg) --> copysign X, SignArg
- Value *SignArg;
- if (match(II->getArgOperand(1),
- m_Intrinsic<Intrinsic::copysign>(m_Value(), m_Value(SignArg))))
- return replaceOperand(*II, 1, SignArg);
+ // copysign Mag, (copysign ?, X) --> copysign Mag, X
+ Value *X;
+ if (match(Sign, m_Intrinsic<Intrinsic::copysign>(m_Value(), m_Value(X))))
+ return replaceOperand(*II, 1, X);
+
+ // Peek through changes of magnitude's sign-bit. This call rewrites those:
+ // copysign (fabs X), Sign --> copysign X, Sign
+ // copysign (fneg X), Sign --> copysign X, Sign
+ if (match(Mag, m_FAbs(m_Value(X))) || match(Mag, m_FNeg(m_Value(X))))
+ return replaceOperand(*II, 0, X);
break;
}
case Intrinsic::fabs: {
- Value *Cond;
- Constant *LHS, *RHS;
+ Value *Cond, *TVal, *FVal;
if (match(II->getArgOperand(0),
- m_Select(m_Value(Cond), m_Constant(LHS), m_Constant(RHS)))) {
- CallInst *Call0 = Builder.CreateCall(II->getCalledFunction(), {LHS});
- CallInst *Call1 = Builder.CreateCall(II->getCalledFunction(), {RHS});
- return SelectInst::Create(Cond, Call0, Call1);
+ m_Select(m_Value(Cond), m_Value(TVal), m_Value(FVal)))) {
+ // fabs (select Cond, TrueC, FalseC) --> select Cond, AbsT, AbsF
+ if (isa<Constant>(TVal) && isa<Constant>(FVal)) {
+ CallInst *AbsT = Builder.CreateCall(II->getCalledFunction(), {TVal});
+ CallInst *AbsF = Builder.CreateCall(II->getCalledFunction(), {FVal});
+ return SelectInst::Create(Cond, AbsT, AbsF);
+ }
+ // fabs (select Cond, -FVal, FVal) --> fabs FVal
+ if (match(TVal, m_FNeg(m_Specific(FVal))))
+ return replaceOperand(*II, 0, FVal);
+ // fabs (select Cond, TVal, -TVal) --> fabs TVal
+ if (match(FVal, m_FNeg(m_Specific(TVal))))
+ return replaceOperand(*II, 0, TVal);
}
LLVM_FALLTHROUGH;
@@ -2465,932 +1346,6 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
}
break;
}
- case Intrinsic::ppc_altivec_lvx:
- case Intrinsic::ppc_altivec_lvxl:
- // Turn PPC lvx -> load if the pointer is known aligned.
- if (getOrEnforceKnownAlignment(II->getArgOperand(0), Align(16), DL, II, &AC,
- &DT) >= 16) {
- Value *Ptr = Builder.CreateBitCast(II->getArgOperand(0),
- PointerType::getUnqual(II->getType()));
- return new LoadInst(II->getType(), Ptr, "", false, Align(16));
- }
- break;
- case Intrinsic::ppc_vsx_lxvw4x:
- case Intrinsic::ppc_vsx_lxvd2x: {
- // Turn PPC VSX loads into normal loads.
- Value *Ptr = Builder.CreateBitCast(II->getArgOperand(0),
- PointerType::getUnqual(II->getType()));
- return new LoadInst(II->getType(), Ptr, Twine(""), false, Align(1));
- }
- case Intrinsic::ppc_altivec_stvx:
- case Intrinsic::ppc_altivec_stvxl:
- // Turn stvx -> store if the pointer is known aligned.
- if (getOrEnforceKnownAlignment(II->getArgOperand(1), Align(16), DL, II, &AC,
- &DT) >= 16) {
- Type *OpPtrTy =
- PointerType::getUnqual(II->getArgOperand(0)->getType());
- Value *Ptr = Builder.CreateBitCast(II->getArgOperand(1), OpPtrTy);
- return new StoreInst(II->getArgOperand(0), Ptr, false, Align(16));
- }
- break;
- case Intrinsic::ppc_vsx_stxvw4x:
- case Intrinsic::ppc_vsx_stxvd2x: {
- // Turn PPC VSX stores into normal stores.
- Type *OpPtrTy = PointerType::getUnqual(II->getArgOperand(0)->getType());
- Value *Ptr = Builder.CreateBitCast(II->getArgOperand(1), OpPtrTy);
- return new StoreInst(II->getArgOperand(0), Ptr, false, Align(1));
- }
- case Intrinsic::ppc_qpx_qvlfs:
- // Turn PPC QPX qvlfs -> load if the pointer is known aligned.
- if (getOrEnforceKnownAlignment(II->getArgOperand(0), Align(16), DL, II, &AC,
- &DT) >= 16) {
- Type *VTy =
- VectorType::get(Builder.getFloatTy(),
- cast<VectorType>(II->getType())->getElementCount());
- Value *Ptr = Builder.CreateBitCast(II->getArgOperand(0),
- PointerType::getUnqual(VTy));
- Value *Load = Builder.CreateLoad(VTy, Ptr);
- return new FPExtInst(Load, II->getType());
- }
- break;
- case Intrinsic::ppc_qpx_qvlfd:
- // Turn PPC QPX qvlfd -> load if the pointer is known aligned.
- if (getOrEnforceKnownAlignment(II->getArgOperand(0), Align(32), DL, II, &AC,
- &DT) >= 32) {
- Value *Ptr = Builder.CreateBitCast(II->getArgOperand(0),
- PointerType::getUnqual(II->getType()));
- return new LoadInst(II->getType(), Ptr, "", false, Align(32));
- }
- break;
- case Intrinsic::ppc_qpx_qvstfs:
- // Turn PPC QPX qvstfs -> store if the pointer is known aligned.
- if (getOrEnforceKnownAlignment(II->getArgOperand(1), Align(16), DL, II, &AC,
- &DT) >= 16) {
- Type *VTy = VectorType::get(
- Builder.getFloatTy(),
- cast<VectorType>(II->getArgOperand(0)->getType())->getElementCount());
- Value *TOp = Builder.CreateFPTrunc(II->getArgOperand(0), VTy);
- Type *OpPtrTy = PointerType::getUnqual(VTy);
- Value *Ptr = Builder.CreateBitCast(II->getArgOperand(1), OpPtrTy);
- return new StoreInst(TOp, Ptr, false, Align(16));
- }
- break;
- case Intrinsic::ppc_qpx_qvstfd:
- // Turn PPC QPX qvstfd -> store if the pointer is known aligned.
- if (getOrEnforceKnownAlignment(II->getArgOperand(1), Align(32), DL, II, &AC,
- &DT) >= 32) {
- Type *OpPtrTy =
- PointerType::getUnqual(II->getArgOperand(0)->getType());
- Value *Ptr = Builder.CreateBitCast(II->getArgOperand(1), OpPtrTy);
- return new StoreInst(II->getArgOperand(0), Ptr, false, Align(32));
- }
- break;
-
- case Intrinsic::x86_bmi_bextr_32:
- case Intrinsic::x86_bmi_bextr_64:
- case Intrinsic::x86_tbm_bextri_u32:
- case Intrinsic::x86_tbm_bextri_u64:
- // If the RHS is a constant we can try some simplifications.
- if (auto *C = dyn_cast<ConstantInt>(II->getArgOperand(1))) {
- uint64_t Shift = C->getZExtValue();
- uint64_t Length = (Shift >> 8) & 0xff;
- Shift &= 0xff;
- unsigned BitWidth = II->getType()->getIntegerBitWidth();
- // If the length is 0 or the shift is out of range, replace with zero.
- if (Length == 0 || Shift >= BitWidth)
- return replaceInstUsesWith(CI, ConstantInt::get(II->getType(), 0));
- // If the LHS is also a constant, we can completely constant fold this.
- if (auto *InC = dyn_cast<ConstantInt>(II->getArgOperand(0))) {
- uint64_t Result = InC->getZExtValue() >> Shift;
- if (Length > BitWidth)
- Length = BitWidth;
- Result &= maskTrailingOnes<uint64_t>(Length);
- return replaceInstUsesWith(CI, ConstantInt::get(II->getType(), Result));
- }
- // TODO should we turn this into 'and' if shift is 0? Or 'shl' if we
- // are only masking bits that a shift already cleared?
- }
- break;
-
- case Intrinsic::x86_bmi_bzhi_32:
- case Intrinsic::x86_bmi_bzhi_64:
- // If the RHS is a constant we can try some simplifications.
- if (auto *C = dyn_cast<ConstantInt>(II->getArgOperand(1))) {
- uint64_t Index = C->getZExtValue() & 0xff;
- unsigned BitWidth = II->getType()->getIntegerBitWidth();
- if (Index >= BitWidth)
- return replaceInstUsesWith(CI, II->getArgOperand(0));
- if (Index == 0)
- return replaceInstUsesWith(CI, ConstantInt::get(II->getType(), 0));
- // If the LHS is also a constant, we can completely constant fold this.
- if (auto *InC = dyn_cast<ConstantInt>(II->getArgOperand(0))) {
- uint64_t Result = InC->getZExtValue();
- Result &= maskTrailingOnes<uint64_t>(Index);
- return replaceInstUsesWith(CI, ConstantInt::get(II->getType(), Result));
- }
- // TODO should we convert this to an AND if the RHS is constant?
- }
- break;
- case Intrinsic::x86_bmi_pext_32:
- case Intrinsic::x86_bmi_pext_64:
- if (auto *MaskC = dyn_cast<ConstantInt>(II->getArgOperand(1))) {
- if (MaskC->isNullValue())
- return replaceInstUsesWith(CI, ConstantInt::get(II->getType(), 0));
- if (MaskC->isAllOnesValue())
- return replaceInstUsesWith(CI, II->getArgOperand(0));
-
- if (auto *SrcC = dyn_cast<ConstantInt>(II->getArgOperand(0))) {
- uint64_t Src = SrcC->getZExtValue();
- uint64_t Mask = MaskC->getZExtValue();
- uint64_t Result = 0;
- uint64_t BitToSet = 1;
-
- while (Mask) {
- // Isolate lowest set bit.
- uint64_t BitToTest = Mask & -Mask;
- if (BitToTest & Src)
- Result |= BitToSet;
-
- BitToSet <<= 1;
- // Clear lowest set bit.
- Mask &= Mask - 1;
- }
-
- return replaceInstUsesWith(CI, ConstantInt::get(II->getType(), Result));
- }
- }
- break;
- case Intrinsic::x86_bmi_pdep_32:
- case Intrinsic::x86_bmi_pdep_64:
- if (auto *MaskC = dyn_cast<ConstantInt>(II->getArgOperand(1))) {
- if (MaskC->isNullValue())
- return replaceInstUsesWith(CI, ConstantInt::get(II->getType(), 0));
- if (MaskC->isAllOnesValue())
- return replaceInstUsesWith(CI, II->getArgOperand(0));
-
- if (auto *SrcC = dyn_cast<ConstantInt>(II->getArgOperand(0))) {
- uint64_t Src = SrcC->getZExtValue();
- uint64_t Mask = MaskC->getZExtValue();
- uint64_t Result = 0;
- uint64_t BitToTest = 1;
-
- while (Mask) {
- // Isolate lowest set bit.
- uint64_t BitToSet = Mask & -Mask;
- if (BitToTest & Src)
- Result |= BitToSet;
-
- BitToTest <<= 1;
- // Clear lowest set bit;
- Mask &= Mask - 1;
- }
-
- return replaceInstUsesWith(CI, ConstantInt::get(II->getType(), Result));
- }
- }
- break;
-
- case Intrinsic::x86_sse_cvtss2si:
- case Intrinsic::x86_sse_cvtss2si64:
- case Intrinsic::x86_sse_cvttss2si:
- case Intrinsic::x86_sse_cvttss2si64:
- case Intrinsic::x86_sse2_cvtsd2si:
- case Intrinsic::x86_sse2_cvtsd2si64:
- case Intrinsic::x86_sse2_cvttsd2si:
- case Intrinsic::x86_sse2_cvttsd2si64:
- case Intrinsic::x86_avx512_vcvtss2si32:
- case Intrinsic::x86_avx512_vcvtss2si64:
- case Intrinsic::x86_avx512_vcvtss2usi32:
- case Intrinsic::x86_avx512_vcvtss2usi64:
- case Intrinsic::x86_avx512_vcvtsd2si32:
- case Intrinsic::x86_avx512_vcvtsd2si64:
- case Intrinsic::x86_avx512_vcvtsd2usi32:
- case Intrinsic::x86_avx512_vcvtsd2usi64:
- case Intrinsic::x86_avx512_cvttss2si:
- case Intrinsic::x86_avx512_cvttss2si64:
- case Intrinsic::x86_avx512_cvttss2usi:
- case Intrinsic::x86_avx512_cvttss2usi64:
- case Intrinsic::x86_avx512_cvttsd2si:
- case Intrinsic::x86_avx512_cvttsd2si64:
- case Intrinsic::x86_avx512_cvttsd2usi:
- case Intrinsic::x86_avx512_cvttsd2usi64: {
- // These intrinsics only demand the 0th element of their input vectors. If
- // we can simplify the input based on that, do so now.
- Value *Arg = II->getArgOperand(0);
- unsigned VWidth = cast<VectorType>(Arg->getType())->getNumElements();
- if (Value *V = SimplifyDemandedVectorEltsLow(Arg, VWidth, 1))
- return replaceOperand(*II, 0, V);
- break;
- }
-
- case Intrinsic::x86_mmx_pmovmskb:
- case Intrinsic::x86_sse_movmsk_ps:
- case Intrinsic::x86_sse2_movmsk_pd:
- case Intrinsic::x86_sse2_pmovmskb_128:
- case Intrinsic::x86_avx_movmsk_pd_256:
- case Intrinsic::x86_avx_movmsk_ps_256:
- case Intrinsic::x86_avx2_pmovmskb:
- if (Value *V = simplifyX86movmsk(*II, Builder))
- return replaceInstUsesWith(*II, V);
- break;
-
- case Intrinsic::x86_sse_comieq_ss:
- case Intrinsic::x86_sse_comige_ss:
- case Intrinsic::x86_sse_comigt_ss:
- case Intrinsic::x86_sse_comile_ss:
- case Intrinsic::x86_sse_comilt_ss:
- case Intrinsic::x86_sse_comineq_ss:
- case Intrinsic::x86_sse_ucomieq_ss:
- case Intrinsic::x86_sse_ucomige_ss:
- case Intrinsic::x86_sse_ucomigt_ss:
- case Intrinsic::x86_sse_ucomile_ss:
- case Intrinsic::x86_sse_ucomilt_ss:
- case Intrinsic::x86_sse_ucomineq_ss:
- case Intrinsic::x86_sse2_comieq_sd:
- case Intrinsic::x86_sse2_comige_sd:
- case Intrinsic::x86_sse2_comigt_sd:
- case Intrinsic::x86_sse2_comile_sd:
- case Intrinsic::x86_sse2_comilt_sd:
- case Intrinsic::x86_sse2_comineq_sd:
- case Intrinsic::x86_sse2_ucomieq_sd:
- case Intrinsic::x86_sse2_ucomige_sd:
- case Intrinsic::x86_sse2_ucomigt_sd:
- case Intrinsic::x86_sse2_ucomile_sd:
- case Intrinsic::x86_sse2_ucomilt_sd:
- case Intrinsic::x86_sse2_ucomineq_sd:
- case Intrinsic::x86_avx512_vcomi_ss:
- case Intrinsic::x86_avx512_vcomi_sd:
- case Intrinsic::x86_avx512_mask_cmp_ss:
- case Intrinsic::x86_avx512_mask_cmp_sd: {
- // These intrinsics only demand the 0th element of their input vectors. If
- // we can simplify the input based on that, do so now.
- bool MadeChange = false;
- Value *Arg0 = II->getArgOperand(0);
- Value *Arg1 = II->getArgOperand(1);
- unsigned VWidth = cast<VectorType>(Arg0->getType())->getNumElements();
- if (Value *V = SimplifyDemandedVectorEltsLow(Arg0, VWidth, 1)) {
- replaceOperand(*II, 0, V);
- MadeChange = true;
- }
- if (Value *V = SimplifyDemandedVectorEltsLow(Arg1, VWidth, 1)) {
- replaceOperand(*II, 1, V);
- MadeChange = true;
- }
- if (MadeChange)
- return II;
- break;
- }
- case Intrinsic::x86_avx512_cmp_pd_128:
- case Intrinsic::x86_avx512_cmp_pd_256:
- case Intrinsic::x86_avx512_cmp_pd_512:
- case Intrinsic::x86_avx512_cmp_ps_128:
- case Intrinsic::x86_avx512_cmp_ps_256:
- case Intrinsic::x86_avx512_cmp_ps_512: {
- // Folding cmp(sub(a,b),0) -> cmp(a,b) and cmp(0,sub(a,b)) -> cmp(b,a)
- Value *Arg0 = II->getArgOperand(0);
- Value *Arg1 = II->getArgOperand(1);
- bool Arg0IsZero = match(Arg0, m_PosZeroFP());
- if (Arg0IsZero)
- std::swap(Arg0, Arg1);
- Value *A, *B;
- // This fold requires only the NINF(not +/- inf) since inf minus
- // inf is nan.
- // NSZ(No Signed Zeros) is not needed because zeros of any sign are
- // equal for both compares.
- // NNAN is not needed because nans compare the same for both compares.
- // The compare intrinsic uses the above assumptions and therefore
- // doesn't require additional flags.
- if ((match(Arg0, m_OneUse(m_FSub(m_Value(A), m_Value(B)))) &&
- match(Arg1, m_PosZeroFP()) && isa<Instruction>(Arg0) &&
- cast<Instruction>(Arg0)->getFastMathFlags().noInfs())) {
- if (Arg0IsZero)
- std::swap(A, B);
- replaceOperand(*II, 0, A);
- replaceOperand(*II, 1, B);
- return II;
- }
- break;
- }
-
- case Intrinsic::x86_avx512_add_ps_512:
- case Intrinsic::x86_avx512_div_ps_512:
- case Intrinsic::x86_avx512_mul_ps_512:
- case Intrinsic::x86_avx512_sub_ps_512:
- case Intrinsic::x86_avx512_add_pd_512:
- case Intrinsic::x86_avx512_div_pd_512:
- case Intrinsic::x86_avx512_mul_pd_512:
- case Intrinsic::x86_avx512_sub_pd_512:
- // If the rounding mode is CUR_DIRECTION(4) we can turn these into regular
- // IR operations.
- if (auto *R = dyn_cast<ConstantInt>(II->getArgOperand(2))) {
- if (R->getValue() == 4) {
- Value *Arg0 = II->getArgOperand(0);
- Value *Arg1 = II->getArgOperand(1);
-
- Value *V;
- switch (IID) {
- default: llvm_unreachable("Case stmts out of sync!");
- case Intrinsic::x86_avx512_add_ps_512:
- case Intrinsic::x86_avx512_add_pd_512:
- V = Builder.CreateFAdd(Arg0, Arg1);
- break;
- case Intrinsic::x86_avx512_sub_ps_512:
- case Intrinsic::x86_avx512_sub_pd_512:
- V = Builder.CreateFSub(Arg0, Arg1);
- break;
- case Intrinsic::x86_avx512_mul_ps_512:
- case Intrinsic::x86_avx512_mul_pd_512:
- V = Builder.CreateFMul(Arg0, Arg1);
- break;
- case Intrinsic::x86_avx512_div_ps_512:
- case Intrinsic::x86_avx512_div_pd_512:
- V = Builder.CreateFDiv(Arg0, Arg1);
- break;
- }
-
- return replaceInstUsesWith(*II, V);
- }
- }
- break;
-
- case Intrinsic::x86_avx512_mask_add_ss_round:
- case Intrinsic::x86_avx512_mask_div_ss_round:
- case Intrinsic::x86_avx512_mask_mul_ss_round:
- case Intrinsic::x86_avx512_mask_sub_ss_round:
- case Intrinsic::x86_avx512_mask_add_sd_round:
- case Intrinsic::x86_avx512_mask_div_sd_round:
- case Intrinsic::x86_avx512_mask_mul_sd_round:
- case Intrinsic::x86_avx512_mask_sub_sd_round:
- // If the rounding mode is CUR_DIRECTION(4) we can turn these into regular
- // IR operations.
- if (auto *R = dyn_cast<ConstantInt>(II->getArgOperand(4))) {
- if (R->getValue() == 4) {
- // Extract the element as scalars.
- Value *Arg0 = II->getArgOperand(0);
- Value *Arg1 = II->getArgOperand(1);
- Value *LHS = Builder.CreateExtractElement(Arg0, (uint64_t)0);
- Value *RHS = Builder.CreateExtractElement(Arg1, (uint64_t)0);
-
- Value *V;
- switch (IID) {
- default: llvm_unreachable("Case stmts out of sync!");
- case Intrinsic::x86_avx512_mask_add_ss_round:
- case Intrinsic::x86_avx512_mask_add_sd_round:
- V = Builder.CreateFAdd(LHS, RHS);
- break;
- case Intrinsic::x86_avx512_mask_sub_ss_round:
- case Intrinsic::x86_avx512_mask_sub_sd_round:
- V = Builder.CreateFSub(LHS, RHS);
- break;
- case Intrinsic::x86_avx512_mask_mul_ss_round:
- case Intrinsic::x86_avx512_mask_mul_sd_round:
- V = Builder.CreateFMul(LHS, RHS);
- break;
- case Intrinsic::x86_avx512_mask_div_ss_round:
- case Intrinsic::x86_avx512_mask_div_sd_round:
- V = Builder.CreateFDiv(LHS, RHS);
- break;
- }
-
- // Handle the masking aspect of the intrinsic.
- Value *Mask = II->getArgOperand(3);
- auto *C = dyn_cast<ConstantInt>(Mask);
- // We don't need a select if we know the mask bit is a 1.
- if (!C || !C->getValue()[0]) {
- // Cast the mask to an i1 vector and then extract the lowest element.
- auto *MaskTy = FixedVectorType::get(
- Builder.getInt1Ty(),
- cast<IntegerType>(Mask->getType())->getBitWidth());
- Mask = Builder.CreateBitCast(Mask, MaskTy);
- Mask = Builder.CreateExtractElement(Mask, (uint64_t)0);
- // Extract the lowest element from the passthru operand.
- Value *Passthru = Builder.CreateExtractElement(II->getArgOperand(2),
- (uint64_t)0);
- V = Builder.CreateSelect(Mask, V, Passthru);
- }
-
- // Insert the result back into the original argument 0.
- V = Builder.CreateInsertElement(Arg0, V, (uint64_t)0);
-
- return replaceInstUsesWith(*II, V);
- }
- }
- break;
-
- // Constant fold ashr( <A x Bi>, Ci ).
- // Constant fold lshr( <A x Bi>, Ci ).
- // Constant fold shl( <A x Bi>, Ci ).
- case Intrinsic::x86_sse2_psrai_d:
- case Intrinsic::x86_sse2_psrai_w:
- case Intrinsic::x86_avx2_psrai_d:
- case Intrinsic::x86_avx2_psrai_w:
- case Intrinsic::x86_avx512_psrai_q_128:
- case Intrinsic::x86_avx512_psrai_q_256:
- case Intrinsic::x86_avx512_psrai_d_512:
- case Intrinsic::x86_avx512_psrai_q_512:
- case Intrinsic::x86_avx512_psrai_w_512:
- case Intrinsic::x86_sse2_psrli_d:
- case Intrinsic::x86_sse2_psrli_q:
- case Intrinsic::x86_sse2_psrli_w:
- case Intrinsic::x86_avx2_psrli_d:
- case Intrinsic::x86_avx2_psrli_q:
- case Intrinsic::x86_avx2_psrli_w:
- case Intrinsic::x86_avx512_psrli_d_512:
- case Intrinsic::x86_avx512_psrli_q_512:
- case Intrinsic::x86_avx512_psrli_w_512:
- case Intrinsic::x86_sse2_pslli_d:
- case Intrinsic::x86_sse2_pslli_q:
- case Intrinsic::x86_sse2_pslli_w:
- case Intrinsic::x86_avx2_pslli_d:
- case Intrinsic::x86_avx2_pslli_q:
- case Intrinsic::x86_avx2_pslli_w:
- case Intrinsic::x86_avx512_pslli_d_512:
- case Intrinsic::x86_avx512_pslli_q_512:
- case Intrinsic::x86_avx512_pslli_w_512:
- if (Value *V = simplifyX86immShift(*II, Builder))
- return replaceInstUsesWith(*II, V);
- break;
-
- case Intrinsic::x86_sse2_psra_d:
- case Intrinsic::x86_sse2_psra_w:
- case Intrinsic::x86_avx2_psra_d:
- case Intrinsic::x86_avx2_psra_w:
- case Intrinsic::x86_avx512_psra_q_128:
- case Intrinsic::x86_avx512_psra_q_256:
- case Intrinsic::x86_avx512_psra_d_512:
- case Intrinsic::x86_avx512_psra_q_512:
- case Intrinsic::x86_avx512_psra_w_512:
- case Intrinsic::x86_sse2_psrl_d:
- case Intrinsic::x86_sse2_psrl_q:
- case Intrinsic::x86_sse2_psrl_w:
- case Intrinsic::x86_avx2_psrl_d:
- case Intrinsic::x86_avx2_psrl_q:
- case Intrinsic::x86_avx2_psrl_w:
- case Intrinsic::x86_avx512_psrl_d_512:
- case Intrinsic::x86_avx512_psrl_q_512:
- case Intrinsic::x86_avx512_psrl_w_512:
- case Intrinsic::x86_sse2_psll_d:
- case Intrinsic::x86_sse2_psll_q:
- case Intrinsic::x86_sse2_psll_w:
- case Intrinsic::x86_avx2_psll_d:
- case Intrinsic::x86_avx2_psll_q:
- case Intrinsic::x86_avx2_psll_w:
- case Intrinsic::x86_avx512_psll_d_512:
- case Intrinsic::x86_avx512_psll_q_512:
- case Intrinsic::x86_avx512_psll_w_512: {
- if (Value *V = simplifyX86immShift(*II, Builder))
- return replaceInstUsesWith(*II, V);
-
- // SSE2/AVX2 uses only the first 64-bits of the 128-bit vector
- // operand to compute the shift amount.
- Value *Arg1 = II->getArgOperand(1);
- assert(Arg1->getType()->getPrimitiveSizeInBits() == 128 &&
- "Unexpected packed shift size");
- unsigned VWidth = cast<VectorType>(Arg1->getType())->getNumElements();
-
- if (Value *V = SimplifyDemandedVectorEltsLow(Arg1, VWidth, VWidth / 2))
- return replaceOperand(*II, 1, V);
- break;
- }
-
- case Intrinsic::x86_avx2_psllv_d:
- case Intrinsic::x86_avx2_psllv_d_256:
- case Intrinsic::x86_avx2_psllv_q:
- case Intrinsic::x86_avx2_psllv_q_256:
- case Intrinsic::x86_avx512_psllv_d_512:
- case Intrinsic::x86_avx512_psllv_q_512:
- case Intrinsic::x86_avx512_psllv_w_128:
- case Intrinsic::x86_avx512_psllv_w_256:
- case Intrinsic::x86_avx512_psllv_w_512:
- case Intrinsic::x86_avx2_psrav_d:
- case Intrinsic::x86_avx2_psrav_d_256:
- case Intrinsic::x86_avx512_psrav_q_128:
- case Intrinsic::x86_avx512_psrav_q_256:
- case Intrinsic::x86_avx512_psrav_d_512:
- case Intrinsic::x86_avx512_psrav_q_512:
- case Intrinsic::x86_avx512_psrav_w_128:
- case Intrinsic::x86_avx512_psrav_w_256:
- case Intrinsic::x86_avx512_psrav_w_512:
- case Intrinsic::x86_avx2_psrlv_d:
- case Intrinsic::x86_avx2_psrlv_d_256:
- case Intrinsic::x86_avx2_psrlv_q:
- case Intrinsic::x86_avx2_psrlv_q_256:
- case Intrinsic::x86_avx512_psrlv_d_512:
- case Intrinsic::x86_avx512_psrlv_q_512:
- case Intrinsic::x86_avx512_psrlv_w_128:
- case Intrinsic::x86_avx512_psrlv_w_256:
- case Intrinsic::x86_avx512_psrlv_w_512:
- if (Value *V = simplifyX86varShift(*II, Builder))
- return replaceInstUsesWith(*II, V);
- break;
-
- case Intrinsic::x86_sse2_packssdw_128:
- case Intrinsic::x86_sse2_packsswb_128:
- case Intrinsic::x86_avx2_packssdw:
- case Intrinsic::x86_avx2_packsswb:
- case Intrinsic::x86_avx512_packssdw_512:
- case Intrinsic::x86_avx512_packsswb_512:
- if (Value *V = simplifyX86pack(*II, Builder, true))
- return replaceInstUsesWith(*II, V);
- break;
-
- case Intrinsic::x86_sse2_packuswb_128:
- case Intrinsic::x86_sse41_packusdw:
- case Intrinsic::x86_avx2_packusdw:
- case Intrinsic::x86_avx2_packuswb:
- case Intrinsic::x86_avx512_packusdw_512:
- case Intrinsic::x86_avx512_packuswb_512:
- if (Value *V = simplifyX86pack(*II, Builder, false))
- return replaceInstUsesWith(*II, V);
- break;
-
- case Intrinsic::x86_pclmulqdq:
- case Intrinsic::x86_pclmulqdq_256:
- case Intrinsic::x86_pclmulqdq_512: {
- if (auto *C = dyn_cast<ConstantInt>(II->getArgOperand(2))) {
- unsigned Imm = C->getZExtValue();
-
- bool MadeChange = false;
- Value *Arg0 = II->getArgOperand(0);
- Value *Arg1 = II->getArgOperand(1);
- unsigned VWidth = cast<VectorType>(Arg0->getType())->getNumElements();
-
- APInt UndefElts1(VWidth, 0);
- APInt DemandedElts1 = APInt::getSplat(VWidth,
- APInt(2, (Imm & 0x01) ? 2 : 1));
- if (Value *V = SimplifyDemandedVectorElts(Arg0, DemandedElts1,
- UndefElts1)) {
- replaceOperand(*II, 0, V);
- MadeChange = true;
- }
-
- APInt UndefElts2(VWidth, 0);
- APInt DemandedElts2 = APInt::getSplat(VWidth,
- APInt(2, (Imm & 0x10) ? 2 : 1));
- if (Value *V = SimplifyDemandedVectorElts(Arg1, DemandedElts2,
- UndefElts2)) {
- replaceOperand(*II, 1, V);
- MadeChange = true;
- }
-
- // If either input elements are undef, the result is zero.
- if (DemandedElts1.isSubsetOf(UndefElts1) ||
- DemandedElts2.isSubsetOf(UndefElts2))
- return replaceInstUsesWith(*II,
- ConstantAggregateZero::get(II->getType()));
-
- if (MadeChange)
- return II;
- }
- break;
- }
-
- case Intrinsic::x86_sse41_insertps:
- if (Value *V = simplifyX86insertps(*II, Builder))
- return replaceInstUsesWith(*II, V);
- break;
-
- case Intrinsic::x86_sse4a_extrq: {
- Value *Op0 = II->getArgOperand(0);
- Value *Op1 = II->getArgOperand(1);
- unsigned VWidth0 = cast<VectorType>(Op0->getType())->getNumElements();
- unsigned VWidth1 = cast<VectorType>(Op1->getType())->getNumElements();
- assert(Op0->getType()->getPrimitiveSizeInBits() == 128 &&
- Op1->getType()->getPrimitiveSizeInBits() == 128 && VWidth0 == 2 &&
- VWidth1 == 16 && "Unexpected operand sizes");
-
- // See if we're dealing with constant values.
- Constant *C1 = dyn_cast<Constant>(Op1);
- ConstantInt *CILength =
- C1 ? dyn_cast_or_null<ConstantInt>(C1->getAggregateElement((unsigned)0))
- : nullptr;
- ConstantInt *CIIndex =
- C1 ? dyn_cast_or_null<ConstantInt>(C1->getAggregateElement((unsigned)1))
- : nullptr;
-
- // Attempt to simplify to a constant, shuffle vector or EXTRQI call.
- if (Value *V = simplifyX86extrq(*II, Op0, CILength, CIIndex, Builder))
- return replaceInstUsesWith(*II, V);
-
- // EXTRQ only uses the lowest 64-bits of the first 128-bit vector
- // operands and the lowest 16-bits of the second.
- bool MadeChange = false;
- if (Value *V = SimplifyDemandedVectorEltsLow(Op0, VWidth0, 1)) {
- replaceOperand(*II, 0, V);
- MadeChange = true;
- }
- if (Value *V = SimplifyDemandedVectorEltsLow(Op1, VWidth1, 2)) {
- replaceOperand(*II, 1, V);
- MadeChange = true;
- }
- if (MadeChange)
- return II;
- break;
- }
-
- case Intrinsic::x86_sse4a_extrqi: {
- // EXTRQI: Extract Length bits starting from Index. Zero pad the remaining
- // bits of the lower 64-bits. The upper 64-bits are undefined.
- Value *Op0 = II->getArgOperand(0);
- unsigned VWidth = cast<VectorType>(Op0->getType())->getNumElements();
- assert(Op0->getType()->getPrimitiveSizeInBits() == 128 && VWidth == 2 &&
- "Unexpected operand size");
-
- // See if we're dealing with constant values.
- ConstantInt *CILength = dyn_cast<ConstantInt>(II->getArgOperand(1));
- ConstantInt *CIIndex = dyn_cast<ConstantInt>(II->getArgOperand(2));
-
- // Attempt to simplify to a constant or shuffle vector.
- if (Value *V = simplifyX86extrq(*II, Op0, CILength, CIIndex, Builder))
- return replaceInstUsesWith(*II, V);
-
- // EXTRQI only uses the lowest 64-bits of the first 128-bit vector
- // operand.
- if (Value *V = SimplifyDemandedVectorEltsLow(Op0, VWidth, 1))
- return replaceOperand(*II, 0, V);
- break;
- }
-
- case Intrinsic::x86_sse4a_insertq: {
- Value *Op0 = II->getArgOperand(0);
- Value *Op1 = II->getArgOperand(1);
- unsigned VWidth = cast<VectorType>(Op0->getType())->getNumElements();
- assert(Op0->getType()->getPrimitiveSizeInBits() == 128 &&
- Op1->getType()->getPrimitiveSizeInBits() == 128 && VWidth == 2 &&
- cast<VectorType>(Op1->getType())->getNumElements() == 2 &&
- "Unexpected operand size");
-
- // See if we're dealing with constant values.
- Constant *C1 = dyn_cast<Constant>(Op1);
- ConstantInt *CI11 =
- C1 ? dyn_cast_or_null<ConstantInt>(C1->getAggregateElement((unsigned)1))
- : nullptr;
-
- // Attempt to simplify to a constant, shuffle vector or INSERTQI call.
- if (CI11) {
- const APInt &V11 = CI11->getValue();
- APInt Len = V11.zextOrTrunc(6);
- APInt Idx = V11.lshr(8).zextOrTrunc(6);
- if (Value *V = simplifyX86insertq(*II, Op0, Op1, Len, Idx, Builder))
- return replaceInstUsesWith(*II, V);
- }
-
- // INSERTQ only uses the lowest 64-bits of the first 128-bit vector
- // operand.
- if (Value *V = SimplifyDemandedVectorEltsLow(Op0, VWidth, 1))
- return replaceOperand(*II, 0, V);
- break;
- }
-
- case Intrinsic::x86_sse4a_insertqi: {
- // INSERTQI: Extract lowest Length bits from lower half of second source and
- // insert over first source starting at Index bit. The upper 64-bits are
- // undefined.
- Value *Op0 = II->getArgOperand(0);
- Value *Op1 = II->getArgOperand(1);
- unsigned VWidth0 = cast<VectorType>(Op0->getType())->getNumElements();
- unsigned VWidth1 = cast<VectorType>(Op1->getType())->getNumElements();
- assert(Op0->getType()->getPrimitiveSizeInBits() == 128 &&
- Op1->getType()->getPrimitiveSizeInBits() == 128 && VWidth0 == 2 &&
- VWidth1 == 2 && "Unexpected operand sizes");
-
- // See if we're dealing with constant values.
- ConstantInt *CILength = dyn_cast<ConstantInt>(II->getArgOperand(2));
- ConstantInt *CIIndex = dyn_cast<ConstantInt>(II->getArgOperand(3));
-
- // Attempt to simplify to a constant or shuffle vector.
- if (CILength && CIIndex) {
- APInt Len = CILength->getValue().zextOrTrunc(6);
- APInt Idx = CIIndex->getValue().zextOrTrunc(6);
- if (Value *V = simplifyX86insertq(*II, Op0, Op1, Len, Idx, Builder))
- return replaceInstUsesWith(*II, V);
- }
-
- // INSERTQI only uses the lowest 64-bits of the first two 128-bit vector
- // operands.
- bool MadeChange = false;
- if (Value *V = SimplifyDemandedVectorEltsLow(Op0, VWidth0, 1)) {
- replaceOperand(*II, 0, V);
- MadeChange = true;
- }
- if (Value *V = SimplifyDemandedVectorEltsLow(Op1, VWidth1, 1)) {
- replaceOperand(*II, 1, V);
- MadeChange = true;
- }
- if (MadeChange)
- return II;
- break;
- }
-
- case Intrinsic::x86_sse41_pblendvb:
- case Intrinsic::x86_sse41_blendvps:
- case Intrinsic::x86_sse41_blendvpd:
- case Intrinsic::x86_avx_blendv_ps_256:
- case Intrinsic::x86_avx_blendv_pd_256:
- case Intrinsic::x86_avx2_pblendvb: {
- // fold (blend A, A, Mask) -> A
- Value *Op0 = II->getArgOperand(0);
- Value *Op1 = II->getArgOperand(1);
- Value *Mask = II->getArgOperand(2);
- if (Op0 == Op1)
- return replaceInstUsesWith(CI, Op0);
-
- // Zero Mask - select 1st argument.
- if (isa<ConstantAggregateZero>(Mask))
- return replaceInstUsesWith(CI, Op0);
-
- // Constant Mask - select 1st/2nd argument lane based on top bit of mask.
- if (auto *ConstantMask = dyn_cast<ConstantDataVector>(Mask)) {
- Constant *NewSelector = getNegativeIsTrueBoolVec(ConstantMask);
- return SelectInst::Create(NewSelector, Op1, Op0, "blendv");
- }
-
- // Convert to a vector select if we can bypass casts and find a boolean
- // vector condition value.
- Value *BoolVec;
- Mask = peekThroughBitcast(Mask);
- if (match(Mask, m_SExt(m_Value(BoolVec))) &&
- BoolVec->getType()->isVectorTy() &&
- BoolVec->getType()->getScalarSizeInBits() == 1) {
- assert(Mask->getType()->getPrimitiveSizeInBits() ==
- II->getType()->getPrimitiveSizeInBits() &&
- "Not expecting mask and operands with different sizes");
-
- unsigned NumMaskElts =
- cast<VectorType>(Mask->getType())->getNumElements();
- unsigned NumOperandElts =
- cast<VectorType>(II->getType())->getNumElements();
- if (NumMaskElts == NumOperandElts)
- return SelectInst::Create(BoolVec, Op1, Op0);
-
- // If the mask has less elements than the operands, each mask bit maps to
- // multiple elements of the operands. Bitcast back and forth.
- if (NumMaskElts < NumOperandElts) {
- Value *CastOp0 = Builder.CreateBitCast(Op0, Mask->getType());
- Value *CastOp1 = Builder.CreateBitCast(Op1, Mask->getType());
- Value *Sel = Builder.CreateSelect(BoolVec, CastOp1, CastOp0);
- return new BitCastInst(Sel, II->getType());
- }
- }
-
- break;
- }
-
- case Intrinsic::x86_ssse3_pshuf_b_128:
- case Intrinsic::x86_avx2_pshuf_b:
- case Intrinsic::x86_avx512_pshuf_b_512:
- if (Value *V = simplifyX86pshufb(*II, Builder))
- return replaceInstUsesWith(*II, V);
- break;
-
- case Intrinsic::x86_avx_vpermilvar_ps:
- case Intrinsic::x86_avx_vpermilvar_ps_256:
- case Intrinsic::x86_avx512_vpermilvar_ps_512:
- case Intrinsic::x86_avx_vpermilvar_pd:
- case Intrinsic::x86_avx_vpermilvar_pd_256:
- case Intrinsic::x86_avx512_vpermilvar_pd_512:
- if (Value *V = simplifyX86vpermilvar(*II, Builder))
- return replaceInstUsesWith(*II, V);
- break;
-
- case Intrinsic::x86_avx2_permd:
- case Intrinsic::x86_avx2_permps:
- case Intrinsic::x86_avx512_permvar_df_256:
- case Intrinsic::x86_avx512_permvar_df_512:
- case Intrinsic::x86_avx512_permvar_di_256:
- case Intrinsic::x86_avx512_permvar_di_512:
- case Intrinsic::x86_avx512_permvar_hi_128:
- case Intrinsic::x86_avx512_permvar_hi_256:
- case Intrinsic::x86_avx512_permvar_hi_512:
- case Intrinsic::x86_avx512_permvar_qi_128:
- case Intrinsic::x86_avx512_permvar_qi_256:
- case Intrinsic::x86_avx512_permvar_qi_512:
- case Intrinsic::x86_avx512_permvar_sf_512:
- case Intrinsic::x86_avx512_permvar_si_512:
- if (Value *V = simplifyX86vpermv(*II, Builder))
- return replaceInstUsesWith(*II, V);
- break;
-
- case Intrinsic::x86_avx_maskload_ps:
- case Intrinsic::x86_avx_maskload_pd:
- case Intrinsic::x86_avx_maskload_ps_256:
- case Intrinsic::x86_avx_maskload_pd_256:
- case Intrinsic::x86_avx2_maskload_d:
- case Intrinsic::x86_avx2_maskload_q:
- case Intrinsic::x86_avx2_maskload_d_256:
- case Intrinsic::x86_avx2_maskload_q_256:
- if (Instruction *I = simplifyX86MaskedLoad(*II, *this))
- return I;
- break;
-
- case Intrinsic::x86_sse2_maskmov_dqu:
- case Intrinsic::x86_avx_maskstore_ps:
- case Intrinsic::x86_avx_maskstore_pd:
- case Intrinsic::x86_avx_maskstore_ps_256:
- case Intrinsic::x86_avx_maskstore_pd_256:
- case Intrinsic::x86_avx2_maskstore_d:
- case Intrinsic::x86_avx2_maskstore_q:
- case Intrinsic::x86_avx2_maskstore_d_256:
- case Intrinsic::x86_avx2_maskstore_q_256:
- if (simplifyX86MaskedStore(*II, *this))
- return nullptr;
- break;
-
- case Intrinsic::x86_addcarry_32:
- case Intrinsic::x86_addcarry_64:
- if (Value *V = simplifyX86addcarry(*II, Builder))
- return replaceInstUsesWith(*II, V);
- break;
-
- case Intrinsic::ppc_altivec_vperm:
- // Turn vperm(V1,V2,mask) -> shuffle(V1,V2,mask) if mask is a constant.
- // Note that ppc_altivec_vperm has a big-endian bias, so when creating
- // a vectorshuffle for little endian, we must undo the transformation
- // performed on vec_perm in altivec.h. That is, we must complement
- // the permutation mask with respect to 31 and reverse the order of
- // V1 and V2.
- if (Constant *Mask = dyn_cast<Constant>(II->getArgOperand(2))) {
- assert(cast<VectorType>(Mask->getType())->getNumElements() == 16 &&
- "Bad type for intrinsic!");
-
- // Check that all of the elements are integer constants or undefs.
- bool AllEltsOk = true;
- for (unsigned i = 0; i != 16; ++i) {
- Constant *Elt = Mask->getAggregateElement(i);
- if (!Elt || !(isa<ConstantInt>(Elt) || isa<UndefValue>(Elt))) {
- AllEltsOk = false;
- break;
- }
- }
-
- if (AllEltsOk) {
- // Cast the input vectors to byte vectors.
- Value *Op0 = Builder.CreateBitCast(II->getArgOperand(0),
- Mask->getType());
- Value *Op1 = Builder.CreateBitCast(II->getArgOperand(1),
- Mask->getType());
- Value *Result = UndefValue::get(Op0->getType());
-
- // Only extract each element once.
- Value *ExtractedElts[32];
- memset(ExtractedElts, 0, sizeof(ExtractedElts));
-
- for (unsigned i = 0; i != 16; ++i) {
- if (isa<UndefValue>(Mask->getAggregateElement(i)))
- continue;
- unsigned Idx =
- cast<ConstantInt>(Mask->getAggregateElement(i))->getZExtValue();
- Idx &= 31; // Match the hardware behavior.
- if (DL.isLittleEndian())
- Idx = 31 - Idx;
-
- if (!ExtractedElts[Idx]) {
- Value *Op0ToUse = (DL.isLittleEndian()) ? Op1 : Op0;
- Value *Op1ToUse = (DL.isLittleEndian()) ? Op0 : Op1;
- ExtractedElts[Idx] =
- Builder.CreateExtractElement(Idx < 16 ? Op0ToUse : Op1ToUse,
- Builder.getInt32(Idx&15));
- }
-
- // Insert this value into the result vector.
- Result = Builder.CreateInsertElement(Result, ExtractedElts[Idx],
- Builder.getInt32(i));
- }
- return CastInst::Create(Instruction::BitCast, Result, CI.getType());
- }
- }
- break;
-
- case Intrinsic::arm_neon_vld1: {
- Align MemAlign = getKnownAlignment(II->getArgOperand(0), DL, II, &AC, &DT);
- if (Value *V = simplifyNeonVld1(*II, MemAlign.value(), Builder))
- return replaceInstUsesWith(*II, V);
- break;
- }
-
- case Intrinsic::arm_neon_vld2:
- case Intrinsic::arm_neon_vld3:
- case Intrinsic::arm_neon_vld4:
- case Intrinsic::arm_neon_vld2lane:
- case Intrinsic::arm_neon_vld3lane:
- case Intrinsic::arm_neon_vld4lane:
- case Intrinsic::arm_neon_vst1:
- case Intrinsic::arm_neon_vst2:
- case Intrinsic::arm_neon_vst3:
- case Intrinsic::arm_neon_vst4:
- case Intrinsic::arm_neon_vst2lane:
- case Intrinsic::arm_neon_vst3lane:
- case Intrinsic::arm_neon_vst4lane: {
- Align MemAlign = getKnownAlignment(II->getArgOperand(0), DL, II, &AC, &DT);
- unsigned AlignArg = II->getNumArgOperands() - 1;
- Value *AlignArgOp = II->getArgOperand(AlignArg);
- MaybeAlign Align = cast<ConstantInt>(AlignArgOp)->getMaybeAlignValue();
- if (Align && *Align < MemAlign)
- return replaceOperand(*II, AlignArg,
- ConstantInt::get(Type::getInt32Ty(II->getContext()),
- MemAlign.value(), false));
- break;
- }
case Intrinsic::arm_neon_vtbl1:
case Intrinsic::aarch64_neon_tbl1:
@@ -3453,690 +1408,6 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
}
break;
}
- case Intrinsic::arm_mve_pred_i2v: {
- Value *Arg = II->getArgOperand(0);
- Value *ArgArg;
- if (match(Arg, m_Intrinsic<Intrinsic::arm_mve_pred_v2i>(m_Value(ArgArg))) &&
- II->getType() == ArgArg->getType())
- return replaceInstUsesWith(*II, ArgArg);
- Constant *XorMask;
- if (match(Arg,
- m_Xor(m_Intrinsic<Intrinsic::arm_mve_pred_v2i>(m_Value(ArgArg)),
- m_Constant(XorMask))) &&
- II->getType() == ArgArg->getType()) {
- if (auto *CI = dyn_cast<ConstantInt>(XorMask)) {
- if (CI->getValue().trunc(16).isAllOnesValue()) {
- auto TrueVector = Builder.CreateVectorSplat(
- cast<VectorType>(II->getType())->getNumElements(),
- Builder.getTrue());
- return BinaryOperator::Create(Instruction::Xor, ArgArg, TrueVector);
- }
- }
- }
- KnownBits ScalarKnown(32);
- if (SimplifyDemandedBits(II, 0, APInt::getLowBitsSet(32, 16),
- ScalarKnown, 0))
- return II;
- break;
- }
- case Intrinsic::arm_mve_pred_v2i: {
- Value *Arg = II->getArgOperand(0);
- Value *ArgArg;
- if (match(Arg, m_Intrinsic<Intrinsic::arm_mve_pred_i2v>(m_Value(ArgArg))))
- return replaceInstUsesWith(*II, ArgArg);
- if (!II->getMetadata(LLVMContext::MD_range)) {
- Type *IntTy32 = Type::getInt32Ty(II->getContext());
- Metadata *M[] = {
- ConstantAsMetadata::get(ConstantInt::get(IntTy32, 0)),
- ConstantAsMetadata::get(ConstantInt::get(IntTy32, 0xFFFF))
- };
- II->setMetadata(LLVMContext::MD_range, MDNode::get(II->getContext(), M));
- return II;
- }
- break;
- }
- case Intrinsic::arm_mve_vadc:
- case Intrinsic::arm_mve_vadc_predicated: {
- unsigned CarryOp =
- (II->getIntrinsicID() == Intrinsic::arm_mve_vadc_predicated) ? 3 : 2;
- assert(II->getArgOperand(CarryOp)->getType()->getScalarSizeInBits() == 32 &&
- "Bad type for intrinsic!");
-
- KnownBits CarryKnown(32);
- if (SimplifyDemandedBits(II, CarryOp, APInt::getOneBitSet(32, 29),
- CarryKnown))
- return II;
- break;
- }
- case Intrinsic::amdgcn_rcp: {
- Value *Src = II->getArgOperand(0);
-
- // TODO: Move to ConstantFolding/InstSimplify?
- if (isa<UndefValue>(Src)) {
- Type *Ty = II->getType();
- auto *QNaN = ConstantFP::get(Ty, APFloat::getQNaN(Ty->getFltSemantics()));
- return replaceInstUsesWith(CI, QNaN);
- }
-
- if (II->isStrictFP())
- break;
-
- if (const ConstantFP *C = dyn_cast<ConstantFP>(Src)) {
- const APFloat &ArgVal = C->getValueAPF();
- APFloat Val(ArgVal.getSemantics(), 1);
- Val.divide(ArgVal, APFloat::rmNearestTiesToEven);
-
- // This is more precise than the instruction may give.
- //
- // TODO: The instruction always flushes denormal results (except for f16),
- // should this also?
- return replaceInstUsesWith(CI, ConstantFP::get(II->getContext(), Val));
- }
-
- break;
- }
- case Intrinsic::amdgcn_rsq: {
- Value *Src = II->getArgOperand(0);
-
- // TODO: Move to ConstantFolding/InstSimplify?
- if (isa<UndefValue>(Src)) {
- Type *Ty = II->getType();
- auto *QNaN = ConstantFP::get(Ty, APFloat::getQNaN(Ty->getFltSemantics()));
- return replaceInstUsesWith(CI, QNaN);
- }
-
- break;
- }
- case Intrinsic::amdgcn_frexp_mant:
- case Intrinsic::amdgcn_frexp_exp: {
- Value *Src = II->getArgOperand(0);
- if (const ConstantFP *C = dyn_cast<ConstantFP>(Src)) {
- int Exp;
- APFloat Significand = frexp(C->getValueAPF(), Exp,
- APFloat::rmNearestTiesToEven);
-
- if (IID == Intrinsic::amdgcn_frexp_mant) {
- return replaceInstUsesWith(CI, ConstantFP::get(II->getContext(),
- Significand));
- }
-
- // Match instruction special case behavior.
- if (Exp == APFloat::IEK_NaN || Exp == APFloat::IEK_Inf)
- Exp = 0;
-
- return replaceInstUsesWith(CI, ConstantInt::get(II->getType(), Exp));
- }
-
- if (isa<UndefValue>(Src))
- return replaceInstUsesWith(CI, UndefValue::get(II->getType()));
-
- break;
- }
- case Intrinsic::amdgcn_class: {
- enum {
- S_NAN = 1 << 0, // Signaling NaN
- Q_NAN = 1 << 1, // Quiet NaN
- N_INFINITY = 1 << 2, // Negative infinity
- N_NORMAL = 1 << 3, // Negative normal
- N_SUBNORMAL = 1 << 4, // Negative subnormal
- N_ZERO = 1 << 5, // Negative zero
- P_ZERO = 1 << 6, // Positive zero
- P_SUBNORMAL = 1 << 7, // Positive subnormal
- P_NORMAL = 1 << 8, // Positive normal
- P_INFINITY = 1 << 9 // Positive infinity
- };
-
- const uint32_t FullMask = S_NAN | Q_NAN | N_INFINITY | N_NORMAL |
- N_SUBNORMAL | N_ZERO | P_ZERO | P_SUBNORMAL | P_NORMAL | P_INFINITY;
-
- Value *Src0 = II->getArgOperand(0);
- Value *Src1 = II->getArgOperand(1);
- const ConstantInt *CMask = dyn_cast<ConstantInt>(Src1);
- if (!CMask) {
- if (isa<UndefValue>(Src0))
- return replaceInstUsesWith(*II, UndefValue::get(II->getType()));
-
- if (isa<UndefValue>(Src1))
- return replaceInstUsesWith(*II, ConstantInt::get(II->getType(), false));
- break;
- }
-
- uint32_t Mask = CMask->getZExtValue();
-
- // If all tests are made, it doesn't matter what the value is.
- if ((Mask & FullMask) == FullMask)
- return replaceInstUsesWith(*II, ConstantInt::get(II->getType(), true));
-
- if ((Mask & FullMask) == 0)
- return replaceInstUsesWith(*II, ConstantInt::get(II->getType(), false));
-
- if (Mask == (S_NAN | Q_NAN)) {
- // Equivalent of isnan. Replace with standard fcmp.
- Value *FCmp = Builder.CreateFCmpUNO(Src0, Src0);
- FCmp->takeName(II);
- return replaceInstUsesWith(*II, FCmp);
- }
-
- if (Mask == (N_ZERO | P_ZERO)) {
- // Equivalent of == 0.
- Value *FCmp = Builder.CreateFCmpOEQ(
- Src0, ConstantFP::get(Src0->getType(), 0.0));
-
- FCmp->takeName(II);
- return replaceInstUsesWith(*II, FCmp);
- }
-
- // fp_class (nnan x), qnan|snan|other -> fp_class (nnan x), other
- if (((Mask & S_NAN) || (Mask & Q_NAN)) && isKnownNeverNaN(Src0, &TLI))
- return replaceOperand(*II, 1, ConstantInt::get(Src1->getType(),
- Mask & ~(S_NAN | Q_NAN)));
-
- const ConstantFP *CVal = dyn_cast<ConstantFP>(Src0);
- if (!CVal) {
- if (isa<UndefValue>(Src0))
- return replaceInstUsesWith(*II, UndefValue::get(II->getType()));
-
- // Clamp mask to used bits
- if ((Mask & FullMask) != Mask) {
- CallInst *NewCall = Builder.CreateCall(II->getCalledFunction(),
- { Src0, ConstantInt::get(Src1->getType(), Mask & FullMask) }
- );
-
- NewCall->takeName(II);
- return replaceInstUsesWith(*II, NewCall);
- }
-
- break;
- }
-
- const APFloat &Val = CVal->getValueAPF();
-
- bool Result =
- ((Mask & S_NAN) && Val.isNaN() && Val.isSignaling()) ||
- ((Mask & Q_NAN) && Val.isNaN() && !Val.isSignaling()) ||
- ((Mask & N_INFINITY) && Val.isInfinity() && Val.isNegative()) ||
- ((Mask & N_NORMAL) && Val.isNormal() && Val.isNegative()) ||
- ((Mask & N_SUBNORMAL) && Val.isDenormal() && Val.isNegative()) ||
- ((Mask & N_ZERO) && Val.isZero() && Val.isNegative()) ||
- ((Mask & P_ZERO) && Val.isZero() && !Val.isNegative()) ||
- ((Mask & P_SUBNORMAL) && Val.isDenormal() && !Val.isNegative()) ||
- ((Mask & P_NORMAL) && Val.isNormal() && !Val.isNegative()) ||
- ((Mask & P_INFINITY) && Val.isInfinity() && !Val.isNegative());
-
- return replaceInstUsesWith(*II, ConstantInt::get(II->getType(), Result));
- }
- case Intrinsic::amdgcn_cvt_pkrtz: {
- Value *Src0 = II->getArgOperand(0);
- Value *Src1 = II->getArgOperand(1);
- if (const ConstantFP *C0 = dyn_cast<ConstantFP>(Src0)) {
- if (const ConstantFP *C1 = dyn_cast<ConstantFP>(Src1)) {
- const fltSemantics &HalfSem
- = II->getType()->getScalarType()->getFltSemantics();
- bool LosesInfo;
- APFloat Val0 = C0->getValueAPF();
- APFloat Val1 = C1->getValueAPF();
- Val0.convert(HalfSem, APFloat::rmTowardZero, &LosesInfo);
- Val1.convert(HalfSem, APFloat::rmTowardZero, &LosesInfo);
-
- Constant *Folded = ConstantVector::get({
- ConstantFP::get(II->getContext(), Val0),
- ConstantFP::get(II->getContext(), Val1) });
- return replaceInstUsesWith(*II, Folded);
- }
- }
-
- if (isa<UndefValue>(Src0) && isa<UndefValue>(Src1))
- return replaceInstUsesWith(*II, UndefValue::get(II->getType()));
-
- break;
- }
- case Intrinsic::amdgcn_cvt_pknorm_i16:
- case Intrinsic::amdgcn_cvt_pknorm_u16:
- case Intrinsic::amdgcn_cvt_pk_i16:
- case Intrinsic::amdgcn_cvt_pk_u16: {
- Value *Src0 = II->getArgOperand(0);
- Value *Src1 = II->getArgOperand(1);
-
- if (isa<UndefValue>(Src0) && isa<UndefValue>(Src1))
- return replaceInstUsesWith(*II, UndefValue::get(II->getType()));
-
- break;
- }
- case Intrinsic::amdgcn_ubfe:
- case Intrinsic::amdgcn_sbfe: {
- // Decompose simple cases into standard shifts.
- Value *Src = II->getArgOperand(0);
- if (isa<UndefValue>(Src))
- return replaceInstUsesWith(*II, Src);
-
- unsigned Width;
- Type *Ty = II->getType();
- unsigned IntSize = Ty->getIntegerBitWidth();
-
- ConstantInt *CWidth = dyn_cast<ConstantInt>(II->getArgOperand(2));
- if (CWidth) {
- Width = CWidth->getZExtValue();
- if ((Width & (IntSize - 1)) == 0)
- return replaceInstUsesWith(*II, ConstantInt::getNullValue(Ty));
-
- // Hardware ignores high bits, so remove those.
- if (Width >= IntSize)
- return replaceOperand(*II, 2, ConstantInt::get(CWidth->getType(),
- Width & (IntSize - 1)));
- }
-
- unsigned Offset;
- ConstantInt *COffset = dyn_cast<ConstantInt>(II->getArgOperand(1));
- if (COffset) {
- Offset = COffset->getZExtValue();
- if (Offset >= IntSize)
- return replaceOperand(*II, 1, ConstantInt::get(COffset->getType(),
- Offset & (IntSize - 1)));
- }
-
- bool Signed = IID == Intrinsic::amdgcn_sbfe;
-
- if (!CWidth || !COffset)
- break;
-
- // The case of Width == 0 is handled above, which makes this tranformation
- // safe. If Width == 0, then the ashr and lshr instructions become poison
- // value since the shift amount would be equal to the bit size.
- assert(Width != 0);
-
- // TODO: This allows folding to undef when the hardware has specific
- // behavior?
- if (Offset + Width < IntSize) {
- Value *Shl = Builder.CreateShl(Src, IntSize - Offset - Width);
- Value *RightShift = Signed ? Builder.CreateAShr(Shl, IntSize - Width)
- : Builder.CreateLShr(Shl, IntSize - Width);
- RightShift->takeName(II);
- return replaceInstUsesWith(*II, RightShift);
- }
-
- Value *RightShift = Signed ? Builder.CreateAShr(Src, Offset)
- : Builder.CreateLShr(Src, Offset);
-
- RightShift->takeName(II);
- return replaceInstUsesWith(*II, RightShift);
- }
- case Intrinsic::amdgcn_exp:
- case Intrinsic::amdgcn_exp_compr: {
- ConstantInt *En = cast<ConstantInt>(II->getArgOperand(1));
- unsigned EnBits = En->getZExtValue();
- if (EnBits == 0xf)
- break; // All inputs enabled.
-
- bool IsCompr = IID == Intrinsic::amdgcn_exp_compr;
- bool Changed = false;
- for (int I = 0; I < (IsCompr ? 2 : 4); ++I) {
- if ((!IsCompr && (EnBits & (1 << I)) == 0) ||
- (IsCompr && ((EnBits & (0x3 << (2 * I))) == 0))) {
- Value *Src = II->getArgOperand(I + 2);
- if (!isa<UndefValue>(Src)) {
- replaceOperand(*II, I + 2, UndefValue::get(Src->getType()));
- Changed = true;
- }
- }
- }
-
- if (Changed)
- return II;
-
- break;
- }
- case Intrinsic::amdgcn_fmed3: {
- // Note this does not preserve proper sNaN behavior if IEEE-mode is enabled
- // for the shader.
-
- Value *Src0 = II->getArgOperand(0);
- Value *Src1 = II->getArgOperand(1);
- Value *Src2 = II->getArgOperand(2);
-
- // Checking for NaN before canonicalization provides better fidelity when
- // mapping other operations onto fmed3 since the order of operands is
- // unchanged.
- CallInst *NewCall = nullptr;
- if (match(Src0, m_NaN()) || isa<UndefValue>(Src0)) {
- NewCall = Builder.CreateMinNum(Src1, Src2);
- } else if (match(Src1, m_NaN()) || isa<UndefValue>(Src1)) {
- NewCall = Builder.CreateMinNum(Src0, Src2);
- } else if (match(Src2, m_NaN()) || isa<UndefValue>(Src2)) {
- NewCall = Builder.CreateMaxNum(Src0, Src1);
- }
-
- if (NewCall) {
- NewCall->copyFastMathFlags(II);
- NewCall->takeName(II);
- return replaceInstUsesWith(*II, NewCall);
- }
-
- bool Swap = false;
- // Canonicalize constants to RHS operands.
- //
- // fmed3(c0, x, c1) -> fmed3(x, c0, c1)
- if (isa<Constant>(Src0) && !isa<Constant>(Src1)) {
- std::swap(Src0, Src1);
- Swap = true;
- }
-
- if (isa<Constant>(Src1) && !isa<Constant>(Src2)) {
- std::swap(Src1, Src2);
- Swap = true;
- }
-
- if (isa<Constant>(Src0) && !isa<Constant>(Src1)) {
- std::swap(Src0, Src1);
- Swap = true;
- }
-
- if (Swap) {
- II->setArgOperand(0, Src0);
- II->setArgOperand(1, Src1);
- II->setArgOperand(2, Src2);
- return II;
- }
-
- if (const ConstantFP *C0 = dyn_cast<ConstantFP>(Src0)) {
- if (const ConstantFP *C1 = dyn_cast<ConstantFP>(Src1)) {
- if (const ConstantFP *C2 = dyn_cast<ConstantFP>(Src2)) {
- APFloat Result = fmed3AMDGCN(C0->getValueAPF(), C1->getValueAPF(),
- C2->getValueAPF());
- return replaceInstUsesWith(*II,
- ConstantFP::get(Builder.getContext(), Result));
- }
- }
- }
-
- break;
- }
- case Intrinsic::amdgcn_icmp:
- case Intrinsic::amdgcn_fcmp: {
- const ConstantInt *CC = cast<ConstantInt>(II->getArgOperand(2));
- // Guard against invalid arguments.
- int64_t CCVal = CC->getZExtValue();
- bool IsInteger = IID == Intrinsic::amdgcn_icmp;
- if ((IsInteger && (CCVal < CmpInst::FIRST_ICMP_PREDICATE ||
- CCVal > CmpInst::LAST_ICMP_PREDICATE)) ||
- (!IsInteger && (CCVal < CmpInst::FIRST_FCMP_PREDICATE ||
- CCVal > CmpInst::LAST_FCMP_PREDICATE)))
- break;
-
- Value *Src0 = II->getArgOperand(0);
- Value *Src1 = II->getArgOperand(1);
-
- if (auto *CSrc0 = dyn_cast<Constant>(Src0)) {
- if (auto *CSrc1 = dyn_cast<Constant>(Src1)) {
- Constant *CCmp = ConstantExpr::getCompare(CCVal, CSrc0, CSrc1);
- if (CCmp->isNullValue()) {
- return replaceInstUsesWith(
- *II, ConstantExpr::getSExt(CCmp, II->getType()));
- }
-
- // The result of V_ICMP/V_FCMP assembly instructions (which this
- // intrinsic exposes) is one bit per thread, masked with the EXEC
- // register (which contains the bitmask of live threads). So a
- // comparison that always returns true is the same as a read of the
- // EXEC register.
- Function *NewF = Intrinsic::getDeclaration(
- II->getModule(), Intrinsic::read_register, II->getType());
- Metadata *MDArgs[] = {MDString::get(II->getContext(), "exec")};
- MDNode *MD = MDNode::get(II->getContext(), MDArgs);
- Value *Args[] = {MetadataAsValue::get(II->getContext(), MD)};
- CallInst *NewCall = Builder.CreateCall(NewF, Args);
- NewCall->addAttribute(AttributeList::FunctionIndex,
- Attribute::Convergent);
- NewCall->takeName(II);
- return replaceInstUsesWith(*II, NewCall);
- }
-
- // Canonicalize constants to RHS.
- CmpInst::Predicate SwapPred
- = CmpInst::getSwappedPredicate(static_cast<CmpInst::Predicate>(CCVal));
- II->setArgOperand(0, Src1);
- II->setArgOperand(1, Src0);
- II->setArgOperand(2, ConstantInt::get(CC->getType(),
- static_cast<int>(SwapPred)));
- return II;
- }
-
- if (CCVal != CmpInst::ICMP_EQ && CCVal != CmpInst::ICMP_NE)
- break;
-
- // Canonicalize compare eq with true value to compare != 0
- // llvm.amdgcn.icmp(zext (i1 x), 1, eq)
- // -> llvm.amdgcn.icmp(zext (i1 x), 0, ne)
- // llvm.amdgcn.icmp(sext (i1 x), -1, eq)
- // -> llvm.amdgcn.icmp(sext (i1 x), 0, ne)
- Value *ExtSrc;
- if (CCVal == CmpInst::ICMP_EQ &&
- ((match(Src1, m_One()) && match(Src0, m_ZExt(m_Value(ExtSrc)))) ||
- (match(Src1, m_AllOnes()) && match(Src0, m_SExt(m_Value(ExtSrc))))) &&
- ExtSrc->getType()->isIntegerTy(1)) {
- replaceOperand(*II, 1, ConstantInt::getNullValue(Src1->getType()));
- replaceOperand(*II, 2, ConstantInt::get(CC->getType(), CmpInst::ICMP_NE));
- return II;
- }
-
- CmpInst::Predicate SrcPred;
- Value *SrcLHS;
- Value *SrcRHS;
-
- // Fold compare eq/ne with 0 from a compare result as the predicate to the
- // intrinsic. The typical use is a wave vote function in the library, which
- // will be fed from a user code condition compared with 0. Fold in the
- // redundant compare.
-
- // llvm.amdgcn.icmp([sz]ext ([if]cmp pred a, b), 0, ne)
- // -> llvm.amdgcn.[if]cmp(a, b, pred)
- //
- // llvm.amdgcn.icmp([sz]ext ([if]cmp pred a, b), 0, eq)
- // -> llvm.amdgcn.[if]cmp(a, b, inv pred)
- if (match(Src1, m_Zero()) &&
- match(Src0,
- m_ZExtOrSExt(m_Cmp(SrcPred, m_Value(SrcLHS), m_Value(SrcRHS))))) {
- if (CCVal == CmpInst::ICMP_EQ)
- SrcPred = CmpInst::getInversePredicate(SrcPred);
-
- Intrinsic::ID NewIID = CmpInst::isFPPredicate(SrcPred) ?
- Intrinsic::amdgcn_fcmp : Intrinsic::amdgcn_icmp;
-
- Type *Ty = SrcLHS->getType();
- if (auto *CmpType = dyn_cast<IntegerType>(Ty)) {
- // Promote to next legal integer type.
- unsigned Width = CmpType->getBitWidth();
- unsigned NewWidth = Width;
-
- // Don't do anything for i1 comparisons.
- if (Width == 1)
- break;
-
- if (Width <= 16)
- NewWidth = 16;
- else if (Width <= 32)
- NewWidth = 32;
- else if (Width <= 64)
- NewWidth = 64;
- else if (Width > 64)
- break; // Can't handle this.
-
- if (Width != NewWidth) {
- IntegerType *CmpTy = Builder.getIntNTy(NewWidth);
- if (CmpInst::isSigned(SrcPred)) {
- SrcLHS = Builder.CreateSExt(SrcLHS, CmpTy);
- SrcRHS = Builder.CreateSExt(SrcRHS, CmpTy);
- } else {
- SrcLHS = Builder.CreateZExt(SrcLHS, CmpTy);
- SrcRHS = Builder.CreateZExt(SrcRHS, CmpTy);
- }
- }
- } else if (!Ty->isFloatTy() && !Ty->isDoubleTy() && !Ty->isHalfTy())
- break;
-
- Function *NewF =
- Intrinsic::getDeclaration(II->getModule(), NewIID,
- { II->getType(),
- SrcLHS->getType() });
- Value *Args[] = { SrcLHS, SrcRHS,
- ConstantInt::get(CC->getType(), SrcPred) };
- CallInst *NewCall = Builder.CreateCall(NewF, Args);
- NewCall->takeName(II);
- return replaceInstUsesWith(*II, NewCall);
- }
-
- break;
- }
- case Intrinsic::amdgcn_ballot: {
- if (auto *Src = dyn_cast<ConstantInt>(II->getArgOperand(0))) {
- if (Src->isZero()) {
- // amdgcn.ballot(i1 0) is zero.
- return replaceInstUsesWith(*II, Constant::getNullValue(II->getType()));
- }
-
- if (Src->isOne()) {
- // amdgcn.ballot(i1 1) is exec.
- const char *RegName = "exec";
- if (II->getType()->isIntegerTy(32))
- RegName = "exec_lo";
- else if (!II->getType()->isIntegerTy(64))
- break;
-
- Function *NewF = Intrinsic::getDeclaration(
- II->getModule(), Intrinsic::read_register, II->getType());
- Metadata *MDArgs[] = {MDString::get(II->getContext(), RegName)};
- MDNode *MD = MDNode::get(II->getContext(), MDArgs);
- Value *Args[] = {MetadataAsValue::get(II->getContext(), MD)};
- CallInst *NewCall = Builder.CreateCall(NewF, Args);
- NewCall->addAttribute(AttributeList::FunctionIndex,
- Attribute::Convergent);
- NewCall->takeName(II);
- return replaceInstUsesWith(*II, NewCall);
- }
- }
- break;
- }
- case Intrinsic::amdgcn_wqm_vote: {
- // wqm_vote is identity when the argument is constant.
- if (!isa<Constant>(II->getArgOperand(0)))
- break;
-
- return replaceInstUsesWith(*II, II->getArgOperand(0));
- }
- case Intrinsic::amdgcn_kill: {
- const ConstantInt *C = dyn_cast<ConstantInt>(II->getArgOperand(0));
- if (!C || !C->getZExtValue())
- break;
-
- // amdgcn.kill(i1 1) is a no-op
- return eraseInstFromFunction(CI);
- }
- case Intrinsic::amdgcn_update_dpp: {
- Value *Old = II->getArgOperand(0);
-
- auto BC = cast<ConstantInt>(II->getArgOperand(5));
- auto RM = cast<ConstantInt>(II->getArgOperand(3));
- auto BM = cast<ConstantInt>(II->getArgOperand(4));
- if (BC->isZeroValue() ||
- RM->getZExtValue() != 0xF ||
- BM->getZExtValue() != 0xF ||
- isa<UndefValue>(Old))
- break;
-
- // If bound_ctrl = 1, row mask = bank mask = 0xf we can omit old value.
- return replaceOperand(*II, 0, UndefValue::get(Old->getType()));
- }
- case Intrinsic::amdgcn_permlane16:
- case Intrinsic::amdgcn_permlanex16: {
- // Discard vdst_in if it's not going to be read.
- Value *VDstIn = II->getArgOperand(0);
- if (isa<UndefValue>(VDstIn))
- break;
-
- ConstantInt *FetchInvalid = cast<ConstantInt>(II->getArgOperand(4));
- ConstantInt *BoundCtrl = cast<ConstantInt>(II->getArgOperand(5));
- if (!FetchInvalid->getZExtValue() && !BoundCtrl->getZExtValue())
- break;
-
- return replaceOperand(*II, 0, UndefValue::get(VDstIn->getType()));
- }
- case Intrinsic::amdgcn_readfirstlane:
- case Intrinsic::amdgcn_readlane: {
- // A constant value is trivially uniform.
- if (Constant *C = dyn_cast<Constant>(II->getArgOperand(0)))
- return replaceInstUsesWith(*II, C);
-
- // The rest of these may not be safe if the exec may not be the same between
- // the def and use.
- Value *Src = II->getArgOperand(0);
- Instruction *SrcInst = dyn_cast<Instruction>(Src);
- if (SrcInst && SrcInst->getParent() != II->getParent())
- break;
-
- // readfirstlane (readfirstlane x) -> readfirstlane x
- // readlane (readfirstlane x), y -> readfirstlane x
- if (match(Src, m_Intrinsic<Intrinsic::amdgcn_readfirstlane>()))
- return replaceInstUsesWith(*II, Src);
-
- if (IID == Intrinsic::amdgcn_readfirstlane) {
- // readfirstlane (readlane x, y) -> readlane x, y
- if (match(Src, m_Intrinsic<Intrinsic::amdgcn_readlane>()))
- return replaceInstUsesWith(*II, Src);
- } else {
- // readlane (readlane x, y), y -> readlane x, y
- if (match(Src, m_Intrinsic<Intrinsic::amdgcn_readlane>(
- m_Value(), m_Specific(II->getArgOperand(1)))))
- return replaceInstUsesWith(*II, Src);
- }
-
- break;
- }
- case Intrinsic::amdgcn_ldexp: {
- // FIXME: This doesn't introduce new instructions and belongs in
- // InstructionSimplify.
- Type *Ty = II->getType();
- Value *Op0 = II->getArgOperand(0);
- Value *Op1 = II->getArgOperand(1);
-
- // Folding undef to qnan is safe regardless of the FP mode.
- if (isa<UndefValue>(Op0)) {
- auto *QNaN = ConstantFP::get(Ty, APFloat::getQNaN(Ty->getFltSemantics()));
- return replaceInstUsesWith(*II, QNaN);
- }
-
- const APFloat *C = nullptr;
- match(Op0, m_APFloat(C));
-
- // FIXME: Should flush denorms depending on FP mode, but that's ignored
- // everywhere else.
- //
- // These cases should be safe, even with strictfp.
- // ldexp(0.0, x) -> 0.0
- // ldexp(-0.0, x) -> -0.0
- // ldexp(inf, x) -> inf
- // ldexp(-inf, x) -> -inf
- if (C && (C->isZero() || C->isInfinity()))
- return replaceInstUsesWith(*II, Op0);
-
- // With strictfp, be more careful about possibly needing to flush denormals
- // or not, and snan behavior depends on ieee_mode.
- if (II->isStrictFP())
- break;
-
- if (C && C->isNaN()) {
- // FIXME: We just need to make the nan quiet here, but that's unavailable
- // on APFloat, only IEEEfloat
- auto *Quieted = ConstantFP::get(
- Ty, scalbn(*C, 0, APFloat::rmNearestTiesToEven));
- return replaceInstUsesWith(*II, Quieted);
- }
-
- // ldexp(x, 0) -> x
- // ldexp(x, undef) -> x
- if (isa<UndefValue>(Op1) || match(Op1, m_ZeroInt()))
- return replaceInstUsesWith(*II, Op0);
-
- break;
- }
case Intrinsic::hexagon_V6_vandvrt:
case Intrinsic::hexagon_V6_vandvrt_128B: {
// Simplify Q -> V -> Q conversion.
@@ -4238,14 +1509,14 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
FunctionType *AssumeIntrinsicTy = II->getFunctionType();
Value *AssumeIntrinsic = II->getCalledOperand();
Value *A, *B;
- if (match(IIOperand, m_And(m_Value(A), m_Value(B)))) {
+ if (match(IIOperand, m_LogicalAnd(m_Value(A), m_Value(B)))) {
Builder.CreateCall(AssumeIntrinsicTy, AssumeIntrinsic, A, OpBundles,
II->getName());
Builder.CreateCall(AssumeIntrinsicTy, AssumeIntrinsic, B, II->getName());
return eraseInstFromFunction(*II);
}
// assume(!(a || b)) -> assume(!a); assume(!b);
- if (match(IIOperand, m_Not(m_Or(m_Value(A), m_Value(B))))) {
+ if (match(IIOperand, m_Not(m_LogicalOr(m_Value(A), m_Value(B))))) {
Builder.CreateCall(AssumeIntrinsicTy, AssumeIntrinsic,
Builder.CreateNot(A), OpBundles, II->getName());
Builder.CreateCall(AssumeIntrinsicTy, AssumeIntrinsic,
@@ -4282,59 +1553,104 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
AC.updateAffectedValues(II);
break;
}
- case Intrinsic::experimental_gc_relocate: {
- auto &GCR = *cast<GCRelocateInst>(II);
-
- // If we have two copies of the same pointer in the statepoint argument
- // list, canonicalize to one. This may let us common gc.relocates.
- if (GCR.getBasePtr() == GCR.getDerivedPtr() &&
- GCR.getBasePtrIndex() != GCR.getDerivedPtrIndex()) {
- auto *OpIntTy = GCR.getOperand(2)->getType();
- return replaceOperand(*II, 2,
- ConstantInt::get(OpIntTy, GCR.getBasePtrIndex()));
- }
+ case Intrinsic::experimental_gc_statepoint: {
+ GCStatepointInst &GCSP = *cast<GCStatepointInst>(II);
+ SmallPtrSet<Value *, 32> LiveGcValues;
+ for (const GCRelocateInst *Reloc : GCSP.getGCRelocates()) {
+ GCRelocateInst &GCR = *const_cast<GCRelocateInst *>(Reloc);
- // Translate facts known about a pointer before relocating into
- // facts about the relocate value, while being careful to
- // preserve relocation semantics.
- Value *DerivedPtr = GCR.getDerivedPtr();
+ // Remove the relocation if unused.
+ if (GCR.use_empty()) {
+ eraseInstFromFunction(GCR);
+ continue;
+ }
- // Remove the relocation if unused, note that this check is required
- // to prevent the cases below from looping forever.
- if (II->use_empty())
- return eraseInstFromFunction(*II);
+ Value *DerivedPtr = GCR.getDerivedPtr();
+ Value *BasePtr = GCR.getBasePtr();
- // Undef is undef, even after relocation.
- // TODO: provide a hook for this in GCStrategy. This is clearly legal for
- // most practical collectors, but there was discussion in the review thread
- // about whether it was legal for all possible collectors.
- if (isa<UndefValue>(DerivedPtr))
- // Use undef of gc_relocate's type to replace it.
- return replaceInstUsesWith(*II, UndefValue::get(II->getType()));
-
- if (auto *PT = dyn_cast<PointerType>(II->getType())) {
- // The relocation of null will be null for most any collector.
- // TODO: provide a hook for this in GCStrategy. There might be some
- // weird collector this property does not hold for.
- if (isa<ConstantPointerNull>(DerivedPtr))
- // Use null-pointer of gc_relocate's type to replace it.
- return replaceInstUsesWith(*II, ConstantPointerNull::get(PT));
-
- // isKnownNonNull -> nonnull attribute
- if (!II->hasRetAttr(Attribute::NonNull) &&
- isKnownNonZero(DerivedPtr, DL, 0, &AC, II, &DT)) {
- II->addAttribute(AttributeList::ReturnIndex, Attribute::NonNull);
- return II;
+ // Undef is undef, even after relocation.
+ if (isa<UndefValue>(DerivedPtr) || isa<UndefValue>(BasePtr)) {
+ replaceInstUsesWith(GCR, UndefValue::get(GCR.getType()));
+ eraseInstFromFunction(GCR);
+ continue;
+ }
+
+ if (auto *PT = dyn_cast<PointerType>(GCR.getType())) {
+ // The relocation of null will be null for most any collector.
+ // TODO: provide a hook for this in GCStrategy. There might be some
+ // weird collector this property does not hold for.
+ if (isa<ConstantPointerNull>(DerivedPtr)) {
+ // Use null-pointer of gc_relocate's type to replace it.
+ replaceInstUsesWith(GCR, ConstantPointerNull::get(PT));
+ eraseInstFromFunction(GCR);
+ continue;
+ }
+
+ // isKnownNonNull -> nonnull attribute
+ if (!GCR.hasRetAttr(Attribute::NonNull) &&
+ isKnownNonZero(DerivedPtr, DL, 0, &AC, II, &DT)) {
+ GCR.addAttribute(AttributeList::ReturnIndex, Attribute::NonNull);
+ // We discovered new fact, re-check users.
+ Worklist.pushUsersToWorkList(GCR);
+ }
+ }
+
+ // If we have two copies of the same pointer in the statepoint argument
+ // list, canonicalize to one. This may let us common gc.relocates.
+ if (GCR.getBasePtr() == GCR.getDerivedPtr() &&
+ GCR.getBasePtrIndex() != GCR.getDerivedPtrIndex()) {
+ auto *OpIntTy = GCR.getOperand(2)->getType();
+ GCR.setOperand(2, ConstantInt::get(OpIntTy, GCR.getBasePtrIndex()));
}
- }
- // TODO: bitcast(relocate(p)) -> relocate(bitcast(p))
- // Canonicalize on the type from the uses to the defs
+ // TODO: bitcast(relocate(p)) -> relocate(bitcast(p))
+ // Canonicalize on the type from the uses to the defs
- // TODO: relocate((gep p, C, C2, ...)) -> gep(relocate(p), C, C2, ...)
+ // TODO: relocate((gep p, C, C2, ...)) -> gep(relocate(p), C, C2, ...)
+ LiveGcValues.insert(BasePtr);
+ LiveGcValues.insert(DerivedPtr);
+ }
+ Optional<OperandBundleUse> Bundle =
+ GCSP.getOperandBundle(LLVMContext::OB_gc_live);
+ unsigned NumOfGCLives = LiveGcValues.size();
+ if (!Bundle.hasValue() || NumOfGCLives == Bundle->Inputs.size())
+ break;
+ // We can reduce the size of gc live bundle.
+ DenseMap<Value *, unsigned> Val2Idx;
+ std::vector<Value *> NewLiveGc;
+ for (unsigned I = 0, E = Bundle->Inputs.size(); I < E; ++I) {
+ Value *V = Bundle->Inputs[I];
+ if (Val2Idx.count(V))
+ continue;
+ if (LiveGcValues.count(V)) {
+ Val2Idx[V] = NewLiveGc.size();
+ NewLiveGc.push_back(V);
+ } else
+ Val2Idx[V] = NumOfGCLives;
+ }
+ // Update all gc.relocates
+ for (const GCRelocateInst *Reloc : GCSP.getGCRelocates()) {
+ GCRelocateInst &GCR = *const_cast<GCRelocateInst *>(Reloc);
+ Value *BasePtr = GCR.getBasePtr();
+ assert(Val2Idx.count(BasePtr) && Val2Idx[BasePtr] != NumOfGCLives &&
+ "Missed live gc for base pointer");
+ auto *OpIntTy1 = GCR.getOperand(1)->getType();
+ GCR.setOperand(1, ConstantInt::get(OpIntTy1, Val2Idx[BasePtr]));
+ Value *DerivedPtr = GCR.getDerivedPtr();
+ assert(Val2Idx.count(DerivedPtr) && Val2Idx[DerivedPtr] != NumOfGCLives &&
+ "Missed live gc for derived pointer");
+ auto *OpIntTy2 = GCR.getOperand(2)->getType();
+ GCR.setOperand(2, ConstantInt::get(OpIntTy2, Val2Idx[DerivedPtr]));
+ }
+ // Create new statepoint instruction.
+ OperandBundleDef NewBundle("gc-live", NewLiveGc);
+ if (isa<CallInst>(II))
+ return CallInst::CreateWithReplacedBundle(cast<CallInst>(II), NewBundle);
+ else
+ return InvokeInst::CreateWithReplacedBundle(cast<InvokeInst>(II),
+ NewBundle);
break;
}
-
case Intrinsic::experimental_guard: {
// Is this guard followed by another guard? We scan forward over a small
// fixed window of instructions to handle common cases with conditions
@@ -4367,12 +1683,114 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
}
break;
}
+ case Intrinsic::experimental_vector_insert: {
+ Value *Vec = II->getArgOperand(0);
+ Value *SubVec = II->getArgOperand(1);
+ Value *Idx = II->getArgOperand(2);
+ auto *DstTy = dyn_cast<FixedVectorType>(II->getType());
+ auto *VecTy = dyn_cast<FixedVectorType>(Vec->getType());
+ auto *SubVecTy = dyn_cast<FixedVectorType>(SubVec->getType());
+
+ // Only canonicalize if the destination vector, Vec, and SubVec are all
+ // fixed vectors.
+ if (DstTy && VecTy && SubVecTy) {
+ unsigned DstNumElts = DstTy->getNumElements();
+ unsigned VecNumElts = VecTy->getNumElements();
+ unsigned SubVecNumElts = SubVecTy->getNumElements();
+ unsigned IdxN = cast<ConstantInt>(Idx)->getZExtValue();
+
+ // The result of this call is undefined if IdxN is not a constant multiple
+ // of the SubVec's minimum vector length OR the insertion overruns Vec.
+ if (IdxN % SubVecNumElts != 0 || IdxN + SubVecNumElts > VecNumElts) {
+ replaceInstUsesWith(CI, UndefValue::get(CI.getType()));
+ return eraseInstFromFunction(CI);
+ }
+
+ // An insert that entirely overwrites Vec with SubVec is a nop.
+ if (VecNumElts == SubVecNumElts) {
+ replaceInstUsesWith(CI, SubVec);
+ return eraseInstFromFunction(CI);
+ }
+
+ // Widen SubVec into a vector of the same width as Vec, since
+ // shufflevector requires the two input vectors to be the same width.
+ // Elements beyond the bounds of SubVec within the widened vector are
+ // undefined.
+ SmallVector<int, 8> WidenMask;
+ unsigned i;
+ for (i = 0; i != SubVecNumElts; ++i)
+ WidenMask.push_back(i);
+ for (; i != VecNumElts; ++i)
+ WidenMask.push_back(UndefMaskElem);
+
+ Value *WidenShuffle = Builder.CreateShuffleVector(SubVec, WidenMask);
+
+ SmallVector<int, 8> Mask;
+ for (unsigned i = 0; i != IdxN; ++i)
+ Mask.push_back(i);
+ for (unsigned i = DstNumElts; i != DstNumElts + SubVecNumElts; ++i)
+ Mask.push_back(i);
+ for (unsigned i = IdxN + SubVecNumElts; i != DstNumElts; ++i)
+ Mask.push_back(i);
+
+ Value *Shuffle = Builder.CreateShuffleVector(Vec, WidenShuffle, Mask);
+ replaceInstUsesWith(CI, Shuffle);
+ return eraseInstFromFunction(CI);
+ }
+ break;
+ }
+ case Intrinsic::experimental_vector_extract: {
+ Value *Vec = II->getArgOperand(0);
+ Value *Idx = II->getArgOperand(1);
+
+ auto *DstTy = dyn_cast<FixedVectorType>(II->getType());
+ auto *VecTy = dyn_cast<FixedVectorType>(Vec->getType());
+
+ // Only canonicalize if the the destination vector and Vec are fixed
+ // vectors.
+ if (DstTy && VecTy) {
+ unsigned DstNumElts = DstTy->getNumElements();
+ unsigned VecNumElts = VecTy->getNumElements();
+ unsigned IdxN = cast<ConstantInt>(Idx)->getZExtValue();
+
+ // The result of this call is undefined if IdxN is not a constant multiple
+ // of the result type's minimum vector length OR the extraction overruns
+ // Vec.
+ if (IdxN % DstNumElts != 0 || IdxN + DstNumElts > VecNumElts) {
+ replaceInstUsesWith(CI, UndefValue::get(CI.getType()));
+ return eraseInstFromFunction(CI);
+ }
+
+ // Extracting the entirety of Vec is a nop.
+ if (VecNumElts == DstNumElts) {
+ replaceInstUsesWith(CI, Vec);
+ return eraseInstFromFunction(CI);
+ }
+
+ SmallVector<int, 8> Mask;
+ for (unsigned i = 0; i != DstNumElts; ++i)
+ Mask.push_back(IdxN + i);
+
+ Value *Shuffle =
+ Builder.CreateShuffleVector(Vec, UndefValue::get(VecTy), Mask);
+ replaceInstUsesWith(CI, Shuffle);
+ return eraseInstFromFunction(CI);
+ }
+ break;
+ }
+ default: {
+ // Handle target specific intrinsics
+ Optional<Instruction *> V = targetInstCombineIntrinsic(*II);
+ if (V.hasValue())
+ return V.getValue();
+ break;
+ }
}
return visitCallBase(*II);
}
// Fence instruction simplification
-Instruction *InstCombiner::visitFenceInst(FenceInst &FI) {
+Instruction *InstCombinerImpl::visitFenceInst(FenceInst &FI) {
// Remove identical consecutive fences.
Instruction *Next = FI.getNextNonDebugInstruction();
if (auto *NFI = dyn_cast<FenceInst>(Next))
@@ -4382,12 +1800,12 @@ Instruction *InstCombiner::visitFenceInst(FenceInst &FI) {
}
// InvokeInst simplification
-Instruction *InstCombiner::visitInvokeInst(InvokeInst &II) {
+Instruction *InstCombinerImpl::visitInvokeInst(InvokeInst &II) {
return visitCallBase(II);
}
// CallBrInst simplification
-Instruction *InstCombiner::visitCallBrInst(CallBrInst &CBI) {
+Instruction *InstCombinerImpl::visitCallBrInst(CallBrInst &CBI) {
return visitCallBase(CBI);
}
@@ -4427,7 +1845,7 @@ static bool isSafeToEliminateVarargsCast(const CallBase &Call,
return true;
}
-Instruction *InstCombiner::tryOptimizeCall(CallInst *CI) {
+Instruction *InstCombinerImpl::tryOptimizeCall(CallInst *CI) {
if (!CI->getCalledFunction()) return nullptr;
auto InstCombineRAUW = [this](Instruction *From, Value *With) {
@@ -4584,7 +2002,7 @@ static void annotateAnyAllocSite(CallBase &Call, const TargetLibraryInfo *TLI) {
}
/// Improvements for call, callbr and invoke instructions.
-Instruction *InstCombiner::visitCallBase(CallBase &Call) {
+Instruction *InstCombinerImpl::visitCallBase(CallBase &Call) {
if (isAllocationFn(&Call, &TLI))
annotateAnyAllocSite(Call, &TLI);
@@ -4640,7 +2058,7 @@ Instruction *InstCombiner::visitCallBase(CallBase &Call) {
!CalleeF->isDeclaration()) {
Instruction *OldCall = &Call;
CreateNonTerminatorUnreachable(OldCall);
- // If OldCall does not return void then replaceAllUsesWith undef.
+ // If OldCall does not return void then replaceInstUsesWith undef.
// This allows ValueHandlers and custom metadata to adjust itself.
if (!OldCall->getType()->isVoidTy())
replaceInstUsesWith(*OldCall, UndefValue::get(OldCall->getType()));
@@ -4659,7 +2077,7 @@ Instruction *InstCombiner::visitCallBase(CallBase &Call) {
if ((isa<ConstantPointerNull>(Callee) &&
!NullPointerIsDefined(Call.getFunction())) ||
isa<UndefValue>(Callee)) {
- // If Call does not return void then replaceAllUsesWith undef.
+ // If Call does not return void then replaceInstUsesWith undef.
// This allows ValueHandlers and custom metadata to adjust itself.
if (!Call.getType()->isVoidTy())
replaceInstUsesWith(Call, UndefValue::get(Call.getType()));
@@ -4735,7 +2153,7 @@ Instruction *InstCombiner::visitCallBase(CallBase &Call) {
/// If the callee is a constexpr cast of a function, attempt to move the cast to
/// the arguments of the call/callbr/invoke.
-bool InstCombiner::transformConstExprCastCall(CallBase &Call) {
+bool InstCombinerImpl::transformConstExprCastCall(CallBase &Call) {
auto *Callee =
dyn_cast<Function>(Call.getCalledOperand()->stripPointerCasts());
if (!Callee)
@@ -4834,6 +2252,9 @@ bool InstCombiner::transformConstExprCastCall(CallBase &Call) {
if (Call.isInAllocaArgument(i))
return false; // Cannot transform to and from inalloca.
+ if (CallerPAL.hasParamAttribute(i, Attribute::SwiftError))
+ return false;
+
// If the parameter is passed as a byval argument, then we have to have a
// sized type and the sized type has to have the same size as the old type.
if (ParamTy != ActTy && CallerPAL.hasParamAttribute(i, Attribute::ByVal)) {
@@ -5019,8 +2440,8 @@ bool InstCombiner::transformConstExprCastCall(CallBase &Call) {
/// Turn a call to a function created by init_trampoline / adjust_trampoline
/// intrinsic pair into a direct call to the underlying function.
Instruction *
-InstCombiner::transformCallThroughTrampoline(CallBase &Call,
- IntrinsicInst &Tramp) {
+InstCombinerImpl::transformCallThroughTrampoline(CallBase &Call,
+ IntrinsicInst &Tramp) {
Value *Callee = Call.getCalledOperand();
Type *CalleeTy = Callee->getType();
FunctionType *FTy = Call.getFunctionType();
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
index 3639edb5df4d..0b53007bb6dc 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
@@ -14,10 +14,11 @@
#include "llvm/ADT/SetVector.h"
#include "llvm/Analysis/ConstantFolding.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
-#include "llvm/IR/DataLayout.h"
#include "llvm/IR/DIBuilder.h"
+#include "llvm/IR/DataLayout.h"
#include "llvm/IR/PatternMatch.h"
#include "llvm/Support/KnownBits.h"
+#include "llvm/Transforms/InstCombine/InstCombiner.h"
#include <numeric>
using namespace llvm;
using namespace PatternMatch;
@@ -81,8 +82,8 @@ static Value *decomposeSimpleLinearExpr(Value *Val, unsigned &Scale,
/// If we find a cast of an allocation instruction, try to eliminate the cast by
/// moving the type information into the alloc.
-Instruction *InstCombiner::PromoteCastOfAllocation(BitCastInst &CI,
- AllocaInst &AI) {
+Instruction *InstCombinerImpl::PromoteCastOfAllocation(BitCastInst &CI,
+ AllocaInst &AI) {
PointerType *PTy = cast<PointerType>(CI.getType());
IRBuilderBase::InsertPointGuard Guard(Builder);
@@ -93,6 +94,18 @@ Instruction *InstCombiner::PromoteCastOfAllocation(BitCastInst &CI,
Type *CastElTy = PTy->getElementType();
if (!AllocElTy->isSized() || !CastElTy->isSized()) return nullptr;
+ // This optimisation does not work for cases where the cast type
+ // is scalable and the allocated type is not. This because we need to
+ // know how many times the casted type fits into the allocated type.
+ // For the opposite case where the allocated type is scalable and the
+ // cast type is not this leads to poor code quality due to the
+ // introduction of 'vscale' into the calculations. It seems better to
+ // bail out for this case too until we've done a proper cost-benefit
+ // analysis.
+ bool AllocIsScalable = isa<ScalableVectorType>(AllocElTy);
+ bool CastIsScalable = isa<ScalableVectorType>(CastElTy);
+ if (AllocIsScalable != CastIsScalable) return nullptr;
+
Align AllocElTyAlign = DL.getABITypeAlign(AllocElTy);
Align CastElTyAlign = DL.getABITypeAlign(CastElTy);
if (CastElTyAlign < AllocElTyAlign) return nullptr;
@@ -102,14 +115,15 @@ Instruction *InstCombiner::PromoteCastOfAllocation(BitCastInst &CI,
// same, we open the door to infinite loops of various kinds.
if (!AI.hasOneUse() && CastElTyAlign == AllocElTyAlign) return nullptr;
- uint64_t AllocElTySize = DL.getTypeAllocSize(AllocElTy);
- uint64_t CastElTySize = DL.getTypeAllocSize(CastElTy);
+ // The alloc and cast types should be either both fixed or both scalable.
+ uint64_t AllocElTySize = DL.getTypeAllocSize(AllocElTy).getKnownMinSize();
+ uint64_t CastElTySize = DL.getTypeAllocSize(CastElTy).getKnownMinSize();
if (CastElTySize == 0 || AllocElTySize == 0) return nullptr;
// If the allocation has multiple uses, only promote it if we're not
// shrinking the amount of memory being allocated.
- uint64_t AllocElTyStoreSize = DL.getTypeStoreSize(AllocElTy);
- uint64_t CastElTyStoreSize = DL.getTypeStoreSize(CastElTy);
+ uint64_t AllocElTyStoreSize = DL.getTypeStoreSize(AllocElTy).getKnownMinSize();
+ uint64_t CastElTyStoreSize = DL.getTypeStoreSize(CastElTy).getKnownMinSize();
if (!AI.hasOneUse() && CastElTyStoreSize < AllocElTyStoreSize) return nullptr;
// See if we can satisfy the modulus by pulling a scale out of the array
@@ -124,6 +138,9 @@ Instruction *InstCombiner::PromoteCastOfAllocation(BitCastInst &CI,
if ((AllocElTySize*ArraySizeScale) % CastElTySize != 0 ||
(AllocElTySize*ArrayOffset ) % CastElTySize != 0) return nullptr;
+ // We don't currently support arrays of scalable types.
+ assert(!AllocIsScalable || (ArrayOffset == 1 && ArraySizeScale == 0));
+
unsigned Scale = (AllocElTySize*ArraySizeScale)/CastElTySize;
Value *Amt = nullptr;
if (Scale == 1) {
@@ -160,8 +177,8 @@ Instruction *InstCombiner::PromoteCastOfAllocation(BitCastInst &CI,
/// Given an expression that CanEvaluateTruncated or CanEvaluateSExtd returns
/// true for, actually insert the code to evaluate the expression.
-Value *InstCombiner::EvaluateInDifferentType(Value *V, Type *Ty,
- bool isSigned) {
+Value *InstCombinerImpl::EvaluateInDifferentType(Value *V, Type *Ty,
+ bool isSigned) {
if (Constant *C = dyn_cast<Constant>(V)) {
C = ConstantExpr::getIntegerCast(C, Ty, isSigned /*Sext or ZExt*/);
// If we got a constantexpr back, try to simplify it with DL info.
@@ -229,8 +246,9 @@ Value *InstCombiner::EvaluateInDifferentType(Value *V, Type *Ty,
return InsertNewInstWith(Res, *I);
}
-Instruction::CastOps InstCombiner::isEliminableCastPair(const CastInst *CI1,
- const CastInst *CI2) {
+Instruction::CastOps
+InstCombinerImpl::isEliminableCastPair(const CastInst *CI1,
+ const CastInst *CI2) {
Type *SrcTy = CI1->getSrcTy();
Type *MidTy = CI1->getDestTy();
Type *DstTy = CI2->getDestTy();
@@ -257,7 +275,7 @@ Instruction::CastOps InstCombiner::isEliminableCastPair(const CastInst *CI1,
}
/// Implement the transforms common to all CastInst visitors.
-Instruction *InstCombiner::commonCastTransforms(CastInst &CI) {
+Instruction *InstCombinerImpl::commonCastTransforms(CastInst &CI) {
Value *Src = CI.getOperand(0);
// Try to eliminate a cast of a cast.
@@ -342,7 +360,7 @@ static bool canNotEvaluateInType(Value *V, Type *Ty) {
///
/// This function works on both vectors and scalars.
///
-static bool canEvaluateTruncated(Value *V, Type *Ty, InstCombiner &IC,
+static bool canEvaluateTruncated(Value *V, Type *Ty, InstCombinerImpl &IC,
Instruction *CxtI) {
if (canAlwaysEvaluateInType(V, Ty))
return true;
@@ -459,7 +477,8 @@ static bool canEvaluateTruncated(Value *V, Type *Ty, InstCombiner &IC,
/// trunc (lshr (bitcast <4 x i32> %X to i128), 32) to i32
/// --->
/// extractelement <4 x i32> %X, 1
-static Instruction *foldVecTruncToExtElt(TruncInst &Trunc, InstCombiner &IC) {
+static Instruction *foldVecTruncToExtElt(TruncInst &Trunc,
+ InstCombinerImpl &IC) {
Value *TruncOp = Trunc.getOperand(0);
Type *DestType = Trunc.getType();
if (!TruncOp->hasOneUse() || !isa<IntegerType>(DestType))
@@ -496,9 +515,9 @@ static Instruction *foldVecTruncToExtElt(TruncInst &Trunc, InstCombiner &IC) {
return ExtractElementInst::Create(VecInput, IC.Builder.getInt32(Elt));
}
-/// Rotate left/right may occur in a wider type than necessary because of type
-/// promotion rules. Try to narrow the inputs and convert to funnel shift.
-Instruction *InstCombiner::narrowRotate(TruncInst &Trunc) {
+/// Funnel/Rotate left/right may occur in a wider type than necessary because of
+/// type promotion rules. Try to narrow the inputs and convert to funnel shift.
+Instruction *InstCombinerImpl::narrowFunnelShift(TruncInst &Trunc) {
assert((isa<VectorType>(Trunc.getSrcTy()) ||
shouldChangeType(Trunc.getSrcTy(), Trunc.getType())) &&
"Don't narrow to an illegal scalar type");
@@ -510,32 +529,43 @@ Instruction *InstCombiner::narrowRotate(TruncInst &Trunc) {
if (!isPowerOf2_32(NarrowWidth))
return nullptr;
- // First, find an or'd pair of opposite shifts with the same shifted operand:
- // trunc (or (lshr ShVal, ShAmt0), (shl ShVal, ShAmt1))
- Value *Or0, *Or1;
- if (!match(Trunc.getOperand(0), m_OneUse(m_Or(m_Value(Or0), m_Value(Or1)))))
+ // First, find an or'd pair of opposite shifts:
+ // trunc (or (lshr ShVal0, ShAmt0), (shl ShVal1, ShAmt1))
+ BinaryOperator *Or0, *Or1;
+ if (!match(Trunc.getOperand(0), m_OneUse(m_Or(m_BinOp(Or0), m_BinOp(Or1)))))
return nullptr;
- Value *ShVal, *ShAmt0, *ShAmt1;
- if (!match(Or0, m_OneUse(m_LogicalShift(m_Value(ShVal), m_Value(ShAmt0)))) ||
- !match(Or1, m_OneUse(m_LogicalShift(m_Specific(ShVal), m_Value(ShAmt1)))))
+ Value *ShVal0, *ShVal1, *ShAmt0, *ShAmt1;
+ if (!match(Or0, m_OneUse(m_LogicalShift(m_Value(ShVal0), m_Value(ShAmt0)))) ||
+ !match(Or1, m_OneUse(m_LogicalShift(m_Value(ShVal1), m_Value(ShAmt1)))) ||
+ Or0->getOpcode() == Or1->getOpcode())
return nullptr;
- auto ShiftOpcode0 = cast<BinaryOperator>(Or0)->getOpcode();
- auto ShiftOpcode1 = cast<BinaryOperator>(Or1)->getOpcode();
- if (ShiftOpcode0 == ShiftOpcode1)
- return nullptr;
+ // Canonicalize to or(shl(ShVal0, ShAmt0), lshr(ShVal1, ShAmt1)).
+ if (Or0->getOpcode() == BinaryOperator::LShr) {
+ std::swap(Or0, Or1);
+ std::swap(ShVal0, ShVal1);
+ std::swap(ShAmt0, ShAmt1);
+ }
+ assert(Or0->getOpcode() == BinaryOperator::Shl &&
+ Or1->getOpcode() == BinaryOperator::LShr &&
+ "Illegal or(shift,shift) pair");
- // Match the shift amount operands for a rotate pattern. This always matches
- // a subtraction on the R operand.
- auto matchShiftAmount = [](Value *L, Value *R, unsigned Width) -> Value * {
+ // Match the shift amount operands for a funnel/rotate pattern. This always
+ // matches a subtraction on the R operand.
+ auto matchShiftAmount = [&](Value *L, Value *R, unsigned Width) -> Value * {
// The shift amounts may add up to the narrow bit width:
- // (shl ShVal, L) | (lshr ShVal, Width - L)
+ // (shl ShVal0, L) | (lshr ShVal1, Width - L)
if (match(R, m_OneUse(m_Sub(m_SpecificInt(Width), m_Specific(L)))))
return L;
+ // The following patterns currently only work for rotation patterns.
+ // TODO: Add more general funnel-shift compatible patterns.
+ if (ShVal0 != ShVal1)
+ return nullptr;
+
// The shift amount may be masked with negation:
- // (shl ShVal, (X & (Width - 1))) | (lshr ShVal, ((-X) & (Width - 1)))
+ // (shl ShVal0, (X & (Width - 1))) | (lshr ShVal1, ((-X) & (Width - 1)))
Value *X;
unsigned Mask = Width - 1;
if (match(L, m_And(m_Value(X), m_SpecificInt(Mask))) &&
@@ -551,10 +581,10 @@ Instruction *InstCombiner::narrowRotate(TruncInst &Trunc) {
};
Value *ShAmt = matchShiftAmount(ShAmt0, ShAmt1, NarrowWidth);
- bool SubIsOnLHS = false;
+ bool IsFshl = true; // Sub on LSHR.
if (!ShAmt) {
ShAmt = matchShiftAmount(ShAmt1, ShAmt0, NarrowWidth);
- SubIsOnLHS = true;
+ IsFshl = false; // Sub on SHL.
}
if (!ShAmt)
return nullptr;
@@ -563,26 +593,28 @@ Instruction *InstCombiner::narrowRotate(TruncInst &Trunc) {
// will be a zext, but it could also be the result of an 'and' or 'shift'.
unsigned WideWidth = Trunc.getSrcTy()->getScalarSizeInBits();
APInt HiBitMask = APInt::getHighBitsSet(WideWidth, WideWidth - NarrowWidth);
- if (!MaskedValueIsZero(ShVal, HiBitMask, 0, &Trunc))
+ if (!MaskedValueIsZero(ShVal0, HiBitMask, 0, &Trunc) ||
+ !MaskedValueIsZero(ShVal1, HiBitMask, 0, &Trunc))
return nullptr;
// We have an unnecessarily wide rotate!
- // trunc (or (lshr ShVal, ShAmt), (shl ShVal, BitWidth - ShAmt))
+ // trunc (or (lshr ShVal0, ShAmt), (shl ShVal1, BitWidth - ShAmt))
// Narrow the inputs and convert to funnel shift intrinsic:
// llvm.fshl.i8(trunc(ShVal), trunc(ShVal), trunc(ShAmt))
Value *NarrowShAmt = Builder.CreateTrunc(ShAmt, DestTy);
- Value *X = Builder.CreateTrunc(ShVal, DestTy);
- bool IsFshl = (!SubIsOnLHS && ShiftOpcode0 == BinaryOperator::Shl) ||
- (SubIsOnLHS && ShiftOpcode1 == BinaryOperator::Shl);
+ Value *X, *Y;
+ X = Y = Builder.CreateTrunc(ShVal0, DestTy);
+ if (ShVal0 != ShVal1)
+ Y = Builder.CreateTrunc(ShVal1, DestTy);
Intrinsic::ID IID = IsFshl ? Intrinsic::fshl : Intrinsic::fshr;
Function *F = Intrinsic::getDeclaration(Trunc.getModule(), IID, DestTy);
- return IntrinsicInst::Create(F, { X, X, NarrowShAmt });
+ return IntrinsicInst::Create(F, {X, Y, NarrowShAmt});
}
/// Try to narrow the width of math or bitwise logic instructions by pulling a
/// truncate ahead of binary operators.
/// TODO: Transforms for truncated shifts should be moved into here.
-Instruction *InstCombiner::narrowBinOp(TruncInst &Trunc) {
+Instruction *InstCombinerImpl::narrowBinOp(TruncInst &Trunc) {
Type *SrcTy = Trunc.getSrcTy();
Type *DestTy = Trunc.getType();
if (!isa<VectorType>(SrcTy) && !shouldChangeType(SrcTy, DestTy))
@@ -631,7 +663,7 @@ Instruction *InstCombiner::narrowBinOp(TruncInst &Trunc) {
default: break;
}
- if (Instruction *NarrowOr = narrowRotate(Trunc))
+ if (Instruction *NarrowOr = narrowFunnelShift(Trunc))
return NarrowOr;
return nullptr;
@@ -687,7 +719,7 @@ static Instruction *shrinkInsertElt(CastInst &Trunc,
return nullptr;
}
-Instruction *InstCombiner::visitTrunc(TruncInst &Trunc) {
+Instruction *InstCombinerImpl::visitTrunc(TruncInst &Trunc) {
if (Instruction *Result = commonCastTransforms(Trunc))
return Result;
@@ -695,7 +727,6 @@ Instruction *InstCombiner::visitTrunc(TruncInst &Trunc) {
Type *DestTy = Trunc.getType(), *SrcTy = Src->getType();
unsigned DestWidth = DestTy->getScalarSizeInBits();
unsigned SrcWidth = SrcTy->getScalarSizeInBits();
- ConstantInt *Cst;
// Attempt to truncate the entire input expression tree to the destination
// type. Only do this if the dest type is a simple type, don't convert the
@@ -782,56 +813,60 @@ Instruction *InstCombiner::visitTrunc(TruncInst &Trunc) {
}
}
- // FIXME: Maybe combine the next two transforms to handle the no cast case
- // more efficiently. Support vector types. Cleanup code by using m_OneUse.
-
- // Transform trunc(lshr (zext A), Cst) to eliminate one type conversion.
- Value *A = nullptr;
- if (Src->hasOneUse() &&
- match(Src, m_LShr(m_ZExt(m_Value(A)), m_ConstantInt(Cst)))) {
- // We have three types to worry about here, the type of A, the source of
- // the truncate (MidSize), and the destination of the truncate. We know that
- // ASize < MidSize and MidSize > ResultSize, but don't know the relation
- // between ASize and ResultSize.
- unsigned ASize = A->getType()->getPrimitiveSizeInBits();
-
- // If the shift amount is larger than the size of A, then the result is
- // known to be zero because all the input bits got shifted out.
- if (Cst->getZExtValue() >= ASize)
- return replaceInstUsesWith(Trunc, Constant::getNullValue(DestTy));
-
- // Since we're doing an lshr and a zero extend, and know that the shift
- // amount is smaller than ASize, it is always safe to do the shift in A's
- // type, then zero extend or truncate to the result.
- Value *Shift = Builder.CreateLShr(A, Cst->getZExtValue());
- Shift->takeName(Src);
- return CastInst::CreateIntegerCast(Shift, DestTy, false);
- }
-
- const APInt *C;
- if (match(Src, m_LShr(m_SExt(m_Value(A)), m_APInt(C)))) {
+ Value *A;
+ Constant *C;
+ if (match(Src, m_LShr(m_SExt(m_Value(A)), m_Constant(C)))) {
unsigned AWidth = A->getType()->getScalarSizeInBits();
unsigned MaxShiftAmt = SrcWidth - std::max(DestWidth, AWidth);
+ auto *OldSh = cast<Instruction>(Src);
+ bool IsExact = OldSh->isExact();
// If the shift is small enough, all zero bits created by the shift are
// removed by the trunc.
- if (C->getZExtValue() <= MaxShiftAmt) {
+ if (match(C, m_SpecificInt_ICMP(ICmpInst::ICMP_ULE,
+ APInt(SrcWidth, MaxShiftAmt)))) {
// trunc (lshr (sext A), C) --> ashr A, C
if (A->getType() == DestTy) {
- unsigned ShAmt = std::min((unsigned)C->getZExtValue(), DestWidth - 1);
- return BinaryOperator::CreateAShr(A, ConstantInt::get(DestTy, ShAmt));
+ Constant *MaxAmt = ConstantInt::get(SrcTy, DestWidth - 1, false);
+ Constant *ShAmt = ConstantExpr::getUMin(C, MaxAmt);
+ ShAmt = ConstantExpr::getTrunc(ShAmt, A->getType());
+ ShAmt = Constant::mergeUndefsWith(ShAmt, C);
+ return IsExact ? BinaryOperator::CreateExactAShr(A, ShAmt)
+ : BinaryOperator::CreateAShr(A, ShAmt);
}
// The types are mismatched, so create a cast after shifting:
// trunc (lshr (sext A), C) --> sext/trunc (ashr A, C)
if (Src->hasOneUse()) {
- unsigned ShAmt = std::min((unsigned)C->getZExtValue(), AWidth - 1);
- Value *Shift = Builder.CreateAShr(A, ShAmt);
+ Constant *MaxAmt = ConstantInt::get(SrcTy, AWidth - 1, false);
+ Constant *ShAmt = ConstantExpr::getUMin(C, MaxAmt);
+ ShAmt = ConstantExpr::getTrunc(ShAmt, A->getType());
+ Value *Shift = Builder.CreateAShr(A, ShAmt, "", IsExact);
return CastInst::CreateIntegerCast(Shift, DestTy, true);
}
}
// TODO: Mask high bits with 'and'.
}
+ // trunc (*shr (trunc A), C) --> trunc(*shr A, C)
+ if (match(Src, m_OneUse(m_Shr(m_Trunc(m_Value(A)), m_Constant(C))))) {
+ unsigned MaxShiftAmt = SrcWidth - DestWidth;
+
+ // If the shift is small enough, all zero/sign bits created by the shift are
+ // removed by the trunc.
+ if (match(C, m_SpecificInt_ICMP(ICmpInst::ICMP_ULE,
+ APInt(SrcWidth, MaxShiftAmt)))) {
+ auto *OldShift = cast<Instruction>(Src);
+ bool IsExact = OldShift->isExact();
+ auto *ShAmt = ConstantExpr::getIntegerCast(C, A->getType(), true);
+ ShAmt = Constant::mergeUndefsWith(ShAmt, C);
+ Value *Shift =
+ OldShift->getOpcode() == Instruction::AShr
+ ? Builder.CreateAShr(A, ShAmt, OldShift->getName(), IsExact)
+ : Builder.CreateLShr(A, ShAmt, OldShift->getName(), IsExact);
+ return CastInst::CreateTruncOrBitCast(Shift, DestTy);
+ }
+ }
+
if (Instruction *I = narrowBinOp(Trunc))
return I;
@@ -841,20 +876,19 @@ Instruction *InstCombiner::visitTrunc(TruncInst &Trunc) {
if (Instruction *I = shrinkInsertElt(Trunc, Builder))
return I;
- if (Src->hasOneUse() && isa<IntegerType>(SrcTy) &&
- shouldChangeType(SrcTy, DestTy)) {
+ if (Src->hasOneUse() &&
+ (isa<VectorType>(SrcTy) || shouldChangeType(SrcTy, DestTy))) {
// Transform "trunc (shl X, cst)" -> "shl (trunc X), cst" so long as the
// dest type is native and cst < dest size.
- if (match(Src, m_Shl(m_Value(A), m_ConstantInt(Cst))) &&
+ if (match(Src, m_Shl(m_Value(A), m_Constant(C))) &&
!match(A, m_Shr(m_Value(), m_Constant()))) {
// Skip shifts of shift by constants. It undoes a combine in
// FoldShiftByConstant and is the extend in reg pattern.
- if (Cst->getValue().ult(DestWidth)) {
+ APInt Threshold = APInt(C->getType()->getScalarSizeInBits(), DestWidth);
+ if (match(C, m_SpecificInt_ICMP(ICmpInst::ICMP_ULT, Threshold))) {
Value *NewTrunc = Builder.CreateTrunc(A, DestTy, A->getName() + ".tr");
-
- return BinaryOperator::Create(
- Instruction::Shl, NewTrunc,
- ConstantInt::get(DestTy, Cst->getValue().trunc(DestWidth)));
+ return BinaryOperator::Create(Instruction::Shl, NewTrunc,
+ ConstantExpr::getTrunc(C, DestTy));
}
}
}
@@ -871,21 +905,23 @@ Instruction *InstCombiner::visitTrunc(TruncInst &Trunc) {
// --->
// extractelement <8 x i32> (bitcast <4 x i64> %X to <8 x i32>), i32 0
Value *VecOp;
+ ConstantInt *Cst;
if (match(Src, m_OneUse(m_ExtractElt(m_Value(VecOp), m_ConstantInt(Cst))))) {
auto *VecOpTy = cast<VectorType>(VecOp->getType());
- unsigned VecNumElts = VecOpTy->getNumElements();
+ auto VecElts = VecOpTy->getElementCount();
// A badly fit destination size would result in an invalid cast.
if (SrcWidth % DestWidth == 0) {
uint64_t TruncRatio = SrcWidth / DestWidth;
- uint64_t BitCastNumElts = VecNumElts * TruncRatio;
+ uint64_t BitCastNumElts = VecElts.getKnownMinValue() * TruncRatio;
uint64_t VecOpIdx = Cst->getZExtValue();
uint64_t NewIdx = DL.isBigEndian() ? (VecOpIdx + 1) * TruncRatio - 1
: VecOpIdx * TruncRatio;
assert(BitCastNumElts <= std::numeric_limits<uint32_t>::max() &&
"overflow 32-bits");
- auto *BitCastTo = FixedVectorType::get(DestTy, BitCastNumElts);
+ auto *BitCastTo =
+ VectorType::get(DestTy, BitCastNumElts, VecElts.isScalable());
Value *BitCast = Builder.CreateBitCast(VecOp, BitCastTo);
return ExtractElementInst::Create(BitCast, Builder.getInt32(NewIdx));
}
@@ -894,8 +930,8 @@ Instruction *InstCombiner::visitTrunc(TruncInst &Trunc) {
return nullptr;
}
-Instruction *InstCombiner::transformZExtICmp(ICmpInst *Cmp, ZExtInst &Zext,
- bool DoTransform) {
+Instruction *InstCombinerImpl::transformZExtICmp(ICmpInst *Cmp, ZExtInst &Zext,
+ bool DoTransform) {
// If we are just checking for a icmp eq of a single bit and zext'ing it
// to an integer, then shift the bit to the appropriate place and then
// cast to integer to avoid the comparison.
@@ -1031,7 +1067,7 @@ Instruction *InstCombiner::transformZExtICmp(ICmpInst *Cmp, ZExtInst &Zext,
///
/// This function works on both vectors and scalars.
static bool canEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear,
- InstCombiner &IC, Instruction *CxtI) {
+ InstCombinerImpl &IC, Instruction *CxtI) {
BitsToClear = 0;
if (canAlwaysEvaluateInType(V, Ty))
return true;
@@ -1136,7 +1172,7 @@ static bool canEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear,
}
}
-Instruction *InstCombiner::visitZExt(ZExtInst &CI) {
+Instruction *InstCombinerImpl::visitZExt(ZExtInst &CI) {
// If this zero extend is only used by a truncate, let the truncate be
// eliminated before we try to optimize this zext.
if (CI.hasOneUse() && isa<TruncInst>(CI.user_back()))
@@ -1274,7 +1310,8 @@ Instruction *InstCombiner::visitZExt(ZExtInst &CI) {
}
/// Transform (sext icmp) to bitwise / integer operations to eliminate the icmp.
-Instruction *InstCombiner::transformSExtICmp(ICmpInst *ICI, Instruction &CI) {
+Instruction *InstCombinerImpl::transformSExtICmp(ICmpInst *ICI,
+ Instruction &CI) {
Value *Op0 = ICI->getOperand(0), *Op1 = ICI->getOperand(1);
ICmpInst::Predicate Pred = ICI->getPredicate();
@@ -1410,7 +1447,7 @@ static bool canEvaluateSExtd(Value *V, Type *Ty) {
return false;
}
-Instruction *InstCombiner::visitSExt(SExtInst &CI) {
+Instruction *InstCombinerImpl::visitSExt(SExtInst &CI) {
// If this sign extend is only used by a truncate, let the truncate be
// eliminated before we try to optimize this sext.
if (CI.hasOneUse() && isa<TruncInst>(CI.user_back()))
@@ -1473,31 +1510,33 @@ Instruction *InstCombiner::visitSExt(SExtInst &CI) {
// for a truncate. If the source and dest are the same type, eliminate the
// trunc and extend and just do shifts. For example, turn:
// %a = trunc i32 %i to i8
- // %b = shl i8 %a, 6
- // %c = ashr i8 %b, 6
+ // %b = shl i8 %a, C
+ // %c = ashr i8 %b, C
// %d = sext i8 %c to i32
// into:
- // %a = shl i32 %i, 30
- // %d = ashr i32 %a, 30
+ // %a = shl i32 %i, 32-(8-C)
+ // %d = ashr i32 %a, 32-(8-C)
Value *A = nullptr;
// TODO: Eventually this could be subsumed by EvaluateInDifferentType.
Constant *BA = nullptr, *CA = nullptr;
if (match(Src, m_AShr(m_Shl(m_Trunc(m_Value(A)), m_Constant(BA)),
m_Constant(CA))) &&
- BA == CA && A->getType() == CI.getType()) {
- unsigned MidSize = Src->getType()->getScalarSizeInBits();
- unsigned SrcDstSize = CI.getType()->getScalarSizeInBits();
- Constant *SizeDiff = ConstantInt::get(CA->getType(), SrcDstSize - MidSize);
- Constant *ShAmt = ConstantExpr::getAdd(CA, SizeDiff);
- Constant *ShAmtExt = ConstantExpr::getSExt(ShAmt, CI.getType());
- A = Builder.CreateShl(A, ShAmtExt, CI.getName());
- return BinaryOperator::CreateAShr(A, ShAmtExt);
+ BA->isElementWiseEqual(CA) && A->getType() == DestTy) {
+ Constant *WideCurrShAmt = ConstantExpr::getSExt(CA, DestTy);
+ Constant *NumLowbitsLeft = ConstantExpr::getSub(
+ ConstantInt::get(DestTy, SrcTy->getScalarSizeInBits()), WideCurrShAmt);
+ Constant *NewShAmt = ConstantExpr::getSub(
+ ConstantInt::get(DestTy, DestTy->getScalarSizeInBits()),
+ NumLowbitsLeft);
+ NewShAmt =
+ Constant::mergeUndefsWith(Constant::mergeUndefsWith(NewShAmt, BA), CA);
+ A = Builder.CreateShl(A, NewShAmt, CI.getName());
+ return BinaryOperator::CreateAShr(A, NewShAmt);
}
return nullptr;
}
-
/// Return a Constant* for the specified floating-point constant if it fits
/// in the specified FP type without changing its value.
static bool fitsInFPType(ConstantFP *CFP, const fltSemantics &Sem) {
@@ -1535,7 +1574,7 @@ static Type *shrinkFPConstantVector(Value *V) {
Type *MinType = nullptr;
- unsigned NumElts = CVVTy->getNumElements();
+ unsigned NumElts = cast<FixedVectorType>(CVVTy)->getNumElements();
for (unsigned i = 0; i != NumElts; ++i) {
auto *CFP = dyn_cast_or_null<ConstantFP>(CV->getAggregateElement(i));
if (!CFP)
@@ -1616,7 +1655,7 @@ static bool isKnownExactCastIntToFP(CastInst &I) {
return false;
}
-Instruction *InstCombiner::visitFPTrunc(FPTruncInst &FPT) {
+Instruction *InstCombinerImpl::visitFPTrunc(FPTruncInst &FPT) {
if (Instruction *I = commonCastTransforms(FPT))
return I;
@@ -1800,7 +1839,7 @@ Instruction *InstCombiner::visitFPTrunc(FPTruncInst &FPT) {
return nullptr;
}
-Instruction *InstCombiner::visitFPExt(CastInst &FPExt) {
+Instruction *InstCombinerImpl::visitFPExt(CastInst &FPExt) {
// If the source operand is a cast from integer to FP and known exact, then
// cast the integer operand directly to the destination type.
Type *Ty = FPExt.getType();
@@ -1818,7 +1857,7 @@ Instruction *InstCombiner::visitFPExt(CastInst &FPExt) {
/// This is safe if the intermediate type has enough bits in its mantissa to
/// accurately represent all values of X. For example, this won't work with
/// i64 -> float -> i64.
-Instruction *InstCombiner::foldItoFPtoI(CastInst &FI) {
+Instruction *InstCombinerImpl::foldItoFPtoI(CastInst &FI) {
if (!isa<UIToFPInst>(FI.getOperand(0)) && !isa<SIToFPInst>(FI.getOperand(0)))
return nullptr;
@@ -1858,29 +1897,29 @@ Instruction *InstCombiner::foldItoFPtoI(CastInst &FI) {
return replaceInstUsesWith(FI, X);
}
-Instruction *InstCombiner::visitFPToUI(FPToUIInst &FI) {
+Instruction *InstCombinerImpl::visitFPToUI(FPToUIInst &FI) {
if (Instruction *I = foldItoFPtoI(FI))
return I;
return commonCastTransforms(FI);
}
-Instruction *InstCombiner::visitFPToSI(FPToSIInst &FI) {
+Instruction *InstCombinerImpl::visitFPToSI(FPToSIInst &FI) {
if (Instruction *I = foldItoFPtoI(FI))
return I;
return commonCastTransforms(FI);
}
-Instruction *InstCombiner::visitUIToFP(CastInst &CI) {
+Instruction *InstCombinerImpl::visitUIToFP(CastInst &CI) {
return commonCastTransforms(CI);
}
-Instruction *InstCombiner::visitSIToFP(CastInst &CI) {
+Instruction *InstCombinerImpl::visitSIToFP(CastInst &CI) {
return commonCastTransforms(CI);
}
-Instruction *InstCombiner::visitIntToPtr(IntToPtrInst &CI) {
+Instruction *InstCombinerImpl::visitIntToPtr(IntToPtrInst &CI) {
// If the source integer type is not the intptr_t type for this target, do a
// trunc or zext to the intptr_t type, then inttoptr of it. This allows the
// cast to be exposed to other transforms.
@@ -1903,7 +1942,7 @@ Instruction *InstCombiner::visitIntToPtr(IntToPtrInst &CI) {
}
/// Implement the transforms for cast of pointer (bitcast/ptrtoint)
-Instruction *InstCombiner::commonPointerCastTransforms(CastInst &CI) {
+Instruction *InstCombinerImpl::commonPointerCastTransforms(CastInst &CI) {
Value *Src = CI.getOperand(0);
if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Src)) {
@@ -1925,26 +1964,37 @@ Instruction *InstCombiner::commonPointerCastTransforms(CastInst &CI) {
return commonCastTransforms(CI);
}
-Instruction *InstCombiner::visitPtrToInt(PtrToIntInst &CI) {
+Instruction *InstCombinerImpl::visitPtrToInt(PtrToIntInst &CI) {
// If the destination integer type is not the intptr_t type for this target,
// do a ptrtoint to intptr_t then do a trunc or zext. This allows the cast
// to be exposed to other transforms.
-
+ Value *SrcOp = CI.getPointerOperand();
Type *Ty = CI.getType();
unsigned AS = CI.getPointerAddressSpace();
+ unsigned TySize = Ty->getScalarSizeInBits();
+ unsigned PtrSize = DL.getPointerSizeInBits(AS);
+ if (TySize != PtrSize) {
+ Type *IntPtrTy = DL.getIntPtrType(CI.getContext(), AS);
+ // Handle vectors of pointers.
+ if (auto *VecTy = dyn_cast<VectorType>(Ty))
+ IntPtrTy = VectorType::get(IntPtrTy, VecTy->getElementCount());
- if (Ty->getScalarSizeInBits() == DL.getPointerSizeInBits(AS))
- return commonPointerCastTransforms(CI);
+ Value *P = Builder.CreatePtrToInt(SrcOp, IntPtrTy);
+ return CastInst::CreateIntegerCast(P, Ty, /*isSigned=*/false);
+ }
- Type *PtrTy = DL.getIntPtrType(CI.getContext(), AS);
- if (auto *VTy = dyn_cast<VectorType>(Ty)) {
- // Handle vectors of pointers.
- // FIXME: what should happen for scalable vectors?
- PtrTy = FixedVectorType::get(PtrTy, VTy->getNumElements());
+ Value *Vec, *Scalar, *Index;
+ if (match(SrcOp, m_OneUse(m_InsertElt(m_IntToPtr(m_Value(Vec)),
+ m_Value(Scalar), m_Value(Index)))) &&
+ Vec->getType() == Ty) {
+ assert(Vec->getType()->getScalarSizeInBits() == PtrSize && "Wrong type");
+ // Convert the scalar to int followed by insert to eliminate one cast:
+ // p2i (ins (i2p Vec), Scalar, Index --> ins Vec, (p2i Scalar), Index
+ Value *NewCast = Builder.CreatePtrToInt(Scalar, Ty->getScalarType());
+ return InsertElementInst::Create(Vec, NewCast, Index);
}
- Value *P = Builder.CreatePtrToInt(CI.getOperand(0), PtrTy);
- return CastInst::CreateIntegerCast(P, Ty, /*isSigned=*/false);
+ return commonPointerCastTransforms(CI);
}
/// This input value (which is known to have vector type) is being zero extended
@@ -1963,9 +2013,9 @@ Instruction *InstCombiner::visitPtrToInt(PtrToIntInst &CI) {
/// Try to replace it with a shuffle (and vector/vector bitcast) if possible.
///
/// The source and destination vector types may have different element types.
-static Instruction *optimizeVectorResizeWithIntegerBitCasts(Value *InVal,
- VectorType *DestTy,
- InstCombiner &IC) {
+static Instruction *
+optimizeVectorResizeWithIntegerBitCasts(Value *InVal, VectorType *DestTy,
+ InstCombinerImpl &IC) {
// We can only do this optimization if the output is a multiple of the input
// element size, or the input is a multiple of the output element size.
// Convert the input type to have the same element type as the output.
@@ -1981,13 +2031,14 @@ static Instruction *optimizeVectorResizeWithIntegerBitCasts(Value *InVal,
return nullptr;
SrcTy =
- FixedVectorType::get(DestTy->getElementType(), SrcTy->getNumElements());
+ FixedVectorType::get(DestTy->getElementType(),
+ cast<FixedVectorType>(SrcTy)->getNumElements());
InVal = IC.Builder.CreateBitCast(InVal, SrcTy);
}
bool IsBigEndian = IC.getDataLayout().isBigEndian();
- unsigned SrcElts = SrcTy->getNumElements();
- unsigned DestElts = DestTy->getNumElements();
+ unsigned SrcElts = cast<FixedVectorType>(SrcTy)->getNumElements();
+ unsigned DestElts = cast<FixedVectorType>(DestTy)->getNumElements();
assert(SrcElts != DestElts && "Element counts should be different.");
@@ -2165,8 +2216,8 @@ static bool collectInsertionElements(Value *V, unsigned Shift,
///
/// Into two insertelements that do "buildvector{%inc, %inc5}".
static Value *optimizeIntegerToVectorInsertions(BitCastInst &CI,
- InstCombiner &IC) {
- VectorType *DestVecTy = cast<VectorType>(CI.getType());
+ InstCombinerImpl &IC) {
+ auto *DestVecTy = cast<FixedVectorType>(CI.getType());
Value *IntInput = CI.getOperand(0);
SmallVector<Value*, 8> Elements(DestVecTy->getNumElements());
@@ -2194,7 +2245,7 @@ static Value *optimizeIntegerToVectorInsertions(BitCastInst &CI,
/// vectors better than bitcasts of scalars because vector registers are
/// usually not type-specific like scalar integer or scalar floating-point.
static Instruction *canonicalizeBitCastExtElt(BitCastInst &BitCast,
- InstCombiner &IC) {
+ InstCombinerImpl &IC) {
// TODO: Create and use a pattern matcher for ExtractElementInst.
auto *ExtElt = dyn_cast<ExtractElementInst>(BitCast.getOperand(0));
if (!ExtElt || !ExtElt->hasOneUse())
@@ -2206,8 +2257,7 @@ static Instruction *canonicalizeBitCastExtElt(BitCastInst &BitCast,
if (!VectorType::isValidElementType(DestType))
return nullptr;
- unsigned NumElts = ExtElt->getVectorOperandType()->getNumElements();
- auto *NewVecType = FixedVectorType::get(DestType, NumElts);
+ auto *NewVecType = VectorType::get(DestType, ExtElt->getVectorOperandType());
auto *NewBC = IC.Builder.CreateBitCast(ExtElt->getVectorOperand(),
NewVecType, "bc");
return ExtractElementInst::Create(NewBC, ExtElt->getIndexOperand());
@@ -2270,12 +2320,11 @@ static Instruction *foldBitCastSelect(BitCastInst &BitCast,
// A vector select must maintain the same number of elements in its operands.
Type *CondTy = Cond->getType();
Type *DestTy = BitCast.getType();
- if (auto *CondVTy = dyn_cast<VectorType>(CondTy)) {
- if (!DestTy->isVectorTy())
+ if (auto *CondVTy = dyn_cast<VectorType>(CondTy))
+ if (!DestTy->isVectorTy() ||
+ CondVTy->getElementCount() !=
+ cast<VectorType>(DestTy)->getElementCount())
return nullptr;
- if (cast<VectorType>(DestTy)->getNumElements() != CondVTy->getNumElements())
- return nullptr;
- }
// FIXME: This transform is restricted from changing the select between
// scalars and vectors to avoid backend problems caused by creating
@@ -2320,7 +2369,8 @@ static bool hasStoreUsersOnly(CastInst &CI) {
///
/// All the related PHI nodes can be replaced by new PHI nodes with type A.
/// The uses of \p CI can be changed to the new PHI node corresponding to \p PN.
-Instruction *InstCombiner::optimizeBitCastFromPhi(CastInst &CI, PHINode *PN) {
+Instruction *InstCombinerImpl::optimizeBitCastFromPhi(CastInst &CI,
+ PHINode *PN) {
// BitCast used by Store can be handled in InstCombineLoadStoreAlloca.cpp.
if (hasStoreUsersOnly(CI))
return nullptr;
@@ -2450,10 +2500,7 @@ Instruction *InstCombiner::optimizeBitCastFromPhi(CastInst &CI, PHINode *PN) {
Instruction *RetVal = nullptr;
for (auto *OldPN : OldPhiNodes) {
PHINode *NewPN = NewPNodes[OldPN];
- for (auto It = OldPN->user_begin(), End = OldPN->user_end(); It != End; ) {
- User *V = *It;
- // We may remove this user, advance to avoid iterator invalidation.
- ++It;
+ for (User *V : make_early_inc_range(OldPN->users())) {
if (auto *SI = dyn_cast<StoreInst>(V)) {
assert(SI->isSimple() && SI->getOperand(0) == OldPN);
Builder.SetInsertPoint(SI);
@@ -2473,7 +2520,7 @@ Instruction *InstCombiner::optimizeBitCastFromPhi(CastInst &CI, PHINode *PN) {
if (BCI == &CI)
RetVal = I;
} else if (auto *PHI = dyn_cast<PHINode>(V)) {
- assert(OldPhiNodes.count(PHI) > 0);
+ assert(OldPhiNodes.contains(PHI));
(void) PHI;
} else {
llvm_unreachable("all uses should be handled");
@@ -2484,7 +2531,7 @@ Instruction *InstCombiner::optimizeBitCastFromPhi(CastInst &CI, PHINode *PN) {
return RetVal;
}
-Instruction *InstCombiner::visitBitCast(BitCastInst &CI) {
+Instruction *InstCombinerImpl::visitBitCast(BitCastInst &CI) {
// If the operands are integer typed then apply the integer transforms,
// otherwise just apply the common ones.
Value *Src = CI.getOperand(0);
@@ -2608,12 +2655,11 @@ Instruction *InstCombiner::visitBitCast(BitCastInst &CI) {
// a bitcast to a vector with the same # elts.
Value *ShufOp0 = Shuf->getOperand(0);
Value *ShufOp1 = Shuf->getOperand(1);
- unsigned NumShufElts = Shuf->getType()->getNumElements();
- unsigned NumSrcVecElts =
- cast<VectorType>(ShufOp0->getType())->getNumElements();
+ auto ShufElts = cast<VectorType>(Shuf->getType())->getElementCount();
+ auto SrcVecElts = cast<VectorType>(ShufOp0->getType())->getElementCount();
if (Shuf->hasOneUse() && DestTy->isVectorTy() &&
- cast<VectorType>(DestTy)->getNumElements() == NumShufElts &&
- NumShufElts == NumSrcVecElts) {
+ cast<VectorType>(DestTy)->getElementCount() == ShufElts &&
+ ShufElts == SrcVecElts) {
BitCastInst *Tmp;
// If either of the operands is a cast from CI.getType(), then
// evaluating the shuffle in the casted destination's type will allow
@@ -2636,8 +2682,9 @@ Instruction *InstCombiner::visitBitCast(BitCastInst &CI) {
// TODO: We should match the related pattern for bitreverse.
if (DestTy->isIntegerTy() &&
DL.isLegalInteger(DestTy->getScalarSizeInBits()) &&
- SrcTy->getScalarSizeInBits() == 8 && NumShufElts % 2 == 0 &&
- Shuf->hasOneUse() && Shuf->isReverse()) {
+ SrcTy->getScalarSizeInBits() == 8 &&
+ ShufElts.getKnownMinValue() % 2 == 0 && Shuf->hasOneUse() &&
+ Shuf->isReverse()) {
assert(ShufOp0->getType() == SrcTy && "Unexpected shuffle mask");
assert(isa<UndefValue>(ShufOp1) && "Unexpected shuffle op");
Function *Bswap =
@@ -2666,7 +2713,7 @@ Instruction *InstCombiner::visitBitCast(BitCastInst &CI) {
return commonCastTransforms(CI);
}
-Instruction *InstCombiner::visitAddrSpaceCast(AddrSpaceCastInst &CI) {
+Instruction *InstCombinerImpl::visitAddrSpaceCast(AddrSpaceCastInst &CI) {
// If the destination pointer element type is not the same as the source's
// first do a bitcast to the destination type, and then the addrspacecast.
// This allows the cast to be exposed to other transforms.
@@ -2677,11 +2724,9 @@ Instruction *InstCombiner::visitAddrSpaceCast(AddrSpaceCastInst &CI) {
Type *DestElemTy = DestTy->getElementType();
if (SrcTy->getElementType() != DestElemTy) {
Type *MidTy = PointerType::get(DestElemTy, SrcTy->getAddressSpace());
- if (VectorType *VT = dyn_cast<VectorType>(CI.getType())) {
- // Handle vectors of pointers.
- // FIXME: what should happen for scalable vectors?
- MidTy = FixedVectorType::get(MidTy, VT->getNumElements());
- }
+ // Handle vectors of pointers.
+ if (VectorType *VT = dyn_cast<VectorType>(CI.getType()))
+ MidTy = VectorType::get(MidTy, VT->getElementCount());
Value *NewBitCast = Builder.CreateBitCast(Src, MidTy);
return new AddrSpaceCastInst(NewBitCast, CI.getType());
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index f1233b62445d..cd9a036179b6 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -24,6 +24,7 @@
#include "llvm/IR/PatternMatch.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/KnownBits.h"
+#include "llvm/Transforms/InstCombine/InstCombiner.h"
using namespace llvm;
using namespace PatternMatch;
@@ -95,45 +96,6 @@ static bool isSignTest(ICmpInst::Predicate &Pred, const APInt &C) {
return false;
}
-/// Given a signed integer type and a set of known zero and one bits, compute
-/// the maximum and minimum values that could have the specified known zero and
-/// known one bits, returning them in Min/Max.
-/// TODO: Move to method on KnownBits struct?
-static void computeSignedMinMaxValuesFromKnownBits(const KnownBits &Known,
- APInt &Min, APInt &Max) {
- assert(Known.getBitWidth() == Min.getBitWidth() &&
- Known.getBitWidth() == Max.getBitWidth() &&
- "KnownZero, KnownOne and Min, Max must have equal bitwidth.");
- APInt UnknownBits = ~(Known.Zero|Known.One);
-
- // The minimum value is when all unknown bits are zeros, EXCEPT for the sign
- // bit if it is unknown.
- Min = Known.One;
- Max = Known.One|UnknownBits;
-
- if (UnknownBits.isNegative()) { // Sign bit is unknown
- Min.setSignBit();
- Max.clearSignBit();
- }
-}
-
-/// Given an unsigned integer type and a set of known zero and one bits, compute
-/// the maximum and minimum values that could have the specified known zero and
-/// known one bits, returning them in Min/Max.
-/// TODO: Move to method on KnownBits struct?
-static void computeUnsignedMinMaxValuesFromKnownBits(const KnownBits &Known,
- APInt &Min, APInt &Max) {
- assert(Known.getBitWidth() == Min.getBitWidth() &&
- Known.getBitWidth() == Max.getBitWidth() &&
- "Ty, KnownZero, KnownOne and Min, Max must have equal bitwidth.");
- APInt UnknownBits = ~(Known.Zero|Known.One);
-
- // The minimum value is when the unknown bits are all zeros.
- Min = Known.One;
- // The maximum value is when the unknown bits are all ones.
- Max = Known.One|UnknownBits;
-}
-
/// This is called when we see this pattern:
/// cmp pred (load (gep GV, ...)), cmpcst
/// where GV is a global variable with a constant initializer. Try to simplify
@@ -142,10 +104,10 @@ static void computeUnsignedMinMaxValuesFromKnownBits(const KnownBits &Known,
///
/// If AndCst is non-null, then the loaded value is masked with that constant
/// before doing the comparison. This handles cases like "A[i]&4 == 0".
-Instruction *InstCombiner::foldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP,
- GlobalVariable *GV,
- CmpInst &ICI,
- ConstantInt *AndCst) {
+Instruction *
+InstCombinerImpl::foldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP,
+ GlobalVariable *GV, CmpInst &ICI,
+ ConstantInt *AndCst) {
Constant *Init = GV->getInitializer();
if (!isa<ConstantArray>(Init) && !isa<ConstantDataArray>(Init))
return nullptr;
@@ -313,7 +275,7 @@ Instruction *InstCombiner::foldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP,
if (!GEP->isInBounds()) {
Type *IntPtrTy = DL.getIntPtrType(GEP->getType());
unsigned PtrSize = IntPtrTy->getIntegerBitWidth();
- if (Idx->getType()->getPrimitiveSizeInBits() > PtrSize)
+ if (Idx->getType()->getPrimitiveSizeInBits().getFixedSize() > PtrSize)
Idx = Builder.CreateTrunc(Idx, IntPtrTy);
}
@@ -422,7 +384,7 @@ Instruction *InstCombiner::foldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP,
///
/// If we can't emit an optimized form for this expression, this returns null.
///
-static Value *evaluateGEPOffsetExpression(User *GEP, InstCombiner &IC,
+static Value *evaluateGEPOffsetExpression(User *GEP, InstCombinerImpl &IC,
const DataLayout &DL) {
gep_type_iterator GTI = gep_type_begin(GEP);
@@ -486,7 +448,8 @@ static Value *evaluateGEPOffsetExpression(User *GEP, InstCombiner &IC,
// Cast to intptrty in case a truncation occurs. If an extension is needed,
// we don't need to bother extending: the extension won't affect where the
// computation crosses zero.
- if (VariableIdx->getType()->getPrimitiveSizeInBits() > IntPtrWidth) {
+ if (VariableIdx->getType()->getPrimitiveSizeInBits().getFixedSize() >
+ IntPtrWidth) {
VariableIdx = IC.Builder.CreateTrunc(VariableIdx, IntPtrTy);
}
return VariableIdx;
@@ -539,7 +502,7 @@ static bool canRewriteGEPAsOffset(Value *Start, Value *Base,
Value *V = WorkList.back();
- if (Explored.count(V) != 0) {
+ if (Explored.contains(V)) {
WorkList.pop_back();
continue;
}
@@ -551,7 +514,7 @@ static bool canRewriteGEPAsOffset(Value *Start, Value *Base,
return false;
if (isa<IntToPtrInst>(V) || isa<PtrToIntInst>(V)) {
- auto *CI = dyn_cast<CastInst>(V);
+ auto *CI = cast<CastInst>(V);
if (!CI->isNoopCast(DL))
return false;
@@ -841,9 +804,9 @@ static Instruction *transformToIndexedCompare(GEPOperator *GEPLHS, Value *RHS,
/// Fold comparisons between a GEP instruction and something else. At this point
/// we know that the GEP is on the LHS of the comparison.
-Instruction *InstCombiner::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS,
- ICmpInst::Predicate Cond,
- Instruction &I) {
+Instruction *InstCombinerImpl::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS,
+ ICmpInst::Predicate Cond,
+ Instruction &I) {
// Don't transform signed compares of GEPs into index compares. Even if the
// GEP is inbounds, the final add of the base pointer can have signed overflow
// and would change the result of the icmp.
@@ -897,8 +860,8 @@ Instruction *InstCombiner::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS,
// For vectors, we apply the same reasoning on a per-lane basis.
auto *Base = GEPLHS->getPointerOperand();
if (GEPLHS->getType()->isVectorTy() && Base->getType()->isPointerTy()) {
- int NumElts = cast<VectorType>(GEPLHS->getType())->getNumElements();
- Base = Builder.CreateVectorSplat(NumElts, Base);
+ auto EC = cast<VectorType>(GEPLHS->getType())->getElementCount();
+ Base = Builder.CreateVectorSplat(EC, Base);
}
return new ICmpInst(Cond, Base,
ConstantExpr::getPointerBitCastOrAddrSpaceCast(
@@ -941,8 +904,8 @@ Instruction *InstCombiner::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS,
Type *LHSIndexTy = LOffset->getType();
Type *RHSIndexTy = ROffset->getType();
if (LHSIndexTy != RHSIndexTy) {
- if (LHSIndexTy->getPrimitiveSizeInBits() <
- RHSIndexTy->getPrimitiveSizeInBits()) {
+ if (LHSIndexTy->getPrimitiveSizeInBits().getFixedSize() <
+ RHSIndexTy->getPrimitiveSizeInBits().getFixedSize()) {
ROffset = Builder.CreateTrunc(ROffset, LHSIndexTy);
} else
LOffset = Builder.CreateTrunc(LOffset, RHSIndexTy);
@@ -1021,9 +984,9 @@ Instruction *InstCombiner::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS,
return transformToIndexedCompare(GEPLHS, RHS, Cond, DL);
}
-Instruction *InstCombiner::foldAllocaCmp(ICmpInst &ICI,
- const AllocaInst *Alloca,
- const Value *Other) {
+Instruction *InstCombinerImpl::foldAllocaCmp(ICmpInst &ICI,
+ const AllocaInst *Alloca,
+ const Value *Other) {
assert(ICI.isEquality() && "Cannot fold non-equality comparison.");
// It would be tempting to fold away comparisons between allocas and any
@@ -1099,8 +1062,8 @@ Instruction *InstCombiner::foldAllocaCmp(ICmpInst &ICI,
}
/// Fold "icmp pred (X+C), X".
-Instruction *InstCombiner::foldICmpAddOpConst(Value *X, const APInt &C,
- ICmpInst::Predicate Pred) {
+Instruction *InstCombinerImpl::foldICmpAddOpConst(Value *X, const APInt &C,
+ ICmpInst::Predicate Pred) {
// From this point on, we know that (X+C <= X) --> (X+C < X) because C != 0,
// so the values can never be equal. Similarly for all other "or equals"
// operators.
@@ -1149,9 +1112,9 @@ Instruction *InstCombiner::foldICmpAddOpConst(Value *X, const APInt &C,
/// Handle "(icmp eq/ne (ashr/lshr AP2, A), AP1)" ->
/// (icmp eq/ne A, Log2(AP2/AP1)) ->
/// (icmp eq/ne A, Log2(AP2) - Log2(AP1)).
-Instruction *InstCombiner::foldICmpShrConstConst(ICmpInst &I, Value *A,
- const APInt &AP1,
- const APInt &AP2) {
+Instruction *InstCombinerImpl::foldICmpShrConstConst(ICmpInst &I, Value *A,
+ const APInt &AP1,
+ const APInt &AP2) {
assert(I.isEquality() && "Cannot fold icmp gt/lt");
auto getICmp = [&I](CmpInst::Predicate Pred, Value *LHS, Value *RHS) {
@@ -1208,9 +1171,9 @@ Instruction *InstCombiner::foldICmpShrConstConst(ICmpInst &I, Value *A,
/// Handle "(icmp eq/ne (shl AP2, A), AP1)" ->
/// (icmp eq/ne A, TrailingZeros(AP1) - TrailingZeros(AP2)).
-Instruction *InstCombiner::foldICmpShlConstConst(ICmpInst &I, Value *A,
- const APInt &AP1,
- const APInt &AP2) {
+Instruction *InstCombinerImpl::foldICmpShlConstConst(ICmpInst &I, Value *A,
+ const APInt &AP1,
+ const APInt &AP2) {
assert(I.isEquality() && "Cannot fold icmp gt/lt");
auto getICmp = [&I](CmpInst::Predicate Pred, Value *LHS, Value *RHS) {
@@ -1254,7 +1217,7 @@ Instruction *InstCombiner::foldICmpShlConstConst(ICmpInst &I, Value *A,
///
static Instruction *processUGT_ADDCST_ADD(ICmpInst &I, Value *A, Value *B,
ConstantInt *CI2, ConstantInt *CI1,
- InstCombiner &IC) {
+ InstCombinerImpl &IC) {
// The transformation we're trying to do here is to transform this into an
// llvm.sadd.with.overflow. To do this, we have to replace the original add
// with a narrower add, and discard the add-with-constant that is part of the
@@ -1340,7 +1303,7 @@ static Instruction *processUGT_ADDCST_ADD(ICmpInst &I, Value *A, Value *B,
/// icmp eq/ne (urem/srem %x, %y), 0
/// iff %y is a power-of-two, we can replace this with a bit test:
/// icmp eq/ne (and %x, (add %y, -1)), 0
-Instruction *InstCombiner::foldIRemByPowerOfTwoToBitTest(ICmpInst &I) {
+Instruction *InstCombinerImpl::foldIRemByPowerOfTwoToBitTest(ICmpInst &I) {
// This fold is only valid for equality predicates.
if (!I.isEquality())
return nullptr;
@@ -1359,7 +1322,7 @@ Instruction *InstCombiner::foldIRemByPowerOfTwoToBitTest(ICmpInst &I) {
/// Fold equality-comparison between zero and any (maybe truncated) right-shift
/// by one-less-than-bitwidth into a sign test on the original value.
-Instruction *InstCombiner::foldSignBitTest(ICmpInst &I) {
+Instruction *InstCombinerImpl::foldSignBitTest(ICmpInst &I) {
Instruction *Val;
ICmpInst::Predicate Pred;
if (!I.isEquality() || !match(&I, m_ICmp(Pred, m_Instruction(Val), m_Zero())))
@@ -1390,7 +1353,7 @@ Instruction *InstCombiner::foldSignBitTest(ICmpInst &I) {
}
// Handle icmp pred X, 0
-Instruction *InstCombiner::foldICmpWithZero(ICmpInst &Cmp) {
+Instruction *InstCombinerImpl::foldICmpWithZero(ICmpInst &Cmp) {
CmpInst::Predicate Pred = Cmp.getPredicate();
if (!match(Cmp.getOperand(1), m_Zero()))
return nullptr;
@@ -1431,7 +1394,7 @@ Instruction *InstCombiner::foldICmpWithZero(ICmpInst &Cmp) {
/// should be moved to some other helper and extended as noted below (it is also
/// possible that code has been made unnecessary - do we canonicalize IR to
/// overflow/saturating intrinsics or not?).
-Instruction *InstCombiner::foldICmpWithConstant(ICmpInst &Cmp) {
+Instruction *InstCombinerImpl::foldICmpWithConstant(ICmpInst &Cmp) {
// Match the following pattern, which is a common idiom when writing
// overflow-safe integer arithmetic functions. The source performs an addition
// in wider type and explicitly checks for overflow using comparisons against
@@ -1477,7 +1440,7 @@ Instruction *InstCombiner::foldICmpWithConstant(ICmpInst &Cmp) {
}
/// Canonicalize icmp instructions based on dominating conditions.
-Instruction *InstCombiner::foldICmpWithDominatingICmp(ICmpInst &Cmp) {
+Instruction *InstCombinerImpl::foldICmpWithDominatingICmp(ICmpInst &Cmp) {
// This is a cheap/incomplete check for dominance - just match a single
// predecessor with a conditional branch.
BasicBlock *CmpBB = Cmp.getParent();
@@ -1547,9 +1510,9 @@ Instruction *InstCombiner::foldICmpWithDominatingICmp(ICmpInst &Cmp) {
}
/// Fold icmp (trunc X, Y), C.
-Instruction *InstCombiner::foldICmpTruncConstant(ICmpInst &Cmp,
- TruncInst *Trunc,
- const APInt &C) {
+Instruction *InstCombinerImpl::foldICmpTruncConstant(ICmpInst &Cmp,
+ TruncInst *Trunc,
+ const APInt &C) {
ICmpInst::Predicate Pred = Cmp.getPredicate();
Value *X = Trunc->getOperand(0);
if (C.isOneValue() && C.getBitWidth() > 1) {
@@ -1580,9 +1543,9 @@ Instruction *InstCombiner::foldICmpTruncConstant(ICmpInst &Cmp,
}
/// Fold icmp (xor X, Y), C.
-Instruction *InstCombiner::foldICmpXorConstant(ICmpInst &Cmp,
- BinaryOperator *Xor,
- const APInt &C) {
+Instruction *InstCombinerImpl::foldICmpXorConstant(ICmpInst &Cmp,
+ BinaryOperator *Xor,
+ const APInt &C) {
Value *X = Xor->getOperand(0);
Value *Y = Xor->getOperand(1);
const APInt *XorC;
@@ -1612,15 +1575,13 @@ Instruction *InstCombiner::foldICmpXorConstant(ICmpInst &Cmp,
if (Xor->hasOneUse()) {
// (icmp u/s (xor X SignMask), C) -> (icmp s/u X, (xor C SignMask))
if (!Cmp.isEquality() && XorC->isSignMask()) {
- Pred = Cmp.isSigned() ? Cmp.getUnsignedPredicate()
- : Cmp.getSignedPredicate();
+ Pred = Cmp.getFlippedSignednessPredicate();
return new ICmpInst(Pred, X, ConstantInt::get(X->getType(), C ^ *XorC));
}
// (icmp u/s (xor X ~SignMask), C) -> (icmp s/u X, (xor C ~SignMask))
if (!Cmp.isEquality() && XorC->isMaxSignedValue()) {
- Pred = Cmp.isSigned() ? Cmp.getUnsignedPredicate()
- : Cmp.getSignedPredicate();
+ Pred = Cmp.getFlippedSignednessPredicate();
Pred = Cmp.getSwappedPredicate(Pred);
return new ICmpInst(Pred, X, ConstantInt::get(X->getType(), C ^ *XorC));
}
@@ -1649,8 +1610,10 @@ Instruction *InstCombiner::foldICmpXorConstant(ICmpInst &Cmp,
}
/// Fold icmp (and (sh X, Y), C2), C1.
-Instruction *InstCombiner::foldICmpAndShift(ICmpInst &Cmp, BinaryOperator *And,
- const APInt &C1, const APInt &C2) {
+Instruction *InstCombinerImpl::foldICmpAndShift(ICmpInst &Cmp,
+ BinaryOperator *And,
+ const APInt &C1,
+ const APInt &C2) {
BinaryOperator *Shift = dyn_cast<BinaryOperator>(And->getOperand(0));
if (!Shift || !Shift->isShift())
return nullptr;
@@ -1733,9 +1696,9 @@ Instruction *InstCombiner::foldICmpAndShift(ICmpInst &Cmp, BinaryOperator *And,
}
/// Fold icmp (and X, C2), C1.
-Instruction *InstCombiner::foldICmpAndConstConst(ICmpInst &Cmp,
- BinaryOperator *And,
- const APInt &C1) {
+Instruction *InstCombinerImpl::foldICmpAndConstConst(ICmpInst &Cmp,
+ BinaryOperator *And,
+ const APInt &C1) {
bool isICMP_NE = Cmp.getPredicate() == ICmpInst::ICMP_NE;
// For vectors: icmp ne (and X, 1), 0 --> trunc X to N x i1
@@ -1841,9 +1804,9 @@ Instruction *InstCombiner::foldICmpAndConstConst(ICmpInst &Cmp,
}
/// Fold icmp (and X, Y), C.
-Instruction *InstCombiner::foldICmpAndConstant(ICmpInst &Cmp,
- BinaryOperator *And,
- const APInt &C) {
+Instruction *InstCombinerImpl::foldICmpAndConstant(ICmpInst &Cmp,
+ BinaryOperator *And,
+ const APInt &C) {
if (Instruction *I = foldICmpAndConstConst(Cmp, And, C))
return I;
@@ -1883,7 +1846,7 @@ Instruction *InstCombiner::foldICmpAndConstant(ICmpInst &Cmp,
if (ExactLogBase2 != -1 && DL.isLegalInteger(ExactLogBase2 + 1)) {
Type *NTy = IntegerType::get(Cmp.getContext(), ExactLogBase2 + 1);
if (auto *AndVTy = dyn_cast<VectorType>(And->getType()))
- NTy = FixedVectorType::get(NTy, AndVTy->getNumElements());
+ NTy = VectorType::get(NTy, AndVTy->getElementCount());
Value *Trunc = Builder.CreateTrunc(X, NTy);
auto NewPred = Cmp.getPredicate() == CmpInst::ICMP_EQ ? CmpInst::ICMP_SGE
: CmpInst::ICMP_SLT;
@@ -1895,8 +1858,9 @@ Instruction *InstCombiner::foldICmpAndConstant(ICmpInst &Cmp,
}
/// Fold icmp (or X, Y), C.
-Instruction *InstCombiner::foldICmpOrConstant(ICmpInst &Cmp, BinaryOperator *Or,
- const APInt &C) {
+Instruction *InstCombinerImpl::foldICmpOrConstant(ICmpInst &Cmp,
+ BinaryOperator *Or,
+ const APInt &C) {
ICmpInst::Predicate Pred = Cmp.getPredicate();
if (C.isOneValue()) {
// icmp slt signum(V) 1 --> icmp slt V, 1
@@ -1960,9 +1924,9 @@ Instruction *InstCombiner::foldICmpOrConstant(ICmpInst &Cmp, BinaryOperator *Or,
}
/// Fold icmp (mul X, Y), C.
-Instruction *InstCombiner::foldICmpMulConstant(ICmpInst &Cmp,
- BinaryOperator *Mul,
- const APInt &C) {
+Instruction *InstCombinerImpl::foldICmpMulConstant(ICmpInst &Cmp,
+ BinaryOperator *Mul,
+ const APInt &C) {
const APInt *MulC;
if (!match(Mul->getOperand(1), m_APInt(MulC)))
return nullptr;
@@ -1977,6 +1941,21 @@ Instruction *InstCombiner::foldICmpMulConstant(ICmpInst &Cmp,
Constant::getNullValue(Mul->getType()));
}
+ // If the multiply does not wrap, try to divide the compare constant by the
+ // multiplication factor.
+ if (Cmp.isEquality() && !MulC->isNullValue()) {
+ // (mul nsw X, MulC) == C --> X == C /s MulC
+ if (Mul->hasNoSignedWrap() && C.srem(*MulC).isNullValue()) {
+ Constant *NewC = ConstantInt::get(Mul->getType(), C.sdiv(*MulC));
+ return new ICmpInst(Pred, Mul->getOperand(0), NewC);
+ }
+ // (mul nuw X, MulC) == C --> X == C /u MulC
+ if (Mul->hasNoUnsignedWrap() && C.urem(*MulC).isNullValue()) {
+ Constant *NewC = ConstantInt::get(Mul->getType(), C.udiv(*MulC));
+ return new ICmpInst(Pred, Mul->getOperand(0), NewC);
+ }
+ }
+
return nullptr;
}
@@ -2043,9 +2022,9 @@ static Instruction *foldICmpShlOne(ICmpInst &Cmp, Instruction *Shl,
}
/// Fold icmp (shl X, Y), C.
-Instruction *InstCombiner::foldICmpShlConstant(ICmpInst &Cmp,
- BinaryOperator *Shl,
- const APInt &C) {
+Instruction *InstCombinerImpl::foldICmpShlConstant(ICmpInst &Cmp,
+ BinaryOperator *Shl,
+ const APInt &C) {
const APInt *ShiftVal;
if (Cmp.isEquality() && match(Shl->getOperand(0), m_APInt(ShiftVal)))
return foldICmpShlConstConst(Cmp, Shl->getOperand(1), C, *ShiftVal);
@@ -2173,7 +2152,7 @@ Instruction *InstCombiner::foldICmpShlConstant(ICmpInst &Cmp,
DL.isLegalInteger(TypeBits - Amt)) {
Type *TruncTy = IntegerType::get(Cmp.getContext(), TypeBits - Amt);
if (auto *ShVTy = dyn_cast<VectorType>(ShType))
- TruncTy = FixedVectorType::get(TruncTy, ShVTy->getNumElements());
+ TruncTy = VectorType::get(TruncTy, ShVTy->getElementCount());
Constant *NewC =
ConstantInt::get(TruncTy, C.ashr(*ShiftAmt).trunc(TypeBits - Amt));
return new ICmpInst(Pred, Builder.CreateTrunc(X, TruncTy), NewC);
@@ -2183,9 +2162,9 @@ Instruction *InstCombiner::foldICmpShlConstant(ICmpInst &Cmp,
}
/// Fold icmp ({al}shr X, Y), C.
-Instruction *InstCombiner::foldICmpShrConstant(ICmpInst &Cmp,
- BinaryOperator *Shr,
- const APInt &C) {
+Instruction *InstCombinerImpl::foldICmpShrConstant(ICmpInst &Cmp,
+ BinaryOperator *Shr,
+ const APInt &C) {
// An exact shr only shifts out zero bits, so:
// icmp eq/ne (shr X, Y), 0 --> icmp eq/ne X, 0
Value *X = Shr->getOperand(0);
@@ -2231,6 +2210,21 @@ Instruction *InstCombiner::foldICmpShrConstant(ICmpInst &Cmp,
(ShiftedC + 1).ashr(ShAmtVal) == (C + 1))
return new ICmpInst(Pred, X, ConstantInt::get(ShrTy, ShiftedC));
}
+
+ // If the compare constant has significant bits above the lowest sign-bit,
+ // then convert an unsigned cmp to a test of the sign-bit:
+ // (ashr X, ShiftC) u> C --> X s< 0
+ // (ashr X, ShiftC) u< C --> X s> -1
+ if (C.getBitWidth() > 2 && C.getNumSignBits() <= ShAmtVal) {
+ if (Pred == CmpInst::ICMP_UGT) {
+ return new ICmpInst(CmpInst::ICMP_SLT, X,
+ ConstantInt::getNullValue(ShrTy));
+ }
+ if (Pred == CmpInst::ICMP_ULT) {
+ return new ICmpInst(CmpInst::ICMP_SGT, X,
+ ConstantInt::getAllOnesValue(ShrTy));
+ }
+ }
} else {
if (Pred == CmpInst::ICMP_ULT || (Pred == CmpInst::ICMP_UGT && IsExact)) {
// icmp ult (lshr X, ShAmtC), C --> icmp ult X, (C << ShAmtC)
@@ -2276,9 +2270,9 @@ Instruction *InstCombiner::foldICmpShrConstant(ICmpInst &Cmp,
return nullptr;
}
-Instruction *InstCombiner::foldICmpSRemConstant(ICmpInst &Cmp,
- BinaryOperator *SRem,
- const APInt &C) {
+Instruction *InstCombinerImpl::foldICmpSRemConstant(ICmpInst &Cmp,
+ BinaryOperator *SRem,
+ const APInt &C) {
// Match an 'is positive' or 'is negative' comparison of remainder by a
// constant power-of-2 value:
// (X % pow2C) sgt/slt 0
@@ -2315,9 +2309,9 @@ Instruction *InstCombiner::foldICmpSRemConstant(ICmpInst &Cmp,
}
/// Fold icmp (udiv X, Y), C.
-Instruction *InstCombiner::foldICmpUDivConstant(ICmpInst &Cmp,
- BinaryOperator *UDiv,
- const APInt &C) {
+Instruction *InstCombinerImpl::foldICmpUDivConstant(ICmpInst &Cmp,
+ BinaryOperator *UDiv,
+ const APInt &C) {
const APInt *C2;
if (!match(UDiv->getOperand(0), m_APInt(C2)))
return nullptr;
@@ -2344,9 +2338,9 @@ Instruction *InstCombiner::foldICmpUDivConstant(ICmpInst &Cmp,
}
/// Fold icmp ({su}div X, Y), C.
-Instruction *InstCombiner::foldICmpDivConstant(ICmpInst &Cmp,
- BinaryOperator *Div,
- const APInt &C) {
+Instruction *InstCombinerImpl::foldICmpDivConstant(ICmpInst &Cmp,
+ BinaryOperator *Div,
+ const APInt &C) {
// Fold: icmp pred ([us]div X, C2), C -> range test
// Fold this div into the comparison, producing a range check.
// Determine, based on the divide type, what the range is being
@@ -2514,9 +2508,9 @@ Instruction *InstCombiner::foldICmpDivConstant(ICmpInst &Cmp,
}
/// Fold icmp (sub X, Y), C.
-Instruction *InstCombiner::foldICmpSubConstant(ICmpInst &Cmp,
- BinaryOperator *Sub,
- const APInt &C) {
+Instruction *InstCombinerImpl::foldICmpSubConstant(ICmpInst &Cmp,
+ BinaryOperator *Sub,
+ const APInt &C) {
Value *X = Sub->getOperand(0), *Y = Sub->getOperand(1);
ICmpInst::Predicate Pred = Cmp.getPredicate();
const APInt *C2;
@@ -2576,9 +2570,9 @@ Instruction *InstCombiner::foldICmpSubConstant(ICmpInst &Cmp,
}
/// Fold icmp (add X, Y), C.
-Instruction *InstCombiner::foldICmpAddConstant(ICmpInst &Cmp,
- BinaryOperator *Add,
- const APInt &C) {
+Instruction *InstCombinerImpl::foldICmpAddConstant(ICmpInst &Cmp,
+ BinaryOperator *Add,
+ const APInt &C) {
Value *Y = Add->getOperand(1);
const APInt *C2;
if (Cmp.isEquality() || !match(Y, m_APInt(C2)))
@@ -2642,10 +2636,10 @@ Instruction *InstCombiner::foldICmpAddConstant(ICmpInst &Cmp,
return nullptr;
}
-bool InstCombiner::matchThreeWayIntCompare(SelectInst *SI, Value *&LHS,
- Value *&RHS, ConstantInt *&Less,
- ConstantInt *&Equal,
- ConstantInt *&Greater) {
+bool InstCombinerImpl::matchThreeWayIntCompare(SelectInst *SI, Value *&LHS,
+ Value *&RHS, ConstantInt *&Less,
+ ConstantInt *&Equal,
+ ConstantInt *&Greater) {
// TODO: Generalize this to work with other comparison idioms or ensure
// they get canonicalized into this form.
@@ -2682,7 +2676,8 @@ bool InstCombiner::matchThreeWayIntCompare(SelectInst *SI, Value *&LHS,
if (PredB == ICmpInst::ICMP_SGT && isa<Constant>(RHS2)) {
// x sgt C-1 <--> x sge C <--> not(x slt C)
auto FlippedStrictness =
- getFlippedStrictnessPredicateAndConstant(PredB, cast<Constant>(RHS2));
+ InstCombiner::getFlippedStrictnessPredicateAndConstant(
+ PredB, cast<Constant>(RHS2));
if (!FlippedStrictness)
return false;
assert(FlippedStrictness->first == ICmpInst::ICMP_SGE && "Sanity check");
@@ -2694,9 +2689,9 @@ bool InstCombiner::matchThreeWayIntCompare(SelectInst *SI, Value *&LHS,
return PredB == ICmpInst::ICMP_SLT && RHS == RHS2;
}
-Instruction *InstCombiner::foldICmpSelectConstant(ICmpInst &Cmp,
- SelectInst *Select,
- ConstantInt *C) {
+Instruction *InstCombinerImpl::foldICmpSelectConstant(ICmpInst &Cmp,
+ SelectInst *Select,
+ ConstantInt *C) {
assert(C && "Cmp RHS should be a constant int!");
// If we're testing a constant value against the result of a three way
@@ -2794,7 +2789,7 @@ static Instruction *foldICmpBitCast(ICmpInst &Cmp,
const APInt *C;
bool TrueIfSigned;
if (match(Op1, m_APInt(C)) && Bitcast->hasOneUse() &&
- isSignBitCheck(Pred, *C, TrueIfSigned)) {
+ InstCombiner::isSignBitCheck(Pred, *C, TrueIfSigned)) {
if (match(BCSrcOp, m_FPExt(m_Value(X))) ||
match(BCSrcOp, m_FPTrunc(m_Value(X)))) {
// (bitcast (fpext/fptrunc X)) to iX) < 0 --> (bitcast X to iY) < 0
@@ -2806,7 +2801,7 @@ static Instruction *foldICmpBitCast(ICmpInst &Cmp,
Type *NewType = Builder.getIntNTy(XType->getScalarSizeInBits());
if (auto *XVTy = dyn_cast<VectorType>(XType))
- NewType = FixedVectorType::get(NewType, XVTy->getNumElements());
+ NewType = VectorType::get(NewType, XVTy->getElementCount());
Value *NewBitcast = Builder.CreateBitCast(X, NewType);
if (TrueIfSigned)
return new ICmpInst(ICmpInst::ICMP_SLT, NewBitcast,
@@ -2870,7 +2865,7 @@ static Instruction *foldICmpBitCast(ICmpInst &Cmp,
/// Try to fold integer comparisons with a constant operand: icmp Pred X, C
/// where X is some kind of instruction.
-Instruction *InstCombiner::foldICmpInstWithConstant(ICmpInst &Cmp) {
+Instruction *InstCombinerImpl::foldICmpInstWithConstant(ICmpInst &Cmp) {
const APInt *C;
if (!match(Cmp.getOperand(1), m_APInt(C)))
return nullptr;
@@ -2955,9 +2950,8 @@ Instruction *InstCombiner::foldICmpInstWithConstant(ICmpInst &Cmp) {
/// Fold an icmp equality instruction with binary operator LHS and constant RHS:
/// icmp eq/ne BO, C.
-Instruction *InstCombiner::foldICmpBinOpEqualityWithConstant(ICmpInst &Cmp,
- BinaryOperator *BO,
- const APInt &C) {
+Instruction *InstCombinerImpl::foldICmpBinOpEqualityWithConstant(
+ ICmpInst &Cmp, BinaryOperator *BO, const APInt &C) {
// TODO: Some of these folds could work with arbitrary constants, but this
// function is limited to scalar and vector splat constants.
if (!Cmp.isEquality())
@@ -3047,17 +3041,6 @@ Instruction *InstCombiner::foldICmpBinOpEqualityWithConstant(ICmpInst &Cmp,
}
break;
}
- case Instruction::Mul:
- if (C.isNullValue() && BO->hasNoSignedWrap()) {
- const APInt *BOC;
- if (match(BOp1, m_APInt(BOC)) && !BOC->isNullValue()) {
- // The trivial case (mul X, 0) is handled by InstSimplify.
- // General case : (mul X, C) != 0 iff X != 0
- // (mul X, C) == 0 iff X == 0
- return new ICmpInst(Pred, BOp0, Constant::getNullValue(RHS->getType()));
- }
- }
- break;
case Instruction::UDiv:
if (C.isNullValue()) {
// (icmp eq/ne (udiv A, B), 0) -> (icmp ugt/ule i32 B, A)
@@ -3072,12 +3055,19 @@ Instruction *InstCombiner::foldICmpBinOpEqualityWithConstant(ICmpInst &Cmp,
}
/// Fold an equality icmp with LLVM intrinsic and constant operand.
-Instruction *InstCombiner::foldICmpEqIntrinsicWithConstant(ICmpInst &Cmp,
- IntrinsicInst *II,
- const APInt &C) {
+Instruction *InstCombinerImpl::foldICmpEqIntrinsicWithConstant(
+ ICmpInst &Cmp, IntrinsicInst *II, const APInt &C) {
Type *Ty = II->getType();
unsigned BitWidth = C.getBitWidth();
switch (II->getIntrinsicID()) {
+ case Intrinsic::abs:
+ // abs(A) == 0 -> A == 0
+ // abs(A) == INT_MIN -> A == INT_MIN
+ if (C.isNullValue() || C.isMinSignedValue())
+ return new ICmpInst(Cmp.getPredicate(), II->getArgOperand(0),
+ ConstantInt::get(Ty, C));
+ break;
+
case Intrinsic::bswap:
// bswap(A) == C -> A == bswap(C)
return new ICmpInst(Cmp.getPredicate(), II->getArgOperand(0),
@@ -3145,18 +3135,31 @@ Instruction *InstCombiner::foldICmpEqIntrinsicWithConstant(ICmpInst &Cmp,
}
/// Fold an icmp with LLVM intrinsic and constant operand: icmp Pred II, C.
-Instruction *InstCombiner::foldICmpIntrinsicWithConstant(ICmpInst &Cmp,
- IntrinsicInst *II,
- const APInt &C) {
+Instruction *InstCombinerImpl::foldICmpIntrinsicWithConstant(ICmpInst &Cmp,
+ IntrinsicInst *II,
+ const APInt &C) {
if (Cmp.isEquality())
return foldICmpEqIntrinsicWithConstant(Cmp, II, C);
Type *Ty = II->getType();
unsigned BitWidth = C.getBitWidth();
+ ICmpInst::Predicate Pred = Cmp.getPredicate();
switch (II->getIntrinsicID()) {
+ case Intrinsic::ctpop: {
+ // (ctpop X > BitWidth - 1) --> X == -1
+ Value *X = II->getArgOperand(0);
+ if (C == BitWidth - 1 && Pred == ICmpInst::ICMP_UGT)
+ return CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, X,
+ ConstantInt::getAllOnesValue(Ty));
+ // (ctpop X < BitWidth) --> X != -1
+ if (C == BitWidth && Pred == ICmpInst::ICMP_ULT)
+ return CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_NE, X,
+ ConstantInt::getAllOnesValue(Ty));
+ break;
+ }
case Intrinsic::ctlz: {
// ctlz(0bXXXXXXXX) > 3 -> 0bXXXXXXXX < 0b00010000
- if (Cmp.getPredicate() == ICmpInst::ICMP_UGT && C.ult(BitWidth)) {
+ if (Pred == ICmpInst::ICMP_UGT && C.ult(BitWidth)) {
unsigned Num = C.getLimitedValue();
APInt Limit = APInt::getOneBitSet(BitWidth, BitWidth - Num - 1);
return CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_ULT,
@@ -3164,8 +3167,7 @@ Instruction *InstCombiner::foldICmpIntrinsicWithConstant(ICmpInst &Cmp,
}
// ctlz(0bXXXXXXXX) < 3 -> 0bXXXXXXXX > 0b00011111
- if (Cmp.getPredicate() == ICmpInst::ICMP_ULT &&
- C.uge(1) && C.ule(BitWidth)) {
+ if (Pred == ICmpInst::ICMP_ULT && C.uge(1) && C.ule(BitWidth)) {
unsigned Num = C.getLimitedValue();
APInt Limit = APInt::getLowBitsSet(BitWidth, BitWidth - Num);
return CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_UGT,
@@ -3179,7 +3181,7 @@ Instruction *InstCombiner::foldICmpIntrinsicWithConstant(ICmpInst &Cmp,
return nullptr;
// cttz(0bXXXXXXXX) > 3 -> 0bXXXXXXXX & 0b00001111 == 0
- if (Cmp.getPredicate() == ICmpInst::ICMP_UGT && C.ult(BitWidth)) {
+ if (Pred == ICmpInst::ICMP_UGT && C.ult(BitWidth)) {
APInt Mask = APInt::getLowBitsSet(BitWidth, C.getLimitedValue() + 1);
return CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ,
Builder.CreateAnd(II->getArgOperand(0), Mask),
@@ -3187,8 +3189,7 @@ Instruction *InstCombiner::foldICmpIntrinsicWithConstant(ICmpInst &Cmp,
}
// cttz(0bXXXXXXXX) < 3 -> 0bXXXXXXXX & 0b00000111 != 0
- if (Cmp.getPredicate() == ICmpInst::ICMP_ULT &&
- C.uge(1) && C.ule(BitWidth)) {
+ if (Pred == ICmpInst::ICMP_ULT && C.uge(1) && C.ule(BitWidth)) {
APInt Mask = APInt::getLowBitsSet(BitWidth, C.getLimitedValue());
return CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_NE,
Builder.CreateAnd(II->getArgOperand(0), Mask),
@@ -3204,7 +3205,7 @@ Instruction *InstCombiner::foldICmpIntrinsicWithConstant(ICmpInst &Cmp,
}
/// Handle icmp with constant (but not simple integer constant) RHS.
-Instruction *InstCombiner::foldICmpInstWithConstantNotInt(ICmpInst &I) {
+Instruction *InstCombinerImpl::foldICmpInstWithConstantNotInt(ICmpInst &I) {
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
Constant *RHSC = dyn_cast<Constant>(Op1);
Instruction *LHSI = dyn_cast<Instruction>(Op0);
@@ -3383,8 +3384,8 @@ static Value *foldICmpWithLowBitMaskedVal(ICmpInst &I,
// those elements by copying an existing, defined, and safe scalar constant.
Type *OpTy = M->getType();
auto *VecC = dyn_cast<Constant>(M);
- if (OpTy->isVectorTy() && VecC && VecC->containsUndefElement()) {
- auto *OpVTy = cast<VectorType>(OpTy);
+ auto *OpVTy = dyn_cast<FixedVectorType>(OpTy);
+ if (OpVTy && VecC && VecC->containsUndefOrPoisonElement()) {
Constant *SafeReplacementConstant = nullptr;
for (unsigned i = 0, e = OpVTy->getNumElements(); i != e; ++i) {
if (!isa<UndefValue>(VecC->getAggregateElement(i))) {
@@ -3650,7 +3651,7 @@ foldShiftIntoShiftInAnotherHandOfAndInICmp(ICmpInst &I, const SimplifyQuery SQ,
/// @llvm.umul.with.overflow(x, y) plus extraction of overflow bit
/// Note that the comparison is commutative, while inverted (u>=, ==) predicate
/// will mean that we are looking for the opposite answer.
-Value *InstCombiner::foldUnsignedMultiplicationOverflowCheck(ICmpInst &I) {
+Value *InstCombinerImpl::foldUnsignedMultiplicationOverflowCheck(ICmpInst &I) {
ICmpInst::Predicate Pred;
Value *X, *Y;
Instruction *Mul;
@@ -3712,11 +3713,28 @@ Value *InstCombiner::foldUnsignedMultiplicationOverflowCheck(ICmpInst &I) {
return Res;
}
+static Instruction *foldICmpXNegX(ICmpInst &I) {
+ CmpInst::Predicate Pred;
+ Value *X;
+ if (!match(&I, m_c_ICmp(Pred, m_NSWNeg(m_Value(X)), m_Deferred(X))))
+ return nullptr;
+
+ if (ICmpInst::isSigned(Pred))
+ Pred = ICmpInst::getSwappedPredicate(Pred);
+ else if (ICmpInst::isUnsigned(Pred))
+ Pred = ICmpInst::getSignedPredicate(Pred);
+ // else for equality-comparisons just keep the predicate.
+
+ return ICmpInst::Create(Instruction::ICmp, Pred, X,
+ Constant::getNullValue(X->getType()), I.getName());
+}
+
/// Try to fold icmp (binop), X or icmp X, (binop).
/// TODO: A large part of this logic is duplicated in InstSimplify's
/// simplifyICmpWithBinOp(). We should be able to share that and avoid the code
/// duplication.
-Instruction *InstCombiner::foldICmpBinOp(ICmpInst &I, const SimplifyQuery &SQ) {
+Instruction *InstCombinerImpl::foldICmpBinOp(ICmpInst &I,
+ const SimplifyQuery &SQ) {
const SimplifyQuery Q = SQ.getWithInstruction(&I);
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
@@ -3726,6 +3744,9 @@ Instruction *InstCombiner::foldICmpBinOp(ICmpInst &I, const SimplifyQuery &SQ) {
if (!BO0 && !BO1)
return nullptr;
+ if (Instruction *NewICmp = foldICmpXNegX(I))
+ return NewICmp;
+
const CmpInst::Predicate Pred = I.getPredicate();
Value *X;
@@ -3946,6 +3967,19 @@ Instruction *InstCombiner::foldICmpBinOp(ICmpInst &I, const SimplifyQuery &SQ) {
ConstantExpr::getNeg(RHSC));
}
+ {
+ // Try to remove shared constant multiplier from equality comparison:
+ // X * C == Y * C (with no overflowing/aliasing) --> X == Y
+ Value *X, *Y;
+ const APInt *C;
+ if (match(Op0, m_Mul(m_Value(X), m_APInt(C))) && *C != 0 &&
+ match(Op1, m_Mul(m_Value(Y), m_SpecificInt(*C))) && I.isEquality())
+ if (!C->countTrailingZeros() ||
+ (BO0->hasNoSignedWrap() && BO1->hasNoSignedWrap()) ||
+ (BO0->hasNoUnsignedWrap() && BO1->hasNoUnsignedWrap()))
+ return new ICmpInst(Pred, X, Y);
+ }
+
BinaryOperator *SRem = nullptr;
// icmp (srem X, Y), Y
if (BO0 && BO0->getOpcode() == Instruction::SRem && Op1 == BO0->getOperand(1))
@@ -3990,15 +4024,13 @@ Instruction *InstCombiner::foldICmpBinOp(ICmpInst &I, const SimplifyQuery &SQ) {
if (match(BO0->getOperand(1), m_APInt(C))) {
// icmp u/s (a ^ signmask), (b ^ signmask) --> icmp s/u a, b
if (C->isSignMask()) {
- ICmpInst::Predicate NewPred =
- I.isSigned() ? I.getUnsignedPredicate() : I.getSignedPredicate();
+ ICmpInst::Predicate NewPred = I.getFlippedSignednessPredicate();
return new ICmpInst(NewPred, BO0->getOperand(0), BO1->getOperand(0));
}
// icmp u/s (a ^ maxsignval), (b ^ maxsignval) --> icmp s/u' a, b
if (BO0->getOpcode() == Instruction::Xor && C->isMaxSignedValue()) {
- ICmpInst::Predicate NewPred =
- I.isSigned() ? I.getUnsignedPredicate() : I.getSignedPredicate();
+ ICmpInst::Predicate NewPred = I.getFlippedSignednessPredicate();
NewPred = I.getSwappedPredicate(NewPred);
return new ICmpInst(NewPred, BO0->getOperand(0), BO1->getOperand(0));
}
@@ -4022,10 +4054,6 @@ Instruction *InstCombiner::foldICmpBinOp(ICmpInst &I, const SimplifyQuery &SQ) {
Value *And2 = Builder.CreateAnd(BO1->getOperand(0), Mask);
return new ICmpInst(Pred, And1, And2);
}
- // If there are no trailing zeros in the multiplier, just eliminate
- // the multiplies (no masking is needed):
- // icmp eq/ne (X * C), (Y * C) --> icmp eq/ne X, Y
- return new ICmpInst(Pred, BO0->getOperand(0), BO1->getOperand(0));
}
break;
}
@@ -4170,7 +4198,7 @@ static Instruction *foldICmpWithMinMax(ICmpInst &Cmp) {
return nullptr;
}
-Instruction *InstCombiner::foldICmpEquality(ICmpInst &I) {
+Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) {
if (!I.isEquality())
return nullptr;
@@ -4438,7 +4466,7 @@ static Instruction *foldICmpWithZextOrSext(ICmpInst &ICmp,
}
/// Handle icmp (cast x), (cast or constant).
-Instruction *InstCombiner::foldICmpWithCastOp(ICmpInst &ICmp) {
+Instruction *InstCombinerImpl::foldICmpWithCastOp(ICmpInst &ICmp) {
auto *CastOp0 = dyn_cast<CastInst>(ICmp.getOperand(0));
if (!CastOp0)
return nullptr;
@@ -4493,9 +4521,10 @@ static bool isNeutralValue(Instruction::BinaryOps BinaryOp, Value *RHS) {
}
}
-OverflowResult InstCombiner::computeOverflow(
- Instruction::BinaryOps BinaryOp, bool IsSigned,
- Value *LHS, Value *RHS, Instruction *CxtI) const {
+OverflowResult
+InstCombinerImpl::computeOverflow(Instruction::BinaryOps BinaryOp,
+ bool IsSigned, Value *LHS, Value *RHS,
+ Instruction *CxtI) const {
switch (BinaryOp) {
default:
llvm_unreachable("Unsupported binary op");
@@ -4517,9 +4546,11 @@ OverflowResult InstCombiner::computeOverflow(
}
}
-bool InstCombiner::OptimizeOverflowCheck(
- Instruction::BinaryOps BinaryOp, bool IsSigned, Value *LHS, Value *RHS,
- Instruction &OrigI, Value *&Result, Constant *&Overflow) {
+bool InstCombinerImpl::OptimizeOverflowCheck(Instruction::BinaryOps BinaryOp,
+ bool IsSigned, Value *LHS,
+ Value *RHS, Instruction &OrigI,
+ Value *&Result,
+ Constant *&Overflow) {
if (OrigI.isCommutative() && isa<Constant>(LHS) && !isa<Constant>(RHS))
std::swap(LHS, RHS);
@@ -4529,9 +4560,13 @@ bool InstCombiner::OptimizeOverflowCheck(
// compare.
Builder.SetInsertPoint(&OrigI);
+ Type *OverflowTy = Type::getInt1Ty(LHS->getContext());
+ if (auto *LHSTy = dyn_cast<VectorType>(LHS->getType()))
+ OverflowTy = VectorType::get(OverflowTy, LHSTy->getElementCount());
+
if (isNeutralValue(BinaryOp, RHS)) {
Result = LHS;
- Overflow = Builder.getFalse();
+ Overflow = ConstantInt::getFalse(OverflowTy);
return true;
}
@@ -4542,12 +4577,12 @@ bool InstCombiner::OptimizeOverflowCheck(
case OverflowResult::AlwaysOverflowsHigh:
Result = Builder.CreateBinOp(BinaryOp, LHS, RHS);
Result->takeName(&OrigI);
- Overflow = Builder.getTrue();
+ Overflow = ConstantInt::getTrue(OverflowTy);
return true;
case OverflowResult::NeverOverflows:
Result = Builder.CreateBinOp(BinaryOp, LHS, RHS);
Result->takeName(&OrigI);
- Overflow = Builder.getFalse();
+ Overflow = ConstantInt::getFalse(OverflowTy);
if (auto *Inst = dyn_cast<Instruction>(Result)) {
if (IsSigned)
Inst->setHasNoSignedWrap();
@@ -4575,7 +4610,8 @@ bool InstCombiner::OptimizeOverflowCheck(
/// \returns Instruction which must replace the compare instruction, NULL if no
/// replacement required.
static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal,
- Value *OtherVal, InstCombiner &IC) {
+ Value *OtherVal,
+ InstCombinerImpl &IC) {
// Don't bother doing this transformation for pointers, don't do it for
// vectors.
if (!isa<IntegerType>(MulVal->getType()))
@@ -4723,15 +4759,14 @@ static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal,
Function *F = Intrinsic::getDeclaration(
I.getModule(), Intrinsic::umul_with_overflow, MulType);
CallInst *Call = Builder.CreateCall(F, {MulA, MulB}, "umul");
- IC.Worklist.push(MulInstr);
+ IC.addToWorklist(MulInstr);
// If there are uses of mul result other than the comparison, we know that
// they are truncation or binary AND. Change them to use result of
// mul.with.overflow and adjust properly mask/size.
if (MulVal->hasNUsesOrMore(2)) {
Value *Mul = Builder.CreateExtractValue(Call, 0, "umul.value");
- for (auto UI = MulVal->user_begin(), UE = MulVal->user_end(); UI != UE;) {
- User *U = *UI++;
+ for (User *U : make_early_inc_range(MulVal->users())) {
if (U == &I || U == OtherVal)
continue;
if (TruncInst *TI = dyn_cast<TruncInst>(U)) {
@@ -4750,11 +4785,11 @@ static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal,
} else {
llvm_unreachable("Unexpected Binary operation");
}
- IC.Worklist.push(cast<Instruction>(U));
+ IC.addToWorklist(cast<Instruction>(U));
}
}
if (isa<Instruction>(OtherVal))
- IC.Worklist.push(cast<Instruction>(OtherVal));
+ IC.addToWorklist(cast<Instruction>(OtherVal));
// The original icmp gets replaced with the overflow value, maybe inverted
// depending on predicate.
@@ -4799,7 +4834,7 @@ static APInt getDemandedBitsLHSMask(ICmpInst &I, unsigned BitWidth) {
// If this is a normal comparison, it demands all bits. If it is a sign bit
// comparison, it only demands the sign bit.
bool UnusedBit;
- if (isSignBitCheck(I.getPredicate(), *RHS, UnusedBit))
+ if (InstCombiner::isSignBitCheck(I.getPredicate(), *RHS, UnusedBit))
return APInt::getSignMask(BitWidth);
switch (I.getPredicate()) {
@@ -4856,9 +4891,9 @@ static bool swapMayExposeCSEOpportunities(const Value *Op0, const Value *Op1) {
/// \return true when \p UI is the only use of \p DI in the parent block
/// and all other uses of \p DI are in blocks dominated by \p DB.
///
-bool InstCombiner::dominatesAllUses(const Instruction *DI,
- const Instruction *UI,
- const BasicBlock *DB) const {
+bool InstCombinerImpl::dominatesAllUses(const Instruction *DI,
+ const Instruction *UI,
+ const BasicBlock *DB) const {
assert(DI && UI && "Instruction not defined\n");
// Ignore incomplete definitions.
if (!DI->getParent())
@@ -4931,9 +4966,9 @@ static bool isChainSelectCmpBranch(const SelectInst *SI) {
/// major restriction since a NE compare should be 'normalized' to an equal
/// compare, which usually happens in the combiner and test case
/// select-cmp-br.ll checks for it.
-bool InstCombiner::replacedSelectWithOperand(SelectInst *SI,
- const ICmpInst *Icmp,
- const unsigned SIOpd) {
+bool InstCombinerImpl::replacedSelectWithOperand(SelectInst *SI,
+ const ICmpInst *Icmp,
+ const unsigned SIOpd) {
assert((SIOpd == 1 || SIOpd == 2) && "Invalid select operand!");
if (isChainSelectCmpBranch(SI) && Icmp->getPredicate() == ICmpInst::ICMP_EQ) {
BasicBlock *Succ = SI->getParent()->getTerminator()->getSuccessor(1);
@@ -4959,7 +4994,7 @@ bool InstCombiner::replacedSelectWithOperand(SelectInst *SI,
/// Try to fold the comparison based on range information we can get by checking
/// whether bits are known to be zero or one in the inputs.
-Instruction *InstCombiner::foldICmpUsingKnownBits(ICmpInst &I) {
+Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) {
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
Type *Ty = Op0->getType();
ICmpInst::Predicate Pred = I.getPredicate();
@@ -4990,11 +5025,15 @@ Instruction *InstCombiner::foldICmpUsingKnownBits(ICmpInst &I) {
APInt Op0Min(BitWidth, 0), Op0Max(BitWidth, 0);
APInt Op1Min(BitWidth, 0), Op1Max(BitWidth, 0);
if (I.isSigned()) {
- computeSignedMinMaxValuesFromKnownBits(Op0Known, Op0Min, Op0Max);
- computeSignedMinMaxValuesFromKnownBits(Op1Known, Op1Min, Op1Max);
+ Op0Min = Op0Known.getSignedMinValue();
+ Op0Max = Op0Known.getSignedMaxValue();
+ Op1Min = Op1Known.getSignedMinValue();
+ Op1Max = Op1Known.getSignedMaxValue();
} else {
- computeUnsignedMinMaxValuesFromKnownBits(Op0Known, Op0Min, Op0Max);
- computeUnsignedMinMaxValuesFromKnownBits(Op1Known, Op1Min, Op1Max);
+ Op0Min = Op0Known.getMinValue();
+ Op0Max = Op0Known.getMaxValue();
+ Op1Min = Op1Known.getMinValue();
+ Op1Max = Op1Known.getMaxValue();
}
// If Min and Max are known to be the same, then SimplifyDemandedBits figured
@@ -5012,11 +5051,9 @@ Instruction *InstCombiner::foldICmpUsingKnownBits(ICmpInst &I) {
llvm_unreachable("Unknown icmp opcode!");
case ICmpInst::ICMP_EQ:
case ICmpInst::ICMP_NE: {
- if (Op0Max.ult(Op1Min) || Op0Min.ugt(Op1Max)) {
- return Pred == CmpInst::ICMP_EQ
- ? replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()))
- : replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
- }
+ if (Op0Max.ult(Op1Min) || Op0Min.ugt(Op1Max))
+ return replaceInstUsesWith(
+ I, ConstantInt::getBool(I.getType(), Pred == CmpInst::ICMP_NE));
// If all bits are known zero except for one, then we know at most one bit
// is set. If the comparison is against zero, then this is a check to see if
@@ -5186,8 +5223,8 @@ Instruction *InstCombiner::foldICmpUsingKnownBits(ICmpInst &I) {
}
llvm::Optional<std::pair<CmpInst::Predicate, Constant *>>
-llvm::getFlippedStrictnessPredicateAndConstant(CmpInst::Predicate Pred,
- Constant *C) {
+InstCombiner::getFlippedStrictnessPredicateAndConstant(CmpInst::Predicate Pred,
+ Constant *C) {
assert(ICmpInst::isRelational(Pred) && ICmpInst::isIntPredicate(Pred) &&
"Only for relational integer predicates.");
@@ -5209,8 +5246,8 @@ llvm::getFlippedStrictnessPredicateAndConstant(CmpInst::Predicate Pred,
// Bail out if the constant can't be safely incremented/decremented.
if (!ConstantIsOk(CI))
return llvm::None;
- } else if (auto *VTy = dyn_cast<VectorType>(Type)) {
- unsigned NumElts = VTy->getNumElements();
+ } else if (auto *FVTy = dyn_cast<FixedVectorType>(Type)) {
+ unsigned NumElts = FVTy->getNumElements();
for (unsigned i = 0; i != NumElts; ++i) {
Constant *Elt = C->getAggregateElement(i);
if (!Elt)
@@ -5236,7 +5273,8 @@ llvm::getFlippedStrictnessPredicateAndConstant(CmpInst::Predicate Pred,
// It may not be safe to change a compare predicate in the presence of
// undefined elements, so replace those elements with the first safe constant
// that we found.
- if (C->containsUndefElement()) {
+ // TODO: in case of poison, it is safe; let's replace undefs only.
+ if (C->containsUndefOrPoisonElement()) {
assert(SafeReplacementConstant && "Replacement constant not set");
C = Constant::replaceUndefsWith(C, SafeReplacementConstant);
}
@@ -5256,7 +5294,7 @@ llvm::getFlippedStrictnessPredicateAndConstant(CmpInst::Predicate Pred,
static ICmpInst *canonicalizeCmpWithConstant(ICmpInst &I) {
ICmpInst::Predicate Pred = I.getPredicate();
if (ICmpInst::isEquality(Pred) || !ICmpInst::isIntPredicate(Pred) ||
- isCanonicalPredicate(Pred))
+ InstCombiner::isCanonicalPredicate(Pred))
return nullptr;
Value *Op0 = I.getOperand(0);
@@ -5265,7 +5303,8 @@ static ICmpInst *canonicalizeCmpWithConstant(ICmpInst &I) {
if (!Op1C)
return nullptr;
- auto FlippedStrictness = getFlippedStrictnessPredicateAndConstant(Pred, Op1C);
+ auto FlippedStrictness =
+ InstCombiner::getFlippedStrictnessPredicateAndConstant(Pred, Op1C);
if (!FlippedStrictness)
return nullptr;
@@ -5274,14 +5313,14 @@ static ICmpInst *canonicalizeCmpWithConstant(ICmpInst &I) {
/// If we have a comparison with a non-canonical predicate, if we can update
/// all the users, invert the predicate and adjust all the users.
-static CmpInst *canonicalizeICmpPredicate(CmpInst &I) {
+CmpInst *InstCombinerImpl::canonicalizeICmpPredicate(CmpInst &I) {
// Is the predicate already canonical?
CmpInst::Predicate Pred = I.getPredicate();
- if (isCanonicalPredicate(Pred))
+ if (InstCombiner::isCanonicalPredicate(Pred))
return nullptr;
// Can all users be adjusted to predicate inversion?
- if (!canFreelyInvertAllUsersOf(&I, /*IgnoredUser=*/nullptr))
+ if (!InstCombiner::canFreelyInvertAllUsersOf(&I, /*IgnoredUser=*/nullptr))
return nullptr;
// Ok, we can canonicalize comparison!
@@ -5289,26 +5328,8 @@ static CmpInst *canonicalizeICmpPredicate(CmpInst &I) {
I.setPredicate(CmpInst::getInversePredicate(Pred));
I.setName(I.getName() + ".not");
- // And now let's adjust every user.
- for (User *U : I.users()) {
- switch (cast<Instruction>(U)->getOpcode()) {
- case Instruction::Select: {
- auto *SI = cast<SelectInst>(U);
- SI->swapValues();
- SI->swapProfMetadata();
- break;
- }
- case Instruction::Br:
- cast<BranchInst>(U)->swapSuccessors(); // swaps prof metadata too
- break;
- case Instruction::Xor:
- U->replaceAllUsesWith(&I);
- break;
- default:
- llvm_unreachable("Got unexpected user - out of sync with "
- "canFreelyInvertAllUsersOf() ?");
- }
- }
+ // And, adapt users.
+ freelyInvertAllUsersOf(&I);
return &I;
}
@@ -5510,7 +5531,7 @@ static Instruction *foldICmpOfUAddOv(ICmpInst &I) {
return ExtractValueInst::Create(UAddOv, 1);
}
-Instruction *InstCombiner::visitICmpInst(ICmpInst &I) {
+Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) {
bool Changed = false;
const SimplifyQuery Q = SQ.getWithInstruction(&I);
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
@@ -5634,10 +5655,10 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) {
// Try to optimize equality comparisons against alloca-based pointers.
if (Op0->getType()->isPointerTy() && I.isEquality()) {
assert(Op1->getType()->isPointerTy() && "Comparing pointer with non-pointer?");
- if (auto *Alloca = dyn_cast<AllocaInst>(GetUnderlyingObject(Op0, DL)))
+ if (auto *Alloca = dyn_cast<AllocaInst>(getUnderlyingObject(Op0)))
if (Instruction *New = foldAllocaCmp(I, Alloca, Op1))
return New;
- if (auto *Alloca = dyn_cast<AllocaInst>(GetUnderlyingObject(Op1, DL)))
+ if (auto *Alloca = dyn_cast<AllocaInst>(getUnderlyingObject(Op1)))
if (Instruction *New = foldAllocaCmp(I, Alloca, Op0))
return New;
}
@@ -5748,8 +5769,9 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) {
}
/// Fold fcmp ([us]itofp x, cst) if possible.
-Instruction *InstCombiner::foldFCmpIntToFPConst(FCmpInst &I, Instruction *LHSI,
- Constant *RHSC) {
+Instruction *InstCombinerImpl::foldFCmpIntToFPConst(FCmpInst &I,
+ Instruction *LHSI,
+ Constant *RHSC) {
if (!isa<ConstantFP>(RHSC)) return nullptr;
const APFloat &RHS = cast<ConstantFP>(RHSC)->getValueAPF();
@@ -6034,9 +6056,9 @@ static Instruction *foldFCmpReciprocalAndZero(FCmpInst &I, Instruction *LHSI,
}
/// Optimize fabs(X) compared with zero.
-static Instruction *foldFabsWithFcmpZero(FCmpInst &I, InstCombiner &IC) {
+static Instruction *foldFabsWithFcmpZero(FCmpInst &I, InstCombinerImpl &IC) {
Value *X;
- if (!match(I.getOperand(0), m_Intrinsic<Intrinsic::fabs>(m_Value(X))) ||
+ if (!match(I.getOperand(0), m_FAbs(m_Value(X))) ||
!match(I.getOperand(1), m_PosZeroFP()))
return nullptr;
@@ -6096,7 +6118,7 @@ static Instruction *foldFabsWithFcmpZero(FCmpInst &I, InstCombiner &IC) {
}
}
-Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) {
+Instruction *InstCombinerImpl::visitFCmpInst(FCmpInst &I) {
bool Changed = false;
/// Orders the operands of the compare so that they are listed from most
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index f918dc7198ca..79e9d5c46c70 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -15,40 +15,32 @@
#ifndef LLVM_LIB_TRANSFORMS_INSTCOMBINE_INSTCOMBINEINTERNAL_H
#define LLVM_LIB_TRANSFORMS_INSTCOMBINE_INSTCOMBINEINTERNAL_H
-#include "llvm/ADT/ArrayRef.h"
-#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/Analysis/InstructionSimplify.h"
#include "llvm/Analysis/TargetFolder.h"
#include "llvm/Analysis/ValueTracking.h"
-#include "llvm/IR/Argument.h"
-#include "llvm/IR/BasicBlock.h"
-#include "llvm/IR/Constant.h"
-#include "llvm/IR/Constants.h"
-#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InstVisitor.h"
-#include "llvm/IR/InstrTypes.h"
-#include "llvm/IR/Instruction.h"
-#include "llvm/IR/IntrinsicInst.h"
-#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/PatternMatch.h"
-#include "llvm/IR/Use.h"
-#include "llvm/IR/Value.h"
-#include "llvm/Support/Casting.h"
-#include "llvm/Support/Compiler.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/KnownBits.h"
-#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/InstCombine/InstCombineWorklist.h"
+#include "llvm/Transforms/InstCombine/InstCombiner.h"
#include "llvm/Transforms/Utils/Local.h"
#include <cassert>
-#include <cstdint>
#define DEBUG_TYPE "instcombine"
using namespace llvm::PatternMatch;
+// As a default, let's assume that we want to be aggressive,
+// and attempt to traverse with no limits in attempt to sink negation.
+static constexpr unsigned NegatorDefaultMaxDepth = ~0U;
+
+// Let's guesstimate that most often we will end up visiting/producing
+// fairly small number of new instructions.
+static constexpr unsigned NegatorMaxNodesSSO = 16;
+
namespace llvm {
class AAResults;
@@ -65,305 +57,26 @@ class ProfileSummaryInfo;
class TargetLibraryInfo;
class User;
-/// Assign a complexity or rank value to LLVM Values. This is used to reduce
-/// the amount of pattern matching needed for compares and commutative
-/// instructions. For example, if we have:
-/// icmp ugt X, Constant
-/// or
-/// xor (add X, Constant), cast Z
-///
-/// We do not have to consider the commuted variants of these patterns because
-/// canonicalization based on complexity guarantees the above ordering.
-///
-/// This routine maps IR values to various complexity ranks:
-/// 0 -> undef
-/// 1 -> Constants
-/// 2 -> Other non-instructions
-/// 3 -> Arguments
-/// 4 -> Cast and (f)neg/not instructions
-/// 5 -> Other instructions
-static inline unsigned getComplexity(Value *V) {
- if (isa<Instruction>(V)) {
- if (isa<CastInst>(V) || match(V, m_Neg(m_Value())) ||
- match(V, m_Not(m_Value())) || match(V, m_FNeg(m_Value())))
- return 4;
- return 5;
- }
- if (isa<Argument>(V))
- return 3;
- return isa<Constant>(V) ? (isa<UndefValue>(V) ? 0 : 1) : 2;
-}
-
-/// Predicate canonicalization reduces the number of patterns that need to be
-/// matched by other transforms. For example, we may swap the operands of a
-/// conditional branch or select to create a compare with a canonical (inverted)
-/// predicate which is then more likely to be matched with other values.
-static inline bool isCanonicalPredicate(CmpInst::Predicate Pred) {
- switch (Pred) {
- case CmpInst::ICMP_NE:
- case CmpInst::ICMP_ULE:
- case CmpInst::ICMP_SLE:
- case CmpInst::ICMP_UGE:
- case CmpInst::ICMP_SGE:
- // TODO: There are 16 FCMP predicates. Should others be (not) canonical?
- case CmpInst::FCMP_ONE:
- case CmpInst::FCMP_OLE:
- case CmpInst::FCMP_OGE:
- return false;
- default:
- return true;
- }
-}
-
-/// Given an exploded icmp instruction, return true if the comparison only
-/// checks the sign bit. If it only checks the sign bit, set TrueIfSigned if the
-/// result of the comparison is true when the input value is signed.
-inline bool isSignBitCheck(ICmpInst::Predicate Pred, const APInt &RHS,
- bool &TrueIfSigned) {
- switch (Pred) {
- case ICmpInst::ICMP_SLT: // True if LHS s< 0
- TrueIfSigned = true;
- return RHS.isNullValue();
- case ICmpInst::ICMP_SLE: // True if LHS s<= -1
- TrueIfSigned = true;
- return RHS.isAllOnesValue();
- case ICmpInst::ICMP_SGT: // True if LHS s> -1
- TrueIfSigned = false;
- return RHS.isAllOnesValue();
- case ICmpInst::ICMP_SGE: // True if LHS s>= 0
- TrueIfSigned = false;
- return RHS.isNullValue();
- case ICmpInst::ICMP_UGT:
- // True if LHS u> RHS and RHS == sign-bit-mask - 1
- TrueIfSigned = true;
- return RHS.isMaxSignedValue();
- case ICmpInst::ICMP_UGE:
- // True if LHS u>= RHS and RHS == sign-bit-mask (2^7, 2^15, 2^31, etc)
- TrueIfSigned = true;
- return RHS.isMinSignedValue();
- case ICmpInst::ICMP_ULT:
- // True if LHS u< RHS and RHS == sign-bit-mask (2^7, 2^15, 2^31, etc)
- TrueIfSigned = false;
- return RHS.isMinSignedValue();
- case ICmpInst::ICMP_ULE:
- // True if LHS u<= RHS and RHS == sign-bit-mask - 1
- TrueIfSigned = false;
- return RHS.isMaxSignedValue();
- default:
- return false;
- }
-}
-
-llvm::Optional<std::pair<CmpInst::Predicate, Constant *>>
-getFlippedStrictnessPredicateAndConstant(CmpInst::Predicate Pred, Constant *C);
-
-/// Return the source operand of a potentially bitcasted value while optionally
-/// checking if it has one use. If there is no bitcast or the one use check is
-/// not met, return the input value itself.
-static inline Value *peekThroughBitcast(Value *V, bool OneUseOnly = false) {
- if (auto *BitCast = dyn_cast<BitCastInst>(V))
- if (!OneUseOnly || BitCast->hasOneUse())
- return BitCast->getOperand(0);
-
- // V is not a bitcast or V has more than one use and OneUseOnly is true.
- return V;
-}
-
-/// Add one to a Constant
-static inline Constant *AddOne(Constant *C) {
- return ConstantExpr::getAdd(C, ConstantInt::get(C->getType(), 1));
-}
-
-/// Subtract one from a Constant
-static inline Constant *SubOne(Constant *C) {
- return ConstantExpr::getSub(C, ConstantInt::get(C->getType(), 1));
-}
-
-/// Return true if the specified value is free to invert (apply ~ to).
-/// This happens in cases where the ~ can be eliminated. If WillInvertAllUses
-/// is true, work under the assumption that the caller intends to remove all
-/// uses of V and only keep uses of ~V.
-///
-/// See also: canFreelyInvertAllUsersOf()
-static inline bool isFreeToInvert(Value *V, bool WillInvertAllUses) {
- // ~(~(X)) -> X.
- if (match(V, m_Not(m_Value())))
- return true;
-
- // Constants can be considered to be not'ed values.
- if (match(V, m_AnyIntegralConstant()))
- return true;
-
- // Compares can be inverted if all of their uses are being modified to use the
- // ~V.
- if (isa<CmpInst>(V))
- return WillInvertAllUses;
-
- // If `V` is of the form `A + Constant` then `-1 - V` can be folded into `(-1
- // - Constant) - A` if we are willing to invert all of the uses.
- if (BinaryOperator *BO = dyn_cast<BinaryOperator>(V))
- if (BO->getOpcode() == Instruction::Add ||
- BO->getOpcode() == Instruction::Sub)
- if (isa<Constant>(BO->getOperand(0)) || isa<Constant>(BO->getOperand(1)))
- return WillInvertAllUses;
-
- // Selects with invertible operands are freely invertible
- if (match(V, m_Select(m_Value(), m_Not(m_Value()), m_Not(m_Value()))))
- return WillInvertAllUses;
-
- return false;
-}
-
-/// Given i1 V, can every user of V be freely adapted if V is changed to !V ?
-/// InstCombine's canonicalizeICmpPredicate() must be kept in sync with this fn.
-///
-/// See also: isFreeToInvert()
-static inline bool canFreelyInvertAllUsersOf(Value *V, Value *IgnoredUser) {
- // Look at every user of V.
- for (Use &U : V->uses()) {
- if (U.getUser() == IgnoredUser)
- continue; // Don't consider this user.
-
- auto *I = cast<Instruction>(U.getUser());
- switch (I->getOpcode()) {
- case Instruction::Select:
- if (U.getOperandNo() != 0) // Only if the value is used as select cond.
- return false;
- break;
- case Instruction::Br:
- assert(U.getOperandNo() == 0 && "Must be branching on that value.");
- break; // Free to invert by swapping true/false values/destinations.
- case Instruction::Xor: // Can invert 'xor' if it's a 'not', by ignoring it.
- if (!match(I, m_Not(m_Value())))
- return false; // Not a 'not'.
- break;
- default:
- return false; // Don't know, likely not freely invertible.
- }
- // So far all users were free to invert...
- }
- return true; // Can freely invert all users!
-}
-
-/// Some binary operators require special handling to avoid poison and undefined
-/// behavior. If a constant vector has undef elements, replace those undefs with
-/// identity constants if possible because those are always safe to execute.
-/// If no identity constant exists, replace undef with some other safe constant.
-static inline Constant *getSafeVectorConstantForBinop(
- BinaryOperator::BinaryOps Opcode, Constant *In, bool IsRHSConstant) {
- auto *InVTy = dyn_cast<VectorType>(In->getType());
- assert(InVTy && "Not expecting scalars here");
-
- Type *EltTy = InVTy->getElementType();
- auto *SafeC = ConstantExpr::getBinOpIdentity(Opcode, EltTy, IsRHSConstant);
- if (!SafeC) {
- // TODO: Should this be available as a constant utility function? It is
- // similar to getBinOpAbsorber().
- if (IsRHSConstant) {
- switch (Opcode) {
- case Instruction::SRem: // X % 1 = 0
- case Instruction::URem: // X %u 1 = 0
- SafeC = ConstantInt::get(EltTy, 1);
- break;
- case Instruction::FRem: // X % 1.0 (doesn't simplify, but it is safe)
- SafeC = ConstantFP::get(EltTy, 1.0);
- break;
- default:
- llvm_unreachable("Only rem opcodes have no identity constant for RHS");
- }
- } else {
- switch (Opcode) {
- case Instruction::Shl: // 0 << X = 0
- case Instruction::LShr: // 0 >>u X = 0
- case Instruction::AShr: // 0 >> X = 0
- case Instruction::SDiv: // 0 / X = 0
- case Instruction::UDiv: // 0 /u X = 0
- case Instruction::SRem: // 0 % X = 0
- case Instruction::URem: // 0 %u X = 0
- case Instruction::Sub: // 0 - X (doesn't simplify, but it is safe)
- case Instruction::FSub: // 0.0 - X (doesn't simplify, but it is safe)
- case Instruction::FDiv: // 0.0 / X (doesn't simplify, but it is safe)
- case Instruction::FRem: // 0.0 % X = 0
- SafeC = Constant::getNullValue(EltTy);
- break;
- default:
- llvm_unreachable("Expected to find identity constant for opcode");
- }
- }
- }
- assert(SafeC && "Must have safe constant for binop");
- unsigned NumElts = InVTy->getNumElements();
- SmallVector<Constant *, 16> Out(NumElts);
- for (unsigned i = 0; i != NumElts; ++i) {
- Constant *C = In->getAggregateElement(i);
- Out[i] = isa<UndefValue>(C) ? SafeC : C;
- }
- return ConstantVector::get(Out);
-}
-
-/// The core instruction combiner logic.
-///
-/// This class provides both the logic to recursively visit instructions and
-/// combine them.
-class LLVM_LIBRARY_VISIBILITY InstCombiner
- : public InstVisitor<InstCombiner, Instruction *> {
- // FIXME: These members shouldn't be public.
+class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
+ : public InstCombiner,
+ public InstVisitor<InstCombinerImpl, Instruction *> {
public:
- /// A worklist of the instructions that need to be simplified.
- InstCombineWorklist &Worklist;
-
- /// An IRBuilder that automatically inserts new instructions into the
- /// worklist.
- using BuilderTy = IRBuilder<TargetFolder, IRBuilderCallbackInserter>;
- BuilderTy &Builder;
-
-private:
- // Mode in which we are running the combiner.
- const bool MinimizeSize;
-
- AAResults *AA;
-
- // Required analyses.
- AssumptionCache &AC;
- TargetLibraryInfo &TLI;
- DominatorTree &DT;
- const DataLayout &DL;
- const SimplifyQuery SQ;
- OptimizationRemarkEmitter &ORE;
- BlockFrequencyInfo *BFI;
- ProfileSummaryInfo *PSI;
+ InstCombinerImpl(InstCombineWorklist &Worklist, BuilderTy &Builder,
+ bool MinimizeSize, AAResults *AA, AssumptionCache &AC,
+ TargetLibraryInfo &TLI, TargetTransformInfo &TTI,
+ DominatorTree &DT, OptimizationRemarkEmitter &ORE,
+ BlockFrequencyInfo *BFI, ProfileSummaryInfo *PSI,
+ const DataLayout &DL, LoopInfo *LI)
+ : InstCombiner(Worklist, Builder, MinimizeSize, AA, AC, TLI, TTI, DT, ORE,
+ BFI, PSI, DL, LI) {}
- // Optional analyses. When non-null, these can both be used to do better
- // combining and will be updated to reflect any changes.
- LoopInfo *LI;
-
- bool MadeIRChange = false;
-
-public:
- InstCombiner(InstCombineWorklist &Worklist, BuilderTy &Builder,
- bool MinimizeSize, AAResults *AA,
- AssumptionCache &AC, TargetLibraryInfo &TLI, DominatorTree &DT,
- OptimizationRemarkEmitter &ORE, BlockFrequencyInfo *BFI,
- ProfileSummaryInfo *PSI, const DataLayout &DL, LoopInfo *LI)
- : Worklist(Worklist), Builder(Builder), MinimizeSize(MinimizeSize),
- AA(AA), AC(AC), TLI(TLI), DT(DT),
- DL(DL), SQ(DL, &TLI, &DT, &AC), ORE(ORE), BFI(BFI), PSI(PSI), LI(LI) {}
+ virtual ~InstCombinerImpl() {}
/// Run the combiner over the entire worklist until it is empty.
///
/// \returns true if the IR is changed.
bool run();
- AssumptionCache &getAssumptionCache() const { return AC; }
-
- const DataLayout &getDataLayout() const { return DL; }
-
- DominatorTree &getDominatorTree() const { return DT; }
-
- LoopInfo *getLoopInfo() const { return LI; }
-
- TargetLibraryInfo &getTargetLibraryInfo() const { return TLI; }
-
// Visitation implementation - Implement instruction combining for different
// instruction types. The semantics are as follows:
// Return Value:
@@ -384,9 +97,7 @@ public:
Instruction *visitSRem(BinaryOperator &I);
Instruction *visitFRem(BinaryOperator &I);
bool simplifyDivRemOfSelectWithZeroOp(BinaryOperator &I);
- Instruction *commonRemTransforms(BinaryOperator &I);
Instruction *commonIRemTransforms(BinaryOperator &I);
- Instruction *commonDivTransforms(BinaryOperator &I);
Instruction *commonIDivTransforms(BinaryOperator &I);
Instruction *visitUDiv(BinaryOperator &I);
Instruction *visitSDiv(BinaryOperator &I);
@@ -394,6 +105,7 @@ public:
Value *simplifyRangeCheck(ICmpInst *Cmp0, ICmpInst *Cmp1, bool Inverted);
Instruction *visitAnd(BinaryOperator &I);
Instruction *visitOr(BinaryOperator &I);
+ bool sinkNotIntoOtherHandOfAndOrOr(BinaryOperator &I);
Instruction *visitXor(BinaryOperator &I);
Instruction *visitShl(BinaryOperator &I);
Value *reassociateShiftAmtsOfTwoSameDirectionShifts(
@@ -407,6 +119,7 @@ public:
Instruction *visitLShr(BinaryOperator &I);
Instruction *commonShiftTransforms(BinaryOperator &I);
Instruction *visitFCmpInst(FCmpInst &I);
+ CmpInst *canonicalizeICmpPredicate(CmpInst &I);
Instruction *visitICmpInst(ICmpInst &I);
Instruction *FoldShiftByConstant(Value *Op0, Constant *Op1,
BinaryOperator &I);
@@ -445,6 +158,9 @@ public:
Instruction *visitFenceInst(FenceInst &FI);
Instruction *visitSwitchInst(SwitchInst &SI);
Instruction *visitReturnInst(ReturnInst &RI);
+ Instruction *visitUnreachableInst(UnreachableInst &I);
+ Instruction *
+ foldAggregateConstructionIntoAggregateReuse(InsertValueInst &OrigIVI);
Instruction *visitInsertValueInst(InsertValueInst &IV);
Instruction *visitInsertElementInst(InsertElementInst &IE);
Instruction *visitExtractElementInst(ExtractElementInst &EI);
@@ -467,11 +183,6 @@ public:
bool replacedSelectWithOperand(SelectInst *SI, const ICmpInst *Icmp,
const unsigned SIOpd);
- /// Try to replace instruction \p I with value \p V which are pointers
- /// in different address space.
- /// \return true if successful.
- bool replacePointer(Instruction &I, Value *V);
-
LoadInst *combineLoadToNewType(LoadInst &LI, Type *NewTy,
const Twine &Suffix = "");
@@ -609,10 +320,12 @@ private:
Instruction *narrowBinOp(TruncInst &Trunc);
Instruction *narrowMaskedBinOp(BinaryOperator &And);
Instruction *narrowMathIfNoOverflow(BinaryOperator &I);
- Instruction *narrowRotate(TruncInst &Trunc);
+ Instruction *narrowFunnelShift(TruncInst &Trunc);
Instruction *optimizeBitCastFromPhi(CastInst &CI, PHINode *PN);
Instruction *matchSAddSubSat(SelectInst &MinMax1);
+ void freelyInvertAllUsersOf(Value *V);
+
/// Determine if a pair of casts can be replaced by a single cast.
///
/// \param CI1 The first of a pair of casts.
@@ -653,7 +366,7 @@ public:
"New instruction already inserted into a basic block!");
BasicBlock *BB = Old.getParent();
BB->getInstList().insert(Old.getIterator(), New); // Insert inst
- Worklist.push(New);
+ Worklist.add(New);
return New;
}
@@ -685,6 +398,7 @@ public:
<< " with " << *V << '\n');
I.replaceAllUsesWith(V);
+ MadeIRChange = true;
return &I;
}
@@ -726,7 +440,7 @@ public:
/// When dealing with an instruction that has side effects or produces a void
/// value, we can't rely on DCE to delete the instruction. Instead, visit
/// methods should return the value returned by this function.
- Instruction *eraseInstFromFunction(Instruction &I) {
+ Instruction *eraseInstFromFunction(Instruction &I) override {
LLVM_DEBUG(dbgs() << "IC: ERASE " << I << '\n');
assert(I.use_empty() && "Cannot erase instruction that is used!");
salvageDebugInfo(I);
@@ -808,10 +522,6 @@ public:
Instruction::BinaryOps BinaryOp, bool IsSigned,
Value *LHS, Value *RHS, Instruction *CxtI) const;
- /// Maximum size of array considered when transforming.
- uint64_t MaxArraySizeForCombine = 0;
-
-private:
/// Performs a few simplifications for operators which are associative
/// or commutative.
bool SimplifyAssociativeOrCommutative(BinaryOperator &I);
@@ -857,7 +567,7 @@ private:
unsigned Depth, Instruction *CxtI);
bool SimplifyDemandedBits(Instruction *I, unsigned Op,
const APInt &DemandedMask, KnownBits &Known,
- unsigned Depth = 0);
+ unsigned Depth = 0) override;
/// Helper routine of SimplifyDemandedUseBits. It computes KnownZero/KnownOne
/// bits. It also tries to handle simplifications that can be done based on
@@ -877,13 +587,10 @@ private:
/// demanded bits.
bool SimplifyDemandedInstructionBits(Instruction &Inst);
- Value *simplifyAMDGCNMemoryIntrinsicDemanded(IntrinsicInst *II,
- APInt DemandedElts,
- int DmaskIdx = -1);
-
- Value *SimplifyDemandedVectorElts(Value *V, APInt DemandedElts,
- APInt &UndefElts, unsigned Depth = 0,
- bool AllowMultipleUsers = false);
+ virtual Value *
+ SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, APInt &UndefElts,
+ unsigned Depth = 0,
+ bool AllowMultipleUsers = false) override;
/// Canonicalize the position of binops relative to shufflevector.
Instruction *foldVectorBinop(BinaryOperator &Inst);
@@ -907,16 +614,18 @@ private:
/// Try to rotate an operation below a PHI node, using PHI nodes for
/// its operands.
- Instruction *FoldPHIArgOpIntoPHI(PHINode &PN);
- Instruction *FoldPHIArgBinOpIntoPHI(PHINode &PN);
- Instruction *FoldPHIArgGEPIntoPHI(PHINode &PN);
- Instruction *FoldPHIArgLoadIntoPHI(PHINode &PN);
- Instruction *FoldPHIArgZextsIntoPHI(PHINode &PN);
+ Instruction *foldPHIArgOpIntoPHI(PHINode &PN);
+ Instruction *foldPHIArgBinOpIntoPHI(PHINode &PN);
+ Instruction *foldPHIArgInsertValueInstructionIntoPHI(PHINode &PN);
+ Instruction *foldPHIArgExtractValueInstructionIntoPHI(PHINode &PN);
+ Instruction *foldPHIArgGEPIntoPHI(PHINode &PN);
+ Instruction *foldPHIArgLoadIntoPHI(PHINode &PN);
+ Instruction *foldPHIArgZextsIntoPHI(PHINode &PN);
/// If an integer typed PHI has only one use which is an IntToPtr operation,
/// replace the PHI with an existing pointer typed PHI if it exists. Otherwise
/// insert a new pointer typed PHI and replace the original one.
- Instruction *FoldIntegerTypedPHI(PHINode &PN);
+ Instruction *foldIntegerTypedPHI(PHINode &PN);
/// Helper function for FoldPHIArgXIntoPHI() to set debug location for the
/// folded operation.
@@ -999,18 +708,18 @@ private:
Value *A, Value *B, Instruction &Outer,
SelectPatternFlavor SPF2, Value *C);
Instruction *foldSelectInstWithICmp(SelectInst &SI, ICmpInst *ICI);
-
- Instruction *OptAndOp(BinaryOperator *Op, ConstantInt *OpRHS,
- ConstantInt *AndRHS, BinaryOperator &TheAnd);
+ Instruction *foldSelectValueEquivalence(SelectInst &SI, ICmpInst &ICI);
Value *insertRangeTest(Value *V, const APInt &Lo, const APInt &Hi,
bool isSigned, bool Inside);
Instruction *PromoteCastOfAllocation(BitCastInst &CI, AllocaInst &AI);
bool mergeStoreIntoSuccessor(StoreInst &SI);
- /// Given an 'or' instruction, check to see if it is part of a bswap idiom.
- /// If so, return the equivalent bswap intrinsic.
- Instruction *matchBSwap(BinaryOperator &Or);
+ /// Given an 'or' instruction, check to see if it is part of a
+ /// bswap/bitreverse idiom. If so, return the equivalent bswap/bitreverse
+ /// intrinsic.
+ Instruction *matchBSwapOrBitReverse(BinaryOperator &Or, bool MatchBSwaps,
+ bool MatchBitReversals);
Instruction *SimplifyAnyMemTransfer(AnyMemTransferInst *MI);
Instruction *SimplifyAnyMemSet(AnyMemSetInst *MI);
@@ -1023,18 +732,6 @@ private:
Value *Descale(Value *Val, APInt Scale, bool &NoSignedWrap);
};
-namespace {
-
-// As a default, let's assume that we want to be aggressive,
-// and attempt to traverse with no limits in attempt to sink negation.
-static constexpr unsigned NegatorDefaultMaxDepth = ~0U;
-
-// Let's guesstimate that most often we will end up visiting/producing
-// fairly small number of new instructions.
-static constexpr unsigned NegatorMaxNodesSSO = 16;
-
-} // namespace
-
class Negator final {
/// Top-to-bottom, def-to-use negated instruction tree we produced.
SmallVector<Instruction *, NegatorMaxNodesSSO> NewInstructions;
@@ -1061,6 +758,8 @@ class Negator final {
using Result = std::pair<ArrayRef<Instruction *> /*NewInstructions*/,
Value * /*NegatedRoot*/>;
+ std::array<Value *, 2> getSortedOperandsOfBinOp(Instruction *I);
+
LLVM_NODISCARD Value *visitImpl(Value *V, unsigned Depth);
LLVM_NODISCARD Value *negate(Value *V, unsigned Depth);
@@ -1078,7 +777,7 @@ public:
/// Attempt to negate \p Root. Retuns nullptr if negation can't be performed,
/// otherwise returns negated value.
LLVM_NODISCARD static Value *Negate(bool LHSIsZero, Value *Root,
- InstCombiner &IC);
+ InstCombinerImpl &IC);
};
} // end namespace llvm
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
index dad2f23120bd..c7b5f6f78069 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
@@ -23,6 +23,7 @@
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/MDBuilder.h"
#include "llvm/IR/PatternMatch.h"
+#include "llvm/Transforms/InstCombine/InstCombiner.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/Local.h"
using namespace llvm;
@@ -166,7 +167,8 @@ static bool isDereferenceableForAllocaSize(const Value *V, const AllocaInst *AI,
APInt(64, AllocaSize), DL);
}
-static Instruction *simplifyAllocaArraySize(InstCombiner &IC, AllocaInst &AI) {
+static Instruction *simplifyAllocaArraySize(InstCombinerImpl &IC,
+ AllocaInst &AI) {
// Check for array size of 1 (scalar allocation).
if (!AI.isArrayAllocation()) {
// i32 1 is the canonical array size for scalar allocations.
@@ -234,47 +236,45 @@ namespace {
// instruction.
class PointerReplacer {
public:
- PointerReplacer(InstCombiner &IC) : IC(IC) {}
+ PointerReplacer(InstCombinerImpl &IC) : IC(IC) {}
+
+ bool collectUsers(Instruction &I);
void replacePointer(Instruction &I, Value *V);
private:
- void findLoadAndReplace(Instruction &I);
void replace(Instruction *I);
Value *getReplacement(Value *I);
- SmallVector<Instruction *, 4> Path;
+ SmallSetVector<Instruction *, 4> Worklist;
MapVector<Value *, Value *> WorkMap;
- InstCombiner &IC;
+ InstCombinerImpl &IC;
};
} // end anonymous namespace
-void PointerReplacer::findLoadAndReplace(Instruction &I) {
+bool PointerReplacer::collectUsers(Instruction &I) {
for (auto U : I.users()) {
- auto *Inst = dyn_cast<Instruction>(&*U);
- if (!Inst)
- return;
- LLVM_DEBUG(dbgs() << "Found pointer user: " << *U << '\n');
- if (isa<LoadInst>(Inst)) {
- for (auto P : Path)
- replace(P);
- replace(Inst);
+ Instruction *Inst = cast<Instruction>(&*U);
+ if (LoadInst *Load = dyn_cast<LoadInst>(Inst)) {
+ if (Load->isVolatile())
+ return false;
+ Worklist.insert(Load);
} else if (isa<GetElementPtrInst>(Inst) || isa<BitCastInst>(Inst)) {
- Path.push_back(Inst);
- findLoadAndReplace(*Inst);
- Path.pop_back();
+ Worklist.insert(Inst);
+ if (!collectUsers(*Inst))
+ return false;
+ } else if (isa<MemTransferInst>(Inst)) {
+ Worklist.insert(Inst);
} else {
- return;
+ LLVM_DEBUG(dbgs() << "Cannot handle pointer user: " << *U << '\n');
+ return false;
}
}
-}
-Value *PointerReplacer::getReplacement(Value *V) {
- auto Loc = WorkMap.find(V);
- if (Loc != WorkMap.end())
- return Loc->second;
- return nullptr;
+ return true;
}
+Value *PointerReplacer::getReplacement(Value *V) { return WorkMap.lookup(V); }
+
void PointerReplacer::replace(Instruction *I) {
if (getReplacement(I))
return;
@@ -282,9 +282,12 @@ void PointerReplacer::replace(Instruction *I) {
if (auto *LT = dyn_cast<LoadInst>(I)) {
auto *V = getReplacement(LT->getPointerOperand());
assert(V && "Operand not replaced");
- auto *NewI = new LoadInst(I->getType(), V, "", false,
- IC.getDataLayout().getABITypeAlign(I->getType()));
+ auto *NewI = new LoadInst(LT->getType(), V, "", LT->isVolatile(),
+ LT->getAlign(), LT->getOrdering(),
+ LT->getSyncScopeID());
NewI->takeName(LT);
+ copyMetadataForLoad(*NewI, *LT);
+
IC.InsertNewInstWith(NewI, *LT);
IC.replaceInstUsesWith(*LT, NewI);
WorkMap[LT] = NewI;
@@ -307,6 +310,28 @@ void PointerReplacer::replace(Instruction *I) {
IC.InsertNewInstWith(NewI, *BC);
NewI->takeName(BC);
WorkMap[BC] = NewI;
+ } else if (auto *MemCpy = dyn_cast<MemTransferInst>(I)) {
+ auto *SrcV = getReplacement(MemCpy->getRawSource());
+ // The pointer may appear in the destination of a copy, but we don't want to
+ // replace it.
+ if (!SrcV) {
+ assert(getReplacement(MemCpy->getRawDest()) &&
+ "destination not in replace list");
+ return;
+ }
+
+ IC.Builder.SetInsertPoint(MemCpy);
+ auto *NewI = IC.Builder.CreateMemTransferInst(
+ MemCpy->getIntrinsicID(), MemCpy->getRawDest(), MemCpy->getDestAlign(),
+ SrcV, MemCpy->getSourceAlign(), MemCpy->getLength(),
+ MemCpy->isVolatile());
+ AAMDNodes AAMD;
+ MemCpy->getAAMetadata(AAMD);
+ if (AAMD)
+ NewI->setAAMetadata(AAMD);
+
+ IC.eraseInstFromFunction(*MemCpy);
+ WorkMap[MemCpy] = NewI;
} else {
llvm_unreachable("should never reach here");
}
@@ -320,10 +345,12 @@ void PointerReplacer::replacePointer(Instruction &I, Value *V) {
"Invalid usage");
#endif
WorkMap[&I] = V;
- findLoadAndReplace(I);
+
+ for (Instruction *Workitem : Worklist)
+ replace(Workitem);
}
-Instruction *InstCombiner::visitAllocaInst(AllocaInst &AI) {
+Instruction *InstCombinerImpl::visitAllocaInst(AllocaInst &AI) {
if (auto *I = simplifyAllocaArraySize(*this, AI))
return I;
@@ -374,23 +401,21 @@ Instruction *InstCombiner::visitAllocaInst(AllocaInst &AI) {
// read.
SmallVector<Instruction *, 4> ToDelete;
if (MemTransferInst *Copy = isOnlyCopiedFromConstantMemory(AA, &AI, ToDelete)) {
+ Value *TheSrc = Copy->getSource();
Align AllocaAlign = AI.getAlign();
Align SourceAlign = getOrEnforceKnownAlignment(
- Copy->getSource(), AllocaAlign, DL, &AI, &AC, &DT);
+ TheSrc, AllocaAlign, DL, &AI, &AC, &DT);
if (AllocaAlign <= SourceAlign &&
- isDereferenceableForAllocaSize(Copy->getSource(), &AI, DL)) {
+ isDereferenceableForAllocaSize(TheSrc, &AI, DL)) {
LLVM_DEBUG(dbgs() << "Found alloca equal to global: " << AI << '\n');
LLVM_DEBUG(dbgs() << " memcpy = " << *Copy << '\n');
- for (unsigned i = 0, e = ToDelete.size(); i != e; ++i)
- eraseInstFromFunction(*ToDelete[i]);
- Value *TheSrc = Copy->getSource();
- auto *SrcTy = TheSrc->getType();
- auto *DestTy = PointerType::get(AI.getType()->getPointerElementType(),
- SrcTy->getPointerAddressSpace());
- Value *Cast =
- Builder.CreatePointerBitCastOrAddrSpaceCast(TheSrc, DestTy);
- if (AI.getType()->getPointerAddressSpace() ==
- SrcTy->getPointerAddressSpace()) {
+ unsigned SrcAddrSpace = TheSrc->getType()->getPointerAddressSpace();
+ auto *DestTy = PointerType::get(AI.getAllocatedType(), SrcAddrSpace);
+ if (AI.getType()->getAddressSpace() == SrcAddrSpace) {
+ for (Instruction *Delete : ToDelete)
+ eraseInstFromFunction(*Delete);
+
+ Value *Cast = Builder.CreateBitCast(TheSrc, DestTy);
Instruction *NewI = replaceInstUsesWith(AI, Cast);
eraseInstFromFunction(*Copy);
++NumGlobalCopies;
@@ -398,8 +423,14 @@ Instruction *InstCombiner::visitAllocaInst(AllocaInst &AI) {
}
PointerReplacer PtrReplacer(*this);
- PtrReplacer.replacePointer(AI, Cast);
- ++NumGlobalCopies;
+ if (PtrReplacer.collectUsers(AI)) {
+ for (Instruction *Delete : ToDelete)
+ eraseInstFromFunction(*Delete);
+
+ Value *Cast = Builder.CreateBitCast(TheSrc, DestTy);
+ PtrReplacer.replacePointer(AI, Cast);
+ ++NumGlobalCopies;
+ }
}
}
@@ -421,9 +452,9 @@ static bool isSupportedAtomicType(Type *Ty) {
/// that pointer type, load it, etc.
///
/// Note that this will create all of the instructions with whatever insert
-/// point the \c InstCombiner currently is using.
-LoadInst *InstCombiner::combineLoadToNewType(LoadInst &LI, Type *NewTy,
- const Twine &Suffix) {
+/// point the \c InstCombinerImpl currently is using.
+LoadInst *InstCombinerImpl::combineLoadToNewType(LoadInst &LI, Type *NewTy,
+ const Twine &Suffix) {
assert((!LI.isAtomic() || isSupportedAtomicType(NewTy)) &&
"can't fold an atomic load to requested type");
@@ -445,7 +476,8 @@ LoadInst *InstCombiner::combineLoadToNewType(LoadInst &LI, Type *NewTy,
/// Combine a store to a new type.
///
/// Returns the newly created store instruction.
-static StoreInst *combineStoreToNewValue(InstCombiner &IC, StoreInst &SI, Value *V) {
+static StoreInst *combineStoreToNewValue(InstCombinerImpl &IC, StoreInst &SI,
+ Value *V) {
assert((!SI.isAtomic() || isSupportedAtomicType(V->getType())) &&
"can't fold an atomic store of requested type");
@@ -485,6 +517,7 @@ static StoreInst *combineStoreToNewValue(InstCombiner &IC, StoreInst &SI, Value
break;
case LLVMContext::MD_invariant_load:
case LLVMContext::MD_nonnull:
+ case LLVMContext::MD_noundef:
case LLVMContext::MD_range:
case LLVMContext::MD_align:
case LLVMContext::MD_dereferenceable:
@@ -502,7 +535,7 @@ static StoreInst *combineStoreToNewValue(InstCombiner &IC, StoreInst &SI, Value
static bool isMinMaxWithLoads(Value *V, Type *&LoadTy) {
assert(V->getType()->isPointerTy() && "Expected pointer type.");
// Ignore possible ty* to ixx* bitcast.
- V = peekThroughBitcast(V);
+ V = InstCombiner::peekThroughBitcast(V);
// Check that select is select ((cmp load V1, load V2), V1, V2) - minmax
// pattern.
CmpInst::Predicate Pred;
@@ -537,7 +570,8 @@ static bool isMinMaxWithLoads(Value *V, Type *&LoadTy) {
/// or a volatile load. This is debatable, and might be reasonable to change
/// later. However, it is risky in case some backend or other part of LLVM is
/// relying on the exact type loaded to select appropriate atomic operations.
-static Instruction *combineLoadToOperationType(InstCombiner &IC, LoadInst &LI) {
+static Instruction *combineLoadToOperationType(InstCombinerImpl &IC,
+ LoadInst &LI) {
// FIXME: We could probably with some care handle both volatile and ordered
// atomic loads here but it isn't clear that this is important.
if (!LI.isUnordered())
@@ -550,62 +584,38 @@ static Instruction *combineLoadToOperationType(InstCombiner &IC, LoadInst &LI) {
if (LI.getPointerOperand()->isSwiftError())
return nullptr;
- Type *Ty = LI.getType();
const DataLayout &DL = IC.getDataLayout();
- // Try to canonicalize loads which are only ever stored to operate over
- // integers instead of any other type. We only do this when the loaded type
- // is sized and has a size exactly the same as its store size and the store
- // size is a legal integer type.
- // Do not perform canonicalization if minmax pattern is found (to avoid
- // infinite loop).
- Type *Dummy;
- if (!Ty->isIntegerTy() && Ty->isSized() && !isa<ScalableVectorType>(Ty) &&
- DL.isLegalInteger(DL.getTypeStoreSizeInBits(Ty)) &&
- DL.typeSizeEqualsStoreSize(Ty) && !DL.isNonIntegralPointerType(Ty) &&
- !isMinMaxWithLoads(
- peekThroughBitcast(LI.getPointerOperand(), /*OneUseOnly=*/true),
- Dummy)) {
- if (all_of(LI.users(), [&LI](User *U) {
- auto *SI = dyn_cast<StoreInst>(U);
- return SI && SI->getPointerOperand() != &LI &&
- !SI->getPointerOperand()->isSwiftError();
- })) {
- LoadInst *NewLoad = IC.combineLoadToNewType(
- LI, Type::getIntNTy(LI.getContext(), DL.getTypeStoreSizeInBits(Ty)));
- // Replace all the stores with stores of the newly loaded value.
- for (auto UI = LI.user_begin(), UE = LI.user_end(); UI != UE;) {
- auto *SI = cast<StoreInst>(*UI++);
- IC.Builder.SetInsertPoint(SI);
- combineStoreToNewValue(IC, *SI, NewLoad);
- IC.eraseInstFromFunction(*SI);
- }
- assert(LI.use_empty() && "Failed to remove all users of the load!");
- // Return the old load so the combiner can delete it safely.
- return &LI;
+ // Fold away bit casts of the loaded value by loading the desired type.
+ // Note that we should not do this for pointer<->integer casts,
+ // because that would result in type punning.
+ if (LI.hasOneUse()) {
+ // Don't transform when the type is x86_amx, it makes the pass that lower
+ // x86_amx type happy.
+ if (auto *BC = dyn_cast<BitCastInst>(LI.user_back())) {
+ assert(!LI.getType()->isX86_AMXTy() &&
+ "load from x86_amx* should not happen!");
+ if (BC->getType()->isX86_AMXTy())
+ return nullptr;
}
- }
- // Fold away bit casts of the loaded value by loading the desired type.
- // We can do this for BitCastInsts as well as casts from and to pointer types,
- // as long as those are noops (i.e., the source or dest type have the same
- // bitwidth as the target's pointers).
- if (LI.hasOneUse())
if (auto* CI = dyn_cast<CastInst>(LI.user_back()))
- if (CI->isNoopCast(DL))
+ if (CI->isNoopCast(DL) && LI.getType()->isPtrOrPtrVectorTy() ==
+ CI->getDestTy()->isPtrOrPtrVectorTy())
if (!LI.isAtomic() || isSupportedAtomicType(CI->getDestTy())) {
LoadInst *NewLoad = IC.combineLoadToNewType(LI, CI->getDestTy());
CI->replaceAllUsesWith(NewLoad);
IC.eraseInstFromFunction(*CI);
return &LI;
}
+ }
// FIXME: We should also canonicalize loads of vectors when their elements are
// cast to other types.
return nullptr;
}
-static Instruction *unpackLoadToAggregate(InstCombiner &IC, LoadInst &LI) {
+static Instruction *unpackLoadToAggregate(InstCombinerImpl &IC, LoadInst &LI) {
// FIXME: We could probably with some care handle both volatile and atomic
// stores here but it isn't clear that this is important.
if (!LI.isSimple())
@@ -743,8 +753,7 @@ static bool isObjectSizeLessThanOrEq(Value *V, uint64_t MaxSize,
}
if (PHINode *PN = dyn_cast<PHINode>(P)) {
- for (Value *IncValue : PN->incoming_values())
- Worklist.push_back(IncValue);
+ append_range(Worklist, PN->incoming_values());
continue;
}
@@ -804,8 +813,9 @@ static bool isObjectSizeLessThanOrEq(Value *V, uint64_t MaxSize,
// not zero. Currently, we only handle the first such index. Also, we could
// also search through non-zero constant indices if we kept track of the
// offsets those indices implied.
-static bool canReplaceGEPIdxWithZero(InstCombiner &IC, GetElementPtrInst *GEPI,
- Instruction *MemI, unsigned &Idx) {
+static bool canReplaceGEPIdxWithZero(InstCombinerImpl &IC,
+ GetElementPtrInst *GEPI, Instruction *MemI,
+ unsigned &Idx) {
if (GEPI->getNumOperands() < 2)
return false;
@@ -834,12 +844,17 @@ static bool canReplaceGEPIdxWithZero(InstCombiner &IC, GetElementPtrInst *GEPI,
return false;
SmallVector<Value *, 4> Ops(GEPI->idx_begin(), GEPI->idx_begin() + Idx);
- Type *AllocTy =
- GetElementPtrInst::getIndexedType(GEPI->getSourceElementType(), Ops);
+ Type *SourceElementType = GEPI->getSourceElementType();
+ // Size information about scalable vectors is not available, so we cannot
+ // deduce whether indexing at n is undefined behaviour or not. Bail out.
+ if (isa<ScalableVectorType>(SourceElementType))
+ return false;
+
+ Type *AllocTy = GetElementPtrInst::getIndexedType(SourceElementType, Ops);
if (!AllocTy || !AllocTy->isSized())
return false;
const DataLayout &DL = IC.getDataLayout();
- uint64_t TyAllocSize = DL.getTypeAllocSize(AllocTy);
+ uint64_t TyAllocSize = DL.getTypeAllocSize(AllocTy).getFixedSize();
// If there are more indices after the one we might replace with a zero, make
// sure they're all non-negative. If any of them are negative, the overall
@@ -874,7 +889,7 @@ static bool canReplaceGEPIdxWithZero(InstCombiner &IC, GetElementPtrInst *GEPI,
// access, but the object has only one element, we can assume that the index
// will always be zero. If we replace the GEP, return it.
template <typename T>
-static Instruction *replaceGEPIdxWithZero(InstCombiner &IC, Value *Ptr,
+static Instruction *replaceGEPIdxWithZero(InstCombinerImpl &IC, Value *Ptr,
T &MemI) {
if (GetElementPtrInst *GEPI = dyn_cast<GetElementPtrInst>(Ptr)) {
unsigned Idx;
@@ -916,7 +931,7 @@ static bool canSimplifyNullLoadOrGEP(LoadInst &LI, Value *Op) {
return false;
}
-Instruction *InstCombiner::visitLoadInst(LoadInst &LI) {
+Instruction *InstCombinerImpl::visitLoadInst(LoadInst &LI) {
Value *Op = LI.getOperand(0);
// Try to canonicalize the loaded type.
@@ -1033,7 +1048,7 @@ Instruction *InstCombiner::visitLoadInst(LoadInst &LI) {
/// and the layout of a <2 x double> is isomorphic to a [2 x double],
/// then %V1 can be safely approximated by a conceptual "bitcast" of %U.
/// Note that %U may contain non-undef values where %V1 has undef.
-static Value *likeBitCastFromVector(InstCombiner &IC, Value *V) {
+static Value *likeBitCastFromVector(InstCombinerImpl &IC, Value *V) {
Value *U = nullptr;
while (auto *IV = dyn_cast<InsertValueInst>(V)) {
auto *E = dyn_cast<ExtractElementInst>(IV->getInsertedValueOperand());
@@ -1060,11 +1075,11 @@ static Value *likeBitCastFromVector(InstCombiner &IC, Value *V) {
return nullptr;
}
if (auto *AT = dyn_cast<ArrayType>(VT)) {
- if (AT->getNumElements() != UT->getNumElements())
+ if (AT->getNumElements() != cast<FixedVectorType>(UT)->getNumElements())
return nullptr;
} else {
auto *ST = cast<StructType>(VT);
- if (ST->getNumElements() != UT->getNumElements())
+ if (ST->getNumElements() != cast<FixedVectorType>(UT)->getNumElements())
return nullptr;
for (const auto *EltT : ST->elements()) {
if (EltT != UT->getElementType())
@@ -1094,7 +1109,7 @@ static Value *likeBitCastFromVector(InstCombiner &IC, Value *V) {
/// the caller must erase the store instruction. We have to let the caller erase
/// the store instruction as otherwise there is no way to signal whether it was
/// combined or not: IC.EraseInstFromFunction returns a null pointer.
-static bool combineStoreToValueType(InstCombiner &IC, StoreInst &SI) {
+static bool combineStoreToValueType(InstCombinerImpl &IC, StoreInst &SI) {
// FIXME: We could probably with some care handle both volatile and ordered
// atomic stores here but it isn't clear that this is important.
if (!SI.isUnordered())
@@ -1108,7 +1123,13 @@ static bool combineStoreToValueType(InstCombiner &IC, StoreInst &SI) {
// Fold away bit casts of the stored value by storing the original type.
if (auto *BC = dyn_cast<BitCastInst>(V)) {
+ assert(!BC->getType()->isX86_AMXTy() &&
+ "store to x86_amx* should not happen!");
V = BC->getOperand(0);
+ // Don't transform when the type is x86_amx, it makes the pass that lower
+ // x86_amx type happy.
+ if (V->getType()->isX86_AMXTy())
+ return false;
if (!SI.isAtomic() || isSupportedAtomicType(V->getType())) {
combineStoreToNewValue(IC, SI, V);
return true;
@@ -1126,7 +1147,7 @@ static bool combineStoreToValueType(InstCombiner &IC, StoreInst &SI) {
return false;
}
-static bool unpackStoreToAggregate(InstCombiner &IC, StoreInst &SI) {
+static bool unpackStoreToAggregate(InstCombinerImpl &IC, StoreInst &SI) {
// FIXME: We could probably with some care handle both volatile and atomic
// stores here but it isn't clear that this is important.
if (!SI.isSimple())
@@ -1266,7 +1287,7 @@ static bool equivalentAddressValues(Value *A, Value *B) {
/// Converts store (bitcast (load (bitcast (select ...)))) to
/// store (load (select ...)), where select is minmax:
/// select ((cmp load V1, load V2), V1, V2).
-static bool removeBitcastsFromLoadStoreOnMinMax(InstCombiner &IC,
+static bool removeBitcastsFromLoadStoreOnMinMax(InstCombinerImpl &IC,
StoreInst &SI) {
// bitcast?
if (!match(SI.getPointerOperand(), m_BitCast(m_Value())))
@@ -1296,7 +1317,8 @@ static bool removeBitcastsFromLoadStoreOnMinMax(InstCombiner &IC,
if (!all_of(LI->users(), [LI, LoadAddr](User *U) {
auto *SI = dyn_cast<StoreInst>(U);
return SI && SI->getPointerOperand() != LI &&
- peekThroughBitcast(SI->getPointerOperand()) != LoadAddr &&
+ InstCombiner::peekThroughBitcast(SI->getPointerOperand()) !=
+ LoadAddr &&
!SI->getPointerOperand()->isSwiftError();
}))
return false;
@@ -1314,7 +1336,7 @@ static bool removeBitcastsFromLoadStoreOnMinMax(InstCombiner &IC,
return true;
}
-Instruction *InstCombiner::visitStoreInst(StoreInst &SI) {
+Instruction *InstCombinerImpl::visitStoreInst(StoreInst &SI) {
Value *Val = SI.getOperand(0);
Value *Ptr = SI.getOperand(1);
@@ -1433,7 +1455,7 @@ Instruction *InstCombiner::visitStoreInst(StoreInst &SI) {
/// or:
/// *P = v1; if () { *P = v2; }
/// into a phi node with a store in the successor.
-bool InstCombiner::mergeStoreIntoSuccessor(StoreInst &SI) {
+bool InstCombinerImpl::mergeStoreIntoSuccessor(StoreInst &SI) {
if (!SI.isUnordered())
return false; // This code has not been audited for volatile/ordered case.
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index c6233a68847d..4b485a0ad85e 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -32,6 +32,7 @@
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/KnownBits.h"
#include "llvm/Transforms/InstCombine/InstCombineWorklist.h"
+#include "llvm/Transforms/InstCombine/InstCombiner.h"
#include "llvm/Transforms/Utils/BuildLibCalls.h"
#include <cassert>
#include <cstddef>
@@ -46,7 +47,7 @@ using namespace PatternMatch;
/// The specific integer value is used in a context where it is known to be
/// non-zero. If this allows us to simplify the computation, do so and return
/// the new operand, otherwise return null.
-static Value *simplifyValueKnownNonZero(Value *V, InstCombiner &IC,
+static Value *simplifyValueKnownNonZero(Value *V, InstCombinerImpl &IC,
Instruction &CxtI) {
// If V has multiple uses, then we would have to do more analysis to determine
// if this is safe. For example, the use could be in dynamically unreached
@@ -94,39 +95,6 @@ static Value *simplifyValueKnownNonZero(Value *V, InstCombiner &IC,
return MadeChange ? V : nullptr;
}
-/// A helper routine of InstCombiner::visitMul().
-///
-/// If C is a scalar/fixed width vector of known powers of 2, then this
-/// function returns a new scalar/fixed width vector obtained from logBase2
-/// of C.
-/// Return a null pointer otherwise.
-static Constant *getLogBase2(Type *Ty, Constant *C) {
- const APInt *IVal;
- if (match(C, m_APInt(IVal)) && IVal->isPowerOf2())
- return ConstantInt::get(Ty, IVal->logBase2());
-
- // FIXME: We can extract pow of 2 of splat constant for scalable vectors.
- if (!isa<FixedVectorType>(Ty))
- return nullptr;
-
- SmallVector<Constant *, 4> Elts;
- for (unsigned I = 0, E = cast<FixedVectorType>(Ty)->getNumElements(); I != E;
- ++I) {
- Constant *Elt = C->getAggregateElement(I);
- if (!Elt)
- return nullptr;
- if (isa<UndefValue>(Elt)) {
- Elts.push_back(UndefValue::get(Ty->getScalarType()));
- continue;
- }
- if (!match(Elt, m_APInt(IVal)) || !IVal->isPowerOf2())
- return nullptr;
- Elts.push_back(ConstantInt::get(Ty->getScalarType(), IVal->logBase2()));
- }
-
- return ConstantVector::get(Elts);
-}
-
// TODO: This is a specific form of a much more general pattern.
// We could detect a select with any binop identity constant, or we
// could use SimplifyBinOp to see if either arm of the select reduces.
@@ -171,7 +139,7 @@ static Value *foldMulSelectToNegate(BinaryOperator &I,
return nullptr;
}
-Instruction *InstCombiner::visitMul(BinaryOperator &I) {
+Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) {
if (Value *V = SimplifyMulInst(I.getOperand(0), I.getOperand(1),
SQ.getWithInstruction(&I)))
return replaceInstUsesWith(I, V);
@@ -185,8 +153,10 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) {
if (Value *V = SimplifyUsingDistributiveLaws(I))
return replaceInstUsesWith(I, V);
- // X * -1 == 0 - X
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
+ unsigned BitWidth = I.getType()->getScalarSizeInBits();
+
+ // X * -1 == 0 - X
if (match(Op1, m_AllOnes())) {
BinaryOperator *BO = BinaryOperator::CreateNeg(Op0, I.getName());
if (I.hasNoSignedWrap())
@@ -216,7 +186,7 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) {
if (match(&I, m_Mul(m_Value(NewOp), m_Constant(C1)))) {
// Replace X*(2^C) with X << C, where C is either a scalar or a vector.
- if (Constant *NewCst = getLogBase2(NewOp->getType(), C1)) {
+ if (Constant *NewCst = ConstantExpr::getExactLogBase2(C1)) {
BinaryOperator *Shl = BinaryOperator::CreateShl(NewOp, NewCst);
if (I.hasNoUnsignedWrap())
@@ -232,29 +202,12 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) {
}
}
- if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) {
- // (Y - X) * (-(2**n)) -> (X - Y) * (2**n), for positive nonzero n
- // (Y + const) * (-(2**n)) -> (-constY) * (2**n), for positive nonzero n
- // The "* (2**n)" thus becomes a potential shifting opportunity.
- {
- const APInt & Val = CI->getValue();
- const APInt &PosVal = Val.abs();
- if (Val.isNegative() && PosVal.isPowerOf2()) {
- Value *X = nullptr, *Y = nullptr;
- if (Op0->hasOneUse()) {
- ConstantInt *C1;
- Value *Sub = nullptr;
- if (match(Op0, m_Sub(m_Value(Y), m_Value(X))))
- Sub = Builder.CreateSub(X, Y, "suba");
- else if (match(Op0, m_Add(m_Value(Y), m_ConstantInt(C1))))
- Sub = Builder.CreateSub(Builder.CreateNeg(C1), Y, "subc");
- if (Sub)
- return
- BinaryOperator::CreateMul(Sub,
- ConstantInt::get(Y->getType(), PosVal));
- }
- }
- }
+ if (Op0->hasOneUse() && match(Op1, m_NegatedPower2())) {
+ // Interpret X * (-1<<C) as (-X) * (1<<C) and try to sink the negation.
+ // The "* (1<<C)" thus becomes a potential shifting opportunity.
+ if (Value *NegOp0 = Negator::Negate(/*IsNegation*/ true, Op0, *this))
+ return BinaryOperator::CreateMul(
+ NegOp0, ConstantExpr::getNeg(cast<Constant>(Op1)), I.getName());
}
if (Instruction *FoldedMul = foldBinOpIntoSelectOrPhi(I))
@@ -284,6 +237,9 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) {
SelectPatternFlavor SPF = matchSelectPattern(Op0, X, Y).Flavor;
if (SPF == SPF_ABS || SPF == SPF_NABS)
return BinaryOperator::CreateMul(X, X);
+
+ if (match(Op0, m_Intrinsic<Intrinsic::abs>(m_Value(X))))
+ return BinaryOperator::CreateMul(X, X);
}
// -X * C --> X * -C
@@ -406,6 +362,19 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) {
if (match(Op1, m_LShr(m_Value(X), m_APInt(C))) && *C == C->getBitWidth() - 1)
return BinaryOperator::CreateAnd(Builder.CreateAShr(X, *C), Op0);
+ // ((ashr X, 31) | 1) * X --> abs(X)
+ // X * ((ashr X, 31) | 1) --> abs(X)
+ if (match(&I, m_c_BinOp(m_Or(m_AShr(m_Value(X),
+ m_SpecificIntAllowUndef(BitWidth - 1)),
+ m_One()),
+ m_Deferred(X)))) {
+ Value *Abs = Builder.CreateBinaryIntrinsic(
+ Intrinsic::abs, X,
+ ConstantInt::getBool(I.getContext(), I.hasNoSignedWrap()));
+ Abs->takeName(&I);
+ return replaceInstUsesWith(I, Abs);
+ }
+
if (Instruction *Ext = narrowMathIfNoOverflow(I))
return Ext;
@@ -423,7 +392,7 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) {
return Changed ? &I : nullptr;
}
-Instruction *InstCombiner::foldFPSignBitOps(BinaryOperator &I) {
+Instruction *InstCombinerImpl::foldFPSignBitOps(BinaryOperator &I) {
BinaryOperator::BinaryOps Opcode = I.getOpcode();
assert((Opcode == Instruction::FMul || Opcode == Instruction::FDiv) &&
"Expected fmul or fdiv");
@@ -438,13 +407,12 @@ Instruction *InstCombiner::foldFPSignBitOps(BinaryOperator &I) {
// fabs(X) * fabs(X) -> X * X
// fabs(X) / fabs(X) -> X / X
- if (Op0 == Op1 && match(Op0, m_Intrinsic<Intrinsic::fabs>(m_Value(X))))
+ if (Op0 == Op1 && match(Op0, m_FAbs(m_Value(X))))
return BinaryOperator::CreateWithCopiedFlags(Opcode, X, X, &I);
// fabs(X) * fabs(Y) --> fabs(X * Y)
// fabs(X) / fabs(Y) --> fabs(X / Y)
- if (match(Op0, m_Intrinsic<Intrinsic::fabs>(m_Value(X))) &&
- match(Op1, m_Intrinsic<Intrinsic::fabs>(m_Value(Y))) &&
+ if (match(Op0, m_FAbs(m_Value(X))) && match(Op1, m_FAbs(m_Value(Y))) &&
(Op0->hasOneUse() || Op1->hasOneUse())) {
IRBuilder<>::FastMathFlagGuard FMFGuard(Builder);
Builder.setFastMathFlags(I.getFastMathFlags());
@@ -457,7 +425,7 @@ Instruction *InstCombiner::foldFPSignBitOps(BinaryOperator &I) {
return nullptr;
}
-Instruction *InstCombiner::visitFMul(BinaryOperator &I) {
+Instruction *InstCombinerImpl::visitFMul(BinaryOperator &I) {
if (Value *V = SimplifyFMulInst(I.getOperand(0), I.getOperand(1),
I.getFastMathFlags(),
SQ.getWithInstruction(&I)))
@@ -553,6 +521,21 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) {
return replaceInstUsesWith(I, Sqrt);
}
+ // The following transforms are done irrespective of the number of uses
+ // for the expression "1.0/sqrt(X)".
+ // 1) 1.0/sqrt(X) * X -> X/sqrt(X)
+ // 2) X * 1.0/sqrt(X) -> X/sqrt(X)
+ // We always expect the backend to reduce X/sqrt(X) to sqrt(X), if it
+ // has the necessary (reassoc) fast-math-flags.
+ if (I.hasNoSignedZeros() &&
+ match(Op0, (m_FDiv(m_SpecificFP(1.0), m_Value(Y)))) &&
+ match(Y, m_Intrinsic<Intrinsic::sqrt>(m_Value(X))) && Op1 == X)
+ return BinaryOperator::CreateFDivFMF(X, Y, &I);
+ if (I.hasNoSignedZeros() &&
+ match(Op1, (m_FDiv(m_SpecificFP(1.0), m_Value(Y)))) &&
+ match(Y, m_Intrinsic<Intrinsic::sqrt>(m_Value(X))) && Op0 == X)
+ return BinaryOperator::CreateFDivFMF(X, Y, &I);
+
// Like the similar transform in instsimplify, this requires 'nsz' because
// sqrt(-0.0) = -0.0, and -0.0 * -0.0 does not simplify to -0.0.
if (I.hasNoNaNs() && I.hasNoSignedZeros() && Op0 == Op1 &&
@@ -637,7 +620,7 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) {
/// Fold a divide or remainder with a select instruction divisor when one of the
/// select operands is zero. In that case, we can use the other select operand
/// because div/rem by zero is undefined.
-bool InstCombiner::simplifyDivRemOfSelectWithZeroOp(BinaryOperator &I) {
+bool InstCombinerImpl::simplifyDivRemOfSelectWithZeroOp(BinaryOperator &I) {
SelectInst *SI = dyn_cast<SelectInst>(I.getOperand(1));
if (!SI)
return false;
@@ -738,7 +721,7 @@ static bool isMultiple(const APInt &C1, const APInt &C2, APInt &Quotient,
/// instructions (udiv and sdiv). It is called by the visitors to those integer
/// division instructions.
/// Common integer divide transforms
-Instruction *InstCombiner::commonIDivTransforms(BinaryOperator &I) {
+Instruction *InstCombinerImpl::commonIDivTransforms(BinaryOperator &I) {
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
bool IsSigned = I.getOpcode() == Instruction::SDiv;
Type *Ty = I.getType();
@@ -874,7 +857,7 @@ namespace {
using FoldUDivOperandCb = Instruction *(*)(Value *Op0, Value *Op1,
const BinaryOperator &I,
- InstCombiner &IC);
+ InstCombinerImpl &IC);
/// Used to maintain state for visitUDivOperand().
struct UDivFoldAction {
@@ -903,8 +886,9 @@ struct UDivFoldAction {
// X udiv 2^C -> X >> C
static Instruction *foldUDivPow2Cst(Value *Op0, Value *Op1,
- const BinaryOperator &I, InstCombiner &IC) {
- Constant *C1 = getLogBase2(Op0->getType(), cast<Constant>(Op1));
+ const BinaryOperator &I,
+ InstCombinerImpl &IC) {
+ Constant *C1 = ConstantExpr::getExactLogBase2(cast<Constant>(Op1));
if (!C1)
llvm_unreachable("Failed to constant fold udiv -> logbase2");
BinaryOperator *LShr = BinaryOperator::CreateLShr(Op0, C1);
@@ -916,7 +900,7 @@ static Instruction *foldUDivPow2Cst(Value *Op0, Value *Op1,
// X udiv (C1 << N), where C1 is "1<<C2" --> X >> (N+C2)
// X udiv (zext (C1 << N)), where C1 is "1<<C2" --> X >> (N+C2)
static Instruction *foldUDivShl(Value *Op0, Value *Op1, const BinaryOperator &I,
- InstCombiner &IC) {
+ InstCombinerImpl &IC) {
Value *ShiftLeft;
if (!match(Op1, m_ZExt(m_Value(ShiftLeft))))
ShiftLeft = Op1;
@@ -925,7 +909,7 @@ static Instruction *foldUDivShl(Value *Op0, Value *Op1, const BinaryOperator &I,
Value *N;
if (!match(ShiftLeft, m_Shl(m_Constant(CI), m_Value(N))))
llvm_unreachable("match should never fail here!");
- Constant *Log2Base = getLogBase2(N->getType(), CI);
+ Constant *Log2Base = ConstantExpr::getExactLogBase2(CI);
if (!Log2Base)
llvm_unreachable("getLogBase2 should never fail here!");
N = IC.Builder.CreateAdd(N, Log2Base);
@@ -944,6 +928,8 @@ static Instruction *foldUDivShl(Value *Op0, Value *Op1, const BinaryOperator &I,
static size_t visitUDivOperand(Value *Op0, Value *Op1, const BinaryOperator &I,
SmallVectorImpl<UDivFoldAction> &Actions,
unsigned Depth = 0) {
+ // FIXME: assert that Op1 isn't/doesn't contain undef.
+
// Check to see if this is an unsigned division with an exact power of 2,
// if so, convert to a right shift.
if (match(Op1, m_Power2())) {
@@ -963,6 +949,9 @@ static size_t visitUDivOperand(Value *Op0, Value *Op1, const BinaryOperator &I,
return 0;
if (SelectInst *SI = dyn_cast<SelectInst>(Op1))
+ // FIXME: missed optimization: if one of the hands of select is/contains
+ // undef, just directly pick the other one.
+ // FIXME: can both hands contain undef?
if (size_t LHSIdx =
visitUDivOperand(Op0, SI->getOperand(1), I, Actions, Depth))
if (visitUDivOperand(Op0, SI->getOperand(2), I, Actions, Depth)) {
@@ -1010,7 +999,7 @@ static Instruction *narrowUDivURem(BinaryOperator &I,
return nullptr;
}
-Instruction *InstCombiner::visitUDiv(BinaryOperator &I) {
+Instruction *InstCombinerImpl::visitUDiv(BinaryOperator &I) {
if (Value *V = SimplifyUDivInst(I.getOperand(0), I.getOperand(1),
SQ.getWithInstruction(&I)))
return replaceInstUsesWith(I, V);
@@ -1104,7 +1093,7 @@ Instruction *InstCombiner::visitUDiv(BinaryOperator &I) {
return nullptr;
}
-Instruction *InstCombiner::visitSDiv(BinaryOperator &I) {
+Instruction *InstCombinerImpl::visitSDiv(BinaryOperator &I) {
if (Value *V = SimplifySDivInst(I.getOperand(0), I.getOperand(1),
SQ.getWithInstruction(&I)))
return replaceInstUsesWith(I, V);
@@ -1117,6 +1106,7 @@ Instruction *InstCombiner::visitSDiv(BinaryOperator &I) {
return Common;
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
+ Type *Ty = I.getType();
Value *X;
// sdiv Op0, -1 --> -Op0
// sdiv Op0, (sext i1 X) --> -Op0 (because if X is 0, the op is undefined)
@@ -1126,16 +1116,26 @@ Instruction *InstCombiner::visitSDiv(BinaryOperator &I) {
// X / INT_MIN --> X == INT_MIN
if (match(Op1, m_SignMask()))
- return new ZExtInst(Builder.CreateICmpEQ(Op0, Op1), I.getType());
+ return new ZExtInst(Builder.CreateICmpEQ(Op0, Op1), Ty);
+
+ // sdiv exact X, 1<<C --> ashr exact X, C iff 1<<C is non-negative
+ // sdiv exact X, -1<<C --> -(ashr exact X, C)
+ if (I.isExact() && ((match(Op1, m_Power2()) && match(Op1, m_NonNegative())) ||
+ match(Op1, m_NegatedPower2()))) {
+ bool DivisorWasNegative = match(Op1, m_NegatedPower2());
+ if (DivisorWasNegative)
+ Op1 = ConstantExpr::getNeg(cast<Constant>(Op1));
+ auto *AShr = BinaryOperator::CreateExactAShr(
+ Op0, ConstantExpr::getExactLogBase2(cast<Constant>(Op1)), I.getName());
+ if (!DivisorWasNegative)
+ return AShr;
+ Builder.Insert(AShr);
+ AShr->setName(I.getName() + ".neg");
+ return BinaryOperator::CreateNeg(AShr, I.getName());
+ }
const APInt *Op1C;
if (match(Op1, m_APInt(Op1C))) {
- // sdiv exact X, C --> ashr exact X, log2(C)
- if (I.isExact() && Op1C->isNonNegative() && Op1C->isPowerOf2()) {
- Value *ShAmt = ConstantInt::get(Op1->getType(), Op1C->exactLogBase2());
- return BinaryOperator::CreateExactAShr(Op0, ShAmt, I.getName());
- }
-
// If the dividend is sign-extended and the constant divisor is small enough
// to fit in the source type, shrink the division to the narrower type:
// (sext X) sdiv C --> sext (X sdiv C)
@@ -1150,7 +1150,7 @@ Instruction *InstCombiner::visitSDiv(BinaryOperator &I) {
Constant *NarrowDivisor =
ConstantExpr::getTrunc(cast<Constant>(Op1), Op0Src->getType());
Value *NarrowOp = Builder.CreateSDiv(Op0Src, NarrowDivisor);
- return new SExtInst(NarrowOp, Op0->getType());
+ return new SExtInst(NarrowOp, Ty);
}
// -X / C --> X / -C (if the negation doesn't overflow).
@@ -1158,7 +1158,7 @@ Instruction *InstCombiner::visitSDiv(BinaryOperator &I) {
// checking if all elements are not the min-signed-val.
if (!Op1C->isMinSignedValue() &&
match(Op0, m_NSWSub(m_Zero(), m_Value(X)))) {
- Constant *NegC = ConstantInt::get(I.getType(), -(*Op1C));
+ Constant *NegC = ConstantInt::get(Ty, -(*Op1C));
Instruction *BO = BinaryOperator::CreateSDiv(X, NegC);
BO->setIsExact(I.isExact());
return BO;
@@ -1171,9 +1171,19 @@ Instruction *InstCombiner::visitSDiv(BinaryOperator &I) {
return BinaryOperator::CreateNSWNeg(
Builder.CreateSDiv(X, Y, I.getName(), I.isExact()));
+ // abs(X) / X --> X > -1 ? 1 : -1
+ // X / abs(X) --> X > -1 ? 1 : -1
+ if (match(&I, m_c_BinOp(
+ m_OneUse(m_Intrinsic<Intrinsic::abs>(m_Value(X), m_One())),
+ m_Deferred(X)))) {
+ Constant *NegOne = ConstantInt::getAllOnesValue(Ty);
+ Value *Cond = Builder.CreateICmpSGT(X, NegOne);
+ return SelectInst::Create(Cond, ConstantInt::get(Ty, 1), NegOne);
+ }
+
// If the sign bits of both operands are zero (i.e. we can prove they are
// unsigned inputs), turn this into a udiv.
- APInt Mask(APInt::getSignMask(I.getType()->getScalarSizeInBits()));
+ APInt Mask(APInt::getSignMask(Ty->getScalarSizeInBits()));
if (MaskedValueIsZero(Op0, Mask, 0, &I)) {
if (MaskedValueIsZero(Op1, Mask, 0, &I)) {
// X sdiv Y -> X udiv Y, iff X and Y don't have sign bit set
@@ -1182,6 +1192,13 @@ Instruction *InstCombiner::visitSDiv(BinaryOperator &I) {
return BO;
}
+ if (match(Op1, m_NegatedPower2())) {
+ // X sdiv (-(1 << C)) -> -(X sdiv (1 << C)) ->
+ // -> -(X udiv (1 << C)) -> -(X u>> C)
+ return BinaryOperator::CreateNeg(Builder.Insert(foldUDivPow2Cst(
+ Op0, ConstantExpr::getNeg(cast<Constant>(Op1)), I, *this)));
+ }
+
if (isKnownToBeAPowerOfTwo(Op1, /*OrZero*/ true, 0, &I)) {
// X sdiv (1 << Y) -> X udiv (1 << Y) ( -> X u>> Y)
// Safe because the only negative value (1 << Y) can take on is
@@ -1258,7 +1275,7 @@ static Instruction *foldFDivConstantDividend(BinaryOperator &I) {
return BinaryOperator::CreateFDivFMF(NewC, X, &I);
}
-Instruction *InstCombiner::visitFDiv(BinaryOperator &I) {
+Instruction *InstCombinerImpl::visitFDiv(BinaryOperator &I) {
if (Value *V = SimplifyFDivInst(I.getOperand(0), I.getOperand(1),
I.getFastMathFlags(),
SQ.getWithInstruction(&I)))
@@ -1350,10 +1367,8 @@ Instruction *InstCombiner::visitFDiv(BinaryOperator &I) {
// X / fabs(X) -> copysign(1.0, X)
// fabs(X) / X -> copysign(1.0, X)
if (I.hasNoNaNs() && I.hasNoInfs() &&
- (match(&I,
- m_FDiv(m_Value(X), m_Intrinsic<Intrinsic::fabs>(m_Deferred(X)))) ||
- match(&I, m_FDiv(m_Intrinsic<Intrinsic::fabs>(m_Value(X)),
- m_Deferred(X))))) {
+ (match(&I, m_FDiv(m_Value(X), m_FAbs(m_Deferred(X)))) ||
+ match(&I, m_FDiv(m_FAbs(m_Value(X)), m_Deferred(X))))) {
Value *V = Builder.CreateBinaryIntrinsic(
Intrinsic::copysign, ConstantFP::get(I.getType(), 1.0), X, &I);
return replaceInstUsesWith(I, V);
@@ -1365,7 +1380,7 @@ Instruction *InstCombiner::visitFDiv(BinaryOperator &I) {
/// instructions (urem and srem). It is called by the visitors to those integer
/// remainder instructions.
/// Common integer remainder transforms
-Instruction *InstCombiner::commonIRemTransforms(BinaryOperator &I) {
+Instruction *InstCombinerImpl::commonIRemTransforms(BinaryOperator &I) {
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
// The RHS is known non-zero.
@@ -1403,7 +1418,7 @@ Instruction *InstCombiner::commonIRemTransforms(BinaryOperator &I) {
return nullptr;
}
-Instruction *InstCombiner::visitURem(BinaryOperator &I) {
+Instruction *InstCombinerImpl::visitURem(BinaryOperator &I) {
if (Value *V = SimplifyURemInst(I.getOperand(0), I.getOperand(1),
SQ.getWithInstruction(&I)))
return replaceInstUsesWith(I, V);
@@ -1454,7 +1469,7 @@ Instruction *InstCombiner::visitURem(BinaryOperator &I) {
return nullptr;
}
-Instruction *InstCombiner::visitSRem(BinaryOperator &I) {
+Instruction *InstCombinerImpl::visitSRem(BinaryOperator &I) {
if (Value *V = SimplifySRemInst(I.getOperand(0), I.getOperand(1),
SQ.getWithInstruction(&I)))
return replaceInstUsesWith(I, V);
@@ -1477,7 +1492,7 @@ Instruction *InstCombiner::visitSRem(BinaryOperator &I) {
// -X srem Y --> -(X srem Y)
Value *X, *Y;
if (match(&I, m_SRem(m_OneUse(m_NSWSub(m_Zero(), m_Value(X))), m_Value(Y))))
- return BinaryOperator::CreateNSWNeg(Builder.CreateSRem(X, Y));
+ return BinaryOperator::CreateNSWNeg(Builder.CreateSRem(X, Y));
// If the sign bits of both operands are zero (i.e. we can prove they are
// unsigned inputs), turn this into a urem.
@@ -1491,7 +1506,7 @@ Instruction *InstCombiner::visitSRem(BinaryOperator &I) {
// If it's a constant vector, flip any negative values positive.
if (isa<ConstantVector>(Op1) || isa<ConstantDataVector>(Op1)) {
Constant *C = cast<Constant>(Op1);
- unsigned VWidth = cast<VectorType>(C->getType())->getNumElements();
+ unsigned VWidth = cast<FixedVectorType>(C->getType())->getNumElements();
bool hasNegative = false;
bool hasMissing = false;
@@ -1526,7 +1541,7 @@ Instruction *InstCombiner::visitSRem(BinaryOperator &I) {
return nullptr;
}
-Instruction *InstCombiner::visitFRem(BinaryOperator &I) {
+Instruction *InstCombinerImpl::visitFRem(BinaryOperator &I) {
if (Value *V = SimplifyFRemInst(I.getOperand(0), I.getOperand(1),
I.getFastMathFlags(),
SQ.getWithInstruction(&I)))
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp b/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp
index 3fe615ac5439..7718c8b0eedd 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp
@@ -42,6 +42,9 @@
#include "llvm/Support/DebugCounter.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/raw_ostream.h"
+#include "llvm/Transforms/InstCombine/InstCombiner.h"
+#include <cassert>
+#include <cstdint>
#include <functional>
#include <tuple>
#include <type_traits>
@@ -112,6 +115,19 @@ Negator::~Negator() {
}
#endif
+// Due to the InstCombine's worklist management, there are no guarantees that
+// each instruction we'll encounter has been visited by InstCombine already.
+// In particular, most importantly for us, that means we have to canonicalize
+// constants to RHS ourselves, since that is helpful sometimes.
+std::array<Value *, 2> Negator::getSortedOperandsOfBinOp(Instruction *I) {
+ assert(I->getNumOperands() == 2 && "Only for binops!");
+ std::array<Value *, 2> Ops{I->getOperand(0), I->getOperand(1)};
+ if (I->isCommutative() && InstCombiner::getComplexity(I->getOperand(0)) <
+ InstCombiner::getComplexity(I->getOperand(1)))
+ std::swap(Ops[0], Ops[1]);
+ return Ops;
+}
+
// FIXME: can this be reworked into a worklist-based algorithm while preserving
// the depth-first, early bailout traversal?
LLVM_NODISCARD Value *Negator::visitImpl(Value *V, unsigned Depth) {
@@ -156,11 +172,13 @@ LLVM_NODISCARD Value *Negator::visitImpl(Value *V, unsigned Depth) {
// In some cases we can give the answer without further recursion.
switch (I->getOpcode()) {
- case Instruction::Add:
+ case Instruction::Add: {
+ std::array<Value *, 2> Ops = getSortedOperandsOfBinOp(I);
// `inc` is always negatible.
- if (match(I->getOperand(1), m_One()))
- return Builder.CreateNot(I->getOperand(0), I->getName() + ".neg");
+ if (match(Ops[1], m_One()))
+ return Builder.CreateNot(Ops[0], I->getName() + ".neg");
break;
+ }
case Instruction::Xor:
// `not` is always negatible.
if (match(I, m_Not(m_Value(X))))
@@ -181,6 +199,10 @@ LLVM_NODISCARD Value *Negator::visitImpl(Value *V, unsigned Depth) {
}
return BO;
}
+ // While we could negate exact arithmetic shift:
+ // ashr exact %x, C --> sdiv exact i8 %x, -1<<C
+ // iff C != 0 and C u< bitwidth(%x), we don't want to,
+ // because division is *THAT* much worse than a shift.
break;
}
case Instruction::SExt:
@@ -197,26 +219,28 @@ LLVM_NODISCARD Value *Negator::visitImpl(Value *V, unsigned Depth) {
break; // Other instructions require recursive reasoning.
}
+ if (I->getOpcode() == Instruction::Sub &&
+ (I->hasOneUse() || match(I->getOperand(0), m_ImmConstant()))) {
+ // `sub` is always negatible.
+ // However, only do this either if the old `sub` doesn't stick around, or
+ // it was subtracting from a constant. Otherwise, this isn't profitable.
+ return Builder.CreateSub(I->getOperand(1), I->getOperand(0),
+ I->getName() + ".neg");
+ }
+
// Some other cases, while still don't require recursion,
// are restricted to the one-use case.
if (!V->hasOneUse())
return nullptr;
switch (I->getOpcode()) {
- case Instruction::Sub:
- // `sub` is always negatible.
- // But if the old `sub` sticks around, even thought we don't increase
- // instruction count, this is a likely regression since we increased
- // live-range of *both* of the operands, which might lead to more spilling.
- return Builder.CreateSub(I->getOperand(1), I->getOperand(0),
- I->getName() + ".neg");
case Instruction::SDiv:
// `sdiv` is negatible if divisor is not undef/INT_MIN/1.
// While this is normally not behind a use-check,
// let's consider division to be special since it's costly.
if (auto *Op1C = dyn_cast<Constant>(I->getOperand(1))) {
- if (!Op1C->containsUndefElement() && Op1C->isNotMinSignedValue() &&
- Op1C->isNotOneValue()) {
+ if (!Op1C->containsUndefOrPoisonElement() &&
+ Op1C->isNotMinSignedValue() && Op1C->isNotOneValue()) {
Value *BO =
Builder.CreateSDiv(I->getOperand(0), ConstantExpr::getNeg(Op1C),
I->getName() + ".neg");
@@ -237,6 +261,13 @@ LLVM_NODISCARD Value *Negator::visitImpl(Value *V, unsigned Depth) {
}
switch (I->getOpcode()) {
+ case Instruction::Freeze: {
+ // `freeze` is negatible if its operand is negatible.
+ Value *NegOp = negate(I->getOperand(0), Depth + 1);
+ if (!NegOp) // Early return.
+ return nullptr;
+ return Builder.CreateFreeze(NegOp, I->getName() + ".neg");
+ }
case Instruction::PHI: {
// `phi` is negatible if all the incoming values are negatible.
auto *PHI = cast<PHINode>(I);
@@ -254,20 +285,16 @@ LLVM_NODISCARD Value *Negator::visitImpl(Value *V, unsigned Depth) {
return NegatedPHI;
}
case Instruction::Select: {
- {
- // `abs`/`nabs` is always negatible.
- Value *LHS, *RHS;
- SelectPatternFlavor SPF =
- matchSelectPattern(I, LHS, RHS, /*CastOp=*/nullptr, Depth).Flavor;
- if (SPF == SPF_ABS || SPF == SPF_NABS) {
- auto *NewSelect = cast<SelectInst>(I->clone());
- // Just swap the operands of the select.
- NewSelect->swapValues();
- // Don't swap prof metadata, we didn't change the branch behavior.
- NewSelect->setName(I->getName() + ".neg");
- Builder.Insert(NewSelect);
- return NewSelect;
- }
+ if (isKnownNegation(I->getOperand(1), I->getOperand(2))) {
+ // Of one hand of select is known to be negation of another hand,
+ // just swap the hands around.
+ auto *NewSelect = cast<SelectInst>(I->clone());
+ // Just swap the operands of the select.
+ NewSelect->swapValues();
+ // Don't swap prof metadata, we didn't change the branch behavior.
+ NewSelect->setName(I->getName() + ".neg");
+ Builder.Insert(NewSelect);
+ return NewSelect;
}
// `select` is negatible if both hands of `select` are negatible.
Value *NegOp1 = negate(I->getOperand(1), Depth + 1);
@@ -323,51 +350,81 @@ LLVM_NODISCARD Value *Negator::visitImpl(Value *V, unsigned Depth) {
}
case Instruction::Shl: {
// `shl` is negatible if the first operand is negatible.
- Value *NegOp0 = negate(I->getOperand(0), Depth + 1);
- if (!NegOp0) // Early return.
+ if (Value *NegOp0 = negate(I->getOperand(0), Depth + 1))
+ return Builder.CreateShl(NegOp0, I->getOperand(1), I->getName() + ".neg");
+ // Otherwise, `shl %x, C` can be interpreted as `mul %x, 1<<C`.
+ auto *Op1C = dyn_cast<Constant>(I->getOperand(1));
+ if (!Op1C) // Early return.
return nullptr;
- return Builder.CreateShl(NegOp0, I->getOperand(1), I->getName() + ".neg");
+ return Builder.CreateMul(
+ I->getOperand(0),
+ ConstantExpr::getShl(Constant::getAllOnesValue(Op1C->getType()), Op1C),
+ I->getName() + ".neg");
}
- case Instruction::Or:
+ case Instruction::Or: {
if (!haveNoCommonBitsSet(I->getOperand(0), I->getOperand(1), DL, &AC, I,
&DT))
return nullptr; // Don't know how to handle `or` in general.
+ std::array<Value *, 2> Ops = getSortedOperandsOfBinOp(I);
// `or`/`add` are interchangeable when operands have no common bits set.
// `inc` is always negatible.
- if (match(I->getOperand(1), m_One()))
- return Builder.CreateNot(I->getOperand(0), I->getName() + ".neg");
+ if (match(Ops[1], m_One()))
+ return Builder.CreateNot(Ops[0], I->getName() + ".neg");
// Else, just defer to Instruction::Add handling.
LLVM_FALLTHROUGH;
+ }
case Instruction::Add: {
// `add` is negatible if both of its operands are negatible.
- Value *NegOp0 = negate(I->getOperand(0), Depth + 1);
- if (!NegOp0) // Early return.
- return nullptr;
- Value *NegOp1 = negate(I->getOperand(1), Depth + 1);
- if (!NegOp1)
+ SmallVector<Value *, 2> NegatedOps, NonNegatedOps;
+ for (Value *Op : I->operands()) {
+ // Can we sink the negation into this operand?
+ if (Value *NegOp = negate(Op, Depth + 1)) {
+ NegatedOps.emplace_back(NegOp); // Successfully negated operand!
+ continue;
+ }
+ // Failed to sink negation into this operand. IFF we started from negation
+ // and we manage to sink negation into one operand, we can still do this.
+ if (!IsTrulyNegation)
+ return nullptr;
+ NonNegatedOps.emplace_back(Op); // Just record which operand that was.
+ }
+ assert((NegatedOps.size() + NonNegatedOps.size()) == 2 &&
+ "Internal consistency sanity check.");
+ // Did we manage to sink negation into both of the operands?
+ if (NegatedOps.size() == 2) // Then we get to keep the `add`!
+ return Builder.CreateAdd(NegatedOps[0], NegatedOps[1],
+ I->getName() + ".neg");
+ assert(IsTrulyNegation && "We should have early-exited then.");
+ // Completely failed to sink negation?
+ if (NonNegatedOps.size() == 2)
return nullptr;
- return Builder.CreateAdd(NegOp0, NegOp1, I->getName() + ".neg");
+ // 0-(a+b) --> (-a)-b
+ return Builder.CreateSub(NegatedOps[0], NonNegatedOps[0],
+ I->getName() + ".neg");
}
- case Instruction::Xor:
+ case Instruction::Xor: {
+ std::array<Value *, 2> Ops = getSortedOperandsOfBinOp(I);
// `xor` is negatible if one of its operands is invertible.
// FIXME: InstCombineInverter? But how to connect Inverter and Negator?
- if (auto *C = dyn_cast<Constant>(I->getOperand(1))) {
- Value *Xor = Builder.CreateXor(I->getOperand(0), ConstantExpr::getNot(C));
+ if (auto *C = dyn_cast<Constant>(Ops[1])) {
+ Value *Xor = Builder.CreateXor(Ops[0], ConstantExpr::getNot(C));
return Builder.CreateAdd(Xor, ConstantInt::get(Xor->getType(), 1),
I->getName() + ".neg");
}
return nullptr;
+ }
case Instruction::Mul: {
+ std::array<Value *, 2> Ops = getSortedOperandsOfBinOp(I);
// `mul` is negatible if one of its operands is negatible.
Value *NegatedOp, *OtherOp;
// First try the second operand, in case it's a constant it will be best to
// just invert it instead of sinking the `neg` deeper.
- if (Value *NegOp1 = negate(I->getOperand(1), Depth + 1)) {
+ if (Value *NegOp1 = negate(Ops[1], Depth + 1)) {
NegatedOp = NegOp1;
- OtherOp = I->getOperand(0);
- } else if (Value *NegOp0 = negate(I->getOperand(0), Depth + 1)) {
+ OtherOp = Ops[0];
+ } else if (Value *NegOp0 = negate(Ops[0], Depth + 1)) {
NegatedOp = NegOp0;
- OtherOp = I->getOperand(1);
+ OtherOp = Ops[1];
} else
// Can't negate either of them.
return nullptr;
@@ -430,7 +487,7 @@ LLVM_NODISCARD Optional<Negator::Result> Negator::run(Value *Root) {
}
LLVM_NODISCARD Value *Negator::Negate(bool LHSIsZero, Value *Root,
- InstCombiner &IC) {
+ InstCombinerImpl &IC) {
++NegatorTotalNegationsAttempted;
LLVM_DEBUG(dbgs() << "Negator: attempting to sink negation into " << *Root
<< "\n");
diff --git a/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp b/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp
index 2b2f2e1b9470..d687ec654438 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp
@@ -13,11 +13,14 @@
#include "InstCombineInternal.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/ADT/Statistic.h"
#include "llvm/Analysis/InstructionSimplify.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/IR/PatternMatch.h"
#include "llvm/Support/CommandLine.h"
+#include "llvm/Transforms/InstCombine/InstCombiner.h"
#include "llvm/Transforms/Utils/Local.h"
+
using namespace llvm;
using namespace llvm::PatternMatch;
@@ -27,10 +30,16 @@ static cl::opt<unsigned>
MaxNumPhis("instcombine-max-num-phis", cl::init(512),
cl::desc("Maximum number phis to handle in intptr/ptrint folding"));
+STATISTIC(NumPHIsOfInsertValues,
+ "Number of phi-of-insertvalue turned into insertvalue-of-phis");
+STATISTIC(NumPHIsOfExtractValues,
+ "Number of phi-of-extractvalue turned into extractvalue-of-phi");
+STATISTIC(NumPHICSEs, "Number of PHI's that got CSE'd");
+
/// The PHI arguments will be folded into a single operation with a PHI node
/// as input. The debug location of the single operation will be the merged
/// locations of the original PHI node arguments.
-void InstCombiner::PHIArgMergedDebugLoc(Instruction *Inst, PHINode &PN) {
+void InstCombinerImpl::PHIArgMergedDebugLoc(Instruction *Inst, PHINode &PN) {
auto *FirstInst = cast<Instruction>(PN.getIncomingValue(0));
Inst->setDebugLoc(FirstInst->getDebugLoc());
// We do not expect a CallInst here, otherwise, N-way merging of DebugLoc
@@ -93,7 +102,7 @@ void InstCombiner::PHIArgMergedDebugLoc(Instruction *Inst, PHINode &PN) {
// ptr_val_inc = ...
// ...
//
-Instruction *InstCombiner::FoldIntegerTypedPHI(PHINode &PN) {
+Instruction *InstCombinerImpl::foldIntegerTypedPHI(PHINode &PN) {
if (!PN.getType()->isIntegerTy())
return nullptr;
if (!PN.hasOneUse())
@@ -290,9 +299,86 @@ Instruction *InstCombiner::FoldIntegerTypedPHI(PHINode &PN) {
IntToPtr->getOperand(0)->getType());
}
+/// If we have something like phi [insertvalue(a,b,0), insertvalue(c,d,0)],
+/// turn this into a phi[a,c] and phi[b,d] and a single insertvalue.
+Instruction *
+InstCombinerImpl::foldPHIArgInsertValueInstructionIntoPHI(PHINode &PN) {
+ auto *FirstIVI = cast<InsertValueInst>(PN.getIncomingValue(0));
+
+ // Scan to see if all operands are `insertvalue`'s with the same indicies,
+ // and all have a single use.
+ for (unsigned i = 1; i != PN.getNumIncomingValues(); ++i) {
+ auto *I = dyn_cast<InsertValueInst>(PN.getIncomingValue(i));
+ if (!I || !I->hasOneUser() || I->getIndices() != FirstIVI->getIndices())
+ return nullptr;
+ }
+
+ // For each operand of an `insertvalue`
+ std::array<PHINode *, 2> NewOperands;
+ for (int OpIdx : {0, 1}) {
+ auto *&NewOperand = NewOperands[OpIdx];
+ // Create a new PHI node to receive the values the operand has in each
+ // incoming basic block.
+ NewOperand = PHINode::Create(
+ FirstIVI->getOperand(OpIdx)->getType(), PN.getNumIncomingValues(),
+ FirstIVI->getOperand(OpIdx)->getName() + ".pn");
+ // And populate each operand's PHI with said values.
+ for (auto Incoming : zip(PN.blocks(), PN.incoming_values()))
+ NewOperand->addIncoming(
+ cast<InsertValueInst>(std::get<1>(Incoming))->getOperand(OpIdx),
+ std::get<0>(Incoming));
+ InsertNewInstBefore(NewOperand, PN);
+ }
+
+ // And finally, create `insertvalue` over the newly-formed PHI nodes.
+ auto *NewIVI = InsertValueInst::Create(NewOperands[0], NewOperands[1],
+ FirstIVI->getIndices(), PN.getName());
+
+ PHIArgMergedDebugLoc(NewIVI, PN);
+ ++NumPHIsOfInsertValues;
+ return NewIVI;
+}
+
+/// If we have something like phi [extractvalue(a,0), extractvalue(b,0)],
+/// turn this into a phi[a,b] and a single extractvalue.
+Instruction *
+InstCombinerImpl::foldPHIArgExtractValueInstructionIntoPHI(PHINode &PN) {
+ auto *FirstEVI = cast<ExtractValueInst>(PN.getIncomingValue(0));
+
+ // Scan to see if all operands are `extractvalue`'s with the same indicies,
+ // and all have a single use.
+ for (unsigned i = 1; i != PN.getNumIncomingValues(); ++i) {
+ auto *I = dyn_cast<ExtractValueInst>(PN.getIncomingValue(i));
+ if (!I || !I->hasOneUser() || I->getIndices() != FirstEVI->getIndices() ||
+ I->getAggregateOperand()->getType() !=
+ FirstEVI->getAggregateOperand()->getType())
+ return nullptr;
+ }
+
+ // Create a new PHI node to receive the values the aggregate operand has
+ // in each incoming basic block.
+ auto *NewAggregateOperand = PHINode::Create(
+ FirstEVI->getAggregateOperand()->getType(), PN.getNumIncomingValues(),
+ FirstEVI->getAggregateOperand()->getName() + ".pn");
+ // And populate the PHI with said values.
+ for (auto Incoming : zip(PN.blocks(), PN.incoming_values()))
+ NewAggregateOperand->addIncoming(
+ cast<ExtractValueInst>(std::get<1>(Incoming))->getAggregateOperand(),
+ std::get<0>(Incoming));
+ InsertNewInstBefore(NewAggregateOperand, PN);
+
+ // And finally, create `extractvalue` over the newly-formed PHI nodes.
+ auto *NewEVI = ExtractValueInst::Create(NewAggregateOperand,
+ FirstEVI->getIndices(), PN.getName());
+
+ PHIArgMergedDebugLoc(NewEVI, PN);
+ ++NumPHIsOfExtractValues;
+ return NewEVI;
+}
+
/// If we have something like phi [add (a,b), add(a,c)] and if a/b/c and the
-/// adds all have a single use, turn this into a phi and a single binop.
-Instruction *InstCombiner::FoldPHIArgBinOpIntoPHI(PHINode &PN) {
+/// adds all have a single user, turn this into a phi and a single binop.
+Instruction *InstCombinerImpl::foldPHIArgBinOpIntoPHI(PHINode &PN) {
Instruction *FirstInst = cast<Instruction>(PN.getIncomingValue(0));
assert(isa<BinaryOperator>(FirstInst) || isa<CmpInst>(FirstInst));
unsigned Opc = FirstInst->getOpcode();
@@ -302,10 +388,10 @@ Instruction *InstCombiner::FoldPHIArgBinOpIntoPHI(PHINode &PN) {
Type *LHSType = LHSVal->getType();
Type *RHSType = RHSVal->getType();
- // Scan to see if all operands are the same opcode, and all have one use.
+ // Scan to see if all operands are the same opcode, and all have one user.
for (unsigned i = 1; i != PN.getNumIncomingValues(); ++i) {
Instruction *I = dyn_cast<Instruction>(PN.getIncomingValue(i));
- if (!I || I->getOpcode() != Opc || !I->hasOneUse() ||
+ if (!I || I->getOpcode() != Opc || !I->hasOneUser() ||
// Verify type of the LHS matches so we don't fold cmp's of different
// types.
I->getOperand(0)->getType() != LHSType ||
@@ -385,7 +471,7 @@ Instruction *InstCombiner::FoldPHIArgBinOpIntoPHI(PHINode &PN) {
return NewBinOp;
}
-Instruction *InstCombiner::FoldPHIArgGEPIntoPHI(PHINode &PN) {
+Instruction *InstCombinerImpl::foldPHIArgGEPIntoPHI(PHINode &PN) {
GetElementPtrInst *FirstInst =cast<GetElementPtrInst>(PN.getIncomingValue(0));
SmallVector<Value*, 16> FixedOperands(FirstInst->op_begin(),
@@ -401,11 +487,12 @@ Instruction *InstCombiner::FoldPHIArgGEPIntoPHI(PHINode &PN) {
bool AllInBounds = true;
- // Scan to see if all operands are the same opcode, and all have one use.
+ // Scan to see if all operands are the same opcode, and all have one user.
for (unsigned i = 1; i != PN.getNumIncomingValues(); ++i) {
- GetElementPtrInst *GEP= dyn_cast<GetElementPtrInst>(PN.getIncomingValue(i));
- if (!GEP || !GEP->hasOneUse() || GEP->getType() != FirstInst->getType() ||
- GEP->getNumOperands() != FirstInst->getNumOperands())
+ GetElementPtrInst *GEP =
+ dyn_cast<GetElementPtrInst>(PN.getIncomingValue(i));
+ if (!GEP || !GEP->hasOneUser() || GEP->getType() != FirstInst->getType() ||
+ GEP->getNumOperands() != FirstInst->getNumOperands())
return nullptr;
AllInBounds &= GEP->isInBounds();
@@ -494,7 +581,6 @@ Instruction *InstCombiner::FoldPHIArgGEPIntoPHI(PHINode &PN) {
return NewGEP;
}
-
/// Return true if we know that it is safe to sink the load out of the block
/// that defines it. This means that it must be obvious the value of the load is
/// not changed from the point of the load to the end of the block it is in.
@@ -540,7 +626,7 @@ static bool isSafeAndProfitableToSinkLoad(LoadInst *L) {
return true;
}
-Instruction *InstCombiner::FoldPHIArgLoadIntoPHI(PHINode &PN) {
+Instruction *InstCombinerImpl::foldPHIArgLoadIntoPHI(PHINode &PN) {
LoadInst *FirstLI = cast<LoadInst>(PN.getIncomingValue(0));
// FIXME: This is overconservative; this transform is allowed in some cases
@@ -573,7 +659,7 @@ Instruction *InstCombiner::FoldPHIArgLoadIntoPHI(PHINode &PN) {
// Check to see if all arguments are the same operation.
for (unsigned i = 1, e = PN.getNumIncomingValues(); i != e; ++i) {
LoadInst *LI = dyn_cast<LoadInst>(PN.getIncomingValue(i));
- if (!LI || !LI->hasOneUse())
+ if (!LI || !LI->hasOneUser())
return nullptr;
// We can't sink the load if the loaded value could be modified between
@@ -654,7 +740,7 @@ Instruction *InstCombiner::FoldPHIArgLoadIntoPHI(PHINode &PN) {
/// TODO: This function could handle other cast types, but then it might
/// require special-casing a cast from the 'i1' type. See the comment in
/// FoldPHIArgOpIntoPHI() about pessimizing illegal integer types.
-Instruction *InstCombiner::FoldPHIArgZextsIntoPHI(PHINode &Phi) {
+Instruction *InstCombinerImpl::foldPHIArgZextsIntoPHI(PHINode &Phi) {
// We cannot create a new instruction after the PHI if the terminator is an
// EHPad because there is no valid insertion point.
if (Instruction *TI = Phi.getParent()->getTerminator())
@@ -686,8 +772,8 @@ Instruction *InstCombiner::FoldPHIArgZextsIntoPHI(PHINode &Phi) {
unsigned NumConsts = 0;
for (Value *V : Phi.incoming_values()) {
if (auto *Zext = dyn_cast<ZExtInst>(V)) {
- // All zexts must be identical and have one use.
- if (Zext->getSrcTy() != NarrowType || !Zext->hasOneUse())
+ // All zexts must be identical and have one user.
+ if (Zext->getSrcTy() != NarrowType || !Zext->hasOneUser())
return nullptr;
NewIncoming.push_back(Zext->getOperand(0));
NumZexts++;
@@ -728,7 +814,7 @@ Instruction *InstCombiner::FoldPHIArgZextsIntoPHI(PHINode &Phi) {
/// If all operands to a PHI node are the same "unary" operator and they all are
/// only used by the PHI, PHI together their inputs, and do the operation once,
/// to the result of the PHI.
-Instruction *InstCombiner::FoldPHIArgOpIntoPHI(PHINode &PN) {
+Instruction *InstCombinerImpl::foldPHIArgOpIntoPHI(PHINode &PN) {
// We cannot create a new instruction after the PHI if the terminator is an
// EHPad because there is no valid insertion point.
if (Instruction *TI = PN.getParent()->getTerminator())
@@ -738,9 +824,13 @@ Instruction *InstCombiner::FoldPHIArgOpIntoPHI(PHINode &PN) {
Instruction *FirstInst = cast<Instruction>(PN.getIncomingValue(0));
if (isa<GetElementPtrInst>(FirstInst))
- return FoldPHIArgGEPIntoPHI(PN);
+ return foldPHIArgGEPIntoPHI(PN);
if (isa<LoadInst>(FirstInst))
- return FoldPHIArgLoadIntoPHI(PN);
+ return foldPHIArgLoadIntoPHI(PN);
+ if (isa<InsertValueInst>(FirstInst))
+ return foldPHIArgInsertValueInstructionIntoPHI(PN);
+ if (isa<ExtractValueInst>(FirstInst))
+ return foldPHIArgExtractValueInstructionIntoPHI(PN);
// Scan the instruction, looking for input operations that can be folded away.
// If all input operands to the phi are the same instruction (e.g. a cast from
@@ -763,7 +853,7 @@ Instruction *InstCombiner::FoldPHIArgOpIntoPHI(PHINode &PN) {
// otherwise call FoldPHIArgBinOpIntoPHI.
ConstantOp = dyn_cast<Constant>(FirstInst->getOperand(1));
if (!ConstantOp)
- return FoldPHIArgBinOpIntoPHI(PN);
+ return foldPHIArgBinOpIntoPHI(PN);
} else {
return nullptr; // Cannot fold this operation.
}
@@ -771,7 +861,7 @@ Instruction *InstCombiner::FoldPHIArgOpIntoPHI(PHINode &PN) {
// Check to see if all arguments are the same operation.
for (unsigned i = 1, e = PN.getNumIncomingValues(); i != e; ++i) {
Instruction *I = dyn_cast<Instruction>(PN.getIncomingValue(i));
- if (!I || !I->hasOneUse() || !I->isSameOperationAs(FirstInst))
+ if (!I || !I->hasOneUser() || !I->isSameOperationAs(FirstInst))
return nullptr;
if (CastSrcTy) {
if (I->getOperand(0)->getType() != CastSrcTy)
@@ -923,7 +1013,7 @@ struct LoweredPHIRecord {
LoweredPHIRecord(PHINode *pn, unsigned Sh)
: PN(pn), Shift(Sh), Width(0) {}
};
-}
+} // namespace
namespace llvm {
template<>
@@ -944,7 +1034,7 @@ namespace llvm {
LHS.Width == RHS.Width;
}
};
-}
+} // namespace llvm
/// This is an integer PHI and we know that it has an illegal type: see if it is
@@ -955,7 +1045,7 @@ namespace llvm {
/// TODO: The user of the trunc may be an bitcast to float/double/vector or an
/// inttoptr. We should produce new PHIs in the right type.
///
-Instruction *InstCombiner::SliceUpIllegalIntegerPHI(PHINode &FirstPhi) {
+Instruction *InstCombinerImpl::SliceUpIllegalIntegerPHI(PHINode &FirstPhi) {
// PHIUsers - Keep track of all of the truncated values extracted from a set
// of PHIs, along with their offset. These are the things we want to rewrite.
SmallVector<PHIUsageRecord, 16> PHIUsers;
@@ -1129,13 +1219,85 @@ Instruction *InstCombiner::SliceUpIllegalIntegerPHI(PHINode &FirstPhi) {
return replaceInstUsesWith(FirstPhi, Undef);
}
+static Value *SimplifyUsingControlFlow(InstCombiner &Self, PHINode &PN,
+ const DominatorTree &DT) {
+ // Simplify the following patterns:
+ // if (cond)
+ // / \
+ // ... ...
+ // \ /
+ // phi [true] [false]
+ if (!PN.getType()->isIntegerTy(1))
+ return nullptr;
+
+ if (PN.getNumOperands() != 2)
+ return nullptr;
+
+ // Make sure all inputs are constants.
+ if (!all_of(PN.operands(), [](Value *V) { return isa<ConstantInt>(V); }))
+ return nullptr;
+
+ BasicBlock *BB = PN.getParent();
+ // Do not bother with unreachable instructions.
+ if (!DT.isReachableFromEntry(BB))
+ return nullptr;
+
+ // Same inputs.
+ if (PN.getOperand(0) == PN.getOperand(1))
+ return PN.getOperand(0);
+
+ BasicBlock *TruePred = nullptr, *FalsePred = nullptr;
+ for (auto *Pred : predecessors(BB)) {
+ auto *Input = cast<ConstantInt>(PN.getIncomingValueForBlock(Pred));
+ if (Input->isAllOnesValue())
+ TruePred = Pred;
+ else
+ FalsePred = Pred;
+ }
+ assert(TruePred && FalsePred && "Must be!");
+
+ // Check which edge of the dominator dominates the true input. If it is the
+ // false edge, we should invert the condition.
+ auto *IDom = DT.getNode(BB)->getIDom()->getBlock();
+ auto *BI = dyn_cast<BranchInst>(IDom->getTerminator());
+ if (!BI || BI->isUnconditional())
+ return nullptr;
+
+ // Check that edges outgoing from the idom's terminators dominate respective
+ // inputs of the Phi.
+ BasicBlockEdge TrueOutEdge(IDom, BI->getSuccessor(0));
+ BasicBlockEdge FalseOutEdge(IDom, BI->getSuccessor(1));
+
+ BasicBlockEdge TrueIncEdge(TruePred, BB);
+ BasicBlockEdge FalseIncEdge(FalsePred, BB);
+
+ auto *Cond = BI->getCondition();
+ if (DT.dominates(TrueOutEdge, TrueIncEdge) &&
+ DT.dominates(FalseOutEdge, FalseIncEdge))
+ // This Phi is actually equivalent to branching condition of IDom.
+ return Cond;
+ else if (DT.dominates(TrueOutEdge, FalseIncEdge) &&
+ DT.dominates(FalseOutEdge, TrueIncEdge)) {
+ // This Phi is actually opposite to branching condition of IDom. We invert
+ // the condition that will potentially open up some opportunities for
+ // sinking.
+ auto InsertPt = BB->getFirstInsertionPt();
+ if (InsertPt != BB->end()) {
+ Self.Builder.SetInsertPoint(&*InsertPt);
+ return Self.Builder.CreateNot(Cond);
+ }
+ }
+
+ return nullptr;
+}
+
// PHINode simplification
//
-Instruction *InstCombiner::visitPHINode(PHINode &PN) {
+Instruction *InstCombinerImpl::visitPHINode(PHINode &PN) {
if (Value *V = SimplifyInstruction(&PN, SQ.getWithInstruction(&PN)))
return replaceInstUsesWith(PN, V);
- if (Instruction *Result = FoldPHIArgZextsIntoPHI(PN))
+ if (Instruction *Result = foldPHIArgZextsIntoPHI(PN))
return Result;
// If all PHI operands are the same operation, pull them through the PHI,
@@ -1143,18 +1305,16 @@ Instruction *InstCombiner::visitPHINode(PHINode &PN) {
if (isa<Instruction>(PN.getIncomingValue(0)) &&
isa<Instruction>(PN.getIncomingValue(1)) &&
cast<Instruction>(PN.getIncomingValue(0))->getOpcode() ==
- cast<Instruction>(PN.getIncomingValue(1))->getOpcode() &&
- // FIXME: The hasOneUse check will fail for PHIs that use the value more
- // than themselves more than once.
- PN.getIncomingValue(0)->hasOneUse())
- if (Instruction *Result = FoldPHIArgOpIntoPHI(PN))
+ cast<Instruction>(PN.getIncomingValue(1))->getOpcode() &&
+ PN.getIncomingValue(0)->hasOneUser())
+ if (Instruction *Result = foldPHIArgOpIntoPHI(PN))
return Result;
// If this is a trivial cycle in the PHI node graph, remove it. Basically, if
// this PHI only has a single use (a PHI), and if that PHI only has one use (a
// PHI)... break the cycle.
if (PN.hasOneUse()) {
- if (Instruction *Result = FoldIntegerTypedPHI(PN))
+ if (Instruction *Result = foldIntegerTypedPHI(PN))
return Result;
Instruction *PHIUser = cast<Instruction>(PN.user_back());
@@ -1267,6 +1427,21 @@ Instruction *InstCombiner::visitPHINode(PHINode &PN) {
}
}
+ // Is there an identical PHI node in this basic block?
+ for (PHINode &IdenticalPN : PN.getParent()->phis()) {
+ // Ignore the PHI node itself.
+ if (&IdenticalPN == &PN)
+ continue;
+ // Note that even though we've just canonicalized this PHI, due to the
+ // worklist visitation order, there are no guarantess that *every* PHI
+ // has been canonicalized, so we can't just compare operands ranges.
+ if (!PN.isIdenticalToWhenDefined(&IdenticalPN))
+ continue;
+ // Just use that PHI instead then.
+ ++NumPHICSEs;
+ return replaceInstUsesWith(PN, &IdenticalPN);
+ }
+
// If this is an integer PHI and we know that it has an illegal type, see if
// it is only used by trunc or trunc(lshr) operations. If so, we split the
// PHI into the various pieces being extracted. This sort of thing is
@@ -1276,5 +1451,9 @@ Instruction *InstCombiner::visitPHINode(PHINode &PN) {
if (Instruction *Res = SliceUpIllegalIntegerPHI(PN))
return Res;
+ // Ultimately, try to replace this Phi with a dominating condition.
+ if (auto *V = SimplifyUsingControlFlow(*this, PN, DT))
+ return replaceInstUsesWith(PN, V);
+
return nullptr;
}
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 17124f717af7..f26c194d31b9 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -38,6 +38,7 @@
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/KnownBits.h"
#include "llvm/Transforms/InstCombine/InstCombineWorklist.h"
+#include "llvm/Transforms/InstCombine/InstCombiner.h"
#include <cassert>
#include <utility>
@@ -46,6 +47,11 @@ using namespace PatternMatch;
#define DEBUG_TYPE "instcombine"
+/// FIXME: Enabled by default until the pattern is supported well.
+static cl::opt<bool> EnableUnsafeSelectTransform(
+ "instcombine-unsafe-select-transform", cl::init(true),
+ cl::desc("Enable poison-unsafe select to and/or transform"));
+
static Value *createMinMax(InstCombiner::BuilderTy &Builder,
SelectPatternFlavor SPF, Value *A, Value *B) {
CmpInst::Predicate Pred = getMinMaxPred(SPF);
@@ -57,7 +63,7 @@ static Value *createMinMax(InstCombiner::BuilderTy &Builder,
/// constant of a binop.
static Instruction *foldSelectBinOpIdentity(SelectInst &Sel,
const TargetLibraryInfo &TLI,
- InstCombiner &IC) {
+ InstCombinerImpl &IC) {
// The select condition must be an equality compare with a constant operand.
Value *X;
Constant *C;
@@ -258,29 +264,9 @@ static unsigned getSelectFoldableOperands(BinaryOperator *I) {
}
}
-/// For the same transformation as the previous function, return the identity
-/// constant that goes into the select.
-static APInt getSelectFoldableConstant(BinaryOperator *I) {
- switch (I->getOpcode()) {
- default: llvm_unreachable("This cannot happen!");
- case Instruction::Add:
- case Instruction::Sub:
- case Instruction::Or:
- case Instruction::Xor:
- case Instruction::Shl:
- case Instruction::LShr:
- case Instruction::AShr:
- return APInt::getNullValue(I->getType()->getScalarSizeInBits());
- case Instruction::And:
- return APInt::getAllOnesValue(I->getType()->getScalarSizeInBits());
- case Instruction::Mul:
- return APInt(I->getType()->getScalarSizeInBits(), 1);
- }
-}
-
/// We have (select c, TI, FI), and we know that TI and FI have the same opcode.
-Instruction *InstCombiner::foldSelectOpOp(SelectInst &SI, Instruction *TI,
- Instruction *FI) {
+Instruction *InstCombinerImpl::foldSelectOpOp(SelectInst &SI, Instruction *TI,
+ Instruction *FI) {
// Don't break up min/max patterns. The hasOneUse checks below prevent that
// for most cases, but vector min/max with bitcasts can be transformed. If the
// one-use restrictions are eased for other patterns, we still don't want to
@@ -302,10 +288,9 @@ Instruction *InstCombiner::foldSelectOpOp(SelectInst &SI, Instruction *TI,
// The select condition may be a vector. We may only change the operand
// type if the vector width remains the same (and matches the condition).
if (auto *CondVTy = dyn_cast<VectorType>(CondTy)) {
- if (!FIOpndTy->isVectorTy())
- return nullptr;
- if (CondVTy->getNumElements() !=
- cast<VectorType>(FIOpndTy)->getNumElements())
+ if (!FIOpndTy->isVectorTy() ||
+ CondVTy->getElementCount() !=
+ cast<VectorType>(FIOpndTy)->getElementCount())
return nullptr;
// TODO: If the backend knew how to deal with casts better, we could
@@ -418,8 +403,8 @@ static bool isSelect01(const APInt &C1I, const APInt &C2I) {
/// Try to fold the select into one of the operands to allow further
/// optimization.
-Instruction *InstCombiner::foldSelectIntoOp(SelectInst &SI, Value *TrueVal,
- Value *FalseVal) {
+Instruction *InstCombinerImpl::foldSelectIntoOp(SelectInst &SI, Value *TrueVal,
+ Value *FalseVal) {
// See the comment above GetSelectFoldableOperands for a description of the
// transformation we are doing here.
if (auto *TVI = dyn_cast<BinaryOperator>(TrueVal)) {
@@ -433,14 +418,15 @@ Instruction *InstCombiner::foldSelectIntoOp(SelectInst &SI, Value *TrueVal,
}
if (OpToFold) {
- APInt CI = getSelectFoldableConstant(TVI);
+ Constant *C = ConstantExpr::getBinOpIdentity(TVI->getOpcode(),
+ TVI->getType(), true);
Value *OOp = TVI->getOperand(2-OpToFold);
// Avoid creating select between 2 constants unless it's selecting
// between 0, 1 and -1.
const APInt *OOpC;
bool OOpIsAPInt = match(OOp, m_APInt(OOpC));
- if (!isa<Constant>(OOp) || (OOpIsAPInt && isSelect01(CI, *OOpC))) {
- Value *C = ConstantInt::get(OOp->getType(), CI);
+ if (!isa<Constant>(OOp) ||
+ (OOpIsAPInt && isSelect01(C->getUniqueInteger(), *OOpC))) {
Value *NewSel = Builder.CreateSelect(SI.getCondition(), OOp, C);
NewSel->takeName(TVI);
BinaryOperator *BO = BinaryOperator::Create(TVI->getOpcode(),
@@ -464,14 +450,15 @@ Instruction *InstCombiner::foldSelectIntoOp(SelectInst &SI, Value *TrueVal,
}
if (OpToFold) {
- APInt CI = getSelectFoldableConstant(FVI);
+ Constant *C = ConstantExpr::getBinOpIdentity(FVI->getOpcode(),
+ FVI->getType(), true);
Value *OOp = FVI->getOperand(2-OpToFold);
// Avoid creating select between 2 constants unless it's selecting
// between 0, 1 and -1.
const APInt *OOpC;
bool OOpIsAPInt = match(OOp, m_APInt(OOpC));
- if (!isa<Constant>(OOp) || (OOpIsAPInt && isSelect01(CI, *OOpC))) {
- Value *C = ConstantInt::get(OOp->getType(), CI);
+ if (!isa<Constant>(OOp) ||
+ (OOpIsAPInt && isSelect01(C->getUniqueInteger(), *OOpC))) {
Value *NewSel = Builder.CreateSelect(SI.getCondition(), C, OOp);
NewSel->takeName(FVI);
BinaryOperator *BO = BinaryOperator::Create(FVI->getOpcode(),
@@ -782,25 +769,24 @@ static Value *canonicalizeSaturatedAdd(ICmpInst *Cmp, Value *TVal, Value *FVal,
// Match unsigned saturated add of 2 variables with an unnecessary 'not'.
// There are 8 commuted variants.
- // Canonicalize -1 (saturated result) to true value of the select. Just
- // swapping the compare operands is legal, because the selected value is the
- // same in case of equality, so we can interchange u< and u<=.
+ // Canonicalize -1 (saturated result) to true value of the select.
if (match(FVal, m_AllOnes())) {
std::swap(TVal, FVal);
- std::swap(Cmp0, Cmp1);
+ Pred = CmpInst::getInversePredicate(Pred);
}
if (!match(TVal, m_AllOnes()))
return nullptr;
- // Canonicalize predicate to 'ULT'.
- if (Pred == ICmpInst::ICMP_UGT) {
- Pred = ICmpInst::ICMP_ULT;
+ // Canonicalize predicate to less-than or less-or-equal-than.
+ if (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_UGE) {
std::swap(Cmp0, Cmp1);
+ Pred = CmpInst::getSwappedPredicate(Pred);
}
- if (Pred != ICmpInst::ICMP_ULT)
+ if (Pred != ICmpInst::ICMP_ULT && Pred != ICmpInst::ICMP_ULE)
return nullptr;
// Match unsigned saturated add of 2 variables with an unnecessary 'not'.
+ // Strictness of the comparison is irrelevant.
Value *Y;
if (match(Cmp0, m_Not(m_Value(X))) &&
match(FVal, m_c_Add(m_Specific(X), m_Value(Y))) && Y == Cmp1) {
@@ -809,6 +795,7 @@ static Value *canonicalizeSaturatedAdd(ICmpInst *Cmp, Value *TVal, Value *FVal,
return Builder.CreateBinaryIntrinsic(Intrinsic::uadd_sat, X, Y);
}
// The 'not' op may be included in the sum but not the compare.
+ // Strictness of the comparison is irrelevant.
X = Cmp0;
Y = Cmp1;
if (match(FVal, m_c_Add(m_Not(m_Specific(X)), m_Specific(Y)))) {
@@ -819,7 +806,9 @@ static Value *canonicalizeSaturatedAdd(ICmpInst *Cmp, Value *TVal, Value *FVal,
Intrinsic::uadd_sat, BO->getOperand(0), BO->getOperand(1));
}
// The overflow may be detected via the add wrapping round.
- if (match(Cmp0, m_c_Add(m_Specific(Cmp1), m_Value(Y))) &&
+ // This is only valid for strict comparison!
+ if (Pred == ICmpInst::ICMP_ULT &&
+ match(Cmp0, m_c_Add(m_Specific(Cmp1), m_Value(Y))) &&
match(FVal, m_c_Add(m_Specific(Cmp1), m_Specific(Y)))) {
// ((X + Y) u< X) ? -1 : (X + Y) --> uadd.sat(X, Y)
// ((X + Y) u< Y) ? -1 : (X + Y) --> uadd.sat(X, Y)
@@ -1024,9 +1013,9 @@ static bool adjustMinMax(SelectInst &Sel, ICmpInst &Cmp) {
/// select (icmp Pred X, C1), C2, X --> select (icmp Pred' X, C2), X, C2
/// Note: if C1 != C2, this will change the icmp constant to the existing
/// constant operand of the select.
-static Instruction *
-canonicalizeMinMaxWithConstant(SelectInst &Sel, ICmpInst &Cmp,
- InstCombiner &IC) {
+static Instruction *canonicalizeMinMaxWithConstant(SelectInst &Sel,
+ ICmpInst &Cmp,
+ InstCombinerImpl &IC) {
if (!Cmp.hasOneUse() || !isa<Constant>(Cmp.getOperand(1)))
return nullptr;
@@ -1063,105 +1052,29 @@ canonicalizeMinMaxWithConstant(SelectInst &Sel, ICmpInst &Cmp,
return &Sel;
}
-/// There are many select variants for each of ABS/NABS.
-/// In matchSelectPattern(), there are different compare constants, compare
-/// predicates/operands and select operands.
-/// In isKnownNegation(), there are different formats of negated operands.
-/// Canonicalize all these variants to 1 pattern.
-/// This makes CSE more likely.
static Instruction *canonicalizeAbsNabs(SelectInst &Sel, ICmpInst &Cmp,
- InstCombiner &IC) {
+ InstCombinerImpl &IC) {
if (!Cmp.hasOneUse() || !isa<Constant>(Cmp.getOperand(1)))
return nullptr;
- // Choose a sign-bit check for the compare (likely simpler for codegen).
- // ABS: (X <s 0) ? -X : X
- // NABS: (X <s 0) ? X : -X
Value *LHS, *RHS;
SelectPatternFlavor SPF = matchSelectPattern(&Sel, LHS, RHS).Flavor;
if (SPF != SelectPatternFlavor::SPF_ABS &&
SPF != SelectPatternFlavor::SPF_NABS)
return nullptr;
- Value *TVal = Sel.getTrueValue();
- Value *FVal = Sel.getFalseValue();
- assert(isKnownNegation(TVal, FVal) &&
- "Unexpected result from matchSelectPattern");
-
- // The compare may use the negated abs()/nabs() operand, or it may use
- // negation in non-canonical form such as: sub A, B.
- bool CmpUsesNegatedOp = match(Cmp.getOperand(0), m_Neg(m_Specific(TVal))) ||
- match(Cmp.getOperand(0), m_Neg(m_Specific(FVal)));
-
- bool CmpCanonicalized = !CmpUsesNegatedOp &&
- match(Cmp.getOperand(1), m_ZeroInt()) &&
- Cmp.getPredicate() == ICmpInst::ICMP_SLT;
- bool RHSCanonicalized = match(RHS, m_Neg(m_Specific(LHS)));
-
- // Is this already canonical?
- if (CmpCanonicalized && RHSCanonicalized)
- return nullptr;
-
- // If RHS is not canonical but is used by other instructions, don't
- // canonicalize it and potentially increase the instruction count.
- if (!RHSCanonicalized)
- if (!(RHS->hasOneUse() || (RHS->hasNUses(2) && CmpUsesNegatedOp)))
- return nullptr;
+ // Note that NSW flag can only be propagated for normal, non-negated abs!
+ bool IntMinIsPoison = SPF == SelectPatternFlavor::SPF_ABS &&
+ match(RHS, m_NSWNeg(m_Specific(LHS)));
+ Constant *IntMinIsPoisonC =
+ ConstantInt::get(Type::getInt1Ty(Sel.getContext()), IntMinIsPoison);
+ Instruction *Abs =
+ IC.Builder.CreateBinaryIntrinsic(Intrinsic::abs, LHS, IntMinIsPoisonC);
- // Create the canonical compare: icmp slt LHS 0.
- if (!CmpCanonicalized) {
- Cmp.setPredicate(ICmpInst::ICMP_SLT);
- Cmp.setOperand(1, ConstantInt::getNullValue(Cmp.getOperand(0)->getType()));
- if (CmpUsesNegatedOp)
- Cmp.setOperand(0, LHS);
- }
-
- // Create the canonical RHS: RHS = sub (0, LHS).
- if (!RHSCanonicalized) {
- assert(RHS->hasOneUse() && "RHS use number is not right");
- RHS = IC.Builder.CreateNeg(LHS);
- if (TVal == LHS) {
- // Replace false value.
- IC.replaceOperand(Sel, 2, RHS);
- FVal = RHS;
- } else {
- // Replace true value.
- IC.replaceOperand(Sel, 1, RHS);
- TVal = RHS;
- }
- }
+ if (SPF == SelectPatternFlavor::SPF_NABS)
+ return BinaryOperator::CreateNeg(Abs); // Always without NSW flag!
- // If the select operands do not change, we're done.
- if (SPF == SelectPatternFlavor::SPF_NABS) {
- if (TVal == LHS)
- return &Sel;
- assert(FVal == LHS && "Unexpected results from matchSelectPattern");
- } else {
- if (FVal == LHS)
- return &Sel;
- assert(TVal == LHS && "Unexpected results from matchSelectPattern");
- }
-
- // We are swapping the select operands, so swap the metadata too.
- Sel.swapValues();
- Sel.swapProfMetadata();
- return &Sel;
-}
-
-static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *ReplaceOp,
- const SimplifyQuery &Q) {
- // If this is a binary operator, try to simplify it with the replaced op
- // because we know Op and ReplaceOp are equivalant.
- // For example: V = X + 1, Op = X, ReplaceOp = 42
- // Simplifies as: add(42, 1) --> 43
- if (auto *BO = dyn_cast<BinaryOperator>(V)) {
- if (BO->getOperand(0) == Op)
- return SimplifyBinOp(BO->getOpcode(), ReplaceOp, BO->getOperand(1), Q);
- if (BO->getOperand(1) == Op)
- return SimplifyBinOp(BO->getOpcode(), BO->getOperand(0), ReplaceOp, Q);
- }
-
- return nullptr;
+ return IC.replaceInstUsesWith(Sel, Abs);
}
/// If we have a select with an equality comparison, then we know the value in
@@ -1180,30 +1093,97 @@ static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *ReplaceOp,
///
/// We can't replace %sel with %add unless we strip away the flags.
/// TODO: Wrapping flags could be preserved in some cases with better analysis.
-static Value *foldSelectValueEquivalence(SelectInst &Sel, ICmpInst &Cmp,
- const SimplifyQuery &Q) {
+Instruction *InstCombinerImpl::foldSelectValueEquivalence(SelectInst &Sel,
+ ICmpInst &Cmp) {
if (!Cmp.isEquality())
return nullptr;
// Canonicalize the pattern to ICMP_EQ by swapping the select operands.
Value *TrueVal = Sel.getTrueValue(), *FalseVal = Sel.getFalseValue();
- if (Cmp.getPredicate() == ICmpInst::ICMP_NE)
+ bool Swapped = false;
+ if (Cmp.getPredicate() == ICmpInst::ICMP_NE) {
std::swap(TrueVal, FalseVal);
+ Swapped = true;
+ }
+
+ // In X == Y ? f(X) : Z, try to evaluate f(Y) and replace the operand.
+ // Make sure Y cannot be undef though, as we might pick different values for
+ // undef in the icmp and in f(Y). Additionally, take care to avoid replacing
+ // X == Y ? X : Z with X == Y ? Y : Z, as that would lead to an infinite
+ // replacement cycle.
+ Value *CmpLHS = Cmp.getOperand(0), *CmpRHS = Cmp.getOperand(1);
+ if (TrueVal != CmpLHS &&
+ isGuaranteedNotToBeUndefOrPoison(CmpRHS, SQ.AC, &Sel, &DT)) {
+ if (Value *V = SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, SQ,
+ /* AllowRefinement */ true))
+ return replaceOperand(Sel, Swapped ? 2 : 1, V);
+
+ // Even if TrueVal does not simplify, we can directly replace a use of
+ // CmpLHS with CmpRHS, as long as the instruction is not used anywhere
+ // else and is safe to speculatively execute (we may end up executing it
+ // with different operands, which should not cause side-effects or trigger
+ // undefined behavior). Only do this if CmpRHS is a constant, as
+ // profitability is not clear for other cases.
+ // FIXME: The replacement could be performed recursively.
+ if (match(CmpRHS, m_ImmConstant()) && !match(CmpLHS, m_ImmConstant()))
+ if (auto *I = dyn_cast<Instruction>(TrueVal))
+ if (I->hasOneUse() && isSafeToSpeculativelyExecute(I))
+ for (Use &U : I->operands())
+ if (U == CmpLHS) {
+ replaceUse(U, CmpRHS);
+ return &Sel;
+ }
+ }
+ if (TrueVal != CmpRHS &&
+ isGuaranteedNotToBeUndefOrPoison(CmpLHS, SQ.AC, &Sel, &DT))
+ if (Value *V = SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, SQ,
+ /* AllowRefinement */ true))
+ return replaceOperand(Sel, Swapped ? 2 : 1, V);
+
+ auto *FalseInst = dyn_cast<Instruction>(FalseVal);
+ if (!FalseInst)
+ return nullptr;
+
+ // InstSimplify already performed this fold if it was possible subject to
+ // current poison-generating flags. Try the transform again with
+ // poison-generating flags temporarily dropped.
+ bool WasNUW = false, WasNSW = false, WasExact = false, WasInBounds = false;
+ if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(FalseVal)) {
+ WasNUW = OBO->hasNoUnsignedWrap();
+ WasNSW = OBO->hasNoSignedWrap();
+ FalseInst->setHasNoUnsignedWrap(false);
+ FalseInst->setHasNoSignedWrap(false);
+ }
+ if (auto *PEO = dyn_cast<PossiblyExactOperator>(FalseVal)) {
+ WasExact = PEO->isExact();
+ FalseInst->setIsExact(false);
+ }
+ if (auto *GEP = dyn_cast<GetElementPtrInst>(FalseVal)) {
+ WasInBounds = GEP->isInBounds();
+ GEP->setIsInBounds(false);
+ }
// Try each equivalence substitution possibility.
// We have an 'EQ' comparison, so the select's false value will propagate.
// Example:
// (X == 42) ? 43 : (X + 1) --> (X == 42) ? (X + 1) : (X + 1) --> X + 1
- // (X == 42) ? (X + 1) : 43 --> (X == 42) ? (42 + 1) : 43 --> 43
- Value *CmpLHS = Cmp.getOperand(0), *CmpRHS = Cmp.getOperand(1);
- if (simplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, Q) == TrueVal ||
- simplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, Q) == TrueVal ||
- simplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, Q) == FalseVal ||
- simplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, Q) == FalseVal) {
- if (auto *FalseInst = dyn_cast<Instruction>(FalseVal))
- FalseInst->dropPoisonGeneratingFlags();
- return FalseVal;
- }
+ if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, SQ,
+ /* AllowRefinement */ false) == TrueVal ||
+ SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, SQ,
+ /* AllowRefinement */ false) == TrueVal) {
+ return replaceInstUsesWith(Sel, FalseVal);
+ }
+
+ // Restore poison-generating flags if the transform did not apply.
+ if (WasNUW)
+ FalseInst->setHasNoUnsignedWrap();
+ if (WasNSW)
+ FalseInst->setHasNoSignedWrap();
+ if (WasExact)
+ FalseInst->setIsExact();
+ if (WasInBounds)
+ cast<GetElementPtrInst>(FalseInst)->setIsInBounds();
+
return nullptr;
}
@@ -1253,7 +1233,7 @@ static Instruction *canonicalizeClampLike(SelectInst &Sel0, ICmpInst &Cmp0,
APInt::getAllOnesValue(
C0->getType()->getScalarSizeInBits()))))
return nullptr; // Can't do, have all-ones element[s].
- C0 = AddOne(C0);
+ C0 = InstCombiner::AddOne(C0);
std::swap(X, Sel1);
break;
case ICmpInst::Predicate::ICMP_UGE:
@@ -1313,7 +1293,7 @@ static Instruction *canonicalizeClampLike(SelectInst &Sel0, ICmpInst &Cmp0,
APInt::getSignedMaxValue(
C2->getType()->getScalarSizeInBits()))))
return nullptr; // Can't do, have signed max element[s].
- C2 = AddOne(C2);
+ C2 = InstCombiner::AddOne(C2);
LLVM_FALLTHROUGH;
case ICmpInst::Predicate::ICMP_SGE:
// Also non-canonical, but here we don't need to change C2,
@@ -1360,7 +1340,7 @@ static Instruction *canonicalizeClampLike(SelectInst &Sel0, ICmpInst &Cmp0,
// and swap the hands of select.
static Instruction *
tryToReuseConstantFromSelectInComparison(SelectInst &Sel, ICmpInst &Cmp,
- InstCombiner &IC) {
+ InstCombinerImpl &IC) {
ICmpInst::Predicate Pred;
Value *X;
Constant *C0;
@@ -1375,7 +1355,7 @@ tryToReuseConstantFromSelectInComparison(SelectInst &Sel, ICmpInst &Cmp,
// If comparison predicate is non-canonical, then we certainly won't be able
// to make it canonical; canonicalizeCmpWithConstant() already tried.
- if (!isCanonicalPredicate(Pred))
+ if (!InstCombiner::isCanonicalPredicate(Pred))
return nullptr;
// If the [input] type of comparison and select type are different, lets abort
@@ -1403,7 +1383,8 @@ tryToReuseConstantFromSelectInComparison(SelectInst &Sel, ICmpInst &Cmp,
return nullptr;
// Check the constant we'd have with flipped-strictness predicate.
- auto FlippedStrictness = getFlippedStrictnessPredicateAndConstant(Pred, C0);
+ auto FlippedStrictness =
+ InstCombiner::getFlippedStrictnessPredicateAndConstant(Pred, C0);
if (!FlippedStrictness)
return nullptr;
@@ -1426,10 +1407,10 @@ tryToReuseConstantFromSelectInComparison(SelectInst &Sel, ICmpInst &Cmp,
}
/// Visit a SelectInst that has an ICmpInst as its first operand.
-Instruction *InstCombiner::foldSelectInstWithICmp(SelectInst &SI,
- ICmpInst *ICI) {
- if (Value *V = foldSelectValueEquivalence(SI, *ICI, SQ))
- return replaceInstUsesWith(SI, V);
+Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI,
+ ICmpInst *ICI) {
+ if (Instruction *NewSel = foldSelectValueEquivalence(SI, *ICI))
+ return NewSel;
if (Instruction *NewSel = canonicalizeMinMaxWithConstant(SI, *ICI, *this))
return NewSel;
@@ -1579,11 +1560,11 @@ static bool canSelectOperandBeMappingIntoPredBlock(const Value *V,
/// We have an SPF (e.g. a min or max) of an SPF of the form:
/// SPF2(SPF1(A, B), C)
-Instruction *InstCombiner::foldSPFofSPF(Instruction *Inner,
- SelectPatternFlavor SPF1,
- Value *A, Value *B,
- Instruction &Outer,
- SelectPatternFlavor SPF2, Value *C) {
+Instruction *InstCombinerImpl::foldSPFofSPF(Instruction *Inner,
+ SelectPatternFlavor SPF1, Value *A,
+ Value *B, Instruction &Outer,
+ SelectPatternFlavor SPF2,
+ Value *C) {
if (Outer.getType() != Inner->getType())
return nullptr;
@@ -1900,7 +1881,7 @@ foldOverflowingAddSubSelect(SelectInst &SI, InstCombiner::BuilderTy &Builder) {
return CallInst::Create(F, {X, Y});
}
-Instruction *InstCombiner::foldSelectExtConst(SelectInst &Sel) {
+Instruction *InstCombinerImpl::foldSelectExtConst(SelectInst &Sel) {
Constant *C;
if (!match(Sel.getTrueValue(), m_Constant(C)) &&
!match(Sel.getFalseValue(), m_Constant(C)))
@@ -1966,10 +1947,11 @@ Instruction *InstCombiner::foldSelectExtConst(SelectInst &Sel) {
static Instruction *canonicalizeSelectToShuffle(SelectInst &SI) {
Value *CondVal = SI.getCondition();
Constant *CondC;
- if (!CondVal->getType()->isVectorTy() || !match(CondVal, m_Constant(CondC)))
+ auto *CondValTy = dyn_cast<FixedVectorType>(CondVal->getType());
+ if (!CondValTy || !match(CondVal, m_Constant(CondC)))
return nullptr;
- unsigned NumElts = cast<VectorType>(CondVal->getType())->getNumElements();
+ unsigned NumElts = CondValTy->getNumElements();
SmallVector<int, 16> Mask;
Mask.reserve(NumElts);
for (unsigned i = 0; i != NumElts; ++i) {
@@ -2001,8 +1983,8 @@ static Instruction *canonicalizeSelectToShuffle(SelectInst &SI) {
/// to a vector select by splatting the condition. A splat may get folded with
/// other operations in IR and having all operands of a select be vector types
/// is likely better for vector codegen.
-static Instruction *canonicalizeScalarSelectOfVecs(
- SelectInst &Sel, InstCombiner &IC) {
+static Instruction *canonicalizeScalarSelectOfVecs(SelectInst &Sel,
+ InstCombinerImpl &IC) {
auto *Ty = dyn_cast<VectorType>(Sel.getType());
if (!Ty)
return nullptr;
@@ -2015,8 +1997,8 @@ static Instruction *canonicalizeScalarSelectOfVecs(
// select (extelt V, Index), T, F --> select (splat V, Index), T, F
// Splatting the extracted condition reduces code (we could directly create a
// splat shuffle of the source vector to eliminate the intermediate step).
- unsigned NumElts = Ty->getNumElements();
- return IC.replaceOperand(Sel, 0, IC.Builder.CreateVectorSplat(NumElts, Cond));
+ return IC.replaceOperand(
+ Sel, 0, IC.Builder.CreateVectorSplat(Ty->getElementCount(), Cond));
}
/// Reuse bitcasted operands between a compare and select:
@@ -2172,7 +2154,7 @@ static Instruction *moveAddAfterMinMax(SelectPatternFlavor SPF, Value *X,
}
/// Match a sadd_sat or ssub_sat which is using min/max to clamp the value.
-Instruction *InstCombiner::matchSAddSubSat(SelectInst &MinMax1) {
+Instruction *InstCombinerImpl::matchSAddSubSat(SelectInst &MinMax1) {
Type *Ty = MinMax1.getType();
// We are looking for a tree of:
@@ -2293,34 +2275,42 @@ static Instruction *factorizeMinMaxTree(SelectPatternFlavor SPF, Value *LHS,
return SelectInst::Create(CmpABC, MinMaxOp, ThirdOp);
}
-/// Try to reduce a rotate pattern that includes a compare and select into a
-/// funnel shift intrinsic. Example:
+/// Try to reduce a funnel/rotate pattern that includes a compare and select
+/// into a funnel shift intrinsic. Example:
/// rotl32(a, b) --> (b == 0 ? a : ((a >> (32 - b)) | (a << b)))
/// --> call llvm.fshl.i32(a, a, b)
-static Instruction *foldSelectRotate(SelectInst &Sel) {
- // The false value of the select must be a rotate of the true value.
- Value *Or0, *Or1;
- if (!match(Sel.getFalseValue(), m_OneUse(m_Or(m_Value(Or0), m_Value(Or1)))))
+/// fshl32(a, b, c) --> (c == 0 ? a : ((b >> (32 - c)) | (a << c)))
+/// --> call llvm.fshl.i32(a, b, c)
+/// fshr32(a, b, c) --> (c == 0 ? b : ((a >> (32 - c)) | (b << c)))
+/// --> call llvm.fshr.i32(a, b, c)
+static Instruction *foldSelectFunnelShift(SelectInst &Sel,
+ InstCombiner::BuilderTy &Builder) {
+ // This must be a power-of-2 type for a bitmasking transform to be valid.
+ unsigned Width = Sel.getType()->getScalarSizeInBits();
+ if (!isPowerOf2_32(Width))
return nullptr;
- Value *TVal = Sel.getTrueValue();
- Value *SA0, *SA1;
- if (!match(Or0, m_OneUse(m_LogicalShift(m_Specific(TVal), m_Value(SA0)))) ||
- !match(Or1, m_OneUse(m_LogicalShift(m_Specific(TVal), m_Value(SA1)))))
+ BinaryOperator *Or0, *Or1;
+ if (!match(Sel.getFalseValue(), m_OneUse(m_Or(m_BinOp(Or0), m_BinOp(Or1)))))
return nullptr;
- auto ShiftOpcode0 = cast<BinaryOperator>(Or0)->getOpcode();
- auto ShiftOpcode1 = cast<BinaryOperator>(Or1)->getOpcode();
- if (ShiftOpcode0 == ShiftOpcode1)
+ Value *SV0, *SV1, *SA0, *SA1;
+ if (!match(Or0, m_OneUse(m_LogicalShift(m_Value(SV0),
+ m_ZExtOrSelf(m_Value(SA0))))) ||
+ !match(Or1, m_OneUse(m_LogicalShift(m_Value(SV1),
+ m_ZExtOrSelf(m_Value(SA1))))) ||
+ Or0->getOpcode() == Or1->getOpcode())
return nullptr;
- // We have one of these patterns so far:
- // select ?, TVal, (or (lshr TVal, SA0), (shl TVal, SA1))
- // select ?, TVal, (or (shl TVal, SA0), (lshr TVal, SA1))
- // This must be a power-of-2 rotate for a bitmasking transform to be valid.
- unsigned Width = Sel.getType()->getScalarSizeInBits();
- if (!isPowerOf2_32(Width))
- return nullptr;
+ // Canonicalize to or(shl(SV0, SA0), lshr(SV1, SA1)).
+ if (Or0->getOpcode() == BinaryOperator::LShr) {
+ std::swap(Or0, Or1);
+ std::swap(SV0, SV1);
+ std::swap(SA0, SA1);
+ }
+ assert(Or0->getOpcode() == BinaryOperator::Shl &&
+ Or1->getOpcode() == BinaryOperator::LShr &&
+ "Illegal or(shift,shift) pair");
// Check the shift amounts to see if they are an opposite pair.
Value *ShAmt;
@@ -2331,6 +2321,15 @@ static Instruction *foldSelectRotate(SelectInst &Sel) {
else
return nullptr;
+ // We should now have this pattern:
+ // select ?, TVal, (or (shl SV0, SA0), (lshr SV1, SA1))
+ // The false value of the select must be a funnel-shift of the true value:
+ // IsFShl -> TVal must be SV0 else TVal must be SV1.
+ bool IsFshl = (ShAmt == SA0);
+ Value *TVal = Sel.getTrueValue();
+ if ((IsFshl && TVal != SV0) || (!IsFshl && TVal != SV1))
+ return nullptr;
+
// Finally, see if the select is filtering out a shift-by-zero.
Value *Cond = Sel.getCondition();
ICmpInst::Predicate Pred;
@@ -2338,13 +2337,21 @@ static Instruction *foldSelectRotate(SelectInst &Sel) {
Pred != ICmpInst::ICMP_EQ)
return nullptr;
- // This is a rotate that avoids shift-by-bitwidth UB in a suboptimal way.
+ // If this is not a rotate then the select was blocking poison from the
+ // 'shift-by-zero' non-TVal, but a funnel shift won't - so freeze it.
+ if (SV0 != SV1) {
+ if (IsFshl && !llvm::isGuaranteedNotToBePoison(SV1))
+ SV1 = Builder.CreateFreeze(SV1);
+ else if (!IsFshl && !llvm::isGuaranteedNotToBePoison(SV0))
+ SV0 = Builder.CreateFreeze(SV0);
+ }
+
+ // This is a funnel/rotate that avoids shift-by-bitwidth UB in a suboptimal way.
// Convert to funnel shift intrinsic.
- bool IsFshl = (ShAmt == SA0 && ShiftOpcode0 == BinaryOperator::Shl) ||
- (ShAmt == SA1 && ShiftOpcode1 == BinaryOperator::Shl);
Intrinsic::ID IID = IsFshl ? Intrinsic::fshl : Intrinsic::fshr;
Function *F = Intrinsic::getDeclaration(Sel.getModule(), IID, Sel.getType());
- return IntrinsicInst::Create(F, { TVal, TVal, ShAmt });
+ ShAmt = Builder.CreateZExt(ShAmt, Sel.getType());
+ return IntrinsicInst::Create(F, { SV0, SV1, ShAmt });
}
static Instruction *foldSelectToCopysign(SelectInst &Sel,
@@ -2368,7 +2375,8 @@ static Instruction *foldSelectToCopysign(SelectInst &Sel,
bool IsTrueIfSignSet;
ICmpInst::Predicate Pred;
if (!match(Cond, m_OneUse(m_ICmp(Pred, m_BitCast(m_Value(X)), m_APInt(C)))) ||
- !isSignBitCheck(Pred, *C, IsTrueIfSignSet) || X->getType() != SelType)
+ !InstCombiner::isSignBitCheck(Pred, *C, IsTrueIfSignSet) ||
+ X->getType() != SelType)
return nullptr;
// If needed, negate the value that will be the sign argument of the copysign:
@@ -2389,7 +2397,7 @@ static Instruction *foldSelectToCopysign(SelectInst &Sel,
return CopySign;
}
-Instruction *InstCombiner::foldVectorSelect(SelectInst &Sel) {
+Instruction *InstCombinerImpl::foldVectorSelect(SelectInst &Sel) {
auto *VecTy = dyn_cast<FixedVectorType>(Sel.getType());
if (!VecTy)
return nullptr;
@@ -2469,6 +2477,10 @@ static Instruction *foldSelectToPhiImpl(SelectInst &Sel, BasicBlock *BB,
} else
return nullptr;
+ // Make sure the branches are actually different.
+ if (TrueSucc == FalseSucc)
+ return nullptr;
+
// We want to replace select %cond, %a, %b with a phi that takes value %a
// for all incoming edges that are dominated by condition `%cond == true`,
// and value %b for edges dominated by condition `%cond == false`. If %a
@@ -2515,7 +2527,33 @@ static Instruction *foldSelectToPhi(SelectInst &Sel, const DominatorTree &DT,
return nullptr;
}
-Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
+static Value *foldSelectWithFrozenICmp(SelectInst &Sel, InstCombiner::BuilderTy &Builder) {
+ FreezeInst *FI = dyn_cast<FreezeInst>(Sel.getCondition());
+ if (!FI)
+ return nullptr;
+
+ Value *Cond = FI->getOperand(0);
+ Value *TrueVal = Sel.getTrueValue(), *FalseVal = Sel.getFalseValue();
+
+ // select (freeze(x == y)), x, y --> y
+ // select (freeze(x != y)), x, y --> x
+ // The freeze should be only used by this select. Otherwise, remaining uses of
+ // the freeze can observe a contradictory value.
+ // c = freeze(x == y) ; Let's assume that y = poison & x = 42; c is 0 or 1
+ // a = select c, x, y ;
+ // f(a, c) ; f(poison, 1) cannot happen, but if a is folded
+ // ; to y, this can happen.
+ CmpInst::Predicate Pred;
+ if (FI->hasOneUse() &&
+ match(Cond, m_c_ICmp(Pred, m_Specific(TrueVal), m_Specific(FalseVal))) &&
+ (Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE)) {
+ return Pred == ICmpInst::ICMP_EQ ? FalseVal : TrueVal;
+ }
+
+ return nullptr;
+}
+
+Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
Value *CondVal = SI.getCondition();
Value *TrueVal = SI.getTrueValue();
Value *FalseVal = SI.getFalseValue();
@@ -2551,38 +2589,45 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
if (SelType->isIntOrIntVectorTy(1) &&
TrueVal->getType() == CondVal->getType()) {
- if (match(TrueVal, m_One())) {
+ if (match(TrueVal, m_One()) &&
+ (EnableUnsafeSelectTransform || impliesPoison(FalseVal, CondVal))) {
// Change: A = select B, true, C --> A = or B, C
return BinaryOperator::CreateOr(CondVal, FalseVal);
}
- if (match(TrueVal, m_Zero())) {
- // Change: A = select B, false, C --> A = and !B, C
- Value *NotCond = Builder.CreateNot(CondVal, "not." + CondVal->getName());
- return BinaryOperator::CreateAnd(NotCond, FalseVal);
- }
- if (match(FalseVal, m_Zero())) {
+ if (match(FalseVal, m_Zero()) &&
+ (EnableUnsafeSelectTransform || impliesPoison(TrueVal, CondVal))) {
// Change: A = select B, C, false --> A = and B, C
return BinaryOperator::CreateAnd(CondVal, TrueVal);
}
+
+ // select a, false, b -> select !a, b, false
+ if (match(TrueVal, m_Zero())) {
+ Value *NotCond = Builder.CreateNot(CondVal, "not." + CondVal->getName());
+ return SelectInst::Create(NotCond, FalseVal,
+ ConstantInt::getFalse(SelType));
+ }
+ // select a, b, true -> select !a, true, b
if (match(FalseVal, m_One())) {
- // Change: A = select B, C, true --> A = or !B, C
Value *NotCond = Builder.CreateNot(CondVal, "not." + CondVal->getName());
- return BinaryOperator::CreateOr(NotCond, TrueVal);
+ return SelectInst::Create(NotCond, ConstantInt::getTrue(SelType),
+ TrueVal);
}
- // select a, a, b -> a | b
- // select a, b, a -> a & b
+ // select a, a, b -> select a, true, b
if (CondVal == TrueVal)
- return BinaryOperator::CreateOr(CondVal, FalseVal);
+ return replaceOperand(SI, 1, ConstantInt::getTrue(SelType));
+ // select a, b, a -> select a, b, false
if (CondVal == FalseVal)
- return BinaryOperator::CreateAnd(CondVal, TrueVal);
+ return replaceOperand(SI, 2, ConstantInt::getFalse(SelType));
- // select a, ~a, b -> (~a) & b
- // select a, b, ~a -> (~a) | b
+ // select a, !a, b -> select !a, b, false
if (match(TrueVal, m_Not(m_Specific(CondVal))))
- return BinaryOperator::CreateAnd(TrueVal, FalseVal);
+ return SelectInst::Create(TrueVal, FalseVal,
+ ConstantInt::getFalse(SelType));
+ // select a, b, !a -> select !a, true, b
if (match(FalseVal, m_Not(m_Specific(CondVal))))
- return BinaryOperator::CreateOr(TrueVal, FalseVal);
+ return SelectInst::Create(FalseVal, ConstantInt::getTrue(SelType),
+ TrueVal);
}
// Selecting between two integer or vector splat integer constants?
@@ -2591,7 +2636,10 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
// select i1 %c, <2 x i8> <1, 1>, <2 x i8> <0, 0>
// because that may need 3 instructions to splat the condition value:
// extend, insertelement, shufflevector.
- if (SelType->isIntOrIntVectorTy() &&
+ //
+ // Do not handle i1 TrueVal and FalseVal otherwise would result in
+ // zext/sext i1 to i1.
+ if (SelType->isIntOrIntVectorTy() && !SelType->isIntOrIntVectorTy(1) &&
CondVal->getType()->isVectorTy() == SelType->isVectorTy()) {
// select C, 1, 0 -> zext C to int
if (match(TrueVal, m_One()) && match(FalseVal, m_Zero()))
@@ -2838,8 +2886,9 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
return replaceOperand(SI, 1, TrueSI->getTrueValue());
}
// select(C0, select(C1, a, b), b) -> select(C0&C1, a, b)
- // We choose this as normal form to enable folding on the And and shortening
- // paths for the values (this helps GetUnderlyingObjects() for example).
+ // We choose this as normal form to enable folding on the And and
+ // shortening paths for the values (this helps getUnderlyingObjects() for
+ // example).
if (TrueSI->getFalseValue() == FalseVal && TrueSI->hasOneUse()) {
Value *And = Builder.CreateAnd(CondVal, TrueSI->getCondition());
replaceOperand(SI, 0, And);
@@ -2922,7 +2971,8 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
}
Value *NotCond;
- if (match(CondVal, m_Not(m_Value(NotCond)))) {
+ if (match(CondVal, m_Not(m_Value(NotCond))) &&
+ !InstCombiner::shouldAvoidAbsorbingNotIntoSelect(SI)) {
replaceOperand(SI, 0, NotCond);
SI.swapValues();
SI.swapProfMetadata();
@@ -2956,8 +3006,8 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
if (Instruction *Select = foldSelectBinOpIdentity(SI, TLI, *this))
return Select;
- if (Instruction *Rot = foldSelectRotate(SI))
- return Rot;
+ if (Instruction *Funnel = foldSelectFunnelShift(SI, Builder))
+ return Funnel;
if (Instruction *Copysign = foldSelectToCopysign(SI, Builder))
return Copysign;
@@ -2965,5 +3015,8 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
if (Instruction *PN = foldSelectToPhi(SI, DT, Builder))
return replaceInstUsesWith(SI, PN);
+ if (Value *Fr = foldSelectWithFrozenICmp(SI, Builder))
+ return replaceInstUsesWith(SI, Fr);
+
return nullptr;
}
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
index 0a842b4e1047..7295369365c4 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
@@ -15,6 +15,7 @@
#include "llvm/Analysis/InstructionSimplify.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/PatternMatch.h"
+#include "llvm/Transforms/InstCombine/InstCombiner.h"
using namespace llvm;
using namespace PatternMatch;
@@ -31,7 +32,7 @@ using namespace PatternMatch;
//
// AnalyzeForSignBitExtraction indicates that we will only analyze whether this
// pattern has any 2 right-shifts that sum to 1 less than original bit width.
-Value *InstCombiner::reassociateShiftAmtsOfTwoSameDirectionShifts(
+Value *InstCombinerImpl::reassociateShiftAmtsOfTwoSameDirectionShifts(
BinaryOperator *Sh0, const SimplifyQuery &SQ,
bool AnalyzeForSignBitExtraction) {
// Look for a shift of some instruction, ignore zext of shift amount if any.
@@ -327,8 +328,8 @@ static Instruction *foldShiftOfShiftedLogic(BinaryOperator &I,
if (!LogicInst || !LogicInst->isBitwiseLogicOp() || !LogicInst->hasOneUse())
return nullptr;
- const APInt *C0, *C1;
- if (!match(I.getOperand(1), m_APInt(C1)))
+ Constant *C0, *C1;
+ if (!match(I.getOperand(1), m_Constant(C1)))
return nullptr;
Instruction::BinaryOps ShiftOpcode = I.getOpcode();
@@ -339,10 +340,12 @@ static Instruction *foldShiftOfShiftedLogic(BinaryOperator &I,
// TODO: Remove the one-use check if the other logic operand (Y) is constant.
Value *X, *Y;
auto matchFirstShift = [&](Value *V) {
- return !isa<ConstantExpr>(V) &&
- match(V, m_OneUse(m_Shift(m_Value(X), m_APInt(C0)))) &&
- cast<BinaryOperator>(V)->getOpcode() == ShiftOpcode &&
- (*C0 + *C1).ult(Ty->getScalarSizeInBits());
+ BinaryOperator *BO;
+ APInt Threshold(Ty->getScalarSizeInBits(), Ty->getScalarSizeInBits());
+ return match(V, m_BinOp(BO)) && BO->getOpcode() == ShiftOpcode &&
+ match(V, m_OneUse(m_Shift(m_Value(X), m_Constant(C0)))) &&
+ match(ConstantExpr::getAdd(C0, C1),
+ m_SpecificInt_ICMP(ICmpInst::ICMP_ULT, Threshold));
};
// Logic ops are commutative, so check each operand for a match.
@@ -354,13 +357,13 @@ static Instruction *foldShiftOfShiftedLogic(BinaryOperator &I,
return nullptr;
// shift (logic (shift X, C0), Y), C1 -> logic (shift X, C0+C1), (shift Y, C1)
- Constant *ShiftSumC = ConstantInt::get(Ty, *C0 + *C1);
+ Constant *ShiftSumC = ConstantExpr::getAdd(C0, C1);
Value *NewShift1 = Builder.CreateBinOp(ShiftOpcode, X, ShiftSumC);
Value *NewShift2 = Builder.CreateBinOp(ShiftOpcode, Y, I.getOperand(1));
return BinaryOperator::Create(LogicInst->getOpcode(), NewShift1, NewShift2);
}
-Instruction *InstCombiner::commonShiftTransforms(BinaryOperator &I) {
+Instruction *InstCombinerImpl::commonShiftTransforms(BinaryOperator &I) {
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
assert(Op0->getType() == Op1->getType());
@@ -399,15 +402,15 @@ Instruction *InstCombiner::commonShiftTransforms(BinaryOperator &I) {
return BinaryOperator::Create(
I.getOpcode(), Builder.CreateBinOp(I.getOpcode(), Op0, C), A);
- // X shift (A srem B) -> X shift (A and B-1) iff B is a power of 2.
+ // X shift (A srem C) -> X shift (A and (C - 1)) iff C is a power of 2.
// Because shifts by negative values (which could occur if A were negative)
// are undefined.
- const APInt *B;
- if (Op1->hasOneUse() && match(Op1, m_SRem(m_Value(A), m_Power2(B)))) {
+ if (Op1->hasOneUse() && match(Op1, m_SRem(m_Value(A), m_Constant(C))) &&
+ match(C, m_Power2())) {
// FIXME: Should this get moved into SimplifyDemandedBits by saying we don't
// demand the sign bit (and many others) here??
- Value *Rem = Builder.CreateAnd(A, ConstantInt::get(I.getType(), *B - 1),
- Op1->getName());
+ Constant *Mask = ConstantExpr::getSub(C, ConstantInt::get(I.getType(), 1));
+ Value *Rem = Builder.CreateAnd(A, Mask, Op1->getName());
return replaceOperand(I, 1, Rem);
}
@@ -420,8 +423,8 @@ Instruction *InstCombiner::commonShiftTransforms(BinaryOperator &I) {
/// Return true if we can simplify two logical (either left or right) shifts
/// that have constant shift amounts: OuterShift (InnerShift X, C1), C2.
static bool canEvaluateShiftedShift(unsigned OuterShAmt, bool IsOuterShl,
- Instruction *InnerShift, InstCombiner &IC,
- Instruction *CxtI) {
+ Instruction *InnerShift,
+ InstCombinerImpl &IC, Instruction *CxtI) {
assert(InnerShift->isLogicalShift() && "Unexpected instruction type");
// We need constant scalar or constant splat shifts.
@@ -472,7 +475,7 @@ static bool canEvaluateShiftedShift(unsigned OuterShAmt, bool IsOuterShl,
/// where the client will ask if E can be computed shifted right by 64-bits. If
/// this succeeds, getShiftedValue() will be called to produce the value.
static bool canEvaluateShifted(Value *V, unsigned NumBits, bool IsLeftShift,
- InstCombiner &IC, Instruction *CxtI) {
+ InstCombinerImpl &IC, Instruction *CxtI) {
// We can always evaluate constants shifted.
if (isa<Constant>(V))
return true;
@@ -480,31 +483,6 @@ static bool canEvaluateShifted(Value *V, unsigned NumBits, bool IsLeftShift,
Instruction *I = dyn_cast<Instruction>(V);
if (!I) return false;
- // If this is the opposite shift, we can directly reuse the input of the shift
- // if the needed bits are already zero in the input. This allows us to reuse
- // the value which means that we don't care if the shift has multiple uses.
- // TODO: Handle opposite shift by exact value.
- ConstantInt *CI = nullptr;
- if ((IsLeftShift && match(I, m_LShr(m_Value(), m_ConstantInt(CI)))) ||
- (!IsLeftShift && match(I, m_Shl(m_Value(), m_ConstantInt(CI))))) {
- if (CI->getValue() == NumBits) {
- // TODO: Check that the input bits are already zero with MaskedValueIsZero
-#if 0
- // If this is a truncate of a logical shr, we can truncate it to a smaller
- // lshr iff we know that the bits we would otherwise be shifting in are
- // already zeros.
- uint32_t OrigBitWidth = OrigTy->getScalarSizeInBits();
- uint32_t BitWidth = Ty->getScalarSizeInBits();
- if (MaskedValueIsZero(I->getOperand(0),
- APInt::getHighBitsSet(OrigBitWidth, OrigBitWidth-BitWidth)) &&
- CI->getLimitedValue(BitWidth) < BitWidth) {
- return CanEvaluateTruncated(I->getOperand(0), Ty);
- }
-#endif
-
- }
- }
-
// We can't mutate something that has multiple uses: doing so would
// require duplicating the instruction in general, which isn't profitable.
if (!I->hasOneUse()) return false;
@@ -608,7 +586,7 @@ static Value *foldShiftedShift(BinaryOperator *InnerShift, unsigned OuterShAmt,
/// When canEvaluateShifted() returns true for an expression, this function
/// inserts the new computation that produces the shifted value.
static Value *getShiftedValue(Value *V, unsigned NumBits, bool isLeftShift,
- InstCombiner &IC, const DataLayout &DL) {
+ InstCombinerImpl &IC, const DataLayout &DL) {
// We can always evaluate constants shifted.
if (Constant *C = dyn_cast<Constant>(V)) {
if (isLeftShift)
@@ -618,7 +596,7 @@ static Value *getShiftedValue(Value *V, unsigned NumBits, bool isLeftShift,
}
Instruction *I = cast<Instruction>(V);
- IC.Worklist.push(I);
+ IC.addToWorklist(I);
switch (I->getOpcode()) {
default: llvm_unreachable("Inconsistency with CanEvaluateShifted");
@@ -666,14 +644,17 @@ static bool canShiftBinOpWithConstantRHS(BinaryOperator &Shift,
case Instruction::Add:
return Shift.getOpcode() == Instruction::Shl;
case Instruction::Or:
- case Instruction::Xor:
case Instruction::And:
return true;
+ case Instruction::Xor:
+ // Do not change a 'not' of logical shift because that would create a normal
+ // 'xor'. The 'not' is likely better for analysis, SCEV, and codegen.
+ return !(Shift.isLogicalShift() && match(BO, m_Not(m_Value())));
}
}
-Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1,
- BinaryOperator &I) {
+Instruction *InstCombinerImpl::FoldShiftByConstant(Value *Op0, Constant *Op1,
+ BinaryOperator &I) {
bool isLeftShift = I.getOpcode() == Instruction::Shl;
const APInt *Op1C;
@@ -695,8 +676,8 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1,
// See if we can simplify any instructions used by the instruction whose sole
// purpose is to compute bits we don't care about.
- unsigned TypeBits = Op0->getType()->getScalarSizeInBits();
-
+ Type *Ty = I.getType();
+ unsigned TypeBits = Ty->getScalarSizeInBits();
assert(!Op1C->uge(TypeBits) &&
"Shift over the type width should have been removed already");
@@ -704,18 +685,20 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1,
return FoldedShift;
// Fold shift2(trunc(shift1(x,c1)), c2) -> trunc(shift2(shift1(x,c1),c2))
- if (TruncInst *TI = dyn_cast<TruncInst>(Op0)) {
- Instruction *TrOp = dyn_cast<Instruction>(TI->getOperand(0));
+ if (auto *TI = dyn_cast<TruncInst>(Op0)) {
// If 'shift2' is an ashr, we would have to get the sign bit into a funny
// place. Don't try to do this transformation in this case. Also, we
// require that the input operand is a shift-by-constant so that we have
// confidence that the shifts will get folded together. We could do this
// xform in more cases, but it is unlikely to be profitable.
- if (TrOp && I.isLogicalShift() && TrOp->isShift() &&
- isa<ConstantInt>(TrOp->getOperand(1))) {
+ const APInt *TrShiftAmt;
+ if (I.isLogicalShift() &&
+ match(TI->getOperand(0), m_Shift(m_Value(), m_APInt(TrShiftAmt)))) {
+ auto *TrOp = cast<Instruction>(TI->getOperand(0));
+ Type *SrcTy = TrOp->getType();
+
// Okay, we'll do this xform. Make the shift of shift.
- Constant *ShAmt =
- ConstantExpr::getZExt(cast<Constant>(Op1), TrOp->getType());
+ Constant *ShAmt = ConstantExpr::getZExt(Op1, SrcTy);
// (shift2 (shift1 & 0x00FF), c2)
Value *NSh = Builder.CreateBinOp(I.getOpcode(), TrOp, ShAmt, I.getName());
@@ -723,36 +706,27 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1,
// part of the register be zeros. Emulate this by inserting an AND to
// clear the top bits as needed. This 'and' will usually be zapped by
// other xforms later if dead.
- unsigned SrcSize = TrOp->getType()->getScalarSizeInBits();
- unsigned DstSize = TI->getType()->getScalarSizeInBits();
- APInt MaskV(APInt::getLowBitsSet(SrcSize, DstSize));
+ unsigned SrcSize = SrcTy->getScalarSizeInBits();
+ Constant *MaskV =
+ ConstantInt::get(SrcTy, APInt::getLowBitsSet(SrcSize, TypeBits));
// The mask we constructed says what the trunc would do if occurring
// between the shifts. We want to know the effect *after* the second
// shift. We know that it is a logical shift by a constant, so adjust the
// mask as appropriate.
- if (I.getOpcode() == Instruction::Shl)
- MaskV <<= Op1C->getZExtValue();
- else {
- assert(I.getOpcode() == Instruction::LShr && "Unknown logical shift");
- MaskV.lshrInPlace(Op1C->getZExtValue());
- }
-
+ MaskV = ConstantExpr::get(I.getOpcode(), MaskV, ShAmt);
// shift1 & 0x00FF
- Value *And = Builder.CreateAnd(NSh,
- ConstantInt::get(I.getContext(), MaskV),
- TI->getName());
-
+ Value *And = Builder.CreateAnd(NSh, MaskV, TI->getName());
// Return the value truncated to the interesting size.
- return new TruncInst(And, I.getType());
+ return new TruncInst(And, Ty);
}
}
if (Op0->hasOneUse()) {
if (BinaryOperator *Op0BO = dyn_cast<BinaryOperator>(Op0)) {
// Turn ((X >> C) + Y) << C -> (X + (Y << C)) & (~0 << C)
- Value *V1, *V2;
- ConstantInt *CC;
+ Value *V1;
+ const APInt *CC;
switch (Op0BO->getOpcode()) {
default: break;
case Instruction::Add:
@@ -770,25 +744,22 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1,
Value *X = Builder.CreateBinOp(Op0BO->getOpcode(), YS, V1,
Op0BO->getOperand(1)->getName());
unsigned Op1Val = Op1C->getLimitedValue(TypeBits);
-
APInt Bits = APInt::getHighBitsSet(TypeBits, TypeBits - Op1Val);
- Constant *Mask = ConstantInt::get(I.getContext(), Bits);
- if (VectorType *VT = dyn_cast<VectorType>(X->getType()))
- Mask = ConstantVector::getSplat(VT->getElementCount(), Mask);
+ Constant *Mask = ConstantInt::get(Ty, Bits);
return BinaryOperator::CreateAnd(X, Mask);
}
// Turn (Y + ((X >> C) & CC)) << C -> ((X & (CC << C)) + (Y << C))
Value *Op0BOOp1 = Op0BO->getOperand(1);
if (isLeftShift && Op0BOOp1->hasOneUse() &&
- match(Op0BOOp1,
- m_And(m_OneUse(m_Shr(m_Value(V1), m_Specific(Op1))),
- m_ConstantInt(CC)))) {
- Value *YS = // (Y << C)
- Builder.CreateShl(Op0BO->getOperand(0), Op1, Op0BO->getName());
+ match(Op0BOOp1, m_And(m_OneUse(m_Shr(m_Value(V1), m_Specific(Op1))),
+ m_APInt(CC)))) {
+ Value *YS = // (Y << C)
+ Builder.CreateShl(Op0BO->getOperand(0), Op1, Op0BO->getName());
// X & (CC << C)
- Value *XM = Builder.CreateAnd(V1, ConstantExpr::getShl(CC, Op1),
- V1->getName()+".mask");
+ Value *XM = Builder.CreateAnd(
+ V1, ConstantExpr::getShl(ConstantInt::get(Ty, *CC), Op1),
+ V1->getName() + ".mask");
return BinaryOperator::Create(Op0BO->getOpcode(), YS, XM);
}
LLVM_FALLTHROUGH;
@@ -805,25 +776,22 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1,
Value *X = Builder.CreateBinOp(Op0BO->getOpcode(), V1, YS,
Op0BO->getOperand(0)->getName());
unsigned Op1Val = Op1C->getLimitedValue(TypeBits);
-
APInt Bits = APInt::getHighBitsSet(TypeBits, TypeBits - Op1Val);
- Constant *Mask = ConstantInt::get(I.getContext(), Bits);
- if (VectorType *VT = dyn_cast<VectorType>(X->getType()))
- Mask = ConstantVector::getSplat(VT->getElementCount(), Mask);
+ Constant *Mask = ConstantInt::get(Ty, Bits);
return BinaryOperator::CreateAnd(X, Mask);
}
// Turn (((X >> C)&CC) + Y) << C -> (X + (Y << C)) & (CC << C)
if (isLeftShift && Op0BO->getOperand(0)->hasOneUse() &&
match(Op0BO->getOperand(0),
- m_And(m_OneUse(m_Shr(m_Value(V1), m_Value(V2))),
- m_ConstantInt(CC))) && V2 == Op1) {
+ m_And(m_OneUse(m_Shr(m_Value(V1), m_Specific(Op1))),
+ m_APInt(CC)))) {
Value *YS = // (Y << C)
- Builder.CreateShl(Op0BO->getOperand(1), Op1, Op0BO->getName());
+ Builder.CreateShl(Op0BO->getOperand(1), Op1, Op0BO->getName());
// X & (CC << C)
- Value *XM = Builder.CreateAnd(V1, ConstantExpr::getShl(CC, Op1),
- V1->getName()+".mask");
-
+ Value *XM = Builder.CreateAnd(
+ V1, ConstantExpr::getShl(ConstantInt::get(Ty, *CC), Op1),
+ V1->getName() + ".mask");
return BinaryOperator::Create(Op0BO->getOpcode(), XM, YS);
}
@@ -831,7 +799,6 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1,
}
}
-
// If the operand is a bitwise operator with a constant RHS, and the
// shift is the only use, we can pull it out of the shift.
const APInt *Op0C;
@@ -915,7 +882,7 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1,
return nullptr;
}
-Instruction *InstCombiner::visitShl(BinaryOperator &I) {
+Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) {
const SimplifyQuery Q = SQ.getWithInstruction(&I);
if (Value *V = SimplifyShlInst(I.getOperand(0), I.getOperand(1),
@@ -955,10 +922,9 @@ Instruction *InstCombiner::visitShl(BinaryOperator &I) {
return BinaryOperator::CreateAnd(X, ConstantInt::get(Ty, Mask));
}
- // FIXME: we do not yet transform non-exact shr's. The backend (DAGCombine)
- // needs a few fixes for the rotate pattern recognition first.
const APInt *ShOp1;
- if (match(Op0, m_Exact(m_Shr(m_Value(X), m_APInt(ShOp1))))) {
+ if (match(Op0, m_Exact(m_Shr(m_Value(X), m_APInt(ShOp1)))) &&
+ ShOp1->ult(BitWidth)) {
unsigned ShrAmt = ShOp1->getZExtValue();
if (ShrAmt < ShAmt) {
// If C1 < C2: (X >>?,exact C1) << C2 --> X << (C2 - C1)
@@ -978,7 +944,33 @@ Instruction *InstCombiner::visitShl(BinaryOperator &I) {
}
}
- if (match(Op0, m_Shl(m_Value(X), m_APInt(ShOp1)))) {
+ if (match(Op0, m_OneUse(m_Shr(m_Value(X), m_APInt(ShOp1)))) &&
+ ShOp1->ult(BitWidth)) {
+ unsigned ShrAmt = ShOp1->getZExtValue();
+ if (ShrAmt < ShAmt) {
+ // If C1 < C2: (X >>? C1) << C2 --> X << (C2 - C1) & (-1 << C2)
+ Constant *ShiftDiff = ConstantInt::get(Ty, ShAmt - ShrAmt);
+ auto *NewShl = BinaryOperator::CreateShl(X, ShiftDiff);
+ NewShl->setHasNoUnsignedWrap(I.hasNoUnsignedWrap());
+ NewShl->setHasNoSignedWrap(I.hasNoSignedWrap());
+ Builder.Insert(NewShl);
+ APInt Mask(APInt::getHighBitsSet(BitWidth, BitWidth - ShAmt));
+ return BinaryOperator::CreateAnd(NewShl, ConstantInt::get(Ty, Mask));
+ }
+ if (ShrAmt > ShAmt) {
+ // If C1 > C2: (X >>? C1) << C2 --> X >>? (C1 - C2) & (-1 << C2)
+ Constant *ShiftDiff = ConstantInt::get(Ty, ShrAmt - ShAmt);
+ auto *OldShr = cast<BinaryOperator>(Op0);
+ auto *NewShr =
+ BinaryOperator::Create(OldShr->getOpcode(), X, ShiftDiff);
+ NewShr->setIsExact(OldShr->isExact());
+ Builder.Insert(NewShr);
+ APInt Mask(APInt::getHighBitsSet(BitWidth, BitWidth - ShAmt));
+ return BinaryOperator::CreateAnd(NewShr, ConstantInt::get(Ty, Mask));
+ }
+ }
+
+ if (match(Op0, m_Shl(m_Value(X), m_APInt(ShOp1))) && ShOp1->ult(BitWidth)) {
unsigned AmtSum = ShAmt + ShOp1->getZExtValue();
// Oversized shifts are simplified to zero in InstSimplify.
if (AmtSum < BitWidth)
@@ -1037,7 +1029,7 @@ Instruction *InstCombiner::visitShl(BinaryOperator &I) {
return nullptr;
}
-Instruction *InstCombiner::visitLShr(BinaryOperator &I) {
+Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) {
if (Value *V = SimplifyLShrInst(I.getOperand(0), I.getOperand(1), I.isExact(),
SQ.getWithInstruction(&I)))
return replaceInstUsesWith(I, V);
@@ -1139,6 +1131,12 @@ Instruction *InstCombiner::visitLShr(BinaryOperator &I) {
}
}
+ // lshr i32 (X -nsw Y), 31 --> zext (X < Y)
+ Value *Y;
+ if (ShAmt == BitWidth - 1 &&
+ match(Op0, m_OneUse(m_NSWSub(m_Value(X), m_Value(Y)))))
+ return new ZExtInst(Builder.CreateICmpSLT(X, Y), Ty);
+
if (match(Op0, m_LShr(m_Value(X), m_APInt(ShOp1)))) {
unsigned AmtSum = ShAmt + ShOp1->getZExtValue();
// Oversized shifts are simplified to zero in InstSimplify.
@@ -1167,7 +1165,7 @@ Instruction *InstCombiner::visitLShr(BinaryOperator &I) {
}
Instruction *
-InstCombiner::foldVariableSignZeroExtensionOfVariableHighBitExtract(
+InstCombinerImpl::foldVariableSignZeroExtensionOfVariableHighBitExtract(
BinaryOperator &OldAShr) {
assert(OldAShr.getOpcode() == Instruction::AShr &&
"Must be called with arithmetic right-shift instruction only.");
@@ -1235,7 +1233,7 @@ InstCombiner::foldVariableSignZeroExtensionOfVariableHighBitExtract(
return TruncInst::CreateTruncOrBitCast(NewAShr, OldAShr.getType());
}
-Instruction *InstCombiner::visitAShr(BinaryOperator &I) {
+Instruction *InstCombinerImpl::visitAShr(BinaryOperator &I) {
if (Value *V = SimplifyAShrInst(I.getOperand(0), I.getOperand(1), I.isExact(),
SQ.getWithInstruction(&I)))
return replaceInstUsesWith(I, V);
@@ -1301,6 +1299,12 @@ Instruction *InstCombiner::visitAShr(BinaryOperator &I) {
return new SExtInst(NewSh, Ty);
}
+ // ashr i32 (X -nsw Y), 31 --> sext (X < Y)
+ Value *Y;
+ if (ShAmt == BitWidth - 1 &&
+ match(Op0, m_OneUse(m_NSWSub(m_Value(X), m_Value(Y)))))
+ return new SExtInst(Builder.CreateICmpSLT(X, Y), Ty);
+
// If the shifted-out value is known-zero, then this is an exact shift.
if (!I.isExact() &&
MaskedValueIsZero(Op0, APInt::getLowBitsSet(BitWidth, ShAmt), 0, &I)) {
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
index 7cfe4c8b5892..c265516213aa 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
@@ -12,29 +12,18 @@
//===----------------------------------------------------------------------===//
#include "InstCombineInternal.h"
+#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/IR/IntrinsicInst.h"
-#include "llvm/IR/IntrinsicsAMDGPU.h"
-#include "llvm/IR/IntrinsicsX86.h"
#include "llvm/IR/PatternMatch.h"
#include "llvm/Support/KnownBits.h"
+#include "llvm/Transforms/InstCombine/InstCombiner.h"
using namespace llvm;
using namespace llvm::PatternMatch;
#define DEBUG_TYPE "instcombine"
-namespace {
-
-struct AMDGPUImageDMaskIntrinsic {
- unsigned Intr;
-};
-
-#define GET_AMDGPUImageDMaskIntrinsicTable_IMPL
-#include "InstCombineTables.inc"
-
-} // end anonymous namespace
-
/// Check to see if the specified operand of the specified instruction is a
/// constant integer. If so, check to see if there are any bits set in the
/// constant that are not demanded. If so, shrink the constant and return true.
@@ -63,7 +52,7 @@ static bool ShrinkDemandedConstant(Instruction *I, unsigned OpNo,
/// Inst is an integer instruction that SimplifyDemandedBits knows about. See if
/// the instruction has any properties that allow us to simplify its operands.
-bool InstCombiner::SimplifyDemandedInstructionBits(Instruction &Inst) {
+bool InstCombinerImpl::SimplifyDemandedInstructionBits(Instruction &Inst) {
unsigned BitWidth = Inst.getType()->getScalarSizeInBits();
KnownBits Known(BitWidth);
APInt DemandedMask(APInt::getAllOnesValue(BitWidth));
@@ -79,22 +68,20 @@ bool InstCombiner::SimplifyDemandedInstructionBits(Instruction &Inst) {
/// This form of SimplifyDemandedBits simplifies the specified instruction
/// operand if possible, updating it in place. It returns true if it made any
/// change and false otherwise.
-bool InstCombiner::SimplifyDemandedBits(Instruction *I, unsigned OpNo,
- const APInt &DemandedMask,
- KnownBits &Known,
- unsigned Depth) {
+bool InstCombinerImpl::SimplifyDemandedBits(Instruction *I, unsigned OpNo,
+ const APInt &DemandedMask,
+ KnownBits &Known, unsigned Depth) {
Use &U = I->getOperandUse(OpNo);
Value *NewVal = SimplifyDemandedUseBits(U.get(), DemandedMask, Known,
Depth, I);
if (!NewVal) return false;
if (Instruction* OpInst = dyn_cast<Instruction>(U))
salvageDebugInfo(*OpInst);
-
+
replaceUse(U, NewVal);
return true;
}
-
/// This function attempts to replace V with a simpler value based on the
/// demanded bits. When this function is called, it is known that only the bits
/// set in DemandedMask of the result of V are ever used downstream.
@@ -118,11 +105,12 @@ bool InstCombiner::SimplifyDemandedBits(Instruction *I, unsigned OpNo,
/// operands based on the information about what bits are demanded. This returns
/// some other non-null value if it found out that V is equal to another value
/// in the context where the specified bits are demanded, but not for all users.
-Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
- KnownBits &Known, unsigned Depth,
- Instruction *CxtI) {
+Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
+ KnownBits &Known,
+ unsigned Depth,
+ Instruction *CxtI) {
assert(V != nullptr && "Null pointer of Value???");
- assert(Depth <= 6 && "Limit Search Depth");
+ assert(Depth <= MaxAnalysisRecursionDepth && "Limit Search Depth");
uint32_t BitWidth = DemandedMask.getBitWidth();
Type *VTy = V->getType();
assert(
@@ -139,7 +127,10 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
if (DemandedMask.isNullValue()) // Not demanding any bits from V.
return UndefValue::get(VTy);
- if (Depth == 6) // Limit search depth.
+ if (Depth == MaxAnalysisRecursionDepth)
+ return nullptr;
+
+ if (isa<ScalableVectorType>(VTy))
return nullptr;
Instruction *I = dyn_cast<Instruction>(V);
@@ -268,35 +259,44 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
return InsertNewInstWith(And, *I);
}
- // If the RHS is a constant, see if we can simplify it.
- // FIXME: for XOR, we prefer to force bits to 1 if they will make a -1.
- if (ShrinkDemandedConstant(I, 1, DemandedMask))
- return I;
+ // If the RHS is a constant, see if we can change it. Don't alter a -1
+ // constant because that's a canonical 'not' op, and that is better for
+ // combining, SCEV, and codegen.
+ const APInt *C;
+ if (match(I->getOperand(1), m_APInt(C)) && !C->isAllOnesValue()) {
+ if ((*C | ~DemandedMask).isAllOnesValue()) {
+ // Force bits to 1 to create a 'not' op.
+ I->setOperand(1, ConstantInt::getAllOnesValue(VTy));
+ return I;
+ }
+ // If we can't turn this into a 'not', try to shrink the constant.
+ if (ShrinkDemandedConstant(I, 1, DemandedMask))
+ return I;
+ }
// If our LHS is an 'and' and if it has one use, and if any of the bits we
// are flipping are known to be set, then the xor is just resetting those
// bits to zero. We can just knock out bits from the 'and' and the 'xor',
// simplifying both of them.
- if (Instruction *LHSInst = dyn_cast<Instruction>(I->getOperand(0)))
+ if (Instruction *LHSInst = dyn_cast<Instruction>(I->getOperand(0))) {
+ ConstantInt *AndRHS, *XorRHS;
if (LHSInst->getOpcode() == Instruction::And && LHSInst->hasOneUse() &&
- isa<ConstantInt>(I->getOperand(1)) &&
- isa<ConstantInt>(LHSInst->getOperand(1)) &&
+ match(I->getOperand(1), m_ConstantInt(XorRHS)) &&
+ match(LHSInst->getOperand(1), m_ConstantInt(AndRHS)) &&
(LHSKnown.One & RHSKnown.One & DemandedMask) != 0) {
- ConstantInt *AndRHS = cast<ConstantInt>(LHSInst->getOperand(1));
- ConstantInt *XorRHS = cast<ConstantInt>(I->getOperand(1));
APInt NewMask = ~(LHSKnown.One & RHSKnown.One & DemandedMask);
Constant *AndC =
- ConstantInt::get(I->getType(), NewMask & AndRHS->getValue());
+ ConstantInt::get(I->getType(), NewMask & AndRHS->getValue());
Instruction *NewAnd = BinaryOperator::CreateAnd(I->getOperand(0), AndC);
InsertNewInstWith(NewAnd, *I);
Constant *XorC =
- ConstantInt::get(I->getType(), NewMask & XorRHS->getValue());
+ ConstantInt::get(I->getType(), NewMask & XorRHS->getValue());
Instruction *NewXor = BinaryOperator::CreateXor(NewAnd, XorC);
return InsertNewInstWith(NewXor, *I);
}
-
+ }
break;
}
case Instruction::Select: {
@@ -339,7 +339,7 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
// we can. This helps not break apart (or helps put back together)
// canonical patterns like min and max.
auto CanonicalizeSelectConstant = [](Instruction *I, unsigned OpNo,
- APInt DemandedMask) {
+ const APInt &DemandedMask) {
const APInt *SelC;
if (!match(I->getOperand(OpNo), m_APInt(SelC)))
return false;
@@ -367,8 +367,7 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
return I;
// Only known if known in both the LHS and RHS.
- Known.One = RHSKnown.One & LHSKnown.One;
- Known.Zero = RHSKnown.Zero & LHSKnown.Zero;
+ Known = KnownBits::commonBits(LHSKnown, RHSKnown);
break;
}
case Instruction::ZExt:
@@ -391,7 +390,8 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
if (VectorType *DstVTy = dyn_cast<VectorType>(I->getType())) {
if (VectorType *SrcVTy =
dyn_cast<VectorType>(I->getOperand(0)->getType())) {
- if (DstVTy->getNumElements() != SrcVTy->getNumElements())
+ if (cast<FixedVectorType>(DstVTy)->getNumElements() !=
+ cast<FixedVectorType>(SrcVTy)->getNumElements())
// Don't touch a bitcast between vectors of different element counts.
return nullptr;
} else
@@ -669,8 +669,9 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
}
break;
}
- case Instruction::SRem:
- if (ConstantInt *Rem = dyn_cast<ConstantInt>(I->getOperand(1))) {
+ case Instruction::SRem: {
+ ConstantInt *Rem;
+ if (match(I->getOperand(1), m_ConstantInt(Rem))) {
// X % -1 demands all the bits because we don't want to introduce
// INT_MIN % -1 (== undef) by accident.
if (Rem->isMinusOne())
@@ -713,6 +714,7 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
Known.makeNonNegative();
}
break;
+ }
case Instruction::URem: {
KnownBits Known2(BitWidth);
APInt AllOnes = APInt::getAllOnesValue(BitWidth);
@@ -728,7 +730,6 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
bool KnownBitsComputed = false;
if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) {
switch (II->getIntrinsicID()) {
- default: break;
case Intrinsic::bswap: {
// If the only bits demanded come from one byte of the bswap result,
// just shift the input byte into position to eliminate the bswap.
@@ -784,39 +785,14 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
KnownBitsComputed = true;
break;
}
- case Intrinsic::x86_mmx_pmovmskb:
- case Intrinsic::x86_sse_movmsk_ps:
- case Intrinsic::x86_sse2_movmsk_pd:
- case Intrinsic::x86_sse2_pmovmskb_128:
- case Intrinsic::x86_avx_movmsk_ps_256:
- case Intrinsic::x86_avx_movmsk_pd_256:
- case Intrinsic::x86_avx2_pmovmskb: {
- // MOVMSK copies the vector elements' sign bits to the low bits
- // and zeros the high bits.
- unsigned ArgWidth;
- if (II->getIntrinsicID() == Intrinsic::x86_mmx_pmovmskb) {
- ArgWidth = 8; // Arg is x86_mmx, but treated as <8 x i8>.
- } else {
- auto Arg = II->getArgOperand(0);
- auto ArgType = cast<VectorType>(Arg->getType());
- ArgWidth = ArgType->getNumElements();
- }
-
- // If we don't need any of low bits then return zero,
- // we know that DemandedMask is non-zero already.
- APInt DemandedElts = DemandedMask.zextOrTrunc(ArgWidth);
- if (DemandedElts.isNullValue())
- return ConstantInt::getNullValue(VTy);
-
- // We know that the upper bits are set to zero.
- Known.Zero.setBitsFrom(ArgWidth);
- KnownBitsComputed = true;
+ default: {
+ // Handle target specific intrinsics
+ Optional<Value *> V = targetSimplifyDemandedUseBitsIntrinsic(
+ *II, DemandedMask, Known, KnownBitsComputed);
+ if (V.hasValue())
+ return V.getValue();
break;
}
- case Intrinsic::x86_sse42_crc32_64_64:
- Known.Zero.setBitsFrom(32);
- KnownBitsComputed = true;
- break;
}
}
@@ -836,11 +812,9 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
/// Helper routine of SimplifyDemandedUseBits. It computes Known
/// bits. It also tries to handle simplifications that can be done based on
/// DemandedMask, but without modifying the Instruction.
-Value *InstCombiner::SimplifyMultipleUseDemandedBits(Instruction *I,
- const APInt &DemandedMask,
- KnownBits &Known,
- unsigned Depth,
- Instruction *CxtI) {
+Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits(
+ Instruction *I, const APInt &DemandedMask, KnownBits &Known, unsigned Depth,
+ Instruction *CxtI) {
unsigned BitWidth = DemandedMask.getBitWidth();
Type *ITy = I->getType();
@@ -925,6 +899,33 @@ Value *InstCombiner::SimplifyMultipleUseDemandedBits(Instruction *I,
break;
}
+ case Instruction::AShr: {
+ // Compute the Known bits to simplify things downstream.
+ computeKnownBits(I, Known, Depth, CxtI);
+
+ // If this user is only demanding bits that we know, return the known
+ // constant.
+ if (DemandedMask.isSubsetOf(Known.Zero | Known.One))
+ return Constant::getIntegerValue(ITy, Known.One);
+
+ // If the right shift operand 0 is a result of a left shift by the same
+ // amount, this is probably a zero/sign extension, which may be unnecessary,
+ // if we do not demand any of the new sign bits. So, return the original
+ // operand instead.
+ const APInt *ShiftRC;
+ const APInt *ShiftLC;
+ Value *X;
+ unsigned BitWidth = DemandedMask.getBitWidth();
+ if (match(I,
+ m_AShr(m_Shl(m_Value(X), m_APInt(ShiftLC)), m_APInt(ShiftRC))) &&
+ ShiftLC == ShiftRC &&
+ DemandedMask.isSubsetOf(APInt::getLowBitsSet(
+ BitWidth, BitWidth - ShiftRC->getZExtValue()))) {
+ return X;
+ }
+
+ break;
+ }
default:
// Compute the Known bits to simplify things downstream.
computeKnownBits(I, Known, Depth, CxtI);
@@ -940,7 +941,6 @@ Value *InstCombiner::SimplifyMultipleUseDemandedBits(Instruction *I,
return nullptr;
}
-
/// Helper routine of SimplifyDemandedUseBits. It tries to simplify
/// "E1 = (X lsr C1) << C2", where the C1 and C2 are constant, into
/// "E2 = X << (C2 - C1)" or "E2 = X >> (C1 - C2)", depending on the sign
@@ -958,11 +958,9 @@ Value *InstCombiner::SimplifyMultipleUseDemandedBits(Instruction *I,
///
/// As with SimplifyDemandedUseBits, it returns NULL if the simplification was
/// not successful.
-Value *
-InstCombiner::simplifyShrShlDemandedBits(Instruction *Shr, const APInt &ShrOp1,
- Instruction *Shl, const APInt &ShlOp1,
- const APInt &DemandedMask,
- KnownBits &Known) {
+Value *InstCombinerImpl::simplifyShrShlDemandedBits(
+ Instruction *Shr, const APInt &ShrOp1, Instruction *Shl,
+ const APInt &ShlOp1, const APInt &DemandedMask, KnownBits &Known) {
if (!ShlOp1 || !ShrOp1)
return nullptr; // No-op.
@@ -1022,156 +1020,9 @@ InstCombiner::simplifyShrShlDemandedBits(Instruction *Shr, const APInt &ShrOp1,
return nullptr;
}
-/// Implement SimplifyDemandedVectorElts for amdgcn buffer and image intrinsics.
-///
-/// Note: This only supports non-TFE/LWE image intrinsic calls; those have
-/// struct returns.
-Value *InstCombiner::simplifyAMDGCNMemoryIntrinsicDemanded(IntrinsicInst *II,
- APInt DemandedElts,
- int DMaskIdx) {
-
- // FIXME: Allow v3i16/v3f16 in buffer intrinsics when the types are fully supported.
- if (DMaskIdx < 0 &&
- II->getType()->getScalarSizeInBits() != 32 &&
- DemandedElts.getActiveBits() == 3)
- return nullptr;
-
- auto *IIVTy = cast<VectorType>(II->getType());
- unsigned VWidth = IIVTy->getNumElements();
- if (VWidth == 1)
- return nullptr;
-
- IRBuilderBase::InsertPointGuard Guard(Builder);
- Builder.SetInsertPoint(II);
-
- // Assume the arguments are unchanged and later override them, if needed.
- SmallVector<Value *, 16> Args(II->arg_begin(), II->arg_end());
-
- if (DMaskIdx < 0) {
- // Buffer case.
-
- const unsigned ActiveBits = DemandedElts.getActiveBits();
- const unsigned UnusedComponentsAtFront = DemandedElts.countTrailingZeros();
-
- // Start assuming the prefix of elements is demanded, but possibly clear
- // some other bits if there are trailing zeros (unused components at front)
- // and update offset.
- DemandedElts = (1 << ActiveBits) - 1;
-
- if (UnusedComponentsAtFront > 0) {
- static const unsigned InvalidOffsetIdx = 0xf;
-
- unsigned OffsetIdx;
- switch (II->getIntrinsicID()) {
- case Intrinsic::amdgcn_raw_buffer_load:
- OffsetIdx = 1;
- break;
- case Intrinsic::amdgcn_s_buffer_load:
- // If resulting type is vec3, there is no point in trimming the
- // load with updated offset, as the vec3 would most likely be widened to
- // vec4 anyway during lowering.
- if (ActiveBits == 4 && UnusedComponentsAtFront == 1)
- OffsetIdx = InvalidOffsetIdx;
- else
- OffsetIdx = 1;
- break;
- case Intrinsic::amdgcn_struct_buffer_load:
- OffsetIdx = 2;
- break;
- default:
- // TODO: handle tbuffer* intrinsics.
- OffsetIdx = InvalidOffsetIdx;
- break;
- }
-
- if (OffsetIdx != InvalidOffsetIdx) {
- // Clear demanded bits and update the offset.
- DemandedElts &= ~((1 << UnusedComponentsAtFront) - 1);
- auto *Offset = II->getArgOperand(OffsetIdx);
- unsigned SingleComponentSizeInBits =
- getDataLayout().getTypeSizeInBits(II->getType()->getScalarType());
- unsigned OffsetAdd =
- UnusedComponentsAtFront * SingleComponentSizeInBits / 8;
- auto *OffsetAddVal = ConstantInt::get(Offset->getType(), OffsetAdd);
- Args[OffsetIdx] = Builder.CreateAdd(Offset, OffsetAddVal);
- }
- }
- } else {
- // Image case.
-
- ConstantInt *DMask = cast<ConstantInt>(II->getArgOperand(DMaskIdx));
- unsigned DMaskVal = DMask->getZExtValue() & 0xf;
-
- // Mask off values that are undefined because the dmask doesn't cover them
- DemandedElts &= (1 << countPopulation(DMaskVal)) - 1;
-
- unsigned NewDMaskVal = 0;
- unsigned OrigLoadIdx = 0;
- for (unsigned SrcIdx = 0; SrcIdx < 4; ++SrcIdx) {
- const unsigned Bit = 1 << SrcIdx;
- if (!!(DMaskVal & Bit)) {
- if (!!DemandedElts[OrigLoadIdx])
- NewDMaskVal |= Bit;
- OrigLoadIdx++;
- }
- }
-
- if (DMaskVal != NewDMaskVal)
- Args[DMaskIdx] = ConstantInt::get(DMask->getType(), NewDMaskVal);
- }
-
- unsigned NewNumElts = DemandedElts.countPopulation();
- if (!NewNumElts)
- return UndefValue::get(II->getType());
-
- if (NewNumElts >= VWidth && DemandedElts.isMask()) {
- if (DMaskIdx >= 0)
- II->setArgOperand(DMaskIdx, Args[DMaskIdx]);
- return nullptr;
- }
-
- // Validate function argument and return types, extracting overloaded types
- // along the way.
- SmallVector<Type *, 6> OverloadTys;
- if (!Intrinsic::getIntrinsicSignature(II->getCalledFunction(), OverloadTys))
- return nullptr;
-
- Module *M = II->getParent()->getParent()->getParent();
- Type *EltTy = IIVTy->getElementType();
- Type *NewTy =
- (NewNumElts == 1) ? EltTy : FixedVectorType::get(EltTy, NewNumElts);
-
- OverloadTys[0] = NewTy;
- Function *NewIntrin =
- Intrinsic::getDeclaration(M, II->getIntrinsicID(), OverloadTys);
-
- CallInst *NewCall = Builder.CreateCall(NewIntrin, Args);
- NewCall->takeName(II);
- NewCall->copyMetadata(*II);
-
- if (NewNumElts == 1) {
- return Builder.CreateInsertElement(UndefValue::get(II->getType()), NewCall,
- DemandedElts.countTrailingZeros());
- }
-
- SmallVector<int, 8> EltMask;
- unsigned NewLoadIdx = 0;
- for (unsigned OrigLoadIdx = 0; OrigLoadIdx < VWidth; ++OrigLoadIdx) {
- if (!!DemandedElts[OrigLoadIdx])
- EltMask.push_back(NewLoadIdx++);
- else
- EltMask.push_back(NewNumElts);
- }
-
- Value *Shuffle =
- Builder.CreateShuffleVector(NewCall, UndefValue::get(NewTy), EltMask);
-
- return Shuffle;
-}
-
/// The specified value produces a vector with any number of elements.
-/// This method analyzes which elements of the operand are undef and returns
-/// that information in UndefElts.
+/// This method analyzes which elements of the operand are undef or poison and
+/// returns that information in UndefElts.
///
/// DemandedElts contains the set of elements that are actually used by the
/// caller, and by default (AllowMultipleUsers equals false) the value is
@@ -1182,10 +1033,11 @@ Value *InstCombiner::simplifyAMDGCNMemoryIntrinsicDemanded(IntrinsicInst *II,
/// If the information about demanded elements can be used to simplify the
/// operation, the operation is simplified, then the resultant value is
/// returned. This returns null if no change was made.
-Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts,
- APInt &UndefElts,
- unsigned Depth,
- bool AllowMultipleUsers) {
+Value *InstCombinerImpl::SimplifyDemandedVectorElts(Value *V,
+ APInt DemandedElts,
+ APInt &UndefElts,
+ unsigned Depth,
+ bool AllowMultipleUsers) {
// Cannot analyze scalable type. The number of vector elements is not a
// compile-time constant.
if (isa<ScalableVectorType>(V->getType()))
@@ -1196,14 +1048,14 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts,
assert((DemandedElts & ~EltMask) == 0 && "Invalid DemandedElts!");
if (isa<UndefValue>(V)) {
- // If the entire vector is undefined, just return this info.
+ // If the entire vector is undef or poison, just return this info.
UndefElts = EltMask;
return nullptr;
}
- if (DemandedElts.isNullValue()) { // If nothing is demanded, provide undef.
+ if (DemandedElts.isNullValue()) { // If nothing is demanded, provide poison.
UndefElts = EltMask;
- return UndefValue::get(V->getType());
+ return PoisonValue::get(V->getType());
}
UndefElts = 0;
@@ -1215,11 +1067,11 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts,
return nullptr;
Type *EltTy = cast<VectorType>(V->getType())->getElementType();
- Constant *Undef = UndefValue::get(EltTy);
+ Constant *Poison = PoisonValue::get(EltTy);
SmallVector<Constant*, 16> Elts;
for (unsigned i = 0; i != VWidth; ++i) {
- if (!DemandedElts[i]) { // If not demanded, set to undef.
- Elts.push_back(Undef);
+ if (!DemandedElts[i]) { // If not demanded, set to poison.
+ Elts.push_back(Poison);
UndefElts.setBit(i);
continue;
}
@@ -1227,12 +1079,9 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts,
Constant *Elt = C->getAggregateElement(i);
if (!Elt) return nullptr;
- if (isa<UndefValue>(Elt)) { // Already undef.
- Elts.push_back(Undef);
+ Elts.push_back(Elt);
+ if (isa<UndefValue>(Elt)) // Already undef or poison.
UndefElts.setBit(i);
- } else { // Otherwise, defined.
- Elts.push_back(Elt);
- }
}
// If we changed the constant, return it.
@@ -1292,12 +1141,12 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts,
};
if (mayIndexStructType(cast<GetElementPtrInst>(*I)))
break;
-
+
// Conservatively track the demanded elements back through any vector
// operands we may have. We know there must be at least one, or we
// wouldn't have a vector result to get here. Note that we intentionally
// merge the undef bits here since gepping with either an undef base or
- // index results in undef.
+ // index results in undef.
for (unsigned i = 0; i < I->getNumOperands(); i++) {
if (isa<UndefValue>(I->getOperand(i))) {
// If the entire vector is undefined, just return this info.
@@ -1331,6 +1180,19 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts,
if (IdxNo < VWidth)
PreInsertDemandedElts.clearBit(IdxNo);
+ // If we only demand the element that is being inserted and that element
+ // was extracted from the same index in another vector with the same type,
+ // replace this insert with that other vector.
+ // Note: This is attempted before the call to simplifyAndSetOp because that
+ // may change UndefElts to a value that does not match with Vec.
+ Value *Vec;
+ if (PreInsertDemandedElts == 0 &&
+ match(I->getOperand(1),
+ m_ExtractElt(m_Value(Vec), m_SpecificInt(IdxNo))) &&
+ Vec->getType() == I->getType()) {
+ return Vec;
+ }
+
simplifyAndSetOp(I, 0, PreInsertDemandedElts, UndefElts);
// If this is inserting an element that isn't demanded, remove this
@@ -1349,8 +1211,8 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts,
assert(Shuffle->getOperand(0)->getType() ==
Shuffle->getOperand(1)->getType() &&
"Expected shuffle operands to have same type");
- unsigned OpWidth =
- cast<VectorType>(Shuffle->getOperand(0)->getType())->getNumElements();
+ unsigned OpWidth = cast<FixedVectorType>(Shuffle->getOperand(0)->getType())
+ ->getNumElements();
// Handle trivial case of a splat. Only check the first element of LHS
// operand.
if (all_of(Shuffle->getShuffleMask(), [](int Elt) { return Elt == 0; }) &&
@@ -1451,7 +1313,8 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts,
// this constant vector to single insertelement instruction.
// shufflevector V, C, <v1, v2, .., ci, .., vm> ->
// insertelement V, C[ci], ci-n
- if (OpWidth == Shuffle->getType()->getNumElements()) {
+ if (OpWidth ==
+ cast<FixedVectorType>(Shuffle->getType())->getNumElements()) {
Value *Op = nullptr;
Constant *Value = nullptr;
unsigned Idx = -1u;
@@ -1538,7 +1401,7 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts,
// Vector->vector casts only.
VectorType *VTy = dyn_cast<VectorType>(I->getOperand(0)->getType());
if (!VTy) break;
- unsigned InVWidth = VTy->getNumElements();
+ unsigned InVWidth = cast<FixedVectorType>(VTy)->getNumElements();
APInt InputDemandedElts(InVWidth, 0);
UndefElts2 = APInt(InVWidth, 0);
unsigned Ratio;
@@ -1621,227 +1484,19 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts,
if (II->getIntrinsicID() == Intrinsic::masked_gather)
simplifyAndSetOp(II, 0, DemandedPtrs, UndefElts2);
simplifyAndSetOp(II, 3, DemandedPassThrough, UndefElts3);
-
+
// Output elements are undefined if the element from both sources are.
// TODO: can strengthen via mask as well.
UndefElts = UndefElts2 & UndefElts3;
break;
}
- case Intrinsic::x86_xop_vfrcz_ss:
- case Intrinsic::x86_xop_vfrcz_sd:
- // The instructions for these intrinsics are speced to zero upper bits not
- // pass them through like other scalar intrinsics. So we shouldn't just
- // use Arg0 if DemandedElts[0] is clear like we do for other intrinsics.
- // Instead we should return a zero vector.
- if (!DemandedElts[0]) {
- Worklist.push(II);
- return ConstantAggregateZero::get(II->getType());
- }
-
- // Only the lower element is used.
- DemandedElts = 1;
- simplifyAndSetOp(II, 0, DemandedElts, UndefElts);
-
- // Only the lower element is undefined. The high elements are zero.
- UndefElts = UndefElts[0];
- break;
-
- // Unary scalar-as-vector operations that work column-wise.
- case Intrinsic::x86_sse_rcp_ss:
- case Intrinsic::x86_sse_rsqrt_ss:
- simplifyAndSetOp(II, 0, DemandedElts, UndefElts);
-
- // If lowest element of a scalar op isn't used then use Arg0.
- if (!DemandedElts[0]) {
- Worklist.push(II);
- return II->getArgOperand(0);
- }
- // TODO: If only low elt lower SQRT to FSQRT (with rounding/exceptions
- // checks).
- break;
-
- // Binary scalar-as-vector operations that work column-wise. The high
- // elements come from operand 0. The low element is a function of both
- // operands.
- case Intrinsic::x86_sse_min_ss:
- case Intrinsic::x86_sse_max_ss:
- case Intrinsic::x86_sse_cmp_ss:
- case Intrinsic::x86_sse2_min_sd:
- case Intrinsic::x86_sse2_max_sd:
- case Intrinsic::x86_sse2_cmp_sd: {
- simplifyAndSetOp(II, 0, DemandedElts, UndefElts);
-
- // If lowest element of a scalar op isn't used then use Arg0.
- if (!DemandedElts[0]) {
- Worklist.push(II);
- return II->getArgOperand(0);
- }
-
- // Only lower element is used for operand 1.
- DemandedElts = 1;
- simplifyAndSetOp(II, 1, DemandedElts, UndefElts2);
-
- // Lower element is undefined if both lower elements are undefined.
- // Consider things like undef&0. The result is known zero, not undef.
- if (!UndefElts2[0])
- UndefElts.clearBit(0);
-
- break;
- }
-
- // Binary scalar-as-vector operations that work column-wise. The high
- // elements come from operand 0 and the low element comes from operand 1.
- case Intrinsic::x86_sse41_round_ss:
- case Intrinsic::x86_sse41_round_sd: {
- // Don't use the low element of operand 0.
- APInt DemandedElts2 = DemandedElts;
- DemandedElts2.clearBit(0);
- simplifyAndSetOp(II, 0, DemandedElts2, UndefElts);
-
- // If lowest element of a scalar op isn't used then use Arg0.
- if (!DemandedElts[0]) {
- Worklist.push(II);
- return II->getArgOperand(0);
- }
-
- // Only lower element is used for operand 1.
- DemandedElts = 1;
- simplifyAndSetOp(II, 1, DemandedElts, UndefElts2);
-
- // Take the high undef elements from operand 0 and take the lower element
- // from operand 1.
- UndefElts.clearBit(0);
- UndefElts |= UndefElts2[0];
- break;
- }
-
- // Three input scalar-as-vector operations that work column-wise. The high
- // elements come from operand 0 and the low element is a function of all
- // three inputs.
- case Intrinsic::x86_avx512_mask_add_ss_round:
- case Intrinsic::x86_avx512_mask_div_ss_round:
- case Intrinsic::x86_avx512_mask_mul_ss_round:
- case Intrinsic::x86_avx512_mask_sub_ss_round:
- case Intrinsic::x86_avx512_mask_max_ss_round:
- case Intrinsic::x86_avx512_mask_min_ss_round:
- case Intrinsic::x86_avx512_mask_add_sd_round:
- case Intrinsic::x86_avx512_mask_div_sd_round:
- case Intrinsic::x86_avx512_mask_mul_sd_round:
- case Intrinsic::x86_avx512_mask_sub_sd_round:
- case Intrinsic::x86_avx512_mask_max_sd_round:
- case Intrinsic::x86_avx512_mask_min_sd_round:
- simplifyAndSetOp(II, 0, DemandedElts, UndefElts);
-
- // If lowest element of a scalar op isn't used then use Arg0.
- if (!DemandedElts[0]) {
- Worklist.push(II);
- return II->getArgOperand(0);
- }
-
- // Only lower element is used for operand 1 and 2.
- DemandedElts = 1;
- simplifyAndSetOp(II, 1, DemandedElts, UndefElts2);
- simplifyAndSetOp(II, 2, DemandedElts, UndefElts3);
-
- // Lower element is undefined if all three lower elements are undefined.
- // Consider things like undef&0. The result is known zero, not undef.
- if (!UndefElts2[0] || !UndefElts3[0])
- UndefElts.clearBit(0);
-
- break;
-
- case Intrinsic::x86_sse2_packssdw_128:
- case Intrinsic::x86_sse2_packsswb_128:
- case Intrinsic::x86_sse2_packuswb_128:
- case Intrinsic::x86_sse41_packusdw:
- case Intrinsic::x86_avx2_packssdw:
- case Intrinsic::x86_avx2_packsswb:
- case Intrinsic::x86_avx2_packusdw:
- case Intrinsic::x86_avx2_packuswb:
- case Intrinsic::x86_avx512_packssdw_512:
- case Intrinsic::x86_avx512_packsswb_512:
- case Intrinsic::x86_avx512_packusdw_512:
- case Intrinsic::x86_avx512_packuswb_512: {
- auto *Ty0 = II->getArgOperand(0)->getType();
- unsigned InnerVWidth = cast<VectorType>(Ty0)->getNumElements();
- assert(VWidth == (InnerVWidth * 2) && "Unexpected input size");
-
- unsigned NumLanes = Ty0->getPrimitiveSizeInBits() / 128;
- unsigned VWidthPerLane = VWidth / NumLanes;
- unsigned InnerVWidthPerLane = InnerVWidth / NumLanes;
-
- // Per lane, pack the elements of the first input and then the second.
- // e.g.
- // v8i16 PACK(v4i32 X, v4i32 Y) - (X[0..3],Y[0..3])
- // v32i8 PACK(v16i16 X, v16i16 Y) - (X[0..7],Y[0..7]),(X[8..15],Y[8..15])
- for (int OpNum = 0; OpNum != 2; ++OpNum) {
- APInt OpDemandedElts(InnerVWidth, 0);
- for (unsigned Lane = 0; Lane != NumLanes; ++Lane) {
- unsigned LaneIdx = Lane * VWidthPerLane;
- for (unsigned Elt = 0; Elt != InnerVWidthPerLane; ++Elt) {
- unsigned Idx = LaneIdx + Elt + InnerVWidthPerLane * OpNum;
- if (DemandedElts[Idx])
- OpDemandedElts.setBit((Lane * InnerVWidthPerLane) + Elt);
- }
- }
-
- // Demand elements from the operand.
- APInt OpUndefElts(InnerVWidth, 0);
- simplifyAndSetOp(II, OpNum, OpDemandedElts, OpUndefElts);
-
- // Pack the operand's UNDEF elements, one lane at a time.
- OpUndefElts = OpUndefElts.zext(VWidth);
- for (unsigned Lane = 0; Lane != NumLanes; ++Lane) {
- APInt LaneElts = OpUndefElts.lshr(InnerVWidthPerLane * Lane);
- LaneElts = LaneElts.getLoBits(InnerVWidthPerLane);
- LaneElts <<= InnerVWidthPerLane * (2 * Lane + OpNum);
- UndefElts |= LaneElts;
- }
- }
- break;
- }
-
- // PSHUFB
- case Intrinsic::x86_ssse3_pshuf_b_128:
- case Intrinsic::x86_avx2_pshuf_b:
- case Intrinsic::x86_avx512_pshuf_b_512:
- // PERMILVAR
- case Intrinsic::x86_avx_vpermilvar_ps:
- case Intrinsic::x86_avx_vpermilvar_ps_256:
- case Intrinsic::x86_avx512_vpermilvar_ps_512:
- case Intrinsic::x86_avx_vpermilvar_pd:
- case Intrinsic::x86_avx_vpermilvar_pd_256:
- case Intrinsic::x86_avx512_vpermilvar_pd_512:
- // PERMV
- case Intrinsic::x86_avx2_permd:
- case Intrinsic::x86_avx2_permps: {
- simplifyAndSetOp(II, 1, DemandedElts, UndefElts);
- break;
- }
-
- // SSE4A instructions leave the upper 64-bits of the 128-bit result
- // in an undefined state.
- case Intrinsic::x86_sse4a_extrq:
- case Intrinsic::x86_sse4a_extrqi:
- case Intrinsic::x86_sse4a_insertq:
- case Intrinsic::x86_sse4a_insertqi:
- UndefElts.setHighBits(VWidth / 2);
- break;
- case Intrinsic::amdgcn_buffer_load:
- case Intrinsic::amdgcn_buffer_load_format:
- case Intrinsic::amdgcn_raw_buffer_load:
- case Intrinsic::amdgcn_raw_buffer_load_format:
- case Intrinsic::amdgcn_raw_tbuffer_load:
- case Intrinsic::amdgcn_s_buffer_load:
- case Intrinsic::amdgcn_struct_buffer_load:
- case Intrinsic::amdgcn_struct_buffer_load_format:
- case Intrinsic::amdgcn_struct_tbuffer_load:
- case Intrinsic::amdgcn_tbuffer_load:
- return simplifyAMDGCNMemoryIntrinsicDemanded(II, DemandedElts);
default: {
- if (getAMDGPUImageDMaskIntrinsic(II->getIntrinsicID()))
- return simplifyAMDGCNMemoryIntrinsicDemanded(II, DemandedElts, 0);
-
+ // Handle target specific intrinsics
+ Optional<Value *> V = targetSimplifyDemandedVectorEltsIntrinsic(
+ *II, DemandedElts, UndefElts, UndefElts2, UndefElts3,
+ simplifyAndSetOp);
+ if (V.hasValue())
+ return V.getValue();
break;
}
} // switch on IntrinsicID
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineTables.td b/llvm/lib/Transforms/InstCombine/InstCombineTables.td
deleted file mode 100644
index 98b2adc442fa..000000000000
--- a/llvm/lib/Transforms/InstCombine/InstCombineTables.td
+++ /dev/null
@@ -1,11 +0,0 @@
-include "llvm/TableGen/SearchableTable.td"
-include "llvm/IR/Intrinsics.td"
-
-def AMDGPUImageDMaskIntrinsicTable : GenericTable {
- let FilterClass = "AMDGPUImageDMaskIntrinsic";
- let Fields = ["Intr"];
-
- let PrimaryKey = ["Intr"];
- let PrimaryKeyName = "getAMDGPUImageDMaskIntrinsic";
- let PrimaryKeyEarlyOut = 1;
-}
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
index ff70347569ab..06f22cdfb63d 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
@@ -18,6 +18,7 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/Statistic.h"
#include "llvm/Analysis/InstructionSimplify.h"
#include "llvm/Analysis/VectorUtils.h"
#include "llvm/IR/BasicBlock.h"
@@ -35,6 +36,7 @@
#include "llvm/Support/Casting.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Transforms/InstCombine/InstCombineWorklist.h"
+#include "llvm/Transforms/InstCombine/InstCombiner.h"
#include <cassert>
#include <cstdint>
#include <iterator>
@@ -45,6 +47,10 @@ using namespace PatternMatch;
#define DEBUG_TYPE "instcombine"
+STATISTIC(NumAggregateReconstructionsSimplified,
+ "Number of aggregate reconstructions turned into reuse of the "
+ "original aggregate");
+
/// Return true if the value is cheaper to scalarize than it is to leave as a
/// vector operation. IsConstantExtractIndex indicates whether we are extracting
/// one known element from a vector constant.
@@ -85,7 +91,8 @@ static bool cheapToScalarize(Value *V, bool IsConstantExtractIndex) {
// If we have a PHI node with a vector type that is only used to feed
// itself and be an operand of extractelement at a constant location,
// try to replace the PHI of the vector type with a PHI of a scalar type.
-Instruction *InstCombiner::scalarizePHI(ExtractElementInst &EI, PHINode *PN) {
+Instruction *InstCombinerImpl::scalarizePHI(ExtractElementInst &EI,
+ PHINode *PN) {
SmallVector<Instruction *, 2> Extracts;
// The users we want the PHI to have are:
// 1) The EI ExtractElement (we already know this)
@@ -178,15 +185,19 @@ static Instruction *foldBitcastExtElt(ExtractElementInst &Ext,
// extelt (bitcast VecX), IndexC --> bitcast X[IndexC]
auto *SrcTy = cast<VectorType>(X->getType());
Type *DestTy = Ext.getType();
- unsigned NumSrcElts = SrcTy->getNumElements();
- unsigned NumElts = Ext.getVectorOperandType()->getNumElements();
+ ElementCount NumSrcElts = SrcTy->getElementCount();
+ ElementCount NumElts =
+ cast<VectorType>(Ext.getVectorOperandType())->getElementCount();
if (NumSrcElts == NumElts)
if (Value *Elt = findScalarElement(X, ExtIndexC))
return new BitCastInst(Elt, DestTy);
+ assert(NumSrcElts.isScalable() == NumElts.isScalable() &&
+ "Src and Dst must be the same sort of vector type");
+
// If the source elements are wider than the destination, try to shift and
// truncate a subset of scalar bits of an insert op.
- if (NumSrcElts < NumElts) {
+ if (NumSrcElts.getKnownMinValue() < NumElts.getKnownMinValue()) {
Value *Scalar;
uint64_t InsIndexC;
if (!match(X, m_InsertElt(m_Value(), m_Value(Scalar),
@@ -197,7 +208,8 @@ static Instruction *foldBitcastExtElt(ExtractElementInst &Ext,
// into. Example: if we inserted element 1 of a <2 x i64> and we are
// extracting an i16 (narrowing ratio = 4), then this extract must be from 1
// of elements 4-7 of the bitcasted vector.
- unsigned NarrowingRatio = NumElts / NumSrcElts;
+ unsigned NarrowingRatio =
+ NumElts.getKnownMinValue() / NumSrcElts.getKnownMinValue();
if (ExtIndexC / NarrowingRatio != InsIndexC)
return nullptr;
@@ -259,7 +271,7 @@ static Instruction *foldBitcastExtElt(ExtractElementInst &Ext,
/// Find elements of V demanded by UserInstr.
static APInt findDemandedEltsBySingleUser(Value *V, Instruction *UserInstr) {
- unsigned VWidth = cast<VectorType>(V->getType())->getNumElements();
+ unsigned VWidth = cast<FixedVectorType>(V->getType())->getNumElements();
// Conservatively assume that all elements are needed.
APInt UsedElts(APInt::getAllOnesValue(VWidth));
@@ -277,7 +289,7 @@ static APInt findDemandedEltsBySingleUser(Value *V, Instruction *UserInstr) {
case Instruction::ShuffleVector: {
ShuffleVectorInst *Shuffle = cast<ShuffleVectorInst>(UserInstr);
unsigned MaskNumElts =
- cast<VectorType>(UserInstr->getType())->getNumElements();
+ cast<FixedVectorType>(UserInstr->getType())->getNumElements();
UsedElts = APInt(VWidth, 0);
for (unsigned i = 0; i < MaskNumElts; i++) {
@@ -303,7 +315,7 @@ static APInt findDemandedEltsBySingleUser(Value *V, Instruction *UserInstr) {
/// no user demands an element of V, then the corresponding bit
/// remains unset in the returned value.
static APInt findDemandedEltsByAllUsers(Value *V) {
- unsigned VWidth = cast<VectorType>(V->getType())->getNumElements();
+ unsigned VWidth = cast<FixedVectorType>(V->getType())->getNumElements();
APInt UnionUsedElts(VWidth, 0);
for (const Use &U : V->uses()) {
@@ -321,7 +333,7 @@ static APInt findDemandedEltsByAllUsers(Value *V) {
return UnionUsedElts;
}
-Instruction *InstCombiner::visitExtractElementInst(ExtractElementInst &EI) {
+Instruction *InstCombinerImpl::visitExtractElementInst(ExtractElementInst &EI) {
Value *SrcVec = EI.getVectorOperand();
Value *Index = EI.getIndexOperand();
if (Value *V = SimplifyExtractElementInst(SrcVec, Index,
@@ -333,17 +345,17 @@ Instruction *InstCombiner::visitExtractElementInst(ExtractElementInst &EI) {
auto *IndexC = dyn_cast<ConstantInt>(Index);
if (IndexC) {
ElementCount EC = EI.getVectorOperandType()->getElementCount();
- unsigned NumElts = EC.Min;
+ unsigned NumElts = EC.getKnownMinValue();
// InstSimplify should handle cases where the index is invalid.
// For fixed-length vector, it's invalid to extract out-of-range element.
- if (!EC.Scalable && IndexC->getValue().uge(NumElts))
+ if (!EC.isScalable() && IndexC->getValue().uge(NumElts))
return nullptr;
// This instruction only demands the single element from the input vector.
// Skip for scalable type, the number of elements is unknown at
// compile-time.
- if (!EC.Scalable && NumElts != 1) {
+ if (!EC.isScalable() && NumElts != 1) {
// If the input vector has a single use, simplify it based on this use
// property.
if (SrcVec->hasOneUse()) {
@@ -460,7 +472,7 @@ static bool collectSingleShuffleElements(Value *V, Value *LHS, Value *RHS,
SmallVectorImpl<int> &Mask) {
assert(LHS->getType() == RHS->getType() &&
"Invalid CollectSingleShuffleElements");
- unsigned NumElts = cast<VectorType>(V->getType())->getNumElements();
+ unsigned NumElts = cast<FixedVectorType>(V->getType())->getNumElements();
if (isa<UndefValue>(V)) {
Mask.assign(NumElts, -1);
@@ -502,7 +514,7 @@ static bool collectSingleShuffleElements(Value *V, Value *LHS, Value *RHS,
unsigned ExtractedIdx =
cast<ConstantInt>(EI->getOperand(1))->getZExtValue();
unsigned NumLHSElts =
- cast<VectorType>(LHS->getType())->getNumElements();
+ cast<FixedVectorType>(LHS->getType())->getNumElements();
// This must be extracting from either LHS or RHS.
if (EI->getOperand(0) == LHS || EI->getOperand(0) == RHS) {
@@ -531,9 +543,9 @@ static bool collectSingleShuffleElements(Value *V, Value *LHS, Value *RHS,
/// shufflevector to replace one or more insert/extract pairs.
static void replaceExtractElements(InsertElementInst *InsElt,
ExtractElementInst *ExtElt,
- InstCombiner &IC) {
- VectorType *InsVecType = InsElt->getType();
- VectorType *ExtVecType = ExtElt->getVectorOperandType();
+ InstCombinerImpl &IC) {
+ auto *InsVecType = cast<FixedVectorType>(InsElt->getType());
+ auto *ExtVecType = cast<FixedVectorType>(ExtElt->getVectorOperandType());
unsigned NumInsElts = InsVecType->getNumElements();
unsigned NumExtElts = ExtVecType->getNumElements();
@@ -614,7 +626,7 @@ using ShuffleOps = std::pair<Value *, Value *>;
static ShuffleOps collectShuffleElements(Value *V, SmallVectorImpl<int> &Mask,
Value *PermittedRHS,
- InstCombiner &IC) {
+ InstCombinerImpl &IC) {
assert(V->getType()->isVectorTy() && "Invalid shuffle!");
unsigned NumElts = cast<FixedVectorType>(V->getType())->getNumElements();
@@ -661,7 +673,7 @@ static ShuffleOps collectShuffleElements(Value *V, SmallVectorImpl<int> &Mask,
}
unsigned NumLHSElts =
- cast<VectorType>(RHS->getType())->getNumElements();
+ cast<FixedVectorType>(RHS->getType())->getNumElements();
Mask[InsertedIdx % NumElts] = NumLHSElts + ExtractedIdx;
return std::make_pair(LR.first, RHS);
}
@@ -670,7 +682,8 @@ static ShuffleOps collectShuffleElements(Value *V, SmallVectorImpl<int> &Mask,
// We've gone as far as we can: anything on the other side of the
// extractelement will already have been converted into a shuffle.
unsigned NumLHSElts =
- cast<VectorType>(EI->getOperand(0)->getType())->getNumElements();
+ cast<FixedVectorType>(EI->getOperand(0)->getType())
+ ->getNumElements();
for (unsigned i = 0; i != NumElts; ++i)
Mask.push_back(i == InsertedIdx ? ExtractedIdx : NumLHSElts + i);
return std::make_pair(EI->getOperand(0), PermittedRHS);
@@ -692,6 +705,285 @@ static ShuffleOps collectShuffleElements(Value *V, SmallVectorImpl<int> &Mask,
return std::make_pair(V, nullptr);
}
+/// Look for chain of insertvalue's that fully define an aggregate, and trace
+/// back the values inserted, see if they are all were extractvalue'd from
+/// the same source aggregate from the exact same element indexes.
+/// If they were, just reuse the source aggregate.
+/// This potentially deals with PHI indirections.
+Instruction *InstCombinerImpl::foldAggregateConstructionIntoAggregateReuse(
+ InsertValueInst &OrigIVI) {
+ Type *AggTy = OrigIVI.getType();
+ unsigned NumAggElts;
+ switch (AggTy->getTypeID()) {
+ case Type::StructTyID:
+ NumAggElts = AggTy->getStructNumElements();
+ break;
+ case Type::ArrayTyID:
+ NumAggElts = AggTy->getArrayNumElements();
+ break;
+ default:
+ llvm_unreachable("Unhandled aggregate type?");
+ }
+
+ // Arbitrary aggregate size cut-off. Motivation for limit of 2 is to be able
+ // to handle clang C++ exception struct (which is hardcoded as {i8*, i32}),
+ // FIXME: any interesting patterns to be caught with larger limit?
+ assert(NumAggElts > 0 && "Aggregate should have elements.");
+ if (NumAggElts > 2)
+ return nullptr;
+
+ static constexpr auto NotFound = None;
+ static constexpr auto FoundMismatch = nullptr;
+
+ // Try to find a value of each element of an aggregate.
+ // FIXME: deal with more complex, not one-dimensional, aggregate types
+ SmallVector<Optional<Value *>, 2> AggElts(NumAggElts, NotFound);
+
+ // Do we know values for each element of the aggregate?
+ auto KnowAllElts = [&AggElts]() {
+ return all_of(AggElts,
+ [](Optional<Value *> Elt) { return Elt != NotFound; });
+ };
+
+ int Depth = 0;
+
+ // Arbitrary `insertvalue` visitation depth limit. Let's be okay with
+ // every element being overwritten twice, which should never happen.
+ static const int DepthLimit = 2 * NumAggElts;
+
+ // Recurse up the chain of `insertvalue` aggregate operands until either we've
+ // reconstructed full initializer or can't visit any more `insertvalue`'s.
+ for (InsertValueInst *CurrIVI = &OrigIVI;
+ Depth < DepthLimit && CurrIVI && !KnowAllElts();
+ CurrIVI = dyn_cast<InsertValueInst>(CurrIVI->getAggregateOperand()),
+ ++Depth) {
+ Value *InsertedValue = CurrIVI->getInsertedValueOperand();
+ ArrayRef<unsigned int> Indices = CurrIVI->getIndices();
+
+ // Don't bother with more than single-level aggregates.
+ if (Indices.size() != 1)
+ return nullptr; // FIXME: deal with more complex aggregates?
+
+ // Now, we may have already previously recorded the value for this element
+ // of an aggregate. If we did, that means the CurrIVI will later be
+ // overwritten with the already-recorded value. But if not, let's record it!
+ Optional<Value *> &Elt = AggElts[Indices.front()];
+ Elt = Elt.getValueOr(InsertedValue);
+
+ // FIXME: should we handle chain-terminating undef base operand?
+ }
+
+ // Was that sufficient to deduce the full initializer for the aggregate?
+ if (!KnowAllElts())
+ return nullptr; // Give up then.
+
+ // We now want to find the source[s] of the aggregate elements we've found.
+ // And with "source" we mean the original aggregate[s] from which
+ // the inserted elements were extracted. This may require PHI translation.
+
+ enum class AggregateDescription {
+ /// When analyzing the value that was inserted into an aggregate, we did
+ /// not manage to find defining `extractvalue` instruction to analyze.
+ NotFound,
+ /// When analyzing the value that was inserted into an aggregate, we did
+ /// manage to find defining `extractvalue` instruction[s], and everything
+ /// matched perfectly - aggregate type, element insertion/extraction index.
+ Found,
+ /// When analyzing the value that was inserted into an aggregate, we did
+ /// manage to find defining `extractvalue` instruction, but there was
+ /// a mismatch: either the source type from which the extraction was didn't
+ /// match the aggregate type into which the insertion was,
+ /// or the extraction/insertion channels mismatched,
+ /// or different elements had different source aggregates.
+ FoundMismatch
+ };
+ auto Describe = [](Optional<Value *> SourceAggregate) {
+ if (SourceAggregate == NotFound)
+ return AggregateDescription::NotFound;
+ if (*SourceAggregate == FoundMismatch)
+ return AggregateDescription::FoundMismatch;
+ return AggregateDescription::Found;
+ };
+
+ // Given the value \p Elt that was being inserted into element \p EltIdx of an
+ // aggregate AggTy, see if \p Elt was originally defined by an
+ // appropriate extractvalue (same element index, same aggregate type).
+ // If found, return the source aggregate from which the extraction was.
+ // If \p PredBB is provided, does PHI translation of an \p Elt first.
+ auto FindSourceAggregate =
+ [&](Value *Elt, unsigned EltIdx, Optional<BasicBlock *> UseBB,
+ Optional<BasicBlock *> PredBB) -> Optional<Value *> {
+ // For now(?), only deal with, at most, a single level of PHI indirection.
+ if (UseBB && PredBB)
+ Elt = Elt->DoPHITranslation(*UseBB, *PredBB);
+ // FIXME: deal with multiple levels of PHI indirection?
+
+ // Did we find an extraction?
+ auto *EVI = dyn_cast<ExtractValueInst>(Elt);
+ if (!EVI)
+ return NotFound;
+
+ Value *SourceAggregate = EVI->getAggregateOperand();
+
+ // Is the extraction from the same type into which the insertion was?
+ if (SourceAggregate->getType() != AggTy)
+ return FoundMismatch;
+ // And the element index doesn't change between extraction and insertion?
+ if (EVI->getNumIndices() != 1 || EltIdx != EVI->getIndices().front())
+ return FoundMismatch;
+
+ return SourceAggregate; // AggregateDescription::Found
+ };
+
+ // Given elements AggElts that were constructing an aggregate OrigIVI,
+ // see if we can find appropriate source aggregate for each of the elements,
+ // and see it's the same aggregate for each element. If so, return it.
+ auto FindCommonSourceAggregate =
+ [&](Optional<BasicBlock *> UseBB,
+ Optional<BasicBlock *> PredBB) -> Optional<Value *> {
+ Optional<Value *> SourceAggregate;
+
+ for (auto I : enumerate(AggElts)) {
+ assert(Describe(SourceAggregate) != AggregateDescription::FoundMismatch &&
+ "We don't store nullptr in SourceAggregate!");
+ assert((Describe(SourceAggregate) == AggregateDescription::Found) ==
+ (I.index() != 0) &&
+ "SourceAggregate should be valid after the the first element,");
+
+ // For this element, is there a plausible source aggregate?
+ // FIXME: we could special-case undef element, IFF we know that in the
+ // source aggregate said element isn't poison.
+ Optional<Value *> SourceAggregateForElement =
+ FindSourceAggregate(*I.value(), I.index(), UseBB, PredBB);
+
+ // Okay, what have we found? Does that correlate with previous findings?
+
+ // Regardless of whether or not we have previously found source
+ // aggregate for previous elements (if any), if we didn't find one for
+ // this element, passthrough whatever we have just found.
+ if (Describe(SourceAggregateForElement) != AggregateDescription::Found)
+ return SourceAggregateForElement;
+
+ // Okay, we have found source aggregate for this element.
+ // Let's see what we already know from previous elements, if any.
+ switch (Describe(SourceAggregate)) {
+ case AggregateDescription::NotFound:
+ // This is apparently the first element that we have examined.
+ SourceAggregate = SourceAggregateForElement; // Record the aggregate!
+ continue; // Great, now look at next element.
+ case AggregateDescription::Found:
+ // We have previously already successfully examined other elements.
+ // Is this the same source aggregate we've found for other elements?
+ if (*SourceAggregateForElement != *SourceAggregate)
+ return FoundMismatch;
+ continue; // Still the same aggregate, look at next element.
+ case AggregateDescription::FoundMismatch:
+ llvm_unreachable("Can't happen. We would have early-exited then.");
+ };
+ }
+
+ assert(Describe(SourceAggregate) == AggregateDescription::Found &&
+ "Must be a valid Value");
+ return *SourceAggregate;
+ };
+
+ Optional<Value *> SourceAggregate;
+
+ // Can we find the source aggregate without looking at predecessors?
+ SourceAggregate = FindCommonSourceAggregate(/*UseBB=*/None, /*PredBB=*/None);
+ if (Describe(SourceAggregate) != AggregateDescription::NotFound) {
+ if (Describe(SourceAggregate) == AggregateDescription::FoundMismatch)
+ return nullptr; // Conflicting source aggregates!
+ ++NumAggregateReconstructionsSimplified;
+ return replaceInstUsesWith(OrigIVI, *SourceAggregate);
+ }
+
+ // Okay, apparently we need to look at predecessors.
+
+ // We should be smart about picking the "use" basic block, which will be the
+ // merge point for aggregate, where we'll insert the final PHI that will be
+ // used instead of OrigIVI. Basic block of OrigIVI is *not* the right choice.
+ // We should look in which blocks each of the AggElts is being defined,
+ // they all should be defined in the same basic block.
+ BasicBlock *UseBB = nullptr;
+
+ for (const Optional<Value *> &Elt : AggElts) {
+ // If this element's value was not defined by an instruction, ignore it.
+ auto *I = dyn_cast<Instruction>(*Elt);
+ if (!I)
+ continue;
+ // Otherwise, in which basic block is this instruction located?
+ BasicBlock *BB = I->getParent();
+ // If it's the first instruction we've encountered, record the basic block.
+ if (!UseBB) {
+ UseBB = BB;
+ continue;
+ }
+ // Otherwise, this must be the same basic block we've seen previously.
+ if (UseBB != BB)
+ return nullptr;
+ }
+
+ // If *all* of the elements are basic-block-independent, meaning they are
+ // either function arguments, or constant expressions, then if we didn't
+ // handle them without predecessor-aware handling, we won't handle them now.
+ if (!UseBB)
+ return nullptr;
+
+ // If we didn't manage to find source aggregate without looking at
+ // predecessors, and there are no predecessors to look at, then we're done.
+ if (pred_empty(UseBB))
+ return nullptr;
+
+ // Arbitrary predecessor count limit.
+ static const int PredCountLimit = 64;
+
+ // Cache the (non-uniqified!) list of predecessors in a vector,
+ // checking the limit at the same time for efficiency.
+ SmallVector<BasicBlock *, 4> Preds; // May have duplicates!
+ for (BasicBlock *Pred : predecessors(UseBB)) {
+ // Don't bother if there are too many predecessors.
+ if (Preds.size() >= PredCountLimit) // FIXME: only count duplicates once?
+ return nullptr;
+ Preds.emplace_back(Pred);
+ }
+
+ // For each predecessor, what is the source aggregate,
+ // from which all the elements were originally extracted from?
+ // Note that we want for the map to have stable iteration order!
+ SmallDenseMap<BasicBlock *, Value *, 4> SourceAggregates;
+ for (BasicBlock *Pred : Preds) {
+ std::pair<decltype(SourceAggregates)::iterator, bool> IV =
+ SourceAggregates.insert({Pred, nullptr});
+ // Did we already evaluate this predecessor?
+ if (!IV.second)
+ continue;
+
+ // Let's hope that when coming from predecessor Pred, all elements of the
+ // aggregate produced by OrigIVI must have been originally extracted from
+ // the same aggregate. Is that so? Can we find said original aggregate?
+ SourceAggregate = FindCommonSourceAggregate(UseBB, Pred);
+ if (Describe(SourceAggregate) != AggregateDescription::Found)
+ return nullptr; // Give up.
+ IV.first->second = *SourceAggregate;
+ }
+
+ // All good! Now we just need to thread the source aggregates here.
+ // Note that we have to insert the new PHI here, ourselves, because we can't
+ // rely on InstCombinerImpl::run() inserting it into the right basic block.
+ // Note that the same block can be a predecessor more than once,
+ // and we need to preserve that invariant for the PHI node.
+ BuilderTy::InsertPointGuard Guard(Builder);
+ Builder.SetInsertPoint(UseBB->getFirstNonPHI());
+ auto *PHI =
+ Builder.CreatePHI(AggTy, Preds.size(), OrigIVI.getName() + ".merged");
+ for (BasicBlock *Pred : Preds)
+ PHI->addIncoming(SourceAggregates[Pred], Pred);
+
+ ++NumAggregateReconstructionsSimplified;
+ return replaceInstUsesWith(OrigIVI, PHI);
+}
+
/// Try to find redundant insertvalue instructions, like the following ones:
/// %0 = insertvalue { i8, i32 } undef, i8 %x, 0
/// %1 = insertvalue { i8, i32 } %0, i8 %y, 0
@@ -699,7 +991,7 @@ static ShuffleOps collectShuffleElements(Value *V, SmallVectorImpl<int> &Mask,
/// first one, making the first one redundant.
/// It should be transformed to:
/// %0 = insertvalue { i8, i32 } undef, i8 %y, 0
-Instruction *InstCombiner::visitInsertValueInst(InsertValueInst &I) {
+Instruction *InstCombinerImpl::visitInsertValueInst(InsertValueInst &I) {
bool IsRedundant = false;
ArrayRef<unsigned int> FirstIndices = I.getIndices();
@@ -724,6 +1016,10 @@ Instruction *InstCombiner::visitInsertValueInst(InsertValueInst &I) {
if (IsRedundant)
return replaceInstUsesWith(I, I.getOperand(0));
+
+ if (Instruction *NewI = foldAggregateConstructionIntoAggregateReuse(I))
+ return NewI;
+
return nullptr;
}
@@ -854,7 +1150,8 @@ static Instruction *foldInsEltIntoSplat(InsertElementInst &InsElt) {
// For example:
// inselt (shuf (inselt undef, X, 0), undef, <0,undef,0,undef>), X, 1
// --> shuf (inselt undef, X, 0), undef, <0,0,0,undef>
- unsigned NumMaskElts = Shuf->getType()->getNumElements();
+ unsigned NumMaskElts =
+ cast<FixedVectorType>(Shuf->getType())->getNumElements();
SmallVector<int, 16> NewMask(NumMaskElts);
for (unsigned i = 0; i != NumMaskElts; ++i)
NewMask[i] = i == IdxC ? 0 : Shuf->getMaskValue(i);
@@ -892,7 +1189,8 @@ static Instruction *foldInsEltIntoIdentityShuffle(InsertElementInst &InsElt) {
// that same index value.
// For example:
// inselt (shuf X, IdMask), (extelt X, IdxC), IdxC --> shuf X, IdMask'
- unsigned NumMaskElts = Shuf->getType()->getNumElements();
+ unsigned NumMaskElts =
+ cast<FixedVectorType>(Shuf->getType())->getNumElements();
SmallVector<int, 16> NewMask(NumMaskElts);
ArrayRef<int> OldMask = Shuf->getShuffleMask();
for (unsigned i = 0; i != NumMaskElts; ++i) {
@@ -1041,7 +1339,7 @@ static Instruction *foldConstantInsEltIntoShuffle(InsertElementInst &InsElt) {
return nullptr;
}
-Instruction *InstCombiner::visitInsertElementInst(InsertElementInst &IE) {
+Instruction *InstCombinerImpl::visitInsertElementInst(InsertElementInst &IE) {
Value *VecOp = IE.getOperand(0);
Value *ScalarOp = IE.getOperand(1);
Value *IdxOp = IE.getOperand(2);
@@ -1189,7 +1487,7 @@ static bool canEvaluateShuffled(Value *V, ArrayRef<int> Mask,
// Propagating an undefined shuffle mask element to integer div/rem is not
// allowed because those opcodes can create immediate undefined behavior
// from an undefined element in an operand.
- if (llvm::any_of(Mask, [](int M){ return M == -1; }))
+ if (llvm::is_contained(Mask, -1))
return false;
LLVM_FALLTHROUGH;
case Instruction::Add:
@@ -1222,7 +1520,7 @@ static bool canEvaluateShuffled(Value *V, ArrayRef<int> Mask,
// longer vector ops, but that may result in more expensive codegen.
Type *ITy = I->getType();
if (ITy->isVectorTy() &&
- Mask.size() > cast<VectorType>(ITy)->getNumElements())
+ Mask.size() > cast<FixedVectorType>(ITy)->getNumElements())
return false;
for (Value *Operand : I->operands()) {
if (!canEvaluateShuffled(Operand, Mask, Depth - 1))
@@ -1380,7 +1678,8 @@ static Value *evaluateInDifferentElementOrder(Value *V, ArrayRef<int> Mask) {
case Instruction::GetElementPtr: {
SmallVector<Value*, 8> NewOps;
bool NeedsRebuild =
- (Mask.size() != cast<VectorType>(I->getType())->getNumElements());
+ (Mask.size() !=
+ cast<FixedVectorType>(I->getType())->getNumElements());
for (int i = 0, e = I->getNumOperands(); i != e; ++i) {
Value *V;
// Recursively call evaluateInDifferentElementOrder on vector arguments
@@ -1435,7 +1734,7 @@ static Value *evaluateInDifferentElementOrder(Value *V, ArrayRef<int> Mask) {
static bool isShuffleExtractingFromLHS(ShuffleVectorInst &SVI,
ArrayRef<int> Mask) {
unsigned LHSElems =
- cast<VectorType>(SVI.getOperand(0)->getType())->getNumElements();
+ cast<FixedVectorType>(SVI.getOperand(0)->getType())->getNumElements();
unsigned MaskElems = Mask.size();
unsigned BegIdx = Mask.front();
unsigned EndIdx = Mask.back();
@@ -1525,7 +1824,7 @@ static Instruction *foldSelectShuffleWith1Binop(ShuffleVectorInst &Shuf) {
is_contained(Mask, UndefMaskElem) &&
(Instruction::isIntDivRem(BOpcode) || Instruction::isShift(BOpcode));
if (MightCreatePoisonOrUB)
- NewC = getSafeVectorConstantForBinop(BOpcode, NewC, true);
+ NewC = InstCombiner::getSafeVectorConstantForBinop(BOpcode, NewC, true);
// shuf (bop X, C), X, M --> bop X, C'
// shuf X, (bop X, C), M --> bop X, C'
@@ -1567,7 +1866,8 @@ static Instruction *canonicalizeInsertSplat(ShuffleVectorInst &Shuf,
// For example:
// shuf (inselt undef, X, 2), undef, <2,2,undef>
// --> shuf (inselt undef, X, 0), undef, <0,0,undef>
- unsigned NumMaskElts = Shuf.getType()->getNumElements();
+ unsigned NumMaskElts =
+ cast<FixedVectorType>(Shuf.getType())->getNumElements();
SmallVector<int, 16> NewMask(NumMaskElts, 0);
for (unsigned i = 0; i != NumMaskElts; ++i)
if (Mask[i] == UndefMaskElem)
@@ -1585,7 +1885,7 @@ static Instruction *foldSelectShuffle(ShuffleVectorInst &Shuf,
// Canonicalize to choose from operand 0 first unless operand 1 is undefined.
// Commuting undef to operand 0 conflicts with another canonicalization.
- unsigned NumElts = Shuf.getType()->getNumElements();
+ unsigned NumElts = cast<FixedVectorType>(Shuf.getType())->getNumElements();
if (!isa<UndefValue>(Shuf.getOperand(1)) &&
Shuf.getMaskValue(0) >= (int)NumElts) {
// TODO: Can we assert that both operands of a shuffle-select are not undef
@@ -1652,7 +1952,8 @@ static Instruction *foldSelectShuffle(ShuffleVectorInst &Shuf,
is_contained(Mask, UndefMaskElem) &&
(Instruction::isIntDivRem(BOpc) || Instruction::isShift(BOpc));
if (MightCreatePoisonOrUB)
- NewC = getSafeVectorConstantForBinop(BOpc, NewC, ConstantsAreOp1);
+ NewC = InstCombiner::getSafeVectorConstantForBinop(BOpc, NewC,
+ ConstantsAreOp1);
Value *V;
if (X == Y) {
@@ -1719,8 +2020,8 @@ static Instruction *foldTruncShuffle(ShuffleVectorInst &Shuf,
// and the source element type must be larger than the shuffle element type.
Type *SrcType = X->getType();
if (!SrcType->isVectorTy() || !SrcType->isIntOrIntVectorTy() ||
- cast<VectorType>(SrcType)->getNumElements() !=
- cast<VectorType>(DestType)->getNumElements() ||
+ cast<FixedVectorType>(SrcType)->getNumElements() !=
+ cast<FixedVectorType>(DestType)->getNumElements() ||
SrcType->getScalarSizeInBits() % DestType->getScalarSizeInBits() != 0)
return nullptr;
@@ -1736,8 +2037,7 @@ static Instruction *foldTruncShuffle(ShuffleVectorInst &Shuf,
if (Mask[i] == UndefMaskElem)
continue;
uint64_t LSBIndex = IsBigEndian ? (i + 1) * TruncRatio - 1 : i * TruncRatio;
- assert(LSBIndex <= std::numeric_limits<int32_t>::max() &&
- "Overflowed 32-bits");
+ assert(LSBIndex <= INT32_MAX && "Overflowed 32-bits");
if (Mask[i] != (int)LSBIndex)
return nullptr;
}
@@ -1764,19 +2064,19 @@ static Instruction *narrowVectorSelect(ShuffleVectorInst &Shuf,
// We need a narrow condition value. It must be extended with undef elements
// and have the same number of elements as this shuffle.
- unsigned NarrowNumElts = Shuf.getType()->getNumElements();
+ unsigned NarrowNumElts =
+ cast<FixedVectorType>(Shuf.getType())->getNumElements();
Value *NarrowCond;
if (!match(Cond, m_OneUse(m_Shuffle(m_Value(NarrowCond), m_Undef()))) ||
- cast<VectorType>(NarrowCond->getType())->getNumElements() !=
+ cast<FixedVectorType>(NarrowCond->getType())->getNumElements() !=
NarrowNumElts ||
!cast<ShuffleVectorInst>(Cond)->isIdentityWithPadding())
return nullptr;
// shuf (sel (shuf NarrowCond, undef, WideMask), X, Y), undef, NarrowMask) -->
// sel NarrowCond, (shuf X, undef, NarrowMask), (shuf Y, undef, NarrowMask)
- Value *Undef = UndefValue::get(X->getType());
- Value *NarrowX = Builder.CreateShuffleVector(X, Undef, Shuf.getShuffleMask());
- Value *NarrowY = Builder.CreateShuffleVector(Y, Undef, Shuf.getShuffleMask());
+ Value *NarrowX = Builder.CreateShuffleVector(X, Shuf.getShuffleMask());
+ Value *NarrowY = Builder.CreateShuffleVector(Y, Shuf.getShuffleMask());
return SelectInst::Create(NarrowCond, NarrowX, NarrowY);
}
@@ -1807,7 +2107,7 @@ static Instruction *foldIdentityExtractShuffle(ShuffleVectorInst &Shuf) {
// new shuffle mask. Otherwise, copy the original mask element. Example:
// shuf (shuf X, Y, <C0, C1, C2, undef, C4>), undef, <0, undef, 2, 3> -->
// shuf X, Y, <C0, undef, C2, undef>
- unsigned NumElts = Shuf.getType()->getNumElements();
+ unsigned NumElts = cast<FixedVectorType>(Shuf.getType())->getNumElements();
SmallVector<int, 16> NewMask(NumElts);
assert(NumElts < Mask.size() &&
"Identity with extract must have less elements than its inputs");
@@ -1823,7 +2123,7 @@ static Instruction *foldIdentityExtractShuffle(ShuffleVectorInst &Shuf) {
/// Try to replace a shuffle with an insertelement or try to replace a shuffle
/// operand with the operand of an insertelement.
static Instruction *foldShuffleWithInsert(ShuffleVectorInst &Shuf,
- InstCombiner &IC) {
+ InstCombinerImpl &IC) {
Value *V0 = Shuf.getOperand(0), *V1 = Shuf.getOperand(1);
SmallVector<int, 16> Mask;
Shuf.getShuffleMask(Mask);
@@ -1832,7 +2132,7 @@ static Instruction *foldShuffleWithInsert(ShuffleVectorInst &Shuf,
// TODO: This restriction could be removed if the insert has only one use
// (because the transform would require a new length-changing shuffle).
int NumElts = Mask.size();
- if (NumElts != (int)(cast<VectorType>(V0->getType())->getNumElements()))
+ if (NumElts != (int)(cast<FixedVectorType>(V0->getType())->getNumElements()))
return nullptr;
// This is a specialization of a fold in SimplifyDemandedVectorElts. We may
@@ -1844,7 +2144,7 @@ static Instruction *foldShuffleWithInsert(ShuffleVectorInst &Shuf,
uint64_t IdxC;
if (match(V0, m_InsertElt(m_Value(X), m_Value(), m_ConstantInt(IdxC)))) {
// shuf (inselt X, ?, IdxC), ?, Mask --> shuf X, ?, Mask
- if (none_of(Mask, [IdxC](int MaskElt) { return MaskElt == (int)IdxC; }))
+ if (!is_contained(Mask, (int)IdxC))
return IC.replaceOperand(Shuf, 0, X);
}
if (match(V1, m_InsertElt(m_Value(X), m_Value(), m_ConstantInt(IdxC)))) {
@@ -1852,7 +2152,7 @@ static Instruction *foldShuffleWithInsert(ShuffleVectorInst &Shuf,
// accesses to the 2nd vector input of the shuffle.
IdxC += NumElts;
// shuf ?, (inselt X, ?, IdxC), Mask --> shuf ?, X, Mask
- if (none_of(Mask, [IdxC](int MaskElt) { return MaskElt == (int)IdxC; }))
+ if (!is_contained(Mask, (int)IdxC))
return IC.replaceOperand(Shuf, 1, X);
}
@@ -1927,9 +2227,10 @@ static Instruction *foldIdentityPaddedShuffles(ShuffleVectorInst &Shuf) {
Value *X = Shuffle0->getOperand(0);
Value *Y = Shuffle1->getOperand(0);
if (X->getType() != Y->getType() ||
- !isPowerOf2_32(Shuf.getType()->getNumElements()) ||
- !isPowerOf2_32(Shuffle0->getType()->getNumElements()) ||
- !isPowerOf2_32(cast<VectorType>(X->getType())->getNumElements()) ||
+ !isPowerOf2_32(cast<FixedVectorType>(Shuf.getType())->getNumElements()) ||
+ !isPowerOf2_32(
+ cast<FixedVectorType>(Shuffle0->getType())->getNumElements()) ||
+ !isPowerOf2_32(cast<FixedVectorType>(X->getType())->getNumElements()) ||
isa<UndefValue>(X) || isa<UndefValue>(Y))
return nullptr;
assert(isa<UndefValue>(Shuffle0->getOperand(1)) &&
@@ -1940,8 +2241,8 @@ static Instruction *foldIdentityPaddedShuffles(ShuffleVectorInst &Shuf) {
// operands directly by adjusting the shuffle mask to account for the narrower
// types:
// shuf (widen X), (widen Y), Mask --> shuf X, Y, Mask'
- int NarrowElts = cast<VectorType>(X->getType())->getNumElements();
- int WideElts = Shuffle0->getType()->getNumElements();
+ int NarrowElts = cast<FixedVectorType>(X->getType())->getNumElements();
+ int WideElts = cast<FixedVectorType>(Shuffle0->getType())->getNumElements();
assert(WideElts > NarrowElts && "Unexpected types for identity with padding");
ArrayRef<int> Mask = Shuf.getShuffleMask();
@@ -1974,7 +2275,7 @@ static Instruction *foldIdentityPaddedShuffles(ShuffleVectorInst &Shuf) {
return new ShuffleVectorInst(X, Y, NewMask);
}
-Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) {
+Instruction *InstCombinerImpl::visitShuffleVectorInst(ShuffleVectorInst &SVI) {
Value *LHS = SVI.getOperand(0);
Value *RHS = SVI.getOperand(1);
SimplifyQuery ShufQuery = SQ.getWithInstruction(&SVI);
@@ -1982,9 +2283,13 @@ Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) {
SVI.getType(), ShufQuery))
return replaceInstUsesWith(SVI, V);
+ // Bail out for scalable vectors
+ if (isa<ScalableVectorType>(LHS->getType()))
+ return nullptr;
+
// shuffle x, x, mask --> shuffle x, undef, mask'
- unsigned VWidth = SVI.getType()->getNumElements();
- unsigned LHSWidth = cast<VectorType>(LHS->getType())->getNumElements();
+ unsigned VWidth = cast<FixedVectorType>(SVI.getType())->getNumElements();
+ unsigned LHSWidth = cast<FixedVectorType>(LHS->getType())->getNumElements();
ArrayRef<int> Mask = SVI.getShuffleMask();
Type *Int32Ty = Type::getInt32Ty(SVI.getContext());
@@ -1998,7 +2303,7 @@ Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) {
if (match(LHS, m_BitCast(m_Value(X))) && match(RHS, m_Undef()) &&
X->getType()->isVectorTy() && VWidth == LHSWidth) {
// Try to create a scaled mask constant.
- auto *XType = cast<VectorType>(X->getType());
+ auto *XType = cast<FixedVectorType>(X->getType());
unsigned XNumElts = XType->getNumElements();
SmallVector<int, 16> ScaledMask;
if (XNumElts >= VWidth) {
@@ -2106,7 +2411,7 @@ Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) {
if (isShuffleExtractingFromLHS(SVI, Mask)) {
Value *V = LHS;
unsigned MaskElems = Mask.size();
- VectorType *SrcTy = cast<VectorType>(V->getType());
+ auto *SrcTy = cast<FixedVectorType>(V->getType());
unsigned VecBitWidth = SrcTy->getPrimitiveSizeInBits().getFixedSize();
unsigned SrcElemBitWidth = DL.getTypeSizeInBits(SrcTy->getElementType());
assert(SrcElemBitWidth && "vector elements must have a bitwidth");
@@ -2138,8 +2443,7 @@ Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) {
SmallVector<int, 16> ShuffleMask(SrcNumElems, -1);
for (unsigned I = 0, E = MaskElems, Idx = BegIdx; I != E; ++Idx, ++I)
ShuffleMask[I] = Idx;
- V = Builder.CreateShuffleVector(V, UndefValue::get(V->getType()),
- ShuffleMask,
+ V = Builder.CreateShuffleVector(V, ShuffleMask,
SVI.getName() + ".extract");
BegIdx = 0;
}
@@ -2224,11 +2528,11 @@ Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) {
if (LHSShuffle) {
LHSOp0 = LHSShuffle->getOperand(0);
LHSOp1 = LHSShuffle->getOperand(1);
- LHSOp0Width = cast<VectorType>(LHSOp0->getType())->getNumElements();
+ LHSOp0Width = cast<FixedVectorType>(LHSOp0->getType())->getNumElements();
}
if (RHSShuffle) {
RHSOp0 = RHSShuffle->getOperand(0);
- RHSOp0Width = cast<VectorType>(RHSOp0->getType())->getNumElements();
+ RHSOp0Width = cast<FixedVectorType>(RHSOp0->getType())->getNumElements();
}
Value* newLHS = LHS;
Value* newRHS = RHS;
@@ -2331,17 +2635,9 @@ Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) {
// If the result mask is equal to one of the original shuffle masks,
// or is a splat, do the replacement.
if (isSplat || newMask == LHSMask || newMask == RHSMask || newMask == Mask) {
- SmallVector<Constant*, 16> Elts;
- for (unsigned i = 0, e = newMask.size(); i != e; ++i) {
- if (newMask[i] < 0) {
- Elts.push_back(UndefValue::get(Int32Ty));
- } else {
- Elts.push_back(ConstantInt::get(Int32Ty, newMask[i]));
- }
- }
if (!newRHS)
newRHS = UndefValue::get(newLHS->getType());
- return new ShuffleVectorInst(newLHS, newRHS, ConstantVector::get(Elts));
+ return new ShuffleVectorInst(newLHS, newRHS, newMask);
}
return MadeChange ? &SVI : nullptr;
diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index b3254c10a0b2..518e909e8ab4 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -59,6 +59,7 @@
#include "llvm/Analysis/ProfileSummaryInfo.h"
#include "llvm/Analysis/TargetFolder.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
+#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/Analysis/VectorUtils.h"
#include "llvm/IR/BasicBlock.h"
@@ -113,6 +114,9 @@ using namespace llvm::PatternMatch;
#define DEBUG_TYPE "instcombine"
+STATISTIC(NumWorklistIterations,
+ "Number of instruction combining iterations performed");
+
STATISTIC(NumCombined , "Number of insts combined");
STATISTIC(NumConstProp, "Number of constant folds");
STATISTIC(NumDeadInst , "Number of dead inst eliminated");
@@ -123,8 +127,13 @@ STATISTIC(NumReassoc , "Number of reassociations");
DEBUG_COUNTER(VisitCounter, "instcombine-visit",
"Controls which instructions are visited");
+// FIXME: these limits eventually should be as low as 2.
static constexpr unsigned InstCombineDefaultMaxIterations = 1000;
+#ifndef NDEBUG
+static constexpr unsigned InstCombineDefaultInfiniteLoopThreshold = 100;
+#else
static constexpr unsigned InstCombineDefaultInfiniteLoopThreshold = 1000;
+#endif
static cl::opt<bool>
EnableCodeSinking("instcombine-code-sinking", cl::desc("Enable code sinking"),
@@ -155,7 +164,41 @@ MaxArraySize("instcombine-maxarray-size", cl::init(1024),
static cl::opt<unsigned> ShouldLowerDbgDeclare("instcombine-lower-dbg-declare",
cl::Hidden, cl::init(true));
-Value *InstCombiner::EmitGEPOffset(User *GEP) {
+Optional<Instruction *>
+InstCombiner::targetInstCombineIntrinsic(IntrinsicInst &II) {
+ // Handle target specific intrinsics
+ if (II.getCalledFunction()->isTargetIntrinsic()) {
+ return TTI.instCombineIntrinsic(*this, II);
+ }
+ return None;
+}
+
+Optional<Value *> InstCombiner::targetSimplifyDemandedUseBitsIntrinsic(
+ IntrinsicInst &II, APInt DemandedMask, KnownBits &Known,
+ bool &KnownBitsComputed) {
+ // Handle target specific intrinsics
+ if (II.getCalledFunction()->isTargetIntrinsic()) {
+ return TTI.simplifyDemandedUseBitsIntrinsic(*this, II, DemandedMask, Known,
+ KnownBitsComputed);
+ }
+ return None;
+}
+
+Optional<Value *> InstCombiner::targetSimplifyDemandedVectorEltsIntrinsic(
+ IntrinsicInst &II, APInt DemandedElts, APInt &UndefElts, APInt &UndefElts2,
+ APInt &UndefElts3,
+ std::function<void(Instruction *, unsigned, APInt, APInt &)>
+ SimplifyAndSetOp) {
+ // Handle target specific intrinsics
+ if (II.getCalledFunction()->isTargetIntrinsic()) {
+ return TTI.simplifyDemandedVectorEltsIntrinsic(
+ *this, II, DemandedElts, UndefElts, UndefElts2, UndefElts3,
+ SimplifyAndSetOp);
+ }
+ return None;
+}
+
+Value *InstCombinerImpl::EmitGEPOffset(User *GEP) {
return llvm::EmitGEPOffset(&Builder, DL, GEP);
}
@@ -168,8 +211,8 @@ Value *InstCombiner::EmitGEPOffset(User *GEP) {
/// legal to convert to, in order to open up more combining opportunities.
/// NOTE: this treats i8, i16 and i32 specially, due to them being so common
/// from frontend languages.
-bool InstCombiner::shouldChangeType(unsigned FromWidth,
- unsigned ToWidth) const {
+bool InstCombinerImpl::shouldChangeType(unsigned FromWidth,
+ unsigned ToWidth) const {
bool FromLegal = FromWidth == 1 || DL.isLegalInteger(FromWidth);
bool ToLegal = ToWidth == 1 || DL.isLegalInteger(ToWidth);
@@ -196,7 +239,7 @@ bool InstCombiner::shouldChangeType(unsigned FromWidth,
/// to a larger illegal type. i1 is always treated as a legal type because it is
/// a fundamental type in IR, and there are many specialized optimizations for
/// i1 types.
-bool InstCombiner::shouldChangeType(Type *From, Type *To) const {
+bool InstCombinerImpl::shouldChangeType(Type *From, Type *To) const {
// TODO: This could be extended to allow vectors. Datalayout changes might be
// needed to properly support that.
if (!From->isIntegerTy() || !To->isIntegerTy())
@@ -264,7 +307,8 @@ static void ClearSubclassDataAfterReassociation(BinaryOperator &I) {
/// cast to eliminate one of the associative operations:
/// (op (cast (op X, C2)), C1) --> (cast (op X, op (C1, C2)))
/// (op (cast (op X, C2)), C1) --> (op (cast X), op (C1, C2))
-static bool simplifyAssocCastAssoc(BinaryOperator *BinOp1, InstCombiner &IC) {
+static bool simplifyAssocCastAssoc(BinaryOperator *BinOp1,
+ InstCombinerImpl &IC) {
auto *Cast = dyn_cast<CastInst>(BinOp1->getOperand(0));
if (!Cast || !Cast->hasOneUse())
return false;
@@ -322,7 +366,7 @@ static bool simplifyAssocCastAssoc(BinaryOperator *BinOp1, InstCombiner &IC) {
/// 5. Transform: "A op (B op C)" ==> "B op (C op A)" if "C op A" simplifies.
/// 6. Transform: "(A op C1) op (B op C2)" ==> "(A op B) op (C1 op C2)"
/// if C1 and C2 are constants.
-bool InstCombiner::SimplifyAssociativeOrCommutative(BinaryOperator &I) {
+bool InstCombinerImpl::SimplifyAssociativeOrCommutative(BinaryOperator &I) {
Instruction::BinaryOps Opcode = I.getOpcode();
bool Changed = false;
@@ -550,9 +594,10 @@ getBinOpsForFactorization(Instruction::BinaryOps TopOpcode, BinaryOperator *Op,
/// This tries to simplify binary operations by factorizing out common terms
/// (e. g. "(A*B)+(A*C)" -> "A*(B+C)").
-Value *InstCombiner::tryFactorization(BinaryOperator &I,
- Instruction::BinaryOps InnerOpcode,
- Value *A, Value *B, Value *C, Value *D) {
+Value *InstCombinerImpl::tryFactorization(BinaryOperator &I,
+ Instruction::BinaryOps InnerOpcode,
+ Value *A, Value *B, Value *C,
+ Value *D) {
assert(A && B && C && D && "All values must be provided");
Value *V = nullptr;
@@ -655,7 +700,7 @@ Value *InstCombiner::tryFactorization(BinaryOperator &I,
/// (eg "(A*B)+(A*C)" -> "A*(B+C)") or expanding out if this results in
/// simplifications (eg: "A & (B | C) -> (A&B) | (A&C)" if this is a win).
/// Returns the simplified value, or null if it didn't simplify.
-Value *InstCombiner::SimplifyUsingDistributiveLaws(BinaryOperator &I) {
+Value *InstCombinerImpl::SimplifyUsingDistributiveLaws(BinaryOperator &I) {
Value *LHS = I.getOperand(0), *RHS = I.getOperand(1);
BinaryOperator *Op0 = dyn_cast<BinaryOperator>(LHS);
BinaryOperator *Op1 = dyn_cast<BinaryOperator>(RHS);
@@ -698,8 +743,10 @@ Value *InstCombiner::SimplifyUsingDistributiveLaws(BinaryOperator &I) {
Value *A = Op0->getOperand(0), *B = Op0->getOperand(1), *C = RHS;
Instruction::BinaryOps InnerOpcode = Op0->getOpcode(); // op'
- Value *L = SimplifyBinOp(TopLevelOpcode, A, C, SQ.getWithInstruction(&I));
- Value *R = SimplifyBinOp(TopLevelOpcode, B, C, SQ.getWithInstruction(&I));
+ // Disable the use of undef because it's not safe to distribute undef.
+ auto SQDistributive = SQ.getWithInstruction(&I).getWithoutUndef();
+ Value *L = SimplifyBinOp(TopLevelOpcode, A, C, SQDistributive);
+ Value *R = SimplifyBinOp(TopLevelOpcode, B, C, SQDistributive);
// Do "A op C" and "B op C" both simplify?
if (L && R) {
@@ -735,8 +782,10 @@ Value *InstCombiner::SimplifyUsingDistributiveLaws(BinaryOperator &I) {
Value *A = LHS, *B = Op1->getOperand(0), *C = Op1->getOperand(1);
Instruction::BinaryOps InnerOpcode = Op1->getOpcode(); // op'
- Value *L = SimplifyBinOp(TopLevelOpcode, A, B, SQ.getWithInstruction(&I));
- Value *R = SimplifyBinOp(TopLevelOpcode, A, C, SQ.getWithInstruction(&I));
+ // Disable the use of undef because it's not safe to distribute undef.
+ auto SQDistributive = SQ.getWithInstruction(&I).getWithoutUndef();
+ Value *L = SimplifyBinOp(TopLevelOpcode, A, B, SQDistributive);
+ Value *R = SimplifyBinOp(TopLevelOpcode, A, C, SQDistributive);
// Do "A op B" and "A op C" both simplify?
if (L && R) {
@@ -769,8 +818,9 @@ Value *InstCombiner::SimplifyUsingDistributiveLaws(BinaryOperator &I) {
return SimplifySelectsFeedingBinaryOp(I, LHS, RHS);
}
-Value *InstCombiner::SimplifySelectsFeedingBinaryOp(BinaryOperator &I,
- Value *LHS, Value *RHS) {
+Value *InstCombinerImpl::SimplifySelectsFeedingBinaryOp(BinaryOperator &I,
+ Value *LHS,
+ Value *RHS) {
Value *A, *B, *C, *D, *E, *F;
bool LHSIsSelect = match(LHS, m_Select(m_Value(A), m_Value(B), m_Value(C)));
bool RHSIsSelect = match(RHS, m_Select(m_Value(D), m_Value(E), m_Value(F)));
@@ -820,9 +870,33 @@ Value *InstCombiner::SimplifySelectsFeedingBinaryOp(BinaryOperator &I,
return SI;
}
+/// Freely adapt every user of V as-if V was changed to !V.
+/// WARNING: only if canFreelyInvertAllUsersOf() said this can be done.
+void InstCombinerImpl::freelyInvertAllUsersOf(Value *I) {
+ for (User *U : I->users()) {
+ switch (cast<Instruction>(U)->getOpcode()) {
+ case Instruction::Select: {
+ auto *SI = cast<SelectInst>(U);
+ SI->swapValues();
+ SI->swapProfMetadata();
+ break;
+ }
+ case Instruction::Br:
+ cast<BranchInst>(U)->swapSuccessors(); // swaps prof metadata too
+ break;
+ case Instruction::Xor:
+ replaceInstUsesWith(cast<Instruction>(*U), I);
+ break;
+ default:
+ llvm_unreachable("Got unexpected user - out of sync with "
+ "canFreelyInvertAllUsersOf() ?");
+ }
+ }
+}
+
/// Given a 'sub' instruction, return the RHS of the instruction if the LHS is a
/// constant zero (which is the 'negate' form).
-Value *InstCombiner::dyn_castNegVal(Value *V) const {
+Value *InstCombinerImpl::dyn_castNegVal(Value *V) const {
Value *NegV;
if (match(V, m_Neg(m_Value(NegV))))
return NegV;
@@ -883,7 +957,8 @@ static Value *foldOperationIntoSelectOperand(Instruction &I, Value *SO,
return RI;
}
-Instruction *InstCombiner::FoldOpIntoSelect(Instruction &Op, SelectInst *SI) {
+Instruction *InstCombinerImpl::FoldOpIntoSelect(Instruction &Op,
+ SelectInst *SI) {
// Don't modify shared select instructions.
if (!SI->hasOneUse())
return nullptr;
@@ -908,7 +983,7 @@ Instruction *InstCombiner::FoldOpIntoSelect(Instruction &Op, SelectInst *SI) {
return nullptr;
// If vectors, verify that they have the same number of elements.
- if (SrcTy && SrcTy->getNumElements() != DestTy->getNumElements())
+ if (SrcTy && SrcTy->getElementCount() != DestTy->getElementCount())
return nullptr;
}
@@ -978,7 +1053,7 @@ static Value *foldOperationIntoPhiValue(BinaryOperator *I, Value *InV,
return RI;
}
-Instruction *InstCombiner::foldOpIntoPhi(Instruction &I, PHINode *PN) {
+Instruction *InstCombinerImpl::foldOpIntoPhi(Instruction &I, PHINode *PN) {
unsigned NumPHIValues = PN->getNumIncomingValues();
if (NumPHIValues == 0)
return nullptr;
@@ -1004,7 +1079,9 @@ Instruction *InstCombiner::foldOpIntoPhi(Instruction &I, PHINode *PN) {
BasicBlock *NonConstBB = nullptr;
for (unsigned i = 0; i != NumPHIValues; ++i) {
Value *InVal = PN->getIncomingValue(i);
- if (isa<Constant>(InVal) && !isa<ConstantExpr>(InVal))
+ // If I is a freeze instruction, count undef as a non-constant.
+ if (match(InVal, m_ImmConstant()) &&
+ (!isa<FreezeInst>(I) || isGuaranteedNotToBeUndefOrPoison(InVal)))
continue;
if (isa<PHINode>(InVal)) return nullptr; // Itself a phi.
@@ -1029,9 +1106,11 @@ Instruction *InstCombiner::foldOpIntoPhi(Instruction &I, PHINode *PN) {
// operation in that block. However, if this is a critical edge, we would be
// inserting the computation on some other paths (e.g. inside a loop). Only
// do this if the pred block is unconditionally branching into the phi block.
+ // Also, make sure that the pred block is not dead code.
if (NonConstBB != nullptr) {
BranchInst *BI = dyn_cast<BranchInst>(NonConstBB->getTerminator());
- if (!BI || !BI->isUnconditional()) return nullptr;
+ if (!BI || !BI->isUnconditional() || !DT.isReachableFromEntry(NonConstBB))
+ return nullptr;
}
// Okay, we can do the transformation: create the new PHI node.
@@ -1063,7 +1142,7 @@ Instruction *InstCombiner::foldOpIntoPhi(Instruction &I, PHINode *PN) {
// FalseVInPred versus TrueVInPred. When we have individual nonzero
// elements in the vector, we will incorrectly fold InC to
// `TrueVInPred`.
- if (InC && !isa<ConstantExpr>(InC) && isa<ConstantInt>(InC))
+ if (InC && isa<ConstantInt>(InC))
InV = InC->isNullValue() ? FalseVInPred : TrueVInPred;
else {
// Generate the select in the same block as PN's current incoming block.
@@ -1097,6 +1176,15 @@ Instruction *InstCombiner::foldOpIntoPhi(Instruction &I, PHINode *PN) {
Builder);
NewPN->addIncoming(InV, PN->getIncomingBlock(i));
}
+ } else if (isa<FreezeInst>(&I)) {
+ for (unsigned i = 0; i != NumPHIValues; ++i) {
+ Value *InV;
+ if (NonConstBB == PN->getIncomingBlock(i))
+ InV = Builder.CreateFreeze(PN->getIncomingValue(i), "phi.fr");
+ else
+ InV = PN->getIncomingValue(i);
+ NewPN->addIncoming(InV, PN->getIncomingBlock(i));
+ }
} else {
CastInst *CI = cast<CastInst>(&I);
Type *RetTy = CI->getType();
@@ -1111,8 +1199,8 @@ Instruction *InstCombiner::foldOpIntoPhi(Instruction &I, PHINode *PN) {
}
}
- for (auto UI = PN->user_begin(), E = PN->user_end(); UI != E;) {
- Instruction *User = cast<Instruction>(*UI++);
+ for (User *U : make_early_inc_range(PN->users())) {
+ Instruction *User = cast<Instruction>(U);
if (User == &I) continue;
replaceInstUsesWith(*User, NewPN);
eraseInstFromFunction(*User);
@@ -1120,7 +1208,7 @@ Instruction *InstCombiner::foldOpIntoPhi(Instruction &I, PHINode *PN) {
return replaceInstUsesWith(I, NewPN);
}
-Instruction *InstCombiner::foldBinOpIntoSelectOrPhi(BinaryOperator &I) {
+Instruction *InstCombinerImpl::foldBinOpIntoSelectOrPhi(BinaryOperator &I) {
if (!isa<Constant>(I.getOperand(1)))
return nullptr;
@@ -1138,8 +1226,9 @@ Instruction *InstCombiner::foldBinOpIntoSelectOrPhi(BinaryOperator &I) {
/// is a sequence of GEP indices into the pointed type that will land us at the
/// specified offset. If so, fill them into NewIndices and return the resultant
/// element type, otherwise return null.
-Type *InstCombiner::FindElementAtOffset(PointerType *PtrTy, int64_t Offset,
- SmallVectorImpl<Value *> &NewIndices) {
+Type *
+InstCombinerImpl::FindElementAtOffset(PointerType *PtrTy, int64_t Offset,
+ SmallVectorImpl<Value *> &NewIndices) {
Type *Ty = PtrTy->getElementType();
if (!Ty->isSized())
return nullptr;
@@ -1208,7 +1297,7 @@ static bool shouldMergeGEPs(GEPOperator &GEP, GEPOperator &Src) {
/// Return a value X such that Val = X * Scale, or null if none.
/// If the multiplication is known not to overflow, then NoSignedWrap is set.
-Value *InstCombiner::Descale(Value *Val, APInt Scale, bool &NoSignedWrap) {
+Value *InstCombinerImpl::Descale(Value *Val, APInt Scale, bool &NoSignedWrap) {
assert(isa<IntegerType>(Val->getType()) && "Can only descale integers!");
assert(cast<IntegerType>(Val->getType())->getBitWidth() ==
Scale.getBitWidth() && "Scale not compatible with value!");
@@ -1448,9 +1537,8 @@ Value *InstCombiner::Descale(Value *Val, APInt Scale, bool &NoSignedWrap) {
} while (true);
}
-Instruction *InstCombiner::foldVectorBinop(BinaryOperator &Inst) {
- // FIXME: some of this is likely fine for scalable vectors
- if (!isa<FixedVectorType>(Inst.getType()))
+Instruction *InstCombinerImpl::foldVectorBinop(BinaryOperator &Inst) {
+ if (!isa<VectorType>(Inst.getType()))
return nullptr;
BinaryOperator::BinaryOps Opcode = Inst.getOpcode();
@@ -1539,13 +1627,15 @@ Instruction *InstCombiner::foldVectorBinop(BinaryOperator &Inst) {
// intends to move shuffles closer to other shuffles and binops closer to
// other binops, so they can be folded. It may also enable demanded elements
// transforms.
- unsigned NumElts = cast<FixedVectorType>(Inst.getType())->getNumElements();
Constant *C;
- if (match(&Inst,
+ auto *InstVTy = dyn_cast<FixedVectorType>(Inst.getType());
+ if (InstVTy &&
+ match(&Inst,
m_c_BinOp(m_OneUse(m_Shuffle(m_Value(V1), m_Undef(), m_Mask(Mask))),
- m_Constant(C))) &&
- cast<FixedVectorType>(V1->getType())->getNumElements() <= NumElts) {
- assert(Inst.getType()->getScalarType() == V1->getType()->getScalarType() &&
+ m_ImmConstant(C))) &&
+ cast<FixedVectorType>(V1->getType())->getNumElements() <=
+ InstVTy->getNumElements()) {
+ assert(InstVTy->getScalarType() == V1->getType()->getScalarType() &&
"Shuffle should not change scalar type");
// Find constant NewC that has property:
@@ -1560,6 +1650,7 @@ Instruction *InstCombiner::foldVectorBinop(BinaryOperator &Inst) {
UndefValue *UndefScalar = UndefValue::get(C->getType()->getScalarType());
SmallVector<Constant *, 16> NewVecC(SrcVecNumElts, UndefScalar);
bool MayChange = true;
+ unsigned NumElts = InstVTy->getNumElements();
for (unsigned I = 0; I < NumElts; ++I) {
Constant *CElt = C->getAggregateElement(I);
if (ShMask[I] >= 0) {
@@ -1648,9 +1739,8 @@ Instruction *InstCombiner::foldVectorBinop(BinaryOperator &Inst) {
// values followed by a splat followed by the 2nd binary operation:
// bo (splat X), (bo Y, OtherOp) --> bo (splat (bo X, Y)), OtherOp
Value *NewBO = Builder.CreateBinOp(Opcode, X, Y);
- UndefValue *Undef = UndefValue::get(Inst.getType());
SmallVector<int, 8> NewMask(MaskC.size(), SplatIndex);
- Value *NewSplat = Builder.CreateShuffleVector(NewBO, Undef, NewMask);
+ Value *NewSplat = Builder.CreateShuffleVector(NewBO, NewMask);
Instruction *R = BinaryOperator::Create(Opcode, NewSplat, OtherOp);
// Intersect FMF on both new binops. Other (poison-generating) flags are
@@ -1670,7 +1760,7 @@ Instruction *InstCombiner::foldVectorBinop(BinaryOperator &Inst) {
/// Try to narrow the width of a binop if at least 1 operand is an extend of
/// of a value. This requires a potentially expensive known bits check to make
/// sure the narrow op does not overflow.
-Instruction *InstCombiner::narrowMathIfNoOverflow(BinaryOperator &BO) {
+Instruction *InstCombinerImpl::narrowMathIfNoOverflow(BinaryOperator &BO) {
// We need at least one extended operand.
Value *Op0 = BO.getOperand(0), *Op1 = BO.getOperand(1);
@@ -1750,7 +1840,7 @@ static Instruction *foldSelectGEP(GetElementPtrInst &GEP,
// gep (select Cond, TrueC, FalseC), IndexC --> select Cond, TrueC', FalseC'
// Propagate 'inbounds' and metadata from existing instructions.
// Note: using IRBuilder to create the constants for efficiency.
- SmallVector<Value *, 4> IndexC(GEP.idx_begin(), GEP.idx_end());
+ SmallVector<Value *, 4> IndexC(GEP.indices());
bool IsInBounds = GEP.isInBounds();
Value *NewTrueC = IsInBounds ? Builder.CreateInBoundsGEP(TrueC, IndexC)
: Builder.CreateGEP(TrueC, IndexC);
@@ -1759,8 +1849,8 @@ static Instruction *foldSelectGEP(GetElementPtrInst &GEP,
return SelectInst::Create(Cond, NewTrueC, NewFalseC, "", nullptr, Sel);
}
-Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) {
- SmallVector<Value*, 8> Ops(GEP.op_begin(), GEP.op_end());
+Instruction *InstCombinerImpl::visitGetElementPtrInst(GetElementPtrInst &GEP) {
+ SmallVector<Value *, 8> Ops(GEP.operands());
Type *GEPType = GEP.getType();
Type *GEPEltType = GEP.getSourceElementType();
bool IsGEPSrcEleScalable = isa<ScalableVectorType>(GEPEltType);
@@ -2130,7 +2220,7 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) {
// GEP (bitcast i8* X to [0 x i8]*), i32 0, ... ?
if (CATy->getElementType() == StrippedPtrEltTy) {
// -> GEP i8* X, ...
- SmallVector<Value*, 8> Idx(GEP.idx_begin()+1, GEP.idx_end());
+ SmallVector<Value *, 8> Idx(drop_begin(GEP.indices()));
GetElementPtrInst *Res = GetElementPtrInst::Create(
StrippedPtrEltTy, StrippedPtr, Idx, GEP.getName());
Res->setIsInBounds(GEP.isInBounds());
@@ -2166,7 +2256,7 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) {
// ->
// %0 = GEP [10 x i8] addrspace(1)* X, ...
// addrspacecast i8 addrspace(1)* %0 to i8*
- SmallVector<Value*, 8> Idx(GEP.idx_begin(), GEP.idx_end());
+ SmallVector<Value *, 8> Idx(GEP.indices());
Value *NewGEP =
GEP.isInBounds()
? Builder.CreateInBoundsGEP(StrippedPtrEltTy, StrippedPtr,
@@ -2308,15 +2398,15 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) {
// gep (bitcast [c x ty]* X to <c x ty>*), Y, Z --> gep X, Y, Z
auto areMatchingArrayAndVecTypes = [](Type *ArrTy, Type *VecTy,
const DataLayout &DL) {
- auto *VecVTy = cast<VectorType>(VecTy);
+ auto *VecVTy = cast<FixedVectorType>(VecTy);
return ArrTy->getArrayElementType() == VecVTy->getElementType() &&
ArrTy->getArrayNumElements() == VecVTy->getNumElements() &&
DL.getTypeAllocSize(ArrTy) == DL.getTypeAllocSize(VecTy);
};
if (GEP.getNumOperands() == 3 &&
- ((GEPEltType->isArrayTy() && SrcEltType->isVectorTy() &&
+ ((GEPEltType->isArrayTy() && isa<FixedVectorType>(SrcEltType) &&
areMatchingArrayAndVecTypes(GEPEltType, SrcEltType, DL)) ||
- (GEPEltType->isVectorTy() && SrcEltType->isArrayTy() &&
+ (isa<FixedVectorType>(GEPEltType) && SrcEltType->isArrayTy() &&
areMatchingArrayAndVecTypes(SrcEltType, GEPEltType, DL)))) {
// Create a new GEP here, as using `setOperand()` followed by
@@ -2511,7 +2601,7 @@ static bool isAllocSiteRemovable(Instruction *AI,
return true;
}
-Instruction *InstCombiner::visitAllocSite(Instruction &MI) {
+Instruction *InstCombinerImpl::visitAllocSite(Instruction &MI) {
// If we have a malloc call which is only used in any amount of comparisons to
// null and free calls, delete the calls and replace the comparisons with true
// or false as appropriate.
@@ -2526,10 +2616,10 @@ Instruction *InstCombiner::visitAllocSite(Instruction &MI) {
// If we are removing an alloca with a dbg.declare, insert dbg.value calls
// before each store.
- TinyPtrVector<DbgVariableIntrinsic *> DIIs;
+ SmallVector<DbgVariableIntrinsic *, 8> DVIs;
std::unique_ptr<DIBuilder> DIB;
if (isa<AllocaInst>(MI)) {
- DIIs = FindDbgAddrUses(&MI);
+ findDbgUsers(DVIs, &MI);
DIB.reset(new DIBuilder(*MI.getModule(), /*AllowUnresolved=*/false));
}
@@ -2563,8 +2653,9 @@ Instruction *InstCombiner::visitAllocSite(Instruction &MI) {
ConstantInt::get(Type::getInt1Ty(C->getContext()),
C->isFalseWhenEqual()));
} else if (auto *SI = dyn_cast<StoreInst>(I)) {
- for (auto *DII : DIIs)
- ConvertDebugDeclareToDebugValue(DII, SI, *DIB);
+ for (auto *DVI : DVIs)
+ if (DVI->isAddressOfVariable())
+ ConvertDebugDeclareToDebugValue(DVI, SI, *DIB);
} else {
// Casts, GEP, or anything else: we're about to delete this instruction,
// so it can not have any valid uses.
@@ -2581,8 +2672,31 @@ Instruction *InstCombiner::visitAllocSite(Instruction &MI) {
None, "", II->getParent());
}
- for (auto *DII : DIIs)
- eraseInstFromFunction(*DII);
+ // Remove debug intrinsics which describe the value contained within the
+ // alloca. In addition to removing dbg.{declare,addr} which simply point to
+ // the alloca, remove dbg.value(<alloca>, ..., DW_OP_deref)'s as well, e.g.:
+ //
+ // ```
+ // define void @foo(i32 %0) {
+ // %a = alloca i32 ; Deleted.
+ // store i32 %0, i32* %a
+ // dbg.value(i32 %0, "arg0") ; Not deleted.
+ // dbg.value(i32* %a, "arg0", DW_OP_deref) ; Deleted.
+ // call void @trivially_inlinable_no_op(i32* %a)
+ // ret void
+ // }
+ // ```
+ //
+ // This may not be required if we stop describing the contents of allocas
+ // using dbg.value(<alloca>, ..., DW_OP_deref), but we currently do this in
+ // the LowerDbgDeclare utility.
+ //
+ // If there is a dead store to `%a` in @trivially_inlinable_no_op, the
+ // "arg0" dbg.value may be stale after the call. However, failing to remove
+ // the DW_OP_deref dbg.value causes large gaps in location coverage.
+ for (auto *DVI : DVIs)
+ if (DVI->isAddressOfVariable() || DVI->getExpression()->startsWithDeref())
+ DVI->eraseFromParent();
return eraseInstFromFunction(MI);
}
@@ -2670,7 +2784,7 @@ static Instruction *tryToMoveFreeBeforeNullTest(CallInst &FI,
return &FI;
}
-Instruction *InstCombiner::visitFree(CallInst &FI) {
+Instruction *InstCombinerImpl::visitFree(CallInst &FI) {
Value *Op = FI.getArgOperand(0);
// free undef -> unreachable.
@@ -2711,7 +2825,7 @@ static bool isMustTailCall(Value *V) {
return false;
}
-Instruction *InstCombiner::visitReturnInst(ReturnInst &RI) {
+Instruction *InstCombinerImpl::visitReturnInst(ReturnInst &RI) {
if (RI.getNumOperands() == 0) // ret void
return nullptr;
@@ -2734,7 +2848,31 @@ Instruction *InstCombiner::visitReturnInst(ReturnInst &RI) {
return nullptr;
}
-Instruction *InstCombiner::visitUnconditionalBranchInst(BranchInst &BI) {
+Instruction *InstCombinerImpl::visitUnreachableInst(UnreachableInst &I) {
+ // Try to remove the previous instruction if it must lead to unreachable.
+ // This includes instructions like stores and "llvm.assume" that may not get
+ // removed by simple dead code elimination.
+ Instruction *Prev = I.getPrevNonDebugInstruction();
+ if (Prev && !Prev->isEHPad() &&
+ isGuaranteedToTransferExecutionToSuccessor(Prev)) {
+ // Temporarily disable removal of volatile stores preceding unreachable,
+ // pending a potential LangRef change permitting volatile stores to trap.
+ // TODO: Either remove this code, or properly integrate the check into
+ // isGuaranteedToTransferExecutionToSuccessor().
+ if (auto *SI = dyn_cast<StoreInst>(Prev))
+ if (SI->isVolatile())
+ return nullptr;
+
+ // A value may still have uses before we process it here (for example, in
+ // another unreachable block), so convert those to undef.
+ replaceInstUsesWith(*Prev, UndefValue::get(Prev->getType()));
+ eraseInstFromFunction(*Prev);
+ return &I;
+ }
+ return nullptr;
+}
+
+Instruction *InstCombinerImpl::visitUnconditionalBranchInst(BranchInst &BI) {
assert(BI.isUnconditional() && "Only for unconditional branches.");
// If this store is the second-to-last instruction in the basic block
@@ -2763,7 +2901,7 @@ Instruction *InstCombiner::visitUnconditionalBranchInst(BranchInst &BI) {
return nullptr;
}
-Instruction *InstCombiner::visitBranchInst(BranchInst &BI) {
+Instruction *InstCombinerImpl::visitBranchInst(BranchInst &BI) {
if (BI.isUnconditional())
return visitUnconditionalBranchInst(BI);
@@ -2799,7 +2937,7 @@ Instruction *InstCombiner::visitBranchInst(BranchInst &BI) {
return nullptr;
}
-Instruction *InstCombiner::visitSwitchInst(SwitchInst &SI) {
+Instruction *InstCombinerImpl::visitSwitchInst(SwitchInst &SI) {
Value *Cond = SI.getCondition();
Value *Op0;
ConstantInt *AddRHS;
@@ -2830,7 +2968,7 @@ Instruction *InstCombiner::visitSwitchInst(SwitchInst &SI) {
unsigned NewWidth = Known.getBitWidth() - std::max(LeadingKnownZeros, LeadingKnownOnes);
// Shrink the condition operand if the new type is smaller than the old type.
- // But do not shrink to a non-standard type, because backend can't generate
+ // But do not shrink to a non-standard type, because backend can't generate
// good code for that yet.
// TODO: We can make it aggressive again after fixing PR39569.
if (NewWidth > 0 && NewWidth < Known.getBitWidth() &&
@@ -2849,7 +2987,7 @@ Instruction *InstCombiner::visitSwitchInst(SwitchInst &SI) {
return nullptr;
}
-Instruction *InstCombiner::visitExtractValueInst(ExtractValueInst &EV) {
+Instruction *InstCombinerImpl::visitExtractValueInst(ExtractValueInst &EV) {
Value *Agg = EV.getAggregateOperand();
if (!EV.hasIndices())
@@ -2994,10 +3132,11 @@ static bool isCatchAll(EHPersonality Personality, Constant *TypeInfo) {
case EHPersonality::GNU_CXX_SjLj:
case EHPersonality::GNU_ObjC:
case EHPersonality::MSVC_X86SEH:
- case EHPersonality::MSVC_Win64SEH:
+ case EHPersonality::MSVC_TableSEH:
case EHPersonality::MSVC_CXX:
case EHPersonality::CoreCLR:
case EHPersonality::Wasm_CXX:
+ case EHPersonality::XL_CXX:
return TypeInfo->isNullValue();
}
llvm_unreachable("invalid enum");
@@ -3010,7 +3149,7 @@ static bool shorter_filter(const Value *LHS, const Value *RHS) {
cast<ArrayType>(RHS->getType())->getNumElements();
}
-Instruction *InstCombiner::visitLandingPadInst(LandingPadInst &LI) {
+Instruction *InstCombinerImpl::visitLandingPadInst(LandingPadInst &LI) {
// The logic here should be correct for any real-world personality function.
// However if that turns out not to be true, the offending logic can always
// be conditioned on the personality function, like the catch-all logic is.
@@ -3319,12 +3458,46 @@ Instruction *InstCombiner::visitLandingPadInst(LandingPadInst &LI) {
return nullptr;
}
-Instruction *InstCombiner::visitFreeze(FreezeInst &I) {
+Instruction *InstCombinerImpl::visitFreeze(FreezeInst &I) {
Value *Op0 = I.getOperand(0);
if (Value *V = SimplifyFreezeInst(Op0, SQ.getWithInstruction(&I)))
return replaceInstUsesWith(I, V);
+ // freeze (phi const, x) --> phi const, (freeze x)
+ if (auto *PN = dyn_cast<PHINode>(Op0)) {
+ if (Instruction *NV = foldOpIntoPhi(I, PN))
+ return NV;
+ }
+
+ if (match(Op0, m_Undef())) {
+ // If I is freeze(undef), see its uses and fold it to the best constant.
+ // - or: pick -1
+ // - select's condition: pick the value that leads to choosing a constant
+ // - other ops: pick 0
+ Constant *BestValue = nullptr;
+ Constant *NullValue = Constant::getNullValue(I.getType());
+ for (const auto *U : I.users()) {
+ Constant *C = NullValue;
+
+ if (match(U, m_Or(m_Value(), m_Value())))
+ C = Constant::getAllOnesValue(I.getType());
+ else if (const auto *SI = dyn_cast<SelectInst>(U)) {
+ if (SI->getCondition() == &I) {
+ APInt CondVal(1, isa<Constant>(SI->getFalseValue()) ? 0 : 1);
+ C = Constant::getIntegerValue(I.getType(), CondVal);
+ }
+ }
+
+ if (!BestValue)
+ BestValue = C;
+ else if (BestValue != C)
+ BestValue = NullValue;
+ }
+
+ return replaceInstUsesWith(I, BestValue);
+ }
+
return nullptr;
}
@@ -3430,7 +3603,7 @@ static bool TryToSinkInstruction(Instruction *I, BasicBlock *DestBlock) {
return true;
}
-bool InstCombiner::run() {
+bool InstCombinerImpl::run() {
while (!Worklist.isEmpty()) {
// Walk deferred instructions in reverse order, and push them to the
// worklist, which means they'll end up popped from the worklist in-order.
@@ -3492,7 +3665,9 @@ bool InstCombiner::run() {
else
UserParent = UserInst->getParent();
- if (UserParent != BB) {
+ // Try sinking to another block. If that block is unreachable, then do
+ // not bother. SimplifyCFG should handle it.
+ if (UserParent != BB && DT.isReachableFromEntry(UserParent)) {
// See if the user is one of our successors that has only one
// predecessor, so that we don't have to split the critical edge.
bool ShouldSink = UserParent->getUniquePredecessor() == BB;
@@ -3526,7 +3701,8 @@ bool InstCombiner::run() {
// Now that we have an instruction, try combining it to simplify it.
Builder.SetInsertPoint(I);
- Builder.SetCurrentDebugLocation(I->getDebugLoc());
+ Builder.CollectMetadataToCopy(
+ I, {LLVMContext::MD_dbg, LLVMContext::MD_annotation});
#ifndef NDEBUG
std::string OrigI;
@@ -3541,8 +3717,8 @@ bool InstCombiner::run() {
LLVM_DEBUG(dbgs() << "IC: Old = " << *I << '\n'
<< " New = " << *Result << '\n');
- if (I->getDebugLoc())
- Result->setDebugLoc(I->getDebugLoc());
+ Result->copyMetadata(*I,
+ {LLVMContext::MD_dbg, LLVMContext::MD_annotation});
// Everything uses the new instruction now.
I->replaceAllUsesWith(Result);
@@ -3553,10 +3729,14 @@ bool InstCombiner::run() {
BasicBlock *InstParent = I->getParent();
BasicBlock::iterator InsertPos = I->getIterator();
- // If we replace a PHI with something that isn't a PHI, fix up the
- // insertion point.
- if (!isa<PHINode>(Result) && isa<PHINode>(InsertPos))
- InsertPos = InstParent->getFirstInsertionPt();
+ // Are we replace a PHI with something that isn't a PHI, or vice versa?
+ if (isa<PHINode>(Result) != isa<PHINode>(I)) {
+ // We need to fix up the insertion point.
+ if (isa<PHINode>(I)) // PHI -> Non-PHI
+ InsertPos = InstParent->getFirstInsertionPt();
+ else // Non-PHI -> PHI
+ InsertPos = InstParent->getFirstNonPHI()->getIterator();
+ }
InstParent->getInstList().insert(InsertPos, Result);
@@ -3586,6 +3766,55 @@ bool InstCombiner::run() {
return MadeIRChange;
}
+// Track the scopes used by !alias.scope and !noalias. In a function, a
+// @llvm.experimental.noalias.scope.decl is only useful if that scope is used
+// by both sets. If not, the declaration of the scope can be safely omitted.
+// The MDNode of the scope can be omitted as well for the instructions that are
+// part of this function. We do not do that at this point, as this might become
+// too time consuming to do.
+class AliasScopeTracker {
+ SmallPtrSet<const MDNode *, 8> UsedAliasScopesAndLists;
+ SmallPtrSet<const MDNode *, 8> UsedNoAliasScopesAndLists;
+
+public:
+ void analyse(Instruction *I) {
+ // This seems to be faster than checking 'mayReadOrWriteMemory()'.
+ if (!I->hasMetadataOtherThanDebugLoc())
+ return;
+
+ auto Track = [](Metadata *ScopeList, auto &Container) {
+ const auto *MDScopeList = dyn_cast_or_null<MDNode>(ScopeList);
+ if (!MDScopeList || !Container.insert(MDScopeList).second)
+ return;
+ for (auto &MDOperand : MDScopeList->operands())
+ if (auto *MDScope = dyn_cast<MDNode>(MDOperand))
+ Container.insert(MDScope);
+ };
+
+ Track(I->getMetadata(LLVMContext::MD_alias_scope), UsedAliasScopesAndLists);
+ Track(I->getMetadata(LLVMContext::MD_noalias), UsedNoAliasScopesAndLists);
+ }
+
+ bool isNoAliasScopeDeclDead(Instruction *Inst) {
+ NoAliasScopeDeclInst *Decl = dyn_cast<NoAliasScopeDeclInst>(Inst);
+ if (!Decl)
+ return false;
+
+ assert(Decl->use_empty() &&
+ "llvm.experimental.noalias.scope.decl in use ?");
+ const MDNode *MDSL = Decl->getScopeList();
+ assert(MDSL->getNumOperands() == 1 &&
+ "llvm.experimental.noalias.scope should refer to a single scope");
+ auto &MDOperand = MDSL->getOperand(0);
+ if (auto *MD = dyn_cast<MDNode>(MDOperand))
+ return !UsedAliasScopesAndLists.contains(MD) ||
+ !UsedNoAliasScopesAndLists.contains(MD);
+
+ // Not an MDNode ? throw away.
+ return true;
+ }
+};
+
/// Populate the IC worklist from a function, by walking it in depth-first
/// order and adding all reachable code to the worklist.
///
@@ -3604,6 +3833,7 @@ static bool prepareICWorklistFromFunction(Function &F, const DataLayout &DL,
SmallVector<Instruction*, 128> InstrsForInstCombineWorklist;
DenseMap<Constant *, Constant *> FoldedConstants;
+ AliasScopeTracker SeenAliasScopes;
do {
BasicBlock *BB = Worklist.pop_back_val();
@@ -3650,8 +3880,10 @@ static bool prepareICWorklistFromFunction(Function &F, const DataLayout &DL,
// Skip processing debug intrinsics in InstCombine. Processing these call instructions
// consumes non-trivial amount of time and provides no value for the optimization.
- if (!isa<DbgInfoIntrinsic>(Inst))
+ if (!isa<DbgInfoIntrinsic>(Inst)) {
InstrsForInstCombineWorklist.push_back(Inst);
+ SeenAliasScopes.analyse(Inst);
+ }
}
// Recursively visit successors. If this is a branch or switch on a
@@ -3671,8 +3903,7 @@ static bool prepareICWorklistFromFunction(Function &F, const DataLayout &DL,
}
}
- for (BasicBlock *SuccBB : successors(TI))
- Worklist.push_back(SuccBB);
+ append_range(Worklist, successors(TI));
} while (!Worklist.empty());
// Remove instructions inside unreachable blocks. This prevents the
@@ -3682,8 +3913,12 @@ static bool prepareICWorklistFromFunction(Function &F, const DataLayout &DL,
if (Visited.count(&BB))
continue;
- unsigned NumDeadInstInBB = removeAllNonTerminatorAndEHPadInstructions(&BB);
- MadeIRChange |= NumDeadInstInBB > 0;
+ unsigned NumDeadInstInBB;
+ unsigned NumDeadDbgInstInBB;
+ std::tie(NumDeadInstInBB, NumDeadDbgInstInBB) =
+ removeAllNonTerminatorAndEHPadInstructions(&BB);
+
+ MadeIRChange |= NumDeadInstInBB + NumDeadDbgInstInBB > 0;
NumDeadInst += NumDeadInstInBB;
}
@@ -3696,7 +3931,8 @@ static bool prepareICWorklistFromFunction(Function &F, const DataLayout &DL,
for (Instruction *Inst : reverse(InstrsForInstCombineWorklist)) {
// DCE instruction if trivially dead. As we iterate in reverse program
// order here, we will clean up whole chains of dead instructions.
- if (isInstructionTriviallyDead(Inst, TLI)) {
+ if (isInstructionTriviallyDead(Inst, TLI) ||
+ SeenAliasScopes.isNoAliasScopeDeclDead(Inst)) {
++NumDeadInst;
LLVM_DEBUG(dbgs() << "IC: DCE: " << *Inst << '\n');
salvageDebugInfo(*Inst);
@@ -3713,8 +3949,8 @@ static bool prepareICWorklistFromFunction(Function &F, const DataLayout &DL,
static bool combineInstructionsOverFunction(
Function &F, InstCombineWorklist &Worklist, AliasAnalysis *AA,
- AssumptionCache &AC, TargetLibraryInfo &TLI, DominatorTree &DT,
- OptimizationRemarkEmitter &ORE, BlockFrequencyInfo *BFI,
+ AssumptionCache &AC, TargetLibraryInfo &TLI, TargetTransformInfo &TTI,
+ DominatorTree &DT, OptimizationRemarkEmitter &ORE, BlockFrequencyInfo *BFI,
ProfileSummaryInfo *PSI, unsigned MaxIterations, LoopInfo *LI) {
auto &DL = F.getParent()->getDataLayout();
MaxIterations = std::min(MaxIterations, LimitMaxIterations.getValue());
@@ -3738,6 +3974,7 @@ static bool combineInstructionsOverFunction(
// Iterate while there is work to do.
unsigned Iteration = 0;
while (true) {
+ ++NumWorklistIterations;
++Iteration;
if (Iteration > InfiniteLoopDetectionThreshold) {
@@ -3758,8 +3995,8 @@ static bool combineInstructionsOverFunction(
MadeIRChange |= prepareICWorklistFromFunction(F, DL, &TLI, Worklist);
- InstCombiner IC(Worklist, Builder, F.hasMinSize(), AA,
- AC, TLI, DT, ORE, BFI, PSI, DL, LI);
+ InstCombinerImpl IC(Worklist, Builder, F.hasMinSize(), AA, AC, TLI, TTI, DT,
+ ORE, BFI, PSI, DL, LI);
IC.MaxArraySizeForCombine = MaxArraySize;
if (!IC.run())
@@ -3782,6 +4019,7 @@ PreservedAnalyses InstCombinePass::run(Function &F,
auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
auto &TLI = AM.getResult<TargetLibraryAnalysis>(F);
auto &ORE = AM.getResult<OptimizationRemarkEmitterAnalysis>(F);
+ auto &TTI = AM.getResult<TargetIRAnalysis>(F);
auto *LI = AM.getCachedResult<LoopAnalysis>(F);
@@ -3792,8 +4030,8 @@ PreservedAnalyses InstCombinePass::run(Function &F,
auto *BFI = (PSI && PSI->hasProfileSummary()) ?
&AM.getResult<BlockFrequencyAnalysis>(F) : nullptr;
- if (!combineInstructionsOverFunction(F, Worklist, AA, AC, TLI, DT, ORE, BFI,
- PSI, MaxIterations, LI))
+ if (!combineInstructionsOverFunction(F, Worklist, AA, AC, TLI, TTI, DT, ORE,
+ BFI, PSI, MaxIterations, LI))
// No changes, all analyses are preserved.
return PreservedAnalyses::all();
@@ -3811,6 +4049,7 @@ void InstructionCombiningPass::getAnalysisUsage(AnalysisUsage &AU) const {
AU.addRequired<AAResultsWrapperPass>();
AU.addRequired<AssumptionCacheTracker>();
AU.addRequired<TargetLibraryInfoWrapperPass>();
+ AU.addRequired<TargetTransformInfoWrapperPass>();
AU.addRequired<DominatorTreeWrapperPass>();
AU.addRequired<OptimizationRemarkEmitterWrapperPass>();
AU.addPreserved<DominatorTreeWrapperPass>();
@@ -3829,6 +4068,7 @@ bool InstructionCombiningPass::runOnFunction(Function &F) {
auto AA = &getAnalysis<AAResultsWrapperPass>().getAAResults();
auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
+ auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
auto &ORE = getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE();
@@ -3842,8 +4082,8 @@ bool InstructionCombiningPass::runOnFunction(Function &F) {
&getAnalysis<LazyBlockFrequencyInfoPass>().getBFI() :
nullptr;
- return combineInstructionsOverFunction(F, Worklist, AA, AC, TLI, DT, ORE, BFI,
- PSI, MaxIterations, LI);
+ return combineInstructionsOverFunction(F, Worklist, AA, AC, TLI, TTI, DT, ORE,
+ BFI, PSI, MaxIterations, LI);
}
char InstructionCombiningPass::ID = 0;
@@ -3862,6 +4102,7 @@ INITIALIZE_PASS_BEGIN(InstructionCombiningPass, "instcombine",
"Combine redundant instructions", false, false)
INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass)
INITIALIZE_PASS_DEPENDENCY(GlobalsAAWrapperPass)