]>
Commit | Line | Data |
---|---|---|
1d09f67e TL |
1 | # Licensed to the Apache Software Foundation (ASF) under one |
2 | # or more contributor license agreements. See the NOTICE file | |
3 | # distributed with this work for additional information | |
4 | # regarding copyright ownership. The ASF licenses this file | |
5 | # to you under the Apache License, Version 2.0 (the | |
6 | # "License"); you may not use this file except in compliance | |
7 | # with the License. You may obtain a copy of the License at | |
8 | # | |
9 | # http://www.apache.org/licenses/LICENSE-2.0 | |
10 | # | |
11 | # Unless required by applicable law or agreed to in writing, | |
12 | # software distributed under the License is distributed on an | |
13 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | |
14 | # KIND, either express or implied. See the License for the | |
15 | # specific language governing permissions and limitations | |
16 | # under the License. | |
17 | ||
18 | import datetime | |
19 | import pytest | |
20 | ||
21 | import pyarrow as pa | |
22 | ||
23 | ||
24 | @pytest.mark.gandiva | |
25 | def test_tree_exp_builder(): | |
26 | import pyarrow.gandiva as gandiva | |
27 | ||
28 | builder = gandiva.TreeExprBuilder() | |
29 | ||
30 | field_a = pa.field('a', pa.int32()) | |
31 | field_b = pa.field('b', pa.int32()) | |
32 | ||
33 | schema = pa.schema([field_a, field_b]) | |
34 | ||
35 | field_result = pa.field('res', pa.int32()) | |
36 | ||
37 | node_a = builder.make_field(field_a) | |
38 | node_b = builder.make_field(field_b) | |
39 | ||
40 | assert node_a.return_type() == field_a.type | |
41 | ||
42 | condition = builder.make_function("greater_than", [node_a, node_b], | |
43 | pa.bool_()) | |
44 | if_node = builder.make_if(condition, node_a, node_b, pa.int32()) | |
45 | ||
46 | expr = builder.make_expression(if_node, field_result) | |
47 | ||
48 | assert expr.result().type == pa.int32() | |
49 | ||
50 | projector = gandiva.make_projector( | |
51 | schema, [expr], pa.default_memory_pool()) | |
52 | ||
53 | # Gandiva generates compute kernel function named `@expr_X` | |
54 | assert projector.llvm_ir.find("@expr_") != -1 | |
55 | ||
56 | a = pa.array([10, 12, -20, 5], type=pa.int32()) | |
57 | b = pa.array([5, 15, 15, 17], type=pa.int32()) | |
58 | e = pa.array([10, 15, 15, 17], type=pa.int32()) | |
59 | input_batch = pa.RecordBatch.from_arrays([a, b], names=['a', 'b']) | |
60 | ||
61 | r, = projector.evaluate(input_batch) | |
62 | assert r.equals(e) | |
63 | ||
64 | ||
65 | @pytest.mark.gandiva | |
66 | def test_table(): | |
67 | import pyarrow.gandiva as gandiva | |
68 | ||
69 | table = pa.Table.from_arrays([pa.array([1.0, 2.0]), pa.array([3.0, 4.0])], | |
70 | ['a', 'b']) | |
71 | ||
72 | builder = gandiva.TreeExprBuilder() | |
73 | node_a = builder.make_field(table.schema.field("a")) | |
74 | node_b = builder.make_field(table.schema.field("b")) | |
75 | ||
76 | sum = builder.make_function("add", [node_a, node_b], pa.float64()) | |
77 | ||
78 | field_result = pa.field("c", pa.float64()) | |
79 | expr = builder.make_expression(sum, field_result) | |
80 | ||
81 | projector = gandiva.make_projector( | |
82 | table.schema, [expr], pa.default_memory_pool()) | |
83 | ||
84 | # TODO: Add .evaluate function which can take Tables instead of | |
85 | # RecordBatches | |
86 | r, = projector.evaluate(table.to_batches()[0]) | |
87 | ||
88 | e = pa.array([4.0, 6.0]) | |
89 | assert r.equals(e) | |
90 | ||
91 | ||
92 | @pytest.mark.gandiva | |
93 | def test_filter(): | |
94 | import pyarrow.gandiva as gandiva | |
95 | ||
96 | table = pa.Table.from_arrays([pa.array([1.0 * i for i in range(10000)])], | |
97 | ['a']) | |
98 | ||
99 | builder = gandiva.TreeExprBuilder() | |
100 | node_a = builder.make_field(table.schema.field("a")) | |
101 | thousand = builder.make_literal(1000.0, pa.float64()) | |
102 | cond = builder.make_function("less_than", [node_a, thousand], pa.bool_()) | |
103 | condition = builder.make_condition(cond) | |
104 | ||
105 | assert condition.result().type == pa.bool_() | |
106 | ||
107 | filter = gandiva.make_filter(table.schema, condition) | |
108 | # Gandiva generates compute kernel function named `@expr_X` | |
109 | assert filter.llvm_ir.find("@expr_") != -1 | |
110 | ||
111 | result = filter.evaluate(table.to_batches()[0], pa.default_memory_pool()) | |
112 | assert result.to_array().equals(pa.array(range(1000), type=pa.uint32())) | |
113 | ||
114 | ||
115 | @pytest.mark.gandiva | |
116 | def test_in_expr(): | |
117 | import pyarrow.gandiva as gandiva | |
118 | ||
119 | arr = pa.array(["ga", "an", "nd", "di", "iv", "va"]) | |
120 | table = pa.Table.from_arrays([arr], ["a"]) | |
121 | ||
122 | # string | |
123 | builder = gandiva.TreeExprBuilder() | |
124 | node_a = builder.make_field(table.schema.field("a")) | |
125 | cond = builder.make_in_expression(node_a, ["an", "nd"], pa.string()) | |
126 | condition = builder.make_condition(cond) | |
127 | filter = gandiva.make_filter(table.schema, condition) | |
128 | result = filter.evaluate(table.to_batches()[0], pa.default_memory_pool()) | |
129 | assert result.to_array().equals(pa.array([1, 2], type=pa.uint32())) | |
130 | ||
131 | # int32 | |
132 | arr = pa.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 4]) | |
133 | table = pa.Table.from_arrays([arr.cast(pa.int32())], ["a"]) | |
134 | node_a = builder.make_field(table.schema.field("a")) | |
135 | cond = builder.make_in_expression(node_a, [1, 5], pa.int32()) | |
136 | condition = builder.make_condition(cond) | |
137 | filter = gandiva.make_filter(table.schema, condition) | |
138 | result = filter.evaluate(table.to_batches()[0], pa.default_memory_pool()) | |
139 | assert result.to_array().equals(pa.array([1, 3, 4, 8], type=pa.uint32())) | |
140 | ||
141 | # int64 | |
142 | arr = pa.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 4]) | |
143 | table = pa.Table.from_arrays([arr], ["a"]) | |
144 | node_a = builder.make_field(table.schema.field("a")) | |
145 | cond = builder.make_in_expression(node_a, [1, 5], pa.int64()) | |
146 | condition = builder.make_condition(cond) | |
147 | filter = gandiva.make_filter(table.schema, condition) | |
148 | result = filter.evaluate(table.to_batches()[0], pa.default_memory_pool()) | |
149 | assert result.to_array().equals(pa.array([1, 3, 4, 8], type=pa.uint32())) | |
150 | ||
151 | ||
152 | @pytest.mark.skip(reason="Gandiva C++ did not have *real* binary, " | |
153 | "time and date support.") | |
154 | def test_in_expr_todo(): | |
155 | import pyarrow.gandiva as gandiva | |
156 | # TODO: Implement reasonable support for timestamp, time & date. | |
157 | # Current exceptions: | |
158 | # pyarrow.lib.ArrowException: ExpressionValidationError: | |
159 | # Evaluation expression for IN clause returns XXXX values are of typeXXXX | |
160 | ||
161 | # binary | |
162 | arr = pa.array([b"ga", b"an", b"nd", b"di", b"iv", b"va"]) | |
163 | table = pa.Table.from_arrays([arr], ["a"]) | |
164 | ||
165 | builder = gandiva.TreeExprBuilder() | |
166 | node_a = builder.make_field(table.schema.field("a")) | |
167 | cond = builder.make_in_expression(node_a, [b'an', b'nd'], pa.binary()) | |
168 | condition = builder.make_condition(cond) | |
169 | ||
170 | filter = gandiva.make_filter(table.schema, condition) | |
171 | result = filter.evaluate(table.to_batches()[0], pa.default_memory_pool()) | |
172 | assert result.to_array().equals(pa.array([1, 2], type=pa.uint32())) | |
173 | ||
174 | # timestamp | |
175 | datetime_1 = datetime.datetime.utcfromtimestamp(1542238951.621877) | |
176 | datetime_2 = datetime.datetime.utcfromtimestamp(1542238911.621877) | |
177 | datetime_3 = datetime.datetime.utcfromtimestamp(1542238051.621877) | |
178 | ||
179 | arr = pa.array([datetime_1, datetime_2, datetime_3]) | |
180 | table = pa.Table.from_arrays([arr], ["a"]) | |
181 | ||
182 | builder = gandiva.TreeExprBuilder() | |
183 | node_a = builder.make_field(table.schema.field("a")) | |
184 | cond = builder.make_in_expression(node_a, [datetime_2], pa.timestamp('ms')) | |
185 | condition = builder.make_condition(cond) | |
186 | ||
187 | filter = gandiva.make_filter(table.schema, condition) | |
188 | result = filter.evaluate(table.to_batches()[0], pa.default_memory_pool()) | |
189 | assert list(result.to_array()) == [1] | |
190 | ||
191 | # time | |
192 | time_1 = datetime_1.time() | |
193 | time_2 = datetime_2.time() | |
194 | time_3 = datetime_3.time() | |
195 | ||
196 | arr = pa.array([time_1, time_2, time_3]) | |
197 | table = pa.Table.from_arrays([arr], ["a"]) | |
198 | ||
199 | builder = gandiva.TreeExprBuilder() | |
200 | node_a = builder.make_field(table.schema.field("a")) | |
201 | cond = builder.make_in_expression(node_a, [time_2], pa.time64('ms')) | |
202 | condition = builder.make_condition(cond) | |
203 | ||
204 | filter = gandiva.make_filter(table.schema, condition) | |
205 | result = filter.evaluate(table.to_batches()[0], pa.default_memory_pool()) | |
206 | assert list(result.to_array()) == [1] | |
207 | ||
208 | # date | |
209 | date_1 = datetime_1.date() | |
210 | date_2 = datetime_2.date() | |
211 | date_3 = datetime_3.date() | |
212 | ||
213 | arr = pa.array([date_1, date_2, date_3]) | |
214 | table = pa.Table.from_arrays([arr], ["a"]) | |
215 | ||
216 | builder = gandiva.TreeExprBuilder() | |
217 | node_a = builder.make_field(table.schema.field("a")) | |
218 | cond = builder.make_in_expression(node_a, [date_2], pa.date32()) | |
219 | condition = builder.make_condition(cond) | |
220 | ||
221 | filter = gandiva.make_filter(table.schema, condition) | |
222 | result = filter.evaluate(table.to_batches()[0], pa.default_memory_pool()) | |
223 | assert list(result.to_array()) == [1] | |
224 | ||
225 | ||
226 | @pytest.mark.gandiva | |
227 | def test_boolean(): | |
228 | import pyarrow.gandiva as gandiva | |
229 | ||
230 | table = pa.Table.from_arrays([ | |
231 | pa.array([1., 31., 46., 3., 57., 44., 22.]), | |
232 | pa.array([5., 45., 36., 73., 83., 23., 76.])], | |
233 | ['a', 'b']) | |
234 | ||
235 | builder = gandiva.TreeExprBuilder() | |
236 | node_a = builder.make_field(table.schema.field("a")) | |
237 | node_b = builder.make_field(table.schema.field("b")) | |
238 | fifty = builder.make_literal(50.0, pa.float64()) | |
239 | eleven = builder.make_literal(11.0, pa.float64()) | |
240 | ||
241 | cond_1 = builder.make_function("less_than", [node_a, fifty], pa.bool_()) | |
242 | cond_2 = builder.make_function("greater_than", [node_a, node_b], | |
243 | pa.bool_()) | |
244 | cond_3 = builder.make_function("less_than", [node_b, eleven], pa.bool_()) | |
245 | cond = builder.make_or([builder.make_and([cond_1, cond_2]), cond_3]) | |
246 | condition = builder.make_condition(cond) | |
247 | ||
248 | filter = gandiva.make_filter(table.schema, condition) | |
249 | result = filter.evaluate(table.to_batches()[0], pa.default_memory_pool()) | |
250 | assert result.to_array().equals(pa.array([0, 2, 5], type=pa.uint32())) | |
251 | ||
252 | ||
253 | @pytest.mark.gandiva | |
254 | def test_literals(): | |
255 | import pyarrow.gandiva as gandiva | |
256 | ||
257 | builder = gandiva.TreeExprBuilder() | |
258 | ||
259 | builder.make_literal(True, pa.bool_()) | |
260 | builder.make_literal(0, pa.uint8()) | |
261 | builder.make_literal(1, pa.uint16()) | |
262 | builder.make_literal(2, pa.uint32()) | |
263 | builder.make_literal(3, pa.uint64()) | |
264 | builder.make_literal(4, pa.int8()) | |
265 | builder.make_literal(5, pa.int16()) | |
266 | builder.make_literal(6, pa.int32()) | |
267 | builder.make_literal(7, pa.int64()) | |
268 | builder.make_literal(8.0, pa.float32()) | |
269 | builder.make_literal(9.0, pa.float64()) | |
270 | builder.make_literal("hello", pa.string()) | |
271 | builder.make_literal(b"world", pa.binary()) | |
272 | ||
273 | builder.make_literal(True, "bool") | |
274 | builder.make_literal(0, "uint8") | |
275 | builder.make_literal(1, "uint16") | |
276 | builder.make_literal(2, "uint32") | |
277 | builder.make_literal(3, "uint64") | |
278 | builder.make_literal(4, "int8") | |
279 | builder.make_literal(5, "int16") | |
280 | builder.make_literal(6, "int32") | |
281 | builder.make_literal(7, "int64") | |
282 | builder.make_literal(8.0, "float32") | |
283 | builder.make_literal(9.0, "float64") | |
284 | builder.make_literal("hello", "string") | |
285 | builder.make_literal(b"world", "binary") | |
286 | ||
287 | with pytest.raises(TypeError): | |
288 | builder.make_literal("hello", pa.int64()) | |
289 | with pytest.raises(TypeError): | |
290 | builder.make_literal(True, None) | |
291 | ||
292 | ||
293 | @pytest.mark.gandiva | |
294 | def test_regex(): | |
295 | import pyarrow.gandiva as gandiva | |
296 | ||
297 | elements = ["park", "sparkle", "bright spark and fire", "spark"] | |
298 | data = pa.array(elements, type=pa.string()) | |
299 | table = pa.Table.from_arrays([data], names=['a']) | |
300 | ||
301 | builder = gandiva.TreeExprBuilder() | |
302 | node_a = builder.make_field(table.schema.field("a")) | |
303 | regex = builder.make_literal("%spark%", pa.string()) | |
304 | like = builder.make_function("like", [node_a, regex], pa.bool_()) | |
305 | ||
306 | field_result = pa.field("b", pa.bool_()) | |
307 | expr = builder.make_expression(like, field_result) | |
308 | ||
309 | projector = gandiva.make_projector( | |
310 | table.schema, [expr], pa.default_memory_pool()) | |
311 | ||
312 | r, = projector.evaluate(table.to_batches()[0]) | |
313 | b = pa.array([False, True, True, True], type=pa.bool_()) | |
314 | assert r.equals(b) | |
315 | ||
316 | ||
317 | @pytest.mark.gandiva | |
318 | def test_get_registered_function_signatures(): | |
319 | import pyarrow.gandiva as gandiva | |
320 | signatures = gandiva.get_registered_function_signatures() | |
321 | ||
322 | assert type(signatures[0].return_type()) is pa.DataType | |
323 | assert type(signatures[0].param_types()) is list | |
324 | assert hasattr(signatures[0], "name") | |
325 | ||
326 | ||
327 | @pytest.mark.gandiva | |
328 | def test_filter_project(): | |
329 | import pyarrow.gandiva as gandiva | |
330 | mpool = pa.default_memory_pool() | |
331 | # Create a table with some sample data | |
332 | array0 = pa.array([10, 12, -20, 5, 21, 29], pa.int32()) | |
333 | array1 = pa.array([5, 15, 15, 17, 12, 3], pa.int32()) | |
334 | array2 = pa.array([1, 25, 11, 30, -21, None], pa.int32()) | |
335 | ||
336 | table = pa.Table.from_arrays([array0, array1, array2], ['a', 'b', 'c']) | |
337 | ||
338 | field_result = pa.field("res", pa.int32()) | |
339 | ||
340 | builder = gandiva.TreeExprBuilder() | |
341 | node_a = builder.make_field(table.schema.field("a")) | |
342 | node_b = builder.make_field(table.schema.field("b")) | |
343 | node_c = builder.make_field(table.schema.field("c")) | |
344 | ||
345 | greater_than_function = builder.make_function("greater_than", | |
346 | [node_a, node_b], pa.bool_()) | |
347 | filter_condition = builder.make_condition( | |
348 | greater_than_function) | |
349 | ||
350 | project_condition = builder.make_function("less_than", | |
351 | [node_b, node_c], pa.bool_()) | |
352 | if_node = builder.make_if(project_condition, | |
353 | node_b, node_c, pa.int32()) | |
354 | expr = builder.make_expression(if_node, field_result) | |
355 | ||
356 | # Build a filter for the expressions. | |
357 | filter = gandiva.make_filter(table.schema, filter_condition) | |
358 | ||
359 | # Build a projector for the expressions. | |
360 | projector = gandiva.make_projector( | |
361 | table.schema, [expr], mpool, "UINT32") | |
362 | ||
363 | # Evaluate filter | |
364 | selection_vector = filter.evaluate(table.to_batches()[0], mpool) | |
365 | ||
366 | # Evaluate project | |
367 | r, = projector.evaluate( | |
368 | table.to_batches()[0], selection_vector) | |
369 | ||
370 | exp = pa.array([1, -21, None], pa.int32()) | |
371 | assert r.equals(exp) | |
372 | ||
373 | ||
374 | @pytest.mark.gandiva | |
375 | def test_to_string(): | |
376 | import pyarrow.gandiva as gandiva | |
377 | builder = gandiva.TreeExprBuilder() | |
378 | ||
379 | assert str(builder.make_literal(2.0, pa.float64()) | |
380 | ).startswith('(const double) 2 raw(') | |
381 | assert str(builder.make_literal(2, pa.int64())) == '(const int64) 2' | |
382 | assert str(builder.make_field(pa.field('x', pa.float64()))) == '(double) x' | |
383 | assert str(builder.make_field(pa.field('y', pa.string()))) == '(string) y' | |
384 | ||
385 | field_z = builder.make_field(pa.field('z', pa.bool_())) | |
386 | func_node = builder.make_function('not', [field_z], pa.bool_()) | |
387 | assert str(func_node) == 'bool not((bool) z)' | |
388 | ||
389 | field_y = builder.make_field(pa.field('y', pa.bool_())) | |
390 | and_node = builder.make_and([func_node, field_y]) | |
391 | assert str(and_node) == 'bool not((bool) z) && (bool) y' |