aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp')
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp127
1 files changed, 111 insertions, 16 deletions
diff --git a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
index ae4e03974428..d5b81bf46c80 100644
--- a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
@@ -91,15 +91,11 @@ SPIRVTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
return std::make_pair(0u, RC);
if (VT.isFloatingPoint())
- RC = VT.isVector() ? &SPIRV::vfIDRegClass
- : (VT.getScalarSizeInBits() > 32 ? &SPIRV::fID64RegClass
- : &SPIRV::fIDRegClass);
+ RC = VT.isVector() ? &SPIRV::vfIDRegClass : &SPIRV::fIDRegClass;
else if (VT.isInteger())
- RC = VT.isVector() ? &SPIRV::vIDRegClass
- : (VT.getScalarSizeInBits() > 32 ? &SPIRV::ID64RegClass
- : &SPIRV::IDRegClass);
+ RC = VT.isVector() ? &SPIRV::vIDRegClass : &SPIRV::iIDRegClass;
else
- RC = &SPIRV::IDRegClass;
+ RC = &SPIRV::iIDRegClass;
return std::make_pair(0u, RC);
}
@@ -115,8 +111,8 @@ static void doInsertBitcast(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI,
SPIRVGlobalRegistry &GR, MachineInstr &I,
Register OpReg, unsigned OpIdx,
SPIRVType *NewPtrType) {
- Register NewReg = MRI->createGenericVirtualRegister(LLT::scalar(32));
MachineIRBuilder MIB(I);
+ Register NewReg = createVirtualRegister(NewPtrType, &GR, MRI, MIB.getMF());
bool Res = MIB.buildInstr(SPIRV::OpBitcast)
.addDef(NewReg)
.addUse(GR.getSPIRVTypeID(NewPtrType))
@@ -125,8 +121,6 @@ static void doInsertBitcast(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI,
*STI.getRegBankInfo());
if (!Res)
report_fatal_error("insert validation bitcast: cannot constrain all uses");
- MRI->setRegClass(NewReg, &SPIRV::IDRegClass);
- GR.assignSPIRVTypeToVReg(NewPtrType, NewReg, MIB.getMF());
I.getOperand(OpIdx).setReg(NewReg);
}
@@ -203,10 +197,34 @@ static void validateGroupWaitEventsPtr(const SPIRVSubtarget &STI,
doInsertBitcast(STI, MRI, GR, I, OpReg, OpIdx, NewPtrType);
}
-static void validateGroupAsyncCopyPtr(const SPIRVSubtarget &STI,
- MachineRegisterInfo *MRI,
- SPIRVGlobalRegistry &GR, MachineInstr &I,
- unsigned OpIdx) {
+static void validateLifetimeStart(const SPIRVSubtarget &STI,
+ MachineRegisterInfo *MRI,
+ SPIRVGlobalRegistry &GR, MachineInstr &I) {
+ Register PtrReg = I.getOperand(0).getReg();
+ MachineFunction *MF = I.getParent()->getParent();
+ Register PtrTypeReg = getTypeReg(MRI, PtrReg);
+ SPIRVType *PtrType = GR.getSPIRVTypeForVReg(PtrTypeReg, MF);
+ SPIRVType *PonteeElemType = PtrType ? GR.getPointeeType(PtrType) : nullptr;
+ if (!PonteeElemType || PonteeElemType->getOpcode() == SPIRV::OpTypeVoid ||
+ (PonteeElemType->getOpcode() == SPIRV::OpTypeInt &&
+ PonteeElemType->getOperand(1).getImm() == 8))
+ return;
+ // To keep the code valid a bitcast must be inserted
+ SPIRV::StorageClass::StorageClass SC =
+ static_cast<SPIRV::StorageClass::StorageClass>(
+ PtrType->getOperand(1).getImm());
+ MachineIRBuilder MIB(I);
+ LLVMContext &Context = MF->getFunction().getContext();
+ SPIRVType *ElemType =
+ GR.getOrCreateSPIRVType(IntegerType::getInt8Ty(Context), MIB);
+ SPIRVType *NewPtrType = GR.getOrCreateSPIRVPointerType(ElemType, MIB, SC);
+ doInsertBitcast(STI, MRI, GR, I, PtrReg, 0, NewPtrType);
+}
+
+static void validatePtrUnwrapStructField(const SPIRVSubtarget &STI,
+ MachineRegisterInfo *MRI,
+ SPIRVGlobalRegistry &GR,
+ MachineInstr &I, unsigned OpIdx) {
MachineFunction *MF = I.getParent()->getParent();
Register OpReg = I.getOperand(OpIdx).getReg();
Register OpTypeReg = getTypeReg(MRI, OpReg);
@@ -333,6 +351,7 @@ void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {
GR.setCurrentFunc(MF);
for (MachineFunction::iterator I = MF.begin(), E = MF.end(); I != E; ++I) {
MachineBasicBlock *MBB = &*I;
+ SmallPtrSet<MachineInstr *, 8> ToMove;
for (MachineBasicBlock::iterator MBBI = MBB->begin(), MBBE = MBB->end();
MBBI != MBBE;) {
MachineInstr &MI = *MBBI++;
@@ -375,6 +394,7 @@ void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {
case SPIRV::OpGenericCastToPtr:
validateAccessChain(STI, MRI, GR, MI);
break;
+ case SPIRV::OpPtrAccessChain:
case SPIRV::OpInBoundsPtrAccessChain:
if (MI.getNumOperands() == 4)
validateAccessChain(STI, MRI, GR, MI);
@@ -393,6 +413,17 @@ void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {
validateForwardCalls(STI, MRI, GR, MI);
break;
+ // ensure that LLVM IR add/sub instructions result in logical SPIR-V
+ // instructions when applied to bool type
+ case SPIRV::OpIAddS:
+ case SPIRV::OpIAddV:
+ case SPIRV::OpISubS:
+ case SPIRV::OpISubV:
+ if (GR.isScalarOrVectorOfType(MI.getOperand(1).getReg(),
+ SPIRV::OpTypeBool))
+ MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpLogicalNotEqual));
+ break;
+
// ensure that LLVM IR bitwise instructions result in logical SPIR-V
// instructions when applied to bool type
case SPIRV::OpBitwiseOrS:
@@ -413,9 +444,14 @@ void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {
SPIRV::OpTypeBool))
MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpLogicalNotEqual));
break;
+ case SPIRV::OpLifetimeStart:
+ case SPIRV::OpLifetimeStop:
+ if (MI.getOperand(1).getImm() > 0)
+ validateLifetimeStart(STI, MRI, GR, MI);
+ break;
case SPIRV::OpGroupAsyncCopy:
- validateGroupAsyncCopyPtr(STI, MRI, GR, MI, 3);
- validateGroupAsyncCopyPtr(STI, MRI, GR, MI, 4);
+ validatePtrUnwrapStructField(STI, MRI, GR, MI, 3);
+ validatePtrUnwrapStructField(STI, MRI, GR, MI, 4);
break;
case SPIRV::OpGroupWaitEvents:
// OpGroupWaitEvents ..., ..., <pointer to OpTypeEvent>
@@ -431,8 +467,67 @@ void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {
MI.removeOperand(i);
}
} break;
+ case SPIRV::OpPhi: {
+ // Phi refers to a type definition that goes after the Phi
+ // instruction, so that the virtual register definition of the type
+ // doesn't dominate all uses. Let's place the type definition
+ // instruction at the end of the predecessor.
+ MachineBasicBlock *Curr = MI.getParent();
+ SPIRVType *Type = GR.getSPIRVTypeForVReg(MI.getOperand(1).getReg());
+ if (Type->getParent() == Curr && !Curr->pred_empty())
+ ToMove.insert(const_cast<MachineInstr *>(Type));
+ } break;
+ case SPIRV::OpExtInst: {
+ // prefetch
+ if (!MI.getOperand(2).isImm() || !MI.getOperand(3).isImm() ||
+ MI.getOperand(2).getImm() != SPIRV::InstructionSet::OpenCL_std)
+ continue;
+ switch (MI.getOperand(3).getImm()) {
+ case SPIRV::OpenCLExtInst::frexp:
+ case SPIRV::OpenCLExtInst::lgamma_r:
+ case SPIRV::OpenCLExtInst::remquo: {
+ // The last operand must be of a pointer to i32 or vector of i32
+ // values.
+ MachineIRBuilder MIB(MI);
+ SPIRVType *Int32Type = GR.getOrCreateSPIRVIntegerType(32, MIB);
+ SPIRVType *RetType = MRI->getVRegDef(MI.getOperand(1).getReg());
+ assert(RetType && "Expected return type");
+ validatePtrTypes(
+ STI, MRI, GR, MI, MI.getNumOperands() - 1,
+ RetType->getOpcode() != SPIRV::OpTypeVector
+ ? Int32Type
+ : GR.getOrCreateSPIRVVectorType(
+ Int32Type, RetType->getOperand(2).getImm(), MIB));
+ } break;
+ case SPIRV::OpenCLExtInst::fract:
+ case SPIRV::OpenCLExtInst::modf:
+ case SPIRV::OpenCLExtInst::sincos:
+ // The last operand must be of a pointer to the base type represented
+ // by the previous operand.
+ assert(MI.getOperand(MI.getNumOperands() - 2).isReg() &&
+ "Expected v-reg");
+ validatePtrTypes(
+ STI, MRI, GR, MI, MI.getNumOperands() - 1,
+ GR.getSPIRVTypeForVReg(
+ MI.getOperand(MI.getNumOperands() - 2).getReg()));
+ break;
+ case SPIRV::OpenCLExtInst::prefetch:
+ // Expected `ptr` type is a pointer to float, integer or vector, but
+ // the pontee value can be wrapped into a struct.
+ assert(MI.getOperand(MI.getNumOperands() - 2).isReg() &&
+ "Expected v-reg");
+ validatePtrUnwrapStructField(STI, MRI, GR, MI,
+ MI.getNumOperands() - 2);
+ break;
+ }
+ } break;
}
}
+ for (MachineInstr *MI : ToMove) {
+ MachineBasicBlock *Curr = MI->getParent();
+ MachineBasicBlock *Pred = *Curr->pred_begin();
+ Pred->insert(Pred->getFirstTerminator(), Curr->remove_instr(MI));
+ }
}
ProcessedMF.insert(&MF);
TargetLowering::finalizeLowering(MF);