xref: /freebsd/contrib/llvm-project/clang/lib/CIR/CodeGen/CIRGenStmt.cpp (revision 700637cbb5e582861067a11aaca4d053546871d2)
1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Emit Stmt nodes as CIR code.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "CIRGenBuilder.h"
14 #include "CIRGenFunction.h"
15 
16 #include "mlir/IR/Builders.h"
17 #include "clang/AST/ExprCXX.h"
18 #include "clang/AST/Stmt.h"
19 #include "clang/AST/StmtOpenACC.h"
20 #include "clang/CIR/MissingFeatures.h"
21 
22 using namespace clang;
23 using namespace clang::CIRGen;
24 using namespace cir;
25 
emitCompoundStmtWithoutScope(const CompoundStmt & s)26 void CIRGenFunction::emitCompoundStmtWithoutScope(const CompoundStmt &s) {
27   for (auto *curStmt : s.body()) {
28     if (emitStmt(curStmt, /*useCurrentScope=*/false).failed())
29       getCIRGenModule().errorNYI(curStmt->getSourceRange(),
30                                  std::string("emitCompoundStmtWithoutScope: ") +
31                                      curStmt->getStmtClassName());
32   }
33 }
34 
emitCompoundStmt(const CompoundStmt & s)35 void CIRGenFunction::emitCompoundStmt(const CompoundStmt &s) {
36   mlir::Location scopeLoc = getLoc(s.getSourceRange());
37   mlir::OpBuilder::InsertPoint scopeInsPt;
38   builder.create<cir::ScopeOp>(
39       scopeLoc, [&](mlir::OpBuilder &b, mlir::Type &type, mlir::Location loc) {
40         scopeInsPt = b.saveInsertionPoint();
41       });
42   {
43     mlir::OpBuilder::InsertionGuard guard(builder);
44     builder.restoreInsertionPoint(scopeInsPt);
45     LexicalScope lexScope(*this, scopeLoc, builder.getInsertionBlock());
46     emitCompoundStmtWithoutScope(s);
47   }
48 }
49 
emitStopPoint(const Stmt * s)50 void CIRGenFunction::emitStopPoint(const Stmt *s) {
51   assert(!cir::MissingFeatures::generateDebugInfo());
52 }
53 
54 // Build CIR for a statement. useCurrentScope should be true if no new scopes
55 // need to be created when finding a compound statement.
emitStmt(const Stmt * s,bool useCurrentScope,ArrayRef<const Attr * > attr)56 mlir::LogicalResult CIRGenFunction::emitStmt(const Stmt *s,
57                                              bool useCurrentScope,
58                                              ArrayRef<const Attr *> attr) {
59   if (mlir::succeeded(emitSimpleStmt(s, useCurrentScope)))
60     return mlir::success();
61 
62   switch (s->getStmtClass()) {
63   case Stmt::NoStmtClass:
64   case Stmt::CXXCatchStmtClass:
65   case Stmt::SEHExceptStmtClass:
66   case Stmt::SEHFinallyStmtClass:
67   case Stmt::MSDependentExistsStmtClass:
68     llvm_unreachable("invalid statement class to emit generically");
69   case Stmt::BreakStmtClass:
70   case Stmt::NullStmtClass:
71   case Stmt::CompoundStmtClass:
72   case Stmt::ContinueStmtClass:
73   case Stmt::DeclStmtClass:
74   case Stmt::ReturnStmtClass:
75     llvm_unreachable("should have emitted these statements as simple");
76 
77 #define STMT(Type, Base)
78 #define ABSTRACT_STMT(Op)
79 #define EXPR(Type, Base) case Stmt::Type##Class:
80 #include "clang/AST/StmtNodes.inc"
81     {
82       // Remember the block we came in on.
83       mlir::Block *incoming = builder.getInsertionBlock();
84       assert(incoming && "expression emission must have an insertion point");
85 
86       emitIgnoredExpr(cast<Expr>(s));
87 
88       mlir::Block *outgoing = builder.getInsertionBlock();
89       assert(outgoing && "expression emission cleared block!");
90       return mlir::success();
91     }
92   case Stmt::IfStmtClass:
93     return emitIfStmt(cast<IfStmt>(*s));
94   case Stmt::SwitchStmtClass:
95     return emitSwitchStmt(cast<SwitchStmt>(*s));
96   case Stmt::ForStmtClass:
97     return emitForStmt(cast<ForStmt>(*s));
98   case Stmt::WhileStmtClass:
99     return emitWhileStmt(cast<WhileStmt>(*s));
100   case Stmt::DoStmtClass:
101     return emitDoStmt(cast<DoStmt>(*s));
102   case Stmt::CXXForRangeStmtClass:
103     return emitCXXForRangeStmt(cast<CXXForRangeStmt>(*s), attr);
104   case Stmt::OpenACCComputeConstructClass:
105     return emitOpenACCComputeConstruct(cast<OpenACCComputeConstruct>(*s));
106   case Stmt::OpenACCLoopConstructClass:
107     return emitOpenACCLoopConstruct(cast<OpenACCLoopConstruct>(*s));
108   case Stmt::OpenACCCombinedConstructClass:
109     return emitOpenACCCombinedConstruct(cast<OpenACCCombinedConstruct>(*s));
110   case Stmt::OpenACCDataConstructClass:
111     return emitOpenACCDataConstruct(cast<OpenACCDataConstruct>(*s));
112   case Stmt::OpenACCEnterDataConstructClass:
113     return emitOpenACCEnterDataConstruct(cast<OpenACCEnterDataConstruct>(*s));
114   case Stmt::OpenACCExitDataConstructClass:
115     return emitOpenACCExitDataConstruct(cast<OpenACCExitDataConstruct>(*s));
116   case Stmt::OpenACCHostDataConstructClass:
117     return emitOpenACCHostDataConstruct(cast<OpenACCHostDataConstruct>(*s));
118   case Stmt::OpenACCWaitConstructClass:
119     return emitOpenACCWaitConstruct(cast<OpenACCWaitConstruct>(*s));
120   case Stmt::OpenACCInitConstructClass:
121     return emitOpenACCInitConstruct(cast<OpenACCInitConstruct>(*s));
122   case Stmt::OpenACCShutdownConstructClass:
123     return emitOpenACCShutdownConstruct(cast<OpenACCShutdownConstruct>(*s));
124   case Stmt::OpenACCSetConstructClass:
125     return emitOpenACCSetConstruct(cast<OpenACCSetConstruct>(*s));
126   case Stmt::OpenACCUpdateConstructClass:
127     return emitOpenACCUpdateConstruct(cast<OpenACCUpdateConstruct>(*s));
128   case Stmt::OpenACCCacheConstructClass:
129     return emitOpenACCCacheConstruct(cast<OpenACCCacheConstruct>(*s));
130   case Stmt::OpenACCAtomicConstructClass:
131     return emitOpenACCAtomicConstruct(cast<OpenACCAtomicConstruct>(*s));
132   case Stmt::OMPScopeDirectiveClass:
133   case Stmt::OMPErrorDirectiveClass:
134   case Stmt::LabelStmtClass:
135   case Stmt::AttributedStmtClass:
136   case Stmt::GotoStmtClass:
137   case Stmt::DefaultStmtClass:
138   case Stmt::CaseStmtClass:
139   case Stmt::SEHLeaveStmtClass:
140   case Stmt::SYCLKernelCallStmtClass:
141   case Stmt::CoroutineBodyStmtClass:
142   case Stmt::CoreturnStmtClass:
143   case Stmt::CXXTryStmtClass:
144   case Stmt::IndirectGotoStmtClass:
145   case Stmt::GCCAsmStmtClass:
146   case Stmt::MSAsmStmtClass:
147   case Stmt::OMPParallelDirectiveClass:
148   case Stmt::OMPTaskwaitDirectiveClass:
149   case Stmt::OMPTaskyieldDirectiveClass:
150   case Stmt::OMPBarrierDirectiveClass:
151   case Stmt::CapturedStmtClass:
152   case Stmt::ObjCAtTryStmtClass:
153   case Stmt::ObjCAtThrowStmtClass:
154   case Stmt::ObjCAtSynchronizedStmtClass:
155   case Stmt::ObjCForCollectionStmtClass:
156   case Stmt::ObjCAutoreleasePoolStmtClass:
157   case Stmt::SEHTryStmtClass:
158   case Stmt::OMPMetaDirectiveClass:
159   case Stmt::OMPCanonicalLoopClass:
160   case Stmt::OMPSimdDirectiveClass:
161   case Stmt::OMPTileDirectiveClass:
162   case Stmt::OMPUnrollDirectiveClass:
163   case Stmt::OMPForDirectiveClass:
164   case Stmt::OMPForSimdDirectiveClass:
165   case Stmt::OMPSectionsDirectiveClass:
166   case Stmt::OMPSectionDirectiveClass:
167   case Stmt::OMPSingleDirectiveClass:
168   case Stmt::OMPMasterDirectiveClass:
169   case Stmt::OMPCriticalDirectiveClass:
170   case Stmt::OMPParallelForDirectiveClass:
171   case Stmt::OMPParallelForSimdDirectiveClass:
172   case Stmt::OMPParallelMasterDirectiveClass:
173   case Stmt::OMPParallelSectionsDirectiveClass:
174   case Stmt::OMPTaskDirectiveClass:
175   case Stmt::OMPTaskgroupDirectiveClass:
176   case Stmt::OMPFlushDirectiveClass:
177   case Stmt::OMPDepobjDirectiveClass:
178   case Stmt::OMPScanDirectiveClass:
179   case Stmt::OMPOrderedDirectiveClass:
180   case Stmt::OMPAtomicDirectiveClass:
181   case Stmt::OMPTargetDirectiveClass:
182   case Stmt::OMPTeamsDirectiveClass:
183   case Stmt::OMPCancellationPointDirectiveClass:
184   case Stmt::OMPCancelDirectiveClass:
185   case Stmt::OMPTargetDataDirectiveClass:
186   case Stmt::OMPTargetEnterDataDirectiveClass:
187   case Stmt::OMPTargetExitDataDirectiveClass:
188   case Stmt::OMPTargetParallelDirectiveClass:
189   case Stmt::OMPTargetParallelForDirectiveClass:
190   case Stmt::OMPTaskLoopDirectiveClass:
191   case Stmt::OMPTaskLoopSimdDirectiveClass:
192   case Stmt::OMPMaskedTaskLoopDirectiveClass:
193   case Stmt::OMPMaskedTaskLoopSimdDirectiveClass:
194   case Stmt::OMPMasterTaskLoopDirectiveClass:
195   case Stmt::OMPMasterTaskLoopSimdDirectiveClass:
196   case Stmt::OMPParallelGenericLoopDirectiveClass:
197   case Stmt::OMPParallelMaskedDirectiveClass:
198   case Stmt::OMPParallelMaskedTaskLoopDirectiveClass:
199   case Stmt::OMPParallelMaskedTaskLoopSimdDirectiveClass:
200   case Stmt::OMPParallelMasterTaskLoopDirectiveClass:
201   case Stmt::OMPParallelMasterTaskLoopSimdDirectiveClass:
202   case Stmt::OMPDistributeDirectiveClass:
203   case Stmt::OMPDistributeParallelForDirectiveClass:
204   case Stmt::OMPDistributeParallelForSimdDirectiveClass:
205   case Stmt::OMPDistributeSimdDirectiveClass:
206   case Stmt::OMPTargetParallelGenericLoopDirectiveClass:
207   case Stmt::OMPTargetParallelForSimdDirectiveClass:
208   case Stmt::OMPTargetSimdDirectiveClass:
209   case Stmt::OMPTargetTeamsGenericLoopDirectiveClass:
210   case Stmt::OMPTargetUpdateDirectiveClass:
211   case Stmt::OMPTeamsDistributeDirectiveClass:
212   case Stmt::OMPTeamsDistributeSimdDirectiveClass:
213   case Stmt::OMPTeamsDistributeParallelForSimdDirectiveClass:
214   case Stmt::OMPTeamsDistributeParallelForDirectiveClass:
215   case Stmt::OMPTeamsGenericLoopDirectiveClass:
216   case Stmt::OMPTargetTeamsDirectiveClass:
217   case Stmt::OMPTargetTeamsDistributeDirectiveClass:
218   case Stmt::OMPTargetTeamsDistributeParallelForDirectiveClass:
219   case Stmt::OMPTargetTeamsDistributeParallelForSimdDirectiveClass:
220   case Stmt::OMPTargetTeamsDistributeSimdDirectiveClass:
221   case Stmt::OMPInteropDirectiveClass:
222   case Stmt::OMPDispatchDirectiveClass:
223   case Stmt::OMPGenericLoopDirectiveClass:
224   case Stmt::OMPReverseDirectiveClass:
225   case Stmt::OMPInterchangeDirectiveClass:
226   case Stmt::OMPAssumeDirectiveClass:
227   case Stmt::OMPMaskedDirectiveClass:
228   case Stmt::OMPStripeDirectiveClass:
229   case Stmt::ObjCAtCatchStmtClass:
230   case Stmt::ObjCAtFinallyStmtClass:
231     cgm.errorNYI(s->getSourceRange(),
232                  std::string("emitStmt: ") + s->getStmtClassName());
233     return mlir::failure();
234   }
235 
236   llvm_unreachable("Unexpected statement class");
237 }
238 
emitSimpleStmt(const Stmt * s,bool useCurrentScope)239 mlir::LogicalResult CIRGenFunction::emitSimpleStmt(const Stmt *s,
240                                                    bool useCurrentScope) {
241   switch (s->getStmtClass()) {
242   default:
243     return mlir::failure();
244   case Stmt::DeclStmtClass:
245     return emitDeclStmt(cast<DeclStmt>(*s));
246   case Stmt::CompoundStmtClass:
247     if (useCurrentScope)
248       emitCompoundStmtWithoutScope(cast<CompoundStmt>(*s));
249     else
250       emitCompoundStmt(cast<CompoundStmt>(*s));
251     break;
252   case Stmt::ContinueStmtClass:
253     return emitContinueStmt(cast<ContinueStmt>(*s));
254 
255   // NullStmt doesn't need any handling, but we need to say we handled it.
256   case Stmt::NullStmtClass:
257     break;
258   case Stmt::CaseStmtClass:
259   case Stmt::DefaultStmtClass:
260     // If we reached here, we must not handling a switch case in the top level.
261     return emitSwitchCase(cast<SwitchCase>(*s),
262                           /*buildingTopLevelCase=*/false);
263     break;
264 
265   case Stmt::BreakStmtClass:
266     return emitBreakStmt(cast<BreakStmt>(*s));
267   case Stmt::ReturnStmtClass:
268     return emitReturnStmt(cast<ReturnStmt>(*s));
269   }
270 
271   return mlir::success();
272 }
273 
274 // Add a terminating yield on a body region if no other terminators are used.
terminateBody(CIRGenBuilderTy & builder,mlir::Region & r,mlir::Location loc)275 static void terminateBody(CIRGenBuilderTy &builder, mlir::Region &r,
276                           mlir::Location loc) {
277   if (r.empty())
278     return;
279 
280   SmallVector<mlir::Block *, 4> eraseBlocks;
281   unsigned numBlocks = r.getBlocks().size();
282   for (auto &block : r.getBlocks()) {
283     // Already cleanup after return operations, which might create
284     // empty blocks if emitted as last stmt.
285     if (numBlocks != 1 && block.empty() && block.hasNoPredecessors() &&
286         block.hasNoSuccessors())
287       eraseBlocks.push_back(&block);
288 
289     if (block.empty() ||
290         !block.back().hasTrait<mlir::OpTrait::IsTerminator>()) {
291       mlir::OpBuilder::InsertionGuard guardCase(builder);
292       builder.setInsertionPointToEnd(&block);
293       builder.createYield(loc);
294     }
295   }
296 
297   for (auto *b : eraseBlocks)
298     b->erase();
299 }
300 
emitIfStmt(const IfStmt & s)301 mlir::LogicalResult CIRGenFunction::emitIfStmt(const IfStmt &s) {
302   mlir::LogicalResult res = mlir::success();
303   // The else branch of a consteval if statement is always the only branch
304   // that can be runtime evaluated.
305   const Stmt *constevalExecuted;
306   if (s.isConsteval()) {
307     constevalExecuted = s.isNegatedConsteval() ? s.getThen() : s.getElse();
308     if (!constevalExecuted) {
309       // No runtime code execution required
310       return res;
311     }
312   }
313 
314   // C99 6.8.4.1: The first substatement is executed if the expression
315   // compares unequal to 0.  The condition must be a scalar type.
316   auto ifStmtBuilder = [&]() -> mlir::LogicalResult {
317     if (s.isConsteval())
318       return emitStmt(constevalExecuted, /*useCurrentScope=*/true);
319 
320     if (s.getInit())
321       if (emitStmt(s.getInit(), /*useCurrentScope=*/true).failed())
322         return mlir::failure();
323 
324     if (s.getConditionVariable())
325       emitDecl(*s.getConditionVariable());
326 
327     // If the condition folds to a constant and this is an 'if constexpr',
328     // we simplify it early in CIRGen to avoid emitting the full 'if'.
329     bool condConstant;
330     if (constantFoldsToBool(s.getCond(), condConstant, s.isConstexpr())) {
331       if (s.isConstexpr()) {
332         // Handle "if constexpr" explicitly here to avoid generating some
333         // ill-formed code since in CIR the "if" is no longer simplified
334         // in this lambda like in Clang but postponed to other MLIR
335         // passes.
336         if (const Stmt *executed = condConstant ? s.getThen() : s.getElse())
337           return emitStmt(executed, /*useCurrentScope=*/true);
338         // There is nothing to execute at runtime.
339         // TODO(cir): there is still an empty cir.scope generated by the caller.
340         return mlir::success();
341       }
342     }
343 
344     assert(!cir::MissingFeatures::emitCondLikelihoodViaExpectIntrinsic());
345     assert(!cir::MissingFeatures::incrementProfileCounter());
346     return emitIfOnBoolExpr(s.getCond(), s.getThen(), s.getElse());
347   };
348 
349   // TODO: Add a new scoped symbol table.
350   // LexicalScope ConditionScope(*this, S.getCond()->getSourceRange());
351   // The if scope contains the full source range for IfStmt.
352   mlir::Location scopeLoc = getLoc(s.getSourceRange());
353   builder.create<cir::ScopeOp>(
354       scopeLoc, /*scopeBuilder=*/
355       [&](mlir::OpBuilder &b, mlir::Location loc) {
356         LexicalScope lexScope{*this, scopeLoc, builder.getInsertionBlock()};
357         res = ifStmtBuilder();
358       });
359 
360   return res;
361 }
362 
emitDeclStmt(const DeclStmt & s)363 mlir::LogicalResult CIRGenFunction::emitDeclStmt(const DeclStmt &s) {
364   assert(builder.getInsertionBlock() && "expected valid insertion point");
365 
366   for (const Decl *I : s.decls())
367     emitDecl(*I);
368 
369   return mlir::success();
370 }
371 
emitReturnStmt(const ReturnStmt & s)372 mlir::LogicalResult CIRGenFunction::emitReturnStmt(const ReturnStmt &s) {
373   mlir::Location loc = getLoc(s.getSourceRange());
374   const Expr *rv = s.getRetValue();
375 
376   if (getContext().getLangOpts().ElideConstructors && s.getNRVOCandidate() &&
377       s.getNRVOCandidate()->isNRVOVariable()) {
378     getCIRGenModule().errorNYI(s.getSourceRange(),
379                                "named return value optimization");
380   } else if (!rv) {
381     // No return expression. Do nothing.
382   } else if (rv->getType()->isVoidType()) {
383     // Make sure not to return anything, but evaluate the expression
384     // for side effects.
385     if (rv) {
386       emitAnyExpr(rv);
387     }
388   } else if (cast<FunctionDecl>(curGD.getDecl())
389                  ->getReturnType()
390                  ->isReferenceType()) {
391     // If this function returns a reference, take the address of the
392     // expression rather than the value.
393     RValue result = emitReferenceBindingToExpr(rv);
394     builder.CIRBaseBuilderTy::createStore(loc, result.getValue(), *fnRetAlloca);
395   } else {
396     mlir::Value value = nullptr;
397     switch (CIRGenFunction::getEvaluationKind(rv->getType())) {
398     case cir::TEK_Scalar:
399       value = emitScalarExpr(rv);
400       if (value) { // Change this to an assert once emitScalarExpr is complete
401         builder.CIRBaseBuilderTy::createStore(loc, value, *fnRetAlloca);
402       }
403       break;
404     default:
405       getCIRGenModule().errorNYI(s.getSourceRange(),
406                                  "non-scalar function return type");
407       break;
408     }
409   }
410 
411   auto *retBlock = curLexScope->getOrCreateRetBlock(*this, loc);
412   builder.create<cir::BrOp>(loc, retBlock);
413   builder.createBlock(builder.getBlock()->getParent());
414 
415   return mlir::success();
416 }
417 
418 mlir::LogicalResult
emitContinueStmt(const clang::ContinueStmt & s)419 CIRGenFunction::emitContinueStmt(const clang::ContinueStmt &s) {
420   builder.createContinue(getLoc(s.getContinueLoc()));
421 
422   // Insert the new block to continue codegen after the continue statement.
423   builder.createBlock(builder.getBlock()->getParent());
424 
425   return mlir::success();
426 }
427 
emitBreakStmt(const clang::BreakStmt & s)428 mlir::LogicalResult CIRGenFunction::emitBreakStmt(const clang::BreakStmt &s) {
429   builder.createBreak(getLoc(s.getBreakLoc()));
430 
431   // Insert the new block to continue codegen after the break statement.
432   builder.createBlock(builder.getBlock()->getParent());
433 
434   return mlir::success();
435 }
436 
437 template <typename T>
438 mlir::LogicalResult
emitCaseDefaultCascade(const T * stmt,mlir::Type condType,mlir::ArrayAttr value,CaseOpKind kind,bool buildingTopLevelCase)439 CIRGenFunction::emitCaseDefaultCascade(const T *stmt, mlir::Type condType,
440                                        mlir::ArrayAttr value, CaseOpKind kind,
441                                        bool buildingTopLevelCase) {
442 
443   assert((isa<CaseStmt, DefaultStmt>(stmt)) &&
444          "only case or default stmt go here");
445 
446   mlir::LogicalResult result = mlir::success();
447 
448   mlir::Location loc = getLoc(stmt->getBeginLoc());
449 
450   enum class SubStmtKind { Case, Default, Other };
451   SubStmtKind subStmtKind = SubStmtKind::Other;
452   const Stmt *sub = stmt->getSubStmt();
453 
454   mlir::OpBuilder::InsertPoint insertPoint;
455   builder.create<CaseOp>(loc, value, kind, insertPoint);
456 
457   {
458     mlir::OpBuilder::InsertionGuard guardSwitch(builder);
459     builder.restoreInsertionPoint(insertPoint);
460 
461     if (isa<DefaultStmt>(sub) && isa<CaseStmt>(stmt)) {
462       subStmtKind = SubStmtKind::Default;
463       builder.createYield(loc);
464     } else if (isa<CaseStmt>(sub) && isa<DefaultStmt, CaseStmt>(stmt)) {
465       subStmtKind = SubStmtKind::Case;
466       builder.createYield(loc);
467     } else {
468       result = emitStmt(sub, /*useCurrentScope=*/!isa<CompoundStmt>(sub));
469     }
470 
471     insertPoint = builder.saveInsertionPoint();
472   }
473 
474   // If the substmt is default stmt or case stmt, try to handle the special case
475   // to make it into the simple form. e.g.
476   //
477   //  swtich () {
478   //    case 1:
479   //    default:
480   //      ...
481   //  }
482   //
483   // we prefer generating
484   //
485   //  cir.switch() {
486   //     cir.case(equal, 1) {
487   //        cir.yield
488   //     }
489   //     cir.case(default) {
490   //        ...
491   //     }
492   //  }
493   //
494   // than
495   //
496   //  cir.switch() {
497   //     cir.case(equal, 1) {
498   //       cir.case(default) {
499   //         ...
500   //       }
501   //     }
502   //  }
503   //
504   // We don't need to revert this if we find the current switch can't be in
505   // simple form later since the conversion itself should be harmless.
506   if (subStmtKind == SubStmtKind::Case) {
507     result = emitCaseStmt(*cast<CaseStmt>(sub), condType, buildingTopLevelCase);
508   } else if (subStmtKind == SubStmtKind::Default) {
509     result = emitDefaultStmt(*cast<DefaultStmt>(sub), condType,
510                              buildingTopLevelCase);
511   } else if (buildingTopLevelCase) {
512     // If we're building a top level case, try to restore the insert point to
513     // the case we're building, then we can attach more random stmts to the
514     // case to make generating `cir.switch` operation to be a simple form.
515     builder.restoreInsertionPoint(insertPoint);
516   }
517 
518   return result;
519 }
520 
emitCaseStmt(const CaseStmt & s,mlir::Type condType,bool buildingTopLevelCase)521 mlir::LogicalResult CIRGenFunction::emitCaseStmt(const CaseStmt &s,
522                                                  mlir::Type condType,
523                                                  bool buildingTopLevelCase) {
524   cir::CaseOpKind kind;
525   mlir::ArrayAttr value;
526   llvm::APSInt intVal = s.getLHS()->EvaluateKnownConstInt(getContext());
527 
528   // If the case statement has an RHS value, it is representing a GNU
529   // case range statement, where LHS is the beginning of the range
530   // and RHS is the end of the range.
531   if (const Expr *rhs = s.getRHS()) {
532     llvm::APSInt endVal = rhs->EvaluateKnownConstInt(getContext());
533     value = builder.getArrayAttr({cir::IntAttr::get(condType, intVal),
534                                   cir::IntAttr::get(condType, endVal)});
535     kind = cir::CaseOpKind::Range;
536   } else {
537     value = builder.getArrayAttr({cir::IntAttr::get(condType, intVal)});
538     kind = cir::CaseOpKind::Equal;
539   }
540 
541   return emitCaseDefaultCascade(&s, condType, value, kind,
542                                 buildingTopLevelCase);
543 }
544 
emitDefaultStmt(const clang::DefaultStmt & s,mlir::Type condType,bool buildingTopLevelCase)545 mlir::LogicalResult CIRGenFunction::emitDefaultStmt(const clang::DefaultStmt &s,
546                                                     mlir::Type condType,
547                                                     bool buildingTopLevelCase) {
548   return emitCaseDefaultCascade(&s, condType, builder.getArrayAttr({}),
549                                 cir::CaseOpKind::Default, buildingTopLevelCase);
550 }
551 
emitSwitchCase(const SwitchCase & s,bool buildingTopLevelCase)552 mlir::LogicalResult CIRGenFunction::emitSwitchCase(const SwitchCase &s,
553                                                    bool buildingTopLevelCase) {
554   assert(!condTypeStack.empty() &&
555          "build switch case without specifying the type of the condition");
556 
557   if (s.getStmtClass() == Stmt::CaseStmtClass)
558     return emitCaseStmt(cast<CaseStmt>(s), condTypeStack.back(),
559                         buildingTopLevelCase);
560 
561   if (s.getStmtClass() == Stmt::DefaultStmtClass)
562     return emitDefaultStmt(cast<DefaultStmt>(s), condTypeStack.back(),
563                            buildingTopLevelCase);
564 
565   llvm_unreachable("expect case or default stmt");
566 }
567 
568 mlir::LogicalResult
emitCXXForRangeStmt(const CXXForRangeStmt & s,ArrayRef<const Attr * > forAttrs)569 CIRGenFunction::emitCXXForRangeStmt(const CXXForRangeStmt &s,
570                                     ArrayRef<const Attr *> forAttrs) {
571   cir::ForOp forOp;
572 
573   // TODO(cir): pass in array of attributes.
574   auto forStmtBuilder = [&]() -> mlir::LogicalResult {
575     mlir::LogicalResult loopRes = mlir::success();
576     // Evaluate the first pieces before the loop.
577     if (s.getInit())
578       if (emitStmt(s.getInit(), /*useCurrentScope=*/true).failed())
579         return mlir::failure();
580     if (emitStmt(s.getRangeStmt(), /*useCurrentScope=*/true).failed())
581       return mlir::failure();
582     if (emitStmt(s.getBeginStmt(), /*useCurrentScope=*/true).failed())
583       return mlir::failure();
584     if (emitStmt(s.getEndStmt(), /*useCurrentScope=*/true).failed())
585       return mlir::failure();
586 
587     assert(!cir::MissingFeatures::loopInfoStack());
588     // From LLVM: if there are any cleanups between here and the loop-exit
589     // scope, create a block to stage a loop exit along.
590     // We probably already do the right thing because of ScopeOp, but make
591     // sure we handle all cases.
592     assert(!cir::MissingFeatures::requiresCleanups());
593 
594     forOp = builder.createFor(
595         getLoc(s.getSourceRange()),
596         /*condBuilder=*/
597         [&](mlir::OpBuilder &b, mlir::Location loc) {
598           assert(!cir::MissingFeatures::createProfileWeightsForLoop());
599           assert(!cir::MissingFeatures::emitCondLikelihoodViaExpectIntrinsic());
600           mlir::Value condVal = evaluateExprAsBool(s.getCond());
601           builder.createCondition(condVal);
602         },
603         /*bodyBuilder=*/
604         [&](mlir::OpBuilder &b, mlir::Location loc) {
605           // https://en.cppreference.com/w/cpp/language/for
606           // In C++ the scope of the init-statement and the scope of
607           // statement are one and the same.
608           bool useCurrentScope = true;
609           if (emitStmt(s.getLoopVarStmt(), useCurrentScope).failed())
610             loopRes = mlir::failure();
611           if (emitStmt(s.getBody(), useCurrentScope).failed())
612             loopRes = mlir::failure();
613           emitStopPoint(&s);
614         },
615         /*stepBuilder=*/
616         [&](mlir::OpBuilder &b, mlir::Location loc) {
617           if (s.getInc())
618             if (emitStmt(s.getInc(), /*useCurrentScope=*/true).failed())
619               loopRes = mlir::failure();
620           builder.createYield(loc);
621         });
622     return loopRes;
623   };
624 
625   mlir::LogicalResult res = mlir::success();
626   mlir::Location scopeLoc = getLoc(s.getSourceRange());
627   builder.create<cir::ScopeOp>(scopeLoc, /*scopeBuilder=*/
628                                [&](mlir::OpBuilder &b, mlir::Location loc) {
629                                  // Create a cleanup scope for the condition
630                                  // variable cleanups. Logical equivalent from
631                                  // LLVM codegn for LexicalScope
632                                  // ConditionScope(*this, S.getSourceRange())...
633                                  LexicalScope lexScope{
634                                      *this, loc, builder.getInsertionBlock()};
635                                  res = forStmtBuilder();
636                                });
637 
638   if (res.failed())
639     return res;
640 
641   terminateBody(builder, forOp.getBody(), getLoc(s.getEndLoc()));
642   return mlir::success();
643 }
644 
emitForStmt(const ForStmt & s)645 mlir::LogicalResult CIRGenFunction::emitForStmt(const ForStmt &s) {
646   cir::ForOp forOp;
647 
648   // TODO: pass in an array of attributes.
649   auto forStmtBuilder = [&]() -> mlir::LogicalResult {
650     mlir::LogicalResult loopRes = mlir::success();
651     // Evaluate the first part before the loop.
652     if (s.getInit())
653       if (emitStmt(s.getInit(), /*useCurrentScope=*/true).failed())
654         return mlir::failure();
655     assert(!cir::MissingFeatures::loopInfoStack());
656     // In the classic codegen, if there are any cleanups between here and the
657     // loop-exit scope, a block is created to stage the loop exit. We probably
658     // already do the right thing because of ScopeOp, but we need more testing
659     // to be sure we handle all cases.
660     assert(!cir::MissingFeatures::requiresCleanups());
661 
662     forOp = builder.createFor(
663         getLoc(s.getSourceRange()),
664         /*condBuilder=*/
665         [&](mlir::OpBuilder &b, mlir::Location loc) {
666           assert(!cir::MissingFeatures::createProfileWeightsForLoop());
667           assert(!cir::MissingFeatures::emitCondLikelihoodViaExpectIntrinsic());
668           mlir::Value condVal;
669           if (s.getCond()) {
670             // If the for statement has a condition scope,
671             // emit the local variable declaration.
672             if (s.getConditionVariable())
673               emitDecl(*s.getConditionVariable());
674             // C99 6.8.5p2/p4: The first substatement is executed if the
675             // expression compares unequal to 0. The condition must be a
676             // scalar type.
677             condVal = evaluateExprAsBool(s.getCond());
678           } else {
679             condVal = b.create<cir::ConstantOp>(loc, builder.getTrueAttr());
680           }
681           builder.createCondition(condVal);
682         },
683         /*bodyBuilder=*/
684         [&](mlir::OpBuilder &b, mlir::Location loc) {
685           // The scope of the for loop body is nested within the scope of the
686           // for loop's init-statement and condition.
687           if (emitStmt(s.getBody(), /*useCurrentScope=*/false).failed())
688             loopRes = mlir::failure();
689           emitStopPoint(&s);
690         },
691         /*stepBuilder=*/
692         [&](mlir::OpBuilder &b, mlir::Location loc) {
693           if (s.getInc())
694             if (emitStmt(s.getInc(), /*useCurrentScope=*/true).failed())
695               loopRes = mlir::failure();
696           builder.createYield(loc);
697         });
698     return loopRes;
699   };
700 
701   auto res = mlir::success();
702   auto scopeLoc = getLoc(s.getSourceRange());
703   builder.create<cir::ScopeOp>(scopeLoc, /*scopeBuilder=*/
704                                [&](mlir::OpBuilder &b, mlir::Location loc) {
705                                  LexicalScope lexScope{
706                                      *this, loc, builder.getInsertionBlock()};
707                                  res = forStmtBuilder();
708                                });
709 
710   if (res.failed())
711     return res;
712 
713   terminateBody(builder, forOp.getBody(), getLoc(s.getEndLoc()));
714   return mlir::success();
715 }
716 
emitDoStmt(const DoStmt & s)717 mlir::LogicalResult CIRGenFunction::emitDoStmt(const DoStmt &s) {
718   cir::DoWhileOp doWhileOp;
719 
720   // TODO: pass in array of attributes.
721   auto doStmtBuilder = [&]() -> mlir::LogicalResult {
722     mlir::LogicalResult loopRes = mlir::success();
723     assert(!cir::MissingFeatures::loopInfoStack());
724     // From LLVM: if there are any cleanups between here and the loop-exit
725     // scope, create a block to stage a loop exit along.
726     // We probably already do the right thing because of ScopeOp, but make
727     // sure we handle all cases.
728     assert(!cir::MissingFeatures::requiresCleanups());
729 
730     doWhileOp = builder.createDoWhile(
731         getLoc(s.getSourceRange()),
732         /*condBuilder=*/
733         [&](mlir::OpBuilder &b, mlir::Location loc) {
734           assert(!cir::MissingFeatures::createProfileWeightsForLoop());
735           assert(!cir::MissingFeatures::emitCondLikelihoodViaExpectIntrinsic());
736           // C99 6.8.5p2/p4: The first substatement is executed if the
737           // expression compares unequal to 0. The condition must be a
738           // scalar type.
739           mlir::Value condVal = evaluateExprAsBool(s.getCond());
740           builder.createCondition(condVal);
741         },
742         /*bodyBuilder=*/
743         [&](mlir::OpBuilder &b, mlir::Location loc) {
744           // The scope of the do-while loop body is a nested scope.
745           if (emitStmt(s.getBody(), /*useCurrentScope=*/false).failed())
746             loopRes = mlir::failure();
747           emitStopPoint(&s);
748         });
749     return loopRes;
750   };
751 
752   mlir::LogicalResult res = mlir::success();
753   mlir::Location scopeLoc = getLoc(s.getSourceRange());
754   builder.create<cir::ScopeOp>(scopeLoc, /*scopeBuilder=*/
755                                [&](mlir::OpBuilder &b, mlir::Location loc) {
756                                  LexicalScope lexScope{
757                                      *this, loc, builder.getInsertionBlock()};
758                                  res = doStmtBuilder();
759                                });
760 
761   if (res.failed())
762     return res;
763 
764   terminateBody(builder, doWhileOp.getBody(), getLoc(s.getEndLoc()));
765   return mlir::success();
766 }
767 
emitWhileStmt(const WhileStmt & s)768 mlir::LogicalResult CIRGenFunction::emitWhileStmt(const WhileStmt &s) {
769   cir::WhileOp whileOp;
770 
771   // TODO: pass in array of attributes.
772   auto whileStmtBuilder = [&]() -> mlir::LogicalResult {
773     mlir::LogicalResult loopRes = mlir::success();
774     assert(!cir::MissingFeatures::loopInfoStack());
775     // From LLVM: if there are any cleanups between here and the loop-exit
776     // scope, create a block to stage a loop exit along.
777     // We probably already do the right thing because of ScopeOp, but make
778     // sure we handle all cases.
779     assert(!cir::MissingFeatures::requiresCleanups());
780 
781     whileOp = builder.createWhile(
782         getLoc(s.getSourceRange()),
783         /*condBuilder=*/
784         [&](mlir::OpBuilder &b, mlir::Location loc) {
785           assert(!cir::MissingFeatures::createProfileWeightsForLoop());
786           assert(!cir::MissingFeatures::emitCondLikelihoodViaExpectIntrinsic());
787           mlir::Value condVal;
788           // If the for statement has a condition scope,
789           // emit the local variable declaration.
790           if (s.getConditionVariable())
791             emitDecl(*s.getConditionVariable());
792           // C99 6.8.5p2/p4: The first substatement is executed if the
793           // expression compares unequal to 0. The condition must be a
794           // scalar type.
795           condVal = evaluateExprAsBool(s.getCond());
796           builder.createCondition(condVal);
797         },
798         /*bodyBuilder=*/
799         [&](mlir::OpBuilder &b, mlir::Location loc) {
800           // The scope of the while loop body is a nested scope.
801           if (emitStmt(s.getBody(), /*useCurrentScope=*/false).failed())
802             loopRes = mlir::failure();
803           emitStopPoint(&s);
804         });
805     return loopRes;
806   };
807 
808   mlir::LogicalResult res = mlir::success();
809   mlir::Location scopeLoc = getLoc(s.getSourceRange());
810   builder.create<cir::ScopeOp>(scopeLoc, /*scopeBuilder=*/
811                                [&](mlir::OpBuilder &b, mlir::Location loc) {
812                                  LexicalScope lexScope{
813                                      *this, loc, builder.getInsertionBlock()};
814                                  res = whileStmtBuilder();
815                                });
816 
817   if (res.failed())
818     return res;
819 
820   terminateBody(builder, whileOp.getBody(), getLoc(s.getEndLoc()));
821   return mlir::success();
822 }
823 
emitSwitchBody(const Stmt * s)824 mlir::LogicalResult CIRGenFunction::emitSwitchBody(const Stmt *s) {
825   // It is rare but legal if the switch body is not a compound stmt. e.g.,
826   //
827   //  switch(a)
828   //    while(...) {
829   //      case1
830   //      ...
831   //      case2
832   //      ...
833   //    }
834   if (!isa<CompoundStmt>(s))
835     return emitStmt(s, /*useCurrentScope=*/true);
836 
837   auto *compoundStmt = cast<CompoundStmt>(s);
838 
839   mlir::Block *swtichBlock = builder.getBlock();
840   for (auto *c : compoundStmt->body()) {
841     if (auto *switchCase = dyn_cast<SwitchCase>(c)) {
842       builder.setInsertionPointToEnd(swtichBlock);
843       // Reset insert point automatically, so that we can attach following
844       // random stmt to the region of previous built case op to try to make
845       // the being generated `cir.switch` to be in simple form.
846       if (mlir::failed(
847               emitSwitchCase(*switchCase, /*buildingTopLevelCase=*/true)))
848         return mlir::failure();
849 
850       continue;
851     }
852 
853     // Otherwise, just build the statements in the nearest case region.
854     if (mlir::failed(emitStmt(c, /*useCurrentScope=*/!isa<CompoundStmt>(c))))
855       return mlir::failure();
856   }
857 
858   return mlir::success();
859 }
860 
emitSwitchStmt(const clang::SwitchStmt & s)861 mlir::LogicalResult CIRGenFunction::emitSwitchStmt(const clang::SwitchStmt &s) {
862   // TODO: LLVM codegen does some early optimization to fold the condition and
863   // only emit live cases. CIR should use MLIR to achieve similar things,
864   // nothing to be done here.
865   // if (ConstantFoldsToSimpleInteger(S.getCond(), ConstantCondValue))...
866   assert(!cir::MissingFeatures::constantFoldSwitchStatement());
867 
868   SwitchOp swop;
869   auto switchStmtBuilder = [&]() -> mlir::LogicalResult {
870     if (s.getInit())
871       if (emitStmt(s.getInit(), /*useCurrentScope=*/true).failed())
872         return mlir::failure();
873 
874     if (s.getConditionVariable())
875       emitDecl(*s.getConditionVariable());
876 
877     mlir::Value condV = emitScalarExpr(s.getCond());
878 
879     // TODO: PGO and likelihood (e.g. PGO.haveRegionCounts())
880     assert(!cir::MissingFeatures::pgoUse());
881     assert(!cir::MissingFeatures::emitCondLikelihoodViaExpectIntrinsic());
882     // TODO: if the switch has a condition wrapped by __builtin_unpredictable?
883     assert(!cir::MissingFeatures::insertBuiltinUnpredictable());
884 
885     mlir::LogicalResult res = mlir::success();
886     swop = builder.create<SwitchOp>(
887         getLoc(s.getBeginLoc()), condV,
888         /*switchBuilder=*/
889         [&](mlir::OpBuilder &b, mlir::Location loc, mlir::OperationState &os) {
890           curLexScope->setAsSwitch();
891 
892           condTypeStack.push_back(condV.getType());
893 
894           res = emitSwitchBody(s.getBody());
895 
896           condTypeStack.pop_back();
897         });
898 
899     return res;
900   };
901 
902   // The switch scope contains the full source range for SwitchStmt.
903   mlir::Location scopeLoc = getLoc(s.getSourceRange());
904   mlir::LogicalResult res = mlir::success();
905   builder.create<cir::ScopeOp>(scopeLoc, /*scopeBuilder=*/
906                                [&](mlir::OpBuilder &b, mlir::Location loc) {
907                                  LexicalScope lexScope{
908                                      *this, loc, builder.getInsertionBlock()};
909                                  res = switchStmtBuilder();
910                                });
911 
912   llvm::SmallVector<CaseOp> cases;
913   swop.collectCases(cases);
914   for (auto caseOp : cases)
915     terminateBody(builder, caseOp.getCaseRegion(), caseOp.getLoc());
916   terminateBody(builder, swop.getBody(), swop.getLoc());
917 
918   return res;
919 }
920