]>
Commit | Line | Data |
---|---|---|
1d09f67e TL |
1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one or more | |
3 | * contributor license agreements. See the NOTICE file distributed with | |
4 | * this work for additional information regarding copyright ownership. | |
5 | * The ASF licenses this file to You under the Apache License, Version 2.0 | |
6 | * (the "License"); you may not use this file except in compliance with | |
7 | * the License. You may obtain a copy of the License at | |
8 | * | |
9 | * http://www.apache.org/licenses/LICENSE-2.0 | |
10 | * | |
11 | * Unless required by applicable law or agreed to in writing, software | |
12 | * distributed under the License is distributed on an "AS IS" BASIS, | |
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
14 | * See the License for the specific language governing permissions and | |
15 | * limitations under the License. | |
16 | */ | |
17 | ||
18 | package org.apache.arrow.flight.example.integration; | |
19 | ||
20 | import static org.apache.arrow.memory.util.LargeMemoryUtil.checkedCastToInt; | |
21 | ||
22 | import java.io.File; | |
23 | import java.io.IOException; | |
24 | import java.nio.charset.StandardCharsets; | |
25 | import java.util.List; | |
26 | ||
27 | import org.apache.arrow.flight.AsyncPutListener; | |
28 | import org.apache.arrow.flight.FlightClient; | |
29 | import org.apache.arrow.flight.FlightDescriptor; | |
30 | import org.apache.arrow.flight.FlightEndpoint; | |
31 | import org.apache.arrow.flight.FlightInfo; | |
32 | import org.apache.arrow.flight.FlightStream; | |
33 | import org.apache.arrow.flight.Location; | |
34 | import org.apache.arrow.flight.PutResult; | |
35 | import org.apache.arrow.memory.ArrowBuf; | |
36 | import org.apache.arrow.memory.BufferAllocator; | |
37 | import org.apache.arrow.memory.RootAllocator; | |
38 | import org.apache.arrow.vector.VectorLoader; | |
39 | import org.apache.arrow.vector.VectorSchemaRoot; | |
40 | import org.apache.arrow.vector.VectorUnloader; | |
41 | import org.apache.arrow.vector.ipc.JsonFileReader; | |
42 | import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; | |
43 | import org.apache.arrow.vector.types.pojo.Schema; | |
44 | import org.apache.arrow.vector.util.Validator; | |
45 | import org.apache.commons.cli.CommandLine; | |
46 | import org.apache.commons.cli.CommandLineParser; | |
47 | import org.apache.commons.cli.DefaultParser; | |
48 | import org.apache.commons.cli.Options; | |
49 | import org.apache.commons.cli.ParseException; | |
50 | ||
51 | /** | |
52 | * A Flight client for integration testing. | |
53 | */ | |
54 | class IntegrationTestClient { | |
55 | private static final org.slf4j.Logger LOGGER = org.slf4j.LoggerFactory.getLogger(IntegrationTestClient.class); | |
56 | private final Options options; | |
57 | ||
58 | private IntegrationTestClient() { | |
59 | options = new Options(); | |
60 | options.addOption("j", "json", true, "json file"); | |
61 | options.addOption("scenario", true, "The integration test scenario."); | |
62 | options.addOption("host", true, "The host to connect to."); | |
63 | options.addOption("port", true, "The port to connect to."); | |
64 | } | |
65 | ||
66 | public static void main(String[] args) { | |
67 | try { | |
68 | new IntegrationTestClient().run(args); | |
69 | } catch (ParseException e) { | |
70 | fatalError("Invalid parameters", e); | |
71 | } catch (IOException e) { | |
72 | fatalError("Error accessing files", e); | |
73 | } catch (Exception e) { | |
74 | fatalError("Unknown error", e); | |
75 | } | |
76 | } | |
77 | ||
78 | private static void fatalError(String message, Throwable e) { | |
79 | System.err.println(message); | |
80 | System.err.println(e.getMessage()); | |
81 | LOGGER.error(message, e); | |
82 | System.exit(1); | |
83 | } | |
84 | ||
85 | private void run(String[] args) throws Exception { | |
86 | final CommandLineParser parser = new DefaultParser(); | |
87 | final CommandLine cmd = parser.parse(options, args, false); | |
88 | ||
89 | final String host = cmd.getOptionValue("host", "localhost"); | |
90 | final int port = Integer.parseInt(cmd.getOptionValue("port", "31337")); | |
91 | ||
92 | final Location defaultLocation = Location.forGrpcInsecure(host, port); | |
93 | try (final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE); | |
94 | final FlightClient client = FlightClient.builder(allocator, defaultLocation).build()) { | |
95 | ||
96 | if (cmd.hasOption("scenario")) { | |
97 | Scenarios.getScenario(cmd.getOptionValue("scenario")).client(allocator, defaultLocation, client); | |
98 | } else { | |
99 | final String inputPath = cmd.getOptionValue("j"); | |
100 | testStream(allocator, defaultLocation, client, inputPath); | |
101 | } | |
102 | } catch (InterruptedException e) { | |
103 | throw new RuntimeException(e); | |
104 | } | |
105 | } | |
106 | ||
107 | private static void testStream(BufferAllocator allocator, Location server, FlightClient client, String inputPath) | |
108 | throws IOException { | |
109 | // 1. Read data from JSON and upload to server. | |
110 | FlightDescriptor descriptor = FlightDescriptor.path(inputPath); | |
111 | try (JsonFileReader reader = new JsonFileReader(new File(inputPath), allocator); | |
112 | VectorSchemaRoot root = VectorSchemaRoot.create(reader.start(), allocator)) { | |
113 | FlightClient.ClientStreamListener stream = client.startPut(descriptor, root, reader, | |
114 | new AsyncPutListener() { | |
115 | int counter = 0; | |
116 | ||
117 | @Override | |
118 | public void onNext(PutResult val) { | |
119 | final byte[] metadataRaw = new byte[checkedCastToInt(val.getApplicationMetadata().readableBytes())]; | |
120 | val.getApplicationMetadata().readBytes(metadataRaw); | |
121 | final String metadata = new String(metadataRaw, StandardCharsets.UTF_8); | |
122 | if (!Integer.toString(counter).equals(metadata)) { | |
123 | throw new RuntimeException( | |
124 | String.format("Invalid ACK from server. Expected '%d' but got '%s'.", counter, metadata)); | |
125 | } | |
126 | counter++; | |
127 | } | |
128 | }); | |
129 | int counter = 0; | |
130 | while (reader.read(root)) { | |
131 | final byte[] rawMetadata = Integer.toString(counter).getBytes(StandardCharsets.UTF_8); | |
132 | final ArrowBuf metadata = allocator.buffer(rawMetadata.length); | |
133 | metadata.writeBytes(rawMetadata); | |
134 | // Transfers ownership of the buffer, so do not release it ourselves | |
135 | stream.putNext(metadata); | |
136 | root.clear(); | |
137 | counter++; | |
138 | } | |
139 | stream.completed(); | |
140 | // Need to call this, or exceptions from the server get swallowed | |
141 | stream.getResult(); | |
142 | } | |
143 | ||
144 | // 2. Get the ticket for the data. | |
145 | FlightInfo info = client.getInfo(descriptor); | |
146 | List<FlightEndpoint> endpoints = info.getEndpoints(); | |
147 | if (endpoints.isEmpty()) { | |
148 | throw new RuntimeException("No endpoints returned from Flight server."); | |
149 | } | |
150 | ||
151 | for (FlightEndpoint endpoint : info.getEndpoints()) { | |
152 | // 3. Download the data from the server. | |
153 | List<Location> locations = endpoint.getLocations(); | |
154 | if (locations.isEmpty()) { | |
155 | throw new RuntimeException("No locations returned from Flight server."); | |
156 | } | |
157 | for (Location location : locations) { | |
158 | System.out.println("Verifying location " + location.getUri()); | |
159 | try (FlightClient readClient = FlightClient.builder(allocator, location).build(); | |
160 | FlightStream stream = readClient.getStream(endpoint.getTicket()); | |
161 | VectorSchemaRoot root = stream.getRoot(); | |
162 | VectorSchemaRoot downloadedRoot = VectorSchemaRoot.create(root.getSchema(), allocator); | |
163 | JsonFileReader reader = new JsonFileReader(new File(inputPath), allocator)) { | |
164 | VectorLoader loader = new VectorLoader(downloadedRoot); | |
165 | VectorUnloader unloader = new VectorUnloader(root); | |
166 | ||
167 | Schema jsonSchema = reader.start(); | |
168 | Validator.compareSchemas(root.getSchema(), jsonSchema); | |
169 | try (VectorSchemaRoot jsonRoot = VectorSchemaRoot.create(jsonSchema, allocator)) { | |
170 | ||
171 | while (stream.next()) { | |
172 | try (final ArrowRecordBatch arb = unloader.getRecordBatch()) { | |
173 | loader.load(arb); | |
174 | if (reader.read(jsonRoot)) { | |
175 | ||
176 | // 4. Validate the data. | |
177 | Validator.compareVectorSchemaRoot(jsonRoot, downloadedRoot); | |
178 | jsonRoot.clear(); | |
179 | } else { | |
180 | throw new RuntimeException("Flight stream has more batches than JSON"); | |
181 | } | |
182 | } | |
183 | } | |
184 | ||
185 | // Verify no more batches with data in JSON | |
186 | // NOTE: Currently the C++ Flight server skips empty batches at end of the stream | |
187 | if (reader.read(jsonRoot) && jsonRoot.getRowCount() > 0) { | |
188 | throw new RuntimeException("JSON has more batches with than Flight stream"); | |
189 | } | |
190 | } | |
191 | } catch (Exception e) { | |
192 | throw new RuntimeException(e); | |
193 | } | |
194 | } | |
195 | } | |
196 | } | |
197 | } |