1 // Licensed to the Apache Software Foundation (ASF) under one or more
2 // contributor license agreements. See the NOTICE file distributed with
3 // this work for additional information regarding copyright ownership.
4 // The ASF licenses this file to You under the Apache License, Version 2.0
5 // (the "License"); you may not use this file except in compliance with
6 // the License. You may obtain a copy of the License at
8 // http://www.apache.org/licenses/LICENSE-2.0
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
17 using System.Collections.Generic;
19 using System.Threading.Tasks;
20 using Apache.Arrow.Flight.Client;
21 using Apache.Arrow.Flight.TestWeb;
22 using Apache.Arrow.Tests;
23 using Google.Protobuf;
24 using Grpc.Core.Utils;
27 namespace Apache.Arrow.Flight.Tests
29 public class FlightTests : IDisposable
31 readonly TestWebFactory _testWebFactory;
32 readonly FlightClient _flightClient;
33 readonly FlightStore _flightStore;
36 _flightStore = new FlightStore();
37 _testWebFactory = new TestWebFactory(_flightStore);
38 _flightClient = new FlightClient(_testWebFactory.GetChannel());
43 _testWebFactory.Dispose();
46 private RecordBatch CreateTestBatch(int startValue, int length)
48 var batchBuilder = new RecordBatch.Builder();
49 Int32Array.Builder builder = new Int32Array.Builder();
50 for (int i = 0; i < length; i++)
52 builder.Append(startValue + i);
54 batchBuilder.Append("test", true, builder.Build());
55 return batchBuilder.Build();
59 private IEnumerable<RecordBatchWithMetadata> GetStoreBatch(FlightDescriptor flightDescriptor)
61 Assert.Contains(flightDescriptor, (IReadOnlyDictionary<FlightDescriptor, FlightHolder>)_flightStore.Flights);
63 var flightHolder = _flightStore.Flights[flightDescriptor];
64 return flightHolder.GetRecordBatches();
67 private FlightInfo GivenStoreBatches(FlightDescriptor flightDescriptor, params RecordBatchWithMetadata[] batches)
69 var initialBatch = batches.FirstOrDefault();
71 var flightHolder = new FlightHolder(flightDescriptor, initialBatch.RecordBatch.Schema, _testWebFactory.GetAddress());
73 foreach(var batch in batches)
75 flightHolder.AddBatch(batch);
78 _flightStore.Flights.Add(flightDescriptor, flightHolder);
80 return flightHolder.GetFlightInfo();
84 public async Task TestPutSingleRecordBatch()
86 var flightDescriptor = FlightDescriptor.CreatePathDescriptor("test");
87 var expectedBatch = CreateTestBatch(0, 100);
89 var putStream = _flightClient.StartPut(flightDescriptor);
90 await putStream.RequestStream.WriteAsync(expectedBatch);
91 await putStream.RequestStream.CompleteAsync();
92 var putResults = await putStream.ResponseStream.ToListAsync();
94 Assert.Single(putResults);
96 var actualBatches = GetStoreBatch(flightDescriptor);
97 Assert.Single(actualBatches);
99 ArrowReaderVerifier.CompareBatches(expectedBatch, actualBatches.First().RecordBatch);
103 public async Task TestPutTwoRecordBatches()
105 var flightDescriptor = FlightDescriptor.CreatePathDescriptor("test");
106 var expectedBatch1 = CreateTestBatch(0, 100);
107 var expectedBatch2 = CreateTestBatch(0, 100);
109 var putStream = _flightClient.StartPut(flightDescriptor);
110 await putStream.RequestStream.WriteAsync(expectedBatch1);
111 await putStream.RequestStream.WriteAsync(expectedBatch2);
112 await putStream.RequestStream.CompleteAsync();
113 var putResults = await putStream.ResponseStream.ToListAsync();
115 Assert.Equal(2, putResults.Count);
117 var actualBatches = GetStoreBatch(flightDescriptor).ToList();
118 Assert.Equal(2, actualBatches.Count);
120 ArrowReaderVerifier.CompareBatches(expectedBatch1, actualBatches[0].RecordBatch);
121 ArrowReaderVerifier.CompareBatches(expectedBatch2, actualBatches[1].RecordBatch);
125 public async Task TestGetSingleRecordBatch()
127 var flightDescriptor = FlightDescriptor.CreatePathDescriptor("test");
128 var expectedBatch = CreateTestBatch(0, 100);
130 //Add batch to the in memory store
131 GivenStoreBatches(flightDescriptor, new RecordBatchWithMetadata(expectedBatch));
133 //Get the flight info for the ticket
134 var flightInfo = await _flightClient.GetInfo(flightDescriptor);
135 Assert.Single(flightInfo.Endpoints);
137 var endpoint = flightInfo.Endpoints.FirstOrDefault();
139 var getStream = _flightClient.GetStream(endpoint.Ticket);
140 var resultList = await getStream.ResponseStream.ToListAsync();
142 Assert.Single(resultList);
143 ArrowReaderVerifier.CompareBatches(expectedBatch, resultList[0]);
147 public async Task TestGetTwoRecordBatch()
149 var flightDescriptor = FlightDescriptor.CreatePathDescriptor("test");
150 var expectedBatch1 = CreateTestBatch(0, 100);
151 var expectedBatch2 = CreateTestBatch(100, 100);
153 //Add batch to the in memory store
154 GivenStoreBatches(flightDescriptor, new RecordBatchWithMetadata(expectedBatch1), new RecordBatchWithMetadata(expectedBatch2));
156 //Get the flight info for the ticket
157 var flightInfo = await _flightClient.GetInfo(flightDescriptor);
158 Assert.Single(flightInfo.Endpoints);
160 var endpoint = flightInfo.Endpoints.FirstOrDefault();
162 var getStream = _flightClient.GetStream(endpoint.Ticket);
163 var resultList = await getStream.ResponseStream.ToListAsync();
165 Assert.Equal(2, resultList.Count);
166 ArrowReaderVerifier.CompareBatches(expectedBatch1, resultList[0]);
167 ArrowReaderVerifier.CompareBatches(expectedBatch2, resultList[1]);
171 public async Task TestGetFlightMetadata()
173 var flightDescriptor = FlightDescriptor.CreatePathDescriptor("test");
174 var expectedBatch1 = CreateTestBatch(0, 100);
176 var expectedMetadata = ByteString.CopyFromUtf8("test metadata");
177 var expectedMetadataList = new List<ByteString>() { expectedMetadata };
179 //Add batch to the in memory store
180 GivenStoreBatches(flightDescriptor, new RecordBatchWithMetadata(expectedBatch1, expectedMetadata));
182 //Get the flight info for the ticket
183 var flightInfo = await _flightClient.GetInfo(flightDescriptor);
184 Assert.Single(flightInfo.Endpoints);
186 var endpoint = flightInfo.Endpoints.FirstOrDefault();
188 var getStream = _flightClient.GetStream(endpoint.Ticket);
190 List<ByteString> actualMetadata = new List<ByteString>();
191 while(await getStream.ResponseStream.MoveNext(default))
193 actualMetadata.AddRange(getStream.ResponseStream.ApplicationMetadata);
196 Assert.Equal(expectedMetadataList, actualMetadata);
200 public async Task TestPutWithMetadata()
202 var flightDescriptor = FlightDescriptor.CreatePathDescriptor("test");
203 var expectedBatch = CreateTestBatch(0, 100);
204 var expectedMetadata = ByteString.CopyFromUtf8("test metadata");
206 var putStream = _flightClient.StartPut(flightDescriptor);
207 await putStream.RequestStream.WriteAsync(expectedBatch, expectedMetadata);
208 await putStream.RequestStream.CompleteAsync();
209 var putResults = await putStream.ResponseStream.ToListAsync();
211 Assert.Single(putResults);
213 var actualBatches = GetStoreBatch(flightDescriptor);
214 Assert.Single(actualBatches);
216 ArrowReaderVerifier.CompareBatches(expectedBatch, actualBatches.First().RecordBatch);
217 Assert.Equal(expectedMetadata, actualBatches.First().Metadata);
221 public async Task TestGetSchema()
223 var flightDescriptor = FlightDescriptor.CreatePathDescriptor("test");
224 var expectedBatch = CreateTestBatch(0, 100);
225 var expectedSchema = expectedBatch.Schema;
227 GivenStoreBatches(flightDescriptor, new RecordBatchWithMetadata(expectedBatch));
229 var actualSchema = await _flightClient.GetSchema(flightDescriptor);
231 SchemaComparer.Compare(expectedSchema, actualSchema);
235 public async Task TestDoAction()
237 var expectedResult = new List<FlightResult>()
239 new FlightResult("test data")
242 var resultStream = _flightClient.DoAction(new FlightAction("test"));
243 var actualResult = await resultStream.ResponseStream.ToListAsync();
245 Assert.Equal(expectedResult, actualResult);
249 public async Task TestListActions()
251 var expected = new List<FlightActionType>()
253 new FlightActionType("get", "get a flight"),
254 new FlightActionType("put", "add a flight"),
255 new FlightActionType("delete", "delete a flight"),
256 new FlightActionType("test", "test action")
259 var actual = await _flightClient.ListActions().ResponseStream.ToListAsync();
261 Assert.Equal(expected, actual);
265 public async Task TestListFlights()
267 var flightDescriptor1 = FlightDescriptor.CreatePathDescriptor("test1");
268 var flightDescriptor2 = FlightDescriptor.CreatePathDescriptor("test2");
269 var expectedBatch = CreateTestBatch(0, 100);
271 List<FlightInfo> expectedFlightInfo = new List<FlightInfo>();
273 expectedFlightInfo.Add(GivenStoreBatches(flightDescriptor1, new RecordBatchWithMetadata(expectedBatch)));
274 expectedFlightInfo.Add(GivenStoreBatches(flightDescriptor2, new RecordBatchWithMetadata(expectedBatch)));
276 var listFlightStream = _flightClient.ListFlights();
278 var actualFlights = await listFlightStream.ResponseStream.ToListAsync();
280 for(int i = 0; i < expectedFlightInfo.Count; i++)
282 FlightInfoComparer.Compare(expectedFlightInfo[i], actualFlights[i]);
287 public async Task TestGetBatchesWithAsyncEnumerable()
289 var flightDescriptor = FlightDescriptor.CreatePathDescriptor("test");
290 var expectedBatch1 = CreateTestBatch(0, 100);
291 var expectedBatch2 = CreateTestBatch(100, 100);
293 //Add batch to the in memory store
294 GivenStoreBatches(flightDescriptor, new RecordBatchWithMetadata(expectedBatch1), new RecordBatchWithMetadata(expectedBatch2));
296 //Get the flight info for the ticket
297 var flightInfo = await _flightClient.GetInfo(flightDescriptor);
298 Assert.Single(flightInfo.Endpoints);
300 var endpoint = flightInfo.Endpoints.FirstOrDefault();
302 var getStream = _flightClient.GetStream(endpoint.Ticket);
305 List<RecordBatch> resultList = new List<RecordBatch>();
306 await foreach(var recordBatch in getStream.ResponseStream)
308 resultList.Add(recordBatch);
311 Assert.Equal(2, resultList.Count);
312 ArrowReaderVerifier.CompareBatches(expectedBatch1, resultList[0]);
313 ArrowReaderVerifier.CompareBatches(expectedBatch2, resultList[1]);