LLVM 23.0.0git
Parallel.h
Go to the documentation of this file.
1//===- llvm/Support/Parallel.h - Parallel algorithms ----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef LLVM_SUPPORT_PARALLEL_H
10#define LLVM_SUPPORT_PARALLEL_H
11
12#include "llvm/ADT/STLExtras.h"
13#include "llvm/Config/llvm-config.h"
15#include "llvm/Support/Error.h"
18
19#include <algorithm>
20#include <atomic>
21#include <condition_variable>
22#include <functional>
23#include <mutex>
24
25namespace llvm {
26
27namespace parallel {
28
29// Strategy for the default executor used by the parallel routines provided by
30// this file. It defaults to using all hardware threads and should be
31// initialized before the first use of parallel routines.
33
34#if LLVM_ENABLE_THREADS
35#define GET_THREAD_INDEX_IMPL \
36 if (parallel::strategy.ThreadsRequested == 1) \
37 return 0; \
38 assert((threadIndex != UINT_MAX) && \
39 "getThreadIndex() must be called from a thread created by " \
40 "ThreadPoolExecutor"); \
41 return threadIndex;
42
43#ifdef _WIN32
44// Direct access to thread_local variables from a different DLL isn't
45// possible with Windows Native TLS.
46LLVM_ABI unsigned getThreadIndex();
47#else
48// Don't access this directly, use the getThreadIndex wrapper.
49LLVM_ABI extern thread_local unsigned threadIndex;
50
51inline unsigned getThreadIndex() { GET_THREAD_INDEX_IMPL; }
52#endif
53
55#else
56inline unsigned getThreadIndex() { return 0; }
57inline size_t getThreadCount() { return 1; }
58#endif
59
60namespace detail {
61class Latch {
62 std::atomic<uint32_t> Count;
63 mutable std::mutex Mutex;
64 mutable std::condition_variable Cond;
65
66public:
67 explicit Latch(uint32_t Count = 0) : Count(Count) {}
68 ~Latch() { assert(Count.load(std::memory_order_relaxed) == 0); }
69
70 void inc() { Count.fetch_add(1, std::memory_order_relaxed); }
71
72 // dec() must hold Mutex so that sync() cannot observe Count==0 and
73 // destroy the Latch while dec() is still running.
74 void dec() {
75 std::lock_guard<std::mutex> lock(Mutex);
76 // fetch_sub returns the previous value; == 1 means Count is now 0.
77 if (Count.fetch_sub(1, std::memory_order_acq_rel) == 1)
78 Cond.notify_all();
79 }
80
81 void sync() const {
82 std::unique_lock<std::mutex> lock(Mutex);
83 Cond.wait(lock, [&] { return Count.load(std::memory_order_relaxed) == 0; });
84 }
85};
86} // namespace detail
87
88class TaskGroup {
90 bool Parallel;
91
92public:
95
96 // Spawn a task, but does not wait for it to finish.
97 // Tasks marked with \p Sequential will be executed
98 // exactly in the order which they were spawned.
99 LLVM_ABI void spawn(std::function<void()> f);
100
101 void sync() const { L.sync(); }
102
103 bool isParallel() const { return Parallel; }
104};
105
106namespace detail {
107
108#if LLVM_ENABLE_THREADS
109const ptrdiff_t MinParallelSize = 1024;
110
111/// Inclusive median.
112template <class RandomAccessIterator, class Comparator>
113RandomAccessIterator medianOf3(RandomAccessIterator Start,
114 RandomAccessIterator End,
115 const Comparator &Comp) {
116 RandomAccessIterator Mid = Start + (std::distance(Start, End) / 2);
117 return Comp(*Start, *(End - 1))
118 ? (Comp(*Mid, *(End - 1)) ? (Comp(*Start, *Mid) ? Mid : Start)
119 : End - 1)
120 : (Comp(*Mid, *Start) ? (Comp(*(End - 1), *Mid) ? Mid : End - 1)
121 : Start);
122}
123
124template <class RandomAccessIterator, class Comparator>
125void parallel_quick_sort(RandomAccessIterator Start, RandomAccessIterator End,
126 const Comparator &Comp, TaskGroup &TG, size_t Depth) {
127 // Do a sequential sort for small inputs.
128 if (std::distance(Start, End) < detail::MinParallelSize || Depth == 0) {
129 llvm::sort(Start, End, Comp);
130 return;
131 }
132
133 // Partition.
134 auto Pivot = medianOf3(Start, End, Comp);
135 // Move Pivot to End.
136 std::swap(*(End - 1), *Pivot);
137 Pivot = std::partition(Start, End - 1, [&Comp, End](decltype(*Start) V) {
138 return Comp(V, *(End - 1));
139 });
140 // Move Pivot to middle of partition.
141 std::swap(*Pivot, *(End - 1));
142
143 // Recurse.
144 TG.spawn([=, &Comp, &TG] {
145 parallel_quick_sort(Start, Pivot, Comp, TG, Depth - 1);
146 });
147 parallel_quick_sort(Pivot + 1, End, Comp, TG, Depth - 1);
148}
149
150template <class RandomAccessIterator, class Comparator>
151void parallel_sort(RandomAccessIterator Start, RandomAccessIterator End,
152 const Comparator &Comp) {
153 TaskGroup TG;
154 parallel_quick_sort(Start, End, Comp, TG,
155 llvm::Log2_64(std::distance(Start, End)) + 1);
156}
157
158// TaskGroup has a relatively high overhead, so we want to reduce
159// the number of spawn() calls. We'll create up to 1024 tasks here.
160// (Note that 1024 is an arbitrary number. This code probably needs
161// improving to take the number of available cores into account.)
162enum { MaxTasksPerGroup = 1024 };
163
164template <class IterTy, class ResultTy, class ReduceFuncTy,
165 class TransformFuncTy>
166ResultTy parallel_transform_reduce(IterTy Begin, IterTy End, ResultTy Init,
167 ReduceFuncTy Reduce,
168 TransformFuncTy Transform) {
169 // Limit the number of tasks to MaxTasksPerGroup to limit job scheduling
170 // overhead on large inputs.
171 size_t NumInputs = std::distance(Begin, End);
172 if (NumInputs == 0)
173 return std::move(Init);
174 size_t NumTasks = std::min(static_cast<size_t>(MaxTasksPerGroup), NumInputs);
175 std::vector<ResultTy> Results(NumTasks, Init);
176 {
177 // Each task processes either TaskSize or TaskSize+1 inputs. Any inputs
178 // remaining after dividing them equally amongst tasks are distributed as
179 // one extra input over the first tasks.
180 TaskGroup TG;
181 size_t TaskSize = NumInputs / NumTasks;
182 size_t RemainingInputs = NumInputs % NumTasks;
183 IterTy TBegin = Begin;
184 for (size_t TaskId = 0; TaskId < NumTasks; ++TaskId) {
185 IterTy TEnd = TBegin + TaskSize + (TaskId < RemainingInputs ? 1 : 0);
186 TG.spawn([=, &Transform, &Reduce, &Results] {
187 // Reduce the result of transformation eagerly within each task.
188 ResultTy R = Init;
189 for (IterTy It = TBegin; It != TEnd; ++It)
190 R = Reduce(R, Transform(*It));
191 Results[TaskId] = R;
192 });
193 TBegin = TEnd;
194 }
195 assert(TBegin == End);
196 }
197
198 // Do a final reduction. There are at most 1024 tasks, so this only adds
199 // constant single-threaded overhead for large inputs. Hopefully most
200 // reductions are cheaper than the transformation.
201 ResultTy FinalResult = std::move(Results.front());
202 for (ResultTy &PartialResult :
203 MutableArrayRef(Results.data() + 1, Results.size() - 1))
204 FinalResult = Reduce(FinalResult, std::move(PartialResult));
205 return std::move(FinalResult);
206}
207
208#endif
209
210} // namespace detail
211} // namespace parallel
212
213template <class RandomAccessIterator,
214 class Comparator = std::less<
215 typename std::iterator_traits<RandomAccessIterator>::value_type>>
216void parallelSort(RandomAccessIterator Start, RandomAccessIterator End,
217 const Comparator &Comp = Comparator()) {
218#if LLVM_ENABLE_THREADS
219 if (parallel::strategy.ThreadsRequested != 1) {
220 parallel::detail::parallel_sort(Start, End, Comp);
221 return;
222 }
223#endif
224 llvm::sort(Start, End, Comp);
225}
226
227LLVM_ABI void parallelFor(size_t Begin, size_t End,
228 function_ref<void(size_t)> Fn);
229
230template <class IterTy, class FuncTy>
231void parallelForEach(IterTy Begin, IterTy End, FuncTy Fn) {
232 parallelFor(0, End - Begin, [&](size_t I) { Fn(Begin[I]); });
233}
234
235template <class IterTy, class ResultTy, class ReduceFuncTy,
236 class TransformFuncTy>
237ResultTy parallelTransformReduce(IterTy Begin, IterTy End, ResultTy Init,
238 ReduceFuncTy Reduce,
239 TransformFuncTy Transform) {
240#if LLVM_ENABLE_THREADS
241 if (parallel::strategy.ThreadsRequested != 1) {
242 return parallel::detail::parallel_transform_reduce(Begin, End, Init, Reduce,
243 Transform);
244 }
245#endif
246 for (IterTy I = Begin; I != End; ++I)
247 Init = Reduce(std::move(Init), Transform(*I));
248 return std::move(Init);
249}
250
251// Range wrappers.
252template <class RangeTy,
253 class Comparator = std::less<decltype(*std::begin(RangeTy()))>>
254void parallelSort(RangeTy &&R, const Comparator &Comp = Comparator()) {
255 parallelSort(std::begin(R), std::end(R), Comp);
256}
257
258template <class RangeTy, class FuncTy>
259void parallelForEach(RangeTy &&R, FuncTy Fn) {
260 parallelForEach(std::begin(R), std::end(R), Fn);
261}
262
263template <class RangeTy, class ResultTy, class ReduceFuncTy,
264 class TransformFuncTy>
265ResultTy parallelTransformReduce(RangeTy &&R, ResultTy Init,
266 ReduceFuncTy Reduce,
267 TransformFuncTy Transform) {
268 return parallelTransformReduce(std::begin(R), std::end(R), Init, Reduce,
269 Transform);
270}
271
272// Parallel for-each, but with error handling.
273template <class RangeTy, class FuncTy>
274Error parallelForEachError(RangeTy &&R, FuncTy Fn) {
275 // The transform_reduce algorithm requires that the initial value be copyable.
276 // Error objects are uncopyable. We only need to copy initial success values,
277 // so work around this mismatch via the C API. The C API represents success
278 // values with a null pointer. The joinErrors discards null values and joins
279 // multiple errors into an ErrorList.
281 std::begin(R), std::end(R), wrap(Error::success()),
282 [](LLVMErrorRef Lhs, LLVMErrorRef Rhs) {
283 return wrap(joinErrors(unwrap(Lhs), unwrap(Rhs)));
284 },
285 [&Fn](auto &&V) { return wrap(Fn(V)); }));
286}
287
288} // namespace llvm
289
290#endif // LLVM_SUPPORT_PARALLEL_H
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
Function Alias Analysis Results
#define LLVM_ABI
Definition Compiler.h:213
#define I(x, y, z)
Definition MD5.cpp:57
This file contains some templates that are useful if you are working with the STL at all.
Lightweight error class with error context and mandatory checking.
Definition Error.h:159
static ErrorSuccess success()
Create a success value.
Definition Error.h:336
This tells how a thread pool will be used.
Definition Threading.h:115
LLVM_ABI void spawn(std::function< void()> f)
Definition Parallel.cpp:237
bool isParallel() const
Definition Parallel.h:103
Latch(uint32_t Count=0)
Definition Parallel.h:67
struct LLVMOpaqueError * LLVMErrorRef
Opaque reference to an error instance.
Definition Error.h:34
LLVM_ABI ThreadPoolStrategy strategy
Definition Parallel.cpp:27
unsigned getThreadIndex()
Definition Parallel.h:56
size_t getThreadCount()
Definition Parallel.h:57
This is an optimization pass for GlobalISel generic memory operations.
unsigned Log2_64(uint64_t Value)
Return the floor log base 2 of the specified value, -1 if the value is zero.
Definition MathExtras.h:337
void parallelSort(RandomAccessIterator Start, RandomAccessIterator End, const Comparator &Comp=Comparator())
Definition Parallel.h:216
Error joinErrors(Error E1, Error E2)
Concatenate errors.
Definition Error.h:442
void sort(IteratorTy Start, IteratorTy End)
Definition STLExtras.h:1636
MutableArrayRef(T &OneElt) -> MutableArrayRef< T >
Attribute unwrap(LLVMAttributeRef Attr)
Definition Attributes.h:397
LLVMAttributeRef wrap(Attribute Attr)
Definition Attributes.h:392
LLVM_ABI void parallelFor(size_t Begin, size_t End, function_ref< void(size_t)> Fn)
Definition Parallel.cpp:248
ResultTy parallelTransformReduce(IterTy Begin, IterTy End, ResultTy Init, ReduceFuncTy Reduce, TransformFuncTy Transform)
Definition Parallel.h:237
void parallelForEach(IterTy Begin, IterTy End, FuncTy Fn)
Definition Parallel.h:231
Error parallelForEachError(RangeTy &&R, FuncTy Fn)
Definition Parallel.h:274
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
Definition BitVector.h:872