xref: /freebsd/contrib/llvm-project/llvm/include/llvm/Frontend/OpenMP/ConstructCompositionT.h (revision 0fca6ea1d4eea4c934cfff25ac9ee8ad6fe95583)
1 //===- ConstructCompositionT.h -- Composing compound constructs -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 // Given a list of leaf construct, each with a set of clauses, generate the
9 // compound construct whose leaf constructs are the given list, and whose clause
10 // list is the merged lists of individual leaf clauses.
11 //
12 // *** At the moment it assumes that the individual constructs and their clauses
13 // *** are a subset of those created by splitting a valid compound construct.
14 //===----------------------------------------------------------------------===//
15 #ifndef LLVM_FRONTEND_OPENMP_CONSTRUCTCOMPOSITIONT_H
16 #define LLVM_FRONTEND_OPENMP_CONSTRUCTCOMPOSITIONT_H
17 
18 #include "llvm/ADT/ArrayRef.h"
19 #include "llvm/ADT/BitVector.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/Frontend/OpenMP/ClauseT.h"
23 #include "llvm/Frontend/OpenMP/OMP.h"
24 
25 #include <iterator>
26 #include <optional>
27 #include <tuple>
28 #include <unordered_map>
29 #include <unordered_set>
30 #include <utility>
31 
32 namespace tomp {
33 template <typename ClauseType> struct ConstructCompositionT {
34   using ClauseTy = ClauseType;
35 
36   using TypeTy = typename ClauseTy::TypeTy;
37   using IdTy = typename ClauseTy::IdTy;
38   using ExprTy = typename ClauseTy::ExprTy;
39 
40   ConstructCompositionT(uint32_t version,
41                         llvm::ArrayRef<DirectiveWithClauses<ClauseTy>> leafs);
42 
43   DirectiveWithClauses<ClauseTy> merged;
44 
45 private:
46   // Use an ordered container, since we beed to maintain the order in which
47   // clauses are added to it. This is to avoid non-deterministic output.
48   using ClauseSet = ListT<ClauseTy>;
49 
50   enum class Presence {
51     All,  // Clause is preesnt on all leaf constructs that allow it.
52     Some, // Clause is present on some, but not on all constructs.
53     None, // Clause is absent on all constructs.
54   };
55 
56   template <typename S>
makeClauseConstructCompositionT57   ClauseTy makeClause(llvm::omp::Clause clauseId, S &&specific) {
58     return typename ClauseTy::BaseT{clauseId, std::move(specific)};
59   }
60 
61   llvm::omp::Directive
62   makeCompound(llvm::ArrayRef<DirectiveWithClauses<ClauseTy>> parts);
63 
64   Presence checkPresence(llvm::omp::Clause clauseId);
65 
66   // There are clauses that need special handling:
67   // 1. "if": the "directive-name-modifier" on the merged clause may need
68   // to be set appropriately.
69   // 2. "reduction": implies "privateness" of all objects (incompatible
70   // with "shared"); there are rules for merging modifiers
71   void mergeIf();
72   void mergeReduction();
73   void mergeDSA();
74 
75   uint32_t version;
76   llvm::ArrayRef<DirectiveWithClauses<ClauseTy>> leafs;
77 
78   // clause id -> set of leaf constructs that contain it
79   std::unordered_map<llvm::omp::Clause, llvm::BitVector> clausePresence;
80   // clause id -> set of instances of that clause
81   std::unordered_map<llvm::omp::Clause, ClauseSet> clauseSets;
82 };
83 
84 template <typename C>
ConstructCompositionT(uint32_t version,llvm::ArrayRef<DirectiveWithClauses<C>> leafs)85 ConstructCompositionT<C>::ConstructCompositionT(
86     uint32_t version, llvm::ArrayRef<DirectiveWithClauses<C>> leafs)
87     : version(version), leafs(leafs) {
88   // Merge the list of constructs with clauses into a compound construct
89   // with a single list of clauses.
90   // The intended use of this function is in splitting compound constructs,
91   // while preserving composite constituent constructs:
92   // Step 1: split compound construct into leaf constructs.
93   // Step 2: identify composite sub-construct, and merge the constituent leafs.
94   //
95   // *** At the moment it assumes that the individual constructs and their
96   // *** clauses are a subset of those created by splitting a valid compound
97   // *** construct.
98   //
99   // 1. Deduplicate clauses
100   //    - exact duplicates: e.g. shared(x) shared(x) -> shared(x)
101   //    - special cases of clauses differing in modifier:
102   //      (a) reduction: inscan + (none|default) = inscan
103   //      (b) reduction: task + (none|default) = task
104   //      (c) combine repeated "if" clauses if possible
105   // 2. Merge DSA clauses: e.g. private(x) private(y) -> private(x, y).
106   // 3. Resolve potential DSA conflicts (typically due to implied clauses).
107 
108   if (leafs.empty())
109     return;
110 
111   merged.id = makeCompound(leafs);
112 
113   // Populate the two maps:
114   for (const auto &[index, leaf] : llvm::enumerate(leafs)) {
115     for (const auto &clause : leaf.clauses) {
116       // Update clausePresence.
117       auto &pset = clausePresence[clause.id];
118       if (pset.size() < leafs.size())
119         pset.resize(leafs.size());
120       pset.set(index);
121       // Update clauseSets.
122       ClauseSet &cset = clauseSets[clause.id];
123       if (!llvm::is_contained(cset, clause))
124         cset.push_back(clause);
125     }
126   }
127 
128   mergeIf();
129   mergeReduction();
130   mergeDSA();
131 
132   // Fir the rest of the clauses, just copy them.
133   for (auto &[id, clauses] : clauseSets) {
134     // Skip clauses we've already dealt with.
135     switch (id) {
136     case llvm::omp::Clause::OMPC_if:
137     case llvm::omp::Clause::OMPC_reduction:
138     case llvm::omp::Clause::OMPC_shared:
139     case llvm::omp::Clause::OMPC_private:
140     case llvm::omp::Clause::OMPC_firstprivate:
141     case llvm::omp::Clause::OMPC_lastprivate:
142       continue;
143     default:
144       break;
145     }
146     llvm::append_range(merged.clauses, clauses);
147   }
148 }
149 
150 template <typename C>
makeCompound(llvm::ArrayRef<DirectiveWithClauses<ClauseTy>> parts)151 llvm::omp::Directive ConstructCompositionT<C>::makeCompound(
152     llvm::ArrayRef<DirectiveWithClauses<ClauseTy>> parts) {
153   llvm::SmallVector<llvm::omp::Directive> dirIds;
154   llvm::transform(parts, std::back_inserter(dirIds),
155                   [](auto &&dwc) { return dwc.id; });
156 
157   return llvm::omp::getCompoundConstruct(dirIds);
158 }
159 
160 template <typename C>
161 auto ConstructCompositionT<C>::checkPresence(llvm::omp::Clause clauseId)
162     -> Presence {
163   auto found = clausePresence.find(clauseId);
164   if (found == clausePresence.end())
165     return Presence::None;
166 
167   bool OnAll = true, OnNone = true;
168   for (const auto &[index, leaf] : llvm::enumerate(leafs)) {
169     if (!llvm::omp::isAllowedClauseForDirective(leaf.id, clauseId, version))
170       continue;
171 
172     if (found->second.test(index))
173       OnNone = false;
174     else
175       OnAll = false;
176   }
177 
178   if (OnNone)
179     return Presence::None;
180   if (OnAll)
181     return Presence::All;
182   return Presence::Some;
183 }
184 
mergeIf()185 template <typename C> void ConstructCompositionT<C>::mergeIf() {
186   using IfTy = tomp::clause::IfT<TypeTy, IdTy, ExprTy>;
187   // Deal with the "if" clauses. If it's on all leafs that allow it, then it
188   // will apply to the compound construct. Otherwise it will apply to the
189   // single (assumed) leaf construct.
190   // This assumes that the "if" clauses have the same expression.
191   Presence presence = checkPresence(llvm::omp::Clause::OMPC_if);
192   if (presence == Presence::None)
193     return;
194 
195   const ClauseTy &some = *clauseSets[llvm::omp::Clause::OMPC_if].begin();
196   const auto &someIf = std::get<IfTy>(some.u);
197 
198   if (presence == Presence::All) {
199     // Create "if" without "directive-name-modifier".
200     merged.clauses.emplace_back(
201         makeClause(llvm::omp::Clause::OMPC_if,
202                    IfTy{{/*DirectiveNameModifier=*/std::nullopt,
203                          /*IfExpression=*/std::get<typename IfTy::IfExpression>(
204                              someIf.t)}}));
205   } else {
206     // Find out where it's present and create "if" with the corresponding
207     // "directive-name-modifier".
208     int Idx = clausePresence[llvm::omp::Clause::OMPC_if].find_first();
209     assert(Idx >= 0);
210     merged.clauses.emplace_back(
211         makeClause(llvm::omp::Clause::OMPC_if,
212                    IfTy{{/*DirectiveNameModifier=*/leafs[Idx].id,
213                          /*IfExpression=*/std::get<typename IfTy::IfExpression>(
214                              someIf.t)}}));
215   }
216 }
217 
mergeReduction()218 template <typename C> void ConstructCompositionT<C>::mergeReduction() {
219   Presence presence = checkPresence(llvm::omp::Clause::OMPC_reduction);
220   if (presence == Presence::None)
221     return;
222 
223   using ReductionTy = tomp::clause::ReductionT<TypeTy, IdTy, ExprTy>;
224   using ModifierTy = typename ReductionTy::ReductionModifier;
225   using IdentifiersTy = typename ReductionTy::ReductionIdentifiers;
226   using ListTy = typename ReductionTy::List;
227   // There are exceptions on which constructs "reduction" may appear
228   // (specifically "parallel", and "teams"). Assume that if "reduction"
229   // is present, it can be applied to the compound construct.
230 
231   // What's left is to see if there are any modifiers present. Again,
232   // assume that there are no conflicting modifiers.
233   // There can be, however, multiple reductions on different objects.
234   auto equal = [](const ClauseTy &red1, const ClauseTy &red2) {
235     // Extract actual reductions.
236     const auto r1 = std::get<ReductionTy>(red1.u);
237     const auto r2 = std::get<ReductionTy>(red2.u);
238     // Compare everything except modifiers.
239     if (std::get<IdentifiersTy>(r1.t) != std::get<IdentifiersTy>(r2.t))
240       return false;
241     if (std::get<ListTy>(r1.t) != std::get<ListTy>(r2.t))
242       return false;
243     return true;
244   };
245 
246   auto getModifier = [](const ClauseTy &clause) {
247     const ReductionTy &red = std::get<ReductionTy>(clause.u);
248     return std::get<std::optional<ModifierTy>>(red.t);
249   };
250 
251   const ClauseSet &reductions = clauseSets[llvm::omp::Clause::OMPC_reduction];
252   std::unordered_set<const ClauseTy *> visited;
253   while (reductions.size() != visited.size()) {
254     typename ClauseSet::const_iterator first;
255 
256     // Find first non-visited reduction.
257     for (first = reductions.begin(); first != reductions.end(); ++first) {
258       if (visited.count(&*first))
259         continue;
260       visited.insert(&*first);
261       break;
262     }
263 
264     std::optional<ModifierTy> modifier = getModifier(*first);
265 
266     // Visit all other reductions that are "equal" (with respect to the
267     // definition above) to "first". Collect modifiers.
268     for (auto iter = std::next(first); iter != reductions.end(); ++iter) {
269       if (!equal(*first, *iter))
270         continue;
271       visited.insert(&*iter);
272       if (!modifier || *modifier == ModifierTy::Default)
273         modifier = getModifier(*iter);
274     }
275 
276     const auto &firstRed = std::get<ReductionTy>(first->u);
277     merged.clauses.emplace_back(makeClause(
278         llvm::omp::Clause::OMPC_reduction,
279         ReductionTy{
280             {/*ReductionModifier=*/modifier,
281              /*ReductionIdentifiers=*/std::get<IdentifiersTy>(firstRed.t),
282              /*List=*/std::get<ListTy>(firstRed.t)}}));
283   }
284 }
285 
mergeDSA()286 template <typename C> void ConstructCompositionT<C>::mergeDSA() {
287   using ObjectTy = tomp::type::ObjectT<IdTy, ExprTy>;
288 
289   // Resolve data-sharing attributes.
290   enum DSA : int {
291     None = 0,
292     Shared = 1 << 0,
293     Private = 1 << 1,
294     FirstPrivate = 1 << 2,
295     LastPrivate = 1 << 3,
296     LastPrivateConditional = 1 << 4,
297   };
298 
299   // Use ordered containers to avoid non-deterministic output.
300   llvm::SmallVector<std::pair<ObjectTy, int>, 8> objectDsa;
301 
302   auto getDsa = [&](const ObjectTy &object) -> std::pair<ObjectTy, int> & {
303     auto found = llvm::find_if(objectDsa, [&](std::pair<ObjectTy, int> &p) {
304       return p.first.id() == object.id();
305     });
306     if (found != objectDsa.end())
307       return *found;
308     return objectDsa.emplace_back(object, DSA::None);
309   };
310 
311   using SharedTy = tomp::clause::SharedT<TypeTy, IdTy, ExprTy>;
312   using PrivateTy = tomp::clause::PrivateT<TypeTy, IdTy, ExprTy>;
313   using FirstprivateTy = tomp::clause::FirstprivateT<TypeTy, IdTy, ExprTy>;
314   using LastprivateTy = tomp::clause::LastprivateT<TypeTy, IdTy, ExprTy>;
315 
316   // Visit clauses that affect DSA.
317   for (auto &clause : clauseSets[llvm::omp::Clause::OMPC_shared]) {
318     for (auto &object : std::get<SharedTy>(clause.u).v)
319       getDsa(object).second |= DSA::Shared;
320   }
321 
322   for (auto &clause : clauseSets[llvm::omp::Clause::OMPC_private]) {
323     for (auto &object : std::get<PrivateTy>(clause.u).v)
324       getDsa(object).second |= DSA::Private;
325   }
326 
327   for (auto &clause : clauseSets[llvm::omp::Clause::OMPC_firstprivate]) {
328     for (auto &object : std::get<FirstprivateTy>(clause.u).v)
329       getDsa(object).second |= DSA::FirstPrivate;
330   }
331 
332   for (auto &clause : clauseSets[llvm::omp::Clause::OMPC_lastprivate]) {
333     using ModifierTy = typename LastprivateTy::LastprivateModifier;
334     using ListTy = typename LastprivateTy::List;
335     const auto &lastp = std::get<LastprivateTy>(clause.u);
336     for (auto &object : std::get<ListTy>(lastp.t)) {
337       auto &mod = std::get<std::optional<ModifierTy>>(lastp.t);
338       if (mod && *mod == ModifierTy::Conditional) {
339         getDsa(object).second |= DSA::LastPrivateConditional;
340       } else {
341         getDsa(object).second |= DSA::LastPrivate;
342       }
343     }
344   }
345 
346   // Check other privatizing clauses as well, clear "shared" if set.
347   for (auto &clause : clauseSets[llvm::omp::Clause::OMPC_in_reduction]) {
348     using InReductionTy = tomp::clause::InReductionT<TypeTy, IdTy, ExprTy>;
349     using ListTy = typename InReductionTy::List;
350     for (auto &object : std::get<ListTy>(std::get<InReductionTy>(clause.u).t))
351       getDsa(object).second &= ~DSA::Shared;
352   }
353   for (auto &clause : clauseSets[llvm::omp::Clause::OMPC_linear]) {
354     using LinearTy = tomp::clause::LinearT<TypeTy, IdTy, ExprTy>;
355     using ListTy = typename LinearTy::List;
356     for (auto &object : std::get<ListTy>(std::get<LinearTy>(clause.u).t))
357       getDsa(object).second &= ~DSA::Shared;
358   }
359   for (auto &clause : clauseSets[llvm::omp::Clause::OMPC_reduction]) {
360     using ReductionTy = tomp::clause::ReductionT<TypeTy, IdTy, ExprTy>;
361     using ListTy = typename ReductionTy::List;
362     for (auto &object : std::get<ListTy>(std::get<ReductionTy>(clause.u).t))
363       getDsa(object).second &= ~DSA::Shared;
364   }
365   for (auto &clause : clauseSets[llvm::omp::Clause::OMPC_task_reduction]) {
366     using TaskReductionTy = tomp::clause::TaskReductionT<TypeTy, IdTy, ExprTy>;
367     using ListTy = typename TaskReductionTy::List;
368     for (auto &object : std::get<ListTy>(std::get<TaskReductionTy>(clause.u).t))
369       getDsa(object).second &= ~DSA::Shared;
370   }
371 
372   tomp::ListT<ObjectTy> privateObj, sharedObj, firstpObj, lastpObj, lastpcObj;
373   for (auto &[object, dsa] : objectDsa) {
374     if (dsa &
375         (DSA::FirstPrivate | DSA::LastPrivate | DSA::LastPrivateConditional)) {
376       if (dsa & DSA::FirstPrivate)
377         firstpObj.push_back(object); // no else
378       if (dsa & DSA::LastPrivateConditional)
379         lastpcObj.push_back(object);
380       else if (dsa & DSA::LastPrivate)
381         lastpObj.push_back(object);
382     } else if (dsa & DSA::Private) {
383       privateObj.push_back(object);
384     } else if (dsa & DSA::Shared) {
385       sharedObj.push_back(object);
386     }
387   }
388 
389   // Materialize each clause.
390   if (!privateObj.empty()) {
391     merged.clauses.emplace_back(
392         makeClause(llvm::omp::Clause::OMPC_private,
393                    PrivateTy{/*List=*/std::move(privateObj)}));
394   }
395   if (!sharedObj.empty()) {
396     merged.clauses.emplace_back(
397         makeClause(llvm::omp::Clause::OMPC_shared,
398                    SharedTy{/*List=*/std::move(sharedObj)}));
399   }
400   if (!firstpObj.empty()) {
401     merged.clauses.emplace_back(
402         makeClause(llvm::omp::Clause::OMPC_firstprivate,
403                    FirstprivateTy{/*List=*/std::move(firstpObj)}));
404   }
405   if (!lastpObj.empty()) {
406     merged.clauses.emplace_back(
407         makeClause(llvm::omp::Clause::OMPC_lastprivate,
408                    LastprivateTy{{/*LastprivateModifier=*/std::nullopt,
409                                   /*List=*/std::move(lastpObj)}}));
410   }
411   if (!lastpcObj.empty()) {
412     auto conditional = LastprivateTy::LastprivateModifier::Conditional;
413     merged.clauses.emplace_back(
414         makeClause(llvm::omp::Clause::OMPC_lastprivate,
415                    LastprivateTy{{/*LastprivateModifier=*/conditional,
416                                   /*List=*/std::move(lastpcObj)}}));
417   }
418 }
419 } // namespace tomp
420 
421 #endif // LLVM_FRONTEND_OPENMP_CONSTRUCTCOMPOSITIONT_H
422