xref: /freebsd/contrib/llvm-project/llvm/lib/Transforms/Coroutines/CoroSplit.cpp (revision 700637cbb5e582861067a11aaca4d053546871d2)
1 //===- CoroSplit.cpp - Converts a coroutine into a state machine ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 // This pass builds the coroutine frame and outlines resume and destroy parts
9 // of the coroutine into separate functions.
10 //
11 // We present a coroutine to an LLVM as an ordinary function with suspension
12 // points marked up with intrinsics. We let the optimizer party on the coroutine
13 // as a single function for as long as possible. Shortly before the coroutine is
14 // eligible to be inlined into its callers, we split up the coroutine into parts
15 // corresponding to an initial, resume and destroy invocations of the coroutine,
16 // add them to the current SCC and restart the IPO pipeline to optimize the
17 // coroutine subfunctions we extracted before proceeding to the caller of the
18 // coroutine.
19 //===----------------------------------------------------------------------===//
20 
21 #include "llvm/Transforms/Coroutines/CoroSplit.h"
22 #include "CoroCloner.h"
23 #include "CoroInternal.h"
24 #include "llvm/ADT/DenseMap.h"
25 #include "llvm/ADT/PriorityWorklist.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/ADT/SmallPtrSet.h"
28 #include "llvm/ADT/SmallVector.h"
29 #include "llvm/ADT/StringExtras.h"
30 #include "llvm/ADT/StringRef.h"
31 #include "llvm/ADT/Twine.h"
32 #include "llvm/Analysis/CFG.h"
33 #include "llvm/Analysis/CallGraph.h"
34 #include "llvm/Analysis/ConstantFolding.h"
35 #include "llvm/Analysis/LazyCallGraph.h"
36 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
37 #include "llvm/Analysis/TargetTransformInfo.h"
38 #include "llvm/BinaryFormat/Dwarf.h"
39 #include "llvm/IR/Argument.h"
40 #include "llvm/IR/Attributes.h"
41 #include "llvm/IR/BasicBlock.h"
42 #include "llvm/IR/CFG.h"
43 #include "llvm/IR/CallingConv.h"
44 #include "llvm/IR/Constants.h"
45 #include "llvm/IR/DIBuilder.h"
46 #include "llvm/IR/DataLayout.h"
47 #include "llvm/IR/DebugInfo.h"
48 #include "llvm/IR/DerivedTypes.h"
49 #include "llvm/IR/Dominators.h"
50 #include "llvm/IR/GlobalValue.h"
51 #include "llvm/IR/GlobalVariable.h"
52 #include "llvm/IR/InstIterator.h"
53 #include "llvm/IR/InstrTypes.h"
54 #include "llvm/IR/Instruction.h"
55 #include "llvm/IR/Instructions.h"
56 #include "llvm/IR/IntrinsicInst.h"
57 #include "llvm/IR/LLVMContext.h"
58 #include "llvm/IR/Module.h"
59 #include "llvm/IR/Type.h"
60 #include "llvm/IR/Value.h"
61 #include "llvm/IR/Verifier.h"
62 #include "llvm/Support/Casting.h"
63 #include "llvm/Support/Debug.h"
64 #include "llvm/Support/PrettyStackTrace.h"
65 #include "llvm/Support/raw_ostream.h"
66 #include "llvm/Transforms/Coroutines/MaterializationUtils.h"
67 #include "llvm/Transforms/Scalar.h"
68 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
69 #include "llvm/Transforms/Utils/CallGraphUpdater.h"
70 #include "llvm/Transforms/Utils/Cloning.h"
71 #include "llvm/Transforms/Utils/Local.h"
72 #include <cassert>
73 #include <cstddef>
74 #include <cstdint>
75 #include <initializer_list>
76 #include <iterator>
77 
78 using namespace llvm;
79 
80 #define DEBUG_TYPE "coro-split"
81 
82 // FIXME:
83 // Lower the intrinisc in CoroEarly phase if coroutine frame doesn't escape
84 // and it is known that other transformations, for example, sanitizers
85 // won't lead to incorrect code.
lowerAwaitSuspend(IRBuilder<> & Builder,CoroAwaitSuspendInst * CB,coro::Shape & Shape)86 static void lowerAwaitSuspend(IRBuilder<> &Builder, CoroAwaitSuspendInst *CB,
87                               coro::Shape &Shape) {
88   auto Wrapper = CB->getWrapperFunction();
89   auto Awaiter = CB->getAwaiter();
90   auto FramePtr = CB->getFrame();
91 
92   Builder.SetInsertPoint(CB);
93 
94   CallBase *NewCall = nullptr;
95   // await_suspend has only 2 parameters, awaiter and handle.
96   // Copy parameter attributes from the intrinsic call, but remove the last,
97   // because the last parameter now becomes the function that is being called.
98   AttributeList NewAttributes =
99       CB->getAttributes().removeParamAttributes(CB->getContext(), 2);
100 
101   if (auto Invoke = dyn_cast<InvokeInst>(CB)) {
102     auto WrapperInvoke =
103         Builder.CreateInvoke(Wrapper, Invoke->getNormalDest(),
104                              Invoke->getUnwindDest(), {Awaiter, FramePtr});
105 
106     WrapperInvoke->setCallingConv(Invoke->getCallingConv());
107     std::copy(Invoke->bundle_op_info_begin(), Invoke->bundle_op_info_end(),
108               WrapperInvoke->bundle_op_info_begin());
109     WrapperInvoke->setAttributes(NewAttributes);
110     WrapperInvoke->setDebugLoc(Invoke->getDebugLoc());
111     NewCall = WrapperInvoke;
112   } else if (auto Call = dyn_cast<CallInst>(CB)) {
113     auto WrapperCall = Builder.CreateCall(Wrapper, {Awaiter, FramePtr});
114 
115     WrapperCall->setAttributes(NewAttributes);
116     WrapperCall->setDebugLoc(Call->getDebugLoc());
117     NewCall = WrapperCall;
118   } else {
119     llvm_unreachable("Unexpected coro_await_suspend invocation method");
120   }
121 
122   if (CB->getCalledFunction()->getIntrinsicID() ==
123       Intrinsic::coro_await_suspend_handle) {
124     // Follow the lowered await_suspend call above with a lowered resume call
125     // to the returned coroutine.
126     if (auto *Invoke = dyn_cast<InvokeInst>(CB)) {
127       // If the await_suspend call is an invoke, we continue in the next block.
128       Builder.SetInsertPoint(Invoke->getNormalDest()->getFirstInsertionPt());
129     }
130 
131     coro::LowererBase LB(*Wrapper->getParent());
132     auto *ResumeAddr = LB.makeSubFnCall(NewCall, CoroSubFnInst::ResumeIndex,
133                                         &*Builder.GetInsertPoint());
134 
135     LLVMContext &Ctx = Builder.getContext();
136     FunctionType *ResumeTy = FunctionType::get(
137         Type::getVoidTy(Ctx), PointerType::getUnqual(Ctx), false);
138     auto *ResumeCall = Builder.CreateCall(ResumeTy, ResumeAddr, {NewCall});
139     ResumeCall->setCallingConv(CallingConv::Fast);
140 
141     // We can't insert the 'ret' instruction and adjust the cc until the
142     // function has been split, so remember this for later.
143     Shape.SymmetricTransfers.push_back(ResumeCall);
144 
145     NewCall = ResumeCall;
146   }
147 
148   CB->replaceAllUsesWith(NewCall);
149   CB->eraseFromParent();
150 }
151 
lowerAwaitSuspends(Function & F,coro::Shape & Shape)152 static void lowerAwaitSuspends(Function &F, coro::Shape &Shape) {
153   IRBuilder<> Builder(F.getContext());
154   for (auto *AWS : Shape.CoroAwaitSuspends)
155     lowerAwaitSuspend(Builder, AWS, Shape);
156 }
157 
maybeFreeRetconStorage(IRBuilder<> & Builder,const coro::Shape & Shape,Value * FramePtr,CallGraph * CG)158 static void maybeFreeRetconStorage(IRBuilder<> &Builder,
159                                    const coro::Shape &Shape, Value *FramePtr,
160                                    CallGraph *CG) {
161   assert(Shape.ABI == coro::ABI::Retcon || Shape.ABI == coro::ABI::RetconOnce);
162   if (Shape.RetconLowering.IsFrameInlineInStorage)
163     return;
164 
165   Shape.emitDealloc(Builder, FramePtr, CG);
166 }
167 
168 /// Replace an llvm.coro.end.async.
169 /// Will inline the must tail call function call if there is one.
170 /// \returns true if cleanup of the coro.end block is needed, false otherwise.
replaceCoroEndAsync(AnyCoroEndInst * End)171 static bool replaceCoroEndAsync(AnyCoroEndInst *End) {
172   IRBuilder<> Builder(End);
173 
174   auto *EndAsync = dyn_cast<CoroAsyncEndInst>(End);
175   if (!EndAsync) {
176     Builder.CreateRetVoid();
177     return true /*needs cleanup of coro.end block*/;
178   }
179 
180   auto *MustTailCallFunc = EndAsync->getMustTailCallFunction();
181   if (!MustTailCallFunc) {
182     Builder.CreateRetVoid();
183     return true /*needs cleanup of coro.end block*/;
184   }
185 
186   // Move the must tail call from the predecessor block into the end block.
187   auto *CoroEndBlock = End->getParent();
188   auto *MustTailCallFuncBlock = CoroEndBlock->getSinglePredecessor();
189   assert(MustTailCallFuncBlock && "Must have a single predecessor block");
190   auto It = MustTailCallFuncBlock->getTerminator()->getIterator();
191   auto *MustTailCall = cast<CallInst>(&*std::prev(It));
192   CoroEndBlock->splice(End->getIterator(), MustTailCallFuncBlock,
193                        MustTailCall->getIterator());
194 
195   // Insert the return instruction.
196   Builder.SetInsertPoint(End);
197   Builder.CreateRetVoid();
198   InlineFunctionInfo FnInfo;
199 
200   // Remove the rest of the block, by splitting it into an unreachable block.
201   auto *BB = End->getParent();
202   BB->splitBasicBlock(End);
203   BB->getTerminator()->eraseFromParent();
204 
205   auto InlineRes = InlineFunction(*MustTailCall, FnInfo);
206   assert(InlineRes.isSuccess() && "Expected inlining to succeed");
207   (void)InlineRes;
208 
209   // We have cleaned up the coro.end block above.
210   return false;
211 }
212 
213 /// Replace a non-unwind call to llvm.coro.end.
replaceFallthroughCoroEnd(AnyCoroEndInst * End,const coro::Shape & Shape,Value * FramePtr,bool InResume,CallGraph * CG)214 static void replaceFallthroughCoroEnd(AnyCoroEndInst *End,
215                                       const coro::Shape &Shape, Value *FramePtr,
216                                       bool InResume, CallGraph *CG) {
217   // Start inserting right before the coro.end.
218   IRBuilder<> Builder(End);
219 
220   // Create the return instruction.
221   switch (Shape.ABI) {
222   // The cloned functions in switch-lowering always return void.
223   case coro::ABI::Switch:
224     assert(!cast<CoroEndInst>(End)->hasResults() &&
225            "switch coroutine should not return any values");
226     // coro.end doesn't immediately end the coroutine in the main function
227     // in this lowering, because we need to deallocate the coroutine.
228     if (!InResume)
229       return;
230     Builder.CreateRetVoid();
231     break;
232 
233   // In async lowering this returns.
234   case coro::ABI::Async: {
235     bool CoroEndBlockNeedsCleanup = replaceCoroEndAsync(End);
236     if (!CoroEndBlockNeedsCleanup)
237       return;
238     break;
239   }
240 
241   // In unique continuation lowering, the continuations always return void.
242   // But we may have implicitly allocated storage.
243   case coro::ABI::RetconOnce: {
244     maybeFreeRetconStorage(Builder, Shape, FramePtr, CG);
245     auto *CoroEnd = cast<CoroEndInst>(End);
246     auto *RetTy = Shape.getResumeFunctionType()->getReturnType();
247 
248     if (!CoroEnd->hasResults()) {
249       assert(RetTy->isVoidTy());
250       Builder.CreateRetVoid();
251       break;
252     }
253 
254     auto *CoroResults = CoroEnd->getResults();
255     unsigned NumReturns = CoroResults->numReturns();
256 
257     if (auto *RetStructTy = dyn_cast<StructType>(RetTy)) {
258       assert(RetStructTy->getNumElements() == NumReturns &&
259              "numbers of returns should match resume function singature");
260       Value *ReturnValue = PoisonValue::get(RetStructTy);
261       unsigned Idx = 0;
262       for (Value *RetValEl : CoroResults->return_values())
263         ReturnValue = Builder.CreateInsertValue(ReturnValue, RetValEl, Idx++);
264       Builder.CreateRet(ReturnValue);
265     } else if (NumReturns == 0) {
266       assert(RetTy->isVoidTy());
267       Builder.CreateRetVoid();
268     } else {
269       assert(NumReturns == 1);
270       Builder.CreateRet(*CoroResults->retval_begin());
271     }
272     CoroResults->replaceAllUsesWith(
273         ConstantTokenNone::get(CoroResults->getContext()));
274     CoroResults->eraseFromParent();
275     break;
276   }
277 
278   // In non-unique continuation lowering, we signal completion by returning
279   // a null continuation.
280   case coro::ABI::Retcon: {
281     assert(!cast<CoroEndInst>(End)->hasResults() &&
282            "retcon coroutine should not return any values");
283     maybeFreeRetconStorage(Builder, Shape, FramePtr, CG);
284     auto RetTy = Shape.getResumeFunctionType()->getReturnType();
285     auto RetStructTy = dyn_cast<StructType>(RetTy);
286     PointerType *ContinuationTy =
287         cast<PointerType>(RetStructTy ? RetStructTy->getElementType(0) : RetTy);
288 
289     Value *ReturnValue = ConstantPointerNull::get(ContinuationTy);
290     if (RetStructTy) {
291       ReturnValue = Builder.CreateInsertValue(PoisonValue::get(RetStructTy),
292                                               ReturnValue, 0);
293     }
294     Builder.CreateRet(ReturnValue);
295     break;
296   }
297   }
298 
299   // Remove the rest of the block, by splitting it into an unreachable block.
300   auto *BB = End->getParent();
301   BB->splitBasicBlock(End);
302   BB->getTerminator()->eraseFromParent();
303 }
304 
305 // Mark a coroutine as done, which implies that the coroutine is finished and
306 // never gets resumed.
307 //
308 // In resume-switched ABI, the done state is represented by storing zero in
309 // ResumeFnAddr.
310 //
311 // NOTE: We couldn't omit the argument `FramePtr`. It is necessary because the
312 // pointer to the frame in splitted function is not stored in `Shape`.
markCoroutineAsDone(IRBuilder<> & Builder,const coro::Shape & Shape,Value * FramePtr)313 static void markCoroutineAsDone(IRBuilder<> &Builder, const coro::Shape &Shape,
314                                 Value *FramePtr) {
315   assert(
316       Shape.ABI == coro::ABI::Switch &&
317       "markCoroutineAsDone is only supported for Switch-Resumed ABI for now.");
318   auto *GepIndex = Builder.CreateStructGEP(
319       Shape.FrameTy, FramePtr, coro::Shape::SwitchFieldIndex::Resume,
320       "ResumeFn.addr");
321   auto *NullPtr = ConstantPointerNull::get(cast<PointerType>(
322       Shape.FrameTy->getTypeAtIndex(coro::Shape::SwitchFieldIndex::Resume)));
323   Builder.CreateStore(NullPtr, GepIndex);
324 
325   // If the coroutine don't have unwind coro end, we could omit the store to
326   // the final suspend point since we could infer the coroutine is suspended
327   // at the final suspend point by the nullness of ResumeFnAddr.
328   // However, we can't skip it if the coroutine have unwind coro end. Since
329   // the coroutine reaches unwind coro end is considered suspended at the
330   // final suspend point (the ResumeFnAddr is null) but in fact the coroutine
331   // didn't complete yet. We need the IndexVal for the final suspend point
332   // to make the states clear.
333   if (Shape.SwitchLowering.HasUnwindCoroEnd &&
334       Shape.SwitchLowering.HasFinalSuspend) {
335     assert(cast<CoroSuspendInst>(Shape.CoroSuspends.back())->isFinal() &&
336            "The final suspend should only live in the last position of "
337            "CoroSuspends.");
338     ConstantInt *IndexVal = Shape.getIndex(Shape.CoroSuspends.size() - 1);
339     auto *FinalIndex = Builder.CreateStructGEP(
340         Shape.FrameTy, FramePtr, Shape.getSwitchIndexField(), "index.addr");
341 
342     Builder.CreateStore(IndexVal, FinalIndex);
343   }
344 }
345 
346 /// Replace an unwind call to llvm.coro.end.
replaceUnwindCoroEnd(AnyCoroEndInst * End,const coro::Shape & Shape,Value * FramePtr,bool InResume,CallGraph * CG)347 static void replaceUnwindCoroEnd(AnyCoroEndInst *End, const coro::Shape &Shape,
348                                  Value *FramePtr, bool InResume,
349                                  CallGraph *CG) {
350   IRBuilder<> Builder(End);
351 
352   switch (Shape.ABI) {
353   // In switch-lowering, this does nothing in the main function.
354   case coro::ABI::Switch: {
355     // In C++'s specification, the coroutine should be marked as done
356     // if promise.unhandled_exception() throws.  The frontend will
357     // call coro.end(true) along this path.
358     //
359     // FIXME: We should refactor this once there is other language
360     // which uses Switch-Resumed style other than C++.
361     markCoroutineAsDone(Builder, Shape, FramePtr);
362     if (!InResume)
363       return;
364     break;
365   }
366   // In async lowering this does nothing.
367   case coro::ABI::Async:
368     break;
369   // In continuation-lowering, this frees the continuation storage.
370   case coro::ABI::Retcon:
371   case coro::ABI::RetconOnce:
372     maybeFreeRetconStorage(Builder, Shape, FramePtr, CG);
373     break;
374   }
375 
376   // If coro.end has an associated bundle, add cleanupret instruction.
377   if (auto Bundle = End->getOperandBundle(LLVMContext::OB_funclet)) {
378     auto *FromPad = cast<CleanupPadInst>(Bundle->Inputs[0]);
379     auto *CleanupRet = Builder.CreateCleanupRet(FromPad, nullptr);
380     End->getParent()->splitBasicBlock(End);
381     CleanupRet->getParent()->getTerminator()->eraseFromParent();
382   }
383 }
384 
replaceCoroEnd(AnyCoroEndInst * End,const coro::Shape & Shape,Value * FramePtr,bool InResume,CallGraph * CG)385 static void replaceCoroEnd(AnyCoroEndInst *End, const coro::Shape &Shape,
386                            Value *FramePtr, bool InResume, CallGraph *CG) {
387   if (End->isUnwind())
388     replaceUnwindCoroEnd(End, Shape, FramePtr, InResume, CG);
389   else
390     replaceFallthroughCoroEnd(End, Shape, FramePtr, InResume, CG);
391 
392   auto &Context = End->getContext();
393   End->replaceAllUsesWith(InResume ? ConstantInt::getTrue(Context)
394                                    : ConstantInt::getFalse(Context));
395   End->eraseFromParent();
396 }
397 
398 // In the resume function, we remove the last case  (when coro::Shape is built,
399 // the final suspend point (if present) is always the last element of
400 // CoroSuspends array) since it is an undefined behavior to resume a coroutine
401 // suspended at the final suspend point.
402 // In the destroy function, if it isn't possible that the ResumeFnAddr is NULL
403 // and the coroutine doesn't suspend at the final suspend point actually (this
404 // is possible since the coroutine is considered suspended at the final suspend
405 // point if promise.unhandled_exception() exits via an exception), we can
406 // remove the last case.
handleFinalSuspend()407 void coro::BaseCloner::handleFinalSuspend() {
408   assert(Shape.ABI == coro::ABI::Switch &&
409          Shape.SwitchLowering.HasFinalSuspend);
410 
411   if (isSwitchDestroyFunction() && Shape.SwitchLowering.HasUnwindCoroEnd)
412     return;
413 
414   auto *Switch = cast<SwitchInst>(VMap[Shape.SwitchLowering.ResumeSwitch]);
415   auto FinalCaseIt = std::prev(Switch->case_end());
416   BasicBlock *ResumeBB = FinalCaseIt->getCaseSuccessor();
417   Switch->removeCase(FinalCaseIt);
418   if (isSwitchDestroyFunction()) {
419     BasicBlock *OldSwitchBB = Switch->getParent();
420     auto *NewSwitchBB = OldSwitchBB->splitBasicBlock(Switch, "Switch");
421     Builder.SetInsertPoint(OldSwitchBB->getTerminator());
422 
423     if (NewF->isCoroOnlyDestroyWhenComplete()) {
424       // When the coroutine can only be destroyed when complete, we don't need
425       // to generate code for other cases.
426       Builder.CreateBr(ResumeBB);
427     } else {
428       auto *GepIndex = Builder.CreateStructGEP(
429           Shape.FrameTy, NewFramePtr, coro::Shape::SwitchFieldIndex::Resume,
430           "ResumeFn.addr");
431       auto *Load =
432           Builder.CreateLoad(Shape.getSwitchResumePointerType(), GepIndex);
433       auto *Cond = Builder.CreateIsNull(Load);
434       Builder.CreateCondBr(Cond, ResumeBB, NewSwitchBB);
435     }
436     OldSwitchBB->getTerminator()->eraseFromParent();
437   }
438 }
439 
440 static FunctionType *
getFunctionTypeFromAsyncSuspend(AnyCoroSuspendInst * Suspend)441 getFunctionTypeFromAsyncSuspend(AnyCoroSuspendInst *Suspend) {
442   auto *AsyncSuspend = cast<CoroSuspendAsyncInst>(Suspend);
443   auto *StructTy = cast<StructType>(AsyncSuspend->getType());
444   auto &Context = Suspend->getParent()->getParent()->getContext();
445   auto *VoidTy = Type::getVoidTy(Context);
446   return FunctionType::get(VoidTy, StructTy->elements(), false);
447 }
448 
createCloneDeclaration(Function & OrigF,coro::Shape & Shape,const Twine & Suffix,Module::iterator InsertBefore,AnyCoroSuspendInst * ActiveSuspend)449 static Function *createCloneDeclaration(Function &OrigF, coro::Shape &Shape,
450                                         const Twine &Suffix,
451                                         Module::iterator InsertBefore,
452                                         AnyCoroSuspendInst *ActiveSuspend) {
453   Module *M = OrigF.getParent();
454   auto *FnTy = (Shape.ABI != coro::ABI::Async)
455                    ? Shape.getResumeFunctionType()
456                    : getFunctionTypeFromAsyncSuspend(ActiveSuspend);
457 
458   Function *NewF =
459       Function::Create(FnTy, GlobalValue::LinkageTypes::InternalLinkage,
460                        OrigF.getName() + Suffix);
461 
462   M->getFunctionList().insert(InsertBefore, NewF);
463 
464   return NewF;
465 }
466 
467 /// Replace uses of the active llvm.coro.suspend.retcon/async call with the
468 /// arguments to the continuation function.
469 ///
470 /// This assumes that the builder has a meaningful insertion point.
replaceRetconOrAsyncSuspendUses()471 void coro::BaseCloner::replaceRetconOrAsyncSuspendUses() {
472   assert(Shape.ABI == coro::ABI::Retcon || Shape.ABI == coro::ABI::RetconOnce ||
473          Shape.ABI == coro::ABI::Async);
474 
475   auto NewS = VMap[ActiveSuspend];
476   if (NewS->use_empty())
477     return;
478 
479   // Copy out all the continuation arguments after the buffer pointer into
480   // an easily-indexed data structure for convenience.
481   SmallVector<Value *, 8> Args;
482   // The async ABI includes all arguments -- including the first argument.
483   bool IsAsyncABI = Shape.ABI == coro::ABI::Async;
484   for (auto I = IsAsyncABI ? NewF->arg_begin() : std::next(NewF->arg_begin()),
485             E = NewF->arg_end();
486        I != E; ++I)
487     Args.push_back(&*I);
488 
489   // If the suspend returns a single scalar value, we can just do a simple
490   // replacement.
491   if (!isa<StructType>(NewS->getType())) {
492     assert(Args.size() == 1);
493     NewS->replaceAllUsesWith(Args.front());
494     return;
495   }
496 
497   // Try to peephole extracts of an aggregate return.
498   for (Use &U : llvm::make_early_inc_range(NewS->uses())) {
499     auto *EVI = dyn_cast<ExtractValueInst>(U.getUser());
500     if (!EVI || EVI->getNumIndices() != 1)
501       continue;
502 
503     EVI->replaceAllUsesWith(Args[EVI->getIndices().front()]);
504     EVI->eraseFromParent();
505   }
506 
507   // If we have no remaining uses, we're done.
508   if (NewS->use_empty())
509     return;
510 
511   // Otherwise, we need to create an aggregate.
512   Value *Aggr = PoisonValue::get(NewS->getType());
513   for (auto [Idx, Arg] : llvm::enumerate(Args))
514     Aggr = Builder.CreateInsertValue(Aggr, Arg, Idx);
515 
516   NewS->replaceAllUsesWith(Aggr);
517 }
518 
replaceCoroSuspends()519 void coro::BaseCloner::replaceCoroSuspends() {
520   Value *SuspendResult;
521 
522   switch (Shape.ABI) {
523   // In switch lowering, replace coro.suspend with the appropriate value
524   // for the type of function we're extracting.
525   // Replacing coro.suspend with (0) will result in control flow proceeding to
526   // a resume label associated with a suspend point, replacing it with (1) will
527   // result in control flow proceeding to a cleanup label associated with this
528   // suspend point.
529   case coro::ABI::Switch:
530     SuspendResult = Builder.getInt8(isSwitchDestroyFunction() ? 1 : 0);
531     break;
532 
533   // In async lowering there are no uses of the result.
534   case coro::ABI::Async:
535     return;
536 
537   // In returned-continuation lowering, the arguments from earlier
538   // continuations are theoretically arbitrary, and they should have been
539   // spilled.
540   case coro::ABI::RetconOnce:
541   case coro::ABI::Retcon:
542     return;
543   }
544 
545   for (AnyCoroSuspendInst *CS : Shape.CoroSuspends) {
546     // The active suspend was handled earlier.
547     if (CS == ActiveSuspend)
548       continue;
549 
550     auto *MappedCS = cast<AnyCoroSuspendInst>(VMap[CS]);
551     MappedCS->replaceAllUsesWith(SuspendResult);
552     MappedCS->eraseFromParent();
553   }
554 }
555 
replaceCoroEnds()556 void coro::BaseCloner::replaceCoroEnds() {
557   for (AnyCoroEndInst *CE : Shape.CoroEnds) {
558     // We use a null call graph because there's no call graph node for
559     // the cloned function yet.  We'll just be rebuilding that later.
560     auto *NewCE = cast<AnyCoroEndInst>(VMap[CE]);
561     replaceCoroEnd(NewCE, Shape, NewFramePtr, /*in resume*/ true, nullptr);
562   }
563 }
564 
replaceSwiftErrorOps(Function & F,coro::Shape & Shape,ValueToValueMapTy * VMap)565 static void replaceSwiftErrorOps(Function &F, coro::Shape &Shape,
566                                  ValueToValueMapTy *VMap) {
567   if (Shape.ABI == coro::ABI::Async && Shape.CoroSuspends.empty())
568     return;
569   Value *CachedSlot = nullptr;
570   auto getSwiftErrorSlot = [&](Type *ValueTy) -> Value * {
571     if (CachedSlot)
572       return CachedSlot;
573 
574     // Check if the function has a swifterror argument.
575     for (auto &Arg : F.args()) {
576       if (Arg.isSwiftError()) {
577         CachedSlot = &Arg;
578         return &Arg;
579       }
580     }
581 
582     // Create a swifterror alloca.
583     IRBuilder<> Builder(&F.getEntryBlock(),
584                         F.getEntryBlock().getFirstNonPHIOrDbg());
585     auto Alloca = Builder.CreateAlloca(ValueTy);
586     Alloca->setSwiftError(true);
587 
588     CachedSlot = Alloca;
589     return Alloca;
590   };
591 
592   for (CallInst *Op : Shape.SwiftErrorOps) {
593     auto MappedOp = VMap ? cast<CallInst>((*VMap)[Op]) : Op;
594     IRBuilder<> Builder(MappedOp);
595 
596     // If there are no arguments, this is a 'get' operation.
597     Value *MappedResult;
598     if (Op->arg_empty()) {
599       auto ValueTy = Op->getType();
600       auto Slot = getSwiftErrorSlot(ValueTy);
601       MappedResult = Builder.CreateLoad(ValueTy, Slot);
602     } else {
603       assert(Op->arg_size() == 1);
604       auto Value = MappedOp->getArgOperand(0);
605       auto ValueTy = Value->getType();
606       auto Slot = getSwiftErrorSlot(ValueTy);
607       Builder.CreateStore(Value, Slot);
608       MappedResult = Slot;
609     }
610 
611     MappedOp->replaceAllUsesWith(MappedResult);
612     MappedOp->eraseFromParent();
613   }
614 
615   // If we're updating the original function, we've invalidated SwiftErrorOps.
616   if (VMap == nullptr) {
617     Shape.SwiftErrorOps.clear();
618   }
619 }
620 
621 /// Returns all DbgVariableIntrinsic in F.
622 static std::pair<SmallVector<DbgVariableIntrinsic *, 8>,
623                  SmallVector<DbgVariableRecord *>>
collectDbgVariableIntrinsics(Function & F)624 collectDbgVariableIntrinsics(Function &F) {
625   SmallVector<DbgVariableIntrinsic *, 8> Intrinsics;
626   SmallVector<DbgVariableRecord *> DbgVariableRecords;
627   for (auto &I : instructions(F)) {
628     for (DbgVariableRecord &DVR : filterDbgVars(I.getDbgRecordRange()))
629       DbgVariableRecords.push_back(&DVR);
630     if (auto *DVI = dyn_cast<DbgVariableIntrinsic>(&I))
631       Intrinsics.push_back(DVI);
632   }
633   return {Intrinsics, DbgVariableRecords};
634 }
635 
replaceSwiftErrorOps()636 void coro::BaseCloner::replaceSwiftErrorOps() {
637   ::replaceSwiftErrorOps(*NewF, Shape, &VMap);
638 }
639 
salvageDebugInfo()640 void coro::BaseCloner::salvageDebugInfo() {
641   auto [Worklist, DbgVariableRecords] = collectDbgVariableIntrinsics(*NewF);
642   SmallDenseMap<Argument *, AllocaInst *, 4> ArgToAllocaMap;
643 
644   // Only 64-bit ABIs have a register we can refer to with the entry value.
645   bool UseEntryValue = OrigF.getParent()->getTargetTriple().isArch64Bit();
646   for (DbgVariableIntrinsic *DVI : Worklist)
647     coro::salvageDebugInfo(ArgToAllocaMap, *DVI, UseEntryValue);
648   for (DbgVariableRecord *DVR : DbgVariableRecords)
649     coro::salvageDebugInfo(ArgToAllocaMap, *DVR, UseEntryValue);
650 
651   // Remove all salvaged dbg.declare intrinsics that became
652   // either unreachable or stale due to the CoroSplit transformation.
653   DominatorTree DomTree(*NewF);
654   auto IsUnreachableBlock = [&](BasicBlock *BB) {
655     return !isPotentiallyReachable(&NewF->getEntryBlock(), BB, nullptr,
656                                    &DomTree);
657   };
658   auto RemoveOne = [&](auto *DVI) {
659     if (IsUnreachableBlock(DVI->getParent()))
660       DVI->eraseFromParent();
661     else if (isa_and_nonnull<AllocaInst>(DVI->getVariableLocationOp(0))) {
662       // Count all non-debuginfo uses in reachable blocks.
663       unsigned Uses = 0;
664       for (auto *User : DVI->getVariableLocationOp(0)->users())
665         if (auto *I = dyn_cast<Instruction>(User))
666           if (!isa<AllocaInst>(I) && !IsUnreachableBlock(I->getParent()))
667             ++Uses;
668       if (!Uses)
669         DVI->eraseFromParent();
670     }
671   };
672   for_each(Worklist, RemoveOne);
673   for_each(DbgVariableRecords, RemoveOne);
674 }
675 
replaceEntryBlock()676 void coro::BaseCloner::replaceEntryBlock() {
677   // In the original function, the AllocaSpillBlock is a block immediately
678   // following the allocation of the frame object which defines GEPs for
679   // all the allocas that have been moved into the frame, and it ends by
680   // branching to the original beginning of the coroutine.  Make this
681   // the entry block of the cloned function.
682   auto *Entry = cast<BasicBlock>(VMap[Shape.AllocaSpillBlock]);
683   auto *OldEntry = &NewF->getEntryBlock();
684   Entry->setName("entry" + Suffix);
685   Entry->moveBefore(OldEntry);
686   Entry->getTerminator()->eraseFromParent();
687 
688   // Clear all predecessors of the new entry block.  There should be
689   // exactly one predecessor, which we created when splitting out
690   // AllocaSpillBlock to begin with.
691   assert(Entry->hasOneUse());
692   auto BranchToEntry = cast<BranchInst>(Entry->user_back());
693   assert(BranchToEntry->isUnconditional());
694   Builder.SetInsertPoint(BranchToEntry);
695   Builder.CreateUnreachable();
696   BranchToEntry->eraseFromParent();
697 
698   // Branch from the entry to the appropriate place.
699   Builder.SetInsertPoint(Entry);
700   switch (Shape.ABI) {
701   case coro::ABI::Switch: {
702     // In switch-lowering, we built a resume-entry block in the original
703     // function.  Make the entry block branch to this.
704     auto *SwitchBB =
705         cast<BasicBlock>(VMap[Shape.SwitchLowering.ResumeEntryBlock]);
706     Builder.CreateBr(SwitchBB);
707     SwitchBB->moveAfter(Entry);
708     break;
709   }
710   case coro::ABI::Async:
711   case coro::ABI::Retcon:
712   case coro::ABI::RetconOnce: {
713     // In continuation ABIs, we want to branch to immediately after the
714     // active suspend point.  Earlier phases will have put the suspend in its
715     // own basic block, so just thread our jump directly to its successor.
716     assert((Shape.ABI == coro::ABI::Async &&
717             isa<CoroSuspendAsyncInst>(ActiveSuspend)) ||
718            ((Shape.ABI == coro::ABI::Retcon ||
719              Shape.ABI == coro::ABI::RetconOnce) &&
720             isa<CoroSuspendRetconInst>(ActiveSuspend)));
721     auto *MappedCS = cast<AnyCoroSuspendInst>(VMap[ActiveSuspend]);
722     auto Branch = cast<BranchInst>(MappedCS->getNextNode());
723     assert(Branch->isUnconditional());
724     Builder.CreateBr(Branch->getSuccessor(0));
725     break;
726   }
727   }
728 
729   // Any static alloca that's still being used but not reachable from the new
730   // entry needs to be moved to the new entry.
731   Function *F = OldEntry->getParent();
732   DominatorTree DT{*F};
733   for (Instruction &I : llvm::make_early_inc_range(instructions(F))) {
734     auto *Alloca = dyn_cast<AllocaInst>(&I);
735     if (!Alloca || I.use_empty())
736       continue;
737     if (DT.isReachableFromEntry(I.getParent()) ||
738         !isa<ConstantInt>(Alloca->getArraySize()))
739       continue;
740     I.moveBefore(*Entry, Entry->getFirstInsertionPt());
741   }
742 }
743 
744 /// Derive the value of the new frame pointer.
deriveNewFramePointer()745 Value *coro::BaseCloner::deriveNewFramePointer() {
746   // Builder should be inserting to the front of the new entry block.
747 
748   switch (Shape.ABI) {
749   // In switch-lowering, the argument is the frame pointer.
750   case coro::ABI::Switch:
751     return &*NewF->arg_begin();
752   // In async-lowering, one of the arguments is an async context as determined
753   // by the `llvm.coro.id.async` intrinsic. We can retrieve the async context of
754   // the resume function from the async context projection function associated
755   // with the active suspend. The frame is located as a tail to the async
756   // context header.
757   case coro::ABI::Async: {
758     auto *ActiveAsyncSuspend = cast<CoroSuspendAsyncInst>(ActiveSuspend);
759     auto ContextIdx = ActiveAsyncSuspend->getStorageArgumentIndex() & 0xff;
760     auto *CalleeContext = NewF->getArg(ContextIdx);
761     auto *ProjectionFunc =
762         ActiveAsyncSuspend->getAsyncContextProjectionFunction();
763     auto DbgLoc =
764         cast<CoroSuspendAsyncInst>(VMap[ActiveSuspend])->getDebugLoc();
765     // Calling i8* (i8*)
766     auto *CallerContext = Builder.CreateCall(ProjectionFunc->getFunctionType(),
767                                              ProjectionFunc, CalleeContext);
768     CallerContext->setCallingConv(ProjectionFunc->getCallingConv());
769     CallerContext->setDebugLoc(DbgLoc);
770     // The frame is located after the async_context header.
771     auto &Context = Builder.getContext();
772     auto *FramePtrAddr = Builder.CreateConstInBoundsGEP1_32(
773         Type::getInt8Ty(Context), CallerContext,
774         Shape.AsyncLowering.FrameOffset, "async.ctx.frameptr");
775     // Inline the projection function.
776     InlineFunctionInfo InlineInfo;
777     auto InlineRes = InlineFunction(*CallerContext, InlineInfo);
778     assert(InlineRes.isSuccess());
779     (void)InlineRes;
780     return FramePtrAddr;
781   }
782   // In continuation-lowering, the argument is the opaque storage.
783   case coro::ABI::Retcon:
784   case coro::ABI::RetconOnce: {
785     Argument *NewStorage = &*NewF->arg_begin();
786     auto FramePtrTy = PointerType::getUnqual(Shape.FrameTy->getContext());
787 
788     // If the storage is inline, just bitcast to the storage to the frame type.
789     if (Shape.RetconLowering.IsFrameInlineInStorage)
790       return NewStorage;
791 
792     // Otherwise, load the real frame from the opaque storage.
793     return Builder.CreateLoad(FramePtrTy, NewStorage);
794   }
795   }
796   llvm_unreachable("bad ABI");
797 }
798 
799 /// Adjust the scope line of the funclet to the first line number after the
800 /// suspend point. This avoids a jump in the line table from the function
801 /// declaration (where prologue instructions are attributed to) to the suspend
802 /// point.
803 /// Only adjust the scope line when the files are the same.
804 /// If no candidate line number is found, fallback to the line of ActiveSuspend.
updateScopeLine(Instruction * ActiveSuspend,DISubprogram & SPToUpdate)805 static void updateScopeLine(Instruction *ActiveSuspend,
806                             DISubprogram &SPToUpdate) {
807   if (!ActiveSuspend)
808     return;
809 
810   // No subsequent instruction -> fallback to the location of ActiveSuspend.
811   if (!ActiveSuspend->getNextNonDebugInstruction()) {
812     if (auto DL = ActiveSuspend->getDebugLoc())
813       if (SPToUpdate.getFile() == DL->getFile())
814         SPToUpdate.setScopeLine(DL->getLine());
815     return;
816   }
817 
818   BasicBlock::iterator Successor =
819       ActiveSuspend->getNextNonDebugInstruction()->getIterator();
820   // Corosplit splits the BB around ActiveSuspend, so the meaningful
821   // instructions are not in the same BB.
822   if (auto *Branch = dyn_cast_or_null<BranchInst>(Successor);
823       Branch && Branch->isUnconditional())
824     Successor = Branch->getSuccessor(0)->getFirstNonPHIOrDbg();
825 
826   // Find the first successor of ActiveSuspend with a non-zero line location.
827   // If that matches the file of ActiveSuspend, use it.
828   BasicBlock *PBB = Successor->getParent();
829   for (; Successor != PBB->end(); Successor = std::next(Successor)) {
830     Successor = skipDebugIntrinsics(Successor);
831     auto DL = Successor->getDebugLoc();
832     if (!DL || DL.getLine() == 0)
833       continue;
834 
835     if (SPToUpdate.getFile() == DL->getFile()) {
836       SPToUpdate.setScopeLine(DL.getLine());
837       return;
838     }
839 
840     break;
841   }
842 
843   // If the search above failed, fallback to the location of ActiveSuspend.
844   if (auto DL = ActiveSuspend->getDebugLoc())
845     if (SPToUpdate.getFile() == DL->getFile())
846       SPToUpdate.setScopeLine(DL->getLine());
847 }
848 
addFramePointerAttrs(AttributeList & Attrs,LLVMContext & Context,unsigned ParamIndex,uint64_t Size,Align Alignment,bool NoAlias)849 static void addFramePointerAttrs(AttributeList &Attrs, LLVMContext &Context,
850                                  unsigned ParamIndex, uint64_t Size,
851                                  Align Alignment, bool NoAlias) {
852   AttrBuilder ParamAttrs(Context);
853   ParamAttrs.addAttribute(Attribute::NonNull);
854   ParamAttrs.addAttribute(Attribute::NoUndef);
855 
856   if (NoAlias)
857     ParamAttrs.addAttribute(Attribute::NoAlias);
858 
859   ParamAttrs.addAlignmentAttr(Alignment);
860   ParamAttrs.addDereferenceableAttr(Size);
861   Attrs = Attrs.addParamAttributes(Context, ParamIndex, ParamAttrs);
862 }
863 
addAsyncContextAttrs(AttributeList & Attrs,LLVMContext & Context,unsigned ParamIndex)864 static void addAsyncContextAttrs(AttributeList &Attrs, LLVMContext &Context,
865                                  unsigned ParamIndex) {
866   AttrBuilder ParamAttrs(Context);
867   ParamAttrs.addAttribute(Attribute::SwiftAsync);
868   Attrs = Attrs.addParamAttributes(Context, ParamIndex, ParamAttrs);
869 }
870 
addSwiftSelfAttrs(AttributeList & Attrs,LLVMContext & Context,unsigned ParamIndex)871 static void addSwiftSelfAttrs(AttributeList &Attrs, LLVMContext &Context,
872                               unsigned ParamIndex) {
873   AttrBuilder ParamAttrs(Context);
874   ParamAttrs.addAttribute(Attribute::SwiftSelf);
875   Attrs = Attrs.addParamAttributes(Context, ParamIndex, ParamAttrs);
876 }
877 
878 /// Clone the body of the original function into a resume function of
879 /// some sort.
create()880 void coro::BaseCloner::create() {
881   assert(NewF);
882 
883   // Replace all args with dummy instructions. If an argument is the old frame
884   // pointer, the dummy will be replaced by the new frame pointer once it is
885   // computed below. Uses of all other arguments should have already been
886   // rewritten by buildCoroutineFrame() to use loads/stores on the coroutine
887   // frame.
888   SmallVector<Instruction *> DummyArgs;
889   for (Argument &A : OrigF.args()) {
890     DummyArgs.push_back(new FreezeInst(PoisonValue::get(A.getType())));
891     VMap[&A] = DummyArgs.back();
892   }
893 
894   SmallVector<ReturnInst *, 4> Returns;
895 
896   // Ignore attempts to change certain attributes of the function.
897   // TODO: maybe there should be a way to suppress this during cloning?
898   auto savedVisibility = NewF->getVisibility();
899   auto savedUnnamedAddr = NewF->getUnnamedAddr();
900   auto savedDLLStorageClass = NewF->getDLLStorageClass();
901 
902   // NewF's linkage (which CloneFunctionInto does *not* change) might not
903   // be compatible with the visibility of OrigF (which it *does* change),
904   // so protect against that.
905   auto savedLinkage = NewF->getLinkage();
906   NewF->setLinkage(llvm::GlobalValue::ExternalLinkage);
907 
908   CloneFunctionInto(NewF, &OrigF, VMap,
909                     CloneFunctionChangeType::LocalChangesOnly, Returns);
910 
911   auto &Context = NewF->getContext();
912 
913   if (DISubprogram *SP = NewF->getSubprogram()) {
914     assert(SP != OrigF.getSubprogram() && SP->isDistinct());
915     updateScopeLine(ActiveSuspend, *SP);
916 
917     // Update the linkage name and the function name to reflect the modified
918     // name.
919     MDString *NewLinkageName = MDString::get(Context, NewF->getName());
920     SP->replaceLinkageName(NewLinkageName);
921     if (DISubprogram *Decl = SP->getDeclaration()) {
922       TempDISubprogram NewDecl = Decl->clone();
923       NewDecl->replaceLinkageName(NewLinkageName);
924       SP->replaceDeclaration(MDNode::replaceWithUniqued(std::move(NewDecl)));
925     }
926   }
927 
928   NewF->setLinkage(savedLinkage);
929   NewF->setVisibility(savedVisibility);
930   NewF->setUnnamedAddr(savedUnnamedAddr);
931   NewF->setDLLStorageClass(savedDLLStorageClass);
932   // The function sanitizer metadata needs to match the signature of the
933   // function it is being attached to. However this does not hold for split
934   // functions here. Thus remove the metadata for split functions.
935   if (Shape.ABI == coro::ABI::Switch &&
936       NewF->hasMetadata(LLVMContext::MD_func_sanitize))
937     NewF->eraseMetadata(LLVMContext::MD_func_sanitize);
938 
939   // Replace the attributes of the new function:
940   auto OrigAttrs = NewF->getAttributes();
941   auto NewAttrs = AttributeList();
942 
943   switch (Shape.ABI) {
944   case coro::ABI::Switch:
945     // Bootstrap attributes by copying function attributes from the
946     // original function.  This should include optimization settings and so on.
947     NewAttrs = NewAttrs.addFnAttributes(
948         Context, AttrBuilder(Context, OrigAttrs.getFnAttrs()));
949 
950     addFramePointerAttrs(NewAttrs, Context, 0, Shape.FrameSize,
951                          Shape.FrameAlign, /*NoAlias=*/false);
952     break;
953   case coro::ABI::Async: {
954     auto *ActiveAsyncSuspend = cast<CoroSuspendAsyncInst>(ActiveSuspend);
955     if (OrigF.hasParamAttribute(Shape.AsyncLowering.ContextArgNo,
956                                 Attribute::SwiftAsync)) {
957       uint32_t ArgAttributeIndices =
958           ActiveAsyncSuspend->getStorageArgumentIndex();
959       auto ContextArgIndex = ArgAttributeIndices & 0xff;
960       addAsyncContextAttrs(NewAttrs, Context, ContextArgIndex);
961 
962       // `swiftasync` must preceed `swiftself` so 0 is not a valid index for
963       // `swiftself`.
964       auto SwiftSelfIndex = ArgAttributeIndices >> 8;
965       if (SwiftSelfIndex)
966         addSwiftSelfAttrs(NewAttrs, Context, SwiftSelfIndex);
967     }
968 
969     // Transfer the original function's attributes.
970     auto FnAttrs = OrigF.getAttributes().getFnAttrs();
971     NewAttrs = NewAttrs.addFnAttributes(Context, AttrBuilder(Context, FnAttrs));
972     break;
973   }
974   case coro::ABI::Retcon:
975   case coro::ABI::RetconOnce:
976     // If we have a continuation prototype, just use its attributes,
977     // full-stop.
978     NewAttrs = Shape.RetconLowering.ResumePrototype->getAttributes();
979 
980     /// FIXME: Is it really good to add the NoAlias attribute?
981     addFramePointerAttrs(NewAttrs, Context, 0,
982                          Shape.getRetconCoroId()->getStorageSize(),
983                          Shape.getRetconCoroId()->getStorageAlignment(),
984                          /*NoAlias=*/true);
985 
986     break;
987   }
988 
989   switch (Shape.ABI) {
990   // In these ABIs, the cloned functions always return 'void', and the
991   // existing return sites are meaningless.  Note that for unique
992   // continuations, this includes the returns associated with suspends;
993   // this is fine because we can't suspend twice.
994   case coro::ABI::Switch:
995   case coro::ABI::RetconOnce:
996     // Remove old returns.
997     for (ReturnInst *Return : Returns)
998       changeToUnreachable(Return);
999     break;
1000 
1001   // With multi-suspend continuations, we'll already have eliminated the
1002   // original returns and inserted returns before all the suspend points,
1003   // so we want to leave any returns in place.
1004   case coro::ABI::Retcon:
1005     break;
1006   // Async lowering will insert musttail call functions at all suspend points
1007   // followed by a return.
1008   // Don't change returns to unreachable because that will trip up the verifier.
1009   // These returns should be unreachable from the clone.
1010   case coro::ABI::Async:
1011     break;
1012   }
1013 
1014   NewF->setAttributes(NewAttrs);
1015   NewF->setCallingConv(Shape.getResumeFunctionCC());
1016 
1017   // Set up the new entry block.
1018   replaceEntryBlock();
1019 
1020   // Turn symmetric transfers into musttail calls.
1021   for (CallInst *ResumeCall : Shape.SymmetricTransfers) {
1022     ResumeCall = cast<CallInst>(VMap[ResumeCall]);
1023     if (TTI.supportsTailCallFor(ResumeCall)) {
1024       // FIXME: Could we support symmetric transfer effectively without
1025       // musttail?
1026       ResumeCall->setTailCallKind(CallInst::TCK_MustTail);
1027     }
1028 
1029     // Put a 'ret void' after the call, and split any remaining instructions to
1030     // an unreachable block.
1031     BasicBlock *BB = ResumeCall->getParent();
1032     BB->splitBasicBlock(ResumeCall->getNextNode());
1033     Builder.SetInsertPoint(BB->getTerminator());
1034     Builder.CreateRetVoid();
1035     BB->getTerminator()->eraseFromParent();
1036   }
1037 
1038   Builder.SetInsertPoint(&NewF->getEntryBlock().front());
1039   NewFramePtr = deriveNewFramePointer();
1040 
1041   // Remap frame pointer.
1042   Value *OldFramePtr = VMap[Shape.FramePtr];
1043   NewFramePtr->takeName(OldFramePtr);
1044   OldFramePtr->replaceAllUsesWith(NewFramePtr);
1045 
1046   // Remap vFrame pointer.
1047   auto *NewVFrame = Builder.CreateBitCast(
1048       NewFramePtr, PointerType::getUnqual(Builder.getContext()), "vFrame");
1049   Value *OldVFrame = cast<Value>(VMap[Shape.CoroBegin]);
1050   if (OldVFrame != NewVFrame)
1051     OldVFrame->replaceAllUsesWith(NewVFrame);
1052 
1053   // All uses of the arguments should have been resolved by this point,
1054   // so we can safely remove the dummy values.
1055   for (Instruction *DummyArg : DummyArgs) {
1056     DummyArg->replaceAllUsesWith(PoisonValue::get(DummyArg->getType()));
1057     DummyArg->deleteValue();
1058   }
1059 
1060   switch (Shape.ABI) {
1061   case coro::ABI::Switch:
1062     // Rewrite final suspend handling as it is not done via switch (allows to
1063     // remove final case from the switch, since it is undefined behavior to
1064     // resume the coroutine suspended at the final suspend point.
1065     if (Shape.SwitchLowering.HasFinalSuspend)
1066       handleFinalSuspend();
1067     break;
1068   case coro::ABI::Async:
1069   case coro::ABI::Retcon:
1070   case coro::ABI::RetconOnce:
1071     // Replace uses of the active suspend with the corresponding
1072     // continuation-function arguments.
1073     assert(ActiveSuspend != nullptr &&
1074            "no active suspend when lowering a continuation-style coroutine");
1075     replaceRetconOrAsyncSuspendUses();
1076     break;
1077   }
1078 
1079   // Handle suspends.
1080   replaceCoroSuspends();
1081 
1082   // Handle swifterror.
1083   replaceSwiftErrorOps();
1084 
1085   // Remove coro.end intrinsics.
1086   replaceCoroEnds();
1087 
1088   // Salvage debug info that points into the coroutine frame.
1089   salvageDebugInfo();
1090 }
1091 
create()1092 void coro::SwitchCloner::create() {
1093   // Create a new function matching the original type
1094   NewF = createCloneDeclaration(OrigF, Shape, Suffix, OrigF.getParent()->end(),
1095                                 ActiveSuspend);
1096 
1097   // Clone the function
1098   coro::BaseCloner::create();
1099 
1100   // Eliminate coro.free from the clones, replacing it with 'null' in cleanup,
1101   // to suppress deallocation code.
1102   coro::replaceCoroFree(cast<CoroIdInst>(VMap[Shape.CoroBegin->getId()]),
1103                         /*Elide=*/FKind == coro::CloneKind::SwitchCleanup);
1104 }
1105 
updateAsyncFuncPointerContextSize(coro::Shape & Shape)1106 static void updateAsyncFuncPointerContextSize(coro::Shape &Shape) {
1107   assert(Shape.ABI == coro::ABI::Async);
1108 
1109   auto *FuncPtrStruct = cast<ConstantStruct>(
1110       Shape.AsyncLowering.AsyncFuncPointer->getInitializer());
1111   auto *OrigRelativeFunOffset = FuncPtrStruct->getOperand(0);
1112   auto *OrigContextSize = FuncPtrStruct->getOperand(1);
1113   auto *NewContextSize = ConstantInt::get(OrigContextSize->getType(),
1114                                           Shape.AsyncLowering.ContextSize);
1115   auto *NewFuncPtrStruct = ConstantStruct::get(
1116       FuncPtrStruct->getType(), OrigRelativeFunOffset, NewContextSize);
1117 
1118   Shape.AsyncLowering.AsyncFuncPointer->setInitializer(NewFuncPtrStruct);
1119 }
1120 
getFrameSizeForShape(coro::Shape & Shape)1121 static TypeSize getFrameSizeForShape(coro::Shape &Shape) {
1122   // In the same function all coro.sizes should have the same result type.
1123   auto *SizeIntrin = Shape.CoroSizes.back();
1124   Module *M = SizeIntrin->getModule();
1125   const DataLayout &DL = M->getDataLayout();
1126   return DL.getTypeAllocSize(Shape.FrameTy);
1127 }
1128 
replaceFrameSizeAndAlignment(coro::Shape & Shape)1129 static void replaceFrameSizeAndAlignment(coro::Shape &Shape) {
1130   if (Shape.ABI == coro::ABI::Async)
1131     updateAsyncFuncPointerContextSize(Shape);
1132 
1133   for (CoroAlignInst *CA : Shape.CoroAligns) {
1134     CA->replaceAllUsesWith(
1135         ConstantInt::get(CA->getType(), Shape.FrameAlign.value()));
1136     CA->eraseFromParent();
1137   }
1138 
1139   if (Shape.CoroSizes.empty())
1140     return;
1141 
1142   // In the same function all coro.sizes should have the same result type.
1143   auto *SizeIntrin = Shape.CoroSizes.back();
1144   auto *SizeConstant =
1145       ConstantInt::get(SizeIntrin->getType(), getFrameSizeForShape(Shape));
1146 
1147   for (CoroSizeInst *CS : Shape.CoroSizes) {
1148     CS->replaceAllUsesWith(SizeConstant);
1149     CS->eraseFromParent();
1150   }
1151 }
1152 
postSplitCleanup(Function & F)1153 static void postSplitCleanup(Function &F) {
1154   removeUnreachableBlocks(F);
1155 
1156 #ifndef NDEBUG
1157   // For now, we do a mandatory verification step because we don't
1158   // entirely trust this pass.  Note that we don't want to add a verifier
1159   // pass to FPM below because it will also verify all the global data.
1160   if (verifyFunction(F, &errs()))
1161     report_fatal_error("Broken function");
1162 #endif
1163 }
1164 
1165 // Coroutine has no suspend points. Remove heap allocation for the coroutine
1166 // frame if possible.
handleNoSuspendCoroutine(coro::Shape & Shape)1167 static void handleNoSuspendCoroutine(coro::Shape &Shape) {
1168   auto *CoroBegin = Shape.CoroBegin;
1169   switch (Shape.ABI) {
1170   case coro::ABI::Switch: {
1171     auto SwitchId = Shape.getSwitchCoroId();
1172     auto *AllocInst = SwitchId->getCoroAlloc();
1173     coro::replaceCoroFree(SwitchId, /*Elide=*/AllocInst != nullptr);
1174     if (AllocInst) {
1175       IRBuilder<> Builder(AllocInst);
1176       auto *Frame = Builder.CreateAlloca(Shape.FrameTy);
1177       Frame->setAlignment(Shape.FrameAlign);
1178       AllocInst->replaceAllUsesWith(Builder.getFalse());
1179       AllocInst->eraseFromParent();
1180       CoroBegin->replaceAllUsesWith(Frame);
1181     } else {
1182       CoroBegin->replaceAllUsesWith(CoroBegin->getMem());
1183     }
1184 
1185     break;
1186   }
1187   case coro::ABI::Async:
1188   case coro::ABI::Retcon:
1189   case coro::ABI::RetconOnce:
1190     CoroBegin->replaceAllUsesWith(PoisonValue::get(CoroBegin->getType()));
1191     break;
1192   }
1193 
1194   CoroBegin->eraseFromParent();
1195   Shape.CoroBegin = nullptr;
1196 }
1197 
1198 // SimplifySuspendPoint needs to check that there is no calls between
1199 // coro_save and coro_suspend, since any of the calls may potentially resume
1200 // the coroutine and if that is the case we cannot eliminate the suspend point.
hasCallsInBlockBetween(iterator_range<BasicBlock::iterator> R)1201 static bool hasCallsInBlockBetween(iterator_range<BasicBlock::iterator> R) {
1202   for (Instruction &I : R) {
1203     // Assume that no intrinsic can resume the coroutine.
1204     if (isa<IntrinsicInst>(I))
1205       continue;
1206 
1207     if (isa<CallBase>(I))
1208       return true;
1209   }
1210   return false;
1211 }
1212 
hasCallsInBlocksBetween(BasicBlock * SaveBB,BasicBlock * ResDesBB)1213 static bool hasCallsInBlocksBetween(BasicBlock *SaveBB, BasicBlock *ResDesBB) {
1214   SmallPtrSet<BasicBlock *, 8> Set;
1215   SmallVector<BasicBlock *, 8> Worklist;
1216 
1217   Set.insert(SaveBB);
1218   Worklist.push_back(ResDesBB);
1219 
1220   // Accumulate all blocks between SaveBB and ResDesBB. Because CoroSaveIntr
1221   // returns a token consumed by suspend instruction, all blocks in between
1222   // will have to eventually hit SaveBB when going backwards from ResDesBB.
1223   while (!Worklist.empty()) {
1224     auto *BB = Worklist.pop_back_val();
1225     Set.insert(BB);
1226     for (auto *Pred : predecessors(BB))
1227       if (!Set.contains(Pred))
1228         Worklist.push_back(Pred);
1229   }
1230 
1231   // SaveBB and ResDesBB are checked separately in hasCallsBetween.
1232   Set.erase(SaveBB);
1233   Set.erase(ResDesBB);
1234 
1235   for (auto *BB : Set)
1236     if (hasCallsInBlockBetween({BB->getFirstNonPHIIt(), BB->end()}))
1237       return true;
1238 
1239   return false;
1240 }
1241 
hasCallsBetween(Instruction * Save,Instruction * ResumeOrDestroy)1242 static bool hasCallsBetween(Instruction *Save, Instruction *ResumeOrDestroy) {
1243   auto *SaveBB = Save->getParent();
1244   auto *ResumeOrDestroyBB = ResumeOrDestroy->getParent();
1245   BasicBlock::iterator SaveIt = Save->getIterator();
1246   BasicBlock::iterator ResumeOrDestroyIt = ResumeOrDestroy->getIterator();
1247 
1248   if (SaveBB == ResumeOrDestroyBB)
1249     return hasCallsInBlockBetween({std::next(SaveIt), ResumeOrDestroyIt});
1250 
1251   // Any calls from Save to the end of the block?
1252   if (hasCallsInBlockBetween({std::next(SaveIt), SaveBB->end()}))
1253     return true;
1254 
1255   // Any calls from begging of the block up to ResumeOrDestroy?
1256   if (hasCallsInBlockBetween(
1257           {ResumeOrDestroyBB->getFirstNonPHIIt(), ResumeOrDestroyIt}))
1258     return true;
1259 
1260   // Any calls in all of the blocks between SaveBB and ResumeOrDestroyBB?
1261   if (hasCallsInBlocksBetween(SaveBB, ResumeOrDestroyBB))
1262     return true;
1263 
1264   return false;
1265 }
1266 
1267 // If a SuspendIntrin is preceded by Resume or Destroy, we can eliminate the
1268 // suspend point and replace it with nornal control flow.
simplifySuspendPoint(CoroSuspendInst * Suspend,CoroBeginInst * CoroBegin)1269 static bool simplifySuspendPoint(CoroSuspendInst *Suspend,
1270                                  CoroBeginInst *CoroBegin) {
1271   Instruction *Prev = Suspend->getPrevNode();
1272   if (!Prev) {
1273     auto *Pred = Suspend->getParent()->getSinglePredecessor();
1274     if (!Pred)
1275       return false;
1276     Prev = Pred->getTerminator();
1277   }
1278 
1279   CallBase *CB = dyn_cast<CallBase>(Prev);
1280   if (!CB)
1281     return false;
1282 
1283   auto *Callee = CB->getCalledOperand()->stripPointerCasts();
1284 
1285   // See if the callsite is for resumption or destruction of the coroutine.
1286   auto *SubFn = dyn_cast<CoroSubFnInst>(Callee);
1287   if (!SubFn)
1288     return false;
1289 
1290   // Does not refer to the current coroutine, we cannot do anything with it.
1291   if (SubFn->getFrame() != CoroBegin)
1292     return false;
1293 
1294   // See if the transformation is safe. Specifically, see if there are any
1295   // calls in between Save and CallInstr. They can potenitally resume the
1296   // coroutine rendering this optimization unsafe.
1297   auto *Save = Suspend->getCoroSave();
1298   if (hasCallsBetween(Save, CB))
1299     return false;
1300 
1301   // Replace llvm.coro.suspend with the value that results in resumption over
1302   // the resume or cleanup path.
1303   Suspend->replaceAllUsesWith(SubFn->getRawIndex());
1304   Suspend->eraseFromParent();
1305   Save->eraseFromParent();
1306 
1307   // No longer need a call to coro.resume or coro.destroy.
1308   if (auto *Invoke = dyn_cast<InvokeInst>(CB)) {
1309     BranchInst::Create(Invoke->getNormalDest(), Invoke->getIterator());
1310   }
1311 
1312   // Grab the CalledValue from CB before erasing the CallInstr.
1313   auto *CalledValue = CB->getCalledOperand();
1314   CB->eraseFromParent();
1315 
1316   // If no more users remove it. Usually it is a bitcast of SubFn.
1317   if (CalledValue != SubFn && CalledValue->user_empty())
1318     if (auto *I = dyn_cast<Instruction>(CalledValue))
1319       I->eraseFromParent();
1320 
1321   // Now we are good to remove SubFn.
1322   if (SubFn->user_empty())
1323     SubFn->eraseFromParent();
1324 
1325   return true;
1326 }
1327 
1328 // Remove suspend points that are simplified.
simplifySuspendPoints(coro::Shape & Shape)1329 static void simplifySuspendPoints(coro::Shape &Shape) {
1330   // Currently, the only simplification we do is switch-lowering-specific.
1331   if (Shape.ABI != coro::ABI::Switch)
1332     return;
1333 
1334   auto &S = Shape.CoroSuspends;
1335   size_t I = 0, N = S.size();
1336   if (N == 0)
1337     return;
1338 
1339   size_t ChangedFinalIndex = std::numeric_limits<size_t>::max();
1340   while (true) {
1341     auto SI = cast<CoroSuspendInst>(S[I]);
1342     // Leave final.suspend to handleFinalSuspend since it is undefined behavior
1343     // to resume a coroutine suspended at the final suspend point.
1344     if (!SI->isFinal() && simplifySuspendPoint(SI, Shape.CoroBegin)) {
1345       if (--N == I)
1346         break;
1347 
1348       std::swap(S[I], S[N]);
1349 
1350       if (cast<CoroSuspendInst>(S[I])->isFinal()) {
1351         assert(Shape.SwitchLowering.HasFinalSuspend);
1352         ChangedFinalIndex = I;
1353       }
1354 
1355       continue;
1356     }
1357     if (++I == N)
1358       break;
1359   }
1360   S.resize(N);
1361 
1362   // Maintain final.suspend in case final suspend was swapped.
1363   // Due to we requrie the final suspend to be the last element of CoroSuspends.
1364   if (ChangedFinalIndex < N) {
1365     assert(cast<CoroSuspendInst>(S[ChangedFinalIndex])->isFinal());
1366     std::swap(S[ChangedFinalIndex], S.back());
1367   }
1368 }
1369 
1370 namespace {
1371 
1372 struct SwitchCoroutineSplitter {
split__anon98d7ec870411::SwitchCoroutineSplitter1373   static void split(Function &F, coro::Shape &Shape,
1374                     SmallVectorImpl<Function *> &Clones,
1375                     TargetTransformInfo &TTI) {
1376     assert(Shape.ABI == coro::ABI::Switch);
1377 
1378     // Create a resume clone by cloning the body of the original function,
1379     // setting new entry block and replacing coro.suspend an appropriate value
1380     // to force resume or cleanup pass for every suspend point.
1381     createResumeEntryBlock(F, Shape);
1382     auto *ResumeClone = coro::SwitchCloner::createClone(
1383         F, ".resume", Shape, coro::CloneKind::SwitchResume, TTI);
1384     auto *DestroyClone = coro::SwitchCloner::createClone(
1385         F, ".destroy", Shape, coro::CloneKind::SwitchUnwind, TTI);
1386     auto *CleanupClone = coro::SwitchCloner::createClone(
1387         F, ".cleanup", Shape, coro::CloneKind::SwitchCleanup, TTI);
1388 
1389     postSplitCleanup(*ResumeClone);
1390     postSplitCleanup(*DestroyClone);
1391     postSplitCleanup(*CleanupClone);
1392 
1393     // Store addresses resume/destroy/cleanup functions in the coroutine frame.
1394     updateCoroFrame(Shape, ResumeClone, DestroyClone, CleanupClone);
1395 
1396     assert(Clones.empty());
1397     Clones.push_back(ResumeClone);
1398     Clones.push_back(DestroyClone);
1399     Clones.push_back(CleanupClone);
1400 
1401     // Create a constant array referring to resume/destroy/clone functions
1402     // pointed by the last argument of @llvm.coro.info, so that CoroElide pass
1403     // can determined correct function to call.
1404     setCoroInfo(F, Shape, Clones);
1405   }
1406 
1407   // Create a variant of ramp function that does not perform heap allocation
1408   // for a switch ABI coroutine.
1409   //
1410   // The newly split `.noalloc` ramp function has the following differences:
1411   //  - Has one additional frame pointer parameter in lieu of dynamic
1412   //  allocation.
1413   //  - Suppressed allocations by replacing coro.alloc and coro.free.
createNoAllocVariant__anon98d7ec870411::SwitchCoroutineSplitter1414   static Function *createNoAllocVariant(Function &F, coro::Shape &Shape,
1415                                         SmallVectorImpl<Function *> &Clones) {
1416     assert(Shape.ABI == coro::ABI::Switch);
1417     auto *OrigFnTy = F.getFunctionType();
1418     auto OldParams = OrigFnTy->params();
1419 
1420     SmallVector<Type *> NewParams;
1421     NewParams.reserve(OldParams.size() + 1);
1422     NewParams.append(OldParams.begin(), OldParams.end());
1423     NewParams.push_back(PointerType::getUnqual(Shape.FrameTy->getContext()));
1424 
1425     auto *NewFnTy = FunctionType::get(OrigFnTy->getReturnType(), NewParams,
1426                                       OrigFnTy->isVarArg());
1427     Function *NoAllocF =
1428         Function::Create(NewFnTy, F.getLinkage(), F.getName() + ".noalloc");
1429 
1430     ValueToValueMapTy VMap;
1431     unsigned int Idx = 0;
1432     for (const auto &I : F.args()) {
1433       VMap[&I] = NoAllocF->getArg(Idx++);
1434     }
1435     // We just appended the frame pointer as the last argument of the new
1436     // function.
1437     auto FrameIdx = NoAllocF->arg_size() - 1;
1438     SmallVector<ReturnInst *, 4> Returns;
1439     CloneFunctionInto(NoAllocF, &F, VMap,
1440                       CloneFunctionChangeType::LocalChangesOnly, Returns);
1441 
1442     if (Shape.CoroBegin) {
1443       auto *NewCoroBegin =
1444           cast_if_present<CoroBeginInst>(VMap[Shape.CoroBegin]);
1445       auto *NewCoroId = cast<CoroIdInst>(NewCoroBegin->getId());
1446       coro::replaceCoroFree(NewCoroId, /*Elide=*/true);
1447       coro::suppressCoroAllocs(NewCoroId);
1448       NewCoroBegin->replaceAllUsesWith(NoAllocF->getArg(FrameIdx));
1449       NewCoroBegin->eraseFromParent();
1450     }
1451 
1452     Module *M = F.getParent();
1453     M->getFunctionList().insert(M->end(), NoAllocF);
1454 
1455     removeUnreachableBlocks(*NoAllocF);
1456     auto NewAttrs = NoAllocF->getAttributes();
1457     // When we elide allocation, we read these attributes to determine the
1458     // frame size and alignment.
1459     addFramePointerAttrs(NewAttrs, NoAllocF->getContext(), FrameIdx,
1460                          Shape.FrameSize, Shape.FrameAlign,
1461                          /*NoAlias=*/false);
1462 
1463     NoAllocF->setAttributes(NewAttrs);
1464 
1465     Clones.push_back(NoAllocF);
1466     // Reset the original function's coro info, make the new noalloc variant
1467     // connected to the original ramp function.
1468     setCoroInfo(F, Shape, Clones);
1469     // After copying, set the linkage to internal linkage. Original function
1470     // may have different linkage, but optimization dependent on this function
1471     // generally relies on LTO.
1472     NoAllocF->setLinkage(llvm::GlobalValue::InternalLinkage);
1473     return NoAllocF;
1474   }
1475 
1476 private:
1477   // Create an entry block for a resume function with a switch that will jump to
1478   // suspend points.
createResumeEntryBlock__anon98d7ec870411::SwitchCoroutineSplitter1479   static void createResumeEntryBlock(Function &F, coro::Shape &Shape) {
1480     LLVMContext &C = F.getContext();
1481 
1482     DIBuilder DBuilder(*F.getParent(), /*AllowUnresolved*/ false);
1483     DISubprogram *DIS = F.getSubprogram();
1484     // If there is no DISubprogram for F, it implies the function is compiled
1485     // without debug info. So we also don't generate debug info for the
1486     // suspension points.
1487     bool AddDebugLabels = DIS && DIS->getUnit() &&
1488                           (DIS->getUnit()->getEmissionKind() ==
1489                            DICompileUnit::DebugEmissionKind::FullDebug);
1490 
1491     // resume.entry:
1492     //  %index.addr = getelementptr inbounds %f.Frame, %f.Frame* %FramePtr, i32
1493     //  0, i32 2 % index = load i32, i32* %index.addr switch i32 %index, label
1494     //  %unreachable [
1495     //    i32 0, label %resume.0
1496     //    i32 1, label %resume.1
1497     //    ...
1498     //  ]
1499 
1500     auto *NewEntry = BasicBlock::Create(C, "resume.entry", &F);
1501     auto *UnreachBB = BasicBlock::Create(C, "unreachable", &F);
1502 
1503     IRBuilder<> Builder(NewEntry);
1504     auto *FramePtr = Shape.FramePtr;
1505     auto *FrameTy = Shape.FrameTy;
1506     auto *GepIndex = Builder.CreateStructGEP(
1507         FrameTy, FramePtr, Shape.getSwitchIndexField(), "index.addr");
1508     auto *Index = Builder.CreateLoad(Shape.getIndexType(), GepIndex, "index");
1509     auto *Switch =
1510         Builder.CreateSwitch(Index, UnreachBB, Shape.CoroSuspends.size());
1511     Shape.SwitchLowering.ResumeSwitch = Switch;
1512 
1513     // Split all coro.suspend calls
1514     size_t SuspendIndex = 0;
1515     for (auto *AnyS : Shape.CoroSuspends) {
1516       auto *S = cast<CoroSuspendInst>(AnyS);
1517       ConstantInt *IndexVal = Shape.getIndex(SuspendIndex);
1518 
1519       // Replace CoroSave with a store to Index:
1520       //    %index.addr = getelementptr %f.frame... (index field number)
1521       //    store i32 %IndexVal, i32* %index.addr1
1522       auto *Save = S->getCoroSave();
1523       Builder.SetInsertPoint(Save);
1524       if (S->isFinal()) {
1525         // The coroutine should be marked done if it reaches the final suspend
1526         // point.
1527         markCoroutineAsDone(Builder, Shape, FramePtr);
1528       } else {
1529         auto *GepIndex = Builder.CreateStructGEP(
1530             FrameTy, FramePtr, Shape.getSwitchIndexField(), "index.addr");
1531         Builder.CreateStore(IndexVal, GepIndex);
1532       }
1533 
1534       Save->replaceAllUsesWith(ConstantTokenNone::get(C));
1535       Save->eraseFromParent();
1536 
1537       // Split block before and after coro.suspend and add a jump from an entry
1538       // switch:
1539       //
1540       //  whateverBB:
1541       //    whatever
1542       //    %0 = call i8 @llvm.coro.suspend(token none, i1 false)
1543       //    switch i8 %0, label %suspend[i8 0, label %resume
1544       //                                 i8 1, label %cleanup]
1545       // becomes:
1546       //
1547       //  whateverBB:
1548       //     whatever
1549       //     br label %resume.0.landing
1550       //
1551       //  resume.0: ; <--- jump from the switch in the resume.entry
1552       //        #dbg_label(...)  ; <--- artificial label for debuggers
1553       //     %0 = tail call i8 @llvm.coro.suspend(token none, i1 false)
1554       //     br label %resume.0.landing
1555       //
1556       //  resume.0.landing:
1557       //     %1 = phi i8[-1, %whateverBB], [%0, %resume.0]
1558       //     switch i8 % 1, label %suspend [i8 0, label %resume
1559       //                                    i8 1, label %cleanup]
1560 
1561       auto *SuspendBB = S->getParent();
1562       auto *ResumeBB =
1563           SuspendBB->splitBasicBlock(S, "resume." + Twine(SuspendIndex));
1564       auto *LandingBB = ResumeBB->splitBasicBlock(
1565           S->getNextNode(), ResumeBB->getName() + Twine(".landing"));
1566       Switch->addCase(IndexVal, ResumeBB);
1567 
1568       cast<BranchInst>(SuspendBB->getTerminator())->setSuccessor(0, LandingBB);
1569       auto *PN = PHINode::Create(Builder.getInt8Ty(), 2, "");
1570       PN->insertBefore(LandingBB->begin());
1571       S->replaceAllUsesWith(PN);
1572       PN->addIncoming(Builder.getInt8(-1), SuspendBB);
1573       PN->addIncoming(S, ResumeBB);
1574 
1575       if (AddDebugLabels) {
1576         if (DebugLoc SuspendLoc = S->getDebugLoc()) {
1577           std::string LabelName =
1578               ("__coro_resume_" + Twine(SuspendIndex)).str();
1579           DILocation &DILoc = *SuspendLoc.get();
1580           DILabel *ResumeLabel =
1581               DBuilder.createLabel(DIS, LabelName, DILoc.getFile(),
1582                                    SuspendLoc.getLine(), SuspendLoc.getCol(),
1583                                    /*IsArtificial=*/true,
1584                                    /*CoroSuspendIdx=*/SuspendIndex,
1585                                    /*AlwaysPreserve=*/false);
1586           DBuilder.insertLabel(ResumeLabel, &DILoc, ResumeBB->begin());
1587         }
1588       }
1589 
1590       ++SuspendIndex;
1591     }
1592 
1593     Builder.SetInsertPoint(UnreachBB);
1594     Builder.CreateUnreachable();
1595     DBuilder.finalize();
1596 
1597     Shape.SwitchLowering.ResumeEntryBlock = NewEntry;
1598   }
1599 
1600   // Store addresses of Resume/Destroy/Cleanup functions in the coroutine frame.
updateCoroFrame__anon98d7ec870411::SwitchCoroutineSplitter1601   static void updateCoroFrame(coro::Shape &Shape, Function *ResumeFn,
1602                               Function *DestroyFn, Function *CleanupFn) {
1603     IRBuilder<> Builder(&*Shape.getInsertPtAfterFramePtr());
1604 
1605     auto *ResumeAddr = Builder.CreateStructGEP(
1606         Shape.FrameTy, Shape.FramePtr, coro::Shape::SwitchFieldIndex::Resume,
1607         "resume.addr");
1608     Builder.CreateStore(ResumeFn, ResumeAddr);
1609 
1610     Value *DestroyOrCleanupFn = DestroyFn;
1611 
1612     CoroIdInst *CoroId = Shape.getSwitchCoroId();
1613     if (CoroAllocInst *CA = CoroId->getCoroAlloc()) {
1614       // If there is a CoroAlloc and it returns false (meaning we elide the
1615       // allocation, use CleanupFn instead of DestroyFn).
1616       DestroyOrCleanupFn = Builder.CreateSelect(CA, DestroyFn, CleanupFn);
1617     }
1618 
1619     auto *DestroyAddr = Builder.CreateStructGEP(
1620         Shape.FrameTy, Shape.FramePtr, coro::Shape::SwitchFieldIndex::Destroy,
1621         "destroy.addr");
1622     Builder.CreateStore(DestroyOrCleanupFn, DestroyAddr);
1623   }
1624 
1625   // Create a global constant array containing pointers to functions provided
1626   // and set Info parameter of CoroBegin to point at this constant. Example:
1627   //
1628   //   @f.resumers = internal constant [2 x void(%f.frame*)*]
1629   //                    [void(%f.frame*)* @f.resume, void(%f.frame*)*
1630   //                    @f.destroy]
1631   //   define void @f() {
1632   //     ...
1633   //     call i8* @llvm.coro.begin(i8* null, i32 0, i8* null,
1634   //                    i8* bitcast([2 x void(%f.frame*)*] * @f.resumers to
1635   //                    i8*))
1636   //
1637   // Assumes that all the functions have the same signature.
setCoroInfo__anon98d7ec870411::SwitchCoroutineSplitter1638   static void setCoroInfo(Function &F, coro::Shape &Shape,
1639                           ArrayRef<Function *> Fns) {
1640     // This only works under the switch-lowering ABI because coro elision
1641     // only works on the switch-lowering ABI.
1642     SmallVector<Constant *, 4> Args(Fns);
1643     assert(!Args.empty());
1644     Function *Part = *Fns.begin();
1645     Module *M = Part->getParent();
1646     auto *ArrTy = ArrayType::get(Part->getType(), Args.size());
1647 
1648     auto *ConstVal = ConstantArray::get(ArrTy, Args);
1649     auto *GV = new GlobalVariable(*M, ConstVal->getType(), /*isConstant=*/true,
1650                                   GlobalVariable::PrivateLinkage, ConstVal,
1651                                   F.getName() + Twine(".resumers"));
1652 
1653     // Update coro.begin instruction to refer to this constant.
1654     LLVMContext &C = F.getContext();
1655     auto *BC = ConstantExpr::getPointerCast(GV, PointerType::getUnqual(C));
1656     Shape.getSwitchCoroId()->setInfo(BC);
1657   }
1658 };
1659 
1660 } // namespace
1661 
replaceAsyncResumeFunction(CoroSuspendAsyncInst * Suspend,Value * Continuation)1662 static void replaceAsyncResumeFunction(CoroSuspendAsyncInst *Suspend,
1663                                        Value *Continuation) {
1664   auto *ResumeIntrinsic = Suspend->getResumeFunction();
1665   auto &Context = Suspend->getParent()->getParent()->getContext();
1666   auto *Int8PtrTy = PointerType::getUnqual(Context);
1667 
1668   IRBuilder<> Builder(ResumeIntrinsic);
1669   auto *Val = Builder.CreateBitOrPointerCast(Continuation, Int8PtrTy);
1670   ResumeIntrinsic->replaceAllUsesWith(Val);
1671   ResumeIntrinsic->eraseFromParent();
1672   Suspend->setOperand(CoroSuspendAsyncInst::ResumeFunctionArg,
1673                       PoisonValue::get(Int8PtrTy));
1674 }
1675 
1676 /// Coerce the arguments in \p FnArgs according to \p FnTy in \p CallArgs.
coerceArguments(IRBuilder<> & Builder,FunctionType * FnTy,ArrayRef<Value * > FnArgs,SmallVectorImpl<Value * > & CallArgs)1677 static void coerceArguments(IRBuilder<> &Builder, FunctionType *FnTy,
1678                             ArrayRef<Value *> FnArgs,
1679                             SmallVectorImpl<Value *> &CallArgs) {
1680   size_t ArgIdx = 0;
1681   for (auto *paramTy : FnTy->params()) {
1682     assert(ArgIdx < FnArgs.size());
1683     if (paramTy != FnArgs[ArgIdx]->getType())
1684       CallArgs.push_back(
1685           Builder.CreateBitOrPointerCast(FnArgs[ArgIdx], paramTy));
1686     else
1687       CallArgs.push_back(FnArgs[ArgIdx]);
1688     ++ArgIdx;
1689   }
1690 }
1691 
createMustTailCall(DebugLoc Loc,Function * MustTailCallFn,TargetTransformInfo & TTI,ArrayRef<Value * > Arguments,IRBuilder<> & Builder)1692 CallInst *coro::createMustTailCall(DebugLoc Loc, Function *MustTailCallFn,
1693                                    TargetTransformInfo &TTI,
1694                                    ArrayRef<Value *> Arguments,
1695                                    IRBuilder<> &Builder) {
1696   auto *FnTy = MustTailCallFn->getFunctionType();
1697   // Coerce the arguments, llvm optimizations seem to ignore the types in
1698   // vaarg functions and throws away casts in optimized mode.
1699   SmallVector<Value *, 8> CallArgs;
1700   coerceArguments(Builder, FnTy, Arguments, CallArgs);
1701 
1702   auto *TailCall = Builder.CreateCall(FnTy, MustTailCallFn, CallArgs);
1703   // Skip targets which don't support tail call.
1704   if (TTI.supportsTailCallFor(TailCall)) {
1705     TailCall->setTailCallKind(CallInst::TCK_MustTail);
1706   }
1707   TailCall->setDebugLoc(Loc);
1708   TailCall->setCallingConv(MustTailCallFn->getCallingConv());
1709   return TailCall;
1710 }
1711 
splitCoroutine(Function & F,coro::Shape & Shape,SmallVectorImpl<Function * > & Clones,TargetTransformInfo & TTI)1712 void coro::AsyncABI::splitCoroutine(Function &F, coro::Shape &Shape,
1713                                     SmallVectorImpl<Function *> &Clones,
1714                                     TargetTransformInfo &TTI) {
1715   assert(Shape.ABI == coro::ABI::Async);
1716   assert(Clones.empty());
1717   // Reset various things that the optimizer might have decided it
1718   // "knows" about the coroutine function due to not seeing a return.
1719   F.removeFnAttr(Attribute::NoReturn);
1720   F.removeRetAttr(Attribute::NoAlias);
1721   F.removeRetAttr(Attribute::NonNull);
1722 
1723   auto &Context = F.getContext();
1724   auto *Int8PtrTy = PointerType::getUnqual(Context);
1725 
1726   auto *Id = Shape.getAsyncCoroId();
1727   IRBuilder<> Builder(Id);
1728 
1729   auto *FramePtr = Id->getStorage();
1730   FramePtr = Builder.CreateBitOrPointerCast(FramePtr, Int8PtrTy);
1731   FramePtr = Builder.CreateConstInBoundsGEP1_32(
1732       Type::getInt8Ty(Context), FramePtr, Shape.AsyncLowering.FrameOffset,
1733       "async.ctx.frameptr");
1734 
1735   // Map all uses of llvm.coro.begin to the allocated frame pointer.
1736   {
1737     // Make sure we don't invalidate Shape.FramePtr.
1738     TrackingVH<Value> Handle(Shape.FramePtr);
1739     Shape.CoroBegin->replaceAllUsesWith(FramePtr);
1740     Shape.FramePtr = Handle.getValPtr();
1741   }
1742 
1743   // Create all the functions in order after the main function.
1744   auto NextF = std::next(F.getIterator());
1745 
1746   // Create a continuation function for each of the suspend points.
1747   Clones.reserve(Shape.CoroSuspends.size());
1748   for (auto [Idx, CS] : llvm::enumerate(Shape.CoroSuspends)) {
1749     auto *Suspend = cast<CoroSuspendAsyncInst>(CS);
1750 
1751     // Create the clone declaration.
1752     auto ResumeNameSuffix = ".resume.";
1753     auto ProjectionFunctionName =
1754         Suspend->getAsyncContextProjectionFunction()->getName();
1755     bool UseSwiftMangling = false;
1756     if (ProjectionFunctionName == "__swift_async_resume_project_context") {
1757       ResumeNameSuffix = "TQ";
1758       UseSwiftMangling = true;
1759     } else if (ProjectionFunctionName == "__swift_async_resume_get_context") {
1760       ResumeNameSuffix = "TY";
1761       UseSwiftMangling = true;
1762     }
1763     auto *Continuation = createCloneDeclaration(
1764         F, Shape,
1765         UseSwiftMangling ? ResumeNameSuffix + Twine(Idx) + "_"
1766                          : ResumeNameSuffix + Twine(Idx),
1767         NextF, Suspend);
1768     Clones.push_back(Continuation);
1769 
1770     // Insert a branch to a new return block immediately before the suspend
1771     // point.
1772     auto *SuspendBB = Suspend->getParent();
1773     auto *NewSuspendBB = SuspendBB->splitBasicBlock(Suspend);
1774     auto *Branch = cast<BranchInst>(SuspendBB->getTerminator());
1775 
1776     // Place it before the first suspend.
1777     auto *ReturnBB =
1778         BasicBlock::Create(F.getContext(), "coro.return", &F, NewSuspendBB);
1779     Branch->setSuccessor(0, ReturnBB);
1780 
1781     IRBuilder<> Builder(ReturnBB);
1782 
1783     // Insert the call to the tail call function and inline it.
1784     auto *Fn = Suspend->getMustTailCallFunction();
1785     SmallVector<Value *, 8> Args(Suspend->args());
1786     auto FnArgs = ArrayRef<Value *>(Args).drop_front(
1787         CoroSuspendAsyncInst::MustTailCallFuncArg + 1);
1788     auto *TailCall = coro::createMustTailCall(Suspend->getDebugLoc(), Fn, TTI,
1789                                               FnArgs, Builder);
1790     Builder.CreateRetVoid();
1791     InlineFunctionInfo FnInfo;
1792     (void)InlineFunction(*TailCall, FnInfo);
1793 
1794     // Replace the lvm.coro.async.resume intrisic call.
1795     replaceAsyncResumeFunction(Suspend, Continuation);
1796   }
1797 
1798   assert(Clones.size() == Shape.CoroSuspends.size());
1799 
1800   for (auto [Idx, CS] : llvm::enumerate(Shape.CoroSuspends)) {
1801     auto *Suspend = CS;
1802     auto *Clone = Clones[Idx];
1803 
1804     coro::BaseCloner::createClone(F, "resume." + Twine(Idx), Shape, Clone,
1805                                   Suspend, TTI);
1806   }
1807 }
1808 
splitCoroutine(Function & F,coro::Shape & Shape,SmallVectorImpl<Function * > & Clones,TargetTransformInfo & TTI)1809 void coro::AnyRetconABI::splitCoroutine(Function &F, coro::Shape &Shape,
1810                                         SmallVectorImpl<Function *> &Clones,
1811                                         TargetTransformInfo &TTI) {
1812   assert(Shape.ABI == coro::ABI::Retcon || Shape.ABI == coro::ABI::RetconOnce);
1813   assert(Clones.empty());
1814 
1815   // Reset various things that the optimizer might have decided it
1816   // "knows" about the coroutine function due to not seeing a return.
1817   F.removeFnAttr(Attribute::NoReturn);
1818   F.removeRetAttr(Attribute::NoAlias);
1819   F.removeRetAttr(Attribute::NonNull);
1820 
1821   // Allocate the frame.
1822   auto *Id = Shape.getRetconCoroId();
1823   Value *RawFramePtr;
1824   if (Shape.RetconLowering.IsFrameInlineInStorage) {
1825     RawFramePtr = Id->getStorage();
1826   } else {
1827     IRBuilder<> Builder(Id);
1828 
1829     // Determine the size of the frame.
1830     const DataLayout &DL = F.getDataLayout();
1831     auto Size = DL.getTypeAllocSize(Shape.FrameTy);
1832 
1833     // Allocate.  We don't need to update the call graph node because we're
1834     // going to recompute it from scratch after splitting.
1835     // FIXME: pass the required alignment
1836     RawFramePtr = Shape.emitAlloc(Builder, Builder.getInt64(Size), nullptr);
1837     RawFramePtr =
1838         Builder.CreateBitCast(RawFramePtr, Shape.CoroBegin->getType());
1839 
1840     // Stash the allocated frame pointer in the continuation storage.
1841     Builder.CreateStore(RawFramePtr, Id->getStorage());
1842   }
1843 
1844   // Map all uses of llvm.coro.begin to the allocated frame pointer.
1845   {
1846     // Make sure we don't invalidate Shape.FramePtr.
1847     TrackingVH<Value> Handle(Shape.FramePtr);
1848     Shape.CoroBegin->replaceAllUsesWith(RawFramePtr);
1849     Shape.FramePtr = Handle.getValPtr();
1850   }
1851 
1852   // Create a unique return block.
1853   BasicBlock *ReturnBB = nullptr;
1854   PHINode *ContinuationPhi = nullptr;
1855   SmallVector<PHINode *, 4> ReturnPHIs;
1856 
1857   // Create all the functions in order after the main function.
1858   auto NextF = std::next(F.getIterator());
1859 
1860   // Create a continuation function for each of the suspend points.
1861   Clones.reserve(Shape.CoroSuspends.size());
1862   for (auto [Idx, CS] : llvm::enumerate(Shape.CoroSuspends)) {
1863     auto Suspend = cast<CoroSuspendRetconInst>(CS);
1864 
1865     // Create the clone declaration.
1866     auto Continuation = createCloneDeclaration(
1867         F, Shape, ".resume." + Twine(Idx), NextF, nullptr);
1868     Clones.push_back(Continuation);
1869 
1870     // Insert a branch to the unified return block immediately before
1871     // the suspend point.
1872     auto SuspendBB = Suspend->getParent();
1873     auto NewSuspendBB = SuspendBB->splitBasicBlock(Suspend);
1874     auto Branch = cast<BranchInst>(SuspendBB->getTerminator());
1875 
1876     // Create the unified return block.
1877     if (!ReturnBB) {
1878       // Place it before the first suspend.
1879       ReturnBB =
1880           BasicBlock::Create(F.getContext(), "coro.return", &F, NewSuspendBB);
1881       Shape.RetconLowering.ReturnBlock = ReturnBB;
1882 
1883       IRBuilder<> Builder(ReturnBB);
1884 
1885       // First, the continuation.
1886       ContinuationPhi =
1887           Builder.CreatePHI(Continuation->getType(), Shape.CoroSuspends.size());
1888 
1889       // Create PHIs for all other return values.
1890       assert(ReturnPHIs.empty());
1891 
1892       // Next, all the directly-yielded values.
1893       for (auto *ResultTy : Shape.getRetconResultTypes())
1894         ReturnPHIs.push_back(
1895             Builder.CreatePHI(ResultTy, Shape.CoroSuspends.size()));
1896 
1897       // Build the return value.
1898       auto RetTy = F.getReturnType();
1899 
1900       // Cast the continuation value if necessary.
1901       // We can't rely on the types matching up because that type would
1902       // have to be infinite.
1903       auto CastedContinuationTy =
1904           (ReturnPHIs.empty() ? RetTy : RetTy->getStructElementType(0));
1905       auto *CastedContinuation =
1906           Builder.CreateBitCast(ContinuationPhi, CastedContinuationTy);
1907 
1908       Value *RetV = CastedContinuation;
1909       if (!ReturnPHIs.empty()) {
1910         auto ValueIdx = 0;
1911         RetV = PoisonValue::get(RetTy);
1912         RetV = Builder.CreateInsertValue(RetV, CastedContinuation, ValueIdx++);
1913 
1914         for (auto Phi : ReturnPHIs)
1915           RetV = Builder.CreateInsertValue(RetV, Phi, ValueIdx++);
1916       }
1917 
1918       Builder.CreateRet(RetV);
1919     }
1920 
1921     // Branch to the return block.
1922     Branch->setSuccessor(0, ReturnBB);
1923     assert(ContinuationPhi);
1924     ContinuationPhi->addIncoming(Continuation, SuspendBB);
1925     for (auto [Phi, VUse] :
1926          llvm::zip_equal(ReturnPHIs, Suspend->value_operands()))
1927       Phi->addIncoming(VUse, SuspendBB);
1928   }
1929 
1930   assert(Clones.size() == Shape.CoroSuspends.size());
1931 
1932   for (auto [Idx, CS] : llvm::enumerate(Shape.CoroSuspends)) {
1933     auto Suspend = CS;
1934     auto Clone = Clones[Idx];
1935 
1936     coro::BaseCloner::createClone(F, "resume." + Twine(Idx), Shape, Clone,
1937                                   Suspend, TTI);
1938   }
1939 }
1940 
1941 namespace {
1942 class PrettyStackTraceFunction : public PrettyStackTraceEntry {
1943   Function &F;
1944 
1945 public:
PrettyStackTraceFunction(Function & F)1946   PrettyStackTraceFunction(Function &F) : F(F) {}
print(raw_ostream & OS) const1947   void print(raw_ostream &OS) const override {
1948     OS << "While splitting coroutine ";
1949     F.printAsOperand(OS, /*print type*/ false, F.getParent());
1950     OS << "\n";
1951   }
1952 };
1953 } // namespace
1954 
1955 /// Remove calls to llvm.coro.end in the original function.
removeCoroEndsFromRampFunction(const coro::Shape & Shape)1956 static void removeCoroEndsFromRampFunction(const coro::Shape &Shape) {
1957   if (Shape.ABI != coro::ABI::Switch) {
1958     for (auto *End : Shape.CoroEnds) {
1959       replaceCoroEnd(End, Shape, Shape.FramePtr, /*in resume*/ false, nullptr);
1960     }
1961   } else {
1962     for (llvm::AnyCoroEndInst *End : Shape.CoroEnds) {
1963       auto &Context = End->getContext();
1964       End->replaceAllUsesWith(ConstantInt::getFalse(Context));
1965       End->eraseFromParent();
1966     }
1967   }
1968 }
1969 
hasSafeElideCaller(Function & F)1970 static bool hasSafeElideCaller(Function &F) {
1971   for (auto *U : F.users()) {
1972     if (auto *CB = dyn_cast<CallBase>(U)) {
1973       auto *Caller = CB->getFunction();
1974       if (Caller && Caller->isPresplitCoroutine() &&
1975           CB->hasFnAttr(llvm::Attribute::CoroElideSafe))
1976         return true;
1977     }
1978   }
1979   return false;
1980 }
1981 
splitCoroutine(Function & F,coro::Shape & Shape,SmallVectorImpl<Function * > & Clones,TargetTransformInfo & TTI)1982 void coro::SwitchABI::splitCoroutine(Function &F, coro::Shape &Shape,
1983                                      SmallVectorImpl<Function *> &Clones,
1984                                      TargetTransformInfo &TTI) {
1985   SwitchCoroutineSplitter::split(F, Shape, Clones, TTI);
1986 }
1987 
doSplitCoroutine(Function & F,SmallVectorImpl<Function * > & Clones,coro::BaseABI & ABI,TargetTransformInfo & TTI,bool OptimizeFrame)1988 static void doSplitCoroutine(Function &F, SmallVectorImpl<Function *> &Clones,
1989                              coro::BaseABI &ABI, TargetTransformInfo &TTI,
1990                              bool OptimizeFrame) {
1991   PrettyStackTraceFunction prettyStackTrace(F);
1992 
1993   auto &Shape = ABI.Shape;
1994   assert(Shape.CoroBegin);
1995 
1996   lowerAwaitSuspends(F, Shape);
1997 
1998   simplifySuspendPoints(Shape);
1999 
2000   normalizeCoroutine(F, Shape, TTI);
2001   ABI.buildCoroutineFrame(OptimizeFrame);
2002   replaceFrameSizeAndAlignment(Shape);
2003 
2004   bool isNoSuspendCoroutine = Shape.CoroSuspends.empty();
2005 
2006   bool shouldCreateNoAllocVariant =
2007       !isNoSuspendCoroutine && Shape.ABI == coro::ABI::Switch &&
2008       hasSafeElideCaller(F) && !F.hasFnAttribute(llvm::Attribute::NoInline);
2009 
2010   // If there are no suspend points, no split required, just remove
2011   // the allocation and deallocation blocks, they are not needed.
2012   if (isNoSuspendCoroutine) {
2013     handleNoSuspendCoroutine(Shape);
2014   } else {
2015     ABI.splitCoroutine(F, Shape, Clones, TTI);
2016   }
2017 
2018   // Replace all the swifterror operations in the original function.
2019   // This invalidates SwiftErrorOps in the Shape.
2020   replaceSwiftErrorOps(F, Shape, nullptr);
2021 
2022   // Salvage debug intrinsics that point into the coroutine frame in the
2023   // original function. The Cloner has already salvaged debug info in the new
2024   // coroutine funclets.
2025   SmallDenseMap<Argument *, AllocaInst *, 4> ArgToAllocaMap;
2026   auto [DbgInsts, DbgVariableRecords] = collectDbgVariableIntrinsics(F);
2027   for (auto *DDI : DbgInsts)
2028     coro::salvageDebugInfo(ArgToAllocaMap, *DDI, false /*UseEntryValue*/);
2029   for (DbgVariableRecord *DVR : DbgVariableRecords)
2030     coro::salvageDebugInfo(ArgToAllocaMap, *DVR, false /*UseEntryValue*/);
2031 
2032   removeCoroEndsFromRampFunction(Shape);
2033 
2034   if (shouldCreateNoAllocVariant)
2035     SwitchCoroutineSplitter::createNoAllocVariant(F, Shape, Clones);
2036 }
2037 
updateCallGraphAfterCoroutineSplit(LazyCallGraph::Node & N,const coro::Shape & Shape,const SmallVectorImpl<Function * > & Clones,LazyCallGraph::SCC & C,LazyCallGraph & CG,CGSCCAnalysisManager & AM,CGSCCUpdateResult & UR,FunctionAnalysisManager & FAM)2038 static LazyCallGraph::SCC &updateCallGraphAfterCoroutineSplit(
2039     LazyCallGraph::Node &N, const coro::Shape &Shape,
2040     const SmallVectorImpl<Function *> &Clones, LazyCallGraph::SCC &C,
2041     LazyCallGraph &CG, CGSCCAnalysisManager &AM, CGSCCUpdateResult &UR,
2042     FunctionAnalysisManager &FAM) {
2043 
2044   auto *CurrentSCC = &C;
2045   if (!Clones.empty()) {
2046     switch (Shape.ABI) {
2047     case coro::ABI::Switch:
2048       // Each clone in the Switch lowering is independent of the other clones.
2049       // Let the LazyCallGraph know about each one separately.
2050       for (Function *Clone : Clones)
2051         CG.addSplitFunction(N.getFunction(), *Clone);
2052       break;
2053     case coro::ABI::Async:
2054     case coro::ABI::Retcon:
2055     case coro::ABI::RetconOnce:
2056       // Each clone in the Async/Retcon lowering references of the other clones.
2057       // Let the LazyCallGraph know about all of them at once.
2058       if (!Clones.empty())
2059         CG.addSplitRefRecursiveFunctions(N.getFunction(), Clones);
2060       break;
2061     }
2062 
2063     // Let the CGSCC infra handle the changes to the original function.
2064     CurrentSCC = &updateCGAndAnalysisManagerForCGSCCPass(CG, *CurrentSCC, N, AM,
2065                                                          UR, FAM);
2066   }
2067 
2068   // Do some cleanup and let the CGSCC infra see if we've cleaned up any edges
2069   // to the split functions.
2070   postSplitCleanup(N.getFunction());
2071   CurrentSCC = &updateCGAndAnalysisManagerForFunctionPass(CG, *CurrentSCC, N,
2072                                                           AM, UR, FAM);
2073   return *CurrentSCC;
2074 }
2075 
2076 /// Replace a call to llvm.coro.prepare.retcon.
replacePrepare(CallInst * Prepare,LazyCallGraph & CG,LazyCallGraph::SCC & C)2077 static void replacePrepare(CallInst *Prepare, LazyCallGraph &CG,
2078                            LazyCallGraph::SCC &C) {
2079   auto CastFn = Prepare->getArgOperand(0); // as an i8*
2080   auto Fn = CastFn->stripPointerCasts();   // as its original type
2081 
2082   // Attempt to peephole this pattern:
2083   //    %0 = bitcast [[TYPE]] @some_function to i8*
2084   //    %1 = call @llvm.coro.prepare.retcon(i8* %0)
2085   //    %2 = bitcast %1 to [[TYPE]]
2086   // ==>
2087   //    %2 = @some_function
2088   for (Use &U : llvm::make_early_inc_range(Prepare->uses())) {
2089     // Look for bitcasts back to the original function type.
2090     auto *Cast = dyn_cast<BitCastInst>(U.getUser());
2091     if (!Cast || Cast->getType() != Fn->getType())
2092       continue;
2093 
2094     // Replace and remove the cast.
2095     Cast->replaceAllUsesWith(Fn);
2096     Cast->eraseFromParent();
2097   }
2098 
2099   // Replace any remaining uses with the function as an i8*.
2100   // This can never directly be a callee, so we don't need to update CG.
2101   Prepare->replaceAllUsesWith(CastFn);
2102   Prepare->eraseFromParent();
2103 
2104   // Kill dead bitcasts.
2105   while (auto *Cast = dyn_cast<BitCastInst>(CastFn)) {
2106     if (!Cast->use_empty())
2107       break;
2108     CastFn = Cast->getOperand(0);
2109     Cast->eraseFromParent();
2110   }
2111 }
2112 
replaceAllPrepares(Function * PrepareFn,LazyCallGraph & CG,LazyCallGraph::SCC & C)2113 static bool replaceAllPrepares(Function *PrepareFn, LazyCallGraph &CG,
2114                                LazyCallGraph::SCC &C) {
2115   bool Changed = false;
2116   for (Use &P : llvm::make_early_inc_range(PrepareFn->uses())) {
2117     // Intrinsics can only be used in calls.
2118     auto *Prepare = cast<CallInst>(P.getUser());
2119     replacePrepare(Prepare, CG, C);
2120     Changed = true;
2121   }
2122 
2123   return Changed;
2124 }
2125 
addPrepareFunction(const Module & M,SmallVectorImpl<Function * > & Fns,StringRef Name)2126 static void addPrepareFunction(const Module &M,
2127                                SmallVectorImpl<Function *> &Fns,
2128                                StringRef Name) {
2129   auto *PrepareFn = M.getFunction(Name);
2130   if (PrepareFn && !PrepareFn->use_empty())
2131     Fns.push_back(PrepareFn);
2132 }
2133 
2134 static std::unique_ptr<coro::BaseABI>
CreateNewABI(Function & F,coro::Shape & S,std::function<bool (Instruction &)> IsMatCallback,const SmallVector<CoroSplitPass::BaseABITy> GenCustomABIs)2135 CreateNewABI(Function &F, coro::Shape &S,
2136              std::function<bool(Instruction &)> IsMatCallback,
2137              const SmallVector<CoroSplitPass::BaseABITy> GenCustomABIs) {
2138   if (S.CoroBegin->hasCustomABI()) {
2139     unsigned CustomABI = S.CoroBegin->getCustomABI();
2140     if (CustomABI >= GenCustomABIs.size())
2141       llvm_unreachable("Custom ABI not found amoung those specified");
2142     return GenCustomABIs[CustomABI](F, S);
2143   }
2144 
2145   switch (S.ABI) {
2146   case coro::ABI::Switch:
2147     return std::make_unique<coro::SwitchABI>(F, S, IsMatCallback);
2148   case coro::ABI::Async:
2149     return std::make_unique<coro::AsyncABI>(F, S, IsMatCallback);
2150   case coro::ABI::Retcon:
2151     return std::make_unique<coro::AnyRetconABI>(F, S, IsMatCallback);
2152   case coro::ABI::RetconOnce:
2153     return std::make_unique<coro::AnyRetconABI>(F, S, IsMatCallback);
2154   }
2155   llvm_unreachable("Unknown ABI");
2156 }
2157 
CoroSplitPass(bool OptimizeFrame)2158 CoroSplitPass::CoroSplitPass(bool OptimizeFrame)
2159     : CreateAndInitABI([](Function &F, coro::Shape &S) {
2160         std::unique_ptr<coro::BaseABI> ABI =
2161             CreateNewABI(F, S, coro::isTriviallyMaterializable, {});
2162         ABI->init();
2163         return ABI;
2164       }),
2165       OptimizeFrame(OptimizeFrame) {}
2166 
CoroSplitPass(SmallVector<CoroSplitPass::BaseABITy> GenCustomABIs,bool OptimizeFrame)2167 CoroSplitPass::CoroSplitPass(
2168     SmallVector<CoroSplitPass::BaseABITy> GenCustomABIs, bool OptimizeFrame)
2169     : CreateAndInitABI([=](Function &F, coro::Shape &S) {
2170         std::unique_ptr<coro::BaseABI> ABI =
2171             CreateNewABI(F, S, coro::isTriviallyMaterializable, GenCustomABIs);
2172         ABI->init();
2173         return ABI;
2174       }),
2175       OptimizeFrame(OptimizeFrame) {}
2176 
2177 // For back compatibility, constructor takes a materializable callback and
2178 // creates a generator for an ABI with a modified materializable callback.
CoroSplitPass(std::function<bool (Instruction &)> IsMatCallback,bool OptimizeFrame)2179 CoroSplitPass::CoroSplitPass(std::function<bool(Instruction &)> IsMatCallback,
2180                              bool OptimizeFrame)
2181     : CreateAndInitABI([=](Function &F, coro::Shape &S) {
2182         std::unique_ptr<coro::BaseABI> ABI =
2183             CreateNewABI(F, S, IsMatCallback, {});
2184         ABI->init();
2185         return ABI;
2186       }),
2187       OptimizeFrame(OptimizeFrame) {}
2188 
2189 // For back compatibility, constructor takes a materializable callback and
2190 // creates a generator for an ABI with a modified materializable callback.
CoroSplitPass(std::function<bool (Instruction &)> IsMatCallback,SmallVector<CoroSplitPass::BaseABITy> GenCustomABIs,bool OptimizeFrame)2191 CoroSplitPass::CoroSplitPass(
2192     std::function<bool(Instruction &)> IsMatCallback,
2193     SmallVector<CoroSplitPass::BaseABITy> GenCustomABIs, bool OptimizeFrame)
2194     : CreateAndInitABI([=](Function &F, coro::Shape &S) {
2195         std::unique_ptr<coro::BaseABI> ABI =
2196             CreateNewABI(F, S, IsMatCallback, GenCustomABIs);
2197         ABI->init();
2198         return ABI;
2199       }),
2200       OptimizeFrame(OptimizeFrame) {}
2201 
run(LazyCallGraph::SCC & C,CGSCCAnalysisManager & AM,LazyCallGraph & CG,CGSCCUpdateResult & UR)2202 PreservedAnalyses CoroSplitPass::run(LazyCallGraph::SCC &C,
2203                                      CGSCCAnalysisManager &AM,
2204                                      LazyCallGraph &CG, CGSCCUpdateResult &UR) {
2205   // NB: One invariant of a valid LazyCallGraph::SCC is that it must contain a
2206   //     non-zero number of nodes, so we assume that here and grab the first
2207   //     node's function's module.
2208   Module &M = *C.begin()->getFunction().getParent();
2209   auto &FAM =
2210       AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, CG).getManager();
2211 
2212   // Check for uses of llvm.coro.prepare.retcon/async.
2213   SmallVector<Function *, 2> PrepareFns;
2214   addPrepareFunction(M, PrepareFns, "llvm.coro.prepare.retcon");
2215   addPrepareFunction(M, PrepareFns, "llvm.coro.prepare.async");
2216 
2217   // Find coroutines for processing.
2218   SmallVector<LazyCallGraph::Node *> Coroutines;
2219   for (LazyCallGraph::Node &N : C)
2220     if (N.getFunction().isPresplitCoroutine())
2221       Coroutines.push_back(&N);
2222 
2223   if (Coroutines.empty() && PrepareFns.empty())
2224     return PreservedAnalyses::all();
2225 
2226   auto *CurrentSCC = &C;
2227   // Split all the coroutines.
2228   for (LazyCallGraph::Node *N : Coroutines) {
2229     Function &F = N->getFunction();
2230     LLVM_DEBUG(dbgs() << "CoroSplit: Processing coroutine '" << F.getName()
2231                       << "\n");
2232 
2233     // The suspend-crossing algorithm in buildCoroutineFrame gets tripped up
2234     // by unreachable blocks, so remove them as a first pass. Remove the
2235     // unreachable blocks before collecting intrinsics into Shape.
2236     removeUnreachableBlocks(F);
2237 
2238     coro::Shape Shape(F);
2239     if (!Shape.CoroBegin)
2240       continue;
2241 
2242     F.setSplittedCoroutine();
2243 
2244     std::unique_ptr<coro::BaseABI> ABI = CreateAndInitABI(F, Shape);
2245 
2246     SmallVector<Function *, 4> Clones;
2247     auto &TTI = FAM.getResult<TargetIRAnalysis>(F);
2248     doSplitCoroutine(F, Clones, *ABI, TTI, OptimizeFrame);
2249     CurrentSCC = &updateCallGraphAfterCoroutineSplit(
2250         *N, Shape, Clones, *CurrentSCC, CG, AM, UR, FAM);
2251 
2252     auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(F);
2253     ORE.emit([&]() {
2254       return OptimizationRemark(DEBUG_TYPE, "CoroSplit", &F)
2255              << "Split '" << ore::NV("function", F.getName())
2256              << "' (frame_size=" << ore::NV("frame_size", Shape.FrameSize)
2257              << ", align=" << ore::NV("align", Shape.FrameAlign.value()) << ")";
2258     });
2259 
2260     if (!Shape.CoroSuspends.empty()) {
2261       // Run the CGSCC pipeline on the original and newly split functions.
2262       UR.CWorklist.insert(CurrentSCC);
2263       for (Function *Clone : Clones)
2264         UR.CWorklist.insert(CG.lookupSCC(CG.get(*Clone)));
2265     }
2266   }
2267 
2268   for (auto *PrepareFn : PrepareFns) {
2269     replaceAllPrepares(PrepareFn, CG, *CurrentSCC);
2270   }
2271 
2272   return PreservedAnalyses::none();
2273 }
2274