]> git.proxmox.com Git - ceph.git/blobdiff - ceph/src/arrow/go/arrow/flight/client.go
import quincy 17.2.0
[ceph.git] / ceph / src / arrow / go / arrow / flight / client.go
diff --git a/ceph/src/arrow/go/arrow/flight/client.go b/ceph/src/arrow/go/arrow/flight/client.go
new file mode 100644 (file)
index 0000000..735c08b
--- /dev/null
@@ -0,0 +1,325 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package flight
+
+import (
+       "context"
+       "encoding/base64"
+       "io"
+       "runtime"
+       "strings"
+       "sync/atomic"
+
+       "golang.org/x/xerrors"
+       "google.golang.org/grpc"
+       "google.golang.org/grpc/codes"
+       "google.golang.org/grpc/metadata"
+       "google.golang.org/grpc/status"
+)
+
+// Client is an interface wrapped around the generated FlightServiceClient which is
+// generated by grpc protobuf definitions. This interface provides a useful hiding
+// of the authentication handshake via calling Authenticate and using the
+// ClientAuthHandler rather than manually having to implement the grpc communication
+// and sending of the auth token.
+type Client interface {
+       // Authenticate uses the ClientAuthHandler that was used when creating the client
+       // in order to use the Handshake endpoints of the service.
+       Authenticate(context.Context, ...grpc.CallOption) error
+       AuthenticateBasicToken(ctx context.Context, username string, password string, opts ...grpc.CallOption) (context.Context, error)
+       Close() error
+       // join the interface from the FlightServiceClient instead of re-defining all
+       // the endpoints here.
+       FlightServiceClient
+}
+
+type CustomClientMiddleware interface {
+       StartCall(ctx context.Context) context.Context
+}
+
+type ClientPostCallMiddleware interface {
+       CallCompleted(ctx context.Context, err error)
+}
+
+type ClientHeadersMiddleware interface {
+       HeadersReceived(ctx context.Context, md metadata.MD)
+}
+
+func CreateClientMiddleware(middleware CustomClientMiddleware) ClientMiddleware {
+       return ClientMiddleware{
+               Unary: func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
+                       nctx := middleware.StartCall(ctx)
+                       if nctx != nil {
+                               ctx = nctx
+                       }
+
+                       if hdrs, ok := middleware.(ClientHeadersMiddleware); ok {
+                               hdrmd := make(metadata.MD)
+                               trailermd := make(metadata.MD)
+                               opts = append(opts, grpc.Header(&hdrmd), grpc.Trailer(&trailermd))
+                               defer func() {
+                                       hdrs.HeadersReceived(ctx, metadata.Join(hdrmd, trailermd))
+                               }()
+                       }
+
+                       err := invoker(ctx, method, req, reply, cc, opts...)
+                       if post, ok := middleware.(ClientPostCallMiddleware); ok {
+                               post.CallCompleted(ctx, err)
+                       }
+                       return err
+               },
+               Stream: func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
+                       nctx := middleware.StartCall(ctx)
+                       if nctx != nil {
+                               ctx = nctx
+                       }
+
+                       cs, err := streamer(ctx, desc, cc, method, opts...)
+                       hdrs, isHdrs := middleware.(ClientHeadersMiddleware)
+                       post, isPostcall := middleware.(ClientPostCallMiddleware)
+                       if !isPostcall && !isHdrs {
+                               return cs, err
+                       }
+
+                       if err != nil {
+                               if isHdrs {
+                                       md, _ := cs.Header()
+                                       hdrs.HeadersReceived(ctx, metadata.Join(md, cs.Trailer()))
+                               }
+                               if isPostcall {
+                                       post.CallCompleted(ctx, err)
+                               }
+                               return cs, err
+                       }
+
+                       // Grab the client stream context because when the finish function or the goroutine below will be
+                       // executed it's not guaranteed cs.Context() will be valid.
+                       csCtx := cs.Context()
+                       finishChan := make(chan struct{})
+                       isFinished := new(int32)
+                       *isFinished = 0
+                       finishFunc := func(err error) {
+
+                               // since there are multiple code paths that could call finishFunc
+                               // we need some sort of synchronization to guard against multiple
+                               // calls to finish
+                               if !atomic.CompareAndSwapInt32(isFinished, 0, 1) {
+                                       return
+                               }
+
+                               close(finishChan)
+                               if isPostcall {
+                                       post.CallCompleted(csCtx, err)
+                               }
+                               if isHdrs {
+                                       hdrmd, _ := cs.Header()
+                                       hdrs.HeadersReceived(csCtx, metadata.Join(hdrmd, cs.Trailer()))
+                               }
+                       }
+                       go func() {
+                               select {
+                               case <-finishChan:
+                                       // finish is being called by something else, no action necessary
+                               case <-csCtx.Done():
+                                       finishFunc(csCtx.Err())
+                               }
+                       }()
+
+                       newCS := &clientStream{
+                               ClientStream: cs,
+                               desc:         desc,
+                               finishFn:     finishFunc,
+                       }
+                       // The `ClientStream` interface allows one to omit calling `Recv` if it's
+                       // known that the result will be `io.EOF`. See
+                       // http://stackoverflow.com/q/42915337
+                       // In such cases, there's nothing that triggers the span to finish. We,
+                       // therefore, set a finalizer so that the span and the context goroutine will
+                       // at least be cleaned up when the garbage collector is run.
+                       runtime.SetFinalizer(newCS, func(newcs *clientStream) {
+                               newcs.finishFn(nil)
+                       })
+                       return newCS, nil
+               },
+       }
+}
+
+type clientStream struct {
+       grpc.ClientStream
+       desc     *grpc.StreamDesc
+       finishFn func(error)
+}
+
+func (cs *clientStream) Header() (metadata.MD, error) {
+       md, err := cs.ClientStream.Header()
+       if err != nil {
+               cs.finishFn(err)
+       }
+       return md, err
+}
+
+func (cs *clientStream) SendMsg(m interface{}) error {
+       err := cs.ClientStream.SendMsg(m)
+       if err != nil {
+               cs.finishFn(err)
+       }
+       return err
+}
+
+func (cs *clientStream) RecvMsg(m interface{}) error {
+       err := cs.ClientStream.RecvMsg(m)
+       if err == io.EOF {
+               cs.finishFn(nil)
+               return err
+       } else if err != nil {
+               cs.finishFn(err)
+               return err
+       }
+
+       if !cs.desc.ServerStreams {
+               cs.finishFn(nil)
+       }
+       return err
+}
+
+func (cs *clientStream) CloseSend() error {
+       err := cs.ClientStream.CloseSend()
+       if err != nil {
+               cs.finishFn(err)
+       }
+       return err
+}
+
+type ClientMiddleware struct {
+       Stream grpc.StreamClientInterceptor
+       Unary  grpc.UnaryClientInterceptor
+}
+
+type client struct {
+       conn        *grpc.ClientConn
+       authHandler ClientAuthHandler
+
+       FlightServiceClient
+}
+
+// NewFlightClient takes in the address of the grpc server and an auth handler for the
+// application-level handshake. If using TLS or other grpc configurations they can still
+// be passed via the grpc.DialOption list just as if connecting manually without this
+// helper function.
+//
+// Alternatively, a grpc client can be constructed as normal without this helper as the
+// grpc generated client code is still exported. This exists to add utility and helpers
+// around the authentication and passing the token with requests.
+//
+// Deprecated: prefer to use NewClientWithMiddleware
+func NewFlightClient(addr string, auth ClientAuthHandler, opts ...grpc.DialOption) (Client, error) {
+       if auth != nil {
+               opts = append([]grpc.DialOption{
+                       grpc.WithChainStreamInterceptor(createClientAuthStreamInterceptor(auth)),
+                       grpc.WithChainUnaryInterceptor(createClientAuthUnaryInterceptor(auth)),
+               }, opts...)
+       }
+
+       conn, err := grpc.Dial(addr, opts...)
+       if err != nil {
+               return nil, err
+       }
+
+       return &client{conn: conn, FlightServiceClient: NewFlightServiceClient(conn), authHandler: auth}, nil
+}
+
+// NewClientWithMiddleware takes a slice of middlewares in addition to the auth and address which will be
+// used by grpc and chained, the first middleware will be the outer most with the last middleware
+// being the inner most wrapper around the actual call. It also passes along the dialoptions passed in such
+// as TLS certs and so on.
+func NewClientWithMiddleware(addr string, auth ClientAuthHandler, middleware []ClientMiddleware, opts ...grpc.DialOption) (Client, error) {
+       unary := make([]grpc.UnaryClientInterceptor, 0, len(middleware))
+       stream := make([]grpc.StreamClientInterceptor, 0, len(middleware))
+       if auth != nil {
+               unary = append(unary, createClientAuthUnaryInterceptor(auth))
+               stream = append(stream, createClientAuthStreamInterceptor(auth))
+       }
+       if len(middleware) > 0 {
+               for _, m := range middleware {
+                       if m.Unary != nil {
+                               unary = append(unary, m.Unary)
+                       }
+                       if m.Stream != nil {
+                               stream = append(stream, m.Stream)
+                       }
+               }
+       }
+       opts = append(opts, grpc.WithChainUnaryInterceptor(unary...), grpc.WithChainStreamInterceptor(stream...))
+       conn, err := grpc.Dial(addr, opts...)
+       if err != nil {
+               return nil, err
+       }
+
+       return &client{conn: conn, FlightServiceClient: NewFlightServiceClient(conn), authHandler: auth}, nil
+}
+
+func (c *client) AuthenticateBasicToken(ctx context.Context, username, password string, opts ...grpc.CallOption) (context.Context, error) {
+       authCtx := metadata.AppendToOutgoingContext(ctx, "Authorization", "Basic "+base64.RawStdEncoding.EncodeToString([]byte(strings.Join([]string{username, password}, ":"))))
+
+       stream, err := c.FlightServiceClient.Handshake(authCtx, opts...)
+       if err != nil {
+               return ctx, err
+       }
+
+       header, err := stream.Header()
+       if err != nil {
+               return ctx, err
+       }
+
+       _, err = stream.Recv()
+       if err != nil && err != io.EOF {
+               return ctx, err
+       }
+
+       err = stream.CloseSend()
+       if err != nil {
+               return ctx, err
+       }
+
+       meta := stream.Trailer()
+       md := metadata.Join(header, meta)
+       for _, token := range md.Get("authorization") {
+               if token != "" {
+                       return metadata.AppendToOutgoingContext(ctx, "Authorization", token), nil
+               }
+       }
+
+       return ctx, xerrors.Errorf("flight: no authorization header on the response")
+}
+
+func (c *client) Authenticate(ctx context.Context, opts ...grpc.CallOption) error {
+       if c.authHandler == nil {
+               return status.Error(codes.NotFound, "cannot authenticate without an auth-handler")
+       }
+
+       stream, err := c.FlightServiceClient.Handshake(ctx, opts...)
+       if err != nil {
+               return err
+       }
+
+       return c.authHandler.Authenticate(ctx, &clientAuthConn{stream})
+}
+
+func (c *client) Close() error {
+       c.FlightServiceClient = nil
+       return c.conn.Close()
+}