diff options
Diffstat (limited to 'llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp')
| -rw-r--r-- | llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp | 127 |
1 files changed, 111 insertions, 16 deletions
diff --git a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp index ae4e03974428..d5b81bf46c80 100644 --- a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp @@ -91,15 +91,11 @@ SPIRVTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI, return std::make_pair(0u, RC); if (VT.isFloatingPoint()) - RC = VT.isVector() ? &SPIRV::vfIDRegClass - : (VT.getScalarSizeInBits() > 32 ? &SPIRV::fID64RegClass - : &SPIRV::fIDRegClass); + RC = VT.isVector() ? &SPIRV::vfIDRegClass : &SPIRV::fIDRegClass; else if (VT.isInteger()) - RC = VT.isVector() ? &SPIRV::vIDRegClass - : (VT.getScalarSizeInBits() > 32 ? &SPIRV::ID64RegClass - : &SPIRV::IDRegClass); + RC = VT.isVector() ? &SPIRV::vIDRegClass : &SPIRV::iIDRegClass; else - RC = &SPIRV::IDRegClass; + RC = &SPIRV::iIDRegClass; return std::make_pair(0u, RC); } @@ -115,8 +111,8 @@ static void doInsertBitcast(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI, SPIRVGlobalRegistry &GR, MachineInstr &I, Register OpReg, unsigned OpIdx, SPIRVType *NewPtrType) { - Register NewReg = MRI->createGenericVirtualRegister(LLT::scalar(32)); MachineIRBuilder MIB(I); + Register NewReg = createVirtualRegister(NewPtrType, &GR, MRI, MIB.getMF()); bool Res = MIB.buildInstr(SPIRV::OpBitcast) .addDef(NewReg) .addUse(GR.getSPIRVTypeID(NewPtrType)) @@ -125,8 +121,6 @@ static void doInsertBitcast(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI, *STI.getRegBankInfo()); if (!Res) report_fatal_error("insert validation bitcast: cannot constrain all uses"); - MRI->setRegClass(NewReg, &SPIRV::IDRegClass); - GR.assignSPIRVTypeToVReg(NewPtrType, NewReg, MIB.getMF()); I.getOperand(OpIdx).setReg(NewReg); } @@ -203,10 +197,34 @@ static void validateGroupWaitEventsPtr(const SPIRVSubtarget &STI, doInsertBitcast(STI, MRI, GR, I, OpReg, OpIdx, NewPtrType); } -static void validateGroupAsyncCopyPtr(const SPIRVSubtarget &STI, - MachineRegisterInfo *MRI, - SPIRVGlobalRegistry &GR, MachineInstr &I, - unsigned OpIdx) { +static void validateLifetimeStart(const SPIRVSubtarget &STI, + MachineRegisterInfo *MRI, + SPIRVGlobalRegistry &GR, MachineInstr &I) { + Register PtrReg = I.getOperand(0).getReg(); + MachineFunction *MF = I.getParent()->getParent(); + Register PtrTypeReg = getTypeReg(MRI, PtrReg); + SPIRVType *PtrType = GR.getSPIRVTypeForVReg(PtrTypeReg, MF); + SPIRVType *PonteeElemType = PtrType ? GR.getPointeeType(PtrType) : nullptr; + if (!PonteeElemType || PonteeElemType->getOpcode() == SPIRV::OpTypeVoid || + (PonteeElemType->getOpcode() == SPIRV::OpTypeInt && + PonteeElemType->getOperand(1).getImm() == 8)) + return; + // To keep the code valid a bitcast must be inserted + SPIRV::StorageClass::StorageClass SC = + static_cast<SPIRV::StorageClass::StorageClass>( + PtrType->getOperand(1).getImm()); + MachineIRBuilder MIB(I); + LLVMContext &Context = MF->getFunction().getContext(); + SPIRVType *ElemType = + GR.getOrCreateSPIRVType(IntegerType::getInt8Ty(Context), MIB); + SPIRVType *NewPtrType = GR.getOrCreateSPIRVPointerType(ElemType, MIB, SC); + doInsertBitcast(STI, MRI, GR, I, PtrReg, 0, NewPtrType); +} + +static void validatePtrUnwrapStructField(const SPIRVSubtarget &STI, + MachineRegisterInfo *MRI, + SPIRVGlobalRegistry &GR, + MachineInstr &I, unsigned OpIdx) { MachineFunction *MF = I.getParent()->getParent(); Register OpReg = I.getOperand(OpIdx).getReg(); Register OpTypeReg = getTypeReg(MRI, OpReg); @@ -333,6 +351,7 @@ void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const { GR.setCurrentFunc(MF); for (MachineFunction::iterator I = MF.begin(), E = MF.end(); I != E; ++I) { MachineBasicBlock *MBB = &*I; + SmallPtrSet<MachineInstr *, 8> ToMove; for (MachineBasicBlock::iterator MBBI = MBB->begin(), MBBE = MBB->end(); MBBI != MBBE;) { MachineInstr &MI = *MBBI++; @@ -375,6 +394,7 @@ void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const { case SPIRV::OpGenericCastToPtr: validateAccessChain(STI, MRI, GR, MI); break; + case SPIRV::OpPtrAccessChain: case SPIRV::OpInBoundsPtrAccessChain: if (MI.getNumOperands() == 4) validateAccessChain(STI, MRI, GR, MI); @@ -393,6 +413,17 @@ void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const { validateForwardCalls(STI, MRI, GR, MI); break; + // ensure that LLVM IR add/sub instructions result in logical SPIR-V + // instructions when applied to bool type + case SPIRV::OpIAddS: + case SPIRV::OpIAddV: + case SPIRV::OpISubS: + case SPIRV::OpISubV: + if (GR.isScalarOrVectorOfType(MI.getOperand(1).getReg(), + SPIRV::OpTypeBool)) + MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpLogicalNotEqual)); + break; + // ensure that LLVM IR bitwise instructions result in logical SPIR-V // instructions when applied to bool type case SPIRV::OpBitwiseOrS: @@ -413,9 +444,14 @@ void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const { SPIRV::OpTypeBool)) MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpLogicalNotEqual)); break; + case SPIRV::OpLifetimeStart: + case SPIRV::OpLifetimeStop: + if (MI.getOperand(1).getImm() > 0) + validateLifetimeStart(STI, MRI, GR, MI); + break; case SPIRV::OpGroupAsyncCopy: - validateGroupAsyncCopyPtr(STI, MRI, GR, MI, 3); - validateGroupAsyncCopyPtr(STI, MRI, GR, MI, 4); + validatePtrUnwrapStructField(STI, MRI, GR, MI, 3); + validatePtrUnwrapStructField(STI, MRI, GR, MI, 4); break; case SPIRV::OpGroupWaitEvents: // OpGroupWaitEvents ..., ..., <pointer to OpTypeEvent> @@ -431,8 +467,67 @@ void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const { MI.removeOperand(i); } } break; + case SPIRV::OpPhi: { + // Phi refers to a type definition that goes after the Phi + // instruction, so that the virtual register definition of the type + // doesn't dominate all uses. Let's place the type definition + // instruction at the end of the predecessor. + MachineBasicBlock *Curr = MI.getParent(); + SPIRVType *Type = GR.getSPIRVTypeForVReg(MI.getOperand(1).getReg()); + if (Type->getParent() == Curr && !Curr->pred_empty()) + ToMove.insert(const_cast<MachineInstr *>(Type)); + } break; + case SPIRV::OpExtInst: { + // prefetch + if (!MI.getOperand(2).isImm() || !MI.getOperand(3).isImm() || + MI.getOperand(2).getImm() != SPIRV::InstructionSet::OpenCL_std) + continue; + switch (MI.getOperand(3).getImm()) { + case SPIRV::OpenCLExtInst::frexp: + case SPIRV::OpenCLExtInst::lgamma_r: + case SPIRV::OpenCLExtInst::remquo: { + // The last operand must be of a pointer to i32 or vector of i32 + // values. + MachineIRBuilder MIB(MI); + SPIRVType *Int32Type = GR.getOrCreateSPIRVIntegerType(32, MIB); + SPIRVType *RetType = MRI->getVRegDef(MI.getOperand(1).getReg()); + assert(RetType && "Expected return type"); + validatePtrTypes( + STI, MRI, GR, MI, MI.getNumOperands() - 1, + RetType->getOpcode() != SPIRV::OpTypeVector + ? Int32Type + : GR.getOrCreateSPIRVVectorType( + Int32Type, RetType->getOperand(2).getImm(), MIB)); + } break; + case SPIRV::OpenCLExtInst::fract: + case SPIRV::OpenCLExtInst::modf: + case SPIRV::OpenCLExtInst::sincos: + // The last operand must be of a pointer to the base type represented + // by the previous operand. + assert(MI.getOperand(MI.getNumOperands() - 2).isReg() && + "Expected v-reg"); + validatePtrTypes( + STI, MRI, GR, MI, MI.getNumOperands() - 1, + GR.getSPIRVTypeForVReg( + MI.getOperand(MI.getNumOperands() - 2).getReg())); + break; + case SPIRV::OpenCLExtInst::prefetch: + // Expected `ptr` type is a pointer to float, integer or vector, but + // the pontee value can be wrapped into a struct. + assert(MI.getOperand(MI.getNumOperands() - 2).isReg() && + "Expected v-reg"); + validatePtrUnwrapStructField(STI, MRI, GR, MI, + MI.getNumOperands() - 2); + break; + } + } break; } } + for (MachineInstr *MI : ToMove) { + MachineBasicBlock *Curr = MI->getParent(); + MachineBasicBlock *Pred = *Curr->pred_begin(); + Pred->insert(Pred->getFirstTerminator(), Curr->remove_instr(MI)); + } } ProcessedMF.insert(&MF); TargetLowering::finalizeLowering(MF); |
