diff options
Diffstat (limited to 'llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp')
| -rw-r--r-- | llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp | 464 |
1 files changed, 428 insertions, 36 deletions
diff --git a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp index ae6d6f96904c..5fd5c226eef8 100644 --- a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp +++ b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp @@ -1,75 +1,467 @@ -//===- DXILTranslateMetadata.cpp - Pass to emit DXIL metadata ---*- C++ -*-===// +//===- DXILTranslateMetadata.cpp - Pass to emit DXIL metadata -------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// -/// -//===----------------------------------------------------------------------===// -#include "DXILMetadata.h" +#include "DXILTranslateMetadata.h" #include "DXILResource.h" #include "DXILResourceAnalysis.h" #include "DXILShaderFlags.h" #include "DirectX.h" -#include "llvm/ADT/StringSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Twine.h" +#include "llvm/Analysis/DXILMetadataAnalysis.h" +#include "llvm/Analysis/DXILResource.h" +#include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constants.h" +#include "llvm/IR/DiagnosticInfo.h" +#include "llvm/IR/DiagnosticPrinter.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/MDBuilder.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" +#include "llvm/InitializePasses.h" #include "llvm/Pass.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/VersionTuple.h" #include "llvm/TargetParser/Triple.h" +#include <cstdint> using namespace llvm; using namespace llvm::dxil; namespace { -class DXILTranslateMetadata : public ModulePass { -public: - static char ID; // Pass identification, replacement for typeid - explicit DXILTranslateMetadata() : ModulePass(ID) {} +/// A simple Wrapper DiagnosticInfo that generates Module-level diagnostic +/// for TranslateMetadata pass +class DiagnosticInfoTranslateMD : public DiagnosticInfo { +private: + const Twine &Msg; + const Module &Mod; - StringRef getPassName() const override { return "DXIL Metadata Emit"; } +public: + /// \p M is the module for which the diagnostic is being emitted. \p Msg is + /// the message to show. Note that this class does not copy this message, so + /// this reference must be valid for the whole life time of the diagnostic. + DiagnosticInfoTranslateMD(const Module &M, const Twine &Msg, + DiagnosticSeverity Severity = DS_Error) + : DiagnosticInfo(DK_Unsupported, Severity), Msg(Msg), Mod(M) {} - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.setPreservesAll(); - AU.addRequired<DXILResourceWrapper>(); - AU.addRequired<ShaderFlagsAnalysisWrapper>(); + void print(DiagnosticPrinter &DP) const override { + DP << Mod.getName() << ": " << Msg << '\n'; } +}; - bool runOnModule(Module &M) override; +enum class EntryPropsTag { + ShaderFlags = 0, + GSState, + DSState, + HSState, + NumThreads, + AutoBindingSpace, + RayPayloadSize, + RayAttribSize, + ShaderKind, + MSState, + ASStateTag, + WaveSize, + EntryRootSig, }; } // namespace -bool DXILTranslateMetadata::runOnModule(Module &M) { +static NamedMDNode *emitResourceMetadata(Module &M, DXILBindingMap &DBM, + DXILResourceTypeMap &DRTM, + const dxil::Resources &MDResources) { + LLVMContext &Context = M.getContext(); + + for (ResourceBindingInfo &RI : DBM) + if (!RI.hasSymbol()) + RI.createSymbol(M, DRTM[RI.getHandleTy()].createElementStruct()); + + SmallVector<Metadata *> SRVs, UAVs, CBufs, Smps; + for (const ResourceBindingInfo &RI : DBM.srvs()) + SRVs.push_back(RI.getAsMetadata(M, DRTM[RI.getHandleTy()])); + for (const ResourceBindingInfo &RI : DBM.uavs()) + UAVs.push_back(RI.getAsMetadata(M, DRTM[RI.getHandleTy()])); + for (const ResourceBindingInfo &RI : DBM.cbuffers()) + CBufs.push_back(RI.getAsMetadata(M, DRTM[RI.getHandleTy()])); + for (const ResourceBindingInfo &RI : DBM.samplers()) + Smps.push_back(RI.getAsMetadata(M, DRTM[RI.getHandleTy()])); + + Metadata *SRVMD = SRVs.empty() ? nullptr : MDNode::get(Context, SRVs); + Metadata *UAVMD = UAVs.empty() ? nullptr : MDNode::get(Context, UAVs); + Metadata *CBufMD = CBufs.empty() ? nullptr : MDNode::get(Context, CBufs); + Metadata *SmpMD = Smps.empty() ? nullptr : MDNode::get(Context, Smps); + bool HasResources = !DBM.empty(); + + if (MDResources.hasUAVs()) { + assert(!UAVMD && "Old and new UAV representations can't coexist"); + UAVMD = MDResources.writeUAVs(M); + HasResources = true; + } + + if (MDResources.hasCBuffers()) { + assert(!CBufMD && "Old and new cbuffer representations can't coexist"); + CBufMD = MDResources.writeCBuffers(M); + HasResources = true; + } + + if (!HasResources) + return nullptr; + + NamedMDNode *ResourceMD = M.getOrInsertNamedMetadata("dx.resources"); + ResourceMD->addOperand( + MDNode::get(M.getContext(), {SRVMD, UAVMD, CBufMD, SmpMD})); + + return ResourceMD; +} + +static StringRef getShortShaderStage(Triple::EnvironmentType Env) { + switch (Env) { + case Triple::Pixel: + return "ps"; + case Triple::Vertex: + return "vs"; + case Triple::Geometry: + return "gs"; + case Triple::Hull: + return "hs"; + case Triple::Domain: + return "ds"; + case Triple::Compute: + return "cs"; + case Triple::Library: + return "lib"; + case Triple::Mesh: + return "ms"; + case Triple::Amplification: + return "as"; + default: + break; + } + llvm_unreachable("Unsupported environment for DXIL generation."); +} + +static uint32_t getShaderStage(Triple::EnvironmentType Env) { + return (uint32_t)Env - (uint32_t)llvm::Triple::Pixel; +} + +static SmallVector<Metadata *> +getTagValueAsMetadata(EntryPropsTag Tag, uint64_t Value, LLVMContext &Ctx) { + SmallVector<Metadata *> MDVals; + MDVals.emplace_back(ConstantAsMetadata::get( + ConstantInt::get(Type::getInt32Ty(Ctx), static_cast<int>(Tag)))); + switch (Tag) { + case EntryPropsTag::ShaderFlags: + MDVals.emplace_back(ConstantAsMetadata::get( + ConstantInt::get(Type::getInt64Ty(Ctx), Value))); + break; + case EntryPropsTag::ShaderKind: + MDVals.emplace_back(ConstantAsMetadata::get( + ConstantInt::get(Type::getInt32Ty(Ctx), Value))); + break; + case EntryPropsTag::GSState: + case EntryPropsTag::DSState: + case EntryPropsTag::HSState: + case EntryPropsTag::NumThreads: + case EntryPropsTag::AutoBindingSpace: + case EntryPropsTag::RayPayloadSize: + case EntryPropsTag::RayAttribSize: + case EntryPropsTag::MSState: + case EntryPropsTag::ASStateTag: + case EntryPropsTag::WaveSize: + case EntryPropsTag::EntryRootSig: + llvm_unreachable("NYI: Unhandled entry property tag"); + } + return MDVals; +} + +static MDTuple * +getEntryPropAsMetadata(const EntryProperties &EP, uint64_t EntryShaderFlags, + const Triple::EnvironmentType ShaderProfile) { + SmallVector<Metadata *> MDVals; + LLVMContext &Ctx = EP.Entry->getContext(); + if (EntryShaderFlags != 0) + MDVals.append(getTagValueAsMetadata(EntryPropsTag::ShaderFlags, + EntryShaderFlags, Ctx)); + + if (EP.Entry != nullptr) { + // FIXME: support more props. + // See https://github.com/llvm/llvm-project/issues/57948. + // Add shader kind for lib entries. + if (ShaderProfile == Triple::EnvironmentType::Library && + EP.ShaderStage != Triple::EnvironmentType::Library) + MDVals.append(getTagValueAsMetadata(EntryPropsTag::ShaderKind, + getShaderStage(EP.ShaderStage), Ctx)); + + if (EP.ShaderStage == Triple::EnvironmentType::Compute) { + MDVals.emplace_back(ConstantAsMetadata::get(ConstantInt::get( + Type::getInt32Ty(Ctx), static_cast<int>(EntryPropsTag::NumThreads)))); + Metadata *NumThreadVals[] = {ConstantAsMetadata::get(ConstantInt::get( + Type::getInt32Ty(Ctx), EP.NumThreadsX)), + ConstantAsMetadata::get(ConstantInt::get( + Type::getInt32Ty(Ctx), EP.NumThreadsY)), + ConstantAsMetadata::get(ConstantInt::get( + Type::getInt32Ty(Ctx), EP.NumThreadsZ))}; + MDVals.emplace_back(MDNode::get(Ctx, NumThreadVals)); + } + } + if (MDVals.empty()) + return nullptr; + return MDNode::get(Ctx, MDVals); +} + +MDTuple *constructEntryMetadata(const Function *EntryFn, MDTuple *Signatures, + MDNode *Resources, MDTuple *Properties, + LLVMContext &Ctx) { + // Each entry point metadata record specifies: + // * reference to the entry point function global symbol + // * unmangled name + // * list of signatures + // * list of resources + // * list of tag-value pairs of shader capabilities and other properties + Metadata *MDVals[5]; + MDVals[0] = + EntryFn ? ValueAsMetadata::get(const_cast<Function *>(EntryFn)) : nullptr; + MDVals[1] = MDString::get(Ctx, EntryFn ? EntryFn->getName() : ""); + MDVals[2] = Signatures; + MDVals[3] = Resources; + MDVals[4] = Properties; + return MDNode::get(Ctx, MDVals); +} + +static MDTuple *emitEntryMD(const EntryProperties &EP, MDTuple *Signatures, + MDNode *MDResources, + const uint64_t EntryShaderFlags, + const Triple::EnvironmentType ShaderProfile) { + MDTuple *Properties = + getEntryPropAsMetadata(EP, EntryShaderFlags, ShaderProfile); + return constructEntryMetadata(EP.Entry, Signatures, MDResources, Properties, + EP.Entry->getContext()); +} + +static void emitValidatorVersionMD(Module &M, const ModuleMetadataInfo &MMDI) { + if (MMDI.ValidatorVersion.empty()) + return; + + LLVMContext &Ctx = M.getContext(); + IRBuilder<> IRB(Ctx); + Metadata *MDVals[2]; + MDVals[0] = + ConstantAsMetadata::get(IRB.getInt32(MMDI.ValidatorVersion.getMajor())); + MDVals[1] = ConstantAsMetadata::get( + IRB.getInt32(MMDI.ValidatorVersion.getMinor().value_or(0))); + NamedMDNode *ValVerNode = M.getOrInsertNamedMetadata("dx.valver"); + // Set validator version obtained from DXIL Metadata Analysis pass + ValVerNode->clearOperands(); + ValVerNode->addOperand(MDNode::get(Ctx, MDVals)); +} + +static void emitShaderModelVersionMD(Module &M, + const ModuleMetadataInfo &MMDI) { + LLVMContext &Ctx = M.getContext(); + IRBuilder<> IRB(Ctx); + Metadata *SMVals[3]; + VersionTuple SM = MMDI.ShaderModelVersion; + SMVals[0] = MDString::get(Ctx, getShortShaderStage(MMDI.ShaderProfile)); + SMVals[1] = ConstantAsMetadata::get(IRB.getInt32(SM.getMajor())); + SMVals[2] = ConstantAsMetadata::get(IRB.getInt32(SM.getMinor().value_or(0))); + NamedMDNode *SMMDNode = M.getOrInsertNamedMetadata("dx.shaderModel"); + SMMDNode->addOperand(MDNode::get(Ctx, SMVals)); +} + +static void emitDXILVersionTupleMD(Module &M, const ModuleMetadataInfo &MMDI) { + LLVMContext &Ctx = M.getContext(); + IRBuilder<> IRB(Ctx); + VersionTuple DXILVer = MMDI.DXILVersion; + Metadata *DXILVals[2]; + DXILVals[0] = ConstantAsMetadata::get(IRB.getInt32(DXILVer.getMajor())); + DXILVals[1] = + ConstantAsMetadata::get(IRB.getInt32(DXILVer.getMinor().value_or(0))); + NamedMDNode *DXILVerMDNode = M.getOrInsertNamedMetadata("dx.version"); + DXILVerMDNode->addOperand(MDNode::get(Ctx, DXILVals)); +} + +static MDTuple *emitTopLevelLibraryNode(Module &M, MDNode *RMD, + uint64_t ShaderFlags) { + LLVMContext &Ctx = M.getContext(); + MDTuple *Properties = nullptr; + if (ShaderFlags != 0) { + SmallVector<Metadata *> MDVals; + MDVals.append( + getTagValueAsMetadata(EntryPropsTag::ShaderFlags, ShaderFlags, Ctx)); + Properties = MDNode::get(Ctx, MDVals); + } + // Library has an entry metadata with resource table metadata and all other + // MDNodes as null. + return constructEntryMetadata(nullptr, nullptr, RMD, Properties, Ctx); +} + +// TODO: We might need to refactor this to be more generic, +// in case we need more metadata to be replaced. +static void translateBranchMetadata(Module &M) { + for (Function &F : M) { + for (BasicBlock &BB : F) { + Instruction *BBTerminatorInst = BB.getTerminator(); + + MDNode *HlslControlFlowMD = + BBTerminatorInst->getMetadata("hlsl.controlflow.hint"); + + if (!HlslControlFlowMD) + continue; + + assert(HlslControlFlowMD->getNumOperands() == 2 && + "invalid operands for hlsl.controlflow.hint"); - dxil::ValidatorVersionMD ValVerMD(M); - if (ValVerMD.isEmpty()) - ValVerMD.update(VersionTuple(1, 0)); - dxil::createShaderModelMD(M); - dxil::createDXILVersionMD(M); + MDBuilder MDHelper(M.getContext()); + ConstantInt *Op1 = + mdconst::extract<ConstantInt>(HlslControlFlowMD->getOperand(1)); - const dxil::Resources &Res = - getAnalysis<DXILResourceWrapper>().getDXILResource(); - Res.write(M); + SmallVector<llvm::Metadata *, 2> Vals( + ArrayRef<Metadata *>{MDHelper.createString("dx.controlflow.hints"), + MDHelper.createConstant(Op1)}); - const uint64_t Flags = static_cast<uint64_t>( - getAnalysis<ShaderFlagsAnalysisWrapper>().getShaderFlags()); - dxil::createEntryMD(M, Flags); + MDNode *MDNode = llvm::MDNode::get(M.getContext(), Vals); + + BBTerminatorInst->setMetadata("dx.controlflow.hints", MDNode); + BBTerminatorInst->setMetadata("hlsl.controlflow.hint", nullptr); + } + } +} + +static void translateMetadata(Module &M, DXILBindingMap &DBM, + DXILResourceTypeMap &DRTM, + const Resources &MDResources, + const ModuleShaderFlags &ShaderFlags, + const ModuleMetadataInfo &MMDI) { + LLVMContext &Ctx = M.getContext(); + IRBuilder<> IRB(Ctx); + SmallVector<MDNode *> EntryFnMDNodes; + + emitValidatorVersionMD(M, MMDI); + emitShaderModelVersionMD(M, MMDI); + emitDXILVersionTupleMD(M, MMDI); + NamedMDNode *NamedResourceMD = + emitResourceMetadata(M, DBM, DRTM, MDResources); + auto *ResourceMD = + (NamedResourceMD != nullptr) ? NamedResourceMD->getOperand(0) : nullptr; + // FIXME: Add support to construct Signatures + // See https://github.com/llvm/llvm-project/issues/57928 + MDTuple *Signatures = nullptr; + + if (MMDI.ShaderProfile == Triple::EnvironmentType::Library) { + // Get the combined shader flag mask of all functions in the library to be + // used as shader flags mask value associated with top-level library entry + // metadata. + uint64_t CombinedMask = ShaderFlags.getCombinedFlags(); + EntryFnMDNodes.emplace_back( + emitTopLevelLibraryNode(M, ResourceMD, CombinedMask)); + } else if (MMDI.EntryPropertyVec.size() > 1) { + M.getContext().diagnose(DiagnosticInfoTranslateMD( + M, "Non-library shader: One and only one entry expected")); + } + + for (const EntryProperties &EntryProp : MMDI.EntryPropertyVec) { + const ComputedShaderFlags &EntrySFMask = + ShaderFlags.getFunctionFlags(EntryProp.Entry); + + // If ShaderProfile is Library, mask is already consolidated in the + // top-level library node. Hence it is not emitted. + uint64_t EntryShaderFlags = 0; + if (MMDI.ShaderProfile != Triple::EnvironmentType::Library) { + EntryShaderFlags = EntrySFMask; + if (EntryProp.ShaderStage != MMDI.ShaderProfile) { + M.getContext().diagnose(DiagnosticInfoTranslateMD( + M, + "Shader stage '" + + Twine(getShortShaderStage(EntryProp.ShaderStage) + + "' for entry '" + Twine(EntryProp.Entry->getName()) + + "' different from specified target profile '" + + Twine(Triple::getEnvironmentTypeName(MMDI.ShaderProfile) + + "'")))); + } + } + EntryFnMDNodes.emplace_back(emitEntryMD(EntryProp, Signatures, ResourceMD, + EntryShaderFlags, + MMDI.ShaderProfile)); + } - return false; + NamedMDNode *EntryPointsNamedMD = + M.getOrInsertNamedMetadata("dx.entryPoints"); + for (auto *Entry : EntryFnMDNodes) + EntryPointsNamedMD->addOperand(Entry); } -char DXILTranslateMetadata::ID = 0; +PreservedAnalyses DXILTranslateMetadata::run(Module &M, + ModuleAnalysisManager &MAM) { + DXILBindingMap &DBM = MAM.getResult<DXILResourceBindingAnalysis>(M); + DXILResourceTypeMap &DRTM = MAM.getResult<DXILResourceTypeAnalysis>(M); + const dxil::Resources &MDResources = MAM.getResult<DXILResourceMDAnalysis>(M); + const ModuleShaderFlags &ShaderFlags = MAM.getResult<ShaderFlagsAnalysis>(M); + const dxil::ModuleMetadataInfo MMDI = MAM.getResult<DXILMetadataAnalysis>(M); + + translateMetadata(M, DBM, DRTM, MDResources, ShaderFlags, MMDI); + translateBranchMetadata(M); + + return PreservedAnalyses::all(); +} + +namespace { +class DXILTranslateMetadataLegacy : public ModulePass { +public: + static char ID; // Pass identification, replacement for typeid + explicit DXILTranslateMetadataLegacy() : ModulePass(ID) {} + + StringRef getPassName() const override { return "DXIL Translate Metadata"; } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<DXILResourceTypeWrapperPass>(); + AU.addRequired<DXILResourceBindingWrapperPass>(); + AU.addRequired<DXILResourceMDWrapper>(); + AU.addRequired<ShaderFlagsAnalysisWrapper>(); + AU.addRequired<DXILMetadataAnalysisWrapperPass>(); + AU.addPreserved<DXILResourceBindingWrapperPass>(); + AU.addPreserved<DXILResourceMDWrapper>(); + AU.addPreserved<DXILMetadataAnalysisWrapperPass>(); + AU.addPreserved<ShaderFlagsAnalysisWrapper>(); + } + + bool runOnModule(Module &M) override { + DXILBindingMap &DBM = + getAnalysis<DXILResourceBindingWrapperPass>().getBindingMap(); + DXILResourceTypeMap &DRTM = + getAnalysis<DXILResourceTypeWrapperPass>().getResourceTypeMap(); + const dxil::Resources &MDResources = + getAnalysis<DXILResourceMDWrapper>().getDXILResource(); + const ModuleShaderFlags &ShaderFlags = + getAnalysis<ShaderFlagsAnalysisWrapper>().getShaderFlags(); + dxil::ModuleMetadataInfo MMDI = + getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata(); + + translateMetadata(M, DBM, DRTM, MDResources, ShaderFlags, MMDI); + translateBranchMetadata(M); + return true; + } +}; + +} // namespace + +char DXILTranslateMetadataLegacy::ID = 0; -ModulePass *llvm::createDXILTranslateMetadataPass() { - return new DXILTranslateMetadata(); +ModulePass *llvm::createDXILTranslateMetadataLegacyPass() { + return new DXILTranslateMetadataLegacy(); } -INITIALIZE_PASS_BEGIN(DXILTranslateMetadata, "dxil-metadata-emit", - "DXIL Metadata Emit", false, false) -INITIALIZE_PASS_DEPENDENCY(DXILResourceWrapper) +INITIALIZE_PASS_BEGIN(DXILTranslateMetadataLegacy, "dxil-translate-metadata", + "DXIL Translate Metadata", false, false) +INITIALIZE_PASS_DEPENDENCY(DXILResourceBindingWrapperPass) +INITIALIZE_PASS_DEPENDENCY(DXILResourceMDWrapper) INITIALIZE_PASS_DEPENDENCY(ShaderFlagsAnalysisWrapper) -INITIALIZE_PASS_END(DXILTranslateMetadata, "dxil-metadata-emit", - "DXIL Metadata Emit", false, false) +INITIALIZE_PASS_DEPENDENCY(DXILMetadataAnalysisWrapperPass) +INITIALIZE_PASS_END(DXILTranslateMetadataLegacy, "dxil-translate-metadata", + "DXIL Translate Metadata", false, false) |
