diff options
Diffstat (limited to 'llvm/lib/Target/DirectX/DXILShaderFlags.cpp')
| -rw-r--r-- | llvm/lib/Target/DirectX/DXILShaderFlags.cpp | 186 |
1 files changed, 167 insertions, 19 deletions
diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp index 9fa137b4c025..6a15bac153d8 100644 --- a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp +++ b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp @@ -13,36 +13,135 @@ #include "DXILShaderFlags.h" #include "DirectX.h" +#include "llvm/ADT/SCCIterator.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Analysis/CallGraph.h" +#include "llvm/Analysis/DXILResource.h" #include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/IntrinsicsDirectX.h" #include "llvm/IR/Module.h" +#include "llvm/InitializePasses.h" #include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/raw_ostream.h" using namespace llvm; using namespace llvm::dxil; -static void updateFlags(ComputedShaderFlags &Flags, const Instruction &I) { - Type *Ty = I.getType(); - if (Ty->isDoubleTy()) { - Flags.Doubles = true; +/// Update the shader flags mask based on the given instruction. +/// \param CSF Shader flags mask to update. +/// \param I Instruction to check. +void ModuleShaderFlags::updateFunctionFlags(ComputedShaderFlags &CSF, + const Instruction &I, + DXILResourceTypeMap &DRTM) { + if (!CSF.Doubles) + CSF.Doubles = I.getType()->isDoubleTy(); + + if (!CSF.Doubles) { + for (const Value *Op : I.operands()) { + if (Op->getType()->isDoubleTy()) { + CSF.Doubles = true; + break; + } + } + } + + if (CSF.Doubles) { switch (I.getOpcode()) { case Instruction::FDiv: case Instruction::UIToFP: case Instruction::SIToFP: case Instruction::FPToUI: case Instruction::FPToSI: - Flags.DX11_1_DoubleExtensions = true; + CSF.DX11_1_DoubleExtensions = true; break; } } + + if (auto *II = dyn_cast<IntrinsicInst>(&I)) { + switch (II->getIntrinsicID()) { + default: + break; + case Intrinsic::dx_resource_handlefrombinding: + switch (DRTM[cast<TargetExtType>(II->getType())].getResourceKind()) { + case dxil::ResourceKind::StructuredBuffer: + case dxil::ResourceKind::RawBuffer: + CSF.EnableRawAndStructuredBuffers = true; + break; + default: + break; + } + break; + case Intrinsic::dx_resource_load_typedbuffer: { + dxil::ResourceTypeInfo &RTI = + DRTM[cast<TargetExtType>(II->getArgOperand(0)->getType())]; + if (RTI.isTyped()) + CSF.TypedUAVLoadAdditionalFormats |= RTI.getTyped().ElementCount > 1; + break; + } + } + } + // Handle call instructions + if (auto *CI = dyn_cast<CallInst>(&I)) { + const Function *CF = CI->getCalledFunction(); + // Merge-in shader flags mask of the called function in the current module + if (FunctionFlags.contains(CF)) + CSF.merge(FunctionFlags[CF]); + + // TODO: Set DX11_1_DoubleExtensions if I is a call to DXIL intrinsic + // DXIL::Opcode::Fma https://github.com/llvm/llvm-project/issues/114554 + } } -ComputedShaderFlags ComputedShaderFlags::computeFlags(Module &M) { - ComputedShaderFlags Flags; - for (const auto &F : M) - for (const auto &BB : F) - for (const auto &I : BB) - updateFlags(Flags, I); - return Flags; +/// Construct ModuleShaderFlags for module Module M +void ModuleShaderFlags::initialize(Module &M, DXILResourceTypeMap &DRTM) { + CallGraph CG(M); + + // Compute Shader Flags Mask for all functions using post-order visit of SCC + // of the call graph. + for (scc_iterator<CallGraph *> SCCI = scc_begin(&CG); !SCCI.isAtEnd(); + ++SCCI) { + const std::vector<CallGraphNode *> &CurSCC = *SCCI; + + // Union of shader masks of all functions in CurSCC + ComputedShaderFlags SCCSF; + // List of functions in CurSCC that are neither external nor declarations + // and hence whose flags are collected + SmallVector<Function *> CurSCCFuncs; + for (CallGraphNode *CGN : CurSCC) { + Function *F = CGN->getFunction(); + if (!F) + continue; + + if (F->isDeclaration()) { + assert(!F->getName().starts_with("dx.op.") && + "DXIL Shader Flag analysis should not be run post-lowering."); + continue; + } + + ComputedShaderFlags CSF; + for (const auto &BB : *F) + for (const auto &I : BB) + updateFunctionFlags(CSF, I, DRTM); + // Update combined shader flags mask for all functions in this SCC + SCCSF.merge(CSF); + + CurSCCFuncs.push_back(F); + } + + // Update combined shader flags mask for all functions of the module + CombinedSFMask.merge(SCCSF); + + // Shader flags mask of each of the functions in an SCC of the call graph is + // the union of all functions in the SCC. Update shader flags masks of + // functions in CurSCC accordingly. This is trivially true if SCC contains + // one function. + for (Function *F : CurSCCFuncs) + // Merge SCCSF with that of F + FunctionFlags[F].merge(SCCSF); + } } void ComputedShaderFlags::print(raw_ostream &OS) const { @@ -63,21 +162,70 @@ void ComputedShaderFlags::print(raw_ostream &OS) const { OS << ";\n"; } +/// Return the shader flags mask of the specified function Func. +const ComputedShaderFlags & +ModuleShaderFlags::getFunctionFlags(const Function *Func) const { + auto Iter = FunctionFlags.find(Func); + assert((Iter != FunctionFlags.end() && Iter->first == Func) && + "Get Shader Flags : No Shader Flags Mask exists for function"); + return Iter->second; +} + +//===----------------------------------------------------------------------===// +// ShaderFlagsAnalysis and ShaderFlagsAnalysisPrinterPass + +// Provide an explicit template instantiation for the static ID. AnalysisKey ShaderFlagsAnalysis::Key; -ComputedShaderFlags ShaderFlagsAnalysis::run(Module &M, - ModuleAnalysisManager &AM) { - return ComputedShaderFlags::computeFlags(M); +ModuleShaderFlags ShaderFlagsAnalysis::run(Module &M, + ModuleAnalysisManager &AM) { + DXILResourceTypeMap &DRTM = AM.getResult<DXILResourceTypeAnalysis>(M); + + ModuleShaderFlags MSFI; + MSFI.initialize(M, DRTM); + + return MSFI; } PreservedAnalyses ShaderFlagsAnalysisPrinter::run(Module &M, ModuleAnalysisManager &AM) { - ComputedShaderFlags Flags = AM.getResult<ShaderFlagsAnalysis>(M); - Flags.print(OS); + const ModuleShaderFlags &FlagsInfo = AM.getResult<ShaderFlagsAnalysis>(M); + // Print description of combined shader flags for all module functions + OS << "; Combined Shader Flags for Module\n"; + FlagsInfo.getCombinedFlags().print(OS); + // Print shader flags mask for each of the module functions + OS << "; Shader Flags for Module Functions\n"; + for (const auto &F : M.getFunctionList()) { + if (F.isDeclaration()) + continue; + const ComputedShaderFlags &SFMask = FlagsInfo.getFunctionFlags(&F); + OS << formatv("; Function {0} : {1:x8}\n;\n", F.getName(), + (uint64_t)(SFMask)); + } + return PreservedAnalyses::all(); } +//===----------------------------------------------------------------------===// +// ShaderFlagsAnalysis and ShaderFlagsAnalysisPrinterPass + +bool ShaderFlagsAnalysisWrapper::runOnModule(Module &M) { + DXILResourceTypeMap &DRTM = + getAnalysis<DXILResourceTypeWrapperPass>().getResourceTypeMap(); + + MSFI.initialize(M, DRTM); + return false; +} + +void ShaderFlagsAnalysisWrapper::getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesAll(); + AU.addRequiredTransitive<DXILResourceTypeWrapperPass>(); +} + char ShaderFlagsAnalysisWrapper::ID = 0; -INITIALIZE_PASS(ShaderFlagsAnalysisWrapper, "dx-shader-flag-analysis", - "DXIL Shader Flag Analysis", true, true) +INITIALIZE_PASS_BEGIN(ShaderFlagsAnalysisWrapper, "dx-shader-flag-analysis", + "DXIL Shader Flag Analysis", true, true) +INITIALIZE_PASS_DEPENDENCY(DXILResourceTypeWrapperPass) +INITIALIZE_PASS_END(ShaderFlagsAnalysisWrapper, "dx-shader-flag-analysis", + "DXIL Shader Flag Analysis", true, true) |
