diff options
Diffstat (limited to 'lib/Target/NVPTX/NVPTXISelLowering.cpp')
-rw-r--r-- | lib/Target/NVPTX/NVPTXISelLowering.cpp | 1694 |
1 files changed, 851 insertions, 843 deletions
diff --git a/lib/Target/NVPTX/NVPTXISelLowering.cpp b/lib/Target/NVPTX/NVPTXISelLowering.cpp index 7a760fd38d0f..4d06912054a2 100644 --- a/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -79,6 +79,60 @@ FMAContractLevelOpt("nvptx-fma-level", cl::ZeroOrMore, cl::Hidden, " 1: do it 2: do it aggressively"), cl::init(2)); +static cl::opt<int> UsePrecDivF32( + "nvptx-prec-divf32", cl::ZeroOrMore, cl::Hidden, + cl::desc("NVPTX Specifies: 0 use div.approx, 1 use div.full, 2 use" + " IEEE Compliant F32 div.rnd if available."), + cl::init(2)); + +static cl::opt<bool> UsePrecSqrtF32( + "nvptx-prec-sqrtf32", cl::Hidden, + cl::desc("NVPTX Specific: 0 use sqrt.approx, 1 use sqrt.rn."), + cl::init(true)); + +static cl::opt<bool> FtzEnabled( + "nvptx-f32ftz", cl::ZeroOrMore, cl::Hidden, + cl::desc("NVPTX Specific: Flush f32 subnormals to sign-preserving zero."), + cl::init(false)); + +int NVPTXTargetLowering::getDivF32Level() const { + if (UsePrecDivF32.getNumOccurrences() > 0) { + // If nvptx-prec-div32=N is used on the command-line, always honor it + return UsePrecDivF32; + } else { + // Otherwise, use div.approx if fast math is enabled + if (getTargetMachine().Options.UnsafeFPMath) + return 0; + else + return 2; + } +} + +bool NVPTXTargetLowering::usePrecSqrtF32() const { + if (UsePrecSqrtF32.getNumOccurrences() > 0) { + // If nvptx-prec-sqrtf32 is used on the command-line, always honor it + return UsePrecSqrtF32; + } else { + // Otherwise, use sqrt.approx if fast math is enabled + return !getTargetMachine().Options.UnsafeFPMath; + } +} + +bool NVPTXTargetLowering::useF32FTZ(const MachineFunction &MF) const { + // TODO: Get rid of this flag; there can be only one way to do this. + if (FtzEnabled.getNumOccurrences() > 0) { + // If nvptx-f32ftz is used on the command-line, always honor it + return FtzEnabled; + } else { + const Function *F = MF.getFunction(); + // Otherwise, check for an nvptx-f32ftz attribute on the function + if (F->hasFnAttribute("nvptx-f32ftz")) + return F->getFnAttribute("nvptx-f32ftz").getValueAsString() == "true"; + else + return false; + } +} + static bool IsPTXVectorType(MVT VT) { switch (VT.SimpleTy) { default: @@ -92,6 +146,9 @@ static bool IsPTXVectorType(MVT VT) { case MVT::v2i32: case MVT::v4i32: case MVT::v2i64: + case MVT::v2f16: + case MVT::v4f16: + case MVT::v8f16: // <4 x f16x2> case MVT::v2f32: case MVT::v4f32: case MVT::v2f64: @@ -116,13 +173,24 @@ static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL, for (unsigned i = 0, e = TempVTs.size(); i != e; ++i) { EVT VT = TempVTs[i]; uint64_t Off = TempOffsets[i]; - if (VT.isVector()) - for (unsigned j = 0, je = VT.getVectorNumElements(); j != je; ++j) { - ValueVTs.push_back(VT.getVectorElementType()); + // Split vectors into individual elements, except for v2f16, which + // we will pass as a single scalar. + if (VT.isVector()) { + unsigned NumElts = VT.getVectorNumElements(); + EVT EltVT = VT.getVectorElementType(); + // Vectors with an even number of f16 elements will be passed to + // us as an array of v2f16 elements. We must match this so we + // stay in sync with Ins/Outs. + if (EltVT == MVT::f16 && NumElts % 2 == 0) { + EltVT = MVT::v2f16; + NumElts /= 2; + } + for (unsigned j = 0; j != NumElts; ++j) { + ValueVTs.push_back(EltVT); if (Offsets) - Offsets->push_back(Off+j*VT.getVectorElementType().getStoreSize()); + Offsets->push_back(Off + j * EltVT.getStoreSize()); } - else { + } else { ValueVTs.push_back(VT); if (Offsets) Offsets->push_back(Off); @@ -130,6 +198,125 @@ static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL, } } +// Check whether we can merge loads/stores of some of the pieces of a +// flattened function parameter or return value into a single vector +// load/store. +// +// The flattened parameter is represented as a list of EVTs and +// offsets, and the whole structure is aligned to ParamAlignment. This +// function determines whether we can load/store pieces of the +// parameter starting at index Idx using a single vectorized op of +// size AccessSize. If so, it returns the number of param pieces +// covered by the vector op. Otherwise, it returns 1. +static unsigned CanMergeParamLoadStoresStartingAt( + unsigned Idx, uint32_t AccessSize, const SmallVectorImpl<EVT> &ValueVTs, + const SmallVectorImpl<uint64_t> &Offsets, unsigned ParamAlignment) { + assert(isPowerOf2_32(AccessSize) && "must be a power of 2!"); + + // Can't vectorize if param alignment is not sufficient. + if (AccessSize > ParamAlignment) + return 1; + // Can't vectorize if offset is not aligned. + if (Offsets[Idx] & (AccessSize - 1)) + return 1; + + EVT EltVT = ValueVTs[Idx]; + unsigned EltSize = EltVT.getStoreSize(); + + // Element is too large to vectorize. + if (EltSize >= AccessSize) + return 1; + + unsigned NumElts = AccessSize / EltSize; + // Can't vectorize if AccessBytes if not a multiple of EltSize. + if (AccessSize != EltSize * NumElts) + return 1; + + // We don't have enough elements to vectorize. + if (Idx + NumElts > ValueVTs.size()) + return 1; + + // PTX ISA can only deal with 2- and 4-element vector ops. + if (NumElts != 4 && NumElts != 2) + return 1; + + for (unsigned j = Idx + 1; j < Idx + NumElts; ++j) { + // Types do not match. + if (ValueVTs[j] != EltVT) + return 1; + + // Elements are not contiguous. + if (Offsets[j] - Offsets[j - 1] != EltSize) + return 1; + } + // OK. We can vectorize ValueVTs[i..i+NumElts) + return NumElts; +} + +// Flags for tracking per-element vectorization state of loads/stores +// of a flattened function parameter or return value. +enum ParamVectorizationFlags { + PVF_INNER = 0x0, // Middle elements of a vector. + PVF_FIRST = 0x1, // First element of the vector. + PVF_LAST = 0x2, // Last element of the vector. + // Scalar is effectively a 1-element vector. + PVF_SCALAR = PVF_FIRST | PVF_LAST +}; + +// Computes whether and how we can vectorize the loads/stores of a +// flattened function parameter or return value. +// +// The flattened parameter is represented as the list of ValueVTs and +// Offsets, and is aligned to ParamAlignment bytes. We return a vector +// of the same size as ValueVTs indicating how each piece should be +// loaded/stored (i.e. as a scalar, or as part of a vector +// load/store). +static SmallVector<ParamVectorizationFlags, 16> +VectorizePTXValueVTs(const SmallVectorImpl<EVT> &ValueVTs, + const SmallVectorImpl<uint64_t> &Offsets, + unsigned ParamAlignment) { + // Set vector size to match ValueVTs and mark all elements as + // scalars by default. + SmallVector<ParamVectorizationFlags, 16> VectorInfo; + VectorInfo.assign(ValueVTs.size(), PVF_SCALAR); + + // Check what we can vectorize using 128/64/32-bit accesses. + for (int I = 0, E = ValueVTs.size(); I != E; ++I) { + // Skip elements we've already processed. + assert(VectorInfo[I] == PVF_SCALAR && "Unexpected vector info state."); + for (unsigned AccessSize : {16, 8, 4, 2}) { + unsigned NumElts = CanMergeParamLoadStoresStartingAt( + I, AccessSize, ValueVTs, Offsets, ParamAlignment); + // Mark vectorized elements. + switch (NumElts) { + default: + llvm_unreachable("Unexpected return value"); + case 1: + // Can't vectorize using this size, try next smaller size. + continue; + case 2: + assert(I + 1 < E && "Not enough elements."); + VectorInfo[I] = PVF_FIRST; + VectorInfo[I + 1] = PVF_LAST; + I += 1; + break; + case 4: + assert(I + 3 < E && "Not enough elements."); + VectorInfo[I] = PVF_FIRST; + VectorInfo[I + 1] = PVF_INNER; + VectorInfo[I + 2] = PVF_INNER; + VectorInfo[I + 3] = PVF_LAST; + I += 3; + break; + } + // Break out of the inner loop because we've already succeeded + // using largest possible AccessSize. + break; + } + } + return VectorInfo; +} + // NVPTXTargetLowering Constructor. NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM, const NVPTXSubtarget &STI) @@ -158,14 +345,32 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM, else setSchedulingPreference(Sched::Source); + auto setFP16OperationAction = [&](unsigned Op, MVT VT, LegalizeAction Action, + LegalizeAction NoF16Action) { + setOperationAction(Op, VT, STI.allowFP16Math() ? Action : NoF16Action); + }; + addRegisterClass(MVT::i1, &NVPTX::Int1RegsRegClass); addRegisterClass(MVT::i16, &NVPTX::Int16RegsRegClass); addRegisterClass(MVT::i32, &NVPTX::Int32RegsRegClass); addRegisterClass(MVT::i64, &NVPTX::Int64RegsRegClass); addRegisterClass(MVT::f32, &NVPTX::Float32RegsRegClass); addRegisterClass(MVT::f64, &NVPTX::Float64RegsRegClass); + addRegisterClass(MVT::f16, &NVPTX::Float16RegsRegClass); + addRegisterClass(MVT::v2f16, &NVPTX::Float16x2RegsRegClass); + + // Conversion to/from FP16/FP16x2 is always legal. + setOperationAction(ISD::SINT_TO_FP, MVT::f16, Legal); + setOperationAction(ISD::FP_TO_SINT, MVT::f16, Legal); + setOperationAction(ISD::BUILD_VECTOR, MVT::v2f16, Custom); + setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::v2f16, Custom); + + setFP16OperationAction(ISD::SETCC, MVT::f16, Legal, Promote); + setFP16OperationAction(ISD::SETCC, MVT::v2f16, Legal, Expand); // Operations not directly supported by NVPTX. + setOperationAction(ISD::SELECT_CC, MVT::f16, Expand); + setOperationAction(ISD::SELECT_CC, MVT::v2f16, Expand); setOperationAction(ISD::SELECT_CC, MVT::f32, Expand); setOperationAction(ISD::SELECT_CC, MVT::f64, Expand); setOperationAction(ISD::SELECT_CC, MVT::i1, Expand); @@ -173,6 +378,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM, setOperationAction(ISD::SELECT_CC, MVT::i16, Expand); setOperationAction(ISD::SELECT_CC, MVT::i32, Expand); setOperationAction(ISD::SELECT_CC, MVT::i64, Expand); + setOperationAction(ISD::BR_CC, MVT::f16, Expand); + setOperationAction(ISD::BR_CC, MVT::v2f16, Expand); setOperationAction(ISD::BR_CC, MVT::f32, Expand); setOperationAction(ISD::BR_CC, MVT::f64, Expand); setOperationAction(ISD::BR_CC, MVT::i1, Expand); @@ -195,6 +402,9 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM, setOperationAction(ISD::SRA_PARTS, MVT::i64 , Custom); setOperationAction(ISD::SRL_PARTS, MVT::i64 , Custom); + setOperationAction(ISD::BITREVERSE, MVT::i32, Legal); + setOperationAction(ISD::BITREVERSE, MVT::i64, Legal); + if (STI.hasROT64()) { setOperationAction(ISD::ROTL, MVT::i64, Legal); setOperationAction(ISD::ROTR, MVT::i64, Legal); @@ -259,6 +469,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM, // This is legal in NVPTX setOperationAction(ISD::ConstantFP, MVT::f64, Legal); setOperationAction(ISD::ConstantFP, MVT::f32, Legal); + setOperationAction(ISD::ConstantFP, MVT::f16, Legal); // TRAP can be lowered to PTX trap setOperationAction(ISD::TRAP, MVT::Other, Legal); @@ -278,15 +489,19 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM, // Custom handling for i8 intrinsics setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::i8, Custom); - setOperationAction(ISD::CTLZ, MVT::i16, Legal); - setOperationAction(ISD::CTLZ, MVT::i32, Legal); - setOperationAction(ISD::CTLZ, MVT::i64, Legal); + for (const auto& Ty : {MVT::i16, MVT::i32, MVT::i64}) { + setOperationAction(ISD::SMIN, Ty, Legal); + setOperationAction(ISD::SMAX, Ty, Legal); + setOperationAction(ISD::UMIN, Ty, Legal); + setOperationAction(ISD::UMAX, Ty, Legal); + + setOperationAction(ISD::CTPOP, Ty, Legal); + setOperationAction(ISD::CTLZ, Ty, Legal); + } + setOperationAction(ISD::CTTZ, MVT::i16, Expand); setOperationAction(ISD::CTTZ, MVT::i32, Expand); setOperationAction(ISD::CTTZ, MVT::i64, Expand); - setOperationAction(ISD::CTPOP, MVT::i16, Legal); - setOperationAction(ISD::CTPOP, MVT::i32, Legal); - setOperationAction(ISD::CTPOP, MVT::i64, Legal); // PTX does not directly support SELP of i1, so promote to i32 first setOperationAction(ISD::SELECT, MVT::i1, Custom); @@ -301,28 +516,60 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM, setTargetDAGCombine(ISD::FADD); setTargetDAGCombine(ISD::MUL); setTargetDAGCombine(ISD::SHL); - setTargetDAGCombine(ISD::SELECT); setTargetDAGCombine(ISD::SREM); setTargetDAGCombine(ISD::UREM); - // Library functions. These default to Expand, but we have instructions - // for them. - setOperationAction(ISD::FCEIL, MVT::f32, Legal); - setOperationAction(ISD::FCEIL, MVT::f64, Legal); - setOperationAction(ISD::FFLOOR, MVT::f32, Legal); - setOperationAction(ISD::FFLOOR, MVT::f64, Legal); - setOperationAction(ISD::FNEARBYINT, MVT::f32, Legal); - setOperationAction(ISD::FNEARBYINT, MVT::f64, Legal); - setOperationAction(ISD::FRINT, MVT::f32, Legal); - setOperationAction(ISD::FRINT, MVT::f64, Legal); - setOperationAction(ISD::FROUND, MVT::f32, Legal); - setOperationAction(ISD::FROUND, MVT::f64, Legal); - setOperationAction(ISD::FTRUNC, MVT::f32, Legal); - setOperationAction(ISD::FTRUNC, MVT::f64, Legal); - setOperationAction(ISD::FMINNUM, MVT::f32, Legal); - setOperationAction(ISD::FMINNUM, MVT::f64, Legal); - setOperationAction(ISD::FMAXNUM, MVT::f32, Legal); - setOperationAction(ISD::FMAXNUM, MVT::f64, Legal); + // setcc for f16x2 needs special handling to prevent legalizer's + // attempt to scalarize it due to v2i1 not being legal. + if (STI.allowFP16Math()) + setTargetDAGCombine(ISD::SETCC); + + // Promote fp16 arithmetic if fp16 hardware isn't available or the + // user passed --nvptx-no-fp16-math. The flag is useful because, + // although sm_53+ GPUs have some sort of FP16 support in + // hardware, only sm_53 and sm_60 have full implementation. Others + // only have token amount of hardware and are likely to run faster + // by using fp32 units instead. + for (const auto &Op : {ISD::FADD, ISD::FMUL, ISD::FSUB, ISD::FMA}) { + setFP16OperationAction(Op, MVT::f16, Legal, Promote); + setFP16OperationAction(Op, MVT::v2f16, Legal, Expand); + } + + // There's no neg.f16 instruction. Expand to (0-x). + setOperationAction(ISD::FNEG, MVT::f16, Expand); + setOperationAction(ISD::FNEG, MVT::v2f16, Expand); + + // (would be) Library functions. + + // These map to conversion instructions for scalar FP types. + for (const auto &Op : {ISD::FCEIL, ISD::FFLOOR, ISD::FNEARBYINT, ISD::FRINT, + ISD::FROUND, ISD::FTRUNC}) { + setOperationAction(Op, MVT::f16, Legal); + setOperationAction(Op, MVT::f32, Legal); + setOperationAction(Op, MVT::f64, Legal); + setOperationAction(Op, MVT::v2f16, Expand); + } + + // 'Expand' implements FCOPYSIGN without calling an external library. + setOperationAction(ISD::FCOPYSIGN, MVT::f16, Expand); + setOperationAction(ISD::FCOPYSIGN, MVT::v2f16, Expand); + setOperationAction(ISD::FCOPYSIGN, MVT::f32, Expand); + setOperationAction(ISD::FCOPYSIGN, MVT::f64, Expand); + + // These map to corresponding instructions for f32/f64. f16 must be + // promoted to f32. v2f16 is expanded to f16, which is then promoted + // to f32. + for (const auto &Op : {ISD::FDIV, ISD::FREM, ISD::FSQRT, ISD::FSIN, ISD::FCOS, + ISD::FABS, ISD::FMINNUM, ISD::FMAXNUM}) { + setOperationAction(Op, MVT::f16, Promote); + setOperationAction(Op, MVT::f32, Legal); + setOperationAction(Op, MVT::f64, Legal); + setOperationAction(Op, MVT::v2f16, Expand); + } + setOperationAction(ISD::FMINNUM, MVT::f16, Promote); + setOperationAction(ISD::FMAXNUM, MVT::f16, Promote); + setOperationAction(ISD::FMINNAN, MVT::f16, Promote); + setOperationAction(ISD::FMAXNAN, MVT::f16, Promote); // No FEXP2, FLOG2. The PTX ex2 and log2 functions are always approximate. // No FPOW or FREM in PTX. @@ -434,6 +681,8 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const { return "NVPTXISD::FUN_SHFR_CLAMP"; case NVPTXISD::IMAD: return "NVPTXISD::IMAD"; + case NVPTXISD::SETP_F16X2: + return "NVPTXISD::SETP_F16X2"; case NVPTXISD::Dummy: return "NVPTXISD::Dummy"; case NVPTXISD::MUL_WIDE_SIGNED: @@ -932,10 +1181,60 @@ TargetLoweringBase::LegalizeTypeAction NVPTXTargetLowering::getPreferredVectorAction(EVT VT) const { if (VT.getVectorNumElements() != 1 && VT.getScalarType() == MVT::i1) return TypeSplitVector; - + if (VT == MVT::v2f16) + return TypeLegal; return TargetLoweringBase::getPreferredVectorAction(VT); } +SDValue NVPTXTargetLowering::getSqrtEstimate(SDValue Operand, SelectionDAG &DAG, + int Enabled, int &ExtraSteps, + bool &UseOneConst, + bool Reciprocal) const { + if (!(Enabled == ReciprocalEstimate::Enabled || + (Enabled == ReciprocalEstimate::Unspecified && !usePrecSqrtF32()))) + return SDValue(); + + if (ExtraSteps == ReciprocalEstimate::Unspecified) + ExtraSteps = 0; + + SDLoc DL(Operand); + EVT VT = Operand.getValueType(); + bool Ftz = useF32FTZ(DAG.getMachineFunction()); + + auto MakeIntrinsicCall = [&](Intrinsic::ID IID) { + return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT, + DAG.getConstant(IID, DL, MVT::i32), Operand); + }; + + // The sqrt and rsqrt refinement processes assume we always start out with an + // approximation of the rsqrt. Therefore, if we're going to do any refinement + // (i.e. ExtraSteps > 0), we must return an rsqrt. But if we're *not* doing + // any refinement, we must return a regular sqrt. + if (Reciprocal || ExtraSteps > 0) { + if (VT == MVT::f32) + return MakeIntrinsicCall(Ftz ? Intrinsic::nvvm_rsqrt_approx_ftz_f + : Intrinsic::nvvm_rsqrt_approx_f); + else if (VT == MVT::f64) + return MakeIntrinsicCall(Intrinsic::nvvm_rsqrt_approx_d); + else + return SDValue(); + } else { + if (VT == MVT::f32) + return MakeIntrinsicCall(Ftz ? Intrinsic::nvvm_sqrt_approx_ftz_f + : Intrinsic::nvvm_sqrt_approx_f); + else { + // There's no sqrt.approx.f64 instruction, so we emit + // reciprocal(rsqrt(x)). This is faster than + // select(x == 0, 0, x * rsqrt(x)). (In fact, it's faster than plain + // x * rsqrt(x).) + return DAG.getNode( + ISD::INTRINSIC_WO_CHAIN, DL, VT, + DAG.getConstant(Intrinsic::nvvm_rcp_approx_ftz_d, DL, MVT::i32), + MakeIntrinsicCall(Intrinsic::nvvm_rsqrt_approx_d)); + } + } +} + SDValue NVPTXTargetLowering::LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const { SDLoc dl(Op); @@ -967,19 +1266,21 @@ std::string NVPTXTargetLowering::getPrototype( unsigned size = 0; if (auto *ITy = dyn_cast<IntegerType>(retTy)) { size = ITy->getBitWidth(); - if (size < 32) - size = 32; } else { assert(retTy->isFloatingPointTy() && "Floating point type expected here"); size = retTy->getPrimitiveSizeInBits(); } + // PTX ABI requires all scalar return values to be at least 32 + // bits in size. fp16 normally uses .b16 as its storage type in + // PTX, so its size must be adjusted here, too. + if (size < 32) + size = 32; O << ".param .b" << size << " _"; } else if (isa<PointerType>(retTy)) { O << ".param .b" << PtrVT.getSizeInBits() << " _"; - } else if ((retTy->getTypeID() == Type::StructTyID) || - isa<VectorType>(retTy)) { + } else if (retTy->isAggregateType() || retTy->isVectorTy()) { auto &DL = CS->getCalledFunction()->getParent()->getDataLayout(); O << ".param .align " << retAlignment << " .b8 _[" << DL.getTypeAllocSize(retTy) << "]"; @@ -1018,7 +1319,7 @@ std::string NVPTXTargetLowering::getPrototype( OIdx += len - 1; continue; } - // i8 types in IR will be i16 types in SDAG + // i8 types in IR will be i16 types in SDAG assert((getValueType(DL, Ty) == Outs[OIdx].VT || (getValueType(DL, Ty) == MVT::i8 && Outs[OIdx].VT == MVT::i16)) && "type mismatch between callee prototype and arguments"); @@ -1028,8 +1329,13 @@ std::string NVPTXTargetLowering::getPrototype( sz = cast<IntegerType>(Ty)->getBitWidth(); if (sz < 32) sz = 32; - } else if (isa<PointerType>(Ty)) + } else if (isa<PointerType>(Ty)) { sz = PtrVT.getSizeInBits(); + } else if (Ty->isHalfTy()) + // PTX ABI requires all scalar parameters to be at least 32 + // bits in size. fp16 normally uses .b16 as its storage type + // in PTX, so its size must be adjusted here, too. + sz = 32; else sz = Ty->getPrimitiveSizeInBits(); O << ".param .b" << sz << " "; @@ -1113,21 +1419,18 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, SDValue Callee = CLI.Callee; bool &isTailCall = CLI.IsTailCall; ArgListTy &Args = CLI.getArgs(); - Type *retTy = CLI.RetTy; + Type *RetTy = CLI.RetTy; ImmutableCallSite *CS = CLI.CS; + const DataLayout &DL = DAG.getDataLayout(); bool isABI = (STI.getSmVersion() >= 20); assert(isABI && "Non-ABI compilation is not supported"); if (!isABI) return Chain; - MachineFunction &MF = DAG.getMachineFunction(); - const Function *F = MF.getFunction(); - auto &DL = MF.getDataLayout(); SDValue tempChain = Chain; - Chain = DAG.getCALLSEQ_START(Chain, - DAG.getIntPtrConstant(uniqueCallSite, dl, true), - dl); + Chain = DAG.getCALLSEQ_START( + Chain, DAG.getIntPtrConstant(uniqueCallSite, dl, true), dl); SDValue InFlag = Chain.getValue(1); unsigned paramCount = 0; @@ -1148,240 +1451,124 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, Type *Ty = Args[i].Ty; if (!Outs[OIdx].Flags.isByVal()) { - if (Ty->isAggregateType()) { - // aggregate - SmallVector<EVT, 16> vtparts; - SmallVector<uint64_t, 16> Offsets; - ComputePTXValueVTs(*this, DAG.getDataLayout(), Ty, vtparts, &Offsets, - 0); - - unsigned align = - getArgumentAlignment(Callee, CS, Ty, paramCount + 1, DL); + SmallVector<EVT, 16> VTs; + SmallVector<uint64_t, 16> Offsets; + ComputePTXValueVTs(*this, DL, Ty, VTs, &Offsets); + unsigned ArgAlign = + getArgumentAlignment(Callee, CS, Ty, paramCount + 1, DL); + unsigned AllocSize = DL.getTypeAllocSize(Ty); + SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue); + bool NeedAlign; // Does argument declaration specify alignment? + if (Ty->isAggregateType() || Ty->isVectorTy()) { // declare .param .align <align> .b8 .param<n>[<size>]; - unsigned sz = DL.getTypeAllocSize(Ty); - SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue); - SDValue DeclareParamOps[] = { Chain, DAG.getConstant(align, dl, - MVT::i32), - DAG.getConstant(paramCount, dl, MVT::i32), - DAG.getConstant(sz, dl, MVT::i32), - InFlag }; + SDValue DeclareParamOps[] = { + Chain, DAG.getConstant(ArgAlign, dl, MVT::i32), + DAG.getConstant(paramCount, dl, MVT::i32), + DAG.getConstant(AllocSize, dl, MVT::i32), InFlag}; Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs, DeclareParamOps); - InFlag = Chain.getValue(1); - for (unsigned j = 0, je = vtparts.size(); j != je; ++j) { - EVT elemtype = vtparts[j]; - unsigned ArgAlign = GreatestCommonDivisor64(align, Offsets[j]); - if (elemtype.isInteger() && (sz < 8)) - sz = 8; - SDValue StVal = OutVals[OIdx]; - if (elemtype.getSizeInBits() < 16) { - StVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, StVal); - } - SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue); - SDValue CopyParamOps[] = { Chain, - DAG.getConstant(paramCount, dl, MVT::i32), - DAG.getConstant(Offsets[j], dl, MVT::i32), - StVal, InFlag }; - Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreParam, dl, - CopyParamVTs, CopyParamOps, - elemtype, MachinePointerInfo(), - ArgAlign); - InFlag = Chain.getValue(1); - ++OIdx; + NeedAlign = true; + } else { + // declare .param .b<size> .param<n>; + if ((VT.isInteger() || VT.isFloatingPoint()) && AllocSize < 4) { + // PTX ABI requires integral types to be at least 32 bits in + // size. FP16 is loaded/stored using i16, so it's handled + // here as well. + AllocSize = 4; } - if (vtparts.size() > 0) - --OIdx; - ++paramCount; - continue; + SDValue DeclareScalarParamOps[] = { + Chain, DAG.getConstant(paramCount, dl, MVT::i32), + DAG.getConstant(AllocSize * 8, dl, MVT::i32), + DAG.getConstant(0, dl, MVT::i32), InFlag}; + Chain = DAG.getNode(NVPTXISD::DeclareScalarParam, dl, DeclareParamVTs, + DeclareScalarParamOps); + NeedAlign = false; } - if (Ty->isVectorTy()) { - EVT ObjectVT = getValueType(DL, Ty); - unsigned align = - getArgumentAlignment(Callee, CS, Ty, paramCount + 1, DL); - // declare .param .align <align> .b8 .param<n>[<size>]; - unsigned sz = DL.getTypeAllocSize(Ty); - SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue); - SDValue DeclareParamOps[] = { Chain, - DAG.getConstant(align, dl, MVT::i32), - DAG.getConstant(paramCount, dl, MVT::i32), - DAG.getConstant(sz, dl, MVT::i32), - InFlag }; - Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs, - DeclareParamOps); - InFlag = Chain.getValue(1); - unsigned NumElts = ObjectVT.getVectorNumElements(); - EVT EltVT = ObjectVT.getVectorElementType(); - EVT MemVT = EltVT; - bool NeedExtend = false; - if (EltVT.getSizeInBits() < 16) { - NeedExtend = true; - EltVT = MVT::i16; + InFlag = Chain.getValue(1); + + // PTX Interoperability Guide 3.3(A): [Integer] Values shorter + // than 32-bits are sign extended or zero extended, depending on + // whether they are signed or unsigned types. This case applies + // only to scalar parameters and not to aggregate values. + bool ExtendIntegerParam = + Ty->isIntegerTy() && DL.getTypeAllocSizeInBits(Ty) < 32; + + auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, ArgAlign); + SmallVector<SDValue, 6> StoreOperands; + for (unsigned j = 0, je = VTs.size(); j != je; ++j) { + // New store. + if (VectorInfo[j] & PVF_FIRST) { + assert(StoreOperands.empty() && "Unfinished preceeding store."); + StoreOperands.push_back(Chain); + StoreOperands.push_back(DAG.getConstant(paramCount, dl, MVT::i32)); + StoreOperands.push_back(DAG.getConstant(Offsets[j], dl, MVT::i32)); } - // V1 store - if (NumElts == 1) { - SDValue Elt = OutVals[OIdx++]; - if (NeedExtend) - Elt = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, Elt); - - SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue); - SDValue CopyParamOps[] = { Chain, - DAG.getConstant(paramCount, dl, MVT::i32), - DAG.getConstant(0, dl, MVT::i32), Elt, - InFlag }; - Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreParam, dl, - CopyParamVTs, CopyParamOps, - MemVT, MachinePointerInfo()); - InFlag = Chain.getValue(1); - } else if (NumElts == 2) { - SDValue Elt0 = OutVals[OIdx++]; - SDValue Elt1 = OutVals[OIdx++]; - if (NeedExtend) { - Elt0 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, Elt0); - Elt1 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, Elt1); + EVT EltVT = VTs[j]; + SDValue StVal = OutVals[OIdx]; + if (ExtendIntegerParam) { + assert(VTs.size() == 1 && "Scalar can't have multiple parts."); + // zext/sext to i32 + StVal = DAG.getNode(Outs[OIdx].Flags.isSExt() ? ISD::SIGN_EXTEND + : ISD::ZERO_EXTEND, + dl, MVT::i32, StVal); + } else if (EltVT.getSizeInBits() < 16) { + // Use 16-bit registers for small stores as it's the + // smallest general purpose register size supported by NVPTX. + StVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, StVal); + } + + // Record the value to store. + StoreOperands.push_back(StVal); + + if (VectorInfo[j] & PVF_LAST) { + unsigned NumElts = StoreOperands.size() - 3; + NVPTXISD::NodeType Op; + switch (NumElts) { + case 1: + Op = NVPTXISD::StoreParam; + break; + case 2: + Op = NVPTXISD::StoreParamV2; + break; + case 4: + Op = NVPTXISD::StoreParamV4; + break; + default: + llvm_unreachable("Invalid vector info."); } - SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue); - SDValue CopyParamOps[] = { Chain, - DAG.getConstant(paramCount, dl, MVT::i32), - DAG.getConstant(0, dl, MVT::i32), Elt0, - Elt1, InFlag }; - Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreParamV2, dl, - CopyParamVTs, CopyParamOps, - MemVT, MachinePointerInfo()); - InFlag = Chain.getValue(1); - } else { - unsigned curOffset = 0; - // V4 stores - // We have at least 4 elements (<3 x Ty> expands to 4 elements) and - // the - // vector will be expanded to a power of 2 elements, so we know we can - // always round up to the next multiple of 4 when creating the vector - // stores. - // e.g. 4 elem => 1 st.v4 - // 6 elem => 2 st.v4 - // 8 elem => 2 st.v4 - // 11 elem => 3 st.v4 - unsigned VecSize = 4; - if (EltVT.getSizeInBits() == 64) - VecSize = 2; - - // This is potentially only part of a vector, so assume all elements - // are packed together. - unsigned PerStoreOffset = MemVT.getStoreSizeInBits() / 8 * VecSize; - - for (unsigned i = 0; i < NumElts; i += VecSize) { - // Get values - SDValue StoreVal; - SmallVector<SDValue, 8> Ops; - Ops.push_back(Chain); - Ops.push_back(DAG.getConstant(paramCount, dl, MVT::i32)); - Ops.push_back(DAG.getConstant(curOffset, dl, MVT::i32)); - - unsigned Opc = NVPTXISD::StoreParamV2; - - StoreVal = OutVals[OIdx++]; - if (NeedExtend) - StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal); - Ops.push_back(StoreVal); - - if (i + 1 < NumElts) { - StoreVal = OutVals[OIdx++]; - if (NeedExtend) - StoreVal = - DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal); - } else { - StoreVal = DAG.getUNDEF(EltVT); - } - Ops.push_back(StoreVal); - - if (VecSize == 4) { - Opc = NVPTXISD::StoreParamV4; - if (i + 2 < NumElts) { - StoreVal = OutVals[OIdx++]; - if (NeedExtend) - StoreVal = - DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal); - } else { - StoreVal = DAG.getUNDEF(EltVT); - } - Ops.push_back(StoreVal); - - if (i + 3 < NumElts) { - StoreVal = OutVals[OIdx++]; - if (NeedExtend) - StoreVal = - DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal); - } else { - StoreVal = DAG.getUNDEF(EltVT); - } - Ops.push_back(StoreVal); - } + StoreOperands.push_back(InFlag); - Ops.push_back(InFlag); + // Adjust type of the store op if we've extended the scalar + // return value. + EVT TheStoreType = ExtendIntegerParam ? MVT::i32 : VTs[j]; + unsigned EltAlign = + NeedAlign ? GreatestCommonDivisor64(ArgAlign, Offsets[j]) : 0; - SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue); - Chain = DAG.getMemIntrinsicNode(Opc, dl, CopyParamVTs, Ops, - MemVT, MachinePointerInfo()); - InFlag = Chain.getValue(1); - curOffset += PerStoreOffset; - } + Chain = DAG.getMemIntrinsicNode( + Op, dl, DAG.getVTList(MVT::Other, MVT::Glue), StoreOperands, + TheStoreType, MachinePointerInfo(), EltAlign); + InFlag = Chain.getValue(1); + + // Cleanup. + StoreOperands.clear(); } - ++paramCount; - --OIdx; - continue; - } - // Plain scalar - // for ABI, declare .param .b<size> .param<n>; - unsigned sz = VT.getSizeInBits(); - bool needExtend = false; - if (VT.isInteger()) { - if (sz < 16) - needExtend = true; - if (sz < 32) - sz = 32; + ++OIdx; } - SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue); - SDValue DeclareParamOps[] = { Chain, - DAG.getConstant(paramCount, dl, MVT::i32), - DAG.getConstant(sz, dl, MVT::i32), - DAG.getConstant(0, dl, MVT::i32), InFlag }; - Chain = DAG.getNode(NVPTXISD::DeclareScalarParam, dl, DeclareParamVTs, - DeclareParamOps); - InFlag = Chain.getValue(1); - SDValue OutV = OutVals[OIdx]; - if (needExtend) { - // zext/sext i1 to i16 - unsigned opc = ISD::ZERO_EXTEND; - if (Outs[OIdx].Flags.isSExt()) - opc = ISD::SIGN_EXTEND; - OutV = DAG.getNode(opc, dl, MVT::i16, OutV); - } - SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue); - SDValue CopyParamOps[] = { Chain, - DAG.getConstant(paramCount, dl, MVT::i32), - DAG.getConstant(0, dl, MVT::i32), OutV, - InFlag }; - - unsigned opcode = NVPTXISD::StoreParam; - if (Outs[OIdx].Flags.isZExt() && VT.getSizeInBits() < 32) - opcode = NVPTXISD::StoreParamU32; - else if (Outs[OIdx].Flags.isSExt() && VT.getSizeInBits() < 32) - opcode = NVPTXISD::StoreParamS32; - Chain = DAG.getMemIntrinsicNode(opcode, dl, CopyParamVTs, CopyParamOps, - VT, MachinePointerInfo()); - - InFlag = Chain.getValue(1); + assert(StoreOperands.empty() && "Unfinished parameter store."); + if (VTs.size() > 0) + --OIdx; ++paramCount; continue; } - // struct or vector - SmallVector<EVT, 16> vtparts; + + // ByVal arguments + SmallVector<EVT, 16> VTs; SmallVector<uint64_t, 16> Offsets; auto *PTy = dyn_cast<PointerType>(Args[i].Ty); assert(PTy && "Type of a byval parameter should be pointer"); - ComputePTXValueVTs(*this, DAG.getDataLayout(), PTy->getElementType(), - vtparts, &Offsets, 0); + ComputePTXValueVTs(*this, DL, PTy->getElementType(), VTs, &Offsets, 0); // declare .param .align <align> .b8 .param<n>[<size>]; unsigned sz = Outs[OIdx].Flags.getByValSize(); @@ -1402,11 +1589,11 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs, DeclareParamOps); InFlag = Chain.getValue(1); - for (unsigned j = 0, je = vtparts.size(); j != je; ++j) { - EVT elemtype = vtparts[j]; + for (unsigned j = 0, je = VTs.size(); j != je; ++j) { + EVT elemtype = VTs[j]; int curOffset = Offsets[j]; unsigned PartAlign = GreatestCommonDivisor64(ArgAlign, curOffset); - auto PtrVT = getPointerTy(DAG.getDataLayout()); + auto PtrVT = getPointerTy(DL); SDValue srcAddr = DAG.getNode(ISD::ADD, dl, PtrVT, OutVals[OIdx], DAG.getConstant(curOffset, dl, PtrVT)); SDValue theVal = DAG.getLoad(elemtype, dl, tempChain, srcAddr, @@ -1434,18 +1621,18 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, // Handle Result if (Ins.size() > 0) { SmallVector<EVT, 16> resvtparts; - ComputeValueVTs(*this, DL, retTy, resvtparts); + ComputeValueVTs(*this, DL, RetTy, resvtparts); // Declare // .param .align 16 .b8 retval0[<size-in-bytes>], or // .param .b<size-in-bits> retval0 - unsigned resultsz = DL.getTypeAllocSizeInBits(retTy); + unsigned resultsz = DL.getTypeAllocSizeInBits(RetTy); // Emit ".param .b<size-in-bits> retval0" instead of byte arrays only for // these three types to match the logic in // NVPTXAsmPrinter::printReturnValStr and NVPTXTargetLowering::getPrototype. // Plus, this behavior is consistent with nvcc's. - if (retTy->isFloatingPointTy() || retTy->isIntegerTy() || - retTy->isPointerTy()) { + if (RetTy->isFloatingPointTy() || RetTy->isIntegerTy() || + RetTy->isPointerTy()) { // Scalar needs to be at least 32bit wide if (resultsz < 32) resultsz = 32; @@ -1457,7 +1644,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, DeclareRetOps); InFlag = Chain.getValue(1); } else { - retAlignment = getArgumentAlignment(Callee, CS, retTy, 0, DL); + retAlignment = getArgumentAlignment(Callee, CS, RetTy, 0, DL); SDVTList DeclareRetVTs = DAG.getVTList(MVT::Other, MVT::Glue); SDValue DeclareRetOps[] = { Chain, DAG.getConstant(retAlignment, dl, MVT::i32), @@ -1478,8 +1665,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, // The prototype is embedded in a string and put as the operand for a // CallPrototype SDNode which will print out to the value of the string. SDVTList ProtoVTs = DAG.getVTList(MVT::Other, MVT::Glue); - std::string Proto = - getPrototype(DAG.getDataLayout(), retTy, Args, Outs, retAlignment, CS); + std::string Proto = getPrototype(DL, RetTy, Args, Outs, retAlignment, CS); const char *ProtoStr = nvTM->getManagedStrPool()->getManagedString(Proto.c_str())->c_str(); SDValue ProtoOps[] = { @@ -1544,175 +1730,84 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, // Generate loads from param memory/moves from registers for result if (Ins.size() > 0) { - if (retTy && retTy->isVectorTy()) { - EVT ObjectVT = getValueType(DL, retTy); - unsigned NumElts = ObjectVT.getVectorNumElements(); - EVT EltVT = ObjectVT.getVectorElementType(); - assert(STI.getTargetLowering()->getNumRegisters(F->getContext(), - ObjectVT) == NumElts && - "Vector was not scalarized"); - unsigned sz = EltVT.getSizeInBits(); - bool needTruncate = sz < 8; - - if (NumElts == 1) { - // Just a simple load - SmallVector<EVT, 4> LoadRetVTs; - if (EltVT == MVT::i1 || EltVT == MVT::i8) { - // If loading i1/i8 result, generate - // load.b8 i16 - // if i1 - // trunc i16 to i1 - LoadRetVTs.push_back(MVT::i16); - } else - LoadRetVTs.push_back(EltVT); - LoadRetVTs.push_back(MVT::Other); - LoadRetVTs.push_back(MVT::Glue); - SDValue LoadRetOps[] = {Chain, DAG.getConstant(1, dl, MVT::i32), - DAG.getConstant(0, dl, MVT::i32), InFlag}; - SDValue retval = DAG.getMemIntrinsicNode( - NVPTXISD::LoadParam, dl, - DAG.getVTList(LoadRetVTs), LoadRetOps, EltVT, MachinePointerInfo()); - Chain = retval.getValue(1); - InFlag = retval.getValue(2); - SDValue Ret0 = retval; - if (needTruncate) - Ret0 = DAG.getNode(ISD::TRUNCATE, dl, EltVT, Ret0); - InVals.push_back(Ret0); - } else if (NumElts == 2) { - // LoadV2 - SmallVector<EVT, 4> LoadRetVTs; - if (EltVT == MVT::i1 || EltVT == MVT::i8) { - // If loading i1/i8 result, generate - // load.b8 i16 - // if i1 - // trunc i16 to i1 - LoadRetVTs.push_back(MVT::i16); - LoadRetVTs.push_back(MVT::i16); - } else { - LoadRetVTs.push_back(EltVT); - LoadRetVTs.push_back(EltVT); - } - LoadRetVTs.push_back(MVT::Other); - LoadRetVTs.push_back(MVT::Glue); - SDValue LoadRetOps[] = {Chain, DAG.getConstant(1, dl, MVT::i32), - DAG.getConstant(0, dl, MVT::i32), InFlag}; - SDValue retval = DAG.getMemIntrinsicNode( - NVPTXISD::LoadParamV2, dl, - DAG.getVTList(LoadRetVTs), LoadRetOps, EltVT, MachinePointerInfo()); - Chain = retval.getValue(2); - InFlag = retval.getValue(3); - SDValue Ret0 = retval.getValue(0); - SDValue Ret1 = retval.getValue(1); - if (needTruncate) { - Ret0 = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, Ret0); - InVals.push_back(Ret0); - Ret1 = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, Ret1); - InVals.push_back(Ret1); - } else { - InVals.push_back(Ret0); - InVals.push_back(Ret1); - } - } else { - // Split into N LoadV4 - unsigned Ofst = 0; - unsigned VecSize = 4; - unsigned Opc = NVPTXISD::LoadParamV4; - if (EltVT.getSizeInBits() == 64) { - VecSize = 2; - Opc = NVPTXISD::LoadParamV2; - } - EVT VecVT = EVT::getVectorVT(F->getContext(), EltVT, VecSize); - for (unsigned i = 0; i < NumElts; i += VecSize) { - SmallVector<EVT, 8> LoadRetVTs; - if (EltVT == MVT::i1 || EltVT == MVT::i8) { - // If loading i1/i8 result, generate - // load.b8 i16 - // if i1 - // trunc i16 to i1 - for (unsigned j = 0; j < VecSize; ++j) - LoadRetVTs.push_back(MVT::i16); - } else { - for (unsigned j = 0; j < VecSize; ++j) - LoadRetVTs.push_back(EltVT); - } - LoadRetVTs.push_back(MVT::Other); - LoadRetVTs.push_back(MVT::Glue); - SDValue LoadRetOps[] = {Chain, DAG.getConstant(1, dl, MVT::i32), - DAG.getConstant(Ofst, dl, MVT::i32), InFlag}; - SDValue retval = DAG.getMemIntrinsicNode( - Opc, dl, DAG.getVTList(LoadRetVTs), - LoadRetOps, EltVT, MachinePointerInfo()); - if (VecSize == 2) { - Chain = retval.getValue(2); - InFlag = retval.getValue(3); - } else { - Chain = retval.getValue(4); - InFlag = retval.getValue(5); - } + SmallVector<EVT, 16> VTs; + SmallVector<uint64_t, 16> Offsets; + ComputePTXValueVTs(*this, DL, RetTy, VTs, &Offsets, 0); + assert(VTs.size() == Ins.size() && "Bad value decomposition"); + + unsigned RetAlign = getArgumentAlignment(Callee, CS, RetTy, 0, DL); + auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, RetAlign); + + SmallVector<EVT, 6> LoadVTs; + int VecIdx = -1; // Index of the first element of the vector. + + // PTX Interoperability Guide 3.3(A): [Integer] Values shorter than + // 32-bits are sign extended or zero extended, depending on whether + // they are signed or unsigned types. + bool ExtendIntegerRetVal = + RetTy->isIntegerTy() && DL.getTypeAllocSizeInBits(RetTy) < 32; + + for (unsigned i = 0, e = VTs.size(); i != e; ++i) { + bool needTruncate = false; + EVT TheLoadType = VTs[i]; + EVT EltType = Ins[i].VT; + unsigned EltAlign = GreatestCommonDivisor64(RetAlign, Offsets[i]); + if (ExtendIntegerRetVal) { + TheLoadType = MVT::i32; + EltType = MVT::i32; + needTruncate = true; + } else if (TheLoadType.getSizeInBits() < 16) { + if (VTs[i].isInteger()) + needTruncate = true; + EltType = MVT::i16; + } - for (unsigned j = 0; j < VecSize; ++j) { - if (i + j >= NumElts) - break; - SDValue Elt = retval.getValue(j); - if (needTruncate) - Elt = DAG.getNode(ISD::TRUNCATE, dl, EltVT, Elt); - InVals.push_back(Elt); - } - Ofst += DL.getTypeAllocSize(VecVT.getTypeForEVT(F->getContext())); - } + // Record index of the very first element of the vector. + if (VectorInfo[i] & PVF_FIRST) { + assert(VecIdx == -1 && LoadVTs.empty() && "Orphaned operand list."); + VecIdx = i; } - } else { - SmallVector<EVT, 16> VTs; - SmallVector<uint64_t, 16> Offsets; - auto &DL = DAG.getDataLayout(); - ComputePTXValueVTs(*this, DL, retTy, VTs, &Offsets, 0); - assert(VTs.size() == Ins.size() && "Bad value decomposition"); - unsigned RetAlign = getArgumentAlignment(Callee, CS, retTy, 0, DL); - for (unsigned i = 0, e = Ins.size(); i != e; ++i) { - unsigned sz = VTs[i].getSizeInBits(); - unsigned AlignI = GreatestCommonDivisor64(RetAlign, Offsets[i]); - bool needTruncate = false; - if (VTs[i].isInteger() && sz < 8) { - sz = 8; - needTruncate = true; + + LoadVTs.push_back(EltType); + + if (VectorInfo[i] & PVF_LAST) { + unsigned NumElts = LoadVTs.size(); + LoadVTs.push_back(MVT::Other); + LoadVTs.push_back(MVT::Glue); + NVPTXISD::NodeType Op; + switch (NumElts) { + case 1: + Op = NVPTXISD::LoadParam; + break; + case 2: + Op = NVPTXISD::LoadParamV2; + break; + case 4: + Op = NVPTXISD::LoadParamV4; + break; + default: + llvm_unreachable("Invalid vector info."); } - SmallVector<EVT, 4> LoadRetVTs; - EVT TheLoadType = VTs[i]; - if (retTy->isIntegerTy() && DL.getTypeAllocSizeInBits(retTy) < 32) { - // This is for integer types only, and specifically not for - // aggregates. - LoadRetVTs.push_back(MVT::i32); - TheLoadType = MVT::i32; - needTruncate = true; - } else if (sz < 16) { - // If loading i1/i8 result, generate - // load i8 (-> i16) - // trunc i16 to i1/i8 - - // FIXME: Do we need to set needTruncate to true here, too? We could - // not figure out what this branch is for in D17872, so we left it - // alone. The comment above about loading i1/i8 may be wrong, as the - // branch above seems to cover integers of size < 32. - LoadRetVTs.push_back(MVT::i16); - } else - LoadRetVTs.push_back(Ins[i].VT); - LoadRetVTs.push_back(MVT::Other); - LoadRetVTs.push_back(MVT::Glue); - - SDValue LoadRetOps[] = {Chain, DAG.getConstant(1, dl, MVT::i32), - DAG.getConstant(Offsets[i], dl, MVT::i32), - InFlag}; - SDValue retval = DAG.getMemIntrinsicNode( - NVPTXISD::LoadParam, dl, - DAG.getVTList(LoadRetVTs), LoadRetOps, - TheLoadType, MachinePointerInfo(), AlignI); - Chain = retval.getValue(1); - InFlag = retval.getValue(2); - SDValue Ret0 = retval.getValue(0); - if (needTruncate) - Ret0 = DAG.getNode(ISD::TRUNCATE, dl, Ins[i].VT, Ret0); - InVals.push_back(Ret0); + SDValue LoadOperands[] = { + Chain, DAG.getConstant(1, dl, MVT::i32), + DAG.getConstant(Offsets[VecIdx], dl, MVT::i32), InFlag}; + SDValue RetVal = DAG.getMemIntrinsicNode( + Op, dl, DAG.getVTList(LoadVTs), LoadOperands, TheLoadType, + MachinePointerInfo(), EltAlign); + + for (unsigned j = 0; j < NumElts; ++j) { + SDValue Ret = RetVal.getValue(j); + if (needTruncate) + Ret = DAG.getNode(ISD::TRUNCATE, dl, Ins[VecIdx + j].VT, Ret); + InVals.push_back(Ret); + } + Chain = RetVal.getValue(NumElts); + InFlag = RetVal.getValue(NumElts + 1); + + // Cleanup + VecIdx = -1; + LoadVTs.clear(); } } } @@ -1752,6 +1847,55 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const { return DAG.getBuildVector(Node->getValueType(0), dl, Ops); } +// We can init constant f16x2 with a single .b32 move. Normally it +// would get lowered as two constant loads and vector-packing move. +// mov.b16 %h1, 0x4000; +// mov.b16 %h2, 0x3C00; +// mov.b32 %hh2, {%h2, %h1}; +// Instead we want just a constant move: +// mov.b32 %hh2, 0x40003C00 +// +// This results in better SASS code with CUDA 7.x. Ptxas in CUDA 8.0 +// generates good SASS in both cases. +SDValue NVPTXTargetLowering::LowerBUILD_VECTOR(SDValue Op, + SelectionDAG &DAG) const { + //return Op; + if (!(Op->getValueType(0) == MVT::v2f16 && + isa<ConstantFPSDNode>(Op->getOperand(0)) && + isa<ConstantFPSDNode>(Op->getOperand(1)))) + return Op; + + APInt E0 = + cast<ConstantFPSDNode>(Op->getOperand(0))->getValueAPF().bitcastToAPInt(); + APInt E1 = + cast<ConstantFPSDNode>(Op->getOperand(1))->getValueAPF().bitcastToAPInt(); + SDValue Const = + DAG.getConstant(E1.zext(32).shl(16) | E0.zext(32), SDLoc(Op), MVT::i32); + return DAG.getNode(ISD::BITCAST, SDLoc(Op), MVT::v2f16, Const); +} + +SDValue NVPTXTargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op, + SelectionDAG &DAG) const { + SDValue Index = Op->getOperand(1); + // Constant index will be matched by tablegen. + if (isa<ConstantSDNode>(Index.getNode())) + return Op; + + // Extract individual elements and select one of them. + SDValue Vector = Op->getOperand(0); + EVT VectorVT = Vector.getValueType(); + assert(VectorVT == MVT::v2f16 && "Unexpected vector type."); + EVT EltVT = VectorVT.getVectorElementType(); + + SDLoc dl(Op.getNode()); + SDValue E0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, Vector, + DAG.getIntPtrConstant(0, dl)); + SDValue E1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, Vector, + DAG.getIntPtrConstant(1, dl)); + return DAG.getSelectCC(dl, Index, DAG.getIntPtrConstant(0, dl), E0, E1, + ISD::CondCode::SETEQ); +} + /// LowerShiftRightParts - Lower SRL_PARTS, SRA_PARTS, which /// 1) returns two i32 values and take a 2 x i32 value to shift plus a shift /// amount, or @@ -1885,8 +2029,11 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const { case ISD::INTRINSIC_W_CHAIN: return Op; case ISD::BUILD_VECTOR: + return LowerBUILD_VECTOR(Op, DAG); case ISD::EXTRACT_SUBVECTOR: return Op; + case ISD::EXTRACT_VECTOR_ELT: + return LowerEXTRACT_VECTOR_ELT(Op, DAG); case ISD::CONCAT_VECTORS: return LowerCONCAT_VECTORS(Op, DAG); case ISD::STORE: @@ -1924,8 +2071,21 @@ SDValue NVPTXTargetLowering::LowerSelect(SDValue Op, SelectionDAG &DAG) const { SDValue NVPTXTargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const { if (Op.getValueType() == MVT::i1) return LowerLOADi1(Op, DAG); - else - return SDValue(); + + // v2f16 is legal, so we can't rely on legalizer to handle unaligned + // loads and have to handle it here. + if (Op.getValueType() == MVT::v2f16) { + LoadSDNode *Load = cast<LoadSDNode>(Op); + EVT MemVT = Load->getMemoryVT(); + if (!allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), MemVT, + Load->getAddressSpace(), Load->getAlignment())) { + SDValue Ops[2]; + std::tie(Ops[0], Ops[1]) = expandUnalignedLoad(Load, DAG); + return DAG.getMergeValues(Ops, SDLoc(Op)); + } + } + + return SDValue(); } // v = ld i1* addr @@ -1951,13 +2111,23 @@ SDValue NVPTXTargetLowering::LowerLOADi1(SDValue Op, SelectionDAG &DAG) const { } SDValue NVPTXTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const { - EVT ValVT = Op.getOperand(1).getValueType(); - if (ValVT == MVT::i1) + StoreSDNode *Store = cast<StoreSDNode>(Op); + EVT VT = Store->getMemoryVT(); + + if (VT == MVT::i1) return LowerSTOREi1(Op, DAG); - else if (ValVT.isVector()) + + // v2f16 is legal, so we can't rely on legalizer to handle unaligned + // stores and have to handle it here. + if (VT == MVT::v2f16 && + !allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), VT, + Store->getAddressSpace(), Store->getAlignment())) + return expandUnalignedStore(Store, DAG); + + if (VT.isVector()) return LowerSTOREVector(Op, DAG); - else - return SDValue(); + + return SDValue(); } SDValue @@ -1980,12 +2150,15 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const { case MVT::v2i16: case MVT::v2i32: case MVT::v2i64: + case MVT::v2f16: case MVT::v2f32: case MVT::v2f64: case MVT::v4i8: case MVT::v4i16: case MVT::v4i32: + case MVT::v4f16: case MVT::v4f32: + case MVT::v8f16: // <4 x f16x2> // This is a "native" vector type break; } @@ -2016,6 +2189,7 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const { if (EltVT.getSizeInBits() < 16) NeedExt = true; + bool StoreF16x2 = false; switch (NumElts) { default: return SDValue(); @@ -2025,6 +2199,14 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const { case 4: Opcode = NVPTXISD::StoreV4; break; + case 8: + // v8f16 is a special case. PTX doesn't have st.v8.f16 + // instruction. Instead, we split the vector into v2f16 chunks and + // store them with st.v4.b32. + assert(EltVT == MVT::f16 && "Wrong type for the vector."); + Opcode = NVPTXISD::StoreV4; + StoreF16x2 = true; + break; } SmallVector<SDValue, 8> Ops; @@ -2032,23 +2214,36 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const { // First is the chain Ops.push_back(N->getOperand(0)); - // Then the split values - for (unsigned i = 0; i < NumElts; ++i) { - SDValue ExtVal = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Val, - DAG.getIntPtrConstant(i, DL)); - if (NeedExt) - ExtVal = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i16, ExtVal); - Ops.push_back(ExtVal); + if (StoreF16x2) { + // Combine f16,f16 -> v2f16 + NumElts /= 2; + for (unsigned i = 0; i < NumElts; ++i) { + SDValue E0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::f16, Val, + DAG.getIntPtrConstant(i * 2, DL)); + SDValue E1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::f16, Val, + DAG.getIntPtrConstant(i * 2 + 1, DL)); + SDValue V2 = DAG.getNode(ISD::BUILD_VECTOR, DL, MVT::v2f16, E0, E1); + Ops.push_back(V2); + } + } else { + // Then the split values + for (unsigned i = 0; i < NumElts; ++i) { + SDValue ExtVal = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Val, + DAG.getIntPtrConstant(i, DL)); + if (NeedExt) + ExtVal = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i16, ExtVal); + Ops.push_back(ExtVal); + } } // Then any remaining arguments Ops.append(N->op_begin() + 2, N->op_end()); - SDValue NewSt = DAG.getMemIntrinsicNode( - Opcode, DL, DAG.getVTList(MVT::Other), Ops, - MemSD->getMemoryVT(), MemSD->getMemOperand()); + SDValue NewSt = + DAG.getMemIntrinsicNode(Opcode, DL, DAG.getVTList(MVT::Other), Ops, + MemSD->getMemoryVT(), MemSD->getMemOperand()); - //return DCI.CombineTo(N, NewSt, true); + // return DCI.CombineTo(N, NewSt, true); return NewSt; } @@ -2120,7 +2315,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments( auto PtrVT = getPointerTy(DAG.getDataLayout()); const Function *F = MF.getFunction(); - const AttributeSet &PAL = F->getAttributes(); + const AttributeList &PAL = F->getAttributes(); const TargetLowering *TLI = STI.getTargetLowering(); SDValue Root = DAG.getRoot(); @@ -2200,177 +2395,80 @@ SDValue NVPTXTargetLowering::LowerFormalArguments( // to newly created nodes. The SDNodes for params have to // appear in the same order as their order of appearance // in the original function. "idx+1" holds that order. - if (!PAL.hasAttribute(i + 1, Attribute::ByVal)) { - if (Ty->isAggregateType()) { - SmallVector<EVT, 16> vtparts; - SmallVector<uint64_t, 16> offsets; + if (!PAL.hasParamAttribute(i, Attribute::ByVal)) { + bool aggregateIsPacked = false; + if (StructType *STy = dyn_cast<StructType>(Ty)) + aggregateIsPacked = STy->isPacked(); - // NOTE: Here, we lose the ability to issue vector loads for vectors - // that are a part of a struct. This should be investigated in the - // future. - ComputePTXValueVTs(*this, DAG.getDataLayout(), Ty, vtparts, &offsets, - 0); - assert(vtparts.size() > 0 && "empty aggregate type not expected"); - bool aggregateIsPacked = false; - if (StructType *STy = dyn_cast<StructType>(Ty)) - aggregateIsPacked = STy->isPacked(); + SmallVector<EVT, 16> VTs; + SmallVector<uint64_t, 16> Offsets; + ComputePTXValueVTs(*this, DL, Ty, VTs, &Offsets, 0); + assert(VTs.size() > 0 && "Unexpected empty type."); + auto VectorInfo = + VectorizePTXValueVTs(VTs, Offsets, DL.getABITypeAlignment(Ty)); - SDValue Arg = getParamSymbol(DAG, idx, PtrVT); - for (unsigned parti = 0, parte = vtparts.size(); parti != parte; - ++parti) { - EVT partVT = vtparts[parti]; - Value *srcValue = Constant::getNullValue( - PointerType::get(partVT.getTypeForEVT(F->getContext()), - ADDRESS_SPACE_PARAM)); - SDValue srcAddr = - DAG.getNode(ISD::ADD, dl, PtrVT, Arg, - DAG.getConstant(offsets[parti], dl, PtrVT)); - unsigned partAlign = aggregateIsPacked - ? 1 - : DL.getABITypeAlignment( - partVT.getTypeForEVT(F->getContext())); - SDValue p; - if (Ins[InsIdx].VT.getSizeInBits() > partVT.getSizeInBits()) { - ISD::LoadExtType ExtOp = Ins[InsIdx].Flags.isSExt() ? - ISD::SEXTLOAD : ISD::ZEXTLOAD; - p = DAG.getExtLoad(ExtOp, dl, Ins[InsIdx].VT, Root, srcAddr, - MachinePointerInfo(srcValue), partVT, partAlign); - } else { - p = DAG.getLoad(partVT, dl, Root, srcAddr, - MachinePointerInfo(srcValue), partAlign); - } - if (p.getNode()) - p.getNode()->setIROrder(idx + 1); - InVals.push_back(p); - ++InsIdx; + SDValue Arg = getParamSymbol(DAG, idx, PtrVT); + int VecIdx = -1; // Index of the first element of the current vector. + for (unsigned parti = 0, parte = VTs.size(); parti != parte; ++parti) { + if (VectorInfo[parti] & PVF_FIRST) { + assert(VecIdx == -1 && "Orphaned vector."); + VecIdx = parti; } - if (vtparts.size() > 0) - --InsIdx; - continue; - } - if (Ty->isVectorTy()) { - EVT ObjectVT = getValueType(DL, Ty); - SDValue Arg = getParamSymbol(DAG, idx, PtrVT); - unsigned NumElts = ObjectVT.getVectorNumElements(); - assert(TLI->getNumRegisters(F->getContext(), ObjectVT) == NumElts && - "Vector was not scalarized"); - EVT EltVT = ObjectVT.getVectorElementType(); - - // V1 load - // f32 = load ... - if (NumElts == 1) { - // We only have one element, so just directly load it - Value *SrcValue = Constant::getNullValue(PointerType::get( - EltVT.getTypeForEVT(F->getContext()), ADDRESS_SPACE_PARAM)); - SDValue P = DAG.getLoad( - EltVT, dl, Root, Arg, MachinePointerInfo(SrcValue), - DL.getABITypeAlignment(EltVT.getTypeForEVT(F->getContext())), - MachineMemOperand::MODereferenceable | - MachineMemOperand::MOInvariant); - if (P.getNode()) - P.getNode()->setIROrder(idx + 1); - if (Ins[InsIdx].VT.getSizeInBits() > EltVT.getSizeInBits()) - P = DAG.getNode(ISD::ANY_EXTEND, dl, Ins[InsIdx].VT, P); - InVals.push_back(P); - ++InsIdx; - } else if (NumElts == 2) { - // V2 load - // f32,f32 = load ... - EVT VecVT = EVT::getVectorVT(F->getContext(), EltVT, 2); - Value *SrcValue = Constant::getNullValue(PointerType::get( - VecVT.getTypeForEVT(F->getContext()), ADDRESS_SPACE_PARAM)); - SDValue P = DAG.getLoad( - VecVT, dl, Root, Arg, MachinePointerInfo(SrcValue), - DL.getABITypeAlignment(VecVT.getTypeForEVT(F->getContext())), - MachineMemOperand::MODereferenceable | - MachineMemOperand::MOInvariant); + // That's the last element of this store op. + if (VectorInfo[parti] & PVF_LAST) { + unsigned NumElts = parti - VecIdx + 1; + EVT EltVT = VTs[parti]; + // i1 is loaded/stored as i8. + EVT LoadVT = EltVT; + if (EltVT == MVT::i1) + LoadVT = MVT::i8; + else if (EltVT == MVT::v2f16) + // getLoad needs a vector type, but it can't handle + // vectors which contain v2f16 elements. So we must load + // using i32 here and then bitcast back. + LoadVT = MVT::i32; + + EVT VecVT = EVT::getVectorVT(F->getContext(), LoadVT, NumElts); + SDValue VecAddr = + DAG.getNode(ISD::ADD, dl, PtrVT, Arg, + DAG.getConstant(Offsets[VecIdx], dl, PtrVT)); + Value *srcValue = Constant::getNullValue(PointerType::get( + EltVT.getTypeForEVT(F->getContext()), ADDRESS_SPACE_PARAM)); + SDValue P = + DAG.getLoad(VecVT, dl, Root, VecAddr, + MachinePointerInfo(srcValue), aggregateIsPacked, + MachineMemOperand::MODereferenceable | + MachineMemOperand::MOInvariant); if (P.getNode()) P.getNode()->setIROrder(idx + 1); - - SDValue Elt0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, P, - DAG.getIntPtrConstant(0, dl)); - SDValue Elt1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, P, - DAG.getIntPtrConstant(1, dl)); - - if (Ins[InsIdx].VT.getSizeInBits() > EltVT.getSizeInBits()) { - Elt0 = DAG.getNode(ISD::ANY_EXTEND, dl, Ins[InsIdx].VT, Elt0); - Elt1 = DAG.getNode(ISD::ANY_EXTEND, dl, Ins[InsIdx].VT, Elt1); - } - - InVals.push_back(Elt0); - InVals.push_back(Elt1); - InsIdx += 2; - } else { - // V4 loads - // We have at least 4 elements (<3 x Ty> expands to 4 elements) and - // the vector will be expanded to a power of 2 elements, so we know we - // can always round up to the next multiple of 4 when creating the - // vector loads. - // e.g. 4 elem => 1 ld.v4 - // 6 elem => 2 ld.v4 - // 8 elem => 2 ld.v4 - // 11 elem => 3 ld.v4 - unsigned VecSize = 4; - if (EltVT.getSizeInBits() == 64) { - VecSize = 2; - } - EVT VecVT = EVT::getVectorVT(F->getContext(), EltVT, VecSize); - unsigned Ofst = 0; - for (unsigned i = 0; i < NumElts; i += VecSize) { - Value *SrcValue = Constant::getNullValue( - PointerType::get(VecVT.getTypeForEVT(F->getContext()), - ADDRESS_SPACE_PARAM)); - SDValue SrcAddr = DAG.getNode(ISD::ADD, dl, PtrVT, Arg, - DAG.getConstant(Ofst, dl, PtrVT)); - SDValue P = DAG.getLoad( - VecVT, dl, Root, SrcAddr, MachinePointerInfo(SrcValue), - DL.getABITypeAlignment(VecVT.getTypeForEVT(F->getContext())), - MachineMemOperand::MODereferenceable | - MachineMemOperand::MOInvariant); - if (P.getNode()) - P.getNode()->setIROrder(idx + 1); - - for (unsigned j = 0; j < VecSize; ++j) { - if (i + j >= NumElts) - break; - SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, P, - DAG.getIntPtrConstant(j, dl)); - if (Ins[InsIdx].VT.getSizeInBits() > EltVT.getSizeInBits()) - Elt = DAG.getNode(ISD::ANY_EXTEND, dl, Ins[InsIdx].VT, Elt); - InVals.push_back(Elt); + for (unsigned j = 0; j < NumElts; ++j) { + SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, LoadVT, P, + DAG.getIntPtrConstant(j, dl)); + // We've loaded i1 as an i8 and now must truncate it back to i1 + if (EltVT == MVT::i1) + Elt = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, Elt); + // v2f16 was loaded as an i32. Now we must bitcast it back. + else if (EltVT == MVT::v2f16) + Elt = DAG.getNode(ISD::BITCAST, dl, MVT::v2f16, Elt); + // Extend the element if necesary (e.g. an i8 is loaded + // into an i16 register) + if (Ins[InsIdx].VT.isInteger() && + Ins[InsIdx].VT.getSizeInBits() > LoadVT.getSizeInBits()) { + unsigned Extend = Ins[InsIdx].Flags.isSExt() ? ISD::SIGN_EXTEND + : ISD::ZERO_EXTEND; + Elt = DAG.getNode(Extend, dl, Ins[InsIdx].VT, Elt); } - Ofst += DL.getTypeAllocSize(VecVT.getTypeForEVT(F->getContext())); + InVals.push_back(Elt); } - InsIdx += NumElts; - } - if (NumElts > 0) - --InsIdx; - continue; - } - // A plain scalar. - EVT ObjectVT = getValueType(DL, Ty); - // If ABI, load from the param symbol - SDValue Arg = getParamSymbol(DAG, idx, PtrVT); - Value *srcValue = Constant::getNullValue(PointerType::get( - ObjectVT.getTypeForEVT(F->getContext()), ADDRESS_SPACE_PARAM)); - SDValue p; - if (ObjectVT.getSizeInBits() < Ins[InsIdx].VT.getSizeInBits()) { - ISD::LoadExtType ExtOp = Ins[InsIdx].Flags.isSExt() ? - ISD::SEXTLOAD : ISD::ZEXTLOAD; - p = DAG.getExtLoad( - ExtOp, dl, Ins[InsIdx].VT, Root, Arg, MachinePointerInfo(srcValue), - ObjectVT, - DL.getABITypeAlignment(ObjectVT.getTypeForEVT(F->getContext()))); - } else { - p = DAG.getLoad( - Ins[InsIdx].VT, dl, Root, Arg, MachinePointerInfo(srcValue), - DL.getABITypeAlignment(ObjectVT.getTypeForEVT(F->getContext()))); + // Reset vector tracking state. + VecIdx = -1; + } + ++InsIdx; } - if (p.getNode()) - p.getNode()->setIROrder(idx + 1); - InVals.push_back(p); + if (VTs.size() > 0) + --InsIdx; continue; } @@ -2412,164 +2510,77 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv, const SmallVectorImpl<SDValue> &OutVals, const SDLoc &dl, SelectionDAG &DAG) const { MachineFunction &MF = DAG.getMachineFunction(); - const Function *F = MF.getFunction(); - Type *RetTy = F->getReturnType(); - const DataLayout &TD = DAG.getDataLayout(); + Type *RetTy = MF.getFunction()->getReturnType(); bool isABI = (STI.getSmVersion() >= 20); assert(isABI && "Non-ABI compilation is not supported"); if (!isABI) return Chain; - if (VectorType *VTy = dyn_cast<VectorType>(RetTy)) { - // If we have a vector type, the OutVals array will be the scalarized - // components and we have combine them into 1 or more vector stores. - unsigned NumElts = VTy->getNumElements(); - assert(NumElts == Outs.size() && "Bad scalarization of return value"); + const DataLayout DL = DAG.getDataLayout(); + SmallVector<EVT, 16> VTs; + SmallVector<uint64_t, 16> Offsets; + ComputePTXValueVTs(*this, DL, RetTy, VTs, &Offsets); + assert(VTs.size() == OutVals.size() && "Bad return value decomposition"); + + auto VectorInfo = VectorizePTXValueVTs( + VTs, Offsets, RetTy->isSized() ? DL.getABITypeAlignment(RetTy) : 1); + + // PTX Interoperability Guide 3.3(A): [Integer] Values shorter than + // 32-bits are sign extended or zero extended, depending on whether + // they are signed or unsigned types. + bool ExtendIntegerRetVal = + RetTy->isIntegerTy() && DL.getTypeAllocSizeInBits(RetTy) < 32; + + SmallVector<SDValue, 6> StoreOperands; + for (unsigned i = 0, e = VTs.size(); i != e; ++i) { + // New load/store. Record chain and offset operands. + if (VectorInfo[i] & PVF_FIRST) { + assert(StoreOperands.empty() && "Orphaned operand list."); + StoreOperands.push_back(Chain); + StoreOperands.push_back(DAG.getConstant(Offsets[i], dl, MVT::i32)); + } - // const_cast can be removed in later LLVM versions - EVT EltVT = getValueType(TD, RetTy).getVectorElementType(); - bool NeedExtend = false; - if (EltVT.getSizeInBits() < 16) - NeedExtend = true; - - // V1 store - if (NumElts == 1) { - SDValue StoreVal = OutVals[0]; - // We only have one element, so just directly store it - if (NeedExtend) - StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal); - SDValue Ops[] = { Chain, DAG.getConstant(0, dl, MVT::i32), StoreVal }; - Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreRetval, dl, - DAG.getVTList(MVT::Other), Ops, - EltVT, MachinePointerInfo()); - } else if (NumElts == 2) { - // V2 store - SDValue StoreVal0 = OutVals[0]; - SDValue StoreVal1 = OutVals[1]; - - if (NeedExtend) { - StoreVal0 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal0); - StoreVal1 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal1); - } + SDValue RetVal = OutVals[i]; + if (ExtendIntegerRetVal) { + RetVal = DAG.getNode(Outs[i].Flags.isSExt() ? ISD::SIGN_EXTEND + : ISD::ZERO_EXTEND, + dl, MVT::i32, RetVal); + } else if (RetVal.getValueSizeInBits() < 16) { + // Use 16-bit registers for small load-stores as it's the + // smallest general purpose register size supported by NVPTX. + RetVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, RetVal); + } - SDValue Ops[] = { Chain, DAG.getConstant(0, dl, MVT::i32), StoreVal0, - StoreVal1 }; - Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreRetvalV2, dl, - DAG.getVTList(MVT::Other), Ops, - EltVT, MachinePointerInfo()); - } else { - // V4 stores - // We have at least 4 elements (<3 x Ty> expands to 4 elements) and the - // vector will be expanded to a power of 2 elements, so we know we can - // always round up to the next multiple of 4 when creating the vector - // stores. - // e.g. 4 elem => 1 st.v4 - // 6 elem => 2 st.v4 - // 8 elem => 2 st.v4 - // 11 elem => 3 st.v4 - - unsigned VecSize = 4; - if (OutVals[0].getValueSizeInBits() == 64) - VecSize = 2; - - unsigned Offset = 0; - - EVT VecVT = - EVT::getVectorVT(F->getContext(), EltVT, VecSize); - unsigned PerStoreOffset = - TD.getTypeAllocSize(VecVT.getTypeForEVT(F->getContext())); - - for (unsigned i = 0; i < NumElts; i += VecSize) { - // Get values - SDValue StoreVal; - SmallVector<SDValue, 8> Ops; - Ops.push_back(Chain); - Ops.push_back(DAG.getConstant(Offset, dl, MVT::i32)); - unsigned Opc = NVPTXISD::StoreRetvalV2; - EVT ExtendedVT = (NeedExtend) ? MVT::i16 : OutVals[0].getValueType(); - - StoreVal = OutVals[i]; - if (NeedExtend) - StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, ExtendedVT, StoreVal); - Ops.push_back(StoreVal); - - if (i + 1 < NumElts) { - StoreVal = OutVals[i + 1]; - if (NeedExtend) - StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, ExtendedVT, StoreVal); - } else { - StoreVal = DAG.getUNDEF(ExtendedVT); - } - Ops.push_back(StoreVal); - - if (VecSize == 4) { - Opc = NVPTXISD::StoreRetvalV4; - if (i + 2 < NumElts) { - StoreVal = OutVals[i + 2]; - if (NeedExtend) - StoreVal = - DAG.getNode(ISD::ZERO_EXTEND, dl, ExtendedVT, StoreVal); - } else { - StoreVal = DAG.getUNDEF(ExtendedVT); - } - Ops.push_back(StoreVal); - - if (i + 3 < NumElts) { - StoreVal = OutVals[i + 3]; - if (NeedExtend) - StoreVal = - DAG.getNode(ISD::ZERO_EXTEND, dl, ExtendedVT, StoreVal); - } else { - StoreVal = DAG.getUNDEF(ExtendedVT); - } - Ops.push_back(StoreVal); - } + // Record the value to return. + StoreOperands.push_back(RetVal); - // Chain = DAG.getNode(Opc, dl, MVT::Other, &Ops[0], Ops.size()); - Chain = - DAG.getMemIntrinsicNode(Opc, dl, DAG.getVTList(MVT::Other), Ops, - EltVT, MachinePointerInfo()); - Offset += PerStoreOffset; - } - } - } else { - SmallVector<EVT, 16> ValVTs; - SmallVector<uint64_t, 16> Offsets; - ComputePTXValueVTs(*this, DAG.getDataLayout(), RetTy, ValVTs, &Offsets, 0); - assert(ValVTs.size() == OutVals.size() && "Bad return value decomposition"); - - for (unsigned i = 0, e = Outs.size(); i != e; ++i) { - SDValue theVal = OutVals[i]; - EVT TheValType = theVal.getValueType(); - unsigned numElems = 1; - if (TheValType.isVector()) - numElems = TheValType.getVectorNumElements(); - for (unsigned j = 0, je = numElems; j != je; ++j) { - SDValue TmpVal = theVal; - if (TheValType.isVector()) - TmpVal = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, - TheValType.getVectorElementType(), TmpVal, - DAG.getIntPtrConstant(j, dl)); - EVT TheStoreType = ValVTs[i]; - if (RetTy->isIntegerTy() && TD.getTypeAllocSizeInBits(RetTy) < 32) { - // The following zero-extension is for integer types only, and - // specifically not for aggregates. - TmpVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i32, TmpVal); - TheStoreType = MVT::i32; - } - else if (TmpVal.getValueSizeInBits() < 16) - TmpVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, TmpVal); - - SDValue Ops[] = { - Chain, - DAG.getConstant(Offsets[i], dl, MVT::i32), - TmpVal }; - Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreRetval, dl, - DAG.getVTList(MVT::Other), Ops, - TheStoreType, - MachinePointerInfo()); + // That's the last element of this store op. + if (VectorInfo[i] & PVF_LAST) { + NVPTXISD::NodeType Op; + unsigned NumElts = StoreOperands.size() - 2; + switch (NumElts) { + case 1: + Op = NVPTXISD::StoreRetval; + break; + case 2: + Op = NVPTXISD::StoreRetvalV2; + break; + case 4: + Op = NVPTXISD::StoreRetvalV4; + break; + default: + llvm_unreachable("Invalid vector info."); } + + // Adjust type of load/store op if we've extended the scalar + // return value. + EVT TheStoreType = ExtendIntegerRetVal ? MVT::i32 : VTs[i]; + Chain = DAG.getMemIntrinsicNode(Op, dl, DAG.getVTList(MVT::Other), + StoreOperands, TheStoreType, + MachinePointerInfo(), 1); + // Cleanup vector state. + StoreOperands.clear(); } } @@ -3863,27 +3874,35 @@ NVPTXTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI, bool NVPTXTargetLowering::allowFMA(MachineFunction &MF, CodeGenOpt::Level OptLevel) const { - const Function *F = MF.getFunction(); - const TargetOptions &TO = MF.getTarget().Options; - // Always honor command-line argument - if (FMAContractLevelOpt.getNumOccurrences() > 0) { + if (FMAContractLevelOpt.getNumOccurrences() > 0) return FMAContractLevelOpt > 0; - } else if (OptLevel == 0) { - // Do not contract if we're not optimizing the code + + // Do not contract if we're not optimizing the code. + if (OptLevel == 0) return false; - } else if (TO.AllowFPOpFusion == FPOpFusion::Fast || TO.UnsafeFPMath) { - // Honor TargetOptions flags that explicitly say fusion is okay + + // Honor TargetOptions flags that explicitly say fusion is okay. + if (MF.getTarget().Options.AllowFPOpFusion == FPOpFusion::Fast) return true; - } else if (F->hasFnAttribute("unsafe-fp-math")) { - // Check for unsafe-fp-math=true coming from Clang + + return allowUnsafeFPMath(MF); +} + +bool NVPTXTargetLowering::allowUnsafeFPMath(MachineFunction &MF) const { + // Honor TargetOptions flags that explicitly say unsafe math is okay. + if (MF.getTarget().Options.UnsafeFPMath) + return true; + + // Allow unsafe math if unsafe-fp-math attribute explicitly says so. + const Function *F = MF.getFunction(); + if (F->hasFnAttribute("unsafe-fp-math")) { Attribute Attr = F->getFnAttribute("unsafe-fp-math"); StringRef Val = Attr.getValueAsString(); if (Val == "true") return true; } - // We did not have a clear indication that fusion is allowed, so assume not return false; } @@ -4088,67 +4107,6 @@ static SDValue PerformANDCombine(SDNode *N, return SDValue(); } -static SDValue PerformSELECTCombine(SDNode *N, - TargetLowering::DAGCombinerInfo &DCI) { - // Currently this detects patterns for integer min and max and - // lowers them to PTX-specific intrinsics that enable hardware - // support. - - const SDValue Cond = N->getOperand(0); - if (Cond.getOpcode() != ISD::SETCC) return SDValue(); - - const SDValue LHS = Cond.getOperand(0); - const SDValue RHS = Cond.getOperand(1); - const SDValue True = N->getOperand(1); - const SDValue False = N->getOperand(2); - if (!(LHS == True && RHS == False) && !(LHS == False && RHS == True)) - return SDValue(); - - const EVT VT = N->getValueType(0); - if (VT != MVT::i32 && VT != MVT::i64) return SDValue(); - - const ISD::CondCode CC = cast<CondCodeSDNode>(Cond.getOperand(2))->get(); - SDValue Larger; // The larger of LHS and RHS when condition is true. - switch (CC) { - case ISD::SETULT: - case ISD::SETULE: - case ISD::SETLT: - case ISD::SETLE: - Larger = RHS; - break; - - case ISD::SETGT: - case ISD::SETGE: - case ISD::SETUGT: - case ISD::SETUGE: - Larger = LHS; - break; - - default: - return SDValue(); - } - const bool IsMax = (Larger == True); - const bool IsSigned = ISD::isSignedIntSetCC(CC); - - unsigned IntrinsicId; - if (VT == MVT::i32) { - if (IsSigned) - IntrinsicId = IsMax ? Intrinsic::nvvm_max_i : Intrinsic::nvvm_min_i; - else - IntrinsicId = IsMax ? Intrinsic::nvvm_max_ui : Intrinsic::nvvm_min_ui; - } else { - assert(VT == MVT::i64); - if (IsSigned) - IntrinsicId = IsMax ? Intrinsic::nvvm_max_ll : Intrinsic::nvvm_min_ll; - else - IntrinsicId = IsMax ? Intrinsic::nvvm_max_ull : Intrinsic::nvvm_min_ull; - } - - SDLoc DL(N); - return DCI.DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT, - DCI.DAG.getConstant(IntrinsicId, DL, VT), LHS, RHS); -} - static SDValue PerformREMCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, CodeGenOpt::Level OptLevel) { @@ -4344,6 +4302,27 @@ static SDValue PerformSHLCombine(SDNode *N, return SDValue(); } +static SDValue PerformSETCCCombine(SDNode *N, + TargetLowering::DAGCombinerInfo &DCI) { + EVT CCType = N->getValueType(0); + SDValue A = N->getOperand(0); + SDValue B = N->getOperand(1); + + if (CCType != MVT::v2i1 || A.getValueType() != MVT::v2f16) + return SDValue(); + + SDLoc DL(N); + // setp.f16x2 returns two scalar predicates, which we need to + // convert back to v2i1. The returned result will be scalarized by + // the legalizer, but the comparison will remain a single vector + // instruction. + SDValue CCNode = DCI.DAG.getNode(NVPTXISD::SETP_F16X2, DL, + DCI.DAG.getVTList(MVT::i1, MVT::i1), + {A, B, N->getOperand(2)}); + return DCI.DAG.getNode(ISD::BUILD_VECTOR, DL, CCType, CCNode.getValue(0), + CCNode.getValue(1)); +} + SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N, DAGCombinerInfo &DCI) const { CodeGenOpt::Level OptLevel = getTargetMachine().getOptLevel(); @@ -4358,11 +4337,11 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N, return PerformSHLCombine(N, DCI, OptLevel); case ISD::AND: return PerformANDCombine(N, DCI); - case ISD::SELECT: - return PerformSELECTCombine(N, DCI); case ISD::UREM: case ISD::SREM: return PerformREMCombine(N, DCI, OptLevel); + case ISD::SETCC: + return PerformSETCCCombine(N, DCI); } return SDValue(); } @@ -4386,12 +4365,15 @@ static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG, case MVT::v2i16: case MVT::v2i32: case MVT::v2i64: + case MVT::v2f16: case MVT::v2f32: case MVT::v2f64: case MVT::v4i8: case MVT::v4i16: case MVT::v4i32: + case MVT::v4f16: case MVT::v4f32: + case MVT::v8f16: // <4 x f16x2> // This is a "native" vector type break; } @@ -4425,6 +4407,7 @@ static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG, unsigned Opcode = 0; SDVTList LdResVTs; + bool LoadF16x2 = false; switch (NumElts) { default: @@ -4439,6 +4422,18 @@ static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG, LdResVTs = DAG.getVTList(ListVTs); break; } + case 8: { + // v8f16 is a special case. PTX doesn't have ld.v8.f16 + // instruction. Instead, we split the vector into v2f16 chunks and + // load them with ld.v4.b32. + assert(EltVT == MVT::f16 && "Unsupported v8 vector type."); + LoadF16x2 = true; + Opcode = NVPTXISD::LoadV4; + EVT ListVTs[] = {MVT::v2f16, MVT::v2f16, MVT::v2f16, MVT::v2f16, + MVT::Other}; + LdResVTs = DAG.getVTList(ListVTs); + break; + } } // Copy regular operands @@ -4452,13 +4447,26 @@ static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG, LD->getMemoryVT(), LD->getMemOperand()); - SmallVector<SDValue, 4> ScalarRes; - - for (unsigned i = 0; i < NumElts; ++i) { - SDValue Res = NewLD.getValue(i); - if (NeedTrunc) - Res = DAG.getNode(ISD::TRUNCATE, DL, ResVT.getVectorElementType(), Res); - ScalarRes.push_back(Res); + SmallVector<SDValue, 8> ScalarRes; + if (LoadF16x2) { + // Split v2f16 subvectors back into individual elements. + NumElts /= 2; + for (unsigned i = 0; i < NumElts; ++i) { + SDValue SubVector = NewLD.getValue(i); + SDValue E0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, SubVector, + DAG.getIntPtrConstant(0, DL)); + SDValue E1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, SubVector, + DAG.getIntPtrConstant(1, DL)); + ScalarRes.push_back(E0); + ScalarRes.push_back(E1); + } + } else { + for (unsigned i = 0; i < NumElts; ++i) { + SDValue Res = NewLD.getValue(i); + if (NeedTrunc) + Res = DAG.getNode(ISD::TRUNCATE, DL, ResVT.getVectorElementType(), Res); + ScalarRes.push_back(Res); + } } SDValue LoadChain = NewLD.getValue(NumElts); |