xref: /freebsd/contrib/llvm-project/clang/lib/CIR/CodeGen/CIRGenStmtOpenACCLoop.cpp (revision 700637cbb5e582861067a11aaca4d053546871d2)
1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Emit OpenACC Loop Stmt node as CIR code.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "CIRGenBuilder.h"
14 #include "CIRGenFunction.h"
15 
16 #include "clang/AST/StmtOpenACC.h"
17 
18 #include "mlir/Dialect/OpenACC/OpenACC.h"
19 
20 using namespace clang;
21 using namespace clang::CIRGen;
22 using namespace cir;
23 using namespace mlir::acc;
24 
updateLoopOpParallelism(mlir::acc::LoopOp & op,bool isOrphan,OpenACCDirectiveKind dk)25 void CIRGenFunction::updateLoopOpParallelism(mlir::acc::LoopOp &op,
26                                              bool isOrphan,
27                                              OpenACCDirectiveKind dk) {
28   // Check that at least one of auto, independent, or seq is present
29   // for the device-independent default clauses.
30   if (op.hasParallelismFlag(mlir::acc::DeviceType::None))
31     return;
32 
33   switch (dk) {
34   default:
35     llvm_unreachable("Invalid parent directive kind");
36   case OpenACCDirectiveKind::Invalid:
37   case OpenACCDirectiveKind::Parallel:
38   case OpenACCDirectiveKind::ParallelLoop:
39     op.addIndependent(builder.getContext(), {});
40     return;
41   case OpenACCDirectiveKind::Kernels:
42   case OpenACCDirectiveKind::KernelsLoop:
43     op.addAuto(builder.getContext(), {});
44     return;
45   case OpenACCDirectiveKind::Serial:
46   case OpenACCDirectiveKind::SerialLoop:
47     if (op.hasDefaultGangWorkerVector())
48       op.addAuto(builder.getContext(), {});
49     else
50       op.addSeq(builder.getContext(), {});
51     return;
52   };
53 }
54 
55 mlir::LogicalResult
emitOpenACCLoopConstruct(const OpenACCLoopConstruct & s)56 CIRGenFunction::emitOpenACCLoopConstruct(const OpenACCLoopConstruct &s) {
57   mlir::Location start = getLoc(s.getSourceRange().getBegin());
58   mlir::Location end = getLoc(s.getSourceRange().getEnd());
59   llvm::SmallVector<mlir::Type> retTy;
60   llvm::SmallVector<mlir::Value> operands;
61   auto op = builder.create<LoopOp>(start, retTy, operands);
62 
63   // TODO(OpenACC): In the future we are going to need to come up with a
64   // transformation here that can teach the acc.loop how to figure out the
65   // 'lowerbound', 'upperbound', and 'step'.
66   //
67   // -'upperbound' should fortunately be pretty easy as it should be
68   // in the initialization section of the cir.for loop. In Sema, we limit to
69   // just the forms 'Var = init', `Type Var = init`, or `Var = init` (where it
70   // is an operator= call)`.  However, as those are all necessary to emit for
71   // the init section of the for loop, they should be inside the initial
72   // cir.scope.
73   //
74   // -'upperbound' should be somewhat easy to determine. Sema is limiting this
75   // to: ==, <, >, !=,  <=, >= builtin operators, the overloaded 'comparison'
76   // operations, and member-call expressions.
77   //
78   // For the builtin comparison operators, we can pretty well deduce based on
79   // the comparison what the 'end' object is going to be, and the inclusive
80   // nature of it.
81   //
82   // For the overloaded operators, Sema will ensure that at least one side of
83   // the operator is the init variable, so we can deduce the comparison there
84   // too. The standard places no real bounds on WHAT the comparison operators do
85   // for a `RandomAccessIterator` however, so we'll have to just 'assume' they
86   // do the right thing? Note that this might be incrementing by a different
87   // 'object', not an integral, so it isn't really clear to me what we can do to
88   // determine the other side.
89   //
90   // Member-call expressions are the difficult ones. I don't think there is
91   // anything we can deduce from this to determine the 'end', so we might end up
92   // having to go back to Sema and make this ill-formed.
93   //
94   // HOWEVER: What ACC dialect REALLY cares about is the tripcount, which you
95   // cannot get (in the case of `RandomAccessIterator`) from JUST 'upperbound'
96   // and 'lowerbound'. We will likely have to provide a 'recipe' equivalent to
97   // `std::distance` instead.  In the case of integer/pointers, it is fairly
98   // simple to find: it is just the mathematical subtraction. Howver, in the
99   // case of `RandomAccessIterator`, we have to enable the use of `operator-`.
100   // FORTUNATELY the standard requires this to work correctly for
101   // `RandomAccessIterator`, so we don't have to implement a `std::distance`
102   // that loops through, like we would for a forward/etc iterator.
103   //
104   // 'step': Sema is currently allowing builtin ++,--, +=, -=, *=, /=, and =
105   // operators. Additionally, it allows the equivalent for the operator-call, as
106   // well as member-call.
107   //
108   // For builtin operators, we perhaps should refine the assignment here. It
109   // doesn't really help us know the 'step' count at all, but we could perhaps
110   // do one more step of analysis in Sema to allow something like Var = Var + 1.
111   // For the others, this should get us the step reasonably well.
112   //
113   // For the overloaded operators, we have the same problems as for
114   // 'upperbound', plus not really knowing what they do. Member-call expressions
115   // are again difficult, and we might want to reconsider allowing these in
116   // Sema.
117   //
118 
119   // Emit all clauses.
120   emitOpenACCClauses(op, s.getDirectiveKind(), s.getDirectiveLoc(),
121                      s.clauses());
122 
123   updateLoopOpParallelism(op, s.isOrphanedLoopConstruct(),
124                           s.getParentComputeConstructKind());
125 
126   mlir::LogicalResult stmtRes = mlir::success();
127   // Emit body.
128   {
129     mlir::Block &block = op.getRegion().emplaceBlock();
130     mlir::OpBuilder::InsertionGuard guardCase(builder);
131     builder.setInsertionPointToEnd(&block);
132     LexicalScope ls{*this, start, builder.getInsertionBlock()};
133     ActiveOpenACCLoopRAII activeLoop{*this, &op};
134 
135     stmtRes = emitStmt(s.getLoop(), /*useCurrentScope=*/true);
136     builder.create<mlir::acc::YieldOp>(end);
137   }
138 
139   return stmtRes;
140 }
141