diff options
Diffstat (limited to 'contrib/llvm-project/clang/lib/CodeGen/CGGPUBuiltin.cpp')
-rw-r--r-- | contrib/llvm-project/clang/lib/CodeGen/CGGPUBuiltin.cpp | 142 |
1 files changed, 102 insertions, 40 deletions
diff --git a/contrib/llvm-project/clang/lib/CodeGen/CGGPUBuiltin.cpp b/contrib/llvm-project/clang/lib/CodeGen/CGGPUBuiltin.cpp index f860623e2bc3..75fb06de9384 100644 --- a/contrib/llvm-project/clang/lib/CodeGen/CGGPUBuiltin.cpp +++ b/contrib/llvm-project/clang/lib/CodeGen/CGGPUBuiltin.cpp @@ -21,13 +21,14 @@ using namespace clang; using namespace CodeGen; -static llvm::Function *GetVprintfDeclaration(llvm::Module &M) { +namespace { +llvm::Function *GetVprintfDeclaration(llvm::Module &M) { llvm::Type *ArgTypes[] = {llvm::Type::getInt8PtrTy(M.getContext()), llvm::Type::getInt8PtrTy(M.getContext())}; llvm::FunctionType *VprintfFuncType = llvm::FunctionType::get( llvm::Type::getInt32Ty(M.getContext()), ArgTypes, false); - if (auto* F = M.getFunction("vprintf")) { + if (auto *F = M.getFunction("vprintf")) { // Our CUDA system header declares vprintf with the right signature, so // nobody else should have been able to declare vprintf with a bogus // signature. @@ -41,6 +42,28 @@ static llvm::Function *GetVprintfDeclaration(llvm::Module &M) { VprintfFuncType, llvm::GlobalVariable::ExternalLinkage, "vprintf", &M); } +llvm::Function *GetOpenMPVprintfDeclaration(CodeGenModule &CGM) { + const char *Name = "__llvm_omp_vprintf"; + llvm::Module &M = CGM.getModule(); + llvm::Type *ArgTypes[] = {llvm::Type::getInt8PtrTy(M.getContext()), + llvm::Type::getInt8PtrTy(M.getContext()), + llvm::Type::getInt32Ty(M.getContext())}; + llvm::FunctionType *VprintfFuncType = llvm::FunctionType::get( + llvm::Type::getInt32Ty(M.getContext()), ArgTypes, false); + + if (auto *F = M.getFunction(Name)) { + if (F->getFunctionType() != VprintfFuncType) { + CGM.Error(SourceLocation(), + "Invalid type declaration for __llvm_omp_vprintf"); + return nullptr; + } + return F; + } + + return llvm::Function::Create( + VprintfFuncType, llvm::GlobalVariable::ExternalLinkage, Name, &M); +} + // Transforms a call to printf into a call to the NVPTX vprintf syscall (which // isn't particularly special; it's invoked just like a regular function). // vprintf takes two args: A format string, and a pointer to a buffer containing @@ -66,39 +89,22 @@ static llvm::Function *GetVprintfDeclaration(llvm::Module &M) { // // Note that by the time this function runs, E's args have already undergone the // standard C vararg promotion (short -> int, float -> double, etc.). -RValue -CodeGenFunction::EmitNVPTXDevicePrintfCallExpr(const CallExpr *E, - ReturnValueSlot ReturnValue) { - assert(getTarget().getTriple().isNVPTX()); - assert(E->getBuiltinCallee() == Builtin::BIprintf); - assert(E->getNumArgs() >= 1); // printf always has at least one arg. - const llvm::DataLayout &DL = CGM.getDataLayout(); - llvm::LLVMContext &Ctx = CGM.getLLVMContext(); - - CallArgList Args; - EmitCallArgs(Args, - E->getDirectCallee()->getType()->getAs<FunctionProtoType>(), - E->arguments(), E->getDirectCallee(), - /* ParamsToSkip = */ 0); - - // We don't know how to emit non-scalar varargs. - if (std::any_of(Args.begin() + 1, Args.end(), [&](const CallArg &A) { - return !A.getRValue(*this).isScalar(); - })) { - CGM.ErrorUnsupported(E, "non-scalar arg to printf"); - return RValue::get(llvm::ConstantInt::get(IntTy, 0)); - } +std::pair<llvm::Value *, llvm::TypeSize> +packArgsIntoNVPTXFormatBuffer(CodeGenFunction *CGF, const CallArgList &Args) { + const llvm::DataLayout &DL = CGF->CGM.getDataLayout(); + llvm::LLVMContext &Ctx = CGF->CGM.getLLVMContext(); + CGBuilderTy &Builder = CGF->Builder; // Construct and fill the args buffer that we'll pass to vprintf. - llvm::Value *BufferPtr; if (Args.size() <= 1) { - // If there are no args, pass a null pointer to vprintf. - BufferPtr = llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(Ctx)); + // If there are no args, pass a null pointer and size 0 + llvm::Value * BufferPtr = llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(Ctx)); + return {BufferPtr, llvm::TypeSize::Fixed(0)}; } else { llvm::SmallVector<llvm::Type *, 8> ArgTypes; for (unsigned I = 1, NumArgs = Args.size(); I < NumArgs; ++I) - ArgTypes.push_back(Args[I].getRValue(*this).getScalarVal()->getType()); + ArgTypes.push_back(Args[I].getRValue(*CGF).getScalarVal()->getType()); // Using llvm::StructType is correct only because printf doesn't accept // aggregates. If we had to handle aggregates here, we'd have to manually @@ -106,25 +112,71 @@ CodeGenFunction::EmitNVPTXDevicePrintfCallExpr(const CallExpr *E, // that the alignment of the llvm type was the same as the alignment of the // clang type. llvm::Type *AllocaTy = llvm::StructType::create(ArgTypes, "printf_args"); - llvm::Value *Alloca = CreateTempAlloca(AllocaTy); + llvm::Value *Alloca = CGF->CreateTempAlloca(AllocaTy); for (unsigned I = 1, NumArgs = Args.size(); I < NumArgs; ++I) { llvm::Value *P = Builder.CreateStructGEP(AllocaTy, Alloca, I - 1); - llvm::Value *Arg = Args[I].getRValue(*this).getScalarVal(); + llvm::Value *Arg = Args[I].getRValue(*CGF).getScalarVal(); Builder.CreateAlignedStore(Arg, P, DL.getPrefTypeAlign(Arg->getType())); } - BufferPtr = Builder.CreatePointerCast(Alloca, llvm::Type::getInt8PtrTy(Ctx)); + llvm::Value *BufferPtr = + Builder.CreatePointerCast(Alloca, llvm::Type::getInt8PtrTy(Ctx)); + return {BufferPtr, DL.getTypeAllocSize(AllocaTy)}; + } +} + +bool containsNonScalarVarargs(CodeGenFunction *CGF, const CallArgList &Args) { + return llvm::any_of(llvm::drop_begin(Args), [&](const CallArg &A) { + return !A.getRValue(*CGF).isScalar(); + }); +} + +RValue EmitDevicePrintfCallExpr(const CallExpr *E, CodeGenFunction *CGF, + llvm::Function *Decl, bool WithSizeArg) { + CodeGenModule &CGM = CGF->CGM; + CGBuilderTy &Builder = CGF->Builder; + assert(E->getBuiltinCallee() == Builtin::BIprintf); + assert(E->getNumArgs() >= 1); // printf always has at least one arg. + + // Uses the same format as nvptx for the argument packing, but also passes + // an i32 for the total size of the passed pointer + CallArgList Args; + CGF->EmitCallArgs(Args, + E->getDirectCallee()->getType()->getAs<FunctionProtoType>(), + E->arguments(), E->getDirectCallee(), + /* ParamsToSkip = */ 0); + + // We don't know how to emit non-scalar varargs. + if (containsNonScalarVarargs(CGF, Args)) { + CGM.ErrorUnsupported(E, "non-scalar arg to printf"); + return RValue::get(llvm::ConstantInt::get(CGF->IntTy, 0)); } - // Invoke vprintf and return. - llvm::Function* VprintfFunc = GetVprintfDeclaration(CGM.getModule()); - return RValue::get(Builder.CreateCall( - VprintfFunc, {Args[0].getRValue(*this).getScalarVal(), BufferPtr})); + auto r = packArgsIntoNVPTXFormatBuffer(CGF, Args); + llvm::Value *BufferPtr = r.first; + + llvm::SmallVector<llvm::Value *, 3> Vec = { + Args[0].getRValue(*CGF).getScalarVal(), BufferPtr}; + if (WithSizeArg) { + // Passing > 32bit of data as a local alloca doesn't work for nvptx or + // amdgpu + llvm::Constant *Size = + llvm::ConstantInt::get(llvm::Type::getInt32Ty(CGM.getLLVMContext()), + static_cast<uint32_t>(r.second.getFixedValue())); + + Vec.push_back(Size); + } + return RValue::get(Builder.CreateCall(Decl, Vec)); } +} // namespace -RValue -CodeGenFunction::EmitAMDGPUDevicePrintfCallExpr(const CallExpr *E, - ReturnValueSlot ReturnValue) { +RValue CodeGenFunction::EmitNVPTXDevicePrintfCallExpr(const CallExpr *E) { + assert(getTarget().getTriple().isNVPTX()); + return EmitDevicePrintfCallExpr( + E, this, GetVprintfDeclaration(CGM.getModule()), false); +} + +RValue CodeGenFunction::EmitAMDGPUDevicePrintfCallExpr(const CallExpr *E) { assert(getTarget().getTriple().getArch() == llvm::Triple::amdgcn); assert(E->getBuiltinCallee() == Builtin::BIprintf || E->getBuiltinCallee() == Builtin::BI__builtin_printf); @@ -137,7 +189,7 @@ CodeGenFunction::EmitAMDGPUDevicePrintfCallExpr(const CallExpr *E, /* ParamsToSkip = */ 0); SmallVector<llvm::Value *, 8> Args; - for (auto A : CallArgs) { + for (const auto &A : CallArgs) { // We don't know how to emit non-scalar varargs. if (!A.getRValue(*this).isScalar()) { CGM.ErrorUnsupported(E, "non-scalar arg to printf"); @@ -150,7 +202,17 @@ CodeGenFunction::EmitAMDGPUDevicePrintfCallExpr(const CallExpr *E, llvm::IRBuilder<> IRB(Builder.GetInsertBlock(), Builder.GetInsertPoint()); IRB.SetCurrentDebugLocation(Builder.getCurrentDebugLocation()); - auto Printf = llvm::emitAMDGPUPrintfCall(IRB, Args); + + bool isBuffered = (CGM.getTarget().getTargetOpts().AMDGPUPrintfKindVal == + clang::TargetOptions::AMDGPUPrintfKind::Buffered); + auto Printf = llvm::emitAMDGPUPrintfCall(IRB, Args, isBuffered); Builder.SetInsertPoint(IRB.GetInsertBlock(), IRB.GetInsertPoint()); return RValue::get(Printf); } + +RValue CodeGenFunction::EmitOpenMPDevicePrintfCallExpr(const CallExpr *E) { + assert(getTarget().getTriple().isNVPTX() || + getTarget().getTriple().isAMDGCN()); + return EmitDevicePrintfCallExpr(E, this, GetOpenMPVprintfDeclaration(CGM), + true); +} |