]> git.proxmox.com Git - ceph.git/blob - ceph/src/arrow/cpp/src/arrow/compute/exec/filter_node.cc
import quincy 17.2.0
[ceph.git] / ceph / src / arrow / cpp / src / arrow / compute / exec / filter_node.cc
1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements. See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership. The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License. You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied. See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17
18 #include "arrow/compute/exec/exec_plan.h"
19
20 #include "arrow/compute/api_vector.h"
21 #include "arrow/compute/exec.h"
22 #include "arrow/compute/exec/expression.h"
23 #include "arrow/compute/exec/options.h"
24 #include "arrow/datum.h"
25 #include "arrow/result.h"
26 #include "arrow/util/checked_cast.h"
27 #include "arrow/util/future.h"
28 #include "arrow/util/logging.h"
29 namespace arrow {
30
31 using internal::checked_cast;
32
33 namespace compute {
34 namespace {
35
36 class FilterNode : public MapNode {
37 public:
38 FilterNode(ExecPlan* plan, std::vector<ExecNode*> inputs,
39 std::shared_ptr<Schema> output_schema, Expression filter, bool async_mode)
40 : MapNode(plan, std::move(inputs), std::move(output_schema), async_mode),
41 filter_(std::move(filter)) {}
42
43 static Result<ExecNode*> Make(ExecPlan* plan, std::vector<ExecNode*> inputs,
44 const ExecNodeOptions& options) {
45 RETURN_NOT_OK(ValidateExecNodeInputs(plan, inputs, 1, "FilterNode"));
46 auto schema = inputs[0]->output_schema();
47
48 const auto& filter_options = checked_cast<const FilterNodeOptions&>(options);
49
50 auto filter_expression = filter_options.filter_expression;
51 if (!filter_expression.IsBound()) {
52 ARROW_ASSIGN_OR_RAISE(filter_expression, filter_expression.Bind(*schema));
53 }
54
55 if (filter_expression.type()->id() != Type::BOOL) {
56 return Status::TypeError("Filter expression must evaluate to bool, but ",
57 filter_expression.ToString(), " evaluates to ",
58 filter_expression.type()->ToString());
59 }
60 return plan->EmplaceNode<FilterNode>(plan, std::move(inputs), std::move(schema),
61 std::move(filter_expression),
62 filter_options.async_mode);
63 }
64
65 const char* kind_name() const override { return "FilterNode"; }
66
67 Result<ExecBatch> DoFilter(const ExecBatch& target) {
68 ARROW_ASSIGN_OR_RAISE(Expression simplified_filter,
69 SimplifyWithGuarantee(filter_, target.guarantee));
70
71 ARROW_ASSIGN_OR_RAISE(Datum mask, ExecuteScalarExpression(simplified_filter, target,
72 plan()->exec_context()));
73
74 if (mask.is_scalar()) {
75 const auto& mask_scalar = mask.scalar_as<BooleanScalar>();
76 if (mask_scalar.is_valid && mask_scalar.value) {
77 return target;
78 }
79
80 return target.Slice(0, 0);
81 }
82
83 // if the values are all scalar then the mask must also be
84 DCHECK(!std::all_of(target.values.begin(), target.values.end(),
85 [](const Datum& value) { return value.is_scalar(); }));
86
87 auto values = target.values;
88 for (auto& value : values) {
89 if (value.is_scalar()) continue;
90 ARROW_ASSIGN_OR_RAISE(value, Filter(value, mask, FilterOptions::Defaults()));
91 }
92 return ExecBatch::Make(std::move(values));
93 }
94
95 void InputReceived(ExecNode* input, ExecBatch batch) override {
96 DCHECK_EQ(input, inputs_[0]);
97 auto func = [this](ExecBatch batch) { return DoFilter(std::move(batch)); };
98 this->SubmitTask(std::move(func), std::move(batch));
99 }
100
101 protected:
102 std::string ToStringExtra() const override { return "filter=" + filter_.ToString(); }
103
104 private:
105 Expression filter_;
106 };
107 } // namespace
108
109 namespace internal {
110 void RegisterFilterNode(ExecFactoryRegistry* registry) {
111 DCHECK_OK(registry->AddFactory("filter", FilterNode::Make));
112 }
113
114 } // namespace internal
115 } // namespace compute
116 } // namespace arrow