aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp')
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp164
1 files changed, 161 insertions, 3 deletions
diff --git a/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp b/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp
index 605bf949187f..6d60bd5e3c97 100644
--- a/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp
@@ -21,6 +21,7 @@
#include "SPIRVUtils.h"
#include "TargetInfo/SPIRVTargetInfo.h"
#include "llvm/ADT/DenseMap.h"
+#include "llvm/Analysis/ValueTracking.h"
#include "llvm/CodeGen/AsmPrinter.h"
#include "llvm/CodeGen/MachineConstantPool.h"
#include "llvm/CodeGen/MachineFunctionPass.h"
@@ -58,9 +59,14 @@ public:
void outputModuleSection(SPIRV::ModuleSectionType MSType);
void outputEntryPoints();
void outputDebugSourceAndStrings(const Module &M);
+ void outputOpExtInstImports(const Module &M);
void outputOpMemoryModel();
void outputOpFunctionEnd();
void outputExtFuncDecls();
+ void outputExecutionModeFromMDNode(Register Reg, MDNode *Node,
+ SPIRV::ExecutionMode EM);
+ void outputExecutionMode(const Module &M);
+ void outputAnnotations(const Module &M);
void outputModuleSections();
void emitInstruction(const MachineInstr *MI) override;
@@ -127,6 +133,8 @@ void SPIRVAsmPrinter::emitFunctionBodyEnd() {
}
void SPIRVAsmPrinter::emitOpLabel(const MachineBasicBlock &MBB) {
+ if (MAI->MBBsToSkip.contains(&MBB))
+ return;
MCInst LabelInst;
LabelInst.setOpcode(SPIRV::OpLabel);
LabelInst.addOperand(MCOperand::createReg(MAI->getOrCreateMBBRegister(MBB)));
@@ -237,6 +245,13 @@ void SPIRVAsmPrinter::outputModuleSection(SPIRV::ModuleSectionType MSType) {
}
void SPIRVAsmPrinter::outputDebugSourceAndStrings(const Module &M) {
+ // Output OpSourceExtensions.
+ for (auto &Str : MAI->SrcExt) {
+ MCInst Inst;
+ Inst.setOpcode(SPIRV::OpSourceExtension);
+ addStringImm(Str.first(), Inst);
+ outputMCInst(Inst);
+ }
// Output OpSource.
MCInst Inst;
Inst.setOpcode(SPIRV::OpSource);
@@ -246,6 +261,19 @@ void SPIRVAsmPrinter::outputDebugSourceAndStrings(const Module &M) {
outputMCInst(Inst);
}
+void SPIRVAsmPrinter::outputOpExtInstImports(const Module &M) {
+ for (auto &CU : MAI->ExtInstSetMap) {
+ unsigned Set = CU.first;
+ Register Reg = CU.second;
+ MCInst Inst;
+ Inst.setOpcode(SPIRV::OpExtInstImport);
+ Inst.addOperand(MCOperand::createReg(Reg));
+ addStringImm(getExtInstSetName(static_cast<SPIRV::InstructionSet>(Set)),
+ Inst);
+ outputMCInst(Inst);
+ }
+}
+
void SPIRVAsmPrinter::outputOpMemoryModel() {
MCInst Inst;
Inst.setOpcode(SPIRV::OpMemoryModel);
@@ -301,6 +329,135 @@ void SPIRVAsmPrinter::outputExtFuncDecls() {
}
}
+// Encode LLVM type by SPIR-V execution mode VecTypeHint.
+static unsigned encodeVecTypeHint(Type *Ty) {
+ if (Ty->isHalfTy())
+ return 4;
+ if (Ty->isFloatTy())
+ return 5;
+ if (Ty->isDoubleTy())
+ return 6;
+ if (IntegerType *IntTy = dyn_cast<IntegerType>(Ty)) {
+ switch (IntTy->getIntegerBitWidth()) {
+ case 8:
+ return 0;
+ case 16:
+ return 1;
+ case 32:
+ return 2;
+ case 64:
+ return 3;
+ default:
+ llvm_unreachable("invalid integer type");
+ }
+ }
+ if (FixedVectorType *VecTy = dyn_cast<FixedVectorType>(Ty)) {
+ Type *EleTy = VecTy->getElementType();
+ unsigned Size = VecTy->getNumElements();
+ return Size << 16 | encodeVecTypeHint(EleTy);
+ }
+ llvm_unreachable("invalid type");
+}
+
+static void addOpsFromMDNode(MDNode *MDN, MCInst &Inst,
+ SPIRV::ModuleAnalysisInfo *MAI) {
+ for (const MDOperand &MDOp : MDN->operands()) {
+ if (auto *CMeta = dyn_cast<ConstantAsMetadata>(MDOp)) {
+ Constant *C = CMeta->getValue();
+ if (ConstantInt *Const = dyn_cast<ConstantInt>(C)) {
+ Inst.addOperand(MCOperand::createImm(Const->getZExtValue()));
+ } else if (auto *CE = dyn_cast<Function>(C)) {
+ Register FuncReg = MAI->getFuncReg(CE->getName().str());
+ assert(FuncReg.isValid());
+ Inst.addOperand(MCOperand::createReg(FuncReg));
+ }
+ }
+ }
+}
+
+void SPIRVAsmPrinter::outputExecutionModeFromMDNode(Register Reg, MDNode *Node,
+ SPIRV::ExecutionMode EM) {
+ MCInst Inst;
+ Inst.setOpcode(SPIRV::OpExecutionMode);
+ Inst.addOperand(MCOperand::createReg(Reg));
+ Inst.addOperand(MCOperand::createImm(static_cast<unsigned>(EM)));
+ addOpsFromMDNode(Node, Inst, MAI);
+ outputMCInst(Inst);
+}
+
+void SPIRVAsmPrinter::outputExecutionMode(const Module &M) {
+ NamedMDNode *Node = M.getNamedMetadata("spirv.ExecutionMode");
+ if (Node) {
+ for (unsigned i = 0; i < Node->getNumOperands(); i++) {
+ MCInst Inst;
+ Inst.setOpcode(SPIRV::OpExecutionMode);
+ addOpsFromMDNode(cast<MDNode>(Node->getOperand(i)), Inst, MAI);
+ outputMCInst(Inst);
+ }
+ }
+ for (auto FI = M.begin(), E = M.end(); FI != E; ++FI) {
+ const Function &F = *FI;
+ if (F.isDeclaration())
+ continue;
+ Register FReg = MAI->getFuncReg(F.getGlobalIdentifier());
+ assert(FReg.isValid());
+ if (MDNode *Node = F.getMetadata("reqd_work_group_size"))
+ outputExecutionModeFromMDNode(FReg, Node,
+ SPIRV::ExecutionMode::LocalSize);
+ if (MDNode *Node = F.getMetadata("work_group_size_hint"))
+ outputExecutionModeFromMDNode(FReg, Node,
+ SPIRV::ExecutionMode::LocalSizeHint);
+ if (MDNode *Node = F.getMetadata("intel_reqd_sub_group_size"))
+ outputExecutionModeFromMDNode(FReg, Node,
+ SPIRV::ExecutionMode::SubgroupSize);
+ if (MDNode *Node = F.getMetadata("vec_type_hint")) {
+ MCInst Inst;
+ Inst.setOpcode(SPIRV::OpExecutionMode);
+ Inst.addOperand(MCOperand::createReg(FReg));
+ unsigned EM = static_cast<unsigned>(SPIRV::ExecutionMode::VecTypeHint);
+ Inst.addOperand(MCOperand::createImm(EM));
+ unsigned TypeCode = encodeVecTypeHint(getMDOperandAsType(Node, 0));
+ Inst.addOperand(MCOperand::createImm(TypeCode));
+ outputMCInst(Inst);
+ }
+ }
+}
+
+void SPIRVAsmPrinter::outputAnnotations(const Module &M) {
+ outputModuleSection(SPIRV::MB_Annotations);
+ // Process llvm.global.annotations special global variable.
+ for (auto F = M.global_begin(), E = M.global_end(); F != E; ++F) {
+ if ((*F).getName() != "llvm.global.annotations")
+ continue;
+ const GlobalVariable *V = &(*F);
+ const ConstantArray *CA = cast<ConstantArray>(V->getOperand(0));
+ for (Value *Op : CA->operands()) {
+ ConstantStruct *CS = cast<ConstantStruct>(Op);
+ // The first field of the struct contains a pointer to
+ // the annotated variable.
+ Value *AnnotatedVar = CS->getOperand(0)->stripPointerCasts();
+ if (!isa<Function>(AnnotatedVar))
+ llvm_unreachable("Unsupported value in llvm.global.annotations");
+ Function *Func = cast<Function>(AnnotatedVar);
+ Register Reg = MAI->getFuncReg(Func->getGlobalIdentifier());
+
+ // The second field contains a pointer to a global annotation string.
+ GlobalVariable *GV =
+ cast<GlobalVariable>(CS->getOperand(1)->stripPointerCasts());
+
+ StringRef AnnotationString;
+ getConstantStringInfo(GV, AnnotationString);
+ MCInst Inst;
+ Inst.setOpcode(SPIRV::OpDecorate);
+ Inst.addOperand(MCOperand::createReg(Reg));
+ unsigned Dec = static_cast<unsigned>(SPIRV::Decoration::UserSemantic);
+ Inst.addOperand(MCOperand::createImm(Dec));
+ addStringImm(AnnotationString, Inst);
+ outputMCInst(Inst);
+ }
+ }
+}
+
void SPIRVAsmPrinter::outputModuleSections() {
const Module *M = MMI->getModule();
// Get the global subtarget to output module-level info.
@@ -311,13 +468,14 @@ void SPIRVAsmPrinter::outputModuleSections() {
// Output instructions according to the Logical Layout of a Module:
// TODO: 1,2. All OpCapability instructions, then optional OpExtension
// instructions.
- // TODO: 3. Optional OpExtInstImport instructions.
+ // 3. Optional OpExtInstImport instructions.
+ outputOpExtInstImports(*M);
// 4. The single required OpMemoryModel instruction.
outputOpMemoryModel();
// 5. All entry point declarations, using OpEntryPoint.
outputEntryPoints();
// 6. Execution-mode declarations, using OpExecutionMode or OpExecutionModeId.
- // TODO:
+ outputExecutionMode(*M);
// 7a. Debug: all OpString, OpSourceExtension, OpSource, and
// OpSourceContinued, without forward references.
outputDebugSourceAndStrings(*M);
@@ -326,7 +484,7 @@ void SPIRVAsmPrinter::outputModuleSections() {
// 7c. Debug: all OpModuleProcessed instructions.
outputModuleSection(SPIRV::MB_DebugModuleProcessed);
// 8. All annotation instructions (all decorations).
- outputModuleSection(SPIRV::MB_Annotations);
+ outputAnnotations(*M);
// 9. All type declarations (OpTypeXXX instructions), all constant
// instructions, and all global variable declarations. This section is
// the first section to allow use of: OpLine and OpNoLine debug information;