]>
Commit | Line | Data |
---|---|---|
1d09f67e TL |
1 | // Licensed to the Apache Software Foundation (ASF) under one |
2 | // or more contributor license agreements. See the NOTICE file | |
3 | // distributed with this work for additional information | |
4 | // regarding copyright ownership. The ASF licenses this file | |
5 | // to you under the Apache License, Version 2.0 (the | |
6 | // "License"); you may not use this file except in compliance | |
7 | // with the License. You may obtain a copy of the License at | |
8 | // | |
9 | // http://www.apache.org/licenses/LICENSE-2.0 | |
10 | // | |
11 | // Unless required by applicable law or agreed to in writing, | |
12 | // software distributed under the License is distributed on an | |
13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | |
14 | // KIND, either express or implied. See the License for the | |
15 | // specific language governing permissions and limitations | |
16 | // under the License. | |
17 | ||
18 | #include "arrow/record_batch.h" | |
19 | ||
20 | #include <algorithm> | |
21 | #include <cstdlib> | |
22 | #include <memory> | |
23 | #include <sstream> | |
24 | #include <string> | |
25 | #include <utility> | |
26 | ||
27 | #include "arrow/array.h" | |
28 | #include "arrow/array/validate.h" | |
29 | #include "arrow/pretty_print.h" | |
30 | #include "arrow/status.h" | |
31 | #include "arrow/table.h" | |
32 | #include "arrow/type.h" | |
33 | #include "arrow/util/atomic_shared_ptr.h" | |
34 | #include "arrow/util/iterator.h" | |
35 | #include "arrow/util/logging.h" | |
36 | #include "arrow/util/vector.h" | |
37 | ||
38 | namespace arrow { | |
39 | ||
40 | Result<std::shared_ptr<RecordBatch>> RecordBatch::AddColumn( | |
41 | int i, std::string field_name, const std::shared_ptr<Array>& column) const { | |
42 | auto field = ::arrow::field(std::move(field_name), column->type()); | |
43 | return AddColumn(i, field, column); | |
44 | } | |
45 | ||
46 | std::shared_ptr<Array> RecordBatch::GetColumnByName(const std::string& name) const { | |
47 | auto i = schema_->GetFieldIndex(name); | |
48 | return i == -1 ? NULLPTR : column(i); | |
49 | } | |
50 | ||
51 | int RecordBatch::num_columns() const { return schema_->num_fields(); } | |
52 | ||
53 | /// \class SimpleRecordBatch | |
54 | /// \brief A basic, non-lazy in-memory record batch | |
55 | class SimpleRecordBatch : public RecordBatch { | |
56 | public: | |
57 | SimpleRecordBatch(std::shared_ptr<Schema> schema, int64_t num_rows, | |
58 | std::vector<std::shared_ptr<Array>> columns) | |
59 | : RecordBatch(std::move(schema), num_rows), boxed_columns_(std::move(columns)) { | |
60 | columns_.resize(boxed_columns_.size()); | |
61 | for (size_t i = 0; i < columns_.size(); ++i) { | |
62 | columns_[i] = boxed_columns_[i]->data(); | |
63 | } | |
64 | } | |
65 | ||
66 | SimpleRecordBatch(const std::shared_ptr<Schema>& schema, int64_t num_rows, | |
67 | std::vector<std::shared_ptr<ArrayData>> columns) | |
68 | : RecordBatch(std::move(schema), num_rows), columns_(std::move(columns)) { | |
69 | boxed_columns_.resize(schema_->num_fields()); | |
70 | } | |
71 | ||
72 | const std::vector<std::shared_ptr<Array>>& columns() const override { | |
73 | for (int i = 0; i < num_columns(); ++i) { | |
74 | // Force all columns to be boxed | |
75 | column(i); | |
76 | } | |
77 | return boxed_columns_; | |
78 | } | |
79 | ||
80 | std::shared_ptr<Array> column(int i) const override { | |
81 | std::shared_ptr<Array> result = internal::atomic_load(&boxed_columns_[i]); | |
82 | if (!result) { | |
83 | result = MakeArray(columns_[i]); | |
84 | internal::atomic_store(&boxed_columns_[i], result); | |
85 | } | |
86 | return result; | |
87 | } | |
88 | ||
89 | std::shared_ptr<ArrayData> column_data(int i) const override { return columns_[i]; } | |
90 | ||
91 | const ArrayDataVector& column_data() const override { return columns_; } | |
92 | ||
93 | Result<std::shared_ptr<RecordBatch>> AddColumn( | |
94 | int i, const std::shared_ptr<Field>& field, | |
95 | const std::shared_ptr<Array>& column) const override { | |
96 | ARROW_CHECK(field != nullptr); | |
97 | ARROW_CHECK(column != nullptr); | |
98 | ||
99 | if (!field->type()->Equals(column->type())) { | |
100 | return Status::TypeError("Column data type ", field->type()->name(), | |
101 | " does not match field data type ", | |
102 | column->type()->name()); | |
103 | } | |
104 | if (column->length() != num_rows_) { | |
105 | return Status::Invalid( | |
106 | "Added column's length must match record batch's length. Expected length ", | |
107 | num_rows_, " but got length ", column->length()); | |
108 | } | |
109 | ||
110 | ARROW_ASSIGN_OR_RAISE(auto new_schema, schema_->AddField(i, field)); | |
111 | return RecordBatch::Make(std::move(new_schema), num_rows_, | |
112 | internal::AddVectorElement(columns_, i, column->data())); | |
113 | } | |
114 | ||
115 | Result<std::shared_ptr<RecordBatch>> SetColumn( | |
116 | int i, const std::shared_ptr<Field>& field, | |
117 | const std::shared_ptr<Array>& column) const override { | |
118 | ARROW_CHECK(field != nullptr); | |
119 | ARROW_CHECK(column != nullptr); | |
120 | ||
121 | if (!field->type()->Equals(column->type())) { | |
122 | return Status::TypeError("Column data type ", field->type()->name(), | |
123 | " does not match field data type ", | |
124 | column->type()->name()); | |
125 | } | |
126 | if (column->length() != num_rows_) { | |
127 | return Status::Invalid( | |
128 | "Added column's length must match record batch's length. Expected length ", | |
129 | num_rows_, " but got length ", column->length()); | |
130 | } | |
131 | ||
132 | ARROW_ASSIGN_OR_RAISE(auto new_schema, schema_->SetField(i, field)); | |
133 | return RecordBatch::Make(std::move(new_schema), num_rows_, | |
134 | internal::ReplaceVectorElement(columns_, i, column->data())); | |
135 | } | |
136 | ||
137 | Result<std::shared_ptr<RecordBatch>> RemoveColumn(int i) const override { | |
138 | ARROW_ASSIGN_OR_RAISE(auto new_schema, schema_->RemoveField(i)); | |
139 | return RecordBatch::Make(std::move(new_schema), num_rows_, | |
140 | internal::DeleteVectorElement(columns_, i)); | |
141 | } | |
142 | ||
143 | std::shared_ptr<RecordBatch> ReplaceSchemaMetadata( | |
144 | const std::shared_ptr<const KeyValueMetadata>& metadata) const override { | |
145 | auto new_schema = schema_->WithMetadata(metadata); | |
146 | return RecordBatch::Make(std::move(new_schema), num_rows_, columns_); | |
147 | } | |
148 | ||
149 | std::shared_ptr<RecordBatch> Slice(int64_t offset, int64_t length) const override { | |
150 | std::vector<std::shared_ptr<ArrayData>> arrays; | |
151 | arrays.reserve(num_columns()); | |
152 | for (const auto& field : columns_) { | |
153 | arrays.emplace_back(field->Slice(offset, length)); | |
154 | } | |
155 | int64_t num_rows = std::min(num_rows_ - offset, length); | |
156 | return std::make_shared<SimpleRecordBatch>(schema_, num_rows, std::move(arrays)); | |
157 | } | |
158 | ||
159 | Status Validate() const override { | |
160 | if (static_cast<int>(columns_.size()) != schema_->num_fields()) { | |
161 | return Status::Invalid("Number of columns did not match schema"); | |
162 | } | |
163 | return RecordBatch::Validate(); | |
164 | } | |
165 | ||
166 | private: | |
167 | std::vector<std::shared_ptr<ArrayData>> columns_; | |
168 | ||
169 | // Caching boxed array data | |
170 | mutable std::vector<std::shared_ptr<Array>> boxed_columns_; | |
171 | }; | |
172 | ||
173 | RecordBatch::RecordBatch(const std::shared_ptr<Schema>& schema, int64_t num_rows) | |
174 | : schema_(schema), num_rows_(num_rows) {} | |
175 | ||
176 | std::shared_ptr<RecordBatch> RecordBatch::Make( | |
177 | std::shared_ptr<Schema> schema, int64_t num_rows, | |
178 | std::vector<std::shared_ptr<Array>> columns) { | |
179 | DCHECK_EQ(schema->num_fields(), static_cast<int>(columns.size())); | |
180 | return std::make_shared<SimpleRecordBatch>(std::move(schema), num_rows, columns); | |
181 | } | |
182 | ||
183 | std::shared_ptr<RecordBatch> RecordBatch::Make( | |
184 | std::shared_ptr<Schema> schema, int64_t num_rows, | |
185 | std::vector<std::shared_ptr<ArrayData>> columns) { | |
186 | DCHECK_EQ(schema->num_fields(), static_cast<int>(columns.size())); | |
187 | return std::make_shared<SimpleRecordBatch>(std::move(schema), num_rows, | |
188 | std::move(columns)); | |
189 | } | |
190 | ||
191 | Result<std::shared_ptr<RecordBatch>> RecordBatch::FromStructArray( | |
192 | const std::shared_ptr<Array>& array) { | |
193 | if (array->type_id() != Type::STRUCT) { | |
194 | return Status::TypeError("Cannot construct record batch from array of type ", | |
195 | *array->type()); | |
196 | } | |
197 | if (array->null_count() != 0) { | |
198 | return Status::Invalid( | |
199 | "Unable to construct record batch from a StructArray with non-zero nulls."); | |
200 | } | |
201 | return Make(arrow::schema(array->type()->fields()), array->length(), | |
202 | array->data()->child_data); | |
203 | } | |
204 | ||
205 | Result<std::shared_ptr<StructArray>> RecordBatch::ToStructArray() const { | |
206 | if (num_columns() != 0) { | |
207 | return StructArray::Make(columns(), schema()->fields()); | |
208 | } | |
209 | return std::make_shared<StructArray>(arrow::struct_({}), num_rows_, | |
210 | std::vector<std::shared_ptr<Array>>{}, | |
211 | /*null_bitmap=*/nullptr, | |
212 | /*null_count=*/0, | |
213 | /*offset=*/0); | |
214 | } | |
215 | ||
216 | const std::string& RecordBatch::column_name(int i) const { | |
217 | return schema_->field(i)->name(); | |
218 | } | |
219 | ||
220 | bool RecordBatch::Equals(const RecordBatch& other, bool check_metadata) const { | |
221 | if (num_columns() != other.num_columns() || num_rows_ != other.num_rows()) { | |
222 | return false; | |
223 | } | |
224 | ||
225 | if (check_metadata) { | |
226 | if (!schema_->Equals(*other.schema(), /*check_metadata=*/true)) { | |
227 | return false; | |
228 | } | |
229 | } | |
230 | ||
231 | for (int i = 0; i < num_columns(); ++i) { | |
232 | if (!column(i)->Equals(other.column(i))) { | |
233 | return false; | |
234 | } | |
235 | } | |
236 | ||
237 | return true; | |
238 | } | |
239 | ||
240 | bool RecordBatch::ApproxEquals(const RecordBatch& other) const { | |
241 | if (num_columns() != other.num_columns() || num_rows_ != other.num_rows()) { | |
242 | return false; | |
243 | } | |
244 | ||
245 | for (int i = 0; i < num_columns(); ++i) { | |
246 | if (!column(i)->ApproxEquals(other.column(i))) { | |
247 | return false; | |
248 | } | |
249 | } | |
250 | ||
251 | return true; | |
252 | } | |
253 | ||
254 | Result<std::shared_ptr<RecordBatch>> RecordBatch::SelectColumns( | |
255 | const std::vector<int>& indices) const { | |
256 | int n = static_cast<int>(indices.size()); | |
257 | ||
258 | FieldVector fields(n); | |
259 | ArrayVector columns(n); | |
260 | ||
261 | for (int i = 0; i < n; i++) { | |
262 | int pos = indices[i]; | |
263 | if (pos < 0 || pos > num_columns() - 1) { | |
264 | return Status::Invalid("Invalid column index ", pos, " to select columns."); | |
265 | } | |
266 | fields[i] = schema()->field(pos); | |
267 | columns[i] = column(pos); | |
268 | } | |
269 | ||
270 | auto new_schema = | |
271 | std::make_shared<arrow::Schema>(std::move(fields), schema()->metadata()); | |
272 | return RecordBatch::Make(std::move(new_schema), num_rows(), std::move(columns)); | |
273 | } | |
274 | ||
275 | std::shared_ptr<RecordBatch> RecordBatch::Slice(int64_t offset) const { | |
276 | return Slice(offset, this->num_rows() - offset); | |
277 | } | |
278 | ||
279 | std::string RecordBatch::ToString() const { | |
280 | std::stringstream ss; | |
281 | ARROW_CHECK_OK(PrettyPrint(*this, 0, &ss)); | |
282 | return ss.str(); | |
283 | } | |
284 | ||
285 | Status RecordBatch::Validate() const { | |
286 | for (int i = 0; i < num_columns(); ++i) { | |
287 | const auto& array = *this->column(i); | |
288 | if (array.length() != num_rows_) { | |
289 | return Status::Invalid("Number of rows in column ", i, | |
290 | " did not match batch: ", array.length(), " vs ", num_rows_); | |
291 | } | |
292 | const auto& schema_type = *schema_->field(i)->type(); | |
293 | if (!array.type()->Equals(schema_type)) { | |
294 | return Status::Invalid("Column ", i, | |
295 | " type not match schema: ", array.type()->ToString(), " vs ", | |
296 | schema_type.ToString()); | |
297 | } | |
298 | RETURN_NOT_OK(internal::ValidateArray(array)); | |
299 | } | |
300 | return Status::OK(); | |
301 | } | |
302 | ||
303 | Status RecordBatch::ValidateFull() const { | |
304 | RETURN_NOT_OK(Validate()); | |
305 | for (int i = 0; i < num_columns(); ++i) { | |
306 | const auto& array = *this->column(i); | |
307 | RETURN_NOT_OK(internal::ValidateArrayFull(array)); | |
308 | } | |
309 | return Status::OK(); | |
310 | } | |
311 | ||
312 | // ---------------------------------------------------------------------- | |
313 | // Base record batch reader | |
314 | ||
315 | Status RecordBatchReader::ReadAll(std::vector<std::shared_ptr<RecordBatch>>* batches) { | |
316 | while (true) { | |
317 | std::shared_ptr<RecordBatch> batch; | |
318 | RETURN_NOT_OK(ReadNext(&batch)); | |
319 | if (!batch) { | |
320 | break; | |
321 | } | |
322 | batches->emplace_back(std::move(batch)); | |
323 | } | |
324 | return Status::OK(); | |
325 | } | |
326 | ||
327 | Status RecordBatchReader::ReadAll(std::shared_ptr<Table>* table) { | |
328 | std::vector<std::shared_ptr<RecordBatch>> batches; | |
329 | RETURN_NOT_OK(ReadAll(&batches)); | |
330 | return Table::FromRecordBatches(schema(), std::move(batches)).Value(table); | |
331 | } | |
332 | ||
333 | class SimpleRecordBatchReader : public RecordBatchReader { | |
334 | public: | |
335 | SimpleRecordBatchReader(Iterator<std::shared_ptr<RecordBatch>> it, | |
336 | std::shared_ptr<Schema> schema) | |
337 | : schema_(std::move(schema)), it_(std::move(it)) {} | |
338 | ||
339 | SimpleRecordBatchReader(std::vector<std::shared_ptr<RecordBatch>> batches, | |
340 | std::shared_ptr<Schema> schema) | |
341 | : schema_(std::move(schema)), it_(MakeVectorIterator(std::move(batches))) {} | |
342 | ||
343 | Status ReadNext(std::shared_ptr<RecordBatch>* batch) override { | |
344 | return it_.Next().Value(batch); | |
345 | } | |
346 | ||
347 | std::shared_ptr<Schema> schema() const override { return schema_; } | |
348 | ||
349 | protected: | |
350 | std::shared_ptr<Schema> schema_; | |
351 | Iterator<std::shared_ptr<RecordBatch>> it_; | |
352 | }; | |
353 | ||
354 | Result<std::shared_ptr<RecordBatchReader>> RecordBatchReader::Make( | |
355 | std::vector<std::shared_ptr<RecordBatch>> batches, std::shared_ptr<Schema> schema) { | |
356 | if (schema == nullptr) { | |
357 | if (batches.size() == 0 || batches[0] == nullptr) { | |
358 | return Status::Invalid("Cannot infer schema from empty vector or nullptr"); | |
359 | } | |
360 | ||
361 | schema = batches[0]->schema(); | |
362 | } | |
363 | ||
364 | return std::make_shared<SimpleRecordBatchReader>(std::move(batches), schema); | |
365 | } | |
366 | ||
367 | } // namespace arrow |