]>
Commit | Line | Data |
---|---|---|
1d09f67e TL |
1 | // Licensed to the Apache Software Foundation (ASF) under one |
2 | // or more contributor license agreements. See the NOTICE file | |
3 | // distributed with this work for additional information | |
4 | // regarding copyright ownership. The ASF licenses this file | |
5 | // to you under the Apache License, Version 2.0 (the | |
6 | // "License"); you may not use this file except in compliance | |
7 | // with the License. You may obtain a copy of the License at | |
8 | // | |
9 | // http://www.apache.org/licenses/LICENSE-2.0 | |
10 | // | |
11 | // Unless required by applicable law or agreed to in writing, | |
12 | // software distributed under the License is distributed on an | |
13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | |
14 | // KIND, either express or implied. See the License for the | |
15 | // specific language governing permissions and limitations | |
16 | // under the License. | |
17 | ||
18 | #include "arrow/record_batch.h" | |
19 | ||
20 | #include <gmock/gmock.h> | |
21 | #include <gtest/gtest.h> | |
22 | ||
23 | #include <cstdint> | |
24 | #include <memory> | |
25 | #include <vector> | |
26 | ||
27 | #include "arrow/array/array_base.h" | |
28 | #include "arrow/array/data.h" | |
29 | #include "arrow/array/util.h" | |
30 | #include "arrow/chunked_array.h" | |
31 | #include "arrow/status.h" | |
32 | #include "arrow/table.h" | |
33 | #include "arrow/testing/gtest_common.h" | |
34 | #include "arrow/testing/gtest_util.h" | |
35 | #include "arrow/type.h" | |
36 | #include "arrow/util/key_value_metadata.h" | |
37 | ||
38 | namespace arrow { | |
39 | ||
40 | class TestRecordBatch : public TestBase {}; | |
41 | ||
42 | TEST_F(TestRecordBatch, Equals) { | |
43 | const int length = 10; | |
44 | ||
45 | auto f0 = field("f0", int32()); | |
46 | auto f1 = field("f1", uint8()); | |
47 | auto f2 = field("f2", int16()); | |
48 | ||
49 | auto metadata = key_value_metadata({"foo"}, {"bar"}); | |
50 | ||
51 | std::vector<std::shared_ptr<Field>> fields = {f0, f1, f2}; | |
52 | auto schema = ::arrow::schema({f0, f1, f2}); | |
53 | auto schema2 = ::arrow::schema({f0, f1}); | |
54 | auto schema3 = ::arrow::schema({f0, f1, f2}, metadata); | |
55 | ||
56 | auto a0 = MakeRandomArray<Int32Array>(length); | |
57 | auto a1 = MakeRandomArray<UInt8Array>(length); | |
58 | auto a2 = MakeRandomArray<Int16Array>(length); | |
59 | ||
60 | auto b1 = RecordBatch::Make(schema, length, {a0, a1, a2}); | |
61 | auto b2 = RecordBatch::Make(schema3, length, {a0, a1, a2}); | |
62 | auto b3 = RecordBatch::Make(schema2, length, {a0, a1}); | |
63 | auto b4 = RecordBatch::Make(schema, length, {a0, a1, a1}); | |
64 | ||
65 | ASSERT_TRUE(b1->Equals(*b1)); | |
66 | ASSERT_FALSE(b1->Equals(*b3)); | |
67 | ASSERT_FALSE(b1->Equals(*b4)); | |
68 | ||
69 | // Different metadata | |
70 | ASSERT_TRUE(b1->Equals(*b2)); | |
71 | ASSERT_FALSE(b1->Equals(*b2, /*check_metadata=*/true)); | |
72 | } | |
73 | ||
74 | TEST_F(TestRecordBatch, Validate) { | |
75 | const int length = 10; | |
76 | ||
77 | auto f0 = field("f0", int32()); | |
78 | auto f1 = field("f1", uint8()); | |
79 | auto f2 = field("f2", int16()); | |
80 | ||
81 | auto schema = ::arrow::schema({f0, f1, f2}); | |
82 | ||
83 | auto a0 = MakeRandomArray<Int32Array>(length); | |
84 | auto a1 = MakeRandomArray<UInt8Array>(length); | |
85 | auto a2 = MakeRandomArray<Int16Array>(length); | |
86 | auto a3 = MakeRandomArray<Int16Array>(5); | |
87 | ||
88 | auto b1 = RecordBatch::Make(schema, length, {a0, a1, a2}); | |
89 | ||
90 | ASSERT_OK(b1->ValidateFull()); | |
91 | ||
92 | // Length mismatch | |
93 | auto b2 = RecordBatch::Make(schema, length, {a0, a1, a3}); | |
94 | ASSERT_RAISES(Invalid, b2->ValidateFull()); | |
95 | ||
96 | // Type mismatch | |
97 | auto b3 = RecordBatch::Make(schema, length, {a0, a1, a0}); | |
98 | ASSERT_RAISES(Invalid, b3->ValidateFull()); | |
99 | } | |
100 | ||
101 | TEST_F(TestRecordBatch, Slice) { | |
102 | const int length = 7; | |
103 | ||
104 | auto f0 = field("f0", int32()); | |
105 | auto f1 = field("f1", uint8()); | |
106 | auto f2 = field("f2", int8()); | |
107 | ||
108 | std::vector<std::shared_ptr<Field>> fields = {f0, f1, f2}; | |
109 | auto schema = ::arrow::schema(fields); | |
110 | ||
111 | auto a0 = MakeRandomArray<Int32Array>(length); | |
112 | auto a1 = MakeRandomArray<UInt8Array>(length); | |
113 | auto a2 = ArrayFromJSON(int8(), "[0, 1, 2, 3, 4, 5, 6]"); | |
114 | ||
115 | auto batch = RecordBatch::Make(schema, length, {a0, a1, a2}); | |
116 | ||
117 | auto batch_slice = batch->Slice(2); | |
118 | auto batch_slice2 = batch->Slice(1, 5); | |
119 | ||
120 | ASSERT_EQ(batch_slice->num_rows(), batch->num_rows() - 2); | |
121 | ||
122 | for (int i = 0; i < batch->num_columns(); ++i) { | |
123 | ASSERT_EQ(2, batch_slice->column(i)->offset()); | |
124 | ASSERT_EQ(length - 2, batch_slice->column(i)->length()); | |
125 | ||
126 | ASSERT_EQ(1, batch_slice2->column(i)->offset()); | |
127 | ASSERT_EQ(5, batch_slice2->column(i)->length()); | |
128 | } | |
129 | ||
130 | // ARROW-9143: RecordBatch::Slice was incorrectly setting a2's | |
131 | // ArrayData::null_count to kUnknownNullCount | |
132 | ASSERT_EQ(batch_slice->column(2)->data()->null_count, 0); | |
133 | ASSERT_EQ(batch_slice2->column(2)->data()->null_count, 0); | |
134 | } | |
135 | ||
136 | TEST_F(TestRecordBatch, AddColumn) { | |
137 | const int length = 10; | |
138 | ||
139 | auto field1 = field("f1", int32()); | |
140 | auto field2 = field("f2", uint8()); | |
141 | auto field3 = field("f3", int16()); | |
142 | ||
143 | auto schema1 = ::arrow::schema({field1, field2}); | |
144 | auto schema2 = ::arrow::schema({field2, field3}); | |
145 | auto schema3 = ::arrow::schema({field2}); | |
146 | ||
147 | auto array1 = MakeRandomArray<Int32Array>(length); | |
148 | auto array2 = MakeRandomArray<UInt8Array>(length); | |
149 | auto array3 = MakeRandomArray<Int16Array>(length); | |
150 | ||
151 | auto batch1 = RecordBatch::Make(schema1, length, {array1, array2}); | |
152 | auto batch2 = RecordBatch::Make(schema2, length, {array2, array3}); | |
153 | auto batch3 = RecordBatch::Make(schema3, length, {array2}); | |
154 | ||
155 | const RecordBatch& batch = *batch3; | |
156 | ||
157 | // Negative tests with invalid index | |
158 | ASSERT_RAISES(Invalid, batch.AddColumn(5, field1, array1)); | |
159 | ASSERT_RAISES(Invalid, batch.AddColumn(2, field1, array1)); | |
160 | ASSERT_RAISES(Invalid, batch.AddColumn(-1, field1, array1)); | |
161 | ||
162 | // Negative test with wrong length | |
163 | auto longer_col = MakeRandomArray<Int32Array>(length + 1); | |
164 | ASSERT_RAISES(Invalid, batch.AddColumn(0, field1, longer_col)); | |
165 | ||
166 | // Negative test with mismatch type | |
167 | ASSERT_RAISES(TypeError, batch.AddColumn(0, field1, array2)); | |
168 | ||
169 | ASSERT_OK_AND_ASSIGN(auto new_batch, batch.AddColumn(0, field1, array1)); | |
170 | AssertBatchesEqual(*new_batch, *batch1); | |
171 | ||
172 | ASSERT_OK_AND_ASSIGN(new_batch, batch.AddColumn(1, field3, array3)); | |
173 | AssertBatchesEqual(*new_batch, *batch2); | |
174 | ||
175 | ASSERT_OK_AND_ASSIGN(auto new_batch2, batch.AddColumn(1, "f3", array3)); | |
176 | AssertBatchesEqual(*new_batch2, *new_batch); | |
177 | ||
178 | ASSERT_TRUE(new_batch2->schema()->field(1)->nullable()); | |
179 | } | |
180 | ||
181 | TEST_F(TestRecordBatch, SetColumn) { | |
182 | const int length = 10; | |
183 | ||
184 | auto field1 = field("f1", int32()); | |
185 | auto field2 = field("f2", uint8()); | |
186 | auto field3 = field("f3", int16()); | |
187 | ||
188 | auto schema1 = ::arrow::schema({field1, field2}); | |
189 | auto schema2 = ::arrow::schema({field1, field3}); | |
190 | auto schema3 = ::arrow::schema({field3, field2}); | |
191 | ||
192 | auto array1 = MakeRandomArray<Int32Array>(length); | |
193 | auto array2 = MakeRandomArray<UInt8Array>(length); | |
194 | auto array3 = MakeRandomArray<Int16Array>(length); | |
195 | ||
196 | auto batch1 = RecordBatch::Make(schema1, length, {array1, array2}); | |
197 | auto batch2 = RecordBatch::Make(schema2, length, {array1, array3}); | |
198 | auto batch3 = RecordBatch::Make(schema3, length, {array3, array2}); | |
199 | ||
200 | const RecordBatch& batch = *batch1; | |
201 | ||
202 | // Negative tests with invalid index | |
203 | ASSERT_RAISES(Invalid, batch.SetColumn(5, field1, array1)); | |
204 | ASSERT_RAISES(Invalid, batch.SetColumn(-1, field1, array1)); | |
205 | ||
206 | // Negative test with wrong length | |
207 | auto longer_col = MakeRandomArray<Int32Array>(length + 1); | |
208 | ASSERT_RAISES(Invalid, batch.SetColumn(0, field1, longer_col)); | |
209 | ||
210 | // Negative test with mismatch type | |
211 | ASSERT_RAISES(TypeError, batch.SetColumn(0, field1, array2)); | |
212 | ||
213 | ASSERT_OK_AND_ASSIGN(auto new_batch, batch.SetColumn(1, field3, array3)); | |
214 | AssertBatchesEqual(*new_batch, *batch2); | |
215 | ||
216 | ASSERT_OK_AND_ASSIGN(new_batch, batch.SetColumn(0, field3, array3)); | |
217 | AssertBatchesEqual(*new_batch, *batch3); | |
218 | } | |
219 | ||
220 | TEST_F(TestRecordBatch, RemoveColumn) { | |
221 | const int length = 10; | |
222 | ||
223 | auto field1 = field("f1", int32()); | |
224 | auto field2 = field("f2", uint8()); | |
225 | auto field3 = field("f3", int16()); | |
226 | ||
227 | auto schema1 = ::arrow::schema({field1, field2, field3}); | |
228 | auto schema2 = ::arrow::schema({field2, field3}); | |
229 | auto schema3 = ::arrow::schema({field1, field3}); | |
230 | auto schema4 = ::arrow::schema({field1, field2}); | |
231 | ||
232 | auto array1 = MakeRandomArray<Int32Array>(length); | |
233 | auto array2 = MakeRandomArray<UInt8Array>(length); | |
234 | auto array3 = MakeRandomArray<Int16Array>(length); | |
235 | ||
236 | auto batch1 = RecordBatch::Make(schema1, length, {array1, array2, array3}); | |
237 | auto batch2 = RecordBatch::Make(schema2, length, {array2, array3}); | |
238 | auto batch3 = RecordBatch::Make(schema3, length, {array1, array3}); | |
239 | auto batch4 = RecordBatch::Make(schema4, length, {array1, array2}); | |
240 | ||
241 | const RecordBatch& batch = *batch1; | |
242 | std::shared_ptr<RecordBatch> result; | |
243 | ||
244 | // Negative tests with invalid index | |
245 | ASSERT_RAISES(Invalid, batch.RemoveColumn(3)); | |
246 | ASSERT_RAISES(Invalid, batch.RemoveColumn(-1)); | |
247 | ||
248 | ASSERT_OK_AND_ASSIGN(auto new_batch, batch.RemoveColumn(0)); | |
249 | AssertBatchesEqual(*new_batch, *batch2); | |
250 | ||
251 | ASSERT_OK_AND_ASSIGN(new_batch, batch.RemoveColumn(1)); | |
252 | AssertBatchesEqual(*new_batch, *batch3); | |
253 | ||
254 | ASSERT_OK_AND_ASSIGN(new_batch, batch.RemoveColumn(2)); | |
255 | AssertBatchesEqual(*new_batch, *batch4); | |
256 | } | |
257 | ||
258 | TEST_F(TestRecordBatch, SelectColumns) { | |
259 | const int length = 10; | |
260 | ||
261 | auto field1 = field("f1", int32()); | |
262 | auto field2 = field("f2", uint8()); | |
263 | auto field3 = field("f3", int16()); | |
264 | ||
265 | auto schema1 = ::arrow::schema({field1, field2, field3}); | |
266 | ||
267 | auto array1 = MakeRandomArray<Int32Array>(length); | |
268 | auto array2 = MakeRandomArray<UInt8Array>(length); | |
269 | auto array3 = MakeRandomArray<Int16Array>(length); | |
270 | ||
271 | auto batch = RecordBatch::Make(schema1, length, {array1, array2, array3}); | |
272 | ||
273 | ASSERT_OK_AND_ASSIGN(auto subset, batch->SelectColumns({0, 2})); | |
274 | ASSERT_OK(subset->ValidateFull()); | |
275 | ||
276 | auto expected_schema = ::arrow::schema({schema1->field(0), schema1->field(2)}); | |
277 | auto expected = | |
278 | RecordBatch::Make(expected_schema, length, {batch->column(0), batch->column(2)}); | |
279 | ASSERT_TRUE(subset->Equals(*expected)); | |
280 | ||
281 | // Out of bounds indices | |
282 | ASSERT_RAISES(Invalid, batch->SelectColumns({0, 3})); | |
283 | ASSERT_RAISES(Invalid, batch->SelectColumns({-1})); | |
284 | } | |
285 | ||
286 | TEST_F(TestRecordBatch, RemoveColumnEmpty) { | |
287 | const int length = 10; | |
288 | ||
289 | auto field1 = field("f1", int32()); | |
290 | auto schema1 = ::arrow::schema({field1}); | |
291 | auto array1 = MakeRandomArray<Int32Array>(length); | |
292 | auto batch1 = RecordBatch::Make(schema1, length, {array1}); | |
293 | ||
294 | ASSERT_OK_AND_ASSIGN(auto empty, batch1->RemoveColumn(0)); | |
295 | ASSERT_EQ(batch1->num_rows(), empty->num_rows()); | |
296 | ||
297 | ASSERT_OK_AND_ASSIGN(auto added, empty->AddColumn(0, field1, array1)); | |
298 | AssertBatchesEqual(*added, *batch1); | |
299 | } | |
300 | ||
301 | TEST_F(TestRecordBatch, ToFromEmptyStructArray) { | |
302 | auto batch1 = | |
303 | RecordBatch::Make(::arrow::schema({}), 10, std::vector<std::shared_ptr<Array>>{}); | |
304 | ASSERT_OK_AND_ASSIGN(auto struct_array, batch1->ToStructArray()); | |
305 | ASSERT_EQ(10, struct_array->length()); | |
306 | ASSERT_OK_AND_ASSIGN(auto batch2, RecordBatch::FromStructArray(struct_array)); | |
307 | ASSERT_TRUE(batch1->Equals(*batch2)); | |
308 | } | |
309 | ||
310 | TEST_F(TestRecordBatch, FromStructArrayInvalidType) { | |
311 | ASSERT_RAISES(TypeError, RecordBatch::FromStructArray(MakeRandomArray<Int32Array>(10))); | |
312 | } | |
313 | ||
314 | TEST_F(TestRecordBatch, FromStructArrayInvalidNullCount) { | |
315 | auto struct_array = | |
316 | ArrayFromJSON(struct_({field("f1", int32())}), R"([{"f1": 1}, null])"); | |
317 | ASSERT_RAISES(Invalid, RecordBatch::FromStructArray(struct_array)); | |
318 | } | |
319 | ||
320 | } // namespace arrow |