aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp')
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp425
1 files changed, 384 insertions, 41 deletions
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 5f890c003cbc..5c8fa7adfbdf 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -24,6 +24,24 @@ using namespace llvm;
SPIRVGlobalRegistry::SPIRVGlobalRegistry(unsigned PointerSize)
: PointerSize(PointerSize) {}
+SPIRVType *SPIRVGlobalRegistry::assignIntTypeToVReg(unsigned BitWidth,
+ Register VReg,
+ MachineInstr &I,
+ const SPIRVInstrInfo &TII) {
+ SPIRVType *SpirvType = getOrCreateSPIRVIntegerType(BitWidth, I, TII);
+ assignSPIRVTypeToVReg(SpirvType, VReg, *CurMF);
+ return SpirvType;
+}
+
+SPIRVType *SPIRVGlobalRegistry::assignVectTypeToVReg(
+ SPIRVType *BaseType, unsigned NumElements, Register VReg, MachineInstr &I,
+ const SPIRVInstrInfo &TII) {
+ SPIRVType *SpirvType =
+ getOrCreateSPIRVVectorType(BaseType, NumElements, I, TII);
+ assignSPIRVTypeToVReg(SpirvType, VReg, *CurMF);
+ return SpirvType;
+}
+
SPIRVType *SPIRVGlobalRegistry::assignTypeToVReg(
const Type *Type, Register VReg, MachineIRBuilder &MIRBuilder,
SPIRV::AccessQualifier AccessQual, bool EmitIR) {
@@ -96,6 +114,65 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeVector(uint32_t NumElems,
return MIB;
}
+std::tuple<Register, ConstantInt *, bool>
+SPIRVGlobalRegistry::getOrCreateConstIntReg(uint64_t Val, SPIRVType *SpvType,
+ MachineIRBuilder *MIRBuilder,
+ MachineInstr *I,
+ const SPIRVInstrInfo *TII) {
+ const IntegerType *LLVMIntTy;
+ if (SpvType)
+ LLVMIntTy = cast<IntegerType>(getTypeForSPIRVType(SpvType));
+ else
+ LLVMIntTy = IntegerType::getInt32Ty(CurMF->getFunction().getContext());
+ bool NewInstr = false;
+ // Find a constant in DT or build a new one.
+ ConstantInt *CI = ConstantInt::get(const_cast<IntegerType *>(LLVMIntTy), Val);
+ Register Res = DT.find(CI, CurMF);
+ if (!Res.isValid()) {
+ unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32;
+ LLT LLTy = LLT::scalar(32);
+ Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
+ if (MIRBuilder)
+ assignTypeToVReg(LLVMIntTy, Res, *MIRBuilder);
+ else
+ assignIntTypeToVReg(BitWidth, Res, *I, *TII);
+ DT.add(CI, CurMF, Res);
+ NewInstr = true;
+ }
+ return std::make_tuple(Res, CI, NewInstr);
+}
+
+Register SPIRVGlobalRegistry::getOrCreateConstInt(uint64_t Val, MachineInstr &I,
+ SPIRVType *SpvType,
+ const SPIRVInstrInfo &TII) {
+ assert(SpvType);
+ ConstantInt *CI;
+ Register Res;
+ bool New;
+ std::tie(Res, CI, New) =
+ getOrCreateConstIntReg(Val, SpvType, nullptr, &I, &TII);
+ // If we have found Res register which is defined by the passed G_CONSTANT
+ // machine instruction, a new constant instruction should be created.
+ if (!New && (!I.getOperand(0).isReg() || Res != I.getOperand(0).getReg()))
+ return Res;
+ MachineInstrBuilder MIB;
+ MachineBasicBlock &BB = *I.getParent();
+ if (Val) {
+ MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantI))
+ .addDef(Res)
+ .addUse(getSPIRVTypeID(SpvType));
+ addNumImm(APInt(getScalarOrVectorBitWidth(SpvType), Val), MIB);
+ } else {
+ MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantNull))
+ .addDef(Res)
+ .addUse(getSPIRVTypeID(SpvType));
+ }
+ const auto &ST = CurMF->getSubtarget();
+ constrainSelectedInstRegOperands(*MIB, *ST.getInstrInfo(),
+ *ST.getRegisterInfo(), *ST.getRegBankInfo());
+ return Res;
+}
+
Register SPIRVGlobalRegistry::buildConstantInt(uint64_t Val,
MachineIRBuilder &MIRBuilder,
SPIRVType *SpvType,
@@ -112,14 +189,32 @@ Register SPIRVGlobalRegistry::buildConstantInt(uint64_t Val,
Register Res = DT.find(ConstInt, &MF);
if (!Res.isValid()) {
unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32;
- Res = MF.getRegInfo().createGenericVirtualRegister(LLT::scalar(BitWidth));
- assignTypeToVReg(LLVMIntTy, Res, MIRBuilder);
- if (EmitIR)
+ LLT LLTy = LLT::scalar(EmitIR ? BitWidth : 32);
+ Res = MF.getRegInfo().createGenericVirtualRegister(LLTy);
+ assignTypeToVReg(LLVMIntTy, Res, MIRBuilder,
+ SPIRV::AccessQualifier::ReadWrite, EmitIR);
+ DT.add(ConstInt, &MIRBuilder.getMF(), Res);
+ if (EmitIR) {
MIRBuilder.buildConstant(Res, *ConstInt);
- else
- MIRBuilder.buildInstr(SPIRV::OpConstantI)
- .addDef(Res)
- .addImm(ConstInt->getSExtValue());
+ } else {
+ MachineInstrBuilder MIB;
+ if (Val) {
+ assert(SpvType);
+ MIB = MIRBuilder.buildInstr(SPIRV::OpConstantI)
+ .addDef(Res)
+ .addUse(getSPIRVTypeID(SpvType));
+ addNumImm(APInt(BitWidth, Val), MIB);
+ } else {
+ assert(SpvType);
+ MIB = MIRBuilder.buildInstr(SPIRV::OpConstantNull)
+ .addDef(Res)
+ .addUse(getSPIRVTypeID(SpvType));
+ }
+ const auto &Subtarget = CurMF->getSubtarget();
+ constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(),
+ *Subtarget.getRegisterInfo(),
+ *Subtarget.getRegBankInfo());
+ }
}
return Res;
}
@@ -142,11 +237,63 @@ Register SPIRVGlobalRegistry::buildConstantFP(APFloat Val,
unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32;
Res = MF.getRegInfo().createGenericVirtualRegister(LLT::scalar(BitWidth));
assignTypeToVReg(LLVMFPTy, Res, MIRBuilder);
+ DT.add(ConstFP, &MF, Res);
MIRBuilder.buildFConstant(Res, *ConstFP);
}
return Res;
}
+Register
+SPIRVGlobalRegistry::getOrCreateConsIntVector(uint64_t Val, MachineInstr &I,
+ SPIRVType *SpvType,
+ const SPIRVInstrInfo &TII) {
+ const Type *LLVMTy = getTypeForSPIRVType(SpvType);
+ assert(LLVMTy->isVectorTy());
+ const FixedVectorType *LLVMVecTy = cast<FixedVectorType>(LLVMTy);
+ Type *LLVMBaseTy = LLVMVecTy->getElementType();
+ // Find a constant vector in DT or build a new one.
+ const auto ConstInt = ConstantInt::get(LLVMBaseTy, Val);
+ auto ConstVec =
+ ConstantVector::getSplat(LLVMVecTy->getElementCount(), ConstInt);
+ Register Res = DT.find(ConstVec, CurMF);
+ if (!Res.isValid()) {
+ unsigned BitWidth = getScalarOrVectorBitWidth(SpvType);
+ SPIRVType *SpvBaseType = getOrCreateSPIRVIntegerType(BitWidth, I, TII);
+ // SpvScalConst should be created before SpvVecConst to avoid undefined ID
+ // error on validation.
+ // TODO: can moved below once sorting of types/consts/defs is implemented.
+ Register SpvScalConst;
+ if (Val)
+ SpvScalConst = getOrCreateConstInt(Val, I, SpvBaseType, TII);
+ // TODO: maybe use bitwidth of base type.
+ LLT LLTy = LLT::scalar(32);
+ Register SpvVecConst =
+ CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
+ const unsigned ElemCnt = SpvType->getOperand(2).getImm();
+ assignVectTypeToVReg(SpvBaseType, ElemCnt, SpvVecConst, I, TII);
+ DT.add(ConstVec, CurMF, SpvVecConst);
+ MachineInstrBuilder MIB;
+ MachineBasicBlock &BB = *I.getParent();
+ if (Val) {
+ MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantComposite))
+ .addDef(SpvVecConst)
+ .addUse(getSPIRVTypeID(SpvType));
+ for (unsigned i = 0; i < ElemCnt; ++i)
+ MIB.addUse(SpvScalConst);
+ } else {
+ MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantNull))
+ .addDef(SpvVecConst)
+ .addUse(getSPIRVTypeID(SpvType));
+ }
+ const auto &Subtarget = CurMF->getSubtarget();
+ constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(),
+ *Subtarget.getRegisterInfo(),
+ *Subtarget.getRegBankInfo());
+ return SpvVecConst;
+ }
+ return Res;
+}
+
Register SPIRVGlobalRegistry::buildGlobalVariable(
Register ResVReg, SPIRVType *BaseType, StringRef Name,
const GlobalValue *GV, SPIRV::StorageClass Storage,
@@ -169,7 +316,13 @@ Register SPIRVGlobalRegistry::buildGlobalVariable(
}
GV = GVar;
}
- Register Reg;
+ Register Reg = DT.find(GVar, &MIRBuilder.getMF());
+ if (Reg.isValid()) {
+ if (Reg != ResVReg)
+ MIRBuilder.buildCopy(ResVReg, Reg);
+ return ResVReg;
+ }
+
auto MIB = MIRBuilder.buildInstr(SPIRV::OpVariable)
.addDef(ResVReg)
.addUse(getSPIRVTypeID(BaseType))
@@ -234,14 +387,76 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeArray(uint32_t NumElems,
return MIB;
}
+SPIRVType *SPIRVGlobalRegistry::getOpTypeOpaque(const StructType *Ty,
+ MachineIRBuilder &MIRBuilder) {
+ assert(Ty->hasName());
+ const StringRef Name = Ty->hasName() ? Ty->getName() : "";
+ Register ResVReg = createTypeVReg(MIRBuilder);
+ auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeOpaque).addDef(ResVReg);
+ addStringImm(Name, MIB);
+ buildOpName(ResVReg, Name, MIRBuilder);
+ return MIB;
+}
+
+SPIRVType *SPIRVGlobalRegistry::getOpTypeStruct(const StructType *Ty,
+ MachineIRBuilder &MIRBuilder,
+ bool EmitIR) {
+ SmallVector<Register, 4> FieldTypes;
+ for (const auto &Elem : Ty->elements()) {
+ SPIRVType *ElemTy = findSPIRVType(Elem, MIRBuilder);
+ assert(ElemTy && ElemTy->getOpcode() != SPIRV::OpTypeVoid &&
+ "Invalid struct element type");
+ FieldTypes.push_back(getSPIRVTypeID(ElemTy));
+ }
+ Register ResVReg = createTypeVReg(MIRBuilder);
+ auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeStruct).addDef(ResVReg);
+ for (const auto &Ty : FieldTypes)
+ MIB.addUse(Ty);
+ if (Ty->hasName())
+ buildOpName(ResVReg, Ty->getName(), MIRBuilder);
+ if (Ty->isPacked())
+ buildOpDecorate(ResVReg, MIRBuilder, SPIRV::Decoration::CPacked, {});
+ return MIB;
+}
+
+static bool isOpenCLBuiltinType(const StructType *SType) {
+ return SType->isOpaque() && SType->hasName() &&
+ SType->getName().startswith("opencl.");
+}
+
+static bool isSPIRVBuiltinType(const StructType *SType) {
+ return SType->isOpaque() && SType->hasName() &&
+ SType->getName().startswith("spirv.");
+}
+
+static bool isSpecialType(const Type *Ty) {
+ if (auto PType = dyn_cast<PointerType>(Ty)) {
+ if (!PType->isOpaque())
+ Ty = PType->getNonOpaquePointerElementType();
+ }
+ if (auto SType = dyn_cast<StructType>(Ty))
+ return isOpenCLBuiltinType(SType) || isSPIRVBuiltinType(SType);
+ return false;
+}
+
SPIRVType *SPIRVGlobalRegistry::getOpTypePointer(SPIRV::StorageClass SC,
SPIRVType *ElemType,
- MachineIRBuilder &MIRBuilder) {
- auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypePointer)
- .addDef(createTypeVReg(MIRBuilder))
- .addImm(static_cast<uint32_t>(SC))
- .addUse(getSPIRVTypeID(ElemType));
- return MIB;
+ MachineIRBuilder &MIRBuilder,
+ Register Reg) {
+ if (!Reg.isValid())
+ Reg = createTypeVReg(MIRBuilder);
+ return MIRBuilder.buildInstr(SPIRV::OpTypePointer)
+ .addDef(Reg)
+ .addImm(static_cast<uint32_t>(SC))
+ .addUse(getSPIRVTypeID(ElemType));
+}
+
+SPIRVType *
+SPIRVGlobalRegistry::getOpTypeForwardPointer(SPIRV::StorageClass SC,
+ MachineIRBuilder &MIRBuilder) {
+ return MIRBuilder.buildInstr(SPIRV::OpTypeForwardPointer)
+ .addUse(createTypeVReg(MIRBuilder))
+ .addImm(static_cast<uint32_t>(SC));
}
SPIRVType *SPIRVGlobalRegistry::getOpTypeFunction(
@@ -255,10 +470,49 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeFunction(
return MIB;
}
+SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeFunctionWithArgs(
+ const Type *Ty, SPIRVType *RetType,
+ const SmallVectorImpl<SPIRVType *> &ArgTypes,
+ MachineIRBuilder &MIRBuilder) {
+ Register Reg = DT.find(Ty, &MIRBuilder.getMF());
+ if (Reg.isValid())
+ return getSPIRVTypeForVReg(Reg);
+ SPIRVType *SpirvType = getOpTypeFunction(RetType, ArgTypes, MIRBuilder);
+ return finishCreatingSPIRVType(Ty, SpirvType);
+}
+
+SPIRVType *SPIRVGlobalRegistry::findSPIRVType(const Type *Ty,
+ MachineIRBuilder &MIRBuilder,
+ SPIRV::AccessQualifier AccQual,
+ bool EmitIR) {
+ Register Reg = DT.find(Ty, &MIRBuilder.getMF());
+ if (Reg.isValid())
+ return getSPIRVTypeForVReg(Reg);
+ if (ForwardPointerTypes.find(Ty) != ForwardPointerTypes.end())
+ return ForwardPointerTypes[Ty];
+ return restOfCreateSPIRVType(Ty, MIRBuilder, AccQual, EmitIR);
+}
+
+Register SPIRVGlobalRegistry::getSPIRVTypeID(const SPIRVType *SpirvType) const {
+ assert(SpirvType && "Attempting to get type id for nullptr type.");
+ if (SpirvType->getOpcode() == SPIRV::OpTypeForwardPointer)
+ return SpirvType->uses().begin()->getReg();
+ return SpirvType->defs().begin()->getReg();
+}
+
SPIRVType *SPIRVGlobalRegistry::createSPIRVType(const Type *Ty,
MachineIRBuilder &MIRBuilder,
SPIRV::AccessQualifier AccQual,
bool EmitIR) {
+ assert(!isSpecialType(Ty));
+ auto &TypeToSPIRVTypeMap = DT.getTypes()->getAllUses();
+ auto t = TypeToSPIRVTypeMap.find(Ty);
+ if (t != TypeToSPIRVTypeMap.end()) {
+ auto tt = t->second.find(&MIRBuilder.getMF());
+ if (tt != t->second.end())
+ return getSPIRVTypeForVReg(tt->second);
+ }
+
if (auto IType = dyn_cast<IntegerType>(Ty)) {
const unsigned Width = IType->getBitWidth();
return Width == 1 ? getOpTypeBool(MIRBuilder)
@@ -269,21 +523,25 @@ SPIRVType *SPIRVGlobalRegistry::createSPIRVType(const Type *Ty,
if (Ty->isVoidTy())
return getOpTypeVoid(MIRBuilder);
if (Ty->isVectorTy()) {
- auto El = getOrCreateSPIRVType(cast<FixedVectorType>(Ty)->getElementType(),
- MIRBuilder);
+ SPIRVType *El =
+ findSPIRVType(cast<FixedVectorType>(Ty)->getElementType(), MIRBuilder);
return getOpTypeVector(cast<FixedVectorType>(Ty)->getNumElements(), El,
MIRBuilder);
}
if (Ty->isArrayTy()) {
- auto *El = getOrCreateSPIRVType(Ty->getArrayElementType(), MIRBuilder);
+ SPIRVType *El = findSPIRVType(Ty->getArrayElementType(), MIRBuilder);
return getOpTypeArray(Ty->getArrayNumElements(), El, MIRBuilder, EmitIR);
}
- assert(!isa<StructType>(Ty) && "Unsupported StructType");
+ if (auto SType = dyn_cast<StructType>(Ty)) {
+ if (SType->isOpaque())
+ return getOpTypeOpaque(SType, MIRBuilder);
+ return getOpTypeStruct(SType, MIRBuilder, EmitIR);
+ }
if (auto FType = dyn_cast<FunctionType>(Ty)) {
- SPIRVType *RetTy = getOrCreateSPIRVType(FType->getReturnType(), MIRBuilder);
+ SPIRVType *RetTy = findSPIRVType(FType->getReturnType(), MIRBuilder);
SmallVector<SPIRVType *, 4> ParamTypes;
for (const auto &t : FType->params()) {
- ParamTypes.push_back(getOrCreateSPIRVType(t, MIRBuilder));
+ ParamTypes.push_back(findSPIRVType(t, MIRBuilder));
}
return getOpTypeFunction(RetTy, ParamTypes, MIRBuilder);
}
@@ -292,24 +550,51 @@ SPIRVType *SPIRVGlobalRegistry::createSPIRVType(const Type *Ty,
// At the moment, all opaque pointers correspond to i8 element type.
// TODO: change the implementation once opaque pointers are supported
// in the SPIR-V specification.
- if (PType->isOpaque()) {
+ if (PType->isOpaque())
SpvElementType = getOrCreateSPIRVIntegerType(8, MIRBuilder);
- } else {
- Type *ElemType = PType->getNonOpaquePointerElementType();
- // TODO: support OpenCL and SPIRV builtins like image2d_t that are passed
- // as pointers, but should be treated as custom types like OpTypeImage.
- assert(!isa<StructType>(ElemType) && "Unsupported StructType pointer");
-
- // Otherwise, treat it as a regular pointer type.
- SpvElementType = getOrCreateSPIRVType(
- ElemType, MIRBuilder, SPIRV::AccessQualifier::ReadWrite, EmitIR);
- }
+ else
+ SpvElementType =
+ findSPIRVType(PType->getNonOpaquePointerElementType(), MIRBuilder,
+ SPIRV::AccessQualifier::ReadWrite, EmitIR);
auto SC = addressSpaceToStorageClass(PType->getAddressSpace());
- return getOpTypePointer(SC, SpvElementType, MIRBuilder);
+ // Null pointer means we have a loop in type definitions, make and
+ // return corresponding OpTypeForwardPointer.
+ if (SpvElementType == nullptr) {
+ if (ForwardPointerTypes.find(Ty) == ForwardPointerTypes.end())
+ ForwardPointerTypes[PType] = getOpTypeForwardPointer(SC, MIRBuilder);
+ return ForwardPointerTypes[PType];
+ }
+ Register Reg(0);
+ // If we have forward pointer associated with this type, use its register
+ // operand to create OpTypePointer.
+ if (ForwardPointerTypes.find(PType) != ForwardPointerTypes.end())
+ Reg = getSPIRVTypeID(ForwardPointerTypes[PType]);
+
+ return getOpTypePointer(SC, SpvElementType, MIRBuilder, Reg);
}
llvm_unreachable("Unable to convert LLVM type to SPIRVType");
}
+SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType(
+ const Type *Ty, MachineIRBuilder &MIRBuilder,
+ SPIRV::AccessQualifier AccessQual, bool EmitIR) {
+ if (TypesInProcessing.count(Ty) && !Ty->isPointerTy())
+ return nullptr;
+ TypesInProcessing.insert(Ty);
+ SPIRVType *SpirvType = createSPIRVType(Ty, MIRBuilder, AccessQual, EmitIR);
+ TypesInProcessing.erase(Ty);
+ VRegToTypeMap[&MIRBuilder.getMF()][getSPIRVTypeID(SpirvType)] = SpirvType;
+ SPIRVToLLVMType[SpirvType] = Ty;
+ Register Reg = DT.find(Ty, &MIRBuilder.getMF());
+ // Do not add OpTypeForwardPointer to DT, a corresponding normal pointer type
+ // will be added later. For special types it is already added to DT.
+ if (SpirvType->getOpcode() != SPIRV::OpTypeForwardPointer && !Reg.isValid() &&
+ !isSpecialType(Ty))
+ DT.add(Ty, &MIRBuilder.getMF(), getSPIRVTypeID(SpirvType));
+
+ return SpirvType;
+}
+
SPIRVType *SPIRVGlobalRegistry::getSPIRVTypeForVReg(Register VReg) const {
auto t = VRegToTypeMap.find(CurMF);
if (t != VRegToTypeMap.end()) {
@@ -321,13 +606,26 @@ SPIRVType *SPIRVGlobalRegistry::getSPIRVTypeForVReg(Register VReg) const {
}
SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(
- const Type *Type, MachineIRBuilder &MIRBuilder,
+ const Type *Ty, MachineIRBuilder &MIRBuilder,
SPIRV::AccessQualifier AccessQual, bool EmitIR) {
- Register Reg = DT.find(Type, &MIRBuilder.getMF());
+ Register Reg = DT.find(Ty, &MIRBuilder.getMF());
if (Reg.isValid())
return getSPIRVTypeForVReg(Reg);
- SPIRVType *SpirvType = createSPIRVType(Type, MIRBuilder, AccessQual, EmitIR);
- return restOfCreateSPIRVType(Type, SpirvType);
+ TypesInProcessing.clear();
+ SPIRVType *STy = restOfCreateSPIRVType(Ty, MIRBuilder, AccessQual, EmitIR);
+ // Create normal pointer types for the corresponding OpTypeForwardPointers.
+ for (auto &CU : ForwardPointerTypes) {
+ const Type *Ty2 = CU.first;
+ SPIRVType *STy2 = CU.second;
+ if ((Reg = DT.find(Ty2, &MIRBuilder.getMF())).isValid())
+ STy2 = getSPIRVTypeForVReg(Reg);
+ else
+ STy2 = restOfCreateSPIRVType(Ty2, MIRBuilder, AccessQual, EmitIR);
+ if (Ty == Ty2)
+ STy = STy2;
+ }
+ ForwardPointerTypes.clear();
+ return STy;
}
bool SPIRVGlobalRegistry::isScalarOfType(Register VReg,
@@ -393,8 +691,8 @@ SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType(unsigned BitWidth,
MIRBuilder);
}
-SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType(const Type *LLVMTy,
- SPIRVType *SpirvType) {
+SPIRVType *SPIRVGlobalRegistry::finishCreatingSPIRVType(const Type *LLVMTy,
+ SPIRVType *SpirvType) {
assert(CurMF == SpirvType->getMF());
VRegToTypeMap[CurMF][getSPIRVTypeID(SpirvType)] = SpirvType;
SPIRVToLLVMType[SpirvType] = LLVMTy;
@@ -413,7 +711,7 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType(
.addDef(createTypeVReg(CurMF->getRegInfo()))
.addImm(BitWidth)
.addImm(0);
- return restOfCreateSPIRVType(LLVMTy, MIB);
+ return finishCreatingSPIRVType(LLVMTy, MIB);
}
SPIRVType *
@@ -423,6 +721,19 @@ SPIRVGlobalRegistry::getOrCreateSPIRVBoolType(MachineIRBuilder &MIRBuilder) {
MIRBuilder);
}
+SPIRVType *
+SPIRVGlobalRegistry::getOrCreateSPIRVBoolType(MachineInstr &I,
+ const SPIRVInstrInfo &TII) {
+ Type *LLVMTy = IntegerType::get(CurMF->getFunction().getContext(), 1);
+ Register Reg = DT.find(LLVMTy, CurMF);
+ if (Reg.isValid())
+ return getSPIRVTypeForVReg(Reg);
+ MachineBasicBlock &BB = *I.getParent();
+ auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeBool))
+ .addDef(createTypeVReg(CurMF->getRegInfo()));
+ return finishCreatingSPIRVType(LLVMTy, MIB);
+}
+
SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVVectorType(
SPIRVType *BaseType, unsigned NumElements, MachineIRBuilder &MIRBuilder) {
return getOrCreateSPIRVType(
@@ -436,12 +747,15 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVVectorType(
const SPIRVInstrInfo &TII) {
Type *LLVMTy = FixedVectorType::get(
const_cast<Type *>(getTypeForSPIRVType(BaseType)), NumElements);
+ Register Reg = DT.find(LLVMTy, CurMF);
+ if (Reg.isValid())
+ return getSPIRVTypeForVReg(Reg);
MachineBasicBlock &BB = *I.getParent();
auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeVector))
.addDef(createTypeVReg(CurMF->getRegInfo()))
.addUse(getSPIRVTypeID(BaseType))
.addImm(NumElements);
- return restOfCreateSPIRVType(LLVMTy, MIB);
+ return finishCreatingSPIRVType(LLVMTy, MIB);
}
SPIRVType *
@@ -460,10 +774,39 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(
Type *LLVMTy =
PointerType::get(const_cast<Type *>(getTypeForSPIRVType(BaseType)),
storageClassToAddressSpace(SC));
+ Register Reg = DT.find(LLVMTy, CurMF);
+ if (Reg.isValid())
+ return getSPIRVTypeForVReg(Reg);
MachineBasicBlock &BB = *I.getParent();
auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypePointer))
.addDef(createTypeVReg(CurMF->getRegInfo()))
.addImm(static_cast<uint32_t>(SC))
.addUse(getSPIRVTypeID(BaseType));
- return restOfCreateSPIRVType(LLVMTy, MIB);
+ return finishCreatingSPIRVType(LLVMTy, MIB);
+}
+
+Register SPIRVGlobalRegistry::getOrCreateUndef(MachineInstr &I,
+ SPIRVType *SpvType,
+ const SPIRVInstrInfo &TII) {
+ assert(SpvType);
+ const Type *LLVMTy = getTypeForSPIRVType(SpvType);
+ assert(LLVMTy);
+ // Find a constant in DT or build a new one.
+ UndefValue *UV = UndefValue::get(const_cast<Type *>(LLVMTy));
+ Register Res = DT.find(UV, CurMF);
+ if (Res.isValid())
+ return Res;
+ LLT LLTy = LLT::scalar(32);
+ Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
+ assignSPIRVTypeToVReg(SpvType, Res, *CurMF);
+ DT.add(UV, CurMF, Res);
+
+ MachineInstrBuilder MIB;
+ MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpUndef))
+ .addDef(Res)
+ .addUse(getSPIRVTypeID(SpvType));
+ const auto &ST = CurMF->getSubtarget();
+ constrainSelectedInstRegOperands(*MIB, *ST.getInstrInfo(),
+ *ST.getRegisterInfo(), *ST.getRegBankInfo());
+ return Res;
}