1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements. See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership. The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License. You may obtain a copy of the License at
9 // http://www.apache.org/licenses/LICENSE-2.0
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied. See the License for the
15 // specific language governing permissions and limitations
20 #include <condition_variable>
29 #include <gtest/gtest.h>
31 #include "arrow/status.h"
32 #include "arrow/testing/future_util.h"
33 #include "arrow/testing/gtest_util.h"
34 #include "arrow/util/task_group.h"
35 #include "arrow/util/thread_pool.h"
40 // Generate random sleep durations
41 static std::vector
<double> RandomSleepDurations(int nsleeps
, double min_seconds
,
43 std::vector
<double> sleeps
;
44 std::default_random_engine engine
;
45 std::uniform_real_distribution
<> sleep_dist(min_seconds
, max_seconds
);
46 for (int i
= 0; i
< nsleeps
; ++i
) {
47 sleeps
.push_back(sleep_dist(engine
));
52 // Check TaskGroup behaviour with a bunch of all-successful tasks
53 void TestTaskGroupSuccess(std::shared_ptr
<TaskGroup
> task_group
) {
54 const int NTASKS
= 10;
55 auto sleeps
= RandomSleepDurations(NTASKS
, 1e-3, 4e-3);
58 std::atomic
<int> count(0);
59 for (int i
= 0; i
< NTASKS
; ++i
) {
60 task_group
->Append([&, i
]() {
66 ASSERT_TRUE(task_group
->ok());
68 ASSERT_OK(task_group
->Finish());
69 ASSERT_TRUE(task_group
->ok());
70 ASSERT_EQ(count
.load(), NTASKS
* (NTASKS
- 1) / 2);
71 // Finish() is idempotent
72 ASSERT_OK(task_group
->Finish());
75 // Check TaskGroup behaviour with some successful and some failing tasks
76 void TestTaskGroupErrors(std::shared_ptr
<TaskGroup
> task_group
) {
77 const int NSUCCESSES
= 2;
78 const int NERRORS
= 20;
80 std::atomic
<int> count(0);
82 auto task_group_was_ok
= false;
83 task_group
->Append([&]() -> Status
{
84 for (int i
= 0; i
< NSUCCESSES
; ++i
) {
85 task_group
->Append([&]() {
90 task_group_was_ok
= task_group
->ok();
91 for (int i
= 0; i
< NERRORS
; ++i
) {
92 task_group
->Append([&]() {
95 return Status::Invalid("some message");
102 // Task error is propagated
103 ASSERT_RAISES(Invalid
, task_group
->Finish());
104 ASSERT_TRUE(task_group_was_ok
);
105 ASSERT_FALSE(task_group
->ok());
106 if (task_group
->parallelism() == 1) {
107 // Serial: exactly two successes and an error
108 ASSERT_EQ(count
.load(), 3);
110 // Parallel: at least two successes and an error
111 ASSERT_GE(count
.load(), 3);
112 ASSERT_LE(count
.load(), 2 * task_group
->parallelism());
114 // Finish() is idempotent
115 ASSERT_RAISES(Invalid
, task_group
->Finish());
118 void TestTaskGroupCancel(std::shared_ptr
<TaskGroup
> task_group
, StopSource
* stop_source
) {
119 const int NSUCCESSES
= 2;
120 const int NCANCELS
= 20;
122 std::atomic
<int> count(0);
124 auto task_group_was_ok
= false;
125 task_group
->Append([&]() -> Status
{
126 for (int i
= 0; i
< NSUCCESSES
; ++i
) {
127 task_group
->Append([&]() {
132 task_group_was_ok
= task_group
->ok();
133 for (int i
= 0; i
< NCANCELS
; ++i
) {
134 task_group
->Append([&]() {
136 stop_source
->RequestStop();
145 // Cancellation is propagated
146 ASSERT_RAISES(Cancelled
, task_group
->Finish());
147 ASSERT_TRUE(task_group_was_ok
);
148 ASSERT_FALSE(task_group
->ok());
149 if (task_group
->parallelism() == 1) {
150 // Serial: exactly three successes
151 ASSERT_EQ(count
.load(), NSUCCESSES
+ 1);
153 // Parallel: at least three successes
154 ASSERT_GE(count
.load(), NSUCCESSES
+ 1);
155 ASSERT_LE(count
.load(), NSUCCESSES
* task_group
->parallelism());
157 // Finish() is idempotent
158 ASSERT_RAISES(Cancelled
, task_group
->Finish());
161 class CopyCountingTask
{
163 explicit CopyCountingTask(std::shared_ptr
<uint8_t> target
)
164 : counter(0), target(std::move(target
)) {}
166 CopyCountingTask(const CopyCountingTask
& other
)
167 : counter(other
.counter
+ 1), target(other
.target
) {}
169 CopyCountingTask
& operator=(const CopyCountingTask
& other
) {
170 counter
= other
.counter
+ 1;
171 target
= other
.target
;
175 CopyCountingTask(CopyCountingTask
&& other
) = default;
176 CopyCountingTask
& operator=(CopyCountingTask
&& other
) = default;
178 Status
operator()() {
185 std::shared_ptr
<uint8_t> target
;
188 // Check TaskGroup behaviour with tasks spawning other tasks
189 void TestTasksSpawnTasks(std::shared_ptr
<TaskGroup
> task_group
) {
192 std::atomic
<int> count(0);
193 // Make a task that recursively spawns itself
194 std::function
<std::function
<Status()>(int)> make_task
= [&](int i
) {
198 // Exercise parallelism by spawning two tasks at once and then sleeping
199 task_group
->Append(make_task(i
- 1));
200 task_group
->Append(make_task(i
- 1));
207 task_group
->Append(make_task(N
));
209 ASSERT_OK(task_group
->Finish());
210 ASSERT_TRUE(task_group
->ok());
211 ASSERT_EQ(count
.load(), (1 << (N
+ 1)) - 1);
214 // A task that keeps recursing until a barrier is set.
215 // Using a lambda for this doesn't play well with Thread Sanitizer.
217 std::atomic
<bool>* barrier_
;
218 std::weak_ptr
<TaskGroup
> weak_group_ptr_
;
219 Status final_status_
;
221 Status
operator()() {
222 if (!barrier_
->load()) {
224 // Note the TaskGroup should be kept alive by the fact this task
225 // is still running...
226 weak_group_ptr_
.lock()->Append(*this);
228 return final_status_
;
232 // Try to replicate subtle lifetime issues when destroying a TaskGroup
233 // where all tasks may not have finished running.
234 void StressTaskGroupLifetime(std::function
<std::shared_ptr
<TaskGroup
>()> factory
) {
235 const int NTASKS
= 100;
236 auto task_group
= factory();
237 auto weak_group_ptr
= std::weak_ptr
<TaskGroup
>(task_group
);
239 std::atomic
<bool> barrier(false);
241 BarrierTask task
{&barrier
, weak_group_ptr
, Status::OK()};
243 for (int i
= 0; i
< NTASKS
; ++i
) {
244 task_group
->Append(task
);
247 // Lose strong reference
252 while (!weak_group_ptr
.expired()) {
257 // Same, but with also a failing task
258 void StressFailingTaskGroupLifetime(std::function
<std::shared_ptr
<TaskGroup
>()> factory
) {
259 const int NTASKS
= 100;
260 auto task_group
= factory();
261 auto weak_group_ptr
= std::weak_ptr
<TaskGroup
>(task_group
);
263 std::atomic
<bool> barrier(false);
265 BarrierTask task
{&barrier
, weak_group_ptr
, Status::OK()};
266 BarrierTask failing_task
{&barrier
, weak_group_ptr
, Status::Invalid("XXX")};
268 for (int i
= 0; i
< NTASKS
; ++i
) {
269 task_group
->Append(task
);
271 task_group
->Append(failing_task
);
273 // Lose strong reference
278 while (!weak_group_ptr
.expired()) {
283 void TestNoCopyTask(std::shared_ptr
<TaskGroup
> task_group
) {
284 auto counter
= std::make_shared
<uint8_t>(0);
285 CopyCountingTask
task(counter
);
286 task_group
->Append(std::move(task
));
287 ASSERT_OK(task_group
->Finish());
288 ASSERT_EQ(0, *counter
);
291 void TestFinishNotSticky(std::function
<std::shared_ptr
<TaskGroup
>()> factory
) {
292 // If a task is added that runs very quickly it might decrement the task counter back
293 // down to 0 and mark the completion future as complete before all tasks are added.
294 // The "finished future" of the task group could get stuck to complete.
296 // Instead the task group should not allow the finished future to be marked complete
297 // until after FinishAsync has been called.
298 const int NTASKS
= 100;
299 for (int i
= 0; i
< NTASKS
; ++i
) {
300 auto task_group
= factory();
301 // Add a task and let it complete
302 task_group
->Append([] { return Status::OK(); });
303 // Wait a little bit, if the task group was going to lock the finish hopefully it
304 // would do so here while we wait
307 // Add a new task that will still be running
308 std::atomic
<bool> ready(false);
310 std::condition_variable cv
;
311 task_group
->Append([&m
, &cv
, &ready
] {
312 std::unique_lock
<std::mutex
> lk(m
);
313 cv
.wait(lk
, [&ready
] { return ready
.load(); });
317 // Ensure task group not finished already
318 auto finished
= task_group
->FinishAsync();
319 ASSERT_FALSE(finished
.is_finished());
321 std::unique_lock
<std::mutex
> lk(m
);
326 ASSERT_FINISHES_OK(finished
);
330 void TestFinishNeverStarted(std::shared_ptr
<TaskGroup
> task_group
) {
331 // If we call FinishAsync we are done adding tasks so if we never added any it should be
333 auto finished
= task_group
->FinishAsync();
334 ASSERT_TRUE(finished
.Wait(1));
337 void TestFinishAlreadyCompleted(std::function
<std::shared_ptr
<TaskGroup
>()> factory
) {
338 // If we call FinishAsync we are done adding tasks so even if no tasks are running we
339 // should still be completed
340 const int NTASKS
= 100;
341 for (int i
= 0; i
< NTASKS
; ++i
) {
342 auto task_group
= factory();
343 // Add a task and let it complete
344 task_group
->Append([] { return Status::OK(); });
345 // Wait a little bit, hopefully enough time for the task to finish on one of these
348 auto finished
= task_group
->FinishAsync();
349 ASSERT_FINISHES_OK(finished
);
353 TEST(SerialTaskGroup
, Success
) { TestTaskGroupSuccess(TaskGroup::MakeSerial()); }
355 TEST(SerialTaskGroup
, Errors
) { TestTaskGroupErrors(TaskGroup::MakeSerial()); }
357 TEST(SerialTaskGroup
, Cancel
) {
358 StopSource stop_source
;
359 TestTaskGroupCancel(TaskGroup::MakeSerial(stop_source
.token()), &stop_source
);
362 TEST(SerialTaskGroup
, TasksSpawnTasks
) { TestTasksSpawnTasks(TaskGroup::MakeSerial()); }
364 TEST(SerialTaskGroup
, NoCopyTask
) { TestNoCopyTask(TaskGroup::MakeSerial()); }
366 TEST(SerialTaskGroup
, FinishNeverStarted
) {
367 TestFinishNeverStarted(TaskGroup::MakeSerial());
370 TEST(SerialTaskGroup
, FinishAlreadyCompleted
) {
371 TestFinishAlreadyCompleted([] { return TaskGroup::MakeSerial(); });
374 TEST(ThreadedTaskGroup
, Success
) {
375 auto task_group
= TaskGroup::MakeThreaded(GetCpuThreadPool());
376 TestTaskGroupSuccess(task_group
);
379 TEST(ThreadedTaskGroup
, Errors
) {
380 // Limit parallelism to ensure some tasks don't get started
381 // after the first failing ones
382 std::shared_ptr
<ThreadPool
> thread_pool
;
383 ASSERT_OK_AND_ASSIGN(thread_pool
, ThreadPool::Make(4));
385 TestTaskGroupErrors(TaskGroup::MakeThreaded(thread_pool
.get()));
388 TEST(ThreadedTaskGroup
, Cancel
) {
389 std::shared_ptr
<ThreadPool
> thread_pool
;
390 ASSERT_OK_AND_ASSIGN(thread_pool
, ThreadPool::Make(4));
392 StopSource stop_source
;
393 TestTaskGroupCancel(TaskGroup::MakeThreaded(thread_pool
.get(), stop_source
.token()),
397 TEST(ThreadedTaskGroup
, TasksSpawnTasks
) {
398 auto task_group
= TaskGroup::MakeThreaded(GetCpuThreadPool());
399 TestTasksSpawnTasks(task_group
);
402 TEST(ThreadedTaskGroup
, NoCopyTask
) {
403 std::shared_ptr
<ThreadPool
> thread_pool
;
404 ASSERT_OK_AND_ASSIGN(thread_pool
, ThreadPool::Make(4));
405 TestNoCopyTask(TaskGroup::MakeThreaded(thread_pool
.get()));
408 TEST(ThreadedTaskGroup
, StressTaskGroupLifetime
) {
409 std::shared_ptr
<ThreadPool
> thread_pool
;
410 ASSERT_OK_AND_ASSIGN(thread_pool
, ThreadPool::Make(16));
412 StressTaskGroupLifetime([&] { return TaskGroup::MakeThreaded(thread_pool
.get()); });
415 TEST(ThreadedTaskGroup
, StressFailingTaskGroupLifetime
) {
416 std::shared_ptr
<ThreadPool
> thread_pool
;
417 ASSERT_OK_AND_ASSIGN(thread_pool
, ThreadPool::Make(16));
419 StressFailingTaskGroupLifetime(
420 [&] { return TaskGroup::MakeThreaded(thread_pool
.get()); });
423 TEST(ThreadedTaskGroup
, FinishNotSticky
) {
424 std::shared_ptr
<ThreadPool
> thread_pool
;
425 ASSERT_OK_AND_ASSIGN(thread_pool
, ThreadPool::Make(16));
427 TestFinishNotSticky([&] { return TaskGroup::MakeThreaded(thread_pool
.get()); });
430 TEST(ThreadedTaskGroup
, FinishNeverStarted
) {
431 std::shared_ptr
<ThreadPool
> thread_pool
;
432 ASSERT_OK_AND_ASSIGN(thread_pool
, ThreadPool::Make(4));
433 TestFinishNeverStarted(TaskGroup::MakeThreaded(thread_pool
.get()));
436 TEST(ThreadedTaskGroup
, FinishAlreadyCompleted
) {
437 std::shared_ptr
<ThreadPool
> thread_pool
;
438 ASSERT_OK_AND_ASSIGN(thread_pool
, ThreadPool::Make(16));
440 TestFinishAlreadyCompleted([&] { return TaskGroup::MakeThreaded(thread_pool
.get()); });
443 } // namespace internal