]> git.proxmox.com Git - ceph.git/blobdiff - ceph/src/arrow/java/adapter/avro/src/test/java/org/apache/arrow/AvroTestBase.java
import quincy 17.2.0
[ceph.git] / ceph / src / arrow / java / adapter / avro / src / test / java / org / apache / arrow / AvroTestBase.java
diff --git a/ceph/src/arrow/java/adapter/avro/src/test/java/org/apache/arrow/AvroTestBase.java b/ceph/src/arrow/java/adapter/avro/src/test/java/org/apache/arrow/AvroTestBase.java
new file mode 100644 (file)
index 0000000..a00cd77
--- /dev/null
@@ -0,0 +1,228 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertTrue;
+
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.FileOutputStream;
+import java.nio.ByteBuffer;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.memory.RootAllocator;
+import org.apache.arrow.vector.FieldVector;
+import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.arrow.vector.complex.ListVector;
+import org.apache.arrow.vector.complex.StructVector;
+import org.apache.arrow.vector.util.Text;
+import org.apache.avro.Schema;
+import org.apache.avro.generic.GenericDatumWriter;
+import org.apache.avro.generic.GenericRecord;
+import org.apache.avro.io.BinaryDecoder;
+import org.apache.avro.io.BinaryEncoder;
+import org.apache.avro.io.DatumWriter;
+import org.apache.avro.io.DecoderFactory;
+import org.apache.avro.io.EncoderFactory;
+import org.junit.Before;
+import org.junit.ClassRule;
+import org.junit.rules.TemporaryFolder;
+
+public class AvroTestBase {
+
+  @ClassRule
+  public static final TemporaryFolder TMP = new TemporaryFolder();
+
+  protected AvroToArrowConfig config;
+
+  @Before
+  public void init() {
+    BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE);
+    config = new AvroToArrowConfigBuilder(allocator).build();
+  }
+
+  protected Schema getSchema(String schemaName) throws Exception {
+    Path schemaPath = Paths.get(TestWriteReadAvroRecord.class.getResource("/").getPath(),
+        "schema", schemaName);
+    return new Schema.Parser().parse(schemaPath.toFile());
+  }
+
+  protected VectorSchemaRoot writeAndRead(Schema schema, List data) throws Exception {
+    File dataFile = TMP.newFile();
+
+    BinaryEncoder
+        encoder = new EncoderFactory().directBinaryEncoder(new FileOutputStream(dataFile), null);
+    DatumWriter writer = new GenericDatumWriter(schema);
+    BinaryDecoder
+        decoder = new DecoderFactory().directBinaryDecoder(new FileInputStream(dataFile), null);
+
+    for (Object value : data) {
+      writer.write(value, encoder);
+    }
+
+    return AvroToArrow.avroToArrow(schema, decoder, config);
+  }
+
+  protected void checkArrayResult(List<List<?>> expected, ListVector vector) {
+    assertEquals(expected.size(), vector.getValueCount());
+    for (int i = 0; i < expected.size(); i++) {
+      checkArrayElement(expected.get(i), vector.getObject(i));
+    }
+  }
+
+  protected void checkArrayElement(List expected, List actual) {
+    assertEquals(expected.size(), actual.size());
+    for (int i = 0; i < expected.size(); i++) {
+      Object value1 = expected.get(i);
+      Object value2 = actual.get(i);
+      if (value1 == null) {
+        assertTrue(value2 == null);
+        continue;
+      }
+      if (value2 instanceof byte[]) {
+        value2 = ByteBuffer.wrap((byte[]) value2);
+      } else if (value2 instanceof Text) {
+        value2 = value2.toString();
+      }
+      assertEquals(value1, value2);
+    }
+  }
+
+  protected void checkPrimitiveResult(List data, FieldVector vector) {
+    assertEquals(data.size(), vector.getValueCount());
+    for (int i = 0; i < data.size(); i++) {
+      Object value1 = data.get(i);
+      Object value2 = vector.getObject(i);
+      if (value1 == null) {
+        assertTrue(value2 == null);
+        continue;
+      }
+      if (value2 instanceof byte[]) {
+        value2 = ByteBuffer.wrap((byte[]) value2);
+        if (value1 instanceof byte[]) {
+          value1 = ByteBuffer.wrap((byte[]) value1);
+        }
+      } else if (value2 instanceof Text) {
+        value2 = value2.toString();
+      } else if (value2 instanceof Byte) {
+        value2 = ((Byte) value2).intValue();
+      }
+      assertEquals(value1, value2);
+    }
+  }
+
+  protected void checkRecordResult(Schema schema, ArrayList<GenericRecord> data, VectorSchemaRoot root) {
+    assertEquals(data.size(), root.getRowCount());
+    assertEquals(schema.getFields().size(), root.getFieldVectors().size());
+
+    for (int i = 0; i < schema.getFields().size(); i++) {
+      ArrayList fieldData = new ArrayList();
+      for (GenericRecord record : data) {
+        fieldData.add(record.get(i));
+      }
+
+      checkPrimitiveResult(fieldData, root.getFieldVectors().get(i));
+    }
+
+  }
+
+  protected void checkNestedRecordResult(Schema schema, List<GenericRecord> data, VectorSchemaRoot root) {
+    assertEquals(data.size(), root.getRowCount());
+    assertTrue(schema.getFields().size() == 1);
+
+    final Schema nestedSchema = schema.getFields().get(0).schema();
+    final StructVector structVector = (StructVector) root.getFieldVectors().get(0);
+
+    for (int i = 0; i < nestedSchema.getFields().size(); i++) {
+      ArrayList fieldData = new ArrayList();
+      for (GenericRecord record : data) {
+        GenericRecord nestedRecord = (GenericRecord) record.get(0);
+        fieldData.add(nestedRecord.get(i));
+      }
+
+      checkPrimitiveResult(fieldData, structVector.getChildrenFromFields().get(i));
+    }
+
+  }
+
+
+  // belows are for iterator api
+
+  protected void checkArrayResult(List<List<?>> expected, List<ListVector> vectors) {
+    int valueCount = vectors.stream().mapToInt(v -> v.getValueCount()).sum();
+    assertEquals(expected.size(), valueCount);
+
+    int index = 0;
+    for (ListVector vector : vectors) {
+      for (int i = 0; i < vector.getValueCount(); i++) {
+        checkArrayElement(expected.get(index++), vector.getObject(i));
+      }
+    }
+  }
+
+  protected void checkRecordResult(Schema schema, ArrayList<GenericRecord> data, List<VectorSchemaRoot> roots) {
+    roots.forEach(root -> {
+      assertEquals(schema.getFields().size(), root.getFieldVectors().size());
+    });
+
+    for (int i = 0; i < schema.getFields().size(); i++) {
+      List fieldData = new ArrayList();
+      List<FieldVector> vectors = new ArrayList<>();
+      for (GenericRecord record : data) {
+        fieldData.add(record.get(i));
+      }
+      final int columnIndex = i;
+      roots.forEach(root -> vectors.add(root.getFieldVectors().get(columnIndex)));
+
+      checkPrimitiveResult(fieldData, vectors);
+    }
+
+  }
+
+  protected void checkPrimitiveResult(List data, List<FieldVector> vectors) {
+    int valueCount = vectors.stream().mapToInt(v -> v.getValueCount()).sum();
+    assertEquals(data.size(), valueCount);
+
+    int index = 0;
+    for (FieldVector vector : vectors) {
+      for (int i = 0; i < vector.getValueCount(); i++) {
+        Object value1 = data.get(index++);
+        Object value2 = vector.getObject(i);
+        if (value1 == null) {
+          assertNull(value2);
+          continue;
+        }
+        if (value2 instanceof byte[]) {
+          value2 = ByteBuffer.wrap((byte[]) value2);
+          if (value1 instanceof byte[]) {
+            value1 = ByteBuffer.wrap((byte[]) value1);
+          }
+        } else if (value2 instanceof Text) {
+          value2 = value2.toString();
+        }
+        assertEquals(value1, value2);
+      }
+    }
+  }
+}