]> git.proxmox.com Git - ceph.git/blob - ceph/src/arrow/cpp/src/arrow/util/task_group.cc
import quincy 17.2.0
[ceph.git] / ceph / src / arrow / cpp / src / arrow / util / task_group.cc
1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements. See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership. The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License. You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied. See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17
18 #include "arrow/util/task_group.h"
19
20 #include <atomic>
21 #include <condition_variable>
22 #include <cstdint>
23 #include <mutex>
24 #include <utility>
25
26 #include "arrow/util/checked_cast.h"
27 #include "arrow/util/logging.h"
28 #include "arrow/util/thread_pool.h"
29
30 namespace arrow {
31 namespace internal {
32
33 namespace {
34
35 ////////////////////////////////////////////////////////////////////////
36 // Serial TaskGroup implementation
37
38 class SerialTaskGroup : public TaskGroup {
39 public:
40 explicit SerialTaskGroup(StopToken stop_token) : stop_token_(std::move(stop_token)) {}
41
42 void AppendReal(FnOnce<Status()> task) override {
43 DCHECK(!finished_);
44 if (stop_token_.IsStopRequested()) {
45 status_ &= stop_token_.Poll();
46 return;
47 }
48 if (status_.ok()) {
49 status_ &= std::move(task)();
50 }
51 }
52
53 Status current_status() override { return status_; }
54
55 bool ok() const override { return status_.ok(); }
56
57 Status Finish() override {
58 if (!finished_) {
59 finished_ = true;
60 }
61 return status_;
62 }
63
64 Future<> FinishAsync() override { return Future<>::MakeFinished(Finish()); }
65
66 int parallelism() override { return 1; }
67
68 StopToken stop_token_;
69 Status status_;
70 bool finished_ = false;
71 };
72
73 ////////////////////////////////////////////////////////////////////////
74 // Threaded TaskGroup implementation
75
76 class ThreadedTaskGroup : public TaskGroup {
77 public:
78 ThreadedTaskGroup(Executor* executor, StopToken stop_token)
79 : executor_(executor),
80 stop_token_(std::move(stop_token)),
81 nremaining_(0),
82 ok_(true) {}
83
84 ~ThreadedTaskGroup() override {
85 // Make sure all pending tasks are finished, so that dangling references
86 // to this don't persist.
87 ARROW_UNUSED(Finish());
88 }
89
90 void AppendReal(FnOnce<Status()> task) override {
91 DCHECK(!finished_);
92 if (stop_token_.IsStopRequested()) {
93 UpdateStatus(stop_token_.Poll());
94 return;
95 }
96
97 // The hot path is unlocked thanks to atomics
98 // Only if an error occurs is the lock taken
99 if (ok_.load(std::memory_order_acquire)) {
100 nremaining_.fetch_add(1, std::memory_order_acquire);
101
102 auto self = checked_pointer_cast<ThreadedTaskGroup>(shared_from_this());
103
104 struct Callable {
105 void operator()() {
106 if (self_->ok_.load(std::memory_order_acquire)) {
107 Status st;
108 if (stop_token_.IsStopRequested()) {
109 st = stop_token_.Poll();
110 } else {
111 // XXX what about exceptions?
112 st = std::move(task_)();
113 }
114 self_->UpdateStatus(std::move(st));
115 }
116 self_->OneTaskDone();
117 }
118
119 std::shared_ptr<ThreadedTaskGroup> self_;
120 FnOnce<Status()> task_;
121 StopToken stop_token_;
122 };
123
124 Status st =
125 executor_->Spawn(Callable{std::move(self), std::move(task), stop_token_});
126 UpdateStatus(std::move(st));
127 }
128 }
129
130 Status current_status() override {
131 std::lock_guard<std::mutex> lock(mutex_);
132 return status_;
133 }
134
135 bool ok() const override { return ok_.load(); }
136
137 Status Finish() override {
138 std::unique_lock<std::mutex> lock(mutex_);
139 if (!finished_) {
140 cv_.wait(lock, [&]() { return nremaining_.load() == 0; });
141 // Current tasks may start other tasks, so only set this when done
142 finished_ = true;
143 }
144 return status_;
145 }
146
147 Future<> FinishAsync() override {
148 std::lock_guard<std::mutex> lock(mutex_);
149 if (!completion_future_.has_value()) {
150 if (nremaining_.load() == 0) {
151 completion_future_ = Future<>::MakeFinished(status_);
152 } else {
153 completion_future_ = Future<>::Make();
154 }
155 }
156 return *completion_future_;
157 }
158
159 int parallelism() override { return executor_->GetCapacity(); }
160
161 protected:
162 void UpdateStatus(Status&& st) {
163 // Must be called unlocked, only locks on error
164 if (ARROW_PREDICT_FALSE(!st.ok())) {
165 std::lock_guard<std::mutex> lock(mutex_);
166 ok_.store(false, std::memory_order_release);
167 status_ &= std::move(st);
168 }
169 }
170
171 void OneTaskDone() {
172 // Can be called unlocked thanks to atomics
173 auto nremaining = nremaining_.fetch_sub(1, std::memory_order_release) - 1;
174 DCHECK_GE(nremaining, 0);
175 if (nremaining == 0) {
176 // Take the lock so that ~ThreadedTaskGroup cannot destroy cv
177 // before cv.notify_one() has returned
178 std::unique_lock<std::mutex> lock(mutex_);
179 cv_.notify_one();
180 if (completion_future_.has_value()) {
181 // MarkFinished could be slow. We don't want to call it while we are holding
182 // the lock.
183 auto& future = *completion_future_;
184 const auto finished = completion_future_->is_finished();
185 const auto& status = status_;
186 // This will be redundant if the user calls Finish and not FinishAsync
187 if (!finished && !finished_) {
188 finished_ = true;
189 lock.unlock();
190 future.MarkFinished(status);
191 } else {
192 lock.unlock();
193 }
194 }
195 }
196 }
197
198 // These members are usable unlocked
199 Executor* executor_;
200 StopToken stop_token_;
201 std::atomic<int32_t> nremaining_;
202 std::atomic<bool> ok_;
203
204 // These members use locking
205 std::mutex mutex_;
206 std::condition_variable cv_;
207 Status status_;
208 bool finished_ = false;
209 util::optional<Future<>> completion_future_;
210 };
211
212 } // namespace
213
214 std::shared_ptr<TaskGroup> TaskGroup::MakeSerial(StopToken stop_token) {
215 return std::shared_ptr<TaskGroup>(new SerialTaskGroup{stop_token});
216 }
217
218 std::shared_ptr<TaskGroup> TaskGroup::MakeThreaded(Executor* thread_pool,
219 StopToken stop_token) {
220 return std::shared_ptr<TaskGroup>(new ThreadedTaskGroup{thread_pool, stop_token});
221 }
222
223 } // namespace internal
224 } // namespace arrow