diff options
Diffstat (limited to 'contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp')
-rw-r--r-- | contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp | 201 |
1 files changed, 150 insertions, 51 deletions
diff --git a/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp index 5b6b82aebf30..e8fedfeffde7 100644 --- a/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp +++ b/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp @@ -24,9 +24,8 @@ using namespace llvm; SPIRVCallLowering::SPIRVCallLowering(const SPIRVTargetLowering &TLI, - const SPIRVSubtarget &ST, SPIRVGlobalRegistry *GR) - : CallLowering(&TLI), ST(ST), GR(GR) {} + : CallLowering(&TLI), GR(GR) {} bool SPIRVCallLowering::lowerReturn(MachineIRBuilder &MIRBuilder, const Value *Val, ArrayRef<Register> VRegs, @@ -36,11 +35,13 @@ bool SPIRVCallLowering::lowerReturn(MachineIRBuilder &MIRBuilder, // TODO: handle the case of multiple registers. if (VRegs.size() > 1) return false; - if (Val) + if (Val) { + const auto &STI = MIRBuilder.getMF().getSubtarget(); return MIRBuilder.buildInstr(SPIRV::OpReturnValue) .addUse(VRegs[0]) - .constrainAllUses(MIRBuilder.getTII(), *ST.getRegisterInfo(), - *ST.getRegBankInfo()); + .constrainAllUses(MIRBuilder.getTII(), *STI.getRegisterInfo(), + *STI.getRegBankInfo()); + } MIRBuilder.buildInstr(SPIRV::OpReturn); return true; } @@ -63,6 +64,56 @@ static uint32_t getFunctionControl(const Function &F) { return FuncControl; } +static ConstantInt *getConstInt(MDNode *MD, unsigned NumOp) { + if (MD->getNumOperands() > NumOp) { + auto *CMeta = dyn_cast<ConstantAsMetadata>(MD->getOperand(NumOp)); + if (CMeta) + return dyn_cast<ConstantInt>(CMeta->getValue()); + } + return nullptr; +} + +// This code restores function args/retvalue types for composite cases +// because the final types should still be aggregate whereas they're i32 +// during the translation to cope with aggregate flattening etc. +static FunctionType *getOriginalFunctionType(const Function &F) { + auto *NamedMD = F.getParent()->getNamedMetadata("spv.cloned_funcs"); + if (NamedMD == nullptr) + return F.getFunctionType(); + + Type *RetTy = F.getFunctionType()->getReturnType(); + SmallVector<Type *, 4> ArgTypes; + for (auto &Arg : F.args()) + ArgTypes.push_back(Arg.getType()); + + auto ThisFuncMDIt = + std::find_if(NamedMD->op_begin(), NamedMD->op_end(), [&F](MDNode *N) { + return isa<MDString>(N->getOperand(0)) && + cast<MDString>(N->getOperand(0))->getString() == F.getName(); + }); + // TODO: probably one function can have numerous type mutations, + // so we should support this. + if (ThisFuncMDIt != NamedMD->op_end()) { + auto *ThisFuncMD = *ThisFuncMDIt; + MDNode *MD = dyn_cast<MDNode>(ThisFuncMD->getOperand(1)); + assert(MD && "MDNode operand is expected"); + ConstantInt *Const = getConstInt(MD, 0); + if (Const) { + auto *CMeta = dyn_cast<ConstantAsMetadata>(MD->getOperand(1)); + assert(CMeta && "ConstantAsMetadata operand is expected"); + assert(Const->getSExtValue() >= -1); + // Currently -1 indicates return value, greater values mean + // argument numbers. + if (Const->getSExtValue() == -1) + RetTy = CMeta->getType(); + else + ArgTypes[Const->getSExtValue()] = CMeta->getType(); + } + } + + return FunctionType::get(RetTy, ArgTypes, F.isVarArg()); +} + bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder, const Function &F, ArrayRef<ArrayRef<Register>> VRegs, @@ -71,7 +122,8 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder, GR->setCurrentFunc(MIRBuilder.getMF()); // Assign types and names to all args, and store their types for later. - SmallVector<Register, 4> ArgTypeVRegs; + FunctionType *FTy = getOriginalFunctionType(F); + SmallVector<SPIRVType *, 4> ArgTypeVRegs; if (VRegs.size() > 0) { unsigned i = 0; for (const auto &Arg : F.args()) { @@ -79,9 +131,18 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder, // TODO: handle the case of multiple registers. if (VRegs[i].size() > 1) return false; - auto *SpirvTy = - GR->assignTypeToVReg(Arg.getType(), VRegs[i][0], MIRBuilder); - ArgTypeVRegs.push_back(GR->getSPIRVTypeID(SpirvTy)); + Type *ArgTy = FTy->getParamType(i); + SPIRV::AccessQualifier AQ = SPIRV::AccessQualifier::ReadWrite; + MDNode *Node = F.getMetadata("kernel_arg_access_qual"); + if (Node && i < Node->getNumOperands()) { + StringRef AQString = cast<MDString>(Node->getOperand(i))->getString(); + if (AQString.compare("read_only") == 0) + AQ = SPIRV::AccessQualifier::ReadOnly; + else if (AQString.compare("write_only") == 0) + AQ = SPIRV::AccessQualifier::WriteOnly; + } + auto *SpirvTy = GR->assignTypeToVReg(ArgTy, VRegs[i][0], MIRBuilder, AQ); + ArgTypeVRegs.push_back(SpirvTy); if (Arg.hasName()) buildOpName(VRegs[i][0], Arg.getName(), MIRBuilder); @@ -92,8 +153,10 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder, SPIRV::Decoration::MaxByteOffset, {DerefBytes}); } if (Arg.hasAttribute(Attribute::Alignment)) { + auto Alignment = static_cast<unsigned>( + Arg.getAttribute(Attribute::Alignment).getValueAsInt()); buildOpDecorate(VRegs[i][0], MIRBuilder, SPIRV::Decoration::Alignment, - {static_cast<unsigned>(Arg.getParamAlignment())}); + {Alignment}); } if (Arg.hasAttribute(Attribute::ReadOnly)) { auto Attr = @@ -107,6 +170,38 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder, buildOpDecorate(VRegs[i][0], MIRBuilder, SPIRV::Decoration::FuncParamAttr, {Attr}); } + if (Arg.hasAttribute(Attribute::NoAlias)) { + auto Attr = + static_cast<unsigned>(SPIRV::FunctionParameterAttribute::NoAlias); + buildOpDecorate(VRegs[i][0], MIRBuilder, + SPIRV::Decoration::FuncParamAttr, {Attr}); + } + Node = F.getMetadata("kernel_arg_type_qual"); + if (Node && i < Node->getNumOperands()) { + StringRef TypeQual = cast<MDString>(Node->getOperand(i))->getString(); + if (TypeQual.compare("volatile") == 0) + buildOpDecorate(VRegs[i][0], MIRBuilder, SPIRV::Decoration::Volatile, + {}); + } + Node = F.getMetadata("spirv.ParameterDecorations"); + if (Node && i < Node->getNumOperands() && + isa<MDNode>(Node->getOperand(i))) { + MDNode *MD = cast<MDNode>(Node->getOperand(i)); + for (const MDOperand &MDOp : MD->operands()) { + MDNode *MD2 = dyn_cast<MDNode>(MDOp); + assert(MD2 && "Metadata operand is expected"); + ConstantInt *Const = getConstInt(MD2, 0); + assert(Const && "MDOperand should be ConstantInt"); + auto Dec = static_cast<SPIRV::Decoration>(Const->getZExtValue()); + std::vector<uint32_t> DecVec; + for (unsigned j = 1; j < MD2->getNumOperands(); j++) { + ConstantInt *Const = getConstInt(MD2, j); + assert(Const && "MDOperand should be ConstantInt"); + DecVec.push_back(static_cast<uint32_t>(Const->getZExtValue())); + } + buildOpDecorate(VRegs[i][0], MIRBuilder, Dec, DecVec); + } + } ++i; } } @@ -117,30 +212,30 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder, MRI->setRegClass(FuncVReg, &SPIRV::IDRegClass); if (F.isDeclaration()) GR->add(&F, &MIRBuilder.getMF(), FuncVReg); - - auto *FTy = F.getFunctionType(); - auto FuncTy = GR->assignTypeToVReg(FTy, FuncVReg, MIRBuilder); + SPIRVType *RetTy = GR->getOrCreateSPIRVType(FTy->getReturnType(), MIRBuilder); + SPIRVType *FuncTy = GR->getOrCreateOpTypeFunctionWithArgs( + FTy, RetTy, ArgTypeVRegs, MIRBuilder); // Build the OpTypeFunction declaring it. - Register ReturnTypeID = FuncTy->getOperand(1).getReg(); uint32_t FuncControl = getFunctionControl(F); MIRBuilder.buildInstr(SPIRV::OpFunction) .addDef(FuncVReg) - .addUse(ReturnTypeID) + .addUse(GR->getSPIRVTypeID(RetTy)) .addImm(FuncControl) .addUse(GR->getSPIRVTypeID(FuncTy)); // Add OpFunctionParameters. - const unsigned NumArgs = ArgTypeVRegs.size(); - for (unsigned i = 0; i < NumArgs; ++i) { + int i = 0; + for (const auto &Arg : F.args()) { assert(VRegs[i].size() == 1 && "Formal arg has multiple vregs"); MRI->setRegClass(VRegs[i][0], &SPIRV::IDRegClass); MIRBuilder.buildInstr(SPIRV::OpFunctionParameter) .addDef(VRegs[i][0]) - .addUse(ArgTypeVRegs[i]); + .addUse(GR->getSPIRVTypeID(ArgTypeVRegs[i])); if (F.isDeclaration()) - GR->add(F.getArg(i), &MIRBuilder.getMF(), VRegs[i][0]); + GR->add(&Arg, &MIRBuilder.getMF(), VRegs[i][0]); + i++; } // Name the function. if (F.hasName()) @@ -169,48 +264,51 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder, // TODO: handle the case of multiple registers. if (Info.OrigRet.Regs.size() > 1) return false; + MachineFunction &MF = MIRBuilder.getMF(); + GR->setCurrentFunc(MF); + FunctionType *FTy = nullptr; + const Function *CF = nullptr; - GR->setCurrentFunc(MIRBuilder.getMF()); - Register ResVReg = - Info.OrigRet.Regs.empty() ? Register(0) : Info.OrigRet.Regs[0]; // Emit a regular OpFunctionCall. If it's an externally declared function, - // be sure to emit its type and function declaration here. It will be - // hoisted globally later. + // be sure to emit its type and function declaration here. It will be hoisted + // globally later. if (Info.Callee.isGlobal()) { - auto *CF = dyn_cast_or_null<const Function>(Info.Callee.getGlobal()); + CF = dyn_cast_or_null<const Function>(Info.Callee.getGlobal()); // TODO: support constexpr casts and indirect calls. if (CF == nullptr) return false; - if (CF->isDeclaration()) { - // Emit the type info and forward function declaration to the first MBB - // to ensure VReg definition dependencies are valid across all MBBs. - MachineBasicBlock::iterator OldII = MIRBuilder.getInsertPt(); - MachineBasicBlock &OldBB = MIRBuilder.getMBB(); - MachineBasicBlock &FirstBB = *MIRBuilder.getMF().getBlockNumbered(0); - MIRBuilder.setInsertPt(FirstBB, FirstBB.instr_end()); - - SmallVector<ArrayRef<Register>, 8> VRegArgs; - SmallVector<SmallVector<Register, 1>, 8> ToInsert; - for (const Argument &Arg : CF->args()) { - if (MIRBuilder.getDataLayout().getTypeStoreSize(Arg.getType()).isZero()) - continue; // Don't handle zero sized types. - ToInsert.push_back({MIRBuilder.getMRI()->createGenericVirtualRegister( - LLT::scalar(32))}); - VRegArgs.push_back(ToInsert.back()); - } - // TODO: Reuse FunctionLoweringInfo. - FunctionLoweringInfo FuncInfo; - lowerFormalArguments(MIRBuilder, *CF, VRegArgs, FuncInfo); - MIRBuilder.setInsertPt(OldBB, OldII); + FTy = getOriginalFunctionType(*CF); + } + + Register ResVReg = + Info.OrigRet.Regs.empty() ? Register(0) : Info.OrigRet.Regs[0]; + if (CF && CF->isDeclaration() && + !GR->find(CF, &MIRBuilder.getMF()).isValid()) { + // Emit the type info and forward function declaration to the first MBB + // to ensure VReg definition dependencies are valid across all MBBs. + MachineIRBuilder FirstBlockBuilder; + FirstBlockBuilder.setMF(MF); + FirstBlockBuilder.setMBB(*MF.getBlockNumbered(0)); + + SmallVector<ArrayRef<Register>, 8> VRegArgs; + SmallVector<SmallVector<Register, 1>, 8> ToInsert; + for (const Argument &Arg : CF->args()) { + if (MIRBuilder.getDataLayout().getTypeStoreSize(Arg.getType()).isZero()) + continue; // Don't handle zero sized types. + ToInsert.push_back( + {MIRBuilder.getMRI()->createGenericVirtualRegister(LLT::scalar(32))}); + VRegArgs.push_back(ToInsert.back()); } + // TODO: Reuse FunctionLoweringInfo + FunctionLoweringInfo FuncInfo; + lowerFormalArguments(FirstBlockBuilder, *CF, VRegArgs, FuncInfo); } // Make sure there's a valid return reg, even for functions returning void. - if (!ResVReg.isValid()) { + if (!ResVReg.isValid()) ResVReg = MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::IDRegClass); - } SPIRVType *RetType = - GR->assignTypeToVReg(Info.OrigRet.Ty, ResVReg, MIRBuilder); + GR->assignTypeToVReg(FTy->getReturnType(), ResVReg, MIRBuilder); // Emit the OpFunctionCall and its args. auto MIB = MIRBuilder.buildInstr(SPIRV::OpFunctionCall) @@ -224,6 +322,7 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder, return false; MIB.addUse(Arg.Regs[0]); } - return MIB.constrainAllUses(MIRBuilder.getTII(), *ST.getRegisterInfo(), - *ST.getRegBankInfo()); + const auto &STI = MF.getSubtarget(); + return MIB.constrainAllUses(MIRBuilder.getTII(), *STI.getRegisterInfo(), + *STI.getRegBankInfo()); } |