1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements. See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership. The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License. You may obtain a copy of the License at
9 // http://www.apache.org/licenses/LICENSE-2.0
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied. See the License for the
15 // specific language governing permissions and limitations
18 #include "arrow/compute/exec/exec_plan.h"
20 #include "arrow/compute/api_vector.h"
21 #include "arrow/compute/exec.h"
22 #include "arrow/compute/exec/expression.h"
23 #include "arrow/compute/exec/options.h"
24 #include "arrow/datum.h"
25 #include "arrow/result.h"
26 #include "arrow/util/checked_cast.h"
27 #include "arrow/util/future.h"
28 #include "arrow/util/logging.h"
31 using internal::checked_cast
;
36 class FilterNode
: public MapNode
{
38 FilterNode(ExecPlan
* plan
, std::vector
<ExecNode
*> inputs
,
39 std::shared_ptr
<Schema
> output_schema
, Expression filter
, bool async_mode
)
40 : MapNode(plan
, std::move(inputs
), std::move(output_schema
), async_mode
),
41 filter_(std::move(filter
)) {}
43 static Result
<ExecNode
*> Make(ExecPlan
* plan
, std::vector
<ExecNode
*> inputs
,
44 const ExecNodeOptions
& options
) {
45 RETURN_NOT_OK(ValidateExecNodeInputs(plan
, inputs
, 1, "FilterNode"));
46 auto schema
= inputs
[0]->output_schema();
48 const auto& filter_options
= checked_cast
<const FilterNodeOptions
&>(options
);
50 auto filter_expression
= filter_options
.filter_expression
;
51 if (!filter_expression
.IsBound()) {
52 ARROW_ASSIGN_OR_RAISE(filter_expression
, filter_expression
.Bind(*schema
));
55 if (filter_expression
.type()->id() != Type::BOOL
) {
56 return Status::TypeError("Filter expression must evaluate to bool, but ",
57 filter_expression
.ToString(), " evaluates to ",
58 filter_expression
.type()->ToString());
60 return plan
->EmplaceNode
<FilterNode
>(plan
, std::move(inputs
), std::move(schema
),
61 std::move(filter_expression
),
62 filter_options
.async_mode
);
65 const char* kind_name() const override
{ return "FilterNode"; }
67 Result
<ExecBatch
> DoFilter(const ExecBatch
& target
) {
68 ARROW_ASSIGN_OR_RAISE(Expression simplified_filter
,
69 SimplifyWithGuarantee(filter_
, target
.guarantee
));
71 ARROW_ASSIGN_OR_RAISE(Datum mask
, ExecuteScalarExpression(simplified_filter
, target
,
72 plan()->exec_context()));
74 if (mask
.is_scalar()) {
75 const auto& mask_scalar
= mask
.scalar_as
<BooleanScalar
>();
76 if (mask_scalar
.is_valid
&& mask_scalar
.value
) {
80 return target
.Slice(0, 0);
83 // if the values are all scalar then the mask must also be
84 DCHECK(!std::all_of(target
.values
.begin(), target
.values
.end(),
85 [](const Datum
& value
) { return value
.is_scalar(); }));
87 auto values
= target
.values
;
88 for (auto& value
: values
) {
89 if (value
.is_scalar()) continue;
90 ARROW_ASSIGN_OR_RAISE(value
, Filter(value
, mask
, FilterOptions::Defaults()));
92 return ExecBatch::Make(std::move(values
));
95 void InputReceived(ExecNode
* input
, ExecBatch batch
) override
{
96 DCHECK_EQ(input
, inputs_
[0]);
97 auto func
= [this](ExecBatch batch
) { return DoFilter(std::move(batch
)); };
98 this->SubmitTask(std::move(func
), std::move(batch
));
102 std::string
ToStringExtra() const override
{ return "filter=" + filter_
.ToString(); }
110 void RegisterFilterNode(ExecFactoryRegistry
* registry
) {
111 DCHECK_OK(registry
->AddFactory("filter", FilterNode::Make
));
114 } // namespace internal
115 } // namespace compute