]> git.proxmox.com Git - ceph.git/blame - ceph/src/arrow/cpp/src/arrow/record_batch_test.cc
import quincy 17.2.0
[ceph.git] / ceph / src / arrow / cpp / src / arrow / record_batch_test.cc
CommitLineData
1d09f67e
TL
1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18#include "arrow/record_batch.h"
19
20#include <gmock/gmock.h>
21#include <gtest/gtest.h>
22
23#include <cstdint>
24#include <memory>
25#include <vector>
26
27#include "arrow/array/array_base.h"
28#include "arrow/array/data.h"
29#include "arrow/array/util.h"
30#include "arrow/chunked_array.h"
31#include "arrow/status.h"
32#include "arrow/table.h"
33#include "arrow/testing/gtest_common.h"
34#include "arrow/testing/gtest_util.h"
35#include "arrow/type.h"
36#include "arrow/util/key_value_metadata.h"
37
38namespace arrow {
39
40class TestRecordBatch : public TestBase {};
41
42TEST_F(TestRecordBatch, Equals) {
43 const int length = 10;
44
45 auto f0 = field("f0", int32());
46 auto f1 = field("f1", uint8());
47 auto f2 = field("f2", int16());
48
49 auto metadata = key_value_metadata({"foo"}, {"bar"});
50
51 std::vector<std::shared_ptr<Field>> fields = {f0, f1, f2};
52 auto schema = ::arrow::schema({f0, f1, f2});
53 auto schema2 = ::arrow::schema({f0, f1});
54 auto schema3 = ::arrow::schema({f0, f1, f2}, metadata);
55
56 auto a0 = MakeRandomArray<Int32Array>(length);
57 auto a1 = MakeRandomArray<UInt8Array>(length);
58 auto a2 = MakeRandomArray<Int16Array>(length);
59
60 auto b1 = RecordBatch::Make(schema, length, {a0, a1, a2});
61 auto b2 = RecordBatch::Make(schema3, length, {a0, a1, a2});
62 auto b3 = RecordBatch::Make(schema2, length, {a0, a1});
63 auto b4 = RecordBatch::Make(schema, length, {a0, a1, a1});
64
65 ASSERT_TRUE(b1->Equals(*b1));
66 ASSERT_FALSE(b1->Equals(*b3));
67 ASSERT_FALSE(b1->Equals(*b4));
68
69 // Different metadata
70 ASSERT_TRUE(b1->Equals(*b2));
71 ASSERT_FALSE(b1->Equals(*b2, /*check_metadata=*/true));
72}
73
74TEST_F(TestRecordBatch, Validate) {
75 const int length = 10;
76
77 auto f0 = field("f0", int32());
78 auto f1 = field("f1", uint8());
79 auto f2 = field("f2", int16());
80
81 auto schema = ::arrow::schema({f0, f1, f2});
82
83 auto a0 = MakeRandomArray<Int32Array>(length);
84 auto a1 = MakeRandomArray<UInt8Array>(length);
85 auto a2 = MakeRandomArray<Int16Array>(length);
86 auto a3 = MakeRandomArray<Int16Array>(5);
87
88 auto b1 = RecordBatch::Make(schema, length, {a0, a1, a2});
89
90 ASSERT_OK(b1->ValidateFull());
91
92 // Length mismatch
93 auto b2 = RecordBatch::Make(schema, length, {a0, a1, a3});
94 ASSERT_RAISES(Invalid, b2->ValidateFull());
95
96 // Type mismatch
97 auto b3 = RecordBatch::Make(schema, length, {a0, a1, a0});
98 ASSERT_RAISES(Invalid, b3->ValidateFull());
99}
100
101TEST_F(TestRecordBatch, Slice) {
102 const int length = 7;
103
104 auto f0 = field("f0", int32());
105 auto f1 = field("f1", uint8());
106 auto f2 = field("f2", int8());
107
108 std::vector<std::shared_ptr<Field>> fields = {f0, f1, f2};
109 auto schema = ::arrow::schema(fields);
110
111 auto a0 = MakeRandomArray<Int32Array>(length);
112 auto a1 = MakeRandomArray<UInt8Array>(length);
113 auto a2 = ArrayFromJSON(int8(), "[0, 1, 2, 3, 4, 5, 6]");
114
115 auto batch = RecordBatch::Make(schema, length, {a0, a1, a2});
116
117 auto batch_slice = batch->Slice(2);
118 auto batch_slice2 = batch->Slice(1, 5);
119
120 ASSERT_EQ(batch_slice->num_rows(), batch->num_rows() - 2);
121
122 for (int i = 0; i < batch->num_columns(); ++i) {
123 ASSERT_EQ(2, batch_slice->column(i)->offset());
124 ASSERT_EQ(length - 2, batch_slice->column(i)->length());
125
126 ASSERT_EQ(1, batch_slice2->column(i)->offset());
127 ASSERT_EQ(5, batch_slice2->column(i)->length());
128 }
129
130 // ARROW-9143: RecordBatch::Slice was incorrectly setting a2's
131 // ArrayData::null_count to kUnknownNullCount
132 ASSERT_EQ(batch_slice->column(2)->data()->null_count, 0);
133 ASSERT_EQ(batch_slice2->column(2)->data()->null_count, 0);
134}
135
136TEST_F(TestRecordBatch, AddColumn) {
137 const int length = 10;
138
139 auto field1 = field("f1", int32());
140 auto field2 = field("f2", uint8());
141 auto field3 = field("f3", int16());
142
143 auto schema1 = ::arrow::schema({field1, field2});
144 auto schema2 = ::arrow::schema({field2, field3});
145 auto schema3 = ::arrow::schema({field2});
146
147 auto array1 = MakeRandomArray<Int32Array>(length);
148 auto array2 = MakeRandomArray<UInt8Array>(length);
149 auto array3 = MakeRandomArray<Int16Array>(length);
150
151 auto batch1 = RecordBatch::Make(schema1, length, {array1, array2});
152 auto batch2 = RecordBatch::Make(schema2, length, {array2, array3});
153 auto batch3 = RecordBatch::Make(schema3, length, {array2});
154
155 const RecordBatch& batch = *batch3;
156
157 // Negative tests with invalid index
158 ASSERT_RAISES(Invalid, batch.AddColumn(5, field1, array1));
159 ASSERT_RAISES(Invalid, batch.AddColumn(2, field1, array1));
160 ASSERT_RAISES(Invalid, batch.AddColumn(-1, field1, array1));
161
162 // Negative test with wrong length
163 auto longer_col = MakeRandomArray<Int32Array>(length + 1);
164 ASSERT_RAISES(Invalid, batch.AddColumn(0, field1, longer_col));
165
166 // Negative test with mismatch type
167 ASSERT_RAISES(TypeError, batch.AddColumn(0, field1, array2));
168
169 ASSERT_OK_AND_ASSIGN(auto new_batch, batch.AddColumn(0, field1, array1));
170 AssertBatchesEqual(*new_batch, *batch1);
171
172 ASSERT_OK_AND_ASSIGN(new_batch, batch.AddColumn(1, field3, array3));
173 AssertBatchesEqual(*new_batch, *batch2);
174
175 ASSERT_OK_AND_ASSIGN(auto new_batch2, batch.AddColumn(1, "f3", array3));
176 AssertBatchesEqual(*new_batch2, *new_batch);
177
178 ASSERT_TRUE(new_batch2->schema()->field(1)->nullable());
179}
180
181TEST_F(TestRecordBatch, SetColumn) {
182 const int length = 10;
183
184 auto field1 = field("f1", int32());
185 auto field2 = field("f2", uint8());
186 auto field3 = field("f3", int16());
187
188 auto schema1 = ::arrow::schema({field1, field2});
189 auto schema2 = ::arrow::schema({field1, field3});
190 auto schema3 = ::arrow::schema({field3, field2});
191
192 auto array1 = MakeRandomArray<Int32Array>(length);
193 auto array2 = MakeRandomArray<UInt8Array>(length);
194 auto array3 = MakeRandomArray<Int16Array>(length);
195
196 auto batch1 = RecordBatch::Make(schema1, length, {array1, array2});
197 auto batch2 = RecordBatch::Make(schema2, length, {array1, array3});
198 auto batch3 = RecordBatch::Make(schema3, length, {array3, array2});
199
200 const RecordBatch& batch = *batch1;
201
202 // Negative tests with invalid index
203 ASSERT_RAISES(Invalid, batch.SetColumn(5, field1, array1));
204 ASSERT_RAISES(Invalid, batch.SetColumn(-1, field1, array1));
205
206 // Negative test with wrong length
207 auto longer_col = MakeRandomArray<Int32Array>(length + 1);
208 ASSERT_RAISES(Invalid, batch.SetColumn(0, field1, longer_col));
209
210 // Negative test with mismatch type
211 ASSERT_RAISES(TypeError, batch.SetColumn(0, field1, array2));
212
213 ASSERT_OK_AND_ASSIGN(auto new_batch, batch.SetColumn(1, field3, array3));
214 AssertBatchesEqual(*new_batch, *batch2);
215
216 ASSERT_OK_AND_ASSIGN(new_batch, batch.SetColumn(0, field3, array3));
217 AssertBatchesEqual(*new_batch, *batch3);
218}
219
220TEST_F(TestRecordBatch, RemoveColumn) {
221 const int length = 10;
222
223 auto field1 = field("f1", int32());
224 auto field2 = field("f2", uint8());
225 auto field3 = field("f3", int16());
226
227 auto schema1 = ::arrow::schema({field1, field2, field3});
228 auto schema2 = ::arrow::schema({field2, field3});
229 auto schema3 = ::arrow::schema({field1, field3});
230 auto schema4 = ::arrow::schema({field1, field2});
231
232 auto array1 = MakeRandomArray<Int32Array>(length);
233 auto array2 = MakeRandomArray<UInt8Array>(length);
234 auto array3 = MakeRandomArray<Int16Array>(length);
235
236 auto batch1 = RecordBatch::Make(schema1, length, {array1, array2, array3});
237 auto batch2 = RecordBatch::Make(schema2, length, {array2, array3});
238 auto batch3 = RecordBatch::Make(schema3, length, {array1, array3});
239 auto batch4 = RecordBatch::Make(schema4, length, {array1, array2});
240
241 const RecordBatch& batch = *batch1;
242 std::shared_ptr<RecordBatch> result;
243
244 // Negative tests with invalid index
245 ASSERT_RAISES(Invalid, batch.RemoveColumn(3));
246 ASSERT_RAISES(Invalid, batch.RemoveColumn(-1));
247
248 ASSERT_OK_AND_ASSIGN(auto new_batch, batch.RemoveColumn(0));
249 AssertBatchesEqual(*new_batch, *batch2);
250
251 ASSERT_OK_AND_ASSIGN(new_batch, batch.RemoveColumn(1));
252 AssertBatchesEqual(*new_batch, *batch3);
253
254 ASSERT_OK_AND_ASSIGN(new_batch, batch.RemoveColumn(2));
255 AssertBatchesEqual(*new_batch, *batch4);
256}
257
258TEST_F(TestRecordBatch, SelectColumns) {
259 const int length = 10;
260
261 auto field1 = field("f1", int32());
262 auto field2 = field("f2", uint8());
263 auto field3 = field("f3", int16());
264
265 auto schema1 = ::arrow::schema({field1, field2, field3});
266
267 auto array1 = MakeRandomArray<Int32Array>(length);
268 auto array2 = MakeRandomArray<UInt8Array>(length);
269 auto array3 = MakeRandomArray<Int16Array>(length);
270
271 auto batch = RecordBatch::Make(schema1, length, {array1, array2, array3});
272
273 ASSERT_OK_AND_ASSIGN(auto subset, batch->SelectColumns({0, 2}));
274 ASSERT_OK(subset->ValidateFull());
275
276 auto expected_schema = ::arrow::schema({schema1->field(0), schema1->field(2)});
277 auto expected =
278 RecordBatch::Make(expected_schema, length, {batch->column(0), batch->column(2)});
279 ASSERT_TRUE(subset->Equals(*expected));
280
281 // Out of bounds indices
282 ASSERT_RAISES(Invalid, batch->SelectColumns({0, 3}));
283 ASSERT_RAISES(Invalid, batch->SelectColumns({-1}));
284}
285
286TEST_F(TestRecordBatch, RemoveColumnEmpty) {
287 const int length = 10;
288
289 auto field1 = field("f1", int32());
290 auto schema1 = ::arrow::schema({field1});
291 auto array1 = MakeRandomArray<Int32Array>(length);
292 auto batch1 = RecordBatch::Make(schema1, length, {array1});
293
294 ASSERT_OK_AND_ASSIGN(auto empty, batch1->RemoveColumn(0));
295 ASSERT_EQ(batch1->num_rows(), empty->num_rows());
296
297 ASSERT_OK_AND_ASSIGN(auto added, empty->AddColumn(0, field1, array1));
298 AssertBatchesEqual(*added, *batch1);
299}
300
301TEST_F(TestRecordBatch, ToFromEmptyStructArray) {
302 auto batch1 =
303 RecordBatch::Make(::arrow::schema({}), 10, std::vector<std::shared_ptr<Array>>{});
304 ASSERT_OK_AND_ASSIGN(auto struct_array, batch1->ToStructArray());
305 ASSERT_EQ(10, struct_array->length());
306 ASSERT_OK_AND_ASSIGN(auto batch2, RecordBatch::FromStructArray(struct_array));
307 ASSERT_TRUE(batch1->Equals(*batch2));
308}
309
310TEST_F(TestRecordBatch, FromStructArrayInvalidType) {
311 ASSERT_RAISES(TypeError, RecordBatch::FromStructArray(MakeRandomArray<Int32Array>(10)));
312}
313
314TEST_F(TestRecordBatch, FromStructArrayInvalidNullCount) {
315 auto struct_array =
316 ArrayFromJSON(struct_({field("f1", int32())}), R"([{"f1": 1}, null])");
317 ASSERT_RAISES(Invalid, RecordBatch::FromStructArray(struct_array));
318}
319
320} // namespace arrow