diff options
author | Dimitry Andric <dim@FreeBSD.org> | 2021-02-16 20:13:02 +0000 |
---|---|---|
committer | Dimitry Andric <dim@FreeBSD.org> | 2021-02-16 20:13:02 +0000 |
commit | b60736ec1405bb0a8dd40989f67ef4c93da068ab (patch) | |
tree | 5c43fbb7c9fc45f0f87e0e6795a86267dbd12f9d /llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | |
parent | cfca06d7963fa0909f90483b42a6d7d194d01e08 (diff) | |
download | src-b60736ec1405bb0a8dd40989f67ef4c93da068ab.tar.gz src-b60736ec1405bb0a8dd40989f67ef4c93da068ab.zip |
Vendor import of llvm-project main 8e464dd76bef, the last commit beforevendor/llvm-project/llvmorg-12-init-17869-g8e464dd76bef
the upstream release/12.x branch was created.
Diffstat (limited to 'llvm/lib/Target/AArch64/AArch64ISelLowering.cpp')
-rw-r--r-- | llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 3221 |
1 files changed, 2665 insertions, 556 deletions
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 85db14ab66fe..1be09186dc0a 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -27,7 +27,6 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/ADT/StringRef.h" -#include "llvm/ADT/StringSwitch.h" #include "llvm/ADT/Triple.h" #include "llvm/ADT/Twine.h" #include "llvm/Analysis/VectorUtils.h" @@ -113,9 +112,76 @@ EnableOptimizeLogicalImm("aarch64-enable-logical-imm", cl::Hidden, "optimization"), cl::init(true)); +// Temporary option added for the purpose of testing functionality added +// to DAGCombiner.cpp in D92230. It is expected that this can be removed +// in future when both implementations will be based off MGATHER rather +// than the GLD1 nodes added for the SVE gather load intrinsics. +static cl::opt<bool> +EnableCombineMGatherIntrinsics("aarch64-enable-mgather-combine", cl::Hidden, + cl::desc("Combine extends of AArch64 masked " + "gather intrinsics"), + cl::init(true)); + /// Value type used for condition codes. static const MVT MVT_CC = MVT::i32; +static inline EVT getPackedSVEVectorVT(EVT VT) { + switch (VT.getSimpleVT().SimpleTy) { + default: + llvm_unreachable("unexpected element type for vector"); + case MVT::i8: + return MVT::nxv16i8; + case MVT::i16: + return MVT::nxv8i16; + case MVT::i32: + return MVT::nxv4i32; + case MVT::i64: + return MVT::nxv2i64; + case MVT::f16: + return MVT::nxv8f16; + case MVT::f32: + return MVT::nxv4f32; + case MVT::f64: + return MVT::nxv2f64; + case MVT::bf16: + return MVT::nxv8bf16; + } +} + +// NOTE: Currently there's only a need to return integer vector types. If this +// changes then just add an extra "type" parameter. +static inline EVT getPackedSVEVectorVT(ElementCount EC) { + switch (EC.getKnownMinValue()) { + default: + llvm_unreachable("unexpected element count for vector"); + case 16: + return MVT::nxv16i8; + case 8: + return MVT::nxv8i16; + case 4: + return MVT::nxv4i32; + case 2: + return MVT::nxv2i64; + } +} + +static inline EVT getPromotedVTForPredicate(EVT VT) { + assert(VT.isScalableVector() && (VT.getVectorElementType() == MVT::i1) && + "Expected scalable predicate vector type!"); + switch (VT.getVectorMinNumElements()) { + default: + llvm_unreachable("unexpected element count for vector"); + case 2: + return MVT::nxv2i64; + case 4: + return MVT::nxv4i32; + case 8: + return MVT::nxv8i16; + case 16: + return MVT::nxv16i8; + } +} + /// Returns true if VT's elements occupy the lowest bit positions of its /// associated register class without any intervening space. /// @@ -128,6 +194,42 @@ static inline bool isPackedVectorType(EVT VT, SelectionDAG &DAG) { VT.getSizeInBits().getKnownMinSize() == AArch64::SVEBitsPerBlock; } +// Returns true for ####_MERGE_PASSTHRU opcodes, whose operands have a leading +// predicate and end with a passthru value matching the result type. +static bool isMergePassthruOpcode(unsigned Opc) { + switch (Opc) { + default: + return false; + case AArch64ISD::BITREVERSE_MERGE_PASSTHRU: + case AArch64ISD::BSWAP_MERGE_PASSTHRU: + case AArch64ISD::CTLZ_MERGE_PASSTHRU: + case AArch64ISD::CTPOP_MERGE_PASSTHRU: + case AArch64ISD::DUP_MERGE_PASSTHRU: + case AArch64ISD::ABS_MERGE_PASSTHRU: + case AArch64ISD::NEG_MERGE_PASSTHRU: + case AArch64ISD::FNEG_MERGE_PASSTHRU: + case AArch64ISD::SIGN_EXTEND_INREG_MERGE_PASSTHRU: + case AArch64ISD::ZERO_EXTEND_INREG_MERGE_PASSTHRU: + case AArch64ISD::FCEIL_MERGE_PASSTHRU: + case AArch64ISD::FFLOOR_MERGE_PASSTHRU: + case AArch64ISD::FNEARBYINT_MERGE_PASSTHRU: + case AArch64ISD::FRINT_MERGE_PASSTHRU: + case AArch64ISD::FROUND_MERGE_PASSTHRU: + case AArch64ISD::FROUNDEVEN_MERGE_PASSTHRU: + case AArch64ISD::FTRUNC_MERGE_PASSTHRU: + case AArch64ISD::FP_ROUND_MERGE_PASSTHRU: + case AArch64ISD::FP_EXTEND_MERGE_PASSTHRU: + case AArch64ISD::SINT_TO_FP_MERGE_PASSTHRU: + case AArch64ISD::UINT_TO_FP_MERGE_PASSTHRU: + case AArch64ISD::FCVTZU_MERGE_PASSTHRU: + case AArch64ISD::FCVTZS_MERGE_PASSTHRU: + case AArch64ISD::FSQRT_MERGE_PASSTHRU: + case AArch64ISD::FRECPX_MERGE_PASSTHRU: + case AArch64ISD::FABS_MERGE_PASSTHRU: + return true; + } +} + AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, const AArch64Subtarget &STI) : TargetLowering(TM), Subtarget(&STI) { @@ -161,7 +263,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, addDRTypeForNEON(MVT::v1i64); addDRTypeForNEON(MVT::v1f64); addDRTypeForNEON(MVT::v4f16); - addDRTypeForNEON(MVT::v4bf16); + if (Subtarget->hasBF16()) + addDRTypeForNEON(MVT::v4bf16); addQRTypeForNEON(MVT::v4f32); addQRTypeForNEON(MVT::v2f64); @@ -170,7 +273,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, addQRTypeForNEON(MVT::v4i32); addQRTypeForNEON(MVT::v2i64); addQRTypeForNEON(MVT::v8f16); - addQRTypeForNEON(MVT::v8bf16); + if (Subtarget->hasBF16()) + addQRTypeForNEON(MVT::v8bf16); } if (Subtarget->hasSVE()) { @@ -199,7 +303,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, addRegisterClass(MVT::nxv8bf16, &AArch64::ZPRRegClass); } - if (useSVEForFixedLengthVectors()) { + if (Subtarget->useSVEForFixedLengthVectors()) { for (MVT VT : MVT::integer_fixedlen_vector_valuetypes()) if (useSVEForFixedLengthVectorVT(VT)) addRegisterClass(VT, &AArch64::ZPRRegClass); @@ -230,7 +334,9 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, MVT::nxv2f64 }) { setCondCodeAction(ISD::SETO, VT, Expand); setCondCodeAction(ISD::SETOLT, VT, Expand); + setCondCodeAction(ISD::SETLT, VT, Expand); setCondCodeAction(ISD::SETOLE, VT, Expand); + setCondCodeAction(ISD::SETLE, VT, Expand); setCondCodeAction(ISD::SETULT, VT, Expand); setCondCodeAction(ISD::SETULE, VT, Expand); setCondCodeAction(ISD::SETUGE, VT, Expand); @@ -296,12 +402,12 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, // Virtually no operation on f128 is legal, but LLVM can't expand them when // there's a valid register class, so we need custom operations in most cases. setOperationAction(ISD::FABS, MVT::f128, Expand); - setOperationAction(ISD::FADD, MVT::f128, Custom); + setOperationAction(ISD::FADD, MVT::f128, LibCall); setOperationAction(ISD::FCOPYSIGN, MVT::f128, Expand); setOperationAction(ISD::FCOS, MVT::f128, Expand); - setOperationAction(ISD::FDIV, MVT::f128, Custom); + setOperationAction(ISD::FDIV, MVT::f128, LibCall); setOperationAction(ISD::FMA, MVT::f128, Expand); - setOperationAction(ISD::FMUL, MVT::f128, Custom); + setOperationAction(ISD::FMUL, MVT::f128, LibCall); setOperationAction(ISD::FNEG, MVT::f128, Expand); setOperationAction(ISD::FPOW, MVT::f128, Expand); setOperationAction(ISD::FREM, MVT::f128, Expand); @@ -309,7 +415,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setOperationAction(ISD::FSIN, MVT::f128, Expand); setOperationAction(ISD::FSINCOS, MVT::f128, Expand); setOperationAction(ISD::FSQRT, MVT::f128, Expand); - setOperationAction(ISD::FSUB, MVT::f128, Custom); + setOperationAction(ISD::FSUB, MVT::f128, LibCall); setOperationAction(ISD::FTRUNC, MVT::f128, Expand); setOperationAction(ISD::SETCC, MVT::f128, Custom); setOperationAction(ISD::STRICT_FSETCC, MVT::f128, Custom); @@ -345,8 +451,10 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setOperationAction(ISD::STRICT_UINT_TO_FP, MVT::i32, Custom); setOperationAction(ISD::STRICT_UINT_TO_FP, MVT::i64, Custom); setOperationAction(ISD::STRICT_UINT_TO_FP, MVT::i128, Custom); + setOperationAction(ISD::FP_ROUND, MVT::f16, Custom); setOperationAction(ISD::FP_ROUND, MVT::f32, Custom); setOperationAction(ISD::FP_ROUND, MVT::f64, Custom); + setOperationAction(ISD::STRICT_FP_ROUND, MVT::f16, Custom); setOperationAction(ISD::STRICT_FP_ROUND, MVT::f32, Custom); setOperationAction(ISD::STRICT_FP_ROUND, MVT::f64, Custom); @@ -401,6 +509,9 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setOperationAction(ISD::CTPOP, MVT::i64, Custom); setOperationAction(ISD::CTPOP, MVT::i128, Custom); + setOperationAction(ISD::ABS, MVT::i32, Custom); + setOperationAction(ISD::ABS, MVT::i64, Custom); + setOperationAction(ISD::SDIVREM, MVT::i32, Expand); setOperationAction(ISD::SDIVREM, MVT::i64, Expand); for (MVT VT : MVT::fixedlen_vector_valuetypes()) { @@ -588,6 +699,57 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setOperationAction(ISD::ATOMIC_LOAD_AND, MVT::i32, Custom); setOperationAction(ISD::ATOMIC_LOAD_AND, MVT::i64, Custom); + // Generate outline atomics library calls only if LSE was not specified for + // subtarget + if (Subtarget->outlineAtomics() && !Subtarget->hasLSE()) { + setOperationAction(ISD::ATOMIC_CMP_SWAP, MVT::i8, LibCall); + setOperationAction(ISD::ATOMIC_CMP_SWAP, MVT::i16, LibCall); + setOperationAction(ISD::ATOMIC_CMP_SWAP, MVT::i32, LibCall); + setOperationAction(ISD::ATOMIC_CMP_SWAP, MVT::i64, LibCall); + setOperationAction(ISD::ATOMIC_CMP_SWAP, MVT::i128, LibCall); + setOperationAction(ISD::ATOMIC_SWAP, MVT::i8, LibCall); + setOperationAction(ISD::ATOMIC_SWAP, MVT::i16, LibCall); + setOperationAction(ISD::ATOMIC_SWAP, MVT::i32, LibCall); + setOperationAction(ISD::ATOMIC_SWAP, MVT::i64, LibCall); + setOperationAction(ISD::ATOMIC_LOAD_ADD, MVT::i8, LibCall); + setOperationAction(ISD::ATOMIC_LOAD_ADD, MVT::i16, LibCall); + setOperationAction(ISD::ATOMIC_LOAD_ADD, MVT::i32, LibCall); + setOperationAction(ISD::ATOMIC_LOAD_ADD, MVT::i64, LibCall); + setOperationAction(ISD::ATOMIC_LOAD_OR, MVT::i8, LibCall); + setOperationAction(ISD::ATOMIC_LOAD_OR, MVT::i16, LibCall); + setOperationAction(ISD::ATOMIC_LOAD_OR, MVT::i32, LibCall); + setOperationAction(ISD::ATOMIC_LOAD_OR, MVT::i64, LibCall); + setOperationAction(ISD::ATOMIC_LOAD_CLR, MVT::i8, LibCall); + setOperationAction(ISD::ATOMIC_LOAD_CLR, MVT::i16, LibCall); + setOperationAction(ISD::ATOMIC_LOAD_CLR, MVT::i32, LibCall); + setOperationAction(ISD::ATOMIC_LOAD_CLR, MVT::i64, LibCall); + setOperationAction(ISD::ATOMIC_LOAD_XOR, MVT::i8, LibCall); + setOperationAction(ISD::ATOMIC_LOAD_XOR, MVT::i16, LibCall); + setOperationAction(ISD::ATOMIC_LOAD_XOR, MVT::i32, LibCall); + setOperationAction(ISD::ATOMIC_LOAD_XOR, MVT::i64, LibCall); +#define LCALLNAMES(A, B, N) \ + setLibcallName(A##N##_RELAX, #B #N "_relax"); \ + setLibcallName(A##N##_ACQ, #B #N "_acq"); \ + setLibcallName(A##N##_REL, #B #N "_rel"); \ + setLibcallName(A##N##_ACQ_REL, #B #N "_acq_rel"); +#define LCALLNAME4(A, B) \ + LCALLNAMES(A, B, 1) \ + LCALLNAMES(A, B, 2) LCALLNAMES(A, B, 4) LCALLNAMES(A, B, 8) +#define LCALLNAME5(A, B) \ + LCALLNAMES(A, B, 1) \ + LCALLNAMES(A, B, 2) \ + LCALLNAMES(A, B, 4) LCALLNAMES(A, B, 8) LCALLNAMES(A, B, 16) + LCALLNAME5(RTLIB::OUTLINE_ATOMIC_CAS, __aarch64_cas) + LCALLNAME4(RTLIB::OUTLINE_ATOMIC_SWP, __aarch64_swp) + LCALLNAME4(RTLIB::OUTLINE_ATOMIC_LDADD, __aarch64_ldadd) + LCALLNAME4(RTLIB::OUTLINE_ATOMIC_LDSET, __aarch64_ldset) + LCALLNAME4(RTLIB::OUTLINE_ATOMIC_LDCLR, __aarch64_ldclr) + LCALLNAME4(RTLIB::OUTLINE_ATOMIC_LDEOR, __aarch64_ldeor) +#undef LCALLNAMES +#undef LCALLNAME4 +#undef LCALLNAME5 + } + // 128-bit loads and stores can be done without expanding setOperationAction(ISD::LOAD, MVT::i128, Custom); setOperationAction(ISD::STORE, MVT::i128, Custom); @@ -677,8 +839,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, // Trap. setOperationAction(ISD::TRAP, MVT::Other, Legal); - if (Subtarget->isTargetWindows()) - setOperationAction(ISD::DEBUGTRAP, MVT::Other, Legal); + setOperationAction(ISD::DEBUGTRAP, MVT::Other, Legal); + setOperationAction(ISD::UBSANTRAP, MVT::Other, Legal); // We combine OR nodes for bitfield operations. setTargetDAGCombine(ISD::OR); @@ -688,6 +850,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, // Vector add and sub nodes may conceal a high-half opportunity. // Also, try to fold ADD into CSINC/CSINV.. setTargetDAGCombine(ISD::ADD); + setTargetDAGCombine(ISD::ABS); setTargetDAGCombine(ISD::SUB); setTargetDAGCombine(ISD::SRL); setTargetDAGCombine(ISD::XOR); @@ -704,11 +867,15 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setTargetDAGCombine(ISD::ZERO_EXTEND); setTargetDAGCombine(ISD::SIGN_EXTEND); setTargetDAGCombine(ISD::SIGN_EXTEND_INREG); + setTargetDAGCombine(ISD::TRUNCATE); setTargetDAGCombine(ISD::CONCAT_VECTORS); setTargetDAGCombine(ISD::STORE); if (Subtarget->supportsAddressTopByteIgnored()) setTargetDAGCombine(ISD::LOAD); + setTargetDAGCombine(ISD::MGATHER); + setTargetDAGCombine(ISD::MSCATTER); + setTargetDAGCombine(ISD::MUL); setTargetDAGCombine(ISD::SELECT); @@ -717,6 +884,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setTargetDAGCombine(ISD::INTRINSIC_VOID); setTargetDAGCombine(ISD::INTRINSIC_W_CHAIN); setTargetDAGCombine(ISD::INSERT_VECTOR_ELT); + setTargetDAGCombine(ISD::EXTRACT_VECTOR_ELT); + setTargetDAGCombine(ISD::VECREDUCE_ADD); setTargetDAGCombine(ISD::GlobalAddress); @@ -836,28 +1005,33 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setOperationAction(ISD::MUL, MVT::v4i32, Custom); setOperationAction(ISD::MUL, MVT::v2i64, Custom); + // Saturates for (MVT VT : { MVT::v8i8, MVT::v4i16, MVT::v2i32, MVT::v16i8, MVT::v8i16, MVT::v4i32, MVT::v2i64 }) { - // Vector reductions - setOperationAction(ISD::VECREDUCE_ADD, VT, Custom); - setOperationAction(ISD::VECREDUCE_SMAX, VT, Custom); - setOperationAction(ISD::VECREDUCE_SMIN, VT, Custom); - setOperationAction(ISD::VECREDUCE_UMAX, VT, Custom); - setOperationAction(ISD::VECREDUCE_UMIN, VT, Custom); - - // Saturates setOperationAction(ISD::SADDSAT, VT, Legal); setOperationAction(ISD::UADDSAT, VT, Legal); setOperationAction(ISD::SSUBSAT, VT, Legal); setOperationAction(ISD::USUBSAT, VT, Legal); - - setOperationAction(ISD::TRUNCATE, VT, Custom); } + + // Vector reductions for (MVT VT : { MVT::v4f16, MVT::v2f32, MVT::v8f16, MVT::v4f32, MVT::v2f64 }) { setOperationAction(ISD::VECREDUCE_FMAX, VT, Custom); setOperationAction(ISD::VECREDUCE_FMIN, VT, Custom); + + if (VT.getVectorElementType() != MVT::f16 || Subtarget->hasFullFP16()) + setOperationAction(ISD::VECREDUCE_FADD, VT, Legal); } + for (MVT VT : { MVT::v8i8, MVT::v4i16, MVT::v2i32, + MVT::v16i8, MVT::v8i16, MVT::v4i32 }) { + setOperationAction(ISD::VECREDUCE_ADD, VT, Custom); + setOperationAction(ISD::VECREDUCE_SMAX, VT, Custom); + setOperationAction(ISD::VECREDUCE_SMIN, VT, Custom); + setOperationAction(ISD::VECREDUCE_UMAX, VT, Custom); + setOperationAction(ISD::VECREDUCE_UMIN, VT, Custom); + } + setOperationAction(ISD::VECREDUCE_ADD, MVT::v2i64, Custom); setOperationAction(ISD::ANY_EXTEND, MVT::v4i32, Legal); setTruncStoreAction(MVT::v2i32, MVT::v2i16, Expand); @@ -918,43 +1092,112 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, // FIXME: Add custom lowering of MLOAD to handle different passthrus (not a // splat of 0 or undef) once vector selects supported in SVE codegen. See // D68877 for more details. - for (MVT VT : MVT::integer_scalable_vector_valuetypes()) { - if (isTypeLegal(VT)) { - setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom); - setOperationAction(ISD::SPLAT_VECTOR, VT, Custom); - setOperationAction(ISD::SELECT, VT, Custom); - setOperationAction(ISD::SDIV, VT, Custom); - setOperationAction(ISD::UDIV, VT, Custom); - setOperationAction(ISD::SMIN, VT, Custom); - setOperationAction(ISD::UMIN, VT, Custom); - setOperationAction(ISD::SMAX, VT, Custom); - setOperationAction(ISD::UMAX, VT, Custom); - setOperationAction(ISD::SHL, VT, Custom); - setOperationAction(ISD::SRL, VT, Custom); - setOperationAction(ISD::SRA, VT, Custom); - if (VT.getScalarType() == MVT::i1) - setOperationAction(ISD::SETCC, VT, Custom); - } + for (auto VT : {MVT::nxv16i8, MVT::nxv8i16, MVT::nxv4i32, MVT::nxv2i64}) { + setOperationAction(ISD::BITREVERSE, VT, Custom); + setOperationAction(ISD::BSWAP, VT, Custom); + setOperationAction(ISD::CTLZ, VT, Custom); + setOperationAction(ISD::CTPOP, VT, Custom); + setOperationAction(ISD::CTTZ, VT, Custom); + setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom); + setOperationAction(ISD::UINT_TO_FP, VT, Custom); + setOperationAction(ISD::SINT_TO_FP, VT, Custom); + setOperationAction(ISD::FP_TO_UINT, VT, Custom); + setOperationAction(ISD::FP_TO_SINT, VT, Custom); + setOperationAction(ISD::MGATHER, VT, Custom); + setOperationAction(ISD::MSCATTER, VT, Custom); + setOperationAction(ISD::MUL, VT, Custom); + setOperationAction(ISD::SPLAT_VECTOR, VT, Custom); + setOperationAction(ISD::SELECT, VT, Custom); + setOperationAction(ISD::SDIV, VT, Custom); + setOperationAction(ISD::UDIV, VT, Custom); + setOperationAction(ISD::SMIN, VT, Custom); + setOperationAction(ISD::UMIN, VT, Custom); + setOperationAction(ISD::SMAX, VT, Custom); + setOperationAction(ISD::UMAX, VT, Custom); + setOperationAction(ISD::SHL, VT, Custom); + setOperationAction(ISD::SRL, VT, Custom); + setOperationAction(ISD::SRA, VT, Custom); + setOperationAction(ISD::ABS, VT, Custom); + setOperationAction(ISD::VECREDUCE_ADD, VT, Custom); + setOperationAction(ISD::VECREDUCE_AND, VT, Custom); + setOperationAction(ISD::VECREDUCE_OR, VT, Custom); + setOperationAction(ISD::VECREDUCE_XOR, VT, Custom); + setOperationAction(ISD::VECREDUCE_UMIN, VT, Custom); + setOperationAction(ISD::VECREDUCE_UMAX, VT, Custom); + setOperationAction(ISD::VECREDUCE_SMIN, VT, Custom); + setOperationAction(ISD::VECREDUCE_SMAX, VT, Custom); } - for (auto VT : {MVT::nxv8i8, MVT::nxv4i16, MVT::nxv2i32}) + // Illegal unpacked integer vector types. + for (auto VT : {MVT::nxv8i8, MVT::nxv4i16, MVT::nxv2i32}) { setOperationAction(ISD::EXTRACT_SUBVECTOR, VT, Custom); + setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom); + } - setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i8, Custom); - setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i16, Custom); - - for (MVT VT : MVT::fp_scalable_vector_valuetypes()) { - if (isTypeLegal(VT)) { - setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom); - setOperationAction(ISD::SPLAT_VECTOR, VT, Custom); - setOperationAction(ISD::SELECT, VT, Custom); - setOperationAction(ISD::FMA, VT, Custom); + for (auto VT : {MVT::nxv16i1, MVT::nxv8i1, MVT::nxv4i1, MVT::nxv2i1}) { + setOperationAction(ISD::CONCAT_VECTORS, VT, Custom); + setOperationAction(ISD::SELECT, VT, Custom); + setOperationAction(ISD::SETCC, VT, Custom); + setOperationAction(ISD::SPLAT_VECTOR, VT, Custom); + setOperationAction(ISD::TRUNCATE, VT, Custom); + setOperationAction(ISD::VECREDUCE_AND, VT, Custom); + setOperationAction(ISD::VECREDUCE_OR, VT, Custom); + setOperationAction(ISD::VECREDUCE_XOR, VT, Custom); + + // There are no legal MVT::nxv16f## based types. + if (VT != MVT::nxv16i1) { + setOperationAction(ISD::SINT_TO_FP, VT, Custom); + setOperationAction(ISD::UINT_TO_FP, VT, Custom); } } + for (auto VT : {MVT::nxv2f16, MVT::nxv4f16, MVT::nxv8f16, MVT::nxv2f32, + MVT::nxv4f32, MVT::nxv2f64}) { + setOperationAction(ISD::CONCAT_VECTORS, VT, Custom); + setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom); + setOperationAction(ISD::MGATHER, VT, Custom); + setOperationAction(ISD::MSCATTER, VT, Custom); + setOperationAction(ISD::SPLAT_VECTOR, VT, Custom); + setOperationAction(ISD::SELECT, VT, Custom); + setOperationAction(ISD::FADD, VT, Custom); + setOperationAction(ISD::FDIV, VT, Custom); + setOperationAction(ISD::FMA, VT, Custom); + setOperationAction(ISD::FMAXNUM, VT, Custom); + setOperationAction(ISD::FMINNUM, VT, Custom); + setOperationAction(ISD::FMUL, VT, Custom); + setOperationAction(ISD::FNEG, VT, Custom); + setOperationAction(ISD::FSUB, VT, Custom); + setOperationAction(ISD::FCEIL, VT, Custom); + setOperationAction(ISD::FFLOOR, VT, Custom); + setOperationAction(ISD::FNEARBYINT, VT, Custom); + setOperationAction(ISD::FRINT, VT, Custom); + setOperationAction(ISD::FROUND, VT, Custom); + setOperationAction(ISD::FROUNDEVEN, VT, Custom); + setOperationAction(ISD::FTRUNC, VT, Custom); + setOperationAction(ISD::FSQRT, VT, Custom); + setOperationAction(ISD::FABS, VT, Custom); + setOperationAction(ISD::FP_EXTEND, VT, Custom); + setOperationAction(ISD::FP_ROUND, VT, Custom); + setOperationAction(ISD::VECREDUCE_FADD, VT, Custom); + setOperationAction(ISD::VECREDUCE_FMAX, VT, Custom); + setOperationAction(ISD::VECREDUCE_FMIN, VT, Custom); + setOperationAction(ISD::VECREDUCE_SEQ_FADD, VT, Custom); + } + + for (auto VT : {MVT::nxv2bf16, MVT::nxv4bf16, MVT::nxv8bf16}) { + setOperationAction(ISD::CONCAT_VECTORS, VT, Custom); + setOperationAction(ISD::MGATHER, VT, Custom); + setOperationAction(ISD::MSCATTER, VT, Custom); + } + + setOperationAction(ISD::SPLAT_VECTOR, MVT::nxv8bf16, Custom); + + setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i8, Custom); + setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i16, Custom); + // NOTE: Currently this has to happen after computeRegisterProperties rather // than the preferred option of combining it with the addRegisterClass call. - if (useSVEForFixedLengthVectors()) { + if (Subtarget->useSVEForFixedLengthVectors()) { for (MVT VT : MVT::integer_fixedlen_vector_valuetypes()) if (useSVEForFixedLengthVectorVT(VT)) addTypeForFixedLengthSVE(VT); @@ -972,6 +1215,61 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setOperationAction(ISD::TRUNCATE, VT, Custom); for (auto VT : {MVT::v8f16, MVT::v4f32}) setOperationAction(ISD::FP_ROUND, VT, Expand); + + // These operations are not supported on NEON but SVE can do them. + setOperationAction(ISD::BITREVERSE, MVT::v1i64, Custom); + setOperationAction(ISD::CTLZ, MVT::v1i64, Custom); + setOperationAction(ISD::CTLZ, MVT::v2i64, Custom); + setOperationAction(ISD::CTTZ, MVT::v1i64, Custom); + setOperationAction(ISD::MUL, MVT::v1i64, Custom); + setOperationAction(ISD::MUL, MVT::v2i64, Custom); + setOperationAction(ISD::SDIV, MVT::v8i8, Custom); + setOperationAction(ISD::SDIV, MVT::v16i8, Custom); + setOperationAction(ISD::SDIV, MVT::v4i16, Custom); + setOperationAction(ISD::SDIV, MVT::v8i16, Custom); + setOperationAction(ISD::SDIV, MVT::v2i32, Custom); + setOperationAction(ISD::SDIV, MVT::v4i32, Custom); + setOperationAction(ISD::SDIV, MVT::v1i64, Custom); + setOperationAction(ISD::SDIV, MVT::v2i64, Custom); + setOperationAction(ISD::SMAX, MVT::v1i64, Custom); + setOperationAction(ISD::SMAX, MVT::v2i64, Custom); + setOperationAction(ISD::SMIN, MVT::v1i64, Custom); + setOperationAction(ISD::SMIN, MVT::v2i64, Custom); + setOperationAction(ISD::UDIV, MVT::v8i8, Custom); + setOperationAction(ISD::UDIV, MVT::v16i8, Custom); + setOperationAction(ISD::UDIV, MVT::v4i16, Custom); + setOperationAction(ISD::UDIV, MVT::v8i16, Custom); + setOperationAction(ISD::UDIV, MVT::v2i32, Custom); + setOperationAction(ISD::UDIV, MVT::v4i32, Custom); + setOperationAction(ISD::UDIV, MVT::v1i64, Custom); + setOperationAction(ISD::UDIV, MVT::v2i64, Custom); + setOperationAction(ISD::UMAX, MVT::v1i64, Custom); + setOperationAction(ISD::UMAX, MVT::v2i64, Custom); + setOperationAction(ISD::UMIN, MVT::v1i64, Custom); + setOperationAction(ISD::UMIN, MVT::v2i64, Custom); + setOperationAction(ISD::VECREDUCE_SMAX, MVT::v2i64, Custom); + setOperationAction(ISD::VECREDUCE_SMIN, MVT::v2i64, Custom); + setOperationAction(ISD::VECREDUCE_UMAX, MVT::v2i64, Custom); + setOperationAction(ISD::VECREDUCE_UMIN, MVT::v2i64, Custom); + + // Int operations with no NEON support. + for (auto VT : {MVT::v8i8, MVT::v16i8, MVT::v4i16, MVT::v8i16, + MVT::v2i32, MVT::v4i32, MVT::v2i64}) { + setOperationAction(ISD::BITREVERSE, VT, Custom); + setOperationAction(ISD::CTTZ, VT, Custom); + setOperationAction(ISD::VECREDUCE_AND, VT, Custom); + setOperationAction(ISD::VECREDUCE_OR, VT, Custom); + setOperationAction(ISD::VECREDUCE_XOR, VT, Custom); + } + + // FP operations with no NEON support. + for (auto VT : {MVT::v4f16, MVT::v8f16, MVT::v2f32, MVT::v4f32, + MVT::v1f64, MVT::v2f64}) + setOperationAction(ISD::VECREDUCE_SEQ_FADD, VT, Custom); + + // Use SVE for vectors with more than 2 elements. + for (auto VT : {MVT::v4f16, MVT::v8f16, MVT::v4f32}) + setOperationAction(ISD::VECREDUCE_FADD, VT, Custom); } } @@ -1043,6 +1341,7 @@ void AArch64TargetLowering::addTypeForNEON(MVT VT, MVT PromotedBitwiseVT) { // F[MIN|MAX][NUM|NAN] are available for all FP NEON types. if (VT.isFloatingPoint() && + VT.getVectorElementType() != MVT::bf16 && (VT.getVectorElementType() != MVT::f16 || Subtarget->hasFullFP16())) for (unsigned Opcode : {ISD::FMINIMUM, ISD::FMAXIMUM, ISD::FMINNUM, ISD::FMAXNUM}) @@ -1068,11 +1367,64 @@ void AArch64TargetLowering::addTypeForFixedLengthSVE(MVT VT) { setOperationAction(ISD::EXTRACT_SUBVECTOR, VT, Custom); // Lower fixed length vector operations to scalable equivalents. + setOperationAction(ISD::ABS, VT, Custom); setOperationAction(ISD::ADD, VT, Custom); + setOperationAction(ISD::AND, VT, Custom); + setOperationAction(ISD::ANY_EXTEND, VT, Custom); + setOperationAction(ISD::BITREVERSE, VT, Custom); + setOperationAction(ISD::BSWAP, VT, Custom); + setOperationAction(ISD::CTLZ, VT, Custom); + setOperationAction(ISD::CTPOP, VT, Custom); + setOperationAction(ISD::CTTZ, VT, Custom); setOperationAction(ISD::FADD, VT, Custom); + setOperationAction(ISD::FCEIL, VT, Custom); + setOperationAction(ISD::FDIV, VT, Custom); + setOperationAction(ISD::FFLOOR, VT, Custom); + setOperationAction(ISD::FMA, VT, Custom); + setOperationAction(ISD::FMAXNUM, VT, Custom); + setOperationAction(ISD::FMINNUM, VT, Custom); + setOperationAction(ISD::FMUL, VT, Custom); + setOperationAction(ISD::FNEARBYINT, VT, Custom); + setOperationAction(ISD::FNEG, VT, Custom); + setOperationAction(ISD::FRINT, VT, Custom); + setOperationAction(ISD::FROUND, VT, Custom); + setOperationAction(ISD::FSQRT, VT, Custom); + setOperationAction(ISD::FSUB, VT, Custom); + setOperationAction(ISD::FTRUNC, VT, Custom); setOperationAction(ISD::LOAD, VT, Custom); + setOperationAction(ISD::MUL, VT, Custom); + setOperationAction(ISD::OR, VT, Custom); + setOperationAction(ISD::SDIV, VT, Custom); + setOperationAction(ISD::SETCC, VT, Custom); + setOperationAction(ISD::SHL, VT, Custom); + setOperationAction(ISD::SIGN_EXTEND, VT, Custom); + setOperationAction(ISD::SIGN_EXTEND_INREG, VT, Custom); + setOperationAction(ISD::SMAX, VT, Custom); + setOperationAction(ISD::SMIN, VT, Custom); + setOperationAction(ISD::SPLAT_VECTOR, VT, Custom); + setOperationAction(ISD::SRA, VT, Custom); + setOperationAction(ISD::SRL, VT, Custom); setOperationAction(ISD::STORE, VT, Custom); + setOperationAction(ISD::SUB, VT, Custom); setOperationAction(ISD::TRUNCATE, VT, Custom); + setOperationAction(ISD::UDIV, VT, Custom); + setOperationAction(ISD::UMAX, VT, Custom); + setOperationAction(ISD::UMIN, VT, Custom); + setOperationAction(ISD::VECREDUCE_ADD, VT, Custom); + setOperationAction(ISD::VECREDUCE_AND, VT, Custom); + setOperationAction(ISD::VECREDUCE_FADD, VT, Custom); + setOperationAction(ISD::VECREDUCE_SEQ_FADD, VT, Custom); + setOperationAction(ISD::VECREDUCE_FMAX, VT, Custom); + setOperationAction(ISD::VECREDUCE_FMIN, VT, Custom); + setOperationAction(ISD::VECREDUCE_OR, VT, Custom); + setOperationAction(ISD::VECREDUCE_SMAX, VT, Custom); + setOperationAction(ISD::VECREDUCE_SMIN, VT, Custom); + setOperationAction(ISD::VECREDUCE_UMAX, VT, Custom); + setOperationAction(ISD::VECREDUCE_UMIN, VT, Custom); + setOperationAction(ISD::VECREDUCE_XOR, VT, Custom); + setOperationAction(ISD::VSELECT, VT, Custom); + setOperationAction(ISD::XOR, VT, Custom); + setOperationAction(ISD::ZERO_EXTEND, VT, Custom); } void AArch64TargetLowering::addDRTypeForNEON(MVT VT) { @@ -1244,8 +1596,7 @@ void AArch64TargetLowering::computeKnownBitsForTargetNode( KnownBits Known2; Known = DAG.computeKnownBits(Op->getOperand(0), Depth + 1); Known2 = DAG.computeKnownBits(Op->getOperand(1), Depth + 1); - Known.Zero &= Known2.Zero; - Known.One &= Known2.One; + Known = KnownBits::commonBits(Known, Known2); break; } case AArch64ISD::LOADgot: @@ -1385,15 +1736,38 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const { MAKE_CASE(AArch64ISD::THREAD_POINTER) MAKE_CASE(AArch64ISD::TLSDESC_CALLSEQ) MAKE_CASE(AArch64ISD::ADD_PRED) + MAKE_CASE(AArch64ISD::MUL_PRED) MAKE_CASE(AArch64ISD::SDIV_PRED) + MAKE_CASE(AArch64ISD::SHL_PRED) + MAKE_CASE(AArch64ISD::SMAX_PRED) + MAKE_CASE(AArch64ISD::SMIN_PRED) + MAKE_CASE(AArch64ISD::SRA_PRED) + MAKE_CASE(AArch64ISD::SRL_PRED) + MAKE_CASE(AArch64ISD::SUB_PRED) MAKE_CASE(AArch64ISD::UDIV_PRED) - MAKE_CASE(AArch64ISD::SMIN_MERGE_OP1) - MAKE_CASE(AArch64ISD::UMIN_MERGE_OP1) - MAKE_CASE(AArch64ISD::SMAX_MERGE_OP1) - MAKE_CASE(AArch64ISD::UMAX_MERGE_OP1) - MAKE_CASE(AArch64ISD::SHL_MERGE_OP1) - MAKE_CASE(AArch64ISD::SRL_MERGE_OP1) - MAKE_CASE(AArch64ISD::SRA_MERGE_OP1) + MAKE_CASE(AArch64ISD::UMAX_PRED) + MAKE_CASE(AArch64ISD::UMIN_PRED) + MAKE_CASE(AArch64ISD::FNEG_MERGE_PASSTHRU) + MAKE_CASE(AArch64ISD::SIGN_EXTEND_INREG_MERGE_PASSTHRU) + MAKE_CASE(AArch64ISD::ZERO_EXTEND_INREG_MERGE_PASSTHRU) + MAKE_CASE(AArch64ISD::FCEIL_MERGE_PASSTHRU) + MAKE_CASE(AArch64ISD::FFLOOR_MERGE_PASSTHRU) + MAKE_CASE(AArch64ISD::FNEARBYINT_MERGE_PASSTHRU) + MAKE_CASE(AArch64ISD::FRINT_MERGE_PASSTHRU) + MAKE_CASE(AArch64ISD::FROUND_MERGE_PASSTHRU) + MAKE_CASE(AArch64ISD::FROUNDEVEN_MERGE_PASSTHRU) + MAKE_CASE(AArch64ISD::FTRUNC_MERGE_PASSTHRU) + MAKE_CASE(AArch64ISD::FP_ROUND_MERGE_PASSTHRU) + MAKE_CASE(AArch64ISD::FP_EXTEND_MERGE_PASSTHRU) + MAKE_CASE(AArch64ISD::SINT_TO_FP_MERGE_PASSTHRU) + MAKE_CASE(AArch64ISD::UINT_TO_FP_MERGE_PASSTHRU) + MAKE_CASE(AArch64ISD::FCVTZU_MERGE_PASSTHRU) + MAKE_CASE(AArch64ISD::FCVTZS_MERGE_PASSTHRU) + MAKE_CASE(AArch64ISD::FSQRT_MERGE_PASSTHRU) + MAKE_CASE(AArch64ISD::FRECPX_MERGE_PASSTHRU) + MAKE_CASE(AArch64ISD::FABS_MERGE_PASSTHRU) + MAKE_CASE(AArch64ISD::ABS_MERGE_PASSTHRU) + MAKE_CASE(AArch64ISD::NEG_MERGE_PASSTHRU) MAKE_CASE(AArch64ISD::SETCC_MERGE_ZERO) MAKE_CASE(AArch64ISD::ADC) MAKE_CASE(AArch64ISD::SBC) @@ -1462,10 +1836,14 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const { MAKE_CASE(AArch64ISD::UADDV) MAKE_CASE(AArch64ISD::SRHADD) MAKE_CASE(AArch64ISD::URHADD) + MAKE_CASE(AArch64ISD::SHADD) + MAKE_CASE(AArch64ISD::UHADD) MAKE_CASE(AArch64ISD::SMINV) MAKE_CASE(AArch64ISD::UMINV) MAKE_CASE(AArch64ISD::SMAXV) MAKE_CASE(AArch64ISD::UMAXV) + MAKE_CASE(AArch64ISD::SADDV_PRED) + MAKE_CASE(AArch64ISD::UADDV_PRED) MAKE_CASE(AArch64ISD::SMAXV_PRED) MAKE_CASE(AArch64ISD::UMAXV_PRED) MAKE_CASE(AArch64ISD::SMINV_PRED) @@ -1483,12 +1861,16 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const { MAKE_CASE(AArch64ISD::FADD_PRED) MAKE_CASE(AArch64ISD::FADDA_PRED) MAKE_CASE(AArch64ISD::FADDV_PRED) + MAKE_CASE(AArch64ISD::FDIV_PRED) MAKE_CASE(AArch64ISD::FMA_PRED) MAKE_CASE(AArch64ISD::FMAXV_PRED) + MAKE_CASE(AArch64ISD::FMAXNM_PRED) MAKE_CASE(AArch64ISD::FMAXNMV_PRED) MAKE_CASE(AArch64ISD::FMINV_PRED) + MAKE_CASE(AArch64ISD::FMINNM_PRED) MAKE_CASE(AArch64ISD::FMINNMV_PRED) - MAKE_CASE(AArch64ISD::NOT) + MAKE_CASE(AArch64ISD::FMUL_PRED) + MAKE_CASE(AArch64ISD::FSUB_PRED) MAKE_CASE(AArch64ISD::BIT) MAKE_CASE(AArch64ISD::CBZ) MAKE_CASE(AArch64ISD::CBNZ) @@ -1600,8 +1982,15 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const { MAKE_CASE(AArch64ISD::LDP) MAKE_CASE(AArch64ISD::STP) MAKE_CASE(AArch64ISD::STNP) + MAKE_CASE(AArch64ISD::BITREVERSE_MERGE_PASSTHRU) + MAKE_CASE(AArch64ISD::BSWAP_MERGE_PASSTHRU) + MAKE_CASE(AArch64ISD::CTLZ_MERGE_PASSTHRU) + MAKE_CASE(AArch64ISD::CTPOP_MERGE_PASSTHRU) MAKE_CASE(AArch64ISD::DUP_MERGE_PASSTHRU) MAKE_CASE(AArch64ISD::INDEX_VECTOR) + MAKE_CASE(AArch64ISD::UABD) + MAKE_CASE(AArch64ISD::SABD) + MAKE_CASE(AArch64ISD::CALL_RVMARKER) } #undef MAKE_CASE return nullptr; @@ -1689,6 +2078,7 @@ MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter( case TargetOpcode::STACKMAP: case TargetOpcode::PATCHPOINT: + case TargetOpcode::STATEPOINT: return emitPatchPoint(MI, BB); case AArch64::CATCHRET: @@ -2514,21 +2904,10 @@ getAArch64XALUOOp(AArch64CC::CondCode &CC, SDValue Op, SelectionDAG &DAG) { return std::make_pair(Value, Overflow); } -SDValue AArch64TargetLowering::LowerF128Call(SDValue Op, SelectionDAG &DAG, - RTLIB::Libcall Call) const { - bool IsStrict = Op->isStrictFPOpcode(); - unsigned Offset = IsStrict ? 1 : 0; - SDValue Chain = IsStrict ? Op.getOperand(0) : SDValue(); - SmallVector<SDValue, 2> Ops(Op->op_begin() + Offset, Op->op_end()); - MakeLibCallOptions CallOptions; - SDValue Result; - SDLoc dl(Op); - std::tie(Result, Chain) = makeLibCall(DAG, Call, Op.getValueType(), Ops, - CallOptions, dl, Chain); - return IsStrict ? DAG.getMergeValues({Result, Chain}, dl) : Result; -} +SDValue AArch64TargetLowering::LowerXOR(SDValue Op, SelectionDAG &DAG) const { + if (useSVEForFixedLengthVectorVT(Op.getValueType())) + return LowerToScalableOp(Op, DAG); -static SDValue LowerXOR(SDValue Op, SelectionDAG &DAG) { SDValue Sel = Op.getOperand(0); SDValue Other = Op.getOperand(1); SDLoc dl(Sel); @@ -2703,16 +3082,18 @@ static SDValue LowerPREFETCH(SDValue Op, SelectionDAG &DAG) { SDValue AArch64TargetLowering::LowerFP_EXTEND(SDValue Op, SelectionDAG &DAG) const { - assert(Op.getValueType() == MVT::f128 && "Unexpected lowering"); - - RTLIB::Libcall LC; - LC = RTLIB::getFPEXT(Op.getOperand(0).getValueType(), Op.getValueType()); + if (Op.getValueType().isScalableVector()) + return LowerToPredicatedOp(Op, DAG, AArch64ISD::FP_EXTEND_MERGE_PASSTHRU); - return LowerF128Call(Op, DAG, LC); + assert(Op.getValueType() == MVT::f128 && "Unexpected lowering"); + return SDValue(); } SDValue AArch64TargetLowering::LowerFP_ROUND(SDValue Op, SelectionDAG &DAG) const { + if (Op.getValueType().isScalableVector()) + return LowerToPredicatedOp(Op, DAG, AArch64ISD::FP_ROUND_MERGE_PASSTHRU); + bool IsStrict = Op->isStrictFPOpcode(); SDValue SrcVal = Op.getOperand(IsStrict ? 1 : 0); EVT SrcVT = SrcVal.getValueType(); @@ -2726,19 +3107,7 @@ SDValue AArch64TargetLowering::LowerFP_ROUND(SDValue Op, return Op; } - RTLIB::Libcall LC; - LC = RTLIB::getFPROUND(SrcVT, Op.getValueType()); - - // FP_ROUND node has a second operand indicating whether it is known to be - // precise. That doesn't take part in the LibCall so we can't directly use - // LowerF128Call. - MakeLibCallOptions CallOptions; - SDValue Chain = IsStrict ? Op.getOperand(0) : SDValue(); - SDValue Result; - SDLoc dl(Op); - std::tie(Result, Chain) = makeLibCall(DAG, LC, Op.getValueType(), SrcVal, - CallOptions, dl, Chain); - return IsStrict ? DAG.getMergeValues({Result, Chain}, dl) : Result; + return SDValue(); } SDValue AArch64TargetLowering::LowerVectorFP_TO_INT(SDValue Op, @@ -2748,6 +3117,14 @@ SDValue AArch64TargetLowering::LowerVectorFP_TO_INT(SDValue Op, // in the cost tables. EVT InVT = Op.getOperand(0).getValueType(); EVT VT = Op.getValueType(); + + if (VT.isScalableVector()) { + unsigned Opcode = Op.getOpcode() == ISD::FP_TO_UINT + ? AArch64ISD::FCVTZU_MERGE_PASSTHRU + : AArch64ISD::FCVTZS_MERGE_PASSTHRU; + return LowerToPredicatedOp(Op, DAG, Opcode); + } + unsigned NumElts = InVT.getVectorNumElements(); // f16 conversions are promoted to f32 when full fp16 is not supported. @@ -2760,7 +3137,9 @@ SDValue AArch64TargetLowering::LowerVectorFP_TO_INT(SDValue Op, DAG.getNode(ISD::FP_EXTEND, dl, NewVT, Op.getOperand(0))); } - if (VT.getSizeInBits() < InVT.getSizeInBits()) { + uint64_t VTSize = VT.getFixedSizeInBits(); + uint64_t InVTSize = InVT.getFixedSizeInBits(); + if (VTSize < InVTSize) { SDLoc dl(Op); SDValue Cv = DAG.getNode(Op.getOpcode(), dl, InVT.changeVectorElementTypeToInteger(), @@ -2768,7 +3147,7 @@ SDValue AArch64TargetLowering::LowerVectorFP_TO_INT(SDValue Op, return DAG.getNode(ISD::TRUNCATE, dl, VT, Cv); } - if (VT.getSizeInBits() > InVT.getSizeInBits()) { + if (VTSize > InVTSize) { SDLoc dl(Op); MVT ExtVT = MVT::getVectorVT(MVT::getFloatingPointVT(VT.getScalarSizeInBits()), @@ -2803,17 +3182,11 @@ SDValue AArch64TargetLowering::LowerFP_TO_INT(SDValue Op, return Op; } - RTLIB::Libcall LC; - if (Op.getOpcode() == ISD::FP_TO_SINT || - Op.getOpcode() == ISD::STRICT_FP_TO_SINT) - LC = RTLIB::getFPTOSINT(SrcVal.getValueType(), Op.getValueType()); - else - LC = RTLIB::getFPTOUINT(SrcVal.getValueType(), Op.getValueType()); - - return LowerF128Call(Op, DAG, LC); + return SDValue(); } -static SDValue LowerVectorINT_TO_FP(SDValue Op, SelectionDAG &DAG) { +SDValue AArch64TargetLowering::LowerVectorINT_TO_FP(SDValue Op, + SelectionDAG &DAG) const { // Warning: We maintain cost tables in AArch64TargetTransformInfo.cpp. // Any additional optimization in this function should be recorded // in the cost tables. @@ -2821,21 +3194,38 @@ static SDValue LowerVectorINT_TO_FP(SDValue Op, SelectionDAG &DAG) { SDLoc dl(Op); SDValue In = Op.getOperand(0); EVT InVT = In.getValueType(); + unsigned Opc = Op.getOpcode(); + bool IsSigned = Opc == ISD::SINT_TO_FP || Opc == ISD::STRICT_SINT_TO_FP; + + if (VT.isScalableVector()) { + if (InVT.getVectorElementType() == MVT::i1) { + // We can't directly extend an SVE predicate; extend it first. + unsigned CastOpc = IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND; + EVT CastVT = getPromotedVTForPredicate(InVT); + In = DAG.getNode(CastOpc, dl, CastVT, In); + return DAG.getNode(Opc, dl, VT, In); + } + + unsigned Opcode = IsSigned ? AArch64ISD::SINT_TO_FP_MERGE_PASSTHRU + : AArch64ISD::UINT_TO_FP_MERGE_PASSTHRU; + return LowerToPredicatedOp(Op, DAG, Opcode); + } - if (VT.getSizeInBits() < InVT.getSizeInBits()) { + uint64_t VTSize = VT.getFixedSizeInBits(); + uint64_t InVTSize = InVT.getFixedSizeInBits(); + if (VTSize < InVTSize) { MVT CastVT = MVT::getVectorVT(MVT::getFloatingPointVT(InVT.getScalarSizeInBits()), InVT.getVectorNumElements()); - In = DAG.getNode(Op.getOpcode(), dl, CastVT, In); + In = DAG.getNode(Opc, dl, CastVT, In); return DAG.getNode(ISD::FP_ROUND, dl, VT, In, DAG.getIntPtrConstant(0, dl)); } - if (VT.getSizeInBits() > InVT.getSizeInBits()) { - unsigned CastOpc = - Op.getOpcode() == ISD::SINT_TO_FP ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND; + if (VTSize > InVTSize) { + unsigned CastOpc = IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND; EVT CastVT = VT.changeVectorElementTypeToInteger(); In = DAG.getNode(CastOpc, dl, CastVT, In); - return DAG.getNode(Op.getOpcode(), dl, VT, In); + return DAG.getNode(Opc, dl, VT, In); } return Op; @@ -2868,15 +3258,7 @@ SDValue AArch64TargetLowering::LowerINT_TO_FP(SDValue Op, // fp128. if (Op.getValueType() != MVT::f128) return Op; - - RTLIB::Libcall LC; - if (Op.getOpcode() == ISD::SINT_TO_FP || - Op.getOpcode() == ISD::STRICT_SINT_TO_FP) - LC = RTLIB::getSINTTOFP(SrcVal.getValueType(), Op.getValueType()); - else - LC = RTLIB::getUINTTOFP(SrcVal.getValueType(), Op.getValueType()); - - return LowerF128Call(Op, DAG, LC); + return SDValue(); } SDValue AArch64TargetLowering::LowerFSINCOS(SDValue Op, @@ -2990,7 +3372,8 @@ static bool isExtendedBUILD_VECTOR(SDNode *N, SelectionDAG &DAG, } static SDValue skipExtensionForVectorMULL(SDNode *N, SelectionDAG &DAG) { - if (N->getOpcode() == ISD::SIGN_EXTEND || N->getOpcode() == ISD::ZERO_EXTEND) + if (N->getOpcode() == ISD::SIGN_EXTEND || + N->getOpcode() == ISD::ZERO_EXTEND || N->getOpcode() == ISD::ANY_EXTEND) return addRequiredExtensionForVectorMULL(N->getOperand(0), DAG, N->getOperand(0)->getValueType(0), N->getValueType(0), @@ -3015,11 +3398,13 @@ static SDValue skipExtensionForVectorMULL(SDNode *N, SelectionDAG &DAG) { static bool isSignExtended(SDNode *N, SelectionDAG &DAG) { return N->getOpcode() == ISD::SIGN_EXTEND || + N->getOpcode() == ISD::ANY_EXTEND || isExtendedBUILD_VECTOR(N, DAG, true); } static bool isZeroExtended(SDNode *N, SelectionDAG &DAG) { return N->getOpcode() == ISD::ZERO_EXTEND || + N->getOpcode() == ISD::ANY_EXTEND || isExtendedBUILD_VECTOR(N, DAG, false); } @@ -3068,10 +3453,17 @@ SDValue AArch64TargetLowering::LowerFLT_ROUNDS_(SDValue Op, return DAG.getMergeValues({AND, Chain}, dl); } -static SDValue LowerMUL(SDValue Op, SelectionDAG &DAG) { +SDValue AArch64TargetLowering::LowerMUL(SDValue Op, SelectionDAG &DAG) const { + EVT VT = Op.getValueType(); + + // If SVE is available then i64 vector multiplications can also be made legal. + bool OverrideNEON = VT == MVT::v2i64 || VT == MVT::v1i64; + + if (VT.isScalableVector() || useSVEForFixedLengthVectorVT(VT, OverrideNEON)) + return LowerToPredicatedOp(Op, DAG, AArch64ISD::MUL_PRED, OverrideNEON); + // Multiplications are only custom-lowered for 128-bit vectors so that // VMULL can be detected. Otherwise v2i64 multiplications are not legal. - EVT VT = Op.getValueType(); assert(VT.is128BitVector() && VT.isInteger() && "unexpected type for custom-lowering ISD::MUL"); SDNode *N0 = Op.getOperand(0).getNode(); @@ -3230,11 +3622,77 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op, case Intrinsic::aarch64_sve_ptrue: return DAG.getNode(AArch64ISD::PTRUE, dl, Op.getValueType(), Op.getOperand(1)); + case Intrinsic::aarch64_sve_clz: + return DAG.getNode(AArch64ISD::CTLZ_MERGE_PASSTHRU, dl, Op.getValueType(), + Op.getOperand(2), Op.getOperand(3), Op.getOperand(1)); + case Intrinsic::aarch64_sve_cnt: { + SDValue Data = Op.getOperand(3); + // CTPOP only supports integer operands. + if (Data.getValueType().isFloatingPoint()) + Data = DAG.getNode(ISD::BITCAST, dl, Op.getValueType(), Data); + return DAG.getNode(AArch64ISD::CTPOP_MERGE_PASSTHRU, dl, Op.getValueType(), + Op.getOperand(2), Data, Op.getOperand(1)); + } case Intrinsic::aarch64_sve_dupq_lane: return LowerDUPQLane(Op, DAG); case Intrinsic::aarch64_sve_convert_from_svbool: return DAG.getNode(AArch64ISD::REINTERPRET_CAST, dl, Op.getValueType(), Op.getOperand(1)); + case Intrinsic::aarch64_sve_fneg: + return DAG.getNode(AArch64ISD::FNEG_MERGE_PASSTHRU, dl, Op.getValueType(), + Op.getOperand(2), Op.getOperand(3), Op.getOperand(1)); + case Intrinsic::aarch64_sve_frintp: + return DAG.getNode(AArch64ISD::FCEIL_MERGE_PASSTHRU, dl, Op.getValueType(), + Op.getOperand(2), Op.getOperand(3), Op.getOperand(1)); + case Intrinsic::aarch64_sve_frintm: + return DAG.getNode(AArch64ISD::FFLOOR_MERGE_PASSTHRU, dl, Op.getValueType(), + Op.getOperand(2), Op.getOperand(3), Op.getOperand(1)); + case Intrinsic::aarch64_sve_frinti: + return DAG.getNode(AArch64ISD::FNEARBYINT_MERGE_PASSTHRU, dl, Op.getValueType(), + Op.getOperand(2), Op.getOperand(3), Op.getOperand(1)); + case Intrinsic::aarch64_sve_frintx: + return DAG.getNode(AArch64ISD::FRINT_MERGE_PASSTHRU, dl, Op.getValueType(), + Op.getOperand(2), Op.getOperand(3), Op.getOperand(1)); + case Intrinsic::aarch64_sve_frinta: + return DAG.getNode(AArch64ISD::FROUND_MERGE_PASSTHRU, dl, Op.getValueType(), + Op.getOperand(2), Op.getOperand(3), Op.getOperand(1)); + case Intrinsic::aarch64_sve_frintn: + return DAG.getNode(AArch64ISD::FROUNDEVEN_MERGE_PASSTHRU, dl, Op.getValueType(), + Op.getOperand(2), Op.getOperand(3), Op.getOperand(1)); + case Intrinsic::aarch64_sve_frintz: + return DAG.getNode(AArch64ISD::FTRUNC_MERGE_PASSTHRU, dl, Op.getValueType(), + Op.getOperand(2), Op.getOperand(3), Op.getOperand(1)); + case Intrinsic::aarch64_sve_ucvtf: + return DAG.getNode(AArch64ISD::UINT_TO_FP_MERGE_PASSTHRU, dl, + Op.getValueType(), Op.getOperand(2), Op.getOperand(3), + Op.getOperand(1)); + case Intrinsic::aarch64_sve_scvtf: + return DAG.getNode(AArch64ISD::SINT_TO_FP_MERGE_PASSTHRU, dl, + Op.getValueType(), Op.getOperand(2), Op.getOperand(3), + Op.getOperand(1)); + case Intrinsic::aarch64_sve_fcvtzu: + return DAG.getNode(AArch64ISD::FCVTZU_MERGE_PASSTHRU, dl, + Op.getValueType(), Op.getOperand(2), Op.getOperand(3), + Op.getOperand(1)); + case Intrinsic::aarch64_sve_fcvtzs: + return DAG.getNode(AArch64ISD::FCVTZS_MERGE_PASSTHRU, dl, + Op.getValueType(), Op.getOperand(2), Op.getOperand(3), + Op.getOperand(1)); + case Intrinsic::aarch64_sve_fsqrt: + return DAG.getNode(AArch64ISD::FSQRT_MERGE_PASSTHRU, dl, Op.getValueType(), + Op.getOperand(2), Op.getOperand(3), Op.getOperand(1)); + case Intrinsic::aarch64_sve_frecpx: + return DAG.getNode(AArch64ISD::FRECPX_MERGE_PASSTHRU, dl, Op.getValueType(), + Op.getOperand(2), Op.getOperand(3), Op.getOperand(1)); + case Intrinsic::aarch64_sve_fabs: + return DAG.getNode(AArch64ISD::FABS_MERGE_PASSTHRU, dl, Op.getValueType(), + Op.getOperand(2), Op.getOperand(3), Op.getOperand(1)); + case Intrinsic::aarch64_sve_abs: + return DAG.getNode(AArch64ISD::ABS_MERGE_PASSTHRU, dl, Op.getValueType(), + Op.getOperand(2), Op.getOperand(3), Op.getOperand(1)); + case Intrinsic::aarch64_sve_neg: + return DAG.getNode(AArch64ISD::NEG_MERGE_PASSTHRU, dl, Op.getValueType(), + Op.getOperand(2), Op.getOperand(3), Op.getOperand(1)); case Intrinsic::aarch64_sve_convert_to_svbool: { EVT OutVT = Op.getValueType(); EVT InVT = Op.getOperand(1).getValueType(); @@ -3260,6 +3718,49 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op, return DAG.getNode(AArch64ISD::INSR, dl, Op.getValueType(), Op.getOperand(1), Scalar); } + case Intrinsic::aarch64_sve_rbit: + return DAG.getNode(AArch64ISD::BITREVERSE_MERGE_PASSTHRU, dl, + Op.getValueType(), Op.getOperand(2), Op.getOperand(3), + Op.getOperand(1)); + case Intrinsic::aarch64_sve_revb: + return DAG.getNode(AArch64ISD::BSWAP_MERGE_PASSTHRU, dl, Op.getValueType(), + Op.getOperand(2), Op.getOperand(3), Op.getOperand(1)); + case Intrinsic::aarch64_sve_sxtb: + return DAG.getNode( + AArch64ISD::SIGN_EXTEND_INREG_MERGE_PASSTHRU, dl, Op.getValueType(), + Op.getOperand(2), Op.getOperand(3), + DAG.getValueType(Op.getValueType().changeVectorElementType(MVT::i8)), + Op.getOperand(1)); + case Intrinsic::aarch64_sve_sxth: + return DAG.getNode( + AArch64ISD::SIGN_EXTEND_INREG_MERGE_PASSTHRU, dl, Op.getValueType(), + Op.getOperand(2), Op.getOperand(3), + DAG.getValueType(Op.getValueType().changeVectorElementType(MVT::i16)), + Op.getOperand(1)); + case Intrinsic::aarch64_sve_sxtw: + return DAG.getNode( + AArch64ISD::SIGN_EXTEND_INREG_MERGE_PASSTHRU, dl, Op.getValueType(), + Op.getOperand(2), Op.getOperand(3), + DAG.getValueType(Op.getValueType().changeVectorElementType(MVT::i32)), + Op.getOperand(1)); + case Intrinsic::aarch64_sve_uxtb: + return DAG.getNode( + AArch64ISD::ZERO_EXTEND_INREG_MERGE_PASSTHRU, dl, Op.getValueType(), + Op.getOperand(2), Op.getOperand(3), + DAG.getValueType(Op.getValueType().changeVectorElementType(MVT::i8)), + Op.getOperand(1)); + case Intrinsic::aarch64_sve_uxth: + return DAG.getNode( + AArch64ISD::ZERO_EXTEND_INREG_MERGE_PASSTHRU, dl, Op.getValueType(), + Op.getOperand(2), Op.getOperand(3), + DAG.getValueType(Op.getValueType().changeVectorElementType(MVT::i16)), + Op.getOperand(1)); + case Intrinsic::aarch64_sve_uxtw: + return DAG.getNode( + AArch64ISD::ZERO_EXTEND_INREG_MERGE_PASSTHRU, dl, Op.getValueType(), + Op.getOperand(2), Op.getOperand(3), + DAG.getValueType(Op.getValueType().changeVectorElementType(MVT::i32)), + Op.getOperand(1)); case Intrinsic::localaddress: { const auto &MF = DAG.getMachineFunction(); @@ -3299,19 +3800,291 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op, } case Intrinsic::aarch64_neon_srhadd: - case Intrinsic::aarch64_neon_urhadd: { - bool IsSignedAdd = IntNo == Intrinsic::aarch64_neon_srhadd; - unsigned Opcode = IsSignedAdd ? AArch64ISD::SRHADD : AArch64ISD::URHADD; + case Intrinsic::aarch64_neon_urhadd: + case Intrinsic::aarch64_neon_shadd: + case Intrinsic::aarch64_neon_uhadd: { + bool IsSignedAdd = (IntNo == Intrinsic::aarch64_neon_srhadd || + IntNo == Intrinsic::aarch64_neon_shadd); + bool IsRoundingAdd = (IntNo == Intrinsic::aarch64_neon_srhadd || + IntNo == Intrinsic::aarch64_neon_urhadd); + unsigned Opcode = + IsSignedAdd ? (IsRoundingAdd ? AArch64ISD::SRHADD : AArch64ISD::SHADD) + : (IsRoundingAdd ? AArch64ISD::URHADD : AArch64ISD::UHADD); return DAG.getNode(Opcode, dl, Op.getValueType(), Op.getOperand(1), Op.getOperand(2)); } + + case Intrinsic::aarch64_neon_uabd: { + return DAG.getNode(AArch64ISD::UABD, dl, Op.getValueType(), + Op.getOperand(1), Op.getOperand(2)); + } + case Intrinsic::aarch64_neon_sabd: { + return DAG.getNode(AArch64ISD::SABD, dl, Op.getValueType(), + Op.getOperand(1), Op.getOperand(2)); + } } } +bool AArch64TargetLowering::shouldRemoveExtendFromGSIndex(EVT VT) const { + if (VT.getVectorElementType() == MVT::i32 && + VT.getVectorElementCount().getKnownMinValue() >= 4) + return true; + + return false; +} + bool AArch64TargetLowering::isVectorLoadExtDesirable(SDValue ExtVal) const { return ExtVal.getValueType().isScalableVector(); } +unsigned getGatherVecOpcode(bool IsScaled, bool IsSigned, bool NeedsExtend) { + std::map<std::tuple<bool, bool, bool>, unsigned> AddrModes = { + {std::make_tuple(/*Scaled*/ false, /*Signed*/ false, /*Extend*/ false), + AArch64ISD::GLD1_MERGE_ZERO}, + {std::make_tuple(/*Scaled*/ false, /*Signed*/ false, /*Extend*/ true), + AArch64ISD::GLD1_UXTW_MERGE_ZERO}, + {std::make_tuple(/*Scaled*/ false, /*Signed*/ true, /*Extend*/ false), + AArch64ISD::GLD1_MERGE_ZERO}, + {std::make_tuple(/*Scaled*/ false, /*Signed*/ true, /*Extend*/ true), + AArch64ISD::GLD1_SXTW_MERGE_ZERO}, + {std::make_tuple(/*Scaled*/ true, /*Signed*/ false, /*Extend*/ false), + AArch64ISD::GLD1_SCALED_MERGE_ZERO}, + {std::make_tuple(/*Scaled*/ true, /*Signed*/ false, /*Extend*/ true), + AArch64ISD::GLD1_UXTW_SCALED_MERGE_ZERO}, + {std::make_tuple(/*Scaled*/ true, /*Signed*/ true, /*Extend*/ false), + AArch64ISD::GLD1_SCALED_MERGE_ZERO}, + {std::make_tuple(/*Scaled*/ true, /*Signed*/ true, /*Extend*/ true), + AArch64ISD::GLD1_SXTW_SCALED_MERGE_ZERO}, + }; + auto Key = std::make_tuple(IsScaled, IsSigned, NeedsExtend); + return AddrModes.find(Key)->second; +} + +unsigned getScatterVecOpcode(bool IsScaled, bool IsSigned, bool NeedsExtend) { + std::map<std::tuple<bool, bool, bool>, unsigned> AddrModes = { + {std::make_tuple(/*Scaled*/ false, /*Signed*/ false, /*Extend*/ false), + AArch64ISD::SST1_PRED}, + {std::make_tuple(/*Scaled*/ false, /*Signed*/ false, /*Extend*/ true), + AArch64ISD::SST1_UXTW_PRED}, + {std::make_tuple(/*Scaled*/ false, /*Signed*/ true, /*Extend*/ false), + AArch64ISD::SST1_PRED}, + {std::make_tuple(/*Scaled*/ false, /*Signed*/ true, /*Extend*/ true), + AArch64ISD::SST1_SXTW_PRED}, + {std::make_tuple(/*Scaled*/ true, /*Signed*/ false, /*Extend*/ false), + AArch64ISD::SST1_SCALED_PRED}, + {std::make_tuple(/*Scaled*/ true, /*Signed*/ false, /*Extend*/ true), + AArch64ISD::SST1_UXTW_SCALED_PRED}, + {std::make_tuple(/*Scaled*/ true, /*Signed*/ true, /*Extend*/ false), + AArch64ISD::SST1_SCALED_PRED}, + {std::make_tuple(/*Scaled*/ true, /*Signed*/ true, /*Extend*/ true), + AArch64ISD::SST1_SXTW_SCALED_PRED}, + }; + auto Key = std::make_tuple(IsScaled, IsSigned, NeedsExtend); + return AddrModes.find(Key)->second; +} + +unsigned getSignExtendedGatherOpcode(unsigned Opcode) { + switch (Opcode) { + default: + llvm_unreachable("unimplemented opcode"); + return Opcode; + case AArch64ISD::GLD1_MERGE_ZERO: + return AArch64ISD::GLD1S_MERGE_ZERO; + case AArch64ISD::GLD1_IMM_MERGE_ZERO: + return AArch64ISD::GLD1S_IMM_MERGE_ZERO; + case AArch64ISD::GLD1_UXTW_MERGE_ZERO: + return AArch64ISD::GLD1S_UXTW_MERGE_ZERO; + case AArch64ISD::GLD1_SXTW_MERGE_ZERO: + return AArch64ISD::GLD1S_SXTW_MERGE_ZERO; + case AArch64ISD::GLD1_SCALED_MERGE_ZERO: + return AArch64ISD::GLD1S_SCALED_MERGE_ZERO; + case AArch64ISD::GLD1_UXTW_SCALED_MERGE_ZERO: + return AArch64ISD::GLD1S_UXTW_SCALED_MERGE_ZERO; + case AArch64ISD::GLD1_SXTW_SCALED_MERGE_ZERO: + return AArch64ISD::GLD1S_SXTW_SCALED_MERGE_ZERO; + } +} + +bool getGatherScatterIndexIsExtended(SDValue Index) { + unsigned Opcode = Index.getOpcode(); + if (Opcode == ISD::SIGN_EXTEND_INREG) + return true; + + if (Opcode == ISD::AND) { + SDValue Splat = Index.getOperand(1); + if (Splat.getOpcode() != ISD::SPLAT_VECTOR) + return false; + ConstantSDNode *Mask = dyn_cast<ConstantSDNode>(Splat.getOperand(0)); + if (!Mask || Mask->getZExtValue() != 0xFFFFFFFF) + return false; + return true; + } + + return false; +} + +// If the base pointer of a masked gather or scatter is null, we +// may be able to swap BasePtr & Index and use the vector + register +// or vector + immediate addressing mode, e.g. +// VECTOR + REGISTER: +// getelementptr nullptr, <vscale x N x T> (splat(%offset)) + %indices) +// -> getelementptr %offset, <vscale x N x T> %indices +// VECTOR + IMMEDIATE: +// getelementptr nullptr, <vscale x N x T> (splat(#x)) + %indices) +// -> getelementptr #x, <vscale x N x T> %indices +void selectGatherScatterAddrMode(SDValue &BasePtr, SDValue &Index, EVT MemVT, + unsigned &Opcode, bool IsGather, + SelectionDAG &DAG) { + if (!isNullConstant(BasePtr)) + return; + + ConstantSDNode *Offset = nullptr; + if (Index.getOpcode() == ISD::ADD) + if (auto SplatVal = DAG.getSplatValue(Index.getOperand(1))) { + if (isa<ConstantSDNode>(SplatVal)) + Offset = cast<ConstantSDNode>(SplatVal); + else { + BasePtr = SplatVal; + Index = Index->getOperand(0); + return; + } + } + + unsigned NewOp = + IsGather ? AArch64ISD::GLD1_IMM_MERGE_ZERO : AArch64ISD::SST1_IMM_PRED; + + if (!Offset) { + std::swap(BasePtr, Index); + Opcode = NewOp; + return; + } + + uint64_t OffsetVal = Offset->getZExtValue(); + unsigned ScalarSizeInBytes = MemVT.getScalarSizeInBits() / 8; + auto ConstOffset = DAG.getConstant(OffsetVal, SDLoc(Index), MVT::i64); + + if (OffsetVal % ScalarSizeInBytes || OffsetVal / ScalarSizeInBytes > 31) { + // Index is out of range for the immediate addressing mode + BasePtr = ConstOffset; + Index = Index->getOperand(0); + return; + } + + // Immediate is in range + Opcode = NewOp; + BasePtr = Index->getOperand(0); + Index = ConstOffset; +} + +SDValue AArch64TargetLowering::LowerMGATHER(SDValue Op, + SelectionDAG &DAG) const { + SDLoc DL(Op); + MaskedGatherSDNode *MGT = cast<MaskedGatherSDNode>(Op); + assert(MGT && "Can only custom lower gather load nodes"); + + SDValue Index = MGT->getIndex(); + SDValue Chain = MGT->getChain(); + SDValue PassThru = MGT->getPassThru(); + SDValue Mask = MGT->getMask(); + SDValue BasePtr = MGT->getBasePtr(); + ISD::LoadExtType ExtTy = MGT->getExtensionType(); + + ISD::MemIndexType IndexType = MGT->getIndexType(); + bool IsScaled = + IndexType == ISD::SIGNED_SCALED || IndexType == ISD::UNSIGNED_SCALED; + bool IsSigned = + IndexType == ISD::SIGNED_SCALED || IndexType == ISD::SIGNED_UNSCALED; + bool IdxNeedsExtend = + getGatherScatterIndexIsExtended(Index) || + Index.getSimpleValueType().getVectorElementType() == MVT::i32; + bool ResNeedsSignExtend = ExtTy == ISD::EXTLOAD || ExtTy == ISD::SEXTLOAD; + + EVT VT = PassThru.getSimpleValueType(); + EVT MemVT = MGT->getMemoryVT(); + SDValue InputVT = DAG.getValueType(MemVT); + + if (VT.getVectorElementType() == MVT::bf16 && + !static_cast<const AArch64Subtarget &>(DAG.getSubtarget()).hasBF16()) + return SDValue(); + + // Handle FP data by using an integer gather and casting the result. + if (VT.isFloatingPoint()) { + EVT PassThruVT = getPackedSVEVectorVT(VT.getVectorElementCount()); + PassThru = getSVESafeBitCast(PassThruVT, PassThru, DAG); + InputVT = DAG.getValueType(MemVT.changeVectorElementTypeToInteger()); + } + + SDVTList VTs = DAG.getVTList(PassThru.getSimpleValueType(), MVT::Other); + + if (getGatherScatterIndexIsExtended(Index)) + Index = Index.getOperand(0); + + unsigned Opcode = getGatherVecOpcode(IsScaled, IsSigned, IdxNeedsExtend); + selectGatherScatterAddrMode(BasePtr, Index, MemVT, Opcode, + /*isGather=*/true, DAG); + + if (ResNeedsSignExtend) + Opcode = getSignExtendedGatherOpcode(Opcode); + + SDValue Ops[] = {Chain, Mask, BasePtr, Index, InputVT, PassThru}; + SDValue Gather = DAG.getNode(Opcode, DL, VTs, Ops); + + if (VT.isFloatingPoint()) { + SDValue Cast = getSVESafeBitCast(VT, Gather, DAG); + return DAG.getMergeValues({Cast, Gather}, DL); + } + + return Gather; +} + +SDValue AArch64TargetLowering::LowerMSCATTER(SDValue Op, + SelectionDAG &DAG) const { + SDLoc DL(Op); + MaskedScatterSDNode *MSC = cast<MaskedScatterSDNode>(Op); + assert(MSC && "Can only custom lower scatter store nodes"); + + SDValue Index = MSC->getIndex(); + SDValue Chain = MSC->getChain(); + SDValue StoreVal = MSC->getValue(); + SDValue Mask = MSC->getMask(); + SDValue BasePtr = MSC->getBasePtr(); + + ISD::MemIndexType IndexType = MSC->getIndexType(); + bool IsScaled = + IndexType == ISD::SIGNED_SCALED || IndexType == ISD::UNSIGNED_SCALED; + bool IsSigned = + IndexType == ISD::SIGNED_SCALED || IndexType == ISD::SIGNED_UNSCALED; + bool NeedsExtend = + getGatherScatterIndexIsExtended(Index) || + Index.getSimpleValueType().getVectorElementType() == MVT::i32; + + EVT VT = StoreVal.getSimpleValueType(); + SDVTList VTs = DAG.getVTList(MVT::Other); + EVT MemVT = MSC->getMemoryVT(); + SDValue InputVT = DAG.getValueType(MemVT); + + if (VT.getVectorElementType() == MVT::bf16 && + !static_cast<const AArch64Subtarget &>(DAG.getSubtarget()).hasBF16()) + return SDValue(); + + // Handle FP data by casting the data so an integer scatter can be used. + if (VT.isFloatingPoint()) { + EVT StoreValVT = getPackedSVEVectorVT(VT.getVectorElementCount()); + StoreVal = getSVESafeBitCast(StoreValVT, StoreVal, DAG); + InputVT = DAG.getValueType(MemVT.changeVectorElementTypeToInteger()); + } + + if (getGatherScatterIndexIsExtended(Index)) + Index = Index.getOperand(0); + + unsigned Opcode = getScatterVecOpcode(IsScaled, IsSigned, NeedsExtend); + selectGatherScatterAddrMode(BasePtr, Index, MemVT, Opcode, + /*isGather=*/false, DAG); + + SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, InputVT}; + return DAG.getNode(Opcode, DL, VTs, Ops); +} + // Custom lower trunc store for v4i8 vectors, since it is promoted to v4i16. static SDValue LowerTruncateVectorStore(SDLoc DL, StoreSDNode *ST, EVT VT, EVT MemVT, @@ -3377,8 +4150,9 @@ SDValue AArch64TargetLowering::LowerSTORE(SDValue Op, // 256 bit non-temporal stores can be lowered to STNP. Do this as part of // the custom lowering, as there are no un-paired non-temporal stores and // legalization will break up 256 bit inputs. + ElementCount EC = MemVT.getVectorElementCount(); if (StoreNode->isNonTemporal() && MemVT.getSizeInBits() == 256u && - MemVT.getVectorElementCount().Min % 2u == 0 && + EC.isKnownEven() && ((MemVT.getScalarSizeInBits() == 8u || MemVT.getScalarSizeInBits() == 16u || MemVT.getScalarSizeInBits() == 32u || @@ -3387,11 +4161,11 @@ SDValue AArch64TargetLowering::LowerSTORE(SDValue Op, DAG.getNode(ISD::EXTRACT_SUBVECTOR, Dl, MemVT.getHalfNumVectorElementsVT(*DAG.getContext()), StoreNode->getValue(), DAG.getConstant(0, Dl, MVT::i64)); - SDValue Hi = DAG.getNode( - ISD::EXTRACT_SUBVECTOR, Dl, - MemVT.getHalfNumVectorElementsVT(*DAG.getContext()), - StoreNode->getValue(), - DAG.getConstant(MemVT.getVectorElementCount().Min / 2, Dl, MVT::i64)); + SDValue Hi = + DAG.getNode(ISD::EXTRACT_SUBVECTOR, Dl, + MemVT.getHalfNumVectorElementsVT(*DAG.getContext()), + StoreNode->getValue(), + DAG.getConstant(EC.getKnownMinValue() / 2, Dl, MVT::i64)); SDValue Result = DAG.getMemIntrinsicNode( AArch64ISD::STNP, Dl, DAG.getVTList(MVT::Other), {StoreNode->getChain(), Lo, Hi, StoreNode->getBasePtr()}, @@ -3416,6 +4190,25 @@ SDValue AArch64TargetLowering::LowerSTORE(SDValue Op, return SDValue(); } +// Generate SUBS and CSEL for integer abs. +SDValue AArch64TargetLowering::LowerABS(SDValue Op, SelectionDAG &DAG) const { + MVT VT = Op.getSimpleValueType(); + + if (VT.isVector()) + return LowerToPredicatedOp(Op, DAG, AArch64ISD::ABS_MERGE_PASSTHRU); + + SDLoc DL(Op); + SDValue Neg = DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT), + Op.getOperand(0)); + // Generate SUBS & CSEL. + SDValue Cmp = + DAG.getNode(AArch64ISD::SUBS, DL, DAG.getVTList(VT, MVT::i32), + Op.getOperand(0), DAG.getConstant(0, DL, VT)); + return DAG.getNode(AArch64ISD::CSEL, DL, VT, Op.getOperand(0), Neg, + DAG.getConstant(AArch64CC::PL, DL, MVT::i32), + Cmp.getValue(1)); +} + SDValue AArch64TargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const { LLVM_DEBUG(dbgs() << "Custom lowering: "); @@ -3468,17 +4261,35 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op, case ISD::UMULO: return LowerXALUO(Op, DAG); case ISD::FADD: - if (useSVEForFixedLengthVectorVT(Op.getValueType())) - return LowerToPredicatedOp(Op, DAG, AArch64ISD::FADD_PRED); - return LowerF128Call(Op, DAG, RTLIB::ADD_F128); + return LowerToPredicatedOp(Op, DAG, AArch64ISD::FADD_PRED); case ISD::FSUB: - return LowerF128Call(Op, DAG, RTLIB::SUB_F128); + return LowerToPredicatedOp(Op, DAG, AArch64ISD::FSUB_PRED); case ISD::FMUL: - return LowerF128Call(Op, DAG, RTLIB::MUL_F128); + return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMUL_PRED); case ISD::FMA: return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMA_PRED); case ISD::FDIV: - return LowerF128Call(Op, DAG, RTLIB::DIV_F128); + return LowerToPredicatedOp(Op, DAG, AArch64ISD::FDIV_PRED); + case ISD::FNEG: + return LowerToPredicatedOp(Op, DAG, AArch64ISD::FNEG_MERGE_PASSTHRU); + case ISD::FCEIL: + return LowerToPredicatedOp(Op, DAG, AArch64ISD::FCEIL_MERGE_PASSTHRU); + case ISD::FFLOOR: + return LowerToPredicatedOp(Op, DAG, AArch64ISD::FFLOOR_MERGE_PASSTHRU); + case ISD::FNEARBYINT: + return LowerToPredicatedOp(Op, DAG, AArch64ISD::FNEARBYINT_MERGE_PASSTHRU); + case ISD::FRINT: + return LowerToPredicatedOp(Op, DAG, AArch64ISD::FRINT_MERGE_PASSTHRU); + case ISD::FROUND: + return LowerToPredicatedOp(Op, DAG, AArch64ISD::FROUND_MERGE_PASSTHRU); + case ISD::FROUNDEVEN: + return LowerToPredicatedOp(Op, DAG, AArch64ISD::FROUNDEVEN_MERGE_PASSTHRU); + case ISD::FTRUNC: + return LowerToPredicatedOp(Op, DAG, AArch64ISD::FTRUNC_MERGE_PASSTHRU); + case ISD::FSQRT: + return LowerToPredicatedOp(Op, DAG, AArch64ISD::FSQRT_MERGE_PASSTHRU); + case ISD::FABS: + return LowerToPredicatedOp(Op, DAG, AArch64ISD::FABS_MERGE_PASSTHRU); case ISD::FP_ROUND: case ISD::STRICT_FP_ROUND: return LowerFP_ROUND(Op, DAG); @@ -3492,6 +4303,8 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op, return LowerRETURNADDR(Op, DAG); case ISD::ADDROFRETURNADDR: return LowerADDROFRETURNADDR(Op, DAG); + case ISD::CONCAT_VECTORS: + return LowerCONCAT_VECTORS(Op, DAG); case ISD::INSERT_VECTOR_ELT: return LowerINSERT_VECTOR_ELT(Op, DAG); case ISD::EXTRACT_VECTOR_ELT: @@ -3507,17 +4320,20 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op, case ISD::INSERT_SUBVECTOR: return LowerINSERT_SUBVECTOR(Op, DAG); case ISD::SDIV: - return LowerToPredicatedOp(Op, DAG, AArch64ISD::SDIV_PRED); case ISD::UDIV: - return LowerToPredicatedOp(Op, DAG, AArch64ISD::UDIV_PRED); + return LowerDIV(Op, DAG); case ISD::SMIN: - return LowerToPredicatedOp(Op, DAG, AArch64ISD::SMIN_MERGE_OP1); + return LowerToPredicatedOp(Op, DAG, AArch64ISD::SMIN_PRED, + /*OverrideNEON=*/true); case ISD::UMIN: - return LowerToPredicatedOp(Op, DAG, AArch64ISD::UMIN_MERGE_OP1); + return LowerToPredicatedOp(Op, DAG, AArch64ISD::UMIN_PRED, + /*OverrideNEON=*/true); case ISD::SMAX: - return LowerToPredicatedOp(Op, DAG, AArch64ISD::SMAX_MERGE_OP1); + return LowerToPredicatedOp(Op, DAG, AArch64ISD::SMAX_PRED, + /*OverrideNEON=*/true); case ISD::UMAX: - return LowerToPredicatedOp(Op, DAG, AArch64ISD::UMAX_MERGE_OP1); + return LowerToPredicatedOp(Op, DAG, AArch64ISD::UMAX_PRED, + /*OverrideNEON=*/true); case ISD::SRA: case ISD::SRL: case ISD::SHL: @@ -3557,11 +4373,21 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op, return LowerINTRINSIC_WO_CHAIN(Op, DAG); case ISD::STORE: return LowerSTORE(Op, DAG); + case ISD::MGATHER: + return LowerMGATHER(Op, DAG); + case ISD::MSCATTER: + return LowerMSCATTER(Op, DAG); + case ISD::VECREDUCE_SEQ_FADD: + return LowerVECREDUCE_SEQ_FADD(Op, DAG); case ISD::VECREDUCE_ADD: + case ISD::VECREDUCE_AND: + case ISD::VECREDUCE_OR: + case ISD::VECREDUCE_XOR: case ISD::VECREDUCE_SMAX: case ISD::VECREDUCE_SMIN: case ISD::VECREDUCE_UMAX: case ISD::VECREDUCE_UMIN: + case ISD::VECREDUCE_FADD: case ISD::VECREDUCE_FMAX: case ISD::VECREDUCE_FMIN: return LowerVECREDUCE(Op, DAG); @@ -3573,6 +4399,21 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op, return LowerDYNAMIC_STACKALLOC(Op, DAG); case ISD::VSCALE: return LowerVSCALE(Op, DAG); + case ISD::ANY_EXTEND: + case ISD::SIGN_EXTEND: + case ISD::ZERO_EXTEND: + return LowerFixedLengthVectorIntExtendToSVE(Op, DAG); + case ISD::SIGN_EXTEND_INREG: { + // Only custom lower when ExtraVT has a legal byte based element type. + EVT ExtraVT = cast<VTSDNode>(Op.getOperand(1))->getVT(); + EVT ExtraEltVT = ExtraVT.getVectorElementType(); + if ((ExtraEltVT != MVT::i8) && (ExtraEltVT != MVT::i16) && + (ExtraEltVT != MVT::i32) && (ExtraEltVT != MVT::i64)) + return SDValue(); + + return LowerToPredicatedOp(Op, DAG, + AArch64ISD::SIGN_EXTEND_INREG_MERGE_PASSTHRU); + } case ISD::TRUNCATE: return LowerTRUNCATE(Op, DAG); case ISD::LOAD: @@ -3580,31 +4421,49 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op, return LowerFixedLengthVectorLoadToSVE(Op, DAG); llvm_unreachable("Unexpected request to lower ISD::LOAD"); case ISD::ADD: - if (useSVEForFixedLengthVectorVT(Op.getValueType())) - return LowerToPredicatedOp(Op, DAG, AArch64ISD::ADD_PRED); - llvm_unreachable("Unexpected request to lower ISD::ADD"); + return LowerToPredicatedOp(Op, DAG, AArch64ISD::ADD_PRED); + case ISD::AND: + return LowerToScalableOp(Op, DAG); + case ISD::SUB: + return LowerToPredicatedOp(Op, DAG, AArch64ISD::SUB_PRED); + case ISD::FMAXNUM: + return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMAXNM_PRED); + case ISD::FMINNUM: + return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMINNM_PRED); + case ISD::VSELECT: + return LowerFixedLengthVectorSelectToSVE(Op, DAG); + case ISD::ABS: + return LowerABS(Op, DAG); + case ISD::BITREVERSE: + return LowerToPredicatedOp(Op, DAG, AArch64ISD::BITREVERSE_MERGE_PASSTHRU, + /*OverrideNEON=*/true); + case ISD::BSWAP: + return LowerToPredicatedOp(Op, DAG, AArch64ISD::BSWAP_MERGE_PASSTHRU); + case ISD::CTLZ: + return LowerToPredicatedOp(Op, DAG, AArch64ISD::CTLZ_MERGE_PASSTHRU, + /*OverrideNEON=*/true); + case ISD::CTTZ: + return LowerCTTZ(Op, DAG); } } -bool AArch64TargetLowering::useSVEForFixedLengthVectors() const { - // Prefer NEON unless larger SVE registers are available. - return Subtarget->hasSVE() && Subtarget->getMinSVEVectorSizeInBits() >= 256; +bool AArch64TargetLowering::mergeStoresAfterLegalization(EVT VT) const { + return !Subtarget->useSVEForFixedLengthVectors(); } -bool AArch64TargetLowering::useSVEForFixedLengthVectorVT(EVT VT) const { - if (!useSVEForFixedLengthVectors()) +bool AArch64TargetLowering::useSVEForFixedLengthVectorVT( + EVT VT, bool OverrideNEON) const { + if (!Subtarget->useSVEForFixedLengthVectors()) return false; if (!VT.isFixedLengthVector()) return false; - // Fixed length predicates should be promoted to i8. - // NOTE: This is consistent with how NEON (and thus 64/128bit vectors) work. - if (VT.getVectorElementType() == MVT::i1) - return false; - // Don't use SVE for vectors we cannot scalarize if required. switch (VT.getVectorElementType().getSimpleVT().SimpleTy) { + // Fixed length predicates should be promoted to i8. + // NOTE: This is consistent with how NEON (and thus 64/128bit vectors) work. + case MVT::i1: default: return false; case MVT::i8: @@ -3617,12 +4476,16 @@ bool AArch64TargetLowering::useSVEForFixedLengthVectorVT(EVT VT) const { break; } + // All SVE implementations support NEON sized vectors. + if (OverrideNEON && (VT.is128BitVector() || VT.is64BitVector())) + return true; + // Ensure NEON MVTs only belong to a single register class. - if (VT.getSizeInBits() <= 128) + if (VT.getFixedSizeInBits() <= 128) return false; // Don't use SVE for types that don't fit. - if (VT.getSizeInBits() > Subtarget->getMinSVEVectorSizeInBits()) + if (VT.getFixedSizeInBits() > Subtarget->getMinSVEVectorSizeInBits()) return false; // TODO: Perhaps an artificial restriction, but worth having whilst getting @@ -3721,10 +4584,10 @@ SDValue AArch64TargetLowering::LowerFormalArguments( assert(!Res && "Call operand has unhandled type"); (void)Res; } - assert(ArgLocs.size() == Ins.size()); SmallVector<SDValue, 16> ArgValues; - for (unsigned i = 0, e = ArgLocs.size(); i != e; ++i) { - CCValAssign &VA = ArgLocs[i]; + unsigned ExtraArgLocs = 0; + for (unsigned i = 0, e = Ins.size(); i != e; ++i) { + CCValAssign &VA = ArgLocs[i - ExtraArgLocs]; if (Ins[i].Flags.isByVal()) { // Byval is used for HFAs in the PCS, but the system should work in a @@ -3852,16 +4715,44 @@ SDValue AArch64TargetLowering::LowerFormalArguments( if (VA.getLocInfo() == CCValAssign::Indirect) { assert(VA.getValVT().isScalableVector() && "Only scalable vectors can be passed indirectly"); - // If value is passed via pointer - do a load. - ArgValue = - DAG.getLoad(VA.getValVT(), DL, Chain, ArgValue, MachinePointerInfo()); - } - if (Subtarget->isTargetILP32() && Ins[i].Flags.isPointer()) - ArgValue = DAG.getNode(ISD::AssertZext, DL, ArgValue.getValueType(), - ArgValue, DAG.getValueType(MVT::i32)); - InVals.push_back(ArgValue); + uint64_t PartSize = VA.getValVT().getStoreSize().getKnownMinSize(); + unsigned NumParts = 1; + if (Ins[i].Flags.isInConsecutiveRegs()) { + assert(!Ins[i].Flags.isInConsecutiveRegsLast()); + while (!Ins[i + NumParts - 1].Flags.isInConsecutiveRegsLast()) + ++NumParts; + } + + MVT PartLoad = VA.getValVT(); + SDValue Ptr = ArgValue; + + // Ensure we generate all loads for each tuple part, whilst updating the + // pointer after each load correctly using vscale. + while (NumParts > 0) { + ArgValue = DAG.getLoad(PartLoad, DL, Chain, Ptr, MachinePointerInfo()); + InVals.push_back(ArgValue); + NumParts--; + if (NumParts > 0) { + SDValue BytesIncrement = DAG.getVScale( + DL, Ptr.getValueType(), + APInt(Ptr.getValueSizeInBits().getFixedSize(), PartSize)); + SDNodeFlags Flags; + Flags.setNoUnsignedWrap(true); + Ptr = DAG.getNode(ISD::ADD, DL, Ptr.getValueType(), Ptr, + BytesIncrement, Flags); + ExtraArgLocs++; + i++; + } + } + } else { + if (Subtarget->isTargetILP32() && Ins[i].Flags.isPointer()) + ArgValue = DAG.getNode(ISD::AssertZext, DL, ArgValue.getValueType(), + ArgValue, DAG.getValueType(MVT::i32)); + InVals.push_back(ArgValue); + } } + assert((ArgLocs.size() + ExtraArgLocs) == Ins.size()); // varargs AArch64FunctionInfo *FuncInfo = MF.getInfo<AArch64FunctionInfo>(); @@ -4036,9 +4927,7 @@ SDValue AArch64TargetLowering::LowerCallResult( const SmallVectorImpl<ISD::InputArg> &Ins, const SDLoc &DL, SelectionDAG &DAG, SmallVectorImpl<SDValue> &InVals, bool isThisReturn, SDValue ThisVal) const { - CCAssignFn *RetCC = CallConv == CallingConv::WebKit_JS - ? RetCC_AArch64_WebKit_JS - : RetCC_AArch64_AAPCS; + CCAssignFn *RetCC = CCAssignFnForReturn(CallConv); // Assign locations to each value returned by this call. SmallVector<CCValAssign, 16> RVLocs; DenseMap<unsigned, SDValue> CopiedRegs; @@ -4104,6 +4993,7 @@ static bool canGuaranteeTCO(CallingConv::ID CC) { static bool mayTailCallThisCC(CallingConv::ID CC) { switch (CC) { case CallingConv::C: + case CallingConv::AArch64_SVE_VectorCall: case CallingConv::PreserveMost: case CallingConv::Swift: return true; @@ -4123,6 +5013,15 @@ bool AArch64TargetLowering::isEligibleForTailCallOptimization( MachineFunction &MF = DAG.getMachineFunction(); const Function &CallerF = MF.getFunction(); CallingConv::ID CallerCC = CallerF.getCallingConv(); + + // If this function uses the C calling convention but has an SVE signature, + // then it preserves more registers and should assume the SVE_VectorCall CC. + // The check for matching callee-saved regs will determine whether it is + // eligible for TCO. + if (CallerCC == CallingConv::C && + AArch64RegisterInfo::hasSVEArgsOrReturn(&MF)) + CallerCC = CallingConv::AArch64_SVE_VectorCall; + bool CCMatch = CallerCC == CalleeCC; // When using the Windows calling convention on a non-windows OS, we want @@ -4310,6 +5209,20 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, bool TailCallOpt = MF.getTarget().Options.GuaranteedTailCallOpt; bool IsSibCall = false; + // Check callee args/returns for SVE registers and set calling convention + // accordingly. + if (CallConv == CallingConv::C) { + bool CalleeOutSVE = any_of(Outs, [](ISD::OutputArg &Out){ + return Out.VT.isScalableVector(); + }); + bool CalleeInSVE = any_of(Ins, [](ISD::InputArg &In){ + return In.VT.isScalableVector(); + }); + + if (CalleeInSVE || CalleeOutSVE) + CallConv = CallingConv::AArch64_SVE_VectorCall; + } + if (IsTailCall) { // Check if it's really possible to do a tail call. IsTailCall = isEligibleForTailCallOptimization( @@ -4339,6 +5252,10 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, for (unsigned i = 0; i != NumArgs; ++i) { MVT ArgVT = Outs[i].VT; + if (!Outs[i].IsFixed && ArgVT.isScalableVector()) + report_fatal_error("Passing SVE types to variadic functions is " + "currently not supported"); + ISD::ArgFlagsTy ArgFlags = Outs[i].Flags; CCAssignFn *AssignFn = CCAssignFnForCall(CallConv, /*IsVarArg=*/ !Outs[i].IsFixed); @@ -4433,8 +5350,9 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, } // Walk the register/memloc assignments, inserting copies/loads. - for (unsigned i = 0, e = ArgLocs.size(); i != e; ++i) { - CCValAssign &VA = ArgLocs[i]; + unsigned ExtraArgLocs = 0; + for (unsigned i = 0, e = Outs.size(); i != e; ++i) { + CCValAssign &VA = ArgLocs[i - ExtraArgLocs]; SDValue Arg = OutVals[i]; ISD::ArgFlagsTy Flags = Outs[i].Flags; @@ -4476,18 +5394,49 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, case CCValAssign::Indirect: assert(VA.getValVT().isScalableVector() && "Only scalable vectors can be passed indirectly"); + + uint64_t StoreSize = VA.getValVT().getStoreSize().getKnownMinSize(); + uint64_t PartSize = StoreSize; + unsigned NumParts = 1; + if (Outs[i].Flags.isInConsecutiveRegs()) { + assert(!Outs[i].Flags.isInConsecutiveRegsLast()); + while (!Outs[i + NumParts - 1].Flags.isInConsecutiveRegsLast()) + ++NumParts; + StoreSize *= NumParts; + } + MachineFrameInfo &MFI = DAG.getMachineFunction().getFrameInfo(); Type *Ty = EVT(VA.getValVT()).getTypeForEVT(*DAG.getContext()); Align Alignment = DAG.getDataLayout().getPrefTypeAlign(Ty); - int FI = MFI.CreateStackObject( - VA.getValVT().getStoreSize().getKnownMinSize(), Alignment, false); - MFI.setStackID(FI, TargetStackID::SVEVector); + int FI = MFI.CreateStackObject(StoreSize, Alignment, false); + MFI.setStackID(FI, TargetStackID::ScalableVector); - SDValue SpillSlot = DAG.getFrameIndex( + MachinePointerInfo MPI = + MachinePointerInfo::getFixedStack(DAG.getMachineFunction(), FI); + SDValue Ptr = DAG.getFrameIndex( FI, DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout())); - Chain = DAG.getStore( - Chain, DL, Arg, SpillSlot, - MachinePointerInfo::getFixedStack(DAG.getMachineFunction(), FI)); + SDValue SpillSlot = Ptr; + + // Ensure we generate all stores for each tuple part, whilst updating the + // pointer after each store correctly using vscale. + while (NumParts) { + Chain = DAG.getStore(Chain, DL, OutVals[i], Ptr, MPI); + NumParts--; + if (NumParts > 0) { + SDValue BytesIncrement = DAG.getVScale( + DL, Ptr.getValueType(), + APInt(Ptr.getValueSizeInBits().getFixedSize(), PartSize)); + SDNodeFlags Flags; + Flags.setNoUnsignedWrap(true); + + MPI = MachinePointerInfo(MPI.getAddrSpace()); + Ptr = DAG.getNode(ISD::ADD, DL, Ptr.getValueType(), Ptr, + BytesIncrement, Flags); + ExtraArgLocs++; + i++; + } + } + Arg = SpillSlot; break; } @@ -4507,20 +5456,18 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, // take care of putting the two halves in the right place but we have to // combine them. SDValue &Bits = - std::find_if(RegsToPass.begin(), RegsToPass.end(), - [=](const std::pair<unsigned, SDValue> &Elt) { - return Elt.first == VA.getLocReg(); - }) + llvm::find_if(RegsToPass, + [=](const std::pair<unsigned, SDValue> &Elt) { + return Elt.first == VA.getLocReg(); + }) ->second; Bits = DAG.getNode(ISD::OR, DL, Bits.getValueType(), Bits, Arg); // Call site info is used for function's parameter entry value // tracking. For now we track only simple cases when parameter // is transferred through whole register. - CSInfo.erase(std::remove_if(CSInfo.begin(), CSInfo.end(), - [&VA](MachineFunction::ArgRegPair ArgReg) { - return ArgReg.Reg == VA.getLocReg(); - }), - CSInfo.end()); + llvm::erase_if(CSInfo, [&VA](MachineFunction::ArgRegPair ArgReg) { + return ArgReg.Reg == VA.getLocReg(); + }); } else { RegsToPass.emplace_back(VA.getLocReg(), Arg); RegsUsed.insert(VA.getLocReg()); @@ -4539,7 +5486,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, uint32_t BEAlign = 0; unsigned OpSize; if (VA.getLocInfo() == CCValAssign::Indirect) - OpSize = VA.getLocVT().getSizeInBits(); + OpSize = VA.getLocVT().getFixedSizeInBits(); else OpSize = Flags.isByVal() ? Flags.getByValSize() * 8 : VA.getValVT().getSizeInBits(); @@ -4663,20 +5610,6 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, Ops.push_back(DAG.getRegister(RegToPass.first, RegToPass.second.getValueType())); - // Check callee args/returns for SVE registers and set calling convention - // accordingly. - if (CallConv == CallingConv::C) { - bool CalleeOutSVE = any_of(Outs, [](ISD::OutputArg &Out){ - return Out.VT.isScalableVector(); - }); - bool CalleeInSVE = any_of(Ins, [](ISD::InputArg &In){ - return In.VT.isScalableVector(); - }); - - if (CalleeInSVE || CalleeOutSVE) - CallConv = CallingConv::AArch64_SVE_VectorCall; - } - // Add a register mask operand representing the call-preserved registers. const uint32_t *Mask; const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo(); @@ -4713,8 +5646,17 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, return Ret; } + unsigned CallOpc = AArch64ISD::CALL; + // Calls marked with "rv_marker" are special. They should be expanded to the + // call, directly followed by a special marker sequence. Use the CALL_RVMARKER + // to do that. + if (CLI.CB && CLI.CB->hasRetAttr("rv_marker")) { + assert(!IsTailCall && "tail calls cannot be marked with rv_marker"); + CallOpc = AArch64ISD::CALL_RVMARKER; + } + // Returns a chain and a flag for retval copy to use. - Chain = DAG.getNode(AArch64ISD::CALL, DL, NodeTys, Ops); + Chain = DAG.getNode(CallOpc, DL, NodeTys, Ops); DAG.addNoMergeSiteInfo(Chain.getNode(), CLI.NoMerge); InFlag = Chain.getValue(1); DAG.addCallSiteInfo(Chain.getNode(), std::move(CSInfo)); @@ -4738,9 +5680,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, bool AArch64TargetLowering::CanLowerReturn( CallingConv::ID CallConv, MachineFunction &MF, bool isVarArg, const SmallVectorImpl<ISD::OutputArg> &Outs, LLVMContext &Context) const { - CCAssignFn *RetCC = CallConv == CallingConv::WebKit_JS - ? RetCC_AArch64_WebKit_JS - : RetCC_AArch64_AAPCS; + CCAssignFn *RetCC = CCAssignFnForReturn(CallConv); SmallVector<CCValAssign, 16> RVLocs; CCState CCInfo(CallConv, isVarArg, MF, RVLocs, Context); return CCInfo.CheckReturn(Outs, RetCC); @@ -4755,9 +5695,7 @@ AArch64TargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv, auto &MF = DAG.getMachineFunction(); auto *FuncInfo = MF.getInfo<AArch64FunctionInfo>(); - CCAssignFn *RetCC = CallConv == CallingConv::WebKit_JS - ? RetCC_AArch64_WebKit_JS - : RetCC_AArch64_AAPCS; + CCAssignFn *RetCC = CCAssignFnForReturn(CallConv); SmallVector<CCValAssign, 16> RVLocs; CCState CCInfo(CallConv, isVarArg, DAG.getMachineFunction(), RVLocs, *DAG.getContext()); @@ -4802,11 +5740,9 @@ AArch64TargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv, if (RegsUsed.count(VA.getLocReg())) { SDValue &Bits = - std::find_if(RetVals.begin(), RetVals.end(), - [=](const std::pair<unsigned, SDValue> &Elt) { - return Elt.first == VA.getLocReg(); - }) - ->second; + llvm::find_if(RetVals, [=](const std::pair<unsigned, SDValue> &Elt) { + return Elt.first == VA.getLocReg(); + })->second; Bits = DAG.getNode(ISD::OR, DL, Bits.getValueType(), Bits, Arg); } else { RetVals.emplace_back(VA.getLocReg(), Arg); @@ -5026,7 +5962,7 @@ AArch64TargetLowering::LowerDarwinGlobalTLSAddress(SDValue Op, SDValue FuncTLVGet = DAG.getLoad( PtrMemVT, DL, Chain, DescAddr, MachinePointerInfo::getGOT(DAG.getMachineFunction()), - /* Alignment = */ PtrMemVT.getSizeInBits() / 8, + Align(PtrMemVT.getSizeInBits() / 8), MachineMemOperand::MOInvariant | MachineMemOperand::MODereferenceable); Chain = FuncTLVGet.getValue(1); @@ -5341,6 +6277,22 @@ SDValue AArch64TargetLowering::LowerGlobalTLSAddress(SDValue Op, llvm_unreachable("Unexpected platform trying to use TLS"); } +// Looks through \param Val to determine the bit that can be used to +// check the sign of the value. It returns the unextended value and +// the sign bit position. +std::pair<SDValue, uint64_t> lookThroughSignExtension(SDValue Val) { + if (Val.getOpcode() == ISD::SIGN_EXTEND_INREG) + return {Val.getOperand(0), + cast<VTSDNode>(Val.getOperand(1))->getVT().getFixedSizeInBits() - + 1}; + + if (Val.getOpcode() == ISD::SIGN_EXTEND) + return {Val.getOperand(0), + Val.getOperand(0)->getValueType(0).getFixedSizeInBits() - 1}; + + return {Val, Val.getValueSizeInBits() - 1}; +} + SDValue AArch64TargetLowering::LowerBR_CC(SDValue Op, SelectionDAG &DAG) const { SDValue Chain = Op.getOperand(0); ISD::CondCode CC = cast<CondCodeSDNode>(Op.getOperand(1))->get(); @@ -5435,9 +6387,10 @@ SDValue AArch64TargetLowering::LowerBR_CC(SDValue Op, SelectionDAG &DAG) const { // Don't combine AND since emitComparison converts the AND to an ANDS // (a.k.a. TST) and the test in the test bit and branch instruction // becomes redundant. This would also increase register pressure. - uint64_t Mask = LHS.getValueSizeInBits() - 1; + uint64_t SignBitPos; + std::tie(LHS, SignBitPos) = lookThroughSignExtension(LHS); return DAG.getNode(AArch64ISD::TBNZ, dl, MVT::Other, Chain, LHS, - DAG.getConstant(Mask, dl, MVT::i64), Dest); + DAG.getConstant(SignBitPos, dl, MVT::i64), Dest); } } if (RHSC && RHSC->getSExtValue() == -1 && CC == ISD::SETGT && @@ -5445,9 +6398,10 @@ SDValue AArch64TargetLowering::LowerBR_CC(SDValue Op, SelectionDAG &DAG) const { // Don't combine AND since emitComparison converts the AND to an ANDS // (a.k.a. TST) and the test in the test bit and branch instruction // becomes redundant. This would also increase register pressure. - uint64_t Mask = LHS.getValueSizeInBits() - 1; + uint64_t SignBitPos; + std::tie(LHS, SignBitPos) = lookThroughSignExtension(LHS); return DAG.getNode(AArch64ISD::TBZ, dl, MVT::Other, Chain, LHS, - DAG.getConstant(Mask, dl, MVT::i64), Dest); + DAG.getConstant(SignBitPos, dl, MVT::i64), Dest); } SDValue CCVal; @@ -5594,6 +6548,9 @@ SDValue AArch64TargetLowering::LowerCTPOP(SDValue Op, SelectionDAG &DAG) const { return DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i128, UaddLV); } + if (VT.isScalableVector() || useSVEForFixedLengthVectorVT(VT)) + return LowerToPredicatedOp(Op, DAG, AArch64ISD::CTPOP_MERGE_PASSTHRU); + assert((VT == MVT::v1i64 || VT == MVT::v2i64 || VT == MVT::v2i32 || VT == MVT::v4i32 || VT == MVT::v4i16 || VT == MVT::v8i16) && "Unexpected type for custom ctpop lowering"); @@ -5617,6 +6574,16 @@ SDValue AArch64TargetLowering::LowerCTPOP(SDValue Op, SelectionDAG &DAG) const { return Val; } +SDValue AArch64TargetLowering::LowerCTTZ(SDValue Op, SelectionDAG &DAG) const { + EVT VT = Op.getValueType(); + assert(VT.isScalableVector() || + useSVEForFixedLengthVectorVT(VT, /*OverrideNEON=*/true)); + + SDLoc DL(Op); + SDValue RBIT = DAG.getNode(ISD::BITREVERSE, DL, VT, Op.getOperand(0)); + return DAG.getNode(ISD::CTLZ, DL, VT, RBIT); +} + SDValue AArch64TargetLowering::LowerSETCC(SDValue Op, SelectionDAG &DAG) const { if (Op.getValueType().isVector()) @@ -5774,7 +6741,8 @@ SDValue AArch64TargetLowering::LowerSELECT_CC(ISD::CondCode CC, SDValue LHS, // instead of a CSEL in that case. if (TrueVal == ~FalseVal) { Opcode = AArch64ISD::CSINV; - } else if (TrueVal == -FalseVal) { + } else if (FalseVal > std::numeric_limits<int64_t>::min() && + TrueVal == -FalseVal) { Opcode = AArch64ISD::CSNEG; } else if (TVal.getValueType() == MVT::i32) { // If our operands are only 32-bit wide, make sure we use 32-bit @@ -5974,6 +6942,9 @@ SDValue AArch64TargetLowering::LowerBR_JT(SDValue Op, SDValue Entry = Op.getOperand(2); int JTI = cast<JumpTableSDNode>(JT.getNode())->getIndex(); + auto *AFI = DAG.getMachineFunction().getInfo<AArch64FunctionInfo>(); + AFI->setJumpTableEntryInfo(JTI, 4, nullptr); + SDNode *Dest = DAG.getMachineNode(AArch64::JumpTableDest32, DL, MVT::i64, MVT::i64, JT, Entry, DAG.getTargetJumpTable(JTI, MVT::i32)); @@ -6040,11 +7011,13 @@ SDValue AArch64TargetLowering::LowerWin64_VASTART(SDValue Op, } SDValue AArch64TargetLowering::LowerAAPCS_VASTART(SDValue Op, - SelectionDAG &DAG) const { + SelectionDAG &DAG) const { // The layout of the va_list struct is specified in the AArch64 Procedure Call // Standard, section B.3. MachineFunction &MF = DAG.getMachineFunction(); AArch64FunctionInfo *FuncInfo = MF.getInfo<AArch64FunctionInfo>(); + unsigned PtrSize = Subtarget->isTargetILP32() ? 4 : 8; + auto PtrMemVT = getPointerMemTy(DAG.getDataLayout()); auto PtrVT = getPointerTy(DAG.getDataLayout()); SDLoc DL(Op); @@ -6054,56 +7027,64 @@ SDValue AArch64TargetLowering::LowerAAPCS_VASTART(SDValue Op, SmallVector<SDValue, 4> MemOps; // void *__stack at offset 0 + unsigned Offset = 0; SDValue Stack = DAG.getFrameIndex(FuncInfo->getVarArgsStackIndex(), PtrVT); + Stack = DAG.getZExtOrTrunc(Stack, DL, PtrMemVT); MemOps.push_back(DAG.getStore(Chain, DL, Stack, VAList, - MachinePointerInfo(SV), /* Alignment = */ 8)); + MachinePointerInfo(SV), Align(PtrSize))); - // void *__gr_top at offset 8 + // void *__gr_top at offset 8 (4 on ILP32) + Offset += PtrSize; int GPRSize = FuncInfo->getVarArgsGPRSize(); if (GPRSize > 0) { SDValue GRTop, GRTopAddr; - GRTopAddr = - DAG.getNode(ISD::ADD, DL, PtrVT, VAList, DAG.getConstant(8, DL, PtrVT)); + GRTopAddr = DAG.getNode(ISD::ADD, DL, PtrVT, VAList, + DAG.getConstant(Offset, DL, PtrVT)); GRTop = DAG.getFrameIndex(FuncInfo->getVarArgsGPRIndex(), PtrVT); GRTop = DAG.getNode(ISD::ADD, DL, PtrVT, GRTop, DAG.getConstant(GPRSize, DL, PtrVT)); + GRTop = DAG.getZExtOrTrunc(GRTop, DL, PtrMemVT); MemOps.push_back(DAG.getStore(Chain, DL, GRTop, GRTopAddr, - MachinePointerInfo(SV, 8), - /* Alignment = */ 8)); + MachinePointerInfo(SV, Offset), + Align(PtrSize))); } - // void *__vr_top at offset 16 + // void *__vr_top at offset 16 (8 on ILP32) + Offset += PtrSize; int FPRSize = FuncInfo->getVarArgsFPRSize(); if (FPRSize > 0) { SDValue VRTop, VRTopAddr; VRTopAddr = DAG.getNode(ISD::ADD, DL, PtrVT, VAList, - DAG.getConstant(16, DL, PtrVT)); + DAG.getConstant(Offset, DL, PtrVT)); VRTop = DAG.getFrameIndex(FuncInfo->getVarArgsFPRIndex(), PtrVT); VRTop = DAG.getNode(ISD::ADD, DL, PtrVT, VRTop, DAG.getConstant(FPRSize, DL, PtrVT)); + VRTop = DAG.getZExtOrTrunc(VRTop, DL, PtrMemVT); MemOps.push_back(DAG.getStore(Chain, DL, VRTop, VRTopAddr, - MachinePointerInfo(SV, 16), - /* Alignment = */ 8)); - } - - // int __gr_offs at offset 24 - SDValue GROffsAddr = - DAG.getNode(ISD::ADD, DL, PtrVT, VAList, DAG.getConstant(24, DL, PtrVT)); - MemOps.push_back(DAG.getStore( - Chain, DL, DAG.getConstant(-GPRSize, DL, MVT::i32), GROffsAddr, - MachinePointerInfo(SV, 24), /* Alignment = */ 4)); - - // int __vr_offs at offset 28 - SDValue VROffsAddr = - DAG.getNode(ISD::ADD, DL, PtrVT, VAList, DAG.getConstant(28, DL, PtrVT)); - MemOps.push_back(DAG.getStore( - Chain, DL, DAG.getConstant(-FPRSize, DL, MVT::i32), VROffsAddr, - MachinePointerInfo(SV, 28), /* Alignment = */ 4)); + MachinePointerInfo(SV, Offset), + Align(PtrSize))); + } + + // int __gr_offs at offset 24 (12 on ILP32) + Offset += PtrSize; + SDValue GROffsAddr = DAG.getNode(ISD::ADD, DL, PtrVT, VAList, + DAG.getConstant(Offset, DL, PtrVT)); + MemOps.push_back( + DAG.getStore(Chain, DL, DAG.getConstant(-GPRSize, DL, MVT::i32), + GROffsAddr, MachinePointerInfo(SV, Offset), Align(4))); + + // int __vr_offs at offset 28 (16 on ILP32) + Offset += 4; + SDValue VROffsAddr = DAG.getNode(ISD::ADD, DL, PtrVT, VAList, + DAG.getConstant(Offset, DL, PtrVT)); + MemOps.push_back( + DAG.getStore(Chain, DL, DAG.getConstant(-FPRSize, DL, MVT::i32), + VROffsAddr, MachinePointerInfo(SV, Offset), Align(4))); return DAG.getNode(ISD::TokenFactor, DL, MVT::Other, MemOps); } @@ -6126,8 +7107,10 @@ SDValue AArch64TargetLowering::LowerVACOPY(SDValue Op, // pointer. SDLoc DL(Op); unsigned PtrSize = Subtarget->isTargetILP32() ? 4 : 8; - unsigned VaListSize = (Subtarget->isTargetDarwin() || - Subtarget->isTargetWindows()) ? PtrSize : 32; + unsigned VaListSize = + (Subtarget->isTargetDarwin() || Subtarget->isTargetWindows()) + ? PtrSize + : Subtarget->isTargetILP32() ? 20 : 32; const Value *DestSV = cast<SrcValueSDNode>(Op.getOperand(3))->getValue(); const Value *SrcSV = cast<SrcValueSDNode>(Op.getOperand(4))->getValue(); @@ -6155,6 +7138,10 @@ SDValue AArch64TargetLowering::LowerVAARG(SDValue Op, SelectionDAG &DAG) const { Chain = VAList.getValue(1); VAList = DAG.getZExtOrTrunc(VAList, DL, PtrVT); + if (VT.isScalableVector()) + report_fatal_error("Passing SVE types to variadic functions is " + "currently not supported"); + if (Align && *Align > MinSlotSize) { VAList = DAG.getNode(ISD::ADD, DL, PtrVT, VAList, DAG.getConstant(Align->value() - 1, DL, PtrVT)); @@ -6276,17 +7263,34 @@ SDValue AArch64TargetLowering::LowerRETURNADDR(SDValue Op, EVT VT = Op.getValueType(); SDLoc DL(Op); unsigned Depth = cast<ConstantSDNode>(Op.getOperand(0))->getZExtValue(); + SDValue ReturnAddress; if (Depth) { SDValue FrameAddr = LowerFRAMEADDR(Op, DAG); SDValue Offset = DAG.getConstant(8, DL, getPointerTy(DAG.getDataLayout())); - return DAG.getLoad(VT, DL, DAG.getEntryNode(), - DAG.getNode(ISD::ADD, DL, VT, FrameAddr, Offset), - MachinePointerInfo()); + ReturnAddress = DAG.getLoad( + VT, DL, DAG.getEntryNode(), + DAG.getNode(ISD::ADD, DL, VT, FrameAddr, Offset), MachinePointerInfo()); + } else { + // Return LR, which contains the return address. Mark it an implicit + // live-in. + unsigned Reg = MF.addLiveIn(AArch64::LR, &AArch64::GPR64RegClass); + ReturnAddress = DAG.getCopyFromReg(DAG.getEntryNode(), DL, Reg, VT); + } + + // The XPACLRI instruction assembles to a hint-space instruction before + // Armv8.3-A therefore this instruction can be safely used for any pre + // Armv8.3-A architectures. On Armv8.3-A and onwards XPACI is available so use + // that instead. + SDNode *St; + if (Subtarget->hasPAuth()) { + St = DAG.getMachineNode(AArch64::XPACI, DL, VT, ReturnAddress); + } else { + // XPACLRI operates on LR therefore we must move the operand accordingly. + SDValue Chain = + DAG.getCopyToReg(DAG.getEntryNode(), DL, AArch64::LR, ReturnAddress); + St = DAG.getMachineNode(AArch64::XPACLRI, DL, VT, Chain); } - - // Return LR, which contains the return address. Mark it an implicit live-in. - unsigned Reg = MF.addLiveIn(AArch64::LR, &AArch64::GPR64RegClass); - return DAG.getCopyFromReg(DAG.getEntryNode(), DL, Reg, VT); + return SDValue(St, 0); } /// LowerShiftRightParts - Lower SRA_PARTS, which returns two @@ -6467,6 +7471,22 @@ static SDValue getEstimate(const AArch64Subtarget *ST, unsigned Opcode, return SDValue(); } +SDValue +AArch64TargetLowering::getSqrtInputTest(SDValue Op, SelectionDAG &DAG, + const DenormalMode &Mode) const { + SDLoc DL(Op); + EVT VT = Op.getValueType(); + EVT CCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT); + SDValue FPZero = DAG.getConstantFP(0.0, DL, VT); + return DAG.getSetCC(DL, CCVT, Op, FPZero, ISD::SETEQ); +} + +SDValue +AArch64TargetLowering::getSqrtResultForDenormInput(SDValue Op, + SelectionDAG &DAG) const { + return Op; +} + SDValue AArch64TargetLowering::getSqrtEstimate(SDValue Operand, SelectionDAG &DAG, int Enabled, int &ExtraSteps, @@ -6490,17 +7510,8 @@ SDValue AArch64TargetLowering::getSqrtEstimate(SDValue Operand, Step = DAG.getNode(AArch64ISD::FRSQRTS, DL, VT, Operand, Step, Flags); Estimate = DAG.getNode(ISD::FMUL, DL, VT, Estimate, Step, Flags); } - if (!Reciprocal) { - EVT CCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), - VT); - SDValue FPZero = DAG.getConstantFP(0.0, DL, VT); - SDValue Eq = DAG.getSetCC(DL, CCVT, Operand, FPZero, ISD::SETEQ); - + if (!Reciprocal) Estimate = DAG.getNode(ISD::FMUL, DL, VT, Operand, Estimate, Flags); - // Correct the result if the operand is 0.0. - Estimate = DAG.getNode(VT.isVector() ? ISD::VSELECT : ISD::SELECT, DL, - VT, Eq, Operand, Estimate); - } ExtraSteps = 0; return Estimate; @@ -6676,23 +7687,30 @@ AArch64TargetLowering::getRegForInlineAsmConstraint( if (Constraint.size() == 1) { switch (Constraint[0]) { case 'r': - if (VT.getSizeInBits() == 64) + if (VT.isScalableVector()) + return std::make_pair(0U, nullptr); + if (VT.getFixedSizeInBits() == 64) return std::make_pair(0U, &AArch64::GPR64commonRegClass); return std::make_pair(0U, &AArch64::GPR32commonRegClass); - case 'w': + case 'w': { if (!Subtarget->hasFPARMv8()) break; - if (VT.isScalableVector()) - return std::make_pair(0U, &AArch64::ZPRRegClass); - if (VT.getSizeInBits() == 16) + if (VT.isScalableVector()) { + if (VT.getVectorElementType() != MVT::i1) + return std::make_pair(0U, &AArch64::ZPRRegClass); + return std::make_pair(0U, nullptr); + } + uint64_t VTSize = VT.getFixedSizeInBits(); + if (VTSize == 16) return std::make_pair(0U, &AArch64::FPR16RegClass); - if (VT.getSizeInBits() == 32) + if (VTSize == 32) return std::make_pair(0U, &AArch64::FPR32RegClass); - if (VT.getSizeInBits() == 64) + if (VTSize == 64) return std::make_pair(0U, &AArch64::FPR64RegClass); - if (VT.getSizeInBits() == 128) + if (VTSize == 128) return std::make_pair(0U, &AArch64::FPR128RegClass); break; + } // The instructions that this constraint is designed for can // only take 128-bit registers so just use that regclass. case 'x': @@ -6713,10 +7731,11 @@ AArch64TargetLowering::getRegForInlineAsmConstraint( } else { PredicateConstraint PC = parsePredicateConstraint(Constraint); if (PC != PredicateConstraint::Invalid) { - assert(VT.isScalableVector()); + if (!VT.isScalableVector() || VT.getVectorElementType() != MVT::i1) + return std::make_pair(0U, nullptr); bool restricted = (PC == PredicateConstraint::Upl); return restricted ? std::make_pair(0U, &AArch64::PPR_3bRegClass) - : std::make_pair(0U, &AArch64::PPRRegClass); + : std::make_pair(0U, &AArch64::PPRRegClass); } } if (StringRef("{cc}").equals_lower(Constraint)) @@ -6955,6 +7974,8 @@ SDValue AArch64TargetLowering::ReconstructShuffle(SDValue Op, LLVM_DEBUG(dbgs() << "AArch64TargetLowering::ReconstructShuffle\n"); SDLoc dl(Op); EVT VT = Op.getValueType(); + assert(!VT.isScalableVector() && + "Scalable vectors cannot be used with ISD::BUILD_VECTOR"); unsigned NumElts = VT.getVectorNumElements(); struct ShuffleSourceInfo { @@ -7025,8 +8046,9 @@ SDValue AArch64TargetLowering::ReconstructShuffle(SDValue Op, } } unsigned ResMultiplier = - VT.getScalarSizeInBits() / SmallestEltTy.getSizeInBits(); - NumElts = VT.getSizeInBits() / SmallestEltTy.getSizeInBits(); + VT.getScalarSizeInBits() / SmallestEltTy.getFixedSizeInBits(); + uint64_t VTSize = VT.getFixedSizeInBits(); + NumElts = VTSize / SmallestEltTy.getFixedSizeInBits(); EVT ShuffleVT = EVT::getVectorVT(*DAG.getContext(), SmallestEltTy, NumElts); // If the source vector is too wide or too narrow, we may nevertheless be able @@ -7035,17 +8057,18 @@ SDValue AArch64TargetLowering::ReconstructShuffle(SDValue Op, for (auto &Src : Sources) { EVT SrcVT = Src.ShuffleVec.getValueType(); - if (SrcVT.getSizeInBits() == VT.getSizeInBits()) + uint64_t SrcVTSize = SrcVT.getFixedSizeInBits(); + if (SrcVTSize == VTSize) continue; // This stage of the search produces a source with the same element type as // the original, but with a total width matching the BUILD_VECTOR output. EVT EltVT = SrcVT.getVectorElementType(); - unsigned NumSrcElts = VT.getSizeInBits() / EltVT.getSizeInBits(); + unsigned NumSrcElts = VTSize / EltVT.getFixedSizeInBits(); EVT DestVT = EVT::getVectorVT(*DAG.getContext(), EltVT, NumSrcElts); - if (SrcVT.getSizeInBits() < VT.getSizeInBits()) { - assert(2 * SrcVT.getSizeInBits() == VT.getSizeInBits()); + if (SrcVTSize < VTSize) { + assert(2 * SrcVTSize == VTSize); // We can pad out the smaller vector for free, so if it's part of a // shuffle... Src.ShuffleVec = @@ -7054,7 +8077,11 @@ SDValue AArch64TargetLowering::ReconstructShuffle(SDValue Op, continue; } - assert(SrcVT.getSizeInBits() == 2 * VT.getSizeInBits()); + if (SrcVTSize != 2 * VTSize) { + LLVM_DEBUG( + dbgs() << "Reshuffle failed: result vector too small to extract\n"); + return SDValue(); + } if (Src.MaxElt - Src.MinElt >= NumSrcElts) { LLVM_DEBUG( @@ -7083,6 +8110,13 @@ SDValue AArch64TargetLowering::ReconstructShuffle(SDValue Op, DAG.getConstant(NumSrcElts, dl, MVT::i64)); unsigned Imm = Src.MinElt * getExtFactor(VEXTSrc1); + if (!SrcVT.is64BitVector()) { + LLVM_DEBUG( + dbgs() << "Reshuffle failed: don't know how to lower AArch64ISD::EXT " + "for SVE vectors."); + return SDValue(); + } + Src.ShuffleVec = DAG.getNode(AArch64ISD::EXT, dl, DestVT, VEXTSrc1, VEXTSrc2, DAG.getConstant(Imm, dl, MVT::i32)); @@ -7099,7 +8133,8 @@ SDValue AArch64TargetLowering::ReconstructShuffle(SDValue Op, continue; assert(ShuffleVT.getVectorElementType() == SmallestEltTy); Src.ShuffleVec = DAG.getNode(ISD::BITCAST, dl, ShuffleVT, Src.ShuffleVec); - Src.WindowScale = SrcEltTy.getSizeInBits() / SmallestEltTy.getSizeInBits(); + Src.WindowScale = + SrcEltTy.getFixedSizeInBits() / SmallestEltTy.getFixedSizeInBits(); Src.WindowBase *= Src.WindowScale; } @@ -7123,8 +8158,8 @@ SDValue AArch64TargetLowering::ReconstructShuffle(SDValue Op, // trunc. So only std::min(SrcBits, DestBits) actually get defined in this // segment. EVT OrigEltTy = Entry.getOperand(0).getValueType().getVectorElementType(); - int BitsDefined = - std::min(OrigEltTy.getSizeInBits(), VT.getScalarSizeInBits()); + int BitsDefined = std::min(OrigEltTy.getScalarSizeInBits(), + VT.getScalarSizeInBits()); int LanesDefined = BitsDefined / BitsPerShuffleLane; // This source is expected to fill ResMultiplier lanes of the final shuffle, @@ -7188,6 +8223,81 @@ static bool isSingletonEXTMask(ArrayRef<int> M, EVT VT, unsigned &Imm) { return true; } +/// Check if a vector shuffle corresponds to a DUP instructions with a larger +/// element width than the vector lane type. If that is the case the function +/// returns true and writes the value of the DUP instruction lane operand into +/// DupLaneOp +static bool isWideDUPMask(ArrayRef<int> M, EVT VT, unsigned BlockSize, + unsigned &DupLaneOp) { + assert((BlockSize == 16 || BlockSize == 32 || BlockSize == 64) && + "Only possible block sizes for wide DUP are: 16, 32, 64"); + + if (BlockSize <= VT.getScalarSizeInBits()) + return false; + if (BlockSize % VT.getScalarSizeInBits() != 0) + return false; + if (VT.getSizeInBits() % BlockSize != 0) + return false; + + size_t SingleVecNumElements = VT.getVectorNumElements(); + size_t NumEltsPerBlock = BlockSize / VT.getScalarSizeInBits(); + size_t NumBlocks = VT.getSizeInBits() / BlockSize; + + // We are looking for masks like + // [0, 1, 0, 1] or [2, 3, 2, 3] or [4, 5, 6, 7, 4, 5, 6, 7] where any element + // might be replaced by 'undefined'. BlockIndices will eventually contain + // lane indices of the duplicated block (i.e. [0, 1], [2, 3] and [4, 5, 6, 7] + // for the above examples) + SmallVector<int, 8> BlockElts(NumEltsPerBlock, -1); + for (size_t BlockIndex = 0; BlockIndex < NumBlocks; BlockIndex++) + for (size_t I = 0; I < NumEltsPerBlock; I++) { + int Elt = M[BlockIndex * NumEltsPerBlock + I]; + if (Elt < 0) + continue; + // For now we don't support shuffles that use the second operand + if ((unsigned)Elt >= SingleVecNumElements) + return false; + if (BlockElts[I] < 0) + BlockElts[I] = Elt; + else if (BlockElts[I] != Elt) + return false; + } + + // We found a candidate block (possibly with some undefs). It must be a + // sequence of consecutive integers starting with a value divisible by + // NumEltsPerBlock with some values possibly replaced by undef-s. + + // Find first non-undef element + auto FirstRealEltIter = find_if(BlockElts, [](int Elt) { return Elt >= 0; }); + assert(FirstRealEltIter != BlockElts.end() && + "Shuffle with all-undefs must have been caught by previous cases, " + "e.g. isSplat()"); + if (FirstRealEltIter == BlockElts.end()) { + DupLaneOp = 0; + return true; + } + + // Index of FirstRealElt in BlockElts + size_t FirstRealIndex = FirstRealEltIter - BlockElts.begin(); + + if ((unsigned)*FirstRealEltIter < FirstRealIndex) + return false; + // BlockElts[0] must have the following value if it isn't undef: + size_t Elt0 = *FirstRealEltIter - FirstRealIndex; + + // Check the first element + if (Elt0 % NumEltsPerBlock != 0) + return false; + // Check that the sequence indeed consists of consecutive integers (modulo + // undefs) + for (size_t I = 0; I < NumEltsPerBlock; I++) + if (BlockElts[I] >= 0 && (unsigned)BlockElts[I] != Elt0 + I) + return false; + + DupLaneOp = Elt0 / NumEltsPerBlock; + return true; +} + // check if an EXT instruction can handle the shuffle mask when the // vector sources of the shuffle are different. static bool isEXTMask(ArrayRef<int> M, EVT VT, bool &ReverseEXT, @@ -7621,6 +8731,60 @@ static unsigned getDUPLANEOp(EVT EltType) { llvm_unreachable("Invalid vector element type?"); } +static SDValue constructDup(SDValue V, int Lane, SDLoc dl, EVT VT, + unsigned Opcode, SelectionDAG &DAG) { + // Try to eliminate a bitcasted extract subvector before a DUPLANE. + auto getScaledOffsetDup = [](SDValue BitCast, int &LaneC, MVT &CastVT) { + // Match: dup (bitcast (extract_subv X, C)), LaneC + if (BitCast.getOpcode() != ISD::BITCAST || + BitCast.getOperand(0).getOpcode() != ISD::EXTRACT_SUBVECTOR) + return false; + + // The extract index must align in the destination type. That may not + // happen if the bitcast is from narrow to wide type. + SDValue Extract = BitCast.getOperand(0); + unsigned ExtIdx = Extract.getConstantOperandVal(1); + unsigned SrcEltBitWidth = Extract.getScalarValueSizeInBits(); + unsigned ExtIdxInBits = ExtIdx * SrcEltBitWidth; + unsigned CastedEltBitWidth = BitCast.getScalarValueSizeInBits(); + if (ExtIdxInBits % CastedEltBitWidth != 0) + return false; + + // Update the lane value by offsetting with the scaled extract index. + LaneC += ExtIdxInBits / CastedEltBitWidth; + + // Determine the casted vector type of the wide vector input. + // dup (bitcast (extract_subv X, C)), LaneC --> dup (bitcast X), LaneC' + // Examples: + // dup (bitcast (extract_subv v2f64 X, 1) to v2f32), 1 --> dup v4f32 X, 3 + // dup (bitcast (extract_subv v16i8 X, 8) to v4i16), 1 --> dup v8i16 X, 5 + unsigned SrcVecNumElts = + Extract.getOperand(0).getValueSizeInBits() / CastedEltBitWidth; + CastVT = MVT::getVectorVT(BitCast.getSimpleValueType().getScalarType(), + SrcVecNumElts); + return true; + }; + MVT CastVT; + if (getScaledOffsetDup(V, Lane, CastVT)) { + V = DAG.getBitcast(CastVT, V.getOperand(0).getOperand(0)); + } else if (V.getOpcode() == ISD::EXTRACT_SUBVECTOR) { + // The lane is incremented by the index of the extract. + // Example: dup v2f32 (extract v4f32 X, 2), 1 --> dup v4f32 X, 3 + Lane += V.getConstantOperandVal(1); + V = V.getOperand(0); + } else if (V.getOpcode() == ISD::CONCAT_VECTORS) { + // The lane is decremented if we are splatting from the 2nd operand. + // Example: dup v4i32 (concat v2i32 X, v2i32 Y), 3 --> dup v4i32 Y, 1 + unsigned Idx = Lane >= (int)VT.getVectorNumElements() / 2; + Lane -= Idx * VT.getVectorNumElements() / 2; + V = WidenVector(V.getOperand(Idx), DAG); + } else if (VT.getSizeInBits() == 64) { + // Widen the operand to 128-bit register with undef. + V = WidenVector(V, DAG); + } + return DAG.getNode(Opcode, dl, VT, V, DAG.getConstant(Lane, dl, MVT::i64)); +} + SDValue AArch64TargetLowering::LowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG) const { SDLoc dl(Op); @@ -7654,57 +8818,26 @@ SDValue AArch64TargetLowering::LowerVECTOR_SHUFFLE(SDValue Op, // Otherwise, duplicate from the lane of the input vector. unsigned Opcode = getDUPLANEOp(V1.getValueType().getVectorElementType()); - - // Try to eliminate a bitcasted extract subvector before a DUPLANE. - auto getScaledOffsetDup = [](SDValue BitCast, int &LaneC, MVT &CastVT) { - // Match: dup (bitcast (extract_subv X, C)), LaneC - if (BitCast.getOpcode() != ISD::BITCAST || - BitCast.getOperand(0).getOpcode() != ISD::EXTRACT_SUBVECTOR) - return false; - - // The extract index must align in the destination type. That may not - // happen if the bitcast is from narrow to wide type. - SDValue Extract = BitCast.getOperand(0); - unsigned ExtIdx = Extract.getConstantOperandVal(1); - unsigned SrcEltBitWidth = Extract.getScalarValueSizeInBits(); - unsigned ExtIdxInBits = ExtIdx * SrcEltBitWidth; - unsigned CastedEltBitWidth = BitCast.getScalarValueSizeInBits(); - if (ExtIdxInBits % CastedEltBitWidth != 0) - return false; - - // Update the lane value by offsetting with the scaled extract index. - LaneC += ExtIdxInBits / CastedEltBitWidth; - - // Determine the casted vector type of the wide vector input. - // dup (bitcast (extract_subv X, C)), LaneC --> dup (bitcast X), LaneC' - // Examples: - // dup (bitcast (extract_subv v2f64 X, 1) to v2f32), 1 --> dup v4f32 X, 3 - // dup (bitcast (extract_subv v16i8 X, 8) to v4i16), 1 --> dup v8i16 X, 5 - unsigned SrcVecNumElts = - Extract.getOperand(0).getValueSizeInBits() / CastedEltBitWidth; - CastVT = MVT::getVectorVT(BitCast.getSimpleValueType().getScalarType(), - SrcVecNumElts); - return true; - }; - MVT CastVT; - if (getScaledOffsetDup(V1, Lane, CastVT)) { - V1 = DAG.getBitcast(CastVT, V1.getOperand(0).getOperand(0)); - } else if (V1.getOpcode() == ISD::EXTRACT_SUBVECTOR) { - // The lane is incremented by the index of the extract. - // Example: dup v2f32 (extract v4f32 X, 2), 1 --> dup v4f32 X, 3 - Lane += V1.getConstantOperandVal(1); - V1 = V1.getOperand(0); - } else if (V1.getOpcode() == ISD::CONCAT_VECTORS) { - // The lane is decremented if we are splatting from the 2nd operand. - // Example: dup v4i32 (concat v2i32 X, v2i32 Y), 3 --> dup v4i32 Y, 1 - unsigned Idx = Lane >= (int)VT.getVectorNumElements() / 2; - Lane -= Idx * VT.getVectorNumElements() / 2; - V1 = WidenVector(V1.getOperand(Idx), DAG); - } else if (VT.getSizeInBits() == 64) { - // Widen the operand to 128-bit register with undef. - V1 = WidenVector(V1, DAG); - } - return DAG.getNode(Opcode, dl, VT, V1, DAG.getConstant(Lane, dl, MVT::i64)); + return constructDup(V1, Lane, dl, VT, Opcode, DAG); + } + + // Check if the mask matches a DUP for a wider element + for (unsigned LaneSize : {64U, 32U, 16U}) { + unsigned Lane = 0; + if (isWideDUPMask(ShuffleMask, VT, LaneSize, Lane)) { + unsigned Opcode = LaneSize == 64 ? AArch64ISD::DUPLANE64 + : LaneSize == 32 ? AArch64ISD::DUPLANE32 + : AArch64ISD::DUPLANE16; + // Cast V1 to an integer vector with required lane size + MVT NewEltTy = MVT::getIntegerVT(LaneSize); + unsigned NewEltCount = VT.getSizeInBits() / LaneSize; + MVT NewVecTy = MVT::getVectorVT(NewEltTy, NewEltCount); + V1 = DAG.getBitcast(NewVecTy, V1); + // Constuct the DUP instruction + V1 = constructDup(V1, Lane, dl, NewVecTy, Opcode, DAG); + // Cast back to the original type + return DAG.getBitcast(VT, V1); + } } if (isREVMask(ShuffleMask, VT, 64)) @@ -7775,7 +8908,7 @@ SDValue AArch64TargetLowering::LowerVECTOR_SHUFFLE(SDValue Op, EVT ScalarVT = VT.getVectorElementType(); - if (ScalarVT.getSizeInBits() < 32 && ScalarVT.isInteger()) + if (ScalarVT.getFixedSizeInBits() < 32 && ScalarVT.isInteger()) ScalarVT = MVT::i32; return DAG.getNode( @@ -7814,9 +8947,11 @@ SDValue AArch64TargetLowering::LowerSPLAT_VECTOR(SDValue Op, SDLoc dl(Op); EVT VT = Op.getValueType(); EVT ElemVT = VT.getScalarType(); - SDValue SplatVal = Op.getOperand(0); + if (useSVEForFixedLengthVectorVT(VT)) + return LowerToScalableOp(Op, DAG); + // Extend input splat value where needed to fit into a GPR (32b or 64b only) // FPRs don't have this restriction. switch (ElemVT.getSimpleVT().SimpleTy) { @@ -8246,6 +9381,9 @@ static SDValue tryLowerToSLI(SDNode *N, SelectionDAG &DAG) { SDValue AArch64TargetLowering::LowerVectorOR(SDValue Op, SelectionDAG &DAG) const { + if (useSVEForFixedLengthVectorVT(Op.getValueType())) + return LowerToScalableOp(Op, DAG); + // Attempt to form a vector S[LR]I from (or (and X, C1), (lsl Y, C2)) if (SDValue Res = tryLowerToSLI(Op.getNode(), DAG)) return Res; @@ -8404,14 +9542,18 @@ SDValue AArch64TargetLowering::LowerBUILD_VECTOR(SDValue Op, bool isConstant = true; bool AllLanesExtractElt = true; unsigned NumConstantLanes = 0; + unsigned NumDifferentLanes = 0; + unsigned NumUndefLanes = 0; SDValue Value; SDValue ConstantValue; for (unsigned i = 0; i < NumElts; ++i) { SDValue V = Op.getOperand(i); if (V.getOpcode() != ISD::EXTRACT_VECTOR_ELT) AllLanesExtractElt = false; - if (V.isUndef()) + if (V.isUndef()) { + ++NumUndefLanes; continue; + } if (i > 0) isOnlyLowElement = false; if (!isa<ConstantFPSDNode>(V) && !isa<ConstantSDNode>(V)) @@ -8427,8 +9569,10 @@ SDValue AArch64TargetLowering::LowerBUILD_VECTOR(SDValue Op, if (!Value.getNode()) Value = V; - else if (V != Value) + else if (V != Value) { usesOnlyOneValue = false; + ++NumDifferentLanes; + } } if (!Value.getNode()) { @@ -8554,11 +9698,20 @@ SDValue AArch64TargetLowering::LowerBUILD_VECTOR(SDValue Op, } } + // If we need to insert a small number of different non-constant elements and + // the vector width is sufficiently large, prefer using DUP with the common + // value and INSERT_VECTOR_ELT for the different lanes. If DUP is preferred, + // skip the constant lane handling below. + bool PreferDUPAndInsert = + !isConstant && NumDifferentLanes >= 1 && + NumDifferentLanes < ((NumElts - NumUndefLanes) / 2) && + NumDifferentLanes >= NumConstantLanes; + // If there was only one constant value used and for more than one lane, // start by splatting that value, then replace the non-constant lanes. This // is better than the default, which will perform a separate initialization // for each lane. - if (NumConstantLanes > 0 && usesOnlyOneConstantValue) { + if (!PreferDUPAndInsert && NumConstantLanes > 0 && usesOnlyOneConstantValue) { // Firstly, try to materialize the splat constant. SDValue Vec = DAG.getSplatBuildVector(VT, dl, ConstantValue), Val = ConstantBuildVector(Vec, DAG); @@ -8594,6 +9747,22 @@ SDValue AArch64TargetLowering::LowerBUILD_VECTOR(SDValue Op, return shuffle; } + if (PreferDUPAndInsert) { + // First, build a constant vector with the common element. + SmallVector<SDValue, 8> Ops; + for (unsigned I = 0; I < NumElts; ++I) + Ops.push_back(Value); + SDValue NewVector = LowerBUILD_VECTOR(DAG.getBuildVector(VT, dl, Ops), DAG); + // Next, insert the elements that do not match the common value. + for (unsigned I = 0; I < NumElts; ++I) + if (Op.getOperand(I) != Value) + NewVector = + DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, VT, NewVector, + Op.getOperand(I), DAG.getConstant(I, dl, MVT::i64)); + + return NewVector; + } + // If all else fails, just use a sequence of INSERT_VECTOR_ELT when we // know the default expansion would otherwise fall back on something even // worse. For a vector with one or two non-undef values, that's @@ -8642,6 +9811,18 @@ SDValue AArch64TargetLowering::LowerBUILD_VECTOR(SDValue Op, return SDValue(); } +SDValue AArch64TargetLowering::LowerCONCAT_VECTORS(SDValue Op, + SelectionDAG &DAG) const { + assert(Op.getValueType().isScalableVector() && + isTypeLegal(Op.getValueType()) && + "Expected legal scalable vector type!"); + + if (isTypeLegal(Op.getOperand(0).getValueType()) && Op.getNumOperands() == 2) + return Op; + + return SDValue(); +} + SDValue AArch64TargetLowering::LowerINSERT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const { assert(Op.getOpcode() == ISD::INSERT_VECTOR_ELT && "Unknown opcode!"); @@ -8737,7 +9918,8 @@ SDValue AArch64TargetLowering::LowerEXTRACT_SUBVECTOR(SDValue Op, // If this is extracting the upper 64-bits of a 128-bit vector, we match // that directly. - if (Size == 64 && Idx * InVT.getScalarSizeInBits() == 64) + if (Size == 64 && Idx * InVT.getScalarSizeInBits() == 64 && + InVT.getSizeInBits() == 128) return Op; return SDValue(); @@ -8751,9 +9933,34 @@ SDValue AArch64TargetLowering::LowerINSERT_SUBVECTOR(SDValue Op, EVT InVT = Op.getOperand(1).getValueType(); unsigned Idx = cast<ConstantSDNode>(Op.getOperand(2))->getZExtValue(); - // We don't have any patterns for scalable vector yet. - if (InVT.isScalableVector() || !useSVEForFixedLengthVectorVT(InVT)) + if (InVT.isScalableVector()) { + SDLoc DL(Op); + EVT VT = Op.getValueType(); + + if (!isTypeLegal(VT) || !VT.isInteger()) + return SDValue(); + + SDValue Vec0 = Op.getOperand(0); + SDValue Vec1 = Op.getOperand(1); + + // Ensure the subvector is half the size of the main vector. + if (VT.getVectorElementCount() != (InVT.getVectorElementCount() * 2)) + return SDValue(); + + // Extend elements of smaller vector... + EVT WideVT = InVT.widenIntegerVectorElementType(*(DAG.getContext())); + SDValue ExtVec = DAG.getNode(ISD::ANY_EXTEND, DL, WideVT, Vec1); + + if (Idx == 0) { + SDValue HiVec0 = DAG.getNode(AArch64ISD::UUNPKHI, DL, WideVT, Vec0); + return DAG.getNode(AArch64ISD::UZP1, DL, VT, ExtVec, HiVec0); + } else if (Idx == InVT.getVectorMinNumElements()) { + SDValue LoVec0 = DAG.getNode(AArch64ISD::UUNPKLO, DL, WideVT, Vec0); + return DAG.getNode(AArch64ISD::UZP1, DL, VT, LoVec0, ExtVec); + } + return SDValue(); + } // This will be matched by custom code during ISelDAGToDAG. if (Idx == 0 && isPackedVectorType(InVT, DAG) && Op.getOperand(0).isUndef()) @@ -8762,6 +9969,42 @@ SDValue AArch64TargetLowering::LowerINSERT_SUBVECTOR(SDValue Op, return SDValue(); } +SDValue AArch64TargetLowering::LowerDIV(SDValue Op, SelectionDAG &DAG) const { + EVT VT = Op.getValueType(); + + if (useSVEForFixedLengthVectorVT(VT, /*OverrideNEON=*/true)) + return LowerFixedLengthVectorIntDivideToSVE(Op, DAG); + + assert(VT.isScalableVector() && "Expected a scalable vector."); + + bool Signed = Op.getOpcode() == ISD::SDIV; + unsigned PredOpcode = Signed ? AArch64ISD::SDIV_PRED : AArch64ISD::UDIV_PRED; + + if (VT == MVT::nxv4i32 || VT == MVT::nxv2i64) + return LowerToPredicatedOp(Op, DAG, PredOpcode); + + // SVE doesn't have i8 and i16 DIV operations; widen them to 32-bit + // operations, and truncate the result. + EVT WidenedVT; + if (VT == MVT::nxv16i8) + WidenedVT = MVT::nxv8i16; + else if (VT == MVT::nxv8i16) + WidenedVT = MVT::nxv4i32; + else + llvm_unreachable("Unexpected Custom DIV operation"); + + SDLoc dl(Op); + unsigned UnpkLo = Signed ? AArch64ISD::SUNPKLO : AArch64ISD::UUNPKLO; + unsigned UnpkHi = Signed ? AArch64ISD::SUNPKHI : AArch64ISD::UUNPKHI; + SDValue Op0Lo = DAG.getNode(UnpkLo, dl, WidenedVT, Op.getOperand(0)); + SDValue Op1Lo = DAG.getNode(UnpkLo, dl, WidenedVT, Op.getOperand(1)); + SDValue Op0Hi = DAG.getNode(UnpkHi, dl, WidenedVT, Op.getOperand(0)); + SDValue Op1Hi = DAG.getNode(UnpkHi, dl, WidenedVT, Op.getOperand(1)); + SDValue ResultLo = DAG.getNode(Op.getOpcode(), dl, WidenedVT, Op0Lo, Op1Lo); + SDValue ResultHi = DAG.getNode(Op.getOpcode(), dl, WidenedVT, Op0Hi, Op1Hi); + return DAG.getNode(AArch64ISD::UZP1, dl, VT, ResultLo, ResultHi); +} + bool AArch64TargetLowering::isShuffleMaskLegal(ArrayRef<int> M, EVT VT) const { // Currently no fixed length shuffles that require SVE are legal. if (useSVEForFixedLengthVectorVT(VT)) @@ -8846,79 +10089,27 @@ static bool isVShiftRImm(SDValue Op, EVT VT, bool isNarrow, int64_t &Cnt) { return (Cnt >= 1 && Cnt <= (isNarrow ? ElementBits / 2 : ElementBits)); } -// Attempt to form urhadd(OpA, OpB) from -// truncate(vlshr(sub(zext(OpB), xor(zext(OpA), Ones(ElemSizeInBits))), 1)). -// The original form of this expression is -// truncate(srl(add(zext(OpB), add(zext(OpA), 1)), 1)) and before this function -// is called the srl will have been lowered to AArch64ISD::VLSHR and the -// ((OpA + OpB + 1) >> 1) expression will have been changed to (OpB - (~OpA)). -// This pass can also recognize a variant of this pattern that uses sign -// extension instead of zero extension and form a srhadd(OpA, OpB) from it. SDValue AArch64TargetLowering::LowerTRUNCATE(SDValue Op, SelectionDAG &DAG) const { EVT VT = Op.getValueType(); + if (VT.getScalarType() == MVT::i1) { + // Lower i1 truncate to `(x & 1) != 0`. + SDLoc dl(Op); + EVT OpVT = Op.getOperand(0).getValueType(); + SDValue Zero = DAG.getConstant(0, dl, OpVT); + SDValue One = DAG.getConstant(1, dl, OpVT); + SDValue And = DAG.getNode(ISD::AND, dl, OpVT, Op.getOperand(0), One); + return DAG.getSetCC(dl, VT, And, Zero, ISD::SETNE); + } + if (!VT.isVector() || VT.isScalableVector()) - return Op; + return SDValue(); if (useSVEForFixedLengthVectorVT(Op.getOperand(0).getValueType())) return LowerFixedLengthVectorTruncateToSVE(Op, DAG); - // Since we are looking for a right shift by a constant value of 1 and we are - // operating on types at least 16 bits in length (sign/zero extended OpA and - // OpB, which are at least 8 bits), it follows that the truncate will always - // discard the shifted-in bit and therefore the right shift will be logical - // regardless of the signedness of OpA and OpB. - SDValue Shift = Op.getOperand(0); - if (Shift.getOpcode() != AArch64ISD::VLSHR) - return Op; - - // Is the right shift using an immediate value of 1? - uint64_t ShiftAmount = Shift.getConstantOperandVal(1); - if (ShiftAmount != 1) - return Op; - - SDValue Sub = Shift->getOperand(0); - if (Sub.getOpcode() != ISD::SUB) - return Op; - - SDValue Xor = Sub.getOperand(1); - if (Xor.getOpcode() != ISD::XOR) - return Op; - - SDValue ExtendOpA = Xor.getOperand(0); - SDValue ExtendOpB = Sub.getOperand(0); - unsigned ExtendOpAOpc = ExtendOpA.getOpcode(); - unsigned ExtendOpBOpc = ExtendOpB.getOpcode(); - if (!(ExtendOpAOpc == ExtendOpBOpc && - (ExtendOpAOpc == ISD::ZERO_EXTEND || ExtendOpAOpc == ISD::SIGN_EXTEND))) - return Op; - - // Is the result of the right shift being truncated to the same value type as - // the original operands, OpA and OpB? - SDValue OpA = ExtendOpA.getOperand(0); - SDValue OpB = ExtendOpB.getOperand(0); - EVT OpAVT = OpA.getValueType(); - assert(ExtendOpA.getValueType() == ExtendOpB.getValueType()); - if (!(VT == OpAVT && OpAVT == OpB.getValueType())) - return Op; - - // Is the XOR using a constant amount of all ones in the right hand side? - uint64_t C; - if (!isAllConstantBuildVector(Xor.getOperand(1), C)) - return Op; - - unsigned ElemSizeInBits = VT.getScalarSizeInBits(); - APInt CAsAPInt(ElemSizeInBits, C); - if (CAsAPInt != APInt::getAllOnesValue(ElemSizeInBits)) - return Op; - - SDLoc DL(Op); - bool IsSignExtend = ExtendOpAOpc == ISD::SIGN_EXTEND; - unsigned RHADDOpc = IsSignExtend ? AArch64ISD::SRHADD : AArch64ISD::URHADD; - SDValue ResultURHADD = DAG.getNode(RHADDOpc, DL, VT, OpA, OpB); - - return ResultURHADD; + return SDValue(); } SDValue AArch64TargetLowering::LowerVectorSRA_SRL_SHL(SDValue Op, @@ -8936,8 +10127,8 @@ SDValue AArch64TargetLowering::LowerVectorSRA_SRL_SHL(SDValue Op, llvm_unreachable("unexpected shift opcode"); case ISD::SHL: - if (VT.isScalableVector()) - return LowerToPredicatedOp(Op, DAG, AArch64ISD::SHL_MERGE_OP1); + if (VT.isScalableVector() || useSVEForFixedLengthVectorVT(VT)) + return LowerToPredicatedOp(Op, DAG, AArch64ISD::SHL_PRED); if (isVShiftLImm(Op.getOperand(1), VT, false, Cnt) && Cnt < EltSize) return DAG.getNode(AArch64ISD::VSHL, DL, VT, Op.getOperand(0), @@ -8948,9 +10139,9 @@ SDValue AArch64TargetLowering::LowerVectorSRA_SRL_SHL(SDValue Op, Op.getOperand(0), Op.getOperand(1)); case ISD::SRA: case ISD::SRL: - if (VT.isScalableVector()) { - unsigned Opc = Op.getOpcode() == ISD::SRA ? AArch64ISD::SRA_MERGE_OP1 - : AArch64ISD::SRL_MERGE_OP1; + if (VT.isScalableVector() || useSVEForFixedLengthVectorVT(VT)) { + unsigned Opc = Op.getOpcode() == ISD::SRA ? AArch64ISD::SRA_PRED + : AArch64ISD::SRL_PRED; return LowerToPredicatedOp(Op, DAG, Opc); } @@ -9002,7 +10193,7 @@ static SDValue EmitVectorComparison(SDValue LHS, SDValue RHS, Fcmeq = DAG.getNode(AArch64ISD::FCMEQz, dl, VT, LHS); else Fcmeq = DAG.getNode(AArch64ISD::FCMEQ, dl, VT, LHS, RHS); - return DAG.getNode(AArch64ISD::NOT, dl, VT, Fcmeq); + return DAG.getNOT(dl, Fcmeq, VT); } case AArch64CC::EQ: if (IsZero) @@ -9041,7 +10232,7 @@ static SDValue EmitVectorComparison(SDValue LHS, SDValue RHS, Cmeq = DAG.getNode(AArch64ISD::CMEQz, dl, VT, LHS); else Cmeq = DAG.getNode(AArch64ISD::CMEQ, dl, VT, LHS, RHS); - return DAG.getNode(AArch64ISD::NOT, dl, VT, Cmeq); + return DAG.getNOT(dl, Cmeq, VT); } case AArch64CC::EQ: if (IsZero) @@ -9082,6 +10273,9 @@ SDValue AArch64TargetLowering::LowerVSETCC(SDValue Op, return LowerToPredicatedOp(Op, DAG, AArch64ISD::SETCC_MERGE_ZERO); } + if (useSVEForFixedLengthVectorVT(Op.getOperand(0).getValueType())) + return LowerFixedLengthVectorSetccToSVE(Op, DAG); + ISD::CondCode CC = cast<CondCodeSDNode>(Op.getOperand(2))->get(); SDValue LHS = Op.getOperand(0); SDValue RHS = Op.getOperand(1); @@ -9154,6 +10348,51 @@ static SDValue getReductionSDNode(unsigned Op, SDLoc DL, SDValue ScalarOp, SDValue AArch64TargetLowering::LowerVECREDUCE(SDValue Op, SelectionDAG &DAG) const { + SDValue Src = Op.getOperand(0); + + // Try to lower fixed length reductions to SVE. + EVT SrcVT = Src.getValueType(); + bool OverrideNEON = Op.getOpcode() == ISD::VECREDUCE_AND || + Op.getOpcode() == ISD::VECREDUCE_OR || + Op.getOpcode() == ISD::VECREDUCE_XOR || + Op.getOpcode() == ISD::VECREDUCE_FADD || + (Op.getOpcode() != ISD::VECREDUCE_ADD && + SrcVT.getVectorElementType() == MVT::i64); + if (SrcVT.isScalableVector() || + useSVEForFixedLengthVectorVT(SrcVT, OverrideNEON)) { + + if (SrcVT.getVectorElementType() == MVT::i1) + return LowerPredReductionToSVE(Op, DAG); + + switch (Op.getOpcode()) { + case ISD::VECREDUCE_ADD: + return LowerReductionToSVE(AArch64ISD::UADDV_PRED, Op, DAG); + case ISD::VECREDUCE_AND: + return LowerReductionToSVE(AArch64ISD::ANDV_PRED, Op, DAG); + case ISD::VECREDUCE_OR: + return LowerReductionToSVE(AArch64ISD::ORV_PRED, Op, DAG); + case ISD::VECREDUCE_SMAX: + return LowerReductionToSVE(AArch64ISD::SMAXV_PRED, Op, DAG); + case ISD::VECREDUCE_SMIN: + return LowerReductionToSVE(AArch64ISD::SMINV_PRED, Op, DAG); + case ISD::VECREDUCE_UMAX: + return LowerReductionToSVE(AArch64ISD::UMAXV_PRED, Op, DAG); + case ISD::VECREDUCE_UMIN: + return LowerReductionToSVE(AArch64ISD::UMINV_PRED, Op, DAG); + case ISD::VECREDUCE_XOR: + return LowerReductionToSVE(AArch64ISD::EORV_PRED, Op, DAG); + case ISD::VECREDUCE_FADD: + return LowerReductionToSVE(AArch64ISD::FADDV_PRED, Op, DAG); + case ISD::VECREDUCE_FMAX: + return LowerReductionToSVE(AArch64ISD::FMAXNMV_PRED, Op, DAG); + case ISD::VECREDUCE_FMIN: + return LowerReductionToSVE(AArch64ISD::FMINNMV_PRED, Op, DAG); + default: + llvm_unreachable("Unhandled fixed length reduction"); + } + } + + // Lower NEON reductions. SDLoc dl(Op); switch (Op.getOpcode()) { case ISD::VECREDUCE_ADD: @@ -9167,18 +10406,16 @@ SDValue AArch64TargetLowering::LowerVECREDUCE(SDValue Op, case ISD::VECREDUCE_UMIN: return getReductionSDNode(AArch64ISD::UMINV, dl, Op, DAG); case ISD::VECREDUCE_FMAX: { - assert(Op->getFlags().hasNoNaNs() && "fmax vector reduction needs NoNaN flag"); return DAG.getNode( ISD::INTRINSIC_WO_CHAIN, dl, Op.getValueType(), DAG.getConstant(Intrinsic::aarch64_neon_fmaxnmv, dl, MVT::i32), - Op.getOperand(0)); + Src); } case ISD::VECREDUCE_FMIN: { - assert(Op->getFlags().hasNoNaNs() && "fmin vector reduction needs NoNaN flag"); return DAG.getNode( ISD::INTRINSIC_WO_CHAIN, dl, Op.getValueType(), DAG.getConstant(Intrinsic::aarch64_neon_fminnmv, dl, MVT::i32), - Op.getOperand(0)); + Src); } default: llvm_unreachable("Unhandled reduction"); @@ -9188,7 +10425,7 @@ SDValue AArch64TargetLowering::LowerVECREDUCE(SDValue Op, SDValue AArch64TargetLowering::LowerATOMIC_LOAD_SUB(SDValue Op, SelectionDAG &DAG) const { auto &Subtarget = static_cast<const AArch64Subtarget &>(DAG.getSubtarget()); - if (!Subtarget.hasLSE()) + if (!Subtarget.hasLSE() && !Subtarget.outlineAtomics()) return SDValue(); // LSE has an atomic load-add instruction, but not a load-sub. @@ -9205,7 +10442,7 @@ SDValue AArch64TargetLowering::LowerATOMIC_LOAD_SUB(SDValue Op, SDValue AArch64TargetLowering::LowerATOMIC_LOAD_AND(SDValue Op, SelectionDAG &DAG) const { auto &Subtarget = static_cast<const AArch64Subtarget &>(DAG.getSubtarget()); - if (!Subtarget.hasLSE()) + if (!Subtarget.hasLSE() && !Subtarget.outlineAtomics()) return SDValue(); // LSE has an atomic load-clear instruction, but not a load-and. @@ -9306,16 +10543,17 @@ SDValue AArch64TargetLowering::LowerVSCALE(SDValue Op, /// Set the IntrinsicInfo for the `aarch64_sve_st<N>` intrinsics. template <unsigned NumVecs> -static bool setInfoSVEStN(AArch64TargetLowering::IntrinsicInfo &Info, - const CallInst &CI) { +static bool +setInfoSVEStN(const AArch64TargetLowering &TLI, const DataLayout &DL, + AArch64TargetLowering::IntrinsicInfo &Info, const CallInst &CI) { Info.opc = ISD::INTRINSIC_VOID; // Retrieve EC from first vector argument. - const EVT VT = EVT::getEVT(CI.getArgOperand(0)->getType()); + const EVT VT = TLI.getMemValueType(DL, CI.getArgOperand(0)->getType()); ElementCount EC = VT.getVectorElementCount(); #ifndef NDEBUG // Check the assumption that all input vectors are the same type. for (unsigned I = 0; I < NumVecs; ++I) - assert(VT == EVT::getEVT(CI.getArgOperand(I)->getType()) && + assert(VT == TLI.getMemValueType(DL, CI.getArgOperand(I)->getType()) && "Invalid type."); #endif // memVT is `NumVecs * VT`. @@ -9338,11 +10576,11 @@ bool AArch64TargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info, auto &DL = I.getModule()->getDataLayout(); switch (Intrinsic) { case Intrinsic::aarch64_sve_st2: - return setInfoSVEStN<2>(Info, I); + return setInfoSVEStN<2>(*this, DL, Info, I); case Intrinsic::aarch64_sve_st3: - return setInfoSVEStN<3>(Info, I); + return setInfoSVEStN<3>(*this, DL, Info, I); case Intrinsic::aarch64_sve_st4: - return setInfoSVEStN<4>(Info, I); + return setInfoSVEStN<4>(*this, DL, Info, I); case Intrinsic::aarch64_neon_ld2: case Intrinsic::aarch64_neon_ld3: case Intrinsic::aarch64_neon_ld4: @@ -9498,15 +10736,15 @@ bool AArch64TargetLowering::shouldReduceLoadWidth(SDNode *Load, bool AArch64TargetLowering::isTruncateFree(Type *Ty1, Type *Ty2) const { if (!Ty1->isIntegerTy() || !Ty2->isIntegerTy()) return false; - unsigned NumBits1 = Ty1->getPrimitiveSizeInBits(); - unsigned NumBits2 = Ty2->getPrimitiveSizeInBits(); + uint64_t NumBits1 = Ty1->getPrimitiveSizeInBits().getFixedSize(); + uint64_t NumBits2 = Ty2->getPrimitiveSizeInBits().getFixedSize(); return NumBits1 > NumBits2; } bool AArch64TargetLowering::isTruncateFree(EVT VT1, EVT VT2) const { if (VT1.isVector() || VT2.isVector() || !VT1.isInteger() || !VT2.isInteger()) return false; - unsigned NumBits1 = VT1.getSizeInBits(); - unsigned NumBits2 = VT2.getSizeInBits(); + uint64_t NumBits1 = VT1.getFixedSizeInBits(); + uint64_t NumBits2 = VT2.getFixedSizeInBits(); return NumBits1 > NumBits2; } @@ -9748,6 +10986,43 @@ bool AArch64TargetLowering::shouldSinkOperands( return true; } + case Instruction::Mul: { + bool IsProfitable = false; + for (auto &Op : I->operands()) { + // Make sure we are not already sinking this operand + if (any_of(Ops, [&](Use *U) { return U->get() == Op; })) + continue; + + ShuffleVectorInst *Shuffle = dyn_cast<ShuffleVectorInst>(Op); + if (!Shuffle || !Shuffle->isZeroEltSplat()) + continue; + + Value *ShuffleOperand = Shuffle->getOperand(0); + InsertElementInst *Insert = dyn_cast<InsertElementInst>(ShuffleOperand); + if (!Insert) + continue; + + Instruction *OperandInstr = dyn_cast<Instruction>(Insert->getOperand(1)); + if (!OperandInstr) + continue; + + ConstantInt *ElementConstant = + dyn_cast<ConstantInt>(Insert->getOperand(2)); + // Check that the insertelement is inserting into element 0 + if (!ElementConstant || ElementConstant->getZExtValue() != 0) + continue; + + unsigned Opcode = OperandInstr->getOpcode(); + if (Opcode != Instruction::SExt && Opcode != Instruction::ZExt) + continue; + + Ops.push_back(&Shuffle->getOperandUse(0)); + Ops.push_back(&Op); + IsProfitable = true; + } + + return IsProfitable; + } default: return false; } @@ -10083,11 +11358,12 @@ SDValue AArch64TargetLowering::LowerSVEStructLoad(unsigned Intrinsic, {Intrinsic::aarch64_sve_ld4, {4, AArch64ISD::SVE_LD4_MERGE_ZERO}}}; std::tie(N, Opcode) = IntrinsicMap[Intrinsic]; - assert(VT.getVectorElementCount().Min % N == 0 && + assert(VT.getVectorElementCount().getKnownMinValue() % N == 0 && "invalid tuple vector type!"); - EVT SplitVT = EVT::getVectorVT(*DAG.getContext(), VT.getVectorElementType(), - VT.getVectorElementCount() / N); + EVT SplitVT = + EVT::getVectorVT(*DAG.getContext(), VT.getVectorElementType(), + VT.getVectorElementCount().divideCoefficientBy(N)); assert(isTypeLegal(SplitVT)); SmallVector<EVT, 5> VTs(N, SplitVT); @@ -10378,32 +11654,77 @@ static SDValue foldVectorXorShiftIntoCmp(SDNode *N, SelectionDAG &DAG, return DAG.getNode(AArch64ISD::CMGEz, SDLoc(N), VT, Shift.getOperand(0)); } -// Generate SUBS and CSEL for integer abs. -static SDValue performIntegerAbsCombine(SDNode *N, SelectionDAG &DAG) { - EVT VT = N->getValueType(0); +// VECREDUCE_ADD( EXTEND(v16i8_type) ) to +// VECREDUCE_ADD( DOTv16i8(v16i8_type) ) +static SDValue performVecReduceAddCombine(SDNode *N, SelectionDAG &DAG, + const AArch64Subtarget *ST) { + SDValue Op0 = N->getOperand(0); + if (!ST->hasDotProd() || N->getValueType(0) != MVT::i32) + return SDValue(); - SDValue N0 = N->getOperand(0); - SDValue N1 = N->getOperand(1); - SDLoc DL(N); + if (Op0.getValueType().getVectorElementType() != MVT::i32) + return SDValue(); - // Check pattern of XOR(ADD(X,Y), Y) where Y is SRA(X, size(X)-1) - // and change it to SUB and CSEL. - if (VT.isInteger() && N->getOpcode() == ISD::XOR && - N0.getOpcode() == ISD::ADD && N0.getOperand(1) == N1 && - N1.getOpcode() == ISD::SRA && N1.getOperand(0) == N0.getOperand(0)) - if (ConstantSDNode *Y1C = dyn_cast<ConstantSDNode>(N1.getOperand(1))) - if (Y1C->getAPIntValue() == VT.getSizeInBits() - 1) { - SDValue Neg = DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT), - N0.getOperand(0)); - // Generate SUBS & CSEL. - SDValue Cmp = - DAG.getNode(AArch64ISD::SUBS, DL, DAG.getVTList(VT, MVT::i32), - N0.getOperand(0), DAG.getConstant(0, DL, VT)); - return DAG.getNode(AArch64ISD::CSEL, DL, VT, N0.getOperand(0), Neg, - DAG.getConstant(AArch64CC::PL, DL, MVT::i32), - SDValue(Cmp.getNode(), 1)); - } - return SDValue(); + unsigned ExtOpcode = Op0.getOpcode(); + if (ExtOpcode != ISD::ZERO_EXTEND && ExtOpcode != ISD::SIGN_EXTEND) + return SDValue(); + + EVT Op0VT = Op0.getOperand(0).getValueType(); + if (Op0VT != MVT::v16i8) + return SDValue(); + + SDLoc DL(Op0); + SDValue Ones = DAG.getConstant(1, DL, Op0VT); + SDValue Zeros = DAG.getConstant(0, DL, MVT::v4i32); + auto DotIntrisic = (ExtOpcode == ISD::ZERO_EXTEND) + ? Intrinsic::aarch64_neon_udot + : Intrinsic::aarch64_neon_sdot; + SDValue Dot = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, Zeros.getValueType(), + DAG.getConstant(DotIntrisic, DL, MVT::i32), Zeros, + Ones, Op0.getOperand(0)); + return DAG.getNode(ISD::VECREDUCE_ADD, DL, N->getValueType(0), Dot); +} + +// Given a ABS node, detect the following pattern: +// (ABS (SUB (EXTEND a), (EXTEND b))). +// Generates UABD/SABD instruction. +static SDValue performABSCombine(SDNode *N, SelectionDAG &DAG, + TargetLowering::DAGCombinerInfo &DCI, + const AArch64Subtarget *Subtarget) { + SDValue AbsOp1 = N->getOperand(0); + SDValue Op0, Op1; + + if (AbsOp1.getOpcode() != ISD::SUB) + return SDValue(); + + Op0 = AbsOp1.getOperand(0); + Op1 = AbsOp1.getOperand(1); + + unsigned Opc0 = Op0.getOpcode(); + // Check if the operands of the sub are (zero|sign)-extended. + if (Opc0 != Op1.getOpcode() || + (Opc0 != ISD::ZERO_EXTEND && Opc0 != ISD::SIGN_EXTEND)) + return SDValue(); + + EVT VectorT1 = Op0.getOperand(0).getValueType(); + EVT VectorT2 = Op1.getOperand(0).getValueType(); + // Check if vectors are of same type and valid size. + uint64_t Size = VectorT1.getFixedSizeInBits(); + if (VectorT1 != VectorT2 || (Size != 64 && Size != 128)) + return SDValue(); + + // Check if vector element types are valid. + EVT VT1 = VectorT1.getVectorElementType(); + if (VT1 != MVT::i8 && VT1 != MVT::i16 && VT1 != MVT::i32) + return SDValue(); + + Op0 = Op0.getOperand(0); + Op1 = Op1.getOperand(0); + unsigned ABDOpcode = + (Opc0 == ISD::SIGN_EXTEND) ? AArch64ISD::SABD : AArch64ISD::UABD; + SDValue ABD = + DAG.getNode(ABDOpcode, SDLoc(N), Op0->getValueType(0), Op0, Op1); + return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), N->getValueType(0), ABD); } static SDValue performXorCombine(SDNode *N, SelectionDAG &DAG, @@ -10412,10 +11733,7 @@ static SDValue performXorCombine(SDNode *N, SelectionDAG &DAG, if (DCI.isBeforeLegalizeOps()) return SDValue(); - if (SDValue Cmp = foldVectorXorShiftIntoCmp(N, DAG, Subtarget)) - return Cmp; - - return performIntegerAbsCombine(N, DAG); + return foldVectorXorShiftIntoCmp(N, DAG, Subtarget); } SDValue @@ -10474,9 +11792,157 @@ static bool IsSVECntIntrinsic(SDValue S) { return false; } +/// Calculates what the pre-extend type is, based on the extension +/// operation node provided by \p Extend. +/// +/// In the case that \p Extend is a SIGN_EXTEND or a ZERO_EXTEND, the +/// pre-extend type is pulled directly from the operand, while other extend +/// operations need a bit more inspection to get this information. +/// +/// \param Extend The SDNode from the DAG that represents the extend operation +/// \param DAG The SelectionDAG hosting the \p Extend node +/// +/// \returns The type representing the \p Extend source type, or \p MVT::Other +/// if no valid type can be determined +static EVT calculatePreExtendType(SDValue Extend, SelectionDAG &DAG) { + switch (Extend.getOpcode()) { + case ISD::SIGN_EXTEND: + case ISD::ZERO_EXTEND: + return Extend.getOperand(0).getValueType(); + case ISD::AssertSext: + case ISD::AssertZext: + case ISD::SIGN_EXTEND_INREG: { + VTSDNode *TypeNode = dyn_cast<VTSDNode>(Extend.getOperand(1)); + if (!TypeNode) + return MVT::Other; + return TypeNode->getVT(); + } + case ISD::AND: { + ConstantSDNode *Constant = + dyn_cast<ConstantSDNode>(Extend.getOperand(1).getNode()); + if (!Constant) + return MVT::Other; + + uint32_t Mask = Constant->getZExtValue(); + + if (Mask == UCHAR_MAX) + return MVT::i8; + else if (Mask == USHRT_MAX) + return MVT::i16; + else if (Mask == UINT_MAX) + return MVT::i32; + + return MVT::Other; + } + default: + return MVT::Other; + } + + llvm_unreachable("Code path unhandled in calculatePreExtendType!"); +} + +/// Combines a dup(sext/zext) node pattern into sext/zext(dup) +/// making use of the vector SExt/ZExt rather than the scalar SExt/ZExt +static SDValue performCommonVectorExtendCombine(SDValue VectorShuffle, + SelectionDAG &DAG) { + + ShuffleVectorSDNode *ShuffleNode = + dyn_cast<ShuffleVectorSDNode>(VectorShuffle.getNode()); + if (!ShuffleNode) + return SDValue(); + + // Ensuring the mask is zero before continuing + if (!ShuffleNode->isSplat() || ShuffleNode->getSplatIndex() != 0) + return SDValue(); + + SDValue InsertVectorElt = VectorShuffle.getOperand(0); + + if (InsertVectorElt.getOpcode() != ISD::INSERT_VECTOR_ELT) + return SDValue(); + + SDValue InsertLane = InsertVectorElt.getOperand(2); + ConstantSDNode *Constant = dyn_cast<ConstantSDNode>(InsertLane.getNode()); + // Ensures the insert is inserting into lane 0 + if (!Constant || Constant->getZExtValue() != 0) + return SDValue(); + + SDValue Extend = InsertVectorElt.getOperand(1); + unsigned ExtendOpcode = Extend.getOpcode(); + + bool IsSExt = ExtendOpcode == ISD::SIGN_EXTEND || + ExtendOpcode == ISD::SIGN_EXTEND_INREG || + ExtendOpcode == ISD::AssertSext; + if (!IsSExt && ExtendOpcode != ISD::ZERO_EXTEND && + ExtendOpcode != ISD::AssertZext && ExtendOpcode != ISD::AND) + return SDValue(); + + EVT TargetType = VectorShuffle.getValueType(); + EVT PreExtendType = calculatePreExtendType(Extend, DAG); + + if ((TargetType != MVT::v8i16 && TargetType != MVT::v4i32 && + TargetType != MVT::v2i64) || + (PreExtendType == MVT::Other)) + return SDValue(); + + // Restrict valid pre-extend data type + if (PreExtendType != MVT::i8 && PreExtendType != MVT::i16 && + PreExtendType != MVT::i32) + return SDValue(); + + EVT PreExtendVT = TargetType.changeVectorElementType(PreExtendType); + + if (PreExtendVT.getVectorElementCount() != TargetType.getVectorElementCount()) + return SDValue(); + + if (TargetType.getScalarSizeInBits() != PreExtendVT.getScalarSizeInBits() * 2) + return SDValue(); + + SDLoc DL(VectorShuffle); + + SDValue InsertVectorNode = DAG.getNode( + InsertVectorElt.getOpcode(), DL, PreExtendVT, DAG.getUNDEF(PreExtendVT), + DAG.getAnyExtOrTrunc(Extend.getOperand(0), DL, PreExtendType), + DAG.getConstant(0, DL, MVT::i64)); + + std::vector<int> ShuffleMask(TargetType.getVectorElementCount().getValue()); + + SDValue VectorShuffleNode = + DAG.getVectorShuffle(PreExtendVT, DL, InsertVectorNode, + DAG.getUNDEF(PreExtendVT), ShuffleMask); + + SDValue ExtendNode = DAG.getNode(IsSExt ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND, + DL, TargetType, VectorShuffleNode); + + return ExtendNode; +} + +/// Combines a mul(dup(sext/zext)) node pattern into mul(sext/zext(dup)) +/// making use of the vector SExt/ZExt rather than the scalar SExt/ZExt +static SDValue performMulVectorExtendCombine(SDNode *Mul, SelectionDAG &DAG) { + // If the value type isn't a vector, none of the operands are going to be dups + if (!Mul->getValueType(0).isVector()) + return SDValue(); + + SDValue Op0 = performCommonVectorExtendCombine(Mul->getOperand(0), DAG); + SDValue Op1 = performCommonVectorExtendCombine(Mul->getOperand(1), DAG); + + // Neither operands have been changed, don't make any further changes + if (!Op0 && !Op1) + return SDValue(); + + SDLoc DL(Mul); + return DAG.getNode(Mul->getOpcode(), DL, Mul->getValueType(0), + Op0 ? Op0 : Mul->getOperand(0), + Op1 ? Op1 : Mul->getOperand(1)); +} + static SDValue performMulCombine(SDNode *N, SelectionDAG &DAG, TargetLowering::DAGCombinerInfo &DCI, const AArch64Subtarget *Subtarget) { + + if (SDValue Ext = performMulVectorExtendCombine(N, DAG)) + return Ext; + if (DCI.isBeforeLegalizeOps()) return SDValue(); @@ -11011,6 +12477,9 @@ static SDValue performSVEAndCombine(SDNode *N, return DAG.getNode(Opc, DL, N->getValueType(0), And); } + if (!EnableCombineMGatherIntrinsics) + return SDValue(); + SDValue Mask = N->getOperand(1); if (!Src.hasOneUse()) @@ -11064,6 +12533,11 @@ static SDValue performANDCombine(SDNode *N, if (VT.isScalableVector()) return performSVEAndCombine(N, DCI); + // The combining code below works only for NEON vectors. In particular, it + // does not work for SVE when dealing with vectors wider than 128 bits. + if (!(VT.is64BitVector() || VT.is128BitVector())) + return SDValue(); + BuildVectorSDNode *BVN = dyn_cast<BuildVectorSDNode>(N->getOperand(1).getNode()); if (!BVN) @@ -11124,6 +12598,143 @@ static SDValue performSRLCombine(SDNode *N, return SDValue(); } +// Attempt to form urhadd(OpA, OpB) from +// truncate(vlshr(sub(zext(OpB), xor(zext(OpA), Ones(ElemSizeInBits))), 1)) +// or uhadd(OpA, OpB) from truncate(vlshr(add(zext(OpA), zext(OpB)), 1)). +// The original form of the first expression is +// truncate(srl(add(zext(OpB), add(zext(OpA), 1)), 1)) and the +// (OpA + OpB + 1) subexpression will have been changed to (OpB - (~OpA)). +// Before this function is called the srl will have been lowered to +// AArch64ISD::VLSHR. +// This pass can also recognize signed variants of the patterns that use sign +// extension instead of zero extension and form a srhadd(OpA, OpB) or a +// shadd(OpA, OpB) from them. +static SDValue +performVectorTruncateCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, + SelectionDAG &DAG) { + EVT VT = N->getValueType(0); + + // Since we are looking for a right shift by a constant value of 1 and we are + // operating on types at least 16 bits in length (sign/zero extended OpA and + // OpB, which are at least 8 bits), it follows that the truncate will always + // discard the shifted-in bit and therefore the right shift will be logical + // regardless of the signedness of OpA and OpB. + SDValue Shift = N->getOperand(0); + if (Shift.getOpcode() != AArch64ISD::VLSHR) + return SDValue(); + + // Is the right shift using an immediate value of 1? + uint64_t ShiftAmount = Shift.getConstantOperandVal(1); + if (ShiftAmount != 1) + return SDValue(); + + SDValue ExtendOpA, ExtendOpB; + SDValue ShiftOp0 = Shift.getOperand(0); + unsigned ShiftOp0Opc = ShiftOp0.getOpcode(); + if (ShiftOp0Opc == ISD::SUB) { + + SDValue Xor = ShiftOp0.getOperand(1); + if (Xor.getOpcode() != ISD::XOR) + return SDValue(); + + // Is the XOR using a constant amount of all ones in the right hand side? + uint64_t C; + if (!isAllConstantBuildVector(Xor.getOperand(1), C)) + return SDValue(); + + unsigned ElemSizeInBits = VT.getScalarSizeInBits(); + APInt CAsAPInt(ElemSizeInBits, C); + if (CAsAPInt != APInt::getAllOnesValue(ElemSizeInBits)) + return SDValue(); + + ExtendOpA = Xor.getOperand(0); + ExtendOpB = ShiftOp0.getOperand(0); + } else if (ShiftOp0Opc == ISD::ADD) { + ExtendOpA = ShiftOp0.getOperand(0); + ExtendOpB = ShiftOp0.getOperand(1); + } else + return SDValue(); + + unsigned ExtendOpAOpc = ExtendOpA.getOpcode(); + unsigned ExtendOpBOpc = ExtendOpB.getOpcode(); + if (!(ExtendOpAOpc == ExtendOpBOpc && + (ExtendOpAOpc == ISD::ZERO_EXTEND || ExtendOpAOpc == ISD::SIGN_EXTEND))) + return SDValue(); + + // Is the result of the right shift being truncated to the same value type as + // the original operands, OpA and OpB? + SDValue OpA = ExtendOpA.getOperand(0); + SDValue OpB = ExtendOpB.getOperand(0); + EVT OpAVT = OpA.getValueType(); + assert(ExtendOpA.getValueType() == ExtendOpB.getValueType()); + if (!(VT == OpAVT && OpAVT == OpB.getValueType())) + return SDValue(); + + SDLoc DL(N); + bool IsSignExtend = ExtendOpAOpc == ISD::SIGN_EXTEND; + bool IsRHADD = ShiftOp0Opc == ISD::SUB; + unsigned HADDOpc = IsSignExtend + ? (IsRHADD ? AArch64ISD::SRHADD : AArch64ISD::SHADD) + : (IsRHADD ? AArch64ISD::URHADD : AArch64ISD::UHADD); + SDValue ResultHADD = DAG.getNode(HADDOpc, DL, VT, OpA, OpB); + + return ResultHADD; +} + +static bool hasPairwiseAdd(unsigned Opcode, EVT VT, bool FullFP16) { + switch (Opcode) { + case ISD::FADD: + return (FullFP16 && VT == MVT::f16) || VT == MVT::f32 || VT == MVT::f64; + case ISD::ADD: + return VT == MVT::i64; + default: + return false; + } +} + +static SDValue performExtractVectorEltCombine(SDNode *N, SelectionDAG &DAG) { + SDValue N0 = N->getOperand(0), N1 = N->getOperand(1); + ConstantSDNode *ConstantN1 = dyn_cast<ConstantSDNode>(N1); + + EVT VT = N->getValueType(0); + const bool FullFP16 = + static_cast<const AArch64Subtarget &>(DAG.getSubtarget()).hasFullFP16(); + + // Rewrite for pairwise fadd pattern + // (f32 (extract_vector_elt + // (fadd (vXf32 Other) + // (vector_shuffle (vXf32 Other) undef <1,X,...> )) 0)) + // -> + // (f32 (fadd (extract_vector_elt (vXf32 Other) 0) + // (extract_vector_elt (vXf32 Other) 1)) + if (ConstantN1 && ConstantN1->getZExtValue() == 0 && + hasPairwiseAdd(N0->getOpcode(), VT, FullFP16)) { + SDLoc DL(N0); + SDValue N00 = N0->getOperand(0); + SDValue N01 = N0->getOperand(1); + + ShuffleVectorSDNode *Shuffle = dyn_cast<ShuffleVectorSDNode>(N01); + SDValue Other = N00; + + // And handle the commutative case. + if (!Shuffle) { + Shuffle = dyn_cast<ShuffleVectorSDNode>(N00); + Other = N01; + } + + if (Shuffle && Shuffle->getMaskElt(0) == 1 && + Other == Shuffle->getOperand(0)) { + return DAG.getNode(N0->getOpcode(), DL, VT, + DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Other, + DAG.getConstant(0, DL, MVT::i64)), + DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Other, + DAG.getConstant(1, DL, MVT::i64))); + } + } + + return SDValue(); +} + static SDValue performConcatVectorsCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, SelectionDAG &DAG) { @@ -11169,9 +12780,9 @@ static SDValue performConcatVectorsCombine(SDNode *N, if (DCI.isBeforeLegalizeOps()) return SDValue(); - // Optimise concat_vectors of two [us]rhadds that use extracted subvectors - // from the same original vectors. Combine these into a single [us]rhadd that - // operates on the two original vectors. Example: + // Optimise concat_vectors of two [us]rhadds or [us]hadds that use extracted + // subvectors from the same original vectors. Combine these into a single + // [us]rhadd or [us]hadd that operates on the two original vectors. Example: // (v16i8 (concat_vectors (v8i8 (urhadd (extract_subvector (v16i8 OpA, <0>), // extract_subvector (v16i8 OpB, // <0>))), @@ -11181,7 +12792,8 @@ static SDValue performConcatVectorsCombine(SDNode *N, // -> // (v16i8(urhadd(v16i8 OpA, v16i8 OpB))) if (N->getNumOperands() == 2 && N0Opc == N1Opc && - (N0Opc == AArch64ISD::URHADD || N0Opc == AArch64ISD::SRHADD)) { + (N0Opc == AArch64ISD::URHADD || N0Opc == AArch64ISD::SRHADD || + N0Opc == AArch64ISD::UHADD || N0Opc == AArch64ISD::SHADD)) { SDValue N00 = N0->getOperand(0); SDValue N01 = N0->getOperand(1); SDValue N10 = N1->getOperand(0); @@ -11486,6 +13098,43 @@ static SDValue performSetccAddFolding(SDNode *Op, SelectionDAG &DAG) { return DAG.getNode(AArch64ISD::CSEL, dl, VT, RHS, LHS, CCVal, Cmp); } +// ADD(UADDV a, UADDV b) --> UADDV(ADD a, b) +static SDValue performUADDVCombine(SDNode *N, SelectionDAG &DAG) { + EVT VT = N->getValueType(0); + // Only scalar integer and vector types. + if (N->getOpcode() != ISD::ADD || !VT.isScalarInteger()) + return SDValue(); + + SDValue LHS = N->getOperand(0); + SDValue RHS = N->getOperand(1); + if (LHS.getOpcode() != ISD::EXTRACT_VECTOR_ELT || + RHS.getOpcode() != ISD::EXTRACT_VECTOR_ELT || LHS.getValueType() != VT) + return SDValue(); + + auto *LHSN1 = dyn_cast<ConstantSDNode>(LHS->getOperand(1)); + auto *RHSN1 = dyn_cast<ConstantSDNode>(RHS->getOperand(1)); + if (!LHSN1 || LHSN1 != RHSN1 || !RHSN1->isNullValue()) + return SDValue(); + + SDValue Op1 = LHS->getOperand(0); + SDValue Op2 = RHS->getOperand(0); + EVT OpVT1 = Op1.getValueType(); + EVT OpVT2 = Op2.getValueType(); + if (Op1.getOpcode() != AArch64ISD::UADDV || OpVT1 != OpVT2 || + Op2.getOpcode() != AArch64ISD::UADDV || + OpVT1.getVectorElementType() != VT) + return SDValue(); + + SDValue Val1 = Op1.getOperand(0); + SDValue Val2 = Op2.getOperand(0); + EVT ValVT = Val1->getValueType(0); + SDLoc DL(N); + SDValue AddVal = DAG.getNode(ISD::ADD, DL, ValVT, Val1, Val2); + return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, + DAG.getNode(AArch64ISD::UADDV, DL, ValVT, AddVal), + DAG.getConstant(0, DL, MVT::i64)); +} + // The basic add/sub long vector instructions have variants with "2" on the end // which act on the high-half of their inputs. They are normally matched by // patterns like: @@ -11539,6 +13188,16 @@ static SDValue performAddSubLongCombine(SDNode *N, return DAG.getNode(N->getOpcode(), SDLoc(N), VT, LHS, RHS); } +static SDValue performAddSubCombine(SDNode *N, + TargetLowering::DAGCombinerInfo &DCI, + SelectionDAG &DAG) { + // Try to change sum of two reductions. + if (SDValue Val = performUADDVCombine(N, DAG)) + return Val; + + return performAddSubLongCombine(N, DCI, DAG); +} + // Massage DAGs which we can use the high-half "long" operations on into // something isel will recognize better. E.g. // @@ -11552,8 +13211,8 @@ static SDValue tryCombineLongOpWithDup(unsigned IID, SDNode *N, if (DCI.isBeforeLegalizeOps()) return SDValue(); - SDValue LHS = N->getOperand(1); - SDValue RHS = N->getOperand(2); + SDValue LHS = N->getOperand((IID == Intrinsic::not_intrinsic) ? 0 : 1); + SDValue RHS = N->getOperand((IID == Intrinsic::not_intrinsic) ? 1 : 2); assert(LHS.getValueType().is64BitVector() && RHS.getValueType().is64BitVector() && "unexpected shape for long operation"); @@ -11571,6 +13230,9 @@ static SDValue tryCombineLongOpWithDup(unsigned IID, SDNode *N, return SDValue(); } + if (IID == Intrinsic::not_intrinsic) + return DAG.getNode(N->getOpcode(), SDLoc(N), N->getValueType(0), LHS, RHS); + return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, SDLoc(N), N->getValueType(0), N->getOperand(0), LHS, RHS); } @@ -11669,34 +13331,6 @@ static SDValue combineAcrossLanesIntrinsic(unsigned Opc, SDNode *N, DAG.getConstant(0, dl, MVT::i64)); } -static SDValue LowerSVEIntReduction(SDNode *N, unsigned Opc, - SelectionDAG &DAG) { - SDLoc dl(N); - LLVMContext &Ctx = *DAG.getContext(); - const TargetLowering &TLI = DAG.getTargetLoweringInfo(); - - EVT VT = N->getValueType(0); - SDValue Pred = N->getOperand(1); - SDValue Data = N->getOperand(2); - EVT DataVT = Data.getValueType(); - - if (DataVT.getVectorElementType().isScalarInteger() && - (VT == MVT::i8 || VT == MVT::i16 || VT == MVT::i32 || VT == MVT::i64)) { - if (!TLI.isTypeLegal(DataVT)) - return SDValue(); - - EVT OutputVT = EVT::getVectorVT(Ctx, VT, - AArch64::NeonBitsPerVector / VT.getSizeInBits()); - SDValue Reduce = DAG.getNode(Opc, dl, OutputVT, Pred, Data); - SDValue Zero = DAG.getConstant(0, dl, MVT::i64); - SDValue Result = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, VT, Reduce, Zero); - - return Result; - } - - return SDValue(); -} - static SDValue LowerSVEIntrinsicIndex(SDNode *N, SelectionDAG &DAG) { SDLoc DL(N); SDValue Op1 = N->getOperand(1); @@ -11739,7 +13373,8 @@ static SDValue LowerSVEIntrinsicEXT(SDNode *N, SelectionDAG &DAG) { unsigned ElemSize = VT.getVectorElementType().getSizeInBits() / 8; unsigned ByteSize = VT.getSizeInBits().getKnownMinSize() / 8; - EVT ByteVT = EVT::getVectorVT(Ctx, MVT::i8, { ByteSize, true }); + EVT ByteVT = + EVT::getVectorVT(Ctx, MVT::i8, ElementCount::getScalable(ByteSize)); // Convert everything to the domain of EXT (i.e bytes). SDValue Op0 = DAG.getNode(ISD::BITCAST, dl, ByteVT, N->getOperand(1)); @@ -11839,6 +13474,25 @@ static SDValue getPTest(SelectionDAG &DAG, EVT VT, SDValue Pg, SDValue Op, return DAG.getZExtOrTrunc(Res, DL, VT); } +static SDValue combineSVEReductionInt(SDNode *N, unsigned Opc, + SelectionDAG &DAG) { + SDLoc DL(N); + + SDValue Pred = N->getOperand(1); + SDValue VecToReduce = N->getOperand(2); + + // NOTE: The integer reduction's result type is not always linked to the + // operand's element type so we construct it from the intrinsic's result type. + EVT ReduceVT = getPackedSVEVectorVT(N->getValueType(0)); + SDValue Reduce = DAG.getNode(Opc, DL, ReduceVT, Pred, VecToReduce); + + // SVE reductions set the whole vector register with the first element + // containing the reduction result, which we'll now extract. + SDValue Zero = DAG.getConstant(0, DL, MVT::i64); + return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, N->getValueType(0), Reduce, + Zero); +} + static SDValue combineSVEReductionFP(SDNode *N, unsigned Opc, SelectionDAG &DAG) { SDLoc DL(N); @@ -11879,6 +13533,25 @@ static SDValue combineSVEReductionOrderedFP(SDNode *N, unsigned Opc, Zero); } +// If a merged operation has no inactive lanes we can relax it to a predicated +// or unpredicated operation, which potentially allows better isel (perhaps +// using immediate forms) or relaxing register reuse requirements. +static SDValue convertMergedOpToPredOp(SDNode *N, unsigned PredOpc, + SelectionDAG &DAG) { + assert(N->getOpcode() == ISD::INTRINSIC_WO_CHAIN && "Expected intrinsic!"); + assert(N->getNumOperands() == 4 && "Expected 3 operand intrinsic!"); + SDValue Pg = N->getOperand(1); + + // ISD way to specify an all active predicate. + if ((Pg.getOpcode() == AArch64ISD::PTRUE) && + (Pg.getConstantOperandVal(0) == AArch64SVEPredPattern::all)) + return DAG.getNode(PredOpc, SDLoc(N), N->getValueType(0), Pg, + N->getOperand(2), N->getOperand(3)); + + // FUTURE: SplatVector(true) + return SDValue(); +} + static SDValue performIntrinsicCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, const AArch64Subtarget *Subtarget) { @@ -11933,20 +13606,28 @@ static SDValue performIntrinsicCombine(SDNode *N, case Intrinsic::aarch64_crc32h: case Intrinsic::aarch64_crc32ch: return tryCombineCRC32(0xffff, N, DAG); + case Intrinsic::aarch64_sve_saddv: + // There is no i64 version of SADDV because the sign is irrelevant. + if (N->getOperand(2)->getValueType(0).getVectorElementType() == MVT::i64) + return combineSVEReductionInt(N, AArch64ISD::UADDV_PRED, DAG); + else + return combineSVEReductionInt(N, AArch64ISD::SADDV_PRED, DAG); + case Intrinsic::aarch64_sve_uaddv: + return combineSVEReductionInt(N, AArch64ISD::UADDV_PRED, DAG); case Intrinsic::aarch64_sve_smaxv: - return LowerSVEIntReduction(N, AArch64ISD::SMAXV_PRED, DAG); + return combineSVEReductionInt(N, AArch64ISD::SMAXV_PRED, DAG); case Intrinsic::aarch64_sve_umaxv: - return LowerSVEIntReduction(N, AArch64ISD::UMAXV_PRED, DAG); + return combineSVEReductionInt(N, AArch64ISD::UMAXV_PRED, DAG); case Intrinsic::aarch64_sve_sminv: - return LowerSVEIntReduction(N, AArch64ISD::SMINV_PRED, DAG); + return combineSVEReductionInt(N, AArch64ISD::SMINV_PRED, DAG); case Intrinsic::aarch64_sve_uminv: - return LowerSVEIntReduction(N, AArch64ISD::UMINV_PRED, DAG); + return combineSVEReductionInt(N, AArch64ISD::UMINV_PRED, DAG); case Intrinsic::aarch64_sve_orv: - return LowerSVEIntReduction(N, AArch64ISD::ORV_PRED, DAG); + return combineSVEReductionInt(N, AArch64ISD::ORV_PRED, DAG); case Intrinsic::aarch64_sve_eorv: - return LowerSVEIntReduction(N, AArch64ISD::EORV_PRED, DAG); + return combineSVEReductionInt(N, AArch64ISD::EORV_PRED, DAG); case Intrinsic::aarch64_sve_andv: - return LowerSVEIntReduction(N, AArch64ISD::ANDV_PRED, DAG); + return combineSVEReductionInt(N, AArch64ISD::ANDV_PRED, DAG); case Intrinsic::aarch64_sve_index: return LowerSVEIntrinsicIndex(N, DAG); case Intrinsic::aarch64_sve_dup: @@ -11957,26 +13638,19 @@ static SDValue performIntrinsicCombine(SDNode *N, case Intrinsic::aarch64_sve_ext: return LowerSVEIntrinsicEXT(N, DAG); case Intrinsic::aarch64_sve_smin: - return DAG.getNode(AArch64ISD::SMIN_MERGE_OP1, SDLoc(N), N->getValueType(0), - N->getOperand(1), N->getOperand(2), N->getOperand(3)); + return convertMergedOpToPredOp(N, AArch64ISD::SMIN_PRED, DAG); case Intrinsic::aarch64_sve_umin: - return DAG.getNode(AArch64ISD::UMIN_MERGE_OP1, SDLoc(N), N->getValueType(0), - N->getOperand(1), N->getOperand(2), N->getOperand(3)); + return convertMergedOpToPredOp(N, AArch64ISD::UMIN_PRED, DAG); case Intrinsic::aarch64_sve_smax: - return DAG.getNode(AArch64ISD::SMAX_MERGE_OP1, SDLoc(N), N->getValueType(0), - N->getOperand(1), N->getOperand(2), N->getOperand(3)); + return convertMergedOpToPredOp(N, AArch64ISD::SMAX_PRED, DAG); case Intrinsic::aarch64_sve_umax: - return DAG.getNode(AArch64ISD::UMAX_MERGE_OP1, SDLoc(N), N->getValueType(0), - N->getOperand(1), N->getOperand(2), N->getOperand(3)); + return convertMergedOpToPredOp(N, AArch64ISD::UMAX_PRED, DAG); case Intrinsic::aarch64_sve_lsl: - return DAG.getNode(AArch64ISD::SHL_MERGE_OP1, SDLoc(N), N->getValueType(0), - N->getOperand(1), N->getOperand(2), N->getOperand(3)); + return convertMergedOpToPredOp(N, AArch64ISD::SHL_PRED, DAG); case Intrinsic::aarch64_sve_lsr: - return DAG.getNode(AArch64ISD::SRL_MERGE_OP1, SDLoc(N), N->getValueType(0), - N->getOperand(1), N->getOperand(2), N->getOperand(3)); + return convertMergedOpToPredOp(N, AArch64ISD::SRL_PRED, DAG); case Intrinsic::aarch64_sve_asr: - return DAG.getNode(AArch64ISD::SRA_MERGE_OP1, SDLoc(N), N->getValueType(0), - N->getOperand(1), N->getOperand(2), N->getOperand(3)); + return convertMergedOpToPredOp(N, AArch64ISD::SRA_PRED, DAG); case Intrinsic::aarch64_sve_cmphs: if (!N->getOperand(2).getValueType().isFloatingPoint()) return DAG.getNode(AArch64ISD::SETCC_MERGE_ZERO, SDLoc(N), @@ -12069,18 +13743,15 @@ static SDValue performExtendCombine(SDNode *N, // helps the backend to decide that an sabdl2 would be useful, saving a real // extract_high operation. if (!DCI.isBeforeLegalizeOps() && N->getOpcode() == ISD::ZERO_EXTEND && - N->getOperand(0).getOpcode() == ISD::INTRINSIC_WO_CHAIN) { + (N->getOperand(0).getOpcode() == AArch64ISD::UABD || + N->getOperand(0).getOpcode() == AArch64ISD::SABD)) { SDNode *ABDNode = N->getOperand(0).getNode(); - unsigned IID = getIntrinsicID(ABDNode); - if (IID == Intrinsic::aarch64_neon_sabd || - IID == Intrinsic::aarch64_neon_uabd) { - SDValue NewABD = tryCombineLongOpWithDup(IID, ABDNode, DCI, DAG); - if (!NewABD.getNode()) - return SDValue(); + SDValue NewABD = + tryCombineLongOpWithDup(Intrinsic::not_intrinsic, ABDNode, DCI, DAG); + if (!NewABD.getNode()) + return SDValue(); - return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), N->getValueType(0), - NewABD); - } + return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), N->getValueType(0), NewABD); } // This is effectively a custom type legalization for AArch64. @@ -12288,6 +13959,9 @@ static SDValue performLD1ReplicateCombine(SDNode *N, SelectionDAG &DAG) { "Unsupported opcode."); SDLoc DL(N); EVT VT = N->getValueType(0); + if (VT == MVT::nxv8bf16 && + !static_cast<const AArch64Subtarget &>(DAG.getSubtarget()).hasBF16()) + return SDValue(); EVT LoadVT = VT; if (VT.isFloatingPoint()) @@ -12560,6 +14234,31 @@ static SDValue splitStores(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, S->getMemOperand()->getFlags()); } +static SDValue performUzpCombine(SDNode *N, SelectionDAG &DAG) { + SDLoc DL(N); + SDValue Op0 = N->getOperand(0); + SDValue Op1 = N->getOperand(1); + EVT ResVT = N->getValueType(0); + + // uzp1(unpklo(uzp1(x, y)), z) => uzp1(x, z) + if (Op0.getOpcode() == AArch64ISD::UUNPKLO) { + if (Op0.getOperand(0).getOpcode() == AArch64ISD::UZP1) { + SDValue X = Op0.getOperand(0).getOperand(0); + return DAG.getNode(AArch64ISD::UZP1, DL, ResVT, X, Op1); + } + } + + // uzp1(x, unpkhi(uzp1(y, z))) => uzp1(x, z) + if (Op1.getOpcode() == AArch64ISD::UUNPKHI) { + if (Op1.getOperand(0).getOpcode() == AArch64ISD::UZP1) { + SDValue Z = Op1.getOperand(0).getOperand(1); + return DAG.getNode(AArch64ISD::UZP1, DL, ResVT, Op0, Z); + } + } + + return SDValue(); +} + /// Target-specific DAG combine function for post-increment LD1 (lane) and /// post-increment LD1R. static SDValue performPostLD1Combine(SDNode *N, @@ -12698,6 +14397,54 @@ static SDValue performSTORECombine(SDNode *N, return SDValue(); } +static SDValue performMaskedGatherScatterCombine(SDNode *N, + TargetLowering::DAGCombinerInfo &DCI, + SelectionDAG &DAG) { + MaskedGatherScatterSDNode *MGS = cast<MaskedGatherScatterSDNode>(N); + assert(MGS && "Can only combine gather load or scatter store nodes"); + + SDLoc DL(MGS); + SDValue Chain = MGS->getChain(); + SDValue Scale = MGS->getScale(); + SDValue Index = MGS->getIndex(); + SDValue Mask = MGS->getMask(); + SDValue BasePtr = MGS->getBasePtr(); + ISD::MemIndexType IndexType = MGS->getIndexType(); + + EVT IdxVT = Index.getValueType(); + + if (DCI.isBeforeLegalize()) { + // SVE gather/scatter requires indices of i32/i64. Promote anything smaller + // prior to legalisation so the result can be split if required. + if ((IdxVT.getVectorElementType() == MVT::i8) || + (IdxVT.getVectorElementType() == MVT::i16)) { + EVT NewIdxVT = IdxVT.changeVectorElementType(MVT::i32); + if (MGS->isIndexSigned()) + Index = DAG.getNode(ISD::SIGN_EXTEND, DL, NewIdxVT, Index); + else + Index = DAG.getNode(ISD::ZERO_EXTEND, DL, NewIdxVT, Index); + + if (auto *MGT = dyn_cast<MaskedGatherSDNode>(MGS)) { + SDValue PassThru = MGT->getPassThru(); + SDValue Ops[] = { Chain, PassThru, Mask, BasePtr, Index, Scale }; + return DAG.getMaskedGather(DAG.getVTList(N->getValueType(0), MVT::Other), + PassThru.getValueType(), DL, Ops, + MGT->getMemOperand(), + MGT->getIndexType(), MGT->getExtensionType()); + } else { + auto *MSC = cast<MaskedScatterSDNode>(MGS); + SDValue Data = MSC->getValue(); + SDValue Ops[] = { Chain, Data, Mask, BasePtr, Index, Scale }; + return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), + MSC->getMemoryVT(), DL, Ops, + MSC->getMemOperand(), IndexType, + MSC->isTruncatingStore()); + } + } + } + + return SDValue(); +} /// Target-specific DAG combine function for NEON load/store intrinsics /// to merge base address updates. @@ -13669,9 +15416,6 @@ static SDValue performGatherLoadCombine(SDNode *N, SelectionDAG &DAG, static SDValue performSignExtendInRegCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, SelectionDAG &DAG) { - if (DCI.isBeforeLegalizeOps()) - return SDValue(); - SDLoc DL(N); SDValue Src = N->getOperand(0); unsigned Opc = Src->getOpcode(); @@ -13698,9 +15442,7 @@ performSignExtendInRegCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, assert((EltTy == MVT::i8 || EltTy == MVT::i16 || EltTy == MVT::i32) && "Sign extending from an invalid type"); - EVT ExtVT = EVT::getVectorVT(*DAG.getContext(), - VT.getVectorElementType(), - VT.getVectorElementCount() * 2); + EVT ExtVT = VT.getDoubleNumVectorElementsVT(*DAG.getContext()); SDValue Ext = DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, ExtOp.getValueType(), ExtOp, DAG.getValueType(ExtVT)); @@ -13708,6 +15450,12 @@ performSignExtendInRegCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, return DAG.getNode(SOpc, DL, N->getValueType(0), Ext); } + if (DCI.isBeforeLegalizeOps()) + return SDValue(); + + if (!EnableCombineMGatherIntrinsics) + return SDValue(); + // SVE load nodes (e.g. AArch64ISD::GLD1) are straightforward candidates // for DAG Combine with SIGN_EXTEND_INREG. Bail out for all other nodes. unsigned NewOpc; @@ -13847,9 +15595,11 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N, default: LLVM_DEBUG(dbgs() << "Custom combining: skipping\n"); break; + case ISD::ABS: + return performABSCombine(N, DAG, DCI, Subtarget); case ISD::ADD: case ISD::SUB: - return performAddSubLongCombine(N, DCI, DAG); + return performAddSubCombine(N, DCI, DAG); case ISD::XOR: return performXorCombine(N, DAG, DCI, Subtarget); case ISD::MUL: @@ -13876,6 +15626,8 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N, return performExtendCombine(N, DCI, DAG); case ISD::SIGN_EXTEND_INREG: return performSignExtendInRegCombine(N, DCI, DAG); + case ISD::TRUNCATE: + return performVectorTruncateCombine(N, DCI, DAG); case ISD::CONCAT_VECTORS: return performConcatVectorsCombine(N, DCI, DAG); case ISD::SELECT: @@ -13888,6 +15640,9 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N, break; case ISD::STORE: return performSTORECombine(N, DCI, DAG, Subtarget); + case ISD::MGATHER: + case ISD::MSCATTER: + return performMaskedGatherScatterCombine(N, DCI, DAG); case AArch64ISD::BRCOND: return performBRCONDCombine(N, DCI, DAG); case AArch64ISD::TBNZ: @@ -13899,8 +15654,14 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N, return performPostLD1Combine(N, DCI, false); case AArch64ISD::NVCAST: return performNVCASTCombine(N); + case AArch64ISD::UZP1: + return performUzpCombine(N, DAG); case ISD::INSERT_VECTOR_ELT: return performPostLD1Combine(N, DCI, true); + case ISD::EXTRACT_VECTOR_ELT: + return performExtractVectorEltCombine(N, DAG); + case ISD::VECREDUCE_ADD: + return performVecReduceAddCombine(N, DCI.DAG, Subtarget); case ISD::INTRINSIC_VOID: case ISD::INTRINSIC_W_CHAIN: switch (cast<ConstantSDNode>(N->getOperand(1))->getZExtValue()) { @@ -14049,10 +15810,10 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N, uint64_t IdxConst = cast<ConstantSDNode>(Idx)->getZExtValue(); EVT ResVT = N->getValueType(0); - uint64_t NumLanes = ResVT.getVectorElementCount().Min; + uint64_t NumLanes = ResVT.getVectorElementCount().getKnownMinValue(); + SDValue ExtIdx = DAG.getVectorIdxConstant(IdxConst * NumLanes, DL); SDValue Val = - DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ResVT, Src1, - DAG.getConstant(IdxConst * NumLanes, DL, MVT::i32)); + DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ResVT, Src1, ExtIdx); return DAG.getMergeValues({Val, Chain}, DL); } case Intrinsic::aarch64_sve_tuple_set: { @@ -14063,10 +15824,11 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N, SDValue Vec = N->getOperand(4); EVT TupleVT = Tuple.getValueType(); - uint64_t TupleLanes = TupleVT.getVectorElementCount().Min; + uint64_t TupleLanes = TupleVT.getVectorElementCount().getKnownMinValue(); uint64_t IdxConst = cast<ConstantSDNode>(Idx)->getZExtValue(); - uint64_t NumLanes = Vec.getValueType().getVectorElementCount().Min; + uint64_t NumLanes = + Vec.getValueType().getVectorElementCount().getKnownMinValue(); if ((TupleLanes % NumLanes) != 0) report_fatal_error("invalid tuple vector!"); @@ -14078,9 +15840,9 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N, if (I == IdxConst) Opnds.push_back(Vec); else { - Opnds.push_back( - DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, Vec.getValueType(), Tuple, - DAG.getConstant(I * NumLanes, DL, MVT::i32))); + SDValue ExtIdx = DAG.getVectorIdxConstant(I * NumLanes, DL); + Opnds.push_back(DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, + Vec.getValueType(), Tuple, ExtIdx)); } } SDValue Concat = @@ -14302,7 +16064,7 @@ void AArch64TargetLowering::ReplaceExtractSubVectorResults( ElementCount ResEC = VT.getVectorElementCount(); - if (InVT.getVectorElementCount().Min != (ResEC.Min * 2)) + if (InVT.getVectorElementCount() != (ResEC * 2)) return; auto *CIndex = dyn_cast<ConstantSDNode>(N->getOperand(1)); @@ -14310,7 +16072,7 @@ void AArch64TargetLowering::ReplaceExtractSubVectorResults( return; unsigned Index = CIndex->getZExtValue(); - if ((Index != 0) && (Index != ResEC.Min)) + if ((Index != 0) && (Index != ResEC.getKnownMinValue())) return; unsigned Opcode = (Index == 0) ? AArch64ISD::UUNPKLO : AArch64ISD::UUNPKHI; @@ -14345,7 +16107,7 @@ static void ReplaceCMP_SWAP_128Results(SDNode *N, assert(N->getValueType(0) == MVT::i128 && "AtomicCmpSwap on types less than 128 should be legal"); - if (Subtarget->hasLSE()) { + if (Subtarget->hasLSE() || Subtarget->outlineAtomics()) { // LSE has a 128-bit compare and swap (CASP), but i128 is not a legal type, // so lower it here, wrapped in REG_SEQUENCE and EXTRACT_SUBREG. SDValue Ops[] = { @@ -14426,7 +16188,8 @@ void AArch64TargetLowering::ReplaceNodeResults( return; case ISD::CTPOP: - Results.push_back(LowerCTPOP(SDValue(N, 0), DAG)); + if (SDValue Result = LowerCTPOP(SDValue(N, 0), DAG)) + Results.push_back(Result); return; case AArch64ISD::SADDV: ReplaceReductionResults(N, Results, DAG, ISD::ADD, AArch64ISD::SADDV); @@ -14574,14 +16337,30 @@ AArch64TargetLowering::shouldExpandAtomicRMWInIR(AtomicRMWInst *AI) const { // Nand not supported in LSE. if (AI->getOperation() == AtomicRMWInst::Nand) return AtomicExpansionKind::LLSC; // Leave 128 bits to LLSC. - return (Subtarget->hasLSE() && Size < 128) ? AtomicExpansionKind::None : AtomicExpansionKind::LLSC; + if (Subtarget->hasLSE() && Size < 128) + return AtomicExpansionKind::None; + if (Subtarget->outlineAtomics() && Size < 128) { + // [U]Min/[U]Max RWM atomics are used in __sync_fetch_ libcalls so far. + // Don't outline them unless + // (1) high level <atomic> support approved: + // http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2020/p0493r1.pdf + // (2) low level libgcc and compiler-rt support implemented by: + // min/max outline atomics helpers + if (AI->getOperation() != AtomicRMWInst::Min && + AI->getOperation() != AtomicRMWInst::Max && + AI->getOperation() != AtomicRMWInst::UMin && + AI->getOperation() != AtomicRMWInst::UMax) { + return AtomicExpansionKind::None; + } + } + return AtomicExpansionKind::LLSC; } TargetLowering::AtomicExpansionKind AArch64TargetLowering::shouldExpandAtomicCmpXchgInIR( AtomicCmpXchgInst *AI) const { // If subtarget has LSE, leave cmpxchg intact for codegen. - if (Subtarget->hasLSE()) + if (Subtarget->hasLSE() || Subtarget->outlineAtomics()) return AtomicExpansionKind::None; // At -O0, fast-regalloc cannot cope with the live vregs necessary to // implement cmpxchg without spilling. If the address being exchanged is also @@ -14676,7 +16455,14 @@ Value *AArch64TargetLowering::emitStoreConditional(IRBuilder<> &Builder, bool AArch64TargetLowering::functionArgumentNeedsConsecutiveRegisters( Type *Ty, CallingConv::ID CallConv, bool isVarArg) const { - return Ty->isArrayTy(); + if (Ty->isArrayTy()) + return true; + + const TypeSize &TySize = Ty->getPrimitiveSizeInBits(); + if (TySize.isScalable() && TySize.getKnownMinSize() > 128) + return true; + + return false; } bool AArch64TargetLowering::shouldNormalizeToSelectSequence(LLVMContext &, @@ -14909,6 +16695,11 @@ bool AArch64TargetLowering::fallBackToDAGISel(const Instruction &Inst) const { if (isa<ScalableVectorType>(Inst.getOperand(i)->getType())) return true; + if (const AllocaInst *AI = dyn_cast<AllocaInst>(&Inst)) { + if (isa<ScalableVectorType>(AI->getAllocatedType())) + return true; + } + return false; } @@ -15080,6 +16871,92 @@ SDValue AArch64TargetLowering::LowerFixedLengthVectorStoreToSVE( Store->isTruncatingStore()); } +SDValue AArch64TargetLowering::LowerFixedLengthVectorIntDivideToSVE( + SDValue Op, SelectionDAG &DAG) const { + SDLoc dl(Op); + EVT VT = Op.getValueType(); + EVT EltVT = VT.getVectorElementType(); + + bool Signed = Op.getOpcode() == ISD::SDIV; + unsigned PredOpcode = Signed ? AArch64ISD::SDIV_PRED : AArch64ISD::UDIV_PRED; + + // Scalable vector i32/i64 DIV is supported. + if (EltVT == MVT::i32 || EltVT == MVT::i64) + return LowerToPredicatedOp(Op, DAG, PredOpcode, /*OverrideNEON=*/true); + + // Scalable vector i8/i16 DIV is not supported. Promote it to i32. + EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT); + EVT HalfVT = VT.getHalfNumVectorElementsVT(*DAG.getContext()); + EVT FixedWidenedVT = HalfVT.widenIntegerVectorElementType(*DAG.getContext()); + EVT ScalableWidenedVT = getContainerForFixedLengthVector(DAG, FixedWidenedVT); + + // Convert the operands to scalable vectors. + SDValue Op0 = convertToScalableVector(DAG, ContainerVT, Op.getOperand(0)); + SDValue Op1 = convertToScalableVector(DAG, ContainerVT, Op.getOperand(1)); + + // Extend the scalable operands. + unsigned UnpkLo = Signed ? AArch64ISD::SUNPKLO : AArch64ISD::UUNPKLO; + unsigned UnpkHi = Signed ? AArch64ISD::SUNPKHI : AArch64ISD::UUNPKHI; + SDValue Op0Lo = DAG.getNode(UnpkLo, dl, ScalableWidenedVT, Op0); + SDValue Op1Lo = DAG.getNode(UnpkLo, dl, ScalableWidenedVT, Op1); + SDValue Op0Hi = DAG.getNode(UnpkHi, dl, ScalableWidenedVT, Op0); + SDValue Op1Hi = DAG.getNode(UnpkHi, dl, ScalableWidenedVT, Op1); + + // Convert back to fixed vectors so the DIV can be further lowered. + Op0Lo = convertFromScalableVector(DAG, FixedWidenedVT, Op0Lo); + Op1Lo = convertFromScalableVector(DAG, FixedWidenedVT, Op1Lo); + Op0Hi = convertFromScalableVector(DAG, FixedWidenedVT, Op0Hi); + Op1Hi = convertFromScalableVector(DAG, FixedWidenedVT, Op1Hi); + SDValue ResultLo = DAG.getNode(Op.getOpcode(), dl, FixedWidenedVT, + Op0Lo, Op1Lo); + SDValue ResultHi = DAG.getNode(Op.getOpcode(), dl, FixedWidenedVT, + Op0Hi, Op1Hi); + + // Convert again to scalable vectors to truncate. + ResultLo = convertToScalableVector(DAG, ScalableWidenedVT, ResultLo); + ResultHi = convertToScalableVector(DAG, ScalableWidenedVT, ResultHi); + SDValue ScalableResult = DAG.getNode(AArch64ISD::UZP1, dl, ContainerVT, + ResultLo, ResultHi); + + return convertFromScalableVector(DAG, VT, ScalableResult); +} + +SDValue AArch64TargetLowering::LowerFixedLengthVectorIntExtendToSVE( + SDValue Op, SelectionDAG &DAG) const { + EVT VT = Op.getValueType(); + assert(VT.isFixedLengthVector() && "Expected fixed length vector type!"); + + SDLoc DL(Op); + SDValue Val = Op.getOperand(0); + EVT ContainerVT = getContainerForFixedLengthVector(DAG, Val.getValueType()); + Val = convertToScalableVector(DAG, ContainerVT, Val); + + bool Signed = Op.getOpcode() == ISD::SIGN_EXTEND; + unsigned ExtendOpc = Signed ? AArch64ISD::SUNPKLO : AArch64ISD::UUNPKLO; + + // Repeatedly unpack Val until the result is of the desired element type. + switch (ContainerVT.getSimpleVT().SimpleTy) { + default: + llvm_unreachable("unimplemented container type"); + case MVT::nxv16i8: + Val = DAG.getNode(ExtendOpc, DL, MVT::nxv8i16, Val); + if (VT.getVectorElementType() == MVT::i16) + break; + LLVM_FALLTHROUGH; + case MVT::nxv8i16: + Val = DAG.getNode(ExtendOpc, DL, MVT::nxv4i32, Val); + if (VT.getVectorElementType() == MVT::i32) + break; + LLVM_FALLTHROUGH; + case MVT::nxv4i32: + Val = DAG.getNode(ExtendOpc, DL, MVT::nxv2i64, Val); + assert(VT.getVectorElementType() == MVT::i64 && "Unexpected element type!"); + break; + } + + return convertFromScalableVector(DAG, VT, Val); +} + SDValue AArch64TargetLowering::LowerFixedLengthVectorTruncateToSVE( SDValue Op, SelectionDAG &DAG) const { EVT VT = Op.getValueType(); @@ -15116,17 +16993,21 @@ SDValue AArch64TargetLowering::LowerFixedLengthVectorTruncateToSVE( return convertFromScalableVector(DAG, VT, Val); } +// Convert vector operation 'Op' to an equivalent predicated operation whereby +// the original operation's type is used to construct a suitable predicate. +// NOTE: The results for inactive lanes are undefined. SDValue AArch64TargetLowering::LowerToPredicatedOp(SDValue Op, SelectionDAG &DAG, - unsigned NewOp) const { + unsigned NewOp, + bool OverrideNEON) const { EVT VT = Op.getValueType(); SDLoc DL(Op); auto Pg = getPredicateForVector(DAG, DL, VT); - if (useSVEForFixedLengthVectorVT(VT)) { + if (useSVEForFixedLengthVectorVT(VT, OverrideNEON)) { EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT); - // Create list of operands by convereting existing ones to scalable types. + // Create list of operands by converting existing ones to scalable types. SmallVector<SDValue, 4> Operands = {Pg}; for (const SDValue &V : Op->op_values()) { if (isa<CondCodeSDNode>(V)) { @@ -15134,11 +17015,21 @@ SDValue AArch64TargetLowering::LowerToPredicatedOp(SDValue Op, continue; } - assert(useSVEForFixedLengthVectorVT(V.getValueType()) && + if (const VTSDNode *VTNode = dyn_cast<VTSDNode>(V)) { + EVT VTArg = VTNode->getVT().getVectorElementType(); + EVT NewVTArg = ContainerVT.changeVectorElementType(VTArg); + Operands.push_back(DAG.getValueType(NewVTArg)); + continue; + } + + assert(useSVEForFixedLengthVectorVT(V.getValueType(), OverrideNEON) && "Only fixed length vectors are supported!"); Operands.push_back(convertToScalableVector(DAG, ContainerVT, V)); } + if (isMergePassthruOpcode(NewOp)) + Operands.push_back(DAG.getUNDEF(ContainerVT)); + auto ScalableRes = DAG.getNode(NewOp, DL, ContainerVT, Operands); return convertFromScalableVector(DAG, VT, ScalableRes); } @@ -15147,10 +17038,228 @@ SDValue AArch64TargetLowering::LowerToPredicatedOp(SDValue Op, SmallVector<SDValue, 4> Operands = {Pg}; for (const SDValue &V : Op->op_values()) { - assert((isa<CondCodeSDNode>(V) || V.getValueType().isScalableVector()) && + assert((!V.getValueType().isVector() || + V.getValueType().isScalableVector()) && "Only scalable vectors are supported!"); Operands.push_back(V); } + if (isMergePassthruOpcode(NewOp)) + Operands.push_back(DAG.getUNDEF(VT)); + return DAG.getNode(NewOp, DL, VT, Operands); } + +// If a fixed length vector operation has no side effects when applied to +// undefined elements, we can safely use scalable vectors to perform the same +// operation without needing to worry about predication. +SDValue AArch64TargetLowering::LowerToScalableOp(SDValue Op, + SelectionDAG &DAG) const { + EVT VT = Op.getValueType(); + assert(useSVEForFixedLengthVectorVT(VT) && + "Only expected to lower fixed length vector operation!"); + EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT); + + // Create list of operands by converting existing ones to scalable types. + SmallVector<SDValue, 4> Ops; + for (const SDValue &V : Op->op_values()) { + assert(!isa<VTSDNode>(V) && "Unexpected VTSDNode node!"); + + // Pass through non-vector operands. + if (!V.getValueType().isVector()) { + Ops.push_back(V); + continue; + } + + // "cast" fixed length vector to a scalable vector. + assert(useSVEForFixedLengthVectorVT(V.getValueType()) && + "Only fixed length vectors are supported!"); + Ops.push_back(convertToScalableVector(DAG, ContainerVT, V)); + } + + auto ScalableRes = DAG.getNode(Op.getOpcode(), SDLoc(Op), ContainerVT, Ops); + return convertFromScalableVector(DAG, VT, ScalableRes); +} + +SDValue AArch64TargetLowering::LowerVECREDUCE_SEQ_FADD(SDValue ScalarOp, + SelectionDAG &DAG) const { + SDLoc DL(ScalarOp); + SDValue AccOp = ScalarOp.getOperand(0); + SDValue VecOp = ScalarOp.getOperand(1); + EVT SrcVT = VecOp.getValueType(); + EVT ResVT = SrcVT.getVectorElementType(); + + EVT ContainerVT = SrcVT; + if (SrcVT.isFixedLengthVector()) { + ContainerVT = getContainerForFixedLengthVector(DAG, SrcVT); + VecOp = convertToScalableVector(DAG, ContainerVT, VecOp); + } + + SDValue Pg = getPredicateForVector(DAG, DL, SrcVT); + SDValue Zero = DAG.getConstant(0, DL, MVT::i64); + + // Convert operands to Scalable. + AccOp = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, ContainerVT, + DAG.getUNDEF(ContainerVT), AccOp, Zero); + + // Perform reduction. + SDValue Rdx = DAG.getNode(AArch64ISD::FADDA_PRED, DL, ContainerVT, + Pg, AccOp, VecOp); + + return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ResVT, Rdx, Zero); +} + +SDValue AArch64TargetLowering::LowerPredReductionToSVE(SDValue ReduceOp, + SelectionDAG &DAG) const { + SDLoc DL(ReduceOp); + SDValue Op = ReduceOp.getOperand(0); + EVT OpVT = Op.getValueType(); + EVT VT = ReduceOp.getValueType(); + + if (!OpVT.isScalableVector() || OpVT.getVectorElementType() != MVT::i1) + return SDValue(); + + SDValue Pg = getPredicateForVector(DAG, DL, OpVT); + + switch (ReduceOp.getOpcode()) { + default: + return SDValue(); + case ISD::VECREDUCE_OR: + return getPTest(DAG, VT, Pg, Op, AArch64CC::ANY_ACTIVE); + case ISD::VECREDUCE_AND: { + Op = DAG.getNode(ISD::XOR, DL, OpVT, Op, Pg); + return getPTest(DAG, VT, Pg, Op, AArch64CC::NONE_ACTIVE); + } + case ISD::VECREDUCE_XOR: { + SDValue ID = + DAG.getTargetConstant(Intrinsic::aarch64_sve_cntp, DL, MVT::i64); + SDValue Cntp = + DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, MVT::i64, ID, Pg, Op); + return DAG.getAnyExtOrTrunc(Cntp, DL, VT); + } + } + + return SDValue(); +} + +SDValue AArch64TargetLowering::LowerReductionToSVE(unsigned Opcode, + SDValue ScalarOp, + SelectionDAG &DAG) const { + SDLoc DL(ScalarOp); + SDValue VecOp = ScalarOp.getOperand(0); + EVT SrcVT = VecOp.getValueType(); + + if (useSVEForFixedLengthVectorVT(SrcVT, true)) { + EVT ContainerVT = getContainerForFixedLengthVector(DAG, SrcVT); + VecOp = convertToScalableVector(DAG, ContainerVT, VecOp); + } + + // UADDV always returns an i64 result. + EVT ResVT = (Opcode == AArch64ISD::UADDV_PRED) ? MVT::i64 : + SrcVT.getVectorElementType(); + EVT RdxVT = SrcVT; + if (SrcVT.isFixedLengthVector() || Opcode == AArch64ISD::UADDV_PRED) + RdxVT = getPackedSVEVectorVT(ResVT); + + SDValue Pg = getPredicateForVector(DAG, DL, SrcVT); + SDValue Rdx = DAG.getNode(Opcode, DL, RdxVT, Pg, VecOp); + SDValue Res = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ResVT, + Rdx, DAG.getConstant(0, DL, MVT::i64)); + + // The VEC_REDUCE nodes expect an element size result. + if (ResVT != ScalarOp.getValueType()) + Res = DAG.getAnyExtOrTrunc(Res, DL, ScalarOp.getValueType()); + + return Res; +} + +SDValue +AArch64TargetLowering::LowerFixedLengthVectorSelectToSVE(SDValue Op, + SelectionDAG &DAG) const { + EVT VT = Op.getValueType(); + SDLoc DL(Op); + + EVT InVT = Op.getOperand(1).getValueType(); + EVT ContainerVT = getContainerForFixedLengthVector(DAG, InVT); + SDValue Op1 = convertToScalableVector(DAG, ContainerVT, Op->getOperand(1)); + SDValue Op2 = convertToScalableVector(DAG, ContainerVT, Op->getOperand(2)); + + // Convert the mask to a predicated (NOTE: We don't need to worry about + // inactive lanes since VSELECT is safe when given undefined elements). + EVT MaskVT = Op.getOperand(0).getValueType(); + EVT MaskContainerVT = getContainerForFixedLengthVector(DAG, MaskVT); + auto Mask = convertToScalableVector(DAG, MaskContainerVT, Op.getOperand(0)); + Mask = DAG.getNode(ISD::TRUNCATE, DL, + MaskContainerVT.changeVectorElementType(MVT::i1), Mask); + + auto ScalableRes = DAG.getNode(ISD::VSELECT, DL, ContainerVT, + Mask, Op1, Op2); + + return convertFromScalableVector(DAG, VT, ScalableRes); +} + +SDValue AArch64TargetLowering::LowerFixedLengthVectorSetccToSVE( + SDValue Op, SelectionDAG &DAG) const { + SDLoc DL(Op); + EVT InVT = Op.getOperand(0).getValueType(); + EVT ContainerVT = getContainerForFixedLengthVector(DAG, InVT); + + assert(useSVEForFixedLengthVectorVT(InVT) && + "Only expected to lower fixed length vector operation!"); + assert(Op.getValueType() == InVT.changeTypeToInteger() && + "Expected integer result of the same bit length as the inputs!"); + + // Expand floating point vector comparisons. + if (InVT.isFloatingPoint()) + return SDValue(); + + auto Op1 = convertToScalableVector(DAG, ContainerVT, Op.getOperand(0)); + auto Op2 = convertToScalableVector(DAG, ContainerVT, Op.getOperand(1)); + auto Pg = getPredicateForFixedLengthVector(DAG, DL, InVT); + + EVT CmpVT = Pg.getValueType(); + auto Cmp = DAG.getNode(AArch64ISD::SETCC_MERGE_ZERO, DL, CmpVT, + {Pg, Op1, Op2, Op.getOperand(2)}); + + EVT PromoteVT = ContainerVT.changeTypeToInteger(); + auto Promote = DAG.getBoolExtOrTrunc(Cmp, DL, PromoteVT, InVT); + return convertFromScalableVector(DAG, Op.getValueType(), Promote); +} + +SDValue AArch64TargetLowering::getSVESafeBitCast(EVT VT, SDValue Op, + SelectionDAG &DAG) const { + SDLoc DL(Op); + EVT InVT = Op.getValueType(); + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + (void)TLI; + + assert(VT.isScalableVector() && TLI.isTypeLegal(VT) && + InVT.isScalableVector() && TLI.isTypeLegal(InVT) && + "Only expect to cast between legal scalable vector types!"); + assert((VT.getVectorElementType() == MVT::i1) == + (InVT.getVectorElementType() == MVT::i1) && + "Cannot cast between data and predicate scalable vector types!"); + + if (InVT == VT) + return Op; + + if (VT.getVectorElementType() == MVT::i1) + return DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, VT, Op); + + EVT PackedVT = getPackedSVEVectorVT(VT.getVectorElementType()); + EVT PackedInVT = getPackedSVEVectorVT(InVT.getVectorElementType()); + assert((VT == PackedVT || InVT == PackedInVT) && + "Cannot cast between unpacked scalable vector types!"); + + // Pack input if required. + if (InVT != PackedInVT) + Op = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, PackedInVT, Op); + + Op = DAG.getNode(ISD::BITCAST, DL, PackedVT, Op); + + // Unpack result if required. + if (VT != PackedVT) + Op = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, VT, Op); + + return Op; +} |