2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
18 package org
.apache
.arrow
;
20 import static org
.junit
.Assert
.assertEquals
;
21 import static org
.junit
.Assert
.assertNull
;
22 import static org
.junit
.Assert
.assertTrue
;
25 import java
.io
.FileInputStream
;
26 import java
.io
.FileOutputStream
;
27 import java
.nio
.ByteBuffer
;
28 import java
.nio
.file
.Path
;
29 import java
.nio
.file
.Paths
;
30 import java
.util
.ArrayList
;
31 import java
.util
.List
;
33 import org
.apache
.arrow
.memory
.BufferAllocator
;
34 import org
.apache
.arrow
.memory
.RootAllocator
;
35 import org
.apache
.arrow
.vector
.FieldVector
;
36 import org
.apache
.arrow
.vector
.VectorSchemaRoot
;
37 import org
.apache
.arrow
.vector
.complex
.ListVector
;
38 import org
.apache
.arrow
.vector
.complex
.StructVector
;
39 import org
.apache
.arrow
.vector
.util
.Text
;
40 import org
.apache
.avro
.Schema
;
41 import org
.apache
.avro
.generic
.GenericDatumWriter
;
42 import org
.apache
.avro
.generic
.GenericRecord
;
43 import org
.apache
.avro
.io
.BinaryDecoder
;
44 import org
.apache
.avro
.io
.BinaryEncoder
;
45 import org
.apache
.avro
.io
.DatumWriter
;
46 import org
.apache
.avro
.io
.DecoderFactory
;
47 import org
.apache
.avro
.io
.EncoderFactory
;
48 import org
.junit
.Before
;
49 import org
.junit
.ClassRule
;
50 import org
.junit
.rules
.TemporaryFolder
;
52 public class AvroTestBase
{
55 public static final TemporaryFolder TMP
= new TemporaryFolder();
57 protected AvroToArrowConfig config
;
61 BufferAllocator allocator
= new RootAllocator(Long
.MAX_VALUE
);
62 config
= new AvroToArrowConfigBuilder(allocator
).build();
65 protected Schema
getSchema(String schemaName
) throws Exception
{
66 Path schemaPath
= Paths
.get(TestWriteReadAvroRecord
.class.getResource("/").getPath(),
67 "schema", schemaName
);
68 return new Schema
.Parser().parse(schemaPath
.toFile());
71 protected VectorSchemaRoot
writeAndRead(Schema schema
, List data
) throws Exception
{
72 File dataFile
= TMP
.newFile();
75 encoder
= new EncoderFactory().directBinaryEncoder(new FileOutputStream(dataFile
), null);
76 DatumWriter writer
= new GenericDatumWriter(schema
);
78 decoder
= new DecoderFactory().directBinaryDecoder(new FileInputStream(dataFile
), null);
80 for (Object value
: data
) {
81 writer
.write(value
, encoder
);
84 return AvroToArrow
.avroToArrow(schema
, decoder
, config
);
87 protected void checkArrayResult(List
<List
<?
>> expected
, ListVector vector
) {
88 assertEquals(expected
.size(), vector
.getValueCount());
89 for (int i
= 0; i
< expected
.size(); i
++) {
90 checkArrayElement(expected
.get(i
), vector
.getObject(i
));
94 protected void checkArrayElement(List expected
, List actual
) {
95 assertEquals(expected
.size(), actual
.size());
96 for (int i
= 0; i
< expected
.size(); i
++) {
97 Object value1
= expected
.get(i
);
98 Object value2
= actual
.get(i
);
100 assertTrue(value2
== null);
103 if (value2
instanceof byte[]) {
104 value2
= ByteBuffer
.wrap((byte[]) value2
);
105 } else if (value2
instanceof Text
) {
106 value2
= value2
.toString();
108 assertEquals(value1
, value2
);
112 protected void checkPrimitiveResult(List data
, FieldVector vector
) {
113 assertEquals(data
.size(), vector
.getValueCount());
114 for (int i
= 0; i
< data
.size(); i
++) {
115 Object value1
= data
.get(i
);
116 Object value2
= vector
.getObject(i
);
117 if (value1
== null) {
118 assertTrue(value2
== null);
121 if (value2
instanceof byte[]) {
122 value2
= ByteBuffer
.wrap((byte[]) value2
);
123 if (value1
instanceof byte[]) {
124 value1
= ByteBuffer
.wrap((byte[]) value1
);
126 } else if (value2
instanceof Text
) {
127 value2
= value2
.toString();
128 } else if (value2
instanceof Byte
) {
129 value2
= ((Byte
) value2
).intValue();
131 assertEquals(value1
, value2
);
135 protected void checkRecordResult(Schema schema
, ArrayList
<GenericRecord
> data
, VectorSchemaRoot root
) {
136 assertEquals(data
.size(), root
.getRowCount());
137 assertEquals(schema
.getFields().size(), root
.getFieldVectors().size());
139 for (int i
= 0; i
< schema
.getFields().size(); i
++) {
140 ArrayList fieldData
= new ArrayList();
141 for (GenericRecord record
: data
) {
142 fieldData
.add(record
.get(i
));
145 checkPrimitiveResult(fieldData
, root
.getFieldVectors().get(i
));
150 protected void checkNestedRecordResult(Schema schema
, List
<GenericRecord
> data
, VectorSchemaRoot root
) {
151 assertEquals(data
.size(), root
.getRowCount());
152 assertTrue(schema
.getFields().size() == 1);
154 final Schema nestedSchema
= schema
.getFields().get(0).schema();
155 final StructVector structVector
= (StructVector
) root
.getFieldVectors().get(0);
157 for (int i
= 0; i
< nestedSchema
.getFields().size(); i
++) {
158 ArrayList fieldData
= new ArrayList();
159 for (GenericRecord record
: data
) {
160 GenericRecord nestedRecord
= (GenericRecord
) record
.get(0);
161 fieldData
.add(nestedRecord
.get(i
));
164 checkPrimitiveResult(fieldData
, structVector
.getChildrenFromFields().get(i
));
170 // belows are for iterator api
172 protected void checkArrayResult(List
<List
<?
>> expected
, List
<ListVector
> vectors
) {
173 int valueCount
= vectors
.stream().mapToInt(v
-> v
.getValueCount()).sum();
174 assertEquals(expected
.size(), valueCount
);
177 for (ListVector vector
: vectors
) {
178 for (int i
= 0; i
< vector
.getValueCount(); i
++) {
179 checkArrayElement(expected
.get(index
++), vector
.getObject(i
));
184 protected void checkRecordResult(Schema schema
, ArrayList
<GenericRecord
> data
, List
<VectorSchemaRoot
> roots
) {
185 roots
.forEach(root
-> {
186 assertEquals(schema
.getFields().size(), root
.getFieldVectors().size());
189 for (int i
= 0; i
< schema
.getFields().size(); i
++) {
190 List fieldData
= new ArrayList();
191 List
<FieldVector
> vectors
= new ArrayList
<>();
192 for (GenericRecord record
: data
) {
193 fieldData
.add(record
.get(i
));
195 final int columnIndex
= i
;
196 roots
.forEach(root
-> vectors
.add(root
.getFieldVectors().get(columnIndex
)));
198 checkPrimitiveResult(fieldData
, vectors
);
203 protected void checkPrimitiveResult(List data
, List
<FieldVector
> vectors
) {
204 int valueCount
= vectors
.stream().mapToInt(v
-> v
.getValueCount()).sum();
205 assertEquals(data
.size(), valueCount
);
208 for (FieldVector vector
: vectors
) {
209 for (int i
= 0; i
< vector
.getValueCount(); i
++) {
210 Object value1
= data
.get(index
++);
211 Object value2
= vector
.getObject(i
);
212 if (value1
== null) {
216 if (value2
instanceof byte[]) {
217 value2
= ByteBuffer
.wrap((byte[]) value2
);
218 if (value1
instanceof byte[]) {
219 value1
= ByteBuffer
.wrap((byte[]) value1
);
221 } else if (value2
instanceof Text
) {
222 value2
= value2
.toString();
224 assertEquals(value1
, value2
);