aboutsummaryrefslogtreecommitdiff
path: root/lib/Target/X86/X86ISelLowering.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'lib/Target/X86/X86ISelLowering.cpp')
-rw-r--r--lib/Target/X86/X86ISelLowering.cpp317
1 files changed, 233 insertions, 84 deletions
diff --git a/lib/Target/X86/X86ISelLowering.cpp b/lib/Target/X86/X86ISelLowering.cpp
index 7dcdb7967058..2820004cfc6d 100644
--- a/lib/Target/X86/X86ISelLowering.cpp
+++ b/lib/Target/X86/X86ISelLowering.cpp
@@ -1800,17 +1800,19 @@ X86TargetLowering::getPreferredVectorAction(EVT VT) const {
}
MVT X86TargetLowering::getRegisterTypeForCallingConv(LLVMContext &Context,
+ CallingConv::ID CC,
EVT VT) const {
if (VT == MVT::v32i1 && Subtarget.hasAVX512() && !Subtarget.hasBWI())
return MVT::v32i8;
- return TargetLowering::getRegisterTypeForCallingConv(Context, VT);
+ return TargetLowering::getRegisterTypeForCallingConv(Context, CC, VT);
}
unsigned X86TargetLowering::getNumRegistersForCallingConv(LLVMContext &Context,
+ CallingConv::ID CC,
EVT VT) const {
if (VT == MVT::v32i1 && Subtarget.hasAVX512() && !Subtarget.hasBWI())
return 1;
- return TargetLowering::getNumRegistersForCallingConv(Context, VT);
+ return TargetLowering::getNumRegistersForCallingConv(Context, CC, VT);
}
EVT X86TargetLowering::getSetCCResultType(const DataLayout &DL,
@@ -23366,7 +23368,7 @@ static SDValue convertShiftLeftToScale(SDValue Amt, const SDLoc &dl,
return DAG.getBuildVector(VT, dl, Elts);
}
- // If the target doesn't support variable shifts, use either FP conversion
+ // If the target doesn't support variable shifts, use either FP conversion
// or integer multiplication to avoid shifting each element individually.
if (VT == MVT::v4i32) {
Amt = DAG.getNode(ISD::SHL, dl, VT, Amt, DAG.getConstant(23, dl, VT));
@@ -23509,6 +23511,24 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget &Subtarget,
if (SDValue Scale = convertShiftLeftToScale(Amt, dl, Subtarget, DAG))
return DAG.getNode(ISD::MUL, dl, VT, R, Scale);
+ // Constant ISD::SRL can be performed efficiently on vXi8/vXi16 vectors as we
+ // can replace with ISD::MULHU, creating scale factor from (NumEltBits - Amt).
+ // TODO: Improve support for the shift by zero special case.
+ if (Op.getOpcode() == ISD::SRL && ConstantAmt &&
+ ((Subtarget.hasSSE41() && VT == MVT::v8i16) ||
+ DAG.isKnownNeverZero(Amt)) &&
+ (VT == MVT::v16i8 || VT == MVT::v8i16 ||
+ ((VT == MVT::v32i8 || VT == MVT::v16i16) && Subtarget.hasInt256()))) {
+ SDValue EltBits = DAG.getConstant(VT.getScalarSizeInBits(), dl, VT);
+ SDValue RAmt = DAG.getNode(ISD::SUB, dl, VT, EltBits, Amt);
+ if (SDValue Scale = convertShiftLeftToScale(RAmt, dl, Subtarget, DAG)) {
+ SDValue Zero = DAG.getConstant(0, dl, VT);
+ SDValue ZAmt = DAG.getSetCC(dl, VT, Amt, Zero, ISD::SETEQ);
+ SDValue Res = DAG.getNode(ISD::MULHU, dl, VT, R, Scale);
+ return DAG.getSelect(dl, VT, ZAmt, R, Res);
+ }
+ }
+
// v4i32 Non Uniform Shifts.
// If the shift amount is constant we can shift each lane using the SSE2
// immediate shifts, else we need to zero-extend each lane to the lower i64
@@ -33425,33 +33445,32 @@ static SDValue combineCMov(SDNode *N, SelectionDAG &DAG,
}
}
- // Handle (CMOV C-1, (ADD (CTTZ X), C), (X != 0)) ->
- // (ADD (CMOV (CTTZ X), -1, (X != 0)), C) or
- // (CMOV (ADD (CTTZ X), C), C-1, (X == 0)) ->
- // (ADD (CMOV C-1, (CTTZ X), (X == 0)), C)
- if (CC == X86::COND_NE || CC == X86::COND_E) {
- auto *Cnst = CC == X86::COND_E ? dyn_cast<ConstantSDNode>(TrueOp)
- : dyn_cast<ConstantSDNode>(FalseOp);
- SDValue Add = CC == X86::COND_E ? FalseOp : TrueOp;
-
- if (Cnst && Add.getOpcode() == ISD::ADD && Add.hasOneUse()) {
- auto *AddOp1 = dyn_cast<ConstantSDNode>(Add.getOperand(1));
- SDValue AddOp2 = Add.getOperand(0);
- if (AddOp1 && (AddOp2.getOpcode() == ISD::CTTZ_ZERO_UNDEF ||
- AddOp2.getOpcode() == ISD::CTTZ)) {
- APInt Diff = Cnst->getAPIntValue() - AddOp1->getAPIntValue();
- if (CC == X86::COND_E) {
- Add = DAG.getNode(X86ISD::CMOV, DL, Add.getValueType(), AddOp2,
- DAG.getConstant(Diff, DL, Add.getValueType()),
- DAG.getConstant(CC, DL, MVT::i8), Cond);
- } else {
- Add = DAG.getNode(X86ISD::CMOV, DL, Add.getValueType(),
- DAG.getConstant(Diff, DL, Add.getValueType()),
- AddOp2, DAG.getConstant(CC, DL, MVT::i8), Cond);
- }
- return DAG.getNode(X86ISD::ADD, DL, Add.getValueType(), Add,
- SDValue(AddOp1, 0));
- }
+ // Fold (CMOV C1, (ADD (CTTZ X), C2), (X != 0)) ->
+ // (ADD (CMOV C1-C2, (CTTZ X), (X != 0)), C2)
+ // Or (CMOV (ADD (CTTZ X), C2), C1, (X == 0)) ->
+ // (ADD (CMOV (CTTZ X), C1-C2, (X == 0)), C2)
+ if ((CC == X86::COND_NE || CC == X86::COND_E) &&
+ Cond.getOpcode() == X86ISD::CMP && isNullConstant(Cond.getOperand(1))) {
+ SDValue Add = TrueOp;
+ SDValue Const = FalseOp;
+ // Canonicalize the condition code for easier matching and output.
+ if (CC == X86::COND_E) {
+ std::swap(Add, Const);
+ CC = X86::COND_NE;
+ }
+
+ // Ok, now make sure that Add is (add (cttz X), C2) and Const is a constant.
+ if (isa<ConstantSDNode>(Const) && Add.getOpcode() == ISD::ADD &&
+ Add.hasOneUse() && isa<ConstantSDNode>(Add.getOperand(1)) &&
+ (Add.getOperand(0).getOpcode() == ISD::CTTZ_ZERO_UNDEF ||
+ Add.getOperand(0).getOpcode() == ISD::CTTZ) &&
+ Add.getOperand(0).getOperand(0) == Cond.getOperand(0)) {
+ EVT VT = N->getValueType(0);
+ // This should constant fold.
+ SDValue Diff = DAG.getNode(ISD::SUB, DL, VT, Const, Add.getOperand(1));
+ SDValue CMov = DAG.getNode(X86ISD::CMOV, DL, VT, Diff, Add.getOperand(0),
+ DAG.getConstant(CC, DL, MVT::i8), Cond);
+ return DAG.getNode(ISD::ADD, DL, VT, CMov, Add.getOperand(1));
}
}
@@ -33873,31 +33892,42 @@ static SDValue combineMul(SDNode *N, SelectionDAG &DAG,
ConstantSDNode *C = dyn_cast<ConstantSDNode>(N->getOperand(1));
if (!C)
return SDValue();
- uint64_t MulAmt = C->getZExtValue();
- if (isPowerOf2_64(MulAmt))
+ if (isPowerOf2_64(C->getZExtValue()))
return SDValue();
+ int64_t SignMulAmt = C->getSExtValue();
+ assert(SignMulAmt != INT64_MIN && "Int min should have been handled!");
+ uint64_t AbsMulAmt = SignMulAmt < 0 ? -SignMulAmt : SignMulAmt;
+
SDLoc DL(N);
- if (MulAmt == 3 || MulAmt == 5 || MulAmt == 9)
- return DAG.getNode(X86ISD::MUL_IMM, DL, VT, N->getOperand(0),
- N->getOperand(1));
+ if (AbsMulAmt == 3 || AbsMulAmt == 5 || AbsMulAmt == 9) {
+ SDValue NewMul = DAG.getNode(X86ISD::MUL_IMM, DL, VT, N->getOperand(0),
+ DAG.getConstant(AbsMulAmt, DL, VT));
+ if (SignMulAmt < 0)
+ NewMul = DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT),
+ NewMul);
+
+ return NewMul;
+ }
uint64_t MulAmt1 = 0;
uint64_t MulAmt2 = 0;
- if ((MulAmt % 9) == 0) {
+ if ((AbsMulAmt % 9) == 0) {
MulAmt1 = 9;
- MulAmt2 = MulAmt / 9;
- } else if ((MulAmt % 5) == 0) {
+ MulAmt2 = AbsMulAmt / 9;
+ } else if ((AbsMulAmt % 5) == 0) {
MulAmt1 = 5;
- MulAmt2 = MulAmt / 5;
- } else if ((MulAmt % 3) == 0) {
+ MulAmt2 = AbsMulAmt / 5;
+ } else if ((AbsMulAmt % 3) == 0) {
MulAmt1 = 3;
- MulAmt2 = MulAmt / 3;
+ MulAmt2 = AbsMulAmt / 3;
}
SDValue NewMul;
+ // For negative multiply amounts, only allow MulAmt2 to be a power of 2.
if (MulAmt2 &&
- (isPowerOf2_64(MulAmt2) || MulAmt2 == 3 || MulAmt2 == 5 || MulAmt2 == 9)){
+ (isPowerOf2_64(MulAmt2) ||
+ (SignMulAmt >= 0 && (MulAmt2 == 3 || MulAmt2 == 5 || MulAmt2 == 9)))) {
if (isPowerOf2_64(MulAmt2) &&
!(N->hasOneUse() && N->use_begin()->getOpcode() == ISD::ADD))
@@ -33919,17 +33949,19 @@ static SDValue combineMul(SDNode *N, SelectionDAG &DAG,
else
NewMul = DAG.getNode(X86ISD::MUL_IMM, DL, VT, NewMul,
DAG.getConstant(MulAmt2, DL, VT));
+
+ // Negate the result.
+ if (SignMulAmt < 0)
+ NewMul = DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT),
+ NewMul);
} else if (!Subtarget.slowLEA())
- NewMul = combineMulSpecial(MulAmt, N, DAG, VT, DL);
+ NewMul = combineMulSpecial(C->getZExtValue(), N, DAG, VT, DL);
if (!NewMul) {
- assert(MulAmt != 0 &&
- MulAmt != (VT == MVT::i64 ? UINT64_MAX : UINT32_MAX) &&
+ assert(C->getZExtValue() != 0 &&
+ C->getZExtValue() != (VT == MVT::i64 ? UINT64_MAX : UINT32_MAX) &&
"Both cases that could cause potential overflows should have "
"already been handled.");
- int64_t SignMulAmt = C->getSExtValue();
- assert(SignMulAmt != INT64_MIN && "Int min should have been handled!");
- uint64_t AbsMulAmt = SignMulAmt < 0 ? -SignMulAmt : SignMulAmt;
if (isPowerOf2_64(AbsMulAmt - 1)) {
// (mul x, 2^N + 1) => (add (shl x, N), x)
NewMul = DAG.getNode(
@@ -36738,6 +36770,145 @@ static SDValue combinePMULH(SDValue Src, EVT VT, const SDLoc &DL,
return DAG.getNode(Opc, DL, VT, LHS, RHS);
}
+// Attempt to match PMADDUBSW, which multiplies corresponding unsigned bytes
+// from one vector with signed bytes from another vector, adds together
+// adjacent pairs of 16-bit products, and saturates the result before
+// truncating to 16-bits.
+//
+// Which looks something like this:
+// (i16 (ssat (add (mul (zext (even elts (i8 A))), (sext (even elts (i8 B)))),
+// (mul (zext (odd elts (i8 A)), (sext (odd elts (i8 B))))))))
+static SDValue detectPMADDUBSW(SDValue In, EVT VT, SelectionDAG &DAG,
+ const X86Subtarget &Subtarget,
+ const SDLoc &DL) {
+ if (!VT.isVector() || !Subtarget.hasSSSE3())
+ return SDValue();
+
+ unsigned NumElems = VT.getVectorNumElements();
+ EVT ScalarVT = VT.getVectorElementType();
+ if (ScalarVT != MVT::i16 || NumElems < 8 || !isPowerOf2_32(NumElems))
+ return SDValue();
+
+ SDValue SSatVal = detectSSatPattern(In, VT);
+ if (!SSatVal || SSatVal.getOpcode() != ISD::ADD)
+ return SDValue();
+
+ // Ok this is a signed saturation of an ADD. See if this ADD is adding pairs
+ // of multiplies from even/odd elements.
+ SDValue N0 = SSatVal.getOperand(0);
+ SDValue N1 = SSatVal.getOperand(1);
+
+ if (N0.getOpcode() != ISD::MUL || N1.getOpcode() != ISD::MUL)
+ return SDValue();
+
+ SDValue N00 = N0.getOperand(0);
+ SDValue N01 = N0.getOperand(1);
+ SDValue N10 = N1.getOperand(0);
+ SDValue N11 = N1.getOperand(1);
+
+ // TODO: Handle constant vectors and use knownbits/computenumsignbits?
+ // Canonicalize zero_extend to LHS.
+ if (N01.getOpcode() == ISD::ZERO_EXTEND)
+ std::swap(N00, N01);
+ if (N11.getOpcode() == ISD::ZERO_EXTEND)
+ std::swap(N10, N11);
+
+ // Ensure we have a zero_extend and a sign_extend.
+ if (N00.getOpcode() != ISD::ZERO_EXTEND ||
+ N01.getOpcode() != ISD::SIGN_EXTEND ||
+ N10.getOpcode() != ISD::ZERO_EXTEND ||
+ N11.getOpcode() != ISD::SIGN_EXTEND)
+ return SDValue();
+
+ // Peek through the extends.
+ N00 = N00.getOperand(0);
+ N01 = N01.getOperand(0);
+ N10 = N10.getOperand(0);
+ N11 = N11.getOperand(0);
+
+ // Ensure the extend is from vXi8.
+ if (N00.getValueType().getVectorElementType() != MVT::i8 ||
+ N01.getValueType().getVectorElementType() != MVT::i8 ||
+ N10.getValueType().getVectorElementType() != MVT::i8 ||
+ N11.getValueType().getVectorElementType() != MVT::i8)
+ return SDValue();
+
+ // All inputs should be build_vectors.
+ if (N00.getOpcode() != ISD::BUILD_VECTOR ||
+ N01.getOpcode() != ISD::BUILD_VECTOR ||
+ N10.getOpcode() != ISD::BUILD_VECTOR ||
+ N11.getOpcode() != ISD::BUILD_VECTOR)
+ return SDValue();
+
+ // N00/N10 are zero extended. N01/N11 are sign extended.
+
+ // For each element, we need to ensure we have an odd element from one vector
+ // multiplied by the odd element of another vector and the even element from
+ // one of the same vectors being multiplied by the even element from the
+ // other vector. So we need to make sure for each element i, this operator
+ // is being performed:
+ // A[2 * i] * B[2 * i] + A[2 * i + 1] * B[2 * i + 1]
+ SDValue ZExtIn, SExtIn;
+ for (unsigned i = 0; i != NumElems; ++i) {
+ SDValue N00Elt = N00.getOperand(i);
+ SDValue N01Elt = N01.getOperand(i);
+ SDValue N10Elt = N10.getOperand(i);
+ SDValue N11Elt = N11.getOperand(i);
+ // TODO: Be more tolerant to undefs.
+ if (N00Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
+ N01Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
+ N10Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
+ N11Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT)
+ return SDValue();
+ auto *ConstN00Elt = dyn_cast<ConstantSDNode>(N00Elt.getOperand(1));
+ auto *ConstN01Elt = dyn_cast<ConstantSDNode>(N01Elt.getOperand(1));
+ auto *ConstN10Elt = dyn_cast<ConstantSDNode>(N10Elt.getOperand(1));
+ auto *ConstN11Elt = dyn_cast<ConstantSDNode>(N11Elt.getOperand(1));
+ if (!ConstN00Elt || !ConstN01Elt || !ConstN10Elt || !ConstN11Elt)
+ return SDValue();
+ unsigned IdxN00 = ConstN00Elt->getZExtValue();
+ unsigned IdxN01 = ConstN01Elt->getZExtValue();
+ unsigned IdxN10 = ConstN10Elt->getZExtValue();
+ unsigned IdxN11 = ConstN11Elt->getZExtValue();
+ // Add is commutative so indices can be reordered.
+ if (IdxN00 > IdxN10) {
+ std::swap(IdxN00, IdxN10);
+ std::swap(IdxN01, IdxN11);
+ }
+ // N0 indices be the even element. N1 indices must be the next odd element.
+ if (IdxN00 != 2 * i || IdxN10 != 2 * i + 1 ||
+ IdxN01 != 2 * i || IdxN11 != 2 * i + 1)
+ return SDValue();
+ SDValue N00In = N00Elt.getOperand(0);
+ SDValue N01In = N01Elt.getOperand(0);
+ SDValue N10In = N10Elt.getOperand(0);
+ SDValue N11In = N11Elt.getOperand(0);
+ // First time we find an input capture it.
+ if (!ZExtIn) {
+ ZExtIn = N00In;
+ SExtIn = N01In;
+ }
+ if (ZExtIn != N00In || SExtIn != N01In ||
+ ZExtIn != N10In || SExtIn != N11In)
+ return SDValue();
+ }
+
+ auto PMADDBuilder = [](SelectionDAG &DAG, const SDLoc &DL,
+ ArrayRef<SDValue> Ops) {
+ // Shrink by adding truncate nodes and let DAGCombine fold with the
+ // sources.
+ EVT InVT = Ops[0].getValueType();
+ assert(InVT.getScalarType() == MVT::i8 &&
+ "Unexpected scalar element type");
+ assert(InVT == Ops[1].getValueType() && "Operands' types mismatch");
+ EVT ResVT = EVT::getVectorVT(*DAG.getContext(), MVT::i16,
+ InVT.getVectorNumElements() / 2);
+ return DAG.getNode(X86ISD::VPMADDUBSW, DL, ResVT, Ops[0], Ops[1]);
+ };
+ return SplitOpsAndApply(DAG, Subtarget, DL, VT, { ZExtIn, SExtIn },
+ PMADDBuilder);
+}
+
static SDValue combineTruncate(SDNode *N, SelectionDAG &DAG,
const X86Subtarget &Subtarget) {
EVT VT = N->getValueType(0);
@@ -36752,6 +36923,10 @@ static SDValue combineTruncate(SDNode *N, SelectionDAG &DAG,
if (SDValue Avg = detectAVGPattern(Src, VT, DAG, Subtarget, DL))
return Avg;
+ // Try to detect PMADD
+ if (SDValue PMAdd = detectPMADDUBSW(Src, VT, DAG, Subtarget, DL))
+ return PMAdd;
+
// Try to combine truncation with signed/unsigned saturation.
if (SDValue Val = combineTruncateWithSat(Src, VT, DL, DAG, Subtarget))
return Val;
@@ -36793,38 +36968,14 @@ static SDValue isFNEG(SDNode *N) {
if (!Op1.getValueType().isFloatingPoint())
return SDValue();
- SDValue Op0 = peekThroughBitcasts(Op.getOperand(0));
-
- unsigned EltBits = Op1.getScalarValueSizeInBits();
- auto isSignMask = [&](const ConstantFP *C) {
- return C->getValueAPF().bitcastToAPInt() == APInt::getSignMask(EltBits);
- };
-
- // There is more than one way to represent the same constant on
- // the different X86 targets. The type of the node may also depend on size.
- // - load scalar value and broadcast
- // - BUILD_VECTOR node
- // - load from a constant pool.
- // We check all variants here.
- if (Op1.getOpcode() == X86ISD::VBROADCAST) {
- if (auto *C = getTargetConstantFromNode(Op1.getOperand(0)))
- if (isSignMask(cast<ConstantFP>(C)))
- return Op0;
-
- } else if (BuildVectorSDNode *BV = dyn_cast<BuildVectorSDNode>(Op1)) {
- if (ConstantFPSDNode *CN = BV->getConstantFPSplatNode())
- if (isSignMask(CN->getConstantFPValue()))
- return Op0;
+ // Extract constant bits and see if they are all sign bit masks.
+ APInt UndefElts;
+ SmallVector<APInt, 16> EltBits;
+ if (getTargetConstantBitsFromNode(Op1, Op1.getScalarValueSizeInBits(),
+ UndefElts, EltBits, false, false))
+ if (llvm::all_of(EltBits, [](APInt &I) { return I.isSignMask(); }))
+ return peekThroughBitcasts(Op.getOperand(0));
- } else if (auto *C = getTargetConstantFromNode(Op1)) {
- if (C->getType()->isVectorTy()) {
- if (auto *SplatV = C->getSplatValue())
- if (isSignMask(cast<ConstantFP>(SplatV)))
- return Op0;
- } else if (auto *FPConst = dyn_cast<ConstantFP>(C))
- if (isSignMask(FPConst))
- return Op0;
- }
return SDValue();
}
@@ -37777,8 +37928,7 @@ static SDValue combineFMA(SDNode *N, SelectionDAG &DAG,
// Look through extract_vector_elts. If it comes from an FNEG, create a
// new extract from the FNEG input.
if (V.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
- isa<ConstantSDNode>(V.getOperand(1)) &&
- cast<ConstantSDNode>(V.getOperand(1))->getZExtValue() == 0) {
+ isNullConstant(V.getOperand(1))) {
if (SDValue NegVal = isFNEG(V.getOperand(0).getNode())) {
NegVal = DAG.getBitcast(V.getOperand(0).getValueType(), NegVal);
V = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SDLoc(V), V.getValueType(),
@@ -38896,7 +39046,7 @@ static SDValue matchPMADDWD_2(SelectionDAG &DAG, SDValue N0, SDValue N1,
std::swap(IdxN00, IdxN10);
std::swap(IdxN01, IdxN11);
}
- // N0 indices be the even elemtn. N1 indices must be the next odd element.
+ // N0 indices be the even element. N1 indices must be the next odd element.
if (IdxN00 != 2 * i || IdxN10 != 2 * i + 1 ||
IdxN01 != 2 * i || IdxN11 != 2 * i + 1)
return SDValue();
@@ -39322,8 +39472,7 @@ static SDValue combineInsertSubvector(SDNode *N, SelectionDAG &DAG,
if ((IdxVal == OpVT.getVectorNumElements() / 2) &&
Vec.getOpcode() == ISD::INSERT_SUBVECTOR &&
OpVT.getSizeInBits() == SubVecVT.getSizeInBits() * 2) {
- auto *Idx2 = dyn_cast<ConstantSDNode>(Vec.getOperand(2));
- if (Idx2 && Idx2->getZExtValue() == 0) {
+ if (isNullConstant(Vec.getOperand(2))) {
SDValue SubVec2 = Vec.getOperand(1);
// If needed, look through bitcasts to get to the load.
if (auto *FirstLd = dyn_cast<LoadSDNode>(peekThroughBitcasts(SubVec2))) {