]>
Commit | Line | Data |
---|---|---|
1d09f67e TL |
1 | # Licensed to the Apache Software Foundation (ASF) under one |
2 | # or more contributor license agreements. See the NOTICE file | |
3 | # distributed with this work for additional information | |
4 | # regarding copyright ownership. The ASF licenses this file | |
5 | # to you under the Apache License, Version 2.0 (the | |
6 | # "License"); you may not use this file except in compliance | |
7 | # with the License. You may obtain a copy of the License at | |
8 | # | |
9 | # http://www.apache.org/licenses/LICENSE-2.0 | |
10 | # | |
11 | # Unless required by applicable law or agreed to in writing, | |
12 | # software distributed under the License is distributed on an | |
13 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | |
14 | # KIND, either express or implied. See the License for the | |
15 | # specific language governing permissions and limitations | |
16 | # under the License. | |
17 | ||
18 | import datetime | |
19 | ||
20 | import pytz | |
21 | import hypothesis as h | |
22 | import hypothesis.strategies as st | |
23 | import hypothesis.extra.numpy as npst | |
24 | import hypothesis.extra.pytz as tzst | |
25 | import numpy as np | |
26 | ||
27 | import pyarrow as pa | |
28 | ||
29 | ||
30 | # TODO(kszucs): alphanum_text, surrogate_text | |
31 | custom_text = st.text( | |
32 | alphabet=st.characters( | |
33 | min_codepoint=0x41, | |
34 | max_codepoint=0x7E | |
35 | ) | |
36 | ) | |
37 | ||
38 | null_type = st.just(pa.null()) | |
39 | bool_type = st.just(pa.bool_()) | |
40 | ||
41 | binary_type = st.just(pa.binary()) | |
42 | string_type = st.just(pa.string()) | |
43 | large_binary_type = st.just(pa.large_binary()) | |
44 | large_string_type = st.just(pa.large_string()) | |
45 | fixed_size_binary_type = st.builds( | |
46 | pa.binary, | |
47 | st.integers(min_value=0, max_value=16) | |
48 | ) | |
49 | binary_like_types = st.one_of( | |
50 | binary_type, | |
51 | string_type, | |
52 | large_binary_type, | |
53 | large_string_type, | |
54 | fixed_size_binary_type | |
55 | ) | |
56 | ||
57 | signed_integer_types = st.sampled_from([ | |
58 | pa.int8(), | |
59 | pa.int16(), | |
60 | pa.int32(), | |
61 | pa.int64() | |
62 | ]) | |
63 | unsigned_integer_types = st.sampled_from([ | |
64 | pa.uint8(), | |
65 | pa.uint16(), | |
66 | pa.uint32(), | |
67 | pa.uint64() | |
68 | ]) | |
69 | integer_types = st.one_of(signed_integer_types, unsigned_integer_types) | |
70 | ||
71 | floating_types = st.sampled_from([ | |
72 | pa.float16(), | |
73 | pa.float32(), | |
74 | pa.float64() | |
75 | ]) | |
76 | decimal128_type = st.builds( | |
77 | pa.decimal128, | |
78 | precision=st.integers(min_value=1, max_value=38), | |
79 | scale=st.integers(min_value=1, max_value=38) | |
80 | ) | |
81 | decimal256_type = st.builds( | |
82 | pa.decimal256, | |
83 | precision=st.integers(min_value=1, max_value=76), | |
84 | scale=st.integers(min_value=1, max_value=76) | |
85 | ) | |
86 | numeric_types = st.one_of(integer_types, floating_types, | |
87 | decimal128_type, decimal256_type) | |
88 | ||
89 | date_types = st.sampled_from([ | |
90 | pa.date32(), | |
91 | pa.date64() | |
92 | ]) | |
93 | time_types = st.sampled_from([ | |
94 | pa.time32('s'), | |
95 | pa.time32('ms'), | |
96 | pa.time64('us'), | |
97 | pa.time64('ns') | |
98 | ]) | |
99 | timestamp_types = st.builds( | |
100 | pa.timestamp, | |
101 | unit=st.sampled_from(['s', 'ms', 'us', 'ns']), | |
102 | tz=tzst.timezones() | |
103 | ) | |
104 | duration_types = st.builds( | |
105 | pa.duration, | |
106 | st.sampled_from(['s', 'ms', 'us', 'ns']) | |
107 | ) | |
108 | interval_types = st.sampled_from( | |
109 | pa.month_day_nano_interval() | |
110 | ) | |
111 | temporal_types = st.one_of( | |
112 | date_types, | |
113 | time_types, | |
114 | timestamp_types, | |
115 | duration_types, | |
116 | interval_types | |
117 | ) | |
118 | ||
119 | primitive_types = st.one_of( | |
120 | null_type, | |
121 | bool_type, | |
122 | numeric_types, | |
123 | temporal_types, | |
124 | binary_like_types | |
125 | ) | |
126 | ||
127 | metadata = st.dictionaries(st.text(), st.text()) | |
128 | ||
129 | ||
130 | @st.composite | |
131 | def fields(draw, type_strategy=primitive_types): | |
132 | name = draw(custom_text) | |
133 | typ = draw(type_strategy) | |
134 | if pa.types.is_null(typ): | |
135 | nullable = True | |
136 | else: | |
137 | nullable = draw(st.booleans()) | |
138 | meta = draw(metadata) | |
139 | return pa.field(name, type=typ, nullable=nullable, metadata=meta) | |
140 | ||
141 | ||
142 | def list_types(item_strategy=primitive_types): | |
143 | return ( | |
144 | st.builds(pa.list_, item_strategy) | | |
145 | st.builds(pa.large_list, item_strategy) | | |
146 | st.builds( | |
147 | pa.list_, | |
148 | item_strategy, | |
149 | st.integers(min_value=0, max_value=16) | |
150 | ) | |
151 | ) | |
152 | ||
153 | ||
154 | @st.composite | |
155 | def struct_types(draw, item_strategy=primitive_types): | |
156 | fields_strategy = st.lists(fields(item_strategy)) | |
157 | fields_rendered = draw(fields_strategy) | |
158 | field_names = [field.name for field in fields_rendered] | |
159 | # check that field names are unique, see ARROW-9997 | |
160 | h.assume(len(set(field_names)) == len(field_names)) | |
161 | return pa.struct(fields_rendered) | |
162 | ||
163 | ||
164 | def dictionary_types(key_strategy=None, value_strategy=None): | |
165 | key_strategy = key_strategy or signed_integer_types | |
166 | value_strategy = value_strategy or st.one_of( | |
167 | bool_type, | |
168 | integer_types, | |
169 | st.sampled_from([pa.float32(), pa.float64()]), | |
170 | binary_type, | |
171 | string_type, | |
172 | fixed_size_binary_type, | |
173 | ) | |
174 | return st.builds(pa.dictionary, key_strategy, value_strategy) | |
175 | ||
176 | ||
177 | @st.composite | |
178 | def map_types(draw, key_strategy=primitive_types, | |
179 | item_strategy=primitive_types): | |
180 | key_type = draw(key_strategy) | |
181 | h.assume(not pa.types.is_null(key_type)) | |
182 | value_type = draw(item_strategy) | |
183 | return pa.map_(key_type, value_type) | |
184 | ||
185 | ||
186 | # union type | |
187 | # extension type | |
188 | ||
189 | ||
190 | def schemas(type_strategy=primitive_types, max_fields=None): | |
191 | children = st.lists(fields(type_strategy), max_size=max_fields) | |
192 | return st.builds(pa.schema, children) | |
193 | ||
194 | ||
195 | all_types = st.deferred( | |
196 | lambda: ( | |
197 | primitive_types | | |
198 | list_types() | | |
199 | struct_types() | | |
200 | dictionary_types() | | |
201 | map_types() | | |
202 | list_types(all_types) | | |
203 | struct_types(all_types) | |
204 | ) | |
205 | ) | |
206 | all_fields = fields(all_types) | |
207 | all_schemas = schemas(all_types) | |
208 | ||
209 | ||
210 | _default_array_sizes = st.integers(min_value=0, max_value=20) | |
211 | ||
212 | ||
213 | @st.composite | |
214 | def _pylist(draw, value_type, size, nullable=True): | |
215 | arr = draw(arrays(value_type, size=size, nullable=False)) | |
216 | return arr.to_pylist() | |
217 | ||
218 | ||
219 | @st.composite | |
220 | def _pymap(draw, key_type, value_type, size, nullable=True): | |
221 | length = draw(size) | |
222 | keys = draw(_pylist(key_type, size=length, nullable=False)) | |
223 | values = draw(_pylist(value_type, size=length, nullable=nullable)) | |
224 | return list(zip(keys, values)) | |
225 | ||
226 | ||
227 | @st.composite | |
228 | def arrays(draw, type, size=None, nullable=True): | |
229 | if isinstance(type, st.SearchStrategy): | |
230 | ty = draw(type) | |
231 | elif isinstance(type, pa.DataType): | |
232 | ty = type | |
233 | else: | |
234 | raise TypeError('Type must be a pyarrow DataType') | |
235 | ||
236 | if isinstance(size, st.SearchStrategy): | |
237 | size = draw(size) | |
238 | elif size is None: | |
239 | size = draw(_default_array_sizes) | |
240 | elif not isinstance(size, int): | |
241 | raise TypeError('Size must be an integer') | |
242 | ||
243 | if pa.types.is_null(ty): | |
244 | h.assume(nullable) | |
245 | value = st.none() | |
246 | elif pa.types.is_boolean(ty): | |
247 | value = st.booleans() | |
248 | elif pa.types.is_integer(ty): | |
249 | values = draw(npst.arrays(ty.to_pandas_dtype(), shape=(size,))) | |
250 | return pa.array(values, type=ty) | |
251 | elif pa.types.is_floating(ty): | |
252 | values = draw(npst.arrays(ty.to_pandas_dtype(), shape=(size,))) | |
253 | # Workaround ARROW-4952: no easy way to assert array equality | |
254 | # in a NaN-tolerant way. | |
255 | values[np.isnan(values)] = -42.0 | |
256 | return pa.array(values, type=ty) | |
257 | elif pa.types.is_decimal(ty): | |
258 | # TODO(kszucs): properly limit the precision | |
259 | # value = st.decimals(places=type.scale, allow_infinity=False) | |
260 | h.reject() | |
261 | elif pa.types.is_time(ty): | |
262 | value = st.times() | |
263 | elif pa.types.is_date(ty): | |
264 | value = st.dates() | |
265 | elif pa.types.is_timestamp(ty): | |
266 | min_int64 = -(2**63) | |
267 | max_int64 = 2**63 - 1 | |
268 | min_datetime = datetime.datetime.fromtimestamp(min_int64 // 10**9) | |
269 | max_datetime = datetime.datetime.fromtimestamp(max_int64 // 10**9) | |
270 | try: | |
271 | offset_hours = int(ty.tz) | |
272 | tz = pytz.FixedOffset(offset_hours * 60) | |
273 | except ValueError: | |
274 | tz = pytz.timezone(ty.tz) | |
275 | value = st.datetimes(timezones=st.just(tz), min_value=min_datetime, | |
276 | max_value=max_datetime) | |
277 | elif pa.types.is_duration(ty): | |
278 | value = st.timedeltas() | |
279 | elif pa.types.is_binary(ty) or pa.types.is_large_binary(ty): | |
280 | value = st.binary() | |
281 | elif pa.types.is_string(ty) or pa.types.is_large_string(ty): | |
282 | value = st.text() | |
283 | elif pa.types.is_fixed_size_binary(ty): | |
284 | value = st.binary(min_size=ty.byte_width, max_size=ty.byte_width) | |
285 | elif pa.types.is_list(ty): | |
286 | value = _pylist(ty.value_type, size=size, nullable=nullable) | |
287 | elif pa.types.is_large_list(ty): | |
288 | value = _pylist(ty.value_type, size=size, nullable=nullable) | |
289 | elif pa.types.is_fixed_size_list(ty): | |
290 | value = _pylist(ty.value_type, size=ty.list_size, nullable=nullable) | |
291 | elif pa.types.is_dictionary(ty): | |
292 | values = _pylist(ty.value_type, size=size, nullable=nullable) | |
293 | return pa.array(draw(values), type=ty) | |
294 | elif pa.types.is_map(ty): | |
295 | value = _pymap(ty.key_type, ty.item_type, size=_default_array_sizes, | |
296 | nullable=nullable) | |
297 | elif pa.types.is_struct(ty): | |
298 | h.assume(len(ty) > 0) | |
299 | fields, child_arrays = [], [] | |
300 | for field in ty: | |
301 | fields.append(field) | |
302 | child_arrays.append(draw(arrays(field.type, size=size))) | |
303 | return pa.StructArray.from_arrays(child_arrays, fields=fields) | |
304 | else: | |
305 | raise NotImplementedError(ty) | |
306 | ||
307 | if nullable: | |
308 | value = st.one_of(st.none(), value) | |
309 | values = st.lists(value, min_size=size, max_size=size) | |
310 | ||
311 | return pa.array(draw(values), type=ty) | |
312 | ||
313 | ||
314 | @st.composite | |
315 | def chunked_arrays(draw, type, min_chunks=0, max_chunks=None, chunk_size=None): | |
316 | if isinstance(type, st.SearchStrategy): | |
317 | type = draw(type) | |
318 | ||
319 | # TODO(kszucs): remove it, field metadata is not kept | |
320 | h.assume(not pa.types.is_struct(type)) | |
321 | ||
322 | chunk = arrays(type, size=chunk_size) | |
323 | chunks = st.lists(chunk, min_size=min_chunks, max_size=max_chunks) | |
324 | ||
325 | return pa.chunked_array(draw(chunks), type=type) | |
326 | ||
327 | ||
328 | @st.composite | |
329 | def record_batches(draw, type, rows=None, max_fields=None): | |
330 | if isinstance(rows, st.SearchStrategy): | |
331 | rows = draw(rows) | |
332 | elif rows is None: | |
333 | rows = draw(_default_array_sizes) | |
334 | elif not isinstance(rows, int): | |
335 | raise TypeError('Rows must be an integer') | |
336 | ||
337 | schema = draw(schemas(type, max_fields=max_fields)) | |
338 | children = [draw(arrays(field.type, size=rows)) for field in schema] | |
339 | # TODO(kszucs): the names and schema arguments are not consistent with | |
340 | # Table.from_array's arguments | |
341 | return pa.RecordBatch.from_arrays(children, names=schema) | |
342 | ||
343 | ||
344 | @st.composite | |
345 | def tables(draw, type, rows=None, max_fields=None): | |
346 | if isinstance(rows, st.SearchStrategy): | |
347 | rows = draw(rows) | |
348 | elif rows is None: | |
349 | rows = draw(_default_array_sizes) | |
350 | elif not isinstance(rows, int): | |
351 | raise TypeError('Rows must be an integer') | |
352 | ||
353 | schema = draw(schemas(type, max_fields=max_fields)) | |
354 | children = [draw(arrays(field.type, size=rows)) for field in schema] | |
355 | return pa.Table.from_arrays(children, schema=schema) | |
356 | ||
357 | ||
358 | all_arrays = arrays(all_types) | |
359 | all_chunked_arrays = chunked_arrays(all_types) | |
360 | all_record_batches = record_batches(all_types) | |
361 | all_tables = tables(all_types) | |
362 | ||
363 | ||
364 | # Define the same rules as above for pandas tests by excluding certain types | |
365 | # from the generation because of known issues. | |
366 | ||
367 | pandas_compatible_primitive_types = st.one_of( | |
368 | null_type, | |
369 | bool_type, | |
370 | integer_types, | |
371 | st.sampled_from([pa.float32(), pa.float64()]), | |
372 | decimal128_type, | |
373 | date_types, | |
374 | time_types, | |
375 | # Need to exclude timestamp and duration types otherwise hypothesis | |
376 | # discovers ARROW-10210 | |
377 | # timestamp_types, | |
378 | # duration_types | |
379 | interval_types, | |
380 | binary_type, | |
381 | string_type, | |
382 | large_binary_type, | |
383 | large_string_type, | |
384 | ) | |
385 | ||
386 | # Need to exclude floating point types otherwise hypothesis discovers | |
387 | # ARROW-10211 | |
388 | pandas_compatible_dictionary_value_types = st.one_of( | |
389 | bool_type, | |
390 | integer_types, | |
391 | binary_type, | |
392 | string_type, | |
393 | fixed_size_binary_type, | |
394 | ) | |
395 | ||
396 | ||
397 | def pandas_compatible_list_types( | |
398 | item_strategy=pandas_compatible_primitive_types | |
399 | ): | |
400 | # Need to exclude fixed size list type otherwise hypothesis discovers | |
401 | # ARROW-10194 | |
402 | return ( | |
403 | st.builds(pa.list_, item_strategy) | | |
404 | st.builds(pa.large_list, item_strategy) | |
405 | ) | |
406 | ||
407 | ||
408 | pandas_compatible_types = st.deferred( | |
409 | lambda: st.one_of( | |
410 | pandas_compatible_primitive_types, | |
411 | pandas_compatible_list_types(pandas_compatible_primitive_types), | |
412 | struct_types(pandas_compatible_primitive_types), | |
413 | dictionary_types( | |
414 | value_strategy=pandas_compatible_dictionary_value_types | |
415 | ), | |
416 | pandas_compatible_list_types(pandas_compatible_types), | |
417 | struct_types(pandas_compatible_types) | |
418 | ) | |
419 | ) |