aboutsummaryrefslogtreecommitdiff
path: root/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp')
-rw-r--r--contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp201
1 files changed, 150 insertions, 51 deletions
diff --git a/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index 5b6b82aebf30..e8fedfeffde7 100644
--- a/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -24,9 +24,8 @@
using namespace llvm;
SPIRVCallLowering::SPIRVCallLowering(const SPIRVTargetLowering &TLI,
- const SPIRVSubtarget &ST,
SPIRVGlobalRegistry *GR)
- : CallLowering(&TLI), ST(ST), GR(GR) {}
+ : CallLowering(&TLI), GR(GR) {}
bool SPIRVCallLowering::lowerReturn(MachineIRBuilder &MIRBuilder,
const Value *Val, ArrayRef<Register> VRegs,
@@ -36,11 +35,13 @@ bool SPIRVCallLowering::lowerReturn(MachineIRBuilder &MIRBuilder,
// TODO: handle the case of multiple registers.
if (VRegs.size() > 1)
return false;
- if (Val)
+ if (Val) {
+ const auto &STI = MIRBuilder.getMF().getSubtarget();
return MIRBuilder.buildInstr(SPIRV::OpReturnValue)
.addUse(VRegs[0])
- .constrainAllUses(MIRBuilder.getTII(), *ST.getRegisterInfo(),
- *ST.getRegBankInfo());
+ .constrainAllUses(MIRBuilder.getTII(), *STI.getRegisterInfo(),
+ *STI.getRegBankInfo());
+ }
MIRBuilder.buildInstr(SPIRV::OpReturn);
return true;
}
@@ -63,6 +64,56 @@ static uint32_t getFunctionControl(const Function &F) {
return FuncControl;
}
+static ConstantInt *getConstInt(MDNode *MD, unsigned NumOp) {
+ if (MD->getNumOperands() > NumOp) {
+ auto *CMeta = dyn_cast<ConstantAsMetadata>(MD->getOperand(NumOp));
+ if (CMeta)
+ return dyn_cast<ConstantInt>(CMeta->getValue());
+ }
+ return nullptr;
+}
+
+// This code restores function args/retvalue types for composite cases
+// because the final types should still be aggregate whereas they're i32
+// during the translation to cope with aggregate flattening etc.
+static FunctionType *getOriginalFunctionType(const Function &F) {
+ auto *NamedMD = F.getParent()->getNamedMetadata("spv.cloned_funcs");
+ if (NamedMD == nullptr)
+ return F.getFunctionType();
+
+ Type *RetTy = F.getFunctionType()->getReturnType();
+ SmallVector<Type *, 4> ArgTypes;
+ for (auto &Arg : F.args())
+ ArgTypes.push_back(Arg.getType());
+
+ auto ThisFuncMDIt =
+ std::find_if(NamedMD->op_begin(), NamedMD->op_end(), [&F](MDNode *N) {
+ return isa<MDString>(N->getOperand(0)) &&
+ cast<MDString>(N->getOperand(0))->getString() == F.getName();
+ });
+ // TODO: probably one function can have numerous type mutations,
+ // so we should support this.
+ if (ThisFuncMDIt != NamedMD->op_end()) {
+ auto *ThisFuncMD = *ThisFuncMDIt;
+ MDNode *MD = dyn_cast<MDNode>(ThisFuncMD->getOperand(1));
+ assert(MD && "MDNode operand is expected");
+ ConstantInt *Const = getConstInt(MD, 0);
+ if (Const) {
+ auto *CMeta = dyn_cast<ConstantAsMetadata>(MD->getOperand(1));
+ assert(CMeta && "ConstantAsMetadata operand is expected");
+ assert(Const->getSExtValue() >= -1);
+ // Currently -1 indicates return value, greater values mean
+ // argument numbers.
+ if (Const->getSExtValue() == -1)
+ RetTy = CMeta->getType();
+ else
+ ArgTypes[Const->getSExtValue()] = CMeta->getType();
+ }
+ }
+
+ return FunctionType::get(RetTy, ArgTypes, F.isVarArg());
+}
+
bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
const Function &F,
ArrayRef<ArrayRef<Register>> VRegs,
@@ -71,7 +122,8 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
GR->setCurrentFunc(MIRBuilder.getMF());
// Assign types and names to all args, and store their types for later.
- SmallVector<Register, 4> ArgTypeVRegs;
+ FunctionType *FTy = getOriginalFunctionType(F);
+ SmallVector<SPIRVType *, 4> ArgTypeVRegs;
if (VRegs.size() > 0) {
unsigned i = 0;
for (const auto &Arg : F.args()) {
@@ -79,9 +131,18 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
// TODO: handle the case of multiple registers.
if (VRegs[i].size() > 1)
return false;
- auto *SpirvTy =
- GR->assignTypeToVReg(Arg.getType(), VRegs[i][0], MIRBuilder);
- ArgTypeVRegs.push_back(GR->getSPIRVTypeID(SpirvTy));
+ Type *ArgTy = FTy->getParamType(i);
+ SPIRV::AccessQualifier AQ = SPIRV::AccessQualifier::ReadWrite;
+ MDNode *Node = F.getMetadata("kernel_arg_access_qual");
+ if (Node && i < Node->getNumOperands()) {
+ StringRef AQString = cast<MDString>(Node->getOperand(i))->getString();
+ if (AQString.compare("read_only") == 0)
+ AQ = SPIRV::AccessQualifier::ReadOnly;
+ else if (AQString.compare("write_only") == 0)
+ AQ = SPIRV::AccessQualifier::WriteOnly;
+ }
+ auto *SpirvTy = GR->assignTypeToVReg(ArgTy, VRegs[i][0], MIRBuilder, AQ);
+ ArgTypeVRegs.push_back(SpirvTy);
if (Arg.hasName())
buildOpName(VRegs[i][0], Arg.getName(), MIRBuilder);
@@ -92,8 +153,10 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
SPIRV::Decoration::MaxByteOffset, {DerefBytes});
}
if (Arg.hasAttribute(Attribute::Alignment)) {
+ auto Alignment = static_cast<unsigned>(
+ Arg.getAttribute(Attribute::Alignment).getValueAsInt());
buildOpDecorate(VRegs[i][0], MIRBuilder, SPIRV::Decoration::Alignment,
- {static_cast<unsigned>(Arg.getParamAlignment())});
+ {Alignment});
}
if (Arg.hasAttribute(Attribute::ReadOnly)) {
auto Attr =
@@ -107,6 +170,38 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
buildOpDecorate(VRegs[i][0], MIRBuilder,
SPIRV::Decoration::FuncParamAttr, {Attr});
}
+ if (Arg.hasAttribute(Attribute::NoAlias)) {
+ auto Attr =
+ static_cast<unsigned>(SPIRV::FunctionParameterAttribute::NoAlias);
+ buildOpDecorate(VRegs[i][0], MIRBuilder,
+ SPIRV::Decoration::FuncParamAttr, {Attr});
+ }
+ Node = F.getMetadata("kernel_arg_type_qual");
+ if (Node && i < Node->getNumOperands()) {
+ StringRef TypeQual = cast<MDString>(Node->getOperand(i))->getString();
+ if (TypeQual.compare("volatile") == 0)
+ buildOpDecorate(VRegs[i][0], MIRBuilder, SPIRV::Decoration::Volatile,
+ {});
+ }
+ Node = F.getMetadata("spirv.ParameterDecorations");
+ if (Node && i < Node->getNumOperands() &&
+ isa<MDNode>(Node->getOperand(i))) {
+ MDNode *MD = cast<MDNode>(Node->getOperand(i));
+ for (const MDOperand &MDOp : MD->operands()) {
+ MDNode *MD2 = dyn_cast<MDNode>(MDOp);
+ assert(MD2 && "Metadata operand is expected");
+ ConstantInt *Const = getConstInt(MD2, 0);
+ assert(Const && "MDOperand should be ConstantInt");
+ auto Dec = static_cast<SPIRV::Decoration>(Const->getZExtValue());
+ std::vector<uint32_t> DecVec;
+ for (unsigned j = 1; j < MD2->getNumOperands(); j++) {
+ ConstantInt *Const = getConstInt(MD2, j);
+ assert(Const && "MDOperand should be ConstantInt");
+ DecVec.push_back(static_cast<uint32_t>(Const->getZExtValue()));
+ }
+ buildOpDecorate(VRegs[i][0], MIRBuilder, Dec, DecVec);
+ }
+ }
++i;
}
}
@@ -117,30 +212,30 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
MRI->setRegClass(FuncVReg, &SPIRV::IDRegClass);
if (F.isDeclaration())
GR->add(&F, &MIRBuilder.getMF(), FuncVReg);
-
- auto *FTy = F.getFunctionType();
- auto FuncTy = GR->assignTypeToVReg(FTy, FuncVReg, MIRBuilder);
+ SPIRVType *RetTy = GR->getOrCreateSPIRVType(FTy->getReturnType(), MIRBuilder);
+ SPIRVType *FuncTy = GR->getOrCreateOpTypeFunctionWithArgs(
+ FTy, RetTy, ArgTypeVRegs, MIRBuilder);
// Build the OpTypeFunction declaring it.
- Register ReturnTypeID = FuncTy->getOperand(1).getReg();
uint32_t FuncControl = getFunctionControl(F);
MIRBuilder.buildInstr(SPIRV::OpFunction)
.addDef(FuncVReg)
- .addUse(ReturnTypeID)
+ .addUse(GR->getSPIRVTypeID(RetTy))
.addImm(FuncControl)
.addUse(GR->getSPIRVTypeID(FuncTy));
// Add OpFunctionParameters.
- const unsigned NumArgs = ArgTypeVRegs.size();
- for (unsigned i = 0; i < NumArgs; ++i) {
+ int i = 0;
+ for (const auto &Arg : F.args()) {
assert(VRegs[i].size() == 1 && "Formal arg has multiple vregs");
MRI->setRegClass(VRegs[i][0], &SPIRV::IDRegClass);
MIRBuilder.buildInstr(SPIRV::OpFunctionParameter)
.addDef(VRegs[i][0])
- .addUse(ArgTypeVRegs[i]);
+ .addUse(GR->getSPIRVTypeID(ArgTypeVRegs[i]));
if (F.isDeclaration())
- GR->add(F.getArg(i), &MIRBuilder.getMF(), VRegs[i][0]);
+ GR->add(&Arg, &MIRBuilder.getMF(), VRegs[i][0]);
+ i++;
}
// Name the function.
if (F.hasName())
@@ -169,48 +264,51 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
// TODO: handle the case of multiple registers.
if (Info.OrigRet.Regs.size() > 1)
return false;
+ MachineFunction &MF = MIRBuilder.getMF();
+ GR->setCurrentFunc(MF);
+ FunctionType *FTy = nullptr;
+ const Function *CF = nullptr;
- GR->setCurrentFunc(MIRBuilder.getMF());
- Register ResVReg =
- Info.OrigRet.Regs.empty() ? Register(0) : Info.OrigRet.Regs[0];
// Emit a regular OpFunctionCall. If it's an externally declared function,
- // be sure to emit its type and function declaration here. It will be
- // hoisted globally later.
+ // be sure to emit its type and function declaration here. It will be hoisted
+ // globally later.
if (Info.Callee.isGlobal()) {
- auto *CF = dyn_cast_or_null<const Function>(Info.Callee.getGlobal());
+ CF = dyn_cast_or_null<const Function>(Info.Callee.getGlobal());
// TODO: support constexpr casts and indirect calls.
if (CF == nullptr)
return false;
- if (CF->isDeclaration()) {
- // Emit the type info and forward function declaration to the first MBB
- // to ensure VReg definition dependencies are valid across all MBBs.
- MachineBasicBlock::iterator OldII = MIRBuilder.getInsertPt();
- MachineBasicBlock &OldBB = MIRBuilder.getMBB();
- MachineBasicBlock &FirstBB = *MIRBuilder.getMF().getBlockNumbered(0);
- MIRBuilder.setInsertPt(FirstBB, FirstBB.instr_end());
-
- SmallVector<ArrayRef<Register>, 8> VRegArgs;
- SmallVector<SmallVector<Register, 1>, 8> ToInsert;
- for (const Argument &Arg : CF->args()) {
- if (MIRBuilder.getDataLayout().getTypeStoreSize(Arg.getType()).isZero())
- continue; // Don't handle zero sized types.
- ToInsert.push_back({MIRBuilder.getMRI()->createGenericVirtualRegister(
- LLT::scalar(32))});
- VRegArgs.push_back(ToInsert.back());
- }
- // TODO: Reuse FunctionLoweringInfo.
- FunctionLoweringInfo FuncInfo;
- lowerFormalArguments(MIRBuilder, *CF, VRegArgs, FuncInfo);
- MIRBuilder.setInsertPt(OldBB, OldII);
+ FTy = getOriginalFunctionType(*CF);
+ }
+
+ Register ResVReg =
+ Info.OrigRet.Regs.empty() ? Register(0) : Info.OrigRet.Regs[0];
+ if (CF && CF->isDeclaration() &&
+ !GR->find(CF, &MIRBuilder.getMF()).isValid()) {
+ // Emit the type info and forward function declaration to the first MBB
+ // to ensure VReg definition dependencies are valid across all MBBs.
+ MachineIRBuilder FirstBlockBuilder;
+ FirstBlockBuilder.setMF(MF);
+ FirstBlockBuilder.setMBB(*MF.getBlockNumbered(0));
+
+ SmallVector<ArrayRef<Register>, 8> VRegArgs;
+ SmallVector<SmallVector<Register, 1>, 8> ToInsert;
+ for (const Argument &Arg : CF->args()) {
+ if (MIRBuilder.getDataLayout().getTypeStoreSize(Arg.getType()).isZero())
+ continue; // Don't handle zero sized types.
+ ToInsert.push_back(
+ {MIRBuilder.getMRI()->createGenericVirtualRegister(LLT::scalar(32))});
+ VRegArgs.push_back(ToInsert.back());
}
+ // TODO: Reuse FunctionLoweringInfo
+ FunctionLoweringInfo FuncInfo;
+ lowerFormalArguments(FirstBlockBuilder, *CF, VRegArgs, FuncInfo);
}
// Make sure there's a valid return reg, even for functions returning void.
- if (!ResVReg.isValid()) {
+ if (!ResVReg.isValid())
ResVReg = MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::IDRegClass);
- }
SPIRVType *RetType =
- GR->assignTypeToVReg(Info.OrigRet.Ty, ResVReg, MIRBuilder);
+ GR->assignTypeToVReg(FTy->getReturnType(), ResVReg, MIRBuilder);
// Emit the OpFunctionCall and its args.
auto MIB = MIRBuilder.buildInstr(SPIRV::OpFunctionCall)
@@ -224,6 +322,7 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
return false;
MIB.addUse(Arg.Regs[0]);
}
- return MIB.constrainAllUses(MIRBuilder.getTII(), *ST.getRegisterInfo(),
- *ST.getRegBankInfo());
+ const auto &STI = MF.getSubtarget();
+ return MIB.constrainAllUses(MIRBuilder.getTII(), *STI.getRegisterInfo(),
+ *STI.getRegBankInfo());
}