1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements. See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership. The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License. You may obtain a copy of the License at
9 // http://www.apache.org/licenses/LICENSE-2.0
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
26 "github.com/apache/arrow/go/v6/arrow/array"
27 "github.com/apache/arrow/go/v6/arrow/flight"
28 "github.com/apache/arrow/go/v6/arrow/internal/arrdata"
29 "github.com/apache/arrow/go/v6/arrow/ipc"
30 "github.com/apache/arrow/go/v6/arrow/memory"
31 "google.golang.org/grpc"
32 "google.golang.org/grpc/codes"
33 "google.golang.org/grpc/status"
36 type flightServer struct {
40 func (f *flightServer) getmem() memory.Allocator {
42 f.mem = memory.NewGoAllocator()
48 func (f *flightServer) ListFlights(c *flight.Criteria, fs flight.FlightService_ListFlightsServer) error {
49 expr := string(c.GetExpression())
52 authVal := flight.AuthFromContext(fs.Context())
54 auth = authVal.(string)
57 for _, name := range arrdata.RecordNames {
58 if expr != "" && expr != name {
62 recs := arrdata.Records[name]
64 for _, r := range recs {
65 totalRows += r.NumRows()
68 fs.Send(&flight.FlightInfo{
69 Schema: flight.SerializeSchema(recs[0].Schema(), f.getmem()),
70 FlightDescriptor: &flight.FlightDescriptor{
71 Type: flight.FlightDescriptor_PATH,
72 Path: []string{name, auth},
74 TotalRecords: totalRows,
82 func (f *flightServer) GetSchema(_ context.Context, in *flight.FlightDescriptor) (*flight.SchemaResult, error) {
84 return nil, status.Error(codes.InvalidArgument, "invalid flight descriptor")
87 recs, ok := arrdata.Records[in.Path[0]]
89 return nil, status.Error(codes.NotFound, "flight not found")
92 return &flight.SchemaResult{Schema: flight.SerializeSchema(recs[0].Schema(), f.getmem())}, nil
95 func (f *flightServer) DoGet(tkt *flight.Ticket, fs flight.FlightService_DoGetServer) error {
96 recs := arrdata.Records[string(tkt.GetTicket())]
98 w := flight.NewRecordWriter(fs, ipc.WithSchema(recs[0].Schema()))
99 for _, r := range recs {
106 type servAuth struct{}
108 func (a *servAuth) Authenticate(c flight.AuthConn) error {
114 if string(tok) != "foobar" {
115 return errors.New("novalid")
122 return c.Send([]byte("baz"))
125 func (a *servAuth) IsValid(token string) (interface{}, error) {
129 return "", errors.New("novalid")
132 type ctxauth struct{}
134 type clientAuth struct{}
136 func (a *clientAuth) Authenticate(ctx context.Context, c flight.AuthConn) error {
137 if err := c.Send(ctx.Value(ctxauth{}).([]byte)); err != nil {
145 func (a *clientAuth) GetToken(ctx context.Context) (string, error) {
146 return ctx.Value(ctxauth{}).(string), nil
149 func TestListFlights(t *testing.T) {
150 s := flight.NewFlightServer(nil)
151 s.Init("localhost:0")
153 s.RegisterFlightService(&flight.FlightServiceService{
154 ListFlights: f.ListFlights,
160 client, err := flight.NewFlightClient(s.Addr().String(), nil, grpc.WithInsecure())
166 flightStream, err := client.ListFlights(context.Background(), &flight.Criteria{})
172 info, err := flightStream.Recv()
175 } else if err != nil {
179 fname := info.GetFlightDescriptor().GetPath()[0]
180 recs, ok := arrdata.Records[fname]
182 t.Fatalf("got unknown flight info: %s", fname)
185 sc, err := flight.DeserializeSchema(info.GetSchema(), f.mem)
190 if !recs[0].Schema().Equal(sc) {
191 t.Fatalf("flight info schema transfer failed: \ngot = %#v\nwant = %#v\n", sc, recs[0].Schema())
195 for _, r := range recs {
199 if info.TotalRecords != total {
200 t.Fatalf("got wrong number of total records: got = %d, wanted = %d", info.TotalRecords, total)
205 func TestGetSchema(t *testing.T) {
206 s := flight.NewFlightServer(nil)
207 s.Init("localhost:0")
209 s.RegisterFlightService(&flight.FlightServiceService{
210 GetSchema: f.GetSchema,
216 client, err := flight.NewFlightClient(s.Addr().String(), nil, grpc.WithInsecure())
222 for name, testrecs := range arrdata.Records {
223 t.Run("flight get schema: "+name, func(t *testing.T) {
224 res, err := client.GetSchema(context.Background(), &flight.FlightDescriptor{Path: []string{name}})
229 schema, err := flight.DeserializeSchema(res.GetSchema(), f.getmem())
234 if !testrecs[0].Schema().Equal(schema) {
235 t.Fatalf("schema not match: \ngot = %#v\nwant = %#v\n", schema, testrecs[0].Schema())
241 func TestServer(t *testing.T) {
243 service := &flight.FlightServiceService{
244 ListFlights: f.ListFlights,
248 s := flight.NewFlightServer(&servAuth{})
249 s.Init("localhost:0")
250 s.RegisterFlightService(service)
255 client, err := flight.NewFlightClient(s.Addr().String(), &clientAuth{}, grpc.WithInsecure())
261 err = client.Authenticate(context.WithValue(context.Background(), ctxauth{}, []byte("foobar")))
266 ctx := context.WithValue(context.Background(), ctxauth{}, "baz")
268 fistream, err := client.ListFlights(ctx, &flight.Criteria{Expression: []byte("decimal128")})
273 fi, err := fistream.Recv()
278 if len(fi.FlightDescriptor.GetPath()) != 2 || fi.FlightDescriptor.GetPath()[1] != "bar" {
279 t.Fatalf("path should have auth info: want %s got %s", "bar", fi.FlightDescriptor.GetPath()[1])
282 fdata, err := client.DoGet(ctx, &flight.Ticket{Ticket: []byte("decimal128")})
287 r, err := flight.NewRecordReader(fdata)
292 expected := arrdata.Records["decimal128"]
294 var numRows int64 = 0
304 numRows += rec.NumRows()
305 if !array.RecordEqual(expected[idx], rec) {
306 t.Errorf("flight data stream records don't match: \ngot = %#v\nwant = %#v", rec, expected[idx])
311 if numRows != fi.TotalRecords {
312 t.Fatalf("got %d, want %d", numRows, fi.TotalRecords)
316 type flightMetadataWriterServer struct{}
318 func (f *flightMetadataWriterServer) DoGet(tkt *flight.Ticket, fs flight.FlightService_DoGetServer) error {
319 recs := arrdata.Records[string(tkt.GetTicket())]
321 w := flight.NewRecordWriter(fs, ipc.WithSchema(recs[0].Schema()))
323 for idx, r := range recs {
324 w.WriteWithAppMetadata(r, []byte(fmt.Sprintf("%d_%s", idx, string(tkt.GetTicket()))) /*metadata*/)
329 func TestFlightWithAppMetadata(t *testing.T) {
330 f := &flightMetadataWriterServer{}
331 s := flight.NewFlightServer(nil)
332 s.RegisterFlightService(&flight.FlightServiceService{DoGet: f.DoGet})
333 s.Init("localhost:0")
338 client, err := flight.NewFlightClient(s.Addr().String(), nil, grpc.WithInsecure())
344 fdata, err := client.DoGet(context.Background(), &flight.Ticket{Ticket: []byte("primitives")})
349 r, err := flight.NewRecordReader(fdata)
354 expected := arrdata.Records["primitives"]
365 appMeta := r.LatestAppMetadata()
366 if !array.RecordEqual(expected[idx], rec) {
367 t.Errorf("flight data stream records for idx: %d don't match: \ngot = %#v\nwant = %#v", idx, rec, expected[idx])
370 exMeta := fmt.Sprintf("%d_primitives", idx)
371 if string(appMeta) != exMeta {
372 t.Errorf("flight data stream application metadata mismatch: got: %v, want: %v\n", string(appMeta), exMeta)
378 type flightErrorReturn struct{}
380 func (f *flightErrorReturn) DoGet(_ *flight.Ticket, _ flight.FlightService_DoGetServer) error {
381 return status.Error(codes.NotFound, "nofound")
384 func TestReaderError(t *testing.T) {
385 f := &flightErrorReturn{}
386 s := flight.NewFlightServer(nil)
387 s.RegisterFlightService(&flight.FlightServiceService{DoGet: f.DoGet})
388 s.Init("localhost:0")
393 client, err := flight.NewFlightClient(s.Addr().String(), nil, grpc.WithInsecure())
399 fdata, err := client.DoGet(context.Background(), &flight.Ticket{})
404 _, err = flight.NewRecordReader(fdata)
406 t.Fatal("should have errored")