]>
Commit | Line | Data |
---|---|---|
1d09f67e TL |
1 | // Licensed to the Apache Software Foundation (ASF) under one |
2 | // or more contributor license agreements. See the NOTICE file | |
3 | // distributed with this work for additional information | |
4 | // regarding copyright ownership. The ASF licenses this file | |
5 | // to you under the Apache License, Version 2.0 (the | |
6 | // "License"); you may not use this file except in compliance | |
7 | // with the License. You may obtain a copy of the License at | |
8 | // | |
9 | // http://www.apache.org/licenses/LICENSE-2.0 | |
10 | // | |
11 | // Unless required by applicable law or agreed to in writing, | |
12 | // software distributed under the License is distributed on an | |
13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | |
14 | // KIND, either express or implied. See the License for the | |
15 | // specific language governing permissions and limitations | |
16 | // under the License. | |
17 | ||
18 | #include <algorithm> | |
19 | #include <cstdint> | |
20 | #include <cstring> | |
21 | #include <memory> | |
22 | #include <sstream> | |
23 | #include <string> | |
24 | ||
25 | #include <gtest/gtest.h> | |
26 | ||
27 | #include "arrow/array/array_nested.h" | |
28 | #include "arrow/array/util.h" | |
29 | #include "arrow/extension_type.h" | |
30 | #include "arrow/io/memory.h" | |
31 | #include "arrow/ipc/options.h" | |
32 | #include "arrow/ipc/reader.h" | |
33 | #include "arrow/ipc/writer.h" | |
34 | #include "arrow/record_batch.h" | |
35 | #include "arrow/status.h" | |
36 | #include "arrow/testing/extension_type.h" | |
37 | #include "arrow/testing/gtest_util.h" | |
38 | #include "arrow/type.h" | |
39 | #include "arrow/util/key_value_metadata.h" | |
40 | #include "arrow/util/logging.h" | |
41 | ||
42 | namespace arrow { | |
43 | ||
44 | class Parametric1Array : public ExtensionArray { | |
45 | public: | |
46 | using ExtensionArray::ExtensionArray; | |
47 | }; | |
48 | ||
49 | class Parametric2Array : public ExtensionArray { | |
50 | public: | |
51 | using ExtensionArray::ExtensionArray; | |
52 | }; | |
53 | ||
54 | // A parametric type where the extension_name() is always the same | |
55 | class Parametric1Type : public ExtensionType { | |
56 | public: | |
57 | explicit Parametric1Type(int32_t parameter) | |
58 | : ExtensionType(int32()), parameter_(parameter) {} | |
59 | ||
60 | int32_t parameter() const { return parameter_; } | |
61 | ||
62 | std::string extension_name() const override { return "parametric-type-1"; } | |
63 | ||
64 | bool ExtensionEquals(const ExtensionType& other) const override { | |
65 | const auto& other_ext = static_cast<const ExtensionType&>(other); | |
66 | if (other_ext.extension_name() != this->extension_name()) { | |
67 | return false; | |
68 | } | |
69 | return this->parameter() == static_cast<const Parametric1Type&>(other).parameter(); | |
70 | } | |
71 | ||
72 | std::shared_ptr<Array> MakeArray(std::shared_ptr<ArrayData> data) const override { | |
73 | return std::make_shared<Parametric1Array>(data); | |
74 | } | |
75 | ||
76 | Result<std::shared_ptr<DataType>> Deserialize( | |
77 | std::shared_ptr<DataType> storage_type, | |
78 | const std::string& serialized) const override { | |
79 | DCHECK_EQ(4, serialized.size()); | |
80 | const int32_t parameter = *reinterpret_cast<const int32_t*>(serialized.data()); | |
81 | DCHECK(storage_type->Equals(int32())); | |
82 | return std::make_shared<Parametric1Type>(parameter); | |
83 | } | |
84 | ||
85 | std::string Serialize() const override { | |
86 | std::string result(" "); | |
87 | memcpy(&result[0], ¶meter_, sizeof(int32_t)); | |
88 | return result; | |
89 | } | |
90 | ||
91 | private: | |
92 | int32_t parameter_; | |
93 | }; | |
94 | ||
95 | // A parametric type where the extension_name() is different for each | |
96 | // parameter, and must be separately registered | |
97 | class Parametric2Type : public ExtensionType { | |
98 | public: | |
99 | explicit Parametric2Type(int32_t parameter) | |
100 | : ExtensionType(int32()), parameter_(parameter) {} | |
101 | ||
102 | int32_t parameter() const { return parameter_; } | |
103 | ||
104 | std::string extension_name() const override { | |
105 | std::stringstream ss; | |
106 | ss << "parametric-type-2<param=" << parameter_ << ">"; | |
107 | return ss.str(); | |
108 | } | |
109 | ||
110 | bool ExtensionEquals(const ExtensionType& other) const override { | |
111 | const auto& other_ext = static_cast<const ExtensionType&>(other); | |
112 | if (other_ext.extension_name() != this->extension_name()) { | |
113 | return false; | |
114 | } | |
115 | return this->parameter() == static_cast<const Parametric2Type&>(other).parameter(); | |
116 | } | |
117 | ||
118 | std::shared_ptr<Array> MakeArray(std::shared_ptr<ArrayData> data) const override { | |
119 | return std::make_shared<Parametric2Array>(data); | |
120 | } | |
121 | ||
122 | Result<std::shared_ptr<DataType>> Deserialize( | |
123 | std::shared_ptr<DataType> storage_type, | |
124 | const std::string& serialized) const override { | |
125 | DCHECK_EQ(4, serialized.size()); | |
126 | const int32_t parameter = *reinterpret_cast<const int32_t*>(serialized.data()); | |
127 | DCHECK(storage_type->Equals(int32())); | |
128 | return std::make_shared<Parametric2Type>(parameter); | |
129 | } | |
130 | ||
131 | std::string Serialize() const override { | |
132 | std::string result(" "); | |
133 | memcpy(&result[0], ¶meter_, sizeof(int32_t)); | |
134 | return result; | |
135 | } | |
136 | ||
137 | private: | |
138 | int32_t parameter_; | |
139 | }; | |
140 | ||
141 | // An extension type with a non-primitive storage type | |
142 | class ExtStructArray : public ExtensionArray { | |
143 | public: | |
144 | using ExtensionArray::ExtensionArray; | |
145 | }; | |
146 | ||
147 | class ExtStructType : public ExtensionType { | |
148 | public: | |
149 | ExtStructType() | |
150 | : ExtensionType( | |
151 | struct_({::arrow::field("a", int64()), ::arrow::field("b", float64())})) {} | |
152 | ||
153 | std::string extension_name() const override { return "ext-struct-type"; } | |
154 | ||
155 | bool ExtensionEquals(const ExtensionType& other) const override { | |
156 | const auto& other_ext = static_cast<const ExtensionType&>(other); | |
157 | if (other_ext.extension_name() != this->extension_name()) { | |
158 | return false; | |
159 | } | |
160 | return true; | |
161 | } | |
162 | ||
163 | std::shared_ptr<Array> MakeArray(std::shared_ptr<ArrayData> data) const override { | |
164 | return std::make_shared<ExtStructArray>(data); | |
165 | } | |
166 | ||
167 | Result<std::shared_ptr<DataType>> Deserialize( | |
168 | std::shared_ptr<DataType> storage_type, | |
169 | const std::string& serialized) const override { | |
170 | if (serialized != "ext-struct-type-unique-code") { | |
171 | return Status::Invalid("Type identifier did not match"); | |
172 | } | |
173 | return std::make_shared<ExtStructType>(); | |
174 | } | |
175 | ||
176 | std::string Serialize() const override { return "ext-struct-type-unique-code"; } | |
177 | }; | |
178 | ||
179 | class TestExtensionType : public ::testing::Test { | |
180 | public: | |
181 | void SetUp() { ASSERT_OK(RegisterExtensionType(std::make_shared<UuidType>())); } | |
182 | ||
183 | void TearDown() { | |
184 | if (GetExtensionType("uuid")) { | |
185 | ASSERT_OK(UnregisterExtensionType("uuid")); | |
186 | } | |
187 | } | |
188 | }; | |
189 | ||
190 | TEST_F(TestExtensionType, ExtensionTypeTest) { | |
191 | auto type_not_exist = GetExtensionType("uuid-unknown"); | |
192 | ASSERT_EQ(type_not_exist, nullptr); | |
193 | ||
194 | auto registered_type = GetExtensionType("uuid"); | |
195 | ASSERT_NE(registered_type, nullptr); | |
196 | ||
197 | auto type = uuid(); | |
198 | ASSERT_EQ(type->id(), Type::EXTENSION); | |
199 | ||
200 | const auto& ext_type = static_cast<const ExtensionType&>(*type); | |
201 | std::string serialized = ext_type.Serialize(); | |
202 | ||
203 | ASSERT_OK_AND_ASSIGN(auto deserialized, | |
204 | ext_type.Deserialize(fixed_size_binary(16), serialized)); | |
205 | ASSERT_TRUE(deserialized->Equals(*type)); | |
206 | ASSERT_FALSE(deserialized->Equals(*fixed_size_binary(16))); | |
207 | } | |
208 | ||
209 | auto RoundtripBatch = [](const std::shared_ptr<RecordBatch>& batch, | |
210 | std::shared_ptr<RecordBatch>* out) { | |
211 | ASSERT_OK_AND_ASSIGN(auto out_stream, io::BufferOutputStream::Create()); | |
212 | ASSERT_OK(ipc::WriteRecordBatchStream({batch}, ipc::IpcWriteOptions::Defaults(), | |
213 | out_stream.get())); | |
214 | ||
215 | ASSERT_OK_AND_ASSIGN(auto complete_ipc_stream, out_stream->Finish()); | |
216 | ||
217 | io::BufferReader reader(complete_ipc_stream); | |
218 | std::shared_ptr<RecordBatchReader> batch_reader; | |
219 | ASSERT_OK_AND_ASSIGN(batch_reader, ipc::RecordBatchStreamReader::Open(&reader)); | |
220 | ASSERT_OK(batch_reader->ReadNext(out)); | |
221 | }; | |
222 | ||
223 | TEST_F(TestExtensionType, IpcRoundtrip) { | |
224 | auto ext_arr = ExampleUuid(); | |
225 | auto batch = RecordBatch::Make(schema({field("f0", uuid())}), 4, {ext_arr}); | |
226 | ||
227 | std::shared_ptr<RecordBatch> read_batch; | |
228 | RoundtripBatch(batch, &read_batch); | |
229 | CompareBatch(*batch, *read_batch, false /* compare_metadata */); | |
230 | ||
231 | // Wrap type in a ListArray and ensure it also makes it | |
232 | auto offsets_arr = ArrayFromJSON(int32(), "[0, 0, 2, 4]"); | |
233 | ASSERT_OK_AND_ASSIGN(auto list_arr, ListArray::FromArrays(*offsets_arr, *ext_arr)); | |
234 | batch = RecordBatch::Make(schema({field("f0", list(uuid()))}), 3, {list_arr}); | |
235 | RoundtripBatch(batch, &read_batch); | |
236 | CompareBatch(*batch, *read_batch, false /* compare_metadata */); | |
237 | } | |
238 | ||
239 | TEST_F(TestExtensionType, UnrecognizedExtension) { | |
240 | auto ext_arr = ExampleUuid(); | |
241 | auto batch = RecordBatch::Make(schema({field("f0", uuid())}), 4, {ext_arr}); | |
242 | ||
243 | auto storage_arr = static_cast<const ExtensionArray&>(*ext_arr).storage(); | |
244 | ||
245 | // Write full IPC stream including schema, then unregister type, then read | |
246 | // and ensure that a plain instance of the storage type is created | |
247 | ASSERT_OK_AND_ASSIGN(auto out_stream, io::BufferOutputStream::Create()); | |
248 | ASSERT_OK(ipc::WriteRecordBatchStream({batch}, ipc::IpcWriteOptions::Defaults(), | |
249 | out_stream.get())); | |
250 | ||
251 | ASSERT_OK_AND_ASSIGN(auto complete_ipc_stream, out_stream->Finish()); | |
252 | ||
253 | ASSERT_OK(UnregisterExtensionType("uuid")); | |
254 | auto ext_metadata = | |
255 | key_value_metadata({{"ARROW:extension:name", "uuid"}, | |
256 | {"ARROW:extension:metadata", "uuid-serialized"}}); | |
257 | auto ext_field = field("f0", fixed_size_binary(16), true, ext_metadata); | |
258 | auto batch_no_ext = RecordBatch::Make(schema({ext_field}), 4, {storage_arr}); | |
259 | ||
260 | io::BufferReader reader(complete_ipc_stream); | |
261 | std::shared_ptr<RecordBatchReader> batch_reader; | |
262 | ASSERT_OK_AND_ASSIGN(batch_reader, ipc::RecordBatchStreamReader::Open(&reader)); | |
263 | std::shared_ptr<RecordBatch> read_batch; | |
264 | ASSERT_OK(batch_reader->ReadNext(&read_batch)); | |
265 | CompareBatch(*batch_no_ext, *read_batch); | |
266 | } | |
267 | ||
268 | std::shared_ptr<Array> ExampleParametric(std::shared_ptr<DataType> type, | |
269 | const std::string& json_data) { | |
270 | auto arr = ArrayFromJSON(int32(), json_data); | |
271 | auto ext_data = arr->data()->Copy(); | |
272 | ext_data->type = type; | |
273 | return MakeArray(ext_data); | |
274 | } | |
275 | ||
276 | TEST_F(TestExtensionType, ParametricTypes) { | |
277 | auto p1_type = std::make_shared<Parametric1Type>(6); | |
278 | auto p1 = ExampleParametric(p1_type, "[null, 1, 2, 3]"); | |
279 | ||
280 | auto p2_type = std::make_shared<Parametric1Type>(12); | |
281 | auto p2 = ExampleParametric(p2_type, "[2, null, 3, 4]"); | |
282 | ||
283 | auto p3_type = std::make_shared<Parametric2Type>(2); | |
284 | auto p3 = ExampleParametric(p3_type, "[5, 6, 7, 8]"); | |
285 | ||
286 | auto p4_type = std::make_shared<Parametric2Type>(3); | |
287 | auto p4 = ExampleParametric(p4_type, "[5, 6, 7, 9]"); | |
288 | ||
289 | ASSERT_OK(RegisterExtensionType(std::make_shared<Parametric1Type>(-1))); | |
290 | ASSERT_OK(RegisterExtensionType(p3_type)); | |
291 | ASSERT_OK(RegisterExtensionType(p4_type)); | |
292 | ||
293 | auto batch = RecordBatch::Make(schema({field("f0", p1_type), field("f1", p2_type), | |
294 | field("f2", p3_type), field("f3", p4_type)}), | |
295 | 4, {p1, p2, p3, p4}); | |
296 | ||
297 | std::shared_ptr<RecordBatch> read_batch; | |
298 | RoundtripBatch(batch, &read_batch); | |
299 | CompareBatch(*batch, *read_batch, false /* compare_metadata */); | |
300 | } | |
301 | ||
302 | TEST_F(TestExtensionType, ParametricEquals) { | |
303 | auto p1_type = std::make_shared<Parametric1Type>(6); | |
304 | auto p2_type = std::make_shared<Parametric1Type>(6); | |
305 | auto p3_type = std::make_shared<Parametric1Type>(3); | |
306 | ||
307 | ASSERT_TRUE(p1_type->Equals(p2_type)); | |
308 | ASSERT_FALSE(p1_type->Equals(p3_type)); | |
309 | ||
310 | ASSERT_EQ(p1_type->fingerprint(), ""); | |
311 | } | |
312 | ||
313 | std::shared_ptr<Array> ExampleStruct() { | |
314 | auto ext_type = std::make_shared<ExtStructType>(); | |
315 | auto storage_type = ext_type->storage_type(); | |
316 | auto arr = ArrayFromJSON(storage_type, "[[1, 0.1], [2, 0.2]]"); | |
317 | ||
318 | auto ext_data = arr->data()->Copy(); | |
319 | ext_data->type = ext_type; | |
320 | return MakeArray(ext_data); | |
321 | } | |
322 | ||
323 | TEST_F(TestExtensionType, ValidateExtensionArray) { | |
324 | auto ext_arr1 = ExampleUuid(); | |
325 | auto p1_type = std::make_shared<Parametric1Type>(6); | |
326 | auto ext_arr2 = ExampleParametric(p1_type, "[null, 1, 2, 3]"); | |
327 | auto ext_arr3 = ExampleStruct(); | |
328 | auto ext_arr4 = ExampleComplex128(); | |
329 | ||
330 | ASSERT_OK(ext_arr1->ValidateFull()); | |
331 | ASSERT_OK(ext_arr2->ValidateFull()); | |
332 | ASSERT_OK(ext_arr3->ValidateFull()); | |
333 | ASSERT_OK(ext_arr4->ValidateFull()); | |
334 | } | |
335 | ||
336 | } // namespace arrow |