]> git.proxmox.com Git - ceph.git/blob - ceph/src/boost/libs/asio/example/cpp11/executors/fork_join.cpp
update sources to v12.2.3
[ceph.git] / ceph / src / boost / libs / asio / example / cpp11 / executors / fork_join.cpp
1 #include <boost/asio/dispatch.hpp>
2 #include <boost/asio/execution_context.hpp>
3 #include <boost/asio/thread_pool.hpp>
4 #include <condition_variable>
5 #include <memory>
6 #include <mutex>
7 #include <queue>
8 #include <thread>
9
10 using boost::asio::dispatch;
11 using boost::asio::execution_context;
12 using boost::asio::thread_pool;
13
14 // A fixed-size thread pool used to implement fork/join semantics. Functions
15 // are scheduled using a simple FIFO queue. Implementing work stealing, or
16 // using a queue based on atomic operations, are left as tasks for the reader.
17 class fork_join_pool : public execution_context
18 {
19 public:
20 // The constructor starts a thread pool with the specified number of threads.
21 // Note that the thread_count is not a fixed limit on the pool's concurrency.
22 // Additional threads may temporarily be added to the pool if they join a
23 // fork_executor.
24 explicit fork_join_pool(
25 std::size_t thread_count = std::thread::hardware_concurrency() * 2)
26 : use_count_(1),
27 threads_(thread_count)
28 {
29 try
30 {
31 // Ask each thread in the pool to dequeue and execute functions until
32 // it is time to shut down, i.e. the use count is zero.
33 for (thread_count_ = 0; thread_count_ < thread_count; ++thread_count_)
34 {
35 dispatch(threads_, [&]
36 {
37 std::unique_lock<std::mutex> lock(mutex_);
38 while (use_count_ > 0)
39 if (!execute_next(lock))
40 condition_.wait(lock);
41 });
42 }
43 }
44 catch (...)
45 {
46 stop_threads();
47 threads_.join();
48 throw;
49 }
50 }
51
52 // The destructor waits for the pool to finish executing functions.
53 ~fork_join_pool()
54 {
55 stop_threads();
56 threads_.join();
57 }
58
59 private:
60 friend class fork_executor;
61
62 // The base for all functions that are queued in the pool.
63 struct function_base
64 {
65 std::shared_ptr<std::size_t> work_count_;
66 void (*execute_)(std::shared_ptr<function_base>& p);
67 };
68
69 // Execute the next function from the queue, if any. Returns true if a
70 // function was executed, and false if the queue was empty.
71 bool execute_next(std::unique_lock<std::mutex>& lock)
72 {
73 if (queue_.empty())
74 return false;
75 auto p(queue_.front());
76 queue_.pop();
77 lock.unlock();
78 execute(lock, p);
79 return true;
80 }
81
82 // Execute a function and decrement the outstanding work.
83 void execute(std::unique_lock<std::mutex>& lock,
84 std::shared_ptr<function_base>& p)
85 {
86 std::shared_ptr<std::size_t> work_count(std::move(p->work_count_));
87 try
88 {
89 p->execute_(p);
90 lock.lock();
91 do_work_finished(work_count);
92 }
93 catch (...)
94 {
95 lock.lock();
96 do_work_finished(work_count);
97 throw;
98 }
99 }
100
101 // Increment outstanding work.
102 void do_work_started(const std::shared_ptr<std::size_t>& work_count) noexcept
103 {
104 if (++(*work_count) == 1)
105 ++use_count_;
106 }
107
108 // Decrement outstanding work. Notify waiting threads if we run out.
109 void do_work_finished(const std::shared_ptr<std::size_t>& work_count) noexcept
110 {
111 if (--(*work_count) == 0)
112 {
113 --use_count_;
114 condition_.notify_all();
115 }
116 }
117
118 // Dispatch a function, executing it immediately if the queue is already
119 // loaded. Otherwise adds the function to the queue and wakes a thread.
120 void do_dispatch(std::shared_ptr<function_base> p,
121 const std::shared_ptr<std::size_t>& work_count)
122 {
123 std::unique_lock<std::mutex> lock(mutex_);
124 if (queue_.size() > thread_count_ * 16)
125 {
126 do_work_started(work_count);
127 lock.unlock();
128 execute(lock, p);
129 }
130 else
131 {
132 queue_.push(p);
133 do_work_started(work_count);
134 condition_.notify_one();
135 }
136 }
137
138 // Add a function to the queue and wake a thread.
139 void do_post(std::shared_ptr<function_base> p,
140 const std::shared_ptr<std::size_t>& work_count)
141 {
142 std::lock_guard<std::mutex> lock(mutex_);
143 queue_.push(p);
144 do_work_started(work_count);
145 condition_.notify_one();
146 }
147
148 // Ask all threads to shut down.
149 void stop_threads()
150 {
151 std::lock_guard<std::mutex> lock(mutex_);
152 --use_count_;
153 condition_.notify_all();
154 }
155
156 std::mutex mutex_;
157 std::condition_variable condition_;
158 std::queue<std::shared_ptr<function_base>> queue_;
159 std::size_t use_count_;
160 std::size_t thread_count_;
161 thread_pool threads_;
162 };
163
164 // A class that satisfies the Executor requirements. Every function or piece of
165 // work associated with a fork_executor is part of a single, joinable group.
166 class fork_executor
167 {
168 public:
169 fork_executor(fork_join_pool& ctx)
170 : context_(ctx),
171 work_count_(std::make_shared<std::size_t>(0))
172 {
173 }
174
175 fork_join_pool& context() const noexcept
176 {
177 return context_;
178 }
179
180 void on_work_started() const noexcept
181 {
182 std::lock_guard<std::mutex> lock(context_.mutex_);
183 context_.do_work_started(work_count_);
184 }
185
186 void on_work_finished() const noexcept
187 {
188 std::lock_guard<std::mutex> lock(context_.mutex_);
189 context_.do_work_finished(work_count_);
190 }
191
192 template <class Func, class Alloc>
193 void dispatch(Func&& f, const Alloc& a) const
194 {
195 auto p(std::allocate_shared<function<Func>>(
196 typename std::allocator_traits<Alloc>::template rebind_alloc<char>(a),
197 std::move(f), work_count_));
198 context_.do_dispatch(p, work_count_);
199 }
200
201 template <class Func, class Alloc>
202 void post(Func f, const Alloc& a) const
203 {
204 auto p(std::allocate_shared<function<Func>>(
205 typename std::allocator_traits<Alloc>::template rebind_alloc<char>(a),
206 std::move(f), work_count_));
207 context_.do_post(p, work_count_);
208 }
209
210 template <class Func, class Alloc>
211 void defer(Func&& f, const Alloc& a) const
212 {
213 post(std::forward<Func>(f), a);
214 }
215
216 friend bool operator==(const fork_executor& a,
217 const fork_executor& b) noexcept
218 {
219 return a.work_count_ == b.work_count_;
220 }
221
222 friend bool operator!=(const fork_executor& a,
223 const fork_executor& b) noexcept
224 {
225 return a.work_count_ != b.work_count_;
226 }
227
228 // Block until all work associated with the executor is complete. While it is
229 // waiting, the thread may be borrowed to execute functions from the queue.
230 void join() const
231 {
232 std::unique_lock<std::mutex> lock(context_.mutex_);
233 while (*work_count_ > 0)
234 if (!context_.execute_next(lock))
235 context_.condition_.wait(lock);
236 }
237
238 private:
239 template <class Func>
240 struct function : fork_join_pool::function_base
241 {
242 explicit function(Func f, const std::shared_ptr<std::size_t>& w)
243 : function_(std::move(f))
244 {
245 work_count_ = w;
246 execute_ = [](std::shared_ptr<fork_join_pool::function_base>& p)
247 {
248 Func tmp(std::move(static_cast<function*>(p.get())->function_));
249 p.reset();
250 tmp();
251 };
252 }
253
254 Func function_;
255 };
256
257 fork_join_pool& context_;
258 std::shared_ptr<std::size_t> work_count_;
259 };
260
261 // Helper class to automatically join a fork_executor when exiting a scope.
262 class join_guard
263 {
264 public:
265 explicit join_guard(const fork_executor& ex) : ex_(ex) {}
266 join_guard(const join_guard&) = delete;
267 join_guard(join_guard&&) = delete;
268 ~join_guard() { ex_.join(); }
269
270 private:
271 fork_executor ex_;
272 };
273
274 //------------------------------------------------------------------------------
275
276 #include <algorithm>
277 #include <iostream>
278 #include <random>
279 #include <vector>
280
281 fork_join_pool pool;
282
283 template <class Iterator>
284 void fork_join_sort(Iterator begin, Iterator end)
285 {
286 std::size_t n = end - begin;
287 if (n > 32768)
288 {
289 {
290 fork_executor fork(pool);
291 join_guard join(fork);
292 dispatch(fork, [=]{ fork_join_sort(begin, begin + n / 2); });
293 dispatch(fork, [=]{ fork_join_sort(begin + n / 2, end); });
294 }
295 std::inplace_merge(begin, begin + n / 2, end);
296 }
297 else
298 {
299 std::sort(begin, end);
300 }
301 }
302
303 int main(int argc, char* argv[])
304 {
305 if (argc != 2)
306 {
307 std::cerr << "Usage: fork_join <size>\n";
308 return 1;
309 }
310
311 std::vector<double> vec(std::atoll(argv[1]));
312 std::iota(vec.begin(), vec.end(), 0);
313
314 std::random_device rd;
315 std::mt19937 g(rd());
316 std::shuffle(vec.begin(), vec.end(), g);
317
318 std::chrono::steady_clock::time_point start = std::chrono::steady_clock::now();
319
320 fork_join_sort(vec.begin(), vec.end());
321
322 std::chrono::steady_clock::duration elapsed = std::chrono::steady_clock::now() - start;
323
324 std::cout << "sort took ";
325 std::cout << std::chrono::duration_cast<std::chrono::microseconds>(elapsed).count();
326 std::cout << " microseconds" << std::endl;
327 }