xref: /freebsd/contrib/llvm-project/llvm/lib/Transforms/Coroutines/Coroutines.cpp (revision 0fca6ea1d4eea4c934cfff25ac9ee8ad6fe95583)
1 //===- Coroutines.cpp -----------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the common infrastructure for Coroutine Passes.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "CoroInstr.h"
14 #include "CoroInternal.h"
15 #include "llvm/ADT/SmallVector.h"
16 #include "llvm/ADT/StringRef.h"
17 #include "llvm/Analysis/CallGraph.h"
18 #include "llvm/IR/Attributes.h"
19 #include "llvm/IR/Constants.h"
20 #include "llvm/IR/DerivedTypes.h"
21 #include "llvm/IR/Function.h"
22 #include "llvm/IR/InstIterator.h"
23 #include "llvm/IR/Instructions.h"
24 #include "llvm/IR/IntrinsicInst.h"
25 #include "llvm/IR/Intrinsics.h"
26 #include "llvm/IR/Module.h"
27 #include "llvm/IR/Type.h"
28 #include "llvm/Support/Casting.h"
29 #include "llvm/Support/ErrorHandling.h"
30 #include "llvm/Transforms/Utils/Local.h"
31 #include <cassert>
32 #include <cstddef>
33 #include <utility>
34 
35 using namespace llvm;
36 
37 // Construct the lowerer base class and initialize its members.
LowererBase(Module & M)38 coro::LowererBase::LowererBase(Module &M)
39     : TheModule(M), Context(M.getContext()),
40       Int8Ptr(PointerType::get(Context, 0)),
41       ResumeFnType(FunctionType::get(Type::getVoidTy(Context), Int8Ptr,
42                                      /*isVarArg=*/false)),
43       NullPtr(ConstantPointerNull::get(Int8Ptr)) {}
44 
45 // Creates a call to llvm.coro.subfn.addr to obtain a resume function address.
46 // It generates the following:
47 //
48 //    call ptr @llvm.coro.subfn.addr(ptr %Arg, i8 %index)
49 
makeSubFnCall(Value * Arg,int Index,Instruction * InsertPt)50 CallInst *coro::LowererBase::makeSubFnCall(Value *Arg, int Index,
51                                            Instruction *InsertPt) {
52   auto *IndexVal = ConstantInt::get(Type::getInt8Ty(Context), Index);
53   auto *Fn = Intrinsic::getDeclaration(&TheModule, Intrinsic::coro_subfn_addr);
54 
55   assert(Index >= CoroSubFnInst::IndexFirst &&
56          Index < CoroSubFnInst::IndexLast &&
57          "makeSubFnCall: Index value out of range");
58   return CallInst::Create(Fn, {Arg, IndexVal}, "", InsertPt->getIterator());
59 }
60 
61 // NOTE: Must be sorted!
62 static const char *const CoroIntrinsics[] = {
63     "llvm.coro.align",
64     "llvm.coro.alloc",
65     "llvm.coro.async.context.alloc",
66     "llvm.coro.async.context.dealloc",
67     "llvm.coro.async.resume",
68     "llvm.coro.async.size.replace",
69     "llvm.coro.async.store_resume",
70     "llvm.coro.await.suspend.bool",
71     "llvm.coro.await.suspend.handle",
72     "llvm.coro.await.suspend.void",
73     "llvm.coro.begin",
74     "llvm.coro.destroy",
75     "llvm.coro.done",
76     "llvm.coro.end",
77     "llvm.coro.end.async",
78     "llvm.coro.frame",
79     "llvm.coro.free",
80     "llvm.coro.id",
81     "llvm.coro.id.async",
82     "llvm.coro.id.retcon",
83     "llvm.coro.id.retcon.once",
84     "llvm.coro.noop",
85     "llvm.coro.prepare.async",
86     "llvm.coro.prepare.retcon",
87     "llvm.coro.promise",
88     "llvm.coro.resume",
89     "llvm.coro.save",
90     "llvm.coro.size",
91     "llvm.coro.subfn.addr",
92     "llvm.coro.suspend",
93     "llvm.coro.suspend.async",
94     "llvm.coro.suspend.retcon",
95 };
96 
97 #ifndef NDEBUG
isCoroutineIntrinsicName(StringRef Name)98 static bool isCoroutineIntrinsicName(StringRef Name) {
99   return Intrinsic::lookupLLVMIntrinsicByName(CoroIntrinsics, Name) != -1;
100 }
101 #endif
102 
declaresAnyIntrinsic(const Module & M)103 bool coro::declaresAnyIntrinsic(const Module &M) {
104   for (StringRef Name : CoroIntrinsics) {
105     assert(isCoroutineIntrinsicName(Name) && "not a coroutine intrinsic");
106     if (M.getNamedValue(Name))
107       return true;
108   }
109 
110   return false;
111 }
112 
113 // Verifies if a module has named values listed. Also, in debug mode verifies
114 // that names are intrinsic names.
declaresIntrinsics(const Module & M,const std::initializer_list<StringRef> List)115 bool coro::declaresIntrinsics(const Module &M,
116                               const std::initializer_list<StringRef> List) {
117   for (StringRef Name : List) {
118     assert(isCoroutineIntrinsicName(Name) && "not a coroutine intrinsic");
119     if (M.getNamedValue(Name))
120       return true;
121   }
122 
123   return false;
124 }
125 
126 // Replace all coro.frees associated with the provided CoroId either with 'null'
127 // if Elide is true and with its frame parameter otherwise.
replaceCoroFree(CoroIdInst * CoroId,bool Elide)128 void coro::replaceCoroFree(CoroIdInst *CoroId, bool Elide) {
129   SmallVector<CoroFreeInst *, 4> CoroFrees;
130   for (User *U : CoroId->users())
131     if (auto CF = dyn_cast<CoroFreeInst>(U))
132       CoroFrees.push_back(CF);
133 
134   if (CoroFrees.empty())
135     return;
136 
137   Value *Replacement =
138       Elide
139           ? ConstantPointerNull::get(PointerType::get(CoroId->getContext(), 0))
140           : CoroFrees.front()->getFrame();
141 
142   for (CoroFreeInst *CF : CoroFrees) {
143     CF->replaceAllUsesWith(Replacement);
144     CF->eraseFromParent();
145   }
146 }
147 
clear(coro::Shape & Shape)148 static void clear(coro::Shape &Shape) {
149   Shape.CoroBegin = nullptr;
150   Shape.CoroEnds.clear();
151   Shape.CoroSizes.clear();
152   Shape.CoroSuspends.clear();
153 
154   Shape.FrameTy = nullptr;
155   Shape.FramePtr = nullptr;
156   Shape.AllocaSpillBlock = nullptr;
157 }
158 
createCoroSave(CoroBeginInst * CoroBegin,CoroSuspendInst * SuspendInst)159 static CoroSaveInst *createCoroSave(CoroBeginInst *CoroBegin,
160                                     CoroSuspendInst *SuspendInst) {
161   Module *M = SuspendInst->getModule();
162   auto *Fn = Intrinsic::getDeclaration(M, Intrinsic::coro_save);
163   auto *SaveInst = cast<CoroSaveInst>(
164       CallInst::Create(Fn, CoroBegin, "", SuspendInst->getIterator()));
165   assert(!SuspendInst->getCoroSave());
166   SuspendInst->setArgOperand(0, SaveInst);
167   return SaveInst;
168 }
169 
170 // Collect "interesting" coroutine intrinsics.
buildFrom(Function & F)171 void coro::Shape::buildFrom(Function &F) {
172   bool HasFinalSuspend = false;
173   bool HasUnwindCoroEnd = false;
174   size_t FinalSuspendIndex = 0;
175   clear(*this);
176   SmallVector<CoroFrameInst *, 8> CoroFrames;
177   SmallVector<CoroSaveInst *, 2> UnusedCoroSaves;
178 
179   for (Instruction &I : instructions(F)) {
180     // FIXME: coro_await_suspend_* are not proper `IntrinisicInst`s
181     // because they might be invoked
182     if (auto AWS = dyn_cast<CoroAwaitSuspendInst>(&I)) {
183       CoroAwaitSuspends.push_back(AWS);
184     } else if (auto II = dyn_cast<IntrinsicInst>(&I)) {
185       switch (II->getIntrinsicID()) {
186       default:
187         continue;
188       case Intrinsic::coro_size:
189         CoroSizes.push_back(cast<CoroSizeInst>(II));
190         break;
191       case Intrinsic::coro_align:
192         CoroAligns.push_back(cast<CoroAlignInst>(II));
193         break;
194       case Intrinsic::coro_frame:
195         CoroFrames.push_back(cast<CoroFrameInst>(II));
196         break;
197       case Intrinsic::coro_save:
198         // After optimizations, coro_suspends using this coro_save might have
199         // been removed, remember orphaned coro_saves to remove them later.
200         if (II->use_empty())
201           UnusedCoroSaves.push_back(cast<CoroSaveInst>(II));
202         break;
203       case Intrinsic::coro_suspend_async: {
204         auto *Suspend = cast<CoroSuspendAsyncInst>(II);
205         Suspend->checkWellFormed();
206         CoroSuspends.push_back(Suspend);
207         break;
208       }
209       case Intrinsic::coro_suspend_retcon: {
210         auto Suspend = cast<CoroSuspendRetconInst>(II);
211         CoroSuspends.push_back(Suspend);
212         break;
213       }
214       case Intrinsic::coro_suspend: {
215         auto Suspend = cast<CoroSuspendInst>(II);
216         CoroSuspends.push_back(Suspend);
217         if (Suspend->isFinal()) {
218           if (HasFinalSuspend)
219             report_fatal_error(
220               "Only one suspend point can be marked as final");
221           HasFinalSuspend = true;
222           FinalSuspendIndex = CoroSuspends.size() - 1;
223         }
224         break;
225       }
226       case Intrinsic::coro_begin: {
227         auto CB = cast<CoroBeginInst>(II);
228 
229         // Ignore coro id's that aren't pre-split.
230         auto Id = dyn_cast<CoroIdInst>(CB->getId());
231         if (Id && !Id->getInfo().isPreSplit())
232           break;
233 
234         if (CoroBegin)
235           report_fatal_error(
236                 "coroutine should have exactly one defining @llvm.coro.begin");
237         CB->addRetAttr(Attribute::NonNull);
238         CB->addRetAttr(Attribute::NoAlias);
239         CB->removeFnAttr(Attribute::NoDuplicate);
240         CoroBegin = CB;
241         break;
242       }
243       case Intrinsic::coro_end_async:
244       case Intrinsic::coro_end:
245         CoroEnds.push_back(cast<AnyCoroEndInst>(II));
246         if (auto *AsyncEnd = dyn_cast<CoroAsyncEndInst>(II)) {
247           AsyncEnd->checkWellFormed();
248         }
249 
250         if (CoroEnds.back()->isUnwind())
251           HasUnwindCoroEnd = true;
252 
253         if (CoroEnds.back()->isFallthrough() && isa<CoroEndInst>(II)) {
254           // Make sure that the fallthrough coro.end is the first element in the
255           // CoroEnds vector.
256           // Note: I don't think this is neccessary anymore.
257           if (CoroEnds.size() > 1) {
258             if (CoroEnds.front()->isFallthrough())
259               report_fatal_error(
260                   "Only one coro.end can be marked as fallthrough");
261             std::swap(CoroEnds.front(), CoroEnds.back());
262           }
263         }
264         break;
265       }
266     }
267   }
268 
269   // If for some reason, we were not able to find coro.begin, bailout.
270   if (!CoroBegin) {
271     // Replace coro.frame which are supposed to be lowered to the result of
272     // coro.begin with undef.
273     auto *Undef = UndefValue::get(PointerType::get(F.getContext(), 0));
274     for (CoroFrameInst *CF : CoroFrames) {
275       CF->replaceAllUsesWith(Undef);
276       CF->eraseFromParent();
277     }
278 
279     // Replace all coro.suspend with undef and remove related coro.saves if
280     // present.
281     for (AnyCoroSuspendInst *CS : CoroSuspends) {
282       CS->replaceAllUsesWith(UndefValue::get(CS->getType()));
283       CS->eraseFromParent();
284       if (auto *CoroSave = CS->getCoroSave())
285         CoroSave->eraseFromParent();
286     }
287 
288     // Replace all coro.ends with unreachable instruction.
289     for (AnyCoroEndInst *CE : CoroEnds)
290       changeToUnreachable(CE);
291 
292     return;
293   }
294 
295   auto Id = CoroBegin->getId();
296   switch (auto IdIntrinsic = Id->getIntrinsicID()) {
297   case Intrinsic::coro_id: {
298     auto SwitchId = cast<CoroIdInst>(Id);
299     this->ABI = coro::ABI::Switch;
300     this->SwitchLowering.HasFinalSuspend = HasFinalSuspend;
301     this->SwitchLowering.HasUnwindCoroEnd = HasUnwindCoroEnd;
302     this->SwitchLowering.ResumeSwitch = nullptr;
303     this->SwitchLowering.PromiseAlloca = SwitchId->getPromise();
304     this->SwitchLowering.ResumeEntryBlock = nullptr;
305 
306     for (auto *AnySuspend : CoroSuspends) {
307       auto Suspend = dyn_cast<CoroSuspendInst>(AnySuspend);
308       if (!Suspend) {
309 #ifndef NDEBUG
310         AnySuspend->dump();
311 #endif
312         report_fatal_error("coro.id must be paired with coro.suspend");
313       }
314 
315       if (!Suspend->getCoroSave())
316         createCoroSave(CoroBegin, Suspend);
317     }
318     break;
319   }
320   case Intrinsic::coro_id_async: {
321     auto *AsyncId = cast<CoroIdAsyncInst>(Id);
322     AsyncId->checkWellFormed();
323     this->ABI = coro::ABI::Async;
324     this->AsyncLowering.Context = AsyncId->getStorage();
325     this->AsyncLowering.ContextArgNo = AsyncId->getStorageArgumentIndex();
326     this->AsyncLowering.ContextHeaderSize = AsyncId->getStorageSize();
327     this->AsyncLowering.ContextAlignment =
328         AsyncId->getStorageAlignment().value();
329     this->AsyncLowering.AsyncFuncPointer = AsyncId->getAsyncFunctionPointer();
330     this->AsyncLowering.AsyncCC = F.getCallingConv();
331     break;
332   };
333   case Intrinsic::coro_id_retcon:
334   case Intrinsic::coro_id_retcon_once: {
335     auto ContinuationId = cast<AnyCoroIdRetconInst>(Id);
336     ContinuationId->checkWellFormed();
337     this->ABI = (IdIntrinsic == Intrinsic::coro_id_retcon
338                   ? coro::ABI::Retcon
339                   : coro::ABI::RetconOnce);
340     auto Prototype = ContinuationId->getPrototype();
341     this->RetconLowering.ResumePrototype = Prototype;
342     this->RetconLowering.Alloc = ContinuationId->getAllocFunction();
343     this->RetconLowering.Dealloc = ContinuationId->getDeallocFunction();
344     this->RetconLowering.ReturnBlock = nullptr;
345     this->RetconLowering.IsFrameInlineInStorage = false;
346 
347     // Determine the result value types, and make sure they match up with
348     // the values passed to the suspends.
349     auto ResultTys = getRetconResultTypes();
350     auto ResumeTys = getRetconResumeTypes();
351 
352     for (auto *AnySuspend : CoroSuspends) {
353       auto Suspend = dyn_cast<CoroSuspendRetconInst>(AnySuspend);
354       if (!Suspend) {
355 #ifndef NDEBUG
356         AnySuspend->dump();
357 #endif
358         report_fatal_error("coro.id.retcon.* must be paired with "
359                            "coro.suspend.retcon");
360       }
361 
362       // Check that the argument types of the suspend match the results.
363       auto SI = Suspend->value_begin(), SE = Suspend->value_end();
364       auto RI = ResultTys.begin(), RE = ResultTys.end();
365       for (; SI != SE && RI != RE; ++SI, ++RI) {
366         auto SrcTy = (*SI)->getType();
367         if (SrcTy != *RI) {
368           // The optimizer likes to eliminate bitcasts leading into variadic
369           // calls, but that messes with our invariants.  Re-insert the
370           // bitcast and ignore this type mismatch.
371           if (CastInst::isBitCastable(SrcTy, *RI)) {
372             auto BCI = new BitCastInst(*SI, *RI, "", Suspend->getIterator());
373             SI->set(BCI);
374             continue;
375           }
376 
377 #ifndef NDEBUG
378           Suspend->dump();
379           Prototype->getFunctionType()->dump();
380 #endif
381           report_fatal_error("argument to coro.suspend.retcon does not "
382                              "match corresponding prototype function result");
383         }
384       }
385       if (SI != SE || RI != RE) {
386 #ifndef NDEBUG
387         Suspend->dump();
388         Prototype->getFunctionType()->dump();
389 #endif
390         report_fatal_error("wrong number of arguments to coro.suspend.retcon");
391       }
392 
393       // Check that the result type of the suspend matches the resume types.
394       Type *SResultTy = Suspend->getType();
395       ArrayRef<Type*> SuspendResultTys;
396       if (SResultTy->isVoidTy()) {
397         // leave as empty array
398       } else if (auto SResultStructTy = dyn_cast<StructType>(SResultTy)) {
399         SuspendResultTys = SResultStructTy->elements();
400       } else {
401         // forms an ArrayRef using SResultTy, be careful
402         SuspendResultTys = SResultTy;
403       }
404       if (SuspendResultTys.size() != ResumeTys.size()) {
405 #ifndef NDEBUG
406         Suspend->dump();
407         Prototype->getFunctionType()->dump();
408 #endif
409         report_fatal_error("wrong number of results from coro.suspend.retcon");
410       }
411       for (size_t I = 0, E = ResumeTys.size(); I != E; ++I) {
412         if (SuspendResultTys[I] != ResumeTys[I]) {
413 #ifndef NDEBUG
414           Suspend->dump();
415           Prototype->getFunctionType()->dump();
416 #endif
417           report_fatal_error("result from coro.suspend.retcon does not "
418                              "match corresponding prototype function param");
419         }
420       }
421     }
422     break;
423   }
424 
425   default:
426     llvm_unreachable("coro.begin is not dependent on a coro.id call");
427   }
428 
429   // The coro.free intrinsic is always lowered to the result of coro.begin.
430   for (CoroFrameInst *CF : CoroFrames) {
431     CF->replaceAllUsesWith(CoroBegin);
432     CF->eraseFromParent();
433   }
434 
435   // Move final suspend to be the last element in the CoroSuspends vector.
436   if (ABI == coro::ABI::Switch &&
437       SwitchLowering.HasFinalSuspend &&
438       FinalSuspendIndex != CoroSuspends.size() - 1)
439     std::swap(CoroSuspends[FinalSuspendIndex], CoroSuspends.back());
440 
441   // Remove orphaned coro.saves.
442   for (CoroSaveInst *CoroSave : UnusedCoroSaves)
443     CoroSave->eraseFromParent();
444 }
445 
propagateCallAttrsFromCallee(CallInst * Call,Function * Callee)446 static void propagateCallAttrsFromCallee(CallInst *Call, Function *Callee) {
447   Call->setCallingConv(Callee->getCallingConv());
448   // TODO: attributes?
449 }
450 
addCallToCallGraph(CallGraph * CG,CallInst * Call,Function * Callee)451 static void addCallToCallGraph(CallGraph *CG, CallInst *Call, Function *Callee){
452   if (CG)
453     (*CG)[Call->getFunction()]->addCalledFunction(Call, (*CG)[Callee]);
454 }
455 
emitAlloc(IRBuilder<> & Builder,Value * Size,CallGraph * CG) const456 Value *coro::Shape::emitAlloc(IRBuilder<> &Builder, Value *Size,
457                               CallGraph *CG) const {
458   switch (ABI) {
459   case coro::ABI::Switch:
460     llvm_unreachable("can't allocate memory in coro switch-lowering");
461 
462   case coro::ABI::Retcon:
463   case coro::ABI::RetconOnce: {
464     auto Alloc = RetconLowering.Alloc;
465     Size = Builder.CreateIntCast(Size,
466                                  Alloc->getFunctionType()->getParamType(0),
467                                  /*is signed*/ false);
468     auto *Call = Builder.CreateCall(Alloc, Size);
469     propagateCallAttrsFromCallee(Call, Alloc);
470     addCallToCallGraph(CG, Call, Alloc);
471     return Call;
472   }
473   case coro::ABI::Async:
474     llvm_unreachable("can't allocate memory in coro async-lowering");
475   }
476   llvm_unreachable("Unknown coro::ABI enum");
477 }
478 
emitDealloc(IRBuilder<> & Builder,Value * Ptr,CallGraph * CG) const479 void coro::Shape::emitDealloc(IRBuilder<> &Builder, Value *Ptr,
480                               CallGraph *CG) const {
481   switch (ABI) {
482   case coro::ABI::Switch:
483     llvm_unreachable("can't allocate memory in coro switch-lowering");
484 
485   case coro::ABI::Retcon:
486   case coro::ABI::RetconOnce: {
487     auto Dealloc = RetconLowering.Dealloc;
488     Ptr = Builder.CreateBitCast(Ptr,
489                                 Dealloc->getFunctionType()->getParamType(0));
490     auto *Call = Builder.CreateCall(Dealloc, Ptr);
491     propagateCallAttrsFromCallee(Call, Dealloc);
492     addCallToCallGraph(CG, Call, Dealloc);
493     return;
494   }
495   case coro::ABI::Async:
496     llvm_unreachable("can't allocate memory in coro async-lowering");
497   }
498   llvm_unreachable("Unknown coro::ABI enum");
499 }
500 
fail(const Instruction * I,const char * Reason,Value * V)501 [[noreturn]] static void fail(const Instruction *I, const char *Reason,
502                               Value *V) {
503 #ifndef NDEBUG
504   I->dump();
505   if (V) {
506     errs() << "  Value: ";
507     V->printAsOperand(llvm::errs());
508     errs() << '\n';
509   }
510 #endif
511   report_fatal_error(Reason);
512 }
513 
514 /// Check that the given value is a well-formed prototype for the
515 /// llvm.coro.id.retcon.* intrinsics.
checkWFRetconPrototype(const AnyCoroIdRetconInst * I,Value * V)516 static void checkWFRetconPrototype(const AnyCoroIdRetconInst *I, Value *V) {
517   auto F = dyn_cast<Function>(V->stripPointerCasts());
518   if (!F)
519     fail(I, "llvm.coro.id.retcon.* prototype not a Function", V);
520 
521   auto FT = F->getFunctionType();
522 
523   if (isa<CoroIdRetconInst>(I)) {
524     bool ResultOkay;
525     if (FT->getReturnType()->isPointerTy()) {
526       ResultOkay = true;
527     } else if (auto SRetTy = dyn_cast<StructType>(FT->getReturnType())) {
528       ResultOkay = (!SRetTy->isOpaque() &&
529                     SRetTy->getNumElements() > 0 &&
530                     SRetTy->getElementType(0)->isPointerTy());
531     } else {
532       ResultOkay = false;
533     }
534     if (!ResultOkay)
535       fail(I, "llvm.coro.id.retcon prototype must return pointer as first "
536               "result", F);
537 
538     if (FT->getReturnType() !=
539           I->getFunction()->getFunctionType()->getReturnType())
540       fail(I, "llvm.coro.id.retcon prototype return type must be same as"
541               "current function return type", F);
542   } else {
543     // No meaningful validation to do here for llvm.coro.id.unique.once.
544   }
545 
546   if (FT->getNumParams() == 0 || !FT->getParamType(0)->isPointerTy())
547     fail(I, "llvm.coro.id.retcon.* prototype must take pointer as "
548             "its first parameter", F);
549 }
550 
551 /// Check that the given value is a well-formed allocator.
checkWFAlloc(const Instruction * I,Value * V)552 static void checkWFAlloc(const Instruction *I, Value *V) {
553   auto F = dyn_cast<Function>(V->stripPointerCasts());
554   if (!F)
555     fail(I, "llvm.coro.* allocator not a Function", V);
556 
557   auto FT = F->getFunctionType();
558   if (!FT->getReturnType()->isPointerTy())
559     fail(I, "llvm.coro.* allocator must return a pointer", F);
560 
561   if (FT->getNumParams() != 1 ||
562       !FT->getParamType(0)->isIntegerTy())
563     fail(I, "llvm.coro.* allocator must take integer as only param", F);
564 }
565 
566 /// Check that the given value is a well-formed deallocator.
checkWFDealloc(const Instruction * I,Value * V)567 static void checkWFDealloc(const Instruction *I, Value *V) {
568   auto F = dyn_cast<Function>(V->stripPointerCasts());
569   if (!F)
570     fail(I, "llvm.coro.* deallocator not a Function", V);
571 
572   auto FT = F->getFunctionType();
573   if (!FT->getReturnType()->isVoidTy())
574     fail(I, "llvm.coro.* deallocator must return void", F);
575 
576   if (FT->getNumParams() != 1 ||
577       !FT->getParamType(0)->isPointerTy())
578     fail(I, "llvm.coro.* deallocator must take pointer as only param", F);
579 }
580 
checkConstantInt(const Instruction * I,Value * V,const char * Reason)581 static void checkConstantInt(const Instruction *I, Value *V,
582                              const char *Reason) {
583   if (!isa<ConstantInt>(V)) {
584     fail(I, Reason, V);
585   }
586 }
587 
checkWellFormed() const588 void AnyCoroIdRetconInst::checkWellFormed() const {
589   checkConstantInt(this, getArgOperand(SizeArg),
590                    "size argument to coro.id.retcon.* must be constant");
591   checkConstantInt(this, getArgOperand(AlignArg),
592                    "alignment argument to coro.id.retcon.* must be constant");
593   checkWFRetconPrototype(this, getArgOperand(PrototypeArg));
594   checkWFAlloc(this, getArgOperand(AllocArg));
595   checkWFDealloc(this, getArgOperand(DeallocArg));
596 }
597 
checkAsyncFuncPointer(const Instruction * I,Value * V)598 static void checkAsyncFuncPointer(const Instruction *I, Value *V) {
599   auto *AsyncFuncPtrAddr = dyn_cast<GlobalVariable>(V->stripPointerCasts());
600   if (!AsyncFuncPtrAddr)
601     fail(I, "llvm.coro.id.async async function pointer not a global", V);
602 }
603 
checkWellFormed() const604 void CoroIdAsyncInst::checkWellFormed() const {
605   checkConstantInt(this, getArgOperand(SizeArg),
606                    "size argument to coro.id.async must be constant");
607   checkConstantInt(this, getArgOperand(AlignArg),
608                    "alignment argument to coro.id.async must be constant");
609   checkConstantInt(this, getArgOperand(StorageArg),
610                    "storage argument offset to coro.id.async must be constant");
611   checkAsyncFuncPointer(this, getArgOperand(AsyncFuncPtrArg));
612 }
613 
checkAsyncContextProjectFunction(const Instruction * I,Function * F)614 static void checkAsyncContextProjectFunction(const Instruction *I,
615                                              Function *F) {
616   auto *FunTy = cast<FunctionType>(F->getValueType());
617   if (!FunTy->getReturnType()->isPointerTy())
618     fail(I,
619          "llvm.coro.suspend.async resume function projection function must "
620          "return a ptr type",
621          F);
622   if (FunTy->getNumParams() != 1 || !FunTy->getParamType(0)->isPointerTy())
623     fail(I,
624          "llvm.coro.suspend.async resume function projection function must "
625          "take one ptr type as parameter",
626          F);
627 }
628 
checkWellFormed() const629 void CoroSuspendAsyncInst::checkWellFormed() const {
630   checkAsyncContextProjectFunction(this, getAsyncContextProjectionFunction());
631 }
632 
checkWellFormed() const633 void CoroAsyncEndInst::checkWellFormed() const {
634   auto *MustTailCallFunc = getMustTailCallFunction();
635   if (!MustTailCallFunc)
636     return;
637   auto *FnTy = MustTailCallFunc->getFunctionType();
638   if (FnTy->getNumParams() != (arg_size() - 3))
639     fail(this,
640          "llvm.coro.end.async must tail call function argument type must "
641          "match the tail arguments",
642          MustTailCallFunc);
643 }
644