]> git.proxmox.com Git - ceph.git/blob - ceph/src/arrow/cpp/src/arrow/compute/exec/expression.cc
import quincy 17.2.0
[ceph.git] / ceph / src / arrow / cpp / src / arrow / compute / exec / expression.cc
1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements. See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership. The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License. You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied. See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17
18 #include "arrow/compute/exec/expression.h"
19
20 #include <unordered_map>
21 #include <unordered_set>
22
23 #include "arrow/chunked_array.h"
24 #include "arrow/compute/api_vector.h"
25 #include "arrow/compute/exec/expression_internal.h"
26 #include "arrow/compute/exec_internal.h"
27 #include "arrow/compute/function_internal.h"
28 #include "arrow/io/memory.h"
29 #include "arrow/ipc/reader.h"
30 #include "arrow/ipc/writer.h"
31 #include "arrow/util/hash_util.h"
32 #include "arrow/util/key_value_metadata.h"
33 #include "arrow/util/logging.h"
34 #include "arrow/util/optional.h"
35 #include "arrow/util/string.h"
36 #include "arrow/util/value_parsing.h"
37
38 namespace arrow {
39
40 using internal::checked_cast;
41 using internal::checked_pointer_cast;
42
43 namespace compute {
44
45 void Expression::Call::ComputeHash() {
46 hash = std::hash<std::string>{}(function_name);
47 for (const auto& arg : arguments) {
48 arrow::internal::hash_combine(hash, arg.hash());
49 }
50 }
51
52 Expression::Expression(Call call) {
53 call.ComputeHash();
54 impl_ = std::make_shared<Impl>(std::move(call));
55 }
56
57 Expression::Expression(Datum literal)
58 : impl_(std::make_shared<Impl>(std::move(literal))) {}
59
60 Expression::Expression(Parameter parameter)
61 : impl_(std::make_shared<Impl>(std::move(parameter))) {}
62
63 Expression literal(Datum lit) { return Expression(std::move(lit)); }
64
65 Expression field_ref(FieldRef ref) {
66 return Expression(Expression::Parameter{std::move(ref), ValueDescr{}, -1});
67 }
68
69 Expression call(std::string function, std::vector<Expression> arguments,
70 std::shared_ptr<compute::FunctionOptions> options) {
71 Expression::Call call;
72 call.function_name = std::move(function);
73 call.arguments = std::move(arguments);
74 call.options = std::move(options);
75 return Expression(std::move(call));
76 }
77
78 const Datum* Expression::literal() const { return util::get_if<Datum>(impl_.get()); }
79
80 const Expression::Parameter* Expression::parameter() const {
81 return util::get_if<Parameter>(impl_.get());
82 }
83
84 const FieldRef* Expression::field_ref() const {
85 if (auto parameter = this->parameter()) {
86 return &parameter->ref;
87 }
88 return nullptr;
89 }
90
91 const Expression::Call* Expression::call() const {
92 return util::get_if<Call>(impl_.get());
93 }
94
95 ValueDescr Expression::descr() const {
96 if (impl_ == nullptr) return {};
97
98 if (auto lit = literal()) {
99 return lit->descr();
100 }
101
102 if (auto parameter = this->parameter()) {
103 return parameter->descr;
104 }
105
106 return CallNotNull(*this)->descr;
107 }
108
109 namespace {
110
111 std::string PrintDatum(const Datum& datum) {
112 if (datum.is_scalar()) {
113 if (!datum.scalar()->is_valid) return "null";
114
115 switch (datum.type()->id()) {
116 case Type::STRING:
117 case Type::LARGE_STRING:
118 return '"' +
119 Escape(util::string_view(*datum.scalar_as<BaseBinaryScalar>().value)) +
120 '"';
121
122 case Type::BINARY:
123 case Type::FIXED_SIZE_BINARY:
124 case Type::LARGE_BINARY:
125 return '"' + datum.scalar_as<BaseBinaryScalar>().value->ToHexString() + '"';
126
127 default:
128 break;
129 }
130
131 return datum.scalar()->ToString();
132 }
133 return datum.ToString();
134 }
135
136 } // namespace
137
138 std::string Expression::ToString() const {
139 if (auto lit = literal()) {
140 return PrintDatum(*lit);
141 }
142
143 if (auto ref = field_ref()) {
144 if (auto name = ref->name()) {
145 return *name;
146 }
147 if (auto path = ref->field_path()) {
148 return path->ToString();
149 }
150 return ref->ToString();
151 }
152
153 auto call = CallNotNull(*this);
154 auto binary = [&](std::string op) {
155 return "(" + call->arguments[0].ToString() + " " + op + " " +
156 call->arguments[1].ToString() + ")";
157 };
158
159 if (auto cmp = Comparison::Get(call->function_name)) {
160 return binary(Comparison::GetOp(*cmp));
161 }
162
163 constexpr util::string_view kleene = "_kleene";
164 if (util::string_view{call->function_name}.ends_with(kleene)) {
165 auto op = call->function_name.substr(0, call->function_name.size() - kleene.size());
166 return binary(std::move(op));
167 }
168
169 if (auto options = GetMakeStructOptions(*call)) {
170 std::string out = "{";
171 auto argument = call->arguments.begin();
172 for (const auto& field_name : options->field_names) {
173 out += field_name + "=" + argument++->ToString() + ", ";
174 }
175 out.resize(out.size() - 1);
176 out.back() = '}';
177 return out;
178 }
179
180 std::string out = call->function_name + "(";
181 for (const auto& arg : call->arguments) {
182 out += arg.ToString() + ", ";
183 }
184
185 if (call->options) {
186 out += call->options->ToString();
187 out.resize(out.size() + 1);
188 } else {
189 out.resize(out.size() - 1);
190 }
191 out.back() = ')';
192 return out;
193 }
194
195 void PrintTo(const Expression& expr, std::ostream* os) {
196 *os << expr.ToString();
197 if (expr.IsBound()) {
198 *os << "[bound]";
199 }
200 }
201
202 bool Expression::Equals(const Expression& other) const {
203 if (Identical(*this, other)) return true;
204
205 if (impl_->index() != other.impl_->index()) {
206 return false;
207 }
208
209 if (auto lit = literal()) {
210 return lit->Equals(*other.literal());
211 }
212
213 if (auto ref = field_ref()) {
214 return ref->Equals(*other.field_ref());
215 }
216
217 auto call = CallNotNull(*this);
218 auto other_call = CallNotNull(other);
219
220 if (call->function_name != other_call->function_name ||
221 call->kernel != other_call->kernel) {
222 return false;
223 }
224
225 for (size_t i = 0; i < call->arguments.size(); ++i) {
226 if (!call->arguments[i].Equals(other_call->arguments[i])) {
227 return false;
228 }
229 }
230
231 if (call->options == other_call->options) return true;
232 if (call->options && other_call->options) {
233 return call->options->Equals(other_call->options);
234 }
235 return false;
236 }
237
238 bool Identical(const Expression& l, const Expression& r) { return l.impl_ == r.impl_; }
239
240 size_t Expression::hash() const {
241 if (auto lit = literal()) {
242 if (lit->is_scalar()) {
243 return lit->scalar()->hash();
244 }
245 return 0;
246 }
247
248 if (auto ref = field_ref()) {
249 return ref->hash();
250 }
251
252 return CallNotNull(*this)->hash;
253 }
254
255 bool Expression::IsBound() const {
256 if (type() == nullptr) return false;
257
258 if (auto call = this->call()) {
259 if (call->kernel == nullptr) return false;
260
261 for (const Expression& arg : call->arguments) {
262 if (!arg.IsBound()) return false;
263 }
264 }
265
266 return true;
267 }
268
269 bool Expression::IsScalarExpression() const {
270 if (auto lit = literal()) {
271 return lit->is_scalar();
272 }
273
274 if (field_ref()) return true;
275
276 auto call = CallNotNull(*this);
277
278 for (const Expression& arg : call->arguments) {
279 if (!arg.IsScalarExpression()) return false;
280 }
281
282 if (call->function) {
283 return call->function->kind() == compute::Function::SCALAR;
284 }
285
286 // this expression is not bound; make a best guess based on
287 // the default function registry
288 if (auto function = compute::GetFunctionRegistry()
289 ->GetFunction(call->function_name)
290 .ValueOr(nullptr)) {
291 return function->kind() == compute::Function::SCALAR;
292 }
293
294 // unknown function or other error; conservatively return false
295 return false;
296 }
297
298 bool Expression::IsNullLiteral() const {
299 if (auto lit = literal()) {
300 if (lit->null_count() == lit->length()) {
301 return true;
302 }
303 }
304
305 return false;
306 }
307
308 bool Expression::IsSatisfiable() const {
309 if (type() && type()->id() == Type::NA) {
310 return false;
311 }
312
313 if (auto lit = literal()) {
314 if (lit->null_count() == lit->length()) {
315 return false;
316 }
317
318 if (lit->is_scalar() && lit->type()->id() == Type::BOOL) {
319 return lit->scalar_as<BooleanScalar>().value;
320 }
321 }
322
323 return true;
324 }
325
326 namespace {
327
328 // Produce a bound Expression from unbound Call and bound arguments.
329 Result<Expression> BindNonRecursive(Expression::Call call, bool insert_implicit_casts,
330 compute::ExecContext* exec_context) {
331 DCHECK(std::all_of(call.arguments.begin(), call.arguments.end(),
332 [](const Expression& argument) { return argument.IsBound(); }));
333
334 auto descrs = GetDescriptors(call.arguments);
335 ARROW_ASSIGN_OR_RAISE(call.function, GetFunction(call, exec_context));
336
337 if (!insert_implicit_casts) {
338 ARROW_ASSIGN_OR_RAISE(call.kernel, call.function->DispatchExact(descrs));
339 } else {
340 ARROW_ASSIGN_OR_RAISE(call.kernel, call.function->DispatchBest(&descrs));
341
342 for (size_t i = 0; i < descrs.size(); ++i) {
343 if (descrs[i] == call.arguments[i].descr()) continue;
344
345 if (descrs[i].shape != call.arguments[i].descr().shape) {
346 return Status::NotImplemented(
347 "Automatic broadcasting of scalars arguments to arrays in ",
348 Expression(std::move(call)).ToString());
349 }
350
351 if (auto lit = call.arguments[i].literal()) {
352 ARROW_ASSIGN_OR_RAISE(Datum new_lit, compute::Cast(*lit, descrs[i].type));
353 call.arguments[i] = literal(std::move(new_lit));
354 continue;
355 }
356
357 // construct an implicit cast Expression with which to replace this argument
358 Expression::Call implicit_cast;
359 implicit_cast.function_name = "cast";
360 implicit_cast.arguments = {std::move(call.arguments[i])};
361 implicit_cast.options = std::make_shared<compute::CastOptions>(
362 compute::CastOptions::Safe(descrs[i].type));
363
364 ARROW_ASSIGN_OR_RAISE(
365 call.arguments[i],
366 BindNonRecursive(std::move(implicit_cast),
367 /*insert_implicit_casts=*/false, exec_context));
368 }
369 }
370
371 compute::KernelContext kernel_context(exec_context);
372 if (call.kernel->init) {
373 ARROW_ASSIGN_OR_RAISE(
374 call.kernel_state,
375 call.kernel->init(&kernel_context, {call.kernel, descrs, call.options.get()}));
376
377 kernel_context.SetState(call.kernel_state.get());
378 }
379
380 ARROW_ASSIGN_OR_RAISE(
381 call.descr, call.kernel->signature->out_type().Resolve(&kernel_context, descrs));
382
383 return Expression(std::move(call));
384 }
385
386 template <typename TypeOrSchema>
387 Result<Expression> BindImpl(Expression expr, const TypeOrSchema& in,
388 ValueDescr::Shape shape, compute::ExecContext* exec_context) {
389 if (exec_context == nullptr) {
390 compute::ExecContext exec_context;
391 return BindImpl(std::move(expr), in, shape, &exec_context);
392 }
393
394 if (expr.literal()) return expr;
395
396 if (auto ref = expr.field_ref()) {
397 if (ref->IsNested()) {
398 return Status::NotImplemented("nested field references");
399 }
400
401 ARROW_ASSIGN_OR_RAISE(auto path, ref->FindOne(in));
402
403 auto bound = *expr.parameter();
404 bound.index = path[0];
405 ARROW_ASSIGN_OR_RAISE(auto field, path.Get(in));
406 bound.descr.type = field->type();
407 bound.descr.shape = shape;
408 return Expression{std::move(bound)};
409 }
410
411 auto call = *CallNotNull(expr);
412 for (auto& argument : call.arguments) {
413 ARROW_ASSIGN_OR_RAISE(argument,
414 BindImpl(std::move(argument), in, shape, exec_context));
415 }
416 return BindNonRecursive(std::move(call),
417 /*insert_implicit_casts=*/true, exec_context);
418 }
419
420 } // namespace
421
422 Result<Expression> Expression::Bind(const ValueDescr& in,
423 compute::ExecContext* exec_context) const {
424 return BindImpl(*this, *in.type, in.shape, exec_context);
425 }
426
427 Result<Expression> Expression::Bind(const Schema& in_schema,
428 compute::ExecContext* exec_context) const {
429 return BindImpl(*this, in_schema, ValueDescr::ARRAY, exec_context);
430 }
431
432 Result<ExecBatch> MakeExecBatch(const Schema& full_schema, const Datum& partial) {
433 ExecBatch out;
434
435 if (partial.kind() == Datum::RECORD_BATCH) {
436 const auto& partial_batch = *partial.record_batch();
437 out.length = partial_batch.num_rows();
438
439 for (const auto& field : full_schema.fields()) {
440 ARROW_ASSIGN_OR_RAISE(auto column,
441 FieldRef(field->name()).GetOneOrNone(partial_batch));
442
443 if (column) {
444 if (!column->type()->Equals(field->type())) {
445 // Referenced field was present but didn't have the expected type.
446 // This *should* be handled by readers, and will just be an error in the future.
447 ARROW_ASSIGN_OR_RAISE(
448 auto converted,
449 compute::Cast(column, field->type(), compute::CastOptions::Safe()));
450 column = converted.make_array();
451 }
452 out.values.emplace_back(std::move(column));
453 } else {
454 out.values.emplace_back(MakeNullScalar(field->type()));
455 }
456 }
457 return out;
458 }
459
460 // wasteful but useful for testing:
461 if (partial.type()->id() == Type::STRUCT) {
462 if (partial.is_array()) {
463 ARROW_ASSIGN_OR_RAISE(auto partial_batch,
464 RecordBatch::FromStructArray(partial.make_array()));
465
466 return MakeExecBatch(full_schema, partial_batch);
467 }
468
469 if (partial.is_scalar()) {
470 ARROW_ASSIGN_OR_RAISE(auto partial_array,
471 MakeArrayFromScalar(*partial.scalar(), 1));
472 ARROW_ASSIGN_OR_RAISE(auto out, MakeExecBatch(full_schema, partial_array));
473
474 for (Datum& value : out.values) {
475 if (value.is_scalar()) continue;
476 ARROW_ASSIGN_OR_RAISE(value, value.make_array()->GetScalar(0));
477 }
478 return out;
479 }
480 }
481
482 return Status::NotImplemented("MakeExecBatch from ", PrintDatum(partial));
483 }
484
485 Result<Datum> ExecuteScalarExpression(const Expression& expr, const Schema& full_schema,
486 const Datum& partial_input,
487 compute::ExecContext* exec_context) {
488 ARROW_ASSIGN_OR_RAISE(auto input, MakeExecBatch(full_schema, partial_input));
489 return ExecuteScalarExpression(expr, input, exec_context);
490 }
491
492 Result<Datum> ExecuteScalarExpression(const Expression& expr, const ExecBatch& input,
493 compute::ExecContext* exec_context) {
494 if (exec_context == nullptr) {
495 compute::ExecContext exec_context;
496 return ExecuteScalarExpression(expr, input, &exec_context);
497 }
498
499 if (!expr.IsBound()) {
500 return Status::Invalid("Cannot Execute unbound expression.");
501 }
502
503 if (!expr.IsScalarExpression()) {
504 return Status::Invalid(
505 "ExecuteScalarExpression cannot Execute non-scalar expression ", expr.ToString());
506 }
507
508 if (auto lit = expr.literal()) return *lit;
509
510 if (auto param = expr.parameter()) {
511 if (param->descr.type->id() == Type::NA) {
512 return MakeNullScalar(null());
513 }
514
515 const Datum& field = input[param->index];
516 if (!field.type()->Equals(param->descr.type)) {
517 return Status::Invalid("Referenced field ", expr.ToString(), " was ",
518 field.type()->ToString(), " but should have been ",
519 param->descr.type->ToString());
520 }
521
522 return field;
523 }
524
525 auto call = CallNotNull(expr);
526
527 std::vector<Datum> arguments(call->arguments.size());
528 for (size_t i = 0; i < arguments.size(); ++i) {
529 ARROW_ASSIGN_OR_RAISE(
530 arguments[i], ExecuteScalarExpression(call->arguments[i], input, exec_context));
531 }
532
533 auto executor = compute::detail::KernelExecutor::MakeScalar();
534
535 compute::KernelContext kernel_context(exec_context);
536 kernel_context.SetState(call->kernel_state.get());
537
538 auto kernel = call->kernel;
539 auto descrs = GetDescriptors(arguments);
540 auto options = call->options.get();
541 RETURN_NOT_OK(executor->Init(&kernel_context, {kernel, descrs, options}));
542
543 compute::detail::DatumAccumulator listener;
544 RETURN_NOT_OK(executor->Execute(arguments, &listener));
545 const auto out = executor->WrapResults(arguments, listener.values());
546 #ifndef NDEBUG
547 DCHECK_OK(executor->CheckResultType(out, call->function_name.c_str()));
548 #endif
549 return out;
550 }
551
552 namespace {
553
554 std::array<std::pair<const Expression&, const Expression&>, 2>
555 ArgumentsAndFlippedArguments(const Expression::Call& call) {
556 DCHECK_EQ(call.arguments.size(), 2);
557 return {std::pair<const Expression&, const Expression&>{call.arguments[0],
558 call.arguments[1]},
559 std::pair<const Expression&, const Expression&>{call.arguments[1],
560 call.arguments[0]}};
561 }
562
563 template <typename BinOp, typename It,
564 typename Out = typename std::iterator_traits<It>::value_type>
565 util::optional<Out> FoldLeft(It begin, It end, const BinOp& bin_op) {
566 if (begin == end) return util::nullopt;
567
568 Out folded = std::move(*begin++);
569 while (begin != end) {
570 folded = bin_op(std::move(folded), std::move(*begin++));
571 }
572 return folded;
573 }
574
575 util::optional<compute::NullHandling::type> GetNullHandling(
576 const Expression::Call& call) {
577 if (call.function && call.function->kind() == compute::Function::SCALAR) {
578 return static_cast<const compute::ScalarKernel*>(call.kernel)->null_handling;
579 }
580 return util::nullopt;
581 }
582
583 } // namespace
584
585 std::vector<FieldRef> FieldsInExpression(const Expression& expr) {
586 if (expr.literal()) return {};
587
588 if (auto ref = expr.field_ref()) {
589 return {*ref};
590 }
591
592 std::vector<FieldRef> fields;
593 for (const Expression& arg : CallNotNull(expr)->arguments) {
594 auto argument_fields = FieldsInExpression(arg);
595 std::move(argument_fields.begin(), argument_fields.end(), std::back_inserter(fields));
596 }
597 return fields;
598 }
599
600 bool ExpressionHasFieldRefs(const Expression& expr) {
601 if (expr.literal()) return false;
602
603 if (expr.field_ref()) return true;
604
605 for (const Expression& arg : CallNotNull(expr)->arguments) {
606 if (ExpressionHasFieldRefs(arg)) return true;
607 }
608 return false;
609 }
610
611 Result<Expression> FoldConstants(Expression expr) {
612 return Modify(
613 std::move(expr), [](Expression expr) { return expr; },
614 [](Expression expr, ...) -> Result<Expression> {
615 auto call = CallNotNull(expr);
616 if (std::all_of(call->arguments.begin(), call->arguments.end(),
617 [](const Expression& argument) { return argument.literal(); })) {
618 // all arguments are literal; we can evaluate this subexpression *now*
619 static const ExecBatch ignored_input = ExecBatch{};
620 ARROW_ASSIGN_OR_RAISE(Datum constant,
621 ExecuteScalarExpression(expr, ignored_input));
622
623 return literal(std::move(constant));
624 }
625
626 // XXX the following should probably be in a registry of passes instead
627 // of inline
628
629 if (GetNullHandling(*call) == compute::NullHandling::INTERSECTION) {
630 // kernels which always produce intersected validity can be resolved
631 // to null *now* if any of their inputs is a null literal
632 for (const auto& argument : call->arguments) {
633 if (argument.IsNullLiteral()) {
634 return argument;
635 }
636 }
637 }
638
639 if (call->function_name == "and_kleene") {
640 for (auto args : ArgumentsAndFlippedArguments(*call)) {
641 // true and x == x
642 if (args.first == literal(true)) return args.second;
643
644 // false and x == false
645 if (args.first == literal(false)) return args.first;
646
647 // x and x == x
648 if (args.first == args.second) return args.first;
649 }
650 return expr;
651 }
652
653 if (call->function_name == "or_kleene") {
654 for (auto args : ArgumentsAndFlippedArguments(*call)) {
655 // false or x == x
656 if (args.first == literal(false)) return args.second;
657
658 // true or x == true
659 if (args.first == literal(true)) return args.first;
660
661 // x or x == x
662 if (args.first == args.second) return args.first;
663 }
664 return expr;
665 }
666
667 return expr;
668 });
669 }
670
671 namespace {
672
673 std::vector<Expression> GuaranteeConjunctionMembers(
674 const Expression& guaranteed_true_predicate) {
675 auto guarantee = guaranteed_true_predicate.call();
676 if (!guarantee || guarantee->function_name != "and_kleene") {
677 return {guaranteed_true_predicate};
678 }
679 return FlattenedAssociativeChain(guaranteed_true_predicate).fringe;
680 }
681
682 // Conjunction members which are represented in known_values are erased from
683 // conjunction_members
684 Status ExtractKnownFieldValuesImpl(
685 std::vector<Expression>* conjunction_members,
686 std::unordered_map<FieldRef, Datum, FieldRef::Hash>* known_values) {
687 auto unconsumed_end =
688 std::partition(conjunction_members->begin(), conjunction_members->end(),
689 [](const Expression& expr) {
690 // search for an equality conditions between a field and a literal
691 auto call = expr.call();
692 if (!call) return true;
693
694 if (call->function_name == "equal") {
695 auto ref = call->arguments[0].field_ref();
696 auto lit = call->arguments[1].literal();
697 return !(ref && lit);
698 }
699
700 if (call->function_name == "is_null") {
701 auto ref = call->arguments[0].field_ref();
702 return !ref;
703 }
704
705 return true;
706 });
707
708 for (auto it = unconsumed_end; it != conjunction_members->end(); ++it) {
709 auto call = CallNotNull(*it);
710
711 if (call->function_name == "equal") {
712 auto ref = call->arguments[0].field_ref();
713 auto lit = call->arguments[1].literal();
714 known_values->emplace(*ref, *lit);
715 } else if (call->function_name == "is_null") {
716 auto ref = call->arguments[0].field_ref();
717 known_values->emplace(*ref, Datum(std::make_shared<NullScalar>()));
718 }
719 }
720
721 conjunction_members->erase(unconsumed_end, conjunction_members->end());
722
723 return Status::OK();
724 }
725
726 } // namespace
727
728 Result<KnownFieldValues> ExtractKnownFieldValues(
729 const Expression& guaranteed_true_predicate) {
730 auto conjunction_members = GuaranteeConjunctionMembers(guaranteed_true_predicate);
731 KnownFieldValues known_values;
732 RETURN_NOT_OK(ExtractKnownFieldValuesImpl(&conjunction_members, &known_values.map));
733 return known_values;
734 }
735
736 Result<Expression> ReplaceFieldsWithKnownValues(const KnownFieldValues& known_values,
737 Expression expr) {
738 if (!expr.IsBound()) {
739 return Status::Invalid(
740 "ReplaceFieldsWithKnownValues called on an unbound Expression");
741 }
742
743 return Modify(
744 std::move(expr),
745 [&known_values](Expression expr) -> Result<Expression> {
746 if (auto ref = expr.field_ref()) {
747 auto it = known_values.map.find(*ref);
748 if (it != known_values.map.end()) {
749 Datum lit = it->second;
750 if (lit.descr() == expr.descr()) return literal(std::move(lit));
751 // type mismatch, try casting the known value to the correct type
752
753 if (expr.type()->id() == Type::DICTIONARY &&
754 lit.type()->id() != Type::DICTIONARY) {
755 // the known value must be dictionary encoded
756
757 const auto& dict_type = checked_cast<const DictionaryType&>(*expr.type());
758 if (!lit.type()->Equals(dict_type.value_type())) {
759 ARROW_ASSIGN_OR_RAISE(lit, compute::Cast(lit, dict_type.value_type()));
760 }
761
762 if (lit.is_scalar()) {
763 ARROW_ASSIGN_OR_RAISE(auto dictionary,
764 MakeArrayFromScalar(*lit.scalar(), 1));
765
766 lit = Datum{DictionaryScalar::Make(MakeScalar<int32_t>(0),
767 std::move(dictionary))};
768 }
769 }
770
771 ARROW_ASSIGN_OR_RAISE(lit, compute::Cast(lit, expr.type()));
772 return literal(std::move(lit));
773 }
774 }
775 return expr;
776 },
777 [](Expression expr, ...) { return expr; });
778 }
779
780 namespace {
781
782 bool IsBinaryAssociativeCommutative(const Expression::Call& call) {
783 static std::unordered_set<std::string> binary_associative_commutative{
784 "and", "or", "and_kleene", "or_kleene", "xor",
785 "multiply", "add", "multiply_checked", "add_checked"};
786
787 auto it = binary_associative_commutative.find(call.function_name);
788 return it != binary_associative_commutative.end();
789 }
790
791 } // namespace
792
793 Result<Expression> Canonicalize(Expression expr, compute::ExecContext* exec_context) {
794 if (exec_context == nullptr) {
795 compute::ExecContext exec_context;
796 return Canonicalize(std::move(expr), &exec_context);
797 }
798
799 // If potentially reconstructing more deeply than a call's immediate arguments
800 // (for example, when reorganizing an associative chain), add expressions to this set to
801 // avoid unnecessary work
802 struct {
803 std::unordered_set<Expression, Expression::Hash> set_;
804
805 bool operator()(const Expression& expr) const {
806 return set_.find(expr) != set_.end();
807 }
808
809 void Add(std::vector<Expression> exprs) {
810 std::move(exprs.begin(), exprs.end(), std::inserter(set_, set_.end()));
811 }
812 } AlreadyCanonicalized;
813
814 return Modify(
815 std::move(expr),
816 [&AlreadyCanonicalized, exec_context](Expression expr) -> Result<Expression> {
817 auto call = expr.call();
818 if (!call) return expr;
819
820 if (AlreadyCanonicalized(expr)) return expr;
821
822 if (IsBinaryAssociativeCommutative(*call)) {
823 struct {
824 int Priority(const Expression& operand) const {
825 // order literals first, starting with nulls
826 if (operand.IsNullLiteral()) return 0;
827 if (operand.literal()) return 1;
828 return 2;
829 }
830 bool operator()(const Expression& l, const Expression& r) const {
831 return Priority(l) < Priority(r);
832 }
833 } CanonicalOrdering;
834
835 FlattenedAssociativeChain chain(expr);
836 if (chain.was_left_folded &&
837 std::is_sorted(chain.fringe.begin(), chain.fringe.end(),
838 CanonicalOrdering)) {
839 AlreadyCanonicalized.Add(std::move(chain.exprs));
840 return expr;
841 }
842
843 std::stable_sort(chain.fringe.begin(), chain.fringe.end(), CanonicalOrdering);
844
845 // fold the chain back up
846 auto folded =
847 FoldLeft(chain.fringe.begin(), chain.fringe.end(),
848 [call, &AlreadyCanonicalized](Expression l, Expression r) {
849 auto canonicalized_call = *call;
850 canonicalized_call.arguments = {std::move(l), std::move(r)};
851 Expression expr(std::move(canonicalized_call));
852 AlreadyCanonicalized.Add({expr});
853 return expr;
854 });
855 return std::move(*folded);
856 }
857
858 if (auto cmp = Comparison::Get(call->function_name)) {
859 if (call->arguments[0].literal() && !call->arguments[1].literal()) {
860 // ensure that literals are on comparisons' RHS
861 auto flipped_call = *call;
862
863 std::swap(flipped_call.arguments[0], flipped_call.arguments[1]);
864 flipped_call.function_name =
865 Comparison::GetName(Comparison::GetFlipped(*cmp));
866
867 return BindNonRecursive(flipped_call,
868 /*insert_implicit_casts=*/false, exec_context);
869 }
870 }
871
872 return expr;
873 },
874 [](Expression expr, ...) { return expr; });
875 }
876
877 namespace {
878
879 Result<Expression> DirectComparisonSimplification(Expression expr,
880 const Expression::Call& guarantee) {
881 return Modify(
882 std::move(expr), [](Expression expr) { return expr; },
883 [&guarantee](Expression expr, ...) -> Result<Expression> {
884 auto call = expr.call();
885 if (!call) return expr;
886
887 // Ensure both calls are comparisons with equal LHS and scalar RHS
888 auto cmp = Comparison::Get(expr);
889 auto cmp_guarantee = Comparison::Get(guarantee.function_name);
890
891 if (!cmp) return expr;
892 if (!cmp_guarantee) return expr;
893
894 const auto& lhs = Comparison::StripOrderPreservingCasts(call->arguments[0]);
895 const auto& guarantee_lhs = guarantee.arguments[0];
896 if (lhs != guarantee_lhs) return expr;
897
898 auto rhs = call->arguments[1].literal();
899 auto guarantee_rhs = guarantee.arguments[1].literal();
900
901 if (!rhs) return expr;
902 if (!rhs->is_scalar()) return expr;
903
904 if (!guarantee_rhs) return expr;
905 if (!guarantee_rhs->is_scalar()) return expr;
906
907 ARROW_ASSIGN_OR_RAISE(auto cmp_rhs_guarantee_rhs,
908 Comparison::Execute(*rhs, *guarantee_rhs));
909 DCHECK_NE(cmp_rhs_guarantee_rhs, Comparison::NA);
910
911 if (cmp_rhs_guarantee_rhs == Comparison::EQUAL) {
912 // RHS of filter is equal to RHS of guarantee
913
914 if ((*cmp & *cmp_guarantee) == *cmp_guarantee) {
915 // guarantee is a subset of filter, so all data will be included
916 // x > 1, x >= 1, x != 1 guaranteed by x > 1
917 return literal(true);
918 }
919
920 if ((*cmp & *cmp_guarantee) == 0) {
921 // guarantee disjoint with filter, so all data will be excluded
922 // x > 1, x >= 1, x != 1 unsatisfiable if x == 1
923 return literal(false);
924 }
925
926 return expr;
927 }
928
929 if (*cmp_guarantee & cmp_rhs_guarantee_rhs) {
930 // x > 1, x >= 1, x != 1 cannot use guarantee x >= 3
931 return expr;
932 }
933
934 if (*cmp & Comparison::GetFlipped(cmp_rhs_guarantee_rhs)) {
935 // x > 1, x >= 1, x != 1 guaranteed by x >= 3
936 return literal(true);
937 } else {
938 // x < 1, x <= 1, x == 1 unsatisfiable if x >= 3
939 return literal(false);
940 }
941 });
942 }
943
944 } // namespace
945
946 Result<Expression> SimplifyWithGuarantee(Expression expr,
947 const Expression& guaranteed_true_predicate) {
948 auto conjunction_members = GuaranteeConjunctionMembers(guaranteed_true_predicate);
949
950 KnownFieldValues known_values;
951 RETURN_NOT_OK(ExtractKnownFieldValuesImpl(&conjunction_members, &known_values.map));
952
953 ARROW_ASSIGN_OR_RAISE(expr,
954 ReplaceFieldsWithKnownValues(known_values, std::move(expr)));
955
956 auto CanonicalizeAndFoldConstants = [&expr] {
957 ARROW_ASSIGN_OR_RAISE(expr, Canonicalize(std::move(expr)));
958 ARROW_ASSIGN_OR_RAISE(expr, FoldConstants(std::move(expr)));
959 return Status::OK();
960 };
961 RETURN_NOT_OK(CanonicalizeAndFoldConstants());
962
963 for (const auto& guarantee : conjunction_members) {
964 if (Comparison::Get(guarantee) && guarantee.call()->arguments[1].literal()) {
965 ARROW_ASSIGN_OR_RAISE(
966 auto simplified, DirectComparisonSimplification(expr, *CallNotNull(guarantee)));
967
968 if (Identical(simplified, expr)) continue;
969
970 expr = std::move(simplified);
971 RETURN_NOT_OK(CanonicalizeAndFoldConstants());
972 }
973 }
974
975 return expr;
976 }
977
978 // Serialization is accomplished by converting expressions to KeyValueMetadata and storing
979 // this in the schema of a RecordBatch. Embedded arrays and scalars are stored in its
980 // columns. Finally, the RecordBatch is written to an IPC file.
981 Result<std::shared_ptr<Buffer>> Serialize(const Expression& expr) {
982 struct {
983 std::shared_ptr<KeyValueMetadata> metadata_ = std::make_shared<KeyValueMetadata>();
984 ArrayVector columns_;
985
986 Result<std::string> AddScalar(const Scalar& scalar) {
987 auto ret = columns_.size();
988 ARROW_ASSIGN_OR_RAISE(auto array, MakeArrayFromScalar(scalar, 1));
989 columns_.push_back(std::move(array));
990 return std::to_string(ret);
991 }
992
993 Status Visit(const Expression& expr) {
994 if (auto lit = expr.literal()) {
995 if (!lit->is_scalar()) {
996 return Status::NotImplemented("Serialization of non-scalar literals");
997 }
998 ARROW_ASSIGN_OR_RAISE(auto value, AddScalar(*lit->scalar()));
999 metadata_->Append("literal", std::move(value));
1000 return Status::OK();
1001 }
1002
1003 if (auto ref = expr.field_ref()) {
1004 if (!ref->name()) {
1005 return Status::NotImplemented("Serialization of non-name field_refs");
1006 }
1007 metadata_->Append("field_ref", *ref->name());
1008 return Status::OK();
1009 }
1010
1011 auto call = CallNotNull(expr);
1012 metadata_->Append("call", call->function_name);
1013
1014 for (const auto& argument : call->arguments) {
1015 RETURN_NOT_OK(Visit(argument));
1016 }
1017
1018 if (call->options) {
1019 ARROW_ASSIGN_OR_RAISE(auto options_scalar,
1020 internal::FunctionOptionsToStructScalar(*call->options));
1021 ARROW_ASSIGN_OR_RAISE(auto value, AddScalar(*options_scalar));
1022 metadata_->Append("options", std::move(value));
1023 }
1024
1025 metadata_->Append("end", call->function_name);
1026 return Status::OK();
1027 }
1028
1029 Result<std::shared_ptr<RecordBatch>> operator()(const Expression& expr) {
1030 RETURN_NOT_OK(Visit(expr));
1031 FieldVector fields(columns_.size());
1032 for (size_t i = 0; i < fields.size(); ++i) {
1033 fields[i] = field("", columns_[i]->type());
1034 }
1035 return RecordBatch::Make(schema(std::move(fields), std::move(metadata_)), 1,
1036 std::move(columns_));
1037 }
1038 } ToRecordBatch;
1039
1040 ARROW_ASSIGN_OR_RAISE(auto batch, ToRecordBatch(expr));
1041 ARROW_ASSIGN_OR_RAISE(auto stream, io::BufferOutputStream::Create());
1042 ARROW_ASSIGN_OR_RAISE(auto writer, ipc::MakeFileWriter(stream, batch->schema()));
1043 RETURN_NOT_OK(writer->WriteRecordBatch(*batch));
1044 RETURN_NOT_OK(writer->Close());
1045 return stream->Finish();
1046 }
1047
1048 Result<Expression> Deserialize(std::shared_ptr<Buffer> buffer) {
1049 io::BufferReader stream(std::move(buffer));
1050 ARROW_ASSIGN_OR_RAISE(auto reader, ipc::RecordBatchFileReader::Open(&stream));
1051 ARROW_ASSIGN_OR_RAISE(auto batch, reader->ReadRecordBatch(0));
1052 if (batch->schema()->metadata() == nullptr) {
1053 return Status::Invalid("serialized Expression's batch repr had null metadata");
1054 }
1055 if (batch->num_rows() != 1) {
1056 return Status::Invalid(
1057 "serialized Expression's batch repr was not a single row - had ",
1058 batch->num_rows());
1059 }
1060
1061 struct FromRecordBatch {
1062 const RecordBatch& batch_;
1063 int index_;
1064
1065 const KeyValueMetadata& metadata() { return *batch_.schema()->metadata(); }
1066
1067 Result<std::shared_ptr<Scalar>> GetScalar(const std::string& i) {
1068 int32_t column_index;
1069 if (!::arrow::internal::ParseValue<Int32Type>(i.data(), i.length(),
1070 &column_index)) {
1071 return Status::Invalid("Couldn't parse column_index");
1072 }
1073 if (column_index >= batch_.num_columns()) {
1074 return Status::Invalid("column_index out of bounds");
1075 }
1076 return batch_.column(column_index)->GetScalar(0);
1077 }
1078
1079 Result<Expression> GetOne() {
1080 if (index_ >= metadata().size()) {
1081 return Status::Invalid("unterminated serialized Expression");
1082 }
1083
1084 const std::string& key = metadata().key(index_);
1085 const std::string& value = metadata().value(index_);
1086 ++index_;
1087
1088 if (key == "literal") {
1089 ARROW_ASSIGN_OR_RAISE(auto scalar, GetScalar(value));
1090 return literal(std::move(scalar));
1091 }
1092
1093 if (key == "field_ref") {
1094 return field_ref(value);
1095 }
1096
1097 if (key != "call") {
1098 return Status::Invalid("Unrecognized serialized Expression key ", key);
1099 }
1100
1101 std::vector<Expression> arguments;
1102 while (metadata().key(index_) != "end") {
1103 if (metadata().key(index_) == "options") {
1104 ARROW_ASSIGN_OR_RAISE(auto options_scalar, GetScalar(metadata().value(index_)));
1105 std::shared_ptr<compute::FunctionOptions> options;
1106 if (options_scalar) {
1107 ARROW_ASSIGN_OR_RAISE(
1108 options, internal::FunctionOptionsFromStructScalar(
1109 checked_cast<const StructScalar&>(*options_scalar)));
1110 }
1111 auto expr = call(value, std::move(arguments), std::move(options));
1112 index_ += 2;
1113 return expr;
1114 }
1115
1116 ARROW_ASSIGN_OR_RAISE(auto argument, GetOne());
1117 arguments.push_back(std::move(argument));
1118 }
1119
1120 ++index_;
1121 return call(value, std::move(arguments));
1122 }
1123 };
1124
1125 return FromRecordBatch{*batch, 0}.GetOne();
1126 }
1127
1128 Expression project(std::vector<Expression> values, std::vector<std::string> names) {
1129 return call("make_struct", std::move(values),
1130 compute::MakeStructOptions{std::move(names)});
1131 }
1132
1133 Expression equal(Expression lhs, Expression rhs) {
1134 return call("equal", {std::move(lhs), std::move(rhs)});
1135 }
1136
1137 Expression not_equal(Expression lhs, Expression rhs) {
1138 return call("not_equal", {std::move(lhs), std::move(rhs)});
1139 }
1140
1141 Expression less(Expression lhs, Expression rhs) {
1142 return call("less", {std::move(lhs), std::move(rhs)});
1143 }
1144
1145 Expression less_equal(Expression lhs, Expression rhs) {
1146 return call("less_equal", {std::move(lhs), std::move(rhs)});
1147 }
1148
1149 Expression greater(Expression lhs, Expression rhs) {
1150 return call("greater", {std::move(lhs), std::move(rhs)});
1151 }
1152
1153 Expression greater_equal(Expression lhs, Expression rhs) {
1154 return call("greater_equal", {std::move(lhs), std::move(rhs)});
1155 }
1156
1157 Expression is_null(Expression lhs, bool nan_is_null) {
1158 return call("is_null", {std::move(lhs)}, compute::NullOptions(std::move(nan_is_null)));
1159 }
1160
1161 Expression is_valid(Expression lhs) { return call("is_valid", {std::move(lhs)}); }
1162
1163 Expression and_(Expression lhs, Expression rhs) {
1164 return call("and_kleene", {std::move(lhs), std::move(rhs)});
1165 }
1166
1167 Expression and_(const std::vector<Expression>& operands) {
1168 auto folded = FoldLeft<Expression(Expression, Expression)>(operands.begin(),
1169 operands.end(), and_);
1170 if (folded) {
1171 return std::move(*folded);
1172 }
1173 return literal(true);
1174 }
1175
1176 Expression or_(Expression lhs, Expression rhs) {
1177 return call("or_kleene", {std::move(lhs), std::move(rhs)});
1178 }
1179
1180 Expression or_(const std::vector<Expression>& operands) {
1181 auto folded =
1182 FoldLeft<Expression(Expression, Expression)>(operands.begin(), operands.end(), or_);
1183 if (folded) {
1184 return std::move(*folded);
1185 }
1186 return literal(false);
1187 }
1188
1189 Expression not_(Expression operand) { return call("invert", {std::move(operand)}); }
1190
1191 } // namespace compute
1192 } // namespace arrow