]> git.proxmox.com Git - ceph.git/blob - ceph/src/arrow/cpp/src/arrow/compute/exec/expression_internal.h
import quincy 17.2.0
[ceph.git] / ceph / src / arrow / cpp / src / arrow / compute / exec / expression_internal.h
1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements. See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership. The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License. You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied. See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17
18 #include "arrow/compute/exec/expression.h"
19
20 #include <unordered_map>
21 #include <unordered_set>
22 #include <vector>
23
24 #include "arrow/compute/api_scalar.h"
25 #include "arrow/compute/cast.h"
26 #include "arrow/compute/registry.h"
27 #include "arrow/record_batch.h"
28 #include "arrow/table.h"
29 #include "arrow/util/logging.h"
30
31 namespace arrow {
32
33 using internal::checked_cast;
34
35 namespace compute {
36
37 struct KnownFieldValues {
38 std::unordered_map<FieldRef, Datum, FieldRef::Hash> map;
39 };
40
41 inline const Expression::Call* CallNotNull(const Expression& expr) {
42 auto call = expr.call();
43 DCHECK_NE(call, nullptr);
44 return call;
45 }
46
47 inline std::vector<ValueDescr> GetDescriptors(const std::vector<Expression>& exprs) {
48 std::vector<ValueDescr> descrs(exprs.size());
49 for (size_t i = 0; i < exprs.size(); ++i) {
50 DCHECK(exprs[i].IsBound());
51 descrs[i] = exprs[i].descr();
52 }
53 return descrs;
54 }
55
56 inline std::vector<ValueDescr> GetDescriptors(const std::vector<Datum>& values) {
57 std::vector<ValueDescr> descrs(values.size());
58 for (size_t i = 0; i < values.size(); ++i) {
59 descrs[i] = values[i].descr();
60 }
61 return descrs;
62 }
63
64 struct Comparison {
65 enum type {
66 NA = 0,
67 EQUAL = 1,
68 LESS = 2,
69 GREATER = 4,
70 NOT_EQUAL = LESS | GREATER,
71 LESS_EQUAL = LESS | EQUAL,
72 GREATER_EQUAL = GREATER | EQUAL,
73 };
74
75 static const type* Get(const std::string& function) {
76 static std::unordered_map<std::string, type> map{
77 {"equal", EQUAL}, {"not_equal", NOT_EQUAL},
78 {"less", LESS}, {"less_equal", LESS_EQUAL},
79 {"greater", GREATER}, {"greater_equal", GREATER_EQUAL},
80 };
81
82 auto it = map.find(function);
83 return it != map.end() ? &it->second : nullptr;
84 }
85
86 static const type* Get(const Expression& expr) {
87 if (auto call = expr.call()) {
88 return Comparison::Get(call->function_name);
89 }
90 return nullptr;
91 }
92
93 // Execute a simple Comparison between scalars
94 static Result<type> Execute(Datum l, Datum r) {
95 if (!l.is_scalar() || !r.is_scalar()) {
96 return Status::Invalid("Cannot Execute Comparison on non-scalars");
97 }
98
99 std::vector<Datum> arguments{std::move(l), std::move(r)};
100
101 ARROW_ASSIGN_OR_RAISE(auto equal, compute::CallFunction("equal", arguments));
102
103 if (!equal.scalar()->is_valid) return NA;
104 if (equal.scalar_as<BooleanScalar>().value) return EQUAL;
105
106 ARROW_ASSIGN_OR_RAISE(auto less, compute::CallFunction("less", arguments));
107
108 if (!less.scalar()->is_valid) return NA;
109 return less.scalar_as<BooleanScalar>().value ? LESS : GREATER;
110 }
111
112 // Given an Expression wrapped in casts which preserve ordering
113 // (for example, cast(field_ref("i16"), to_type=int32())), unwrap the inner Expression.
114 // This is used to destructure implicitly cast field_refs during Expression
115 // simplification.
116 static const Expression& StripOrderPreservingCasts(const Expression& expr) {
117 auto call = expr.call();
118 if (!call) return expr;
119 if (call->function_name != "cast") return expr;
120
121 const Expression& from = call->arguments[0];
122
123 auto from_id = from.type()->id();
124 auto to_id = expr.type()->id();
125
126 if (is_floating(to_id)) {
127 if (is_integer(from_id) || is_floating(from_id)) {
128 return StripOrderPreservingCasts(from);
129 }
130 return expr;
131 }
132
133 if (is_unsigned_integer(to_id)) {
134 if (is_unsigned_integer(from_id) && bit_width(to_id) >= bit_width(from_id)) {
135 return StripOrderPreservingCasts(from);
136 }
137 return expr;
138 }
139
140 if (is_signed_integer(to_id)) {
141 if (is_integer(from_id) && bit_width(to_id) >= bit_width(from_id)) {
142 return StripOrderPreservingCasts(from);
143 }
144 return expr;
145 }
146
147 return expr;
148 }
149
150 static type GetFlipped(type op) {
151 switch (op) {
152 case NA:
153 return NA;
154 case EQUAL:
155 return EQUAL;
156 case LESS:
157 return GREATER;
158 case GREATER:
159 return LESS;
160 case NOT_EQUAL:
161 return NOT_EQUAL;
162 case LESS_EQUAL:
163 return GREATER_EQUAL;
164 case GREATER_EQUAL:
165 return LESS_EQUAL;
166 }
167 DCHECK(false);
168 return NA;
169 }
170
171 static std::string GetName(type op) {
172 switch (op) {
173 case NA:
174 break;
175 case EQUAL:
176 return "equal";
177 case LESS:
178 return "less";
179 case GREATER:
180 return "greater";
181 case NOT_EQUAL:
182 return "not_equal";
183 case LESS_EQUAL:
184 return "less_equal";
185 case GREATER_EQUAL:
186 return "greater_equal";
187 }
188 return "na";
189 }
190
191 static std::string GetOp(type op) {
192 switch (op) {
193 case NA:
194 DCHECK(false) << "unreachable";
195 break;
196 case EQUAL:
197 return "==";
198 case LESS:
199 return "<";
200 case GREATER:
201 return ">";
202 case NOT_EQUAL:
203 return "!=";
204 case LESS_EQUAL:
205 return "<=";
206 case GREATER_EQUAL:
207 return ">=";
208 }
209 DCHECK(false);
210 return "";
211 }
212 };
213
214 inline const compute::CastOptions* GetCastOptions(const Expression::Call& call) {
215 if (call.function_name != "cast") return nullptr;
216 return checked_cast<const compute::CastOptions*>(call.options.get());
217 }
218
219 inline bool IsSetLookup(const std::string& function) {
220 return function == "is_in" || function == "index_in";
221 }
222
223 inline const compute::MakeStructOptions* GetMakeStructOptions(
224 const Expression::Call& call) {
225 if (call.function_name != "make_struct") return nullptr;
226 return checked_cast<const compute::MakeStructOptions*>(call.options.get());
227 }
228
229 /// A helper for unboxing an Expression composed of associative function calls.
230 /// Such expressions can frequently be rearranged to a semantically equivalent
231 /// expression for more optimal execution or more straightforward manipulation.
232 /// For example, (a + ((b + 3) + 4)) is equivalent to (((4 + 3) + a) + b) and the latter
233 /// can be trivially constant-folded to ((7 + a) + b).
234 struct FlattenedAssociativeChain {
235 /// True if a chain was already a left fold.
236 bool was_left_folded = true;
237
238 /// All "branch" expressions in a flattened chain. For example given (a + ((b + 3) + 4))
239 /// exprs would be [(a + ((b + 3) + 4)), ((b + 3) + 4), (b + 3)]
240 std::vector<Expression> exprs;
241
242 /// All "leaf" expressions in a flattened chain. For example given (a + ((b + 3) + 4))
243 /// the fringe would be [a, b, 3, 4]
244 std::vector<Expression> fringe;
245
246 explicit FlattenedAssociativeChain(Expression expr) : exprs{std::move(expr)} {
247 auto call = CallNotNull(exprs.back());
248 fringe = call->arguments;
249
250 auto it = fringe.begin();
251
252 while (it != fringe.end()) {
253 auto sub_call = it->call();
254 if (!sub_call || sub_call->function_name != call->function_name) {
255 ++it;
256 continue;
257 }
258
259 if (it != fringe.begin()) {
260 was_left_folded = false;
261 }
262
263 exprs.push_back(std::move(*it));
264 it = fringe.erase(it);
265
266 auto index = it - fringe.begin();
267 fringe.insert(it, sub_call->arguments.begin(), sub_call->arguments.end());
268 it = fringe.begin() + index;
269 // NB: no increment so we hit sub_call's first argument next iteration
270 }
271
272 DCHECK(std::all_of(exprs.begin(), exprs.end(), [](const Expression& expr) {
273 return CallNotNull(expr)->options == nullptr;
274 }));
275 }
276 };
277
278 inline Result<std::shared_ptr<compute::Function>> GetFunction(
279 const Expression::Call& call, compute::ExecContext* exec_context) {
280 if (call.function_name != "cast") {
281 return exec_context->func_registry()->GetFunction(call.function_name);
282 }
283 // XXX this special case is strange; why not make "cast" a ScalarFunction?
284 const auto& to_type = checked_cast<const compute::CastOptions&>(*call.options).to_type;
285 return compute::GetCastFunction(to_type);
286 }
287
288 /// Modify an Expression with pre-order and post-order visitation.
289 /// `pre` will be invoked on each Expression. `pre` will visit Calls before their
290 /// arguments, `post_call` will visit Calls (and no other Expressions) after their
291 /// arguments. Visitors should return the Identical expression to indicate no change; this
292 /// will prevent unnecessary construction in the common case where a modification is not
293 /// possible/necessary/...
294 ///
295 /// If an argument was modified, `post_call` visits a reconstructed Call with the modified
296 /// arguments but also receives a pointer to the unmodified Expression as a second
297 /// argument. If no arguments were modified the unmodified Expression* will be nullptr.
298 template <typename PreVisit, typename PostVisitCall>
299 Result<Expression> Modify(Expression expr, const PreVisit& pre,
300 const PostVisitCall& post_call) {
301 ARROW_ASSIGN_OR_RAISE(expr, Result<Expression>(pre(std::move(expr))));
302
303 auto call = expr.call();
304 if (!call) return expr;
305
306 bool at_least_one_modified = false;
307 std::vector<Expression> modified_arguments;
308
309 for (size_t i = 0; i < call->arguments.size(); ++i) {
310 ARROW_ASSIGN_OR_RAISE(auto modified_argument,
311 Modify(call->arguments[i], pre, post_call));
312
313 if (Identical(modified_argument, call->arguments[i])) {
314 continue;
315 }
316
317 if (!at_least_one_modified) {
318 modified_arguments = call->arguments;
319 at_least_one_modified = true;
320 }
321
322 modified_arguments[i] = std::move(modified_argument);
323 }
324
325 if (at_least_one_modified) {
326 // reconstruct the call expression with the modified arguments
327 auto modified_call = *call;
328 modified_call.arguments = std::move(modified_arguments);
329 return post_call(Expression(std::move(modified_call)), &expr);
330 }
331
332 return post_call(std::move(expr), nullptr);
333 }
334
335 } // namespace compute
336 } // namespace arrow