1 //===- ConstructCompositionT.h -- Composing compound constructs -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 // Given a list of leaf construct, each with a set of clauses, generate the
9 // compound construct whose leaf constructs are the given list, and whose clause
10 // list is the merged lists of individual leaf clauses.
11 //
12 // *** At the moment it assumes that the individual constructs and their clauses
13 // *** are a subset of those created by splitting a valid compound construct.
14 //===----------------------------------------------------------------------===//
15 #ifndef LLVM_FRONTEND_OPENMP_CONSTRUCTCOMPOSITIONT_H
16 #define LLVM_FRONTEND_OPENMP_CONSTRUCTCOMPOSITIONT_H
17
18 #include "llvm/ADT/ArrayRef.h"
19 #include "llvm/ADT/BitVector.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/Frontend/OpenMP/ClauseT.h"
23 #include "llvm/Frontend/OpenMP/OMP.h"
24
25 #include <iterator>
26 #include <optional>
27 #include <tuple>
28 #include <unordered_map>
29 #include <unordered_set>
30 #include <utility>
31
32 namespace tomp {
33 template <typename ClauseType> struct ConstructCompositionT {
34 using ClauseTy = ClauseType;
35
36 using TypeTy = typename ClauseTy::TypeTy;
37 using IdTy = typename ClauseTy::IdTy;
38 using ExprTy = typename ClauseTy::ExprTy;
39
40 ConstructCompositionT(uint32_t version,
41 llvm::ArrayRef<DirectiveWithClauses<ClauseTy>> leafs);
42
43 DirectiveWithClauses<ClauseTy> merged;
44
45 private:
46 // Use an ordered container, since we beed to maintain the order in which
47 // clauses are added to it. This is to avoid non-deterministic output.
48 using ClauseSet = ListT<ClauseTy>;
49
50 enum class Presence {
51 All, // Clause is preesnt on all leaf constructs that allow it.
52 Some, // Clause is present on some, but not on all constructs.
53 None, // Clause is absent on all constructs.
54 };
55
56 template <typename S>
makeClauseConstructCompositionT57 ClauseTy makeClause(llvm::omp::Clause clauseId, S &&specific) {
58 return typename ClauseTy::BaseT{clauseId, std::move(specific)};
59 }
60
61 llvm::omp::Directive
62 makeCompound(llvm::ArrayRef<DirectiveWithClauses<ClauseTy>> parts);
63
64 Presence checkPresence(llvm::omp::Clause clauseId);
65
66 // There are clauses that need special handling:
67 // 1. "if": the "directive-name-modifier" on the merged clause may need
68 // to be set appropriately.
69 // 2. "reduction": implies "privateness" of all objects (incompatible
70 // with "shared"); there are rules for merging modifiers
71 void mergeIf();
72 void mergeReduction();
73 void mergeDSA();
74
75 uint32_t version;
76 llvm::ArrayRef<DirectiveWithClauses<ClauseTy>> leafs;
77
78 // clause id -> set of leaf constructs that contain it
79 std::unordered_map<llvm::omp::Clause, llvm::BitVector> clausePresence;
80 // clause id -> set of instances of that clause
81 std::unordered_map<llvm::omp::Clause, ClauseSet> clauseSets;
82 };
83
84 template <typename C>
ConstructCompositionT(uint32_t version,llvm::ArrayRef<DirectiveWithClauses<C>> leafs)85 ConstructCompositionT<C>::ConstructCompositionT(
86 uint32_t version, llvm::ArrayRef<DirectiveWithClauses<C>> leafs)
87 : version(version), leafs(leafs) {
88 // Merge the list of constructs with clauses into a compound construct
89 // with a single list of clauses.
90 // The intended use of this function is in splitting compound constructs,
91 // while preserving composite constituent constructs:
92 // Step 1: split compound construct into leaf constructs.
93 // Step 2: identify composite sub-construct, and merge the constituent leafs.
94 //
95 // *** At the moment it assumes that the individual constructs and their
96 // *** clauses are a subset of those created by splitting a valid compound
97 // *** construct.
98 //
99 // 1. Deduplicate clauses
100 // - exact duplicates: e.g. shared(x) shared(x) -> shared(x)
101 // - special cases of clauses differing in modifier:
102 // (a) reduction: inscan + (none|default) = inscan
103 // (b) reduction: task + (none|default) = task
104 // (c) combine repeated "if" clauses if possible
105 // 2. Merge DSA clauses: e.g. private(x) private(y) -> private(x, y).
106 // 3. Resolve potential DSA conflicts (typically due to implied clauses).
107
108 if (leafs.empty())
109 return;
110
111 merged.id = makeCompound(leafs);
112
113 // Populate the two maps:
114 for (const auto &[index, leaf] : llvm::enumerate(leafs)) {
115 for (const auto &clause : leaf.clauses) {
116 // Update clausePresence.
117 auto &pset = clausePresence[clause.id];
118 if (pset.size() < leafs.size())
119 pset.resize(leafs.size());
120 pset.set(index);
121 // Update clauseSets.
122 ClauseSet &cset = clauseSets[clause.id];
123 if (!llvm::is_contained(cset, clause))
124 cset.push_back(clause);
125 }
126 }
127
128 mergeIf();
129 mergeReduction();
130 mergeDSA();
131
132 // Fir the rest of the clauses, just copy them.
133 for (auto &[id, clauses] : clauseSets) {
134 // Skip clauses we've already dealt with.
135 switch (id) {
136 case llvm::omp::Clause::OMPC_if:
137 case llvm::omp::Clause::OMPC_reduction:
138 case llvm::omp::Clause::OMPC_shared:
139 case llvm::omp::Clause::OMPC_private:
140 case llvm::omp::Clause::OMPC_firstprivate:
141 case llvm::omp::Clause::OMPC_lastprivate:
142 continue;
143 default:
144 break;
145 }
146 llvm::append_range(merged.clauses, clauses);
147 }
148 }
149
150 template <typename C>
makeCompound(llvm::ArrayRef<DirectiveWithClauses<ClauseTy>> parts)151 llvm::omp::Directive ConstructCompositionT<C>::makeCompound(
152 llvm::ArrayRef<DirectiveWithClauses<ClauseTy>> parts) {
153 llvm::SmallVector<llvm::omp::Directive> dirIds;
154 llvm::transform(parts, std::back_inserter(dirIds),
155 [](auto &&dwc) { return dwc.id; });
156
157 return llvm::omp::getCompoundConstruct(dirIds);
158 }
159
160 template <typename C>
161 auto ConstructCompositionT<C>::checkPresence(llvm::omp::Clause clauseId)
162 -> Presence {
163 auto found = clausePresence.find(clauseId);
164 if (found == clausePresence.end())
165 return Presence::None;
166
167 bool OnAll = true, OnNone = true;
168 for (const auto &[index, leaf] : llvm::enumerate(leafs)) {
169 if (!llvm::omp::isAllowedClauseForDirective(leaf.id, clauseId, version))
170 continue;
171
172 if (found->second.test(index))
173 OnNone = false;
174 else
175 OnAll = false;
176 }
177
178 if (OnNone)
179 return Presence::None;
180 if (OnAll)
181 return Presence::All;
182 return Presence::Some;
183 }
184
mergeIf()185 template <typename C> void ConstructCompositionT<C>::mergeIf() {
186 using IfTy = tomp::clause::IfT<TypeTy, IdTy, ExprTy>;
187 // Deal with the "if" clauses. If it's on all leafs that allow it, then it
188 // will apply to the compound construct. Otherwise it will apply to the
189 // single (assumed) leaf construct.
190 // This assumes that the "if" clauses have the same expression.
191 Presence presence = checkPresence(llvm::omp::Clause::OMPC_if);
192 if (presence == Presence::None)
193 return;
194
195 const ClauseTy &some = *clauseSets[llvm::omp::Clause::OMPC_if].begin();
196 const auto &someIf = std::get<IfTy>(some.u);
197
198 if (presence == Presence::All) {
199 // Create "if" without "directive-name-modifier".
200 merged.clauses.emplace_back(
201 makeClause(llvm::omp::Clause::OMPC_if,
202 IfTy{{/*DirectiveNameModifier=*/std::nullopt,
203 /*IfExpression=*/std::get<typename IfTy::IfExpression>(
204 someIf.t)}}));
205 } else {
206 // Find out where it's present and create "if" with the corresponding
207 // "directive-name-modifier".
208 int Idx = clausePresence[llvm::omp::Clause::OMPC_if].find_first();
209 assert(Idx >= 0);
210 merged.clauses.emplace_back(
211 makeClause(llvm::omp::Clause::OMPC_if,
212 IfTy{{/*DirectiveNameModifier=*/leafs[Idx].id,
213 /*IfExpression=*/std::get<typename IfTy::IfExpression>(
214 someIf.t)}}));
215 }
216 }
217
mergeReduction()218 template <typename C> void ConstructCompositionT<C>::mergeReduction() {
219 Presence presence = checkPresence(llvm::omp::Clause::OMPC_reduction);
220 if (presence == Presence::None)
221 return;
222
223 using ReductionTy = tomp::clause::ReductionT<TypeTy, IdTy, ExprTy>;
224 using ModifierTy = typename ReductionTy::ReductionModifier;
225 using IdentifiersTy = typename ReductionTy::ReductionIdentifiers;
226 using ListTy = typename ReductionTy::List;
227 // There are exceptions on which constructs "reduction" may appear
228 // (specifically "parallel", and "teams"). Assume that if "reduction"
229 // is present, it can be applied to the compound construct.
230
231 // What's left is to see if there are any modifiers present. Again,
232 // assume that there are no conflicting modifiers.
233 // There can be, however, multiple reductions on different objects.
234 auto equal = [](const ClauseTy &red1, const ClauseTy &red2) {
235 // Extract actual reductions.
236 const auto r1 = std::get<ReductionTy>(red1.u);
237 const auto r2 = std::get<ReductionTy>(red2.u);
238 // Compare everything except modifiers.
239 if (std::get<IdentifiersTy>(r1.t) != std::get<IdentifiersTy>(r2.t))
240 return false;
241 if (std::get<ListTy>(r1.t) != std::get<ListTy>(r2.t))
242 return false;
243 return true;
244 };
245
246 auto getModifier = [](const ClauseTy &clause) {
247 const ReductionTy &red = std::get<ReductionTy>(clause.u);
248 return std::get<std::optional<ModifierTy>>(red.t);
249 };
250
251 const ClauseSet &reductions = clauseSets[llvm::omp::Clause::OMPC_reduction];
252 std::unordered_set<const ClauseTy *> visited;
253 while (reductions.size() != visited.size()) {
254 typename ClauseSet::const_iterator first;
255
256 // Find first non-visited reduction.
257 for (first = reductions.begin(); first != reductions.end(); ++first) {
258 if (visited.count(&*first))
259 continue;
260 visited.insert(&*first);
261 break;
262 }
263
264 std::optional<ModifierTy> modifier = getModifier(*first);
265
266 // Visit all other reductions that are "equal" (with respect to the
267 // definition above) to "first". Collect modifiers.
268 for (auto iter = std::next(first); iter != reductions.end(); ++iter) {
269 if (!equal(*first, *iter))
270 continue;
271 visited.insert(&*iter);
272 if (!modifier || *modifier == ModifierTy::Default)
273 modifier = getModifier(*iter);
274 }
275
276 const auto &firstRed = std::get<ReductionTy>(first->u);
277 merged.clauses.emplace_back(makeClause(
278 llvm::omp::Clause::OMPC_reduction,
279 ReductionTy{
280 {/*ReductionModifier=*/modifier,
281 /*ReductionIdentifiers=*/std::get<IdentifiersTy>(firstRed.t),
282 /*List=*/std::get<ListTy>(firstRed.t)}}));
283 }
284 }
285
mergeDSA()286 template <typename C> void ConstructCompositionT<C>::mergeDSA() {
287 using ObjectTy = tomp::type::ObjectT<IdTy, ExprTy>;
288
289 // Resolve data-sharing attributes.
290 enum DSA : int {
291 None = 0,
292 Shared = 1 << 0,
293 Private = 1 << 1,
294 FirstPrivate = 1 << 2,
295 LastPrivate = 1 << 3,
296 LastPrivateConditional = 1 << 4,
297 };
298
299 // Use ordered containers to avoid non-deterministic output.
300 llvm::SmallVector<std::pair<ObjectTy, int>, 8> objectDsa;
301
302 auto getDsa = [&](const ObjectTy &object) -> std::pair<ObjectTy, int> & {
303 auto found = llvm::find_if(objectDsa, [&](std::pair<ObjectTy, int> &p) {
304 return p.first.id() == object.id();
305 });
306 if (found != objectDsa.end())
307 return *found;
308 return objectDsa.emplace_back(object, DSA::None);
309 };
310
311 using SharedTy = tomp::clause::SharedT<TypeTy, IdTy, ExprTy>;
312 using PrivateTy = tomp::clause::PrivateT<TypeTy, IdTy, ExprTy>;
313 using FirstprivateTy = tomp::clause::FirstprivateT<TypeTy, IdTy, ExprTy>;
314 using LastprivateTy = tomp::clause::LastprivateT<TypeTy, IdTy, ExprTy>;
315
316 // Visit clauses that affect DSA.
317 for (auto &clause : clauseSets[llvm::omp::Clause::OMPC_shared]) {
318 for (auto &object : std::get<SharedTy>(clause.u).v)
319 getDsa(object).second |= DSA::Shared;
320 }
321
322 for (auto &clause : clauseSets[llvm::omp::Clause::OMPC_private]) {
323 for (auto &object : std::get<PrivateTy>(clause.u).v)
324 getDsa(object).second |= DSA::Private;
325 }
326
327 for (auto &clause : clauseSets[llvm::omp::Clause::OMPC_firstprivate]) {
328 for (auto &object : std::get<FirstprivateTy>(clause.u).v)
329 getDsa(object).second |= DSA::FirstPrivate;
330 }
331
332 for (auto &clause : clauseSets[llvm::omp::Clause::OMPC_lastprivate]) {
333 using ModifierTy = typename LastprivateTy::LastprivateModifier;
334 using ListTy = typename LastprivateTy::List;
335 const auto &lastp = std::get<LastprivateTy>(clause.u);
336 for (auto &object : std::get<ListTy>(lastp.t)) {
337 auto &mod = std::get<std::optional<ModifierTy>>(lastp.t);
338 if (mod && *mod == ModifierTy::Conditional) {
339 getDsa(object).second |= DSA::LastPrivateConditional;
340 } else {
341 getDsa(object).second |= DSA::LastPrivate;
342 }
343 }
344 }
345
346 // Check other privatizing clauses as well, clear "shared" if set.
347 for (auto &clause : clauseSets[llvm::omp::Clause::OMPC_in_reduction]) {
348 using InReductionTy = tomp::clause::InReductionT<TypeTy, IdTy, ExprTy>;
349 using ListTy = typename InReductionTy::List;
350 for (auto &object : std::get<ListTy>(std::get<InReductionTy>(clause.u).t))
351 getDsa(object).second &= ~DSA::Shared;
352 }
353 for (auto &clause : clauseSets[llvm::omp::Clause::OMPC_linear]) {
354 using LinearTy = tomp::clause::LinearT<TypeTy, IdTy, ExprTy>;
355 using ListTy = typename LinearTy::List;
356 for (auto &object : std::get<ListTy>(std::get<LinearTy>(clause.u).t))
357 getDsa(object).second &= ~DSA::Shared;
358 }
359 for (auto &clause : clauseSets[llvm::omp::Clause::OMPC_reduction]) {
360 using ReductionTy = tomp::clause::ReductionT<TypeTy, IdTy, ExprTy>;
361 using ListTy = typename ReductionTy::List;
362 for (auto &object : std::get<ListTy>(std::get<ReductionTy>(clause.u).t))
363 getDsa(object).second &= ~DSA::Shared;
364 }
365 for (auto &clause : clauseSets[llvm::omp::Clause::OMPC_task_reduction]) {
366 using TaskReductionTy = tomp::clause::TaskReductionT<TypeTy, IdTy, ExprTy>;
367 using ListTy = typename TaskReductionTy::List;
368 for (auto &object : std::get<ListTy>(std::get<TaskReductionTy>(clause.u).t))
369 getDsa(object).second &= ~DSA::Shared;
370 }
371
372 tomp::ListT<ObjectTy> privateObj, sharedObj, firstpObj, lastpObj, lastpcObj;
373 for (auto &[object, dsa] : objectDsa) {
374 if (dsa &
375 (DSA::FirstPrivate | DSA::LastPrivate | DSA::LastPrivateConditional)) {
376 if (dsa & DSA::FirstPrivate)
377 firstpObj.push_back(object); // no else
378 if (dsa & DSA::LastPrivateConditional)
379 lastpcObj.push_back(object);
380 else if (dsa & DSA::LastPrivate)
381 lastpObj.push_back(object);
382 } else if (dsa & DSA::Private) {
383 privateObj.push_back(object);
384 } else if (dsa & DSA::Shared) {
385 sharedObj.push_back(object);
386 }
387 }
388
389 // Materialize each clause.
390 if (!privateObj.empty()) {
391 merged.clauses.emplace_back(
392 makeClause(llvm::omp::Clause::OMPC_private,
393 PrivateTy{/*List=*/std::move(privateObj)}));
394 }
395 if (!sharedObj.empty()) {
396 merged.clauses.emplace_back(
397 makeClause(llvm::omp::Clause::OMPC_shared,
398 SharedTy{/*List=*/std::move(sharedObj)}));
399 }
400 if (!firstpObj.empty()) {
401 merged.clauses.emplace_back(
402 makeClause(llvm::omp::Clause::OMPC_firstprivate,
403 FirstprivateTy{/*List=*/std::move(firstpObj)}));
404 }
405 if (!lastpObj.empty()) {
406 merged.clauses.emplace_back(
407 makeClause(llvm::omp::Clause::OMPC_lastprivate,
408 LastprivateTy{{/*LastprivateModifier=*/std::nullopt,
409 /*List=*/std::move(lastpObj)}}));
410 }
411 if (!lastpcObj.empty()) {
412 auto conditional = LastprivateTy::LastprivateModifier::Conditional;
413 merged.clauses.emplace_back(
414 makeClause(llvm::omp::Clause::OMPC_lastprivate,
415 LastprivateTy{{/*LastprivateModifier=*/conditional,
416 /*List=*/std::move(lastpcObj)}}));
417 }
418 }
419 } // namespace tomp
420
421 #endif // LLVM_FRONTEND_OPENMP_CONSTRUCTCOMPOSITIONT_H
422