]>
git.proxmox.com Git - ceph.git/blob - ceph/src/arrow/cpp/src/arrow/python/iterators.h
1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements. See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership. The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License. You may obtain a copy of the License at
9 // http://www.apache.org/licenses/LICENSE-2.0
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied. See the License for the
15 // specific language governing permissions and limitations
22 #include "arrow/array/array_primitive.h"
24 #include "arrow/python/common.h"
25 #include "arrow/python/numpy_internal.h"
31 using arrow::internal::checked_cast
;
33 // Visit the Python sequence, calling the given callable on each element. If
34 // the callable returns a non-OK status, iteration stops and the status is
37 // The call signature for Visitor must be
39 // Visit(PyObject* obj, int64_t index, bool* keep_going)
41 // If keep_going is set to false, the iteration terminates
42 template <class VisitorFunc
>
43 inline Status
VisitSequenceGeneric(PyObject
* obj
, int64_t offset
, VisitorFunc
&& func
) {
44 // VisitorFunc may set to false to terminate iteration
45 bool keep_going
= true;
47 if (PyArray_Check(obj
)) {
48 PyArrayObject
* arr_obj
= reinterpret_cast<PyArrayObject
*>(obj
);
49 if (PyArray_NDIM(arr_obj
) != 1) {
50 return Status::Invalid("Only 1D arrays accepted");
53 if (PyArray_DESCR(arr_obj
)->type_num
== NPY_OBJECT
) {
54 // It's an array object, we can fetch object pointers directly
55 const Ndarray1DIndexer
<PyObject
*> objects(arr_obj
);
56 for (int64_t i
= offset
; keep_going
&& i
< objects
.size(); ++i
) {
57 RETURN_NOT_OK(func(objects
[i
], i
, &keep_going
));
61 // It's a non-object array, fall back on regular sequence access.
62 // (note PyArray_GETITEM() is slightly different: it returns standard
63 // Python types, not Numpy scalar types)
64 // This code path is inefficient: callers should implement dedicated
65 // logic for non-object arrays.
67 if (PySequence_Check(obj
)) {
68 if (PyList_Check(obj
) || PyTuple_Check(obj
)) {
69 // Use fast item access
70 const Py_ssize_t size
= PySequence_Fast_GET_SIZE(obj
);
71 for (Py_ssize_t i
= offset
; keep_going
&& i
< size
; ++i
) {
72 PyObject
* value
= PySequence_Fast_GET_ITEM(obj
, i
);
73 RETURN_NOT_OK(func(value
, static_cast<int64_t>(i
), &keep_going
));
76 // Regular sequence: avoid making a potentially large copy
77 const Py_ssize_t size
= PySequence_Size(obj
);
79 for (Py_ssize_t i
= offset
; keep_going
&& i
< size
; ++i
) {
80 OwnedRef
value_ref(PySequence_ITEM(obj
, i
));
82 RETURN_NOT_OK(func(value_ref
.obj(), static_cast<int64_t>(i
), &keep_going
));
86 return Status::TypeError("Object is not a sequence");
91 // Visit sequence with no null mask
92 template <class VisitorFunc
>
93 inline Status
VisitSequence(PyObject
* obj
, int64_t offset
, VisitorFunc
&& func
) {
94 return VisitSequenceGeneric(
95 obj
, offset
, [&func
](PyObject
* value
, int64_t i
/* unused */, bool* keep_going
) {
96 return func(value
, keep_going
);
100 /// Visit sequence with null mask
101 template <class VisitorFunc
>
102 inline Status
VisitSequenceMasked(PyObject
* obj
, PyObject
* mo
, int64_t offset
,
103 VisitorFunc
&& func
) {
104 if (PyArray_Check(mo
)) {
105 PyArrayObject
* mask
= reinterpret_cast<PyArrayObject
*>(mo
);
106 if (PyArray_NDIM(mask
) != 1) {
107 return Status::Invalid("Mask must be 1D array");
109 if (PyArray_SIZE(mask
) != static_cast<int64_t>(PySequence_Size(obj
))) {
110 return Status::Invalid("Mask was a different length from sequence being converted");
113 const int dtype
= fix_numpy_type_num(PyArray_DESCR(mask
)->type_num
);
114 if (dtype
== NPY_BOOL
) {
115 Ndarray1DIndexer
<uint8_t> mask_values(mask
);
117 return VisitSequenceGeneric(
119 [&func
, &mask_values
](PyObject
* value
, int64_t i
, bool* keep_going
) {
120 return func(value
, mask_values
[i
], keep_going
);
123 return Status::TypeError("Mask must be boolean dtype");
125 } else if (py::is_array(mo
)) {
126 auto unwrap_mask_result
= unwrap_array(mo
);
127 ARROW_RETURN_NOT_OK(unwrap_mask_result
);
128 std::shared_ptr
<Array
> mask_
= unwrap_mask_result
.ValueOrDie();
129 if (mask_
->type_id() != Type::type::BOOL
) {
130 return Status::TypeError("Mask must be an array of booleans");
133 if (mask_
->length() != PySequence_Size(obj
)) {
134 return Status::Invalid("Mask was a different length from sequence being converted");
137 if (mask_
->null_count() != 0) {
138 return Status::TypeError("Mask must be an array of booleans");
141 BooleanArray
* boolmask
= checked_cast
<BooleanArray
*>(mask_
.get());
142 return VisitSequenceGeneric(
143 obj
, offset
, [&func
, &boolmask
](PyObject
* value
, int64_t i
, bool* keep_going
) {
144 return func(value
, boolmask
->Value(i
), keep_going
);
146 } else if (PySequence_Check(mo
)) {
147 if (PySequence_Size(mo
) != PySequence_Size(obj
)) {
148 return Status::Invalid("Mask was a different length from sequence being converted");
152 return VisitSequenceGeneric(
153 obj
, offset
, [&func
, &mo
](PyObject
* value
, int64_t i
, bool* keep_going
) {
154 OwnedRef
value_ref(PySequence_ITEM(mo
, i
));
155 if (!PyBool_Check(value_ref
.obj()))
156 return Status::TypeError("Mask must be a sequence of booleans");
157 return func(value
, value_ref
.obj() == Py_True
, keep_going
);
160 return Status::Invalid("Null mask must be a NumPy array, Arrow array or a Sequence");
166 // Like IterateSequence, but accepts any generic iterable (including
167 // non-restartable iterators, e.g. generators).
169 // The call signature for VisitorFunc must be Visit(PyObject*, bool*
170 // keep_going). If keep_going is set to false, the iteration terminates
171 template <class VisitorFunc
>
172 inline Status
VisitIterable(PyObject
* obj
, VisitorFunc
&& func
) {
173 if (PySequence_Check(obj
)) {
174 // Numpy arrays fall here as well
175 return VisitSequence(obj
, /*offset=*/0, std::forward
<VisitorFunc
>(func
));
177 // Fall back on the iterator protocol
178 OwnedRef
iter_ref(PyObject_GetIter(obj
));
179 PyObject
* iter
= iter_ref
.obj();
183 bool keep_going
= true;
184 while (keep_going
&& (value
= PyIter_Next(iter
))) {
185 OwnedRef
value_ref(value
);
186 RETURN_NOT_OK(func(value_ref
.obj(), &keep_going
));
188 RETURN_IF_PYERROR(); // __next__() might have raised
192 } // namespace internal