LLVM 23.0.0git
Parallel.h
Go to the documentation of this file.
1//===- llvm/Support/Parallel.h - Parallel algorithms ----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef LLVM_SUPPORT_PARALLEL_H
10#define LLVM_SUPPORT_PARALLEL_H
11
12#include "llvm/ADT/STLExtras.h"
13#include "llvm/Config/llvm-config.h"
15#include "llvm/Support/Error.h"
18
19#include <algorithm>
20#include <atomic>
21#include <condition_variable>
22#include <functional>
23#include <mutex>
24
25namespace llvm {
26
27namespace parallel {
28
29// Strategy for the default executor used by the parallel routines provided by
30// this file. It defaults to using all hardware threads and should be
31// initialized before the first use of parallel routines.
33
34#if LLVM_ENABLE_THREADS
35#define GET_THREAD_INDEX_IMPL \
36 if (parallel::strategy.ThreadsRequested == 1) \
37 return 0; \
38 assert((threadIndex != UINT_MAX) && \
39 "getThreadIndex() must be called from a thread created by " \
40 "ThreadPoolExecutor"); \
41 return threadIndex;
42
43#ifdef _WIN32
44// Direct access to thread_local variables from a different DLL isn't
45// possible with Windows Native TLS.
46LLVM_ABI unsigned getThreadIndex();
47#else
48// Don't access this directly, use the getThreadIndex wrapper.
49LLVM_ABI extern thread_local unsigned threadIndex;
50
51inline unsigned getThreadIndex() { GET_THREAD_INDEX_IMPL; }
52#endif
53
55#else
56inline unsigned getThreadIndex() { return 0; }
57inline size_t getThreadCount() { return 1; }
58#endif
59
60namespace detail {
61class Latch {
62 std::atomic<uint32_t> Count;
63 mutable std::mutex Mutex;
64 mutable std::condition_variable Cond;
65
66public:
67 explicit Latch(uint32_t Count = 0) : Count(Count) {}
68 ~Latch() { assert(Count.load(std::memory_order_relaxed) == 0); }
69
70 void inc() { Count.fetch_add(1, std::memory_order_relaxed); }
71
72 // dec() must hold Mutex so that sync() cannot observe Count==0 and
73 // destroy the Latch while dec() is still running.
74 void dec() {
75 std::lock_guard<std::mutex> lock(Mutex);
76 // fetch_sub returns the previous value; == 1 means Count is now 0.
77 if (Count.fetch_sub(1, std::memory_order_acq_rel) == 1)
78 Cond.notify_all();
79 }
80
81 uint32_t getCount() const { return Count.load(std::memory_order_acquire); }
82
83 void sync() const {
84 std::unique_lock<std::mutex> lock(Mutex);
85 Cond.wait(lock, [&] { return Count.load(std::memory_order_relaxed) == 0; });
86 }
87};
88} // namespace detail
89
90class TaskGroup {
92 bool Parallel;
93
94public:
97
98 // Spawn a task, but does not wait for it to finish.
99 LLVM_ABI void spawn(std::function<void()> f);
100
101 bool isParallel() const { return Parallel; }
102};
103
104namespace detail {
105
106#if LLVM_ENABLE_THREADS
107const ptrdiff_t MinParallelSize = 1024;
108
109/// Inclusive median.
110template <class RandomAccessIterator, class Comparator>
111RandomAccessIterator medianOf3(RandomAccessIterator Start,
112 RandomAccessIterator End,
113 const Comparator &Comp) {
114 RandomAccessIterator Mid = Start + (std::distance(Start, End) / 2);
115 return Comp(*Start, *(End - 1))
116 ? (Comp(*Mid, *(End - 1)) ? (Comp(*Start, *Mid) ? Mid : Start)
117 : End - 1)
118 : (Comp(*Mid, *Start) ? (Comp(*(End - 1), *Mid) ? Mid : End - 1)
119 : Start);
120}
121
122template <class RandomAccessIterator, class Comparator>
123void parallel_quick_sort(RandomAccessIterator Start, RandomAccessIterator End,
124 const Comparator &Comp, TaskGroup &TG, size_t Depth) {
125 // Do a sequential sort for small inputs.
126 if (std::distance(Start, End) < detail::MinParallelSize || Depth == 0) {
127 llvm::sort(Start, End, Comp);
128 return;
129 }
130
131 // Partition.
132 auto Pivot = medianOf3(Start, End, Comp);
133 // Move Pivot to End.
134 std::swap(*(End - 1), *Pivot);
135 Pivot = std::partition(Start, End - 1, [&Comp, End](decltype(*Start) V) {
136 return Comp(V, *(End - 1));
137 });
138 // Move Pivot to middle of partition.
139 std::swap(*Pivot, *(End - 1));
140
141 // Recurse.
142 TG.spawn([=, &Comp, &TG] {
143 parallel_quick_sort(Start, Pivot, Comp, TG, Depth - 1);
144 });
145 parallel_quick_sort(Pivot + 1, End, Comp, TG, Depth - 1);
146}
147
148template <class RandomAccessIterator, class Comparator>
149void parallel_sort(RandomAccessIterator Start, RandomAccessIterator End,
150 const Comparator &Comp) {
151 TaskGroup TG;
152 parallel_quick_sort(Start, End, Comp, TG,
153 llvm::Log2_64(std::distance(Start, End)) + 1);
154}
155
156// TaskGroup has a relatively high overhead, so we want to reduce
157// the number of spawn() calls. We'll create up to 1024 tasks here.
158// (Note that 1024 is an arbitrary number. This code probably needs
159// improving to take the number of available cores into account.)
160enum { MaxTasksPerGroup = 1024 };
161
162template <class IterTy, class ResultTy, class ReduceFuncTy,
163 class TransformFuncTy>
164ResultTy parallel_transform_reduce(IterTy Begin, IterTy End, ResultTy Init,
165 ReduceFuncTy Reduce,
166 TransformFuncTy Transform) {
167 // Limit the number of tasks to MaxTasksPerGroup to limit job scheduling
168 // overhead on large inputs.
169 size_t NumInputs = std::distance(Begin, End);
170 if (NumInputs == 0)
171 return std::move(Init);
172 size_t NumTasks = std::min(static_cast<size_t>(MaxTasksPerGroup), NumInputs);
173 std::vector<ResultTy> Results(NumTasks, Init);
174 {
175 // Each task processes either TaskSize or TaskSize+1 inputs. Any inputs
176 // remaining after dividing them equally amongst tasks are distributed as
177 // one extra input over the first tasks.
178 TaskGroup TG;
179 size_t TaskSize = NumInputs / NumTasks;
180 size_t RemainingInputs = NumInputs % NumTasks;
181 IterTy TBegin = Begin;
182 for (size_t TaskId = 0; TaskId < NumTasks; ++TaskId) {
183 IterTy TEnd = TBegin + TaskSize + (TaskId < RemainingInputs ? 1 : 0);
184 TG.spawn([=, &Transform, &Reduce, &Results] {
185 // Reduce the result of transformation eagerly within each task.
186 ResultTy R = Init;
187 for (IterTy It = TBegin; It != TEnd; ++It)
188 R = Reduce(R, Transform(*It));
189 Results[TaskId] = R;
190 });
191 TBegin = TEnd;
192 }
193 assert(TBegin == End);
194 }
195
196 // Do a final reduction. There are at most 1024 tasks, so this only adds
197 // constant single-threaded overhead for large inputs. Hopefully most
198 // reductions are cheaper than the transformation.
199 ResultTy FinalResult = std::move(Results.front());
200 for (ResultTy &PartialResult :
201 MutableArrayRef(Results.data() + 1, Results.size() - 1))
202 FinalResult = Reduce(FinalResult, std::move(PartialResult));
203 return std::move(FinalResult);
204}
205
206#endif
207
208} // namespace detail
209} // namespace parallel
210
211template <class RandomAccessIterator,
212 class Comparator = std::less<
213 typename std::iterator_traits<RandomAccessIterator>::value_type>>
214void parallelSort(RandomAccessIterator Start, RandomAccessIterator End,
215 const Comparator &Comp = Comparator()) {
216#if LLVM_ENABLE_THREADS
217 if (parallel::strategy.ThreadsRequested != 1) {
218 parallel::detail::parallel_sort(Start, End, Comp);
219 return;
220 }
221#endif
222 llvm::sort(Start, End, Comp);
223}
224
225LLVM_ABI void parallelFor(size_t Begin, size_t End,
226 function_ref<void(size_t)> Fn);
227
228template <class IterTy, class FuncTy>
229void parallelForEach(IterTy Begin, IterTy End, FuncTy Fn) {
230 parallelFor(0, End - Begin, [&](size_t I) { Fn(Begin[I]); });
231}
232
233template <class IterTy, class ResultTy, class ReduceFuncTy,
234 class TransformFuncTy>
235ResultTy parallelTransformReduce(IterTy Begin, IterTy End, ResultTy Init,
236 ReduceFuncTy Reduce,
237 TransformFuncTy Transform) {
238#if LLVM_ENABLE_THREADS
239 if (parallel::strategy.ThreadsRequested != 1) {
240 return parallel::detail::parallel_transform_reduce(Begin, End, Init, Reduce,
241 Transform);
242 }
243#endif
244 for (IterTy I = Begin; I != End; ++I)
245 Init = Reduce(std::move(Init), Transform(*I));
246 return std::move(Init);
247}
248
249// Range wrappers.
250template <class RangeTy,
251 class Comparator = std::less<decltype(*std::begin(RangeTy()))>>
252void parallelSort(RangeTy &&R, const Comparator &Comp = Comparator()) {
253 parallelSort(std::begin(R), std::end(R), Comp);
254}
255
256template <class RangeTy, class FuncTy>
257void parallelForEach(RangeTy &&R, FuncTy Fn) {
258 parallelForEach(std::begin(R), std::end(R), Fn);
259}
260
261template <class RangeTy, class ResultTy, class ReduceFuncTy,
262 class TransformFuncTy>
263ResultTy parallelTransformReduce(RangeTy &&R, ResultTy Init,
264 ReduceFuncTy Reduce,
265 TransformFuncTy Transform) {
266 return parallelTransformReduce(std::begin(R), std::end(R), Init, Reduce,
267 Transform);
268}
269
270// Parallel for-each, but with error handling.
271template <class RangeTy, class FuncTy>
272Error parallelForEachError(RangeTy &&R, FuncTy Fn) {
273 // The transform_reduce algorithm requires that the initial value be copyable.
274 // Error objects are uncopyable. We only need to copy initial success values,
275 // so work around this mismatch via the C API. The C API represents success
276 // values with a null pointer. The joinErrors discards null values and joins
277 // multiple errors into an ErrorList.
279 std::begin(R), std::end(R), wrap(Error::success()),
280 [](LLVMErrorRef Lhs, LLVMErrorRef Rhs) {
281 return wrap(joinErrors(unwrap(Lhs), unwrap(Rhs)));
282 },
283 [&Fn](auto &&V) { return wrap(Fn(V)); }));
284}
285
286} // namespace llvm
287
288#endif // LLVM_SUPPORT_PARALLEL_H
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
Function Alias Analysis Results
#define LLVM_ABI
Definition Compiler.h:213
#define I(x, y, z)
Definition MD5.cpp:57
This file contains some templates that are useful if you are working with the STL at all.
Lightweight error class with error context and mandatory checking.
Definition Error.h:159
static ErrorSuccess success()
Create a success value.
Definition Error.h:336
This tells how a thread pool will be used.
Definition Threading.h:115
LLVM_ABI void spawn(std::function< void()> f)
Definition Parallel.cpp:257
bool isParallel() const
Definition Parallel.h:101
Latch(uint32_t Count=0)
Definition Parallel.h:67
uint32_t getCount() const
Definition Parallel.h:81
struct LLVMOpaqueError * LLVMErrorRef
Opaque reference to an error instance.
Definition Error.h:34
LLVM_ABI ThreadPoolStrategy strategy
Definition Parallel.cpp:27
unsigned getThreadIndex()
Definition Parallel.h:56
size_t getThreadCount()
Definition Parallel.h:57
This is an optimization pass for GlobalISel generic memory operations.
unsigned Log2_64(uint64_t Value)
Return the floor log base 2 of the specified value, -1 if the value is zero.
Definition MathExtras.h:337
void parallelSort(RandomAccessIterator Start, RandomAccessIterator End, const Comparator &Comp=Comparator())
Definition Parallel.h:214
Error joinErrors(Error E1, Error E2)
Concatenate errors.
Definition Error.h:442
void sort(IteratorTy Start, IteratorTy End)
Definition STLExtras.h:1636
MutableArrayRef(T &OneElt) -> MutableArrayRef< T >
Attribute unwrap(LLVMAttributeRef Attr)
Definition Attributes.h:397
LLVMAttributeRef wrap(Attribute Attr)
Definition Attributes.h:392
LLVM_ABI void parallelFor(size_t Begin, size_t End, function_ref< void(size_t)> Fn)
Definition Parallel.cpp:268
ResultTy parallelTransformReduce(IterTy Begin, IterTy End, ResultTy Init, ReduceFuncTy Reduce, TransformFuncTy Transform)
Definition Parallel.h:235
void parallelForEach(IterTy Begin, IterTy End, FuncTy Fn)
Definition Parallel.h:229
Error parallelForEachError(RangeTy &&R, FuncTy Fn)
Definition Parallel.h:272
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
Definition BitVector.h:872