1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements. See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership. The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License. You may obtain a copy of the License at
9 // http://www.apache.org/licenses/LICENSE-2.0
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied. See the License for the
15 // specific language governing permissions and limitations
18 #include "arrow/compute/exec/expression.h"
20 #include <unordered_map>
21 #include <unordered_set>
23 #include "arrow/chunked_array.h"
24 #include "arrow/compute/api_vector.h"
25 #include "arrow/compute/exec/expression_internal.h"
26 #include "arrow/compute/exec_internal.h"
27 #include "arrow/compute/function_internal.h"
28 #include "arrow/io/memory.h"
29 #include "arrow/ipc/reader.h"
30 #include "arrow/ipc/writer.h"
31 #include "arrow/util/hash_util.h"
32 #include "arrow/util/key_value_metadata.h"
33 #include "arrow/util/logging.h"
34 #include "arrow/util/optional.h"
35 #include "arrow/util/string.h"
36 #include "arrow/util/value_parsing.h"
40 using internal::checked_cast
;
41 using internal::checked_pointer_cast
;
45 void Expression::Call::ComputeHash() {
46 hash
= std::hash
<std::string
>{}(function_name
);
47 for (const auto& arg
: arguments
) {
48 arrow::internal::hash_combine(hash
, arg
.hash());
52 Expression::Expression(Call call
) {
54 impl_
= std::make_shared
<Impl
>(std::move(call
));
57 Expression::Expression(Datum literal
)
58 : impl_(std::make_shared
<Impl
>(std::move(literal
))) {}
60 Expression::Expression(Parameter parameter
)
61 : impl_(std::make_shared
<Impl
>(std::move(parameter
))) {}
63 Expression
literal(Datum lit
) { return Expression(std::move(lit
)); }
65 Expression
field_ref(FieldRef ref
) {
66 return Expression(Expression::Parameter
{std::move(ref
), ValueDescr
{}, -1});
69 Expression
call(std::string function
, std::vector
<Expression
> arguments
,
70 std::shared_ptr
<compute::FunctionOptions
> options
) {
71 Expression::Call call
;
72 call
.function_name
= std::move(function
);
73 call
.arguments
= std::move(arguments
);
74 call
.options
= std::move(options
);
75 return Expression(std::move(call
));
78 const Datum
* Expression::literal() const { return util::get_if
<Datum
>(impl_
.get()); }
80 const Expression::Parameter
* Expression::parameter() const {
81 return util::get_if
<Parameter
>(impl_
.get());
84 const FieldRef
* Expression::field_ref() const {
85 if (auto parameter
= this->parameter()) {
86 return ¶meter
->ref
;
91 const Expression::Call
* Expression::call() const {
92 return util::get_if
<Call
>(impl_
.get());
95 ValueDescr
Expression::descr() const {
96 if (impl_
== nullptr) return {};
98 if (auto lit
= literal()) {
102 if (auto parameter
= this->parameter()) {
103 return parameter
->descr
;
106 return CallNotNull(*this)->descr
;
111 std::string
PrintDatum(const Datum
& datum
) {
112 if (datum
.is_scalar()) {
113 if (!datum
.scalar()->is_valid
) return "null";
115 switch (datum
.type()->id()) {
117 case Type::LARGE_STRING
:
119 Escape(util::string_view(*datum
.scalar_as
<BaseBinaryScalar
>().value
)) +
123 case Type::FIXED_SIZE_BINARY
:
124 case Type::LARGE_BINARY
:
125 return '"' + datum
.scalar_as
<BaseBinaryScalar
>().value
->ToHexString() + '"';
131 return datum
.scalar()->ToString();
133 return datum
.ToString();
138 std::string
Expression::ToString() const {
139 if (auto lit
= literal()) {
140 return PrintDatum(*lit
);
143 if (auto ref
= field_ref()) {
144 if (auto name
= ref
->name()) {
147 if (auto path
= ref
->field_path()) {
148 return path
->ToString();
150 return ref
->ToString();
153 auto call
= CallNotNull(*this);
154 auto binary
= [&](std::string op
) {
155 return "(" + call
->arguments
[0].ToString() + " " + op
+ " " +
156 call
->arguments
[1].ToString() + ")";
159 if (auto cmp
= Comparison::Get(call
->function_name
)) {
160 return binary(Comparison::GetOp(*cmp
));
163 constexpr util::string_view kleene
= "_kleene";
164 if (util::string_view
{call
->function_name
}.ends_with(kleene
)) {
165 auto op
= call
->function_name
.substr(0, call
->function_name
.size() - kleene
.size());
166 return binary(std::move(op
));
169 if (auto options
= GetMakeStructOptions(*call
)) {
170 std::string out
= "{";
171 auto argument
= call
->arguments
.begin();
172 for (const auto& field_name
: options
->field_names
) {
173 out
+= field_name
+ "=" + argument
++->ToString() + ", ";
175 out
.resize(out
.size() - 1);
180 std::string out
= call
->function_name
+ "(";
181 for (const auto& arg
: call
->arguments
) {
182 out
+= arg
.ToString() + ", ";
186 out
+= call
->options
->ToString();
187 out
.resize(out
.size() + 1);
189 out
.resize(out
.size() - 1);
195 void PrintTo(const Expression
& expr
, std::ostream
* os
) {
196 *os
<< expr
.ToString();
197 if (expr
.IsBound()) {
202 bool Expression::Equals(const Expression
& other
) const {
203 if (Identical(*this, other
)) return true;
205 if (impl_
->index() != other
.impl_
->index()) {
209 if (auto lit
= literal()) {
210 return lit
->Equals(*other
.literal());
213 if (auto ref
= field_ref()) {
214 return ref
->Equals(*other
.field_ref());
217 auto call
= CallNotNull(*this);
218 auto other_call
= CallNotNull(other
);
220 if (call
->function_name
!= other_call
->function_name
||
221 call
->kernel
!= other_call
->kernel
) {
225 for (size_t i
= 0; i
< call
->arguments
.size(); ++i
) {
226 if (!call
->arguments
[i
].Equals(other_call
->arguments
[i
])) {
231 if (call
->options
== other_call
->options
) return true;
232 if (call
->options
&& other_call
->options
) {
233 return call
->options
->Equals(other_call
->options
);
238 bool Identical(const Expression
& l
, const Expression
& r
) { return l
.impl_
== r
.impl_
; }
240 size_t Expression::hash() const {
241 if (auto lit
= literal()) {
242 if (lit
->is_scalar()) {
243 return lit
->scalar()->hash();
248 if (auto ref
= field_ref()) {
252 return CallNotNull(*this)->hash
;
255 bool Expression::IsBound() const {
256 if (type() == nullptr) return false;
258 if (auto call
= this->call()) {
259 if (call
->kernel
== nullptr) return false;
261 for (const Expression
& arg
: call
->arguments
) {
262 if (!arg
.IsBound()) return false;
269 bool Expression::IsScalarExpression() const {
270 if (auto lit
= literal()) {
271 return lit
->is_scalar();
274 if (field_ref()) return true;
276 auto call
= CallNotNull(*this);
278 for (const Expression
& arg
: call
->arguments
) {
279 if (!arg
.IsScalarExpression()) return false;
282 if (call
->function
) {
283 return call
->function
->kind() == compute::Function::SCALAR
;
286 // this expression is not bound; make a best guess based on
287 // the default function registry
288 if (auto function
= compute::GetFunctionRegistry()
289 ->GetFunction(call
->function_name
)
291 return function
->kind() == compute::Function::SCALAR
;
294 // unknown function or other error; conservatively return false
298 bool Expression::IsNullLiteral() const {
299 if (auto lit
= literal()) {
300 if (lit
->null_count() == lit
->length()) {
308 bool Expression::IsSatisfiable() const {
309 if (type() && type()->id() == Type::NA
) {
313 if (auto lit
= literal()) {
314 if (lit
->null_count() == lit
->length()) {
318 if (lit
->is_scalar() && lit
->type()->id() == Type::BOOL
) {
319 return lit
->scalar_as
<BooleanScalar
>().value
;
328 // Produce a bound Expression from unbound Call and bound arguments.
329 Result
<Expression
> BindNonRecursive(Expression::Call call
, bool insert_implicit_casts
,
330 compute::ExecContext
* exec_context
) {
331 DCHECK(std::all_of(call
.arguments
.begin(), call
.arguments
.end(),
332 [](const Expression
& argument
) { return argument
.IsBound(); }));
334 auto descrs
= GetDescriptors(call
.arguments
);
335 ARROW_ASSIGN_OR_RAISE(call
.function
, GetFunction(call
, exec_context
));
337 if (!insert_implicit_casts
) {
338 ARROW_ASSIGN_OR_RAISE(call
.kernel
, call
.function
->DispatchExact(descrs
));
340 ARROW_ASSIGN_OR_RAISE(call
.kernel
, call
.function
->DispatchBest(&descrs
));
342 for (size_t i
= 0; i
< descrs
.size(); ++i
) {
343 if (descrs
[i
] == call
.arguments
[i
].descr()) continue;
345 if (descrs
[i
].shape
!= call
.arguments
[i
].descr().shape
) {
346 return Status::NotImplemented(
347 "Automatic broadcasting of scalars arguments to arrays in ",
348 Expression(std::move(call
)).ToString());
351 if (auto lit
= call
.arguments
[i
].literal()) {
352 ARROW_ASSIGN_OR_RAISE(Datum new_lit
, compute::Cast(*lit
, descrs
[i
].type
));
353 call
.arguments
[i
] = literal(std::move(new_lit
));
357 // construct an implicit cast Expression with which to replace this argument
358 Expression::Call implicit_cast
;
359 implicit_cast
.function_name
= "cast";
360 implicit_cast
.arguments
= {std::move(call
.arguments
[i
])};
361 implicit_cast
.options
= std::make_shared
<compute::CastOptions
>(
362 compute::CastOptions::Safe(descrs
[i
].type
));
364 ARROW_ASSIGN_OR_RAISE(
366 BindNonRecursive(std::move(implicit_cast
),
367 /*insert_implicit_casts=*/false, exec_context
));
371 compute::KernelContext
kernel_context(exec_context
);
372 if (call
.kernel
->init
) {
373 ARROW_ASSIGN_OR_RAISE(
375 call
.kernel
->init(&kernel_context
, {call
.kernel
, descrs
, call
.options
.get()}));
377 kernel_context
.SetState(call
.kernel_state
.get());
380 ARROW_ASSIGN_OR_RAISE(
381 call
.descr
, call
.kernel
->signature
->out_type().Resolve(&kernel_context
, descrs
));
383 return Expression(std::move(call
));
386 template <typename TypeOrSchema
>
387 Result
<Expression
> BindImpl(Expression expr
, const TypeOrSchema
& in
,
388 ValueDescr::Shape shape
, compute::ExecContext
* exec_context
) {
389 if (exec_context
== nullptr) {
390 compute::ExecContext exec_context
;
391 return BindImpl(std::move(expr
), in
, shape
, &exec_context
);
394 if (expr
.literal()) return expr
;
396 if (auto ref
= expr
.field_ref()) {
397 if (ref
->IsNested()) {
398 return Status::NotImplemented("nested field references");
401 ARROW_ASSIGN_OR_RAISE(auto path
, ref
->FindOne(in
));
403 auto bound
= *expr
.parameter();
404 bound
.index
= path
[0];
405 ARROW_ASSIGN_OR_RAISE(auto field
, path
.Get(in
));
406 bound
.descr
.type
= field
->type();
407 bound
.descr
.shape
= shape
;
408 return Expression
{std::move(bound
)};
411 auto call
= *CallNotNull(expr
);
412 for (auto& argument
: call
.arguments
) {
413 ARROW_ASSIGN_OR_RAISE(argument
,
414 BindImpl(std::move(argument
), in
, shape
, exec_context
));
416 return BindNonRecursive(std::move(call
),
417 /*insert_implicit_casts=*/true, exec_context
);
422 Result
<Expression
> Expression::Bind(const ValueDescr
& in
,
423 compute::ExecContext
* exec_context
) const {
424 return BindImpl(*this, *in
.type
, in
.shape
, exec_context
);
427 Result
<Expression
> Expression::Bind(const Schema
& in_schema
,
428 compute::ExecContext
* exec_context
) const {
429 return BindImpl(*this, in_schema
, ValueDescr::ARRAY
, exec_context
);
432 Result
<ExecBatch
> MakeExecBatch(const Schema
& full_schema
, const Datum
& partial
) {
435 if (partial
.kind() == Datum::RECORD_BATCH
) {
436 const auto& partial_batch
= *partial
.record_batch();
437 out
.length
= partial_batch
.num_rows();
439 for (const auto& field
: full_schema
.fields()) {
440 ARROW_ASSIGN_OR_RAISE(auto column
,
441 FieldRef(field
->name()).GetOneOrNone(partial_batch
));
444 if (!column
->type()->Equals(field
->type())) {
445 // Referenced field was present but didn't have the expected type.
446 // This *should* be handled by readers, and will just be an error in the future.
447 ARROW_ASSIGN_OR_RAISE(
449 compute::Cast(column
, field
->type(), compute::CastOptions::Safe()));
450 column
= converted
.make_array();
452 out
.values
.emplace_back(std::move(column
));
454 out
.values
.emplace_back(MakeNullScalar(field
->type()));
460 // wasteful but useful for testing:
461 if (partial
.type()->id() == Type::STRUCT
) {
462 if (partial
.is_array()) {
463 ARROW_ASSIGN_OR_RAISE(auto partial_batch
,
464 RecordBatch::FromStructArray(partial
.make_array()));
466 return MakeExecBatch(full_schema
, partial_batch
);
469 if (partial
.is_scalar()) {
470 ARROW_ASSIGN_OR_RAISE(auto partial_array
,
471 MakeArrayFromScalar(*partial
.scalar(), 1));
472 ARROW_ASSIGN_OR_RAISE(auto out
, MakeExecBatch(full_schema
, partial_array
));
474 for (Datum
& value
: out
.values
) {
475 if (value
.is_scalar()) continue;
476 ARROW_ASSIGN_OR_RAISE(value
, value
.make_array()->GetScalar(0));
482 return Status::NotImplemented("MakeExecBatch from ", PrintDatum(partial
));
485 Result
<Datum
> ExecuteScalarExpression(const Expression
& expr
, const Schema
& full_schema
,
486 const Datum
& partial_input
,
487 compute::ExecContext
* exec_context
) {
488 ARROW_ASSIGN_OR_RAISE(auto input
, MakeExecBatch(full_schema
, partial_input
));
489 return ExecuteScalarExpression(expr
, input
, exec_context
);
492 Result
<Datum
> ExecuteScalarExpression(const Expression
& expr
, const ExecBatch
& input
,
493 compute::ExecContext
* exec_context
) {
494 if (exec_context
== nullptr) {
495 compute::ExecContext exec_context
;
496 return ExecuteScalarExpression(expr
, input
, &exec_context
);
499 if (!expr
.IsBound()) {
500 return Status::Invalid("Cannot Execute unbound expression.");
503 if (!expr
.IsScalarExpression()) {
504 return Status::Invalid(
505 "ExecuteScalarExpression cannot Execute non-scalar expression ", expr
.ToString());
508 if (auto lit
= expr
.literal()) return *lit
;
510 if (auto param
= expr
.parameter()) {
511 if (param
->descr
.type
->id() == Type::NA
) {
512 return MakeNullScalar(null());
515 const Datum
& field
= input
[param
->index
];
516 if (!field
.type()->Equals(param
->descr
.type
)) {
517 return Status::Invalid("Referenced field ", expr
.ToString(), " was ",
518 field
.type()->ToString(), " but should have been ",
519 param
->descr
.type
->ToString());
525 auto call
= CallNotNull(expr
);
527 std::vector
<Datum
> arguments(call
->arguments
.size());
528 for (size_t i
= 0; i
< arguments
.size(); ++i
) {
529 ARROW_ASSIGN_OR_RAISE(
530 arguments
[i
], ExecuteScalarExpression(call
->arguments
[i
], input
, exec_context
));
533 auto executor
= compute::detail::KernelExecutor::MakeScalar();
535 compute::KernelContext
kernel_context(exec_context
);
536 kernel_context
.SetState(call
->kernel_state
.get());
538 auto kernel
= call
->kernel
;
539 auto descrs
= GetDescriptors(arguments
);
540 auto options
= call
->options
.get();
541 RETURN_NOT_OK(executor
->Init(&kernel_context
, {kernel
, descrs
, options
}));
543 compute::detail::DatumAccumulator listener
;
544 RETURN_NOT_OK(executor
->Execute(arguments
, &listener
));
545 const auto out
= executor
->WrapResults(arguments
, listener
.values());
547 DCHECK_OK(executor
->CheckResultType(out
, call
->function_name
.c_str()));
554 std::array
<std::pair
<const Expression
&, const Expression
&>, 2>
555 ArgumentsAndFlippedArguments(const Expression::Call
& call
) {
556 DCHECK_EQ(call
.arguments
.size(), 2);
557 return {std::pair
<const Expression
&, const Expression
&>{call
.arguments
[0],
559 std::pair
<const Expression
&, const Expression
&>{call
.arguments
[1],
563 template <typename BinOp
, typename It
,
564 typename Out
= typename
std::iterator_traits
<It
>::value_type
>
565 util::optional
<Out
> FoldLeft(It begin
, It end
, const BinOp
& bin_op
) {
566 if (begin
== end
) return util::nullopt
;
568 Out folded
= std::move(*begin
++);
569 while (begin
!= end
) {
570 folded
= bin_op(std::move(folded
), std::move(*begin
++));
575 util::optional
<compute::NullHandling::type
> GetNullHandling(
576 const Expression::Call
& call
) {
577 if (call
.function
&& call
.function
->kind() == compute::Function::SCALAR
) {
578 return static_cast<const compute::ScalarKernel
*>(call
.kernel
)->null_handling
;
580 return util::nullopt
;
585 std::vector
<FieldRef
> FieldsInExpression(const Expression
& expr
) {
586 if (expr
.literal()) return {};
588 if (auto ref
= expr
.field_ref()) {
592 std::vector
<FieldRef
> fields
;
593 for (const Expression
& arg
: CallNotNull(expr
)->arguments
) {
594 auto argument_fields
= FieldsInExpression(arg
);
595 std::move(argument_fields
.begin(), argument_fields
.end(), std::back_inserter(fields
));
600 bool ExpressionHasFieldRefs(const Expression
& expr
) {
601 if (expr
.literal()) return false;
603 if (expr
.field_ref()) return true;
605 for (const Expression
& arg
: CallNotNull(expr
)->arguments
) {
606 if (ExpressionHasFieldRefs(arg
)) return true;
611 Result
<Expression
> FoldConstants(Expression expr
) {
613 std::move(expr
), [](Expression expr
) { return expr
; },
614 [](Expression expr
, ...) -> Result
<Expression
> {
615 auto call
= CallNotNull(expr
);
616 if (std::all_of(call
->arguments
.begin(), call
->arguments
.end(),
617 [](const Expression
& argument
) { return argument
.literal(); })) {
618 // all arguments are literal; we can evaluate this subexpression *now*
619 static const ExecBatch ignored_input
= ExecBatch
{};
620 ARROW_ASSIGN_OR_RAISE(Datum constant
,
621 ExecuteScalarExpression(expr
, ignored_input
));
623 return literal(std::move(constant
));
626 // XXX the following should probably be in a registry of passes instead
629 if (GetNullHandling(*call
) == compute::NullHandling::INTERSECTION
) {
630 // kernels which always produce intersected validity can be resolved
631 // to null *now* if any of their inputs is a null literal
632 for (const auto& argument
: call
->arguments
) {
633 if (argument
.IsNullLiteral()) {
639 if (call
->function_name
== "and_kleene") {
640 for (auto args
: ArgumentsAndFlippedArguments(*call
)) {
642 if (args
.first
== literal(true)) return args
.second
;
644 // false and x == false
645 if (args
.first
== literal(false)) return args
.first
;
648 if (args
.first
== args
.second
) return args
.first
;
653 if (call
->function_name
== "or_kleene") {
654 for (auto args
: ArgumentsAndFlippedArguments(*call
)) {
656 if (args
.first
== literal(false)) return args
.second
;
659 if (args
.first
== literal(true)) return args
.first
;
662 if (args
.first
== args
.second
) return args
.first
;
673 std::vector
<Expression
> GuaranteeConjunctionMembers(
674 const Expression
& guaranteed_true_predicate
) {
675 auto guarantee
= guaranteed_true_predicate
.call();
676 if (!guarantee
|| guarantee
->function_name
!= "and_kleene") {
677 return {guaranteed_true_predicate
};
679 return FlattenedAssociativeChain(guaranteed_true_predicate
).fringe
;
682 // Conjunction members which are represented in known_values are erased from
683 // conjunction_members
684 Status
ExtractKnownFieldValuesImpl(
685 std::vector
<Expression
>* conjunction_members
,
686 std::unordered_map
<FieldRef
, Datum
, FieldRef::Hash
>* known_values
) {
687 auto unconsumed_end
=
688 std::partition(conjunction_members
->begin(), conjunction_members
->end(),
689 [](const Expression
& expr
) {
690 // search for an equality conditions between a field and a literal
691 auto call
= expr
.call();
692 if (!call
) return true;
694 if (call
->function_name
== "equal") {
695 auto ref
= call
->arguments
[0].field_ref();
696 auto lit
= call
->arguments
[1].literal();
697 return !(ref
&& lit
);
700 if (call
->function_name
== "is_null") {
701 auto ref
= call
->arguments
[0].field_ref();
708 for (auto it
= unconsumed_end
; it
!= conjunction_members
->end(); ++it
) {
709 auto call
= CallNotNull(*it
);
711 if (call
->function_name
== "equal") {
712 auto ref
= call
->arguments
[0].field_ref();
713 auto lit
= call
->arguments
[1].literal();
714 known_values
->emplace(*ref
, *lit
);
715 } else if (call
->function_name
== "is_null") {
716 auto ref
= call
->arguments
[0].field_ref();
717 known_values
->emplace(*ref
, Datum(std::make_shared
<NullScalar
>()));
721 conjunction_members
->erase(unconsumed_end
, conjunction_members
->end());
728 Result
<KnownFieldValues
> ExtractKnownFieldValues(
729 const Expression
& guaranteed_true_predicate
) {
730 auto conjunction_members
= GuaranteeConjunctionMembers(guaranteed_true_predicate
);
731 KnownFieldValues known_values
;
732 RETURN_NOT_OK(ExtractKnownFieldValuesImpl(&conjunction_members
, &known_values
.map
));
736 Result
<Expression
> ReplaceFieldsWithKnownValues(const KnownFieldValues
& known_values
,
738 if (!expr
.IsBound()) {
739 return Status::Invalid(
740 "ReplaceFieldsWithKnownValues called on an unbound Expression");
745 [&known_values
](Expression expr
) -> Result
<Expression
> {
746 if (auto ref
= expr
.field_ref()) {
747 auto it
= known_values
.map
.find(*ref
);
748 if (it
!= known_values
.map
.end()) {
749 Datum lit
= it
->second
;
750 if (lit
.descr() == expr
.descr()) return literal(std::move(lit
));
751 // type mismatch, try casting the known value to the correct type
753 if (expr
.type()->id() == Type::DICTIONARY
&&
754 lit
.type()->id() != Type::DICTIONARY
) {
755 // the known value must be dictionary encoded
757 const auto& dict_type
= checked_cast
<const DictionaryType
&>(*expr
.type());
758 if (!lit
.type()->Equals(dict_type
.value_type())) {
759 ARROW_ASSIGN_OR_RAISE(lit
, compute::Cast(lit
, dict_type
.value_type()));
762 if (lit
.is_scalar()) {
763 ARROW_ASSIGN_OR_RAISE(auto dictionary
,
764 MakeArrayFromScalar(*lit
.scalar(), 1));
766 lit
= Datum
{DictionaryScalar::Make(MakeScalar
<int32_t>(0),
767 std::move(dictionary
))};
771 ARROW_ASSIGN_OR_RAISE(lit
, compute::Cast(lit
, expr
.type()));
772 return literal(std::move(lit
));
777 [](Expression expr
, ...) { return expr
; });
782 bool IsBinaryAssociativeCommutative(const Expression::Call
& call
) {
783 static std::unordered_set
<std::string
> binary_associative_commutative
{
784 "and", "or", "and_kleene", "or_kleene", "xor",
785 "multiply", "add", "multiply_checked", "add_checked"};
787 auto it
= binary_associative_commutative
.find(call
.function_name
);
788 return it
!= binary_associative_commutative
.end();
793 Result
<Expression
> Canonicalize(Expression expr
, compute::ExecContext
* exec_context
) {
794 if (exec_context
== nullptr) {
795 compute::ExecContext exec_context
;
796 return Canonicalize(std::move(expr
), &exec_context
);
799 // If potentially reconstructing more deeply than a call's immediate arguments
800 // (for example, when reorganizing an associative chain), add expressions to this set to
801 // avoid unnecessary work
803 std::unordered_set
<Expression
, Expression::Hash
> set_
;
805 bool operator()(const Expression
& expr
) const {
806 return set_
.find(expr
) != set_
.end();
809 void Add(std::vector
<Expression
> exprs
) {
810 std::move(exprs
.begin(), exprs
.end(), std::inserter(set_
, set_
.end()));
812 } AlreadyCanonicalized
;
816 [&AlreadyCanonicalized
, exec_context
](Expression expr
) -> Result
<Expression
> {
817 auto call
= expr
.call();
818 if (!call
) return expr
;
820 if (AlreadyCanonicalized(expr
)) return expr
;
822 if (IsBinaryAssociativeCommutative(*call
)) {
824 int Priority(const Expression
& operand
) const {
825 // order literals first, starting with nulls
826 if (operand
.IsNullLiteral()) return 0;
827 if (operand
.literal()) return 1;
830 bool operator()(const Expression
& l
, const Expression
& r
) const {
831 return Priority(l
) < Priority(r
);
835 FlattenedAssociativeChain
chain(expr
);
836 if (chain
.was_left_folded
&&
837 std::is_sorted(chain
.fringe
.begin(), chain
.fringe
.end(),
838 CanonicalOrdering
)) {
839 AlreadyCanonicalized
.Add(std::move(chain
.exprs
));
843 std::stable_sort(chain
.fringe
.begin(), chain
.fringe
.end(), CanonicalOrdering
);
845 // fold the chain back up
847 FoldLeft(chain
.fringe
.begin(), chain
.fringe
.end(),
848 [call
, &AlreadyCanonicalized
](Expression l
, Expression r
) {
849 auto canonicalized_call
= *call
;
850 canonicalized_call
.arguments
= {std::move(l
), std::move(r
)};
851 Expression
expr(std::move(canonicalized_call
));
852 AlreadyCanonicalized
.Add({expr
});
855 return std::move(*folded
);
858 if (auto cmp
= Comparison::Get(call
->function_name
)) {
859 if (call
->arguments
[0].literal() && !call
->arguments
[1].literal()) {
860 // ensure that literals are on comparisons' RHS
861 auto flipped_call
= *call
;
863 std::swap(flipped_call
.arguments
[0], flipped_call
.arguments
[1]);
864 flipped_call
.function_name
=
865 Comparison::GetName(Comparison::GetFlipped(*cmp
));
867 return BindNonRecursive(flipped_call
,
868 /*insert_implicit_casts=*/false, exec_context
);
874 [](Expression expr
, ...) { return expr
; });
879 Result
<Expression
> DirectComparisonSimplification(Expression expr
,
880 const Expression::Call
& guarantee
) {
882 std::move(expr
), [](Expression expr
) { return expr
; },
883 [&guarantee
](Expression expr
, ...) -> Result
<Expression
> {
884 auto call
= expr
.call();
885 if (!call
) return expr
;
887 // Ensure both calls are comparisons with equal LHS and scalar RHS
888 auto cmp
= Comparison::Get(expr
);
889 auto cmp_guarantee
= Comparison::Get(guarantee
.function_name
);
891 if (!cmp
) return expr
;
892 if (!cmp_guarantee
) return expr
;
894 const auto& lhs
= Comparison::StripOrderPreservingCasts(call
->arguments
[0]);
895 const auto& guarantee_lhs
= guarantee
.arguments
[0];
896 if (lhs
!= guarantee_lhs
) return expr
;
898 auto rhs
= call
->arguments
[1].literal();
899 auto guarantee_rhs
= guarantee
.arguments
[1].literal();
901 if (!rhs
) return expr
;
902 if (!rhs
->is_scalar()) return expr
;
904 if (!guarantee_rhs
) return expr
;
905 if (!guarantee_rhs
->is_scalar()) return expr
;
907 ARROW_ASSIGN_OR_RAISE(auto cmp_rhs_guarantee_rhs
,
908 Comparison::Execute(*rhs
, *guarantee_rhs
));
909 DCHECK_NE(cmp_rhs_guarantee_rhs
, Comparison::NA
);
911 if (cmp_rhs_guarantee_rhs
== Comparison::EQUAL
) {
912 // RHS of filter is equal to RHS of guarantee
914 if ((*cmp
& *cmp_guarantee
) == *cmp_guarantee
) {
915 // guarantee is a subset of filter, so all data will be included
916 // x > 1, x >= 1, x != 1 guaranteed by x > 1
917 return literal(true);
920 if ((*cmp
& *cmp_guarantee
) == 0) {
921 // guarantee disjoint with filter, so all data will be excluded
922 // x > 1, x >= 1, x != 1 unsatisfiable if x == 1
923 return literal(false);
929 if (*cmp_guarantee
& cmp_rhs_guarantee_rhs
) {
930 // x > 1, x >= 1, x != 1 cannot use guarantee x >= 3
934 if (*cmp
& Comparison::GetFlipped(cmp_rhs_guarantee_rhs
)) {
935 // x > 1, x >= 1, x != 1 guaranteed by x >= 3
936 return literal(true);
938 // x < 1, x <= 1, x == 1 unsatisfiable if x >= 3
939 return literal(false);
946 Result
<Expression
> SimplifyWithGuarantee(Expression expr
,
947 const Expression
& guaranteed_true_predicate
) {
948 auto conjunction_members
= GuaranteeConjunctionMembers(guaranteed_true_predicate
);
950 KnownFieldValues known_values
;
951 RETURN_NOT_OK(ExtractKnownFieldValuesImpl(&conjunction_members
, &known_values
.map
));
953 ARROW_ASSIGN_OR_RAISE(expr
,
954 ReplaceFieldsWithKnownValues(known_values
, std::move(expr
)));
956 auto CanonicalizeAndFoldConstants
= [&expr
] {
957 ARROW_ASSIGN_OR_RAISE(expr
, Canonicalize(std::move(expr
)));
958 ARROW_ASSIGN_OR_RAISE(expr
, FoldConstants(std::move(expr
)));
961 RETURN_NOT_OK(CanonicalizeAndFoldConstants());
963 for (const auto& guarantee
: conjunction_members
) {
964 if (Comparison::Get(guarantee
) && guarantee
.call()->arguments
[1].literal()) {
965 ARROW_ASSIGN_OR_RAISE(
966 auto simplified
, DirectComparisonSimplification(expr
, *CallNotNull(guarantee
)));
968 if (Identical(simplified
, expr
)) continue;
970 expr
= std::move(simplified
);
971 RETURN_NOT_OK(CanonicalizeAndFoldConstants());
978 // Serialization is accomplished by converting expressions to KeyValueMetadata and storing
979 // this in the schema of a RecordBatch. Embedded arrays and scalars are stored in its
980 // columns. Finally, the RecordBatch is written to an IPC file.
981 Result
<std::shared_ptr
<Buffer
>> Serialize(const Expression
& expr
) {
983 std::shared_ptr
<KeyValueMetadata
> metadata_
= std::make_shared
<KeyValueMetadata
>();
984 ArrayVector columns_
;
986 Result
<std::string
> AddScalar(const Scalar
& scalar
) {
987 auto ret
= columns_
.size();
988 ARROW_ASSIGN_OR_RAISE(auto array
, MakeArrayFromScalar(scalar
, 1));
989 columns_
.push_back(std::move(array
));
990 return std::to_string(ret
);
993 Status
Visit(const Expression
& expr
) {
994 if (auto lit
= expr
.literal()) {
995 if (!lit
->is_scalar()) {
996 return Status::NotImplemented("Serialization of non-scalar literals");
998 ARROW_ASSIGN_OR_RAISE(auto value
, AddScalar(*lit
->scalar()));
999 metadata_
->Append("literal", std::move(value
));
1000 return Status::OK();
1003 if (auto ref
= expr
.field_ref()) {
1005 return Status::NotImplemented("Serialization of non-name field_refs");
1007 metadata_
->Append("field_ref", *ref
->name());
1008 return Status::OK();
1011 auto call
= CallNotNull(expr
);
1012 metadata_
->Append("call", call
->function_name
);
1014 for (const auto& argument
: call
->arguments
) {
1015 RETURN_NOT_OK(Visit(argument
));
1018 if (call
->options
) {
1019 ARROW_ASSIGN_OR_RAISE(auto options_scalar
,
1020 internal::FunctionOptionsToStructScalar(*call
->options
));
1021 ARROW_ASSIGN_OR_RAISE(auto value
, AddScalar(*options_scalar
));
1022 metadata_
->Append("options", std::move(value
));
1025 metadata_
->Append("end", call
->function_name
);
1026 return Status::OK();
1029 Result
<std::shared_ptr
<RecordBatch
>> operator()(const Expression
& expr
) {
1030 RETURN_NOT_OK(Visit(expr
));
1031 FieldVector
fields(columns_
.size());
1032 for (size_t i
= 0; i
< fields
.size(); ++i
) {
1033 fields
[i
] = field("", columns_
[i
]->type());
1035 return RecordBatch::Make(schema(std::move(fields
), std::move(metadata_
)), 1,
1036 std::move(columns_
));
1040 ARROW_ASSIGN_OR_RAISE(auto batch
, ToRecordBatch(expr
));
1041 ARROW_ASSIGN_OR_RAISE(auto stream
, io::BufferOutputStream::Create());
1042 ARROW_ASSIGN_OR_RAISE(auto writer
, ipc::MakeFileWriter(stream
, batch
->schema()));
1043 RETURN_NOT_OK(writer
->WriteRecordBatch(*batch
));
1044 RETURN_NOT_OK(writer
->Close());
1045 return stream
->Finish();
1048 Result
<Expression
> Deserialize(std::shared_ptr
<Buffer
> buffer
) {
1049 io::BufferReader
stream(std::move(buffer
));
1050 ARROW_ASSIGN_OR_RAISE(auto reader
, ipc::RecordBatchFileReader::Open(&stream
));
1051 ARROW_ASSIGN_OR_RAISE(auto batch
, reader
->ReadRecordBatch(0));
1052 if (batch
->schema()->metadata() == nullptr) {
1053 return Status::Invalid("serialized Expression's batch repr had null metadata");
1055 if (batch
->num_rows() != 1) {
1056 return Status::Invalid(
1057 "serialized Expression's batch repr was not a single row - had ",
1061 struct FromRecordBatch
{
1062 const RecordBatch
& batch_
;
1065 const KeyValueMetadata
& metadata() { return *batch_
.schema()->metadata(); }
1067 Result
<std::shared_ptr
<Scalar
>> GetScalar(const std::string
& i
) {
1068 int32_t column_index
;
1069 if (!::arrow::internal::ParseValue
<Int32Type
>(i
.data(), i
.length(),
1071 return Status::Invalid("Couldn't parse column_index");
1073 if (column_index
>= batch_
.num_columns()) {
1074 return Status::Invalid("column_index out of bounds");
1076 return batch_
.column(column_index
)->GetScalar(0);
1079 Result
<Expression
> GetOne() {
1080 if (index_
>= metadata().size()) {
1081 return Status::Invalid("unterminated serialized Expression");
1084 const std::string
& key
= metadata().key(index_
);
1085 const std::string
& value
= metadata().value(index_
);
1088 if (key
== "literal") {
1089 ARROW_ASSIGN_OR_RAISE(auto scalar
, GetScalar(value
));
1090 return literal(std::move(scalar
));
1093 if (key
== "field_ref") {
1094 return field_ref(value
);
1097 if (key
!= "call") {
1098 return Status::Invalid("Unrecognized serialized Expression key ", key
);
1101 std::vector
<Expression
> arguments
;
1102 while (metadata().key(index_
) != "end") {
1103 if (metadata().key(index_
) == "options") {
1104 ARROW_ASSIGN_OR_RAISE(auto options_scalar
, GetScalar(metadata().value(index_
)));
1105 std::shared_ptr
<compute::FunctionOptions
> options
;
1106 if (options_scalar
) {
1107 ARROW_ASSIGN_OR_RAISE(
1108 options
, internal::FunctionOptionsFromStructScalar(
1109 checked_cast
<const StructScalar
&>(*options_scalar
)));
1111 auto expr
= call(value
, std::move(arguments
), std::move(options
));
1116 ARROW_ASSIGN_OR_RAISE(auto argument
, GetOne());
1117 arguments
.push_back(std::move(argument
));
1121 return call(value
, std::move(arguments
));
1125 return FromRecordBatch
{*batch
, 0}.GetOne();
1128 Expression
project(std::vector
<Expression
> values
, std::vector
<std::string
> names
) {
1129 return call("make_struct", std::move(values
),
1130 compute::MakeStructOptions
{std::move(names
)});
1133 Expression
equal(Expression lhs
, Expression rhs
) {
1134 return call("equal", {std::move(lhs
), std::move(rhs
)});
1137 Expression
not_equal(Expression lhs
, Expression rhs
) {
1138 return call("not_equal", {std::move(lhs
), std::move(rhs
)});
1141 Expression
less(Expression lhs
, Expression rhs
) {
1142 return call("less", {std::move(lhs
), std::move(rhs
)});
1145 Expression
less_equal(Expression lhs
, Expression rhs
) {
1146 return call("less_equal", {std::move(lhs
), std::move(rhs
)});
1149 Expression
greater(Expression lhs
, Expression rhs
) {
1150 return call("greater", {std::move(lhs
), std::move(rhs
)});
1153 Expression
greater_equal(Expression lhs
, Expression rhs
) {
1154 return call("greater_equal", {std::move(lhs
), std::move(rhs
)});
1157 Expression
is_null(Expression lhs
, bool nan_is_null
) {
1158 return call("is_null", {std::move(lhs
)}, compute::NullOptions(std::move(nan_is_null
)));
1161 Expression
is_valid(Expression lhs
) { return call("is_valid", {std::move(lhs
)}); }
1163 Expression
and_(Expression lhs
, Expression rhs
) {
1164 return call("and_kleene", {std::move(lhs
), std::move(rhs
)});
1167 Expression
and_(const std::vector
<Expression
>& operands
) {
1168 auto folded
= FoldLeft
<Expression(Expression
, Expression
)>(operands
.begin(),
1169 operands
.end(), and_
);
1171 return std::move(*folded
);
1173 return literal(true);
1176 Expression
or_(Expression lhs
, Expression rhs
) {
1177 return call("or_kleene", {std::move(lhs
), std::move(rhs
)});
1180 Expression
or_(const std::vector
<Expression
>& operands
) {
1182 FoldLeft
<Expression(Expression
, Expression
)>(operands
.begin(), operands
.end(), or_
);
1184 return std::move(*folded
);
1186 return literal(false);
1189 Expression
not_(Expression operand
) { return call("invert", {std::move(operand
)}); }
1191 } // namespace compute
1192 } // namespace arrow