1 //===- llvm/Support/Parallel.h - Parallel algorithms ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #ifndef LLVM_SUPPORT_PARALLEL_H
10 #define LLVM_SUPPORT_PARALLEL_H
11
12 #include "llvm/ADT/STLExtras.h"
13 #include "llvm/Config/llvm-config.h"
14 #include "llvm/Support/Error.h"
15 #include "llvm/Support/MathExtras.h"
16 #include "llvm/Support/Threading.h"
17
18 #include <algorithm>
19 #include <condition_variable>
20 #include <functional>
21 #include <mutex>
22
23 namespace llvm {
24
25 namespace parallel {
26
27 // Strategy for the default executor used by the parallel routines provided by
28 // this file. It defaults to using all hardware threads and should be
29 // initialized before the first use of parallel routines.
30 extern ThreadPoolStrategy strategy;
31
32 #if LLVM_ENABLE_THREADS
33 #define GET_THREAD_INDEX_IMPL \
34 if (parallel::strategy.ThreadsRequested == 1) \
35 return 0; \
36 assert((threadIndex != UINT_MAX) && \
37 "getThreadIndex() must be called from a thread created by " \
38 "ThreadPoolExecutor"); \
39 return threadIndex;
40
41 #ifdef _WIN32
42 // Direct access to thread_local variables from a different DLL isn't
43 // possible with Windows Native TLS.
44 unsigned getThreadIndex();
45 #else
46 // Don't access this directly, use the getThreadIndex wrapper.
47 extern thread_local unsigned threadIndex;
48
getThreadIndex()49 inline unsigned getThreadIndex() { GET_THREAD_INDEX_IMPL; }
50 #endif
51
52 size_t getThreadCount();
53 #else
getThreadIndex()54 inline unsigned getThreadIndex() { return 0; }
getThreadCount()55 inline size_t getThreadCount() { return 1; }
56 #endif
57
58 namespace detail {
59 class Latch {
60 uint32_t Count;
61 mutable std::mutex Mutex;
62 mutable std::condition_variable Cond;
63
64 public:
Count(Count)65 explicit Latch(uint32_t Count = 0) : Count(Count) {}
~Latch()66 ~Latch() {
67 // Ensure at least that sync() was called.
68 assert(Count == 0);
69 }
70
inc()71 void inc() {
72 std::lock_guard<std::mutex> lock(Mutex);
73 ++Count;
74 }
75
dec()76 void dec() {
77 std::lock_guard<std::mutex> lock(Mutex);
78 if (--Count == 0)
79 Cond.notify_all();
80 }
81
sync()82 void sync() const {
83 std::unique_lock<std::mutex> lock(Mutex);
84 Cond.wait(lock, [&] { return Count == 0; });
85 }
86 };
87 } // namespace detail
88
89 class TaskGroup {
90 detail::Latch L;
91 bool Parallel;
92
93 public:
94 TaskGroup();
95 ~TaskGroup();
96
97 // Spawn a task, but does not wait for it to finish.
98 // Tasks marked with \p Sequential will be executed
99 // exactly in the order which they were spawned.
100 void spawn(std::function<void()> f);
101
sync()102 void sync() const { L.sync(); }
103
isParallel()104 bool isParallel() const { return Parallel; }
105 };
106
107 namespace detail {
108
109 #if LLVM_ENABLE_THREADS
110 const ptrdiff_t MinParallelSize = 1024;
111
112 /// Inclusive median.
113 template <class RandomAccessIterator, class Comparator>
medianOf3(RandomAccessIterator Start,RandomAccessIterator End,const Comparator & Comp)114 RandomAccessIterator medianOf3(RandomAccessIterator Start,
115 RandomAccessIterator End,
116 const Comparator &Comp) {
117 RandomAccessIterator Mid = Start + (std::distance(Start, End) / 2);
118 return Comp(*Start, *(End - 1))
119 ? (Comp(*Mid, *(End - 1)) ? (Comp(*Start, *Mid) ? Mid : Start)
120 : End - 1)
121 : (Comp(*Mid, *Start) ? (Comp(*(End - 1), *Mid) ? Mid : End - 1)
122 : Start);
123 }
124
125 template <class RandomAccessIterator, class Comparator>
parallel_quick_sort(RandomAccessIterator Start,RandomAccessIterator End,const Comparator & Comp,TaskGroup & TG,size_t Depth)126 void parallel_quick_sort(RandomAccessIterator Start, RandomAccessIterator End,
127 const Comparator &Comp, TaskGroup &TG, size_t Depth) {
128 // Do a sequential sort for small inputs.
129 if (std::distance(Start, End) < detail::MinParallelSize || Depth == 0) {
130 llvm::sort(Start, End, Comp);
131 return;
132 }
133
134 // Partition.
135 auto Pivot = medianOf3(Start, End, Comp);
136 // Move Pivot to End.
137 std::swap(*(End - 1), *Pivot);
138 Pivot = std::partition(Start, End - 1, [&Comp, End](decltype(*Start) V) {
139 return Comp(V, *(End - 1));
140 });
141 // Move Pivot to middle of partition.
142 std::swap(*Pivot, *(End - 1));
143
144 // Recurse.
145 TG.spawn([=, &Comp, &TG] {
146 parallel_quick_sort(Start, Pivot, Comp, TG, Depth - 1);
147 });
148 parallel_quick_sort(Pivot + 1, End, Comp, TG, Depth - 1);
149 }
150
151 template <class RandomAccessIterator, class Comparator>
parallel_sort(RandomAccessIterator Start,RandomAccessIterator End,const Comparator & Comp)152 void parallel_sort(RandomAccessIterator Start, RandomAccessIterator End,
153 const Comparator &Comp) {
154 TaskGroup TG;
155 parallel_quick_sort(Start, End, Comp, TG,
156 llvm::Log2_64(std::distance(Start, End)) + 1);
157 }
158
159 // TaskGroup has a relatively high overhead, so we want to reduce
160 // the number of spawn() calls. We'll create up to 1024 tasks here.
161 // (Note that 1024 is an arbitrary number. This code probably needs
162 // improving to take the number of available cores into account.)
163 enum { MaxTasksPerGroup = 1024 };
164
165 template <class IterTy, class ResultTy, class ReduceFuncTy,
166 class TransformFuncTy>
parallel_transform_reduce(IterTy Begin,IterTy End,ResultTy Init,ReduceFuncTy Reduce,TransformFuncTy Transform)167 ResultTy parallel_transform_reduce(IterTy Begin, IterTy End, ResultTy Init,
168 ReduceFuncTy Reduce,
169 TransformFuncTy Transform) {
170 // Limit the number of tasks to MaxTasksPerGroup to limit job scheduling
171 // overhead on large inputs.
172 size_t NumInputs = std::distance(Begin, End);
173 if (NumInputs == 0)
174 return std::move(Init);
175 size_t NumTasks = std::min(static_cast<size_t>(MaxTasksPerGroup), NumInputs);
176 std::vector<ResultTy> Results(NumTasks, Init);
177 {
178 // Each task processes either TaskSize or TaskSize+1 inputs. Any inputs
179 // remaining after dividing them equally amongst tasks are distributed as
180 // one extra input over the first tasks.
181 TaskGroup TG;
182 size_t TaskSize = NumInputs / NumTasks;
183 size_t RemainingInputs = NumInputs % NumTasks;
184 IterTy TBegin = Begin;
185 for (size_t TaskId = 0; TaskId < NumTasks; ++TaskId) {
186 IterTy TEnd = TBegin + TaskSize + (TaskId < RemainingInputs ? 1 : 0);
187 TG.spawn([=, &Transform, &Reduce, &Results] {
188 // Reduce the result of transformation eagerly within each task.
189 ResultTy R = Init;
190 for (IterTy It = TBegin; It != TEnd; ++It)
191 R = Reduce(R, Transform(*It));
192 Results[TaskId] = R;
193 });
194 TBegin = TEnd;
195 }
196 assert(TBegin == End);
197 }
198
199 // Do a final reduction. There are at most 1024 tasks, so this only adds
200 // constant single-threaded overhead for large inputs. Hopefully most
201 // reductions are cheaper than the transformation.
202 ResultTy FinalResult = std::move(Results.front());
203 for (ResultTy &PartialResult :
204 MutableArrayRef(Results.data() + 1, Results.size() - 1))
205 FinalResult = Reduce(FinalResult, std::move(PartialResult));
206 return std::move(FinalResult);
207 }
208
209 #endif
210
211 } // namespace detail
212 } // namespace parallel
213
214 template <class RandomAccessIterator,
215 class Comparator = std::less<
216 typename std::iterator_traits<RandomAccessIterator>::value_type>>
217 void parallelSort(RandomAccessIterator Start, RandomAccessIterator End,
218 const Comparator &Comp = Comparator()) {
219 #if LLVM_ENABLE_THREADS
220 if (parallel::strategy.ThreadsRequested != 1) {
221 parallel::detail::parallel_sort(Start, End, Comp);
222 return;
223 }
224 #endif
225 llvm::sort(Start, End, Comp);
226 }
227
228 void parallelFor(size_t Begin, size_t End, function_ref<void(size_t)> Fn);
229
230 template <class IterTy, class FuncTy>
parallelForEach(IterTy Begin,IterTy End,FuncTy Fn)231 void parallelForEach(IterTy Begin, IterTy End, FuncTy Fn) {
232 parallelFor(0, End - Begin, [&](size_t I) { Fn(Begin[I]); });
233 }
234
235 template <class IterTy, class ResultTy, class ReduceFuncTy,
236 class TransformFuncTy>
parallelTransformReduce(IterTy Begin,IterTy End,ResultTy Init,ReduceFuncTy Reduce,TransformFuncTy Transform)237 ResultTy parallelTransformReduce(IterTy Begin, IterTy End, ResultTy Init,
238 ReduceFuncTy Reduce,
239 TransformFuncTy Transform) {
240 #if LLVM_ENABLE_THREADS
241 if (parallel::strategy.ThreadsRequested != 1) {
242 return parallel::detail::parallel_transform_reduce(Begin, End, Init, Reduce,
243 Transform);
244 }
245 #endif
246 for (IterTy I = Begin; I != End; ++I)
247 Init = Reduce(std::move(Init), Transform(*I));
248 return std::move(Init);
249 }
250
251 // Range wrappers.
252 template <class RangeTy,
253 class Comparator = std::less<decltype(*std::begin(RangeTy()))>>
254 void parallelSort(RangeTy &&R, const Comparator &Comp = Comparator()) {
255 parallelSort(std::begin(R), std::end(R), Comp);
256 }
257
258 template <class RangeTy, class FuncTy>
parallelForEach(RangeTy && R,FuncTy Fn)259 void parallelForEach(RangeTy &&R, FuncTy Fn) {
260 parallelForEach(std::begin(R), std::end(R), Fn);
261 }
262
263 template <class RangeTy, class ResultTy, class ReduceFuncTy,
264 class TransformFuncTy>
parallelTransformReduce(RangeTy && R,ResultTy Init,ReduceFuncTy Reduce,TransformFuncTy Transform)265 ResultTy parallelTransformReduce(RangeTy &&R, ResultTy Init,
266 ReduceFuncTy Reduce,
267 TransformFuncTy Transform) {
268 return parallelTransformReduce(std::begin(R), std::end(R), Init, Reduce,
269 Transform);
270 }
271
272 // Parallel for-each, but with error handling.
273 template <class RangeTy, class FuncTy>
parallelForEachError(RangeTy && R,FuncTy Fn)274 Error parallelForEachError(RangeTy &&R, FuncTy Fn) {
275 // The transform_reduce algorithm requires that the initial value be copyable.
276 // Error objects are uncopyable. We only need to copy initial success values,
277 // so work around this mismatch via the C API. The C API represents success
278 // values with a null pointer. The joinErrors discards null values and joins
279 // multiple errors into an ErrorList.
280 return unwrap(parallelTransformReduce(
281 std::begin(R), std::end(R), wrap(Error::success()),
282 [](LLVMErrorRef Lhs, LLVMErrorRef Rhs) {
283 return wrap(joinErrors(unwrap(Lhs), unwrap(Rhs)));
284 },
285 [&Fn](auto &&V) { return wrap(Fn(V)); }));
286 }
287
288 } // namespace llvm
289
290 #endif // LLVM_SUPPORT_PARALLEL_H
291