1*700637cbSDimitry Andric //===----------------------------------------------------------------------===//
2*700637cbSDimitry Andric //
3*700637cbSDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*700637cbSDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
5*700637cbSDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*700637cbSDimitry Andric //
7*700637cbSDimitry Andric //===----------------------------------------------------------------------===//
8*700637cbSDimitry Andric //
9*700637cbSDimitry Andric // Emit OpenACC Loop Stmt node as CIR code.
10*700637cbSDimitry Andric //
11*700637cbSDimitry Andric //===----------------------------------------------------------------------===//
12*700637cbSDimitry Andric
13*700637cbSDimitry Andric #include "CIRGenBuilder.h"
14*700637cbSDimitry Andric #include "CIRGenFunction.h"
15*700637cbSDimitry Andric
16*700637cbSDimitry Andric #include "clang/AST/StmtOpenACC.h"
17*700637cbSDimitry Andric
18*700637cbSDimitry Andric #include "mlir/Dialect/OpenACC/OpenACC.h"
19*700637cbSDimitry Andric
20*700637cbSDimitry Andric using namespace clang;
21*700637cbSDimitry Andric using namespace clang::CIRGen;
22*700637cbSDimitry Andric using namespace cir;
23*700637cbSDimitry Andric using namespace mlir::acc;
24*700637cbSDimitry Andric
updateLoopOpParallelism(mlir::acc::LoopOp & op,bool isOrphan,OpenACCDirectiveKind dk)25*700637cbSDimitry Andric void CIRGenFunction::updateLoopOpParallelism(mlir::acc::LoopOp &op,
26*700637cbSDimitry Andric bool isOrphan,
27*700637cbSDimitry Andric OpenACCDirectiveKind dk) {
28*700637cbSDimitry Andric // Check that at least one of auto, independent, or seq is present
29*700637cbSDimitry Andric // for the device-independent default clauses.
30*700637cbSDimitry Andric if (op.hasParallelismFlag(mlir::acc::DeviceType::None))
31*700637cbSDimitry Andric return;
32*700637cbSDimitry Andric
33*700637cbSDimitry Andric switch (dk) {
34*700637cbSDimitry Andric default:
35*700637cbSDimitry Andric llvm_unreachable("Invalid parent directive kind");
36*700637cbSDimitry Andric case OpenACCDirectiveKind::Invalid:
37*700637cbSDimitry Andric case OpenACCDirectiveKind::Parallel:
38*700637cbSDimitry Andric case OpenACCDirectiveKind::ParallelLoop:
39*700637cbSDimitry Andric op.addIndependent(builder.getContext(), {});
40*700637cbSDimitry Andric return;
41*700637cbSDimitry Andric case OpenACCDirectiveKind::Kernels:
42*700637cbSDimitry Andric case OpenACCDirectiveKind::KernelsLoop:
43*700637cbSDimitry Andric op.addAuto(builder.getContext(), {});
44*700637cbSDimitry Andric return;
45*700637cbSDimitry Andric case OpenACCDirectiveKind::Serial:
46*700637cbSDimitry Andric case OpenACCDirectiveKind::SerialLoop:
47*700637cbSDimitry Andric if (op.hasDefaultGangWorkerVector())
48*700637cbSDimitry Andric op.addAuto(builder.getContext(), {});
49*700637cbSDimitry Andric else
50*700637cbSDimitry Andric op.addSeq(builder.getContext(), {});
51*700637cbSDimitry Andric return;
52*700637cbSDimitry Andric };
53*700637cbSDimitry Andric }
54*700637cbSDimitry Andric
55*700637cbSDimitry Andric mlir::LogicalResult
emitOpenACCLoopConstruct(const OpenACCLoopConstruct & s)56*700637cbSDimitry Andric CIRGenFunction::emitOpenACCLoopConstruct(const OpenACCLoopConstruct &s) {
57*700637cbSDimitry Andric mlir::Location start = getLoc(s.getSourceRange().getBegin());
58*700637cbSDimitry Andric mlir::Location end = getLoc(s.getSourceRange().getEnd());
59*700637cbSDimitry Andric llvm::SmallVector<mlir::Type> retTy;
60*700637cbSDimitry Andric llvm::SmallVector<mlir::Value> operands;
61*700637cbSDimitry Andric auto op = builder.create<LoopOp>(start, retTy, operands);
62*700637cbSDimitry Andric
63*700637cbSDimitry Andric // TODO(OpenACC): In the future we are going to need to come up with a
64*700637cbSDimitry Andric // transformation here that can teach the acc.loop how to figure out the
65*700637cbSDimitry Andric // 'lowerbound', 'upperbound', and 'step'.
66*700637cbSDimitry Andric //
67*700637cbSDimitry Andric // -'upperbound' should fortunately be pretty easy as it should be
68*700637cbSDimitry Andric // in the initialization section of the cir.for loop. In Sema, we limit to
69*700637cbSDimitry Andric // just the forms 'Var = init', `Type Var = init`, or `Var = init` (where it
70*700637cbSDimitry Andric // is an operator= call)`. However, as those are all necessary to emit for
71*700637cbSDimitry Andric // the init section of the for loop, they should be inside the initial
72*700637cbSDimitry Andric // cir.scope.
73*700637cbSDimitry Andric //
74*700637cbSDimitry Andric // -'upperbound' should be somewhat easy to determine. Sema is limiting this
75*700637cbSDimitry Andric // to: ==, <, >, !=, <=, >= builtin operators, the overloaded 'comparison'
76*700637cbSDimitry Andric // operations, and member-call expressions.
77*700637cbSDimitry Andric //
78*700637cbSDimitry Andric // For the builtin comparison operators, we can pretty well deduce based on
79*700637cbSDimitry Andric // the comparison what the 'end' object is going to be, and the inclusive
80*700637cbSDimitry Andric // nature of it.
81*700637cbSDimitry Andric //
82*700637cbSDimitry Andric // For the overloaded operators, Sema will ensure that at least one side of
83*700637cbSDimitry Andric // the operator is the init variable, so we can deduce the comparison there
84*700637cbSDimitry Andric // too. The standard places no real bounds on WHAT the comparison operators do
85*700637cbSDimitry Andric // for a `RandomAccessIterator` however, so we'll have to just 'assume' they
86*700637cbSDimitry Andric // do the right thing? Note that this might be incrementing by a different
87*700637cbSDimitry Andric // 'object', not an integral, so it isn't really clear to me what we can do to
88*700637cbSDimitry Andric // determine the other side.
89*700637cbSDimitry Andric //
90*700637cbSDimitry Andric // Member-call expressions are the difficult ones. I don't think there is
91*700637cbSDimitry Andric // anything we can deduce from this to determine the 'end', so we might end up
92*700637cbSDimitry Andric // having to go back to Sema and make this ill-formed.
93*700637cbSDimitry Andric //
94*700637cbSDimitry Andric // HOWEVER: What ACC dialect REALLY cares about is the tripcount, which you
95*700637cbSDimitry Andric // cannot get (in the case of `RandomAccessIterator`) from JUST 'upperbound'
96*700637cbSDimitry Andric // and 'lowerbound'. We will likely have to provide a 'recipe' equivalent to
97*700637cbSDimitry Andric // `std::distance` instead. In the case of integer/pointers, it is fairly
98*700637cbSDimitry Andric // simple to find: it is just the mathematical subtraction. Howver, in the
99*700637cbSDimitry Andric // case of `RandomAccessIterator`, we have to enable the use of `operator-`.
100*700637cbSDimitry Andric // FORTUNATELY the standard requires this to work correctly for
101*700637cbSDimitry Andric // `RandomAccessIterator`, so we don't have to implement a `std::distance`
102*700637cbSDimitry Andric // that loops through, like we would for a forward/etc iterator.
103*700637cbSDimitry Andric //
104*700637cbSDimitry Andric // 'step': Sema is currently allowing builtin ++,--, +=, -=, *=, /=, and =
105*700637cbSDimitry Andric // operators. Additionally, it allows the equivalent for the operator-call, as
106*700637cbSDimitry Andric // well as member-call.
107*700637cbSDimitry Andric //
108*700637cbSDimitry Andric // For builtin operators, we perhaps should refine the assignment here. It
109*700637cbSDimitry Andric // doesn't really help us know the 'step' count at all, but we could perhaps
110*700637cbSDimitry Andric // do one more step of analysis in Sema to allow something like Var = Var + 1.
111*700637cbSDimitry Andric // For the others, this should get us the step reasonably well.
112*700637cbSDimitry Andric //
113*700637cbSDimitry Andric // For the overloaded operators, we have the same problems as for
114*700637cbSDimitry Andric // 'upperbound', plus not really knowing what they do. Member-call expressions
115*700637cbSDimitry Andric // are again difficult, and we might want to reconsider allowing these in
116*700637cbSDimitry Andric // Sema.
117*700637cbSDimitry Andric //
118*700637cbSDimitry Andric
119*700637cbSDimitry Andric // Emit all clauses.
120*700637cbSDimitry Andric emitOpenACCClauses(op, s.getDirectiveKind(), s.getDirectiveLoc(),
121*700637cbSDimitry Andric s.clauses());
122*700637cbSDimitry Andric
123*700637cbSDimitry Andric updateLoopOpParallelism(op, s.isOrphanedLoopConstruct(),
124*700637cbSDimitry Andric s.getParentComputeConstructKind());
125*700637cbSDimitry Andric
126*700637cbSDimitry Andric mlir::LogicalResult stmtRes = mlir::success();
127*700637cbSDimitry Andric // Emit body.
128*700637cbSDimitry Andric {
129*700637cbSDimitry Andric mlir::Block &block = op.getRegion().emplaceBlock();
130*700637cbSDimitry Andric mlir::OpBuilder::InsertionGuard guardCase(builder);
131*700637cbSDimitry Andric builder.setInsertionPointToEnd(&block);
132*700637cbSDimitry Andric LexicalScope ls{*this, start, builder.getInsertionBlock()};
133*700637cbSDimitry Andric ActiveOpenACCLoopRAII activeLoop{*this, &op};
134*700637cbSDimitry Andric
135*700637cbSDimitry Andric stmtRes = emitStmt(s.getLoop(), /*useCurrentScope=*/true);
136*700637cbSDimitry Andric builder.create<mlir::acc::YieldOp>(end);
137*700637cbSDimitry Andric }
138*700637cbSDimitry Andric
139*700637cbSDimitry Andric return stmtRes;
140*700637cbSDimitry Andric }
141