diff options
Diffstat (limited to 'llvm/lib/Target/AArch64/AArch64ISelLowering.cpp')
-rw-r--r-- | llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 2521 |
1 files changed, 1969 insertions, 552 deletions
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 1be09186dc0a..e7282aad05e2 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -29,7 +29,9 @@ #include "llvm/ADT/StringRef.h" #include "llvm/ADT/Triple.h" #include "llvm/ADT/Twine.h" +#include "llvm/Analysis/ObjCARCUtil.h" #include "llvm/Analysis/VectorUtils.h" +#include "llvm/CodeGen/Analysis.h" #include "llvm/CodeGen/CallingConvLower.h" #include "llvm/CodeGen/MachineBasicBlock.h" #include "llvm/CodeGen/MachineFrameInfo.h" @@ -343,6 +345,18 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setCondCodeAction(ISD::SETUGT, VT, Expand); setCondCodeAction(ISD::SETUEQ, VT, Expand); setCondCodeAction(ISD::SETUNE, VT, Expand); + + setOperationAction(ISD::FREM, VT, Expand); + setOperationAction(ISD::FPOW, VT, Expand); + setOperationAction(ISD::FPOWI, VT, Expand); + setOperationAction(ISD::FCOS, VT, Expand); + setOperationAction(ISD::FSIN, VT, Expand); + setOperationAction(ISD::FSINCOS, VT, Expand); + setOperationAction(ISD::FEXP, VT, Expand); + setOperationAction(ISD::FEXP2, VT, Expand); + setOperationAction(ISD::FLOG, VT, Expand); + setOperationAction(ISD::FLOG2, VT, Expand); + setOperationAction(ISD::FLOG10, VT, Expand); } } @@ -458,6 +472,11 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setOperationAction(ISD::STRICT_FP_ROUND, MVT::f32, Custom); setOperationAction(ISD::STRICT_FP_ROUND, MVT::f64, Custom); + setOperationAction(ISD::FP_TO_UINT_SAT, MVT::i32, Custom); + setOperationAction(ISD::FP_TO_UINT_SAT, MVT::i64, Custom); + setOperationAction(ISD::FP_TO_SINT_SAT, MVT::i32, Custom); + setOperationAction(ISD::FP_TO_SINT_SAT, MVT::i64, Custom); + // Variable arguments. setOperationAction(ISD::VASTART, MVT::Other, Custom); setOperationAction(ISD::VAARG, MVT::Other, Custom); @@ -604,6 +623,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setOperationAction(ISD::FNEARBYINT, MVT::f16, Promote); setOperationAction(ISD::FRINT, MVT::f16, Promote); setOperationAction(ISD::FROUND, MVT::f16, Promote); + setOperationAction(ISD::FROUNDEVEN, MVT::f16, Promote); setOperationAction(ISD::FTRUNC, MVT::f16, Promote); setOperationAction(ISD::FMINNUM, MVT::f16, Promote); setOperationAction(ISD::FMAXNUM, MVT::f16, Promote); @@ -623,6 +643,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setOperationAction(ISD::FABS, MVT::v4f16, Expand); setOperationAction(ISD::FNEG, MVT::v4f16, Expand); setOperationAction(ISD::FROUND, MVT::v4f16, Expand); + setOperationAction(ISD::FROUNDEVEN, MVT::v4f16, Expand); setOperationAction(ISD::FMA, MVT::v4f16, Expand); setOperationAction(ISD::SETCC, MVT::v4f16, Expand); setOperationAction(ISD::BR_CC, MVT::v4f16, Expand); @@ -647,6 +668,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setOperationAction(ISD::FNEARBYINT, MVT::v8f16, Expand); setOperationAction(ISD::FNEG, MVT::v8f16, Expand); setOperationAction(ISD::FROUND, MVT::v8f16, Expand); + setOperationAction(ISD::FROUNDEVEN, MVT::v8f16, Expand); setOperationAction(ISD::FRINT, MVT::v8f16, Expand); setOperationAction(ISD::FSQRT, MVT::v8f16, Expand); setOperationAction(ISD::FSUB, MVT::v8f16, Expand); @@ -666,6 +688,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setOperationAction(ISD::FRINT, Ty, Legal); setOperationAction(ISD::FTRUNC, Ty, Legal); setOperationAction(ISD::FROUND, Ty, Legal); + setOperationAction(ISD::FROUNDEVEN, Ty, Legal); setOperationAction(ISD::FMINNUM, Ty, Legal); setOperationAction(ISD::FMAXNUM, Ty, Legal); setOperationAction(ISD::FMINIMUM, Ty, Legal); @@ -683,6 +706,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setOperationAction(ISD::FRINT, MVT::f16, Legal); setOperationAction(ISD::FTRUNC, MVT::f16, Legal); setOperationAction(ISD::FROUND, MVT::f16, Legal); + setOperationAction(ISD::FROUNDEVEN, MVT::f16, Legal); setOperationAction(ISD::FMINNUM, MVT::f16, Legal); setOperationAction(ISD::FMAXNUM, MVT::f16, Legal); setOperationAction(ISD::FMINIMUM, MVT::f16, Legal); @@ -692,6 +716,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setOperationAction(ISD::PREFETCH, MVT::Other, Custom); setOperationAction(ISD::FLT_ROUNDS_, MVT::i32, Custom); + setOperationAction(ISD::SET_ROUNDING, MVT::Other, Custom); setOperationAction(ISD::ATOMIC_CMP_SWAP, MVT::i128, Custom); setOperationAction(ISD::ATOMIC_LOAD_SUB, MVT::i32, Custom); @@ -857,15 +882,20 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setTargetDAGCombine(ISD::SINT_TO_FP); setTargetDAGCombine(ISD::UINT_TO_FP); + // TODO: Do the same for FP_TO_*INT_SAT. setTargetDAGCombine(ISD::FP_TO_SINT); setTargetDAGCombine(ISD::FP_TO_UINT); setTargetDAGCombine(ISD::FDIV); + // Try and combine setcc with csel + setTargetDAGCombine(ISD::SETCC); + setTargetDAGCombine(ISD::INTRINSIC_WO_CHAIN); setTargetDAGCombine(ISD::ANY_EXTEND); setTargetDAGCombine(ISD::ZERO_EXTEND); setTargetDAGCombine(ISD::SIGN_EXTEND); + setTargetDAGCombine(ISD::VECTOR_SPLICE); setTargetDAGCombine(ISD::SIGN_EXTEND_INREG); setTargetDAGCombine(ISD::TRUNCATE); setTargetDAGCombine(ISD::CONCAT_VECTORS); @@ -873,9 +903,6 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, if (Subtarget->supportsAddressTopByteIgnored()) setTargetDAGCombine(ISD::LOAD); - setTargetDAGCombine(ISD::MGATHER); - setTargetDAGCombine(ISD::MSCATTER); - setTargetDAGCombine(ISD::MUL); setTargetDAGCombine(ISD::SELECT); @@ -886,6 +913,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setTargetDAGCombine(ISD::INSERT_VECTOR_ELT); setTargetDAGCombine(ISD::EXTRACT_VECTOR_ELT); setTargetDAGCombine(ISD::VECREDUCE_ADD); + setTargetDAGCombine(ISD::STEP_VECTOR); setTargetDAGCombine(ISD::GlobalAddress); @@ -944,6 +972,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setOperationAction(ISD::FPOW, MVT::v1f64, Expand); setOperationAction(ISD::FREM, MVT::v1f64, Expand); setOperationAction(ISD::FROUND, MVT::v1f64, Expand); + setOperationAction(ISD::FROUNDEVEN, MVT::v1f64, Expand); setOperationAction(ISD::FRINT, MVT::v1f64, Expand); setOperationAction(ISD::FSIN, MVT::v1f64, Expand); setOperationAction(ISD::FSINCOS, MVT::v1f64, Expand); @@ -968,9 +997,11 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, // elements smaller than i32, so promote the input to i32 first. setOperationPromotedToType(ISD::UINT_TO_FP, MVT::v4i8, MVT::v4i32); setOperationPromotedToType(ISD::SINT_TO_FP, MVT::v4i8, MVT::v4i32); - // i8 vector elements also need promotion to i32 for v8i8 setOperationPromotedToType(ISD::SINT_TO_FP, MVT::v8i8, MVT::v8i32); setOperationPromotedToType(ISD::UINT_TO_FP, MVT::v8i8, MVT::v8i32); + setOperationPromotedToType(ISD::UINT_TO_FP, MVT::v16i8, MVT::v16i32); + setOperationPromotedToType(ISD::SINT_TO_FP, MVT::v16i8, MVT::v16i32); + // Similarly, there is no direct i32 -> f64 vector conversion instruction. setOperationAction(ISD::SINT_TO_FP, MVT::v2i32, Custom); setOperationAction(ISD::UINT_TO_FP, MVT::v2i32, Custom); @@ -997,6 +1028,12 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setOperationAction(ISD::CTLZ, MVT::v1i64, Expand); setOperationAction(ISD::CTLZ, MVT::v2i64, Expand); + setOperationAction(ISD::BITREVERSE, MVT::v8i8, Legal); + setOperationAction(ISD::BITREVERSE, MVT::v16i8, Legal); + setOperationAction(ISD::BITREVERSE, MVT::v2i32, Custom); + setOperationAction(ISD::BITREVERSE, MVT::v4i32, Custom); + setOperationAction(ISD::BITREVERSE, MVT::v1i64, Custom); + setOperationAction(ISD::BITREVERSE, MVT::v2i64, Custom); // AArch64 doesn't have MUL.2d: setOperationAction(ISD::MUL, MVT::v2i64, Expand); @@ -1014,14 +1051,21 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setOperationAction(ISD::USUBSAT, VT, Legal); } + for (MVT VT : {MVT::v8i8, MVT::v4i16, MVT::v2i32, MVT::v16i8, MVT::v8i16, + MVT::v4i32}) { + setOperationAction(ISD::ABDS, VT, Legal); + setOperationAction(ISD::ABDU, VT, Legal); + } + // Vector reductions for (MVT VT : { MVT::v4f16, MVT::v2f32, MVT::v8f16, MVT::v4f32, MVT::v2f64 }) { - setOperationAction(ISD::VECREDUCE_FMAX, VT, Custom); - setOperationAction(ISD::VECREDUCE_FMIN, VT, Custom); + if (VT.getVectorElementType() != MVT::f16 || Subtarget->hasFullFP16()) { + setOperationAction(ISD::VECREDUCE_FMAX, VT, Custom); + setOperationAction(ISD::VECREDUCE_FMIN, VT, Custom); - if (VT.getVectorElementType() != MVT::f16 || Subtarget->hasFullFP16()) setOperationAction(ISD::VECREDUCE_FADD, VT, Legal); + } } for (MVT VT : { MVT::v8i8, MVT::v4i16, MVT::v2i32, MVT::v16i8, MVT::v8i16, MVT::v4i32 }) { @@ -1069,6 +1113,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setOperationAction(ISD::FRINT, Ty, Legal); setOperationAction(ISD::FTRUNC, Ty, Legal); setOperationAction(ISD::FROUND, Ty, Legal); + setOperationAction(ISD::FROUNDEVEN, Ty, Legal); } if (Subtarget->hasFullFP16()) { @@ -1079,6 +1124,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setOperationAction(ISD::FRINT, Ty, Legal); setOperationAction(ISD::FTRUNC, Ty, Legal); setOperationAction(ISD::FROUND, Ty, Legal); + setOperationAction(ISD::FROUNDEVEN, Ty, Legal); } } @@ -1086,12 +1132,16 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setOperationAction(ISD::VSCALE, MVT::i32, Custom); setTruncStoreAction(MVT::v4i16, MVT::v4i8, Custom); + + setLoadExtAction(ISD::EXTLOAD, MVT::v4i16, MVT::v4i8, Custom); + setLoadExtAction(ISD::SEXTLOAD, MVT::v4i16, MVT::v4i8, Custom); + setLoadExtAction(ISD::ZEXTLOAD, MVT::v4i16, MVT::v4i8, Custom); + setLoadExtAction(ISD::EXTLOAD, MVT::v4i32, MVT::v4i8, Custom); + setLoadExtAction(ISD::SEXTLOAD, MVT::v4i32, MVT::v4i8, Custom); + setLoadExtAction(ISD::ZEXTLOAD, MVT::v4i32, MVT::v4i8, Custom); } if (Subtarget->hasSVE()) { - // FIXME: Add custom lowering of MLOAD to handle different passthrus (not a - // splat of 0 or undef) once vector selects supported in SVE codegen. See - // D68877 for more details. for (auto VT : {MVT::nxv16i8, MVT::nxv8i16, MVT::nxv4i32, MVT::nxv2i64}) { setOperationAction(ISD::BITREVERSE, VT, Custom); setOperationAction(ISD::BSWAP, VT, Custom); @@ -1105,9 +1155,14 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setOperationAction(ISD::FP_TO_SINT, VT, Custom); setOperationAction(ISD::MGATHER, VT, Custom); setOperationAction(ISD::MSCATTER, VT, Custom); + setOperationAction(ISD::MLOAD, VT, Custom); setOperationAction(ISD::MUL, VT, Custom); + setOperationAction(ISD::MULHS, VT, Custom); + setOperationAction(ISD::MULHU, VT, Custom); setOperationAction(ISD::SPLAT_VECTOR, VT, Custom); + setOperationAction(ISD::VECTOR_SPLICE, VT, Custom); setOperationAction(ISD::SELECT, VT, Custom); + setOperationAction(ISD::SETCC, VT, Custom); setOperationAction(ISD::SDIV, VT, Custom); setOperationAction(ISD::UDIV, VT, Custom); setOperationAction(ISD::SMIN, VT, Custom); @@ -1126,6 +1181,12 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setOperationAction(ISD::VECREDUCE_UMAX, VT, Custom); setOperationAction(ISD::VECREDUCE_SMIN, VT, Custom); setOperationAction(ISD::VECREDUCE_SMAX, VT, Custom); + + setOperationAction(ISD::UMUL_LOHI, VT, Expand); + setOperationAction(ISD::SMUL_LOHI, VT, Expand); + setOperationAction(ISD::SELECT_CC, VT, Expand); + setOperationAction(ISD::ROTL, VT, Expand); + setOperationAction(ISD::ROTR, VT, Expand); } // Illegal unpacked integer vector types. @@ -1134,6 +1195,11 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom); } + // Legalize unpacked bitcasts to REINTERPRET_CAST. + for (auto VT : {MVT::nxv2i16, MVT::nxv4i16, MVT::nxv2i32, MVT::nxv2bf16, + MVT::nxv2f16, MVT::nxv4f16, MVT::nxv2f32}) + setOperationAction(ISD::BITCAST, VT, Custom); + for (auto VT : {MVT::nxv16i1, MVT::nxv8i1, MVT::nxv4i1, MVT::nxv2i1}) { setOperationAction(ISD::CONCAT_VECTORS, VT, Custom); setOperationAction(ISD::SELECT, VT, Custom); @@ -1144,6 +1210,10 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setOperationAction(ISD::VECREDUCE_OR, VT, Custom); setOperationAction(ISD::VECREDUCE_XOR, VT, Custom); + setOperationAction(ISD::SELECT_CC, VT, Expand); + setOperationAction(ISD::EXTRACT_VECTOR_ELT, VT, Custom); + setOperationAction(ISD::INSERT_VECTOR_ELT, VT, Custom); + // There are no legal MVT::nxv16f## based types. if (VT != MVT::nxv16i1) { setOperationAction(ISD::SINT_TO_FP, VT, Custom); @@ -1151,18 +1221,50 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, } } + // NEON doesn't support masked loads/stores/gathers/scatters, but SVE does + for (auto VT : {MVT::v4f16, MVT::v8f16, MVT::v2f32, MVT::v4f32, MVT::v1f64, + MVT::v2f64, MVT::v8i8, MVT::v16i8, MVT::v4i16, MVT::v8i16, + MVT::v2i32, MVT::v4i32, MVT::v1i64, MVT::v2i64}) { + setOperationAction(ISD::MLOAD, VT, Custom); + setOperationAction(ISD::MSTORE, VT, Custom); + setOperationAction(ISD::MGATHER, VT, Custom); + setOperationAction(ISD::MSCATTER, VT, Custom); + } + + for (MVT VT : MVT::fp_scalable_vector_valuetypes()) { + for (MVT InnerVT : MVT::fp_scalable_vector_valuetypes()) { + // Avoid marking truncating FP stores as legal to prevent the + // DAGCombiner from creating unsupported truncating stores. + setTruncStoreAction(VT, InnerVT, Expand); + // SVE does not have floating-point extending loads. + setLoadExtAction(ISD::SEXTLOAD, VT, InnerVT, Expand); + setLoadExtAction(ISD::ZEXTLOAD, VT, InnerVT, Expand); + setLoadExtAction(ISD::EXTLOAD, VT, InnerVT, Expand); + } + } + + // SVE supports truncating stores of 64 and 128-bit vectors + setTruncStoreAction(MVT::v2i64, MVT::v2i8, Custom); + setTruncStoreAction(MVT::v2i64, MVT::v2i16, Custom); + setTruncStoreAction(MVT::v2i64, MVT::v2i32, Custom); + setTruncStoreAction(MVT::v2i32, MVT::v2i8, Custom); + setTruncStoreAction(MVT::v2i32, MVT::v2i16, Custom); + for (auto VT : {MVT::nxv2f16, MVT::nxv4f16, MVT::nxv8f16, MVT::nxv2f32, MVT::nxv4f32, MVT::nxv2f64}) { setOperationAction(ISD::CONCAT_VECTORS, VT, Custom); setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom); setOperationAction(ISD::MGATHER, VT, Custom); setOperationAction(ISD::MSCATTER, VT, Custom); + setOperationAction(ISD::MLOAD, VT, Custom); setOperationAction(ISD::SPLAT_VECTOR, VT, Custom); setOperationAction(ISD::SELECT, VT, Custom); setOperationAction(ISD::FADD, VT, Custom); setOperationAction(ISD::FDIV, VT, Custom); setOperationAction(ISD::FMA, VT, Custom); + setOperationAction(ISD::FMAXIMUM, VT, Custom); setOperationAction(ISD::FMAXNUM, VT, Custom); + setOperationAction(ISD::FMINIMUM, VT, Custom); setOperationAction(ISD::FMINNUM, VT, Custom); setOperationAction(ISD::FMUL, VT, Custom); setOperationAction(ISD::FNEG, VT, Custom); @@ -1182,12 +1284,16 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setOperationAction(ISD::VECREDUCE_FMAX, VT, Custom); setOperationAction(ISD::VECREDUCE_FMIN, VT, Custom); setOperationAction(ISD::VECREDUCE_SEQ_FADD, VT, Custom); + setOperationAction(ISD::VECTOR_SPLICE, VT, Custom); + + setOperationAction(ISD::SELECT_CC, VT, Expand); } for (auto VT : {MVT::nxv2bf16, MVT::nxv4bf16, MVT::nxv8bf16}) { setOperationAction(ISD::CONCAT_VECTORS, VT, Custom); setOperationAction(ISD::MGATHER, VT, Custom); setOperationAction(ISD::MSCATTER, VT, Custom); + setOperationAction(ISD::MLOAD, VT, Custom); } setOperationAction(ISD::SPLAT_VECTOR, MVT::nxv8bf16, Custom); @@ -1214,7 +1320,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, for (auto VT : {MVT::v16i8, MVT::v8i16, MVT::v4i32}) setOperationAction(ISD::TRUNCATE, VT, Custom); for (auto VT : {MVT::v8f16, MVT::v4f32}) - setOperationAction(ISD::FP_ROUND, VT, Expand); + setOperationAction(ISD::FP_ROUND, VT, Custom); // These operations are not supported on NEON but SVE can do them. setOperationAction(ISD::BITREVERSE, MVT::v1i64, Custom); @@ -1223,6 +1329,10 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setOperationAction(ISD::CTTZ, MVT::v1i64, Custom); setOperationAction(ISD::MUL, MVT::v1i64, Custom); setOperationAction(ISD::MUL, MVT::v2i64, Custom); + setOperationAction(ISD::MULHS, MVT::v1i64, Custom); + setOperationAction(ISD::MULHS, MVT::v2i64, Custom); + setOperationAction(ISD::MULHU, MVT::v1i64, Custom); + setOperationAction(ISD::MULHU, MVT::v2i64, Custom); setOperationAction(ISD::SDIV, MVT::v8i8, Custom); setOperationAction(ISD::SDIV, MVT::v16i8, Custom); setOperationAction(ISD::SDIV, MVT::v4i16, Custom); @@ -1271,12 +1381,17 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, for (auto VT : {MVT::v4f16, MVT::v8f16, MVT::v4f32}) setOperationAction(ISD::VECREDUCE_FADD, VT, Custom); } + + setOperationPromotedToType(ISD::VECTOR_SPLICE, MVT::nxv2i1, MVT::nxv2i64); + setOperationPromotedToType(ISD::VECTOR_SPLICE, MVT::nxv4i1, MVT::nxv4i32); + setOperationPromotedToType(ISD::VECTOR_SPLICE, MVT::nxv8i1, MVT::nxv8i16); + setOperationPromotedToType(ISD::VECTOR_SPLICE, MVT::nxv16i1, MVT::nxv16i8); } PredictableSelectIsExpensive = Subtarget->predictableSelectIsExpensive(); } -void AArch64TargetLowering::addTypeForNEON(MVT VT, MVT PromotedBitwiseVT) { +void AArch64TargetLowering::addTypeForNEON(MVT VT) { assert(VT.isVector() && "VT should be a vector type"); if (VT.isFloatingPoint()) { @@ -1295,10 +1410,12 @@ void AArch64TargetLowering::addTypeForNEON(MVT VT, MVT PromotedBitwiseVT) { setOperationAction(ISD::FLOG10, VT, Expand); setOperationAction(ISD::FEXP, VT, Expand); setOperationAction(ISD::FEXP2, VT, Expand); + } - // But we do support custom-lowering for FCOPYSIGN. + // But we do support custom-lowering for FCOPYSIGN. + if (VT == MVT::v2f32 || VT == MVT::v4f32 || VT == MVT::v2f64 || + ((VT == MVT::v4f16 || VT == MVT::v8f16) && Subtarget->hasFullFP16())) setOperationAction(ISD::FCOPYSIGN, VT, Custom); - } setOperationAction(ISD::EXTRACT_VECTOR_ELT, VT, Custom); setOperationAction(ISD::INSERT_VECTOR_ELT, VT, Custom); @@ -1366,48 +1483,93 @@ void AArch64TargetLowering::addTypeForFixedLengthSVE(MVT VT) { // We use EXTRACT_SUBVECTOR to "cast" a scalable vector to a fixed length one. setOperationAction(ISD::EXTRACT_SUBVECTOR, VT, Custom); + if (VT.isFloatingPoint()) { + setCondCodeAction(ISD::SETO, VT, Expand); + setCondCodeAction(ISD::SETOLT, VT, Expand); + setCondCodeAction(ISD::SETLT, VT, Expand); + setCondCodeAction(ISD::SETOLE, VT, Expand); + setCondCodeAction(ISD::SETLE, VT, Expand); + setCondCodeAction(ISD::SETULT, VT, Expand); + setCondCodeAction(ISD::SETULE, VT, Expand); + setCondCodeAction(ISD::SETUGE, VT, Expand); + setCondCodeAction(ISD::SETUGT, VT, Expand); + setCondCodeAction(ISD::SETUEQ, VT, Expand); + setCondCodeAction(ISD::SETUNE, VT, Expand); + } + + // Mark integer truncating stores as having custom lowering + if (VT.isInteger()) { + MVT InnerVT = VT.changeVectorElementType(MVT::i8); + while (InnerVT != VT) { + setTruncStoreAction(VT, InnerVT, Custom); + InnerVT = InnerVT.changeVectorElementType( + MVT::getIntegerVT(2 * InnerVT.getScalarSizeInBits())); + } + } + // Lower fixed length vector operations to scalable equivalents. setOperationAction(ISD::ABS, VT, Custom); setOperationAction(ISD::ADD, VT, Custom); setOperationAction(ISD::AND, VT, Custom); setOperationAction(ISD::ANY_EXTEND, VT, Custom); + setOperationAction(ISD::BITCAST, VT, Custom); setOperationAction(ISD::BITREVERSE, VT, Custom); setOperationAction(ISD::BSWAP, VT, Custom); + setOperationAction(ISD::CONCAT_VECTORS, VT, Custom); setOperationAction(ISD::CTLZ, VT, Custom); setOperationAction(ISD::CTPOP, VT, Custom); setOperationAction(ISD::CTTZ, VT, Custom); + setOperationAction(ISD::FABS, VT, Custom); setOperationAction(ISD::FADD, VT, Custom); + setOperationAction(ISD::EXTRACT_VECTOR_ELT, VT, Custom); setOperationAction(ISD::FCEIL, VT, Custom); setOperationAction(ISD::FDIV, VT, Custom); setOperationAction(ISD::FFLOOR, VT, Custom); setOperationAction(ISD::FMA, VT, Custom); + setOperationAction(ISD::FMAXIMUM, VT, Custom); setOperationAction(ISD::FMAXNUM, VT, Custom); + setOperationAction(ISD::FMINIMUM, VT, Custom); setOperationAction(ISD::FMINNUM, VT, Custom); setOperationAction(ISD::FMUL, VT, Custom); setOperationAction(ISD::FNEARBYINT, VT, Custom); setOperationAction(ISD::FNEG, VT, Custom); + setOperationAction(ISD::FP_EXTEND, VT, Custom); + setOperationAction(ISD::FP_ROUND, VT, Custom); + setOperationAction(ISD::FP_TO_SINT, VT, Custom); + setOperationAction(ISD::FP_TO_UINT, VT, Custom); setOperationAction(ISD::FRINT, VT, Custom); setOperationAction(ISD::FROUND, VT, Custom); + setOperationAction(ISD::FROUNDEVEN, VT, Custom); setOperationAction(ISD::FSQRT, VT, Custom); setOperationAction(ISD::FSUB, VT, Custom); setOperationAction(ISD::FTRUNC, VT, Custom); setOperationAction(ISD::LOAD, VT, Custom); + setOperationAction(ISD::MGATHER, VT, Custom); + setOperationAction(ISD::MLOAD, VT, Custom); + setOperationAction(ISD::MSCATTER, VT, Custom); + setOperationAction(ISD::MSTORE, VT, Custom); setOperationAction(ISD::MUL, VT, Custom); + setOperationAction(ISD::MULHS, VT, Custom); + setOperationAction(ISD::MULHU, VT, Custom); setOperationAction(ISD::OR, VT, Custom); setOperationAction(ISD::SDIV, VT, Custom); + setOperationAction(ISD::SELECT, VT, Custom); setOperationAction(ISD::SETCC, VT, Custom); setOperationAction(ISD::SHL, VT, Custom); setOperationAction(ISD::SIGN_EXTEND, VT, Custom); setOperationAction(ISD::SIGN_EXTEND_INREG, VT, Custom); + setOperationAction(ISD::SINT_TO_FP, VT, Custom); setOperationAction(ISD::SMAX, VT, Custom); setOperationAction(ISD::SMIN, VT, Custom); setOperationAction(ISD::SPLAT_VECTOR, VT, Custom); + setOperationAction(ISD::VECTOR_SPLICE, VT, Custom); setOperationAction(ISD::SRA, VT, Custom); setOperationAction(ISD::SRL, VT, Custom); setOperationAction(ISD::STORE, VT, Custom); setOperationAction(ISD::SUB, VT, Custom); setOperationAction(ISD::TRUNCATE, VT, Custom); setOperationAction(ISD::UDIV, VT, Custom); + setOperationAction(ISD::UINT_TO_FP, VT, Custom); setOperationAction(ISD::UMAX, VT, Custom); setOperationAction(ISD::UMIN, VT, Custom); setOperationAction(ISD::VECREDUCE_ADD, VT, Custom); @@ -1417,11 +1579,13 @@ void AArch64TargetLowering::addTypeForFixedLengthSVE(MVT VT) { setOperationAction(ISD::VECREDUCE_FMAX, VT, Custom); setOperationAction(ISD::VECREDUCE_FMIN, VT, Custom); setOperationAction(ISD::VECREDUCE_OR, VT, Custom); + setOperationAction(ISD::INSERT_VECTOR_ELT, VT, Custom); setOperationAction(ISD::VECREDUCE_SMAX, VT, Custom); setOperationAction(ISD::VECREDUCE_SMIN, VT, Custom); setOperationAction(ISD::VECREDUCE_UMAX, VT, Custom); setOperationAction(ISD::VECREDUCE_UMIN, VT, Custom); setOperationAction(ISD::VECREDUCE_XOR, VT, Custom); + setOperationAction(ISD::VECTOR_SHUFFLE, VT, Custom); setOperationAction(ISD::VSELECT, VT, Custom); setOperationAction(ISD::XOR, VT, Custom); setOperationAction(ISD::ZERO_EXTEND, VT, Custom); @@ -1429,12 +1593,12 @@ void AArch64TargetLowering::addTypeForFixedLengthSVE(MVT VT) { void AArch64TargetLowering::addDRTypeForNEON(MVT VT) { addRegisterClass(VT, &AArch64::FPR64RegClass); - addTypeForNEON(VT, MVT::v2i32); + addTypeForNEON(VT); } void AArch64TargetLowering::addQRTypeForNEON(MVT VT) { addRegisterClass(VT, &AArch64::FPR128RegClass); - addTypeForNEON(VT, MVT::v4i32); + addTypeForNEON(VT); } EVT AArch64TargetLowering::getSetCCResultType(const DataLayout &, @@ -1659,7 +1823,7 @@ MVT AArch64TargetLowering::getScalarShiftAmountTy(const DataLayout &DL, } bool AArch64TargetLowering::allowsMisalignedMemoryAccesses( - EVT VT, unsigned AddrSpace, unsigned Align, MachineMemOperand::Flags Flags, + EVT VT, unsigned AddrSpace, Align Alignment, MachineMemOperand::Flags Flags, bool *Fast) const { if (Subtarget->requiresStrictAlign()) return false; @@ -1673,7 +1837,7 @@ bool AArch64TargetLowering::allowsMisalignedMemoryAccesses( // Code that uses clang vector extensions can mark that it // wants unaligned accesses to be treated as fast by // underspecifying alignment to be 1 or 2. - Align <= 2 || + Alignment <= 2 || // Disregard v2i64. Memcpy lowering produces those and splitting // them regresses performance on micro-benchmarks and olden/bh. @@ -1703,7 +1867,7 @@ bool AArch64TargetLowering::allowsMisalignedMemoryAccesses( // Disregard v2i64. Memcpy lowering produces those and splitting // them regresses performance on micro-benchmarks and olden/bh. - Ty == LLT::vector(2, 64); + Ty == LLT::fixed_vector(2, 64); } return true; } @@ -1729,7 +1893,6 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const { MAKE_CASE(AArch64ISD::RET_FLAG) MAKE_CASE(AArch64ISD::BRCOND) MAKE_CASE(AArch64ISD::CSEL) - MAKE_CASE(AArch64ISD::FCSEL) MAKE_CASE(AArch64ISD::CSINV) MAKE_CASE(AArch64ISD::CSNEG) MAKE_CASE(AArch64ISD::CSINC) @@ -1737,6 +1900,8 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const { MAKE_CASE(AArch64ISD::TLSDESC_CALLSEQ) MAKE_CASE(AArch64ISD::ADD_PRED) MAKE_CASE(AArch64ISD::MUL_PRED) + MAKE_CASE(AArch64ISD::MULHS_PRED) + MAKE_CASE(AArch64ISD::MULHU_PRED) MAKE_CASE(AArch64ISD::SDIV_PRED) MAKE_CASE(AArch64ISD::SHL_PRED) MAKE_CASE(AArch64ISD::SMAX_PRED) @@ -1797,7 +1962,6 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const { MAKE_CASE(AArch64ISD::BICi) MAKE_CASE(AArch64ISD::ORRi) MAKE_CASE(AArch64ISD::BSP) - MAKE_CASE(AArch64ISD::NEG) MAKE_CASE(AArch64ISD::EXTR) MAKE_CASE(AArch64ISD::ZIP1) MAKE_CASE(AArch64ISD::ZIP2) @@ -1809,6 +1973,7 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const { MAKE_CASE(AArch64ISD::REV32) MAKE_CASE(AArch64ISD::REV64) MAKE_CASE(AArch64ISD::EXT) + MAKE_CASE(AArch64ISD::SPLICE) MAKE_CASE(AArch64ISD::VSHL) MAKE_CASE(AArch64ISD::VLSHR) MAKE_CASE(AArch64ISD::VASHR) @@ -1838,6 +2003,8 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const { MAKE_CASE(AArch64ISD::URHADD) MAKE_CASE(AArch64ISD::SHADD) MAKE_CASE(AArch64ISD::UHADD) + MAKE_CASE(AArch64ISD::SDOT) + MAKE_CASE(AArch64ISD::UDOT) MAKE_CASE(AArch64ISD::SMINV) MAKE_CASE(AArch64ISD::UMINV) MAKE_CASE(AArch64ISD::SMAXV) @@ -1855,7 +2022,6 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const { MAKE_CASE(AArch64ISD::CLASTB_N) MAKE_CASE(AArch64ISD::LASTA) MAKE_CASE(AArch64ISD::LASTB) - MAKE_CASE(AArch64ISD::REV) MAKE_CASE(AArch64ISD::REINTERPRET_CAST) MAKE_CASE(AArch64ISD::TBL) MAKE_CASE(AArch64ISD::FADD_PRED) @@ -1863,14 +2029,17 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const { MAKE_CASE(AArch64ISD::FADDV_PRED) MAKE_CASE(AArch64ISD::FDIV_PRED) MAKE_CASE(AArch64ISD::FMA_PRED) + MAKE_CASE(AArch64ISD::FMAX_PRED) MAKE_CASE(AArch64ISD::FMAXV_PRED) MAKE_CASE(AArch64ISD::FMAXNM_PRED) MAKE_CASE(AArch64ISD::FMAXNMV_PRED) + MAKE_CASE(AArch64ISD::FMIN_PRED) MAKE_CASE(AArch64ISD::FMINV_PRED) MAKE_CASE(AArch64ISD::FMINNM_PRED) MAKE_CASE(AArch64ISD::FMINNMV_PRED) MAKE_CASE(AArch64ISD::FMUL_PRED) MAKE_CASE(AArch64ISD::FSUB_PRED) + MAKE_CASE(AArch64ISD::BIC) MAKE_CASE(AArch64ISD::BIT) MAKE_CASE(AArch64ISD::CBZ) MAKE_CASE(AArch64ISD::CBNZ) @@ -1881,6 +2050,7 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const { MAKE_CASE(AArch64ISD::SITOF) MAKE_CASE(AArch64ISD::UITOF) MAKE_CASE(AArch64ISD::NVCAST) + MAKE_CASE(AArch64ISD::MRS) MAKE_CASE(AArch64ISD::SQSHL_I) MAKE_CASE(AArch64ISD::UQSHL_I) MAKE_CASE(AArch64ISD::SRSHR_I) @@ -1988,8 +2158,7 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const { MAKE_CASE(AArch64ISD::CTPOP_MERGE_PASSTHRU) MAKE_CASE(AArch64ISD::DUP_MERGE_PASSTHRU) MAKE_CASE(AArch64ISD::INDEX_VECTOR) - MAKE_CASE(AArch64ISD::UABD) - MAKE_CASE(AArch64ISD::SABD) + MAKE_CASE(AArch64ISD::UADDLP) MAKE_CASE(AArch64ISD::CALL_RVMARKER) } #undef MAKE_CASE @@ -2094,6 +2263,31 @@ MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter( // Lowering Code //===----------------------------------------------------------------------===// +// Forward declarations of SVE fixed length lowering helpers +static EVT getContainerForFixedLengthVector(SelectionDAG &DAG, EVT VT); +static SDValue convertToScalableVector(SelectionDAG &DAG, EVT VT, SDValue V); +static SDValue convertFromScalableVector(SelectionDAG &DAG, EVT VT, SDValue V); +static SDValue convertFixedMaskToScalableVector(SDValue Mask, + SelectionDAG &DAG); + +/// isZerosVector - Check whether SDNode N is a zero-filled vector. +static bool isZerosVector(const SDNode *N) { + // Look through a bit convert. + while (N->getOpcode() == ISD::BITCAST) + N = N->getOperand(0).getNode(); + + if (ISD::isConstantSplatVectorAllZeros(N)) + return true; + + if (N->getOpcode() != AArch64ISD::DUP) + return false; + + auto Opnd0 = N->getOperand(0); + auto *CINT = dyn_cast<ConstantSDNode>(Opnd0); + auto *CFP = dyn_cast<ConstantFPSDNode>(Opnd0); + return (CINT && CINT->isNullValue()) || (CFP && CFP->isZero()); +} + /// changeIntCCToAArch64CC - Convert a DAG integer condition code to an AArch64 /// CC static AArch64CC::CondCode changeIntCCToAArch64CC(ISD::CondCode CC) { @@ -2823,50 +3017,25 @@ getAArch64XALUOOp(AArch64CC::CondCode &CC, SDValue Op, SelectionDAG &DAG) { CC = AArch64CC::NE; bool IsSigned = Op.getOpcode() == ISD::SMULO; if (Op.getValueType() == MVT::i32) { + // Extend to 64-bits, then perform a 64-bit multiply. unsigned ExtendOpc = IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND; - // For a 32 bit multiply with overflow check we want the instruction - // selector to generate a widening multiply (SMADDL/UMADDL). For that we - // need to generate the following pattern: - // (i64 add 0, (i64 mul (i64 sext|zext i32 %a), (i64 sext|zext i32 %b)) LHS = DAG.getNode(ExtendOpc, DL, MVT::i64, LHS); RHS = DAG.getNode(ExtendOpc, DL, MVT::i64, RHS); SDValue Mul = DAG.getNode(ISD::MUL, DL, MVT::i64, LHS, RHS); - SDValue Add = DAG.getNode(ISD::ADD, DL, MVT::i64, Mul, - DAG.getConstant(0, DL, MVT::i64)); - // On AArch64 the upper 32 bits are always zero extended for a 32 bit - // operation. We need to clear out the upper 32 bits, because we used a - // widening multiply that wrote all 64 bits. In the end this should be a - // noop. - Value = DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, Add); + Value = DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, Mul); + + // Check that the result fits into a 32-bit integer. + SDVTList VTs = DAG.getVTList(MVT::i64, MVT_CC); if (IsSigned) { - // The signed overflow check requires more than just a simple check for - // any bit set in the upper 32 bits of the result. These bits could be - // just the sign bits of a negative number. To perform the overflow - // check we have to arithmetic shift right the 32nd bit of the result by - // 31 bits. Then we compare the result to the upper 32 bits. - SDValue UpperBits = DAG.getNode(ISD::SRL, DL, MVT::i64, Add, - DAG.getConstant(32, DL, MVT::i64)); - UpperBits = DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, UpperBits); - SDValue LowerBits = DAG.getNode(ISD::SRA, DL, MVT::i32, Value, - DAG.getConstant(31, DL, MVT::i64)); - // It is important that LowerBits is last, otherwise the arithmetic - // shift will not be folded into the compare (SUBS). - SDVTList VTs = DAG.getVTList(MVT::i32, MVT::i32); - Overflow = DAG.getNode(AArch64ISD::SUBS, DL, VTs, UpperBits, LowerBits) - .getValue(1); + // cmp xreg, wreg, sxtw + SDValue SExtMul = DAG.getNode(ISD::SIGN_EXTEND, DL, MVT::i64, Value); + Overflow = + DAG.getNode(AArch64ISD::SUBS, DL, VTs, Mul, SExtMul).getValue(1); } else { - // The overflow check for unsigned multiply is easy. We only need to - // check if any of the upper 32 bits are set. This can be done with a - // CMP (shifted register). For that we need to generate the following - // pattern: - // (i64 AArch64ISD::SUBS i64 0, (i64 srl i64 %Mul, i64 32) - SDValue UpperBits = DAG.getNode(ISD::SRL, DL, MVT::i64, Mul, - DAG.getConstant(32, DL, MVT::i64)); - SDVTList VTs = DAG.getVTList(MVT::i64, MVT::i32); + // tst xreg, #0xffffffff00000000 + SDValue UpperBits = DAG.getConstant(0xFFFFFFFF00000000, DL, MVT::i64); Overflow = - DAG.getNode(AArch64ISD::SUBS, DL, VTs, - DAG.getConstant(0, DL, MVT::i64), - UpperBits).getValue(1); + DAG.getNode(AArch64ISD::ANDS, DL, VTs, Mul, UpperBits).getValue(1); } break; } @@ -3082,9 +3251,13 @@ static SDValue LowerPREFETCH(SDValue Op, SelectionDAG &DAG) { SDValue AArch64TargetLowering::LowerFP_EXTEND(SDValue Op, SelectionDAG &DAG) const { - if (Op.getValueType().isScalableVector()) + EVT VT = Op.getValueType(); + if (VT.isScalableVector()) return LowerToPredicatedOp(Op, DAG, AArch64ISD::FP_EXTEND_MERGE_PASSTHRU); + if (useSVEForFixedLengthVectorVT(VT)) + return LowerFixedLengthFPExtendToSVE(Op, DAG); + assert(Op.getValueType() == MVT::f128 && "Unexpected lowering"); return SDValue(); } @@ -3098,6 +3271,9 @@ SDValue AArch64TargetLowering::LowerFP_ROUND(SDValue Op, SDValue SrcVal = Op.getOperand(IsStrict ? 1 : 0); EVT SrcVT = SrcVal.getValueType(); + if (useSVEForFixedLengthVectorVT(SrcVT)) + return LowerFixedLengthFPRoundToSVE(Op, DAG); + if (SrcVT != MVT::f128) { // Expand cases where the input is a vector bigger than NEON. if (useSVEForFixedLengthVectorVT(SrcVT)) @@ -3125,6 +3301,9 @@ SDValue AArch64TargetLowering::LowerVectorFP_TO_INT(SDValue Op, return LowerToPredicatedOp(Op, DAG, Opcode); } + if (useSVEForFixedLengthVectorVT(VT) || useSVEForFixedLengthVectorVT(InVT)) + return LowerFixedLengthFPToIntToSVE(Op, DAG); + unsigned NumElts = InVT.getVectorNumElements(); // f16 conversions are promoted to f32 when full fp16 is not supported. @@ -3185,6 +3364,44 @@ SDValue AArch64TargetLowering::LowerFP_TO_INT(SDValue Op, return SDValue(); } +SDValue AArch64TargetLowering::LowerFP_TO_INT_SAT(SDValue Op, + SelectionDAG &DAG) const { + // AArch64 FP-to-int conversions saturate to the destination register size, so + // we can lower common saturating conversions to simple instructions. + SDValue SrcVal = Op.getOperand(0); + + EVT SrcVT = SrcVal.getValueType(); + EVT DstVT = Op.getValueType(); + + EVT SatVT = cast<VTSDNode>(Op.getOperand(1))->getVT(); + uint64_t SatWidth = SatVT.getScalarSizeInBits(); + uint64_t DstWidth = DstVT.getScalarSizeInBits(); + assert(SatWidth <= DstWidth && "Saturation width cannot exceed result width"); + + // TODO: Support lowering of NEON and SVE conversions. + if (SrcVT.isVector()) + return SDValue(); + + // TODO: Saturate to SatWidth explicitly. + if (SatWidth != DstWidth) + return SDValue(); + + // In the absence of FP16 support, promote f32 to f16, like LowerFP_TO_INT(). + if (SrcVT == MVT::f16 && !Subtarget->hasFullFP16()) + return DAG.getNode(Op.getOpcode(), SDLoc(Op), Op.getValueType(), + DAG.getNode(ISD::FP_EXTEND, SDLoc(Op), MVT::f32, SrcVal), + Op.getOperand(1)); + + // Cases that we can emit directly. + if ((SrcVT == MVT::f64 || SrcVT == MVT::f32 || + (SrcVT == MVT::f16 && Subtarget->hasFullFP16())) && + (DstVT == MVT::i64 || DstVT == MVT::i32)) + return Op; + + // For all other cases, fall back on the expanded form. + return SDValue(); +} + SDValue AArch64TargetLowering::LowerVectorINT_TO_FP(SDValue Op, SelectionDAG &DAG) const { // Warning: We maintain cost tables in AArch64TargetTransformInfo.cpp. @@ -3211,6 +3428,9 @@ SDValue AArch64TargetLowering::LowerVectorINT_TO_FP(SDValue Op, return LowerToPredicatedOp(Op, DAG, Opcode); } + if (useSVEForFixedLengthVectorVT(VT) || useSVEForFixedLengthVectorVT(InVT)) + return LowerFixedLengthIntToFPToSVE(Op, DAG); + uint64_t VTSize = VT.getFixedSizeInBits(); uint64_t InVTSize = InVT.getFixedSizeInBits(); if (VTSize < InVTSize) { @@ -3295,12 +3515,32 @@ SDValue AArch64TargetLowering::LowerFSINCOS(SDValue Op, return CallResult.first; } -static SDValue LowerBITCAST(SDValue Op, SelectionDAG &DAG) { +static MVT getSVEContainerType(EVT ContentTy); + +SDValue AArch64TargetLowering::LowerBITCAST(SDValue Op, + SelectionDAG &DAG) const { EVT OpVT = Op.getValueType(); + EVT ArgVT = Op.getOperand(0).getValueType(); + + if (useSVEForFixedLengthVectorVT(OpVT)) + return LowerFixedLengthBitcastToSVE(Op, DAG); + + if (OpVT.isScalableVector()) { + if (isTypeLegal(OpVT) && !isTypeLegal(ArgVT)) { + assert(OpVT.isFloatingPoint() && !ArgVT.isFloatingPoint() && + "Expected int->fp bitcast!"); + SDValue ExtResult = + DAG.getNode(ISD::ANY_EXTEND, SDLoc(Op), getSVEContainerType(ArgVT), + Op.getOperand(0)); + return getSVESafeBitCast(OpVT, ExtResult, DAG); + } + return getSVESafeBitCast(OpVT, Op.getOperand(0), DAG); + } + if (OpVT != MVT::f16 && OpVT != MVT::bf16) return SDValue(); - assert(Op.getOperand(0).getValueType() == MVT::i16); + assert(ArgVT == MVT::i16); SDLoc DL(Op); Op = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, Op.getOperand(0)); @@ -3453,6 +3693,50 @@ SDValue AArch64TargetLowering::LowerFLT_ROUNDS_(SDValue Op, return DAG.getMergeValues({AND, Chain}, dl); } +SDValue AArch64TargetLowering::LowerSET_ROUNDING(SDValue Op, + SelectionDAG &DAG) const { + SDLoc DL(Op); + SDValue Chain = Op->getOperand(0); + SDValue RMValue = Op->getOperand(1); + + // The rounding mode is in bits 23:22 of the FPCR. + // The llvm.set.rounding argument value to the rounding mode in FPCR mapping + // is 0->3, 1->0, 2->1, 3->2. The formula we use to implement this is + // ((arg - 1) & 3) << 22). + // + // The argument of llvm.set.rounding must be within the segment [0, 3], so + // NearestTiesToAway (4) is not handled here. It is responsibility of the code + // generated llvm.set.rounding to ensure this condition. + + // Calculate new value of FPCR[23:22]. + RMValue = DAG.getNode(ISD::SUB, DL, MVT::i32, RMValue, + DAG.getConstant(1, DL, MVT::i32)); + RMValue = DAG.getNode(ISD::AND, DL, MVT::i32, RMValue, + DAG.getConstant(0x3, DL, MVT::i32)); + RMValue = + DAG.getNode(ISD::SHL, DL, MVT::i32, RMValue, + DAG.getConstant(AArch64::RoundingBitsPos, DL, MVT::i32)); + RMValue = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i64, RMValue); + + // Get current value of FPCR. + SDValue Ops[] = { + Chain, DAG.getTargetConstant(Intrinsic::aarch64_get_fpcr, DL, MVT::i64)}; + SDValue FPCR = + DAG.getNode(ISD::INTRINSIC_W_CHAIN, DL, {MVT::i64, MVT::Other}, Ops); + Chain = FPCR.getValue(1); + FPCR = FPCR.getValue(0); + + // Put new rounding mode into FPSCR[23:22]. + const int RMMask = ~(AArch64::Rounding::rmMask << AArch64::RoundingBitsPos); + FPCR = DAG.getNode(ISD::AND, DL, MVT::i64, FPCR, + DAG.getConstant(RMMask, DL, MVT::i64)); + FPCR = DAG.getNode(ISD::OR, DL, MVT::i64, FPCR, RMValue); + SDValue Ops2[] = { + Chain, DAG.getTargetConstant(Intrinsic::aarch64_set_fpcr, DL, MVT::i64), + FPCR}; + return DAG.getNode(ISD::INTRINSIC_VOID, DL, MVT::Other, Ops2); +} + SDValue AArch64TargetLowering::LowerMUL(SDValue Op, SelectionDAG &DAG) const { EVT VT = Op.getValueType(); @@ -3535,6 +3819,37 @@ static inline SDValue getPTrue(SelectionDAG &DAG, SDLoc DL, EVT VT, DAG.getTargetConstant(Pattern, DL, MVT::i32)); } +static SDValue lowerConvertToSVBool(SDValue Op, SelectionDAG &DAG) { + SDLoc DL(Op); + EVT OutVT = Op.getValueType(); + SDValue InOp = Op.getOperand(1); + EVT InVT = InOp.getValueType(); + + // Return the operand if the cast isn't changing type, + // i.e. <n x 16 x i1> -> <n x 16 x i1> + if (InVT == OutVT) + return InOp; + + SDValue Reinterpret = + DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, OutVT, InOp); + + // If the argument converted to an svbool is a ptrue or a comparison, the + // lanes introduced by the widening are zero by construction. + switch (InOp.getOpcode()) { + case AArch64ISD::SETCC_MERGE_ZERO: + return Reinterpret; + case ISD::INTRINSIC_WO_CHAIN: + if (InOp.getConstantOperandVal(0) == Intrinsic::aarch64_sve_ptrue) + return Reinterpret; + } + + // Otherwise, zero the newly introduced lanes. + SDValue Mask = getPTrue(DAG, DL, InVT, AArch64SVEPredPattern::all); + SDValue MaskReinterpret = + DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, OutVT, Mask); + return DAG.getNode(ISD::AND, DL, OutVT, Reinterpret, MaskReinterpret); +} + SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op, SelectionDAG &DAG) const { unsigned IntNo = cast<ConstantSDNode>(Op.getOperand(0))->getZExtValue(); @@ -3596,7 +3911,7 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op, return DAG.getNode(AArch64ISD::LASTB, dl, Op.getValueType(), Op.getOperand(1), Op.getOperand(2)); case Intrinsic::aarch64_sve_rev: - return DAG.getNode(AArch64ISD::REV, dl, Op.getValueType(), + return DAG.getNode(ISD::VECTOR_REVERSE, dl, Op.getValueType(), Op.getOperand(1)); case Intrinsic::aarch64_sve_tbl: return DAG.getNode(AArch64ISD::TBL, dl, Op.getValueType(), @@ -3619,6 +3934,9 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op, case Intrinsic::aarch64_sve_zip2: return DAG.getNode(AArch64ISD::ZIP2, dl, Op.getValueType(), Op.getOperand(1), Op.getOperand(2)); + case Intrinsic::aarch64_sve_splice: + return DAG.getNode(AArch64ISD::SPLICE, dl, Op.getValueType(), + Op.getOperand(1), Op.getOperand(2), Op.getOperand(3)); case Intrinsic::aarch64_sve_ptrue: return DAG.getNode(AArch64ISD::PTRUE, dl, Op.getValueType(), Op.getOperand(1)); @@ -3638,6 +3956,8 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op, case Intrinsic::aarch64_sve_convert_from_svbool: return DAG.getNode(AArch64ISD::REINTERPRET_CAST, dl, Op.getValueType(), Op.getOperand(1)); + case Intrinsic::aarch64_sve_convert_to_svbool: + return lowerConvertToSVBool(Op, DAG); case Intrinsic::aarch64_sve_fneg: return DAG.getNode(AArch64ISD::FNEG_MERGE_PASSTHRU, dl, Op.getValueType(), Op.getOperand(2), Op.getOperand(3), Op.getOperand(1)); @@ -3693,22 +4013,6 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op, case Intrinsic::aarch64_sve_neg: return DAG.getNode(AArch64ISD::NEG_MERGE_PASSTHRU, dl, Op.getValueType(), Op.getOperand(2), Op.getOperand(3), Op.getOperand(1)); - case Intrinsic::aarch64_sve_convert_to_svbool: { - EVT OutVT = Op.getValueType(); - EVT InVT = Op.getOperand(1).getValueType(); - // Return the operand if the cast isn't changing type, - // i.e. <n x 16 x i1> -> <n x 16 x i1> - if (InVT == OutVT) - return Op.getOperand(1); - // Otherwise, zero the newly introduced lanes. - SDValue Reinterpret = - DAG.getNode(AArch64ISD::REINTERPRET_CAST, dl, OutVT, Op.getOperand(1)); - SDValue Mask = getPTrue(DAG, dl, InVT, AArch64SVEPredPattern::all); - SDValue MaskReinterpret = - DAG.getNode(AArch64ISD::REINTERPRET_CAST, dl, OutVT, Mask); - return DAG.getNode(ISD::AND, dl, OutVT, Reinterpret, MaskReinterpret); - } - case Intrinsic::aarch64_sve_insr: { SDValue Scalar = Op.getOperand(2); EVT ScalarTy = Scalar.getValueType(); @@ -3813,18 +4117,40 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op, return DAG.getNode(Opcode, dl, Op.getValueType(), Op.getOperand(1), Op.getOperand(2)); } - + case Intrinsic::aarch64_neon_sabd: case Intrinsic::aarch64_neon_uabd: { - return DAG.getNode(AArch64ISD::UABD, dl, Op.getValueType(), - Op.getOperand(1), Op.getOperand(2)); + unsigned Opcode = IntNo == Intrinsic::aarch64_neon_uabd ? ISD::ABDU + : ISD::ABDS; + return DAG.getNode(Opcode, dl, Op.getValueType(), Op.getOperand(1), + Op.getOperand(2)); } - case Intrinsic::aarch64_neon_sabd: { - return DAG.getNode(AArch64ISD::SABD, dl, Op.getValueType(), - Op.getOperand(1), Op.getOperand(2)); + case Intrinsic::aarch64_neon_uaddlp: { + unsigned Opcode = AArch64ISD::UADDLP; + return DAG.getNode(Opcode, dl, Op.getValueType(), Op.getOperand(1)); + } + case Intrinsic::aarch64_neon_sdot: + case Intrinsic::aarch64_neon_udot: + case Intrinsic::aarch64_sve_sdot: + case Intrinsic::aarch64_sve_udot: { + unsigned Opcode = (IntNo == Intrinsic::aarch64_neon_udot || + IntNo == Intrinsic::aarch64_sve_udot) + ? AArch64ISD::UDOT + : AArch64ISD::SDOT; + return DAG.getNode(Opcode, dl, Op.getValueType(), Op.getOperand(1), + Op.getOperand(2), Op.getOperand(3)); } } } +bool AArch64TargetLowering::shouldExtendGSIndex(EVT VT, EVT &EltTy) const { + if (VT.getVectorElementType() == MVT::i8 || + VT.getVectorElementType() == MVT::i16) { + EltTy = MVT::i32; + return true; + } + return false; +} + bool AArch64TargetLowering::shouldRemoveExtendFromGSIndex(EVT VT) const { if (VT.getVectorElementType() == MVT::i32 && VT.getVectorElementCount().getKnownMinValue() >= 4) @@ -3938,6 +4264,12 @@ void selectGatherScatterAddrMode(SDValue &BasePtr, SDValue &Index, EVT MemVT, if (!isNullConstant(BasePtr)) return; + // FIXME: This will not match for fixed vector type codegen as the nodes in + // question will have fixed<->scalable conversions around them. This should be + // moved to a DAG combine or complex pattern so that is executes after all of + // the fixed vector insert and extracts have been removed. This deficiency + // will result in a sub-optimal addressing mode being used, i.e. an ADD not + // being folded into the scatter/gather. ConstantSDNode *Offset = nullptr; if (Index.getOpcode() == ISD::ADD) if (auto SplatVal = DAG.getSplatValue(Index.getOperand(1))) { @@ -3982,6 +4314,8 @@ SDValue AArch64TargetLowering::LowerMGATHER(SDValue Op, MaskedGatherSDNode *MGT = cast<MaskedGatherSDNode>(Op); assert(MGT && "Can only custom lower gather load nodes"); + bool IsFixedLength = MGT->getMemoryVT().isFixedLengthVector(); + SDValue Index = MGT->getIndex(); SDValue Chain = MGT->getChain(); SDValue PassThru = MGT->getPassThru(); @@ -4000,6 +4334,7 @@ SDValue AArch64TargetLowering::LowerMGATHER(SDValue Op, bool ResNeedsSignExtend = ExtTy == ISD::EXTLOAD || ExtTy == ISD::SEXTLOAD; EVT VT = PassThru.getSimpleValueType(); + EVT IndexVT = Index.getSimpleValueType(); EVT MemVT = MGT->getMemoryVT(); SDValue InputVT = DAG.getValueType(MemVT); @@ -4007,14 +4342,27 @@ SDValue AArch64TargetLowering::LowerMGATHER(SDValue Op, !static_cast<const AArch64Subtarget &>(DAG.getSubtarget()).hasBF16()) return SDValue(); - // Handle FP data by using an integer gather and casting the result. - if (VT.isFloatingPoint()) { - EVT PassThruVT = getPackedSVEVectorVT(VT.getVectorElementCount()); - PassThru = getSVESafeBitCast(PassThruVT, PassThru, DAG); + if (IsFixedLength) { + assert(Subtarget->useSVEForFixedLengthVectors() && + "Cannot lower when not using SVE for fixed vectors"); + IndexVT = getContainerForFixedLengthVector(DAG, IndexVT); + MemVT = IndexVT.changeVectorElementType(MemVT.getVectorElementType()); + InputVT = DAG.getValueType(MemVT.changeTypeToInteger()); + } + + if (PassThru->isUndef() || isZerosVector(PassThru.getNode())) + PassThru = SDValue(); + + if (VT.isFloatingPoint() && !IsFixedLength) { + // Handle FP data by using an integer gather and casting the result. + if (PassThru) { + EVT PassThruVT = getPackedSVEVectorVT(VT.getVectorElementCount()); + PassThru = getSVESafeBitCast(PassThruVT, PassThru, DAG); + } InputVT = DAG.getValueType(MemVT.changeVectorElementTypeToInteger()); } - SDVTList VTs = DAG.getVTList(PassThru.getSimpleValueType(), MVT::Other); + SDVTList VTs = DAG.getVTList(IndexVT, MVT::Other); if (getGatherScatterIndexIsExtended(Index)) Index = Index.getOperand(0); @@ -4026,15 +4374,36 @@ SDValue AArch64TargetLowering::LowerMGATHER(SDValue Op, if (ResNeedsSignExtend) Opcode = getSignExtendedGatherOpcode(Opcode); - SDValue Ops[] = {Chain, Mask, BasePtr, Index, InputVT, PassThru}; - SDValue Gather = DAG.getNode(Opcode, DL, VTs, Ops); + if (IsFixedLength) { + if (Index.getSimpleValueType().isFixedLengthVector()) + Index = convertToScalableVector(DAG, IndexVT, Index); + if (BasePtr.getSimpleValueType().isFixedLengthVector()) + BasePtr = convertToScalableVector(DAG, IndexVT, BasePtr); + Mask = convertFixedMaskToScalableVector(Mask, DAG); + } - if (VT.isFloatingPoint()) { - SDValue Cast = getSVESafeBitCast(VT, Gather, DAG); - return DAG.getMergeValues({Cast, Gather}, DL); + SDValue Ops[] = {Chain, Mask, BasePtr, Index, InputVT}; + SDValue Result = DAG.getNode(Opcode, DL, VTs, Ops); + Chain = Result.getValue(1); + + if (IsFixedLength) { + Result = convertFromScalableVector( + DAG, VT.changeVectorElementType(IndexVT.getVectorElementType()), + Result); + Result = DAG.getNode(ISD::TRUNCATE, DL, VT.changeTypeToInteger(), Result); + Result = DAG.getNode(ISD::BITCAST, DL, VT, Result); + + if (PassThru) + Result = DAG.getSelect(DL, VT, MGT->getMask(), Result, PassThru); + } else { + if (PassThru) + Result = DAG.getSelect(DL, IndexVT, Mask, Result, PassThru); + + if (VT.isFloatingPoint()) + Result = getSVESafeBitCast(VT, Result, DAG); } - return Gather; + return DAG.getMergeValues({Result, Chain}, DL); } SDValue AArch64TargetLowering::LowerMSCATTER(SDValue Op, @@ -4043,6 +4412,8 @@ SDValue AArch64TargetLowering::LowerMSCATTER(SDValue Op, MaskedScatterSDNode *MSC = cast<MaskedScatterSDNode>(Op); assert(MSC && "Can only custom lower scatter store nodes"); + bool IsFixedLength = MSC->getMemoryVT().isFixedLengthVector(); + SDValue Index = MSC->getIndex(); SDValue Chain = MSC->getChain(); SDValue StoreVal = MSC->getValue(); @@ -4059,6 +4430,7 @@ SDValue AArch64TargetLowering::LowerMSCATTER(SDValue Op, Index.getSimpleValueType().getVectorElementType() == MVT::i32; EVT VT = StoreVal.getSimpleValueType(); + EVT IndexVT = Index.getSimpleValueType(); SDVTList VTs = DAG.getVTList(MVT::Other); EVT MemVT = MSC->getMemoryVT(); SDValue InputVT = DAG.getValueType(MemVT); @@ -4067,8 +4439,21 @@ SDValue AArch64TargetLowering::LowerMSCATTER(SDValue Op, !static_cast<const AArch64Subtarget &>(DAG.getSubtarget()).hasBF16()) return SDValue(); - // Handle FP data by casting the data so an integer scatter can be used. - if (VT.isFloatingPoint()) { + if (IsFixedLength) { + assert(Subtarget->useSVEForFixedLengthVectors() && + "Cannot lower when not using SVE for fixed vectors"); + IndexVT = getContainerForFixedLengthVector(DAG, IndexVT); + MemVT = IndexVT.changeVectorElementType(MemVT.getVectorElementType()); + InputVT = DAG.getValueType(MemVT.changeTypeToInteger()); + + StoreVal = + DAG.getNode(ISD::BITCAST, DL, VT.changeTypeToInteger(), StoreVal); + StoreVal = DAG.getNode( + ISD::ANY_EXTEND, DL, + VT.changeVectorElementType(IndexVT.getVectorElementType()), StoreVal); + StoreVal = convertToScalableVector(DAG, IndexVT, StoreVal); + } else if (VT.isFloatingPoint()) { + // Handle FP data by casting the data so an integer scatter can be used. EVT StoreValVT = getPackedSVEVectorVT(VT.getVectorElementCount()); StoreVal = getSVESafeBitCast(StoreValVT, StoreVal, DAG); InputVT = DAG.getValueType(MemVT.changeVectorElementTypeToInteger()); @@ -4081,10 +4466,44 @@ SDValue AArch64TargetLowering::LowerMSCATTER(SDValue Op, selectGatherScatterAddrMode(BasePtr, Index, MemVT, Opcode, /*isGather=*/false, DAG); + if (IsFixedLength) { + if (Index.getSimpleValueType().isFixedLengthVector()) + Index = convertToScalableVector(DAG, IndexVT, Index); + if (BasePtr.getSimpleValueType().isFixedLengthVector()) + BasePtr = convertToScalableVector(DAG, IndexVT, BasePtr); + Mask = convertFixedMaskToScalableVector(Mask, DAG); + } + SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, InputVT}; return DAG.getNode(Opcode, DL, VTs, Ops); } +SDValue AArch64TargetLowering::LowerMLOAD(SDValue Op, SelectionDAG &DAG) const { + SDLoc DL(Op); + MaskedLoadSDNode *LoadNode = cast<MaskedLoadSDNode>(Op); + assert(LoadNode && "Expected custom lowering of a masked load node"); + EVT VT = Op->getValueType(0); + + if (useSVEForFixedLengthVectorVT(VT, true)) + return LowerFixedLengthVectorMLoadToSVE(Op, DAG); + + SDValue PassThru = LoadNode->getPassThru(); + SDValue Mask = LoadNode->getMask(); + + if (PassThru->isUndef() || isZerosVector(PassThru.getNode())) + return Op; + + SDValue Load = DAG.getMaskedLoad( + VT, DL, LoadNode->getChain(), LoadNode->getBasePtr(), + LoadNode->getOffset(), Mask, DAG.getUNDEF(VT), LoadNode->getMemoryVT(), + LoadNode->getMemOperand(), LoadNode->getAddressingMode(), + LoadNode->getExtensionType()); + + SDValue Result = DAG.getSelect(DL, VT, Mask, Load, PassThru); + + return DAG.getMergeValues({Result, Load.getValue(1)}, DL); +} + // Custom lower trunc store for v4i8 vectors, since it is promoted to v4i16. static SDValue LowerTruncateVectorStore(SDLoc DL, StoreSDNode *ST, EVT VT, EVT MemVT, @@ -4132,19 +4551,20 @@ SDValue AArch64TargetLowering::LowerSTORE(SDValue Op, EVT MemVT = StoreNode->getMemoryVT(); if (VT.isVector()) { - if (useSVEForFixedLengthVectorVT(VT)) + if (useSVEForFixedLengthVectorVT(VT, true)) return LowerFixedLengthVectorStoreToSVE(Op, DAG); unsigned AS = StoreNode->getAddressSpace(); Align Alignment = StoreNode->getAlign(); if (Alignment < MemVT.getStoreSize() && - !allowsMisalignedMemoryAccesses(MemVT, AS, Alignment.value(), + !allowsMisalignedMemoryAccesses(MemVT, AS, Alignment, StoreNode->getMemOperand()->getFlags(), nullptr)) { return scalarizeVectorStore(StoreNode, DAG); } - if (StoreNode->isTruncatingStore()) { + if (StoreNode->isTruncatingStore() && VT == MVT::v4i16 && + MemVT == MVT::v4i8) { return LowerTruncateVectorStore(Dl, StoreNode, VT, MemVT, DAG); } // 256 bit non-temporal stores can be lowered to STNP. Do this as part of @@ -4190,6 +4610,40 @@ SDValue AArch64TargetLowering::LowerSTORE(SDValue Op, return SDValue(); } +// Custom lowering for extending v4i8 vector loads. +SDValue AArch64TargetLowering::LowerLOAD(SDValue Op, + SelectionDAG &DAG) const { + SDLoc DL(Op); + LoadSDNode *LoadNode = cast<LoadSDNode>(Op); + assert(LoadNode && "Expected custom lowering of a load node"); + EVT VT = Op->getValueType(0); + assert((VT == MVT::v4i16 || VT == MVT::v4i32) && "Expected v4i16 or v4i32"); + + if (LoadNode->getMemoryVT() != MVT::v4i8) + return SDValue(); + + unsigned ExtType; + if (LoadNode->getExtensionType() == ISD::SEXTLOAD) + ExtType = ISD::SIGN_EXTEND; + else if (LoadNode->getExtensionType() == ISD::ZEXTLOAD || + LoadNode->getExtensionType() == ISD::EXTLOAD) + ExtType = ISD::ZERO_EXTEND; + else + return SDValue(); + + SDValue Load = DAG.getLoad(MVT::f32, DL, LoadNode->getChain(), + LoadNode->getBasePtr(), MachinePointerInfo()); + SDValue Chain = Load.getValue(1); + SDValue Vec = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, MVT::v2f32, Load); + SDValue BC = DAG.getNode(ISD::BITCAST, DL, MVT::v8i8, Vec); + SDValue Ext = DAG.getNode(ExtType, DL, MVT::v8i16, BC); + Ext = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v4i16, Ext, + DAG.getConstant(0, DL, MVT::i64)); + if (VT == MVT::v4i32) + Ext = DAG.getNode(ExtType, DL, MVT::v4i32, Ext); + return DAG.getMergeValues({Ext, Chain}, DL); +} + // Generate SUBS and CSEL for integer abs. SDValue AArch64TargetLowering::LowerABS(SDValue Op, SelectionDAG &DAG) const { MVT VT = Op.getSimpleValueType(); @@ -4339,10 +4793,9 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op, case ISD::SHL: return LowerVectorSRA_SRL_SHL(Op, DAG); case ISD::SHL_PARTS: - return LowerShiftLeftParts(Op, DAG); case ISD::SRL_PARTS: case ISD::SRA_PARTS: - return LowerShiftRightParts(Op, DAG); + return LowerShiftParts(Op, DAG); case ISD::CTPOP: return LowerCTPOP(Op, DAG); case ISD::FCOPYSIGN: @@ -4363,16 +4816,29 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op, case ISD::STRICT_FP_TO_SINT: case ISD::STRICT_FP_TO_UINT: return LowerFP_TO_INT(Op, DAG); + case ISD::FP_TO_SINT_SAT: + case ISD::FP_TO_UINT_SAT: + return LowerFP_TO_INT_SAT(Op, DAG); case ISD::FSINCOS: return LowerFSINCOS(Op, DAG); case ISD::FLT_ROUNDS_: return LowerFLT_ROUNDS_(Op, DAG); + case ISD::SET_ROUNDING: + return LowerSET_ROUNDING(Op, DAG); case ISD::MUL: return LowerMUL(Op, DAG); + case ISD::MULHS: + return LowerToPredicatedOp(Op, DAG, AArch64ISD::MULHS_PRED, + /*OverrideNEON=*/true); + case ISD::MULHU: + return LowerToPredicatedOp(Op, DAG, AArch64ISD::MULHU_PRED, + /*OverrideNEON=*/true); case ISD::INTRINSIC_WO_CHAIN: return LowerINTRINSIC_WO_CHAIN(Op, DAG); case ISD::STORE: return LowerSTORE(Op, DAG); + case ISD::MSTORE: + return LowerFixedLengthVectorMStoreToSVE(Op, DAG); case ISD::MGATHER: return LowerMGATHER(Op, DAG); case ISD::MSCATTER: @@ -4416,18 +4882,24 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op, } case ISD::TRUNCATE: return LowerTRUNCATE(Op, DAG); + case ISD::MLOAD: + return LowerMLOAD(Op, DAG); case ISD::LOAD: if (useSVEForFixedLengthVectorVT(Op.getValueType())) return LowerFixedLengthVectorLoadToSVE(Op, DAG); - llvm_unreachable("Unexpected request to lower ISD::LOAD"); + return LowerLOAD(Op, DAG); case ISD::ADD: return LowerToPredicatedOp(Op, DAG, AArch64ISD::ADD_PRED); case ISD::AND: return LowerToScalableOp(Op, DAG); case ISD::SUB: return LowerToPredicatedOp(Op, DAG, AArch64ISD::SUB_PRED); + case ISD::FMAXIMUM: + return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMAX_PRED); case ISD::FMAXNUM: return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMAXNM_PRED); + case ISD::FMINIMUM: + return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMIN_PRED); case ISD::FMINNUM: return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMINNM_PRED); case ISD::VSELECT: @@ -4435,8 +4907,7 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op, case ISD::ABS: return LowerABS(Op, DAG); case ISD::BITREVERSE: - return LowerToPredicatedOp(Op, DAG, AArch64ISD::BITREVERSE_MERGE_PASSTHRU, - /*OverrideNEON=*/true); + return LowerBitreverse(Op, DAG); case ISD::BSWAP: return LowerToPredicatedOp(Op, DAG, AArch64ISD::BSWAP_MERGE_PASSTHRU); case ISD::CTLZ: @@ -4444,6 +4915,8 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op, /*OverrideNEON=*/true); case ISD::CTTZ: return LowerCTTZ(Op, DAG); + case ISD::VECTOR_SPLICE: + return LowerVECTOR_SPLICE(Op, DAG); } } @@ -4515,6 +4988,8 @@ CCAssignFn *AArch64TargetLowering::CCAssignFnForCall(CallingConv::ID CC, case CallingConv::PreserveMost: case CallingConv::CXX_FAST_TLS: case CallingConv::Swift: + case CallingConv::SwiftTail: + case CallingConv::Tail: if (Subtarget->isTargetWindows() && IsVarArg) return CC_AArch64_Win64_VarArg; if (!Subtarget->isTargetDarwin()) @@ -4578,7 +5053,10 @@ SDValue AArch64TargetLowering::LowerFormalArguments( else if (ActualMVT == MVT::i16) ValVT = MVT::i16; } - CCAssignFn *AssignFn = CCAssignFnForCall(CallConv, /*IsVarArg=*/false); + bool UseVarArgCC = false; + if (IsWin64) + UseVarArgCC = isVarArg; + CCAssignFn *AssignFn = CCAssignFnForCall(CallConv, UseVarArgCC); bool Res = AssignFn(i, ValVT, ValVT, CCValAssign::Full, Ins[i].Flags, CCInfo); assert(!Res && "Call operand has unhandled type"); @@ -4606,6 +5084,9 @@ SDValue AArch64TargetLowering::LowerFormalArguments( continue; } + if (Ins[i].Flags.isSwiftAsync()) + MF.getInfo<AArch64FunctionInfo>()->setHasSwiftAsyncContext(true); + SDValue ArgValue; if (VA.isRegLoc()) { // Arguments stored in registers. @@ -4709,7 +5190,6 @@ SDValue AArch64TargetLowering::LowerFormalArguments( ExtType, DL, VA.getLocVT(), Chain, FIN, MachinePointerInfo::getFixedStack(DAG.getMachineFunction(), FI), MemVT); - } if (VA.getLocInfo() == CCValAssign::Indirect) { @@ -4985,8 +5465,9 @@ SDValue AArch64TargetLowering::LowerCallResult( } /// Return true if the calling convention is one that we can guarantee TCO for. -static bool canGuaranteeTCO(CallingConv::ID CC) { - return CC == CallingConv::Fast; +static bool canGuaranteeTCO(CallingConv::ID CC, bool GuaranteeTailCalls) { + return (CC == CallingConv::Fast && GuaranteeTailCalls) || + CC == CallingConv::Tail || CC == CallingConv::SwiftTail; } /// Return true if we might ever do TCO for calls with this calling convention. @@ -4996,9 +5477,12 @@ static bool mayTailCallThisCC(CallingConv::ID CC) { case CallingConv::AArch64_SVE_VectorCall: case CallingConv::PreserveMost: case CallingConv::Swift: + case CallingConv::SwiftTail: + case CallingConv::Tail: + case CallingConv::Fast: return true; default: - return canGuaranteeTCO(CC); + return false; } } @@ -5014,11 +5498,11 @@ bool AArch64TargetLowering::isEligibleForTailCallOptimization( const Function &CallerF = MF.getFunction(); CallingConv::ID CallerCC = CallerF.getCallingConv(); - // If this function uses the C calling convention but has an SVE signature, - // then it preserves more registers and should assume the SVE_VectorCall CC. + // Functions using the C or Fast calling convention that have an SVE signature + // preserve more registers and should assume the SVE_VectorCall CC. // The check for matching callee-saved regs will determine whether it is // eligible for TCO. - if (CallerCC == CallingConv::C && + if ((CallerCC == CallingConv::C || CallerCC == CallingConv::Fast) && AArch64RegisterInfo::hasSVEArgsOrReturn(&MF)) CallerCC = CallingConv::AArch64_SVE_VectorCall; @@ -5050,8 +5534,8 @@ bool AArch64TargetLowering::isEligibleForTailCallOptimization( return false; } - if (getTargetMachine().Options.GuaranteedTailCallOpt) - return canGuaranteeTCO(CalleeCC) && CCMatch; + if (canGuaranteeTCO(CalleeCC, getTargetMachine().Options.GuaranteedTailCallOpt)) + return CCMatch; // Externally-defined functions with weak linkage should not be // tail-called on AArch64 when the OS does not support dynamic @@ -5182,7 +5666,8 @@ SDValue AArch64TargetLowering::addTokenForArgument(SDValue Chain, bool AArch64TargetLowering::DoesCalleeRestoreStack(CallingConv::ID CallCC, bool TailCallOpt) const { - return CallCC == CallingConv::Fast && TailCallOpt; + return (CallCC == CallingConv::Fast && TailCallOpt) || + CallCC == CallingConv::Tail || CallCC == CallingConv::SwiftTail; } /// LowerCall - Lower a call to a callseq_start + CALL + callseq_end chain, @@ -5208,10 +5693,11 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, AArch64FunctionInfo *FuncInfo = MF.getInfo<AArch64FunctionInfo>(); bool TailCallOpt = MF.getTarget().Options.GuaranteedTailCallOpt; bool IsSibCall = false; + bool IsCalleeWin64 = Subtarget->isCallingConvWin64(CallConv); // Check callee args/returns for SVE registers and set calling convention // accordingly. - if (CallConv == CallingConv::C) { + if (CallConv == CallingConv::C || CallConv == CallingConv::Fast) { bool CalleeOutSVE = any_of(Outs, [](ISD::OutputArg &Out){ return Out.VT.isScalableVector(); }); @@ -5227,19 +5713,21 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, // Check if it's really possible to do a tail call. IsTailCall = isEligibleForTailCallOptimization( Callee, CallConv, IsVarArg, Outs, OutVals, Ins, DAG); - if (!IsTailCall && CLI.CB && CLI.CB->isMustTailCall()) - report_fatal_error("failed to perform tail call elimination on a call " - "site marked musttail"); // A sibling call is one where we're under the usual C ABI and not planning // to change that but can still do a tail call: - if (!TailCallOpt && IsTailCall) + if (!TailCallOpt && IsTailCall && CallConv != CallingConv::Tail && + CallConv != CallingConv::SwiftTail) IsSibCall = true; if (IsTailCall) ++NumTailCalls; } + if (!IsTailCall && CLI.CB && CLI.CB->isMustTailCall()) + report_fatal_error("failed to perform tail call elimination on a call " + "site marked musttail"); + // Analyze operands of the call, assigning locations to each operand. SmallVector<CCValAssign, 16> ArgLocs; CCState CCInfo(CallConv, IsVarArg, DAG.getMachineFunction(), ArgLocs, @@ -5257,8 +5745,12 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, "currently not supported"); ISD::ArgFlagsTy ArgFlags = Outs[i].Flags; - CCAssignFn *AssignFn = CCAssignFnForCall(CallConv, - /*IsVarArg=*/ !Outs[i].IsFixed); + bool UseVarArgCC = !Outs[i].IsFixed; + // On Windows, the fixed arguments in a vararg call are passed in GPRs + // too, so use the vararg CC to force them to integer registers. + if (IsCalleeWin64) + UseVarArgCC = true; + CCAssignFn *AssignFn = CCAssignFnForCall(CallConv, UseVarArgCC); bool Res = AssignFn(i, ArgVT, ArgVT, CCValAssign::Full, ArgFlags, CCInfo); assert(!Res && "Call operand has unhandled type"); (void)Res; @@ -5320,6 +5812,11 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, // can actually shrink the stack. FPDiff = NumReusableBytes - NumBytes; + // Update the required reserved area if this is the tail call requiring the + // most argument stack space. + if (FPDiff < 0 && FuncInfo->getTailCallReservedStack() < (unsigned)-FPDiff) + FuncInfo->setTailCallReservedStack(-FPDiff); + // The stack pointer must be 16-byte aligned at all times it's used for a // memory operation, which in practice means at *all* times and in // particular across call boundaries. Therefore our own arguments started at @@ -5331,7 +5828,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, // Adjust the stack pointer for the new arguments... // These operations are automatically eliminated by the prolog/epilog pass if (!IsSibCall) - Chain = DAG.getCALLSEQ_START(Chain, NumBytes, 0, DL); + Chain = DAG.getCALLSEQ_START(Chain, IsTailCall ? 0 : NumBytes, 0, DL); SDValue StackPtr = DAG.getCopyFromReg(Chain, DL, AArch64::SP, getPointerTy(DAG.getDataLayout())); @@ -5485,7 +5982,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, // common case. It should also work for fundamental types too. uint32_t BEAlign = 0; unsigned OpSize; - if (VA.getLocInfo() == CCValAssign::Indirect) + if (VA.getLocInfo() == CCValAssign::Indirect || + VA.getLocInfo() == CCValAssign::Trunc) OpSize = VA.getLocVT().getFixedSizeInBits(); else OpSize = Flags.isByVal() ? Flags.getByValSize() * 8 @@ -5588,7 +6086,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, // we've carefully laid out the parameters so that when sp is reset they'll be // in the correct location. if (IsTailCall && !IsSibCall) { - Chain = DAG.getCALLSEQ_END(Chain, DAG.getIntPtrConstant(NumBytes, DL, true), + Chain = DAG.getCALLSEQ_END(Chain, DAG.getIntPtrConstant(0, DL, true), DAG.getIntPtrConstant(0, DL, true), InFlag, DL); InFlag = Chain.getValue(1); } @@ -5647,11 +6145,12 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, } unsigned CallOpc = AArch64ISD::CALL; - // Calls marked with "rv_marker" are special. They should be expanded to the - // call, directly followed by a special marker sequence. Use the CALL_RVMARKER - // to do that. - if (CLI.CB && CLI.CB->hasRetAttr("rv_marker")) { - assert(!IsTailCall && "tail calls cannot be marked with rv_marker"); + // Calls with operand bundle "clang.arc.attachedcall" are special. They should + // be expanded to the call, directly followed by a special marker sequence. + // Use the CALL_RVMARKER to do that. + if (CLI.CB && objcarc::hasAttachedCallOpBundle(CLI.CB)) { + assert(!IsTailCall && + "tail calls cannot be marked with clang.arc.attachedcall"); CallOpc = AArch64ISD::CALL_RVMARKER; } @@ -6584,6 +7083,56 @@ SDValue AArch64TargetLowering::LowerCTTZ(SDValue Op, SelectionDAG &DAG) const { return DAG.getNode(ISD::CTLZ, DL, VT, RBIT); } +SDValue AArch64TargetLowering::LowerBitreverse(SDValue Op, + SelectionDAG &DAG) const { + EVT VT = Op.getValueType(); + + if (VT.isScalableVector() || + useSVEForFixedLengthVectorVT(VT, /*OverrideNEON=*/true)) + return LowerToPredicatedOp(Op, DAG, AArch64ISD::BITREVERSE_MERGE_PASSTHRU, + true); + + SDLoc DL(Op); + SDValue REVB; + MVT VST; + + switch (VT.getSimpleVT().SimpleTy) { + default: + llvm_unreachable("Invalid type for bitreverse!"); + + case MVT::v2i32: { + VST = MVT::v8i8; + REVB = DAG.getNode(AArch64ISD::REV32, DL, VST, Op.getOperand(0)); + + break; + } + + case MVT::v4i32: { + VST = MVT::v16i8; + REVB = DAG.getNode(AArch64ISD::REV32, DL, VST, Op.getOperand(0)); + + break; + } + + case MVT::v1i64: { + VST = MVT::v8i8; + REVB = DAG.getNode(AArch64ISD::REV64, DL, VST, Op.getOperand(0)); + + break; + } + + case MVT::v2i64: { + VST = MVT::v16i8; + REVB = DAG.getNode(AArch64ISD::REV64, DL, VST, Op.getOperand(0)); + + break; + } + } + + return DAG.getNode(AArch64ISD::NVCAST, DL, VT, + DAG.getNode(ISD::BITREVERSE, DL, VST, REVB)); +} + SDValue AArch64TargetLowering::LowerSETCC(SDValue Op, SelectionDAG &DAG) const { if (Op.getValueType().isVector()) @@ -6700,13 +7249,26 @@ SDValue AArch64TargetLowering::LowerSELECT_CC(ISD::CondCode CC, SDValue LHS, assert((LHS.getValueType() == RHS.getValueType()) && (LHS.getValueType() == MVT::i32 || LHS.getValueType() == MVT::i64)); + ConstantSDNode *CFVal = dyn_cast<ConstantSDNode>(FVal); + ConstantSDNode *CTVal = dyn_cast<ConstantSDNode>(TVal); + ConstantSDNode *RHSC = dyn_cast<ConstantSDNode>(RHS); + // Check for sign pattern (SELECT_CC setgt, iN lhs, -1, 1, -1) and transform + // into (OR (ASR lhs, N-1), 1), which requires less instructions for the + // supported types. + if (CC == ISD::SETGT && RHSC && RHSC->isAllOnesValue() && CTVal && CFVal && + CTVal->isOne() && CFVal->isAllOnesValue() && + LHS.getValueType() == TVal.getValueType()) { + EVT VT = LHS.getValueType(); + SDValue Shift = + DAG.getNode(ISD::SRA, dl, VT, LHS, + DAG.getConstant(VT.getSizeInBits() - 1, dl, VT)); + return DAG.getNode(ISD::OR, dl, VT, Shift, DAG.getConstant(1, dl, VT)); + } + unsigned Opcode = AArch64ISD::CSEL; // If both the TVal and the FVal are constants, see if we can swap them in // order to for a CSINV or CSINC out of them. - ConstantSDNode *CFVal = dyn_cast<ConstantSDNode>(FVal); - ConstantSDNode *CTVal = dyn_cast<ConstantSDNode>(TVal); - if (CTVal && CFVal && CTVal->isAllOnesValue() && CFVal->isNullValue()) { std::swap(TVal, FVal); std::swap(CTVal, CFVal); @@ -6861,6 +7423,16 @@ SDValue AArch64TargetLowering::LowerSELECT_CC(ISD::CondCode CC, SDValue LHS, return CS1; } +SDValue AArch64TargetLowering::LowerVECTOR_SPLICE(SDValue Op, + SelectionDAG &DAG) const { + + EVT Ty = Op.getValueType(); + auto Idx = Op.getConstantOperandAPInt(2); + if (Idx.sge(-1) && Idx.slt(Ty.getVectorMinNumElements())) + return Op; + return SDValue(); +} + SDValue AArch64TargetLowering::LowerSELECT_CC(SDValue Op, SelectionDAG &DAG) const { ISD::CondCode CC = cast<CondCodeSDNode>(Op.getOperand(4))->get(); @@ -6887,6 +7459,17 @@ SDValue AArch64TargetLowering::LowerSELECT(SDValue Op, return DAG.getNode(ISD::VSELECT, DL, Ty, SplatPred, TVal, FVal); } + if (useSVEForFixedLengthVectorVT(Ty)) { + // FIXME: Ideally this would be the same as above using i1 types, however + // for the moment we can't deal with fixed i1 vector types properly, so + // instead extend the predicate to a result type sized integer vector. + MVT SplatValVT = MVT::getIntegerVT(Ty.getScalarSizeInBits()); + MVT PredVT = MVT::getVectorVT(SplatValVT, Ty.getVectorElementCount()); + SDValue SplatVal = DAG.getSExtOrTrunc(CCVal, DL, SplatValVT); + SDValue SplatPred = DAG.getNode(ISD::SPLAT_VECTOR, DL, PredVT, SplatVal); + return DAG.getNode(ISD::VSELECT, DL, Ty, SplatPred, TVal, FVal); + } + // Optimize {s|u}{add|sub|mul}.with.overflow feeding into a select // instruction. if (ISD::isOverflowIntrOpRes(CCVal)) { @@ -6909,7 +7492,7 @@ SDValue AArch64TargetLowering::LowerSELECT(SDValue Op, if (CCVal.getOpcode() == ISD::SETCC) { LHS = CCVal.getOperand(0); RHS = CCVal.getOperand(1); - CC = cast<CondCodeSDNode>(CCVal->getOperand(2))->get(); + CC = cast<CondCodeSDNode>(CCVal.getOperand(2))->get(); } else { LHS = CCVal; RHS = DAG.getConstant(0, DL, CCVal.getValueType()); @@ -7293,112 +7876,13 @@ SDValue AArch64TargetLowering::LowerRETURNADDR(SDValue Op, return SDValue(St, 0); } -/// LowerShiftRightParts - Lower SRA_PARTS, which returns two -/// i64 values and take a 2 x i64 value to shift plus a shift amount. -SDValue AArch64TargetLowering::LowerShiftRightParts(SDValue Op, - SelectionDAG &DAG) const { - assert(Op.getNumOperands() == 3 && "Not a double-shift!"); - EVT VT = Op.getValueType(); - unsigned VTBits = VT.getSizeInBits(); - SDLoc dl(Op); - SDValue ShOpLo = Op.getOperand(0); - SDValue ShOpHi = Op.getOperand(1); - SDValue ShAmt = Op.getOperand(2); - unsigned Opc = (Op.getOpcode() == ISD::SRA_PARTS) ? ISD::SRA : ISD::SRL; - - assert(Op.getOpcode() == ISD::SRA_PARTS || Op.getOpcode() == ISD::SRL_PARTS); - - SDValue RevShAmt = DAG.getNode(ISD::SUB, dl, MVT::i64, - DAG.getConstant(VTBits, dl, MVT::i64), ShAmt); - SDValue HiBitsForLo = DAG.getNode(ISD::SHL, dl, VT, ShOpHi, RevShAmt); - - // Unfortunately, if ShAmt == 0, we just calculated "(SHL ShOpHi, 64)" which - // is "undef". We wanted 0, so CSEL it directly. - SDValue Cmp = emitComparison(ShAmt, DAG.getConstant(0, dl, MVT::i64), - ISD::SETEQ, dl, DAG); - SDValue CCVal = DAG.getConstant(AArch64CC::EQ, dl, MVT::i32); - HiBitsForLo = - DAG.getNode(AArch64ISD::CSEL, dl, VT, DAG.getConstant(0, dl, MVT::i64), - HiBitsForLo, CCVal, Cmp); - - SDValue ExtraShAmt = DAG.getNode(ISD::SUB, dl, MVT::i64, ShAmt, - DAG.getConstant(VTBits, dl, MVT::i64)); - - SDValue LoBitsForLo = DAG.getNode(ISD::SRL, dl, VT, ShOpLo, ShAmt); - SDValue LoForNormalShift = - DAG.getNode(ISD::OR, dl, VT, LoBitsForLo, HiBitsForLo); - - Cmp = emitComparison(ExtraShAmt, DAG.getConstant(0, dl, MVT::i64), ISD::SETGE, - dl, DAG); - CCVal = DAG.getConstant(AArch64CC::GE, dl, MVT::i32); - SDValue LoForBigShift = DAG.getNode(Opc, dl, VT, ShOpHi, ExtraShAmt); - SDValue Lo = DAG.getNode(AArch64ISD::CSEL, dl, VT, LoForBigShift, - LoForNormalShift, CCVal, Cmp); - - // AArch64 shifts larger than the register width are wrapped rather than - // clamped, so we can't just emit "hi >> x". - SDValue HiForNormalShift = DAG.getNode(Opc, dl, VT, ShOpHi, ShAmt); - SDValue HiForBigShift = - Opc == ISD::SRA - ? DAG.getNode(Opc, dl, VT, ShOpHi, - DAG.getConstant(VTBits - 1, dl, MVT::i64)) - : DAG.getConstant(0, dl, VT); - SDValue Hi = DAG.getNode(AArch64ISD::CSEL, dl, VT, HiForBigShift, - HiForNormalShift, CCVal, Cmp); - - SDValue Ops[2] = { Lo, Hi }; - return DAG.getMergeValues(Ops, dl); -} - -/// LowerShiftLeftParts - Lower SHL_PARTS, which returns two -/// i64 values and take a 2 x i64 value to shift plus a shift amount. -SDValue AArch64TargetLowering::LowerShiftLeftParts(SDValue Op, - SelectionDAG &DAG) const { - assert(Op.getNumOperands() == 3 && "Not a double-shift!"); - EVT VT = Op.getValueType(); - unsigned VTBits = VT.getSizeInBits(); - SDLoc dl(Op); - SDValue ShOpLo = Op.getOperand(0); - SDValue ShOpHi = Op.getOperand(1); - SDValue ShAmt = Op.getOperand(2); - - assert(Op.getOpcode() == ISD::SHL_PARTS); - SDValue RevShAmt = DAG.getNode(ISD::SUB, dl, MVT::i64, - DAG.getConstant(VTBits, dl, MVT::i64), ShAmt); - SDValue LoBitsForHi = DAG.getNode(ISD::SRL, dl, VT, ShOpLo, RevShAmt); - - // Unfortunately, if ShAmt == 0, we just calculated "(SRL ShOpLo, 64)" which - // is "undef". We wanted 0, so CSEL it directly. - SDValue Cmp = emitComparison(ShAmt, DAG.getConstant(0, dl, MVT::i64), - ISD::SETEQ, dl, DAG); - SDValue CCVal = DAG.getConstant(AArch64CC::EQ, dl, MVT::i32); - LoBitsForHi = - DAG.getNode(AArch64ISD::CSEL, dl, VT, DAG.getConstant(0, dl, MVT::i64), - LoBitsForHi, CCVal, Cmp); - - SDValue ExtraShAmt = DAG.getNode(ISD::SUB, dl, MVT::i64, ShAmt, - DAG.getConstant(VTBits, dl, MVT::i64)); - SDValue HiBitsForHi = DAG.getNode(ISD::SHL, dl, VT, ShOpHi, ShAmt); - SDValue HiForNormalShift = - DAG.getNode(ISD::OR, dl, VT, LoBitsForHi, HiBitsForHi); - - SDValue HiForBigShift = DAG.getNode(ISD::SHL, dl, VT, ShOpLo, ExtraShAmt); - - Cmp = emitComparison(ExtraShAmt, DAG.getConstant(0, dl, MVT::i64), ISD::SETGE, - dl, DAG); - CCVal = DAG.getConstant(AArch64CC::GE, dl, MVT::i32); - SDValue Hi = DAG.getNode(AArch64ISD::CSEL, dl, VT, HiForBigShift, - HiForNormalShift, CCVal, Cmp); - - // AArch64 shifts of larger than register sizes are wrapped rather than - // clamped, so we can't just emit "lo << a" if a is too big. - SDValue LoForBigShift = DAG.getConstant(0, dl, VT); - SDValue LoForNormalShift = DAG.getNode(ISD::SHL, dl, VT, ShOpLo, ShAmt); - SDValue Lo = DAG.getNode(AArch64ISD::CSEL, dl, VT, LoForBigShift, - LoForNormalShift, CCVal, Cmp); - - SDValue Ops[2] = { Lo, Hi }; - return DAG.getMergeValues(Ops, dl); +/// LowerShiftParts - Lower SHL_PARTS/SRA_PARTS/SRL_PARTS, which returns two +/// i32 values and take a 2 x i32 value to shift plus a shift amount. +SDValue AArch64TargetLowering::LowerShiftParts(SDValue Op, + SelectionDAG &DAG) const { + SDValue Lo, Hi; + expandShiftParts(Op.getNode(), Lo, Hi, DAG); + return DAG.getMergeValues({Lo, Hi}, SDLoc(Op)); } bool AArch64TargetLowering::isOffsetFoldingLegal( @@ -7738,7 +8222,7 @@ AArch64TargetLowering::getRegForInlineAsmConstraint( : std::make_pair(0U, &AArch64::PPRRegClass); } } - if (StringRef("{cc}").equals_lower(Constraint)) + if (StringRef("{cc}").equals_insensitive(Constraint)) return std::make_pair(unsigned(AArch64::NZCV), &AArch64::CCRRegClass); // Use the default implementation in TargetLowering to convert the register @@ -7814,10 +8298,6 @@ void AArch64TargetLowering::LowerAsmOperandForConstraint( dyn_cast<BlockAddressSDNode>(Op)) { Result = DAG.getTargetBlockAddress(BA->getBlockAddress(), BA->getValueType(0)); - } else if (const ExternalSymbolSDNode *ES = - dyn_cast<ExternalSymbolSDNode>(Op)) { - Result = - DAG.getTargetExternalSymbol(ES->getSymbol(), ES->getValueType(0)); } else return; break; @@ -7944,7 +8424,7 @@ static SDValue WidenVector(SDValue V64Reg, SelectionDAG &DAG) { SDLoc DL(V64Reg); return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, WideTy, DAG.getUNDEF(WideTy), - V64Reg, DAG.getConstant(0, DL, MVT::i32)); + V64Reg, DAG.getConstant(0, DL, MVT::i64)); } /// getExtFactor - Determine the adjustment factor for the position when @@ -8792,6 +9272,9 @@ SDValue AArch64TargetLowering::LowerVECTOR_SHUFFLE(SDValue Op, ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(Op.getNode()); + if (useSVEForFixedLengthVectorVT(VT)) + return LowerFixedLengthVECTOR_SHUFFLEToSVE(Op, DAG); + // Convert shuffles that are directly supported on NEON to target-specific // DAG nodes, instead of keeping them as shuffles and matching them again // during code selection. This is more efficient and avoids the possibility @@ -8801,6 +9284,10 @@ SDValue AArch64TargetLowering::LowerVECTOR_SHUFFLE(SDValue Op, SDValue V1 = Op.getOperand(0); SDValue V2 = Op.getOperand(1); + assert(V1.getValueType() == VT && "Unexpected VECTOR_SHUFFLE type!"); + assert(ShuffleMask.size() == VT.getVectorNumElements() && + "Unexpected VECTOR_SHUFFLE mask size!"); + if (SVN->isSplat()) { int Lane = SVN->getSplatIndex(); // If this is undef splat, generate it via "just" vdup, if possible. @@ -8847,6 +9334,14 @@ SDValue AArch64TargetLowering::LowerVECTOR_SHUFFLE(SDValue Op, if (isREVMask(ShuffleMask, VT, 16)) return DAG.getNode(AArch64ISD::REV16, dl, V1.getValueType(), V1, V2); + if (((VT.getVectorNumElements() == 8 && VT.getScalarSizeInBits() == 16) || + (VT.getVectorNumElements() == 16 && VT.getScalarSizeInBits() == 8)) && + ShuffleVectorInst::isReverseMask(ShuffleMask)) { + SDValue Rev = DAG.getNode(AArch64ISD::REV64, dl, VT, V1); + return DAG.getNode(AArch64ISD::EXT, dl, VT, Rev, Rev, + DAG.getConstant(8, dl, MVT::i32)); + } + bool ReverseEXT = false; unsigned Imm; if (isEXTMask(ShuffleMask, VT, ReverseEXT, Imm)) { @@ -9027,9 +9522,7 @@ SDValue AArch64TargetLowering::LowerDUPQLane(SDValue Op, SDValue SplatOne = DAG.getNode(ISD::SPLAT_VECTOR, DL, MVT::nxv2i64, One); // create the vector 0,1,0,1,... - SDValue Zero = DAG.getConstant(0, DL, MVT::i64); - SDValue SV = DAG.getNode(AArch64ISD::INDEX_VECTOR, - DL, MVT::nxv2i64, Zero, One); + SDValue SV = DAG.getStepVector(DL, MVT::nxv2i64); SV = DAG.getNode(ISD::AND, DL, MVT::nxv2i64, SV, SplatOne); // create the vector idx64,idx64+1,idx64,idx64+1,... @@ -9556,10 +10049,10 @@ SDValue AArch64TargetLowering::LowerBUILD_VECTOR(SDValue Op, } if (i > 0) isOnlyLowElement = false; - if (!isa<ConstantFPSDNode>(V) && !isa<ConstantSDNode>(V)) + if (!isIntOrFPConstant(V)) isConstant = false; - if (isa<ConstantSDNode>(V) || isa<ConstantFPSDNode>(V)) { + if (isIntOrFPConstant(V)) { ++NumConstantLanes; if (!ConstantValue.getNode()) ConstantValue = V; @@ -9584,7 +10077,7 @@ SDValue AArch64TargetLowering::LowerBUILD_VECTOR(SDValue Op, // Convert BUILD_VECTOR where all elements but the lowest are undef into // SCALAR_TO_VECTOR, except for when we have a single-element constant vector // as SimplifyDemandedBits will just turn that back into BUILD_VECTOR. - if (isOnlyLowElement && !(NumElts == 1 && isa<ConstantSDNode>(Value))) { + if (isOnlyLowElement && !(NumElts == 1 && isIntOrFPConstant(Value))) { LLVM_DEBUG(dbgs() << "LowerBUILD_VECTOR: only low element used, creating 1 " "SCALAR_TO_VECTOR node\n"); return DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, VT, Value); @@ -9725,7 +10218,7 @@ SDValue AArch64TargetLowering::LowerBUILD_VECTOR(SDValue Op, for (unsigned i = 0; i < NumElts; ++i) { SDValue V = Op.getOperand(i); SDValue LaneIdx = DAG.getConstant(i, dl, MVT::i64); - if (!isa<ConstantSDNode>(V) && !isa<ConstantFPSDNode>(V)) + if (!isIntOrFPConstant(V)) // Note that type legalization likely mucked about with the VT of the // source operand, so we may have to convert it here before inserting. Val = DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, VT, Val, V, LaneIdx); @@ -9749,9 +10242,7 @@ SDValue AArch64TargetLowering::LowerBUILD_VECTOR(SDValue Op, if (PreferDUPAndInsert) { // First, build a constant vector with the common element. - SmallVector<SDValue, 8> Ops; - for (unsigned I = 0; I < NumElts; ++I) - Ops.push_back(Value); + SmallVector<SDValue, 8> Ops(NumElts, Value); SDValue NewVector = LowerBUILD_VECTOR(DAG.getBuildVector(VT, dl, Ops), DAG); // Next, insert the elements that do not match the common value. for (unsigned I = 0; I < NumElts; ++I) @@ -9813,6 +10304,9 @@ SDValue AArch64TargetLowering::LowerBUILD_VECTOR(SDValue Op, SDValue AArch64TargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const { + if (useSVEForFixedLengthVectorVT(Op.getValueType())) + return LowerFixedLengthConcatVectorsToSVE(Op, DAG); + assert(Op.getValueType().isScalableVector() && isTypeLegal(Op.getValueType()) && "Expected legal scalable vector type!"); @@ -9827,13 +10321,32 @@ SDValue AArch64TargetLowering::LowerINSERT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const { assert(Op.getOpcode() == ISD::INSERT_VECTOR_ELT && "Unknown opcode!"); + if (useSVEForFixedLengthVectorVT(Op.getValueType())) + return LowerFixedLengthInsertVectorElt(Op, DAG); + // Check for non-constant or out of range lane. EVT VT = Op.getOperand(0).getValueType(); + + if (VT.getScalarType() == MVT::i1) { + EVT VectorVT = getPromotedVTForPredicate(VT); + SDLoc DL(Op); + SDValue ExtendedVector = + DAG.getAnyExtOrTrunc(Op.getOperand(0), DL, VectorVT); + SDValue ExtendedValue = + DAG.getAnyExtOrTrunc(Op.getOperand(1), DL, + VectorVT.getScalarType().getSizeInBits() < 32 + ? MVT::i32 + : VectorVT.getScalarType()); + ExtendedVector = + DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, VectorVT, ExtendedVector, + ExtendedValue, Op.getOperand(2)); + return DAG.getAnyExtOrTrunc(ExtendedVector, DL, VT); + } + ConstantSDNode *CI = dyn_cast<ConstantSDNode>(Op.getOperand(2)); if (!CI || CI->getZExtValue() >= VT.getVectorNumElements()) return SDValue(); - // Insertion/extraction are legal for V128 types. if (VT == MVT::v16i8 || VT == MVT::v8i16 || VT == MVT::v4i32 || VT == MVT::v2i64 || VT == MVT::v4f32 || VT == MVT::v2f64 || @@ -9861,14 +10374,29 @@ SDValue AArch64TargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const { assert(Op.getOpcode() == ISD::EXTRACT_VECTOR_ELT && "Unknown opcode!"); + EVT VT = Op.getOperand(0).getValueType(); + + if (VT.getScalarType() == MVT::i1) { + // We can't directly extract from an SVE predicate; extend it first. + // (This isn't the only possible lowering, but it's straightforward.) + EVT VectorVT = getPromotedVTForPredicate(VT); + SDLoc DL(Op); + SDValue Extend = + DAG.getNode(ISD::ANY_EXTEND, DL, VectorVT, Op.getOperand(0)); + MVT ExtractTy = VectorVT == MVT::nxv2i64 ? MVT::i64 : MVT::i32; + SDValue Extract = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ExtractTy, + Extend, Op.getOperand(1)); + return DAG.getAnyExtOrTrunc(Extract, DL, Op.getValueType()); + } + + if (useSVEForFixedLengthVectorVT(VT)) + return LowerFixedLengthExtractVectorElt(Op, DAG); // Check for non-constant or out of range lane. - EVT VT = Op.getOperand(0).getValueType(); ConstantSDNode *CI = dyn_cast<ConstantSDNode>(Op.getOperand(1)); if (!CI || CI->getZExtValue() >= VT.getVectorNumElements()) return SDValue(); - // Insertion/extraction are legal for V128 types. if (VT == MVT::v16i8 || VT == MVT::v8i16 || VT == MVT::v4i32 || VT == MVT::v2i64 || VT == MVT::v4f32 || VT == MVT::v2f64 || @@ -10159,7 +10687,8 @@ SDValue AArch64TargetLowering::LowerVectorSRA_SRL_SHL(SDValue Op, unsigned Opc = (Op.getOpcode() == ISD::SRA) ? Intrinsic::aarch64_neon_sshl : Intrinsic::aarch64_neon_ushl; // negate the shift amount - SDValue NegShift = DAG.getNode(AArch64ISD::NEG, DL, VT, Op.getOperand(1)); + SDValue NegShift = DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT), + Op.getOperand(1)); SDValue NegShiftLeft = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT, DAG.getConstant(Opc, DL, MVT::i32), Op.getOperand(0), @@ -10267,11 +10796,8 @@ static SDValue EmitVectorComparison(SDValue LHS, SDValue RHS, SDValue AArch64TargetLowering::LowerVSETCC(SDValue Op, SelectionDAG &DAG) const { - if (Op.getValueType().isScalableVector()) { - if (Op.getOperand(0).getValueType().isFloatingPoint()) - return Op; + if (Op.getValueType().isScalableVector()) return LowerToPredicatedOp(Op, DAG, AArch64ISD::SETCC_MERGE_ZERO); - } if (useSVEForFixedLengthVectorVT(Op.getOperand(0).getValueType())) return LowerFixedLengthVectorSetccToSVE(Op, DAG); @@ -11391,8 +11917,8 @@ EVT AArch64TargetLowering::getOptimalMemOpType( if (Op.isAligned(AlignCheck)) return true; bool Fast; - return allowsMisalignedMemoryAccesses(VT, 0, 1, MachineMemOperand::MONone, - &Fast) && + return allowsMisalignedMemoryAccesses(VT, 0, Align(1), + MachineMemOperand::MONone, &Fast) && Fast; }; @@ -11422,14 +11948,14 @@ LLT AArch64TargetLowering::getOptimalMemOpLLT( if (Op.isAligned(AlignCheck)) return true; bool Fast; - return allowsMisalignedMemoryAccesses(VT, 0, 1, MachineMemOperand::MONone, - &Fast) && + return allowsMisalignedMemoryAccesses(VT, 0, Align(1), + MachineMemOperand::MONone, &Fast) && Fast; }; if (CanUseNEON && Op.isMemset() && !IsSmallMemset && AlignmentIsAcceptable(MVT::v2i64, Align(16))) - return LLT::vector(2, 64); + return LLT::fixed_vector(2, 64); if (CanUseFP && !IsSmallMemset && AlignmentIsAcceptable(MVT::f128, Align(16))) return LLT::scalar(128); if (Op.size() >= 8 && AlignmentIsAcceptable(MVT::i64, Align(8))) @@ -11482,8 +12008,12 @@ bool AArch64TargetLowering::isLegalAddressingMode(const DataLayout &DL, return false; // FIXME: Update this method to support scalable addressing modes. - if (isa<ScalableVectorType>(Ty)) - return AM.HasBaseReg && !AM.BaseOffs && !AM.Scale; + if (isa<ScalableVectorType>(Ty)) { + uint64_t VecElemNumBytes = + DL.getTypeSizeInBits(cast<VectorType>(Ty)->getElementType()) / 8; + return AM.HasBaseReg && !AM.BaseOffs && + (AM.Scale == 0 || (uint64_t)AM.Scale == VecElemNumBytes); + } // check reg + imm case: // i.e., reg + 0, reg + imm9, reg + SIZE_IN_BYTES * uimm12 @@ -11521,9 +12051,8 @@ bool AArch64TargetLowering::shouldConsiderGEPOffsetSplit() const { return true; } -int AArch64TargetLowering::getScalingFactorCost(const DataLayout &DL, - const AddrMode &AM, Type *Ty, - unsigned AS) const { +InstructionCost AArch64TargetLowering::getScalingFactorCost( + const DataLayout &DL, const AddrMode &AM, Type *Ty, unsigned AS) const { // Scaling factors are not free at all. // Operands | Rt Latency // ------------------------------------------- @@ -11546,6 +12075,8 @@ bool AArch64TargetLowering::isFMAFasterThanFMulAndFAdd( return false; switch (VT.getSimpleVT().SimpleTy) { + case MVT::f16: + return Subtarget->hasFullFP16(); case MVT::f32: case MVT::f64: return true; @@ -11567,6 +12098,11 @@ bool AArch64TargetLowering::isFMAFasterThanFMulAndFAdd(const Function &F, } } +bool AArch64TargetLowering::generateFMAsInMachineCombiner( + EVT VT, CodeGenOpt::Level OptLevel) const { + return (OptLevel >= CodeGenOpt::Aggressive) && !VT.isScalableVector(); +} + const MCPhysReg * AArch64TargetLowering::getScratchRegisters(CallingConv::ID) const { // LR is a callee-save register, but we must treat it as clobbered by any call @@ -11654,77 +12190,142 @@ static SDValue foldVectorXorShiftIntoCmp(SDNode *N, SelectionDAG &DAG, return DAG.getNode(AArch64ISD::CMGEz, SDLoc(N), VT, Shift.getOperand(0)); } -// VECREDUCE_ADD( EXTEND(v16i8_type) ) to -// VECREDUCE_ADD( DOTv16i8(v16i8_type) ) -static SDValue performVecReduceAddCombine(SDNode *N, SelectionDAG &DAG, - const AArch64Subtarget *ST) { - SDValue Op0 = N->getOperand(0); - if (!ST->hasDotProd() || N->getValueType(0) != MVT::i32) +// Given a vecreduce_add node, detect the below pattern and convert it to the +// node sequence with UABDL, [S|U]ADB and UADDLP. +// +// i32 vecreduce_add( +// v16i32 abs( +// v16i32 sub( +// v16i32 [sign|zero]_extend(v16i8 a), v16i32 [sign|zero]_extend(v16i8 b)))) +// =================> +// i32 vecreduce_add( +// v4i32 UADDLP( +// v8i16 add( +// v8i16 zext( +// v8i8 [S|U]ABD low8:v16i8 a, low8:v16i8 b +// v8i16 zext( +// v8i8 [S|U]ABD high8:v16i8 a, high8:v16i8 b +static SDValue performVecReduceAddCombineWithUADDLP(SDNode *N, + SelectionDAG &DAG) { + // Assumed i32 vecreduce_add + if (N->getValueType(0) != MVT::i32) return SDValue(); - if (Op0.getValueType().getVectorElementType() != MVT::i32) + SDValue VecReduceOp0 = N->getOperand(0); + unsigned Opcode = VecReduceOp0.getOpcode(); + // Assumed v16i32 abs + if (Opcode != ISD::ABS || VecReduceOp0->getValueType(0) != MVT::v16i32) return SDValue(); - unsigned ExtOpcode = Op0.getOpcode(); - if (ExtOpcode != ISD::ZERO_EXTEND && ExtOpcode != ISD::SIGN_EXTEND) + SDValue ABS = VecReduceOp0; + // Assumed v16i32 sub + if (ABS->getOperand(0)->getOpcode() != ISD::SUB || + ABS->getOperand(0)->getValueType(0) != MVT::v16i32) return SDValue(); - EVT Op0VT = Op0.getOperand(0).getValueType(); - if (Op0VT != MVT::v16i8) + SDValue SUB = ABS->getOperand(0); + unsigned Opcode0 = SUB->getOperand(0).getOpcode(); + unsigned Opcode1 = SUB->getOperand(1).getOpcode(); + // Assumed v16i32 type + if (SUB->getOperand(0)->getValueType(0) != MVT::v16i32 || + SUB->getOperand(1)->getValueType(0) != MVT::v16i32) return SDValue(); - SDLoc DL(Op0); - SDValue Ones = DAG.getConstant(1, DL, Op0VT); - SDValue Zeros = DAG.getConstant(0, DL, MVT::v4i32); - auto DotIntrisic = (ExtOpcode == ISD::ZERO_EXTEND) - ? Intrinsic::aarch64_neon_udot - : Intrinsic::aarch64_neon_sdot; - SDValue Dot = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, Zeros.getValueType(), - DAG.getConstant(DotIntrisic, DL, MVT::i32), Zeros, - Ones, Op0.getOperand(0)); - return DAG.getNode(ISD::VECREDUCE_ADD, DL, N->getValueType(0), Dot); -} - -// Given a ABS node, detect the following pattern: -// (ABS (SUB (EXTEND a), (EXTEND b))). -// Generates UABD/SABD instruction. -static SDValue performABSCombine(SDNode *N, SelectionDAG &DAG, - TargetLowering::DAGCombinerInfo &DCI, - const AArch64Subtarget *Subtarget) { - SDValue AbsOp1 = N->getOperand(0); - SDValue Op0, Op1; + // Assumed zext or sext + bool IsZExt = false; + if (Opcode0 == ISD::ZERO_EXTEND && Opcode1 == ISD::ZERO_EXTEND) { + IsZExt = true; + } else if (Opcode0 == ISD::SIGN_EXTEND && Opcode1 == ISD::SIGN_EXTEND) { + IsZExt = false; + } else + return SDValue(); - if (AbsOp1.getOpcode() != ISD::SUB) + SDValue EXT0 = SUB->getOperand(0); + SDValue EXT1 = SUB->getOperand(1); + // Assumed zext's operand has v16i8 type + if (EXT0->getOperand(0)->getValueType(0) != MVT::v16i8 || + EXT1->getOperand(0)->getValueType(0) != MVT::v16i8) return SDValue(); - Op0 = AbsOp1.getOperand(0); - Op1 = AbsOp1.getOperand(1); + // Pattern is dectected. Let's convert it to sequence of nodes. + SDLoc DL(N); + + // First, create the node pattern of UABD/SABD. + SDValue UABDHigh8Op0 = + DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v8i8, EXT0->getOperand(0), + DAG.getConstant(8, DL, MVT::i64)); + SDValue UABDHigh8Op1 = + DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v8i8, EXT1->getOperand(0), + DAG.getConstant(8, DL, MVT::i64)); + SDValue UABDHigh8 = DAG.getNode(IsZExt ? ISD::ABDU : ISD::ABDS, DL, MVT::v8i8, + UABDHigh8Op0, UABDHigh8Op1); + SDValue UABDL = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::v8i16, UABDHigh8); + + // Second, create the node pattern of UABAL. + SDValue UABDLo8Op0 = + DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v8i8, EXT0->getOperand(0), + DAG.getConstant(0, DL, MVT::i64)); + SDValue UABDLo8Op1 = + DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v8i8, EXT1->getOperand(0), + DAG.getConstant(0, DL, MVT::i64)); + SDValue UABDLo8 = DAG.getNode(IsZExt ? ISD::ABDU : ISD::ABDS, DL, MVT::v8i8, + UABDLo8Op0, UABDLo8Op1); + SDValue ZExtUABD = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::v8i16, UABDLo8); + SDValue UABAL = DAG.getNode(ISD::ADD, DL, MVT::v8i16, UABDL, ZExtUABD); + + // Third, create the node of UADDLP. + SDValue UADDLP = DAG.getNode(AArch64ISD::UADDLP, DL, MVT::v4i32, UABAL); + + // Fourth, create the node of VECREDUCE_ADD. + return DAG.getNode(ISD::VECREDUCE_ADD, DL, MVT::i32, UADDLP); +} + +// Turn a v8i8/v16i8 extended vecreduce into a udot/sdot and vecreduce +// vecreduce.add(ext(A)) to vecreduce.add(DOT(zero, A, one)) +// vecreduce.add(mul(ext(A), ext(B))) to vecreduce.add(DOT(zero, A, B)) +static SDValue performVecReduceAddCombine(SDNode *N, SelectionDAG &DAG, + const AArch64Subtarget *ST) { + if (!ST->hasDotProd()) + return performVecReduceAddCombineWithUADDLP(N, DAG); - unsigned Opc0 = Op0.getOpcode(); - // Check if the operands of the sub are (zero|sign)-extended. - if (Opc0 != Op1.getOpcode() || - (Opc0 != ISD::ZERO_EXTEND && Opc0 != ISD::SIGN_EXTEND)) + SDValue Op0 = N->getOperand(0); + if (N->getValueType(0) != MVT::i32 || + Op0.getValueType().getVectorElementType() != MVT::i32) return SDValue(); - EVT VectorT1 = Op0.getOperand(0).getValueType(); - EVT VectorT2 = Op1.getOperand(0).getValueType(); - // Check if vectors are of same type and valid size. - uint64_t Size = VectorT1.getFixedSizeInBits(); - if (VectorT1 != VectorT2 || (Size != 64 && Size != 128)) + unsigned ExtOpcode = Op0.getOpcode(); + SDValue A = Op0; + SDValue B; + if (ExtOpcode == ISD::MUL) { + A = Op0.getOperand(0); + B = Op0.getOperand(1); + if (A.getOpcode() != B.getOpcode() || + A.getOperand(0).getValueType() != B.getOperand(0).getValueType()) + return SDValue(); + ExtOpcode = A.getOpcode(); + } + if (ExtOpcode != ISD::ZERO_EXTEND && ExtOpcode != ISD::SIGN_EXTEND) return SDValue(); - // Check if vector element types are valid. - EVT VT1 = VectorT1.getVectorElementType(); - if (VT1 != MVT::i8 && VT1 != MVT::i16 && VT1 != MVT::i32) + EVT Op0VT = A.getOperand(0).getValueType(); + if (Op0VT != MVT::v8i8 && Op0VT != MVT::v16i8) return SDValue(); - Op0 = Op0.getOperand(0); - Op1 = Op1.getOperand(0); - unsigned ABDOpcode = - (Opc0 == ISD::SIGN_EXTEND) ? AArch64ISD::SABD : AArch64ISD::UABD; - SDValue ABD = - DAG.getNode(ABDOpcode, SDLoc(N), Op0->getValueType(0), Op0, Op1); - return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), N->getValueType(0), ABD); + SDLoc DL(Op0); + // For non-mla reductions B can be set to 1. For MLA we take the operand of + // the extend B. + if (!B) + B = DAG.getConstant(1, DL, Op0VT); + else + B = B.getOperand(0); + + SDValue Zeros = + DAG.getConstant(0, DL, Op0VT == MVT::v8i8 ? MVT::v2i32 : MVT::v4i32); + auto DotOpcode = + (ExtOpcode == ISD::ZERO_EXTEND) ? AArch64ISD::UDOT : AArch64ISD::SDOT; + SDValue Dot = DAG.getNode(DotOpcode, DL, Zeros.getValueType(), Zeros, + A.getOperand(0), B); + return DAG.getNode(ISD::VECREDUCE_ADD, DL, N->getValueType(0), Dot); } static SDValue performXorCombine(SDNode *N, SelectionDAG &DAG, @@ -11972,6 +12573,7 @@ static SDValue performMulCombine(SDNode *N, SelectionDAG &DAG, // e.g. 6=3*2=(2+1)*2. // TODO: consider lowering more cases, e.g. C = 14, -6, -14 or even 45 // which equals to (1+2)*16-(1+2). + // TrailingZeroes is used to test if the mul can be lowered to // shift+add+shift. unsigned TrailingZeroes = ConstValue.countTrailingZeros(); @@ -12350,6 +12952,11 @@ static SDValue tryCombineToBSL(SDNode *N, if (!VT.isVector()) return SDValue(); + // The combining code currently only works for NEON vectors. In particular, + // it does not work for SVE when dealing with vectors wider than 128 bits. + if (!VT.is64BitVector() && !VT.is128BitVector()) + return SDValue(); + SDValue N0 = N->getOperand(0); if (N0.getOpcode() != ISD::AND) return SDValue(); @@ -12358,6 +12965,44 @@ static SDValue tryCombineToBSL(SDNode *N, if (N1.getOpcode() != ISD::AND) return SDValue(); + // InstCombine does (not (neg a)) => (add a -1). + // Try: (or (and (neg a) b) (and (add a -1) c)) => (bsl (neg a) b c) + // Loop over all combinations of AND operands. + for (int i = 1; i >= 0; --i) { + for (int j = 1; j >= 0; --j) { + SDValue O0 = N0->getOperand(i); + SDValue O1 = N1->getOperand(j); + SDValue Sub, Add, SubSibling, AddSibling; + + // Find a SUB and an ADD operand, one from each AND. + if (O0.getOpcode() == ISD::SUB && O1.getOpcode() == ISD::ADD) { + Sub = O0; + Add = O1; + SubSibling = N0->getOperand(1 - i); + AddSibling = N1->getOperand(1 - j); + } else if (O0.getOpcode() == ISD::ADD && O1.getOpcode() == ISD::SUB) { + Add = O0; + Sub = O1; + AddSibling = N0->getOperand(1 - i); + SubSibling = N1->getOperand(1 - j); + } else + continue; + + if (!ISD::isBuildVectorAllZeros(Sub.getOperand(0).getNode())) + continue; + + // Constant ones is always righthand operand of the Add. + if (!ISD::isBuildVectorAllOnes(Add.getOperand(1).getNode())) + continue; + + if (Sub.getOperand(1) != Add.getOperand(0)) + continue; + + return DAG.getNode(AArch64ISD::BSP, DL, VT, Sub, SubSibling, AddSibling); + } + } + + // (or (and a b) (and (not a) c)) => (bsl a b c) // We only have to look for constant vectors here since the general, variable // case can be handled in TableGen. unsigned Bits = VT.getScalarSizeInBits(); @@ -13065,6 +13710,13 @@ static SDValue performSetccAddFolding(SDNode *Op, SelectionDAG &DAG) { SDValue RHS = Op->getOperand(1); SetCCInfoAndKind InfoAndKind; + // If both operands are a SET_CC, then we don't want to perform this + // folding and create another csel as this results in more instructions + // (and higher register usage). + if (isSetCCOrZExtSetCC(LHS, InfoAndKind) && + isSetCCOrZExtSetCC(RHS, InfoAndKind)) + return SDValue(); + // If neither operand is a SET_CC, give up. if (!isSetCCOrZExtSetCC(LHS, InfoAndKind)) { std::swap(LHS, RHS); @@ -13135,6 +13787,29 @@ static SDValue performUADDVCombine(SDNode *N, SelectionDAG &DAG) { DAG.getConstant(0, DL, MVT::i64)); } +// ADD(UDOT(zero, x, y), A) --> UDOT(A, x, y) +static SDValue performAddDotCombine(SDNode *N, SelectionDAG &DAG) { + EVT VT = N->getValueType(0); + if (N->getOpcode() != ISD::ADD) + return SDValue(); + + SDValue Dot = N->getOperand(0); + SDValue A = N->getOperand(1); + // Handle commutivity + auto isZeroDot = [](SDValue Dot) { + return (Dot.getOpcode() == AArch64ISD::UDOT || + Dot.getOpcode() == AArch64ISD::SDOT) && + isZerosVector(Dot.getOperand(0).getNode()); + }; + if (!isZeroDot(Dot)) + std::swap(Dot, A); + if (!isZeroDot(Dot)) + return SDValue(); + + return DAG.getNode(Dot.getOpcode(), SDLoc(N), VT, A, Dot.getOperand(1), + Dot.getOperand(2)); +} + // The basic add/sub long vector instructions have variants with "2" on the end // which act on the high-half of their inputs. They are normally matched by // patterns like: @@ -13194,6 +13869,8 @@ static SDValue performAddSubCombine(SDNode *N, // Try to change sum of two reductions. if (SDValue Val = performUADDVCombine(N, DAG)) return Val; + if (SDValue Val = performAddDotCombine(N, DAG)) + return Val; return performAddSubLongCombine(N, DCI, DAG); } @@ -13335,15 +14012,16 @@ static SDValue LowerSVEIntrinsicIndex(SDNode *N, SelectionDAG &DAG) { SDLoc DL(N); SDValue Op1 = N->getOperand(1); SDValue Op2 = N->getOperand(2); - EVT ScalarTy = Op1.getValueType(); - - if ((ScalarTy == MVT::i8) || (ScalarTy == MVT::i16)) { - Op1 = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, Op1); - Op2 = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, Op2); - } + EVT ScalarTy = Op2.getValueType(); + if ((ScalarTy == MVT::i8) || (ScalarTy == MVT::i16)) + ScalarTy = MVT::i32; - return DAG.getNode(AArch64ISD::INDEX_VECTOR, DL, N->getValueType(0), - Op1, Op2); + // Lower index_vector(base, step) to mul(step step_vector(1)) + splat(base). + SDValue StepVector = DAG.getStepVector(DL, N->getValueType(0)); + SDValue Step = DAG.getNode(ISD::SPLAT_VECTOR, DL, N->getValueType(0), Op2); + SDValue Mul = DAG.getNode(ISD::MUL, DL, N->getValueType(0), StepVector, Step); + SDValue Base = DAG.getNode(ISD::SPLAT_VECTOR, DL, N->getValueType(0), Op1); + return DAG.getNode(ISD::ADD, DL, N->getValueType(0), Mul, Base); } static SDValue LowerSVEIntrinsicDUP(SDNode *N, SelectionDAG &DAG) { @@ -13533,20 +14211,47 @@ static SDValue combineSVEReductionOrderedFP(SDNode *N, unsigned Opc, Zero); } +static bool isAllActivePredicate(SDValue N) { + unsigned NumElts = N.getValueType().getVectorMinNumElements(); + + // Look through cast. + while (N.getOpcode() == AArch64ISD::REINTERPRET_CAST) { + N = N.getOperand(0); + // When reinterpreting from a type with fewer elements the "new" elements + // are not active, so bail if they're likely to be used. + if (N.getValueType().getVectorMinNumElements() < NumElts) + return false; + } + + // "ptrue p.<ty>, all" can be considered all active when <ty> is the same size + // or smaller than the implicit element type represented by N. + // NOTE: A larger element count implies a smaller element type. + if (N.getOpcode() == AArch64ISD::PTRUE && + N.getConstantOperandVal(0) == AArch64SVEPredPattern::all) + return N.getValueType().getVectorMinNumElements() >= NumElts; + + return false; +} + // If a merged operation has no inactive lanes we can relax it to a predicated // or unpredicated operation, which potentially allows better isel (perhaps // using immediate forms) or relaxing register reuse requirements. -static SDValue convertMergedOpToPredOp(SDNode *N, unsigned PredOpc, - SelectionDAG &DAG) { +static SDValue convertMergedOpToPredOp(SDNode *N, unsigned Opc, + SelectionDAG &DAG, + bool UnpredOp = false) { assert(N->getOpcode() == ISD::INTRINSIC_WO_CHAIN && "Expected intrinsic!"); assert(N->getNumOperands() == 4 && "Expected 3 operand intrinsic!"); SDValue Pg = N->getOperand(1); // ISD way to specify an all active predicate. - if ((Pg.getOpcode() == AArch64ISD::PTRUE) && - (Pg.getConstantOperandVal(0) == AArch64SVEPredPattern::all)) - return DAG.getNode(PredOpc, SDLoc(N), N->getValueType(0), Pg, - N->getOperand(2), N->getOperand(3)); + if (isAllActivePredicate(Pg)) { + if (UnpredOp) + return DAG.getNode(Opc, SDLoc(N), N->getValueType(0), N->getOperand(2), + N->getOperand(3)); + else + return DAG.getNode(Opc, SDLoc(N), N->getValueType(0), Pg, + N->getOperand(2), N->getOperand(3)); + } // FUTURE: SplatVector(true) return SDValue(); @@ -13637,6 +14342,12 @@ static SDValue performIntrinsicCombine(SDNode *N, N->getOperand(1)); case Intrinsic::aarch64_sve_ext: return LowerSVEIntrinsicEXT(N, DAG); + case Intrinsic::aarch64_sve_mul: + return convertMergedOpToPredOp(N, AArch64ISD::MUL_PRED, DAG); + case Intrinsic::aarch64_sve_smulh: + return convertMergedOpToPredOp(N, AArch64ISD::MULHS_PRED, DAG); + case Intrinsic::aarch64_sve_umulh: + return convertMergedOpToPredOp(N, AArch64ISD::MULHU_PRED, DAG); case Intrinsic::aarch64_sve_smin: return convertMergedOpToPredOp(N, AArch64ISD::SMIN_PRED, DAG); case Intrinsic::aarch64_sve_umin: @@ -13651,6 +14362,44 @@ static SDValue performIntrinsicCombine(SDNode *N, return convertMergedOpToPredOp(N, AArch64ISD::SRL_PRED, DAG); case Intrinsic::aarch64_sve_asr: return convertMergedOpToPredOp(N, AArch64ISD::SRA_PRED, DAG); + case Intrinsic::aarch64_sve_fadd: + return convertMergedOpToPredOp(N, AArch64ISD::FADD_PRED, DAG); + case Intrinsic::aarch64_sve_fsub: + return convertMergedOpToPredOp(N, AArch64ISD::FSUB_PRED, DAG); + case Intrinsic::aarch64_sve_fmul: + return convertMergedOpToPredOp(N, AArch64ISD::FMUL_PRED, DAG); + case Intrinsic::aarch64_sve_add: + return convertMergedOpToPredOp(N, ISD::ADD, DAG, true); + case Intrinsic::aarch64_sve_sub: + return convertMergedOpToPredOp(N, ISD::SUB, DAG, true); + case Intrinsic::aarch64_sve_and: + return convertMergedOpToPredOp(N, ISD::AND, DAG, true); + case Intrinsic::aarch64_sve_bic: + return convertMergedOpToPredOp(N, AArch64ISD::BIC, DAG, true); + case Intrinsic::aarch64_sve_eor: + return convertMergedOpToPredOp(N, ISD::XOR, DAG, true); + case Intrinsic::aarch64_sve_orr: + return convertMergedOpToPredOp(N, ISD::OR, DAG, true); + case Intrinsic::aarch64_sve_sqadd: + return convertMergedOpToPredOp(N, ISD::SADDSAT, DAG, true); + case Intrinsic::aarch64_sve_sqsub: + return convertMergedOpToPredOp(N, ISD::SSUBSAT, DAG, true); + case Intrinsic::aarch64_sve_uqadd: + return convertMergedOpToPredOp(N, ISD::UADDSAT, DAG, true); + case Intrinsic::aarch64_sve_uqsub: + return convertMergedOpToPredOp(N, ISD::USUBSAT, DAG, true); + case Intrinsic::aarch64_sve_sqadd_x: + return DAG.getNode(ISD::SADDSAT, SDLoc(N), N->getValueType(0), + N->getOperand(1), N->getOperand(2)); + case Intrinsic::aarch64_sve_sqsub_x: + return DAG.getNode(ISD::SSUBSAT, SDLoc(N), N->getValueType(0), + N->getOperand(1), N->getOperand(2)); + case Intrinsic::aarch64_sve_uqadd_x: + return DAG.getNode(ISD::UADDSAT, SDLoc(N), N->getValueType(0), + N->getOperand(1), N->getOperand(2)); + case Intrinsic::aarch64_sve_uqsub_x: + return DAG.getNode(ISD::USUBSAT, SDLoc(N), N->getValueType(0), + N->getOperand(1), N->getOperand(2)); case Intrinsic::aarch64_sve_cmphs: if (!N->getOperand(2).getValueType().isFloatingPoint()) return DAG.getNode(AArch64ISD::SETCC_MERGE_ZERO, SDLoc(N), @@ -13663,29 +14412,34 @@ static SDValue performIntrinsicCombine(SDNode *N, N->getValueType(0), N->getOperand(1), N->getOperand(2), N->getOperand(3), DAG.getCondCode(ISD::SETUGT)); break; + case Intrinsic::aarch64_sve_fcmpge: case Intrinsic::aarch64_sve_cmpge: - if (!N->getOperand(2).getValueType().isFloatingPoint()) - return DAG.getNode(AArch64ISD::SETCC_MERGE_ZERO, SDLoc(N), - N->getValueType(0), N->getOperand(1), N->getOperand(2), - N->getOperand(3), DAG.getCondCode(ISD::SETGE)); + return DAG.getNode(AArch64ISD::SETCC_MERGE_ZERO, SDLoc(N), + N->getValueType(0), N->getOperand(1), N->getOperand(2), + N->getOperand(3), DAG.getCondCode(ISD::SETGE)); break; + case Intrinsic::aarch64_sve_fcmpgt: case Intrinsic::aarch64_sve_cmpgt: - if (!N->getOperand(2).getValueType().isFloatingPoint()) - return DAG.getNode(AArch64ISD::SETCC_MERGE_ZERO, SDLoc(N), - N->getValueType(0), N->getOperand(1), N->getOperand(2), - N->getOperand(3), DAG.getCondCode(ISD::SETGT)); + return DAG.getNode(AArch64ISD::SETCC_MERGE_ZERO, SDLoc(N), + N->getValueType(0), N->getOperand(1), N->getOperand(2), + N->getOperand(3), DAG.getCondCode(ISD::SETGT)); break; + case Intrinsic::aarch64_sve_fcmpeq: case Intrinsic::aarch64_sve_cmpeq: - if (!N->getOperand(2).getValueType().isFloatingPoint()) - return DAG.getNode(AArch64ISD::SETCC_MERGE_ZERO, SDLoc(N), - N->getValueType(0), N->getOperand(1), N->getOperand(2), - N->getOperand(3), DAG.getCondCode(ISD::SETEQ)); + return DAG.getNode(AArch64ISD::SETCC_MERGE_ZERO, SDLoc(N), + N->getValueType(0), N->getOperand(1), N->getOperand(2), + N->getOperand(3), DAG.getCondCode(ISD::SETEQ)); break; + case Intrinsic::aarch64_sve_fcmpne: case Intrinsic::aarch64_sve_cmpne: - if (!N->getOperand(2).getValueType().isFloatingPoint()) - return DAG.getNode(AArch64ISD::SETCC_MERGE_ZERO, SDLoc(N), - N->getValueType(0), N->getOperand(1), N->getOperand(2), - N->getOperand(3), DAG.getCondCode(ISD::SETNE)); + return DAG.getNode(AArch64ISD::SETCC_MERGE_ZERO, SDLoc(N), + N->getValueType(0), N->getOperand(1), N->getOperand(2), + N->getOperand(3), DAG.getCondCode(ISD::SETNE)); + break; + case Intrinsic::aarch64_sve_fcmpuo: + return DAG.getNode(AArch64ISD::SETCC_MERGE_ZERO, SDLoc(N), + N->getValueType(0), N->getOperand(1), N->getOperand(2), + N->getOperand(3), DAG.getCondCode(ISD::SETUO)); break; case Intrinsic::aarch64_sve_fadda: return combineSVEReductionOrderedFP(N, AArch64ISD::FADDA_PRED, DAG); @@ -13743,8 +14497,8 @@ static SDValue performExtendCombine(SDNode *N, // helps the backend to decide that an sabdl2 would be useful, saving a real // extract_high operation. if (!DCI.isBeforeLegalizeOps() && N->getOpcode() == ISD::ZERO_EXTEND && - (N->getOperand(0).getOpcode() == AArch64ISD::UABD || - N->getOperand(0).getOpcode() == AArch64ISD::SABD)) { + (N->getOperand(0).getOpcode() == ISD::ABDU || + N->getOperand(0).getOpcode() == ISD::ABDS)) { SDNode *ABDNode = N->getOperand(0).getNode(); SDValue NewABD = tryCombineLongOpWithDup(Intrinsic::not_intrinsic, ABDNode, DCI, DAG); @@ -13753,78 +14507,7 @@ static SDValue performExtendCombine(SDNode *N, return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), N->getValueType(0), NewABD); } - - // This is effectively a custom type legalization for AArch64. - // - // Type legalization will split an extend of a small, legal, type to a larger - // illegal type by first splitting the destination type, often creating - // illegal source types, which then get legalized in isel-confusing ways, - // leading to really terrible codegen. E.g., - // %result = v8i32 sext v8i8 %value - // becomes - // %losrc = extract_subreg %value, ... - // %hisrc = extract_subreg %value, ... - // %lo = v4i32 sext v4i8 %losrc - // %hi = v4i32 sext v4i8 %hisrc - // Things go rapidly downhill from there. - // - // For AArch64, the [sz]ext vector instructions can only go up one element - // size, so we can, e.g., extend from i8 to i16, but to go from i8 to i32 - // take two instructions. - // - // This implies that the most efficient way to do the extend from v8i8 - // to two v4i32 values is to first extend the v8i8 to v8i16, then do - // the normal splitting to happen for the v8i16->v8i32. - - // This is pre-legalization to catch some cases where the default - // type legalization will create ill-tempered code. - if (!DCI.isBeforeLegalizeOps()) - return SDValue(); - - // We're only interested in cleaning things up for non-legal vector types - // here. If both the source and destination are legal, things will just - // work naturally without any fiddling. - const TargetLowering &TLI = DAG.getTargetLoweringInfo(); - EVT ResVT = N->getValueType(0); - if (!ResVT.isVector() || TLI.isTypeLegal(ResVT)) - return SDValue(); - // If the vector type isn't a simple VT, it's beyond the scope of what - // we're worried about here. Let legalization do its thing and hope for - // the best. - SDValue Src = N->getOperand(0); - EVT SrcVT = Src->getValueType(0); - if (!ResVT.isSimple() || !SrcVT.isSimple()) - return SDValue(); - - // If the source VT is a 64-bit fixed or scalable vector, we can play games - // and get the better results we want. - if (SrcVT.getSizeInBits().getKnownMinSize() != 64) - return SDValue(); - - unsigned SrcEltSize = SrcVT.getScalarSizeInBits(); - ElementCount SrcEC = SrcVT.getVectorElementCount(); - SrcVT = MVT::getVectorVT(MVT::getIntegerVT(SrcEltSize * 2), SrcEC); - SDLoc DL(N); - Src = DAG.getNode(N->getOpcode(), DL, SrcVT, Src); - - // Now split the rest of the operation into two halves, each with a 64 - // bit source. - EVT LoVT, HiVT; - SDValue Lo, Hi; - LoVT = HiVT = ResVT.getHalfNumVectorElementsVT(*DAG.getContext()); - - EVT InNVT = EVT::getVectorVT(*DAG.getContext(), SrcVT.getVectorElementType(), - LoVT.getVectorElementCount()); - Lo = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, InNVT, Src, - DAG.getConstant(0, DL, MVT::i64)); - Hi = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, InNVT, Src, - DAG.getConstant(InNVT.getVectorMinNumElements(), DL, MVT::i64)); - Lo = DAG.getNode(N->getOpcode(), DL, LoVT, Lo); - Hi = DAG.getNode(N->getOpcode(), DL, HiVT, Hi); - - // Now combine the parts back together so we still have a single result - // like the combiner expects. - return DAG.getNode(ISD::CONCAT_VECTORS, DL, ResVT, Lo, Hi); + return SDValue(); } static SDValue splitStoreSplat(SelectionDAG &DAG, StoreSDNode &St, @@ -14234,6 +14917,16 @@ static SDValue splitStores(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, S->getMemOperand()->getFlags()); } +static SDValue performSpliceCombine(SDNode *N, SelectionDAG &DAG) { + assert(N->getOpcode() == AArch64ISD::SPLICE && "Unexepected Opcode!"); + + // splice(pg, op1, undef) -> op1 + if (N->getOperand(2).isUndef()) + return N->getOperand(1); + + return SDValue(); +} + static SDValue performUzpCombine(SDNode *N, SelectionDAG &DAG) { SDLoc DL(N); SDValue Op0 = N->getOperand(0); @@ -14259,6 +14952,86 @@ static SDValue performUzpCombine(SDNode *N, SelectionDAG &DAG) { return SDValue(); } +static SDValue performGLD1Combine(SDNode *N, SelectionDAG &DAG) { + unsigned Opc = N->getOpcode(); + + assert(((Opc >= AArch64ISD::GLD1_MERGE_ZERO && // unsigned gather loads + Opc <= AArch64ISD::GLD1_IMM_MERGE_ZERO) || + (Opc >= AArch64ISD::GLD1S_MERGE_ZERO && // signed gather loads + Opc <= AArch64ISD::GLD1S_IMM_MERGE_ZERO)) && + "Invalid opcode."); + + const bool Scaled = Opc == AArch64ISD::GLD1_SCALED_MERGE_ZERO || + Opc == AArch64ISD::GLD1S_SCALED_MERGE_ZERO; + const bool Signed = Opc == AArch64ISD::GLD1S_MERGE_ZERO || + Opc == AArch64ISD::GLD1S_SCALED_MERGE_ZERO; + const bool Extended = Opc == AArch64ISD::GLD1_SXTW_MERGE_ZERO || + Opc == AArch64ISD::GLD1_SXTW_SCALED_MERGE_ZERO || + Opc == AArch64ISD::GLD1_UXTW_MERGE_ZERO || + Opc == AArch64ISD::GLD1_UXTW_SCALED_MERGE_ZERO; + + SDLoc DL(N); + SDValue Chain = N->getOperand(0); + SDValue Pg = N->getOperand(1); + SDValue Base = N->getOperand(2); + SDValue Offset = N->getOperand(3); + SDValue Ty = N->getOperand(4); + + EVT ResVT = N->getValueType(0); + + const auto OffsetOpc = Offset.getOpcode(); + const bool OffsetIsZExt = + OffsetOpc == AArch64ISD::ZERO_EXTEND_INREG_MERGE_PASSTHRU; + const bool OffsetIsSExt = + OffsetOpc == AArch64ISD::SIGN_EXTEND_INREG_MERGE_PASSTHRU; + + // Fold sign/zero extensions of vector offsets into GLD1 nodes where possible. + if (!Extended && (OffsetIsSExt || OffsetIsZExt)) { + SDValue ExtPg = Offset.getOperand(0); + VTSDNode *ExtFrom = cast<VTSDNode>(Offset.getOperand(2).getNode()); + EVT ExtFromEVT = ExtFrom->getVT().getVectorElementType(); + + // If the predicate for the sign- or zero-extended offset is the + // same as the predicate used for this load and the sign-/zero-extension + // was from a 32-bits... + if (ExtPg == Pg && ExtFromEVT == MVT::i32) { + SDValue UnextendedOffset = Offset.getOperand(1); + + unsigned NewOpc = getGatherVecOpcode(Scaled, OffsetIsSExt, true); + if (Signed) + NewOpc = getSignExtendedGatherOpcode(NewOpc); + + return DAG.getNode(NewOpc, DL, {ResVT, MVT::Other}, + {Chain, Pg, Base, UnextendedOffset, Ty}); + } + } + + return SDValue(); +} + +/// Optimize a vector shift instruction and its operand if shifted out +/// bits are not used. +static SDValue performVectorShiftCombine(SDNode *N, + const AArch64TargetLowering &TLI, + TargetLowering::DAGCombinerInfo &DCI) { + assert(N->getOpcode() == AArch64ISD::VASHR || + N->getOpcode() == AArch64ISD::VLSHR); + + SDValue Op = N->getOperand(0); + unsigned OpScalarSize = Op.getScalarValueSizeInBits(); + + unsigned ShiftImm = N->getConstantOperandVal(1); + assert(OpScalarSize > ShiftImm && "Invalid shift imm"); + + APInt ShiftedOutBits = APInt::getLowBitsSet(OpScalarSize, ShiftImm); + APInt DemandedMask = ~ShiftedOutBits; + + if (TLI.SimplifyDemandedBits(Op, DemandedMask, DCI)) + return SDValue(N, 0); + + return SDValue(); +} + /// Target-specific DAG combine function for post-increment LD1 (lane) and /// post-increment LD1R. static SDValue performPostLD1Combine(SDNode *N, @@ -14383,6 +15156,29 @@ static bool performTBISimplification(SDValue Addr, return false; } +static SDValue foldTruncStoreOfExt(SelectionDAG &DAG, SDNode *N) { + assert((N->getOpcode() == ISD::STORE || N->getOpcode() == ISD::MSTORE) && + "Expected STORE dag node in input!"); + + if (auto Store = dyn_cast<StoreSDNode>(N)) { + if (!Store->isTruncatingStore() || Store->isIndexed()) + return SDValue(); + SDValue Ext = Store->getValue(); + auto ExtOpCode = Ext.getOpcode(); + if (ExtOpCode != ISD::ZERO_EXTEND && ExtOpCode != ISD::SIGN_EXTEND && + ExtOpCode != ISD::ANY_EXTEND) + return SDValue(); + SDValue Orig = Ext->getOperand(0); + if (Store->getMemoryVT() != Orig->getValueType(0)) + return SDValue(); + return DAG.getStore(Store->getChain(), SDLoc(Store), Orig, + Store->getBasePtr(), Store->getPointerInfo(), + Store->getAlign()); + } + + return SDValue(); +} + static SDValue performSTORECombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, SelectionDAG &DAG, @@ -14394,54 +15190,8 @@ static SDValue performSTORECombine(SDNode *N, performTBISimplification(N->getOperand(2), DCI, DAG)) return SDValue(N, 0); - return SDValue(); -} - -static SDValue performMaskedGatherScatterCombine(SDNode *N, - TargetLowering::DAGCombinerInfo &DCI, - SelectionDAG &DAG) { - MaskedGatherScatterSDNode *MGS = cast<MaskedGatherScatterSDNode>(N); - assert(MGS && "Can only combine gather load or scatter store nodes"); - - SDLoc DL(MGS); - SDValue Chain = MGS->getChain(); - SDValue Scale = MGS->getScale(); - SDValue Index = MGS->getIndex(); - SDValue Mask = MGS->getMask(); - SDValue BasePtr = MGS->getBasePtr(); - ISD::MemIndexType IndexType = MGS->getIndexType(); - - EVT IdxVT = Index.getValueType(); - - if (DCI.isBeforeLegalize()) { - // SVE gather/scatter requires indices of i32/i64. Promote anything smaller - // prior to legalisation so the result can be split if required. - if ((IdxVT.getVectorElementType() == MVT::i8) || - (IdxVT.getVectorElementType() == MVT::i16)) { - EVT NewIdxVT = IdxVT.changeVectorElementType(MVT::i32); - if (MGS->isIndexSigned()) - Index = DAG.getNode(ISD::SIGN_EXTEND, DL, NewIdxVT, Index); - else - Index = DAG.getNode(ISD::ZERO_EXTEND, DL, NewIdxVT, Index); - - if (auto *MGT = dyn_cast<MaskedGatherSDNode>(MGS)) { - SDValue PassThru = MGT->getPassThru(); - SDValue Ops[] = { Chain, PassThru, Mask, BasePtr, Index, Scale }; - return DAG.getMaskedGather(DAG.getVTList(N->getValueType(0), MVT::Other), - PassThru.getValueType(), DL, Ops, - MGT->getMemOperand(), - MGT->getIndexType(), MGT->getExtensionType()); - } else { - auto *MSC = cast<MaskedScatterSDNode>(MGS); - SDValue Data = MSC->getValue(); - SDValue Ops[] = { Chain, Data, Mask, BasePtr, Index, Scale }; - return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), - MSC->getMemoryVT(), DL, Ops, - MSC->getMemOperand(), IndexType, - MSC->isTruncatingStore()); - } - } - } + if (SDValue Store = foldTruncStoreOfExt(DAG, N)) + return Store; return SDValue(); } @@ -14902,6 +15652,67 @@ static SDValue performBRCONDCombine(SDNode *N, return SDValue(); } +// Optimize CSEL instructions +static SDValue performCSELCombine(SDNode *N, + TargetLowering::DAGCombinerInfo &DCI, + SelectionDAG &DAG) { + // CSEL x, x, cc -> x + if (N->getOperand(0) == N->getOperand(1)) + return N->getOperand(0); + + return performCONDCombine(N, DCI, DAG, 2, 3); +} + +static SDValue performSETCCCombine(SDNode *N, SelectionDAG &DAG) { + assert(N->getOpcode() == ISD::SETCC && "Unexpected opcode!"); + SDValue LHS = N->getOperand(0); + SDValue RHS = N->getOperand(1); + ISD::CondCode Cond = cast<CondCodeSDNode>(N->getOperand(2))->get(); + + // setcc (csel 0, 1, cond, X), 1, ne ==> csel 0, 1, !cond, X + if (Cond == ISD::SETNE && isOneConstant(RHS) && + LHS->getOpcode() == AArch64ISD::CSEL && + isNullConstant(LHS->getOperand(0)) && isOneConstant(LHS->getOperand(1)) && + LHS->hasOneUse()) { + SDLoc DL(N); + + // Invert CSEL's condition. + auto *OpCC = cast<ConstantSDNode>(LHS.getOperand(2)); + auto OldCond = static_cast<AArch64CC::CondCode>(OpCC->getZExtValue()); + auto NewCond = getInvertedCondCode(OldCond); + + // csel 0, 1, !cond, X + SDValue CSEL = + DAG.getNode(AArch64ISD::CSEL, DL, LHS.getValueType(), LHS.getOperand(0), + LHS.getOperand(1), DAG.getConstant(NewCond, DL, MVT::i32), + LHS.getOperand(3)); + return DAG.getZExtOrTrunc(CSEL, DL, N->getValueType(0)); + } + + return SDValue(); +} + +static SDValue performSetccMergeZeroCombine(SDNode *N, SelectionDAG &DAG) { + assert(N->getOpcode() == AArch64ISD::SETCC_MERGE_ZERO && + "Unexpected opcode!"); + + SDValue Pred = N->getOperand(0); + SDValue LHS = N->getOperand(1); + SDValue RHS = N->getOperand(2); + ISD::CondCode Cond = cast<CondCodeSDNode>(N->getOperand(3))->get(); + + // setcc_merge_zero pred (sign_extend (setcc_merge_zero ... pred ...)), 0, ne + // => inner setcc_merge_zero + if (Cond == ISD::SETNE && isZerosVector(RHS.getNode()) && + LHS->getOpcode() == ISD::SIGN_EXTEND && + LHS->getOperand(0)->getValueType(0) == N->getValueType(0) && + LHS->getOperand(0)->getOpcode() == AArch64ISD::SETCC_MERGE_ZERO && + LHS->getOperand(0)->getOperand(0) == Pred) + return LHS->getOperand(0); + + return SDValue(); +} + // Optimize some simple tbz/tbnz cases. Returns the new operand and bit to test // as well as whether the test should be inverted. This code is required to // catch these cases (as opposed to standard dag combines) because @@ -15014,7 +15825,41 @@ static SDValue performVSelectCombine(SDNode *N, SelectionDAG &DAG) { SDValue N0 = N->getOperand(0); EVT CCVT = N0.getValueType(); - if (N0.getOpcode() != ISD::SETCC || CCVT.getVectorNumElements() != 1 || + // Check for sign pattern (VSELECT setgt, iN lhs, -1, 1, -1) and transform + // into (OR (ASR lhs, N-1), 1), which requires less instructions for the + // supported types. + SDValue SetCC = N->getOperand(0); + if (SetCC.getOpcode() == ISD::SETCC && + SetCC.getOperand(2) == DAG.getCondCode(ISD::SETGT)) { + SDValue CmpLHS = SetCC.getOperand(0); + EVT VT = CmpLHS.getValueType(); + SDNode *CmpRHS = SetCC.getOperand(1).getNode(); + SDNode *SplatLHS = N->getOperand(1).getNode(); + SDNode *SplatRHS = N->getOperand(2).getNode(); + APInt SplatLHSVal; + if (CmpLHS.getValueType() == N->getOperand(1).getValueType() && + VT.isSimple() && + is_contained( + makeArrayRef({MVT::v8i8, MVT::v16i8, MVT::v4i16, MVT::v8i16, + MVT::v2i32, MVT::v4i32, MVT::v2i64}), + VT.getSimpleVT().SimpleTy) && + ISD::isConstantSplatVector(SplatLHS, SplatLHSVal) && + SplatLHSVal.isOneValue() && ISD::isConstantSplatVectorAllOnes(CmpRHS) && + ISD::isConstantSplatVectorAllOnes(SplatRHS)) { + unsigned NumElts = VT.getVectorNumElements(); + SmallVector<SDValue, 8> Ops( + NumElts, DAG.getConstant(VT.getScalarSizeInBits() - 1, SDLoc(N), + VT.getScalarType())); + SDValue Val = DAG.getBuildVector(VT, SDLoc(N), Ops); + + auto Shift = DAG.getNode(ISD::SRA, SDLoc(N), VT, CmpLHS, Val); + auto Or = DAG.getNode(ISD::OR, SDLoc(N), VT, Shift, N->getOperand(1)); + return Or; + } + } + + if (N0.getOpcode() != ISD::SETCC || + CCVT.getVectorElementCount() != ElementCount::getFixed(1) || CCVT.getVectorElementType() != MVT::i1) return SDValue(); @@ -15027,10 +15872,9 @@ static SDValue performVSelectCombine(SDNode *N, SelectionDAG &DAG) { SDValue IfTrue = N->getOperand(1); SDValue IfFalse = N->getOperand(2); - SDValue SetCC = - DAG.getSetCC(SDLoc(N), CmpVT.changeVectorElementTypeToInteger(), - N0.getOperand(0), N0.getOperand(1), - cast<CondCodeSDNode>(N0.getOperand(2))->get()); + SetCC = DAG.getSetCC(SDLoc(N), CmpVT.changeVectorElementTypeToInteger(), + N0.getOperand(0), N0.getOperand(1), + cast<CondCodeSDNode>(N0.getOperand(2))->get()); return DAG.getNode(ISD::VSELECT, SDLoc(N), ResVT, SetCC, IfTrue, IfFalse); } @@ -15048,6 +15892,9 @@ static SDValue performSelectCombine(SDNode *N, if (N0.getOpcode() != ISD::SETCC) return SDValue(); + if (ResVT.isScalableVector()) + return SDValue(); + // Make sure the SETCC result is either i1 (initial DAG), or i32, the lowered // scalar SetCCResultType. We also don't expect vectors, because we assume // that selects fed by vector SETCCs are canonicalized to VSELECT. @@ -15180,7 +16027,6 @@ static SDValue getScaledOffsetForBitWidth(SelectionDAG &DAG, SDValue Offset, /// [<Zn>.[S|D]{, #<imm>}] /// /// where <imm> = sizeof(<T>) * k, for k = 0, 1, ..., 31. - inline static bool isValidImmForSVEVecImmAddrMode(unsigned OffsetInBytes, unsigned ScalarSizeInBytes) { // The immediate is not a multiple of the scalar size. @@ -15588,6 +16434,97 @@ static SDValue combineSVEPrefetchVecBaseImmOff(SDNode *N, SelectionDAG &DAG, return DAG.getNode(N->getOpcode(), DL, DAG.getVTList(MVT::Other), Ops); } +// Return true if the vector operation can guarantee only the first lane of its +// result contains data, with all bits in other lanes set to zero. +static bool isLanes1toNKnownZero(SDValue Op) { + switch (Op.getOpcode()) { + default: + return false; + case AArch64ISD::ANDV_PRED: + case AArch64ISD::EORV_PRED: + case AArch64ISD::FADDA_PRED: + case AArch64ISD::FADDV_PRED: + case AArch64ISD::FMAXNMV_PRED: + case AArch64ISD::FMAXV_PRED: + case AArch64ISD::FMINNMV_PRED: + case AArch64ISD::FMINV_PRED: + case AArch64ISD::ORV_PRED: + case AArch64ISD::SADDV_PRED: + case AArch64ISD::SMAXV_PRED: + case AArch64ISD::SMINV_PRED: + case AArch64ISD::UADDV_PRED: + case AArch64ISD::UMAXV_PRED: + case AArch64ISD::UMINV_PRED: + return true; + } +} + +static SDValue removeRedundantInsertVectorElt(SDNode *N) { + assert(N->getOpcode() == ISD::INSERT_VECTOR_ELT && "Unexpected node!"); + SDValue InsertVec = N->getOperand(0); + SDValue InsertElt = N->getOperand(1); + SDValue InsertIdx = N->getOperand(2); + + // We only care about inserts into the first element... + if (!isNullConstant(InsertIdx)) + return SDValue(); + // ...of a zero'd vector... + if (!ISD::isConstantSplatVectorAllZeros(InsertVec.getNode())) + return SDValue(); + // ...where the inserted data was previously extracted... + if (InsertElt.getOpcode() != ISD::EXTRACT_VECTOR_ELT) + return SDValue(); + + SDValue ExtractVec = InsertElt.getOperand(0); + SDValue ExtractIdx = InsertElt.getOperand(1); + + // ...from the first element of a vector. + if (!isNullConstant(ExtractIdx)) + return SDValue(); + + // If we get here we are effectively trying to zero lanes 1-N of a vector. + + // Ensure there's no type conversion going on. + if (N->getValueType(0) != ExtractVec.getValueType()) + return SDValue(); + + if (!isLanes1toNKnownZero(ExtractVec)) + return SDValue(); + + // The explicit zeroing is redundant. + return ExtractVec; +} + +static SDValue +performInsertVectorEltCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) { + if (SDValue Res = removeRedundantInsertVectorElt(N)) + return Res; + + return performPostLD1Combine(N, DCI, true); +} + +SDValue performSVESpliceCombine(SDNode *N, SelectionDAG &DAG) { + EVT Ty = N->getValueType(0); + if (Ty.isInteger()) + return SDValue(); + + EVT IntTy = Ty.changeVectorElementTypeToInteger(); + EVT ExtIntTy = getPackedSVEVectorVT(IntTy.getVectorElementCount()); + if (ExtIntTy.getVectorElementType().getScalarSizeInBits() < + IntTy.getVectorElementType().getScalarSizeInBits()) + return SDValue(); + + SDLoc DL(N); + SDValue LHS = DAG.getAnyExtOrTrunc(DAG.getBitcast(IntTy, N->getOperand(0)), + DL, ExtIntTy); + SDValue RHS = DAG.getAnyExtOrTrunc(DAG.getBitcast(IntTy, N->getOperand(1)), + DL, ExtIntTy); + SDValue Idx = N->getOperand(2); + SDValue Splice = DAG.getNode(ISD::VECTOR_SPLICE, DL, ExtIntTy, LHS, RHS, Idx); + SDValue Trunc = DAG.getAnyExtOrTrunc(Splice, DL, IntTy); + return DAG.getBitcast(Ty, Trunc); +} + SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N, DAGCombinerInfo &DCI) const { SelectionDAG &DAG = DCI.DAG; @@ -15595,8 +16532,6 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N, default: LLVM_DEBUG(dbgs() << "Custom combining: skipping\n"); break; - case ISD::ABS: - return performABSCombine(N, DAG, DCI, Subtarget); case ISD::ADD: case ISD::SUB: return performAddSubCombine(N, DCI, DAG); @@ -15634,30 +16569,53 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N, return performSelectCombine(N, DCI); case ISD::VSELECT: return performVSelectCombine(N, DCI.DAG); + case ISD::SETCC: + return performSETCCCombine(N, DAG); case ISD::LOAD: if (performTBISimplification(N->getOperand(1), DCI, DAG)) return SDValue(N, 0); break; case ISD::STORE: return performSTORECombine(N, DCI, DAG, Subtarget); - case ISD::MGATHER: - case ISD::MSCATTER: - return performMaskedGatherScatterCombine(N, DCI, DAG); + case ISD::VECTOR_SPLICE: + return performSVESpliceCombine(N, DAG); case AArch64ISD::BRCOND: return performBRCONDCombine(N, DCI, DAG); case AArch64ISD::TBNZ: case AArch64ISD::TBZ: return performTBZCombine(N, DCI, DAG); case AArch64ISD::CSEL: - return performCONDCombine(N, DCI, DAG, 2, 3); + return performCSELCombine(N, DCI, DAG); case AArch64ISD::DUP: return performPostLD1Combine(N, DCI, false); case AArch64ISD::NVCAST: return performNVCASTCombine(N); + case AArch64ISD::SPLICE: + return performSpliceCombine(N, DAG); case AArch64ISD::UZP1: return performUzpCombine(N, DAG); + case AArch64ISD::SETCC_MERGE_ZERO: + return performSetccMergeZeroCombine(N, DAG); + case AArch64ISD::GLD1_MERGE_ZERO: + case AArch64ISD::GLD1_SCALED_MERGE_ZERO: + case AArch64ISD::GLD1_UXTW_MERGE_ZERO: + case AArch64ISD::GLD1_SXTW_MERGE_ZERO: + case AArch64ISD::GLD1_UXTW_SCALED_MERGE_ZERO: + case AArch64ISD::GLD1_SXTW_SCALED_MERGE_ZERO: + case AArch64ISD::GLD1_IMM_MERGE_ZERO: + case AArch64ISD::GLD1S_MERGE_ZERO: + case AArch64ISD::GLD1S_SCALED_MERGE_ZERO: + case AArch64ISD::GLD1S_UXTW_MERGE_ZERO: + case AArch64ISD::GLD1S_SXTW_MERGE_ZERO: + case AArch64ISD::GLD1S_UXTW_SCALED_MERGE_ZERO: + case AArch64ISD::GLD1S_SXTW_SCALED_MERGE_ZERO: + case AArch64ISD::GLD1S_IMM_MERGE_ZERO: + return performGLD1Combine(N, DAG); + case AArch64ISD::VASHR: + case AArch64ISD::VLSHR: + return performVectorShiftCombine(N, *this, DCI); case ISD::INSERT_VECTOR_ELT: - return performPostLD1Combine(N, DCI, true); + return performInsertVectorEltCombine(N, DCI); case ISD::EXTRACT_VECTOR_ELT: return performExtractVectorEltCombine(N, DAG); case ISD::VECREDUCE_ADD: @@ -15881,6 +16839,24 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N, LowerSVEStructLoad(IntrinsicID, LoadOps, N->getValueType(0), DAG, DL); return DAG.getMergeValues({Result, Chain}, DL); } + case Intrinsic::aarch64_rndr: + case Intrinsic::aarch64_rndrrs: { + unsigned IntrinsicID = + cast<ConstantSDNode>(N->getOperand(1))->getZExtValue(); + auto Register = + (IntrinsicID == Intrinsic::aarch64_rndr ? AArch64SysReg::RNDR + : AArch64SysReg::RNDRRS); + SDLoc DL(N); + SDValue A = DAG.getNode( + AArch64ISD::MRS, DL, DAG.getVTList(MVT::i64, MVT::Glue, MVT::Other), + N->getOperand(0), DAG.getConstant(Register, DL, MVT::i64)); + SDValue B = DAG.getNode( + AArch64ISD::CSINC, DL, MVT::i32, DAG.getConstant(0, DL, MVT::i32), + DAG.getConstant(0, DL, MVT::i32), + DAG.getConstant(AArch64CC::NE, DL, MVT::i32), A.getValue(1)); + return DAG.getMergeValues( + {A, DAG.getZExtOrTrunc(B, DL, MVT::i1), A.getValue(2)}, DL); + } default: break; } @@ -16007,13 +16983,22 @@ bool AArch64TargetLowering::getPostIndexedAddressParts( return true; } -static void ReplaceBITCASTResults(SDNode *N, SmallVectorImpl<SDValue> &Results, - SelectionDAG &DAG) { +void AArch64TargetLowering::ReplaceBITCASTResults( + SDNode *N, SmallVectorImpl<SDValue> &Results, SelectionDAG &DAG) const { SDLoc DL(N); SDValue Op = N->getOperand(0); + EVT VT = N->getValueType(0); + EVT SrcVT = Op.getValueType(); + + if (VT.isScalableVector() && !isTypeLegal(VT) && isTypeLegal(SrcVT)) { + assert(!VT.isFloatingPoint() && SrcVT.isFloatingPoint() && + "Expected fp->int bitcast!"); + SDValue CastResult = getSVESafeBitCast(getSVEContainerType(VT), Op, DAG); + Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, VT, CastResult)); + return; + } - if (N->getValueType(0) != MVT::i16 || - (Op.getValueType() != MVT::f16 && Op.getValueType() != MVT::bf16)) + if (VT != MVT::i16 || (SrcVT != MVT::f16 && SrcVT != MVT::bf16)) return; Op = SDValue( @@ -16107,6 +17092,7 @@ static void ReplaceCMP_SWAP_128Results(SDNode *N, assert(N->getValueType(0) == MVT::i128 && "AtomicCmpSwap on types less than 128 should be legal"); + MachineMemOperand *MemOp = cast<MemSDNode>(N)->getMemOperand(); if (Subtarget->hasLSE() || Subtarget->outlineAtomics()) { // LSE has a 128-bit compare and swap (CASP), but i128 is not a legal type, // so lower it here, wrapped in REG_SEQUENCE and EXTRACT_SUBREG. @@ -16117,10 +17103,8 @@ static void ReplaceCMP_SWAP_128Results(SDNode *N, N->getOperand(0), // Chain in }; - MachineMemOperand *MemOp = cast<MemSDNode>(N)->getMemOperand(); - unsigned Opcode; - switch (MemOp->getOrdering()) { + switch (MemOp->getMergedOrdering()) { case AtomicOrdering::Monotonic: Opcode = AArch64::CASPX; break; @@ -16155,15 +17139,32 @@ static void ReplaceCMP_SWAP_128Results(SDNode *N, return; } + unsigned Opcode; + switch (MemOp->getMergedOrdering()) { + case AtomicOrdering::Monotonic: + Opcode = AArch64::CMP_SWAP_128_MONOTONIC; + break; + case AtomicOrdering::Acquire: + Opcode = AArch64::CMP_SWAP_128_ACQUIRE; + break; + case AtomicOrdering::Release: + Opcode = AArch64::CMP_SWAP_128_RELEASE; + break; + case AtomicOrdering::AcquireRelease: + case AtomicOrdering::SequentiallyConsistent: + Opcode = AArch64::CMP_SWAP_128; + break; + default: + llvm_unreachable("Unexpected ordering!"); + } + auto Desired = splitInt128(N->getOperand(2), DAG); auto New = splitInt128(N->getOperand(3), DAG); SDValue Ops[] = {N->getOperand(1), Desired.first, Desired.second, New.first, New.second, N->getOperand(0)}; SDNode *CmpSwap = DAG.getMachineNode( - AArch64::CMP_SWAP_128, SDLoc(N), - DAG.getVTList(MVT::i64, MVT::i64, MVT::i32, MVT::Other), Ops); - - MachineMemOperand *MemOp = cast<MemSDNode>(N)->getMemOperand(); + Opcode, SDLoc(N), DAG.getVTList(MVT::i64, MVT::i64, MVT::i32, MVT::Other), + Ops); DAG.setNodeMemRefs(cast<MachineSDNode>(CmpSwap), {MemOp}); Results.push_back(DAG.getNode(ISD::BUILD_PAIR, SDLoc(N), MVT::i128, @@ -16241,6 +17242,10 @@ void AArch64TargetLowering::ReplaceNodeResults( case ISD::EXTRACT_SUBVECTOR: ReplaceExtractSubVectorResults(N, Results, DAG); return; + case ISD::INSERT_SUBVECTOR: + // Custom lowering has been requested for INSERT_SUBVECTOR -- but delegate + // to common code for result type legalisation + return; case ISD::INTRINSIC_WO_CHAIN: { EVT VT = N->getValueType(0); assert((VT == MVT::i8 || VT == MVT::i16) && @@ -16334,25 +17339,36 @@ AArch64TargetLowering::shouldExpandAtomicRMWInIR(AtomicRMWInst *AI) const { unsigned Size = AI->getType()->getPrimitiveSizeInBits(); if (Size > 128) return AtomicExpansionKind::None; - // Nand not supported in LSE. - if (AI->getOperation() == AtomicRMWInst::Nand) return AtomicExpansionKind::LLSC; - // Leave 128 bits to LLSC. - if (Subtarget->hasLSE() && Size < 128) - return AtomicExpansionKind::None; - if (Subtarget->outlineAtomics() && Size < 128) { - // [U]Min/[U]Max RWM atomics are used in __sync_fetch_ libcalls so far. - // Don't outline them unless - // (1) high level <atomic> support approved: - // http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2020/p0493r1.pdf - // (2) low level libgcc and compiler-rt support implemented by: - // min/max outline atomics helpers - if (AI->getOperation() != AtomicRMWInst::Min && - AI->getOperation() != AtomicRMWInst::Max && - AI->getOperation() != AtomicRMWInst::UMin && - AI->getOperation() != AtomicRMWInst::UMax) { + + // Nand is not supported in LSE. + // Leave 128 bits to LLSC or CmpXChg. + if (AI->getOperation() != AtomicRMWInst::Nand && Size < 128) { + if (Subtarget->hasLSE()) return AtomicExpansionKind::None; + if (Subtarget->outlineAtomics()) { + // [U]Min/[U]Max RWM atomics are used in __sync_fetch_ libcalls so far. + // Don't outline them unless + // (1) high level <atomic> support approved: + // http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2020/p0493r1.pdf + // (2) low level libgcc and compiler-rt support implemented by: + // min/max outline atomics helpers + if (AI->getOperation() != AtomicRMWInst::Min && + AI->getOperation() != AtomicRMWInst::Max && + AI->getOperation() != AtomicRMWInst::UMin && + AI->getOperation() != AtomicRMWInst::UMax) { + return AtomicExpansionKind::None; + } } } + + // At -O0, fast-regalloc cannot cope with the live vregs necessary to + // implement atomicrmw without spilling. If the target address is also on the + // stack and close enough to the spill slot, this can lead to a situation + // where the monitor always gets cleared and the atomic operation can never + // succeed. So at -O0 lower this operation to a CAS loop. + if (getTargetMachine().getOptLevel() == CodeGenOpt::None) + return AtomicExpansionKind::CmpXChg; + return AtomicExpansionKind::LLSC; } @@ -16369,19 +17385,26 @@ AArch64TargetLowering::shouldExpandAtomicCmpXchgInIR( // can never succeed. So at -O0 we need a late-expanded pseudo-inst instead. if (getTargetMachine().getOptLevel() == CodeGenOpt::None) return AtomicExpansionKind::None; + + // 128-bit atomic cmpxchg is weird; AtomicExpand doesn't know how to expand + // it. + unsigned Size = AI->getCompareOperand()->getType()->getPrimitiveSizeInBits(); + if (Size > 64) + return AtomicExpansionKind::None; + return AtomicExpansionKind::LLSC; } -Value *AArch64TargetLowering::emitLoadLinked(IRBuilder<> &Builder, Value *Addr, +Value *AArch64TargetLowering::emitLoadLinked(IRBuilderBase &Builder, + Type *ValueTy, Value *Addr, AtomicOrdering Ord) const { Module *M = Builder.GetInsertBlock()->getParent()->getParent(); - Type *ValTy = cast<PointerType>(Addr->getType())->getElementType(); bool IsAcquire = isAcquireOrStronger(Ord); // Since i128 isn't legal and intrinsics don't get type-lowered, the ldrexd // intrinsic must return {i64, i64} and we have to recombine them into a // single i128 here. - if (ValTy->getPrimitiveSizeInBits() == 128) { + if (ValueTy->getPrimitiveSizeInBits() == 128) { Intrinsic::ID Int = IsAcquire ? Intrinsic::aarch64_ldaxp : Intrinsic::aarch64_ldxp; Function *Ldxr = Intrinsic::getDeclaration(M, Int); @@ -16391,10 +17414,10 @@ Value *AArch64TargetLowering::emitLoadLinked(IRBuilder<> &Builder, Value *Addr, Value *Lo = Builder.CreateExtractValue(LoHi, 0, "lo"); Value *Hi = Builder.CreateExtractValue(LoHi, 1, "hi"); - Lo = Builder.CreateZExt(Lo, ValTy, "lo64"); - Hi = Builder.CreateZExt(Hi, ValTy, "hi64"); + Lo = Builder.CreateZExt(Lo, ValueTy, "lo64"); + Hi = Builder.CreateZExt(Hi, ValueTy, "hi64"); return Builder.CreateOr( - Lo, Builder.CreateShl(Hi, ConstantInt::get(ValTy, 64)), "val64"); + Lo, Builder.CreateShl(Hi, ConstantInt::get(ValueTy, 64)), "val64"); } Type *Tys[] = { Addr->getType() }; @@ -16402,22 +17425,20 @@ Value *AArch64TargetLowering::emitLoadLinked(IRBuilder<> &Builder, Value *Addr, IsAcquire ? Intrinsic::aarch64_ldaxr : Intrinsic::aarch64_ldxr; Function *Ldxr = Intrinsic::getDeclaration(M, Int, Tys); - Type *EltTy = cast<PointerType>(Addr->getType())->getElementType(); - const DataLayout &DL = M->getDataLayout(); - IntegerType *IntEltTy = Builder.getIntNTy(DL.getTypeSizeInBits(EltTy)); + IntegerType *IntEltTy = Builder.getIntNTy(DL.getTypeSizeInBits(ValueTy)); Value *Trunc = Builder.CreateTrunc(Builder.CreateCall(Ldxr, Addr), IntEltTy); - return Builder.CreateBitCast(Trunc, EltTy); + return Builder.CreateBitCast(Trunc, ValueTy); } void AArch64TargetLowering::emitAtomicCmpXchgNoStoreLLBalance( - IRBuilder<> &Builder) const { + IRBuilderBase &Builder) const { Module *M = Builder.GetInsertBlock()->getParent()->getParent(); Builder.CreateCall(Intrinsic::getDeclaration(M, Intrinsic::aarch64_clrex)); } -Value *AArch64TargetLowering::emitStoreConditional(IRBuilder<> &Builder, +Value *AArch64TargetLowering::emitStoreConditional(IRBuilderBase &Builder, Value *Val, Value *Addr, AtomicOrdering Ord) const { Module *M = Builder.GetInsertBlock()->getParent()->getParent(); @@ -16454,15 +17475,17 @@ Value *AArch64TargetLowering::emitStoreConditional(IRBuilder<> &Builder, } bool AArch64TargetLowering::functionArgumentNeedsConsecutiveRegisters( - Type *Ty, CallingConv::ID CallConv, bool isVarArg) const { - if (Ty->isArrayTy()) - return true; - - const TypeSize &TySize = Ty->getPrimitiveSizeInBits(); - if (TySize.isScalable() && TySize.getKnownMinSize() > 128) - return true; + Type *Ty, CallingConv::ID CallConv, bool isVarArg, + const DataLayout &DL) const { + if (!Ty->isArrayTy()) { + const TypeSize &TySize = Ty->getPrimitiveSizeInBits(); + return TySize.isScalable() && TySize.getKnownMinSize() > 128; + } - return false; + // All non aggregate members of the type must have the same type + SmallVector<EVT> ValueVTs; + ComputeValueVTs(*this, DL, Ty, ValueVTs); + return is_splat(ValueVTs); } bool AArch64TargetLowering::shouldNormalizeToSelectSequence(LLVMContext &, @@ -16470,7 +17493,7 @@ bool AArch64TargetLowering::shouldNormalizeToSelectSequence(LLVMContext &, return false; } -static Value *UseTlsOffset(IRBuilder<> &IRB, unsigned Offset) { +static Value *UseTlsOffset(IRBuilderBase &IRB, unsigned Offset) { Module *M = IRB.GetInsertBlock()->getParent()->getParent(); Function *ThreadPointerFunc = Intrinsic::getDeclaration(M, Intrinsic::thread_pointer); @@ -16480,7 +17503,7 @@ static Value *UseTlsOffset(IRBuilder<> &IRB, unsigned Offset) { IRB.getInt8PtrTy()->getPointerTo(0)); } -Value *AArch64TargetLowering::getIRStackGuard(IRBuilder<> &IRB) const { +Value *AArch64TargetLowering::getIRStackGuard(IRBuilderBase &IRB) const { // Android provides a fixed TLS slot for the stack cookie. See the definition // of TLS_SLOT_STACK_GUARD in // https://android.googlesource.com/platform/bionic/+/master/libc/private/bionic_tls.h @@ -16529,7 +17552,8 @@ Function *AArch64TargetLowering::getSSPStackGuardCheck(const Module &M) const { return TargetLowering::getSSPStackGuardCheck(M); } -Value *AArch64TargetLowering::getSafeStackPointerLocation(IRBuilder<> &IRB) const { +Value * +AArch64TargetLowering::getSafeStackPointerLocation(IRBuilderBase &IRB) const { // Android provides a fixed TLS slot for the SafeStack pointer. See the // definition of TLS_SLOT_SAFESTACK in // https://android.googlesource.com/platform/bionic/+/master/libc/private/bionic_tls.h @@ -16854,6 +17878,66 @@ SDValue AArch64TargetLowering::LowerFixedLengthVectorLoadToSVE( return DAG.getMergeValues(MergedValues, DL); } +static SDValue convertFixedMaskToScalableVector(SDValue Mask, + SelectionDAG &DAG) { + SDLoc DL(Mask); + EVT InVT = Mask.getValueType(); + EVT ContainerVT = getContainerForFixedLengthVector(DAG, InVT); + + auto Op1 = convertToScalableVector(DAG, ContainerVT, Mask); + auto Op2 = DAG.getConstant(0, DL, ContainerVT); + auto Pg = getPredicateForFixedLengthVector(DAG, DL, InVT); + + EVT CmpVT = Pg.getValueType(); + return DAG.getNode(AArch64ISD::SETCC_MERGE_ZERO, DL, CmpVT, + {Pg, Op1, Op2, DAG.getCondCode(ISD::SETNE)}); +} + +// Convert all fixed length vector loads larger than NEON to masked_loads. +SDValue AArch64TargetLowering::LowerFixedLengthVectorMLoadToSVE( + SDValue Op, SelectionDAG &DAG) const { + auto Load = cast<MaskedLoadSDNode>(Op); + + if (Load->getExtensionType() != ISD::LoadExtType::NON_EXTLOAD) + return SDValue(); + + SDLoc DL(Op); + EVT VT = Op.getValueType(); + EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT); + + SDValue Mask = convertFixedMaskToScalableVector(Load->getMask(), DAG); + + SDValue PassThru; + bool IsPassThruZeroOrUndef = false; + + if (Load->getPassThru()->isUndef()) { + PassThru = DAG.getUNDEF(ContainerVT); + IsPassThruZeroOrUndef = true; + } else { + if (ContainerVT.isInteger()) + PassThru = DAG.getConstant(0, DL, ContainerVT); + else + PassThru = DAG.getConstantFP(0, DL, ContainerVT); + if (isZerosVector(Load->getPassThru().getNode())) + IsPassThruZeroOrUndef = true; + } + + auto NewLoad = DAG.getMaskedLoad( + ContainerVT, DL, Load->getChain(), Load->getBasePtr(), Load->getOffset(), + Mask, PassThru, Load->getMemoryVT(), Load->getMemOperand(), + Load->getAddressingMode(), Load->getExtensionType()); + + if (!IsPassThruZeroOrUndef) { + SDValue OldPassThru = + convertToScalableVector(DAG, ContainerVT, Load->getPassThru()); + NewLoad = DAG.getSelect(DL, ContainerVT, Mask, NewLoad, OldPassThru); + } + + auto Result = convertFromScalableVector(DAG, VT, NewLoad); + SDValue MergedValues[2] = {Result, Load->getChain()}; + return DAG.getMergeValues(MergedValues, DL); +} + // Convert all fixed length vector stores larger than NEON to masked_stores. SDValue AArch64TargetLowering::LowerFixedLengthVectorStoreToSVE( SDValue Op, SelectionDAG &DAG) const { @@ -16871,6 +17955,26 @@ SDValue AArch64TargetLowering::LowerFixedLengthVectorStoreToSVE( Store->isTruncatingStore()); } +SDValue AArch64TargetLowering::LowerFixedLengthVectorMStoreToSVE( + SDValue Op, SelectionDAG &DAG) const { + auto Store = cast<MaskedStoreSDNode>(Op); + + if (Store->isTruncatingStore()) + return SDValue(); + + SDLoc DL(Op); + EVT VT = Store->getValue().getValueType(); + EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT); + + auto NewValue = convertToScalableVector(DAG, ContainerVT, Store->getValue()); + SDValue Mask = convertFixedMaskToScalableVector(Store->getMask(), DAG); + + return DAG.getMaskedStore( + Store->getChain(), DL, NewValue, Store->getBasePtr(), Store->getOffset(), + Mask, Store->getMemoryVT(), Store->getMemOperand(), + Store->getAddressingMode(), Store->isTruncatingStore()); +} + SDValue AArch64TargetLowering::LowerFixedLengthVectorIntDivideToSVE( SDValue Op, SelectionDAG &DAG) const { SDLoc dl(Op); @@ -16890,6 +17994,16 @@ SDValue AArch64TargetLowering::LowerFixedLengthVectorIntDivideToSVE( EVT FixedWidenedVT = HalfVT.widenIntegerVectorElementType(*DAG.getContext()); EVT ScalableWidenedVT = getContainerForFixedLengthVector(DAG, FixedWidenedVT); + // If this is not a full vector, extend, div, and truncate it. + EVT WidenedVT = VT.widenIntegerVectorElementType(*DAG.getContext()); + if (DAG.getTargetLoweringInfo().isTypeLegal(WidenedVT)) { + unsigned ExtendOpcode = Signed ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND; + SDValue Op0 = DAG.getNode(ExtendOpcode, dl, WidenedVT, Op.getOperand(0)); + SDValue Op1 = DAG.getNode(ExtendOpcode, dl, WidenedVT, Op.getOperand(1)); + SDValue Div = DAG.getNode(Op.getOpcode(), dl, WidenedVT, Op0, Op1); + return DAG.getNode(ISD::TRUNCATE, dl, VT, Div); + } + // Convert the operands to scalable vectors. SDValue Op0 = convertToScalableVector(DAG, ContainerVT, Op.getOperand(0)); SDValue Op1 = convertToScalableVector(DAG, ContainerVT, Op.getOperand(1)); @@ -16993,6 +18107,35 @@ SDValue AArch64TargetLowering::LowerFixedLengthVectorTruncateToSVE( return convertFromScalableVector(DAG, VT, Val); } +SDValue AArch64TargetLowering::LowerFixedLengthExtractVectorElt( + SDValue Op, SelectionDAG &DAG) const { + EVT VT = Op.getValueType(); + EVT InVT = Op.getOperand(0).getValueType(); + assert(InVT.isFixedLengthVector() && "Expected fixed length vector type!"); + + SDLoc DL(Op); + EVT ContainerVT = getContainerForFixedLengthVector(DAG, InVT); + SDValue Op0 = convertToScalableVector(DAG, ContainerVT, Op->getOperand(0)); + + return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Op0, Op.getOperand(1)); +} + +SDValue AArch64TargetLowering::LowerFixedLengthInsertVectorElt( + SDValue Op, SelectionDAG &DAG) const { + EVT VT = Op.getValueType(); + assert(VT.isFixedLengthVector() && "Expected fixed length vector type!"); + + SDLoc DL(Op); + EVT InVT = Op.getOperand(0).getValueType(); + EVT ContainerVT = getContainerForFixedLengthVector(DAG, InVT); + SDValue Op0 = convertToScalableVector(DAG, ContainerVT, Op->getOperand(0)); + + auto ScalableRes = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, ContainerVT, Op0, + Op.getOperand(1), Op.getOperand(2)); + + return convertFromScalableVector(DAG, VT, ScalableRes); +} + // Convert vector operation 'Op' to an equivalent predicated operation whereby // the original operation's type is used to construct a suitable predicate. // NOTE: The results for inactive lanes are undefined. @@ -17209,10 +18352,6 @@ SDValue AArch64TargetLowering::LowerFixedLengthVectorSetccToSVE( assert(Op.getValueType() == InVT.changeTypeToInteger() && "Expected integer result of the same bit length as the inputs!"); - // Expand floating point vector comparisons. - if (InVT.isFloatingPoint()) - return SDValue(); - auto Op1 = convertToScalableVector(DAG, ContainerVT, Op.getOperand(0)); auto Op2 = convertToScalableVector(DAG, ContainerVT, Op.getOperand(1)); auto Pg = getPredicateForFixedLengthVector(DAG, DL, InVT); @@ -17226,6 +18365,229 @@ SDValue AArch64TargetLowering::LowerFixedLengthVectorSetccToSVE( return convertFromScalableVector(DAG, Op.getValueType(), Promote); } +SDValue +AArch64TargetLowering::LowerFixedLengthBitcastToSVE(SDValue Op, + SelectionDAG &DAG) const { + SDLoc DL(Op); + auto SrcOp = Op.getOperand(0); + EVT VT = Op.getValueType(); + EVT ContainerDstVT = getContainerForFixedLengthVector(DAG, VT); + EVT ContainerSrcVT = + getContainerForFixedLengthVector(DAG, SrcOp.getValueType()); + + SrcOp = convertToScalableVector(DAG, ContainerSrcVT, SrcOp); + Op = DAG.getNode(ISD::BITCAST, DL, ContainerDstVT, SrcOp); + return convertFromScalableVector(DAG, VT, Op); +} + +SDValue AArch64TargetLowering::LowerFixedLengthConcatVectorsToSVE( + SDValue Op, SelectionDAG &DAG) const { + SDLoc DL(Op); + unsigned NumOperands = Op->getNumOperands(); + + assert(NumOperands > 1 && isPowerOf2_32(NumOperands) && + "Unexpected number of operands in CONCAT_VECTORS"); + + auto SrcOp1 = Op.getOperand(0); + auto SrcOp2 = Op.getOperand(1); + EVT VT = Op.getValueType(); + EVT SrcVT = SrcOp1.getValueType(); + + if (NumOperands > 2) { + SmallVector<SDValue, 4> Ops; + EVT PairVT = SrcVT.getDoubleNumVectorElementsVT(*DAG.getContext()); + for (unsigned I = 0; I < NumOperands; I += 2) + Ops.push_back(DAG.getNode(ISD::CONCAT_VECTORS, DL, PairVT, + Op->getOperand(I), Op->getOperand(I + 1))); + + return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Ops); + } + + EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT); + + SDValue Pg = getPredicateForFixedLengthVector(DAG, DL, SrcVT); + SrcOp1 = convertToScalableVector(DAG, ContainerVT, SrcOp1); + SrcOp2 = convertToScalableVector(DAG, ContainerVT, SrcOp2); + + Op = DAG.getNode(AArch64ISD::SPLICE, DL, ContainerVT, Pg, SrcOp1, SrcOp2); + + return convertFromScalableVector(DAG, VT, Op); +} + +SDValue +AArch64TargetLowering::LowerFixedLengthFPExtendToSVE(SDValue Op, + SelectionDAG &DAG) const { + EVT VT = Op.getValueType(); + assert(VT.isFixedLengthVector() && "Expected fixed length vector type!"); + + SDLoc DL(Op); + SDValue Val = Op.getOperand(0); + SDValue Pg = getPredicateForVector(DAG, DL, VT); + EVT SrcVT = Val.getValueType(); + EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT); + EVT ExtendVT = ContainerVT.changeVectorElementType( + SrcVT.getVectorElementType()); + + Val = DAG.getNode(ISD::BITCAST, DL, SrcVT.changeTypeToInteger(), Val); + Val = DAG.getNode(ISD::ANY_EXTEND, DL, VT.changeTypeToInteger(), Val); + + Val = convertToScalableVector(DAG, ContainerVT.changeTypeToInteger(), Val); + Val = getSVESafeBitCast(ExtendVT, Val, DAG); + Val = DAG.getNode(AArch64ISD::FP_EXTEND_MERGE_PASSTHRU, DL, ContainerVT, + Pg, Val, DAG.getUNDEF(ContainerVT)); + + return convertFromScalableVector(DAG, VT, Val); +} + +SDValue +AArch64TargetLowering::LowerFixedLengthFPRoundToSVE(SDValue Op, + SelectionDAG &DAG) const { + EVT VT = Op.getValueType(); + assert(VT.isFixedLengthVector() && "Expected fixed length vector type!"); + + SDLoc DL(Op); + SDValue Val = Op.getOperand(0); + EVT SrcVT = Val.getValueType(); + EVT ContainerSrcVT = getContainerForFixedLengthVector(DAG, SrcVT); + EVT RoundVT = ContainerSrcVT.changeVectorElementType( + VT.getVectorElementType()); + SDValue Pg = getPredicateForVector(DAG, DL, RoundVT); + + Val = convertToScalableVector(DAG, ContainerSrcVT, Val); + Val = DAG.getNode(AArch64ISD::FP_ROUND_MERGE_PASSTHRU, DL, RoundVT, Pg, Val, + Op.getOperand(1), DAG.getUNDEF(RoundVT)); + Val = getSVESafeBitCast(ContainerSrcVT.changeTypeToInteger(), Val, DAG); + Val = convertFromScalableVector(DAG, SrcVT.changeTypeToInteger(), Val); + + Val = DAG.getNode(ISD::TRUNCATE, DL, VT.changeTypeToInteger(), Val); + return DAG.getNode(ISD::BITCAST, DL, VT, Val); +} + +SDValue +AArch64TargetLowering::LowerFixedLengthIntToFPToSVE(SDValue Op, + SelectionDAG &DAG) const { + EVT VT = Op.getValueType(); + assert(VT.isFixedLengthVector() && "Expected fixed length vector type!"); + + bool IsSigned = Op.getOpcode() == ISD::SINT_TO_FP; + unsigned Opcode = IsSigned ? AArch64ISD::SINT_TO_FP_MERGE_PASSTHRU + : AArch64ISD::UINT_TO_FP_MERGE_PASSTHRU; + + SDLoc DL(Op); + SDValue Val = Op.getOperand(0); + EVT SrcVT = Val.getValueType(); + EVT ContainerDstVT = getContainerForFixedLengthVector(DAG, VT); + EVT ContainerSrcVT = getContainerForFixedLengthVector(DAG, SrcVT); + + if (ContainerSrcVT.getVectorElementType().getSizeInBits() <= + ContainerDstVT.getVectorElementType().getSizeInBits()) { + SDValue Pg = getPredicateForVector(DAG, DL, VT); + + Val = DAG.getNode(IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND, DL, + VT.changeTypeToInteger(), Val); + + Val = convertToScalableVector(DAG, ContainerSrcVT, Val); + Val = getSVESafeBitCast(ContainerDstVT.changeTypeToInteger(), Val, DAG); + // Safe to use a larger than specified operand since we just unpacked the + // data, hence the upper bits are zero. + Val = DAG.getNode(Opcode, DL, ContainerDstVT, Pg, Val, + DAG.getUNDEF(ContainerDstVT)); + return convertFromScalableVector(DAG, VT, Val); + } else { + EVT CvtVT = ContainerSrcVT.changeVectorElementType( + ContainerDstVT.getVectorElementType()); + SDValue Pg = getPredicateForVector(DAG, DL, CvtVT); + + Val = convertToScalableVector(DAG, ContainerSrcVT, Val); + Val = DAG.getNode(Opcode, DL, CvtVT, Pg, Val, DAG.getUNDEF(CvtVT)); + Val = getSVESafeBitCast(ContainerSrcVT, Val, DAG); + Val = convertFromScalableVector(DAG, SrcVT, Val); + + Val = DAG.getNode(ISD::TRUNCATE, DL, VT.changeTypeToInteger(), Val); + return DAG.getNode(ISD::BITCAST, DL, VT, Val); + } +} + +SDValue +AArch64TargetLowering::LowerFixedLengthFPToIntToSVE(SDValue Op, + SelectionDAG &DAG) const { + EVT VT = Op.getValueType(); + assert(VT.isFixedLengthVector() && "Expected fixed length vector type!"); + + bool IsSigned = Op.getOpcode() == ISD::FP_TO_SINT; + unsigned Opcode = IsSigned ? AArch64ISD::FCVTZS_MERGE_PASSTHRU + : AArch64ISD::FCVTZU_MERGE_PASSTHRU; + + SDLoc DL(Op); + SDValue Val = Op.getOperand(0); + EVT SrcVT = Val.getValueType(); + EVT ContainerDstVT = getContainerForFixedLengthVector(DAG, VT); + EVT ContainerSrcVT = getContainerForFixedLengthVector(DAG, SrcVT); + + if (ContainerSrcVT.getVectorElementType().getSizeInBits() <= + ContainerDstVT.getVectorElementType().getSizeInBits()) { + EVT CvtVT = ContainerDstVT.changeVectorElementType( + ContainerSrcVT.getVectorElementType()); + SDValue Pg = getPredicateForVector(DAG, DL, VT); + + Val = DAG.getNode(ISD::BITCAST, DL, SrcVT.changeTypeToInteger(), Val); + Val = DAG.getNode(ISD::ANY_EXTEND, DL, VT, Val); + + Val = convertToScalableVector(DAG, ContainerSrcVT, Val); + Val = getSVESafeBitCast(CvtVT, Val, DAG); + Val = DAG.getNode(Opcode, DL, ContainerDstVT, Pg, Val, + DAG.getUNDEF(ContainerDstVT)); + return convertFromScalableVector(DAG, VT, Val); + } else { + EVT CvtVT = ContainerSrcVT.changeTypeToInteger(); + SDValue Pg = getPredicateForVector(DAG, DL, CvtVT); + + // Safe to use a larger than specified result since an fp_to_int where the + // result doesn't fit into the destination is undefined. + Val = convertToScalableVector(DAG, ContainerSrcVT, Val); + Val = DAG.getNode(Opcode, DL, CvtVT, Pg, Val, DAG.getUNDEF(CvtVT)); + Val = convertFromScalableVector(DAG, SrcVT.changeTypeToInteger(), Val); + + return DAG.getNode(ISD::TRUNCATE, DL, VT, Val); + } +} + +SDValue AArch64TargetLowering::LowerFixedLengthVECTOR_SHUFFLEToSVE( + SDValue Op, SelectionDAG &DAG) const { + EVT VT = Op.getValueType(); + assert(VT.isFixedLengthVector() && "Expected fixed length vector type!"); + + auto *SVN = cast<ShuffleVectorSDNode>(Op.getNode()); + auto ShuffleMask = SVN->getMask(); + + SDLoc DL(Op); + SDValue Op1 = Op.getOperand(0); + SDValue Op2 = Op.getOperand(1); + + EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT); + Op1 = convertToScalableVector(DAG, ContainerVT, Op1); + Op2 = convertToScalableVector(DAG, ContainerVT, Op2); + + bool ReverseEXT = false; + unsigned Imm; + if (isEXTMask(ShuffleMask, VT, ReverseEXT, Imm) && + Imm == VT.getVectorNumElements() - 1) { + if (ReverseEXT) + std::swap(Op1, Op2); + + EVT ScalarTy = VT.getVectorElementType(); + if ((ScalarTy == MVT::i8) || (ScalarTy == MVT::i16)) + ScalarTy = MVT::i32; + SDValue Scalar = DAG.getNode( + ISD::EXTRACT_VECTOR_ELT, DL, ScalarTy, Op1, + DAG.getConstant(VT.getVectorNumElements() - 1, DL, MVT::i64)); + Op = DAG.getNode(AArch64ISD::INSR, DL, ContainerVT, Op2, Scalar); + return convertFromScalableVector(DAG, VT, Op); + } + + return SDValue(); +} + SDValue AArch64TargetLowering::getSVESafeBitCast(EVT VT, SDValue Op, SelectionDAG &DAG) const { SDLoc DL(Op); @@ -17248,8 +18610,6 @@ SDValue AArch64TargetLowering::getSVESafeBitCast(EVT VT, SDValue Op, EVT PackedVT = getPackedSVEVectorVT(VT.getVectorElementType()); EVT PackedInVT = getPackedSVEVectorVT(InVT.getVectorElementType()); - assert((VT == PackedVT || InVT == PackedInVT) && - "Cannot cast between unpacked scalable vector types!"); // Pack input if required. if (InVT != PackedInVT) @@ -17263,3 +18623,60 @@ SDValue AArch64TargetLowering::getSVESafeBitCast(EVT VT, SDValue Op, return Op; } + +bool AArch64TargetLowering::isAllActivePredicate(SDValue N) const { + return ::isAllActivePredicate(N); +} + +EVT AArch64TargetLowering::getPromotedVTForPredicate(EVT VT) const { + return ::getPromotedVTForPredicate(VT); +} + +bool AArch64TargetLowering::SimplifyDemandedBitsForTargetNode( + SDValue Op, const APInt &OriginalDemandedBits, + const APInt &OriginalDemandedElts, KnownBits &Known, TargetLoweringOpt &TLO, + unsigned Depth) const { + + unsigned Opc = Op.getOpcode(); + switch (Opc) { + case AArch64ISD::VSHL: { + // Match (VSHL (VLSHR Val X) X) + SDValue ShiftL = Op; + SDValue ShiftR = Op->getOperand(0); + if (ShiftR->getOpcode() != AArch64ISD::VLSHR) + return false; + + if (!ShiftL.hasOneUse() || !ShiftR.hasOneUse()) + return false; + + unsigned ShiftLBits = ShiftL->getConstantOperandVal(1); + unsigned ShiftRBits = ShiftR->getConstantOperandVal(1); + + // Other cases can be handled as well, but this is not + // implemented. + if (ShiftRBits != ShiftLBits) + return false; + + unsigned ScalarSize = Op.getScalarValueSizeInBits(); + assert(ScalarSize > ShiftLBits && "Invalid shift imm"); + + APInt ZeroBits = APInt::getLowBitsSet(ScalarSize, ShiftLBits); + APInt UnusedBits = ~OriginalDemandedBits; + + if ((ZeroBits & UnusedBits) != ZeroBits) + return false; + + // All bits that are zeroed by (VSHL (VLSHR Val X) X) are not + // used - simplify to just Val. + return TLO.CombineTo(Op, ShiftR->getOperand(0)); + } + } + + return TargetLowering::SimplifyDemandedBitsForTargetNode( + Op, OriginalDemandedBits, OriginalDemandedElts, Known, TLO, Depth); +} + +bool AArch64TargetLowering::isConstantUnsignedBitfieldExtactLegal( + unsigned Opc, LLT Ty1, LLT Ty2) const { + return Ty1 == Ty2 && (Ty1 == LLT::scalar(32) || Ty1 == LLT::scalar(64)); +} |