diff options
author | Dimitry Andric <dim@FreeBSD.org> | 2021-02-16 20:13:02 +0000 |
---|---|---|
committer | Dimitry Andric <dim@FreeBSD.org> | 2021-02-16 20:13:02 +0000 |
commit | b60736ec1405bb0a8dd40989f67ef4c93da068ab (patch) | |
tree | 5c43fbb7c9fc45f0f87e0e6795a86267dbd12f9d /llvm/lib/Transforms/InstCombine | |
parent | cfca06d7963fa0909f90483b42a6d7d194d01e08 (diff) | |
download | src-b60736ec1405bb0a8dd40989f67ef4c93da068ab.tar.gz src-b60736ec1405bb0a8dd40989f67ef4c93da068ab.zip |
Vendor import of llvm-project main 8e464dd76bef, the last commit beforevendor/llvm-project/llvmorg-12-init-17869-g8e464dd76bef
the upstream release/12.x branch was created.
Diffstat (limited to 'llvm/lib/Transforms/InstCombine')
17 files changed, 3267 insertions, 5440 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp index a7f5e0a7774d..bacb8689892a 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -29,6 +29,7 @@ #include "llvm/Support/AlignOf.h" #include "llvm/Support/Casting.h" #include "llvm/Support/KnownBits.h" +#include "llvm/Transforms/InstCombine/InstCombiner.h" #include <cassert> #include <utility> @@ -81,11 +82,11 @@ namespace { private: bool insaneIntVal(int V) { return V > 4 || V < -4; } - APFloat *getFpValPtr() - { return reinterpret_cast<APFloat *>(&FpValBuf.buffer[0]); } + APFloat *getFpValPtr() { return reinterpret_cast<APFloat *>(&FpValBuf); } - const APFloat *getFpValPtr() const - { return reinterpret_cast<const APFloat *>(&FpValBuf.buffer[0]); } + const APFloat *getFpValPtr() const { + return reinterpret_cast<const APFloat *>(&FpValBuf); + } const APFloat &getFpVal() const { assert(IsFp && BufHasFpVal && "Incorret state"); @@ -860,7 +861,7 @@ static Instruction *foldNoWrapAdd(BinaryOperator &Add, return nullptr; } -Instruction *InstCombiner::foldAddWithConstant(BinaryOperator &Add) { +Instruction *InstCombinerImpl::foldAddWithConstant(BinaryOperator &Add) { Value *Op0 = Add.getOperand(0), *Op1 = Add.getOperand(1); Constant *Op1C; if (!match(Op1, m_Constant(Op1C))) @@ -886,15 +887,15 @@ Instruction *InstCombiner::foldAddWithConstant(BinaryOperator &Add) { // zext(bool) + C -> bool ? C + 1 : C if (match(Op0, m_ZExt(m_Value(X))) && X->getType()->getScalarSizeInBits() == 1) - return SelectInst::Create(X, AddOne(Op1C), Op1); + return SelectInst::Create(X, InstCombiner::AddOne(Op1C), Op1); // sext(bool) + C -> bool ? C - 1 : C if (match(Op0, m_SExt(m_Value(X))) && X->getType()->getScalarSizeInBits() == 1) - return SelectInst::Create(X, SubOne(Op1C), Op1); + return SelectInst::Create(X, InstCombiner::SubOne(Op1C), Op1); // ~X + C --> (C-1) - X if (match(Op0, m_Not(m_Value(X)))) - return BinaryOperator::CreateSub(SubOne(Op1C), X); + return BinaryOperator::CreateSub(InstCombiner::SubOne(Op1C), X); const APInt *C; if (!match(Op1, m_APInt(C))) @@ -923,6 +924,39 @@ Instruction *InstCombiner::foldAddWithConstant(BinaryOperator &Add) { C2->isMinSignedValue() && C2->sext(Ty->getScalarSizeInBits()) == *C) return CastInst::Create(Instruction::SExt, X, Ty); + if (match(Op0, m_Xor(m_Value(X), m_APInt(C2)))) { + // (X ^ signmask) + C --> (X + (signmask ^ C)) + if (C2->isSignMask()) + return BinaryOperator::CreateAdd(X, ConstantInt::get(Ty, *C2 ^ *C)); + + // If X has no high-bits set above an xor mask: + // add (xor X, LowMaskC), C --> sub (LowMaskC + C), X + if (C2->isMask()) { + KnownBits LHSKnown = computeKnownBits(X, 0, &Add); + if ((*C2 | LHSKnown.Zero).isAllOnesValue()) + return BinaryOperator::CreateSub(ConstantInt::get(Ty, *C2 + *C), X); + } + + // Look for a math+logic pattern that corresponds to sext-in-register of a + // value with cleared high bits. Convert that into a pair of shifts: + // add (xor X, 0x80), 0xF..F80 --> (X << ShAmtC) >>s ShAmtC + // add (xor X, 0xF..F80), 0x80 --> (X << ShAmtC) >>s ShAmtC + if (Op0->hasOneUse() && *C2 == -(*C)) { + unsigned BitWidth = Ty->getScalarSizeInBits(); + unsigned ShAmt = 0; + if (C->isPowerOf2()) + ShAmt = BitWidth - C->logBase2() - 1; + else if (C2->isPowerOf2()) + ShAmt = BitWidth - C2->logBase2() - 1; + if (ShAmt && MaskedValueIsZero(X, APInt::getHighBitsSet(BitWidth, ShAmt), + 0, &Add)) { + Constant *ShAmtC = ConstantInt::get(Ty, ShAmt); + Value *NewShl = Builder.CreateShl(X, ShAmtC, "sext"); + return BinaryOperator::CreateAShr(NewShl, ShAmtC); + } + } + } + if (C->isOneValue() && Op0->hasOneUse()) { // add (sext i1 X), 1 --> zext (not X) // TODO: The smallest IR representation is (select X, 0, 1), and that would @@ -943,6 +977,15 @@ Instruction *InstCombiner::foldAddWithConstant(BinaryOperator &Add) { } } + // If all bits affected by the add are included in a high-bit-mask, do the + // add before the mask op: + // (X & 0xFF00) + xx00 --> (X + xx00) & 0xFF00 + if (match(Op0, m_OneUse(m_And(m_Value(X), m_APInt(C2)))) && + C2->isNegative() && C2->isShiftedMask() && *C == (*C & *C2)) { + Value *NewAdd = Builder.CreateAdd(X, ConstantInt::get(Ty, *C)); + return BinaryOperator::CreateAnd(NewAdd, ConstantInt::get(Ty, *C2)); + } + return nullptr; } @@ -1021,7 +1064,7 @@ static bool MulWillOverflow(APInt &C0, APInt &C1, bool IsSigned) { // Simplifies X % C0 + (( X / C0 ) % C1) * C0 to X % (C0 * C1), where (C0 * C1) // does not overflow. -Value *InstCombiner::SimplifyAddWithRemainder(BinaryOperator &I) { +Value *InstCombinerImpl::SimplifyAddWithRemainder(BinaryOperator &I) { Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); Value *X, *MulOpV; APInt C0, MulOpC; @@ -1097,9 +1140,9 @@ static Instruction *foldToUnsignedSaturatedAdd(BinaryOperator &I) { return nullptr; } -Instruction * -InstCombiner::canonicalizeCondSignextOfHighBitExtractToSignextHighBitExtract( - BinaryOperator &I) { +Instruction *InstCombinerImpl:: + canonicalizeCondSignextOfHighBitExtractToSignextHighBitExtract( + BinaryOperator &I) { assert((I.getOpcode() == Instruction::Add || I.getOpcode() == Instruction::Or || I.getOpcode() == Instruction::Sub) && @@ -1198,7 +1241,44 @@ InstCombiner::canonicalizeCondSignextOfHighBitExtractToSignextHighBitExtract( return TruncInst::CreateTruncOrBitCast(NewAShr, I.getType()); } -Instruction *InstCombiner::visitAdd(BinaryOperator &I) { +/// This is a specialization of a more general transform from +/// SimplifyUsingDistributiveLaws. If that code can be made to work optimally +/// for multi-use cases or propagating nsw/nuw, then we would not need this. +static Instruction *factorizeMathWithShlOps(BinaryOperator &I, + InstCombiner::BuilderTy &Builder) { + // TODO: Also handle mul by doubling the shift amount? + assert((I.getOpcode() == Instruction::Add || + I.getOpcode() == Instruction::Sub) && + "Expected add/sub"); + auto *Op0 = dyn_cast<BinaryOperator>(I.getOperand(0)); + auto *Op1 = dyn_cast<BinaryOperator>(I.getOperand(1)); + if (!Op0 || !Op1 || !(Op0->hasOneUse() || Op1->hasOneUse())) + return nullptr; + + Value *X, *Y, *ShAmt; + if (!match(Op0, m_Shl(m_Value(X), m_Value(ShAmt))) || + !match(Op1, m_Shl(m_Value(Y), m_Specific(ShAmt)))) + return nullptr; + + // No-wrap propagates only when all ops have no-wrap. + bool HasNSW = I.hasNoSignedWrap() && Op0->hasNoSignedWrap() && + Op1->hasNoSignedWrap(); + bool HasNUW = I.hasNoUnsignedWrap() && Op0->hasNoUnsignedWrap() && + Op1->hasNoUnsignedWrap(); + + // add/sub (X << ShAmt), (Y << ShAmt) --> (add/sub X, Y) << ShAmt + Value *NewMath = Builder.CreateBinOp(I.getOpcode(), X, Y); + if (auto *NewI = dyn_cast<BinaryOperator>(NewMath)) { + NewI->setHasNoSignedWrap(HasNSW); + NewI->setHasNoUnsignedWrap(HasNUW); + } + auto *NewShl = BinaryOperator::CreateShl(NewMath, ShAmt); + NewShl->setHasNoSignedWrap(HasNSW); + NewShl->setHasNoUnsignedWrap(HasNUW); + return NewShl; +} + +Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) { if (Value *V = SimplifyAddInst(I.getOperand(0), I.getOperand(1), I.hasNoSignedWrap(), I.hasNoUnsignedWrap(), SQ.getWithInstruction(&I))) @@ -1214,59 +1294,17 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { if (Value *V = SimplifyUsingDistributiveLaws(I)) return replaceInstUsesWith(I, V); + if (Instruction *R = factorizeMathWithShlOps(I, Builder)) + return R; + if (Instruction *X = foldAddWithConstant(I)) return X; if (Instruction *X = foldNoWrapAdd(I, Builder)) return X; - // FIXME: This should be moved into the above helper function to allow these - // transforms for general constant or constant splat vectors. Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); Type *Ty = I.getType(); - if (ConstantInt *CI = dyn_cast<ConstantInt>(RHS)) { - Value *XorLHS = nullptr; ConstantInt *XorRHS = nullptr; - if (match(LHS, m_Xor(m_Value(XorLHS), m_ConstantInt(XorRHS)))) { - unsigned TySizeBits = Ty->getScalarSizeInBits(); - const APInt &RHSVal = CI->getValue(); - unsigned ExtendAmt = 0; - // If we have ADD(XOR(AND(X, 0xFF), 0x80), 0xF..F80), it's a sext. - // If we have ADD(XOR(AND(X, 0xFF), 0xF..F80), 0x80), it's a sext. - if (XorRHS->getValue() == -RHSVal) { - if (RHSVal.isPowerOf2()) - ExtendAmt = TySizeBits - RHSVal.logBase2() - 1; - else if (XorRHS->getValue().isPowerOf2()) - ExtendAmt = TySizeBits - XorRHS->getValue().logBase2() - 1; - } - - if (ExtendAmt) { - APInt Mask = APInt::getHighBitsSet(TySizeBits, ExtendAmt); - if (!MaskedValueIsZero(XorLHS, Mask, 0, &I)) - ExtendAmt = 0; - } - - if (ExtendAmt) { - Constant *ShAmt = ConstantInt::get(Ty, ExtendAmt); - Value *NewShl = Builder.CreateShl(XorLHS, ShAmt, "sext"); - return BinaryOperator::CreateAShr(NewShl, ShAmt); - } - - // If this is a xor that was canonicalized from a sub, turn it back into - // a sub and fuse this add with it. - if (LHS->hasOneUse() && (XorRHS->getValue()+1).isPowerOf2()) { - KnownBits LHSKnown = computeKnownBits(XorLHS, 0, &I); - if ((XorRHS->getValue() | LHSKnown.Zero).isAllOnesValue()) - return BinaryOperator::CreateSub(ConstantExpr::getAdd(XorRHS, CI), - XorLHS); - } - // (X + signmask) + C could have gotten canonicalized to (X^signmask) + C, - // transform them into (X + (signmask ^ C)) - if (XorRHS->getValue().isSignMask()) - return BinaryOperator::CreateAdd(XorLHS, - ConstantExpr::getXor(XorRHS, CI)); - } - } - if (Ty->isIntOrIntVectorTy(1)) return BinaryOperator::CreateXor(LHS, RHS); @@ -1329,34 +1367,6 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { if (haveNoCommonBitsSet(LHS, RHS, DL, &AC, &I, &DT)) return BinaryOperator::CreateOr(LHS, RHS); - // FIXME: We already did a check for ConstantInt RHS above this. - // FIXME: Is this pattern covered by another fold? No regression tests fail on - // removal. - if (ConstantInt *CRHS = dyn_cast<ConstantInt>(RHS)) { - // (X & FF00) + xx00 -> (X+xx00) & FF00 - Value *X; - ConstantInt *C2; - if (LHS->hasOneUse() && - match(LHS, m_And(m_Value(X), m_ConstantInt(C2))) && - CRHS->getValue() == (CRHS->getValue() & C2->getValue())) { - // See if all bits from the first bit set in the Add RHS up are included - // in the mask. First, get the rightmost bit. - const APInt &AddRHSV = CRHS->getValue(); - - // Form a mask of all bits from the lowest bit added through the top. - APInt AddRHSHighBits(~((AddRHSV & -AddRHSV)-1)); - - // See if the and mask includes all of these bits. - APInt AddRHSHighBitsAnd(AddRHSHighBits & C2->getValue()); - - if (AddRHSHighBits == AddRHSHighBitsAnd) { - // Okay, the xform is safe. Insert the new add pronto. - Value *NewAdd = Builder.CreateAdd(X, CRHS, LHS->getName()); - return BinaryOperator::CreateAnd(NewAdd, C2); - } - } - } - // add (select X 0 (sub n A)) A --> select X A n { SelectInst *SI = dyn_cast<SelectInst>(LHS); @@ -1424,6 +1434,14 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { if (Instruction *SatAdd = foldToUnsignedSaturatedAdd(I)) return SatAdd; + // usub.sat(A, B) + B => umax(A, B) + if (match(&I, m_c_BinOp( + m_OneUse(m_Intrinsic<Intrinsic::usub_sat>(m_Value(A), m_Value(B))), + m_Deferred(B)))) { + return replaceInstUsesWith(I, + Builder.CreateIntrinsic(Intrinsic::umax, {I.getType()}, {A, B})); + } + return Changed ? &I : nullptr; } @@ -1486,7 +1504,7 @@ static Instruction *factorizeFAddFSub(BinaryOperator &I, : BinaryOperator::CreateFDivFMF(XY, Z, &I); } -Instruction *InstCombiner::visitFAdd(BinaryOperator &I) { +Instruction *InstCombinerImpl::visitFAdd(BinaryOperator &I) { if (Value *V = SimplifyFAddInst(I.getOperand(0), I.getOperand(1), I.getFastMathFlags(), SQ.getWithInstruction(&I))) @@ -1600,49 +1618,33 @@ Instruction *InstCombiner::visitFAdd(BinaryOperator &I) { /// Optimize pointer differences into the same array into a size. Consider: /// &A[10] - &A[0]: we should compile this to "10". LHS/RHS are the pointer /// operands to the ptrtoint instructions for the LHS/RHS of the subtract. -Value *InstCombiner::OptimizePointerDifference(Value *LHS, Value *RHS, - Type *Ty, bool IsNUW) { +Value *InstCombinerImpl::OptimizePointerDifference(Value *LHS, Value *RHS, + Type *Ty, bool IsNUW) { // If LHS is a gep based on RHS or RHS is a gep based on LHS, we can optimize // this. bool Swapped = false; GEPOperator *GEP1 = nullptr, *GEP2 = nullptr; + if (!isa<GEPOperator>(LHS) && isa<GEPOperator>(RHS)) { + std::swap(LHS, RHS); + Swapped = true; + } - // For now we require one side to be the base pointer "A" or a constant - // GEP derived from it. - if (GEPOperator *LHSGEP = dyn_cast<GEPOperator>(LHS)) { + // Require at least one GEP with a common base pointer on both sides. + if (auto *LHSGEP = dyn_cast<GEPOperator>(LHS)) { // (gep X, ...) - X if (LHSGEP->getOperand(0) == RHS) { GEP1 = LHSGEP; - Swapped = false; - } else if (GEPOperator *RHSGEP = dyn_cast<GEPOperator>(RHS)) { + } else if (auto *RHSGEP = dyn_cast<GEPOperator>(RHS)) { // (gep X, ...) - (gep X, ...) if (LHSGEP->getOperand(0)->stripPointerCasts() == - RHSGEP->getOperand(0)->stripPointerCasts()) { - GEP2 = RHSGEP; + RHSGEP->getOperand(0)->stripPointerCasts()) { GEP1 = LHSGEP; - Swapped = false; - } - } - } - - if (GEPOperator *RHSGEP = dyn_cast<GEPOperator>(RHS)) { - // X - (gep X, ...) - if (RHSGEP->getOperand(0) == LHS) { - GEP1 = RHSGEP; - Swapped = true; - } else if (GEPOperator *LHSGEP = dyn_cast<GEPOperator>(LHS)) { - // (gep X, ...) - (gep X, ...) - if (RHSGEP->getOperand(0)->stripPointerCasts() == - LHSGEP->getOperand(0)->stripPointerCasts()) { - GEP2 = LHSGEP; - GEP1 = RHSGEP; - Swapped = true; + GEP2 = RHSGEP; } } } if (!GEP1) - // No GEP found. return nullptr; if (GEP2) { @@ -1670,19 +1672,18 @@ Value *InstCombiner::OptimizePointerDifference(Value *LHS, Value *RHS, Value *Result = EmitGEPOffset(GEP1); // If this is a single inbounds GEP and the original sub was nuw, - // then the final multiplication is also nuw. We match an extra add zero - // here, because that's what EmitGEPOffset() generates. - Instruction *I; - if (IsNUW && !GEP2 && !Swapped && GEP1->isInBounds() && - match(Result, m_Add(m_Instruction(I), m_Zero())) && - I->getOpcode() == Instruction::Mul) - I->setHasNoUnsignedWrap(); - - // If we had a constant expression GEP on the other side offsetting the - // pointer, subtract it from the offset we have. + // then the final multiplication is also nuw. + if (auto *I = dyn_cast<Instruction>(Result)) + if (IsNUW && !GEP2 && !Swapped && GEP1->isInBounds() && + I->getOpcode() == Instruction::Mul) + I->setHasNoUnsignedWrap(); + + // If we have a 2nd GEP of the same base pointer, subtract the offsets. + // If both GEPs are inbounds, then the subtract does not have signed overflow. if (GEP2) { Value *Offset = EmitGEPOffset(GEP2); - Result = Builder.CreateSub(Result, Offset); + Result = Builder.CreateSub(Result, Offset, "gepdiff", /* NUW */ false, + GEP1->isInBounds() && GEP2->isInBounds()); } // If we have p - gep(p, ...) then we have to negate the result. @@ -1692,7 +1693,7 @@ Value *InstCombiner::OptimizePointerDifference(Value *LHS, Value *RHS, return Builder.CreateIntCast(Result, Ty, true); } -Instruction *InstCombiner::visitSub(BinaryOperator &I) { +Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) { if (Value *V = SimplifySubInst(I.getOperand(0), I.getOperand(1), I.hasNoSignedWrap(), I.hasNoUnsignedWrap(), SQ.getWithInstruction(&I))) @@ -1721,6 +1722,19 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { return Res; } + // Try this before Negator to preserve NSW flag. + if (Instruction *R = factorizeMathWithShlOps(I, Builder)) + return R; + + if (Constant *C = dyn_cast<Constant>(Op0)) { + Value *X; + Constant *C2; + + // C-(X+C2) --> (C-C2)-X + if (match(Op1, m_Add(m_Value(X), m_Constant(C2)))) + return BinaryOperator::CreateSub(ConstantExpr::getSub(C, C2), X); + } + auto TryToNarrowDeduceFlags = [this, &I, &Op0, &Op1]() -> Instruction * { if (Instruction *Ext = narrowMathIfNoOverflow(I)) return Ext; @@ -1788,8 +1802,7 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { } auto m_AddRdx = [](Value *&Vec) { - return m_OneUse( - m_Intrinsic<Intrinsic::experimental_vector_reduce_add>(m_Value(Vec))); + return m_OneUse(m_Intrinsic<Intrinsic::vector_reduce_add>(m_Value(Vec))); }; Value *V0, *V1; if (match(Op0, m_AddRdx(V0)) && match(Op1, m_AddRdx(V1)) && @@ -1797,8 +1810,8 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { // Difference of sums is sum of differences: // add_rdx(V0) - add_rdx(V1) --> add_rdx(V0 - V1) Value *Sub = Builder.CreateSub(V0, V1); - Value *Rdx = Builder.CreateIntrinsic( - Intrinsic::experimental_vector_reduce_add, {Sub->getType()}, {Sub}); + Value *Rdx = Builder.CreateIntrinsic(Intrinsic::vector_reduce_add, + {Sub->getType()}, {Sub}); return replaceInstUsesWith(I, Rdx); } @@ -1806,14 +1819,14 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { Value *X; if (match(Op1, m_ZExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) // C - (zext bool) --> bool ? C - 1 : C - return SelectInst::Create(X, SubOne(C), C); + return SelectInst::Create(X, InstCombiner::SubOne(C), C); if (match(Op1, m_SExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) // C - (sext bool) --> bool ? C + 1 : C - return SelectInst::Create(X, AddOne(C), C); + return SelectInst::Create(X, InstCombiner::AddOne(C), C); // C - ~X == X + (1+C) if (match(Op1, m_Not(m_Value(X)))) - return BinaryOperator::CreateAdd(X, AddOne(C)); + return BinaryOperator::CreateAdd(X, InstCombiner::AddOne(C)); // Try to fold constant sub into select arguments. if (SelectInst *SI = dyn_cast<SelectInst>(Op1)) @@ -1828,12 +1841,8 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { Constant *C2; // C-(C2-X) --> X+(C-C2) - if (match(Op1, m_Sub(m_Constant(C2), m_Value(X))) && !isa<ConstantExpr>(C2)) + if (match(Op1, m_Sub(m_ImmConstant(C2), m_Value(X)))) return BinaryOperator::CreateAdd(X, ConstantExpr::getSub(C, C2)); - - // C-(X+C2) --> (C-C2)-X - if (match(Op1, m_Add(m_Value(X), m_Constant(C2)))) - return BinaryOperator::CreateSub(ConstantExpr::getSub(C, C2), X); } const APInt *Op0C; @@ -1864,6 +1873,22 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { return BinaryOperator::CreateXor(A, B); } + // (sub (add A, B) (or A, B)) --> (and A, B) + { + Value *A, *B; + if (match(Op0, m_Add(m_Value(A), m_Value(B))) && + match(Op1, m_c_Or(m_Specific(A), m_Specific(B)))) + return BinaryOperator::CreateAnd(A, B); + } + + // (sub (add A, B) (and A, B)) --> (or A, B) + { + Value *A, *B; + if (match(Op0, m_Add(m_Value(A), m_Value(B))) && + match(Op1, m_c_And(m_Specific(A), m_Specific(B)))) + return BinaryOperator::CreateOr(A, B); + } + // (sub (and A, B) (or A, B)) --> neg (xor A, B) { Value *A, *B; @@ -2042,6 +2067,20 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { return SelectInst::Create(Cmp, Neg, A); } + // If we are subtracting a low-bit masked subset of some value from an add + // of that same value with no low bits changed, that is clearing some low bits + // of the sum: + // sub (X + AddC), (X & AndC) --> and (X + AddC), ~AndC + const APInt *AddC, *AndC; + if (match(Op0, m_Add(m_Value(X), m_APInt(AddC))) && + match(Op1, m_And(m_Specific(X), m_APInt(AndC)))) { + unsigned BitWidth = Ty->getScalarSizeInBits(); + unsigned Cttz = AddC->countTrailingZeros(); + APInt HighMask(APInt::getHighBitsSet(BitWidth, BitWidth - Cttz)); + if ((HighMask & *AndC).isNullValue()) + return BinaryOperator::CreateAnd(Op0, ConstantInt::get(Ty, ~(*AndC))); + } + if (Instruction *V = canonicalizeCondSignextOfHighBitExtractToSignextHighBitExtract(I)) return V; @@ -2094,11 +2133,11 @@ static Instruction *hoistFNegAboveFMulFDiv(Instruction &I, return nullptr; } -Instruction *InstCombiner::visitFNeg(UnaryOperator &I) { +Instruction *InstCombinerImpl::visitFNeg(UnaryOperator &I) { Value *Op = I.getOperand(0); if (Value *V = SimplifyFNegInst(Op, I.getFastMathFlags(), - SQ.getWithInstruction(&I))) + getSimplifyQuery().getWithInstruction(&I))) return replaceInstUsesWith(I, V); if (Instruction *X = foldFNegIntoConstant(I)) @@ -2117,10 +2156,10 @@ Instruction *InstCombiner::visitFNeg(UnaryOperator &I) { return nullptr; } -Instruction *InstCombiner::visitFSub(BinaryOperator &I) { +Instruction *InstCombinerImpl::visitFSub(BinaryOperator &I) { if (Value *V = SimplifyFSubInst(I.getOperand(0), I.getOperand(1), I.getFastMathFlags(), - SQ.getWithInstruction(&I))) + getSimplifyQuery().getWithInstruction(&I))) return replaceInstUsesWith(I, V); if (Instruction *X = foldVectorBinop(I)) @@ -2175,7 +2214,7 @@ Instruction *InstCombiner::visitFSub(BinaryOperator &I) { // X - C --> X + (-C) // But don't transform constant expressions because there's an inverse fold // for X + (-Y) --> X - Y. - if (match(Op1, m_Constant(C)) && !isa<ConstantExpr>(Op1)) + if (match(Op1, m_ImmConstant(C))) return BinaryOperator::CreateFAddFMF(Op0, ConstantExpr::getFNeg(C), &I); // X - (-Y) --> X + Y @@ -2244,9 +2283,8 @@ Instruction *InstCombiner::visitFSub(BinaryOperator &I) { } auto m_FaddRdx = [](Value *&Sum, Value *&Vec) { - return m_OneUse( - m_Intrinsic<Intrinsic::experimental_vector_reduce_v2_fadd>( - m_Value(Sum), m_Value(Vec))); + return m_OneUse(m_Intrinsic<Intrinsic::vector_reduce_fadd>(m_Value(Sum), + m_Value(Vec))); }; Value *A0, *A1, *V0, *V1; if (match(Op0, m_FaddRdx(A0, V0)) && match(Op1, m_FaddRdx(A1, V1)) && @@ -2254,9 +2292,8 @@ Instruction *InstCombiner::visitFSub(BinaryOperator &I) { // Difference of sums is sum of differences: // add_rdx(A0, V0) - add_rdx(A1, V1) --> add_rdx(A0, V0 - V1) - A1 Value *Sub = Builder.CreateFSubFMF(V0, V1, &I); - Value *Rdx = Builder.CreateIntrinsic( - Intrinsic::experimental_vector_reduce_v2_fadd, - {A0->getType(), Sub->getType()}, {A0, Sub}, &I); + Value *Rdx = Builder.CreateIntrinsic(Intrinsic::vector_reduce_fadd, + {Sub->getType()}, {A0, Sub}, &I); return BinaryOperator::CreateFSubFMF(Rdx, A1, &I); } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index d3c718a919c0..68c4156af2c4 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -13,10 +13,12 @@ #include "InstCombineInternal.h" #include "llvm/Analysis/CmpInstAnalysis.h" #include "llvm/Analysis/InstructionSimplify.h" -#include "llvm/Transforms/Utils/Local.h" #include "llvm/IR/ConstantRange.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/Transforms/InstCombine/InstCombiner.h" +#include "llvm/Transforms/Utils/Local.h" + using namespace llvm; using namespace PatternMatch; @@ -112,57 +114,12 @@ static Value *SimplifyBSwap(BinaryOperator &I, return Builder.CreateCall(F, BinOp); } -/// This handles expressions of the form ((val OP C1) & C2). Where -/// the Op parameter is 'OP', OpRHS is 'C1', and AndRHS is 'C2'. -Instruction *InstCombiner::OptAndOp(BinaryOperator *Op, - ConstantInt *OpRHS, - ConstantInt *AndRHS, - BinaryOperator &TheAnd) { - Value *X = Op->getOperand(0); - - switch (Op->getOpcode()) { - default: break; - case Instruction::Add: - if (Op->hasOneUse()) { - // Adding a one to a single bit bit-field should be turned into an XOR - // of the bit. First thing to check is to see if this AND is with a - // single bit constant. - const APInt &AndRHSV = AndRHS->getValue(); - - // If there is only one bit set. - if (AndRHSV.isPowerOf2()) { - // Ok, at this point, we know that we are masking the result of the - // ADD down to exactly one bit. If the constant we are adding has - // no bits set below this bit, then we can eliminate the ADD. - const APInt& AddRHS = OpRHS->getValue(); - - // Check to see if any bits below the one bit set in AndRHSV are set. - if ((AddRHS & (AndRHSV - 1)).isNullValue()) { - // If not, the only thing that can effect the output of the AND is - // the bit specified by AndRHSV. If that bit is set, the effect of - // the XOR is to toggle the bit. If it is clear, then the ADD has - // no effect. - if ((AddRHS & AndRHSV).isNullValue()) { // Bit is not set, noop - return replaceOperand(TheAnd, 0, X); - } else { - // Pull the XOR out of the AND. - Value *NewAnd = Builder.CreateAnd(X, AndRHS); - NewAnd->takeName(Op); - return BinaryOperator::CreateXor(NewAnd, AndRHS); - } - } - } - } - break; - } - return nullptr; -} - /// Emit a computation of: (V >= Lo && V < Hi) if Inside is true, otherwise /// (V < Lo || V >= Hi). This method expects that Lo < Hi. IsSigned indicates /// whether to treat V, Lo, and Hi as signed or not. -Value *InstCombiner::insertRangeTest(Value *V, const APInt &Lo, const APInt &Hi, - bool isSigned, bool Inside) { +Value *InstCombinerImpl::insertRangeTest(Value *V, const APInt &Lo, + const APInt &Hi, bool isSigned, + bool Inside) { assert((isSigned ? Lo.slt(Hi) : Lo.ult(Hi)) && "Lo is not < Hi in range emission code!"); @@ -437,11 +394,10 @@ getMaskedTypeForICmpPair(Value *&A, Value *&B, Value *&C, /// (icmp(A & X) ==/!= Y), where the left-hand side is of type Mask_NotAllZeros /// and the right hand side is of type BMask_Mixed. For example, /// (icmp (A & 12) != 0) & (icmp (A & 15) == 8) -> (icmp (A & 15) == 8). -static Value * foldLogOpOfMaskedICmps_NotAllZeros_BMask_Mixed( - ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, - Value *A, Value *B, Value *C, Value *D, Value *E, - ICmpInst::Predicate PredL, ICmpInst::Predicate PredR, - llvm::InstCombiner::BuilderTy &Builder) { +static Value *foldLogOpOfMaskedICmps_NotAllZeros_BMask_Mixed( + ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, Value *A, Value *B, Value *C, + Value *D, Value *E, ICmpInst::Predicate PredL, ICmpInst::Predicate PredR, + InstCombiner::BuilderTy &Builder) { // We are given the canonical form: // (icmp ne (A & B), 0) & (icmp eq (A & D), E). // where D & E == E. @@ -452,17 +408,9 @@ static Value * foldLogOpOfMaskedICmps_NotAllZeros_BMask_Mixed( // // We currently handle the case of B, C, D, E are constant. // - ConstantInt *BCst = dyn_cast<ConstantInt>(B); - if (!BCst) - return nullptr; - ConstantInt *CCst = dyn_cast<ConstantInt>(C); - if (!CCst) - return nullptr; - ConstantInt *DCst = dyn_cast<ConstantInt>(D); - if (!DCst) - return nullptr; - ConstantInt *ECst = dyn_cast<ConstantInt>(E); - if (!ECst) + ConstantInt *BCst, *CCst, *DCst, *ECst; + if (!match(B, m_ConstantInt(BCst)) || !match(C, m_ConstantInt(CCst)) || + !match(D, m_ConstantInt(DCst)) || !match(E, m_ConstantInt(ECst))) return nullptr; ICmpInst::Predicate NewCC = IsAnd ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE; @@ -568,11 +516,9 @@ static Value * foldLogOpOfMaskedICmps_NotAllZeros_BMask_Mixed( /// (icmp(A & X) ==/!= Y), where the left-hand side and the right hand side /// aren't of the common mask pattern type. static Value *foldLogOpOfMaskedICmpsAsymmetric( - ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, - Value *A, Value *B, Value *C, Value *D, Value *E, - ICmpInst::Predicate PredL, ICmpInst::Predicate PredR, - unsigned LHSMask, unsigned RHSMask, - llvm::InstCombiner::BuilderTy &Builder) { + ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, Value *A, Value *B, Value *C, + Value *D, Value *E, ICmpInst::Predicate PredL, ICmpInst::Predicate PredR, + unsigned LHSMask, unsigned RHSMask, InstCombiner::BuilderTy &Builder) { assert(ICmpInst::isEquality(PredL) && ICmpInst::isEquality(PredR) && "Expected equality predicates for masked type of icmps."); // Handle Mask_NotAllZeros-BMask_Mixed cases. @@ -603,7 +549,7 @@ static Value *foldLogOpOfMaskedICmpsAsymmetric( /// Try to fold (icmp(A & B) ==/!= C) &/| (icmp(A & D) ==/!= E) /// into a single (icmp(A & X) ==/!= Y). static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, - llvm::InstCombiner::BuilderTy &Builder) { + InstCombiner::BuilderTy &Builder) { Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr, *E = nullptr; ICmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate(); Optional<std::pair<unsigned, unsigned>> MaskPair = @@ -673,11 +619,8 @@ static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, // Remaining cases assume at least that B and D are constant, and depend on // their actual values. This isn't strictly necessary, just a "handle the // easy cases for now" decision. - ConstantInt *BCst = dyn_cast<ConstantInt>(B); - if (!BCst) - return nullptr; - ConstantInt *DCst = dyn_cast<ConstantInt>(D); - if (!DCst) + ConstantInt *BCst, *DCst; + if (!match(B, m_ConstantInt(BCst)) || !match(D, m_ConstantInt(DCst))) return nullptr; if (Mask & (Mask_NotAllZeros | BMask_NotAllOnes)) { @@ -718,11 +661,8 @@ static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, // We can't simply use C and E because we might actually handle // (icmp ne (A & B), B) & (icmp eq (A & D), D) // with B and D, having a single bit set. - ConstantInt *CCst = dyn_cast<ConstantInt>(C); - if (!CCst) - return nullptr; - ConstantInt *ECst = dyn_cast<ConstantInt>(E); - if (!ECst) + ConstantInt *CCst, *ECst; + if (!match(C, m_ConstantInt(CCst)) || !match(E, m_ConstantInt(ECst))) return nullptr; if (PredL != NewCC) CCst = cast<ConstantInt>(ConstantExpr::getXor(BCst, CCst)); @@ -748,8 +688,8 @@ static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, /// Example: (icmp sge x, 0) & (icmp slt x, n) --> icmp ult x, n /// If \p Inverted is true then the check is for the inverted range, e.g. /// (icmp slt x, 0) | (icmp sgt x, n) --> icmp ugt x, n -Value *InstCombiner::simplifyRangeCheck(ICmpInst *Cmp0, ICmpInst *Cmp1, - bool Inverted) { +Value *InstCombinerImpl::simplifyRangeCheck(ICmpInst *Cmp0, ICmpInst *Cmp1, + bool Inverted) { // Check the lower range comparison, e.g. x >= 0 // InstCombine already ensured that if there is a constant it's on the RHS. ConstantInt *RangeStart = dyn_cast<ConstantInt>(Cmp0->getOperand(1)); @@ -856,8 +796,9 @@ foldAndOrOfEqualityCmpsWithConstants(ICmpInst *LHS, ICmpInst *RHS, // Fold (iszero(A & K1) | iszero(A & K2)) -> (A & (K1 | K2)) != (K1 | K2) // Fold (!iszero(A & K1) & !iszero(A & K2)) -> (A & (K1 | K2)) == (K1 | K2) -Value *InstCombiner::foldAndOrOfICmpsOfAndWithPow2(ICmpInst *LHS, ICmpInst *RHS, - BinaryOperator &Logic) { +Value *InstCombinerImpl::foldAndOrOfICmpsOfAndWithPow2(ICmpInst *LHS, + ICmpInst *RHS, + BinaryOperator &Logic) { bool JoinedByAnd = Logic.getOpcode() == Instruction::And; assert((JoinedByAnd || Logic.getOpcode() == Instruction::Or) && "Wrong opcode"); @@ -869,10 +810,8 @@ Value *InstCombiner::foldAndOrOfICmpsOfAndWithPow2(ICmpInst *LHS, ICmpInst *RHS, if (!JoinedByAnd && Pred != ICmpInst::ICMP_EQ) return nullptr; - // TODO support vector splats - ConstantInt *LHSC = dyn_cast<ConstantInt>(LHS->getOperand(1)); - ConstantInt *RHSC = dyn_cast<ConstantInt>(RHS->getOperand(1)); - if (!LHSC || !RHSC || !LHSC->isZero() || !RHSC->isZero()) + if (!match(LHS->getOperand(1), m_Zero()) || + !match(RHS->getOperand(1), m_Zero())) return nullptr; Value *A, *B, *C, *D; @@ -1148,11 +1087,12 @@ static Value *foldAndOrOfICmpsWithConstEq(ICmpInst *Cmp0, ICmpInst *Cmp1, assert((IsAnd || Logic.getOpcode() == Instruction::Or) && "Wrong logic op"); // Match an equality compare with a non-poison constant as Cmp0. + // Also, give up if the compare can be constant-folded to avoid looping. ICmpInst::Predicate Pred0; Value *X; Constant *C; if (!match(Cmp0, m_ICmp(Pred0, m_Value(X), m_Constant(C))) || - !isGuaranteedNotToBeUndefOrPoison(C)) + !isGuaranteedNotToBeUndefOrPoison(C) || isa<Constant>(X)) return nullptr; if ((IsAnd && Pred0 != ICmpInst::ICMP_EQ) || (!IsAnd && Pred0 != ICmpInst::ICMP_NE)) @@ -1183,8 +1123,8 @@ static Value *foldAndOrOfICmpsWithConstEq(ICmpInst *Cmp0, ICmpInst *Cmp1, } /// Fold (icmp)&(icmp) if possible. -Value *InstCombiner::foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS, - BinaryOperator &And) { +Value *InstCombinerImpl::foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS, + BinaryOperator &And) { const SimplifyQuery Q = SQ.getWithInstruction(&And); // Fold (!iszero(A & K1) & !iszero(A & K2)) -> (A & (K1 | K2)) == (K1 | K2) @@ -1243,9 +1183,10 @@ Value *InstCombiner::foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS, // This only handles icmp of constants: (icmp1 A, C1) & (icmp2 B, C2). Value *LHS0 = LHS->getOperand(0), *RHS0 = RHS->getOperand(0); - ConstantInt *LHSC = dyn_cast<ConstantInt>(LHS->getOperand(1)); - ConstantInt *RHSC = dyn_cast<ConstantInt>(RHS->getOperand(1)); - if (!LHSC || !RHSC) + + ConstantInt *LHSC, *RHSC; + if (!match(LHS->getOperand(1), m_ConstantInt(LHSC)) || + !match(RHS->getOperand(1), m_ConstantInt(RHSC))) return nullptr; if (LHSC == RHSC && PredL == PredR) { @@ -1403,7 +1344,8 @@ Value *InstCombiner::foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS, return nullptr; } -Value *InstCombiner::foldLogicOfFCmps(FCmpInst *LHS, FCmpInst *RHS, bool IsAnd) { +Value *InstCombinerImpl::foldLogicOfFCmps(FCmpInst *LHS, FCmpInst *RHS, + bool IsAnd) { Value *LHS0 = LHS->getOperand(0), *LHS1 = LHS->getOperand(1); Value *RHS0 = RHS->getOperand(0), *RHS1 = RHS->getOperand(1); FCmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate(); @@ -1513,8 +1455,8 @@ static Instruction *matchDeMorgansLaws(BinaryOperator &I, Value *A, *B; if (match(I.getOperand(0), m_OneUse(m_Not(m_Value(A)))) && match(I.getOperand(1), m_OneUse(m_Not(m_Value(B)))) && - !isFreeToInvert(A, A->hasOneUse()) && - !isFreeToInvert(B, B->hasOneUse())) { + !InstCombiner::isFreeToInvert(A, A->hasOneUse()) && + !InstCombiner::isFreeToInvert(B, B->hasOneUse())) { Value *AndOr = Builder.CreateBinOp(Opcode, A, B, I.getName() + ".demorgan"); return BinaryOperator::CreateNot(AndOr); } @@ -1522,7 +1464,7 @@ static Instruction *matchDeMorgansLaws(BinaryOperator &I, return nullptr; } -bool InstCombiner::shouldOptimizeCast(CastInst *CI) { +bool InstCombinerImpl::shouldOptimizeCast(CastInst *CI) { Value *CastSrc = CI->getOperand(0); // Noop casts and casts of constants should be eliminated trivially. @@ -1578,7 +1520,7 @@ static Instruction *foldLogicCastConstant(BinaryOperator &Logic, CastInst *Cast, } /// Fold {and,or,xor} (cast X), Y. -Instruction *InstCombiner::foldCastedBitwiseLogic(BinaryOperator &I) { +Instruction *InstCombinerImpl::foldCastedBitwiseLogic(BinaryOperator &I) { auto LogicOpc = I.getOpcode(); assert(I.isBitwiseLogicOp() && "Unexpected opcode for bitwise logic folding"); @@ -1685,6 +1627,14 @@ static Instruction *foldOrToXor(BinaryOperator &I, match(Op1, m_Not(m_c_Or(m_Specific(A), m_Specific(B))))) return BinaryOperator::CreateNot(Builder.CreateXor(A, B)); + // Operand complexity canonicalization guarantees that the 'xor' is Op0. + // (A ^ B) | ~(A | B) --> ~(A & B) + // (A ^ B) | ~(B | A) --> ~(A & B) + if (Op0->hasOneUse() || Op1->hasOneUse()) + if (match(Op0, m_Xor(m_Value(A), m_Value(B))) && + match(Op1, m_Not(m_c_Or(m_Specific(A), m_Specific(B))))) + return BinaryOperator::CreateNot(Builder.CreateAnd(A, B)); + // (A & ~B) | (~A & B) --> A ^ B // (A & ~B) | (B & ~A) --> A ^ B // (~B & A) | (~A & B) --> A ^ B @@ -1699,32 +1649,13 @@ static Instruction *foldOrToXor(BinaryOperator &I, /// Return true if a constant shift amount is always less than the specified /// bit-width. If not, the shift could create poison in the narrower type. static bool canNarrowShiftAmt(Constant *C, unsigned BitWidth) { - if (auto *ScalarC = dyn_cast<ConstantInt>(C)) - return ScalarC->getZExtValue() < BitWidth; - - if (C->getType()->isVectorTy()) { - // Check each element of a constant vector. - unsigned NumElts = cast<VectorType>(C->getType())->getNumElements(); - for (unsigned i = 0; i != NumElts; ++i) { - Constant *Elt = C->getAggregateElement(i); - if (!Elt) - return false; - if (isa<UndefValue>(Elt)) - continue; - auto *CI = dyn_cast<ConstantInt>(Elt); - if (!CI || CI->getZExtValue() >= BitWidth) - return false; - } - return true; - } - - // The constant is a constant expression or unknown. - return false; + APInt Threshold(C->getType()->getScalarSizeInBits(), BitWidth); + return match(C, m_SpecificInt_ICMP(ICmpInst::ICMP_ULT, Threshold)); } /// Try to use narrower ops (sink zext ops) for an 'and' with binop operand and /// a common zext operand: and (binop (zext X), C), (zext X). -Instruction *InstCombiner::narrowMaskedBinOp(BinaryOperator &And) { +Instruction *InstCombinerImpl::narrowMaskedBinOp(BinaryOperator &And) { // This transform could also apply to {or, and, xor}, but there are better // folds for those cases, so we don't expect those patterns here. AShr is not // handled because it should always be transformed to LShr in this sequence. @@ -1766,7 +1697,9 @@ Instruction *InstCombiner::narrowMaskedBinOp(BinaryOperator &And) { // FIXME: We use commutative matchers (m_c_*) for some, but not all, matches // here. We should standardize that construct where it is needed or choose some // other way to ensure that commutated variants of patterns are not missed. -Instruction *InstCombiner::visitAnd(BinaryOperator &I) { +Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) { + Type *Ty = I.getType(); + if (Value *V = SimplifyAndInst(I.getOperand(0), I.getOperand(1), SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); @@ -1794,21 +1727,22 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { return replaceInstUsesWith(I, V); Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + + Value *X, *Y; + if (match(Op0, m_OneUse(m_LogicalShift(m_One(), m_Value(X)))) && + match(Op1, m_One())) { + // (1 << X) & 1 --> zext(X == 0) + // (1 >> X) & 1 --> zext(X == 0) + Value *IsZero = Builder.CreateICmpEQ(X, ConstantInt::get(Ty, 0)); + return new ZExtInst(IsZero, Ty); + } + const APInt *C; if (match(Op1, m_APInt(C))) { - Value *X, *Y; - if (match(Op0, m_OneUse(m_LogicalShift(m_One(), m_Value(X)))) && - C->isOneValue()) { - // (1 << X) & 1 --> zext(X == 0) - // (1 >> X) & 1 --> zext(X == 0) - Value *IsZero = Builder.CreateICmpEQ(X, ConstantInt::get(I.getType(), 0)); - return new ZExtInst(IsZero, I.getType()); - } - const APInt *XorC; if (match(Op0, m_OneUse(m_Xor(m_Value(X), m_APInt(XorC))))) { // (X ^ C1) & C2 --> (X & C2) ^ (C1&C2) - Constant *NewC = ConstantInt::get(I.getType(), *C & *XorC); + Constant *NewC = ConstantInt::get(Ty, *C & *XorC); Value *And = Builder.CreateAnd(X, Op1); And->takeName(Op0); return BinaryOperator::CreateXor(And, NewC); @@ -1823,11 +1757,9 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { // that aren't set in C2. Meaning we can replace (C1&C2) with C1 in // above, but this feels safer. APInt Together = *C & *OrC; - Value *And = Builder.CreateAnd(X, ConstantInt::get(I.getType(), - Together ^ *C)); + Value *And = Builder.CreateAnd(X, ConstantInt::get(Ty, Together ^ *C)); And->takeName(Op0); - return BinaryOperator::CreateOr(And, ConstantInt::get(I.getType(), - Together)); + return BinaryOperator::CreateOr(And, ConstantInt::get(Ty, Together)); } // If the mask is only needed on one incoming arm, push the 'and' op up. @@ -1848,27 +1780,49 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { return BinaryOperator::Create(BinOp, NewLHS, Y); } } + + unsigned Width = Ty->getScalarSizeInBits(); const APInt *ShiftC; if (match(Op0, m_OneUse(m_SExt(m_AShr(m_Value(X), m_APInt(ShiftC)))))) { - unsigned Width = I.getType()->getScalarSizeInBits(); if (*C == APInt::getLowBitsSet(Width, Width - ShiftC->getZExtValue())) { // We are clearing high bits that were potentially set by sext+ashr: // and (sext (ashr X, ShiftC)), C --> lshr (sext X), ShiftC - Value *Sext = Builder.CreateSExt(X, I.getType()); - Constant *ShAmtC = ConstantInt::get(I.getType(), ShiftC->zext(Width)); + Value *Sext = Builder.CreateSExt(X, Ty); + Constant *ShAmtC = ConstantInt::get(Ty, ShiftC->zext(Width)); return BinaryOperator::CreateLShr(Sext, ShAmtC); } } + + const APInt *AddC; + if (match(Op0, m_Add(m_Value(X), m_APInt(AddC)))) { + // If we add zeros to every bit below a mask, the add has no effect: + // (X + AddC) & LowMaskC --> X & LowMaskC + unsigned Ctlz = C->countLeadingZeros(); + APInt LowMask(APInt::getLowBitsSet(Width, Width - Ctlz)); + if ((*AddC & LowMask).isNullValue()) + return BinaryOperator::CreateAnd(X, Op1); + + // If we are masking the result of the add down to exactly one bit and + // the constant we are adding has no bits set below that bit, then the + // add is flipping a single bit. Example: + // (X + 4) & 4 --> (X & 4) ^ 4 + if (Op0->hasOneUse() && C->isPowerOf2() && (*AddC & (*C - 1)) == 0) { + assert((*C & *AddC) != 0 && "Expected common bit"); + Value *NewAnd = Builder.CreateAnd(X, Op1); + return BinaryOperator::CreateXor(NewAnd, Op1); + } + } } - if (ConstantInt *AndRHS = dyn_cast<ConstantInt>(Op1)) { + ConstantInt *AndRHS; + if (match(Op1, m_ConstantInt(AndRHS))) { const APInt &AndRHSMask = AndRHS->getValue(); // Optimize a variety of ((val OP C1) & C2) combinations... if (BinaryOperator *Op0I = dyn_cast<BinaryOperator>(Op0)) { // ((C1 OP zext(X)) & C2) -> zext((C1-X) & C2) if C2 fits in the bitwidth // of X and OP behaves well when given trunc(C1) and X. - // TODO: Do this for vectors by using m_APInt isntead of m_ConstantInt. + // TODO: Do this for vectors by using m_APInt instead of m_ConstantInt. switch (Op0I->getOpcode()) { default: break; @@ -1893,31 +1847,30 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { BinOp = Builder.CreateBinOp(Op0I->getOpcode(), TruncC1, X); auto *TruncC2 = ConstantExpr::getTrunc(AndRHS, X->getType()); auto *And = Builder.CreateAnd(BinOp, TruncC2); - return new ZExtInst(And, I.getType()); + return new ZExtInst(And, Ty); } } } - - if (ConstantInt *Op0CI = dyn_cast<ConstantInt>(Op0I->getOperand(1))) - if (Instruction *Res = OptAndOp(Op0I, Op0CI, AndRHS, I)) - return Res; } + } - // If this is an integer truncation, and if the source is an 'and' with - // immediate, transform it. This frequently occurs for bitfield accesses. - { - Value *X = nullptr; ConstantInt *YC = nullptr; - if (match(Op0, m_Trunc(m_And(m_Value(X), m_ConstantInt(YC))))) { - // Change: and (trunc (and X, YC) to T), C2 - // into : and (trunc X to T), trunc(YC) & C2 - // This will fold the two constants together, which may allow - // other simplifications. - Value *NewCast = Builder.CreateTrunc(X, I.getType(), "and.shrunk"); - Constant *C3 = ConstantExpr::getTrunc(YC, I.getType()); - C3 = ConstantExpr::getAnd(C3, AndRHS); - return BinaryOperator::CreateAnd(NewCast, C3); - } - } + if (match(&I, m_And(m_OneUse(m_Shl(m_ZExt(m_Value(X)), m_Value(Y))), + m_SignMask())) && + match(Y, m_SpecificInt_ICMP( + ICmpInst::Predicate::ICMP_EQ, + APInt(Ty->getScalarSizeInBits(), + Ty->getScalarSizeInBits() - + X->getType()->getScalarSizeInBits())))) { + auto *SExt = Builder.CreateSExt(X, Ty, X->getName() + ".signext"); + auto *SanitizedSignMask = cast<Constant>(Op1); + // We must be careful with the undef elements of the sign bit mask, however: + // the mask elt can be undef iff the shift amount for that lane was undef, + // otherwise we need to sanitize undef masks to zero. + SanitizedSignMask = Constant::replaceUndefsWith( + SanitizedSignMask, ConstantInt::getNullValue(Ty->getScalarType())); + SanitizedSignMask = + Constant::mergeUndefsWith(SanitizedSignMask, cast<Constant>(Y)); + return BinaryOperator::CreateAnd(SExt, SanitizedSignMask); } if (Instruction *Z = narrowMaskedBinOp(I)) @@ -1938,6 +1891,13 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { if (match(Op0, m_OneUse(m_c_Xor(m_Specific(Op1), m_Value(B))))) return BinaryOperator::CreateAnd(Op1, Builder.CreateNot(B)); + // A & ~(A ^ B) --> A & B + if (match(Op1, m_Not(m_c_Xor(m_Specific(Op0), m_Value(B))))) + return BinaryOperator::CreateAnd(Op0, B); + // ~(A ^ B) & A --> A & B + if (match(Op0, m_Not(m_c_Xor(m_Specific(Op1), m_Value(B))))) + return BinaryOperator::CreateAnd(Op1, B); + // (A ^ B) & ((B ^ C) ^ A) -> (A ^ B) & ~C if (match(Op0, m_Xor(m_Value(A), m_Value(B)))) if (match(Op1, m_Xor(m_Xor(m_Specific(B), m_Value(C)), m_Specific(A)))) @@ -1976,7 +1936,6 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { // TODO: Make this recursive; it's a little tricky because an arbitrary // number of 'and' instructions might have to be created. - Value *X, *Y; if (LHS && match(Op1, m_OneUse(m_And(m_Value(X), m_Value(Y))))) { if (auto *Cmp = dyn_cast<ICmpInst>(X)) if (Value *Res = foldAndOfICmps(LHS, Cmp, I)) @@ -2010,29 +1969,30 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { Value *A; if (match(Op0, m_OneUse(m_SExt(m_Value(A)))) && A->getType()->isIntOrIntVectorTy(1)) - return SelectInst::Create(A, Op1, Constant::getNullValue(I.getType())); + return SelectInst::Create(A, Op1, Constant::getNullValue(Ty)); if (match(Op1, m_OneUse(m_SExt(m_Value(A)))) && A->getType()->isIntOrIntVectorTy(1)) - return SelectInst::Create(A, Op0, Constant::getNullValue(I.getType())); + return SelectInst::Create(A, Op0, Constant::getNullValue(Ty)); // and(ashr(subNSW(Y, X), ScalarSizeInBits(Y)-1), X) --> X s> Y ? X : 0. - { - Value *X, *Y; - const APInt *ShAmt; - Type *Ty = I.getType(); - if (match(&I, m_c_And(m_OneUse(m_AShr(m_NSWSub(m_Value(Y), m_Value(X)), - m_APInt(ShAmt))), - m_Deferred(X))) && - *ShAmt == Ty->getScalarSizeInBits() - 1) { - Value *NewICmpInst = Builder.CreateICmpSGT(X, Y); - return SelectInst::Create(NewICmpInst, X, ConstantInt::getNullValue(Ty)); - } + if (match(&I, m_c_And(m_OneUse(m_AShr( + m_NSWSub(m_Value(Y), m_Value(X)), + m_SpecificInt(Ty->getScalarSizeInBits() - 1))), + m_Deferred(X)))) { + Value *NewICmpInst = Builder.CreateICmpSGT(X, Y); + return SelectInst::Create(NewICmpInst, X, ConstantInt::getNullValue(Ty)); } + // (~x) & y --> ~(x | (~y)) iff that gets rid of inversions + if (sinkNotIntoOtherHandOfAndOrOr(I)) + return &I; + return nullptr; } -Instruction *InstCombiner::matchBSwap(BinaryOperator &Or) { +Instruction *InstCombinerImpl::matchBSwapOrBitReverse(BinaryOperator &Or, + bool MatchBSwaps, + bool MatchBitReversals) { assert(Or.getOpcode() == Instruction::Or && "bswap requires an 'or'"); Value *Op0 = Or.getOperand(0), *Op1 = Or.getOperand(1); @@ -2044,33 +2004,32 @@ Instruction *InstCombiner::matchBSwap(BinaryOperator &Or) { Op1 = Ext->getOperand(0); // (A | B) | C and A | (B | C) -> bswap if possible. - bool OrOfOrs = match(Op0, m_Or(m_Value(), m_Value())) || - match(Op1, m_Or(m_Value(), m_Value())); + bool OrWithOrs = match(Op0, m_Or(m_Value(), m_Value())) || + match(Op1, m_Or(m_Value(), m_Value())); - // (A >> B) | (C << D) and (A << B) | (B >> C) -> bswap if possible. - bool OrOfShifts = match(Op0, m_LogicalShift(m_Value(), m_Value())) && - match(Op1, m_LogicalShift(m_Value(), m_Value())); + // (A >> B) | C and (A << B) | C -> bswap if possible. + bool OrWithShifts = match(Op0, m_LogicalShift(m_Value(), m_Value())) || + match(Op1, m_LogicalShift(m_Value(), m_Value())); - // (A & B) | (C & D) -> bswap if possible. - bool OrOfAnds = match(Op0, m_And(m_Value(), m_Value())) && - match(Op1, m_And(m_Value(), m_Value())); + // (A & B) | C and A | (B & C) -> bswap if possible. + bool OrWithAnds = match(Op0, m_And(m_Value(), m_Value())) || + match(Op1, m_And(m_Value(), m_Value())); - // (A << B) | (C & D) -> bswap if possible. - // The bigger pattern here is ((A & C1) << C2) | ((B >> C2) & C1), which is a - // part of the bswap idiom for specific values of C1, C2 (e.g. C1 = 16711935, - // C2 = 8 for i32). - // This pattern can occur when the operands of the 'or' are not canonicalized - // for some reason (not having only one use, for example). - bool OrOfAndAndSh = (match(Op0, m_LogicalShift(m_Value(), m_Value())) && - match(Op1, m_And(m_Value(), m_Value()))) || - (match(Op0, m_And(m_Value(), m_Value())) && - match(Op1, m_LogicalShift(m_Value(), m_Value()))); + // fshl(A,B,C) | D and A | fshl(B,C,D) -> bswap if possible. + // fshr(A,B,C) | D and A | fshr(B,C,D) -> bswap if possible. + bool OrWithFunnels = match(Op0, m_FShl(m_Value(), m_Value(), m_Value())) || + match(Op0, m_FShr(m_Value(), m_Value(), m_Value())) || + match(Op0, m_FShl(m_Value(), m_Value(), m_Value())) || + match(Op0, m_FShr(m_Value(), m_Value(), m_Value())); - if (!OrOfOrs && !OrOfShifts && !OrOfAnds && !OrOfAndAndSh) + // TODO: Do we need all these filtering checks or should we just rely on + // recognizeBSwapOrBitReverseIdiom + collectBitParts to reject them quickly? + if (!OrWithOrs && !OrWithShifts && !OrWithAnds && !OrWithFunnels) return nullptr; - SmallVector<Instruction*, 4> Insts; - if (!recognizeBSwapOrBitReverseIdiom(&Or, true, false, Insts)) + SmallVector<Instruction *, 4> Insts; + if (!recognizeBSwapOrBitReverseIdiom(&Or, MatchBSwaps, MatchBitReversals, + Insts)) return nullptr; Instruction *LastInst = Insts.pop_back_val(); LastInst->removeFromParent(); @@ -2080,34 +2039,72 @@ Instruction *InstCombiner::matchBSwap(BinaryOperator &Or) { return LastInst; } -/// Transform UB-safe variants of bitwise rotate to the funnel shift intrinsic. -static Instruction *matchRotate(Instruction &Or) { +/// Match UB-safe variants of the funnel shift intrinsic. +static Instruction *matchFunnelShift(Instruction &Or, InstCombinerImpl &IC) { // TODO: Can we reduce the code duplication between this and the related // rotate matching code under visitSelect and visitTrunc? unsigned Width = Or.getType()->getScalarSizeInBits(); - if (!isPowerOf2_32(Width)) - return nullptr; - // First, find an or'd pair of opposite shifts with the same shifted operand: - // or (lshr ShVal, ShAmt0), (shl ShVal, ShAmt1) + // First, find an or'd pair of opposite shifts: + // or (lshr ShVal0, ShAmt0), (shl ShVal1, ShAmt1) BinaryOperator *Or0, *Or1; if (!match(Or.getOperand(0), m_BinOp(Or0)) || !match(Or.getOperand(1), m_BinOp(Or1))) return nullptr; - Value *ShVal, *ShAmt0, *ShAmt1; - if (!match(Or0, m_OneUse(m_LogicalShift(m_Value(ShVal), m_Value(ShAmt0)))) || - !match(Or1, m_OneUse(m_LogicalShift(m_Specific(ShVal), m_Value(ShAmt1))))) - return nullptr; + Value *ShVal0, *ShVal1, *ShAmt0, *ShAmt1; + if (!match(Or0, m_OneUse(m_LogicalShift(m_Value(ShVal0), m_Value(ShAmt0)))) || + !match(Or1, m_OneUse(m_LogicalShift(m_Value(ShVal1), m_Value(ShAmt1)))) || + Or0->getOpcode() == Or1->getOpcode()) + return nullptr; + + // Canonicalize to or(shl(ShVal0, ShAmt0), lshr(ShVal1, ShAmt1)). + if (Or0->getOpcode() == BinaryOperator::LShr) { + std::swap(Or0, Or1); + std::swap(ShVal0, ShVal1); + std::swap(ShAmt0, ShAmt1); + } + assert(Or0->getOpcode() == BinaryOperator::Shl && + Or1->getOpcode() == BinaryOperator::LShr && + "Illegal or(shift,shift) pair"); + + // Match the shift amount operands for a funnel shift pattern. This always + // matches a subtraction on the R operand. + auto matchShiftAmount = [&](Value *L, Value *R, unsigned Width) -> Value * { + // Check for constant shift amounts that sum to the bitwidth. + const APInt *LI, *RI; + if (match(L, m_APIntAllowUndef(LI)) && match(R, m_APIntAllowUndef(RI))) + if (LI->ult(Width) && RI->ult(Width) && (*LI + *RI) == Width) + return ConstantInt::get(L->getType(), *LI); + + Constant *LC, *RC; + if (match(L, m_Constant(LC)) && match(R, m_Constant(RC)) && + match(L, m_SpecificInt_ICMP(ICmpInst::ICMP_ULT, APInt(Width, Width))) && + match(R, m_SpecificInt_ICMP(ICmpInst::ICMP_ULT, APInt(Width, Width))) && + match(ConstantExpr::getAdd(LC, RC), m_SpecificIntAllowUndef(Width))) + return ConstantExpr::mergeUndefsWith(LC, RC); + + // (shl ShVal, X) | (lshr ShVal, (Width - x)) iff X < Width. + // We limit this to X < Width in case the backend re-expands the intrinsic, + // and has to reintroduce a shift modulo operation (InstCombine might remove + // it after this fold). This still doesn't guarantee that the final codegen + // will match this original pattern. + if (match(R, m_OneUse(m_Sub(m_SpecificInt(Width), m_Specific(L))))) { + KnownBits KnownL = IC.computeKnownBits(L, /*Depth*/ 0, &Or); + return KnownL.getMaxValue().ult(Width) ? L : nullptr; + } - BinaryOperator::BinaryOps ShiftOpcode0 = Or0->getOpcode(); - BinaryOperator::BinaryOps ShiftOpcode1 = Or1->getOpcode(); - if (ShiftOpcode0 == ShiftOpcode1) - return nullptr; + // For non-constant cases, the following patterns currently only work for + // rotation patterns. + // TODO: Add general funnel-shift compatible patterns. + if (ShVal0 != ShVal1) + return nullptr; + + // For non-constant cases we don't support non-pow2 shift masks. + // TODO: Is it worth matching urem as well? + if (!isPowerOf2_32(Width)) + return nullptr; - // Match the shift amount operands for a rotate pattern. This always matches - // a subtraction on the R operand. - auto matchShiftAmount = [](Value *L, Value *R, unsigned Width) -> Value * { // The shift amount may be masked with negation: // (shl ShVal, (X & (Width - 1))) | (lshr ShVal, ((-X) & (Width - 1))) Value *X; @@ -2123,23 +2120,25 @@ static Instruction *matchRotate(Instruction &Or) { m_SpecificInt(Mask)))) return L; + if (match(L, m_ZExt(m_And(m_Value(X), m_SpecificInt(Mask)))) && + match(R, m_ZExt(m_And(m_Neg(m_Specific(X)), m_SpecificInt(Mask))))) + return L; + return nullptr; }; Value *ShAmt = matchShiftAmount(ShAmt0, ShAmt1, Width); - bool SubIsOnLHS = false; + bool IsFshl = true; // Sub on LSHR. if (!ShAmt) { ShAmt = matchShiftAmount(ShAmt1, ShAmt0, Width); - SubIsOnLHS = true; + IsFshl = false; // Sub on SHL. } if (!ShAmt) return nullptr; - bool IsFshl = (!SubIsOnLHS && ShiftOpcode0 == BinaryOperator::Shl) || - (SubIsOnLHS && ShiftOpcode1 == BinaryOperator::Shl); Intrinsic::ID IID = IsFshl ? Intrinsic::fshl : Intrinsic::fshr; Function *F = Intrinsic::getDeclaration(Or.getModule(), IID, Or.getType()); - return IntrinsicInst::Create(F, { ShVal, ShVal, ShAmt }); + return IntrinsicInst::Create(F, {ShVal0, ShVal1, ShAmt}); } /// Attempt to combine or(zext(x),shl(zext(y),bw/2) concat packing patterns. @@ -2197,7 +2196,7 @@ static Instruction *matchOrConcat(Instruction &Or, /// If all elements of two constant vectors are 0/-1 and inverses, return true. static bool areInverseVectorBitmasks(Constant *C1, Constant *C2) { - unsigned NumElts = cast<VectorType>(C1->getType())->getNumElements(); + unsigned NumElts = cast<FixedVectorType>(C1->getType())->getNumElements(); for (unsigned i = 0; i != NumElts; ++i) { Constant *EltC1 = C1->getAggregateElement(i); Constant *EltC2 = C2->getAggregateElement(i); @@ -2215,7 +2214,7 @@ static bool areInverseVectorBitmasks(Constant *C1, Constant *C2) { /// We have an expression of the form (A & C) | (B & D). If A is a scalar or /// vector composed of all-zeros or all-ones values and is the bitwise 'not' of /// B, it can be used as the condition operand of a select instruction. -Value *InstCombiner::getSelectCondition(Value *A, Value *B) { +Value *InstCombinerImpl::getSelectCondition(Value *A, Value *B) { // Step 1: We may have peeked through bitcasts in the caller. // Exit immediately if we don't have (vector) integer types. Type *Ty = A->getType(); @@ -2272,8 +2271,8 @@ Value *InstCombiner::getSelectCondition(Value *A, Value *B) { /// We have an expression of the form (A & C) | (B & D). Try to simplify this /// to "A' ? C : D", where A' is a boolean or vector of booleans. -Value *InstCombiner::matchSelectFromAndOr(Value *A, Value *C, Value *B, - Value *D) { +Value *InstCombinerImpl::matchSelectFromAndOr(Value *A, Value *C, Value *B, + Value *D) { // The potential condition of the select may be bitcasted. In that case, look // through its bitcast and the corresponding bitcast of the 'not' condition. Type *OrigType = A->getType(); @@ -2293,8 +2292,8 @@ Value *InstCombiner::matchSelectFromAndOr(Value *A, Value *C, Value *B, } /// Fold (icmp)|(icmp) if possible. -Value *InstCombiner::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, - BinaryOperator &Or) { +Value *InstCombinerImpl::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, + BinaryOperator &Or) { const SimplifyQuery Q = SQ.getWithInstruction(&Or); // Fold (iszero(A & K1) | iszero(A & K2)) -> (A & (K1 | K2)) != (K1 | K2) @@ -2303,9 +2302,10 @@ Value *InstCombiner::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, return V; ICmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate(); - - ConstantInt *LHSC = dyn_cast<ConstantInt>(LHS->getOperand(1)); - ConstantInt *RHSC = dyn_cast<ConstantInt>(RHS->getOperand(1)); + Value *LHS0 = LHS->getOperand(0), *RHS0 = RHS->getOperand(0); + Value *LHS1 = LHS->getOperand(1), *RHS1 = RHS->getOperand(1); + auto *LHSC = dyn_cast<ConstantInt>(LHS1); + auto *RHSC = dyn_cast<ConstantInt>(RHS1); // Fold (icmp ult/ule (A + C1), C3) | (icmp ult/ule (A + C2), C3) // --> (icmp ult/ule ((A & ~(C1 ^ C2)) + max(C1, C2)), C3) @@ -2317,24 +2317,20 @@ Value *InstCombiner::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, // 3) C1 ^ C2 is one-bit mask. // 4) LowRange1 ^ LowRange2 and HighRange1 ^ HighRange2 are one-bit mask. // This implies all values in the two ranges differ by exactly one bit. - if ((PredL == ICmpInst::ICMP_ULT || PredL == ICmpInst::ICMP_ULE) && PredL == PredR && LHSC && RHSC && LHS->hasOneUse() && RHS->hasOneUse() && LHSC->getType() == RHSC->getType() && LHSC->getValue() == (RHSC->getValue())) { - Value *LAdd = LHS->getOperand(0); - Value *RAdd = RHS->getOperand(0); - - Value *LAddOpnd, *RAddOpnd; + Value *AddOpnd; ConstantInt *LAddC, *RAddC; - if (match(LAdd, m_Add(m_Value(LAddOpnd), m_ConstantInt(LAddC))) && - match(RAdd, m_Add(m_Value(RAddOpnd), m_ConstantInt(RAddC))) && + if (match(LHS0, m_Add(m_Value(AddOpnd), m_ConstantInt(LAddC))) && + match(RHS0, m_Add(m_Specific(AddOpnd), m_ConstantInt(RAddC))) && LAddC->getValue().ugt(LHSC->getValue()) && RAddC->getValue().ugt(LHSC->getValue())) { APInt DiffC = LAddC->getValue() ^ RAddC->getValue(); - if (LAddOpnd == RAddOpnd && DiffC.isPowerOf2()) { + if (DiffC.isPowerOf2()) { ConstantInt *MaxAddC = nullptr; if (LAddC->getValue().ult(RAddC->getValue())) MaxAddC = RAddC; @@ -2354,7 +2350,7 @@ Value *InstCombiner::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, RangeDiff.ugt(LHSC->getValue())) { Value *MaskC = ConstantInt::get(LAddC->getType(), ~DiffC); - Value *NewAnd = Builder.CreateAnd(LAddOpnd, MaskC); + Value *NewAnd = Builder.CreateAnd(AddOpnd, MaskC); Value *NewAdd = Builder.CreateAdd(NewAnd, MaxAddC); return Builder.CreateICmp(LHS->getPredicate(), NewAdd, LHSC); } @@ -2364,15 +2360,12 @@ Value *InstCombiner::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, // (icmp1 A, B) | (icmp2 A, B) --> (icmp3 A, B) if (predicatesFoldable(PredL, PredR)) { - if (LHS->getOperand(0) == RHS->getOperand(1) && - LHS->getOperand(1) == RHS->getOperand(0)) + if (LHS0 == RHS1 && LHS1 == RHS0) LHS->swapOperands(); - if (LHS->getOperand(0) == RHS->getOperand(0) && - LHS->getOperand(1) == RHS->getOperand(1)) { - Value *Op0 = LHS->getOperand(0), *Op1 = LHS->getOperand(1); + if (LHS0 == RHS0 && LHS1 == RHS1) { unsigned Code = getICmpCode(LHS) | getICmpCode(RHS); bool IsSigned = LHS->isSigned() || RHS->isSigned(); - return getNewICmpValue(Code, IsSigned, Op0, Op1, Builder); + return getNewICmpValue(Code, IsSigned, LHS0, LHS1, Builder); } } @@ -2381,31 +2374,30 @@ Value *InstCombiner::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, if (Value *V = foldLogOpOfMaskedICmps(LHS, RHS, false, Builder)) return V; - Value *LHS0 = LHS->getOperand(0), *RHS0 = RHS->getOperand(0); if (LHS->hasOneUse() || RHS->hasOneUse()) { // (icmp eq B, 0) | (icmp ult A, B) -> (icmp ule A, B-1) // (icmp eq B, 0) | (icmp ugt B, A) -> (icmp ule A, B-1) Value *A = nullptr, *B = nullptr; - if (PredL == ICmpInst::ICMP_EQ && LHSC && LHSC->isZero()) { + if (PredL == ICmpInst::ICMP_EQ && match(LHS1, m_Zero())) { B = LHS0; - if (PredR == ICmpInst::ICMP_ULT && LHS0 == RHS->getOperand(1)) + if (PredR == ICmpInst::ICMP_ULT && LHS0 == RHS1) A = RHS0; else if (PredR == ICmpInst::ICMP_UGT && LHS0 == RHS0) - A = RHS->getOperand(1); + A = RHS1; } // (icmp ult A, B) | (icmp eq B, 0) -> (icmp ule A, B-1) // (icmp ugt B, A) | (icmp eq B, 0) -> (icmp ule A, B-1) - else if (PredR == ICmpInst::ICMP_EQ && RHSC && RHSC->isZero()) { + else if (PredR == ICmpInst::ICMP_EQ && match(RHS1, m_Zero())) { B = RHS0; - if (PredL == ICmpInst::ICMP_ULT && RHS0 == LHS->getOperand(1)) + if (PredL == ICmpInst::ICMP_ULT && RHS0 == LHS1) A = LHS0; - else if (PredL == ICmpInst::ICMP_UGT && LHS0 == RHS0) - A = LHS->getOperand(1); + else if (PredL == ICmpInst::ICMP_UGT && RHS0 == LHS0) + A = LHS1; } - if (A && B) + if (A && B && B->getType()->isIntOrIntVectorTy()) return Builder.CreateICmp( ICmpInst::ICMP_UGE, - Builder.CreateAdd(B, ConstantInt::getSigned(B->getType(), -1)), A); + Builder.CreateAdd(B, Constant::getAllOnesValue(B->getType())), A); } if (Value *V = foldAndOrOfICmpsWithConstEq(LHS, RHS, Or, Builder, Q)) @@ -2434,18 +2426,21 @@ Value *InstCombiner::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, foldUnsignedUnderflowCheck(RHS, LHS, /*IsAnd=*/false, Q, Builder)) return X; + // (icmp ne A, 0) | (icmp ne B, 0) --> (icmp ne (A|B), 0) + // TODO: Remove this when foldLogOpOfMaskedICmps can handle vectors. + if (PredL == ICmpInst::ICMP_NE && match(LHS1, m_Zero()) && + PredR == ICmpInst::ICMP_NE && match(RHS1, m_Zero()) && + LHS0->getType()->isIntOrIntVectorTy() && + LHS0->getType() == RHS0->getType()) { + Value *NewOr = Builder.CreateOr(LHS0, RHS0); + return Builder.CreateICmp(PredL, NewOr, + Constant::getNullValue(NewOr->getType())); + } + // This only handles icmp of constants: (icmp1 A, C1) | (icmp2 B, C2). if (!LHSC || !RHSC) return nullptr; - if (LHSC == RHSC && PredL == PredR) { - // (icmp ne A, 0) | (icmp ne B, 0) --> (icmp ne (A|B), 0) - if (PredL == ICmpInst::ICMP_NE && LHSC->isZero()) { - Value *NewOr = Builder.CreateOr(LHS0, RHS0); - return Builder.CreateICmp(PredL, NewOr, LHSC); - } - } - // (icmp ult (X + CA), C1) | (icmp eq X, C2) -> (icmp ule (X + CA), C1) // iff C2 + CA == C1. if (PredL == ICmpInst::ICMP_ULT && PredR == ICmpInst::ICMP_EQ) { @@ -2559,7 +2554,7 @@ Value *InstCombiner::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, // FIXME: We use commutative matchers (m_c_*) for some, but not all, matches // here. We should standardize that construct where it is needed or choose some // other way to ensure that commutated variants of patterns are not missed. -Instruction *InstCombiner::visitOr(BinaryOperator &I) { +Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) { if (Value *V = SimplifyOrInst(I.getOperand(0), I.getOperand(1), SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); @@ -2589,11 +2584,12 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { if (Instruction *FoldedLogic = foldBinOpIntoSelectOrPhi(I)) return FoldedLogic; - if (Instruction *BSwap = matchBSwap(I)) + if (Instruction *BSwap = matchBSwapOrBitReverse(I, /*MatchBSwaps*/ true, + /*MatchBitReversals*/ false)) return BSwap; - if (Instruction *Rotate = matchRotate(I)) - return Rotate; + if (Instruction *Funnel = matchFunnelShift(I, *this)) + return Funnel; if (Instruction *Concat = matchOrConcat(I, Builder)) return replaceInstUsesWith(I, Concat); @@ -2613,9 +2609,9 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { Value *A, *B, *C, *D; if (match(Op0, m_And(m_Value(A), m_Value(C))) && match(Op1, m_And(m_Value(B), m_Value(D)))) { - ConstantInt *C1 = dyn_cast<ConstantInt>(C); - ConstantInt *C2 = dyn_cast<ConstantInt>(D); - if (C1 && C2) { // (A & C1)|(B & C2) + // (A & C1)|(B & C2) + ConstantInt *C1, *C2; + if (match(C, m_ConstantInt(C1)) && match(D, m_ConstantInt(C2))) { Value *V1 = nullptr, *V2 = nullptr; if ((C1->getValue() & C2->getValue()).isNullValue()) { // ((V | N) & C1) | (V & C2) --> (V|N) & (C1|C2) @@ -2806,7 +2802,7 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { // ORs in the hopes that we'll be able to simplify it this way. // (X|C) | V --> (X|V) | C ConstantInt *CI; - if (Op0->hasOneUse() && !isa<ConstantInt>(Op1) && + if (Op0->hasOneUse() && !match(Op1, m_ConstantInt()) && match(Op0, m_Or(m_Value(A), m_ConstantInt(CI)))) { Value *Inner = Builder.CreateOr(A, Op1); Inner->takeName(Op0); @@ -2827,18 +2823,17 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { } } - // or(ashr(subNSW(Y, X), ScalarSizeInBits(Y)-1), X) --> X s> Y ? -1 : X. + // or(ashr(subNSW(Y, X), ScalarSizeInBits(Y) - 1), X) --> X s> Y ? -1 : X. { Value *X, *Y; - const APInt *ShAmt; Type *Ty = I.getType(); - if (match(&I, m_c_Or(m_OneUse(m_AShr(m_NSWSub(m_Value(Y), m_Value(X)), - m_APInt(ShAmt))), - m_Deferred(X))) && - *ShAmt == Ty->getScalarSizeInBits() - 1) { + if (match(&I, m_c_Or(m_OneUse(m_AShr( + m_NSWSub(m_Value(Y), m_Value(X)), + m_SpecificInt(Ty->getScalarSizeInBits() - 1))), + m_Deferred(X)))) { Value *NewICmpInst = Builder.CreateICmpSGT(X, Y); - return SelectInst::Create(NewICmpInst, ConstantInt::getAllOnesValue(Ty), - X); + Value *AllOnes = ConstantInt::getAllOnesValue(Ty); + return SelectInst::Create(NewICmpInst, AllOnes, X); } } @@ -2872,6 +2867,10 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { } } + // (~x) | y --> ~(x & (~y)) iff that gets rid of inversions + if (sinkNotIntoOtherHandOfAndOrOr(I)) + return &I; + return nullptr; } @@ -2928,8 +2927,8 @@ static Instruction *foldXorToXor(BinaryOperator &I, return nullptr; } -Value *InstCombiner::foldXorOfICmps(ICmpInst *LHS, ICmpInst *RHS, - BinaryOperator &I) { +Value *InstCombinerImpl::foldXorOfICmps(ICmpInst *LHS, ICmpInst *RHS, + BinaryOperator &I) { assert(I.getOpcode() == Instruction::Xor && I.getOperand(0) == LHS && I.getOperand(1) == RHS && "Should be 'xor' with these operands"); @@ -3087,9 +3086,9 @@ static Instruction *sinkNotIntoXor(BinaryOperator &I, return nullptr; // We only want to do the transform if it is free to do. - if (isFreeToInvert(X, X->hasOneUse())) { + if (InstCombiner::isFreeToInvert(X, X->hasOneUse())) { // Ok, good. - } else if (isFreeToInvert(Y, Y->hasOneUse())) { + } else if (InstCombiner::isFreeToInvert(Y, Y->hasOneUse())) { std::swap(X, Y); } else return nullptr; @@ -3098,10 +3097,52 @@ static Instruction *sinkNotIntoXor(BinaryOperator &I, return BinaryOperator::CreateXor(NotX, Y, I.getName() + ".demorgan"); } +// Transform +// z = (~x) &/| y +// into: +// z = ~(x |/& (~y)) +// iff y is free to invert and all uses of z can be freely updated. +bool InstCombinerImpl::sinkNotIntoOtherHandOfAndOrOr(BinaryOperator &I) { + Instruction::BinaryOps NewOpc; + switch (I.getOpcode()) { + case Instruction::And: + NewOpc = Instruction::Or; + break; + case Instruction::Or: + NewOpc = Instruction::And; + break; + default: + return false; + }; + + Value *X, *Y; + if (!match(&I, m_c_BinOp(m_Not(m_Value(X)), m_Value(Y)))) + return false; + + // Will we be able to fold the `not` into Y eventually? + if (!InstCombiner::isFreeToInvert(Y, Y->hasOneUse())) + return false; + + // And can our users be adapted? + if (!InstCombiner::canFreelyInvertAllUsersOf(&I, /*IgnoredUser=*/nullptr)) + return false; + + Value *NotY = Builder.CreateNot(Y, Y->getName() + ".not"); + Value *NewBinOp = + BinaryOperator::Create(NewOpc, X, NotY, I.getName() + ".not"); + Builder.Insert(NewBinOp); + replaceInstUsesWith(I, NewBinOp); + // We can not just create an outer `not`, it will most likely be immediately + // folded back, reconstructing our initial pattern, and causing an + // infinite combine loop, so immediately manually fold it away. + freelyInvertAllUsersOf(NewBinOp); + return true; +} + // FIXME: We use commutative matchers (m_c_*) for some, but not all, matches // here. We should standardize that construct where it is needed or choose some // other way to ensure that commutated variants of patterns are not missed. -Instruction *InstCombiner::visitXor(BinaryOperator &I) { +Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) { if (Value *V = SimplifyXorInst(I.getOperand(0), I.getOperand(1), SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); @@ -3128,6 +3169,7 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { return replaceInstUsesWith(I, V); Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + Type *Ty = I.getType(); // Fold (X & M) ^ (Y & ~M) -> (X & M) | (Y & ~M) // This it a special case in haveNoCommonBitsSet, but the computeKnownBits @@ -3199,7 +3241,7 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { match(C, m_Negative())) { // We matched a negative constant, so propagating undef is unsafe. // Clamp undef elements to -1. - Type *EltTy = C->getType()->getScalarType(); + Type *EltTy = Ty->getScalarType(); C = Constant::replaceUndefsWith(C, ConstantInt::getAllOnesValue(EltTy)); return BinaryOperator::CreateLShr(ConstantExpr::getNot(C), Y); } @@ -3209,7 +3251,7 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { match(C, m_NonNegative())) { // We matched a non-negative constant, so propagating undef is unsafe. // Clamp undef elements to 0. - Type *EltTy = C->getType()->getScalarType(); + Type *EltTy = Ty->getScalarType(); C = Constant::replaceUndefsWith(C, ConstantInt::getNullValue(EltTy)); return BinaryOperator::CreateAShr(ConstantExpr::getNot(C), Y); } @@ -3217,6 +3259,11 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { // ~(X + C) --> -(C + 1) - X if (match(Op0, m_Add(m_Value(X), m_Constant(C)))) return BinaryOperator::CreateSub(ConstantExpr::getNeg(AddOne(C)), X); + + // ~(~X + Y) --> X - Y + if (match(NotVal, m_c_Add(m_Not(m_Value(X)), m_Value(Y)))) + return BinaryOperator::CreateWithCopiedFlags(Instruction::Sub, X, Y, + NotVal); } // Use DeMorgan and reassociation to eliminate a 'not' op. @@ -3247,52 +3294,56 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { if (match(Op1, m_APInt(RHSC))) { Value *X; const APInt *C; - if (RHSC->isSignMask() && match(Op0, m_Sub(m_APInt(C), m_Value(X)))) { - // (C - X) ^ signmask -> (C + signmask - X) - Constant *NewC = ConstantInt::get(I.getType(), *C + *RHSC); - return BinaryOperator::CreateSub(NewC, X); - } - if (RHSC->isSignMask() && match(Op0, m_Add(m_Value(X), m_APInt(C)))) { - // (X + C) ^ signmask -> (X + C + signmask) - Constant *NewC = ConstantInt::get(I.getType(), *C + *RHSC); - return BinaryOperator::CreateAdd(X, NewC); - } + // (C - X) ^ signmaskC --> (C + signmaskC) - X + if (RHSC->isSignMask() && match(Op0, m_Sub(m_APInt(C), m_Value(X)))) + return BinaryOperator::CreateSub(ConstantInt::get(Ty, *C + *RHSC), X); + + // (X + C) ^ signmaskC --> X + (C + signmaskC) + if (RHSC->isSignMask() && match(Op0, m_Add(m_Value(X), m_APInt(C)))) + return BinaryOperator::CreateAdd(X, ConstantInt::get(Ty, *C + *RHSC)); - // (X|C1)^C2 -> X^(C1^C2) iff X&~C1 == 0 + // (X | C) ^ RHSC --> X ^ (C ^ RHSC) iff X & C == 0 if (match(Op0, m_Or(m_Value(X), m_APInt(C))) && - MaskedValueIsZero(X, *C, 0, &I)) { - Constant *NewC = ConstantInt::get(I.getType(), *C ^ *RHSC); - return BinaryOperator::CreateXor(X, NewC); + MaskedValueIsZero(X, *C, 0, &I)) + return BinaryOperator::CreateXor(X, ConstantInt::get(Ty, *C ^ *RHSC)); + + // If RHSC is inverting the remaining bits of shifted X, + // canonicalize to a 'not' before the shift to help SCEV and codegen: + // (X << C) ^ RHSC --> ~X << C + if (match(Op0, m_OneUse(m_Shl(m_Value(X), m_APInt(C)))) && + *RHSC == APInt::getAllOnesValue(Ty->getScalarSizeInBits()).shl(*C)) { + Value *NotX = Builder.CreateNot(X); + return BinaryOperator::CreateShl(NotX, ConstantInt::get(Ty, *C)); } + // (X >>u C) ^ RHSC --> ~X >>u C + if (match(Op0, m_OneUse(m_LShr(m_Value(X), m_APInt(C)))) && + *RHSC == APInt::getAllOnesValue(Ty->getScalarSizeInBits()).lshr(*C)) { + Value *NotX = Builder.CreateNot(X); + return BinaryOperator::CreateLShr(NotX, ConstantInt::get(Ty, *C)); + } + // TODO: We could handle 'ashr' here as well. That would be matching + // a 'not' op and moving it before the shift. Doing that requires + // preventing the inverse fold in canShiftBinOpWithConstantRHS(). } } - if (ConstantInt *RHSC = dyn_cast<ConstantInt>(Op1)) { - if (BinaryOperator *Op0I = dyn_cast<BinaryOperator>(Op0)) { - if (ConstantInt *Op0CI = dyn_cast<ConstantInt>(Op0I->getOperand(1))) { - if (Op0I->getOpcode() == Instruction::LShr) { - // ((X^C1) >> C2) ^ C3 -> (X>>C2) ^ ((C1>>C2)^C3) - // E1 = "X ^ C1" - BinaryOperator *E1; - ConstantInt *C1; - if (Op0I->hasOneUse() && - (E1 = dyn_cast<BinaryOperator>(Op0I->getOperand(0))) && - E1->getOpcode() == Instruction::Xor && - (C1 = dyn_cast<ConstantInt>(E1->getOperand(1)))) { - // fold (C1 >> C2) ^ C3 - ConstantInt *C2 = Op0CI, *C3 = RHSC; - APInt FoldConst = C1->getValue().lshr(C2->getValue()); - FoldConst ^= C3->getValue(); - // Prepare the two operands. - Value *Opnd0 = Builder.CreateLShr(E1->getOperand(0), C2); - Opnd0->takeName(Op0I); - cast<Instruction>(Opnd0)->setDebugLoc(I.getDebugLoc()); - Value *FoldVal = ConstantInt::get(Opnd0->getType(), FoldConst); - - return BinaryOperator::CreateXor(Opnd0, FoldVal); - } - } - } + // FIXME: This should not be limited to scalar (pull into APInt match above). + { + Value *X; + ConstantInt *C1, *C2, *C3; + // ((X^C1) >> C2) ^ C3 -> (X>>C2) ^ ((C1>>C2)^C3) + if (match(Op1, m_ConstantInt(C3)) && + match(Op0, m_LShr(m_Xor(m_Value(X), m_ConstantInt(C1)), + m_ConstantInt(C2))) && + Op0->hasOneUse()) { + // fold (C1 >> C2) ^ C3 + APInt FoldConst = C1->getValue().lshr(C2->getValue()); + FoldConst ^= C3->getValue(); + // Prepare the two operands. + auto *Opnd0 = cast<Instruction>(Builder.CreateLShr(X, C2)); + Opnd0->takeName(cast<Instruction>(Op0)); + Opnd0->setDebugLoc(I.getDebugLoc()); + return BinaryOperator::CreateXor(Opnd0, ConstantInt::get(Ty, FoldConst)); } } @@ -3349,6 +3400,25 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { match(Op1, m_Not(m_Specific(A)))) return BinaryOperator::CreateNot(Builder.CreateAnd(A, B)); + // (~A & B) ^ A --> A | B -- There are 4 commuted variants. + if (match(&I, m_c_Xor(m_c_And(m_Not(m_Value(A)), m_Value(B)), m_Deferred(A)))) + return BinaryOperator::CreateOr(A, B); + + // (A | B) ^ (A | C) --> (B ^ C) & ~A -- There are 4 commuted variants. + // TODO: Loosen one-use restriction if common operand is a constant. + Value *D; + if (match(Op0, m_OneUse(m_Or(m_Value(A), m_Value(B)))) && + match(Op1, m_OneUse(m_Or(m_Value(C), m_Value(D))))) { + if (B == C || B == D) + std::swap(A, B); + if (A == C) + std::swap(C, D); + if (A == D) { + Value *NotA = Builder.CreateNot(A); + return BinaryOperator::CreateAnd(Builder.CreateXor(B, C), NotA); + } + } + if (auto *LHS = dyn_cast<ICmpInst>(I.getOperand(0))) if (auto *RHS = dyn_cast<ICmpInst>(I.getOperand(1))) if (Value *V = foldXorOfICmps(LHS, RHS, I)) @@ -3366,7 +3436,6 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { std::swap(Op0, Op1); const APInt *ShAmt; - Type *Ty = I.getType(); if (match(Op1, m_AShr(m_Value(A), m_APInt(ShAmt))) && Op1->hasNUses(2) && *ShAmt == Ty->getScalarSizeInBits() - 1 && match(Op0, m_OneUse(m_c_Add(m_Specific(A), m_Specific(Op1))))) { @@ -3425,19 +3494,30 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { } } - // Pull 'not' into operands of select if both operands are one-use compares. + // Pull 'not' into operands of select if both operands are one-use compares + // or one is one-use compare and the other one is a constant. // Inverting the predicates eliminates the 'not' operation. // Example: - // not (select ?, (cmp TPred, ?, ?), (cmp FPred, ?, ?) --> + // not (select ?, (cmp TPred, ?, ?), (cmp FPred, ?, ?) --> // select ?, (cmp InvTPred, ?, ?), (cmp InvFPred, ?, ?) - // TODO: Canonicalize by hoisting 'not' into an arm of the select if only - // 1 select operand is a cmp? + // not (select ?, (cmp TPred, ?, ?), true --> + // select ?, (cmp InvTPred, ?, ?), false if (auto *Sel = dyn_cast<SelectInst>(Op0)) { - auto *CmpT = dyn_cast<CmpInst>(Sel->getTrueValue()); - auto *CmpF = dyn_cast<CmpInst>(Sel->getFalseValue()); - if (CmpT && CmpF && CmpT->hasOneUse() && CmpF->hasOneUse()) { - CmpT->setPredicate(CmpT->getInversePredicate()); - CmpF->setPredicate(CmpF->getInversePredicate()); + Value *TV = Sel->getTrueValue(); + Value *FV = Sel->getFalseValue(); + auto *CmpT = dyn_cast<CmpInst>(TV); + auto *CmpF = dyn_cast<CmpInst>(FV); + bool InvertibleT = (CmpT && CmpT->hasOneUse()) || isa<Constant>(TV); + bool InvertibleF = (CmpF && CmpF->hasOneUse()) || isa<Constant>(FV); + if (InvertibleT && InvertibleF) { + if (CmpT) + CmpT->setPredicate(CmpT->getInversePredicate()); + else + Sel->setTrueValue(ConstantExpr::getNot(cast<Constant>(TV))); + if (CmpF) + CmpF->setPredicate(CmpF->getInversePredicate()); + else + Sel->setFalseValue(ConstantExpr::getNot(cast<Constant>(FV))); return replaceInstUsesWith(I, Sel); } } @@ -3446,5 +3526,15 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { if (Instruction *NewXor = sinkNotIntoXor(I, Builder)) return NewXor; + // Otherwise, if all else failed, try to hoist the xor-by-constant: + // (X ^ C) ^ Y --> (X ^ Y) ^ C + // Just like we do in other places, we completely avoid the fold + // for constantexprs, at least to avoid endless combine loop. + if (match(&I, m_c_Xor(m_OneUse(m_Xor(m_CombineAnd(m_Value(X), + m_Unless(m_ConstantExpr())), + m_ImmConstant(C1))), + m_Value(Y)))) + return BinaryOperator::CreateXor(Builder.CreateXor(X, Y), C1); + return nullptr; } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAtomicRMW.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAtomicRMW.cpp index ba1cf982229d..495493aab4b5 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAtomicRMW.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAtomicRMW.cpp @@ -9,8 +9,10 @@ // This file implements the visit functions for atomic rmw instructions. // //===----------------------------------------------------------------------===// + #include "InstCombineInternal.h" #include "llvm/IR/Instructions.h" +#include "llvm/Transforms/InstCombine/InstCombiner.h" using namespace llvm; @@ -30,7 +32,7 @@ bool isIdempotentRMW(AtomicRMWInst& RMWI) { default: return false; }; - + auto C = dyn_cast<ConstantInt>(RMWI.getValOperand()); if(!C) return false; @@ -91,13 +93,13 @@ bool isSaturating(AtomicRMWInst& RMWI) { return C->isMaxValue(false); }; } -} +} // namespace -Instruction *InstCombiner::visitAtomicRMWInst(AtomicRMWInst &RMWI) { +Instruction *InstCombinerImpl::visitAtomicRMWInst(AtomicRMWInst &RMWI) { // Volatile RMWs perform a load and a store, we cannot replace this by just a // load or just a store. We chose not to canonicalize out of general paranoia - // about user expectations around volatile. + // about user expectations around volatile. if (RMWI.isVolatile()) return nullptr; @@ -115,7 +117,7 @@ Instruction *InstCombiner::visitAtomicRMWInst(AtomicRMWInst &RMWI) { "AtomicRMWs don't make sense with Unordered or NotAtomic"); // Any atomicrmw xchg with no uses can be converted to a atomic store if the - // ordering is compatible. + // ordering is compatible. if (RMWI.getOperation() == AtomicRMWInst::Xchg && RMWI.use_empty()) { if (Ordering != AtomicOrdering::Release && @@ -127,14 +129,14 @@ Instruction *InstCombiner::visitAtomicRMWInst(AtomicRMWInst &RMWI) { SI->setAlignment(DL.getABITypeAlign(RMWI.getType())); return eraseInstFromFunction(RMWI); } - + if (!isIdempotentRMW(RMWI)) return nullptr; // We chose to canonicalize all idempotent operations to an single // operation code and constant. This makes it easier for the rest of the // optimizer to match easily. The choices of or w/0 and fadd w/-0.0 are - // arbitrary. + // arbitrary. if (RMWI.getType()->isIntegerTy() && RMWI.getOperation() != AtomicRMWInst::Or) { RMWI.setOperation(AtomicRMWInst::Or); @@ -149,7 +151,7 @@ Instruction *InstCombiner::visitAtomicRMWInst(AtomicRMWInst &RMWI) { if (Ordering != AtomicOrdering::Acquire && Ordering != AtomicOrdering::Monotonic) return nullptr; - + LoadInst *Load = new LoadInst(RMWI.getType(), RMWI.getPointerOperand(), "", false, DL.getABITypeAlign(RMWI.getType()), Ordering, RMWI.getSyncScopeID()); diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp index c734c9a68fb2..5482b944e347 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -28,6 +28,7 @@ #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/Loads.h" #include "llvm/Analysis/MemoryBuiltins.h" +#include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/Attributes.h" @@ -47,9 +48,6 @@ #include "llvm/IR/IntrinsicsAMDGPU.h" #include "llvm/IR/IntrinsicsARM.h" #include "llvm/IR/IntrinsicsHexagon.h" -#include "llvm/IR/IntrinsicsNVPTX.h" -#include "llvm/IR/IntrinsicsPowerPC.h" -#include "llvm/IR/IntrinsicsX86.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/PatternMatch.h" @@ -68,6 +66,7 @@ #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/InstCombine/InstCombineWorklist.h" +#include "llvm/Transforms/InstCombine/InstCombiner.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/SimplifyLibCalls.h" #include <algorithm> @@ -100,24 +99,7 @@ static Type *getPromotedType(Type *Ty) { return Ty; } -/// Return a constant boolean vector that has true elements in all positions -/// where the input constant data vector has an element with the sign bit set. -static Constant *getNegativeIsTrueBoolVec(ConstantDataVector *V) { - SmallVector<Constant *, 32> BoolVec; - IntegerType *BoolTy = Type::getInt1Ty(V->getContext()); - for (unsigned I = 0, E = V->getNumElements(); I != E; ++I) { - Constant *Elt = V->getElementAsConstant(I); - assert((isa<ConstantInt>(Elt) || isa<ConstantFP>(Elt)) && - "Unexpected constant data vector element type"); - bool Sign = V->getElementType()->isIntegerTy() - ? cast<ConstantInt>(Elt)->isNegative() - : cast<ConstantFP>(Elt)->isNegative(); - BoolVec.push_back(ConstantInt::get(BoolTy, Sign)); - } - return ConstantVector::get(BoolVec); -} - -Instruction *InstCombiner::SimplifyAnyMemTransfer(AnyMemTransferInst *MI) { +Instruction *InstCombinerImpl::SimplifyAnyMemTransfer(AnyMemTransferInst *MI) { Align DstAlign = getKnownAlignment(MI->getRawDest(), DL, MI, &AC, &DT); MaybeAlign CopyDstAlign = MI->getDestAlign(); if (!CopyDstAlign || *CopyDstAlign < DstAlign) { @@ -232,7 +214,7 @@ Instruction *InstCombiner::SimplifyAnyMemTransfer(AnyMemTransferInst *MI) { return MI; } -Instruction *InstCombiner::SimplifyAnyMemSet(AnyMemSetInst *MI) { +Instruction *InstCombinerImpl::SimplifyAnyMemSet(AnyMemSetInst *MI) { const Align KnownAlignment = getKnownAlignment(MI->getDest(), DL, MI, &AC, &DT); MaybeAlign MemSetAlign = MI->getDestAlign(); @@ -292,820 +274,9 @@ Instruction *InstCombiner::SimplifyAnyMemSet(AnyMemSetInst *MI) { return nullptr; } -static Value *simplifyX86immShift(const IntrinsicInst &II, - InstCombiner::BuilderTy &Builder) { - bool LogicalShift = false; - bool ShiftLeft = false; - bool IsImm = false; - - switch (II.getIntrinsicID()) { - default: llvm_unreachable("Unexpected intrinsic!"); - case Intrinsic::x86_sse2_psrai_d: - case Intrinsic::x86_sse2_psrai_w: - case Intrinsic::x86_avx2_psrai_d: - case Intrinsic::x86_avx2_psrai_w: - case Intrinsic::x86_avx512_psrai_q_128: - case Intrinsic::x86_avx512_psrai_q_256: - case Intrinsic::x86_avx512_psrai_d_512: - case Intrinsic::x86_avx512_psrai_q_512: - case Intrinsic::x86_avx512_psrai_w_512: - IsImm = true; - LLVM_FALLTHROUGH; - case Intrinsic::x86_sse2_psra_d: - case Intrinsic::x86_sse2_psra_w: - case Intrinsic::x86_avx2_psra_d: - case Intrinsic::x86_avx2_psra_w: - case Intrinsic::x86_avx512_psra_q_128: - case Intrinsic::x86_avx512_psra_q_256: - case Intrinsic::x86_avx512_psra_d_512: - case Intrinsic::x86_avx512_psra_q_512: - case Intrinsic::x86_avx512_psra_w_512: - LogicalShift = false; - ShiftLeft = false; - break; - case Intrinsic::x86_sse2_psrli_d: - case Intrinsic::x86_sse2_psrli_q: - case Intrinsic::x86_sse2_psrli_w: - case Intrinsic::x86_avx2_psrli_d: - case Intrinsic::x86_avx2_psrli_q: - case Intrinsic::x86_avx2_psrli_w: - case Intrinsic::x86_avx512_psrli_d_512: - case Intrinsic::x86_avx512_psrli_q_512: - case Intrinsic::x86_avx512_psrli_w_512: - IsImm = true; - LLVM_FALLTHROUGH; - case Intrinsic::x86_sse2_psrl_d: - case Intrinsic::x86_sse2_psrl_q: - case Intrinsic::x86_sse2_psrl_w: - case Intrinsic::x86_avx2_psrl_d: - case Intrinsic::x86_avx2_psrl_q: - case Intrinsic::x86_avx2_psrl_w: - case Intrinsic::x86_avx512_psrl_d_512: - case Intrinsic::x86_avx512_psrl_q_512: - case Intrinsic::x86_avx512_psrl_w_512: - LogicalShift = true; - ShiftLeft = false; - break; - case Intrinsic::x86_sse2_pslli_d: - case Intrinsic::x86_sse2_pslli_q: - case Intrinsic::x86_sse2_pslli_w: - case Intrinsic::x86_avx2_pslli_d: - case Intrinsic::x86_avx2_pslli_q: - case Intrinsic::x86_avx2_pslli_w: - case Intrinsic::x86_avx512_pslli_d_512: - case Intrinsic::x86_avx512_pslli_q_512: - case Intrinsic::x86_avx512_pslli_w_512: - IsImm = true; - LLVM_FALLTHROUGH; - case Intrinsic::x86_sse2_psll_d: - case Intrinsic::x86_sse2_psll_q: - case Intrinsic::x86_sse2_psll_w: - case Intrinsic::x86_avx2_psll_d: - case Intrinsic::x86_avx2_psll_q: - case Intrinsic::x86_avx2_psll_w: - case Intrinsic::x86_avx512_psll_d_512: - case Intrinsic::x86_avx512_psll_q_512: - case Intrinsic::x86_avx512_psll_w_512: - LogicalShift = true; - ShiftLeft = true; - break; - } - assert((LogicalShift || !ShiftLeft) && "Only logical shifts can shift left"); - - auto Vec = II.getArgOperand(0); - auto Amt = II.getArgOperand(1); - auto VT = cast<VectorType>(Vec->getType()); - auto SVT = VT->getElementType(); - auto AmtVT = Amt->getType(); - unsigned VWidth = VT->getNumElements(); - unsigned BitWidth = SVT->getPrimitiveSizeInBits(); - - // If the shift amount is guaranteed to be in-range we can replace it with a - // generic shift. If its guaranteed to be out of range, logical shifts combine to - // zero and arithmetic shifts are clamped to (BitWidth - 1). - if (IsImm) { - assert(AmtVT ->isIntegerTy(32) && - "Unexpected shift-by-immediate type"); - KnownBits KnownAmtBits = - llvm::computeKnownBits(Amt, II.getModule()->getDataLayout()); - if (KnownAmtBits.getMaxValue().ult(BitWidth)) { - Amt = Builder.CreateZExtOrTrunc(Amt, SVT); - Amt = Builder.CreateVectorSplat(VWidth, Amt); - return (LogicalShift ? (ShiftLeft ? Builder.CreateShl(Vec, Amt) - : Builder.CreateLShr(Vec, Amt)) - : Builder.CreateAShr(Vec, Amt)); - } - if (KnownAmtBits.getMinValue().uge(BitWidth)) { - if (LogicalShift) - return ConstantAggregateZero::get(VT); - Amt = ConstantInt::get(SVT, BitWidth - 1); - return Builder.CreateAShr(Vec, Builder.CreateVectorSplat(VWidth, Amt)); - } - } else { - // Ensure the first element has an in-range value and the rest of the - // elements in the bottom 64 bits are zero. - assert(AmtVT->isVectorTy() && AmtVT->getPrimitiveSizeInBits() == 128 && - cast<VectorType>(AmtVT)->getElementType() == SVT && - "Unexpected shift-by-scalar type"); - unsigned NumAmtElts = cast<VectorType>(AmtVT)->getNumElements(); - APInt DemandedLower = APInt::getOneBitSet(NumAmtElts, 0); - APInt DemandedUpper = APInt::getBitsSet(NumAmtElts, 1, NumAmtElts / 2); - KnownBits KnownLowerBits = llvm::computeKnownBits( - Amt, DemandedLower, II.getModule()->getDataLayout()); - KnownBits KnownUpperBits = llvm::computeKnownBits( - Amt, DemandedUpper, II.getModule()->getDataLayout()); - if (KnownLowerBits.getMaxValue().ult(BitWidth) && - (DemandedUpper.isNullValue() || KnownUpperBits.isZero())) { - SmallVector<int, 16> ZeroSplat(VWidth, 0); - Amt = Builder.CreateShuffleVector(Amt, Amt, ZeroSplat); - return (LogicalShift ? (ShiftLeft ? Builder.CreateShl(Vec, Amt) - : Builder.CreateLShr(Vec, Amt)) - : Builder.CreateAShr(Vec, Amt)); - } - } - - // Simplify if count is constant vector. - auto CDV = dyn_cast<ConstantDataVector>(Amt); - if (!CDV) - return nullptr; - - // SSE2/AVX2 uses all the first 64-bits of the 128-bit vector - // operand to compute the shift amount. - assert(AmtVT->isVectorTy() && AmtVT->getPrimitiveSizeInBits() == 128 && - cast<VectorType>(AmtVT)->getElementType() == SVT && - "Unexpected shift-by-scalar type"); - - // Concatenate the sub-elements to create the 64-bit value. - APInt Count(64, 0); - for (unsigned i = 0, NumSubElts = 64 / BitWidth; i != NumSubElts; ++i) { - unsigned SubEltIdx = (NumSubElts - 1) - i; - auto SubElt = cast<ConstantInt>(CDV->getElementAsConstant(SubEltIdx)); - Count <<= BitWidth; - Count |= SubElt->getValue().zextOrTrunc(64); - } - - // If shift-by-zero then just return the original value. - if (Count.isNullValue()) - return Vec; - - // Handle cases when Shift >= BitWidth. - if (Count.uge(BitWidth)) { - // If LogicalShift - just return zero. - if (LogicalShift) - return ConstantAggregateZero::get(VT); - - // If ArithmeticShift - clamp Shift to (BitWidth - 1). - Count = APInt(64, BitWidth - 1); - } - - // Get a constant vector of the same type as the first operand. - auto ShiftAmt = ConstantInt::get(SVT, Count.zextOrTrunc(BitWidth)); - auto ShiftVec = Builder.CreateVectorSplat(VWidth, ShiftAmt); - - if (ShiftLeft) - return Builder.CreateShl(Vec, ShiftVec); - - if (LogicalShift) - return Builder.CreateLShr(Vec, ShiftVec); - - return Builder.CreateAShr(Vec, ShiftVec); -} - -// Attempt to simplify AVX2 per-element shift intrinsics to a generic IR shift. -// Unlike the generic IR shifts, the intrinsics have defined behaviour for out -// of range shift amounts (logical - set to zero, arithmetic - splat sign bit). -static Value *simplifyX86varShift(const IntrinsicInst &II, - InstCombiner::BuilderTy &Builder) { - bool LogicalShift = false; - bool ShiftLeft = false; - - switch (II.getIntrinsicID()) { - default: llvm_unreachable("Unexpected intrinsic!"); - case Intrinsic::x86_avx2_psrav_d: - case Intrinsic::x86_avx2_psrav_d_256: - case Intrinsic::x86_avx512_psrav_q_128: - case Intrinsic::x86_avx512_psrav_q_256: - case Intrinsic::x86_avx512_psrav_d_512: - case Intrinsic::x86_avx512_psrav_q_512: - case Intrinsic::x86_avx512_psrav_w_128: - case Intrinsic::x86_avx512_psrav_w_256: - case Intrinsic::x86_avx512_psrav_w_512: - LogicalShift = false; - ShiftLeft = false; - break; - case Intrinsic::x86_avx2_psrlv_d: - case Intrinsic::x86_avx2_psrlv_d_256: - case Intrinsic::x86_avx2_psrlv_q: - case Intrinsic::x86_avx2_psrlv_q_256: - case Intrinsic::x86_avx512_psrlv_d_512: - case Intrinsic::x86_avx512_psrlv_q_512: - case Intrinsic::x86_avx512_psrlv_w_128: - case Intrinsic::x86_avx512_psrlv_w_256: - case Intrinsic::x86_avx512_psrlv_w_512: - LogicalShift = true; - ShiftLeft = false; - break; - case Intrinsic::x86_avx2_psllv_d: - case Intrinsic::x86_avx2_psllv_d_256: - case Intrinsic::x86_avx2_psllv_q: - case Intrinsic::x86_avx2_psllv_q_256: - case Intrinsic::x86_avx512_psllv_d_512: - case Intrinsic::x86_avx512_psllv_q_512: - case Intrinsic::x86_avx512_psllv_w_128: - case Intrinsic::x86_avx512_psllv_w_256: - case Intrinsic::x86_avx512_psllv_w_512: - LogicalShift = true; - ShiftLeft = true; - break; - } - assert((LogicalShift || !ShiftLeft) && "Only logical shifts can shift left"); - - auto Vec = II.getArgOperand(0); - auto Amt = II.getArgOperand(1); - auto VT = cast<VectorType>(II.getType()); - auto SVT = VT->getElementType(); - int NumElts = VT->getNumElements(); - int BitWidth = SVT->getIntegerBitWidth(); - - // If the shift amount is guaranteed to be in-range we can replace it with a - // generic shift. - APInt UpperBits = - APInt::getHighBitsSet(BitWidth, BitWidth - Log2_32(BitWidth)); - if (llvm::MaskedValueIsZero(Amt, UpperBits, - II.getModule()->getDataLayout())) { - return (LogicalShift ? (ShiftLeft ? Builder.CreateShl(Vec, Amt) - : Builder.CreateLShr(Vec, Amt)) - : Builder.CreateAShr(Vec, Amt)); - } - - // Simplify if all shift amounts are constant/undef. - auto *CShift = dyn_cast<Constant>(Amt); - if (!CShift) - return nullptr; - - // Collect each element's shift amount. - // We also collect special cases: UNDEF = -1, OUT-OF-RANGE = BitWidth. - bool AnyOutOfRange = false; - SmallVector<int, 8> ShiftAmts; - for (int I = 0; I < NumElts; ++I) { - auto *CElt = CShift->getAggregateElement(I); - if (CElt && isa<UndefValue>(CElt)) { - ShiftAmts.push_back(-1); - continue; - } - - auto *COp = dyn_cast_or_null<ConstantInt>(CElt); - if (!COp) - return nullptr; - - // Handle out of range shifts. - // If LogicalShift - set to BitWidth (special case). - // If ArithmeticShift - set to (BitWidth - 1) (sign splat). - APInt ShiftVal = COp->getValue(); - if (ShiftVal.uge(BitWidth)) { - AnyOutOfRange = LogicalShift; - ShiftAmts.push_back(LogicalShift ? BitWidth : BitWidth - 1); - continue; - } - - ShiftAmts.push_back((int)ShiftVal.getZExtValue()); - } - - // If all elements out of range or UNDEF, return vector of zeros/undefs. - // ArithmeticShift should only hit this if they are all UNDEF. - auto OutOfRange = [&](int Idx) { return (Idx < 0) || (BitWidth <= Idx); }; - if (llvm::all_of(ShiftAmts, OutOfRange)) { - SmallVector<Constant *, 8> ConstantVec; - for (int Idx : ShiftAmts) { - if (Idx < 0) { - ConstantVec.push_back(UndefValue::get(SVT)); - } else { - assert(LogicalShift && "Logical shift expected"); - ConstantVec.push_back(ConstantInt::getNullValue(SVT)); - } - } - return ConstantVector::get(ConstantVec); - } - - // We can't handle only some out of range values with generic logical shifts. - if (AnyOutOfRange) - return nullptr; - - // Build the shift amount constant vector. - SmallVector<Constant *, 8> ShiftVecAmts; - for (int Idx : ShiftAmts) { - if (Idx < 0) - ShiftVecAmts.push_back(UndefValue::get(SVT)); - else - ShiftVecAmts.push_back(ConstantInt::get(SVT, Idx)); - } - auto ShiftVec = ConstantVector::get(ShiftVecAmts); - - if (ShiftLeft) - return Builder.CreateShl(Vec, ShiftVec); - - if (LogicalShift) - return Builder.CreateLShr(Vec, ShiftVec); - - return Builder.CreateAShr(Vec, ShiftVec); -} - -static Value *simplifyX86pack(IntrinsicInst &II, - InstCombiner::BuilderTy &Builder, bool IsSigned) { - Value *Arg0 = II.getArgOperand(0); - Value *Arg1 = II.getArgOperand(1); - Type *ResTy = II.getType(); - - // Fast all undef handling. - if (isa<UndefValue>(Arg0) && isa<UndefValue>(Arg1)) - return UndefValue::get(ResTy); - - auto *ArgTy = cast<VectorType>(Arg0->getType()); - unsigned NumLanes = ResTy->getPrimitiveSizeInBits() / 128; - unsigned NumSrcElts = ArgTy->getNumElements(); - assert(cast<VectorType>(ResTy)->getNumElements() == (2 * NumSrcElts) && - "Unexpected packing types"); - - unsigned NumSrcEltsPerLane = NumSrcElts / NumLanes; - unsigned DstScalarSizeInBits = ResTy->getScalarSizeInBits(); - unsigned SrcScalarSizeInBits = ArgTy->getScalarSizeInBits(); - assert(SrcScalarSizeInBits == (2 * DstScalarSizeInBits) && - "Unexpected packing types"); - - // Constant folding. - if (!isa<Constant>(Arg0) || !isa<Constant>(Arg1)) - return nullptr; - - // Clamp Values - signed/unsigned both use signed clamp values, but they - // differ on the min/max values. - APInt MinValue, MaxValue; - if (IsSigned) { - // PACKSS: Truncate signed value with signed saturation. - // Source values less than dst minint are saturated to minint. - // Source values greater than dst maxint are saturated to maxint. - MinValue = - APInt::getSignedMinValue(DstScalarSizeInBits).sext(SrcScalarSizeInBits); - MaxValue = - APInt::getSignedMaxValue(DstScalarSizeInBits).sext(SrcScalarSizeInBits); - } else { - // PACKUS: Truncate signed value with unsigned saturation. - // Source values less than zero are saturated to zero. - // Source values greater than dst maxuint are saturated to maxuint. - MinValue = APInt::getNullValue(SrcScalarSizeInBits); - MaxValue = APInt::getLowBitsSet(SrcScalarSizeInBits, DstScalarSizeInBits); - } - - auto *MinC = Constant::getIntegerValue(ArgTy, MinValue); - auto *MaxC = Constant::getIntegerValue(ArgTy, MaxValue); - Arg0 = Builder.CreateSelect(Builder.CreateICmpSLT(Arg0, MinC), MinC, Arg0); - Arg1 = Builder.CreateSelect(Builder.CreateICmpSLT(Arg1, MinC), MinC, Arg1); - Arg0 = Builder.CreateSelect(Builder.CreateICmpSGT(Arg0, MaxC), MaxC, Arg0); - Arg1 = Builder.CreateSelect(Builder.CreateICmpSGT(Arg1, MaxC), MaxC, Arg1); - - // Shuffle clamped args together at the lane level. - SmallVector<int, 32> PackMask; - for (unsigned Lane = 0; Lane != NumLanes; ++Lane) { - for (unsigned Elt = 0; Elt != NumSrcEltsPerLane; ++Elt) - PackMask.push_back(Elt + (Lane * NumSrcEltsPerLane)); - for (unsigned Elt = 0; Elt != NumSrcEltsPerLane; ++Elt) - PackMask.push_back(Elt + (Lane * NumSrcEltsPerLane) + NumSrcElts); - } - auto *Shuffle = Builder.CreateShuffleVector(Arg0, Arg1, PackMask); - - // Truncate to dst size. - return Builder.CreateTrunc(Shuffle, ResTy); -} - -static Value *simplifyX86movmsk(const IntrinsicInst &II, - InstCombiner::BuilderTy &Builder) { - Value *Arg = II.getArgOperand(0); - Type *ResTy = II.getType(); - - // movmsk(undef) -> zero as we must ensure the upper bits are zero. - if (isa<UndefValue>(Arg)) - return Constant::getNullValue(ResTy); - - auto *ArgTy = dyn_cast<VectorType>(Arg->getType()); - // We can't easily peek through x86_mmx types. - if (!ArgTy) - return nullptr; - - // Expand MOVMSK to compare/bitcast/zext: - // e.g. PMOVMSKB(v16i8 x): - // %cmp = icmp slt <16 x i8> %x, zeroinitializer - // %int = bitcast <16 x i1> %cmp to i16 - // %res = zext i16 %int to i32 - unsigned NumElts = ArgTy->getNumElements(); - Type *IntegerVecTy = VectorType::getInteger(ArgTy); - Type *IntegerTy = Builder.getIntNTy(NumElts); - - Value *Res = Builder.CreateBitCast(Arg, IntegerVecTy); - Res = Builder.CreateICmpSLT(Res, Constant::getNullValue(IntegerVecTy)); - Res = Builder.CreateBitCast(Res, IntegerTy); - Res = Builder.CreateZExtOrTrunc(Res, ResTy); - return Res; -} - -static Value *simplifyX86addcarry(const IntrinsicInst &II, - InstCombiner::BuilderTy &Builder) { - Value *CarryIn = II.getArgOperand(0); - Value *Op1 = II.getArgOperand(1); - Value *Op2 = II.getArgOperand(2); - Type *RetTy = II.getType(); - Type *OpTy = Op1->getType(); - assert(RetTy->getStructElementType(0)->isIntegerTy(8) && - RetTy->getStructElementType(1) == OpTy && OpTy == Op2->getType() && - "Unexpected types for x86 addcarry"); - - // If carry-in is zero, this is just an unsigned add with overflow. - if (match(CarryIn, m_ZeroInt())) { - Value *UAdd = Builder.CreateIntrinsic(Intrinsic::uadd_with_overflow, OpTy, - { Op1, Op2 }); - // The types have to be adjusted to match the x86 call types. - Value *UAddResult = Builder.CreateExtractValue(UAdd, 0); - Value *UAddOV = Builder.CreateZExt(Builder.CreateExtractValue(UAdd, 1), - Builder.getInt8Ty()); - Value *Res = UndefValue::get(RetTy); - Res = Builder.CreateInsertValue(Res, UAddOV, 0); - return Builder.CreateInsertValue(Res, UAddResult, 1); - } - - return nullptr; -} - -static Value *simplifyX86insertps(const IntrinsicInst &II, - InstCombiner::BuilderTy &Builder) { - auto *CInt = dyn_cast<ConstantInt>(II.getArgOperand(2)); - if (!CInt) - return nullptr; - - VectorType *VecTy = cast<VectorType>(II.getType()); - assert(VecTy->getNumElements() == 4 && "insertps with wrong vector type"); - - // The immediate permute control byte looks like this: - // [3:0] - zero mask for each 32-bit lane - // [5:4] - select one 32-bit destination lane - // [7:6] - select one 32-bit source lane - - uint8_t Imm = CInt->getZExtValue(); - uint8_t ZMask = Imm & 0xf; - uint8_t DestLane = (Imm >> 4) & 0x3; - uint8_t SourceLane = (Imm >> 6) & 0x3; - - ConstantAggregateZero *ZeroVector = ConstantAggregateZero::get(VecTy); - - // If all zero mask bits are set, this was just a weird way to - // generate a zero vector. - if (ZMask == 0xf) - return ZeroVector; - - // Initialize by passing all of the first source bits through. - int ShuffleMask[4] = {0, 1, 2, 3}; - - // We may replace the second operand with the zero vector. - Value *V1 = II.getArgOperand(1); - - if (ZMask) { - // If the zero mask is being used with a single input or the zero mask - // overrides the destination lane, this is a shuffle with the zero vector. - if ((II.getArgOperand(0) == II.getArgOperand(1)) || - (ZMask & (1 << DestLane))) { - V1 = ZeroVector; - // We may still move 32-bits of the first source vector from one lane - // to another. - ShuffleMask[DestLane] = SourceLane; - // The zero mask may override the previous insert operation. - for (unsigned i = 0; i < 4; ++i) - if ((ZMask >> i) & 0x1) - ShuffleMask[i] = i + 4; - } else { - // TODO: Model this case as 2 shuffles or a 'logical and' plus shuffle? - return nullptr; - } - } else { - // Replace the selected destination lane with the selected source lane. - ShuffleMask[DestLane] = SourceLane + 4; - } - - return Builder.CreateShuffleVector(II.getArgOperand(0), V1, ShuffleMask); -} - -/// Attempt to simplify SSE4A EXTRQ/EXTRQI instructions using constant folding -/// or conversion to a shuffle vector. -static Value *simplifyX86extrq(IntrinsicInst &II, Value *Op0, - ConstantInt *CILength, ConstantInt *CIIndex, - InstCombiner::BuilderTy &Builder) { - auto LowConstantHighUndef = [&](uint64_t Val) { - Type *IntTy64 = Type::getInt64Ty(II.getContext()); - Constant *Args[] = {ConstantInt::get(IntTy64, Val), - UndefValue::get(IntTy64)}; - return ConstantVector::get(Args); - }; - - // See if we're dealing with constant values. - Constant *C0 = dyn_cast<Constant>(Op0); - ConstantInt *CI0 = - C0 ? dyn_cast_or_null<ConstantInt>(C0->getAggregateElement((unsigned)0)) - : nullptr; - - // Attempt to constant fold. - if (CILength && CIIndex) { - // From AMD documentation: "The bit index and field length are each six - // bits in length other bits of the field are ignored." - APInt APIndex = CIIndex->getValue().zextOrTrunc(6); - APInt APLength = CILength->getValue().zextOrTrunc(6); - - unsigned Index = APIndex.getZExtValue(); - - // From AMD documentation: "a value of zero in the field length is - // defined as length of 64". - unsigned Length = APLength == 0 ? 64 : APLength.getZExtValue(); - - // From AMD documentation: "If the sum of the bit index + length field - // is greater than 64, the results are undefined". - unsigned End = Index + Length; - - // Note that both field index and field length are 8-bit quantities. - // Since variables 'Index' and 'Length' are unsigned values - // obtained from zero-extending field index and field length - // respectively, their sum should never wrap around. - if (End > 64) - return UndefValue::get(II.getType()); - - // If we are inserting whole bytes, we can convert this to a shuffle. - // Lowering can recognize EXTRQI shuffle masks. - if ((Length % 8) == 0 && (Index % 8) == 0) { - // Convert bit indices to byte indices. - Length /= 8; - Index /= 8; - - Type *IntTy8 = Type::getInt8Ty(II.getContext()); - auto *ShufTy = FixedVectorType::get(IntTy8, 16); - - SmallVector<int, 16> ShuffleMask; - for (int i = 0; i != (int)Length; ++i) - ShuffleMask.push_back(i + Index); - for (int i = Length; i != 8; ++i) - ShuffleMask.push_back(i + 16); - for (int i = 8; i != 16; ++i) - ShuffleMask.push_back(-1); - - Value *SV = Builder.CreateShuffleVector( - Builder.CreateBitCast(Op0, ShufTy), - ConstantAggregateZero::get(ShufTy), ShuffleMask); - return Builder.CreateBitCast(SV, II.getType()); - } - - // Constant Fold - shift Index'th bit to lowest position and mask off - // Length bits. - if (CI0) { - APInt Elt = CI0->getValue(); - Elt.lshrInPlace(Index); - Elt = Elt.zextOrTrunc(Length); - return LowConstantHighUndef(Elt.getZExtValue()); - } - - // If we were an EXTRQ call, we'll save registers if we convert to EXTRQI. - if (II.getIntrinsicID() == Intrinsic::x86_sse4a_extrq) { - Value *Args[] = {Op0, CILength, CIIndex}; - Module *M = II.getModule(); - Function *F = Intrinsic::getDeclaration(M, Intrinsic::x86_sse4a_extrqi); - return Builder.CreateCall(F, Args); - } - } - - // Constant Fold - extraction from zero is always {zero, undef}. - if (CI0 && CI0->isZero()) - return LowConstantHighUndef(0); - - return nullptr; -} - -/// Attempt to simplify SSE4A INSERTQ/INSERTQI instructions using constant -/// folding or conversion to a shuffle vector. -static Value *simplifyX86insertq(IntrinsicInst &II, Value *Op0, Value *Op1, - APInt APLength, APInt APIndex, - InstCombiner::BuilderTy &Builder) { - // From AMD documentation: "The bit index and field length are each six bits - // in length other bits of the field are ignored." - APIndex = APIndex.zextOrTrunc(6); - APLength = APLength.zextOrTrunc(6); - - // Attempt to constant fold. - unsigned Index = APIndex.getZExtValue(); - - // From AMD documentation: "a value of zero in the field length is - // defined as length of 64". - unsigned Length = APLength == 0 ? 64 : APLength.getZExtValue(); - - // From AMD documentation: "If the sum of the bit index + length field - // is greater than 64, the results are undefined". - unsigned End = Index + Length; - - // Note that both field index and field length are 8-bit quantities. - // Since variables 'Index' and 'Length' are unsigned values - // obtained from zero-extending field index and field length - // respectively, their sum should never wrap around. - if (End > 64) - return UndefValue::get(II.getType()); - - // If we are inserting whole bytes, we can convert this to a shuffle. - // Lowering can recognize INSERTQI shuffle masks. - if ((Length % 8) == 0 && (Index % 8) == 0) { - // Convert bit indices to byte indices. - Length /= 8; - Index /= 8; - - Type *IntTy8 = Type::getInt8Ty(II.getContext()); - auto *ShufTy = FixedVectorType::get(IntTy8, 16); - - SmallVector<int, 16> ShuffleMask; - for (int i = 0; i != (int)Index; ++i) - ShuffleMask.push_back(i); - for (int i = 0; i != (int)Length; ++i) - ShuffleMask.push_back(i + 16); - for (int i = Index + Length; i != 8; ++i) - ShuffleMask.push_back(i); - for (int i = 8; i != 16; ++i) - ShuffleMask.push_back(-1); - - Value *SV = Builder.CreateShuffleVector(Builder.CreateBitCast(Op0, ShufTy), - Builder.CreateBitCast(Op1, ShufTy), - ShuffleMask); - return Builder.CreateBitCast(SV, II.getType()); - } - - // See if we're dealing with constant values. - Constant *C0 = dyn_cast<Constant>(Op0); - Constant *C1 = dyn_cast<Constant>(Op1); - ConstantInt *CI00 = - C0 ? dyn_cast_or_null<ConstantInt>(C0->getAggregateElement((unsigned)0)) - : nullptr; - ConstantInt *CI10 = - C1 ? dyn_cast_or_null<ConstantInt>(C1->getAggregateElement((unsigned)0)) - : nullptr; - - // Constant Fold - insert bottom Length bits starting at the Index'th bit. - if (CI00 && CI10) { - APInt V00 = CI00->getValue(); - APInt V10 = CI10->getValue(); - APInt Mask = APInt::getLowBitsSet(64, Length).shl(Index); - V00 = V00 & ~Mask; - V10 = V10.zextOrTrunc(Length).zextOrTrunc(64).shl(Index); - APInt Val = V00 | V10; - Type *IntTy64 = Type::getInt64Ty(II.getContext()); - Constant *Args[] = {ConstantInt::get(IntTy64, Val.getZExtValue()), - UndefValue::get(IntTy64)}; - return ConstantVector::get(Args); - } - - // If we were an INSERTQ call, we'll save demanded elements if we convert to - // INSERTQI. - if (II.getIntrinsicID() == Intrinsic::x86_sse4a_insertq) { - Type *IntTy8 = Type::getInt8Ty(II.getContext()); - Constant *CILength = ConstantInt::get(IntTy8, Length, false); - Constant *CIIndex = ConstantInt::get(IntTy8, Index, false); - - Value *Args[] = {Op0, Op1, CILength, CIIndex}; - Module *M = II.getModule(); - Function *F = Intrinsic::getDeclaration(M, Intrinsic::x86_sse4a_insertqi); - return Builder.CreateCall(F, Args); - } - - return nullptr; -} - -/// Attempt to convert pshufb* to shufflevector if the mask is constant. -static Value *simplifyX86pshufb(const IntrinsicInst &II, - InstCombiner::BuilderTy &Builder) { - Constant *V = dyn_cast<Constant>(II.getArgOperand(1)); - if (!V) - return nullptr; - - auto *VecTy = cast<VectorType>(II.getType()); - unsigned NumElts = VecTy->getNumElements(); - assert((NumElts == 16 || NumElts == 32 || NumElts == 64) && - "Unexpected number of elements in shuffle mask!"); - - // Construct a shuffle mask from constant integers or UNDEFs. - int Indexes[64]; - - // Each byte in the shuffle control mask forms an index to permute the - // corresponding byte in the destination operand. - for (unsigned I = 0; I < NumElts; ++I) { - Constant *COp = V->getAggregateElement(I); - if (!COp || (!isa<UndefValue>(COp) && !isa<ConstantInt>(COp))) - return nullptr; - - if (isa<UndefValue>(COp)) { - Indexes[I] = -1; - continue; - } - - int8_t Index = cast<ConstantInt>(COp)->getValue().getZExtValue(); - - // If the most significant bit (bit[7]) of each byte of the shuffle - // control mask is set, then zero is written in the result byte. - // The zero vector is in the right-hand side of the resulting - // shufflevector. - - // The value of each index for the high 128-bit lane is the least - // significant 4 bits of the respective shuffle control byte. - Index = ((Index < 0) ? NumElts : Index & 0x0F) + (I & 0xF0); - Indexes[I] = Index; - } - - auto V1 = II.getArgOperand(0); - auto V2 = Constant::getNullValue(VecTy); - return Builder.CreateShuffleVector(V1, V2, makeArrayRef(Indexes, NumElts)); -} - -/// Attempt to convert vpermilvar* to shufflevector if the mask is constant. -static Value *simplifyX86vpermilvar(const IntrinsicInst &II, - InstCombiner::BuilderTy &Builder) { - Constant *V = dyn_cast<Constant>(II.getArgOperand(1)); - if (!V) - return nullptr; - - auto *VecTy = cast<VectorType>(II.getType()); - unsigned NumElts = VecTy->getNumElements(); - bool IsPD = VecTy->getScalarType()->isDoubleTy(); - unsigned NumLaneElts = IsPD ? 2 : 4; - assert(NumElts == 16 || NumElts == 8 || NumElts == 4 || NumElts == 2); - - // Construct a shuffle mask from constant integers or UNDEFs. - int Indexes[16]; - - // The intrinsics only read one or two bits, clear the rest. - for (unsigned I = 0; I < NumElts; ++I) { - Constant *COp = V->getAggregateElement(I); - if (!COp || (!isa<UndefValue>(COp) && !isa<ConstantInt>(COp))) - return nullptr; - - if (isa<UndefValue>(COp)) { - Indexes[I] = -1; - continue; - } - - APInt Index = cast<ConstantInt>(COp)->getValue(); - Index = Index.zextOrTrunc(32).getLoBits(2); - - // The PD variants uses bit 1 to select per-lane element index, so - // shift down to convert to generic shuffle mask index. - if (IsPD) - Index.lshrInPlace(1); - - // The _256 variants are a bit trickier since the mask bits always index - // into the corresponding 128 half. In order to convert to a generic - // shuffle, we have to make that explicit. - Index += APInt(32, (I / NumLaneElts) * NumLaneElts); - - Indexes[I] = Index.getZExtValue(); - } - - auto V1 = II.getArgOperand(0); - auto V2 = UndefValue::get(V1->getType()); - return Builder.CreateShuffleVector(V1, V2, makeArrayRef(Indexes, NumElts)); -} - -/// Attempt to convert vpermd/vpermps to shufflevector if the mask is constant. -static Value *simplifyX86vpermv(const IntrinsicInst &II, - InstCombiner::BuilderTy &Builder) { - auto *V = dyn_cast<Constant>(II.getArgOperand(1)); - if (!V) - return nullptr; - - auto *VecTy = cast<VectorType>(II.getType()); - unsigned Size = VecTy->getNumElements(); - assert((Size == 4 || Size == 8 || Size == 16 || Size == 32 || Size == 64) && - "Unexpected shuffle mask size"); - - // Construct a shuffle mask from constant integers or UNDEFs. - int Indexes[64]; - - for (unsigned I = 0; I < Size; ++I) { - Constant *COp = V->getAggregateElement(I); - if (!COp || (!isa<UndefValue>(COp) && !isa<ConstantInt>(COp))) - return nullptr; - - if (isa<UndefValue>(COp)) { - Indexes[I] = -1; - continue; - } - - uint32_t Index = cast<ConstantInt>(COp)->getZExtValue(); - Index &= Size - 1; - Indexes[I] = Index; - } - - auto V1 = II.getArgOperand(0); - auto V2 = UndefValue::get(VecTy); - return Builder.CreateShuffleVector(V1, V2, makeArrayRef(Indexes, Size)); -} - // TODO, Obvious Missing Transforms: // * Narrow width by halfs excluding zero/undef lanes -Value *InstCombiner::simplifyMaskedLoad(IntrinsicInst &II) { +Value *InstCombinerImpl::simplifyMaskedLoad(IntrinsicInst &II) { Value *LoadPtr = II.getArgOperand(0); const Align Alignment = cast<ConstantInt>(II.getArgOperand(1))->getAlignValue(); @@ -1118,9 +289,8 @@ Value *InstCombiner::simplifyMaskedLoad(IntrinsicInst &II) { // If we can unconditionally load from this address, replace with a // load/select idiom. TODO: use DT for context sensitive query - if (isDereferenceableAndAlignedPointer(LoadPtr, II.getType(), Alignment, - II.getModule()->getDataLayout(), &II, - nullptr)) { + if (isDereferenceablePointer(LoadPtr, II.getType(), + II.getModule()->getDataLayout(), &II, nullptr)) { Value *LI = Builder.CreateAlignedLoad(II.getType(), LoadPtr, Alignment, "unmaskedload"); return Builder.CreateSelect(II.getArgOperand(2), LI, II.getArgOperand(3)); @@ -1132,7 +302,7 @@ Value *InstCombiner::simplifyMaskedLoad(IntrinsicInst &II) { // TODO, Obvious Missing Transforms: // * Single constant active lane -> store // * Narrow width by halfs excluding zero/undef lanes -Instruction *InstCombiner::simplifyMaskedStore(IntrinsicInst &II) { +Instruction *InstCombinerImpl::simplifyMaskedStore(IntrinsicInst &II) { auto *ConstMask = dyn_cast<Constant>(II.getArgOperand(3)); if (!ConstMask) return nullptr; @@ -1148,11 +318,14 @@ Instruction *InstCombiner::simplifyMaskedStore(IntrinsicInst &II) { return new StoreInst(II.getArgOperand(0), StorePtr, false, Alignment); } + if (isa<ScalableVectorType>(ConstMask->getType())) + return nullptr; + // Use masked off lanes to simplify operands via SimplifyDemandedVectorElts APInt DemandedElts = possiblyDemandedEltsInMask(ConstMask); APInt UndefElts(DemandedElts.getBitWidth(), 0); - if (Value *V = SimplifyDemandedVectorElts(II.getOperand(0), - DemandedElts, UndefElts)) + if (Value *V = + SimplifyDemandedVectorElts(II.getOperand(0), DemandedElts, UndefElts)) return replaceOperand(II, 0, V); return nullptr; @@ -1165,7 +338,7 @@ Instruction *InstCombiner::simplifyMaskedStore(IntrinsicInst &II) { // * Narrow width by halfs excluding zero/undef lanes // * Vector splat address w/known mask -> scalar load // * Vector incrementing address -> vector masked load -Instruction *InstCombiner::simplifyMaskedGather(IntrinsicInst &II) { +Instruction *InstCombinerImpl::simplifyMaskedGather(IntrinsicInst &II) { return nullptr; } @@ -1175,7 +348,7 @@ Instruction *InstCombiner::simplifyMaskedGather(IntrinsicInst &II) { // * Narrow store width by halfs excluding zero/undef lanes // * Vector splat address w/known mask -> scalar store // * Vector incrementing address -> vector masked store -Instruction *InstCombiner::simplifyMaskedScatter(IntrinsicInst &II) { +Instruction *InstCombinerImpl::simplifyMaskedScatter(IntrinsicInst &II) { auto *ConstMask = dyn_cast<Constant>(II.getArgOperand(3)); if (!ConstMask) return nullptr; @@ -1184,14 +357,17 @@ Instruction *InstCombiner::simplifyMaskedScatter(IntrinsicInst &II) { if (ConstMask->isNullValue()) return eraseInstFromFunction(II); + if (isa<ScalableVectorType>(ConstMask->getType())) + return nullptr; + // Use masked off lanes to simplify operands via SimplifyDemandedVectorElts APInt DemandedElts = possiblyDemandedEltsInMask(ConstMask); APInt UndefElts(DemandedElts.getBitWidth(), 0); - if (Value *V = SimplifyDemandedVectorElts(II.getOperand(0), - DemandedElts, UndefElts)) + if (Value *V = + SimplifyDemandedVectorElts(II.getOperand(0), DemandedElts, UndefElts)) return replaceOperand(II, 0, V); - if (Value *V = SimplifyDemandedVectorElts(II.getOperand(1), - DemandedElts, UndefElts)) + if (Value *V = + SimplifyDemandedVectorElts(II.getOperand(1), DemandedElts, UndefElts)) return replaceOperand(II, 1, V); return nullptr; @@ -1206,7 +382,7 @@ Instruction *InstCombiner::simplifyMaskedScatter(IntrinsicInst &II) { /// This is legal because it preserves the most recent information about /// the presence or absence of invariant.group. static Instruction *simplifyInvariantGroupIntrinsic(IntrinsicInst &II, - InstCombiner &IC) { + InstCombinerImpl &IC) { auto *Arg = II.getArgOperand(0); auto *StrippedArg = Arg->stripPointerCasts(); auto *StrippedInvariantGroupsArg = Arg->stripPointerCastsAndInvariantGroups(); @@ -1231,7 +407,7 @@ static Instruction *simplifyInvariantGroupIntrinsic(IntrinsicInst &II, return cast<Instruction>(Result); } -static Instruction *foldCttzCtlz(IntrinsicInst &II, InstCombiner &IC) { +static Instruction *foldCttzCtlz(IntrinsicInst &II, InstCombinerImpl &IC) { assert((II.getIntrinsicID() == Intrinsic::cttz || II.getIntrinsicID() == Intrinsic::ctlz) && "Expected cttz or ctlz intrinsic"); @@ -1257,6 +433,9 @@ static Instruction *foldCttzCtlz(IntrinsicInst &II, InstCombiner &IC) { SelectPatternFlavor SPF = matchSelectPattern(Op0, X, Y).Flavor; if (SPF == SPF_ABS || SPF == SPF_NABS) return IC.replaceOperand(II, 0, X); + + if (match(Op0, m_Intrinsic<Intrinsic::abs>(m_Value(X)))) + return IC.replaceOperand(II, 0, X); } KnownBits Known = IC.computeKnownBits(Op0, 0, &II); @@ -1301,7 +480,7 @@ static Instruction *foldCttzCtlz(IntrinsicInst &II, InstCombiner &IC) { return nullptr; } -static Instruction *foldCtpop(IntrinsicInst &II, InstCombiner &IC) { +static Instruction *foldCtpop(IntrinsicInst &II, InstCombinerImpl &IC) { assert(II.getIntrinsicID() == Intrinsic::ctpop && "Expected ctpop intrinsic"); Type *Ty = II.getType(); @@ -1356,107 +535,6 @@ static Instruction *foldCtpop(IntrinsicInst &II, InstCombiner &IC) { return nullptr; } -// TODO: If the x86 backend knew how to convert a bool vector mask back to an -// XMM register mask efficiently, we could transform all x86 masked intrinsics -// to LLVM masked intrinsics and remove the x86 masked intrinsic defs. -static Instruction *simplifyX86MaskedLoad(IntrinsicInst &II, InstCombiner &IC) { - Value *Ptr = II.getOperand(0); - Value *Mask = II.getOperand(1); - Constant *ZeroVec = Constant::getNullValue(II.getType()); - - // Special case a zero mask since that's not a ConstantDataVector. - // This masked load instruction creates a zero vector. - if (isa<ConstantAggregateZero>(Mask)) - return IC.replaceInstUsesWith(II, ZeroVec); - - auto *ConstMask = dyn_cast<ConstantDataVector>(Mask); - if (!ConstMask) - return nullptr; - - // The mask is constant. Convert this x86 intrinsic to the LLVM instrinsic - // to allow target-independent optimizations. - - // First, cast the x86 intrinsic scalar pointer to a vector pointer to match - // the LLVM intrinsic definition for the pointer argument. - unsigned AddrSpace = cast<PointerType>(Ptr->getType())->getAddressSpace(); - PointerType *VecPtrTy = PointerType::get(II.getType(), AddrSpace); - Value *PtrCast = IC.Builder.CreateBitCast(Ptr, VecPtrTy, "castvec"); - - // Second, convert the x86 XMM integer vector mask to a vector of bools based - // on each element's most significant bit (the sign bit). - Constant *BoolMask = getNegativeIsTrueBoolVec(ConstMask); - - // The pass-through vector for an x86 masked load is a zero vector. - CallInst *NewMaskedLoad = - IC.Builder.CreateMaskedLoad(PtrCast, Align(1), BoolMask, ZeroVec); - return IC.replaceInstUsesWith(II, NewMaskedLoad); -} - -// TODO: If the x86 backend knew how to convert a bool vector mask back to an -// XMM register mask efficiently, we could transform all x86 masked intrinsics -// to LLVM masked intrinsics and remove the x86 masked intrinsic defs. -static bool simplifyX86MaskedStore(IntrinsicInst &II, InstCombiner &IC) { - Value *Ptr = II.getOperand(0); - Value *Mask = II.getOperand(1); - Value *Vec = II.getOperand(2); - - // Special case a zero mask since that's not a ConstantDataVector: - // this masked store instruction does nothing. - if (isa<ConstantAggregateZero>(Mask)) { - IC.eraseInstFromFunction(II); - return true; - } - - // The SSE2 version is too weird (eg, unaligned but non-temporal) to do - // anything else at this level. - if (II.getIntrinsicID() == Intrinsic::x86_sse2_maskmov_dqu) - return false; - - auto *ConstMask = dyn_cast<ConstantDataVector>(Mask); - if (!ConstMask) - return false; - - // The mask is constant. Convert this x86 intrinsic to the LLVM instrinsic - // to allow target-independent optimizations. - - // First, cast the x86 intrinsic scalar pointer to a vector pointer to match - // the LLVM intrinsic definition for the pointer argument. - unsigned AddrSpace = cast<PointerType>(Ptr->getType())->getAddressSpace(); - PointerType *VecPtrTy = PointerType::get(Vec->getType(), AddrSpace); - Value *PtrCast = IC.Builder.CreateBitCast(Ptr, VecPtrTy, "castvec"); - - // Second, convert the x86 XMM integer vector mask to a vector of bools based - // on each element's most significant bit (the sign bit). - Constant *BoolMask = getNegativeIsTrueBoolVec(ConstMask); - - IC.Builder.CreateMaskedStore(Vec, PtrCast, Align(1), BoolMask); - - // 'Replace uses' doesn't work for stores. Erase the original masked store. - IC.eraseInstFromFunction(II); - return true; -} - -// Constant fold llvm.amdgcn.fmed3 intrinsics for standard inputs. -// -// A single NaN input is folded to minnum, so we rely on that folding for -// handling NaNs. -static APFloat fmed3AMDGCN(const APFloat &Src0, const APFloat &Src1, - const APFloat &Src2) { - APFloat Max3 = maxnum(maxnum(Src0, Src1), Src2); - - APFloat::cmpResult Cmp0 = Max3.compare(Src0); - assert(Cmp0 != APFloat::cmpUnordered && "nans handled separately"); - if (Cmp0 == APFloat::cmpEqual) - return maxnum(Src1, Src2); - - APFloat::cmpResult Cmp1 = Max3.compare(Src1); - assert(Cmp1 != APFloat::cmpUnordered && "nans handled separately"); - if (Cmp1 == APFloat::cmpEqual) - return maxnum(Src0, Src2); - - return maxnum(Src0, Src1); -} - /// Convert a table lookup to shufflevector if the mask is constant. /// This could benefit tbl1 if the mask is { 7,6,5,4,3,2,1,0 }, in /// which case we could lower the shufflevector with rev64 instructions @@ -1468,7 +546,7 @@ static Value *simplifyNeonTbl1(const IntrinsicInst &II, if (!C) return nullptr; - auto *VecTy = cast<VectorType>(II.getType()); + auto *VecTy = cast<FixedVectorType>(II.getType()); unsigned NumElts = VecTy->getNumElements(); // Only perform this transformation for <8 x i8> vector types. @@ -1495,28 +573,6 @@ static Value *simplifyNeonTbl1(const IntrinsicInst &II, return Builder.CreateShuffleVector(V1, V2, makeArrayRef(Indexes)); } -/// Convert a vector load intrinsic into a simple llvm load instruction. -/// This is beneficial when the underlying object being addressed comes -/// from a constant, since we get constant-folding for free. -static Value *simplifyNeonVld1(const IntrinsicInst &II, - unsigned MemAlign, - InstCombiner::BuilderTy &Builder) { - auto *IntrAlign = dyn_cast<ConstantInt>(II.getArgOperand(1)); - - if (!IntrAlign) - return nullptr; - - unsigned Alignment = IntrAlign->getLimitedValue() < MemAlign ? - MemAlign : IntrAlign->getLimitedValue(); - - if (!isPowerOf2_32(Alignment)) - return nullptr; - - auto *BCastInst = Builder.CreateBitCast(II.getArgOperand(0), - PointerType::get(II.getType(), 0)); - return Builder.CreateAlignedLoad(II.getType(), BCastInst, Align(Alignment)); -} - // Returns true iff the 2 intrinsics have the same operands, limiting the // comparison to the first NumOperands. static bool haveSameOperands(const IntrinsicInst &I, const IntrinsicInst &E, @@ -1538,9 +594,9 @@ static bool haveSameOperands(const IntrinsicInst &I, const IntrinsicInst &E, // call @llvm.foo.start(i1 0) ; This one won't be skipped: it will be removed // call @llvm.foo.end(i1 0) // call @llvm.foo.end(i1 0) ; &I -static bool removeTriviallyEmptyRange( - IntrinsicInst &EndI, InstCombiner &IC, - std::function<bool(const IntrinsicInst &)> IsStart) { +static bool +removeTriviallyEmptyRange(IntrinsicInst &EndI, InstCombinerImpl &IC, + std::function<bool(const IntrinsicInst &)> IsStart) { // We start from the end intrinsic and scan backwards, so that InstCombine // has already processed (and potentially removed) all the instructions // before the end intrinsic. @@ -1566,256 +622,7 @@ static bool removeTriviallyEmptyRange( return false; } -// Convert NVVM intrinsics to target-generic LLVM code where possible. -static Instruction *SimplifyNVVMIntrinsic(IntrinsicInst *II, InstCombiner &IC) { - // Each NVVM intrinsic we can simplify can be replaced with one of: - // - // * an LLVM intrinsic, - // * an LLVM cast operation, - // * an LLVM binary operation, or - // * ad-hoc LLVM IR for the particular operation. - - // Some transformations are only valid when the module's - // flush-denormals-to-zero (ftz) setting is true/false, whereas other - // transformations are valid regardless of the module's ftz setting. - enum FtzRequirementTy { - FTZ_Any, // Any ftz setting is ok. - FTZ_MustBeOn, // Transformation is valid only if ftz is on. - FTZ_MustBeOff, // Transformation is valid only if ftz is off. - }; - // Classes of NVVM intrinsics that can't be replaced one-to-one with a - // target-generic intrinsic, cast op, or binary op but that we can nonetheless - // simplify. - enum SpecialCase { - SPC_Reciprocal, - }; - - // SimplifyAction is a poor-man's variant (plus an additional flag) that - // represents how to replace an NVVM intrinsic with target-generic LLVM IR. - struct SimplifyAction { - // Invariant: At most one of these Optionals has a value. - Optional<Intrinsic::ID> IID; - Optional<Instruction::CastOps> CastOp; - Optional<Instruction::BinaryOps> BinaryOp; - Optional<SpecialCase> Special; - - FtzRequirementTy FtzRequirement = FTZ_Any; - - SimplifyAction() = default; - - SimplifyAction(Intrinsic::ID IID, FtzRequirementTy FtzReq) - : IID(IID), FtzRequirement(FtzReq) {} - - // Cast operations don't have anything to do with FTZ, so we skip that - // argument. - SimplifyAction(Instruction::CastOps CastOp) : CastOp(CastOp) {} - - SimplifyAction(Instruction::BinaryOps BinaryOp, FtzRequirementTy FtzReq) - : BinaryOp(BinaryOp), FtzRequirement(FtzReq) {} - - SimplifyAction(SpecialCase Special, FtzRequirementTy FtzReq) - : Special(Special), FtzRequirement(FtzReq) {} - }; - - // Try to generate a SimplifyAction describing how to replace our - // IntrinsicInstr with target-generic LLVM IR. - const SimplifyAction Action = [II]() -> SimplifyAction { - switch (II->getIntrinsicID()) { - // NVVM intrinsics that map directly to LLVM intrinsics. - case Intrinsic::nvvm_ceil_d: - return {Intrinsic::ceil, FTZ_Any}; - case Intrinsic::nvvm_ceil_f: - return {Intrinsic::ceil, FTZ_MustBeOff}; - case Intrinsic::nvvm_ceil_ftz_f: - return {Intrinsic::ceil, FTZ_MustBeOn}; - case Intrinsic::nvvm_fabs_d: - return {Intrinsic::fabs, FTZ_Any}; - case Intrinsic::nvvm_fabs_f: - return {Intrinsic::fabs, FTZ_MustBeOff}; - case Intrinsic::nvvm_fabs_ftz_f: - return {Intrinsic::fabs, FTZ_MustBeOn}; - case Intrinsic::nvvm_floor_d: - return {Intrinsic::floor, FTZ_Any}; - case Intrinsic::nvvm_floor_f: - return {Intrinsic::floor, FTZ_MustBeOff}; - case Intrinsic::nvvm_floor_ftz_f: - return {Intrinsic::floor, FTZ_MustBeOn}; - case Intrinsic::nvvm_fma_rn_d: - return {Intrinsic::fma, FTZ_Any}; - case Intrinsic::nvvm_fma_rn_f: - return {Intrinsic::fma, FTZ_MustBeOff}; - case Intrinsic::nvvm_fma_rn_ftz_f: - return {Intrinsic::fma, FTZ_MustBeOn}; - case Intrinsic::nvvm_fmax_d: - return {Intrinsic::maxnum, FTZ_Any}; - case Intrinsic::nvvm_fmax_f: - return {Intrinsic::maxnum, FTZ_MustBeOff}; - case Intrinsic::nvvm_fmax_ftz_f: - return {Intrinsic::maxnum, FTZ_MustBeOn}; - case Intrinsic::nvvm_fmin_d: - return {Intrinsic::minnum, FTZ_Any}; - case Intrinsic::nvvm_fmin_f: - return {Intrinsic::minnum, FTZ_MustBeOff}; - case Intrinsic::nvvm_fmin_ftz_f: - return {Intrinsic::minnum, FTZ_MustBeOn}; - case Intrinsic::nvvm_round_d: - return {Intrinsic::round, FTZ_Any}; - case Intrinsic::nvvm_round_f: - return {Intrinsic::round, FTZ_MustBeOff}; - case Intrinsic::nvvm_round_ftz_f: - return {Intrinsic::round, FTZ_MustBeOn}; - case Intrinsic::nvvm_sqrt_rn_d: - return {Intrinsic::sqrt, FTZ_Any}; - case Intrinsic::nvvm_sqrt_f: - // nvvm_sqrt_f is a special case. For most intrinsics, foo_ftz_f is the - // ftz version, and foo_f is the non-ftz version. But nvvm_sqrt_f adopts - // the ftz-ness of the surrounding code. sqrt_rn_f and sqrt_rn_ftz_f are - // the versions with explicit ftz-ness. - return {Intrinsic::sqrt, FTZ_Any}; - case Intrinsic::nvvm_sqrt_rn_f: - return {Intrinsic::sqrt, FTZ_MustBeOff}; - case Intrinsic::nvvm_sqrt_rn_ftz_f: - return {Intrinsic::sqrt, FTZ_MustBeOn}; - case Intrinsic::nvvm_trunc_d: - return {Intrinsic::trunc, FTZ_Any}; - case Intrinsic::nvvm_trunc_f: - return {Intrinsic::trunc, FTZ_MustBeOff}; - case Intrinsic::nvvm_trunc_ftz_f: - return {Intrinsic::trunc, FTZ_MustBeOn}; - - // NVVM intrinsics that map to LLVM cast operations. - // - // Note that llvm's target-generic conversion operators correspond to the rz - // (round to zero) versions of the nvvm conversion intrinsics, even though - // most everything else here uses the rn (round to nearest even) nvvm ops. - case Intrinsic::nvvm_d2i_rz: - case Intrinsic::nvvm_f2i_rz: - case Intrinsic::nvvm_d2ll_rz: - case Intrinsic::nvvm_f2ll_rz: - return {Instruction::FPToSI}; - case Intrinsic::nvvm_d2ui_rz: - case Intrinsic::nvvm_f2ui_rz: - case Intrinsic::nvvm_d2ull_rz: - case Intrinsic::nvvm_f2ull_rz: - return {Instruction::FPToUI}; - case Intrinsic::nvvm_i2d_rz: - case Intrinsic::nvvm_i2f_rz: - case Intrinsic::nvvm_ll2d_rz: - case Intrinsic::nvvm_ll2f_rz: - return {Instruction::SIToFP}; - case Intrinsic::nvvm_ui2d_rz: - case Intrinsic::nvvm_ui2f_rz: - case Intrinsic::nvvm_ull2d_rz: - case Intrinsic::nvvm_ull2f_rz: - return {Instruction::UIToFP}; - - // NVVM intrinsics that map to LLVM binary ops. - case Intrinsic::nvvm_add_rn_d: - return {Instruction::FAdd, FTZ_Any}; - case Intrinsic::nvvm_add_rn_f: - return {Instruction::FAdd, FTZ_MustBeOff}; - case Intrinsic::nvvm_add_rn_ftz_f: - return {Instruction::FAdd, FTZ_MustBeOn}; - case Intrinsic::nvvm_mul_rn_d: - return {Instruction::FMul, FTZ_Any}; - case Intrinsic::nvvm_mul_rn_f: - return {Instruction::FMul, FTZ_MustBeOff}; - case Intrinsic::nvvm_mul_rn_ftz_f: - return {Instruction::FMul, FTZ_MustBeOn}; - case Intrinsic::nvvm_div_rn_d: - return {Instruction::FDiv, FTZ_Any}; - case Intrinsic::nvvm_div_rn_f: - return {Instruction::FDiv, FTZ_MustBeOff}; - case Intrinsic::nvvm_div_rn_ftz_f: - return {Instruction::FDiv, FTZ_MustBeOn}; - - // The remainder of cases are NVVM intrinsics that map to LLVM idioms, but - // need special handling. - // - // We seem to be missing intrinsics for rcp.approx.{ftz.}f32, which is just - // as well. - case Intrinsic::nvvm_rcp_rn_d: - return {SPC_Reciprocal, FTZ_Any}; - case Intrinsic::nvvm_rcp_rn_f: - return {SPC_Reciprocal, FTZ_MustBeOff}; - case Intrinsic::nvvm_rcp_rn_ftz_f: - return {SPC_Reciprocal, FTZ_MustBeOn}; - - // We do not currently simplify intrinsics that give an approximate answer. - // These include: - // - // - nvvm_cos_approx_{f,ftz_f} - // - nvvm_ex2_approx_{d,f,ftz_f} - // - nvvm_lg2_approx_{d,f,ftz_f} - // - nvvm_sin_approx_{f,ftz_f} - // - nvvm_sqrt_approx_{f,ftz_f} - // - nvvm_rsqrt_approx_{d,f,ftz_f} - // - nvvm_div_approx_{ftz_d,ftz_f,f} - // - nvvm_rcp_approx_ftz_d - // - // Ideally we'd encode them as e.g. "fast call @llvm.cos", where "fast" - // means that fastmath is enabled in the intrinsic. Unfortunately only - // binary operators (currently) have a fastmath bit in SelectionDAG, so this - // information gets lost and we can't select on it. - // - // TODO: div and rcp are lowered to a binary op, so these we could in theory - // lower them to "fast fdiv". - - default: - return {}; - } - }(); - - // If Action.FtzRequirementTy is not satisfied by the module's ftz state, we - // can bail out now. (Notice that in the case that IID is not an NVVM - // intrinsic, we don't have to look up any module metadata, as - // FtzRequirementTy will be FTZ_Any.) - if (Action.FtzRequirement != FTZ_Any) { - StringRef Attr = II->getFunction() - ->getFnAttribute("denormal-fp-math-f32") - .getValueAsString(); - DenormalMode Mode = parseDenormalFPAttribute(Attr); - bool FtzEnabled = Mode.Output != DenormalMode::IEEE; - - if (FtzEnabled != (Action.FtzRequirement == FTZ_MustBeOn)) - return nullptr; - } - - // Simplify to target-generic intrinsic. - if (Action.IID) { - SmallVector<Value *, 4> Args(II->arg_operands()); - // All the target-generic intrinsics currently of interest to us have one - // type argument, equal to that of the nvvm intrinsic's argument. - Type *Tys[] = {II->getArgOperand(0)->getType()}; - return CallInst::Create( - Intrinsic::getDeclaration(II->getModule(), *Action.IID, Tys), Args); - } - - // Simplify to target-generic binary op. - if (Action.BinaryOp) - return BinaryOperator::Create(*Action.BinaryOp, II->getArgOperand(0), - II->getArgOperand(1), II->getName()); - - // Simplify to target-generic cast op. - if (Action.CastOp) - return CastInst::Create(*Action.CastOp, II->getArgOperand(0), II->getType(), - II->getName()); - - // All that's left are the special cases. - if (!Action.Special) - return nullptr; - - switch (*Action.Special) { - case SPC_Reciprocal: - // Simplify reciprocal. - return BinaryOperator::Create( - Instruction::FDiv, ConstantFP::get(II->getArgOperand(0)->getType(), 1), - II->getArgOperand(0), II->getName()); - } - llvm_unreachable("All SpecialCase enumerators should be handled in switch."); -} - -Instruction *InstCombiner::visitVAEndInst(VAEndInst &I) { +Instruction *InstCombinerImpl::visitVAEndInst(VAEndInst &I) { removeTriviallyEmptyRange(I, *this, [](const IntrinsicInst &I) { return I.getIntrinsicID() == Intrinsic::vastart || I.getIntrinsicID() == Intrinsic::vacopy; @@ -1823,7 +630,7 @@ Instruction *InstCombiner::visitVAEndInst(VAEndInst &I) { return nullptr; } -static Instruction *canonicalizeConstantArg0ToArg1(CallInst &Call) { +static CallInst *canonicalizeConstantArg0ToArg1(CallInst &Call) { assert(Call.getNumArgOperands() > 1 && "Need at least 2 args to swap"); Value *Arg0 = Call.getArgOperand(0), *Arg1 = Call.getArgOperand(1); if (isa<Constant>(Arg0) && !isa<Constant>(Arg1)) { @@ -1834,20 +641,44 @@ static Instruction *canonicalizeConstantArg0ToArg1(CallInst &Call) { return nullptr; } -Instruction *InstCombiner::foldIntrinsicWithOverflowCommon(IntrinsicInst *II) { +/// Creates a result tuple for an overflow intrinsic \p II with a given +/// \p Result and a constant \p Overflow value. +static Instruction *createOverflowTuple(IntrinsicInst *II, Value *Result, + Constant *Overflow) { + Constant *V[] = {UndefValue::get(Result->getType()), Overflow}; + StructType *ST = cast<StructType>(II->getType()); + Constant *Struct = ConstantStruct::get(ST, V); + return InsertValueInst::Create(Struct, Result, 0); +} + +Instruction * +InstCombinerImpl::foldIntrinsicWithOverflowCommon(IntrinsicInst *II) { WithOverflowInst *WO = cast<WithOverflowInst>(II); Value *OperationResult = nullptr; Constant *OverflowResult = nullptr; if (OptimizeOverflowCheck(WO->getBinaryOp(), WO->isSigned(), WO->getLHS(), WO->getRHS(), *WO, OperationResult, OverflowResult)) - return CreateOverflowTuple(WO, OperationResult, OverflowResult); + return createOverflowTuple(WO, OperationResult, OverflowResult); return nullptr; } +static Optional<bool> getKnownSign(Value *Op, Instruction *CxtI, + const DataLayout &DL, AssumptionCache *AC, + DominatorTree *DT) { + KnownBits Known = computeKnownBits(Op, DL, 0, AC, CxtI, DT); + if (Known.isNonNegative()) + return false; + if (Known.isNegative()) + return true; + + return isImpliedByDomCondition( + ICmpInst::ICMP_SLT, Op, Constant::getNullValue(Op->getType()), CxtI, DL); +} + /// CallInst simplification. This mostly only handles folding of intrinsic /// instructions. For normal calls, it allows visitCallBase to do the heavy /// lifting. -Instruction *InstCombiner::visitCallInst(CallInst &CI) { +Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { // Don't try to simplify calls without uses. It will not do anything useful, // but will result in the following folds being skipped. if (!CI.use_empty()) @@ -1953,31 +784,84 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { } } - if (Instruction *I = SimplifyNVVMIntrinsic(II, *this)) - return I; - - auto SimplifyDemandedVectorEltsLow = [this](Value *Op, unsigned Width, - unsigned DemandedWidth) { - APInt UndefElts(Width, 0); - APInt DemandedElts = APInt::getLowBitsSet(Width, DemandedWidth); - return SimplifyDemandedVectorElts(Op, DemandedElts, UndefElts); - }; + if (II->isCommutative()) { + if (CallInst *NewCall = canonicalizeConstantArg0ToArg1(CI)) + return NewCall; + } Intrinsic::ID IID = II->getIntrinsicID(); switch (IID) { - default: break; case Intrinsic::objectsize: if (Value *V = lowerObjectSizeCall(II, DL, &TLI, /*MustSucceed=*/false)) return replaceInstUsesWith(CI, V); return nullptr; + case Intrinsic::abs: { + Value *IIOperand = II->getArgOperand(0); + bool IntMinIsPoison = cast<Constant>(II->getArgOperand(1))->isOneValue(); + + // abs(-x) -> abs(x) + // TODO: Copy nsw if it was present on the neg? + Value *X; + if (match(IIOperand, m_Neg(m_Value(X)))) + return replaceOperand(*II, 0, X); + if (match(IIOperand, m_Select(m_Value(), m_Value(X), m_Neg(m_Deferred(X))))) + return replaceOperand(*II, 0, X); + if (match(IIOperand, m_Select(m_Value(), m_Neg(m_Value(X)), m_Deferred(X)))) + return replaceOperand(*II, 0, X); + + if (Optional<bool> Sign = getKnownSign(IIOperand, II, DL, &AC, &DT)) { + // abs(x) -> x if x >= 0 + if (!*Sign) + return replaceInstUsesWith(*II, IIOperand); + + // abs(x) -> -x if x < 0 + if (IntMinIsPoison) + return BinaryOperator::CreateNSWNeg(IIOperand); + return BinaryOperator::CreateNeg(IIOperand); + } + + // abs (sext X) --> zext (abs X*) + // Clear the IsIntMin (nsw) bit on the abs to allow narrowing. + if (match(IIOperand, m_OneUse(m_SExt(m_Value(X))))) { + Value *NarrowAbs = + Builder.CreateBinaryIntrinsic(Intrinsic::abs, X, Builder.getFalse()); + return CastInst::Create(Instruction::ZExt, NarrowAbs, II->getType()); + } + + break; + } + case Intrinsic::umax: + case Intrinsic::umin: { + Value *I0 = II->getArgOperand(0), *I1 = II->getArgOperand(1); + Value *X, *Y; + if (match(I0, m_ZExt(m_Value(X))) && match(I1, m_ZExt(m_Value(Y))) && + (I0->hasOneUse() || I1->hasOneUse()) && X->getType() == Y->getType()) { + Value *NarrowMaxMin = Builder.CreateBinaryIntrinsic(IID, X, Y); + return CastInst::Create(Instruction::ZExt, NarrowMaxMin, II->getType()); + } + // If both operands of unsigned min/max are sign-extended, it is still ok + // to narrow the operation. + LLVM_FALLTHROUGH; + } + case Intrinsic::smax: + case Intrinsic::smin: { + Value *I0 = II->getArgOperand(0), *I1 = II->getArgOperand(1); + Value *X, *Y; + if (match(I0, m_SExt(m_Value(X))) && match(I1, m_SExt(m_Value(Y))) && + (I0->hasOneUse() || I1->hasOneUse()) && X->getType() == Y->getType()) { + Value *NarrowMaxMin = Builder.CreateBinaryIntrinsic(IID, X, Y); + return CastInst::Create(Instruction::SExt, NarrowMaxMin, II->getType()); + } + break; + } case Intrinsic::bswap: { Value *IIOperand = II->getArgOperand(0); Value *X = nullptr; // bswap(trunc(bswap(x))) -> trunc(lshr(x, c)) if (match(IIOperand, m_Trunc(m_BSwap(m_Value(X))))) { - unsigned C = X->getType()->getPrimitiveSizeInBits() - - IIOperand->getType()->getPrimitiveSizeInBits(); + unsigned C = X->getType()->getScalarSizeInBits() - + IIOperand->getType()->getScalarSizeInBits(); Value *CV = ConstantInt::get(X->getType(), C); Value *V = Builder.CreateLShr(X, CV); return new TruncInst(V, IIOperand->getType()); @@ -2002,15 +886,14 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { case Intrinsic::powi: if (ConstantInt *Power = dyn_cast<ConstantInt>(II->getArgOperand(1))) { // 0 and 1 are handled in instsimplify - // powi(x, -1) -> 1/x if (Power->isMinusOne()) - return BinaryOperator::CreateFDiv(ConstantFP::get(CI.getType(), 1.0), - II->getArgOperand(0)); + return BinaryOperator::CreateFDivFMF(ConstantFP::get(CI.getType(), 1.0), + II->getArgOperand(0), II); // powi(x, 2) -> x*x if (Power->equalsInt(2)) - return BinaryOperator::CreateFMul(II->getArgOperand(0), - II->getArgOperand(0)); + return BinaryOperator::CreateFMulFMF(II->getArgOperand(0), + II->getArgOperand(0), II); } break; @@ -2031,8 +914,8 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { Type *Ty = II->getType(); unsigned BitWidth = Ty->getScalarSizeInBits(); Constant *ShAmtC; - if (match(II->getArgOperand(2), m_Constant(ShAmtC)) && - !isa<ConstantExpr>(ShAmtC) && !ShAmtC->containsConstantExpression()) { + if (match(II->getArgOperand(2), m_ImmConstant(ShAmtC)) && + !ShAmtC->containsConstantExpression()) { // Canonicalize a shift amount constant operand to modulo the bit-width. Constant *WidthC = ConstantInt::get(Ty, BitWidth); Constant *ModuloC = ConstantExpr::getURem(ShAmtC, WidthC); @@ -2092,8 +975,6 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { } case Intrinsic::uadd_with_overflow: case Intrinsic::sadd_with_overflow: { - if (Instruction *I = canonicalizeConstantArg0ToArg1(CI)) - return I; if (Instruction *I = foldIntrinsicWithOverflowCommon(II)) return I; @@ -2121,10 +1002,6 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { case Intrinsic::umul_with_overflow: case Intrinsic::smul_with_overflow: - if (Instruction *I = canonicalizeConstantArg0ToArg1(CI)) - return I; - LLVM_FALLTHROUGH; - case Intrinsic::usub_with_overflow: if (Instruction *I = foldIntrinsicWithOverflowCommon(II)) return I; @@ -2155,9 +1032,6 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { case Intrinsic::uadd_sat: case Intrinsic::sadd_sat: - if (Instruction *I = canonicalizeConstantArg0ToArg1(CI)) - return I; - LLVM_FALLTHROUGH; case Intrinsic::usub_sat: case Intrinsic::ssub_sat: { SaturatingInst *SI = cast<SaturatingInst>(II); @@ -2238,8 +1112,6 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { case Intrinsic::maxnum: case Intrinsic::minimum: case Intrinsic::maximum: { - if (Instruction *I = canonicalizeConstantArg0ToArg1(CI)) - return I; Value *Arg0 = II->getArgOperand(0); Value *Arg1 = II->getArgOperand(1); Value *X, *Y; @@ -2348,9 +1220,6 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { LLVM_FALLTHROUGH; } case Intrinsic::fma: { - if (Instruction *I = canonicalizeConstantArg0ToArg1(CI)) - return I; - // fma fneg(x), fneg(y), z -> fma x, y, z Value *Src0 = II->getArgOperand(0); Value *Src1 = II->getArgOperand(1); @@ -2390,40 +1259,52 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { break; } case Intrinsic::copysign: { - if (SignBitMustBeZero(II->getArgOperand(1), &TLI)) { + Value *Mag = II->getArgOperand(0), *Sign = II->getArgOperand(1); + if (SignBitMustBeZero(Sign, &TLI)) { // If we know that the sign argument is positive, reduce to FABS: - // copysign X, Pos --> fabs X - Value *Fabs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, - II->getArgOperand(0), II); + // copysign Mag, +Sign --> fabs Mag + Value *Fabs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, Mag, II); return replaceInstUsesWith(*II, Fabs); } // TODO: There should be a ValueTracking sibling like SignBitMustBeOne. const APFloat *C; - if (match(II->getArgOperand(1), m_APFloat(C)) && C->isNegative()) { + if (match(Sign, m_APFloat(C)) && C->isNegative()) { // If we know that the sign argument is negative, reduce to FNABS: - // copysign X, Neg --> fneg (fabs X) - Value *Fabs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, - II->getArgOperand(0), II); + // copysign Mag, -Sign --> fneg (fabs Mag) + Value *Fabs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, Mag, II); return replaceInstUsesWith(*II, Builder.CreateFNegFMF(Fabs, II)); } // Propagate sign argument through nested calls: - // copysign X, (copysign ?, SignArg) --> copysign X, SignArg - Value *SignArg; - if (match(II->getArgOperand(1), - m_Intrinsic<Intrinsic::copysign>(m_Value(), m_Value(SignArg)))) - return replaceOperand(*II, 1, SignArg); + // copysign Mag, (copysign ?, X) --> copysign Mag, X + Value *X; + if (match(Sign, m_Intrinsic<Intrinsic::copysign>(m_Value(), m_Value(X)))) + return replaceOperand(*II, 1, X); + + // Peek through changes of magnitude's sign-bit. This call rewrites those: + // copysign (fabs X), Sign --> copysign X, Sign + // copysign (fneg X), Sign --> copysign X, Sign + if (match(Mag, m_FAbs(m_Value(X))) || match(Mag, m_FNeg(m_Value(X)))) + return replaceOperand(*II, 0, X); break; } case Intrinsic::fabs: { - Value *Cond; - Constant *LHS, *RHS; + Value *Cond, *TVal, *FVal; if (match(II->getArgOperand(0), - m_Select(m_Value(Cond), m_Constant(LHS), m_Constant(RHS)))) { - CallInst *Call0 = Builder.CreateCall(II->getCalledFunction(), {LHS}); - CallInst *Call1 = Builder.CreateCall(II->getCalledFunction(), {RHS}); - return SelectInst::Create(Cond, Call0, Call1); + m_Select(m_Value(Cond), m_Value(TVal), m_Value(FVal)))) { + // fabs (select Cond, TrueC, FalseC) --> select Cond, AbsT, AbsF + if (isa<Constant>(TVal) && isa<Constant>(FVal)) { + CallInst *AbsT = Builder.CreateCall(II->getCalledFunction(), {TVal}); + CallInst *AbsF = Builder.CreateCall(II->getCalledFunction(), {FVal}); + return SelectInst::Create(Cond, AbsT, AbsF); + } + // fabs (select Cond, -FVal, FVal) --> fabs FVal + if (match(TVal, m_FNeg(m_Specific(FVal)))) + return replaceOperand(*II, 0, FVal); + // fabs (select Cond, TVal, -TVal) --> fabs TVal + if (match(FVal, m_FNeg(m_Specific(TVal)))) + return replaceOperand(*II, 0, TVal); } LLVM_FALLTHROUGH; @@ -2465,932 +1346,6 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { } break; } - case Intrinsic::ppc_altivec_lvx: - case Intrinsic::ppc_altivec_lvxl: - // Turn PPC lvx -> load if the pointer is known aligned. - if (getOrEnforceKnownAlignment(II->getArgOperand(0), Align(16), DL, II, &AC, - &DT) >= 16) { - Value *Ptr = Builder.CreateBitCast(II->getArgOperand(0), - PointerType::getUnqual(II->getType())); - return new LoadInst(II->getType(), Ptr, "", false, Align(16)); - } - break; - case Intrinsic::ppc_vsx_lxvw4x: - case Intrinsic::ppc_vsx_lxvd2x: { - // Turn PPC VSX loads into normal loads. - Value *Ptr = Builder.CreateBitCast(II->getArgOperand(0), - PointerType::getUnqual(II->getType())); - return new LoadInst(II->getType(), Ptr, Twine(""), false, Align(1)); - } - case Intrinsic::ppc_altivec_stvx: - case Intrinsic::ppc_altivec_stvxl: - // Turn stvx -> store if the pointer is known aligned. - if (getOrEnforceKnownAlignment(II->getArgOperand(1), Align(16), DL, II, &AC, - &DT) >= 16) { - Type *OpPtrTy = - PointerType::getUnqual(II->getArgOperand(0)->getType()); - Value *Ptr = Builder.CreateBitCast(II->getArgOperand(1), OpPtrTy); - return new StoreInst(II->getArgOperand(0), Ptr, false, Align(16)); - } - break; - case Intrinsic::ppc_vsx_stxvw4x: - case Intrinsic::ppc_vsx_stxvd2x: { - // Turn PPC VSX stores into normal stores. - Type *OpPtrTy = PointerType::getUnqual(II->getArgOperand(0)->getType()); - Value *Ptr = Builder.CreateBitCast(II->getArgOperand(1), OpPtrTy); - return new StoreInst(II->getArgOperand(0), Ptr, false, Align(1)); - } - case Intrinsic::ppc_qpx_qvlfs: - // Turn PPC QPX qvlfs -> load if the pointer is known aligned. - if (getOrEnforceKnownAlignment(II->getArgOperand(0), Align(16), DL, II, &AC, - &DT) >= 16) { - Type *VTy = - VectorType::get(Builder.getFloatTy(), - cast<VectorType>(II->getType())->getElementCount()); - Value *Ptr = Builder.CreateBitCast(II->getArgOperand(0), - PointerType::getUnqual(VTy)); - Value *Load = Builder.CreateLoad(VTy, Ptr); - return new FPExtInst(Load, II->getType()); - } - break; - case Intrinsic::ppc_qpx_qvlfd: - // Turn PPC QPX qvlfd -> load if the pointer is known aligned. - if (getOrEnforceKnownAlignment(II->getArgOperand(0), Align(32), DL, II, &AC, - &DT) >= 32) { - Value *Ptr = Builder.CreateBitCast(II->getArgOperand(0), - PointerType::getUnqual(II->getType())); - return new LoadInst(II->getType(), Ptr, "", false, Align(32)); - } - break; - case Intrinsic::ppc_qpx_qvstfs: - // Turn PPC QPX qvstfs -> store if the pointer is known aligned. - if (getOrEnforceKnownAlignment(II->getArgOperand(1), Align(16), DL, II, &AC, - &DT) >= 16) { - Type *VTy = VectorType::get( - Builder.getFloatTy(), - cast<VectorType>(II->getArgOperand(0)->getType())->getElementCount()); - Value *TOp = Builder.CreateFPTrunc(II->getArgOperand(0), VTy); - Type *OpPtrTy = PointerType::getUnqual(VTy); - Value *Ptr = Builder.CreateBitCast(II->getArgOperand(1), OpPtrTy); - return new StoreInst(TOp, Ptr, false, Align(16)); - } - break; - case Intrinsic::ppc_qpx_qvstfd: - // Turn PPC QPX qvstfd -> store if the pointer is known aligned. - if (getOrEnforceKnownAlignment(II->getArgOperand(1), Align(32), DL, II, &AC, - &DT) >= 32) { - Type *OpPtrTy = - PointerType::getUnqual(II->getArgOperand(0)->getType()); - Value *Ptr = Builder.CreateBitCast(II->getArgOperand(1), OpPtrTy); - return new StoreInst(II->getArgOperand(0), Ptr, false, Align(32)); - } - break; - - case Intrinsic::x86_bmi_bextr_32: - case Intrinsic::x86_bmi_bextr_64: - case Intrinsic::x86_tbm_bextri_u32: - case Intrinsic::x86_tbm_bextri_u64: - // If the RHS is a constant we can try some simplifications. - if (auto *C = dyn_cast<ConstantInt>(II->getArgOperand(1))) { - uint64_t Shift = C->getZExtValue(); - uint64_t Length = (Shift >> 8) & 0xff; - Shift &= 0xff; - unsigned BitWidth = II->getType()->getIntegerBitWidth(); - // If the length is 0 or the shift is out of range, replace with zero. - if (Length == 0 || Shift >= BitWidth) - return replaceInstUsesWith(CI, ConstantInt::get(II->getType(), 0)); - // If the LHS is also a constant, we can completely constant fold this. - if (auto *InC = dyn_cast<ConstantInt>(II->getArgOperand(0))) { - uint64_t Result = InC->getZExtValue() >> Shift; - if (Length > BitWidth) - Length = BitWidth; - Result &= maskTrailingOnes<uint64_t>(Length); - return replaceInstUsesWith(CI, ConstantInt::get(II->getType(), Result)); - } - // TODO should we turn this into 'and' if shift is 0? Or 'shl' if we - // are only masking bits that a shift already cleared? - } - break; - - case Intrinsic::x86_bmi_bzhi_32: - case Intrinsic::x86_bmi_bzhi_64: - // If the RHS is a constant we can try some simplifications. - if (auto *C = dyn_cast<ConstantInt>(II->getArgOperand(1))) { - uint64_t Index = C->getZExtValue() & 0xff; - unsigned BitWidth = II->getType()->getIntegerBitWidth(); - if (Index >= BitWidth) - return replaceInstUsesWith(CI, II->getArgOperand(0)); - if (Index == 0) - return replaceInstUsesWith(CI, ConstantInt::get(II->getType(), 0)); - // If the LHS is also a constant, we can completely constant fold this. - if (auto *InC = dyn_cast<ConstantInt>(II->getArgOperand(0))) { - uint64_t Result = InC->getZExtValue(); - Result &= maskTrailingOnes<uint64_t>(Index); - return replaceInstUsesWith(CI, ConstantInt::get(II->getType(), Result)); - } - // TODO should we convert this to an AND if the RHS is constant? - } - break; - case Intrinsic::x86_bmi_pext_32: - case Intrinsic::x86_bmi_pext_64: - if (auto *MaskC = dyn_cast<ConstantInt>(II->getArgOperand(1))) { - if (MaskC->isNullValue()) - return replaceInstUsesWith(CI, ConstantInt::get(II->getType(), 0)); - if (MaskC->isAllOnesValue()) - return replaceInstUsesWith(CI, II->getArgOperand(0)); - - if (auto *SrcC = dyn_cast<ConstantInt>(II->getArgOperand(0))) { - uint64_t Src = SrcC->getZExtValue(); - uint64_t Mask = MaskC->getZExtValue(); - uint64_t Result = 0; - uint64_t BitToSet = 1; - - while (Mask) { - // Isolate lowest set bit. - uint64_t BitToTest = Mask & -Mask; - if (BitToTest & Src) - Result |= BitToSet; - - BitToSet <<= 1; - // Clear lowest set bit. - Mask &= Mask - 1; - } - - return replaceInstUsesWith(CI, ConstantInt::get(II->getType(), Result)); - } - } - break; - case Intrinsic::x86_bmi_pdep_32: - case Intrinsic::x86_bmi_pdep_64: - if (auto *MaskC = dyn_cast<ConstantInt>(II->getArgOperand(1))) { - if (MaskC->isNullValue()) - return replaceInstUsesWith(CI, ConstantInt::get(II->getType(), 0)); - if (MaskC->isAllOnesValue()) - return replaceInstUsesWith(CI, II->getArgOperand(0)); - - if (auto *SrcC = dyn_cast<ConstantInt>(II->getArgOperand(0))) { - uint64_t Src = SrcC->getZExtValue(); - uint64_t Mask = MaskC->getZExtValue(); - uint64_t Result = 0; - uint64_t BitToTest = 1; - - while (Mask) { - // Isolate lowest set bit. - uint64_t BitToSet = Mask & -Mask; - if (BitToTest & Src) - Result |= BitToSet; - - BitToTest <<= 1; - // Clear lowest set bit; - Mask &= Mask - 1; - } - - return replaceInstUsesWith(CI, ConstantInt::get(II->getType(), Result)); - } - } - break; - - case Intrinsic::x86_sse_cvtss2si: - case Intrinsic::x86_sse_cvtss2si64: - case Intrinsic::x86_sse_cvttss2si: - case Intrinsic::x86_sse_cvttss2si64: - case Intrinsic::x86_sse2_cvtsd2si: - case Intrinsic::x86_sse2_cvtsd2si64: - case Intrinsic::x86_sse2_cvttsd2si: - case Intrinsic::x86_sse2_cvttsd2si64: - case Intrinsic::x86_avx512_vcvtss2si32: - case Intrinsic::x86_avx512_vcvtss2si64: - case Intrinsic::x86_avx512_vcvtss2usi32: - case Intrinsic::x86_avx512_vcvtss2usi64: - case Intrinsic::x86_avx512_vcvtsd2si32: - case Intrinsic::x86_avx512_vcvtsd2si64: - case Intrinsic::x86_avx512_vcvtsd2usi32: - case Intrinsic::x86_avx512_vcvtsd2usi64: - case Intrinsic::x86_avx512_cvttss2si: - case Intrinsic::x86_avx512_cvttss2si64: - case Intrinsic::x86_avx512_cvttss2usi: - case Intrinsic::x86_avx512_cvttss2usi64: - case Intrinsic::x86_avx512_cvttsd2si: - case Intrinsic::x86_avx512_cvttsd2si64: - case Intrinsic::x86_avx512_cvttsd2usi: - case Intrinsic::x86_avx512_cvttsd2usi64: { - // These intrinsics only demand the 0th element of their input vectors. If - // we can simplify the input based on that, do so now. - Value *Arg = II->getArgOperand(0); - unsigned VWidth = cast<VectorType>(Arg->getType())->getNumElements(); - if (Value *V = SimplifyDemandedVectorEltsLow(Arg, VWidth, 1)) - return replaceOperand(*II, 0, V); - break; - } - - case Intrinsic::x86_mmx_pmovmskb: - case Intrinsic::x86_sse_movmsk_ps: - case Intrinsic::x86_sse2_movmsk_pd: - case Intrinsic::x86_sse2_pmovmskb_128: - case Intrinsic::x86_avx_movmsk_pd_256: - case Intrinsic::x86_avx_movmsk_ps_256: - case Intrinsic::x86_avx2_pmovmskb: - if (Value *V = simplifyX86movmsk(*II, Builder)) - return replaceInstUsesWith(*II, V); - break; - - case Intrinsic::x86_sse_comieq_ss: - case Intrinsic::x86_sse_comige_ss: - case Intrinsic::x86_sse_comigt_ss: - case Intrinsic::x86_sse_comile_ss: - case Intrinsic::x86_sse_comilt_ss: - case Intrinsic::x86_sse_comineq_ss: - case Intrinsic::x86_sse_ucomieq_ss: - case Intrinsic::x86_sse_ucomige_ss: - case Intrinsic::x86_sse_ucomigt_ss: - case Intrinsic::x86_sse_ucomile_ss: - case Intrinsic::x86_sse_ucomilt_ss: - case Intrinsic::x86_sse_ucomineq_ss: - case Intrinsic::x86_sse2_comieq_sd: - case Intrinsic::x86_sse2_comige_sd: - case Intrinsic::x86_sse2_comigt_sd: - case Intrinsic::x86_sse2_comile_sd: - case Intrinsic::x86_sse2_comilt_sd: - case Intrinsic::x86_sse2_comineq_sd: - case Intrinsic::x86_sse2_ucomieq_sd: - case Intrinsic::x86_sse2_ucomige_sd: - case Intrinsic::x86_sse2_ucomigt_sd: - case Intrinsic::x86_sse2_ucomile_sd: - case Intrinsic::x86_sse2_ucomilt_sd: - case Intrinsic::x86_sse2_ucomineq_sd: - case Intrinsic::x86_avx512_vcomi_ss: - case Intrinsic::x86_avx512_vcomi_sd: - case Intrinsic::x86_avx512_mask_cmp_ss: - case Intrinsic::x86_avx512_mask_cmp_sd: { - // These intrinsics only demand the 0th element of their input vectors. If - // we can simplify the input based on that, do so now. - bool MadeChange = false; - Value *Arg0 = II->getArgOperand(0); - Value *Arg1 = II->getArgOperand(1); - unsigned VWidth = cast<VectorType>(Arg0->getType())->getNumElements(); - if (Value *V = SimplifyDemandedVectorEltsLow(Arg0, VWidth, 1)) { - replaceOperand(*II, 0, V); - MadeChange = true; - } - if (Value *V = SimplifyDemandedVectorEltsLow(Arg1, VWidth, 1)) { - replaceOperand(*II, 1, V); - MadeChange = true; - } - if (MadeChange) - return II; - break; - } - case Intrinsic::x86_avx512_cmp_pd_128: - case Intrinsic::x86_avx512_cmp_pd_256: - case Intrinsic::x86_avx512_cmp_pd_512: - case Intrinsic::x86_avx512_cmp_ps_128: - case Intrinsic::x86_avx512_cmp_ps_256: - case Intrinsic::x86_avx512_cmp_ps_512: { - // Folding cmp(sub(a,b),0) -> cmp(a,b) and cmp(0,sub(a,b)) -> cmp(b,a) - Value *Arg0 = II->getArgOperand(0); - Value *Arg1 = II->getArgOperand(1); - bool Arg0IsZero = match(Arg0, m_PosZeroFP()); - if (Arg0IsZero) - std::swap(Arg0, Arg1); - Value *A, *B; - // This fold requires only the NINF(not +/- inf) since inf minus - // inf is nan. - // NSZ(No Signed Zeros) is not needed because zeros of any sign are - // equal for both compares. - // NNAN is not needed because nans compare the same for both compares. - // The compare intrinsic uses the above assumptions and therefore - // doesn't require additional flags. - if ((match(Arg0, m_OneUse(m_FSub(m_Value(A), m_Value(B)))) && - match(Arg1, m_PosZeroFP()) && isa<Instruction>(Arg0) && - cast<Instruction>(Arg0)->getFastMathFlags().noInfs())) { - if (Arg0IsZero) - std::swap(A, B); - replaceOperand(*II, 0, A); - replaceOperand(*II, 1, B); - return II; - } - break; - } - - case Intrinsic::x86_avx512_add_ps_512: - case Intrinsic::x86_avx512_div_ps_512: - case Intrinsic::x86_avx512_mul_ps_512: - case Intrinsic::x86_avx512_sub_ps_512: - case Intrinsic::x86_avx512_add_pd_512: - case Intrinsic::x86_avx512_div_pd_512: - case Intrinsic::x86_avx512_mul_pd_512: - case Intrinsic::x86_avx512_sub_pd_512: - // If the rounding mode is CUR_DIRECTION(4) we can turn these into regular - // IR operations. - if (auto *R = dyn_cast<ConstantInt>(II->getArgOperand(2))) { - if (R->getValue() == 4) { - Value *Arg0 = II->getArgOperand(0); - Value *Arg1 = II->getArgOperand(1); - - Value *V; - switch (IID) { - default: llvm_unreachable("Case stmts out of sync!"); - case Intrinsic::x86_avx512_add_ps_512: - case Intrinsic::x86_avx512_add_pd_512: - V = Builder.CreateFAdd(Arg0, Arg1); - break; - case Intrinsic::x86_avx512_sub_ps_512: - case Intrinsic::x86_avx512_sub_pd_512: - V = Builder.CreateFSub(Arg0, Arg1); - break; - case Intrinsic::x86_avx512_mul_ps_512: - case Intrinsic::x86_avx512_mul_pd_512: - V = Builder.CreateFMul(Arg0, Arg1); - break; - case Intrinsic::x86_avx512_div_ps_512: - case Intrinsic::x86_avx512_div_pd_512: - V = Builder.CreateFDiv(Arg0, Arg1); - break; - } - - return replaceInstUsesWith(*II, V); - } - } - break; - - case Intrinsic::x86_avx512_mask_add_ss_round: - case Intrinsic::x86_avx512_mask_div_ss_round: - case Intrinsic::x86_avx512_mask_mul_ss_round: - case Intrinsic::x86_avx512_mask_sub_ss_round: - case Intrinsic::x86_avx512_mask_add_sd_round: - case Intrinsic::x86_avx512_mask_div_sd_round: - case Intrinsic::x86_avx512_mask_mul_sd_round: - case Intrinsic::x86_avx512_mask_sub_sd_round: - // If the rounding mode is CUR_DIRECTION(4) we can turn these into regular - // IR operations. - if (auto *R = dyn_cast<ConstantInt>(II->getArgOperand(4))) { - if (R->getValue() == 4) { - // Extract the element as scalars. - Value *Arg0 = II->getArgOperand(0); - Value *Arg1 = II->getArgOperand(1); - Value *LHS = Builder.CreateExtractElement(Arg0, (uint64_t)0); - Value *RHS = Builder.CreateExtractElement(Arg1, (uint64_t)0); - - Value *V; - switch (IID) { - default: llvm_unreachable("Case stmts out of sync!"); - case Intrinsic::x86_avx512_mask_add_ss_round: - case Intrinsic::x86_avx512_mask_add_sd_round: - V = Builder.CreateFAdd(LHS, RHS); - break; - case Intrinsic::x86_avx512_mask_sub_ss_round: - case Intrinsic::x86_avx512_mask_sub_sd_round: - V = Builder.CreateFSub(LHS, RHS); - break; - case Intrinsic::x86_avx512_mask_mul_ss_round: - case Intrinsic::x86_avx512_mask_mul_sd_round: - V = Builder.CreateFMul(LHS, RHS); - break; - case Intrinsic::x86_avx512_mask_div_ss_round: - case Intrinsic::x86_avx512_mask_div_sd_round: - V = Builder.CreateFDiv(LHS, RHS); - break; - } - - // Handle the masking aspect of the intrinsic. - Value *Mask = II->getArgOperand(3); - auto *C = dyn_cast<ConstantInt>(Mask); - // We don't need a select if we know the mask bit is a 1. - if (!C || !C->getValue()[0]) { - // Cast the mask to an i1 vector and then extract the lowest element. - auto *MaskTy = FixedVectorType::get( - Builder.getInt1Ty(), - cast<IntegerType>(Mask->getType())->getBitWidth()); - Mask = Builder.CreateBitCast(Mask, MaskTy); - Mask = Builder.CreateExtractElement(Mask, (uint64_t)0); - // Extract the lowest element from the passthru operand. - Value *Passthru = Builder.CreateExtractElement(II->getArgOperand(2), - (uint64_t)0); - V = Builder.CreateSelect(Mask, V, Passthru); - } - - // Insert the result back into the original argument 0. - V = Builder.CreateInsertElement(Arg0, V, (uint64_t)0); - - return replaceInstUsesWith(*II, V); - } - } - break; - - // Constant fold ashr( <A x Bi>, Ci ). - // Constant fold lshr( <A x Bi>, Ci ). - // Constant fold shl( <A x Bi>, Ci ). - case Intrinsic::x86_sse2_psrai_d: - case Intrinsic::x86_sse2_psrai_w: - case Intrinsic::x86_avx2_psrai_d: - case Intrinsic::x86_avx2_psrai_w: - case Intrinsic::x86_avx512_psrai_q_128: - case Intrinsic::x86_avx512_psrai_q_256: - case Intrinsic::x86_avx512_psrai_d_512: - case Intrinsic::x86_avx512_psrai_q_512: - case Intrinsic::x86_avx512_psrai_w_512: - case Intrinsic::x86_sse2_psrli_d: - case Intrinsic::x86_sse2_psrli_q: - case Intrinsic::x86_sse2_psrli_w: - case Intrinsic::x86_avx2_psrli_d: - case Intrinsic::x86_avx2_psrli_q: - case Intrinsic::x86_avx2_psrli_w: - case Intrinsic::x86_avx512_psrli_d_512: - case Intrinsic::x86_avx512_psrli_q_512: - case Intrinsic::x86_avx512_psrli_w_512: - case Intrinsic::x86_sse2_pslli_d: - case Intrinsic::x86_sse2_pslli_q: - case Intrinsic::x86_sse2_pslli_w: - case Intrinsic::x86_avx2_pslli_d: - case Intrinsic::x86_avx2_pslli_q: - case Intrinsic::x86_avx2_pslli_w: - case Intrinsic::x86_avx512_pslli_d_512: - case Intrinsic::x86_avx512_pslli_q_512: - case Intrinsic::x86_avx512_pslli_w_512: - if (Value *V = simplifyX86immShift(*II, Builder)) - return replaceInstUsesWith(*II, V); - break; - - case Intrinsic::x86_sse2_psra_d: - case Intrinsic::x86_sse2_psra_w: - case Intrinsic::x86_avx2_psra_d: - case Intrinsic::x86_avx2_psra_w: - case Intrinsic::x86_avx512_psra_q_128: - case Intrinsic::x86_avx512_psra_q_256: - case Intrinsic::x86_avx512_psra_d_512: - case Intrinsic::x86_avx512_psra_q_512: - case Intrinsic::x86_avx512_psra_w_512: - case Intrinsic::x86_sse2_psrl_d: - case Intrinsic::x86_sse2_psrl_q: - case Intrinsic::x86_sse2_psrl_w: - case Intrinsic::x86_avx2_psrl_d: - case Intrinsic::x86_avx2_psrl_q: - case Intrinsic::x86_avx2_psrl_w: - case Intrinsic::x86_avx512_psrl_d_512: - case Intrinsic::x86_avx512_psrl_q_512: - case Intrinsic::x86_avx512_psrl_w_512: - case Intrinsic::x86_sse2_psll_d: - case Intrinsic::x86_sse2_psll_q: - case Intrinsic::x86_sse2_psll_w: - case Intrinsic::x86_avx2_psll_d: - case Intrinsic::x86_avx2_psll_q: - case Intrinsic::x86_avx2_psll_w: - case Intrinsic::x86_avx512_psll_d_512: - case Intrinsic::x86_avx512_psll_q_512: - case Intrinsic::x86_avx512_psll_w_512: { - if (Value *V = simplifyX86immShift(*II, Builder)) - return replaceInstUsesWith(*II, V); - - // SSE2/AVX2 uses only the first 64-bits of the 128-bit vector - // operand to compute the shift amount. - Value *Arg1 = II->getArgOperand(1); - assert(Arg1->getType()->getPrimitiveSizeInBits() == 128 && - "Unexpected packed shift size"); - unsigned VWidth = cast<VectorType>(Arg1->getType())->getNumElements(); - - if (Value *V = SimplifyDemandedVectorEltsLow(Arg1, VWidth, VWidth / 2)) - return replaceOperand(*II, 1, V); - break; - } - - case Intrinsic::x86_avx2_psllv_d: - case Intrinsic::x86_avx2_psllv_d_256: - case Intrinsic::x86_avx2_psllv_q: - case Intrinsic::x86_avx2_psllv_q_256: - case Intrinsic::x86_avx512_psllv_d_512: - case Intrinsic::x86_avx512_psllv_q_512: - case Intrinsic::x86_avx512_psllv_w_128: - case Intrinsic::x86_avx512_psllv_w_256: - case Intrinsic::x86_avx512_psllv_w_512: - case Intrinsic::x86_avx2_psrav_d: - case Intrinsic::x86_avx2_psrav_d_256: - case Intrinsic::x86_avx512_psrav_q_128: - case Intrinsic::x86_avx512_psrav_q_256: - case Intrinsic::x86_avx512_psrav_d_512: - case Intrinsic::x86_avx512_psrav_q_512: - case Intrinsic::x86_avx512_psrav_w_128: - case Intrinsic::x86_avx512_psrav_w_256: - case Intrinsic::x86_avx512_psrav_w_512: - case Intrinsic::x86_avx2_psrlv_d: - case Intrinsic::x86_avx2_psrlv_d_256: - case Intrinsic::x86_avx2_psrlv_q: - case Intrinsic::x86_avx2_psrlv_q_256: - case Intrinsic::x86_avx512_psrlv_d_512: - case Intrinsic::x86_avx512_psrlv_q_512: - case Intrinsic::x86_avx512_psrlv_w_128: - case Intrinsic::x86_avx512_psrlv_w_256: - case Intrinsic::x86_avx512_psrlv_w_512: - if (Value *V = simplifyX86varShift(*II, Builder)) - return replaceInstUsesWith(*II, V); - break; - - case Intrinsic::x86_sse2_packssdw_128: - case Intrinsic::x86_sse2_packsswb_128: - case Intrinsic::x86_avx2_packssdw: - case Intrinsic::x86_avx2_packsswb: - case Intrinsic::x86_avx512_packssdw_512: - case Intrinsic::x86_avx512_packsswb_512: - if (Value *V = simplifyX86pack(*II, Builder, true)) - return replaceInstUsesWith(*II, V); - break; - - case Intrinsic::x86_sse2_packuswb_128: - case Intrinsic::x86_sse41_packusdw: - case Intrinsic::x86_avx2_packusdw: - case Intrinsic::x86_avx2_packuswb: - case Intrinsic::x86_avx512_packusdw_512: - case Intrinsic::x86_avx512_packuswb_512: - if (Value *V = simplifyX86pack(*II, Builder, false)) - return replaceInstUsesWith(*II, V); - break; - - case Intrinsic::x86_pclmulqdq: - case Intrinsic::x86_pclmulqdq_256: - case Intrinsic::x86_pclmulqdq_512: { - if (auto *C = dyn_cast<ConstantInt>(II->getArgOperand(2))) { - unsigned Imm = C->getZExtValue(); - - bool MadeChange = false; - Value *Arg0 = II->getArgOperand(0); - Value *Arg1 = II->getArgOperand(1); - unsigned VWidth = cast<VectorType>(Arg0->getType())->getNumElements(); - - APInt UndefElts1(VWidth, 0); - APInt DemandedElts1 = APInt::getSplat(VWidth, - APInt(2, (Imm & 0x01) ? 2 : 1)); - if (Value *V = SimplifyDemandedVectorElts(Arg0, DemandedElts1, - UndefElts1)) { - replaceOperand(*II, 0, V); - MadeChange = true; - } - - APInt UndefElts2(VWidth, 0); - APInt DemandedElts2 = APInt::getSplat(VWidth, - APInt(2, (Imm & 0x10) ? 2 : 1)); - if (Value *V = SimplifyDemandedVectorElts(Arg1, DemandedElts2, - UndefElts2)) { - replaceOperand(*II, 1, V); - MadeChange = true; - } - - // If either input elements are undef, the result is zero. - if (DemandedElts1.isSubsetOf(UndefElts1) || - DemandedElts2.isSubsetOf(UndefElts2)) - return replaceInstUsesWith(*II, - ConstantAggregateZero::get(II->getType())); - - if (MadeChange) - return II; - } - break; - } - - case Intrinsic::x86_sse41_insertps: - if (Value *V = simplifyX86insertps(*II, Builder)) - return replaceInstUsesWith(*II, V); - break; - - case Intrinsic::x86_sse4a_extrq: { - Value *Op0 = II->getArgOperand(0); - Value *Op1 = II->getArgOperand(1); - unsigned VWidth0 = cast<VectorType>(Op0->getType())->getNumElements(); - unsigned VWidth1 = cast<VectorType>(Op1->getType())->getNumElements(); - assert(Op0->getType()->getPrimitiveSizeInBits() == 128 && - Op1->getType()->getPrimitiveSizeInBits() == 128 && VWidth0 == 2 && - VWidth1 == 16 && "Unexpected operand sizes"); - - // See if we're dealing with constant values. - Constant *C1 = dyn_cast<Constant>(Op1); - ConstantInt *CILength = - C1 ? dyn_cast_or_null<ConstantInt>(C1->getAggregateElement((unsigned)0)) - : nullptr; - ConstantInt *CIIndex = - C1 ? dyn_cast_or_null<ConstantInt>(C1->getAggregateElement((unsigned)1)) - : nullptr; - - // Attempt to simplify to a constant, shuffle vector or EXTRQI call. - if (Value *V = simplifyX86extrq(*II, Op0, CILength, CIIndex, Builder)) - return replaceInstUsesWith(*II, V); - - // EXTRQ only uses the lowest 64-bits of the first 128-bit vector - // operands and the lowest 16-bits of the second. - bool MadeChange = false; - if (Value *V = SimplifyDemandedVectorEltsLow(Op0, VWidth0, 1)) { - replaceOperand(*II, 0, V); - MadeChange = true; - } - if (Value *V = SimplifyDemandedVectorEltsLow(Op1, VWidth1, 2)) { - replaceOperand(*II, 1, V); - MadeChange = true; - } - if (MadeChange) - return II; - break; - } - - case Intrinsic::x86_sse4a_extrqi: { - // EXTRQI: Extract Length bits starting from Index. Zero pad the remaining - // bits of the lower 64-bits. The upper 64-bits are undefined. - Value *Op0 = II->getArgOperand(0); - unsigned VWidth = cast<VectorType>(Op0->getType())->getNumElements(); - assert(Op0->getType()->getPrimitiveSizeInBits() == 128 && VWidth == 2 && - "Unexpected operand size"); - - // See if we're dealing with constant values. - ConstantInt *CILength = dyn_cast<ConstantInt>(II->getArgOperand(1)); - ConstantInt *CIIndex = dyn_cast<ConstantInt>(II->getArgOperand(2)); - - // Attempt to simplify to a constant or shuffle vector. - if (Value *V = simplifyX86extrq(*II, Op0, CILength, CIIndex, Builder)) - return replaceInstUsesWith(*II, V); - - // EXTRQI only uses the lowest 64-bits of the first 128-bit vector - // operand. - if (Value *V = SimplifyDemandedVectorEltsLow(Op0, VWidth, 1)) - return replaceOperand(*II, 0, V); - break; - } - - case Intrinsic::x86_sse4a_insertq: { - Value *Op0 = II->getArgOperand(0); - Value *Op1 = II->getArgOperand(1); - unsigned VWidth = cast<VectorType>(Op0->getType())->getNumElements(); - assert(Op0->getType()->getPrimitiveSizeInBits() == 128 && - Op1->getType()->getPrimitiveSizeInBits() == 128 && VWidth == 2 && - cast<VectorType>(Op1->getType())->getNumElements() == 2 && - "Unexpected operand size"); - - // See if we're dealing with constant values. - Constant *C1 = dyn_cast<Constant>(Op1); - ConstantInt *CI11 = - C1 ? dyn_cast_or_null<ConstantInt>(C1->getAggregateElement((unsigned)1)) - : nullptr; - - // Attempt to simplify to a constant, shuffle vector or INSERTQI call. - if (CI11) { - const APInt &V11 = CI11->getValue(); - APInt Len = V11.zextOrTrunc(6); - APInt Idx = V11.lshr(8).zextOrTrunc(6); - if (Value *V = simplifyX86insertq(*II, Op0, Op1, Len, Idx, Builder)) - return replaceInstUsesWith(*II, V); - } - - // INSERTQ only uses the lowest 64-bits of the first 128-bit vector - // operand. - if (Value *V = SimplifyDemandedVectorEltsLow(Op0, VWidth, 1)) - return replaceOperand(*II, 0, V); - break; - } - - case Intrinsic::x86_sse4a_insertqi: { - // INSERTQI: Extract lowest Length bits from lower half of second source and - // insert over first source starting at Index bit. The upper 64-bits are - // undefined. - Value *Op0 = II->getArgOperand(0); - Value *Op1 = II->getArgOperand(1); - unsigned VWidth0 = cast<VectorType>(Op0->getType())->getNumElements(); - unsigned VWidth1 = cast<VectorType>(Op1->getType())->getNumElements(); - assert(Op0->getType()->getPrimitiveSizeInBits() == 128 && - Op1->getType()->getPrimitiveSizeInBits() == 128 && VWidth0 == 2 && - VWidth1 == 2 && "Unexpected operand sizes"); - - // See if we're dealing with constant values. - ConstantInt *CILength = dyn_cast<ConstantInt>(II->getArgOperand(2)); - ConstantInt *CIIndex = dyn_cast<ConstantInt>(II->getArgOperand(3)); - - // Attempt to simplify to a constant or shuffle vector. - if (CILength && CIIndex) { - APInt Len = CILength->getValue().zextOrTrunc(6); - APInt Idx = CIIndex->getValue().zextOrTrunc(6); - if (Value *V = simplifyX86insertq(*II, Op0, Op1, Len, Idx, Builder)) - return replaceInstUsesWith(*II, V); - } - - // INSERTQI only uses the lowest 64-bits of the first two 128-bit vector - // operands. - bool MadeChange = false; - if (Value *V = SimplifyDemandedVectorEltsLow(Op0, VWidth0, 1)) { - replaceOperand(*II, 0, V); - MadeChange = true; - } - if (Value *V = SimplifyDemandedVectorEltsLow(Op1, VWidth1, 1)) { - replaceOperand(*II, 1, V); - MadeChange = true; - } - if (MadeChange) - return II; - break; - } - - case Intrinsic::x86_sse41_pblendvb: - case Intrinsic::x86_sse41_blendvps: - case Intrinsic::x86_sse41_blendvpd: - case Intrinsic::x86_avx_blendv_ps_256: - case Intrinsic::x86_avx_blendv_pd_256: - case Intrinsic::x86_avx2_pblendvb: { - // fold (blend A, A, Mask) -> A - Value *Op0 = II->getArgOperand(0); - Value *Op1 = II->getArgOperand(1); - Value *Mask = II->getArgOperand(2); - if (Op0 == Op1) - return replaceInstUsesWith(CI, Op0); - - // Zero Mask - select 1st argument. - if (isa<ConstantAggregateZero>(Mask)) - return replaceInstUsesWith(CI, Op0); - - // Constant Mask - select 1st/2nd argument lane based on top bit of mask. - if (auto *ConstantMask = dyn_cast<ConstantDataVector>(Mask)) { - Constant *NewSelector = getNegativeIsTrueBoolVec(ConstantMask); - return SelectInst::Create(NewSelector, Op1, Op0, "blendv"); - } - - // Convert to a vector select if we can bypass casts and find a boolean - // vector condition value. - Value *BoolVec; - Mask = peekThroughBitcast(Mask); - if (match(Mask, m_SExt(m_Value(BoolVec))) && - BoolVec->getType()->isVectorTy() && - BoolVec->getType()->getScalarSizeInBits() == 1) { - assert(Mask->getType()->getPrimitiveSizeInBits() == - II->getType()->getPrimitiveSizeInBits() && - "Not expecting mask and operands with different sizes"); - - unsigned NumMaskElts = - cast<VectorType>(Mask->getType())->getNumElements(); - unsigned NumOperandElts = - cast<VectorType>(II->getType())->getNumElements(); - if (NumMaskElts == NumOperandElts) - return SelectInst::Create(BoolVec, Op1, Op0); - - // If the mask has less elements than the operands, each mask bit maps to - // multiple elements of the operands. Bitcast back and forth. - if (NumMaskElts < NumOperandElts) { - Value *CastOp0 = Builder.CreateBitCast(Op0, Mask->getType()); - Value *CastOp1 = Builder.CreateBitCast(Op1, Mask->getType()); - Value *Sel = Builder.CreateSelect(BoolVec, CastOp1, CastOp0); - return new BitCastInst(Sel, II->getType()); - } - } - - break; - } - - case Intrinsic::x86_ssse3_pshuf_b_128: - case Intrinsic::x86_avx2_pshuf_b: - case Intrinsic::x86_avx512_pshuf_b_512: - if (Value *V = simplifyX86pshufb(*II, Builder)) - return replaceInstUsesWith(*II, V); - break; - - case Intrinsic::x86_avx_vpermilvar_ps: - case Intrinsic::x86_avx_vpermilvar_ps_256: - case Intrinsic::x86_avx512_vpermilvar_ps_512: - case Intrinsic::x86_avx_vpermilvar_pd: - case Intrinsic::x86_avx_vpermilvar_pd_256: - case Intrinsic::x86_avx512_vpermilvar_pd_512: - if (Value *V = simplifyX86vpermilvar(*II, Builder)) - return replaceInstUsesWith(*II, V); - break; - - case Intrinsic::x86_avx2_permd: - case Intrinsic::x86_avx2_permps: - case Intrinsic::x86_avx512_permvar_df_256: - case Intrinsic::x86_avx512_permvar_df_512: - case Intrinsic::x86_avx512_permvar_di_256: - case Intrinsic::x86_avx512_permvar_di_512: - case Intrinsic::x86_avx512_permvar_hi_128: - case Intrinsic::x86_avx512_permvar_hi_256: - case Intrinsic::x86_avx512_permvar_hi_512: - case Intrinsic::x86_avx512_permvar_qi_128: - case Intrinsic::x86_avx512_permvar_qi_256: - case Intrinsic::x86_avx512_permvar_qi_512: - case Intrinsic::x86_avx512_permvar_sf_512: - case Intrinsic::x86_avx512_permvar_si_512: - if (Value *V = simplifyX86vpermv(*II, Builder)) - return replaceInstUsesWith(*II, V); - break; - - case Intrinsic::x86_avx_maskload_ps: - case Intrinsic::x86_avx_maskload_pd: - case Intrinsic::x86_avx_maskload_ps_256: - case Intrinsic::x86_avx_maskload_pd_256: - case Intrinsic::x86_avx2_maskload_d: - case Intrinsic::x86_avx2_maskload_q: - case Intrinsic::x86_avx2_maskload_d_256: - case Intrinsic::x86_avx2_maskload_q_256: - if (Instruction *I = simplifyX86MaskedLoad(*II, *this)) - return I; - break; - - case Intrinsic::x86_sse2_maskmov_dqu: - case Intrinsic::x86_avx_maskstore_ps: - case Intrinsic::x86_avx_maskstore_pd: - case Intrinsic::x86_avx_maskstore_ps_256: - case Intrinsic::x86_avx_maskstore_pd_256: - case Intrinsic::x86_avx2_maskstore_d: - case Intrinsic::x86_avx2_maskstore_q: - case Intrinsic::x86_avx2_maskstore_d_256: - case Intrinsic::x86_avx2_maskstore_q_256: - if (simplifyX86MaskedStore(*II, *this)) - return nullptr; - break; - - case Intrinsic::x86_addcarry_32: - case Intrinsic::x86_addcarry_64: - if (Value *V = simplifyX86addcarry(*II, Builder)) - return replaceInstUsesWith(*II, V); - break; - - case Intrinsic::ppc_altivec_vperm: - // Turn vperm(V1,V2,mask) -> shuffle(V1,V2,mask) if mask is a constant. - // Note that ppc_altivec_vperm has a big-endian bias, so when creating - // a vectorshuffle for little endian, we must undo the transformation - // performed on vec_perm in altivec.h. That is, we must complement - // the permutation mask with respect to 31 and reverse the order of - // V1 and V2. - if (Constant *Mask = dyn_cast<Constant>(II->getArgOperand(2))) { - assert(cast<VectorType>(Mask->getType())->getNumElements() == 16 && - "Bad type for intrinsic!"); - - // Check that all of the elements are integer constants or undefs. - bool AllEltsOk = true; - for (unsigned i = 0; i != 16; ++i) { - Constant *Elt = Mask->getAggregateElement(i); - if (!Elt || !(isa<ConstantInt>(Elt) || isa<UndefValue>(Elt))) { - AllEltsOk = false; - break; - } - } - - if (AllEltsOk) { - // Cast the input vectors to byte vectors. - Value *Op0 = Builder.CreateBitCast(II->getArgOperand(0), - Mask->getType()); - Value *Op1 = Builder.CreateBitCast(II->getArgOperand(1), - Mask->getType()); - Value *Result = UndefValue::get(Op0->getType()); - - // Only extract each element once. - Value *ExtractedElts[32]; - memset(ExtractedElts, 0, sizeof(ExtractedElts)); - - for (unsigned i = 0; i != 16; ++i) { - if (isa<UndefValue>(Mask->getAggregateElement(i))) - continue; - unsigned Idx = - cast<ConstantInt>(Mask->getAggregateElement(i))->getZExtValue(); - Idx &= 31; // Match the hardware behavior. - if (DL.isLittleEndian()) - Idx = 31 - Idx; - - if (!ExtractedElts[Idx]) { - Value *Op0ToUse = (DL.isLittleEndian()) ? Op1 : Op0; - Value *Op1ToUse = (DL.isLittleEndian()) ? Op0 : Op1; - ExtractedElts[Idx] = - Builder.CreateExtractElement(Idx < 16 ? Op0ToUse : Op1ToUse, - Builder.getInt32(Idx&15)); - } - - // Insert this value into the result vector. - Result = Builder.CreateInsertElement(Result, ExtractedElts[Idx], - Builder.getInt32(i)); - } - return CastInst::Create(Instruction::BitCast, Result, CI.getType()); - } - } - break; - - case Intrinsic::arm_neon_vld1: { - Align MemAlign = getKnownAlignment(II->getArgOperand(0), DL, II, &AC, &DT); - if (Value *V = simplifyNeonVld1(*II, MemAlign.value(), Builder)) - return replaceInstUsesWith(*II, V); - break; - } - - case Intrinsic::arm_neon_vld2: - case Intrinsic::arm_neon_vld3: - case Intrinsic::arm_neon_vld4: - case Intrinsic::arm_neon_vld2lane: - case Intrinsic::arm_neon_vld3lane: - case Intrinsic::arm_neon_vld4lane: - case Intrinsic::arm_neon_vst1: - case Intrinsic::arm_neon_vst2: - case Intrinsic::arm_neon_vst3: - case Intrinsic::arm_neon_vst4: - case Intrinsic::arm_neon_vst2lane: - case Intrinsic::arm_neon_vst3lane: - case Intrinsic::arm_neon_vst4lane: { - Align MemAlign = getKnownAlignment(II->getArgOperand(0), DL, II, &AC, &DT); - unsigned AlignArg = II->getNumArgOperands() - 1; - Value *AlignArgOp = II->getArgOperand(AlignArg); - MaybeAlign Align = cast<ConstantInt>(AlignArgOp)->getMaybeAlignValue(); - if (Align && *Align < MemAlign) - return replaceOperand(*II, AlignArg, - ConstantInt::get(Type::getInt32Ty(II->getContext()), - MemAlign.value(), false)); - break; - } case Intrinsic::arm_neon_vtbl1: case Intrinsic::aarch64_neon_tbl1: @@ -3453,690 +1408,6 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { } break; } - case Intrinsic::arm_mve_pred_i2v: { - Value *Arg = II->getArgOperand(0); - Value *ArgArg; - if (match(Arg, m_Intrinsic<Intrinsic::arm_mve_pred_v2i>(m_Value(ArgArg))) && - II->getType() == ArgArg->getType()) - return replaceInstUsesWith(*II, ArgArg); - Constant *XorMask; - if (match(Arg, - m_Xor(m_Intrinsic<Intrinsic::arm_mve_pred_v2i>(m_Value(ArgArg)), - m_Constant(XorMask))) && - II->getType() == ArgArg->getType()) { - if (auto *CI = dyn_cast<ConstantInt>(XorMask)) { - if (CI->getValue().trunc(16).isAllOnesValue()) { - auto TrueVector = Builder.CreateVectorSplat( - cast<VectorType>(II->getType())->getNumElements(), - Builder.getTrue()); - return BinaryOperator::Create(Instruction::Xor, ArgArg, TrueVector); - } - } - } - KnownBits ScalarKnown(32); - if (SimplifyDemandedBits(II, 0, APInt::getLowBitsSet(32, 16), - ScalarKnown, 0)) - return II; - break; - } - case Intrinsic::arm_mve_pred_v2i: { - Value *Arg = II->getArgOperand(0); - Value *ArgArg; - if (match(Arg, m_Intrinsic<Intrinsic::arm_mve_pred_i2v>(m_Value(ArgArg)))) - return replaceInstUsesWith(*II, ArgArg); - if (!II->getMetadata(LLVMContext::MD_range)) { - Type *IntTy32 = Type::getInt32Ty(II->getContext()); - Metadata *M[] = { - ConstantAsMetadata::get(ConstantInt::get(IntTy32, 0)), - ConstantAsMetadata::get(ConstantInt::get(IntTy32, 0xFFFF)) - }; - II->setMetadata(LLVMContext::MD_range, MDNode::get(II->getContext(), M)); - return II; - } - break; - } - case Intrinsic::arm_mve_vadc: - case Intrinsic::arm_mve_vadc_predicated: { - unsigned CarryOp = - (II->getIntrinsicID() == Intrinsic::arm_mve_vadc_predicated) ? 3 : 2; - assert(II->getArgOperand(CarryOp)->getType()->getScalarSizeInBits() == 32 && - "Bad type for intrinsic!"); - - KnownBits CarryKnown(32); - if (SimplifyDemandedBits(II, CarryOp, APInt::getOneBitSet(32, 29), - CarryKnown)) - return II; - break; - } - case Intrinsic::amdgcn_rcp: { - Value *Src = II->getArgOperand(0); - - // TODO: Move to ConstantFolding/InstSimplify? - if (isa<UndefValue>(Src)) { - Type *Ty = II->getType(); - auto *QNaN = ConstantFP::get(Ty, APFloat::getQNaN(Ty->getFltSemantics())); - return replaceInstUsesWith(CI, QNaN); - } - - if (II->isStrictFP()) - break; - - if (const ConstantFP *C = dyn_cast<ConstantFP>(Src)) { - const APFloat &ArgVal = C->getValueAPF(); - APFloat Val(ArgVal.getSemantics(), 1); - Val.divide(ArgVal, APFloat::rmNearestTiesToEven); - - // This is more precise than the instruction may give. - // - // TODO: The instruction always flushes denormal results (except for f16), - // should this also? - return replaceInstUsesWith(CI, ConstantFP::get(II->getContext(), Val)); - } - - break; - } - case Intrinsic::amdgcn_rsq: { - Value *Src = II->getArgOperand(0); - - // TODO: Move to ConstantFolding/InstSimplify? - if (isa<UndefValue>(Src)) { - Type *Ty = II->getType(); - auto *QNaN = ConstantFP::get(Ty, APFloat::getQNaN(Ty->getFltSemantics())); - return replaceInstUsesWith(CI, QNaN); - } - - break; - } - case Intrinsic::amdgcn_frexp_mant: - case Intrinsic::amdgcn_frexp_exp: { - Value *Src = II->getArgOperand(0); - if (const ConstantFP *C = dyn_cast<ConstantFP>(Src)) { - int Exp; - APFloat Significand = frexp(C->getValueAPF(), Exp, - APFloat::rmNearestTiesToEven); - - if (IID == Intrinsic::amdgcn_frexp_mant) { - return replaceInstUsesWith(CI, ConstantFP::get(II->getContext(), - Significand)); - } - - // Match instruction special case behavior. - if (Exp == APFloat::IEK_NaN || Exp == APFloat::IEK_Inf) - Exp = 0; - - return replaceInstUsesWith(CI, ConstantInt::get(II->getType(), Exp)); - } - - if (isa<UndefValue>(Src)) - return replaceInstUsesWith(CI, UndefValue::get(II->getType())); - - break; - } - case Intrinsic::amdgcn_class: { - enum { - S_NAN = 1 << 0, // Signaling NaN - Q_NAN = 1 << 1, // Quiet NaN - N_INFINITY = 1 << 2, // Negative infinity - N_NORMAL = 1 << 3, // Negative normal - N_SUBNORMAL = 1 << 4, // Negative subnormal - N_ZERO = 1 << 5, // Negative zero - P_ZERO = 1 << 6, // Positive zero - P_SUBNORMAL = 1 << 7, // Positive subnormal - P_NORMAL = 1 << 8, // Positive normal - P_INFINITY = 1 << 9 // Positive infinity - }; - - const uint32_t FullMask = S_NAN | Q_NAN | N_INFINITY | N_NORMAL | - N_SUBNORMAL | N_ZERO | P_ZERO | P_SUBNORMAL | P_NORMAL | P_INFINITY; - - Value *Src0 = II->getArgOperand(0); - Value *Src1 = II->getArgOperand(1); - const ConstantInt *CMask = dyn_cast<ConstantInt>(Src1); - if (!CMask) { - if (isa<UndefValue>(Src0)) - return replaceInstUsesWith(*II, UndefValue::get(II->getType())); - - if (isa<UndefValue>(Src1)) - return replaceInstUsesWith(*II, ConstantInt::get(II->getType(), false)); - break; - } - - uint32_t Mask = CMask->getZExtValue(); - - // If all tests are made, it doesn't matter what the value is. - if ((Mask & FullMask) == FullMask) - return replaceInstUsesWith(*II, ConstantInt::get(II->getType(), true)); - - if ((Mask & FullMask) == 0) - return replaceInstUsesWith(*II, ConstantInt::get(II->getType(), false)); - - if (Mask == (S_NAN | Q_NAN)) { - // Equivalent of isnan. Replace with standard fcmp. - Value *FCmp = Builder.CreateFCmpUNO(Src0, Src0); - FCmp->takeName(II); - return replaceInstUsesWith(*II, FCmp); - } - - if (Mask == (N_ZERO | P_ZERO)) { - // Equivalent of == 0. - Value *FCmp = Builder.CreateFCmpOEQ( - Src0, ConstantFP::get(Src0->getType(), 0.0)); - - FCmp->takeName(II); - return replaceInstUsesWith(*II, FCmp); - } - - // fp_class (nnan x), qnan|snan|other -> fp_class (nnan x), other - if (((Mask & S_NAN) || (Mask & Q_NAN)) && isKnownNeverNaN(Src0, &TLI)) - return replaceOperand(*II, 1, ConstantInt::get(Src1->getType(), - Mask & ~(S_NAN | Q_NAN))); - - const ConstantFP *CVal = dyn_cast<ConstantFP>(Src0); - if (!CVal) { - if (isa<UndefValue>(Src0)) - return replaceInstUsesWith(*II, UndefValue::get(II->getType())); - - // Clamp mask to used bits - if ((Mask & FullMask) != Mask) { - CallInst *NewCall = Builder.CreateCall(II->getCalledFunction(), - { Src0, ConstantInt::get(Src1->getType(), Mask & FullMask) } - ); - - NewCall->takeName(II); - return replaceInstUsesWith(*II, NewCall); - } - - break; - } - - const APFloat &Val = CVal->getValueAPF(); - - bool Result = - ((Mask & S_NAN) && Val.isNaN() && Val.isSignaling()) || - ((Mask & Q_NAN) && Val.isNaN() && !Val.isSignaling()) || - ((Mask & N_INFINITY) && Val.isInfinity() && Val.isNegative()) || - ((Mask & N_NORMAL) && Val.isNormal() && Val.isNegative()) || - ((Mask & N_SUBNORMAL) && Val.isDenormal() && Val.isNegative()) || - ((Mask & N_ZERO) && Val.isZero() && Val.isNegative()) || - ((Mask & P_ZERO) && Val.isZero() && !Val.isNegative()) || - ((Mask & P_SUBNORMAL) && Val.isDenormal() && !Val.isNegative()) || - ((Mask & P_NORMAL) && Val.isNormal() && !Val.isNegative()) || - ((Mask & P_INFINITY) && Val.isInfinity() && !Val.isNegative()); - - return replaceInstUsesWith(*II, ConstantInt::get(II->getType(), Result)); - } - case Intrinsic::amdgcn_cvt_pkrtz: { - Value *Src0 = II->getArgOperand(0); - Value *Src1 = II->getArgOperand(1); - if (const ConstantFP *C0 = dyn_cast<ConstantFP>(Src0)) { - if (const ConstantFP *C1 = dyn_cast<ConstantFP>(Src1)) { - const fltSemantics &HalfSem - = II->getType()->getScalarType()->getFltSemantics(); - bool LosesInfo; - APFloat Val0 = C0->getValueAPF(); - APFloat Val1 = C1->getValueAPF(); - Val0.convert(HalfSem, APFloat::rmTowardZero, &LosesInfo); - Val1.convert(HalfSem, APFloat::rmTowardZero, &LosesInfo); - - Constant *Folded = ConstantVector::get({ - ConstantFP::get(II->getContext(), Val0), - ConstantFP::get(II->getContext(), Val1) }); - return replaceInstUsesWith(*II, Folded); - } - } - - if (isa<UndefValue>(Src0) && isa<UndefValue>(Src1)) - return replaceInstUsesWith(*II, UndefValue::get(II->getType())); - - break; - } - case Intrinsic::amdgcn_cvt_pknorm_i16: - case Intrinsic::amdgcn_cvt_pknorm_u16: - case Intrinsic::amdgcn_cvt_pk_i16: - case Intrinsic::amdgcn_cvt_pk_u16: { - Value *Src0 = II->getArgOperand(0); - Value *Src1 = II->getArgOperand(1); - - if (isa<UndefValue>(Src0) && isa<UndefValue>(Src1)) - return replaceInstUsesWith(*II, UndefValue::get(II->getType())); - - break; - } - case Intrinsic::amdgcn_ubfe: - case Intrinsic::amdgcn_sbfe: { - // Decompose simple cases into standard shifts. - Value *Src = II->getArgOperand(0); - if (isa<UndefValue>(Src)) - return replaceInstUsesWith(*II, Src); - - unsigned Width; - Type *Ty = II->getType(); - unsigned IntSize = Ty->getIntegerBitWidth(); - - ConstantInt *CWidth = dyn_cast<ConstantInt>(II->getArgOperand(2)); - if (CWidth) { - Width = CWidth->getZExtValue(); - if ((Width & (IntSize - 1)) == 0) - return replaceInstUsesWith(*II, ConstantInt::getNullValue(Ty)); - - // Hardware ignores high bits, so remove those. - if (Width >= IntSize) - return replaceOperand(*II, 2, ConstantInt::get(CWidth->getType(), - Width & (IntSize - 1))); - } - - unsigned Offset; - ConstantInt *COffset = dyn_cast<ConstantInt>(II->getArgOperand(1)); - if (COffset) { - Offset = COffset->getZExtValue(); - if (Offset >= IntSize) - return replaceOperand(*II, 1, ConstantInt::get(COffset->getType(), - Offset & (IntSize - 1))); - } - - bool Signed = IID == Intrinsic::amdgcn_sbfe; - - if (!CWidth || !COffset) - break; - - // The case of Width == 0 is handled above, which makes this tranformation - // safe. If Width == 0, then the ashr and lshr instructions become poison - // value since the shift amount would be equal to the bit size. - assert(Width != 0); - - // TODO: This allows folding to undef when the hardware has specific - // behavior? - if (Offset + Width < IntSize) { - Value *Shl = Builder.CreateShl(Src, IntSize - Offset - Width); - Value *RightShift = Signed ? Builder.CreateAShr(Shl, IntSize - Width) - : Builder.CreateLShr(Shl, IntSize - Width); - RightShift->takeName(II); - return replaceInstUsesWith(*II, RightShift); - } - - Value *RightShift = Signed ? Builder.CreateAShr(Src, Offset) - : Builder.CreateLShr(Src, Offset); - - RightShift->takeName(II); - return replaceInstUsesWith(*II, RightShift); - } - case Intrinsic::amdgcn_exp: - case Intrinsic::amdgcn_exp_compr: { - ConstantInt *En = cast<ConstantInt>(II->getArgOperand(1)); - unsigned EnBits = En->getZExtValue(); - if (EnBits == 0xf) - break; // All inputs enabled. - - bool IsCompr = IID == Intrinsic::amdgcn_exp_compr; - bool Changed = false; - for (int I = 0; I < (IsCompr ? 2 : 4); ++I) { - if ((!IsCompr && (EnBits & (1 << I)) == 0) || - (IsCompr && ((EnBits & (0x3 << (2 * I))) == 0))) { - Value *Src = II->getArgOperand(I + 2); - if (!isa<UndefValue>(Src)) { - replaceOperand(*II, I + 2, UndefValue::get(Src->getType())); - Changed = true; - } - } - } - - if (Changed) - return II; - - break; - } - case Intrinsic::amdgcn_fmed3: { - // Note this does not preserve proper sNaN behavior if IEEE-mode is enabled - // for the shader. - - Value *Src0 = II->getArgOperand(0); - Value *Src1 = II->getArgOperand(1); - Value *Src2 = II->getArgOperand(2); - - // Checking for NaN before canonicalization provides better fidelity when - // mapping other operations onto fmed3 since the order of operands is - // unchanged. - CallInst *NewCall = nullptr; - if (match(Src0, m_NaN()) || isa<UndefValue>(Src0)) { - NewCall = Builder.CreateMinNum(Src1, Src2); - } else if (match(Src1, m_NaN()) || isa<UndefValue>(Src1)) { - NewCall = Builder.CreateMinNum(Src0, Src2); - } else if (match(Src2, m_NaN()) || isa<UndefValue>(Src2)) { - NewCall = Builder.CreateMaxNum(Src0, Src1); - } - - if (NewCall) { - NewCall->copyFastMathFlags(II); - NewCall->takeName(II); - return replaceInstUsesWith(*II, NewCall); - } - - bool Swap = false; - // Canonicalize constants to RHS operands. - // - // fmed3(c0, x, c1) -> fmed3(x, c0, c1) - if (isa<Constant>(Src0) && !isa<Constant>(Src1)) { - std::swap(Src0, Src1); - Swap = true; - } - - if (isa<Constant>(Src1) && !isa<Constant>(Src2)) { - std::swap(Src1, Src2); - Swap = true; - } - - if (isa<Constant>(Src0) && !isa<Constant>(Src1)) { - std::swap(Src0, Src1); - Swap = true; - } - - if (Swap) { - II->setArgOperand(0, Src0); - II->setArgOperand(1, Src1); - II->setArgOperand(2, Src2); - return II; - } - - if (const ConstantFP *C0 = dyn_cast<ConstantFP>(Src0)) { - if (const ConstantFP *C1 = dyn_cast<ConstantFP>(Src1)) { - if (const ConstantFP *C2 = dyn_cast<ConstantFP>(Src2)) { - APFloat Result = fmed3AMDGCN(C0->getValueAPF(), C1->getValueAPF(), - C2->getValueAPF()); - return replaceInstUsesWith(*II, - ConstantFP::get(Builder.getContext(), Result)); - } - } - } - - break; - } - case Intrinsic::amdgcn_icmp: - case Intrinsic::amdgcn_fcmp: { - const ConstantInt *CC = cast<ConstantInt>(II->getArgOperand(2)); - // Guard against invalid arguments. - int64_t CCVal = CC->getZExtValue(); - bool IsInteger = IID == Intrinsic::amdgcn_icmp; - if ((IsInteger && (CCVal < CmpInst::FIRST_ICMP_PREDICATE || - CCVal > CmpInst::LAST_ICMP_PREDICATE)) || - (!IsInteger && (CCVal < CmpInst::FIRST_FCMP_PREDICATE || - CCVal > CmpInst::LAST_FCMP_PREDICATE))) - break; - - Value *Src0 = II->getArgOperand(0); - Value *Src1 = II->getArgOperand(1); - - if (auto *CSrc0 = dyn_cast<Constant>(Src0)) { - if (auto *CSrc1 = dyn_cast<Constant>(Src1)) { - Constant *CCmp = ConstantExpr::getCompare(CCVal, CSrc0, CSrc1); - if (CCmp->isNullValue()) { - return replaceInstUsesWith( - *II, ConstantExpr::getSExt(CCmp, II->getType())); - } - - // The result of V_ICMP/V_FCMP assembly instructions (which this - // intrinsic exposes) is one bit per thread, masked with the EXEC - // register (which contains the bitmask of live threads). So a - // comparison that always returns true is the same as a read of the - // EXEC register. - Function *NewF = Intrinsic::getDeclaration( - II->getModule(), Intrinsic::read_register, II->getType()); - Metadata *MDArgs[] = {MDString::get(II->getContext(), "exec")}; - MDNode *MD = MDNode::get(II->getContext(), MDArgs); - Value *Args[] = {MetadataAsValue::get(II->getContext(), MD)}; - CallInst *NewCall = Builder.CreateCall(NewF, Args); - NewCall->addAttribute(AttributeList::FunctionIndex, - Attribute::Convergent); - NewCall->takeName(II); - return replaceInstUsesWith(*II, NewCall); - } - - // Canonicalize constants to RHS. - CmpInst::Predicate SwapPred - = CmpInst::getSwappedPredicate(static_cast<CmpInst::Predicate>(CCVal)); - II->setArgOperand(0, Src1); - II->setArgOperand(1, Src0); - II->setArgOperand(2, ConstantInt::get(CC->getType(), - static_cast<int>(SwapPred))); - return II; - } - - if (CCVal != CmpInst::ICMP_EQ && CCVal != CmpInst::ICMP_NE) - break; - - // Canonicalize compare eq with true value to compare != 0 - // llvm.amdgcn.icmp(zext (i1 x), 1, eq) - // -> llvm.amdgcn.icmp(zext (i1 x), 0, ne) - // llvm.amdgcn.icmp(sext (i1 x), -1, eq) - // -> llvm.amdgcn.icmp(sext (i1 x), 0, ne) - Value *ExtSrc; - if (CCVal == CmpInst::ICMP_EQ && - ((match(Src1, m_One()) && match(Src0, m_ZExt(m_Value(ExtSrc)))) || - (match(Src1, m_AllOnes()) && match(Src0, m_SExt(m_Value(ExtSrc))))) && - ExtSrc->getType()->isIntegerTy(1)) { - replaceOperand(*II, 1, ConstantInt::getNullValue(Src1->getType())); - replaceOperand(*II, 2, ConstantInt::get(CC->getType(), CmpInst::ICMP_NE)); - return II; - } - - CmpInst::Predicate SrcPred; - Value *SrcLHS; - Value *SrcRHS; - - // Fold compare eq/ne with 0 from a compare result as the predicate to the - // intrinsic. The typical use is a wave vote function in the library, which - // will be fed from a user code condition compared with 0. Fold in the - // redundant compare. - - // llvm.amdgcn.icmp([sz]ext ([if]cmp pred a, b), 0, ne) - // -> llvm.amdgcn.[if]cmp(a, b, pred) - // - // llvm.amdgcn.icmp([sz]ext ([if]cmp pred a, b), 0, eq) - // -> llvm.amdgcn.[if]cmp(a, b, inv pred) - if (match(Src1, m_Zero()) && - match(Src0, - m_ZExtOrSExt(m_Cmp(SrcPred, m_Value(SrcLHS), m_Value(SrcRHS))))) { - if (CCVal == CmpInst::ICMP_EQ) - SrcPred = CmpInst::getInversePredicate(SrcPred); - - Intrinsic::ID NewIID = CmpInst::isFPPredicate(SrcPred) ? - Intrinsic::amdgcn_fcmp : Intrinsic::amdgcn_icmp; - - Type *Ty = SrcLHS->getType(); - if (auto *CmpType = dyn_cast<IntegerType>(Ty)) { - // Promote to next legal integer type. - unsigned Width = CmpType->getBitWidth(); - unsigned NewWidth = Width; - - // Don't do anything for i1 comparisons. - if (Width == 1) - break; - - if (Width <= 16) - NewWidth = 16; - else if (Width <= 32) - NewWidth = 32; - else if (Width <= 64) - NewWidth = 64; - else if (Width > 64) - break; // Can't handle this. - - if (Width != NewWidth) { - IntegerType *CmpTy = Builder.getIntNTy(NewWidth); - if (CmpInst::isSigned(SrcPred)) { - SrcLHS = Builder.CreateSExt(SrcLHS, CmpTy); - SrcRHS = Builder.CreateSExt(SrcRHS, CmpTy); - } else { - SrcLHS = Builder.CreateZExt(SrcLHS, CmpTy); - SrcRHS = Builder.CreateZExt(SrcRHS, CmpTy); - } - } - } else if (!Ty->isFloatTy() && !Ty->isDoubleTy() && !Ty->isHalfTy()) - break; - - Function *NewF = - Intrinsic::getDeclaration(II->getModule(), NewIID, - { II->getType(), - SrcLHS->getType() }); - Value *Args[] = { SrcLHS, SrcRHS, - ConstantInt::get(CC->getType(), SrcPred) }; - CallInst *NewCall = Builder.CreateCall(NewF, Args); - NewCall->takeName(II); - return replaceInstUsesWith(*II, NewCall); - } - - break; - } - case Intrinsic::amdgcn_ballot: { - if (auto *Src = dyn_cast<ConstantInt>(II->getArgOperand(0))) { - if (Src->isZero()) { - // amdgcn.ballot(i1 0) is zero. - return replaceInstUsesWith(*II, Constant::getNullValue(II->getType())); - } - - if (Src->isOne()) { - // amdgcn.ballot(i1 1) is exec. - const char *RegName = "exec"; - if (II->getType()->isIntegerTy(32)) - RegName = "exec_lo"; - else if (!II->getType()->isIntegerTy(64)) - break; - - Function *NewF = Intrinsic::getDeclaration( - II->getModule(), Intrinsic::read_register, II->getType()); - Metadata *MDArgs[] = {MDString::get(II->getContext(), RegName)}; - MDNode *MD = MDNode::get(II->getContext(), MDArgs); - Value *Args[] = {MetadataAsValue::get(II->getContext(), MD)}; - CallInst *NewCall = Builder.CreateCall(NewF, Args); - NewCall->addAttribute(AttributeList::FunctionIndex, - Attribute::Convergent); - NewCall->takeName(II); - return replaceInstUsesWith(*II, NewCall); - } - } - break; - } - case Intrinsic::amdgcn_wqm_vote: { - // wqm_vote is identity when the argument is constant. - if (!isa<Constant>(II->getArgOperand(0))) - break; - - return replaceInstUsesWith(*II, II->getArgOperand(0)); - } - case Intrinsic::amdgcn_kill: { - const ConstantInt *C = dyn_cast<ConstantInt>(II->getArgOperand(0)); - if (!C || !C->getZExtValue()) - break; - - // amdgcn.kill(i1 1) is a no-op - return eraseInstFromFunction(CI); - } - case Intrinsic::amdgcn_update_dpp: { - Value *Old = II->getArgOperand(0); - - auto BC = cast<ConstantInt>(II->getArgOperand(5)); - auto RM = cast<ConstantInt>(II->getArgOperand(3)); - auto BM = cast<ConstantInt>(II->getArgOperand(4)); - if (BC->isZeroValue() || - RM->getZExtValue() != 0xF || - BM->getZExtValue() != 0xF || - isa<UndefValue>(Old)) - break; - - // If bound_ctrl = 1, row mask = bank mask = 0xf we can omit old value. - return replaceOperand(*II, 0, UndefValue::get(Old->getType())); - } - case Intrinsic::amdgcn_permlane16: - case Intrinsic::amdgcn_permlanex16: { - // Discard vdst_in if it's not going to be read. - Value *VDstIn = II->getArgOperand(0); - if (isa<UndefValue>(VDstIn)) - break; - - ConstantInt *FetchInvalid = cast<ConstantInt>(II->getArgOperand(4)); - ConstantInt *BoundCtrl = cast<ConstantInt>(II->getArgOperand(5)); - if (!FetchInvalid->getZExtValue() && !BoundCtrl->getZExtValue()) - break; - - return replaceOperand(*II, 0, UndefValue::get(VDstIn->getType())); - } - case Intrinsic::amdgcn_readfirstlane: - case Intrinsic::amdgcn_readlane: { - // A constant value is trivially uniform. - if (Constant *C = dyn_cast<Constant>(II->getArgOperand(0))) - return replaceInstUsesWith(*II, C); - - // The rest of these may not be safe if the exec may not be the same between - // the def and use. - Value *Src = II->getArgOperand(0); - Instruction *SrcInst = dyn_cast<Instruction>(Src); - if (SrcInst && SrcInst->getParent() != II->getParent()) - break; - - // readfirstlane (readfirstlane x) -> readfirstlane x - // readlane (readfirstlane x), y -> readfirstlane x - if (match(Src, m_Intrinsic<Intrinsic::amdgcn_readfirstlane>())) - return replaceInstUsesWith(*II, Src); - - if (IID == Intrinsic::amdgcn_readfirstlane) { - // readfirstlane (readlane x, y) -> readlane x, y - if (match(Src, m_Intrinsic<Intrinsic::amdgcn_readlane>())) - return replaceInstUsesWith(*II, Src); - } else { - // readlane (readlane x, y), y -> readlane x, y - if (match(Src, m_Intrinsic<Intrinsic::amdgcn_readlane>( - m_Value(), m_Specific(II->getArgOperand(1))))) - return replaceInstUsesWith(*II, Src); - } - - break; - } - case Intrinsic::amdgcn_ldexp: { - // FIXME: This doesn't introduce new instructions and belongs in - // InstructionSimplify. - Type *Ty = II->getType(); - Value *Op0 = II->getArgOperand(0); - Value *Op1 = II->getArgOperand(1); - - // Folding undef to qnan is safe regardless of the FP mode. - if (isa<UndefValue>(Op0)) { - auto *QNaN = ConstantFP::get(Ty, APFloat::getQNaN(Ty->getFltSemantics())); - return replaceInstUsesWith(*II, QNaN); - } - - const APFloat *C = nullptr; - match(Op0, m_APFloat(C)); - - // FIXME: Should flush denorms depending on FP mode, but that's ignored - // everywhere else. - // - // These cases should be safe, even with strictfp. - // ldexp(0.0, x) -> 0.0 - // ldexp(-0.0, x) -> -0.0 - // ldexp(inf, x) -> inf - // ldexp(-inf, x) -> -inf - if (C && (C->isZero() || C->isInfinity())) - return replaceInstUsesWith(*II, Op0); - - // With strictfp, be more careful about possibly needing to flush denormals - // or not, and snan behavior depends on ieee_mode. - if (II->isStrictFP()) - break; - - if (C && C->isNaN()) { - // FIXME: We just need to make the nan quiet here, but that's unavailable - // on APFloat, only IEEEfloat - auto *Quieted = ConstantFP::get( - Ty, scalbn(*C, 0, APFloat::rmNearestTiesToEven)); - return replaceInstUsesWith(*II, Quieted); - } - - // ldexp(x, 0) -> x - // ldexp(x, undef) -> x - if (isa<UndefValue>(Op1) || match(Op1, m_ZeroInt())) - return replaceInstUsesWith(*II, Op0); - - break; - } case Intrinsic::hexagon_V6_vandvrt: case Intrinsic::hexagon_V6_vandvrt_128B: { // Simplify Q -> V -> Q conversion. @@ -4238,14 +1509,14 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { FunctionType *AssumeIntrinsicTy = II->getFunctionType(); Value *AssumeIntrinsic = II->getCalledOperand(); Value *A, *B; - if (match(IIOperand, m_And(m_Value(A), m_Value(B)))) { + if (match(IIOperand, m_LogicalAnd(m_Value(A), m_Value(B)))) { Builder.CreateCall(AssumeIntrinsicTy, AssumeIntrinsic, A, OpBundles, II->getName()); Builder.CreateCall(AssumeIntrinsicTy, AssumeIntrinsic, B, II->getName()); return eraseInstFromFunction(*II); } // assume(!(a || b)) -> assume(!a); assume(!b); - if (match(IIOperand, m_Not(m_Or(m_Value(A), m_Value(B))))) { + if (match(IIOperand, m_Not(m_LogicalOr(m_Value(A), m_Value(B))))) { Builder.CreateCall(AssumeIntrinsicTy, AssumeIntrinsic, Builder.CreateNot(A), OpBundles, II->getName()); Builder.CreateCall(AssumeIntrinsicTy, AssumeIntrinsic, @@ -4282,59 +1553,104 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { AC.updateAffectedValues(II); break; } - case Intrinsic::experimental_gc_relocate: { - auto &GCR = *cast<GCRelocateInst>(II); - - // If we have two copies of the same pointer in the statepoint argument - // list, canonicalize to one. This may let us common gc.relocates. - if (GCR.getBasePtr() == GCR.getDerivedPtr() && - GCR.getBasePtrIndex() != GCR.getDerivedPtrIndex()) { - auto *OpIntTy = GCR.getOperand(2)->getType(); - return replaceOperand(*II, 2, - ConstantInt::get(OpIntTy, GCR.getBasePtrIndex())); - } + case Intrinsic::experimental_gc_statepoint: { + GCStatepointInst &GCSP = *cast<GCStatepointInst>(II); + SmallPtrSet<Value *, 32> LiveGcValues; + for (const GCRelocateInst *Reloc : GCSP.getGCRelocates()) { + GCRelocateInst &GCR = *const_cast<GCRelocateInst *>(Reloc); - // Translate facts known about a pointer before relocating into - // facts about the relocate value, while being careful to - // preserve relocation semantics. - Value *DerivedPtr = GCR.getDerivedPtr(); + // Remove the relocation if unused. + if (GCR.use_empty()) { + eraseInstFromFunction(GCR); + continue; + } - // Remove the relocation if unused, note that this check is required - // to prevent the cases below from looping forever. - if (II->use_empty()) - return eraseInstFromFunction(*II); + Value *DerivedPtr = GCR.getDerivedPtr(); + Value *BasePtr = GCR.getBasePtr(); - // Undef is undef, even after relocation. - // TODO: provide a hook for this in GCStrategy. This is clearly legal for - // most practical collectors, but there was discussion in the review thread - // about whether it was legal for all possible collectors. - if (isa<UndefValue>(DerivedPtr)) - // Use undef of gc_relocate's type to replace it. - return replaceInstUsesWith(*II, UndefValue::get(II->getType())); - - if (auto *PT = dyn_cast<PointerType>(II->getType())) { - // The relocation of null will be null for most any collector. - // TODO: provide a hook for this in GCStrategy. There might be some - // weird collector this property does not hold for. - if (isa<ConstantPointerNull>(DerivedPtr)) - // Use null-pointer of gc_relocate's type to replace it. - return replaceInstUsesWith(*II, ConstantPointerNull::get(PT)); - - // isKnownNonNull -> nonnull attribute - if (!II->hasRetAttr(Attribute::NonNull) && - isKnownNonZero(DerivedPtr, DL, 0, &AC, II, &DT)) { - II->addAttribute(AttributeList::ReturnIndex, Attribute::NonNull); - return II; + // Undef is undef, even after relocation. + if (isa<UndefValue>(DerivedPtr) || isa<UndefValue>(BasePtr)) { + replaceInstUsesWith(GCR, UndefValue::get(GCR.getType())); + eraseInstFromFunction(GCR); + continue; + } + + if (auto *PT = dyn_cast<PointerType>(GCR.getType())) { + // The relocation of null will be null for most any collector. + // TODO: provide a hook for this in GCStrategy. There might be some + // weird collector this property does not hold for. + if (isa<ConstantPointerNull>(DerivedPtr)) { + // Use null-pointer of gc_relocate's type to replace it. + replaceInstUsesWith(GCR, ConstantPointerNull::get(PT)); + eraseInstFromFunction(GCR); + continue; + } + + // isKnownNonNull -> nonnull attribute + if (!GCR.hasRetAttr(Attribute::NonNull) && + isKnownNonZero(DerivedPtr, DL, 0, &AC, II, &DT)) { + GCR.addAttribute(AttributeList::ReturnIndex, Attribute::NonNull); + // We discovered new fact, re-check users. + Worklist.pushUsersToWorkList(GCR); + } + } + + // If we have two copies of the same pointer in the statepoint argument + // list, canonicalize to one. This may let us common gc.relocates. + if (GCR.getBasePtr() == GCR.getDerivedPtr() && + GCR.getBasePtrIndex() != GCR.getDerivedPtrIndex()) { + auto *OpIntTy = GCR.getOperand(2)->getType(); + GCR.setOperand(2, ConstantInt::get(OpIntTy, GCR.getBasePtrIndex())); } - } - // TODO: bitcast(relocate(p)) -> relocate(bitcast(p)) - // Canonicalize on the type from the uses to the defs + // TODO: bitcast(relocate(p)) -> relocate(bitcast(p)) + // Canonicalize on the type from the uses to the defs - // TODO: relocate((gep p, C, C2, ...)) -> gep(relocate(p), C, C2, ...) + // TODO: relocate((gep p, C, C2, ...)) -> gep(relocate(p), C, C2, ...) + LiveGcValues.insert(BasePtr); + LiveGcValues.insert(DerivedPtr); + } + Optional<OperandBundleUse> Bundle = + GCSP.getOperandBundle(LLVMContext::OB_gc_live); + unsigned NumOfGCLives = LiveGcValues.size(); + if (!Bundle.hasValue() || NumOfGCLives == Bundle->Inputs.size()) + break; + // We can reduce the size of gc live bundle. + DenseMap<Value *, unsigned> Val2Idx; + std::vector<Value *> NewLiveGc; + for (unsigned I = 0, E = Bundle->Inputs.size(); I < E; ++I) { + Value *V = Bundle->Inputs[I]; + if (Val2Idx.count(V)) + continue; + if (LiveGcValues.count(V)) { + Val2Idx[V] = NewLiveGc.size(); + NewLiveGc.push_back(V); + } else + Val2Idx[V] = NumOfGCLives; + } + // Update all gc.relocates + for (const GCRelocateInst *Reloc : GCSP.getGCRelocates()) { + GCRelocateInst &GCR = *const_cast<GCRelocateInst *>(Reloc); + Value *BasePtr = GCR.getBasePtr(); + assert(Val2Idx.count(BasePtr) && Val2Idx[BasePtr] != NumOfGCLives && + "Missed live gc for base pointer"); + auto *OpIntTy1 = GCR.getOperand(1)->getType(); + GCR.setOperand(1, ConstantInt::get(OpIntTy1, Val2Idx[BasePtr])); + Value *DerivedPtr = GCR.getDerivedPtr(); + assert(Val2Idx.count(DerivedPtr) && Val2Idx[DerivedPtr] != NumOfGCLives && + "Missed live gc for derived pointer"); + auto *OpIntTy2 = GCR.getOperand(2)->getType(); + GCR.setOperand(2, ConstantInt::get(OpIntTy2, Val2Idx[DerivedPtr])); + } + // Create new statepoint instruction. + OperandBundleDef NewBundle("gc-live", NewLiveGc); + if (isa<CallInst>(II)) + return CallInst::CreateWithReplacedBundle(cast<CallInst>(II), NewBundle); + else + return InvokeInst::CreateWithReplacedBundle(cast<InvokeInst>(II), + NewBundle); break; } - case Intrinsic::experimental_guard: { // Is this guard followed by another guard? We scan forward over a small // fixed window of instructions to handle common cases with conditions @@ -4367,12 +1683,114 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { } break; } + case Intrinsic::experimental_vector_insert: { + Value *Vec = II->getArgOperand(0); + Value *SubVec = II->getArgOperand(1); + Value *Idx = II->getArgOperand(2); + auto *DstTy = dyn_cast<FixedVectorType>(II->getType()); + auto *VecTy = dyn_cast<FixedVectorType>(Vec->getType()); + auto *SubVecTy = dyn_cast<FixedVectorType>(SubVec->getType()); + + // Only canonicalize if the destination vector, Vec, and SubVec are all + // fixed vectors. + if (DstTy && VecTy && SubVecTy) { + unsigned DstNumElts = DstTy->getNumElements(); + unsigned VecNumElts = VecTy->getNumElements(); + unsigned SubVecNumElts = SubVecTy->getNumElements(); + unsigned IdxN = cast<ConstantInt>(Idx)->getZExtValue(); + + // The result of this call is undefined if IdxN is not a constant multiple + // of the SubVec's minimum vector length OR the insertion overruns Vec. + if (IdxN % SubVecNumElts != 0 || IdxN + SubVecNumElts > VecNumElts) { + replaceInstUsesWith(CI, UndefValue::get(CI.getType())); + return eraseInstFromFunction(CI); + } + + // An insert that entirely overwrites Vec with SubVec is a nop. + if (VecNumElts == SubVecNumElts) { + replaceInstUsesWith(CI, SubVec); + return eraseInstFromFunction(CI); + } + + // Widen SubVec into a vector of the same width as Vec, since + // shufflevector requires the two input vectors to be the same width. + // Elements beyond the bounds of SubVec within the widened vector are + // undefined. + SmallVector<int, 8> WidenMask; + unsigned i; + for (i = 0; i != SubVecNumElts; ++i) + WidenMask.push_back(i); + for (; i != VecNumElts; ++i) + WidenMask.push_back(UndefMaskElem); + + Value *WidenShuffle = Builder.CreateShuffleVector(SubVec, WidenMask); + + SmallVector<int, 8> Mask; + for (unsigned i = 0; i != IdxN; ++i) + Mask.push_back(i); + for (unsigned i = DstNumElts; i != DstNumElts + SubVecNumElts; ++i) + Mask.push_back(i); + for (unsigned i = IdxN + SubVecNumElts; i != DstNumElts; ++i) + Mask.push_back(i); + + Value *Shuffle = Builder.CreateShuffleVector(Vec, WidenShuffle, Mask); + replaceInstUsesWith(CI, Shuffle); + return eraseInstFromFunction(CI); + } + break; + } + case Intrinsic::experimental_vector_extract: { + Value *Vec = II->getArgOperand(0); + Value *Idx = II->getArgOperand(1); + + auto *DstTy = dyn_cast<FixedVectorType>(II->getType()); + auto *VecTy = dyn_cast<FixedVectorType>(Vec->getType()); + + // Only canonicalize if the the destination vector and Vec are fixed + // vectors. + if (DstTy && VecTy) { + unsigned DstNumElts = DstTy->getNumElements(); + unsigned VecNumElts = VecTy->getNumElements(); + unsigned IdxN = cast<ConstantInt>(Idx)->getZExtValue(); + + // The result of this call is undefined if IdxN is not a constant multiple + // of the result type's minimum vector length OR the extraction overruns + // Vec. + if (IdxN % DstNumElts != 0 || IdxN + DstNumElts > VecNumElts) { + replaceInstUsesWith(CI, UndefValue::get(CI.getType())); + return eraseInstFromFunction(CI); + } + + // Extracting the entirety of Vec is a nop. + if (VecNumElts == DstNumElts) { + replaceInstUsesWith(CI, Vec); + return eraseInstFromFunction(CI); + } + + SmallVector<int, 8> Mask; + for (unsigned i = 0; i != DstNumElts; ++i) + Mask.push_back(IdxN + i); + + Value *Shuffle = + Builder.CreateShuffleVector(Vec, UndefValue::get(VecTy), Mask); + replaceInstUsesWith(CI, Shuffle); + return eraseInstFromFunction(CI); + } + break; + } + default: { + // Handle target specific intrinsics + Optional<Instruction *> V = targetInstCombineIntrinsic(*II); + if (V.hasValue()) + return V.getValue(); + break; + } } return visitCallBase(*II); } // Fence instruction simplification -Instruction *InstCombiner::visitFenceInst(FenceInst &FI) { +Instruction *InstCombinerImpl::visitFenceInst(FenceInst &FI) { // Remove identical consecutive fences. Instruction *Next = FI.getNextNonDebugInstruction(); if (auto *NFI = dyn_cast<FenceInst>(Next)) @@ -4382,12 +1800,12 @@ Instruction *InstCombiner::visitFenceInst(FenceInst &FI) { } // InvokeInst simplification -Instruction *InstCombiner::visitInvokeInst(InvokeInst &II) { +Instruction *InstCombinerImpl::visitInvokeInst(InvokeInst &II) { return visitCallBase(II); } // CallBrInst simplification -Instruction *InstCombiner::visitCallBrInst(CallBrInst &CBI) { +Instruction *InstCombinerImpl::visitCallBrInst(CallBrInst &CBI) { return visitCallBase(CBI); } @@ -4427,7 +1845,7 @@ static bool isSafeToEliminateVarargsCast(const CallBase &Call, return true; } -Instruction *InstCombiner::tryOptimizeCall(CallInst *CI) { +Instruction *InstCombinerImpl::tryOptimizeCall(CallInst *CI) { if (!CI->getCalledFunction()) return nullptr; auto InstCombineRAUW = [this](Instruction *From, Value *With) { @@ -4584,7 +2002,7 @@ static void annotateAnyAllocSite(CallBase &Call, const TargetLibraryInfo *TLI) { } /// Improvements for call, callbr and invoke instructions. -Instruction *InstCombiner::visitCallBase(CallBase &Call) { +Instruction *InstCombinerImpl::visitCallBase(CallBase &Call) { if (isAllocationFn(&Call, &TLI)) annotateAnyAllocSite(Call, &TLI); @@ -4640,7 +2058,7 @@ Instruction *InstCombiner::visitCallBase(CallBase &Call) { !CalleeF->isDeclaration()) { Instruction *OldCall = &Call; CreateNonTerminatorUnreachable(OldCall); - // If OldCall does not return void then replaceAllUsesWith undef. + // If OldCall does not return void then replaceInstUsesWith undef. // This allows ValueHandlers and custom metadata to adjust itself. if (!OldCall->getType()->isVoidTy()) replaceInstUsesWith(*OldCall, UndefValue::get(OldCall->getType())); @@ -4659,7 +2077,7 @@ Instruction *InstCombiner::visitCallBase(CallBase &Call) { if ((isa<ConstantPointerNull>(Callee) && !NullPointerIsDefined(Call.getFunction())) || isa<UndefValue>(Callee)) { - // If Call does not return void then replaceAllUsesWith undef. + // If Call does not return void then replaceInstUsesWith undef. // This allows ValueHandlers and custom metadata to adjust itself. if (!Call.getType()->isVoidTy()) replaceInstUsesWith(Call, UndefValue::get(Call.getType())); @@ -4735,7 +2153,7 @@ Instruction *InstCombiner::visitCallBase(CallBase &Call) { /// If the callee is a constexpr cast of a function, attempt to move the cast to /// the arguments of the call/callbr/invoke. -bool InstCombiner::transformConstExprCastCall(CallBase &Call) { +bool InstCombinerImpl::transformConstExprCastCall(CallBase &Call) { auto *Callee = dyn_cast<Function>(Call.getCalledOperand()->stripPointerCasts()); if (!Callee) @@ -4834,6 +2252,9 @@ bool InstCombiner::transformConstExprCastCall(CallBase &Call) { if (Call.isInAllocaArgument(i)) return false; // Cannot transform to and from inalloca. + if (CallerPAL.hasParamAttribute(i, Attribute::SwiftError)) + return false; + // If the parameter is passed as a byval argument, then we have to have a // sized type and the sized type has to have the same size as the old type. if (ParamTy != ActTy && CallerPAL.hasParamAttribute(i, Attribute::ByVal)) { @@ -5019,8 +2440,8 @@ bool InstCombiner::transformConstExprCastCall(CallBase &Call) { /// Turn a call to a function created by init_trampoline / adjust_trampoline /// intrinsic pair into a direct call to the underlying function. Instruction * -InstCombiner::transformCallThroughTrampoline(CallBase &Call, - IntrinsicInst &Tramp) { +InstCombinerImpl::transformCallThroughTrampoline(CallBase &Call, + IntrinsicInst &Tramp) { Value *Callee = Call.getCalledOperand(); Type *CalleeTy = Callee->getType(); FunctionType *FTy = Call.getFunctionType(); diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp index 3639edb5df4d..0b53007bb6dc 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -14,10 +14,11 @@ #include "llvm/ADT/SetVector.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/TargetLibraryInfo.h" -#include "llvm/IR/DataLayout.h" #include "llvm/IR/DIBuilder.h" +#include "llvm/IR/DataLayout.h" #include "llvm/IR/PatternMatch.h" #include "llvm/Support/KnownBits.h" +#include "llvm/Transforms/InstCombine/InstCombiner.h" #include <numeric> using namespace llvm; using namespace PatternMatch; @@ -81,8 +82,8 @@ static Value *decomposeSimpleLinearExpr(Value *Val, unsigned &Scale, /// If we find a cast of an allocation instruction, try to eliminate the cast by /// moving the type information into the alloc. -Instruction *InstCombiner::PromoteCastOfAllocation(BitCastInst &CI, - AllocaInst &AI) { +Instruction *InstCombinerImpl::PromoteCastOfAllocation(BitCastInst &CI, + AllocaInst &AI) { PointerType *PTy = cast<PointerType>(CI.getType()); IRBuilderBase::InsertPointGuard Guard(Builder); @@ -93,6 +94,18 @@ Instruction *InstCombiner::PromoteCastOfAllocation(BitCastInst &CI, Type *CastElTy = PTy->getElementType(); if (!AllocElTy->isSized() || !CastElTy->isSized()) return nullptr; + // This optimisation does not work for cases where the cast type + // is scalable and the allocated type is not. This because we need to + // know how many times the casted type fits into the allocated type. + // For the opposite case where the allocated type is scalable and the + // cast type is not this leads to poor code quality due to the + // introduction of 'vscale' into the calculations. It seems better to + // bail out for this case too until we've done a proper cost-benefit + // analysis. + bool AllocIsScalable = isa<ScalableVectorType>(AllocElTy); + bool CastIsScalable = isa<ScalableVectorType>(CastElTy); + if (AllocIsScalable != CastIsScalable) return nullptr; + Align AllocElTyAlign = DL.getABITypeAlign(AllocElTy); Align CastElTyAlign = DL.getABITypeAlign(CastElTy); if (CastElTyAlign < AllocElTyAlign) return nullptr; @@ -102,14 +115,15 @@ Instruction *InstCombiner::PromoteCastOfAllocation(BitCastInst &CI, // same, we open the door to infinite loops of various kinds. if (!AI.hasOneUse() && CastElTyAlign == AllocElTyAlign) return nullptr; - uint64_t AllocElTySize = DL.getTypeAllocSize(AllocElTy); - uint64_t CastElTySize = DL.getTypeAllocSize(CastElTy); + // The alloc and cast types should be either both fixed or both scalable. + uint64_t AllocElTySize = DL.getTypeAllocSize(AllocElTy).getKnownMinSize(); + uint64_t CastElTySize = DL.getTypeAllocSize(CastElTy).getKnownMinSize(); if (CastElTySize == 0 || AllocElTySize == 0) return nullptr; // If the allocation has multiple uses, only promote it if we're not // shrinking the amount of memory being allocated. - uint64_t AllocElTyStoreSize = DL.getTypeStoreSize(AllocElTy); - uint64_t CastElTyStoreSize = DL.getTypeStoreSize(CastElTy); + uint64_t AllocElTyStoreSize = DL.getTypeStoreSize(AllocElTy).getKnownMinSize(); + uint64_t CastElTyStoreSize = DL.getTypeStoreSize(CastElTy).getKnownMinSize(); if (!AI.hasOneUse() && CastElTyStoreSize < AllocElTyStoreSize) return nullptr; // See if we can satisfy the modulus by pulling a scale out of the array @@ -124,6 +138,9 @@ Instruction *InstCombiner::PromoteCastOfAllocation(BitCastInst &CI, if ((AllocElTySize*ArraySizeScale) % CastElTySize != 0 || (AllocElTySize*ArrayOffset ) % CastElTySize != 0) return nullptr; + // We don't currently support arrays of scalable types. + assert(!AllocIsScalable || (ArrayOffset == 1 && ArraySizeScale == 0)); + unsigned Scale = (AllocElTySize*ArraySizeScale)/CastElTySize; Value *Amt = nullptr; if (Scale == 1) { @@ -160,8 +177,8 @@ Instruction *InstCombiner::PromoteCastOfAllocation(BitCastInst &CI, /// Given an expression that CanEvaluateTruncated or CanEvaluateSExtd returns /// true for, actually insert the code to evaluate the expression. -Value *InstCombiner::EvaluateInDifferentType(Value *V, Type *Ty, - bool isSigned) { +Value *InstCombinerImpl::EvaluateInDifferentType(Value *V, Type *Ty, + bool isSigned) { if (Constant *C = dyn_cast<Constant>(V)) { C = ConstantExpr::getIntegerCast(C, Ty, isSigned /*Sext or ZExt*/); // If we got a constantexpr back, try to simplify it with DL info. @@ -229,8 +246,9 @@ Value *InstCombiner::EvaluateInDifferentType(Value *V, Type *Ty, return InsertNewInstWith(Res, *I); } -Instruction::CastOps InstCombiner::isEliminableCastPair(const CastInst *CI1, - const CastInst *CI2) { +Instruction::CastOps +InstCombinerImpl::isEliminableCastPair(const CastInst *CI1, + const CastInst *CI2) { Type *SrcTy = CI1->getSrcTy(); Type *MidTy = CI1->getDestTy(); Type *DstTy = CI2->getDestTy(); @@ -257,7 +275,7 @@ Instruction::CastOps InstCombiner::isEliminableCastPair(const CastInst *CI1, } /// Implement the transforms common to all CastInst visitors. -Instruction *InstCombiner::commonCastTransforms(CastInst &CI) { +Instruction *InstCombinerImpl::commonCastTransforms(CastInst &CI) { Value *Src = CI.getOperand(0); // Try to eliminate a cast of a cast. @@ -342,7 +360,7 @@ static bool canNotEvaluateInType(Value *V, Type *Ty) { /// /// This function works on both vectors and scalars. /// -static bool canEvaluateTruncated(Value *V, Type *Ty, InstCombiner &IC, +static bool canEvaluateTruncated(Value *V, Type *Ty, InstCombinerImpl &IC, Instruction *CxtI) { if (canAlwaysEvaluateInType(V, Ty)) return true; @@ -459,7 +477,8 @@ static bool canEvaluateTruncated(Value *V, Type *Ty, InstCombiner &IC, /// trunc (lshr (bitcast <4 x i32> %X to i128), 32) to i32 /// ---> /// extractelement <4 x i32> %X, 1 -static Instruction *foldVecTruncToExtElt(TruncInst &Trunc, InstCombiner &IC) { +static Instruction *foldVecTruncToExtElt(TruncInst &Trunc, + InstCombinerImpl &IC) { Value *TruncOp = Trunc.getOperand(0); Type *DestType = Trunc.getType(); if (!TruncOp->hasOneUse() || !isa<IntegerType>(DestType)) @@ -496,9 +515,9 @@ static Instruction *foldVecTruncToExtElt(TruncInst &Trunc, InstCombiner &IC) { return ExtractElementInst::Create(VecInput, IC.Builder.getInt32(Elt)); } -/// Rotate left/right may occur in a wider type than necessary because of type -/// promotion rules. Try to narrow the inputs and convert to funnel shift. -Instruction *InstCombiner::narrowRotate(TruncInst &Trunc) { +/// Funnel/Rotate left/right may occur in a wider type than necessary because of +/// type promotion rules. Try to narrow the inputs and convert to funnel shift. +Instruction *InstCombinerImpl::narrowFunnelShift(TruncInst &Trunc) { assert((isa<VectorType>(Trunc.getSrcTy()) || shouldChangeType(Trunc.getSrcTy(), Trunc.getType())) && "Don't narrow to an illegal scalar type"); @@ -510,32 +529,43 @@ Instruction *InstCombiner::narrowRotate(TruncInst &Trunc) { if (!isPowerOf2_32(NarrowWidth)) return nullptr; - // First, find an or'd pair of opposite shifts with the same shifted operand: - // trunc (or (lshr ShVal, ShAmt0), (shl ShVal, ShAmt1)) - Value *Or0, *Or1; - if (!match(Trunc.getOperand(0), m_OneUse(m_Or(m_Value(Or0), m_Value(Or1))))) + // First, find an or'd pair of opposite shifts: + // trunc (or (lshr ShVal0, ShAmt0), (shl ShVal1, ShAmt1)) + BinaryOperator *Or0, *Or1; + if (!match(Trunc.getOperand(0), m_OneUse(m_Or(m_BinOp(Or0), m_BinOp(Or1))))) return nullptr; - Value *ShVal, *ShAmt0, *ShAmt1; - if (!match(Or0, m_OneUse(m_LogicalShift(m_Value(ShVal), m_Value(ShAmt0)))) || - !match(Or1, m_OneUse(m_LogicalShift(m_Specific(ShVal), m_Value(ShAmt1))))) + Value *ShVal0, *ShVal1, *ShAmt0, *ShAmt1; + if (!match(Or0, m_OneUse(m_LogicalShift(m_Value(ShVal0), m_Value(ShAmt0)))) || + !match(Or1, m_OneUse(m_LogicalShift(m_Value(ShVal1), m_Value(ShAmt1)))) || + Or0->getOpcode() == Or1->getOpcode()) return nullptr; - auto ShiftOpcode0 = cast<BinaryOperator>(Or0)->getOpcode(); - auto ShiftOpcode1 = cast<BinaryOperator>(Or1)->getOpcode(); - if (ShiftOpcode0 == ShiftOpcode1) - return nullptr; + // Canonicalize to or(shl(ShVal0, ShAmt0), lshr(ShVal1, ShAmt1)). + if (Or0->getOpcode() == BinaryOperator::LShr) { + std::swap(Or0, Or1); + std::swap(ShVal0, ShVal1); + std::swap(ShAmt0, ShAmt1); + } + assert(Or0->getOpcode() == BinaryOperator::Shl && + Or1->getOpcode() == BinaryOperator::LShr && + "Illegal or(shift,shift) pair"); - // Match the shift amount operands for a rotate pattern. This always matches - // a subtraction on the R operand. - auto matchShiftAmount = [](Value *L, Value *R, unsigned Width) -> Value * { + // Match the shift amount operands for a funnel/rotate pattern. This always + // matches a subtraction on the R operand. + auto matchShiftAmount = [&](Value *L, Value *R, unsigned Width) -> Value * { // The shift amounts may add up to the narrow bit width: - // (shl ShVal, L) | (lshr ShVal, Width - L) + // (shl ShVal0, L) | (lshr ShVal1, Width - L) if (match(R, m_OneUse(m_Sub(m_SpecificInt(Width), m_Specific(L))))) return L; + // The following patterns currently only work for rotation patterns. + // TODO: Add more general funnel-shift compatible patterns. + if (ShVal0 != ShVal1) + return nullptr; + // The shift amount may be masked with negation: - // (shl ShVal, (X & (Width - 1))) | (lshr ShVal, ((-X) & (Width - 1))) + // (shl ShVal0, (X & (Width - 1))) | (lshr ShVal1, ((-X) & (Width - 1))) Value *X; unsigned Mask = Width - 1; if (match(L, m_And(m_Value(X), m_SpecificInt(Mask))) && @@ -551,10 +581,10 @@ Instruction *InstCombiner::narrowRotate(TruncInst &Trunc) { }; Value *ShAmt = matchShiftAmount(ShAmt0, ShAmt1, NarrowWidth); - bool SubIsOnLHS = false; + bool IsFshl = true; // Sub on LSHR. if (!ShAmt) { ShAmt = matchShiftAmount(ShAmt1, ShAmt0, NarrowWidth); - SubIsOnLHS = true; + IsFshl = false; // Sub on SHL. } if (!ShAmt) return nullptr; @@ -563,26 +593,28 @@ Instruction *InstCombiner::narrowRotate(TruncInst &Trunc) { // will be a zext, but it could also be the result of an 'and' or 'shift'. unsigned WideWidth = Trunc.getSrcTy()->getScalarSizeInBits(); APInt HiBitMask = APInt::getHighBitsSet(WideWidth, WideWidth - NarrowWidth); - if (!MaskedValueIsZero(ShVal, HiBitMask, 0, &Trunc)) + if (!MaskedValueIsZero(ShVal0, HiBitMask, 0, &Trunc) || + !MaskedValueIsZero(ShVal1, HiBitMask, 0, &Trunc)) return nullptr; // We have an unnecessarily wide rotate! - // trunc (or (lshr ShVal, ShAmt), (shl ShVal, BitWidth - ShAmt)) + // trunc (or (lshr ShVal0, ShAmt), (shl ShVal1, BitWidth - ShAmt)) // Narrow the inputs and convert to funnel shift intrinsic: // llvm.fshl.i8(trunc(ShVal), trunc(ShVal), trunc(ShAmt)) Value *NarrowShAmt = Builder.CreateTrunc(ShAmt, DestTy); - Value *X = Builder.CreateTrunc(ShVal, DestTy); - bool IsFshl = (!SubIsOnLHS && ShiftOpcode0 == BinaryOperator::Shl) || - (SubIsOnLHS && ShiftOpcode1 == BinaryOperator::Shl); + Value *X, *Y; + X = Y = Builder.CreateTrunc(ShVal0, DestTy); + if (ShVal0 != ShVal1) + Y = Builder.CreateTrunc(ShVal1, DestTy); Intrinsic::ID IID = IsFshl ? Intrinsic::fshl : Intrinsic::fshr; Function *F = Intrinsic::getDeclaration(Trunc.getModule(), IID, DestTy); - return IntrinsicInst::Create(F, { X, X, NarrowShAmt }); + return IntrinsicInst::Create(F, {X, Y, NarrowShAmt}); } /// Try to narrow the width of math or bitwise logic instructions by pulling a /// truncate ahead of binary operators. /// TODO: Transforms for truncated shifts should be moved into here. -Instruction *InstCombiner::narrowBinOp(TruncInst &Trunc) { +Instruction *InstCombinerImpl::narrowBinOp(TruncInst &Trunc) { Type *SrcTy = Trunc.getSrcTy(); Type *DestTy = Trunc.getType(); if (!isa<VectorType>(SrcTy) && !shouldChangeType(SrcTy, DestTy)) @@ -631,7 +663,7 @@ Instruction *InstCombiner::narrowBinOp(TruncInst &Trunc) { default: break; } - if (Instruction *NarrowOr = narrowRotate(Trunc)) + if (Instruction *NarrowOr = narrowFunnelShift(Trunc)) return NarrowOr; return nullptr; @@ -687,7 +719,7 @@ static Instruction *shrinkInsertElt(CastInst &Trunc, return nullptr; } -Instruction *InstCombiner::visitTrunc(TruncInst &Trunc) { +Instruction *InstCombinerImpl::visitTrunc(TruncInst &Trunc) { if (Instruction *Result = commonCastTransforms(Trunc)) return Result; @@ -695,7 +727,6 @@ Instruction *InstCombiner::visitTrunc(TruncInst &Trunc) { Type *DestTy = Trunc.getType(), *SrcTy = Src->getType(); unsigned DestWidth = DestTy->getScalarSizeInBits(); unsigned SrcWidth = SrcTy->getScalarSizeInBits(); - ConstantInt *Cst; // Attempt to truncate the entire input expression tree to the destination // type. Only do this if the dest type is a simple type, don't convert the @@ -782,56 +813,60 @@ Instruction *InstCombiner::visitTrunc(TruncInst &Trunc) { } } - // FIXME: Maybe combine the next two transforms to handle the no cast case - // more efficiently. Support vector types. Cleanup code by using m_OneUse. - - // Transform trunc(lshr (zext A), Cst) to eliminate one type conversion. - Value *A = nullptr; - if (Src->hasOneUse() && - match(Src, m_LShr(m_ZExt(m_Value(A)), m_ConstantInt(Cst)))) { - // We have three types to worry about here, the type of A, the source of - // the truncate (MidSize), and the destination of the truncate. We know that - // ASize < MidSize and MidSize > ResultSize, but don't know the relation - // between ASize and ResultSize. - unsigned ASize = A->getType()->getPrimitiveSizeInBits(); - - // If the shift amount is larger than the size of A, then the result is - // known to be zero because all the input bits got shifted out. - if (Cst->getZExtValue() >= ASize) - return replaceInstUsesWith(Trunc, Constant::getNullValue(DestTy)); - - // Since we're doing an lshr and a zero extend, and know that the shift - // amount is smaller than ASize, it is always safe to do the shift in A's - // type, then zero extend or truncate to the result. - Value *Shift = Builder.CreateLShr(A, Cst->getZExtValue()); - Shift->takeName(Src); - return CastInst::CreateIntegerCast(Shift, DestTy, false); - } - - const APInt *C; - if (match(Src, m_LShr(m_SExt(m_Value(A)), m_APInt(C)))) { + Value *A; + Constant *C; + if (match(Src, m_LShr(m_SExt(m_Value(A)), m_Constant(C)))) { unsigned AWidth = A->getType()->getScalarSizeInBits(); unsigned MaxShiftAmt = SrcWidth - std::max(DestWidth, AWidth); + auto *OldSh = cast<Instruction>(Src); + bool IsExact = OldSh->isExact(); // If the shift is small enough, all zero bits created by the shift are // removed by the trunc. - if (C->getZExtValue() <= MaxShiftAmt) { + if (match(C, m_SpecificInt_ICMP(ICmpInst::ICMP_ULE, + APInt(SrcWidth, MaxShiftAmt)))) { // trunc (lshr (sext A), C) --> ashr A, C if (A->getType() == DestTy) { - unsigned ShAmt = std::min((unsigned)C->getZExtValue(), DestWidth - 1); - return BinaryOperator::CreateAShr(A, ConstantInt::get(DestTy, ShAmt)); + Constant *MaxAmt = ConstantInt::get(SrcTy, DestWidth - 1, false); + Constant *ShAmt = ConstantExpr::getUMin(C, MaxAmt); + ShAmt = ConstantExpr::getTrunc(ShAmt, A->getType()); + ShAmt = Constant::mergeUndefsWith(ShAmt, C); + return IsExact ? BinaryOperator::CreateExactAShr(A, ShAmt) + : BinaryOperator::CreateAShr(A, ShAmt); } // The types are mismatched, so create a cast after shifting: // trunc (lshr (sext A), C) --> sext/trunc (ashr A, C) if (Src->hasOneUse()) { - unsigned ShAmt = std::min((unsigned)C->getZExtValue(), AWidth - 1); - Value *Shift = Builder.CreateAShr(A, ShAmt); + Constant *MaxAmt = ConstantInt::get(SrcTy, AWidth - 1, false); + Constant *ShAmt = ConstantExpr::getUMin(C, MaxAmt); + ShAmt = ConstantExpr::getTrunc(ShAmt, A->getType()); + Value *Shift = Builder.CreateAShr(A, ShAmt, "", IsExact); return CastInst::CreateIntegerCast(Shift, DestTy, true); } } // TODO: Mask high bits with 'and'. } + // trunc (*shr (trunc A), C) --> trunc(*shr A, C) + if (match(Src, m_OneUse(m_Shr(m_Trunc(m_Value(A)), m_Constant(C))))) { + unsigned MaxShiftAmt = SrcWidth - DestWidth; + + // If the shift is small enough, all zero/sign bits created by the shift are + // removed by the trunc. + if (match(C, m_SpecificInt_ICMP(ICmpInst::ICMP_ULE, + APInt(SrcWidth, MaxShiftAmt)))) { + auto *OldShift = cast<Instruction>(Src); + bool IsExact = OldShift->isExact(); + auto *ShAmt = ConstantExpr::getIntegerCast(C, A->getType(), true); + ShAmt = Constant::mergeUndefsWith(ShAmt, C); + Value *Shift = + OldShift->getOpcode() == Instruction::AShr + ? Builder.CreateAShr(A, ShAmt, OldShift->getName(), IsExact) + : Builder.CreateLShr(A, ShAmt, OldShift->getName(), IsExact); + return CastInst::CreateTruncOrBitCast(Shift, DestTy); + } + } + if (Instruction *I = narrowBinOp(Trunc)) return I; @@ -841,20 +876,19 @@ Instruction *InstCombiner::visitTrunc(TruncInst &Trunc) { if (Instruction *I = shrinkInsertElt(Trunc, Builder)) return I; - if (Src->hasOneUse() && isa<IntegerType>(SrcTy) && - shouldChangeType(SrcTy, DestTy)) { + if (Src->hasOneUse() && + (isa<VectorType>(SrcTy) || shouldChangeType(SrcTy, DestTy))) { // Transform "trunc (shl X, cst)" -> "shl (trunc X), cst" so long as the // dest type is native and cst < dest size. - if (match(Src, m_Shl(m_Value(A), m_ConstantInt(Cst))) && + if (match(Src, m_Shl(m_Value(A), m_Constant(C))) && !match(A, m_Shr(m_Value(), m_Constant()))) { // Skip shifts of shift by constants. It undoes a combine in // FoldShiftByConstant and is the extend in reg pattern. - if (Cst->getValue().ult(DestWidth)) { + APInt Threshold = APInt(C->getType()->getScalarSizeInBits(), DestWidth); + if (match(C, m_SpecificInt_ICMP(ICmpInst::ICMP_ULT, Threshold))) { Value *NewTrunc = Builder.CreateTrunc(A, DestTy, A->getName() + ".tr"); - - return BinaryOperator::Create( - Instruction::Shl, NewTrunc, - ConstantInt::get(DestTy, Cst->getValue().trunc(DestWidth))); + return BinaryOperator::Create(Instruction::Shl, NewTrunc, + ConstantExpr::getTrunc(C, DestTy)); } } } @@ -871,21 +905,23 @@ Instruction *InstCombiner::visitTrunc(TruncInst &Trunc) { // ---> // extractelement <8 x i32> (bitcast <4 x i64> %X to <8 x i32>), i32 0 Value *VecOp; + ConstantInt *Cst; if (match(Src, m_OneUse(m_ExtractElt(m_Value(VecOp), m_ConstantInt(Cst))))) { auto *VecOpTy = cast<VectorType>(VecOp->getType()); - unsigned VecNumElts = VecOpTy->getNumElements(); + auto VecElts = VecOpTy->getElementCount(); // A badly fit destination size would result in an invalid cast. if (SrcWidth % DestWidth == 0) { uint64_t TruncRatio = SrcWidth / DestWidth; - uint64_t BitCastNumElts = VecNumElts * TruncRatio; + uint64_t BitCastNumElts = VecElts.getKnownMinValue() * TruncRatio; uint64_t VecOpIdx = Cst->getZExtValue(); uint64_t NewIdx = DL.isBigEndian() ? (VecOpIdx + 1) * TruncRatio - 1 : VecOpIdx * TruncRatio; assert(BitCastNumElts <= std::numeric_limits<uint32_t>::max() && "overflow 32-bits"); - auto *BitCastTo = FixedVectorType::get(DestTy, BitCastNumElts); + auto *BitCastTo = + VectorType::get(DestTy, BitCastNumElts, VecElts.isScalable()); Value *BitCast = Builder.CreateBitCast(VecOp, BitCastTo); return ExtractElementInst::Create(BitCast, Builder.getInt32(NewIdx)); } @@ -894,8 +930,8 @@ Instruction *InstCombiner::visitTrunc(TruncInst &Trunc) { return nullptr; } -Instruction *InstCombiner::transformZExtICmp(ICmpInst *Cmp, ZExtInst &Zext, - bool DoTransform) { +Instruction *InstCombinerImpl::transformZExtICmp(ICmpInst *Cmp, ZExtInst &Zext, + bool DoTransform) { // If we are just checking for a icmp eq of a single bit and zext'ing it // to an integer, then shift the bit to the appropriate place and then // cast to integer to avoid the comparison. @@ -1031,7 +1067,7 @@ Instruction *InstCombiner::transformZExtICmp(ICmpInst *Cmp, ZExtInst &Zext, /// /// This function works on both vectors and scalars. static bool canEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear, - InstCombiner &IC, Instruction *CxtI) { + InstCombinerImpl &IC, Instruction *CxtI) { BitsToClear = 0; if (canAlwaysEvaluateInType(V, Ty)) return true; @@ -1136,7 +1172,7 @@ static bool canEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear, } } -Instruction *InstCombiner::visitZExt(ZExtInst &CI) { +Instruction *InstCombinerImpl::visitZExt(ZExtInst &CI) { // If this zero extend is only used by a truncate, let the truncate be // eliminated before we try to optimize this zext. if (CI.hasOneUse() && isa<TruncInst>(CI.user_back())) @@ -1274,7 +1310,8 @@ Instruction *InstCombiner::visitZExt(ZExtInst &CI) { } /// Transform (sext icmp) to bitwise / integer operations to eliminate the icmp. -Instruction *InstCombiner::transformSExtICmp(ICmpInst *ICI, Instruction &CI) { +Instruction *InstCombinerImpl::transformSExtICmp(ICmpInst *ICI, + Instruction &CI) { Value *Op0 = ICI->getOperand(0), *Op1 = ICI->getOperand(1); ICmpInst::Predicate Pred = ICI->getPredicate(); @@ -1410,7 +1447,7 @@ static bool canEvaluateSExtd(Value *V, Type *Ty) { return false; } -Instruction *InstCombiner::visitSExt(SExtInst &CI) { +Instruction *InstCombinerImpl::visitSExt(SExtInst &CI) { // If this sign extend is only used by a truncate, let the truncate be // eliminated before we try to optimize this sext. if (CI.hasOneUse() && isa<TruncInst>(CI.user_back())) @@ -1473,31 +1510,33 @@ Instruction *InstCombiner::visitSExt(SExtInst &CI) { // for a truncate. If the source and dest are the same type, eliminate the // trunc and extend and just do shifts. For example, turn: // %a = trunc i32 %i to i8 - // %b = shl i8 %a, 6 - // %c = ashr i8 %b, 6 + // %b = shl i8 %a, C + // %c = ashr i8 %b, C // %d = sext i8 %c to i32 // into: - // %a = shl i32 %i, 30 - // %d = ashr i32 %a, 30 + // %a = shl i32 %i, 32-(8-C) + // %d = ashr i32 %a, 32-(8-C) Value *A = nullptr; // TODO: Eventually this could be subsumed by EvaluateInDifferentType. Constant *BA = nullptr, *CA = nullptr; if (match(Src, m_AShr(m_Shl(m_Trunc(m_Value(A)), m_Constant(BA)), m_Constant(CA))) && - BA == CA && A->getType() == CI.getType()) { - unsigned MidSize = Src->getType()->getScalarSizeInBits(); - unsigned SrcDstSize = CI.getType()->getScalarSizeInBits(); - Constant *SizeDiff = ConstantInt::get(CA->getType(), SrcDstSize - MidSize); - Constant *ShAmt = ConstantExpr::getAdd(CA, SizeDiff); - Constant *ShAmtExt = ConstantExpr::getSExt(ShAmt, CI.getType()); - A = Builder.CreateShl(A, ShAmtExt, CI.getName()); - return BinaryOperator::CreateAShr(A, ShAmtExt); + BA->isElementWiseEqual(CA) && A->getType() == DestTy) { + Constant *WideCurrShAmt = ConstantExpr::getSExt(CA, DestTy); + Constant *NumLowbitsLeft = ConstantExpr::getSub( + ConstantInt::get(DestTy, SrcTy->getScalarSizeInBits()), WideCurrShAmt); + Constant *NewShAmt = ConstantExpr::getSub( + ConstantInt::get(DestTy, DestTy->getScalarSizeInBits()), + NumLowbitsLeft); + NewShAmt = + Constant::mergeUndefsWith(Constant::mergeUndefsWith(NewShAmt, BA), CA); + A = Builder.CreateShl(A, NewShAmt, CI.getName()); + return BinaryOperator::CreateAShr(A, NewShAmt); } return nullptr; } - /// Return a Constant* for the specified floating-point constant if it fits /// in the specified FP type without changing its value. static bool fitsInFPType(ConstantFP *CFP, const fltSemantics &Sem) { @@ -1535,7 +1574,7 @@ static Type *shrinkFPConstantVector(Value *V) { Type *MinType = nullptr; - unsigned NumElts = CVVTy->getNumElements(); + unsigned NumElts = cast<FixedVectorType>(CVVTy)->getNumElements(); for (unsigned i = 0; i != NumElts; ++i) { auto *CFP = dyn_cast_or_null<ConstantFP>(CV->getAggregateElement(i)); if (!CFP) @@ -1616,7 +1655,7 @@ static bool isKnownExactCastIntToFP(CastInst &I) { return false; } -Instruction *InstCombiner::visitFPTrunc(FPTruncInst &FPT) { +Instruction *InstCombinerImpl::visitFPTrunc(FPTruncInst &FPT) { if (Instruction *I = commonCastTransforms(FPT)) return I; @@ -1800,7 +1839,7 @@ Instruction *InstCombiner::visitFPTrunc(FPTruncInst &FPT) { return nullptr; } -Instruction *InstCombiner::visitFPExt(CastInst &FPExt) { +Instruction *InstCombinerImpl::visitFPExt(CastInst &FPExt) { // If the source operand is a cast from integer to FP and known exact, then // cast the integer operand directly to the destination type. Type *Ty = FPExt.getType(); @@ -1818,7 +1857,7 @@ Instruction *InstCombiner::visitFPExt(CastInst &FPExt) { /// This is safe if the intermediate type has enough bits in its mantissa to /// accurately represent all values of X. For example, this won't work with /// i64 -> float -> i64. -Instruction *InstCombiner::foldItoFPtoI(CastInst &FI) { +Instruction *InstCombinerImpl::foldItoFPtoI(CastInst &FI) { if (!isa<UIToFPInst>(FI.getOperand(0)) && !isa<SIToFPInst>(FI.getOperand(0))) return nullptr; @@ -1858,29 +1897,29 @@ Instruction *InstCombiner::foldItoFPtoI(CastInst &FI) { return replaceInstUsesWith(FI, X); } -Instruction *InstCombiner::visitFPToUI(FPToUIInst &FI) { +Instruction *InstCombinerImpl::visitFPToUI(FPToUIInst &FI) { if (Instruction *I = foldItoFPtoI(FI)) return I; return commonCastTransforms(FI); } -Instruction *InstCombiner::visitFPToSI(FPToSIInst &FI) { +Instruction *InstCombinerImpl::visitFPToSI(FPToSIInst &FI) { if (Instruction *I = foldItoFPtoI(FI)) return I; return commonCastTransforms(FI); } -Instruction *InstCombiner::visitUIToFP(CastInst &CI) { +Instruction *InstCombinerImpl::visitUIToFP(CastInst &CI) { return commonCastTransforms(CI); } -Instruction *InstCombiner::visitSIToFP(CastInst &CI) { +Instruction *InstCombinerImpl::visitSIToFP(CastInst &CI) { return commonCastTransforms(CI); } -Instruction *InstCombiner::visitIntToPtr(IntToPtrInst &CI) { +Instruction *InstCombinerImpl::visitIntToPtr(IntToPtrInst &CI) { // If the source integer type is not the intptr_t type for this target, do a // trunc or zext to the intptr_t type, then inttoptr of it. This allows the // cast to be exposed to other transforms. @@ -1903,7 +1942,7 @@ Instruction *InstCombiner::visitIntToPtr(IntToPtrInst &CI) { } /// Implement the transforms for cast of pointer (bitcast/ptrtoint) -Instruction *InstCombiner::commonPointerCastTransforms(CastInst &CI) { +Instruction *InstCombinerImpl::commonPointerCastTransforms(CastInst &CI) { Value *Src = CI.getOperand(0); if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Src)) { @@ -1925,26 +1964,37 @@ Instruction *InstCombiner::commonPointerCastTransforms(CastInst &CI) { return commonCastTransforms(CI); } -Instruction *InstCombiner::visitPtrToInt(PtrToIntInst &CI) { +Instruction *InstCombinerImpl::visitPtrToInt(PtrToIntInst &CI) { // If the destination integer type is not the intptr_t type for this target, // do a ptrtoint to intptr_t then do a trunc or zext. This allows the cast // to be exposed to other transforms. - + Value *SrcOp = CI.getPointerOperand(); Type *Ty = CI.getType(); unsigned AS = CI.getPointerAddressSpace(); + unsigned TySize = Ty->getScalarSizeInBits(); + unsigned PtrSize = DL.getPointerSizeInBits(AS); + if (TySize != PtrSize) { + Type *IntPtrTy = DL.getIntPtrType(CI.getContext(), AS); + // Handle vectors of pointers. + if (auto *VecTy = dyn_cast<VectorType>(Ty)) + IntPtrTy = VectorType::get(IntPtrTy, VecTy->getElementCount()); - if (Ty->getScalarSizeInBits() == DL.getPointerSizeInBits(AS)) - return commonPointerCastTransforms(CI); + Value *P = Builder.CreatePtrToInt(SrcOp, IntPtrTy); + return CastInst::CreateIntegerCast(P, Ty, /*isSigned=*/false); + } - Type *PtrTy = DL.getIntPtrType(CI.getContext(), AS); - if (auto *VTy = dyn_cast<VectorType>(Ty)) { - // Handle vectors of pointers. - // FIXME: what should happen for scalable vectors? - PtrTy = FixedVectorType::get(PtrTy, VTy->getNumElements()); + Value *Vec, *Scalar, *Index; + if (match(SrcOp, m_OneUse(m_InsertElt(m_IntToPtr(m_Value(Vec)), + m_Value(Scalar), m_Value(Index)))) && + Vec->getType() == Ty) { + assert(Vec->getType()->getScalarSizeInBits() == PtrSize && "Wrong type"); + // Convert the scalar to int followed by insert to eliminate one cast: + // p2i (ins (i2p Vec), Scalar, Index --> ins Vec, (p2i Scalar), Index + Value *NewCast = Builder.CreatePtrToInt(Scalar, Ty->getScalarType()); + return InsertElementInst::Create(Vec, NewCast, Index); } - Value *P = Builder.CreatePtrToInt(CI.getOperand(0), PtrTy); - return CastInst::CreateIntegerCast(P, Ty, /*isSigned=*/false); + return commonPointerCastTransforms(CI); } /// This input value (which is known to have vector type) is being zero extended @@ -1963,9 +2013,9 @@ Instruction *InstCombiner::visitPtrToInt(PtrToIntInst &CI) { /// Try to replace it with a shuffle (and vector/vector bitcast) if possible. /// /// The source and destination vector types may have different element types. -static Instruction *optimizeVectorResizeWithIntegerBitCasts(Value *InVal, - VectorType *DestTy, - InstCombiner &IC) { +static Instruction * +optimizeVectorResizeWithIntegerBitCasts(Value *InVal, VectorType *DestTy, + InstCombinerImpl &IC) { // We can only do this optimization if the output is a multiple of the input // element size, or the input is a multiple of the output element size. // Convert the input type to have the same element type as the output. @@ -1981,13 +2031,14 @@ static Instruction *optimizeVectorResizeWithIntegerBitCasts(Value *InVal, return nullptr; SrcTy = - FixedVectorType::get(DestTy->getElementType(), SrcTy->getNumElements()); + FixedVectorType::get(DestTy->getElementType(), + cast<FixedVectorType>(SrcTy)->getNumElements()); InVal = IC.Builder.CreateBitCast(InVal, SrcTy); } bool IsBigEndian = IC.getDataLayout().isBigEndian(); - unsigned SrcElts = SrcTy->getNumElements(); - unsigned DestElts = DestTy->getNumElements(); + unsigned SrcElts = cast<FixedVectorType>(SrcTy)->getNumElements(); + unsigned DestElts = cast<FixedVectorType>(DestTy)->getNumElements(); assert(SrcElts != DestElts && "Element counts should be different."); @@ -2165,8 +2216,8 @@ static bool collectInsertionElements(Value *V, unsigned Shift, /// /// Into two insertelements that do "buildvector{%inc, %inc5}". static Value *optimizeIntegerToVectorInsertions(BitCastInst &CI, - InstCombiner &IC) { - VectorType *DestVecTy = cast<VectorType>(CI.getType()); + InstCombinerImpl &IC) { + auto *DestVecTy = cast<FixedVectorType>(CI.getType()); Value *IntInput = CI.getOperand(0); SmallVector<Value*, 8> Elements(DestVecTy->getNumElements()); @@ -2194,7 +2245,7 @@ static Value *optimizeIntegerToVectorInsertions(BitCastInst &CI, /// vectors better than bitcasts of scalars because vector registers are /// usually not type-specific like scalar integer or scalar floating-point. static Instruction *canonicalizeBitCastExtElt(BitCastInst &BitCast, - InstCombiner &IC) { + InstCombinerImpl &IC) { // TODO: Create and use a pattern matcher for ExtractElementInst. auto *ExtElt = dyn_cast<ExtractElementInst>(BitCast.getOperand(0)); if (!ExtElt || !ExtElt->hasOneUse()) @@ -2206,8 +2257,7 @@ static Instruction *canonicalizeBitCastExtElt(BitCastInst &BitCast, if (!VectorType::isValidElementType(DestType)) return nullptr; - unsigned NumElts = ExtElt->getVectorOperandType()->getNumElements(); - auto *NewVecType = FixedVectorType::get(DestType, NumElts); + auto *NewVecType = VectorType::get(DestType, ExtElt->getVectorOperandType()); auto *NewBC = IC.Builder.CreateBitCast(ExtElt->getVectorOperand(), NewVecType, "bc"); return ExtractElementInst::Create(NewBC, ExtElt->getIndexOperand()); @@ -2270,12 +2320,11 @@ static Instruction *foldBitCastSelect(BitCastInst &BitCast, // A vector select must maintain the same number of elements in its operands. Type *CondTy = Cond->getType(); Type *DestTy = BitCast.getType(); - if (auto *CondVTy = dyn_cast<VectorType>(CondTy)) { - if (!DestTy->isVectorTy()) + if (auto *CondVTy = dyn_cast<VectorType>(CondTy)) + if (!DestTy->isVectorTy() || + CondVTy->getElementCount() != + cast<VectorType>(DestTy)->getElementCount()) return nullptr; - if (cast<VectorType>(DestTy)->getNumElements() != CondVTy->getNumElements()) - return nullptr; - } // FIXME: This transform is restricted from changing the select between // scalars and vectors to avoid backend problems caused by creating @@ -2320,7 +2369,8 @@ static bool hasStoreUsersOnly(CastInst &CI) { /// /// All the related PHI nodes can be replaced by new PHI nodes with type A. /// The uses of \p CI can be changed to the new PHI node corresponding to \p PN. -Instruction *InstCombiner::optimizeBitCastFromPhi(CastInst &CI, PHINode *PN) { +Instruction *InstCombinerImpl::optimizeBitCastFromPhi(CastInst &CI, + PHINode *PN) { // BitCast used by Store can be handled in InstCombineLoadStoreAlloca.cpp. if (hasStoreUsersOnly(CI)) return nullptr; @@ -2450,10 +2500,7 @@ Instruction *InstCombiner::optimizeBitCastFromPhi(CastInst &CI, PHINode *PN) { Instruction *RetVal = nullptr; for (auto *OldPN : OldPhiNodes) { PHINode *NewPN = NewPNodes[OldPN]; - for (auto It = OldPN->user_begin(), End = OldPN->user_end(); It != End; ) { - User *V = *It; - // We may remove this user, advance to avoid iterator invalidation. - ++It; + for (User *V : make_early_inc_range(OldPN->users())) { if (auto *SI = dyn_cast<StoreInst>(V)) { assert(SI->isSimple() && SI->getOperand(0) == OldPN); Builder.SetInsertPoint(SI); @@ -2473,7 +2520,7 @@ Instruction *InstCombiner::optimizeBitCastFromPhi(CastInst &CI, PHINode *PN) { if (BCI == &CI) RetVal = I; } else if (auto *PHI = dyn_cast<PHINode>(V)) { - assert(OldPhiNodes.count(PHI) > 0); + assert(OldPhiNodes.contains(PHI)); (void) PHI; } else { llvm_unreachable("all uses should be handled"); @@ -2484,7 +2531,7 @@ Instruction *InstCombiner::optimizeBitCastFromPhi(CastInst &CI, PHINode *PN) { return RetVal; } -Instruction *InstCombiner::visitBitCast(BitCastInst &CI) { +Instruction *InstCombinerImpl::visitBitCast(BitCastInst &CI) { // If the operands are integer typed then apply the integer transforms, // otherwise just apply the common ones. Value *Src = CI.getOperand(0); @@ -2608,12 +2655,11 @@ Instruction *InstCombiner::visitBitCast(BitCastInst &CI) { // a bitcast to a vector with the same # elts. Value *ShufOp0 = Shuf->getOperand(0); Value *ShufOp1 = Shuf->getOperand(1); - unsigned NumShufElts = Shuf->getType()->getNumElements(); - unsigned NumSrcVecElts = - cast<VectorType>(ShufOp0->getType())->getNumElements(); + auto ShufElts = cast<VectorType>(Shuf->getType())->getElementCount(); + auto SrcVecElts = cast<VectorType>(ShufOp0->getType())->getElementCount(); if (Shuf->hasOneUse() && DestTy->isVectorTy() && - cast<VectorType>(DestTy)->getNumElements() == NumShufElts && - NumShufElts == NumSrcVecElts) { + cast<VectorType>(DestTy)->getElementCount() == ShufElts && + ShufElts == SrcVecElts) { BitCastInst *Tmp; // If either of the operands is a cast from CI.getType(), then // evaluating the shuffle in the casted destination's type will allow @@ -2636,8 +2682,9 @@ Instruction *InstCombiner::visitBitCast(BitCastInst &CI) { // TODO: We should match the related pattern for bitreverse. if (DestTy->isIntegerTy() && DL.isLegalInteger(DestTy->getScalarSizeInBits()) && - SrcTy->getScalarSizeInBits() == 8 && NumShufElts % 2 == 0 && - Shuf->hasOneUse() && Shuf->isReverse()) { + SrcTy->getScalarSizeInBits() == 8 && + ShufElts.getKnownMinValue() % 2 == 0 && Shuf->hasOneUse() && + Shuf->isReverse()) { assert(ShufOp0->getType() == SrcTy && "Unexpected shuffle mask"); assert(isa<UndefValue>(ShufOp1) && "Unexpected shuffle op"); Function *Bswap = @@ -2666,7 +2713,7 @@ Instruction *InstCombiner::visitBitCast(BitCastInst &CI) { return commonCastTransforms(CI); } -Instruction *InstCombiner::visitAddrSpaceCast(AddrSpaceCastInst &CI) { +Instruction *InstCombinerImpl::visitAddrSpaceCast(AddrSpaceCastInst &CI) { // If the destination pointer element type is not the same as the source's // first do a bitcast to the destination type, and then the addrspacecast. // This allows the cast to be exposed to other transforms. @@ -2677,11 +2724,9 @@ Instruction *InstCombiner::visitAddrSpaceCast(AddrSpaceCastInst &CI) { Type *DestElemTy = DestTy->getElementType(); if (SrcTy->getElementType() != DestElemTy) { Type *MidTy = PointerType::get(DestElemTy, SrcTy->getAddressSpace()); - if (VectorType *VT = dyn_cast<VectorType>(CI.getType())) { - // Handle vectors of pointers. - // FIXME: what should happen for scalable vectors? - MidTy = FixedVectorType::get(MidTy, VT->getNumElements()); - } + // Handle vectors of pointers. + if (VectorType *VT = dyn_cast<VectorType>(CI.getType())) + MidTy = VectorType::get(MidTy, VT->getElementCount()); Value *NewBitCast = Builder.CreateBitCast(Src, MidTy); return new AddrSpaceCastInst(NewBitCast, CI.getType()); diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp index f1233b62445d..cd9a036179b6 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -24,6 +24,7 @@ #include "llvm/IR/PatternMatch.h" #include "llvm/Support/Debug.h" #include "llvm/Support/KnownBits.h" +#include "llvm/Transforms/InstCombine/InstCombiner.h" using namespace llvm; using namespace PatternMatch; @@ -95,45 +96,6 @@ static bool isSignTest(ICmpInst::Predicate &Pred, const APInt &C) { return false; } -/// Given a signed integer type and a set of known zero and one bits, compute -/// the maximum and minimum values that could have the specified known zero and -/// known one bits, returning them in Min/Max. -/// TODO: Move to method on KnownBits struct? -static void computeSignedMinMaxValuesFromKnownBits(const KnownBits &Known, - APInt &Min, APInt &Max) { - assert(Known.getBitWidth() == Min.getBitWidth() && - Known.getBitWidth() == Max.getBitWidth() && - "KnownZero, KnownOne and Min, Max must have equal bitwidth."); - APInt UnknownBits = ~(Known.Zero|Known.One); - - // The minimum value is when all unknown bits are zeros, EXCEPT for the sign - // bit if it is unknown. - Min = Known.One; - Max = Known.One|UnknownBits; - - if (UnknownBits.isNegative()) { // Sign bit is unknown - Min.setSignBit(); - Max.clearSignBit(); - } -} - -/// Given an unsigned integer type and a set of known zero and one bits, compute -/// the maximum and minimum values that could have the specified known zero and -/// known one bits, returning them in Min/Max. -/// TODO: Move to method on KnownBits struct? -static void computeUnsignedMinMaxValuesFromKnownBits(const KnownBits &Known, - APInt &Min, APInt &Max) { - assert(Known.getBitWidth() == Min.getBitWidth() && - Known.getBitWidth() == Max.getBitWidth() && - "Ty, KnownZero, KnownOne and Min, Max must have equal bitwidth."); - APInt UnknownBits = ~(Known.Zero|Known.One); - - // The minimum value is when the unknown bits are all zeros. - Min = Known.One; - // The maximum value is when the unknown bits are all ones. - Max = Known.One|UnknownBits; -} - /// This is called when we see this pattern: /// cmp pred (load (gep GV, ...)), cmpcst /// where GV is a global variable with a constant initializer. Try to simplify @@ -142,10 +104,10 @@ static void computeUnsignedMinMaxValuesFromKnownBits(const KnownBits &Known, /// /// If AndCst is non-null, then the loaded value is masked with that constant /// before doing the comparison. This handles cases like "A[i]&4 == 0". -Instruction *InstCombiner::foldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, - GlobalVariable *GV, - CmpInst &ICI, - ConstantInt *AndCst) { +Instruction * +InstCombinerImpl::foldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, + GlobalVariable *GV, CmpInst &ICI, + ConstantInt *AndCst) { Constant *Init = GV->getInitializer(); if (!isa<ConstantArray>(Init) && !isa<ConstantDataArray>(Init)) return nullptr; @@ -313,7 +275,7 @@ Instruction *InstCombiner::foldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, if (!GEP->isInBounds()) { Type *IntPtrTy = DL.getIntPtrType(GEP->getType()); unsigned PtrSize = IntPtrTy->getIntegerBitWidth(); - if (Idx->getType()->getPrimitiveSizeInBits() > PtrSize) + if (Idx->getType()->getPrimitiveSizeInBits().getFixedSize() > PtrSize) Idx = Builder.CreateTrunc(Idx, IntPtrTy); } @@ -422,7 +384,7 @@ Instruction *InstCombiner::foldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, /// /// If we can't emit an optimized form for this expression, this returns null. /// -static Value *evaluateGEPOffsetExpression(User *GEP, InstCombiner &IC, +static Value *evaluateGEPOffsetExpression(User *GEP, InstCombinerImpl &IC, const DataLayout &DL) { gep_type_iterator GTI = gep_type_begin(GEP); @@ -486,7 +448,8 @@ static Value *evaluateGEPOffsetExpression(User *GEP, InstCombiner &IC, // Cast to intptrty in case a truncation occurs. If an extension is needed, // we don't need to bother extending: the extension won't affect where the // computation crosses zero. - if (VariableIdx->getType()->getPrimitiveSizeInBits() > IntPtrWidth) { + if (VariableIdx->getType()->getPrimitiveSizeInBits().getFixedSize() > + IntPtrWidth) { VariableIdx = IC.Builder.CreateTrunc(VariableIdx, IntPtrTy); } return VariableIdx; @@ -539,7 +502,7 @@ static bool canRewriteGEPAsOffset(Value *Start, Value *Base, Value *V = WorkList.back(); - if (Explored.count(V) != 0) { + if (Explored.contains(V)) { WorkList.pop_back(); continue; } @@ -551,7 +514,7 @@ static bool canRewriteGEPAsOffset(Value *Start, Value *Base, return false; if (isa<IntToPtrInst>(V) || isa<PtrToIntInst>(V)) { - auto *CI = dyn_cast<CastInst>(V); + auto *CI = cast<CastInst>(V); if (!CI->isNoopCast(DL)) return false; @@ -841,9 +804,9 @@ static Instruction *transformToIndexedCompare(GEPOperator *GEPLHS, Value *RHS, /// Fold comparisons between a GEP instruction and something else. At this point /// we know that the GEP is on the LHS of the comparison. -Instruction *InstCombiner::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS, - ICmpInst::Predicate Cond, - Instruction &I) { +Instruction *InstCombinerImpl::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS, + ICmpInst::Predicate Cond, + Instruction &I) { // Don't transform signed compares of GEPs into index compares. Even if the // GEP is inbounds, the final add of the base pointer can have signed overflow // and would change the result of the icmp. @@ -897,8 +860,8 @@ Instruction *InstCombiner::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS, // For vectors, we apply the same reasoning on a per-lane basis. auto *Base = GEPLHS->getPointerOperand(); if (GEPLHS->getType()->isVectorTy() && Base->getType()->isPointerTy()) { - int NumElts = cast<VectorType>(GEPLHS->getType())->getNumElements(); - Base = Builder.CreateVectorSplat(NumElts, Base); + auto EC = cast<VectorType>(GEPLHS->getType())->getElementCount(); + Base = Builder.CreateVectorSplat(EC, Base); } return new ICmpInst(Cond, Base, ConstantExpr::getPointerBitCastOrAddrSpaceCast( @@ -941,8 +904,8 @@ Instruction *InstCombiner::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS, Type *LHSIndexTy = LOffset->getType(); Type *RHSIndexTy = ROffset->getType(); if (LHSIndexTy != RHSIndexTy) { - if (LHSIndexTy->getPrimitiveSizeInBits() < - RHSIndexTy->getPrimitiveSizeInBits()) { + if (LHSIndexTy->getPrimitiveSizeInBits().getFixedSize() < + RHSIndexTy->getPrimitiveSizeInBits().getFixedSize()) { ROffset = Builder.CreateTrunc(ROffset, LHSIndexTy); } else LOffset = Builder.CreateTrunc(LOffset, RHSIndexTy); @@ -1021,9 +984,9 @@ Instruction *InstCombiner::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS, return transformToIndexedCompare(GEPLHS, RHS, Cond, DL); } -Instruction *InstCombiner::foldAllocaCmp(ICmpInst &ICI, - const AllocaInst *Alloca, - const Value *Other) { +Instruction *InstCombinerImpl::foldAllocaCmp(ICmpInst &ICI, + const AllocaInst *Alloca, + const Value *Other) { assert(ICI.isEquality() && "Cannot fold non-equality comparison."); // It would be tempting to fold away comparisons between allocas and any @@ -1099,8 +1062,8 @@ Instruction *InstCombiner::foldAllocaCmp(ICmpInst &ICI, } /// Fold "icmp pred (X+C), X". -Instruction *InstCombiner::foldICmpAddOpConst(Value *X, const APInt &C, - ICmpInst::Predicate Pred) { +Instruction *InstCombinerImpl::foldICmpAddOpConst(Value *X, const APInt &C, + ICmpInst::Predicate Pred) { // From this point on, we know that (X+C <= X) --> (X+C < X) because C != 0, // so the values can never be equal. Similarly for all other "or equals" // operators. @@ -1149,9 +1112,9 @@ Instruction *InstCombiner::foldICmpAddOpConst(Value *X, const APInt &C, /// Handle "(icmp eq/ne (ashr/lshr AP2, A), AP1)" -> /// (icmp eq/ne A, Log2(AP2/AP1)) -> /// (icmp eq/ne A, Log2(AP2) - Log2(AP1)). -Instruction *InstCombiner::foldICmpShrConstConst(ICmpInst &I, Value *A, - const APInt &AP1, - const APInt &AP2) { +Instruction *InstCombinerImpl::foldICmpShrConstConst(ICmpInst &I, Value *A, + const APInt &AP1, + const APInt &AP2) { assert(I.isEquality() && "Cannot fold icmp gt/lt"); auto getICmp = [&I](CmpInst::Predicate Pred, Value *LHS, Value *RHS) { @@ -1208,9 +1171,9 @@ Instruction *InstCombiner::foldICmpShrConstConst(ICmpInst &I, Value *A, /// Handle "(icmp eq/ne (shl AP2, A), AP1)" -> /// (icmp eq/ne A, TrailingZeros(AP1) - TrailingZeros(AP2)). -Instruction *InstCombiner::foldICmpShlConstConst(ICmpInst &I, Value *A, - const APInt &AP1, - const APInt &AP2) { +Instruction *InstCombinerImpl::foldICmpShlConstConst(ICmpInst &I, Value *A, + const APInt &AP1, + const APInt &AP2) { assert(I.isEquality() && "Cannot fold icmp gt/lt"); auto getICmp = [&I](CmpInst::Predicate Pred, Value *LHS, Value *RHS) { @@ -1254,7 +1217,7 @@ Instruction *InstCombiner::foldICmpShlConstConst(ICmpInst &I, Value *A, /// static Instruction *processUGT_ADDCST_ADD(ICmpInst &I, Value *A, Value *B, ConstantInt *CI2, ConstantInt *CI1, - InstCombiner &IC) { + InstCombinerImpl &IC) { // The transformation we're trying to do here is to transform this into an // llvm.sadd.with.overflow. To do this, we have to replace the original add // with a narrower add, and discard the add-with-constant that is part of the @@ -1340,7 +1303,7 @@ static Instruction *processUGT_ADDCST_ADD(ICmpInst &I, Value *A, Value *B, /// icmp eq/ne (urem/srem %x, %y), 0 /// iff %y is a power-of-two, we can replace this with a bit test: /// icmp eq/ne (and %x, (add %y, -1)), 0 -Instruction *InstCombiner::foldIRemByPowerOfTwoToBitTest(ICmpInst &I) { +Instruction *InstCombinerImpl::foldIRemByPowerOfTwoToBitTest(ICmpInst &I) { // This fold is only valid for equality predicates. if (!I.isEquality()) return nullptr; @@ -1359,7 +1322,7 @@ Instruction *InstCombiner::foldIRemByPowerOfTwoToBitTest(ICmpInst &I) { /// Fold equality-comparison between zero and any (maybe truncated) right-shift /// by one-less-than-bitwidth into a sign test on the original value. -Instruction *InstCombiner::foldSignBitTest(ICmpInst &I) { +Instruction *InstCombinerImpl::foldSignBitTest(ICmpInst &I) { Instruction *Val; ICmpInst::Predicate Pred; if (!I.isEquality() || !match(&I, m_ICmp(Pred, m_Instruction(Val), m_Zero()))) @@ -1390,7 +1353,7 @@ Instruction *InstCombiner::foldSignBitTest(ICmpInst &I) { } // Handle icmp pred X, 0 -Instruction *InstCombiner::foldICmpWithZero(ICmpInst &Cmp) { +Instruction *InstCombinerImpl::foldICmpWithZero(ICmpInst &Cmp) { CmpInst::Predicate Pred = Cmp.getPredicate(); if (!match(Cmp.getOperand(1), m_Zero())) return nullptr; @@ -1431,7 +1394,7 @@ Instruction *InstCombiner::foldICmpWithZero(ICmpInst &Cmp) { /// should be moved to some other helper and extended as noted below (it is also /// possible that code has been made unnecessary - do we canonicalize IR to /// overflow/saturating intrinsics or not?). -Instruction *InstCombiner::foldICmpWithConstant(ICmpInst &Cmp) { +Instruction *InstCombinerImpl::foldICmpWithConstant(ICmpInst &Cmp) { // Match the following pattern, which is a common idiom when writing // overflow-safe integer arithmetic functions. The source performs an addition // in wider type and explicitly checks for overflow using comparisons against @@ -1477,7 +1440,7 @@ Instruction *InstCombiner::foldICmpWithConstant(ICmpInst &Cmp) { } /// Canonicalize icmp instructions based on dominating conditions. -Instruction *InstCombiner::foldICmpWithDominatingICmp(ICmpInst &Cmp) { +Instruction *InstCombinerImpl::foldICmpWithDominatingICmp(ICmpInst &Cmp) { // This is a cheap/incomplete check for dominance - just match a single // predecessor with a conditional branch. BasicBlock *CmpBB = Cmp.getParent(); @@ -1547,9 +1510,9 @@ Instruction *InstCombiner::foldICmpWithDominatingICmp(ICmpInst &Cmp) { } /// Fold icmp (trunc X, Y), C. -Instruction *InstCombiner::foldICmpTruncConstant(ICmpInst &Cmp, - TruncInst *Trunc, - const APInt &C) { +Instruction *InstCombinerImpl::foldICmpTruncConstant(ICmpInst &Cmp, + TruncInst *Trunc, + const APInt &C) { ICmpInst::Predicate Pred = Cmp.getPredicate(); Value *X = Trunc->getOperand(0); if (C.isOneValue() && C.getBitWidth() > 1) { @@ -1580,9 +1543,9 @@ Instruction *InstCombiner::foldICmpTruncConstant(ICmpInst &Cmp, } /// Fold icmp (xor X, Y), C. -Instruction *InstCombiner::foldICmpXorConstant(ICmpInst &Cmp, - BinaryOperator *Xor, - const APInt &C) { +Instruction *InstCombinerImpl::foldICmpXorConstant(ICmpInst &Cmp, + BinaryOperator *Xor, + const APInt &C) { Value *X = Xor->getOperand(0); Value *Y = Xor->getOperand(1); const APInt *XorC; @@ -1612,15 +1575,13 @@ Instruction *InstCombiner::foldICmpXorConstant(ICmpInst &Cmp, if (Xor->hasOneUse()) { // (icmp u/s (xor X SignMask), C) -> (icmp s/u X, (xor C SignMask)) if (!Cmp.isEquality() && XorC->isSignMask()) { - Pred = Cmp.isSigned() ? Cmp.getUnsignedPredicate() - : Cmp.getSignedPredicate(); + Pred = Cmp.getFlippedSignednessPredicate(); return new ICmpInst(Pred, X, ConstantInt::get(X->getType(), C ^ *XorC)); } // (icmp u/s (xor X ~SignMask), C) -> (icmp s/u X, (xor C ~SignMask)) if (!Cmp.isEquality() && XorC->isMaxSignedValue()) { - Pred = Cmp.isSigned() ? Cmp.getUnsignedPredicate() - : Cmp.getSignedPredicate(); + Pred = Cmp.getFlippedSignednessPredicate(); Pred = Cmp.getSwappedPredicate(Pred); return new ICmpInst(Pred, X, ConstantInt::get(X->getType(), C ^ *XorC)); } @@ -1649,8 +1610,10 @@ Instruction *InstCombiner::foldICmpXorConstant(ICmpInst &Cmp, } /// Fold icmp (and (sh X, Y), C2), C1. -Instruction *InstCombiner::foldICmpAndShift(ICmpInst &Cmp, BinaryOperator *And, - const APInt &C1, const APInt &C2) { +Instruction *InstCombinerImpl::foldICmpAndShift(ICmpInst &Cmp, + BinaryOperator *And, + const APInt &C1, + const APInt &C2) { BinaryOperator *Shift = dyn_cast<BinaryOperator>(And->getOperand(0)); if (!Shift || !Shift->isShift()) return nullptr; @@ -1733,9 +1696,9 @@ Instruction *InstCombiner::foldICmpAndShift(ICmpInst &Cmp, BinaryOperator *And, } /// Fold icmp (and X, C2), C1. -Instruction *InstCombiner::foldICmpAndConstConst(ICmpInst &Cmp, - BinaryOperator *And, - const APInt &C1) { +Instruction *InstCombinerImpl::foldICmpAndConstConst(ICmpInst &Cmp, + BinaryOperator *And, + const APInt &C1) { bool isICMP_NE = Cmp.getPredicate() == ICmpInst::ICMP_NE; // For vectors: icmp ne (and X, 1), 0 --> trunc X to N x i1 @@ -1841,9 +1804,9 @@ Instruction *InstCombiner::foldICmpAndConstConst(ICmpInst &Cmp, } /// Fold icmp (and X, Y), C. -Instruction *InstCombiner::foldICmpAndConstant(ICmpInst &Cmp, - BinaryOperator *And, - const APInt &C) { +Instruction *InstCombinerImpl::foldICmpAndConstant(ICmpInst &Cmp, + BinaryOperator *And, + const APInt &C) { if (Instruction *I = foldICmpAndConstConst(Cmp, And, C)) return I; @@ -1883,7 +1846,7 @@ Instruction *InstCombiner::foldICmpAndConstant(ICmpInst &Cmp, if (ExactLogBase2 != -1 && DL.isLegalInteger(ExactLogBase2 + 1)) { Type *NTy = IntegerType::get(Cmp.getContext(), ExactLogBase2 + 1); if (auto *AndVTy = dyn_cast<VectorType>(And->getType())) - NTy = FixedVectorType::get(NTy, AndVTy->getNumElements()); + NTy = VectorType::get(NTy, AndVTy->getElementCount()); Value *Trunc = Builder.CreateTrunc(X, NTy); auto NewPred = Cmp.getPredicate() == CmpInst::ICMP_EQ ? CmpInst::ICMP_SGE : CmpInst::ICMP_SLT; @@ -1895,8 +1858,9 @@ Instruction *InstCombiner::foldICmpAndConstant(ICmpInst &Cmp, } /// Fold icmp (or X, Y), C. -Instruction *InstCombiner::foldICmpOrConstant(ICmpInst &Cmp, BinaryOperator *Or, - const APInt &C) { +Instruction *InstCombinerImpl::foldICmpOrConstant(ICmpInst &Cmp, + BinaryOperator *Or, + const APInt &C) { ICmpInst::Predicate Pred = Cmp.getPredicate(); if (C.isOneValue()) { // icmp slt signum(V) 1 --> icmp slt V, 1 @@ -1960,9 +1924,9 @@ Instruction *InstCombiner::foldICmpOrConstant(ICmpInst &Cmp, BinaryOperator *Or, } /// Fold icmp (mul X, Y), C. -Instruction *InstCombiner::foldICmpMulConstant(ICmpInst &Cmp, - BinaryOperator *Mul, - const APInt &C) { +Instruction *InstCombinerImpl::foldICmpMulConstant(ICmpInst &Cmp, + BinaryOperator *Mul, + const APInt &C) { const APInt *MulC; if (!match(Mul->getOperand(1), m_APInt(MulC))) return nullptr; @@ -1977,6 +1941,21 @@ Instruction *InstCombiner::foldICmpMulConstant(ICmpInst &Cmp, Constant::getNullValue(Mul->getType())); } + // If the multiply does not wrap, try to divide the compare constant by the + // multiplication factor. + if (Cmp.isEquality() && !MulC->isNullValue()) { + // (mul nsw X, MulC) == C --> X == C /s MulC + if (Mul->hasNoSignedWrap() && C.srem(*MulC).isNullValue()) { + Constant *NewC = ConstantInt::get(Mul->getType(), C.sdiv(*MulC)); + return new ICmpInst(Pred, Mul->getOperand(0), NewC); + } + // (mul nuw X, MulC) == C --> X == C /u MulC + if (Mul->hasNoUnsignedWrap() && C.urem(*MulC).isNullValue()) { + Constant *NewC = ConstantInt::get(Mul->getType(), C.udiv(*MulC)); + return new ICmpInst(Pred, Mul->getOperand(0), NewC); + } + } + return nullptr; } @@ -2043,9 +2022,9 @@ static Instruction *foldICmpShlOne(ICmpInst &Cmp, Instruction *Shl, } /// Fold icmp (shl X, Y), C. -Instruction *InstCombiner::foldICmpShlConstant(ICmpInst &Cmp, - BinaryOperator *Shl, - const APInt &C) { +Instruction *InstCombinerImpl::foldICmpShlConstant(ICmpInst &Cmp, + BinaryOperator *Shl, + const APInt &C) { const APInt *ShiftVal; if (Cmp.isEquality() && match(Shl->getOperand(0), m_APInt(ShiftVal))) return foldICmpShlConstConst(Cmp, Shl->getOperand(1), C, *ShiftVal); @@ -2173,7 +2152,7 @@ Instruction *InstCombiner::foldICmpShlConstant(ICmpInst &Cmp, DL.isLegalInteger(TypeBits - Amt)) { Type *TruncTy = IntegerType::get(Cmp.getContext(), TypeBits - Amt); if (auto *ShVTy = dyn_cast<VectorType>(ShType)) - TruncTy = FixedVectorType::get(TruncTy, ShVTy->getNumElements()); + TruncTy = VectorType::get(TruncTy, ShVTy->getElementCount()); Constant *NewC = ConstantInt::get(TruncTy, C.ashr(*ShiftAmt).trunc(TypeBits - Amt)); return new ICmpInst(Pred, Builder.CreateTrunc(X, TruncTy), NewC); @@ -2183,9 +2162,9 @@ Instruction *InstCombiner::foldICmpShlConstant(ICmpInst &Cmp, } /// Fold icmp ({al}shr X, Y), C. -Instruction *InstCombiner::foldICmpShrConstant(ICmpInst &Cmp, - BinaryOperator *Shr, - const APInt &C) { +Instruction *InstCombinerImpl::foldICmpShrConstant(ICmpInst &Cmp, + BinaryOperator *Shr, + const APInt &C) { // An exact shr only shifts out zero bits, so: // icmp eq/ne (shr X, Y), 0 --> icmp eq/ne X, 0 Value *X = Shr->getOperand(0); @@ -2231,6 +2210,21 @@ Instruction *InstCombiner::foldICmpShrConstant(ICmpInst &Cmp, (ShiftedC + 1).ashr(ShAmtVal) == (C + 1)) return new ICmpInst(Pred, X, ConstantInt::get(ShrTy, ShiftedC)); } + + // If the compare constant has significant bits above the lowest sign-bit, + // then convert an unsigned cmp to a test of the sign-bit: + // (ashr X, ShiftC) u> C --> X s< 0 + // (ashr X, ShiftC) u< C --> X s> -1 + if (C.getBitWidth() > 2 && C.getNumSignBits() <= ShAmtVal) { + if (Pred == CmpInst::ICMP_UGT) { + return new ICmpInst(CmpInst::ICMP_SLT, X, + ConstantInt::getNullValue(ShrTy)); + } + if (Pred == CmpInst::ICMP_ULT) { + return new ICmpInst(CmpInst::ICMP_SGT, X, + ConstantInt::getAllOnesValue(ShrTy)); + } + } } else { if (Pred == CmpInst::ICMP_ULT || (Pred == CmpInst::ICMP_UGT && IsExact)) { // icmp ult (lshr X, ShAmtC), C --> icmp ult X, (C << ShAmtC) @@ -2276,9 +2270,9 @@ Instruction *InstCombiner::foldICmpShrConstant(ICmpInst &Cmp, return nullptr; } -Instruction *InstCombiner::foldICmpSRemConstant(ICmpInst &Cmp, - BinaryOperator *SRem, - const APInt &C) { +Instruction *InstCombinerImpl::foldICmpSRemConstant(ICmpInst &Cmp, + BinaryOperator *SRem, + const APInt &C) { // Match an 'is positive' or 'is negative' comparison of remainder by a // constant power-of-2 value: // (X % pow2C) sgt/slt 0 @@ -2315,9 +2309,9 @@ Instruction *InstCombiner::foldICmpSRemConstant(ICmpInst &Cmp, } /// Fold icmp (udiv X, Y), C. -Instruction *InstCombiner::foldICmpUDivConstant(ICmpInst &Cmp, - BinaryOperator *UDiv, - const APInt &C) { +Instruction *InstCombinerImpl::foldICmpUDivConstant(ICmpInst &Cmp, + BinaryOperator *UDiv, + const APInt &C) { const APInt *C2; if (!match(UDiv->getOperand(0), m_APInt(C2))) return nullptr; @@ -2344,9 +2338,9 @@ Instruction *InstCombiner::foldICmpUDivConstant(ICmpInst &Cmp, } /// Fold icmp ({su}div X, Y), C. -Instruction *InstCombiner::foldICmpDivConstant(ICmpInst &Cmp, - BinaryOperator *Div, - const APInt &C) { +Instruction *InstCombinerImpl::foldICmpDivConstant(ICmpInst &Cmp, + BinaryOperator *Div, + const APInt &C) { // Fold: icmp pred ([us]div X, C2), C -> range test // Fold this div into the comparison, producing a range check. // Determine, based on the divide type, what the range is being @@ -2514,9 +2508,9 @@ Instruction *InstCombiner::foldICmpDivConstant(ICmpInst &Cmp, } /// Fold icmp (sub X, Y), C. -Instruction *InstCombiner::foldICmpSubConstant(ICmpInst &Cmp, - BinaryOperator *Sub, - const APInt &C) { +Instruction *InstCombinerImpl::foldICmpSubConstant(ICmpInst &Cmp, + BinaryOperator *Sub, + const APInt &C) { Value *X = Sub->getOperand(0), *Y = Sub->getOperand(1); ICmpInst::Predicate Pred = Cmp.getPredicate(); const APInt *C2; @@ -2576,9 +2570,9 @@ Instruction *InstCombiner::foldICmpSubConstant(ICmpInst &Cmp, } /// Fold icmp (add X, Y), C. -Instruction *InstCombiner::foldICmpAddConstant(ICmpInst &Cmp, - BinaryOperator *Add, - const APInt &C) { +Instruction *InstCombinerImpl::foldICmpAddConstant(ICmpInst &Cmp, + BinaryOperator *Add, + const APInt &C) { Value *Y = Add->getOperand(1); const APInt *C2; if (Cmp.isEquality() || !match(Y, m_APInt(C2))) @@ -2642,10 +2636,10 @@ Instruction *InstCombiner::foldICmpAddConstant(ICmpInst &Cmp, return nullptr; } -bool InstCombiner::matchThreeWayIntCompare(SelectInst *SI, Value *&LHS, - Value *&RHS, ConstantInt *&Less, - ConstantInt *&Equal, - ConstantInt *&Greater) { +bool InstCombinerImpl::matchThreeWayIntCompare(SelectInst *SI, Value *&LHS, + Value *&RHS, ConstantInt *&Less, + ConstantInt *&Equal, + ConstantInt *&Greater) { // TODO: Generalize this to work with other comparison idioms or ensure // they get canonicalized into this form. @@ -2682,7 +2676,8 @@ bool InstCombiner::matchThreeWayIntCompare(SelectInst *SI, Value *&LHS, if (PredB == ICmpInst::ICMP_SGT && isa<Constant>(RHS2)) { // x sgt C-1 <--> x sge C <--> not(x slt C) auto FlippedStrictness = - getFlippedStrictnessPredicateAndConstant(PredB, cast<Constant>(RHS2)); + InstCombiner::getFlippedStrictnessPredicateAndConstant( + PredB, cast<Constant>(RHS2)); if (!FlippedStrictness) return false; assert(FlippedStrictness->first == ICmpInst::ICMP_SGE && "Sanity check"); @@ -2694,9 +2689,9 @@ bool InstCombiner::matchThreeWayIntCompare(SelectInst *SI, Value *&LHS, return PredB == ICmpInst::ICMP_SLT && RHS == RHS2; } -Instruction *InstCombiner::foldICmpSelectConstant(ICmpInst &Cmp, - SelectInst *Select, - ConstantInt *C) { +Instruction *InstCombinerImpl::foldICmpSelectConstant(ICmpInst &Cmp, + SelectInst *Select, + ConstantInt *C) { assert(C && "Cmp RHS should be a constant int!"); // If we're testing a constant value against the result of a three way @@ -2794,7 +2789,7 @@ static Instruction *foldICmpBitCast(ICmpInst &Cmp, const APInt *C; bool TrueIfSigned; if (match(Op1, m_APInt(C)) && Bitcast->hasOneUse() && - isSignBitCheck(Pred, *C, TrueIfSigned)) { + InstCombiner::isSignBitCheck(Pred, *C, TrueIfSigned)) { if (match(BCSrcOp, m_FPExt(m_Value(X))) || match(BCSrcOp, m_FPTrunc(m_Value(X)))) { // (bitcast (fpext/fptrunc X)) to iX) < 0 --> (bitcast X to iY) < 0 @@ -2806,7 +2801,7 @@ static Instruction *foldICmpBitCast(ICmpInst &Cmp, Type *NewType = Builder.getIntNTy(XType->getScalarSizeInBits()); if (auto *XVTy = dyn_cast<VectorType>(XType)) - NewType = FixedVectorType::get(NewType, XVTy->getNumElements()); + NewType = VectorType::get(NewType, XVTy->getElementCount()); Value *NewBitcast = Builder.CreateBitCast(X, NewType); if (TrueIfSigned) return new ICmpInst(ICmpInst::ICMP_SLT, NewBitcast, @@ -2870,7 +2865,7 @@ static Instruction *foldICmpBitCast(ICmpInst &Cmp, /// Try to fold integer comparisons with a constant operand: icmp Pred X, C /// where X is some kind of instruction. -Instruction *InstCombiner::foldICmpInstWithConstant(ICmpInst &Cmp) { +Instruction *InstCombinerImpl::foldICmpInstWithConstant(ICmpInst &Cmp) { const APInt *C; if (!match(Cmp.getOperand(1), m_APInt(C))) return nullptr; @@ -2955,9 +2950,8 @@ Instruction *InstCombiner::foldICmpInstWithConstant(ICmpInst &Cmp) { /// Fold an icmp equality instruction with binary operator LHS and constant RHS: /// icmp eq/ne BO, C. -Instruction *InstCombiner::foldICmpBinOpEqualityWithConstant(ICmpInst &Cmp, - BinaryOperator *BO, - const APInt &C) { +Instruction *InstCombinerImpl::foldICmpBinOpEqualityWithConstant( + ICmpInst &Cmp, BinaryOperator *BO, const APInt &C) { // TODO: Some of these folds could work with arbitrary constants, but this // function is limited to scalar and vector splat constants. if (!Cmp.isEquality()) @@ -3047,17 +3041,6 @@ Instruction *InstCombiner::foldICmpBinOpEqualityWithConstant(ICmpInst &Cmp, } break; } - case Instruction::Mul: - if (C.isNullValue() && BO->hasNoSignedWrap()) { - const APInt *BOC; - if (match(BOp1, m_APInt(BOC)) && !BOC->isNullValue()) { - // The trivial case (mul X, 0) is handled by InstSimplify. - // General case : (mul X, C) != 0 iff X != 0 - // (mul X, C) == 0 iff X == 0 - return new ICmpInst(Pred, BOp0, Constant::getNullValue(RHS->getType())); - } - } - break; case Instruction::UDiv: if (C.isNullValue()) { // (icmp eq/ne (udiv A, B), 0) -> (icmp ugt/ule i32 B, A) @@ -3072,12 +3055,19 @@ Instruction *InstCombiner::foldICmpBinOpEqualityWithConstant(ICmpInst &Cmp, } /// Fold an equality icmp with LLVM intrinsic and constant operand. -Instruction *InstCombiner::foldICmpEqIntrinsicWithConstant(ICmpInst &Cmp, - IntrinsicInst *II, - const APInt &C) { +Instruction *InstCombinerImpl::foldICmpEqIntrinsicWithConstant( + ICmpInst &Cmp, IntrinsicInst *II, const APInt &C) { Type *Ty = II->getType(); unsigned BitWidth = C.getBitWidth(); switch (II->getIntrinsicID()) { + case Intrinsic::abs: + // abs(A) == 0 -> A == 0 + // abs(A) == INT_MIN -> A == INT_MIN + if (C.isNullValue() || C.isMinSignedValue()) + return new ICmpInst(Cmp.getPredicate(), II->getArgOperand(0), + ConstantInt::get(Ty, C)); + break; + case Intrinsic::bswap: // bswap(A) == C -> A == bswap(C) return new ICmpInst(Cmp.getPredicate(), II->getArgOperand(0), @@ -3145,18 +3135,31 @@ Instruction *InstCombiner::foldICmpEqIntrinsicWithConstant(ICmpInst &Cmp, } /// Fold an icmp with LLVM intrinsic and constant operand: icmp Pred II, C. -Instruction *InstCombiner::foldICmpIntrinsicWithConstant(ICmpInst &Cmp, - IntrinsicInst *II, - const APInt &C) { +Instruction *InstCombinerImpl::foldICmpIntrinsicWithConstant(ICmpInst &Cmp, + IntrinsicInst *II, + const APInt &C) { if (Cmp.isEquality()) return foldICmpEqIntrinsicWithConstant(Cmp, II, C); Type *Ty = II->getType(); unsigned BitWidth = C.getBitWidth(); + ICmpInst::Predicate Pred = Cmp.getPredicate(); switch (II->getIntrinsicID()) { + case Intrinsic::ctpop: { + // (ctpop X > BitWidth - 1) --> X == -1 + Value *X = II->getArgOperand(0); + if (C == BitWidth - 1 && Pred == ICmpInst::ICMP_UGT) + return CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, X, + ConstantInt::getAllOnesValue(Ty)); + // (ctpop X < BitWidth) --> X != -1 + if (C == BitWidth && Pred == ICmpInst::ICMP_ULT) + return CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_NE, X, + ConstantInt::getAllOnesValue(Ty)); + break; + } case Intrinsic::ctlz: { // ctlz(0bXXXXXXXX) > 3 -> 0bXXXXXXXX < 0b00010000 - if (Cmp.getPredicate() == ICmpInst::ICMP_UGT && C.ult(BitWidth)) { + if (Pred == ICmpInst::ICMP_UGT && C.ult(BitWidth)) { unsigned Num = C.getLimitedValue(); APInt Limit = APInt::getOneBitSet(BitWidth, BitWidth - Num - 1); return CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_ULT, @@ -3164,8 +3167,7 @@ Instruction *InstCombiner::foldICmpIntrinsicWithConstant(ICmpInst &Cmp, } // ctlz(0bXXXXXXXX) < 3 -> 0bXXXXXXXX > 0b00011111 - if (Cmp.getPredicate() == ICmpInst::ICMP_ULT && - C.uge(1) && C.ule(BitWidth)) { + if (Pred == ICmpInst::ICMP_ULT && C.uge(1) && C.ule(BitWidth)) { unsigned Num = C.getLimitedValue(); APInt Limit = APInt::getLowBitsSet(BitWidth, BitWidth - Num); return CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_UGT, @@ -3179,7 +3181,7 @@ Instruction *InstCombiner::foldICmpIntrinsicWithConstant(ICmpInst &Cmp, return nullptr; // cttz(0bXXXXXXXX) > 3 -> 0bXXXXXXXX & 0b00001111 == 0 - if (Cmp.getPredicate() == ICmpInst::ICMP_UGT && C.ult(BitWidth)) { + if (Pred == ICmpInst::ICMP_UGT && C.ult(BitWidth)) { APInt Mask = APInt::getLowBitsSet(BitWidth, C.getLimitedValue() + 1); return CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, Builder.CreateAnd(II->getArgOperand(0), Mask), @@ -3187,8 +3189,7 @@ Instruction *InstCombiner::foldICmpIntrinsicWithConstant(ICmpInst &Cmp, } // cttz(0bXXXXXXXX) < 3 -> 0bXXXXXXXX & 0b00000111 != 0 - if (Cmp.getPredicate() == ICmpInst::ICMP_ULT && - C.uge(1) && C.ule(BitWidth)) { + if (Pred == ICmpInst::ICMP_ULT && C.uge(1) && C.ule(BitWidth)) { APInt Mask = APInt::getLowBitsSet(BitWidth, C.getLimitedValue()); return CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_NE, Builder.CreateAnd(II->getArgOperand(0), Mask), @@ -3204,7 +3205,7 @@ Instruction *InstCombiner::foldICmpIntrinsicWithConstant(ICmpInst &Cmp, } /// Handle icmp with constant (but not simple integer constant) RHS. -Instruction *InstCombiner::foldICmpInstWithConstantNotInt(ICmpInst &I) { +Instruction *InstCombinerImpl::foldICmpInstWithConstantNotInt(ICmpInst &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); Constant *RHSC = dyn_cast<Constant>(Op1); Instruction *LHSI = dyn_cast<Instruction>(Op0); @@ -3383,8 +3384,8 @@ static Value *foldICmpWithLowBitMaskedVal(ICmpInst &I, // those elements by copying an existing, defined, and safe scalar constant. Type *OpTy = M->getType(); auto *VecC = dyn_cast<Constant>(M); - if (OpTy->isVectorTy() && VecC && VecC->containsUndefElement()) { - auto *OpVTy = cast<VectorType>(OpTy); + auto *OpVTy = dyn_cast<FixedVectorType>(OpTy); + if (OpVTy && VecC && VecC->containsUndefOrPoisonElement()) { Constant *SafeReplacementConstant = nullptr; for (unsigned i = 0, e = OpVTy->getNumElements(); i != e; ++i) { if (!isa<UndefValue>(VecC->getAggregateElement(i))) { @@ -3650,7 +3651,7 @@ foldShiftIntoShiftInAnotherHandOfAndInICmp(ICmpInst &I, const SimplifyQuery SQ, /// @llvm.umul.with.overflow(x, y) plus extraction of overflow bit /// Note that the comparison is commutative, while inverted (u>=, ==) predicate /// will mean that we are looking for the opposite answer. -Value *InstCombiner::foldUnsignedMultiplicationOverflowCheck(ICmpInst &I) { +Value *InstCombinerImpl::foldUnsignedMultiplicationOverflowCheck(ICmpInst &I) { ICmpInst::Predicate Pred; Value *X, *Y; Instruction *Mul; @@ -3712,11 +3713,28 @@ Value *InstCombiner::foldUnsignedMultiplicationOverflowCheck(ICmpInst &I) { return Res; } +static Instruction *foldICmpXNegX(ICmpInst &I) { + CmpInst::Predicate Pred; + Value *X; + if (!match(&I, m_c_ICmp(Pred, m_NSWNeg(m_Value(X)), m_Deferred(X)))) + return nullptr; + + if (ICmpInst::isSigned(Pred)) + Pred = ICmpInst::getSwappedPredicate(Pred); + else if (ICmpInst::isUnsigned(Pred)) + Pred = ICmpInst::getSignedPredicate(Pred); + // else for equality-comparisons just keep the predicate. + + return ICmpInst::Create(Instruction::ICmp, Pred, X, + Constant::getNullValue(X->getType()), I.getName()); +} + /// Try to fold icmp (binop), X or icmp X, (binop). /// TODO: A large part of this logic is duplicated in InstSimplify's /// simplifyICmpWithBinOp(). We should be able to share that and avoid the code /// duplication. -Instruction *InstCombiner::foldICmpBinOp(ICmpInst &I, const SimplifyQuery &SQ) { +Instruction *InstCombinerImpl::foldICmpBinOp(ICmpInst &I, + const SimplifyQuery &SQ) { const SimplifyQuery Q = SQ.getWithInstruction(&I); Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); @@ -3726,6 +3744,9 @@ Instruction *InstCombiner::foldICmpBinOp(ICmpInst &I, const SimplifyQuery &SQ) { if (!BO0 && !BO1) return nullptr; + if (Instruction *NewICmp = foldICmpXNegX(I)) + return NewICmp; + const CmpInst::Predicate Pred = I.getPredicate(); Value *X; @@ -3946,6 +3967,19 @@ Instruction *InstCombiner::foldICmpBinOp(ICmpInst &I, const SimplifyQuery &SQ) { ConstantExpr::getNeg(RHSC)); } + { + // Try to remove shared constant multiplier from equality comparison: + // X * C == Y * C (with no overflowing/aliasing) --> X == Y + Value *X, *Y; + const APInt *C; + if (match(Op0, m_Mul(m_Value(X), m_APInt(C))) && *C != 0 && + match(Op1, m_Mul(m_Value(Y), m_SpecificInt(*C))) && I.isEquality()) + if (!C->countTrailingZeros() || + (BO0->hasNoSignedWrap() && BO1->hasNoSignedWrap()) || + (BO0->hasNoUnsignedWrap() && BO1->hasNoUnsignedWrap())) + return new ICmpInst(Pred, X, Y); + } + BinaryOperator *SRem = nullptr; // icmp (srem X, Y), Y if (BO0 && BO0->getOpcode() == Instruction::SRem && Op1 == BO0->getOperand(1)) @@ -3990,15 +4024,13 @@ Instruction *InstCombiner::foldICmpBinOp(ICmpInst &I, const SimplifyQuery &SQ) { if (match(BO0->getOperand(1), m_APInt(C))) { // icmp u/s (a ^ signmask), (b ^ signmask) --> icmp s/u a, b if (C->isSignMask()) { - ICmpInst::Predicate NewPred = - I.isSigned() ? I.getUnsignedPredicate() : I.getSignedPredicate(); + ICmpInst::Predicate NewPred = I.getFlippedSignednessPredicate(); return new ICmpInst(NewPred, BO0->getOperand(0), BO1->getOperand(0)); } // icmp u/s (a ^ maxsignval), (b ^ maxsignval) --> icmp s/u' a, b if (BO0->getOpcode() == Instruction::Xor && C->isMaxSignedValue()) { - ICmpInst::Predicate NewPred = - I.isSigned() ? I.getUnsignedPredicate() : I.getSignedPredicate(); + ICmpInst::Predicate NewPred = I.getFlippedSignednessPredicate(); NewPred = I.getSwappedPredicate(NewPred); return new ICmpInst(NewPred, BO0->getOperand(0), BO1->getOperand(0)); } @@ -4022,10 +4054,6 @@ Instruction *InstCombiner::foldICmpBinOp(ICmpInst &I, const SimplifyQuery &SQ) { Value *And2 = Builder.CreateAnd(BO1->getOperand(0), Mask); return new ICmpInst(Pred, And1, And2); } - // If there are no trailing zeros in the multiplier, just eliminate - // the multiplies (no masking is needed): - // icmp eq/ne (X * C), (Y * C) --> icmp eq/ne X, Y - return new ICmpInst(Pred, BO0->getOperand(0), BO1->getOperand(0)); } break; } @@ -4170,7 +4198,7 @@ static Instruction *foldICmpWithMinMax(ICmpInst &Cmp) { return nullptr; } -Instruction *InstCombiner::foldICmpEquality(ICmpInst &I) { +Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) { if (!I.isEquality()) return nullptr; @@ -4438,7 +4466,7 @@ static Instruction *foldICmpWithZextOrSext(ICmpInst &ICmp, } /// Handle icmp (cast x), (cast or constant). -Instruction *InstCombiner::foldICmpWithCastOp(ICmpInst &ICmp) { +Instruction *InstCombinerImpl::foldICmpWithCastOp(ICmpInst &ICmp) { auto *CastOp0 = dyn_cast<CastInst>(ICmp.getOperand(0)); if (!CastOp0) return nullptr; @@ -4493,9 +4521,10 @@ static bool isNeutralValue(Instruction::BinaryOps BinaryOp, Value *RHS) { } } -OverflowResult InstCombiner::computeOverflow( - Instruction::BinaryOps BinaryOp, bool IsSigned, - Value *LHS, Value *RHS, Instruction *CxtI) const { +OverflowResult +InstCombinerImpl::computeOverflow(Instruction::BinaryOps BinaryOp, + bool IsSigned, Value *LHS, Value *RHS, + Instruction *CxtI) const { switch (BinaryOp) { default: llvm_unreachable("Unsupported binary op"); @@ -4517,9 +4546,11 @@ OverflowResult InstCombiner::computeOverflow( } } -bool InstCombiner::OptimizeOverflowCheck( - Instruction::BinaryOps BinaryOp, bool IsSigned, Value *LHS, Value *RHS, - Instruction &OrigI, Value *&Result, Constant *&Overflow) { +bool InstCombinerImpl::OptimizeOverflowCheck(Instruction::BinaryOps BinaryOp, + bool IsSigned, Value *LHS, + Value *RHS, Instruction &OrigI, + Value *&Result, + Constant *&Overflow) { if (OrigI.isCommutative() && isa<Constant>(LHS) && !isa<Constant>(RHS)) std::swap(LHS, RHS); @@ -4529,9 +4560,13 @@ bool InstCombiner::OptimizeOverflowCheck( // compare. Builder.SetInsertPoint(&OrigI); + Type *OverflowTy = Type::getInt1Ty(LHS->getContext()); + if (auto *LHSTy = dyn_cast<VectorType>(LHS->getType())) + OverflowTy = VectorType::get(OverflowTy, LHSTy->getElementCount()); + if (isNeutralValue(BinaryOp, RHS)) { Result = LHS; - Overflow = Builder.getFalse(); + Overflow = ConstantInt::getFalse(OverflowTy); return true; } @@ -4542,12 +4577,12 @@ bool InstCombiner::OptimizeOverflowCheck( case OverflowResult::AlwaysOverflowsHigh: Result = Builder.CreateBinOp(BinaryOp, LHS, RHS); Result->takeName(&OrigI); - Overflow = Builder.getTrue(); + Overflow = ConstantInt::getTrue(OverflowTy); return true; case OverflowResult::NeverOverflows: Result = Builder.CreateBinOp(BinaryOp, LHS, RHS); Result->takeName(&OrigI); - Overflow = Builder.getFalse(); + Overflow = ConstantInt::getFalse(OverflowTy); if (auto *Inst = dyn_cast<Instruction>(Result)) { if (IsSigned) Inst->setHasNoSignedWrap(); @@ -4575,7 +4610,8 @@ bool InstCombiner::OptimizeOverflowCheck( /// \returns Instruction which must replace the compare instruction, NULL if no /// replacement required. static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal, - Value *OtherVal, InstCombiner &IC) { + Value *OtherVal, + InstCombinerImpl &IC) { // Don't bother doing this transformation for pointers, don't do it for // vectors. if (!isa<IntegerType>(MulVal->getType())) @@ -4723,15 +4759,14 @@ static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal, Function *F = Intrinsic::getDeclaration( I.getModule(), Intrinsic::umul_with_overflow, MulType); CallInst *Call = Builder.CreateCall(F, {MulA, MulB}, "umul"); - IC.Worklist.push(MulInstr); + IC.addToWorklist(MulInstr); // If there are uses of mul result other than the comparison, we know that // they are truncation or binary AND. Change them to use result of // mul.with.overflow and adjust properly mask/size. if (MulVal->hasNUsesOrMore(2)) { Value *Mul = Builder.CreateExtractValue(Call, 0, "umul.value"); - for (auto UI = MulVal->user_begin(), UE = MulVal->user_end(); UI != UE;) { - User *U = *UI++; + for (User *U : make_early_inc_range(MulVal->users())) { if (U == &I || U == OtherVal) continue; if (TruncInst *TI = dyn_cast<TruncInst>(U)) { @@ -4750,11 +4785,11 @@ static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal, } else { llvm_unreachable("Unexpected Binary operation"); } - IC.Worklist.push(cast<Instruction>(U)); + IC.addToWorklist(cast<Instruction>(U)); } } if (isa<Instruction>(OtherVal)) - IC.Worklist.push(cast<Instruction>(OtherVal)); + IC.addToWorklist(cast<Instruction>(OtherVal)); // The original icmp gets replaced with the overflow value, maybe inverted // depending on predicate. @@ -4799,7 +4834,7 @@ static APInt getDemandedBitsLHSMask(ICmpInst &I, unsigned BitWidth) { // If this is a normal comparison, it demands all bits. If it is a sign bit // comparison, it only demands the sign bit. bool UnusedBit; - if (isSignBitCheck(I.getPredicate(), *RHS, UnusedBit)) + if (InstCombiner::isSignBitCheck(I.getPredicate(), *RHS, UnusedBit)) return APInt::getSignMask(BitWidth); switch (I.getPredicate()) { @@ -4856,9 +4891,9 @@ static bool swapMayExposeCSEOpportunities(const Value *Op0, const Value *Op1) { /// \return true when \p UI is the only use of \p DI in the parent block /// and all other uses of \p DI are in blocks dominated by \p DB. /// -bool InstCombiner::dominatesAllUses(const Instruction *DI, - const Instruction *UI, - const BasicBlock *DB) const { +bool InstCombinerImpl::dominatesAllUses(const Instruction *DI, + const Instruction *UI, + const BasicBlock *DB) const { assert(DI && UI && "Instruction not defined\n"); // Ignore incomplete definitions. if (!DI->getParent()) @@ -4931,9 +4966,9 @@ static bool isChainSelectCmpBranch(const SelectInst *SI) { /// major restriction since a NE compare should be 'normalized' to an equal /// compare, which usually happens in the combiner and test case /// select-cmp-br.ll checks for it. -bool InstCombiner::replacedSelectWithOperand(SelectInst *SI, - const ICmpInst *Icmp, - const unsigned SIOpd) { +bool InstCombinerImpl::replacedSelectWithOperand(SelectInst *SI, + const ICmpInst *Icmp, + const unsigned SIOpd) { assert((SIOpd == 1 || SIOpd == 2) && "Invalid select operand!"); if (isChainSelectCmpBranch(SI) && Icmp->getPredicate() == ICmpInst::ICMP_EQ) { BasicBlock *Succ = SI->getParent()->getTerminator()->getSuccessor(1); @@ -4959,7 +4994,7 @@ bool InstCombiner::replacedSelectWithOperand(SelectInst *SI, /// Try to fold the comparison based on range information we can get by checking /// whether bits are known to be zero or one in the inputs. -Instruction *InstCombiner::foldICmpUsingKnownBits(ICmpInst &I) { +Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); Type *Ty = Op0->getType(); ICmpInst::Predicate Pred = I.getPredicate(); @@ -4990,11 +5025,15 @@ Instruction *InstCombiner::foldICmpUsingKnownBits(ICmpInst &I) { APInt Op0Min(BitWidth, 0), Op0Max(BitWidth, 0); APInt Op1Min(BitWidth, 0), Op1Max(BitWidth, 0); if (I.isSigned()) { - computeSignedMinMaxValuesFromKnownBits(Op0Known, Op0Min, Op0Max); - computeSignedMinMaxValuesFromKnownBits(Op1Known, Op1Min, Op1Max); + Op0Min = Op0Known.getSignedMinValue(); + Op0Max = Op0Known.getSignedMaxValue(); + Op1Min = Op1Known.getSignedMinValue(); + Op1Max = Op1Known.getSignedMaxValue(); } else { - computeUnsignedMinMaxValuesFromKnownBits(Op0Known, Op0Min, Op0Max); - computeUnsignedMinMaxValuesFromKnownBits(Op1Known, Op1Min, Op1Max); + Op0Min = Op0Known.getMinValue(); + Op0Max = Op0Known.getMaxValue(); + Op1Min = Op1Known.getMinValue(); + Op1Max = Op1Known.getMaxValue(); } // If Min and Max are known to be the same, then SimplifyDemandedBits figured @@ -5012,11 +5051,9 @@ Instruction *InstCombiner::foldICmpUsingKnownBits(ICmpInst &I) { llvm_unreachable("Unknown icmp opcode!"); case ICmpInst::ICMP_EQ: case ICmpInst::ICMP_NE: { - if (Op0Max.ult(Op1Min) || Op0Min.ugt(Op1Max)) { - return Pred == CmpInst::ICMP_EQ - ? replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())) - : replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); - } + if (Op0Max.ult(Op1Min) || Op0Min.ugt(Op1Max)) + return replaceInstUsesWith( + I, ConstantInt::getBool(I.getType(), Pred == CmpInst::ICMP_NE)); // If all bits are known zero except for one, then we know at most one bit // is set. If the comparison is against zero, then this is a check to see if @@ -5186,8 +5223,8 @@ Instruction *InstCombiner::foldICmpUsingKnownBits(ICmpInst &I) { } llvm::Optional<std::pair<CmpInst::Predicate, Constant *>> -llvm::getFlippedStrictnessPredicateAndConstant(CmpInst::Predicate Pred, - Constant *C) { +InstCombiner::getFlippedStrictnessPredicateAndConstant(CmpInst::Predicate Pred, + Constant *C) { assert(ICmpInst::isRelational(Pred) && ICmpInst::isIntPredicate(Pred) && "Only for relational integer predicates."); @@ -5209,8 +5246,8 @@ llvm::getFlippedStrictnessPredicateAndConstant(CmpInst::Predicate Pred, // Bail out if the constant can't be safely incremented/decremented. if (!ConstantIsOk(CI)) return llvm::None; - } else if (auto *VTy = dyn_cast<VectorType>(Type)) { - unsigned NumElts = VTy->getNumElements(); + } else if (auto *FVTy = dyn_cast<FixedVectorType>(Type)) { + unsigned NumElts = FVTy->getNumElements(); for (unsigned i = 0; i != NumElts; ++i) { Constant *Elt = C->getAggregateElement(i); if (!Elt) @@ -5236,7 +5273,8 @@ llvm::getFlippedStrictnessPredicateAndConstant(CmpInst::Predicate Pred, // It may not be safe to change a compare predicate in the presence of // undefined elements, so replace those elements with the first safe constant // that we found. - if (C->containsUndefElement()) { + // TODO: in case of poison, it is safe; let's replace undefs only. + if (C->containsUndefOrPoisonElement()) { assert(SafeReplacementConstant && "Replacement constant not set"); C = Constant::replaceUndefsWith(C, SafeReplacementConstant); } @@ -5256,7 +5294,7 @@ llvm::getFlippedStrictnessPredicateAndConstant(CmpInst::Predicate Pred, static ICmpInst *canonicalizeCmpWithConstant(ICmpInst &I) { ICmpInst::Predicate Pred = I.getPredicate(); if (ICmpInst::isEquality(Pred) || !ICmpInst::isIntPredicate(Pred) || - isCanonicalPredicate(Pred)) + InstCombiner::isCanonicalPredicate(Pred)) return nullptr; Value *Op0 = I.getOperand(0); @@ -5265,7 +5303,8 @@ static ICmpInst *canonicalizeCmpWithConstant(ICmpInst &I) { if (!Op1C) return nullptr; - auto FlippedStrictness = getFlippedStrictnessPredicateAndConstant(Pred, Op1C); + auto FlippedStrictness = + InstCombiner::getFlippedStrictnessPredicateAndConstant(Pred, Op1C); if (!FlippedStrictness) return nullptr; @@ -5274,14 +5313,14 @@ static ICmpInst *canonicalizeCmpWithConstant(ICmpInst &I) { /// If we have a comparison with a non-canonical predicate, if we can update /// all the users, invert the predicate and adjust all the users. -static CmpInst *canonicalizeICmpPredicate(CmpInst &I) { +CmpInst *InstCombinerImpl::canonicalizeICmpPredicate(CmpInst &I) { // Is the predicate already canonical? CmpInst::Predicate Pred = I.getPredicate(); - if (isCanonicalPredicate(Pred)) + if (InstCombiner::isCanonicalPredicate(Pred)) return nullptr; // Can all users be adjusted to predicate inversion? - if (!canFreelyInvertAllUsersOf(&I, /*IgnoredUser=*/nullptr)) + if (!InstCombiner::canFreelyInvertAllUsersOf(&I, /*IgnoredUser=*/nullptr)) return nullptr; // Ok, we can canonicalize comparison! @@ -5289,26 +5328,8 @@ static CmpInst *canonicalizeICmpPredicate(CmpInst &I) { I.setPredicate(CmpInst::getInversePredicate(Pred)); I.setName(I.getName() + ".not"); - // And now let's adjust every user. - for (User *U : I.users()) { - switch (cast<Instruction>(U)->getOpcode()) { - case Instruction::Select: { - auto *SI = cast<SelectInst>(U); - SI->swapValues(); - SI->swapProfMetadata(); - break; - } - case Instruction::Br: - cast<BranchInst>(U)->swapSuccessors(); // swaps prof metadata too - break; - case Instruction::Xor: - U->replaceAllUsesWith(&I); - break; - default: - llvm_unreachable("Got unexpected user - out of sync with " - "canFreelyInvertAllUsersOf() ?"); - } - } + // And, adapt users. + freelyInvertAllUsersOf(&I); return &I; } @@ -5510,7 +5531,7 @@ static Instruction *foldICmpOfUAddOv(ICmpInst &I) { return ExtractValueInst::Create(UAddOv, 1); } -Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { +Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) { bool Changed = false; const SimplifyQuery Q = SQ.getWithInstruction(&I); Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); @@ -5634,10 +5655,10 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { // Try to optimize equality comparisons against alloca-based pointers. if (Op0->getType()->isPointerTy() && I.isEquality()) { assert(Op1->getType()->isPointerTy() && "Comparing pointer with non-pointer?"); - if (auto *Alloca = dyn_cast<AllocaInst>(GetUnderlyingObject(Op0, DL))) + if (auto *Alloca = dyn_cast<AllocaInst>(getUnderlyingObject(Op0))) if (Instruction *New = foldAllocaCmp(I, Alloca, Op1)) return New; - if (auto *Alloca = dyn_cast<AllocaInst>(GetUnderlyingObject(Op1, DL))) + if (auto *Alloca = dyn_cast<AllocaInst>(getUnderlyingObject(Op1))) if (Instruction *New = foldAllocaCmp(I, Alloca, Op0)) return New; } @@ -5748,8 +5769,9 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { } /// Fold fcmp ([us]itofp x, cst) if possible. -Instruction *InstCombiner::foldFCmpIntToFPConst(FCmpInst &I, Instruction *LHSI, - Constant *RHSC) { +Instruction *InstCombinerImpl::foldFCmpIntToFPConst(FCmpInst &I, + Instruction *LHSI, + Constant *RHSC) { if (!isa<ConstantFP>(RHSC)) return nullptr; const APFloat &RHS = cast<ConstantFP>(RHSC)->getValueAPF(); @@ -6034,9 +6056,9 @@ static Instruction *foldFCmpReciprocalAndZero(FCmpInst &I, Instruction *LHSI, } /// Optimize fabs(X) compared with zero. -static Instruction *foldFabsWithFcmpZero(FCmpInst &I, InstCombiner &IC) { +static Instruction *foldFabsWithFcmpZero(FCmpInst &I, InstCombinerImpl &IC) { Value *X; - if (!match(I.getOperand(0), m_Intrinsic<Intrinsic::fabs>(m_Value(X))) || + if (!match(I.getOperand(0), m_FAbs(m_Value(X))) || !match(I.getOperand(1), m_PosZeroFP())) return nullptr; @@ -6096,7 +6118,7 @@ static Instruction *foldFabsWithFcmpZero(FCmpInst &I, InstCombiner &IC) { } } -Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { +Instruction *InstCombinerImpl::visitFCmpInst(FCmpInst &I) { bool Changed = false; /// Orders the operands of the compare so that they are listed from most diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h index f918dc7198ca..79e9d5c46c70 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -15,40 +15,32 @@ #ifndef LLVM_LIB_TRANSFORMS_INSTCOMBINE_INSTCOMBINEINTERNAL_H #define LLVM_LIB_TRANSFORMS_INSTCOMBINE_INSTCOMBINEINTERNAL_H -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/TargetFolder.h" #include "llvm/Analysis/ValueTracking.h" -#include "llvm/IR/Argument.h" -#include "llvm/IR/BasicBlock.h" -#include "llvm/IR/Constant.h" -#include "llvm/IR/Constants.h" -#include "llvm/IR/DerivedTypes.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstVisitor.h" -#include "llvm/IR/InstrTypes.h" -#include "llvm/IR/Instruction.h" -#include "llvm/IR/IntrinsicInst.h" -#include "llvm/IR/Intrinsics.h" #include "llvm/IR/PatternMatch.h" -#include "llvm/IR/Use.h" -#include "llvm/IR/Value.h" -#include "llvm/Support/Casting.h" -#include "llvm/Support/Compiler.h" #include "llvm/Support/Debug.h" #include "llvm/Support/KnownBits.h" -#include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/InstCombine/InstCombineWorklist.h" +#include "llvm/Transforms/InstCombine/InstCombiner.h" #include "llvm/Transforms/Utils/Local.h" #include <cassert> -#include <cstdint> #define DEBUG_TYPE "instcombine" using namespace llvm::PatternMatch; +// As a default, let's assume that we want to be aggressive, +// and attempt to traverse with no limits in attempt to sink negation. +static constexpr unsigned NegatorDefaultMaxDepth = ~0U; + +// Let's guesstimate that most often we will end up visiting/producing +// fairly small number of new instructions. +static constexpr unsigned NegatorMaxNodesSSO = 16; + namespace llvm { class AAResults; @@ -65,305 +57,26 @@ class ProfileSummaryInfo; class TargetLibraryInfo; class User; -/// Assign a complexity or rank value to LLVM Values. This is used to reduce -/// the amount of pattern matching needed for compares and commutative -/// instructions. For example, if we have: -/// icmp ugt X, Constant -/// or -/// xor (add X, Constant), cast Z -/// -/// We do not have to consider the commuted variants of these patterns because -/// canonicalization based on complexity guarantees the above ordering. -/// -/// This routine maps IR values to various complexity ranks: -/// 0 -> undef -/// 1 -> Constants -/// 2 -> Other non-instructions -/// 3 -> Arguments -/// 4 -> Cast and (f)neg/not instructions -/// 5 -> Other instructions -static inline unsigned getComplexity(Value *V) { - if (isa<Instruction>(V)) { - if (isa<CastInst>(V) || match(V, m_Neg(m_Value())) || - match(V, m_Not(m_Value())) || match(V, m_FNeg(m_Value()))) - return 4; - return 5; - } - if (isa<Argument>(V)) - return 3; - return isa<Constant>(V) ? (isa<UndefValue>(V) ? 0 : 1) : 2; -} - -/// Predicate canonicalization reduces the number of patterns that need to be -/// matched by other transforms. For example, we may swap the operands of a -/// conditional branch or select to create a compare with a canonical (inverted) -/// predicate which is then more likely to be matched with other values. -static inline bool isCanonicalPredicate(CmpInst::Predicate Pred) { - switch (Pred) { - case CmpInst::ICMP_NE: - case CmpInst::ICMP_ULE: - case CmpInst::ICMP_SLE: - case CmpInst::ICMP_UGE: - case CmpInst::ICMP_SGE: - // TODO: There are 16 FCMP predicates. Should others be (not) canonical? - case CmpInst::FCMP_ONE: - case CmpInst::FCMP_OLE: - case CmpInst::FCMP_OGE: - return false; - default: - return true; - } -} - -/// Given an exploded icmp instruction, return true if the comparison only -/// checks the sign bit. If it only checks the sign bit, set TrueIfSigned if the -/// result of the comparison is true when the input value is signed. -inline bool isSignBitCheck(ICmpInst::Predicate Pred, const APInt &RHS, - bool &TrueIfSigned) { - switch (Pred) { - case ICmpInst::ICMP_SLT: // True if LHS s< 0 - TrueIfSigned = true; - return RHS.isNullValue(); - case ICmpInst::ICMP_SLE: // True if LHS s<= -1 - TrueIfSigned = true; - return RHS.isAllOnesValue(); - case ICmpInst::ICMP_SGT: // True if LHS s> -1 - TrueIfSigned = false; - return RHS.isAllOnesValue(); - case ICmpInst::ICMP_SGE: // True if LHS s>= 0 - TrueIfSigned = false; - return RHS.isNullValue(); - case ICmpInst::ICMP_UGT: - // True if LHS u> RHS and RHS == sign-bit-mask - 1 - TrueIfSigned = true; - return RHS.isMaxSignedValue(); - case ICmpInst::ICMP_UGE: - // True if LHS u>= RHS and RHS == sign-bit-mask (2^7, 2^15, 2^31, etc) - TrueIfSigned = true; - return RHS.isMinSignedValue(); - case ICmpInst::ICMP_ULT: - // True if LHS u< RHS and RHS == sign-bit-mask (2^7, 2^15, 2^31, etc) - TrueIfSigned = false; - return RHS.isMinSignedValue(); - case ICmpInst::ICMP_ULE: - // True if LHS u<= RHS and RHS == sign-bit-mask - 1 - TrueIfSigned = false; - return RHS.isMaxSignedValue(); - default: - return false; - } -} - -llvm::Optional<std::pair<CmpInst::Predicate, Constant *>> -getFlippedStrictnessPredicateAndConstant(CmpInst::Predicate Pred, Constant *C); - -/// Return the source operand of a potentially bitcasted value while optionally -/// checking if it has one use. If there is no bitcast or the one use check is -/// not met, return the input value itself. -static inline Value *peekThroughBitcast(Value *V, bool OneUseOnly = false) { - if (auto *BitCast = dyn_cast<BitCastInst>(V)) - if (!OneUseOnly || BitCast->hasOneUse()) - return BitCast->getOperand(0); - - // V is not a bitcast or V has more than one use and OneUseOnly is true. - return V; -} - -/// Add one to a Constant -static inline Constant *AddOne(Constant *C) { - return ConstantExpr::getAdd(C, ConstantInt::get(C->getType(), 1)); -} - -/// Subtract one from a Constant -static inline Constant *SubOne(Constant *C) { - return ConstantExpr::getSub(C, ConstantInt::get(C->getType(), 1)); -} - -/// Return true if the specified value is free to invert (apply ~ to). -/// This happens in cases where the ~ can be eliminated. If WillInvertAllUses -/// is true, work under the assumption that the caller intends to remove all -/// uses of V and only keep uses of ~V. -/// -/// See also: canFreelyInvertAllUsersOf() -static inline bool isFreeToInvert(Value *V, bool WillInvertAllUses) { - // ~(~(X)) -> X. - if (match(V, m_Not(m_Value()))) - return true; - - // Constants can be considered to be not'ed values. - if (match(V, m_AnyIntegralConstant())) - return true; - - // Compares can be inverted if all of their uses are being modified to use the - // ~V. - if (isa<CmpInst>(V)) - return WillInvertAllUses; - - // If `V` is of the form `A + Constant` then `-1 - V` can be folded into `(-1 - // - Constant) - A` if we are willing to invert all of the uses. - if (BinaryOperator *BO = dyn_cast<BinaryOperator>(V)) - if (BO->getOpcode() == Instruction::Add || - BO->getOpcode() == Instruction::Sub) - if (isa<Constant>(BO->getOperand(0)) || isa<Constant>(BO->getOperand(1))) - return WillInvertAllUses; - - // Selects with invertible operands are freely invertible - if (match(V, m_Select(m_Value(), m_Not(m_Value()), m_Not(m_Value())))) - return WillInvertAllUses; - - return false; -} - -/// Given i1 V, can every user of V be freely adapted if V is changed to !V ? -/// InstCombine's canonicalizeICmpPredicate() must be kept in sync with this fn. -/// -/// See also: isFreeToInvert() -static inline bool canFreelyInvertAllUsersOf(Value *V, Value *IgnoredUser) { - // Look at every user of V. - for (Use &U : V->uses()) { - if (U.getUser() == IgnoredUser) - continue; // Don't consider this user. - - auto *I = cast<Instruction>(U.getUser()); - switch (I->getOpcode()) { - case Instruction::Select: - if (U.getOperandNo() != 0) // Only if the value is used as select cond. - return false; - break; - case Instruction::Br: - assert(U.getOperandNo() == 0 && "Must be branching on that value."); - break; // Free to invert by swapping true/false values/destinations. - case Instruction::Xor: // Can invert 'xor' if it's a 'not', by ignoring it. - if (!match(I, m_Not(m_Value()))) - return false; // Not a 'not'. - break; - default: - return false; // Don't know, likely not freely invertible. - } - // So far all users were free to invert... - } - return true; // Can freely invert all users! -} - -/// Some binary operators require special handling to avoid poison and undefined -/// behavior. If a constant vector has undef elements, replace those undefs with -/// identity constants if possible because those are always safe to execute. -/// If no identity constant exists, replace undef with some other safe constant. -static inline Constant *getSafeVectorConstantForBinop( - BinaryOperator::BinaryOps Opcode, Constant *In, bool IsRHSConstant) { - auto *InVTy = dyn_cast<VectorType>(In->getType()); - assert(InVTy && "Not expecting scalars here"); - - Type *EltTy = InVTy->getElementType(); - auto *SafeC = ConstantExpr::getBinOpIdentity(Opcode, EltTy, IsRHSConstant); - if (!SafeC) { - // TODO: Should this be available as a constant utility function? It is - // similar to getBinOpAbsorber(). - if (IsRHSConstant) { - switch (Opcode) { - case Instruction::SRem: // X % 1 = 0 - case Instruction::URem: // X %u 1 = 0 - SafeC = ConstantInt::get(EltTy, 1); - break; - case Instruction::FRem: // X % 1.0 (doesn't simplify, but it is safe) - SafeC = ConstantFP::get(EltTy, 1.0); - break; - default: - llvm_unreachable("Only rem opcodes have no identity constant for RHS"); - } - } else { - switch (Opcode) { - case Instruction::Shl: // 0 << X = 0 - case Instruction::LShr: // 0 >>u X = 0 - case Instruction::AShr: // 0 >> X = 0 - case Instruction::SDiv: // 0 / X = 0 - case Instruction::UDiv: // 0 /u X = 0 - case Instruction::SRem: // 0 % X = 0 - case Instruction::URem: // 0 %u X = 0 - case Instruction::Sub: // 0 - X (doesn't simplify, but it is safe) - case Instruction::FSub: // 0.0 - X (doesn't simplify, but it is safe) - case Instruction::FDiv: // 0.0 / X (doesn't simplify, but it is safe) - case Instruction::FRem: // 0.0 % X = 0 - SafeC = Constant::getNullValue(EltTy); - break; - default: - llvm_unreachable("Expected to find identity constant for opcode"); - } - } - } - assert(SafeC && "Must have safe constant for binop"); - unsigned NumElts = InVTy->getNumElements(); - SmallVector<Constant *, 16> Out(NumElts); - for (unsigned i = 0; i != NumElts; ++i) { - Constant *C = In->getAggregateElement(i); - Out[i] = isa<UndefValue>(C) ? SafeC : C; - } - return ConstantVector::get(Out); -} - -/// The core instruction combiner logic. -/// -/// This class provides both the logic to recursively visit instructions and -/// combine them. -class LLVM_LIBRARY_VISIBILITY InstCombiner - : public InstVisitor<InstCombiner, Instruction *> { - // FIXME: These members shouldn't be public. +class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final + : public InstCombiner, + public InstVisitor<InstCombinerImpl, Instruction *> { public: - /// A worklist of the instructions that need to be simplified. - InstCombineWorklist &Worklist; - - /// An IRBuilder that automatically inserts new instructions into the - /// worklist. - using BuilderTy = IRBuilder<TargetFolder, IRBuilderCallbackInserter>; - BuilderTy &Builder; - -private: - // Mode in which we are running the combiner. - const bool MinimizeSize; - - AAResults *AA; - - // Required analyses. - AssumptionCache &AC; - TargetLibraryInfo &TLI; - DominatorTree &DT; - const DataLayout &DL; - const SimplifyQuery SQ; - OptimizationRemarkEmitter &ORE; - BlockFrequencyInfo *BFI; - ProfileSummaryInfo *PSI; + InstCombinerImpl(InstCombineWorklist &Worklist, BuilderTy &Builder, + bool MinimizeSize, AAResults *AA, AssumptionCache &AC, + TargetLibraryInfo &TLI, TargetTransformInfo &TTI, + DominatorTree &DT, OptimizationRemarkEmitter &ORE, + BlockFrequencyInfo *BFI, ProfileSummaryInfo *PSI, + const DataLayout &DL, LoopInfo *LI) + : InstCombiner(Worklist, Builder, MinimizeSize, AA, AC, TLI, TTI, DT, ORE, + BFI, PSI, DL, LI) {} - // Optional analyses. When non-null, these can both be used to do better - // combining and will be updated to reflect any changes. - LoopInfo *LI; - - bool MadeIRChange = false; - -public: - InstCombiner(InstCombineWorklist &Worklist, BuilderTy &Builder, - bool MinimizeSize, AAResults *AA, - AssumptionCache &AC, TargetLibraryInfo &TLI, DominatorTree &DT, - OptimizationRemarkEmitter &ORE, BlockFrequencyInfo *BFI, - ProfileSummaryInfo *PSI, const DataLayout &DL, LoopInfo *LI) - : Worklist(Worklist), Builder(Builder), MinimizeSize(MinimizeSize), - AA(AA), AC(AC), TLI(TLI), DT(DT), - DL(DL), SQ(DL, &TLI, &DT, &AC), ORE(ORE), BFI(BFI), PSI(PSI), LI(LI) {} + virtual ~InstCombinerImpl() {} /// Run the combiner over the entire worklist until it is empty. /// /// \returns true if the IR is changed. bool run(); - AssumptionCache &getAssumptionCache() const { return AC; } - - const DataLayout &getDataLayout() const { return DL; } - - DominatorTree &getDominatorTree() const { return DT; } - - LoopInfo *getLoopInfo() const { return LI; } - - TargetLibraryInfo &getTargetLibraryInfo() const { return TLI; } - // Visitation implementation - Implement instruction combining for different // instruction types. The semantics are as follows: // Return Value: @@ -384,9 +97,7 @@ public: Instruction *visitSRem(BinaryOperator &I); Instruction *visitFRem(BinaryOperator &I); bool simplifyDivRemOfSelectWithZeroOp(BinaryOperator &I); - Instruction *commonRemTransforms(BinaryOperator &I); Instruction *commonIRemTransforms(BinaryOperator &I); - Instruction *commonDivTransforms(BinaryOperator &I); Instruction *commonIDivTransforms(BinaryOperator &I); Instruction *visitUDiv(BinaryOperator &I); Instruction *visitSDiv(BinaryOperator &I); @@ -394,6 +105,7 @@ public: Value *simplifyRangeCheck(ICmpInst *Cmp0, ICmpInst *Cmp1, bool Inverted); Instruction *visitAnd(BinaryOperator &I); Instruction *visitOr(BinaryOperator &I); + bool sinkNotIntoOtherHandOfAndOrOr(BinaryOperator &I); Instruction *visitXor(BinaryOperator &I); Instruction *visitShl(BinaryOperator &I); Value *reassociateShiftAmtsOfTwoSameDirectionShifts( @@ -407,6 +119,7 @@ public: Instruction *visitLShr(BinaryOperator &I); Instruction *commonShiftTransforms(BinaryOperator &I); Instruction *visitFCmpInst(FCmpInst &I); + CmpInst *canonicalizeICmpPredicate(CmpInst &I); Instruction *visitICmpInst(ICmpInst &I); Instruction *FoldShiftByConstant(Value *Op0, Constant *Op1, BinaryOperator &I); @@ -445,6 +158,9 @@ public: Instruction *visitFenceInst(FenceInst &FI); Instruction *visitSwitchInst(SwitchInst &SI); Instruction *visitReturnInst(ReturnInst &RI); + Instruction *visitUnreachableInst(UnreachableInst &I); + Instruction * + foldAggregateConstructionIntoAggregateReuse(InsertValueInst &OrigIVI); Instruction *visitInsertValueInst(InsertValueInst &IV); Instruction *visitInsertElementInst(InsertElementInst &IE); Instruction *visitExtractElementInst(ExtractElementInst &EI); @@ -467,11 +183,6 @@ public: bool replacedSelectWithOperand(SelectInst *SI, const ICmpInst *Icmp, const unsigned SIOpd); - /// Try to replace instruction \p I with value \p V which are pointers - /// in different address space. - /// \return true if successful. - bool replacePointer(Instruction &I, Value *V); - LoadInst *combineLoadToNewType(LoadInst &LI, Type *NewTy, const Twine &Suffix = ""); @@ -609,10 +320,12 @@ private: Instruction *narrowBinOp(TruncInst &Trunc); Instruction *narrowMaskedBinOp(BinaryOperator &And); Instruction *narrowMathIfNoOverflow(BinaryOperator &I); - Instruction *narrowRotate(TruncInst &Trunc); + Instruction *narrowFunnelShift(TruncInst &Trunc); Instruction *optimizeBitCastFromPhi(CastInst &CI, PHINode *PN); Instruction *matchSAddSubSat(SelectInst &MinMax1); + void freelyInvertAllUsersOf(Value *V); + /// Determine if a pair of casts can be replaced by a single cast. /// /// \param CI1 The first of a pair of casts. @@ -653,7 +366,7 @@ public: "New instruction already inserted into a basic block!"); BasicBlock *BB = Old.getParent(); BB->getInstList().insert(Old.getIterator(), New); // Insert inst - Worklist.push(New); + Worklist.add(New); return New; } @@ -685,6 +398,7 @@ public: << " with " << *V << '\n'); I.replaceAllUsesWith(V); + MadeIRChange = true; return &I; } @@ -726,7 +440,7 @@ public: /// When dealing with an instruction that has side effects or produces a void /// value, we can't rely on DCE to delete the instruction. Instead, visit /// methods should return the value returned by this function. - Instruction *eraseInstFromFunction(Instruction &I) { + Instruction *eraseInstFromFunction(Instruction &I) override { LLVM_DEBUG(dbgs() << "IC: ERASE " << I << '\n'); assert(I.use_empty() && "Cannot erase instruction that is used!"); salvageDebugInfo(I); @@ -808,10 +522,6 @@ public: Instruction::BinaryOps BinaryOp, bool IsSigned, Value *LHS, Value *RHS, Instruction *CxtI) const; - /// Maximum size of array considered when transforming. - uint64_t MaxArraySizeForCombine = 0; - -private: /// Performs a few simplifications for operators which are associative /// or commutative. bool SimplifyAssociativeOrCommutative(BinaryOperator &I); @@ -857,7 +567,7 @@ private: unsigned Depth, Instruction *CxtI); bool SimplifyDemandedBits(Instruction *I, unsigned Op, const APInt &DemandedMask, KnownBits &Known, - unsigned Depth = 0); + unsigned Depth = 0) override; /// Helper routine of SimplifyDemandedUseBits. It computes KnownZero/KnownOne /// bits. It also tries to handle simplifications that can be done based on @@ -877,13 +587,10 @@ private: /// demanded bits. bool SimplifyDemandedInstructionBits(Instruction &Inst); - Value *simplifyAMDGCNMemoryIntrinsicDemanded(IntrinsicInst *II, - APInt DemandedElts, - int DmaskIdx = -1); - - Value *SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, - APInt &UndefElts, unsigned Depth = 0, - bool AllowMultipleUsers = false); + virtual Value * + SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, APInt &UndefElts, + unsigned Depth = 0, + bool AllowMultipleUsers = false) override; /// Canonicalize the position of binops relative to shufflevector. Instruction *foldVectorBinop(BinaryOperator &Inst); @@ -907,16 +614,18 @@ private: /// Try to rotate an operation below a PHI node, using PHI nodes for /// its operands. - Instruction *FoldPHIArgOpIntoPHI(PHINode &PN); - Instruction *FoldPHIArgBinOpIntoPHI(PHINode &PN); - Instruction *FoldPHIArgGEPIntoPHI(PHINode &PN); - Instruction *FoldPHIArgLoadIntoPHI(PHINode &PN); - Instruction *FoldPHIArgZextsIntoPHI(PHINode &PN); + Instruction *foldPHIArgOpIntoPHI(PHINode &PN); + Instruction *foldPHIArgBinOpIntoPHI(PHINode &PN); + Instruction *foldPHIArgInsertValueInstructionIntoPHI(PHINode &PN); + Instruction *foldPHIArgExtractValueInstructionIntoPHI(PHINode &PN); + Instruction *foldPHIArgGEPIntoPHI(PHINode &PN); + Instruction *foldPHIArgLoadIntoPHI(PHINode &PN); + Instruction *foldPHIArgZextsIntoPHI(PHINode &PN); /// If an integer typed PHI has only one use which is an IntToPtr operation, /// replace the PHI with an existing pointer typed PHI if it exists. Otherwise /// insert a new pointer typed PHI and replace the original one. - Instruction *FoldIntegerTypedPHI(PHINode &PN); + Instruction *foldIntegerTypedPHI(PHINode &PN); /// Helper function for FoldPHIArgXIntoPHI() to set debug location for the /// folded operation. @@ -999,18 +708,18 @@ private: Value *A, Value *B, Instruction &Outer, SelectPatternFlavor SPF2, Value *C); Instruction *foldSelectInstWithICmp(SelectInst &SI, ICmpInst *ICI); - - Instruction *OptAndOp(BinaryOperator *Op, ConstantInt *OpRHS, - ConstantInt *AndRHS, BinaryOperator &TheAnd); + Instruction *foldSelectValueEquivalence(SelectInst &SI, ICmpInst &ICI); Value *insertRangeTest(Value *V, const APInt &Lo, const APInt &Hi, bool isSigned, bool Inside); Instruction *PromoteCastOfAllocation(BitCastInst &CI, AllocaInst &AI); bool mergeStoreIntoSuccessor(StoreInst &SI); - /// Given an 'or' instruction, check to see if it is part of a bswap idiom. - /// If so, return the equivalent bswap intrinsic. - Instruction *matchBSwap(BinaryOperator &Or); + /// Given an 'or' instruction, check to see if it is part of a + /// bswap/bitreverse idiom. If so, return the equivalent bswap/bitreverse + /// intrinsic. + Instruction *matchBSwapOrBitReverse(BinaryOperator &Or, bool MatchBSwaps, + bool MatchBitReversals); Instruction *SimplifyAnyMemTransfer(AnyMemTransferInst *MI); Instruction *SimplifyAnyMemSet(AnyMemSetInst *MI); @@ -1023,18 +732,6 @@ private: Value *Descale(Value *Val, APInt Scale, bool &NoSignedWrap); }; -namespace { - -// As a default, let's assume that we want to be aggressive, -// and attempt to traverse with no limits in attempt to sink negation. -static constexpr unsigned NegatorDefaultMaxDepth = ~0U; - -// Let's guesstimate that most often we will end up visiting/producing -// fairly small number of new instructions. -static constexpr unsigned NegatorMaxNodesSSO = 16; - -} // namespace - class Negator final { /// Top-to-bottom, def-to-use negated instruction tree we produced. SmallVector<Instruction *, NegatorMaxNodesSSO> NewInstructions; @@ -1061,6 +758,8 @@ class Negator final { using Result = std::pair<ArrayRef<Instruction *> /*NewInstructions*/, Value * /*NegatedRoot*/>; + std::array<Value *, 2> getSortedOperandsOfBinOp(Instruction *I); + LLVM_NODISCARD Value *visitImpl(Value *V, unsigned Depth); LLVM_NODISCARD Value *negate(Value *V, unsigned Depth); @@ -1078,7 +777,7 @@ public: /// Attempt to negate \p Root. Retuns nullptr if negation can't be performed, /// otherwise returns negated value. LLVM_NODISCARD static Value *Negate(bool LHSIsZero, Value *Root, - InstCombiner &IC); + InstCombinerImpl &IC); }; } // end namespace llvm diff --git a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp index dad2f23120bd..c7b5f6f78069 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp @@ -23,6 +23,7 @@ #include "llvm/IR/LLVMContext.h" #include "llvm/IR/MDBuilder.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/Transforms/InstCombine/InstCombiner.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Local.h" using namespace llvm; @@ -166,7 +167,8 @@ static bool isDereferenceableForAllocaSize(const Value *V, const AllocaInst *AI, APInt(64, AllocaSize), DL); } -static Instruction *simplifyAllocaArraySize(InstCombiner &IC, AllocaInst &AI) { +static Instruction *simplifyAllocaArraySize(InstCombinerImpl &IC, + AllocaInst &AI) { // Check for array size of 1 (scalar allocation). if (!AI.isArrayAllocation()) { // i32 1 is the canonical array size for scalar allocations. @@ -234,47 +236,45 @@ namespace { // instruction. class PointerReplacer { public: - PointerReplacer(InstCombiner &IC) : IC(IC) {} + PointerReplacer(InstCombinerImpl &IC) : IC(IC) {} + + bool collectUsers(Instruction &I); void replacePointer(Instruction &I, Value *V); private: - void findLoadAndReplace(Instruction &I); void replace(Instruction *I); Value *getReplacement(Value *I); - SmallVector<Instruction *, 4> Path; + SmallSetVector<Instruction *, 4> Worklist; MapVector<Value *, Value *> WorkMap; - InstCombiner &IC; + InstCombinerImpl &IC; }; } // end anonymous namespace -void PointerReplacer::findLoadAndReplace(Instruction &I) { +bool PointerReplacer::collectUsers(Instruction &I) { for (auto U : I.users()) { - auto *Inst = dyn_cast<Instruction>(&*U); - if (!Inst) - return; - LLVM_DEBUG(dbgs() << "Found pointer user: " << *U << '\n'); - if (isa<LoadInst>(Inst)) { - for (auto P : Path) - replace(P); - replace(Inst); + Instruction *Inst = cast<Instruction>(&*U); + if (LoadInst *Load = dyn_cast<LoadInst>(Inst)) { + if (Load->isVolatile()) + return false; + Worklist.insert(Load); } else if (isa<GetElementPtrInst>(Inst) || isa<BitCastInst>(Inst)) { - Path.push_back(Inst); - findLoadAndReplace(*Inst); - Path.pop_back(); + Worklist.insert(Inst); + if (!collectUsers(*Inst)) + return false; + } else if (isa<MemTransferInst>(Inst)) { + Worklist.insert(Inst); } else { - return; + LLVM_DEBUG(dbgs() << "Cannot handle pointer user: " << *U << '\n'); + return false; } } -} -Value *PointerReplacer::getReplacement(Value *V) { - auto Loc = WorkMap.find(V); - if (Loc != WorkMap.end()) - return Loc->second; - return nullptr; + return true; } +Value *PointerReplacer::getReplacement(Value *V) { return WorkMap.lookup(V); } + void PointerReplacer::replace(Instruction *I) { if (getReplacement(I)) return; @@ -282,9 +282,12 @@ void PointerReplacer::replace(Instruction *I) { if (auto *LT = dyn_cast<LoadInst>(I)) { auto *V = getReplacement(LT->getPointerOperand()); assert(V && "Operand not replaced"); - auto *NewI = new LoadInst(I->getType(), V, "", false, - IC.getDataLayout().getABITypeAlign(I->getType())); + auto *NewI = new LoadInst(LT->getType(), V, "", LT->isVolatile(), + LT->getAlign(), LT->getOrdering(), + LT->getSyncScopeID()); NewI->takeName(LT); + copyMetadataForLoad(*NewI, *LT); + IC.InsertNewInstWith(NewI, *LT); IC.replaceInstUsesWith(*LT, NewI); WorkMap[LT] = NewI; @@ -307,6 +310,28 @@ void PointerReplacer::replace(Instruction *I) { IC.InsertNewInstWith(NewI, *BC); NewI->takeName(BC); WorkMap[BC] = NewI; + } else if (auto *MemCpy = dyn_cast<MemTransferInst>(I)) { + auto *SrcV = getReplacement(MemCpy->getRawSource()); + // The pointer may appear in the destination of a copy, but we don't want to + // replace it. + if (!SrcV) { + assert(getReplacement(MemCpy->getRawDest()) && + "destination not in replace list"); + return; + } + + IC.Builder.SetInsertPoint(MemCpy); + auto *NewI = IC.Builder.CreateMemTransferInst( + MemCpy->getIntrinsicID(), MemCpy->getRawDest(), MemCpy->getDestAlign(), + SrcV, MemCpy->getSourceAlign(), MemCpy->getLength(), + MemCpy->isVolatile()); + AAMDNodes AAMD; + MemCpy->getAAMetadata(AAMD); + if (AAMD) + NewI->setAAMetadata(AAMD); + + IC.eraseInstFromFunction(*MemCpy); + WorkMap[MemCpy] = NewI; } else { llvm_unreachable("should never reach here"); } @@ -320,10 +345,12 @@ void PointerReplacer::replacePointer(Instruction &I, Value *V) { "Invalid usage"); #endif WorkMap[&I] = V; - findLoadAndReplace(I); + + for (Instruction *Workitem : Worklist) + replace(Workitem); } -Instruction *InstCombiner::visitAllocaInst(AllocaInst &AI) { +Instruction *InstCombinerImpl::visitAllocaInst(AllocaInst &AI) { if (auto *I = simplifyAllocaArraySize(*this, AI)) return I; @@ -374,23 +401,21 @@ Instruction *InstCombiner::visitAllocaInst(AllocaInst &AI) { // read. SmallVector<Instruction *, 4> ToDelete; if (MemTransferInst *Copy = isOnlyCopiedFromConstantMemory(AA, &AI, ToDelete)) { + Value *TheSrc = Copy->getSource(); Align AllocaAlign = AI.getAlign(); Align SourceAlign = getOrEnforceKnownAlignment( - Copy->getSource(), AllocaAlign, DL, &AI, &AC, &DT); + TheSrc, AllocaAlign, DL, &AI, &AC, &DT); if (AllocaAlign <= SourceAlign && - isDereferenceableForAllocaSize(Copy->getSource(), &AI, DL)) { + isDereferenceableForAllocaSize(TheSrc, &AI, DL)) { LLVM_DEBUG(dbgs() << "Found alloca equal to global: " << AI << '\n'); LLVM_DEBUG(dbgs() << " memcpy = " << *Copy << '\n'); - for (unsigned i = 0, e = ToDelete.size(); i != e; ++i) - eraseInstFromFunction(*ToDelete[i]); - Value *TheSrc = Copy->getSource(); - auto *SrcTy = TheSrc->getType(); - auto *DestTy = PointerType::get(AI.getType()->getPointerElementType(), - SrcTy->getPointerAddressSpace()); - Value *Cast = - Builder.CreatePointerBitCastOrAddrSpaceCast(TheSrc, DestTy); - if (AI.getType()->getPointerAddressSpace() == - SrcTy->getPointerAddressSpace()) { + unsigned SrcAddrSpace = TheSrc->getType()->getPointerAddressSpace(); + auto *DestTy = PointerType::get(AI.getAllocatedType(), SrcAddrSpace); + if (AI.getType()->getAddressSpace() == SrcAddrSpace) { + for (Instruction *Delete : ToDelete) + eraseInstFromFunction(*Delete); + + Value *Cast = Builder.CreateBitCast(TheSrc, DestTy); Instruction *NewI = replaceInstUsesWith(AI, Cast); eraseInstFromFunction(*Copy); ++NumGlobalCopies; @@ -398,8 +423,14 @@ Instruction *InstCombiner::visitAllocaInst(AllocaInst &AI) { } PointerReplacer PtrReplacer(*this); - PtrReplacer.replacePointer(AI, Cast); - ++NumGlobalCopies; + if (PtrReplacer.collectUsers(AI)) { + for (Instruction *Delete : ToDelete) + eraseInstFromFunction(*Delete); + + Value *Cast = Builder.CreateBitCast(TheSrc, DestTy); + PtrReplacer.replacePointer(AI, Cast); + ++NumGlobalCopies; + } } } @@ -421,9 +452,9 @@ static bool isSupportedAtomicType(Type *Ty) { /// that pointer type, load it, etc. /// /// Note that this will create all of the instructions with whatever insert -/// point the \c InstCombiner currently is using. -LoadInst *InstCombiner::combineLoadToNewType(LoadInst &LI, Type *NewTy, - const Twine &Suffix) { +/// point the \c InstCombinerImpl currently is using. +LoadInst *InstCombinerImpl::combineLoadToNewType(LoadInst &LI, Type *NewTy, + const Twine &Suffix) { assert((!LI.isAtomic() || isSupportedAtomicType(NewTy)) && "can't fold an atomic load to requested type"); @@ -445,7 +476,8 @@ LoadInst *InstCombiner::combineLoadToNewType(LoadInst &LI, Type *NewTy, /// Combine a store to a new type. /// /// Returns the newly created store instruction. -static StoreInst *combineStoreToNewValue(InstCombiner &IC, StoreInst &SI, Value *V) { +static StoreInst *combineStoreToNewValue(InstCombinerImpl &IC, StoreInst &SI, + Value *V) { assert((!SI.isAtomic() || isSupportedAtomicType(V->getType())) && "can't fold an atomic store of requested type"); @@ -485,6 +517,7 @@ static StoreInst *combineStoreToNewValue(InstCombiner &IC, StoreInst &SI, Value break; case LLVMContext::MD_invariant_load: case LLVMContext::MD_nonnull: + case LLVMContext::MD_noundef: case LLVMContext::MD_range: case LLVMContext::MD_align: case LLVMContext::MD_dereferenceable: @@ -502,7 +535,7 @@ static StoreInst *combineStoreToNewValue(InstCombiner &IC, StoreInst &SI, Value static bool isMinMaxWithLoads(Value *V, Type *&LoadTy) { assert(V->getType()->isPointerTy() && "Expected pointer type."); // Ignore possible ty* to ixx* bitcast. - V = peekThroughBitcast(V); + V = InstCombiner::peekThroughBitcast(V); // Check that select is select ((cmp load V1, load V2), V1, V2) - minmax // pattern. CmpInst::Predicate Pred; @@ -537,7 +570,8 @@ static bool isMinMaxWithLoads(Value *V, Type *&LoadTy) { /// or a volatile load. This is debatable, and might be reasonable to change /// later. However, it is risky in case some backend or other part of LLVM is /// relying on the exact type loaded to select appropriate atomic operations. -static Instruction *combineLoadToOperationType(InstCombiner &IC, LoadInst &LI) { +static Instruction *combineLoadToOperationType(InstCombinerImpl &IC, + LoadInst &LI) { // FIXME: We could probably with some care handle both volatile and ordered // atomic loads here but it isn't clear that this is important. if (!LI.isUnordered()) @@ -550,62 +584,38 @@ static Instruction *combineLoadToOperationType(InstCombiner &IC, LoadInst &LI) { if (LI.getPointerOperand()->isSwiftError()) return nullptr; - Type *Ty = LI.getType(); const DataLayout &DL = IC.getDataLayout(); - // Try to canonicalize loads which are only ever stored to operate over - // integers instead of any other type. We only do this when the loaded type - // is sized and has a size exactly the same as its store size and the store - // size is a legal integer type. - // Do not perform canonicalization if minmax pattern is found (to avoid - // infinite loop). - Type *Dummy; - if (!Ty->isIntegerTy() && Ty->isSized() && !isa<ScalableVectorType>(Ty) && - DL.isLegalInteger(DL.getTypeStoreSizeInBits(Ty)) && - DL.typeSizeEqualsStoreSize(Ty) && !DL.isNonIntegralPointerType(Ty) && - !isMinMaxWithLoads( - peekThroughBitcast(LI.getPointerOperand(), /*OneUseOnly=*/true), - Dummy)) { - if (all_of(LI.users(), [&LI](User *U) { - auto *SI = dyn_cast<StoreInst>(U); - return SI && SI->getPointerOperand() != &LI && - !SI->getPointerOperand()->isSwiftError(); - })) { - LoadInst *NewLoad = IC.combineLoadToNewType( - LI, Type::getIntNTy(LI.getContext(), DL.getTypeStoreSizeInBits(Ty))); - // Replace all the stores with stores of the newly loaded value. - for (auto UI = LI.user_begin(), UE = LI.user_end(); UI != UE;) { - auto *SI = cast<StoreInst>(*UI++); - IC.Builder.SetInsertPoint(SI); - combineStoreToNewValue(IC, *SI, NewLoad); - IC.eraseInstFromFunction(*SI); - } - assert(LI.use_empty() && "Failed to remove all users of the load!"); - // Return the old load so the combiner can delete it safely. - return &LI; + // Fold away bit casts of the loaded value by loading the desired type. + // Note that we should not do this for pointer<->integer casts, + // because that would result in type punning. + if (LI.hasOneUse()) { + // Don't transform when the type is x86_amx, it makes the pass that lower + // x86_amx type happy. + if (auto *BC = dyn_cast<BitCastInst>(LI.user_back())) { + assert(!LI.getType()->isX86_AMXTy() && + "load from x86_amx* should not happen!"); + if (BC->getType()->isX86_AMXTy()) + return nullptr; } - } - // Fold away bit casts of the loaded value by loading the desired type. - // We can do this for BitCastInsts as well as casts from and to pointer types, - // as long as those are noops (i.e., the source or dest type have the same - // bitwidth as the target's pointers). - if (LI.hasOneUse()) if (auto* CI = dyn_cast<CastInst>(LI.user_back())) - if (CI->isNoopCast(DL)) + if (CI->isNoopCast(DL) && LI.getType()->isPtrOrPtrVectorTy() == + CI->getDestTy()->isPtrOrPtrVectorTy()) if (!LI.isAtomic() || isSupportedAtomicType(CI->getDestTy())) { LoadInst *NewLoad = IC.combineLoadToNewType(LI, CI->getDestTy()); CI->replaceAllUsesWith(NewLoad); IC.eraseInstFromFunction(*CI); return &LI; } + } // FIXME: We should also canonicalize loads of vectors when their elements are // cast to other types. return nullptr; } -static Instruction *unpackLoadToAggregate(InstCombiner &IC, LoadInst &LI) { +static Instruction *unpackLoadToAggregate(InstCombinerImpl &IC, LoadInst &LI) { // FIXME: We could probably with some care handle both volatile and atomic // stores here but it isn't clear that this is important. if (!LI.isSimple()) @@ -743,8 +753,7 @@ static bool isObjectSizeLessThanOrEq(Value *V, uint64_t MaxSize, } if (PHINode *PN = dyn_cast<PHINode>(P)) { - for (Value *IncValue : PN->incoming_values()) - Worklist.push_back(IncValue); + append_range(Worklist, PN->incoming_values()); continue; } @@ -804,8 +813,9 @@ static bool isObjectSizeLessThanOrEq(Value *V, uint64_t MaxSize, // not zero. Currently, we only handle the first such index. Also, we could // also search through non-zero constant indices if we kept track of the // offsets those indices implied. -static bool canReplaceGEPIdxWithZero(InstCombiner &IC, GetElementPtrInst *GEPI, - Instruction *MemI, unsigned &Idx) { +static bool canReplaceGEPIdxWithZero(InstCombinerImpl &IC, + GetElementPtrInst *GEPI, Instruction *MemI, + unsigned &Idx) { if (GEPI->getNumOperands() < 2) return false; @@ -834,12 +844,17 @@ static bool canReplaceGEPIdxWithZero(InstCombiner &IC, GetElementPtrInst *GEPI, return false; SmallVector<Value *, 4> Ops(GEPI->idx_begin(), GEPI->idx_begin() + Idx); - Type *AllocTy = - GetElementPtrInst::getIndexedType(GEPI->getSourceElementType(), Ops); + Type *SourceElementType = GEPI->getSourceElementType(); + // Size information about scalable vectors is not available, so we cannot + // deduce whether indexing at n is undefined behaviour or not. Bail out. + if (isa<ScalableVectorType>(SourceElementType)) + return false; + + Type *AllocTy = GetElementPtrInst::getIndexedType(SourceElementType, Ops); if (!AllocTy || !AllocTy->isSized()) return false; const DataLayout &DL = IC.getDataLayout(); - uint64_t TyAllocSize = DL.getTypeAllocSize(AllocTy); + uint64_t TyAllocSize = DL.getTypeAllocSize(AllocTy).getFixedSize(); // If there are more indices after the one we might replace with a zero, make // sure they're all non-negative. If any of them are negative, the overall @@ -874,7 +889,7 @@ static bool canReplaceGEPIdxWithZero(InstCombiner &IC, GetElementPtrInst *GEPI, // access, but the object has only one element, we can assume that the index // will always be zero. If we replace the GEP, return it. template <typename T> -static Instruction *replaceGEPIdxWithZero(InstCombiner &IC, Value *Ptr, +static Instruction *replaceGEPIdxWithZero(InstCombinerImpl &IC, Value *Ptr, T &MemI) { if (GetElementPtrInst *GEPI = dyn_cast<GetElementPtrInst>(Ptr)) { unsigned Idx; @@ -916,7 +931,7 @@ static bool canSimplifyNullLoadOrGEP(LoadInst &LI, Value *Op) { return false; } -Instruction *InstCombiner::visitLoadInst(LoadInst &LI) { +Instruction *InstCombinerImpl::visitLoadInst(LoadInst &LI) { Value *Op = LI.getOperand(0); // Try to canonicalize the loaded type. @@ -1033,7 +1048,7 @@ Instruction *InstCombiner::visitLoadInst(LoadInst &LI) { /// and the layout of a <2 x double> is isomorphic to a [2 x double], /// then %V1 can be safely approximated by a conceptual "bitcast" of %U. /// Note that %U may contain non-undef values where %V1 has undef. -static Value *likeBitCastFromVector(InstCombiner &IC, Value *V) { +static Value *likeBitCastFromVector(InstCombinerImpl &IC, Value *V) { Value *U = nullptr; while (auto *IV = dyn_cast<InsertValueInst>(V)) { auto *E = dyn_cast<ExtractElementInst>(IV->getInsertedValueOperand()); @@ -1060,11 +1075,11 @@ static Value *likeBitCastFromVector(InstCombiner &IC, Value *V) { return nullptr; } if (auto *AT = dyn_cast<ArrayType>(VT)) { - if (AT->getNumElements() != UT->getNumElements()) + if (AT->getNumElements() != cast<FixedVectorType>(UT)->getNumElements()) return nullptr; } else { auto *ST = cast<StructType>(VT); - if (ST->getNumElements() != UT->getNumElements()) + if (ST->getNumElements() != cast<FixedVectorType>(UT)->getNumElements()) return nullptr; for (const auto *EltT : ST->elements()) { if (EltT != UT->getElementType()) @@ -1094,7 +1109,7 @@ static Value *likeBitCastFromVector(InstCombiner &IC, Value *V) { /// the caller must erase the store instruction. We have to let the caller erase /// the store instruction as otherwise there is no way to signal whether it was /// combined or not: IC.EraseInstFromFunction returns a null pointer. -static bool combineStoreToValueType(InstCombiner &IC, StoreInst &SI) { +static bool combineStoreToValueType(InstCombinerImpl &IC, StoreInst &SI) { // FIXME: We could probably with some care handle both volatile and ordered // atomic stores here but it isn't clear that this is important. if (!SI.isUnordered()) @@ -1108,7 +1123,13 @@ static bool combineStoreToValueType(InstCombiner &IC, StoreInst &SI) { // Fold away bit casts of the stored value by storing the original type. if (auto *BC = dyn_cast<BitCastInst>(V)) { + assert(!BC->getType()->isX86_AMXTy() && + "store to x86_amx* should not happen!"); V = BC->getOperand(0); + // Don't transform when the type is x86_amx, it makes the pass that lower + // x86_amx type happy. + if (V->getType()->isX86_AMXTy()) + return false; if (!SI.isAtomic() || isSupportedAtomicType(V->getType())) { combineStoreToNewValue(IC, SI, V); return true; @@ -1126,7 +1147,7 @@ static bool combineStoreToValueType(InstCombiner &IC, StoreInst &SI) { return false; } -static bool unpackStoreToAggregate(InstCombiner &IC, StoreInst &SI) { +static bool unpackStoreToAggregate(InstCombinerImpl &IC, StoreInst &SI) { // FIXME: We could probably with some care handle both volatile and atomic // stores here but it isn't clear that this is important. if (!SI.isSimple()) @@ -1266,7 +1287,7 @@ static bool equivalentAddressValues(Value *A, Value *B) { /// Converts store (bitcast (load (bitcast (select ...)))) to /// store (load (select ...)), where select is minmax: /// select ((cmp load V1, load V2), V1, V2). -static bool removeBitcastsFromLoadStoreOnMinMax(InstCombiner &IC, +static bool removeBitcastsFromLoadStoreOnMinMax(InstCombinerImpl &IC, StoreInst &SI) { // bitcast? if (!match(SI.getPointerOperand(), m_BitCast(m_Value()))) @@ -1296,7 +1317,8 @@ static bool removeBitcastsFromLoadStoreOnMinMax(InstCombiner &IC, if (!all_of(LI->users(), [LI, LoadAddr](User *U) { auto *SI = dyn_cast<StoreInst>(U); return SI && SI->getPointerOperand() != LI && - peekThroughBitcast(SI->getPointerOperand()) != LoadAddr && + InstCombiner::peekThroughBitcast(SI->getPointerOperand()) != + LoadAddr && !SI->getPointerOperand()->isSwiftError(); })) return false; @@ -1314,7 +1336,7 @@ static bool removeBitcastsFromLoadStoreOnMinMax(InstCombiner &IC, return true; } -Instruction *InstCombiner::visitStoreInst(StoreInst &SI) { +Instruction *InstCombinerImpl::visitStoreInst(StoreInst &SI) { Value *Val = SI.getOperand(0); Value *Ptr = SI.getOperand(1); @@ -1433,7 +1455,7 @@ Instruction *InstCombiner::visitStoreInst(StoreInst &SI) { /// or: /// *P = v1; if () { *P = v2; } /// into a phi node with a store in the successor. -bool InstCombiner::mergeStoreIntoSuccessor(StoreInst &SI) { +bool InstCombinerImpl::mergeStoreIntoSuccessor(StoreInst &SI) { if (!SI.isUnordered()) return false; // This code has not been audited for volatile/ordered case. diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp index c6233a68847d..4b485a0ad85e 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -32,6 +32,7 @@ #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/KnownBits.h" #include "llvm/Transforms/InstCombine/InstCombineWorklist.h" +#include "llvm/Transforms/InstCombine/InstCombiner.h" #include "llvm/Transforms/Utils/BuildLibCalls.h" #include <cassert> #include <cstddef> @@ -46,7 +47,7 @@ using namespace PatternMatch; /// The specific integer value is used in a context where it is known to be /// non-zero. If this allows us to simplify the computation, do so and return /// the new operand, otherwise return null. -static Value *simplifyValueKnownNonZero(Value *V, InstCombiner &IC, +static Value *simplifyValueKnownNonZero(Value *V, InstCombinerImpl &IC, Instruction &CxtI) { // If V has multiple uses, then we would have to do more analysis to determine // if this is safe. For example, the use could be in dynamically unreached @@ -94,39 +95,6 @@ static Value *simplifyValueKnownNonZero(Value *V, InstCombiner &IC, return MadeChange ? V : nullptr; } -/// A helper routine of InstCombiner::visitMul(). -/// -/// If C is a scalar/fixed width vector of known powers of 2, then this -/// function returns a new scalar/fixed width vector obtained from logBase2 -/// of C. -/// Return a null pointer otherwise. -static Constant *getLogBase2(Type *Ty, Constant *C) { - const APInt *IVal; - if (match(C, m_APInt(IVal)) && IVal->isPowerOf2()) - return ConstantInt::get(Ty, IVal->logBase2()); - - // FIXME: We can extract pow of 2 of splat constant for scalable vectors. - if (!isa<FixedVectorType>(Ty)) - return nullptr; - - SmallVector<Constant *, 4> Elts; - for (unsigned I = 0, E = cast<FixedVectorType>(Ty)->getNumElements(); I != E; - ++I) { - Constant *Elt = C->getAggregateElement(I); - if (!Elt) - return nullptr; - if (isa<UndefValue>(Elt)) { - Elts.push_back(UndefValue::get(Ty->getScalarType())); - continue; - } - if (!match(Elt, m_APInt(IVal)) || !IVal->isPowerOf2()) - return nullptr; - Elts.push_back(ConstantInt::get(Ty->getScalarType(), IVal->logBase2())); - } - - return ConstantVector::get(Elts); -} - // TODO: This is a specific form of a much more general pattern. // We could detect a select with any binop identity constant, or we // could use SimplifyBinOp to see if either arm of the select reduces. @@ -171,7 +139,7 @@ static Value *foldMulSelectToNegate(BinaryOperator &I, return nullptr; } -Instruction *InstCombiner::visitMul(BinaryOperator &I) { +Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) { if (Value *V = SimplifyMulInst(I.getOperand(0), I.getOperand(1), SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); @@ -185,8 +153,10 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { if (Value *V = SimplifyUsingDistributiveLaws(I)) return replaceInstUsesWith(I, V); - // X * -1 == 0 - X Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + unsigned BitWidth = I.getType()->getScalarSizeInBits(); + + // X * -1 == 0 - X if (match(Op1, m_AllOnes())) { BinaryOperator *BO = BinaryOperator::CreateNeg(Op0, I.getName()); if (I.hasNoSignedWrap()) @@ -216,7 +186,7 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { if (match(&I, m_Mul(m_Value(NewOp), m_Constant(C1)))) { // Replace X*(2^C) with X << C, where C is either a scalar or a vector. - if (Constant *NewCst = getLogBase2(NewOp->getType(), C1)) { + if (Constant *NewCst = ConstantExpr::getExactLogBase2(C1)) { BinaryOperator *Shl = BinaryOperator::CreateShl(NewOp, NewCst); if (I.hasNoUnsignedWrap()) @@ -232,29 +202,12 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { } } - if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) { - // (Y - X) * (-(2**n)) -> (X - Y) * (2**n), for positive nonzero n - // (Y + const) * (-(2**n)) -> (-constY) * (2**n), for positive nonzero n - // The "* (2**n)" thus becomes a potential shifting opportunity. - { - const APInt & Val = CI->getValue(); - const APInt &PosVal = Val.abs(); - if (Val.isNegative() && PosVal.isPowerOf2()) { - Value *X = nullptr, *Y = nullptr; - if (Op0->hasOneUse()) { - ConstantInt *C1; - Value *Sub = nullptr; - if (match(Op0, m_Sub(m_Value(Y), m_Value(X)))) - Sub = Builder.CreateSub(X, Y, "suba"); - else if (match(Op0, m_Add(m_Value(Y), m_ConstantInt(C1)))) - Sub = Builder.CreateSub(Builder.CreateNeg(C1), Y, "subc"); - if (Sub) - return - BinaryOperator::CreateMul(Sub, - ConstantInt::get(Y->getType(), PosVal)); - } - } - } + if (Op0->hasOneUse() && match(Op1, m_NegatedPower2())) { + // Interpret X * (-1<<C) as (-X) * (1<<C) and try to sink the negation. + // The "* (1<<C)" thus becomes a potential shifting opportunity. + if (Value *NegOp0 = Negator::Negate(/*IsNegation*/ true, Op0, *this)) + return BinaryOperator::CreateMul( + NegOp0, ConstantExpr::getNeg(cast<Constant>(Op1)), I.getName()); } if (Instruction *FoldedMul = foldBinOpIntoSelectOrPhi(I)) @@ -284,6 +237,9 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { SelectPatternFlavor SPF = matchSelectPattern(Op0, X, Y).Flavor; if (SPF == SPF_ABS || SPF == SPF_NABS) return BinaryOperator::CreateMul(X, X); + + if (match(Op0, m_Intrinsic<Intrinsic::abs>(m_Value(X)))) + return BinaryOperator::CreateMul(X, X); } // -X * C --> X * -C @@ -406,6 +362,19 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { if (match(Op1, m_LShr(m_Value(X), m_APInt(C))) && *C == C->getBitWidth() - 1) return BinaryOperator::CreateAnd(Builder.CreateAShr(X, *C), Op0); + // ((ashr X, 31) | 1) * X --> abs(X) + // X * ((ashr X, 31) | 1) --> abs(X) + if (match(&I, m_c_BinOp(m_Or(m_AShr(m_Value(X), + m_SpecificIntAllowUndef(BitWidth - 1)), + m_One()), + m_Deferred(X)))) { + Value *Abs = Builder.CreateBinaryIntrinsic( + Intrinsic::abs, X, + ConstantInt::getBool(I.getContext(), I.hasNoSignedWrap())); + Abs->takeName(&I); + return replaceInstUsesWith(I, Abs); + } + if (Instruction *Ext = narrowMathIfNoOverflow(I)) return Ext; @@ -423,7 +392,7 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { return Changed ? &I : nullptr; } -Instruction *InstCombiner::foldFPSignBitOps(BinaryOperator &I) { +Instruction *InstCombinerImpl::foldFPSignBitOps(BinaryOperator &I) { BinaryOperator::BinaryOps Opcode = I.getOpcode(); assert((Opcode == Instruction::FMul || Opcode == Instruction::FDiv) && "Expected fmul or fdiv"); @@ -438,13 +407,12 @@ Instruction *InstCombiner::foldFPSignBitOps(BinaryOperator &I) { // fabs(X) * fabs(X) -> X * X // fabs(X) / fabs(X) -> X / X - if (Op0 == Op1 && match(Op0, m_Intrinsic<Intrinsic::fabs>(m_Value(X)))) + if (Op0 == Op1 && match(Op0, m_FAbs(m_Value(X)))) return BinaryOperator::CreateWithCopiedFlags(Opcode, X, X, &I); // fabs(X) * fabs(Y) --> fabs(X * Y) // fabs(X) / fabs(Y) --> fabs(X / Y) - if (match(Op0, m_Intrinsic<Intrinsic::fabs>(m_Value(X))) && - match(Op1, m_Intrinsic<Intrinsic::fabs>(m_Value(Y))) && + if (match(Op0, m_FAbs(m_Value(X))) && match(Op1, m_FAbs(m_Value(Y))) && (Op0->hasOneUse() || Op1->hasOneUse())) { IRBuilder<>::FastMathFlagGuard FMFGuard(Builder); Builder.setFastMathFlags(I.getFastMathFlags()); @@ -457,7 +425,7 @@ Instruction *InstCombiner::foldFPSignBitOps(BinaryOperator &I) { return nullptr; } -Instruction *InstCombiner::visitFMul(BinaryOperator &I) { +Instruction *InstCombinerImpl::visitFMul(BinaryOperator &I) { if (Value *V = SimplifyFMulInst(I.getOperand(0), I.getOperand(1), I.getFastMathFlags(), SQ.getWithInstruction(&I))) @@ -553,6 +521,21 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) { return replaceInstUsesWith(I, Sqrt); } + // The following transforms are done irrespective of the number of uses + // for the expression "1.0/sqrt(X)". + // 1) 1.0/sqrt(X) * X -> X/sqrt(X) + // 2) X * 1.0/sqrt(X) -> X/sqrt(X) + // We always expect the backend to reduce X/sqrt(X) to sqrt(X), if it + // has the necessary (reassoc) fast-math-flags. + if (I.hasNoSignedZeros() && + match(Op0, (m_FDiv(m_SpecificFP(1.0), m_Value(Y)))) && + match(Y, m_Intrinsic<Intrinsic::sqrt>(m_Value(X))) && Op1 == X) + return BinaryOperator::CreateFDivFMF(X, Y, &I); + if (I.hasNoSignedZeros() && + match(Op1, (m_FDiv(m_SpecificFP(1.0), m_Value(Y)))) && + match(Y, m_Intrinsic<Intrinsic::sqrt>(m_Value(X))) && Op0 == X) + return BinaryOperator::CreateFDivFMF(X, Y, &I); + // Like the similar transform in instsimplify, this requires 'nsz' because // sqrt(-0.0) = -0.0, and -0.0 * -0.0 does not simplify to -0.0. if (I.hasNoNaNs() && I.hasNoSignedZeros() && Op0 == Op1 && @@ -637,7 +620,7 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) { /// Fold a divide or remainder with a select instruction divisor when one of the /// select operands is zero. In that case, we can use the other select operand /// because div/rem by zero is undefined. -bool InstCombiner::simplifyDivRemOfSelectWithZeroOp(BinaryOperator &I) { +bool InstCombinerImpl::simplifyDivRemOfSelectWithZeroOp(BinaryOperator &I) { SelectInst *SI = dyn_cast<SelectInst>(I.getOperand(1)); if (!SI) return false; @@ -738,7 +721,7 @@ static bool isMultiple(const APInt &C1, const APInt &C2, APInt &Quotient, /// instructions (udiv and sdiv). It is called by the visitors to those integer /// division instructions. /// Common integer divide transforms -Instruction *InstCombiner::commonIDivTransforms(BinaryOperator &I) { +Instruction *InstCombinerImpl::commonIDivTransforms(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); bool IsSigned = I.getOpcode() == Instruction::SDiv; Type *Ty = I.getType(); @@ -874,7 +857,7 @@ namespace { using FoldUDivOperandCb = Instruction *(*)(Value *Op0, Value *Op1, const BinaryOperator &I, - InstCombiner &IC); + InstCombinerImpl &IC); /// Used to maintain state for visitUDivOperand(). struct UDivFoldAction { @@ -903,8 +886,9 @@ struct UDivFoldAction { // X udiv 2^C -> X >> C static Instruction *foldUDivPow2Cst(Value *Op0, Value *Op1, - const BinaryOperator &I, InstCombiner &IC) { - Constant *C1 = getLogBase2(Op0->getType(), cast<Constant>(Op1)); + const BinaryOperator &I, + InstCombinerImpl &IC) { + Constant *C1 = ConstantExpr::getExactLogBase2(cast<Constant>(Op1)); if (!C1) llvm_unreachable("Failed to constant fold udiv -> logbase2"); BinaryOperator *LShr = BinaryOperator::CreateLShr(Op0, C1); @@ -916,7 +900,7 @@ static Instruction *foldUDivPow2Cst(Value *Op0, Value *Op1, // X udiv (C1 << N), where C1 is "1<<C2" --> X >> (N+C2) // X udiv (zext (C1 << N)), where C1 is "1<<C2" --> X >> (N+C2) static Instruction *foldUDivShl(Value *Op0, Value *Op1, const BinaryOperator &I, - InstCombiner &IC) { + InstCombinerImpl &IC) { Value *ShiftLeft; if (!match(Op1, m_ZExt(m_Value(ShiftLeft)))) ShiftLeft = Op1; @@ -925,7 +909,7 @@ static Instruction *foldUDivShl(Value *Op0, Value *Op1, const BinaryOperator &I, Value *N; if (!match(ShiftLeft, m_Shl(m_Constant(CI), m_Value(N)))) llvm_unreachable("match should never fail here!"); - Constant *Log2Base = getLogBase2(N->getType(), CI); + Constant *Log2Base = ConstantExpr::getExactLogBase2(CI); if (!Log2Base) llvm_unreachable("getLogBase2 should never fail here!"); N = IC.Builder.CreateAdd(N, Log2Base); @@ -944,6 +928,8 @@ static Instruction *foldUDivShl(Value *Op0, Value *Op1, const BinaryOperator &I, static size_t visitUDivOperand(Value *Op0, Value *Op1, const BinaryOperator &I, SmallVectorImpl<UDivFoldAction> &Actions, unsigned Depth = 0) { + // FIXME: assert that Op1 isn't/doesn't contain undef. + // Check to see if this is an unsigned division with an exact power of 2, // if so, convert to a right shift. if (match(Op1, m_Power2())) { @@ -963,6 +949,9 @@ static size_t visitUDivOperand(Value *Op0, Value *Op1, const BinaryOperator &I, return 0; if (SelectInst *SI = dyn_cast<SelectInst>(Op1)) + // FIXME: missed optimization: if one of the hands of select is/contains + // undef, just directly pick the other one. + // FIXME: can both hands contain undef? if (size_t LHSIdx = visitUDivOperand(Op0, SI->getOperand(1), I, Actions, Depth)) if (visitUDivOperand(Op0, SI->getOperand(2), I, Actions, Depth)) { @@ -1010,7 +999,7 @@ static Instruction *narrowUDivURem(BinaryOperator &I, return nullptr; } -Instruction *InstCombiner::visitUDiv(BinaryOperator &I) { +Instruction *InstCombinerImpl::visitUDiv(BinaryOperator &I) { if (Value *V = SimplifyUDivInst(I.getOperand(0), I.getOperand(1), SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); @@ -1104,7 +1093,7 @@ Instruction *InstCombiner::visitUDiv(BinaryOperator &I) { return nullptr; } -Instruction *InstCombiner::visitSDiv(BinaryOperator &I) { +Instruction *InstCombinerImpl::visitSDiv(BinaryOperator &I) { if (Value *V = SimplifySDivInst(I.getOperand(0), I.getOperand(1), SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); @@ -1117,6 +1106,7 @@ Instruction *InstCombiner::visitSDiv(BinaryOperator &I) { return Common; Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + Type *Ty = I.getType(); Value *X; // sdiv Op0, -1 --> -Op0 // sdiv Op0, (sext i1 X) --> -Op0 (because if X is 0, the op is undefined) @@ -1126,16 +1116,26 @@ Instruction *InstCombiner::visitSDiv(BinaryOperator &I) { // X / INT_MIN --> X == INT_MIN if (match(Op1, m_SignMask())) - return new ZExtInst(Builder.CreateICmpEQ(Op0, Op1), I.getType()); + return new ZExtInst(Builder.CreateICmpEQ(Op0, Op1), Ty); + + // sdiv exact X, 1<<C --> ashr exact X, C iff 1<<C is non-negative + // sdiv exact X, -1<<C --> -(ashr exact X, C) + if (I.isExact() && ((match(Op1, m_Power2()) && match(Op1, m_NonNegative())) || + match(Op1, m_NegatedPower2()))) { + bool DivisorWasNegative = match(Op1, m_NegatedPower2()); + if (DivisorWasNegative) + Op1 = ConstantExpr::getNeg(cast<Constant>(Op1)); + auto *AShr = BinaryOperator::CreateExactAShr( + Op0, ConstantExpr::getExactLogBase2(cast<Constant>(Op1)), I.getName()); + if (!DivisorWasNegative) + return AShr; + Builder.Insert(AShr); + AShr->setName(I.getName() + ".neg"); + return BinaryOperator::CreateNeg(AShr, I.getName()); + } const APInt *Op1C; if (match(Op1, m_APInt(Op1C))) { - // sdiv exact X, C --> ashr exact X, log2(C) - if (I.isExact() && Op1C->isNonNegative() && Op1C->isPowerOf2()) { - Value *ShAmt = ConstantInt::get(Op1->getType(), Op1C->exactLogBase2()); - return BinaryOperator::CreateExactAShr(Op0, ShAmt, I.getName()); - } - // If the dividend is sign-extended and the constant divisor is small enough // to fit in the source type, shrink the division to the narrower type: // (sext X) sdiv C --> sext (X sdiv C) @@ -1150,7 +1150,7 @@ Instruction *InstCombiner::visitSDiv(BinaryOperator &I) { Constant *NarrowDivisor = ConstantExpr::getTrunc(cast<Constant>(Op1), Op0Src->getType()); Value *NarrowOp = Builder.CreateSDiv(Op0Src, NarrowDivisor); - return new SExtInst(NarrowOp, Op0->getType()); + return new SExtInst(NarrowOp, Ty); } // -X / C --> X / -C (if the negation doesn't overflow). @@ -1158,7 +1158,7 @@ Instruction *InstCombiner::visitSDiv(BinaryOperator &I) { // checking if all elements are not the min-signed-val. if (!Op1C->isMinSignedValue() && match(Op0, m_NSWSub(m_Zero(), m_Value(X)))) { - Constant *NegC = ConstantInt::get(I.getType(), -(*Op1C)); + Constant *NegC = ConstantInt::get(Ty, -(*Op1C)); Instruction *BO = BinaryOperator::CreateSDiv(X, NegC); BO->setIsExact(I.isExact()); return BO; @@ -1171,9 +1171,19 @@ Instruction *InstCombiner::visitSDiv(BinaryOperator &I) { return BinaryOperator::CreateNSWNeg( Builder.CreateSDiv(X, Y, I.getName(), I.isExact())); + // abs(X) / X --> X > -1 ? 1 : -1 + // X / abs(X) --> X > -1 ? 1 : -1 + if (match(&I, m_c_BinOp( + m_OneUse(m_Intrinsic<Intrinsic::abs>(m_Value(X), m_One())), + m_Deferred(X)))) { + Constant *NegOne = ConstantInt::getAllOnesValue(Ty); + Value *Cond = Builder.CreateICmpSGT(X, NegOne); + return SelectInst::Create(Cond, ConstantInt::get(Ty, 1), NegOne); + } + // If the sign bits of both operands are zero (i.e. we can prove they are // unsigned inputs), turn this into a udiv. - APInt Mask(APInt::getSignMask(I.getType()->getScalarSizeInBits())); + APInt Mask(APInt::getSignMask(Ty->getScalarSizeInBits())); if (MaskedValueIsZero(Op0, Mask, 0, &I)) { if (MaskedValueIsZero(Op1, Mask, 0, &I)) { // X sdiv Y -> X udiv Y, iff X and Y don't have sign bit set @@ -1182,6 +1192,13 @@ Instruction *InstCombiner::visitSDiv(BinaryOperator &I) { return BO; } + if (match(Op1, m_NegatedPower2())) { + // X sdiv (-(1 << C)) -> -(X sdiv (1 << C)) -> + // -> -(X udiv (1 << C)) -> -(X u>> C) + return BinaryOperator::CreateNeg(Builder.Insert(foldUDivPow2Cst( + Op0, ConstantExpr::getNeg(cast<Constant>(Op1)), I, *this))); + } + if (isKnownToBeAPowerOfTwo(Op1, /*OrZero*/ true, 0, &I)) { // X sdiv (1 << Y) -> X udiv (1 << Y) ( -> X u>> Y) // Safe because the only negative value (1 << Y) can take on is @@ -1258,7 +1275,7 @@ static Instruction *foldFDivConstantDividend(BinaryOperator &I) { return BinaryOperator::CreateFDivFMF(NewC, X, &I); } -Instruction *InstCombiner::visitFDiv(BinaryOperator &I) { +Instruction *InstCombinerImpl::visitFDiv(BinaryOperator &I) { if (Value *V = SimplifyFDivInst(I.getOperand(0), I.getOperand(1), I.getFastMathFlags(), SQ.getWithInstruction(&I))) @@ -1350,10 +1367,8 @@ Instruction *InstCombiner::visitFDiv(BinaryOperator &I) { // X / fabs(X) -> copysign(1.0, X) // fabs(X) / X -> copysign(1.0, X) if (I.hasNoNaNs() && I.hasNoInfs() && - (match(&I, - m_FDiv(m_Value(X), m_Intrinsic<Intrinsic::fabs>(m_Deferred(X)))) || - match(&I, m_FDiv(m_Intrinsic<Intrinsic::fabs>(m_Value(X)), - m_Deferred(X))))) { + (match(&I, m_FDiv(m_Value(X), m_FAbs(m_Deferred(X)))) || + match(&I, m_FDiv(m_FAbs(m_Value(X)), m_Deferred(X))))) { Value *V = Builder.CreateBinaryIntrinsic( Intrinsic::copysign, ConstantFP::get(I.getType(), 1.0), X, &I); return replaceInstUsesWith(I, V); @@ -1365,7 +1380,7 @@ Instruction *InstCombiner::visitFDiv(BinaryOperator &I) { /// instructions (urem and srem). It is called by the visitors to those integer /// remainder instructions. /// Common integer remainder transforms -Instruction *InstCombiner::commonIRemTransforms(BinaryOperator &I) { +Instruction *InstCombinerImpl::commonIRemTransforms(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); // The RHS is known non-zero. @@ -1403,7 +1418,7 @@ Instruction *InstCombiner::commonIRemTransforms(BinaryOperator &I) { return nullptr; } -Instruction *InstCombiner::visitURem(BinaryOperator &I) { +Instruction *InstCombinerImpl::visitURem(BinaryOperator &I) { if (Value *V = SimplifyURemInst(I.getOperand(0), I.getOperand(1), SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); @@ -1454,7 +1469,7 @@ Instruction *InstCombiner::visitURem(BinaryOperator &I) { return nullptr; } -Instruction *InstCombiner::visitSRem(BinaryOperator &I) { +Instruction *InstCombinerImpl::visitSRem(BinaryOperator &I) { if (Value *V = SimplifySRemInst(I.getOperand(0), I.getOperand(1), SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); @@ -1477,7 +1492,7 @@ Instruction *InstCombiner::visitSRem(BinaryOperator &I) { // -X srem Y --> -(X srem Y) Value *X, *Y; if (match(&I, m_SRem(m_OneUse(m_NSWSub(m_Zero(), m_Value(X))), m_Value(Y)))) - return BinaryOperator::CreateNSWNeg(Builder.CreateSRem(X, Y)); + return BinaryOperator::CreateNSWNeg(Builder.CreateSRem(X, Y)); // If the sign bits of both operands are zero (i.e. we can prove they are // unsigned inputs), turn this into a urem. @@ -1491,7 +1506,7 @@ Instruction *InstCombiner::visitSRem(BinaryOperator &I) { // If it's a constant vector, flip any negative values positive. if (isa<ConstantVector>(Op1) || isa<ConstantDataVector>(Op1)) { Constant *C = cast<Constant>(Op1); - unsigned VWidth = cast<VectorType>(C->getType())->getNumElements(); + unsigned VWidth = cast<FixedVectorType>(C->getType())->getNumElements(); bool hasNegative = false; bool hasMissing = false; @@ -1526,7 +1541,7 @@ Instruction *InstCombiner::visitSRem(BinaryOperator &I) { return nullptr; } -Instruction *InstCombiner::visitFRem(BinaryOperator &I) { +Instruction *InstCombinerImpl::visitFRem(BinaryOperator &I) { if (Value *V = SimplifyFRemInst(I.getOperand(0), I.getOperand(1), I.getFastMathFlags(), SQ.getWithInstruction(&I))) diff --git a/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp b/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp index 3fe615ac5439..7718c8b0eedd 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp @@ -42,6 +42,9 @@ #include "llvm/Support/DebugCounter.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/InstCombine/InstCombiner.h" +#include <cassert> +#include <cstdint> #include <functional> #include <tuple> #include <type_traits> @@ -112,6 +115,19 @@ Negator::~Negator() { } #endif +// Due to the InstCombine's worklist management, there are no guarantees that +// each instruction we'll encounter has been visited by InstCombine already. +// In particular, most importantly for us, that means we have to canonicalize +// constants to RHS ourselves, since that is helpful sometimes. +std::array<Value *, 2> Negator::getSortedOperandsOfBinOp(Instruction *I) { + assert(I->getNumOperands() == 2 && "Only for binops!"); + std::array<Value *, 2> Ops{I->getOperand(0), I->getOperand(1)}; + if (I->isCommutative() && InstCombiner::getComplexity(I->getOperand(0)) < + InstCombiner::getComplexity(I->getOperand(1))) + std::swap(Ops[0], Ops[1]); + return Ops; +} + // FIXME: can this be reworked into a worklist-based algorithm while preserving // the depth-first, early bailout traversal? LLVM_NODISCARD Value *Negator::visitImpl(Value *V, unsigned Depth) { @@ -156,11 +172,13 @@ LLVM_NODISCARD Value *Negator::visitImpl(Value *V, unsigned Depth) { // In some cases we can give the answer without further recursion. switch (I->getOpcode()) { - case Instruction::Add: + case Instruction::Add: { + std::array<Value *, 2> Ops = getSortedOperandsOfBinOp(I); // `inc` is always negatible. - if (match(I->getOperand(1), m_One())) - return Builder.CreateNot(I->getOperand(0), I->getName() + ".neg"); + if (match(Ops[1], m_One())) + return Builder.CreateNot(Ops[0], I->getName() + ".neg"); break; + } case Instruction::Xor: // `not` is always negatible. if (match(I, m_Not(m_Value(X)))) @@ -181,6 +199,10 @@ LLVM_NODISCARD Value *Negator::visitImpl(Value *V, unsigned Depth) { } return BO; } + // While we could negate exact arithmetic shift: + // ashr exact %x, C --> sdiv exact i8 %x, -1<<C + // iff C != 0 and C u< bitwidth(%x), we don't want to, + // because division is *THAT* much worse than a shift. break; } case Instruction::SExt: @@ -197,26 +219,28 @@ LLVM_NODISCARD Value *Negator::visitImpl(Value *V, unsigned Depth) { break; // Other instructions require recursive reasoning. } + if (I->getOpcode() == Instruction::Sub && + (I->hasOneUse() || match(I->getOperand(0), m_ImmConstant()))) { + // `sub` is always negatible. + // However, only do this either if the old `sub` doesn't stick around, or + // it was subtracting from a constant. Otherwise, this isn't profitable. + return Builder.CreateSub(I->getOperand(1), I->getOperand(0), + I->getName() + ".neg"); + } + // Some other cases, while still don't require recursion, // are restricted to the one-use case. if (!V->hasOneUse()) return nullptr; switch (I->getOpcode()) { - case Instruction::Sub: - // `sub` is always negatible. - // But if the old `sub` sticks around, even thought we don't increase - // instruction count, this is a likely regression since we increased - // live-range of *both* of the operands, which might lead to more spilling. - return Builder.CreateSub(I->getOperand(1), I->getOperand(0), - I->getName() + ".neg"); case Instruction::SDiv: // `sdiv` is negatible if divisor is not undef/INT_MIN/1. // While this is normally not behind a use-check, // let's consider division to be special since it's costly. if (auto *Op1C = dyn_cast<Constant>(I->getOperand(1))) { - if (!Op1C->containsUndefElement() && Op1C->isNotMinSignedValue() && - Op1C->isNotOneValue()) { + if (!Op1C->containsUndefOrPoisonElement() && + Op1C->isNotMinSignedValue() && Op1C->isNotOneValue()) { Value *BO = Builder.CreateSDiv(I->getOperand(0), ConstantExpr::getNeg(Op1C), I->getName() + ".neg"); @@ -237,6 +261,13 @@ LLVM_NODISCARD Value *Negator::visitImpl(Value *V, unsigned Depth) { } switch (I->getOpcode()) { + case Instruction::Freeze: { + // `freeze` is negatible if its operand is negatible. + Value *NegOp = negate(I->getOperand(0), Depth + 1); + if (!NegOp) // Early return. + return nullptr; + return Builder.CreateFreeze(NegOp, I->getName() + ".neg"); + } case Instruction::PHI: { // `phi` is negatible if all the incoming values are negatible. auto *PHI = cast<PHINode>(I); @@ -254,20 +285,16 @@ LLVM_NODISCARD Value *Negator::visitImpl(Value *V, unsigned Depth) { return NegatedPHI; } case Instruction::Select: { - { - // `abs`/`nabs` is always negatible. - Value *LHS, *RHS; - SelectPatternFlavor SPF = - matchSelectPattern(I, LHS, RHS, /*CastOp=*/nullptr, Depth).Flavor; - if (SPF == SPF_ABS || SPF == SPF_NABS) { - auto *NewSelect = cast<SelectInst>(I->clone()); - // Just swap the operands of the select. - NewSelect->swapValues(); - // Don't swap prof metadata, we didn't change the branch behavior. - NewSelect->setName(I->getName() + ".neg"); - Builder.Insert(NewSelect); - return NewSelect; - } + if (isKnownNegation(I->getOperand(1), I->getOperand(2))) { + // Of one hand of select is known to be negation of another hand, + // just swap the hands around. + auto *NewSelect = cast<SelectInst>(I->clone()); + // Just swap the operands of the select. + NewSelect->swapValues(); + // Don't swap prof metadata, we didn't change the branch behavior. + NewSelect->setName(I->getName() + ".neg"); + Builder.Insert(NewSelect); + return NewSelect; } // `select` is negatible if both hands of `select` are negatible. Value *NegOp1 = negate(I->getOperand(1), Depth + 1); @@ -323,51 +350,81 @@ LLVM_NODISCARD Value *Negator::visitImpl(Value *V, unsigned Depth) { } case Instruction::Shl: { // `shl` is negatible if the first operand is negatible. - Value *NegOp0 = negate(I->getOperand(0), Depth + 1); - if (!NegOp0) // Early return. + if (Value *NegOp0 = negate(I->getOperand(0), Depth + 1)) + return Builder.CreateShl(NegOp0, I->getOperand(1), I->getName() + ".neg"); + // Otherwise, `shl %x, C` can be interpreted as `mul %x, 1<<C`. + auto *Op1C = dyn_cast<Constant>(I->getOperand(1)); + if (!Op1C) // Early return. return nullptr; - return Builder.CreateShl(NegOp0, I->getOperand(1), I->getName() + ".neg"); + return Builder.CreateMul( + I->getOperand(0), + ConstantExpr::getShl(Constant::getAllOnesValue(Op1C->getType()), Op1C), + I->getName() + ".neg"); } - case Instruction::Or: + case Instruction::Or: { if (!haveNoCommonBitsSet(I->getOperand(0), I->getOperand(1), DL, &AC, I, &DT)) return nullptr; // Don't know how to handle `or` in general. + std::array<Value *, 2> Ops = getSortedOperandsOfBinOp(I); // `or`/`add` are interchangeable when operands have no common bits set. // `inc` is always negatible. - if (match(I->getOperand(1), m_One())) - return Builder.CreateNot(I->getOperand(0), I->getName() + ".neg"); + if (match(Ops[1], m_One())) + return Builder.CreateNot(Ops[0], I->getName() + ".neg"); // Else, just defer to Instruction::Add handling. LLVM_FALLTHROUGH; + } case Instruction::Add: { // `add` is negatible if both of its operands are negatible. - Value *NegOp0 = negate(I->getOperand(0), Depth + 1); - if (!NegOp0) // Early return. - return nullptr; - Value *NegOp1 = negate(I->getOperand(1), Depth + 1); - if (!NegOp1) + SmallVector<Value *, 2> NegatedOps, NonNegatedOps; + for (Value *Op : I->operands()) { + // Can we sink the negation into this operand? + if (Value *NegOp = negate(Op, Depth + 1)) { + NegatedOps.emplace_back(NegOp); // Successfully negated operand! + continue; + } + // Failed to sink negation into this operand. IFF we started from negation + // and we manage to sink negation into one operand, we can still do this. + if (!IsTrulyNegation) + return nullptr; + NonNegatedOps.emplace_back(Op); // Just record which operand that was. + } + assert((NegatedOps.size() + NonNegatedOps.size()) == 2 && + "Internal consistency sanity check."); + // Did we manage to sink negation into both of the operands? + if (NegatedOps.size() == 2) // Then we get to keep the `add`! + return Builder.CreateAdd(NegatedOps[0], NegatedOps[1], + I->getName() + ".neg"); + assert(IsTrulyNegation && "We should have early-exited then."); + // Completely failed to sink negation? + if (NonNegatedOps.size() == 2) return nullptr; - return Builder.CreateAdd(NegOp0, NegOp1, I->getName() + ".neg"); + // 0-(a+b) --> (-a)-b + return Builder.CreateSub(NegatedOps[0], NonNegatedOps[0], + I->getName() + ".neg"); } - case Instruction::Xor: + case Instruction::Xor: { + std::array<Value *, 2> Ops = getSortedOperandsOfBinOp(I); // `xor` is negatible if one of its operands is invertible. // FIXME: InstCombineInverter? But how to connect Inverter and Negator? - if (auto *C = dyn_cast<Constant>(I->getOperand(1))) { - Value *Xor = Builder.CreateXor(I->getOperand(0), ConstantExpr::getNot(C)); + if (auto *C = dyn_cast<Constant>(Ops[1])) { + Value *Xor = Builder.CreateXor(Ops[0], ConstantExpr::getNot(C)); return Builder.CreateAdd(Xor, ConstantInt::get(Xor->getType(), 1), I->getName() + ".neg"); } return nullptr; + } case Instruction::Mul: { + std::array<Value *, 2> Ops = getSortedOperandsOfBinOp(I); // `mul` is negatible if one of its operands is negatible. Value *NegatedOp, *OtherOp; // First try the second operand, in case it's a constant it will be best to // just invert it instead of sinking the `neg` deeper. - if (Value *NegOp1 = negate(I->getOperand(1), Depth + 1)) { + if (Value *NegOp1 = negate(Ops[1], Depth + 1)) { NegatedOp = NegOp1; - OtherOp = I->getOperand(0); - } else if (Value *NegOp0 = negate(I->getOperand(0), Depth + 1)) { + OtherOp = Ops[0]; + } else if (Value *NegOp0 = negate(Ops[0], Depth + 1)) { NegatedOp = NegOp0; - OtherOp = I->getOperand(1); + OtherOp = Ops[1]; } else // Can't negate either of them. return nullptr; @@ -430,7 +487,7 @@ LLVM_NODISCARD Optional<Negator::Result> Negator::run(Value *Root) { } LLVM_NODISCARD Value *Negator::Negate(bool LHSIsZero, Value *Root, - InstCombiner &IC) { + InstCombinerImpl &IC) { ++NegatorTotalNegationsAttempted; LLVM_DEBUG(dbgs() << "Negator: attempting to sink negation into " << *Root << "\n"); diff --git a/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp b/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp index 2b2f2e1b9470..d687ec654438 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp @@ -13,11 +13,14 @@ #include "InstCombineInternal.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/Statistic.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/PatternMatch.h" #include "llvm/Support/CommandLine.h" +#include "llvm/Transforms/InstCombine/InstCombiner.h" #include "llvm/Transforms/Utils/Local.h" + using namespace llvm; using namespace llvm::PatternMatch; @@ -27,10 +30,16 @@ static cl::opt<unsigned> MaxNumPhis("instcombine-max-num-phis", cl::init(512), cl::desc("Maximum number phis to handle in intptr/ptrint folding")); +STATISTIC(NumPHIsOfInsertValues, + "Number of phi-of-insertvalue turned into insertvalue-of-phis"); +STATISTIC(NumPHIsOfExtractValues, + "Number of phi-of-extractvalue turned into extractvalue-of-phi"); +STATISTIC(NumPHICSEs, "Number of PHI's that got CSE'd"); + /// The PHI arguments will be folded into a single operation with a PHI node /// as input. The debug location of the single operation will be the merged /// locations of the original PHI node arguments. -void InstCombiner::PHIArgMergedDebugLoc(Instruction *Inst, PHINode &PN) { +void InstCombinerImpl::PHIArgMergedDebugLoc(Instruction *Inst, PHINode &PN) { auto *FirstInst = cast<Instruction>(PN.getIncomingValue(0)); Inst->setDebugLoc(FirstInst->getDebugLoc()); // We do not expect a CallInst here, otherwise, N-way merging of DebugLoc @@ -93,7 +102,7 @@ void InstCombiner::PHIArgMergedDebugLoc(Instruction *Inst, PHINode &PN) { // ptr_val_inc = ... // ... // -Instruction *InstCombiner::FoldIntegerTypedPHI(PHINode &PN) { +Instruction *InstCombinerImpl::foldIntegerTypedPHI(PHINode &PN) { if (!PN.getType()->isIntegerTy()) return nullptr; if (!PN.hasOneUse()) @@ -290,9 +299,86 @@ Instruction *InstCombiner::FoldIntegerTypedPHI(PHINode &PN) { IntToPtr->getOperand(0)->getType()); } +/// If we have something like phi [insertvalue(a,b,0), insertvalue(c,d,0)], +/// turn this into a phi[a,c] and phi[b,d] and a single insertvalue. +Instruction * +InstCombinerImpl::foldPHIArgInsertValueInstructionIntoPHI(PHINode &PN) { + auto *FirstIVI = cast<InsertValueInst>(PN.getIncomingValue(0)); + + // Scan to see if all operands are `insertvalue`'s with the same indicies, + // and all have a single use. + for (unsigned i = 1; i != PN.getNumIncomingValues(); ++i) { + auto *I = dyn_cast<InsertValueInst>(PN.getIncomingValue(i)); + if (!I || !I->hasOneUser() || I->getIndices() != FirstIVI->getIndices()) + return nullptr; + } + + // For each operand of an `insertvalue` + std::array<PHINode *, 2> NewOperands; + for (int OpIdx : {0, 1}) { + auto *&NewOperand = NewOperands[OpIdx]; + // Create a new PHI node to receive the values the operand has in each + // incoming basic block. + NewOperand = PHINode::Create( + FirstIVI->getOperand(OpIdx)->getType(), PN.getNumIncomingValues(), + FirstIVI->getOperand(OpIdx)->getName() + ".pn"); + // And populate each operand's PHI with said values. + for (auto Incoming : zip(PN.blocks(), PN.incoming_values())) + NewOperand->addIncoming( + cast<InsertValueInst>(std::get<1>(Incoming))->getOperand(OpIdx), + std::get<0>(Incoming)); + InsertNewInstBefore(NewOperand, PN); + } + + // And finally, create `insertvalue` over the newly-formed PHI nodes. + auto *NewIVI = InsertValueInst::Create(NewOperands[0], NewOperands[1], + FirstIVI->getIndices(), PN.getName()); + + PHIArgMergedDebugLoc(NewIVI, PN); + ++NumPHIsOfInsertValues; + return NewIVI; +} + +/// If we have something like phi [extractvalue(a,0), extractvalue(b,0)], +/// turn this into a phi[a,b] and a single extractvalue. +Instruction * +InstCombinerImpl::foldPHIArgExtractValueInstructionIntoPHI(PHINode &PN) { + auto *FirstEVI = cast<ExtractValueInst>(PN.getIncomingValue(0)); + + // Scan to see if all operands are `extractvalue`'s with the same indicies, + // and all have a single use. + for (unsigned i = 1; i != PN.getNumIncomingValues(); ++i) { + auto *I = dyn_cast<ExtractValueInst>(PN.getIncomingValue(i)); + if (!I || !I->hasOneUser() || I->getIndices() != FirstEVI->getIndices() || + I->getAggregateOperand()->getType() != + FirstEVI->getAggregateOperand()->getType()) + return nullptr; + } + + // Create a new PHI node to receive the values the aggregate operand has + // in each incoming basic block. + auto *NewAggregateOperand = PHINode::Create( + FirstEVI->getAggregateOperand()->getType(), PN.getNumIncomingValues(), + FirstEVI->getAggregateOperand()->getName() + ".pn"); + // And populate the PHI with said values. + for (auto Incoming : zip(PN.blocks(), PN.incoming_values())) + NewAggregateOperand->addIncoming( + cast<ExtractValueInst>(std::get<1>(Incoming))->getAggregateOperand(), + std::get<0>(Incoming)); + InsertNewInstBefore(NewAggregateOperand, PN); + + // And finally, create `extractvalue` over the newly-formed PHI nodes. + auto *NewEVI = ExtractValueInst::Create(NewAggregateOperand, + FirstEVI->getIndices(), PN.getName()); + + PHIArgMergedDebugLoc(NewEVI, PN); + ++NumPHIsOfExtractValues; + return NewEVI; +} + /// If we have something like phi [add (a,b), add(a,c)] and if a/b/c and the -/// adds all have a single use, turn this into a phi and a single binop. -Instruction *InstCombiner::FoldPHIArgBinOpIntoPHI(PHINode &PN) { +/// adds all have a single user, turn this into a phi and a single binop. +Instruction *InstCombinerImpl::foldPHIArgBinOpIntoPHI(PHINode &PN) { Instruction *FirstInst = cast<Instruction>(PN.getIncomingValue(0)); assert(isa<BinaryOperator>(FirstInst) || isa<CmpInst>(FirstInst)); unsigned Opc = FirstInst->getOpcode(); @@ -302,10 +388,10 @@ Instruction *InstCombiner::FoldPHIArgBinOpIntoPHI(PHINode &PN) { Type *LHSType = LHSVal->getType(); Type *RHSType = RHSVal->getType(); - // Scan to see if all operands are the same opcode, and all have one use. + // Scan to see if all operands are the same opcode, and all have one user. for (unsigned i = 1; i != PN.getNumIncomingValues(); ++i) { Instruction *I = dyn_cast<Instruction>(PN.getIncomingValue(i)); - if (!I || I->getOpcode() != Opc || !I->hasOneUse() || + if (!I || I->getOpcode() != Opc || !I->hasOneUser() || // Verify type of the LHS matches so we don't fold cmp's of different // types. I->getOperand(0)->getType() != LHSType || @@ -385,7 +471,7 @@ Instruction *InstCombiner::FoldPHIArgBinOpIntoPHI(PHINode &PN) { return NewBinOp; } -Instruction *InstCombiner::FoldPHIArgGEPIntoPHI(PHINode &PN) { +Instruction *InstCombinerImpl::foldPHIArgGEPIntoPHI(PHINode &PN) { GetElementPtrInst *FirstInst =cast<GetElementPtrInst>(PN.getIncomingValue(0)); SmallVector<Value*, 16> FixedOperands(FirstInst->op_begin(), @@ -401,11 +487,12 @@ Instruction *InstCombiner::FoldPHIArgGEPIntoPHI(PHINode &PN) { bool AllInBounds = true; - // Scan to see if all operands are the same opcode, and all have one use. + // Scan to see if all operands are the same opcode, and all have one user. for (unsigned i = 1; i != PN.getNumIncomingValues(); ++i) { - GetElementPtrInst *GEP= dyn_cast<GetElementPtrInst>(PN.getIncomingValue(i)); - if (!GEP || !GEP->hasOneUse() || GEP->getType() != FirstInst->getType() || - GEP->getNumOperands() != FirstInst->getNumOperands()) + GetElementPtrInst *GEP = + dyn_cast<GetElementPtrInst>(PN.getIncomingValue(i)); + if (!GEP || !GEP->hasOneUser() || GEP->getType() != FirstInst->getType() || + GEP->getNumOperands() != FirstInst->getNumOperands()) return nullptr; AllInBounds &= GEP->isInBounds(); @@ -494,7 +581,6 @@ Instruction *InstCombiner::FoldPHIArgGEPIntoPHI(PHINode &PN) { return NewGEP; } - /// Return true if we know that it is safe to sink the load out of the block /// that defines it. This means that it must be obvious the value of the load is /// not changed from the point of the load to the end of the block it is in. @@ -540,7 +626,7 @@ static bool isSafeAndProfitableToSinkLoad(LoadInst *L) { return true; } -Instruction *InstCombiner::FoldPHIArgLoadIntoPHI(PHINode &PN) { +Instruction *InstCombinerImpl::foldPHIArgLoadIntoPHI(PHINode &PN) { LoadInst *FirstLI = cast<LoadInst>(PN.getIncomingValue(0)); // FIXME: This is overconservative; this transform is allowed in some cases @@ -573,7 +659,7 @@ Instruction *InstCombiner::FoldPHIArgLoadIntoPHI(PHINode &PN) { // Check to see if all arguments are the same operation. for (unsigned i = 1, e = PN.getNumIncomingValues(); i != e; ++i) { LoadInst *LI = dyn_cast<LoadInst>(PN.getIncomingValue(i)); - if (!LI || !LI->hasOneUse()) + if (!LI || !LI->hasOneUser()) return nullptr; // We can't sink the load if the loaded value could be modified between @@ -654,7 +740,7 @@ Instruction *InstCombiner::FoldPHIArgLoadIntoPHI(PHINode &PN) { /// TODO: This function could handle other cast types, but then it might /// require special-casing a cast from the 'i1' type. See the comment in /// FoldPHIArgOpIntoPHI() about pessimizing illegal integer types. -Instruction *InstCombiner::FoldPHIArgZextsIntoPHI(PHINode &Phi) { +Instruction *InstCombinerImpl::foldPHIArgZextsIntoPHI(PHINode &Phi) { // We cannot create a new instruction after the PHI if the terminator is an // EHPad because there is no valid insertion point. if (Instruction *TI = Phi.getParent()->getTerminator()) @@ -686,8 +772,8 @@ Instruction *InstCombiner::FoldPHIArgZextsIntoPHI(PHINode &Phi) { unsigned NumConsts = 0; for (Value *V : Phi.incoming_values()) { if (auto *Zext = dyn_cast<ZExtInst>(V)) { - // All zexts must be identical and have one use. - if (Zext->getSrcTy() != NarrowType || !Zext->hasOneUse()) + // All zexts must be identical and have one user. + if (Zext->getSrcTy() != NarrowType || !Zext->hasOneUser()) return nullptr; NewIncoming.push_back(Zext->getOperand(0)); NumZexts++; @@ -728,7 +814,7 @@ Instruction *InstCombiner::FoldPHIArgZextsIntoPHI(PHINode &Phi) { /// If all operands to a PHI node are the same "unary" operator and they all are /// only used by the PHI, PHI together their inputs, and do the operation once, /// to the result of the PHI. -Instruction *InstCombiner::FoldPHIArgOpIntoPHI(PHINode &PN) { +Instruction *InstCombinerImpl::foldPHIArgOpIntoPHI(PHINode &PN) { // We cannot create a new instruction after the PHI if the terminator is an // EHPad because there is no valid insertion point. if (Instruction *TI = PN.getParent()->getTerminator()) @@ -738,9 +824,13 @@ Instruction *InstCombiner::FoldPHIArgOpIntoPHI(PHINode &PN) { Instruction *FirstInst = cast<Instruction>(PN.getIncomingValue(0)); if (isa<GetElementPtrInst>(FirstInst)) - return FoldPHIArgGEPIntoPHI(PN); + return foldPHIArgGEPIntoPHI(PN); if (isa<LoadInst>(FirstInst)) - return FoldPHIArgLoadIntoPHI(PN); + return foldPHIArgLoadIntoPHI(PN); + if (isa<InsertValueInst>(FirstInst)) + return foldPHIArgInsertValueInstructionIntoPHI(PN); + if (isa<ExtractValueInst>(FirstInst)) + return foldPHIArgExtractValueInstructionIntoPHI(PN); // Scan the instruction, looking for input operations that can be folded away. // If all input operands to the phi are the same instruction (e.g. a cast from @@ -763,7 +853,7 @@ Instruction *InstCombiner::FoldPHIArgOpIntoPHI(PHINode &PN) { // otherwise call FoldPHIArgBinOpIntoPHI. ConstantOp = dyn_cast<Constant>(FirstInst->getOperand(1)); if (!ConstantOp) - return FoldPHIArgBinOpIntoPHI(PN); + return foldPHIArgBinOpIntoPHI(PN); } else { return nullptr; // Cannot fold this operation. } @@ -771,7 +861,7 @@ Instruction *InstCombiner::FoldPHIArgOpIntoPHI(PHINode &PN) { // Check to see if all arguments are the same operation. for (unsigned i = 1, e = PN.getNumIncomingValues(); i != e; ++i) { Instruction *I = dyn_cast<Instruction>(PN.getIncomingValue(i)); - if (!I || !I->hasOneUse() || !I->isSameOperationAs(FirstInst)) + if (!I || !I->hasOneUser() || !I->isSameOperationAs(FirstInst)) return nullptr; if (CastSrcTy) { if (I->getOperand(0)->getType() != CastSrcTy) @@ -923,7 +1013,7 @@ struct LoweredPHIRecord { LoweredPHIRecord(PHINode *pn, unsigned Sh) : PN(pn), Shift(Sh), Width(0) {} }; -} +} // namespace namespace llvm { template<> @@ -944,7 +1034,7 @@ namespace llvm { LHS.Width == RHS.Width; } }; -} +} // namespace llvm /// This is an integer PHI and we know that it has an illegal type: see if it is @@ -955,7 +1045,7 @@ namespace llvm { /// TODO: The user of the trunc may be an bitcast to float/double/vector or an /// inttoptr. We should produce new PHIs in the right type. /// -Instruction *InstCombiner::SliceUpIllegalIntegerPHI(PHINode &FirstPhi) { +Instruction *InstCombinerImpl::SliceUpIllegalIntegerPHI(PHINode &FirstPhi) { // PHIUsers - Keep track of all of the truncated values extracted from a set // of PHIs, along with their offset. These are the things we want to rewrite. SmallVector<PHIUsageRecord, 16> PHIUsers; @@ -1129,13 +1219,85 @@ Instruction *InstCombiner::SliceUpIllegalIntegerPHI(PHINode &FirstPhi) { return replaceInstUsesWith(FirstPhi, Undef); } +static Value *SimplifyUsingControlFlow(InstCombiner &Self, PHINode &PN, + const DominatorTree &DT) { + // Simplify the following patterns: + // if (cond) + // / \ + // ... ... + // \ / + // phi [true] [false] + if (!PN.getType()->isIntegerTy(1)) + return nullptr; + + if (PN.getNumOperands() != 2) + return nullptr; + + // Make sure all inputs are constants. + if (!all_of(PN.operands(), [](Value *V) { return isa<ConstantInt>(V); })) + return nullptr; + + BasicBlock *BB = PN.getParent(); + // Do not bother with unreachable instructions. + if (!DT.isReachableFromEntry(BB)) + return nullptr; + + // Same inputs. + if (PN.getOperand(0) == PN.getOperand(1)) + return PN.getOperand(0); + + BasicBlock *TruePred = nullptr, *FalsePred = nullptr; + for (auto *Pred : predecessors(BB)) { + auto *Input = cast<ConstantInt>(PN.getIncomingValueForBlock(Pred)); + if (Input->isAllOnesValue()) + TruePred = Pred; + else + FalsePred = Pred; + } + assert(TruePred && FalsePred && "Must be!"); + + // Check which edge of the dominator dominates the true input. If it is the + // false edge, we should invert the condition. + auto *IDom = DT.getNode(BB)->getIDom()->getBlock(); + auto *BI = dyn_cast<BranchInst>(IDom->getTerminator()); + if (!BI || BI->isUnconditional()) + return nullptr; + + // Check that edges outgoing from the idom's terminators dominate respective + // inputs of the Phi. + BasicBlockEdge TrueOutEdge(IDom, BI->getSuccessor(0)); + BasicBlockEdge FalseOutEdge(IDom, BI->getSuccessor(1)); + + BasicBlockEdge TrueIncEdge(TruePred, BB); + BasicBlockEdge FalseIncEdge(FalsePred, BB); + + auto *Cond = BI->getCondition(); + if (DT.dominates(TrueOutEdge, TrueIncEdge) && + DT.dominates(FalseOutEdge, FalseIncEdge)) + // This Phi is actually equivalent to branching condition of IDom. + return Cond; + else if (DT.dominates(TrueOutEdge, FalseIncEdge) && + DT.dominates(FalseOutEdge, TrueIncEdge)) { + // This Phi is actually opposite to branching condition of IDom. We invert + // the condition that will potentially open up some opportunities for + // sinking. + auto InsertPt = BB->getFirstInsertionPt(); + if (InsertPt != BB->end()) { + Self.Builder.SetInsertPoint(&*InsertPt); + return Self.Builder.CreateNot(Cond); + } + } + + return nullptr; +} + // PHINode simplification // -Instruction *InstCombiner::visitPHINode(PHINode &PN) { +Instruction *InstCombinerImpl::visitPHINode(PHINode &PN) { if (Value *V = SimplifyInstruction(&PN, SQ.getWithInstruction(&PN))) return replaceInstUsesWith(PN, V); - if (Instruction *Result = FoldPHIArgZextsIntoPHI(PN)) + if (Instruction *Result = foldPHIArgZextsIntoPHI(PN)) return Result; // If all PHI operands are the same operation, pull them through the PHI, @@ -1143,18 +1305,16 @@ Instruction *InstCombiner::visitPHINode(PHINode &PN) { if (isa<Instruction>(PN.getIncomingValue(0)) && isa<Instruction>(PN.getIncomingValue(1)) && cast<Instruction>(PN.getIncomingValue(0))->getOpcode() == - cast<Instruction>(PN.getIncomingValue(1))->getOpcode() && - // FIXME: The hasOneUse check will fail for PHIs that use the value more - // than themselves more than once. - PN.getIncomingValue(0)->hasOneUse()) - if (Instruction *Result = FoldPHIArgOpIntoPHI(PN)) + cast<Instruction>(PN.getIncomingValue(1))->getOpcode() && + PN.getIncomingValue(0)->hasOneUser()) + if (Instruction *Result = foldPHIArgOpIntoPHI(PN)) return Result; // If this is a trivial cycle in the PHI node graph, remove it. Basically, if // this PHI only has a single use (a PHI), and if that PHI only has one use (a // PHI)... break the cycle. if (PN.hasOneUse()) { - if (Instruction *Result = FoldIntegerTypedPHI(PN)) + if (Instruction *Result = foldIntegerTypedPHI(PN)) return Result; Instruction *PHIUser = cast<Instruction>(PN.user_back()); @@ -1267,6 +1427,21 @@ Instruction *InstCombiner::visitPHINode(PHINode &PN) { } } + // Is there an identical PHI node in this basic block? + for (PHINode &IdenticalPN : PN.getParent()->phis()) { + // Ignore the PHI node itself. + if (&IdenticalPN == &PN) + continue; + // Note that even though we've just canonicalized this PHI, due to the + // worklist visitation order, there are no guarantess that *every* PHI + // has been canonicalized, so we can't just compare operands ranges. + if (!PN.isIdenticalToWhenDefined(&IdenticalPN)) + continue; + // Just use that PHI instead then. + ++NumPHICSEs; + return replaceInstUsesWith(PN, &IdenticalPN); + } + // If this is an integer PHI and we know that it has an illegal type, see if // it is only used by trunc or trunc(lshr) operations. If so, we split the // PHI into the various pieces being extracted. This sort of thing is @@ -1276,5 +1451,9 @@ Instruction *InstCombiner::visitPHINode(PHINode &PN) { if (Instruction *Res = SliceUpIllegalIntegerPHI(PN)) return Res; + // Ultimately, try to replace this Phi with a dominating condition. + if (auto *V = SimplifyUsingControlFlow(*this, PN, DT)) + return replaceInstUsesWith(PN, V); + return nullptr; } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp index 17124f717af7..f26c194d31b9 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -38,6 +38,7 @@ #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/KnownBits.h" #include "llvm/Transforms/InstCombine/InstCombineWorklist.h" +#include "llvm/Transforms/InstCombine/InstCombiner.h" #include <cassert> #include <utility> @@ -46,6 +47,11 @@ using namespace PatternMatch; #define DEBUG_TYPE "instcombine" +/// FIXME: Enabled by default until the pattern is supported well. +static cl::opt<bool> EnableUnsafeSelectTransform( + "instcombine-unsafe-select-transform", cl::init(true), + cl::desc("Enable poison-unsafe select to and/or transform")); + static Value *createMinMax(InstCombiner::BuilderTy &Builder, SelectPatternFlavor SPF, Value *A, Value *B) { CmpInst::Predicate Pred = getMinMaxPred(SPF); @@ -57,7 +63,7 @@ static Value *createMinMax(InstCombiner::BuilderTy &Builder, /// constant of a binop. static Instruction *foldSelectBinOpIdentity(SelectInst &Sel, const TargetLibraryInfo &TLI, - InstCombiner &IC) { + InstCombinerImpl &IC) { // The select condition must be an equality compare with a constant operand. Value *X; Constant *C; @@ -258,29 +264,9 @@ static unsigned getSelectFoldableOperands(BinaryOperator *I) { } } -/// For the same transformation as the previous function, return the identity -/// constant that goes into the select. -static APInt getSelectFoldableConstant(BinaryOperator *I) { - switch (I->getOpcode()) { - default: llvm_unreachable("This cannot happen!"); - case Instruction::Add: - case Instruction::Sub: - case Instruction::Or: - case Instruction::Xor: - case Instruction::Shl: - case Instruction::LShr: - case Instruction::AShr: - return APInt::getNullValue(I->getType()->getScalarSizeInBits()); - case Instruction::And: - return APInt::getAllOnesValue(I->getType()->getScalarSizeInBits()); - case Instruction::Mul: - return APInt(I->getType()->getScalarSizeInBits(), 1); - } -} - /// We have (select c, TI, FI), and we know that TI and FI have the same opcode. -Instruction *InstCombiner::foldSelectOpOp(SelectInst &SI, Instruction *TI, - Instruction *FI) { +Instruction *InstCombinerImpl::foldSelectOpOp(SelectInst &SI, Instruction *TI, + Instruction *FI) { // Don't break up min/max patterns. The hasOneUse checks below prevent that // for most cases, but vector min/max with bitcasts can be transformed. If the // one-use restrictions are eased for other patterns, we still don't want to @@ -302,10 +288,9 @@ Instruction *InstCombiner::foldSelectOpOp(SelectInst &SI, Instruction *TI, // The select condition may be a vector. We may only change the operand // type if the vector width remains the same (and matches the condition). if (auto *CondVTy = dyn_cast<VectorType>(CondTy)) { - if (!FIOpndTy->isVectorTy()) - return nullptr; - if (CondVTy->getNumElements() != - cast<VectorType>(FIOpndTy)->getNumElements()) + if (!FIOpndTy->isVectorTy() || + CondVTy->getElementCount() != + cast<VectorType>(FIOpndTy)->getElementCount()) return nullptr; // TODO: If the backend knew how to deal with casts better, we could @@ -418,8 +403,8 @@ static bool isSelect01(const APInt &C1I, const APInt &C2I) { /// Try to fold the select into one of the operands to allow further /// optimization. -Instruction *InstCombiner::foldSelectIntoOp(SelectInst &SI, Value *TrueVal, - Value *FalseVal) { +Instruction *InstCombinerImpl::foldSelectIntoOp(SelectInst &SI, Value *TrueVal, + Value *FalseVal) { // See the comment above GetSelectFoldableOperands for a description of the // transformation we are doing here. if (auto *TVI = dyn_cast<BinaryOperator>(TrueVal)) { @@ -433,14 +418,15 @@ Instruction *InstCombiner::foldSelectIntoOp(SelectInst &SI, Value *TrueVal, } if (OpToFold) { - APInt CI = getSelectFoldableConstant(TVI); + Constant *C = ConstantExpr::getBinOpIdentity(TVI->getOpcode(), + TVI->getType(), true); Value *OOp = TVI->getOperand(2-OpToFold); // Avoid creating select between 2 constants unless it's selecting // between 0, 1 and -1. const APInt *OOpC; bool OOpIsAPInt = match(OOp, m_APInt(OOpC)); - if (!isa<Constant>(OOp) || (OOpIsAPInt && isSelect01(CI, *OOpC))) { - Value *C = ConstantInt::get(OOp->getType(), CI); + if (!isa<Constant>(OOp) || + (OOpIsAPInt && isSelect01(C->getUniqueInteger(), *OOpC))) { Value *NewSel = Builder.CreateSelect(SI.getCondition(), OOp, C); NewSel->takeName(TVI); BinaryOperator *BO = BinaryOperator::Create(TVI->getOpcode(), @@ -464,14 +450,15 @@ Instruction *InstCombiner::foldSelectIntoOp(SelectInst &SI, Value *TrueVal, } if (OpToFold) { - APInt CI = getSelectFoldableConstant(FVI); + Constant *C = ConstantExpr::getBinOpIdentity(FVI->getOpcode(), + FVI->getType(), true); Value *OOp = FVI->getOperand(2-OpToFold); // Avoid creating select between 2 constants unless it's selecting // between 0, 1 and -1. const APInt *OOpC; bool OOpIsAPInt = match(OOp, m_APInt(OOpC)); - if (!isa<Constant>(OOp) || (OOpIsAPInt && isSelect01(CI, *OOpC))) { - Value *C = ConstantInt::get(OOp->getType(), CI); + if (!isa<Constant>(OOp) || + (OOpIsAPInt && isSelect01(C->getUniqueInteger(), *OOpC))) { Value *NewSel = Builder.CreateSelect(SI.getCondition(), C, OOp); NewSel->takeName(FVI); BinaryOperator *BO = BinaryOperator::Create(FVI->getOpcode(), @@ -782,25 +769,24 @@ static Value *canonicalizeSaturatedAdd(ICmpInst *Cmp, Value *TVal, Value *FVal, // Match unsigned saturated add of 2 variables with an unnecessary 'not'. // There are 8 commuted variants. - // Canonicalize -1 (saturated result) to true value of the select. Just - // swapping the compare operands is legal, because the selected value is the - // same in case of equality, so we can interchange u< and u<=. + // Canonicalize -1 (saturated result) to true value of the select. if (match(FVal, m_AllOnes())) { std::swap(TVal, FVal); - std::swap(Cmp0, Cmp1); + Pred = CmpInst::getInversePredicate(Pred); } if (!match(TVal, m_AllOnes())) return nullptr; - // Canonicalize predicate to 'ULT'. - if (Pred == ICmpInst::ICMP_UGT) { - Pred = ICmpInst::ICMP_ULT; + // Canonicalize predicate to less-than or less-or-equal-than. + if (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_UGE) { std::swap(Cmp0, Cmp1); + Pred = CmpInst::getSwappedPredicate(Pred); } - if (Pred != ICmpInst::ICMP_ULT) + if (Pred != ICmpInst::ICMP_ULT && Pred != ICmpInst::ICMP_ULE) return nullptr; // Match unsigned saturated add of 2 variables with an unnecessary 'not'. + // Strictness of the comparison is irrelevant. Value *Y; if (match(Cmp0, m_Not(m_Value(X))) && match(FVal, m_c_Add(m_Specific(X), m_Value(Y))) && Y == Cmp1) { @@ -809,6 +795,7 @@ static Value *canonicalizeSaturatedAdd(ICmpInst *Cmp, Value *TVal, Value *FVal, return Builder.CreateBinaryIntrinsic(Intrinsic::uadd_sat, X, Y); } // The 'not' op may be included in the sum but not the compare. + // Strictness of the comparison is irrelevant. X = Cmp0; Y = Cmp1; if (match(FVal, m_c_Add(m_Not(m_Specific(X)), m_Specific(Y)))) { @@ -819,7 +806,9 @@ static Value *canonicalizeSaturatedAdd(ICmpInst *Cmp, Value *TVal, Value *FVal, Intrinsic::uadd_sat, BO->getOperand(0), BO->getOperand(1)); } // The overflow may be detected via the add wrapping round. - if (match(Cmp0, m_c_Add(m_Specific(Cmp1), m_Value(Y))) && + // This is only valid for strict comparison! + if (Pred == ICmpInst::ICMP_ULT && + match(Cmp0, m_c_Add(m_Specific(Cmp1), m_Value(Y))) && match(FVal, m_c_Add(m_Specific(Cmp1), m_Specific(Y)))) { // ((X + Y) u< X) ? -1 : (X + Y) --> uadd.sat(X, Y) // ((X + Y) u< Y) ? -1 : (X + Y) --> uadd.sat(X, Y) @@ -1024,9 +1013,9 @@ static bool adjustMinMax(SelectInst &Sel, ICmpInst &Cmp) { /// select (icmp Pred X, C1), C2, X --> select (icmp Pred' X, C2), X, C2 /// Note: if C1 != C2, this will change the icmp constant to the existing /// constant operand of the select. -static Instruction * -canonicalizeMinMaxWithConstant(SelectInst &Sel, ICmpInst &Cmp, - InstCombiner &IC) { +static Instruction *canonicalizeMinMaxWithConstant(SelectInst &Sel, + ICmpInst &Cmp, + InstCombinerImpl &IC) { if (!Cmp.hasOneUse() || !isa<Constant>(Cmp.getOperand(1))) return nullptr; @@ -1063,105 +1052,29 @@ canonicalizeMinMaxWithConstant(SelectInst &Sel, ICmpInst &Cmp, return &Sel; } -/// There are many select variants for each of ABS/NABS. -/// In matchSelectPattern(), there are different compare constants, compare -/// predicates/operands and select operands. -/// In isKnownNegation(), there are different formats of negated operands. -/// Canonicalize all these variants to 1 pattern. -/// This makes CSE more likely. static Instruction *canonicalizeAbsNabs(SelectInst &Sel, ICmpInst &Cmp, - InstCombiner &IC) { + InstCombinerImpl &IC) { if (!Cmp.hasOneUse() || !isa<Constant>(Cmp.getOperand(1))) return nullptr; - // Choose a sign-bit check for the compare (likely simpler for codegen). - // ABS: (X <s 0) ? -X : X - // NABS: (X <s 0) ? X : -X Value *LHS, *RHS; SelectPatternFlavor SPF = matchSelectPattern(&Sel, LHS, RHS).Flavor; if (SPF != SelectPatternFlavor::SPF_ABS && SPF != SelectPatternFlavor::SPF_NABS) return nullptr; - Value *TVal = Sel.getTrueValue(); - Value *FVal = Sel.getFalseValue(); - assert(isKnownNegation(TVal, FVal) && - "Unexpected result from matchSelectPattern"); - - // The compare may use the negated abs()/nabs() operand, or it may use - // negation in non-canonical form such as: sub A, B. - bool CmpUsesNegatedOp = match(Cmp.getOperand(0), m_Neg(m_Specific(TVal))) || - match(Cmp.getOperand(0), m_Neg(m_Specific(FVal))); - - bool CmpCanonicalized = !CmpUsesNegatedOp && - match(Cmp.getOperand(1), m_ZeroInt()) && - Cmp.getPredicate() == ICmpInst::ICMP_SLT; - bool RHSCanonicalized = match(RHS, m_Neg(m_Specific(LHS))); - - // Is this already canonical? - if (CmpCanonicalized && RHSCanonicalized) - return nullptr; - - // If RHS is not canonical but is used by other instructions, don't - // canonicalize it and potentially increase the instruction count. - if (!RHSCanonicalized) - if (!(RHS->hasOneUse() || (RHS->hasNUses(2) && CmpUsesNegatedOp))) - return nullptr; + // Note that NSW flag can only be propagated for normal, non-negated abs! + bool IntMinIsPoison = SPF == SelectPatternFlavor::SPF_ABS && + match(RHS, m_NSWNeg(m_Specific(LHS))); + Constant *IntMinIsPoisonC = + ConstantInt::get(Type::getInt1Ty(Sel.getContext()), IntMinIsPoison); + Instruction *Abs = + IC.Builder.CreateBinaryIntrinsic(Intrinsic::abs, LHS, IntMinIsPoisonC); - // Create the canonical compare: icmp slt LHS 0. - if (!CmpCanonicalized) { - Cmp.setPredicate(ICmpInst::ICMP_SLT); - Cmp.setOperand(1, ConstantInt::getNullValue(Cmp.getOperand(0)->getType())); - if (CmpUsesNegatedOp) - Cmp.setOperand(0, LHS); - } - - // Create the canonical RHS: RHS = sub (0, LHS). - if (!RHSCanonicalized) { - assert(RHS->hasOneUse() && "RHS use number is not right"); - RHS = IC.Builder.CreateNeg(LHS); - if (TVal == LHS) { - // Replace false value. - IC.replaceOperand(Sel, 2, RHS); - FVal = RHS; - } else { - // Replace true value. - IC.replaceOperand(Sel, 1, RHS); - TVal = RHS; - } - } + if (SPF == SelectPatternFlavor::SPF_NABS) + return BinaryOperator::CreateNeg(Abs); // Always without NSW flag! - // If the select operands do not change, we're done. - if (SPF == SelectPatternFlavor::SPF_NABS) { - if (TVal == LHS) - return &Sel; - assert(FVal == LHS && "Unexpected results from matchSelectPattern"); - } else { - if (FVal == LHS) - return &Sel; - assert(TVal == LHS && "Unexpected results from matchSelectPattern"); - } - - // We are swapping the select operands, so swap the metadata too. - Sel.swapValues(); - Sel.swapProfMetadata(); - return &Sel; -} - -static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *ReplaceOp, - const SimplifyQuery &Q) { - // If this is a binary operator, try to simplify it with the replaced op - // because we know Op and ReplaceOp are equivalant. - // For example: V = X + 1, Op = X, ReplaceOp = 42 - // Simplifies as: add(42, 1) --> 43 - if (auto *BO = dyn_cast<BinaryOperator>(V)) { - if (BO->getOperand(0) == Op) - return SimplifyBinOp(BO->getOpcode(), ReplaceOp, BO->getOperand(1), Q); - if (BO->getOperand(1) == Op) - return SimplifyBinOp(BO->getOpcode(), BO->getOperand(0), ReplaceOp, Q); - } - - return nullptr; + return IC.replaceInstUsesWith(Sel, Abs); } /// If we have a select with an equality comparison, then we know the value in @@ -1180,30 +1093,97 @@ static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *ReplaceOp, /// /// We can't replace %sel with %add unless we strip away the flags. /// TODO: Wrapping flags could be preserved in some cases with better analysis. -static Value *foldSelectValueEquivalence(SelectInst &Sel, ICmpInst &Cmp, - const SimplifyQuery &Q) { +Instruction *InstCombinerImpl::foldSelectValueEquivalence(SelectInst &Sel, + ICmpInst &Cmp) { if (!Cmp.isEquality()) return nullptr; // Canonicalize the pattern to ICMP_EQ by swapping the select operands. Value *TrueVal = Sel.getTrueValue(), *FalseVal = Sel.getFalseValue(); - if (Cmp.getPredicate() == ICmpInst::ICMP_NE) + bool Swapped = false; + if (Cmp.getPredicate() == ICmpInst::ICMP_NE) { std::swap(TrueVal, FalseVal); + Swapped = true; + } + + // In X == Y ? f(X) : Z, try to evaluate f(Y) and replace the operand. + // Make sure Y cannot be undef though, as we might pick different values for + // undef in the icmp and in f(Y). Additionally, take care to avoid replacing + // X == Y ? X : Z with X == Y ? Y : Z, as that would lead to an infinite + // replacement cycle. + Value *CmpLHS = Cmp.getOperand(0), *CmpRHS = Cmp.getOperand(1); + if (TrueVal != CmpLHS && + isGuaranteedNotToBeUndefOrPoison(CmpRHS, SQ.AC, &Sel, &DT)) { + if (Value *V = SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, SQ, + /* AllowRefinement */ true)) + return replaceOperand(Sel, Swapped ? 2 : 1, V); + + // Even if TrueVal does not simplify, we can directly replace a use of + // CmpLHS with CmpRHS, as long as the instruction is not used anywhere + // else and is safe to speculatively execute (we may end up executing it + // with different operands, which should not cause side-effects or trigger + // undefined behavior). Only do this if CmpRHS is a constant, as + // profitability is not clear for other cases. + // FIXME: The replacement could be performed recursively. + if (match(CmpRHS, m_ImmConstant()) && !match(CmpLHS, m_ImmConstant())) + if (auto *I = dyn_cast<Instruction>(TrueVal)) + if (I->hasOneUse() && isSafeToSpeculativelyExecute(I)) + for (Use &U : I->operands()) + if (U == CmpLHS) { + replaceUse(U, CmpRHS); + return &Sel; + } + } + if (TrueVal != CmpRHS && + isGuaranteedNotToBeUndefOrPoison(CmpLHS, SQ.AC, &Sel, &DT)) + if (Value *V = SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, SQ, + /* AllowRefinement */ true)) + return replaceOperand(Sel, Swapped ? 2 : 1, V); + + auto *FalseInst = dyn_cast<Instruction>(FalseVal); + if (!FalseInst) + return nullptr; + + // InstSimplify already performed this fold if it was possible subject to + // current poison-generating flags. Try the transform again with + // poison-generating flags temporarily dropped. + bool WasNUW = false, WasNSW = false, WasExact = false, WasInBounds = false; + if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(FalseVal)) { + WasNUW = OBO->hasNoUnsignedWrap(); + WasNSW = OBO->hasNoSignedWrap(); + FalseInst->setHasNoUnsignedWrap(false); + FalseInst->setHasNoSignedWrap(false); + } + if (auto *PEO = dyn_cast<PossiblyExactOperator>(FalseVal)) { + WasExact = PEO->isExact(); + FalseInst->setIsExact(false); + } + if (auto *GEP = dyn_cast<GetElementPtrInst>(FalseVal)) { + WasInBounds = GEP->isInBounds(); + GEP->setIsInBounds(false); + } // Try each equivalence substitution possibility. // We have an 'EQ' comparison, so the select's false value will propagate. // Example: // (X == 42) ? 43 : (X + 1) --> (X == 42) ? (X + 1) : (X + 1) --> X + 1 - // (X == 42) ? (X + 1) : 43 --> (X == 42) ? (42 + 1) : 43 --> 43 - Value *CmpLHS = Cmp.getOperand(0), *CmpRHS = Cmp.getOperand(1); - if (simplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, Q) == TrueVal || - simplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, Q) == TrueVal || - simplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, Q) == FalseVal || - simplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, Q) == FalseVal) { - if (auto *FalseInst = dyn_cast<Instruction>(FalseVal)) - FalseInst->dropPoisonGeneratingFlags(); - return FalseVal; - } + if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, SQ, + /* AllowRefinement */ false) == TrueVal || + SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, SQ, + /* AllowRefinement */ false) == TrueVal) { + return replaceInstUsesWith(Sel, FalseVal); + } + + // Restore poison-generating flags if the transform did not apply. + if (WasNUW) + FalseInst->setHasNoUnsignedWrap(); + if (WasNSW) + FalseInst->setHasNoSignedWrap(); + if (WasExact) + FalseInst->setIsExact(); + if (WasInBounds) + cast<GetElementPtrInst>(FalseInst)->setIsInBounds(); + return nullptr; } @@ -1253,7 +1233,7 @@ static Instruction *canonicalizeClampLike(SelectInst &Sel0, ICmpInst &Cmp0, APInt::getAllOnesValue( C0->getType()->getScalarSizeInBits())))) return nullptr; // Can't do, have all-ones element[s]. - C0 = AddOne(C0); + C0 = InstCombiner::AddOne(C0); std::swap(X, Sel1); break; case ICmpInst::Predicate::ICMP_UGE: @@ -1313,7 +1293,7 @@ static Instruction *canonicalizeClampLike(SelectInst &Sel0, ICmpInst &Cmp0, APInt::getSignedMaxValue( C2->getType()->getScalarSizeInBits())))) return nullptr; // Can't do, have signed max element[s]. - C2 = AddOne(C2); + C2 = InstCombiner::AddOne(C2); LLVM_FALLTHROUGH; case ICmpInst::Predicate::ICMP_SGE: // Also non-canonical, but here we don't need to change C2, @@ -1360,7 +1340,7 @@ static Instruction *canonicalizeClampLike(SelectInst &Sel0, ICmpInst &Cmp0, // and swap the hands of select. static Instruction * tryToReuseConstantFromSelectInComparison(SelectInst &Sel, ICmpInst &Cmp, - InstCombiner &IC) { + InstCombinerImpl &IC) { ICmpInst::Predicate Pred; Value *X; Constant *C0; @@ -1375,7 +1355,7 @@ tryToReuseConstantFromSelectInComparison(SelectInst &Sel, ICmpInst &Cmp, // If comparison predicate is non-canonical, then we certainly won't be able // to make it canonical; canonicalizeCmpWithConstant() already tried. - if (!isCanonicalPredicate(Pred)) + if (!InstCombiner::isCanonicalPredicate(Pred)) return nullptr; // If the [input] type of comparison and select type are different, lets abort @@ -1403,7 +1383,8 @@ tryToReuseConstantFromSelectInComparison(SelectInst &Sel, ICmpInst &Cmp, return nullptr; // Check the constant we'd have with flipped-strictness predicate. - auto FlippedStrictness = getFlippedStrictnessPredicateAndConstant(Pred, C0); + auto FlippedStrictness = + InstCombiner::getFlippedStrictnessPredicateAndConstant(Pred, C0); if (!FlippedStrictness) return nullptr; @@ -1426,10 +1407,10 @@ tryToReuseConstantFromSelectInComparison(SelectInst &Sel, ICmpInst &Cmp, } /// Visit a SelectInst that has an ICmpInst as its first operand. -Instruction *InstCombiner::foldSelectInstWithICmp(SelectInst &SI, - ICmpInst *ICI) { - if (Value *V = foldSelectValueEquivalence(SI, *ICI, SQ)) - return replaceInstUsesWith(SI, V); +Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI, + ICmpInst *ICI) { + if (Instruction *NewSel = foldSelectValueEquivalence(SI, *ICI)) + return NewSel; if (Instruction *NewSel = canonicalizeMinMaxWithConstant(SI, *ICI, *this)) return NewSel; @@ -1579,11 +1560,11 @@ static bool canSelectOperandBeMappingIntoPredBlock(const Value *V, /// We have an SPF (e.g. a min or max) of an SPF of the form: /// SPF2(SPF1(A, B), C) -Instruction *InstCombiner::foldSPFofSPF(Instruction *Inner, - SelectPatternFlavor SPF1, - Value *A, Value *B, - Instruction &Outer, - SelectPatternFlavor SPF2, Value *C) { +Instruction *InstCombinerImpl::foldSPFofSPF(Instruction *Inner, + SelectPatternFlavor SPF1, Value *A, + Value *B, Instruction &Outer, + SelectPatternFlavor SPF2, + Value *C) { if (Outer.getType() != Inner->getType()) return nullptr; @@ -1900,7 +1881,7 @@ foldOverflowingAddSubSelect(SelectInst &SI, InstCombiner::BuilderTy &Builder) { return CallInst::Create(F, {X, Y}); } -Instruction *InstCombiner::foldSelectExtConst(SelectInst &Sel) { +Instruction *InstCombinerImpl::foldSelectExtConst(SelectInst &Sel) { Constant *C; if (!match(Sel.getTrueValue(), m_Constant(C)) && !match(Sel.getFalseValue(), m_Constant(C))) @@ -1966,10 +1947,11 @@ Instruction *InstCombiner::foldSelectExtConst(SelectInst &Sel) { static Instruction *canonicalizeSelectToShuffle(SelectInst &SI) { Value *CondVal = SI.getCondition(); Constant *CondC; - if (!CondVal->getType()->isVectorTy() || !match(CondVal, m_Constant(CondC))) + auto *CondValTy = dyn_cast<FixedVectorType>(CondVal->getType()); + if (!CondValTy || !match(CondVal, m_Constant(CondC))) return nullptr; - unsigned NumElts = cast<VectorType>(CondVal->getType())->getNumElements(); + unsigned NumElts = CondValTy->getNumElements(); SmallVector<int, 16> Mask; Mask.reserve(NumElts); for (unsigned i = 0; i != NumElts; ++i) { @@ -2001,8 +1983,8 @@ static Instruction *canonicalizeSelectToShuffle(SelectInst &SI) { /// to a vector select by splatting the condition. A splat may get folded with /// other operations in IR and having all operands of a select be vector types /// is likely better for vector codegen. -static Instruction *canonicalizeScalarSelectOfVecs( - SelectInst &Sel, InstCombiner &IC) { +static Instruction *canonicalizeScalarSelectOfVecs(SelectInst &Sel, + InstCombinerImpl &IC) { auto *Ty = dyn_cast<VectorType>(Sel.getType()); if (!Ty) return nullptr; @@ -2015,8 +1997,8 @@ static Instruction *canonicalizeScalarSelectOfVecs( // select (extelt V, Index), T, F --> select (splat V, Index), T, F // Splatting the extracted condition reduces code (we could directly create a // splat shuffle of the source vector to eliminate the intermediate step). - unsigned NumElts = Ty->getNumElements(); - return IC.replaceOperand(Sel, 0, IC.Builder.CreateVectorSplat(NumElts, Cond)); + return IC.replaceOperand( + Sel, 0, IC.Builder.CreateVectorSplat(Ty->getElementCount(), Cond)); } /// Reuse bitcasted operands between a compare and select: @@ -2172,7 +2154,7 @@ static Instruction *moveAddAfterMinMax(SelectPatternFlavor SPF, Value *X, } /// Match a sadd_sat or ssub_sat which is using min/max to clamp the value. -Instruction *InstCombiner::matchSAddSubSat(SelectInst &MinMax1) { +Instruction *InstCombinerImpl::matchSAddSubSat(SelectInst &MinMax1) { Type *Ty = MinMax1.getType(); // We are looking for a tree of: @@ -2293,34 +2275,42 @@ static Instruction *factorizeMinMaxTree(SelectPatternFlavor SPF, Value *LHS, return SelectInst::Create(CmpABC, MinMaxOp, ThirdOp); } -/// Try to reduce a rotate pattern that includes a compare and select into a -/// funnel shift intrinsic. Example: +/// Try to reduce a funnel/rotate pattern that includes a compare and select +/// into a funnel shift intrinsic. Example: /// rotl32(a, b) --> (b == 0 ? a : ((a >> (32 - b)) | (a << b))) /// --> call llvm.fshl.i32(a, a, b) -static Instruction *foldSelectRotate(SelectInst &Sel) { - // The false value of the select must be a rotate of the true value. - Value *Or0, *Or1; - if (!match(Sel.getFalseValue(), m_OneUse(m_Or(m_Value(Or0), m_Value(Or1))))) +/// fshl32(a, b, c) --> (c == 0 ? a : ((b >> (32 - c)) | (a << c))) +/// --> call llvm.fshl.i32(a, b, c) +/// fshr32(a, b, c) --> (c == 0 ? b : ((a >> (32 - c)) | (b << c))) +/// --> call llvm.fshr.i32(a, b, c) +static Instruction *foldSelectFunnelShift(SelectInst &Sel, + InstCombiner::BuilderTy &Builder) { + // This must be a power-of-2 type for a bitmasking transform to be valid. + unsigned Width = Sel.getType()->getScalarSizeInBits(); + if (!isPowerOf2_32(Width)) return nullptr; - Value *TVal = Sel.getTrueValue(); - Value *SA0, *SA1; - if (!match(Or0, m_OneUse(m_LogicalShift(m_Specific(TVal), m_Value(SA0)))) || - !match(Or1, m_OneUse(m_LogicalShift(m_Specific(TVal), m_Value(SA1))))) + BinaryOperator *Or0, *Or1; + if (!match(Sel.getFalseValue(), m_OneUse(m_Or(m_BinOp(Or0), m_BinOp(Or1))))) return nullptr; - auto ShiftOpcode0 = cast<BinaryOperator>(Or0)->getOpcode(); - auto ShiftOpcode1 = cast<BinaryOperator>(Or1)->getOpcode(); - if (ShiftOpcode0 == ShiftOpcode1) + Value *SV0, *SV1, *SA0, *SA1; + if (!match(Or0, m_OneUse(m_LogicalShift(m_Value(SV0), + m_ZExtOrSelf(m_Value(SA0))))) || + !match(Or1, m_OneUse(m_LogicalShift(m_Value(SV1), + m_ZExtOrSelf(m_Value(SA1))))) || + Or0->getOpcode() == Or1->getOpcode()) return nullptr; - // We have one of these patterns so far: - // select ?, TVal, (or (lshr TVal, SA0), (shl TVal, SA1)) - // select ?, TVal, (or (shl TVal, SA0), (lshr TVal, SA1)) - // This must be a power-of-2 rotate for a bitmasking transform to be valid. - unsigned Width = Sel.getType()->getScalarSizeInBits(); - if (!isPowerOf2_32(Width)) - return nullptr; + // Canonicalize to or(shl(SV0, SA0), lshr(SV1, SA1)). + if (Or0->getOpcode() == BinaryOperator::LShr) { + std::swap(Or0, Or1); + std::swap(SV0, SV1); + std::swap(SA0, SA1); + } + assert(Or0->getOpcode() == BinaryOperator::Shl && + Or1->getOpcode() == BinaryOperator::LShr && + "Illegal or(shift,shift) pair"); // Check the shift amounts to see if they are an opposite pair. Value *ShAmt; @@ -2331,6 +2321,15 @@ static Instruction *foldSelectRotate(SelectInst &Sel) { else return nullptr; + // We should now have this pattern: + // select ?, TVal, (or (shl SV0, SA0), (lshr SV1, SA1)) + // The false value of the select must be a funnel-shift of the true value: + // IsFShl -> TVal must be SV0 else TVal must be SV1. + bool IsFshl = (ShAmt == SA0); + Value *TVal = Sel.getTrueValue(); + if ((IsFshl && TVal != SV0) || (!IsFshl && TVal != SV1)) + return nullptr; + // Finally, see if the select is filtering out a shift-by-zero. Value *Cond = Sel.getCondition(); ICmpInst::Predicate Pred; @@ -2338,13 +2337,21 @@ static Instruction *foldSelectRotate(SelectInst &Sel) { Pred != ICmpInst::ICMP_EQ) return nullptr; - // This is a rotate that avoids shift-by-bitwidth UB in a suboptimal way. + // If this is not a rotate then the select was blocking poison from the + // 'shift-by-zero' non-TVal, but a funnel shift won't - so freeze it. + if (SV0 != SV1) { + if (IsFshl && !llvm::isGuaranteedNotToBePoison(SV1)) + SV1 = Builder.CreateFreeze(SV1); + else if (!IsFshl && !llvm::isGuaranteedNotToBePoison(SV0)) + SV0 = Builder.CreateFreeze(SV0); + } + + // This is a funnel/rotate that avoids shift-by-bitwidth UB in a suboptimal way. // Convert to funnel shift intrinsic. - bool IsFshl = (ShAmt == SA0 && ShiftOpcode0 == BinaryOperator::Shl) || - (ShAmt == SA1 && ShiftOpcode1 == BinaryOperator::Shl); Intrinsic::ID IID = IsFshl ? Intrinsic::fshl : Intrinsic::fshr; Function *F = Intrinsic::getDeclaration(Sel.getModule(), IID, Sel.getType()); - return IntrinsicInst::Create(F, { TVal, TVal, ShAmt }); + ShAmt = Builder.CreateZExt(ShAmt, Sel.getType()); + return IntrinsicInst::Create(F, { SV0, SV1, ShAmt }); } static Instruction *foldSelectToCopysign(SelectInst &Sel, @@ -2368,7 +2375,8 @@ static Instruction *foldSelectToCopysign(SelectInst &Sel, bool IsTrueIfSignSet; ICmpInst::Predicate Pred; if (!match(Cond, m_OneUse(m_ICmp(Pred, m_BitCast(m_Value(X)), m_APInt(C)))) || - !isSignBitCheck(Pred, *C, IsTrueIfSignSet) || X->getType() != SelType) + !InstCombiner::isSignBitCheck(Pred, *C, IsTrueIfSignSet) || + X->getType() != SelType) return nullptr; // If needed, negate the value that will be the sign argument of the copysign: @@ -2389,7 +2397,7 @@ static Instruction *foldSelectToCopysign(SelectInst &Sel, return CopySign; } -Instruction *InstCombiner::foldVectorSelect(SelectInst &Sel) { +Instruction *InstCombinerImpl::foldVectorSelect(SelectInst &Sel) { auto *VecTy = dyn_cast<FixedVectorType>(Sel.getType()); if (!VecTy) return nullptr; @@ -2469,6 +2477,10 @@ static Instruction *foldSelectToPhiImpl(SelectInst &Sel, BasicBlock *BB, } else return nullptr; + // Make sure the branches are actually different. + if (TrueSucc == FalseSucc) + return nullptr; + // We want to replace select %cond, %a, %b with a phi that takes value %a // for all incoming edges that are dominated by condition `%cond == true`, // and value %b for edges dominated by condition `%cond == false`. If %a @@ -2515,7 +2527,33 @@ static Instruction *foldSelectToPhi(SelectInst &Sel, const DominatorTree &DT, return nullptr; } -Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { +static Value *foldSelectWithFrozenICmp(SelectInst &Sel, InstCombiner::BuilderTy &Builder) { + FreezeInst *FI = dyn_cast<FreezeInst>(Sel.getCondition()); + if (!FI) + return nullptr; + + Value *Cond = FI->getOperand(0); + Value *TrueVal = Sel.getTrueValue(), *FalseVal = Sel.getFalseValue(); + + // select (freeze(x == y)), x, y --> y + // select (freeze(x != y)), x, y --> x + // The freeze should be only used by this select. Otherwise, remaining uses of + // the freeze can observe a contradictory value. + // c = freeze(x == y) ; Let's assume that y = poison & x = 42; c is 0 or 1 + // a = select c, x, y ; + // f(a, c) ; f(poison, 1) cannot happen, but if a is folded + // ; to y, this can happen. + CmpInst::Predicate Pred; + if (FI->hasOneUse() && + match(Cond, m_c_ICmp(Pred, m_Specific(TrueVal), m_Specific(FalseVal))) && + (Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE)) { + return Pred == ICmpInst::ICMP_EQ ? FalseVal : TrueVal; + } + + return nullptr; +} + +Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { Value *CondVal = SI.getCondition(); Value *TrueVal = SI.getTrueValue(); Value *FalseVal = SI.getFalseValue(); @@ -2551,38 +2589,45 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { if (SelType->isIntOrIntVectorTy(1) && TrueVal->getType() == CondVal->getType()) { - if (match(TrueVal, m_One())) { + if (match(TrueVal, m_One()) && + (EnableUnsafeSelectTransform || impliesPoison(FalseVal, CondVal))) { // Change: A = select B, true, C --> A = or B, C return BinaryOperator::CreateOr(CondVal, FalseVal); } - if (match(TrueVal, m_Zero())) { - // Change: A = select B, false, C --> A = and !B, C - Value *NotCond = Builder.CreateNot(CondVal, "not." + CondVal->getName()); - return BinaryOperator::CreateAnd(NotCond, FalseVal); - } - if (match(FalseVal, m_Zero())) { + if (match(FalseVal, m_Zero()) && + (EnableUnsafeSelectTransform || impliesPoison(TrueVal, CondVal))) { // Change: A = select B, C, false --> A = and B, C return BinaryOperator::CreateAnd(CondVal, TrueVal); } + + // select a, false, b -> select !a, b, false + if (match(TrueVal, m_Zero())) { + Value *NotCond = Builder.CreateNot(CondVal, "not." + CondVal->getName()); + return SelectInst::Create(NotCond, FalseVal, + ConstantInt::getFalse(SelType)); + } + // select a, b, true -> select !a, true, b if (match(FalseVal, m_One())) { - // Change: A = select B, C, true --> A = or !B, C Value *NotCond = Builder.CreateNot(CondVal, "not." + CondVal->getName()); - return BinaryOperator::CreateOr(NotCond, TrueVal); + return SelectInst::Create(NotCond, ConstantInt::getTrue(SelType), + TrueVal); } - // select a, a, b -> a | b - // select a, b, a -> a & b + // select a, a, b -> select a, true, b if (CondVal == TrueVal) - return BinaryOperator::CreateOr(CondVal, FalseVal); + return replaceOperand(SI, 1, ConstantInt::getTrue(SelType)); + // select a, b, a -> select a, b, false if (CondVal == FalseVal) - return BinaryOperator::CreateAnd(CondVal, TrueVal); + return replaceOperand(SI, 2, ConstantInt::getFalse(SelType)); - // select a, ~a, b -> (~a) & b - // select a, b, ~a -> (~a) | b + // select a, !a, b -> select !a, b, false if (match(TrueVal, m_Not(m_Specific(CondVal)))) - return BinaryOperator::CreateAnd(TrueVal, FalseVal); + return SelectInst::Create(TrueVal, FalseVal, + ConstantInt::getFalse(SelType)); + // select a, b, !a -> select !a, true, b if (match(FalseVal, m_Not(m_Specific(CondVal)))) - return BinaryOperator::CreateOr(TrueVal, FalseVal); + return SelectInst::Create(FalseVal, ConstantInt::getTrue(SelType), + TrueVal); } // Selecting between two integer or vector splat integer constants? @@ -2591,7 +2636,10 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { // select i1 %c, <2 x i8> <1, 1>, <2 x i8> <0, 0> // because that may need 3 instructions to splat the condition value: // extend, insertelement, shufflevector. - if (SelType->isIntOrIntVectorTy() && + // + // Do not handle i1 TrueVal and FalseVal otherwise would result in + // zext/sext i1 to i1. + if (SelType->isIntOrIntVectorTy() && !SelType->isIntOrIntVectorTy(1) && CondVal->getType()->isVectorTy() == SelType->isVectorTy()) { // select C, 1, 0 -> zext C to int if (match(TrueVal, m_One()) && match(FalseVal, m_Zero())) @@ -2838,8 +2886,9 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { return replaceOperand(SI, 1, TrueSI->getTrueValue()); } // select(C0, select(C1, a, b), b) -> select(C0&C1, a, b) - // We choose this as normal form to enable folding on the And and shortening - // paths for the values (this helps GetUnderlyingObjects() for example). + // We choose this as normal form to enable folding on the And and + // shortening paths for the values (this helps getUnderlyingObjects() for + // example). if (TrueSI->getFalseValue() == FalseVal && TrueSI->hasOneUse()) { Value *And = Builder.CreateAnd(CondVal, TrueSI->getCondition()); replaceOperand(SI, 0, And); @@ -2922,7 +2971,8 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { } Value *NotCond; - if (match(CondVal, m_Not(m_Value(NotCond)))) { + if (match(CondVal, m_Not(m_Value(NotCond))) && + !InstCombiner::shouldAvoidAbsorbingNotIntoSelect(SI)) { replaceOperand(SI, 0, NotCond); SI.swapValues(); SI.swapProfMetadata(); @@ -2956,8 +3006,8 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { if (Instruction *Select = foldSelectBinOpIdentity(SI, TLI, *this)) return Select; - if (Instruction *Rot = foldSelectRotate(SI)) - return Rot; + if (Instruction *Funnel = foldSelectFunnelShift(SI, Builder)) + return Funnel; if (Instruction *Copysign = foldSelectToCopysign(SI, Builder)) return Copysign; @@ -2965,5 +3015,8 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { if (Instruction *PN = foldSelectToPhi(SI, DT, Builder)) return replaceInstUsesWith(SI, PN); + if (Value *Fr = foldSelectWithFrozenICmp(SI, Builder)) + return replaceInstUsesWith(SI, Fr); + return nullptr; } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp index 0a842b4e1047..7295369365c4 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp @@ -15,6 +15,7 @@ #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/Transforms/InstCombine/InstCombiner.h" using namespace llvm; using namespace PatternMatch; @@ -31,7 +32,7 @@ using namespace PatternMatch; // // AnalyzeForSignBitExtraction indicates that we will only analyze whether this // pattern has any 2 right-shifts that sum to 1 less than original bit width. -Value *InstCombiner::reassociateShiftAmtsOfTwoSameDirectionShifts( +Value *InstCombinerImpl::reassociateShiftAmtsOfTwoSameDirectionShifts( BinaryOperator *Sh0, const SimplifyQuery &SQ, bool AnalyzeForSignBitExtraction) { // Look for a shift of some instruction, ignore zext of shift amount if any. @@ -327,8 +328,8 @@ static Instruction *foldShiftOfShiftedLogic(BinaryOperator &I, if (!LogicInst || !LogicInst->isBitwiseLogicOp() || !LogicInst->hasOneUse()) return nullptr; - const APInt *C0, *C1; - if (!match(I.getOperand(1), m_APInt(C1))) + Constant *C0, *C1; + if (!match(I.getOperand(1), m_Constant(C1))) return nullptr; Instruction::BinaryOps ShiftOpcode = I.getOpcode(); @@ -339,10 +340,12 @@ static Instruction *foldShiftOfShiftedLogic(BinaryOperator &I, // TODO: Remove the one-use check if the other logic operand (Y) is constant. Value *X, *Y; auto matchFirstShift = [&](Value *V) { - return !isa<ConstantExpr>(V) && - match(V, m_OneUse(m_Shift(m_Value(X), m_APInt(C0)))) && - cast<BinaryOperator>(V)->getOpcode() == ShiftOpcode && - (*C0 + *C1).ult(Ty->getScalarSizeInBits()); + BinaryOperator *BO; + APInt Threshold(Ty->getScalarSizeInBits(), Ty->getScalarSizeInBits()); + return match(V, m_BinOp(BO)) && BO->getOpcode() == ShiftOpcode && + match(V, m_OneUse(m_Shift(m_Value(X), m_Constant(C0)))) && + match(ConstantExpr::getAdd(C0, C1), + m_SpecificInt_ICMP(ICmpInst::ICMP_ULT, Threshold)); }; // Logic ops are commutative, so check each operand for a match. @@ -354,13 +357,13 @@ static Instruction *foldShiftOfShiftedLogic(BinaryOperator &I, return nullptr; // shift (logic (shift X, C0), Y), C1 -> logic (shift X, C0+C1), (shift Y, C1) - Constant *ShiftSumC = ConstantInt::get(Ty, *C0 + *C1); + Constant *ShiftSumC = ConstantExpr::getAdd(C0, C1); Value *NewShift1 = Builder.CreateBinOp(ShiftOpcode, X, ShiftSumC); Value *NewShift2 = Builder.CreateBinOp(ShiftOpcode, Y, I.getOperand(1)); return BinaryOperator::Create(LogicInst->getOpcode(), NewShift1, NewShift2); } -Instruction *InstCombiner::commonShiftTransforms(BinaryOperator &I) { +Instruction *InstCombinerImpl::commonShiftTransforms(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); assert(Op0->getType() == Op1->getType()); @@ -399,15 +402,15 @@ Instruction *InstCombiner::commonShiftTransforms(BinaryOperator &I) { return BinaryOperator::Create( I.getOpcode(), Builder.CreateBinOp(I.getOpcode(), Op0, C), A); - // X shift (A srem B) -> X shift (A and B-1) iff B is a power of 2. + // X shift (A srem C) -> X shift (A and (C - 1)) iff C is a power of 2. // Because shifts by negative values (which could occur if A were negative) // are undefined. - const APInt *B; - if (Op1->hasOneUse() && match(Op1, m_SRem(m_Value(A), m_Power2(B)))) { + if (Op1->hasOneUse() && match(Op1, m_SRem(m_Value(A), m_Constant(C))) && + match(C, m_Power2())) { // FIXME: Should this get moved into SimplifyDemandedBits by saying we don't // demand the sign bit (and many others) here?? - Value *Rem = Builder.CreateAnd(A, ConstantInt::get(I.getType(), *B - 1), - Op1->getName()); + Constant *Mask = ConstantExpr::getSub(C, ConstantInt::get(I.getType(), 1)); + Value *Rem = Builder.CreateAnd(A, Mask, Op1->getName()); return replaceOperand(I, 1, Rem); } @@ -420,8 +423,8 @@ Instruction *InstCombiner::commonShiftTransforms(BinaryOperator &I) { /// Return true if we can simplify two logical (either left or right) shifts /// that have constant shift amounts: OuterShift (InnerShift X, C1), C2. static bool canEvaluateShiftedShift(unsigned OuterShAmt, bool IsOuterShl, - Instruction *InnerShift, InstCombiner &IC, - Instruction *CxtI) { + Instruction *InnerShift, + InstCombinerImpl &IC, Instruction *CxtI) { assert(InnerShift->isLogicalShift() && "Unexpected instruction type"); // We need constant scalar or constant splat shifts. @@ -472,7 +475,7 @@ static bool canEvaluateShiftedShift(unsigned OuterShAmt, bool IsOuterShl, /// where the client will ask if E can be computed shifted right by 64-bits. If /// this succeeds, getShiftedValue() will be called to produce the value. static bool canEvaluateShifted(Value *V, unsigned NumBits, bool IsLeftShift, - InstCombiner &IC, Instruction *CxtI) { + InstCombinerImpl &IC, Instruction *CxtI) { // We can always evaluate constants shifted. if (isa<Constant>(V)) return true; @@ -480,31 +483,6 @@ static bool canEvaluateShifted(Value *V, unsigned NumBits, bool IsLeftShift, Instruction *I = dyn_cast<Instruction>(V); if (!I) return false; - // If this is the opposite shift, we can directly reuse the input of the shift - // if the needed bits are already zero in the input. This allows us to reuse - // the value which means that we don't care if the shift has multiple uses. - // TODO: Handle opposite shift by exact value. - ConstantInt *CI = nullptr; - if ((IsLeftShift && match(I, m_LShr(m_Value(), m_ConstantInt(CI)))) || - (!IsLeftShift && match(I, m_Shl(m_Value(), m_ConstantInt(CI))))) { - if (CI->getValue() == NumBits) { - // TODO: Check that the input bits are already zero with MaskedValueIsZero -#if 0 - // If this is a truncate of a logical shr, we can truncate it to a smaller - // lshr iff we know that the bits we would otherwise be shifting in are - // already zeros. - uint32_t OrigBitWidth = OrigTy->getScalarSizeInBits(); - uint32_t BitWidth = Ty->getScalarSizeInBits(); - if (MaskedValueIsZero(I->getOperand(0), - APInt::getHighBitsSet(OrigBitWidth, OrigBitWidth-BitWidth)) && - CI->getLimitedValue(BitWidth) < BitWidth) { - return CanEvaluateTruncated(I->getOperand(0), Ty); - } -#endif - - } - } - // We can't mutate something that has multiple uses: doing so would // require duplicating the instruction in general, which isn't profitable. if (!I->hasOneUse()) return false; @@ -608,7 +586,7 @@ static Value *foldShiftedShift(BinaryOperator *InnerShift, unsigned OuterShAmt, /// When canEvaluateShifted() returns true for an expression, this function /// inserts the new computation that produces the shifted value. static Value *getShiftedValue(Value *V, unsigned NumBits, bool isLeftShift, - InstCombiner &IC, const DataLayout &DL) { + InstCombinerImpl &IC, const DataLayout &DL) { // We can always evaluate constants shifted. if (Constant *C = dyn_cast<Constant>(V)) { if (isLeftShift) @@ -618,7 +596,7 @@ static Value *getShiftedValue(Value *V, unsigned NumBits, bool isLeftShift, } Instruction *I = cast<Instruction>(V); - IC.Worklist.push(I); + IC.addToWorklist(I); switch (I->getOpcode()) { default: llvm_unreachable("Inconsistency with CanEvaluateShifted"); @@ -666,14 +644,17 @@ static bool canShiftBinOpWithConstantRHS(BinaryOperator &Shift, case Instruction::Add: return Shift.getOpcode() == Instruction::Shl; case Instruction::Or: - case Instruction::Xor: case Instruction::And: return true; + case Instruction::Xor: + // Do not change a 'not' of logical shift because that would create a normal + // 'xor'. The 'not' is likely better for analysis, SCEV, and codegen. + return !(Shift.isLogicalShift() && match(BO, m_Not(m_Value()))); } } -Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, - BinaryOperator &I) { +Instruction *InstCombinerImpl::FoldShiftByConstant(Value *Op0, Constant *Op1, + BinaryOperator &I) { bool isLeftShift = I.getOpcode() == Instruction::Shl; const APInt *Op1C; @@ -695,8 +676,8 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, // See if we can simplify any instructions used by the instruction whose sole // purpose is to compute bits we don't care about. - unsigned TypeBits = Op0->getType()->getScalarSizeInBits(); - + Type *Ty = I.getType(); + unsigned TypeBits = Ty->getScalarSizeInBits(); assert(!Op1C->uge(TypeBits) && "Shift over the type width should have been removed already"); @@ -704,18 +685,20 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, return FoldedShift; // Fold shift2(trunc(shift1(x,c1)), c2) -> trunc(shift2(shift1(x,c1),c2)) - if (TruncInst *TI = dyn_cast<TruncInst>(Op0)) { - Instruction *TrOp = dyn_cast<Instruction>(TI->getOperand(0)); + if (auto *TI = dyn_cast<TruncInst>(Op0)) { // If 'shift2' is an ashr, we would have to get the sign bit into a funny // place. Don't try to do this transformation in this case. Also, we // require that the input operand is a shift-by-constant so that we have // confidence that the shifts will get folded together. We could do this // xform in more cases, but it is unlikely to be profitable. - if (TrOp && I.isLogicalShift() && TrOp->isShift() && - isa<ConstantInt>(TrOp->getOperand(1))) { + const APInt *TrShiftAmt; + if (I.isLogicalShift() && + match(TI->getOperand(0), m_Shift(m_Value(), m_APInt(TrShiftAmt)))) { + auto *TrOp = cast<Instruction>(TI->getOperand(0)); + Type *SrcTy = TrOp->getType(); + // Okay, we'll do this xform. Make the shift of shift. - Constant *ShAmt = - ConstantExpr::getZExt(cast<Constant>(Op1), TrOp->getType()); + Constant *ShAmt = ConstantExpr::getZExt(Op1, SrcTy); // (shift2 (shift1 & 0x00FF), c2) Value *NSh = Builder.CreateBinOp(I.getOpcode(), TrOp, ShAmt, I.getName()); @@ -723,36 +706,27 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, // part of the register be zeros. Emulate this by inserting an AND to // clear the top bits as needed. This 'and' will usually be zapped by // other xforms later if dead. - unsigned SrcSize = TrOp->getType()->getScalarSizeInBits(); - unsigned DstSize = TI->getType()->getScalarSizeInBits(); - APInt MaskV(APInt::getLowBitsSet(SrcSize, DstSize)); + unsigned SrcSize = SrcTy->getScalarSizeInBits(); + Constant *MaskV = + ConstantInt::get(SrcTy, APInt::getLowBitsSet(SrcSize, TypeBits)); // The mask we constructed says what the trunc would do if occurring // between the shifts. We want to know the effect *after* the second // shift. We know that it is a logical shift by a constant, so adjust the // mask as appropriate. - if (I.getOpcode() == Instruction::Shl) - MaskV <<= Op1C->getZExtValue(); - else { - assert(I.getOpcode() == Instruction::LShr && "Unknown logical shift"); - MaskV.lshrInPlace(Op1C->getZExtValue()); - } - + MaskV = ConstantExpr::get(I.getOpcode(), MaskV, ShAmt); // shift1 & 0x00FF - Value *And = Builder.CreateAnd(NSh, - ConstantInt::get(I.getContext(), MaskV), - TI->getName()); - + Value *And = Builder.CreateAnd(NSh, MaskV, TI->getName()); // Return the value truncated to the interesting size. - return new TruncInst(And, I.getType()); + return new TruncInst(And, Ty); } } if (Op0->hasOneUse()) { if (BinaryOperator *Op0BO = dyn_cast<BinaryOperator>(Op0)) { // Turn ((X >> C) + Y) << C -> (X + (Y << C)) & (~0 << C) - Value *V1, *V2; - ConstantInt *CC; + Value *V1; + const APInt *CC; switch (Op0BO->getOpcode()) { default: break; case Instruction::Add: @@ -770,25 +744,22 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, Value *X = Builder.CreateBinOp(Op0BO->getOpcode(), YS, V1, Op0BO->getOperand(1)->getName()); unsigned Op1Val = Op1C->getLimitedValue(TypeBits); - APInt Bits = APInt::getHighBitsSet(TypeBits, TypeBits - Op1Val); - Constant *Mask = ConstantInt::get(I.getContext(), Bits); - if (VectorType *VT = dyn_cast<VectorType>(X->getType())) - Mask = ConstantVector::getSplat(VT->getElementCount(), Mask); + Constant *Mask = ConstantInt::get(Ty, Bits); return BinaryOperator::CreateAnd(X, Mask); } // Turn (Y + ((X >> C) & CC)) << C -> ((X & (CC << C)) + (Y << C)) Value *Op0BOOp1 = Op0BO->getOperand(1); if (isLeftShift && Op0BOOp1->hasOneUse() && - match(Op0BOOp1, - m_And(m_OneUse(m_Shr(m_Value(V1), m_Specific(Op1))), - m_ConstantInt(CC)))) { - Value *YS = // (Y << C) - Builder.CreateShl(Op0BO->getOperand(0), Op1, Op0BO->getName()); + match(Op0BOOp1, m_And(m_OneUse(m_Shr(m_Value(V1), m_Specific(Op1))), + m_APInt(CC)))) { + Value *YS = // (Y << C) + Builder.CreateShl(Op0BO->getOperand(0), Op1, Op0BO->getName()); // X & (CC << C) - Value *XM = Builder.CreateAnd(V1, ConstantExpr::getShl(CC, Op1), - V1->getName()+".mask"); + Value *XM = Builder.CreateAnd( + V1, ConstantExpr::getShl(ConstantInt::get(Ty, *CC), Op1), + V1->getName() + ".mask"); return BinaryOperator::Create(Op0BO->getOpcode(), YS, XM); } LLVM_FALLTHROUGH; @@ -805,25 +776,22 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, Value *X = Builder.CreateBinOp(Op0BO->getOpcode(), V1, YS, Op0BO->getOperand(0)->getName()); unsigned Op1Val = Op1C->getLimitedValue(TypeBits); - APInt Bits = APInt::getHighBitsSet(TypeBits, TypeBits - Op1Val); - Constant *Mask = ConstantInt::get(I.getContext(), Bits); - if (VectorType *VT = dyn_cast<VectorType>(X->getType())) - Mask = ConstantVector::getSplat(VT->getElementCount(), Mask); + Constant *Mask = ConstantInt::get(Ty, Bits); return BinaryOperator::CreateAnd(X, Mask); } // Turn (((X >> C)&CC) + Y) << C -> (X + (Y << C)) & (CC << C) if (isLeftShift && Op0BO->getOperand(0)->hasOneUse() && match(Op0BO->getOperand(0), - m_And(m_OneUse(m_Shr(m_Value(V1), m_Value(V2))), - m_ConstantInt(CC))) && V2 == Op1) { + m_And(m_OneUse(m_Shr(m_Value(V1), m_Specific(Op1))), + m_APInt(CC)))) { Value *YS = // (Y << C) - Builder.CreateShl(Op0BO->getOperand(1), Op1, Op0BO->getName()); + Builder.CreateShl(Op0BO->getOperand(1), Op1, Op0BO->getName()); // X & (CC << C) - Value *XM = Builder.CreateAnd(V1, ConstantExpr::getShl(CC, Op1), - V1->getName()+".mask"); - + Value *XM = Builder.CreateAnd( + V1, ConstantExpr::getShl(ConstantInt::get(Ty, *CC), Op1), + V1->getName() + ".mask"); return BinaryOperator::Create(Op0BO->getOpcode(), XM, YS); } @@ -831,7 +799,6 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, } } - // If the operand is a bitwise operator with a constant RHS, and the // shift is the only use, we can pull it out of the shift. const APInt *Op0C; @@ -915,7 +882,7 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, return nullptr; } -Instruction *InstCombiner::visitShl(BinaryOperator &I) { +Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) { const SimplifyQuery Q = SQ.getWithInstruction(&I); if (Value *V = SimplifyShlInst(I.getOperand(0), I.getOperand(1), @@ -955,10 +922,9 @@ Instruction *InstCombiner::visitShl(BinaryOperator &I) { return BinaryOperator::CreateAnd(X, ConstantInt::get(Ty, Mask)); } - // FIXME: we do not yet transform non-exact shr's. The backend (DAGCombine) - // needs a few fixes for the rotate pattern recognition first. const APInt *ShOp1; - if (match(Op0, m_Exact(m_Shr(m_Value(X), m_APInt(ShOp1))))) { + if (match(Op0, m_Exact(m_Shr(m_Value(X), m_APInt(ShOp1)))) && + ShOp1->ult(BitWidth)) { unsigned ShrAmt = ShOp1->getZExtValue(); if (ShrAmt < ShAmt) { // If C1 < C2: (X >>?,exact C1) << C2 --> X << (C2 - C1) @@ -978,7 +944,33 @@ Instruction *InstCombiner::visitShl(BinaryOperator &I) { } } - if (match(Op0, m_Shl(m_Value(X), m_APInt(ShOp1)))) { + if (match(Op0, m_OneUse(m_Shr(m_Value(X), m_APInt(ShOp1)))) && + ShOp1->ult(BitWidth)) { + unsigned ShrAmt = ShOp1->getZExtValue(); + if (ShrAmt < ShAmt) { + // If C1 < C2: (X >>? C1) << C2 --> X << (C2 - C1) & (-1 << C2) + Constant *ShiftDiff = ConstantInt::get(Ty, ShAmt - ShrAmt); + auto *NewShl = BinaryOperator::CreateShl(X, ShiftDiff); + NewShl->setHasNoUnsignedWrap(I.hasNoUnsignedWrap()); + NewShl->setHasNoSignedWrap(I.hasNoSignedWrap()); + Builder.Insert(NewShl); + APInt Mask(APInt::getHighBitsSet(BitWidth, BitWidth - ShAmt)); + return BinaryOperator::CreateAnd(NewShl, ConstantInt::get(Ty, Mask)); + } + if (ShrAmt > ShAmt) { + // If C1 > C2: (X >>? C1) << C2 --> X >>? (C1 - C2) & (-1 << C2) + Constant *ShiftDiff = ConstantInt::get(Ty, ShrAmt - ShAmt); + auto *OldShr = cast<BinaryOperator>(Op0); + auto *NewShr = + BinaryOperator::Create(OldShr->getOpcode(), X, ShiftDiff); + NewShr->setIsExact(OldShr->isExact()); + Builder.Insert(NewShr); + APInt Mask(APInt::getHighBitsSet(BitWidth, BitWidth - ShAmt)); + return BinaryOperator::CreateAnd(NewShr, ConstantInt::get(Ty, Mask)); + } + } + + if (match(Op0, m_Shl(m_Value(X), m_APInt(ShOp1))) && ShOp1->ult(BitWidth)) { unsigned AmtSum = ShAmt + ShOp1->getZExtValue(); // Oversized shifts are simplified to zero in InstSimplify. if (AmtSum < BitWidth) @@ -1037,7 +1029,7 @@ Instruction *InstCombiner::visitShl(BinaryOperator &I) { return nullptr; } -Instruction *InstCombiner::visitLShr(BinaryOperator &I) { +Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) { if (Value *V = SimplifyLShrInst(I.getOperand(0), I.getOperand(1), I.isExact(), SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); @@ -1139,6 +1131,12 @@ Instruction *InstCombiner::visitLShr(BinaryOperator &I) { } } + // lshr i32 (X -nsw Y), 31 --> zext (X < Y) + Value *Y; + if (ShAmt == BitWidth - 1 && + match(Op0, m_OneUse(m_NSWSub(m_Value(X), m_Value(Y))))) + return new ZExtInst(Builder.CreateICmpSLT(X, Y), Ty); + if (match(Op0, m_LShr(m_Value(X), m_APInt(ShOp1)))) { unsigned AmtSum = ShAmt + ShOp1->getZExtValue(); // Oversized shifts are simplified to zero in InstSimplify. @@ -1167,7 +1165,7 @@ Instruction *InstCombiner::visitLShr(BinaryOperator &I) { } Instruction * -InstCombiner::foldVariableSignZeroExtensionOfVariableHighBitExtract( +InstCombinerImpl::foldVariableSignZeroExtensionOfVariableHighBitExtract( BinaryOperator &OldAShr) { assert(OldAShr.getOpcode() == Instruction::AShr && "Must be called with arithmetic right-shift instruction only."); @@ -1235,7 +1233,7 @@ InstCombiner::foldVariableSignZeroExtensionOfVariableHighBitExtract( return TruncInst::CreateTruncOrBitCast(NewAShr, OldAShr.getType()); } -Instruction *InstCombiner::visitAShr(BinaryOperator &I) { +Instruction *InstCombinerImpl::visitAShr(BinaryOperator &I) { if (Value *V = SimplifyAShrInst(I.getOperand(0), I.getOperand(1), I.isExact(), SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); @@ -1301,6 +1299,12 @@ Instruction *InstCombiner::visitAShr(BinaryOperator &I) { return new SExtInst(NewSh, Ty); } + // ashr i32 (X -nsw Y), 31 --> sext (X < Y) + Value *Y; + if (ShAmt == BitWidth - 1 && + match(Op0, m_OneUse(m_NSWSub(m_Value(X), m_Value(Y))))) + return new SExtInst(Builder.CreateICmpSLT(X, Y), Ty); + // If the shifted-out value is known-zero, then this is an exact shift. if (!I.isExact() && MaskedValueIsZero(Op0, APInt::getLowBitsSet(BitWidth, ShAmt), 0, &I)) { diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp index 7cfe4c8b5892..c265516213aa 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp @@ -12,29 +12,18 @@ //===----------------------------------------------------------------------===// #include "InstCombineInternal.h" +#include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/IntrinsicInst.h" -#include "llvm/IR/IntrinsicsAMDGPU.h" -#include "llvm/IR/IntrinsicsX86.h" #include "llvm/IR/PatternMatch.h" #include "llvm/Support/KnownBits.h" +#include "llvm/Transforms/InstCombine/InstCombiner.h" using namespace llvm; using namespace llvm::PatternMatch; #define DEBUG_TYPE "instcombine" -namespace { - -struct AMDGPUImageDMaskIntrinsic { - unsigned Intr; -}; - -#define GET_AMDGPUImageDMaskIntrinsicTable_IMPL -#include "InstCombineTables.inc" - -} // end anonymous namespace - /// Check to see if the specified operand of the specified instruction is a /// constant integer. If so, check to see if there are any bits set in the /// constant that are not demanded. If so, shrink the constant and return true. @@ -63,7 +52,7 @@ static bool ShrinkDemandedConstant(Instruction *I, unsigned OpNo, /// Inst is an integer instruction that SimplifyDemandedBits knows about. See if /// the instruction has any properties that allow us to simplify its operands. -bool InstCombiner::SimplifyDemandedInstructionBits(Instruction &Inst) { +bool InstCombinerImpl::SimplifyDemandedInstructionBits(Instruction &Inst) { unsigned BitWidth = Inst.getType()->getScalarSizeInBits(); KnownBits Known(BitWidth); APInt DemandedMask(APInt::getAllOnesValue(BitWidth)); @@ -79,22 +68,20 @@ bool InstCombiner::SimplifyDemandedInstructionBits(Instruction &Inst) { /// This form of SimplifyDemandedBits simplifies the specified instruction /// operand if possible, updating it in place. It returns true if it made any /// change and false otherwise. -bool InstCombiner::SimplifyDemandedBits(Instruction *I, unsigned OpNo, - const APInt &DemandedMask, - KnownBits &Known, - unsigned Depth) { +bool InstCombinerImpl::SimplifyDemandedBits(Instruction *I, unsigned OpNo, + const APInt &DemandedMask, + KnownBits &Known, unsigned Depth) { Use &U = I->getOperandUse(OpNo); Value *NewVal = SimplifyDemandedUseBits(U.get(), DemandedMask, Known, Depth, I); if (!NewVal) return false; if (Instruction* OpInst = dyn_cast<Instruction>(U)) salvageDebugInfo(*OpInst); - + replaceUse(U, NewVal); return true; } - /// This function attempts to replace V with a simpler value based on the /// demanded bits. When this function is called, it is known that only the bits /// set in DemandedMask of the result of V are ever used downstream. @@ -118,11 +105,12 @@ bool InstCombiner::SimplifyDemandedBits(Instruction *I, unsigned OpNo, /// operands based on the information about what bits are demanded. This returns /// some other non-null value if it found out that V is equal to another value /// in the context where the specified bits are demanded, but not for all users. -Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, - KnownBits &Known, unsigned Depth, - Instruction *CxtI) { +Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, + KnownBits &Known, + unsigned Depth, + Instruction *CxtI) { assert(V != nullptr && "Null pointer of Value???"); - assert(Depth <= 6 && "Limit Search Depth"); + assert(Depth <= MaxAnalysisRecursionDepth && "Limit Search Depth"); uint32_t BitWidth = DemandedMask.getBitWidth(); Type *VTy = V->getType(); assert( @@ -139,7 +127,10 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, if (DemandedMask.isNullValue()) // Not demanding any bits from V. return UndefValue::get(VTy); - if (Depth == 6) // Limit search depth. + if (Depth == MaxAnalysisRecursionDepth) + return nullptr; + + if (isa<ScalableVectorType>(VTy)) return nullptr; Instruction *I = dyn_cast<Instruction>(V); @@ -268,35 +259,44 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, return InsertNewInstWith(And, *I); } - // If the RHS is a constant, see if we can simplify it. - // FIXME: for XOR, we prefer to force bits to 1 if they will make a -1. - if (ShrinkDemandedConstant(I, 1, DemandedMask)) - return I; + // If the RHS is a constant, see if we can change it. Don't alter a -1 + // constant because that's a canonical 'not' op, and that is better for + // combining, SCEV, and codegen. + const APInt *C; + if (match(I->getOperand(1), m_APInt(C)) && !C->isAllOnesValue()) { + if ((*C | ~DemandedMask).isAllOnesValue()) { + // Force bits to 1 to create a 'not' op. + I->setOperand(1, ConstantInt::getAllOnesValue(VTy)); + return I; + } + // If we can't turn this into a 'not', try to shrink the constant. + if (ShrinkDemandedConstant(I, 1, DemandedMask)) + return I; + } // If our LHS is an 'and' and if it has one use, and if any of the bits we // are flipping are known to be set, then the xor is just resetting those // bits to zero. We can just knock out bits from the 'and' and the 'xor', // simplifying both of them. - if (Instruction *LHSInst = dyn_cast<Instruction>(I->getOperand(0))) + if (Instruction *LHSInst = dyn_cast<Instruction>(I->getOperand(0))) { + ConstantInt *AndRHS, *XorRHS; if (LHSInst->getOpcode() == Instruction::And && LHSInst->hasOneUse() && - isa<ConstantInt>(I->getOperand(1)) && - isa<ConstantInt>(LHSInst->getOperand(1)) && + match(I->getOperand(1), m_ConstantInt(XorRHS)) && + match(LHSInst->getOperand(1), m_ConstantInt(AndRHS)) && (LHSKnown.One & RHSKnown.One & DemandedMask) != 0) { - ConstantInt *AndRHS = cast<ConstantInt>(LHSInst->getOperand(1)); - ConstantInt *XorRHS = cast<ConstantInt>(I->getOperand(1)); APInt NewMask = ~(LHSKnown.One & RHSKnown.One & DemandedMask); Constant *AndC = - ConstantInt::get(I->getType(), NewMask & AndRHS->getValue()); + ConstantInt::get(I->getType(), NewMask & AndRHS->getValue()); Instruction *NewAnd = BinaryOperator::CreateAnd(I->getOperand(0), AndC); InsertNewInstWith(NewAnd, *I); Constant *XorC = - ConstantInt::get(I->getType(), NewMask & XorRHS->getValue()); + ConstantInt::get(I->getType(), NewMask & XorRHS->getValue()); Instruction *NewXor = BinaryOperator::CreateXor(NewAnd, XorC); return InsertNewInstWith(NewXor, *I); } - + } break; } case Instruction::Select: { @@ -339,7 +339,7 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, // we can. This helps not break apart (or helps put back together) // canonical patterns like min and max. auto CanonicalizeSelectConstant = [](Instruction *I, unsigned OpNo, - APInt DemandedMask) { + const APInt &DemandedMask) { const APInt *SelC; if (!match(I->getOperand(OpNo), m_APInt(SelC))) return false; @@ -367,8 +367,7 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, return I; // Only known if known in both the LHS and RHS. - Known.One = RHSKnown.One & LHSKnown.One; - Known.Zero = RHSKnown.Zero & LHSKnown.Zero; + Known = KnownBits::commonBits(LHSKnown, RHSKnown); break; } case Instruction::ZExt: @@ -391,7 +390,8 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, if (VectorType *DstVTy = dyn_cast<VectorType>(I->getType())) { if (VectorType *SrcVTy = dyn_cast<VectorType>(I->getOperand(0)->getType())) { - if (DstVTy->getNumElements() != SrcVTy->getNumElements()) + if (cast<FixedVectorType>(DstVTy)->getNumElements() != + cast<FixedVectorType>(SrcVTy)->getNumElements()) // Don't touch a bitcast between vectors of different element counts. return nullptr; } else @@ -669,8 +669,9 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, } break; } - case Instruction::SRem: - if (ConstantInt *Rem = dyn_cast<ConstantInt>(I->getOperand(1))) { + case Instruction::SRem: { + ConstantInt *Rem; + if (match(I->getOperand(1), m_ConstantInt(Rem))) { // X % -1 demands all the bits because we don't want to introduce // INT_MIN % -1 (== undef) by accident. if (Rem->isMinusOne()) @@ -713,6 +714,7 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, Known.makeNonNegative(); } break; + } case Instruction::URem: { KnownBits Known2(BitWidth); APInt AllOnes = APInt::getAllOnesValue(BitWidth); @@ -728,7 +730,6 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, bool KnownBitsComputed = false; if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) { switch (II->getIntrinsicID()) { - default: break; case Intrinsic::bswap: { // If the only bits demanded come from one byte of the bswap result, // just shift the input byte into position to eliminate the bswap. @@ -784,39 +785,14 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, KnownBitsComputed = true; break; } - case Intrinsic::x86_mmx_pmovmskb: - case Intrinsic::x86_sse_movmsk_ps: - case Intrinsic::x86_sse2_movmsk_pd: - case Intrinsic::x86_sse2_pmovmskb_128: - case Intrinsic::x86_avx_movmsk_ps_256: - case Intrinsic::x86_avx_movmsk_pd_256: - case Intrinsic::x86_avx2_pmovmskb: { - // MOVMSK copies the vector elements' sign bits to the low bits - // and zeros the high bits. - unsigned ArgWidth; - if (II->getIntrinsicID() == Intrinsic::x86_mmx_pmovmskb) { - ArgWidth = 8; // Arg is x86_mmx, but treated as <8 x i8>. - } else { - auto Arg = II->getArgOperand(0); - auto ArgType = cast<VectorType>(Arg->getType()); - ArgWidth = ArgType->getNumElements(); - } - - // If we don't need any of low bits then return zero, - // we know that DemandedMask is non-zero already. - APInt DemandedElts = DemandedMask.zextOrTrunc(ArgWidth); - if (DemandedElts.isNullValue()) - return ConstantInt::getNullValue(VTy); - - // We know that the upper bits are set to zero. - Known.Zero.setBitsFrom(ArgWidth); - KnownBitsComputed = true; + default: { + // Handle target specific intrinsics + Optional<Value *> V = targetSimplifyDemandedUseBitsIntrinsic( + *II, DemandedMask, Known, KnownBitsComputed); + if (V.hasValue()) + return V.getValue(); break; } - case Intrinsic::x86_sse42_crc32_64_64: - Known.Zero.setBitsFrom(32); - KnownBitsComputed = true; - break; } } @@ -836,11 +812,9 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, /// Helper routine of SimplifyDemandedUseBits. It computes Known /// bits. It also tries to handle simplifications that can be done based on /// DemandedMask, but without modifying the Instruction. -Value *InstCombiner::SimplifyMultipleUseDemandedBits(Instruction *I, - const APInt &DemandedMask, - KnownBits &Known, - unsigned Depth, - Instruction *CxtI) { +Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits( + Instruction *I, const APInt &DemandedMask, KnownBits &Known, unsigned Depth, + Instruction *CxtI) { unsigned BitWidth = DemandedMask.getBitWidth(); Type *ITy = I->getType(); @@ -925,6 +899,33 @@ Value *InstCombiner::SimplifyMultipleUseDemandedBits(Instruction *I, break; } + case Instruction::AShr: { + // Compute the Known bits to simplify things downstream. + computeKnownBits(I, Known, Depth, CxtI); + + // If this user is only demanding bits that we know, return the known + // constant. + if (DemandedMask.isSubsetOf(Known.Zero | Known.One)) + return Constant::getIntegerValue(ITy, Known.One); + + // If the right shift operand 0 is a result of a left shift by the same + // amount, this is probably a zero/sign extension, which may be unnecessary, + // if we do not demand any of the new sign bits. So, return the original + // operand instead. + const APInt *ShiftRC; + const APInt *ShiftLC; + Value *X; + unsigned BitWidth = DemandedMask.getBitWidth(); + if (match(I, + m_AShr(m_Shl(m_Value(X), m_APInt(ShiftLC)), m_APInt(ShiftRC))) && + ShiftLC == ShiftRC && + DemandedMask.isSubsetOf(APInt::getLowBitsSet( + BitWidth, BitWidth - ShiftRC->getZExtValue()))) { + return X; + } + + break; + } default: // Compute the Known bits to simplify things downstream. computeKnownBits(I, Known, Depth, CxtI); @@ -940,7 +941,6 @@ Value *InstCombiner::SimplifyMultipleUseDemandedBits(Instruction *I, return nullptr; } - /// Helper routine of SimplifyDemandedUseBits. It tries to simplify /// "E1 = (X lsr C1) << C2", where the C1 and C2 are constant, into /// "E2 = X << (C2 - C1)" or "E2 = X >> (C1 - C2)", depending on the sign @@ -958,11 +958,9 @@ Value *InstCombiner::SimplifyMultipleUseDemandedBits(Instruction *I, /// /// As with SimplifyDemandedUseBits, it returns NULL if the simplification was /// not successful. -Value * -InstCombiner::simplifyShrShlDemandedBits(Instruction *Shr, const APInt &ShrOp1, - Instruction *Shl, const APInt &ShlOp1, - const APInt &DemandedMask, - KnownBits &Known) { +Value *InstCombinerImpl::simplifyShrShlDemandedBits( + Instruction *Shr, const APInt &ShrOp1, Instruction *Shl, + const APInt &ShlOp1, const APInt &DemandedMask, KnownBits &Known) { if (!ShlOp1 || !ShrOp1) return nullptr; // No-op. @@ -1022,156 +1020,9 @@ InstCombiner::simplifyShrShlDemandedBits(Instruction *Shr, const APInt &ShrOp1, return nullptr; } -/// Implement SimplifyDemandedVectorElts for amdgcn buffer and image intrinsics. -/// -/// Note: This only supports non-TFE/LWE image intrinsic calls; those have -/// struct returns. -Value *InstCombiner::simplifyAMDGCNMemoryIntrinsicDemanded(IntrinsicInst *II, - APInt DemandedElts, - int DMaskIdx) { - - // FIXME: Allow v3i16/v3f16 in buffer intrinsics when the types are fully supported. - if (DMaskIdx < 0 && - II->getType()->getScalarSizeInBits() != 32 && - DemandedElts.getActiveBits() == 3) - return nullptr; - - auto *IIVTy = cast<VectorType>(II->getType()); - unsigned VWidth = IIVTy->getNumElements(); - if (VWidth == 1) - return nullptr; - - IRBuilderBase::InsertPointGuard Guard(Builder); - Builder.SetInsertPoint(II); - - // Assume the arguments are unchanged and later override them, if needed. - SmallVector<Value *, 16> Args(II->arg_begin(), II->arg_end()); - - if (DMaskIdx < 0) { - // Buffer case. - - const unsigned ActiveBits = DemandedElts.getActiveBits(); - const unsigned UnusedComponentsAtFront = DemandedElts.countTrailingZeros(); - - // Start assuming the prefix of elements is demanded, but possibly clear - // some other bits if there are trailing zeros (unused components at front) - // and update offset. - DemandedElts = (1 << ActiveBits) - 1; - - if (UnusedComponentsAtFront > 0) { - static const unsigned InvalidOffsetIdx = 0xf; - - unsigned OffsetIdx; - switch (II->getIntrinsicID()) { - case Intrinsic::amdgcn_raw_buffer_load: - OffsetIdx = 1; - break; - case Intrinsic::amdgcn_s_buffer_load: - // If resulting type is vec3, there is no point in trimming the - // load with updated offset, as the vec3 would most likely be widened to - // vec4 anyway during lowering. - if (ActiveBits == 4 && UnusedComponentsAtFront == 1) - OffsetIdx = InvalidOffsetIdx; - else - OffsetIdx = 1; - break; - case Intrinsic::amdgcn_struct_buffer_load: - OffsetIdx = 2; - break; - default: - // TODO: handle tbuffer* intrinsics. - OffsetIdx = InvalidOffsetIdx; - break; - } - - if (OffsetIdx != InvalidOffsetIdx) { - // Clear demanded bits and update the offset. - DemandedElts &= ~((1 << UnusedComponentsAtFront) - 1); - auto *Offset = II->getArgOperand(OffsetIdx); - unsigned SingleComponentSizeInBits = - getDataLayout().getTypeSizeInBits(II->getType()->getScalarType()); - unsigned OffsetAdd = - UnusedComponentsAtFront * SingleComponentSizeInBits / 8; - auto *OffsetAddVal = ConstantInt::get(Offset->getType(), OffsetAdd); - Args[OffsetIdx] = Builder.CreateAdd(Offset, OffsetAddVal); - } - } - } else { - // Image case. - - ConstantInt *DMask = cast<ConstantInt>(II->getArgOperand(DMaskIdx)); - unsigned DMaskVal = DMask->getZExtValue() & 0xf; - - // Mask off values that are undefined because the dmask doesn't cover them - DemandedElts &= (1 << countPopulation(DMaskVal)) - 1; - - unsigned NewDMaskVal = 0; - unsigned OrigLoadIdx = 0; - for (unsigned SrcIdx = 0; SrcIdx < 4; ++SrcIdx) { - const unsigned Bit = 1 << SrcIdx; - if (!!(DMaskVal & Bit)) { - if (!!DemandedElts[OrigLoadIdx]) - NewDMaskVal |= Bit; - OrigLoadIdx++; - } - } - - if (DMaskVal != NewDMaskVal) - Args[DMaskIdx] = ConstantInt::get(DMask->getType(), NewDMaskVal); - } - - unsigned NewNumElts = DemandedElts.countPopulation(); - if (!NewNumElts) - return UndefValue::get(II->getType()); - - if (NewNumElts >= VWidth && DemandedElts.isMask()) { - if (DMaskIdx >= 0) - II->setArgOperand(DMaskIdx, Args[DMaskIdx]); - return nullptr; - } - - // Validate function argument and return types, extracting overloaded types - // along the way. - SmallVector<Type *, 6> OverloadTys; - if (!Intrinsic::getIntrinsicSignature(II->getCalledFunction(), OverloadTys)) - return nullptr; - - Module *M = II->getParent()->getParent()->getParent(); - Type *EltTy = IIVTy->getElementType(); - Type *NewTy = - (NewNumElts == 1) ? EltTy : FixedVectorType::get(EltTy, NewNumElts); - - OverloadTys[0] = NewTy; - Function *NewIntrin = - Intrinsic::getDeclaration(M, II->getIntrinsicID(), OverloadTys); - - CallInst *NewCall = Builder.CreateCall(NewIntrin, Args); - NewCall->takeName(II); - NewCall->copyMetadata(*II); - - if (NewNumElts == 1) { - return Builder.CreateInsertElement(UndefValue::get(II->getType()), NewCall, - DemandedElts.countTrailingZeros()); - } - - SmallVector<int, 8> EltMask; - unsigned NewLoadIdx = 0; - for (unsigned OrigLoadIdx = 0; OrigLoadIdx < VWidth; ++OrigLoadIdx) { - if (!!DemandedElts[OrigLoadIdx]) - EltMask.push_back(NewLoadIdx++); - else - EltMask.push_back(NewNumElts); - } - - Value *Shuffle = - Builder.CreateShuffleVector(NewCall, UndefValue::get(NewTy), EltMask); - - return Shuffle; -} - /// The specified value produces a vector with any number of elements. -/// This method analyzes which elements of the operand are undef and returns -/// that information in UndefElts. +/// This method analyzes which elements of the operand are undef or poison and +/// returns that information in UndefElts. /// /// DemandedElts contains the set of elements that are actually used by the /// caller, and by default (AllowMultipleUsers equals false) the value is @@ -1182,10 +1033,11 @@ Value *InstCombiner::simplifyAMDGCNMemoryIntrinsicDemanded(IntrinsicInst *II, /// If the information about demanded elements can be used to simplify the /// operation, the operation is simplified, then the resultant value is /// returned. This returns null if no change was made. -Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, - APInt &UndefElts, - unsigned Depth, - bool AllowMultipleUsers) { +Value *InstCombinerImpl::SimplifyDemandedVectorElts(Value *V, + APInt DemandedElts, + APInt &UndefElts, + unsigned Depth, + bool AllowMultipleUsers) { // Cannot analyze scalable type. The number of vector elements is not a // compile-time constant. if (isa<ScalableVectorType>(V->getType())) @@ -1196,14 +1048,14 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, assert((DemandedElts & ~EltMask) == 0 && "Invalid DemandedElts!"); if (isa<UndefValue>(V)) { - // If the entire vector is undefined, just return this info. + // If the entire vector is undef or poison, just return this info. UndefElts = EltMask; return nullptr; } - if (DemandedElts.isNullValue()) { // If nothing is demanded, provide undef. + if (DemandedElts.isNullValue()) { // If nothing is demanded, provide poison. UndefElts = EltMask; - return UndefValue::get(V->getType()); + return PoisonValue::get(V->getType()); } UndefElts = 0; @@ -1215,11 +1067,11 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, return nullptr; Type *EltTy = cast<VectorType>(V->getType())->getElementType(); - Constant *Undef = UndefValue::get(EltTy); + Constant *Poison = PoisonValue::get(EltTy); SmallVector<Constant*, 16> Elts; for (unsigned i = 0; i != VWidth; ++i) { - if (!DemandedElts[i]) { // If not demanded, set to undef. - Elts.push_back(Undef); + if (!DemandedElts[i]) { // If not demanded, set to poison. + Elts.push_back(Poison); UndefElts.setBit(i); continue; } @@ -1227,12 +1079,9 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, Constant *Elt = C->getAggregateElement(i); if (!Elt) return nullptr; - if (isa<UndefValue>(Elt)) { // Already undef. - Elts.push_back(Undef); + Elts.push_back(Elt); + if (isa<UndefValue>(Elt)) // Already undef or poison. UndefElts.setBit(i); - } else { // Otherwise, defined. - Elts.push_back(Elt); - } } // If we changed the constant, return it. @@ -1292,12 +1141,12 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, }; if (mayIndexStructType(cast<GetElementPtrInst>(*I))) break; - + // Conservatively track the demanded elements back through any vector // operands we may have. We know there must be at least one, or we // wouldn't have a vector result to get here. Note that we intentionally // merge the undef bits here since gepping with either an undef base or - // index results in undef. + // index results in undef. for (unsigned i = 0; i < I->getNumOperands(); i++) { if (isa<UndefValue>(I->getOperand(i))) { // If the entire vector is undefined, just return this info. @@ -1331,6 +1180,19 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, if (IdxNo < VWidth) PreInsertDemandedElts.clearBit(IdxNo); + // If we only demand the element that is being inserted and that element + // was extracted from the same index in another vector with the same type, + // replace this insert with that other vector. + // Note: This is attempted before the call to simplifyAndSetOp because that + // may change UndefElts to a value that does not match with Vec. + Value *Vec; + if (PreInsertDemandedElts == 0 && + match(I->getOperand(1), + m_ExtractElt(m_Value(Vec), m_SpecificInt(IdxNo))) && + Vec->getType() == I->getType()) { + return Vec; + } + simplifyAndSetOp(I, 0, PreInsertDemandedElts, UndefElts); // If this is inserting an element that isn't demanded, remove this @@ -1349,8 +1211,8 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, assert(Shuffle->getOperand(0)->getType() == Shuffle->getOperand(1)->getType() && "Expected shuffle operands to have same type"); - unsigned OpWidth = - cast<VectorType>(Shuffle->getOperand(0)->getType())->getNumElements(); + unsigned OpWidth = cast<FixedVectorType>(Shuffle->getOperand(0)->getType()) + ->getNumElements(); // Handle trivial case of a splat. Only check the first element of LHS // operand. if (all_of(Shuffle->getShuffleMask(), [](int Elt) { return Elt == 0; }) && @@ -1451,7 +1313,8 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, // this constant vector to single insertelement instruction. // shufflevector V, C, <v1, v2, .., ci, .., vm> -> // insertelement V, C[ci], ci-n - if (OpWidth == Shuffle->getType()->getNumElements()) { + if (OpWidth == + cast<FixedVectorType>(Shuffle->getType())->getNumElements()) { Value *Op = nullptr; Constant *Value = nullptr; unsigned Idx = -1u; @@ -1538,7 +1401,7 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, // Vector->vector casts only. VectorType *VTy = dyn_cast<VectorType>(I->getOperand(0)->getType()); if (!VTy) break; - unsigned InVWidth = VTy->getNumElements(); + unsigned InVWidth = cast<FixedVectorType>(VTy)->getNumElements(); APInt InputDemandedElts(InVWidth, 0); UndefElts2 = APInt(InVWidth, 0); unsigned Ratio; @@ -1621,227 +1484,19 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, if (II->getIntrinsicID() == Intrinsic::masked_gather) simplifyAndSetOp(II, 0, DemandedPtrs, UndefElts2); simplifyAndSetOp(II, 3, DemandedPassThrough, UndefElts3); - + // Output elements are undefined if the element from both sources are. // TODO: can strengthen via mask as well. UndefElts = UndefElts2 & UndefElts3; break; } - case Intrinsic::x86_xop_vfrcz_ss: - case Intrinsic::x86_xop_vfrcz_sd: - // The instructions for these intrinsics are speced to zero upper bits not - // pass them through like other scalar intrinsics. So we shouldn't just - // use Arg0 if DemandedElts[0] is clear like we do for other intrinsics. - // Instead we should return a zero vector. - if (!DemandedElts[0]) { - Worklist.push(II); - return ConstantAggregateZero::get(II->getType()); - } - - // Only the lower element is used. - DemandedElts = 1; - simplifyAndSetOp(II, 0, DemandedElts, UndefElts); - - // Only the lower element is undefined. The high elements are zero. - UndefElts = UndefElts[0]; - break; - - // Unary scalar-as-vector operations that work column-wise. - case Intrinsic::x86_sse_rcp_ss: - case Intrinsic::x86_sse_rsqrt_ss: - simplifyAndSetOp(II, 0, DemandedElts, UndefElts); - - // If lowest element of a scalar op isn't used then use Arg0. - if (!DemandedElts[0]) { - Worklist.push(II); - return II->getArgOperand(0); - } - // TODO: If only low elt lower SQRT to FSQRT (with rounding/exceptions - // checks). - break; - - // Binary scalar-as-vector operations that work column-wise. The high - // elements come from operand 0. The low element is a function of both - // operands. - case Intrinsic::x86_sse_min_ss: - case Intrinsic::x86_sse_max_ss: - case Intrinsic::x86_sse_cmp_ss: - case Intrinsic::x86_sse2_min_sd: - case Intrinsic::x86_sse2_max_sd: - case Intrinsic::x86_sse2_cmp_sd: { - simplifyAndSetOp(II, 0, DemandedElts, UndefElts); - - // If lowest element of a scalar op isn't used then use Arg0. - if (!DemandedElts[0]) { - Worklist.push(II); - return II->getArgOperand(0); - } - - // Only lower element is used for operand 1. - DemandedElts = 1; - simplifyAndSetOp(II, 1, DemandedElts, UndefElts2); - - // Lower element is undefined if both lower elements are undefined. - // Consider things like undef&0. The result is known zero, not undef. - if (!UndefElts2[0]) - UndefElts.clearBit(0); - - break; - } - - // Binary scalar-as-vector operations that work column-wise. The high - // elements come from operand 0 and the low element comes from operand 1. - case Intrinsic::x86_sse41_round_ss: - case Intrinsic::x86_sse41_round_sd: { - // Don't use the low element of operand 0. - APInt DemandedElts2 = DemandedElts; - DemandedElts2.clearBit(0); - simplifyAndSetOp(II, 0, DemandedElts2, UndefElts); - - // If lowest element of a scalar op isn't used then use Arg0. - if (!DemandedElts[0]) { - Worklist.push(II); - return II->getArgOperand(0); - } - - // Only lower element is used for operand 1. - DemandedElts = 1; - simplifyAndSetOp(II, 1, DemandedElts, UndefElts2); - - // Take the high undef elements from operand 0 and take the lower element - // from operand 1. - UndefElts.clearBit(0); - UndefElts |= UndefElts2[0]; - break; - } - - // Three input scalar-as-vector operations that work column-wise. The high - // elements come from operand 0 and the low element is a function of all - // three inputs. - case Intrinsic::x86_avx512_mask_add_ss_round: - case Intrinsic::x86_avx512_mask_div_ss_round: - case Intrinsic::x86_avx512_mask_mul_ss_round: - case Intrinsic::x86_avx512_mask_sub_ss_round: - case Intrinsic::x86_avx512_mask_max_ss_round: - case Intrinsic::x86_avx512_mask_min_ss_round: - case Intrinsic::x86_avx512_mask_add_sd_round: - case Intrinsic::x86_avx512_mask_div_sd_round: - case Intrinsic::x86_avx512_mask_mul_sd_round: - case Intrinsic::x86_avx512_mask_sub_sd_round: - case Intrinsic::x86_avx512_mask_max_sd_round: - case Intrinsic::x86_avx512_mask_min_sd_round: - simplifyAndSetOp(II, 0, DemandedElts, UndefElts); - - // If lowest element of a scalar op isn't used then use Arg0. - if (!DemandedElts[0]) { - Worklist.push(II); - return II->getArgOperand(0); - } - - // Only lower element is used for operand 1 and 2. - DemandedElts = 1; - simplifyAndSetOp(II, 1, DemandedElts, UndefElts2); - simplifyAndSetOp(II, 2, DemandedElts, UndefElts3); - - // Lower element is undefined if all three lower elements are undefined. - // Consider things like undef&0. The result is known zero, not undef. - if (!UndefElts2[0] || !UndefElts3[0]) - UndefElts.clearBit(0); - - break; - - case Intrinsic::x86_sse2_packssdw_128: - case Intrinsic::x86_sse2_packsswb_128: - case Intrinsic::x86_sse2_packuswb_128: - case Intrinsic::x86_sse41_packusdw: - case Intrinsic::x86_avx2_packssdw: - case Intrinsic::x86_avx2_packsswb: - case Intrinsic::x86_avx2_packusdw: - case Intrinsic::x86_avx2_packuswb: - case Intrinsic::x86_avx512_packssdw_512: - case Intrinsic::x86_avx512_packsswb_512: - case Intrinsic::x86_avx512_packusdw_512: - case Intrinsic::x86_avx512_packuswb_512: { - auto *Ty0 = II->getArgOperand(0)->getType(); - unsigned InnerVWidth = cast<VectorType>(Ty0)->getNumElements(); - assert(VWidth == (InnerVWidth * 2) && "Unexpected input size"); - - unsigned NumLanes = Ty0->getPrimitiveSizeInBits() / 128; - unsigned VWidthPerLane = VWidth / NumLanes; - unsigned InnerVWidthPerLane = InnerVWidth / NumLanes; - - // Per lane, pack the elements of the first input and then the second. - // e.g. - // v8i16 PACK(v4i32 X, v4i32 Y) - (X[0..3],Y[0..3]) - // v32i8 PACK(v16i16 X, v16i16 Y) - (X[0..7],Y[0..7]),(X[8..15],Y[8..15]) - for (int OpNum = 0; OpNum != 2; ++OpNum) { - APInt OpDemandedElts(InnerVWidth, 0); - for (unsigned Lane = 0; Lane != NumLanes; ++Lane) { - unsigned LaneIdx = Lane * VWidthPerLane; - for (unsigned Elt = 0; Elt != InnerVWidthPerLane; ++Elt) { - unsigned Idx = LaneIdx + Elt + InnerVWidthPerLane * OpNum; - if (DemandedElts[Idx]) - OpDemandedElts.setBit((Lane * InnerVWidthPerLane) + Elt); - } - } - - // Demand elements from the operand. - APInt OpUndefElts(InnerVWidth, 0); - simplifyAndSetOp(II, OpNum, OpDemandedElts, OpUndefElts); - - // Pack the operand's UNDEF elements, one lane at a time. - OpUndefElts = OpUndefElts.zext(VWidth); - for (unsigned Lane = 0; Lane != NumLanes; ++Lane) { - APInt LaneElts = OpUndefElts.lshr(InnerVWidthPerLane * Lane); - LaneElts = LaneElts.getLoBits(InnerVWidthPerLane); - LaneElts <<= InnerVWidthPerLane * (2 * Lane + OpNum); - UndefElts |= LaneElts; - } - } - break; - } - - // PSHUFB - case Intrinsic::x86_ssse3_pshuf_b_128: - case Intrinsic::x86_avx2_pshuf_b: - case Intrinsic::x86_avx512_pshuf_b_512: - // PERMILVAR - case Intrinsic::x86_avx_vpermilvar_ps: - case Intrinsic::x86_avx_vpermilvar_ps_256: - case Intrinsic::x86_avx512_vpermilvar_ps_512: - case Intrinsic::x86_avx_vpermilvar_pd: - case Intrinsic::x86_avx_vpermilvar_pd_256: - case Intrinsic::x86_avx512_vpermilvar_pd_512: - // PERMV - case Intrinsic::x86_avx2_permd: - case Intrinsic::x86_avx2_permps: { - simplifyAndSetOp(II, 1, DemandedElts, UndefElts); - break; - } - - // SSE4A instructions leave the upper 64-bits of the 128-bit result - // in an undefined state. - case Intrinsic::x86_sse4a_extrq: - case Intrinsic::x86_sse4a_extrqi: - case Intrinsic::x86_sse4a_insertq: - case Intrinsic::x86_sse4a_insertqi: - UndefElts.setHighBits(VWidth / 2); - break; - case Intrinsic::amdgcn_buffer_load: - case Intrinsic::amdgcn_buffer_load_format: - case Intrinsic::amdgcn_raw_buffer_load: - case Intrinsic::amdgcn_raw_buffer_load_format: - case Intrinsic::amdgcn_raw_tbuffer_load: - case Intrinsic::amdgcn_s_buffer_load: - case Intrinsic::amdgcn_struct_buffer_load: - case Intrinsic::amdgcn_struct_buffer_load_format: - case Intrinsic::amdgcn_struct_tbuffer_load: - case Intrinsic::amdgcn_tbuffer_load: - return simplifyAMDGCNMemoryIntrinsicDemanded(II, DemandedElts); default: { - if (getAMDGPUImageDMaskIntrinsic(II->getIntrinsicID())) - return simplifyAMDGCNMemoryIntrinsicDemanded(II, DemandedElts, 0); - + // Handle target specific intrinsics + Optional<Value *> V = targetSimplifyDemandedVectorEltsIntrinsic( + *II, DemandedElts, UndefElts, UndefElts2, UndefElts3, + simplifyAndSetOp); + if (V.hasValue()) + return V.getValue(); break; } } // switch on IntrinsicID diff --git a/llvm/lib/Transforms/InstCombine/InstCombineTables.td b/llvm/lib/Transforms/InstCombine/InstCombineTables.td deleted file mode 100644 index 98b2adc442fa..000000000000 --- a/llvm/lib/Transforms/InstCombine/InstCombineTables.td +++ /dev/null @@ -1,11 +0,0 @@ -include "llvm/TableGen/SearchableTable.td" -include "llvm/IR/Intrinsics.td" - -def AMDGPUImageDMaskIntrinsicTable : GenericTable { - let FilterClass = "AMDGPUImageDMaskIntrinsic"; - let Fields = ["Intr"]; - - let PrimaryKey = ["Intr"]; - let PrimaryKeyName = "getAMDGPUImageDMaskIntrinsic"; - let PrimaryKeyEarlyOut = 1; -} diff --git a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp index ff70347569ab..06f22cdfb63d 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp @@ -18,6 +18,7 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallBitVector.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/BasicBlock.h" @@ -35,6 +36,7 @@ #include "llvm/Support/Casting.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Transforms/InstCombine/InstCombineWorklist.h" +#include "llvm/Transforms/InstCombine/InstCombiner.h" #include <cassert> #include <cstdint> #include <iterator> @@ -45,6 +47,10 @@ using namespace PatternMatch; #define DEBUG_TYPE "instcombine" +STATISTIC(NumAggregateReconstructionsSimplified, + "Number of aggregate reconstructions turned into reuse of the " + "original aggregate"); + /// Return true if the value is cheaper to scalarize than it is to leave as a /// vector operation. IsConstantExtractIndex indicates whether we are extracting /// one known element from a vector constant. @@ -85,7 +91,8 @@ static bool cheapToScalarize(Value *V, bool IsConstantExtractIndex) { // If we have a PHI node with a vector type that is only used to feed // itself and be an operand of extractelement at a constant location, // try to replace the PHI of the vector type with a PHI of a scalar type. -Instruction *InstCombiner::scalarizePHI(ExtractElementInst &EI, PHINode *PN) { +Instruction *InstCombinerImpl::scalarizePHI(ExtractElementInst &EI, + PHINode *PN) { SmallVector<Instruction *, 2> Extracts; // The users we want the PHI to have are: // 1) The EI ExtractElement (we already know this) @@ -178,15 +185,19 @@ static Instruction *foldBitcastExtElt(ExtractElementInst &Ext, // extelt (bitcast VecX), IndexC --> bitcast X[IndexC] auto *SrcTy = cast<VectorType>(X->getType()); Type *DestTy = Ext.getType(); - unsigned NumSrcElts = SrcTy->getNumElements(); - unsigned NumElts = Ext.getVectorOperandType()->getNumElements(); + ElementCount NumSrcElts = SrcTy->getElementCount(); + ElementCount NumElts = + cast<VectorType>(Ext.getVectorOperandType())->getElementCount(); if (NumSrcElts == NumElts) if (Value *Elt = findScalarElement(X, ExtIndexC)) return new BitCastInst(Elt, DestTy); + assert(NumSrcElts.isScalable() == NumElts.isScalable() && + "Src and Dst must be the same sort of vector type"); + // If the source elements are wider than the destination, try to shift and // truncate a subset of scalar bits of an insert op. - if (NumSrcElts < NumElts) { + if (NumSrcElts.getKnownMinValue() < NumElts.getKnownMinValue()) { Value *Scalar; uint64_t InsIndexC; if (!match(X, m_InsertElt(m_Value(), m_Value(Scalar), @@ -197,7 +208,8 @@ static Instruction *foldBitcastExtElt(ExtractElementInst &Ext, // into. Example: if we inserted element 1 of a <2 x i64> and we are // extracting an i16 (narrowing ratio = 4), then this extract must be from 1 // of elements 4-7 of the bitcasted vector. - unsigned NarrowingRatio = NumElts / NumSrcElts; + unsigned NarrowingRatio = + NumElts.getKnownMinValue() / NumSrcElts.getKnownMinValue(); if (ExtIndexC / NarrowingRatio != InsIndexC) return nullptr; @@ -259,7 +271,7 @@ static Instruction *foldBitcastExtElt(ExtractElementInst &Ext, /// Find elements of V demanded by UserInstr. static APInt findDemandedEltsBySingleUser(Value *V, Instruction *UserInstr) { - unsigned VWidth = cast<VectorType>(V->getType())->getNumElements(); + unsigned VWidth = cast<FixedVectorType>(V->getType())->getNumElements(); // Conservatively assume that all elements are needed. APInt UsedElts(APInt::getAllOnesValue(VWidth)); @@ -277,7 +289,7 @@ static APInt findDemandedEltsBySingleUser(Value *V, Instruction *UserInstr) { case Instruction::ShuffleVector: { ShuffleVectorInst *Shuffle = cast<ShuffleVectorInst>(UserInstr); unsigned MaskNumElts = - cast<VectorType>(UserInstr->getType())->getNumElements(); + cast<FixedVectorType>(UserInstr->getType())->getNumElements(); UsedElts = APInt(VWidth, 0); for (unsigned i = 0; i < MaskNumElts; i++) { @@ -303,7 +315,7 @@ static APInt findDemandedEltsBySingleUser(Value *V, Instruction *UserInstr) { /// no user demands an element of V, then the corresponding bit /// remains unset in the returned value. static APInt findDemandedEltsByAllUsers(Value *V) { - unsigned VWidth = cast<VectorType>(V->getType())->getNumElements(); + unsigned VWidth = cast<FixedVectorType>(V->getType())->getNumElements(); APInt UnionUsedElts(VWidth, 0); for (const Use &U : V->uses()) { @@ -321,7 +333,7 @@ static APInt findDemandedEltsByAllUsers(Value *V) { return UnionUsedElts; } -Instruction *InstCombiner::visitExtractElementInst(ExtractElementInst &EI) { +Instruction *InstCombinerImpl::visitExtractElementInst(ExtractElementInst &EI) { Value *SrcVec = EI.getVectorOperand(); Value *Index = EI.getIndexOperand(); if (Value *V = SimplifyExtractElementInst(SrcVec, Index, @@ -333,17 +345,17 @@ Instruction *InstCombiner::visitExtractElementInst(ExtractElementInst &EI) { auto *IndexC = dyn_cast<ConstantInt>(Index); if (IndexC) { ElementCount EC = EI.getVectorOperandType()->getElementCount(); - unsigned NumElts = EC.Min; + unsigned NumElts = EC.getKnownMinValue(); // InstSimplify should handle cases where the index is invalid. // For fixed-length vector, it's invalid to extract out-of-range element. - if (!EC.Scalable && IndexC->getValue().uge(NumElts)) + if (!EC.isScalable() && IndexC->getValue().uge(NumElts)) return nullptr; // This instruction only demands the single element from the input vector. // Skip for scalable type, the number of elements is unknown at // compile-time. - if (!EC.Scalable && NumElts != 1) { + if (!EC.isScalable() && NumElts != 1) { // If the input vector has a single use, simplify it based on this use // property. if (SrcVec->hasOneUse()) { @@ -460,7 +472,7 @@ static bool collectSingleShuffleElements(Value *V, Value *LHS, Value *RHS, SmallVectorImpl<int> &Mask) { assert(LHS->getType() == RHS->getType() && "Invalid CollectSingleShuffleElements"); - unsigned NumElts = cast<VectorType>(V->getType())->getNumElements(); + unsigned NumElts = cast<FixedVectorType>(V->getType())->getNumElements(); if (isa<UndefValue>(V)) { Mask.assign(NumElts, -1); @@ -502,7 +514,7 @@ static bool collectSingleShuffleElements(Value *V, Value *LHS, Value *RHS, unsigned ExtractedIdx = cast<ConstantInt>(EI->getOperand(1))->getZExtValue(); unsigned NumLHSElts = - cast<VectorType>(LHS->getType())->getNumElements(); + cast<FixedVectorType>(LHS->getType())->getNumElements(); // This must be extracting from either LHS or RHS. if (EI->getOperand(0) == LHS || EI->getOperand(0) == RHS) { @@ -531,9 +543,9 @@ static bool collectSingleShuffleElements(Value *V, Value *LHS, Value *RHS, /// shufflevector to replace one or more insert/extract pairs. static void replaceExtractElements(InsertElementInst *InsElt, ExtractElementInst *ExtElt, - InstCombiner &IC) { - VectorType *InsVecType = InsElt->getType(); - VectorType *ExtVecType = ExtElt->getVectorOperandType(); + InstCombinerImpl &IC) { + auto *InsVecType = cast<FixedVectorType>(InsElt->getType()); + auto *ExtVecType = cast<FixedVectorType>(ExtElt->getVectorOperandType()); unsigned NumInsElts = InsVecType->getNumElements(); unsigned NumExtElts = ExtVecType->getNumElements(); @@ -614,7 +626,7 @@ using ShuffleOps = std::pair<Value *, Value *>; static ShuffleOps collectShuffleElements(Value *V, SmallVectorImpl<int> &Mask, Value *PermittedRHS, - InstCombiner &IC) { + InstCombinerImpl &IC) { assert(V->getType()->isVectorTy() && "Invalid shuffle!"); unsigned NumElts = cast<FixedVectorType>(V->getType())->getNumElements(); @@ -661,7 +673,7 @@ static ShuffleOps collectShuffleElements(Value *V, SmallVectorImpl<int> &Mask, } unsigned NumLHSElts = - cast<VectorType>(RHS->getType())->getNumElements(); + cast<FixedVectorType>(RHS->getType())->getNumElements(); Mask[InsertedIdx % NumElts] = NumLHSElts + ExtractedIdx; return std::make_pair(LR.first, RHS); } @@ -670,7 +682,8 @@ static ShuffleOps collectShuffleElements(Value *V, SmallVectorImpl<int> &Mask, // We've gone as far as we can: anything on the other side of the // extractelement will already have been converted into a shuffle. unsigned NumLHSElts = - cast<VectorType>(EI->getOperand(0)->getType())->getNumElements(); + cast<FixedVectorType>(EI->getOperand(0)->getType()) + ->getNumElements(); for (unsigned i = 0; i != NumElts; ++i) Mask.push_back(i == InsertedIdx ? ExtractedIdx : NumLHSElts + i); return std::make_pair(EI->getOperand(0), PermittedRHS); @@ -692,6 +705,285 @@ static ShuffleOps collectShuffleElements(Value *V, SmallVectorImpl<int> &Mask, return std::make_pair(V, nullptr); } +/// Look for chain of insertvalue's that fully define an aggregate, and trace +/// back the values inserted, see if they are all were extractvalue'd from +/// the same source aggregate from the exact same element indexes. +/// If they were, just reuse the source aggregate. +/// This potentially deals with PHI indirections. +Instruction *InstCombinerImpl::foldAggregateConstructionIntoAggregateReuse( + InsertValueInst &OrigIVI) { + Type *AggTy = OrigIVI.getType(); + unsigned NumAggElts; + switch (AggTy->getTypeID()) { + case Type::StructTyID: + NumAggElts = AggTy->getStructNumElements(); + break; + case Type::ArrayTyID: + NumAggElts = AggTy->getArrayNumElements(); + break; + default: + llvm_unreachable("Unhandled aggregate type?"); + } + + // Arbitrary aggregate size cut-off. Motivation for limit of 2 is to be able + // to handle clang C++ exception struct (which is hardcoded as {i8*, i32}), + // FIXME: any interesting patterns to be caught with larger limit? + assert(NumAggElts > 0 && "Aggregate should have elements."); + if (NumAggElts > 2) + return nullptr; + + static constexpr auto NotFound = None; + static constexpr auto FoundMismatch = nullptr; + + // Try to find a value of each element of an aggregate. + // FIXME: deal with more complex, not one-dimensional, aggregate types + SmallVector<Optional<Value *>, 2> AggElts(NumAggElts, NotFound); + + // Do we know values for each element of the aggregate? + auto KnowAllElts = [&AggElts]() { + return all_of(AggElts, + [](Optional<Value *> Elt) { return Elt != NotFound; }); + }; + + int Depth = 0; + + // Arbitrary `insertvalue` visitation depth limit. Let's be okay with + // every element being overwritten twice, which should never happen. + static const int DepthLimit = 2 * NumAggElts; + + // Recurse up the chain of `insertvalue` aggregate operands until either we've + // reconstructed full initializer or can't visit any more `insertvalue`'s. + for (InsertValueInst *CurrIVI = &OrigIVI; + Depth < DepthLimit && CurrIVI && !KnowAllElts(); + CurrIVI = dyn_cast<InsertValueInst>(CurrIVI->getAggregateOperand()), + ++Depth) { + Value *InsertedValue = CurrIVI->getInsertedValueOperand(); + ArrayRef<unsigned int> Indices = CurrIVI->getIndices(); + + // Don't bother with more than single-level aggregates. + if (Indices.size() != 1) + return nullptr; // FIXME: deal with more complex aggregates? + + // Now, we may have already previously recorded the value for this element + // of an aggregate. If we did, that means the CurrIVI will later be + // overwritten with the already-recorded value. But if not, let's record it! + Optional<Value *> &Elt = AggElts[Indices.front()]; + Elt = Elt.getValueOr(InsertedValue); + + // FIXME: should we handle chain-terminating undef base operand? + } + + // Was that sufficient to deduce the full initializer for the aggregate? + if (!KnowAllElts()) + return nullptr; // Give up then. + + // We now want to find the source[s] of the aggregate elements we've found. + // And with "source" we mean the original aggregate[s] from which + // the inserted elements were extracted. This may require PHI translation. + + enum class AggregateDescription { + /// When analyzing the value that was inserted into an aggregate, we did + /// not manage to find defining `extractvalue` instruction to analyze. + NotFound, + /// When analyzing the value that was inserted into an aggregate, we did + /// manage to find defining `extractvalue` instruction[s], and everything + /// matched perfectly - aggregate type, element insertion/extraction index. + Found, + /// When analyzing the value that was inserted into an aggregate, we did + /// manage to find defining `extractvalue` instruction, but there was + /// a mismatch: either the source type from which the extraction was didn't + /// match the aggregate type into which the insertion was, + /// or the extraction/insertion channels mismatched, + /// or different elements had different source aggregates. + FoundMismatch + }; + auto Describe = [](Optional<Value *> SourceAggregate) { + if (SourceAggregate == NotFound) + return AggregateDescription::NotFound; + if (*SourceAggregate == FoundMismatch) + return AggregateDescription::FoundMismatch; + return AggregateDescription::Found; + }; + + // Given the value \p Elt that was being inserted into element \p EltIdx of an + // aggregate AggTy, see if \p Elt was originally defined by an + // appropriate extractvalue (same element index, same aggregate type). + // If found, return the source aggregate from which the extraction was. + // If \p PredBB is provided, does PHI translation of an \p Elt first. + auto FindSourceAggregate = + [&](Value *Elt, unsigned EltIdx, Optional<BasicBlock *> UseBB, + Optional<BasicBlock *> PredBB) -> Optional<Value *> { + // For now(?), only deal with, at most, a single level of PHI indirection. + if (UseBB && PredBB) + Elt = Elt->DoPHITranslation(*UseBB, *PredBB); + // FIXME: deal with multiple levels of PHI indirection? + + // Did we find an extraction? + auto *EVI = dyn_cast<ExtractValueInst>(Elt); + if (!EVI) + return NotFound; + + Value *SourceAggregate = EVI->getAggregateOperand(); + + // Is the extraction from the same type into which the insertion was? + if (SourceAggregate->getType() != AggTy) + return FoundMismatch; + // And the element index doesn't change between extraction and insertion? + if (EVI->getNumIndices() != 1 || EltIdx != EVI->getIndices().front()) + return FoundMismatch; + + return SourceAggregate; // AggregateDescription::Found + }; + + // Given elements AggElts that were constructing an aggregate OrigIVI, + // see if we can find appropriate source aggregate for each of the elements, + // and see it's the same aggregate for each element. If so, return it. + auto FindCommonSourceAggregate = + [&](Optional<BasicBlock *> UseBB, + Optional<BasicBlock *> PredBB) -> Optional<Value *> { + Optional<Value *> SourceAggregate; + + for (auto I : enumerate(AggElts)) { + assert(Describe(SourceAggregate) != AggregateDescription::FoundMismatch && + "We don't store nullptr in SourceAggregate!"); + assert((Describe(SourceAggregate) == AggregateDescription::Found) == + (I.index() != 0) && + "SourceAggregate should be valid after the the first element,"); + + // For this element, is there a plausible source aggregate? + // FIXME: we could special-case undef element, IFF we know that in the + // source aggregate said element isn't poison. + Optional<Value *> SourceAggregateForElement = + FindSourceAggregate(*I.value(), I.index(), UseBB, PredBB); + + // Okay, what have we found? Does that correlate with previous findings? + + // Regardless of whether or not we have previously found source + // aggregate for previous elements (if any), if we didn't find one for + // this element, passthrough whatever we have just found. + if (Describe(SourceAggregateForElement) != AggregateDescription::Found) + return SourceAggregateForElement; + + // Okay, we have found source aggregate for this element. + // Let's see what we already know from previous elements, if any. + switch (Describe(SourceAggregate)) { + case AggregateDescription::NotFound: + // This is apparently the first element that we have examined. + SourceAggregate = SourceAggregateForElement; // Record the aggregate! + continue; // Great, now look at next element. + case AggregateDescription::Found: + // We have previously already successfully examined other elements. + // Is this the same source aggregate we've found for other elements? + if (*SourceAggregateForElement != *SourceAggregate) + return FoundMismatch; + continue; // Still the same aggregate, look at next element. + case AggregateDescription::FoundMismatch: + llvm_unreachable("Can't happen. We would have early-exited then."); + }; + } + + assert(Describe(SourceAggregate) == AggregateDescription::Found && + "Must be a valid Value"); + return *SourceAggregate; + }; + + Optional<Value *> SourceAggregate; + + // Can we find the source aggregate without looking at predecessors? + SourceAggregate = FindCommonSourceAggregate(/*UseBB=*/None, /*PredBB=*/None); + if (Describe(SourceAggregate) != AggregateDescription::NotFound) { + if (Describe(SourceAggregate) == AggregateDescription::FoundMismatch) + return nullptr; // Conflicting source aggregates! + ++NumAggregateReconstructionsSimplified; + return replaceInstUsesWith(OrigIVI, *SourceAggregate); + } + + // Okay, apparently we need to look at predecessors. + + // We should be smart about picking the "use" basic block, which will be the + // merge point for aggregate, where we'll insert the final PHI that will be + // used instead of OrigIVI. Basic block of OrigIVI is *not* the right choice. + // We should look in which blocks each of the AggElts is being defined, + // they all should be defined in the same basic block. + BasicBlock *UseBB = nullptr; + + for (const Optional<Value *> &Elt : AggElts) { + // If this element's value was not defined by an instruction, ignore it. + auto *I = dyn_cast<Instruction>(*Elt); + if (!I) + continue; + // Otherwise, in which basic block is this instruction located? + BasicBlock *BB = I->getParent(); + // If it's the first instruction we've encountered, record the basic block. + if (!UseBB) { + UseBB = BB; + continue; + } + // Otherwise, this must be the same basic block we've seen previously. + if (UseBB != BB) + return nullptr; + } + + // If *all* of the elements are basic-block-independent, meaning they are + // either function arguments, or constant expressions, then if we didn't + // handle them without predecessor-aware handling, we won't handle them now. + if (!UseBB) + return nullptr; + + // If we didn't manage to find source aggregate without looking at + // predecessors, and there are no predecessors to look at, then we're done. + if (pred_empty(UseBB)) + return nullptr; + + // Arbitrary predecessor count limit. + static const int PredCountLimit = 64; + + // Cache the (non-uniqified!) list of predecessors in a vector, + // checking the limit at the same time for efficiency. + SmallVector<BasicBlock *, 4> Preds; // May have duplicates! + for (BasicBlock *Pred : predecessors(UseBB)) { + // Don't bother if there are too many predecessors. + if (Preds.size() >= PredCountLimit) // FIXME: only count duplicates once? + return nullptr; + Preds.emplace_back(Pred); + } + + // For each predecessor, what is the source aggregate, + // from which all the elements were originally extracted from? + // Note that we want for the map to have stable iteration order! + SmallDenseMap<BasicBlock *, Value *, 4> SourceAggregates; + for (BasicBlock *Pred : Preds) { + std::pair<decltype(SourceAggregates)::iterator, bool> IV = + SourceAggregates.insert({Pred, nullptr}); + // Did we already evaluate this predecessor? + if (!IV.second) + continue; + + // Let's hope that when coming from predecessor Pred, all elements of the + // aggregate produced by OrigIVI must have been originally extracted from + // the same aggregate. Is that so? Can we find said original aggregate? + SourceAggregate = FindCommonSourceAggregate(UseBB, Pred); + if (Describe(SourceAggregate) != AggregateDescription::Found) + return nullptr; // Give up. + IV.first->second = *SourceAggregate; + } + + // All good! Now we just need to thread the source aggregates here. + // Note that we have to insert the new PHI here, ourselves, because we can't + // rely on InstCombinerImpl::run() inserting it into the right basic block. + // Note that the same block can be a predecessor more than once, + // and we need to preserve that invariant for the PHI node. + BuilderTy::InsertPointGuard Guard(Builder); + Builder.SetInsertPoint(UseBB->getFirstNonPHI()); + auto *PHI = + Builder.CreatePHI(AggTy, Preds.size(), OrigIVI.getName() + ".merged"); + for (BasicBlock *Pred : Preds) + PHI->addIncoming(SourceAggregates[Pred], Pred); + + ++NumAggregateReconstructionsSimplified; + return replaceInstUsesWith(OrigIVI, PHI); +} + /// Try to find redundant insertvalue instructions, like the following ones: /// %0 = insertvalue { i8, i32 } undef, i8 %x, 0 /// %1 = insertvalue { i8, i32 } %0, i8 %y, 0 @@ -699,7 +991,7 @@ static ShuffleOps collectShuffleElements(Value *V, SmallVectorImpl<int> &Mask, /// first one, making the first one redundant. /// It should be transformed to: /// %0 = insertvalue { i8, i32 } undef, i8 %y, 0 -Instruction *InstCombiner::visitInsertValueInst(InsertValueInst &I) { +Instruction *InstCombinerImpl::visitInsertValueInst(InsertValueInst &I) { bool IsRedundant = false; ArrayRef<unsigned int> FirstIndices = I.getIndices(); @@ -724,6 +1016,10 @@ Instruction *InstCombiner::visitInsertValueInst(InsertValueInst &I) { if (IsRedundant) return replaceInstUsesWith(I, I.getOperand(0)); + + if (Instruction *NewI = foldAggregateConstructionIntoAggregateReuse(I)) + return NewI; + return nullptr; } @@ -854,7 +1150,8 @@ static Instruction *foldInsEltIntoSplat(InsertElementInst &InsElt) { // For example: // inselt (shuf (inselt undef, X, 0), undef, <0,undef,0,undef>), X, 1 // --> shuf (inselt undef, X, 0), undef, <0,0,0,undef> - unsigned NumMaskElts = Shuf->getType()->getNumElements(); + unsigned NumMaskElts = + cast<FixedVectorType>(Shuf->getType())->getNumElements(); SmallVector<int, 16> NewMask(NumMaskElts); for (unsigned i = 0; i != NumMaskElts; ++i) NewMask[i] = i == IdxC ? 0 : Shuf->getMaskValue(i); @@ -892,7 +1189,8 @@ static Instruction *foldInsEltIntoIdentityShuffle(InsertElementInst &InsElt) { // that same index value. // For example: // inselt (shuf X, IdMask), (extelt X, IdxC), IdxC --> shuf X, IdMask' - unsigned NumMaskElts = Shuf->getType()->getNumElements(); + unsigned NumMaskElts = + cast<FixedVectorType>(Shuf->getType())->getNumElements(); SmallVector<int, 16> NewMask(NumMaskElts); ArrayRef<int> OldMask = Shuf->getShuffleMask(); for (unsigned i = 0; i != NumMaskElts; ++i) { @@ -1041,7 +1339,7 @@ static Instruction *foldConstantInsEltIntoShuffle(InsertElementInst &InsElt) { return nullptr; } -Instruction *InstCombiner::visitInsertElementInst(InsertElementInst &IE) { +Instruction *InstCombinerImpl::visitInsertElementInst(InsertElementInst &IE) { Value *VecOp = IE.getOperand(0); Value *ScalarOp = IE.getOperand(1); Value *IdxOp = IE.getOperand(2); @@ -1189,7 +1487,7 @@ static bool canEvaluateShuffled(Value *V, ArrayRef<int> Mask, // Propagating an undefined shuffle mask element to integer div/rem is not // allowed because those opcodes can create immediate undefined behavior // from an undefined element in an operand. - if (llvm::any_of(Mask, [](int M){ return M == -1; })) + if (llvm::is_contained(Mask, -1)) return false; LLVM_FALLTHROUGH; case Instruction::Add: @@ -1222,7 +1520,7 @@ static bool canEvaluateShuffled(Value *V, ArrayRef<int> Mask, // longer vector ops, but that may result in more expensive codegen. Type *ITy = I->getType(); if (ITy->isVectorTy() && - Mask.size() > cast<VectorType>(ITy)->getNumElements()) + Mask.size() > cast<FixedVectorType>(ITy)->getNumElements()) return false; for (Value *Operand : I->operands()) { if (!canEvaluateShuffled(Operand, Mask, Depth - 1)) @@ -1380,7 +1678,8 @@ static Value *evaluateInDifferentElementOrder(Value *V, ArrayRef<int> Mask) { case Instruction::GetElementPtr: { SmallVector<Value*, 8> NewOps; bool NeedsRebuild = - (Mask.size() != cast<VectorType>(I->getType())->getNumElements()); + (Mask.size() != + cast<FixedVectorType>(I->getType())->getNumElements()); for (int i = 0, e = I->getNumOperands(); i != e; ++i) { Value *V; // Recursively call evaluateInDifferentElementOrder on vector arguments @@ -1435,7 +1734,7 @@ static Value *evaluateInDifferentElementOrder(Value *V, ArrayRef<int> Mask) { static bool isShuffleExtractingFromLHS(ShuffleVectorInst &SVI, ArrayRef<int> Mask) { unsigned LHSElems = - cast<VectorType>(SVI.getOperand(0)->getType())->getNumElements(); + cast<FixedVectorType>(SVI.getOperand(0)->getType())->getNumElements(); unsigned MaskElems = Mask.size(); unsigned BegIdx = Mask.front(); unsigned EndIdx = Mask.back(); @@ -1525,7 +1824,7 @@ static Instruction *foldSelectShuffleWith1Binop(ShuffleVectorInst &Shuf) { is_contained(Mask, UndefMaskElem) && (Instruction::isIntDivRem(BOpcode) || Instruction::isShift(BOpcode)); if (MightCreatePoisonOrUB) - NewC = getSafeVectorConstantForBinop(BOpcode, NewC, true); + NewC = InstCombiner::getSafeVectorConstantForBinop(BOpcode, NewC, true); // shuf (bop X, C), X, M --> bop X, C' // shuf X, (bop X, C), M --> bop X, C' @@ -1567,7 +1866,8 @@ static Instruction *canonicalizeInsertSplat(ShuffleVectorInst &Shuf, // For example: // shuf (inselt undef, X, 2), undef, <2,2,undef> // --> shuf (inselt undef, X, 0), undef, <0,0,undef> - unsigned NumMaskElts = Shuf.getType()->getNumElements(); + unsigned NumMaskElts = + cast<FixedVectorType>(Shuf.getType())->getNumElements(); SmallVector<int, 16> NewMask(NumMaskElts, 0); for (unsigned i = 0; i != NumMaskElts; ++i) if (Mask[i] == UndefMaskElem) @@ -1585,7 +1885,7 @@ static Instruction *foldSelectShuffle(ShuffleVectorInst &Shuf, // Canonicalize to choose from operand 0 first unless operand 1 is undefined. // Commuting undef to operand 0 conflicts with another canonicalization. - unsigned NumElts = Shuf.getType()->getNumElements(); + unsigned NumElts = cast<FixedVectorType>(Shuf.getType())->getNumElements(); if (!isa<UndefValue>(Shuf.getOperand(1)) && Shuf.getMaskValue(0) >= (int)NumElts) { // TODO: Can we assert that both operands of a shuffle-select are not undef @@ -1652,7 +1952,8 @@ static Instruction *foldSelectShuffle(ShuffleVectorInst &Shuf, is_contained(Mask, UndefMaskElem) && (Instruction::isIntDivRem(BOpc) || Instruction::isShift(BOpc)); if (MightCreatePoisonOrUB) - NewC = getSafeVectorConstantForBinop(BOpc, NewC, ConstantsAreOp1); + NewC = InstCombiner::getSafeVectorConstantForBinop(BOpc, NewC, + ConstantsAreOp1); Value *V; if (X == Y) { @@ -1719,8 +2020,8 @@ static Instruction *foldTruncShuffle(ShuffleVectorInst &Shuf, // and the source element type must be larger than the shuffle element type. Type *SrcType = X->getType(); if (!SrcType->isVectorTy() || !SrcType->isIntOrIntVectorTy() || - cast<VectorType>(SrcType)->getNumElements() != - cast<VectorType>(DestType)->getNumElements() || + cast<FixedVectorType>(SrcType)->getNumElements() != + cast<FixedVectorType>(DestType)->getNumElements() || SrcType->getScalarSizeInBits() % DestType->getScalarSizeInBits() != 0) return nullptr; @@ -1736,8 +2037,7 @@ static Instruction *foldTruncShuffle(ShuffleVectorInst &Shuf, if (Mask[i] == UndefMaskElem) continue; uint64_t LSBIndex = IsBigEndian ? (i + 1) * TruncRatio - 1 : i * TruncRatio; - assert(LSBIndex <= std::numeric_limits<int32_t>::max() && - "Overflowed 32-bits"); + assert(LSBIndex <= INT32_MAX && "Overflowed 32-bits"); if (Mask[i] != (int)LSBIndex) return nullptr; } @@ -1764,19 +2064,19 @@ static Instruction *narrowVectorSelect(ShuffleVectorInst &Shuf, // We need a narrow condition value. It must be extended with undef elements // and have the same number of elements as this shuffle. - unsigned NarrowNumElts = Shuf.getType()->getNumElements(); + unsigned NarrowNumElts = + cast<FixedVectorType>(Shuf.getType())->getNumElements(); Value *NarrowCond; if (!match(Cond, m_OneUse(m_Shuffle(m_Value(NarrowCond), m_Undef()))) || - cast<VectorType>(NarrowCond->getType())->getNumElements() != + cast<FixedVectorType>(NarrowCond->getType())->getNumElements() != NarrowNumElts || !cast<ShuffleVectorInst>(Cond)->isIdentityWithPadding()) return nullptr; // shuf (sel (shuf NarrowCond, undef, WideMask), X, Y), undef, NarrowMask) --> // sel NarrowCond, (shuf X, undef, NarrowMask), (shuf Y, undef, NarrowMask) - Value *Undef = UndefValue::get(X->getType()); - Value *NarrowX = Builder.CreateShuffleVector(X, Undef, Shuf.getShuffleMask()); - Value *NarrowY = Builder.CreateShuffleVector(Y, Undef, Shuf.getShuffleMask()); + Value *NarrowX = Builder.CreateShuffleVector(X, Shuf.getShuffleMask()); + Value *NarrowY = Builder.CreateShuffleVector(Y, Shuf.getShuffleMask()); return SelectInst::Create(NarrowCond, NarrowX, NarrowY); } @@ -1807,7 +2107,7 @@ static Instruction *foldIdentityExtractShuffle(ShuffleVectorInst &Shuf) { // new shuffle mask. Otherwise, copy the original mask element. Example: // shuf (shuf X, Y, <C0, C1, C2, undef, C4>), undef, <0, undef, 2, 3> --> // shuf X, Y, <C0, undef, C2, undef> - unsigned NumElts = Shuf.getType()->getNumElements(); + unsigned NumElts = cast<FixedVectorType>(Shuf.getType())->getNumElements(); SmallVector<int, 16> NewMask(NumElts); assert(NumElts < Mask.size() && "Identity with extract must have less elements than its inputs"); @@ -1823,7 +2123,7 @@ static Instruction *foldIdentityExtractShuffle(ShuffleVectorInst &Shuf) { /// Try to replace a shuffle with an insertelement or try to replace a shuffle /// operand with the operand of an insertelement. static Instruction *foldShuffleWithInsert(ShuffleVectorInst &Shuf, - InstCombiner &IC) { + InstCombinerImpl &IC) { Value *V0 = Shuf.getOperand(0), *V1 = Shuf.getOperand(1); SmallVector<int, 16> Mask; Shuf.getShuffleMask(Mask); @@ -1832,7 +2132,7 @@ static Instruction *foldShuffleWithInsert(ShuffleVectorInst &Shuf, // TODO: This restriction could be removed if the insert has only one use // (because the transform would require a new length-changing shuffle). int NumElts = Mask.size(); - if (NumElts != (int)(cast<VectorType>(V0->getType())->getNumElements())) + if (NumElts != (int)(cast<FixedVectorType>(V0->getType())->getNumElements())) return nullptr; // This is a specialization of a fold in SimplifyDemandedVectorElts. We may @@ -1844,7 +2144,7 @@ static Instruction *foldShuffleWithInsert(ShuffleVectorInst &Shuf, uint64_t IdxC; if (match(V0, m_InsertElt(m_Value(X), m_Value(), m_ConstantInt(IdxC)))) { // shuf (inselt X, ?, IdxC), ?, Mask --> shuf X, ?, Mask - if (none_of(Mask, [IdxC](int MaskElt) { return MaskElt == (int)IdxC; })) + if (!is_contained(Mask, (int)IdxC)) return IC.replaceOperand(Shuf, 0, X); } if (match(V1, m_InsertElt(m_Value(X), m_Value(), m_ConstantInt(IdxC)))) { @@ -1852,7 +2152,7 @@ static Instruction *foldShuffleWithInsert(ShuffleVectorInst &Shuf, // accesses to the 2nd vector input of the shuffle. IdxC += NumElts; // shuf ?, (inselt X, ?, IdxC), Mask --> shuf ?, X, Mask - if (none_of(Mask, [IdxC](int MaskElt) { return MaskElt == (int)IdxC; })) + if (!is_contained(Mask, (int)IdxC)) return IC.replaceOperand(Shuf, 1, X); } @@ -1927,9 +2227,10 @@ static Instruction *foldIdentityPaddedShuffles(ShuffleVectorInst &Shuf) { Value *X = Shuffle0->getOperand(0); Value *Y = Shuffle1->getOperand(0); if (X->getType() != Y->getType() || - !isPowerOf2_32(Shuf.getType()->getNumElements()) || - !isPowerOf2_32(Shuffle0->getType()->getNumElements()) || - !isPowerOf2_32(cast<VectorType>(X->getType())->getNumElements()) || + !isPowerOf2_32(cast<FixedVectorType>(Shuf.getType())->getNumElements()) || + !isPowerOf2_32( + cast<FixedVectorType>(Shuffle0->getType())->getNumElements()) || + !isPowerOf2_32(cast<FixedVectorType>(X->getType())->getNumElements()) || isa<UndefValue>(X) || isa<UndefValue>(Y)) return nullptr; assert(isa<UndefValue>(Shuffle0->getOperand(1)) && @@ -1940,8 +2241,8 @@ static Instruction *foldIdentityPaddedShuffles(ShuffleVectorInst &Shuf) { // operands directly by adjusting the shuffle mask to account for the narrower // types: // shuf (widen X), (widen Y), Mask --> shuf X, Y, Mask' - int NarrowElts = cast<VectorType>(X->getType())->getNumElements(); - int WideElts = Shuffle0->getType()->getNumElements(); + int NarrowElts = cast<FixedVectorType>(X->getType())->getNumElements(); + int WideElts = cast<FixedVectorType>(Shuffle0->getType())->getNumElements(); assert(WideElts > NarrowElts && "Unexpected types for identity with padding"); ArrayRef<int> Mask = Shuf.getShuffleMask(); @@ -1974,7 +2275,7 @@ static Instruction *foldIdentityPaddedShuffles(ShuffleVectorInst &Shuf) { return new ShuffleVectorInst(X, Y, NewMask); } -Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) { +Instruction *InstCombinerImpl::visitShuffleVectorInst(ShuffleVectorInst &SVI) { Value *LHS = SVI.getOperand(0); Value *RHS = SVI.getOperand(1); SimplifyQuery ShufQuery = SQ.getWithInstruction(&SVI); @@ -1982,9 +2283,13 @@ Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) { SVI.getType(), ShufQuery)) return replaceInstUsesWith(SVI, V); + // Bail out for scalable vectors + if (isa<ScalableVectorType>(LHS->getType())) + return nullptr; + // shuffle x, x, mask --> shuffle x, undef, mask' - unsigned VWidth = SVI.getType()->getNumElements(); - unsigned LHSWidth = cast<VectorType>(LHS->getType())->getNumElements(); + unsigned VWidth = cast<FixedVectorType>(SVI.getType())->getNumElements(); + unsigned LHSWidth = cast<FixedVectorType>(LHS->getType())->getNumElements(); ArrayRef<int> Mask = SVI.getShuffleMask(); Type *Int32Ty = Type::getInt32Ty(SVI.getContext()); @@ -1998,7 +2303,7 @@ Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) { if (match(LHS, m_BitCast(m_Value(X))) && match(RHS, m_Undef()) && X->getType()->isVectorTy() && VWidth == LHSWidth) { // Try to create a scaled mask constant. - auto *XType = cast<VectorType>(X->getType()); + auto *XType = cast<FixedVectorType>(X->getType()); unsigned XNumElts = XType->getNumElements(); SmallVector<int, 16> ScaledMask; if (XNumElts >= VWidth) { @@ -2106,7 +2411,7 @@ Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) { if (isShuffleExtractingFromLHS(SVI, Mask)) { Value *V = LHS; unsigned MaskElems = Mask.size(); - VectorType *SrcTy = cast<VectorType>(V->getType()); + auto *SrcTy = cast<FixedVectorType>(V->getType()); unsigned VecBitWidth = SrcTy->getPrimitiveSizeInBits().getFixedSize(); unsigned SrcElemBitWidth = DL.getTypeSizeInBits(SrcTy->getElementType()); assert(SrcElemBitWidth && "vector elements must have a bitwidth"); @@ -2138,8 +2443,7 @@ Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) { SmallVector<int, 16> ShuffleMask(SrcNumElems, -1); for (unsigned I = 0, E = MaskElems, Idx = BegIdx; I != E; ++Idx, ++I) ShuffleMask[I] = Idx; - V = Builder.CreateShuffleVector(V, UndefValue::get(V->getType()), - ShuffleMask, + V = Builder.CreateShuffleVector(V, ShuffleMask, SVI.getName() + ".extract"); BegIdx = 0; } @@ -2224,11 +2528,11 @@ Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) { if (LHSShuffle) { LHSOp0 = LHSShuffle->getOperand(0); LHSOp1 = LHSShuffle->getOperand(1); - LHSOp0Width = cast<VectorType>(LHSOp0->getType())->getNumElements(); + LHSOp0Width = cast<FixedVectorType>(LHSOp0->getType())->getNumElements(); } if (RHSShuffle) { RHSOp0 = RHSShuffle->getOperand(0); - RHSOp0Width = cast<VectorType>(RHSOp0->getType())->getNumElements(); + RHSOp0Width = cast<FixedVectorType>(RHSOp0->getType())->getNumElements(); } Value* newLHS = LHS; Value* newRHS = RHS; @@ -2331,17 +2635,9 @@ Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) { // If the result mask is equal to one of the original shuffle masks, // or is a splat, do the replacement. if (isSplat || newMask == LHSMask || newMask == RHSMask || newMask == Mask) { - SmallVector<Constant*, 16> Elts; - for (unsigned i = 0, e = newMask.size(); i != e; ++i) { - if (newMask[i] < 0) { - Elts.push_back(UndefValue::get(Int32Ty)); - } else { - Elts.push_back(ConstantInt::get(Int32Ty, newMask[i])); - } - } if (!newRHS) newRHS = UndefValue::get(newLHS->getType()); - return new ShuffleVectorInst(newLHS, newRHS, ConstantVector::get(Elts)); + return new ShuffleVectorInst(newLHS, newRHS, newMask); } return MadeChange ? &SVI : nullptr; diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp index b3254c10a0b2..518e909e8ab4 100644 --- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -59,6 +59,7 @@ #include "llvm/Analysis/ProfileSummaryInfo.h" #include "llvm/Analysis/TargetFolder.h" #include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/BasicBlock.h" @@ -113,6 +114,9 @@ using namespace llvm::PatternMatch; #define DEBUG_TYPE "instcombine" +STATISTIC(NumWorklistIterations, + "Number of instruction combining iterations performed"); + STATISTIC(NumCombined , "Number of insts combined"); STATISTIC(NumConstProp, "Number of constant folds"); STATISTIC(NumDeadInst , "Number of dead inst eliminated"); @@ -123,8 +127,13 @@ STATISTIC(NumReassoc , "Number of reassociations"); DEBUG_COUNTER(VisitCounter, "instcombine-visit", "Controls which instructions are visited"); +// FIXME: these limits eventually should be as low as 2. static constexpr unsigned InstCombineDefaultMaxIterations = 1000; +#ifndef NDEBUG +static constexpr unsigned InstCombineDefaultInfiniteLoopThreshold = 100; +#else static constexpr unsigned InstCombineDefaultInfiniteLoopThreshold = 1000; +#endif static cl::opt<bool> EnableCodeSinking("instcombine-code-sinking", cl::desc("Enable code sinking"), @@ -155,7 +164,41 @@ MaxArraySize("instcombine-maxarray-size", cl::init(1024), static cl::opt<unsigned> ShouldLowerDbgDeclare("instcombine-lower-dbg-declare", cl::Hidden, cl::init(true)); -Value *InstCombiner::EmitGEPOffset(User *GEP) { +Optional<Instruction *> +InstCombiner::targetInstCombineIntrinsic(IntrinsicInst &II) { + // Handle target specific intrinsics + if (II.getCalledFunction()->isTargetIntrinsic()) { + return TTI.instCombineIntrinsic(*this, II); + } + return None; +} + +Optional<Value *> InstCombiner::targetSimplifyDemandedUseBitsIntrinsic( + IntrinsicInst &II, APInt DemandedMask, KnownBits &Known, + bool &KnownBitsComputed) { + // Handle target specific intrinsics + if (II.getCalledFunction()->isTargetIntrinsic()) { + return TTI.simplifyDemandedUseBitsIntrinsic(*this, II, DemandedMask, Known, + KnownBitsComputed); + } + return None; +} + +Optional<Value *> InstCombiner::targetSimplifyDemandedVectorEltsIntrinsic( + IntrinsicInst &II, APInt DemandedElts, APInt &UndefElts, APInt &UndefElts2, + APInt &UndefElts3, + std::function<void(Instruction *, unsigned, APInt, APInt &)> + SimplifyAndSetOp) { + // Handle target specific intrinsics + if (II.getCalledFunction()->isTargetIntrinsic()) { + return TTI.simplifyDemandedVectorEltsIntrinsic( + *this, II, DemandedElts, UndefElts, UndefElts2, UndefElts3, + SimplifyAndSetOp); + } + return None; +} + +Value *InstCombinerImpl::EmitGEPOffset(User *GEP) { return llvm::EmitGEPOffset(&Builder, DL, GEP); } @@ -168,8 +211,8 @@ Value *InstCombiner::EmitGEPOffset(User *GEP) { /// legal to convert to, in order to open up more combining opportunities. /// NOTE: this treats i8, i16 and i32 specially, due to them being so common /// from frontend languages. -bool InstCombiner::shouldChangeType(unsigned FromWidth, - unsigned ToWidth) const { +bool InstCombinerImpl::shouldChangeType(unsigned FromWidth, + unsigned ToWidth) const { bool FromLegal = FromWidth == 1 || DL.isLegalInteger(FromWidth); bool ToLegal = ToWidth == 1 || DL.isLegalInteger(ToWidth); @@ -196,7 +239,7 @@ bool InstCombiner::shouldChangeType(unsigned FromWidth, /// to a larger illegal type. i1 is always treated as a legal type because it is /// a fundamental type in IR, and there are many specialized optimizations for /// i1 types. -bool InstCombiner::shouldChangeType(Type *From, Type *To) const { +bool InstCombinerImpl::shouldChangeType(Type *From, Type *To) const { // TODO: This could be extended to allow vectors. Datalayout changes might be // needed to properly support that. if (!From->isIntegerTy() || !To->isIntegerTy()) @@ -264,7 +307,8 @@ static void ClearSubclassDataAfterReassociation(BinaryOperator &I) { /// cast to eliminate one of the associative operations: /// (op (cast (op X, C2)), C1) --> (cast (op X, op (C1, C2))) /// (op (cast (op X, C2)), C1) --> (op (cast X), op (C1, C2)) -static bool simplifyAssocCastAssoc(BinaryOperator *BinOp1, InstCombiner &IC) { +static bool simplifyAssocCastAssoc(BinaryOperator *BinOp1, + InstCombinerImpl &IC) { auto *Cast = dyn_cast<CastInst>(BinOp1->getOperand(0)); if (!Cast || !Cast->hasOneUse()) return false; @@ -322,7 +366,7 @@ static bool simplifyAssocCastAssoc(BinaryOperator *BinOp1, InstCombiner &IC) { /// 5. Transform: "A op (B op C)" ==> "B op (C op A)" if "C op A" simplifies. /// 6. Transform: "(A op C1) op (B op C2)" ==> "(A op B) op (C1 op C2)" /// if C1 and C2 are constants. -bool InstCombiner::SimplifyAssociativeOrCommutative(BinaryOperator &I) { +bool InstCombinerImpl::SimplifyAssociativeOrCommutative(BinaryOperator &I) { Instruction::BinaryOps Opcode = I.getOpcode(); bool Changed = false; @@ -550,9 +594,10 @@ getBinOpsForFactorization(Instruction::BinaryOps TopOpcode, BinaryOperator *Op, /// This tries to simplify binary operations by factorizing out common terms /// (e. g. "(A*B)+(A*C)" -> "A*(B+C)"). -Value *InstCombiner::tryFactorization(BinaryOperator &I, - Instruction::BinaryOps InnerOpcode, - Value *A, Value *B, Value *C, Value *D) { +Value *InstCombinerImpl::tryFactorization(BinaryOperator &I, + Instruction::BinaryOps InnerOpcode, + Value *A, Value *B, Value *C, + Value *D) { assert(A && B && C && D && "All values must be provided"); Value *V = nullptr; @@ -655,7 +700,7 @@ Value *InstCombiner::tryFactorization(BinaryOperator &I, /// (eg "(A*B)+(A*C)" -> "A*(B+C)") or expanding out if this results in /// simplifications (eg: "A & (B | C) -> (A&B) | (A&C)" if this is a win). /// Returns the simplified value, or null if it didn't simplify. -Value *InstCombiner::SimplifyUsingDistributiveLaws(BinaryOperator &I) { +Value *InstCombinerImpl::SimplifyUsingDistributiveLaws(BinaryOperator &I) { Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); BinaryOperator *Op0 = dyn_cast<BinaryOperator>(LHS); BinaryOperator *Op1 = dyn_cast<BinaryOperator>(RHS); @@ -698,8 +743,10 @@ Value *InstCombiner::SimplifyUsingDistributiveLaws(BinaryOperator &I) { Value *A = Op0->getOperand(0), *B = Op0->getOperand(1), *C = RHS; Instruction::BinaryOps InnerOpcode = Op0->getOpcode(); // op' - Value *L = SimplifyBinOp(TopLevelOpcode, A, C, SQ.getWithInstruction(&I)); - Value *R = SimplifyBinOp(TopLevelOpcode, B, C, SQ.getWithInstruction(&I)); + // Disable the use of undef because it's not safe to distribute undef. + auto SQDistributive = SQ.getWithInstruction(&I).getWithoutUndef(); + Value *L = SimplifyBinOp(TopLevelOpcode, A, C, SQDistributive); + Value *R = SimplifyBinOp(TopLevelOpcode, B, C, SQDistributive); // Do "A op C" and "B op C" both simplify? if (L && R) { @@ -735,8 +782,10 @@ Value *InstCombiner::SimplifyUsingDistributiveLaws(BinaryOperator &I) { Value *A = LHS, *B = Op1->getOperand(0), *C = Op1->getOperand(1); Instruction::BinaryOps InnerOpcode = Op1->getOpcode(); // op' - Value *L = SimplifyBinOp(TopLevelOpcode, A, B, SQ.getWithInstruction(&I)); - Value *R = SimplifyBinOp(TopLevelOpcode, A, C, SQ.getWithInstruction(&I)); + // Disable the use of undef because it's not safe to distribute undef. + auto SQDistributive = SQ.getWithInstruction(&I).getWithoutUndef(); + Value *L = SimplifyBinOp(TopLevelOpcode, A, B, SQDistributive); + Value *R = SimplifyBinOp(TopLevelOpcode, A, C, SQDistributive); // Do "A op B" and "A op C" both simplify? if (L && R) { @@ -769,8 +818,9 @@ Value *InstCombiner::SimplifyUsingDistributiveLaws(BinaryOperator &I) { return SimplifySelectsFeedingBinaryOp(I, LHS, RHS); } -Value *InstCombiner::SimplifySelectsFeedingBinaryOp(BinaryOperator &I, - Value *LHS, Value *RHS) { +Value *InstCombinerImpl::SimplifySelectsFeedingBinaryOp(BinaryOperator &I, + Value *LHS, + Value *RHS) { Value *A, *B, *C, *D, *E, *F; bool LHSIsSelect = match(LHS, m_Select(m_Value(A), m_Value(B), m_Value(C))); bool RHSIsSelect = match(RHS, m_Select(m_Value(D), m_Value(E), m_Value(F))); @@ -820,9 +870,33 @@ Value *InstCombiner::SimplifySelectsFeedingBinaryOp(BinaryOperator &I, return SI; } +/// Freely adapt every user of V as-if V was changed to !V. +/// WARNING: only if canFreelyInvertAllUsersOf() said this can be done. +void InstCombinerImpl::freelyInvertAllUsersOf(Value *I) { + for (User *U : I->users()) { + switch (cast<Instruction>(U)->getOpcode()) { + case Instruction::Select: { + auto *SI = cast<SelectInst>(U); + SI->swapValues(); + SI->swapProfMetadata(); + break; + } + case Instruction::Br: + cast<BranchInst>(U)->swapSuccessors(); // swaps prof metadata too + break; + case Instruction::Xor: + replaceInstUsesWith(cast<Instruction>(*U), I); + break; + default: + llvm_unreachable("Got unexpected user - out of sync with " + "canFreelyInvertAllUsersOf() ?"); + } + } +} + /// Given a 'sub' instruction, return the RHS of the instruction if the LHS is a /// constant zero (which is the 'negate' form). -Value *InstCombiner::dyn_castNegVal(Value *V) const { +Value *InstCombinerImpl::dyn_castNegVal(Value *V) const { Value *NegV; if (match(V, m_Neg(m_Value(NegV)))) return NegV; @@ -883,7 +957,8 @@ static Value *foldOperationIntoSelectOperand(Instruction &I, Value *SO, return RI; } -Instruction *InstCombiner::FoldOpIntoSelect(Instruction &Op, SelectInst *SI) { +Instruction *InstCombinerImpl::FoldOpIntoSelect(Instruction &Op, + SelectInst *SI) { // Don't modify shared select instructions. if (!SI->hasOneUse()) return nullptr; @@ -908,7 +983,7 @@ Instruction *InstCombiner::FoldOpIntoSelect(Instruction &Op, SelectInst *SI) { return nullptr; // If vectors, verify that they have the same number of elements. - if (SrcTy && SrcTy->getNumElements() != DestTy->getNumElements()) + if (SrcTy && SrcTy->getElementCount() != DestTy->getElementCount()) return nullptr; } @@ -978,7 +1053,7 @@ static Value *foldOperationIntoPhiValue(BinaryOperator *I, Value *InV, return RI; } -Instruction *InstCombiner::foldOpIntoPhi(Instruction &I, PHINode *PN) { +Instruction *InstCombinerImpl::foldOpIntoPhi(Instruction &I, PHINode *PN) { unsigned NumPHIValues = PN->getNumIncomingValues(); if (NumPHIValues == 0) return nullptr; @@ -1004,7 +1079,9 @@ Instruction *InstCombiner::foldOpIntoPhi(Instruction &I, PHINode *PN) { BasicBlock *NonConstBB = nullptr; for (unsigned i = 0; i != NumPHIValues; ++i) { Value *InVal = PN->getIncomingValue(i); - if (isa<Constant>(InVal) && !isa<ConstantExpr>(InVal)) + // If I is a freeze instruction, count undef as a non-constant. + if (match(InVal, m_ImmConstant()) && + (!isa<FreezeInst>(I) || isGuaranteedNotToBeUndefOrPoison(InVal))) continue; if (isa<PHINode>(InVal)) return nullptr; // Itself a phi. @@ -1029,9 +1106,11 @@ Instruction *InstCombiner::foldOpIntoPhi(Instruction &I, PHINode *PN) { // operation in that block. However, if this is a critical edge, we would be // inserting the computation on some other paths (e.g. inside a loop). Only // do this if the pred block is unconditionally branching into the phi block. + // Also, make sure that the pred block is not dead code. if (NonConstBB != nullptr) { BranchInst *BI = dyn_cast<BranchInst>(NonConstBB->getTerminator()); - if (!BI || !BI->isUnconditional()) return nullptr; + if (!BI || !BI->isUnconditional() || !DT.isReachableFromEntry(NonConstBB)) + return nullptr; } // Okay, we can do the transformation: create the new PHI node. @@ -1063,7 +1142,7 @@ Instruction *InstCombiner::foldOpIntoPhi(Instruction &I, PHINode *PN) { // FalseVInPred versus TrueVInPred. When we have individual nonzero // elements in the vector, we will incorrectly fold InC to // `TrueVInPred`. - if (InC && !isa<ConstantExpr>(InC) && isa<ConstantInt>(InC)) + if (InC && isa<ConstantInt>(InC)) InV = InC->isNullValue() ? FalseVInPred : TrueVInPred; else { // Generate the select in the same block as PN's current incoming block. @@ -1097,6 +1176,15 @@ Instruction *InstCombiner::foldOpIntoPhi(Instruction &I, PHINode *PN) { Builder); NewPN->addIncoming(InV, PN->getIncomingBlock(i)); } + } else if (isa<FreezeInst>(&I)) { + for (unsigned i = 0; i != NumPHIValues; ++i) { + Value *InV; + if (NonConstBB == PN->getIncomingBlock(i)) + InV = Builder.CreateFreeze(PN->getIncomingValue(i), "phi.fr"); + else + InV = PN->getIncomingValue(i); + NewPN->addIncoming(InV, PN->getIncomingBlock(i)); + } } else { CastInst *CI = cast<CastInst>(&I); Type *RetTy = CI->getType(); @@ -1111,8 +1199,8 @@ Instruction *InstCombiner::foldOpIntoPhi(Instruction &I, PHINode *PN) { } } - for (auto UI = PN->user_begin(), E = PN->user_end(); UI != E;) { - Instruction *User = cast<Instruction>(*UI++); + for (User *U : make_early_inc_range(PN->users())) { + Instruction *User = cast<Instruction>(U); if (User == &I) continue; replaceInstUsesWith(*User, NewPN); eraseInstFromFunction(*User); @@ -1120,7 +1208,7 @@ Instruction *InstCombiner::foldOpIntoPhi(Instruction &I, PHINode *PN) { return replaceInstUsesWith(I, NewPN); } -Instruction *InstCombiner::foldBinOpIntoSelectOrPhi(BinaryOperator &I) { +Instruction *InstCombinerImpl::foldBinOpIntoSelectOrPhi(BinaryOperator &I) { if (!isa<Constant>(I.getOperand(1))) return nullptr; @@ -1138,8 +1226,9 @@ Instruction *InstCombiner::foldBinOpIntoSelectOrPhi(BinaryOperator &I) { /// is a sequence of GEP indices into the pointed type that will land us at the /// specified offset. If so, fill them into NewIndices and return the resultant /// element type, otherwise return null. -Type *InstCombiner::FindElementAtOffset(PointerType *PtrTy, int64_t Offset, - SmallVectorImpl<Value *> &NewIndices) { +Type * +InstCombinerImpl::FindElementAtOffset(PointerType *PtrTy, int64_t Offset, + SmallVectorImpl<Value *> &NewIndices) { Type *Ty = PtrTy->getElementType(); if (!Ty->isSized()) return nullptr; @@ -1208,7 +1297,7 @@ static bool shouldMergeGEPs(GEPOperator &GEP, GEPOperator &Src) { /// Return a value X such that Val = X * Scale, or null if none. /// If the multiplication is known not to overflow, then NoSignedWrap is set. -Value *InstCombiner::Descale(Value *Val, APInt Scale, bool &NoSignedWrap) { +Value *InstCombinerImpl::Descale(Value *Val, APInt Scale, bool &NoSignedWrap) { assert(isa<IntegerType>(Val->getType()) && "Can only descale integers!"); assert(cast<IntegerType>(Val->getType())->getBitWidth() == Scale.getBitWidth() && "Scale not compatible with value!"); @@ -1448,9 +1537,8 @@ Value *InstCombiner::Descale(Value *Val, APInt Scale, bool &NoSignedWrap) { } while (true); } -Instruction *InstCombiner::foldVectorBinop(BinaryOperator &Inst) { - // FIXME: some of this is likely fine for scalable vectors - if (!isa<FixedVectorType>(Inst.getType())) +Instruction *InstCombinerImpl::foldVectorBinop(BinaryOperator &Inst) { + if (!isa<VectorType>(Inst.getType())) return nullptr; BinaryOperator::BinaryOps Opcode = Inst.getOpcode(); @@ -1539,13 +1627,15 @@ Instruction *InstCombiner::foldVectorBinop(BinaryOperator &Inst) { // intends to move shuffles closer to other shuffles and binops closer to // other binops, so they can be folded. It may also enable demanded elements // transforms. - unsigned NumElts = cast<FixedVectorType>(Inst.getType())->getNumElements(); Constant *C; - if (match(&Inst, + auto *InstVTy = dyn_cast<FixedVectorType>(Inst.getType()); + if (InstVTy && + match(&Inst, m_c_BinOp(m_OneUse(m_Shuffle(m_Value(V1), m_Undef(), m_Mask(Mask))), - m_Constant(C))) && - cast<FixedVectorType>(V1->getType())->getNumElements() <= NumElts) { - assert(Inst.getType()->getScalarType() == V1->getType()->getScalarType() && + m_ImmConstant(C))) && + cast<FixedVectorType>(V1->getType())->getNumElements() <= + InstVTy->getNumElements()) { + assert(InstVTy->getScalarType() == V1->getType()->getScalarType() && "Shuffle should not change scalar type"); // Find constant NewC that has property: @@ -1560,6 +1650,7 @@ Instruction *InstCombiner::foldVectorBinop(BinaryOperator &Inst) { UndefValue *UndefScalar = UndefValue::get(C->getType()->getScalarType()); SmallVector<Constant *, 16> NewVecC(SrcVecNumElts, UndefScalar); bool MayChange = true; + unsigned NumElts = InstVTy->getNumElements(); for (unsigned I = 0; I < NumElts; ++I) { Constant *CElt = C->getAggregateElement(I); if (ShMask[I] >= 0) { @@ -1648,9 +1739,8 @@ Instruction *InstCombiner::foldVectorBinop(BinaryOperator &Inst) { // values followed by a splat followed by the 2nd binary operation: // bo (splat X), (bo Y, OtherOp) --> bo (splat (bo X, Y)), OtherOp Value *NewBO = Builder.CreateBinOp(Opcode, X, Y); - UndefValue *Undef = UndefValue::get(Inst.getType()); SmallVector<int, 8> NewMask(MaskC.size(), SplatIndex); - Value *NewSplat = Builder.CreateShuffleVector(NewBO, Undef, NewMask); + Value *NewSplat = Builder.CreateShuffleVector(NewBO, NewMask); Instruction *R = BinaryOperator::Create(Opcode, NewSplat, OtherOp); // Intersect FMF on both new binops. Other (poison-generating) flags are @@ -1670,7 +1760,7 @@ Instruction *InstCombiner::foldVectorBinop(BinaryOperator &Inst) { /// Try to narrow the width of a binop if at least 1 operand is an extend of /// of a value. This requires a potentially expensive known bits check to make /// sure the narrow op does not overflow. -Instruction *InstCombiner::narrowMathIfNoOverflow(BinaryOperator &BO) { +Instruction *InstCombinerImpl::narrowMathIfNoOverflow(BinaryOperator &BO) { // We need at least one extended operand. Value *Op0 = BO.getOperand(0), *Op1 = BO.getOperand(1); @@ -1750,7 +1840,7 @@ static Instruction *foldSelectGEP(GetElementPtrInst &GEP, // gep (select Cond, TrueC, FalseC), IndexC --> select Cond, TrueC', FalseC' // Propagate 'inbounds' and metadata from existing instructions. // Note: using IRBuilder to create the constants for efficiency. - SmallVector<Value *, 4> IndexC(GEP.idx_begin(), GEP.idx_end()); + SmallVector<Value *, 4> IndexC(GEP.indices()); bool IsInBounds = GEP.isInBounds(); Value *NewTrueC = IsInBounds ? Builder.CreateInBoundsGEP(TrueC, IndexC) : Builder.CreateGEP(TrueC, IndexC); @@ -1759,8 +1849,8 @@ static Instruction *foldSelectGEP(GetElementPtrInst &GEP, return SelectInst::Create(Cond, NewTrueC, NewFalseC, "", nullptr, Sel); } -Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { - SmallVector<Value*, 8> Ops(GEP.op_begin(), GEP.op_end()); +Instruction *InstCombinerImpl::visitGetElementPtrInst(GetElementPtrInst &GEP) { + SmallVector<Value *, 8> Ops(GEP.operands()); Type *GEPType = GEP.getType(); Type *GEPEltType = GEP.getSourceElementType(); bool IsGEPSrcEleScalable = isa<ScalableVectorType>(GEPEltType); @@ -2130,7 +2220,7 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { // GEP (bitcast i8* X to [0 x i8]*), i32 0, ... ? if (CATy->getElementType() == StrippedPtrEltTy) { // -> GEP i8* X, ... - SmallVector<Value*, 8> Idx(GEP.idx_begin()+1, GEP.idx_end()); + SmallVector<Value *, 8> Idx(drop_begin(GEP.indices())); GetElementPtrInst *Res = GetElementPtrInst::Create( StrippedPtrEltTy, StrippedPtr, Idx, GEP.getName()); Res->setIsInBounds(GEP.isInBounds()); @@ -2166,7 +2256,7 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { // -> // %0 = GEP [10 x i8] addrspace(1)* X, ... // addrspacecast i8 addrspace(1)* %0 to i8* - SmallVector<Value*, 8> Idx(GEP.idx_begin(), GEP.idx_end()); + SmallVector<Value *, 8> Idx(GEP.indices()); Value *NewGEP = GEP.isInBounds() ? Builder.CreateInBoundsGEP(StrippedPtrEltTy, StrippedPtr, @@ -2308,15 +2398,15 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { // gep (bitcast [c x ty]* X to <c x ty>*), Y, Z --> gep X, Y, Z auto areMatchingArrayAndVecTypes = [](Type *ArrTy, Type *VecTy, const DataLayout &DL) { - auto *VecVTy = cast<VectorType>(VecTy); + auto *VecVTy = cast<FixedVectorType>(VecTy); return ArrTy->getArrayElementType() == VecVTy->getElementType() && ArrTy->getArrayNumElements() == VecVTy->getNumElements() && DL.getTypeAllocSize(ArrTy) == DL.getTypeAllocSize(VecTy); }; if (GEP.getNumOperands() == 3 && - ((GEPEltType->isArrayTy() && SrcEltType->isVectorTy() && + ((GEPEltType->isArrayTy() && isa<FixedVectorType>(SrcEltType) && areMatchingArrayAndVecTypes(GEPEltType, SrcEltType, DL)) || - (GEPEltType->isVectorTy() && SrcEltType->isArrayTy() && + (isa<FixedVectorType>(GEPEltType) && SrcEltType->isArrayTy() && areMatchingArrayAndVecTypes(SrcEltType, GEPEltType, DL)))) { // Create a new GEP here, as using `setOperand()` followed by @@ -2511,7 +2601,7 @@ static bool isAllocSiteRemovable(Instruction *AI, return true; } -Instruction *InstCombiner::visitAllocSite(Instruction &MI) { +Instruction *InstCombinerImpl::visitAllocSite(Instruction &MI) { // If we have a malloc call which is only used in any amount of comparisons to // null and free calls, delete the calls and replace the comparisons with true // or false as appropriate. @@ -2526,10 +2616,10 @@ Instruction *InstCombiner::visitAllocSite(Instruction &MI) { // If we are removing an alloca with a dbg.declare, insert dbg.value calls // before each store. - TinyPtrVector<DbgVariableIntrinsic *> DIIs; + SmallVector<DbgVariableIntrinsic *, 8> DVIs; std::unique_ptr<DIBuilder> DIB; if (isa<AllocaInst>(MI)) { - DIIs = FindDbgAddrUses(&MI); + findDbgUsers(DVIs, &MI); DIB.reset(new DIBuilder(*MI.getModule(), /*AllowUnresolved=*/false)); } @@ -2563,8 +2653,9 @@ Instruction *InstCombiner::visitAllocSite(Instruction &MI) { ConstantInt::get(Type::getInt1Ty(C->getContext()), C->isFalseWhenEqual())); } else if (auto *SI = dyn_cast<StoreInst>(I)) { - for (auto *DII : DIIs) - ConvertDebugDeclareToDebugValue(DII, SI, *DIB); + for (auto *DVI : DVIs) + if (DVI->isAddressOfVariable()) + ConvertDebugDeclareToDebugValue(DVI, SI, *DIB); } else { // Casts, GEP, or anything else: we're about to delete this instruction, // so it can not have any valid uses. @@ -2581,8 +2672,31 @@ Instruction *InstCombiner::visitAllocSite(Instruction &MI) { None, "", II->getParent()); } - for (auto *DII : DIIs) - eraseInstFromFunction(*DII); + // Remove debug intrinsics which describe the value contained within the + // alloca. In addition to removing dbg.{declare,addr} which simply point to + // the alloca, remove dbg.value(<alloca>, ..., DW_OP_deref)'s as well, e.g.: + // + // ``` + // define void @foo(i32 %0) { + // %a = alloca i32 ; Deleted. + // store i32 %0, i32* %a + // dbg.value(i32 %0, "arg0") ; Not deleted. + // dbg.value(i32* %a, "arg0", DW_OP_deref) ; Deleted. + // call void @trivially_inlinable_no_op(i32* %a) + // ret void + // } + // ``` + // + // This may not be required if we stop describing the contents of allocas + // using dbg.value(<alloca>, ..., DW_OP_deref), but we currently do this in + // the LowerDbgDeclare utility. + // + // If there is a dead store to `%a` in @trivially_inlinable_no_op, the + // "arg0" dbg.value may be stale after the call. However, failing to remove + // the DW_OP_deref dbg.value causes large gaps in location coverage. + for (auto *DVI : DVIs) + if (DVI->isAddressOfVariable() || DVI->getExpression()->startsWithDeref()) + DVI->eraseFromParent(); return eraseInstFromFunction(MI); } @@ -2670,7 +2784,7 @@ static Instruction *tryToMoveFreeBeforeNullTest(CallInst &FI, return &FI; } -Instruction *InstCombiner::visitFree(CallInst &FI) { +Instruction *InstCombinerImpl::visitFree(CallInst &FI) { Value *Op = FI.getArgOperand(0); // free undef -> unreachable. @@ -2711,7 +2825,7 @@ static bool isMustTailCall(Value *V) { return false; } -Instruction *InstCombiner::visitReturnInst(ReturnInst &RI) { +Instruction *InstCombinerImpl::visitReturnInst(ReturnInst &RI) { if (RI.getNumOperands() == 0) // ret void return nullptr; @@ -2734,7 +2848,31 @@ Instruction *InstCombiner::visitReturnInst(ReturnInst &RI) { return nullptr; } -Instruction *InstCombiner::visitUnconditionalBranchInst(BranchInst &BI) { +Instruction *InstCombinerImpl::visitUnreachableInst(UnreachableInst &I) { + // Try to remove the previous instruction if it must lead to unreachable. + // This includes instructions like stores and "llvm.assume" that may not get + // removed by simple dead code elimination. + Instruction *Prev = I.getPrevNonDebugInstruction(); + if (Prev && !Prev->isEHPad() && + isGuaranteedToTransferExecutionToSuccessor(Prev)) { + // Temporarily disable removal of volatile stores preceding unreachable, + // pending a potential LangRef change permitting volatile stores to trap. + // TODO: Either remove this code, or properly integrate the check into + // isGuaranteedToTransferExecutionToSuccessor(). + if (auto *SI = dyn_cast<StoreInst>(Prev)) + if (SI->isVolatile()) + return nullptr; + + // A value may still have uses before we process it here (for example, in + // another unreachable block), so convert those to undef. + replaceInstUsesWith(*Prev, UndefValue::get(Prev->getType())); + eraseInstFromFunction(*Prev); + return &I; + } + return nullptr; +} + +Instruction *InstCombinerImpl::visitUnconditionalBranchInst(BranchInst &BI) { assert(BI.isUnconditional() && "Only for unconditional branches."); // If this store is the second-to-last instruction in the basic block @@ -2763,7 +2901,7 @@ Instruction *InstCombiner::visitUnconditionalBranchInst(BranchInst &BI) { return nullptr; } -Instruction *InstCombiner::visitBranchInst(BranchInst &BI) { +Instruction *InstCombinerImpl::visitBranchInst(BranchInst &BI) { if (BI.isUnconditional()) return visitUnconditionalBranchInst(BI); @@ -2799,7 +2937,7 @@ Instruction *InstCombiner::visitBranchInst(BranchInst &BI) { return nullptr; } -Instruction *InstCombiner::visitSwitchInst(SwitchInst &SI) { +Instruction *InstCombinerImpl::visitSwitchInst(SwitchInst &SI) { Value *Cond = SI.getCondition(); Value *Op0; ConstantInt *AddRHS; @@ -2830,7 +2968,7 @@ Instruction *InstCombiner::visitSwitchInst(SwitchInst &SI) { unsigned NewWidth = Known.getBitWidth() - std::max(LeadingKnownZeros, LeadingKnownOnes); // Shrink the condition operand if the new type is smaller than the old type. - // But do not shrink to a non-standard type, because backend can't generate + // But do not shrink to a non-standard type, because backend can't generate // good code for that yet. // TODO: We can make it aggressive again after fixing PR39569. if (NewWidth > 0 && NewWidth < Known.getBitWidth() && @@ -2849,7 +2987,7 @@ Instruction *InstCombiner::visitSwitchInst(SwitchInst &SI) { return nullptr; } -Instruction *InstCombiner::visitExtractValueInst(ExtractValueInst &EV) { +Instruction *InstCombinerImpl::visitExtractValueInst(ExtractValueInst &EV) { Value *Agg = EV.getAggregateOperand(); if (!EV.hasIndices()) @@ -2994,10 +3132,11 @@ static bool isCatchAll(EHPersonality Personality, Constant *TypeInfo) { case EHPersonality::GNU_CXX_SjLj: case EHPersonality::GNU_ObjC: case EHPersonality::MSVC_X86SEH: - case EHPersonality::MSVC_Win64SEH: + case EHPersonality::MSVC_TableSEH: case EHPersonality::MSVC_CXX: case EHPersonality::CoreCLR: case EHPersonality::Wasm_CXX: + case EHPersonality::XL_CXX: return TypeInfo->isNullValue(); } llvm_unreachable("invalid enum"); @@ -3010,7 +3149,7 @@ static bool shorter_filter(const Value *LHS, const Value *RHS) { cast<ArrayType>(RHS->getType())->getNumElements(); } -Instruction *InstCombiner::visitLandingPadInst(LandingPadInst &LI) { +Instruction *InstCombinerImpl::visitLandingPadInst(LandingPadInst &LI) { // The logic here should be correct for any real-world personality function. // However if that turns out not to be true, the offending logic can always // be conditioned on the personality function, like the catch-all logic is. @@ -3319,12 +3458,46 @@ Instruction *InstCombiner::visitLandingPadInst(LandingPadInst &LI) { return nullptr; } -Instruction *InstCombiner::visitFreeze(FreezeInst &I) { +Instruction *InstCombinerImpl::visitFreeze(FreezeInst &I) { Value *Op0 = I.getOperand(0); if (Value *V = SimplifyFreezeInst(Op0, SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); + // freeze (phi const, x) --> phi const, (freeze x) + if (auto *PN = dyn_cast<PHINode>(Op0)) { + if (Instruction *NV = foldOpIntoPhi(I, PN)) + return NV; + } + + if (match(Op0, m_Undef())) { + // If I is freeze(undef), see its uses and fold it to the best constant. + // - or: pick -1 + // - select's condition: pick the value that leads to choosing a constant + // - other ops: pick 0 + Constant *BestValue = nullptr; + Constant *NullValue = Constant::getNullValue(I.getType()); + for (const auto *U : I.users()) { + Constant *C = NullValue; + + if (match(U, m_Or(m_Value(), m_Value()))) + C = Constant::getAllOnesValue(I.getType()); + else if (const auto *SI = dyn_cast<SelectInst>(U)) { + if (SI->getCondition() == &I) { + APInt CondVal(1, isa<Constant>(SI->getFalseValue()) ? 0 : 1); + C = Constant::getIntegerValue(I.getType(), CondVal); + } + } + + if (!BestValue) + BestValue = C; + else if (BestValue != C) + BestValue = NullValue; + } + + return replaceInstUsesWith(I, BestValue); + } + return nullptr; } @@ -3430,7 +3603,7 @@ static bool TryToSinkInstruction(Instruction *I, BasicBlock *DestBlock) { return true; } -bool InstCombiner::run() { +bool InstCombinerImpl::run() { while (!Worklist.isEmpty()) { // Walk deferred instructions in reverse order, and push them to the // worklist, which means they'll end up popped from the worklist in-order. @@ -3492,7 +3665,9 @@ bool InstCombiner::run() { else UserParent = UserInst->getParent(); - if (UserParent != BB) { + // Try sinking to another block. If that block is unreachable, then do + // not bother. SimplifyCFG should handle it. + if (UserParent != BB && DT.isReachableFromEntry(UserParent)) { // See if the user is one of our successors that has only one // predecessor, so that we don't have to split the critical edge. bool ShouldSink = UserParent->getUniquePredecessor() == BB; @@ -3526,7 +3701,8 @@ bool InstCombiner::run() { // Now that we have an instruction, try combining it to simplify it. Builder.SetInsertPoint(I); - Builder.SetCurrentDebugLocation(I->getDebugLoc()); + Builder.CollectMetadataToCopy( + I, {LLVMContext::MD_dbg, LLVMContext::MD_annotation}); #ifndef NDEBUG std::string OrigI; @@ -3541,8 +3717,8 @@ bool InstCombiner::run() { LLVM_DEBUG(dbgs() << "IC: Old = " << *I << '\n' << " New = " << *Result << '\n'); - if (I->getDebugLoc()) - Result->setDebugLoc(I->getDebugLoc()); + Result->copyMetadata(*I, + {LLVMContext::MD_dbg, LLVMContext::MD_annotation}); // Everything uses the new instruction now. I->replaceAllUsesWith(Result); @@ -3553,10 +3729,14 @@ bool InstCombiner::run() { BasicBlock *InstParent = I->getParent(); BasicBlock::iterator InsertPos = I->getIterator(); - // If we replace a PHI with something that isn't a PHI, fix up the - // insertion point. - if (!isa<PHINode>(Result) && isa<PHINode>(InsertPos)) - InsertPos = InstParent->getFirstInsertionPt(); + // Are we replace a PHI with something that isn't a PHI, or vice versa? + if (isa<PHINode>(Result) != isa<PHINode>(I)) { + // We need to fix up the insertion point. + if (isa<PHINode>(I)) // PHI -> Non-PHI + InsertPos = InstParent->getFirstInsertionPt(); + else // Non-PHI -> PHI + InsertPos = InstParent->getFirstNonPHI()->getIterator(); + } InstParent->getInstList().insert(InsertPos, Result); @@ -3586,6 +3766,55 @@ bool InstCombiner::run() { return MadeIRChange; } +// Track the scopes used by !alias.scope and !noalias. In a function, a +// @llvm.experimental.noalias.scope.decl is only useful if that scope is used +// by both sets. If not, the declaration of the scope can be safely omitted. +// The MDNode of the scope can be omitted as well for the instructions that are +// part of this function. We do not do that at this point, as this might become +// too time consuming to do. +class AliasScopeTracker { + SmallPtrSet<const MDNode *, 8> UsedAliasScopesAndLists; + SmallPtrSet<const MDNode *, 8> UsedNoAliasScopesAndLists; + +public: + void analyse(Instruction *I) { + // This seems to be faster than checking 'mayReadOrWriteMemory()'. + if (!I->hasMetadataOtherThanDebugLoc()) + return; + + auto Track = [](Metadata *ScopeList, auto &Container) { + const auto *MDScopeList = dyn_cast_or_null<MDNode>(ScopeList); + if (!MDScopeList || !Container.insert(MDScopeList).second) + return; + for (auto &MDOperand : MDScopeList->operands()) + if (auto *MDScope = dyn_cast<MDNode>(MDOperand)) + Container.insert(MDScope); + }; + + Track(I->getMetadata(LLVMContext::MD_alias_scope), UsedAliasScopesAndLists); + Track(I->getMetadata(LLVMContext::MD_noalias), UsedNoAliasScopesAndLists); + } + + bool isNoAliasScopeDeclDead(Instruction *Inst) { + NoAliasScopeDeclInst *Decl = dyn_cast<NoAliasScopeDeclInst>(Inst); + if (!Decl) + return false; + + assert(Decl->use_empty() && + "llvm.experimental.noalias.scope.decl in use ?"); + const MDNode *MDSL = Decl->getScopeList(); + assert(MDSL->getNumOperands() == 1 && + "llvm.experimental.noalias.scope should refer to a single scope"); + auto &MDOperand = MDSL->getOperand(0); + if (auto *MD = dyn_cast<MDNode>(MDOperand)) + return !UsedAliasScopesAndLists.contains(MD) || + !UsedNoAliasScopesAndLists.contains(MD); + + // Not an MDNode ? throw away. + return true; + } +}; + /// Populate the IC worklist from a function, by walking it in depth-first /// order and adding all reachable code to the worklist. /// @@ -3604,6 +3833,7 @@ static bool prepareICWorklistFromFunction(Function &F, const DataLayout &DL, SmallVector<Instruction*, 128> InstrsForInstCombineWorklist; DenseMap<Constant *, Constant *> FoldedConstants; + AliasScopeTracker SeenAliasScopes; do { BasicBlock *BB = Worklist.pop_back_val(); @@ -3650,8 +3880,10 @@ static bool prepareICWorklistFromFunction(Function &F, const DataLayout &DL, // Skip processing debug intrinsics in InstCombine. Processing these call instructions // consumes non-trivial amount of time and provides no value for the optimization. - if (!isa<DbgInfoIntrinsic>(Inst)) + if (!isa<DbgInfoIntrinsic>(Inst)) { InstrsForInstCombineWorklist.push_back(Inst); + SeenAliasScopes.analyse(Inst); + } } // Recursively visit successors. If this is a branch or switch on a @@ -3671,8 +3903,7 @@ static bool prepareICWorklistFromFunction(Function &F, const DataLayout &DL, } } - for (BasicBlock *SuccBB : successors(TI)) - Worklist.push_back(SuccBB); + append_range(Worklist, successors(TI)); } while (!Worklist.empty()); // Remove instructions inside unreachable blocks. This prevents the @@ -3682,8 +3913,12 @@ static bool prepareICWorklistFromFunction(Function &F, const DataLayout &DL, if (Visited.count(&BB)) continue; - unsigned NumDeadInstInBB = removeAllNonTerminatorAndEHPadInstructions(&BB); - MadeIRChange |= NumDeadInstInBB > 0; + unsigned NumDeadInstInBB; + unsigned NumDeadDbgInstInBB; + std::tie(NumDeadInstInBB, NumDeadDbgInstInBB) = + removeAllNonTerminatorAndEHPadInstructions(&BB); + + MadeIRChange |= NumDeadInstInBB + NumDeadDbgInstInBB > 0; NumDeadInst += NumDeadInstInBB; } @@ -3696,7 +3931,8 @@ static bool prepareICWorklistFromFunction(Function &F, const DataLayout &DL, for (Instruction *Inst : reverse(InstrsForInstCombineWorklist)) { // DCE instruction if trivially dead. As we iterate in reverse program // order here, we will clean up whole chains of dead instructions. - if (isInstructionTriviallyDead(Inst, TLI)) { + if (isInstructionTriviallyDead(Inst, TLI) || + SeenAliasScopes.isNoAliasScopeDeclDead(Inst)) { ++NumDeadInst; LLVM_DEBUG(dbgs() << "IC: DCE: " << *Inst << '\n'); salvageDebugInfo(*Inst); @@ -3713,8 +3949,8 @@ static bool prepareICWorklistFromFunction(Function &F, const DataLayout &DL, static bool combineInstructionsOverFunction( Function &F, InstCombineWorklist &Worklist, AliasAnalysis *AA, - AssumptionCache &AC, TargetLibraryInfo &TLI, DominatorTree &DT, - OptimizationRemarkEmitter &ORE, BlockFrequencyInfo *BFI, + AssumptionCache &AC, TargetLibraryInfo &TLI, TargetTransformInfo &TTI, + DominatorTree &DT, OptimizationRemarkEmitter &ORE, BlockFrequencyInfo *BFI, ProfileSummaryInfo *PSI, unsigned MaxIterations, LoopInfo *LI) { auto &DL = F.getParent()->getDataLayout(); MaxIterations = std::min(MaxIterations, LimitMaxIterations.getValue()); @@ -3738,6 +3974,7 @@ static bool combineInstructionsOverFunction( // Iterate while there is work to do. unsigned Iteration = 0; while (true) { + ++NumWorklistIterations; ++Iteration; if (Iteration > InfiniteLoopDetectionThreshold) { @@ -3758,8 +3995,8 @@ static bool combineInstructionsOverFunction( MadeIRChange |= prepareICWorklistFromFunction(F, DL, &TLI, Worklist); - InstCombiner IC(Worklist, Builder, F.hasMinSize(), AA, - AC, TLI, DT, ORE, BFI, PSI, DL, LI); + InstCombinerImpl IC(Worklist, Builder, F.hasMinSize(), AA, AC, TLI, TTI, DT, + ORE, BFI, PSI, DL, LI); IC.MaxArraySizeForCombine = MaxArraySize; if (!IC.run()) @@ -3782,6 +4019,7 @@ PreservedAnalyses InstCombinePass::run(Function &F, auto &DT = AM.getResult<DominatorTreeAnalysis>(F); auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); auto &ORE = AM.getResult<OptimizationRemarkEmitterAnalysis>(F); + auto &TTI = AM.getResult<TargetIRAnalysis>(F); auto *LI = AM.getCachedResult<LoopAnalysis>(F); @@ -3792,8 +4030,8 @@ PreservedAnalyses InstCombinePass::run(Function &F, auto *BFI = (PSI && PSI->hasProfileSummary()) ? &AM.getResult<BlockFrequencyAnalysis>(F) : nullptr; - if (!combineInstructionsOverFunction(F, Worklist, AA, AC, TLI, DT, ORE, BFI, - PSI, MaxIterations, LI)) + if (!combineInstructionsOverFunction(F, Worklist, AA, AC, TLI, TTI, DT, ORE, + BFI, PSI, MaxIterations, LI)) // No changes, all analyses are preserved. return PreservedAnalyses::all(); @@ -3811,6 +4049,7 @@ void InstructionCombiningPass::getAnalysisUsage(AnalysisUsage &AU) const { AU.addRequired<AAResultsWrapperPass>(); AU.addRequired<AssumptionCacheTracker>(); AU.addRequired<TargetLibraryInfoWrapperPass>(); + AU.addRequired<TargetTransformInfoWrapperPass>(); AU.addRequired<DominatorTreeWrapperPass>(); AU.addRequired<OptimizationRemarkEmitterWrapperPass>(); AU.addPreserved<DominatorTreeWrapperPass>(); @@ -3829,6 +4068,7 @@ bool InstructionCombiningPass::runOnFunction(Function &F) { auto AA = &getAnalysis<AAResultsWrapperPass>().getAAResults(); auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); + auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); auto &ORE = getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(); @@ -3842,8 +4082,8 @@ bool InstructionCombiningPass::runOnFunction(Function &F) { &getAnalysis<LazyBlockFrequencyInfoPass>().getBFI() : nullptr; - return combineInstructionsOverFunction(F, Worklist, AA, AC, TLI, DT, ORE, BFI, - PSI, MaxIterations, LI); + return combineInstructionsOverFunction(F, Worklist, AA, AC, TLI, TTI, DT, ORE, + BFI, PSI, MaxIterations, LI); } char InstructionCombiningPass::ID = 0; @@ -3862,6 +4102,7 @@ INITIALIZE_PASS_BEGIN(InstructionCombiningPass, "instcombine", "Combine redundant instructions", false, false) INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) INITIALIZE_PASS_DEPENDENCY(GlobalsAAWrapperPass) |