diff options
Diffstat (limited to 'lib/Target/PTX/PTXISelLowering.cpp')
-rw-r--r-- | lib/Target/PTX/PTXISelLowering.cpp | 273 |
1 files changed, 174 insertions, 99 deletions
diff --git a/lib/Target/PTX/PTXISelLowering.cpp b/lib/Target/PTX/PTXISelLowering.cpp index 6fcf710e3f1f..3307d91a6188 100644 --- a/lib/Target/PTX/PTXISelLowering.cpp +++ b/lib/Target/PTX/PTXISelLowering.cpp @@ -16,23 +16,19 @@ #include "PTXMachineFunctionInfo.h" #include "PTXRegisterInfo.h" #include "PTXSubtarget.h" +#include "llvm/Function.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/CodeGen/CallingConvLower.h" #include "llvm/CodeGen/MachineFunction.h" #include "llvm/CodeGen/MachineRegisterInfo.h" #include "llvm/CodeGen/SelectionDAG.h" #include "llvm/CodeGen/TargetLoweringObjectFileImpl.h" +#include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" using namespace llvm; //===----------------------------------------------------------------------===// -// Calling Convention Implementation -//===----------------------------------------------------------------------===// - -#include "PTXGenCallingConv.inc" - -//===----------------------------------------------------------------------===// // TargetLowering Implementation //===----------------------------------------------------------------------===// @@ -47,57 +43,58 @@ PTXTargetLowering::PTXTargetLowering(TargetMachine &TM) addRegisterClass(MVT::f64, PTX::RegF64RegisterClass); setBooleanContents(ZeroOrOneBooleanContent); + setBooleanVectorContents(ZeroOrOneBooleanContent); // FIXME: Is this correct? setMinFunctionAlignment(2); - + //////////////////////////////////// /////////// Expansion ////////////// //////////////////////////////////// - + // (any/zero/sign) extload => load + (any/zero/sign) extend - + setLoadExtAction(ISD::EXTLOAD, MVT::i16, Expand); setLoadExtAction(ISD::ZEXTLOAD, MVT::i16, Expand); setLoadExtAction(ISD::SEXTLOAD, MVT::i16, Expand); - + // f32 extload => load + fextend - - setLoadExtAction(ISD::EXTLOAD, MVT::f32, Expand); - + + setLoadExtAction(ISD::EXTLOAD, MVT::f32, Expand); + // f64 truncstore => trunc + store - - setTruncStoreAction(MVT::f64, MVT::f32, Expand); - + + setTruncStoreAction(MVT::f64, MVT::f32, Expand); + // sign_extend_inreg => sign_extend - + setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i1, Expand); - + // br_cc => brcond - + setOperationAction(ISD::BR_CC, MVT::Other, Expand); // select_cc => setcc - + setOperationAction(ISD::SELECT_CC, MVT::Other, Expand); setOperationAction(ISD::SELECT_CC, MVT::f32, Expand); setOperationAction(ISD::SELECT_CC, MVT::f64, Expand); - + //////////////////////////////////// //////////// Legal ///////////////// //////////////////////////////////// - + setOperationAction(ISD::ConstantFP, MVT::f32, Legal); setOperationAction(ISD::ConstantFP, MVT::f64, Legal); - + //////////////////////////////////// //////////// Custom //////////////// //////////////////////////////////// - + // customise setcc to use bitwise logic if possible - + setOperationAction(ISD::SETCC, MVT::i1, Custom); // customize translation of memory addresses - + setOperationAction(ISD::GlobalAddress, MVT::i32, Custom); setOperationAction(ISD::GlobalAddress, MVT::i64, Custom); @@ -105,7 +102,7 @@ PTXTargetLowering::PTXTargetLowering(TargetMachine &TM) computeRegisterProperties(); } -MVT::SimpleValueType PTXTargetLowering::getSetCCResultType(EVT VT) const { +EVT PTXTargetLowering::getSetCCResultType(EVT VT) const { return MVT::i1; } @@ -130,10 +127,16 @@ const char *PTXTargetLowering::getTargetNodeName(unsigned Opcode) const { return "PTXISD::LOAD_PARAM"; case PTXISD::STORE_PARAM: return "PTXISD::STORE_PARAM"; + case PTXISD::READ_PARAM: + return "PTXISD::READ_PARAM"; + case PTXISD::WRITE_PARAM: + return "PTXISD::WRITE_PARAM"; case PTXISD::EXIT: return "PTXISD::EXIT"; case PTXISD::RET: return "PTXISD::RET"; + case PTXISD::CALL: + return "PTXISD::CALL"; } } @@ -149,7 +152,7 @@ SDValue PTXTargetLowering::LowerSETCC(SDValue Op, SelectionDAG &DAG) const { DebugLoc dl = Op.getDebugLoc(); ISD::CondCode CC = cast<CondCodeSDNode>(Op.getOperand(2))->get(); - // Look for X == 0, X == 1, X != 0, or X != 1 + // Look for X == 0, X == 1, X != 0, or X != 1 // We can simplify these to bitwise logic if (Op1.getOpcode() == ISD::Constant && @@ -197,6 +200,7 @@ SDValue PTXTargetLowering:: MachineFunction &MF = DAG.getMachineFunction(); const PTXSubtarget& ST = getTargetMachine().getSubtarget<PTXSubtarget>(); PTXMachineFunctionInfo *MFI = MF.getInfo<PTXMachineFunctionInfo>(); + PTXParamManager &PM = MFI->getParamManager(); switch (CallConv) { default: @@ -216,68 +220,34 @@ SDValue PTXTargetLowering:: if (MFI->isKernel() || ST.useParamSpaceForDeviceArgs()) { // We just need to emit the proper LOAD_PARAM ISDs for (unsigned i = 0, e = Ins.size(); i != e; ++i) { - assert((!MFI->isKernel() || Ins[i].VT != MVT::i1) && "Kernels cannot take pred operands"); + unsigned ParamSize = Ins[i].VT.getStoreSizeInBits(); + unsigned Param = PM.addArgumentParam(ParamSize); + const std::string &ParamName = PM.getParamName(Param); + SDValue ParamValue = DAG.getTargetExternalSymbol(ParamName.c_str(), + MVT::Other); SDValue ArgValue = DAG.getNode(PTXISD::LOAD_PARAM, dl, Ins[i].VT, Chain, - DAG.getTargetConstant(i, MVT::i32)); + ParamValue); InVals.push_back(ArgValue); - - // Instead of storing a physical register in our argument list, we just - // store the total size of the parameter, in bits. The ASM printer - // knows how to process this. - MFI->addArgReg(Ins[i].VT.getStoreSizeInBits()); } } else { - // For device functions, we use the PTX calling convention to do register - // assignments then create CopyFromReg ISDs for the allocated registers - - SmallVector<CCValAssign, 16> ArgLocs; - CCState CCInfo(CallConv, isVarArg, MF, getTargetMachine(), ArgLocs, - *DAG.getContext()); - - CCInfo.AnalyzeFormalArguments(Ins, CC_PTX); - - for (unsigned i = 0, e = ArgLocs.size(); i != e; ++i) { - - CCValAssign& VA = ArgLocs[i]; - EVT RegVT = VA.getLocVT(); - TargetRegisterClass* TRC = 0; - - assert(VA.isRegLoc() && "CCValAssign must be RegLoc"); - - // Determine which register class we need - if (RegVT == MVT::i1) { - TRC = PTX::RegPredRegisterClass; - } - else if (RegVT == MVT::i16) { - TRC = PTX::RegI16RegisterClass; - } - else if (RegVT == MVT::i32) { - TRC = PTX::RegI32RegisterClass; - } - else if (RegVT == MVT::i64) { - TRC = PTX::RegI64RegisterClass; - } - else if (RegVT == MVT::f32) { - TRC = PTX::RegF32RegisterClass; - } - else if (RegVT == MVT::f64) { - TRC = PTX::RegF64RegisterClass; - } - else { - llvm_unreachable("Unknown parameter type"); - } + for (unsigned i = 0, e = Ins.size(); i != e; ++i) { + EVT RegVT = Ins[i].VT; + TargetRegisterClass* TRC = getRegClassFor(RegVT); + // Use a unique index in the instruction to prevent instruction folding. + // Yes, this is a hack. + SDValue Index = DAG.getTargetConstant(i, MVT::i32); unsigned Reg = MF.getRegInfo().createVirtualRegister(TRC); - MF.getRegInfo().addLiveIn(VA.getLocReg(), Reg); + SDValue ArgValue = DAG.getNode(PTXISD::READ_PARAM, dl, RegVT, Chain, + Index); - SDValue ArgValue = DAG.getCopyFromReg(Chain, dl, Reg, RegVT); InVals.push_back(ArgValue); - MFI->addArgReg(VA.getLocReg()); + MFI->addArgReg(Reg); } } @@ -301,41 +271,66 @@ SDValue PTXTargetLowering:: assert(Outs.size() == 0 && "Kernel must return void."); return DAG.getNode(PTXISD::EXIT, dl, MVT::Other, Chain); case CallingConv::PTX_Device: - //assert(Outs.size() <= 1 && "Can at most return one value."); + assert(Outs.size() <= 1 && "Can at most return one value."); break; } MachineFunction& MF = DAG.getMachineFunction(); PTXMachineFunctionInfo *MFI = MF.getInfo<PTXMachineFunctionInfo>(); + PTXParamManager &PM = MFI->getParamManager(); SDValue Flag; + const PTXSubtarget& ST = getTargetMachine().getSubtarget<PTXSubtarget>(); - // Even though we could use the .param space for return arguments for - // device functions if SM >= 2.0 and the number of return arguments is - // only 1, we just always use registers since this makes the codegen - // easier. - SmallVector<CCValAssign, 16> RVLocs; - CCState CCInfo(CallConv, isVarArg, DAG.getMachineFunction(), - getTargetMachine(), RVLocs, *DAG.getContext()); - - CCInfo.AnalyzeReturn(Outs, RetCC_PTX); - - for (unsigned i = 0, e = RVLocs.size(); i != e; ++i) { - CCValAssign& VA = RVLocs[i]; - - assert(VA.isRegLoc() && "CCValAssign must be RegLoc"); + if (ST.useParamSpaceForDeviceArgs()) { + assert(Outs.size() < 2 && "Device functions can return at most one value"); + + if (Outs.size() == 1) { + unsigned ParamSize = OutVals[0].getValueType().getSizeInBits(); + unsigned Param = PM.addReturnParam(ParamSize); + const std::string &ParamName = PM.getParamName(Param); + SDValue ParamValue = DAG.getTargetExternalSymbol(ParamName.c_str(), + MVT::Other); + Chain = DAG.getNode(PTXISD::STORE_PARAM, dl, MVT::Other, Chain, + ParamValue, OutVals[0]); + } + } else { + for (unsigned i = 0, e = Outs.size(); i != e; ++i) { + EVT RegVT = Outs[i].VT; + TargetRegisterClass* TRC = 0; - unsigned Reg = VA.getLocReg(); + // Determine which register class we need + if (RegVT == MVT::i1) { + TRC = PTX::RegPredRegisterClass; + } + else if (RegVT == MVT::i16) { + TRC = PTX::RegI16RegisterClass; + } + else if (RegVT == MVT::i32) { + TRC = PTX::RegI32RegisterClass; + } + else if (RegVT == MVT::i64) { + TRC = PTX::RegI64RegisterClass; + } + else if (RegVT == MVT::f32) { + TRC = PTX::RegF32RegisterClass; + } + else if (RegVT == MVT::f64) { + TRC = PTX::RegF64RegisterClass; + } + else { + llvm_unreachable("Unknown parameter type"); + } - DAG.getMachineFunction().getRegInfo().addLiveOut(Reg); + unsigned Reg = MF.getRegInfo().createVirtualRegister(TRC); - Chain = DAG.getCopyToReg(Chain, dl, Reg, OutVals[i], Flag); + SDValue Copy = DAG.getCopyToReg(Chain, dl, Reg, OutVals[i]/*, Flag*/); + SDValue OutReg = DAG.getRegister(Reg, RegVT); - // Guarantee that all emitted copies are stuck together, - // avoiding something bad - Flag = Chain.getValue(1); + Chain = DAG.getNode(PTXISD::WRITE_PARAM, dl, MVT::Other, Copy, OutReg); - MFI->addRetReg(Reg); + MFI->addRetReg(Reg); + } } if (Flag.getNode() == 0) { @@ -345,3 +340,83 @@ SDValue PTXTargetLowering:: return DAG.getNode(PTXISD::RET, dl, MVT::Other, Chain, Flag); } } + +SDValue +PTXTargetLowering::LowerCall(SDValue Chain, SDValue Callee, + CallingConv::ID CallConv, bool isVarArg, + bool &isTailCall, + const SmallVectorImpl<ISD::OutputArg> &Outs, + const SmallVectorImpl<SDValue> &OutVals, + const SmallVectorImpl<ISD::InputArg> &Ins, + DebugLoc dl, SelectionDAG &DAG, + SmallVectorImpl<SDValue> &InVals) const { + + MachineFunction& MF = DAG.getMachineFunction(); + PTXMachineFunctionInfo *MFI = MF.getInfo<PTXMachineFunctionInfo>(); + PTXParamManager &PM = MFI->getParamManager(); + + assert(getTargetMachine().getSubtarget<PTXSubtarget>().callsAreHandled() && + "Calls are not handled for the target device"); + + std::vector<SDValue> Ops; + // The layout of the ops will be [Chain, #Ins, Ins, Callee, #Outs, Outs] + Ops.resize(Outs.size() + Ins.size() + 4); + + Ops[0] = Chain; + + // Identify the callee function + const GlobalValue *GV = cast<GlobalAddressSDNode>(Callee)->getGlobal(); + assert(cast<Function>(GV)->getCallingConv() == CallingConv::PTX_Device && + "PTX function calls must be to PTX device functions"); + Callee = DAG.getTargetGlobalAddress(GV, dl, getPointerTy()); + Ops[Ins.size()+2] = Callee; + + // Generate STORE_PARAM nodes for each function argument. In PTX, function + // arguments are explicitly stored into .param variables and passed as + // arguments. There is no register/stack-based calling convention in PTX. + Ops[Ins.size()+3] = DAG.getTargetConstant(OutVals.size(), MVT::i32); + for (unsigned i = 0; i != OutVals.size(); ++i) { + unsigned Size = OutVals[i].getValueType().getSizeInBits(); + unsigned Param = PM.addLocalParam(Size); + const std::string &ParamName = PM.getParamName(Param); + SDValue ParamValue = DAG.getTargetExternalSymbol(ParamName.c_str(), + MVT::Other); + Chain = DAG.getNode(PTXISD::STORE_PARAM, dl, MVT::Other, Chain, + ParamValue, OutVals[i]); + Ops[i+Ins.size()+4] = ParamValue; + } + + std::vector<SDValue> InParams; + + // Generate list of .param variables to hold the return value(s). + Ops[1] = DAG.getTargetConstant(Ins.size(), MVT::i32); + for (unsigned i = 0; i < Ins.size(); ++i) { + unsigned Size = Ins[i].VT.getStoreSizeInBits(); + unsigned Param = PM.addLocalParam(Size); + const std::string &ParamName = PM.getParamName(Param); + SDValue ParamValue = DAG.getTargetExternalSymbol(ParamName.c_str(), + MVT::Other); + Ops[i+2] = ParamValue; + InParams.push_back(ParamValue); + } + + Ops[0] = Chain; + + // Create the CALL node. + Chain = DAG.getNode(PTXISD::CALL, dl, MVT::Other, &Ops[0], Ops.size()); + + // Create the LOAD_PARAM nodes that retrieve the function return value(s). + for (unsigned i = 0; i < Ins.size(); ++i) { + SDValue Load = DAG.getNode(PTXISD::LOAD_PARAM, dl, Ins[i].VT, Chain, + InParams[i]); + InVals.push_back(Load); + } + + return Chain; +} + +unsigned PTXTargetLowering::getNumRegisters(LLVMContext &Context, EVT VT) { + // All arguments consist of one "register," regardless of the type. + return 1; +} + |