]> git.proxmox.com Git - ceph.git/blame - ceph/src/arrow/cpp/src/arrow/compute/exec/task_util.cc
import quincy 17.2.0
[ceph.git] / ceph / src / arrow / cpp / src / arrow / compute / exec / task_util.cc
CommitLineData
1d09f67e
TL
1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18#include "arrow/compute/exec/task_util.h"
19
20#include <algorithm>
21#include <mutex>
22
23#include "arrow/util/logging.h"
24
25namespace arrow {
26namespace compute {
27
28class TaskSchedulerImpl : public TaskScheduler {
29 public:
30 TaskSchedulerImpl();
31 int RegisterTaskGroup(TaskImpl task_impl, TaskGroupContinuationImpl cont_impl) override;
32 void RegisterEnd() override;
33 Status StartTaskGroup(size_t thread_id, int group_id, int64_t total_num_tasks) override;
34 Status ExecuteMore(size_t thread_id, int num_tasks_to_execute,
35 bool execute_all) override;
36 Status StartScheduling(size_t thread_id, ScheduleImpl schedule_impl,
37 int num_concurrent_tasks, bool use_sync_execution) override;
38 void Abort(AbortContinuationImpl impl) override;
39
40 private:
41 // Task group state transitions progress one way.
42 // Seeing an old version of the state by a thread is a valid situation.
43 //
44 enum class TaskGroupState : int {
45 NOT_READY,
46 READY,
47 ALL_TASKS_STARTED,
48 ALL_TASKS_FINISHED
49 };
50
51 struct TaskGroup {
52 TaskGroup(TaskImpl task_impl, TaskGroupContinuationImpl cont_impl)
53 : task_impl_(std::move(task_impl)),
54 cont_impl_(std::move(cont_impl)),
55 state_(TaskGroupState::NOT_READY),
56 num_tasks_present_(0) {
57 num_tasks_started_.value.store(0);
58 num_tasks_finished_.value.store(0);
59 }
60 TaskGroup(const TaskGroup& src)
61 : task_impl_(src.task_impl_),
62 cont_impl_(src.cont_impl_),
63 state_(TaskGroupState::NOT_READY),
64 num_tasks_present_(0) {
65 ARROW_DCHECK(src.state_ == TaskGroupState::NOT_READY);
66 num_tasks_started_.value.store(0);
67 num_tasks_finished_.value.store(0);
68 }
69 TaskImpl task_impl_;
70 TaskGroupContinuationImpl cont_impl_;
71
72 TaskGroupState state_;
73 int64_t num_tasks_present_;
74
75 AtomicWithPadding<int64_t> num_tasks_started_;
76 AtomicWithPadding<int64_t> num_tasks_finished_;
77 };
78
79 std::vector<std::pair<int, int64_t>> PickTasks(int num_tasks, int start_task_group = 0);
80 Status ExecuteTask(size_t thread_id, int group_id, int64_t task_id,
81 bool* task_group_finished);
82 bool PostExecuteTask(size_t thread_id, int group_id);
83 Status OnTaskGroupFinished(size_t thread_id, int group_id,
84 bool* all_task_groups_finished);
85 Status ScheduleMore(size_t thread_id, int num_tasks_finished = 0);
86
87 bool use_sync_execution_;
88 int num_concurrent_tasks_;
89 ScheduleImpl schedule_impl_;
90 AbortContinuationImpl abort_cont_impl_;
91
92 std::vector<TaskGroup> task_groups_;
93 bool aborted_;
94 bool register_finished_;
95 std::mutex mutex_; // Mutex protecting task_groups_ (state_ and num_tasks_present_
96 // fields), aborted_ flag and register_finished_ flag
97
98 AtomicWithPadding<int> num_tasks_to_schedule_;
99};
100
101TaskSchedulerImpl::TaskSchedulerImpl()
102 : use_sync_execution_(false),
103 num_concurrent_tasks_(0),
104 aborted_(false),
105 register_finished_(false) {
106 num_tasks_to_schedule_.value.store(0);
107}
108
109int TaskSchedulerImpl::RegisterTaskGroup(TaskImpl task_impl,
110 TaskGroupContinuationImpl cont_impl) {
111 int result = static_cast<int>(task_groups_.size());
112 task_groups_.emplace_back(std::move(task_impl), std::move(cont_impl));
113 return result;
114}
115
116void TaskSchedulerImpl::RegisterEnd() {
117 std::lock_guard<std::mutex> lock(mutex_);
118
119 register_finished_ = true;
120}
121
122Status TaskSchedulerImpl::StartTaskGroup(size_t thread_id, int group_id,
123 int64_t total_num_tasks) {
124 ARROW_DCHECK(group_id >= 0 && group_id < static_cast<int>(task_groups_.size()));
125 TaskGroup& task_group = task_groups_[group_id];
126
127 bool aborted = false;
128 bool all_tasks_finished = false;
129 {
130 std::lock_guard<std::mutex> lock(mutex_);
131
132 aborted = aborted_;
133
134 if (task_group.state_ == TaskGroupState::NOT_READY) {
135 task_group.num_tasks_present_ = total_num_tasks;
136 if (total_num_tasks == 0) {
137 task_group.state_ = TaskGroupState::ALL_TASKS_FINISHED;
138 all_tasks_finished = true;
139 }
140 task_group.state_ = TaskGroupState::READY;
141 }
142 }
143
144 if (!aborted && all_tasks_finished) {
145 bool all_task_groups_finished = false;
146 RETURN_NOT_OK(OnTaskGroupFinished(thread_id, group_id, &all_task_groups_finished));
147 if (all_task_groups_finished) {
148 return Status::OK();
149 }
150 }
151
152 if (!aborted) {
153 return ScheduleMore(thread_id);
154 } else {
155 return Status::Cancelled("Scheduler cancelled");
156 }
157}
158
159std::vector<std::pair<int, int64_t>> TaskSchedulerImpl::PickTasks(int num_tasks,
160 int start_task_group) {
161 std::vector<std::pair<int, int64_t>> result;
162 for (size_t i = 0; i < task_groups_.size(); ++i) {
163 int task_group_id = static_cast<int>((start_task_group + i) % (task_groups_.size()));
164 TaskGroup& task_group = task_groups_[task_group_id];
165
166 {
167 std::lock_guard<std::mutex> lock(mutex_);
168 if (task_group.state_ != TaskGroupState::READY) {
169 continue;
170 }
171 }
172
173 int num_tasks_remaining = num_tasks - static_cast<int>(result.size());
174 int64_t start_task =
175 task_group.num_tasks_started_.value.fetch_add(num_tasks_remaining);
176 if (start_task >= task_group.num_tasks_present_) {
177 continue;
178 }
179
180 int num_tasks_current_group = num_tasks_remaining;
181 if (start_task + num_tasks_current_group >= task_group.num_tasks_present_) {
182 {
183 std::lock_guard<std::mutex> lock(mutex_);
184 if (task_group.state_ == TaskGroupState::READY) {
185 task_group.state_ = TaskGroupState::ALL_TASKS_STARTED;
186 }
187 }
188 num_tasks_current_group =
189 static_cast<int>(task_group.num_tasks_present_ - start_task);
190 }
191
192 for (int64_t task_id = start_task; task_id < start_task + num_tasks_current_group;
193 ++task_id) {
194 result.push_back(std::make_pair(task_group_id, task_id));
195 }
196
197 if (static_cast<int>(result.size()) == num_tasks) {
198 break;
199 }
200 }
201
202 return result;
203}
204
205Status TaskSchedulerImpl::ExecuteTask(size_t thread_id, int group_id, int64_t task_id,
206 bool* task_group_finished) {
207 if (!aborted_) {
208 RETURN_NOT_OK(task_groups_[group_id].task_impl_(thread_id, task_id));
209 }
210 *task_group_finished = PostExecuteTask(thread_id, group_id);
211 return Status::OK();
212}
213
214bool TaskSchedulerImpl::PostExecuteTask(size_t thread_id, int group_id) {
215 int64_t total = task_groups_[group_id].num_tasks_present_;
216 int64_t prev_finished = task_groups_[group_id].num_tasks_finished_.value.fetch_add(1);
217 bool all_tasks_finished = (prev_finished + 1 == total);
218 return all_tasks_finished;
219}
220
221Status TaskSchedulerImpl::OnTaskGroupFinished(size_t thread_id, int group_id,
222 bool* all_task_groups_finished) {
223 bool aborted = false;
224 {
225 std::lock_guard<std::mutex> lock(mutex_);
226
227 aborted = aborted_;
228 TaskGroup& task_group = task_groups_[group_id];
229 task_group.state_ = TaskGroupState::ALL_TASKS_FINISHED;
230 *all_task_groups_finished = true;
231 for (size_t i = 0; i < task_groups_.size(); ++i) {
232 if (task_groups_[i].state_ != TaskGroupState::ALL_TASKS_FINISHED) {
233 *all_task_groups_finished = false;
234 break;
235 }
236 }
237 }
238
239 if (aborted && *all_task_groups_finished) {
240 abort_cont_impl_();
241 return Status::Cancelled("Scheduler cancelled");
242 }
243 if (!aborted) {
244 RETURN_NOT_OK(task_groups_[group_id].cont_impl_(thread_id));
245 }
246 return Status::OK();
247}
248
249Status TaskSchedulerImpl::ExecuteMore(size_t thread_id, int num_tasks_to_execute,
250 bool execute_all) {
251 num_tasks_to_execute = std::max(1, num_tasks_to_execute);
252
253 int last_id = 0;
254 for (;;) {
255 if (aborted_) {
256 return Status::Cancelled("Scheduler cancelled");
257 }
258
259 // Pick next bundle of tasks
260 const auto& tasks = PickTasks(num_tasks_to_execute, last_id);
261 if (tasks.empty()) {
262 break;
263 }
264 last_id = tasks.back().first;
265
266 // Execute picked tasks immediately
267 for (size_t i = 0; i < tasks.size(); ++i) {
268 int group_id = tasks[i].first;
269 int64_t task_id = tasks[i].second;
270 bool task_group_finished = false;
271 Status status = ExecuteTask(thread_id, group_id, task_id, &task_group_finished);
272 if (!status.ok()) {
273 // Mark the remaining picked tasks as finished
274 for (size_t j = i + 1; j < tasks.size(); ++j) {
275 if (PostExecuteTask(thread_id, tasks[j].first)) {
276 bool all_task_groups_finished = false;
277 RETURN_NOT_OK(
278 OnTaskGroupFinished(thread_id, group_id, &all_task_groups_finished));
279 if (all_task_groups_finished) {
280 return Status::OK();
281 }
282 }
283 }
284 return status;
285 } else {
286 if (task_group_finished) {
287 bool all_task_groups_finished = false;
288 RETURN_NOT_OK(
289 OnTaskGroupFinished(thread_id, group_id, &all_task_groups_finished));
290 if (all_task_groups_finished) {
291 return Status::OK();
292 }
293 }
294 }
295 }
296
297 if (!execute_all) {
298 num_tasks_to_execute -= static_cast<int>(tasks.size());
299 if (num_tasks_to_execute == 0) {
300 break;
301 }
302 }
303 }
304
305 return Status::OK();
306}
307
308Status TaskSchedulerImpl::StartScheduling(size_t thread_id, ScheduleImpl schedule_impl,
309 int num_concurrent_tasks,
310 bool use_sync_execution) {
311 schedule_impl_ = std::move(schedule_impl);
312 use_sync_execution_ = use_sync_execution;
313 num_concurrent_tasks_ = num_concurrent_tasks;
314 num_tasks_to_schedule_.value += num_concurrent_tasks;
315 return ScheduleMore(thread_id);
316}
317
318Status TaskSchedulerImpl::ScheduleMore(size_t thread_id, int num_tasks_finished) {
319 if (aborted_) {
320 return Status::Cancelled("Scheduler cancelled");
321 }
322
323 ARROW_DCHECK(register_finished_);
324
325 if (use_sync_execution_) {
326 return ExecuteMore(thread_id, 1, true);
327 }
328
329 int num_new_tasks = num_tasks_finished;
330 for (;;) {
331 int expected = num_tasks_to_schedule_.value.load();
332 if (num_tasks_to_schedule_.value.compare_exchange_strong(expected, 0)) {
333 num_new_tasks += expected;
334 break;
335 }
336 }
337 if (num_new_tasks == 0) {
338 return Status::OK();
339 }
340
341 const auto& tasks = PickTasks(num_new_tasks);
342 if (static_cast<int>(tasks.size()) < num_new_tasks) {
343 num_tasks_to_schedule_.value += num_new_tasks - static_cast<int>(tasks.size());
344 }
345
346 for (size_t i = 0; i < tasks.size(); ++i) {
347 int group_id = tasks[i].first;
348 int64_t task_id = tasks[i].second;
349 RETURN_NOT_OK(schedule_impl_([this, group_id, task_id](size_t thread_id) -> Status {
350 RETURN_NOT_OK(ScheduleMore(thread_id, 1));
351
352 bool task_group_finished = false;
353 RETURN_NOT_OK(ExecuteTask(thread_id, group_id, task_id, &task_group_finished));
354
355 if (task_group_finished) {
356 bool all_task_groups_finished = false;
357 return OnTaskGroupFinished(thread_id, group_id, &all_task_groups_finished);
358 }
359
360 return Status::OK();
361 }));
362 }
363
364 return Status::OK();
365}
366
367void TaskSchedulerImpl::Abort(AbortContinuationImpl impl) {
368 bool all_finished = true;
369 {
370 std::lock_guard<std::mutex> lock(mutex_);
371 aborted_ = true;
372 abort_cont_impl_ = std::move(impl);
373 if (register_finished_) {
374 for (size_t i = 0; i < task_groups_.size(); ++i) {
375 TaskGroup& task_group = task_groups_[i];
376 if (task_group.state_ == TaskGroupState::NOT_READY) {
377 task_group.state_ = TaskGroupState::ALL_TASKS_FINISHED;
378 } else if (task_group.state_ == TaskGroupState::READY) {
379 int64_t expected = task_group.num_tasks_started_.value.load();
380 for (;;) {
381 if (task_group.num_tasks_started_.value.compare_exchange_strong(
382 expected, task_group.num_tasks_present_)) {
383 break;
384 }
385 }
386 int64_t before_add = task_group.num_tasks_finished_.value.fetch_add(
387 task_group.num_tasks_present_ - expected);
388 if (before_add >= expected) {
389 task_group.state_ = TaskGroupState::ALL_TASKS_FINISHED;
390 } else {
391 all_finished = false;
392 task_group.state_ = TaskGroupState::ALL_TASKS_STARTED;
393 }
394 }
395 }
396 }
397 }
398 if (all_finished) {
399 abort_cont_impl_();
400 }
401}
402
403std::unique_ptr<TaskScheduler> TaskScheduler::Make() {
404 std::unique_ptr<TaskSchedulerImpl> impl{new TaskSchedulerImpl()};
405 return std::move(impl);
406}
407
408} // namespace compute
409} // namespace arrow