]>
git.proxmox.com Git - ceph.git/blob - ceph/src/arrow/cpp/src/arrow/compute/exec/exec_plan.cc
1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements. See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership. The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License. You may obtain a copy of the License at
9 // http://www.apache.org/licenses/LICENSE-2.0
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied. See the License for the
15 // specific language governing permissions and limitations
18 #include "arrow/compute/exec/exec_plan.h"
21 #include <unordered_map>
22 #include <unordered_set>
24 #include "arrow/compute/exec.h"
25 #include "arrow/compute/exec/expression.h"
26 #include "arrow/compute/exec_internal.h"
27 #include "arrow/compute/registry.h"
28 #include "arrow/datum.h"
29 #include "arrow/record_batch.h"
30 #include "arrow/result.h"
31 #include "arrow/util/async_generator.h"
32 #include "arrow/util/checked_cast.h"
33 #include "arrow/util/logging.h"
34 #include "arrow/util/optional.h"
38 using internal::checked_cast
;
44 struct ExecPlanImpl
: public ExecPlan
{
45 explicit ExecPlanImpl(ExecContext
* exec_context
) : ExecPlan(exec_context
) {}
47 ~ExecPlanImpl() override
{
48 if (started_
&& !finished_
.is_finished()) {
49 ARROW_LOG(WARNING
) << "Plan was destroyed before finishing";
55 ExecNode
* AddNode(std::unique_ptr
<ExecNode
> node
) {
56 if (node
->label().empty()) {
57 node
->SetLabel(std::to_string(auto_label_counter_
++));
59 if (node
->num_inputs() == 0) {
60 sources_
.push_back(node
.get());
62 if (node
->num_outputs() == 0) {
63 sinks_
.push_back(node
.get());
65 nodes_
.push_back(std::move(node
));
66 return nodes_
.back().get();
69 Status
Validate() const {
71 return Status::Invalid("ExecPlan has no node");
73 for (const auto& node
: nodes_
) {
74 RETURN_NOT_OK(node
->Validate());
79 Status
StartProducing() {
81 return Status::Invalid("restarted ExecPlan");
85 // producers precede consumers
86 sorted_nodes_
= TopoSort();
88 std::vector
<Future
<>> futures
;
90 Status st
= Status::OK();
92 using rev_it
= std::reverse_iterator
<NodeVector::iterator
>;
93 for (rev_it
it(sorted_nodes_
.end()), end(sorted_nodes_
.begin()); it
!= end
; ++it
) {
96 st
= node
->StartProducing();
98 // Stop nodes that successfully started, in reverse order
100 StopProducingImpl(it
.base(), sorted_nodes_
.end());
104 futures
.push_back(node
->finished());
107 finished_
= AllFinished(futures
);
111 void StopProducing() {
112 DCHECK(started_
) << "stopped an ExecPlan which never started";
115 StopProducingImpl(sorted_nodes_
.begin(), sorted_nodes_
.end());
118 template <typename It
>
119 void StopProducingImpl(It begin
, It end
) {
120 for (auto it
= begin
; it
!= end
; ++it
) {
122 node
->StopProducing();
126 NodeVector
TopoSort() const {
128 const std::vector
<std::unique_ptr
<ExecNode
>>& nodes
;
129 std::unordered_set
<ExecNode
*> visited
;
132 explicit Impl(const std::vector
<std::unique_ptr
<ExecNode
>>& nodes
) : nodes(nodes
) {
133 visited
.reserve(nodes
.size());
134 sorted
.resize(nodes
.size());
136 for (const auto& node
: nodes
) {
140 DCHECK_EQ(visited
.size(), nodes
.size());
143 void Visit(ExecNode
* node
) {
144 if (visited
.count(node
) != 0) return;
146 for (auto input
: node
->inputs()) {
147 // Ensure that producers are inserted before this consumer
151 sorted
[visited
.size()] = node
;
152 visited
.insert(node
);
156 return std::move(Impl
{nodes_
}.sorted
);
159 std::string
ToString() const {
160 std::stringstream ss
;
161 ss
<< "ExecPlan with " << nodes_
.size() << " nodes:" << std::endl
;
162 for (const auto& node
: TopoSort()) {
163 ss
<< node
->ToString() << std::endl
;
168 Future
<> finished_
= Future
<>::MakeFinished();
169 bool started_
= false, stopped_
= false;
170 std::vector
<std::unique_ptr
<ExecNode
>> nodes_
;
171 NodeVector sources_
, sinks_
;
172 NodeVector sorted_nodes_
;
173 uint32_t auto_label_counter_
= 0;
176 ExecPlanImpl
* ToDerived(ExecPlan
* ptr
) { return checked_cast
<ExecPlanImpl
*>(ptr
); }
178 const ExecPlanImpl
* ToDerived(const ExecPlan
* ptr
) {
179 return checked_cast
<const ExecPlanImpl
*>(ptr
);
182 util::optional
<int> GetNodeIndex(const std::vector
<ExecNode
*>& nodes
,
183 const ExecNode
* node
) {
184 for (int i
= 0; i
< static_cast<int>(nodes
.size()); ++i
) {
185 if (nodes
[i
] == node
) return i
;
187 return util::nullopt
;
192 Result
<std::shared_ptr
<ExecPlan
>> ExecPlan::Make(ExecContext
* ctx
) {
193 return std::shared_ptr
<ExecPlan
>(new ExecPlanImpl
{ctx
});
196 ExecNode
* ExecPlan::AddNode(std::unique_ptr
<ExecNode
> node
) {
197 return ToDerived(this)->AddNode(std::move(node
));
200 const ExecPlan::NodeVector
& ExecPlan::sources() const {
201 return ToDerived(this)->sources_
;
204 const ExecPlan::NodeVector
& ExecPlan::sinks() const { return ToDerived(this)->sinks_
; }
206 Status
ExecPlan::Validate() { return ToDerived(this)->Validate(); }
208 Status
ExecPlan::StartProducing() { return ToDerived(this)->StartProducing(); }
210 void ExecPlan::StopProducing() { ToDerived(this)->StopProducing(); }
212 Future
<> ExecPlan::finished() { return ToDerived(this)->finished_
; }
214 std::string
ExecPlan::ToString() const { return ToDerived(this)->ToString(); }
216 ExecNode::ExecNode(ExecPlan
* plan
, NodeVector inputs
,
217 std::vector
<std::string
> input_labels
,
218 std::shared_ptr
<Schema
> output_schema
, int num_outputs
)
220 inputs_(std::move(inputs
)),
221 input_labels_(std::move(input_labels
)),
222 output_schema_(std::move(output_schema
)),
223 num_outputs_(num_outputs
) {
224 for (auto input
: inputs_
) {
225 input
->outputs_
.push_back(this);
229 Status
ExecNode::Validate() const {
230 if (inputs_
.size() != input_labels_
.size()) {
231 return Status::Invalid("Invalid number of inputs for '", label(), "' (expected ",
232 num_inputs(), ", actual ", input_labels_
.size(), ")");
235 if (static_cast<int>(outputs_
.size()) != num_outputs_
) {
236 return Status::Invalid("Invalid number of outputs for '", label(), "' (expected ",
237 num_outputs(), ", actual ", outputs_
.size(), ")");
240 for (auto out
: outputs_
) {
241 auto input_index
= GetNodeIndex(out
->inputs(), this);
243 return Status::Invalid("Node '", label(), "' outputs to node '", out
->label(),
244 "' but is not listed as an input.");
251 std::string
ExecNode::ToString() const {
252 std::stringstream ss
;
253 ss
<< kind_name() << "{\"" << label_
<< '"';
254 if (!inputs_
.empty()) {
256 for (size_t i
= 0; i
< inputs_
.size(); i
++) {
257 if (i
> 0) ss
<< ", ";
258 ss
<< input_labels_
[i
] << ": \"" << inputs_
[i
]->label() << '"';
263 if (!outputs_
.empty()) {
265 for (size_t i
= 0; i
< outputs_
.size(); i
++) {
266 if (i
> 0) ss
<< ", ";
267 ss
<< "\"" << outputs_
[i
]->label() << "\"";
272 const std::string extra
= ToStringExtra();
273 if (!extra
.empty()) ss
<< ", " << extra
;
279 std::string
ExecNode::ToStringExtra() const { return ""; }
281 bool ExecNode::ErrorIfNotOk(Status status
) {
282 if (status
.ok()) return false;
284 for (auto out
: outputs_
) {
285 out
->ErrorReceived(this, out
== outputs_
.back() ? std::move(status
) : status
);
290 MapNode::MapNode(ExecPlan
* plan
, std::vector
<ExecNode
*> inputs
,
291 std::shared_ptr
<Schema
> output_schema
, bool async_mode
)
292 : ExecNode(plan
, std::move(inputs
), /*input_labels=*/{"target"},
293 std::move(output_schema
),
296 executor_
= plan_
->exec_context()->executor();
302 void MapNode::ErrorReceived(ExecNode
* input
, Status error
) {
303 DCHECK_EQ(input
, inputs_
[0]);
304 outputs_
[0]->ErrorReceived(this, std::move(error
));
307 void MapNode::InputFinished(ExecNode
* input
, int total_batches
) {
308 DCHECK_EQ(input
, inputs_
[0]);
309 outputs_
[0]->InputFinished(this, total_batches
);
310 if (input_counter_
.SetTotal(total_batches
)) {
315 Status
MapNode::StartProducing() { return Status::OK(); }
317 void MapNode::PauseProducing(ExecNode
* output
) {}
319 void MapNode::ResumeProducing(ExecNode
* output
) {}
321 void MapNode::StopProducing(ExecNode
* output
) {
322 DCHECK_EQ(output
, outputs_
[0]);
326 void MapNode::StopProducing() {
328 this->stop_source_
.RequestStop();
330 if (input_counter_
.Cancel()) {
333 inputs_
[0]->StopProducing(this);
336 Future
<> MapNode::finished() { return finished_
; }
338 void MapNode::SubmitTask(std::function
<Result
<ExecBatch
>(ExecBatch
)> map_fn
,
341 if (finished_
.is_finished()) {
344 auto task
= [this, map_fn
, batch
]() {
345 auto guarantee
= batch
.guarantee
;
346 auto output_batch
= map_fn(std::move(batch
));
347 if (ErrorIfNotOk(output_batch
.status())) {
348 return output_batch
.status();
350 output_batch
->guarantee
= guarantee
;
351 outputs_
[0]->InputReceived(this, output_batch
.MoveValueUnsafe());
356 status
= task_group_
.AddTask([this, task
]() -> Result
<Future
<>> {
357 return this->executor_
->Submit(this->stop_source_
.token(), [this, task
]() {
358 auto status
= task();
359 if (this->input_counter_
.Increment()) {
360 this->Finish(status
);
367 if (input_counter_
.Increment()) {
368 this->Finish(status
);
372 if (input_counter_
.Cancel()) {
373 this->Finish(status
);
375 inputs_
[0]->StopProducing(this);
380 void MapNode::Finish(Status finish_st
/*= Status::OK()*/) {
382 task_group_
.End().AddCallback([this, finish_st
](const Status
& st
) {
383 Status final_status
= finish_st
& st
;
384 this->finished_
.MarkFinished(final_status
);
387 this->finished_
.MarkFinished(finish_st
);
391 std::shared_ptr
<RecordBatchReader
> MakeGeneratorReader(
392 std::shared_ptr
<Schema
> schema
,
393 std::function
<Future
<util::optional
<ExecBatch
>>()> gen
, MemoryPool
* pool
) {
394 struct Impl
: RecordBatchReader
{
395 std::shared_ptr
<Schema
> schema() const override
{ return schema_
; }
397 Status
ReadNext(std::shared_ptr
<RecordBatch
>* record_batch
) override
{
398 ARROW_ASSIGN_OR_RAISE(auto batch
, iterator_
.Next());
400 ARROW_ASSIGN_OR_RAISE(*record_batch
, batch
->ToRecordBatch(schema_
, pool_
));
402 *record_batch
= IterationEnd
<std::shared_ptr
<RecordBatch
>>();
408 std::shared_ptr
<Schema
> schema_
;
409 Iterator
<util::optional
<ExecBatch
>> iterator_
;
412 auto out
= std::make_shared
<Impl
>();
414 out
->schema_
= std::move(schema
);
415 out
->iterator_
= MakeGeneratorIterator(std::move(gen
));
419 Result
<ExecNode
*> Declaration::AddToPlan(ExecPlan
* plan
,
420 ExecFactoryRegistry
* registry
) const {
421 std::vector
<ExecNode
*> inputs(this->inputs
.size());
424 for (const Input
& input
: this->inputs
) {
425 if (auto node
= util::get_if
<ExecNode
*>(&input
)) {
429 ARROW_ASSIGN_OR_RAISE(inputs
[i
++],
430 util::get
<Declaration
>(input
).AddToPlan(plan
, registry
));
433 ARROW_ASSIGN_OR_RAISE(
434 auto node
, MakeExecNode(this->factory_name
, plan
, std::move(inputs
), *this->options
,
436 node
->SetLabel(this->label
);
440 Declaration
Declaration::Sequence(std::vector
<Declaration
> decls
) {
441 DCHECK(!decls
.empty());
443 Declaration out
= std::move(decls
.back());
445 auto receiver
= &out
;
446 while (!decls
.empty()) {
447 Declaration input
= std::move(decls
.back());
450 receiver
->inputs
.emplace_back(std::move(input
));
451 receiver
= &util::get
<Declaration
>(receiver
->inputs
.front());
458 void RegisterSourceNode(ExecFactoryRegistry
*);
459 void RegisterFilterNode(ExecFactoryRegistry
*);
460 void RegisterProjectNode(ExecFactoryRegistry
*);
461 void RegisterUnionNode(ExecFactoryRegistry
*);
462 void RegisterAggregateNode(ExecFactoryRegistry
*);
463 void RegisterSinkNode(ExecFactoryRegistry
*);
464 void RegisterHashJoinNode(ExecFactoryRegistry
*);
466 } // namespace internal
468 ExecFactoryRegistry
* default_exec_factory_registry() {
469 class DefaultRegistry
: public ExecFactoryRegistry
{
472 internal::RegisterSourceNode(this);
473 internal::RegisterFilterNode(this);
474 internal::RegisterProjectNode(this);
475 internal::RegisterUnionNode(this);
476 internal::RegisterAggregateNode(this);
477 internal::RegisterSinkNode(this);
478 internal::RegisterHashJoinNode(this);
481 Result
<Factory
> GetFactory(const std::string
& factory_name
) override
{
482 auto it
= factories_
.find(factory_name
);
483 if (it
== factories_
.end()) {
484 return Status::KeyError("ExecNode factory named ", factory_name
,
485 " not present in registry.");
490 Status
AddFactory(std::string factory_name
, Factory factory
) override
{
491 auto it_success
= factories_
.emplace(std::move(factory_name
), std::move(factory
));
493 if (!it_success
.second
) {
494 const auto& factory_name
= it_success
.first
->first
;
495 return Status::KeyError("ExecNode factory named ", factory_name
,
496 " already registered.");
503 std::unordered_map
<std::string
, Factory
> factories_
;
506 static DefaultRegistry instance
;
510 Result
<std::function
<Future
<util::optional
<ExecBatch
>>()>> MakeReaderGenerator(
511 std::shared_ptr
<RecordBatchReader
> reader
, ::arrow::internal::Executor
* io_executor
,
512 int max_q
, int q_restart
) {
513 auto batch_it
= MakeMapIterator(
514 [](std::shared_ptr
<RecordBatch
> batch
) {
515 return util::make_optional(ExecBatch(*batch
));
517 MakeIteratorFromReader(reader
));
519 return MakeBackgroundGenerator(std::move(batch_it
), io_executor
, max_q
, q_restart
);
521 } // namespace compute