aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
diff options
context:
space:
mode:
authorDimitry Andric <dim@FreeBSD.org>2021-02-16 20:13:02 +0000
committerDimitry Andric <dim@FreeBSD.org>2021-02-16 20:13:02 +0000
commitb60736ec1405bb0a8dd40989f67ef4c93da068ab (patch)
tree5c43fbb7c9fc45f0f87e0e6795a86267dbd12f9d /llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
parentcfca06d7963fa0909f90483b42a6d7d194d01e08 (diff)
downloadsrc-b60736ec1405bb0a8dd40989f67ef4c93da068ab.tar.gz
src-b60736ec1405bb0a8dd40989f67ef4c93da068ab.zip
Vendor import of llvm-project main 8e464dd76bef, the last commit beforevendor/llvm-project/llvmorg-12-init-17869-g8e464dd76bef
the upstream release/12.x branch was created.
Diffstat (limited to 'llvm/lib/Target/AArch64/AArch64ISelLowering.cpp')
-rw-r--r--llvm/lib/Target/AArch64/AArch64ISelLowering.cpp3221
1 files changed, 2665 insertions, 556 deletions
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 85db14ab66fe..1be09186dc0a 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -27,7 +27,6 @@
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/ADT/StringRef.h"
-#include "llvm/ADT/StringSwitch.h"
#include "llvm/ADT/Triple.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Analysis/VectorUtils.h"
@@ -113,9 +112,76 @@ EnableOptimizeLogicalImm("aarch64-enable-logical-imm", cl::Hidden,
"optimization"),
cl::init(true));
+// Temporary option added for the purpose of testing functionality added
+// to DAGCombiner.cpp in D92230. It is expected that this can be removed
+// in future when both implementations will be based off MGATHER rather
+// than the GLD1 nodes added for the SVE gather load intrinsics.
+static cl::opt<bool>
+EnableCombineMGatherIntrinsics("aarch64-enable-mgather-combine", cl::Hidden,
+ cl::desc("Combine extends of AArch64 masked "
+ "gather intrinsics"),
+ cl::init(true));
+
/// Value type used for condition codes.
static const MVT MVT_CC = MVT::i32;
+static inline EVT getPackedSVEVectorVT(EVT VT) {
+ switch (VT.getSimpleVT().SimpleTy) {
+ default:
+ llvm_unreachable("unexpected element type for vector");
+ case MVT::i8:
+ return MVT::nxv16i8;
+ case MVT::i16:
+ return MVT::nxv8i16;
+ case MVT::i32:
+ return MVT::nxv4i32;
+ case MVT::i64:
+ return MVT::nxv2i64;
+ case MVT::f16:
+ return MVT::nxv8f16;
+ case MVT::f32:
+ return MVT::nxv4f32;
+ case MVT::f64:
+ return MVT::nxv2f64;
+ case MVT::bf16:
+ return MVT::nxv8bf16;
+ }
+}
+
+// NOTE: Currently there's only a need to return integer vector types. If this
+// changes then just add an extra "type" parameter.
+static inline EVT getPackedSVEVectorVT(ElementCount EC) {
+ switch (EC.getKnownMinValue()) {
+ default:
+ llvm_unreachable("unexpected element count for vector");
+ case 16:
+ return MVT::nxv16i8;
+ case 8:
+ return MVT::nxv8i16;
+ case 4:
+ return MVT::nxv4i32;
+ case 2:
+ return MVT::nxv2i64;
+ }
+}
+
+static inline EVT getPromotedVTForPredicate(EVT VT) {
+ assert(VT.isScalableVector() && (VT.getVectorElementType() == MVT::i1) &&
+ "Expected scalable predicate vector type!");
+ switch (VT.getVectorMinNumElements()) {
+ default:
+ llvm_unreachable("unexpected element count for vector");
+ case 2:
+ return MVT::nxv2i64;
+ case 4:
+ return MVT::nxv4i32;
+ case 8:
+ return MVT::nxv8i16;
+ case 16:
+ return MVT::nxv16i8;
+ }
+}
+
/// Returns true if VT's elements occupy the lowest bit positions of its
/// associated register class without any intervening space.
///
@@ -128,6 +194,42 @@ static inline bool isPackedVectorType(EVT VT, SelectionDAG &DAG) {
VT.getSizeInBits().getKnownMinSize() == AArch64::SVEBitsPerBlock;
}
+// Returns true for ####_MERGE_PASSTHRU opcodes, whose operands have a leading
+// predicate and end with a passthru value matching the result type.
+static bool isMergePassthruOpcode(unsigned Opc) {
+ switch (Opc) {
+ default:
+ return false;
+ case AArch64ISD::BITREVERSE_MERGE_PASSTHRU:
+ case AArch64ISD::BSWAP_MERGE_PASSTHRU:
+ case AArch64ISD::CTLZ_MERGE_PASSTHRU:
+ case AArch64ISD::CTPOP_MERGE_PASSTHRU:
+ case AArch64ISD::DUP_MERGE_PASSTHRU:
+ case AArch64ISD::ABS_MERGE_PASSTHRU:
+ case AArch64ISD::NEG_MERGE_PASSTHRU:
+ case AArch64ISD::FNEG_MERGE_PASSTHRU:
+ case AArch64ISD::SIGN_EXTEND_INREG_MERGE_PASSTHRU:
+ case AArch64ISD::ZERO_EXTEND_INREG_MERGE_PASSTHRU:
+ case AArch64ISD::FCEIL_MERGE_PASSTHRU:
+ case AArch64ISD::FFLOOR_MERGE_PASSTHRU:
+ case AArch64ISD::FNEARBYINT_MERGE_PASSTHRU:
+ case AArch64ISD::FRINT_MERGE_PASSTHRU:
+ case AArch64ISD::FROUND_MERGE_PASSTHRU:
+ case AArch64ISD::FROUNDEVEN_MERGE_PASSTHRU:
+ case AArch64ISD::FTRUNC_MERGE_PASSTHRU:
+ case AArch64ISD::FP_ROUND_MERGE_PASSTHRU:
+ case AArch64ISD::FP_EXTEND_MERGE_PASSTHRU:
+ case AArch64ISD::SINT_TO_FP_MERGE_PASSTHRU:
+ case AArch64ISD::UINT_TO_FP_MERGE_PASSTHRU:
+ case AArch64ISD::FCVTZU_MERGE_PASSTHRU:
+ case AArch64ISD::FCVTZS_MERGE_PASSTHRU:
+ case AArch64ISD::FSQRT_MERGE_PASSTHRU:
+ case AArch64ISD::FRECPX_MERGE_PASSTHRU:
+ case AArch64ISD::FABS_MERGE_PASSTHRU:
+ return true;
+ }
+}
+
AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
const AArch64Subtarget &STI)
: TargetLowering(TM), Subtarget(&STI) {
@@ -161,7 +263,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
addDRTypeForNEON(MVT::v1i64);
addDRTypeForNEON(MVT::v1f64);
addDRTypeForNEON(MVT::v4f16);
- addDRTypeForNEON(MVT::v4bf16);
+ if (Subtarget->hasBF16())
+ addDRTypeForNEON(MVT::v4bf16);
addQRTypeForNEON(MVT::v4f32);
addQRTypeForNEON(MVT::v2f64);
@@ -170,7 +273,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
addQRTypeForNEON(MVT::v4i32);
addQRTypeForNEON(MVT::v2i64);
addQRTypeForNEON(MVT::v8f16);
- addQRTypeForNEON(MVT::v8bf16);
+ if (Subtarget->hasBF16())
+ addQRTypeForNEON(MVT::v8bf16);
}
if (Subtarget->hasSVE()) {
@@ -199,7 +303,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
addRegisterClass(MVT::nxv8bf16, &AArch64::ZPRRegClass);
}
- if (useSVEForFixedLengthVectors()) {
+ if (Subtarget->useSVEForFixedLengthVectors()) {
for (MVT VT : MVT::integer_fixedlen_vector_valuetypes())
if (useSVEForFixedLengthVectorVT(VT))
addRegisterClass(VT, &AArch64::ZPRRegClass);
@@ -230,7 +334,9 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
MVT::nxv2f64 }) {
setCondCodeAction(ISD::SETO, VT, Expand);
setCondCodeAction(ISD::SETOLT, VT, Expand);
+ setCondCodeAction(ISD::SETLT, VT, Expand);
setCondCodeAction(ISD::SETOLE, VT, Expand);
+ setCondCodeAction(ISD::SETLE, VT, Expand);
setCondCodeAction(ISD::SETULT, VT, Expand);
setCondCodeAction(ISD::SETULE, VT, Expand);
setCondCodeAction(ISD::SETUGE, VT, Expand);
@@ -296,12 +402,12 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
// Virtually no operation on f128 is legal, but LLVM can't expand them when
// there's a valid register class, so we need custom operations in most cases.
setOperationAction(ISD::FABS, MVT::f128, Expand);
- setOperationAction(ISD::FADD, MVT::f128, Custom);
+ setOperationAction(ISD::FADD, MVT::f128, LibCall);
setOperationAction(ISD::FCOPYSIGN, MVT::f128, Expand);
setOperationAction(ISD::FCOS, MVT::f128, Expand);
- setOperationAction(ISD::FDIV, MVT::f128, Custom);
+ setOperationAction(ISD::FDIV, MVT::f128, LibCall);
setOperationAction(ISD::FMA, MVT::f128, Expand);
- setOperationAction(ISD::FMUL, MVT::f128, Custom);
+ setOperationAction(ISD::FMUL, MVT::f128, LibCall);
setOperationAction(ISD::FNEG, MVT::f128, Expand);
setOperationAction(ISD::FPOW, MVT::f128, Expand);
setOperationAction(ISD::FREM, MVT::f128, Expand);
@@ -309,7 +415,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::FSIN, MVT::f128, Expand);
setOperationAction(ISD::FSINCOS, MVT::f128, Expand);
setOperationAction(ISD::FSQRT, MVT::f128, Expand);
- setOperationAction(ISD::FSUB, MVT::f128, Custom);
+ setOperationAction(ISD::FSUB, MVT::f128, LibCall);
setOperationAction(ISD::FTRUNC, MVT::f128, Expand);
setOperationAction(ISD::SETCC, MVT::f128, Custom);
setOperationAction(ISD::STRICT_FSETCC, MVT::f128, Custom);
@@ -345,8 +451,10 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::STRICT_UINT_TO_FP, MVT::i32, Custom);
setOperationAction(ISD::STRICT_UINT_TO_FP, MVT::i64, Custom);
setOperationAction(ISD::STRICT_UINT_TO_FP, MVT::i128, Custom);
+ setOperationAction(ISD::FP_ROUND, MVT::f16, Custom);
setOperationAction(ISD::FP_ROUND, MVT::f32, Custom);
setOperationAction(ISD::FP_ROUND, MVT::f64, Custom);
+ setOperationAction(ISD::STRICT_FP_ROUND, MVT::f16, Custom);
setOperationAction(ISD::STRICT_FP_ROUND, MVT::f32, Custom);
setOperationAction(ISD::STRICT_FP_ROUND, MVT::f64, Custom);
@@ -401,6 +509,9 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::CTPOP, MVT::i64, Custom);
setOperationAction(ISD::CTPOP, MVT::i128, Custom);
+ setOperationAction(ISD::ABS, MVT::i32, Custom);
+ setOperationAction(ISD::ABS, MVT::i64, Custom);
+
setOperationAction(ISD::SDIVREM, MVT::i32, Expand);
setOperationAction(ISD::SDIVREM, MVT::i64, Expand);
for (MVT VT : MVT::fixedlen_vector_valuetypes()) {
@@ -588,6 +699,57 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::ATOMIC_LOAD_AND, MVT::i32, Custom);
setOperationAction(ISD::ATOMIC_LOAD_AND, MVT::i64, Custom);
+ // Generate outline atomics library calls only if LSE was not specified for
+ // subtarget
+ if (Subtarget->outlineAtomics() && !Subtarget->hasLSE()) {
+ setOperationAction(ISD::ATOMIC_CMP_SWAP, MVT::i8, LibCall);
+ setOperationAction(ISD::ATOMIC_CMP_SWAP, MVT::i16, LibCall);
+ setOperationAction(ISD::ATOMIC_CMP_SWAP, MVT::i32, LibCall);
+ setOperationAction(ISD::ATOMIC_CMP_SWAP, MVT::i64, LibCall);
+ setOperationAction(ISD::ATOMIC_CMP_SWAP, MVT::i128, LibCall);
+ setOperationAction(ISD::ATOMIC_SWAP, MVT::i8, LibCall);
+ setOperationAction(ISD::ATOMIC_SWAP, MVT::i16, LibCall);
+ setOperationAction(ISD::ATOMIC_SWAP, MVT::i32, LibCall);
+ setOperationAction(ISD::ATOMIC_SWAP, MVT::i64, LibCall);
+ setOperationAction(ISD::ATOMIC_LOAD_ADD, MVT::i8, LibCall);
+ setOperationAction(ISD::ATOMIC_LOAD_ADD, MVT::i16, LibCall);
+ setOperationAction(ISD::ATOMIC_LOAD_ADD, MVT::i32, LibCall);
+ setOperationAction(ISD::ATOMIC_LOAD_ADD, MVT::i64, LibCall);
+ setOperationAction(ISD::ATOMIC_LOAD_OR, MVT::i8, LibCall);
+ setOperationAction(ISD::ATOMIC_LOAD_OR, MVT::i16, LibCall);
+ setOperationAction(ISD::ATOMIC_LOAD_OR, MVT::i32, LibCall);
+ setOperationAction(ISD::ATOMIC_LOAD_OR, MVT::i64, LibCall);
+ setOperationAction(ISD::ATOMIC_LOAD_CLR, MVT::i8, LibCall);
+ setOperationAction(ISD::ATOMIC_LOAD_CLR, MVT::i16, LibCall);
+ setOperationAction(ISD::ATOMIC_LOAD_CLR, MVT::i32, LibCall);
+ setOperationAction(ISD::ATOMIC_LOAD_CLR, MVT::i64, LibCall);
+ setOperationAction(ISD::ATOMIC_LOAD_XOR, MVT::i8, LibCall);
+ setOperationAction(ISD::ATOMIC_LOAD_XOR, MVT::i16, LibCall);
+ setOperationAction(ISD::ATOMIC_LOAD_XOR, MVT::i32, LibCall);
+ setOperationAction(ISD::ATOMIC_LOAD_XOR, MVT::i64, LibCall);
+#define LCALLNAMES(A, B, N) \
+ setLibcallName(A##N##_RELAX, #B #N "_relax"); \
+ setLibcallName(A##N##_ACQ, #B #N "_acq"); \
+ setLibcallName(A##N##_REL, #B #N "_rel"); \
+ setLibcallName(A##N##_ACQ_REL, #B #N "_acq_rel");
+#define LCALLNAME4(A, B) \
+ LCALLNAMES(A, B, 1) \
+ LCALLNAMES(A, B, 2) LCALLNAMES(A, B, 4) LCALLNAMES(A, B, 8)
+#define LCALLNAME5(A, B) \
+ LCALLNAMES(A, B, 1) \
+ LCALLNAMES(A, B, 2) \
+ LCALLNAMES(A, B, 4) LCALLNAMES(A, B, 8) LCALLNAMES(A, B, 16)
+ LCALLNAME5(RTLIB::OUTLINE_ATOMIC_CAS, __aarch64_cas)
+ LCALLNAME4(RTLIB::OUTLINE_ATOMIC_SWP, __aarch64_swp)
+ LCALLNAME4(RTLIB::OUTLINE_ATOMIC_LDADD, __aarch64_ldadd)
+ LCALLNAME4(RTLIB::OUTLINE_ATOMIC_LDSET, __aarch64_ldset)
+ LCALLNAME4(RTLIB::OUTLINE_ATOMIC_LDCLR, __aarch64_ldclr)
+ LCALLNAME4(RTLIB::OUTLINE_ATOMIC_LDEOR, __aarch64_ldeor)
+#undef LCALLNAMES
+#undef LCALLNAME4
+#undef LCALLNAME5
+ }
+
// 128-bit loads and stores can be done without expanding
setOperationAction(ISD::LOAD, MVT::i128, Custom);
setOperationAction(ISD::STORE, MVT::i128, Custom);
@@ -677,8 +839,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
// Trap.
setOperationAction(ISD::TRAP, MVT::Other, Legal);
- if (Subtarget->isTargetWindows())
- setOperationAction(ISD::DEBUGTRAP, MVT::Other, Legal);
+ setOperationAction(ISD::DEBUGTRAP, MVT::Other, Legal);
+ setOperationAction(ISD::UBSANTRAP, MVT::Other, Legal);
// We combine OR nodes for bitfield operations.
setTargetDAGCombine(ISD::OR);
@@ -688,6 +850,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
// Vector add and sub nodes may conceal a high-half opportunity.
// Also, try to fold ADD into CSINC/CSINV..
setTargetDAGCombine(ISD::ADD);
+ setTargetDAGCombine(ISD::ABS);
setTargetDAGCombine(ISD::SUB);
setTargetDAGCombine(ISD::SRL);
setTargetDAGCombine(ISD::XOR);
@@ -704,11 +867,15 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setTargetDAGCombine(ISD::ZERO_EXTEND);
setTargetDAGCombine(ISD::SIGN_EXTEND);
setTargetDAGCombine(ISD::SIGN_EXTEND_INREG);
+ setTargetDAGCombine(ISD::TRUNCATE);
setTargetDAGCombine(ISD::CONCAT_VECTORS);
setTargetDAGCombine(ISD::STORE);
if (Subtarget->supportsAddressTopByteIgnored())
setTargetDAGCombine(ISD::LOAD);
+ setTargetDAGCombine(ISD::MGATHER);
+ setTargetDAGCombine(ISD::MSCATTER);
+
setTargetDAGCombine(ISD::MUL);
setTargetDAGCombine(ISD::SELECT);
@@ -717,6 +884,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setTargetDAGCombine(ISD::INTRINSIC_VOID);
setTargetDAGCombine(ISD::INTRINSIC_W_CHAIN);
setTargetDAGCombine(ISD::INSERT_VECTOR_ELT);
+ setTargetDAGCombine(ISD::EXTRACT_VECTOR_ELT);
+ setTargetDAGCombine(ISD::VECREDUCE_ADD);
setTargetDAGCombine(ISD::GlobalAddress);
@@ -836,28 +1005,33 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::MUL, MVT::v4i32, Custom);
setOperationAction(ISD::MUL, MVT::v2i64, Custom);
+ // Saturates
for (MVT VT : { MVT::v8i8, MVT::v4i16, MVT::v2i32,
MVT::v16i8, MVT::v8i16, MVT::v4i32, MVT::v2i64 }) {
- // Vector reductions
- setOperationAction(ISD::VECREDUCE_ADD, VT, Custom);
- setOperationAction(ISD::VECREDUCE_SMAX, VT, Custom);
- setOperationAction(ISD::VECREDUCE_SMIN, VT, Custom);
- setOperationAction(ISD::VECREDUCE_UMAX, VT, Custom);
- setOperationAction(ISD::VECREDUCE_UMIN, VT, Custom);
-
- // Saturates
setOperationAction(ISD::SADDSAT, VT, Legal);
setOperationAction(ISD::UADDSAT, VT, Legal);
setOperationAction(ISD::SSUBSAT, VT, Legal);
setOperationAction(ISD::USUBSAT, VT, Legal);
-
- setOperationAction(ISD::TRUNCATE, VT, Custom);
}
+
+ // Vector reductions
for (MVT VT : { MVT::v4f16, MVT::v2f32,
MVT::v8f16, MVT::v4f32, MVT::v2f64 }) {
setOperationAction(ISD::VECREDUCE_FMAX, VT, Custom);
setOperationAction(ISD::VECREDUCE_FMIN, VT, Custom);
+
+ if (VT.getVectorElementType() != MVT::f16 || Subtarget->hasFullFP16())
+ setOperationAction(ISD::VECREDUCE_FADD, VT, Legal);
}
+ for (MVT VT : { MVT::v8i8, MVT::v4i16, MVT::v2i32,
+ MVT::v16i8, MVT::v8i16, MVT::v4i32 }) {
+ setOperationAction(ISD::VECREDUCE_ADD, VT, Custom);
+ setOperationAction(ISD::VECREDUCE_SMAX, VT, Custom);
+ setOperationAction(ISD::VECREDUCE_SMIN, VT, Custom);
+ setOperationAction(ISD::VECREDUCE_UMAX, VT, Custom);
+ setOperationAction(ISD::VECREDUCE_UMIN, VT, Custom);
+ }
+ setOperationAction(ISD::VECREDUCE_ADD, MVT::v2i64, Custom);
setOperationAction(ISD::ANY_EXTEND, MVT::v4i32, Legal);
setTruncStoreAction(MVT::v2i32, MVT::v2i16, Expand);
@@ -918,43 +1092,112 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
// FIXME: Add custom lowering of MLOAD to handle different passthrus (not a
// splat of 0 or undef) once vector selects supported in SVE codegen. See
// D68877 for more details.
- for (MVT VT : MVT::integer_scalable_vector_valuetypes()) {
- if (isTypeLegal(VT)) {
- setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom);
- setOperationAction(ISD::SPLAT_VECTOR, VT, Custom);
- setOperationAction(ISD::SELECT, VT, Custom);
- setOperationAction(ISD::SDIV, VT, Custom);
- setOperationAction(ISD::UDIV, VT, Custom);
- setOperationAction(ISD::SMIN, VT, Custom);
- setOperationAction(ISD::UMIN, VT, Custom);
- setOperationAction(ISD::SMAX, VT, Custom);
- setOperationAction(ISD::UMAX, VT, Custom);
- setOperationAction(ISD::SHL, VT, Custom);
- setOperationAction(ISD::SRL, VT, Custom);
- setOperationAction(ISD::SRA, VT, Custom);
- if (VT.getScalarType() == MVT::i1)
- setOperationAction(ISD::SETCC, VT, Custom);
- }
+ for (auto VT : {MVT::nxv16i8, MVT::nxv8i16, MVT::nxv4i32, MVT::nxv2i64}) {
+ setOperationAction(ISD::BITREVERSE, VT, Custom);
+ setOperationAction(ISD::BSWAP, VT, Custom);
+ setOperationAction(ISD::CTLZ, VT, Custom);
+ setOperationAction(ISD::CTPOP, VT, Custom);
+ setOperationAction(ISD::CTTZ, VT, Custom);
+ setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom);
+ setOperationAction(ISD::UINT_TO_FP, VT, Custom);
+ setOperationAction(ISD::SINT_TO_FP, VT, Custom);
+ setOperationAction(ISD::FP_TO_UINT, VT, Custom);
+ setOperationAction(ISD::FP_TO_SINT, VT, Custom);
+ setOperationAction(ISD::MGATHER, VT, Custom);
+ setOperationAction(ISD::MSCATTER, VT, Custom);
+ setOperationAction(ISD::MUL, VT, Custom);
+ setOperationAction(ISD::SPLAT_VECTOR, VT, Custom);
+ setOperationAction(ISD::SELECT, VT, Custom);
+ setOperationAction(ISD::SDIV, VT, Custom);
+ setOperationAction(ISD::UDIV, VT, Custom);
+ setOperationAction(ISD::SMIN, VT, Custom);
+ setOperationAction(ISD::UMIN, VT, Custom);
+ setOperationAction(ISD::SMAX, VT, Custom);
+ setOperationAction(ISD::UMAX, VT, Custom);
+ setOperationAction(ISD::SHL, VT, Custom);
+ setOperationAction(ISD::SRL, VT, Custom);
+ setOperationAction(ISD::SRA, VT, Custom);
+ setOperationAction(ISD::ABS, VT, Custom);
+ setOperationAction(ISD::VECREDUCE_ADD, VT, Custom);
+ setOperationAction(ISD::VECREDUCE_AND, VT, Custom);
+ setOperationAction(ISD::VECREDUCE_OR, VT, Custom);
+ setOperationAction(ISD::VECREDUCE_XOR, VT, Custom);
+ setOperationAction(ISD::VECREDUCE_UMIN, VT, Custom);
+ setOperationAction(ISD::VECREDUCE_UMAX, VT, Custom);
+ setOperationAction(ISD::VECREDUCE_SMIN, VT, Custom);
+ setOperationAction(ISD::VECREDUCE_SMAX, VT, Custom);
}
- for (auto VT : {MVT::nxv8i8, MVT::nxv4i16, MVT::nxv2i32})
+ // Illegal unpacked integer vector types.
+ for (auto VT : {MVT::nxv8i8, MVT::nxv4i16, MVT::nxv2i32}) {
setOperationAction(ISD::EXTRACT_SUBVECTOR, VT, Custom);
+ setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom);
+ }
- setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i8, Custom);
- setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i16, Custom);
-
- for (MVT VT : MVT::fp_scalable_vector_valuetypes()) {
- if (isTypeLegal(VT)) {
- setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom);
- setOperationAction(ISD::SPLAT_VECTOR, VT, Custom);
- setOperationAction(ISD::SELECT, VT, Custom);
- setOperationAction(ISD::FMA, VT, Custom);
+ for (auto VT : {MVT::nxv16i1, MVT::nxv8i1, MVT::nxv4i1, MVT::nxv2i1}) {
+ setOperationAction(ISD::CONCAT_VECTORS, VT, Custom);
+ setOperationAction(ISD::SELECT, VT, Custom);
+ setOperationAction(ISD::SETCC, VT, Custom);
+ setOperationAction(ISD::SPLAT_VECTOR, VT, Custom);
+ setOperationAction(ISD::TRUNCATE, VT, Custom);
+ setOperationAction(ISD::VECREDUCE_AND, VT, Custom);
+ setOperationAction(ISD::VECREDUCE_OR, VT, Custom);
+ setOperationAction(ISD::VECREDUCE_XOR, VT, Custom);
+
+ // There are no legal MVT::nxv16f## based types.
+ if (VT != MVT::nxv16i1) {
+ setOperationAction(ISD::SINT_TO_FP, VT, Custom);
+ setOperationAction(ISD::UINT_TO_FP, VT, Custom);
}
}
+ for (auto VT : {MVT::nxv2f16, MVT::nxv4f16, MVT::nxv8f16, MVT::nxv2f32,
+ MVT::nxv4f32, MVT::nxv2f64}) {
+ setOperationAction(ISD::CONCAT_VECTORS, VT, Custom);
+ setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom);
+ setOperationAction(ISD::MGATHER, VT, Custom);
+ setOperationAction(ISD::MSCATTER, VT, Custom);
+ setOperationAction(ISD::SPLAT_VECTOR, VT, Custom);
+ setOperationAction(ISD::SELECT, VT, Custom);
+ setOperationAction(ISD::FADD, VT, Custom);
+ setOperationAction(ISD::FDIV, VT, Custom);
+ setOperationAction(ISD::FMA, VT, Custom);
+ setOperationAction(ISD::FMAXNUM, VT, Custom);
+ setOperationAction(ISD::FMINNUM, VT, Custom);
+ setOperationAction(ISD::FMUL, VT, Custom);
+ setOperationAction(ISD::FNEG, VT, Custom);
+ setOperationAction(ISD::FSUB, VT, Custom);
+ setOperationAction(ISD::FCEIL, VT, Custom);
+ setOperationAction(ISD::FFLOOR, VT, Custom);
+ setOperationAction(ISD::FNEARBYINT, VT, Custom);
+ setOperationAction(ISD::FRINT, VT, Custom);
+ setOperationAction(ISD::FROUND, VT, Custom);
+ setOperationAction(ISD::FROUNDEVEN, VT, Custom);
+ setOperationAction(ISD::FTRUNC, VT, Custom);
+ setOperationAction(ISD::FSQRT, VT, Custom);
+ setOperationAction(ISD::FABS, VT, Custom);
+ setOperationAction(ISD::FP_EXTEND, VT, Custom);
+ setOperationAction(ISD::FP_ROUND, VT, Custom);
+ setOperationAction(ISD::VECREDUCE_FADD, VT, Custom);
+ setOperationAction(ISD::VECREDUCE_FMAX, VT, Custom);
+ setOperationAction(ISD::VECREDUCE_FMIN, VT, Custom);
+ setOperationAction(ISD::VECREDUCE_SEQ_FADD, VT, Custom);
+ }
+
+ for (auto VT : {MVT::nxv2bf16, MVT::nxv4bf16, MVT::nxv8bf16}) {
+ setOperationAction(ISD::CONCAT_VECTORS, VT, Custom);
+ setOperationAction(ISD::MGATHER, VT, Custom);
+ setOperationAction(ISD::MSCATTER, VT, Custom);
+ }
+
+ setOperationAction(ISD::SPLAT_VECTOR, MVT::nxv8bf16, Custom);
+
+ setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i8, Custom);
+ setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i16, Custom);
+
// NOTE: Currently this has to happen after computeRegisterProperties rather
// than the preferred option of combining it with the addRegisterClass call.
- if (useSVEForFixedLengthVectors()) {
+ if (Subtarget->useSVEForFixedLengthVectors()) {
for (MVT VT : MVT::integer_fixedlen_vector_valuetypes())
if (useSVEForFixedLengthVectorVT(VT))
addTypeForFixedLengthSVE(VT);
@@ -972,6 +1215,61 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::TRUNCATE, VT, Custom);
for (auto VT : {MVT::v8f16, MVT::v4f32})
setOperationAction(ISD::FP_ROUND, VT, Expand);
+
+ // These operations are not supported on NEON but SVE can do them.
+ setOperationAction(ISD::BITREVERSE, MVT::v1i64, Custom);
+ setOperationAction(ISD::CTLZ, MVT::v1i64, Custom);
+ setOperationAction(ISD::CTLZ, MVT::v2i64, Custom);
+ setOperationAction(ISD::CTTZ, MVT::v1i64, Custom);
+ setOperationAction(ISD::MUL, MVT::v1i64, Custom);
+ setOperationAction(ISD::MUL, MVT::v2i64, Custom);
+ setOperationAction(ISD::SDIV, MVT::v8i8, Custom);
+ setOperationAction(ISD::SDIV, MVT::v16i8, Custom);
+ setOperationAction(ISD::SDIV, MVT::v4i16, Custom);
+ setOperationAction(ISD::SDIV, MVT::v8i16, Custom);
+ setOperationAction(ISD::SDIV, MVT::v2i32, Custom);
+ setOperationAction(ISD::SDIV, MVT::v4i32, Custom);
+ setOperationAction(ISD::SDIV, MVT::v1i64, Custom);
+ setOperationAction(ISD::SDIV, MVT::v2i64, Custom);
+ setOperationAction(ISD::SMAX, MVT::v1i64, Custom);
+ setOperationAction(ISD::SMAX, MVT::v2i64, Custom);
+ setOperationAction(ISD::SMIN, MVT::v1i64, Custom);
+ setOperationAction(ISD::SMIN, MVT::v2i64, Custom);
+ setOperationAction(ISD::UDIV, MVT::v8i8, Custom);
+ setOperationAction(ISD::UDIV, MVT::v16i8, Custom);
+ setOperationAction(ISD::UDIV, MVT::v4i16, Custom);
+ setOperationAction(ISD::UDIV, MVT::v8i16, Custom);
+ setOperationAction(ISD::UDIV, MVT::v2i32, Custom);
+ setOperationAction(ISD::UDIV, MVT::v4i32, Custom);
+ setOperationAction(ISD::UDIV, MVT::v1i64, Custom);
+ setOperationAction(ISD::UDIV, MVT::v2i64, Custom);
+ setOperationAction(ISD::UMAX, MVT::v1i64, Custom);
+ setOperationAction(ISD::UMAX, MVT::v2i64, Custom);
+ setOperationAction(ISD::UMIN, MVT::v1i64, Custom);
+ setOperationAction(ISD::UMIN, MVT::v2i64, Custom);
+ setOperationAction(ISD::VECREDUCE_SMAX, MVT::v2i64, Custom);
+ setOperationAction(ISD::VECREDUCE_SMIN, MVT::v2i64, Custom);
+ setOperationAction(ISD::VECREDUCE_UMAX, MVT::v2i64, Custom);
+ setOperationAction(ISD::VECREDUCE_UMIN, MVT::v2i64, Custom);
+
+ // Int operations with no NEON support.
+ for (auto VT : {MVT::v8i8, MVT::v16i8, MVT::v4i16, MVT::v8i16,
+ MVT::v2i32, MVT::v4i32, MVT::v2i64}) {
+ setOperationAction(ISD::BITREVERSE, VT, Custom);
+ setOperationAction(ISD::CTTZ, VT, Custom);
+ setOperationAction(ISD::VECREDUCE_AND, VT, Custom);
+ setOperationAction(ISD::VECREDUCE_OR, VT, Custom);
+ setOperationAction(ISD::VECREDUCE_XOR, VT, Custom);
+ }
+
+ // FP operations with no NEON support.
+ for (auto VT : {MVT::v4f16, MVT::v8f16, MVT::v2f32, MVT::v4f32,
+ MVT::v1f64, MVT::v2f64})
+ setOperationAction(ISD::VECREDUCE_SEQ_FADD, VT, Custom);
+
+ // Use SVE for vectors with more than 2 elements.
+ for (auto VT : {MVT::v4f16, MVT::v8f16, MVT::v4f32})
+ setOperationAction(ISD::VECREDUCE_FADD, VT, Custom);
}
}
@@ -1043,6 +1341,7 @@ void AArch64TargetLowering::addTypeForNEON(MVT VT, MVT PromotedBitwiseVT) {
// F[MIN|MAX][NUM|NAN] are available for all FP NEON types.
if (VT.isFloatingPoint() &&
+ VT.getVectorElementType() != MVT::bf16 &&
(VT.getVectorElementType() != MVT::f16 || Subtarget->hasFullFP16()))
for (unsigned Opcode :
{ISD::FMINIMUM, ISD::FMAXIMUM, ISD::FMINNUM, ISD::FMAXNUM})
@@ -1068,11 +1367,64 @@ void AArch64TargetLowering::addTypeForFixedLengthSVE(MVT VT) {
setOperationAction(ISD::EXTRACT_SUBVECTOR, VT, Custom);
// Lower fixed length vector operations to scalable equivalents.
+ setOperationAction(ISD::ABS, VT, Custom);
setOperationAction(ISD::ADD, VT, Custom);
+ setOperationAction(ISD::AND, VT, Custom);
+ setOperationAction(ISD::ANY_EXTEND, VT, Custom);
+ setOperationAction(ISD::BITREVERSE, VT, Custom);
+ setOperationAction(ISD::BSWAP, VT, Custom);
+ setOperationAction(ISD::CTLZ, VT, Custom);
+ setOperationAction(ISD::CTPOP, VT, Custom);
+ setOperationAction(ISD::CTTZ, VT, Custom);
setOperationAction(ISD::FADD, VT, Custom);
+ setOperationAction(ISD::FCEIL, VT, Custom);
+ setOperationAction(ISD::FDIV, VT, Custom);
+ setOperationAction(ISD::FFLOOR, VT, Custom);
+ setOperationAction(ISD::FMA, VT, Custom);
+ setOperationAction(ISD::FMAXNUM, VT, Custom);
+ setOperationAction(ISD::FMINNUM, VT, Custom);
+ setOperationAction(ISD::FMUL, VT, Custom);
+ setOperationAction(ISD::FNEARBYINT, VT, Custom);
+ setOperationAction(ISD::FNEG, VT, Custom);
+ setOperationAction(ISD::FRINT, VT, Custom);
+ setOperationAction(ISD::FROUND, VT, Custom);
+ setOperationAction(ISD::FSQRT, VT, Custom);
+ setOperationAction(ISD::FSUB, VT, Custom);
+ setOperationAction(ISD::FTRUNC, VT, Custom);
setOperationAction(ISD::LOAD, VT, Custom);
+ setOperationAction(ISD::MUL, VT, Custom);
+ setOperationAction(ISD::OR, VT, Custom);
+ setOperationAction(ISD::SDIV, VT, Custom);
+ setOperationAction(ISD::SETCC, VT, Custom);
+ setOperationAction(ISD::SHL, VT, Custom);
+ setOperationAction(ISD::SIGN_EXTEND, VT, Custom);
+ setOperationAction(ISD::SIGN_EXTEND_INREG, VT, Custom);
+ setOperationAction(ISD::SMAX, VT, Custom);
+ setOperationAction(ISD::SMIN, VT, Custom);
+ setOperationAction(ISD::SPLAT_VECTOR, VT, Custom);
+ setOperationAction(ISD::SRA, VT, Custom);
+ setOperationAction(ISD::SRL, VT, Custom);
setOperationAction(ISD::STORE, VT, Custom);
+ setOperationAction(ISD::SUB, VT, Custom);
setOperationAction(ISD::TRUNCATE, VT, Custom);
+ setOperationAction(ISD::UDIV, VT, Custom);
+ setOperationAction(ISD::UMAX, VT, Custom);
+ setOperationAction(ISD::UMIN, VT, Custom);
+ setOperationAction(ISD::VECREDUCE_ADD, VT, Custom);
+ setOperationAction(ISD::VECREDUCE_AND, VT, Custom);
+ setOperationAction(ISD::VECREDUCE_FADD, VT, Custom);
+ setOperationAction(ISD::VECREDUCE_SEQ_FADD, VT, Custom);
+ setOperationAction(ISD::VECREDUCE_FMAX, VT, Custom);
+ setOperationAction(ISD::VECREDUCE_FMIN, VT, Custom);
+ setOperationAction(ISD::VECREDUCE_OR, VT, Custom);
+ setOperationAction(ISD::VECREDUCE_SMAX, VT, Custom);
+ setOperationAction(ISD::VECREDUCE_SMIN, VT, Custom);
+ setOperationAction(ISD::VECREDUCE_UMAX, VT, Custom);
+ setOperationAction(ISD::VECREDUCE_UMIN, VT, Custom);
+ setOperationAction(ISD::VECREDUCE_XOR, VT, Custom);
+ setOperationAction(ISD::VSELECT, VT, Custom);
+ setOperationAction(ISD::XOR, VT, Custom);
+ setOperationAction(ISD::ZERO_EXTEND, VT, Custom);
}
void AArch64TargetLowering::addDRTypeForNEON(MVT VT) {
@@ -1244,8 +1596,7 @@ void AArch64TargetLowering::computeKnownBitsForTargetNode(
KnownBits Known2;
Known = DAG.computeKnownBits(Op->getOperand(0), Depth + 1);
Known2 = DAG.computeKnownBits(Op->getOperand(1), Depth + 1);
- Known.Zero &= Known2.Zero;
- Known.One &= Known2.One;
+ Known = KnownBits::commonBits(Known, Known2);
break;
}
case AArch64ISD::LOADgot:
@@ -1385,15 +1736,38 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
MAKE_CASE(AArch64ISD::THREAD_POINTER)
MAKE_CASE(AArch64ISD::TLSDESC_CALLSEQ)
MAKE_CASE(AArch64ISD::ADD_PRED)
+ MAKE_CASE(AArch64ISD::MUL_PRED)
MAKE_CASE(AArch64ISD::SDIV_PRED)
+ MAKE_CASE(AArch64ISD::SHL_PRED)
+ MAKE_CASE(AArch64ISD::SMAX_PRED)
+ MAKE_CASE(AArch64ISD::SMIN_PRED)
+ MAKE_CASE(AArch64ISD::SRA_PRED)
+ MAKE_CASE(AArch64ISD::SRL_PRED)
+ MAKE_CASE(AArch64ISD::SUB_PRED)
MAKE_CASE(AArch64ISD::UDIV_PRED)
- MAKE_CASE(AArch64ISD::SMIN_MERGE_OP1)
- MAKE_CASE(AArch64ISD::UMIN_MERGE_OP1)
- MAKE_CASE(AArch64ISD::SMAX_MERGE_OP1)
- MAKE_CASE(AArch64ISD::UMAX_MERGE_OP1)
- MAKE_CASE(AArch64ISD::SHL_MERGE_OP1)
- MAKE_CASE(AArch64ISD::SRL_MERGE_OP1)
- MAKE_CASE(AArch64ISD::SRA_MERGE_OP1)
+ MAKE_CASE(AArch64ISD::UMAX_PRED)
+ MAKE_CASE(AArch64ISD::UMIN_PRED)
+ MAKE_CASE(AArch64ISD::FNEG_MERGE_PASSTHRU)
+ MAKE_CASE(AArch64ISD::SIGN_EXTEND_INREG_MERGE_PASSTHRU)
+ MAKE_CASE(AArch64ISD::ZERO_EXTEND_INREG_MERGE_PASSTHRU)
+ MAKE_CASE(AArch64ISD::FCEIL_MERGE_PASSTHRU)
+ MAKE_CASE(AArch64ISD::FFLOOR_MERGE_PASSTHRU)
+ MAKE_CASE(AArch64ISD::FNEARBYINT_MERGE_PASSTHRU)
+ MAKE_CASE(AArch64ISD::FRINT_MERGE_PASSTHRU)
+ MAKE_CASE(AArch64ISD::FROUND_MERGE_PASSTHRU)
+ MAKE_CASE(AArch64ISD::FROUNDEVEN_MERGE_PASSTHRU)
+ MAKE_CASE(AArch64ISD::FTRUNC_MERGE_PASSTHRU)
+ MAKE_CASE(AArch64ISD::FP_ROUND_MERGE_PASSTHRU)
+ MAKE_CASE(AArch64ISD::FP_EXTEND_MERGE_PASSTHRU)
+ MAKE_CASE(AArch64ISD::SINT_TO_FP_MERGE_PASSTHRU)
+ MAKE_CASE(AArch64ISD::UINT_TO_FP_MERGE_PASSTHRU)
+ MAKE_CASE(AArch64ISD::FCVTZU_MERGE_PASSTHRU)
+ MAKE_CASE(AArch64ISD::FCVTZS_MERGE_PASSTHRU)
+ MAKE_CASE(AArch64ISD::FSQRT_MERGE_PASSTHRU)
+ MAKE_CASE(AArch64ISD::FRECPX_MERGE_PASSTHRU)
+ MAKE_CASE(AArch64ISD::FABS_MERGE_PASSTHRU)
+ MAKE_CASE(AArch64ISD::ABS_MERGE_PASSTHRU)
+ MAKE_CASE(AArch64ISD::NEG_MERGE_PASSTHRU)
MAKE_CASE(AArch64ISD::SETCC_MERGE_ZERO)
MAKE_CASE(AArch64ISD::ADC)
MAKE_CASE(AArch64ISD::SBC)
@@ -1462,10 +1836,14 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
MAKE_CASE(AArch64ISD::UADDV)
MAKE_CASE(AArch64ISD::SRHADD)
MAKE_CASE(AArch64ISD::URHADD)
+ MAKE_CASE(AArch64ISD::SHADD)
+ MAKE_CASE(AArch64ISD::UHADD)
MAKE_CASE(AArch64ISD::SMINV)
MAKE_CASE(AArch64ISD::UMINV)
MAKE_CASE(AArch64ISD::SMAXV)
MAKE_CASE(AArch64ISD::UMAXV)
+ MAKE_CASE(AArch64ISD::SADDV_PRED)
+ MAKE_CASE(AArch64ISD::UADDV_PRED)
MAKE_CASE(AArch64ISD::SMAXV_PRED)
MAKE_CASE(AArch64ISD::UMAXV_PRED)
MAKE_CASE(AArch64ISD::SMINV_PRED)
@@ -1483,12 +1861,16 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
MAKE_CASE(AArch64ISD::FADD_PRED)
MAKE_CASE(AArch64ISD::FADDA_PRED)
MAKE_CASE(AArch64ISD::FADDV_PRED)
+ MAKE_CASE(AArch64ISD::FDIV_PRED)
MAKE_CASE(AArch64ISD::FMA_PRED)
MAKE_CASE(AArch64ISD::FMAXV_PRED)
+ MAKE_CASE(AArch64ISD::FMAXNM_PRED)
MAKE_CASE(AArch64ISD::FMAXNMV_PRED)
MAKE_CASE(AArch64ISD::FMINV_PRED)
+ MAKE_CASE(AArch64ISD::FMINNM_PRED)
MAKE_CASE(AArch64ISD::FMINNMV_PRED)
- MAKE_CASE(AArch64ISD::NOT)
+ MAKE_CASE(AArch64ISD::FMUL_PRED)
+ MAKE_CASE(AArch64ISD::FSUB_PRED)
MAKE_CASE(AArch64ISD::BIT)
MAKE_CASE(AArch64ISD::CBZ)
MAKE_CASE(AArch64ISD::CBNZ)
@@ -1600,8 +1982,15 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
MAKE_CASE(AArch64ISD::LDP)
MAKE_CASE(AArch64ISD::STP)
MAKE_CASE(AArch64ISD::STNP)
+ MAKE_CASE(AArch64ISD::BITREVERSE_MERGE_PASSTHRU)
+ MAKE_CASE(AArch64ISD::BSWAP_MERGE_PASSTHRU)
+ MAKE_CASE(AArch64ISD::CTLZ_MERGE_PASSTHRU)
+ MAKE_CASE(AArch64ISD::CTPOP_MERGE_PASSTHRU)
MAKE_CASE(AArch64ISD::DUP_MERGE_PASSTHRU)
MAKE_CASE(AArch64ISD::INDEX_VECTOR)
+ MAKE_CASE(AArch64ISD::UABD)
+ MAKE_CASE(AArch64ISD::SABD)
+ MAKE_CASE(AArch64ISD::CALL_RVMARKER)
}
#undef MAKE_CASE
return nullptr;
@@ -1689,6 +2078,7 @@ MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter(
case TargetOpcode::STACKMAP:
case TargetOpcode::PATCHPOINT:
+ case TargetOpcode::STATEPOINT:
return emitPatchPoint(MI, BB);
case AArch64::CATCHRET:
@@ -2514,21 +2904,10 @@ getAArch64XALUOOp(AArch64CC::CondCode &CC, SDValue Op, SelectionDAG &DAG) {
return std::make_pair(Value, Overflow);
}
-SDValue AArch64TargetLowering::LowerF128Call(SDValue Op, SelectionDAG &DAG,
- RTLIB::Libcall Call) const {
- bool IsStrict = Op->isStrictFPOpcode();
- unsigned Offset = IsStrict ? 1 : 0;
- SDValue Chain = IsStrict ? Op.getOperand(0) : SDValue();
- SmallVector<SDValue, 2> Ops(Op->op_begin() + Offset, Op->op_end());
- MakeLibCallOptions CallOptions;
- SDValue Result;
- SDLoc dl(Op);
- std::tie(Result, Chain) = makeLibCall(DAG, Call, Op.getValueType(), Ops,
- CallOptions, dl, Chain);
- return IsStrict ? DAG.getMergeValues({Result, Chain}, dl) : Result;
-}
+SDValue AArch64TargetLowering::LowerXOR(SDValue Op, SelectionDAG &DAG) const {
+ if (useSVEForFixedLengthVectorVT(Op.getValueType()))
+ return LowerToScalableOp(Op, DAG);
-static SDValue LowerXOR(SDValue Op, SelectionDAG &DAG) {
SDValue Sel = Op.getOperand(0);
SDValue Other = Op.getOperand(1);
SDLoc dl(Sel);
@@ -2703,16 +3082,18 @@ static SDValue LowerPREFETCH(SDValue Op, SelectionDAG &DAG) {
SDValue AArch64TargetLowering::LowerFP_EXTEND(SDValue Op,
SelectionDAG &DAG) const {
- assert(Op.getValueType() == MVT::f128 && "Unexpected lowering");
-
- RTLIB::Libcall LC;
- LC = RTLIB::getFPEXT(Op.getOperand(0).getValueType(), Op.getValueType());
+ if (Op.getValueType().isScalableVector())
+ return LowerToPredicatedOp(Op, DAG, AArch64ISD::FP_EXTEND_MERGE_PASSTHRU);
- return LowerF128Call(Op, DAG, LC);
+ assert(Op.getValueType() == MVT::f128 && "Unexpected lowering");
+ return SDValue();
}
SDValue AArch64TargetLowering::LowerFP_ROUND(SDValue Op,
SelectionDAG &DAG) const {
+ if (Op.getValueType().isScalableVector())
+ return LowerToPredicatedOp(Op, DAG, AArch64ISD::FP_ROUND_MERGE_PASSTHRU);
+
bool IsStrict = Op->isStrictFPOpcode();
SDValue SrcVal = Op.getOperand(IsStrict ? 1 : 0);
EVT SrcVT = SrcVal.getValueType();
@@ -2726,19 +3107,7 @@ SDValue AArch64TargetLowering::LowerFP_ROUND(SDValue Op,
return Op;
}
- RTLIB::Libcall LC;
- LC = RTLIB::getFPROUND(SrcVT, Op.getValueType());
-
- // FP_ROUND node has a second operand indicating whether it is known to be
- // precise. That doesn't take part in the LibCall so we can't directly use
- // LowerF128Call.
- MakeLibCallOptions CallOptions;
- SDValue Chain = IsStrict ? Op.getOperand(0) : SDValue();
- SDValue Result;
- SDLoc dl(Op);
- std::tie(Result, Chain) = makeLibCall(DAG, LC, Op.getValueType(), SrcVal,
- CallOptions, dl, Chain);
- return IsStrict ? DAG.getMergeValues({Result, Chain}, dl) : Result;
+ return SDValue();
}
SDValue AArch64TargetLowering::LowerVectorFP_TO_INT(SDValue Op,
@@ -2748,6 +3117,14 @@ SDValue AArch64TargetLowering::LowerVectorFP_TO_INT(SDValue Op,
// in the cost tables.
EVT InVT = Op.getOperand(0).getValueType();
EVT VT = Op.getValueType();
+
+ if (VT.isScalableVector()) {
+ unsigned Opcode = Op.getOpcode() == ISD::FP_TO_UINT
+ ? AArch64ISD::FCVTZU_MERGE_PASSTHRU
+ : AArch64ISD::FCVTZS_MERGE_PASSTHRU;
+ return LowerToPredicatedOp(Op, DAG, Opcode);
+ }
+
unsigned NumElts = InVT.getVectorNumElements();
// f16 conversions are promoted to f32 when full fp16 is not supported.
@@ -2760,7 +3137,9 @@ SDValue AArch64TargetLowering::LowerVectorFP_TO_INT(SDValue Op,
DAG.getNode(ISD::FP_EXTEND, dl, NewVT, Op.getOperand(0)));
}
- if (VT.getSizeInBits() < InVT.getSizeInBits()) {
+ uint64_t VTSize = VT.getFixedSizeInBits();
+ uint64_t InVTSize = InVT.getFixedSizeInBits();
+ if (VTSize < InVTSize) {
SDLoc dl(Op);
SDValue Cv =
DAG.getNode(Op.getOpcode(), dl, InVT.changeVectorElementTypeToInteger(),
@@ -2768,7 +3147,7 @@ SDValue AArch64TargetLowering::LowerVectorFP_TO_INT(SDValue Op,
return DAG.getNode(ISD::TRUNCATE, dl, VT, Cv);
}
- if (VT.getSizeInBits() > InVT.getSizeInBits()) {
+ if (VTSize > InVTSize) {
SDLoc dl(Op);
MVT ExtVT =
MVT::getVectorVT(MVT::getFloatingPointVT(VT.getScalarSizeInBits()),
@@ -2803,17 +3182,11 @@ SDValue AArch64TargetLowering::LowerFP_TO_INT(SDValue Op,
return Op;
}
- RTLIB::Libcall LC;
- if (Op.getOpcode() == ISD::FP_TO_SINT ||
- Op.getOpcode() == ISD::STRICT_FP_TO_SINT)
- LC = RTLIB::getFPTOSINT(SrcVal.getValueType(), Op.getValueType());
- else
- LC = RTLIB::getFPTOUINT(SrcVal.getValueType(), Op.getValueType());
-
- return LowerF128Call(Op, DAG, LC);
+ return SDValue();
}
-static SDValue LowerVectorINT_TO_FP(SDValue Op, SelectionDAG &DAG) {
+SDValue AArch64TargetLowering::LowerVectorINT_TO_FP(SDValue Op,
+ SelectionDAG &DAG) const {
// Warning: We maintain cost tables in AArch64TargetTransformInfo.cpp.
// Any additional optimization in this function should be recorded
// in the cost tables.
@@ -2821,21 +3194,38 @@ static SDValue LowerVectorINT_TO_FP(SDValue Op, SelectionDAG &DAG) {
SDLoc dl(Op);
SDValue In = Op.getOperand(0);
EVT InVT = In.getValueType();
+ unsigned Opc = Op.getOpcode();
+ bool IsSigned = Opc == ISD::SINT_TO_FP || Opc == ISD::STRICT_SINT_TO_FP;
+
+ if (VT.isScalableVector()) {
+ if (InVT.getVectorElementType() == MVT::i1) {
+ // We can't directly extend an SVE predicate; extend it first.
+ unsigned CastOpc = IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
+ EVT CastVT = getPromotedVTForPredicate(InVT);
+ In = DAG.getNode(CastOpc, dl, CastVT, In);
+ return DAG.getNode(Opc, dl, VT, In);
+ }
+
+ unsigned Opcode = IsSigned ? AArch64ISD::SINT_TO_FP_MERGE_PASSTHRU
+ : AArch64ISD::UINT_TO_FP_MERGE_PASSTHRU;
+ return LowerToPredicatedOp(Op, DAG, Opcode);
+ }
- if (VT.getSizeInBits() < InVT.getSizeInBits()) {
+ uint64_t VTSize = VT.getFixedSizeInBits();
+ uint64_t InVTSize = InVT.getFixedSizeInBits();
+ if (VTSize < InVTSize) {
MVT CastVT =
MVT::getVectorVT(MVT::getFloatingPointVT(InVT.getScalarSizeInBits()),
InVT.getVectorNumElements());
- In = DAG.getNode(Op.getOpcode(), dl, CastVT, In);
+ In = DAG.getNode(Opc, dl, CastVT, In);
return DAG.getNode(ISD::FP_ROUND, dl, VT, In, DAG.getIntPtrConstant(0, dl));
}
- if (VT.getSizeInBits() > InVT.getSizeInBits()) {
- unsigned CastOpc =
- Op.getOpcode() == ISD::SINT_TO_FP ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
+ if (VTSize > InVTSize) {
+ unsigned CastOpc = IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
EVT CastVT = VT.changeVectorElementTypeToInteger();
In = DAG.getNode(CastOpc, dl, CastVT, In);
- return DAG.getNode(Op.getOpcode(), dl, VT, In);
+ return DAG.getNode(Opc, dl, VT, In);
}
return Op;
@@ -2868,15 +3258,7 @@ SDValue AArch64TargetLowering::LowerINT_TO_FP(SDValue Op,
// fp128.
if (Op.getValueType() != MVT::f128)
return Op;
-
- RTLIB::Libcall LC;
- if (Op.getOpcode() == ISD::SINT_TO_FP ||
- Op.getOpcode() == ISD::STRICT_SINT_TO_FP)
- LC = RTLIB::getSINTTOFP(SrcVal.getValueType(), Op.getValueType());
- else
- LC = RTLIB::getUINTTOFP(SrcVal.getValueType(), Op.getValueType());
-
- return LowerF128Call(Op, DAG, LC);
+ return SDValue();
}
SDValue AArch64TargetLowering::LowerFSINCOS(SDValue Op,
@@ -2990,7 +3372,8 @@ static bool isExtendedBUILD_VECTOR(SDNode *N, SelectionDAG &DAG,
}
static SDValue skipExtensionForVectorMULL(SDNode *N, SelectionDAG &DAG) {
- if (N->getOpcode() == ISD::SIGN_EXTEND || N->getOpcode() == ISD::ZERO_EXTEND)
+ if (N->getOpcode() == ISD::SIGN_EXTEND ||
+ N->getOpcode() == ISD::ZERO_EXTEND || N->getOpcode() == ISD::ANY_EXTEND)
return addRequiredExtensionForVectorMULL(N->getOperand(0), DAG,
N->getOperand(0)->getValueType(0),
N->getValueType(0),
@@ -3015,11 +3398,13 @@ static SDValue skipExtensionForVectorMULL(SDNode *N, SelectionDAG &DAG) {
static bool isSignExtended(SDNode *N, SelectionDAG &DAG) {
return N->getOpcode() == ISD::SIGN_EXTEND ||
+ N->getOpcode() == ISD::ANY_EXTEND ||
isExtendedBUILD_VECTOR(N, DAG, true);
}
static bool isZeroExtended(SDNode *N, SelectionDAG &DAG) {
return N->getOpcode() == ISD::ZERO_EXTEND ||
+ N->getOpcode() == ISD::ANY_EXTEND ||
isExtendedBUILD_VECTOR(N, DAG, false);
}
@@ -3068,10 +3453,17 @@ SDValue AArch64TargetLowering::LowerFLT_ROUNDS_(SDValue Op,
return DAG.getMergeValues({AND, Chain}, dl);
}
-static SDValue LowerMUL(SDValue Op, SelectionDAG &DAG) {
+SDValue AArch64TargetLowering::LowerMUL(SDValue Op, SelectionDAG &DAG) const {
+ EVT VT = Op.getValueType();
+
+ // If SVE is available then i64 vector multiplications can also be made legal.
+ bool OverrideNEON = VT == MVT::v2i64 || VT == MVT::v1i64;
+
+ if (VT.isScalableVector() || useSVEForFixedLengthVectorVT(VT, OverrideNEON))
+ return LowerToPredicatedOp(Op, DAG, AArch64ISD::MUL_PRED, OverrideNEON);
+
// Multiplications are only custom-lowered for 128-bit vectors so that
// VMULL can be detected. Otherwise v2i64 multiplications are not legal.
- EVT VT = Op.getValueType();
assert(VT.is128BitVector() && VT.isInteger() &&
"unexpected type for custom-lowering ISD::MUL");
SDNode *N0 = Op.getOperand(0).getNode();
@@ -3230,11 +3622,77 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
case Intrinsic::aarch64_sve_ptrue:
return DAG.getNode(AArch64ISD::PTRUE, dl, Op.getValueType(),
Op.getOperand(1));
+ case Intrinsic::aarch64_sve_clz:
+ return DAG.getNode(AArch64ISD::CTLZ_MERGE_PASSTHRU, dl, Op.getValueType(),
+ Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
+ case Intrinsic::aarch64_sve_cnt: {
+ SDValue Data = Op.getOperand(3);
+ // CTPOP only supports integer operands.
+ if (Data.getValueType().isFloatingPoint())
+ Data = DAG.getNode(ISD::BITCAST, dl, Op.getValueType(), Data);
+ return DAG.getNode(AArch64ISD::CTPOP_MERGE_PASSTHRU, dl, Op.getValueType(),
+ Op.getOperand(2), Data, Op.getOperand(1));
+ }
case Intrinsic::aarch64_sve_dupq_lane:
return LowerDUPQLane(Op, DAG);
case Intrinsic::aarch64_sve_convert_from_svbool:
return DAG.getNode(AArch64ISD::REINTERPRET_CAST, dl, Op.getValueType(),
Op.getOperand(1));
+ case Intrinsic::aarch64_sve_fneg:
+ return DAG.getNode(AArch64ISD::FNEG_MERGE_PASSTHRU, dl, Op.getValueType(),
+ Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
+ case Intrinsic::aarch64_sve_frintp:
+ return DAG.getNode(AArch64ISD::FCEIL_MERGE_PASSTHRU, dl, Op.getValueType(),
+ Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
+ case Intrinsic::aarch64_sve_frintm:
+ return DAG.getNode(AArch64ISD::FFLOOR_MERGE_PASSTHRU, dl, Op.getValueType(),
+ Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
+ case Intrinsic::aarch64_sve_frinti:
+ return DAG.getNode(AArch64ISD::FNEARBYINT_MERGE_PASSTHRU, dl, Op.getValueType(),
+ Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
+ case Intrinsic::aarch64_sve_frintx:
+ return DAG.getNode(AArch64ISD::FRINT_MERGE_PASSTHRU, dl, Op.getValueType(),
+ Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
+ case Intrinsic::aarch64_sve_frinta:
+ return DAG.getNode(AArch64ISD::FROUND_MERGE_PASSTHRU, dl, Op.getValueType(),
+ Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
+ case Intrinsic::aarch64_sve_frintn:
+ return DAG.getNode(AArch64ISD::FROUNDEVEN_MERGE_PASSTHRU, dl, Op.getValueType(),
+ Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
+ case Intrinsic::aarch64_sve_frintz:
+ return DAG.getNode(AArch64ISD::FTRUNC_MERGE_PASSTHRU, dl, Op.getValueType(),
+ Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
+ case Intrinsic::aarch64_sve_ucvtf:
+ return DAG.getNode(AArch64ISD::UINT_TO_FP_MERGE_PASSTHRU, dl,
+ Op.getValueType(), Op.getOperand(2), Op.getOperand(3),
+ Op.getOperand(1));
+ case Intrinsic::aarch64_sve_scvtf:
+ return DAG.getNode(AArch64ISD::SINT_TO_FP_MERGE_PASSTHRU, dl,
+ Op.getValueType(), Op.getOperand(2), Op.getOperand(3),
+ Op.getOperand(1));
+ case Intrinsic::aarch64_sve_fcvtzu:
+ return DAG.getNode(AArch64ISD::FCVTZU_MERGE_PASSTHRU, dl,
+ Op.getValueType(), Op.getOperand(2), Op.getOperand(3),
+ Op.getOperand(1));
+ case Intrinsic::aarch64_sve_fcvtzs:
+ return DAG.getNode(AArch64ISD::FCVTZS_MERGE_PASSTHRU, dl,
+ Op.getValueType(), Op.getOperand(2), Op.getOperand(3),
+ Op.getOperand(1));
+ case Intrinsic::aarch64_sve_fsqrt:
+ return DAG.getNode(AArch64ISD::FSQRT_MERGE_PASSTHRU, dl, Op.getValueType(),
+ Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
+ case Intrinsic::aarch64_sve_frecpx:
+ return DAG.getNode(AArch64ISD::FRECPX_MERGE_PASSTHRU, dl, Op.getValueType(),
+ Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
+ case Intrinsic::aarch64_sve_fabs:
+ return DAG.getNode(AArch64ISD::FABS_MERGE_PASSTHRU, dl, Op.getValueType(),
+ Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
+ case Intrinsic::aarch64_sve_abs:
+ return DAG.getNode(AArch64ISD::ABS_MERGE_PASSTHRU, dl, Op.getValueType(),
+ Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
+ case Intrinsic::aarch64_sve_neg:
+ return DAG.getNode(AArch64ISD::NEG_MERGE_PASSTHRU, dl, Op.getValueType(),
+ Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
case Intrinsic::aarch64_sve_convert_to_svbool: {
EVT OutVT = Op.getValueType();
EVT InVT = Op.getOperand(1).getValueType();
@@ -3260,6 +3718,49 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
return DAG.getNode(AArch64ISD::INSR, dl, Op.getValueType(),
Op.getOperand(1), Scalar);
}
+ case Intrinsic::aarch64_sve_rbit:
+ return DAG.getNode(AArch64ISD::BITREVERSE_MERGE_PASSTHRU, dl,
+ Op.getValueType(), Op.getOperand(2), Op.getOperand(3),
+ Op.getOperand(1));
+ case Intrinsic::aarch64_sve_revb:
+ return DAG.getNode(AArch64ISD::BSWAP_MERGE_PASSTHRU, dl, Op.getValueType(),
+ Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
+ case Intrinsic::aarch64_sve_sxtb:
+ return DAG.getNode(
+ AArch64ISD::SIGN_EXTEND_INREG_MERGE_PASSTHRU, dl, Op.getValueType(),
+ Op.getOperand(2), Op.getOperand(3),
+ DAG.getValueType(Op.getValueType().changeVectorElementType(MVT::i8)),
+ Op.getOperand(1));
+ case Intrinsic::aarch64_sve_sxth:
+ return DAG.getNode(
+ AArch64ISD::SIGN_EXTEND_INREG_MERGE_PASSTHRU, dl, Op.getValueType(),
+ Op.getOperand(2), Op.getOperand(3),
+ DAG.getValueType(Op.getValueType().changeVectorElementType(MVT::i16)),
+ Op.getOperand(1));
+ case Intrinsic::aarch64_sve_sxtw:
+ return DAG.getNode(
+ AArch64ISD::SIGN_EXTEND_INREG_MERGE_PASSTHRU, dl, Op.getValueType(),
+ Op.getOperand(2), Op.getOperand(3),
+ DAG.getValueType(Op.getValueType().changeVectorElementType(MVT::i32)),
+ Op.getOperand(1));
+ case Intrinsic::aarch64_sve_uxtb:
+ return DAG.getNode(
+ AArch64ISD::ZERO_EXTEND_INREG_MERGE_PASSTHRU, dl, Op.getValueType(),
+ Op.getOperand(2), Op.getOperand(3),
+ DAG.getValueType(Op.getValueType().changeVectorElementType(MVT::i8)),
+ Op.getOperand(1));
+ case Intrinsic::aarch64_sve_uxth:
+ return DAG.getNode(
+ AArch64ISD::ZERO_EXTEND_INREG_MERGE_PASSTHRU, dl, Op.getValueType(),
+ Op.getOperand(2), Op.getOperand(3),
+ DAG.getValueType(Op.getValueType().changeVectorElementType(MVT::i16)),
+ Op.getOperand(1));
+ case Intrinsic::aarch64_sve_uxtw:
+ return DAG.getNode(
+ AArch64ISD::ZERO_EXTEND_INREG_MERGE_PASSTHRU, dl, Op.getValueType(),
+ Op.getOperand(2), Op.getOperand(3),
+ DAG.getValueType(Op.getValueType().changeVectorElementType(MVT::i32)),
+ Op.getOperand(1));
case Intrinsic::localaddress: {
const auto &MF = DAG.getMachineFunction();
@@ -3299,19 +3800,291 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
}
case Intrinsic::aarch64_neon_srhadd:
- case Intrinsic::aarch64_neon_urhadd: {
- bool IsSignedAdd = IntNo == Intrinsic::aarch64_neon_srhadd;
- unsigned Opcode = IsSignedAdd ? AArch64ISD::SRHADD : AArch64ISD::URHADD;
+ case Intrinsic::aarch64_neon_urhadd:
+ case Intrinsic::aarch64_neon_shadd:
+ case Intrinsic::aarch64_neon_uhadd: {
+ bool IsSignedAdd = (IntNo == Intrinsic::aarch64_neon_srhadd ||
+ IntNo == Intrinsic::aarch64_neon_shadd);
+ bool IsRoundingAdd = (IntNo == Intrinsic::aarch64_neon_srhadd ||
+ IntNo == Intrinsic::aarch64_neon_urhadd);
+ unsigned Opcode =
+ IsSignedAdd ? (IsRoundingAdd ? AArch64ISD::SRHADD : AArch64ISD::SHADD)
+ : (IsRoundingAdd ? AArch64ISD::URHADD : AArch64ISD::UHADD);
return DAG.getNode(Opcode, dl, Op.getValueType(), Op.getOperand(1),
Op.getOperand(2));
}
+
+ case Intrinsic::aarch64_neon_uabd: {
+ return DAG.getNode(AArch64ISD::UABD, dl, Op.getValueType(),
+ Op.getOperand(1), Op.getOperand(2));
+ }
+ case Intrinsic::aarch64_neon_sabd: {
+ return DAG.getNode(AArch64ISD::SABD, dl, Op.getValueType(),
+ Op.getOperand(1), Op.getOperand(2));
+ }
}
}
+bool AArch64TargetLowering::shouldRemoveExtendFromGSIndex(EVT VT) const {
+ if (VT.getVectorElementType() == MVT::i32 &&
+ VT.getVectorElementCount().getKnownMinValue() >= 4)
+ return true;
+
+ return false;
+}
+
bool AArch64TargetLowering::isVectorLoadExtDesirable(SDValue ExtVal) const {
return ExtVal.getValueType().isScalableVector();
}
+unsigned getGatherVecOpcode(bool IsScaled, bool IsSigned, bool NeedsExtend) {
+ std::map<std::tuple<bool, bool, bool>, unsigned> AddrModes = {
+ {std::make_tuple(/*Scaled*/ false, /*Signed*/ false, /*Extend*/ false),
+ AArch64ISD::GLD1_MERGE_ZERO},
+ {std::make_tuple(/*Scaled*/ false, /*Signed*/ false, /*Extend*/ true),
+ AArch64ISD::GLD1_UXTW_MERGE_ZERO},
+ {std::make_tuple(/*Scaled*/ false, /*Signed*/ true, /*Extend*/ false),
+ AArch64ISD::GLD1_MERGE_ZERO},
+ {std::make_tuple(/*Scaled*/ false, /*Signed*/ true, /*Extend*/ true),
+ AArch64ISD::GLD1_SXTW_MERGE_ZERO},
+ {std::make_tuple(/*Scaled*/ true, /*Signed*/ false, /*Extend*/ false),
+ AArch64ISD::GLD1_SCALED_MERGE_ZERO},
+ {std::make_tuple(/*Scaled*/ true, /*Signed*/ false, /*Extend*/ true),
+ AArch64ISD::GLD1_UXTW_SCALED_MERGE_ZERO},
+ {std::make_tuple(/*Scaled*/ true, /*Signed*/ true, /*Extend*/ false),
+ AArch64ISD::GLD1_SCALED_MERGE_ZERO},
+ {std::make_tuple(/*Scaled*/ true, /*Signed*/ true, /*Extend*/ true),
+ AArch64ISD::GLD1_SXTW_SCALED_MERGE_ZERO},
+ };
+ auto Key = std::make_tuple(IsScaled, IsSigned, NeedsExtend);
+ return AddrModes.find(Key)->second;
+}
+
+unsigned getScatterVecOpcode(bool IsScaled, bool IsSigned, bool NeedsExtend) {
+ std::map<std::tuple<bool, bool, bool>, unsigned> AddrModes = {
+ {std::make_tuple(/*Scaled*/ false, /*Signed*/ false, /*Extend*/ false),
+ AArch64ISD::SST1_PRED},
+ {std::make_tuple(/*Scaled*/ false, /*Signed*/ false, /*Extend*/ true),
+ AArch64ISD::SST1_UXTW_PRED},
+ {std::make_tuple(/*Scaled*/ false, /*Signed*/ true, /*Extend*/ false),
+ AArch64ISD::SST1_PRED},
+ {std::make_tuple(/*Scaled*/ false, /*Signed*/ true, /*Extend*/ true),
+ AArch64ISD::SST1_SXTW_PRED},
+ {std::make_tuple(/*Scaled*/ true, /*Signed*/ false, /*Extend*/ false),
+ AArch64ISD::SST1_SCALED_PRED},
+ {std::make_tuple(/*Scaled*/ true, /*Signed*/ false, /*Extend*/ true),
+ AArch64ISD::SST1_UXTW_SCALED_PRED},
+ {std::make_tuple(/*Scaled*/ true, /*Signed*/ true, /*Extend*/ false),
+ AArch64ISD::SST1_SCALED_PRED},
+ {std::make_tuple(/*Scaled*/ true, /*Signed*/ true, /*Extend*/ true),
+ AArch64ISD::SST1_SXTW_SCALED_PRED},
+ };
+ auto Key = std::make_tuple(IsScaled, IsSigned, NeedsExtend);
+ return AddrModes.find(Key)->second;
+}
+
+unsigned getSignExtendedGatherOpcode(unsigned Opcode) {
+ switch (Opcode) {
+ default:
+ llvm_unreachable("unimplemented opcode");
+ return Opcode;
+ case AArch64ISD::GLD1_MERGE_ZERO:
+ return AArch64ISD::GLD1S_MERGE_ZERO;
+ case AArch64ISD::GLD1_IMM_MERGE_ZERO:
+ return AArch64ISD::GLD1S_IMM_MERGE_ZERO;
+ case AArch64ISD::GLD1_UXTW_MERGE_ZERO:
+ return AArch64ISD::GLD1S_UXTW_MERGE_ZERO;
+ case AArch64ISD::GLD1_SXTW_MERGE_ZERO:
+ return AArch64ISD::GLD1S_SXTW_MERGE_ZERO;
+ case AArch64ISD::GLD1_SCALED_MERGE_ZERO:
+ return AArch64ISD::GLD1S_SCALED_MERGE_ZERO;
+ case AArch64ISD::GLD1_UXTW_SCALED_MERGE_ZERO:
+ return AArch64ISD::GLD1S_UXTW_SCALED_MERGE_ZERO;
+ case AArch64ISD::GLD1_SXTW_SCALED_MERGE_ZERO:
+ return AArch64ISD::GLD1S_SXTW_SCALED_MERGE_ZERO;
+ }
+}
+
+bool getGatherScatterIndexIsExtended(SDValue Index) {
+ unsigned Opcode = Index.getOpcode();
+ if (Opcode == ISD::SIGN_EXTEND_INREG)
+ return true;
+
+ if (Opcode == ISD::AND) {
+ SDValue Splat = Index.getOperand(1);
+ if (Splat.getOpcode() != ISD::SPLAT_VECTOR)
+ return false;
+ ConstantSDNode *Mask = dyn_cast<ConstantSDNode>(Splat.getOperand(0));
+ if (!Mask || Mask->getZExtValue() != 0xFFFFFFFF)
+ return false;
+ return true;
+ }
+
+ return false;
+}
+
+// If the base pointer of a masked gather or scatter is null, we
+// may be able to swap BasePtr & Index and use the vector + register
+// or vector + immediate addressing mode, e.g.
+// VECTOR + REGISTER:
+// getelementptr nullptr, <vscale x N x T> (splat(%offset)) + %indices)
+// -> getelementptr %offset, <vscale x N x T> %indices
+// VECTOR + IMMEDIATE:
+// getelementptr nullptr, <vscale x N x T> (splat(#x)) + %indices)
+// -> getelementptr #x, <vscale x N x T> %indices
+void selectGatherScatterAddrMode(SDValue &BasePtr, SDValue &Index, EVT MemVT,
+ unsigned &Opcode, bool IsGather,
+ SelectionDAG &DAG) {
+ if (!isNullConstant(BasePtr))
+ return;
+
+ ConstantSDNode *Offset = nullptr;
+ if (Index.getOpcode() == ISD::ADD)
+ if (auto SplatVal = DAG.getSplatValue(Index.getOperand(1))) {
+ if (isa<ConstantSDNode>(SplatVal))
+ Offset = cast<ConstantSDNode>(SplatVal);
+ else {
+ BasePtr = SplatVal;
+ Index = Index->getOperand(0);
+ return;
+ }
+ }
+
+ unsigned NewOp =
+ IsGather ? AArch64ISD::GLD1_IMM_MERGE_ZERO : AArch64ISD::SST1_IMM_PRED;
+
+ if (!Offset) {
+ std::swap(BasePtr, Index);
+ Opcode = NewOp;
+ return;
+ }
+
+ uint64_t OffsetVal = Offset->getZExtValue();
+ unsigned ScalarSizeInBytes = MemVT.getScalarSizeInBits() / 8;
+ auto ConstOffset = DAG.getConstant(OffsetVal, SDLoc(Index), MVT::i64);
+
+ if (OffsetVal % ScalarSizeInBytes || OffsetVal / ScalarSizeInBytes > 31) {
+ // Index is out of range for the immediate addressing mode
+ BasePtr = ConstOffset;
+ Index = Index->getOperand(0);
+ return;
+ }
+
+ // Immediate is in range
+ Opcode = NewOp;
+ BasePtr = Index->getOperand(0);
+ Index = ConstOffset;
+}
+
+SDValue AArch64TargetLowering::LowerMGATHER(SDValue Op,
+ SelectionDAG &DAG) const {
+ SDLoc DL(Op);
+ MaskedGatherSDNode *MGT = cast<MaskedGatherSDNode>(Op);
+ assert(MGT && "Can only custom lower gather load nodes");
+
+ SDValue Index = MGT->getIndex();
+ SDValue Chain = MGT->getChain();
+ SDValue PassThru = MGT->getPassThru();
+ SDValue Mask = MGT->getMask();
+ SDValue BasePtr = MGT->getBasePtr();
+ ISD::LoadExtType ExtTy = MGT->getExtensionType();
+
+ ISD::MemIndexType IndexType = MGT->getIndexType();
+ bool IsScaled =
+ IndexType == ISD::SIGNED_SCALED || IndexType == ISD::UNSIGNED_SCALED;
+ bool IsSigned =
+ IndexType == ISD::SIGNED_SCALED || IndexType == ISD::SIGNED_UNSCALED;
+ bool IdxNeedsExtend =
+ getGatherScatterIndexIsExtended(Index) ||
+ Index.getSimpleValueType().getVectorElementType() == MVT::i32;
+ bool ResNeedsSignExtend = ExtTy == ISD::EXTLOAD || ExtTy == ISD::SEXTLOAD;
+
+ EVT VT = PassThru.getSimpleValueType();
+ EVT MemVT = MGT->getMemoryVT();
+ SDValue InputVT = DAG.getValueType(MemVT);
+
+ if (VT.getVectorElementType() == MVT::bf16 &&
+ !static_cast<const AArch64Subtarget &>(DAG.getSubtarget()).hasBF16())
+ return SDValue();
+
+ // Handle FP data by using an integer gather and casting the result.
+ if (VT.isFloatingPoint()) {
+ EVT PassThruVT = getPackedSVEVectorVT(VT.getVectorElementCount());
+ PassThru = getSVESafeBitCast(PassThruVT, PassThru, DAG);
+ InputVT = DAG.getValueType(MemVT.changeVectorElementTypeToInteger());
+ }
+
+ SDVTList VTs = DAG.getVTList(PassThru.getSimpleValueType(), MVT::Other);
+
+ if (getGatherScatterIndexIsExtended(Index))
+ Index = Index.getOperand(0);
+
+ unsigned Opcode = getGatherVecOpcode(IsScaled, IsSigned, IdxNeedsExtend);
+ selectGatherScatterAddrMode(BasePtr, Index, MemVT, Opcode,
+ /*isGather=*/true, DAG);
+
+ if (ResNeedsSignExtend)
+ Opcode = getSignExtendedGatherOpcode(Opcode);
+
+ SDValue Ops[] = {Chain, Mask, BasePtr, Index, InputVT, PassThru};
+ SDValue Gather = DAG.getNode(Opcode, DL, VTs, Ops);
+
+ if (VT.isFloatingPoint()) {
+ SDValue Cast = getSVESafeBitCast(VT, Gather, DAG);
+ return DAG.getMergeValues({Cast, Gather}, DL);
+ }
+
+ return Gather;
+}
+
+SDValue AArch64TargetLowering::LowerMSCATTER(SDValue Op,
+ SelectionDAG &DAG) const {
+ SDLoc DL(Op);
+ MaskedScatterSDNode *MSC = cast<MaskedScatterSDNode>(Op);
+ assert(MSC && "Can only custom lower scatter store nodes");
+
+ SDValue Index = MSC->getIndex();
+ SDValue Chain = MSC->getChain();
+ SDValue StoreVal = MSC->getValue();
+ SDValue Mask = MSC->getMask();
+ SDValue BasePtr = MSC->getBasePtr();
+
+ ISD::MemIndexType IndexType = MSC->getIndexType();
+ bool IsScaled =
+ IndexType == ISD::SIGNED_SCALED || IndexType == ISD::UNSIGNED_SCALED;
+ bool IsSigned =
+ IndexType == ISD::SIGNED_SCALED || IndexType == ISD::SIGNED_UNSCALED;
+ bool NeedsExtend =
+ getGatherScatterIndexIsExtended(Index) ||
+ Index.getSimpleValueType().getVectorElementType() == MVT::i32;
+
+ EVT VT = StoreVal.getSimpleValueType();
+ SDVTList VTs = DAG.getVTList(MVT::Other);
+ EVT MemVT = MSC->getMemoryVT();
+ SDValue InputVT = DAG.getValueType(MemVT);
+
+ if (VT.getVectorElementType() == MVT::bf16 &&
+ !static_cast<const AArch64Subtarget &>(DAG.getSubtarget()).hasBF16())
+ return SDValue();
+
+ // Handle FP data by casting the data so an integer scatter can be used.
+ if (VT.isFloatingPoint()) {
+ EVT StoreValVT = getPackedSVEVectorVT(VT.getVectorElementCount());
+ StoreVal = getSVESafeBitCast(StoreValVT, StoreVal, DAG);
+ InputVT = DAG.getValueType(MemVT.changeVectorElementTypeToInteger());
+ }
+
+ if (getGatherScatterIndexIsExtended(Index))
+ Index = Index.getOperand(0);
+
+ unsigned Opcode = getScatterVecOpcode(IsScaled, IsSigned, NeedsExtend);
+ selectGatherScatterAddrMode(BasePtr, Index, MemVT, Opcode,
+ /*isGather=*/false, DAG);
+
+ SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, InputVT};
+ return DAG.getNode(Opcode, DL, VTs, Ops);
+}
+
// Custom lower trunc store for v4i8 vectors, since it is promoted to v4i16.
static SDValue LowerTruncateVectorStore(SDLoc DL, StoreSDNode *ST,
EVT VT, EVT MemVT,
@@ -3377,8 +4150,9 @@ SDValue AArch64TargetLowering::LowerSTORE(SDValue Op,
// 256 bit non-temporal stores can be lowered to STNP. Do this as part of
// the custom lowering, as there are no un-paired non-temporal stores and
// legalization will break up 256 bit inputs.
+ ElementCount EC = MemVT.getVectorElementCount();
if (StoreNode->isNonTemporal() && MemVT.getSizeInBits() == 256u &&
- MemVT.getVectorElementCount().Min % 2u == 0 &&
+ EC.isKnownEven() &&
((MemVT.getScalarSizeInBits() == 8u ||
MemVT.getScalarSizeInBits() == 16u ||
MemVT.getScalarSizeInBits() == 32u ||
@@ -3387,11 +4161,11 @@ SDValue AArch64TargetLowering::LowerSTORE(SDValue Op,
DAG.getNode(ISD::EXTRACT_SUBVECTOR, Dl,
MemVT.getHalfNumVectorElementsVT(*DAG.getContext()),
StoreNode->getValue(), DAG.getConstant(0, Dl, MVT::i64));
- SDValue Hi = DAG.getNode(
- ISD::EXTRACT_SUBVECTOR, Dl,
- MemVT.getHalfNumVectorElementsVT(*DAG.getContext()),
- StoreNode->getValue(),
- DAG.getConstant(MemVT.getVectorElementCount().Min / 2, Dl, MVT::i64));
+ SDValue Hi =
+ DAG.getNode(ISD::EXTRACT_SUBVECTOR, Dl,
+ MemVT.getHalfNumVectorElementsVT(*DAG.getContext()),
+ StoreNode->getValue(),
+ DAG.getConstant(EC.getKnownMinValue() / 2, Dl, MVT::i64));
SDValue Result = DAG.getMemIntrinsicNode(
AArch64ISD::STNP, Dl, DAG.getVTList(MVT::Other),
{StoreNode->getChain(), Lo, Hi, StoreNode->getBasePtr()},
@@ -3416,6 +4190,25 @@ SDValue AArch64TargetLowering::LowerSTORE(SDValue Op,
return SDValue();
}
+// Generate SUBS and CSEL for integer abs.
+SDValue AArch64TargetLowering::LowerABS(SDValue Op, SelectionDAG &DAG) const {
+ MVT VT = Op.getSimpleValueType();
+
+ if (VT.isVector())
+ return LowerToPredicatedOp(Op, DAG, AArch64ISD::ABS_MERGE_PASSTHRU);
+
+ SDLoc DL(Op);
+ SDValue Neg = DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT),
+ Op.getOperand(0));
+ // Generate SUBS & CSEL.
+ SDValue Cmp =
+ DAG.getNode(AArch64ISD::SUBS, DL, DAG.getVTList(VT, MVT::i32),
+ Op.getOperand(0), DAG.getConstant(0, DL, VT));
+ return DAG.getNode(AArch64ISD::CSEL, DL, VT, Op.getOperand(0), Neg,
+ DAG.getConstant(AArch64CC::PL, DL, MVT::i32),
+ Cmp.getValue(1));
+}
+
SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
SelectionDAG &DAG) const {
LLVM_DEBUG(dbgs() << "Custom lowering: ");
@@ -3468,17 +4261,35 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
case ISD::UMULO:
return LowerXALUO(Op, DAG);
case ISD::FADD:
- if (useSVEForFixedLengthVectorVT(Op.getValueType()))
- return LowerToPredicatedOp(Op, DAG, AArch64ISD::FADD_PRED);
- return LowerF128Call(Op, DAG, RTLIB::ADD_F128);
+ return LowerToPredicatedOp(Op, DAG, AArch64ISD::FADD_PRED);
case ISD::FSUB:
- return LowerF128Call(Op, DAG, RTLIB::SUB_F128);
+ return LowerToPredicatedOp(Op, DAG, AArch64ISD::FSUB_PRED);
case ISD::FMUL:
- return LowerF128Call(Op, DAG, RTLIB::MUL_F128);
+ return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMUL_PRED);
case ISD::FMA:
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMA_PRED);
case ISD::FDIV:
- return LowerF128Call(Op, DAG, RTLIB::DIV_F128);
+ return LowerToPredicatedOp(Op, DAG, AArch64ISD::FDIV_PRED);
+ case ISD::FNEG:
+ return LowerToPredicatedOp(Op, DAG, AArch64ISD::FNEG_MERGE_PASSTHRU);
+ case ISD::FCEIL:
+ return LowerToPredicatedOp(Op, DAG, AArch64ISD::FCEIL_MERGE_PASSTHRU);
+ case ISD::FFLOOR:
+ return LowerToPredicatedOp(Op, DAG, AArch64ISD::FFLOOR_MERGE_PASSTHRU);
+ case ISD::FNEARBYINT:
+ return LowerToPredicatedOp(Op, DAG, AArch64ISD::FNEARBYINT_MERGE_PASSTHRU);
+ case ISD::FRINT:
+ return LowerToPredicatedOp(Op, DAG, AArch64ISD::FRINT_MERGE_PASSTHRU);
+ case ISD::FROUND:
+ return LowerToPredicatedOp(Op, DAG, AArch64ISD::FROUND_MERGE_PASSTHRU);
+ case ISD::FROUNDEVEN:
+ return LowerToPredicatedOp(Op, DAG, AArch64ISD::FROUNDEVEN_MERGE_PASSTHRU);
+ case ISD::FTRUNC:
+ return LowerToPredicatedOp(Op, DAG, AArch64ISD::FTRUNC_MERGE_PASSTHRU);
+ case ISD::FSQRT:
+ return LowerToPredicatedOp(Op, DAG, AArch64ISD::FSQRT_MERGE_PASSTHRU);
+ case ISD::FABS:
+ return LowerToPredicatedOp(Op, DAG, AArch64ISD::FABS_MERGE_PASSTHRU);
case ISD::FP_ROUND:
case ISD::STRICT_FP_ROUND:
return LowerFP_ROUND(Op, DAG);
@@ -3492,6 +4303,8 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
return LowerRETURNADDR(Op, DAG);
case ISD::ADDROFRETURNADDR:
return LowerADDROFRETURNADDR(Op, DAG);
+ case ISD::CONCAT_VECTORS:
+ return LowerCONCAT_VECTORS(Op, DAG);
case ISD::INSERT_VECTOR_ELT:
return LowerINSERT_VECTOR_ELT(Op, DAG);
case ISD::EXTRACT_VECTOR_ELT:
@@ -3507,17 +4320,20 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
case ISD::INSERT_SUBVECTOR:
return LowerINSERT_SUBVECTOR(Op, DAG);
case ISD::SDIV:
- return LowerToPredicatedOp(Op, DAG, AArch64ISD::SDIV_PRED);
case ISD::UDIV:
- return LowerToPredicatedOp(Op, DAG, AArch64ISD::UDIV_PRED);
+ return LowerDIV(Op, DAG);
case ISD::SMIN:
- return LowerToPredicatedOp(Op, DAG, AArch64ISD::SMIN_MERGE_OP1);
+ return LowerToPredicatedOp(Op, DAG, AArch64ISD::SMIN_PRED,
+ /*OverrideNEON=*/true);
case ISD::UMIN:
- return LowerToPredicatedOp(Op, DAG, AArch64ISD::UMIN_MERGE_OP1);
+ return LowerToPredicatedOp(Op, DAG, AArch64ISD::UMIN_PRED,
+ /*OverrideNEON=*/true);
case ISD::SMAX:
- return LowerToPredicatedOp(Op, DAG, AArch64ISD::SMAX_MERGE_OP1);
+ return LowerToPredicatedOp(Op, DAG, AArch64ISD::SMAX_PRED,
+ /*OverrideNEON=*/true);
case ISD::UMAX:
- return LowerToPredicatedOp(Op, DAG, AArch64ISD::UMAX_MERGE_OP1);
+ return LowerToPredicatedOp(Op, DAG, AArch64ISD::UMAX_PRED,
+ /*OverrideNEON=*/true);
case ISD::SRA:
case ISD::SRL:
case ISD::SHL:
@@ -3557,11 +4373,21 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
return LowerINTRINSIC_WO_CHAIN(Op, DAG);
case ISD::STORE:
return LowerSTORE(Op, DAG);
+ case ISD::MGATHER:
+ return LowerMGATHER(Op, DAG);
+ case ISD::MSCATTER:
+ return LowerMSCATTER(Op, DAG);
+ case ISD::VECREDUCE_SEQ_FADD:
+ return LowerVECREDUCE_SEQ_FADD(Op, DAG);
case ISD::VECREDUCE_ADD:
+ case ISD::VECREDUCE_AND:
+ case ISD::VECREDUCE_OR:
+ case ISD::VECREDUCE_XOR:
case ISD::VECREDUCE_SMAX:
case ISD::VECREDUCE_SMIN:
case ISD::VECREDUCE_UMAX:
case ISD::VECREDUCE_UMIN:
+ case ISD::VECREDUCE_FADD:
case ISD::VECREDUCE_FMAX:
case ISD::VECREDUCE_FMIN:
return LowerVECREDUCE(Op, DAG);
@@ -3573,6 +4399,21 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
return LowerDYNAMIC_STACKALLOC(Op, DAG);
case ISD::VSCALE:
return LowerVSCALE(Op, DAG);
+ case ISD::ANY_EXTEND:
+ case ISD::SIGN_EXTEND:
+ case ISD::ZERO_EXTEND:
+ return LowerFixedLengthVectorIntExtendToSVE(Op, DAG);
+ case ISD::SIGN_EXTEND_INREG: {
+ // Only custom lower when ExtraVT has a legal byte based element type.
+ EVT ExtraVT = cast<VTSDNode>(Op.getOperand(1))->getVT();
+ EVT ExtraEltVT = ExtraVT.getVectorElementType();
+ if ((ExtraEltVT != MVT::i8) && (ExtraEltVT != MVT::i16) &&
+ (ExtraEltVT != MVT::i32) && (ExtraEltVT != MVT::i64))
+ return SDValue();
+
+ return LowerToPredicatedOp(Op, DAG,
+ AArch64ISD::SIGN_EXTEND_INREG_MERGE_PASSTHRU);
+ }
case ISD::TRUNCATE:
return LowerTRUNCATE(Op, DAG);
case ISD::LOAD:
@@ -3580,31 +4421,49 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
return LowerFixedLengthVectorLoadToSVE(Op, DAG);
llvm_unreachable("Unexpected request to lower ISD::LOAD");
case ISD::ADD:
- if (useSVEForFixedLengthVectorVT(Op.getValueType()))
- return LowerToPredicatedOp(Op, DAG, AArch64ISD::ADD_PRED);
- llvm_unreachable("Unexpected request to lower ISD::ADD");
+ return LowerToPredicatedOp(Op, DAG, AArch64ISD::ADD_PRED);
+ case ISD::AND:
+ return LowerToScalableOp(Op, DAG);
+ case ISD::SUB:
+ return LowerToPredicatedOp(Op, DAG, AArch64ISD::SUB_PRED);
+ case ISD::FMAXNUM:
+ return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMAXNM_PRED);
+ case ISD::FMINNUM:
+ return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMINNM_PRED);
+ case ISD::VSELECT:
+ return LowerFixedLengthVectorSelectToSVE(Op, DAG);
+ case ISD::ABS:
+ return LowerABS(Op, DAG);
+ case ISD::BITREVERSE:
+ return LowerToPredicatedOp(Op, DAG, AArch64ISD::BITREVERSE_MERGE_PASSTHRU,
+ /*OverrideNEON=*/true);
+ case ISD::BSWAP:
+ return LowerToPredicatedOp(Op, DAG, AArch64ISD::BSWAP_MERGE_PASSTHRU);
+ case ISD::CTLZ:
+ return LowerToPredicatedOp(Op, DAG, AArch64ISD::CTLZ_MERGE_PASSTHRU,
+ /*OverrideNEON=*/true);
+ case ISD::CTTZ:
+ return LowerCTTZ(Op, DAG);
}
}
-bool AArch64TargetLowering::useSVEForFixedLengthVectors() const {
- // Prefer NEON unless larger SVE registers are available.
- return Subtarget->hasSVE() && Subtarget->getMinSVEVectorSizeInBits() >= 256;
+bool AArch64TargetLowering::mergeStoresAfterLegalization(EVT VT) const {
+ return !Subtarget->useSVEForFixedLengthVectors();
}
-bool AArch64TargetLowering::useSVEForFixedLengthVectorVT(EVT VT) const {
- if (!useSVEForFixedLengthVectors())
+bool AArch64TargetLowering::useSVEForFixedLengthVectorVT(
+ EVT VT, bool OverrideNEON) const {
+ if (!Subtarget->useSVEForFixedLengthVectors())
return false;
if (!VT.isFixedLengthVector())
return false;
- // Fixed length predicates should be promoted to i8.
- // NOTE: This is consistent with how NEON (and thus 64/128bit vectors) work.
- if (VT.getVectorElementType() == MVT::i1)
- return false;
-
// Don't use SVE for vectors we cannot scalarize if required.
switch (VT.getVectorElementType().getSimpleVT().SimpleTy) {
+ // Fixed length predicates should be promoted to i8.
+ // NOTE: This is consistent with how NEON (and thus 64/128bit vectors) work.
+ case MVT::i1:
default:
return false;
case MVT::i8:
@@ -3617,12 +4476,16 @@ bool AArch64TargetLowering::useSVEForFixedLengthVectorVT(EVT VT) const {
break;
}
+ // All SVE implementations support NEON sized vectors.
+ if (OverrideNEON && (VT.is128BitVector() || VT.is64BitVector()))
+ return true;
+
// Ensure NEON MVTs only belong to a single register class.
- if (VT.getSizeInBits() <= 128)
+ if (VT.getFixedSizeInBits() <= 128)
return false;
// Don't use SVE for types that don't fit.
- if (VT.getSizeInBits() > Subtarget->getMinSVEVectorSizeInBits())
+ if (VT.getFixedSizeInBits() > Subtarget->getMinSVEVectorSizeInBits())
return false;
// TODO: Perhaps an artificial restriction, but worth having whilst getting
@@ -3721,10 +4584,10 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
assert(!Res && "Call operand has unhandled type");
(void)Res;
}
- assert(ArgLocs.size() == Ins.size());
SmallVector<SDValue, 16> ArgValues;
- for (unsigned i = 0, e = ArgLocs.size(); i != e; ++i) {
- CCValAssign &VA = ArgLocs[i];
+ unsigned ExtraArgLocs = 0;
+ for (unsigned i = 0, e = Ins.size(); i != e; ++i) {
+ CCValAssign &VA = ArgLocs[i - ExtraArgLocs];
if (Ins[i].Flags.isByVal()) {
// Byval is used for HFAs in the PCS, but the system should work in a
@@ -3852,16 +4715,44 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
if (VA.getLocInfo() == CCValAssign::Indirect) {
assert(VA.getValVT().isScalableVector() &&
"Only scalable vectors can be passed indirectly");
- // If value is passed via pointer - do a load.
- ArgValue =
- DAG.getLoad(VA.getValVT(), DL, Chain, ArgValue, MachinePointerInfo());
- }
- if (Subtarget->isTargetILP32() && Ins[i].Flags.isPointer())
- ArgValue = DAG.getNode(ISD::AssertZext, DL, ArgValue.getValueType(),
- ArgValue, DAG.getValueType(MVT::i32));
- InVals.push_back(ArgValue);
+ uint64_t PartSize = VA.getValVT().getStoreSize().getKnownMinSize();
+ unsigned NumParts = 1;
+ if (Ins[i].Flags.isInConsecutiveRegs()) {
+ assert(!Ins[i].Flags.isInConsecutiveRegsLast());
+ while (!Ins[i + NumParts - 1].Flags.isInConsecutiveRegsLast())
+ ++NumParts;
+ }
+
+ MVT PartLoad = VA.getValVT();
+ SDValue Ptr = ArgValue;
+
+ // Ensure we generate all loads for each tuple part, whilst updating the
+ // pointer after each load correctly using vscale.
+ while (NumParts > 0) {
+ ArgValue = DAG.getLoad(PartLoad, DL, Chain, Ptr, MachinePointerInfo());
+ InVals.push_back(ArgValue);
+ NumParts--;
+ if (NumParts > 0) {
+ SDValue BytesIncrement = DAG.getVScale(
+ DL, Ptr.getValueType(),
+ APInt(Ptr.getValueSizeInBits().getFixedSize(), PartSize));
+ SDNodeFlags Flags;
+ Flags.setNoUnsignedWrap(true);
+ Ptr = DAG.getNode(ISD::ADD, DL, Ptr.getValueType(), Ptr,
+ BytesIncrement, Flags);
+ ExtraArgLocs++;
+ i++;
+ }
+ }
+ } else {
+ if (Subtarget->isTargetILP32() && Ins[i].Flags.isPointer())
+ ArgValue = DAG.getNode(ISD::AssertZext, DL, ArgValue.getValueType(),
+ ArgValue, DAG.getValueType(MVT::i32));
+ InVals.push_back(ArgValue);
+ }
}
+ assert((ArgLocs.size() + ExtraArgLocs) == Ins.size());
// varargs
AArch64FunctionInfo *FuncInfo = MF.getInfo<AArch64FunctionInfo>();
@@ -4036,9 +4927,7 @@ SDValue AArch64TargetLowering::LowerCallResult(
const SmallVectorImpl<ISD::InputArg> &Ins, const SDLoc &DL,
SelectionDAG &DAG, SmallVectorImpl<SDValue> &InVals, bool isThisReturn,
SDValue ThisVal) const {
- CCAssignFn *RetCC = CallConv == CallingConv::WebKit_JS
- ? RetCC_AArch64_WebKit_JS
- : RetCC_AArch64_AAPCS;
+ CCAssignFn *RetCC = CCAssignFnForReturn(CallConv);
// Assign locations to each value returned by this call.
SmallVector<CCValAssign, 16> RVLocs;
DenseMap<unsigned, SDValue> CopiedRegs;
@@ -4104,6 +4993,7 @@ static bool canGuaranteeTCO(CallingConv::ID CC) {
static bool mayTailCallThisCC(CallingConv::ID CC) {
switch (CC) {
case CallingConv::C:
+ case CallingConv::AArch64_SVE_VectorCall:
case CallingConv::PreserveMost:
case CallingConv::Swift:
return true;
@@ -4123,6 +5013,15 @@ bool AArch64TargetLowering::isEligibleForTailCallOptimization(
MachineFunction &MF = DAG.getMachineFunction();
const Function &CallerF = MF.getFunction();
CallingConv::ID CallerCC = CallerF.getCallingConv();
+
+ // If this function uses the C calling convention but has an SVE signature,
+ // then it preserves more registers and should assume the SVE_VectorCall CC.
+ // The check for matching callee-saved regs will determine whether it is
+ // eligible for TCO.
+ if (CallerCC == CallingConv::C &&
+ AArch64RegisterInfo::hasSVEArgsOrReturn(&MF))
+ CallerCC = CallingConv::AArch64_SVE_VectorCall;
+
bool CCMatch = CallerCC == CalleeCC;
// When using the Windows calling convention on a non-windows OS, we want
@@ -4310,6 +5209,20 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
bool TailCallOpt = MF.getTarget().Options.GuaranteedTailCallOpt;
bool IsSibCall = false;
+ // Check callee args/returns for SVE registers and set calling convention
+ // accordingly.
+ if (CallConv == CallingConv::C) {
+ bool CalleeOutSVE = any_of(Outs, [](ISD::OutputArg &Out){
+ return Out.VT.isScalableVector();
+ });
+ bool CalleeInSVE = any_of(Ins, [](ISD::InputArg &In){
+ return In.VT.isScalableVector();
+ });
+
+ if (CalleeInSVE || CalleeOutSVE)
+ CallConv = CallingConv::AArch64_SVE_VectorCall;
+ }
+
if (IsTailCall) {
// Check if it's really possible to do a tail call.
IsTailCall = isEligibleForTailCallOptimization(
@@ -4339,6 +5252,10 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
for (unsigned i = 0; i != NumArgs; ++i) {
MVT ArgVT = Outs[i].VT;
+ if (!Outs[i].IsFixed && ArgVT.isScalableVector())
+ report_fatal_error("Passing SVE types to variadic functions is "
+ "currently not supported");
+
ISD::ArgFlagsTy ArgFlags = Outs[i].Flags;
CCAssignFn *AssignFn = CCAssignFnForCall(CallConv,
/*IsVarArg=*/ !Outs[i].IsFixed);
@@ -4433,8 +5350,9 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
}
// Walk the register/memloc assignments, inserting copies/loads.
- for (unsigned i = 0, e = ArgLocs.size(); i != e; ++i) {
- CCValAssign &VA = ArgLocs[i];
+ unsigned ExtraArgLocs = 0;
+ for (unsigned i = 0, e = Outs.size(); i != e; ++i) {
+ CCValAssign &VA = ArgLocs[i - ExtraArgLocs];
SDValue Arg = OutVals[i];
ISD::ArgFlagsTy Flags = Outs[i].Flags;
@@ -4476,18 +5394,49 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
case CCValAssign::Indirect:
assert(VA.getValVT().isScalableVector() &&
"Only scalable vectors can be passed indirectly");
+
+ uint64_t StoreSize = VA.getValVT().getStoreSize().getKnownMinSize();
+ uint64_t PartSize = StoreSize;
+ unsigned NumParts = 1;
+ if (Outs[i].Flags.isInConsecutiveRegs()) {
+ assert(!Outs[i].Flags.isInConsecutiveRegsLast());
+ while (!Outs[i + NumParts - 1].Flags.isInConsecutiveRegsLast())
+ ++NumParts;
+ StoreSize *= NumParts;
+ }
+
MachineFrameInfo &MFI = DAG.getMachineFunction().getFrameInfo();
Type *Ty = EVT(VA.getValVT()).getTypeForEVT(*DAG.getContext());
Align Alignment = DAG.getDataLayout().getPrefTypeAlign(Ty);
- int FI = MFI.CreateStackObject(
- VA.getValVT().getStoreSize().getKnownMinSize(), Alignment, false);
- MFI.setStackID(FI, TargetStackID::SVEVector);
+ int FI = MFI.CreateStackObject(StoreSize, Alignment, false);
+ MFI.setStackID(FI, TargetStackID::ScalableVector);
- SDValue SpillSlot = DAG.getFrameIndex(
+ MachinePointerInfo MPI =
+ MachinePointerInfo::getFixedStack(DAG.getMachineFunction(), FI);
+ SDValue Ptr = DAG.getFrameIndex(
FI, DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout()));
- Chain = DAG.getStore(
- Chain, DL, Arg, SpillSlot,
- MachinePointerInfo::getFixedStack(DAG.getMachineFunction(), FI));
+ SDValue SpillSlot = Ptr;
+
+ // Ensure we generate all stores for each tuple part, whilst updating the
+ // pointer after each store correctly using vscale.
+ while (NumParts) {
+ Chain = DAG.getStore(Chain, DL, OutVals[i], Ptr, MPI);
+ NumParts--;
+ if (NumParts > 0) {
+ SDValue BytesIncrement = DAG.getVScale(
+ DL, Ptr.getValueType(),
+ APInt(Ptr.getValueSizeInBits().getFixedSize(), PartSize));
+ SDNodeFlags Flags;
+ Flags.setNoUnsignedWrap(true);
+
+ MPI = MachinePointerInfo(MPI.getAddrSpace());
+ Ptr = DAG.getNode(ISD::ADD, DL, Ptr.getValueType(), Ptr,
+ BytesIncrement, Flags);
+ ExtraArgLocs++;
+ i++;
+ }
+ }
+
Arg = SpillSlot;
break;
}
@@ -4507,20 +5456,18 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
// take care of putting the two halves in the right place but we have to
// combine them.
SDValue &Bits =
- std::find_if(RegsToPass.begin(), RegsToPass.end(),
- [=](const std::pair<unsigned, SDValue> &Elt) {
- return Elt.first == VA.getLocReg();
- })
+ llvm::find_if(RegsToPass,
+ [=](const std::pair<unsigned, SDValue> &Elt) {
+ return Elt.first == VA.getLocReg();
+ })
->second;
Bits = DAG.getNode(ISD::OR, DL, Bits.getValueType(), Bits, Arg);
// Call site info is used for function's parameter entry value
// tracking. For now we track only simple cases when parameter
// is transferred through whole register.
- CSInfo.erase(std::remove_if(CSInfo.begin(), CSInfo.end(),
- [&VA](MachineFunction::ArgRegPair ArgReg) {
- return ArgReg.Reg == VA.getLocReg();
- }),
- CSInfo.end());
+ llvm::erase_if(CSInfo, [&VA](MachineFunction::ArgRegPair ArgReg) {
+ return ArgReg.Reg == VA.getLocReg();
+ });
} else {
RegsToPass.emplace_back(VA.getLocReg(), Arg);
RegsUsed.insert(VA.getLocReg());
@@ -4539,7 +5486,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
uint32_t BEAlign = 0;
unsigned OpSize;
if (VA.getLocInfo() == CCValAssign::Indirect)
- OpSize = VA.getLocVT().getSizeInBits();
+ OpSize = VA.getLocVT().getFixedSizeInBits();
else
OpSize = Flags.isByVal() ? Flags.getByValSize() * 8
: VA.getValVT().getSizeInBits();
@@ -4663,20 +5610,6 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
Ops.push_back(DAG.getRegister(RegToPass.first,
RegToPass.second.getValueType()));
- // Check callee args/returns for SVE registers and set calling convention
- // accordingly.
- if (CallConv == CallingConv::C) {
- bool CalleeOutSVE = any_of(Outs, [](ISD::OutputArg &Out){
- return Out.VT.isScalableVector();
- });
- bool CalleeInSVE = any_of(Ins, [](ISD::InputArg &In){
- return In.VT.isScalableVector();
- });
-
- if (CalleeInSVE || CalleeOutSVE)
- CallConv = CallingConv::AArch64_SVE_VectorCall;
- }
-
// Add a register mask operand representing the call-preserved registers.
const uint32_t *Mask;
const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo();
@@ -4713,8 +5646,17 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
return Ret;
}
+ unsigned CallOpc = AArch64ISD::CALL;
+ // Calls marked with "rv_marker" are special. They should be expanded to the
+ // call, directly followed by a special marker sequence. Use the CALL_RVMARKER
+ // to do that.
+ if (CLI.CB && CLI.CB->hasRetAttr("rv_marker")) {
+ assert(!IsTailCall && "tail calls cannot be marked with rv_marker");
+ CallOpc = AArch64ISD::CALL_RVMARKER;
+ }
+
// Returns a chain and a flag for retval copy to use.
- Chain = DAG.getNode(AArch64ISD::CALL, DL, NodeTys, Ops);
+ Chain = DAG.getNode(CallOpc, DL, NodeTys, Ops);
DAG.addNoMergeSiteInfo(Chain.getNode(), CLI.NoMerge);
InFlag = Chain.getValue(1);
DAG.addCallSiteInfo(Chain.getNode(), std::move(CSInfo));
@@ -4738,9 +5680,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
bool AArch64TargetLowering::CanLowerReturn(
CallingConv::ID CallConv, MachineFunction &MF, bool isVarArg,
const SmallVectorImpl<ISD::OutputArg> &Outs, LLVMContext &Context) const {
- CCAssignFn *RetCC = CallConv == CallingConv::WebKit_JS
- ? RetCC_AArch64_WebKit_JS
- : RetCC_AArch64_AAPCS;
+ CCAssignFn *RetCC = CCAssignFnForReturn(CallConv);
SmallVector<CCValAssign, 16> RVLocs;
CCState CCInfo(CallConv, isVarArg, MF, RVLocs, Context);
return CCInfo.CheckReturn(Outs, RetCC);
@@ -4755,9 +5695,7 @@ AArch64TargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
auto &MF = DAG.getMachineFunction();
auto *FuncInfo = MF.getInfo<AArch64FunctionInfo>();
- CCAssignFn *RetCC = CallConv == CallingConv::WebKit_JS
- ? RetCC_AArch64_WebKit_JS
- : RetCC_AArch64_AAPCS;
+ CCAssignFn *RetCC = CCAssignFnForReturn(CallConv);
SmallVector<CCValAssign, 16> RVLocs;
CCState CCInfo(CallConv, isVarArg, DAG.getMachineFunction(), RVLocs,
*DAG.getContext());
@@ -4802,11 +5740,9 @@ AArch64TargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
if (RegsUsed.count(VA.getLocReg())) {
SDValue &Bits =
- std::find_if(RetVals.begin(), RetVals.end(),
- [=](const std::pair<unsigned, SDValue> &Elt) {
- return Elt.first == VA.getLocReg();
- })
- ->second;
+ llvm::find_if(RetVals, [=](const std::pair<unsigned, SDValue> &Elt) {
+ return Elt.first == VA.getLocReg();
+ })->second;
Bits = DAG.getNode(ISD::OR, DL, Bits.getValueType(), Bits, Arg);
} else {
RetVals.emplace_back(VA.getLocReg(), Arg);
@@ -5026,7 +5962,7 @@ AArch64TargetLowering::LowerDarwinGlobalTLSAddress(SDValue Op,
SDValue FuncTLVGet = DAG.getLoad(
PtrMemVT, DL, Chain, DescAddr,
MachinePointerInfo::getGOT(DAG.getMachineFunction()),
- /* Alignment = */ PtrMemVT.getSizeInBits() / 8,
+ Align(PtrMemVT.getSizeInBits() / 8),
MachineMemOperand::MOInvariant | MachineMemOperand::MODereferenceable);
Chain = FuncTLVGet.getValue(1);
@@ -5341,6 +6277,22 @@ SDValue AArch64TargetLowering::LowerGlobalTLSAddress(SDValue Op,
llvm_unreachable("Unexpected platform trying to use TLS");
}
+// Looks through \param Val to determine the bit that can be used to
+// check the sign of the value. It returns the unextended value and
+// the sign bit position.
+std::pair<SDValue, uint64_t> lookThroughSignExtension(SDValue Val) {
+ if (Val.getOpcode() == ISD::SIGN_EXTEND_INREG)
+ return {Val.getOperand(0),
+ cast<VTSDNode>(Val.getOperand(1))->getVT().getFixedSizeInBits() -
+ 1};
+
+ if (Val.getOpcode() == ISD::SIGN_EXTEND)
+ return {Val.getOperand(0),
+ Val.getOperand(0)->getValueType(0).getFixedSizeInBits() - 1};
+
+ return {Val, Val.getValueSizeInBits() - 1};
+}
+
SDValue AArch64TargetLowering::LowerBR_CC(SDValue Op, SelectionDAG &DAG) const {
SDValue Chain = Op.getOperand(0);
ISD::CondCode CC = cast<CondCodeSDNode>(Op.getOperand(1))->get();
@@ -5435,9 +6387,10 @@ SDValue AArch64TargetLowering::LowerBR_CC(SDValue Op, SelectionDAG &DAG) const {
// Don't combine AND since emitComparison converts the AND to an ANDS
// (a.k.a. TST) and the test in the test bit and branch instruction
// becomes redundant. This would also increase register pressure.
- uint64_t Mask = LHS.getValueSizeInBits() - 1;
+ uint64_t SignBitPos;
+ std::tie(LHS, SignBitPos) = lookThroughSignExtension(LHS);
return DAG.getNode(AArch64ISD::TBNZ, dl, MVT::Other, Chain, LHS,
- DAG.getConstant(Mask, dl, MVT::i64), Dest);
+ DAG.getConstant(SignBitPos, dl, MVT::i64), Dest);
}
}
if (RHSC && RHSC->getSExtValue() == -1 && CC == ISD::SETGT &&
@@ -5445,9 +6398,10 @@ SDValue AArch64TargetLowering::LowerBR_CC(SDValue Op, SelectionDAG &DAG) const {
// Don't combine AND since emitComparison converts the AND to an ANDS
// (a.k.a. TST) and the test in the test bit and branch instruction
// becomes redundant. This would also increase register pressure.
- uint64_t Mask = LHS.getValueSizeInBits() - 1;
+ uint64_t SignBitPos;
+ std::tie(LHS, SignBitPos) = lookThroughSignExtension(LHS);
return DAG.getNode(AArch64ISD::TBZ, dl, MVT::Other, Chain, LHS,
- DAG.getConstant(Mask, dl, MVT::i64), Dest);
+ DAG.getConstant(SignBitPos, dl, MVT::i64), Dest);
}
SDValue CCVal;
@@ -5594,6 +6548,9 @@ SDValue AArch64TargetLowering::LowerCTPOP(SDValue Op, SelectionDAG &DAG) const {
return DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i128, UaddLV);
}
+ if (VT.isScalableVector() || useSVEForFixedLengthVectorVT(VT))
+ return LowerToPredicatedOp(Op, DAG, AArch64ISD::CTPOP_MERGE_PASSTHRU);
+
assert((VT == MVT::v1i64 || VT == MVT::v2i64 || VT == MVT::v2i32 ||
VT == MVT::v4i32 || VT == MVT::v4i16 || VT == MVT::v8i16) &&
"Unexpected type for custom ctpop lowering");
@@ -5617,6 +6574,16 @@ SDValue AArch64TargetLowering::LowerCTPOP(SDValue Op, SelectionDAG &DAG) const {
return Val;
}
+SDValue AArch64TargetLowering::LowerCTTZ(SDValue Op, SelectionDAG &DAG) const {
+ EVT VT = Op.getValueType();
+ assert(VT.isScalableVector() ||
+ useSVEForFixedLengthVectorVT(VT, /*OverrideNEON=*/true));
+
+ SDLoc DL(Op);
+ SDValue RBIT = DAG.getNode(ISD::BITREVERSE, DL, VT, Op.getOperand(0));
+ return DAG.getNode(ISD::CTLZ, DL, VT, RBIT);
+}
+
SDValue AArch64TargetLowering::LowerSETCC(SDValue Op, SelectionDAG &DAG) const {
if (Op.getValueType().isVector())
@@ -5774,7 +6741,8 @@ SDValue AArch64TargetLowering::LowerSELECT_CC(ISD::CondCode CC, SDValue LHS,
// instead of a CSEL in that case.
if (TrueVal == ~FalseVal) {
Opcode = AArch64ISD::CSINV;
- } else if (TrueVal == -FalseVal) {
+ } else if (FalseVal > std::numeric_limits<int64_t>::min() &&
+ TrueVal == -FalseVal) {
Opcode = AArch64ISD::CSNEG;
} else if (TVal.getValueType() == MVT::i32) {
// If our operands are only 32-bit wide, make sure we use 32-bit
@@ -5974,6 +6942,9 @@ SDValue AArch64TargetLowering::LowerBR_JT(SDValue Op,
SDValue Entry = Op.getOperand(2);
int JTI = cast<JumpTableSDNode>(JT.getNode())->getIndex();
+ auto *AFI = DAG.getMachineFunction().getInfo<AArch64FunctionInfo>();
+ AFI->setJumpTableEntryInfo(JTI, 4, nullptr);
+
SDNode *Dest =
DAG.getMachineNode(AArch64::JumpTableDest32, DL, MVT::i64, MVT::i64, JT,
Entry, DAG.getTargetJumpTable(JTI, MVT::i32));
@@ -6040,11 +7011,13 @@ SDValue AArch64TargetLowering::LowerWin64_VASTART(SDValue Op,
}
SDValue AArch64TargetLowering::LowerAAPCS_VASTART(SDValue Op,
- SelectionDAG &DAG) const {
+ SelectionDAG &DAG) const {
// The layout of the va_list struct is specified in the AArch64 Procedure Call
// Standard, section B.3.
MachineFunction &MF = DAG.getMachineFunction();
AArch64FunctionInfo *FuncInfo = MF.getInfo<AArch64FunctionInfo>();
+ unsigned PtrSize = Subtarget->isTargetILP32() ? 4 : 8;
+ auto PtrMemVT = getPointerMemTy(DAG.getDataLayout());
auto PtrVT = getPointerTy(DAG.getDataLayout());
SDLoc DL(Op);
@@ -6054,56 +7027,64 @@ SDValue AArch64TargetLowering::LowerAAPCS_VASTART(SDValue Op,
SmallVector<SDValue, 4> MemOps;
// void *__stack at offset 0
+ unsigned Offset = 0;
SDValue Stack = DAG.getFrameIndex(FuncInfo->getVarArgsStackIndex(), PtrVT);
+ Stack = DAG.getZExtOrTrunc(Stack, DL, PtrMemVT);
MemOps.push_back(DAG.getStore(Chain, DL, Stack, VAList,
- MachinePointerInfo(SV), /* Alignment = */ 8));
+ MachinePointerInfo(SV), Align(PtrSize)));
- // void *__gr_top at offset 8
+ // void *__gr_top at offset 8 (4 on ILP32)
+ Offset += PtrSize;
int GPRSize = FuncInfo->getVarArgsGPRSize();
if (GPRSize > 0) {
SDValue GRTop, GRTopAddr;
- GRTopAddr =
- DAG.getNode(ISD::ADD, DL, PtrVT, VAList, DAG.getConstant(8, DL, PtrVT));
+ GRTopAddr = DAG.getNode(ISD::ADD, DL, PtrVT, VAList,
+ DAG.getConstant(Offset, DL, PtrVT));
GRTop = DAG.getFrameIndex(FuncInfo->getVarArgsGPRIndex(), PtrVT);
GRTop = DAG.getNode(ISD::ADD, DL, PtrVT, GRTop,
DAG.getConstant(GPRSize, DL, PtrVT));
+ GRTop = DAG.getZExtOrTrunc(GRTop, DL, PtrMemVT);
MemOps.push_back(DAG.getStore(Chain, DL, GRTop, GRTopAddr,
- MachinePointerInfo(SV, 8),
- /* Alignment = */ 8));
+ MachinePointerInfo(SV, Offset),
+ Align(PtrSize)));
}
- // void *__vr_top at offset 16
+ // void *__vr_top at offset 16 (8 on ILP32)
+ Offset += PtrSize;
int FPRSize = FuncInfo->getVarArgsFPRSize();
if (FPRSize > 0) {
SDValue VRTop, VRTopAddr;
VRTopAddr = DAG.getNode(ISD::ADD, DL, PtrVT, VAList,
- DAG.getConstant(16, DL, PtrVT));
+ DAG.getConstant(Offset, DL, PtrVT));
VRTop = DAG.getFrameIndex(FuncInfo->getVarArgsFPRIndex(), PtrVT);
VRTop = DAG.getNode(ISD::ADD, DL, PtrVT, VRTop,
DAG.getConstant(FPRSize, DL, PtrVT));
+ VRTop = DAG.getZExtOrTrunc(VRTop, DL, PtrMemVT);
MemOps.push_back(DAG.getStore(Chain, DL, VRTop, VRTopAddr,
- MachinePointerInfo(SV, 16),
- /* Alignment = */ 8));
- }
-
- // int __gr_offs at offset 24
- SDValue GROffsAddr =
- DAG.getNode(ISD::ADD, DL, PtrVT, VAList, DAG.getConstant(24, DL, PtrVT));
- MemOps.push_back(DAG.getStore(
- Chain, DL, DAG.getConstant(-GPRSize, DL, MVT::i32), GROffsAddr,
- MachinePointerInfo(SV, 24), /* Alignment = */ 4));
-
- // int __vr_offs at offset 28
- SDValue VROffsAddr =
- DAG.getNode(ISD::ADD, DL, PtrVT, VAList, DAG.getConstant(28, DL, PtrVT));
- MemOps.push_back(DAG.getStore(
- Chain, DL, DAG.getConstant(-FPRSize, DL, MVT::i32), VROffsAddr,
- MachinePointerInfo(SV, 28), /* Alignment = */ 4));
+ MachinePointerInfo(SV, Offset),
+ Align(PtrSize)));
+ }
+
+ // int __gr_offs at offset 24 (12 on ILP32)
+ Offset += PtrSize;
+ SDValue GROffsAddr = DAG.getNode(ISD::ADD, DL, PtrVT, VAList,
+ DAG.getConstant(Offset, DL, PtrVT));
+ MemOps.push_back(
+ DAG.getStore(Chain, DL, DAG.getConstant(-GPRSize, DL, MVT::i32),
+ GROffsAddr, MachinePointerInfo(SV, Offset), Align(4)));
+
+ // int __vr_offs at offset 28 (16 on ILP32)
+ Offset += 4;
+ SDValue VROffsAddr = DAG.getNode(ISD::ADD, DL, PtrVT, VAList,
+ DAG.getConstant(Offset, DL, PtrVT));
+ MemOps.push_back(
+ DAG.getStore(Chain, DL, DAG.getConstant(-FPRSize, DL, MVT::i32),
+ VROffsAddr, MachinePointerInfo(SV, Offset), Align(4)));
return DAG.getNode(ISD::TokenFactor, DL, MVT::Other, MemOps);
}
@@ -6126,8 +7107,10 @@ SDValue AArch64TargetLowering::LowerVACOPY(SDValue Op,
// pointer.
SDLoc DL(Op);
unsigned PtrSize = Subtarget->isTargetILP32() ? 4 : 8;
- unsigned VaListSize = (Subtarget->isTargetDarwin() ||
- Subtarget->isTargetWindows()) ? PtrSize : 32;
+ unsigned VaListSize =
+ (Subtarget->isTargetDarwin() || Subtarget->isTargetWindows())
+ ? PtrSize
+ : Subtarget->isTargetILP32() ? 20 : 32;
const Value *DestSV = cast<SrcValueSDNode>(Op.getOperand(3))->getValue();
const Value *SrcSV = cast<SrcValueSDNode>(Op.getOperand(4))->getValue();
@@ -6155,6 +7138,10 @@ SDValue AArch64TargetLowering::LowerVAARG(SDValue Op, SelectionDAG &DAG) const {
Chain = VAList.getValue(1);
VAList = DAG.getZExtOrTrunc(VAList, DL, PtrVT);
+ if (VT.isScalableVector())
+ report_fatal_error("Passing SVE types to variadic functions is "
+ "currently not supported");
+
if (Align && *Align > MinSlotSize) {
VAList = DAG.getNode(ISD::ADD, DL, PtrVT, VAList,
DAG.getConstant(Align->value() - 1, DL, PtrVT));
@@ -6276,17 +7263,34 @@ SDValue AArch64TargetLowering::LowerRETURNADDR(SDValue Op,
EVT VT = Op.getValueType();
SDLoc DL(Op);
unsigned Depth = cast<ConstantSDNode>(Op.getOperand(0))->getZExtValue();
+ SDValue ReturnAddress;
if (Depth) {
SDValue FrameAddr = LowerFRAMEADDR(Op, DAG);
SDValue Offset = DAG.getConstant(8, DL, getPointerTy(DAG.getDataLayout()));
- return DAG.getLoad(VT, DL, DAG.getEntryNode(),
- DAG.getNode(ISD::ADD, DL, VT, FrameAddr, Offset),
- MachinePointerInfo());
+ ReturnAddress = DAG.getLoad(
+ VT, DL, DAG.getEntryNode(),
+ DAG.getNode(ISD::ADD, DL, VT, FrameAddr, Offset), MachinePointerInfo());
+ } else {
+ // Return LR, which contains the return address. Mark it an implicit
+ // live-in.
+ unsigned Reg = MF.addLiveIn(AArch64::LR, &AArch64::GPR64RegClass);
+ ReturnAddress = DAG.getCopyFromReg(DAG.getEntryNode(), DL, Reg, VT);
+ }
+
+ // The XPACLRI instruction assembles to a hint-space instruction before
+ // Armv8.3-A therefore this instruction can be safely used for any pre
+ // Armv8.3-A architectures. On Armv8.3-A and onwards XPACI is available so use
+ // that instead.
+ SDNode *St;
+ if (Subtarget->hasPAuth()) {
+ St = DAG.getMachineNode(AArch64::XPACI, DL, VT, ReturnAddress);
+ } else {
+ // XPACLRI operates on LR therefore we must move the operand accordingly.
+ SDValue Chain =
+ DAG.getCopyToReg(DAG.getEntryNode(), DL, AArch64::LR, ReturnAddress);
+ St = DAG.getMachineNode(AArch64::XPACLRI, DL, VT, Chain);
}
-
- // Return LR, which contains the return address. Mark it an implicit live-in.
- unsigned Reg = MF.addLiveIn(AArch64::LR, &AArch64::GPR64RegClass);
- return DAG.getCopyFromReg(DAG.getEntryNode(), DL, Reg, VT);
+ return SDValue(St, 0);
}
/// LowerShiftRightParts - Lower SRA_PARTS, which returns two
@@ -6467,6 +7471,22 @@ static SDValue getEstimate(const AArch64Subtarget *ST, unsigned Opcode,
return SDValue();
}
+SDValue
+AArch64TargetLowering::getSqrtInputTest(SDValue Op, SelectionDAG &DAG,
+ const DenormalMode &Mode) const {
+ SDLoc DL(Op);
+ EVT VT = Op.getValueType();
+ EVT CCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
+ SDValue FPZero = DAG.getConstantFP(0.0, DL, VT);
+ return DAG.getSetCC(DL, CCVT, Op, FPZero, ISD::SETEQ);
+}
+
+SDValue
+AArch64TargetLowering::getSqrtResultForDenormInput(SDValue Op,
+ SelectionDAG &DAG) const {
+ return Op;
+}
+
SDValue AArch64TargetLowering::getSqrtEstimate(SDValue Operand,
SelectionDAG &DAG, int Enabled,
int &ExtraSteps,
@@ -6490,17 +7510,8 @@ SDValue AArch64TargetLowering::getSqrtEstimate(SDValue Operand,
Step = DAG.getNode(AArch64ISD::FRSQRTS, DL, VT, Operand, Step, Flags);
Estimate = DAG.getNode(ISD::FMUL, DL, VT, Estimate, Step, Flags);
}
- if (!Reciprocal) {
- EVT CCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(),
- VT);
- SDValue FPZero = DAG.getConstantFP(0.0, DL, VT);
- SDValue Eq = DAG.getSetCC(DL, CCVT, Operand, FPZero, ISD::SETEQ);
-
+ if (!Reciprocal)
Estimate = DAG.getNode(ISD::FMUL, DL, VT, Operand, Estimate, Flags);
- // Correct the result if the operand is 0.0.
- Estimate = DAG.getNode(VT.isVector() ? ISD::VSELECT : ISD::SELECT, DL,
- VT, Eq, Operand, Estimate);
- }
ExtraSteps = 0;
return Estimate;
@@ -6676,23 +7687,30 @@ AArch64TargetLowering::getRegForInlineAsmConstraint(
if (Constraint.size() == 1) {
switch (Constraint[0]) {
case 'r':
- if (VT.getSizeInBits() == 64)
+ if (VT.isScalableVector())
+ return std::make_pair(0U, nullptr);
+ if (VT.getFixedSizeInBits() == 64)
return std::make_pair(0U, &AArch64::GPR64commonRegClass);
return std::make_pair(0U, &AArch64::GPR32commonRegClass);
- case 'w':
+ case 'w': {
if (!Subtarget->hasFPARMv8())
break;
- if (VT.isScalableVector())
- return std::make_pair(0U, &AArch64::ZPRRegClass);
- if (VT.getSizeInBits() == 16)
+ if (VT.isScalableVector()) {
+ if (VT.getVectorElementType() != MVT::i1)
+ return std::make_pair(0U, &AArch64::ZPRRegClass);
+ return std::make_pair(0U, nullptr);
+ }
+ uint64_t VTSize = VT.getFixedSizeInBits();
+ if (VTSize == 16)
return std::make_pair(0U, &AArch64::FPR16RegClass);
- if (VT.getSizeInBits() == 32)
+ if (VTSize == 32)
return std::make_pair(0U, &AArch64::FPR32RegClass);
- if (VT.getSizeInBits() == 64)
+ if (VTSize == 64)
return std::make_pair(0U, &AArch64::FPR64RegClass);
- if (VT.getSizeInBits() == 128)
+ if (VTSize == 128)
return std::make_pair(0U, &AArch64::FPR128RegClass);
break;
+ }
// The instructions that this constraint is designed for can
// only take 128-bit registers so just use that regclass.
case 'x':
@@ -6713,10 +7731,11 @@ AArch64TargetLowering::getRegForInlineAsmConstraint(
} else {
PredicateConstraint PC = parsePredicateConstraint(Constraint);
if (PC != PredicateConstraint::Invalid) {
- assert(VT.isScalableVector());
+ if (!VT.isScalableVector() || VT.getVectorElementType() != MVT::i1)
+ return std::make_pair(0U, nullptr);
bool restricted = (PC == PredicateConstraint::Upl);
return restricted ? std::make_pair(0U, &AArch64::PPR_3bRegClass)
- : std::make_pair(0U, &AArch64::PPRRegClass);
+ : std::make_pair(0U, &AArch64::PPRRegClass);
}
}
if (StringRef("{cc}").equals_lower(Constraint))
@@ -6955,6 +7974,8 @@ SDValue AArch64TargetLowering::ReconstructShuffle(SDValue Op,
LLVM_DEBUG(dbgs() << "AArch64TargetLowering::ReconstructShuffle\n");
SDLoc dl(Op);
EVT VT = Op.getValueType();
+ assert(!VT.isScalableVector() &&
+ "Scalable vectors cannot be used with ISD::BUILD_VECTOR");
unsigned NumElts = VT.getVectorNumElements();
struct ShuffleSourceInfo {
@@ -7025,8 +8046,9 @@ SDValue AArch64TargetLowering::ReconstructShuffle(SDValue Op,
}
}
unsigned ResMultiplier =
- VT.getScalarSizeInBits() / SmallestEltTy.getSizeInBits();
- NumElts = VT.getSizeInBits() / SmallestEltTy.getSizeInBits();
+ VT.getScalarSizeInBits() / SmallestEltTy.getFixedSizeInBits();
+ uint64_t VTSize = VT.getFixedSizeInBits();
+ NumElts = VTSize / SmallestEltTy.getFixedSizeInBits();
EVT ShuffleVT = EVT::getVectorVT(*DAG.getContext(), SmallestEltTy, NumElts);
// If the source vector is too wide or too narrow, we may nevertheless be able
@@ -7035,17 +8057,18 @@ SDValue AArch64TargetLowering::ReconstructShuffle(SDValue Op,
for (auto &Src : Sources) {
EVT SrcVT = Src.ShuffleVec.getValueType();
- if (SrcVT.getSizeInBits() == VT.getSizeInBits())
+ uint64_t SrcVTSize = SrcVT.getFixedSizeInBits();
+ if (SrcVTSize == VTSize)
continue;
// This stage of the search produces a source with the same element type as
// the original, but with a total width matching the BUILD_VECTOR output.
EVT EltVT = SrcVT.getVectorElementType();
- unsigned NumSrcElts = VT.getSizeInBits() / EltVT.getSizeInBits();
+ unsigned NumSrcElts = VTSize / EltVT.getFixedSizeInBits();
EVT DestVT = EVT::getVectorVT(*DAG.getContext(), EltVT, NumSrcElts);
- if (SrcVT.getSizeInBits() < VT.getSizeInBits()) {
- assert(2 * SrcVT.getSizeInBits() == VT.getSizeInBits());
+ if (SrcVTSize < VTSize) {
+ assert(2 * SrcVTSize == VTSize);
// We can pad out the smaller vector for free, so if it's part of a
// shuffle...
Src.ShuffleVec =
@@ -7054,7 +8077,11 @@ SDValue AArch64TargetLowering::ReconstructShuffle(SDValue Op,
continue;
}
- assert(SrcVT.getSizeInBits() == 2 * VT.getSizeInBits());
+ if (SrcVTSize != 2 * VTSize) {
+ LLVM_DEBUG(
+ dbgs() << "Reshuffle failed: result vector too small to extract\n");
+ return SDValue();
+ }
if (Src.MaxElt - Src.MinElt >= NumSrcElts) {
LLVM_DEBUG(
@@ -7083,6 +8110,13 @@ SDValue AArch64TargetLowering::ReconstructShuffle(SDValue Op,
DAG.getConstant(NumSrcElts, dl, MVT::i64));
unsigned Imm = Src.MinElt * getExtFactor(VEXTSrc1);
+ if (!SrcVT.is64BitVector()) {
+ LLVM_DEBUG(
+ dbgs() << "Reshuffle failed: don't know how to lower AArch64ISD::EXT "
+ "for SVE vectors.");
+ return SDValue();
+ }
+
Src.ShuffleVec = DAG.getNode(AArch64ISD::EXT, dl, DestVT, VEXTSrc1,
VEXTSrc2,
DAG.getConstant(Imm, dl, MVT::i32));
@@ -7099,7 +8133,8 @@ SDValue AArch64TargetLowering::ReconstructShuffle(SDValue Op,
continue;
assert(ShuffleVT.getVectorElementType() == SmallestEltTy);
Src.ShuffleVec = DAG.getNode(ISD::BITCAST, dl, ShuffleVT, Src.ShuffleVec);
- Src.WindowScale = SrcEltTy.getSizeInBits() / SmallestEltTy.getSizeInBits();
+ Src.WindowScale =
+ SrcEltTy.getFixedSizeInBits() / SmallestEltTy.getFixedSizeInBits();
Src.WindowBase *= Src.WindowScale;
}
@@ -7123,8 +8158,8 @@ SDValue AArch64TargetLowering::ReconstructShuffle(SDValue Op,
// trunc. So only std::min(SrcBits, DestBits) actually get defined in this
// segment.
EVT OrigEltTy = Entry.getOperand(0).getValueType().getVectorElementType();
- int BitsDefined =
- std::min(OrigEltTy.getSizeInBits(), VT.getScalarSizeInBits());
+ int BitsDefined = std::min(OrigEltTy.getScalarSizeInBits(),
+ VT.getScalarSizeInBits());
int LanesDefined = BitsDefined / BitsPerShuffleLane;
// This source is expected to fill ResMultiplier lanes of the final shuffle,
@@ -7188,6 +8223,81 @@ static bool isSingletonEXTMask(ArrayRef<int> M, EVT VT, unsigned &Imm) {
return true;
}
+/// Check if a vector shuffle corresponds to a DUP instructions with a larger
+/// element width than the vector lane type. If that is the case the function
+/// returns true and writes the value of the DUP instruction lane operand into
+/// DupLaneOp
+static bool isWideDUPMask(ArrayRef<int> M, EVT VT, unsigned BlockSize,
+ unsigned &DupLaneOp) {
+ assert((BlockSize == 16 || BlockSize == 32 || BlockSize == 64) &&
+ "Only possible block sizes for wide DUP are: 16, 32, 64");
+
+ if (BlockSize <= VT.getScalarSizeInBits())
+ return false;
+ if (BlockSize % VT.getScalarSizeInBits() != 0)
+ return false;
+ if (VT.getSizeInBits() % BlockSize != 0)
+ return false;
+
+ size_t SingleVecNumElements = VT.getVectorNumElements();
+ size_t NumEltsPerBlock = BlockSize / VT.getScalarSizeInBits();
+ size_t NumBlocks = VT.getSizeInBits() / BlockSize;
+
+ // We are looking for masks like
+ // [0, 1, 0, 1] or [2, 3, 2, 3] or [4, 5, 6, 7, 4, 5, 6, 7] where any element
+ // might be replaced by 'undefined'. BlockIndices will eventually contain
+ // lane indices of the duplicated block (i.e. [0, 1], [2, 3] and [4, 5, 6, 7]
+ // for the above examples)
+ SmallVector<int, 8> BlockElts(NumEltsPerBlock, -1);
+ for (size_t BlockIndex = 0; BlockIndex < NumBlocks; BlockIndex++)
+ for (size_t I = 0; I < NumEltsPerBlock; I++) {
+ int Elt = M[BlockIndex * NumEltsPerBlock + I];
+ if (Elt < 0)
+ continue;
+ // For now we don't support shuffles that use the second operand
+ if ((unsigned)Elt >= SingleVecNumElements)
+ return false;
+ if (BlockElts[I] < 0)
+ BlockElts[I] = Elt;
+ else if (BlockElts[I] != Elt)
+ return false;
+ }
+
+ // We found a candidate block (possibly with some undefs). It must be a
+ // sequence of consecutive integers starting with a value divisible by
+ // NumEltsPerBlock with some values possibly replaced by undef-s.
+
+ // Find first non-undef element
+ auto FirstRealEltIter = find_if(BlockElts, [](int Elt) { return Elt >= 0; });
+ assert(FirstRealEltIter != BlockElts.end() &&
+ "Shuffle with all-undefs must have been caught by previous cases, "
+ "e.g. isSplat()");
+ if (FirstRealEltIter == BlockElts.end()) {
+ DupLaneOp = 0;
+ return true;
+ }
+
+ // Index of FirstRealElt in BlockElts
+ size_t FirstRealIndex = FirstRealEltIter - BlockElts.begin();
+
+ if ((unsigned)*FirstRealEltIter < FirstRealIndex)
+ return false;
+ // BlockElts[0] must have the following value if it isn't undef:
+ size_t Elt0 = *FirstRealEltIter - FirstRealIndex;
+
+ // Check the first element
+ if (Elt0 % NumEltsPerBlock != 0)
+ return false;
+ // Check that the sequence indeed consists of consecutive integers (modulo
+ // undefs)
+ for (size_t I = 0; I < NumEltsPerBlock; I++)
+ if (BlockElts[I] >= 0 && (unsigned)BlockElts[I] != Elt0 + I)
+ return false;
+
+ DupLaneOp = Elt0 / NumEltsPerBlock;
+ return true;
+}
+
// check if an EXT instruction can handle the shuffle mask when the
// vector sources of the shuffle are different.
static bool isEXTMask(ArrayRef<int> M, EVT VT, bool &ReverseEXT,
@@ -7621,6 +8731,60 @@ static unsigned getDUPLANEOp(EVT EltType) {
llvm_unreachable("Invalid vector element type?");
}
+static SDValue constructDup(SDValue V, int Lane, SDLoc dl, EVT VT,
+ unsigned Opcode, SelectionDAG &DAG) {
+ // Try to eliminate a bitcasted extract subvector before a DUPLANE.
+ auto getScaledOffsetDup = [](SDValue BitCast, int &LaneC, MVT &CastVT) {
+ // Match: dup (bitcast (extract_subv X, C)), LaneC
+ if (BitCast.getOpcode() != ISD::BITCAST ||
+ BitCast.getOperand(0).getOpcode() != ISD::EXTRACT_SUBVECTOR)
+ return false;
+
+ // The extract index must align in the destination type. That may not
+ // happen if the bitcast is from narrow to wide type.
+ SDValue Extract = BitCast.getOperand(0);
+ unsigned ExtIdx = Extract.getConstantOperandVal(1);
+ unsigned SrcEltBitWidth = Extract.getScalarValueSizeInBits();
+ unsigned ExtIdxInBits = ExtIdx * SrcEltBitWidth;
+ unsigned CastedEltBitWidth = BitCast.getScalarValueSizeInBits();
+ if (ExtIdxInBits % CastedEltBitWidth != 0)
+ return false;
+
+ // Update the lane value by offsetting with the scaled extract index.
+ LaneC += ExtIdxInBits / CastedEltBitWidth;
+
+ // Determine the casted vector type of the wide vector input.
+ // dup (bitcast (extract_subv X, C)), LaneC --> dup (bitcast X), LaneC'
+ // Examples:
+ // dup (bitcast (extract_subv v2f64 X, 1) to v2f32), 1 --> dup v4f32 X, 3
+ // dup (bitcast (extract_subv v16i8 X, 8) to v4i16), 1 --> dup v8i16 X, 5
+ unsigned SrcVecNumElts =
+ Extract.getOperand(0).getValueSizeInBits() / CastedEltBitWidth;
+ CastVT = MVT::getVectorVT(BitCast.getSimpleValueType().getScalarType(),
+ SrcVecNumElts);
+ return true;
+ };
+ MVT CastVT;
+ if (getScaledOffsetDup(V, Lane, CastVT)) {
+ V = DAG.getBitcast(CastVT, V.getOperand(0).getOperand(0));
+ } else if (V.getOpcode() == ISD::EXTRACT_SUBVECTOR) {
+ // The lane is incremented by the index of the extract.
+ // Example: dup v2f32 (extract v4f32 X, 2), 1 --> dup v4f32 X, 3
+ Lane += V.getConstantOperandVal(1);
+ V = V.getOperand(0);
+ } else if (V.getOpcode() == ISD::CONCAT_VECTORS) {
+ // The lane is decremented if we are splatting from the 2nd operand.
+ // Example: dup v4i32 (concat v2i32 X, v2i32 Y), 3 --> dup v4i32 Y, 1
+ unsigned Idx = Lane >= (int)VT.getVectorNumElements() / 2;
+ Lane -= Idx * VT.getVectorNumElements() / 2;
+ V = WidenVector(V.getOperand(Idx), DAG);
+ } else if (VT.getSizeInBits() == 64) {
+ // Widen the operand to 128-bit register with undef.
+ V = WidenVector(V, DAG);
+ }
+ return DAG.getNode(Opcode, dl, VT, V, DAG.getConstant(Lane, dl, MVT::i64));
+}
+
SDValue AArch64TargetLowering::LowerVECTOR_SHUFFLE(SDValue Op,
SelectionDAG &DAG) const {
SDLoc dl(Op);
@@ -7654,57 +8818,26 @@ SDValue AArch64TargetLowering::LowerVECTOR_SHUFFLE(SDValue Op,
// Otherwise, duplicate from the lane of the input vector.
unsigned Opcode = getDUPLANEOp(V1.getValueType().getVectorElementType());
-
- // Try to eliminate a bitcasted extract subvector before a DUPLANE.
- auto getScaledOffsetDup = [](SDValue BitCast, int &LaneC, MVT &CastVT) {
- // Match: dup (bitcast (extract_subv X, C)), LaneC
- if (BitCast.getOpcode() != ISD::BITCAST ||
- BitCast.getOperand(0).getOpcode() != ISD::EXTRACT_SUBVECTOR)
- return false;
-
- // The extract index must align in the destination type. That may not
- // happen if the bitcast is from narrow to wide type.
- SDValue Extract = BitCast.getOperand(0);
- unsigned ExtIdx = Extract.getConstantOperandVal(1);
- unsigned SrcEltBitWidth = Extract.getScalarValueSizeInBits();
- unsigned ExtIdxInBits = ExtIdx * SrcEltBitWidth;
- unsigned CastedEltBitWidth = BitCast.getScalarValueSizeInBits();
- if (ExtIdxInBits % CastedEltBitWidth != 0)
- return false;
-
- // Update the lane value by offsetting with the scaled extract index.
- LaneC += ExtIdxInBits / CastedEltBitWidth;
-
- // Determine the casted vector type of the wide vector input.
- // dup (bitcast (extract_subv X, C)), LaneC --> dup (bitcast X), LaneC'
- // Examples:
- // dup (bitcast (extract_subv v2f64 X, 1) to v2f32), 1 --> dup v4f32 X, 3
- // dup (bitcast (extract_subv v16i8 X, 8) to v4i16), 1 --> dup v8i16 X, 5
- unsigned SrcVecNumElts =
- Extract.getOperand(0).getValueSizeInBits() / CastedEltBitWidth;
- CastVT = MVT::getVectorVT(BitCast.getSimpleValueType().getScalarType(),
- SrcVecNumElts);
- return true;
- };
- MVT CastVT;
- if (getScaledOffsetDup(V1, Lane, CastVT)) {
- V1 = DAG.getBitcast(CastVT, V1.getOperand(0).getOperand(0));
- } else if (V1.getOpcode() == ISD::EXTRACT_SUBVECTOR) {
- // The lane is incremented by the index of the extract.
- // Example: dup v2f32 (extract v4f32 X, 2), 1 --> dup v4f32 X, 3
- Lane += V1.getConstantOperandVal(1);
- V1 = V1.getOperand(0);
- } else if (V1.getOpcode() == ISD::CONCAT_VECTORS) {
- // The lane is decremented if we are splatting from the 2nd operand.
- // Example: dup v4i32 (concat v2i32 X, v2i32 Y), 3 --> dup v4i32 Y, 1
- unsigned Idx = Lane >= (int)VT.getVectorNumElements() / 2;
- Lane -= Idx * VT.getVectorNumElements() / 2;
- V1 = WidenVector(V1.getOperand(Idx), DAG);
- } else if (VT.getSizeInBits() == 64) {
- // Widen the operand to 128-bit register with undef.
- V1 = WidenVector(V1, DAG);
- }
- return DAG.getNode(Opcode, dl, VT, V1, DAG.getConstant(Lane, dl, MVT::i64));
+ return constructDup(V1, Lane, dl, VT, Opcode, DAG);
+ }
+
+ // Check if the mask matches a DUP for a wider element
+ for (unsigned LaneSize : {64U, 32U, 16U}) {
+ unsigned Lane = 0;
+ if (isWideDUPMask(ShuffleMask, VT, LaneSize, Lane)) {
+ unsigned Opcode = LaneSize == 64 ? AArch64ISD::DUPLANE64
+ : LaneSize == 32 ? AArch64ISD::DUPLANE32
+ : AArch64ISD::DUPLANE16;
+ // Cast V1 to an integer vector with required lane size
+ MVT NewEltTy = MVT::getIntegerVT(LaneSize);
+ unsigned NewEltCount = VT.getSizeInBits() / LaneSize;
+ MVT NewVecTy = MVT::getVectorVT(NewEltTy, NewEltCount);
+ V1 = DAG.getBitcast(NewVecTy, V1);
+ // Constuct the DUP instruction
+ V1 = constructDup(V1, Lane, dl, NewVecTy, Opcode, DAG);
+ // Cast back to the original type
+ return DAG.getBitcast(VT, V1);
+ }
}
if (isREVMask(ShuffleMask, VT, 64))
@@ -7775,7 +8908,7 @@ SDValue AArch64TargetLowering::LowerVECTOR_SHUFFLE(SDValue Op,
EVT ScalarVT = VT.getVectorElementType();
- if (ScalarVT.getSizeInBits() < 32 && ScalarVT.isInteger())
+ if (ScalarVT.getFixedSizeInBits() < 32 && ScalarVT.isInteger())
ScalarVT = MVT::i32;
return DAG.getNode(
@@ -7814,9 +8947,11 @@ SDValue AArch64TargetLowering::LowerSPLAT_VECTOR(SDValue Op,
SDLoc dl(Op);
EVT VT = Op.getValueType();
EVT ElemVT = VT.getScalarType();
-
SDValue SplatVal = Op.getOperand(0);
+ if (useSVEForFixedLengthVectorVT(VT))
+ return LowerToScalableOp(Op, DAG);
+
// Extend input splat value where needed to fit into a GPR (32b or 64b only)
// FPRs don't have this restriction.
switch (ElemVT.getSimpleVT().SimpleTy) {
@@ -8246,6 +9381,9 @@ static SDValue tryLowerToSLI(SDNode *N, SelectionDAG &DAG) {
SDValue AArch64TargetLowering::LowerVectorOR(SDValue Op,
SelectionDAG &DAG) const {
+ if (useSVEForFixedLengthVectorVT(Op.getValueType()))
+ return LowerToScalableOp(Op, DAG);
+
// Attempt to form a vector S[LR]I from (or (and X, C1), (lsl Y, C2))
if (SDValue Res = tryLowerToSLI(Op.getNode(), DAG))
return Res;
@@ -8404,14 +9542,18 @@ SDValue AArch64TargetLowering::LowerBUILD_VECTOR(SDValue Op,
bool isConstant = true;
bool AllLanesExtractElt = true;
unsigned NumConstantLanes = 0;
+ unsigned NumDifferentLanes = 0;
+ unsigned NumUndefLanes = 0;
SDValue Value;
SDValue ConstantValue;
for (unsigned i = 0; i < NumElts; ++i) {
SDValue V = Op.getOperand(i);
if (V.getOpcode() != ISD::EXTRACT_VECTOR_ELT)
AllLanesExtractElt = false;
- if (V.isUndef())
+ if (V.isUndef()) {
+ ++NumUndefLanes;
continue;
+ }
if (i > 0)
isOnlyLowElement = false;
if (!isa<ConstantFPSDNode>(V) && !isa<ConstantSDNode>(V))
@@ -8427,8 +9569,10 @@ SDValue AArch64TargetLowering::LowerBUILD_VECTOR(SDValue Op,
if (!Value.getNode())
Value = V;
- else if (V != Value)
+ else if (V != Value) {
usesOnlyOneValue = false;
+ ++NumDifferentLanes;
+ }
}
if (!Value.getNode()) {
@@ -8554,11 +9698,20 @@ SDValue AArch64TargetLowering::LowerBUILD_VECTOR(SDValue Op,
}
}
+ // If we need to insert a small number of different non-constant elements and
+ // the vector width is sufficiently large, prefer using DUP with the common
+ // value and INSERT_VECTOR_ELT for the different lanes. If DUP is preferred,
+ // skip the constant lane handling below.
+ bool PreferDUPAndInsert =
+ !isConstant && NumDifferentLanes >= 1 &&
+ NumDifferentLanes < ((NumElts - NumUndefLanes) / 2) &&
+ NumDifferentLanes >= NumConstantLanes;
+
// If there was only one constant value used and for more than one lane,
// start by splatting that value, then replace the non-constant lanes. This
// is better than the default, which will perform a separate initialization
// for each lane.
- if (NumConstantLanes > 0 && usesOnlyOneConstantValue) {
+ if (!PreferDUPAndInsert && NumConstantLanes > 0 && usesOnlyOneConstantValue) {
// Firstly, try to materialize the splat constant.
SDValue Vec = DAG.getSplatBuildVector(VT, dl, ConstantValue),
Val = ConstantBuildVector(Vec, DAG);
@@ -8594,6 +9747,22 @@ SDValue AArch64TargetLowering::LowerBUILD_VECTOR(SDValue Op,
return shuffle;
}
+ if (PreferDUPAndInsert) {
+ // First, build a constant vector with the common element.
+ SmallVector<SDValue, 8> Ops;
+ for (unsigned I = 0; I < NumElts; ++I)
+ Ops.push_back(Value);
+ SDValue NewVector = LowerBUILD_VECTOR(DAG.getBuildVector(VT, dl, Ops), DAG);
+ // Next, insert the elements that do not match the common value.
+ for (unsigned I = 0; I < NumElts; ++I)
+ if (Op.getOperand(I) != Value)
+ NewVector =
+ DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, VT, NewVector,
+ Op.getOperand(I), DAG.getConstant(I, dl, MVT::i64));
+
+ return NewVector;
+ }
+
// If all else fails, just use a sequence of INSERT_VECTOR_ELT when we
// know the default expansion would otherwise fall back on something even
// worse. For a vector with one or two non-undef values, that's
@@ -8642,6 +9811,18 @@ SDValue AArch64TargetLowering::LowerBUILD_VECTOR(SDValue Op,
return SDValue();
}
+SDValue AArch64TargetLowering::LowerCONCAT_VECTORS(SDValue Op,
+ SelectionDAG &DAG) const {
+ assert(Op.getValueType().isScalableVector() &&
+ isTypeLegal(Op.getValueType()) &&
+ "Expected legal scalable vector type!");
+
+ if (isTypeLegal(Op.getOperand(0).getValueType()) && Op.getNumOperands() == 2)
+ return Op;
+
+ return SDValue();
+}
+
SDValue AArch64TargetLowering::LowerINSERT_VECTOR_ELT(SDValue Op,
SelectionDAG &DAG) const {
assert(Op.getOpcode() == ISD::INSERT_VECTOR_ELT && "Unknown opcode!");
@@ -8737,7 +9918,8 @@ SDValue AArch64TargetLowering::LowerEXTRACT_SUBVECTOR(SDValue Op,
// If this is extracting the upper 64-bits of a 128-bit vector, we match
// that directly.
- if (Size == 64 && Idx * InVT.getScalarSizeInBits() == 64)
+ if (Size == 64 && Idx * InVT.getScalarSizeInBits() == 64 &&
+ InVT.getSizeInBits() == 128)
return Op;
return SDValue();
@@ -8751,9 +9933,34 @@ SDValue AArch64TargetLowering::LowerINSERT_SUBVECTOR(SDValue Op,
EVT InVT = Op.getOperand(1).getValueType();
unsigned Idx = cast<ConstantSDNode>(Op.getOperand(2))->getZExtValue();
- // We don't have any patterns for scalable vector yet.
- if (InVT.isScalableVector() || !useSVEForFixedLengthVectorVT(InVT))
+ if (InVT.isScalableVector()) {
+ SDLoc DL(Op);
+ EVT VT = Op.getValueType();
+
+ if (!isTypeLegal(VT) || !VT.isInteger())
+ return SDValue();
+
+ SDValue Vec0 = Op.getOperand(0);
+ SDValue Vec1 = Op.getOperand(1);
+
+ // Ensure the subvector is half the size of the main vector.
+ if (VT.getVectorElementCount() != (InVT.getVectorElementCount() * 2))
+ return SDValue();
+
+ // Extend elements of smaller vector...
+ EVT WideVT = InVT.widenIntegerVectorElementType(*(DAG.getContext()));
+ SDValue ExtVec = DAG.getNode(ISD::ANY_EXTEND, DL, WideVT, Vec1);
+
+ if (Idx == 0) {
+ SDValue HiVec0 = DAG.getNode(AArch64ISD::UUNPKHI, DL, WideVT, Vec0);
+ return DAG.getNode(AArch64ISD::UZP1, DL, VT, ExtVec, HiVec0);
+ } else if (Idx == InVT.getVectorMinNumElements()) {
+ SDValue LoVec0 = DAG.getNode(AArch64ISD::UUNPKLO, DL, WideVT, Vec0);
+ return DAG.getNode(AArch64ISD::UZP1, DL, VT, LoVec0, ExtVec);
+ }
+
return SDValue();
+ }
// This will be matched by custom code during ISelDAGToDAG.
if (Idx == 0 && isPackedVectorType(InVT, DAG) && Op.getOperand(0).isUndef())
@@ -8762,6 +9969,42 @@ SDValue AArch64TargetLowering::LowerINSERT_SUBVECTOR(SDValue Op,
return SDValue();
}
+SDValue AArch64TargetLowering::LowerDIV(SDValue Op, SelectionDAG &DAG) const {
+ EVT VT = Op.getValueType();
+
+ if (useSVEForFixedLengthVectorVT(VT, /*OverrideNEON=*/true))
+ return LowerFixedLengthVectorIntDivideToSVE(Op, DAG);
+
+ assert(VT.isScalableVector() && "Expected a scalable vector.");
+
+ bool Signed = Op.getOpcode() == ISD::SDIV;
+ unsigned PredOpcode = Signed ? AArch64ISD::SDIV_PRED : AArch64ISD::UDIV_PRED;
+
+ if (VT == MVT::nxv4i32 || VT == MVT::nxv2i64)
+ return LowerToPredicatedOp(Op, DAG, PredOpcode);
+
+ // SVE doesn't have i8 and i16 DIV operations; widen them to 32-bit
+ // operations, and truncate the result.
+ EVT WidenedVT;
+ if (VT == MVT::nxv16i8)
+ WidenedVT = MVT::nxv8i16;
+ else if (VT == MVT::nxv8i16)
+ WidenedVT = MVT::nxv4i32;
+ else
+ llvm_unreachable("Unexpected Custom DIV operation");
+
+ SDLoc dl(Op);
+ unsigned UnpkLo = Signed ? AArch64ISD::SUNPKLO : AArch64ISD::UUNPKLO;
+ unsigned UnpkHi = Signed ? AArch64ISD::SUNPKHI : AArch64ISD::UUNPKHI;
+ SDValue Op0Lo = DAG.getNode(UnpkLo, dl, WidenedVT, Op.getOperand(0));
+ SDValue Op1Lo = DAG.getNode(UnpkLo, dl, WidenedVT, Op.getOperand(1));
+ SDValue Op0Hi = DAG.getNode(UnpkHi, dl, WidenedVT, Op.getOperand(0));
+ SDValue Op1Hi = DAG.getNode(UnpkHi, dl, WidenedVT, Op.getOperand(1));
+ SDValue ResultLo = DAG.getNode(Op.getOpcode(), dl, WidenedVT, Op0Lo, Op1Lo);
+ SDValue ResultHi = DAG.getNode(Op.getOpcode(), dl, WidenedVT, Op0Hi, Op1Hi);
+ return DAG.getNode(AArch64ISD::UZP1, dl, VT, ResultLo, ResultHi);
+}
+
bool AArch64TargetLowering::isShuffleMaskLegal(ArrayRef<int> M, EVT VT) const {
// Currently no fixed length shuffles that require SVE are legal.
if (useSVEForFixedLengthVectorVT(VT))
@@ -8846,79 +10089,27 @@ static bool isVShiftRImm(SDValue Op, EVT VT, bool isNarrow, int64_t &Cnt) {
return (Cnt >= 1 && Cnt <= (isNarrow ? ElementBits / 2 : ElementBits));
}
-// Attempt to form urhadd(OpA, OpB) from
-// truncate(vlshr(sub(zext(OpB), xor(zext(OpA), Ones(ElemSizeInBits))), 1)).
-// The original form of this expression is
-// truncate(srl(add(zext(OpB), add(zext(OpA), 1)), 1)) and before this function
-// is called the srl will have been lowered to AArch64ISD::VLSHR and the
-// ((OpA + OpB + 1) >> 1) expression will have been changed to (OpB - (~OpA)).
-// This pass can also recognize a variant of this pattern that uses sign
-// extension instead of zero extension and form a srhadd(OpA, OpB) from it.
SDValue AArch64TargetLowering::LowerTRUNCATE(SDValue Op,
SelectionDAG &DAG) const {
EVT VT = Op.getValueType();
+ if (VT.getScalarType() == MVT::i1) {
+ // Lower i1 truncate to `(x & 1) != 0`.
+ SDLoc dl(Op);
+ EVT OpVT = Op.getOperand(0).getValueType();
+ SDValue Zero = DAG.getConstant(0, dl, OpVT);
+ SDValue One = DAG.getConstant(1, dl, OpVT);
+ SDValue And = DAG.getNode(ISD::AND, dl, OpVT, Op.getOperand(0), One);
+ return DAG.getSetCC(dl, VT, And, Zero, ISD::SETNE);
+ }
+
if (!VT.isVector() || VT.isScalableVector())
- return Op;
+ return SDValue();
if (useSVEForFixedLengthVectorVT(Op.getOperand(0).getValueType()))
return LowerFixedLengthVectorTruncateToSVE(Op, DAG);
- // Since we are looking for a right shift by a constant value of 1 and we are
- // operating on types at least 16 bits in length (sign/zero extended OpA and
- // OpB, which are at least 8 bits), it follows that the truncate will always
- // discard the shifted-in bit and therefore the right shift will be logical
- // regardless of the signedness of OpA and OpB.
- SDValue Shift = Op.getOperand(0);
- if (Shift.getOpcode() != AArch64ISD::VLSHR)
- return Op;
-
- // Is the right shift using an immediate value of 1?
- uint64_t ShiftAmount = Shift.getConstantOperandVal(1);
- if (ShiftAmount != 1)
- return Op;
-
- SDValue Sub = Shift->getOperand(0);
- if (Sub.getOpcode() != ISD::SUB)
- return Op;
-
- SDValue Xor = Sub.getOperand(1);
- if (Xor.getOpcode() != ISD::XOR)
- return Op;
-
- SDValue ExtendOpA = Xor.getOperand(0);
- SDValue ExtendOpB = Sub.getOperand(0);
- unsigned ExtendOpAOpc = ExtendOpA.getOpcode();
- unsigned ExtendOpBOpc = ExtendOpB.getOpcode();
- if (!(ExtendOpAOpc == ExtendOpBOpc &&
- (ExtendOpAOpc == ISD::ZERO_EXTEND || ExtendOpAOpc == ISD::SIGN_EXTEND)))
- return Op;
-
- // Is the result of the right shift being truncated to the same value type as
- // the original operands, OpA and OpB?
- SDValue OpA = ExtendOpA.getOperand(0);
- SDValue OpB = ExtendOpB.getOperand(0);
- EVT OpAVT = OpA.getValueType();
- assert(ExtendOpA.getValueType() == ExtendOpB.getValueType());
- if (!(VT == OpAVT && OpAVT == OpB.getValueType()))
- return Op;
-
- // Is the XOR using a constant amount of all ones in the right hand side?
- uint64_t C;
- if (!isAllConstantBuildVector(Xor.getOperand(1), C))
- return Op;
-
- unsigned ElemSizeInBits = VT.getScalarSizeInBits();
- APInt CAsAPInt(ElemSizeInBits, C);
- if (CAsAPInt != APInt::getAllOnesValue(ElemSizeInBits))
- return Op;
-
- SDLoc DL(Op);
- bool IsSignExtend = ExtendOpAOpc == ISD::SIGN_EXTEND;
- unsigned RHADDOpc = IsSignExtend ? AArch64ISD::SRHADD : AArch64ISD::URHADD;
- SDValue ResultURHADD = DAG.getNode(RHADDOpc, DL, VT, OpA, OpB);
-
- return ResultURHADD;
+ return SDValue();
}
SDValue AArch64TargetLowering::LowerVectorSRA_SRL_SHL(SDValue Op,
@@ -8936,8 +10127,8 @@ SDValue AArch64TargetLowering::LowerVectorSRA_SRL_SHL(SDValue Op,
llvm_unreachable("unexpected shift opcode");
case ISD::SHL:
- if (VT.isScalableVector())
- return LowerToPredicatedOp(Op, DAG, AArch64ISD::SHL_MERGE_OP1);
+ if (VT.isScalableVector() || useSVEForFixedLengthVectorVT(VT))
+ return LowerToPredicatedOp(Op, DAG, AArch64ISD::SHL_PRED);
if (isVShiftLImm(Op.getOperand(1), VT, false, Cnt) && Cnt < EltSize)
return DAG.getNode(AArch64ISD::VSHL, DL, VT, Op.getOperand(0),
@@ -8948,9 +10139,9 @@ SDValue AArch64TargetLowering::LowerVectorSRA_SRL_SHL(SDValue Op,
Op.getOperand(0), Op.getOperand(1));
case ISD::SRA:
case ISD::SRL:
- if (VT.isScalableVector()) {
- unsigned Opc = Op.getOpcode() == ISD::SRA ? AArch64ISD::SRA_MERGE_OP1
- : AArch64ISD::SRL_MERGE_OP1;
+ if (VT.isScalableVector() || useSVEForFixedLengthVectorVT(VT)) {
+ unsigned Opc = Op.getOpcode() == ISD::SRA ? AArch64ISD::SRA_PRED
+ : AArch64ISD::SRL_PRED;
return LowerToPredicatedOp(Op, DAG, Opc);
}
@@ -9002,7 +10193,7 @@ static SDValue EmitVectorComparison(SDValue LHS, SDValue RHS,
Fcmeq = DAG.getNode(AArch64ISD::FCMEQz, dl, VT, LHS);
else
Fcmeq = DAG.getNode(AArch64ISD::FCMEQ, dl, VT, LHS, RHS);
- return DAG.getNode(AArch64ISD::NOT, dl, VT, Fcmeq);
+ return DAG.getNOT(dl, Fcmeq, VT);
}
case AArch64CC::EQ:
if (IsZero)
@@ -9041,7 +10232,7 @@ static SDValue EmitVectorComparison(SDValue LHS, SDValue RHS,
Cmeq = DAG.getNode(AArch64ISD::CMEQz, dl, VT, LHS);
else
Cmeq = DAG.getNode(AArch64ISD::CMEQ, dl, VT, LHS, RHS);
- return DAG.getNode(AArch64ISD::NOT, dl, VT, Cmeq);
+ return DAG.getNOT(dl, Cmeq, VT);
}
case AArch64CC::EQ:
if (IsZero)
@@ -9082,6 +10273,9 @@ SDValue AArch64TargetLowering::LowerVSETCC(SDValue Op,
return LowerToPredicatedOp(Op, DAG, AArch64ISD::SETCC_MERGE_ZERO);
}
+ if (useSVEForFixedLengthVectorVT(Op.getOperand(0).getValueType()))
+ return LowerFixedLengthVectorSetccToSVE(Op, DAG);
+
ISD::CondCode CC = cast<CondCodeSDNode>(Op.getOperand(2))->get();
SDValue LHS = Op.getOperand(0);
SDValue RHS = Op.getOperand(1);
@@ -9154,6 +10348,51 @@ static SDValue getReductionSDNode(unsigned Op, SDLoc DL, SDValue ScalarOp,
SDValue AArch64TargetLowering::LowerVECREDUCE(SDValue Op,
SelectionDAG &DAG) const {
+ SDValue Src = Op.getOperand(0);
+
+ // Try to lower fixed length reductions to SVE.
+ EVT SrcVT = Src.getValueType();
+ bool OverrideNEON = Op.getOpcode() == ISD::VECREDUCE_AND ||
+ Op.getOpcode() == ISD::VECREDUCE_OR ||
+ Op.getOpcode() == ISD::VECREDUCE_XOR ||
+ Op.getOpcode() == ISD::VECREDUCE_FADD ||
+ (Op.getOpcode() != ISD::VECREDUCE_ADD &&
+ SrcVT.getVectorElementType() == MVT::i64);
+ if (SrcVT.isScalableVector() ||
+ useSVEForFixedLengthVectorVT(SrcVT, OverrideNEON)) {
+
+ if (SrcVT.getVectorElementType() == MVT::i1)
+ return LowerPredReductionToSVE(Op, DAG);
+
+ switch (Op.getOpcode()) {
+ case ISD::VECREDUCE_ADD:
+ return LowerReductionToSVE(AArch64ISD::UADDV_PRED, Op, DAG);
+ case ISD::VECREDUCE_AND:
+ return LowerReductionToSVE(AArch64ISD::ANDV_PRED, Op, DAG);
+ case ISD::VECREDUCE_OR:
+ return LowerReductionToSVE(AArch64ISD::ORV_PRED, Op, DAG);
+ case ISD::VECREDUCE_SMAX:
+ return LowerReductionToSVE(AArch64ISD::SMAXV_PRED, Op, DAG);
+ case ISD::VECREDUCE_SMIN:
+ return LowerReductionToSVE(AArch64ISD::SMINV_PRED, Op, DAG);
+ case ISD::VECREDUCE_UMAX:
+ return LowerReductionToSVE(AArch64ISD::UMAXV_PRED, Op, DAG);
+ case ISD::VECREDUCE_UMIN:
+ return LowerReductionToSVE(AArch64ISD::UMINV_PRED, Op, DAG);
+ case ISD::VECREDUCE_XOR:
+ return LowerReductionToSVE(AArch64ISD::EORV_PRED, Op, DAG);
+ case ISD::VECREDUCE_FADD:
+ return LowerReductionToSVE(AArch64ISD::FADDV_PRED, Op, DAG);
+ case ISD::VECREDUCE_FMAX:
+ return LowerReductionToSVE(AArch64ISD::FMAXNMV_PRED, Op, DAG);
+ case ISD::VECREDUCE_FMIN:
+ return LowerReductionToSVE(AArch64ISD::FMINNMV_PRED, Op, DAG);
+ default:
+ llvm_unreachable("Unhandled fixed length reduction");
+ }
+ }
+
+ // Lower NEON reductions.
SDLoc dl(Op);
switch (Op.getOpcode()) {
case ISD::VECREDUCE_ADD:
@@ -9167,18 +10406,16 @@ SDValue AArch64TargetLowering::LowerVECREDUCE(SDValue Op,
case ISD::VECREDUCE_UMIN:
return getReductionSDNode(AArch64ISD::UMINV, dl, Op, DAG);
case ISD::VECREDUCE_FMAX: {
- assert(Op->getFlags().hasNoNaNs() && "fmax vector reduction needs NoNaN flag");
return DAG.getNode(
ISD::INTRINSIC_WO_CHAIN, dl, Op.getValueType(),
DAG.getConstant(Intrinsic::aarch64_neon_fmaxnmv, dl, MVT::i32),
- Op.getOperand(0));
+ Src);
}
case ISD::VECREDUCE_FMIN: {
- assert(Op->getFlags().hasNoNaNs() && "fmin vector reduction needs NoNaN flag");
return DAG.getNode(
ISD::INTRINSIC_WO_CHAIN, dl, Op.getValueType(),
DAG.getConstant(Intrinsic::aarch64_neon_fminnmv, dl, MVT::i32),
- Op.getOperand(0));
+ Src);
}
default:
llvm_unreachable("Unhandled reduction");
@@ -9188,7 +10425,7 @@ SDValue AArch64TargetLowering::LowerVECREDUCE(SDValue Op,
SDValue AArch64TargetLowering::LowerATOMIC_LOAD_SUB(SDValue Op,
SelectionDAG &DAG) const {
auto &Subtarget = static_cast<const AArch64Subtarget &>(DAG.getSubtarget());
- if (!Subtarget.hasLSE())
+ if (!Subtarget.hasLSE() && !Subtarget.outlineAtomics())
return SDValue();
// LSE has an atomic load-add instruction, but not a load-sub.
@@ -9205,7 +10442,7 @@ SDValue AArch64TargetLowering::LowerATOMIC_LOAD_SUB(SDValue Op,
SDValue AArch64TargetLowering::LowerATOMIC_LOAD_AND(SDValue Op,
SelectionDAG &DAG) const {
auto &Subtarget = static_cast<const AArch64Subtarget &>(DAG.getSubtarget());
- if (!Subtarget.hasLSE())
+ if (!Subtarget.hasLSE() && !Subtarget.outlineAtomics())
return SDValue();
// LSE has an atomic load-clear instruction, but not a load-and.
@@ -9306,16 +10543,17 @@ SDValue AArch64TargetLowering::LowerVSCALE(SDValue Op,
/// Set the IntrinsicInfo for the `aarch64_sve_st<N>` intrinsics.
template <unsigned NumVecs>
-static bool setInfoSVEStN(AArch64TargetLowering::IntrinsicInfo &Info,
- const CallInst &CI) {
+static bool
+setInfoSVEStN(const AArch64TargetLowering &TLI, const DataLayout &DL,
+ AArch64TargetLowering::IntrinsicInfo &Info, const CallInst &CI) {
Info.opc = ISD::INTRINSIC_VOID;
// Retrieve EC from first vector argument.
- const EVT VT = EVT::getEVT(CI.getArgOperand(0)->getType());
+ const EVT VT = TLI.getMemValueType(DL, CI.getArgOperand(0)->getType());
ElementCount EC = VT.getVectorElementCount();
#ifndef NDEBUG
// Check the assumption that all input vectors are the same type.
for (unsigned I = 0; I < NumVecs; ++I)
- assert(VT == EVT::getEVT(CI.getArgOperand(I)->getType()) &&
+ assert(VT == TLI.getMemValueType(DL, CI.getArgOperand(I)->getType()) &&
"Invalid type.");
#endif
// memVT is `NumVecs * VT`.
@@ -9338,11 +10576,11 @@ bool AArch64TargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info,
auto &DL = I.getModule()->getDataLayout();
switch (Intrinsic) {
case Intrinsic::aarch64_sve_st2:
- return setInfoSVEStN<2>(Info, I);
+ return setInfoSVEStN<2>(*this, DL, Info, I);
case Intrinsic::aarch64_sve_st3:
- return setInfoSVEStN<3>(Info, I);
+ return setInfoSVEStN<3>(*this, DL, Info, I);
case Intrinsic::aarch64_sve_st4:
- return setInfoSVEStN<4>(Info, I);
+ return setInfoSVEStN<4>(*this, DL, Info, I);
case Intrinsic::aarch64_neon_ld2:
case Intrinsic::aarch64_neon_ld3:
case Intrinsic::aarch64_neon_ld4:
@@ -9498,15 +10736,15 @@ bool AArch64TargetLowering::shouldReduceLoadWidth(SDNode *Load,
bool AArch64TargetLowering::isTruncateFree(Type *Ty1, Type *Ty2) const {
if (!Ty1->isIntegerTy() || !Ty2->isIntegerTy())
return false;
- unsigned NumBits1 = Ty1->getPrimitiveSizeInBits();
- unsigned NumBits2 = Ty2->getPrimitiveSizeInBits();
+ uint64_t NumBits1 = Ty1->getPrimitiveSizeInBits().getFixedSize();
+ uint64_t NumBits2 = Ty2->getPrimitiveSizeInBits().getFixedSize();
return NumBits1 > NumBits2;
}
bool AArch64TargetLowering::isTruncateFree(EVT VT1, EVT VT2) const {
if (VT1.isVector() || VT2.isVector() || !VT1.isInteger() || !VT2.isInteger())
return false;
- unsigned NumBits1 = VT1.getSizeInBits();
- unsigned NumBits2 = VT2.getSizeInBits();
+ uint64_t NumBits1 = VT1.getFixedSizeInBits();
+ uint64_t NumBits2 = VT2.getFixedSizeInBits();
return NumBits1 > NumBits2;
}
@@ -9748,6 +10986,43 @@ bool AArch64TargetLowering::shouldSinkOperands(
return true;
}
+ case Instruction::Mul: {
+ bool IsProfitable = false;
+ for (auto &Op : I->operands()) {
+ // Make sure we are not already sinking this operand
+ if (any_of(Ops, [&](Use *U) { return U->get() == Op; }))
+ continue;
+
+ ShuffleVectorInst *Shuffle = dyn_cast<ShuffleVectorInst>(Op);
+ if (!Shuffle || !Shuffle->isZeroEltSplat())
+ continue;
+
+ Value *ShuffleOperand = Shuffle->getOperand(0);
+ InsertElementInst *Insert = dyn_cast<InsertElementInst>(ShuffleOperand);
+ if (!Insert)
+ continue;
+
+ Instruction *OperandInstr = dyn_cast<Instruction>(Insert->getOperand(1));
+ if (!OperandInstr)
+ continue;
+
+ ConstantInt *ElementConstant =
+ dyn_cast<ConstantInt>(Insert->getOperand(2));
+ // Check that the insertelement is inserting into element 0
+ if (!ElementConstant || ElementConstant->getZExtValue() != 0)
+ continue;
+
+ unsigned Opcode = OperandInstr->getOpcode();
+ if (Opcode != Instruction::SExt && Opcode != Instruction::ZExt)
+ continue;
+
+ Ops.push_back(&Shuffle->getOperandUse(0));
+ Ops.push_back(&Op);
+ IsProfitable = true;
+ }
+
+ return IsProfitable;
+ }
default:
return false;
}
@@ -10083,11 +11358,12 @@ SDValue AArch64TargetLowering::LowerSVEStructLoad(unsigned Intrinsic,
{Intrinsic::aarch64_sve_ld4, {4, AArch64ISD::SVE_LD4_MERGE_ZERO}}};
std::tie(N, Opcode) = IntrinsicMap[Intrinsic];
- assert(VT.getVectorElementCount().Min % N == 0 &&
+ assert(VT.getVectorElementCount().getKnownMinValue() % N == 0 &&
"invalid tuple vector type!");
- EVT SplitVT = EVT::getVectorVT(*DAG.getContext(), VT.getVectorElementType(),
- VT.getVectorElementCount() / N);
+ EVT SplitVT =
+ EVT::getVectorVT(*DAG.getContext(), VT.getVectorElementType(),
+ VT.getVectorElementCount().divideCoefficientBy(N));
assert(isTypeLegal(SplitVT));
SmallVector<EVT, 5> VTs(N, SplitVT);
@@ -10378,32 +11654,77 @@ static SDValue foldVectorXorShiftIntoCmp(SDNode *N, SelectionDAG &DAG,
return DAG.getNode(AArch64ISD::CMGEz, SDLoc(N), VT, Shift.getOperand(0));
}
-// Generate SUBS and CSEL for integer abs.
-static SDValue performIntegerAbsCombine(SDNode *N, SelectionDAG &DAG) {
- EVT VT = N->getValueType(0);
+// VECREDUCE_ADD( EXTEND(v16i8_type) ) to
+// VECREDUCE_ADD( DOTv16i8(v16i8_type) )
+static SDValue performVecReduceAddCombine(SDNode *N, SelectionDAG &DAG,
+ const AArch64Subtarget *ST) {
+ SDValue Op0 = N->getOperand(0);
+ if (!ST->hasDotProd() || N->getValueType(0) != MVT::i32)
+ return SDValue();
- SDValue N0 = N->getOperand(0);
- SDValue N1 = N->getOperand(1);
- SDLoc DL(N);
+ if (Op0.getValueType().getVectorElementType() != MVT::i32)
+ return SDValue();
- // Check pattern of XOR(ADD(X,Y), Y) where Y is SRA(X, size(X)-1)
- // and change it to SUB and CSEL.
- if (VT.isInteger() && N->getOpcode() == ISD::XOR &&
- N0.getOpcode() == ISD::ADD && N0.getOperand(1) == N1 &&
- N1.getOpcode() == ISD::SRA && N1.getOperand(0) == N0.getOperand(0))
- if (ConstantSDNode *Y1C = dyn_cast<ConstantSDNode>(N1.getOperand(1)))
- if (Y1C->getAPIntValue() == VT.getSizeInBits() - 1) {
- SDValue Neg = DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT),
- N0.getOperand(0));
- // Generate SUBS & CSEL.
- SDValue Cmp =
- DAG.getNode(AArch64ISD::SUBS, DL, DAG.getVTList(VT, MVT::i32),
- N0.getOperand(0), DAG.getConstant(0, DL, VT));
- return DAG.getNode(AArch64ISD::CSEL, DL, VT, N0.getOperand(0), Neg,
- DAG.getConstant(AArch64CC::PL, DL, MVT::i32),
- SDValue(Cmp.getNode(), 1));
- }
- return SDValue();
+ unsigned ExtOpcode = Op0.getOpcode();
+ if (ExtOpcode != ISD::ZERO_EXTEND && ExtOpcode != ISD::SIGN_EXTEND)
+ return SDValue();
+
+ EVT Op0VT = Op0.getOperand(0).getValueType();
+ if (Op0VT != MVT::v16i8)
+ return SDValue();
+
+ SDLoc DL(Op0);
+ SDValue Ones = DAG.getConstant(1, DL, Op0VT);
+ SDValue Zeros = DAG.getConstant(0, DL, MVT::v4i32);
+ auto DotIntrisic = (ExtOpcode == ISD::ZERO_EXTEND)
+ ? Intrinsic::aarch64_neon_udot
+ : Intrinsic::aarch64_neon_sdot;
+ SDValue Dot = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, Zeros.getValueType(),
+ DAG.getConstant(DotIntrisic, DL, MVT::i32), Zeros,
+ Ones, Op0.getOperand(0));
+ return DAG.getNode(ISD::VECREDUCE_ADD, DL, N->getValueType(0), Dot);
+}
+
+// Given a ABS node, detect the following pattern:
+// (ABS (SUB (EXTEND a), (EXTEND b))).
+// Generates UABD/SABD instruction.
+static SDValue performABSCombine(SDNode *N, SelectionDAG &DAG,
+ TargetLowering::DAGCombinerInfo &DCI,
+ const AArch64Subtarget *Subtarget) {
+ SDValue AbsOp1 = N->getOperand(0);
+ SDValue Op0, Op1;
+
+ if (AbsOp1.getOpcode() != ISD::SUB)
+ return SDValue();
+
+ Op0 = AbsOp1.getOperand(0);
+ Op1 = AbsOp1.getOperand(1);
+
+ unsigned Opc0 = Op0.getOpcode();
+ // Check if the operands of the sub are (zero|sign)-extended.
+ if (Opc0 != Op1.getOpcode() ||
+ (Opc0 != ISD::ZERO_EXTEND && Opc0 != ISD::SIGN_EXTEND))
+ return SDValue();
+
+ EVT VectorT1 = Op0.getOperand(0).getValueType();
+ EVT VectorT2 = Op1.getOperand(0).getValueType();
+ // Check if vectors are of same type and valid size.
+ uint64_t Size = VectorT1.getFixedSizeInBits();
+ if (VectorT1 != VectorT2 || (Size != 64 && Size != 128))
+ return SDValue();
+
+ // Check if vector element types are valid.
+ EVT VT1 = VectorT1.getVectorElementType();
+ if (VT1 != MVT::i8 && VT1 != MVT::i16 && VT1 != MVT::i32)
+ return SDValue();
+
+ Op0 = Op0.getOperand(0);
+ Op1 = Op1.getOperand(0);
+ unsigned ABDOpcode =
+ (Opc0 == ISD::SIGN_EXTEND) ? AArch64ISD::SABD : AArch64ISD::UABD;
+ SDValue ABD =
+ DAG.getNode(ABDOpcode, SDLoc(N), Op0->getValueType(0), Op0, Op1);
+ return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), N->getValueType(0), ABD);
}
static SDValue performXorCombine(SDNode *N, SelectionDAG &DAG,
@@ -10412,10 +11733,7 @@ static SDValue performXorCombine(SDNode *N, SelectionDAG &DAG,
if (DCI.isBeforeLegalizeOps())
return SDValue();
- if (SDValue Cmp = foldVectorXorShiftIntoCmp(N, DAG, Subtarget))
- return Cmp;
-
- return performIntegerAbsCombine(N, DAG);
+ return foldVectorXorShiftIntoCmp(N, DAG, Subtarget);
}
SDValue
@@ -10474,9 +11792,157 @@ static bool IsSVECntIntrinsic(SDValue S) {
return false;
}
+/// Calculates what the pre-extend type is, based on the extension
+/// operation node provided by \p Extend.
+///
+/// In the case that \p Extend is a SIGN_EXTEND or a ZERO_EXTEND, the
+/// pre-extend type is pulled directly from the operand, while other extend
+/// operations need a bit more inspection to get this information.
+///
+/// \param Extend The SDNode from the DAG that represents the extend operation
+/// \param DAG The SelectionDAG hosting the \p Extend node
+///
+/// \returns The type representing the \p Extend source type, or \p MVT::Other
+/// if no valid type can be determined
+static EVT calculatePreExtendType(SDValue Extend, SelectionDAG &DAG) {
+ switch (Extend.getOpcode()) {
+ case ISD::SIGN_EXTEND:
+ case ISD::ZERO_EXTEND:
+ return Extend.getOperand(0).getValueType();
+ case ISD::AssertSext:
+ case ISD::AssertZext:
+ case ISD::SIGN_EXTEND_INREG: {
+ VTSDNode *TypeNode = dyn_cast<VTSDNode>(Extend.getOperand(1));
+ if (!TypeNode)
+ return MVT::Other;
+ return TypeNode->getVT();
+ }
+ case ISD::AND: {
+ ConstantSDNode *Constant =
+ dyn_cast<ConstantSDNode>(Extend.getOperand(1).getNode());
+ if (!Constant)
+ return MVT::Other;
+
+ uint32_t Mask = Constant->getZExtValue();
+
+ if (Mask == UCHAR_MAX)
+ return MVT::i8;
+ else if (Mask == USHRT_MAX)
+ return MVT::i16;
+ else if (Mask == UINT_MAX)
+ return MVT::i32;
+
+ return MVT::Other;
+ }
+ default:
+ return MVT::Other;
+ }
+
+ llvm_unreachable("Code path unhandled in calculatePreExtendType!");
+}
+
+/// Combines a dup(sext/zext) node pattern into sext/zext(dup)
+/// making use of the vector SExt/ZExt rather than the scalar SExt/ZExt
+static SDValue performCommonVectorExtendCombine(SDValue VectorShuffle,
+ SelectionDAG &DAG) {
+
+ ShuffleVectorSDNode *ShuffleNode =
+ dyn_cast<ShuffleVectorSDNode>(VectorShuffle.getNode());
+ if (!ShuffleNode)
+ return SDValue();
+
+ // Ensuring the mask is zero before continuing
+ if (!ShuffleNode->isSplat() || ShuffleNode->getSplatIndex() != 0)
+ return SDValue();
+
+ SDValue InsertVectorElt = VectorShuffle.getOperand(0);
+
+ if (InsertVectorElt.getOpcode() != ISD::INSERT_VECTOR_ELT)
+ return SDValue();
+
+ SDValue InsertLane = InsertVectorElt.getOperand(2);
+ ConstantSDNode *Constant = dyn_cast<ConstantSDNode>(InsertLane.getNode());
+ // Ensures the insert is inserting into lane 0
+ if (!Constant || Constant->getZExtValue() != 0)
+ return SDValue();
+
+ SDValue Extend = InsertVectorElt.getOperand(1);
+ unsigned ExtendOpcode = Extend.getOpcode();
+
+ bool IsSExt = ExtendOpcode == ISD::SIGN_EXTEND ||
+ ExtendOpcode == ISD::SIGN_EXTEND_INREG ||
+ ExtendOpcode == ISD::AssertSext;
+ if (!IsSExt && ExtendOpcode != ISD::ZERO_EXTEND &&
+ ExtendOpcode != ISD::AssertZext && ExtendOpcode != ISD::AND)
+ return SDValue();
+
+ EVT TargetType = VectorShuffle.getValueType();
+ EVT PreExtendType = calculatePreExtendType(Extend, DAG);
+
+ if ((TargetType != MVT::v8i16 && TargetType != MVT::v4i32 &&
+ TargetType != MVT::v2i64) ||
+ (PreExtendType == MVT::Other))
+ return SDValue();
+
+ // Restrict valid pre-extend data type
+ if (PreExtendType != MVT::i8 && PreExtendType != MVT::i16 &&
+ PreExtendType != MVT::i32)
+ return SDValue();
+
+ EVT PreExtendVT = TargetType.changeVectorElementType(PreExtendType);
+
+ if (PreExtendVT.getVectorElementCount() != TargetType.getVectorElementCount())
+ return SDValue();
+
+ if (TargetType.getScalarSizeInBits() != PreExtendVT.getScalarSizeInBits() * 2)
+ return SDValue();
+
+ SDLoc DL(VectorShuffle);
+
+ SDValue InsertVectorNode = DAG.getNode(
+ InsertVectorElt.getOpcode(), DL, PreExtendVT, DAG.getUNDEF(PreExtendVT),
+ DAG.getAnyExtOrTrunc(Extend.getOperand(0), DL, PreExtendType),
+ DAG.getConstant(0, DL, MVT::i64));
+
+ std::vector<int> ShuffleMask(TargetType.getVectorElementCount().getValue());
+
+ SDValue VectorShuffleNode =
+ DAG.getVectorShuffle(PreExtendVT, DL, InsertVectorNode,
+ DAG.getUNDEF(PreExtendVT), ShuffleMask);
+
+ SDValue ExtendNode = DAG.getNode(IsSExt ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND,
+ DL, TargetType, VectorShuffleNode);
+
+ return ExtendNode;
+}
+
+/// Combines a mul(dup(sext/zext)) node pattern into mul(sext/zext(dup))
+/// making use of the vector SExt/ZExt rather than the scalar SExt/ZExt
+static SDValue performMulVectorExtendCombine(SDNode *Mul, SelectionDAG &DAG) {
+ // If the value type isn't a vector, none of the operands are going to be dups
+ if (!Mul->getValueType(0).isVector())
+ return SDValue();
+
+ SDValue Op0 = performCommonVectorExtendCombine(Mul->getOperand(0), DAG);
+ SDValue Op1 = performCommonVectorExtendCombine(Mul->getOperand(1), DAG);
+
+ // Neither operands have been changed, don't make any further changes
+ if (!Op0 && !Op1)
+ return SDValue();
+
+ SDLoc DL(Mul);
+ return DAG.getNode(Mul->getOpcode(), DL, Mul->getValueType(0),
+ Op0 ? Op0 : Mul->getOperand(0),
+ Op1 ? Op1 : Mul->getOperand(1));
+}
+
static SDValue performMulCombine(SDNode *N, SelectionDAG &DAG,
TargetLowering::DAGCombinerInfo &DCI,
const AArch64Subtarget *Subtarget) {
+
+ if (SDValue Ext = performMulVectorExtendCombine(N, DAG))
+ return Ext;
+
if (DCI.isBeforeLegalizeOps())
return SDValue();
@@ -11011,6 +12477,9 @@ static SDValue performSVEAndCombine(SDNode *N,
return DAG.getNode(Opc, DL, N->getValueType(0), And);
}
+ if (!EnableCombineMGatherIntrinsics)
+ return SDValue();
+
SDValue Mask = N->getOperand(1);
if (!Src.hasOneUse())
@@ -11064,6 +12533,11 @@ static SDValue performANDCombine(SDNode *N,
if (VT.isScalableVector())
return performSVEAndCombine(N, DCI);
+ // The combining code below works only for NEON vectors. In particular, it
+ // does not work for SVE when dealing with vectors wider than 128 bits.
+ if (!(VT.is64BitVector() || VT.is128BitVector()))
+ return SDValue();
+
BuildVectorSDNode *BVN =
dyn_cast<BuildVectorSDNode>(N->getOperand(1).getNode());
if (!BVN)
@@ -11124,6 +12598,143 @@ static SDValue performSRLCombine(SDNode *N,
return SDValue();
}
+// Attempt to form urhadd(OpA, OpB) from
+// truncate(vlshr(sub(zext(OpB), xor(zext(OpA), Ones(ElemSizeInBits))), 1))
+// or uhadd(OpA, OpB) from truncate(vlshr(add(zext(OpA), zext(OpB)), 1)).
+// The original form of the first expression is
+// truncate(srl(add(zext(OpB), add(zext(OpA), 1)), 1)) and the
+// (OpA + OpB + 1) subexpression will have been changed to (OpB - (~OpA)).
+// Before this function is called the srl will have been lowered to
+// AArch64ISD::VLSHR.
+// This pass can also recognize signed variants of the patterns that use sign
+// extension instead of zero extension and form a srhadd(OpA, OpB) or a
+// shadd(OpA, OpB) from them.
+static SDValue
+performVectorTruncateCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
+ SelectionDAG &DAG) {
+ EVT VT = N->getValueType(0);
+
+ // Since we are looking for a right shift by a constant value of 1 and we are
+ // operating on types at least 16 bits in length (sign/zero extended OpA and
+ // OpB, which are at least 8 bits), it follows that the truncate will always
+ // discard the shifted-in bit and therefore the right shift will be logical
+ // regardless of the signedness of OpA and OpB.
+ SDValue Shift = N->getOperand(0);
+ if (Shift.getOpcode() != AArch64ISD::VLSHR)
+ return SDValue();
+
+ // Is the right shift using an immediate value of 1?
+ uint64_t ShiftAmount = Shift.getConstantOperandVal(1);
+ if (ShiftAmount != 1)
+ return SDValue();
+
+ SDValue ExtendOpA, ExtendOpB;
+ SDValue ShiftOp0 = Shift.getOperand(0);
+ unsigned ShiftOp0Opc = ShiftOp0.getOpcode();
+ if (ShiftOp0Opc == ISD::SUB) {
+
+ SDValue Xor = ShiftOp0.getOperand(1);
+ if (Xor.getOpcode() != ISD::XOR)
+ return SDValue();
+
+ // Is the XOR using a constant amount of all ones in the right hand side?
+ uint64_t C;
+ if (!isAllConstantBuildVector(Xor.getOperand(1), C))
+ return SDValue();
+
+ unsigned ElemSizeInBits = VT.getScalarSizeInBits();
+ APInt CAsAPInt(ElemSizeInBits, C);
+ if (CAsAPInt != APInt::getAllOnesValue(ElemSizeInBits))
+ return SDValue();
+
+ ExtendOpA = Xor.getOperand(0);
+ ExtendOpB = ShiftOp0.getOperand(0);
+ } else if (ShiftOp0Opc == ISD::ADD) {
+ ExtendOpA = ShiftOp0.getOperand(0);
+ ExtendOpB = ShiftOp0.getOperand(1);
+ } else
+ return SDValue();
+
+ unsigned ExtendOpAOpc = ExtendOpA.getOpcode();
+ unsigned ExtendOpBOpc = ExtendOpB.getOpcode();
+ if (!(ExtendOpAOpc == ExtendOpBOpc &&
+ (ExtendOpAOpc == ISD::ZERO_EXTEND || ExtendOpAOpc == ISD::SIGN_EXTEND)))
+ return SDValue();
+
+ // Is the result of the right shift being truncated to the same value type as
+ // the original operands, OpA and OpB?
+ SDValue OpA = ExtendOpA.getOperand(0);
+ SDValue OpB = ExtendOpB.getOperand(0);
+ EVT OpAVT = OpA.getValueType();
+ assert(ExtendOpA.getValueType() == ExtendOpB.getValueType());
+ if (!(VT == OpAVT && OpAVT == OpB.getValueType()))
+ return SDValue();
+
+ SDLoc DL(N);
+ bool IsSignExtend = ExtendOpAOpc == ISD::SIGN_EXTEND;
+ bool IsRHADD = ShiftOp0Opc == ISD::SUB;
+ unsigned HADDOpc = IsSignExtend
+ ? (IsRHADD ? AArch64ISD::SRHADD : AArch64ISD::SHADD)
+ : (IsRHADD ? AArch64ISD::URHADD : AArch64ISD::UHADD);
+ SDValue ResultHADD = DAG.getNode(HADDOpc, DL, VT, OpA, OpB);
+
+ return ResultHADD;
+}
+
+static bool hasPairwiseAdd(unsigned Opcode, EVT VT, bool FullFP16) {
+ switch (Opcode) {
+ case ISD::FADD:
+ return (FullFP16 && VT == MVT::f16) || VT == MVT::f32 || VT == MVT::f64;
+ case ISD::ADD:
+ return VT == MVT::i64;
+ default:
+ return false;
+ }
+}
+
+static SDValue performExtractVectorEltCombine(SDNode *N, SelectionDAG &DAG) {
+ SDValue N0 = N->getOperand(0), N1 = N->getOperand(1);
+ ConstantSDNode *ConstantN1 = dyn_cast<ConstantSDNode>(N1);
+
+ EVT VT = N->getValueType(0);
+ const bool FullFP16 =
+ static_cast<const AArch64Subtarget &>(DAG.getSubtarget()).hasFullFP16();
+
+ // Rewrite for pairwise fadd pattern
+ // (f32 (extract_vector_elt
+ // (fadd (vXf32 Other)
+ // (vector_shuffle (vXf32 Other) undef <1,X,...> )) 0))
+ // ->
+ // (f32 (fadd (extract_vector_elt (vXf32 Other) 0)
+ // (extract_vector_elt (vXf32 Other) 1))
+ if (ConstantN1 && ConstantN1->getZExtValue() == 0 &&
+ hasPairwiseAdd(N0->getOpcode(), VT, FullFP16)) {
+ SDLoc DL(N0);
+ SDValue N00 = N0->getOperand(0);
+ SDValue N01 = N0->getOperand(1);
+
+ ShuffleVectorSDNode *Shuffle = dyn_cast<ShuffleVectorSDNode>(N01);
+ SDValue Other = N00;
+
+ // And handle the commutative case.
+ if (!Shuffle) {
+ Shuffle = dyn_cast<ShuffleVectorSDNode>(N00);
+ Other = N01;
+ }
+
+ if (Shuffle && Shuffle->getMaskElt(0) == 1 &&
+ Other == Shuffle->getOperand(0)) {
+ return DAG.getNode(N0->getOpcode(), DL, VT,
+ DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Other,
+ DAG.getConstant(0, DL, MVT::i64)),
+ DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Other,
+ DAG.getConstant(1, DL, MVT::i64)));
+ }
+ }
+
+ return SDValue();
+}
+
static SDValue performConcatVectorsCombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI,
SelectionDAG &DAG) {
@@ -11169,9 +12780,9 @@ static SDValue performConcatVectorsCombine(SDNode *N,
if (DCI.isBeforeLegalizeOps())
return SDValue();
- // Optimise concat_vectors of two [us]rhadds that use extracted subvectors
- // from the same original vectors. Combine these into a single [us]rhadd that
- // operates on the two original vectors. Example:
+ // Optimise concat_vectors of two [us]rhadds or [us]hadds that use extracted
+ // subvectors from the same original vectors. Combine these into a single
+ // [us]rhadd or [us]hadd that operates on the two original vectors. Example:
// (v16i8 (concat_vectors (v8i8 (urhadd (extract_subvector (v16i8 OpA, <0>),
// extract_subvector (v16i8 OpB,
// <0>))),
@@ -11181,7 +12792,8 @@ static SDValue performConcatVectorsCombine(SDNode *N,
// ->
// (v16i8(urhadd(v16i8 OpA, v16i8 OpB)))
if (N->getNumOperands() == 2 && N0Opc == N1Opc &&
- (N0Opc == AArch64ISD::URHADD || N0Opc == AArch64ISD::SRHADD)) {
+ (N0Opc == AArch64ISD::URHADD || N0Opc == AArch64ISD::SRHADD ||
+ N0Opc == AArch64ISD::UHADD || N0Opc == AArch64ISD::SHADD)) {
SDValue N00 = N0->getOperand(0);
SDValue N01 = N0->getOperand(1);
SDValue N10 = N1->getOperand(0);
@@ -11486,6 +13098,43 @@ static SDValue performSetccAddFolding(SDNode *Op, SelectionDAG &DAG) {
return DAG.getNode(AArch64ISD::CSEL, dl, VT, RHS, LHS, CCVal, Cmp);
}
+// ADD(UADDV a, UADDV b) --> UADDV(ADD a, b)
+static SDValue performUADDVCombine(SDNode *N, SelectionDAG &DAG) {
+ EVT VT = N->getValueType(0);
+ // Only scalar integer and vector types.
+ if (N->getOpcode() != ISD::ADD || !VT.isScalarInteger())
+ return SDValue();
+
+ SDValue LHS = N->getOperand(0);
+ SDValue RHS = N->getOperand(1);
+ if (LHS.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
+ RHS.getOpcode() != ISD::EXTRACT_VECTOR_ELT || LHS.getValueType() != VT)
+ return SDValue();
+
+ auto *LHSN1 = dyn_cast<ConstantSDNode>(LHS->getOperand(1));
+ auto *RHSN1 = dyn_cast<ConstantSDNode>(RHS->getOperand(1));
+ if (!LHSN1 || LHSN1 != RHSN1 || !RHSN1->isNullValue())
+ return SDValue();
+
+ SDValue Op1 = LHS->getOperand(0);
+ SDValue Op2 = RHS->getOperand(0);
+ EVT OpVT1 = Op1.getValueType();
+ EVT OpVT2 = Op2.getValueType();
+ if (Op1.getOpcode() != AArch64ISD::UADDV || OpVT1 != OpVT2 ||
+ Op2.getOpcode() != AArch64ISD::UADDV ||
+ OpVT1.getVectorElementType() != VT)
+ return SDValue();
+
+ SDValue Val1 = Op1.getOperand(0);
+ SDValue Val2 = Op2.getOperand(0);
+ EVT ValVT = Val1->getValueType(0);
+ SDLoc DL(N);
+ SDValue AddVal = DAG.getNode(ISD::ADD, DL, ValVT, Val1, Val2);
+ return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT,
+ DAG.getNode(AArch64ISD::UADDV, DL, ValVT, AddVal),
+ DAG.getConstant(0, DL, MVT::i64));
+}
+
// The basic add/sub long vector instructions have variants with "2" on the end
// which act on the high-half of their inputs. They are normally matched by
// patterns like:
@@ -11539,6 +13188,16 @@ static SDValue performAddSubLongCombine(SDNode *N,
return DAG.getNode(N->getOpcode(), SDLoc(N), VT, LHS, RHS);
}
+static SDValue performAddSubCombine(SDNode *N,
+ TargetLowering::DAGCombinerInfo &DCI,
+ SelectionDAG &DAG) {
+ // Try to change sum of two reductions.
+ if (SDValue Val = performUADDVCombine(N, DAG))
+ return Val;
+
+ return performAddSubLongCombine(N, DCI, DAG);
+}
+
// Massage DAGs which we can use the high-half "long" operations on into
// something isel will recognize better. E.g.
//
@@ -11552,8 +13211,8 @@ static SDValue tryCombineLongOpWithDup(unsigned IID, SDNode *N,
if (DCI.isBeforeLegalizeOps())
return SDValue();
- SDValue LHS = N->getOperand(1);
- SDValue RHS = N->getOperand(2);
+ SDValue LHS = N->getOperand((IID == Intrinsic::not_intrinsic) ? 0 : 1);
+ SDValue RHS = N->getOperand((IID == Intrinsic::not_intrinsic) ? 1 : 2);
assert(LHS.getValueType().is64BitVector() &&
RHS.getValueType().is64BitVector() &&
"unexpected shape for long operation");
@@ -11571,6 +13230,9 @@ static SDValue tryCombineLongOpWithDup(unsigned IID, SDNode *N,
return SDValue();
}
+ if (IID == Intrinsic::not_intrinsic)
+ return DAG.getNode(N->getOpcode(), SDLoc(N), N->getValueType(0), LHS, RHS);
+
return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, SDLoc(N), N->getValueType(0),
N->getOperand(0), LHS, RHS);
}
@@ -11669,34 +13331,6 @@ static SDValue combineAcrossLanesIntrinsic(unsigned Opc, SDNode *N,
DAG.getConstant(0, dl, MVT::i64));
}
-static SDValue LowerSVEIntReduction(SDNode *N, unsigned Opc,
- SelectionDAG &DAG) {
- SDLoc dl(N);
- LLVMContext &Ctx = *DAG.getContext();
- const TargetLowering &TLI = DAG.getTargetLoweringInfo();
-
- EVT VT = N->getValueType(0);
- SDValue Pred = N->getOperand(1);
- SDValue Data = N->getOperand(2);
- EVT DataVT = Data.getValueType();
-
- if (DataVT.getVectorElementType().isScalarInteger() &&
- (VT == MVT::i8 || VT == MVT::i16 || VT == MVT::i32 || VT == MVT::i64)) {
- if (!TLI.isTypeLegal(DataVT))
- return SDValue();
-
- EVT OutputVT = EVT::getVectorVT(Ctx, VT,
- AArch64::NeonBitsPerVector / VT.getSizeInBits());
- SDValue Reduce = DAG.getNode(Opc, dl, OutputVT, Pred, Data);
- SDValue Zero = DAG.getConstant(0, dl, MVT::i64);
- SDValue Result = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, VT, Reduce, Zero);
-
- return Result;
- }
-
- return SDValue();
-}
-
static SDValue LowerSVEIntrinsicIndex(SDNode *N, SelectionDAG &DAG) {
SDLoc DL(N);
SDValue Op1 = N->getOperand(1);
@@ -11739,7 +13373,8 @@ static SDValue LowerSVEIntrinsicEXT(SDNode *N, SelectionDAG &DAG) {
unsigned ElemSize = VT.getVectorElementType().getSizeInBits() / 8;
unsigned ByteSize = VT.getSizeInBits().getKnownMinSize() / 8;
- EVT ByteVT = EVT::getVectorVT(Ctx, MVT::i8, { ByteSize, true });
+ EVT ByteVT =
+ EVT::getVectorVT(Ctx, MVT::i8, ElementCount::getScalable(ByteSize));
// Convert everything to the domain of EXT (i.e bytes).
SDValue Op0 = DAG.getNode(ISD::BITCAST, dl, ByteVT, N->getOperand(1));
@@ -11839,6 +13474,25 @@ static SDValue getPTest(SelectionDAG &DAG, EVT VT, SDValue Pg, SDValue Op,
return DAG.getZExtOrTrunc(Res, DL, VT);
}
+static SDValue combineSVEReductionInt(SDNode *N, unsigned Opc,
+ SelectionDAG &DAG) {
+ SDLoc DL(N);
+
+ SDValue Pred = N->getOperand(1);
+ SDValue VecToReduce = N->getOperand(2);
+
+ // NOTE: The integer reduction's result type is not always linked to the
+ // operand's element type so we construct it from the intrinsic's result type.
+ EVT ReduceVT = getPackedSVEVectorVT(N->getValueType(0));
+ SDValue Reduce = DAG.getNode(Opc, DL, ReduceVT, Pred, VecToReduce);
+
+ // SVE reductions set the whole vector register with the first element
+ // containing the reduction result, which we'll now extract.
+ SDValue Zero = DAG.getConstant(0, DL, MVT::i64);
+ return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, N->getValueType(0), Reduce,
+ Zero);
+}
+
static SDValue combineSVEReductionFP(SDNode *N, unsigned Opc,
SelectionDAG &DAG) {
SDLoc DL(N);
@@ -11879,6 +13533,25 @@ static SDValue combineSVEReductionOrderedFP(SDNode *N, unsigned Opc,
Zero);
}
+// If a merged operation has no inactive lanes we can relax it to a predicated
+// or unpredicated operation, which potentially allows better isel (perhaps
+// using immediate forms) or relaxing register reuse requirements.
+static SDValue convertMergedOpToPredOp(SDNode *N, unsigned PredOpc,
+ SelectionDAG &DAG) {
+ assert(N->getOpcode() == ISD::INTRINSIC_WO_CHAIN && "Expected intrinsic!");
+ assert(N->getNumOperands() == 4 && "Expected 3 operand intrinsic!");
+ SDValue Pg = N->getOperand(1);
+
+ // ISD way to specify an all active predicate.
+ if ((Pg.getOpcode() == AArch64ISD::PTRUE) &&
+ (Pg.getConstantOperandVal(0) == AArch64SVEPredPattern::all))
+ return DAG.getNode(PredOpc, SDLoc(N), N->getValueType(0), Pg,
+ N->getOperand(2), N->getOperand(3));
+
+ // FUTURE: SplatVector(true)
+ return SDValue();
+}
+
static SDValue performIntrinsicCombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI,
const AArch64Subtarget *Subtarget) {
@@ -11933,20 +13606,28 @@ static SDValue performIntrinsicCombine(SDNode *N,
case Intrinsic::aarch64_crc32h:
case Intrinsic::aarch64_crc32ch:
return tryCombineCRC32(0xffff, N, DAG);
+ case Intrinsic::aarch64_sve_saddv:
+ // There is no i64 version of SADDV because the sign is irrelevant.
+ if (N->getOperand(2)->getValueType(0).getVectorElementType() == MVT::i64)
+ return combineSVEReductionInt(N, AArch64ISD::UADDV_PRED, DAG);
+ else
+ return combineSVEReductionInt(N, AArch64ISD::SADDV_PRED, DAG);
+ case Intrinsic::aarch64_sve_uaddv:
+ return combineSVEReductionInt(N, AArch64ISD::UADDV_PRED, DAG);
case Intrinsic::aarch64_sve_smaxv:
- return LowerSVEIntReduction(N, AArch64ISD::SMAXV_PRED, DAG);
+ return combineSVEReductionInt(N, AArch64ISD::SMAXV_PRED, DAG);
case Intrinsic::aarch64_sve_umaxv:
- return LowerSVEIntReduction(N, AArch64ISD::UMAXV_PRED, DAG);
+ return combineSVEReductionInt(N, AArch64ISD::UMAXV_PRED, DAG);
case Intrinsic::aarch64_sve_sminv:
- return LowerSVEIntReduction(N, AArch64ISD::SMINV_PRED, DAG);
+ return combineSVEReductionInt(N, AArch64ISD::SMINV_PRED, DAG);
case Intrinsic::aarch64_sve_uminv:
- return LowerSVEIntReduction(N, AArch64ISD::UMINV_PRED, DAG);
+ return combineSVEReductionInt(N, AArch64ISD::UMINV_PRED, DAG);
case Intrinsic::aarch64_sve_orv:
- return LowerSVEIntReduction(N, AArch64ISD::ORV_PRED, DAG);
+ return combineSVEReductionInt(N, AArch64ISD::ORV_PRED, DAG);
case Intrinsic::aarch64_sve_eorv:
- return LowerSVEIntReduction(N, AArch64ISD::EORV_PRED, DAG);
+ return combineSVEReductionInt(N, AArch64ISD::EORV_PRED, DAG);
case Intrinsic::aarch64_sve_andv:
- return LowerSVEIntReduction(N, AArch64ISD::ANDV_PRED, DAG);
+ return combineSVEReductionInt(N, AArch64ISD::ANDV_PRED, DAG);
case Intrinsic::aarch64_sve_index:
return LowerSVEIntrinsicIndex(N, DAG);
case Intrinsic::aarch64_sve_dup:
@@ -11957,26 +13638,19 @@ static SDValue performIntrinsicCombine(SDNode *N,
case Intrinsic::aarch64_sve_ext:
return LowerSVEIntrinsicEXT(N, DAG);
case Intrinsic::aarch64_sve_smin:
- return DAG.getNode(AArch64ISD::SMIN_MERGE_OP1, SDLoc(N), N->getValueType(0),
- N->getOperand(1), N->getOperand(2), N->getOperand(3));
+ return convertMergedOpToPredOp(N, AArch64ISD::SMIN_PRED, DAG);
case Intrinsic::aarch64_sve_umin:
- return DAG.getNode(AArch64ISD::UMIN_MERGE_OP1, SDLoc(N), N->getValueType(0),
- N->getOperand(1), N->getOperand(2), N->getOperand(3));
+ return convertMergedOpToPredOp(N, AArch64ISD::UMIN_PRED, DAG);
case Intrinsic::aarch64_sve_smax:
- return DAG.getNode(AArch64ISD::SMAX_MERGE_OP1, SDLoc(N), N->getValueType(0),
- N->getOperand(1), N->getOperand(2), N->getOperand(3));
+ return convertMergedOpToPredOp(N, AArch64ISD::SMAX_PRED, DAG);
case Intrinsic::aarch64_sve_umax:
- return DAG.getNode(AArch64ISD::UMAX_MERGE_OP1, SDLoc(N), N->getValueType(0),
- N->getOperand(1), N->getOperand(2), N->getOperand(3));
+ return convertMergedOpToPredOp(N, AArch64ISD::UMAX_PRED, DAG);
case Intrinsic::aarch64_sve_lsl:
- return DAG.getNode(AArch64ISD::SHL_MERGE_OP1, SDLoc(N), N->getValueType(0),
- N->getOperand(1), N->getOperand(2), N->getOperand(3));
+ return convertMergedOpToPredOp(N, AArch64ISD::SHL_PRED, DAG);
case Intrinsic::aarch64_sve_lsr:
- return DAG.getNode(AArch64ISD::SRL_MERGE_OP1, SDLoc(N), N->getValueType(0),
- N->getOperand(1), N->getOperand(2), N->getOperand(3));
+ return convertMergedOpToPredOp(N, AArch64ISD::SRL_PRED, DAG);
case Intrinsic::aarch64_sve_asr:
- return DAG.getNode(AArch64ISD::SRA_MERGE_OP1, SDLoc(N), N->getValueType(0),
- N->getOperand(1), N->getOperand(2), N->getOperand(3));
+ return convertMergedOpToPredOp(N, AArch64ISD::SRA_PRED, DAG);
case Intrinsic::aarch64_sve_cmphs:
if (!N->getOperand(2).getValueType().isFloatingPoint())
return DAG.getNode(AArch64ISD::SETCC_MERGE_ZERO, SDLoc(N),
@@ -12069,18 +13743,15 @@ static SDValue performExtendCombine(SDNode *N,
// helps the backend to decide that an sabdl2 would be useful, saving a real
// extract_high operation.
if (!DCI.isBeforeLegalizeOps() && N->getOpcode() == ISD::ZERO_EXTEND &&
- N->getOperand(0).getOpcode() == ISD::INTRINSIC_WO_CHAIN) {
+ (N->getOperand(0).getOpcode() == AArch64ISD::UABD ||
+ N->getOperand(0).getOpcode() == AArch64ISD::SABD)) {
SDNode *ABDNode = N->getOperand(0).getNode();
- unsigned IID = getIntrinsicID(ABDNode);
- if (IID == Intrinsic::aarch64_neon_sabd ||
- IID == Intrinsic::aarch64_neon_uabd) {
- SDValue NewABD = tryCombineLongOpWithDup(IID, ABDNode, DCI, DAG);
- if (!NewABD.getNode())
- return SDValue();
+ SDValue NewABD =
+ tryCombineLongOpWithDup(Intrinsic::not_intrinsic, ABDNode, DCI, DAG);
+ if (!NewABD.getNode())
+ return SDValue();
- return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), N->getValueType(0),
- NewABD);
- }
+ return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), N->getValueType(0), NewABD);
}
// This is effectively a custom type legalization for AArch64.
@@ -12288,6 +13959,9 @@ static SDValue performLD1ReplicateCombine(SDNode *N, SelectionDAG &DAG) {
"Unsupported opcode.");
SDLoc DL(N);
EVT VT = N->getValueType(0);
+ if (VT == MVT::nxv8bf16 &&
+ !static_cast<const AArch64Subtarget &>(DAG.getSubtarget()).hasBF16())
+ return SDValue();
EVT LoadVT = VT;
if (VT.isFloatingPoint())
@@ -12560,6 +14234,31 @@ static SDValue splitStores(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
S->getMemOperand()->getFlags());
}
+static SDValue performUzpCombine(SDNode *N, SelectionDAG &DAG) {
+ SDLoc DL(N);
+ SDValue Op0 = N->getOperand(0);
+ SDValue Op1 = N->getOperand(1);
+ EVT ResVT = N->getValueType(0);
+
+ // uzp1(unpklo(uzp1(x, y)), z) => uzp1(x, z)
+ if (Op0.getOpcode() == AArch64ISD::UUNPKLO) {
+ if (Op0.getOperand(0).getOpcode() == AArch64ISD::UZP1) {
+ SDValue X = Op0.getOperand(0).getOperand(0);
+ return DAG.getNode(AArch64ISD::UZP1, DL, ResVT, X, Op1);
+ }
+ }
+
+ // uzp1(x, unpkhi(uzp1(y, z))) => uzp1(x, z)
+ if (Op1.getOpcode() == AArch64ISD::UUNPKHI) {
+ if (Op1.getOperand(0).getOpcode() == AArch64ISD::UZP1) {
+ SDValue Z = Op1.getOperand(0).getOperand(1);
+ return DAG.getNode(AArch64ISD::UZP1, DL, ResVT, Op0, Z);
+ }
+ }
+
+ return SDValue();
+}
+
/// Target-specific DAG combine function for post-increment LD1 (lane) and
/// post-increment LD1R.
static SDValue performPostLD1Combine(SDNode *N,
@@ -12698,6 +14397,54 @@ static SDValue performSTORECombine(SDNode *N,
return SDValue();
}
+static SDValue performMaskedGatherScatterCombine(SDNode *N,
+ TargetLowering::DAGCombinerInfo &DCI,
+ SelectionDAG &DAG) {
+ MaskedGatherScatterSDNode *MGS = cast<MaskedGatherScatterSDNode>(N);
+ assert(MGS && "Can only combine gather load or scatter store nodes");
+
+ SDLoc DL(MGS);
+ SDValue Chain = MGS->getChain();
+ SDValue Scale = MGS->getScale();
+ SDValue Index = MGS->getIndex();
+ SDValue Mask = MGS->getMask();
+ SDValue BasePtr = MGS->getBasePtr();
+ ISD::MemIndexType IndexType = MGS->getIndexType();
+
+ EVT IdxVT = Index.getValueType();
+
+ if (DCI.isBeforeLegalize()) {
+ // SVE gather/scatter requires indices of i32/i64. Promote anything smaller
+ // prior to legalisation so the result can be split if required.
+ if ((IdxVT.getVectorElementType() == MVT::i8) ||
+ (IdxVT.getVectorElementType() == MVT::i16)) {
+ EVT NewIdxVT = IdxVT.changeVectorElementType(MVT::i32);
+ if (MGS->isIndexSigned())
+ Index = DAG.getNode(ISD::SIGN_EXTEND, DL, NewIdxVT, Index);
+ else
+ Index = DAG.getNode(ISD::ZERO_EXTEND, DL, NewIdxVT, Index);
+
+ if (auto *MGT = dyn_cast<MaskedGatherSDNode>(MGS)) {
+ SDValue PassThru = MGT->getPassThru();
+ SDValue Ops[] = { Chain, PassThru, Mask, BasePtr, Index, Scale };
+ return DAG.getMaskedGather(DAG.getVTList(N->getValueType(0), MVT::Other),
+ PassThru.getValueType(), DL, Ops,
+ MGT->getMemOperand(),
+ MGT->getIndexType(), MGT->getExtensionType());
+ } else {
+ auto *MSC = cast<MaskedScatterSDNode>(MGS);
+ SDValue Data = MSC->getValue();
+ SDValue Ops[] = { Chain, Data, Mask, BasePtr, Index, Scale };
+ return DAG.getMaskedScatter(DAG.getVTList(MVT::Other),
+ MSC->getMemoryVT(), DL, Ops,
+ MSC->getMemOperand(), IndexType,
+ MSC->isTruncatingStore());
+ }
+ }
+ }
+
+ return SDValue();
+}
/// Target-specific DAG combine function for NEON load/store intrinsics
/// to merge base address updates.
@@ -13669,9 +15416,6 @@ static SDValue performGatherLoadCombine(SDNode *N, SelectionDAG &DAG,
static SDValue
performSignExtendInRegCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
SelectionDAG &DAG) {
- if (DCI.isBeforeLegalizeOps())
- return SDValue();
-
SDLoc DL(N);
SDValue Src = N->getOperand(0);
unsigned Opc = Src->getOpcode();
@@ -13698,9 +15442,7 @@ performSignExtendInRegCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
assert((EltTy == MVT::i8 || EltTy == MVT::i16 || EltTy == MVT::i32) &&
"Sign extending from an invalid type");
- EVT ExtVT = EVT::getVectorVT(*DAG.getContext(),
- VT.getVectorElementType(),
- VT.getVectorElementCount() * 2);
+ EVT ExtVT = VT.getDoubleNumVectorElementsVT(*DAG.getContext());
SDValue Ext = DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, ExtOp.getValueType(),
ExtOp, DAG.getValueType(ExtVT));
@@ -13708,6 +15450,12 @@ performSignExtendInRegCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
return DAG.getNode(SOpc, DL, N->getValueType(0), Ext);
}
+ if (DCI.isBeforeLegalizeOps())
+ return SDValue();
+
+ if (!EnableCombineMGatherIntrinsics)
+ return SDValue();
+
// SVE load nodes (e.g. AArch64ISD::GLD1) are straightforward candidates
// for DAG Combine with SIGN_EXTEND_INREG. Bail out for all other nodes.
unsigned NewOpc;
@@ -13847,9 +15595,11 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
default:
LLVM_DEBUG(dbgs() << "Custom combining: skipping\n");
break;
+ case ISD::ABS:
+ return performABSCombine(N, DAG, DCI, Subtarget);
case ISD::ADD:
case ISD::SUB:
- return performAddSubLongCombine(N, DCI, DAG);
+ return performAddSubCombine(N, DCI, DAG);
case ISD::XOR:
return performXorCombine(N, DAG, DCI, Subtarget);
case ISD::MUL:
@@ -13876,6 +15626,8 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
return performExtendCombine(N, DCI, DAG);
case ISD::SIGN_EXTEND_INREG:
return performSignExtendInRegCombine(N, DCI, DAG);
+ case ISD::TRUNCATE:
+ return performVectorTruncateCombine(N, DCI, DAG);
case ISD::CONCAT_VECTORS:
return performConcatVectorsCombine(N, DCI, DAG);
case ISD::SELECT:
@@ -13888,6 +15640,9 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
break;
case ISD::STORE:
return performSTORECombine(N, DCI, DAG, Subtarget);
+ case ISD::MGATHER:
+ case ISD::MSCATTER:
+ return performMaskedGatherScatterCombine(N, DCI, DAG);
case AArch64ISD::BRCOND:
return performBRCONDCombine(N, DCI, DAG);
case AArch64ISD::TBNZ:
@@ -13899,8 +15654,14 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
return performPostLD1Combine(N, DCI, false);
case AArch64ISD::NVCAST:
return performNVCASTCombine(N);
+ case AArch64ISD::UZP1:
+ return performUzpCombine(N, DAG);
case ISD::INSERT_VECTOR_ELT:
return performPostLD1Combine(N, DCI, true);
+ case ISD::EXTRACT_VECTOR_ELT:
+ return performExtractVectorEltCombine(N, DAG);
+ case ISD::VECREDUCE_ADD:
+ return performVecReduceAddCombine(N, DCI.DAG, Subtarget);
case ISD::INTRINSIC_VOID:
case ISD::INTRINSIC_W_CHAIN:
switch (cast<ConstantSDNode>(N->getOperand(1))->getZExtValue()) {
@@ -14049,10 +15810,10 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
uint64_t IdxConst = cast<ConstantSDNode>(Idx)->getZExtValue();
EVT ResVT = N->getValueType(0);
- uint64_t NumLanes = ResVT.getVectorElementCount().Min;
+ uint64_t NumLanes = ResVT.getVectorElementCount().getKnownMinValue();
+ SDValue ExtIdx = DAG.getVectorIdxConstant(IdxConst * NumLanes, DL);
SDValue Val =
- DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ResVT, Src1,
- DAG.getConstant(IdxConst * NumLanes, DL, MVT::i32));
+ DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ResVT, Src1, ExtIdx);
return DAG.getMergeValues({Val, Chain}, DL);
}
case Intrinsic::aarch64_sve_tuple_set: {
@@ -14063,10 +15824,11 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
SDValue Vec = N->getOperand(4);
EVT TupleVT = Tuple.getValueType();
- uint64_t TupleLanes = TupleVT.getVectorElementCount().Min;
+ uint64_t TupleLanes = TupleVT.getVectorElementCount().getKnownMinValue();
uint64_t IdxConst = cast<ConstantSDNode>(Idx)->getZExtValue();
- uint64_t NumLanes = Vec.getValueType().getVectorElementCount().Min;
+ uint64_t NumLanes =
+ Vec.getValueType().getVectorElementCount().getKnownMinValue();
if ((TupleLanes % NumLanes) != 0)
report_fatal_error("invalid tuple vector!");
@@ -14078,9 +15840,9 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
if (I == IdxConst)
Opnds.push_back(Vec);
else {
- Opnds.push_back(
- DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, Vec.getValueType(), Tuple,
- DAG.getConstant(I * NumLanes, DL, MVT::i32)));
+ SDValue ExtIdx = DAG.getVectorIdxConstant(I * NumLanes, DL);
+ Opnds.push_back(DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL,
+ Vec.getValueType(), Tuple, ExtIdx));
}
}
SDValue Concat =
@@ -14302,7 +16064,7 @@ void AArch64TargetLowering::ReplaceExtractSubVectorResults(
ElementCount ResEC = VT.getVectorElementCount();
- if (InVT.getVectorElementCount().Min != (ResEC.Min * 2))
+ if (InVT.getVectorElementCount() != (ResEC * 2))
return;
auto *CIndex = dyn_cast<ConstantSDNode>(N->getOperand(1));
@@ -14310,7 +16072,7 @@ void AArch64TargetLowering::ReplaceExtractSubVectorResults(
return;
unsigned Index = CIndex->getZExtValue();
- if ((Index != 0) && (Index != ResEC.Min))
+ if ((Index != 0) && (Index != ResEC.getKnownMinValue()))
return;
unsigned Opcode = (Index == 0) ? AArch64ISD::UUNPKLO : AArch64ISD::UUNPKHI;
@@ -14345,7 +16107,7 @@ static void ReplaceCMP_SWAP_128Results(SDNode *N,
assert(N->getValueType(0) == MVT::i128 &&
"AtomicCmpSwap on types less than 128 should be legal");
- if (Subtarget->hasLSE()) {
+ if (Subtarget->hasLSE() || Subtarget->outlineAtomics()) {
// LSE has a 128-bit compare and swap (CASP), but i128 is not a legal type,
// so lower it here, wrapped in REG_SEQUENCE and EXTRACT_SUBREG.
SDValue Ops[] = {
@@ -14426,7 +16188,8 @@ void AArch64TargetLowering::ReplaceNodeResults(
return;
case ISD::CTPOP:
- Results.push_back(LowerCTPOP(SDValue(N, 0), DAG));
+ if (SDValue Result = LowerCTPOP(SDValue(N, 0), DAG))
+ Results.push_back(Result);
return;
case AArch64ISD::SADDV:
ReplaceReductionResults(N, Results, DAG, ISD::ADD, AArch64ISD::SADDV);
@@ -14574,14 +16337,30 @@ AArch64TargetLowering::shouldExpandAtomicRMWInIR(AtomicRMWInst *AI) const {
// Nand not supported in LSE.
if (AI->getOperation() == AtomicRMWInst::Nand) return AtomicExpansionKind::LLSC;
// Leave 128 bits to LLSC.
- return (Subtarget->hasLSE() && Size < 128) ? AtomicExpansionKind::None : AtomicExpansionKind::LLSC;
+ if (Subtarget->hasLSE() && Size < 128)
+ return AtomicExpansionKind::None;
+ if (Subtarget->outlineAtomics() && Size < 128) {
+ // [U]Min/[U]Max RWM atomics are used in __sync_fetch_ libcalls so far.
+ // Don't outline them unless
+ // (1) high level <atomic> support approved:
+ // http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2020/p0493r1.pdf
+ // (2) low level libgcc and compiler-rt support implemented by:
+ // min/max outline atomics helpers
+ if (AI->getOperation() != AtomicRMWInst::Min &&
+ AI->getOperation() != AtomicRMWInst::Max &&
+ AI->getOperation() != AtomicRMWInst::UMin &&
+ AI->getOperation() != AtomicRMWInst::UMax) {
+ return AtomicExpansionKind::None;
+ }
+ }
+ return AtomicExpansionKind::LLSC;
}
TargetLowering::AtomicExpansionKind
AArch64TargetLowering::shouldExpandAtomicCmpXchgInIR(
AtomicCmpXchgInst *AI) const {
// If subtarget has LSE, leave cmpxchg intact for codegen.
- if (Subtarget->hasLSE())
+ if (Subtarget->hasLSE() || Subtarget->outlineAtomics())
return AtomicExpansionKind::None;
// At -O0, fast-regalloc cannot cope with the live vregs necessary to
// implement cmpxchg without spilling. If the address being exchanged is also
@@ -14676,7 +16455,14 @@ Value *AArch64TargetLowering::emitStoreConditional(IRBuilder<> &Builder,
bool AArch64TargetLowering::functionArgumentNeedsConsecutiveRegisters(
Type *Ty, CallingConv::ID CallConv, bool isVarArg) const {
- return Ty->isArrayTy();
+ if (Ty->isArrayTy())
+ return true;
+
+ const TypeSize &TySize = Ty->getPrimitiveSizeInBits();
+ if (TySize.isScalable() && TySize.getKnownMinSize() > 128)
+ return true;
+
+ return false;
}
bool AArch64TargetLowering::shouldNormalizeToSelectSequence(LLVMContext &,
@@ -14909,6 +16695,11 @@ bool AArch64TargetLowering::fallBackToDAGISel(const Instruction &Inst) const {
if (isa<ScalableVectorType>(Inst.getOperand(i)->getType()))
return true;
+ if (const AllocaInst *AI = dyn_cast<AllocaInst>(&Inst)) {
+ if (isa<ScalableVectorType>(AI->getAllocatedType()))
+ return true;
+ }
+
return false;
}
@@ -15080,6 +16871,92 @@ SDValue AArch64TargetLowering::LowerFixedLengthVectorStoreToSVE(
Store->isTruncatingStore());
}
+SDValue AArch64TargetLowering::LowerFixedLengthVectorIntDivideToSVE(
+ SDValue Op, SelectionDAG &DAG) const {
+ SDLoc dl(Op);
+ EVT VT = Op.getValueType();
+ EVT EltVT = VT.getVectorElementType();
+
+ bool Signed = Op.getOpcode() == ISD::SDIV;
+ unsigned PredOpcode = Signed ? AArch64ISD::SDIV_PRED : AArch64ISD::UDIV_PRED;
+
+ // Scalable vector i32/i64 DIV is supported.
+ if (EltVT == MVT::i32 || EltVT == MVT::i64)
+ return LowerToPredicatedOp(Op, DAG, PredOpcode, /*OverrideNEON=*/true);
+
+ // Scalable vector i8/i16 DIV is not supported. Promote it to i32.
+ EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT);
+ EVT HalfVT = VT.getHalfNumVectorElementsVT(*DAG.getContext());
+ EVT FixedWidenedVT = HalfVT.widenIntegerVectorElementType(*DAG.getContext());
+ EVT ScalableWidenedVT = getContainerForFixedLengthVector(DAG, FixedWidenedVT);
+
+ // Convert the operands to scalable vectors.
+ SDValue Op0 = convertToScalableVector(DAG, ContainerVT, Op.getOperand(0));
+ SDValue Op1 = convertToScalableVector(DAG, ContainerVT, Op.getOperand(1));
+
+ // Extend the scalable operands.
+ unsigned UnpkLo = Signed ? AArch64ISD::SUNPKLO : AArch64ISD::UUNPKLO;
+ unsigned UnpkHi = Signed ? AArch64ISD::SUNPKHI : AArch64ISD::UUNPKHI;
+ SDValue Op0Lo = DAG.getNode(UnpkLo, dl, ScalableWidenedVT, Op0);
+ SDValue Op1Lo = DAG.getNode(UnpkLo, dl, ScalableWidenedVT, Op1);
+ SDValue Op0Hi = DAG.getNode(UnpkHi, dl, ScalableWidenedVT, Op0);
+ SDValue Op1Hi = DAG.getNode(UnpkHi, dl, ScalableWidenedVT, Op1);
+
+ // Convert back to fixed vectors so the DIV can be further lowered.
+ Op0Lo = convertFromScalableVector(DAG, FixedWidenedVT, Op0Lo);
+ Op1Lo = convertFromScalableVector(DAG, FixedWidenedVT, Op1Lo);
+ Op0Hi = convertFromScalableVector(DAG, FixedWidenedVT, Op0Hi);
+ Op1Hi = convertFromScalableVector(DAG, FixedWidenedVT, Op1Hi);
+ SDValue ResultLo = DAG.getNode(Op.getOpcode(), dl, FixedWidenedVT,
+ Op0Lo, Op1Lo);
+ SDValue ResultHi = DAG.getNode(Op.getOpcode(), dl, FixedWidenedVT,
+ Op0Hi, Op1Hi);
+
+ // Convert again to scalable vectors to truncate.
+ ResultLo = convertToScalableVector(DAG, ScalableWidenedVT, ResultLo);
+ ResultHi = convertToScalableVector(DAG, ScalableWidenedVT, ResultHi);
+ SDValue ScalableResult = DAG.getNode(AArch64ISD::UZP1, dl, ContainerVT,
+ ResultLo, ResultHi);
+
+ return convertFromScalableVector(DAG, VT, ScalableResult);
+}
+
+SDValue AArch64TargetLowering::LowerFixedLengthVectorIntExtendToSVE(
+ SDValue Op, SelectionDAG &DAG) const {
+ EVT VT = Op.getValueType();
+ assert(VT.isFixedLengthVector() && "Expected fixed length vector type!");
+
+ SDLoc DL(Op);
+ SDValue Val = Op.getOperand(0);
+ EVT ContainerVT = getContainerForFixedLengthVector(DAG, Val.getValueType());
+ Val = convertToScalableVector(DAG, ContainerVT, Val);
+
+ bool Signed = Op.getOpcode() == ISD::SIGN_EXTEND;
+ unsigned ExtendOpc = Signed ? AArch64ISD::SUNPKLO : AArch64ISD::UUNPKLO;
+
+ // Repeatedly unpack Val until the result is of the desired element type.
+ switch (ContainerVT.getSimpleVT().SimpleTy) {
+ default:
+ llvm_unreachable("unimplemented container type");
+ case MVT::nxv16i8:
+ Val = DAG.getNode(ExtendOpc, DL, MVT::nxv8i16, Val);
+ if (VT.getVectorElementType() == MVT::i16)
+ break;
+ LLVM_FALLTHROUGH;
+ case MVT::nxv8i16:
+ Val = DAG.getNode(ExtendOpc, DL, MVT::nxv4i32, Val);
+ if (VT.getVectorElementType() == MVT::i32)
+ break;
+ LLVM_FALLTHROUGH;
+ case MVT::nxv4i32:
+ Val = DAG.getNode(ExtendOpc, DL, MVT::nxv2i64, Val);
+ assert(VT.getVectorElementType() == MVT::i64 && "Unexpected element type!");
+ break;
+ }
+
+ return convertFromScalableVector(DAG, VT, Val);
+}
+
SDValue AArch64TargetLowering::LowerFixedLengthVectorTruncateToSVE(
SDValue Op, SelectionDAG &DAG) const {
EVT VT = Op.getValueType();
@@ -15116,17 +16993,21 @@ SDValue AArch64TargetLowering::LowerFixedLengthVectorTruncateToSVE(
return convertFromScalableVector(DAG, VT, Val);
}
+// Convert vector operation 'Op' to an equivalent predicated operation whereby
+// the original operation's type is used to construct a suitable predicate.
+// NOTE: The results for inactive lanes are undefined.
SDValue AArch64TargetLowering::LowerToPredicatedOp(SDValue Op,
SelectionDAG &DAG,
- unsigned NewOp) const {
+ unsigned NewOp,
+ bool OverrideNEON) const {
EVT VT = Op.getValueType();
SDLoc DL(Op);
auto Pg = getPredicateForVector(DAG, DL, VT);
- if (useSVEForFixedLengthVectorVT(VT)) {
+ if (useSVEForFixedLengthVectorVT(VT, OverrideNEON)) {
EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT);
- // Create list of operands by convereting existing ones to scalable types.
+ // Create list of operands by converting existing ones to scalable types.
SmallVector<SDValue, 4> Operands = {Pg};
for (const SDValue &V : Op->op_values()) {
if (isa<CondCodeSDNode>(V)) {
@@ -15134,11 +17015,21 @@ SDValue AArch64TargetLowering::LowerToPredicatedOp(SDValue Op,
continue;
}
- assert(useSVEForFixedLengthVectorVT(V.getValueType()) &&
+ if (const VTSDNode *VTNode = dyn_cast<VTSDNode>(V)) {
+ EVT VTArg = VTNode->getVT().getVectorElementType();
+ EVT NewVTArg = ContainerVT.changeVectorElementType(VTArg);
+ Operands.push_back(DAG.getValueType(NewVTArg));
+ continue;
+ }
+
+ assert(useSVEForFixedLengthVectorVT(V.getValueType(), OverrideNEON) &&
"Only fixed length vectors are supported!");
Operands.push_back(convertToScalableVector(DAG, ContainerVT, V));
}
+ if (isMergePassthruOpcode(NewOp))
+ Operands.push_back(DAG.getUNDEF(ContainerVT));
+
auto ScalableRes = DAG.getNode(NewOp, DL, ContainerVT, Operands);
return convertFromScalableVector(DAG, VT, ScalableRes);
}
@@ -15147,10 +17038,228 @@ SDValue AArch64TargetLowering::LowerToPredicatedOp(SDValue Op,
SmallVector<SDValue, 4> Operands = {Pg};
for (const SDValue &V : Op->op_values()) {
- assert((isa<CondCodeSDNode>(V) || V.getValueType().isScalableVector()) &&
+ assert((!V.getValueType().isVector() ||
+ V.getValueType().isScalableVector()) &&
"Only scalable vectors are supported!");
Operands.push_back(V);
}
+ if (isMergePassthruOpcode(NewOp))
+ Operands.push_back(DAG.getUNDEF(VT));
+
return DAG.getNode(NewOp, DL, VT, Operands);
}
+
+// If a fixed length vector operation has no side effects when applied to
+// undefined elements, we can safely use scalable vectors to perform the same
+// operation without needing to worry about predication.
+SDValue AArch64TargetLowering::LowerToScalableOp(SDValue Op,
+ SelectionDAG &DAG) const {
+ EVT VT = Op.getValueType();
+ assert(useSVEForFixedLengthVectorVT(VT) &&
+ "Only expected to lower fixed length vector operation!");
+ EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT);
+
+ // Create list of operands by converting existing ones to scalable types.
+ SmallVector<SDValue, 4> Ops;
+ for (const SDValue &V : Op->op_values()) {
+ assert(!isa<VTSDNode>(V) && "Unexpected VTSDNode node!");
+
+ // Pass through non-vector operands.
+ if (!V.getValueType().isVector()) {
+ Ops.push_back(V);
+ continue;
+ }
+
+ // "cast" fixed length vector to a scalable vector.
+ assert(useSVEForFixedLengthVectorVT(V.getValueType()) &&
+ "Only fixed length vectors are supported!");
+ Ops.push_back(convertToScalableVector(DAG, ContainerVT, V));
+ }
+
+ auto ScalableRes = DAG.getNode(Op.getOpcode(), SDLoc(Op), ContainerVT, Ops);
+ return convertFromScalableVector(DAG, VT, ScalableRes);
+}
+
+SDValue AArch64TargetLowering::LowerVECREDUCE_SEQ_FADD(SDValue ScalarOp,
+ SelectionDAG &DAG) const {
+ SDLoc DL(ScalarOp);
+ SDValue AccOp = ScalarOp.getOperand(0);
+ SDValue VecOp = ScalarOp.getOperand(1);
+ EVT SrcVT = VecOp.getValueType();
+ EVT ResVT = SrcVT.getVectorElementType();
+
+ EVT ContainerVT = SrcVT;
+ if (SrcVT.isFixedLengthVector()) {
+ ContainerVT = getContainerForFixedLengthVector(DAG, SrcVT);
+ VecOp = convertToScalableVector(DAG, ContainerVT, VecOp);
+ }
+
+ SDValue Pg = getPredicateForVector(DAG, DL, SrcVT);
+ SDValue Zero = DAG.getConstant(0, DL, MVT::i64);
+
+ // Convert operands to Scalable.
+ AccOp = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, ContainerVT,
+ DAG.getUNDEF(ContainerVT), AccOp, Zero);
+
+ // Perform reduction.
+ SDValue Rdx = DAG.getNode(AArch64ISD::FADDA_PRED, DL, ContainerVT,
+ Pg, AccOp, VecOp);
+
+ return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ResVT, Rdx, Zero);
+}
+
+SDValue AArch64TargetLowering::LowerPredReductionToSVE(SDValue ReduceOp,
+ SelectionDAG &DAG) const {
+ SDLoc DL(ReduceOp);
+ SDValue Op = ReduceOp.getOperand(0);
+ EVT OpVT = Op.getValueType();
+ EVT VT = ReduceOp.getValueType();
+
+ if (!OpVT.isScalableVector() || OpVT.getVectorElementType() != MVT::i1)
+ return SDValue();
+
+ SDValue Pg = getPredicateForVector(DAG, DL, OpVT);
+
+ switch (ReduceOp.getOpcode()) {
+ default:
+ return SDValue();
+ case ISD::VECREDUCE_OR:
+ return getPTest(DAG, VT, Pg, Op, AArch64CC::ANY_ACTIVE);
+ case ISD::VECREDUCE_AND: {
+ Op = DAG.getNode(ISD::XOR, DL, OpVT, Op, Pg);
+ return getPTest(DAG, VT, Pg, Op, AArch64CC::NONE_ACTIVE);
+ }
+ case ISD::VECREDUCE_XOR: {
+ SDValue ID =
+ DAG.getTargetConstant(Intrinsic::aarch64_sve_cntp, DL, MVT::i64);
+ SDValue Cntp =
+ DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, MVT::i64, ID, Pg, Op);
+ return DAG.getAnyExtOrTrunc(Cntp, DL, VT);
+ }
+ }
+
+ return SDValue();
+}
+
+SDValue AArch64TargetLowering::LowerReductionToSVE(unsigned Opcode,
+ SDValue ScalarOp,
+ SelectionDAG &DAG) const {
+ SDLoc DL(ScalarOp);
+ SDValue VecOp = ScalarOp.getOperand(0);
+ EVT SrcVT = VecOp.getValueType();
+
+ if (useSVEForFixedLengthVectorVT(SrcVT, true)) {
+ EVT ContainerVT = getContainerForFixedLengthVector(DAG, SrcVT);
+ VecOp = convertToScalableVector(DAG, ContainerVT, VecOp);
+ }
+
+ // UADDV always returns an i64 result.
+ EVT ResVT = (Opcode == AArch64ISD::UADDV_PRED) ? MVT::i64 :
+ SrcVT.getVectorElementType();
+ EVT RdxVT = SrcVT;
+ if (SrcVT.isFixedLengthVector() || Opcode == AArch64ISD::UADDV_PRED)
+ RdxVT = getPackedSVEVectorVT(ResVT);
+
+ SDValue Pg = getPredicateForVector(DAG, DL, SrcVT);
+ SDValue Rdx = DAG.getNode(Opcode, DL, RdxVT, Pg, VecOp);
+ SDValue Res = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ResVT,
+ Rdx, DAG.getConstant(0, DL, MVT::i64));
+
+ // The VEC_REDUCE nodes expect an element size result.
+ if (ResVT != ScalarOp.getValueType())
+ Res = DAG.getAnyExtOrTrunc(Res, DL, ScalarOp.getValueType());
+
+ return Res;
+}
+
+SDValue
+AArch64TargetLowering::LowerFixedLengthVectorSelectToSVE(SDValue Op,
+ SelectionDAG &DAG) const {
+ EVT VT = Op.getValueType();
+ SDLoc DL(Op);
+
+ EVT InVT = Op.getOperand(1).getValueType();
+ EVT ContainerVT = getContainerForFixedLengthVector(DAG, InVT);
+ SDValue Op1 = convertToScalableVector(DAG, ContainerVT, Op->getOperand(1));
+ SDValue Op2 = convertToScalableVector(DAG, ContainerVT, Op->getOperand(2));
+
+ // Convert the mask to a predicated (NOTE: We don't need to worry about
+ // inactive lanes since VSELECT is safe when given undefined elements).
+ EVT MaskVT = Op.getOperand(0).getValueType();
+ EVT MaskContainerVT = getContainerForFixedLengthVector(DAG, MaskVT);
+ auto Mask = convertToScalableVector(DAG, MaskContainerVT, Op.getOperand(0));
+ Mask = DAG.getNode(ISD::TRUNCATE, DL,
+ MaskContainerVT.changeVectorElementType(MVT::i1), Mask);
+
+ auto ScalableRes = DAG.getNode(ISD::VSELECT, DL, ContainerVT,
+ Mask, Op1, Op2);
+
+ return convertFromScalableVector(DAG, VT, ScalableRes);
+}
+
+SDValue AArch64TargetLowering::LowerFixedLengthVectorSetccToSVE(
+ SDValue Op, SelectionDAG &DAG) const {
+ SDLoc DL(Op);
+ EVT InVT = Op.getOperand(0).getValueType();
+ EVT ContainerVT = getContainerForFixedLengthVector(DAG, InVT);
+
+ assert(useSVEForFixedLengthVectorVT(InVT) &&
+ "Only expected to lower fixed length vector operation!");
+ assert(Op.getValueType() == InVT.changeTypeToInteger() &&
+ "Expected integer result of the same bit length as the inputs!");
+
+ // Expand floating point vector comparisons.
+ if (InVT.isFloatingPoint())
+ return SDValue();
+
+ auto Op1 = convertToScalableVector(DAG, ContainerVT, Op.getOperand(0));
+ auto Op2 = convertToScalableVector(DAG, ContainerVT, Op.getOperand(1));
+ auto Pg = getPredicateForFixedLengthVector(DAG, DL, InVT);
+
+ EVT CmpVT = Pg.getValueType();
+ auto Cmp = DAG.getNode(AArch64ISD::SETCC_MERGE_ZERO, DL, CmpVT,
+ {Pg, Op1, Op2, Op.getOperand(2)});
+
+ EVT PromoteVT = ContainerVT.changeTypeToInteger();
+ auto Promote = DAG.getBoolExtOrTrunc(Cmp, DL, PromoteVT, InVT);
+ return convertFromScalableVector(DAG, Op.getValueType(), Promote);
+}
+
+SDValue AArch64TargetLowering::getSVESafeBitCast(EVT VT, SDValue Op,
+ SelectionDAG &DAG) const {
+ SDLoc DL(Op);
+ EVT InVT = Op.getValueType();
+ const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+ (void)TLI;
+
+ assert(VT.isScalableVector() && TLI.isTypeLegal(VT) &&
+ InVT.isScalableVector() && TLI.isTypeLegal(InVT) &&
+ "Only expect to cast between legal scalable vector types!");
+ assert((VT.getVectorElementType() == MVT::i1) ==
+ (InVT.getVectorElementType() == MVT::i1) &&
+ "Cannot cast between data and predicate scalable vector types!");
+
+ if (InVT == VT)
+ return Op;
+
+ if (VT.getVectorElementType() == MVT::i1)
+ return DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, VT, Op);
+
+ EVT PackedVT = getPackedSVEVectorVT(VT.getVectorElementType());
+ EVT PackedInVT = getPackedSVEVectorVT(InVT.getVectorElementType());
+ assert((VT == PackedVT || InVT == PackedInVT) &&
+ "Cannot cast between unpacked scalable vector types!");
+
+ // Pack input if required.
+ if (InVT != PackedInVT)
+ Op = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, PackedInVT, Op);
+
+ Op = DAG.getNode(ISD::BITCAST, DL, PackedVT, Op);
+
+ // Unpack result if required.
+ if (VT != PackedVT)
+ Op = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, VT, Op);
+
+ return Op;
+}