aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Target/AArch64/AArch64ISelLowering.cpp')
-rw-r--r--llvm/lib/Target/AArch64/AArch64ISelLowering.cpp2521
1 files changed, 1969 insertions, 552 deletions
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 1be09186dc0a..e7282aad05e2 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -29,7 +29,9 @@
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/Triple.h"
#include "llvm/ADT/Twine.h"
+#include "llvm/Analysis/ObjCARCUtil.h"
#include "llvm/Analysis/VectorUtils.h"
+#include "llvm/CodeGen/Analysis.h"
#include "llvm/CodeGen/CallingConvLower.h"
#include "llvm/CodeGen/MachineBasicBlock.h"
#include "llvm/CodeGen/MachineFrameInfo.h"
@@ -343,6 +345,18 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setCondCodeAction(ISD::SETUGT, VT, Expand);
setCondCodeAction(ISD::SETUEQ, VT, Expand);
setCondCodeAction(ISD::SETUNE, VT, Expand);
+
+ setOperationAction(ISD::FREM, VT, Expand);
+ setOperationAction(ISD::FPOW, VT, Expand);
+ setOperationAction(ISD::FPOWI, VT, Expand);
+ setOperationAction(ISD::FCOS, VT, Expand);
+ setOperationAction(ISD::FSIN, VT, Expand);
+ setOperationAction(ISD::FSINCOS, VT, Expand);
+ setOperationAction(ISD::FEXP, VT, Expand);
+ setOperationAction(ISD::FEXP2, VT, Expand);
+ setOperationAction(ISD::FLOG, VT, Expand);
+ setOperationAction(ISD::FLOG2, VT, Expand);
+ setOperationAction(ISD::FLOG10, VT, Expand);
}
}
@@ -458,6 +472,11 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::STRICT_FP_ROUND, MVT::f32, Custom);
setOperationAction(ISD::STRICT_FP_ROUND, MVT::f64, Custom);
+ setOperationAction(ISD::FP_TO_UINT_SAT, MVT::i32, Custom);
+ setOperationAction(ISD::FP_TO_UINT_SAT, MVT::i64, Custom);
+ setOperationAction(ISD::FP_TO_SINT_SAT, MVT::i32, Custom);
+ setOperationAction(ISD::FP_TO_SINT_SAT, MVT::i64, Custom);
+
// Variable arguments.
setOperationAction(ISD::VASTART, MVT::Other, Custom);
setOperationAction(ISD::VAARG, MVT::Other, Custom);
@@ -604,6 +623,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::FNEARBYINT, MVT::f16, Promote);
setOperationAction(ISD::FRINT, MVT::f16, Promote);
setOperationAction(ISD::FROUND, MVT::f16, Promote);
+ setOperationAction(ISD::FROUNDEVEN, MVT::f16, Promote);
setOperationAction(ISD::FTRUNC, MVT::f16, Promote);
setOperationAction(ISD::FMINNUM, MVT::f16, Promote);
setOperationAction(ISD::FMAXNUM, MVT::f16, Promote);
@@ -623,6 +643,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::FABS, MVT::v4f16, Expand);
setOperationAction(ISD::FNEG, MVT::v4f16, Expand);
setOperationAction(ISD::FROUND, MVT::v4f16, Expand);
+ setOperationAction(ISD::FROUNDEVEN, MVT::v4f16, Expand);
setOperationAction(ISD::FMA, MVT::v4f16, Expand);
setOperationAction(ISD::SETCC, MVT::v4f16, Expand);
setOperationAction(ISD::BR_CC, MVT::v4f16, Expand);
@@ -647,6 +668,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::FNEARBYINT, MVT::v8f16, Expand);
setOperationAction(ISD::FNEG, MVT::v8f16, Expand);
setOperationAction(ISD::FROUND, MVT::v8f16, Expand);
+ setOperationAction(ISD::FROUNDEVEN, MVT::v8f16, Expand);
setOperationAction(ISD::FRINT, MVT::v8f16, Expand);
setOperationAction(ISD::FSQRT, MVT::v8f16, Expand);
setOperationAction(ISD::FSUB, MVT::v8f16, Expand);
@@ -666,6 +688,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::FRINT, Ty, Legal);
setOperationAction(ISD::FTRUNC, Ty, Legal);
setOperationAction(ISD::FROUND, Ty, Legal);
+ setOperationAction(ISD::FROUNDEVEN, Ty, Legal);
setOperationAction(ISD::FMINNUM, Ty, Legal);
setOperationAction(ISD::FMAXNUM, Ty, Legal);
setOperationAction(ISD::FMINIMUM, Ty, Legal);
@@ -683,6 +706,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::FRINT, MVT::f16, Legal);
setOperationAction(ISD::FTRUNC, MVT::f16, Legal);
setOperationAction(ISD::FROUND, MVT::f16, Legal);
+ setOperationAction(ISD::FROUNDEVEN, MVT::f16, Legal);
setOperationAction(ISD::FMINNUM, MVT::f16, Legal);
setOperationAction(ISD::FMAXNUM, MVT::f16, Legal);
setOperationAction(ISD::FMINIMUM, MVT::f16, Legal);
@@ -692,6 +716,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::PREFETCH, MVT::Other, Custom);
setOperationAction(ISD::FLT_ROUNDS_, MVT::i32, Custom);
+ setOperationAction(ISD::SET_ROUNDING, MVT::Other, Custom);
setOperationAction(ISD::ATOMIC_CMP_SWAP, MVT::i128, Custom);
setOperationAction(ISD::ATOMIC_LOAD_SUB, MVT::i32, Custom);
@@ -857,15 +882,20 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setTargetDAGCombine(ISD::SINT_TO_FP);
setTargetDAGCombine(ISD::UINT_TO_FP);
+ // TODO: Do the same for FP_TO_*INT_SAT.
setTargetDAGCombine(ISD::FP_TO_SINT);
setTargetDAGCombine(ISD::FP_TO_UINT);
setTargetDAGCombine(ISD::FDIV);
+ // Try and combine setcc with csel
+ setTargetDAGCombine(ISD::SETCC);
+
setTargetDAGCombine(ISD::INTRINSIC_WO_CHAIN);
setTargetDAGCombine(ISD::ANY_EXTEND);
setTargetDAGCombine(ISD::ZERO_EXTEND);
setTargetDAGCombine(ISD::SIGN_EXTEND);
+ setTargetDAGCombine(ISD::VECTOR_SPLICE);
setTargetDAGCombine(ISD::SIGN_EXTEND_INREG);
setTargetDAGCombine(ISD::TRUNCATE);
setTargetDAGCombine(ISD::CONCAT_VECTORS);
@@ -873,9 +903,6 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
if (Subtarget->supportsAddressTopByteIgnored())
setTargetDAGCombine(ISD::LOAD);
- setTargetDAGCombine(ISD::MGATHER);
- setTargetDAGCombine(ISD::MSCATTER);
-
setTargetDAGCombine(ISD::MUL);
setTargetDAGCombine(ISD::SELECT);
@@ -886,6 +913,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setTargetDAGCombine(ISD::INSERT_VECTOR_ELT);
setTargetDAGCombine(ISD::EXTRACT_VECTOR_ELT);
setTargetDAGCombine(ISD::VECREDUCE_ADD);
+ setTargetDAGCombine(ISD::STEP_VECTOR);
setTargetDAGCombine(ISD::GlobalAddress);
@@ -944,6 +972,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::FPOW, MVT::v1f64, Expand);
setOperationAction(ISD::FREM, MVT::v1f64, Expand);
setOperationAction(ISD::FROUND, MVT::v1f64, Expand);
+ setOperationAction(ISD::FROUNDEVEN, MVT::v1f64, Expand);
setOperationAction(ISD::FRINT, MVT::v1f64, Expand);
setOperationAction(ISD::FSIN, MVT::v1f64, Expand);
setOperationAction(ISD::FSINCOS, MVT::v1f64, Expand);
@@ -968,9 +997,11 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
// elements smaller than i32, so promote the input to i32 first.
setOperationPromotedToType(ISD::UINT_TO_FP, MVT::v4i8, MVT::v4i32);
setOperationPromotedToType(ISD::SINT_TO_FP, MVT::v4i8, MVT::v4i32);
- // i8 vector elements also need promotion to i32 for v8i8
setOperationPromotedToType(ISD::SINT_TO_FP, MVT::v8i8, MVT::v8i32);
setOperationPromotedToType(ISD::UINT_TO_FP, MVT::v8i8, MVT::v8i32);
+ setOperationPromotedToType(ISD::UINT_TO_FP, MVT::v16i8, MVT::v16i32);
+ setOperationPromotedToType(ISD::SINT_TO_FP, MVT::v16i8, MVT::v16i32);
+
// Similarly, there is no direct i32 -> f64 vector conversion instruction.
setOperationAction(ISD::SINT_TO_FP, MVT::v2i32, Custom);
setOperationAction(ISD::UINT_TO_FP, MVT::v2i32, Custom);
@@ -997,6 +1028,12 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::CTLZ, MVT::v1i64, Expand);
setOperationAction(ISD::CTLZ, MVT::v2i64, Expand);
+ setOperationAction(ISD::BITREVERSE, MVT::v8i8, Legal);
+ setOperationAction(ISD::BITREVERSE, MVT::v16i8, Legal);
+ setOperationAction(ISD::BITREVERSE, MVT::v2i32, Custom);
+ setOperationAction(ISD::BITREVERSE, MVT::v4i32, Custom);
+ setOperationAction(ISD::BITREVERSE, MVT::v1i64, Custom);
+ setOperationAction(ISD::BITREVERSE, MVT::v2i64, Custom);
// AArch64 doesn't have MUL.2d:
setOperationAction(ISD::MUL, MVT::v2i64, Expand);
@@ -1014,14 +1051,21 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::USUBSAT, VT, Legal);
}
+ for (MVT VT : {MVT::v8i8, MVT::v4i16, MVT::v2i32, MVT::v16i8, MVT::v8i16,
+ MVT::v4i32}) {
+ setOperationAction(ISD::ABDS, VT, Legal);
+ setOperationAction(ISD::ABDU, VT, Legal);
+ }
+
// Vector reductions
for (MVT VT : { MVT::v4f16, MVT::v2f32,
MVT::v8f16, MVT::v4f32, MVT::v2f64 }) {
- setOperationAction(ISD::VECREDUCE_FMAX, VT, Custom);
- setOperationAction(ISD::VECREDUCE_FMIN, VT, Custom);
+ if (VT.getVectorElementType() != MVT::f16 || Subtarget->hasFullFP16()) {
+ setOperationAction(ISD::VECREDUCE_FMAX, VT, Custom);
+ setOperationAction(ISD::VECREDUCE_FMIN, VT, Custom);
- if (VT.getVectorElementType() != MVT::f16 || Subtarget->hasFullFP16())
setOperationAction(ISD::VECREDUCE_FADD, VT, Legal);
+ }
}
for (MVT VT : { MVT::v8i8, MVT::v4i16, MVT::v2i32,
MVT::v16i8, MVT::v8i16, MVT::v4i32 }) {
@@ -1069,6 +1113,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::FRINT, Ty, Legal);
setOperationAction(ISD::FTRUNC, Ty, Legal);
setOperationAction(ISD::FROUND, Ty, Legal);
+ setOperationAction(ISD::FROUNDEVEN, Ty, Legal);
}
if (Subtarget->hasFullFP16()) {
@@ -1079,6 +1124,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::FRINT, Ty, Legal);
setOperationAction(ISD::FTRUNC, Ty, Legal);
setOperationAction(ISD::FROUND, Ty, Legal);
+ setOperationAction(ISD::FROUNDEVEN, Ty, Legal);
}
}
@@ -1086,12 +1132,16 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::VSCALE, MVT::i32, Custom);
setTruncStoreAction(MVT::v4i16, MVT::v4i8, Custom);
+
+ setLoadExtAction(ISD::EXTLOAD, MVT::v4i16, MVT::v4i8, Custom);
+ setLoadExtAction(ISD::SEXTLOAD, MVT::v4i16, MVT::v4i8, Custom);
+ setLoadExtAction(ISD::ZEXTLOAD, MVT::v4i16, MVT::v4i8, Custom);
+ setLoadExtAction(ISD::EXTLOAD, MVT::v4i32, MVT::v4i8, Custom);
+ setLoadExtAction(ISD::SEXTLOAD, MVT::v4i32, MVT::v4i8, Custom);
+ setLoadExtAction(ISD::ZEXTLOAD, MVT::v4i32, MVT::v4i8, Custom);
}
if (Subtarget->hasSVE()) {
- // FIXME: Add custom lowering of MLOAD to handle different passthrus (not a
- // splat of 0 or undef) once vector selects supported in SVE codegen. See
- // D68877 for more details.
for (auto VT : {MVT::nxv16i8, MVT::nxv8i16, MVT::nxv4i32, MVT::nxv2i64}) {
setOperationAction(ISD::BITREVERSE, VT, Custom);
setOperationAction(ISD::BSWAP, VT, Custom);
@@ -1105,9 +1155,14 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::FP_TO_SINT, VT, Custom);
setOperationAction(ISD::MGATHER, VT, Custom);
setOperationAction(ISD::MSCATTER, VT, Custom);
+ setOperationAction(ISD::MLOAD, VT, Custom);
setOperationAction(ISD::MUL, VT, Custom);
+ setOperationAction(ISD::MULHS, VT, Custom);
+ setOperationAction(ISD::MULHU, VT, Custom);
setOperationAction(ISD::SPLAT_VECTOR, VT, Custom);
+ setOperationAction(ISD::VECTOR_SPLICE, VT, Custom);
setOperationAction(ISD::SELECT, VT, Custom);
+ setOperationAction(ISD::SETCC, VT, Custom);
setOperationAction(ISD::SDIV, VT, Custom);
setOperationAction(ISD::UDIV, VT, Custom);
setOperationAction(ISD::SMIN, VT, Custom);
@@ -1126,6 +1181,12 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::VECREDUCE_UMAX, VT, Custom);
setOperationAction(ISD::VECREDUCE_SMIN, VT, Custom);
setOperationAction(ISD::VECREDUCE_SMAX, VT, Custom);
+
+ setOperationAction(ISD::UMUL_LOHI, VT, Expand);
+ setOperationAction(ISD::SMUL_LOHI, VT, Expand);
+ setOperationAction(ISD::SELECT_CC, VT, Expand);
+ setOperationAction(ISD::ROTL, VT, Expand);
+ setOperationAction(ISD::ROTR, VT, Expand);
}
// Illegal unpacked integer vector types.
@@ -1134,6 +1195,11 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom);
}
+ // Legalize unpacked bitcasts to REINTERPRET_CAST.
+ for (auto VT : {MVT::nxv2i16, MVT::nxv4i16, MVT::nxv2i32, MVT::nxv2bf16,
+ MVT::nxv2f16, MVT::nxv4f16, MVT::nxv2f32})
+ setOperationAction(ISD::BITCAST, VT, Custom);
+
for (auto VT : {MVT::nxv16i1, MVT::nxv8i1, MVT::nxv4i1, MVT::nxv2i1}) {
setOperationAction(ISD::CONCAT_VECTORS, VT, Custom);
setOperationAction(ISD::SELECT, VT, Custom);
@@ -1144,6 +1210,10 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::VECREDUCE_OR, VT, Custom);
setOperationAction(ISD::VECREDUCE_XOR, VT, Custom);
+ setOperationAction(ISD::SELECT_CC, VT, Expand);
+ setOperationAction(ISD::EXTRACT_VECTOR_ELT, VT, Custom);
+ setOperationAction(ISD::INSERT_VECTOR_ELT, VT, Custom);
+
// There are no legal MVT::nxv16f## based types.
if (VT != MVT::nxv16i1) {
setOperationAction(ISD::SINT_TO_FP, VT, Custom);
@@ -1151,18 +1221,50 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
}
}
+ // NEON doesn't support masked loads/stores/gathers/scatters, but SVE does
+ for (auto VT : {MVT::v4f16, MVT::v8f16, MVT::v2f32, MVT::v4f32, MVT::v1f64,
+ MVT::v2f64, MVT::v8i8, MVT::v16i8, MVT::v4i16, MVT::v8i16,
+ MVT::v2i32, MVT::v4i32, MVT::v1i64, MVT::v2i64}) {
+ setOperationAction(ISD::MLOAD, VT, Custom);
+ setOperationAction(ISD::MSTORE, VT, Custom);
+ setOperationAction(ISD::MGATHER, VT, Custom);
+ setOperationAction(ISD::MSCATTER, VT, Custom);
+ }
+
+ for (MVT VT : MVT::fp_scalable_vector_valuetypes()) {
+ for (MVT InnerVT : MVT::fp_scalable_vector_valuetypes()) {
+ // Avoid marking truncating FP stores as legal to prevent the
+ // DAGCombiner from creating unsupported truncating stores.
+ setTruncStoreAction(VT, InnerVT, Expand);
+ // SVE does not have floating-point extending loads.
+ setLoadExtAction(ISD::SEXTLOAD, VT, InnerVT, Expand);
+ setLoadExtAction(ISD::ZEXTLOAD, VT, InnerVT, Expand);
+ setLoadExtAction(ISD::EXTLOAD, VT, InnerVT, Expand);
+ }
+ }
+
+ // SVE supports truncating stores of 64 and 128-bit vectors
+ setTruncStoreAction(MVT::v2i64, MVT::v2i8, Custom);
+ setTruncStoreAction(MVT::v2i64, MVT::v2i16, Custom);
+ setTruncStoreAction(MVT::v2i64, MVT::v2i32, Custom);
+ setTruncStoreAction(MVT::v2i32, MVT::v2i8, Custom);
+ setTruncStoreAction(MVT::v2i32, MVT::v2i16, Custom);
+
for (auto VT : {MVT::nxv2f16, MVT::nxv4f16, MVT::nxv8f16, MVT::nxv2f32,
MVT::nxv4f32, MVT::nxv2f64}) {
setOperationAction(ISD::CONCAT_VECTORS, VT, Custom);
setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom);
setOperationAction(ISD::MGATHER, VT, Custom);
setOperationAction(ISD::MSCATTER, VT, Custom);
+ setOperationAction(ISD::MLOAD, VT, Custom);
setOperationAction(ISD::SPLAT_VECTOR, VT, Custom);
setOperationAction(ISD::SELECT, VT, Custom);
setOperationAction(ISD::FADD, VT, Custom);
setOperationAction(ISD::FDIV, VT, Custom);
setOperationAction(ISD::FMA, VT, Custom);
+ setOperationAction(ISD::FMAXIMUM, VT, Custom);
setOperationAction(ISD::FMAXNUM, VT, Custom);
+ setOperationAction(ISD::FMINIMUM, VT, Custom);
setOperationAction(ISD::FMINNUM, VT, Custom);
setOperationAction(ISD::FMUL, VT, Custom);
setOperationAction(ISD::FNEG, VT, Custom);
@@ -1182,12 +1284,16 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::VECREDUCE_FMAX, VT, Custom);
setOperationAction(ISD::VECREDUCE_FMIN, VT, Custom);
setOperationAction(ISD::VECREDUCE_SEQ_FADD, VT, Custom);
+ setOperationAction(ISD::VECTOR_SPLICE, VT, Custom);
+
+ setOperationAction(ISD::SELECT_CC, VT, Expand);
}
for (auto VT : {MVT::nxv2bf16, MVT::nxv4bf16, MVT::nxv8bf16}) {
setOperationAction(ISD::CONCAT_VECTORS, VT, Custom);
setOperationAction(ISD::MGATHER, VT, Custom);
setOperationAction(ISD::MSCATTER, VT, Custom);
+ setOperationAction(ISD::MLOAD, VT, Custom);
}
setOperationAction(ISD::SPLAT_VECTOR, MVT::nxv8bf16, Custom);
@@ -1214,7 +1320,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
for (auto VT : {MVT::v16i8, MVT::v8i16, MVT::v4i32})
setOperationAction(ISD::TRUNCATE, VT, Custom);
for (auto VT : {MVT::v8f16, MVT::v4f32})
- setOperationAction(ISD::FP_ROUND, VT, Expand);
+ setOperationAction(ISD::FP_ROUND, VT, Custom);
// These operations are not supported on NEON but SVE can do them.
setOperationAction(ISD::BITREVERSE, MVT::v1i64, Custom);
@@ -1223,6 +1329,10 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::CTTZ, MVT::v1i64, Custom);
setOperationAction(ISD::MUL, MVT::v1i64, Custom);
setOperationAction(ISD::MUL, MVT::v2i64, Custom);
+ setOperationAction(ISD::MULHS, MVT::v1i64, Custom);
+ setOperationAction(ISD::MULHS, MVT::v2i64, Custom);
+ setOperationAction(ISD::MULHU, MVT::v1i64, Custom);
+ setOperationAction(ISD::MULHU, MVT::v2i64, Custom);
setOperationAction(ISD::SDIV, MVT::v8i8, Custom);
setOperationAction(ISD::SDIV, MVT::v16i8, Custom);
setOperationAction(ISD::SDIV, MVT::v4i16, Custom);
@@ -1271,12 +1381,17 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
for (auto VT : {MVT::v4f16, MVT::v8f16, MVT::v4f32})
setOperationAction(ISD::VECREDUCE_FADD, VT, Custom);
}
+
+ setOperationPromotedToType(ISD::VECTOR_SPLICE, MVT::nxv2i1, MVT::nxv2i64);
+ setOperationPromotedToType(ISD::VECTOR_SPLICE, MVT::nxv4i1, MVT::nxv4i32);
+ setOperationPromotedToType(ISD::VECTOR_SPLICE, MVT::nxv8i1, MVT::nxv8i16);
+ setOperationPromotedToType(ISD::VECTOR_SPLICE, MVT::nxv16i1, MVT::nxv16i8);
}
PredictableSelectIsExpensive = Subtarget->predictableSelectIsExpensive();
}
-void AArch64TargetLowering::addTypeForNEON(MVT VT, MVT PromotedBitwiseVT) {
+void AArch64TargetLowering::addTypeForNEON(MVT VT) {
assert(VT.isVector() && "VT should be a vector type");
if (VT.isFloatingPoint()) {
@@ -1295,10 +1410,12 @@ void AArch64TargetLowering::addTypeForNEON(MVT VT, MVT PromotedBitwiseVT) {
setOperationAction(ISD::FLOG10, VT, Expand);
setOperationAction(ISD::FEXP, VT, Expand);
setOperationAction(ISD::FEXP2, VT, Expand);
+ }
- // But we do support custom-lowering for FCOPYSIGN.
+ // But we do support custom-lowering for FCOPYSIGN.
+ if (VT == MVT::v2f32 || VT == MVT::v4f32 || VT == MVT::v2f64 ||
+ ((VT == MVT::v4f16 || VT == MVT::v8f16) && Subtarget->hasFullFP16()))
setOperationAction(ISD::FCOPYSIGN, VT, Custom);
- }
setOperationAction(ISD::EXTRACT_VECTOR_ELT, VT, Custom);
setOperationAction(ISD::INSERT_VECTOR_ELT, VT, Custom);
@@ -1366,48 +1483,93 @@ void AArch64TargetLowering::addTypeForFixedLengthSVE(MVT VT) {
// We use EXTRACT_SUBVECTOR to "cast" a scalable vector to a fixed length one.
setOperationAction(ISD::EXTRACT_SUBVECTOR, VT, Custom);
+ if (VT.isFloatingPoint()) {
+ setCondCodeAction(ISD::SETO, VT, Expand);
+ setCondCodeAction(ISD::SETOLT, VT, Expand);
+ setCondCodeAction(ISD::SETLT, VT, Expand);
+ setCondCodeAction(ISD::SETOLE, VT, Expand);
+ setCondCodeAction(ISD::SETLE, VT, Expand);
+ setCondCodeAction(ISD::SETULT, VT, Expand);
+ setCondCodeAction(ISD::SETULE, VT, Expand);
+ setCondCodeAction(ISD::SETUGE, VT, Expand);
+ setCondCodeAction(ISD::SETUGT, VT, Expand);
+ setCondCodeAction(ISD::SETUEQ, VT, Expand);
+ setCondCodeAction(ISD::SETUNE, VT, Expand);
+ }
+
+ // Mark integer truncating stores as having custom lowering
+ if (VT.isInteger()) {
+ MVT InnerVT = VT.changeVectorElementType(MVT::i8);
+ while (InnerVT != VT) {
+ setTruncStoreAction(VT, InnerVT, Custom);
+ InnerVT = InnerVT.changeVectorElementType(
+ MVT::getIntegerVT(2 * InnerVT.getScalarSizeInBits()));
+ }
+ }
+
// Lower fixed length vector operations to scalable equivalents.
setOperationAction(ISD::ABS, VT, Custom);
setOperationAction(ISD::ADD, VT, Custom);
setOperationAction(ISD::AND, VT, Custom);
setOperationAction(ISD::ANY_EXTEND, VT, Custom);
+ setOperationAction(ISD::BITCAST, VT, Custom);
setOperationAction(ISD::BITREVERSE, VT, Custom);
setOperationAction(ISD::BSWAP, VT, Custom);
+ setOperationAction(ISD::CONCAT_VECTORS, VT, Custom);
setOperationAction(ISD::CTLZ, VT, Custom);
setOperationAction(ISD::CTPOP, VT, Custom);
setOperationAction(ISD::CTTZ, VT, Custom);
+ setOperationAction(ISD::FABS, VT, Custom);
setOperationAction(ISD::FADD, VT, Custom);
+ setOperationAction(ISD::EXTRACT_VECTOR_ELT, VT, Custom);
setOperationAction(ISD::FCEIL, VT, Custom);
setOperationAction(ISD::FDIV, VT, Custom);
setOperationAction(ISD::FFLOOR, VT, Custom);
setOperationAction(ISD::FMA, VT, Custom);
+ setOperationAction(ISD::FMAXIMUM, VT, Custom);
setOperationAction(ISD::FMAXNUM, VT, Custom);
+ setOperationAction(ISD::FMINIMUM, VT, Custom);
setOperationAction(ISD::FMINNUM, VT, Custom);
setOperationAction(ISD::FMUL, VT, Custom);
setOperationAction(ISD::FNEARBYINT, VT, Custom);
setOperationAction(ISD::FNEG, VT, Custom);
+ setOperationAction(ISD::FP_EXTEND, VT, Custom);
+ setOperationAction(ISD::FP_ROUND, VT, Custom);
+ setOperationAction(ISD::FP_TO_SINT, VT, Custom);
+ setOperationAction(ISD::FP_TO_UINT, VT, Custom);
setOperationAction(ISD::FRINT, VT, Custom);
setOperationAction(ISD::FROUND, VT, Custom);
+ setOperationAction(ISD::FROUNDEVEN, VT, Custom);
setOperationAction(ISD::FSQRT, VT, Custom);
setOperationAction(ISD::FSUB, VT, Custom);
setOperationAction(ISD::FTRUNC, VT, Custom);
setOperationAction(ISD::LOAD, VT, Custom);
+ setOperationAction(ISD::MGATHER, VT, Custom);
+ setOperationAction(ISD::MLOAD, VT, Custom);
+ setOperationAction(ISD::MSCATTER, VT, Custom);
+ setOperationAction(ISD::MSTORE, VT, Custom);
setOperationAction(ISD::MUL, VT, Custom);
+ setOperationAction(ISD::MULHS, VT, Custom);
+ setOperationAction(ISD::MULHU, VT, Custom);
setOperationAction(ISD::OR, VT, Custom);
setOperationAction(ISD::SDIV, VT, Custom);
+ setOperationAction(ISD::SELECT, VT, Custom);
setOperationAction(ISD::SETCC, VT, Custom);
setOperationAction(ISD::SHL, VT, Custom);
setOperationAction(ISD::SIGN_EXTEND, VT, Custom);
setOperationAction(ISD::SIGN_EXTEND_INREG, VT, Custom);
+ setOperationAction(ISD::SINT_TO_FP, VT, Custom);
setOperationAction(ISD::SMAX, VT, Custom);
setOperationAction(ISD::SMIN, VT, Custom);
setOperationAction(ISD::SPLAT_VECTOR, VT, Custom);
+ setOperationAction(ISD::VECTOR_SPLICE, VT, Custom);
setOperationAction(ISD::SRA, VT, Custom);
setOperationAction(ISD::SRL, VT, Custom);
setOperationAction(ISD::STORE, VT, Custom);
setOperationAction(ISD::SUB, VT, Custom);
setOperationAction(ISD::TRUNCATE, VT, Custom);
setOperationAction(ISD::UDIV, VT, Custom);
+ setOperationAction(ISD::UINT_TO_FP, VT, Custom);
setOperationAction(ISD::UMAX, VT, Custom);
setOperationAction(ISD::UMIN, VT, Custom);
setOperationAction(ISD::VECREDUCE_ADD, VT, Custom);
@@ -1417,11 +1579,13 @@ void AArch64TargetLowering::addTypeForFixedLengthSVE(MVT VT) {
setOperationAction(ISD::VECREDUCE_FMAX, VT, Custom);
setOperationAction(ISD::VECREDUCE_FMIN, VT, Custom);
setOperationAction(ISD::VECREDUCE_OR, VT, Custom);
+ setOperationAction(ISD::INSERT_VECTOR_ELT, VT, Custom);
setOperationAction(ISD::VECREDUCE_SMAX, VT, Custom);
setOperationAction(ISD::VECREDUCE_SMIN, VT, Custom);
setOperationAction(ISD::VECREDUCE_UMAX, VT, Custom);
setOperationAction(ISD::VECREDUCE_UMIN, VT, Custom);
setOperationAction(ISD::VECREDUCE_XOR, VT, Custom);
+ setOperationAction(ISD::VECTOR_SHUFFLE, VT, Custom);
setOperationAction(ISD::VSELECT, VT, Custom);
setOperationAction(ISD::XOR, VT, Custom);
setOperationAction(ISD::ZERO_EXTEND, VT, Custom);
@@ -1429,12 +1593,12 @@ void AArch64TargetLowering::addTypeForFixedLengthSVE(MVT VT) {
void AArch64TargetLowering::addDRTypeForNEON(MVT VT) {
addRegisterClass(VT, &AArch64::FPR64RegClass);
- addTypeForNEON(VT, MVT::v2i32);
+ addTypeForNEON(VT);
}
void AArch64TargetLowering::addQRTypeForNEON(MVT VT) {
addRegisterClass(VT, &AArch64::FPR128RegClass);
- addTypeForNEON(VT, MVT::v4i32);
+ addTypeForNEON(VT);
}
EVT AArch64TargetLowering::getSetCCResultType(const DataLayout &,
@@ -1659,7 +1823,7 @@ MVT AArch64TargetLowering::getScalarShiftAmountTy(const DataLayout &DL,
}
bool AArch64TargetLowering::allowsMisalignedMemoryAccesses(
- EVT VT, unsigned AddrSpace, unsigned Align, MachineMemOperand::Flags Flags,
+ EVT VT, unsigned AddrSpace, Align Alignment, MachineMemOperand::Flags Flags,
bool *Fast) const {
if (Subtarget->requiresStrictAlign())
return false;
@@ -1673,7 +1837,7 @@ bool AArch64TargetLowering::allowsMisalignedMemoryAccesses(
// Code that uses clang vector extensions can mark that it
// wants unaligned accesses to be treated as fast by
// underspecifying alignment to be 1 or 2.
- Align <= 2 ||
+ Alignment <= 2 ||
// Disregard v2i64. Memcpy lowering produces those and splitting
// them regresses performance on micro-benchmarks and olden/bh.
@@ -1703,7 +1867,7 @@ bool AArch64TargetLowering::allowsMisalignedMemoryAccesses(
// Disregard v2i64. Memcpy lowering produces those and splitting
// them regresses performance on micro-benchmarks and olden/bh.
- Ty == LLT::vector(2, 64);
+ Ty == LLT::fixed_vector(2, 64);
}
return true;
}
@@ -1729,7 +1893,6 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
MAKE_CASE(AArch64ISD::RET_FLAG)
MAKE_CASE(AArch64ISD::BRCOND)
MAKE_CASE(AArch64ISD::CSEL)
- MAKE_CASE(AArch64ISD::FCSEL)
MAKE_CASE(AArch64ISD::CSINV)
MAKE_CASE(AArch64ISD::CSNEG)
MAKE_CASE(AArch64ISD::CSINC)
@@ -1737,6 +1900,8 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
MAKE_CASE(AArch64ISD::TLSDESC_CALLSEQ)
MAKE_CASE(AArch64ISD::ADD_PRED)
MAKE_CASE(AArch64ISD::MUL_PRED)
+ MAKE_CASE(AArch64ISD::MULHS_PRED)
+ MAKE_CASE(AArch64ISD::MULHU_PRED)
MAKE_CASE(AArch64ISD::SDIV_PRED)
MAKE_CASE(AArch64ISD::SHL_PRED)
MAKE_CASE(AArch64ISD::SMAX_PRED)
@@ -1797,7 +1962,6 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
MAKE_CASE(AArch64ISD::BICi)
MAKE_CASE(AArch64ISD::ORRi)
MAKE_CASE(AArch64ISD::BSP)
- MAKE_CASE(AArch64ISD::NEG)
MAKE_CASE(AArch64ISD::EXTR)
MAKE_CASE(AArch64ISD::ZIP1)
MAKE_CASE(AArch64ISD::ZIP2)
@@ -1809,6 +1973,7 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
MAKE_CASE(AArch64ISD::REV32)
MAKE_CASE(AArch64ISD::REV64)
MAKE_CASE(AArch64ISD::EXT)
+ MAKE_CASE(AArch64ISD::SPLICE)
MAKE_CASE(AArch64ISD::VSHL)
MAKE_CASE(AArch64ISD::VLSHR)
MAKE_CASE(AArch64ISD::VASHR)
@@ -1838,6 +2003,8 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
MAKE_CASE(AArch64ISD::URHADD)
MAKE_CASE(AArch64ISD::SHADD)
MAKE_CASE(AArch64ISD::UHADD)
+ MAKE_CASE(AArch64ISD::SDOT)
+ MAKE_CASE(AArch64ISD::UDOT)
MAKE_CASE(AArch64ISD::SMINV)
MAKE_CASE(AArch64ISD::UMINV)
MAKE_CASE(AArch64ISD::SMAXV)
@@ -1855,7 +2022,6 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
MAKE_CASE(AArch64ISD::CLASTB_N)
MAKE_CASE(AArch64ISD::LASTA)
MAKE_CASE(AArch64ISD::LASTB)
- MAKE_CASE(AArch64ISD::REV)
MAKE_CASE(AArch64ISD::REINTERPRET_CAST)
MAKE_CASE(AArch64ISD::TBL)
MAKE_CASE(AArch64ISD::FADD_PRED)
@@ -1863,14 +2029,17 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
MAKE_CASE(AArch64ISD::FADDV_PRED)
MAKE_CASE(AArch64ISD::FDIV_PRED)
MAKE_CASE(AArch64ISD::FMA_PRED)
+ MAKE_CASE(AArch64ISD::FMAX_PRED)
MAKE_CASE(AArch64ISD::FMAXV_PRED)
MAKE_CASE(AArch64ISD::FMAXNM_PRED)
MAKE_CASE(AArch64ISD::FMAXNMV_PRED)
+ MAKE_CASE(AArch64ISD::FMIN_PRED)
MAKE_CASE(AArch64ISD::FMINV_PRED)
MAKE_CASE(AArch64ISD::FMINNM_PRED)
MAKE_CASE(AArch64ISD::FMINNMV_PRED)
MAKE_CASE(AArch64ISD::FMUL_PRED)
MAKE_CASE(AArch64ISD::FSUB_PRED)
+ MAKE_CASE(AArch64ISD::BIC)
MAKE_CASE(AArch64ISD::BIT)
MAKE_CASE(AArch64ISD::CBZ)
MAKE_CASE(AArch64ISD::CBNZ)
@@ -1881,6 +2050,7 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
MAKE_CASE(AArch64ISD::SITOF)
MAKE_CASE(AArch64ISD::UITOF)
MAKE_CASE(AArch64ISD::NVCAST)
+ MAKE_CASE(AArch64ISD::MRS)
MAKE_CASE(AArch64ISD::SQSHL_I)
MAKE_CASE(AArch64ISD::UQSHL_I)
MAKE_CASE(AArch64ISD::SRSHR_I)
@@ -1988,8 +2158,7 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
MAKE_CASE(AArch64ISD::CTPOP_MERGE_PASSTHRU)
MAKE_CASE(AArch64ISD::DUP_MERGE_PASSTHRU)
MAKE_CASE(AArch64ISD::INDEX_VECTOR)
- MAKE_CASE(AArch64ISD::UABD)
- MAKE_CASE(AArch64ISD::SABD)
+ MAKE_CASE(AArch64ISD::UADDLP)
MAKE_CASE(AArch64ISD::CALL_RVMARKER)
}
#undef MAKE_CASE
@@ -2094,6 +2263,31 @@ MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter(
// Lowering Code
//===----------------------------------------------------------------------===//
+// Forward declarations of SVE fixed length lowering helpers
+static EVT getContainerForFixedLengthVector(SelectionDAG &DAG, EVT VT);
+static SDValue convertToScalableVector(SelectionDAG &DAG, EVT VT, SDValue V);
+static SDValue convertFromScalableVector(SelectionDAG &DAG, EVT VT, SDValue V);
+static SDValue convertFixedMaskToScalableVector(SDValue Mask,
+ SelectionDAG &DAG);
+
+/// isZerosVector - Check whether SDNode N is a zero-filled vector.
+static bool isZerosVector(const SDNode *N) {
+ // Look through a bit convert.
+ while (N->getOpcode() == ISD::BITCAST)
+ N = N->getOperand(0).getNode();
+
+ if (ISD::isConstantSplatVectorAllZeros(N))
+ return true;
+
+ if (N->getOpcode() != AArch64ISD::DUP)
+ return false;
+
+ auto Opnd0 = N->getOperand(0);
+ auto *CINT = dyn_cast<ConstantSDNode>(Opnd0);
+ auto *CFP = dyn_cast<ConstantFPSDNode>(Opnd0);
+ return (CINT && CINT->isNullValue()) || (CFP && CFP->isZero());
+}
+
/// changeIntCCToAArch64CC - Convert a DAG integer condition code to an AArch64
/// CC
static AArch64CC::CondCode changeIntCCToAArch64CC(ISD::CondCode CC) {
@@ -2823,50 +3017,25 @@ getAArch64XALUOOp(AArch64CC::CondCode &CC, SDValue Op, SelectionDAG &DAG) {
CC = AArch64CC::NE;
bool IsSigned = Op.getOpcode() == ISD::SMULO;
if (Op.getValueType() == MVT::i32) {
+ // Extend to 64-bits, then perform a 64-bit multiply.
unsigned ExtendOpc = IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
- // For a 32 bit multiply with overflow check we want the instruction
- // selector to generate a widening multiply (SMADDL/UMADDL). For that we
- // need to generate the following pattern:
- // (i64 add 0, (i64 mul (i64 sext|zext i32 %a), (i64 sext|zext i32 %b))
LHS = DAG.getNode(ExtendOpc, DL, MVT::i64, LHS);
RHS = DAG.getNode(ExtendOpc, DL, MVT::i64, RHS);
SDValue Mul = DAG.getNode(ISD::MUL, DL, MVT::i64, LHS, RHS);
- SDValue Add = DAG.getNode(ISD::ADD, DL, MVT::i64, Mul,
- DAG.getConstant(0, DL, MVT::i64));
- // On AArch64 the upper 32 bits are always zero extended for a 32 bit
- // operation. We need to clear out the upper 32 bits, because we used a
- // widening multiply that wrote all 64 bits. In the end this should be a
- // noop.
- Value = DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, Add);
+ Value = DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, Mul);
+
+ // Check that the result fits into a 32-bit integer.
+ SDVTList VTs = DAG.getVTList(MVT::i64, MVT_CC);
if (IsSigned) {
- // The signed overflow check requires more than just a simple check for
- // any bit set in the upper 32 bits of the result. These bits could be
- // just the sign bits of a negative number. To perform the overflow
- // check we have to arithmetic shift right the 32nd bit of the result by
- // 31 bits. Then we compare the result to the upper 32 bits.
- SDValue UpperBits = DAG.getNode(ISD::SRL, DL, MVT::i64, Add,
- DAG.getConstant(32, DL, MVT::i64));
- UpperBits = DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, UpperBits);
- SDValue LowerBits = DAG.getNode(ISD::SRA, DL, MVT::i32, Value,
- DAG.getConstant(31, DL, MVT::i64));
- // It is important that LowerBits is last, otherwise the arithmetic
- // shift will not be folded into the compare (SUBS).
- SDVTList VTs = DAG.getVTList(MVT::i32, MVT::i32);
- Overflow = DAG.getNode(AArch64ISD::SUBS, DL, VTs, UpperBits, LowerBits)
- .getValue(1);
+ // cmp xreg, wreg, sxtw
+ SDValue SExtMul = DAG.getNode(ISD::SIGN_EXTEND, DL, MVT::i64, Value);
+ Overflow =
+ DAG.getNode(AArch64ISD::SUBS, DL, VTs, Mul, SExtMul).getValue(1);
} else {
- // The overflow check for unsigned multiply is easy. We only need to
- // check if any of the upper 32 bits are set. This can be done with a
- // CMP (shifted register). For that we need to generate the following
- // pattern:
- // (i64 AArch64ISD::SUBS i64 0, (i64 srl i64 %Mul, i64 32)
- SDValue UpperBits = DAG.getNode(ISD::SRL, DL, MVT::i64, Mul,
- DAG.getConstant(32, DL, MVT::i64));
- SDVTList VTs = DAG.getVTList(MVT::i64, MVT::i32);
+ // tst xreg, #0xffffffff00000000
+ SDValue UpperBits = DAG.getConstant(0xFFFFFFFF00000000, DL, MVT::i64);
Overflow =
- DAG.getNode(AArch64ISD::SUBS, DL, VTs,
- DAG.getConstant(0, DL, MVT::i64),
- UpperBits).getValue(1);
+ DAG.getNode(AArch64ISD::ANDS, DL, VTs, Mul, UpperBits).getValue(1);
}
break;
}
@@ -3082,9 +3251,13 @@ static SDValue LowerPREFETCH(SDValue Op, SelectionDAG &DAG) {
SDValue AArch64TargetLowering::LowerFP_EXTEND(SDValue Op,
SelectionDAG &DAG) const {
- if (Op.getValueType().isScalableVector())
+ EVT VT = Op.getValueType();
+ if (VT.isScalableVector())
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FP_EXTEND_MERGE_PASSTHRU);
+ if (useSVEForFixedLengthVectorVT(VT))
+ return LowerFixedLengthFPExtendToSVE(Op, DAG);
+
assert(Op.getValueType() == MVT::f128 && "Unexpected lowering");
return SDValue();
}
@@ -3098,6 +3271,9 @@ SDValue AArch64TargetLowering::LowerFP_ROUND(SDValue Op,
SDValue SrcVal = Op.getOperand(IsStrict ? 1 : 0);
EVT SrcVT = SrcVal.getValueType();
+ if (useSVEForFixedLengthVectorVT(SrcVT))
+ return LowerFixedLengthFPRoundToSVE(Op, DAG);
+
if (SrcVT != MVT::f128) {
// Expand cases where the input is a vector bigger than NEON.
if (useSVEForFixedLengthVectorVT(SrcVT))
@@ -3125,6 +3301,9 @@ SDValue AArch64TargetLowering::LowerVectorFP_TO_INT(SDValue Op,
return LowerToPredicatedOp(Op, DAG, Opcode);
}
+ if (useSVEForFixedLengthVectorVT(VT) || useSVEForFixedLengthVectorVT(InVT))
+ return LowerFixedLengthFPToIntToSVE(Op, DAG);
+
unsigned NumElts = InVT.getVectorNumElements();
// f16 conversions are promoted to f32 when full fp16 is not supported.
@@ -3185,6 +3364,44 @@ SDValue AArch64TargetLowering::LowerFP_TO_INT(SDValue Op,
return SDValue();
}
+SDValue AArch64TargetLowering::LowerFP_TO_INT_SAT(SDValue Op,
+ SelectionDAG &DAG) const {
+ // AArch64 FP-to-int conversions saturate to the destination register size, so
+ // we can lower common saturating conversions to simple instructions.
+ SDValue SrcVal = Op.getOperand(0);
+
+ EVT SrcVT = SrcVal.getValueType();
+ EVT DstVT = Op.getValueType();
+
+ EVT SatVT = cast<VTSDNode>(Op.getOperand(1))->getVT();
+ uint64_t SatWidth = SatVT.getScalarSizeInBits();
+ uint64_t DstWidth = DstVT.getScalarSizeInBits();
+ assert(SatWidth <= DstWidth && "Saturation width cannot exceed result width");
+
+ // TODO: Support lowering of NEON and SVE conversions.
+ if (SrcVT.isVector())
+ return SDValue();
+
+ // TODO: Saturate to SatWidth explicitly.
+ if (SatWidth != DstWidth)
+ return SDValue();
+
+ // In the absence of FP16 support, promote f32 to f16, like LowerFP_TO_INT().
+ if (SrcVT == MVT::f16 && !Subtarget->hasFullFP16())
+ return DAG.getNode(Op.getOpcode(), SDLoc(Op), Op.getValueType(),
+ DAG.getNode(ISD::FP_EXTEND, SDLoc(Op), MVT::f32, SrcVal),
+ Op.getOperand(1));
+
+ // Cases that we can emit directly.
+ if ((SrcVT == MVT::f64 || SrcVT == MVT::f32 ||
+ (SrcVT == MVT::f16 && Subtarget->hasFullFP16())) &&
+ (DstVT == MVT::i64 || DstVT == MVT::i32))
+ return Op;
+
+ // For all other cases, fall back on the expanded form.
+ return SDValue();
+}
+
SDValue AArch64TargetLowering::LowerVectorINT_TO_FP(SDValue Op,
SelectionDAG &DAG) const {
// Warning: We maintain cost tables in AArch64TargetTransformInfo.cpp.
@@ -3211,6 +3428,9 @@ SDValue AArch64TargetLowering::LowerVectorINT_TO_FP(SDValue Op,
return LowerToPredicatedOp(Op, DAG, Opcode);
}
+ if (useSVEForFixedLengthVectorVT(VT) || useSVEForFixedLengthVectorVT(InVT))
+ return LowerFixedLengthIntToFPToSVE(Op, DAG);
+
uint64_t VTSize = VT.getFixedSizeInBits();
uint64_t InVTSize = InVT.getFixedSizeInBits();
if (VTSize < InVTSize) {
@@ -3295,12 +3515,32 @@ SDValue AArch64TargetLowering::LowerFSINCOS(SDValue Op,
return CallResult.first;
}
-static SDValue LowerBITCAST(SDValue Op, SelectionDAG &DAG) {
+static MVT getSVEContainerType(EVT ContentTy);
+
+SDValue AArch64TargetLowering::LowerBITCAST(SDValue Op,
+ SelectionDAG &DAG) const {
EVT OpVT = Op.getValueType();
+ EVT ArgVT = Op.getOperand(0).getValueType();
+
+ if (useSVEForFixedLengthVectorVT(OpVT))
+ return LowerFixedLengthBitcastToSVE(Op, DAG);
+
+ if (OpVT.isScalableVector()) {
+ if (isTypeLegal(OpVT) && !isTypeLegal(ArgVT)) {
+ assert(OpVT.isFloatingPoint() && !ArgVT.isFloatingPoint() &&
+ "Expected int->fp bitcast!");
+ SDValue ExtResult =
+ DAG.getNode(ISD::ANY_EXTEND, SDLoc(Op), getSVEContainerType(ArgVT),
+ Op.getOperand(0));
+ return getSVESafeBitCast(OpVT, ExtResult, DAG);
+ }
+ return getSVESafeBitCast(OpVT, Op.getOperand(0), DAG);
+ }
+
if (OpVT != MVT::f16 && OpVT != MVT::bf16)
return SDValue();
- assert(Op.getOperand(0).getValueType() == MVT::i16);
+ assert(ArgVT == MVT::i16);
SDLoc DL(Op);
Op = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, Op.getOperand(0));
@@ -3453,6 +3693,50 @@ SDValue AArch64TargetLowering::LowerFLT_ROUNDS_(SDValue Op,
return DAG.getMergeValues({AND, Chain}, dl);
}
+SDValue AArch64TargetLowering::LowerSET_ROUNDING(SDValue Op,
+ SelectionDAG &DAG) const {
+ SDLoc DL(Op);
+ SDValue Chain = Op->getOperand(0);
+ SDValue RMValue = Op->getOperand(1);
+
+ // The rounding mode is in bits 23:22 of the FPCR.
+ // The llvm.set.rounding argument value to the rounding mode in FPCR mapping
+ // is 0->3, 1->0, 2->1, 3->2. The formula we use to implement this is
+ // ((arg - 1) & 3) << 22).
+ //
+ // The argument of llvm.set.rounding must be within the segment [0, 3], so
+ // NearestTiesToAway (4) is not handled here. It is responsibility of the code
+ // generated llvm.set.rounding to ensure this condition.
+
+ // Calculate new value of FPCR[23:22].
+ RMValue = DAG.getNode(ISD::SUB, DL, MVT::i32, RMValue,
+ DAG.getConstant(1, DL, MVT::i32));
+ RMValue = DAG.getNode(ISD::AND, DL, MVT::i32, RMValue,
+ DAG.getConstant(0x3, DL, MVT::i32));
+ RMValue =
+ DAG.getNode(ISD::SHL, DL, MVT::i32, RMValue,
+ DAG.getConstant(AArch64::RoundingBitsPos, DL, MVT::i32));
+ RMValue = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i64, RMValue);
+
+ // Get current value of FPCR.
+ SDValue Ops[] = {
+ Chain, DAG.getTargetConstant(Intrinsic::aarch64_get_fpcr, DL, MVT::i64)};
+ SDValue FPCR =
+ DAG.getNode(ISD::INTRINSIC_W_CHAIN, DL, {MVT::i64, MVT::Other}, Ops);
+ Chain = FPCR.getValue(1);
+ FPCR = FPCR.getValue(0);
+
+ // Put new rounding mode into FPSCR[23:22].
+ const int RMMask = ~(AArch64::Rounding::rmMask << AArch64::RoundingBitsPos);
+ FPCR = DAG.getNode(ISD::AND, DL, MVT::i64, FPCR,
+ DAG.getConstant(RMMask, DL, MVT::i64));
+ FPCR = DAG.getNode(ISD::OR, DL, MVT::i64, FPCR, RMValue);
+ SDValue Ops2[] = {
+ Chain, DAG.getTargetConstant(Intrinsic::aarch64_set_fpcr, DL, MVT::i64),
+ FPCR};
+ return DAG.getNode(ISD::INTRINSIC_VOID, DL, MVT::Other, Ops2);
+}
+
SDValue AArch64TargetLowering::LowerMUL(SDValue Op, SelectionDAG &DAG) const {
EVT VT = Op.getValueType();
@@ -3535,6 +3819,37 @@ static inline SDValue getPTrue(SelectionDAG &DAG, SDLoc DL, EVT VT,
DAG.getTargetConstant(Pattern, DL, MVT::i32));
}
+static SDValue lowerConvertToSVBool(SDValue Op, SelectionDAG &DAG) {
+ SDLoc DL(Op);
+ EVT OutVT = Op.getValueType();
+ SDValue InOp = Op.getOperand(1);
+ EVT InVT = InOp.getValueType();
+
+ // Return the operand if the cast isn't changing type,
+ // i.e. <n x 16 x i1> -> <n x 16 x i1>
+ if (InVT == OutVT)
+ return InOp;
+
+ SDValue Reinterpret =
+ DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, OutVT, InOp);
+
+ // If the argument converted to an svbool is a ptrue or a comparison, the
+ // lanes introduced by the widening are zero by construction.
+ switch (InOp.getOpcode()) {
+ case AArch64ISD::SETCC_MERGE_ZERO:
+ return Reinterpret;
+ case ISD::INTRINSIC_WO_CHAIN:
+ if (InOp.getConstantOperandVal(0) == Intrinsic::aarch64_sve_ptrue)
+ return Reinterpret;
+ }
+
+ // Otherwise, zero the newly introduced lanes.
+ SDValue Mask = getPTrue(DAG, DL, InVT, AArch64SVEPredPattern::all);
+ SDValue MaskReinterpret =
+ DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, OutVT, Mask);
+ return DAG.getNode(ISD::AND, DL, OutVT, Reinterpret, MaskReinterpret);
+}
+
SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
SelectionDAG &DAG) const {
unsigned IntNo = cast<ConstantSDNode>(Op.getOperand(0))->getZExtValue();
@@ -3596,7 +3911,7 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
return DAG.getNode(AArch64ISD::LASTB, dl, Op.getValueType(),
Op.getOperand(1), Op.getOperand(2));
case Intrinsic::aarch64_sve_rev:
- return DAG.getNode(AArch64ISD::REV, dl, Op.getValueType(),
+ return DAG.getNode(ISD::VECTOR_REVERSE, dl, Op.getValueType(),
Op.getOperand(1));
case Intrinsic::aarch64_sve_tbl:
return DAG.getNode(AArch64ISD::TBL, dl, Op.getValueType(),
@@ -3619,6 +3934,9 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
case Intrinsic::aarch64_sve_zip2:
return DAG.getNode(AArch64ISD::ZIP2, dl, Op.getValueType(),
Op.getOperand(1), Op.getOperand(2));
+ case Intrinsic::aarch64_sve_splice:
+ return DAG.getNode(AArch64ISD::SPLICE, dl, Op.getValueType(),
+ Op.getOperand(1), Op.getOperand(2), Op.getOperand(3));
case Intrinsic::aarch64_sve_ptrue:
return DAG.getNode(AArch64ISD::PTRUE, dl, Op.getValueType(),
Op.getOperand(1));
@@ -3638,6 +3956,8 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
case Intrinsic::aarch64_sve_convert_from_svbool:
return DAG.getNode(AArch64ISD::REINTERPRET_CAST, dl, Op.getValueType(),
Op.getOperand(1));
+ case Intrinsic::aarch64_sve_convert_to_svbool:
+ return lowerConvertToSVBool(Op, DAG);
case Intrinsic::aarch64_sve_fneg:
return DAG.getNode(AArch64ISD::FNEG_MERGE_PASSTHRU, dl, Op.getValueType(),
Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
@@ -3693,22 +4013,6 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
case Intrinsic::aarch64_sve_neg:
return DAG.getNode(AArch64ISD::NEG_MERGE_PASSTHRU, dl, Op.getValueType(),
Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
- case Intrinsic::aarch64_sve_convert_to_svbool: {
- EVT OutVT = Op.getValueType();
- EVT InVT = Op.getOperand(1).getValueType();
- // Return the operand if the cast isn't changing type,
- // i.e. <n x 16 x i1> -> <n x 16 x i1>
- if (InVT == OutVT)
- return Op.getOperand(1);
- // Otherwise, zero the newly introduced lanes.
- SDValue Reinterpret =
- DAG.getNode(AArch64ISD::REINTERPRET_CAST, dl, OutVT, Op.getOperand(1));
- SDValue Mask = getPTrue(DAG, dl, InVT, AArch64SVEPredPattern::all);
- SDValue MaskReinterpret =
- DAG.getNode(AArch64ISD::REINTERPRET_CAST, dl, OutVT, Mask);
- return DAG.getNode(ISD::AND, dl, OutVT, Reinterpret, MaskReinterpret);
- }
-
case Intrinsic::aarch64_sve_insr: {
SDValue Scalar = Op.getOperand(2);
EVT ScalarTy = Scalar.getValueType();
@@ -3813,18 +4117,40 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
return DAG.getNode(Opcode, dl, Op.getValueType(), Op.getOperand(1),
Op.getOperand(2));
}
-
+ case Intrinsic::aarch64_neon_sabd:
case Intrinsic::aarch64_neon_uabd: {
- return DAG.getNode(AArch64ISD::UABD, dl, Op.getValueType(),
- Op.getOperand(1), Op.getOperand(2));
+ unsigned Opcode = IntNo == Intrinsic::aarch64_neon_uabd ? ISD::ABDU
+ : ISD::ABDS;
+ return DAG.getNode(Opcode, dl, Op.getValueType(), Op.getOperand(1),
+ Op.getOperand(2));
}
- case Intrinsic::aarch64_neon_sabd: {
- return DAG.getNode(AArch64ISD::SABD, dl, Op.getValueType(),
- Op.getOperand(1), Op.getOperand(2));
+ case Intrinsic::aarch64_neon_uaddlp: {
+ unsigned Opcode = AArch64ISD::UADDLP;
+ return DAG.getNode(Opcode, dl, Op.getValueType(), Op.getOperand(1));
+ }
+ case Intrinsic::aarch64_neon_sdot:
+ case Intrinsic::aarch64_neon_udot:
+ case Intrinsic::aarch64_sve_sdot:
+ case Intrinsic::aarch64_sve_udot: {
+ unsigned Opcode = (IntNo == Intrinsic::aarch64_neon_udot ||
+ IntNo == Intrinsic::aarch64_sve_udot)
+ ? AArch64ISD::UDOT
+ : AArch64ISD::SDOT;
+ return DAG.getNode(Opcode, dl, Op.getValueType(), Op.getOperand(1),
+ Op.getOperand(2), Op.getOperand(3));
}
}
}
+bool AArch64TargetLowering::shouldExtendGSIndex(EVT VT, EVT &EltTy) const {
+ if (VT.getVectorElementType() == MVT::i8 ||
+ VT.getVectorElementType() == MVT::i16) {
+ EltTy = MVT::i32;
+ return true;
+ }
+ return false;
+}
+
bool AArch64TargetLowering::shouldRemoveExtendFromGSIndex(EVT VT) const {
if (VT.getVectorElementType() == MVT::i32 &&
VT.getVectorElementCount().getKnownMinValue() >= 4)
@@ -3938,6 +4264,12 @@ void selectGatherScatterAddrMode(SDValue &BasePtr, SDValue &Index, EVT MemVT,
if (!isNullConstant(BasePtr))
return;
+ // FIXME: This will not match for fixed vector type codegen as the nodes in
+ // question will have fixed<->scalable conversions around them. This should be
+ // moved to a DAG combine or complex pattern so that is executes after all of
+ // the fixed vector insert and extracts have been removed. This deficiency
+ // will result in a sub-optimal addressing mode being used, i.e. an ADD not
+ // being folded into the scatter/gather.
ConstantSDNode *Offset = nullptr;
if (Index.getOpcode() == ISD::ADD)
if (auto SplatVal = DAG.getSplatValue(Index.getOperand(1))) {
@@ -3982,6 +4314,8 @@ SDValue AArch64TargetLowering::LowerMGATHER(SDValue Op,
MaskedGatherSDNode *MGT = cast<MaskedGatherSDNode>(Op);
assert(MGT && "Can only custom lower gather load nodes");
+ bool IsFixedLength = MGT->getMemoryVT().isFixedLengthVector();
+
SDValue Index = MGT->getIndex();
SDValue Chain = MGT->getChain();
SDValue PassThru = MGT->getPassThru();
@@ -4000,6 +4334,7 @@ SDValue AArch64TargetLowering::LowerMGATHER(SDValue Op,
bool ResNeedsSignExtend = ExtTy == ISD::EXTLOAD || ExtTy == ISD::SEXTLOAD;
EVT VT = PassThru.getSimpleValueType();
+ EVT IndexVT = Index.getSimpleValueType();
EVT MemVT = MGT->getMemoryVT();
SDValue InputVT = DAG.getValueType(MemVT);
@@ -4007,14 +4342,27 @@ SDValue AArch64TargetLowering::LowerMGATHER(SDValue Op,
!static_cast<const AArch64Subtarget &>(DAG.getSubtarget()).hasBF16())
return SDValue();
- // Handle FP data by using an integer gather and casting the result.
- if (VT.isFloatingPoint()) {
- EVT PassThruVT = getPackedSVEVectorVT(VT.getVectorElementCount());
- PassThru = getSVESafeBitCast(PassThruVT, PassThru, DAG);
+ if (IsFixedLength) {
+ assert(Subtarget->useSVEForFixedLengthVectors() &&
+ "Cannot lower when not using SVE for fixed vectors");
+ IndexVT = getContainerForFixedLengthVector(DAG, IndexVT);
+ MemVT = IndexVT.changeVectorElementType(MemVT.getVectorElementType());
+ InputVT = DAG.getValueType(MemVT.changeTypeToInteger());
+ }
+
+ if (PassThru->isUndef() || isZerosVector(PassThru.getNode()))
+ PassThru = SDValue();
+
+ if (VT.isFloatingPoint() && !IsFixedLength) {
+ // Handle FP data by using an integer gather and casting the result.
+ if (PassThru) {
+ EVT PassThruVT = getPackedSVEVectorVT(VT.getVectorElementCount());
+ PassThru = getSVESafeBitCast(PassThruVT, PassThru, DAG);
+ }
InputVT = DAG.getValueType(MemVT.changeVectorElementTypeToInteger());
}
- SDVTList VTs = DAG.getVTList(PassThru.getSimpleValueType(), MVT::Other);
+ SDVTList VTs = DAG.getVTList(IndexVT, MVT::Other);
if (getGatherScatterIndexIsExtended(Index))
Index = Index.getOperand(0);
@@ -4026,15 +4374,36 @@ SDValue AArch64TargetLowering::LowerMGATHER(SDValue Op,
if (ResNeedsSignExtend)
Opcode = getSignExtendedGatherOpcode(Opcode);
- SDValue Ops[] = {Chain, Mask, BasePtr, Index, InputVT, PassThru};
- SDValue Gather = DAG.getNode(Opcode, DL, VTs, Ops);
+ if (IsFixedLength) {
+ if (Index.getSimpleValueType().isFixedLengthVector())
+ Index = convertToScalableVector(DAG, IndexVT, Index);
+ if (BasePtr.getSimpleValueType().isFixedLengthVector())
+ BasePtr = convertToScalableVector(DAG, IndexVT, BasePtr);
+ Mask = convertFixedMaskToScalableVector(Mask, DAG);
+ }
- if (VT.isFloatingPoint()) {
- SDValue Cast = getSVESafeBitCast(VT, Gather, DAG);
- return DAG.getMergeValues({Cast, Gather}, DL);
+ SDValue Ops[] = {Chain, Mask, BasePtr, Index, InputVT};
+ SDValue Result = DAG.getNode(Opcode, DL, VTs, Ops);
+ Chain = Result.getValue(1);
+
+ if (IsFixedLength) {
+ Result = convertFromScalableVector(
+ DAG, VT.changeVectorElementType(IndexVT.getVectorElementType()),
+ Result);
+ Result = DAG.getNode(ISD::TRUNCATE, DL, VT.changeTypeToInteger(), Result);
+ Result = DAG.getNode(ISD::BITCAST, DL, VT, Result);
+
+ if (PassThru)
+ Result = DAG.getSelect(DL, VT, MGT->getMask(), Result, PassThru);
+ } else {
+ if (PassThru)
+ Result = DAG.getSelect(DL, IndexVT, Mask, Result, PassThru);
+
+ if (VT.isFloatingPoint())
+ Result = getSVESafeBitCast(VT, Result, DAG);
}
- return Gather;
+ return DAG.getMergeValues({Result, Chain}, DL);
}
SDValue AArch64TargetLowering::LowerMSCATTER(SDValue Op,
@@ -4043,6 +4412,8 @@ SDValue AArch64TargetLowering::LowerMSCATTER(SDValue Op,
MaskedScatterSDNode *MSC = cast<MaskedScatterSDNode>(Op);
assert(MSC && "Can only custom lower scatter store nodes");
+ bool IsFixedLength = MSC->getMemoryVT().isFixedLengthVector();
+
SDValue Index = MSC->getIndex();
SDValue Chain = MSC->getChain();
SDValue StoreVal = MSC->getValue();
@@ -4059,6 +4430,7 @@ SDValue AArch64TargetLowering::LowerMSCATTER(SDValue Op,
Index.getSimpleValueType().getVectorElementType() == MVT::i32;
EVT VT = StoreVal.getSimpleValueType();
+ EVT IndexVT = Index.getSimpleValueType();
SDVTList VTs = DAG.getVTList(MVT::Other);
EVT MemVT = MSC->getMemoryVT();
SDValue InputVT = DAG.getValueType(MemVT);
@@ -4067,8 +4439,21 @@ SDValue AArch64TargetLowering::LowerMSCATTER(SDValue Op,
!static_cast<const AArch64Subtarget &>(DAG.getSubtarget()).hasBF16())
return SDValue();
- // Handle FP data by casting the data so an integer scatter can be used.
- if (VT.isFloatingPoint()) {
+ if (IsFixedLength) {
+ assert(Subtarget->useSVEForFixedLengthVectors() &&
+ "Cannot lower when not using SVE for fixed vectors");
+ IndexVT = getContainerForFixedLengthVector(DAG, IndexVT);
+ MemVT = IndexVT.changeVectorElementType(MemVT.getVectorElementType());
+ InputVT = DAG.getValueType(MemVT.changeTypeToInteger());
+
+ StoreVal =
+ DAG.getNode(ISD::BITCAST, DL, VT.changeTypeToInteger(), StoreVal);
+ StoreVal = DAG.getNode(
+ ISD::ANY_EXTEND, DL,
+ VT.changeVectorElementType(IndexVT.getVectorElementType()), StoreVal);
+ StoreVal = convertToScalableVector(DAG, IndexVT, StoreVal);
+ } else if (VT.isFloatingPoint()) {
+ // Handle FP data by casting the data so an integer scatter can be used.
EVT StoreValVT = getPackedSVEVectorVT(VT.getVectorElementCount());
StoreVal = getSVESafeBitCast(StoreValVT, StoreVal, DAG);
InputVT = DAG.getValueType(MemVT.changeVectorElementTypeToInteger());
@@ -4081,10 +4466,44 @@ SDValue AArch64TargetLowering::LowerMSCATTER(SDValue Op,
selectGatherScatterAddrMode(BasePtr, Index, MemVT, Opcode,
/*isGather=*/false, DAG);
+ if (IsFixedLength) {
+ if (Index.getSimpleValueType().isFixedLengthVector())
+ Index = convertToScalableVector(DAG, IndexVT, Index);
+ if (BasePtr.getSimpleValueType().isFixedLengthVector())
+ BasePtr = convertToScalableVector(DAG, IndexVT, BasePtr);
+ Mask = convertFixedMaskToScalableVector(Mask, DAG);
+ }
+
SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, InputVT};
return DAG.getNode(Opcode, DL, VTs, Ops);
}
+SDValue AArch64TargetLowering::LowerMLOAD(SDValue Op, SelectionDAG &DAG) const {
+ SDLoc DL(Op);
+ MaskedLoadSDNode *LoadNode = cast<MaskedLoadSDNode>(Op);
+ assert(LoadNode && "Expected custom lowering of a masked load node");
+ EVT VT = Op->getValueType(0);
+
+ if (useSVEForFixedLengthVectorVT(VT, true))
+ return LowerFixedLengthVectorMLoadToSVE(Op, DAG);
+
+ SDValue PassThru = LoadNode->getPassThru();
+ SDValue Mask = LoadNode->getMask();
+
+ if (PassThru->isUndef() || isZerosVector(PassThru.getNode()))
+ return Op;
+
+ SDValue Load = DAG.getMaskedLoad(
+ VT, DL, LoadNode->getChain(), LoadNode->getBasePtr(),
+ LoadNode->getOffset(), Mask, DAG.getUNDEF(VT), LoadNode->getMemoryVT(),
+ LoadNode->getMemOperand(), LoadNode->getAddressingMode(),
+ LoadNode->getExtensionType());
+
+ SDValue Result = DAG.getSelect(DL, VT, Mask, Load, PassThru);
+
+ return DAG.getMergeValues({Result, Load.getValue(1)}, DL);
+}
+
// Custom lower trunc store for v4i8 vectors, since it is promoted to v4i16.
static SDValue LowerTruncateVectorStore(SDLoc DL, StoreSDNode *ST,
EVT VT, EVT MemVT,
@@ -4132,19 +4551,20 @@ SDValue AArch64TargetLowering::LowerSTORE(SDValue Op,
EVT MemVT = StoreNode->getMemoryVT();
if (VT.isVector()) {
- if (useSVEForFixedLengthVectorVT(VT))
+ if (useSVEForFixedLengthVectorVT(VT, true))
return LowerFixedLengthVectorStoreToSVE(Op, DAG);
unsigned AS = StoreNode->getAddressSpace();
Align Alignment = StoreNode->getAlign();
if (Alignment < MemVT.getStoreSize() &&
- !allowsMisalignedMemoryAccesses(MemVT, AS, Alignment.value(),
+ !allowsMisalignedMemoryAccesses(MemVT, AS, Alignment,
StoreNode->getMemOperand()->getFlags(),
nullptr)) {
return scalarizeVectorStore(StoreNode, DAG);
}
- if (StoreNode->isTruncatingStore()) {
+ if (StoreNode->isTruncatingStore() && VT == MVT::v4i16 &&
+ MemVT == MVT::v4i8) {
return LowerTruncateVectorStore(Dl, StoreNode, VT, MemVT, DAG);
}
// 256 bit non-temporal stores can be lowered to STNP. Do this as part of
@@ -4190,6 +4610,40 @@ SDValue AArch64TargetLowering::LowerSTORE(SDValue Op,
return SDValue();
}
+// Custom lowering for extending v4i8 vector loads.
+SDValue AArch64TargetLowering::LowerLOAD(SDValue Op,
+ SelectionDAG &DAG) const {
+ SDLoc DL(Op);
+ LoadSDNode *LoadNode = cast<LoadSDNode>(Op);
+ assert(LoadNode && "Expected custom lowering of a load node");
+ EVT VT = Op->getValueType(0);
+ assert((VT == MVT::v4i16 || VT == MVT::v4i32) && "Expected v4i16 or v4i32");
+
+ if (LoadNode->getMemoryVT() != MVT::v4i8)
+ return SDValue();
+
+ unsigned ExtType;
+ if (LoadNode->getExtensionType() == ISD::SEXTLOAD)
+ ExtType = ISD::SIGN_EXTEND;
+ else if (LoadNode->getExtensionType() == ISD::ZEXTLOAD ||
+ LoadNode->getExtensionType() == ISD::EXTLOAD)
+ ExtType = ISD::ZERO_EXTEND;
+ else
+ return SDValue();
+
+ SDValue Load = DAG.getLoad(MVT::f32, DL, LoadNode->getChain(),
+ LoadNode->getBasePtr(), MachinePointerInfo());
+ SDValue Chain = Load.getValue(1);
+ SDValue Vec = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, MVT::v2f32, Load);
+ SDValue BC = DAG.getNode(ISD::BITCAST, DL, MVT::v8i8, Vec);
+ SDValue Ext = DAG.getNode(ExtType, DL, MVT::v8i16, BC);
+ Ext = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v4i16, Ext,
+ DAG.getConstant(0, DL, MVT::i64));
+ if (VT == MVT::v4i32)
+ Ext = DAG.getNode(ExtType, DL, MVT::v4i32, Ext);
+ return DAG.getMergeValues({Ext, Chain}, DL);
+}
+
// Generate SUBS and CSEL for integer abs.
SDValue AArch64TargetLowering::LowerABS(SDValue Op, SelectionDAG &DAG) const {
MVT VT = Op.getSimpleValueType();
@@ -4339,10 +4793,9 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
case ISD::SHL:
return LowerVectorSRA_SRL_SHL(Op, DAG);
case ISD::SHL_PARTS:
- return LowerShiftLeftParts(Op, DAG);
case ISD::SRL_PARTS:
case ISD::SRA_PARTS:
- return LowerShiftRightParts(Op, DAG);
+ return LowerShiftParts(Op, DAG);
case ISD::CTPOP:
return LowerCTPOP(Op, DAG);
case ISD::FCOPYSIGN:
@@ -4363,16 +4816,29 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
case ISD::STRICT_FP_TO_SINT:
case ISD::STRICT_FP_TO_UINT:
return LowerFP_TO_INT(Op, DAG);
+ case ISD::FP_TO_SINT_SAT:
+ case ISD::FP_TO_UINT_SAT:
+ return LowerFP_TO_INT_SAT(Op, DAG);
case ISD::FSINCOS:
return LowerFSINCOS(Op, DAG);
case ISD::FLT_ROUNDS_:
return LowerFLT_ROUNDS_(Op, DAG);
+ case ISD::SET_ROUNDING:
+ return LowerSET_ROUNDING(Op, DAG);
case ISD::MUL:
return LowerMUL(Op, DAG);
+ case ISD::MULHS:
+ return LowerToPredicatedOp(Op, DAG, AArch64ISD::MULHS_PRED,
+ /*OverrideNEON=*/true);
+ case ISD::MULHU:
+ return LowerToPredicatedOp(Op, DAG, AArch64ISD::MULHU_PRED,
+ /*OverrideNEON=*/true);
case ISD::INTRINSIC_WO_CHAIN:
return LowerINTRINSIC_WO_CHAIN(Op, DAG);
case ISD::STORE:
return LowerSTORE(Op, DAG);
+ case ISD::MSTORE:
+ return LowerFixedLengthVectorMStoreToSVE(Op, DAG);
case ISD::MGATHER:
return LowerMGATHER(Op, DAG);
case ISD::MSCATTER:
@@ -4416,18 +4882,24 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
}
case ISD::TRUNCATE:
return LowerTRUNCATE(Op, DAG);
+ case ISD::MLOAD:
+ return LowerMLOAD(Op, DAG);
case ISD::LOAD:
if (useSVEForFixedLengthVectorVT(Op.getValueType()))
return LowerFixedLengthVectorLoadToSVE(Op, DAG);
- llvm_unreachable("Unexpected request to lower ISD::LOAD");
+ return LowerLOAD(Op, DAG);
case ISD::ADD:
return LowerToPredicatedOp(Op, DAG, AArch64ISD::ADD_PRED);
case ISD::AND:
return LowerToScalableOp(Op, DAG);
case ISD::SUB:
return LowerToPredicatedOp(Op, DAG, AArch64ISD::SUB_PRED);
+ case ISD::FMAXIMUM:
+ return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMAX_PRED);
case ISD::FMAXNUM:
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMAXNM_PRED);
+ case ISD::FMINIMUM:
+ return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMIN_PRED);
case ISD::FMINNUM:
return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMINNM_PRED);
case ISD::VSELECT:
@@ -4435,8 +4907,7 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
case ISD::ABS:
return LowerABS(Op, DAG);
case ISD::BITREVERSE:
- return LowerToPredicatedOp(Op, DAG, AArch64ISD::BITREVERSE_MERGE_PASSTHRU,
- /*OverrideNEON=*/true);
+ return LowerBitreverse(Op, DAG);
case ISD::BSWAP:
return LowerToPredicatedOp(Op, DAG, AArch64ISD::BSWAP_MERGE_PASSTHRU);
case ISD::CTLZ:
@@ -4444,6 +4915,8 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
/*OverrideNEON=*/true);
case ISD::CTTZ:
return LowerCTTZ(Op, DAG);
+ case ISD::VECTOR_SPLICE:
+ return LowerVECTOR_SPLICE(Op, DAG);
}
}
@@ -4515,6 +4988,8 @@ CCAssignFn *AArch64TargetLowering::CCAssignFnForCall(CallingConv::ID CC,
case CallingConv::PreserveMost:
case CallingConv::CXX_FAST_TLS:
case CallingConv::Swift:
+ case CallingConv::SwiftTail:
+ case CallingConv::Tail:
if (Subtarget->isTargetWindows() && IsVarArg)
return CC_AArch64_Win64_VarArg;
if (!Subtarget->isTargetDarwin())
@@ -4578,7 +5053,10 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
else if (ActualMVT == MVT::i16)
ValVT = MVT::i16;
}
- CCAssignFn *AssignFn = CCAssignFnForCall(CallConv, /*IsVarArg=*/false);
+ bool UseVarArgCC = false;
+ if (IsWin64)
+ UseVarArgCC = isVarArg;
+ CCAssignFn *AssignFn = CCAssignFnForCall(CallConv, UseVarArgCC);
bool Res =
AssignFn(i, ValVT, ValVT, CCValAssign::Full, Ins[i].Flags, CCInfo);
assert(!Res && "Call operand has unhandled type");
@@ -4606,6 +5084,9 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
continue;
}
+ if (Ins[i].Flags.isSwiftAsync())
+ MF.getInfo<AArch64FunctionInfo>()->setHasSwiftAsyncContext(true);
+
SDValue ArgValue;
if (VA.isRegLoc()) {
// Arguments stored in registers.
@@ -4709,7 +5190,6 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
ExtType, DL, VA.getLocVT(), Chain, FIN,
MachinePointerInfo::getFixedStack(DAG.getMachineFunction(), FI),
MemVT);
-
}
if (VA.getLocInfo() == CCValAssign::Indirect) {
@@ -4985,8 +5465,9 @@ SDValue AArch64TargetLowering::LowerCallResult(
}
/// Return true if the calling convention is one that we can guarantee TCO for.
-static bool canGuaranteeTCO(CallingConv::ID CC) {
- return CC == CallingConv::Fast;
+static bool canGuaranteeTCO(CallingConv::ID CC, bool GuaranteeTailCalls) {
+ return (CC == CallingConv::Fast && GuaranteeTailCalls) ||
+ CC == CallingConv::Tail || CC == CallingConv::SwiftTail;
}
/// Return true if we might ever do TCO for calls with this calling convention.
@@ -4996,9 +5477,12 @@ static bool mayTailCallThisCC(CallingConv::ID CC) {
case CallingConv::AArch64_SVE_VectorCall:
case CallingConv::PreserveMost:
case CallingConv::Swift:
+ case CallingConv::SwiftTail:
+ case CallingConv::Tail:
+ case CallingConv::Fast:
return true;
default:
- return canGuaranteeTCO(CC);
+ return false;
}
}
@@ -5014,11 +5498,11 @@ bool AArch64TargetLowering::isEligibleForTailCallOptimization(
const Function &CallerF = MF.getFunction();
CallingConv::ID CallerCC = CallerF.getCallingConv();
- // If this function uses the C calling convention but has an SVE signature,
- // then it preserves more registers and should assume the SVE_VectorCall CC.
+ // Functions using the C or Fast calling convention that have an SVE signature
+ // preserve more registers and should assume the SVE_VectorCall CC.
// The check for matching callee-saved regs will determine whether it is
// eligible for TCO.
- if (CallerCC == CallingConv::C &&
+ if ((CallerCC == CallingConv::C || CallerCC == CallingConv::Fast) &&
AArch64RegisterInfo::hasSVEArgsOrReturn(&MF))
CallerCC = CallingConv::AArch64_SVE_VectorCall;
@@ -5050,8 +5534,8 @@ bool AArch64TargetLowering::isEligibleForTailCallOptimization(
return false;
}
- if (getTargetMachine().Options.GuaranteedTailCallOpt)
- return canGuaranteeTCO(CalleeCC) && CCMatch;
+ if (canGuaranteeTCO(CalleeCC, getTargetMachine().Options.GuaranteedTailCallOpt))
+ return CCMatch;
// Externally-defined functions with weak linkage should not be
// tail-called on AArch64 when the OS does not support dynamic
@@ -5182,7 +5666,8 @@ SDValue AArch64TargetLowering::addTokenForArgument(SDValue Chain,
bool AArch64TargetLowering::DoesCalleeRestoreStack(CallingConv::ID CallCC,
bool TailCallOpt) const {
- return CallCC == CallingConv::Fast && TailCallOpt;
+ return (CallCC == CallingConv::Fast && TailCallOpt) ||
+ CallCC == CallingConv::Tail || CallCC == CallingConv::SwiftTail;
}
/// LowerCall - Lower a call to a callseq_start + CALL + callseq_end chain,
@@ -5208,10 +5693,11 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
AArch64FunctionInfo *FuncInfo = MF.getInfo<AArch64FunctionInfo>();
bool TailCallOpt = MF.getTarget().Options.GuaranteedTailCallOpt;
bool IsSibCall = false;
+ bool IsCalleeWin64 = Subtarget->isCallingConvWin64(CallConv);
// Check callee args/returns for SVE registers and set calling convention
// accordingly.
- if (CallConv == CallingConv::C) {
+ if (CallConv == CallingConv::C || CallConv == CallingConv::Fast) {
bool CalleeOutSVE = any_of(Outs, [](ISD::OutputArg &Out){
return Out.VT.isScalableVector();
});
@@ -5227,19 +5713,21 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
// Check if it's really possible to do a tail call.
IsTailCall = isEligibleForTailCallOptimization(
Callee, CallConv, IsVarArg, Outs, OutVals, Ins, DAG);
- if (!IsTailCall && CLI.CB && CLI.CB->isMustTailCall())
- report_fatal_error("failed to perform tail call elimination on a call "
- "site marked musttail");
// A sibling call is one where we're under the usual C ABI and not planning
// to change that but can still do a tail call:
- if (!TailCallOpt && IsTailCall)
+ if (!TailCallOpt && IsTailCall && CallConv != CallingConv::Tail &&
+ CallConv != CallingConv::SwiftTail)
IsSibCall = true;
if (IsTailCall)
++NumTailCalls;
}
+ if (!IsTailCall && CLI.CB && CLI.CB->isMustTailCall())
+ report_fatal_error("failed to perform tail call elimination on a call "
+ "site marked musttail");
+
// Analyze operands of the call, assigning locations to each operand.
SmallVector<CCValAssign, 16> ArgLocs;
CCState CCInfo(CallConv, IsVarArg, DAG.getMachineFunction(), ArgLocs,
@@ -5257,8 +5745,12 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
"currently not supported");
ISD::ArgFlagsTy ArgFlags = Outs[i].Flags;
- CCAssignFn *AssignFn = CCAssignFnForCall(CallConv,
- /*IsVarArg=*/ !Outs[i].IsFixed);
+ bool UseVarArgCC = !Outs[i].IsFixed;
+ // On Windows, the fixed arguments in a vararg call are passed in GPRs
+ // too, so use the vararg CC to force them to integer registers.
+ if (IsCalleeWin64)
+ UseVarArgCC = true;
+ CCAssignFn *AssignFn = CCAssignFnForCall(CallConv, UseVarArgCC);
bool Res = AssignFn(i, ArgVT, ArgVT, CCValAssign::Full, ArgFlags, CCInfo);
assert(!Res && "Call operand has unhandled type");
(void)Res;
@@ -5320,6 +5812,11 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
// can actually shrink the stack.
FPDiff = NumReusableBytes - NumBytes;
+ // Update the required reserved area if this is the tail call requiring the
+ // most argument stack space.
+ if (FPDiff < 0 && FuncInfo->getTailCallReservedStack() < (unsigned)-FPDiff)
+ FuncInfo->setTailCallReservedStack(-FPDiff);
+
// The stack pointer must be 16-byte aligned at all times it's used for a
// memory operation, which in practice means at *all* times and in
// particular across call boundaries. Therefore our own arguments started at
@@ -5331,7 +5828,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
// Adjust the stack pointer for the new arguments...
// These operations are automatically eliminated by the prolog/epilog pass
if (!IsSibCall)
- Chain = DAG.getCALLSEQ_START(Chain, NumBytes, 0, DL);
+ Chain = DAG.getCALLSEQ_START(Chain, IsTailCall ? 0 : NumBytes, 0, DL);
SDValue StackPtr = DAG.getCopyFromReg(Chain, DL, AArch64::SP,
getPointerTy(DAG.getDataLayout()));
@@ -5485,7 +5982,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
// common case. It should also work for fundamental types too.
uint32_t BEAlign = 0;
unsigned OpSize;
- if (VA.getLocInfo() == CCValAssign::Indirect)
+ if (VA.getLocInfo() == CCValAssign::Indirect ||
+ VA.getLocInfo() == CCValAssign::Trunc)
OpSize = VA.getLocVT().getFixedSizeInBits();
else
OpSize = Flags.isByVal() ? Flags.getByValSize() * 8
@@ -5588,7 +6086,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
// we've carefully laid out the parameters so that when sp is reset they'll be
// in the correct location.
if (IsTailCall && !IsSibCall) {
- Chain = DAG.getCALLSEQ_END(Chain, DAG.getIntPtrConstant(NumBytes, DL, true),
+ Chain = DAG.getCALLSEQ_END(Chain, DAG.getIntPtrConstant(0, DL, true),
DAG.getIntPtrConstant(0, DL, true), InFlag, DL);
InFlag = Chain.getValue(1);
}
@@ -5647,11 +6145,12 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
}
unsigned CallOpc = AArch64ISD::CALL;
- // Calls marked with "rv_marker" are special. They should be expanded to the
- // call, directly followed by a special marker sequence. Use the CALL_RVMARKER
- // to do that.
- if (CLI.CB && CLI.CB->hasRetAttr("rv_marker")) {
- assert(!IsTailCall && "tail calls cannot be marked with rv_marker");
+ // Calls with operand bundle "clang.arc.attachedcall" are special. They should
+ // be expanded to the call, directly followed by a special marker sequence.
+ // Use the CALL_RVMARKER to do that.
+ if (CLI.CB && objcarc::hasAttachedCallOpBundle(CLI.CB)) {
+ assert(!IsTailCall &&
+ "tail calls cannot be marked with clang.arc.attachedcall");
CallOpc = AArch64ISD::CALL_RVMARKER;
}
@@ -6584,6 +7083,56 @@ SDValue AArch64TargetLowering::LowerCTTZ(SDValue Op, SelectionDAG &DAG) const {
return DAG.getNode(ISD::CTLZ, DL, VT, RBIT);
}
+SDValue AArch64TargetLowering::LowerBitreverse(SDValue Op,
+ SelectionDAG &DAG) const {
+ EVT VT = Op.getValueType();
+
+ if (VT.isScalableVector() ||
+ useSVEForFixedLengthVectorVT(VT, /*OverrideNEON=*/true))
+ return LowerToPredicatedOp(Op, DAG, AArch64ISD::BITREVERSE_MERGE_PASSTHRU,
+ true);
+
+ SDLoc DL(Op);
+ SDValue REVB;
+ MVT VST;
+
+ switch (VT.getSimpleVT().SimpleTy) {
+ default:
+ llvm_unreachable("Invalid type for bitreverse!");
+
+ case MVT::v2i32: {
+ VST = MVT::v8i8;
+ REVB = DAG.getNode(AArch64ISD::REV32, DL, VST, Op.getOperand(0));
+
+ break;
+ }
+
+ case MVT::v4i32: {
+ VST = MVT::v16i8;
+ REVB = DAG.getNode(AArch64ISD::REV32, DL, VST, Op.getOperand(0));
+
+ break;
+ }
+
+ case MVT::v1i64: {
+ VST = MVT::v8i8;
+ REVB = DAG.getNode(AArch64ISD::REV64, DL, VST, Op.getOperand(0));
+
+ break;
+ }
+
+ case MVT::v2i64: {
+ VST = MVT::v16i8;
+ REVB = DAG.getNode(AArch64ISD::REV64, DL, VST, Op.getOperand(0));
+
+ break;
+ }
+ }
+
+ return DAG.getNode(AArch64ISD::NVCAST, DL, VT,
+ DAG.getNode(ISD::BITREVERSE, DL, VST, REVB));
+}
+
SDValue AArch64TargetLowering::LowerSETCC(SDValue Op, SelectionDAG &DAG) const {
if (Op.getValueType().isVector())
@@ -6700,13 +7249,26 @@ SDValue AArch64TargetLowering::LowerSELECT_CC(ISD::CondCode CC, SDValue LHS,
assert((LHS.getValueType() == RHS.getValueType()) &&
(LHS.getValueType() == MVT::i32 || LHS.getValueType() == MVT::i64));
+ ConstantSDNode *CFVal = dyn_cast<ConstantSDNode>(FVal);
+ ConstantSDNode *CTVal = dyn_cast<ConstantSDNode>(TVal);
+ ConstantSDNode *RHSC = dyn_cast<ConstantSDNode>(RHS);
+ // Check for sign pattern (SELECT_CC setgt, iN lhs, -1, 1, -1) and transform
+ // into (OR (ASR lhs, N-1), 1), which requires less instructions for the
+ // supported types.
+ if (CC == ISD::SETGT && RHSC && RHSC->isAllOnesValue() && CTVal && CFVal &&
+ CTVal->isOne() && CFVal->isAllOnesValue() &&
+ LHS.getValueType() == TVal.getValueType()) {
+ EVT VT = LHS.getValueType();
+ SDValue Shift =
+ DAG.getNode(ISD::SRA, dl, VT, LHS,
+ DAG.getConstant(VT.getSizeInBits() - 1, dl, VT));
+ return DAG.getNode(ISD::OR, dl, VT, Shift, DAG.getConstant(1, dl, VT));
+ }
+
unsigned Opcode = AArch64ISD::CSEL;
// If both the TVal and the FVal are constants, see if we can swap them in
// order to for a CSINV or CSINC out of them.
- ConstantSDNode *CFVal = dyn_cast<ConstantSDNode>(FVal);
- ConstantSDNode *CTVal = dyn_cast<ConstantSDNode>(TVal);
-
if (CTVal && CFVal && CTVal->isAllOnesValue() && CFVal->isNullValue()) {
std::swap(TVal, FVal);
std::swap(CTVal, CFVal);
@@ -6861,6 +7423,16 @@ SDValue AArch64TargetLowering::LowerSELECT_CC(ISD::CondCode CC, SDValue LHS,
return CS1;
}
+SDValue AArch64TargetLowering::LowerVECTOR_SPLICE(SDValue Op,
+ SelectionDAG &DAG) const {
+
+ EVT Ty = Op.getValueType();
+ auto Idx = Op.getConstantOperandAPInt(2);
+ if (Idx.sge(-1) && Idx.slt(Ty.getVectorMinNumElements()))
+ return Op;
+ return SDValue();
+}
+
SDValue AArch64TargetLowering::LowerSELECT_CC(SDValue Op,
SelectionDAG &DAG) const {
ISD::CondCode CC = cast<CondCodeSDNode>(Op.getOperand(4))->get();
@@ -6887,6 +7459,17 @@ SDValue AArch64TargetLowering::LowerSELECT(SDValue Op,
return DAG.getNode(ISD::VSELECT, DL, Ty, SplatPred, TVal, FVal);
}
+ if (useSVEForFixedLengthVectorVT(Ty)) {
+ // FIXME: Ideally this would be the same as above using i1 types, however
+ // for the moment we can't deal with fixed i1 vector types properly, so
+ // instead extend the predicate to a result type sized integer vector.
+ MVT SplatValVT = MVT::getIntegerVT(Ty.getScalarSizeInBits());
+ MVT PredVT = MVT::getVectorVT(SplatValVT, Ty.getVectorElementCount());
+ SDValue SplatVal = DAG.getSExtOrTrunc(CCVal, DL, SplatValVT);
+ SDValue SplatPred = DAG.getNode(ISD::SPLAT_VECTOR, DL, PredVT, SplatVal);
+ return DAG.getNode(ISD::VSELECT, DL, Ty, SplatPred, TVal, FVal);
+ }
+
// Optimize {s|u}{add|sub|mul}.with.overflow feeding into a select
// instruction.
if (ISD::isOverflowIntrOpRes(CCVal)) {
@@ -6909,7 +7492,7 @@ SDValue AArch64TargetLowering::LowerSELECT(SDValue Op,
if (CCVal.getOpcode() == ISD::SETCC) {
LHS = CCVal.getOperand(0);
RHS = CCVal.getOperand(1);
- CC = cast<CondCodeSDNode>(CCVal->getOperand(2))->get();
+ CC = cast<CondCodeSDNode>(CCVal.getOperand(2))->get();
} else {
LHS = CCVal;
RHS = DAG.getConstant(0, DL, CCVal.getValueType());
@@ -7293,112 +7876,13 @@ SDValue AArch64TargetLowering::LowerRETURNADDR(SDValue Op,
return SDValue(St, 0);
}
-/// LowerShiftRightParts - Lower SRA_PARTS, which returns two
-/// i64 values and take a 2 x i64 value to shift plus a shift amount.
-SDValue AArch64TargetLowering::LowerShiftRightParts(SDValue Op,
- SelectionDAG &DAG) const {
- assert(Op.getNumOperands() == 3 && "Not a double-shift!");
- EVT VT = Op.getValueType();
- unsigned VTBits = VT.getSizeInBits();
- SDLoc dl(Op);
- SDValue ShOpLo = Op.getOperand(0);
- SDValue ShOpHi = Op.getOperand(1);
- SDValue ShAmt = Op.getOperand(2);
- unsigned Opc = (Op.getOpcode() == ISD::SRA_PARTS) ? ISD::SRA : ISD::SRL;
-
- assert(Op.getOpcode() == ISD::SRA_PARTS || Op.getOpcode() == ISD::SRL_PARTS);
-
- SDValue RevShAmt = DAG.getNode(ISD::SUB, dl, MVT::i64,
- DAG.getConstant(VTBits, dl, MVT::i64), ShAmt);
- SDValue HiBitsForLo = DAG.getNode(ISD::SHL, dl, VT, ShOpHi, RevShAmt);
-
- // Unfortunately, if ShAmt == 0, we just calculated "(SHL ShOpHi, 64)" which
- // is "undef". We wanted 0, so CSEL it directly.
- SDValue Cmp = emitComparison(ShAmt, DAG.getConstant(0, dl, MVT::i64),
- ISD::SETEQ, dl, DAG);
- SDValue CCVal = DAG.getConstant(AArch64CC::EQ, dl, MVT::i32);
- HiBitsForLo =
- DAG.getNode(AArch64ISD::CSEL, dl, VT, DAG.getConstant(0, dl, MVT::i64),
- HiBitsForLo, CCVal, Cmp);
-
- SDValue ExtraShAmt = DAG.getNode(ISD::SUB, dl, MVT::i64, ShAmt,
- DAG.getConstant(VTBits, dl, MVT::i64));
-
- SDValue LoBitsForLo = DAG.getNode(ISD::SRL, dl, VT, ShOpLo, ShAmt);
- SDValue LoForNormalShift =
- DAG.getNode(ISD::OR, dl, VT, LoBitsForLo, HiBitsForLo);
-
- Cmp = emitComparison(ExtraShAmt, DAG.getConstant(0, dl, MVT::i64), ISD::SETGE,
- dl, DAG);
- CCVal = DAG.getConstant(AArch64CC::GE, dl, MVT::i32);
- SDValue LoForBigShift = DAG.getNode(Opc, dl, VT, ShOpHi, ExtraShAmt);
- SDValue Lo = DAG.getNode(AArch64ISD::CSEL, dl, VT, LoForBigShift,
- LoForNormalShift, CCVal, Cmp);
-
- // AArch64 shifts larger than the register width are wrapped rather than
- // clamped, so we can't just emit "hi >> x".
- SDValue HiForNormalShift = DAG.getNode(Opc, dl, VT, ShOpHi, ShAmt);
- SDValue HiForBigShift =
- Opc == ISD::SRA
- ? DAG.getNode(Opc, dl, VT, ShOpHi,
- DAG.getConstant(VTBits - 1, dl, MVT::i64))
- : DAG.getConstant(0, dl, VT);
- SDValue Hi = DAG.getNode(AArch64ISD::CSEL, dl, VT, HiForBigShift,
- HiForNormalShift, CCVal, Cmp);
-
- SDValue Ops[2] = { Lo, Hi };
- return DAG.getMergeValues(Ops, dl);
-}
-
-/// LowerShiftLeftParts - Lower SHL_PARTS, which returns two
-/// i64 values and take a 2 x i64 value to shift plus a shift amount.
-SDValue AArch64TargetLowering::LowerShiftLeftParts(SDValue Op,
- SelectionDAG &DAG) const {
- assert(Op.getNumOperands() == 3 && "Not a double-shift!");
- EVT VT = Op.getValueType();
- unsigned VTBits = VT.getSizeInBits();
- SDLoc dl(Op);
- SDValue ShOpLo = Op.getOperand(0);
- SDValue ShOpHi = Op.getOperand(1);
- SDValue ShAmt = Op.getOperand(2);
-
- assert(Op.getOpcode() == ISD::SHL_PARTS);
- SDValue RevShAmt = DAG.getNode(ISD::SUB, dl, MVT::i64,
- DAG.getConstant(VTBits, dl, MVT::i64), ShAmt);
- SDValue LoBitsForHi = DAG.getNode(ISD::SRL, dl, VT, ShOpLo, RevShAmt);
-
- // Unfortunately, if ShAmt == 0, we just calculated "(SRL ShOpLo, 64)" which
- // is "undef". We wanted 0, so CSEL it directly.
- SDValue Cmp = emitComparison(ShAmt, DAG.getConstant(0, dl, MVT::i64),
- ISD::SETEQ, dl, DAG);
- SDValue CCVal = DAG.getConstant(AArch64CC::EQ, dl, MVT::i32);
- LoBitsForHi =
- DAG.getNode(AArch64ISD::CSEL, dl, VT, DAG.getConstant(0, dl, MVT::i64),
- LoBitsForHi, CCVal, Cmp);
-
- SDValue ExtraShAmt = DAG.getNode(ISD::SUB, dl, MVT::i64, ShAmt,
- DAG.getConstant(VTBits, dl, MVT::i64));
- SDValue HiBitsForHi = DAG.getNode(ISD::SHL, dl, VT, ShOpHi, ShAmt);
- SDValue HiForNormalShift =
- DAG.getNode(ISD::OR, dl, VT, LoBitsForHi, HiBitsForHi);
-
- SDValue HiForBigShift = DAG.getNode(ISD::SHL, dl, VT, ShOpLo, ExtraShAmt);
-
- Cmp = emitComparison(ExtraShAmt, DAG.getConstant(0, dl, MVT::i64), ISD::SETGE,
- dl, DAG);
- CCVal = DAG.getConstant(AArch64CC::GE, dl, MVT::i32);
- SDValue Hi = DAG.getNode(AArch64ISD::CSEL, dl, VT, HiForBigShift,
- HiForNormalShift, CCVal, Cmp);
-
- // AArch64 shifts of larger than register sizes are wrapped rather than
- // clamped, so we can't just emit "lo << a" if a is too big.
- SDValue LoForBigShift = DAG.getConstant(0, dl, VT);
- SDValue LoForNormalShift = DAG.getNode(ISD::SHL, dl, VT, ShOpLo, ShAmt);
- SDValue Lo = DAG.getNode(AArch64ISD::CSEL, dl, VT, LoForBigShift,
- LoForNormalShift, CCVal, Cmp);
-
- SDValue Ops[2] = { Lo, Hi };
- return DAG.getMergeValues(Ops, dl);
+/// LowerShiftParts - Lower SHL_PARTS/SRA_PARTS/SRL_PARTS, which returns two
+/// i32 values and take a 2 x i32 value to shift plus a shift amount.
+SDValue AArch64TargetLowering::LowerShiftParts(SDValue Op,
+ SelectionDAG &DAG) const {
+ SDValue Lo, Hi;
+ expandShiftParts(Op.getNode(), Lo, Hi, DAG);
+ return DAG.getMergeValues({Lo, Hi}, SDLoc(Op));
}
bool AArch64TargetLowering::isOffsetFoldingLegal(
@@ -7738,7 +8222,7 @@ AArch64TargetLowering::getRegForInlineAsmConstraint(
: std::make_pair(0U, &AArch64::PPRRegClass);
}
}
- if (StringRef("{cc}").equals_lower(Constraint))
+ if (StringRef("{cc}").equals_insensitive(Constraint))
return std::make_pair(unsigned(AArch64::NZCV), &AArch64::CCRRegClass);
// Use the default implementation in TargetLowering to convert the register
@@ -7814,10 +8298,6 @@ void AArch64TargetLowering::LowerAsmOperandForConstraint(
dyn_cast<BlockAddressSDNode>(Op)) {
Result =
DAG.getTargetBlockAddress(BA->getBlockAddress(), BA->getValueType(0));
- } else if (const ExternalSymbolSDNode *ES =
- dyn_cast<ExternalSymbolSDNode>(Op)) {
- Result =
- DAG.getTargetExternalSymbol(ES->getSymbol(), ES->getValueType(0));
} else
return;
break;
@@ -7944,7 +8424,7 @@ static SDValue WidenVector(SDValue V64Reg, SelectionDAG &DAG) {
SDLoc DL(V64Reg);
return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, WideTy, DAG.getUNDEF(WideTy),
- V64Reg, DAG.getConstant(0, DL, MVT::i32));
+ V64Reg, DAG.getConstant(0, DL, MVT::i64));
}
/// getExtFactor - Determine the adjustment factor for the position when
@@ -8792,6 +9272,9 @@ SDValue AArch64TargetLowering::LowerVECTOR_SHUFFLE(SDValue Op,
ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(Op.getNode());
+ if (useSVEForFixedLengthVectorVT(VT))
+ return LowerFixedLengthVECTOR_SHUFFLEToSVE(Op, DAG);
+
// Convert shuffles that are directly supported on NEON to target-specific
// DAG nodes, instead of keeping them as shuffles and matching them again
// during code selection. This is more efficient and avoids the possibility
@@ -8801,6 +9284,10 @@ SDValue AArch64TargetLowering::LowerVECTOR_SHUFFLE(SDValue Op,
SDValue V1 = Op.getOperand(0);
SDValue V2 = Op.getOperand(1);
+ assert(V1.getValueType() == VT && "Unexpected VECTOR_SHUFFLE type!");
+ assert(ShuffleMask.size() == VT.getVectorNumElements() &&
+ "Unexpected VECTOR_SHUFFLE mask size!");
+
if (SVN->isSplat()) {
int Lane = SVN->getSplatIndex();
// If this is undef splat, generate it via "just" vdup, if possible.
@@ -8847,6 +9334,14 @@ SDValue AArch64TargetLowering::LowerVECTOR_SHUFFLE(SDValue Op,
if (isREVMask(ShuffleMask, VT, 16))
return DAG.getNode(AArch64ISD::REV16, dl, V1.getValueType(), V1, V2);
+ if (((VT.getVectorNumElements() == 8 && VT.getScalarSizeInBits() == 16) ||
+ (VT.getVectorNumElements() == 16 && VT.getScalarSizeInBits() == 8)) &&
+ ShuffleVectorInst::isReverseMask(ShuffleMask)) {
+ SDValue Rev = DAG.getNode(AArch64ISD::REV64, dl, VT, V1);
+ return DAG.getNode(AArch64ISD::EXT, dl, VT, Rev, Rev,
+ DAG.getConstant(8, dl, MVT::i32));
+ }
+
bool ReverseEXT = false;
unsigned Imm;
if (isEXTMask(ShuffleMask, VT, ReverseEXT, Imm)) {
@@ -9027,9 +9522,7 @@ SDValue AArch64TargetLowering::LowerDUPQLane(SDValue Op,
SDValue SplatOne = DAG.getNode(ISD::SPLAT_VECTOR, DL, MVT::nxv2i64, One);
// create the vector 0,1,0,1,...
- SDValue Zero = DAG.getConstant(0, DL, MVT::i64);
- SDValue SV = DAG.getNode(AArch64ISD::INDEX_VECTOR,
- DL, MVT::nxv2i64, Zero, One);
+ SDValue SV = DAG.getStepVector(DL, MVT::nxv2i64);
SV = DAG.getNode(ISD::AND, DL, MVT::nxv2i64, SV, SplatOne);
// create the vector idx64,idx64+1,idx64,idx64+1,...
@@ -9556,10 +10049,10 @@ SDValue AArch64TargetLowering::LowerBUILD_VECTOR(SDValue Op,
}
if (i > 0)
isOnlyLowElement = false;
- if (!isa<ConstantFPSDNode>(V) && !isa<ConstantSDNode>(V))
+ if (!isIntOrFPConstant(V))
isConstant = false;
- if (isa<ConstantSDNode>(V) || isa<ConstantFPSDNode>(V)) {
+ if (isIntOrFPConstant(V)) {
++NumConstantLanes;
if (!ConstantValue.getNode())
ConstantValue = V;
@@ -9584,7 +10077,7 @@ SDValue AArch64TargetLowering::LowerBUILD_VECTOR(SDValue Op,
// Convert BUILD_VECTOR where all elements but the lowest are undef into
// SCALAR_TO_VECTOR, except for when we have a single-element constant vector
// as SimplifyDemandedBits will just turn that back into BUILD_VECTOR.
- if (isOnlyLowElement && !(NumElts == 1 && isa<ConstantSDNode>(Value))) {
+ if (isOnlyLowElement && !(NumElts == 1 && isIntOrFPConstant(Value))) {
LLVM_DEBUG(dbgs() << "LowerBUILD_VECTOR: only low element used, creating 1 "
"SCALAR_TO_VECTOR node\n");
return DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, VT, Value);
@@ -9725,7 +10218,7 @@ SDValue AArch64TargetLowering::LowerBUILD_VECTOR(SDValue Op,
for (unsigned i = 0; i < NumElts; ++i) {
SDValue V = Op.getOperand(i);
SDValue LaneIdx = DAG.getConstant(i, dl, MVT::i64);
- if (!isa<ConstantSDNode>(V) && !isa<ConstantFPSDNode>(V))
+ if (!isIntOrFPConstant(V))
// Note that type legalization likely mucked about with the VT of the
// source operand, so we may have to convert it here before inserting.
Val = DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, VT, Val, V, LaneIdx);
@@ -9749,9 +10242,7 @@ SDValue AArch64TargetLowering::LowerBUILD_VECTOR(SDValue Op,
if (PreferDUPAndInsert) {
// First, build a constant vector with the common element.
- SmallVector<SDValue, 8> Ops;
- for (unsigned I = 0; I < NumElts; ++I)
- Ops.push_back(Value);
+ SmallVector<SDValue, 8> Ops(NumElts, Value);
SDValue NewVector = LowerBUILD_VECTOR(DAG.getBuildVector(VT, dl, Ops), DAG);
// Next, insert the elements that do not match the common value.
for (unsigned I = 0; I < NumElts; ++I)
@@ -9813,6 +10304,9 @@ SDValue AArch64TargetLowering::LowerBUILD_VECTOR(SDValue Op,
SDValue AArch64TargetLowering::LowerCONCAT_VECTORS(SDValue Op,
SelectionDAG &DAG) const {
+ if (useSVEForFixedLengthVectorVT(Op.getValueType()))
+ return LowerFixedLengthConcatVectorsToSVE(Op, DAG);
+
assert(Op.getValueType().isScalableVector() &&
isTypeLegal(Op.getValueType()) &&
"Expected legal scalable vector type!");
@@ -9827,13 +10321,32 @@ SDValue AArch64TargetLowering::LowerINSERT_VECTOR_ELT(SDValue Op,
SelectionDAG &DAG) const {
assert(Op.getOpcode() == ISD::INSERT_VECTOR_ELT && "Unknown opcode!");
+ if (useSVEForFixedLengthVectorVT(Op.getValueType()))
+ return LowerFixedLengthInsertVectorElt(Op, DAG);
+
// Check for non-constant or out of range lane.
EVT VT = Op.getOperand(0).getValueType();
+
+ if (VT.getScalarType() == MVT::i1) {
+ EVT VectorVT = getPromotedVTForPredicate(VT);
+ SDLoc DL(Op);
+ SDValue ExtendedVector =
+ DAG.getAnyExtOrTrunc(Op.getOperand(0), DL, VectorVT);
+ SDValue ExtendedValue =
+ DAG.getAnyExtOrTrunc(Op.getOperand(1), DL,
+ VectorVT.getScalarType().getSizeInBits() < 32
+ ? MVT::i32
+ : VectorVT.getScalarType());
+ ExtendedVector =
+ DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, VectorVT, ExtendedVector,
+ ExtendedValue, Op.getOperand(2));
+ return DAG.getAnyExtOrTrunc(ExtendedVector, DL, VT);
+ }
+
ConstantSDNode *CI = dyn_cast<ConstantSDNode>(Op.getOperand(2));
if (!CI || CI->getZExtValue() >= VT.getVectorNumElements())
return SDValue();
-
// Insertion/extraction are legal for V128 types.
if (VT == MVT::v16i8 || VT == MVT::v8i16 || VT == MVT::v4i32 ||
VT == MVT::v2i64 || VT == MVT::v4f32 || VT == MVT::v2f64 ||
@@ -9861,14 +10374,29 @@ SDValue
AArch64TargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op,
SelectionDAG &DAG) const {
assert(Op.getOpcode() == ISD::EXTRACT_VECTOR_ELT && "Unknown opcode!");
+ EVT VT = Op.getOperand(0).getValueType();
+
+ if (VT.getScalarType() == MVT::i1) {
+ // We can't directly extract from an SVE predicate; extend it first.
+ // (This isn't the only possible lowering, but it's straightforward.)
+ EVT VectorVT = getPromotedVTForPredicate(VT);
+ SDLoc DL(Op);
+ SDValue Extend =
+ DAG.getNode(ISD::ANY_EXTEND, DL, VectorVT, Op.getOperand(0));
+ MVT ExtractTy = VectorVT == MVT::nxv2i64 ? MVT::i64 : MVT::i32;
+ SDValue Extract = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ExtractTy,
+ Extend, Op.getOperand(1));
+ return DAG.getAnyExtOrTrunc(Extract, DL, Op.getValueType());
+ }
+
+ if (useSVEForFixedLengthVectorVT(VT))
+ return LowerFixedLengthExtractVectorElt(Op, DAG);
// Check for non-constant or out of range lane.
- EVT VT = Op.getOperand(0).getValueType();
ConstantSDNode *CI = dyn_cast<ConstantSDNode>(Op.getOperand(1));
if (!CI || CI->getZExtValue() >= VT.getVectorNumElements())
return SDValue();
-
// Insertion/extraction are legal for V128 types.
if (VT == MVT::v16i8 || VT == MVT::v8i16 || VT == MVT::v4i32 ||
VT == MVT::v2i64 || VT == MVT::v4f32 || VT == MVT::v2f64 ||
@@ -10159,7 +10687,8 @@ SDValue AArch64TargetLowering::LowerVectorSRA_SRL_SHL(SDValue Op,
unsigned Opc = (Op.getOpcode() == ISD::SRA) ? Intrinsic::aarch64_neon_sshl
: Intrinsic::aarch64_neon_ushl;
// negate the shift amount
- SDValue NegShift = DAG.getNode(AArch64ISD::NEG, DL, VT, Op.getOperand(1));
+ SDValue NegShift = DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT),
+ Op.getOperand(1));
SDValue NegShiftLeft =
DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT,
DAG.getConstant(Opc, DL, MVT::i32), Op.getOperand(0),
@@ -10267,11 +10796,8 @@ static SDValue EmitVectorComparison(SDValue LHS, SDValue RHS,
SDValue AArch64TargetLowering::LowerVSETCC(SDValue Op,
SelectionDAG &DAG) const {
- if (Op.getValueType().isScalableVector()) {
- if (Op.getOperand(0).getValueType().isFloatingPoint())
- return Op;
+ if (Op.getValueType().isScalableVector())
return LowerToPredicatedOp(Op, DAG, AArch64ISD::SETCC_MERGE_ZERO);
- }
if (useSVEForFixedLengthVectorVT(Op.getOperand(0).getValueType()))
return LowerFixedLengthVectorSetccToSVE(Op, DAG);
@@ -11391,8 +11917,8 @@ EVT AArch64TargetLowering::getOptimalMemOpType(
if (Op.isAligned(AlignCheck))
return true;
bool Fast;
- return allowsMisalignedMemoryAccesses(VT, 0, 1, MachineMemOperand::MONone,
- &Fast) &&
+ return allowsMisalignedMemoryAccesses(VT, 0, Align(1),
+ MachineMemOperand::MONone, &Fast) &&
Fast;
};
@@ -11422,14 +11948,14 @@ LLT AArch64TargetLowering::getOptimalMemOpLLT(
if (Op.isAligned(AlignCheck))
return true;
bool Fast;
- return allowsMisalignedMemoryAccesses(VT, 0, 1, MachineMemOperand::MONone,
- &Fast) &&
+ return allowsMisalignedMemoryAccesses(VT, 0, Align(1),
+ MachineMemOperand::MONone, &Fast) &&
Fast;
};
if (CanUseNEON && Op.isMemset() && !IsSmallMemset &&
AlignmentIsAcceptable(MVT::v2i64, Align(16)))
- return LLT::vector(2, 64);
+ return LLT::fixed_vector(2, 64);
if (CanUseFP && !IsSmallMemset && AlignmentIsAcceptable(MVT::f128, Align(16)))
return LLT::scalar(128);
if (Op.size() >= 8 && AlignmentIsAcceptable(MVT::i64, Align(8)))
@@ -11482,8 +12008,12 @@ bool AArch64TargetLowering::isLegalAddressingMode(const DataLayout &DL,
return false;
// FIXME: Update this method to support scalable addressing modes.
- if (isa<ScalableVectorType>(Ty))
- return AM.HasBaseReg && !AM.BaseOffs && !AM.Scale;
+ if (isa<ScalableVectorType>(Ty)) {
+ uint64_t VecElemNumBytes =
+ DL.getTypeSizeInBits(cast<VectorType>(Ty)->getElementType()) / 8;
+ return AM.HasBaseReg && !AM.BaseOffs &&
+ (AM.Scale == 0 || (uint64_t)AM.Scale == VecElemNumBytes);
+ }
// check reg + imm case:
// i.e., reg + 0, reg + imm9, reg + SIZE_IN_BYTES * uimm12
@@ -11521,9 +12051,8 @@ bool AArch64TargetLowering::shouldConsiderGEPOffsetSplit() const {
return true;
}
-int AArch64TargetLowering::getScalingFactorCost(const DataLayout &DL,
- const AddrMode &AM, Type *Ty,
- unsigned AS) const {
+InstructionCost AArch64TargetLowering::getScalingFactorCost(
+ const DataLayout &DL, const AddrMode &AM, Type *Ty, unsigned AS) const {
// Scaling factors are not free at all.
// Operands | Rt Latency
// -------------------------------------------
@@ -11546,6 +12075,8 @@ bool AArch64TargetLowering::isFMAFasterThanFMulAndFAdd(
return false;
switch (VT.getSimpleVT().SimpleTy) {
+ case MVT::f16:
+ return Subtarget->hasFullFP16();
case MVT::f32:
case MVT::f64:
return true;
@@ -11567,6 +12098,11 @@ bool AArch64TargetLowering::isFMAFasterThanFMulAndFAdd(const Function &F,
}
}
+bool AArch64TargetLowering::generateFMAsInMachineCombiner(
+ EVT VT, CodeGenOpt::Level OptLevel) const {
+ return (OptLevel >= CodeGenOpt::Aggressive) && !VT.isScalableVector();
+}
+
const MCPhysReg *
AArch64TargetLowering::getScratchRegisters(CallingConv::ID) const {
// LR is a callee-save register, but we must treat it as clobbered by any call
@@ -11654,77 +12190,142 @@ static SDValue foldVectorXorShiftIntoCmp(SDNode *N, SelectionDAG &DAG,
return DAG.getNode(AArch64ISD::CMGEz, SDLoc(N), VT, Shift.getOperand(0));
}
-// VECREDUCE_ADD( EXTEND(v16i8_type) ) to
-// VECREDUCE_ADD( DOTv16i8(v16i8_type) )
-static SDValue performVecReduceAddCombine(SDNode *N, SelectionDAG &DAG,
- const AArch64Subtarget *ST) {
- SDValue Op0 = N->getOperand(0);
- if (!ST->hasDotProd() || N->getValueType(0) != MVT::i32)
+// Given a vecreduce_add node, detect the below pattern and convert it to the
+// node sequence with UABDL, [S|U]ADB and UADDLP.
+//
+// i32 vecreduce_add(
+// v16i32 abs(
+// v16i32 sub(
+// v16i32 [sign|zero]_extend(v16i8 a), v16i32 [sign|zero]_extend(v16i8 b))))
+// =================>
+// i32 vecreduce_add(
+// v4i32 UADDLP(
+// v8i16 add(
+// v8i16 zext(
+// v8i8 [S|U]ABD low8:v16i8 a, low8:v16i8 b
+// v8i16 zext(
+// v8i8 [S|U]ABD high8:v16i8 a, high8:v16i8 b
+static SDValue performVecReduceAddCombineWithUADDLP(SDNode *N,
+ SelectionDAG &DAG) {
+ // Assumed i32 vecreduce_add
+ if (N->getValueType(0) != MVT::i32)
return SDValue();
- if (Op0.getValueType().getVectorElementType() != MVT::i32)
+ SDValue VecReduceOp0 = N->getOperand(0);
+ unsigned Opcode = VecReduceOp0.getOpcode();
+ // Assumed v16i32 abs
+ if (Opcode != ISD::ABS || VecReduceOp0->getValueType(0) != MVT::v16i32)
return SDValue();
- unsigned ExtOpcode = Op0.getOpcode();
- if (ExtOpcode != ISD::ZERO_EXTEND && ExtOpcode != ISD::SIGN_EXTEND)
+ SDValue ABS = VecReduceOp0;
+ // Assumed v16i32 sub
+ if (ABS->getOperand(0)->getOpcode() != ISD::SUB ||
+ ABS->getOperand(0)->getValueType(0) != MVT::v16i32)
return SDValue();
- EVT Op0VT = Op0.getOperand(0).getValueType();
- if (Op0VT != MVT::v16i8)
+ SDValue SUB = ABS->getOperand(0);
+ unsigned Opcode0 = SUB->getOperand(0).getOpcode();
+ unsigned Opcode1 = SUB->getOperand(1).getOpcode();
+ // Assumed v16i32 type
+ if (SUB->getOperand(0)->getValueType(0) != MVT::v16i32 ||
+ SUB->getOperand(1)->getValueType(0) != MVT::v16i32)
return SDValue();
- SDLoc DL(Op0);
- SDValue Ones = DAG.getConstant(1, DL, Op0VT);
- SDValue Zeros = DAG.getConstant(0, DL, MVT::v4i32);
- auto DotIntrisic = (ExtOpcode == ISD::ZERO_EXTEND)
- ? Intrinsic::aarch64_neon_udot
- : Intrinsic::aarch64_neon_sdot;
- SDValue Dot = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, Zeros.getValueType(),
- DAG.getConstant(DotIntrisic, DL, MVT::i32), Zeros,
- Ones, Op0.getOperand(0));
- return DAG.getNode(ISD::VECREDUCE_ADD, DL, N->getValueType(0), Dot);
-}
-
-// Given a ABS node, detect the following pattern:
-// (ABS (SUB (EXTEND a), (EXTEND b))).
-// Generates UABD/SABD instruction.
-static SDValue performABSCombine(SDNode *N, SelectionDAG &DAG,
- TargetLowering::DAGCombinerInfo &DCI,
- const AArch64Subtarget *Subtarget) {
- SDValue AbsOp1 = N->getOperand(0);
- SDValue Op0, Op1;
+ // Assumed zext or sext
+ bool IsZExt = false;
+ if (Opcode0 == ISD::ZERO_EXTEND && Opcode1 == ISD::ZERO_EXTEND) {
+ IsZExt = true;
+ } else if (Opcode0 == ISD::SIGN_EXTEND && Opcode1 == ISD::SIGN_EXTEND) {
+ IsZExt = false;
+ } else
+ return SDValue();
- if (AbsOp1.getOpcode() != ISD::SUB)
+ SDValue EXT0 = SUB->getOperand(0);
+ SDValue EXT1 = SUB->getOperand(1);
+ // Assumed zext's operand has v16i8 type
+ if (EXT0->getOperand(0)->getValueType(0) != MVT::v16i8 ||
+ EXT1->getOperand(0)->getValueType(0) != MVT::v16i8)
return SDValue();
- Op0 = AbsOp1.getOperand(0);
- Op1 = AbsOp1.getOperand(1);
+ // Pattern is dectected. Let's convert it to sequence of nodes.
+ SDLoc DL(N);
+
+ // First, create the node pattern of UABD/SABD.
+ SDValue UABDHigh8Op0 =
+ DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v8i8, EXT0->getOperand(0),
+ DAG.getConstant(8, DL, MVT::i64));
+ SDValue UABDHigh8Op1 =
+ DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v8i8, EXT1->getOperand(0),
+ DAG.getConstant(8, DL, MVT::i64));
+ SDValue UABDHigh8 = DAG.getNode(IsZExt ? ISD::ABDU : ISD::ABDS, DL, MVT::v8i8,
+ UABDHigh8Op0, UABDHigh8Op1);
+ SDValue UABDL = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::v8i16, UABDHigh8);
+
+ // Second, create the node pattern of UABAL.
+ SDValue UABDLo8Op0 =
+ DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v8i8, EXT0->getOperand(0),
+ DAG.getConstant(0, DL, MVT::i64));
+ SDValue UABDLo8Op1 =
+ DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v8i8, EXT1->getOperand(0),
+ DAG.getConstant(0, DL, MVT::i64));
+ SDValue UABDLo8 = DAG.getNode(IsZExt ? ISD::ABDU : ISD::ABDS, DL, MVT::v8i8,
+ UABDLo8Op0, UABDLo8Op1);
+ SDValue ZExtUABD = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::v8i16, UABDLo8);
+ SDValue UABAL = DAG.getNode(ISD::ADD, DL, MVT::v8i16, UABDL, ZExtUABD);
+
+ // Third, create the node of UADDLP.
+ SDValue UADDLP = DAG.getNode(AArch64ISD::UADDLP, DL, MVT::v4i32, UABAL);
+
+ // Fourth, create the node of VECREDUCE_ADD.
+ return DAG.getNode(ISD::VECREDUCE_ADD, DL, MVT::i32, UADDLP);
+}
+
+// Turn a v8i8/v16i8 extended vecreduce into a udot/sdot and vecreduce
+// vecreduce.add(ext(A)) to vecreduce.add(DOT(zero, A, one))
+// vecreduce.add(mul(ext(A), ext(B))) to vecreduce.add(DOT(zero, A, B))
+static SDValue performVecReduceAddCombine(SDNode *N, SelectionDAG &DAG,
+ const AArch64Subtarget *ST) {
+ if (!ST->hasDotProd())
+ return performVecReduceAddCombineWithUADDLP(N, DAG);
- unsigned Opc0 = Op0.getOpcode();
- // Check if the operands of the sub are (zero|sign)-extended.
- if (Opc0 != Op1.getOpcode() ||
- (Opc0 != ISD::ZERO_EXTEND && Opc0 != ISD::SIGN_EXTEND))
+ SDValue Op0 = N->getOperand(0);
+ if (N->getValueType(0) != MVT::i32 ||
+ Op0.getValueType().getVectorElementType() != MVT::i32)
return SDValue();
- EVT VectorT1 = Op0.getOperand(0).getValueType();
- EVT VectorT2 = Op1.getOperand(0).getValueType();
- // Check if vectors are of same type and valid size.
- uint64_t Size = VectorT1.getFixedSizeInBits();
- if (VectorT1 != VectorT2 || (Size != 64 && Size != 128))
+ unsigned ExtOpcode = Op0.getOpcode();
+ SDValue A = Op0;
+ SDValue B;
+ if (ExtOpcode == ISD::MUL) {
+ A = Op0.getOperand(0);
+ B = Op0.getOperand(1);
+ if (A.getOpcode() != B.getOpcode() ||
+ A.getOperand(0).getValueType() != B.getOperand(0).getValueType())
+ return SDValue();
+ ExtOpcode = A.getOpcode();
+ }
+ if (ExtOpcode != ISD::ZERO_EXTEND && ExtOpcode != ISD::SIGN_EXTEND)
return SDValue();
- // Check if vector element types are valid.
- EVT VT1 = VectorT1.getVectorElementType();
- if (VT1 != MVT::i8 && VT1 != MVT::i16 && VT1 != MVT::i32)
+ EVT Op0VT = A.getOperand(0).getValueType();
+ if (Op0VT != MVT::v8i8 && Op0VT != MVT::v16i8)
return SDValue();
- Op0 = Op0.getOperand(0);
- Op1 = Op1.getOperand(0);
- unsigned ABDOpcode =
- (Opc0 == ISD::SIGN_EXTEND) ? AArch64ISD::SABD : AArch64ISD::UABD;
- SDValue ABD =
- DAG.getNode(ABDOpcode, SDLoc(N), Op0->getValueType(0), Op0, Op1);
- return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), N->getValueType(0), ABD);
+ SDLoc DL(Op0);
+ // For non-mla reductions B can be set to 1. For MLA we take the operand of
+ // the extend B.
+ if (!B)
+ B = DAG.getConstant(1, DL, Op0VT);
+ else
+ B = B.getOperand(0);
+
+ SDValue Zeros =
+ DAG.getConstant(0, DL, Op0VT == MVT::v8i8 ? MVT::v2i32 : MVT::v4i32);
+ auto DotOpcode =
+ (ExtOpcode == ISD::ZERO_EXTEND) ? AArch64ISD::UDOT : AArch64ISD::SDOT;
+ SDValue Dot = DAG.getNode(DotOpcode, DL, Zeros.getValueType(), Zeros,
+ A.getOperand(0), B);
+ return DAG.getNode(ISD::VECREDUCE_ADD, DL, N->getValueType(0), Dot);
}
static SDValue performXorCombine(SDNode *N, SelectionDAG &DAG,
@@ -11972,6 +12573,7 @@ static SDValue performMulCombine(SDNode *N, SelectionDAG &DAG,
// e.g. 6=3*2=(2+1)*2.
// TODO: consider lowering more cases, e.g. C = 14, -6, -14 or even 45
// which equals to (1+2)*16-(1+2).
+
// TrailingZeroes is used to test if the mul can be lowered to
// shift+add+shift.
unsigned TrailingZeroes = ConstValue.countTrailingZeros();
@@ -12350,6 +12952,11 @@ static SDValue tryCombineToBSL(SDNode *N,
if (!VT.isVector())
return SDValue();
+ // The combining code currently only works for NEON vectors. In particular,
+ // it does not work for SVE when dealing with vectors wider than 128 bits.
+ if (!VT.is64BitVector() && !VT.is128BitVector())
+ return SDValue();
+
SDValue N0 = N->getOperand(0);
if (N0.getOpcode() != ISD::AND)
return SDValue();
@@ -12358,6 +12965,44 @@ static SDValue tryCombineToBSL(SDNode *N,
if (N1.getOpcode() != ISD::AND)
return SDValue();
+ // InstCombine does (not (neg a)) => (add a -1).
+ // Try: (or (and (neg a) b) (and (add a -1) c)) => (bsl (neg a) b c)
+ // Loop over all combinations of AND operands.
+ for (int i = 1; i >= 0; --i) {
+ for (int j = 1; j >= 0; --j) {
+ SDValue O0 = N0->getOperand(i);
+ SDValue O1 = N1->getOperand(j);
+ SDValue Sub, Add, SubSibling, AddSibling;
+
+ // Find a SUB and an ADD operand, one from each AND.
+ if (O0.getOpcode() == ISD::SUB && O1.getOpcode() == ISD::ADD) {
+ Sub = O0;
+ Add = O1;
+ SubSibling = N0->getOperand(1 - i);
+ AddSibling = N1->getOperand(1 - j);
+ } else if (O0.getOpcode() == ISD::ADD && O1.getOpcode() == ISD::SUB) {
+ Add = O0;
+ Sub = O1;
+ AddSibling = N0->getOperand(1 - i);
+ SubSibling = N1->getOperand(1 - j);
+ } else
+ continue;
+
+ if (!ISD::isBuildVectorAllZeros(Sub.getOperand(0).getNode()))
+ continue;
+
+ // Constant ones is always righthand operand of the Add.
+ if (!ISD::isBuildVectorAllOnes(Add.getOperand(1).getNode()))
+ continue;
+
+ if (Sub.getOperand(1) != Add.getOperand(0))
+ continue;
+
+ return DAG.getNode(AArch64ISD::BSP, DL, VT, Sub, SubSibling, AddSibling);
+ }
+ }
+
+ // (or (and a b) (and (not a) c)) => (bsl a b c)
// We only have to look for constant vectors here since the general, variable
// case can be handled in TableGen.
unsigned Bits = VT.getScalarSizeInBits();
@@ -13065,6 +13710,13 @@ static SDValue performSetccAddFolding(SDNode *Op, SelectionDAG &DAG) {
SDValue RHS = Op->getOperand(1);
SetCCInfoAndKind InfoAndKind;
+ // If both operands are a SET_CC, then we don't want to perform this
+ // folding and create another csel as this results in more instructions
+ // (and higher register usage).
+ if (isSetCCOrZExtSetCC(LHS, InfoAndKind) &&
+ isSetCCOrZExtSetCC(RHS, InfoAndKind))
+ return SDValue();
+
// If neither operand is a SET_CC, give up.
if (!isSetCCOrZExtSetCC(LHS, InfoAndKind)) {
std::swap(LHS, RHS);
@@ -13135,6 +13787,29 @@ static SDValue performUADDVCombine(SDNode *N, SelectionDAG &DAG) {
DAG.getConstant(0, DL, MVT::i64));
}
+// ADD(UDOT(zero, x, y), A) --> UDOT(A, x, y)
+static SDValue performAddDotCombine(SDNode *N, SelectionDAG &DAG) {
+ EVT VT = N->getValueType(0);
+ if (N->getOpcode() != ISD::ADD)
+ return SDValue();
+
+ SDValue Dot = N->getOperand(0);
+ SDValue A = N->getOperand(1);
+ // Handle commutivity
+ auto isZeroDot = [](SDValue Dot) {
+ return (Dot.getOpcode() == AArch64ISD::UDOT ||
+ Dot.getOpcode() == AArch64ISD::SDOT) &&
+ isZerosVector(Dot.getOperand(0).getNode());
+ };
+ if (!isZeroDot(Dot))
+ std::swap(Dot, A);
+ if (!isZeroDot(Dot))
+ return SDValue();
+
+ return DAG.getNode(Dot.getOpcode(), SDLoc(N), VT, A, Dot.getOperand(1),
+ Dot.getOperand(2));
+}
+
// The basic add/sub long vector instructions have variants with "2" on the end
// which act on the high-half of their inputs. They are normally matched by
// patterns like:
@@ -13194,6 +13869,8 @@ static SDValue performAddSubCombine(SDNode *N,
// Try to change sum of two reductions.
if (SDValue Val = performUADDVCombine(N, DAG))
return Val;
+ if (SDValue Val = performAddDotCombine(N, DAG))
+ return Val;
return performAddSubLongCombine(N, DCI, DAG);
}
@@ -13335,15 +14012,16 @@ static SDValue LowerSVEIntrinsicIndex(SDNode *N, SelectionDAG &DAG) {
SDLoc DL(N);
SDValue Op1 = N->getOperand(1);
SDValue Op2 = N->getOperand(2);
- EVT ScalarTy = Op1.getValueType();
-
- if ((ScalarTy == MVT::i8) || (ScalarTy == MVT::i16)) {
- Op1 = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, Op1);
- Op2 = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, Op2);
- }
+ EVT ScalarTy = Op2.getValueType();
+ if ((ScalarTy == MVT::i8) || (ScalarTy == MVT::i16))
+ ScalarTy = MVT::i32;
- return DAG.getNode(AArch64ISD::INDEX_VECTOR, DL, N->getValueType(0),
- Op1, Op2);
+ // Lower index_vector(base, step) to mul(step step_vector(1)) + splat(base).
+ SDValue StepVector = DAG.getStepVector(DL, N->getValueType(0));
+ SDValue Step = DAG.getNode(ISD::SPLAT_VECTOR, DL, N->getValueType(0), Op2);
+ SDValue Mul = DAG.getNode(ISD::MUL, DL, N->getValueType(0), StepVector, Step);
+ SDValue Base = DAG.getNode(ISD::SPLAT_VECTOR, DL, N->getValueType(0), Op1);
+ return DAG.getNode(ISD::ADD, DL, N->getValueType(0), Mul, Base);
}
static SDValue LowerSVEIntrinsicDUP(SDNode *N, SelectionDAG &DAG) {
@@ -13533,20 +14211,47 @@ static SDValue combineSVEReductionOrderedFP(SDNode *N, unsigned Opc,
Zero);
}
+static bool isAllActivePredicate(SDValue N) {
+ unsigned NumElts = N.getValueType().getVectorMinNumElements();
+
+ // Look through cast.
+ while (N.getOpcode() == AArch64ISD::REINTERPRET_CAST) {
+ N = N.getOperand(0);
+ // When reinterpreting from a type with fewer elements the "new" elements
+ // are not active, so bail if they're likely to be used.
+ if (N.getValueType().getVectorMinNumElements() < NumElts)
+ return false;
+ }
+
+ // "ptrue p.<ty>, all" can be considered all active when <ty> is the same size
+ // or smaller than the implicit element type represented by N.
+ // NOTE: A larger element count implies a smaller element type.
+ if (N.getOpcode() == AArch64ISD::PTRUE &&
+ N.getConstantOperandVal(0) == AArch64SVEPredPattern::all)
+ return N.getValueType().getVectorMinNumElements() >= NumElts;
+
+ return false;
+}
+
// If a merged operation has no inactive lanes we can relax it to a predicated
// or unpredicated operation, which potentially allows better isel (perhaps
// using immediate forms) or relaxing register reuse requirements.
-static SDValue convertMergedOpToPredOp(SDNode *N, unsigned PredOpc,
- SelectionDAG &DAG) {
+static SDValue convertMergedOpToPredOp(SDNode *N, unsigned Opc,
+ SelectionDAG &DAG,
+ bool UnpredOp = false) {
assert(N->getOpcode() == ISD::INTRINSIC_WO_CHAIN && "Expected intrinsic!");
assert(N->getNumOperands() == 4 && "Expected 3 operand intrinsic!");
SDValue Pg = N->getOperand(1);
// ISD way to specify an all active predicate.
- if ((Pg.getOpcode() == AArch64ISD::PTRUE) &&
- (Pg.getConstantOperandVal(0) == AArch64SVEPredPattern::all))
- return DAG.getNode(PredOpc, SDLoc(N), N->getValueType(0), Pg,
- N->getOperand(2), N->getOperand(3));
+ if (isAllActivePredicate(Pg)) {
+ if (UnpredOp)
+ return DAG.getNode(Opc, SDLoc(N), N->getValueType(0), N->getOperand(2),
+ N->getOperand(3));
+ else
+ return DAG.getNode(Opc, SDLoc(N), N->getValueType(0), Pg,
+ N->getOperand(2), N->getOperand(3));
+ }
// FUTURE: SplatVector(true)
return SDValue();
@@ -13637,6 +14342,12 @@ static SDValue performIntrinsicCombine(SDNode *N,
N->getOperand(1));
case Intrinsic::aarch64_sve_ext:
return LowerSVEIntrinsicEXT(N, DAG);
+ case Intrinsic::aarch64_sve_mul:
+ return convertMergedOpToPredOp(N, AArch64ISD::MUL_PRED, DAG);
+ case Intrinsic::aarch64_sve_smulh:
+ return convertMergedOpToPredOp(N, AArch64ISD::MULHS_PRED, DAG);
+ case Intrinsic::aarch64_sve_umulh:
+ return convertMergedOpToPredOp(N, AArch64ISD::MULHU_PRED, DAG);
case Intrinsic::aarch64_sve_smin:
return convertMergedOpToPredOp(N, AArch64ISD::SMIN_PRED, DAG);
case Intrinsic::aarch64_sve_umin:
@@ -13651,6 +14362,44 @@ static SDValue performIntrinsicCombine(SDNode *N,
return convertMergedOpToPredOp(N, AArch64ISD::SRL_PRED, DAG);
case Intrinsic::aarch64_sve_asr:
return convertMergedOpToPredOp(N, AArch64ISD::SRA_PRED, DAG);
+ case Intrinsic::aarch64_sve_fadd:
+ return convertMergedOpToPredOp(N, AArch64ISD::FADD_PRED, DAG);
+ case Intrinsic::aarch64_sve_fsub:
+ return convertMergedOpToPredOp(N, AArch64ISD::FSUB_PRED, DAG);
+ case Intrinsic::aarch64_sve_fmul:
+ return convertMergedOpToPredOp(N, AArch64ISD::FMUL_PRED, DAG);
+ case Intrinsic::aarch64_sve_add:
+ return convertMergedOpToPredOp(N, ISD::ADD, DAG, true);
+ case Intrinsic::aarch64_sve_sub:
+ return convertMergedOpToPredOp(N, ISD::SUB, DAG, true);
+ case Intrinsic::aarch64_sve_and:
+ return convertMergedOpToPredOp(N, ISD::AND, DAG, true);
+ case Intrinsic::aarch64_sve_bic:
+ return convertMergedOpToPredOp(N, AArch64ISD::BIC, DAG, true);
+ case Intrinsic::aarch64_sve_eor:
+ return convertMergedOpToPredOp(N, ISD::XOR, DAG, true);
+ case Intrinsic::aarch64_sve_orr:
+ return convertMergedOpToPredOp(N, ISD::OR, DAG, true);
+ case Intrinsic::aarch64_sve_sqadd:
+ return convertMergedOpToPredOp(N, ISD::SADDSAT, DAG, true);
+ case Intrinsic::aarch64_sve_sqsub:
+ return convertMergedOpToPredOp(N, ISD::SSUBSAT, DAG, true);
+ case Intrinsic::aarch64_sve_uqadd:
+ return convertMergedOpToPredOp(N, ISD::UADDSAT, DAG, true);
+ case Intrinsic::aarch64_sve_uqsub:
+ return convertMergedOpToPredOp(N, ISD::USUBSAT, DAG, true);
+ case Intrinsic::aarch64_sve_sqadd_x:
+ return DAG.getNode(ISD::SADDSAT, SDLoc(N), N->getValueType(0),
+ N->getOperand(1), N->getOperand(2));
+ case Intrinsic::aarch64_sve_sqsub_x:
+ return DAG.getNode(ISD::SSUBSAT, SDLoc(N), N->getValueType(0),
+ N->getOperand(1), N->getOperand(2));
+ case Intrinsic::aarch64_sve_uqadd_x:
+ return DAG.getNode(ISD::UADDSAT, SDLoc(N), N->getValueType(0),
+ N->getOperand(1), N->getOperand(2));
+ case Intrinsic::aarch64_sve_uqsub_x:
+ return DAG.getNode(ISD::USUBSAT, SDLoc(N), N->getValueType(0),
+ N->getOperand(1), N->getOperand(2));
case Intrinsic::aarch64_sve_cmphs:
if (!N->getOperand(2).getValueType().isFloatingPoint())
return DAG.getNode(AArch64ISD::SETCC_MERGE_ZERO, SDLoc(N),
@@ -13663,29 +14412,34 @@ static SDValue performIntrinsicCombine(SDNode *N,
N->getValueType(0), N->getOperand(1), N->getOperand(2),
N->getOperand(3), DAG.getCondCode(ISD::SETUGT));
break;
+ case Intrinsic::aarch64_sve_fcmpge:
case Intrinsic::aarch64_sve_cmpge:
- if (!N->getOperand(2).getValueType().isFloatingPoint())
- return DAG.getNode(AArch64ISD::SETCC_MERGE_ZERO, SDLoc(N),
- N->getValueType(0), N->getOperand(1), N->getOperand(2),
- N->getOperand(3), DAG.getCondCode(ISD::SETGE));
+ return DAG.getNode(AArch64ISD::SETCC_MERGE_ZERO, SDLoc(N),
+ N->getValueType(0), N->getOperand(1), N->getOperand(2),
+ N->getOperand(3), DAG.getCondCode(ISD::SETGE));
break;
+ case Intrinsic::aarch64_sve_fcmpgt:
case Intrinsic::aarch64_sve_cmpgt:
- if (!N->getOperand(2).getValueType().isFloatingPoint())
- return DAG.getNode(AArch64ISD::SETCC_MERGE_ZERO, SDLoc(N),
- N->getValueType(0), N->getOperand(1), N->getOperand(2),
- N->getOperand(3), DAG.getCondCode(ISD::SETGT));
+ return DAG.getNode(AArch64ISD::SETCC_MERGE_ZERO, SDLoc(N),
+ N->getValueType(0), N->getOperand(1), N->getOperand(2),
+ N->getOperand(3), DAG.getCondCode(ISD::SETGT));
break;
+ case Intrinsic::aarch64_sve_fcmpeq:
case Intrinsic::aarch64_sve_cmpeq:
- if (!N->getOperand(2).getValueType().isFloatingPoint())
- return DAG.getNode(AArch64ISD::SETCC_MERGE_ZERO, SDLoc(N),
- N->getValueType(0), N->getOperand(1), N->getOperand(2),
- N->getOperand(3), DAG.getCondCode(ISD::SETEQ));
+ return DAG.getNode(AArch64ISD::SETCC_MERGE_ZERO, SDLoc(N),
+ N->getValueType(0), N->getOperand(1), N->getOperand(2),
+ N->getOperand(3), DAG.getCondCode(ISD::SETEQ));
break;
+ case Intrinsic::aarch64_sve_fcmpne:
case Intrinsic::aarch64_sve_cmpne:
- if (!N->getOperand(2).getValueType().isFloatingPoint())
- return DAG.getNode(AArch64ISD::SETCC_MERGE_ZERO, SDLoc(N),
- N->getValueType(0), N->getOperand(1), N->getOperand(2),
- N->getOperand(3), DAG.getCondCode(ISD::SETNE));
+ return DAG.getNode(AArch64ISD::SETCC_MERGE_ZERO, SDLoc(N),
+ N->getValueType(0), N->getOperand(1), N->getOperand(2),
+ N->getOperand(3), DAG.getCondCode(ISD::SETNE));
+ break;
+ case Intrinsic::aarch64_sve_fcmpuo:
+ return DAG.getNode(AArch64ISD::SETCC_MERGE_ZERO, SDLoc(N),
+ N->getValueType(0), N->getOperand(1), N->getOperand(2),
+ N->getOperand(3), DAG.getCondCode(ISD::SETUO));
break;
case Intrinsic::aarch64_sve_fadda:
return combineSVEReductionOrderedFP(N, AArch64ISD::FADDA_PRED, DAG);
@@ -13743,8 +14497,8 @@ static SDValue performExtendCombine(SDNode *N,
// helps the backend to decide that an sabdl2 would be useful, saving a real
// extract_high operation.
if (!DCI.isBeforeLegalizeOps() && N->getOpcode() == ISD::ZERO_EXTEND &&
- (N->getOperand(0).getOpcode() == AArch64ISD::UABD ||
- N->getOperand(0).getOpcode() == AArch64ISD::SABD)) {
+ (N->getOperand(0).getOpcode() == ISD::ABDU ||
+ N->getOperand(0).getOpcode() == ISD::ABDS)) {
SDNode *ABDNode = N->getOperand(0).getNode();
SDValue NewABD =
tryCombineLongOpWithDup(Intrinsic::not_intrinsic, ABDNode, DCI, DAG);
@@ -13753,78 +14507,7 @@ static SDValue performExtendCombine(SDNode *N,
return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), N->getValueType(0), NewABD);
}
-
- // This is effectively a custom type legalization for AArch64.
- //
- // Type legalization will split an extend of a small, legal, type to a larger
- // illegal type by first splitting the destination type, often creating
- // illegal source types, which then get legalized in isel-confusing ways,
- // leading to really terrible codegen. E.g.,
- // %result = v8i32 sext v8i8 %value
- // becomes
- // %losrc = extract_subreg %value, ...
- // %hisrc = extract_subreg %value, ...
- // %lo = v4i32 sext v4i8 %losrc
- // %hi = v4i32 sext v4i8 %hisrc
- // Things go rapidly downhill from there.
- //
- // For AArch64, the [sz]ext vector instructions can only go up one element
- // size, so we can, e.g., extend from i8 to i16, but to go from i8 to i32
- // take two instructions.
- //
- // This implies that the most efficient way to do the extend from v8i8
- // to two v4i32 values is to first extend the v8i8 to v8i16, then do
- // the normal splitting to happen for the v8i16->v8i32.
-
- // This is pre-legalization to catch some cases where the default
- // type legalization will create ill-tempered code.
- if (!DCI.isBeforeLegalizeOps())
- return SDValue();
-
- // We're only interested in cleaning things up for non-legal vector types
- // here. If both the source and destination are legal, things will just
- // work naturally without any fiddling.
- const TargetLowering &TLI = DAG.getTargetLoweringInfo();
- EVT ResVT = N->getValueType(0);
- if (!ResVT.isVector() || TLI.isTypeLegal(ResVT))
- return SDValue();
- // If the vector type isn't a simple VT, it's beyond the scope of what
- // we're worried about here. Let legalization do its thing and hope for
- // the best.
- SDValue Src = N->getOperand(0);
- EVT SrcVT = Src->getValueType(0);
- if (!ResVT.isSimple() || !SrcVT.isSimple())
- return SDValue();
-
- // If the source VT is a 64-bit fixed or scalable vector, we can play games
- // and get the better results we want.
- if (SrcVT.getSizeInBits().getKnownMinSize() != 64)
- return SDValue();
-
- unsigned SrcEltSize = SrcVT.getScalarSizeInBits();
- ElementCount SrcEC = SrcVT.getVectorElementCount();
- SrcVT = MVT::getVectorVT(MVT::getIntegerVT(SrcEltSize * 2), SrcEC);
- SDLoc DL(N);
- Src = DAG.getNode(N->getOpcode(), DL, SrcVT, Src);
-
- // Now split the rest of the operation into two halves, each with a 64
- // bit source.
- EVT LoVT, HiVT;
- SDValue Lo, Hi;
- LoVT = HiVT = ResVT.getHalfNumVectorElementsVT(*DAG.getContext());
-
- EVT InNVT = EVT::getVectorVT(*DAG.getContext(), SrcVT.getVectorElementType(),
- LoVT.getVectorElementCount());
- Lo = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, InNVT, Src,
- DAG.getConstant(0, DL, MVT::i64));
- Hi = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, InNVT, Src,
- DAG.getConstant(InNVT.getVectorMinNumElements(), DL, MVT::i64));
- Lo = DAG.getNode(N->getOpcode(), DL, LoVT, Lo);
- Hi = DAG.getNode(N->getOpcode(), DL, HiVT, Hi);
-
- // Now combine the parts back together so we still have a single result
- // like the combiner expects.
- return DAG.getNode(ISD::CONCAT_VECTORS, DL, ResVT, Lo, Hi);
+ return SDValue();
}
static SDValue splitStoreSplat(SelectionDAG &DAG, StoreSDNode &St,
@@ -14234,6 +14917,16 @@ static SDValue splitStores(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
S->getMemOperand()->getFlags());
}
+static SDValue performSpliceCombine(SDNode *N, SelectionDAG &DAG) {
+ assert(N->getOpcode() == AArch64ISD::SPLICE && "Unexepected Opcode!");
+
+ // splice(pg, op1, undef) -> op1
+ if (N->getOperand(2).isUndef())
+ return N->getOperand(1);
+
+ return SDValue();
+}
+
static SDValue performUzpCombine(SDNode *N, SelectionDAG &DAG) {
SDLoc DL(N);
SDValue Op0 = N->getOperand(0);
@@ -14259,6 +14952,86 @@ static SDValue performUzpCombine(SDNode *N, SelectionDAG &DAG) {
return SDValue();
}
+static SDValue performGLD1Combine(SDNode *N, SelectionDAG &DAG) {
+ unsigned Opc = N->getOpcode();
+
+ assert(((Opc >= AArch64ISD::GLD1_MERGE_ZERO && // unsigned gather loads
+ Opc <= AArch64ISD::GLD1_IMM_MERGE_ZERO) ||
+ (Opc >= AArch64ISD::GLD1S_MERGE_ZERO && // signed gather loads
+ Opc <= AArch64ISD::GLD1S_IMM_MERGE_ZERO)) &&
+ "Invalid opcode.");
+
+ const bool Scaled = Opc == AArch64ISD::GLD1_SCALED_MERGE_ZERO ||
+ Opc == AArch64ISD::GLD1S_SCALED_MERGE_ZERO;
+ const bool Signed = Opc == AArch64ISD::GLD1S_MERGE_ZERO ||
+ Opc == AArch64ISD::GLD1S_SCALED_MERGE_ZERO;
+ const bool Extended = Opc == AArch64ISD::GLD1_SXTW_MERGE_ZERO ||
+ Opc == AArch64ISD::GLD1_SXTW_SCALED_MERGE_ZERO ||
+ Opc == AArch64ISD::GLD1_UXTW_MERGE_ZERO ||
+ Opc == AArch64ISD::GLD1_UXTW_SCALED_MERGE_ZERO;
+
+ SDLoc DL(N);
+ SDValue Chain = N->getOperand(0);
+ SDValue Pg = N->getOperand(1);
+ SDValue Base = N->getOperand(2);
+ SDValue Offset = N->getOperand(3);
+ SDValue Ty = N->getOperand(4);
+
+ EVT ResVT = N->getValueType(0);
+
+ const auto OffsetOpc = Offset.getOpcode();
+ const bool OffsetIsZExt =
+ OffsetOpc == AArch64ISD::ZERO_EXTEND_INREG_MERGE_PASSTHRU;
+ const bool OffsetIsSExt =
+ OffsetOpc == AArch64ISD::SIGN_EXTEND_INREG_MERGE_PASSTHRU;
+
+ // Fold sign/zero extensions of vector offsets into GLD1 nodes where possible.
+ if (!Extended && (OffsetIsSExt || OffsetIsZExt)) {
+ SDValue ExtPg = Offset.getOperand(0);
+ VTSDNode *ExtFrom = cast<VTSDNode>(Offset.getOperand(2).getNode());
+ EVT ExtFromEVT = ExtFrom->getVT().getVectorElementType();
+
+ // If the predicate for the sign- or zero-extended offset is the
+ // same as the predicate used for this load and the sign-/zero-extension
+ // was from a 32-bits...
+ if (ExtPg == Pg && ExtFromEVT == MVT::i32) {
+ SDValue UnextendedOffset = Offset.getOperand(1);
+
+ unsigned NewOpc = getGatherVecOpcode(Scaled, OffsetIsSExt, true);
+ if (Signed)
+ NewOpc = getSignExtendedGatherOpcode(NewOpc);
+
+ return DAG.getNode(NewOpc, DL, {ResVT, MVT::Other},
+ {Chain, Pg, Base, UnextendedOffset, Ty});
+ }
+ }
+
+ return SDValue();
+}
+
+/// Optimize a vector shift instruction and its operand if shifted out
+/// bits are not used.
+static SDValue performVectorShiftCombine(SDNode *N,
+ const AArch64TargetLowering &TLI,
+ TargetLowering::DAGCombinerInfo &DCI) {
+ assert(N->getOpcode() == AArch64ISD::VASHR ||
+ N->getOpcode() == AArch64ISD::VLSHR);
+
+ SDValue Op = N->getOperand(0);
+ unsigned OpScalarSize = Op.getScalarValueSizeInBits();
+
+ unsigned ShiftImm = N->getConstantOperandVal(1);
+ assert(OpScalarSize > ShiftImm && "Invalid shift imm");
+
+ APInt ShiftedOutBits = APInt::getLowBitsSet(OpScalarSize, ShiftImm);
+ APInt DemandedMask = ~ShiftedOutBits;
+
+ if (TLI.SimplifyDemandedBits(Op, DemandedMask, DCI))
+ return SDValue(N, 0);
+
+ return SDValue();
+}
+
/// Target-specific DAG combine function for post-increment LD1 (lane) and
/// post-increment LD1R.
static SDValue performPostLD1Combine(SDNode *N,
@@ -14383,6 +15156,29 @@ static bool performTBISimplification(SDValue Addr,
return false;
}
+static SDValue foldTruncStoreOfExt(SelectionDAG &DAG, SDNode *N) {
+ assert((N->getOpcode() == ISD::STORE || N->getOpcode() == ISD::MSTORE) &&
+ "Expected STORE dag node in input!");
+
+ if (auto Store = dyn_cast<StoreSDNode>(N)) {
+ if (!Store->isTruncatingStore() || Store->isIndexed())
+ return SDValue();
+ SDValue Ext = Store->getValue();
+ auto ExtOpCode = Ext.getOpcode();
+ if (ExtOpCode != ISD::ZERO_EXTEND && ExtOpCode != ISD::SIGN_EXTEND &&
+ ExtOpCode != ISD::ANY_EXTEND)
+ return SDValue();
+ SDValue Orig = Ext->getOperand(0);
+ if (Store->getMemoryVT() != Orig->getValueType(0))
+ return SDValue();
+ return DAG.getStore(Store->getChain(), SDLoc(Store), Orig,
+ Store->getBasePtr(), Store->getPointerInfo(),
+ Store->getAlign());
+ }
+
+ return SDValue();
+}
+
static SDValue performSTORECombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI,
SelectionDAG &DAG,
@@ -14394,54 +15190,8 @@ static SDValue performSTORECombine(SDNode *N,
performTBISimplification(N->getOperand(2), DCI, DAG))
return SDValue(N, 0);
- return SDValue();
-}
-
-static SDValue performMaskedGatherScatterCombine(SDNode *N,
- TargetLowering::DAGCombinerInfo &DCI,
- SelectionDAG &DAG) {
- MaskedGatherScatterSDNode *MGS = cast<MaskedGatherScatterSDNode>(N);
- assert(MGS && "Can only combine gather load or scatter store nodes");
-
- SDLoc DL(MGS);
- SDValue Chain = MGS->getChain();
- SDValue Scale = MGS->getScale();
- SDValue Index = MGS->getIndex();
- SDValue Mask = MGS->getMask();
- SDValue BasePtr = MGS->getBasePtr();
- ISD::MemIndexType IndexType = MGS->getIndexType();
-
- EVT IdxVT = Index.getValueType();
-
- if (DCI.isBeforeLegalize()) {
- // SVE gather/scatter requires indices of i32/i64. Promote anything smaller
- // prior to legalisation so the result can be split if required.
- if ((IdxVT.getVectorElementType() == MVT::i8) ||
- (IdxVT.getVectorElementType() == MVT::i16)) {
- EVT NewIdxVT = IdxVT.changeVectorElementType(MVT::i32);
- if (MGS->isIndexSigned())
- Index = DAG.getNode(ISD::SIGN_EXTEND, DL, NewIdxVT, Index);
- else
- Index = DAG.getNode(ISD::ZERO_EXTEND, DL, NewIdxVT, Index);
-
- if (auto *MGT = dyn_cast<MaskedGatherSDNode>(MGS)) {
- SDValue PassThru = MGT->getPassThru();
- SDValue Ops[] = { Chain, PassThru, Mask, BasePtr, Index, Scale };
- return DAG.getMaskedGather(DAG.getVTList(N->getValueType(0), MVT::Other),
- PassThru.getValueType(), DL, Ops,
- MGT->getMemOperand(),
- MGT->getIndexType(), MGT->getExtensionType());
- } else {
- auto *MSC = cast<MaskedScatterSDNode>(MGS);
- SDValue Data = MSC->getValue();
- SDValue Ops[] = { Chain, Data, Mask, BasePtr, Index, Scale };
- return DAG.getMaskedScatter(DAG.getVTList(MVT::Other),
- MSC->getMemoryVT(), DL, Ops,
- MSC->getMemOperand(), IndexType,
- MSC->isTruncatingStore());
- }
- }
- }
+ if (SDValue Store = foldTruncStoreOfExt(DAG, N))
+ return Store;
return SDValue();
}
@@ -14902,6 +15652,67 @@ static SDValue performBRCONDCombine(SDNode *N,
return SDValue();
}
+// Optimize CSEL instructions
+static SDValue performCSELCombine(SDNode *N,
+ TargetLowering::DAGCombinerInfo &DCI,
+ SelectionDAG &DAG) {
+ // CSEL x, x, cc -> x
+ if (N->getOperand(0) == N->getOperand(1))
+ return N->getOperand(0);
+
+ return performCONDCombine(N, DCI, DAG, 2, 3);
+}
+
+static SDValue performSETCCCombine(SDNode *N, SelectionDAG &DAG) {
+ assert(N->getOpcode() == ISD::SETCC && "Unexpected opcode!");
+ SDValue LHS = N->getOperand(0);
+ SDValue RHS = N->getOperand(1);
+ ISD::CondCode Cond = cast<CondCodeSDNode>(N->getOperand(2))->get();
+
+ // setcc (csel 0, 1, cond, X), 1, ne ==> csel 0, 1, !cond, X
+ if (Cond == ISD::SETNE && isOneConstant(RHS) &&
+ LHS->getOpcode() == AArch64ISD::CSEL &&
+ isNullConstant(LHS->getOperand(0)) && isOneConstant(LHS->getOperand(1)) &&
+ LHS->hasOneUse()) {
+ SDLoc DL(N);
+
+ // Invert CSEL's condition.
+ auto *OpCC = cast<ConstantSDNode>(LHS.getOperand(2));
+ auto OldCond = static_cast<AArch64CC::CondCode>(OpCC->getZExtValue());
+ auto NewCond = getInvertedCondCode(OldCond);
+
+ // csel 0, 1, !cond, X
+ SDValue CSEL =
+ DAG.getNode(AArch64ISD::CSEL, DL, LHS.getValueType(), LHS.getOperand(0),
+ LHS.getOperand(1), DAG.getConstant(NewCond, DL, MVT::i32),
+ LHS.getOperand(3));
+ return DAG.getZExtOrTrunc(CSEL, DL, N->getValueType(0));
+ }
+
+ return SDValue();
+}
+
+static SDValue performSetccMergeZeroCombine(SDNode *N, SelectionDAG &DAG) {
+ assert(N->getOpcode() == AArch64ISD::SETCC_MERGE_ZERO &&
+ "Unexpected opcode!");
+
+ SDValue Pred = N->getOperand(0);
+ SDValue LHS = N->getOperand(1);
+ SDValue RHS = N->getOperand(2);
+ ISD::CondCode Cond = cast<CondCodeSDNode>(N->getOperand(3))->get();
+
+ // setcc_merge_zero pred (sign_extend (setcc_merge_zero ... pred ...)), 0, ne
+ // => inner setcc_merge_zero
+ if (Cond == ISD::SETNE && isZerosVector(RHS.getNode()) &&
+ LHS->getOpcode() == ISD::SIGN_EXTEND &&
+ LHS->getOperand(0)->getValueType(0) == N->getValueType(0) &&
+ LHS->getOperand(0)->getOpcode() == AArch64ISD::SETCC_MERGE_ZERO &&
+ LHS->getOperand(0)->getOperand(0) == Pred)
+ return LHS->getOperand(0);
+
+ return SDValue();
+}
+
// Optimize some simple tbz/tbnz cases. Returns the new operand and bit to test
// as well as whether the test should be inverted. This code is required to
// catch these cases (as opposed to standard dag combines) because
@@ -15014,7 +15825,41 @@ static SDValue performVSelectCombine(SDNode *N, SelectionDAG &DAG) {
SDValue N0 = N->getOperand(0);
EVT CCVT = N0.getValueType();
- if (N0.getOpcode() != ISD::SETCC || CCVT.getVectorNumElements() != 1 ||
+ // Check for sign pattern (VSELECT setgt, iN lhs, -1, 1, -1) and transform
+ // into (OR (ASR lhs, N-1), 1), which requires less instructions for the
+ // supported types.
+ SDValue SetCC = N->getOperand(0);
+ if (SetCC.getOpcode() == ISD::SETCC &&
+ SetCC.getOperand(2) == DAG.getCondCode(ISD::SETGT)) {
+ SDValue CmpLHS = SetCC.getOperand(0);
+ EVT VT = CmpLHS.getValueType();
+ SDNode *CmpRHS = SetCC.getOperand(1).getNode();
+ SDNode *SplatLHS = N->getOperand(1).getNode();
+ SDNode *SplatRHS = N->getOperand(2).getNode();
+ APInt SplatLHSVal;
+ if (CmpLHS.getValueType() == N->getOperand(1).getValueType() &&
+ VT.isSimple() &&
+ is_contained(
+ makeArrayRef({MVT::v8i8, MVT::v16i8, MVT::v4i16, MVT::v8i16,
+ MVT::v2i32, MVT::v4i32, MVT::v2i64}),
+ VT.getSimpleVT().SimpleTy) &&
+ ISD::isConstantSplatVector(SplatLHS, SplatLHSVal) &&
+ SplatLHSVal.isOneValue() && ISD::isConstantSplatVectorAllOnes(CmpRHS) &&
+ ISD::isConstantSplatVectorAllOnes(SplatRHS)) {
+ unsigned NumElts = VT.getVectorNumElements();
+ SmallVector<SDValue, 8> Ops(
+ NumElts, DAG.getConstant(VT.getScalarSizeInBits() - 1, SDLoc(N),
+ VT.getScalarType()));
+ SDValue Val = DAG.getBuildVector(VT, SDLoc(N), Ops);
+
+ auto Shift = DAG.getNode(ISD::SRA, SDLoc(N), VT, CmpLHS, Val);
+ auto Or = DAG.getNode(ISD::OR, SDLoc(N), VT, Shift, N->getOperand(1));
+ return Or;
+ }
+ }
+
+ if (N0.getOpcode() != ISD::SETCC ||
+ CCVT.getVectorElementCount() != ElementCount::getFixed(1) ||
CCVT.getVectorElementType() != MVT::i1)
return SDValue();
@@ -15027,10 +15872,9 @@ static SDValue performVSelectCombine(SDNode *N, SelectionDAG &DAG) {
SDValue IfTrue = N->getOperand(1);
SDValue IfFalse = N->getOperand(2);
- SDValue SetCC =
- DAG.getSetCC(SDLoc(N), CmpVT.changeVectorElementTypeToInteger(),
- N0.getOperand(0), N0.getOperand(1),
- cast<CondCodeSDNode>(N0.getOperand(2))->get());
+ SetCC = DAG.getSetCC(SDLoc(N), CmpVT.changeVectorElementTypeToInteger(),
+ N0.getOperand(0), N0.getOperand(1),
+ cast<CondCodeSDNode>(N0.getOperand(2))->get());
return DAG.getNode(ISD::VSELECT, SDLoc(N), ResVT, SetCC,
IfTrue, IfFalse);
}
@@ -15048,6 +15892,9 @@ static SDValue performSelectCombine(SDNode *N,
if (N0.getOpcode() != ISD::SETCC)
return SDValue();
+ if (ResVT.isScalableVector())
+ return SDValue();
+
// Make sure the SETCC result is either i1 (initial DAG), or i32, the lowered
// scalar SetCCResultType. We also don't expect vectors, because we assume
// that selects fed by vector SETCCs are canonicalized to VSELECT.
@@ -15180,7 +16027,6 @@ static SDValue getScaledOffsetForBitWidth(SelectionDAG &DAG, SDValue Offset,
/// [<Zn>.[S|D]{, #<imm>}]
///
/// where <imm> = sizeof(<T>) * k, for k = 0, 1, ..., 31.
-
inline static bool isValidImmForSVEVecImmAddrMode(unsigned OffsetInBytes,
unsigned ScalarSizeInBytes) {
// The immediate is not a multiple of the scalar size.
@@ -15588,6 +16434,97 @@ static SDValue combineSVEPrefetchVecBaseImmOff(SDNode *N, SelectionDAG &DAG,
return DAG.getNode(N->getOpcode(), DL, DAG.getVTList(MVT::Other), Ops);
}
+// Return true if the vector operation can guarantee only the first lane of its
+// result contains data, with all bits in other lanes set to zero.
+static bool isLanes1toNKnownZero(SDValue Op) {
+ switch (Op.getOpcode()) {
+ default:
+ return false;
+ case AArch64ISD::ANDV_PRED:
+ case AArch64ISD::EORV_PRED:
+ case AArch64ISD::FADDA_PRED:
+ case AArch64ISD::FADDV_PRED:
+ case AArch64ISD::FMAXNMV_PRED:
+ case AArch64ISD::FMAXV_PRED:
+ case AArch64ISD::FMINNMV_PRED:
+ case AArch64ISD::FMINV_PRED:
+ case AArch64ISD::ORV_PRED:
+ case AArch64ISD::SADDV_PRED:
+ case AArch64ISD::SMAXV_PRED:
+ case AArch64ISD::SMINV_PRED:
+ case AArch64ISD::UADDV_PRED:
+ case AArch64ISD::UMAXV_PRED:
+ case AArch64ISD::UMINV_PRED:
+ return true;
+ }
+}
+
+static SDValue removeRedundantInsertVectorElt(SDNode *N) {
+ assert(N->getOpcode() == ISD::INSERT_VECTOR_ELT && "Unexpected node!");
+ SDValue InsertVec = N->getOperand(0);
+ SDValue InsertElt = N->getOperand(1);
+ SDValue InsertIdx = N->getOperand(2);
+
+ // We only care about inserts into the first element...
+ if (!isNullConstant(InsertIdx))
+ return SDValue();
+ // ...of a zero'd vector...
+ if (!ISD::isConstantSplatVectorAllZeros(InsertVec.getNode()))
+ return SDValue();
+ // ...where the inserted data was previously extracted...
+ if (InsertElt.getOpcode() != ISD::EXTRACT_VECTOR_ELT)
+ return SDValue();
+
+ SDValue ExtractVec = InsertElt.getOperand(0);
+ SDValue ExtractIdx = InsertElt.getOperand(1);
+
+ // ...from the first element of a vector.
+ if (!isNullConstant(ExtractIdx))
+ return SDValue();
+
+ // If we get here we are effectively trying to zero lanes 1-N of a vector.
+
+ // Ensure there's no type conversion going on.
+ if (N->getValueType(0) != ExtractVec.getValueType())
+ return SDValue();
+
+ if (!isLanes1toNKnownZero(ExtractVec))
+ return SDValue();
+
+ // The explicit zeroing is redundant.
+ return ExtractVec;
+}
+
+static SDValue
+performInsertVectorEltCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
+ if (SDValue Res = removeRedundantInsertVectorElt(N))
+ return Res;
+
+ return performPostLD1Combine(N, DCI, true);
+}
+
+SDValue performSVESpliceCombine(SDNode *N, SelectionDAG &DAG) {
+ EVT Ty = N->getValueType(0);
+ if (Ty.isInteger())
+ return SDValue();
+
+ EVT IntTy = Ty.changeVectorElementTypeToInteger();
+ EVT ExtIntTy = getPackedSVEVectorVT(IntTy.getVectorElementCount());
+ if (ExtIntTy.getVectorElementType().getScalarSizeInBits() <
+ IntTy.getVectorElementType().getScalarSizeInBits())
+ return SDValue();
+
+ SDLoc DL(N);
+ SDValue LHS = DAG.getAnyExtOrTrunc(DAG.getBitcast(IntTy, N->getOperand(0)),
+ DL, ExtIntTy);
+ SDValue RHS = DAG.getAnyExtOrTrunc(DAG.getBitcast(IntTy, N->getOperand(1)),
+ DL, ExtIntTy);
+ SDValue Idx = N->getOperand(2);
+ SDValue Splice = DAG.getNode(ISD::VECTOR_SPLICE, DL, ExtIntTy, LHS, RHS, Idx);
+ SDValue Trunc = DAG.getAnyExtOrTrunc(Splice, DL, IntTy);
+ return DAG.getBitcast(Ty, Trunc);
+}
+
SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
DAGCombinerInfo &DCI) const {
SelectionDAG &DAG = DCI.DAG;
@@ -15595,8 +16532,6 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
default:
LLVM_DEBUG(dbgs() << "Custom combining: skipping\n");
break;
- case ISD::ABS:
- return performABSCombine(N, DAG, DCI, Subtarget);
case ISD::ADD:
case ISD::SUB:
return performAddSubCombine(N, DCI, DAG);
@@ -15634,30 +16569,53 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
return performSelectCombine(N, DCI);
case ISD::VSELECT:
return performVSelectCombine(N, DCI.DAG);
+ case ISD::SETCC:
+ return performSETCCCombine(N, DAG);
case ISD::LOAD:
if (performTBISimplification(N->getOperand(1), DCI, DAG))
return SDValue(N, 0);
break;
case ISD::STORE:
return performSTORECombine(N, DCI, DAG, Subtarget);
- case ISD::MGATHER:
- case ISD::MSCATTER:
- return performMaskedGatherScatterCombine(N, DCI, DAG);
+ case ISD::VECTOR_SPLICE:
+ return performSVESpliceCombine(N, DAG);
case AArch64ISD::BRCOND:
return performBRCONDCombine(N, DCI, DAG);
case AArch64ISD::TBNZ:
case AArch64ISD::TBZ:
return performTBZCombine(N, DCI, DAG);
case AArch64ISD::CSEL:
- return performCONDCombine(N, DCI, DAG, 2, 3);
+ return performCSELCombine(N, DCI, DAG);
case AArch64ISD::DUP:
return performPostLD1Combine(N, DCI, false);
case AArch64ISD::NVCAST:
return performNVCASTCombine(N);
+ case AArch64ISD::SPLICE:
+ return performSpliceCombine(N, DAG);
case AArch64ISD::UZP1:
return performUzpCombine(N, DAG);
+ case AArch64ISD::SETCC_MERGE_ZERO:
+ return performSetccMergeZeroCombine(N, DAG);
+ case AArch64ISD::GLD1_MERGE_ZERO:
+ case AArch64ISD::GLD1_SCALED_MERGE_ZERO:
+ case AArch64ISD::GLD1_UXTW_MERGE_ZERO:
+ case AArch64ISD::GLD1_SXTW_MERGE_ZERO:
+ case AArch64ISD::GLD1_UXTW_SCALED_MERGE_ZERO:
+ case AArch64ISD::GLD1_SXTW_SCALED_MERGE_ZERO:
+ case AArch64ISD::GLD1_IMM_MERGE_ZERO:
+ case AArch64ISD::GLD1S_MERGE_ZERO:
+ case AArch64ISD::GLD1S_SCALED_MERGE_ZERO:
+ case AArch64ISD::GLD1S_UXTW_MERGE_ZERO:
+ case AArch64ISD::GLD1S_SXTW_MERGE_ZERO:
+ case AArch64ISD::GLD1S_UXTW_SCALED_MERGE_ZERO:
+ case AArch64ISD::GLD1S_SXTW_SCALED_MERGE_ZERO:
+ case AArch64ISD::GLD1S_IMM_MERGE_ZERO:
+ return performGLD1Combine(N, DAG);
+ case AArch64ISD::VASHR:
+ case AArch64ISD::VLSHR:
+ return performVectorShiftCombine(N, *this, DCI);
case ISD::INSERT_VECTOR_ELT:
- return performPostLD1Combine(N, DCI, true);
+ return performInsertVectorEltCombine(N, DCI);
case ISD::EXTRACT_VECTOR_ELT:
return performExtractVectorEltCombine(N, DAG);
case ISD::VECREDUCE_ADD:
@@ -15881,6 +16839,24 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
LowerSVEStructLoad(IntrinsicID, LoadOps, N->getValueType(0), DAG, DL);
return DAG.getMergeValues({Result, Chain}, DL);
}
+ case Intrinsic::aarch64_rndr:
+ case Intrinsic::aarch64_rndrrs: {
+ unsigned IntrinsicID =
+ cast<ConstantSDNode>(N->getOperand(1))->getZExtValue();
+ auto Register =
+ (IntrinsicID == Intrinsic::aarch64_rndr ? AArch64SysReg::RNDR
+ : AArch64SysReg::RNDRRS);
+ SDLoc DL(N);
+ SDValue A = DAG.getNode(
+ AArch64ISD::MRS, DL, DAG.getVTList(MVT::i64, MVT::Glue, MVT::Other),
+ N->getOperand(0), DAG.getConstant(Register, DL, MVT::i64));
+ SDValue B = DAG.getNode(
+ AArch64ISD::CSINC, DL, MVT::i32, DAG.getConstant(0, DL, MVT::i32),
+ DAG.getConstant(0, DL, MVT::i32),
+ DAG.getConstant(AArch64CC::NE, DL, MVT::i32), A.getValue(1));
+ return DAG.getMergeValues(
+ {A, DAG.getZExtOrTrunc(B, DL, MVT::i1), A.getValue(2)}, DL);
+ }
default:
break;
}
@@ -16007,13 +16983,22 @@ bool AArch64TargetLowering::getPostIndexedAddressParts(
return true;
}
-static void ReplaceBITCASTResults(SDNode *N, SmallVectorImpl<SDValue> &Results,
- SelectionDAG &DAG) {
+void AArch64TargetLowering::ReplaceBITCASTResults(
+ SDNode *N, SmallVectorImpl<SDValue> &Results, SelectionDAG &DAG) const {
SDLoc DL(N);
SDValue Op = N->getOperand(0);
+ EVT VT = N->getValueType(0);
+ EVT SrcVT = Op.getValueType();
+
+ if (VT.isScalableVector() && !isTypeLegal(VT) && isTypeLegal(SrcVT)) {
+ assert(!VT.isFloatingPoint() && SrcVT.isFloatingPoint() &&
+ "Expected fp->int bitcast!");
+ SDValue CastResult = getSVESafeBitCast(getSVEContainerType(VT), Op, DAG);
+ Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, VT, CastResult));
+ return;
+ }
- if (N->getValueType(0) != MVT::i16 ||
- (Op.getValueType() != MVT::f16 && Op.getValueType() != MVT::bf16))
+ if (VT != MVT::i16 || (SrcVT != MVT::f16 && SrcVT != MVT::bf16))
return;
Op = SDValue(
@@ -16107,6 +17092,7 @@ static void ReplaceCMP_SWAP_128Results(SDNode *N,
assert(N->getValueType(0) == MVT::i128 &&
"AtomicCmpSwap on types less than 128 should be legal");
+ MachineMemOperand *MemOp = cast<MemSDNode>(N)->getMemOperand();
if (Subtarget->hasLSE() || Subtarget->outlineAtomics()) {
// LSE has a 128-bit compare and swap (CASP), but i128 is not a legal type,
// so lower it here, wrapped in REG_SEQUENCE and EXTRACT_SUBREG.
@@ -16117,10 +17103,8 @@ static void ReplaceCMP_SWAP_128Results(SDNode *N,
N->getOperand(0), // Chain in
};
- MachineMemOperand *MemOp = cast<MemSDNode>(N)->getMemOperand();
-
unsigned Opcode;
- switch (MemOp->getOrdering()) {
+ switch (MemOp->getMergedOrdering()) {
case AtomicOrdering::Monotonic:
Opcode = AArch64::CASPX;
break;
@@ -16155,15 +17139,32 @@ static void ReplaceCMP_SWAP_128Results(SDNode *N,
return;
}
+ unsigned Opcode;
+ switch (MemOp->getMergedOrdering()) {
+ case AtomicOrdering::Monotonic:
+ Opcode = AArch64::CMP_SWAP_128_MONOTONIC;
+ break;
+ case AtomicOrdering::Acquire:
+ Opcode = AArch64::CMP_SWAP_128_ACQUIRE;
+ break;
+ case AtomicOrdering::Release:
+ Opcode = AArch64::CMP_SWAP_128_RELEASE;
+ break;
+ case AtomicOrdering::AcquireRelease:
+ case AtomicOrdering::SequentiallyConsistent:
+ Opcode = AArch64::CMP_SWAP_128;
+ break;
+ default:
+ llvm_unreachable("Unexpected ordering!");
+ }
+
auto Desired = splitInt128(N->getOperand(2), DAG);
auto New = splitInt128(N->getOperand(3), DAG);
SDValue Ops[] = {N->getOperand(1), Desired.first, Desired.second,
New.first, New.second, N->getOperand(0)};
SDNode *CmpSwap = DAG.getMachineNode(
- AArch64::CMP_SWAP_128, SDLoc(N),
- DAG.getVTList(MVT::i64, MVT::i64, MVT::i32, MVT::Other), Ops);
-
- MachineMemOperand *MemOp = cast<MemSDNode>(N)->getMemOperand();
+ Opcode, SDLoc(N), DAG.getVTList(MVT::i64, MVT::i64, MVT::i32, MVT::Other),
+ Ops);
DAG.setNodeMemRefs(cast<MachineSDNode>(CmpSwap), {MemOp});
Results.push_back(DAG.getNode(ISD::BUILD_PAIR, SDLoc(N), MVT::i128,
@@ -16241,6 +17242,10 @@ void AArch64TargetLowering::ReplaceNodeResults(
case ISD::EXTRACT_SUBVECTOR:
ReplaceExtractSubVectorResults(N, Results, DAG);
return;
+ case ISD::INSERT_SUBVECTOR:
+ // Custom lowering has been requested for INSERT_SUBVECTOR -- but delegate
+ // to common code for result type legalisation
+ return;
case ISD::INTRINSIC_WO_CHAIN: {
EVT VT = N->getValueType(0);
assert((VT == MVT::i8 || VT == MVT::i16) &&
@@ -16334,25 +17339,36 @@ AArch64TargetLowering::shouldExpandAtomicRMWInIR(AtomicRMWInst *AI) const {
unsigned Size = AI->getType()->getPrimitiveSizeInBits();
if (Size > 128) return AtomicExpansionKind::None;
- // Nand not supported in LSE.
- if (AI->getOperation() == AtomicRMWInst::Nand) return AtomicExpansionKind::LLSC;
- // Leave 128 bits to LLSC.
- if (Subtarget->hasLSE() && Size < 128)
- return AtomicExpansionKind::None;
- if (Subtarget->outlineAtomics() && Size < 128) {
- // [U]Min/[U]Max RWM atomics are used in __sync_fetch_ libcalls so far.
- // Don't outline them unless
- // (1) high level <atomic> support approved:
- // http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2020/p0493r1.pdf
- // (2) low level libgcc and compiler-rt support implemented by:
- // min/max outline atomics helpers
- if (AI->getOperation() != AtomicRMWInst::Min &&
- AI->getOperation() != AtomicRMWInst::Max &&
- AI->getOperation() != AtomicRMWInst::UMin &&
- AI->getOperation() != AtomicRMWInst::UMax) {
+
+ // Nand is not supported in LSE.
+ // Leave 128 bits to LLSC or CmpXChg.
+ if (AI->getOperation() != AtomicRMWInst::Nand && Size < 128) {
+ if (Subtarget->hasLSE())
return AtomicExpansionKind::None;
+ if (Subtarget->outlineAtomics()) {
+ // [U]Min/[U]Max RWM atomics are used in __sync_fetch_ libcalls so far.
+ // Don't outline them unless
+ // (1) high level <atomic> support approved:
+ // http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2020/p0493r1.pdf
+ // (2) low level libgcc and compiler-rt support implemented by:
+ // min/max outline atomics helpers
+ if (AI->getOperation() != AtomicRMWInst::Min &&
+ AI->getOperation() != AtomicRMWInst::Max &&
+ AI->getOperation() != AtomicRMWInst::UMin &&
+ AI->getOperation() != AtomicRMWInst::UMax) {
+ return AtomicExpansionKind::None;
+ }
}
}
+
+ // At -O0, fast-regalloc cannot cope with the live vregs necessary to
+ // implement atomicrmw without spilling. If the target address is also on the
+ // stack and close enough to the spill slot, this can lead to a situation
+ // where the monitor always gets cleared and the atomic operation can never
+ // succeed. So at -O0 lower this operation to a CAS loop.
+ if (getTargetMachine().getOptLevel() == CodeGenOpt::None)
+ return AtomicExpansionKind::CmpXChg;
+
return AtomicExpansionKind::LLSC;
}
@@ -16369,19 +17385,26 @@ AArch64TargetLowering::shouldExpandAtomicCmpXchgInIR(
// can never succeed. So at -O0 we need a late-expanded pseudo-inst instead.
if (getTargetMachine().getOptLevel() == CodeGenOpt::None)
return AtomicExpansionKind::None;
+
+ // 128-bit atomic cmpxchg is weird; AtomicExpand doesn't know how to expand
+ // it.
+ unsigned Size = AI->getCompareOperand()->getType()->getPrimitiveSizeInBits();
+ if (Size > 64)
+ return AtomicExpansionKind::None;
+
return AtomicExpansionKind::LLSC;
}
-Value *AArch64TargetLowering::emitLoadLinked(IRBuilder<> &Builder, Value *Addr,
+Value *AArch64TargetLowering::emitLoadLinked(IRBuilderBase &Builder,
+ Type *ValueTy, Value *Addr,
AtomicOrdering Ord) const {
Module *M = Builder.GetInsertBlock()->getParent()->getParent();
- Type *ValTy = cast<PointerType>(Addr->getType())->getElementType();
bool IsAcquire = isAcquireOrStronger(Ord);
// Since i128 isn't legal and intrinsics don't get type-lowered, the ldrexd
// intrinsic must return {i64, i64} and we have to recombine them into a
// single i128 here.
- if (ValTy->getPrimitiveSizeInBits() == 128) {
+ if (ValueTy->getPrimitiveSizeInBits() == 128) {
Intrinsic::ID Int =
IsAcquire ? Intrinsic::aarch64_ldaxp : Intrinsic::aarch64_ldxp;
Function *Ldxr = Intrinsic::getDeclaration(M, Int);
@@ -16391,10 +17414,10 @@ Value *AArch64TargetLowering::emitLoadLinked(IRBuilder<> &Builder, Value *Addr,
Value *Lo = Builder.CreateExtractValue(LoHi, 0, "lo");
Value *Hi = Builder.CreateExtractValue(LoHi, 1, "hi");
- Lo = Builder.CreateZExt(Lo, ValTy, "lo64");
- Hi = Builder.CreateZExt(Hi, ValTy, "hi64");
+ Lo = Builder.CreateZExt(Lo, ValueTy, "lo64");
+ Hi = Builder.CreateZExt(Hi, ValueTy, "hi64");
return Builder.CreateOr(
- Lo, Builder.CreateShl(Hi, ConstantInt::get(ValTy, 64)), "val64");
+ Lo, Builder.CreateShl(Hi, ConstantInt::get(ValueTy, 64)), "val64");
}
Type *Tys[] = { Addr->getType() };
@@ -16402,22 +17425,20 @@ Value *AArch64TargetLowering::emitLoadLinked(IRBuilder<> &Builder, Value *Addr,
IsAcquire ? Intrinsic::aarch64_ldaxr : Intrinsic::aarch64_ldxr;
Function *Ldxr = Intrinsic::getDeclaration(M, Int, Tys);
- Type *EltTy = cast<PointerType>(Addr->getType())->getElementType();
-
const DataLayout &DL = M->getDataLayout();
- IntegerType *IntEltTy = Builder.getIntNTy(DL.getTypeSizeInBits(EltTy));
+ IntegerType *IntEltTy = Builder.getIntNTy(DL.getTypeSizeInBits(ValueTy));
Value *Trunc = Builder.CreateTrunc(Builder.CreateCall(Ldxr, Addr), IntEltTy);
- return Builder.CreateBitCast(Trunc, EltTy);
+ return Builder.CreateBitCast(Trunc, ValueTy);
}
void AArch64TargetLowering::emitAtomicCmpXchgNoStoreLLBalance(
- IRBuilder<> &Builder) const {
+ IRBuilderBase &Builder) const {
Module *M = Builder.GetInsertBlock()->getParent()->getParent();
Builder.CreateCall(Intrinsic::getDeclaration(M, Intrinsic::aarch64_clrex));
}
-Value *AArch64TargetLowering::emitStoreConditional(IRBuilder<> &Builder,
+Value *AArch64TargetLowering::emitStoreConditional(IRBuilderBase &Builder,
Value *Val, Value *Addr,
AtomicOrdering Ord) const {
Module *M = Builder.GetInsertBlock()->getParent()->getParent();
@@ -16454,15 +17475,17 @@ Value *AArch64TargetLowering::emitStoreConditional(IRBuilder<> &Builder,
}
bool AArch64TargetLowering::functionArgumentNeedsConsecutiveRegisters(
- Type *Ty, CallingConv::ID CallConv, bool isVarArg) const {
- if (Ty->isArrayTy())
- return true;
-
- const TypeSize &TySize = Ty->getPrimitiveSizeInBits();
- if (TySize.isScalable() && TySize.getKnownMinSize() > 128)
- return true;
+ Type *Ty, CallingConv::ID CallConv, bool isVarArg,
+ const DataLayout &DL) const {
+ if (!Ty->isArrayTy()) {
+ const TypeSize &TySize = Ty->getPrimitiveSizeInBits();
+ return TySize.isScalable() && TySize.getKnownMinSize() > 128;
+ }
- return false;
+ // All non aggregate members of the type must have the same type
+ SmallVector<EVT> ValueVTs;
+ ComputeValueVTs(*this, DL, Ty, ValueVTs);
+ return is_splat(ValueVTs);
}
bool AArch64TargetLowering::shouldNormalizeToSelectSequence(LLVMContext &,
@@ -16470,7 +17493,7 @@ bool AArch64TargetLowering::shouldNormalizeToSelectSequence(LLVMContext &,
return false;
}
-static Value *UseTlsOffset(IRBuilder<> &IRB, unsigned Offset) {
+static Value *UseTlsOffset(IRBuilderBase &IRB, unsigned Offset) {
Module *M = IRB.GetInsertBlock()->getParent()->getParent();
Function *ThreadPointerFunc =
Intrinsic::getDeclaration(M, Intrinsic::thread_pointer);
@@ -16480,7 +17503,7 @@ static Value *UseTlsOffset(IRBuilder<> &IRB, unsigned Offset) {
IRB.getInt8PtrTy()->getPointerTo(0));
}
-Value *AArch64TargetLowering::getIRStackGuard(IRBuilder<> &IRB) const {
+Value *AArch64TargetLowering::getIRStackGuard(IRBuilderBase &IRB) const {
// Android provides a fixed TLS slot for the stack cookie. See the definition
// of TLS_SLOT_STACK_GUARD in
// https://android.googlesource.com/platform/bionic/+/master/libc/private/bionic_tls.h
@@ -16529,7 +17552,8 @@ Function *AArch64TargetLowering::getSSPStackGuardCheck(const Module &M) const {
return TargetLowering::getSSPStackGuardCheck(M);
}
-Value *AArch64TargetLowering::getSafeStackPointerLocation(IRBuilder<> &IRB) const {
+Value *
+AArch64TargetLowering::getSafeStackPointerLocation(IRBuilderBase &IRB) const {
// Android provides a fixed TLS slot for the SafeStack pointer. See the
// definition of TLS_SLOT_SAFESTACK in
// https://android.googlesource.com/platform/bionic/+/master/libc/private/bionic_tls.h
@@ -16854,6 +17878,66 @@ SDValue AArch64TargetLowering::LowerFixedLengthVectorLoadToSVE(
return DAG.getMergeValues(MergedValues, DL);
}
+static SDValue convertFixedMaskToScalableVector(SDValue Mask,
+ SelectionDAG &DAG) {
+ SDLoc DL(Mask);
+ EVT InVT = Mask.getValueType();
+ EVT ContainerVT = getContainerForFixedLengthVector(DAG, InVT);
+
+ auto Op1 = convertToScalableVector(DAG, ContainerVT, Mask);
+ auto Op2 = DAG.getConstant(0, DL, ContainerVT);
+ auto Pg = getPredicateForFixedLengthVector(DAG, DL, InVT);
+
+ EVT CmpVT = Pg.getValueType();
+ return DAG.getNode(AArch64ISD::SETCC_MERGE_ZERO, DL, CmpVT,
+ {Pg, Op1, Op2, DAG.getCondCode(ISD::SETNE)});
+}
+
+// Convert all fixed length vector loads larger than NEON to masked_loads.
+SDValue AArch64TargetLowering::LowerFixedLengthVectorMLoadToSVE(
+ SDValue Op, SelectionDAG &DAG) const {
+ auto Load = cast<MaskedLoadSDNode>(Op);
+
+ if (Load->getExtensionType() != ISD::LoadExtType::NON_EXTLOAD)
+ return SDValue();
+
+ SDLoc DL(Op);
+ EVT VT = Op.getValueType();
+ EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT);
+
+ SDValue Mask = convertFixedMaskToScalableVector(Load->getMask(), DAG);
+
+ SDValue PassThru;
+ bool IsPassThruZeroOrUndef = false;
+
+ if (Load->getPassThru()->isUndef()) {
+ PassThru = DAG.getUNDEF(ContainerVT);
+ IsPassThruZeroOrUndef = true;
+ } else {
+ if (ContainerVT.isInteger())
+ PassThru = DAG.getConstant(0, DL, ContainerVT);
+ else
+ PassThru = DAG.getConstantFP(0, DL, ContainerVT);
+ if (isZerosVector(Load->getPassThru().getNode()))
+ IsPassThruZeroOrUndef = true;
+ }
+
+ auto NewLoad = DAG.getMaskedLoad(
+ ContainerVT, DL, Load->getChain(), Load->getBasePtr(), Load->getOffset(),
+ Mask, PassThru, Load->getMemoryVT(), Load->getMemOperand(),
+ Load->getAddressingMode(), Load->getExtensionType());
+
+ if (!IsPassThruZeroOrUndef) {
+ SDValue OldPassThru =
+ convertToScalableVector(DAG, ContainerVT, Load->getPassThru());
+ NewLoad = DAG.getSelect(DL, ContainerVT, Mask, NewLoad, OldPassThru);
+ }
+
+ auto Result = convertFromScalableVector(DAG, VT, NewLoad);
+ SDValue MergedValues[2] = {Result, Load->getChain()};
+ return DAG.getMergeValues(MergedValues, DL);
+}
+
// Convert all fixed length vector stores larger than NEON to masked_stores.
SDValue AArch64TargetLowering::LowerFixedLengthVectorStoreToSVE(
SDValue Op, SelectionDAG &DAG) const {
@@ -16871,6 +17955,26 @@ SDValue AArch64TargetLowering::LowerFixedLengthVectorStoreToSVE(
Store->isTruncatingStore());
}
+SDValue AArch64TargetLowering::LowerFixedLengthVectorMStoreToSVE(
+ SDValue Op, SelectionDAG &DAG) const {
+ auto Store = cast<MaskedStoreSDNode>(Op);
+
+ if (Store->isTruncatingStore())
+ return SDValue();
+
+ SDLoc DL(Op);
+ EVT VT = Store->getValue().getValueType();
+ EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT);
+
+ auto NewValue = convertToScalableVector(DAG, ContainerVT, Store->getValue());
+ SDValue Mask = convertFixedMaskToScalableVector(Store->getMask(), DAG);
+
+ return DAG.getMaskedStore(
+ Store->getChain(), DL, NewValue, Store->getBasePtr(), Store->getOffset(),
+ Mask, Store->getMemoryVT(), Store->getMemOperand(),
+ Store->getAddressingMode(), Store->isTruncatingStore());
+}
+
SDValue AArch64TargetLowering::LowerFixedLengthVectorIntDivideToSVE(
SDValue Op, SelectionDAG &DAG) const {
SDLoc dl(Op);
@@ -16890,6 +17994,16 @@ SDValue AArch64TargetLowering::LowerFixedLengthVectorIntDivideToSVE(
EVT FixedWidenedVT = HalfVT.widenIntegerVectorElementType(*DAG.getContext());
EVT ScalableWidenedVT = getContainerForFixedLengthVector(DAG, FixedWidenedVT);
+ // If this is not a full vector, extend, div, and truncate it.
+ EVT WidenedVT = VT.widenIntegerVectorElementType(*DAG.getContext());
+ if (DAG.getTargetLoweringInfo().isTypeLegal(WidenedVT)) {
+ unsigned ExtendOpcode = Signed ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
+ SDValue Op0 = DAG.getNode(ExtendOpcode, dl, WidenedVT, Op.getOperand(0));
+ SDValue Op1 = DAG.getNode(ExtendOpcode, dl, WidenedVT, Op.getOperand(1));
+ SDValue Div = DAG.getNode(Op.getOpcode(), dl, WidenedVT, Op0, Op1);
+ return DAG.getNode(ISD::TRUNCATE, dl, VT, Div);
+ }
+
// Convert the operands to scalable vectors.
SDValue Op0 = convertToScalableVector(DAG, ContainerVT, Op.getOperand(0));
SDValue Op1 = convertToScalableVector(DAG, ContainerVT, Op.getOperand(1));
@@ -16993,6 +18107,35 @@ SDValue AArch64TargetLowering::LowerFixedLengthVectorTruncateToSVE(
return convertFromScalableVector(DAG, VT, Val);
}
+SDValue AArch64TargetLowering::LowerFixedLengthExtractVectorElt(
+ SDValue Op, SelectionDAG &DAG) const {
+ EVT VT = Op.getValueType();
+ EVT InVT = Op.getOperand(0).getValueType();
+ assert(InVT.isFixedLengthVector() && "Expected fixed length vector type!");
+
+ SDLoc DL(Op);
+ EVT ContainerVT = getContainerForFixedLengthVector(DAG, InVT);
+ SDValue Op0 = convertToScalableVector(DAG, ContainerVT, Op->getOperand(0));
+
+ return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Op0, Op.getOperand(1));
+}
+
+SDValue AArch64TargetLowering::LowerFixedLengthInsertVectorElt(
+ SDValue Op, SelectionDAG &DAG) const {
+ EVT VT = Op.getValueType();
+ assert(VT.isFixedLengthVector() && "Expected fixed length vector type!");
+
+ SDLoc DL(Op);
+ EVT InVT = Op.getOperand(0).getValueType();
+ EVT ContainerVT = getContainerForFixedLengthVector(DAG, InVT);
+ SDValue Op0 = convertToScalableVector(DAG, ContainerVT, Op->getOperand(0));
+
+ auto ScalableRes = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, ContainerVT, Op0,
+ Op.getOperand(1), Op.getOperand(2));
+
+ return convertFromScalableVector(DAG, VT, ScalableRes);
+}
+
// Convert vector operation 'Op' to an equivalent predicated operation whereby
// the original operation's type is used to construct a suitable predicate.
// NOTE: The results for inactive lanes are undefined.
@@ -17209,10 +18352,6 @@ SDValue AArch64TargetLowering::LowerFixedLengthVectorSetccToSVE(
assert(Op.getValueType() == InVT.changeTypeToInteger() &&
"Expected integer result of the same bit length as the inputs!");
- // Expand floating point vector comparisons.
- if (InVT.isFloatingPoint())
- return SDValue();
-
auto Op1 = convertToScalableVector(DAG, ContainerVT, Op.getOperand(0));
auto Op2 = convertToScalableVector(DAG, ContainerVT, Op.getOperand(1));
auto Pg = getPredicateForFixedLengthVector(DAG, DL, InVT);
@@ -17226,6 +18365,229 @@ SDValue AArch64TargetLowering::LowerFixedLengthVectorSetccToSVE(
return convertFromScalableVector(DAG, Op.getValueType(), Promote);
}
+SDValue
+AArch64TargetLowering::LowerFixedLengthBitcastToSVE(SDValue Op,
+ SelectionDAG &DAG) const {
+ SDLoc DL(Op);
+ auto SrcOp = Op.getOperand(0);
+ EVT VT = Op.getValueType();
+ EVT ContainerDstVT = getContainerForFixedLengthVector(DAG, VT);
+ EVT ContainerSrcVT =
+ getContainerForFixedLengthVector(DAG, SrcOp.getValueType());
+
+ SrcOp = convertToScalableVector(DAG, ContainerSrcVT, SrcOp);
+ Op = DAG.getNode(ISD::BITCAST, DL, ContainerDstVT, SrcOp);
+ return convertFromScalableVector(DAG, VT, Op);
+}
+
+SDValue AArch64TargetLowering::LowerFixedLengthConcatVectorsToSVE(
+ SDValue Op, SelectionDAG &DAG) const {
+ SDLoc DL(Op);
+ unsigned NumOperands = Op->getNumOperands();
+
+ assert(NumOperands > 1 && isPowerOf2_32(NumOperands) &&
+ "Unexpected number of operands in CONCAT_VECTORS");
+
+ auto SrcOp1 = Op.getOperand(0);
+ auto SrcOp2 = Op.getOperand(1);
+ EVT VT = Op.getValueType();
+ EVT SrcVT = SrcOp1.getValueType();
+
+ if (NumOperands > 2) {
+ SmallVector<SDValue, 4> Ops;
+ EVT PairVT = SrcVT.getDoubleNumVectorElementsVT(*DAG.getContext());
+ for (unsigned I = 0; I < NumOperands; I += 2)
+ Ops.push_back(DAG.getNode(ISD::CONCAT_VECTORS, DL, PairVT,
+ Op->getOperand(I), Op->getOperand(I + 1)));
+
+ return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Ops);
+ }
+
+ EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT);
+
+ SDValue Pg = getPredicateForFixedLengthVector(DAG, DL, SrcVT);
+ SrcOp1 = convertToScalableVector(DAG, ContainerVT, SrcOp1);
+ SrcOp2 = convertToScalableVector(DAG, ContainerVT, SrcOp2);
+
+ Op = DAG.getNode(AArch64ISD::SPLICE, DL, ContainerVT, Pg, SrcOp1, SrcOp2);
+
+ return convertFromScalableVector(DAG, VT, Op);
+}
+
+SDValue
+AArch64TargetLowering::LowerFixedLengthFPExtendToSVE(SDValue Op,
+ SelectionDAG &DAG) const {
+ EVT VT = Op.getValueType();
+ assert(VT.isFixedLengthVector() && "Expected fixed length vector type!");
+
+ SDLoc DL(Op);
+ SDValue Val = Op.getOperand(0);
+ SDValue Pg = getPredicateForVector(DAG, DL, VT);
+ EVT SrcVT = Val.getValueType();
+ EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT);
+ EVT ExtendVT = ContainerVT.changeVectorElementType(
+ SrcVT.getVectorElementType());
+
+ Val = DAG.getNode(ISD::BITCAST, DL, SrcVT.changeTypeToInteger(), Val);
+ Val = DAG.getNode(ISD::ANY_EXTEND, DL, VT.changeTypeToInteger(), Val);
+
+ Val = convertToScalableVector(DAG, ContainerVT.changeTypeToInteger(), Val);
+ Val = getSVESafeBitCast(ExtendVT, Val, DAG);
+ Val = DAG.getNode(AArch64ISD::FP_EXTEND_MERGE_PASSTHRU, DL, ContainerVT,
+ Pg, Val, DAG.getUNDEF(ContainerVT));
+
+ return convertFromScalableVector(DAG, VT, Val);
+}
+
+SDValue
+AArch64TargetLowering::LowerFixedLengthFPRoundToSVE(SDValue Op,
+ SelectionDAG &DAG) const {
+ EVT VT = Op.getValueType();
+ assert(VT.isFixedLengthVector() && "Expected fixed length vector type!");
+
+ SDLoc DL(Op);
+ SDValue Val = Op.getOperand(0);
+ EVT SrcVT = Val.getValueType();
+ EVT ContainerSrcVT = getContainerForFixedLengthVector(DAG, SrcVT);
+ EVT RoundVT = ContainerSrcVT.changeVectorElementType(
+ VT.getVectorElementType());
+ SDValue Pg = getPredicateForVector(DAG, DL, RoundVT);
+
+ Val = convertToScalableVector(DAG, ContainerSrcVT, Val);
+ Val = DAG.getNode(AArch64ISD::FP_ROUND_MERGE_PASSTHRU, DL, RoundVT, Pg, Val,
+ Op.getOperand(1), DAG.getUNDEF(RoundVT));
+ Val = getSVESafeBitCast(ContainerSrcVT.changeTypeToInteger(), Val, DAG);
+ Val = convertFromScalableVector(DAG, SrcVT.changeTypeToInteger(), Val);
+
+ Val = DAG.getNode(ISD::TRUNCATE, DL, VT.changeTypeToInteger(), Val);
+ return DAG.getNode(ISD::BITCAST, DL, VT, Val);
+}
+
+SDValue
+AArch64TargetLowering::LowerFixedLengthIntToFPToSVE(SDValue Op,
+ SelectionDAG &DAG) const {
+ EVT VT = Op.getValueType();
+ assert(VT.isFixedLengthVector() && "Expected fixed length vector type!");
+
+ bool IsSigned = Op.getOpcode() == ISD::SINT_TO_FP;
+ unsigned Opcode = IsSigned ? AArch64ISD::SINT_TO_FP_MERGE_PASSTHRU
+ : AArch64ISD::UINT_TO_FP_MERGE_PASSTHRU;
+
+ SDLoc DL(Op);
+ SDValue Val = Op.getOperand(0);
+ EVT SrcVT = Val.getValueType();
+ EVT ContainerDstVT = getContainerForFixedLengthVector(DAG, VT);
+ EVT ContainerSrcVT = getContainerForFixedLengthVector(DAG, SrcVT);
+
+ if (ContainerSrcVT.getVectorElementType().getSizeInBits() <=
+ ContainerDstVT.getVectorElementType().getSizeInBits()) {
+ SDValue Pg = getPredicateForVector(DAG, DL, VT);
+
+ Val = DAG.getNode(IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND, DL,
+ VT.changeTypeToInteger(), Val);
+
+ Val = convertToScalableVector(DAG, ContainerSrcVT, Val);
+ Val = getSVESafeBitCast(ContainerDstVT.changeTypeToInteger(), Val, DAG);
+ // Safe to use a larger than specified operand since we just unpacked the
+ // data, hence the upper bits are zero.
+ Val = DAG.getNode(Opcode, DL, ContainerDstVT, Pg, Val,
+ DAG.getUNDEF(ContainerDstVT));
+ return convertFromScalableVector(DAG, VT, Val);
+ } else {
+ EVT CvtVT = ContainerSrcVT.changeVectorElementType(
+ ContainerDstVT.getVectorElementType());
+ SDValue Pg = getPredicateForVector(DAG, DL, CvtVT);
+
+ Val = convertToScalableVector(DAG, ContainerSrcVT, Val);
+ Val = DAG.getNode(Opcode, DL, CvtVT, Pg, Val, DAG.getUNDEF(CvtVT));
+ Val = getSVESafeBitCast(ContainerSrcVT, Val, DAG);
+ Val = convertFromScalableVector(DAG, SrcVT, Val);
+
+ Val = DAG.getNode(ISD::TRUNCATE, DL, VT.changeTypeToInteger(), Val);
+ return DAG.getNode(ISD::BITCAST, DL, VT, Val);
+ }
+}
+
+SDValue
+AArch64TargetLowering::LowerFixedLengthFPToIntToSVE(SDValue Op,
+ SelectionDAG &DAG) const {
+ EVT VT = Op.getValueType();
+ assert(VT.isFixedLengthVector() && "Expected fixed length vector type!");
+
+ bool IsSigned = Op.getOpcode() == ISD::FP_TO_SINT;
+ unsigned Opcode = IsSigned ? AArch64ISD::FCVTZS_MERGE_PASSTHRU
+ : AArch64ISD::FCVTZU_MERGE_PASSTHRU;
+
+ SDLoc DL(Op);
+ SDValue Val = Op.getOperand(0);
+ EVT SrcVT = Val.getValueType();
+ EVT ContainerDstVT = getContainerForFixedLengthVector(DAG, VT);
+ EVT ContainerSrcVT = getContainerForFixedLengthVector(DAG, SrcVT);
+
+ if (ContainerSrcVT.getVectorElementType().getSizeInBits() <=
+ ContainerDstVT.getVectorElementType().getSizeInBits()) {
+ EVT CvtVT = ContainerDstVT.changeVectorElementType(
+ ContainerSrcVT.getVectorElementType());
+ SDValue Pg = getPredicateForVector(DAG, DL, VT);
+
+ Val = DAG.getNode(ISD::BITCAST, DL, SrcVT.changeTypeToInteger(), Val);
+ Val = DAG.getNode(ISD::ANY_EXTEND, DL, VT, Val);
+
+ Val = convertToScalableVector(DAG, ContainerSrcVT, Val);
+ Val = getSVESafeBitCast(CvtVT, Val, DAG);
+ Val = DAG.getNode(Opcode, DL, ContainerDstVT, Pg, Val,
+ DAG.getUNDEF(ContainerDstVT));
+ return convertFromScalableVector(DAG, VT, Val);
+ } else {
+ EVT CvtVT = ContainerSrcVT.changeTypeToInteger();
+ SDValue Pg = getPredicateForVector(DAG, DL, CvtVT);
+
+ // Safe to use a larger than specified result since an fp_to_int where the
+ // result doesn't fit into the destination is undefined.
+ Val = convertToScalableVector(DAG, ContainerSrcVT, Val);
+ Val = DAG.getNode(Opcode, DL, CvtVT, Pg, Val, DAG.getUNDEF(CvtVT));
+ Val = convertFromScalableVector(DAG, SrcVT.changeTypeToInteger(), Val);
+
+ return DAG.getNode(ISD::TRUNCATE, DL, VT, Val);
+ }
+}
+
+SDValue AArch64TargetLowering::LowerFixedLengthVECTOR_SHUFFLEToSVE(
+ SDValue Op, SelectionDAG &DAG) const {
+ EVT VT = Op.getValueType();
+ assert(VT.isFixedLengthVector() && "Expected fixed length vector type!");
+
+ auto *SVN = cast<ShuffleVectorSDNode>(Op.getNode());
+ auto ShuffleMask = SVN->getMask();
+
+ SDLoc DL(Op);
+ SDValue Op1 = Op.getOperand(0);
+ SDValue Op2 = Op.getOperand(1);
+
+ EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT);
+ Op1 = convertToScalableVector(DAG, ContainerVT, Op1);
+ Op2 = convertToScalableVector(DAG, ContainerVT, Op2);
+
+ bool ReverseEXT = false;
+ unsigned Imm;
+ if (isEXTMask(ShuffleMask, VT, ReverseEXT, Imm) &&
+ Imm == VT.getVectorNumElements() - 1) {
+ if (ReverseEXT)
+ std::swap(Op1, Op2);
+
+ EVT ScalarTy = VT.getVectorElementType();
+ if ((ScalarTy == MVT::i8) || (ScalarTy == MVT::i16))
+ ScalarTy = MVT::i32;
+ SDValue Scalar = DAG.getNode(
+ ISD::EXTRACT_VECTOR_ELT, DL, ScalarTy, Op1,
+ DAG.getConstant(VT.getVectorNumElements() - 1, DL, MVT::i64));
+ Op = DAG.getNode(AArch64ISD::INSR, DL, ContainerVT, Op2, Scalar);
+ return convertFromScalableVector(DAG, VT, Op);
+ }
+
+ return SDValue();
+}
+
SDValue AArch64TargetLowering::getSVESafeBitCast(EVT VT, SDValue Op,
SelectionDAG &DAG) const {
SDLoc DL(Op);
@@ -17248,8 +18610,6 @@ SDValue AArch64TargetLowering::getSVESafeBitCast(EVT VT, SDValue Op,
EVT PackedVT = getPackedSVEVectorVT(VT.getVectorElementType());
EVT PackedInVT = getPackedSVEVectorVT(InVT.getVectorElementType());
- assert((VT == PackedVT || InVT == PackedInVT) &&
- "Cannot cast between unpacked scalable vector types!");
// Pack input if required.
if (InVT != PackedInVT)
@@ -17263,3 +18623,60 @@ SDValue AArch64TargetLowering::getSVESafeBitCast(EVT VT, SDValue Op,
return Op;
}
+
+bool AArch64TargetLowering::isAllActivePredicate(SDValue N) const {
+ return ::isAllActivePredicate(N);
+}
+
+EVT AArch64TargetLowering::getPromotedVTForPredicate(EVT VT) const {
+ return ::getPromotedVTForPredicate(VT);
+}
+
+bool AArch64TargetLowering::SimplifyDemandedBitsForTargetNode(
+ SDValue Op, const APInt &OriginalDemandedBits,
+ const APInt &OriginalDemandedElts, KnownBits &Known, TargetLoweringOpt &TLO,
+ unsigned Depth) const {
+
+ unsigned Opc = Op.getOpcode();
+ switch (Opc) {
+ case AArch64ISD::VSHL: {
+ // Match (VSHL (VLSHR Val X) X)
+ SDValue ShiftL = Op;
+ SDValue ShiftR = Op->getOperand(0);
+ if (ShiftR->getOpcode() != AArch64ISD::VLSHR)
+ return false;
+
+ if (!ShiftL.hasOneUse() || !ShiftR.hasOneUse())
+ return false;
+
+ unsigned ShiftLBits = ShiftL->getConstantOperandVal(1);
+ unsigned ShiftRBits = ShiftR->getConstantOperandVal(1);
+
+ // Other cases can be handled as well, but this is not
+ // implemented.
+ if (ShiftRBits != ShiftLBits)
+ return false;
+
+ unsigned ScalarSize = Op.getScalarValueSizeInBits();
+ assert(ScalarSize > ShiftLBits && "Invalid shift imm");
+
+ APInt ZeroBits = APInt::getLowBitsSet(ScalarSize, ShiftLBits);
+ APInt UnusedBits = ~OriginalDemandedBits;
+
+ if ((ZeroBits & UnusedBits) != ZeroBits)
+ return false;
+
+ // All bits that are zeroed by (VSHL (VLSHR Val X) X) are not
+ // used - simplify to just Val.
+ return TLO.CombineTo(Op, ShiftR->getOperand(0));
+ }
+ }
+
+ return TargetLowering::SimplifyDemandedBitsForTargetNode(
+ Op, OriginalDemandedBits, OriginalDemandedElts, Known, TLO, Depth);
+}
+
+bool AArch64TargetLowering::isConstantUnsignedBitfieldExtactLegal(
+ unsigned Opc, LLT Ty1, LLT Ty2) const {
+ return Ty1 == Ty2 && (Ty1 == LLT::scalar(32) || Ty1 == LLT::scalar(64));
+}