aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Target/AMDGPU/AMDGPUReplaceLDSUseWithPointer.cpp
blob: d55bf3917e9c1bd2c48bd8fd057c79eaf1beeea1 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
//===-- AMDGPUReplaceLDSUseWithPointer.cpp --------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This pass replaces all the uses of LDS within non-kernel functions by
// corresponding pointer counter-parts.
//
// The main motivation behind this pass is - to *avoid* subsequent LDS lowering
// pass from directly packing LDS (assume large LDS) into a struct type which
// would otherwise cause allocating huge memory for struct instance within every
// kernel.
//
// Brief sketch of the algorithm implemented in this pass is as below:
//
//   1. Collect all the LDS defined in the module which qualify for pointer
//      replacement, say it is, LDSGlobals set.
//
//   2. Collect all the reachable callees for each kernel defined in the module,
//      say it is, KernelToCallees map.
//
//   3. FOR (each global GV from LDSGlobals set) DO
//        LDSUsedNonKernels = Collect all non-kernel functions which use GV.
//        FOR (each kernel K in KernelToCallees map) DO
//           ReachableCallees = KernelToCallees[K]
//           ReachableAndLDSUsedCallees =
//              SetIntersect(LDSUsedNonKernels, ReachableCallees)
//           IF (ReachableAndLDSUsedCallees is not empty) THEN
//             Pointer = Create a pointer to point-to GV if not created.
//             Initialize Pointer to point-to GV within kernel K.
//           ENDIF
//        ENDFOR
//        Replace all uses of GV within non kernel functions by Pointer.
//      ENFOR
//
// LLVM IR example:
//
//    Input IR:
//
//    @lds = internal addrspace(3) global [4 x i32] undef, align 16
//
//    define internal void @f0() {
//    entry:
//      %gep = getelementptr inbounds [4 x i32], [4 x i32] addrspace(3)* @lds,
//             i32 0, i32 0
//      ret void
//    }
//
//    define protected amdgpu_kernel void @k0() {
//    entry:
//      call void @f0()
//      ret void
//    }
//
//    Output IR:
//
//    @lds = internal addrspace(3) global [4 x i32] undef, align 16
//    @lds.ptr = internal unnamed_addr addrspace(3) global i16 undef, align 2
//
//    define internal void @f0() {
//    entry:
//      %0 = load i16, i16 addrspace(3)* @lds.ptr, align 2
//      %1 = getelementptr i8, i8 addrspace(3)* null, i16 %0
//      %2 = bitcast i8 addrspace(3)* %1 to [4 x i32] addrspace(3)*
//      %gep = getelementptr inbounds [4 x i32], [4 x i32] addrspace(3)* %2,
//             i32 0, i32 0
//      ret void
//    }
//
//    define protected amdgpu_kernel void @k0() {
//    entry:
//      store i16 ptrtoint ([4 x i32] addrspace(3)* @lds to i16),
//            i16 addrspace(3)* @lds.ptr, align 2
//      call void @f0()
//      ret void
//    }
//
//===----------------------------------------------------------------------===//

#include "AMDGPU.h"
#include "GCNSubtarget.h"
#include "Utils/AMDGPUBaseInfo.h"
#include "Utils/AMDGPULDSUtils.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetOperations.h"
#include "llvm/CodeGen/TargetPassConfig.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InlineAsm.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicsAMDGPU.h"
#include "llvm/IR/ReplaceConstant.h"
#include "llvm/InitializePasses.h"
#include "llvm/Pass.h"
#include "llvm/Support/Debug.h"
#include "llvm/Target/TargetMachine.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/ModuleUtils.h"
#include <algorithm>
#include <vector>

#define DEBUG_TYPE "amdgpu-replace-lds-use-with-pointer"

using namespace llvm;

namespace {

class ReplaceLDSUseImpl {
  Module &M;
  LLVMContext &Ctx;
  const DataLayout &DL;
  Constant *LDSMemBaseAddr;

  DenseMap<GlobalVariable *, GlobalVariable *> LDSToPointer;
  DenseMap<GlobalVariable *, SmallPtrSet<Function *, 8>> LDSToNonKernels;
  DenseMap<Function *, SmallPtrSet<Function *, 8>> KernelToCallees;
  DenseMap<Function *, SmallPtrSet<GlobalVariable *, 8>> KernelToLDSPointers;
  DenseMap<Function *, BasicBlock *> KernelToInitBB;
  DenseMap<Function *, DenseMap<GlobalVariable *, Value *>>
      FunctionToLDSToReplaceInst;

  // Collect LDS which requires their uses to be replaced by pointer.
  std::vector<GlobalVariable *> collectLDSRequiringPointerReplace() {
    // Collect LDS which requires module lowering.
    std::vector<GlobalVariable *> LDSGlobals = AMDGPU::findVariablesToLower(M);

    // Remove LDS which don't qualify for replacement.
    llvm::erase_if(LDSGlobals, [&](GlobalVariable *GV) {
      return shouldIgnorePointerReplacement(GV);
    });

    return LDSGlobals;
  }

  // Returns true if uses of given LDS global within non-kernel functions should
  // be keep as it is without pointer replacement.
  bool shouldIgnorePointerReplacement(GlobalVariable *GV) {
    // LDS whose size is very small and doesn't exceed pointer size is not worth
    // replacing.
    if (DL.getTypeAllocSize(GV->getValueType()) <= 2)
      return true;

    // LDS which is not used from non-kernel function scope or it is used from
    // global scope does not qualify for replacement.
    LDSToNonKernels[GV] = AMDGPU::collectNonKernelAccessorsOfLDS(GV);
    return LDSToNonKernels[GV].empty();

    // FIXME: When GV is used within all (or within most of the kernels), then
    // it does not make sense to create a pointer for it.
  }

  // Insert new global LDS pointer which points to LDS.
  GlobalVariable *createLDSPointer(GlobalVariable *GV) {
    // LDS pointer which points to LDS is already created? Return it.
    auto PointerEntry = LDSToPointer.insert(std::make_pair(GV, nullptr));
    if (!PointerEntry.second)
      return PointerEntry.first->second;

    // We need to create new LDS pointer which points to LDS.
    //
    // Each CU owns at max 64K of LDS memory, so LDS address ranges from 0 to
    // 2^16 - 1. Hence 16 bit pointer is enough to hold the LDS address.
    auto *I16Ty = Type::getInt16Ty(Ctx);
    GlobalVariable *LDSPointer = new GlobalVariable(
        M, I16Ty, false, GlobalValue::InternalLinkage, UndefValue::get(I16Ty),
        GV->getName() + Twine(".ptr"), nullptr, GlobalVariable::NotThreadLocal,
        AMDGPUAS::LOCAL_ADDRESS);

    LDSPointer->setUnnamedAddr(GlobalValue::UnnamedAddr::Global);
    LDSPointer->setAlignment(AMDGPU::getAlign(DL, LDSPointer));

    // Mark that an associated LDS pointer is created for LDS.
    LDSToPointer[GV] = LDSPointer;

    return LDSPointer;
  }

  // Split entry basic block in such a way that only lane 0 of each wave does
  // the LDS pointer initialization, and return newly created basic block.
  BasicBlock *activateLaneZero(Function *K) {
    // If the entry basic block of kernel K is already split, then return
    // newly created basic block.
    auto BasicBlockEntry = KernelToInitBB.insert(std::make_pair(K, nullptr));
    if (!BasicBlockEntry.second)
      return BasicBlockEntry.first->second;

    // Split entry basic block of kernel K.
    auto *EI = &(*(K->getEntryBlock().getFirstInsertionPt()));
    IRBuilder<> Builder(EI);

    Value *Mbcnt =
        Builder.CreateIntrinsic(Intrinsic::amdgcn_mbcnt_lo, {},
                                {Builder.getInt32(-1), Builder.getInt32(0)});
    Value *Cond = Builder.CreateICmpEQ(Mbcnt, Builder.getInt32(0));
    Instruction *WB = cast<Instruction>(
        Builder.CreateIntrinsic(Intrinsic::amdgcn_wave_barrier, {}, {}));

    BasicBlock *NBB = SplitBlockAndInsertIfThen(Cond, WB, false)->getParent();

    // Mark that the entry basic block of kernel K is split.
    KernelToInitBB[K] = NBB;

    return NBB;
  }

  // Within given kernel, initialize given LDS pointer to point to given LDS.
  void initializeLDSPointer(Function *K, GlobalVariable *GV,
                            GlobalVariable *LDSPointer) {
    // If LDS pointer is already initialized within K, then nothing to do.
    auto PointerEntry = KernelToLDSPointers.insert(
        std::make_pair(K, SmallPtrSet<GlobalVariable *, 8>()));
    if (!PointerEntry.second)
      if (PointerEntry.first->second.contains(LDSPointer))
        return;

    // Insert instructions at EI which initialize LDS pointer to point-to LDS
    // within kernel K.
    //
    // That is, convert pointer type of GV to i16, and then store this converted
    // i16 value within LDSPointer which is of type i16*.
    auto *EI = &(*(activateLaneZero(K)->getFirstInsertionPt()));
    IRBuilder<> Builder(EI);
    Builder.CreateStore(Builder.CreatePtrToInt(GV, Type::getInt16Ty(Ctx)),
                        LDSPointer);

    // Mark that LDS pointer is initialized within kernel K.
    KernelToLDSPointers[K].insert(LDSPointer);
  }

  // We have created an LDS pointer for LDS, and initialized it to point-to LDS
  // within all relevant kernels. Now replace all the uses of LDS within
  // non-kernel functions by LDS pointer.
  void replaceLDSUseByPointer(GlobalVariable *GV, GlobalVariable *LDSPointer) {
    SmallVector<User *, 8> LDSUsers(GV->users());
    for (auto *U : LDSUsers) {
      // When `U` is a constant expression, it is possible that same constant
      // expression exists within multiple instructions, and within multiple
      // non-kernel functions. Collect all those non-kernel functions and all
      // those instructions within which `U` exist.
      auto FunctionToInsts =
          AMDGPU::getFunctionToInstsMap(U, false /*=CollectKernelInsts*/);

      for (auto FI = FunctionToInsts.begin(), FE = FunctionToInsts.end();
           FI != FE; ++FI) {
        Function *F = FI->first;
        auto &Insts = FI->second;
        for (auto *I : Insts) {
          // If `U` is a constant expression, then we need to break the
          // associated instruction into a set of separate instructions by
          // converting constant expressions into instructions.
          SmallPtrSet<Instruction *, 8> UserInsts;

          if (U == I) {
            // `U` is an instruction, conversion from constant expression to
            // set of instructions is *not* required.
            UserInsts.insert(I);
          } else {
            // `U` is a constant expression, convert it into corresponding set
            // of instructions.
            auto *CE = cast<ConstantExpr>(U);
            convertConstantExprsToInstructions(I, CE, &UserInsts);
          }

          // Go through all the user instructions, if LDS exist within them as
          // an operand, then replace it by replace instruction.
          for (auto *II : UserInsts) {
            auto *ReplaceInst = getReplacementInst(F, GV, LDSPointer);
            II->replaceUsesOfWith(GV, ReplaceInst);
          }
        }
      }
    }
  }

  // Create a set of replacement instructions which together replace LDS within
  // non-kernel function F by accessing LDS indirectly using LDS pointer.
  Value *getReplacementInst(Function *F, GlobalVariable *GV,
                            GlobalVariable *LDSPointer) {
    // If the instruction which replaces LDS within F is already created, then
    // return it.
    auto LDSEntry = FunctionToLDSToReplaceInst.insert(
        std::make_pair(F, DenseMap<GlobalVariable *, Value *>()));
    if (!LDSEntry.second) {
      auto ReplaceInstEntry =
          LDSEntry.first->second.insert(std::make_pair(GV, nullptr));
      if (!ReplaceInstEntry.second)
        return ReplaceInstEntry.first->second;
    }

    // Get the instruction insertion point within the beginning of the entry
    // block of current non-kernel function.
    auto *EI = &(*(F->getEntryBlock().getFirstInsertionPt()));
    IRBuilder<> Builder(EI);

    // Insert required set of instructions which replace LDS within F.
    auto *V = Builder.CreateBitCast(
        Builder.CreateGEP(
            Builder.getInt8Ty(), LDSMemBaseAddr,
            Builder.CreateLoad(LDSPointer->getValueType(), LDSPointer)),
        GV->getType());

    // Mark that the replacement instruction which replace LDS within F is
    // created.
    FunctionToLDSToReplaceInst[F][GV] = V;

    return V;
  }

public:
  ReplaceLDSUseImpl(Module &M)
      : M(M), Ctx(M.getContext()), DL(M.getDataLayout()) {
    LDSMemBaseAddr = Constant::getIntegerValue(
        PointerType::get(Type::getInt8Ty(M.getContext()),
                         AMDGPUAS::LOCAL_ADDRESS),
        APInt(32, 0));
  }

  // Entry-point function which interface ReplaceLDSUseImpl with outside of the
  // class.
  bool replaceLDSUse();

private:
  // For a given LDS from collected LDS globals set, replace its non-kernel
  // function scope uses by pointer.
  bool replaceLDSUse(GlobalVariable *GV);
};

// For given LDS from collected LDS globals set, replace its non-kernel function
// scope uses by pointer.
bool ReplaceLDSUseImpl::replaceLDSUse(GlobalVariable *GV) {
  // Holds all those non-kernel functions within which LDS is being accessed.
  SmallPtrSet<Function *, 8> &LDSAccessors = LDSToNonKernels[GV];

  // The LDS pointer which points to LDS and replaces all the uses of LDS.
  GlobalVariable *LDSPointer = nullptr;

  // Traverse through each kernel K, check and if required, initialize the
  // LDS pointer to point to LDS within K.
  for (auto KI = KernelToCallees.begin(), KE = KernelToCallees.end(); KI != KE;
       ++KI) {
    Function *K = KI->first;
    SmallPtrSet<Function *, 8> Callees = KI->second;

    // Compute reachable and LDS used callees for kernel K.
    set_intersect(Callees, LDSAccessors);

    // None of the LDS accessing non-kernel functions are reachable from
    // kernel K. Hence, no need to initialize LDS pointer within kernel K.
    if (Callees.empty())
      continue;

    // We have found reachable and LDS used callees for kernel K, and we need to
    // initialize LDS pointer within kernel K, and we need to replace LDS use
    // within those callees by LDS pointer.
    //
    // But, first check if LDS pointer is already created, if not create one.
    LDSPointer = createLDSPointer(GV);

    // Initialize LDS pointer to point to LDS within kernel K.
    initializeLDSPointer(K, GV, LDSPointer);
  }

  // We have not found reachable and LDS used callees for any of the kernels,
  // and hence we have not created LDS pointer.
  if (!LDSPointer)
    return false;

  // We have created an LDS pointer for LDS, and initialized it to point-to LDS
  // within all relevant kernels. Now replace all the uses of LDS within
  // non-kernel functions by LDS pointer.
  replaceLDSUseByPointer(GV, LDSPointer);

  return true;
}

// Entry-point function which interface ReplaceLDSUseImpl with outside of the
// class.
bool ReplaceLDSUseImpl::replaceLDSUse() {
  // Collect LDS which requires their uses to be replaced by pointer.
  std::vector<GlobalVariable *> LDSGlobals =
      collectLDSRequiringPointerReplace();

  // No LDS to pointer-replace. Nothing to do.
  if (LDSGlobals.empty())
    return false;

  // Collect reachable callee set for each kernel defined in the module.
  AMDGPU::collectReachableCallees(M, KernelToCallees);

  if (KernelToCallees.empty()) {
    // Either module does not have any kernel definitions, or none of the kernel
    // has a call to non-kernel functions, or we could not resolve any of the
    // call sites to proper non-kernel functions, because of the situations like
    // inline asm calls. Nothing to replace.
    return false;
  }

  // For every LDS from collected LDS globals set, replace its non-kernel
  // function scope use by pointer.
  bool Changed = false;
  for (auto *GV : LDSGlobals)
    Changed |= replaceLDSUse(GV);

  return Changed;
}

class AMDGPUReplaceLDSUseWithPointer : public ModulePass {
public:
  static char ID;

  AMDGPUReplaceLDSUseWithPointer() : ModulePass(ID) {
    initializeAMDGPUReplaceLDSUseWithPointerPass(
        *PassRegistry::getPassRegistry());
  }

  bool runOnModule(Module &M) override;

  void getAnalysisUsage(AnalysisUsage &AU) const override {
    AU.addRequired<TargetPassConfig>();
  }
};

} // namespace

char AMDGPUReplaceLDSUseWithPointer::ID = 0;
char &llvm::AMDGPUReplaceLDSUseWithPointerID =
    AMDGPUReplaceLDSUseWithPointer::ID;

INITIALIZE_PASS_BEGIN(
    AMDGPUReplaceLDSUseWithPointer, DEBUG_TYPE,
    "Replace within non-kernel function use of LDS with pointer",
    false /*only look at the cfg*/, false /*analysis pass*/)
INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
INITIALIZE_PASS_END(
    AMDGPUReplaceLDSUseWithPointer, DEBUG_TYPE,
    "Replace within non-kernel function use of LDS with pointer",
    false /*only look at the cfg*/, false /*analysis pass*/)

bool AMDGPUReplaceLDSUseWithPointer::runOnModule(Module &M) {
  ReplaceLDSUseImpl LDSUseReplacer{M};
  return LDSUseReplacer.replaceLDSUse();
}

ModulePass *llvm::createAMDGPUReplaceLDSUseWithPointerPass() {
  return new AMDGPUReplaceLDSUseWithPointer();
}

PreservedAnalyses
AMDGPUReplaceLDSUseWithPointerPass::run(Module &M, ModuleAnalysisManager &AM) {
  ReplaceLDSUseImpl LDSUseReplacer{M};
  LDSUseReplacer.replaceLDSUse();
  return PreservedAnalyses::all();
}