]>
Commit | Line | Data |
---|---|---|
1d09f67e TL |
1 | // Licensed to the Apache Software Foundation (ASF) under one |
2 | // or more contributor license agreements. See the NOTICE file | |
3 | // distributed with this work for additional information | |
4 | // regarding copyright ownership. The ASF licenses this file | |
5 | // to you under the Apache License, Version 2.0 (the | |
6 | // "License"); you may not use this file except in compliance | |
7 | // with the License. You may obtain a copy of the License at | |
8 | // | |
9 | // http://www.apache.org/licenses/LICENSE-2.0 | |
10 | // | |
11 | // Unless required by applicable law or agreed to in writing, | |
12 | // software distributed under the License is distributed on an | |
13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | |
14 | // KIND, either express or implied. See the License for the | |
15 | // specific language governing permissions and limitations | |
16 | // under the License. | |
17 | ||
18 | #include "arrow/sparse_tensor.h" | |
19 | #include "arrow/tensor/converter.h" | |
20 | ||
21 | #include <algorithm> | |
22 | #include <functional> | |
23 | #include <memory> | |
24 | #include <numeric> | |
25 | ||
26 | #include "arrow/compare.h" | |
27 | #include "arrow/type_traits.h" | |
28 | #include "arrow/util/checked_cast.h" | |
29 | #include "arrow/util/logging.h" | |
30 | #include "arrow/visitor_inline.h" | |
31 | ||
32 | namespace arrow { | |
33 | ||
34 | class MemoryPool; | |
35 | ||
36 | // ---------------------------------------------------------------------- | |
37 | // SparseIndex | |
38 | ||
39 | Status SparseIndex::ValidateShape(const std::vector<int64_t>& shape) const { | |
40 | if (!std::all_of(shape.begin(), shape.end(), [](int64_t x) { return x >= 0; })) { | |
41 | return Status::Invalid("Shape elements must be positive"); | |
42 | } | |
43 | ||
44 | return Status::OK(); | |
45 | } | |
46 | ||
47 | namespace internal { | |
48 | namespace { | |
49 | ||
50 | template <typename IndexValueType> | |
51 | Status CheckSparseIndexMaximumValue(const std::vector<int64_t>& shape) { | |
52 | using c_index_value_type = typename IndexValueType::c_type; | |
53 | constexpr int64_t type_max = | |
54 | static_cast<int64_t>(std::numeric_limits<c_index_value_type>::max()); | |
55 | auto greater_than_type_max = [&](int64_t x) { return x > type_max; }; | |
56 | if (std::any_of(shape.begin(), shape.end(), greater_than_type_max)) { | |
57 | return Status::Invalid("The bit width of the index value type is too small"); | |
58 | } | |
59 | return Status::OK(); | |
60 | } | |
61 | ||
62 | template <> | |
63 | Status CheckSparseIndexMaximumValue<Int64Type>(const std::vector<int64_t>& shape) { | |
64 | return Status::OK(); | |
65 | } | |
66 | ||
67 | template <> | |
68 | Status CheckSparseIndexMaximumValue<UInt64Type>(const std::vector<int64_t>& shape) { | |
69 | return Status::Invalid("UInt64Type cannot be used as IndexValueType of SparseIndex"); | |
70 | } | |
71 | ||
72 | } // namespace | |
73 | ||
74 | #define CALL_CHECK_MAXIMUM_VALUE(TYPE_CLASS) \ | |
75 | case TYPE_CLASS##Type::type_id: \ | |
76 | return CheckSparseIndexMaximumValue<TYPE_CLASS##Type>(shape); | |
77 | ||
78 | Status CheckSparseIndexMaximumValue(const std::shared_ptr<DataType>& index_value_type, | |
79 | const std::vector<int64_t>& shape) { | |
80 | switch (index_value_type->id()) { | |
81 | ARROW_GENERATE_FOR_ALL_INTEGER_TYPES(CALL_CHECK_MAXIMUM_VALUE); | |
82 | default: | |
83 | return Status::TypeError("Unsupported SparseTensor index value type"); | |
84 | } | |
85 | } | |
86 | ||
87 | #undef CALL_CHECK_MAXIMUM_VALUE | |
88 | ||
89 | Status MakeSparseTensorFromTensor(const Tensor& tensor, | |
90 | SparseTensorFormat::type sparse_format_id, | |
91 | const std::shared_ptr<DataType>& index_value_type, | |
92 | MemoryPool* pool, | |
93 | std::shared_ptr<SparseIndex>* out_sparse_index, | |
94 | std::shared_ptr<Buffer>* out_data) { | |
95 | switch (sparse_format_id) { | |
96 | case SparseTensorFormat::COO: | |
97 | return MakeSparseCOOTensorFromTensor(tensor, index_value_type, pool, | |
98 | out_sparse_index, out_data); | |
99 | case SparseTensorFormat::CSR: | |
100 | return MakeSparseCSXMatrixFromTensor(SparseMatrixCompressedAxis::ROW, tensor, | |
101 | index_value_type, pool, out_sparse_index, | |
102 | out_data); | |
103 | case SparseTensorFormat::CSC: | |
104 | return MakeSparseCSXMatrixFromTensor(SparseMatrixCompressedAxis::COLUMN, tensor, | |
105 | index_value_type, pool, out_sparse_index, | |
106 | out_data); | |
107 | case SparseTensorFormat::CSF: | |
108 | return MakeSparseCSFTensorFromTensor(tensor, index_value_type, pool, | |
109 | out_sparse_index, out_data); | |
110 | ||
111 | // LCOV_EXCL_START: ignore program failure | |
112 | default: | |
113 | return Status::Invalid("Invalid sparse tensor format"); | |
114 | // LCOV_EXCL_STOP | |
115 | } | |
116 | } | |
117 | ||
118 | } // namespace internal | |
119 | ||
120 | // ---------------------------------------------------------------------- | |
121 | // SparseCOOIndex | |
122 | ||
123 | namespace { | |
124 | ||
125 | inline Status CheckSparseCOOIndexValidity(const std::shared_ptr<DataType>& type, | |
126 | const std::vector<int64_t>& shape, | |
127 | const std::vector<int64_t>& strides) { | |
128 | if (!is_integer(type->id())) { | |
129 | return Status::TypeError("Type of SparseCOOIndex indices must be integer"); | |
130 | } | |
131 | if (shape.size() != 2) { | |
132 | return Status::Invalid("SparseCOOIndex indices must be a matrix"); | |
133 | } | |
134 | ||
135 | RETURN_NOT_OK(internal::CheckSparseIndexMaximumValue(type, shape)); | |
136 | ||
137 | if (!internal::IsTensorStridesContiguous(type, shape, strides)) { | |
138 | return Status::Invalid("SparseCOOIndex indices must be contiguous"); | |
139 | } | |
140 | return Status::OK(); | |
141 | } | |
142 | ||
143 | void GetCOOIndexTensorRow(const std::shared_ptr<Tensor>& coords, const int64_t row, | |
144 | std::vector<int64_t>* out_index) { | |
145 | const auto& fw_index_value_type = | |
146 | internal::checked_cast<const FixedWidthType&>(*coords->type()); | |
147 | const size_t indices_elsize = fw_index_value_type.bit_width() / CHAR_BIT; | |
148 | ||
149 | const auto& shape = coords->shape(); | |
150 | const int64_t non_zero_length = shape[0]; | |
151 | DCHECK(0 <= row && row < non_zero_length); | |
152 | ||
153 | const int64_t ndim = shape[1]; | |
154 | out_index->resize(ndim); | |
155 | ||
156 | switch (indices_elsize) { | |
157 | case 1: // Int8, UInt8 | |
158 | for (int64_t i = 0; i < ndim; ++i) { | |
159 | (*out_index)[i] = static_cast<int64_t>(coords->Value<UInt8Type>({row, i})); | |
160 | } | |
161 | break; | |
162 | case 2: // Int16, UInt16 | |
163 | for (int64_t i = 0; i < ndim; ++i) { | |
164 | (*out_index)[i] = static_cast<int64_t>(coords->Value<UInt16Type>({row, i})); | |
165 | } | |
166 | break; | |
167 | case 4: // Int32, UInt32 | |
168 | for (int64_t i = 0; i < ndim; ++i) { | |
169 | (*out_index)[i] = static_cast<int64_t>(coords->Value<UInt32Type>({row, i})); | |
170 | } | |
171 | break; | |
172 | case 8: // Int64 | |
173 | for (int64_t i = 0; i < ndim; ++i) { | |
174 | (*out_index)[i] = coords->Value<Int64Type>({row, i}); | |
175 | } | |
176 | break; | |
177 | default: | |
178 | DCHECK(false) << "Must not reach here"; | |
179 | break; | |
180 | } | |
181 | } | |
182 | ||
183 | bool DetectSparseCOOIndexCanonicality(const std::shared_ptr<Tensor>& coords) { | |
184 | DCHECK_EQ(coords->ndim(), 2); | |
185 | ||
186 | const auto& shape = coords->shape(); | |
187 | const int64_t non_zero_length = shape[0]; | |
188 | if (non_zero_length <= 1) return true; | |
189 | ||
190 | const int64_t ndim = shape[1]; | |
191 | std::vector<int64_t> last_index, index; | |
192 | GetCOOIndexTensorRow(coords, 0, &last_index); | |
193 | for (int64_t i = 1; i < non_zero_length; ++i) { | |
194 | GetCOOIndexTensorRow(coords, i, &index); | |
195 | int64_t j = 0; | |
196 | while (j < ndim) { | |
197 | if (last_index[j] > index[j]) { | |
198 | // last_index > index, so we can detect non-canonical here | |
199 | return false; | |
200 | } | |
201 | if (last_index[j] < index[j]) { | |
202 | // last_index < index, so we can skip the remaining dimensions | |
203 | break; | |
204 | } | |
205 | ++j; | |
206 | } | |
207 | if (j == ndim) { | |
208 | // last_index == index, so we can detect non-canonical here | |
209 | return false; | |
210 | } | |
211 | swap(last_index, index); | |
212 | } | |
213 | ||
214 | return true; | |
215 | } | |
216 | ||
217 | } // namespace | |
218 | ||
219 | Result<std::shared_ptr<SparseCOOIndex>> SparseCOOIndex::Make( | |
220 | const std::shared_ptr<Tensor>& coords, bool is_canonical) { | |
221 | RETURN_NOT_OK( | |
222 | CheckSparseCOOIndexValidity(coords->type(), coords->shape(), coords->strides())); | |
223 | return std::make_shared<SparseCOOIndex>(coords, is_canonical); | |
224 | } | |
225 | ||
226 | Result<std::shared_ptr<SparseCOOIndex>> SparseCOOIndex::Make( | |
227 | const std::shared_ptr<Tensor>& coords) { | |
228 | RETURN_NOT_OK( | |
229 | CheckSparseCOOIndexValidity(coords->type(), coords->shape(), coords->strides())); | |
230 | auto is_canonical = DetectSparseCOOIndexCanonicality(coords); | |
231 | return std::make_shared<SparseCOOIndex>(coords, is_canonical); | |
232 | } | |
233 | ||
234 | Result<std::shared_ptr<SparseCOOIndex>> SparseCOOIndex::Make( | |
235 | const std::shared_ptr<DataType>& indices_type, | |
236 | const std::vector<int64_t>& indices_shape, | |
237 | const std::vector<int64_t>& indices_strides, std::shared_ptr<Buffer> indices_data, | |
238 | bool is_canonical) { | |
239 | RETURN_NOT_OK( | |
240 | CheckSparseCOOIndexValidity(indices_type, indices_shape, indices_strides)); | |
241 | return std::make_shared<SparseCOOIndex>( | |
242 | std::make_shared<Tensor>(indices_type, indices_data, indices_shape, | |
243 | indices_strides), | |
244 | is_canonical); | |
245 | } | |
246 | ||
247 | Result<std::shared_ptr<SparseCOOIndex>> SparseCOOIndex::Make( | |
248 | const std::shared_ptr<DataType>& indices_type, | |
249 | const std::vector<int64_t>& indices_shape, | |
250 | const std::vector<int64_t>& indices_strides, std::shared_ptr<Buffer> indices_data) { | |
251 | RETURN_NOT_OK( | |
252 | CheckSparseCOOIndexValidity(indices_type, indices_shape, indices_strides)); | |
253 | auto coords = std::make_shared<Tensor>(indices_type, indices_data, indices_shape, | |
254 | indices_strides); | |
255 | auto is_canonical = DetectSparseCOOIndexCanonicality(coords); | |
256 | return std::make_shared<SparseCOOIndex>(coords, is_canonical); | |
257 | } | |
258 | ||
259 | Result<std::shared_ptr<SparseCOOIndex>> SparseCOOIndex::Make( | |
260 | const std::shared_ptr<DataType>& indices_type, const std::vector<int64_t>& shape, | |
261 | int64_t non_zero_length, std::shared_ptr<Buffer> indices_data, bool is_canonical) { | |
262 | auto ndim = static_cast<int64_t>(shape.size()); | |
263 | if (!is_integer(indices_type->id())) { | |
264 | return Status::TypeError("Type of SparseCOOIndex indices must be integer"); | |
265 | } | |
266 | const int64_t elsize = | |
267 | internal::checked_cast<const IntegerType&>(*indices_type).bit_width() / 8; | |
268 | std::vector<int64_t> indices_shape({non_zero_length, ndim}); | |
269 | std::vector<int64_t> indices_strides({elsize * ndim, elsize}); | |
270 | return Make(indices_type, indices_shape, indices_strides, indices_data, is_canonical); | |
271 | } | |
272 | ||
273 | Result<std::shared_ptr<SparseCOOIndex>> SparseCOOIndex::Make( | |
274 | const std::shared_ptr<DataType>& indices_type, const std::vector<int64_t>& shape, | |
275 | int64_t non_zero_length, std::shared_ptr<Buffer> indices_data) { | |
276 | auto ndim = static_cast<int64_t>(shape.size()); | |
277 | if (!is_integer(indices_type->id())) { | |
278 | return Status::TypeError("Type of SparseCOOIndex indices must be integer"); | |
279 | } | |
280 | const int64_t elsize = internal::GetByteWidth(*indices_type); | |
281 | std::vector<int64_t> indices_shape({non_zero_length, ndim}); | |
282 | std::vector<int64_t> indices_strides({elsize * ndim, elsize}); | |
283 | return Make(indices_type, indices_shape, indices_strides, indices_data); | |
284 | } | |
285 | ||
286 | // Constructor with a contiguous NumericTensor | |
287 | SparseCOOIndex::SparseCOOIndex(const std::shared_ptr<Tensor>& coords, bool is_canonical) | |
288 | : SparseIndexBase(), coords_(coords), is_canonical_(is_canonical) { | |
289 | ARROW_CHECK_OK( | |
290 | CheckSparseCOOIndexValidity(coords_->type(), coords_->shape(), coords_->strides())); | |
291 | } | |
292 | ||
293 | std::string SparseCOOIndex::ToString() const { return std::string("SparseCOOIndex"); } | |
294 | ||
295 | // ---------------------------------------------------------------------- | |
296 | // SparseCSXIndex | |
297 | ||
298 | namespace internal { | |
299 | ||
300 | Status ValidateSparseCSXIndex(const std::shared_ptr<DataType>& indptr_type, | |
301 | const std::shared_ptr<DataType>& indices_type, | |
302 | const std::vector<int64_t>& indptr_shape, | |
303 | const std::vector<int64_t>& indices_shape, | |
304 | char const* type_name) { | |
305 | if (!is_integer(indptr_type->id())) { | |
306 | return Status::TypeError("Type of ", type_name, " indptr must be integer"); | |
307 | } | |
308 | if (indptr_shape.size() != 1) { | |
309 | return Status::Invalid(type_name, " indptr must be a vector"); | |
310 | } | |
311 | if (!is_integer(indices_type->id())) { | |
312 | return Status::Invalid("Type of ", type_name, " indices must be integer"); | |
313 | } | |
314 | if (indices_shape.size() != 1) { | |
315 | return Status::Invalid(type_name, " indices must be a vector"); | |
316 | } | |
317 | ||
318 | RETURN_NOT_OK(internal::CheckSparseIndexMaximumValue(indptr_type, indptr_shape)); | |
319 | RETURN_NOT_OK(internal::CheckSparseIndexMaximumValue(indices_type, indices_shape)); | |
320 | ||
321 | return Status::OK(); | |
322 | } | |
323 | ||
324 | void CheckSparseCSXIndexValidity(const std::shared_ptr<DataType>& indptr_type, | |
325 | const std::shared_ptr<DataType>& indices_type, | |
326 | const std::vector<int64_t>& indptr_shape, | |
327 | const std::vector<int64_t>& indices_shape, | |
328 | char const* type_name) { | |
329 | ARROW_CHECK_OK(ValidateSparseCSXIndex(indptr_type, indices_type, indptr_shape, | |
330 | indices_shape, type_name)); | |
331 | } | |
332 | ||
333 | } // namespace internal | |
334 | ||
335 | // ---------------------------------------------------------------------- | |
336 | // SparseCSFIndex | |
337 | ||
338 | namespace { | |
339 | ||
340 | inline Status CheckSparseCSFIndexValidity(const std::shared_ptr<DataType>& indptr_type, | |
341 | const std::shared_ptr<DataType>& indices_type, | |
342 | const int64_t num_indptrs, | |
343 | const int64_t num_indices, | |
344 | const int64_t axis_order_size) { | |
345 | if (!is_integer(indptr_type->id())) { | |
346 | return Status::TypeError("Type of SparseCSFIndex indptr must be integer"); | |
347 | } | |
348 | if (!is_integer(indices_type->id())) { | |
349 | return Status::TypeError("Type of SparseCSFIndex indices must be integer"); | |
350 | } | |
351 | if (num_indptrs + 1 != num_indices) { | |
352 | return Status::Invalid( | |
353 | "Length of indices must be equal to length of indptrs + 1 for SparseCSFIndex."); | |
354 | } | |
355 | if (axis_order_size != num_indices) { | |
356 | return Status::Invalid( | |
357 | "Length of indices must be equal to number of dimensions for SparseCSFIndex."); | |
358 | } | |
359 | return Status::OK(); | |
360 | } | |
361 | ||
362 | } // namespace | |
363 | ||
364 | Result<std::shared_ptr<SparseCSFIndex>> SparseCSFIndex::Make( | |
365 | const std::shared_ptr<DataType>& indptr_type, | |
366 | const std::shared_ptr<DataType>& indices_type, | |
367 | const std::vector<int64_t>& indices_shapes, const std::vector<int64_t>& axis_order, | |
368 | const std::vector<std::shared_ptr<Buffer>>& indptr_data, | |
369 | const std::vector<std::shared_ptr<Buffer>>& indices_data) { | |
370 | int64_t ndim = axis_order.size(); | |
371 | std::vector<std::shared_ptr<Tensor>> indptr(ndim - 1); | |
372 | std::vector<std::shared_ptr<Tensor>> indices(ndim); | |
373 | ||
374 | for (int64_t i = 0; i < ndim - 1; ++i) | |
375 | indptr[i] = std::make_shared<Tensor>(indptr_type, indptr_data[i], | |
376 | std::vector<int64_t>({indices_shapes[i] + 1})); | |
377 | for (int64_t i = 0; i < ndim; ++i) | |
378 | indices[i] = std::make_shared<Tensor>(indices_type, indices_data[i], | |
379 | std::vector<int64_t>({indices_shapes[i]})); | |
380 | ||
381 | RETURN_NOT_OK(CheckSparseCSFIndexValidity(indptr_type, indices_type, indptr.size(), | |
382 | indices.size(), axis_order.size())); | |
383 | ||
384 | for (auto tensor : indptr) { | |
385 | RETURN_NOT_OK(internal::CheckSparseIndexMaximumValue(indptr_type, tensor->shape())); | |
386 | } | |
387 | ||
388 | for (auto tensor : indices) { | |
389 | RETURN_NOT_OK(internal::CheckSparseIndexMaximumValue(indices_type, tensor->shape())); | |
390 | } | |
391 | ||
392 | return std::make_shared<SparseCSFIndex>(indptr, indices, axis_order); | |
393 | } | |
394 | ||
395 | // Constructor with two index vectors | |
396 | SparseCSFIndex::SparseCSFIndex(const std::vector<std::shared_ptr<Tensor>>& indptr, | |
397 | const std::vector<std::shared_ptr<Tensor>>& indices, | |
398 | const std::vector<int64_t>& axis_order) | |
399 | : SparseIndexBase(), indptr_(indptr), indices_(indices), axis_order_(axis_order) { | |
400 | ARROW_CHECK_OK(CheckSparseCSFIndexValidity(indptr_.front()->type(), | |
401 | indices_.front()->type(), indptr_.size(), | |
402 | indices_.size(), axis_order_.size())); | |
403 | } | |
404 | ||
405 | std::string SparseCSFIndex::ToString() const { return std::string("SparseCSFIndex"); } | |
406 | ||
407 | bool SparseCSFIndex::Equals(const SparseCSFIndex& other) const { | |
408 | for (int64_t i = 0; i < static_cast<int64_t>(indices().size()); ++i) { | |
409 | if (!indices()[i]->Equals(*other.indices()[i])) return false; | |
410 | } | |
411 | for (int64_t i = 0; i < static_cast<int64_t>(indptr().size()); ++i) { | |
412 | if (!indptr()[i]->Equals(*other.indptr()[i])) return false; | |
413 | } | |
414 | return axis_order() == other.axis_order(); | |
415 | } | |
416 | ||
417 | // ---------------------------------------------------------------------- | |
418 | // SparseTensor | |
419 | ||
420 | // Constructor with all attributes | |
421 | SparseTensor::SparseTensor(const std::shared_ptr<DataType>& type, | |
422 | const std::shared_ptr<Buffer>& data, | |
423 | const std::vector<int64_t>& shape, | |
424 | const std::shared_ptr<SparseIndex>& sparse_index, | |
425 | const std::vector<std::string>& dim_names) | |
426 | : type_(type), | |
427 | data_(data), | |
428 | shape_(shape), | |
429 | sparse_index_(sparse_index), | |
430 | dim_names_(dim_names) { | |
431 | ARROW_CHECK(is_tensor_supported(type->id())); | |
432 | } | |
433 | ||
434 | const std::string& SparseTensor::dim_name(int i) const { | |
435 | static const std::string kEmpty = ""; | |
436 | if (dim_names_.size() == 0) { | |
437 | return kEmpty; | |
438 | } else { | |
439 | ARROW_CHECK_LT(i, static_cast<int>(dim_names_.size())); | |
440 | return dim_names_[i]; | |
441 | } | |
442 | } | |
443 | ||
444 | int64_t SparseTensor::size() const { | |
445 | return std::accumulate(shape_.begin(), shape_.end(), 1LL, std::multiplies<int64_t>()); | |
446 | } | |
447 | ||
448 | bool SparseTensor::Equals(const SparseTensor& other, const EqualOptions& opts) const { | |
449 | return SparseTensorEquals(*this, other, opts); | |
450 | } | |
451 | ||
452 | Result<std::shared_ptr<Tensor>> SparseTensor::ToTensor(MemoryPool* pool) const { | |
453 | switch (format_id()) { | |
454 | case SparseTensorFormat::COO: | |
455 | return MakeTensorFromSparseCOOTensor( | |
456 | pool, internal::checked_cast<const SparseCOOTensor*>(this)); | |
457 | break; | |
458 | ||
459 | case SparseTensorFormat::CSR: | |
460 | return MakeTensorFromSparseCSRMatrix( | |
461 | pool, internal::checked_cast<const SparseCSRMatrix*>(this)); | |
462 | break; | |
463 | ||
464 | case SparseTensorFormat::CSC: | |
465 | return MakeTensorFromSparseCSCMatrix( | |
466 | pool, internal::checked_cast<const SparseCSCMatrix*>(this)); | |
467 | break; | |
468 | ||
469 | case SparseTensorFormat::CSF: | |
470 | return MakeTensorFromSparseCSFTensor( | |
471 | pool, internal::checked_cast<const SparseCSFTensor*>(this)); | |
472 | ||
473 | default: | |
474 | return Status::NotImplemented("Unsupported SparseIndex format type"); | |
475 | } | |
476 | } | |
477 | ||
478 | } // namespace arrow |