1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements. See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership. The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License. You may obtain a copy of the License at
9 // http://www.apache.org/licenses/LICENSE-2.0
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied. See the License for the
15 // specific language governing permissions and limitations
18 #include "arrow/compute/exec/expression.h"
20 #include <unordered_map>
21 #include <unordered_set>
24 #include "arrow/compute/api_scalar.h"
25 #include "arrow/compute/cast.h"
26 #include "arrow/compute/registry.h"
27 #include "arrow/record_batch.h"
28 #include "arrow/table.h"
29 #include "arrow/util/logging.h"
33 using internal::checked_cast
;
37 struct KnownFieldValues
{
38 std::unordered_map
<FieldRef
, Datum
, FieldRef::Hash
> map
;
41 inline const Expression::Call
* CallNotNull(const Expression
& expr
) {
42 auto call
= expr
.call();
43 DCHECK_NE(call
, nullptr);
47 inline std::vector
<ValueDescr
> GetDescriptors(const std::vector
<Expression
>& exprs
) {
48 std::vector
<ValueDescr
> descrs(exprs
.size());
49 for (size_t i
= 0; i
< exprs
.size(); ++i
) {
50 DCHECK(exprs
[i
].IsBound());
51 descrs
[i
] = exprs
[i
].descr();
56 inline std::vector
<ValueDescr
> GetDescriptors(const std::vector
<Datum
>& values
) {
57 std::vector
<ValueDescr
> descrs(values
.size());
58 for (size_t i
= 0; i
< values
.size(); ++i
) {
59 descrs
[i
] = values
[i
].descr();
70 NOT_EQUAL
= LESS
| GREATER
,
71 LESS_EQUAL
= LESS
| EQUAL
,
72 GREATER_EQUAL
= GREATER
| EQUAL
,
75 static const type
* Get(const std::string
& function
) {
76 static std::unordered_map
<std::string
, type
> map
{
77 {"equal", EQUAL
}, {"not_equal", NOT_EQUAL
},
78 {"less", LESS
}, {"less_equal", LESS_EQUAL
},
79 {"greater", GREATER
}, {"greater_equal", GREATER_EQUAL
},
82 auto it
= map
.find(function
);
83 return it
!= map
.end() ? &it
->second
: nullptr;
86 static const type
* Get(const Expression
& expr
) {
87 if (auto call
= expr
.call()) {
88 return Comparison::Get(call
->function_name
);
93 // Execute a simple Comparison between scalars
94 static Result
<type
> Execute(Datum l
, Datum r
) {
95 if (!l
.is_scalar() || !r
.is_scalar()) {
96 return Status::Invalid("Cannot Execute Comparison on non-scalars");
99 std::vector
<Datum
> arguments
{std::move(l
), std::move(r
)};
101 ARROW_ASSIGN_OR_RAISE(auto equal
, compute::CallFunction("equal", arguments
));
103 if (!equal
.scalar()->is_valid
) return NA
;
104 if (equal
.scalar_as
<BooleanScalar
>().value
) return EQUAL
;
106 ARROW_ASSIGN_OR_RAISE(auto less
, compute::CallFunction("less", arguments
));
108 if (!less
.scalar()->is_valid
) return NA
;
109 return less
.scalar_as
<BooleanScalar
>().value
? LESS
: GREATER
;
112 // Given an Expression wrapped in casts which preserve ordering
113 // (for example, cast(field_ref("i16"), to_type=int32())), unwrap the inner Expression.
114 // This is used to destructure implicitly cast field_refs during Expression
116 static const Expression
& StripOrderPreservingCasts(const Expression
& expr
) {
117 auto call
= expr
.call();
118 if (!call
) return expr
;
119 if (call
->function_name
!= "cast") return expr
;
121 const Expression
& from
= call
->arguments
[0];
123 auto from_id
= from
.type()->id();
124 auto to_id
= expr
.type()->id();
126 if (is_floating(to_id
)) {
127 if (is_integer(from_id
) || is_floating(from_id
)) {
128 return StripOrderPreservingCasts(from
);
133 if (is_unsigned_integer(to_id
)) {
134 if (is_unsigned_integer(from_id
) && bit_width(to_id
) >= bit_width(from_id
)) {
135 return StripOrderPreservingCasts(from
);
140 if (is_signed_integer(to_id
)) {
141 if (is_integer(from_id
) && bit_width(to_id
) >= bit_width(from_id
)) {
142 return StripOrderPreservingCasts(from
);
150 static type
GetFlipped(type op
) {
163 return GREATER_EQUAL
;
171 static std::string
GetName(type op
) {
186 return "greater_equal";
191 static std::string
GetOp(type op
) {
194 DCHECK(false) << "unreachable";
214 inline const compute::CastOptions
* GetCastOptions(const Expression::Call
& call
) {
215 if (call
.function_name
!= "cast") return nullptr;
216 return checked_cast
<const compute::CastOptions
*>(call
.options
.get());
219 inline bool IsSetLookup(const std::string
& function
) {
220 return function
== "is_in" || function
== "index_in";
223 inline const compute::MakeStructOptions
* GetMakeStructOptions(
224 const Expression::Call
& call
) {
225 if (call
.function_name
!= "make_struct") return nullptr;
226 return checked_cast
<const compute::MakeStructOptions
*>(call
.options
.get());
229 /// A helper for unboxing an Expression composed of associative function calls.
230 /// Such expressions can frequently be rearranged to a semantically equivalent
231 /// expression for more optimal execution or more straightforward manipulation.
232 /// For example, (a + ((b + 3) + 4)) is equivalent to (((4 + 3) + a) + b) and the latter
233 /// can be trivially constant-folded to ((7 + a) + b).
234 struct FlattenedAssociativeChain
{
235 /// True if a chain was already a left fold.
236 bool was_left_folded
= true;
238 /// All "branch" expressions in a flattened chain. For example given (a + ((b + 3) + 4))
239 /// exprs would be [(a + ((b + 3) + 4)), ((b + 3) + 4), (b + 3)]
240 std::vector
<Expression
> exprs
;
242 /// All "leaf" expressions in a flattened chain. For example given (a + ((b + 3) + 4))
243 /// the fringe would be [a, b, 3, 4]
244 std::vector
<Expression
> fringe
;
246 explicit FlattenedAssociativeChain(Expression expr
) : exprs
{std::move(expr
)} {
247 auto call
= CallNotNull(exprs
.back());
248 fringe
= call
->arguments
;
250 auto it
= fringe
.begin();
252 while (it
!= fringe
.end()) {
253 auto sub_call
= it
->call();
254 if (!sub_call
|| sub_call
->function_name
!= call
->function_name
) {
259 if (it
!= fringe
.begin()) {
260 was_left_folded
= false;
263 exprs
.push_back(std::move(*it
));
264 it
= fringe
.erase(it
);
266 auto index
= it
- fringe
.begin();
267 fringe
.insert(it
, sub_call
->arguments
.begin(), sub_call
->arguments
.end());
268 it
= fringe
.begin() + index
;
269 // NB: no increment so we hit sub_call's first argument next iteration
272 DCHECK(std::all_of(exprs
.begin(), exprs
.end(), [](const Expression
& expr
) {
273 return CallNotNull(expr
)->options
== nullptr;
278 inline Result
<std::shared_ptr
<compute::Function
>> GetFunction(
279 const Expression::Call
& call
, compute::ExecContext
* exec_context
) {
280 if (call
.function_name
!= "cast") {
281 return exec_context
->func_registry()->GetFunction(call
.function_name
);
283 // XXX this special case is strange; why not make "cast" a ScalarFunction?
284 const auto& to_type
= checked_cast
<const compute::CastOptions
&>(*call
.options
).to_type
;
285 return compute::GetCastFunction(to_type
);
288 /// Modify an Expression with pre-order and post-order visitation.
289 /// `pre` will be invoked on each Expression. `pre` will visit Calls before their
290 /// arguments, `post_call` will visit Calls (and no other Expressions) after their
291 /// arguments. Visitors should return the Identical expression to indicate no change; this
292 /// will prevent unnecessary construction in the common case where a modification is not
293 /// possible/necessary/...
295 /// If an argument was modified, `post_call` visits a reconstructed Call with the modified
296 /// arguments but also receives a pointer to the unmodified Expression as a second
297 /// argument. If no arguments were modified the unmodified Expression* will be nullptr.
298 template <typename PreVisit
, typename PostVisitCall
>
299 Result
<Expression
> Modify(Expression expr
, const PreVisit
& pre
,
300 const PostVisitCall
& post_call
) {
301 ARROW_ASSIGN_OR_RAISE(expr
, Result
<Expression
>(pre(std::move(expr
))));
303 auto call
= expr
.call();
304 if (!call
) return expr
;
306 bool at_least_one_modified
= false;
307 std::vector
<Expression
> modified_arguments
;
309 for (size_t i
= 0; i
< call
->arguments
.size(); ++i
) {
310 ARROW_ASSIGN_OR_RAISE(auto modified_argument
,
311 Modify(call
->arguments
[i
], pre
, post_call
));
313 if (Identical(modified_argument
, call
->arguments
[i
])) {
317 if (!at_least_one_modified
) {
318 modified_arguments
= call
->arguments
;
319 at_least_one_modified
= true;
322 modified_arguments
[i
] = std::move(modified_argument
);
325 if (at_least_one_modified
) {
326 // reconstruct the call expression with the modified arguments
327 auto modified_call
= *call
;
328 modified_call
.arguments
= std::move(modified_arguments
);
329 return post_call(Expression(std::move(modified_call
)), &expr
);
332 return post_call(std::move(expr
), nullptr);
335 } // namespace compute