]>
Commit | Line | Data |
---|---|---|
1d09f67e TL |
1 | // Licensed to the Apache Software Foundation (ASF) under one |
2 | // or more contributor license agreements. See the NOTICE file | |
3 | // distributed with this work for additional information | |
4 | // regarding copyright ownership. The ASF licenses this file | |
5 | // to you under the Apache License, Version 2.0 (the | |
6 | // "License"); you may not use this file except in compliance | |
7 | // with the License. You may obtain a copy of the License at | |
8 | // | |
9 | // http://www.apache.org/licenses/LICENSE-2.0 | |
10 | // | |
11 | // Unless required by applicable law or agreed to in writing, software | |
12 | // distributed under the License is distributed on an "AS IS" BASIS, | |
13 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
14 | // See the License for the specific language governing permissions and | |
15 | // limitations under the License. | |
16 | ||
17 | package flight_test | |
18 | ||
19 | import ( | |
20 | "context" | |
21 | "io" | |
22 | "testing" | |
23 | ||
24 | "github.com/apache/arrow/go/v6/arrow/flight" | |
25 | "google.golang.org/grpc" | |
26 | "google.golang.org/grpc/codes" | |
27 | "google.golang.org/grpc/metadata" | |
28 | status "google.golang.org/grpc/status" | |
29 | ) | |
30 | ||
31 | const ( | |
32 | validUsername = "flight_username" | |
33 | validPassword = "flight_password" | |
34 | invalidUsername = "invalid_flight_username" | |
35 | invalidPassword = "invalid_flight_password" | |
36 | validBearer = "CAREBARESTARE" | |
37 | invalidBearer = "PANDABEAR" | |
38 | ) | |
39 | ||
40 | type HeaderAuthTestFlight struct{} | |
41 | ||
42 | func (h *HeaderAuthTestFlight) ListFlights(c *flight.Criteria, fs flight.FlightService_ListFlightsServer) error { | |
43 | fs.Send(&flight.FlightInfo{ | |
44 | Schema: []byte("foobar"), | |
45 | }) | |
46 | return nil | |
47 | } | |
48 | ||
49 | func (h *HeaderAuthTestFlight) GetSchema(ctx context.Context, in *flight.FlightDescriptor) (*flight.SchemaResult, error) { | |
50 | return &flight.SchemaResult{Schema: []byte(flight.AuthFromContext(ctx).(string))}, nil | |
51 | } | |
52 | ||
53 | type validator struct{} | |
54 | ||
55 | func (*validator) Validate(username, password string) (string, error) { | |
56 | if username == validUsername && password == validPassword { | |
57 | return validBearer, nil | |
58 | } | |
59 | return "", status.Errorf(codes.Unauthenticated, "invalid user/password") | |
60 | } | |
61 | ||
62 | func (*validator) IsValid(bearerToken string) (interface{}, error) { | |
63 | if bearerToken == validBearer { | |
64 | return "carebears", nil | |
65 | } | |
66 | return "", status.Errorf(codes.Unauthenticated, "invalid authentication") | |
67 | } | |
68 | ||
69 | func TestErrorAuths(t *testing.T) { | |
70 | unary, stream := flight.CreateServerBearerTokenAuthInterceptors(&validator{}) | |
71 | s := flight.NewFlightServer(nil, grpc.UnaryInterceptor(unary), grpc.StreamInterceptor(stream)) | |
72 | s.Init("localhost:0") | |
73 | f := &HeaderAuthTestFlight{} | |
74 | s.RegisterFlightService(&flight.FlightServiceService{ | |
75 | ListFlights: f.ListFlights, | |
76 | GetSchema: f.GetSchema, | |
77 | }) | |
78 | ||
79 | go s.Serve() | |
80 | defer s.Shutdown() | |
81 | ||
82 | client, err := flight.NewFlightClient(s.Addr().String(), nil, grpc.WithInsecure()) | |
83 | if err != nil { | |
84 | t.Fatal(err) | |
85 | } | |
86 | ||
87 | t.Run("non basic auth", func(t *testing.T) { | |
88 | fc, err := client.Handshake(metadata.NewOutgoingContext(context.Background(), metadata.New(map[string]string{"authorization": "Foobar ****"}))) | |
89 | if err != nil { | |
90 | t.Fatal(err) | |
91 | } | |
92 | ||
93 | _, err = fc.Recv() | |
94 | if err == nil { | |
95 | t.Fatal("should have failed") | |
96 | } | |
97 | }) | |
98 | ||
99 | t.Run("invalid auth", func(t *testing.T) { | |
100 | _, err := client.AuthenticateBasicToken(context.Background(), invalidUsername, invalidPassword) | |
101 | if err == nil { | |
102 | t.Fatal("should have failed") | |
103 | } | |
104 | }) | |
105 | ||
106 | t.Run("invalid base64", func(t *testing.T) { | |
107 | fc, err := client.Handshake(metadata.NewOutgoingContext(context.Background(), metadata.New(map[string]string{"authorization": "Basic ****"}))) | |
108 | if err != nil { | |
109 | t.Fatal(err) | |
110 | } | |
111 | ||
112 | _, err = fc.Recv() | |
113 | if err == nil { | |
114 | t.Fatal("should have failed") | |
115 | } | |
116 | }) | |
117 | ||
118 | t.Run("invalid bearer token", func(t *testing.T) { | |
119 | fs, _ := client.ListFlights(metadata.NewOutgoingContext(context.Background(), metadata.New(map[string]string{"authorization": "Bearer " + invalidBearer})), &flight.Criteria{}) | |
120 | _, err = fs.Recv() | |
121 | if err == nil { | |
122 | t.Fatal("should have errored with invalid bearer token") | |
123 | } | |
124 | }) | |
125 | ||
126 | t.Run("invalid auth type", func(t *testing.T) { | |
127 | fs, _ := client.ListFlights(metadata.NewOutgoingContext(context.Background(), metadata.New(map[string]string{"authorization": "FunnyStuff " + invalidBearer})), &flight.Criteria{}) | |
128 | _, err = fs.Recv() | |
129 | if err == nil { | |
130 | t.Fatal("should have errored with invalid bearer token") | |
131 | } | |
132 | }) | |
133 | ||
134 | t.Run("test no auth, unary", func(t *testing.T) { | |
135 | _, err := client.GetSchema(context.Background(), &flight.FlightDescriptor{}) | |
136 | if err == nil { | |
137 | t.Fatal("should have errored") | |
138 | } | |
139 | }) | |
140 | ||
141 | t.Run("test invalid auth, unary", func(t *testing.T) { | |
142 | _, err := client.GetSchema(metadata.NewOutgoingContext(context.Background(), metadata.New(map[string]string{"authorization": "Bearer Foobarmoo"})), &flight.FlightDescriptor{}) | |
143 | if err == nil { | |
144 | t.Fatal("should have errored") | |
145 | } | |
146 | }) | |
147 | } | |
148 | ||
149 | func TestBasicAuthHelpers(t *testing.T) { | |
150 | s := flight.NewServerWithMiddleware(nil, []flight.ServerMiddleware{flight.CreateServerBasicAuthMiddleware(&validator{})}) | |
151 | s.Init("localhost:0") | |
152 | f := &HeaderAuthTestFlight{} | |
153 | s.RegisterFlightService(&flight.FlightServiceService{ | |
154 | ListFlights: f.ListFlights, | |
155 | GetSchema: f.GetSchema, | |
156 | }) | |
157 | ||
158 | go s.Serve() | |
159 | defer s.Shutdown() | |
160 | ||
161 | client, err := flight.NewFlightClient(s.Addr().String(), nil, grpc.WithInsecure()) | |
162 | if err != nil { | |
163 | t.Fatal(err) | |
164 | } | |
165 | ||
166 | ctx := context.Background() | |
167 | fs, err := client.ListFlights(ctx, &flight.Criteria{}) | |
168 | if err != nil { | |
169 | t.Fatal(err) | |
170 | } | |
171 | ||
172 | _, err = fs.Recv() | |
173 | if err == nil || err == io.EOF { | |
174 | t.Fatal("Should have failed with unauthenticated error") | |
175 | } | |
176 | ||
177 | ctx, err = client.AuthenticateBasicToken(ctx, validUsername, validPassword) | |
178 | if err != nil { | |
179 | t.Fatal(err) | |
180 | } | |
181 | ||
182 | fs, err = client.ListFlights(ctx, &flight.Criteria{}) | |
183 | if err != nil { | |
184 | t.Fatal(err) | |
185 | } | |
186 | ||
187 | info, err := fs.Recv() | |
188 | if err != nil { | |
189 | t.Fatal(err) | |
190 | } | |
191 | ||
192 | if "foobar" != string(info.Schema) { | |
193 | t.Fatal("should have received 'foobar'") | |
194 | } | |
195 | ||
196 | sc, err := client.GetSchema(ctx, &flight.FlightDescriptor{}) | |
197 | if err != nil { | |
198 | t.Fatal(err) | |
199 | } | |
200 | ||
201 | if "carebears" != string(sc.Schema) { | |
202 | t.Fatal("should have received carebears") | |
203 | } | |
204 | } |