]>
Commit | Line | Data |
---|---|---|
1d09f67e TL |
1 | // Licensed to the Apache Software Foundation (ASF) under one |
2 | // or more contributor license agreements. See the NOTICE file | |
3 | // distributed with this work for additional information | |
4 | // regarding copyright ownership. The ASF licenses this file | |
5 | // to you under the Apache License, Version 2.0 (the | |
6 | // "License"); you may not use this file except in compliance | |
7 | // with the License. You may obtain a copy of the License at | |
8 | // | |
9 | // http://www.apache.org/licenses/LICENSE-2.0 | |
10 | // | |
11 | // Unless required by applicable law or agreed to in writing, | |
12 | // software distributed under the License is distributed on an | |
13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | |
14 | // KIND, either express or implied. See the License for the | |
15 | // specific language governing permissions and limitations | |
16 | // under the License. | |
17 | ||
18 | #include "arrow/compute/exec/task_util.h" | |
19 | ||
20 | #include <algorithm> | |
21 | #include <mutex> | |
22 | ||
23 | #include "arrow/util/logging.h" | |
24 | ||
25 | namespace arrow { | |
26 | namespace compute { | |
27 | ||
28 | class TaskSchedulerImpl : public TaskScheduler { | |
29 | public: | |
30 | TaskSchedulerImpl(); | |
31 | int RegisterTaskGroup(TaskImpl task_impl, TaskGroupContinuationImpl cont_impl) override; | |
32 | void RegisterEnd() override; | |
33 | Status StartTaskGroup(size_t thread_id, int group_id, int64_t total_num_tasks) override; | |
34 | Status ExecuteMore(size_t thread_id, int num_tasks_to_execute, | |
35 | bool execute_all) override; | |
36 | Status StartScheduling(size_t thread_id, ScheduleImpl schedule_impl, | |
37 | int num_concurrent_tasks, bool use_sync_execution) override; | |
38 | void Abort(AbortContinuationImpl impl) override; | |
39 | ||
40 | private: | |
41 | // Task group state transitions progress one way. | |
42 | // Seeing an old version of the state by a thread is a valid situation. | |
43 | // | |
44 | enum class TaskGroupState : int { | |
45 | NOT_READY, | |
46 | READY, | |
47 | ALL_TASKS_STARTED, | |
48 | ALL_TASKS_FINISHED | |
49 | }; | |
50 | ||
51 | struct TaskGroup { | |
52 | TaskGroup(TaskImpl task_impl, TaskGroupContinuationImpl cont_impl) | |
53 | : task_impl_(std::move(task_impl)), | |
54 | cont_impl_(std::move(cont_impl)), | |
55 | state_(TaskGroupState::NOT_READY), | |
56 | num_tasks_present_(0) { | |
57 | num_tasks_started_.value.store(0); | |
58 | num_tasks_finished_.value.store(0); | |
59 | } | |
60 | TaskGroup(const TaskGroup& src) | |
61 | : task_impl_(src.task_impl_), | |
62 | cont_impl_(src.cont_impl_), | |
63 | state_(TaskGroupState::NOT_READY), | |
64 | num_tasks_present_(0) { | |
65 | ARROW_DCHECK(src.state_ == TaskGroupState::NOT_READY); | |
66 | num_tasks_started_.value.store(0); | |
67 | num_tasks_finished_.value.store(0); | |
68 | } | |
69 | TaskImpl task_impl_; | |
70 | TaskGroupContinuationImpl cont_impl_; | |
71 | ||
72 | TaskGroupState state_; | |
73 | int64_t num_tasks_present_; | |
74 | ||
75 | AtomicWithPadding<int64_t> num_tasks_started_; | |
76 | AtomicWithPadding<int64_t> num_tasks_finished_; | |
77 | }; | |
78 | ||
79 | std::vector<std::pair<int, int64_t>> PickTasks(int num_tasks, int start_task_group = 0); | |
80 | Status ExecuteTask(size_t thread_id, int group_id, int64_t task_id, | |
81 | bool* task_group_finished); | |
82 | bool PostExecuteTask(size_t thread_id, int group_id); | |
83 | Status OnTaskGroupFinished(size_t thread_id, int group_id, | |
84 | bool* all_task_groups_finished); | |
85 | Status ScheduleMore(size_t thread_id, int num_tasks_finished = 0); | |
86 | ||
87 | bool use_sync_execution_; | |
88 | int num_concurrent_tasks_; | |
89 | ScheduleImpl schedule_impl_; | |
90 | AbortContinuationImpl abort_cont_impl_; | |
91 | ||
92 | std::vector<TaskGroup> task_groups_; | |
93 | bool aborted_; | |
94 | bool register_finished_; | |
95 | std::mutex mutex_; // Mutex protecting task_groups_ (state_ and num_tasks_present_ | |
96 | // fields), aborted_ flag and register_finished_ flag | |
97 | ||
98 | AtomicWithPadding<int> num_tasks_to_schedule_; | |
99 | }; | |
100 | ||
101 | TaskSchedulerImpl::TaskSchedulerImpl() | |
102 | : use_sync_execution_(false), | |
103 | num_concurrent_tasks_(0), | |
104 | aborted_(false), | |
105 | register_finished_(false) { | |
106 | num_tasks_to_schedule_.value.store(0); | |
107 | } | |
108 | ||
109 | int TaskSchedulerImpl::RegisterTaskGroup(TaskImpl task_impl, | |
110 | TaskGroupContinuationImpl cont_impl) { | |
111 | int result = static_cast<int>(task_groups_.size()); | |
112 | task_groups_.emplace_back(std::move(task_impl), std::move(cont_impl)); | |
113 | return result; | |
114 | } | |
115 | ||
116 | void TaskSchedulerImpl::RegisterEnd() { | |
117 | std::lock_guard<std::mutex> lock(mutex_); | |
118 | ||
119 | register_finished_ = true; | |
120 | } | |
121 | ||
122 | Status TaskSchedulerImpl::StartTaskGroup(size_t thread_id, int group_id, | |
123 | int64_t total_num_tasks) { | |
124 | ARROW_DCHECK(group_id >= 0 && group_id < static_cast<int>(task_groups_.size())); | |
125 | TaskGroup& task_group = task_groups_[group_id]; | |
126 | ||
127 | bool aborted = false; | |
128 | bool all_tasks_finished = false; | |
129 | { | |
130 | std::lock_guard<std::mutex> lock(mutex_); | |
131 | ||
132 | aborted = aborted_; | |
133 | ||
134 | if (task_group.state_ == TaskGroupState::NOT_READY) { | |
135 | task_group.num_tasks_present_ = total_num_tasks; | |
136 | if (total_num_tasks == 0) { | |
137 | task_group.state_ = TaskGroupState::ALL_TASKS_FINISHED; | |
138 | all_tasks_finished = true; | |
139 | } | |
140 | task_group.state_ = TaskGroupState::READY; | |
141 | } | |
142 | } | |
143 | ||
144 | if (!aborted && all_tasks_finished) { | |
145 | bool all_task_groups_finished = false; | |
146 | RETURN_NOT_OK(OnTaskGroupFinished(thread_id, group_id, &all_task_groups_finished)); | |
147 | if (all_task_groups_finished) { | |
148 | return Status::OK(); | |
149 | } | |
150 | } | |
151 | ||
152 | if (!aborted) { | |
153 | return ScheduleMore(thread_id); | |
154 | } else { | |
155 | return Status::Cancelled("Scheduler cancelled"); | |
156 | } | |
157 | } | |
158 | ||
159 | std::vector<std::pair<int, int64_t>> TaskSchedulerImpl::PickTasks(int num_tasks, | |
160 | int start_task_group) { | |
161 | std::vector<std::pair<int, int64_t>> result; | |
162 | for (size_t i = 0; i < task_groups_.size(); ++i) { | |
163 | int task_group_id = static_cast<int>((start_task_group + i) % (task_groups_.size())); | |
164 | TaskGroup& task_group = task_groups_[task_group_id]; | |
165 | ||
166 | { | |
167 | std::lock_guard<std::mutex> lock(mutex_); | |
168 | if (task_group.state_ != TaskGroupState::READY) { | |
169 | continue; | |
170 | } | |
171 | } | |
172 | ||
173 | int num_tasks_remaining = num_tasks - static_cast<int>(result.size()); | |
174 | int64_t start_task = | |
175 | task_group.num_tasks_started_.value.fetch_add(num_tasks_remaining); | |
176 | if (start_task >= task_group.num_tasks_present_) { | |
177 | continue; | |
178 | } | |
179 | ||
180 | int num_tasks_current_group = num_tasks_remaining; | |
181 | if (start_task + num_tasks_current_group >= task_group.num_tasks_present_) { | |
182 | { | |
183 | std::lock_guard<std::mutex> lock(mutex_); | |
184 | if (task_group.state_ == TaskGroupState::READY) { | |
185 | task_group.state_ = TaskGroupState::ALL_TASKS_STARTED; | |
186 | } | |
187 | } | |
188 | num_tasks_current_group = | |
189 | static_cast<int>(task_group.num_tasks_present_ - start_task); | |
190 | } | |
191 | ||
192 | for (int64_t task_id = start_task; task_id < start_task + num_tasks_current_group; | |
193 | ++task_id) { | |
194 | result.push_back(std::make_pair(task_group_id, task_id)); | |
195 | } | |
196 | ||
197 | if (static_cast<int>(result.size()) == num_tasks) { | |
198 | break; | |
199 | } | |
200 | } | |
201 | ||
202 | return result; | |
203 | } | |
204 | ||
205 | Status TaskSchedulerImpl::ExecuteTask(size_t thread_id, int group_id, int64_t task_id, | |
206 | bool* task_group_finished) { | |
207 | if (!aborted_) { | |
208 | RETURN_NOT_OK(task_groups_[group_id].task_impl_(thread_id, task_id)); | |
209 | } | |
210 | *task_group_finished = PostExecuteTask(thread_id, group_id); | |
211 | return Status::OK(); | |
212 | } | |
213 | ||
214 | bool TaskSchedulerImpl::PostExecuteTask(size_t thread_id, int group_id) { | |
215 | int64_t total = task_groups_[group_id].num_tasks_present_; | |
216 | int64_t prev_finished = task_groups_[group_id].num_tasks_finished_.value.fetch_add(1); | |
217 | bool all_tasks_finished = (prev_finished + 1 == total); | |
218 | return all_tasks_finished; | |
219 | } | |
220 | ||
221 | Status TaskSchedulerImpl::OnTaskGroupFinished(size_t thread_id, int group_id, | |
222 | bool* all_task_groups_finished) { | |
223 | bool aborted = false; | |
224 | { | |
225 | std::lock_guard<std::mutex> lock(mutex_); | |
226 | ||
227 | aborted = aborted_; | |
228 | TaskGroup& task_group = task_groups_[group_id]; | |
229 | task_group.state_ = TaskGroupState::ALL_TASKS_FINISHED; | |
230 | *all_task_groups_finished = true; | |
231 | for (size_t i = 0; i < task_groups_.size(); ++i) { | |
232 | if (task_groups_[i].state_ != TaskGroupState::ALL_TASKS_FINISHED) { | |
233 | *all_task_groups_finished = false; | |
234 | break; | |
235 | } | |
236 | } | |
237 | } | |
238 | ||
239 | if (aborted && *all_task_groups_finished) { | |
240 | abort_cont_impl_(); | |
241 | return Status::Cancelled("Scheduler cancelled"); | |
242 | } | |
243 | if (!aborted) { | |
244 | RETURN_NOT_OK(task_groups_[group_id].cont_impl_(thread_id)); | |
245 | } | |
246 | return Status::OK(); | |
247 | } | |
248 | ||
249 | Status TaskSchedulerImpl::ExecuteMore(size_t thread_id, int num_tasks_to_execute, | |
250 | bool execute_all) { | |
251 | num_tasks_to_execute = std::max(1, num_tasks_to_execute); | |
252 | ||
253 | int last_id = 0; | |
254 | for (;;) { | |
255 | if (aborted_) { | |
256 | return Status::Cancelled("Scheduler cancelled"); | |
257 | } | |
258 | ||
259 | // Pick next bundle of tasks | |
260 | const auto& tasks = PickTasks(num_tasks_to_execute, last_id); | |
261 | if (tasks.empty()) { | |
262 | break; | |
263 | } | |
264 | last_id = tasks.back().first; | |
265 | ||
266 | // Execute picked tasks immediately | |
267 | for (size_t i = 0; i < tasks.size(); ++i) { | |
268 | int group_id = tasks[i].first; | |
269 | int64_t task_id = tasks[i].second; | |
270 | bool task_group_finished = false; | |
271 | Status status = ExecuteTask(thread_id, group_id, task_id, &task_group_finished); | |
272 | if (!status.ok()) { | |
273 | // Mark the remaining picked tasks as finished | |
274 | for (size_t j = i + 1; j < tasks.size(); ++j) { | |
275 | if (PostExecuteTask(thread_id, tasks[j].first)) { | |
276 | bool all_task_groups_finished = false; | |
277 | RETURN_NOT_OK( | |
278 | OnTaskGroupFinished(thread_id, group_id, &all_task_groups_finished)); | |
279 | if (all_task_groups_finished) { | |
280 | return Status::OK(); | |
281 | } | |
282 | } | |
283 | } | |
284 | return status; | |
285 | } else { | |
286 | if (task_group_finished) { | |
287 | bool all_task_groups_finished = false; | |
288 | RETURN_NOT_OK( | |
289 | OnTaskGroupFinished(thread_id, group_id, &all_task_groups_finished)); | |
290 | if (all_task_groups_finished) { | |
291 | return Status::OK(); | |
292 | } | |
293 | } | |
294 | } | |
295 | } | |
296 | ||
297 | if (!execute_all) { | |
298 | num_tasks_to_execute -= static_cast<int>(tasks.size()); | |
299 | if (num_tasks_to_execute == 0) { | |
300 | break; | |
301 | } | |
302 | } | |
303 | } | |
304 | ||
305 | return Status::OK(); | |
306 | } | |
307 | ||
308 | Status TaskSchedulerImpl::StartScheduling(size_t thread_id, ScheduleImpl schedule_impl, | |
309 | int num_concurrent_tasks, | |
310 | bool use_sync_execution) { | |
311 | schedule_impl_ = std::move(schedule_impl); | |
312 | use_sync_execution_ = use_sync_execution; | |
313 | num_concurrent_tasks_ = num_concurrent_tasks; | |
314 | num_tasks_to_schedule_.value += num_concurrent_tasks; | |
315 | return ScheduleMore(thread_id); | |
316 | } | |
317 | ||
318 | Status TaskSchedulerImpl::ScheduleMore(size_t thread_id, int num_tasks_finished) { | |
319 | if (aborted_) { | |
320 | return Status::Cancelled("Scheduler cancelled"); | |
321 | } | |
322 | ||
323 | ARROW_DCHECK(register_finished_); | |
324 | ||
325 | if (use_sync_execution_) { | |
326 | return ExecuteMore(thread_id, 1, true); | |
327 | } | |
328 | ||
329 | int num_new_tasks = num_tasks_finished; | |
330 | for (;;) { | |
331 | int expected = num_tasks_to_schedule_.value.load(); | |
332 | if (num_tasks_to_schedule_.value.compare_exchange_strong(expected, 0)) { | |
333 | num_new_tasks += expected; | |
334 | break; | |
335 | } | |
336 | } | |
337 | if (num_new_tasks == 0) { | |
338 | return Status::OK(); | |
339 | } | |
340 | ||
341 | const auto& tasks = PickTasks(num_new_tasks); | |
342 | if (static_cast<int>(tasks.size()) < num_new_tasks) { | |
343 | num_tasks_to_schedule_.value += num_new_tasks - static_cast<int>(tasks.size()); | |
344 | } | |
345 | ||
346 | for (size_t i = 0; i < tasks.size(); ++i) { | |
347 | int group_id = tasks[i].first; | |
348 | int64_t task_id = tasks[i].second; | |
349 | RETURN_NOT_OK(schedule_impl_([this, group_id, task_id](size_t thread_id) -> Status { | |
350 | RETURN_NOT_OK(ScheduleMore(thread_id, 1)); | |
351 | ||
352 | bool task_group_finished = false; | |
353 | RETURN_NOT_OK(ExecuteTask(thread_id, group_id, task_id, &task_group_finished)); | |
354 | ||
355 | if (task_group_finished) { | |
356 | bool all_task_groups_finished = false; | |
357 | return OnTaskGroupFinished(thread_id, group_id, &all_task_groups_finished); | |
358 | } | |
359 | ||
360 | return Status::OK(); | |
361 | })); | |
362 | } | |
363 | ||
364 | return Status::OK(); | |
365 | } | |
366 | ||
367 | void TaskSchedulerImpl::Abort(AbortContinuationImpl impl) { | |
368 | bool all_finished = true; | |
369 | { | |
370 | std::lock_guard<std::mutex> lock(mutex_); | |
371 | aborted_ = true; | |
372 | abort_cont_impl_ = std::move(impl); | |
373 | if (register_finished_) { | |
374 | for (size_t i = 0; i < task_groups_.size(); ++i) { | |
375 | TaskGroup& task_group = task_groups_[i]; | |
376 | if (task_group.state_ == TaskGroupState::NOT_READY) { | |
377 | task_group.state_ = TaskGroupState::ALL_TASKS_FINISHED; | |
378 | } else if (task_group.state_ == TaskGroupState::READY) { | |
379 | int64_t expected = task_group.num_tasks_started_.value.load(); | |
380 | for (;;) { | |
381 | if (task_group.num_tasks_started_.value.compare_exchange_strong( | |
382 | expected, task_group.num_tasks_present_)) { | |
383 | break; | |
384 | } | |
385 | } | |
386 | int64_t before_add = task_group.num_tasks_finished_.value.fetch_add( | |
387 | task_group.num_tasks_present_ - expected); | |
388 | if (before_add >= expected) { | |
389 | task_group.state_ = TaskGroupState::ALL_TASKS_FINISHED; | |
390 | } else { | |
391 | all_finished = false; | |
392 | task_group.state_ = TaskGroupState::ALL_TASKS_STARTED; | |
393 | } | |
394 | } | |
395 | } | |
396 | } | |
397 | } | |
398 | if (all_finished) { | |
399 | abort_cont_impl_(); | |
400 | } | |
401 | } | |
402 | ||
403 | std::unique_ptr<TaskScheduler> TaskScheduler::Make() { | |
404 | std::unique_ptr<TaskSchedulerImpl> impl{new TaskSchedulerImpl()}; | |
405 | return std::move(impl); | |
406 | } | |
407 | ||
408 | } // namespace compute | |
409 | } // namespace arrow |