]> git.proxmox.com Git - ceph.git/blob - ceph/src/arrow/cpp/src/arrow/compute/exec/exec_plan.cc
import quincy 17.2.0
[ceph.git] / ceph / src / arrow / cpp / src / arrow / compute / exec / exec_plan.cc
1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements. See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership. The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License. You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied. See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17
18 #include "arrow/compute/exec/exec_plan.h"
19
20 #include <sstream>
21 #include <unordered_map>
22 #include <unordered_set>
23
24 #include "arrow/compute/exec.h"
25 #include "arrow/compute/exec/expression.h"
26 #include "arrow/compute/exec_internal.h"
27 #include "arrow/compute/registry.h"
28 #include "arrow/datum.h"
29 #include "arrow/record_batch.h"
30 #include "arrow/result.h"
31 #include "arrow/util/async_generator.h"
32 #include "arrow/util/checked_cast.h"
33 #include "arrow/util/logging.h"
34 #include "arrow/util/optional.h"
35
36 namespace arrow {
37
38 using internal::checked_cast;
39
40 namespace compute {
41
42 namespace {
43
44 struct ExecPlanImpl : public ExecPlan {
45 explicit ExecPlanImpl(ExecContext* exec_context) : ExecPlan(exec_context) {}
46
47 ~ExecPlanImpl() override {
48 if (started_ && !finished_.is_finished()) {
49 ARROW_LOG(WARNING) << "Plan was destroyed before finishing";
50 StopProducing();
51 finished().Wait();
52 }
53 }
54
55 ExecNode* AddNode(std::unique_ptr<ExecNode> node) {
56 if (node->label().empty()) {
57 node->SetLabel(std::to_string(auto_label_counter_++));
58 }
59 if (node->num_inputs() == 0) {
60 sources_.push_back(node.get());
61 }
62 if (node->num_outputs() == 0) {
63 sinks_.push_back(node.get());
64 }
65 nodes_.push_back(std::move(node));
66 return nodes_.back().get();
67 }
68
69 Status Validate() const {
70 if (nodes_.empty()) {
71 return Status::Invalid("ExecPlan has no node");
72 }
73 for (const auto& node : nodes_) {
74 RETURN_NOT_OK(node->Validate());
75 }
76 return Status::OK();
77 }
78
79 Status StartProducing() {
80 if (started_) {
81 return Status::Invalid("restarted ExecPlan");
82 }
83 started_ = true;
84
85 // producers precede consumers
86 sorted_nodes_ = TopoSort();
87
88 std::vector<Future<>> futures;
89
90 Status st = Status::OK();
91
92 using rev_it = std::reverse_iterator<NodeVector::iterator>;
93 for (rev_it it(sorted_nodes_.end()), end(sorted_nodes_.begin()); it != end; ++it) {
94 auto node = *it;
95
96 st = node->StartProducing();
97 if (!st.ok()) {
98 // Stop nodes that successfully started, in reverse order
99 stopped_ = true;
100 StopProducingImpl(it.base(), sorted_nodes_.end());
101 break;
102 }
103
104 futures.push_back(node->finished());
105 }
106
107 finished_ = AllFinished(futures);
108 return st;
109 }
110
111 void StopProducing() {
112 DCHECK(started_) << "stopped an ExecPlan which never started";
113 stopped_ = true;
114
115 StopProducingImpl(sorted_nodes_.begin(), sorted_nodes_.end());
116 }
117
118 template <typename It>
119 void StopProducingImpl(It begin, It end) {
120 for (auto it = begin; it != end; ++it) {
121 auto node = *it;
122 node->StopProducing();
123 }
124 }
125
126 NodeVector TopoSort() const {
127 struct Impl {
128 const std::vector<std::unique_ptr<ExecNode>>& nodes;
129 std::unordered_set<ExecNode*> visited;
130 NodeVector sorted;
131
132 explicit Impl(const std::vector<std::unique_ptr<ExecNode>>& nodes) : nodes(nodes) {
133 visited.reserve(nodes.size());
134 sorted.resize(nodes.size());
135
136 for (const auto& node : nodes) {
137 Visit(node.get());
138 }
139
140 DCHECK_EQ(visited.size(), nodes.size());
141 }
142
143 void Visit(ExecNode* node) {
144 if (visited.count(node) != 0) return;
145
146 for (auto input : node->inputs()) {
147 // Ensure that producers are inserted before this consumer
148 Visit(input);
149 }
150
151 sorted[visited.size()] = node;
152 visited.insert(node);
153 }
154 };
155
156 return std::move(Impl{nodes_}.sorted);
157 }
158
159 std::string ToString() const {
160 std::stringstream ss;
161 ss << "ExecPlan with " << nodes_.size() << " nodes:" << std::endl;
162 for (const auto& node : TopoSort()) {
163 ss << node->ToString() << std::endl;
164 }
165 return ss.str();
166 }
167
168 Future<> finished_ = Future<>::MakeFinished();
169 bool started_ = false, stopped_ = false;
170 std::vector<std::unique_ptr<ExecNode>> nodes_;
171 NodeVector sources_, sinks_;
172 NodeVector sorted_nodes_;
173 uint32_t auto_label_counter_ = 0;
174 };
175
176 ExecPlanImpl* ToDerived(ExecPlan* ptr) { return checked_cast<ExecPlanImpl*>(ptr); }
177
178 const ExecPlanImpl* ToDerived(const ExecPlan* ptr) {
179 return checked_cast<const ExecPlanImpl*>(ptr);
180 }
181
182 util::optional<int> GetNodeIndex(const std::vector<ExecNode*>& nodes,
183 const ExecNode* node) {
184 for (int i = 0; i < static_cast<int>(nodes.size()); ++i) {
185 if (nodes[i] == node) return i;
186 }
187 return util::nullopt;
188 }
189
190 } // namespace
191
192 Result<std::shared_ptr<ExecPlan>> ExecPlan::Make(ExecContext* ctx) {
193 return std::shared_ptr<ExecPlan>(new ExecPlanImpl{ctx});
194 }
195
196 ExecNode* ExecPlan::AddNode(std::unique_ptr<ExecNode> node) {
197 return ToDerived(this)->AddNode(std::move(node));
198 }
199
200 const ExecPlan::NodeVector& ExecPlan::sources() const {
201 return ToDerived(this)->sources_;
202 }
203
204 const ExecPlan::NodeVector& ExecPlan::sinks() const { return ToDerived(this)->sinks_; }
205
206 Status ExecPlan::Validate() { return ToDerived(this)->Validate(); }
207
208 Status ExecPlan::StartProducing() { return ToDerived(this)->StartProducing(); }
209
210 void ExecPlan::StopProducing() { ToDerived(this)->StopProducing(); }
211
212 Future<> ExecPlan::finished() { return ToDerived(this)->finished_; }
213
214 std::string ExecPlan::ToString() const { return ToDerived(this)->ToString(); }
215
216 ExecNode::ExecNode(ExecPlan* plan, NodeVector inputs,
217 std::vector<std::string> input_labels,
218 std::shared_ptr<Schema> output_schema, int num_outputs)
219 : plan_(plan),
220 inputs_(std::move(inputs)),
221 input_labels_(std::move(input_labels)),
222 output_schema_(std::move(output_schema)),
223 num_outputs_(num_outputs) {
224 for (auto input : inputs_) {
225 input->outputs_.push_back(this);
226 }
227 }
228
229 Status ExecNode::Validate() const {
230 if (inputs_.size() != input_labels_.size()) {
231 return Status::Invalid("Invalid number of inputs for '", label(), "' (expected ",
232 num_inputs(), ", actual ", input_labels_.size(), ")");
233 }
234
235 if (static_cast<int>(outputs_.size()) != num_outputs_) {
236 return Status::Invalid("Invalid number of outputs for '", label(), "' (expected ",
237 num_outputs(), ", actual ", outputs_.size(), ")");
238 }
239
240 for (auto out : outputs_) {
241 auto input_index = GetNodeIndex(out->inputs(), this);
242 if (!input_index) {
243 return Status::Invalid("Node '", label(), "' outputs to node '", out->label(),
244 "' but is not listed as an input.");
245 }
246 }
247
248 return Status::OK();
249 }
250
251 std::string ExecNode::ToString() const {
252 std::stringstream ss;
253 ss << kind_name() << "{\"" << label_ << '"';
254 if (!inputs_.empty()) {
255 ss << ", inputs=[";
256 for (size_t i = 0; i < inputs_.size(); i++) {
257 if (i > 0) ss << ", ";
258 ss << input_labels_[i] << ": \"" << inputs_[i]->label() << '"';
259 }
260 ss << ']';
261 }
262
263 if (!outputs_.empty()) {
264 ss << ", outputs=[";
265 for (size_t i = 0; i < outputs_.size(); i++) {
266 if (i > 0) ss << ", ";
267 ss << "\"" << outputs_[i]->label() << "\"";
268 }
269 ss << ']';
270 }
271
272 const std::string extra = ToStringExtra();
273 if (!extra.empty()) ss << ", " << extra;
274
275 ss << '}';
276 return ss.str();
277 }
278
279 std::string ExecNode::ToStringExtra() const { return ""; }
280
281 bool ExecNode::ErrorIfNotOk(Status status) {
282 if (status.ok()) return false;
283
284 for (auto out : outputs_) {
285 out->ErrorReceived(this, out == outputs_.back() ? std::move(status) : status);
286 }
287 return true;
288 }
289
290 MapNode::MapNode(ExecPlan* plan, std::vector<ExecNode*> inputs,
291 std::shared_ptr<Schema> output_schema, bool async_mode)
292 : ExecNode(plan, std::move(inputs), /*input_labels=*/{"target"},
293 std::move(output_schema),
294 /*num_outputs=*/1) {
295 if (async_mode) {
296 executor_ = plan_->exec_context()->executor();
297 } else {
298 executor_ = nullptr;
299 }
300 }
301
302 void MapNode::ErrorReceived(ExecNode* input, Status error) {
303 DCHECK_EQ(input, inputs_[0]);
304 outputs_[0]->ErrorReceived(this, std::move(error));
305 }
306
307 void MapNode::InputFinished(ExecNode* input, int total_batches) {
308 DCHECK_EQ(input, inputs_[0]);
309 outputs_[0]->InputFinished(this, total_batches);
310 if (input_counter_.SetTotal(total_batches)) {
311 this->Finish();
312 }
313 }
314
315 Status MapNode::StartProducing() { return Status::OK(); }
316
317 void MapNode::PauseProducing(ExecNode* output) {}
318
319 void MapNode::ResumeProducing(ExecNode* output) {}
320
321 void MapNode::StopProducing(ExecNode* output) {
322 DCHECK_EQ(output, outputs_[0]);
323 StopProducing();
324 }
325
326 void MapNode::StopProducing() {
327 if (executor_) {
328 this->stop_source_.RequestStop();
329 }
330 if (input_counter_.Cancel()) {
331 this->Finish();
332 }
333 inputs_[0]->StopProducing(this);
334 }
335
336 Future<> MapNode::finished() { return finished_; }
337
338 void MapNode::SubmitTask(std::function<Result<ExecBatch>(ExecBatch)> map_fn,
339 ExecBatch batch) {
340 Status status;
341 if (finished_.is_finished()) {
342 return;
343 }
344 auto task = [this, map_fn, batch]() {
345 auto guarantee = batch.guarantee;
346 auto output_batch = map_fn(std::move(batch));
347 if (ErrorIfNotOk(output_batch.status())) {
348 return output_batch.status();
349 }
350 output_batch->guarantee = guarantee;
351 outputs_[0]->InputReceived(this, output_batch.MoveValueUnsafe());
352 return Status::OK();
353 };
354
355 if (executor_) {
356 status = task_group_.AddTask([this, task]() -> Result<Future<>> {
357 return this->executor_->Submit(this->stop_source_.token(), [this, task]() {
358 auto status = task();
359 if (this->input_counter_.Increment()) {
360 this->Finish(status);
361 }
362 return status;
363 });
364 });
365 } else {
366 status = task();
367 if (input_counter_.Increment()) {
368 this->Finish(status);
369 }
370 }
371 if (!status.ok()) {
372 if (input_counter_.Cancel()) {
373 this->Finish(status);
374 }
375 inputs_[0]->StopProducing(this);
376 return;
377 }
378 }
379
380 void MapNode::Finish(Status finish_st /*= Status::OK()*/) {
381 if (executor_) {
382 task_group_.End().AddCallback([this, finish_st](const Status& st) {
383 Status final_status = finish_st & st;
384 this->finished_.MarkFinished(final_status);
385 });
386 } else {
387 this->finished_.MarkFinished(finish_st);
388 }
389 }
390
391 std::shared_ptr<RecordBatchReader> MakeGeneratorReader(
392 std::shared_ptr<Schema> schema,
393 std::function<Future<util::optional<ExecBatch>>()> gen, MemoryPool* pool) {
394 struct Impl : RecordBatchReader {
395 std::shared_ptr<Schema> schema() const override { return schema_; }
396
397 Status ReadNext(std::shared_ptr<RecordBatch>* record_batch) override {
398 ARROW_ASSIGN_OR_RAISE(auto batch, iterator_.Next());
399 if (batch) {
400 ARROW_ASSIGN_OR_RAISE(*record_batch, batch->ToRecordBatch(schema_, pool_));
401 } else {
402 *record_batch = IterationEnd<std::shared_ptr<RecordBatch>>();
403 }
404 return Status::OK();
405 }
406
407 MemoryPool* pool_;
408 std::shared_ptr<Schema> schema_;
409 Iterator<util::optional<ExecBatch>> iterator_;
410 };
411
412 auto out = std::make_shared<Impl>();
413 out->pool_ = pool;
414 out->schema_ = std::move(schema);
415 out->iterator_ = MakeGeneratorIterator(std::move(gen));
416 return out;
417 }
418
419 Result<ExecNode*> Declaration::AddToPlan(ExecPlan* plan,
420 ExecFactoryRegistry* registry) const {
421 std::vector<ExecNode*> inputs(this->inputs.size());
422
423 size_t i = 0;
424 for (const Input& input : this->inputs) {
425 if (auto node = util::get_if<ExecNode*>(&input)) {
426 inputs[i++] = *node;
427 continue;
428 }
429 ARROW_ASSIGN_OR_RAISE(inputs[i++],
430 util::get<Declaration>(input).AddToPlan(plan, registry));
431 }
432
433 ARROW_ASSIGN_OR_RAISE(
434 auto node, MakeExecNode(this->factory_name, plan, std::move(inputs), *this->options,
435 registry));
436 node->SetLabel(this->label);
437 return node;
438 }
439
440 Declaration Declaration::Sequence(std::vector<Declaration> decls) {
441 DCHECK(!decls.empty());
442
443 Declaration out = std::move(decls.back());
444 decls.pop_back();
445 auto receiver = &out;
446 while (!decls.empty()) {
447 Declaration input = std::move(decls.back());
448 decls.pop_back();
449
450 receiver->inputs.emplace_back(std::move(input));
451 receiver = &util::get<Declaration>(receiver->inputs.front());
452 }
453 return out;
454 }
455
456 namespace internal {
457
458 void RegisterSourceNode(ExecFactoryRegistry*);
459 void RegisterFilterNode(ExecFactoryRegistry*);
460 void RegisterProjectNode(ExecFactoryRegistry*);
461 void RegisterUnionNode(ExecFactoryRegistry*);
462 void RegisterAggregateNode(ExecFactoryRegistry*);
463 void RegisterSinkNode(ExecFactoryRegistry*);
464 void RegisterHashJoinNode(ExecFactoryRegistry*);
465
466 } // namespace internal
467
468 ExecFactoryRegistry* default_exec_factory_registry() {
469 class DefaultRegistry : public ExecFactoryRegistry {
470 public:
471 DefaultRegistry() {
472 internal::RegisterSourceNode(this);
473 internal::RegisterFilterNode(this);
474 internal::RegisterProjectNode(this);
475 internal::RegisterUnionNode(this);
476 internal::RegisterAggregateNode(this);
477 internal::RegisterSinkNode(this);
478 internal::RegisterHashJoinNode(this);
479 }
480
481 Result<Factory> GetFactory(const std::string& factory_name) override {
482 auto it = factories_.find(factory_name);
483 if (it == factories_.end()) {
484 return Status::KeyError("ExecNode factory named ", factory_name,
485 " not present in registry.");
486 }
487 return it->second;
488 }
489
490 Status AddFactory(std::string factory_name, Factory factory) override {
491 auto it_success = factories_.emplace(std::move(factory_name), std::move(factory));
492
493 if (!it_success.second) {
494 const auto& factory_name = it_success.first->first;
495 return Status::KeyError("ExecNode factory named ", factory_name,
496 " already registered.");
497 }
498
499 return Status::OK();
500 }
501
502 private:
503 std::unordered_map<std::string, Factory> factories_;
504 };
505
506 static DefaultRegistry instance;
507 return &instance;
508 }
509
510 Result<std::function<Future<util::optional<ExecBatch>>()>> MakeReaderGenerator(
511 std::shared_ptr<RecordBatchReader> reader, ::arrow::internal::Executor* io_executor,
512 int max_q, int q_restart) {
513 auto batch_it = MakeMapIterator(
514 [](std::shared_ptr<RecordBatch> batch) {
515 return util::make_optional(ExecBatch(*batch));
516 },
517 MakeIteratorFromReader(reader));
518
519 return MakeBackgroundGenerator(std::move(batch_it), io_executor, max_q, q_restart);
520 }
521 } // namespace compute
522
523 } // namespace arrow