xref: /freebsd/contrib/llvm-project/clang/lib/Sema/SemaOpenACCAtomic.cpp (revision 700637cbb5e582861067a11aaca4d053546871d2)
1 //== SemaOpenACCAtomic.cpp - Semantic Analysis for OpenACC Atomic Construct===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 /// \file
9 /// This file implements semantic analysis for the OpenACC atomic construct.
10 ///
11 //===----------------------------------------------------------------------===//
12 
13 #include "clang/AST/ExprCXX.h"
14 #include "clang/Basic/DiagnosticSema.h"
15 #include "clang/Sema/SemaOpenACC.h"
16 
17 #include <optional>
18 
19 using namespace clang;
20 
21 namespace {
22 
23 class AtomicOperandChecker {
24   SemaOpenACC &SemaRef;
25   OpenACCAtomicKind AtKind;
26   SourceLocation AtomicDirLoc;
27   StmtResult AssocStmt;
28 
29   // Do a diagnostic, which sets the correct error, then displays passed note.
DiagnoseInvalidAtomic(SourceLocation Loc,PartialDiagnostic NoteDiag)30   bool DiagnoseInvalidAtomic(SourceLocation Loc, PartialDiagnostic NoteDiag) {
31     SemaRef.Diag(AtomicDirLoc, diag::err_acc_invalid_atomic)
32         << (AtKind != OpenACCAtomicKind::None) << AtKind;
33     SemaRef.Diag(Loc, NoteDiag);
34     return true;
35   }
36 
37   // Create a replacement recovery expr in case we find an error here.  This
38   // allows us to ignore this during template instantiation so we only get a
39   // single error.
getRecoveryExpr()40   StmtResult getRecoveryExpr() {
41     if (!AssocStmt.isUsable())
42       return AssocStmt;
43 
44     if (!SemaRef.getASTContext().getLangOpts().RecoveryAST)
45       return StmtError();
46 
47     Expr *E = dyn_cast<Expr>(AssocStmt.get());
48     QualType T = E ? E->getType() : SemaRef.getASTContext().DependentTy;
49 
50     return RecoveryExpr::Create(SemaRef.getASTContext(), T,
51                                 AssocStmt.get()->getBeginLoc(),
52                                 AssocStmt.get()->getEndLoc(),
53                                 E ? ArrayRef<Expr *>{E} : ArrayRef<Expr *>{});
54   }
55 
56   // OpenACC 3.3 2.12: 'expr' is an expression with scalar type.
CheckOperandExpr(const Expr * E,PartialDiagnostic PD)57   bool CheckOperandExpr(const Expr *E, PartialDiagnostic PD) {
58     QualType ExprTy = E->getType();
59 
60     // Scalar allowed, plus we allow instantiation dependent to support
61     // templates.
62     if (ExprTy->isInstantiationDependentType() || ExprTy->isScalarType())
63       return false;
64 
65     return DiagnoseInvalidAtomic(E->getExprLoc(),
66                                  PD << diag::OACCLValScalar::Scalar << ExprTy);
67   }
68 
69   // OpenACC 3.3 2.12: 'x' and 'v' (as applicable) are boht l-value expressoins
70   // with scalar type.
CheckOperandVariable(const Expr * E,PartialDiagnostic PD)71   bool CheckOperandVariable(const Expr *E, PartialDiagnostic PD) {
72     if (CheckOperandExpr(E, PD))
73       return true;
74 
75     if (E->isLValue())
76       return false;
77 
78     return DiagnoseInvalidAtomic(E->getExprLoc(),
79                                  PD << diag::OACCLValScalar::LVal);
80   }
81 
RequireExpr(Stmt * Stmt,PartialDiagnostic ExpectedNote)82   Expr *RequireExpr(Stmt *Stmt, PartialDiagnostic ExpectedNote) {
83     if (Expr *E = dyn_cast<Expr>(Stmt))
84       return E->IgnoreImpCasts();
85 
86     DiagnoseInvalidAtomic(Stmt->getBeginLoc(), ExpectedNote);
87     return nullptr;
88   }
89 
90   // A struct to hold the return the inner components of any operands, which
91   // allows for compound checking.
92   struct BinaryOpInfo {
93     const Expr *FoundExpr = nullptr;
94     const Expr *LHS = nullptr;
95     const Expr *RHS = nullptr;
96     BinaryOperatorKind Operator;
97   };
98 
99   struct UnaryOpInfo {
100     const Expr *FoundExpr = nullptr;
101     const Expr *SubExpr = nullptr;
102     UnaryOperatorKind Operator;
103 
IsIncrementOp__anonb915406b0111::AtomicOperandChecker::UnaryOpInfo104     bool IsIncrementOp() {
105       return Operator == UO_PostInc || Operator == UO_PreInc;
106     }
107   };
108 
GetUnaryOperatorInfo(const Expr * E)109   std::optional<UnaryOpInfo> GetUnaryOperatorInfo(const Expr *E) {
110     // If this is a simple unary operator, just return its details.
111     if (const auto *UO = dyn_cast<UnaryOperator>(E))
112       return UnaryOpInfo{UO, UO->getSubExpr()->IgnoreImpCasts(),
113                          UO->getOpcode()};
114 
115     // This might be an overloaded operator or a dependent context, so make sure
116     // we can get as many details out of this as we can.
117     if (const auto *OpCall = dyn_cast<CXXOperatorCallExpr>(E)) {
118       UnaryOpInfo Inf;
119       Inf.FoundExpr = OpCall;
120 
121       switch (OpCall->getOperator()) {
122       default:
123         return std::nullopt;
124       case OO_PlusPlus:
125         Inf.Operator = OpCall->getNumArgs() == 1 ? UO_PreInc : UO_PostInc;
126         break;
127       case OO_MinusMinus:
128         Inf.Operator = OpCall->getNumArgs() == 1 ? UO_PreDec : UO_PostDec;
129         break;
130       case OO_Amp:
131         Inf.Operator = UO_AddrOf;
132         break;
133       case OO_Star:
134         Inf.Operator = UO_Deref;
135         break;
136       case OO_Plus:
137         Inf.Operator = UO_Plus;
138         break;
139       case OO_Minus:
140         Inf.Operator = UO_Minus;
141         break;
142       case OO_Tilde:
143         Inf.Operator = UO_Not;
144         break;
145       case OO_Exclaim:
146         Inf.Operator = UO_LNot;
147         break;
148       case OO_Coawait:
149         Inf.Operator = UO_Coawait;
150         break;
151       }
152 
153       // Some of the above can be both binary and unary operations, so make sure
154       // we get the right one.
155       if (Inf.Operator != UO_PostInc && Inf.Operator != UO_PostDec &&
156           OpCall->getNumArgs() != 1)
157         return std::nullopt;
158 
159       Inf.SubExpr = OpCall->getArg(0);
160       return Inf;
161     }
162     return std::nullopt;
163   }
164 
165   // Get a normalized version of a binary operator.
GetBinaryOperatorInfo(const Expr * E)166   std::optional<BinaryOpInfo> GetBinaryOperatorInfo(const Expr *E) {
167     if (const auto *BO = dyn_cast<BinaryOperator>(E))
168       return BinaryOpInfo{BO, BO->getLHS()->IgnoreImpCasts(),
169                           BO->getRHS()->IgnoreImpCasts(), BO->getOpcode()};
170 
171     // In case this is an operator-call, which allows us to support overloaded
172     // operators and dependent expression.
173     if (const auto *OpCall = dyn_cast<CXXOperatorCallExpr>(E)) {
174       BinaryOpInfo Inf;
175       Inf.FoundExpr = OpCall;
176 
177       switch (OpCall->getOperator()) {
178       default:
179         return std::nullopt;
180       case OO_Plus:
181         Inf.Operator = BO_Add;
182         break;
183       case OO_Minus:
184         Inf.Operator = BO_Sub;
185         break;
186       case OO_Star:
187         Inf.Operator = BO_Mul;
188         break;
189       case OO_Slash:
190         Inf.Operator = BO_Div;
191         break;
192       case OO_Percent:
193         Inf.Operator = BO_Rem;
194         break;
195       case OO_Caret:
196         Inf.Operator = BO_Xor;
197         break;
198       case OO_Amp:
199         Inf.Operator = BO_And;
200         break;
201       case OO_Pipe:
202         Inf.Operator = BO_Or;
203         break;
204       case OO_Equal:
205         Inf.Operator = BO_Assign;
206         break;
207       case OO_Spaceship:
208         Inf.Operator = BO_Cmp;
209         break;
210       case OO_Less:
211         Inf.Operator = BO_LT;
212         break;
213       case OO_Greater:
214         Inf.Operator = BO_GT;
215         break;
216       case OO_PlusEqual:
217         Inf.Operator = BO_AddAssign;
218         break;
219       case OO_MinusEqual:
220         Inf.Operator = BO_SubAssign;
221         break;
222       case OO_StarEqual:
223         Inf.Operator = BO_MulAssign;
224         break;
225       case OO_SlashEqual:
226         Inf.Operator = BO_DivAssign;
227         break;
228       case OO_PercentEqual:
229         Inf.Operator = BO_RemAssign;
230         break;
231       case OO_CaretEqual:
232         Inf.Operator = BO_XorAssign;
233         break;
234       case OO_AmpEqual:
235         Inf.Operator = BO_AndAssign;
236         break;
237       case OO_PipeEqual:
238         Inf.Operator = BO_OrAssign;
239         break;
240       case OO_LessLess:
241         Inf.Operator = BO_Shl;
242         break;
243       case OO_GreaterGreater:
244         Inf.Operator = BO_Shr;
245         break;
246       case OO_LessLessEqual:
247         Inf.Operator = BO_ShlAssign;
248         break;
249       case OO_GreaterGreaterEqual:
250         Inf.Operator = BO_ShrAssign;
251         break;
252       case OO_EqualEqual:
253         Inf.Operator = BO_EQ;
254         break;
255       case OO_ExclaimEqual:
256         Inf.Operator = BO_NE;
257         break;
258       case OO_LessEqual:
259         Inf.Operator = BO_LE;
260         break;
261       case OO_GreaterEqual:
262         Inf.Operator = BO_GE;
263         break;
264       case OO_AmpAmp:
265         Inf.Operator = BO_LAnd;
266         break;
267       case OO_PipePipe:
268         Inf.Operator = BO_LOr;
269         break;
270       case OO_Comma:
271         Inf.Operator = BO_Comma;
272         break;
273       case OO_ArrowStar:
274         Inf.Operator = BO_PtrMemI;
275         break;
276       }
277 
278       // This isn't a binary operator unless there are two arguments.
279       if (OpCall->getNumArgs() != 2)
280         return std::nullopt;
281 
282       // Callee is the call-operator, so we only need to extract the two
283       // arguments here.
284       Inf.LHS = OpCall->getArg(0)->IgnoreImpCasts();
285       Inf.RHS = OpCall->getArg(1)->IgnoreImpCasts();
286       return Inf;
287     }
288 
289     return std::nullopt;
290   }
291 
292   // Checks a required assignment operation, but don't check the LHS or RHS,
293   // callers have to do that here.
CheckAssignment(const Expr * E)294   std::optional<BinaryOpInfo> CheckAssignment(const Expr *E) {
295     std::optional<BinaryOpInfo> Inf = GetBinaryOperatorInfo(E);
296 
297     if (!Inf) {
298       DiagnoseInvalidAtomic(E->getExprLoc(),
299                             SemaRef.PDiag(diag::note_acc_atomic_expr_must_be)
300                                 << diag::OACCAtomicExpr::Assign);
301       return std::nullopt;
302     }
303 
304     if (Inf->Operator != BO_Assign) {
305       DiagnoseInvalidAtomic(Inf->FoundExpr->getExprLoc(),
306                             SemaRef.PDiag(diag::note_acc_atomic_expr_must_be)
307                                 << diag::OACCAtomicExpr::Assign);
308       return std::nullopt;
309     }
310 
311     // Assignment always requires an lvalue/scalar on the LHS.
312     if (CheckOperandVariable(
313             Inf->LHS, SemaRef.PDiag(diag::note_acc_atomic_operand_lvalue_scalar)
314                           << /*left=*/0 << diag::OACCAtomicOpKind::Assign))
315       return std::nullopt;
316 
317     return Inf;
318   }
319 
320   struct IDACInfo {
321     bool Failed = false;
322     enum ExprKindTy {
323       Invalid,
324       // increment/decrement ops.
325       Unary,
326       // v = x
327       SimpleAssign,
328       // x = expr
329       ExprAssign,
330       // x binop= expr
331       CompoundAssign,
332       // x = x binop expr
333       // x = expr binop x
334       AssignBinOp
335     } ExprKind;
336 
337     // The variable referred to as 'x' in all of the grammar, such that it is
338     // needed in compound statement checking of capture to check between the two
339     // expressions.
340     const Expr *X_Var = nullptr;
341 
Fail__anonb915406b0111::AtomicOperandChecker::IDACInfo342     static IDACInfo Fail() { return IDACInfo{true, Invalid, nullptr}; };
343   };
344 
345   // Helper for CheckIncDecAssignCompoundAssign, does checks for inc/dec.
CheckIncDec(UnaryOpInfo Inf)346   IDACInfo CheckIncDec(UnaryOpInfo Inf) {
347 
348     if (!UnaryOperator::isIncrementDecrementOp(Inf.Operator)) {
349       DiagnoseInvalidAtomic(
350           Inf.FoundExpr->getExprLoc(),
351           SemaRef.PDiag(diag::note_acc_atomic_unsupported_unary_operator));
352       return IDACInfo::Fail();
353     }
354     bool Failed = CheckOperandVariable(
355         Inf.SubExpr,
356         SemaRef.PDiag(diag::note_acc_atomic_operand_lvalue_scalar)
357             << /*none=*/2
358             << (Inf.IsIncrementOp() ? diag::OACCAtomicOpKind::Inc
359                                     : diag::OACCAtomicOpKind::Dec));
360     // For increment/decrements, the subexpr is the 'x' (x++, ++x, etc).
361     return IDACInfo{Failed, IDACInfo::Unary, Inf.SubExpr};
362   }
363 
364   enum class SimpleAssignKind { None, Var, Expr };
365 
366   // Check an assignment, and ensure the RHS is either x binop expr or expr
367   // binop x.
368   // If AllowSimpleAssign, also allows v = x;
CheckAssignmentWithBinOpOnRHS(BinaryOpInfo AssignInf,SimpleAssignKind SAK)369   IDACInfo CheckAssignmentWithBinOpOnRHS(BinaryOpInfo AssignInf,
370                                          SimpleAssignKind SAK) {
371     PartialDiagnostic PD =
372         SemaRef.PDiag(diag::note_acc_atomic_operand_lvalue_scalar)
373         << /*left=*/0 << diag::OACCAtomicOpKind::Assign;
374     if (CheckOperandVariable(AssignInf.LHS, PD))
375       return IDACInfo::Fail();
376 
377     std::optional<BinaryOpInfo> BinInf = GetBinaryOperatorInfo(AssignInf.RHS);
378 
379     if (!BinInf) {
380 
381       // Capture in a compound statement allows v = x assignment.  So make sure
382       // we permit that here.
383       if (SAK != SimpleAssignKind::None) {
384         PartialDiagnostic PD =
385             SemaRef.PDiag(diag::note_acc_atomic_operand_lvalue_scalar)
386             << /*right=*/1 << diag::OACCAtomicOpKind::Assign;
387         if (SAK == SimpleAssignKind::Var) {
388           // In the var version, everywhere we allow v = x;, X is the RHS.
389           return IDACInfo{CheckOperandVariable(AssignInf.RHS, PD),
390                           IDACInfo::SimpleAssign, AssignInf.RHS};
391         }
392         assert(SAK == SimpleAssignKind::Expr);
393         // In the expression version, supported by v=x; x = expr;, we need to
394         // set to the LHS here.
395         return IDACInfo{CheckOperandExpr(AssignInf.RHS, PD),
396                         IDACInfo::ExprAssign, AssignInf.LHS};
397       }
398 
399       DiagnoseInvalidAtomic(
400           AssignInf.RHS->getExprLoc(),
401           SemaRef.PDiag(diag::note_acc_atomic_expected_binop));
402 
403       return IDACInfo::Fail();
404     }
405     switch (BinInf->Operator) {
406     default:
407       DiagnoseInvalidAtomic(
408           BinInf->FoundExpr->getExprLoc(),
409           SemaRef.PDiag(diag::note_acc_atomic_unsupported_binary_operator));
410       return IDACInfo::Fail();
411       // binop is one of +, *, -, /, &, ^, |, <<, or >>
412     case BO_Add:
413     case BO_Mul:
414     case BO_Sub:
415     case BO_Div:
416     case BO_And:
417     case BO_Xor:
418     case BO_Or:
419     case BO_Shl:
420     case BO_Shr:
421       // Handle these outside of the switch.
422       break;
423     }
424 
425     llvm::FoldingSetNodeID LHS_ID, InnerLHS_ID, InnerRHS_ID;
426     AssignInf.LHS->Profile(LHS_ID, SemaRef.getASTContext(),
427                            /*Canonical=*/true);
428     BinInf->LHS->Profile(InnerLHS_ID, SemaRef.getASTContext(),
429                          /*Canonical=*/true);
430 
431     // This is X = X binop expr;
432     // Check the RHS is an expression.
433     if (LHS_ID == InnerLHS_ID)
434       return IDACInfo{
435           CheckOperandExpr(
436               BinInf->RHS,
437               SemaRef.PDiag(diag::note_acc_atomic_operand_lvalue_scalar
438                             << /*right=*/1
439                             << diag::OACCAtomicOpKind::CompoundAssign)),
440           IDACInfo::AssignBinOp, AssignInf.LHS};
441 
442     BinInf->RHS->Profile(InnerRHS_ID, SemaRef.getASTContext(),
443                          /*Canonical=*/true);
444     // This is X = expr binop X;
445     // Check the LHS is an expression
446     if (LHS_ID == InnerRHS_ID)
447       return IDACInfo{
448           CheckOperandExpr(
449               BinInf->LHS,
450               SemaRef.PDiag(diag::note_acc_atomic_operand_lvalue_scalar)
451                   << /*left=*/0 << diag::OACCAtomicOpKind::CompoundAssign),
452           IDACInfo::AssignBinOp, AssignInf.LHS};
453 
454     // If nothing matches, error out.
455     DiagnoseInvalidAtomic(BinInf->FoundExpr->getExprLoc(),
456                           SemaRef.PDiag(diag::note_acc_atomic_mismatch_operand)
457                               << const_cast<Expr *>(AssignInf.LHS)
458                               << const_cast<Expr *>(BinInf->LHS)
459                               << const_cast<Expr *>(BinInf->RHS));
460     return IDACInfo::Fail();
461   }
462 
463   // Ensures that the expression is an increment/decrement, an assignment, or a
464   // compound assignment. If its an assignment, allows the x binop expr/x binop
465   // expr syntax. If it is a compound-assignment, allows any expr on the RHS.
CheckIncDecAssignCompoundAssign(const Expr * E,SimpleAssignKind SAK)466   IDACInfo CheckIncDecAssignCompoundAssign(const Expr *E,
467                                            SimpleAssignKind SAK) {
468     std::optional<UnaryOpInfo> UInf = GetUnaryOperatorInfo(E);
469 
470     // If this is a unary operator, only increment/decrement are allowed, so get
471     // unary operator, then check everything we can.
472     if (UInf)
473       return CheckIncDec(*UInf);
474 
475     std::optional<BinaryOpInfo> BinInf = GetBinaryOperatorInfo(E);
476 
477     // Unary or binary operator were the only choices, so error here.
478     if (!BinInf) {
479       DiagnoseInvalidAtomic(E->getExprLoc(),
480                             SemaRef.PDiag(diag::note_acc_atomic_expr_must_be)
481                                 << diag::OACCAtomicExpr::UnaryCompAssign);
482       return IDACInfo::Fail();
483     }
484 
485     switch (BinInf->Operator) {
486     default:
487       DiagnoseInvalidAtomic(
488           BinInf->FoundExpr->getExprLoc(),
489           SemaRef.PDiag(
490               diag::note_acc_atomic_unsupported_compound_binary_operator));
491       return IDACInfo::Fail();
492     case BO_Assign:
493       return CheckAssignmentWithBinOpOnRHS(*BinInf, SAK);
494     case BO_AddAssign:
495     case BO_MulAssign:
496     case BO_SubAssign:
497     case BO_DivAssign:
498     case BO_AndAssign:
499     case BO_XorAssign:
500     case BO_OrAssign:
501     case BO_ShlAssign:
502     case BO_ShrAssign: {
503       PartialDiagnostic LPD =
504           SemaRef.PDiag(diag::note_acc_atomic_operand_lvalue_scalar)
505           << /*left=*/0 << diag::OACCAtomicOpKind::CompoundAssign;
506       PartialDiagnostic RPD =
507           SemaRef.PDiag(diag::note_acc_atomic_operand_lvalue_scalar)
508           << /*right=*/1 << diag::OACCAtomicOpKind::CompoundAssign;
509       // nothing to do other than check the variable expressions.
510       // success or failure
511       bool Failed = CheckOperandVariable(BinInf->LHS, LPD) ||
512                     CheckOperandExpr(BinInf->RHS, RPD);
513 
514       return IDACInfo{Failed, IDACInfo::CompoundAssign, BinInf->LHS};
515     }
516     }
517     llvm_unreachable("all binary operator kinds should be checked above");
518   }
519 
CheckRead()520   StmtResult CheckRead() {
521     Expr *AssocExpr = RequireExpr(
522         AssocStmt.get(), SemaRef.PDiag(diag::note_acc_atomic_expr_must_be)
523                              << diag::OACCAtomicExpr::Assign);
524 
525     if (!AssocExpr)
526       return getRecoveryExpr();
527 
528     std::optional<BinaryOpInfo> AssignRes = CheckAssignment(AssocExpr);
529     if (!AssignRes)
530       return getRecoveryExpr();
531 
532     PartialDiagnostic PD =
533         SemaRef.PDiag(diag::note_acc_atomic_operand_lvalue_scalar)
534         << /*right=*/1 << diag::OACCAtomicOpKind::Assign;
535 
536     // Finally, check the RHS.
537     if (CheckOperandVariable(AssignRes->RHS, PD))
538       return getRecoveryExpr();
539 
540     return AssocStmt;
541   }
542 
CheckWrite()543   StmtResult CheckWrite() {
544     Expr *AssocExpr = RequireExpr(
545         AssocStmt.get(), SemaRef.PDiag(diag::note_acc_atomic_expr_must_be)
546                              << diag::OACCAtomicExpr::Assign);
547 
548     if (!AssocExpr)
549       return getRecoveryExpr();
550 
551     std::optional<BinaryOpInfo> AssignRes = CheckAssignment(AssocExpr);
552     if (!AssignRes)
553       return getRecoveryExpr();
554 
555     PartialDiagnostic PD =
556         SemaRef.PDiag(diag::note_acc_atomic_operand_lvalue_scalar)
557         << /*right=*/1 << diag::OACCAtomicOpKind::Assign;
558 
559     // Finally, check the RHS.
560     if (CheckOperandExpr(AssignRes->RHS, PD))
561       return getRecoveryExpr();
562 
563     return AssocStmt;
564   }
565 
CheckUpdate()566   StmtResult CheckUpdate() {
567     Expr *AssocExpr = RequireExpr(
568         AssocStmt.get(), SemaRef.PDiag(diag::note_acc_atomic_expr_must_be)
569                              << diag::OACCAtomicExpr::UnaryCompAssign);
570 
571     if (!AssocExpr ||
572         CheckIncDecAssignCompoundAssign(AssocExpr, SimpleAssignKind::None)
573             .Failed)
574       return getRecoveryExpr();
575 
576     return AssocStmt;
577   }
578 
CheckVarRefsSame(IDACInfo::ExprKindTy FirstKind,const Expr * FirstX,IDACInfo::ExprKindTy SecondKind,const Expr * SecondX)579   bool CheckVarRefsSame(IDACInfo::ExprKindTy FirstKind, const Expr *FirstX,
580                         IDACInfo::ExprKindTy SecondKind, const Expr *SecondX) {
581     llvm::FoldingSetNodeID First_ID, Second_ID;
582     FirstX->Profile(First_ID, SemaRef.getASTContext(), /*Canonical=*/true);
583     SecondX->Profile(Second_ID, SemaRef.getASTContext(), /*Canonical=*/true);
584 
585     if (First_ID == Second_ID)
586       return false;
587 
588     PartialDiagnostic PD =
589         SemaRef.PDiag(diag::note_acc_atomic_mismatch_compound_operand)
590         << FirstKind << const_cast<Expr *>(FirstX) << SecondKind
591         << const_cast<Expr *>(SecondX);
592 
593     return DiagnoseInvalidAtomic(SecondX->getExprLoc(), PD);
594   }
595 
CheckCapture()596   StmtResult CheckCapture() {
597     if (const auto *CmpdStmt = dyn_cast<CompoundStmt>(AssocStmt.get())) {
598       auto *const *BodyItr = CmpdStmt->body().begin();
599       PartialDiagnostic PD = SemaRef.PDiag(diag::note_acc_atomic_expr_must_be)
600                              << diag::OACCAtomicExpr::UnaryCompAssign;
601       // If we don't have at least 1 statement, error.
602       if (BodyItr == CmpdStmt->body().end()) {
603         DiagnoseInvalidAtomic(CmpdStmt->getBeginLoc(), PD);
604         return getRecoveryExpr();
605       }
606 
607       // First Expr can be inc/dec, assign, or compound assign.
608       Expr *FirstExpr = RequireExpr(*BodyItr, PD);
609       if (!FirstExpr)
610         return getRecoveryExpr();
611 
612       IDACInfo FirstExprResults =
613           CheckIncDecAssignCompoundAssign(FirstExpr, SimpleAssignKind::Var);
614       if (FirstExprResults.Failed)
615         return getRecoveryExpr();
616 
617       ++BodyItr;
618 
619       // If we don't have second statement, error.
620       if (BodyItr == CmpdStmt->body().end()) {
621         DiagnoseInvalidAtomic(CmpdStmt->getEndLoc(), PD);
622         return getRecoveryExpr();
623       }
624 
625       Expr *SecondExpr = RequireExpr(*BodyItr, PD);
626       if (!SecondExpr)
627         return getRecoveryExpr();
628 
629       assert(FirstExprResults.ExprKind != IDACInfo::Invalid);
630 
631       switch (FirstExprResults.ExprKind) {
632       case IDACInfo::Invalid:
633       case IDACInfo::ExprAssign:
634         llvm_unreachable("Should have error'ed out by now");
635       case IDACInfo::Unary:
636       case IDACInfo::CompoundAssign:
637       case IDACInfo::AssignBinOp: {
638         // Everything but simple-assign can only be followed by a simple
639         // assignment.
640         std::optional<BinaryOpInfo> AssignRes = CheckAssignment(SecondExpr);
641         if (!AssignRes)
642           return getRecoveryExpr();
643 
644         PartialDiagnostic PD =
645             SemaRef.PDiag(diag::note_acc_atomic_operand_lvalue_scalar)
646             << /*right=*/1 << diag::OACCAtomicOpKind::Assign;
647 
648         if (CheckOperandVariable(AssignRes->RHS, PD))
649           return getRecoveryExpr();
650 
651         if (CheckVarRefsSame(FirstExprResults.ExprKind, FirstExprResults.X_Var,
652                              IDACInfo::SimpleAssign, AssignRes->RHS))
653           return getRecoveryExpr();
654         break;
655       }
656       case IDACInfo::SimpleAssign: {
657         // If the first was v = x, anything but simple expression is allowed.
658         IDACInfo SecondExprResults =
659             CheckIncDecAssignCompoundAssign(SecondExpr, SimpleAssignKind::Expr);
660         if (SecondExprResults.Failed)
661           return getRecoveryExpr();
662 
663         if (CheckVarRefsSame(FirstExprResults.ExprKind, FirstExprResults.X_Var,
664                              SecondExprResults.ExprKind,
665                              SecondExprResults.X_Var))
666           return getRecoveryExpr();
667         break;
668       }
669       }
670       ++BodyItr;
671       if (BodyItr != CmpdStmt->body().end()) {
672         DiagnoseInvalidAtomic(
673             (*BodyItr)->getBeginLoc(),
674             SemaRef.PDiag(diag::note_acc_atomic_too_many_stmts));
675         return getRecoveryExpr();
676       }
677     } else {
678       // This check doesn't need to happen if it is a compound stmt.
679       Expr *AssocExpr = RequireExpr(
680           AssocStmt.get(), SemaRef.PDiag(diag::note_acc_atomic_expr_must_be)
681                                << diag::OACCAtomicExpr::Assign);
682       if (!AssocExpr)
683         return getRecoveryExpr();
684 
685       // First, we require an assignment.
686       std::optional<BinaryOpInfo> AssignRes = CheckAssignment(AssocExpr);
687 
688       if (!AssignRes)
689         return getRecoveryExpr();
690 
691       if (CheckIncDecAssignCompoundAssign(AssignRes->RHS,
692                                           SimpleAssignKind::None)
693               .Failed)
694         return getRecoveryExpr();
695     }
696 
697     return AssocStmt;
698   }
699 
700 public:
AtomicOperandChecker(SemaOpenACC & S,OpenACCAtomicKind AtKind,SourceLocation DirLoc,StmtResult AssocStmt)701   AtomicOperandChecker(SemaOpenACC &S, OpenACCAtomicKind AtKind,
702                        SourceLocation DirLoc, StmtResult AssocStmt)
703       : SemaRef(S), AtKind(AtKind), AtomicDirLoc(DirLoc), AssocStmt(AssocStmt) {
704   }
705 
Check()706   StmtResult Check() {
707 
708     switch (AtKind) {
709     case OpenACCAtomicKind::Read:
710       return CheckRead();
711     case OpenACCAtomicKind::Write:
712       return CheckWrite();
713     case OpenACCAtomicKind::None:
714     case OpenACCAtomicKind::Update:
715       return CheckUpdate();
716     case OpenACCAtomicKind::Capture:
717       return CheckCapture();
718     }
719     llvm_unreachable("Unhandled atomic kind?");
720   }
721 };
722 } // namespace
723 
CheckAtomicAssociatedStmt(SourceLocation AtomicDirLoc,OpenACCAtomicKind AtKind,StmtResult AssocStmt)724 StmtResult SemaOpenACC::CheckAtomicAssociatedStmt(SourceLocation AtomicDirLoc,
725                                                   OpenACCAtomicKind AtKind,
726                                                   StmtResult AssocStmt) {
727   if (!AssocStmt.isUsable())
728     return AssocStmt;
729 
730   if (isa<RecoveryExpr>(AssocStmt.get()))
731     return AssocStmt;
732 
733   AtomicOperandChecker Checker{*this, AtKind, AtomicDirLoc, AssocStmt};
734   return Checker.Check();
735 }
736