1 #include <boost/asio/dispatch.hpp>
2 #include <boost/asio/execution_context.hpp>
3 #include <boost/asio/thread_pool.hpp>
4 #include <condition_variable>
10 using boost::asio::dispatch
;
11 using boost::asio::execution_context
;
12 using boost::asio::thread_pool
;
14 // A fixed-size thread pool used to implement fork/join semantics. Functions
15 // are scheduled using a simple FIFO queue. Implementing work stealing, or
16 // using a queue based on atomic operations, are left as tasks for the reader.
17 class fork_join_pool
: public execution_context
20 // The constructor starts a thread pool with the specified number of threads.
21 // Note that the thread_count is not a fixed limit on the pool's concurrency.
22 // Additional threads may temporarily be added to the pool if they join a
24 explicit fork_join_pool(
25 std::size_t thread_count
= std::thread::hardware_concurrency() * 2)
27 threads_(thread_count
)
31 // Ask each thread in the pool to dequeue and execute functions until
32 // it is time to shut down, i.e. the use count is zero.
33 for (thread_count_
= 0; thread_count_
< thread_count
; ++thread_count_
)
35 dispatch(threads_
, [&]
37 std::unique_lock
<std::mutex
> lock(mutex_
);
38 while (use_count_
> 0)
39 if (!execute_next(lock
))
40 condition_
.wait(lock
);
52 // The destructor waits for the pool to finish executing functions.
60 friend class fork_executor
;
62 // The base for all functions that are queued in the pool.
65 std::shared_ptr
<std::size_t> work_count_
;
66 void (*execute_
)(std::shared_ptr
<function_base
>& p
);
69 // Execute the next function from the queue, if any. Returns true if a
70 // function was executed, and false if the queue was empty.
71 bool execute_next(std::unique_lock
<std::mutex
>& lock
)
75 auto p(queue_
.front());
82 // Execute a function and decrement the outstanding work.
83 void execute(std::unique_lock
<std::mutex
>& lock
,
84 std::shared_ptr
<function_base
>& p
)
86 std::shared_ptr
<std::size_t> work_count(std::move(p
->work_count_
));
91 do_work_finished(work_count
);
96 do_work_finished(work_count
);
101 // Increment outstanding work.
102 void do_work_started(const std::shared_ptr
<std::size_t>& work_count
) noexcept
104 if (++(*work_count
) == 1)
108 // Decrement outstanding work. Notify waiting threads if we run out.
109 void do_work_finished(const std::shared_ptr
<std::size_t>& work_count
) noexcept
111 if (--(*work_count
) == 0)
114 condition_
.notify_all();
118 // Dispatch a function, executing it immediately if the queue is already
119 // loaded. Otherwise adds the function to the queue and wakes a thread.
120 void do_dispatch(std::shared_ptr
<function_base
> p
,
121 const std::shared_ptr
<std::size_t>& work_count
)
123 std::unique_lock
<std::mutex
> lock(mutex_
);
124 if (queue_
.size() > thread_count_
* 16)
126 do_work_started(work_count
);
133 do_work_started(work_count
);
134 condition_
.notify_one();
138 // Add a function to the queue and wake a thread.
139 void do_post(std::shared_ptr
<function_base
> p
,
140 const std::shared_ptr
<std::size_t>& work_count
)
142 std::lock_guard
<std::mutex
> lock(mutex_
);
144 do_work_started(work_count
);
145 condition_
.notify_one();
148 // Ask all threads to shut down.
151 std::lock_guard
<std::mutex
> lock(mutex_
);
153 condition_
.notify_all();
157 std::condition_variable condition_
;
158 std::queue
<std::shared_ptr
<function_base
>> queue_
;
159 std::size_t use_count_
;
160 std::size_t thread_count_
;
161 thread_pool threads_
;
164 // A class that satisfies the Executor requirements. Every function or piece of
165 // work associated with a fork_executor is part of a single, joinable group.
169 fork_executor(fork_join_pool
& ctx
)
171 work_count_(std::make_shared
<std::size_t>(0))
175 fork_join_pool
& context() const noexcept
180 void on_work_started() const noexcept
182 std::lock_guard
<std::mutex
> lock(context_
.mutex_
);
183 context_
.do_work_started(work_count_
);
186 void on_work_finished() const noexcept
188 std::lock_guard
<std::mutex
> lock(context_
.mutex_
);
189 context_
.do_work_finished(work_count_
);
192 template <class Func
, class Alloc
>
193 void dispatch(Func
&& f
, const Alloc
& a
) const
195 auto p(std::allocate_shared
<function
<Func
>>(
196 typename
std::allocator_traits
<Alloc
>::template rebind_alloc
<char>(a
),
197 std::move(f
), work_count_
));
198 context_
.do_dispatch(p
, work_count_
);
201 template <class Func
, class Alloc
>
202 void post(Func f
, const Alloc
& a
) const
204 auto p(std::allocate_shared
<function
<Func
>>(
205 typename
std::allocator_traits
<Alloc
>::template rebind_alloc
<char>(a
),
206 std::move(f
), work_count_
));
207 context_
.do_post(p
, work_count_
);
210 template <class Func
, class Alloc
>
211 void defer(Func
&& f
, const Alloc
& a
) const
213 post(std::forward
<Func
>(f
), a
);
216 friend bool operator==(const fork_executor
& a
,
217 const fork_executor
& b
) noexcept
219 return a
.work_count_
== b
.work_count_
;
222 friend bool operator!=(const fork_executor
& a
,
223 const fork_executor
& b
) noexcept
225 return a
.work_count_
!= b
.work_count_
;
228 // Block until all work associated with the executor is complete. While it is
229 // waiting, the thread may be borrowed to execute functions from the queue.
232 std::unique_lock
<std::mutex
> lock(context_
.mutex_
);
233 while (*work_count_
> 0)
234 if (!context_
.execute_next(lock
))
235 context_
.condition_
.wait(lock
);
239 template <class Func
>
240 struct function
: fork_join_pool::function_base
242 explicit function(Func f
, const std::shared_ptr
<std::size_t>& w
)
243 : function_(std::move(f
))
246 execute_
= [](std::shared_ptr
<fork_join_pool::function_base
>& p
)
248 Func
tmp(std::move(static_cast<function
*>(p
.get())->function_
));
257 fork_join_pool
& context_
;
258 std::shared_ptr
<std::size_t> work_count_
;
261 // Helper class to automatically join a fork_executor when exiting a scope.
265 explicit join_guard(const fork_executor
& ex
) : ex_(ex
) {}
266 join_guard(const join_guard
&) = delete;
267 join_guard(join_guard
&&) = delete;
268 ~join_guard() { ex_
.join(); }
274 //------------------------------------------------------------------------------
283 template <class Iterator
>
284 void fork_join_sort(Iterator begin
, Iterator end
)
286 std::size_t n
= end
- begin
;
290 fork_executor
fork(pool
);
291 join_guard
join(fork
);
292 dispatch(fork
, [=]{ fork_join_sort(begin
, begin
+ n
/ 2); });
293 dispatch(fork
, [=]{ fork_join_sort(begin
+ n
/ 2, end
); });
295 std::inplace_merge(begin
, begin
+ n
/ 2, end
);
299 std::sort(begin
, end
);
303 int main(int argc
, char* argv
[])
307 std::cerr
<< "Usage: fork_join <size>\n";
311 std::vector
<double> vec(std::atoll(argv
[1]));
312 std::iota(vec
.begin(), vec
.end(), 0);
314 std::random_device rd
;
315 std::mt19937
g(rd());
316 std::shuffle(vec
.begin(), vec
.end(), g
);
318 std::chrono::steady_clock::time_point start
= std::chrono::steady_clock::now();
320 fork_join_sort(vec
.begin(), vec
.end());
322 std::chrono::steady_clock::duration elapsed
= std::chrono::steady_clock::now() - start
;
324 std::cout
<< "sort took ";
325 std::cout
<< std::chrono::duration_cast
<std::chrono::microseconds
>(elapsed
).count();
326 std::cout
<< " microseconds" << std::endl
;