]>
Commit | Line | Data |
---|---|---|
1d09f67e TL |
1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one | |
3 | * or more contributor license agreements. See the NOTICE file | |
4 | * distributed with this work for additional information | |
5 | * regarding copyright ownership. The ASF licenses this file | |
6 | * to you under the Apache License, Version 2.0 (the | |
7 | * "License"); you may not use this file except in compliance | |
8 | * with the License. You may obtain a copy of the License at | |
9 | * | |
10 | * http://www.apache.org/licenses/LICENSE-2.0 | |
11 | * | |
12 | * Unless required by applicable law or agreed to in writing, | |
13 | * software distributed under the License is distributed on an | |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | |
15 | * KIND, either express or implied. See the License for the | |
16 | * specific language governing permissions and limitations | |
17 | * under the License. | |
18 | */ | |
19 | ||
20 | #include "memory-view.hpp" | |
21 | ||
22 | #include <arrow-glib/arrow-glib.hpp> | |
23 | #include <rbgobject.h> | |
24 | ||
25 | #include <ruby/version.h> | |
26 | ||
27 | #if RUBY_API_VERSION_MAJOR >= 3 | |
28 | # define HAVE_MEMORY_VIEW | |
29 | # define private memory_view_private | |
30 | # include <ruby/memory_view.h> | |
31 | # undef private | |
32 | #endif | |
33 | ||
34 | #include <sstream> | |
35 | ||
36 | namespace red_arrow { | |
37 | namespace memory_view { | |
38 | #ifdef HAVE_MEMORY_VIEW | |
39 | // This is workaround for the following rb_memory_view_t problems | |
40 | // in C++: | |
41 | // | |
42 | // * Can't use "private" as member name | |
43 | // * Can't assign a value to "rb_memory_view_t::private" | |
44 | // | |
45 | // This has compatible layout with rb_memory_view_t. | |
46 | struct memory_view { | |
47 | VALUE obj; | |
48 | void *data; | |
49 | ssize_t byte_size; | |
50 | bool readonly; | |
51 | const char *format; | |
52 | ssize_t item_size; | |
53 | struct { | |
54 | const rb_memory_view_item_component_t *components; | |
55 | size_t length; | |
56 | } item_desc; | |
57 | ssize_t ndim; | |
58 | const ssize_t *shape; | |
59 | const ssize_t *strides; | |
60 | const ssize_t *sub_offsets; | |
61 | void *private_data; | |
62 | }; | |
63 | ||
64 | struct PrivateData { | |
65 | std::string format; | |
66 | }; | |
67 | ||
68 | class PrimitiveArrayGetter : public arrow::ArrayVisitor { | |
69 | public: | |
70 | explicit PrimitiveArrayGetter(memory_view *view) | |
71 | : view_(view) { | |
72 | } | |
73 | ||
74 | arrow::Status Visit(const arrow::BooleanArray& array) override { | |
75 | fill(static_cast<const arrow::Array&>(array)); | |
76 | // Memory view doesn't support bit stream. We use one byte | |
77 | // for 8 elements. Users can't calculate the number of | |
78 | // elements from memory view but it's limitation of memory view. | |
79 | #ifdef ARROW_LITTLE_ENDIAN | |
80 | view_->format = "b8"; | |
81 | #else | |
82 | view_->format = "B8"; | |
83 | #endif | |
84 | view_->item_size = 1; | |
85 | view_->byte_size = (array.length() + 7) / 8; | |
86 | return arrow::Status::OK(); | |
87 | } | |
88 | ||
89 | arrow::Status Visit(const arrow::Int8Array& array) override { | |
90 | fill(static_cast<const arrow::Array&>(array)); | |
91 | view_->format = "c"; | |
92 | return arrow::Status::OK(); | |
93 | } | |
94 | ||
95 | arrow::Status Visit(const arrow::Int16Array& array) override { | |
96 | fill(static_cast<const arrow::Array&>(array)); | |
97 | view_->format = "s"; | |
98 | return arrow::Status::OK(); | |
99 | } | |
100 | ||
101 | arrow::Status Visit(const arrow::Int32Array& array) override { | |
102 | fill(static_cast<const arrow::Array&>(array)); | |
103 | view_->format = "l"; | |
104 | return arrow::Status::OK(); | |
105 | } | |
106 | ||
107 | arrow::Status Visit(const arrow::Int64Array& array) override { | |
108 | fill(static_cast<const arrow::Array&>(array)); | |
109 | view_->format = "q"; | |
110 | return arrow::Status::OK(); | |
111 | } | |
112 | ||
113 | arrow::Status Visit(const arrow::UInt8Array& array) override { | |
114 | fill(static_cast<const arrow::Array&>(array)); | |
115 | view_->format = "C"; | |
116 | return arrow::Status::OK(); | |
117 | } | |
118 | ||
119 | arrow::Status Visit(const arrow::UInt16Array& array) override { | |
120 | fill(static_cast<const arrow::Array&>(array)); | |
121 | view_->format = "S"; | |
122 | return arrow::Status::OK(); | |
123 | } | |
124 | ||
125 | arrow::Status Visit(const arrow::UInt32Array& array) override { | |
126 | fill(static_cast<const arrow::Array&>(array)); | |
127 | view_->format = "L"; | |
128 | return arrow::Status::OK(); | |
129 | } | |
130 | ||
131 | arrow::Status Visit(const arrow::UInt64Array& array) override { | |
132 | fill(static_cast<const arrow::Array&>(array)); | |
133 | view_->format = "Q"; | |
134 | return arrow::Status::OK(); | |
135 | } | |
136 | ||
137 | arrow::Status Visit(const arrow::FloatArray& array) override { | |
138 | fill(static_cast<const arrow::Array&>(array)); | |
139 | view_->format = "f"; | |
140 | return arrow::Status::OK(); | |
141 | } | |
142 | ||
143 | arrow::Status Visit(const arrow::DoubleArray& array) override { | |
144 | fill(static_cast<const arrow::Array&>(array)); | |
145 | view_->format = "d"; | |
146 | return arrow::Status::OK(); | |
147 | } | |
148 | ||
149 | arrow::Status Visit(const arrow::FixedSizeBinaryArray& array) override { | |
150 | fill(static_cast<const arrow::Array&>(array)); | |
151 | auto priv = static_cast<PrivateData *>(view_->private_data); | |
152 | const auto type = | |
153 | std::static_pointer_cast<const arrow::FixedSizeBinaryType>( | |
154 | array.type()); | |
155 | std::ostringstream output; | |
156 | output << "C" << type->byte_width(); | |
157 | priv->format = output.str(); | |
158 | view_->format = priv->format.c_str(); | |
159 | return arrow::Status::OK(); | |
160 | } | |
161 | ||
162 | arrow::Status Visit(const arrow::Date32Array& array) override { | |
163 | fill(static_cast<const arrow::Array&>(array)); | |
164 | view_->format = "l"; | |
165 | return arrow::Status::OK(); | |
166 | } | |
167 | ||
168 | arrow::Status Visit(const arrow::Date64Array& array) override { | |
169 | fill(static_cast<const arrow::Array&>(array)); | |
170 | view_->format = "q"; | |
171 | return arrow::Status::OK(); | |
172 | } | |
173 | ||
174 | arrow::Status Visit(const arrow::Time32Array& array) override { | |
175 | fill(static_cast<const arrow::Array&>(array)); | |
176 | view_->format = "l"; | |
177 | return arrow::Status::OK(); | |
178 | } | |
179 | ||
180 | arrow::Status Visit(const arrow::Time64Array& array) override { | |
181 | fill(static_cast<const arrow::Array&>(array)); | |
182 | view_->format = "q"; | |
183 | return arrow::Status::OK(); | |
184 | } | |
185 | ||
186 | arrow::Status Visit(const arrow::TimestampArray& array) override { | |
187 | fill(static_cast<const arrow::Array&>(array)); | |
188 | view_->format = "q"; | |
189 | return arrow::Status::OK(); | |
190 | } | |
191 | ||
192 | arrow::Status Visit(const arrow::Decimal128Array& array) override { | |
193 | fill(static_cast<const arrow::Array&>(array)); | |
194 | view_->format = "q2"; | |
195 | return arrow::Status::OK(); | |
196 | } | |
197 | ||
198 | arrow::Status Visit(const arrow::Decimal256Array& array) override { | |
199 | fill(static_cast<const arrow::Array&>(array)); | |
200 | view_->format = "q4"; | |
201 | return arrow::Status::OK(); | |
202 | } | |
203 | ||
204 | private: | |
205 | void fill(const arrow::Array& array) { | |
206 | const auto array_data = array.data(); | |
207 | const auto data = array_data->GetValuesSafe<uint8_t>(1); | |
208 | view_->data = const_cast<void *>(reinterpret_cast<const void *>(data)); | |
209 | const auto type = | |
210 | std::static_pointer_cast<const arrow::FixedWidthType>(array.type()); | |
211 | view_->item_size = type->bit_width() / 8; | |
212 | view_->byte_size = view_->item_size * array.length(); | |
213 | } | |
214 | ||
215 | memory_view *view_; | |
216 | }; | |
217 | ||
218 | bool primitive_array_get(VALUE obj, rb_memory_view_t *view, int flags) { | |
219 | if (flags != RUBY_MEMORY_VIEW_SIMPLE) { | |
220 | return false; | |
221 | } | |
222 | auto view_ = reinterpret_cast<memory_view *>(view); | |
223 | view_->obj = obj; | |
224 | view_->private_data = new PrivateData(); | |
225 | auto array = GARROW_ARRAY(RVAL2GOBJ(obj)); | |
226 | auto arrow_array = garrow_array_get_raw(array); | |
227 | PrimitiveArrayGetter getter(view_); | |
228 | auto status = arrow_array->Accept(&getter); | |
229 | if (!status.ok()) { | |
230 | return false; | |
231 | } | |
232 | view_->readonly = true; | |
233 | view_->ndim = 1; | |
234 | view_->shape = NULL; | |
235 | view_->strides = NULL; | |
236 | view_->sub_offsets = NULL; | |
237 | return true; | |
238 | } | |
239 | ||
240 | bool primitive_array_release(VALUE obj, rb_memory_view_t *view) { | |
241 | auto view_ = reinterpret_cast<memory_view *>(view); | |
242 | delete static_cast<PrivateData *>(view_->private_data); | |
243 | return true; | |
244 | } | |
245 | ||
246 | bool primitive_array_available_p(VALUE obj) { | |
247 | return true; | |
248 | } | |
249 | ||
250 | rb_memory_view_entry_t primitive_array_entry = { | |
251 | primitive_array_get, | |
252 | primitive_array_release, | |
253 | primitive_array_available_p, | |
254 | }; | |
255 | ||
256 | bool buffer_get(VALUE obj, rb_memory_view_t *view, int flags) { | |
257 | if (flags != RUBY_MEMORY_VIEW_SIMPLE) { | |
258 | return false; | |
259 | } | |
260 | auto view_ = reinterpret_cast<memory_view *>(view); | |
261 | view_->obj = obj; | |
262 | auto buffer = GARROW_BUFFER(RVAL2GOBJ(obj)); | |
263 | auto arrow_buffer = garrow_buffer_get_raw(buffer); | |
264 | view_->data = | |
265 | const_cast<void *>(reinterpret_cast<const void *>(arrow_buffer->data())); | |
266 | // Memory view doesn't support bit stream. We use one byte | |
267 | // for 8 elements. Users can't calculate the number of | |
268 | // elements from memory view but it's limitation of memory view. | |
269 | #ifdef ARROW_LITTLE_ENDIAN | |
270 | view_->format = "b8"; | |
271 | #else | |
272 | view_->format = "B8"; | |
273 | #endif | |
274 | view_->item_size = 1; | |
275 | view_->byte_size = arrow_buffer->size(); | |
276 | view_->readonly = true; | |
277 | view_->ndim = 1; | |
278 | view_->shape = NULL; | |
279 | view_->strides = NULL; | |
280 | view_->sub_offsets = NULL; | |
281 | return true; | |
282 | } | |
283 | ||
284 | bool buffer_release(VALUE obj, rb_memory_view_t *view) { | |
285 | return true; | |
286 | } | |
287 | ||
288 | bool buffer_available_p(VALUE obj) { | |
289 | return true; | |
290 | } | |
291 | ||
292 | rb_memory_view_entry_t buffer_entry = { | |
293 | buffer_get, | |
294 | buffer_release, | |
295 | buffer_available_p, | |
296 | }; | |
297 | #endif | |
298 | ||
299 | void init(VALUE mArrow) { | |
300 | #ifdef HAVE_MEMORY_VIEW | |
301 | auto cPrimitiveArray = | |
302 | rb_const_get_at(mArrow, rb_intern("PrimitiveArray")); | |
303 | rb_memory_view_register(cPrimitiveArray, | |
304 | &(red_arrow::memory_view::primitive_array_entry)); | |
305 | ||
306 | auto cBuffer = rb_const_get_at(mArrow, rb_intern("Buffer")); | |
307 | rb_memory_view_register(cBuffer, &(red_arrow::memory_view::buffer_entry)); | |
308 | #endif | |
309 | } | |
310 | } | |
311 | } |