]>
Commit | Line | Data |
---|---|---|
1d09f67e TL |
1 | // Licensed to the Apache Software Foundation (ASF) under one |
2 | // or more contributor license agreements. See the NOTICE file | |
3 | // distributed with this work for additional information | |
4 | // regarding copyright ownership. The ASF licenses this file | |
5 | // to you under the Apache License, Version 2.0 (the | |
6 | // "License"); you may not use this file except in compliance | |
7 | // with the License. You may obtain a copy of the License at | |
8 | // | |
9 | // http://www.apache.org/licenses/LICENSE-2.0 | |
10 | // | |
11 | // Unless required by applicable law or agreed to in writing, | |
12 | // software distributed under the License is distributed on an | |
13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | |
14 | // KIND, either express or implied. See the License for the | |
15 | // specific language governing permissions and limitations | |
16 | // under the License. | |
17 | ||
18 | #include "arrow/tensor/converter.h" | |
19 | ||
20 | #include <cstdint> | |
21 | #include <limits> | |
22 | #include <memory> | |
23 | #include <vector> | |
24 | ||
25 | #include "arrow/buffer.h" | |
26 | #include "arrow/status.h" | |
27 | #include "arrow/type.h" | |
28 | #include "arrow/util/checked_cast.h" | |
29 | #include "arrow/visitor_inline.h" | |
30 | ||
31 | namespace arrow { | |
32 | ||
33 | class MemoryPool; | |
34 | ||
35 | namespace internal { | |
36 | namespace { | |
37 | ||
38 | // ---------------------------------------------------------------------- | |
39 | // SparseTensorConverter for SparseCSRIndex | |
40 | ||
41 | class SparseCSXMatrixConverter : private SparseTensorConverterMixin { | |
42 | using SparseTensorConverterMixin::AssignIndex; | |
43 | using SparseTensorConverterMixin::IsNonZero; | |
44 | ||
45 | public: | |
46 | SparseCSXMatrixConverter(SparseMatrixCompressedAxis axis, const Tensor& tensor, | |
47 | const std::shared_ptr<DataType>& index_value_type, | |
48 | MemoryPool* pool) | |
49 | : axis_(axis), tensor_(tensor), index_value_type_(index_value_type), pool_(pool) {} | |
50 | ||
51 | Status Convert() { | |
52 | RETURN_NOT_OK(::arrow::internal::CheckSparseIndexMaximumValue(index_value_type_, | |
53 | tensor_.shape())); | |
54 | ||
55 | const int index_elsize = GetByteWidth(*index_value_type_); | |
56 | const int value_elsize = GetByteWidth(*tensor_.type()); | |
57 | ||
58 | const int64_t ndim = tensor_.ndim(); | |
59 | if (ndim > 2) { | |
60 | return Status::Invalid("Invalid tensor dimension"); | |
61 | } | |
62 | ||
63 | const int major_axis = static_cast<int>(axis_); | |
64 | const int64_t n_major = tensor_.shape()[major_axis]; | |
65 | const int64_t n_minor = tensor_.shape()[1 - major_axis]; | |
66 | ARROW_ASSIGN_OR_RAISE(int64_t nonzero_count, tensor_.CountNonZero()); | |
67 | ||
68 | std::shared_ptr<Buffer> indptr_buffer; | |
69 | std::shared_ptr<Buffer> indices_buffer; | |
70 | ||
71 | ARROW_ASSIGN_OR_RAISE(auto values_buffer, | |
72 | AllocateBuffer(value_elsize * nonzero_count, pool_)); | |
73 | auto* values = values_buffer->mutable_data(); | |
74 | ||
75 | const auto* tensor_data = tensor_.raw_data(); | |
76 | ||
77 | if (ndim <= 1) { | |
78 | return Status::NotImplemented("TODO for ndim <= 1"); | |
79 | } else { | |
80 | ARROW_ASSIGN_OR_RAISE(indptr_buffer, | |
81 | AllocateBuffer(index_elsize * (n_major + 1), pool_)); | |
82 | auto* indptr = indptr_buffer->mutable_data(); | |
83 | ||
84 | ARROW_ASSIGN_OR_RAISE(indices_buffer, | |
85 | AllocateBuffer(index_elsize * nonzero_count, pool_)); | |
86 | auto* indices = indices_buffer->mutable_data(); | |
87 | ||
88 | std::vector<int64_t> coords(2); | |
89 | int64_t k = 0; | |
90 | std::fill_n(indptr, index_elsize, 0); | |
91 | indptr += index_elsize; | |
92 | for (int64_t i = 0; i < n_major; ++i) { | |
93 | for (int64_t j = 0; j < n_minor; ++j) { | |
94 | if (axis_ == SparseMatrixCompressedAxis::ROW) { | |
95 | coords = {i, j}; | |
96 | } else { | |
97 | coords = {j, i}; | |
98 | } | |
99 | const int64_t offset = tensor_.CalculateValueOffset(coords); | |
100 | if (std::any_of(tensor_data + offset, tensor_data + offset + value_elsize, | |
101 | IsNonZero)) { | |
102 | std::copy_n(tensor_data + offset, value_elsize, values); | |
103 | values += value_elsize; | |
104 | ||
105 | AssignIndex(indices, j, index_elsize); | |
106 | indices += index_elsize; | |
107 | ||
108 | k++; | |
109 | } | |
110 | } | |
111 | AssignIndex(indptr, k, index_elsize); | |
112 | indptr += index_elsize; | |
113 | } | |
114 | } | |
115 | ||
116 | std::vector<int64_t> indptr_shape({n_major + 1}); | |
117 | std::shared_ptr<Tensor> indptr_tensor = | |
118 | std::make_shared<Tensor>(index_value_type_, indptr_buffer, indptr_shape); | |
119 | ||
120 | std::vector<int64_t> indices_shape({nonzero_count}); | |
121 | std::shared_ptr<Tensor> indices_tensor = | |
122 | std::make_shared<Tensor>(index_value_type_, indices_buffer, indices_shape); | |
123 | ||
124 | if (axis_ == SparseMatrixCompressedAxis::ROW) { | |
125 | sparse_index = std::make_shared<SparseCSRIndex>(indptr_tensor, indices_tensor); | |
126 | } else { | |
127 | sparse_index = std::make_shared<SparseCSCIndex>(indptr_tensor, indices_tensor); | |
128 | } | |
129 | data = std::move(values_buffer); | |
130 | ||
131 | return Status::OK(); | |
132 | } | |
133 | ||
134 | std::shared_ptr<SparseIndex> sparse_index; | |
135 | std::shared_ptr<Buffer> data; | |
136 | ||
137 | private: | |
138 | SparseMatrixCompressedAxis axis_; | |
139 | const Tensor& tensor_; | |
140 | const std::shared_ptr<DataType>& index_value_type_; | |
141 | MemoryPool* pool_; | |
142 | }; | |
143 | ||
144 | } // namespace | |
145 | ||
146 | Status MakeSparseCSXMatrixFromTensor(SparseMatrixCompressedAxis axis, | |
147 | const Tensor& tensor, | |
148 | const std::shared_ptr<DataType>& index_value_type, | |
149 | MemoryPool* pool, | |
150 | std::shared_ptr<SparseIndex>* out_sparse_index, | |
151 | std::shared_ptr<Buffer>* out_data) { | |
152 | SparseCSXMatrixConverter converter(axis, tensor, index_value_type, pool); | |
153 | RETURN_NOT_OK(converter.Convert()); | |
154 | ||
155 | *out_sparse_index = converter.sparse_index; | |
156 | *out_data = converter.data; | |
157 | return Status::OK(); | |
158 | } | |
159 | ||
160 | Result<std::shared_ptr<Tensor>> MakeTensorFromSparseCSXMatrix( | |
161 | SparseMatrixCompressedAxis axis, MemoryPool* pool, | |
162 | const std::shared_ptr<Tensor>& indptr, const std::shared_ptr<Tensor>& indices, | |
163 | const int64_t non_zero_length, const std::shared_ptr<DataType>& value_type, | |
164 | const std::vector<int64_t>& shape, const int64_t tensor_size, const uint8_t* raw_data, | |
165 | const std::vector<std::string>& dim_names) { | |
166 | const auto* indptr_data = indptr->raw_data(); | |
167 | const auto* indices_data = indices->raw_data(); | |
168 | ||
169 | const int indptr_elsize = GetByteWidth(*indptr->type()); | |
170 | const int indices_elsize = GetByteWidth(*indices->type()); | |
171 | ||
172 | const auto& fw_value_type = checked_cast<const FixedWidthType&>(*value_type); | |
173 | const int value_elsize = GetByteWidth(fw_value_type); | |
174 | ARROW_ASSIGN_OR_RAISE(auto values_buffer, | |
175 | AllocateBuffer(value_elsize * tensor_size, pool)); | |
176 | auto values = values_buffer->mutable_data(); | |
177 | std::fill_n(values, value_elsize * tensor_size, 0); | |
178 | ||
179 | std::vector<int64_t> strides; | |
180 | RETURN_NOT_OK(ComputeRowMajorStrides(fw_value_type, shape, &strides)); | |
181 | ||
182 | const auto nc = shape[1]; | |
183 | ||
184 | int64_t offset = 0; | |
185 | for (int64_t i = 0; i < indptr->size() - 1; ++i) { | |
186 | const auto start = | |
187 | SparseTensorConverterMixin::GetIndexValue(indptr_data, indptr_elsize); | |
188 | const auto stop = SparseTensorConverterMixin::GetIndexValue( | |
189 | indptr_data + indptr_elsize, indptr_elsize); | |
190 | ||
191 | for (int64_t j = start; j < stop; ++j) { | |
192 | const auto index = SparseTensorConverterMixin::GetIndexValue( | |
193 | indices_data + j * indices_elsize, indices_elsize); | |
194 | switch (axis) { | |
195 | case SparseMatrixCompressedAxis::ROW: | |
196 | offset = (index + i * nc) * value_elsize; | |
197 | break; | |
198 | case SparseMatrixCompressedAxis::COLUMN: | |
199 | offset = (i + index * nc) * value_elsize; | |
200 | break; | |
201 | } | |
202 | ||
203 | std::copy_n(raw_data, value_elsize, values + offset); | |
204 | raw_data += value_elsize; | |
205 | } | |
206 | ||
207 | indptr_data += indptr_elsize; | |
208 | } | |
209 | ||
210 | return std::make_shared<Tensor>(value_type, std::move(values_buffer), shape, strides, | |
211 | dim_names); | |
212 | } | |
213 | ||
214 | Result<std::shared_ptr<Tensor>> MakeTensorFromSparseCSRMatrix( | |
215 | MemoryPool* pool, const SparseCSRMatrix* sparse_tensor) { | |
216 | const auto& sparse_index = | |
217 | internal::checked_cast<const SparseCSRIndex&>(*sparse_tensor->sparse_index()); | |
218 | const auto& indptr = sparse_index.indptr(); | |
219 | const auto& indices = sparse_index.indices(); | |
220 | const auto non_zero_length = sparse_tensor->non_zero_length(); | |
221 | return MakeTensorFromSparseCSXMatrix( | |
222 | SparseMatrixCompressedAxis::ROW, pool, indptr, indices, non_zero_length, | |
223 | sparse_tensor->type(), sparse_tensor->shape(), sparse_tensor->size(), | |
224 | sparse_tensor->raw_data(), sparse_tensor->dim_names()); | |
225 | } | |
226 | ||
227 | Result<std::shared_ptr<Tensor>> MakeTensorFromSparseCSCMatrix( | |
228 | MemoryPool* pool, const SparseCSCMatrix* sparse_tensor) { | |
229 | const auto& sparse_index = | |
230 | internal::checked_cast<const SparseCSCIndex&>(*sparse_tensor->sparse_index()); | |
231 | const auto& indptr = sparse_index.indptr(); | |
232 | const auto& indices = sparse_index.indices(); | |
233 | const auto non_zero_length = sparse_tensor->non_zero_length(); | |
234 | return MakeTensorFromSparseCSXMatrix( | |
235 | SparseMatrixCompressedAxis::COLUMN, pool, indptr, indices, non_zero_length, | |
236 | sparse_tensor->type(), sparse_tensor->shape(), sparse_tensor->size(), | |
237 | sparse_tensor->raw_data(), sparse_tensor->dim_names()); | |
238 | } | |
239 | ||
240 | } // namespace internal | |
241 | } // namespace arrow |