1 # Licensed to the Apache Software Foundation (ASF) under one
2 # or more contributor license agreements. See the NOTICE file
3 # distributed with this work for additional information
4 # regarding copyright ownership. The ASF licenses this file
5 # to you under the Apache License, Version 2.0 (the
6 # "License"); you may not use this file except in compliance
7 # with the License. You may obtain a copy of the License at
9 # http://www.apache.org/licenses/LICENSE-2.0
11 # Unless required by applicable law or agreed to in writing,
12 # software distributed under the License is distributed on an
13 # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 # KIND, either express or implied. See the License for the
15 # specific language governing permissions and limitations
18 from collections
import namedtuple
, OrderedDict
27 from .util
import frombytes
, tobytes
, random_bytes
, random_utf8
30 def metadata_key_values(pairs
):
31 return [{'key': k
, 'value': v
} for k
, v
in pairs
]
36 def __init__(self
, name
, *, nullable
=True, metadata
=None):
38 self
.nullable
= nullable
39 self
.metadata
= metadata
or []
44 ('type', self
._get
_type
()),
45 ('nullable', self
.nullable
),
46 ('children', self
._get
_children
()),
49 dct
= self
._get
_dictionary
()
51 entries
.append(('dictionary', dct
))
53 if self
.metadata
is not None and len(self
.metadata
) > 0:
54 entries
.append(('metadata', metadata_key_values(self
.metadata
)))
56 return OrderedDict(entries
)
58 def _get_dictionary(self
):
61 def _make_is_valid(self
, size
, null_probability
=0.4):
63 return (np
.random
.random_sample(size
) > null_probability
66 return np
.ones(size
, dtype
=np
.int8
)
71 def __init__(self
, name
, count
):
78 def _get_children(self
):
81 def _get_buffers(self
):
90 buffers
= self
._get
_buffers
()
91 entries
.extend(buffers
)
93 children
= self
._get
_children
()
95 entries
.append(('children', children
))
97 return OrderedDict(entries
)
100 class PrimitiveField(Field
):
102 def _get_children(self
):
106 class PrimitiveColumn(Column
):
108 def __init__(self
, name
, count
, is_valid
, values
):
109 super().__init
__(name
, count
)
110 self
.is_valid
= is_valid
113 def _encode_value(self
, x
):
116 def _get_buffers(self
):
118 ('VALIDITY', [int(v
) for v
in self
.is_valid
]),
119 ('DATA', list([self
._encode
_value
(x
) for x
in self
.values
]))
123 class NullColumn(Column
):
124 # This subclass is for readability only
128 class NullField(PrimitiveField
):
130 def __init__(self
, name
, metadata
=None):
131 super().__init
__(name
, nullable
=True,
135 return OrderedDict([('name', 'null')])
137 def generate_column(self
, size
, name
=None):
138 return NullColumn(name
or self
.name
, size
)
141 TEST_INT_MAX
= 2 ** 31 - 1
142 TEST_INT_MIN
= ~TEST_INT_MAX
145 class IntegerField(PrimitiveField
):
147 def __init__(self
, name
, is_signed
, bit_width
, *, nullable
=True,
149 min_value
=TEST_INT_MIN
,
150 max_value
=TEST_INT_MAX
):
151 super().__init
__(name
, nullable
=nullable
,
153 self
.is_signed
= is_signed
154 self
.bit_width
= bit_width
155 self
.min_value
= min_value
156 self
.max_value
= max_value
158 def _get_generated_data_bounds(self
):
160 signed_iinfo
= np
.iinfo('int' + str(self
.bit_width
))
161 min_value
, max_value
= signed_iinfo
.min, signed_iinfo
.max
163 unsigned_iinfo
= np
.iinfo('uint' + str(self
.bit_width
))
164 min_value
, max_value
= 0, unsigned_iinfo
.max
166 lower_bound
= max(min_value
, self
.min_value
)
167 upper_bound
= min(max_value
, self
.max_value
)
168 return lower_bound
, upper_bound
173 ('isSigned', self
.is_signed
),
174 ('bitWidth', self
.bit_width
)
177 def generate_column(self
, size
, name
=None):
178 lower_bound
, upper_bound
= self
._get
_generated
_data
_bounds
()
179 return self
.generate_range(size
, lower_bound
, upper_bound
,
180 name
=name
, include_extremes
=True)
182 def generate_range(self
, size
, lower
, upper
, name
=None,
183 include_extremes
=False):
184 values
= np
.random
.randint(lower
, upper
, size
=size
, dtype
=np
.int64
)
185 if include_extremes
and size
>= 2:
186 values
[:2] = [lower
, upper
]
187 values
= list(map(int if self
.bit_width
< 64 else str, values
))
189 is_valid
= self
._make
_is
_valid
(size
)
193 return PrimitiveColumn(name
, size
, is_valid
, values
)
196 class DateField(IntegerField
):
201 # 1/1/1 to 12/31/9999
203 DAY
: [-719162, 2932896],
204 MILLISECOND
: [-62135596800000, 253402214400000]
207 def __init__(self
, name
, unit
, *, nullable
=True, metadata
=None):
208 bit_width
= 32 if unit
== self
.DAY
else 64
210 min_value
, max_value
= self
._ranges
[unit
]
212 name
, True, bit_width
,
213 nullable
=nullable
, metadata
=metadata
,
214 min_value
=min_value
, max_value
=max_value
221 ('unit', 'DAY' if self
.unit
== self
.DAY
else 'MILLISECOND')
233 class TimeField(IntegerField
):
245 'us': [0, 86400000000],
246 'ns': [0, 86400000000000]
249 def __init__(self
, name
, unit
='s', *, nullable
=True,
251 min_val
, max_val
= self
._ranges
[unit
]
252 super().__init
__(name
, True, self
.BIT_WIDTHS
[unit
],
253 nullable
=nullable
, metadata
=metadata
,
254 min_value
=min_val
, max_value
=max_val
)
260 ('unit', TIMEUNIT_NAMES
[self
.unit
]),
261 ('bitWidth', self
.bit_width
)
265 class TimestampField(IntegerField
):
267 # 1/1/1 to 12/31/9999
269 's': [-62135596800, 253402214400],
270 'ms': [-62135596800000, 253402214400000],
271 'us': [-62135596800000000, 253402214400000000],
273 # Physical range for int64, ~584 years and change
274 'ns': [np
.iinfo('int64').min, np
.iinfo('int64').max]
277 def __init__(self
, name
, unit
='s', tz
=None, *, nullable
=True,
279 min_val
, max_val
= self
._ranges
[unit
]
280 super().__init
__(name
, True, 64,
290 ('name', 'timestamp'),
291 ('unit', TIMEUNIT_NAMES
[self
.unit
])
294 if self
.tz
is not None:
295 fields
.append(('timezone', self
.tz
))
297 return OrderedDict(fields
)
300 class DurationIntervalField(IntegerField
):
302 def __init__(self
, name
, unit
='s', *, nullable
=True,
304 min_val
, max_val
= np
.iinfo('int64').min, np
.iinfo('int64').max,
307 nullable
=nullable
, metadata
=metadata
,
308 min_value
=min_val
, max_value
=max_val
)
313 ('name', 'duration'),
314 ('unit', TIMEUNIT_NAMES
[self
.unit
])
317 return OrderedDict(fields
)
320 class YearMonthIntervalField(IntegerField
):
321 def __init__(self
, name
, *, nullable
=True, metadata
=None):
322 min_val
, max_val
= [-10000*12, 10000*12] # +/- 10000 years.
325 nullable
=nullable
, metadata
=metadata
,
326 min_value
=min_val
, max_value
=max_val
)
330 ('name', 'interval'),
331 ('unit', 'YEAR_MONTH'),
334 return OrderedDict(fields
)
337 class DayTimeIntervalField(PrimitiveField
):
338 def __init__(self
, name
, *, nullable
=True, metadata
=None):
339 super().__init
__(name
,
344 def numpy_type(self
):
350 ('name', 'interval'),
351 ('unit', 'DAY_TIME'),
354 def generate_column(self
, size
, name
=None):
355 min_day_value
, max_day_value
= -10000*366, 10000*366
356 values
= [{'days': random
.randint(min_day_value
, max_day_value
),
357 'milliseconds': random
.randint(-86400000, +86400000)}
358 for _
in range(size
)]
360 is_valid
= self
._make
_is
_valid
(size
)
363 return PrimitiveColumn(name
, size
, is_valid
, values
)
366 class MonthDayNanoIntervalField(PrimitiveField
):
367 def __init__(self
, name
, *, nullable
=True, metadata
=None):
368 super().__init
__(name
,
373 def numpy_type(self
):
379 ('name', 'interval'),
380 ('unit', 'MONTH_DAY_NANO'),
383 def generate_column(self
, size
, name
=None):
385 min_int_value
, max_int_value
= np
.iinfo(I32
).min, np
.iinfo(I32
).max
387 min_nano_val
, max_nano_val
= np
.iinfo(I64
).min, np
.iinfo(I64
).max,
388 values
= [{'months': random
.randint(min_int_value
, max_int_value
),
389 'days': random
.randint(min_int_value
, max_int_value
),
390 'nanoseconds': random
.randint(min_nano_val
, max_nano_val
)}
391 for _
in range(size
)]
393 is_valid
= self
._make
_is
_valid
(size
)
396 return PrimitiveColumn(name
, size
, is_valid
, values
)
399 class FloatingPointField(PrimitiveField
):
401 def __init__(self
, name
, bit_width
, *, nullable
=True,
403 super().__init
__(name
,
407 self
.bit_width
= bit_width
415 def numpy_type(self
):
416 return 'float' + str(self
.bit_width
)
420 ('name', 'floatingpoint'),
421 ('precision', self
.precision
)
424 def generate_column(self
, size
, name
=None):
425 values
= np
.random
.randn(size
) * 1000
426 values
= np
.round(values
, 3)
428 is_valid
= self
._make
_is
_valid
(size
)
431 return PrimitiveColumn(name
, size
, is_valid
, values
)
434 DECIMAL_PRECISION_TO_VALUE
= {
435 key
: (1 << (8 * i
- 1)) - 1 for i
, key
in enumerate(
436 [1, 3, 5, 7, 10, 12, 15, 17, 19, 22, 24, 27, 29, 32, 34, 36,
437 40, 42, 44, 50, 60, 70],
443 def decimal_range_from_precision(precision
):
444 assert 1 <= precision
<= 76
446 max_value
= DECIMAL_PRECISION_TO_VALUE
[precision
]
448 return decimal_range_from_precision(precision
- 1)
450 return ~max_value
, max_value
453 class DecimalField(PrimitiveField
):
454 def __init__(self
, name
, precision
, scale
, bit_width
, *,
455 nullable
=True, metadata
=None):
456 super().__init
__(name
, nullable
=True,
458 self
.precision
= precision
460 self
.bit_width
= bit_width
463 def numpy_type(self
):
469 ('precision', self
.precision
),
470 ('scale', self
.scale
),
471 ('bitWidth', self
.bit_width
),
474 def generate_column(self
, size
, name
=None):
475 min_value
, max_value
= decimal_range_from_precision(self
.precision
)
476 values
= [random
.randint(min_value
, max_value
) for _
in range(size
)]
478 is_valid
= self
._make
_is
_valid
(size
)
481 return DecimalColumn(name
, size
, is_valid
, values
, self
.bit_width
)
484 class DecimalColumn(PrimitiveColumn
):
486 def __init__(self
, name
, count
, is_valid
, values
, bit_width
):
487 super().__init
__(name
, count
, is_valid
, values
)
488 self
.bit_width
= bit_width
490 def _encode_value(self
, x
):
494 class BooleanField(PrimitiveField
):
498 return OrderedDict([('name', 'bool')])
501 def numpy_type(self
):
504 def generate_column(self
, size
, name
=None):
505 values
= list(map(bool, np
.random
.randint(0, 2, size
=size
)))
506 is_valid
= self
._make
_is
_valid
(size
)
509 return PrimitiveColumn(name
, size
, is_valid
, values
)
512 class FixedSizeBinaryField(PrimitiveField
):
514 def __init__(self
, name
, byte_width
, *, nullable
=True,
516 super().__init
__(name
, nullable
=nullable
,
518 self
.byte_width
= byte_width
521 def numpy_type(self
):
525 def column_class(self
):
526 return FixedSizeBinaryColumn
529 return OrderedDict([('name', 'fixedsizebinary'),
530 ('byteWidth', self
.byte_width
)])
532 def generate_column(self
, size
, name
=None):
533 is_valid
= self
._make
_is
_valid
(size
)
536 for i
in range(size
):
537 values
.append(random_bytes(self
.byte_width
))
541 return self
.column_class(name
, size
, is_valid
, values
)
544 class BinaryField(PrimitiveField
):
547 def numpy_type(self
):
551 def column_class(self
):
555 return OrderedDict([('name', 'binary')])
557 def _random_sizes(self
, size
):
558 return np
.random
.exponential(scale
=4, size
=size
).astype(np
.int32
)
560 def generate_column(self
, size
, name
=None):
561 is_valid
= self
._make
_is
_valid
(size
)
564 sizes
= self
._random
_sizes
(size
)
566 for i
, nbytes
in enumerate(sizes
):
568 values
.append(random_bytes(nbytes
))
574 return self
.column_class(name
, size
, is_valid
, values
)
577 class StringField(BinaryField
):
580 def column_class(self
):
584 return OrderedDict([('name', 'utf8')])
586 def generate_column(self
, size
, name
=None):
588 is_valid
= self
._make
_is
_valid
(size
)
591 for i
in range(size
):
593 values
.append(tobytes(random_utf8(K
)))
599 return self
.column_class(name
, size
, is_valid
, values
)
602 class LargeBinaryField(BinaryField
):
605 def column_class(self
):
606 return LargeBinaryColumn
609 return OrderedDict([('name', 'largebinary')])
612 class LargeStringField(StringField
):
615 def column_class(self
):
616 return LargeStringColumn
619 return OrderedDict([('name', 'largeutf8')])
622 class Schema(object):
624 def __init__(self
, fields
, metadata
=None):
626 self
.metadata
= metadata
630 ('fields', [field
.get_json() for field
in self
.fields
])
633 if self
.metadata
is not None and len(self
.metadata
) > 0:
634 entries
.append(('metadata', metadata_key_values(self
.metadata
)))
636 return OrderedDict(entries
)
639 class _NarrowOffsetsMixin
:
641 def _encode_offsets(self
, offsets
):
642 return list(map(int, offsets
))
645 class _LargeOffsetsMixin
:
647 def _encode_offsets(self
, offsets
):
648 # 64-bit offsets have to be represented as strings to roundtrip
650 return list(map(str, offsets
))
653 class _BaseBinaryColumn(PrimitiveColumn
):
655 def _encode_value(self
, x
):
656 return frombytes(binascii
.hexlify(x
).upper())
658 def _get_buffers(self
):
663 for i
, v
in enumerate(self
.values
):
669 offsets
.append(offset
)
670 data
.append(self
._encode
_value
(v
))
673 ('VALIDITY', [int(x
) for x
in self
.is_valid
]),
674 ('OFFSET', self
._encode
_offsets
(offsets
)),
679 class _BaseStringColumn(_BaseBinaryColumn
):
681 def _encode_value(self
, x
):
685 class BinaryColumn(_BaseBinaryColumn
, _NarrowOffsetsMixin
):
689 class StringColumn(_BaseStringColumn
, _NarrowOffsetsMixin
):
693 class LargeBinaryColumn(_BaseBinaryColumn
, _LargeOffsetsMixin
):
697 class LargeStringColumn(_BaseStringColumn
, _LargeOffsetsMixin
):
701 class FixedSizeBinaryColumn(PrimitiveColumn
):
703 def _encode_value(self
, x
):
704 return frombytes(binascii
.hexlify(x
).upper())
706 def _get_buffers(self
):
708 for i
, v
in enumerate(self
.values
):
709 data
.append(self
._encode
_value
(v
))
712 ('VALIDITY', [int(x
) for x
in self
.is_valid
]),
717 class ListField(Field
):
719 def __init__(self
, name
, value_field
, *, nullable
=True,
721 super().__init
__(name
, nullable
=nullable
,
723 self
.value_field
= value_field
726 def column_class(self
):
734 def _get_children(self
):
735 return [self
.value_field
.get_json()]
737 def generate_column(self
, size
, name
=None):
740 is_valid
= self
._make
_is
_valid
(size
)
741 list_sizes
= np
.random
.randint(0, MAX_LIST_SIZE
+ 1, size
=size
)
745 for i
in range(size
):
747 offset
+= int(list_sizes
[i
])
748 offsets
.append(offset
)
750 # The offset now is the total number of elements in the child array
751 values
= self
.value_field
.generate_column(offset
)
755 return self
.column_class(name
, size
, is_valid
, offsets
, values
)
758 class LargeListField(ListField
):
761 def column_class(self
):
762 return LargeListColumn
766 ('name', 'largelist')
770 class _BaseListColumn(Column
):
772 def __init__(self
, name
, count
, is_valid
, offsets
, values
):
773 super().__init
__(name
, count
)
774 self
.is_valid
= is_valid
775 self
.offsets
= offsets
778 def _get_buffers(self
):
780 ('VALIDITY', [int(v
) for v
in self
.is_valid
]),
781 ('OFFSET', self
._encode
_offsets
(self
.offsets
))
784 def _get_children(self
):
785 return [self
.values
.get_json()]
788 class ListColumn(_BaseListColumn
, _NarrowOffsetsMixin
):
792 class LargeListColumn(_BaseListColumn
, _LargeOffsetsMixin
):
796 class MapField(Field
):
798 def __init__(self
, name
, key_field
, item_field
, *, nullable
=True,
799 metadata
=None, keys_sorted
=False, entries_name
='entries'):
800 super().__init
__(name
, nullable
=nullable
,
803 assert not key_field
.nullable
804 self
.key_field
= key_field
805 self
.item_field
= item_field
806 self
.pair_field
= StructField(entries_name
, [key_field
, item_field
],
808 self
.keys_sorted
= keys_sorted
813 ('keysSorted', self
.keys_sorted
)
816 def _get_children(self
):
817 return [self
.pair_field
.get_json()]
819 def generate_column(self
, size
, name
=None):
822 is_valid
= self
._make
_is
_valid
(size
)
823 map_sizes
= np
.random
.randint(0, MAX_MAP_SIZE
+ 1, size
=size
)
827 for i
in range(size
):
829 offset
+= int(map_sizes
[i
])
830 offsets
.append(offset
)
832 # The offset now is the total number of elements in the child array
833 pairs
= self
.pair_field
.generate_column(offset
)
837 return MapColumn(name
, size
, is_valid
, offsets
, pairs
)
840 class MapColumn(Column
):
842 def __init__(self
, name
, count
, is_valid
, offsets
, pairs
):
843 super().__init
__(name
, count
)
844 self
.is_valid
= is_valid
845 self
.offsets
= offsets
848 def _get_buffers(self
):
850 ('VALIDITY', [int(v
) for v
in self
.is_valid
]),
851 ('OFFSET', list(self
.offsets
))
854 def _get_children(self
):
855 return [self
.pairs
.get_json()]
858 class FixedSizeListField(Field
):
860 def __init__(self
, name
, value_field
, list_size
, *, nullable
=True,
862 super().__init
__(name
, nullable
=nullable
,
864 self
.value_field
= value_field
865 self
.list_size
= list_size
869 ('name', 'fixedsizelist'),
870 ('listSize', self
.list_size
)
873 def _get_children(self
):
874 return [self
.value_field
.get_json()]
876 def generate_column(self
, size
, name
=None):
877 is_valid
= self
._make
_is
_valid
(size
)
878 values
= self
.value_field
.generate_column(size
* self
.list_size
)
882 return FixedSizeListColumn(name
, size
, is_valid
, values
)
885 class FixedSizeListColumn(Column
):
887 def __init__(self
, name
, count
, is_valid
, values
):
888 super().__init
__(name
, count
)
889 self
.is_valid
= is_valid
892 def _get_buffers(self
):
894 ('VALIDITY', [int(v
) for v
in self
.is_valid
])
897 def _get_children(self
):
898 return [self
.values
.get_json()]
901 class StructField(Field
):
903 def __init__(self
, name
, fields
, *, nullable
=True,
905 super().__init
__(name
, nullable
=nullable
,
914 def _get_children(self
):
915 return [field
.get_json() for field
in self
.fields
]
917 def generate_column(self
, size
, name
=None):
918 is_valid
= self
._make
_is
_valid
(size
)
920 field_values
= [field
.generate_column(size
) for field
in self
.fields
]
923 return StructColumn(name
, size
, is_valid
, field_values
)
926 class _BaseUnionField(Field
):
928 def __init__(self
, name
, fields
, type_ids
=None, *, nullable
=True,
930 super().__init
__(name
, nullable
=nullable
, metadata
=metadata
)
932 type_ids
= list(range(fields
))
934 assert len(fields
) == len(type_ids
)
936 self
.type_ids
= type_ids
937 assert all(x
>= 0 for x
in self
.type_ids
)
943 ('typeIds', self
.type_ids
),
946 def _get_children(self
):
947 return [field
.get_json() for field
in self
.fields
]
949 def _make_type_ids(self
, size
):
950 return np
.random
.choice(self
.type_ids
, size
)
953 class SparseUnionField(_BaseUnionField
):
956 def generate_column(self
, size
, name
=None):
957 array_type_ids
= self
._make
_type
_ids
(size
)
958 field_values
= [field
.generate_column(size
) for field
in self
.fields
]
962 return SparseUnionColumn(name
, size
, array_type_ids
, field_values
)
965 class DenseUnionField(_BaseUnionField
):
968 def generate_column(self
, size
, name
=None):
969 # Reverse mapping {logical type id => physical child id}
970 child_ids
= [None] * (max(self
.type_ids
) + 1)
971 for i
, type_id
in enumerate(self
.type_ids
):
972 child_ids
[type_id
] = i
974 array_type_ids
= self
._make
_type
_ids
(size
)
976 child_sizes
= [0] * len(self
.fields
)
978 for i
in range(size
):
979 child_id
= child_ids
[array_type_ids
[i
]]
980 offset
= child_sizes
[child_id
]
981 offsets
.append(offset
)
982 child_sizes
[child_id
] = offset
+ 1
985 field
.generate_column(child_size
)
986 for field
, child_size
in zip(self
.fields
, child_sizes
)]
990 return DenseUnionColumn(name
, size
, array_type_ids
, offsets
,
994 class Dictionary(object):
996 def __init__(self
, id_
, field
, size
, name
=None, ordered
=False):
999 self
.values
= field
.generate_column(size
=size
, name
=name
)
1000 self
.ordered
= ordered
1003 return len(self
.values
)
1006 dummy_batch
= RecordBatch(len(self
.values
), [self
.values
])
1007 return OrderedDict([
1009 ('data', dummy_batch
.get_json())
1013 class DictionaryField(Field
):
1015 def __init__(self
, name
, index_field
, dictionary
, *, nullable
=True,
1017 super().__init
__(name
, nullable
=nullable
,
1019 assert index_field
.name
== ''
1020 assert isinstance(index_field
, IntegerField
)
1021 assert isinstance(dictionary
, Dictionary
)
1023 self
.index_field
= index_field
1024 self
.dictionary
= dictionary
1026 def _get_type(self
):
1027 return self
.dictionary
.field
._get
_type
()
1029 def _get_children(self
):
1030 return self
.dictionary
.field
._get
_children
()
1032 def _get_dictionary(self
):
1033 return OrderedDict([
1034 ('id', self
.dictionary
.id_
),
1035 ('indexType', self
.index_field
._get
_type
()),
1036 ('isOrdered', self
.dictionary
.ordered
)
1039 def generate_column(self
, size
, name
=None):
1042 return self
.index_field
.generate_range(size
, 0, len(self
.dictionary
),
1046 ExtensionType
= namedtuple(
1047 'ExtensionType', ['extension_name', 'serialized', 'storage_field'])
1050 class ExtensionField(Field
):
1052 def __init__(self
, name
, extension_type
, *, nullable
=True, metadata
=None):
1053 metadata
= (metadata
or []) + [
1054 ('ARROW:extension:name', extension_type
.extension_name
),
1055 ('ARROW:extension:metadata', extension_type
.serialized
),
1057 super().__init
__(name
, nullable
=nullable
, metadata
=metadata
)
1058 self
.extension_type
= extension_type
1060 def _get_type(self
):
1061 return self
.extension_type
.storage_field
._get
_type
()
1063 def _get_children(self
):
1064 return self
.extension_type
.storage_field
._get
_children
()
1066 def _get_dictionary(self
):
1067 return self
.extension_type
.storage_field
._get
_dictionary
()
1069 def generate_column(self
, size
, name
=None):
1072 return self
.extension_type
.storage_field
.generate_column(size
, name
)
1075 class StructColumn(Column
):
1077 def __init__(self
, name
, count
, is_valid
, field_values
):
1078 super().__init
__(name
, count
)
1079 self
.is_valid
= is_valid
1080 self
.field_values
= field_values
1082 def _get_buffers(self
):
1084 ('VALIDITY', [int(v
) for v
in self
.is_valid
])
1087 def _get_children(self
):
1088 return [field
.get_json() for field
in self
.field_values
]
1091 class SparseUnionColumn(Column
):
1093 def __init__(self
, name
, count
, type_ids
, field_values
):
1094 super().__init
__(name
, count
)
1095 self
.type_ids
= type_ids
1096 self
.field_values
= field_values
1098 def _get_buffers(self
):
1100 ('TYPE_ID', [int(v
) for v
in self
.type_ids
])
1103 def _get_children(self
):
1104 return [field
.get_json() for field
in self
.field_values
]
1107 class DenseUnionColumn(Column
):
1109 def __init__(self
, name
, count
, type_ids
, offsets
, field_values
):
1110 super().__init
__(name
, count
)
1111 self
.type_ids
= type_ids
1112 self
.offsets
= offsets
1113 self
.field_values
= field_values
1115 def _get_buffers(self
):
1117 ('TYPE_ID', [int(v
) for v
in self
.type_ids
]),
1118 ('OFFSET', [int(v
) for v
in self
.offsets
]),
1121 def _get_children(self
):
1122 return [field
.get_json() for field
in self
.field_values
]
1125 class RecordBatch(object):
1127 def __init__(self
, count
, columns
):
1129 self
.columns
= columns
1132 return OrderedDict([
1133 ('count', self
.count
),
1134 ('columns', [col
.get_json() for col
in self
.columns
])
1140 def __init__(self
, name
, schema
, batches
, dictionaries
=None,
1141 skip
=None, path
=None):
1143 self
.schema
= schema
1144 self
.dictionaries
= dictionaries
or []
1145 self
.batches
= batches
1149 self
.skip
.update(skip
)
1153 ('schema', self
.schema
.get_json())
1156 if len(self
.dictionaries
) > 0:
1157 entries
.append(('dictionaries',
1158 [dictionary
.get_json()
1159 for dictionary
in self
.dictionaries
]))
1161 entries
.append(('batches', [batch
.get_json()
1162 for batch
in self
.batches
]))
1163 return OrderedDict(entries
)
1165 def write(self
, path
):
1166 with
open(path
, 'wb') as f
:
1167 f
.write(json
.dumps(self
.get_json(), indent
=2).encode('utf-8'))
1170 def skip_category(self
, category
):
1171 """Skip this test for the given category.
1173 Category should be SKIP_ARROW or SKIP_FLIGHT.
1175 self
.skip
.add(category
)
1179 def get_field(name
, type_
, **kwargs
):
1180 if type_
== 'binary':
1181 return BinaryField(name
, **kwargs
)
1182 elif type_
== 'utf8':
1183 return StringField(name
, **kwargs
)
1184 elif type_
== 'largebinary':
1185 return LargeBinaryField(name
, **kwargs
)
1186 elif type_
== 'largeutf8':
1187 return LargeStringField(name
, **kwargs
)
1188 elif type_
.startswith('fixedsizebinary_'):
1189 byte_width
= int(type_
.split('_')[1])
1190 return FixedSizeBinaryField(name
, byte_width
=byte_width
, **kwargs
)
1192 dtype
= np
.dtype(type_
)
1194 if dtype
.kind
in ('i', 'u'):
1195 signed
= dtype
.kind
== 'i'
1196 bit_width
= dtype
.itemsize
* 8
1197 return IntegerField(name
, signed
, bit_width
, **kwargs
)
1198 elif dtype
.kind
== 'f':
1199 bit_width
= dtype
.itemsize
* 8
1200 return FloatingPointField(name
, bit_width
, **kwargs
)
1201 elif dtype
.kind
== 'b':
1202 return BooleanField(name
, **kwargs
)
1204 raise TypeError(dtype
)
1207 def _generate_file(name
, fields
, batch_sizes
, dictionaries
=None, skip
=None,
1209 schema
= Schema(fields
, metadata
=metadata
)
1211 for size
in batch_sizes
:
1213 for field
in fields
:
1214 col
= field
.generate_column(size
)
1217 batches
.append(RecordBatch(size
, columns
))
1219 return File(name
, schema
, batches
, dictionaries
, skip
=skip
)
1222 def generate_custom_metadata_case():
1224 # Generate a simple block of metadata where each value is '{}'.
1225 # Keys are delimited by whitespace in `items`.
1226 return [(k
, '{}') for k
in items
.split()]
1229 get_field('sort_of_pandas', 'int8', metadata
=meta('pandas')),
1231 get_field('lots_of_meta', 'int8', metadata
=meta('a b c d .. w x y z')),
1234 'unregistered_extension', 'int8',
1236 ('ARROW:extension:name', '!nonexistent'),
1237 ('ARROW:extension:metadata', ''),
1238 ('ARROW:integration:allow_unregistered_extension', 'true'),
1241 ListField('list_with_odd_values',
1242 get_field('item', 'int32', metadata
=meta('odd_values'))),
1246 return _generate_file('custom_metadata', fields
, batch_sizes
,
1247 metadata
=meta('schema_custom_0 schema_custom_1'))
1250 def generate_duplicate_fieldnames_case():
1252 get_field('ints', 'int8'),
1253 get_field('ints', 'int32'),
1255 StructField('struct', [get_field('', 'int32'), get_field('', 'utf8')]),
1259 return _generate_file('duplicate_fieldnames', fields
, batch_sizes
)
1262 def generate_primitive_case(batch_sizes
, name
='primitive'):
1263 types
= ['bool', 'int8', 'int16', 'int32', 'int64',
1264 'uint8', 'uint16', 'uint32', 'uint64',
1265 'float32', 'float64', 'binary', 'utf8',
1266 'fixedsizebinary_19', 'fixedsizebinary_120']
1271 fields
.append(get_field(type_
+ "_nullable", type_
, nullable
=True))
1272 fields
.append(get_field(type_
+ "_nonnullable", type_
, nullable
=False))
1274 return _generate_file(name
, fields
, batch_sizes
)
1277 def generate_primitive_large_offsets_case(batch_sizes
):
1278 types
= ['largebinary', 'largeutf8']
1283 fields
.append(get_field(type_
+ "_nullable", type_
, nullable
=True))
1284 fields
.append(get_field(type_
+ "_nonnullable", type_
, nullable
=False))
1286 return _generate_file('primitive_large_offsets', fields
, batch_sizes
)
1289 def generate_null_case(batch_sizes
):
1290 # Interleave null with non-null types to ensure the appropriate number of
1291 # buffers (0) is read and written
1293 NullField(name
='f0'),
1294 get_field('f1', 'int32'),
1295 NullField(name
='f2'),
1296 get_field('f3', 'float64'),
1297 NullField(name
='f4')
1299 return _generate_file('null', fields
, batch_sizes
)
1302 def generate_null_trivial_case(batch_sizes
):
1303 # Generate a case with no buffers
1305 NullField(name
='f0'),
1307 return _generate_file('null_trivial', fields
, batch_sizes
)
1310 def generate_decimal128_case():
1312 DecimalField(name
='f{}'.format(i
), precision
=precision
, scale
=2,
1314 for i
, precision
in enumerate(range(3, 39))
1317 possible_batch_sizes
= 7, 10
1318 batch_sizes
= [possible_batch_sizes
[i
% 2] for i
in range(len(fields
))]
1319 # 'decimal' is the original name for the test, and it must match
1320 # provide "gold" files that test backwards compatibility, so they
1321 # can be appropriately skipped.
1322 return _generate_file('decimal', fields
, batch_sizes
)
1325 def generate_decimal256_case():
1327 DecimalField(name
='f{}'.format(i
), precision
=precision
, scale
=5,
1329 for i
, precision
in enumerate(range(37, 70))
1332 possible_batch_sizes
= 7, 10
1333 batch_sizes
= [possible_batch_sizes
[i
% 2] for i
in range(len(fields
))]
1334 return _generate_file('decimal256', fields
, batch_sizes
)
1337 def generate_datetime_case():
1339 DateField('f0', DateField
.DAY
),
1340 DateField('f1', DateField
.MILLISECOND
),
1341 TimeField('f2', 's'),
1342 TimeField('f3', 'ms'),
1343 TimeField('f4', 'us'),
1344 TimeField('f5', 'ns'),
1345 TimestampField('f6', 's'),
1346 TimestampField('f7', 'ms'),
1347 TimestampField('f8', 'us'),
1348 TimestampField('f9', 'ns'),
1349 TimestampField('f10', 'ms', tz
=None),
1350 TimestampField('f11', 's', tz
='UTC'),
1351 TimestampField('f12', 'ms', tz
='US/Eastern'),
1352 TimestampField('f13', 'us', tz
='Europe/Paris'),
1353 TimestampField('f14', 'ns', tz
='US/Pacific'),
1356 batch_sizes
= [7, 10]
1357 return _generate_file("datetime", fields
, batch_sizes
)
1360 def generate_interval_case():
1362 DurationIntervalField('f1', 's'),
1363 DurationIntervalField('f2', 'ms'),
1364 DurationIntervalField('f3', 'us'),
1365 DurationIntervalField('f4', 'ns'),
1366 YearMonthIntervalField('f5'),
1367 DayTimeIntervalField('f6'),
1370 batch_sizes
= [7, 10]
1371 return _generate_file("interval", fields
, batch_sizes
)
1374 def generate_month_day_nano_interval_case():
1376 MonthDayNanoIntervalField('f1'),
1379 batch_sizes
= [7, 10]
1380 return _generate_file("interval_mdn", fields
, batch_sizes
)
1383 def generate_map_case():
1385 MapField('map_nullable', get_field('key', 'utf8', nullable
=False),
1386 get_field('value', 'int32')),
1389 batch_sizes
= [7, 10]
1390 return _generate_file("map", fields
, batch_sizes
)
1393 def generate_non_canonical_map_case():
1395 MapField('map_other_names',
1396 get_field('some_key', 'utf8', nullable
=False),
1397 get_field('some_value', 'int32'),
1398 entries_name
='some_entries'),
1402 return _generate_file("map_non_canonical", fields
, batch_sizes
)
1405 def generate_nested_case():
1407 ListField('list_nullable', get_field('item', 'int32')),
1408 FixedSizeListField('fixedsizelist_nullable',
1409 get_field('item', 'int32'), 4),
1410 StructField('struct_nullable', [get_field('f1', 'int32'),
1411 get_field('f2', 'utf8')]),
1412 # Fails on Go (ARROW-8452)
1413 # ListField('list_nonnullable', get_field('item', 'int32'),
1417 batch_sizes
= [7, 10]
1418 return _generate_file("nested", fields
, batch_sizes
)
1421 def generate_recursive_nested_case():
1423 ListField('lists_list',
1424 ListField('inner_list', get_field('item', 'int16'))),
1425 ListField('structs_list',
1426 StructField('inner_struct',
1427 [get_field('f1', 'int32'),
1428 get_field('f2', 'utf8')])),
1431 batch_sizes
= [7, 10]
1432 return _generate_file("recursive_nested", fields
, batch_sizes
)
1435 def generate_nested_large_offsets_case():
1437 LargeListField('large_list_nullable', get_field('item', 'int32')),
1438 LargeListField('large_list_nonnullable',
1439 get_field('item', 'int32'), nullable
=False),
1440 LargeListField('large_list_nested',
1441 ListField('inner_list', get_field('item', 'int16'))),
1444 batch_sizes
= [0, 13]
1445 return _generate_file("nested_large_offsets", fields
, batch_sizes
)
1448 def generate_unions_case():
1450 SparseUnionField('sparse', [get_field('f1', 'int32'),
1451 get_field('f2', 'utf8')],
1453 DenseUnionField('dense', [get_field('f1', 'int16'),
1454 get_field('f2', 'binary')],
1456 SparseUnionField('sparse', [get_field('f1', 'float32', nullable
=False),
1457 get_field('f2', 'bool')],
1458 type_ids
=[5, 7], nullable
=False),
1459 DenseUnionField('dense', [get_field('f1', 'uint8', nullable
=False),
1460 get_field('f2', 'uint16'),
1462 type_ids
=[42, 43, 44], nullable
=False),
1465 batch_sizes
= [0, 11]
1466 return _generate_file("union", fields
, batch_sizes
)
1469 def generate_dictionary_case():
1470 dict0
= Dictionary(0, StringField('dictionary1'), size
=10, name
='DICT0')
1471 dict1
= Dictionary(1, StringField('dictionary1'), size
=5, name
='DICT1')
1472 dict2
= Dictionary(2, get_field('dictionary2', 'int64'),
1473 size
=50, name
='DICT2')
1476 DictionaryField('dict0', get_field('', 'int8'), dict0
),
1477 DictionaryField('dict1', get_field('', 'int32'), dict1
),
1478 DictionaryField('dict2', get_field('', 'int16'), dict2
)
1480 batch_sizes
= [7, 10]
1481 return _generate_file("dictionary", fields
, batch_sizes
,
1482 dictionaries
=[dict0
, dict1
, dict2
])
1485 def generate_dictionary_unsigned_case():
1486 dict0
= Dictionary(0, StringField('dictionary0'), size
=5, name
='DICT0')
1487 dict1
= Dictionary(1, StringField('dictionary1'), size
=5, name
='DICT1')
1488 dict2
= Dictionary(2, StringField('dictionary2'), size
=5, name
='DICT2')
1490 # TODO: JavaScript does not support uint64 dictionary indices, so disabled
1493 # dict3 = Dictionary(3, StringField('dictionary3'), size=5, name='DICT3')
1495 DictionaryField('f0', get_field('', 'uint8'), dict0
),
1496 DictionaryField('f1', get_field('', 'uint16'), dict1
),
1497 DictionaryField('f2', get_field('', 'uint32'), dict2
),
1498 # DictionaryField('f3', get_field('', 'uint64'), dict3)
1500 batch_sizes
= [7, 10]
1501 return _generate_file("dictionary_unsigned", fields
, batch_sizes
,
1502 dictionaries
=[dict0
, dict1
, dict2
])
1505 def generate_nested_dictionary_case():
1506 dict0
= Dictionary(0, StringField('str'), size
=10, name
='DICT0')
1508 list_of_dict
= ListField(
1510 DictionaryField('str_dict', get_field('', 'int8'), dict0
))
1511 dict1
= Dictionary(1, list_of_dict
, size
=30, name
='DICT1')
1513 struct_of_dict
= StructField('struct', [
1514 DictionaryField('str_dict_a', get_field('', 'int8'), dict0
),
1515 DictionaryField('str_dict_b', get_field('', 'int8'), dict0
)
1517 dict2
= Dictionary(2, struct_of_dict
, size
=30, name
='DICT2')
1520 DictionaryField('list_dict', get_field('', 'int8'), dict1
),
1521 DictionaryField('struct_dict', get_field('', 'int8'), dict2
)
1524 batch_sizes
= [10, 13]
1525 return _generate_file("nested_dictionary", fields
, batch_sizes
,
1526 dictionaries
=[dict0
, dict1
, dict2
])
1529 def generate_extension_case():
1530 dict0
= Dictionary(0, StringField('dictionary0'), size
=5, name
='DICT0')
1532 uuid_type
= ExtensionType('uuid', 'uuid-serialized',
1533 FixedSizeBinaryField('', 16))
1534 dict_ext_type
= ExtensionType(
1535 'dict-extension', 'dict-extension-serialized',
1536 DictionaryField('str_dict', get_field('', 'int8'), dict0
))
1539 ExtensionField('uuids', uuid_type
),
1540 ExtensionField('dict_exts', dict_ext_type
),
1543 batch_sizes
= [0, 13]
1544 return _generate_file("extension", fields
, batch_sizes
,
1545 dictionaries
=[dict0
])
1548 def get_generated_json_files(tempdir
=None):
1549 tempdir
= tempdir
or tempfile
.mkdtemp(prefix
='arrow-integration-')
1555 generate_primitive_case([], name
='primitive_no_batches'),
1556 generate_primitive_case([17, 20], name
='primitive'),
1557 generate_primitive_case([0, 0, 0], name
='primitive_zerolength'),
1559 generate_primitive_large_offsets_case([17, 20])
1560 .skip_category('C#')
1561 .skip_category('Go')
1562 .skip_category('JS'),
1564 generate_null_case([10, 0])
1565 .skip_category('C#')
1566 .skip_category('JS'), # TODO(ARROW-7900)
1568 generate_null_trivial_case([0, 0])
1569 .skip_category('C#')
1570 .skip_category('JS'), # TODO(ARROW-7900)
1572 generate_decimal128_case()
1573 .skip_category('Rust'),
1575 generate_decimal256_case()
1576 .skip_category('Go') # TODO(ARROW-7948): Decimal + Go
1577 .skip_category('JS')
1578 .skip_category('Rust'),
1580 generate_datetime_case()
1581 .skip_category('C#'),
1583 generate_interval_case()
1584 .skip_category('C#')
1585 .skip_category('JS') # TODO(ARROW-5239): Intervals + JS
1586 .skip_category('Rust'),
1588 generate_month_day_nano_interval_case()
1589 .skip_category('C#')
1590 .skip_category('JS')
1591 .skip_category('Rust'),
1595 .skip_category('C#')
1596 .skip_category('Rust'),
1598 generate_non_canonical_map_case()
1599 .skip_category('C#')
1600 .skip_category('Java') # TODO(ARROW-8715)
1601 .skip_category('JS') # TODO(ARROW-8716)
1602 .skip_category('Rust'),
1604 generate_nested_case()
1605 .skip_category('C#'),
1607 generate_recursive_nested_case()
1608 .skip_category('C#'),
1610 generate_nested_large_offsets_case()
1611 .skip_category('C#')
1612 .skip_category('Go')
1613 .skip_category('JS')
1614 .skip_category('Rust'),
1616 generate_unions_case()
1617 .skip_category('C#')
1618 .skip_category('Go')
1619 .skip_category('JS')
1620 .skip_category('Rust'),
1622 generate_custom_metadata_case()
1623 .skip_category('C#')
1624 .skip_category('JS'),
1626 generate_duplicate_fieldnames_case()
1627 .skip_category('C#')
1628 .skip_category('Go')
1629 .skip_category('JS'),
1631 # TODO(ARROW-3039, ARROW-5267): Dictionaries in GO
1632 generate_dictionary_case()
1633 .skip_category('C#')
1634 .skip_category('Go'),
1636 generate_dictionary_unsigned_case()
1637 .skip_category('C#')
1638 .skip_category('Go') # TODO(ARROW-9378)
1639 .skip_category('Java'), # TODO(ARROW-9377)
1641 generate_nested_dictionary_case()
1642 .skip_category('C#')
1643 .skip_category('Go')
1644 .skip_category('Java') # TODO(ARROW-7779)
1645 .skip_category('JS')
1646 .skip_category('Rust'),
1648 generate_extension_case()
1649 .skip_category('C#')
1650 .skip_category('Go') # TODO(ARROW-3039): requires dictionaries
1651 .skip_category('JS')
1652 .skip_category('Rust'),
1655 generated_paths
= []
1656 for file_obj
in file_objs
:
1657 out_path
= os
.path
.join(tempdir
, 'generated_' +
1658 file_obj
.name
+ '.json')
1659 file_obj
.write(out_path
)
1660 generated_paths
.append(file_obj
)
1662 return generated_paths