aboutsummaryrefslogtreecommitdiff
path: root/lib/Target/NVPTX/NVPTXISelLowering.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'lib/Target/NVPTX/NVPTXISelLowering.cpp')
-rw-r--r--lib/Target/NVPTX/NVPTXISelLowering.cpp1694
1 files changed, 851 insertions, 843 deletions
diff --git a/lib/Target/NVPTX/NVPTXISelLowering.cpp b/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 7a760fd38d0f..4d06912054a2 100644
--- a/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -79,6 +79,60 @@ FMAContractLevelOpt("nvptx-fma-level", cl::ZeroOrMore, cl::Hidden,
" 1: do it 2: do it aggressively"),
cl::init(2));
+static cl::opt<int> UsePrecDivF32(
+ "nvptx-prec-divf32", cl::ZeroOrMore, cl::Hidden,
+ cl::desc("NVPTX Specifies: 0 use div.approx, 1 use div.full, 2 use"
+ " IEEE Compliant F32 div.rnd if available."),
+ cl::init(2));
+
+static cl::opt<bool> UsePrecSqrtF32(
+ "nvptx-prec-sqrtf32", cl::Hidden,
+ cl::desc("NVPTX Specific: 0 use sqrt.approx, 1 use sqrt.rn."),
+ cl::init(true));
+
+static cl::opt<bool> FtzEnabled(
+ "nvptx-f32ftz", cl::ZeroOrMore, cl::Hidden,
+ cl::desc("NVPTX Specific: Flush f32 subnormals to sign-preserving zero."),
+ cl::init(false));
+
+int NVPTXTargetLowering::getDivF32Level() const {
+ if (UsePrecDivF32.getNumOccurrences() > 0) {
+ // If nvptx-prec-div32=N is used on the command-line, always honor it
+ return UsePrecDivF32;
+ } else {
+ // Otherwise, use div.approx if fast math is enabled
+ if (getTargetMachine().Options.UnsafeFPMath)
+ return 0;
+ else
+ return 2;
+ }
+}
+
+bool NVPTXTargetLowering::usePrecSqrtF32() const {
+ if (UsePrecSqrtF32.getNumOccurrences() > 0) {
+ // If nvptx-prec-sqrtf32 is used on the command-line, always honor it
+ return UsePrecSqrtF32;
+ } else {
+ // Otherwise, use sqrt.approx if fast math is enabled
+ return !getTargetMachine().Options.UnsafeFPMath;
+ }
+}
+
+bool NVPTXTargetLowering::useF32FTZ(const MachineFunction &MF) const {
+ // TODO: Get rid of this flag; there can be only one way to do this.
+ if (FtzEnabled.getNumOccurrences() > 0) {
+ // If nvptx-f32ftz is used on the command-line, always honor it
+ return FtzEnabled;
+ } else {
+ const Function *F = MF.getFunction();
+ // Otherwise, check for an nvptx-f32ftz attribute on the function
+ if (F->hasFnAttribute("nvptx-f32ftz"))
+ return F->getFnAttribute("nvptx-f32ftz").getValueAsString() == "true";
+ else
+ return false;
+ }
+}
+
static bool IsPTXVectorType(MVT VT) {
switch (VT.SimpleTy) {
default:
@@ -92,6 +146,9 @@ static bool IsPTXVectorType(MVT VT) {
case MVT::v2i32:
case MVT::v4i32:
case MVT::v2i64:
+ case MVT::v2f16:
+ case MVT::v4f16:
+ case MVT::v8f16: // <4 x f16x2>
case MVT::v2f32:
case MVT::v4f32:
case MVT::v2f64:
@@ -116,13 +173,24 @@ static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
for (unsigned i = 0, e = TempVTs.size(); i != e; ++i) {
EVT VT = TempVTs[i];
uint64_t Off = TempOffsets[i];
- if (VT.isVector())
- for (unsigned j = 0, je = VT.getVectorNumElements(); j != je; ++j) {
- ValueVTs.push_back(VT.getVectorElementType());
+ // Split vectors into individual elements, except for v2f16, which
+ // we will pass as a single scalar.
+ if (VT.isVector()) {
+ unsigned NumElts = VT.getVectorNumElements();
+ EVT EltVT = VT.getVectorElementType();
+ // Vectors with an even number of f16 elements will be passed to
+ // us as an array of v2f16 elements. We must match this so we
+ // stay in sync with Ins/Outs.
+ if (EltVT == MVT::f16 && NumElts % 2 == 0) {
+ EltVT = MVT::v2f16;
+ NumElts /= 2;
+ }
+ for (unsigned j = 0; j != NumElts; ++j) {
+ ValueVTs.push_back(EltVT);
if (Offsets)
- Offsets->push_back(Off+j*VT.getVectorElementType().getStoreSize());
+ Offsets->push_back(Off + j * EltVT.getStoreSize());
}
- else {
+ } else {
ValueVTs.push_back(VT);
if (Offsets)
Offsets->push_back(Off);
@@ -130,6 +198,125 @@ static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
}
}
+// Check whether we can merge loads/stores of some of the pieces of a
+// flattened function parameter or return value into a single vector
+// load/store.
+//
+// The flattened parameter is represented as a list of EVTs and
+// offsets, and the whole structure is aligned to ParamAlignment. This
+// function determines whether we can load/store pieces of the
+// parameter starting at index Idx using a single vectorized op of
+// size AccessSize. If so, it returns the number of param pieces
+// covered by the vector op. Otherwise, it returns 1.
+static unsigned CanMergeParamLoadStoresStartingAt(
+ unsigned Idx, uint32_t AccessSize, const SmallVectorImpl<EVT> &ValueVTs,
+ const SmallVectorImpl<uint64_t> &Offsets, unsigned ParamAlignment) {
+ assert(isPowerOf2_32(AccessSize) && "must be a power of 2!");
+
+ // Can't vectorize if param alignment is not sufficient.
+ if (AccessSize > ParamAlignment)
+ return 1;
+ // Can't vectorize if offset is not aligned.
+ if (Offsets[Idx] & (AccessSize - 1))
+ return 1;
+
+ EVT EltVT = ValueVTs[Idx];
+ unsigned EltSize = EltVT.getStoreSize();
+
+ // Element is too large to vectorize.
+ if (EltSize >= AccessSize)
+ return 1;
+
+ unsigned NumElts = AccessSize / EltSize;
+ // Can't vectorize if AccessBytes if not a multiple of EltSize.
+ if (AccessSize != EltSize * NumElts)
+ return 1;
+
+ // We don't have enough elements to vectorize.
+ if (Idx + NumElts > ValueVTs.size())
+ return 1;
+
+ // PTX ISA can only deal with 2- and 4-element vector ops.
+ if (NumElts != 4 && NumElts != 2)
+ return 1;
+
+ for (unsigned j = Idx + 1; j < Idx + NumElts; ++j) {
+ // Types do not match.
+ if (ValueVTs[j] != EltVT)
+ return 1;
+
+ // Elements are not contiguous.
+ if (Offsets[j] - Offsets[j - 1] != EltSize)
+ return 1;
+ }
+ // OK. We can vectorize ValueVTs[i..i+NumElts)
+ return NumElts;
+}
+
+// Flags for tracking per-element vectorization state of loads/stores
+// of a flattened function parameter or return value.
+enum ParamVectorizationFlags {
+ PVF_INNER = 0x0, // Middle elements of a vector.
+ PVF_FIRST = 0x1, // First element of the vector.
+ PVF_LAST = 0x2, // Last element of the vector.
+ // Scalar is effectively a 1-element vector.
+ PVF_SCALAR = PVF_FIRST | PVF_LAST
+};
+
+// Computes whether and how we can vectorize the loads/stores of a
+// flattened function parameter or return value.
+//
+// The flattened parameter is represented as the list of ValueVTs and
+// Offsets, and is aligned to ParamAlignment bytes. We return a vector
+// of the same size as ValueVTs indicating how each piece should be
+// loaded/stored (i.e. as a scalar, or as part of a vector
+// load/store).
+static SmallVector<ParamVectorizationFlags, 16>
+VectorizePTXValueVTs(const SmallVectorImpl<EVT> &ValueVTs,
+ const SmallVectorImpl<uint64_t> &Offsets,
+ unsigned ParamAlignment) {
+ // Set vector size to match ValueVTs and mark all elements as
+ // scalars by default.
+ SmallVector<ParamVectorizationFlags, 16> VectorInfo;
+ VectorInfo.assign(ValueVTs.size(), PVF_SCALAR);
+
+ // Check what we can vectorize using 128/64/32-bit accesses.
+ for (int I = 0, E = ValueVTs.size(); I != E; ++I) {
+ // Skip elements we've already processed.
+ assert(VectorInfo[I] == PVF_SCALAR && "Unexpected vector info state.");
+ for (unsigned AccessSize : {16, 8, 4, 2}) {
+ unsigned NumElts = CanMergeParamLoadStoresStartingAt(
+ I, AccessSize, ValueVTs, Offsets, ParamAlignment);
+ // Mark vectorized elements.
+ switch (NumElts) {
+ default:
+ llvm_unreachable("Unexpected return value");
+ case 1:
+ // Can't vectorize using this size, try next smaller size.
+ continue;
+ case 2:
+ assert(I + 1 < E && "Not enough elements.");
+ VectorInfo[I] = PVF_FIRST;
+ VectorInfo[I + 1] = PVF_LAST;
+ I += 1;
+ break;
+ case 4:
+ assert(I + 3 < E && "Not enough elements.");
+ VectorInfo[I] = PVF_FIRST;
+ VectorInfo[I + 1] = PVF_INNER;
+ VectorInfo[I + 2] = PVF_INNER;
+ VectorInfo[I + 3] = PVF_LAST;
+ I += 3;
+ break;
+ }
+ // Break out of the inner loop because we've already succeeded
+ // using largest possible AccessSize.
+ break;
+ }
+ }
+ return VectorInfo;
+}
+
// NVPTXTargetLowering Constructor.
NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
const NVPTXSubtarget &STI)
@@ -158,14 +345,32 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
else
setSchedulingPreference(Sched::Source);
+ auto setFP16OperationAction = [&](unsigned Op, MVT VT, LegalizeAction Action,
+ LegalizeAction NoF16Action) {
+ setOperationAction(Op, VT, STI.allowFP16Math() ? Action : NoF16Action);
+ };
+
addRegisterClass(MVT::i1, &NVPTX::Int1RegsRegClass);
addRegisterClass(MVT::i16, &NVPTX::Int16RegsRegClass);
addRegisterClass(MVT::i32, &NVPTX::Int32RegsRegClass);
addRegisterClass(MVT::i64, &NVPTX::Int64RegsRegClass);
addRegisterClass(MVT::f32, &NVPTX::Float32RegsRegClass);
addRegisterClass(MVT::f64, &NVPTX::Float64RegsRegClass);
+ addRegisterClass(MVT::f16, &NVPTX::Float16RegsRegClass);
+ addRegisterClass(MVT::v2f16, &NVPTX::Float16x2RegsRegClass);
+
+ // Conversion to/from FP16/FP16x2 is always legal.
+ setOperationAction(ISD::SINT_TO_FP, MVT::f16, Legal);
+ setOperationAction(ISD::FP_TO_SINT, MVT::f16, Legal);
+ setOperationAction(ISD::BUILD_VECTOR, MVT::v2f16, Custom);
+ setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::v2f16, Custom);
+
+ setFP16OperationAction(ISD::SETCC, MVT::f16, Legal, Promote);
+ setFP16OperationAction(ISD::SETCC, MVT::v2f16, Legal, Expand);
// Operations not directly supported by NVPTX.
+ setOperationAction(ISD::SELECT_CC, MVT::f16, Expand);
+ setOperationAction(ISD::SELECT_CC, MVT::v2f16, Expand);
setOperationAction(ISD::SELECT_CC, MVT::f32, Expand);
setOperationAction(ISD::SELECT_CC, MVT::f64, Expand);
setOperationAction(ISD::SELECT_CC, MVT::i1, Expand);
@@ -173,6 +378,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
setOperationAction(ISD::SELECT_CC, MVT::i16, Expand);
setOperationAction(ISD::SELECT_CC, MVT::i32, Expand);
setOperationAction(ISD::SELECT_CC, MVT::i64, Expand);
+ setOperationAction(ISD::BR_CC, MVT::f16, Expand);
+ setOperationAction(ISD::BR_CC, MVT::v2f16, Expand);
setOperationAction(ISD::BR_CC, MVT::f32, Expand);
setOperationAction(ISD::BR_CC, MVT::f64, Expand);
setOperationAction(ISD::BR_CC, MVT::i1, Expand);
@@ -195,6 +402,9 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
setOperationAction(ISD::SRA_PARTS, MVT::i64 , Custom);
setOperationAction(ISD::SRL_PARTS, MVT::i64 , Custom);
+ setOperationAction(ISD::BITREVERSE, MVT::i32, Legal);
+ setOperationAction(ISD::BITREVERSE, MVT::i64, Legal);
+
if (STI.hasROT64()) {
setOperationAction(ISD::ROTL, MVT::i64, Legal);
setOperationAction(ISD::ROTR, MVT::i64, Legal);
@@ -259,6 +469,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
// This is legal in NVPTX
setOperationAction(ISD::ConstantFP, MVT::f64, Legal);
setOperationAction(ISD::ConstantFP, MVT::f32, Legal);
+ setOperationAction(ISD::ConstantFP, MVT::f16, Legal);
// TRAP can be lowered to PTX trap
setOperationAction(ISD::TRAP, MVT::Other, Legal);
@@ -278,15 +489,19 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
// Custom handling for i8 intrinsics
setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::i8, Custom);
- setOperationAction(ISD::CTLZ, MVT::i16, Legal);
- setOperationAction(ISD::CTLZ, MVT::i32, Legal);
- setOperationAction(ISD::CTLZ, MVT::i64, Legal);
+ for (const auto& Ty : {MVT::i16, MVT::i32, MVT::i64}) {
+ setOperationAction(ISD::SMIN, Ty, Legal);
+ setOperationAction(ISD::SMAX, Ty, Legal);
+ setOperationAction(ISD::UMIN, Ty, Legal);
+ setOperationAction(ISD::UMAX, Ty, Legal);
+
+ setOperationAction(ISD::CTPOP, Ty, Legal);
+ setOperationAction(ISD::CTLZ, Ty, Legal);
+ }
+
setOperationAction(ISD::CTTZ, MVT::i16, Expand);
setOperationAction(ISD::CTTZ, MVT::i32, Expand);
setOperationAction(ISD::CTTZ, MVT::i64, Expand);
- setOperationAction(ISD::CTPOP, MVT::i16, Legal);
- setOperationAction(ISD::CTPOP, MVT::i32, Legal);
- setOperationAction(ISD::CTPOP, MVT::i64, Legal);
// PTX does not directly support SELP of i1, so promote to i32 first
setOperationAction(ISD::SELECT, MVT::i1, Custom);
@@ -301,28 +516,60 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
setTargetDAGCombine(ISD::FADD);
setTargetDAGCombine(ISD::MUL);
setTargetDAGCombine(ISD::SHL);
- setTargetDAGCombine(ISD::SELECT);
setTargetDAGCombine(ISD::SREM);
setTargetDAGCombine(ISD::UREM);
- // Library functions. These default to Expand, but we have instructions
- // for them.
- setOperationAction(ISD::FCEIL, MVT::f32, Legal);
- setOperationAction(ISD::FCEIL, MVT::f64, Legal);
- setOperationAction(ISD::FFLOOR, MVT::f32, Legal);
- setOperationAction(ISD::FFLOOR, MVT::f64, Legal);
- setOperationAction(ISD::FNEARBYINT, MVT::f32, Legal);
- setOperationAction(ISD::FNEARBYINT, MVT::f64, Legal);
- setOperationAction(ISD::FRINT, MVT::f32, Legal);
- setOperationAction(ISD::FRINT, MVT::f64, Legal);
- setOperationAction(ISD::FROUND, MVT::f32, Legal);
- setOperationAction(ISD::FROUND, MVT::f64, Legal);
- setOperationAction(ISD::FTRUNC, MVT::f32, Legal);
- setOperationAction(ISD::FTRUNC, MVT::f64, Legal);
- setOperationAction(ISD::FMINNUM, MVT::f32, Legal);
- setOperationAction(ISD::FMINNUM, MVT::f64, Legal);
- setOperationAction(ISD::FMAXNUM, MVT::f32, Legal);
- setOperationAction(ISD::FMAXNUM, MVT::f64, Legal);
+ // setcc for f16x2 needs special handling to prevent legalizer's
+ // attempt to scalarize it due to v2i1 not being legal.
+ if (STI.allowFP16Math())
+ setTargetDAGCombine(ISD::SETCC);
+
+ // Promote fp16 arithmetic if fp16 hardware isn't available or the
+ // user passed --nvptx-no-fp16-math. The flag is useful because,
+ // although sm_53+ GPUs have some sort of FP16 support in
+ // hardware, only sm_53 and sm_60 have full implementation. Others
+ // only have token amount of hardware and are likely to run faster
+ // by using fp32 units instead.
+ for (const auto &Op : {ISD::FADD, ISD::FMUL, ISD::FSUB, ISD::FMA}) {
+ setFP16OperationAction(Op, MVT::f16, Legal, Promote);
+ setFP16OperationAction(Op, MVT::v2f16, Legal, Expand);
+ }
+
+ // There's no neg.f16 instruction. Expand to (0-x).
+ setOperationAction(ISD::FNEG, MVT::f16, Expand);
+ setOperationAction(ISD::FNEG, MVT::v2f16, Expand);
+
+ // (would be) Library functions.
+
+ // These map to conversion instructions for scalar FP types.
+ for (const auto &Op : {ISD::FCEIL, ISD::FFLOOR, ISD::FNEARBYINT, ISD::FRINT,
+ ISD::FROUND, ISD::FTRUNC}) {
+ setOperationAction(Op, MVT::f16, Legal);
+ setOperationAction(Op, MVT::f32, Legal);
+ setOperationAction(Op, MVT::f64, Legal);
+ setOperationAction(Op, MVT::v2f16, Expand);
+ }
+
+ // 'Expand' implements FCOPYSIGN without calling an external library.
+ setOperationAction(ISD::FCOPYSIGN, MVT::f16, Expand);
+ setOperationAction(ISD::FCOPYSIGN, MVT::v2f16, Expand);
+ setOperationAction(ISD::FCOPYSIGN, MVT::f32, Expand);
+ setOperationAction(ISD::FCOPYSIGN, MVT::f64, Expand);
+
+ // These map to corresponding instructions for f32/f64. f16 must be
+ // promoted to f32. v2f16 is expanded to f16, which is then promoted
+ // to f32.
+ for (const auto &Op : {ISD::FDIV, ISD::FREM, ISD::FSQRT, ISD::FSIN, ISD::FCOS,
+ ISD::FABS, ISD::FMINNUM, ISD::FMAXNUM}) {
+ setOperationAction(Op, MVT::f16, Promote);
+ setOperationAction(Op, MVT::f32, Legal);
+ setOperationAction(Op, MVT::f64, Legal);
+ setOperationAction(Op, MVT::v2f16, Expand);
+ }
+ setOperationAction(ISD::FMINNUM, MVT::f16, Promote);
+ setOperationAction(ISD::FMAXNUM, MVT::f16, Promote);
+ setOperationAction(ISD::FMINNAN, MVT::f16, Promote);
+ setOperationAction(ISD::FMAXNAN, MVT::f16, Promote);
// No FEXP2, FLOG2. The PTX ex2 and log2 functions are always approximate.
// No FPOW or FREM in PTX.
@@ -434,6 +681,8 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
return "NVPTXISD::FUN_SHFR_CLAMP";
case NVPTXISD::IMAD:
return "NVPTXISD::IMAD";
+ case NVPTXISD::SETP_F16X2:
+ return "NVPTXISD::SETP_F16X2";
case NVPTXISD::Dummy:
return "NVPTXISD::Dummy";
case NVPTXISD::MUL_WIDE_SIGNED:
@@ -932,10 +1181,60 @@ TargetLoweringBase::LegalizeTypeAction
NVPTXTargetLowering::getPreferredVectorAction(EVT VT) const {
if (VT.getVectorNumElements() != 1 && VT.getScalarType() == MVT::i1)
return TypeSplitVector;
-
+ if (VT == MVT::v2f16)
+ return TypeLegal;
return TargetLoweringBase::getPreferredVectorAction(VT);
}
+SDValue NVPTXTargetLowering::getSqrtEstimate(SDValue Operand, SelectionDAG &DAG,
+ int Enabled, int &ExtraSteps,
+ bool &UseOneConst,
+ bool Reciprocal) const {
+ if (!(Enabled == ReciprocalEstimate::Enabled ||
+ (Enabled == ReciprocalEstimate::Unspecified && !usePrecSqrtF32())))
+ return SDValue();
+
+ if (ExtraSteps == ReciprocalEstimate::Unspecified)
+ ExtraSteps = 0;
+
+ SDLoc DL(Operand);
+ EVT VT = Operand.getValueType();
+ bool Ftz = useF32FTZ(DAG.getMachineFunction());
+
+ auto MakeIntrinsicCall = [&](Intrinsic::ID IID) {
+ return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT,
+ DAG.getConstant(IID, DL, MVT::i32), Operand);
+ };
+
+ // The sqrt and rsqrt refinement processes assume we always start out with an
+ // approximation of the rsqrt. Therefore, if we're going to do any refinement
+ // (i.e. ExtraSteps > 0), we must return an rsqrt. But if we're *not* doing
+ // any refinement, we must return a regular sqrt.
+ if (Reciprocal || ExtraSteps > 0) {
+ if (VT == MVT::f32)
+ return MakeIntrinsicCall(Ftz ? Intrinsic::nvvm_rsqrt_approx_ftz_f
+ : Intrinsic::nvvm_rsqrt_approx_f);
+ else if (VT == MVT::f64)
+ return MakeIntrinsicCall(Intrinsic::nvvm_rsqrt_approx_d);
+ else
+ return SDValue();
+ } else {
+ if (VT == MVT::f32)
+ return MakeIntrinsicCall(Ftz ? Intrinsic::nvvm_sqrt_approx_ftz_f
+ : Intrinsic::nvvm_sqrt_approx_f);
+ else {
+ // There's no sqrt.approx.f64 instruction, so we emit
+ // reciprocal(rsqrt(x)). This is faster than
+ // select(x == 0, 0, x * rsqrt(x)). (In fact, it's faster than plain
+ // x * rsqrt(x).)
+ return DAG.getNode(
+ ISD::INTRINSIC_WO_CHAIN, DL, VT,
+ DAG.getConstant(Intrinsic::nvvm_rcp_approx_ftz_d, DL, MVT::i32),
+ MakeIntrinsicCall(Intrinsic::nvvm_rsqrt_approx_d));
+ }
+ }
+}
+
SDValue
NVPTXTargetLowering::LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const {
SDLoc dl(Op);
@@ -967,19 +1266,21 @@ std::string NVPTXTargetLowering::getPrototype(
unsigned size = 0;
if (auto *ITy = dyn_cast<IntegerType>(retTy)) {
size = ITy->getBitWidth();
- if (size < 32)
- size = 32;
} else {
assert(retTy->isFloatingPointTy() &&
"Floating point type expected here");
size = retTy->getPrimitiveSizeInBits();
}
+ // PTX ABI requires all scalar return values to be at least 32
+ // bits in size. fp16 normally uses .b16 as its storage type in
+ // PTX, so its size must be adjusted here, too.
+ if (size < 32)
+ size = 32;
O << ".param .b" << size << " _";
} else if (isa<PointerType>(retTy)) {
O << ".param .b" << PtrVT.getSizeInBits() << " _";
- } else if ((retTy->getTypeID() == Type::StructTyID) ||
- isa<VectorType>(retTy)) {
+ } else if (retTy->isAggregateType() || retTy->isVectorTy()) {
auto &DL = CS->getCalledFunction()->getParent()->getDataLayout();
O << ".param .align " << retAlignment << " .b8 _["
<< DL.getTypeAllocSize(retTy) << "]";
@@ -1018,7 +1319,7 @@ std::string NVPTXTargetLowering::getPrototype(
OIdx += len - 1;
continue;
}
- // i8 types in IR will be i16 types in SDAG
+ // i8 types in IR will be i16 types in SDAG
assert((getValueType(DL, Ty) == Outs[OIdx].VT ||
(getValueType(DL, Ty) == MVT::i8 && Outs[OIdx].VT == MVT::i16)) &&
"type mismatch between callee prototype and arguments");
@@ -1028,8 +1329,13 @@ std::string NVPTXTargetLowering::getPrototype(
sz = cast<IntegerType>(Ty)->getBitWidth();
if (sz < 32)
sz = 32;
- } else if (isa<PointerType>(Ty))
+ } else if (isa<PointerType>(Ty)) {
sz = PtrVT.getSizeInBits();
+ } else if (Ty->isHalfTy())
+ // PTX ABI requires all scalar parameters to be at least 32
+ // bits in size. fp16 normally uses .b16 as its storage type
+ // in PTX, so its size must be adjusted here, too.
+ sz = 32;
else
sz = Ty->getPrimitiveSizeInBits();
O << ".param .b" << sz << " ";
@@ -1113,21 +1419,18 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
SDValue Callee = CLI.Callee;
bool &isTailCall = CLI.IsTailCall;
ArgListTy &Args = CLI.getArgs();
- Type *retTy = CLI.RetTy;
+ Type *RetTy = CLI.RetTy;
ImmutableCallSite *CS = CLI.CS;
+ const DataLayout &DL = DAG.getDataLayout();
bool isABI = (STI.getSmVersion() >= 20);
assert(isABI && "Non-ABI compilation is not supported");
if (!isABI)
return Chain;
- MachineFunction &MF = DAG.getMachineFunction();
- const Function *F = MF.getFunction();
- auto &DL = MF.getDataLayout();
SDValue tempChain = Chain;
- Chain = DAG.getCALLSEQ_START(Chain,
- DAG.getIntPtrConstant(uniqueCallSite, dl, true),
- dl);
+ Chain = DAG.getCALLSEQ_START(
+ Chain, DAG.getIntPtrConstant(uniqueCallSite, dl, true), dl);
SDValue InFlag = Chain.getValue(1);
unsigned paramCount = 0;
@@ -1148,240 +1451,124 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
Type *Ty = Args[i].Ty;
if (!Outs[OIdx].Flags.isByVal()) {
- if (Ty->isAggregateType()) {
- // aggregate
- SmallVector<EVT, 16> vtparts;
- SmallVector<uint64_t, 16> Offsets;
- ComputePTXValueVTs(*this, DAG.getDataLayout(), Ty, vtparts, &Offsets,
- 0);
-
- unsigned align =
- getArgumentAlignment(Callee, CS, Ty, paramCount + 1, DL);
+ SmallVector<EVT, 16> VTs;
+ SmallVector<uint64_t, 16> Offsets;
+ ComputePTXValueVTs(*this, DL, Ty, VTs, &Offsets);
+ unsigned ArgAlign =
+ getArgumentAlignment(Callee, CS, Ty, paramCount + 1, DL);
+ unsigned AllocSize = DL.getTypeAllocSize(Ty);
+ SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
+ bool NeedAlign; // Does argument declaration specify alignment?
+ if (Ty->isAggregateType() || Ty->isVectorTy()) {
// declare .param .align <align> .b8 .param<n>[<size>];
- unsigned sz = DL.getTypeAllocSize(Ty);
- SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
- SDValue DeclareParamOps[] = { Chain, DAG.getConstant(align, dl,
- MVT::i32),
- DAG.getConstant(paramCount, dl, MVT::i32),
- DAG.getConstant(sz, dl, MVT::i32),
- InFlag };
+ SDValue DeclareParamOps[] = {
+ Chain, DAG.getConstant(ArgAlign, dl, MVT::i32),
+ DAG.getConstant(paramCount, dl, MVT::i32),
+ DAG.getConstant(AllocSize, dl, MVT::i32), InFlag};
Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs,
DeclareParamOps);
- InFlag = Chain.getValue(1);
- for (unsigned j = 0, je = vtparts.size(); j != je; ++j) {
- EVT elemtype = vtparts[j];
- unsigned ArgAlign = GreatestCommonDivisor64(align, Offsets[j]);
- if (elemtype.isInteger() && (sz < 8))
- sz = 8;
- SDValue StVal = OutVals[OIdx];
- if (elemtype.getSizeInBits() < 16) {
- StVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, StVal);
- }
- SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
- SDValue CopyParamOps[] = { Chain,
- DAG.getConstant(paramCount, dl, MVT::i32),
- DAG.getConstant(Offsets[j], dl, MVT::i32),
- StVal, InFlag };
- Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreParam, dl,
- CopyParamVTs, CopyParamOps,
- elemtype, MachinePointerInfo(),
- ArgAlign);
- InFlag = Chain.getValue(1);
- ++OIdx;
+ NeedAlign = true;
+ } else {
+ // declare .param .b<size> .param<n>;
+ if ((VT.isInteger() || VT.isFloatingPoint()) && AllocSize < 4) {
+ // PTX ABI requires integral types to be at least 32 bits in
+ // size. FP16 is loaded/stored using i16, so it's handled
+ // here as well.
+ AllocSize = 4;
}
- if (vtparts.size() > 0)
- --OIdx;
- ++paramCount;
- continue;
+ SDValue DeclareScalarParamOps[] = {
+ Chain, DAG.getConstant(paramCount, dl, MVT::i32),
+ DAG.getConstant(AllocSize * 8, dl, MVT::i32),
+ DAG.getConstant(0, dl, MVT::i32), InFlag};
+ Chain = DAG.getNode(NVPTXISD::DeclareScalarParam, dl, DeclareParamVTs,
+ DeclareScalarParamOps);
+ NeedAlign = false;
}
- if (Ty->isVectorTy()) {
- EVT ObjectVT = getValueType(DL, Ty);
- unsigned align =
- getArgumentAlignment(Callee, CS, Ty, paramCount + 1, DL);
- // declare .param .align <align> .b8 .param<n>[<size>];
- unsigned sz = DL.getTypeAllocSize(Ty);
- SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
- SDValue DeclareParamOps[] = { Chain,
- DAG.getConstant(align, dl, MVT::i32),
- DAG.getConstant(paramCount, dl, MVT::i32),
- DAG.getConstant(sz, dl, MVT::i32),
- InFlag };
- Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs,
- DeclareParamOps);
- InFlag = Chain.getValue(1);
- unsigned NumElts = ObjectVT.getVectorNumElements();
- EVT EltVT = ObjectVT.getVectorElementType();
- EVT MemVT = EltVT;
- bool NeedExtend = false;
- if (EltVT.getSizeInBits() < 16) {
- NeedExtend = true;
- EltVT = MVT::i16;
+ InFlag = Chain.getValue(1);
+
+ // PTX Interoperability Guide 3.3(A): [Integer] Values shorter
+ // than 32-bits are sign extended or zero extended, depending on
+ // whether they are signed or unsigned types. This case applies
+ // only to scalar parameters and not to aggregate values.
+ bool ExtendIntegerParam =
+ Ty->isIntegerTy() && DL.getTypeAllocSizeInBits(Ty) < 32;
+
+ auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, ArgAlign);
+ SmallVector<SDValue, 6> StoreOperands;
+ for (unsigned j = 0, je = VTs.size(); j != je; ++j) {
+ // New store.
+ if (VectorInfo[j] & PVF_FIRST) {
+ assert(StoreOperands.empty() && "Unfinished preceeding store.");
+ StoreOperands.push_back(Chain);
+ StoreOperands.push_back(DAG.getConstant(paramCount, dl, MVT::i32));
+ StoreOperands.push_back(DAG.getConstant(Offsets[j], dl, MVT::i32));
}
- // V1 store
- if (NumElts == 1) {
- SDValue Elt = OutVals[OIdx++];
- if (NeedExtend)
- Elt = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, Elt);
-
- SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
- SDValue CopyParamOps[] = { Chain,
- DAG.getConstant(paramCount, dl, MVT::i32),
- DAG.getConstant(0, dl, MVT::i32), Elt,
- InFlag };
- Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreParam, dl,
- CopyParamVTs, CopyParamOps,
- MemVT, MachinePointerInfo());
- InFlag = Chain.getValue(1);
- } else if (NumElts == 2) {
- SDValue Elt0 = OutVals[OIdx++];
- SDValue Elt1 = OutVals[OIdx++];
- if (NeedExtend) {
- Elt0 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, Elt0);
- Elt1 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, Elt1);
+ EVT EltVT = VTs[j];
+ SDValue StVal = OutVals[OIdx];
+ if (ExtendIntegerParam) {
+ assert(VTs.size() == 1 && "Scalar can't have multiple parts.");
+ // zext/sext to i32
+ StVal = DAG.getNode(Outs[OIdx].Flags.isSExt() ? ISD::SIGN_EXTEND
+ : ISD::ZERO_EXTEND,
+ dl, MVT::i32, StVal);
+ } else if (EltVT.getSizeInBits() < 16) {
+ // Use 16-bit registers for small stores as it's the
+ // smallest general purpose register size supported by NVPTX.
+ StVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, StVal);
+ }
+
+ // Record the value to store.
+ StoreOperands.push_back(StVal);
+
+ if (VectorInfo[j] & PVF_LAST) {
+ unsigned NumElts = StoreOperands.size() - 3;
+ NVPTXISD::NodeType Op;
+ switch (NumElts) {
+ case 1:
+ Op = NVPTXISD::StoreParam;
+ break;
+ case 2:
+ Op = NVPTXISD::StoreParamV2;
+ break;
+ case 4:
+ Op = NVPTXISD::StoreParamV4;
+ break;
+ default:
+ llvm_unreachable("Invalid vector info.");
}
- SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
- SDValue CopyParamOps[] = { Chain,
- DAG.getConstant(paramCount, dl, MVT::i32),
- DAG.getConstant(0, dl, MVT::i32), Elt0,
- Elt1, InFlag };
- Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreParamV2, dl,
- CopyParamVTs, CopyParamOps,
- MemVT, MachinePointerInfo());
- InFlag = Chain.getValue(1);
- } else {
- unsigned curOffset = 0;
- // V4 stores
- // We have at least 4 elements (<3 x Ty> expands to 4 elements) and
- // the
- // vector will be expanded to a power of 2 elements, so we know we can
- // always round up to the next multiple of 4 when creating the vector
- // stores.
- // e.g. 4 elem => 1 st.v4
- // 6 elem => 2 st.v4
- // 8 elem => 2 st.v4
- // 11 elem => 3 st.v4
- unsigned VecSize = 4;
- if (EltVT.getSizeInBits() == 64)
- VecSize = 2;
-
- // This is potentially only part of a vector, so assume all elements
- // are packed together.
- unsigned PerStoreOffset = MemVT.getStoreSizeInBits() / 8 * VecSize;
-
- for (unsigned i = 0; i < NumElts; i += VecSize) {
- // Get values
- SDValue StoreVal;
- SmallVector<SDValue, 8> Ops;
- Ops.push_back(Chain);
- Ops.push_back(DAG.getConstant(paramCount, dl, MVT::i32));
- Ops.push_back(DAG.getConstant(curOffset, dl, MVT::i32));
-
- unsigned Opc = NVPTXISD::StoreParamV2;
-
- StoreVal = OutVals[OIdx++];
- if (NeedExtend)
- StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal);
- Ops.push_back(StoreVal);
-
- if (i + 1 < NumElts) {
- StoreVal = OutVals[OIdx++];
- if (NeedExtend)
- StoreVal =
- DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal);
- } else {
- StoreVal = DAG.getUNDEF(EltVT);
- }
- Ops.push_back(StoreVal);
-
- if (VecSize == 4) {
- Opc = NVPTXISD::StoreParamV4;
- if (i + 2 < NumElts) {
- StoreVal = OutVals[OIdx++];
- if (NeedExtend)
- StoreVal =
- DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal);
- } else {
- StoreVal = DAG.getUNDEF(EltVT);
- }
- Ops.push_back(StoreVal);
-
- if (i + 3 < NumElts) {
- StoreVal = OutVals[OIdx++];
- if (NeedExtend)
- StoreVal =
- DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal);
- } else {
- StoreVal = DAG.getUNDEF(EltVT);
- }
- Ops.push_back(StoreVal);
- }
+ StoreOperands.push_back(InFlag);
- Ops.push_back(InFlag);
+ // Adjust type of the store op if we've extended the scalar
+ // return value.
+ EVT TheStoreType = ExtendIntegerParam ? MVT::i32 : VTs[j];
+ unsigned EltAlign =
+ NeedAlign ? GreatestCommonDivisor64(ArgAlign, Offsets[j]) : 0;
- SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
- Chain = DAG.getMemIntrinsicNode(Opc, dl, CopyParamVTs, Ops,
- MemVT, MachinePointerInfo());
- InFlag = Chain.getValue(1);
- curOffset += PerStoreOffset;
- }
+ Chain = DAG.getMemIntrinsicNode(
+ Op, dl, DAG.getVTList(MVT::Other, MVT::Glue), StoreOperands,
+ TheStoreType, MachinePointerInfo(), EltAlign);
+ InFlag = Chain.getValue(1);
+
+ // Cleanup.
+ StoreOperands.clear();
}
- ++paramCount;
- --OIdx;
- continue;
- }
- // Plain scalar
- // for ABI, declare .param .b<size> .param<n>;
- unsigned sz = VT.getSizeInBits();
- bool needExtend = false;
- if (VT.isInteger()) {
- if (sz < 16)
- needExtend = true;
- if (sz < 32)
- sz = 32;
+ ++OIdx;
}
- SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
- SDValue DeclareParamOps[] = { Chain,
- DAG.getConstant(paramCount, dl, MVT::i32),
- DAG.getConstant(sz, dl, MVT::i32),
- DAG.getConstant(0, dl, MVT::i32), InFlag };
- Chain = DAG.getNode(NVPTXISD::DeclareScalarParam, dl, DeclareParamVTs,
- DeclareParamOps);
- InFlag = Chain.getValue(1);
- SDValue OutV = OutVals[OIdx];
- if (needExtend) {
- // zext/sext i1 to i16
- unsigned opc = ISD::ZERO_EXTEND;
- if (Outs[OIdx].Flags.isSExt())
- opc = ISD::SIGN_EXTEND;
- OutV = DAG.getNode(opc, dl, MVT::i16, OutV);
- }
- SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
- SDValue CopyParamOps[] = { Chain,
- DAG.getConstant(paramCount, dl, MVT::i32),
- DAG.getConstant(0, dl, MVT::i32), OutV,
- InFlag };
-
- unsigned opcode = NVPTXISD::StoreParam;
- if (Outs[OIdx].Flags.isZExt() && VT.getSizeInBits() < 32)
- opcode = NVPTXISD::StoreParamU32;
- else if (Outs[OIdx].Flags.isSExt() && VT.getSizeInBits() < 32)
- opcode = NVPTXISD::StoreParamS32;
- Chain = DAG.getMemIntrinsicNode(opcode, dl, CopyParamVTs, CopyParamOps,
- VT, MachinePointerInfo());
-
- InFlag = Chain.getValue(1);
+ assert(StoreOperands.empty() && "Unfinished parameter store.");
+ if (VTs.size() > 0)
+ --OIdx;
++paramCount;
continue;
}
- // struct or vector
- SmallVector<EVT, 16> vtparts;
+
+ // ByVal arguments
+ SmallVector<EVT, 16> VTs;
SmallVector<uint64_t, 16> Offsets;
auto *PTy = dyn_cast<PointerType>(Args[i].Ty);
assert(PTy && "Type of a byval parameter should be pointer");
- ComputePTXValueVTs(*this, DAG.getDataLayout(), PTy->getElementType(),
- vtparts, &Offsets, 0);
+ ComputePTXValueVTs(*this, DL, PTy->getElementType(), VTs, &Offsets, 0);
// declare .param .align <align> .b8 .param<n>[<size>];
unsigned sz = Outs[OIdx].Flags.getByValSize();
@@ -1402,11 +1589,11 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs,
DeclareParamOps);
InFlag = Chain.getValue(1);
- for (unsigned j = 0, je = vtparts.size(); j != je; ++j) {
- EVT elemtype = vtparts[j];
+ for (unsigned j = 0, je = VTs.size(); j != je; ++j) {
+ EVT elemtype = VTs[j];
int curOffset = Offsets[j];
unsigned PartAlign = GreatestCommonDivisor64(ArgAlign, curOffset);
- auto PtrVT = getPointerTy(DAG.getDataLayout());
+ auto PtrVT = getPointerTy(DL);
SDValue srcAddr = DAG.getNode(ISD::ADD, dl, PtrVT, OutVals[OIdx],
DAG.getConstant(curOffset, dl, PtrVT));
SDValue theVal = DAG.getLoad(elemtype, dl, tempChain, srcAddr,
@@ -1434,18 +1621,18 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
// Handle Result
if (Ins.size() > 0) {
SmallVector<EVT, 16> resvtparts;
- ComputeValueVTs(*this, DL, retTy, resvtparts);
+ ComputeValueVTs(*this, DL, RetTy, resvtparts);
// Declare
// .param .align 16 .b8 retval0[<size-in-bytes>], or
// .param .b<size-in-bits> retval0
- unsigned resultsz = DL.getTypeAllocSizeInBits(retTy);
+ unsigned resultsz = DL.getTypeAllocSizeInBits(RetTy);
// Emit ".param .b<size-in-bits> retval0" instead of byte arrays only for
// these three types to match the logic in
// NVPTXAsmPrinter::printReturnValStr and NVPTXTargetLowering::getPrototype.
// Plus, this behavior is consistent with nvcc's.
- if (retTy->isFloatingPointTy() || retTy->isIntegerTy() ||
- retTy->isPointerTy()) {
+ if (RetTy->isFloatingPointTy() || RetTy->isIntegerTy() ||
+ RetTy->isPointerTy()) {
// Scalar needs to be at least 32bit wide
if (resultsz < 32)
resultsz = 32;
@@ -1457,7 +1644,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
DeclareRetOps);
InFlag = Chain.getValue(1);
} else {
- retAlignment = getArgumentAlignment(Callee, CS, retTy, 0, DL);
+ retAlignment = getArgumentAlignment(Callee, CS, RetTy, 0, DL);
SDVTList DeclareRetVTs = DAG.getVTList(MVT::Other, MVT::Glue);
SDValue DeclareRetOps[] = { Chain,
DAG.getConstant(retAlignment, dl, MVT::i32),
@@ -1478,8 +1665,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
// The prototype is embedded in a string and put as the operand for a
// CallPrototype SDNode which will print out to the value of the string.
SDVTList ProtoVTs = DAG.getVTList(MVT::Other, MVT::Glue);
- std::string Proto =
- getPrototype(DAG.getDataLayout(), retTy, Args, Outs, retAlignment, CS);
+ std::string Proto = getPrototype(DL, RetTy, Args, Outs, retAlignment, CS);
const char *ProtoStr =
nvTM->getManagedStrPool()->getManagedString(Proto.c_str())->c_str();
SDValue ProtoOps[] = {
@@ -1544,175 +1730,84 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
// Generate loads from param memory/moves from registers for result
if (Ins.size() > 0) {
- if (retTy && retTy->isVectorTy()) {
- EVT ObjectVT = getValueType(DL, retTy);
- unsigned NumElts = ObjectVT.getVectorNumElements();
- EVT EltVT = ObjectVT.getVectorElementType();
- assert(STI.getTargetLowering()->getNumRegisters(F->getContext(),
- ObjectVT) == NumElts &&
- "Vector was not scalarized");
- unsigned sz = EltVT.getSizeInBits();
- bool needTruncate = sz < 8;
-
- if (NumElts == 1) {
- // Just a simple load
- SmallVector<EVT, 4> LoadRetVTs;
- if (EltVT == MVT::i1 || EltVT == MVT::i8) {
- // If loading i1/i8 result, generate
- // load.b8 i16
- // if i1
- // trunc i16 to i1
- LoadRetVTs.push_back(MVT::i16);
- } else
- LoadRetVTs.push_back(EltVT);
- LoadRetVTs.push_back(MVT::Other);
- LoadRetVTs.push_back(MVT::Glue);
- SDValue LoadRetOps[] = {Chain, DAG.getConstant(1, dl, MVT::i32),
- DAG.getConstant(0, dl, MVT::i32), InFlag};
- SDValue retval = DAG.getMemIntrinsicNode(
- NVPTXISD::LoadParam, dl,
- DAG.getVTList(LoadRetVTs), LoadRetOps, EltVT, MachinePointerInfo());
- Chain = retval.getValue(1);
- InFlag = retval.getValue(2);
- SDValue Ret0 = retval;
- if (needTruncate)
- Ret0 = DAG.getNode(ISD::TRUNCATE, dl, EltVT, Ret0);
- InVals.push_back(Ret0);
- } else if (NumElts == 2) {
- // LoadV2
- SmallVector<EVT, 4> LoadRetVTs;
- if (EltVT == MVT::i1 || EltVT == MVT::i8) {
- // If loading i1/i8 result, generate
- // load.b8 i16
- // if i1
- // trunc i16 to i1
- LoadRetVTs.push_back(MVT::i16);
- LoadRetVTs.push_back(MVT::i16);
- } else {
- LoadRetVTs.push_back(EltVT);
- LoadRetVTs.push_back(EltVT);
- }
- LoadRetVTs.push_back(MVT::Other);
- LoadRetVTs.push_back(MVT::Glue);
- SDValue LoadRetOps[] = {Chain, DAG.getConstant(1, dl, MVT::i32),
- DAG.getConstant(0, dl, MVT::i32), InFlag};
- SDValue retval = DAG.getMemIntrinsicNode(
- NVPTXISD::LoadParamV2, dl,
- DAG.getVTList(LoadRetVTs), LoadRetOps, EltVT, MachinePointerInfo());
- Chain = retval.getValue(2);
- InFlag = retval.getValue(3);
- SDValue Ret0 = retval.getValue(0);
- SDValue Ret1 = retval.getValue(1);
- if (needTruncate) {
- Ret0 = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, Ret0);
- InVals.push_back(Ret0);
- Ret1 = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, Ret1);
- InVals.push_back(Ret1);
- } else {
- InVals.push_back(Ret0);
- InVals.push_back(Ret1);
- }
- } else {
- // Split into N LoadV4
- unsigned Ofst = 0;
- unsigned VecSize = 4;
- unsigned Opc = NVPTXISD::LoadParamV4;
- if (EltVT.getSizeInBits() == 64) {
- VecSize = 2;
- Opc = NVPTXISD::LoadParamV2;
- }
- EVT VecVT = EVT::getVectorVT(F->getContext(), EltVT, VecSize);
- for (unsigned i = 0; i < NumElts; i += VecSize) {
- SmallVector<EVT, 8> LoadRetVTs;
- if (EltVT == MVT::i1 || EltVT == MVT::i8) {
- // If loading i1/i8 result, generate
- // load.b8 i16
- // if i1
- // trunc i16 to i1
- for (unsigned j = 0; j < VecSize; ++j)
- LoadRetVTs.push_back(MVT::i16);
- } else {
- for (unsigned j = 0; j < VecSize; ++j)
- LoadRetVTs.push_back(EltVT);
- }
- LoadRetVTs.push_back(MVT::Other);
- LoadRetVTs.push_back(MVT::Glue);
- SDValue LoadRetOps[] = {Chain, DAG.getConstant(1, dl, MVT::i32),
- DAG.getConstant(Ofst, dl, MVT::i32), InFlag};
- SDValue retval = DAG.getMemIntrinsicNode(
- Opc, dl, DAG.getVTList(LoadRetVTs),
- LoadRetOps, EltVT, MachinePointerInfo());
- if (VecSize == 2) {
- Chain = retval.getValue(2);
- InFlag = retval.getValue(3);
- } else {
- Chain = retval.getValue(4);
- InFlag = retval.getValue(5);
- }
+ SmallVector<EVT, 16> VTs;
+ SmallVector<uint64_t, 16> Offsets;
+ ComputePTXValueVTs(*this, DL, RetTy, VTs, &Offsets, 0);
+ assert(VTs.size() == Ins.size() && "Bad value decomposition");
+
+ unsigned RetAlign = getArgumentAlignment(Callee, CS, RetTy, 0, DL);
+ auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, RetAlign);
+
+ SmallVector<EVT, 6> LoadVTs;
+ int VecIdx = -1; // Index of the first element of the vector.
+
+ // PTX Interoperability Guide 3.3(A): [Integer] Values shorter than
+ // 32-bits are sign extended or zero extended, depending on whether
+ // they are signed or unsigned types.
+ bool ExtendIntegerRetVal =
+ RetTy->isIntegerTy() && DL.getTypeAllocSizeInBits(RetTy) < 32;
+
+ for (unsigned i = 0, e = VTs.size(); i != e; ++i) {
+ bool needTruncate = false;
+ EVT TheLoadType = VTs[i];
+ EVT EltType = Ins[i].VT;
+ unsigned EltAlign = GreatestCommonDivisor64(RetAlign, Offsets[i]);
+ if (ExtendIntegerRetVal) {
+ TheLoadType = MVT::i32;
+ EltType = MVT::i32;
+ needTruncate = true;
+ } else if (TheLoadType.getSizeInBits() < 16) {
+ if (VTs[i].isInteger())
+ needTruncate = true;
+ EltType = MVT::i16;
+ }
- for (unsigned j = 0; j < VecSize; ++j) {
- if (i + j >= NumElts)
- break;
- SDValue Elt = retval.getValue(j);
- if (needTruncate)
- Elt = DAG.getNode(ISD::TRUNCATE, dl, EltVT, Elt);
- InVals.push_back(Elt);
- }
- Ofst += DL.getTypeAllocSize(VecVT.getTypeForEVT(F->getContext()));
- }
+ // Record index of the very first element of the vector.
+ if (VectorInfo[i] & PVF_FIRST) {
+ assert(VecIdx == -1 && LoadVTs.empty() && "Orphaned operand list.");
+ VecIdx = i;
}
- } else {
- SmallVector<EVT, 16> VTs;
- SmallVector<uint64_t, 16> Offsets;
- auto &DL = DAG.getDataLayout();
- ComputePTXValueVTs(*this, DL, retTy, VTs, &Offsets, 0);
- assert(VTs.size() == Ins.size() && "Bad value decomposition");
- unsigned RetAlign = getArgumentAlignment(Callee, CS, retTy, 0, DL);
- for (unsigned i = 0, e = Ins.size(); i != e; ++i) {
- unsigned sz = VTs[i].getSizeInBits();
- unsigned AlignI = GreatestCommonDivisor64(RetAlign, Offsets[i]);
- bool needTruncate = false;
- if (VTs[i].isInteger() && sz < 8) {
- sz = 8;
- needTruncate = true;
+
+ LoadVTs.push_back(EltType);
+
+ if (VectorInfo[i] & PVF_LAST) {
+ unsigned NumElts = LoadVTs.size();
+ LoadVTs.push_back(MVT::Other);
+ LoadVTs.push_back(MVT::Glue);
+ NVPTXISD::NodeType Op;
+ switch (NumElts) {
+ case 1:
+ Op = NVPTXISD::LoadParam;
+ break;
+ case 2:
+ Op = NVPTXISD::LoadParamV2;
+ break;
+ case 4:
+ Op = NVPTXISD::LoadParamV4;
+ break;
+ default:
+ llvm_unreachable("Invalid vector info.");
}
- SmallVector<EVT, 4> LoadRetVTs;
- EVT TheLoadType = VTs[i];
- if (retTy->isIntegerTy() && DL.getTypeAllocSizeInBits(retTy) < 32) {
- // This is for integer types only, and specifically not for
- // aggregates.
- LoadRetVTs.push_back(MVT::i32);
- TheLoadType = MVT::i32;
- needTruncate = true;
- } else if (sz < 16) {
- // If loading i1/i8 result, generate
- // load i8 (-> i16)
- // trunc i16 to i1/i8
-
- // FIXME: Do we need to set needTruncate to true here, too? We could
- // not figure out what this branch is for in D17872, so we left it
- // alone. The comment above about loading i1/i8 may be wrong, as the
- // branch above seems to cover integers of size < 32.
- LoadRetVTs.push_back(MVT::i16);
- } else
- LoadRetVTs.push_back(Ins[i].VT);
- LoadRetVTs.push_back(MVT::Other);
- LoadRetVTs.push_back(MVT::Glue);
-
- SDValue LoadRetOps[] = {Chain, DAG.getConstant(1, dl, MVT::i32),
- DAG.getConstant(Offsets[i], dl, MVT::i32),
- InFlag};
- SDValue retval = DAG.getMemIntrinsicNode(
- NVPTXISD::LoadParam, dl,
- DAG.getVTList(LoadRetVTs), LoadRetOps,
- TheLoadType, MachinePointerInfo(), AlignI);
- Chain = retval.getValue(1);
- InFlag = retval.getValue(2);
- SDValue Ret0 = retval.getValue(0);
- if (needTruncate)
- Ret0 = DAG.getNode(ISD::TRUNCATE, dl, Ins[i].VT, Ret0);
- InVals.push_back(Ret0);
+ SDValue LoadOperands[] = {
+ Chain, DAG.getConstant(1, dl, MVT::i32),
+ DAG.getConstant(Offsets[VecIdx], dl, MVT::i32), InFlag};
+ SDValue RetVal = DAG.getMemIntrinsicNode(
+ Op, dl, DAG.getVTList(LoadVTs), LoadOperands, TheLoadType,
+ MachinePointerInfo(), EltAlign);
+
+ for (unsigned j = 0; j < NumElts; ++j) {
+ SDValue Ret = RetVal.getValue(j);
+ if (needTruncate)
+ Ret = DAG.getNode(ISD::TRUNCATE, dl, Ins[VecIdx + j].VT, Ret);
+ InVals.push_back(Ret);
+ }
+ Chain = RetVal.getValue(NumElts);
+ InFlag = RetVal.getValue(NumElts + 1);
+
+ // Cleanup
+ VecIdx = -1;
+ LoadVTs.clear();
}
}
}
@@ -1752,6 +1847,55 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
return DAG.getBuildVector(Node->getValueType(0), dl, Ops);
}
+// We can init constant f16x2 with a single .b32 move. Normally it
+// would get lowered as two constant loads and vector-packing move.
+// mov.b16 %h1, 0x4000;
+// mov.b16 %h2, 0x3C00;
+// mov.b32 %hh2, {%h2, %h1};
+// Instead we want just a constant move:
+// mov.b32 %hh2, 0x40003C00
+//
+// This results in better SASS code with CUDA 7.x. Ptxas in CUDA 8.0
+// generates good SASS in both cases.
+SDValue NVPTXTargetLowering::LowerBUILD_VECTOR(SDValue Op,
+ SelectionDAG &DAG) const {
+ //return Op;
+ if (!(Op->getValueType(0) == MVT::v2f16 &&
+ isa<ConstantFPSDNode>(Op->getOperand(0)) &&
+ isa<ConstantFPSDNode>(Op->getOperand(1))))
+ return Op;
+
+ APInt E0 =
+ cast<ConstantFPSDNode>(Op->getOperand(0))->getValueAPF().bitcastToAPInt();
+ APInt E1 =
+ cast<ConstantFPSDNode>(Op->getOperand(1))->getValueAPF().bitcastToAPInt();
+ SDValue Const =
+ DAG.getConstant(E1.zext(32).shl(16) | E0.zext(32), SDLoc(Op), MVT::i32);
+ return DAG.getNode(ISD::BITCAST, SDLoc(Op), MVT::v2f16, Const);
+}
+
+SDValue NVPTXTargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op,
+ SelectionDAG &DAG) const {
+ SDValue Index = Op->getOperand(1);
+ // Constant index will be matched by tablegen.
+ if (isa<ConstantSDNode>(Index.getNode()))
+ return Op;
+
+ // Extract individual elements and select one of them.
+ SDValue Vector = Op->getOperand(0);
+ EVT VectorVT = Vector.getValueType();
+ assert(VectorVT == MVT::v2f16 && "Unexpected vector type.");
+ EVT EltVT = VectorVT.getVectorElementType();
+
+ SDLoc dl(Op.getNode());
+ SDValue E0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, Vector,
+ DAG.getIntPtrConstant(0, dl));
+ SDValue E1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, Vector,
+ DAG.getIntPtrConstant(1, dl));
+ return DAG.getSelectCC(dl, Index, DAG.getIntPtrConstant(0, dl), E0, E1,
+ ISD::CondCode::SETEQ);
+}
+
/// LowerShiftRightParts - Lower SRL_PARTS, SRA_PARTS, which
/// 1) returns two i32 values and take a 2 x i32 value to shift plus a shift
/// amount, or
@@ -1885,8 +2029,11 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
case ISD::INTRINSIC_W_CHAIN:
return Op;
case ISD::BUILD_VECTOR:
+ return LowerBUILD_VECTOR(Op, DAG);
case ISD::EXTRACT_SUBVECTOR:
return Op;
+ case ISD::EXTRACT_VECTOR_ELT:
+ return LowerEXTRACT_VECTOR_ELT(Op, DAG);
case ISD::CONCAT_VECTORS:
return LowerCONCAT_VECTORS(Op, DAG);
case ISD::STORE:
@@ -1924,8 +2071,21 @@ SDValue NVPTXTargetLowering::LowerSelect(SDValue Op, SelectionDAG &DAG) const {
SDValue NVPTXTargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const {
if (Op.getValueType() == MVT::i1)
return LowerLOADi1(Op, DAG);
- else
- return SDValue();
+
+ // v2f16 is legal, so we can't rely on legalizer to handle unaligned
+ // loads and have to handle it here.
+ if (Op.getValueType() == MVT::v2f16) {
+ LoadSDNode *Load = cast<LoadSDNode>(Op);
+ EVT MemVT = Load->getMemoryVT();
+ if (!allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), MemVT,
+ Load->getAddressSpace(), Load->getAlignment())) {
+ SDValue Ops[2];
+ std::tie(Ops[0], Ops[1]) = expandUnalignedLoad(Load, DAG);
+ return DAG.getMergeValues(Ops, SDLoc(Op));
+ }
+ }
+
+ return SDValue();
}
// v = ld i1* addr
@@ -1951,13 +2111,23 @@ SDValue NVPTXTargetLowering::LowerLOADi1(SDValue Op, SelectionDAG &DAG) const {
}
SDValue NVPTXTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const {
- EVT ValVT = Op.getOperand(1).getValueType();
- if (ValVT == MVT::i1)
+ StoreSDNode *Store = cast<StoreSDNode>(Op);
+ EVT VT = Store->getMemoryVT();
+
+ if (VT == MVT::i1)
return LowerSTOREi1(Op, DAG);
- else if (ValVT.isVector())
+
+ // v2f16 is legal, so we can't rely on legalizer to handle unaligned
+ // stores and have to handle it here.
+ if (VT == MVT::v2f16 &&
+ !allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), VT,
+ Store->getAddressSpace(), Store->getAlignment()))
+ return expandUnalignedStore(Store, DAG);
+
+ if (VT.isVector())
return LowerSTOREVector(Op, DAG);
- else
- return SDValue();
+
+ return SDValue();
}
SDValue
@@ -1980,12 +2150,15 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
case MVT::v2i16:
case MVT::v2i32:
case MVT::v2i64:
+ case MVT::v2f16:
case MVT::v2f32:
case MVT::v2f64:
case MVT::v4i8:
case MVT::v4i16:
case MVT::v4i32:
+ case MVT::v4f16:
case MVT::v4f32:
+ case MVT::v8f16: // <4 x f16x2>
// This is a "native" vector type
break;
}
@@ -2016,6 +2189,7 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
if (EltVT.getSizeInBits() < 16)
NeedExt = true;
+ bool StoreF16x2 = false;
switch (NumElts) {
default:
return SDValue();
@@ -2025,6 +2199,14 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
case 4:
Opcode = NVPTXISD::StoreV4;
break;
+ case 8:
+ // v8f16 is a special case. PTX doesn't have st.v8.f16
+ // instruction. Instead, we split the vector into v2f16 chunks and
+ // store them with st.v4.b32.
+ assert(EltVT == MVT::f16 && "Wrong type for the vector.");
+ Opcode = NVPTXISD::StoreV4;
+ StoreF16x2 = true;
+ break;
}
SmallVector<SDValue, 8> Ops;
@@ -2032,23 +2214,36 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
// First is the chain
Ops.push_back(N->getOperand(0));
- // Then the split values
- for (unsigned i = 0; i < NumElts; ++i) {
- SDValue ExtVal = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Val,
- DAG.getIntPtrConstant(i, DL));
- if (NeedExt)
- ExtVal = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i16, ExtVal);
- Ops.push_back(ExtVal);
+ if (StoreF16x2) {
+ // Combine f16,f16 -> v2f16
+ NumElts /= 2;
+ for (unsigned i = 0; i < NumElts; ++i) {
+ SDValue E0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::f16, Val,
+ DAG.getIntPtrConstant(i * 2, DL));
+ SDValue E1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::f16, Val,
+ DAG.getIntPtrConstant(i * 2 + 1, DL));
+ SDValue V2 = DAG.getNode(ISD::BUILD_VECTOR, DL, MVT::v2f16, E0, E1);
+ Ops.push_back(V2);
+ }
+ } else {
+ // Then the split values
+ for (unsigned i = 0; i < NumElts; ++i) {
+ SDValue ExtVal = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Val,
+ DAG.getIntPtrConstant(i, DL));
+ if (NeedExt)
+ ExtVal = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i16, ExtVal);
+ Ops.push_back(ExtVal);
+ }
}
// Then any remaining arguments
Ops.append(N->op_begin() + 2, N->op_end());
- SDValue NewSt = DAG.getMemIntrinsicNode(
- Opcode, DL, DAG.getVTList(MVT::Other), Ops,
- MemSD->getMemoryVT(), MemSD->getMemOperand());
+ SDValue NewSt =
+ DAG.getMemIntrinsicNode(Opcode, DL, DAG.getVTList(MVT::Other), Ops,
+ MemSD->getMemoryVT(), MemSD->getMemOperand());
- //return DCI.CombineTo(N, NewSt, true);
+ // return DCI.CombineTo(N, NewSt, true);
return NewSt;
}
@@ -2120,7 +2315,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
auto PtrVT = getPointerTy(DAG.getDataLayout());
const Function *F = MF.getFunction();
- const AttributeSet &PAL = F->getAttributes();
+ const AttributeList &PAL = F->getAttributes();
const TargetLowering *TLI = STI.getTargetLowering();
SDValue Root = DAG.getRoot();
@@ -2200,177 +2395,80 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
// to newly created nodes. The SDNodes for params have to
// appear in the same order as their order of appearance
// in the original function. "idx+1" holds that order.
- if (!PAL.hasAttribute(i + 1, Attribute::ByVal)) {
- if (Ty->isAggregateType()) {
- SmallVector<EVT, 16> vtparts;
- SmallVector<uint64_t, 16> offsets;
+ if (!PAL.hasParamAttribute(i, Attribute::ByVal)) {
+ bool aggregateIsPacked = false;
+ if (StructType *STy = dyn_cast<StructType>(Ty))
+ aggregateIsPacked = STy->isPacked();
- // NOTE: Here, we lose the ability to issue vector loads for vectors
- // that are a part of a struct. This should be investigated in the
- // future.
- ComputePTXValueVTs(*this, DAG.getDataLayout(), Ty, vtparts, &offsets,
- 0);
- assert(vtparts.size() > 0 && "empty aggregate type not expected");
- bool aggregateIsPacked = false;
- if (StructType *STy = dyn_cast<StructType>(Ty))
- aggregateIsPacked = STy->isPacked();
+ SmallVector<EVT, 16> VTs;
+ SmallVector<uint64_t, 16> Offsets;
+ ComputePTXValueVTs(*this, DL, Ty, VTs, &Offsets, 0);
+ assert(VTs.size() > 0 && "Unexpected empty type.");
+ auto VectorInfo =
+ VectorizePTXValueVTs(VTs, Offsets, DL.getABITypeAlignment(Ty));
- SDValue Arg = getParamSymbol(DAG, idx, PtrVT);
- for (unsigned parti = 0, parte = vtparts.size(); parti != parte;
- ++parti) {
- EVT partVT = vtparts[parti];
- Value *srcValue = Constant::getNullValue(
- PointerType::get(partVT.getTypeForEVT(F->getContext()),
- ADDRESS_SPACE_PARAM));
- SDValue srcAddr =
- DAG.getNode(ISD::ADD, dl, PtrVT, Arg,
- DAG.getConstant(offsets[parti], dl, PtrVT));
- unsigned partAlign = aggregateIsPacked
- ? 1
- : DL.getABITypeAlignment(
- partVT.getTypeForEVT(F->getContext()));
- SDValue p;
- if (Ins[InsIdx].VT.getSizeInBits() > partVT.getSizeInBits()) {
- ISD::LoadExtType ExtOp = Ins[InsIdx].Flags.isSExt() ?
- ISD::SEXTLOAD : ISD::ZEXTLOAD;
- p = DAG.getExtLoad(ExtOp, dl, Ins[InsIdx].VT, Root, srcAddr,
- MachinePointerInfo(srcValue), partVT, partAlign);
- } else {
- p = DAG.getLoad(partVT, dl, Root, srcAddr,
- MachinePointerInfo(srcValue), partAlign);
- }
- if (p.getNode())
- p.getNode()->setIROrder(idx + 1);
- InVals.push_back(p);
- ++InsIdx;
+ SDValue Arg = getParamSymbol(DAG, idx, PtrVT);
+ int VecIdx = -1; // Index of the first element of the current vector.
+ for (unsigned parti = 0, parte = VTs.size(); parti != parte; ++parti) {
+ if (VectorInfo[parti] & PVF_FIRST) {
+ assert(VecIdx == -1 && "Orphaned vector.");
+ VecIdx = parti;
}
- if (vtparts.size() > 0)
- --InsIdx;
- continue;
- }
- if (Ty->isVectorTy()) {
- EVT ObjectVT = getValueType(DL, Ty);
- SDValue Arg = getParamSymbol(DAG, idx, PtrVT);
- unsigned NumElts = ObjectVT.getVectorNumElements();
- assert(TLI->getNumRegisters(F->getContext(), ObjectVT) == NumElts &&
- "Vector was not scalarized");
- EVT EltVT = ObjectVT.getVectorElementType();
-
- // V1 load
- // f32 = load ...
- if (NumElts == 1) {
- // We only have one element, so just directly load it
- Value *SrcValue = Constant::getNullValue(PointerType::get(
- EltVT.getTypeForEVT(F->getContext()), ADDRESS_SPACE_PARAM));
- SDValue P = DAG.getLoad(
- EltVT, dl, Root, Arg, MachinePointerInfo(SrcValue),
- DL.getABITypeAlignment(EltVT.getTypeForEVT(F->getContext())),
- MachineMemOperand::MODereferenceable |
- MachineMemOperand::MOInvariant);
- if (P.getNode())
- P.getNode()->setIROrder(idx + 1);
- if (Ins[InsIdx].VT.getSizeInBits() > EltVT.getSizeInBits())
- P = DAG.getNode(ISD::ANY_EXTEND, dl, Ins[InsIdx].VT, P);
- InVals.push_back(P);
- ++InsIdx;
- } else if (NumElts == 2) {
- // V2 load
- // f32,f32 = load ...
- EVT VecVT = EVT::getVectorVT(F->getContext(), EltVT, 2);
- Value *SrcValue = Constant::getNullValue(PointerType::get(
- VecVT.getTypeForEVT(F->getContext()), ADDRESS_SPACE_PARAM));
- SDValue P = DAG.getLoad(
- VecVT, dl, Root, Arg, MachinePointerInfo(SrcValue),
- DL.getABITypeAlignment(VecVT.getTypeForEVT(F->getContext())),
- MachineMemOperand::MODereferenceable |
- MachineMemOperand::MOInvariant);
+ // That's the last element of this store op.
+ if (VectorInfo[parti] & PVF_LAST) {
+ unsigned NumElts = parti - VecIdx + 1;
+ EVT EltVT = VTs[parti];
+ // i1 is loaded/stored as i8.
+ EVT LoadVT = EltVT;
+ if (EltVT == MVT::i1)
+ LoadVT = MVT::i8;
+ else if (EltVT == MVT::v2f16)
+ // getLoad needs a vector type, but it can't handle
+ // vectors which contain v2f16 elements. So we must load
+ // using i32 here and then bitcast back.
+ LoadVT = MVT::i32;
+
+ EVT VecVT = EVT::getVectorVT(F->getContext(), LoadVT, NumElts);
+ SDValue VecAddr =
+ DAG.getNode(ISD::ADD, dl, PtrVT, Arg,
+ DAG.getConstant(Offsets[VecIdx], dl, PtrVT));
+ Value *srcValue = Constant::getNullValue(PointerType::get(
+ EltVT.getTypeForEVT(F->getContext()), ADDRESS_SPACE_PARAM));
+ SDValue P =
+ DAG.getLoad(VecVT, dl, Root, VecAddr,
+ MachinePointerInfo(srcValue), aggregateIsPacked,
+ MachineMemOperand::MODereferenceable |
+ MachineMemOperand::MOInvariant);
if (P.getNode())
P.getNode()->setIROrder(idx + 1);
-
- SDValue Elt0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, P,
- DAG.getIntPtrConstant(0, dl));
- SDValue Elt1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, P,
- DAG.getIntPtrConstant(1, dl));
-
- if (Ins[InsIdx].VT.getSizeInBits() > EltVT.getSizeInBits()) {
- Elt0 = DAG.getNode(ISD::ANY_EXTEND, dl, Ins[InsIdx].VT, Elt0);
- Elt1 = DAG.getNode(ISD::ANY_EXTEND, dl, Ins[InsIdx].VT, Elt1);
- }
-
- InVals.push_back(Elt0);
- InVals.push_back(Elt1);
- InsIdx += 2;
- } else {
- // V4 loads
- // We have at least 4 elements (<3 x Ty> expands to 4 elements) and
- // the vector will be expanded to a power of 2 elements, so we know we
- // can always round up to the next multiple of 4 when creating the
- // vector loads.
- // e.g. 4 elem => 1 ld.v4
- // 6 elem => 2 ld.v4
- // 8 elem => 2 ld.v4
- // 11 elem => 3 ld.v4
- unsigned VecSize = 4;
- if (EltVT.getSizeInBits() == 64) {
- VecSize = 2;
- }
- EVT VecVT = EVT::getVectorVT(F->getContext(), EltVT, VecSize);
- unsigned Ofst = 0;
- for (unsigned i = 0; i < NumElts; i += VecSize) {
- Value *SrcValue = Constant::getNullValue(
- PointerType::get(VecVT.getTypeForEVT(F->getContext()),
- ADDRESS_SPACE_PARAM));
- SDValue SrcAddr = DAG.getNode(ISD::ADD, dl, PtrVT, Arg,
- DAG.getConstant(Ofst, dl, PtrVT));
- SDValue P = DAG.getLoad(
- VecVT, dl, Root, SrcAddr, MachinePointerInfo(SrcValue),
- DL.getABITypeAlignment(VecVT.getTypeForEVT(F->getContext())),
- MachineMemOperand::MODereferenceable |
- MachineMemOperand::MOInvariant);
- if (P.getNode())
- P.getNode()->setIROrder(idx + 1);
-
- for (unsigned j = 0; j < VecSize; ++j) {
- if (i + j >= NumElts)
- break;
- SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, P,
- DAG.getIntPtrConstant(j, dl));
- if (Ins[InsIdx].VT.getSizeInBits() > EltVT.getSizeInBits())
- Elt = DAG.getNode(ISD::ANY_EXTEND, dl, Ins[InsIdx].VT, Elt);
- InVals.push_back(Elt);
+ for (unsigned j = 0; j < NumElts; ++j) {
+ SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, LoadVT, P,
+ DAG.getIntPtrConstant(j, dl));
+ // We've loaded i1 as an i8 and now must truncate it back to i1
+ if (EltVT == MVT::i1)
+ Elt = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, Elt);
+ // v2f16 was loaded as an i32. Now we must bitcast it back.
+ else if (EltVT == MVT::v2f16)
+ Elt = DAG.getNode(ISD::BITCAST, dl, MVT::v2f16, Elt);
+ // Extend the element if necesary (e.g. an i8 is loaded
+ // into an i16 register)
+ if (Ins[InsIdx].VT.isInteger() &&
+ Ins[InsIdx].VT.getSizeInBits() > LoadVT.getSizeInBits()) {
+ unsigned Extend = Ins[InsIdx].Flags.isSExt() ? ISD::SIGN_EXTEND
+ : ISD::ZERO_EXTEND;
+ Elt = DAG.getNode(Extend, dl, Ins[InsIdx].VT, Elt);
}
- Ofst += DL.getTypeAllocSize(VecVT.getTypeForEVT(F->getContext()));
+ InVals.push_back(Elt);
}
- InsIdx += NumElts;
- }
- if (NumElts > 0)
- --InsIdx;
- continue;
- }
- // A plain scalar.
- EVT ObjectVT = getValueType(DL, Ty);
- // If ABI, load from the param symbol
- SDValue Arg = getParamSymbol(DAG, idx, PtrVT);
- Value *srcValue = Constant::getNullValue(PointerType::get(
- ObjectVT.getTypeForEVT(F->getContext()), ADDRESS_SPACE_PARAM));
- SDValue p;
- if (ObjectVT.getSizeInBits() < Ins[InsIdx].VT.getSizeInBits()) {
- ISD::LoadExtType ExtOp = Ins[InsIdx].Flags.isSExt() ?
- ISD::SEXTLOAD : ISD::ZEXTLOAD;
- p = DAG.getExtLoad(
- ExtOp, dl, Ins[InsIdx].VT, Root, Arg, MachinePointerInfo(srcValue),
- ObjectVT,
- DL.getABITypeAlignment(ObjectVT.getTypeForEVT(F->getContext())));
- } else {
- p = DAG.getLoad(
- Ins[InsIdx].VT, dl, Root, Arg, MachinePointerInfo(srcValue),
- DL.getABITypeAlignment(ObjectVT.getTypeForEVT(F->getContext())));
+ // Reset vector tracking state.
+ VecIdx = -1;
+ }
+ ++InsIdx;
}
- if (p.getNode())
- p.getNode()->setIROrder(idx + 1);
- InVals.push_back(p);
+ if (VTs.size() > 0)
+ --InsIdx;
continue;
}
@@ -2412,164 +2510,77 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
const SmallVectorImpl<SDValue> &OutVals,
const SDLoc &dl, SelectionDAG &DAG) const {
MachineFunction &MF = DAG.getMachineFunction();
- const Function *F = MF.getFunction();
- Type *RetTy = F->getReturnType();
- const DataLayout &TD = DAG.getDataLayout();
+ Type *RetTy = MF.getFunction()->getReturnType();
bool isABI = (STI.getSmVersion() >= 20);
assert(isABI && "Non-ABI compilation is not supported");
if (!isABI)
return Chain;
- if (VectorType *VTy = dyn_cast<VectorType>(RetTy)) {
- // If we have a vector type, the OutVals array will be the scalarized
- // components and we have combine them into 1 or more vector stores.
- unsigned NumElts = VTy->getNumElements();
- assert(NumElts == Outs.size() && "Bad scalarization of return value");
+ const DataLayout DL = DAG.getDataLayout();
+ SmallVector<EVT, 16> VTs;
+ SmallVector<uint64_t, 16> Offsets;
+ ComputePTXValueVTs(*this, DL, RetTy, VTs, &Offsets);
+ assert(VTs.size() == OutVals.size() && "Bad return value decomposition");
+
+ auto VectorInfo = VectorizePTXValueVTs(
+ VTs, Offsets, RetTy->isSized() ? DL.getABITypeAlignment(RetTy) : 1);
+
+ // PTX Interoperability Guide 3.3(A): [Integer] Values shorter than
+ // 32-bits are sign extended or zero extended, depending on whether
+ // they are signed or unsigned types.
+ bool ExtendIntegerRetVal =
+ RetTy->isIntegerTy() && DL.getTypeAllocSizeInBits(RetTy) < 32;
+
+ SmallVector<SDValue, 6> StoreOperands;
+ for (unsigned i = 0, e = VTs.size(); i != e; ++i) {
+ // New load/store. Record chain and offset operands.
+ if (VectorInfo[i] & PVF_FIRST) {
+ assert(StoreOperands.empty() && "Orphaned operand list.");
+ StoreOperands.push_back(Chain);
+ StoreOperands.push_back(DAG.getConstant(Offsets[i], dl, MVT::i32));
+ }
- // const_cast can be removed in later LLVM versions
- EVT EltVT = getValueType(TD, RetTy).getVectorElementType();
- bool NeedExtend = false;
- if (EltVT.getSizeInBits() < 16)
- NeedExtend = true;
-
- // V1 store
- if (NumElts == 1) {
- SDValue StoreVal = OutVals[0];
- // We only have one element, so just directly store it
- if (NeedExtend)
- StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal);
- SDValue Ops[] = { Chain, DAG.getConstant(0, dl, MVT::i32), StoreVal };
- Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreRetval, dl,
- DAG.getVTList(MVT::Other), Ops,
- EltVT, MachinePointerInfo());
- } else if (NumElts == 2) {
- // V2 store
- SDValue StoreVal0 = OutVals[0];
- SDValue StoreVal1 = OutVals[1];
-
- if (NeedExtend) {
- StoreVal0 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal0);
- StoreVal1 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal1);
- }
+ SDValue RetVal = OutVals[i];
+ if (ExtendIntegerRetVal) {
+ RetVal = DAG.getNode(Outs[i].Flags.isSExt() ? ISD::SIGN_EXTEND
+ : ISD::ZERO_EXTEND,
+ dl, MVT::i32, RetVal);
+ } else if (RetVal.getValueSizeInBits() < 16) {
+ // Use 16-bit registers for small load-stores as it's the
+ // smallest general purpose register size supported by NVPTX.
+ RetVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, RetVal);
+ }
- SDValue Ops[] = { Chain, DAG.getConstant(0, dl, MVT::i32), StoreVal0,
- StoreVal1 };
- Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreRetvalV2, dl,
- DAG.getVTList(MVT::Other), Ops,
- EltVT, MachinePointerInfo());
- } else {
- // V4 stores
- // We have at least 4 elements (<3 x Ty> expands to 4 elements) and the
- // vector will be expanded to a power of 2 elements, so we know we can
- // always round up to the next multiple of 4 when creating the vector
- // stores.
- // e.g. 4 elem => 1 st.v4
- // 6 elem => 2 st.v4
- // 8 elem => 2 st.v4
- // 11 elem => 3 st.v4
-
- unsigned VecSize = 4;
- if (OutVals[0].getValueSizeInBits() == 64)
- VecSize = 2;
-
- unsigned Offset = 0;
-
- EVT VecVT =
- EVT::getVectorVT(F->getContext(), EltVT, VecSize);
- unsigned PerStoreOffset =
- TD.getTypeAllocSize(VecVT.getTypeForEVT(F->getContext()));
-
- for (unsigned i = 0; i < NumElts; i += VecSize) {
- // Get values
- SDValue StoreVal;
- SmallVector<SDValue, 8> Ops;
- Ops.push_back(Chain);
- Ops.push_back(DAG.getConstant(Offset, dl, MVT::i32));
- unsigned Opc = NVPTXISD::StoreRetvalV2;
- EVT ExtendedVT = (NeedExtend) ? MVT::i16 : OutVals[0].getValueType();
-
- StoreVal = OutVals[i];
- if (NeedExtend)
- StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, ExtendedVT, StoreVal);
- Ops.push_back(StoreVal);
-
- if (i + 1 < NumElts) {
- StoreVal = OutVals[i + 1];
- if (NeedExtend)
- StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, ExtendedVT, StoreVal);
- } else {
- StoreVal = DAG.getUNDEF(ExtendedVT);
- }
- Ops.push_back(StoreVal);
-
- if (VecSize == 4) {
- Opc = NVPTXISD::StoreRetvalV4;
- if (i + 2 < NumElts) {
- StoreVal = OutVals[i + 2];
- if (NeedExtend)
- StoreVal =
- DAG.getNode(ISD::ZERO_EXTEND, dl, ExtendedVT, StoreVal);
- } else {
- StoreVal = DAG.getUNDEF(ExtendedVT);
- }
- Ops.push_back(StoreVal);
-
- if (i + 3 < NumElts) {
- StoreVal = OutVals[i + 3];
- if (NeedExtend)
- StoreVal =
- DAG.getNode(ISD::ZERO_EXTEND, dl, ExtendedVT, StoreVal);
- } else {
- StoreVal = DAG.getUNDEF(ExtendedVT);
- }
- Ops.push_back(StoreVal);
- }
+ // Record the value to return.
+ StoreOperands.push_back(RetVal);
- // Chain = DAG.getNode(Opc, dl, MVT::Other, &Ops[0], Ops.size());
- Chain =
- DAG.getMemIntrinsicNode(Opc, dl, DAG.getVTList(MVT::Other), Ops,
- EltVT, MachinePointerInfo());
- Offset += PerStoreOffset;
- }
- }
- } else {
- SmallVector<EVT, 16> ValVTs;
- SmallVector<uint64_t, 16> Offsets;
- ComputePTXValueVTs(*this, DAG.getDataLayout(), RetTy, ValVTs, &Offsets, 0);
- assert(ValVTs.size() == OutVals.size() && "Bad return value decomposition");
-
- for (unsigned i = 0, e = Outs.size(); i != e; ++i) {
- SDValue theVal = OutVals[i];
- EVT TheValType = theVal.getValueType();
- unsigned numElems = 1;
- if (TheValType.isVector())
- numElems = TheValType.getVectorNumElements();
- for (unsigned j = 0, je = numElems; j != je; ++j) {
- SDValue TmpVal = theVal;
- if (TheValType.isVector())
- TmpVal = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl,
- TheValType.getVectorElementType(), TmpVal,
- DAG.getIntPtrConstant(j, dl));
- EVT TheStoreType = ValVTs[i];
- if (RetTy->isIntegerTy() && TD.getTypeAllocSizeInBits(RetTy) < 32) {
- // The following zero-extension is for integer types only, and
- // specifically not for aggregates.
- TmpVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i32, TmpVal);
- TheStoreType = MVT::i32;
- }
- else if (TmpVal.getValueSizeInBits() < 16)
- TmpVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, TmpVal);
-
- SDValue Ops[] = {
- Chain,
- DAG.getConstant(Offsets[i], dl, MVT::i32),
- TmpVal };
- Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreRetval, dl,
- DAG.getVTList(MVT::Other), Ops,
- TheStoreType,
- MachinePointerInfo());
+ // That's the last element of this store op.
+ if (VectorInfo[i] & PVF_LAST) {
+ NVPTXISD::NodeType Op;
+ unsigned NumElts = StoreOperands.size() - 2;
+ switch (NumElts) {
+ case 1:
+ Op = NVPTXISD::StoreRetval;
+ break;
+ case 2:
+ Op = NVPTXISD::StoreRetvalV2;
+ break;
+ case 4:
+ Op = NVPTXISD::StoreRetvalV4;
+ break;
+ default:
+ llvm_unreachable("Invalid vector info.");
}
+
+ // Adjust type of load/store op if we've extended the scalar
+ // return value.
+ EVT TheStoreType = ExtendIntegerRetVal ? MVT::i32 : VTs[i];
+ Chain = DAG.getMemIntrinsicNode(Op, dl, DAG.getVTList(MVT::Other),
+ StoreOperands, TheStoreType,
+ MachinePointerInfo(), 1);
+ // Cleanup vector state.
+ StoreOperands.clear();
}
}
@@ -3863,27 +3874,35 @@ NVPTXTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
bool NVPTXTargetLowering::allowFMA(MachineFunction &MF,
CodeGenOpt::Level OptLevel) const {
- const Function *F = MF.getFunction();
- const TargetOptions &TO = MF.getTarget().Options;
-
// Always honor command-line argument
- if (FMAContractLevelOpt.getNumOccurrences() > 0) {
+ if (FMAContractLevelOpt.getNumOccurrences() > 0)
return FMAContractLevelOpt > 0;
- } else if (OptLevel == 0) {
- // Do not contract if we're not optimizing the code
+
+ // Do not contract if we're not optimizing the code.
+ if (OptLevel == 0)
return false;
- } else if (TO.AllowFPOpFusion == FPOpFusion::Fast || TO.UnsafeFPMath) {
- // Honor TargetOptions flags that explicitly say fusion is okay
+
+ // Honor TargetOptions flags that explicitly say fusion is okay.
+ if (MF.getTarget().Options.AllowFPOpFusion == FPOpFusion::Fast)
return true;
- } else if (F->hasFnAttribute("unsafe-fp-math")) {
- // Check for unsafe-fp-math=true coming from Clang
+
+ return allowUnsafeFPMath(MF);
+}
+
+bool NVPTXTargetLowering::allowUnsafeFPMath(MachineFunction &MF) const {
+ // Honor TargetOptions flags that explicitly say unsafe math is okay.
+ if (MF.getTarget().Options.UnsafeFPMath)
+ return true;
+
+ // Allow unsafe math if unsafe-fp-math attribute explicitly says so.
+ const Function *F = MF.getFunction();
+ if (F->hasFnAttribute("unsafe-fp-math")) {
Attribute Attr = F->getFnAttribute("unsafe-fp-math");
StringRef Val = Attr.getValueAsString();
if (Val == "true")
return true;
}
- // We did not have a clear indication that fusion is allowed, so assume not
return false;
}
@@ -4088,67 +4107,6 @@ static SDValue PerformANDCombine(SDNode *N,
return SDValue();
}
-static SDValue PerformSELECTCombine(SDNode *N,
- TargetLowering::DAGCombinerInfo &DCI) {
- // Currently this detects patterns for integer min and max and
- // lowers them to PTX-specific intrinsics that enable hardware
- // support.
-
- const SDValue Cond = N->getOperand(0);
- if (Cond.getOpcode() != ISD::SETCC) return SDValue();
-
- const SDValue LHS = Cond.getOperand(0);
- const SDValue RHS = Cond.getOperand(1);
- const SDValue True = N->getOperand(1);
- const SDValue False = N->getOperand(2);
- if (!(LHS == True && RHS == False) && !(LHS == False && RHS == True))
- return SDValue();
-
- const EVT VT = N->getValueType(0);
- if (VT != MVT::i32 && VT != MVT::i64) return SDValue();
-
- const ISD::CondCode CC = cast<CondCodeSDNode>(Cond.getOperand(2))->get();
- SDValue Larger; // The larger of LHS and RHS when condition is true.
- switch (CC) {
- case ISD::SETULT:
- case ISD::SETULE:
- case ISD::SETLT:
- case ISD::SETLE:
- Larger = RHS;
- break;
-
- case ISD::SETGT:
- case ISD::SETGE:
- case ISD::SETUGT:
- case ISD::SETUGE:
- Larger = LHS;
- break;
-
- default:
- return SDValue();
- }
- const bool IsMax = (Larger == True);
- const bool IsSigned = ISD::isSignedIntSetCC(CC);
-
- unsigned IntrinsicId;
- if (VT == MVT::i32) {
- if (IsSigned)
- IntrinsicId = IsMax ? Intrinsic::nvvm_max_i : Intrinsic::nvvm_min_i;
- else
- IntrinsicId = IsMax ? Intrinsic::nvvm_max_ui : Intrinsic::nvvm_min_ui;
- } else {
- assert(VT == MVT::i64);
- if (IsSigned)
- IntrinsicId = IsMax ? Intrinsic::nvvm_max_ll : Intrinsic::nvvm_min_ll;
- else
- IntrinsicId = IsMax ? Intrinsic::nvvm_max_ull : Intrinsic::nvvm_min_ull;
- }
-
- SDLoc DL(N);
- return DCI.DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT,
- DCI.DAG.getConstant(IntrinsicId, DL, VT), LHS, RHS);
-}
-
static SDValue PerformREMCombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI,
CodeGenOpt::Level OptLevel) {
@@ -4344,6 +4302,27 @@ static SDValue PerformSHLCombine(SDNode *N,
return SDValue();
}
+static SDValue PerformSETCCCombine(SDNode *N,
+ TargetLowering::DAGCombinerInfo &DCI) {
+ EVT CCType = N->getValueType(0);
+ SDValue A = N->getOperand(0);
+ SDValue B = N->getOperand(1);
+
+ if (CCType != MVT::v2i1 || A.getValueType() != MVT::v2f16)
+ return SDValue();
+
+ SDLoc DL(N);
+ // setp.f16x2 returns two scalar predicates, which we need to
+ // convert back to v2i1. The returned result will be scalarized by
+ // the legalizer, but the comparison will remain a single vector
+ // instruction.
+ SDValue CCNode = DCI.DAG.getNode(NVPTXISD::SETP_F16X2, DL,
+ DCI.DAG.getVTList(MVT::i1, MVT::i1),
+ {A, B, N->getOperand(2)});
+ return DCI.DAG.getNode(ISD::BUILD_VECTOR, DL, CCType, CCNode.getValue(0),
+ CCNode.getValue(1));
+}
+
SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
DAGCombinerInfo &DCI) const {
CodeGenOpt::Level OptLevel = getTargetMachine().getOptLevel();
@@ -4358,11 +4337,11 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
return PerformSHLCombine(N, DCI, OptLevel);
case ISD::AND:
return PerformANDCombine(N, DCI);
- case ISD::SELECT:
- return PerformSELECTCombine(N, DCI);
case ISD::UREM:
case ISD::SREM:
return PerformREMCombine(N, DCI, OptLevel);
+ case ISD::SETCC:
+ return PerformSETCCCombine(N, DCI);
}
return SDValue();
}
@@ -4386,12 +4365,15 @@ static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG,
case MVT::v2i16:
case MVT::v2i32:
case MVT::v2i64:
+ case MVT::v2f16:
case MVT::v2f32:
case MVT::v2f64:
case MVT::v4i8:
case MVT::v4i16:
case MVT::v4i32:
+ case MVT::v4f16:
case MVT::v4f32:
+ case MVT::v8f16: // <4 x f16x2>
// This is a "native" vector type
break;
}
@@ -4425,6 +4407,7 @@ static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG,
unsigned Opcode = 0;
SDVTList LdResVTs;
+ bool LoadF16x2 = false;
switch (NumElts) {
default:
@@ -4439,6 +4422,18 @@ static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG,
LdResVTs = DAG.getVTList(ListVTs);
break;
}
+ case 8: {
+ // v8f16 is a special case. PTX doesn't have ld.v8.f16
+ // instruction. Instead, we split the vector into v2f16 chunks and
+ // load them with ld.v4.b32.
+ assert(EltVT == MVT::f16 && "Unsupported v8 vector type.");
+ LoadF16x2 = true;
+ Opcode = NVPTXISD::LoadV4;
+ EVT ListVTs[] = {MVT::v2f16, MVT::v2f16, MVT::v2f16, MVT::v2f16,
+ MVT::Other};
+ LdResVTs = DAG.getVTList(ListVTs);
+ break;
+ }
}
// Copy regular operands
@@ -4452,13 +4447,26 @@ static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG,
LD->getMemoryVT(),
LD->getMemOperand());
- SmallVector<SDValue, 4> ScalarRes;
-
- for (unsigned i = 0; i < NumElts; ++i) {
- SDValue Res = NewLD.getValue(i);
- if (NeedTrunc)
- Res = DAG.getNode(ISD::TRUNCATE, DL, ResVT.getVectorElementType(), Res);
- ScalarRes.push_back(Res);
+ SmallVector<SDValue, 8> ScalarRes;
+ if (LoadF16x2) {
+ // Split v2f16 subvectors back into individual elements.
+ NumElts /= 2;
+ for (unsigned i = 0; i < NumElts; ++i) {
+ SDValue SubVector = NewLD.getValue(i);
+ SDValue E0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, SubVector,
+ DAG.getIntPtrConstant(0, DL));
+ SDValue E1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, SubVector,
+ DAG.getIntPtrConstant(1, DL));
+ ScalarRes.push_back(E0);
+ ScalarRes.push_back(E1);
+ }
+ } else {
+ for (unsigned i = 0; i < NumElts; ++i) {
+ SDValue Res = NewLD.getValue(i);
+ if (NeedTrunc)
+ Res = DAG.getNode(ISD::TRUNCATE, DL, ResVT.getVectorElementType(), Res);
+ ScalarRes.push_back(Res);
+ }
}
SDValue LoadChain = NewLD.getValue(NumElts);