]> git.proxmox.com Git - ceph.git/blame - ceph/src/arrow/python/pyarrow/tests/test_gandiva.py
import quincy 17.2.0
[ceph.git] / ceph / src / arrow / python / pyarrow / tests / test_gandiva.py
CommitLineData
1d09f67e
TL
1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements. See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership. The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License. You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied. See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18import datetime
19import pytest
20
21import pyarrow as pa
22
23
24@pytest.mark.gandiva
25def test_tree_exp_builder():
26 import pyarrow.gandiva as gandiva
27
28 builder = gandiva.TreeExprBuilder()
29
30 field_a = pa.field('a', pa.int32())
31 field_b = pa.field('b', pa.int32())
32
33 schema = pa.schema([field_a, field_b])
34
35 field_result = pa.field('res', pa.int32())
36
37 node_a = builder.make_field(field_a)
38 node_b = builder.make_field(field_b)
39
40 assert node_a.return_type() == field_a.type
41
42 condition = builder.make_function("greater_than", [node_a, node_b],
43 pa.bool_())
44 if_node = builder.make_if(condition, node_a, node_b, pa.int32())
45
46 expr = builder.make_expression(if_node, field_result)
47
48 assert expr.result().type == pa.int32()
49
50 projector = gandiva.make_projector(
51 schema, [expr], pa.default_memory_pool())
52
53 # Gandiva generates compute kernel function named `@expr_X`
54 assert projector.llvm_ir.find("@expr_") != -1
55
56 a = pa.array([10, 12, -20, 5], type=pa.int32())
57 b = pa.array([5, 15, 15, 17], type=pa.int32())
58 e = pa.array([10, 15, 15, 17], type=pa.int32())
59 input_batch = pa.RecordBatch.from_arrays([a, b], names=['a', 'b'])
60
61 r, = projector.evaluate(input_batch)
62 assert r.equals(e)
63
64
65@pytest.mark.gandiva
66def test_table():
67 import pyarrow.gandiva as gandiva
68
69 table = pa.Table.from_arrays([pa.array([1.0, 2.0]), pa.array([3.0, 4.0])],
70 ['a', 'b'])
71
72 builder = gandiva.TreeExprBuilder()
73 node_a = builder.make_field(table.schema.field("a"))
74 node_b = builder.make_field(table.schema.field("b"))
75
76 sum = builder.make_function("add", [node_a, node_b], pa.float64())
77
78 field_result = pa.field("c", pa.float64())
79 expr = builder.make_expression(sum, field_result)
80
81 projector = gandiva.make_projector(
82 table.schema, [expr], pa.default_memory_pool())
83
84 # TODO: Add .evaluate function which can take Tables instead of
85 # RecordBatches
86 r, = projector.evaluate(table.to_batches()[0])
87
88 e = pa.array([4.0, 6.0])
89 assert r.equals(e)
90
91
92@pytest.mark.gandiva
93def test_filter():
94 import pyarrow.gandiva as gandiva
95
96 table = pa.Table.from_arrays([pa.array([1.0 * i for i in range(10000)])],
97 ['a'])
98
99 builder = gandiva.TreeExprBuilder()
100 node_a = builder.make_field(table.schema.field("a"))
101 thousand = builder.make_literal(1000.0, pa.float64())
102 cond = builder.make_function("less_than", [node_a, thousand], pa.bool_())
103 condition = builder.make_condition(cond)
104
105 assert condition.result().type == pa.bool_()
106
107 filter = gandiva.make_filter(table.schema, condition)
108 # Gandiva generates compute kernel function named `@expr_X`
109 assert filter.llvm_ir.find("@expr_") != -1
110
111 result = filter.evaluate(table.to_batches()[0], pa.default_memory_pool())
112 assert result.to_array().equals(pa.array(range(1000), type=pa.uint32()))
113
114
115@pytest.mark.gandiva
116def test_in_expr():
117 import pyarrow.gandiva as gandiva
118
119 arr = pa.array(["ga", "an", "nd", "di", "iv", "va"])
120 table = pa.Table.from_arrays([arr], ["a"])
121
122 # string
123 builder = gandiva.TreeExprBuilder()
124 node_a = builder.make_field(table.schema.field("a"))
125 cond = builder.make_in_expression(node_a, ["an", "nd"], pa.string())
126 condition = builder.make_condition(cond)
127 filter = gandiva.make_filter(table.schema, condition)
128 result = filter.evaluate(table.to_batches()[0], pa.default_memory_pool())
129 assert result.to_array().equals(pa.array([1, 2], type=pa.uint32()))
130
131 # int32
132 arr = pa.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 4])
133 table = pa.Table.from_arrays([arr.cast(pa.int32())], ["a"])
134 node_a = builder.make_field(table.schema.field("a"))
135 cond = builder.make_in_expression(node_a, [1, 5], pa.int32())
136 condition = builder.make_condition(cond)
137 filter = gandiva.make_filter(table.schema, condition)
138 result = filter.evaluate(table.to_batches()[0], pa.default_memory_pool())
139 assert result.to_array().equals(pa.array([1, 3, 4, 8], type=pa.uint32()))
140
141 # int64
142 arr = pa.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 4])
143 table = pa.Table.from_arrays([arr], ["a"])
144 node_a = builder.make_field(table.schema.field("a"))
145 cond = builder.make_in_expression(node_a, [1, 5], pa.int64())
146 condition = builder.make_condition(cond)
147 filter = gandiva.make_filter(table.schema, condition)
148 result = filter.evaluate(table.to_batches()[0], pa.default_memory_pool())
149 assert result.to_array().equals(pa.array([1, 3, 4, 8], type=pa.uint32()))
150
151
152@pytest.mark.skip(reason="Gandiva C++ did not have *real* binary, "
153 "time and date support.")
154def test_in_expr_todo():
155 import pyarrow.gandiva as gandiva
156 # TODO: Implement reasonable support for timestamp, time & date.
157 # Current exceptions:
158 # pyarrow.lib.ArrowException: ExpressionValidationError:
159 # Evaluation expression for IN clause returns XXXX values are of typeXXXX
160
161 # binary
162 arr = pa.array([b"ga", b"an", b"nd", b"di", b"iv", b"va"])
163 table = pa.Table.from_arrays([arr], ["a"])
164
165 builder = gandiva.TreeExprBuilder()
166 node_a = builder.make_field(table.schema.field("a"))
167 cond = builder.make_in_expression(node_a, [b'an', b'nd'], pa.binary())
168 condition = builder.make_condition(cond)
169
170 filter = gandiva.make_filter(table.schema, condition)
171 result = filter.evaluate(table.to_batches()[0], pa.default_memory_pool())
172 assert result.to_array().equals(pa.array([1, 2], type=pa.uint32()))
173
174 # timestamp
175 datetime_1 = datetime.datetime.utcfromtimestamp(1542238951.621877)
176 datetime_2 = datetime.datetime.utcfromtimestamp(1542238911.621877)
177 datetime_3 = datetime.datetime.utcfromtimestamp(1542238051.621877)
178
179 arr = pa.array([datetime_1, datetime_2, datetime_3])
180 table = pa.Table.from_arrays([arr], ["a"])
181
182 builder = gandiva.TreeExprBuilder()
183 node_a = builder.make_field(table.schema.field("a"))
184 cond = builder.make_in_expression(node_a, [datetime_2], pa.timestamp('ms'))
185 condition = builder.make_condition(cond)
186
187 filter = gandiva.make_filter(table.schema, condition)
188 result = filter.evaluate(table.to_batches()[0], pa.default_memory_pool())
189 assert list(result.to_array()) == [1]
190
191 # time
192 time_1 = datetime_1.time()
193 time_2 = datetime_2.time()
194 time_3 = datetime_3.time()
195
196 arr = pa.array([time_1, time_2, time_3])
197 table = pa.Table.from_arrays([arr], ["a"])
198
199 builder = gandiva.TreeExprBuilder()
200 node_a = builder.make_field(table.schema.field("a"))
201 cond = builder.make_in_expression(node_a, [time_2], pa.time64('ms'))
202 condition = builder.make_condition(cond)
203
204 filter = gandiva.make_filter(table.schema, condition)
205 result = filter.evaluate(table.to_batches()[0], pa.default_memory_pool())
206 assert list(result.to_array()) == [1]
207
208 # date
209 date_1 = datetime_1.date()
210 date_2 = datetime_2.date()
211 date_3 = datetime_3.date()
212
213 arr = pa.array([date_1, date_2, date_3])
214 table = pa.Table.from_arrays([arr], ["a"])
215
216 builder = gandiva.TreeExprBuilder()
217 node_a = builder.make_field(table.schema.field("a"))
218 cond = builder.make_in_expression(node_a, [date_2], pa.date32())
219 condition = builder.make_condition(cond)
220
221 filter = gandiva.make_filter(table.schema, condition)
222 result = filter.evaluate(table.to_batches()[0], pa.default_memory_pool())
223 assert list(result.to_array()) == [1]
224
225
226@pytest.mark.gandiva
227def test_boolean():
228 import pyarrow.gandiva as gandiva
229
230 table = pa.Table.from_arrays([
231 pa.array([1., 31., 46., 3., 57., 44., 22.]),
232 pa.array([5., 45., 36., 73., 83., 23., 76.])],
233 ['a', 'b'])
234
235 builder = gandiva.TreeExprBuilder()
236 node_a = builder.make_field(table.schema.field("a"))
237 node_b = builder.make_field(table.schema.field("b"))
238 fifty = builder.make_literal(50.0, pa.float64())
239 eleven = builder.make_literal(11.0, pa.float64())
240
241 cond_1 = builder.make_function("less_than", [node_a, fifty], pa.bool_())
242 cond_2 = builder.make_function("greater_than", [node_a, node_b],
243 pa.bool_())
244 cond_3 = builder.make_function("less_than", [node_b, eleven], pa.bool_())
245 cond = builder.make_or([builder.make_and([cond_1, cond_2]), cond_3])
246 condition = builder.make_condition(cond)
247
248 filter = gandiva.make_filter(table.schema, condition)
249 result = filter.evaluate(table.to_batches()[0], pa.default_memory_pool())
250 assert result.to_array().equals(pa.array([0, 2, 5], type=pa.uint32()))
251
252
253@pytest.mark.gandiva
254def test_literals():
255 import pyarrow.gandiva as gandiva
256
257 builder = gandiva.TreeExprBuilder()
258
259 builder.make_literal(True, pa.bool_())
260 builder.make_literal(0, pa.uint8())
261 builder.make_literal(1, pa.uint16())
262 builder.make_literal(2, pa.uint32())
263 builder.make_literal(3, pa.uint64())
264 builder.make_literal(4, pa.int8())
265 builder.make_literal(5, pa.int16())
266 builder.make_literal(6, pa.int32())
267 builder.make_literal(7, pa.int64())
268 builder.make_literal(8.0, pa.float32())
269 builder.make_literal(9.0, pa.float64())
270 builder.make_literal("hello", pa.string())
271 builder.make_literal(b"world", pa.binary())
272
273 builder.make_literal(True, "bool")
274 builder.make_literal(0, "uint8")
275 builder.make_literal(1, "uint16")
276 builder.make_literal(2, "uint32")
277 builder.make_literal(3, "uint64")
278 builder.make_literal(4, "int8")
279 builder.make_literal(5, "int16")
280 builder.make_literal(6, "int32")
281 builder.make_literal(7, "int64")
282 builder.make_literal(8.0, "float32")
283 builder.make_literal(9.0, "float64")
284 builder.make_literal("hello", "string")
285 builder.make_literal(b"world", "binary")
286
287 with pytest.raises(TypeError):
288 builder.make_literal("hello", pa.int64())
289 with pytest.raises(TypeError):
290 builder.make_literal(True, None)
291
292
293@pytest.mark.gandiva
294def test_regex():
295 import pyarrow.gandiva as gandiva
296
297 elements = ["park", "sparkle", "bright spark and fire", "spark"]
298 data = pa.array(elements, type=pa.string())
299 table = pa.Table.from_arrays([data], names=['a'])
300
301 builder = gandiva.TreeExprBuilder()
302 node_a = builder.make_field(table.schema.field("a"))
303 regex = builder.make_literal("%spark%", pa.string())
304 like = builder.make_function("like", [node_a, regex], pa.bool_())
305
306 field_result = pa.field("b", pa.bool_())
307 expr = builder.make_expression(like, field_result)
308
309 projector = gandiva.make_projector(
310 table.schema, [expr], pa.default_memory_pool())
311
312 r, = projector.evaluate(table.to_batches()[0])
313 b = pa.array([False, True, True, True], type=pa.bool_())
314 assert r.equals(b)
315
316
317@pytest.mark.gandiva
318def test_get_registered_function_signatures():
319 import pyarrow.gandiva as gandiva
320 signatures = gandiva.get_registered_function_signatures()
321
322 assert type(signatures[0].return_type()) is pa.DataType
323 assert type(signatures[0].param_types()) is list
324 assert hasattr(signatures[0], "name")
325
326
327@pytest.mark.gandiva
328def test_filter_project():
329 import pyarrow.gandiva as gandiva
330 mpool = pa.default_memory_pool()
331 # Create a table with some sample data
332 array0 = pa.array([10, 12, -20, 5, 21, 29], pa.int32())
333 array1 = pa.array([5, 15, 15, 17, 12, 3], pa.int32())
334 array2 = pa.array([1, 25, 11, 30, -21, None], pa.int32())
335
336 table = pa.Table.from_arrays([array0, array1, array2], ['a', 'b', 'c'])
337
338 field_result = pa.field("res", pa.int32())
339
340 builder = gandiva.TreeExprBuilder()
341 node_a = builder.make_field(table.schema.field("a"))
342 node_b = builder.make_field(table.schema.field("b"))
343 node_c = builder.make_field(table.schema.field("c"))
344
345 greater_than_function = builder.make_function("greater_than",
346 [node_a, node_b], pa.bool_())
347 filter_condition = builder.make_condition(
348 greater_than_function)
349
350 project_condition = builder.make_function("less_than",
351 [node_b, node_c], pa.bool_())
352 if_node = builder.make_if(project_condition,
353 node_b, node_c, pa.int32())
354 expr = builder.make_expression(if_node, field_result)
355
356 # Build a filter for the expressions.
357 filter = gandiva.make_filter(table.schema, filter_condition)
358
359 # Build a projector for the expressions.
360 projector = gandiva.make_projector(
361 table.schema, [expr], mpool, "UINT32")
362
363 # Evaluate filter
364 selection_vector = filter.evaluate(table.to_batches()[0], mpool)
365
366 # Evaluate project
367 r, = projector.evaluate(
368 table.to_batches()[0], selection_vector)
369
370 exp = pa.array([1, -21, None], pa.int32())
371 assert r.equals(exp)
372
373
374@pytest.mark.gandiva
375def test_to_string():
376 import pyarrow.gandiva as gandiva
377 builder = gandiva.TreeExprBuilder()
378
379 assert str(builder.make_literal(2.0, pa.float64())
380 ).startswith('(const double) 2 raw(')
381 assert str(builder.make_literal(2, pa.int64())) == '(const int64) 2'
382 assert str(builder.make_field(pa.field('x', pa.float64()))) == '(double) x'
383 assert str(builder.make_field(pa.field('y', pa.string()))) == '(string) y'
384
385 field_z = builder.make_field(pa.field('z', pa.bool_()))
386 func_node = builder.make_function('not', [field_z], pa.bool_())
387 assert str(func_node) == 'bool not((bool) z)'
388
389 field_y = builder.make_field(pa.field('y', pa.bool_()))
390 and_node = builder.make_and([func_node, field_y])
391 assert str(and_node) == 'bool not((bool) z) && (bool) y'