]> git.proxmox.com Git - ceph.git/blobdiff - ceph/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/perf/TestPerf.java
import quincy 17.2.0
[ceph.git] / ceph / src / arrow / java / flight / flight-core / src / test / java / org / apache / arrow / flight / perf / TestPerf.java
diff --git a/ceph/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/perf/TestPerf.java b/ceph/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/perf/TestPerf.java
new file mode 100644 (file)
index 0000000..9e2d7cc
--- /dev/null
@@ -0,0 +1,199 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight.perf;
+
+import java.util.Arrays;
+import java.util.List;
+import java.util.concurrent.Callable;
+import java.util.concurrent.Executors;
+import java.util.concurrent.TimeUnit;
+import java.util.stream.Collectors;
+
+import org.apache.arrow.flight.FlightClient;
+import org.apache.arrow.flight.FlightDescriptor;
+import org.apache.arrow.flight.FlightInfo;
+import org.apache.arrow.flight.FlightStream;
+import org.apache.arrow.flight.FlightTestUtil;
+import org.apache.arrow.flight.Ticket;
+import org.apache.arrow.flight.perf.impl.PerfOuterClass.Perf;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.memory.RootAllocator;
+import org.apache.arrow.vector.BigIntVector;
+import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.arrow.vector.types.Types.MinorType;
+import org.apache.arrow.vector.types.pojo.Field;
+import org.apache.arrow.vector.types.pojo.Schema;
+import org.junit.Test;
+
+import com.google.common.base.MoreObjects;
+import com.google.common.base.Stopwatch;
+import com.google.common.collect.ImmutableList;
+import com.google.common.util.concurrent.Futures;
+import com.google.common.util.concurrent.ListenableFuture;
+import com.google.common.util.concurrent.ListeningExecutorService;
+import com.google.common.util.concurrent.MoreExecutors;
+import com.google.protobuf.ByteString;
+
+@org.junit.Ignore
+public class TestPerf {
+
+  public static final boolean VALIDATE = false;
+
+  public static FlightDescriptor getPerfFlightDescriptor(long recordCount, int recordsPerBatch, int streamCount) {
+    final Schema pojoSchema = new Schema(ImmutableList.of(
+        Field.nullable("a", MinorType.BIGINT.getType()),
+        Field.nullable("b", MinorType.BIGINT.getType()),
+        Field.nullable("c", MinorType.BIGINT.getType()),
+        Field.nullable("d", MinorType.BIGINT.getType())
+    ));
+
+    ByteString serializedSchema = ByteString.copyFrom(pojoSchema.toByteArray());
+
+    return FlightDescriptor.command(Perf.newBuilder()
+        .setRecordsPerStream(recordCount)
+        .setRecordsPerBatch(recordsPerBatch)
+        .setSchema(serializedSchema)
+        .setStreamCount(streamCount)
+        .build()
+        .toByteArray());
+  }
+
+  public static void main(String[] args) throws Exception {
+    new TestPerf().throughput();
+  }
+
+  @Test
+  public void throughput() throws Exception {
+    final int numRuns = 10;
+    ListeningExecutorService pool = MoreExecutors.listeningDecorator(Executors.newFixedThreadPool(4));
+    double [] throughPuts = new double[numRuns];
+
+    for (int i = 0; i < numRuns; i++) {
+      try (
+          final BufferAllocator a = new RootAllocator(Long.MAX_VALUE);
+          final PerformanceTestServer server =
+              FlightTestUtil.getStartedServer((location) -> new PerformanceTestServer(a, location));
+          final FlightClient client = FlightClient.builder(a, server.getLocation()).build();
+      ) {
+        final FlightInfo info = client.getInfo(getPerfFlightDescriptor(50_000_000L, 4095, 2));
+        List<ListenableFuture<Result>> results = info.getEndpoints()
+            .stream()
+            .map(t -> new Consumer(client, t.getTicket()))
+            .map(t -> pool.submit(t))
+            .collect(Collectors.toList());
+
+        final Result r = Futures.whenAllSucceed(results).call(() -> {
+          Result res = new Result();
+          for (ListenableFuture<Result> f : results) {
+            res.add(f.get());
+          }
+          return res;
+        }, pool).get();
+
+        double seconds = r.nanos * 1.0d / 1000 / 1000 / 1000;
+        throughPuts[i] = (r.bytes * 1.0d / 1024 / 1024) / seconds;
+        System.out.println(String.format(
+            "Transferred %d records totaling %s bytes at %f MiB/s. %f record/s. %f batch/s.",
+            r.rows,
+            r.bytes,
+            throughPuts[i],
+            (r.rows * 1.0d) / seconds,
+            (r.batches * 1.0d) / seconds
+        ));
+      }
+    }
+    pool.shutdown();
+
+    System.out.println("Summary: ");
+    double average = Arrays.stream(throughPuts).sum() / numRuns;
+    double sqrSum = Arrays.stream(throughPuts).map(val -> val - average).map(val -> val * val).sum();
+    double stddev = Math.sqrt(sqrSum / numRuns);
+    System.out.println(String.format("Average throughput: %f MiB/s, standard deviation: %f MiB/s",
+            average, stddev));
+  }
+
+  private final class Consumer implements Callable<Result> {
+
+    private final FlightClient client;
+    private final Ticket ticket;
+
+    public Consumer(FlightClient client, Ticket ticket) {
+      super();
+      this.client = client;
+      this.ticket = ticket;
+    }
+
+    @Override
+    public Result call() throws Exception {
+      final Result r = new Result();
+      Stopwatch watch = Stopwatch.createStarted();
+      try (final FlightStream stream = client.getStream(ticket)) {
+        final VectorSchemaRoot root = stream.getRoot();
+        try {
+          BigIntVector a = (BigIntVector) root.getVector("a");
+          while (stream.next()) {
+            int rows = root.getRowCount();
+            long aSum = r.aSum;
+            for (int i = 0; i < rows; i++) {
+              if (VALIDATE) {
+                aSum += a.get(i);
+              }
+            }
+            r.bytes += rows * 32;
+            r.rows += rows;
+            r.aSum = aSum;
+            r.batches++;
+          }
+
+          r.nanos = watch.elapsed(TimeUnit.NANOSECONDS);
+          return r;
+        } finally {
+          root.clear();
+        }
+      }
+    }
+
+  }
+
+  private final class Result {
+    private long rows;
+    private long aSum;
+    private long bytes;
+    private long nanos;
+    private long batches;
+
+    public void add(Result r) {
+      rows += r.rows;
+      aSum += r.aSum;
+      bytes += r.bytes;
+      batches += r.batches;
+      nanos = Math.max(nanos, r.nanos);
+    }
+
+    @Override
+    public String toString() {
+      return MoreObjects.toStringHelper(this)
+          .add("rows", rows)
+          .add("aSum", aSum)
+          .add("batches", batches)
+          .add("bytes", bytes)
+          .add("nanos", nanos)
+          .toString();
+    }
+  }
+}