1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements. See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership. The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License. You may obtain a copy of the License at
9 // http://www.apache.org/licenses/LICENSE-2.0
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied. See the License for the
15 // specific language governing permissions and limitations
18 // Platform-specific defines
19 #include "arrow/flight/platform.h"
21 #include "arrow/flight/server.h"
36 #include <unordered_map>
39 #ifdef GRPCPP_PP_INCLUDE
40 #include <grpcpp/grpcpp.h>
42 #include <grpc++/grpc++.h>
45 #include "arrow/buffer.h"
46 #include "arrow/ipc/dictionary.h"
47 #include "arrow/ipc/options.h"
48 #include "arrow/ipc/reader.h"
49 #include "arrow/ipc/writer.h"
50 #include "arrow/memory_pool.h"
51 #include "arrow/record_batch.h"
52 #include "arrow/status.h"
53 #include "arrow/util/io_util.h"
54 #include "arrow/util/logging.h"
55 #include "arrow/util/uri.h"
57 #include "arrow/flight/internal.h"
58 #include "arrow/flight/middleware.h"
59 #include "arrow/flight/middleware_internal.h"
60 #include "arrow/flight/serialization_internal.h"
61 #include "arrow/flight/server_auth.h"
62 #include "arrow/flight/server_middleware.h"
63 #include "arrow/flight/types.h"
65 using FlightService
= arrow::flight::protocol::FlightService
;
66 using ServerContext
= grpc::ServerContext
;
69 using ServerWriter
= grpc::ServerWriter
<T
>;
74 namespace pb
= arrow::flight::protocol
;
76 // Macro that runs interceptors before returning the given status
77 #define RETURN_WITH_MIDDLEWARE(CONTEXT, STATUS) \
79 const auto& __s = (STATUS); \
80 return CONTEXT.FinishRequest(__s); \
83 #define CHECK_ARG_NOT_NULL(CONTEXT, VAL, MESSAGE) \
84 if (VAL == nullptr) { \
85 RETURN_WITH_MIDDLEWARE(CONTEXT, \
86 grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, MESSAGE)); \
89 // Same as RETURN_NOT_OK, but accepts either Arrow or gRPC status, and
90 // will run interceptors
91 #define SERVICE_RETURN_NOT_OK(CONTEXT, expr) \
93 const auto& _s = (expr); \
94 if (ARROW_PREDICT_FALSE(!_s.ok())) { \
95 return CONTEXT.FinishRequest(_s); \
101 // A MessageReader implementation that reads from a gRPC ServerReader.
102 // Templated to be generic over DoPut/DoExchange.
103 template <typename Reader
>
104 class FlightIpcMessageReader
: public ipc::MessageReader
{
106 explicit FlightIpcMessageReader(
107 std::shared_ptr
<internal::PeekableFlightDataReader
<Reader
*>> peekable_reader
,
108 std::shared_ptr
<Buffer
>* app_metadata
)
109 : peekable_reader_(peekable_reader
), app_metadata_(app_metadata
) {}
111 ::arrow::Result
<std::unique_ptr
<ipc::Message
>> ReadNextMessage() override
{
112 if (stream_finished_
) {
115 internal::FlightData
* data
;
116 peekable_reader_
->Next(&data
);
118 stream_finished_
= true;
119 if (first_message_
) {
120 return Status::Invalid(
121 "Client provided malformed message or did not provide message");
125 *app_metadata_
= std::move(data
->app_metadata
);
126 return data
->OpenMessage();
130 std::shared_ptr
<internal::PeekableFlightDataReader
<Reader
*>> peekable_reader_
;
131 // A reference to FlightMessageReaderImpl.app_metadata_. That class
132 // can't access the app metadata because when it Peek()s the stream,
133 // it may be looking at a dictionary batch, not the record
134 // batch. Updating it here ensures the reader is always updated with
135 // the last metadata message read.
136 std::shared_ptr
<Buffer
>* app_metadata_
;
137 bool first_message_
= true;
138 bool stream_finished_
= false;
141 template <typename WritePayload
>
142 class FlightMessageReaderImpl
: public FlightMessageReader
{
144 using GrpcStream
= grpc::ServerReaderWriter
<WritePayload
, pb::FlightData
>;
146 explicit FlightMessageReaderImpl(GrpcStream
* reader
)
148 peekable_reader_(new internal::PeekableFlightDataReader
<GrpcStream
*>(reader
)) {}
151 // Peek the first message to get the descriptor.
152 internal::FlightData
* data
;
153 peekable_reader_
->Peek(&data
);
155 return Status::IOError("Stream finished before first message sent");
157 if (!data
->descriptor
) {
158 return Status::IOError("Descriptor missing on first message");
160 descriptor_
= *data
->descriptor
.get(); // Copy
161 // If there's a schema (=DoPut), also Open().
162 if (data
->metadata
) {
163 return EnsureDataStarted();
165 peekable_reader_
->Next(&data
);
169 const FlightDescriptor
& descriptor() const override
{ return descriptor_
; }
171 arrow::Result
<std::shared_ptr
<Schema
>> GetSchema() override
{
172 RETURN_NOT_OK(EnsureDataStarted());
173 return batch_reader_
->schema();
176 Status
Next(FlightStreamChunk
* out
) override
{
177 internal::FlightData
* data
;
178 peekable_reader_
->Peek(&data
);
180 out
->app_metadata
= nullptr;
185 if (!data
->metadata
) {
186 // Metadata-only (data->metadata is the IPC header)
187 out
->app_metadata
= data
->app_metadata
;
189 peekable_reader_
->Next(&data
);
193 if (!batch_reader_
) {
194 RETURN_NOT_OK(EnsureDataStarted());
195 // re-peek here since EnsureDataStarted() advances the stream
198 RETURN_NOT_OK(batch_reader_
->ReadNext(&out
->data
));
199 out
->app_metadata
= std::move(app_metadata_
);
204 /// Ensure we are set up to read data.
205 Status
EnsureDataStarted() {
206 if (!batch_reader_
) {
207 // peek() until we find the first data message; discard metadata
208 if (!peekable_reader_
->SkipToData()) {
209 return Status::IOError("Client never sent a data message");
211 auto message_reader
= std::unique_ptr
<ipc::MessageReader
>(
212 new FlightIpcMessageReader
<GrpcStream
>(peekable_reader_
, &app_metadata_
));
213 ARROW_ASSIGN_OR_RAISE(
214 batch_reader_
, ipc::RecordBatchStreamReader::Open(std::move(message_reader
)));
219 FlightDescriptor descriptor_
;
221 std::shared_ptr
<internal::PeekableFlightDataReader
<GrpcStream
*>> peekable_reader_
;
222 std::shared_ptr
<RecordBatchReader
> batch_reader_
;
223 std::shared_ptr
<Buffer
> app_metadata_
;
226 class GrpcMetadataWriter
: public FlightMetadataWriter
{
228 explicit GrpcMetadataWriter(
229 grpc::ServerReaderWriter
<pb::PutResult
, pb::FlightData
>* writer
)
232 Status
WriteMetadata(const Buffer
& buffer
) override
{
233 pb::PutResult message
{};
234 message
.set_app_metadata(buffer
.data(), buffer
.size());
235 if (writer_
->Write(message
)) {
238 return Status::IOError("Unknown error writing metadata.");
242 grpc::ServerReaderWriter
<pb::PutResult
, pb::FlightData
>* writer_
;
245 class GrpcServerAuthReader
: public ServerAuthReader
{
247 explicit GrpcServerAuthReader(
248 grpc::ServerReaderWriter
<pb::HandshakeResponse
, pb::HandshakeRequest
>* stream
)
251 Status
Read(std::string
* token
) override
{
252 pb::HandshakeRequest request
;
253 if (stream_
->Read(&request
)) {
254 *token
= std::move(*request
.mutable_payload());
257 return Status::IOError("Stream is closed.");
261 grpc::ServerReaderWriter
<pb::HandshakeResponse
, pb::HandshakeRequest
>* stream_
;
264 class GrpcServerAuthSender
: public ServerAuthSender
{
266 explicit GrpcServerAuthSender(
267 grpc::ServerReaderWriter
<pb::HandshakeResponse
, pb::HandshakeRequest
>* stream
)
270 Status
Write(const std::string
& token
) override
{
271 pb::HandshakeResponse response
;
272 response
.set_payload(token
);
273 if (stream_
->Write(response
)) {
276 return Status::IOError("Stream was closed.");
280 grpc::ServerReaderWriter
<pb::HandshakeResponse
, pb::HandshakeRequest
>* stream_
;
283 /// The implementation of the write side of a bidirectional FlightData
284 /// stream for DoExchange.
285 class DoExchangeMessageWriter
: public FlightMessageWriter
{
287 explicit DoExchangeMessageWriter(
288 grpc::ServerReaderWriter
<pb::FlightData
, pb::FlightData
>* stream
)
289 : stream_(stream
), ipc_options_(::arrow::ipc::IpcWriteOptions::Defaults()) {}
291 Status
Begin(const std::shared_ptr
<Schema
>& schema
,
292 const ipc::IpcWriteOptions
& options
) override
{
294 return Status::Invalid("This writer has already been started.");
297 ipc_options_
= options
;
299 RETURN_NOT_OK(mapper_
.AddSchemaFields(*schema
));
300 FlightPayload schema_payload
;
301 RETURN_NOT_OK(ipc::GetSchemaPayload(*schema
, ipc_options_
, mapper_
,
302 &schema_payload
.ipc_message
));
303 return WritePayload(schema_payload
);
306 Status
WriteRecordBatch(const RecordBatch
& batch
) override
{
307 return WriteWithMetadata(batch
, nullptr);
310 Status
WriteMetadata(std::shared_ptr
<Buffer
> app_metadata
) override
{
311 FlightPayload payload
{};
312 payload
.app_metadata
= app_metadata
;
313 return WritePayload(payload
);
316 Status
WriteWithMetadata(const RecordBatch
& batch
,
317 std::shared_ptr
<Buffer
> app_metadata
) override
{
318 RETURN_NOT_OK(CheckStarted());
319 RETURN_NOT_OK(EnsureDictionariesWritten(batch
));
320 FlightPayload payload
{};
322 payload
.app_metadata
= app_metadata
;
324 RETURN_NOT_OK(ipc::GetRecordBatchPayload(batch
, ipc_options_
, &payload
.ipc_message
));
325 RETURN_NOT_OK(WritePayload(payload
));
326 ++stats_
.num_record_batches
;
330 Status
Close() override
{
331 // It's fine to Close() without writing data
335 ipc::WriteStats
stats() const override
{ return stats_
; }
338 Status
WritePayload(const FlightPayload
& payload
) {
339 RETURN_NOT_OK(internal::WritePayload(payload
, stream_
));
340 ++stats_
.num_messages
;
344 Status
CheckStarted() {
346 return Status::Invalid("This writer is not started. Call Begin() with a schema");
351 Status
EnsureDictionariesWritten(const RecordBatch
& batch
) {
352 if (dictionaries_written_
) {
355 dictionaries_written_
= true;
356 ARROW_ASSIGN_OR_RAISE(const auto dictionaries
,
357 ipc::CollectDictionaries(batch
, mapper_
));
358 for (const auto& pair
: dictionaries
) {
359 FlightPayload payload
{};
360 RETURN_NOT_OK(ipc::GetDictionaryPayload(pair
.first
, pair
.second
, ipc_options_
,
361 &payload
.ipc_message
));
362 RETURN_NOT_OK(WritePayload(payload
));
363 ++stats_
.num_dictionary_batches
;
368 grpc::ServerReaderWriter
<pb::FlightData
, pb::FlightData
>* stream_
;
369 ::arrow::ipc::IpcWriteOptions ipc_options_
;
370 ipc::DictionaryFieldMapper mapper_
;
371 ipc::WriteStats stats_
;
372 bool started_
= false;
373 bool dictionaries_written_
= false;
376 class FlightServiceImpl
;
377 class GrpcServerCallContext
: public ServerCallContext
{
378 explicit GrpcServerCallContext(grpc::ServerContext
* context
)
379 : context_(context
), peer_(context_
->peer()) {}
381 const std::string
& peer_identity() const override
{ return peer_identity_
; }
382 const std::string
& peer() const override
{ return peer_
; }
383 bool is_cancelled() const override
{ return context_
->IsCancelled(); }
385 // Helper method that runs interceptors given the result of an RPC,
386 // then returns the final gRPC status to send to the client
387 grpc::Status
FinishRequest(const grpc::Status
& status
) {
388 // Don't double-convert status - return the original one here
389 FinishRequest(internal::FromGrpcStatus(status
));
393 grpc::Status
FinishRequest(const arrow::Status
& status
) {
394 for (const auto& instance
: middleware_
) {
395 instance
->CallCompleted(status
);
398 // Set custom headers to map the exact Arrow status for clients
400 return internal::ToGrpcStatus(status
, context_
);
403 ServerMiddleware
* GetMiddleware(const std::string
& key
) const override
{
404 const auto& instance
= middleware_map_
.find(key
);
405 if (instance
== middleware_map_
.end()) {
408 return instance
->second
.get();
412 friend class FlightServiceImpl
;
413 ServerContext
* context_
;
415 std::string peer_identity_
;
416 std::vector
<std::shared_ptr
<ServerMiddleware
>> middleware_
;
417 std::unordered_map
<std::string
, std::shared_ptr
<ServerMiddleware
>> middleware_map_
;
420 class GrpcAddCallHeaders
: public AddCallHeaders
{
422 explicit GrpcAddCallHeaders(grpc::ServerContext
* context
) : context_(context
) {}
423 ~GrpcAddCallHeaders() override
= default;
425 void AddHeader(const std::string
& key
, const std::string
& value
) override
{
426 context_
->AddInitialMetadata(key
, value
);
430 grpc::ServerContext
* context_
;
433 // This class glues an implementation of FlightServerBase together with the
434 // gRPC service definition, so the latter is not exposed in the public API
435 class FlightServiceImpl
: public FlightService::Service
{
437 explicit FlightServiceImpl(
438 std::shared_ptr
<ServerAuthHandler
> auth_handler
,
439 std::vector
<std::pair
<std::string
, std::shared_ptr
<ServerMiddlewareFactory
>>>
441 FlightServerBase
* server
)
442 : auth_handler_(auth_handler
), middleware_(middleware
), server_(server
) {}
444 template <typename UserType
, typename Iterator
, typename ProtoType
>
445 grpc::Status
WriteStream(Iterator
* iterator
, ServerWriter
<ProtoType
>* writer
) {
447 return grpc::Status(grpc::StatusCode::INTERNAL
, "No items to iterate");
449 // Write flight info to stream until listing is exhausted
452 std::unique_ptr
<UserType
> value
;
453 GRPC_RETURN_NOT_OK(iterator
->Next(&value
));
457 GRPC_RETURN_NOT_OK(internal::ToProto(*value
, &pb_value
));
460 if (!writer
->Write(pb_value
)) {
461 // Write returns false if the stream is closed
465 return grpc::Status::OK
;
468 template <typename UserType
, typename ProtoType
>
469 grpc::Status
WriteStream(const std::vector
<UserType
>& values
,
470 ServerWriter
<ProtoType
>* writer
) {
471 // Write flight info to stream until listing is exhausted
472 for (const UserType
& value
: values
) {
474 GRPC_RETURN_NOT_OK(internal::ToProto(value
, &pb_value
));
476 if (!writer
->Write(pb_value
)) {
477 // Write returns false if the stream is closed
481 return grpc::Status::OK
;
484 // Authenticate the client (if applicable) and construct the call context
485 grpc::Status
CheckAuth(const FlightMethod
& method
, ServerContext
* context
,
486 GrpcServerCallContext
& flight_context
) {
487 if (!auth_handler_
) {
488 const auto auth_context
= context
->auth_context();
489 if (auth_context
&& auth_context
->IsPeerAuthenticated()) {
490 auto peer_identity
= auth_context
->GetPeerIdentity();
491 flight_context
.peer_identity_
=
492 peer_identity
.empty()
494 : std::string(peer_identity
.front().begin(), peer_identity
.front().end());
496 flight_context
.peer_identity_
= "";
499 const auto client_metadata
= context
->client_metadata();
500 const auto auth_header
= client_metadata
.find(internal::kGrpcAuthHeader
);
502 if (auth_header
== client_metadata
.end()) {
505 token
= std::string(auth_header
->second
.data(), auth_header
->second
.length());
507 GRPC_RETURN_NOT_OK(auth_handler_
->IsValid(token
, &flight_context
.peer_identity_
));
510 return MakeCallContext(method
, context
, flight_context
);
513 // Authenticate the client (if applicable) and construct the call context
514 grpc::Status
MakeCallContext(const FlightMethod
& method
, ServerContext
* context
,
515 GrpcServerCallContext
& flight_context
) {
516 // Run server middleware
517 const CallInfo info
{method
};
518 CallHeaders incoming_headers
;
519 for (const auto& entry
: context
->client_metadata()) {
520 incoming_headers
.insert(
521 {util::string_view(entry
.first
.data(), entry
.first
.length()),
522 util::string_view(entry
.second
.data(), entry
.second
.length())});
525 GrpcAddCallHeaders
outgoing_headers(context
);
526 for (const auto& factory
: middleware_
) {
527 std::shared_ptr
<ServerMiddleware
> instance
;
528 Status result
= factory
.second
->StartCall(info
, incoming_headers
, &instance
);
530 // Interceptor rejected call, end the request on all existing
532 return flight_context
.FinishRequest(result
);
534 if (instance
!= nullptr) {
535 flight_context
.middleware_
.push_back(instance
);
536 flight_context
.middleware_map_
.insert({factory
.first
, instance
});
537 instance
->SendingHeaders(&outgoing_headers
);
541 return grpc::Status::OK
;
544 grpc::Status
Handshake(
545 ServerContext
* context
,
546 grpc::ServerReaderWriter
<pb::HandshakeResponse
, pb::HandshakeRequest
>* stream
) {
547 GrpcServerCallContext
flight_context(context
);
548 GRPC_RETURN_NOT_GRPC_OK(
549 MakeCallContext(FlightMethod::Handshake
, context
, flight_context
));
551 if (!auth_handler_
) {
552 RETURN_WITH_MIDDLEWARE(
555 grpc::StatusCode::UNIMPLEMENTED
,
556 "This service does not have an authentication mechanism enabled."));
558 GrpcServerAuthSender outgoing
{stream
};
559 GrpcServerAuthReader incoming
{stream
};
560 RETURN_WITH_MIDDLEWARE(flight_context
,
561 auth_handler_
->Authenticate(&outgoing
, &incoming
));
564 grpc::Status
ListFlights(ServerContext
* context
, const pb::Criteria
* request
,
565 ServerWriter
<pb::FlightInfo
>* writer
) {
566 GrpcServerCallContext
flight_context(context
);
567 GRPC_RETURN_NOT_GRPC_OK(
568 CheckAuth(FlightMethod::ListFlights
, context
, flight_context
));
570 // Retrieve the listing from the implementation
571 std::unique_ptr
<FlightListing
> listing
;
575 SERVICE_RETURN_NOT_OK(flight_context
, internal::FromProto(*request
, &criteria
));
577 SERVICE_RETURN_NOT_OK(flight_context
,
578 server_
->ListFlights(flight_context
, &criteria
, &listing
));
580 // Treat null listing as no flights available
581 RETURN_WITH_MIDDLEWARE(flight_context
, grpc::Status::OK
);
583 RETURN_WITH_MIDDLEWARE(flight_context
,
584 WriteStream
<FlightInfo
>(listing
.get(), writer
));
587 grpc::Status
GetFlightInfo(ServerContext
* context
, const pb::FlightDescriptor
* request
,
588 pb::FlightInfo
* response
) {
589 GrpcServerCallContext
flight_context(context
);
590 GRPC_RETURN_NOT_GRPC_OK(
591 CheckAuth(FlightMethod::GetFlightInfo
, context
, flight_context
));
593 CHECK_ARG_NOT_NULL(flight_context
, request
, "FlightDescriptor cannot be null");
595 FlightDescriptor descr
;
596 SERVICE_RETURN_NOT_OK(flight_context
, internal::FromProto(*request
, &descr
));
598 std::unique_ptr
<FlightInfo
> info
;
599 SERVICE_RETURN_NOT_OK(flight_context
,
600 server_
->GetFlightInfo(flight_context
, descr
, &info
));
603 // Treat null listing as no flights available
604 RETURN_WITH_MIDDLEWARE(
605 flight_context
, grpc::Status(grpc::StatusCode::NOT_FOUND
, "Flight not found"));
608 SERVICE_RETURN_NOT_OK(flight_context
, internal::ToProto(*info
, response
));
609 RETURN_WITH_MIDDLEWARE(flight_context
, grpc::Status::OK
);
612 grpc::Status
GetSchema(ServerContext
* context
, const pb::FlightDescriptor
* request
,
613 pb::SchemaResult
* response
) {
614 GrpcServerCallContext
flight_context(context
);
615 GRPC_RETURN_NOT_GRPC_OK(CheckAuth(FlightMethod::GetSchema
, context
, flight_context
));
617 CHECK_ARG_NOT_NULL(flight_context
, request
, "FlightDescriptor cannot be null");
619 FlightDescriptor descr
;
620 SERVICE_RETURN_NOT_OK(flight_context
, internal::FromProto(*request
, &descr
));
622 std::unique_ptr
<SchemaResult
> result
;
623 SERVICE_RETURN_NOT_OK(flight_context
,
624 server_
->GetSchema(flight_context
, descr
, &result
));
627 // Treat null listing as no flights available
628 RETURN_WITH_MIDDLEWARE(
629 flight_context
, grpc::Status(grpc::StatusCode::NOT_FOUND
, "Flight not found"));
632 SERVICE_RETURN_NOT_OK(flight_context
, internal::ToProto(*result
, response
));
633 RETURN_WITH_MIDDLEWARE(flight_context
, grpc::Status::OK
);
636 grpc::Status
DoGet(ServerContext
* context
, const pb::Ticket
* request
,
637 ServerWriter
<pb::FlightData
>* writer
) {
638 GrpcServerCallContext
flight_context(context
);
639 GRPC_RETURN_NOT_GRPC_OK(CheckAuth(FlightMethod::DoGet
, context
, flight_context
));
641 CHECK_ARG_NOT_NULL(flight_context
, request
, "ticket cannot be null");
644 SERVICE_RETURN_NOT_OK(flight_context
, internal::FromProto(*request
, &ticket
));
646 std::unique_ptr
<FlightDataStream
> data_stream
;
647 SERVICE_RETURN_NOT_OK(flight_context
,
648 server_
->DoGet(flight_context
, ticket
, &data_stream
));
651 RETURN_WITH_MIDDLEWARE(flight_context
, grpc::Status(grpc::StatusCode::NOT_FOUND
,
652 "No data in this flight"));
655 // Write the schema as the first message in the stream
656 FlightPayload schema_payload
;
657 SERVICE_RETURN_NOT_OK(flight_context
, data_stream
->GetSchemaPayload(&schema_payload
));
658 auto status
= internal::WritePayload(schema_payload
, writer
);
659 if (status
.IsIOError()) {
660 // gRPC doesn't give any way for us to know why the message
661 // could not be written.
662 RETURN_WITH_MIDDLEWARE(flight_context
, grpc::Status::OK
);
664 SERVICE_RETURN_NOT_OK(flight_context
, status
);
666 // Consume data stream and write out payloads
668 FlightPayload payload
;
669 SERVICE_RETURN_NOT_OK(flight_context
, data_stream
->Next(&payload
));
671 if (payload
.ipc_message
.metadata
== nullptr) break;
672 auto status
= internal::WritePayload(payload
, writer
);
673 // Connection terminated
674 if (status
.IsIOError()) break;
675 SERVICE_RETURN_NOT_OK(flight_context
, status
);
677 RETURN_WITH_MIDDLEWARE(flight_context
, grpc::Status::OK
);
680 grpc::Status
DoPut(ServerContext
* context
,
681 grpc::ServerReaderWriter
<pb::PutResult
, pb::FlightData
>* reader
) {
682 GrpcServerCallContext
flight_context(context
);
683 GRPC_RETURN_NOT_GRPC_OK(CheckAuth(FlightMethod::DoPut
, context
, flight_context
));
685 auto message_reader
= std::unique_ptr
<FlightMessageReaderImpl
<pb::PutResult
>>(
686 new FlightMessageReaderImpl
<pb::PutResult
>(reader
));
687 SERVICE_RETURN_NOT_OK(flight_context
, message_reader
->Init());
688 auto metadata_writer
=
689 std::unique_ptr
<FlightMetadataWriter
>(new GrpcMetadataWriter(reader
));
690 RETURN_WITH_MIDDLEWARE(flight_context
,
691 server_
->DoPut(flight_context
, std::move(message_reader
),
692 std::move(metadata_writer
)));
695 grpc::Status
DoExchange(
696 ServerContext
* context
,
697 grpc::ServerReaderWriter
<pb::FlightData
, pb::FlightData
>* stream
) {
698 GrpcServerCallContext
flight_context(context
);
699 GRPC_RETURN_NOT_GRPC_OK(CheckAuth(FlightMethod::DoExchange
, context
, flight_context
));
700 auto message_reader
= std::unique_ptr
<FlightMessageReaderImpl
<pb::FlightData
>>(
701 new FlightMessageReaderImpl
<pb::FlightData
>(stream
));
702 SERVICE_RETURN_NOT_OK(flight_context
, message_reader
->Init());
704 std::unique_ptr
<DoExchangeMessageWriter
>(new DoExchangeMessageWriter(stream
));
705 RETURN_WITH_MIDDLEWARE(flight_context
,
706 server_
->DoExchange(flight_context
, std::move(message_reader
),
710 grpc::Status
ListActions(ServerContext
* context
, const pb::Empty
* request
,
711 ServerWriter
<pb::ActionType
>* writer
) {
712 GrpcServerCallContext
flight_context(context
);
713 GRPC_RETURN_NOT_GRPC_OK(
714 CheckAuth(FlightMethod::ListActions
, context
, flight_context
));
715 // Retrieve the listing from the implementation
716 std::vector
<ActionType
> types
;
717 SERVICE_RETURN_NOT_OK(flight_context
, server_
->ListActions(flight_context
, &types
));
718 RETURN_WITH_MIDDLEWARE(flight_context
, WriteStream
<ActionType
>(types
, writer
));
721 grpc::Status
DoAction(ServerContext
* context
, const pb::Action
* request
,
722 ServerWriter
<pb::Result
>* writer
) {
723 GrpcServerCallContext
flight_context(context
);
724 GRPC_RETURN_NOT_GRPC_OK(CheckAuth(FlightMethod::DoAction
, context
, flight_context
));
725 CHECK_ARG_NOT_NULL(flight_context
, request
, "Action cannot be null");
727 SERVICE_RETURN_NOT_OK(flight_context
, internal::FromProto(*request
, &action
));
729 std::unique_ptr
<ResultStream
> results
;
730 SERVICE_RETURN_NOT_OK(flight_context
,
731 server_
->DoAction(flight_context
, action
, &results
));
734 RETURN_WITH_MIDDLEWARE(flight_context
, grpc::Status::CANCELLED
);
738 std::unique_ptr
<Result
> result
;
739 SERVICE_RETURN_NOT_OK(flight_context
, results
->Next(&result
));
744 pb::Result pb_result
;
745 SERVICE_RETURN_NOT_OK(flight_context
, internal::ToProto(*result
, &pb_result
));
746 if (!writer
->Write(pb_result
)) {
747 // Stream may be closed
751 RETURN_WITH_MIDDLEWARE(flight_context
, grpc::Status::OK
);
755 std::shared_ptr
<ServerAuthHandler
> auth_handler_
;
756 std::vector
<std::pair
<std::string
, std::shared_ptr
<ServerMiddlewareFactory
>>>
758 FlightServerBase
* server_
;
763 FlightMetadataWriter::~FlightMetadataWriter() = default;
766 // gRPC server lifecycle
769 #if (ATOMIC_INT_LOCK_FREE != 2 || ATOMIC_POINTER_LOCK_FREE != 2)
770 #error "atomic ints and atomic pointers not always lock-free!"
773 using ::arrow::internal::SetSignalHandler
;
774 using ::arrow::internal::SignalHandler
;
777 #define PIPE_WRITE _write
778 #define PIPE_READ _read
780 #define PIPE_WRITE write
781 #define PIPE_READ read
784 /// RAII guard that manages a self-pipe and a thread that listens on
785 /// the self-pipe, shutting down the gRPC server when a signal handler
786 /// writes to the pipe.
787 class ServerSignalHandler
{
789 ARROW_DISALLOW_COPY_AND_ASSIGN(ServerSignalHandler
);
790 ServerSignalHandler() = default;
792 /// Create the pipe and handler thread.
794 /// \return the fd of the write side of the pipe.
795 template <typename Fn
>
796 arrow::Result
<int> Init(Fn handler
) {
797 ARROW_ASSIGN_OR_RAISE(auto pipe
, arrow::internal::CreatePipe());
799 // Make write end nonblocking
800 int flags
= fcntl(pipe
.wfd
, F_GETFL
);
802 RETURN_NOT_OK(arrow::internal::FileClose(pipe
.rfd
));
803 RETURN_NOT_OK(arrow::internal::FileClose(pipe
.wfd
));
804 return arrow::internal::IOErrorFromErrno(
805 errno
, "Could not initialize self-pipe to wait for signals");
808 if (fcntl(pipe
.wfd
, F_SETFL
, flags
) == -1) {
809 RETURN_NOT_OK(arrow::internal::FileClose(pipe
.rfd
));
810 RETURN_NOT_OK(arrow::internal::FileClose(pipe
.wfd
));
811 return arrow::internal::IOErrorFromErrno(
812 errno
, "Could not initialize self-pipe to wait for signals");
816 handle_signals_
= std::thread(handler
, self_pipe_
.rfd
);
817 return self_pipe_
.wfd
;
821 if (self_pipe_
.rfd
== 0) {
825 if (PIPE_WRITE(self_pipe_
.wfd
, "0", 1) < 0 && errno
!= EAGAIN
&&
826 errno
!= EWOULDBLOCK
&& errno
!= EINTR
) {
827 return arrow::internal::IOErrorFromErrno(errno
, "Could not unblock signal thread");
829 RETURN_NOT_OK(arrow::internal::FileClose(self_pipe_
.rfd
));
830 RETURN_NOT_OK(arrow::internal::FileClose(self_pipe_
.wfd
));
831 handle_signals_
.join();
837 ~ServerSignalHandler() { ARROW_CHECK_OK(Shutdown()); }
840 arrow::internal::Pipe self_pipe_
;
841 std::thread handle_signals_
;
844 struct FlightServerBase::Impl
{
845 std::unique_ptr
<FlightServiceImpl
> service_
;
846 std::unique_ptr
<grpc::Server
> server_
;
849 // Signal handlers (on Windows) and the shutdown handler (other platforms)
850 // are executed in a separate thread, so getting the current thread instance
851 // wouldn't make sense. This means only a single instance can receive signals.
852 static std::atomic
<Impl
*> running_instance_
;
853 // We'll use the self-pipe trick to notify a thread from the signal
854 // handler. The thread will then shut down the gRPC server.
858 std::vector
<int> signals_
;
859 std::vector
<SignalHandler
> old_signal_handlers_
;
860 std::atomic
<int> got_signal_
;
862 static void HandleSignal(int signum
) {
863 auto instance
= running_instance_
.load();
864 if (instance
!= nullptr) {
865 instance
->DoHandleSignal(signum
);
869 void DoHandleSignal(int signum
) {
870 got_signal_
= signum
;
871 int saved_errno
= errno
;
872 // Ignore errors - pipe is nonblocking
873 PIPE_WRITE(self_pipe_wfd_
, "0", 1);
877 static void WaitForSignals(int fd
) {
878 // Wait for a signal handler to write to the pipe
880 while (PIPE_READ(fd
, /*buf=*/buf
, /*count=*/1) == -1) {
881 if (errno
== EINTR
) {
884 ARROW_CHECK_OK(arrow::internal::IOErrorFromErrno(
885 errno
, "Error while waiting for shutdown signal"));
887 auto instance
= running_instance_
.load();
888 if (instance
!= nullptr) {
889 instance
->server_
->Shutdown();
894 std::atomic
<FlightServerBase::Impl
*> FlightServerBase::Impl::running_instance_
;
896 FlightServerOptions::FlightServerOptions(const Location
& location_
)
897 : location(location_
),
898 auth_handler(nullptr),
900 verify_client(false),
903 builder_hook(nullptr) {}
905 FlightServerOptions::~FlightServerOptions() = default;
907 FlightServerBase::FlightServerBase() { impl_
.reset(new Impl
); }
909 FlightServerBase::~FlightServerBase() {}
911 Status
FlightServerBase::Init(const FlightServerOptions
& options
) {
912 impl_
->service_
.reset(
913 new FlightServiceImpl(options
.auth_handler
, options
.middleware
, this));
915 grpc::ServerBuilder builder
;
916 // Allow uploading messages of any length
917 builder
.SetMaxReceiveMessageSize(-1);
919 const Location
& location
= options
.location
;
920 const std::string scheme
= location
.scheme();
921 if (scheme
== kSchemeGrpc
|| scheme
== kSchemeGrpcTcp
|| scheme
== kSchemeGrpcTls
) {
922 std::stringstream address
;
923 address
<< arrow::internal::UriEncodeHost(location
.uri_
->host()) << ':'
924 << location
.uri_
->port_text();
926 std::shared_ptr
<grpc::ServerCredentials
> creds
;
927 if (scheme
== kSchemeGrpcTls
) {
928 grpc::SslServerCredentialsOptions ssl_options
;
929 for (const auto& pair
: options
.tls_certificates
) {
930 ssl_options
.pem_key_cert_pairs
.push_back({pair
.pem_key
, pair
.pem_cert
});
932 if (options
.verify_client
) {
933 ssl_options
.client_certificate_request
=
934 GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY
;
936 if (!options
.root_certificates
.empty()) {
937 ssl_options
.pem_root_certs
= options
.root_certificates
;
939 creds
= grpc::SslServerCredentials(ssl_options
);
941 creds
= grpc::InsecureServerCredentials();
944 builder
.AddListeningPort(address
.str(), creds
, &impl_
->port_
);
945 } else if (scheme
== kSchemeGrpcUnix
) {
946 std::stringstream address
;
947 address
<< "unix:" << location
.uri_
->path();
948 builder
.AddListeningPort(address
.str(), grpc::InsecureServerCredentials());
950 return Status::NotImplemented("Scheme is not supported: " + scheme
);
953 builder
.RegisterService(impl_
->service_
.get());
955 // Disable SO_REUSEPORT - it makes debugging/testing a pain as
956 // leftover processes can handle requests on accident
957 builder
.AddChannelArgument(GRPC_ARG_ALLOW_REUSEPORT
, 0);
959 if (options
.builder_hook
) {
960 options
.builder_hook(&builder
);
963 impl_
->server_
= builder
.BuildAndStart();
964 if (!impl_
->server_
) {
965 return Status::UnknownError("Server did not start properly");
970 int FlightServerBase::port() const { return impl_
->port_
; }
972 Status
FlightServerBase::SetShutdownOnSignals(const std::vector
<int> sigs
) {
973 impl_
->signals_
= sigs
;
974 impl_
->old_signal_handlers_
.clear();
978 Status
FlightServerBase::Serve() {
979 if (!impl_
->server_
) {
980 return Status::UnknownError("Server did not start properly");
982 impl_
->got_signal_
= 0;
983 impl_
->old_signal_handlers_
.clear();
984 impl_
->running_instance_
= impl_
.get();
986 ServerSignalHandler signal_handler
;
987 ARROW_ASSIGN_OR_RAISE(impl_
->self_pipe_wfd_
,
988 signal_handler
.Init(&Impl::WaitForSignals
));
989 // Override existing signal handlers with our own handler so as to stop the server.
990 for (size_t i
= 0; i
< impl_
->signals_
.size(); ++i
) {
991 int signum
= impl_
->signals_
[i
];
992 SignalHandler
new_handler(&Impl::HandleSignal
), old_handler
;
993 ARROW_ASSIGN_OR_RAISE(old_handler
, SetSignalHandler(signum
, new_handler
));
994 impl_
->old_signal_handlers_
.push_back(std::move(old_handler
));
997 impl_
->server_
->Wait();
998 impl_
->running_instance_
= nullptr;
1000 // Restore signal handlers
1001 for (size_t i
= 0; i
< impl_
->signals_
.size(); ++i
) {
1003 SetSignalHandler(impl_
->signals_
[i
], impl_
->old_signal_handlers_
[i
]).status());
1005 return Status::OK();
1008 int FlightServerBase::GotSignal() const { return impl_
->got_signal_
; }
1010 Status
FlightServerBase::Shutdown() {
1011 auto server
= impl_
->server_
.get();
1013 return Status::Invalid("Shutdown() on uninitialized FlightServerBase");
1015 impl_
->server_
->Shutdown();
1016 return Status::OK();
1019 Status
FlightServerBase::Wait() {
1020 impl_
->server_
->Wait();
1021 impl_
->running_instance_
= nullptr;
1022 return Status::OK();
1025 Status
FlightServerBase::ListFlights(const ServerCallContext
& context
,
1026 const Criteria
* criteria
,
1027 std::unique_ptr
<FlightListing
>* listings
) {
1028 return Status::NotImplemented("NYI");
1031 Status
FlightServerBase::GetFlightInfo(const ServerCallContext
& context
,
1032 const FlightDescriptor
& request
,
1033 std::unique_ptr
<FlightInfo
>* info
) {
1034 return Status::NotImplemented("NYI");
1037 Status
FlightServerBase::DoGet(const ServerCallContext
& context
, const Ticket
& request
,
1038 std::unique_ptr
<FlightDataStream
>* data_stream
) {
1039 return Status::NotImplemented("NYI");
1042 Status
FlightServerBase::DoPut(const ServerCallContext
& context
,
1043 std::unique_ptr
<FlightMessageReader
> reader
,
1044 std::unique_ptr
<FlightMetadataWriter
> writer
) {
1045 return Status::NotImplemented("NYI");
1048 Status
FlightServerBase::DoExchange(const ServerCallContext
& context
,
1049 std::unique_ptr
<FlightMessageReader
> reader
,
1050 std::unique_ptr
<FlightMessageWriter
> writer
) {
1051 return Status::NotImplemented("NYI");
1054 Status
FlightServerBase::DoAction(const ServerCallContext
& context
, const Action
& action
,
1055 std::unique_ptr
<ResultStream
>* result
) {
1056 return Status::NotImplemented("NYI");
1059 Status
FlightServerBase::ListActions(const ServerCallContext
& context
,
1060 std::vector
<ActionType
>* actions
) {
1061 return Status::NotImplemented("NYI");
1064 Status
FlightServerBase::GetSchema(const ServerCallContext
& context
,
1065 const FlightDescriptor
& request
,
1066 std::unique_ptr
<SchemaResult
>* schema
) {
1067 return Status::NotImplemented("NYI");
1070 // ----------------------------------------------------------------------
1071 // Implement RecordBatchStream
1073 class RecordBatchStream::RecordBatchStreamImpl
{
1075 // Stages of the stream when producing payloads
1077 NEW
, // The stream has been created, but Next has not been called yet
1078 DICTIONARY
, // Dictionaries have been collected, and are being sent
1079 RECORD_BATCH
// Initial have been sent
1082 RecordBatchStreamImpl(const std::shared_ptr
<RecordBatchReader
>& reader
,
1083 const ipc::IpcWriteOptions
& options
)
1084 : reader_(reader
), mapper_(*reader_
->schema()), ipc_options_(options
) {}
1086 std::shared_ptr
<Schema
> schema() { return reader_
->schema(); }
1088 Status
GetSchemaPayload(FlightPayload
* payload
) {
1089 return ipc::GetSchemaPayload(*reader_
->schema(), ipc_options_
, mapper_
,
1090 &payload
->ipc_message
);
1093 Status
Next(FlightPayload
* payload
) {
1094 if (stage_
== Stage::NEW
) {
1095 RETURN_NOT_OK(reader_
->ReadNext(¤t_batch_
));
1096 if (!current_batch_
) {
1097 // Signal that iteration is over
1098 payload
->ipc_message
.metadata
= nullptr;
1099 return Status::OK();
1101 ARROW_ASSIGN_OR_RAISE(dictionaries_
,
1102 ipc::CollectDictionaries(*current_batch_
, mapper_
));
1103 stage_
= Stage::DICTIONARY
;
1106 if (stage_
== Stage::DICTIONARY
) {
1107 if (dictionary_index_
== static_cast<int>(dictionaries_
.size())) {
1108 stage_
= Stage::RECORD_BATCH
;
1109 return ipc::GetRecordBatchPayload(*current_batch_
, ipc_options_
,
1110 &payload
->ipc_message
);
1112 return GetNextDictionary(payload
);
1116 RETURN_NOT_OK(reader_
->ReadNext(¤t_batch_
));
1118 // TODO(wesm): Delta dictionaries
1119 if (!current_batch_
) {
1120 // Signal that iteration is over
1121 payload
->ipc_message
.metadata
= nullptr;
1122 return Status::OK();
1124 return ipc::GetRecordBatchPayload(*current_batch_
, ipc_options_
,
1125 &payload
->ipc_message
);
1130 Status
GetNextDictionary(FlightPayload
* payload
) {
1131 const auto& it
= dictionaries_
[dictionary_index_
++];
1132 return ipc::GetDictionaryPayload(it
.first
, it
.second
, ipc_options_
,
1133 &payload
->ipc_message
);
1136 Stage stage_
= Stage::NEW
;
1137 std::shared_ptr
<RecordBatchReader
> reader_
;
1138 ipc::DictionaryFieldMapper mapper_
;
1139 ipc::IpcWriteOptions ipc_options_
;
1140 std::shared_ptr
<RecordBatch
> current_batch_
;
1141 std::vector
<std::pair
<int64_t, std::shared_ptr
<Array
>>> dictionaries_
;
1143 // Index of next dictionary to send
1144 int dictionary_index_
= 0;
1147 FlightDataStream::~FlightDataStream() {}
1149 RecordBatchStream::RecordBatchStream(const std::shared_ptr
<RecordBatchReader
>& reader
,
1150 const ipc::IpcWriteOptions
& options
) {
1151 impl_
.reset(new RecordBatchStreamImpl(reader
, options
));
1154 RecordBatchStream::~RecordBatchStream() {}
1156 std::shared_ptr
<Schema
> RecordBatchStream::schema() { return impl_
->schema(); }
1158 Status
RecordBatchStream::GetSchemaPayload(FlightPayload
* payload
) {
1159 return impl_
->GetSchemaPayload(payload
);
1162 Status
RecordBatchStream::Next(FlightPayload
* payload
) { return impl_
->Next(payload
); }
1164 } // namespace flight
1165 } // namespace arrow