aboutsummaryrefslogtreecommitdiff
path: root/contrib/llvm-project/llvm/lib/Transforms/Coroutines/CoroEarly.cpp
blob: 5e5e513cdfdabe17c38c0f5c2cde901c3b730ed8 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
//===- CoroEarly.cpp - Coroutine Early Function Pass ----------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "llvm/Transforms/Coroutines/CoroEarly.h"
#include "CoroInternal.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InstIterator.h"
#include "llvm/IR/Module.h"
#include "llvm/Pass.h"

using namespace llvm;

#define DEBUG_TYPE "coro-early"

namespace {
// Created on demand if the coro-early pass has work to do.
class Lowerer : public coro::LowererBase {
  IRBuilder<> Builder;
  PointerType *const AnyResumeFnPtrTy;
  Constant *NoopCoro = nullptr;

  void lowerResumeOrDestroy(CallBase &CB, CoroSubFnInst::ResumeKind);
  void lowerCoroPromise(CoroPromiseInst *Intrin);
  void lowerCoroDone(IntrinsicInst *II);
  void lowerCoroNoop(IntrinsicInst *II);

public:
  Lowerer(Module &M)
      : LowererBase(M), Builder(Context),
        AnyResumeFnPtrTy(FunctionType::get(Type::getVoidTy(Context), Int8Ptr,
                                           /*isVarArg=*/false)
                             ->getPointerTo()) {}
  bool lowerEarlyIntrinsics(Function &F);
};
}

// Replace a direct call to coro.resume or coro.destroy with an indirect call to
// an address returned by coro.subfn.addr intrinsic. This is done so that
// CGPassManager recognizes devirtualization when CoroElide pass replaces a call
// to coro.subfn.addr with an appropriate function address.
void Lowerer::lowerResumeOrDestroy(CallBase &CB,
                                   CoroSubFnInst::ResumeKind Index) {
  Value *ResumeAddr = makeSubFnCall(CB.getArgOperand(0), Index, &CB);
  CB.setCalledOperand(ResumeAddr);
  CB.setCallingConv(CallingConv::Fast);
}

// Coroutine promise field is always at the fixed offset from the beginning of
// the coroutine frame. i8* coro.promise(i8*, i1 from) intrinsic adds an offset
// to a passed pointer to move from coroutine frame to coroutine promise and
// vice versa. Since we don't know exactly which coroutine frame it is, we build
// a coroutine frame mock up starting with two function pointers, followed by a
// properly aligned coroutine promise field.
// TODO: Handle the case when coroutine promise alloca has align override.
void Lowerer::lowerCoroPromise(CoroPromiseInst *Intrin) {
  Value *Operand = Intrin->getArgOperand(0);
  Align Alignment = Intrin->getAlignment();
  Type *Int8Ty = Builder.getInt8Ty();

  auto *SampleStruct =
      StructType::get(Context, {AnyResumeFnPtrTy, AnyResumeFnPtrTy, Int8Ty});
  const DataLayout &DL = TheModule.getDataLayout();
  int64_t Offset = alignTo(
      DL.getStructLayout(SampleStruct)->getElementOffset(2), Alignment);
  if (Intrin->isFromPromise())
    Offset = -Offset;

  Builder.SetInsertPoint(Intrin);
  Value *Replacement =
      Builder.CreateConstInBoundsGEP1_32(Int8Ty, Operand, Offset);

  Intrin->replaceAllUsesWith(Replacement);
  Intrin->eraseFromParent();
}

// When a coroutine reaches final suspend point, it zeros out ResumeFnAddr in
// the coroutine frame (it is UB to resume from a final suspend point).
// The llvm.coro.done intrinsic is used to check whether a coroutine is
// suspended at the final suspend point or not.
void Lowerer::lowerCoroDone(IntrinsicInst *II) {
  Value *Operand = II->getArgOperand(0);

  // ResumeFnAddr is the first pointer sized element of the coroutine frame.
  static_assert(coro::Shape::SwitchFieldIndex::Resume == 0,
                "resume function not at offset zero");
  auto *FrameTy = Int8Ptr;
  PointerType *FramePtrTy = FrameTy->getPointerTo();

  Builder.SetInsertPoint(II);
  auto *BCI = Builder.CreateBitCast(Operand, FramePtrTy);
  auto *Load = Builder.CreateLoad(FrameTy, BCI);
  auto *Cond = Builder.CreateICmpEQ(Load, NullPtr);

  II->replaceAllUsesWith(Cond);
  II->eraseFromParent();
}

void Lowerer::lowerCoroNoop(IntrinsicInst *II) {
  if (!NoopCoro) {
    LLVMContext &C = Builder.getContext();
    Module &M = *II->getModule();

    // Create a noop.frame struct type.
    StructType *FrameTy = StructType::create(C, "NoopCoro.Frame");
    auto *FramePtrTy = FrameTy->getPointerTo();
    auto *FnTy = FunctionType::get(Type::getVoidTy(C), FramePtrTy,
                                   /*isVarArg=*/false);
    auto *FnPtrTy = FnTy->getPointerTo();
    FrameTy->setBody({FnPtrTy, FnPtrTy});

    // Create a Noop function that does nothing.
    Function *NoopFn =
        Function::Create(FnTy, GlobalValue::LinkageTypes::PrivateLinkage,
                         "NoopCoro.ResumeDestroy", &M);
    NoopFn->setCallingConv(CallingConv::Fast);
    auto *Entry = BasicBlock::Create(C, "entry", NoopFn);
    ReturnInst::Create(C, Entry);

    // Create a constant struct for the frame.
    Constant* Values[] = {NoopFn, NoopFn};
    Constant* NoopCoroConst = ConstantStruct::get(FrameTy, Values);
    NoopCoro = new GlobalVariable(M, NoopCoroConst->getType(), /*isConstant=*/true,
                                GlobalVariable::PrivateLinkage, NoopCoroConst,
                                "NoopCoro.Frame.Const");
  }

  Builder.SetInsertPoint(II);
  auto *NoopCoroVoidPtr = Builder.CreateBitCast(NoopCoro, Int8Ptr);
  II->replaceAllUsesWith(NoopCoroVoidPtr);
  II->eraseFromParent();
}

// Prior to CoroSplit, calls to coro.begin needs to be marked as NoDuplicate,
// as CoroSplit assumes there is exactly one coro.begin. After CoroSplit,
// NoDuplicate attribute will be removed from coro.begin otherwise, it will
// interfere with inlining.
static void setCannotDuplicate(CoroIdInst *CoroId) {
  for (User *U : CoroId->users())
    if (auto *CB = dyn_cast<CoroBeginInst>(U))
      CB->setCannotDuplicate();
}

bool Lowerer::lowerEarlyIntrinsics(Function &F) {
  bool Changed = false;
  CoroIdInst *CoroId = nullptr;
  SmallVector<CoroFreeInst *, 4> CoroFrees;
  bool HasCoroSuspend = false;
  for (auto IB = inst_begin(F), IE = inst_end(F); IB != IE;) {
    Instruction &I = *IB++;
    if (auto *CB = dyn_cast<CallBase>(&I)) {
      switch (CB->getIntrinsicID()) {
      default:
        continue;
      case Intrinsic::coro_free:
        CoroFrees.push_back(cast<CoroFreeInst>(&I));
        break;
      case Intrinsic::coro_suspend:
        // Make sure that final suspend point is not duplicated as CoroSplit
        // pass expects that there is at most one final suspend point.
        if (cast<CoroSuspendInst>(&I)->isFinal())
          CB->setCannotDuplicate();
        HasCoroSuspend = true;
        break;
      case Intrinsic::coro_end_async:
      case Intrinsic::coro_end:
        // Make sure that fallthrough coro.end is not duplicated as CoroSplit
        // pass expects that there is at most one fallthrough coro.end.
        if (cast<AnyCoroEndInst>(&I)->isFallthrough())
          CB->setCannotDuplicate();
        break;
      case Intrinsic::coro_noop:
        lowerCoroNoop(cast<IntrinsicInst>(&I));
        break;
      case Intrinsic::coro_id:
        // Mark a function that comes out of the frontend that has a coro.id
        // with a coroutine attribute.
        if (auto *CII = cast<CoroIdInst>(&I)) {
          if (CII->getInfo().isPreSplit()) {
            F.addFnAttr(CORO_PRESPLIT_ATTR, UNPREPARED_FOR_SPLIT);
            setCannotDuplicate(CII);
            CII->setCoroutineSelf();
            CoroId = cast<CoroIdInst>(&I);
          }
        }
        break;
      case Intrinsic::coro_id_retcon:
      case Intrinsic::coro_id_retcon_once:
      case Intrinsic::coro_id_async:
        F.addFnAttr(CORO_PRESPLIT_ATTR, PREPARED_FOR_SPLIT);
        break;
      case Intrinsic::coro_resume:
        lowerResumeOrDestroy(*CB, CoroSubFnInst::ResumeIndex);
        break;
      case Intrinsic::coro_destroy:
        lowerResumeOrDestroy(*CB, CoroSubFnInst::DestroyIndex);
        break;
      case Intrinsic::coro_promise:
        lowerCoroPromise(cast<CoroPromiseInst>(&I));
        break;
      case Intrinsic::coro_done:
        lowerCoroDone(cast<IntrinsicInst>(&I));
        break;
      }
      Changed = true;
    }
  }
  // Make sure that all CoroFree reference the coro.id intrinsic.
  // Token type is not exposed through coroutine C/C++ builtins to plain C, so
  // we allow specifying none and fixing it up here.
  if (CoroId)
    for (CoroFreeInst *CF : CoroFrees)
      CF->setArgOperand(0, CoroId);
  // Coroutine suspention could potentially lead to any argument modified
  // outside of the function, hence arguments should not have noalias
  // attributes.
  if (HasCoroSuspend)
    for (Argument &A : F.args())
      if (A.hasNoAliasAttr())
        A.removeAttr(Attribute::NoAlias);
  return Changed;
}

static bool declaresCoroEarlyIntrinsics(const Module &M) {
  return coro::declaresIntrinsics(
      M, {"llvm.coro.id", "llvm.coro.id.retcon", "llvm.coro.id.retcon.once",
          "llvm.coro.id.async", "llvm.coro.destroy", "llvm.coro.done",
          "llvm.coro.end", "llvm.coro.end.async", "llvm.coro.noop",
          "llvm.coro.free", "llvm.coro.promise", "llvm.coro.resume",
          "llvm.coro.suspend"});
}

PreservedAnalyses CoroEarlyPass::run(Function &F, FunctionAnalysisManager &) {
  Module &M = *F.getParent();
  if (!declaresCoroEarlyIntrinsics(M) || !Lowerer(M).lowerEarlyIntrinsics(F))
    return PreservedAnalyses::all();

  PreservedAnalyses PA;
  PA.preserveSet<CFGAnalyses>();
  return PA;
}

namespace {

struct CoroEarlyLegacy : public FunctionPass {
  static char ID; // Pass identification, replacement for typeid.
  CoroEarlyLegacy() : FunctionPass(ID) {
    initializeCoroEarlyLegacyPass(*PassRegistry::getPassRegistry());
  }

  std::unique_ptr<Lowerer> L;

  // This pass has work to do only if we find intrinsics we are going to lower
  // in the module.
  bool doInitialization(Module &M) override {
    if (declaresCoroEarlyIntrinsics(M))
      L = std::make_unique<Lowerer>(M);
    return false;
  }

  bool runOnFunction(Function &F) override {
    if (!L)
      return false;

    return L->lowerEarlyIntrinsics(F);
  }

  void getAnalysisUsage(AnalysisUsage &AU) const override {
    AU.setPreservesCFG();
  }
  StringRef getPassName() const override {
    return "Lower early coroutine intrinsics";
  }
};
}

char CoroEarlyLegacy::ID = 0;
INITIALIZE_PASS(CoroEarlyLegacy, "coro-early",
                "Lower early coroutine intrinsics", false, false)

Pass *llvm::createCoroEarlyLegacyPass() { return new CoroEarlyLegacy(); }