1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Emit OpenACC Stmt nodes as CIR code.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "CIRGenBuilder.h"
14 #include "CIRGenFunction.h"
15 #include "mlir/Dialect/OpenACC/OpenACC.h"
16 #include "clang/AST/OpenACCClause.h"
17 #include "clang/AST/StmtOpenACC.h"
18
19 using namespace clang;
20 using namespace clang::CIRGen;
21 using namespace cir;
22 using namespace mlir::acc;
23
24 template <typename Op, typename TermOp>
emitOpenACCOpAssociatedStmt(mlir::Location start,mlir::Location end,OpenACCDirectiveKind dirKind,SourceLocation dirLoc,llvm::ArrayRef<const OpenACCClause * > clauses,const Stmt * associatedStmt)25 mlir::LogicalResult CIRGenFunction::emitOpenACCOpAssociatedStmt(
26 mlir::Location start, mlir::Location end, OpenACCDirectiveKind dirKind,
27 SourceLocation dirLoc, llvm::ArrayRef<const OpenACCClause *> clauses,
28 const Stmt *associatedStmt) {
29 mlir::LogicalResult res = mlir::success();
30
31 llvm::SmallVector<mlir::Type> retTy;
32 llvm::SmallVector<mlir::Value> operands;
33 auto op = builder.create<Op>(start, retTy, operands);
34
35 emitOpenACCClauses(op, dirKind, dirLoc, clauses);
36
37 {
38 mlir::Block &block = op.getRegion().emplaceBlock();
39 mlir::OpBuilder::InsertionGuard guardCase(builder);
40 builder.setInsertionPointToEnd(&block);
41
42 LexicalScope ls{*this, start, builder.getInsertionBlock()};
43 res = emitStmt(associatedStmt, /*useCurrentScope=*/true);
44
45 builder.create<TermOp>(end);
46 }
47 return res;
48 }
49
50 namespace {
51 template <typename Op> struct CombinedType;
52 template <> struct CombinedType<ParallelOp> {
53 static constexpr mlir::acc::CombinedConstructsType value =
54 mlir::acc::CombinedConstructsType::ParallelLoop;
55 };
56 template <> struct CombinedType<SerialOp> {
57 static constexpr mlir::acc::CombinedConstructsType value =
58 mlir::acc::CombinedConstructsType::SerialLoop;
59 };
60 template <> struct CombinedType<KernelsOp> {
61 static constexpr mlir::acc::CombinedConstructsType value =
62 mlir::acc::CombinedConstructsType::KernelsLoop;
63 };
64 } // namespace
65
66 template <typename Op, typename TermOp>
emitOpenACCOpCombinedConstruct(mlir::Location start,mlir::Location end,OpenACCDirectiveKind dirKind,SourceLocation dirLoc,llvm::ArrayRef<const OpenACCClause * > clauses,const Stmt * loopStmt)67 mlir::LogicalResult CIRGenFunction::emitOpenACCOpCombinedConstruct(
68 mlir::Location start, mlir::Location end, OpenACCDirectiveKind dirKind,
69 SourceLocation dirLoc, llvm::ArrayRef<const OpenACCClause *> clauses,
70 const Stmt *loopStmt) {
71 mlir::LogicalResult res = mlir::success();
72
73 llvm::SmallVector<mlir::Type> retTy;
74 llvm::SmallVector<mlir::Value> operands;
75
76 auto computeOp = builder.create<Op>(start, retTy, operands);
77 computeOp.setCombinedAttr(builder.getUnitAttr());
78 mlir::acc::LoopOp loopOp;
79
80 // First, emit the bodies of both operations, with the loop inside the body of
81 // the combined construct.
82 {
83 mlir::Block &block = computeOp.getRegion().emplaceBlock();
84 mlir::OpBuilder::InsertionGuard guardCase(builder);
85 builder.setInsertionPointToEnd(&block);
86
87 LexicalScope ls{*this, start, builder.getInsertionBlock()};
88 auto loopOp = builder.create<LoopOp>(start, retTy, operands);
89 loopOp.setCombinedAttr(mlir::acc::CombinedConstructsTypeAttr::get(
90 builder.getContext(), CombinedType<Op>::value));
91
92 {
93 mlir::Block &innerBlock = loopOp.getRegion().emplaceBlock();
94 mlir::OpBuilder::InsertionGuard guardCase(builder);
95 builder.setInsertionPointToEnd(&innerBlock);
96
97 LexicalScope ls{*this, start, builder.getInsertionBlock()};
98 ActiveOpenACCLoopRAII activeLoop{*this, &loopOp};
99
100 res = emitStmt(loopStmt, /*useCurrentScope=*/true);
101
102 builder.create<mlir::acc::YieldOp>(end);
103 }
104
105 emitOpenACCClauses(computeOp, loopOp, dirKind, dirLoc, clauses);
106
107 updateLoopOpParallelism(loopOp, /*isOrphan=*/false, dirKind);
108
109 builder.create<TermOp>(end);
110 }
111
112 return res;
113 }
114
115 template <typename Op>
emitOpenACCOp(mlir::Location start,OpenACCDirectiveKind dirKind,SourceLocation dirLoc,llvm::ArrayRef<const OpenACCClause * > clauses)116 Op CIRGenFunction::emitOpenACCOp(
117 mlir::Location start, OpenACCDirectiveKind dirKind, SourceLocation dirLoc,
118 llvm::ArrayRef<const OpenACCClause *> clauses) {
119 llvm::SmallVector<mlir::Type> retTy;
120 llvm::SmallVector<mlir::Value> operands;
121 auto op = builder.create<Op>(start, retTy, operands);
122
123 emitOpenACCClauses(op, dirKind, dirLoc, clauses);
124 return op;
125 }
126
127 mlir::LogicalResult
emitOpenACCComputeConstruct(const OpenACCComputeConstruct & s)128 CIRGenFunction::emitOpenACCComputeConstruct(const OpenACCComputeConstruct &s) {
129 mlir::Location start = getLoc(s.getSourceRange().getBegin());
130 mlir::Location end = getLoc(s.getSourceRange().getEnd());
131
132 switch (s.getDirectiveKind()) {
133 case OpenACCDirectiveKind::Parallel:
134 return emitOpenACCOpAssociatedStmt<ParallelOp, mlir::acc::YieldOp>(
135 start, end, s.getDirectiveKind(), s.getDirectiveLoc(), s.clauses(),
136 s.getStructuredBlock());
137 case OpenACCDirectiveKind::Serial:
138 return emitOpenACCOpAssociatedStmt<SerialOp, mlir::acc::YieldOp>(
139 start, end, s.getDirectiveKind(), s.getDirectiveLoc(), s.clauses(),
140 s.getStructuredBlock());
141 case OpenACCDirectiveKind::Kernels:
142 return emitOpenACCOpAssociatedStmt<KernelsOp, mlir::acc::TerminatorOp>(
143 start, end, s.getDirectiveKind(), s.getDirectiveLoc(), s.clauses(),
144 s.getStructuredBlock());
145 default:
146 llvm_unreachable("invalid compute construct kind");
147 }
148 }
149
150 mlir::LogicalResult
emitOpenACCDataConstruct(const OpenACCDataConstruct & s)151 CIRGenFunction::emitOpenACCDataConstruct(const OpenACCDataConstruct &s) {
152 mlir::Location start = getLoc(s.getSourceRange().getBegin());
153 mlir::Location end = getLoc(s.getSourceRange().getEnd());
154
155 return emitOpenACCOpAssociatedStmt<DataOp, mlir::acc::TerminatorOp>(
156 start, end, s.getDirectiveKind(), s.getDirectiveLoc(), s.clauses(),
157 s.getStructuredBlock());
158 }
159
160 mlir::LogicalResult
emitOpenACCInitConstruct(const OpenACCInitConstruct & s)161 CIRGenFunction::emitOpenACCInitConstruct(const OpenACCInitConstruct &s) {
162 mlir::Location start = getLoc(s.getSourceRange().getBegin());
163 emitOpenACCOp<InitOp>(start, s.getDirectiveKind(), s.getDirectiveLoc(),
164 s.clauses());
165 return mlir::success();
166 }
167
168 mlir::LogicalResult
emitOpenACCSetConstruct(const OpenACCSetConstruct & s)169 CIRGenFunction::emitOpenACCSetConstruct(const OpenACCSetConstruct &s) {
170 mlir::Location start = getLoc(s.getSourceRange().getBegin());
171 emitOpenACCOp<SetOp>(start, s.getDirectiveKind(), s.getDirectiveLoc(),
172 s.clauses());
173 return mlir::success();
174 }
175
emitOpenACCShutdownConstruct(const OpenACCShutdownConstruct & s)176 mlir::LogicalResult CIRGenFunction::emitOpenACCShutdownConstruct(
177 const OpenACCShutdownConstruct &s) {
178 mlir::Location start = getLoc(s.getSourceRange().getBegin());
179 emitOpenACCOp<ShutdownOp>(start, s.getDirectiveKind(),
180 s.getDirectiveLoc(), s.clauses());
181 return mlir::success();
182 }
183
184 mlir::LogicalResult
emitOpenACCWaitConstruct(const OpenACCWaitConstruct & s)185 CIRGenFunction::emitOpenACCWaitConstruct(const OpenACCWaitConstruct &s) {
186 mlir::Location start = getLoc(s.getSourceRange().getBegin());
187 auto waitOp = emitOpenACCOp<WaitOp>(start, s.getDirectiveKind(),
188 s.getDirectiveLoc(), s.clauses());
189
190 auto createIntExpr = [this](const Expr *intExpr) {
191 mlir::Value expr = emitScalarExpr(intExpr);
192 mlir::Location exprLoc = cgm.getLoc(intExpr->getBeginLoc());
193
194 mlir::IntegerType targetType = mlir::IntegerType::get(
195 &getMLIRContext(), getContext().getIntWidth(intExpr->getType()),
196 intExpr->getType()->isSignedIntegerOrEnumerationType()
197 ? mlir::IntegerType::SignednessSemantics::Signed
198 : mlir::IntegerType::SignednessSemantics::Unsigned);
199
200 auto conversionOp = builder.create<mlir::UnrealizedConversionCastOp>(
201 exprLoc, targetType, expr);
202 return conversionOp.getResult(0);
203 };
204
205 // Emit the correct 'wait' clauses.
206 {
207 mlir::OpBuilder::InsertionGuard guardCase(builder);
208 builder.setInsertionPoint(waitOp);
209
210 if (s.hasDevNumExpr())
211 waitOp.getWaitDevnumMutable().append(createIntExpr(s.getDevNumExpr()));
212
213 for (Expr *QueueExpr : s.getQueueIdExprs())
214 waitOp.getWaitOperandsMutable().append(createIntExpr(QueueExpr));
215 }
216
217 return mlir::success();
218 }
219
emitOpenACCCombinedConstruct(const OpenACCCombinedConstruct & s)220 mlir::LogicalResult CIRGenFunction::emitOpenACCCombinedConstruct(
221 const OpenACCCombinedConstruct &s) {
222 mlir::Location start = getLoc(s.getSourceRange().getBegin());
223 mlir::Location end = getLoc(s.getSourceRange().getEnd());
224
225 switch (s.getDirectiveKind()) {
226 case OpenACCDirectiveKind::ParallelLoop:
227 return emitOpenACCOpCombinedConstruct<ParallelOp, mlir::acc::YieldOp>(
228 start, end, s.getDirectiveKind(), s.getDirectiveLoc(), s.clauses(),
229 s.getLoop());
230 case OpenACCDirectiveKind::SerialLoop:
231 return emitOpenACCOpCombinedConstruct<SerialOp, mlir::acc::YieldOp>(
232 start, end, s.getDirectiveKind(), s.getDirectiveLoc(), s.clauses(),
233 s.getLoop());
234 case OpenACCDirectiveKind::KernelsLoop:
235 return emitOpenACCOpCombinedConstruct<KernelsOp, mlir::acc::TerminatorOp>(
236 start, end, s.getDirectiveKind(), s.getDirectiveLoc(), s.clauses(),
237 s.getLoop());
238 default:
239 llvm_unreachable("invalid compute construct kind");
240 }
241 }
242
emitOpenACCHostDataConstruct(const OpenACCHostDataConstruct & s)243 mlir::LogicalResult CIRGenFunction::emitOpenACCHostDataConstruct(
244 const OpenACCHostDataConstruct &s) {
245 mlir::Location start = getLoc(s.getSourceRange().getBegin());
246 mlir::Location end = getLoc(s.getSourceRange().getEnd());
247
248 return emitOpenACCOpAssociatedStmt<HostDataOp, mlir::acc::TerminatorOp>(
249 start, end, s.getDirectiveKind(), s.getDirectiveLoc(), s.clauses(),
250 s.getStructuredBlock());
251 }
252
emitOpenACCEnterDataConstruct(const OpenACCEnterDataConstruct & s)253 mlir::LogicalResult CIRGenFunction::emitOpenACCEnterDataConstruct(
254 const OpenACCEnterDataConstruct &s) {
255 mlir::Location start = getLoc(s.getSourceRange().getBegin());
256 emitOpenACCOp<EnterDataOp>(start, s.getDirectiveKind(), s.getDirectiveLoc(),
257 s.clauses());
258 return mlir::success();
259 }
260
emitOpenACCExitDataConstruct(const OpenACCExitDataConstruct & s)261 mlir::LogicalResult CIRGenFunction::emitOpenACCExitDataConstruct(
262 const OpenACCExitDataConstruct &s) {
263 mlir::Location start = getLoc(s.getSourceRange().getBegin());
264 emitOpenACCOp<ExitDataOp>(start, s.getDirectiveKind(), s.getDirectiveLoc(),
265 s.clauses());
266 return mlir::success();
267 }
268
269 mlir::LogicalResult
emitOpenACCUpdateConstruct(const OpenACCUpdateConstruct & s)270 CIRGenFunction::emitOpenACCUpdateConstruct(const OpenACCUpdateConstruct &s) {
271 mlir::Location start = getLoc(s.getSourceRange().getBegin());
272 emitOpenACCOp<UpdateOp>(start, s.getDirectiveKind(), s.getDirectiveLoc(),
273 s.clauses());
274 return mlir::success();
275 }
276
277 mlir::LogicalResult
emitOpenACCCacheConstruct(const OpenACCCacheConstruct & s)278 CIRGenFunction::emitOpenACCCacheConstruct(const OpenACCCacheConstruct &s) {
279 // The 'cache' directive 'may' be at the top of a loop by standard, but
280 // doesn't have to be. Additionally, there is nothing that requires this be a
281 // loop affected by an OpenACC pragma. Sema doesn't do any level of
282 // enforcement here, since it isn't particularly valuable to do so thanks to
283 // that. Instead, we treat cache as a 'noop' if there is no acc.loop to apply
284 // it to.
285 if (!activeLoopOp)
286 return mlir::success();
287
288 mlir::acc::LoopOp loopOp = *activeLoopOp;
289
290 mlir::OpBuilder::InsertionGuard guard(builder);
291 builder.setInsertionPoint(loopOp);
292
293 for (const Expr *var : s.getVarList()) {
294 CIRGenFunction::OpenACCDataOperandInfo opInfo =
295 getOpenACCDataOperandInfo(var);
296
297 auto cacheOp = builder.create<CacheOp>(
298 opInfo.beginLoc, opInfo.varValue,
299 /*structured=*/false, /*implicit=*/false, opInfo.name, opInfo.bounds);
300
301 loopOp.getCacheOperandsMutable().append(cacheOp.getResult());
302 }
303
304 return mlir::success();
305 }
306
307 mlir::LogicalResult
emitOpenACCAtomicConstruct(const OpenACCAtomicConstruct & s)308 CIRGenFunction::emitOpenACCAtomicConstruct(const OpenACCAtomicConstruct &s) {
309 cgm.errorNYI(s.getSourceRange(), "OpenACC Atomic Construct");
310 return mlir::failure();
311 }
312