xref: /freebsd/contrib/llvm-project/clang/lib/CIR/CodeGen/CIRGenStmtOpenACC.cpp (revision 700637cbb5e582861067a11aaca4d053546871d2)
1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Emit OpenACC Stmt nodes as CIR code.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "CIRGenBuilder.h"
14 #include "CIRGenFunction.h"
15 #include "mlir/Dialect/OpenACC/OpenACC.h"
16 #include "clang/AST/OpenACCClause.h"
17 #include "clang/AST/StmtOpenACC.h"
18 
19 using namespace clang;
20 using namespace clang::CIRGen;
21 using namespace cir;
22 using namespace mlir::acc;
23 
24 template <typename Op, typename TermOp>
emitOpenACCOpAssociatedStmt(mlir::Location start,mlir::Location end,OpenACCDirectiveKind dirKind,SourceLocation dirLoc,llvm::ArrayRef<const OpenACCClause * > clauses,const Stmt * associatedStmt)25 mlir::LogicalResult CIRGenFunction::emitOpenACCOpAssociatedStmt(
26     mlir::Location start, mlir::Location end, OpenACCDirectiveKind dirKind,
27     SourceLocation dirLoc, llvm::ArrayRef<const OpenACCClause *> clauses,
28     const Stmt *associatedStmt) {
29   mlir::LogicalResult res = mlir::success();
30 
31   llvm::SmallVector<mlir::Type> retTy;
32   llvm::SmallVector<mlir::Value> operands;
33   auto op = builder.create<Op>(start, retTy, operands);
34 
35   emitOpenACCClauses(op, dirKind, dirLoc, clauses);
36 
37   {
38     mlir::Block &block = op.getRegion().emplaceBlock();
39     mlir::OpBuilder::InsertionGuard guardCase(builder);
40     builder.setInsertionPointToEnd(&block);
41 
42     LexicalScope ls{*this, start, builder.getInsertionBlock()};
43     res = emitStmt(associatedStmt, /*useCurrentScope=*/true);
44 
45     builder.create<TermOp>(end);
46   }
47   return res;
48 }
49 
50 namespace {
51 template <typename Op> struct CombinedType;
52 template <> struct CombinedType<ParallelOp> {
53   static constexpr mlir::acc::CombinedConstructsType value =
54       mlir::acc::CombinedConstructsType::ParallelLoop;
55 };
56 template <> struct CombinedType<SerialOp> {
57   static constexpr mlir::acc::CombinedConstructsType value =
58       mlir::acc::CombinedConstructsType::SerialLoop;
59 };
60 template <> struct CombinedType<KernelsOp> {
61   static constexpr mlir::acc::CombinedConstructsType value =
62       mlir::acc::CombinedConstructsType::KernelsLoop;
63 };
64 } // namespace
65 
66 template <typename Op, typename TermOp>
emitOpenACCOpCombinedConstruct(mlir::Location start,mlir::Location end,OpenACCDirectiveKind dirKind,SourceLocation dirLoc,llvm::ArrayRef<const OpenACCClause * > clauses,const Stmt * loopStmt)67 mlir::LogicalResult CIRGenFunction::emitOpenACCOpCombinedConstruct(
68     mlir::Location start, mlir::Location end, OpenACCDirectiveKind dirKind,
69     SourceLocation dirLoc, llvm::ArrayRef<const OpenACCClause *> clauses,
70     const Stmt *loopStmt) {
71   mlir::LogicalResult res = mlir::success();
72 
73   llvm::SmallVector<mlir::Type> retTy;
74   llvm::SmallVector<mlir::Value> operands;
75 
76   auto computeOp = builder.create<Op>(start, retTy, operands);
77   computeOp.setCombinedAttr(builder.getUnitAttr());
78   mlir::acc::LoopOp loopOp;
79 
80   // First, emit the bodies of both operations, with the loop inside the body of
81   // the combined construct.
82   {
83     mlir::Block &block = computeOp.getRegion().emplaceBlock();
84     mlir::OpBuilder::InsertionGuard guardCase(builder);
85     builder.setInsertionPointToEnd(&block);
86 
87     LexicalScope ls{*this, start, builder.getInsertionBlock()};
88     auto loopOp = builder.create<LoopOp>(start, retTy, operands);
89     loopOp.setCombinedAttr(mlir::acc::CombinedConstructsTypeAttr::get(
90         builder.getContext(), CombinedType<Op>::value));
91 
92     {
93       mlir::Block &innerBlock = loopOp.getRegion().emplaceBlock();
94       mlir::OpBuilder::InsertionGuard guardCase(builder);
95       builder.setInsertionPointToEnd(&innerBlock);
96 
97       LexicalScope ls{*this, start, builder.getInsertionBlock()};
98       ActiveOpenACCLoopRAII activeLoop{*this, &loopOp};
99 
100       res = emitStmt(loopStmt, /*useCurrentScope=*/true);
101 
102       builder.create<mlir::acc::YieldOp>(end);
103     }
104 
105     emitOpenACCClauses(computeOp, loopOp, dirKind, dirLoc, clauses);
106 
107     updateLoopOpParallelism(loopOp, /*isOrphan=*/false, dirKind);
108 
109     builder.create<TermOp>(end);
110   }
111 
112   return res;
113 }
114 
115 template <typename Op>
emitOpenACCOp(mlir::Location start,OpenACCDirectiveKind dirKind,SourceLocation dirLoc,llvm::ArrayRef<const OpenACCClause * > clauses)116 Op CIRGenFunction::emitOpenACCOp(
117     mlir::Location start, OpenACCDirectiveKind dirKind, SourceLocation dirLoc,
118     llvm::ArrayRef<const OpenACCClause *> clauses) {
119   llvm::SmallVector<mlir::Type> retTy;
120   llvm::SmallVector<mlir::Value> operands;
121   auto op = builder.create<Op>(start, retTy, operands);
122 
123   emitOpenACCClauses(op, dirKind, dirLoc, clauses);
124   return op;
125 }
126 
127 mlir::LogicalResult
emitOpenACCComputeConstruct(const OpenACCComputeConstruct & s)128 CIRGenFunction::emitOpenACCComputeConstruct(const OpenACCComputeConstruct &s) {
129   mlir::Location start = getLoc(s.getSourceRange().getBegin());
130   mlir::Location end = getLoc(s.getSourceRange().getEnd());
131 
132   switch (s.getDirectiveKind()) {
133   case OpenACCDirectiveKind::Parallel:
134     return emitOpenACCOpAssociatedStmt<ParallelOp, mlir::acc::YieldOp>(
135         start, end, s.getDirectiveKind(), s.getDirectiveLoc(), s.clauses(),
136         s.getStructuredBlock());
137   case OpenACCDirectiveKind::Serial:
138     return emitOpenACCOpAssociatedStmt<SerialOp, mlir::acc::YieldOp>(
139         start, end, s.getDirectiveKind(), s.getDirectiveLoc(), s.clauses(),
140         s.getStructuredBlock());
141   case OpenACCDirectiveKind::Kernels:
142     return emitOpenACCOpAssociatedStmt<KernelsOp, mlir::acc::TerminatorOp>(
143         start, end, s.getDirectiveKind(), s.getDirectiveLoc(), s.clauses(),
144         s.getStructuredBlock());
145   default:
146     llvm_unreachable("invalid compute construct kind");
147   }
148 }
149 
150 mlir::LogicalResult
emitOpenACCDataConstruct(const OpenACCDataConstruct & s)151 CIRGenFunction::emitOpenACCDataConstruct(const OpenACCDataConstruct &s) {
152   mlir::Location start = getLoc(s.getSourceRange().getBegin());
153   mlir::Location end = getLoc(s.getSourceRange().getEnd());
154 
155   return emitOpenACCOpAssociatedStmt<DataOp, mlir::acc::TerminatorOp>(
156       start, end, s.getDirectiveKind(), s.getDirectiveLoc(), s.clauses(),
157       s.getStructuredBlock());
158 }
159 
160 mlir::LogicalResult
emitOpenACCInitConstruct(const OpenACCInitConstruct & s)161 CIRGenFunction::emitOpenACCInitConstruct(const OpenACCInitConstruct &s) {
162   mlir::Location start = getLoc(s.getSourceRange().getBegin());
163   emitOpenACCOp<InitOp>(start, s.getDirectiveKind(), s.getDirectiveLoc(),
164                                s.clauses());
165   return mlir::success();
166 }
167 
168 mlir::LogicalResult
emitOpenACCSetConstruct(const OpenACCSetConstruct & s)169 CIRGenFunction::emitOpenACCSetConstruct(const OpenACCSetConstruct &s) {
170   mlir::Location start = getLoc(s.getSourceRange().getBegin());
171   emitOpenACCOp<SetOp>(start, s.getDirectiveKind(), s.getDirectiveLoc(),
172                               s.clauses());
173   return mlir::success();
174 }
175 
emitOpenACCShutdownConstruct(const OpenACCShutdownConstruct & s)176 mlir::LogicalResult CIRGenFunction::emitOpenACCShutdownConstruct(
177     const OpenACCShutdownConstruct &s) {
178   mlir::Location start = getLoc(s.getSourceRange().getBegin());
179   emitOpenACCOp<ShutdownOp>(start, s.getDirectiveKind(),
180                                    s.getDirectiveLoc(), s.clauses());
181   return mlir::success();
182 }
183 
184 mlir::LogicalResult
emitOpenACCWaitConstruct(const OpenACCWaitConstruct & s)185 CIRGenFunction::emitOpenACCWaitConstruct(const OpenACCWaitConstruct &s) {
186   mlir::Location start = getLoc(s.getSourceRange().getBegin());
187   auto waitOp = emitOpenACCOp<WaitOp>(start, s.getDirectiveKind(),
188                                    s.getDirectiveLoc(), s.clauses());
189 
190   auto createIntExpr = [this](const Expr *intExpr) {
191     mlir::Value expr = emitScalarExpr(intExpr);
192     mlir::Location exprLoc = cgm.getLoc(intExpr->getBeginLoc());
193 
194     mlir::IntegerType targetType = mlir::IntegerType::get(
195         &getMLIRContext(), getContext().getIntWidth(intExpr->getType()),
196         intExpr->getType()->isSignedIntegerOrEnumerationType()
197             ? mlir::IntegerType::SignednessSemantics::Signed
198             : mlir::IntegerType::SignednessSemantics::Unsigned);
199 
200     auto conversionOp = builder.create<mlir::UnrealizedConversionCastOp>(
201         exprLoc, targetType, expr);
202     return conversionOp.getResult(0);
203   };
204 
205   // Emit the correct 'wait' clauses.
206   {
207     mlir::OpBuilder::InsertionGuard guardCase(builder);
208     builder.setInsertionPoint(waitOp);
209 
210     if (s.hasDevNumExpr())
211       waitOp.getWaitDevnumMutable().append(createIntExpr(s.getDevNumExpr()));
212 
213     for (Expr *QueueExpr : s.getQueueIdExprs())
214       waitOp.getWaitOperandsMutable().append(createIntExpr(QueueExpr));
215   }
216 
217   return mlir::success();
218 }
219 
emitOpenACCCombinedConstruct(const OpenACCCombinedConstruct & s)220 mlir::LogicalResult CIRGenFunction::emitOpenACCCombinedConstruct(
221     const OpenACCCombinedConstruct &s) {
222   mlir::Location start = getLoc(s.getSourceRange().getBegin());
223   mlir::Location end = getLoc(s.getSourceRange().getEnd());
224 
225   switch (s.getDirectiveKind()) {
226   case OpenACCDirectiveKind::ParallelLoop:
227     return emitOpenACCOpCombinedConstruct<ParallelOp, mlir::acc::YieldOp>(
228         start, end, s.getDirectiveKind(), s.getDirectiveLoc(), s.clauses(),
229         s.getLoop());
230   case OpenACCDirectiveKind::SerialLoop:
231     return emitOpenACCOpCombinedConstruct<SerialOp, mlir::acc::YieldOp>(
232         start, end, s.getDirectiveKind(), s.getDirectiveLoc(), s.clauses(),
233         s.getLoop());
234   case OpenACCDirectiveKind::KernelsLoop:
235     return emitOpenACCOpCombinedConstruct<KernelsOp, mlir::acc::TerminatorOp>(
236         start, end, s.getDirectiveKind(), s.getDirectiveLoc(), s.clauses(),
237         s.getLoop());
238   default:
239     llvm_unreachable("invalid compute construct kind");
240   }
241 }
242 
emitOpenACCHostDataConstruct(const OpenACCHostDataConstruct & s)243 mlir::LogicalResult CIRGenFunction::emitOpenACCHostDataConstruct(
244     const OpenACCHostDataConstruct &s) {
245   mlir::Location start = getLoc(s.getSourceRange().getBegin());
246   mlir::Location end = getLoc(s.getSourceRange().getEnd());
247 
248   return emitOpenACCOpAssociatedStmt<HostDataOp, mlir::acc::TerminatorOp>(
249       start, end, s.getDirectiveKind(), s.getDirectiveLoc(), s.clauses(),
250       s.getStructuredBlock());
251 }
252 
emitOpenACCEnterDataConstruct(const OpenACCEnterDataConstruct & s)253 mlir::LogicalResult CIRGenFunction::emitOpenACCEnterDataConstruct(
254     const OpenACCEnterDataConstruct &s) {
255   mlir::Location start = getLoc(s.getSourceRange().getBegin());
256   emitOpenACCOp<EnterDataOp>(start, s.getDirectiveKind(), s.getDirectiveLoc(),
257                              s.clauses());
258   return mlir::success();
259 }
260 
emitOpenACCExitDataConstruct(const OpenACCExitDataConstruct & s)261 mlir::LogicalResult CIRGenFunction::emitOpenACCExitDataConstruct(
262     const OpenACCExitDataConstruct &s) {
263   mlir::Location start = getLoc(s.getSourceRange().getBegin());
264   emitOpenACCOp<ExitDataOp>(start, s.getDirectiveKind(), s.getDirectiveLoc(),
265                             s.clauses());
266   return mlir::success();
267 }
268 
269 mlir::LogicalResult
emitOpenACCUpdateConstruct(const OpenACCUpdateConstruct & s)270 CIRGenFunction::emitOpenACCUpdateConstruct(const OpenACCUpdateConstruct &s) {
271   mlir::Location start = getLoc(s.getSourceRange().getBegin());
272   emitOpenACCOp<UpdateOp>(start, s.getDirectiveKind(), s.getDirectiveLoc(),
273                           s.clauses());
274   return mlir::success();
275 }
276 
277 mlir::LogicalResult
emitOpenACCCacheConstruct(const OpenACCCacheConstruct & s)278 CIRGenFunction::emitOpenACCCacheConstruct(const OpenACCCacheConstruct &s) {
279   // The 'cache' directive 'may' be at the top of a loop by standard, but
280   // doesn't have to be. Additionally, there is nothing that requires this be a
281   // loop affected by an OpenACC pragma. Sema doesn't do any level of
282   // enforcement here, since it isn't particularly valuable to do so thanks to
283   // that. Instead, we treat cache as a 'noop' if there is no acc.loop to apply
284   // it to.
285   if (!activeLoopOp)
286     return mlir::success();
287 
288   mlir::acc::LoopOp loopOp = *activeLoopOp;
289 
290   mlir::OpBuilder::InsertionGuard guard(builder);
291   builder.setInsertionPoint(loopOp);
292 
293   for (const Expr *var : s.getVarList()) {
294     CIRGenFunction::OpenACCDataOperandInfo opInfo =
295         getOpenACCDataOperandInfo(var);
296 
297     auto cacheOp = builder.create<CacheOp>(
298         opInfo.beginLoc, opInfo.varValue,
299         /*structured=*/false, /*implicit=*/false, opInfo.name, opInfo.bounds);
300 
301     loopOp.getCacheOperandsMutable().append(cacheOp.getResult());
302   }
303 
304   return mlir::success();
305 }
306 
307 mlir::LogicalResult
emitOpenACCAtomicConstruct(const OpenACCAtomicConstruct & s)308 CIRGenFunction::emitOpenACCAtomicConstruct(const OpenACCAtomicConstruct &s) {
309   cgm.errorNYI(s.getSourceRange(), "OpenACC Atomic Construct");
310   return mlir::failure();
311 }
312