]> git.proxmox.com Git - ceph.git/blob - ceph/src/arrow/cpp/src/arrow/util/task_group_test.cc
4913fb9294c2f49d25fd27df725bb04b6d5f113f
[ceph.git] / ceph / src / arrow / cpp / src / arrow / util / task_group_test.cc
1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements. See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership. The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License. You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied. See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17
18 #include <atomic>
19 #include <chrono>
20 #include <condition_variable>
21 #include <cstdint>
22 #include <functional>
23 #include <memory>
24 #include <random>
25 #include <thread>
26 #include <utility>
27 #include <vector>
28
29 #include <gtest/gtest.h>
30
31 #include "arrow/status.h"
32 #include "arrow/testing/future_util.h"
33 #include "arrow/testing/gtest_util.h"
34 #include "arrow/util/task_group.h"
35 #include "arrow/util/thread_pool.h"
36
37 namespace arrow {
38 namespace internal {
39
40 // Generate random sleep durations
41 static std::vector<double> RandomSleepDurations(int nsleeps, double min_seconds,
42 double max_seconds) {
43 std::vector<double> sleeps;
44 std::default_random_engine engine;
45 std::uniform_real_distribution<> sleep_dist(min_seconds, max_seconds);
46 for (int i = 0; i < nsleeps; ++i) {
47 sleeps.push_back(sleep_dist(engine));
48 }
49 return sleeps;
50 }
51
52 // Check TaskGroup behaviour with a bunch of all-successful tasks
53 void TestTaskGroupSuccess(std::shared_ptr<TaskGroup> task_group) {
54 const int NTASKS = 10;
55 auto sleeps = RandomSleepDurations(NTASKS, 1e-3, 4e-3);
56
57 // Add NTASKS sleeps
58 std::atomic<int> count(0);
59 for (int i = 0; i < NTASKS; ++i) {
60 task_group->Append([&, i]() {
61 SleepFor(sleeps[i]);
62 count += i;
63 return Status::OK();
64 });
65 }
66 ASSERT_TRUE(task_group->ok());
67
68 ASSERT_OK(task_group->Finish());
69 ASSERT_TRUE(task_group->ok());
70 ASSERT_EQ(count.load(), NTASKS * (NTASKS - 1) / 2);
71 // Finish() is idempotent
72 ASSERT_OK(task_group->Finish());
73 }
74
75 // Check TaskGroup behaviour with some successful and some failing tasks
76 void TestTaskGroupErrors(std::shared_ptr<TaskGroup> task_group) {
77 const int NSUCCESSES = 2;
78 const int NERRORS = 20;
79
80 std::atomic<int> count(0);
81
82 auto task_group_was_ok = false;
83 task_group->Append([&]() -> Status {
84 for (int i = 0; i < NSUCCESSES; ++i) {
85 task_group->Append([&]() {
86 count++;
87 return Status::OK();
88 });
89 }
90 task_group_was_ok = task_group->ok();
91 for (int i = 0; i < NERRORS; ++i) {
92 task_group->Append([&]() {
93 SleepFor(1e-2);
94 count++;
95 return Status::Invalid("some message");
96 });
97 }
98
99 return Status::OK();
100 });
101
102 // Task error is propagated
103 ASSERT_RAISES(Invalid, task_group->Finish());
104 ASSERT_TRUE(task_group_was_ok);
105 ASSERT_FALSE(task_group->ok());
106 if (task_group->parallelism() == 1) {
107 // Serial: exactly two successes and an error
108 ASSERT_EQ(count.load(), 3);
109 } else {
110 // Parallel: at least two successes and an error
111 ASSERT_GE(count.load(), 3);
112 ASSERT_LE(count.load(), 2 * task_group->parallelism());
113 }
114 // Finish() is idempotent
115 ASSERT_RAISES(Invalid, task_group->Finish());
116 }
117
118 void TestTaskGroupCancel(std::shared_ptr<TaskGroup> task_group, StopSource* stop_source) {
119 const int NSUCCESSES = 2;
120 const int NCANCELS = 20;
121
122 std::atomic<int> count(0);
123
124 auto task_group_was_ok = false;
125 task_group->Append([&]() -> Status {
126 for (int i = 0; i < NSUCCESSES; ++i) {
127 task_group->Append([&]() {
128 count++;
129 return Status::OK();
130 });
131 }
132 task_group_was_ok = task_group->ok();
133 for (int i = 0; i < NCANCELS; ++i) {
134 task_group->Append([&]() {
135 SleepFor(1e-2);
136 stop_source->RequestStop();
137 count++;
138 return Status::OK();
139 });
140 }
141
142 return Status::OK();
143 });
144
145 // Cancellation is propagated
146 ASSERT_RAISES(Cancelled, task_group->Finish());
147 ASSERT_TRUE(task_group_was_ok);
148 ASSERT_FALSE(task_group->ok());
149 if (task_group->parallelism() == 1) {
150 // Serial: exactly three successes
151 ASSERT_EQ(count.load(), NSUCCESSES + 1);
152 } else {
153 // Parallel: at least three successes
154 ASSERT_GE(count.load(), NSUCCESSES + 1);
155 ASSERT_LE(count.load(), NSUCCESSES * task_group->parallelism());
156 }
157 // Finish() is idempotent
158 ASSERT_RAISES(Cancelled, task_group->Finish());
159 }
160
161 class CopyCountingTask {
162 public:
163 explicit CopyCountingTask(std::shared_ptr<uint8_t> target)
164 : counter(0), target(std::move(target)) {}
165
166 CopyCountingTask(const CopyCountingTask& other)
167 : counter(other.counter + 1), target(other.target) {}
168
169 CopyCountingTask& operator=(const CopyCountingTask& other) {
170 counter = other.counter + 1;
171 target = other.target;
172 return *this;
173 }
174
175 CopyCountingTask(CopyCountingTask&& other) = default;
176 CopyCountingTask& operator=(CopyCountingTask&& other) = default;
177
178 Status operator()() {
179 *target = counter;
180 return Status::OK();
181 }
182
183 private:
184 uint8_t counter;
185 std::shared_ptr<uint8_t> target;
186 };
187
188 // Check TaskGroup behaviour with tasks spawning other tasks
189 void TestTasksSpawnTasks(std::shared_ptr<TaskGroup> task_group) {
190 const int N = 6;
191
192 std::atomic<int> count(0);
193 // Make a task that recursively spawns itself
194 std::function<std::function<Status()>(int)> make_task = [&](int i) {
195 return [&, i]() {
196 count++;
197 if (i > 0) {
198 // Exercise parallelism by spawning two tasks at once and then sleeping
199 task_group->Append(make_task(i - 1));
200 task_group->Append(make_task(i - 1));
201 SleepFor(1e-3);
202 }
203 return Status::OK();
204 };
205 };
206
207 task_group->Append(make_task(N));
208
209 ASSERT_OK(task_group->Finish());
210 ASSERT_TRUE(task_group->ok());
211 ASSERT_EQ(count.load(), (1 << (N + 1)) - 1);
212 }
213
214 // A task that keeps recursing until a barrier is set.
215 // Using a lambda for this doesn't play well with Thread Sanitizer.
216 struct BarrierTask {
217 std::atomic<bool>* barrier_;
218 std::weak_ptr<TaskGroup> weak_group_ptr_;
219 Status final_status_;
220
221 Status operator()() {
222 if (!barrier_->load()) {
223 SleepFor(1e-5);
224 // Note the TaskGroup should be kept alive by the fact this task
225 // is still running...
226 weak_group_ptr_.lock()->Append(*this);
227 }
228 return final_status_;
229 }
230 };
231
232 // Try to replicate subtle lifetime issues when destroying a TaskGroup
233 // where all tasks may not have finished running.
234 void StressTaskGroupLifetime(std::function<std::shared_ptr<TaskGroup>()> factory) {
235 const int NTASKS = 100;
236 auto task_group = factory();
237 auto weak_group_ptr = std::weak_ptr<TaskGroup>(task_group);
238
239 std::atomic<bool> barrier(false);
240
241 BarrierTask task{&barrier, weak_group_ptr, Status::OK()};
242
243 for (int i = 0; i < NTASKS; ++i) {
244 task_group->Append(task);
245 }
246
247 // Lose strong reference
248 barrier.store(true);
249 task_group.reset();
250
251 // Wait for finish
252 while (!weak_group_ptr.expired()) {
253 SleepFor(1e-5);
254 }
255 }
256
257 // Same, but with also a failing task
258 void StressFailingTaskGroupLifetime(std::function<std::shared_ptr<TaskGroup>()> factory) {
259 const int NTASKS = 100;
260 auto task_group = factory();
261 auto weak_group_ptr = std::weak_ptr<TaskGroup>(task_group);
262
263 std::atomic<bool> barrier(false);
264
265 BarrierTask task{&barrier, weak_group_ptr, Status::OK()};
266 BarrierTask failing_task{&barrier, weak_group_ptr, Status::Invalid("XXX")};
267
268 for (int i = 0; i < NTASKS; ++i) {
269 task_group->Append(task);
270 }
271 task_group->Append(failing_task);
272
273 // Lose strong reference
274 barrier.store(true);
275 task_group.reset();
276
277 // Wait for finish
278 while (!weak_group_ptr.expired()) {
279 SleepFor(1e-5);
280 }
281 }
282
283 void TestNoCopyTask(std::shared_ptr<TaskGroup> task_group) {
284 auto counter = std::make_shared<uint8_t>(0);
285 CopyCountingTask task(counter);
286 task_group->Append(std::move(task));
287 ASSERT_OK(task_group->Finish());
288 ASSERT_EQ(0, *counter);
289 }
290
291 void TestFinishNotSticky(std::function<std::shared_ptr<TaskGroup>()> factory) {
292 // If a task is added that runs very quickly it might decrement the task counter back
293 // down to 0 and mark the completion future as complete before all tasks are added.
294 // The "finished future" of the task group could get stuck to complete.
295 //
296 // Instead the task group should not allow the finished future to be marked complete
297 // until after FinishAsync has been called.
298 const int NTASKS = 100;
299 for (int i = 0; i < NTASKS; ++i) {
300 auto task_group = factory();
301 // Add a task and let it complete
302 task_group->Append([] { return Status::OK(); });
303 // Wait a little bit, if the task group was going to lock the finish hopefully it
304 // would do so here while we wait
305 SleepFor(1e-2);
306
307 // Add a new task that will still be running
308 std::atomic<bool> ready(false);
309 std::mutex m;
310 std::condition_variable cv;
311 task_group->Append([&m, &cv, &ready] {
312 std::unique_lock<std::mutex> lk(m);
313 cv.wait(lk, [&ready] { return ready.load(); });
314 return Status::OK();
315 });
316
317 // Ensure task group not finished already
318 auto finished = task_group->FinishAsync();
319 ASSERT_FALSE(finished.is_finished());
320
321 std::unique_lock<std::mutex> lk(m);
322 ready = true;
323 lk.unlock();
324 cv.notify_one();
325
326 ASSERT_FINISHES_OK(finished);
327 }
328 }
329
330 void TestFinishNeverStarted(std::shared_ptr<TaskGroup> task_group) {
331 // If we call FinishAsync we are done adding tasks so if we never added any it should be
332 // completed
333 auto finished = task_group->FinishAsync();
334 ASSERT_TRUE(finished.Wait(1));
335 }
336
337 void TestFinishAlreadyCompleted(std::function<std::shared_ptr<TaskGroup>()> factory) {
338 // If we call FinishAsync we are done adding tasks so even if no tasks are running we
339 // should still be completed
340 const int NTASKS = 100;
341 for (int i = 0; i < NTASKS; ++i) {
342 auto task_group = factory();
343 // Add a task and let it complete
344 task_group->Append([] { return Status::OK(); });
345 // Wait a little bit, hopefully enough time for the task to finish on one of these
346 // iterations
347 SleepFor(1e-2);
348 auto finished = task_group->FinishAsync();
349 ASSERT_FINISHES_OK(finished);
350 }
351 }
352
353 TEST(SerialTaskGroup, Success) { TestTaskGroupSuccess(TaskGroup::MakeSerial()); }
354
355 TEST(SerialTaskGroup, Errors) { TestTaskGroupErrors(TaskGroup::MakeSerial()); }
356
357 TEST(SerialTaskGroup, Cancel) {
358 StopSource stop_source;
359 TestTaskGroupCancel(TaskGroup::MakeSerial(stop_source.token()), &stop_source);
360 }
361
362 TEST(SerialTaskGroup, TasksSpawnTasks) { TestTasksSpawnTasks(TaskGroup::MakeSerial()); }
363
364 TEST(SerialTaskGroup, NoCopyTask) { TestNoCopyTask(TaskGroup::MakeSerial()); }
365
366 TEST(SerialTaskGroup, FinishNeverStarted) {
367 TestFinishNeverStarted(TaskGroup::MakeSerial());
368 }
369
370 TEST(SerialTaskGroup, FinishAlreadyCompleted) {
371 TestFinishAlreadyCompleted([] { return TaskGroup::MakeSerial(); });
372 }
373
374 TEST(ThreadedTaskGroup, Success) {
375 auto task_group = TaskGroup::MakeThreaded(GetCpuThreadPool());
376 TestTaskGroupSuccess(task_group);
377 }
378
379 TEST(ThreadedTaskGroup, Errors) {
380 // Limit parallelism to ensure some tasks don't get started
381 // after the first failing ones
382 std::shared_ptr<ThreadPool> thread_pool;
383 ASSERT_OK_AND_ASSIGN(thread_pool, ThreadPool::Make(4));
384
385 TestTaskGroupErrors(TaskGroup::MakeThreaded(thread_pool.get()));
386 }
387
388 TEST(ThreadedTaskGroup, Cancel) {
389 std::shared_ptr<ThreadPool> thread_pool;
390 ASSERT_OK_AND_ASSIGN(thread_pool, ThreadPool::Make(4));
391
392 StopSource stop_source;
393 TestTaskGroupCancel(TaskGroup::MakeThreaded(thread_pool.get(), stop_source.token()),
394 &stop_source);
395 }
396
397 TEST(ThreadedTaskGroup, TasksSpawnTasks) {
398 auto task_group = TaskGroup::MakeThreaded(GetCpuThreadPool());
399 TestTasksSpawnTasks(task_group);
400 }
401
402 TEST(ThreadedTaskGroup, NoCopyTask) {
403 std::shared_ptr<ThreadPool> thread_pool;
404 ASSERT_OK_AND_ASSIGN(thread_pool, ThreadPool::Make(4));
405 TestNoCopyTask(TaskGroup::MakeThreaded(thread_pool.get()));
406 }
407
408 TEST(ThreadedTaskGroup, StressTaskGroupLifetime) {
409 std::shared_ptr<ThreadPool> thread_pool;
410 ASSERT_OK_AND_ASSIGN(thread_pool, ThreadPool::Make(16));
411
412 StressTaskGroupLifetime([&] { return TaskGroup::MakeThreaded(thread_pool.get()); });
413 }
414
415 TEST(ThreadedTaskGroup, StressFailingTaskGroupLifetime) {
416 std::shared_ptr<ThreadPool> thread_pool;
417 ASSERT_OK_AND_ASSIGN(thread_pool, ThreadPool::Make(16));
418
419 StressFailingTaskGroupLifetime(
420 [&] { return TaskGroup::MakeThreaded(thread_pool.get()); });
421 }
422
423 TEST(ThreadedTaskGroup, FinishNotSticky) {
424 std::shared_ptr<ThreadPool> thread_pool;
425 ASSERT_OK_AND_ASSIGN(thread_pool, ThreadPool::Make(16));
426
427 TestFinishNotSticky([&] { return TaskGroup::MakeThreaded(thread_pool.get()); });
428 }
429
430 TEST(ThreadedTaskGroup, FinishNeverStarted) {
431 std::shared_ptr<ThreadPool> thread_pool;
432 ASSERT_OK_AND_ASSIGN(thread_pool, ThreadPool::Make(4));
433 TestFinishNeverStarted(TaskGroup::MakeThreaded(thread_pool.get()));
434 }
435
436 TEST(ThreadedTaskGroup, FinishAlreadyCompleted) {
437 std::shared_ptr<ThreadPool> thread_pool;
438 ASSERT_OK_AND_ASSIGN(thread_pool, ThreadPool::Make(16));
439
440 TestFinishAlreadyCompleted([&] { return TaskGroup::MakeThreaded(thread_pool.get()); });
441 }
442
443 } // namespace internal
444 } // namespace arrow