]> git.proxmox.com Git - ceph.git/blame - ceph/src/arrow/python/pyarrow/tests/strategies.py
import quincy 17.2.0
[ceph.git] / ceph / src / arrow / python / pyarrow / tests / strategies.py
CommitLineData
1d09f67e
TL
1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements. See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership. The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License. You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied. See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18import datetime
19
20import pytz
21import hypothesis as h
22import hypothesis.strategies as st
23import hypothesis.extra.numpy as npst
24import hypothesis.extra.pytz as tzst
25import numpy as np
26
27import pyarrow as pa
28
29
30# TODO(kszucs): alphanum_text, surrogate_text
31custom_text = st.text(
32 alphabet=st.characters(
33 min_codepoint=0x41,
34 max_codepoint=0x7E
35 )
36)
37
38null_type = st.just(pa.null())
39bool_type = st.just(pa.bool_())
40
41binary_type = st.just(pa.binary())
42string_type = st.just(pa.string())
43large_binary_type = st.just(pa.large_binary())
44large_string_type = st.just(pa.large_string())
45fixed_size_binary_type = st.builds(
46 pa.binary,
47 st.integers(min_value=0, max_value=16)
48)
49binary_like_types = st.one_of(
50 binary_type,
51 string_type,
52 large_binary_type,
53 large_string_type,
54 fixed_size_binary_type
55)
56
57signed_integer_types = st.sampled_from([
58 pa.int8(),
59 pa.int16(),
60 pa.int32(),
61 pa.int64()
62])
63unsigned_integer_types = st.sampled_from([
64 pa.uint8(),
65 pa.uint16(),
66 pa.uint32(),
67 pa.uint64()
68])
69integer_types = st.one_of(signed_integer_types, unsigned_integer_types)
70
71floating_types = st.sampled_from([
72 pa.float16(),
73 pa.float32(),
74 pa.float64()
75])
76decimal128_type = st.builds(
77 pa.decimal128,
78 precision=st.integers(min_value=1, max_value=38),
79 scale=st.integers(min_value=1, max_value=38)
80)
81decimal256_type = st.builds(
82 pa.decimal256,
83 precision=st.integers(min_value=1, max_value=76),
84 scale=st.integers(min_value=1, max_value=76)
85)
86numeric_types = st.one_of(integer_types, floating_types,
87 decimal128_type, decimal256_type)
88
89date_types = st.sampled_from([
90 pa.date32(),
91 pa.date64()
92])
93time_types = st.sampled_from([
94 pa.time32('s'),
95 pa.time32('ms'),
96 pa.time64('us'),
97 pa.time64('ns')
98])
99timestamp_types = st.builds(
100 pa.timestamp,
101 unit=st.sampled_from(['s', 'ms', 'us', 'ns']),
102 tz=tzst.timezones()
103)
104duration_types = st.builds(
105 pa.duration,
106 st.sampled_from(['s', 'ms', 'us', 'ns'])
107)
108interval_types = st.sampled_from(
109 pa.month_day_nano_interval()
110)
111temporal_types = st.one_of(
112 date_types,
113 time_types,
114 timestamp_types,
115 duration_types,
116 interval_types
117)
118
119primitive_types = st.one_of(
120 null_type,
121 bool_type,
122 numeric_types,
123 temporal_types,
124 binary_like_types
125)
126
127metadata = st.dictionaries(st.text(), st.text())
128
129
130@st.composite
131def fields(draw, type_strategy=primitive_types):
132 name = draw(custom_text)
133 typ = draw(type_strategy)
134 if pa.types.is_null(typ):
135 nullable = True
136 else:
137 nullable = draw(st.booleans())
138 meta = draw(metadata)
139 return pa.field(name, type=typ, nullable=nullable, metadata=meta)
140
141
142def list_types(item_strategy=primitive_types):
143 return (
144 st.builds(pa.list_, item_strategy) |
145 st.builds(pa.large_list, item_strategy) |
146 st.builds(
147 pa.list_,
148 item_strategy,
149 st.integers(min_value=0, max_value=16)
150 )
151 )
152
153
154@st.composite
155def struct_types(draw, item_strategy=primitive_types):
156 fields_strategy = st.lists(fields(item_strategy))
157 fields_rendered = draw(fields_strategy)
158 field_names = [field.name for field in fields_rendered]
159 # check that field names are unique, see ARROW-9997
160 h.assume(len(set(field_names)) == len(field_names))
161 return pa.struct(fields_rendered)
162
163
164def dictionary_types(key_strategy=None, value_strategy=None):
165 key_strategy = key_strategy or signed_integer_types
166 value_strategy = value_strategy or st.one_of(
167 bool_type,
168 integer_types,
169 st.sampled_from([pa.float32(), pa.float64()]),
170 binary_type,
171 string_type,
172 fixed_size_binary_type,
173 )
174 return st.builds(pa.dictionary, key_strategy, value_strategy)
175
176
177@st.composite
178def map_types(draw, key_strategy=primitive_types,
179 item_strategy=primitive_types):
180 key_type = draw(key_strategy)
181 h.assume(not pa.types.is_null(key_type))
182 value_type = draw(item_strategy)
183 return pa.map_(key_type, value_type)
184
185
186# union type
187# extension type
188
189
190def schemas(type_strategy=primitive_types, max_fields=None):
191 children = st.lists(fields(type_strategy), max_size=max_fields)
192 return st.builds(pa.schema, children)
193
194
195all_types = st.deferred(
196 lambda: (
197 primitive_types |
198 list_types() |
199 struct_types() |
200 dictionary_types() |
201 map_types() |
202 list_types(all_types) |
203 struct_types(all_types)
204 )
205)
206all_fields = fields(all_types)
207all_schemas = schemas(all_types)
208
209
210_default_array_sizes = st.integers(min_value=0, max_value=20)
211
212
213@st.composite
214def _pylist(draw, value_type, size, nullable=True):
215 arr = draw(arrays(value_type, size=size, nullable=False))
216 return arr.to_pylist()
217
218
219@st.composite
220def _pymap(draw, key_type, value_type, size, nullable=True):
221 length = draw(size)
222 keys = draw(_pylist(key_type, size=length, nullable=False))
223 values = draw(_pylist(value_type, size=length, nullable=nullable))
224 return list(zip(keys, values))
225
226
227@st.composite
228def arrays(draw, type, size=None, nullable=True):
229 if isinstance(type, st.SearchStrategy):
230 ty = draw(type)
231 elif isinstance(type, pa.DataType):
232 ty = type
233 else:
234 raise TypeError('Type must be a pyarrow DataType')
235
236 if isinstance(size, st.SearchStrategy):
237 size = draw(size)
238 elif size is None:
239 size = draw(_default_array_sizes)
240 elif not isinstance(size, int):
241 raise TypeError('Size must be an integer')
242
243 if pa.types.is_null(ty):
244 h.assume(nullable)
245 value = st.none()
246 elif pa.types.is_boolean(ty):
247 value = st.booleans()
248 elif pa.types.is_integer(ty):
249 values = draw(npst.arrays(ty.to_pandas_dtype(), shape=(size,)))
250 return pa.array(values, type=ty)
251 elif pa.types.is_floating(ty):
252 values = draw(npst.arrays(ty.to_pandas_dtype(), shape=(size,)))
253 # Workaround ARROW-4952: no easy way to assert array equality
254 # in a NaN-tolerant way.
255 values[np.isnan(values)] = -42.0
256 return pa.array(values, type=ty)
257 elif pa.types.is_decimal(ty):
258 # TODO(kszucs): properly limit the precision
259 # value = st.decimals(places=type.scale, allow_infinity=False)
260 h.reject()
261 elif pa.types.is_time(ty):
262 value = st.times()
263 elif pa.types.is_date(ty):
264 value = st.dates()
265 elif pa.types.is_timestamp(ty):
266 min_int64 = -(2**63)
267 max_int64 = 2**63 - 1
268 min_datetime = datetime.datetime.fromtimestamp(min_int64 // 10**9)
269 max_datetime = datetime.datetime.fromtimestamp(max_int64 // 10**9)
270 try:
271 offset_hours = int(ty.tz)
272 tz = pytz.FixedOffset(offset_hours * 60)
273 except ValueError:
274 tz = pytz.timezone(ty.tz)
275 value = st.datetimes(timezones=st.just(tz), min_value=min_datetime,
276 max_value=max_datetime)
277 elif pa.types.is_duration(ty):
278 value = st.timedeltas()
279 elif pa.types.is_binary(ty) or pa.types.is_large_binary(ty):
280 value = st.binary()
281 elif pa.types.is_string(ty) or pa.types.is_large_string(ty):
282 value = st.text()
283 elif pa.types.is_fixed_size_binary(ty):
284 value = st.binary(min_size=ty.byte_width, max_size=ty.byte_width)
285 elif pa.types.is_list(ty):
286 value = _pylist(ty.value_type, size=size, nullable=nullable)
287 elif pa.types.is_large_list(ty):
288 value = _pylist(ty.value_type, size=size, nullable=nullable)
289 elif pa.types.is_fixed_size_list(ty):
290 value = _pylist(ty.value_type, size=ty.list_size, nullable=nullable)
291 elif pa.types.is_dictionary(ty):
292 values = _pylist(ty.value_type, size=size, nullable=nullable)
293 return pa.array(draw(values), type=ty)
294 elif pa.types.is_map(ty):
295 value = _pymap(ty.key_type, ty.item_type, size=_default_array_sizes,
296 nullable=nullable)
297 elif pa.types.is_struct(ty):
298 h.assume(len(ty) > 0)
299 fields, child_arrays = [], []
300 for field in ty:
301 fields.append(field)
302 child_arrays.append(draw(arrays(field.type, size=size)))
303 return pa.StructArray.from_arrays(child_arrays, fields=fields)
304 else:
305 raise NotImplementedError(ty)
306
307 if nullable:
308 value = st.one_of(st.none(), value)
309 values = st.lists(value, min_size=size, max_size=size)
310
311 return pa.array(draw(values), type=ty)
312
313
314@st.composite
315def chunked_arrays(draw, type, min_chunks=0, max_chunks=None, chunk_size=None):
316 if isinstance(type, st.SearchStrategy):
317 type = draw(type)
318
319 # TODO(kszucs): remove it, field metadata is not kept
320 h.assume(not pa.types.is_struct(type))
321
322 chunk = arrays(type, size=chunk_size)
323 chunks = st.lists(chunk, min_size=min_chunks, max_size=max_chunks)
324
325 return pa.chunked_array(draw(chunks), type=type)
326
327
328@st.composite
329def record_batches(draw, type, rows=None, max_fields=None):
330 if isinstance(rows, st.SearchStrategy):
331 rows = draw(rows)
332 elif rows is None:
333 rows = draw(_default_array_sizes)
334 elif not isinstance(rows, int):
335 raise TypeError('Rows must be an integer')
336
337 schema = draw(schemas(type, max_fields=max_fields))
338 children = [draw(arrays(field.type, size=rows)) for field in schema]
339 # TODO(kszucs): the names and schema arguments are not consistent with
340 # Table.from_array's arguments
341 return pa.RecordBatch.from_arrays(children, names=schema)
342
343
344@st.composite
345def tables(draw, type, rows=None, max_fields=None):
346 if isinstance(rows, st.SearchStrategy):
347 rows = draw(rows)
348 elif rows is None:
349 rows = draw(_default_array_sizes)
350 elif not isinstance(rows, int):
351 raise TypeError('Rows must be an integer')
352
353 schema = draw(schemas(type, max_fields=max_fields))
354 children = [draw(arrays(field.type, size=rows)) for field in schema]
355 return pa.Table.from_arrays(children, schema=schema)
356
357
358all_arrays = arrays(all_types)
359all_chunked_arrays = chunked_arrays(all_types)
360all_record_batches = record_batches(all_types)
361all_tables = tables(all_types)
362
363
364# Define the same rules as above for pandas tests by excluding certain types
365# from the generation because of known issues.
366
367pandas_compatible_primitive_types = st.one_of(
368 null_type,
369 bool_type,
370 integer_types,
371 st.sampled_from([pa.float32(), pa.float64()]),
372 decimal128_type,
373 date_types,
374 time_types,
375 # Need to exclude timestamp and duration types otherwise hypothesis
376 # discovers ARROW-10210
377 # timestamp_types,
378 # duration_types
379 interval_types,
380 binary_type,
381 string_type,
382 large_binary_type,
383 large_string_type,
384)
385
386# Need to exclude floating point types otherwise hypothesis discovers
387# ARROW-10211
388pandas_compatible_dictionary_value_types = st.one_of(
389 bool_type,
390 integer_types,
391 binary_type,
392 string_type,
393 fixed_size_binary_type,
394)
395
396
397def pandas_compatible_list_types(
398 item_strategy=pandas_compatible_primitive_types
399):
400 # Need to exclude fixed size list type otherwise hypothesis discovers
401 # ARROW-10194
402 return (
403 st.builds(pa.list_, item_strategy) |
404 st.builds(pa.large_list, item_strategy)
405 )
406
407
408pandas_compatible_types = st.deferred(
409 lambda: st.one_of(
410 pandas_compatible_primitive_types,
411 pandas_compatible_list_types(pandas_compatible_primitive_types),
412 struct_types(pandas_compatible_primitive_types),
413 dictionary_types(
414 value_strategy=pandas_compatible_dictionary_value_types
415 ),
416 pandas_compatible_list_types(pandas_compatible_types),
417 struct_types(pandas_compatible_types)
418 )
419)