]> git.proxmox.com Git - ceph.git/blame - ceph/src/arrow/cpp/src/arrow/extension_type_test.cc
import quincy 17.2.0
[ceph.git] / ceph / src / arrow / cpp / src / arrow / extension_type_test.cc
CommitLineData
1d09f67e
TL
1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18#include <algorithm>
19#include <cstdint>
20#include <cstring>
21#include <memory>
22#include <sstream>
23#include <string>
24
25#include <gtest/gtest.h>
26
27#include "arrow/array/array_nested.h"
28#include "arrow/array/util.h"
29#include "arrow/extension_type.h"
30#include "arrow/io/memory.h"
31#include "arrow/ipc/options.h"
32#include "arrow/ipc/reader.h"
33#include "arrow/ipc/writer.h"
34#include "arrow/record_batch.h"
35#include "arrow/status.h"
36#include "arrow/testing/extension_type.h"
37#include "arrow/testing/gtest_util.h"
38#include "arrow/type.h"
39#include "arrow/util/key_value_metadata.h"
40#include "arrow/util/logging.h"
41
42namespace arrow {
43
44class Parametric1Array : public ExtensionArray {
45 public:
46 using ExtensionArray::ExtensionArray;
47};
48
49class Parametric2Array : public ExtensionArray {
50 public:
51 using ExtensionArray::ExtensionArray;
52};
53
54// A parametric type where the extension_name() is always the same
55class Parametric1Type : public ExtensionType {
56 public:
57 explicit Parametric1Type(int32_t parameter)
58 : ExtensionType(int32()), parameter_(parameter) {}
59
60 int32_t parameter() const { return parameter_; }
61
62 std::string extension_name() const override { return "parametric-type-1"; }
63
64 bool ExtensionEquals(const ExtensionType& other) const override {
65 const auto& other_ext = static_cast<const ExtensionType&>(other);
66 if (other_ext.extension_name() != this->extension_name()) {
67 return false;
68 }
69 return this->parameter() == static_cast<const Parametric1Type&>(other).parameter();
70 }
71
72 std::shared_ptr<Array> MakeArray(std::shared_ptr<ArrayData> data) const override {
73 return std::make_shared<Parametric1Array>(data);
74 }
75
76 Result<std::shared_ptr<DataType>> Deserialize(
77 std::shared_ptr<DataType> storage_type,
78 const std::string& serialized) const override {
79 DCHECK_EQ(4, serialized.size());
80 const int32_t parameter = *reinterpret_cast<const int32_t*>(serialized.data());
81 DCHECK(storage_type->Equals(int32()));
82 return std::make_shared<Parametric1Type>(parameter);
83 }
84
85 std::string Serialize() const override {
86 std::string result(" ");
87 memcpy(&result[0], &parameter_, sizeof(int32_t));
88 return result;
89 }
90
91 private:
92 int32_t parameter_;
93};
94
95// A parametric type where the extension_name() is different for each
96// parameter, and must be separately registered
97class Parametric2Type : public ExtensionType {
98 public:
99 explicit Parametric2Type(int32_t parameter)
100 : ExtensionType(int32()), parameter_(parameter) {}
101
102 int32_t parameter() const { return parameter_; }
103
104 std::string extension_name() const override {
105 std::stringstream ss;
106 ss << "parametric-type-2<param=" << parameter_ << ">";
107 return ss.str();
108 }
109
110 bool ExtensionEquals(const ExtensionType& other) const override {
111 const auto& other_ext = static_cast<const ExtensionType&>(other);
112 if (other_ext.extension_name() != this->extension_name()) {
113 return false;
114 }
115 return this->parameter() == static_cast<const Parametric2Type&>(other).parameter();
116 }
117
118 std::shared_ptr<Array> MakeArray(std::shared_ptr<ArrayData> data) const override {
119 return std::make_shared<Parametric2Array>(data);
120 }
121
122 Result<std::shared_ptr<DataType>> Deserialize(
123 std::shared_ptr<DataType> storage_type,
124 const std::string& serialized) const override {
125 DCHECK_EQ(4, serialized.size());
126 const int32_t parameter = *reinterpret_cast<const int32_t*>(serialized.data());
127 DCHECK(storage_type->Equals(int32()));
128 return std::make_shared<Parametric2Type>(parameter);
129 }
130
131 std::string Serialize() const override {
132 std::string result(" ");
133 memcpy(&result[0], &parameter_, sizeof(int32_t));
134 return result;
135 }
136
137 private:
138 int32_t parameter_;
139};
140
141// An extension type with a non-primitive storage type
142class ExtStructArray : public ExtensionArray {
143 public:
144 using ExtensionArray::ExtensionArray;
145};
146
147class ExtStructType : public ExtensionType {
148 public:
149 ExtStructType()
150 : ExtensionType(
151 struct_({::arrow::field("a", int64()), ::arrow::field("b", float64())})) {}
152
153 std::string extension_name() const override { return "ext-struct-type"; }
154
155 bool ExtensionEquals(const ExtensionType& other) const override {
156 const auto& other_ext = static_cast<const ExtensionType&>(other);
157 if (other_ext.extension_name() != this->extension_name()) {
158 return false;
159 }
160 return true;
161 }
162
163 std::shared_ptr<Array> MakeArray(std::shared_ptr<ArrayData> data) const override {
164 return std::make_shared<ExtStructArray>(data);
165 }
166
167 Result<std::shared_ptr<DataType>> Deserialize(
168 std::shared_ptr<DataType> storage_type,
169 const std::string& serialized) const override {
170 if (serialized != "ext-struct-type-unique-code") {
171 return Status::Invalid("Type identifier did not match");
172 }
173 return std::make_shared<ExtStructType>();
174 }
175
176 std::string Serialize() const override { return "ext-struct-type-unique-code"; }
177};
178
179class TestExtensionType : public ::testing::Test {
180 public:
181 void SetUp() { ASSERT_OK(RegisterExtensionType(std::make_shared<UuidType>())); }
182
183 void TearDown() {
184 if (GetExtensionType("uuid")) {
185 ASSERT_OK(UnregisterExtensionType("uuid"));
186 }
187 }
188};
189
190TEST_F(TestExtensionType, ExtensionTypeTest) {
191 auto type_not_exist = GetExtensionType("uuid-unknown");
192 ASSERT_EQ(type_not_exist, nullptr);
193
194 auto registered_type = GetExtensionType("uuid");
195 ASSERT_NE(registered_type, nullptr);
196
197 auto type = uuid();
198 ASSERT_EQ(type->id(), Type::EXTENSION);
199
200 const auto& ext_type = static_cast<const ExtensionType&>(*type);
201 std::string serialized = ext_type.Serialize();
202
203 ASSERT_OK_AND_ASSIGN(auto deserialized,
204 ext_type.Deserialize(fixed_size_binary(16), serialized));
205 ASSERT_TRUE(deserialized->Equals(*type));
206 ASSERT_FALSE(deserialized->Equals(*fixed_size_binary(16)));
207}
208
209auto RoundtripBatch = [](const std::shared_ptr<RecordBatch>& batch,
210 std::shared_ptr<RecordBatch>* out) {
211 ASSERT_OK_AND_ASSIGN(auto out_stream, io::BufferOutputStream::Create());
212 ASSERT_OK(ipc::WriteRecordBatchStream({batch}, ipc::IpcWriteOptions::Defaults(),
213 out_stream.get()));
214
215 ASSERT_OK_AND_ASSIGN(auto complete_ipc_stream, out_stream->Finish());
216
217 io::BufferReader reader(complete_ipc_stream);
218 std::shared_ptr<RecordBatchReader> batch_reader;
219 ASSERT_OK_AND_ASSIGN(batch_reader, ipc::RecordBatchStreamReader::Open(&reader));
220 ASSERT_OK(batch_reader->ReadNext(out));
221};
222
223TEST_F(TestExtensionType, IpcRoundtrip) {
224 auto ext_arr = ExampleUuid();
225 auto batch = RecordBatch::Make(schema({field("f0", uuid())}), 4, {ext_arr});
226
227 std::shared_ptr<RecordBatch> read_batch;
228 RoundtripBatch(batch, &read_batch);
229 CompareBatch(*batch, *read_batch, false /* compare_metadata */);
230
231 // Wrap type in a ListArray and ensure it also makes it
232 auto offsets_arr = ArrayFromJSON(int32(), "[0, 0, 2, 4]");
233 ASSERT_OK_AND_ASSIGN(auto list_arr, ListArray::FromArrays(*offsets_arr, *ext_arr));
234 batch = RecordBatch::Make(schema({field("f0", list(uuid()))}), 3, {list_arr});
235 RoundtripBatch(batch, &read_batch);
236 CompareBatch(*batch, *read_batch, false /* compare_metadata */);
237}
238
239TEST_F(TestExtensionType, UnrecognizedExtension) {
240 auto ext_arr = ExampleUuid();
241 auto batch = RecordBatch::Make(schema({field("f0", uuid())}), 4, {ext_arr});
242
243 auto storage_arr = static_cast<const ExtensionArray&>(*ext_arr).storage();
244
245 // Write full IPC stream including schema, then unregister type, then read
246 // and ensure that a plain instance of the storage type is created
247 ASSERT_OK_AND_ASSIGN(auto out_stream, io::BufferOutputStream::Create());
248 ASSERT_OK(ipc::WriteRecordBatchStream({batch}, ipc::IpcWriteOptions::Defaults(),
249 out_stream.get()));
250
251 ASSERT_OK_AND_ASSIGN(auto complete_ipc_stream, out_stream->Finish());
252
253 ASSERT_OK(UnregisterExtensionType("uuid"));
254 auto ext_metadata =
255 key_value_metadata({{"ARROW:extension:name", "uuid"},
256 {"ARROW:extension:metadata", "uuid-serialized"}});
257 auto ext_field = field("f0", fixed_size_binary(16), true, ext_metadata);
258 auto batch_no_ext = RecordBatch::Make(schema({ext_field}), 4, {storage_arr});
259
260 io::BufferReader reader(complete_ipc_stream);
261 std::shared_ptr<RecordBatchReader> batch_reader;
262 ASSERT_OK_AND_ASSIGN(batch_reader, ipc::RecordBatchStreamReader::Open(&reader));
263 std::shared_ptr<RecordBatch> read_batch;
264 ASSERT_OK(batch_reader->ReadNext(&read_batch));
265 CompareBatch(*batch_no_ext, *read_batch);
266}
267
268std::shared_ptr<Array> ExampleParametric(std::shared_ptr<DataType> type,
269 const std::string& json_data) {
270 auto arr = ArrayFromJSON(int32(), json_data);
271 auto ext_data = arr->data()->Copy();
272 ext_data->type = type;
273 return MakeArray(ext_data);
274}
275
276TEST_F(TestExtensionType, ParametricTypes) {
277 auto p1_type = std::make_shared<Parametric1Type>(6);
278 auto p1 = ExampleParametric(p1_type, "[null, 1, 2, 3]");
279
280 auto p2_type = std::make_shared<Parametric1Type>(12);
281 auto p2 = ExampleParametric(p2_type, "[2, null, 3, 4]");
282
283 auto p3_type = std::make_shared<Parametric2Type>(2);
284 auto p3 = ExampleParametric(p3_type, "[5, 6, 7, 8]");
285
286 auto p4_type = std::make_shared<Parametric2Type>(3);
287 auto p4 = ExampleParametric(p4_type, "[5, 6, 7, 9]");
288
289 ASSERT_OK(RegisterExtensionType(std::make_shared<Parametric1Type>(-1)));
290 ASSERT_OK(RegisterExtensionType(p3_type));
291 ASSERT_OK(RegisterExtensionType(p4_type));
292
293 auto batch = RecordBatch::Make(schema({field("f0", p1_type), field("f1", p2_type),
294 field("f2", p3_type), field("f3", p4_type)}),
295 4, {p1, p2, p3, p4});
296
297 std::shared_ptr<RecordBatch> read_batch;
298 RoundtripBatch(batch, &read_batch);
299 CompareBatch(*batch, *read_batch, false /* compare_metadata */);
300}
301
302TEST_F(TestExtensionType, ParametricEquals) {
303 auto p1_type = std::make_shared<Parametric1Type>(6);
304 auto p2_type = std::make_shared<Parametric1Type>(6);
305 auto p3_type = std::make_shared<Parametric1Type>(3);
306
307 ASSERT_TRUE(p1_type->Equals(p2_type));
308 ASSERT_FALSE(p1_type->Equals(p3_type));
309
310 ASSERT_EQ(p1_type->fingerprint(), "");
311}
312
313std::shared_ptr<Array> ExampleStruct() {
314 auto ext_type = std::make_shared<ExtStructType>();
315 auto storage_type = ext_type->storage_type();
316 auto arr = ArrayFromJSON(storage_type, "[[1, 0.1], [2, 0.2]]");
317
318 auto ext_data = arr->data()->Copy();
319 ext_data->type = ext_type;
320 return MakeArray(ext_data);
321}
322
323TEST_F(TestExtensionType, ValidateExtensionArray) {
324 auto ext_arr1 = ExampleUuid();
325 auto p1_type = std::make_shared<Parametric1Type>(6);
326 auto ext_arr2 = ExampleParametric(p1_type, "[null, 1, 2, 3]");
327 auto ext_arr3 = ExampleStruct();
328 auto ext_arr4 = ExampleComplex128();
329
330 ASSERT_OK(ext_arr1->ValidateFull());
331 ASSERT_OK(ext_arr2->ValidateFull());
332 ASSERT_OK(ext_arr3->ValidateFull());
333 ASSERT_OK(ext_arr4->ValidateFull());
334}
335
336} // namespace arrow