]> git.proxmox.com Git - ceph.git/blob - ceph/src/arrow/cpp/src/arrow/python/iterators.h
import quincy 17.2.0
[ceph.git] / ceph / src / arrow / cpp / src / arrow / python / iterators.h
1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements. See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership. The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License. You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied. See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17
18 #pragma once
19
20 #include <utility>
21
22 #include "arrow/array/array_primitive.h"
23
24 #include "arrow/python/common.h"
25 #include "arrow/python/numpy_internal.h"
26
27 namespace arrow {
28 namespace py {
29 namespace internal {
30
31 using arrow::internal::checked_cast;
32
33 // Visit the Python sequence, calling the given callable on each element. If
34 // the callable returns a non-OK status, iteration stops and the status is
35 // returned.
36 //
37 // The call signature for Visitor must be
38 //
39 // Visit(PyObject* obj, int64_t index, bool* keep_going)
40 //
41 // If keep_going is set to false, the iteration terminates
42 template <class VisitorFunc>
43 inline Status VisitSequenceGeneric(PyObject* obj, int64_t offset, VisitorFunc&& func) {
44 // VisitorFunc may set to false to terminate iteration
45 bool keep_going = true;
46
47 if (PyArray_Check(obj)) {
48 PyArrayObject* arr_obj = reinterpret_cast<PyArrayObject*>(obj);
49 if (PyArray_NDIM(arr_obj) != 1) {
50 return Status::Invalid("Only 1D arrays accepted");
51 }
52
53 if (PyArray_DESCR(arr_obj)->type_num == NPY_OBJECT) {
54 // It's an array object, we can fetch object pointers directly
55 const Ndarray1DIndexer<PyObject*> objects(arr_obj);
56 for (int64_t i = offset; keep_going && i < objects.size(); ++i) {
57 RETURN_NOT_OK(func(objects[i], i, &keep_going));
58 }
59 return Status::OK();
60 }
61 // It's a non-object array, fall back on regular sequence access.
62 // (note PyArray_GETITEM() is slightly different: it returns standard
63 // Python types, not Numpy scalar types)
64 // This code path is inefficient: callers should implement dedicated
65 // logic for non-object arrays.
66 }
67 if (PySequence_Check(obj)) {
68 if (PyList_Check(obj) || PyTuple_Check(obj)) {
69 // Use fast item access
70 const Py_ssize_t size = PySequence_Fast_GET_SIZE(obj);
71 for (Py_ssize_t i = offset; keep_going && i < size; ++i) {
72 PyObject* value = PySequence_Fast_GET_ITEM(obj, i);
73 RETURN_NOT_OK(func(value, static_cast<int64_t>(i), &keep_going));
74 }
75 } else {
76 // Regular sequence: avoid making a potentially large copy
77 const Py_ssize_t size = PySequence_Size(obj);
78 RETURN_IF_PYERROR();
79 for (Py_ssize_t i = offset; keep_going && i < size; ++i) {
80 OwnedRef value_ref(PySequence_ITEM(obj, i));
81 RETURN_IF_PYERROR();
82 RETURN_NOT_OK(func(value_ref.obj(), static_cast<int64_t>(i), &keep_going));
83 }
84 }
85 } else {
86 return Status::TypeError("Object is not a sequence");
87 }
88 return Status::OK();
89 }
90
91 // Visit sequence with no null mask
92 template <class VisitorFunc>
93 inline Status VisitSequence(PyObject* obj, int64_t offset, VisitorFunc&& func) {
94 return VisitSequenceGeneric(
95 obj, offset, [&func](PyObject* value, int64_t i /* unused */, bool* keep_going) {
96 return func(value, keep_going);
97 });
98 }
99
100 /// Visit sequence with null mask
101 template <class VisitorFunc>
102 inline Status VisitSequenceMasked(PyObject* obj, PyObject* mo, int64_t offset,
103 VisitorFunc&& func) {
104 if (PyArray_Check(mo)) {
105 PyArrayObject* mask = reinterpret_cast<PyArrayObject*>(mo);
106 if (PyArray_NDIM(mask) != 1) {
107 return Status::Invalid("Mask must be 1D array");
108 }
109 if (PyArray_SIZE(mask) != static_cast<int64_t>(PySequence_Size(obj))) {
110 return Status::Invalid("Mask was a different length from sequence being converted");
111 }
112
113 const int dtype = fix_numpy_type_num(PyArray_DESCR(mask)->type_num);
114 if (dtype == NPY_BOOL) {
115 Ndarray1DIndexer<uint8_t> mask_values(mask);
116
117 return VisitSequenceGeneric(
118 obj, offset,
119 [&func, &mask_values](PyObject* value, int64_t i, bool* keep_going) {
120 return func(value, mask_values[i], keep_going);
121 });
122 } else {
123 return Status::TypeError("Mask must be boolean dtype");
124 }
125 } else if (py::is_array(mo)) {
126 auto unwrap_mask_result = unwrap_array(mo);
127 ARROW_RETURN_NOT_OK(unwrap_mask_result);
128 std::shared_ptr<Array> mask_ = unwrap_mask_result.ValueOrDie();
129 if (mask_->type_id() != Type::type::BOOL) {
130 return Status::TypeError("Mask must be an array of booleans");
131 }
132
133 if (mask_->length() != PySequence_Size(obj)) {
134 return Status::Invalid("Mask was a different length from sequence being converted");
135 }
136
137 if (mask_->null_count() != 0) {
138 return Status::TypeError("Mask must be an array of booleans");
139 }
140
141 BooleanArray* boolmask = checked_cast<BooleanArray*>(mask_.get());
142 return VisitSequenceGeneric(
143 obj, offset, [&func, &boolmask](PyObject* value, int64_t i, bool* keep_going) {
144 return func(value, boolmask->Value(i), keep_going);
145 });
146 } else if (PySequence_Check(mo)) {
147 if (PySequence_Size(mo) != PySequence_Size(obj)) {
148 return Status::Invalid("Mask was a different length from sequence being converted");
149 }
150 RETURN_IF_PYERROR();
151
152 return VisitSequenceGeneric(
153 obj, offset, [&func, &mo](PyObject* value, int64_t i, bool* keep_going) {
154 OwnedRef value_ref(PySequence_ITEM(mo, i));
155 if (!PyBool_Check(value_ref.obj()))
156 return Status::TypeError("Mask must be a sequence of booleans");
157 return func(value, value_ref.obj() == Py_True, keep_going);
158 });
159 } else {
160 return Status::Invalid("Null mask must be a NumPy array, Arrow array or a Sequence");
161 }
162
163 return Status::OK();
164 }
165
166 // Like IterateSequence, but accepts any generic iterable (including
167 // non-restartable iterators, e.g. generators).
168 //
169 // The call signature for VisitorFunc must be Visit(PyObject*, bool*
170 // keep_going). If keep_going is set to false, the iteration terminates
171 template <class VisitorFunc>
172 inline Status VisitIterable(PyObject* obj, VisitorFunc&& func) {
173 if (PySequence_Check(obj)) {
174 // Numpy arrays fall here as well
175 return VisitSequence(obj, /*offset=*/0, std::forward<VisitorFunc>(func));
176 }
177 // Fall back on the iterator protocol
178 OwnedRef iter_ref(PyObject_GetIter(obj));
179 PyObject* iter = iter_ref.obj();
180 RETURN_IF_PYERROR();
181 PyObject* value;
182
183 bool keep_going = true;
184 while (keep_going && (value = PyIter_Next(iter))) {
185 OwnedRef value_ref(value);
186 RETURN_NOT_OK(func(value_ref.obj(), &keep_going));
187 }
188 RETURN_IF_PYERROR(); // __next__() might have raised
189 return Status::OK();
190 }
191
192 } // namespace internal
193 } // namespace py
194 } // namespace arrow