]> git.proxmox.com Git - ceph.git/blob - ceph/src/arrow/go/arrow/flight/flight_test.go
import quincy 17.2.0
[ceph.git] / ceph / src / arrow / go / arrow / flight / flight_test.go
1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements. See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership. The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License. You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16
17 package flight_test
18
19 import (
20 "context"
21 "errors"
22 "fmt"
23 "io"
24 "testing"
25
26 "github.com/apache/arrow/go/v6/arrow/array"
27 "github.com/apache/arrow/go/v6/arrow/flight"
28 "github.com/apache/arrow/go/v6/arrow/internal/arrdata"
29 "github.com/apache/arrow/go/v6/arrow/ipc"
30 "github.com/apache/arrow/go/v6/arrow/memory"
31 "google.golang.org/grpc"
32 "google.golang.org/grpc/codes"
33 "google.golang.org/grpc/status"
34 )
35
36 type flightServer struct {
37 mem memory.Allocator
38 }
39
40 func (f *flightServer) getmem() memory.Allocator {
41 if f.mem == nil {
42 f.mem = memory.NewGoAllocator()
43 }
44
45 return f.mem
46 }
47
48 func (f *flightServer) ListFlights(c *flight.Criteria, fs flight.FlightService_ListFlightsServer) error {
49 expr := string(c.GetExpression())
50
51 auth := ""
52 authVal := flight.AuthFromContext(fs.Context())
53 if authVal != nil {
54 auth = authVal.(string)
55 }
56
57 for _, name := range arrdata.RecordNames {
58 if expr != "" && expr != name {
59 continue
60 }
61
62 recs := arrdata.Records[name]
63 totalRows := int64(0)
64 for _, r := range recs {
65 totalRows += r.NumRows()
66 }
67
68 fs.Send(&flight.FlightInfo{
69 Schema: flight.SerializeSchema(recs[0].Schema(), f.getmem()),
70 FlightDescriptor: &flight.FlightDescriptor{
71 Type: flight.FlightDescriptor_PATH,
72 Path: []string{name, auth},
73 },
74 TotalRecords: totalRows,
75 TotalBytes: -1,
76 })
77 }
78
79 return nil
80 }
81
82 func (f *flightServer) GetSchema(_ context.Context, in *flight.FlightDescriptor) (*flight.SchemaResult, error) {
83 if in == nil {
84 return nil, status.Error(codes.InvalidArgument, "invalid flight descriptor")
85 }
86
87 recs, ok := arrdata.Records[in.Path[0]]
88 if !ok {
89 return nil, status.Error(codes.NotFound, "flight not found")
90 }
91
92 return &flight.SchemaResult{Schema: flight.SerializeSchema(recs[0].Schema(), f.getmem())}, nil
93 }
94
95 func (f *flightServer) DoGet(tkt *flight.Ticket, fs flight.FlightService_DoGetServer) error {
96 recs := arrdata.Records[string(tkt.GetTicket())]
97
98 w := flight.NewRecordWriter(fs, ipc.WithSchema(recs[0].Schema()))
99 for _, r := range recs {
100 w.Write(r)
101 }
102
103 return nil
104 }
105
106 type servAuth struct{}
107
108 func (a *servAuth) Authenticate(c flight.AuthConn) error {
109 tok, err := c.Read()
110 if err == io.EOF {
111 return nil
112 }
113
114 if string(tok) != "foobar" {
115 return errors.New("novalid")
116 }
117
118 if err != nil {
119 return err
120 }
121
122 return c.Send([]byte("baz"))
123 }
124
125 func (a *servAuth) IsValid(token string) (interface{}, error) {
126 if token == "baz" {
127 return "bar", nil
128 }
129 return "", errors.New("novalid")
130 }
131
132 type ctxauth struct{}
133
134 type clientAuth struct{}
135
136 func (a *clientAuth) Authenticate(ctx context.Context, c flight.AuthConn) error {
137 if err := c.Send(ctx.Value(ctxauth{}).([]byte)); err != nil {
138 return err
139 }
140
141 _, err := c.Read()
142 return err
143 }
144
145 func (a *clientAuth) GetToken(ctx context.Context) (string, error) {
146 return ctx.Value(ctxauth{}).(string), nil
147 }
148
149 func TestListFlights(t *testing.T) {
150 s := flight.NewFlightServer(nil)
151 s.Init("localhost:0")
152 f := &flightServer{}
153 s.RegisterFlightService(&flight.FlightServiceService{
154 ListFlights: f.ListFlights,
155 })
156
157 go s.Serve()
158 defer s.Shutdown()
159
160 client, err := flight.NewFlightClient(s.Addr().String(), nil, grpc.WithInsecure())
161 if err != nil {
162 t.Error(err)
163 }
164 defer client.Close()
165
166 flightStream, err := client.ListFlights(context.Background(), &flight.Criteria{})
167 if err != nil {
168 t.Error(err)
169 }
170
171 for {
172 info, err := flightStream.Recv()
173 if err == io.EOF {
174 break
175 } else if err != nil {
176 t.Error(err)
177 }
178
179 fname := info.GetFlightDescriptor().GetPath()[0]
180 recs, ok := arrdata.Records[fname]
181 if !ok {
182 t.Fatalf("got unknown flight info: %s", fname)
183 }
184
185 sc, err := flight.DeserializeSchema(info.GetSchema(), f.mem)
186 if err != nil {
187 t.Fatal(err)
188 }
189
190 if !recs[0].Schema().Equal(sc) {
191 t.Fatalf("flight info schema transfer failed: \ngot = %#v\nwant = %#v\n", sc, recs[0].Schema())
192 }
193
194 var total int64 = 0
195 for _, r := range recs {
196 total += r.NumRows()
197 }
198
199 if info.TotalRecords != total {
200 t.Fatalf("got wrong number of total records: got = %d, wanted = %d", info.TotalRecords, total)
201 }
202 }
203 }
204
205 func TestGetSchema(t *testing.T) {
206 s := flight.NewFlightServer(nil)
207 s.Init("localhost:0")
208 f := &flightServer{}
209 s.RegisterFlightService(&flight.FlightServiceService{
210 GetSchema: f.GetSchema,
211 })
212
213 go s.Serve()
214 defer s.Shutdown()
215
216 client, err := flight.NewFlightClient(s.Addr().String(), nil, grpc.WithInsecure())
217 if err != nil {
218 t.Error(err)
219 }
220 defer client.Close()
221
222 for name, testrecs := range arrdata.Records {
223 t.Run("flight get schema: "+name, func(t *testing.T) {
224 res, err := client.GetSchema(context.Background(), &flight.FlightDescriptor{Path: []string{name}})
225 if err != nil {
226 t.Fatal(err)
227 }
228
229 schema, err := flight.DeserializeSchema(res.GetSchema(), f.getmem())
230 if err != nil {
231 t.Fatal(err)
232 }
233
234 if !testrecs[0].Schema().Equal(schema) {
235 t.Fatalf("schema not match: \ngot = %#v\nwant = %#v\n", schema, testrecs[0].Schema())
236 }
237 })
238 }
239 }
240
241 func TestServer(t *testing.T) {
242 f := &flightServer{}
243 service := &flight.FlightServiceService{
244 ListFlights: f.ListFlights,
245 DoGet: f.DoGet,
246 }
247
248 s := flight.NewFlightServer(&servAuth{})
249 s.Init("localhost:0")
250 s.RegisterFlightService(service)
251
252 go s.Serve()
253 defer s.Shutdown()
254
255 client, err := flight.NewFlightClient(s.Addr().String(), &clientAuth{}, grpc.WithInsecure())
256 if err != nil {
257 t.Error(err)
258 }
259 defer client.Close()
260
261 err = client.Authenticate(context.WithValue(context.Background(), ctxauth{}, []byte("foobar")))
262 if err != nil {
263 t.Error(err)
264 }
265
266 ctx := context.WithValue(context.Background(), ctxauth{}, "baz")
267
268 fistream, err := client.ListFlights(ctx, &flight.Criteria{Expression: []byte("decimal128")})
269 if err != nil {
270 t.Error(err)
271 }
272
273 fi, err := fistream.Recv()
274 if err != nil {
275 t.Fatal(err)
276 }
277
278 if len(fi.FlightDescriptor.GetPath()) != 2 || fi.FlightDescriptor.GetPath()[1] != "bar" {
279 t.Fatalf("path should have auth info: want %s got %s", "bar", fi.FlightDescriptor.GetPath()[1])
280 }
281
282 fdata, err := client.DoGet(ctx, &flight.Ticket{Ticket: []byte("decimal128")})
283 if err != nil {
284 t.Error(err)
285 }
286
287 r, err := flight.NewRecordReader(fdata)
288 if err != nil {
289 t.Error(err)
290 }
291
292 expected := arrdata.Records["decimal128"]
293 idx := 0
294 var numRows int64 = 0
295 for {
296 rec, err := r.Read()
297 if err != nil {
298 if err == io.EOF {
299 break
300 }
301 t.Error(err)
302 }
303
304 numRows += rec.NumRows()
305 if !array.RecordEqual(expected[idx], rec) {
306 t.Errorf("flight data stream records don't match: \ngot = %#v\nwant = %#v", rec, expected[idx])
307 }
308 idx++
309 }
310
311 if numRows != fi.TotalRecords {
312 t.Fatalf("got %d, want %d", numRows, fi.TotalRecords)
313 }
314 }
315
316 type flightMetadataWriterServer struct{}
317
318 func (f *flightMetadataWriterServer) DoGet(tkt *flight.Ticket, fs flight.FlightService_DoGetServer) error {
319 recs := arrdata.Records[string(tkt.GetTicket())]
320
321 w := flight.NewRecordWriter(fs, ipc.WithSchema(recs[0].Schema()))
322 defer w.Close()
323 for idx, r := range recs {
324 w.WriteWithAppMetadata(r, []byte(fmt.Sprintf("%d_%s", idx, string(tkt.GetTicket()))) /*metadata*/)
325 }
326 return nil
327 }
328
329 func TestFlightWithAppMetadata(t *testing.T) {
330 f := &flightMetadataWriterServer{}
331 s := flight.NewFlightServer(nil)
332 s.RegisterFlightService(&flight.FlightServiceService{DoGet: f.DoGet})
333 s.Init("localhost:0")
334
335 go s.Serve()
336 defer s.Shutdown()
337
338 client, err := flight.NewFlightClient(s.Addr().String(), nil, grpc.WithInsecure())
339 if err != nil {
340 t.Fatal(err)
341 }
342 defer client.Close()
343
344 fdata, err := client.DoGet(context.Background(), &flight.Ticket{Ticket: []byte("primitives")})
345 if err != nil {
346 t.Fatal(err)
347 }
348
349 r, err := flight.NewRecordReader(fdata)
350 if err != nil {
351 t.Fatal(err)
352 }
353
354 expected := arrdata.Records["primitives"]
355 idx := 0
356 for {
357 rec, err := r.Read()
358 if err != nil {
359 if err == io.EOF {
360 break
361 }
362 t.Fatal(err)
363 }
364
365 appMeta := r.LatestAppMetadata()
366 if !array.RecordEqual(expected[idx], rec) {
367 t.Errorf("flight data stream records for idx: %d don't match: \ngot = %#v\nwant = %#v", idx, rec, expected[idx])
368 }
369
370 exMeta := fmt.Sprintf("%d_primitives", idx)
371 if string(appMeta) != exMeta {
372 t.Errorf("flight data stream application metadata mismatch: got: %v, want: %v\n", string(appMeta), exMeta)
373 }
374 idx++
375 }
376 }
377
378 type flightErrorReturn struct{}
379
380 func (f *flightErrorReturn) DoGet(_ *flight.Ticket, _ flight.FlightService_DoGetServer) error {
381 return status.Error(codes.NotFound, "nofound")
382 }
383
384 func TestReaderError(t *testing.T) {
385 f := &flightErrorReturn{}
386 s := flight.NewFlightServer(nil)
387 s.RegisterFlightService(&flight.FlightServiceService{DoGet: f.DoGet})
388 s.Init("localhost:0")
389
390 go s.Serve()
391 defer s.Shutdown()
392
393 client, err := flight.NewFlightClient(s.Addr().String(), nil, grpc.WithInsecure())
394 if err != nil {
395 t.Fatal(err)
396 }
397 defer client.Close()
398
399 fdata, err := client.DoGet(context.Background(), &flight.Ticket{})
400 if err != nil {
401 t.Fatal(err)
402 }
403
404 _, err = flight.NewRecordReader(fdata)
405 if err == nil {
406 t.Fatal("should have errored")
407 }
408 }