xref: /freebsd/contrib/llvm-project/clang/lib/Sema/SemaCoroutine.cpp (revision f157ca4696f5922275d5d451736005b9332eb136)
1 //===-- SemaCoroutine.cpp - Semantic Analysis for Coroutines --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 //  This file implements semantic analysis for C++ Coroutines.
10 //
11 //  This file contains references to sections of the Coroutines TS, which
12 //  can be found at http://wg21.link/coroutines.
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #include "CoroutineStmtBuilder.h"
17 #include "clang/AST/ASTLambda.h"
18 #include "clang/AST/Decl.h"
19 #include "clang/AST/ExprCXX.h"
20 #include "clang/AST/StmtCXX.h"
21 #include "clang/Lex/Preprocessor.h"
22 #include "clang/Sema/Initialization.h"
23 #include "clang/Sema/Overload.h"
24 #include "clang/Sema/ScopeInfo.h"
25 #include "clang/Sema/SemaInternal.h"
26 
27 using namespace clang;
28 using namespace sema;
29 
30 static LookupResult lookupMember(Sema &S, const char *Name, CXXRecordDecl *RD,
31                                  SourceLocation Loc, bool &Res) {
32   DeclarationName DN = S.PP.getIdentifierInfo(Name);
33   LookupResult LR(S, DN, Loc, Sema::LookupMemberName);
34   // Suppress diagnostics when a private member is selected. The same warnings
35   // will be produced again when building the call.
36   LR.suppressDiagnostics();
37   Res = S.LookupQualifiedName(LR, RD);
38   return LR;
39 }
40 
41 static bool lookupMember(Sema &S, const char *Name, CXXRecordDecl *RD,
42                          SourceLocation Loc) {
43   bool Res;
44   lookupMember(S, Name, RD, Loc, Res);
45   return Res;
46 }
47 
48 /// Look up the std::coroutine_traits<...>::promise_type for the given
49 /// function type.
50 static QualType lookupPromiseType(Sema &S, const FunctionDecl *FD,
51                                   SourceLocation KwLoc) {
52   const FunctionProtoType *FnType = FD->getType()->castAs<FunctionProtoType>();
53   const SourceLocation FuncLoc = FD->getLocation();
54   // FIXME: Cache std::coroutine_traits once we've found it.
55   NamespaceDecl *StdExp = S.lookupStdExperimentalNamespace();
56   if (!StdExp) {
57     S.Diag(KwLoc, diag::err_implied_coroutine_type_not_found)
58         << "std::experimental::coroutine_traits";
59     return QualType();
60   }
61 
62   ClassTemplateDecl *CoroTraits = S.lookupCoroutineTraits(KwLoc, FuncLoc);
63   if (!CoroTraits) {
64     return QualType();
65   }
66 
67   // Form template argument list for coroutine_traits<R, P1, P2, ...> according
68   // to [dcl.fct.def.coroutine]3
69   TemplateArgumentListInfo Args(KwLoc, KwLoc);
70   auto AddArg = [&](QualType T) {
71     Args.addArgument(TemplateArgumentLoc(
72         TemplateArgument(T), S.Context.getTrivialTypeSourceInfo(T, KwLoc)));
73   };
74   AddArg(FnType->getReturnType());
75   // If the function is a non-static member function, add the type
76   // of the implicit object parameter before the formal parameters.
77   if (auto *MD = dyn_cast<CXXMethodDecl>(FD)) {
78     if (MD->isInstance()) {
79       // [over.match.funcs]4
80       // For non-static member functions, the type of the implicit object
81       // parameter is
82       //  -- "lvalue reference to cv X" for functions declared without a
83       //      ref-qualifier or with the & ref-qualifier
84       //  -- "rvalue reference to cv X" for functions declared with the &&
85       //      ref-qualifier
86       QualType T = MD->getThisType()->getAs<PointerType>()->getPointeeType();
87       T = FnType->getRefQualifier() == RQ_RValue
88               ? S.Context.getRValueReferenceType(T)
89               : S.Context.getLValueReferenceType(T, /*SpelledAsLValue*/ true);
90       AddArg(T);
91     }
92   }
93   for (QualType T : FnType->getParamTypes())
94     AddArg(T);
95 
96   // Build the template-id.
97   QualType CoroTrait =
98       S.CheckTemplateIdType(TemplateName(CoroTraits), KwLoc, Args);
99   if (CoroTrait.isNull())
100     return QualType();
101   if (S.RequireCompleteType(KwLoc, CoroTrait,
102                             diag::err_coroutine_type_missing_specialization))
103     return QualType();
104 
105   auto *RD = CoroTrait->getAsCXXRecordDecl();
106   assert(RD && "specialization of class template is not a class?");
107 
108   // Look up the ::promise_type member.
109   LookupResult R(S, &S.PP.getIdentifierTable().get("promise_type"), KwLoc,
110                  Sema::LookupOrdinaryName);
111   S.LookupQualifiedName(R, RD);
112   auto *Promise = R.getAsSingle<TypeDecl>();
113   if (!Promise) {
114     S.Diag(FuncLoc,
115            diag::err_implied_std_coroutine_traits_promise_type_not_found)
116         << RD;
117     return QualType();
118   }
119   // The promise type is required to be a class type.
120   QualType PromiseType = S.Context.getTypeDeclType(Promise);
121 
122   auto buildElaboratedType = [&]() {
123     auto *NNS = NestedNameSpecifier::Create(S.Context, nullptr, StdExp);
124     NNS = NestedNameSpecifier::Create(S.Context, NNS, false,
125                                       CoroTrait.getTypePtr());
126     return S.Context.getElaboratedType(ETK_None, NNS, PromiseType);
127   };
128 
129   if (!PromiseType->getAsCXXRecordDecl()) {
130     S.Diag(FuncLoc,
131            diag::err_implied_std_coroutine_traits_promise_type_not_class)
132         << buildElaboratedType();
133     return QualType();
134   }
135   if (S.RequireCompleteType(FuncLoc, buildElaboratedType(),
136                             diag::err_coroutine_promise_type_incomplete))
137     return QualType();
138 
139   return PromiseType;
140 }
141 
142 /// Look up the std::experimental::coroutine_handle<PromiseType>.
143 static QualType lookupCoroutineHandleType(Sema &S, QualType PromiseType,
144                                           SourceLocation Loc) {
145   if (PromiseType.isNull())
146     return QualType();
147 
148   NamespaceDecl *StdExp = S.lookupStdExperimentalNamespace();
149   assert(StdExp && "Should already be diagnosed");
150 
151   LookupResult Result(S, &S.PP.getIdentifierTable().get("coroutine_handle"),
152                       Loc, Sema::LookupOrdinaryName);
153   if (!S.LookupQualifiedName(Result, StdExp)) {
154     S.Diag(Loc, diag::err_implied_coroutine_type_not_found)
155         << "std::experimental::coroutine_handle";
156     return QualType();
157   }
158 
159   ClassTemplateDecl *CoroHandle = Result.getAsSingle<ClassTemplateDecl>();
160   if (!CoroHandle) {
161     Result.suppressDiagnostics();
162     // We found something weird. Complain about the first thing we found.
163     NamedDecl *Found = *Result.begin();
164     S.Diag(Found->getLocation(), diag::err_malformed_std_coroutine_handle);
165     return QualType();
166   }
167 
168   // Form template argument list for coroutine_handle<Promise>.
169   TemplateArgumentListInfo Args(Loc, Loc);
170   Args.addArgument(TemplateArgumentLoc(
171       TemplateArgument(PromiseType),
172       S.Context.getTrivialTypeSourceInfo(PromiseType, Loc)));
173 
174   // Build the template-id.
175   QualType CoroHandleType =
176       S.CheckTemplateIdType(TemplateName(CoroHandle), Loc, Args);
177   if (CoroHandleType.isNull())
178     return QualType();
179   if (S.RequireCompleteType(Loc, CoroHandleType,
180                             diag::err_coroutine_type_missing_specialization))
181     return QualType();
182 
183   return CoroHandleType;
184 }
185 
186 static bool isValidCoroutineContext(Sema &S, SourceLocation Loc,
187                                     StringRef Keyword) {
188   // [expr.await]p2 dictates that 'co_await' and 'co_yield' must be used within
189   // a function body.
190   // FIXME: This also covers [expr.await]p2: "An await-expression shall not
191   // appear in a default argument." But the diagnostic QoI here could be
192   // improved to inform the user that default arguments specifically are not
193   // allowed.
194   auto *FD = dyn_cast<FunctionDecl>(S.CurContext);
195   if (!FD) {
196     S.Diag(Loc, isa<ObjCMethodDecl>(S.CurContext)
197                     ? diag::err_coroutine_objc_method
198                     : diag::err_coroutine_outside_function) << Keyword;
199     return false;
200   }
201 
202   // An enumeration for mapping the diagnostic type to the correct diagnostic
203   // selection index.
204   enum InvalidFuncDiag {
205     DiagCtor = 0,
206     DiagDtor,
207     DiagMain,
208     DiagConstexpr,
209     DiagAutoRet,
210     DiagVarargs,
211     DiagConsteval,
212   };
213   bool Diagnosed = false;
214   auto DiagInvalid = [&](InvalidFuncDiag ID) {
215     S.Diag(Loc, diag::err_coroutine_invalid_func_context) << ID << Keyword;
216     Diagnosed = true;
217     return false;
218   };
219 
220   // Diagnose when a constructor, destructor
221   // or the function 'main' are declared as a coroutine.
222   auto *MD = dyn_cast<CXXMethodDecl>(FD);
223   // [class.ctor]p11: "A constructor shall not be a coroutine."
224   if (MD && isa<CXXConstructorDecl>(MD))
225     return DiagInvalid(DiagCtor);
226   // [class.dtor]p17: "A destructor shall not be a coroutine."
227   else if (MD && isa<CXXDestructorDecl>(MD))
228     return DiagInvalid(DiagDtor);
229   // [basic.start.main]p3: "The function main shall not be a coroutine."
230   else if (FD->isMain())
231     return DiagInvalid(DiagMain);
232 
233   // Emit a diagnostics for each of the following conditions which is not met.
234   // [expr.const]p2: "An expression e is a core constant expression unless the
235   // evaluation of e [...] would evaluate one of the following expressions:
236   // [...] an await-expression [...] a yield-expression."
237   if (FD->isConstexpr())
238     DiagInvalid(FD->isConsteval() ? DiagConsteval : DiagConstexpr);
239   // [dcl.spec.auto]p15: "A function declared with a return type that uses a
240   // placeholder type shall not be a coroutine."
241   if (FD->getReturnType()->isUndeducedType())
242     DiagInvalid(DiagAutoRet);
243   // [dcl.fct.def.coroutine]p1: "The parameter-declaration-clause of the
244   // coroutine shall not terminate with an ellipsis that is not part of a
245   // parameter-declaration."
246   if (FD->isVariadic())
247     DiagInvalid(DiagVarargs);
248 
249   return !Diagnosed;
250 }
251 
252 static ExprResult buildOperatorCoawaitLookupExpr(Sema &SemaRef, Scope *S,
253                                                  SourceLocation Loc) {
254   DeclarationName OpName =
255       SemaRef.Context.DeclarationNames.getCXXOperatorName(OO_Coawait);
256   LookupResult Operators(SemaRef, OpName, SourceLocation(),
257                          Sema::LookupOperatorName);
258   SemaRef.LookupName(Operators, S);
259 
260   assert(!Operators.isAmbiguous() && "Operator lookup cannot be ambiguous");
261   const auto &Functions = Operators.asUnresolvedSet();
262   bool IsOverloaded =
263       Functions.size() > 1 ||
264       (Functions.size() == 1 && isa<FunctionTemplateDecl>(*Functions.begin()));
265   Expr *CoawaitOp = UnresolvedLookupExpr::Create(
266       SemaRef.Context, /*NamingClass*/ nullptr, NestedNameSpecifierLoc(),
267       DeclarationNameInfo(OpName, Loc), /*RequiresADL*/ true, IsOverloaded,
268       Functions.begin(), Functions.end());
269   assert(CoawaitOp);
270   return CoawaitOp;
271 }
272 
273 /// Build a call to 'operator co_await' if there is a suitable operator for
274 /// the given expression.
275 static ExprResult buildOperatorCoawaitCall(Sema &SemaRef, SourceLocation Loc,
276                                            Expr *E,
277                                            UnresolvedLookupExpr *Lookup) {
278   UnresolvedSet<16> Functions;
279   Functions.append(Lookup->decls_begin(), Lookup->decls_end());
280   return SemaRef.CreateOverloadedUnaryOp(Loc, UO_Coawait, Functions, E);
281 }
282 
283 static ExprResult buildOperatorCoawaitCall(Sema &SemaRef, Scope *S,
284                                            SourceLocation Loc, Expr *E) {
285   ExprResult R = buildOperatorCoawaitLookupExpr(SemaRef, S, Loc);
286   if (R.isInvalid())
287     return ExprError();
288   return buildOperatorCoawaitCall(SemaRef, Loc, E,
289                                   cast<UnresolvedLookupExpr>(R.get()));
290 }
291 
292 static Expr *buildBuiltinCall(Sema &S, SourceLocation Loc, Builtin::ID Id,
293                               MultiExprArg CallArgs) {
294   StringRef Name = S.Context.BuiltinInfo.getName(Id);
295   LookupResult R(S, &S.Context.Idents.get(Name), Loc, Sema::LookupOrdinaryName);
296   S.LookupName(R, S.TUScope, /*AllowBuiltinCreation=*/true);
297 
298   auto *BuiltInDecl = R.getAsSingle<FunctionDecl>();
299   assert(BuiltInDecl && "failed to find builtin declaration");
300 
301   ExprResult DeclRef =
302       S.BuildDeclRefExpr(BuiltInDecl, BuiltInDecl->getType(), VK_LValue, Loc);
303   assert(DeclRef.isUsable() && "Builtin reference cannot fail");
304 
305   ExprResult Call =
306       S.BuildCallExpr(/*Scope=*/nullptr, DeclRef.get(), Loc, CallArgs, Loc);
307 
308   assert(!Call.isInvalid() && "Call to builtin cannot fail!");
309   return Call.get();
310 }
311 
312 static ExprResult buildCoroutineHandle(Sema &S, QualType PromiseType,
313                                        SourceLocation Loc) {
314   QualType CoroHandleType = lookupCoroutineHandleType(S, PromiseType, Loc);
315   if (CoroHandleType.isNull())
316     return ExprError();
317 
318   DeclContext *LookupCtx = S.computeDeclContext(CoroHandleType);
319   LookupResult Found(S, &S.PP.getIdentifierTable().get("from_address"), Loc,
320                      Sema::LookupOrdinaryName);
321   if (!S.LookupQualifiedName(Found, LookupCtx)) {
322     S.Diag(Loc, diag::err_coroutine_handle_missing_member)
323         << "from_address";
324     return ExprError();
325   }
326 
327   Expr *FramePtr =
328       buildBuiltinCall(S, Loc, Builtin::BI__builtin_coro_frame, {});
329 
330   CXXScopeSpec SS;
331   ExprResult FromAddr =
332       S.BuildDeclarationNameExpr(SS, Found, /*NeedsADL=*/false);
333   if (FromAddr.isInvalid())
334     return ExprError();
335 
336   return S.BuildCallExpr(nullptr, FromAddr.get(), Loc, FramePtr, Loc);
337 }
338 
339 struct ReadySuspendResumeResult {
340   enum AwaitCallType { ACT_Ready, ACT_Suspend, ACT_Resume };
341   Expr *Results[3];
342   OpaqueValueExpr *OpaqueValue;
343   bool IsInvalid;
344 };
345 
346 static ExprResult buildMemberCall(Sema &S, Expr *Base, SourceLocation Loc,
347                                   StringRef Name, MultiExprArg Args) {
348   DeclarationNameInfo NameInfo(&S.PP.getIdentifierTable().get(Name), Loc);
349 
350   // FIXME: Fix BuildMemberReferenceExpr to take a const CXXScopeSpec&.
351   CXXScopeSpec SS;
352   ExprResult Result = S.BuildMemberReferenceExpr(
353       Base, Base->getType(), Loc, /*IsPtr=*/false, SS,
354       SourceLocation(), nullptr, NameInfo, /*TemplateArgs=*/nullptr,
355       /*Scope=*/nullptr);
356   if (Result.isInvalid())
357     return ExprError();
358 
359   // We meant exactly what we asked for. No need for typo correction.
360   if (auto *TE = dyn_cast<TypoExpr>(Result.get())) {
361     S.clearDelayedTypo(TE);
362     S.Diag(Loc, diag::err_no_member)
363         << NameInfo.getName() << Base->getType()->getAsCXXRecordDecl()
364         << Base->getSourceRange();
365     return ExprError();
366   }
367 
368   return S.BuildCallExpr(nullptr, Result.get(), Loc, Args, Loc, nullptr);
369 }
370 
371 // See if return type is coroutine-handle and if so, invoke builtin coro-resume
372 // on its address. This is to enable experimental support for coroutine-handle
373 // returning await_suspend that results in a guaranteed tail call to the target
374 // coroutine.
375 static Expr *maybeTailCall(Sema &S, QualType RetType, Expr *E,
376                            SourceLocation Loc) {
377   if (RetType->isReferenceType())
378     return nullptr;
379   Type const *T = RetType.getTypePtr();
380   if (!T->isClassType() && !T->isStructureType())
381     return nullptr;
382 
383   // FIXME: Add convertability check to coroutine_handle<>. Possibly via
384   // EvaluateBinaryTypeTrait(BTT_IsConvertible, ...) which is at the moment
385   // a private function in SemaExprCXX.cpp
386 
387   ExprResult AddressExpr = buildMemberCall(S, E, Loc, "address", None);
388   if (AddressExpr.isInvalid())
389     return nullptr;
390 
391   Expr *JustAddress = AddressExpr.get();
392   // FIXME: Check that the type of AddressExpr is void*
393   return buildBuiltinCall(S, Loc, Builtin::BI__builtin_coro_resume,
394                           JustAddress);
395 }
396 
397 /// Build calls to await_ready, await_suspend, and await_resume for a co_await
398 /// expression.
399 static ReadySuspendResumeResult buildCoawaitCalls(Sema &S, VarDecl *CoroPromise,
400                                                   SourceLocation Loc, Expr *E) {
401   OpaqueValueExpr *Operand = new (S.Context)
402       OpaqueValueExpr(Loc, E->getType(), VK_LValue, E->getObjectKind(), E);
403 
404   // Assume invalid until we see otherwise.
405   ReadySuspendResumeResult Calls = {{}, Operand, /*IsInvalid=*/true};
406 
407   ExprResult CoroHandleRes = buildCoroutineHandle(S, CoroPromise->getType(), Loc);
408   if (CoroHandleRes.isInvalid())
409     return Calls;
410   Expr *CoroHandle = CoroHandleRes.get();
411 
412   const StringRef Funcs[] = {"await_ready", "await_suspend", "await_resume"};
413   MultiExprArg Args[] = {None, CoroHandle, None};
414   for (size_t I = 0, N = llvm::array_lengthof(Funcs); I != N; ++I) {
415     ExprResult Result = buildMemberCall(S, Operand, Loc, Funcs[I], Args[I]);
416     if (Result.isInvalid())
417       return Calls;
418     Calls.Results[I] = Result.get();
419   }
420 
421   // Assume the calls are valid; all further checking should make them invalid.
422   Calls.IsInvalid = false;
423 
424   using ACT = ReadySuspendResumeResult::AwaitCallType;
425   CallExpr *AwaitReady = cast<CallExpr>(Calls.Results[ACT::ACT_Ready]);
426   if (!AwaitReady->getType()->isDependentType()) {
427     // [expr.await]p3 [...]
428     // — await-ready is the expression e.await_ready(), contextually converted
429     // to bool.
430     ExprResult Conv = S.PerformContextuallyConvertToBool(AwaitReady);
431     if (Conv.isInvalid()) {
432       S.Diag(AwaitReady->getDirectCallee()->getBeginLoc(),
433              diag::note_await_ready_no_bool_conversion);
434       S.Diag(Loc, diag::note_coroutine_promise_call_implicitly_required)
435           << AwaitReady->getDirectCallee() << E->getSourceRange();
436       Calls.IsInvalid = true;
437     }
438     Calls.Results[ACT::ACT_Ready] = Conv.get();
439   }
440   CallExpr *AwaitSuspend = cast<CallExpr>(Calls.Results[ACT::ACT_Suspend]);
441   if (!AwaitSuspend->getType()->isDependentType()) {
442     // [expr.await]p3 [...]
443     //   - await-suspend is the expression e.await_suspend(h), which shall be
444     //     a prvalue of type void or bool.
445     QualType RetType = AwaitSuspend->getCallReturnType(S.Context);
446 
447     // Experimental support for coroutine_handle returning await_suspend.
448     if (Expr *TailCallSuspend = maybeTailCall(S, RetType, AwaitSuspend, Loc))
449       Calls.Results[ACT::ACT_Suspend] = TailCallSuspend;
450     else {
451       // non-class prvalues always have cv-unqualified types
452       if (RetType->isReferenceType() ||
453           (!RetType->isBooleanType() && !RetType->isVoidType())) {
454         S.Diag(AwaitSuspend->getCalleeDecl()->getLocation(),
455                diag::err_await_suspend_invalid_return_type)
456             << RetType;
457         S.Diag(Loc, diag::note_coroutine_promise_call_implicitly_required)
458             << AwaitSuspend->getDirectCallee();
459         Calls.IsInvalid = true;
460       }
461     }
462   }
463 
464   return Calls;
465 }
466 
467 static ExprResult buildPromiseCall(Sema &S, VarDecl *Promise,
468                                    SourceLocation Loc, StringRef Name,
469                                    MultiExprArg Args) {
470 
471   // Form a reference to the promise.
472   ExprResult PromiseRef = S.BuildDeclRefExpr(
473       Promise, Promise->getType().getNonReferenceType(), VK_LValue, Loc);
474   if (PromiseRef.isInvalid())
475     return ExprError();
476 
477   return buildMemberCall(S, PromiseRef.get(), Loc, Name, Args);
478 }
479 
480 VarDecl *Sema::buildCoroutinePromise(SourceLocation Loc) {
481   assert(isa<FunctionDecl>(CurContext) && "not in a function scope");
482   auto *FD = cast<FunctionDecl>(CurContext);
483   bool IsThisDependentType = [&] {
484     if (auto *MD = dyn_cast_or_null<CXXMethodDecl>(FD))
485       return MD->isInstance() && MD->getThisType()->isDependentType();
486     else
487       return false;
488   }();
489 
490   QualType T = FD->getType()->isDependentType() || IsThisDependentType
491                    ? Context.DependentTy
492                    : lookupPromiseType(*this, FD, Loc);
493   if (T.isNull())
494     return nullptr;
495 
496   auto *VD = VarDecl::Create(Context, FD, FD->getLocation(), FD->getLocation(),
497                              &PP.getIdentifierTable().get("__promise"), T,
498                              Context.getTrivialTypeSourceInfo(T, Loc), SC_None);
499   CheckVariableDeclarationType(VD);
500   if (VD->isInvalidDecl())
501     return nullptr;
502 
503   auto *ScopeInfo = getCurFunction();
504   // Build a list of arguments, based on the coroutine functions arguments,
505   // that will be passed to the promise type's constructor.
506   llvm::SmallVector<Expr *, 4> CtorArgExprs;
507 
508   // Add implicit object parameter.
509   if (auto *MD = dyn_cast<CXXMethodDecl>(FD)) {
510     if (MD->isInstance() && !isLambdaCallOperator(MD)) {
511       ExprResult ThisExpr = ActOnCXXThis(Loc);
512       if (ThisExpr.isInvalid())
513         return nullptr;
514       ThisExpr = CreateBuiltinUnaryOp(Loc, UO_Deref, ThisExpr.get());
515       if (ThisExpr.isInvalid())
516         return nullptr;
517       CtorArgExprs.push_back(ThisExpr.get());
518     }
519   }
520 
521   auto &Moves = ScopeInfo->CoroutineParameterMoves;
522   for (auto *PD : FD->parameters()) {
523     if (PD->getType()->isDependentType())
524       continue;
525 
526     auto RefExpr = ExprEmpty();
527     auto Move = Moves.find(PD);
528     assert(Move != Moves.end() &&
529            "Coroutine function parameter not inserted into move map");
530     // If a reference to the function parameter exists in the coroutine
531     // frame, use that reference.
532     auto *MoveDecl =
533         cast<VarDecl>(cast<DeclStmt>(Move->second)->getSingleDecl());
534     RefExpr =
535         BuildDeclRefExpr(MoveDecl, MoveDecl->getType().getNonReferenceType(),
536                          ExprValueKind::VK_LValue, FD->getLocation());
537     if (RefExpr.isInvalid())
538       return nullptr;
539     CtorArgExprs.push_back(RefExpr.get());
540   }
541 
542   // Create an initialization sequence for the promise type using the
543   // constructor arguments, wrapped in a parenthesized list expression.
544   Expr *PLE = ParenListExpr::Create(Context, FD->getLocation(),
545                                     CtorArgExprs, FD->getLocation());
546   InitializedEntity Entity = InitializedEntity::InitializeVariable(VD);
547   InitializationKind Kind = InitializationKind::CreateForInit(
548       VD->getLocation(), /*DirectInit=*/true, PLE);
549   InitializationSequence InitSeq(*this, Entity, Kind, CtorArgExprs,
550                                  /*TopLevelOfInitList=*/false,
551                                  /*TreatUnavailableAsInvalid=*/false);
552 
553   // Attempt to initialize the promise type with the arguments.
554   // If that fails, fall back to the promise type's default constructor.
555   if (InitSeq) {
556     ExprResult Result = InitSeq.Perform(*this, Entity, Kind, CtorArgExprs);
557     if (Result.isInvalid()) {
558       VD->setInvalidDecl();
559     } else if (Result.get()) {
560       VD->setInit(MaybeCreateExprWithCleanups(Result.get()));
561       VD->setInitStyle(VarDecl::CallInit);
562       CheckCompleteVariableDeclaration(VD);
563     }
564   } else
565     ActOnUninitializedDecl(VD);
566 
567   FD->addDecl(VD);
568   return VD;
569 }
570 
571 /// Check that this is a context in which a coroutine suspension can appear.
572 static FunctionScopeInfo *checkCoroutineContext(Sema &S, SourceLocation Loc,
573                                                 StringRef Keyword,
574                                                 bool IsImplicit = false) {
575   if (!isValidCoroutineContext(S, Loc, Keyword))
576     return nullptr;
577 
578   assert(isa<FunctionDecl>(S.CurContext) && "not in a function scope");
579 
580   auto *ScopeInfo = S.getCurFunction();
581   assert(ScopeInfo && "missing function scope for function");
582 
583   if (ScopeInfo->FirstCoroutineStmtLoc.isInvalid() && !IsImplicit)
584     ScopeInfo->setFirstCoroutineStmt(Loc, Keyword);
585 
586   if (ScopeInfo->CoroutinePromise)
587     return ScopeInfo;
588 
589   if (!S.buildCoroutineParameterMoves(Loc))
590     return nullptr;
591 
592   ScopeInfo->CoroutinePromise = S.buildCoroutinePromise(Loc);
593   if (!ScopeInfo->CoroutinePromise)
594     return nullptr;
595 
596   return ScopeInfo;
597 }
598 
599 bool Sema::ActOnCoroutineBodyStart(Scope *SC, SourceLocation KWLoc,
600                                    StringRef Keyword) {
601   if (!checkCoroutineContext(*this, KWLoc, Keyword))
602     return false;
603   auto *ScopeInfo = getCurFunction();
604   assert(ScopeInfo->CoroutinePromise);
605 
606   // If we have existing coroutine statements then we have already built
607   // the initial and final suspend points.
608   if (!ScopeInfo->NeedsCoroutineSuspends)
609     return true;
610 
611   ScopeInfo->setNeedsCoroutineSuspends(false);
612 
613   auto *Fn = cast<FunctionDecl>(CurContext);
614   SourceLocation Loc = Fn->getLocation();
615   // Build the initial suspend point
616   auto buildSuspends = [&](StringRef Name) mutable -> StmtResult {
617     ExprResult Suspend =
618         buildPromiseCall(*this, ScopeInfo->CoroutinePromise, Loc, Name, None);
619     if (Suspend.isInvalid())
620       return StmtError();
621     Suspend = buildOperatorCoawaitCall(*this, SC, Loc, Suspend.get());
622     if (Suspend.isInvalid())
623       return StmtError();
624     Suspend = BuildResolvedCoawaitExpr(Loc, Suspend.get(),
625                                        /*IsImplicit*/ true);
626     Suspend = ActOnFinishFullExpr(Suspend.get(), /*DiscardedValue*/ false);
627     if (Suspend.isInvalid()) {
628       Diag(Loc, diag::note_coroutine_promise_suspend_implicitly_required)
629           << ((Name == "initial_suspend") ? 0 : 1);
630       Diag(KWLoc, diag::note_declared_coroutine_here) << Keyword;
631       return StmtError();
632     }
633     return cast<Stmt>(Suspend.get());
634   };
635 
636   StmtResult InitSuspend = buildSuspends("initial_suspend");
637   if (InitSuspend.isInvalid())
638     return true;
639 
640   StmtResult FinalSuspend = buildSuspends("final_suspend");
641   if (FinalSuspend.isInvalid())
642     return true;
643 
644   ScopeInfo->setCoroutineSuspends(InitSuspend.get(), FinalSuspend.get());
645 
646   return true;
647 }
648 
649 // Recursively walks up the scope hierarchy until either a 'catch' or a function
650 // scope is found, whichever comes first.
651 static bool isWithinCatchScope(Scope *S) {
652   // 'co_await' and 'co_yield' keywords are disallowed within catch blocks, but
653   // lambdas that use 'co_await' are allowed. The loop below ends when a
654   // function scope is found in order to ensure the following behavior:
655   //
656   // void foo() {      // <- function scope
657   //   try {           //
658   //     co_await x;   // <- 'co_await' is OK within a function scope
659   //   } catch {       // <- catch scope
660   //     co_await x;   // <- 'co_await' is not OK within a catch scope
661   //     []() {        // <- function scope
662   //       co_await x; // <- 'co_await' is OK within a function scope
663   //     }();
664   //   }
665   // }
666   while (S && !(S->getFlags() & Scope::FnScope)) {
667     if (S->getFlags() & Scope::CatchScope)
668       return true;
669     S = S->getParent();
670   }
671   return false;
672 }
673 
674 // [expr.await]p2, emphasis added: "An await-expression shall appear only in
675 // a *potentially evaluated* expression within the compound-statement of a
676 // function-body *outside of a handler* [...] A context within a function
677 // where an await-expression can appear is called a suspension context of the
678 // function."
679 static void checkSuspensionContext(Sema &S, SourceLocation Loc,
680                                    StringRef Keyword) {
681   // First emphasis of [expr.await]p2: must be a potentially evaluated context.
682   // That is, 'co_await' and 'co_yield' cannot appear in subexpressions of
683   // \c sizeof.
684   if (S.isUnevaluatedContext())
685     S.Diag(Loc, diag::err_coroutine_unevaluated_context) << Keyword;
686 
687   // Second emphasis of [expr.await]p2: must be outside of an exception handler.
688   if (isWithinCatchScope(S.getCurScope()))
689     S.Diag(Loc, diag::err_coroutine_within_handler) << Keyword;
690 }
691 
692 ExprResult Sema::ActOnCoawaitExpr(Scope *S, SourceLocation Loc, Expr *E) {
693   if (!ActOnCoroutineBodyStart(S, Loc, "co_await")) {
694     CorrectDelayedTyposInExpr(E);
695     return ExprError();
696   }
697 
698   checkSuspensionContext(*this, Loc, "co_await");
699 
700   if (E->getType()->isPlaceholderType()) {
701     ExprResult R = CheckPlaceholderExpr(E);
702     if (R.isInvalid()) return ExprError();
703     E = R.get();
704   }
705   ExprResult Lookup = buildOperatorCoawaitLookupExpr(*this, S, Loc);
706   if (Lookup.isInvalid())
707     return ExprError();
708   return BuildUnresolvedCoawaitExpr(Loc, E,
709                                    cast<UnresolvedLookupExpr>(Lookup.get()));
710 }
711 
712 ExprResult Sema::BuildUnresolvedCoawaitExpr(SourceLocation Loc, Expr *E,
713                                             UnresolvedLookupExpr *Lookup) {
714   auto *FSI = checkCoroutineContext(*this, Loc, "co_await");
715   if (!FSI)
716     return ExprError();
717 
718   if (E->getType()->isPlaceholderType()) {
719     ExprResult R = CheckPlaceholderExpr(E);
720     if (R.isInvalid())
721       return ExprError();
722     E = R.get();
723   }
724 
725   auto *Promise = FSI->CoroutinePromise;
726   if (Promise->getType()->isDependentType()) {
727     Expr *Res =
728         new (Context) DependentCoawaitExpr(Loc, Context.DependentTy, E, Lookup);
729     return Res;
730   }
731 
732   auto *RD = Promise->getType()->getAsCXXRecordDecl();
733   if (lookupMember(*this, "await_transform", RD, Loc)) {
734     ExprResult R = buildPromiseCall(*this, Promise, Loc, "await_transform", E);
735     if (R.isInvalid()) {
736       Diag(Loc,
737            diag::note_coroutine_promise_implicit_await_transform_required_here)
738           << E->getSourceRange();
739       return ExprError();
740     }
741     E = R.get();
742   }
743   ExprResult Awaitable = buildOperatorCoawaitCall(*this, Loc, E, Lookup);
744   if (Awaitable.isInvalid())
745     return ExprError();
746 
747   return BuildResolvedCoawaitExpr(Loc, Awaitable.get());
748 }
749 
750 ExprResult Sema::BuildResolvedCoawaitExpr(SourceLocation Loc, Expr *E,
751                                   bool IsImplicit) {
752   auto *Coroutine = checkCoroutineContext(*this, Loc, "co_await", IsImplicit);
753   if (!Coroutine)
754     return ExprError();
755 
756   if (E->getType()->isPlaceholderType()) {
757     ExprResult R = CheckPlaceholderExpr(E);
758     if (R.isInvalid()) return ExprError();
759     E = R.get();
760   }
761 
762   if (E->getType()->isDependentType()) {
763     Expr *Res = new (Context)
764         CoawaitExpr(Loc, Context.DependentTy, E, IsImplicit);
765     return Res;
766   }
767 
768   // If the expression is a temporary, materialize it as an lvalue so that we
769   // can use it multiple times.
770   if (E->getValueKind() == VK_RValue)
771     E = CreateMaterializeTemporaryExpr(E->getType(), E, true);
772 
773   // The location of the `co_await` token cannot be used when constructing
774   // the member call expressions since it's before the location of `Expr`, which
775   // is used as the start of the member call expression.
776   SourceLocation CallLoc = E->getExprLoc();
777 
778   // Build the await_ready, await_suspend, await_resume calls.
779   ReadySuspendResumeResult RSS =
780       buildCoawaitCalls(*this, Coroutine->CoroutinePromise, CallLoc, E);
781   if (RSS.IsInvalid)
782     return ExprError();
783 
784   Expr *Res =
785       new (Context) CoawaitExpr(Loc, E, RSS.Results[0], RSS.Results[1],
786                                 RSS.Results[2], RSS.OpaqueValue, IsImplicit);
787 
788   return Res;
789 }
790 
791 ExprResult Sema::ActOnCoyieldExpr(Scope *S, SourceLocation Loc, Expr *E) {
792   if (!ActOnCoroutineBodyStart(S, Loc, "co_yield")) {
793     CorrectDelayedTyposInExpr(E);
794     return ExprError();
795   }
796 
797   checkSuspensionContext(*this, Loc, "co_yield");
798 
799   // Build yield_value call.
800   ExprResult Awaitable = buildPromiseCall(
801       *this, getCurFunction()->CoroutinePromise, Loc, "yield_value", E);
802   if (Awaitable.isInvalid())
803     return ExprError();
804 
805   // Build 'operator co_await' call.
806   Awaitable = buildOperatorCoawaitCall(*this, S, Loc, Awaitable.get());
807   if (Awaitable.isInvalid())
808     return ExprError();
809 
810   return BuildCoyieldExpr(Loc, Awaitable.get());
811 }
812 ExprResult Sema::BuildCoyieldExpr(SourceLocation Loc, Expr *E) {
813   auto *Coroutine = checkCoroutineContext(*this, Loc, "co_yield");
814   if (!Coroutine)
815     return ExprError();
816 
817   if (E->getType()->isPlaceholderType()) {
818     ExprResult R = CheckPlaceholderExpr(E);
819     if (R.isInvalid()) return ExprError();
820     E = R.get();
821   }
822 
823   if (E->getType()->isDependentType()) {
824     Expr *Res = new (Context) CoyieldExpr(Loc, Context.DependentTy, E);
825     return Res;
826   }
827 
828   // If the expression is a temporary, materialize it as an lvalue so that we
829   // can use it multiple times.
830   if (E->getValueKind() == VK_RValue)
831     E = CreateMaterializeTemporaryExpr(E->getType(), E, true);
832 
833   // Build the await_ready, await_suspend, await_resume calls.
834   ReadySuspendResumeResult RSS =
835       buildCoawaitCalls(*this, Coroutine->CoroutinePromise, Loc, E);
836   if (RSS.IsInvalid)
837     return ExprError();
838 
839   Expr *Res =
840       new (Context) CoyieldExpr(Loc, E, RSS.Results[0], RSS.Results[1],
841                                 RSS.Results[2], RSS.OpaqueValue);
842 
843   return Res;
844 }
845 
846 StmtResult Sema::ActOnCoreturnStmt(Scope *S, SourceLocation Loc, Expr *E) {
847   if (!ActOnCoroutineBodyStart(S, Loc, "co_return")) {
848     CorrectDelayedTyposInExpr(E);
849     return StmtError();
850   }
851   return BuildCoreturnStmt(Loc, E);
852 }
853 
854 StmtResult Sema::BuildCoreturnStmt(SourceLocation Loc, Expr *E,
855                                    bool IsImplicit) {
856   auto *FSI = checkCoroutineContext(*this, Loc, "co_return", IsImplicit);
857   if (!FSI)
858     return StmtError();
859 
860   if (E && E->getType()->isPlaceholderType() &&
861       !E->getType()->isSpecificPlaceholderType(BuiltinType::Overload)) {
862     ExprResult R = CheckPlaceholderExpr(E);
863     if (R.isInvalid()) return StmtError();
864     E = R.get();
865   }
866 
867   // Move the return value if we can
868   if (E) {
869     auto NRVOCandidate = this->getCopyElisionCandidate(E->getType(), E, CES_AsIfByStdMove);
870     if (NRVOCandidate) {
871       InitializedEntity Entity =
872           InitializedEntity::InitializeResult(Loc, E->getType(), NRVOCandidate);
873       ExprResult MoveResult = this->PerformMoveOrCopyInitialization(
874           Entity, NRVOCandidate, E->getType(), E);
875       if (MoveResult.get())
876         E = MoveResult.get();
877     }
878   }
879 
880   // FIXME: If the operand is a reference to a variable that's about to go out
881   // of scope, we should treat the operand as an xvalue for this overload
882   // resolution.
883   VarDecl *Promise = FSI->CoroutinePromise;
884   ExprResult PC;
885   if (E && (isa<InitListExpr>(E) || !E->getType()->isVoidType())) {
886     PC = buildPromiseCall(*this, Promise, Loc, "return_value", E);
887   } else {
888     E = MakeFullDiscardedValueExpr(E).get();
889     PC = buildPromiseCall(*this, Promise, Loc, "return_void", None);
890   }
891   if (PC.isInvalid())
892     return StmtError();
893 
894   Expr *PCE = ActOnFinishFullExpr(PC.get(), /*DiscardedValue*/ false).get();
895 
896   Stmt *Res = new (Context) CoreturnStmt(Loc, E, PCE, IsImplicit);
897   return Res;
898 }
899 
900 /// Look up the std::nothrow object.
901 static Expr *buildStdNoThrowDeclRef(Sema &S, SourceLocation Loc) {
902   NamespaceDecl *Std = S.getStdNamespace();
903   assert(Std && "Should already be diagnosed");
904 
905   LookupResult Result(S, &S.PP.getIdentifierTable().get("nothrow"), Loc,
906                       Sema::LookupOrdinaryName);
907   if (!S.LookupQualifiedName(Result, Std)) {
908     // FIXME: <experimental/coroutine> should have been included already.
909     // If we require it to include <new> then this diagnostic is no longer
910     // needed.
911     S.Diag(Loc, diag::err_implicit_coroutine_std_nothrow_type_not_found);
912     return nullptr;
913   }
914 
915   auto *VD = Result.getAsSingle<VarDecl>();
916   if (!VD) {
917     Result.suppressDiagnostics();
918     // We found something weird. Complain about the first thing we found.
919     NamedDecl *Found = *Result.begin();
920     S.Diag(Found->getLocation(), diag::err_malformed_std_nothrow);
921     return nullptr;
922   }
923 
924   ExprResult DR = S.BuildDeclRefExpr(VD, VD->getType(), VK_LValue, Loc);
925   if (DR.isInvalid())
926     return nullptr;
927 
928   return DR.get();
929 }
930 
931 // Find an appropriate delete for the promise.
932 static FunctionDecl *findDeleteForPromise(Sema &S, SourceLocation Loc,
933                                           QualType PromiseType) {
934   FunctionDecl *OperatorDelete = nullptr;
935 
936   DeclarationName DeleteName =
937       S.Context.DeclarationNames.getCXXOperatorName(OO_Delete);
938 
939   auto *PointeeRD = PromiseType->getAsCXXRecordDecl();
940   assert(PointeeRD && "PromiseType must be a CxxRecordDecl type");
941 
942   if (S.FindDeallocationFunction(Loc, PointeeRD, DeleteName, OperatorDelete))
943     return nullptr;
944 
945   if (!OperatorDelete) {
946     // Look for a global declaration.
947     const bool CanProvideSize = S.isCompleteType(Loc, PromiseType);
948     const bool Overaligned = false;
949     OperatorDelete = S.FindUsualDeallocationFunction(Loc, CanProvideSize,
950                                                      Overaligned, DeleteName);
951   }
952   S.MarkFunctionReferenced(Loc, OperatorDelete);
953   return OperatorDelete;
954 }
955 
956 
957 void Sema::CheckCompletedCoroutineBody(FunctionDecl *FD, Stmt *&Body) {
958   FunctionScopeInfo *Fn = getCurFunction();
959   assert(Fn && Fn->isCoroutine() && "not a coroutine");
960   if (!Body) {
961     assert(FD->isInvalidDecl() &&
962            "a null body is only allowed for invalid declarations");
963     return;
964   }
965   // We have a function that uses coroutine keywords, but we failed to build
966   // the promise type.
967   if (!Fn->CoroutinePromise)
968     return FD->setInvalidDecl();
969 
970   if (isa<CoroutineBodyStmt>(Body)) {
971     // Nothing todo. the body is already a transformed coroutine body statement.
972     return;
973   }
974 
975   // Coroutines [stmt.return]p1:
976   //   A return statement shall not appear in a coroutine.
977   if (Fn->FirstReturnLoc.isValid()) {
978     assert(Fn->FirstCoroutineStmtLoc.isValid() &&
979                    "first coroutine location not set");
980     Diag(Fn->FirstReturnLoc, diag::err_return_in_coroutine);
981     Diag(Fn->FirstCoroutineStmtLoc, diag::note_declared_coroutine_here)
982             << Fn->getFirstCoroutineStmtKeyword();
983   }
984   CoroutineStmtBuilder Builder(*this, *FD, *Fn, Body);
985   if (Builder.isInvalid() || !Builder.buildStatements())
986     return FD->setInvalidDecl();
987 
988   // Build body for the coroutine wrapper statement.
989   Body = CoroutineBodyStmt::Create(Context, Builder);
990 }
991 
992 CoroutineStmtBuilder::CoroutineStmtBuilder(Sema &S, FunctionDecl &FD,
993                                            sema::FunctionScopeInfo &Fn,
994                                            Stmt *Body)
995     : S(S), FD(FD), Fn(Fn), Loc(FD.getLocation()),
996       IsPromiseDependentType(
997           !Fn.CoroutinePromise ||
998           Fn.CoroutinePromise->getType()->isDependentType()) {
999   this->Body = Body;
1000 
1001   for (auto KV : Fn.CoroutineParameterMoves)
1002     this->ParamMovesVector.push_back(KV.second);
1003   this->ParamMoves = this->ParamMovesVector;
1004 
1005   if (!IsPromiseDependentType) {
1006     PromiseRecordDecl = Fn.CoroutinePromise->getType()->getAsCXXRecordDecl();
1007     assert(PromiseRecordDecl && "Type should have already been checked");
1008   }
1009   this->IsValid = makePromiseStmt() && makeInitialAndFinalSuspend();
1010 }
1011 
1012 bool CoroutineStmtBuilder::buildStatements() {
1013   assert(this->IsValid && "coroutine already invalid");
1014   this->IsValid = makeReturnObject();
1015   if (this->IsValid && !IsPromiseDependentType)
1016     buildDependentStatements();
1017   return this->IsValid;
1018 }
1019 
1020 bool CoroutineStmtBuilder::buildDependentStatements() {
1021   assert(this->IsValid && "coroutine already invalid");
1022   assert(!this->IsPromiseDependentType &&
1023          "coroutine cannot have a dependent promise type");
1024   this->IsValid = makeOnException() && makeOnFallthrough() &&
1025                   makeGroDeclAndReturnStmt() && makeReturnOnAllocFailure() &&
1026                   makeNewAndDeleteExpr();
1027   return this->IsValid;
1028 }
1029 
1030 bool CoroutineStmtBuilder::makePromiseStmt() {
1031   // Form a declaration statement for the promise declaration, so that AST
1032   // visitors can more easily find it.
1033   StmtResult PromiseStmt =
1034       S.ActOnDeclStmt(S.ConvertDeclToDeclGroup(Fn.CoroutinePromise), Loc, Loc);
1035   if (PromiseStmt.isInvalid())
1036     return false;
1037 
1038   this->Promise = PromiseStmt.get();
1039   return true;
1040 }
1041 
1042 bool CoroutineStmtBuilder::makeInitialAndFinalSuspend() {
1043   if (Fn.hasInvalidCoroutineSuspends())
1044     return false;
1045   this->InitialSuspend = cast<Expr>(Fn.CoroutineSuspends.first);
1046   this->FinalSuspend = cast<Expr>(Fn.CoroutineSuspends.second);
1047   return true;
1048 }
1049 
1050 static bool diagReturnOnAllocFailure(Sema &S, Expr *E,
1051                                      CXXRecordDecl *PromiseRecordDecl,
1052                                      FunctionScopeInfo &Fn) {
1053   auto Loc = E->getExprLoc();
1054   if (auto *DeclRef = dyn_cast_or_null<DeclRefExpr>(E)) {
1055     auto *Decl = DeclRef->getDecl();
1056     if (CXXMethodDecl *Method = dyn_cast_or_null<CXXMethodDecl>(Decl)) {
1057       if (Method->isStatic())
1058         return true;
1059       else
1060         Loc = Decl->getLocation();
1061     }
1062   }
1063 
1064   S.Diag(
1065       Loc,
1066       diag::err_coroutine_promise_get_return_object_on_allocation_failure)
1067       << PromiseRecordDecl;
1068   S.Diag(Fn.FirstCoroutineStmtLoc, diag::note_declared_coroutine_here)
1069       << Fn.getFirstCoroutineStmtKeyword();
1070   return false;
1071 }
1072 
1073 bool CoroutineStmtBuilder::makeReturnOnAllocFailure() {
1074   assert(!IsPromiseDependentType &&
1075          "cannot make statement while the promise type is dependent");
1076 
1077   // [dcl.fct.def.coroutine]/8
1078   // The unqualified-id get_return_object_on_allocation_failure is looked up in
1079   // the scope of class P by class member access lookup (3.4.5). ...
1080   // If an allocation function returns nullptr, ... the coroutine return value
1081   // is obtained by a call to ... get_return_object_on_allocation_failure().
1082 
1083   DeclarationName DN =
1084       S.PP.getIdentifierInfo("get_return_object_on_allocation_failure");
1085   LookupResult Found(S, DN, Loc, Sema::LookupMemberName);
1086   if (!S.LookupQualifiedName(Found, PromiseRecordDecl))
1087     return true;
1088 
1089   CXXScopeSpec SS;
1090   ExprResult DeclNameExpr =
1091       S.BuildDeclarationNameExpr(SS, Found, /*NeedsADL=*/false);
1092   if (DeclNameExpr.isInvalid())
1093     return false;
1094 
1095   if (!diagReturnOnAllocFailure(S, DeclNameExpr.get(), PromiseRecordDecl, Fn))
1096     return false;
1097 
1098   ExprResult ReturnObjectOnAllocationFailure =
1099       S.BuildCallExpr(nullptr, DeclNameExpr.get(), Loc, {}, Loc);
1100   if (ReturnObjectOnAllocationFailure.isInvalid())
1101     return false;
1102 
1103   StmtResult ReturnStmt =
1104       S.BuildReturnStmt(Loc, ReturnObjectOnAllocationFailure.get());
1105   if (ReturnStmt.isInvalid()) {
1106     S.Diag(Found.getFoundDecl()->getLocation(), diag::note_member_declared_here)
1107         << DN;
1108     S.Diag(Fn.FirstCoroutineStmtLoc, diag::note_declared_coroutine_here)
1109         << Fn.getFirstCoroutineStmtKeyword();
1110     return false;
1111   }
1112 
1113   this->ReturnStmtOnAllocFailure = ReturnStmt.get();
1114   return true;
1115 }
1116 
1117 bool CoroutineStmtBuilder::makeNewAndDeleteExpr() {
1118   // Form and check allocation and deallocation calls.
1119   assert(!IsPromiseDependentType &&
1120          "cannot make statement while the promise type is dependent");
1121   QualType PromiseType = Fn.CoroutinePromise->getType();
1122 
1123   if (S.RequireCompleteType(Loc, PromiseType, diag::err_incomplete_type))
1124     return false;
1125 
1126   const bool RequiresNoThrowAlloc = ReturnStmtOnAllocFailure != nullptr;
1127 
1128   // [dcl.fct.def.coroutine]/7
1129   // Lookup allocation functions using a parameter list composed of the
1130   // requested size of the coroutine state being allocated, followed by
1131   // the coroutine function's arguments. If a matching allocation function
1132   // exists, use it. Otherwise, use an allocation function that just takes
1133   // the requested size.
1134 
1135   FunctionDecl *OperatorNew = nullptr;
1136   FunctionDecl *OperatorDelete = nullptr;
1137   FunctionDecl *UnusedResult = nullptr;
1138   bool PassAlignment = false;
1139   SmallVector<Expr *, 1> PlacementArgs;
1140 
1141   // [dcl.fct.def.coroutine]/7
1142   // "The allocation function’s name is looked up in the scope of P.
1143   // [...] If the lookup finds an allocation function in the scope of P,
1144   // overload resolution is performed on a function call created by assembling
1145   // an argument list. The first argument is the amount of space requested,
1146   // and has type std::size_t. The lvalues p1 ... pn are the succeeding
1147   // arguments."
1148   //
1149   // ...where "p1 ... pn" are defined earlier as:
1150   //
1151   // [dcl.fct.def.coroutine]/3
1152   // "For a coroutine f that is a non-static member function, let P1 denote the
1153   // type of the implicit object parameter (13.3.1) and P2 ... Pn be the types
1154   // of the function parameters; otherwise let P1 ... Pn be the types of the
1155   // function parameters. Let p1 ... pn be lvalues denoting those objects."
1156   if (auto *MD = dyn_cast<CXXMethodDecl>(&FD)) {
1157     if (MD->isInstance() && !isLambdaCallOperator(MD)) {
1158       ExprResult ThisExpr = S.ActOnCXXThis(Loc);
1159       if (ThisExpr.isInvalid())
1160         return false;
1161       ThisExpr = S.CreateBuiltinUnaryOp(Loc, UO_Deref, ThisExpr.get());
1162       if (ThisExpr.isInvalid())
1163         return false;
1164       PlacementArgs.push_back(ThisExpr.get());
1165     }
1166   }
1167   for (auto *PD : FD.parameters()) {
1168     if (PD->getType()->isDependentType())
1169       continue;
1170 
1171     // Build a reference to the parameter.
1172     auto PDLoc = PD->getLocation();
1173     ExprResult PDRefExpr =
1174         S.BuildDeclRefExpr(PD, PD->getOriginalType().getNonReferenceType(),
1175                            ExprValueKind::VK_LValue, PDLoc);
1176     if (PDRefExpr.isInvalid())
1177       return false;
1178 
1179     PlacementArgs.push_back(PDRefExpr.get());
1180   }
1181   S.FindAllocationFunctions(Loc, SourceRange(), /*NewScope*/ Sema::AFS_Class,
1182                             /*DeleteScope*/ Sema::AFS_Both, PromiseType,
1183                             /*isArray*/ false, PassAlignment, PlacementArgs,
1184                             OperatorNew, UnusedResult, /*Diagnose*/ false);
1185 
1186   // [dcl.fct.def.coroutine]/7
1187   // "If no matching function is found, overload resolution is performed again
1188   // on a function call created by passing just the amount of space required as
1189   // an argument of type std::size_t."
1190   if (!OperatorNew && !PlacementArgs.empty()) {
1191     PlacementArgs.clear();
1192     S.FindAllocationFunctions(Loc, SourceRange(), /*NewScope*/ Sema::AFS_Class,
1193                               /*DeleteScope*/ Sema::AFS_Both, PromiseType,
1194                               /*isArray*/ false, PassAlignment, PlacementArgs,
1195                               OperatorNew, UnusedResult, /*Diagnose*/ false);
1196   }
1197 
1198   // [dcl.fct.def.coroutine]/7
1199   // "The allocation function’s name is looked up in the scope of P. If this
1200   // lookup fails, the allocation function’s name is looked up in the global
1201   // scope."
1202   if (!OperatorNew) {
1203     S.FindAllocationFunctions(Loc, SourceRange(), /*NewScope*/ Sema::AFS_Global,
1204                               /*DeleteScope*/ Sema::AFS_Both, PromiseType,
1205                               /*isArray*/ false, PassAlignment, PlacementArgs,
1206                               OperatorNew, UnusedResult);
1207   }
1208 
1209   bool IsGlobalOverload =
1210       OperatorNew && !isa<CXXRecordDecl>(OperatorNew->getDeclContext());
1211   // If we didn't find a class-local new declaration and non-throwing new
1212   // was is required then we need to lookup the non-throwing global operator
1213   // instead.
1214   if (RequiresNoThrowAlloc && (!OperatorNew || IsGlobalOverload)) {
1215     auto *StdNoThrow = buildStdNoThrowDeclRef(S, Loc);
1216     if (!StdNoThrow)
1217       return false;
1218     PlacementArgs = {StdNoThrow};
1219     OperatorNew = nullptr;
1220     S.FindAllocationFunctions(Loc, SourceRange(), /*NewScope*/ Sema::AFS_Both,
1221                               /*DeleteScope*/ Sema::AFS_Both, PromiseType,
1222                               /*isArray*/ false, PassAlignment, PlacementArgs,
1223                               OperatorNew, UnusedResult);
1224   }
1225 
1226   if (!OperatorNew)
1227     return false;
1228 
1229   if (RequiresNoThrowAlloc) {
1230     const auto *FT = OperatorNew->getType()->getAs<FunctionProtoType>();
1231     if (!FT->isNothrow(/*ResultIfDependent*/ false)) {
1232       S.Diag(OperatorNew->getLocation(),
1233              diag::err_coroutine_promise_new_requires_nothrow)
1234           << OperatorNew;
1235       S.Diag(Loc, diag::note_coroutine_promise_call_implicitly_required)
1236           << OperatorNew;
1237       return false;
1238     }
1239   }
1240 
1241   if ((OperatorDelete = findDeleteForPromise(S, Loc, PromiseType)) == nullptr)
1242     return false;
1243 
1244   Expr *FramePtr =
1245       buildBuiltinCall(S, Loc, Builtin::BI__builtin_coro_frame, {});
1246 
1247   Expr *FrameSize =
1248       buildBuiltinCall(S, Loc, Builtin::BI__builtin_coro_size, {});
1249 
1250   // Make new call.
1251 
1252   ExprResult NewRef =
1253       S.BuildDeclRefExpr(OperatorNew, OperatorNew->getType(), VK_LValue, Loc);
1254   if (NewRef.isInvalid())
1255     return false;
1256 
1257   SmallVector<Expr *, 2> NewArgs(1, FrameSize);
1258   for (auto Arg : PlacementArgs)
1259     NewArgs.push_back(Arg);
1260 
1261   ExprResult NewExpr =
1262       S.BuildCallExpr(S.getCurScope(), NewRef.get(), Loc, NewArgs, Loc);
1263   NewExpr = S.ActOnFinishFullExpr(NewExpr.get(), /*DiscardedValue*/ false);
1264   if (NewExpr.isInvalid())
1265     return false;
1266 
1267   // Make delete call.
1268 
1269   QualType OpDeleteQualType = OperatorDelete->getType();
1270 
1271   ExprResult DeleteRef =
1272       S.BuildDeclRefExpr(OperatorDelete, OpDeleteQualType, VK_LValue, Loc);
1273   if (DeleteRef.isInvalid())
1274     return false;
1275 
1276   Expr *CoroFree =
1277       buildBuiltinCall(S, Loc, Builtin::BI__builtin_coro_free, {FramePtr});
1278 
1279   SmallVector<Expr *, 2> DeleteArgs{CoroFree};
1280 
1281   // Check if we need to pass the size.
1282   const auto *OpDeleteType =
1283       OpDeleteQualType.getTypePtr()->getAs<FunctionProtoType>();
1284   if (OpDeleteType->getNumParams() > 1)
1285     DeleteArgs.push_back(FrameSize);
1286 
1287   ExprResult DeleteExpr =
1288       S.BuildCallExpr(S.getCurScope(), DeleteRef.get(), Loc, DeleteArgs, Loc);
1289   DeleteExpr =
1290       S.ActOnFinishFullExpr(DeleteExpr.get(), /*DiscardedValue*/ false);
1291   if (DeleteExpr.isInvalid())
1292     return false;
1293 
1294   this->Allocate = NewExpr.get();
1295   this->Deallocate = DeleteExpr.get();
1296 
1297   return true;
1298 }
1299 
1300 bool CoroutineStmtBuilder::makeOnFallthrough() {
1301   assert(!IsPromiseDependentType &&
1302          "cannot make statement while the promise type is dependent");
1303 
1304   // [dcl.fct.def.coroutine]/4
1305   // The unqualified-ids 'return_void' and 'return_value' are looked up in
1306   // the scope of class P. If both are found, the program is ill-formed.
1307   bool HasRVoid, HasRValue;
1308   LookupResult LRVoid =
1309       lookupMember(S, "return_void", PromiseRecordDecl, Loc, HasRVoid);
1310   LookupResult LRValue =
1311       lookupMember(S, "return_value", PromiseRecordDecl, Loc, HasRValue);
1312 
1313   StmtResult Fallthrough;
1314   if (HasRVoid && HasRValue) {
1315     // FIXME Improve this diagnostic
1316     S.Diag(FD.getLocation(),
1317            diag::err_coroutine_promise_incompatible_return_functions)
1318         << PromiseRecordDecl;
1319     S.Diag(LRVoid.getRepresentativeDecl()->getLocation(),
1320            diag::note_member_first_declared_here)
1321         << LRVoid.getLookupName();
1322     S.Diag(LRValue.getRepresentativeDecl()->getLocation(),
1323            diag::note_member_first_declared_here)
1324         << LRValue.getLookupName();
1325     return false;
1326   } else if (!HasRVoid && !HasRValue) {
1327     // FIXME: The PDTS currently specifies this case as UB, not ill-formed.
1328     // However we still diagnose this as an error since until the PDTS is fixed.
1329     S.Diag(FD.getLocation(),
1330            diag::err_coroutine_promise_requires_return_function)
1331         << PromiseRecordDecl;
1332     S.Diag(PromiseRecordDecl->getLocation(), diag::note_defined_here)
1333         << PromiseRecordDecl;
1334     return false;
1335   } else if (HasRVoid) {
1336     // If the unqualified-id return_void is found, flowing off the end of a
1337     // coroutine is equivalent to a co_return with no operand. Otherwise,
1338     // flowing off the end of a coroutine results in undefined behavior.
1339     Fallthrough = S.BuildCoreturnStmt(FD.getLocation(), nullptr,
1340                                       /*IsImplicit*/false);
1341     Fallthrough = S.ActOnFinishFullStmt(Fallthrough.get());
1342     if (Fallthrough.isInvalid())
1343       return false;
1344   }
1345 
1346   this->OnFallthrough = Fallthrough.get();
1347   return true;
1348 }
1349 
1350 bool CoroutineStmtBuilder::makeOnException() {
1351   // Try to form 'p.unhandled_exception();'
1352   assert(!IsPromiseDependentType &&
1353          "cannot make statement while the promise type is dependent");
1354 
1355   const bool RequireUnhandledException = S.getLangOpts().CXXExceptions;
1356 
1357   if (!lookupMember(S, "unhandled_exception", PromiseRecordDecl, Loc)) {
1358     auto DiagID =
1359         RequireUnhandledException
1360             ? diag::err_coroutine_promise_unhandled_exception_required
1361             : diag::
1362                   warn_coroutine_promise_unhandled_exception_required_with_exceptions;
1363     S.Diag(Loc, DiagID) << PromiseRecordDecl;
1364     S.Diag(PromiseRecordDecl->getLocation(), diag::note_defined_here)
1365         << PromiseRecordDecl;
1366     return !RequireUnhandledException;
1367   }
1368 
1369   // If exceptions are disabled, don't try to build OnException.
1370   if (!S.getLangOpts().CXXExceptions)
1371     return true;
1372 
1373   ExprResult UnhandledException = buildPromiseCall(S, Fn.CoroutinePromise, Loc,
1374                                                    "unhandled_exception", None);
1375   UnhandledException = S.ActOnFinishFullExpr(UnhandledException.get(), Loc,
1376                                              /*DiscardedValue*/ false);
1377   if (UnhandledException.isInvalid())
1378     return false;
1379 
1380   // Since the body of the coroutine will be wrapped in try-catch, it will
1381   // be incompatible with SEH __try if present in a function.
1382   if (!S.getLangOpts().Borland && Fn.FirstSEHTryLoc.isValid()) {
1383     S.Diag(Fn.FirstSEHTryLoc, diag::err_seh_in_a_coroutine_with_cxx_exceptions);
1384     S.Diag(Fn.FirstCoroutineStmtLoc, diag::note_declared_coroutine_here)
1385         << Fn.getFirstCoroutineStmtKeyword();
1386     return false;
1387   }
1388 
1389   this->OnException = UnhandledException.get();
1390   return true;
1391 }
1392 
1393 bool CoroutineStmtBuilder::makeReturnObject() {
1394   // Build implicit 'p.get_return_object()' expression and form initialization
1395   // of return type from it.
1396   ExprResult ReturnObject =
1397       buildPromiseCall(S, Fn.CoroutinePromise, Loc, "get_return_object", None);
1398   if (ReturnObject.isInvalid())
1399     return false;
1400 
1401   this->ReturnValue = ReturnObject.get();
1402   return true;
1403 }
1404 
1405 static void noteMemberDeclaredHere(Sema &S, Expr *E, FunctionScopeInfo &Fn) {
1406   if (auto *MbrRef = dyn_cast<CXXMemberCallExpr>(E)) {
1407     auto *MethodDecl = MbrRef->getMethodDecl();
1408     S.Diag(MethodDecl->getLocation(), diag::note_member_declared_here)
1409         << MethodDecl;
1410   }
1411   S.Diag(Fn.FirstCoroutineStmtLoc, diag::note_declared_coroutine_here)
1412       << Fn.getFirstCoroutineStmtKeyword();
1413 }
1414 
1415 bool CoroutineStmtBuilder::makeGroDeclAndReturnStmt() {
1416   assert(!IsPromiseDependentType &&
1417          "cannot make statement while the promise type is dependent");
1418   assert(this->ReturnValue && "ReturnValue must be already formed");
1419 
1420   QualType const GroType = this->ReturnValue->getType();
1421   assert(!GroType->isDependentType() &&
1422          "get_return_object type must no longer be dependent");
1423 
1424   QualType const FnRetType = FD.getReturnType();
1425   assert(!FnRetType->isDependentType() &&
1426          "get_return_object type must no longer be dependent");
1427 
1428   if (FnRetType->isVoidType()) {
1429     ExprResult Res =
1430         S.ActOnFinishFullExpr(this->ReturnValue, Loc, /*DiscardedValue*/ false);
1431     if (Res.isInvalid())
1432       return false;
1433 
1434     this->ResultDecl = Res.get();
1435     return true;
1436   }
1437 
1438   if (GroType->isVoidType()) {
1439     // Trigger a nice error message.
1440     InitializedEntity Entity =
1441         InitializedEntity::InitializeResult(Loc, FnRetType, false);
1442     S.PerformMoveOrCopyInitialization(Entity, nullptr, FnRetType, ReturnValue);
1443     noteMemberDeclaredHere(S, ReturnValue, Fn);
1444     return false;
1445   }
1446 
1447   auto *GroDecl = VarDecl::Create(
1448       S.Context, &FD, FD.getLocation(), FD.getLocation(),
1449       &S.PP.getIdentifierTable().get("__coro_gro"), GroType,
1450       S.Context.getTrivialTypeSourceInfo(GroType, Loc), SC_None);
1451 
1452   S.CheckVariableDeclarationType(GroDecl);
1453   if (GroDecl->isInvalidDecl())
1454     return false;
1455 
1456   InitializedEntity Entity = InitializedEntity::InitializeVariable(GroDecl);
1457   ExprResult Res = S.PerformMoveOrCopyInitialization(Entity, nullptr, GroType,
1458                                                      this->ReturnValue);
1459   if (Res.isInvalid())
1460     return false;
1461 
1462   Res = S.ActOnFinishFullExpr(Res.get(), /*DiscardedValue*/ false);
1463   if (Res.isInvalid())
1464     return false;
1465 
1466   S.AddInitializerToDecl(GroDecl, Res.get(),
1467                          /*DirectInit=*/false);
1468 
1469   S.FinalizeDeclaration(GroDecl);
1470 
1471   // Form a declaration statement for the return declaration, so that AST
1472   // visitors can more easily find it.
1473   StmtResult GroDeclStmt =
1474       S.ActOnDeclStmt(S.ConvertDeclToDeclGroup(GroDecl), Loc, Loc);
1475   if (GroDeclStmt.isInvalid())
1476     return false;
1477 
1478   this->ResultDecl = GroDeclStmt.get();
1479 
1480   ExprResult declRef = S.BuildDeclRefExpr(GroDecl, GroType, VK_LValue, Loc);
1481   if (declRef.isInvalid())
1482     return false;
1483 
1484   StmtResult ReturnStmt = S.BuildReturnStmt(Loc, declRef.get());
1485   if (ReturnStmt.isInvalid()) {
1486     noteMemberDeclaredHere(S, ReturnValue, Fn);
1487     return false;
1488   }
1489   if (cast<clang::ReturnStmt>(ReturnStmt.get())->getNRVOCandidate() == GroDecl)
1490     GroDecl->setNRVOVariable(true);
1491 
1492   this->ReturnStmt = ReturnStmt.get();
1493   return true;
1494 }
1495 
1496 // Create a static_cast\<T&&>(expr).
1497 static Expr *castForMoving(Sema &S, Expr *E, QualType T = QualType()) {
1498   if (T.isNull())
1499     T = E->getType();
1500   QualType TargetType = S.BuildReferenceType(
1501       T, /*SpelledAsLValue*/ false, SourceLocation(), DeclarationName());
1502   SourceLocation ExprLoc = E->getBeginLoc();
1503   TypeSourceInfo *TargetLoc =
1504       S.Context.getTrivialTypeSourceInfo(TargetType, ExprLoc);
1505 
1506   return S
1507       .BuildCXXNamedCast(ExprLoc, tok::kw_static_cast, TargetLoc, E,
1508                          SourceRange(ExprLoc, ExprLoc), E->getSourceRange())
1509       .get();
1510 }
1511 
1512 /// Build a variable declaration for move parameter.
1513 static VarDecl *buildVarDecl(Sema &S, SourceLocation Loc, QualType Type,
1514                              IdentifierInfo *II) {
1515   TypeSourceInfo *TInfo = S.Context.getTrivialTypeSourceInfo(Type, Loc);
1516   VarDecl *Decl = VarDecl::Create(S.Context, S.CurContext, Loc, Loc, II, Type,
1517                                   TInfo, SC_None);
1518   Decl->setImplicit();
1519   return Decl;
1520 }
1521 
1522 // Build statements that move coroutine function parameters to the coroutine
1523 // frame, and store them on the function scope info.
1524 bool Sema::buildCoroutineParameterMoves(SourceLocation Loc) {
1525   assert(isa<FunctionDecl>(CurContext) && "not in a function scope");
1526   auto *FD = cast<FunctionDecl>(CurContext);
1527 
1528   auto *ScopeInfo = getCurFunction();
1529   assert(ScopeInfo->CoroutineParameterMoves.empty() &&
1530          "Should not build parameter moves twice");
1531 
1532   for (auto *PD : FD->parameters()) {
1533     if (PD->getType()->isDependentType())
1534       continue;
1535 
1536     ExprResult PDRefExpr =
1537         BuildDeclRefExpr(PD, PD->getType().getNonReferenceType(),
1538                          ExprValueKind::VK_LValue, Loc); // FIXME: scope?
1539     if (PDRefExpr.isInvalid())
1540       return false;
1541 
1542     Expr *CExpr = nullptr;
1543     if (PD->getType()->getAsCXXRecordDecl() ||
1544         PD->getType()->isRValueReferenceType())
1545       CExpr = castForMoving(*this, PDRefExpr.get());
1546     else
1547       CExpr = PDRefExpr.get();
1548 
1549     auto D = buildVarDecl(*this, Loc, PD->getType(), PD->getIdentifier());
1550     AddInitializerToDecl(D, CExpr, /*DirectInit=*/true);
1551 
1552     // Convert decl to a statement.
1553     StmtResult Stmt = ActOnDeclStmt(ConvertDeclToDeclGroup(D), Loc, Loc);
1554     if (Stmt.isInvalid())
1555       return false;
1556 
1557     ScopeInfo->CoroutineParameterMoves.insert(std::make_pair(PD, Stmt.get()));
1558   }
1559   return true;
1560 }
1561 
1562 StmtResult Sema::BuildCoroutineBodyStmt(CoroutineBodyStmt::CtorArgs Args) {
1563   CoroutineBodyStmt *Res = CoroutineBodyStmt::Create(Context, Args);
1564   if (!Res)
1565     return StmtError();
1566   return Res;
1567 }
1568 
1569 ClassTemplateDecl *Sema::lookupCoroutineTraits(SourceLocation KwLoc,
1570                                                SourceLocation FuncLoc) {
1571   if (!StdCoroutineTraitsCache) {
1572     if (auto StdExp = lookupStdExperimentalNamespace()) {
1573       LookupResult Result(*this,
1574                           &PP.getIdentifierTable().get("coroutine_traits"),
1575                           FuncLoc, LookupOrdinaryName);
1576       if (!LookupQualifiedName(Result, StdExp)) {
1577         Diag(KwLoc, diag::err_implied_coroutine_type_not_found)
1578             << "std::experimental::coroutine_traits";
1579         return nullptr;
1580       }
1581       if (!(StdCoroutineTraitsCache =
1582                 Result.getAsSingle<ClassTemplateDecl>())) {
1583         Result.suppressDiagnostics();
1584         NamedDecl *Found = *Result.begin();
1585         Diag(Found->getLocation(), diag::err_malformed_std_coroutine_traits);
1586         return nullptr;
1587       }
1588     }
1589   }
1590   return StdCoroutineTraitsCache;
1591 }
1592