aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Target/AArch64/AArch64ISelLowering.cpp')
-rw-r--r--llvm/lib/Target/AArch64/AArch64ISelLowering.cpp190
1 files changed, 154 insertions, 36 deletions
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index e141179fb5c8..a26bbc77f248 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -962,6 +962,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setMinFunctionAlignment(Align(4));
// Set preferred alignments.
setPrefLoopAlignment(Align(1ULL << STI.getPrefLoopLogAlignment()));
+ setMaxBytesForAlignment(STI.getMaxBytesForLoopAlignment());
setPrefFunctionAlignment(Align(1ULL << STI.getPrefFunctionLogAlignment()));
// Only change the limit for entries in a jump table if specified by
@@ -1205,6 +1206,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::SRL, VT, Custom);
setOperationAction(ISD::SRA, VT, Custom);
setOperationAction(ISD::ABS, VT, Custom);
+ setOperationAction(ISD::ABDS, VT, Custom);
+ setOperationAction(ISD::ABDU, VT, Custom);
setOperationAction(ISD::VECREDUCE_ADD, VT, Custom);
setOperationAction(ISD::VECREDUCE_AND, VT, Custom);
setOperationAction(ISD::VECREDUCE_OR, VT, Custom);
@@ -1245,6 +1248,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::SELECT_CC, VT, Expand);
setOperationAction(ISD::EXTRACT_VECTOR_ELT, VT, Custom);
setOperationAction(ISD::INSERT_VECTOR_ELT, VT, Custom);
+ setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom);
// There are no legal MVT::nxv16f## based types.
if (VT != MVT::nxv16i1) {
@@ -1831,6 +1835,28 @@ void AArch64TargetLowering::computeKnownBitsForTargetNode(
Known = KnownBits::commonBits(Known, Known2);
break;
}
+ case AArch64ISD::BICi: {
+ // Compute the bit cleared value.
+ uint64_t Mask =
+ ~(Op->getConstantOperandVal(1) << Op->getConstantOperandVal(2));
+ Known = DAG.computeKnownBits(Op->getOperand(0), Depth + 1);
+ Known &= KnownBits::makeConstant(APInt(Known.getBitWidth(), Mask));
+ break;
+ }
+ case AArch64ISD::VLSHR: {
+ KnownBits Known2;
+ Known = DAG.computeKnownBits(Op->getOperand(0), Depth + 1);
+ Known2 = DAG.computeKnownBits(Op->getOperand(1), Depth + 1);
+ Known = KnownBits::lshr(Known, Known2);
+ break;
+ }
+ case AArch64ISD::VASHR: {
+ KnownBits Known2;
+ Known = DAG.computeKnownBits(Op->getOperand(0), Depth + 1);
+ Known2 = DAG.computeKnownBits(Op->getOperand(1), Depth + 1);
+ Known = KnownBits::ashr(Known, Known2);
+ break;
+ }
case AArch64ISD::LOADgot:
case AArch64ISD::ADDlow: {
if (!Subtarget->isTargetILP32())
@@ -1971,6 +1997,8 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
MAKE_CASE(AArch64ISD::CSINC)
MAKE_CASE(AArch64ISD::THREAD_POINTER)
MAKE_CASE(AArch64ISD::TLSDESC_CALLSEQ)
+ MAKE_CASE(AArch64ISD::ABDS_PRED)
+ MAKE_CASE(AArch64ISD::ABDU_PRED)
MAKE_CASE(AArch64ISD::ADD_PRED)
MAKE_CASE(AArch64ISD::MUL_PRED)
MAKE_CASE(AArch64ISD::MULHS_PRED)
@@ -2173,6 +2201,7 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
MAKE_CASE(AArch64ISD::INSR)
MAKE_CASE(AArch64ISD::PTEST)
MAKE_CASE(AArch64ISD::PTRUE)
+ MAKE_CASE(AArch64ISD::PFALSE)
MAKE_CASE(AArch64ISD::LD1_MERGE_ZERO)
MAKE_CASE(AArch64ISD::LD1S_MERGE_ZERO)
MAKE_CASE(AArch64ISD::LDNF1_MERGE_ZERO)
@@ -5173,6 +5202,10 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
return LowerFixedLengthVectorSelectToSVE(Op, DAG);
case ISD::ABS:
return LowerABS(Op, DAG);
+ case ISD::ABDS:
+ return LowerToPredicatedOp(Op, DAG, AArch64ISD::ABDS_PRED);
+ case ISD::ABDU:
+ return LowerToPredicatedOp(Op, DAG, AArch64ISD::ABDU_PRED);
case ISD::BITREVERSE:
return LowerBitreverse(Op, DAG);
case ISD::BSWAP:
@@ -5380,7 +5413,7 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
llvm_unreachable("RegVT not supported by FORMAL_ARGUMENTS Lowering");
// Transform the arguments in physical registers into virtual ones.
- unsigned Reg = MF.addLiveIn(VA.getLocReg(), RC);
+ Register Reg = MF.addLiveIn(VA.getLocReg(), RC);
ArgValue = DAG.getCopyFromReg(Chain, DL, Reg, RegVT);
// If this is an 8, 16 or 32-bit value, it is really passed promoted
@@ -5542,7 +5575,7 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
// Conservatively forward X8, since it might be used for aggregate return.
if (!CCInfo.isAllocated(AArch64::X8)) {
- unsigned X8VReg = MF.addLiveIn(AArch64::X8, &AArch64::GPR64RegClass);
+ Register X8VReg = MF.addLiveIn(AArch64::X8, &AArch64::GPR64RegClass);
Forwards.push_back(ForwardedRegister(X8VReg, AArch64::X8, MVT::i64));
}
}
@@ -5626,7 +5659,7 @@ void AArch64TargetLowering::saveVarArgRegisters(CCState &CCInfo,
SDValue FIN = DAG.getFrameIndex(GPRIdx, PtrVT);
for (unsigned i = FirstVariadicGPR; i < NumGPRArgRegs; ++i) {
- unsigned VReg = MF.addLiveIn(GPRArgRegs[i], &AArch64::GPR64RegClass);
+ Register VReg = MF.addLiveIn(GPRArgRegs[i], &AArch64::GPR64RegClass);
SDValue Val = DAG.getCopyFromReg(Chain, DL, VReg, MVT::i64);
SDValue Store =
DAG.getStore(Val.getValue(1), DL, Val, FIN,
@@ -5656,7 +5689,7 @@ void AArch64TargetLowering::saveVarArgRegisters(CCState &CCInfo,
SDValue FIN = DAG.getFrameIndex(FPRIdx, PtrVT);
for (unsigned i = FirstVariadicFPR; i < NumFPRArgRegs; ++i) {
- unsigned VReg = MF.addLiveIn(FPRArgRegs[i], &AArch64::FPR128RegClass);
+ Register VReg = MF.addLiveIn(FPRArgRegs[i], &AArch64::FPR128RegClass);
SDValue Val = DAG.getCopyFromReg(Chain, DL, VReg, MVT::f128);
SDValue Store = DAG.getStore(Val.getValue(1), DL, Val, FIN,
@@ -7256,6 +7289,9 @@ SDValue AArch64TargetLowering::LowerFCOPYSIGN(SDValue Op,
return getSVESafeBitCast(VT, IntResult, DAG);
}
+ if (!Subtarget->hasNEON())
+ return SDValue();
+
if (SrcVT.bitsLT(VT))
In2 = DAG.getNode(ISD::FP_EXTEND, DL, VT, In2);
else if (SrcVT.bitsGT(VT))
@@ -7795,10 +7831,37 @@ SDValue AArch64TargetLowering::LowerVECTOR_SPLICE(SDValue Op,
SelectionDAG &DAG) const {
EVT Ty = Op.getValueType();
auto Idx = Op.getConstantOperandAPInt(2);
+ int64_t IdxVal = Idx.getSExtValue();
+ assert(Ty.isScalableVector() &&
+ "Only expect scalable vectors for custom lowering of VECTOR_SPLICE");
+
+ // We can use the splice instruction for certain index values where we are
+ // able to efficiently generate the correct predicate. The index will be
+ // inverted and used directly as the input to the ptrue instruction, i.e.
+ // -1 -> vl1, -2 -> vl2, etc. The predicate will then be reversed to get the
+ // splice predicate. However, we can only do this if we can guarantee that
+ // there are enough elements in the vector, hence we check the index <= min
+ // number of elements.
+ Optional<unsigned> PredPattern;
+ if (Ty.isScalableVector() && IdxVal < 0 &&
+ (PredPattern = getSVEPredPatternFromNumElements(std::abs(IdxVal))) !=
+ None) {
+ SDLoc DL(Op);
+
+ // Create a predicate where all but the last -IdxVal elements are false.
+ EVT PredVT = Ty.changeVectorElementType(MVT::i1);
+ SDValue Pred = getPTrue(DAG, DL, PredVT, *PredPattern);
+ Pred = DAG.getNode(ISD::VECTOR_REVERSE, DL, PredVT, Pred);
+
+ // Now splice the two inputs together using the predicate.
+ return DAG.getNode(AArch64ISD::SPLICE, DL, Ty, Pred, Op.getOperand(0),
+ Op.getOperand(1));
+ }
// This will select to an EXT instruction, which has a maximum immediate
// value of 255, hence 2048-bits is the maximum value we can lower.
- if (Idx.sge(-1) && Idx.slt(2048 / Ty.getVectorElementType().getSizeInBits()))
+ if (IdxVal >= 0 &&
+ IdxVal < int64_t(2048 / Ty.getVectorElementType().getSizeInBits()))
return Op;
return SDValue();
@@ -8227,7 +8290,7 @@ SDValue AArch64TargetLowering::LowerRETURNADDR(SDValue Op,
} else {
// Return LR, which contains the return address. Mark it an implicit
// live-in.
- unsigned Reg = MF.addLiveIn(AArch64::LR, &AArch64::GPR64RegClass);
+ Register Reg = MF.addLiveIn(AArch64::LR, &AArch64::GPR64RegClass);
ReturnAddress = DAG.getCopyFromReg(DAG.getEntryNode(), DL, Reg, VT);
}
@@ -9631,14 +9694,12 @@ static SDValue constructDup(SDValue V, int Lane, SDLoc dl, EVT VT,
MVT CastVT;
if (getScaledOffsetDup(V, Lane, CastVT)) {
V = DAG.getBitcast(CastVT, V.getOperand(0).getOperand(0));
- } else if (V.getOpcode() == ISD::EXTRACT_SUBVECTOR) {
+ } else if (V.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
+ V.getOperand(0).getValueType().is128BitVector()) {
// The lane is incremented by the index of the extract.
// Example: dup v2f32 (extract v4f32 X, 2), 1 --> dup v4f32 X, 3
- auto VecVT = V.getOperand(0).getValueType();
- if (VecVT.isFixedLengthVector() && VecVT.getFixedSizeInBits() <= 128) {
- Lane += V.getConstantOperandVal(1);
- V = V.getOperand(0);
- }
+ Lane += V.getConstantOperandVal(1);
+ V = V.getOperand(0);
} else if (V.getOpcode() == ISD::CONCAT_VECTORS) {
// The lane is decremented if we are splatting from the 2nd operand.
// Example: dup v4i32 (concat v2i32 X, v2i32 Y), 3 --> dup v4i32 Y, 1
@@ -9925,7 +9986,7 @@ SDValue AArch64TargetLowering::LowerSPLAT_VECTOR(SDValue Op,
// lowering code.
if (auto *ConstVal = dyn_cast<ConstantSDNode>(SplatVal)) {
if (ConstVal->isZero())
- return SDValue(DAG.getMachineNode(AArch64::PFALSE, dl, VT), 0);
+ return DAG.getNode(AArch64ISD::PFALSE, dl, VT);
if (ConstVal->isOne())
return getPTrue(DAG, dl, VT, AArch64SVEPredPattern::all);
}
@@ -10978,6 +11039,28 @@ SDValue AArch64TargetLowering::LowerINSERT_SUBVECTOR(SDValue Op,
if (!isTypeLegal(VT))
return SDValue();
+ // Break down insert_subvector into simpler parts.
+ if (VT.getVectorElementType() == MVT::i1) {
+ unsigned NumElts = VT.getVectorMinNumElements();
+ EVT HalfVT = VT.getHalfNumVectorElementsVT(*DAG.getContext());
+
+ SDValue Lo, Hi;
+ Lo = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, HalfVT, Vec0,
+ DAG.getVectorIdxConstant(0, DL));
+ Hi = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, HalfVT, Vec0,
+ DAG.getVectorIdxConstant(NumElts / 2, DL));
+ if (Idx < (NumElts / 2)) {
+ SDValue NewLo = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, HalfVT, Lo, Vec1,
+ DAG.getVectorIdxConstant(Idx, DL));
+ return DAG.getNode(AArch64ISD::UZP1, DL, VT, NewLo, Hi);
+ } else {
+ SDValue NewHi =
+ DAG.getNode(ISD::INSERT_SUBVECTOR, DL, HalfVT, Hi, Vec1,
+ DAG.getVectorIdxConstant(Idx - (NumElts / 2), DL));
+ return DAG.getNode(AArch64ISD::UZP1, DL, VT, Lo, NewHi);
+ }
+ }
+
// Ensure the subvector is half the size of the main vector.
if (VT.getVectorElementCount() != (InVT.getVectorElementCount() * 2))
return SDValue();
@@ -11012,10 +11095,10 @@ SDValue AArch64TargetLowering::LowerINSERT_SUBVECTOR(SDValue Op,
if (Vec0.isUndef())
return Op;
- unsigned int PredPattern =
+ Optional<unsigned> PredPattern =
getSVEPredPatternFromNumElements(InVT.getVectorNumElements());
auto PredTy = VT.changeVectorElementType(MVT::i1);
- SDValue PTrue = getPTrue(DAG, DL, PredTy, PredPattern);
+ SDValue PTrue = getPTrue(DAG, DL, PredTy, *PredPattern);
SDValue ScalableVec1 = convertToScalableVector(DAG, VT, Vec1);
return DAG.getNode(ISD::VSELECT, DL, VT, PTrue, ScalableVec1, Vec0);
}
@@ -11730,10 +11813,10 @@ bool AArch64TargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info,
case Intrinsic::aarch64_ldxr: {
PointerType *PtrTy = cast<PointerType>(I.getArgOperand(0)->getType());
Info.opc = ISD::INTRINSIC_W_CHAIN;
- Info.memVT = MVT::getVT(PtrTy->getElementType());
+ Info.memVT = MVT::getVT(PtrTy->getPointerElementType());
Info.ptrVal = I.getArgOperand(0);
Info.offset = 0;
- Info.align = DL.getABITypeAlign(PtrTy->getElementType());
+ Info.align = DL.getABITypeAlign(PtrTy->getPointerElementType());
Info.flags = MachineMemOperand::MOLoad | MachineMemOperand::MOVolatile;
return true;
}
@@ -11741,10 +11824,10 @@ bool AArch64TargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info,
case Intrinsic::aarch64_stxr: {
PointerType *PtrTy = cast<PointerType>(I.getArgOperand(1)->getType());
Info.opc = ISD::INTRINSIC_W_CHAIN;
- Info.memVT = MVT::getVT(PtrTy->getElementType());
+ Info.memVT = MVT::getVT(PtrTy->getPointerElementType());
Info.ptrVal = I.getArgOperand(1);
Info.offset = 0;
- Info.align = DL.getABITypeAlign(PtrTy->getElementType());
+ Info.align = DL.getABITypeAlign(PtrTy->getPointerElementType());
Info.flags = MachineMemOperand::MOStore | MachineMemOperand::MOVolatile;
return true;
}
@@ -11772,7 +11855,7 @@ bool AArch64TargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info,
Info.memVT = MVT::getVT(I.getType());
Info.ptrVal = I.getArgOperand(1);
Info.offset = 0;
- Info.align = DL.getABITypeAlign(PtrTy->getElementType());
+ Info.align = DL.getABITypeAlign(PtrTy->getPointerElementType());
Info.flags = MachineMemOperand::MOLoad | MachineMemOperand::MONonTemporal;
return true;
}
@@ -11782,7 +11865,7 @@ bool AArch64TargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info,
Info.memVT = MVT::getVT(I.getOperand(0)->getType());
Info.ptrVal = I.getArgOperand(2);
Info.offset = 0;
- Info.align = DL.getABITypeAlign(PtrTy->getElementType());
+ Info.align = DL.getABITypeAlign(PtrTy->getPointerElementType());
Info.flags = MachineMemOperand::MOStore | MachineMemOperand::MONonTemporal;
return true;
}
@@ -12320,7 +12403,7 @@ bool AArch64TargetLowering::lowerInterleavedLoad(
Value *PTrue = nullptr;
if (UseScalable) {
- unsigned PgPattern =
+ Optional<unsigned> PgPattern =
getSVEPredPatternFromNumElements(FVTy->getNumElements());
if (Subtarget->getMinSVEVectorSizeInBits() ==
Subtarget->getMaxSVEVectorSizeInBits() &&
@@ -12328,7 +12411,7 @@ bool AArch64TargetLowering::lowerInterleavedLoad(
PgPattern = AArch64SVEPredPattern::all;
auto *PTruePat =
- ConstantInt::get(Type::getInt32Ty(LDVTy->getContext()), PgPattern);
+ ConstantInt::get(Type::getInt32Ty(LDVTy->getContext()), *PgPattern);
PTrue = Builder.CreateIntrinsic(Intrinsic::aarch64_sve_ptrue, {PredTy},
{PTruePat});
}
@@ -12500,7 +12583,7 @@ bool AArch64TargetLowering::lowerInterleavedStore(StoreInst *SI,
Value *PTrue = nullptr;
if (UseScalable) {
- unsigned PgPattern =
+ Optional<unsigned> PgPattern =
getSVEPredPatternFromNumElements(SubVecTy->getNumElements());
if (Subtarget->getMinSVEVectorSizeInBits() ==
Subtarget->getMaxSVEVectorSizeInBits() &&
@@ -12509,7 +12592,7 @@ bool AArch64TargetLowering::lowerInterleavedStore(StoreInst *SI,
PgPattern = AArch64SVEPredPattern::all;
auto *PTruePat =
- ConstantInt::get(Type::getInt32Ty(STVTy->getContext()), PgPattern);
+ ConstantInt::get(Type::getInt32Ty(STVTy->getContext()), *PgPattern);
PTrue = Builder.CreateIntrinsic(Intrinsic::aarch64_sve_ptrue, {PredTy},
{PTruePat});
}
@@ -12901,7 +12984,7 @@ bool AArch64TargetLowering::isExtractSubvectorCheap(EVT ResVT, EVT SrcVT,
if (!isOperationLegalOrCustom(ISD::EXTRACT_SUBVECTOR, ResVT))
return false;
- return (Index == 0 || Index == ResVT.getVectorNumElements());
+ return (Index == 0 || Index == ResVT.getVectorMinNumElements());
}
/// Turn vector tests of the signbit in the form of:
@@ -14261,6 +14344,7 @@ static SDValue performConcatVectorsCombine(SDNode *N,
static SDValue
performInsertSubvectorCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
SelectionDAG &DAG) {
+ SDLoc DL(N);
SDValue Vec = N->getOperand(0);
SDValue SubVec = N->getOperand(1);
uint64_t IdxVal = N->getConstantOperandVal(2);
@@ -14286,7 +14370,6 @@ performInsertSubvectorCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
// Fold insert_subvector -> concat_vectors
// insert_subvector(Vec,Sub,lo) -> concat_vectors(Sub,extract(Vec,hi))
// insert_subvector(Vec,Sub,hi) -> concat_vectors(extract(Vec,lo),Sub)
- SDLoc DL(N);
SDValue Lo, Hi;
if (IdxVal == 0) {
Lo = SubVec;
@@ -15004,7 +15087,15 @@ static SDValue combineSVEReductionOrderedFP(SDNode *N, unsigned Opc,
Zero);
}
-static bool isAllActivePredicate(SDValue N) {
+static bool isAllInactivePredicate(SDValue N) {
+ // Look through cast.
+ while (N.getOpcode() == AArch64ISD::REINTERPRET_CAST)
+ N = N.getOperand(0);
+
+ return N.getOpcode() == AArch64ISD::PFALSE;
+}
+
+static bool isAllActivePredicate(SelectionDAG &DAG, SDValue N) {
unsigned NumElts = N.getValueType().getVectorMinNumElements();
// Look through cast.
@@ -15023,6 +15114,21 @@ static bool isAllActivePredicate(SDValue N) {
N.getConstantOperandVal(0) == AArch64SVEPredPattern::all)
return N.getValueType().getVectorMinNumElements() >= NumElts;
+ // If we're compiling for a specific vector-length, we can check if the
+ // pattern's VL equals that of the scalable vector at runtime.
+ if (N.getOpcode() == AArch64ISD::PTRUE) {
+ const auto &Subtarget =
+ static_cast<const AArch64Subtarget &>(DAG.getSubtarget());
+ unsigned MinSVESize = Subtarget.getMinSVEVectorSizeInBits();
+ unsigned MaxSVESize = Subtarget.getMaxSVEVectorSizeInBits();
+ if (MaxSVESize && MinSVESize == MaxSVESize) {
+ unsigned VScale = MaxSVESize / AArch64::SVEBitsPerBlock;
+ unsigned PatNumElts =
+ getNumElementsFromSVEPredPattern(N.getConstantOperandVal(0));
+ return PatNumElts == (NumElts * VScale);
+ }
+ }
+
return false;
}
@@ -15039,7 +15145,7 @@ static SDValue convertMergedOpToPredOp(SDNode *N, unsigned Opc,
SDValue Op2 = N->getOperand(SwapOperands ? 2 : 3);
// ISD way to specify an all active predicate.
- if (isAllActivePredicate(Pg)) {
+ if (isAllActivePredicate(DAG, Pg)) {
if (UnpredOp)
return DAG.getNode(Opc, SDLoc(N), N->getValueType(0), Op1, Op2);
@@ -15870,7 +15976,7 @@ static SDValue performPostLD1Combine(SDNode *N,
SelectionDAG &DAG = DCI.DAG;
EVT VT = N->getValueType(0);
- if (VT.isScalableVector())
+ if (!VT.is128BitVector() && !VT.is64BitVector())
return SDValue();
unsigned LoadIdx = IsLaneOp ? 1 : 0;
@@ -16710,6 +16816,12 @@ static SDValue performVSelectCombine(SDNode *N, SelectionDAG &DAG) {
SDValue N0 = N->getOperand(0);
EVT CCVT = N0.getValueType();
+ if (isAllActivePredicate(DAG, N0))
+ return N->getOperand(1);
+
+ if (isAllInactivePredicate(N0))
+ return N->getOperand(2);
+
// Check for sign pattern (VSELECT setgt, iN lhs, -1, 1, -1) and transform
// into (OR (ASR lhs, N-1), 1), which requires less instructions for the
// supported types.
@@ -18753,7 +18865,7 @@ static SDValue getPredicateForFixedLengthVector(SelectionDAG &DAG, SDLoc &DL,
DAG.getTargetLoweringInfo().isTypeLegal(VT) &&
"Expected legal fixed length vector!");
- unsigned PgPattern =
+ Optional<unsigned> PgPattern =
getSVEPredPatternFromNumElements(VT.getVectorNumElements());
assert(PgPattern && "Unexpected element count for SVE predicate");
@@ -18789,7 +18901,7 @@ static SDValue getPredicateForFixedLengthVector(SelectionDAG &DAG, SDLoc &DL,
break;
}
- return getPTrue(DAG, DL, MaskVT, PgPattern);
+ return getPTrue(DAG, DL, MaskVT, *PgPattern);
}
static SDValue getPredicateForScalableVector(SelectionDAG &DAG, SDLoc &DL,
@@ -19281,7 +19393,12 @@ SDValue AArch64TargetLowering::LowerPredReductionToSVE(SDValue ReduceOp,
default:
return SDValue();
case ISD::VECREDUCE_OR:
- return getPTest(DAG, VT, Pg, Op, AArch64CC::ANY_ACTIVE);
+ if (isAllActivePredicate(DAG, Pg))
+ // The predicate can be 'Op' because
+ // vecreduce_or(Op & <all true>) <=> vecreduce_or(Op).
+ return getPTest(DAG, VT, Op, Op, AArch64CC::ANY_ACTIVE);
+ else
+ return getPTest(DAG, VT, Pg, Op, AArch64CC::ANY_ACTIVE);
case ISD::VECREDUCE_AND: {
Op = DAG.getNode(ISD::XOR, DL, OpVT, Op, Pg);
return getPTest(DAG, VT, Pg, Op, AArch64CC::NONE_ACTIVE);
@@ -19725,8 +19842,9 @@ SDValue AArch64TargetLowering::getSVESafeBitCast(EVT VT, SDValue Op,
return Op;
}
-bool AArch64TargetLowering::isAllActivePredicate(SDValue N) const {
- return ::isAllActivePredicate(N);
+bool AArch64TargetLowering::isAllActivePredicate(SelectionDAG &DAG,
+ SDValue N) const {
+ return ::isAllActivePredicate(DAG, N);
}
EVT AArch64TargetLowering::getPromotedVTForPredicate(EVT VT) const {
@@ -19777,7 +19895,7 @@ bool AArch64TargetLowering::SimplifyDemandedBitsForTargetNode(
Op, OriginalDemandedBits, OriginalDemandedElts, Known, TLO, Depth);
}
-bool AArch64TargetLowering::isConstantUnsignedBitfieldExtactLegal(
+bool AArch64TargetLowering::isConstantUnsignedBitfieldExtractLegal(
unsigned Opc, LLT Ty1, LLT Ty2) const {
return Ty1 == Ty2 && (Ty1 == LLT::scalar(32) || Ty1 == LLT::scalar(64));
}