]> git.proxmox.com Git - ceph.git/blob - ceph/src/arrow/java/adapter/avro/src/test/java/org/apache/arrow/AvroTestBase.java
import quincy 17.2.0
[ceph.git] / ceph / src / arrow / java / adapter / avro / src / test / java / org / apache / arrow / AvroTestBase.java
1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 package org.apache.arrow;
19
20 import static org.junit.Assert.assertEquals;
21 import static org.junit.Assert.assertNull;
22 import static org.junit.Assert.assertTrue;
23
24 import java.io.File;
25 import java.io.FileInputStream;
26 import java.io.FileOutputStream;
27 import java.nio.ByteBuffer;
28 import java.nio.file.Path;
29 import java.nio.file.Paths;
30 import java.util.ArrayList;
31 import java.util.List;
32
33 import org.apache.arrow.memory.BufferAllocator;
34 import org.apache.arrow.memory.RootAllocator;
35 import org.apache.arrow.vector.FieldVector;
36 import org.apache.arrow.vector.VectorSchemaRoot;
37 import org.apache.arrow.vector.complex.ListVector;
38 import org.apache.arrow.vector.complex.StructVector;
39 import org.apache.arrow.vector.util.Text;
40 import org.apache.avro.Schema;
41 import org.apache.avro.generic.GenericDatumWriter;
42 import org.apache.avro.generic.GenericRecord;
43 import org.apache.avro.io.BinaryDecoder;
44 import org.apache.avro.io.BinaryEncoder;
45 import org.apache.avro.io.DatumWriter;
46 import org.apache.avro.io.DecoderFactory;
47 import org.apache.avro.io.EncoderFactory;
48 import org.junit.Before;
49 import org.junit.ClassRule;
50 import org.junit.rules.TemporaryFolder;
51
52 public class AvroTestBase {
53
54 @ClassRule
55 public static final TemporaryFolder TMP = new TemporaryFolder();
56
57 protected AvroToArrowConfig config;
58
59 @Before
60 public void init() {
61 BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE);
62 config = new AvroToArrowConfigBuilder(allocator).build();
63 }
64
65 protected Schema getSchema(String schemaName) throws Exception {
66 Path schemaPath = Paths.get(TestWriteReadAvroRecord.class.getResource("/").getPath(),
67 "schema", schemaName);
68 return new Schema.Parser().parse(schemaPath.toFile());
69 }
70
71 protected VectorSchemaRoot writeAndRead(Schema schema, List data) throws Exception {
72 File dataFile = TMP.newFile();
73
74 BinaryEncoder
75 encoder = new EncoderFactory().directBinaryEncoder(new FileOutputStream(dataFile), null);
76 DatumWriter writer = new GenericDatumWriter(schema);
77 BinaryDecoder
78 decoder = new DecoderFactory().directBinaryDecoder(new FileInputStream(dataFile), null);
79
80 for (Object value : data) {
81 writer.write(value, encoder);
82 }
83
84 return AvroToArrow.avroToArrow(schema, decoder, config);
85 }
86
87 protected void checkArrayResult(List<List<?>> expected, ListVector vector) {
88 assertEquals(expected.size(), vector.getValueCount());
89 for (int i = 0; i < expected.size(); i++) {
90 checkArrayElement(expected.get(i), vector.getObject(i));
91 }
92 }
93
94 protected void checkArrayElement(List expected, List actual) {
95 assertEquals(expected.size(), actual.size());
96 for (int i = 0; i < expected.size(); i++) {
97 Object value1 = expected.get(i);
98 Object value2 = actual.get(i);
99 if (value1 == null) {
100 assertTrue(value2 == null);
101 continue;
102 }
103 if (value2 instanceof byte[]) {
104 value2 = ByteBuffer.wrap((byte[]) value2);
105 } else if (value2 instanceof Text) {
106 value2 = value2.toString();
107 }
108 assertEquals(value1, value2);
109 }
110 }
111
112 protected void checkPrimitiveResult(List data, FieldVector vector) {
113 assertEquals(data.size(), vector.getValueCount());
114 for (int i = 0; i < data.size(); i++) {
115 Object value1 = data.get(i);
116 Object value2 = vector.getObject(i);
117 if (value1 == null) {
118 assertTrue(value2 == null);
119 continue;
120 }
121 if (value2 instanceof byte[]) {
122 value2 = ByteBuffer.wrap((byte[]) value2);
123 if (value1 instanceof byte[]) {
124 value1 = ByteBuffer.wrap((byte[]) value1);
125 }
126 } else if (value2 instanceof Text) {
127 value2 = value2.toString();
128 } else if (value2 instanceof Byte) {
129 value2 = ((Byte) value2).intValue();
130 }
131 assertEquals(value1, value2);
132 }
133 }
134
135 protected void checkRecordResult(Schema schema, ArrayList<GenericRecord> data, VectorSchemaRoot root) {
136 assertEquals(data.size(), root.getRowCount());
137 assertEquals(schema.getFields().size(), root.getFieldVectors().size());
138
139 for (int i = 0; i < schema.getFields().size(); i++) {
140 ArrayList fieldData = new ArrayList();
141 for (GenericRecord record : data) {
142 fieldData.add(record.get(i));
143 }
144
145 checkPrimitiveResult(fieldData, root.getFieldVectors().get(i));
146 }
147
148 }
149
150 protected void checkNestedRecordResult(Schema schema, List<GenericRecord> data, VectorSchemaRoot root) {
151 assertEquals(data.size(), root.getRowCount());
152 assertTrue(schema.getFields().size() == 1);
153
154 final Schema nestedSchema = schema.getFields().get(0).schema();
155 final StructVector structVector = (StructVector) root.getFieldVectors().get(0);
156
157 for (int i = 0; i < nestedSchema.getFields().size(); i++) {
158 ArrayList fieldData = new ArrayList();
159 for (GenericRecord record : data) {
160 GenericRecord nestedRecord = (GenericRecord) record.get(0);
161 fieldData.add(nestedRecord.get(i));
162 }
163
164 checkPrimitiveResult(fieldData, structVector.getChildrenFromFields().get(i));
165 }
166
167 }
168
169
170 // belows are for iterator api
171
172 protected void checkArrayResult(List<List<?>> expected, List<ListVector> vectors) {
173 int valueCount = vectors.stream().mapToInt(v -> v.getValueCount()).sum();
174 assertEquals(expected.size(), valueCount);
175
176 int index = 0;
177 for (ListVector vector : vectors) {
178 for (int i = 0; i < vector.getValueCount(); i++) {
179 checkArrayElement(expected.get(index++), vector.getObject(i));
180 }
181 }
182 }
183
184 protected void checkRecordResult(Schema schema, ArrayList<GenericRecord> data, List<VectorSchemaRoot> roots) {
185 roots.forEach(root -> {
186 assertEquals(schema.getFields().size(), root.getFieldVectors().size());
187 });
188
189 for (int i = 0; i < schema.getFields().size(); i++) {
190 List fieldData = new ArrayList();
191 List<FieldVector> vectors = new ArrayList<>();
192 for (GenericRecord record : data) {
193 fieldData.add(record.get(i));
194 }
195 final int columnIndex = i;
196 roots.forEach(root -> vectors.add(root.getFieldVectors().get(columnIndex)));
197
198 checkPrimitiveResult(fieldData, vectors);
199 }
200
201 }
202
203 protected void checkPrimitiveResult(List data, List<FieldVector> vectors) {
204 int valueCount = vectors.stream().mapToInt(v -> v.getValueCount()).sum();
205 assertEquals(data.size(), valueCount);
206
207 int index = 0;
208 for (FieldVector vector : vectors) {
209 for (int i = 0; i < vector.getValueCount(); i++) {
210 Object value1 = data.get(index++);
211 Object value2 = vector.getObject(i);
212 if (value1 == null) {
213 assertNull(value2);
214 continue;
215 }
216 if (value2 instanceof byte[]) {
217 value2 = ByteBuffer.wrap((byte[]) value2);
218 if (value1 instanceof byte[]) {
219 value1 = ByteBuffer.wrap((byte[]) value1);
220 }
221 } else if (value2 instanceof Text) {
222 value2 = value2.toString();
223 }
224 assertEquals(value1, value2);
225 }
226 }
227 }
228 }