]> git.proxmox.com Git - ceph.git/blob - ceph/src/arrow/cpp/src/arrow/flight/server.cc
import quincy 17.2.0
[ceph.git] / ceph / src / arrow / cpp / src / arrow / flight / server.cc
1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements. See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership. The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License. You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied. See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17
18 // Platform-specific defines
19 #include "arrow/flight/platform.h"
20
21 #include "arrow/flight/server.h"
22
23 #ifdef _WIN32
24 #include <io.h>
25 #else
26 #include <fcntl.h>
27 #include <unistd.h>
28 #endif
29 #include <atomic>
30 #include <cerrno>
31 #include <cstdint>
32 #include <memory>
33 #include <sstream>
34 #include <string>
35 #include <thread>
36 #include <unordered_map>
37 #include <utility>
38
39 #ifdef GRPCPP_PP_INCLUDE
40 #include <grpcpp/grpcpp.h>
41 #else
42 #include <grpc++/grpc++.h>
43 #endif
44
45 #include "arrow/buffer.h"
46 #include "arrow/ipc/dictionary.h"
47 #include "arrow/ipc/options.h"
48 #include "arrow/ipc/reader.h"
49 #include "arrow/ipc/writer.h"
50 #include "arrow/memory_pool.h"
51 #include "arrow/record_batch.h"
52 #include "arrow/status.h"
53 #include "arrow/util/io_util.h"
54 #include "arrow/util/logging.h"
55 #include "arrow/util/uri.h"
56
57 #include "arrow/flight/internal.h"
58 #include "arrow/flight/middleware.h"
59 #include "arrow/flight/middleware_internal.h"
60 #include "arrow/flight/serialization_internal.h"
61 #include "arrow/flight/server_auth.h"
62 #include "arrow/flight/server_middleware.h"
63 #include "arrow/flight/types.h"
64
65 using FlightService = arrow::flight::protocol::FlightService;
66 using ServerContext = grpc::ServerContext;
67
68 template <typename T>
69 using ServerWriter = grpc::ServerWriter<T>;
70
71 namespace arrow {
72 namespace flight {
73
74 namespace pb = arrow::flight::protocol;
75
76 // Macro that runs interceptors before returning the given status
77 #define RETURN_WITH_MIDDLEWARE(CONTEXT, STATUS) \
78 do { \
79 const auto& __s = (STATUS); \
80 return CONTEXT.FinishRequest(__s); \
81 } while (false)
82
83 #define CHECK_ARG_NOT_NULL(CONTEXT, VAL, MESSAGE) \
84 if (VAL == nullptr) { \
85 RETURN_WITH_MIDDLEWARE(CONTEXT, \
86 grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, MESSAGE)); \
87 }
88
89 // Same as RETURN_NOT_OK, but accepts either Arrow or gRPC status, and
90 // will run interceptors
91 #define SERVICE_RETURN_NOT_OK(CONTEXT, expr) \
92 do { \
93 const auto& _s = (expr); \
94 if (ARROW_PREDICT_FALSE(!_s.ok())) { \
95 return CONTEXT.FinishRequest(_s); \
96 } \
97 } while (false)
98
99 namespace {
100
101 // A MessageReader implementation that reads from a gRPC ServerReader.
102 // Templated to be generic over DoPut/DoExchange.
103 template <typename Reader>
104 class FlightIpcMessageReader : public ipc::MessageReader {
105 public:
106 explicit FlightIpcMessageReader(
107 std::shared_ptr<internal::PeekableFlightDataReader<Reader*>> peekable_reader,
108 std::shared_ptr<Buffer>* app_metadata)
109 : peekable_reader_(peekable_reader), app_metadata_(app_metadata) {}
110
111 ::arrow::Result<std::unique_ptr<ipc::Message>> ReadNextMessage() override {
112 if (stream_finished_) {
113 return nullptr;
114 }
115 internal::FlightData* data;
116 peekable_reader_->Next(&data);
117 if (!data) {
118 stream_finished_ = true;
119 if (first_message_) {
120 return Status::Invalid(
121 "Client provided malformed message or did not provide message");
122 }
123 return nullptr;
124 }
125 *app_metadata_ = std::move(data->app_metadata);
126 return data->OpenMessage();
127 }
128
129 protected:
130 std::shared_ptr<internal::PeekableFlightDataReader<Reader*>> peekable_reader_;
131 // A reference to FlightMessageReaderImpl.app_metadata_. That class
132 // can't access the app metadata because when it Peek()s the stream,
133 // it may be looking at a dictionary batch, not the record
134 // batch. Updating it here ensures the reader is always updated with
135 // the last metadata message read.
136 std::shared_ptr<Buffer>* app_metadata_;
137 bool first_message_ = true;
138 bool stream_finished_ = false;
139 };
140
141 template <typename WritePayload>
142 class FlightMessageReaderImpl : public FlightMessageReader {
143 public:
144 using GrpcStream = grpc::ServerReaderWriter<WritePayload, pb::FlightData>;
145
146 explicit FlightMessageReaderImpl(GrpcStream* reader)
147 : reader_(reader),
148 peekable_reader_(new internal::PeekableFlightDataReader<GrpcStream*>(reader)) {}
149
150 Status Init() {
151 // Peek the first message to get the descriptor.
152 internal::FlightData* data;
153 peekable_reader_->Peek(&data);
154 if (!data) {
155 return Status::IOError("Stream finished before first message sent");
156 }
157 if (!data->descriptor) {
158 return Status::IOError("Descriptor missing on first message");
159 }
160 descriptor_ = *data->descriptor.get(); // Copy
161 // If there's a schema (=DoPut), also Open().
162 if (data->metadata) {
163 return EnsureDataStarted();
164 }
165 peekable_reader_->Next(&data);
166 return Status::OK();
167 }
168
169 const FlightDescriptor& descriptor() const override { return descriptor_; }
170
171 arrow::Result<std::shared_ptr<Schema>> GetSchema() override {
172 RETURN_NOT_OK(EnsureDataStarted());
173 return batch_reader_->schema();
174 }
175
176 Status Next(FlightStreamChunk* out) override {
177 internal::FlightData* data;
178 peekable_reader_->Peek(&data);
179 if (!data) {
180 out->app_metadata = nullptr;
181 out->data = nullptr;
182 return Status::OK();
183 }
184
185 if (!data->metadata) {
186 // Metadata-only (data->metadata is the IPC header)
187 out->app_metadata = data->app_metadata;
188 out->data = nullptr;
189 peekable_reader_->Next(&data);
190 return Status::OK();
191 }
192
193 if (!batch_reader_) {
194 RETURN_NOT_OK(EnsureDataStarted());
195 // re-peek here since EnsureDataStarted() advances the stream
196 return Next(out);
197 }
198 RETURN_NOT_OK(batch_reader_->ReadNext(&out->data));
199 out->app_metadata = std::move(app_metadata_);
200 return Status::OK();
201 }
202
203 private:
204 /// Ensure we are set up to read data.
205 Status EnsureDataStarted() {
206 if (!batch_reader_) {
207 // peek() until we find the first data message; discard metadata
208 if (!peekable_reader_->SkipToData()) {
209 return Status::IOError("Client never sent a data message");
210 }
211 auto message_reader = std::unique_ptr<ipc::MessageReader>(
212 new FlightIpcMessageReader<GrpcStream>(peekable_reader_, &app_metadata_));
213 ARROW_ASSIGN_OR_RAISE(
214 batch_reader_, ipc::RecordBatchStreamReader::Open(std::move(message_reader)));
215 }
216 return Status::OK();
217 }
218
219 FlightDescriptor descriptor_;
220 GrpcStream* reader_;
221 std::shared_ptr<internal::PeekableFlightDataReader<GrpcStream*>> peekable_reader_;
222 std::shared_ptr<RecordBatchReader> batch_reader_;
223 std::shared_ptr<Buffer> app_metadata_;
224 };
225
226 class GrpcMetadataWriter : public FlightMetadataWriter {
227 public:
228 explicit GrpcMetadataWriter(
229 grpc::ServerReaderWriter<pb::PutResult, pb::FlightData>* writer)
230 : writer_(writer) {}
231
232 Status WriteMetadata(const Buffer& buffer) override {
233 pb::PutResult message{};
234 message.set_app_metadata(buffer.data(), buffer.size());
235 if (writer_->Write(message)) {
236 return Status::OK();
237 }
238 return Status::IOError("Unknown error writing metadata.");
239 }
240
241 private:
242 grpc::ServerReaderWriter<pb::PutResult, pb::FlightData>* writer_;
243 };
244
245 class GrpcServerAuthReader : public ServerAuthReader {
246 public:
247 explicit GrpcServerAuthReader(
248 grpc::ServerReaderWriter<pb::HandshakeResponse, pb::HandshakeRequest>* stream)
249 : stream_(stream) {}
250
251 Status Read(std::string* token) override {
252 pb::HandshakeRequest request;
253 if (stream_->Read(&request)) {
254 *token = std::move(*request.mutable_payload());
255 return Status::OK();
256 }
257 return Status::IOError("Stream is closed.");
258 }
259
260 private:
261 grpc::ServerReaderWriter<pb::HandshakeResponse, pb::HandshakeRequest>* stream_;
262 };
263
264 class GrpcServerAuthSender : public ServerAuthSender {
265 public:
266 explicit GrpcServerAuthSender(
267 grpc::ServerReaderWriter<pb::HandshakeResponse, pb::HandshakeRequest>* stream)
268 : stream_(stream) {}
269
270 Status Write(const std::string& token) override {
271 pb::HandshakeResponse response;
272 response.set_payload(token);
273 if (stream_->Write(response)) {
274 return Status::OK();
275 }
276 return Status::IOError("Stream was closed.");
277 }
278
279 private:
280 grpc::ServerReaderWriter<pb::HandshakeResponse, pb::HandshakeRequest>* stream_;
281 };
282
283 /// The implementation of the write side of a bidirectional FlightData
284 /// stream for DoExchange.
285 class DoExchangeMessageWriter : public FlightMessageWriter {
286 public:
287 explicit DoExchangeMessageWriter(
288 grpc::ServerReaderWriter<pb::FlightData, pb::FlightData>* stream)
289 : stream_(stream), ipc_options_(::arrow::ipc::IpcWriteOptions::Defaults()) {}
290
291 Status Begin(const std::shared_ptr<Schema>& schema,
292 const ipc::IpcWriteOptions& options) override {
293 if (started_) {
294 return Status::Invalid("This writer has already been started.");
295 }
296 started_ = true;
297 ipc_options_ = options;
298
299 RETURN_NOT_OK(mapper_.AddSchemaFields(*schema));
300 FlightPayload schema_payload;
301 RETURN_NOT_OK(ipc::GetSchemaPayload(*schema, ipc_options_, mapper_,
302 &schema_payload.ipc_message));
303 return WritePayload(schema_payload);
304 }
305
306 Status WriteRecordBatch(const RecordBatch& batch) override {
307 return WriteWithMetadata(batch, nullptr);
308 }
309
310 Status WriteMetadata(std::shared_ptr<Buffer> app_metadata) override {
311 FlightPayload payload{};
312 payload.app_metadata = app_metadata;
313 return WritePayload(payload);
314 }
315
316 Status WriteWithMetadata(const RecordBatch& batch,
317 std::shared_ptr<Buffer> app_metadata) override {
318 RETURN_NOT_OK(CheckStarted());
319 RETURN_NOT_OK(EnsureDictionariesWritten(batch));
320 FlightPayload payload{};
321 if (app_metadata) {
322 payload.app_metadata = app_metadata;
323 }
324 RETURN_NOT_OK(ipc::GetRecordBatchPayload(batch, ipc_options_, &payload.ipc_message));
325 RETURN_NOT_OK(WritePayload(payload));
326 ++stats_.num_record_batches;
327 return Status::OK();
328 }
329
330 Status Close() override {
331 // It's fine to Close() without writing data
332 return Status::OK();
333 }
334
335 ipc::WriteStats stats() const override { return stats_; }
336
337 private:
338 Status WritePayload(const FlightPayload& payload) {
339 RETURN_NOT_OK(internal::WritePayload(payload, stream_));
340 ++stats_.num_messages;
341 return Status::OK();
342 }
343
344 Status CheckStarted() {
345 if (!started_) {
346 return Status::Invalid("This writer is not started. Call Begin() with a schema");
347 }
348 return Status::OK();
349 }
350
351 Status EnsureDictionariesWritten(const RecordBatch& batch) {
352 if (dictionaries_written_) {
353 return Status::OK();
354 }
355 dictionaries_written_ = true;
356 ARROW_ASSIGN_OR_RAISE(const auto dictionaries,
357 ipc::CollectDictionaries(batch, mapper_));
358 for (const auto& pair : dictionaries) {
359 FlightPayload payload{};
360 RETURN_NOT_OK(ipc::GetDictionaryPayload(pair.first, pair.second, ipc_options_,
361 &payload.ipc_message));
362 RETURN_NOT_OK(WritePayload(payload));
363 ++stats_.num_dictionary_batches;
364 }
365 return Status::OK();
366 }
367
368 grpc::ServerReaderWriter<pb::FlightData, pb::FlightData>* stream_;
369 ::arrow::ipc::IpcWriteOptions ipc_options_;
370 ipc::DictionaryFieldMapper mapper_;
371 ipc::WriteStats stats_;
372 bool started_ = false;
373 bool dictionaries_written_ = false;
374 };
375
376 class FlightServiceImpl;
377 class GrpcServerCallContext : public ServerCallContext {
378 explicit GrpcServerCallContext(grpc::ServerContext* context)
379 : context_(context), peer_(context_->peer()) {}
380
381 const std::string& peer_identity() const override { return peer_identity_; }
382 const std::string& peer() const override { return peer_; }
383 bool is_cancelled() const override { return context_->IsCancelled(); }
384
385 // Helper method that runs interceptors given the result of an RPC,
386 // then returns the final gRPC status to send to the client
387 grpc::Status FinishRequest(const grpc::Status& status) {
388 // Don't double-convert status - return the original one here
389 FinishRequest(internal::FromGrpcStatus(status));
390 return status;
391 }
392
393 grpc::Status FinishRequest(const arrow::Status& status) {
394 for (const auto& instance : middleware_) {
395 instance->CallCompleted(status);
396 }
397
398 // Set custom headers to map the exact Arrow status for clients
399 // who want it.
400 return internal::ToGrpcStatus(status, context_);
401 }
402
403 ServerMiddleware* GetMiddleware(const std::string& key) const override {
404 const auto& instance = middleware_map_.find(key);
405 if (instance == middleware_map_.end()) {
406 return nullptr;
407 }
408 return instance->second.get();
409 }
410
411 private:
412 friend class FlightServiceImpl;
413 ServerContext* context_;
414 std::string peer_;
415 std::string peer_identity_;
416 std::vector<std::shared_ptr<ServerMiddleware>> middleware_;
417 std::unordered_map<std::string, std::shared_ptr<ServerMiddleware>> middleware_map_;
418 };
419
420 class GrpcAddCallHeaders : public AddCallHeaders {
421 public:
422 explicit GrpcAddCallHeaders(grpc::ServerContext* context) : context_(context) {}
423 ~GrpcAddCallHeaders() override = default;
424
425 void AddHeader(const std::string& key, const std::string& value) override {
426 context_->AddInitialMetadata(key, value);
427 }
428
429 private:
430 grpc::ServerContext* context_;
431 };
432
433 // This class glues an implementation of FlightServerBase together with the
434 // gRPC service definition, so the latter is not exposed in the public API
435 class FlightServiceImpl : public FlightService::Service {
436 public:
437 explicit FlightServiceImpl(
438 std::shared_ptr<ServerAuthHandler> auth_handler,
439 std::vector<std::pair<std::string, std::shared_ptr<ServerMiddlewareFactory>>>
440 middleware,
441 FlightServerBase* server)
442 : auth_handler_(auth_handler), middleware_(middleware), server_(server) {}
443
444 template <typename UserType, typename Iterator, typename ProtoType>
445 grpc::Status WriteStream(Iterator* iterator, ServerWriter<ProtoType>* writer) {
446 if (!iterator) {
447 return grpc::Status(grpc::StatusCode::INTERNAL, "No items to iterate");
448 }
449 // Write flight info to stream until listing is exhausted
450 while (true) {
451 ProtoType pb_value;
452 std::unique_ptr<UserType> value;
453 GRPC_RETURN_NOT_OK(iterator->Next(&value));
454 if (!value) {
455 break;
456 }
457 GRPC_RETURN_NOT_OK(internal::ToProto(*value, &pb_value));
458
459 // Blocking write
460 if (!writer->Write(pb_value)) {
461 // Write returns false if the stream is closed
462 break;
463 }
464 }
465 return grpc::Status::OK;
466 }
467
468 template <typename UserType, typename ProtoType>
469 grpc::Status WriteStream(const std::vector<UserType>& values,
470 ServerWriter<ProtoType>* writer) {
471 // Write flight info to stream until listing is exhausted
472 for (const UserType& value : values) {
473 ProtoType pb_value;
474 GRPC_RETURN_NOT_OK(internal::ToProto(value, &pb_value));
475 // Blocking write
476 if (!writer->Write(pb_value)) {
477 // Write returns false if the stream is closed
478 break;
479 }
480 }
481 return grpc::Status::OK;
482 }
483
484 // Authenticate the client (if applicable) and construct the call context
485 grpc::Status CheckAuth(const FlightMethod& method, ServerContext* context,
486 GrpcServerCallContext& flight_context) {
487 if (!auth_handler_) {
488 const auto auth_context = context->auth_context();
489 if (auth_context && auth_context->IsPeerAuthenticated()) {
490 auto peer_identity = auth_context->GetPeerIdentity();
491 flight_context.peer_identity_ =
492 peer_identity.empty()
493 ? ""
494 : std::string(peer_identity.front().begin(), peer_identity.front().end());
495 } else {
496 flight_context.peer_identity_ = "";
497 }
498 } else {
499 const auto client_metadata = context->client_metadata();
500 const auto auth_header = client_metadata.find(internal::kGrpcAuthHeader);
501 std::string token;
502 if (auth_header == client_metadata.end()) {
503 token = "";
504 } else {
505 token = std::string(auth_header->second.data(), auth_header->second.length());
506 }
507 GRPC_RETURN_NOT_OK(auth_handler_->IsValid(token, &flight_context.peer_identity_));
508 }
509
510 return MakeCallContext(method, context, flight_context);
511 }
512
513 // Authenticate the client (if applicable) and construct the call context
514 grpc::Status MakeCallContext(const FlightMethod& method, ServerContext* context,
515 GrpcServerCallContext& flight_context) {
516 // Run server middleware
517 const CallInfo info{method};
518 CallHeaders incoming_headers;
519 for (const auto& entry : context->client_metadata()) {
520 incoming_headers.insert(
521 {util::string_view(entry.first.data(), entry.first.length()),
522 util::string_view(entry.second.data(), entry.second.length())});
523 }
524
525 GrpcAddCallHeaders outgoing_headers(context);
526 for (const auto& factory : middleware_) {
527 std::shared_ptr<ServerMiddleware> instance;
528 Status result = factory.second->StartCall(info, incoming_headers, &instance);
529 if (!result.ok()) {
530 // Interceptor rejected call, end the request on all existing
531 // interceptors
532 return flight_context.FinishRequest(result);
533 }
534 if (instance != nullptr) {
535 flight_context.middleware_.push_back(instance);
536 flight_context.middleware_map_.insert({factory.first, instance});
537 instance->SendingHeaders(&outgoing_headers);
538 }
539 }
540
541 return grpc::Status::OK;
542 }
543
544 grpc::Status Handshake(
545 ServerContext* context,
546 grpc::ServerReaderWriter<pb::HandshakeResponse, pb::HandshakeRequest>* stream) {
547 GrpcServerCallContext flight_context(context);
548 GRPC_RETURN_NOT_GRPC_OK(
549 MakeCallContext(FlightMethod::Handshake, context, flight_context));
550
551 if (!auth_handler_) {
552 RETURN_WITH_MIDDLEWARE(
553 flight_context,
554 grpc::Status(
555 grpc::StatusCode::UNIMPLEMENTED,
556 "This service does not have an authentication mechanism enabled."));
557 }
558 GrpcServerAuthSender outgoing{stream};
559 GrpcServerAuthReader incoming{stream};
560 RETURN_WITH_MIDDLEWARE(flight_context,
561 auth_handler_->Authenticate(&outgoing, &incoming));
562 }
563
564 grpc::Status ListFlights(ServerContext* context, const pb::Criteria* request,
565 ServerWriter<pb::FlightInfo>* writer) {
566 GrpcServerCallContext flight_context(context);
567 GRPC_RETURN_NOT_GRPC_OK(
568 CheckAuth(FlightMethod::ListFlights, context, flight_context));
569
570 // Retrieve the listing from the implementation
571 std::unique_ptr<FlightListing> listing;
572
573 Criteria criteria;
574 if (request) {
575 SERVICE_RETURN_NOT_OK(flight_context, internal::FromProto(*request, &criteria));
576 }
577 SERVICE_RETURN_NOT_OK(flight_context,
578 server_->ListFlights(flight_context, &criteria, &listing));
579 if (!listing) {
580 // Treat null listing as no flights available
581 RETURN_WITH_MIDDLEWARE(flight_context, grpc::Status::OK);
582 }
583 RETURN_WITH_MIDDLEWARE(flight_context,
584 WriteStream<FlightInfo>(listing.get(), writer));
585 }
586
587 grpc::Status GetFlightInfo(ServerContext* context, const pb::FlightDescriptor* request,
588 pb::FlightInfo* response) {
589 GrpcServerCallContext flight_context(context);
590 GRPC_RETURN_NOT_GRPC_OK(
591 CheckAuth(FlightMethod::GetFlightInfo, context, flight_context));
592
593 CHECK_ARG_NOT_NULL(flight_context, request, "FlightDescriptor cannot be null");
594
595 FlightDescriptor descr;
596 SERVICE_RETURN_NOT_OK(flight_context, internal::FromProto(*request, &descr));
597
598 std::unique_ptr<FlightInfo> info;
599 SERVICE_RETURN_NOT_OK(flight_context,
600 server_->GetFlightInfo(flight_context, descr, &info));
601
602 if (!info) {
603 // Treat null listing as no flights available
604 RETURN_WITH_MIDDLEWARE(
605 flight_context, grpc::Status(grpc::StatusCode::NOT_FOUND, "Flight not found"));
606 }
607
608 SERVICE_RETURN_NOT_OK(flight_context, internal::ToProto(*info, response));
609 RETURN_WITH_MIDDLEWARE(flight_context, grpc::Status::OK);
610 }
611
612 grpc::Status GetSchema(ServerContext* context, const pb::FlightDescriptor* request,
613 pb::SchemaResult* response) {
614 GrpcServerCallContext flight_context(context);
615 GRPC_RETURN_NOT_GRPC_OK(CheckAuth(FlightMethod::GetSchema, context, flight_context));
616
617 CHECK_ARG_NOT_NULL(flight_context, request, "FlightDescriptor cannot be null");
618
619 FlightDescriptor descr;
620 SERVICE_RETURN_NOT_OK(flight_context, internal::FromProto(*request, &descr));
621
622 std::unique_ptr<SchemaResult> result;
623 SERVICE_RETURN_NOT_OK(flight_context,
624 server_->GetSchema(flight_context, descr, &result));
625
626 if (!result) {
627 // Treat null listing as no flights available
628 RETURN_WITH_MIDDLEWARE(
629 flight_context, grpc::Status(grpc::StatusCode::NOT_FOUND, "Flight not found"));
630 }
631
632 SERVICE_RETURN_NOT_OK(flight_context, internal::ToProto(*result, response));
633 RETURN_WITH_MIDDLEWARE(flight_context, grpc::Status::OK);
634 }
635
636 grpc::Status DoGet(ServerContext* context, const pb::Ticket* request,
637 ServerWriter<pb::FlightData>* writer) {
638 GrpcServerCallContext flight_context(context);
639 GRPC_RETURN_NOT_GRPC_OK(CheckAuth(FlightMethod::DoGet, context, flight_context));
640
641 CHECK_ARG_NOT_NULL(flight_context, request, "ticket cannot be null");
642
643 Ticket ticket;
644 SERVICE_RETURN_NOT_OK(flight_context, internal::FromProto(*request, &ticket));
645
646 std::unique_ptr<FlightDataStream> data_stream;
647 SERVICE_RETURN_NOT_OK(flight_context,
648 server_->DoGet(flight_context, ticket, &data_stream));
649
650 if (!data_stream) {
651 RETURN_WITH_MIDDLEWARE(flight_context, grpc::Status(grpc::StatusCode::NOT_FOUND,
652 "No data in this flight"));
653 }
654
655 // Write the schema as the first message in the stream
656 FlightPayload schema_payload;
657 SERVICE_RETURN_NOT_OK(flight_context, data_stream->GetSchemaPayload(&schema_payload));
658 auto status = internal::WritePayload(schema_payload, writer);
659 if (status.IsIOError()) {
660 // gRPC doesn't give any way for us to know why the message
661 // could not be written.
662 RETURN_WITH_MIDDLEWARE(flight_context, grpc::Status::OK);
663 }
664 SERVICE_RETURN_NOT_OK(flight_context, status);
665
666 // Consume data stream and write out payloads
667 while (true) {
668 FlightPayload payload;
669 SERVICE_RETURN_NOT_OK(flight_context, data_stream->Next(&payload));
670 // End of stream
671 if (payload.ipc_message.metadata == nullptr) break;
672 auto status = internal::WritePayload(payload, writer);
673 // Connection terminated
674 if (status.IsIOError()) break;
675 SERVICE_RETURN_NOT_OK(flight_context, status);
676 }
677 RETURN_WITH_MIDDLEWARE(flight_context, grpc::Status::OK);
678 }
679
680 grpc::Status DoPut(ServerContext* context,
681 grpc::ServerReaderWriter<pb::PutResult, pb::FlightData>* reader) {
682 GrpcServerCallContext flight_context(context);
683 GRPC_RETURN_NOT_GRPC_OK(CheckAuth(FlightMethod::DoPut, context, flight_context));
684
685 auto message_reader = std::unique_ptr<FlightMessageReaderImpl<pb::PutResult>>(
686 new FlightMessageReaderImpl<pb::PutResult>(reader));
687 SERVICE_RETURN_NOT_OK(flight_context, message_reader->Init());
688 auto metadata_writer =
689 std::unique_ptr<FlightMetadataWriter>(new GrpcMetadataWriter(reader));
690 RETURN_WITH_MIDDLEWARE(flight_context,
691 server_->DoPut(flight_context, std::move(message_reader),
692 std::move(metadata_writer)));
693 }
694
695 grpc::Status DoExchange(
696 ServerContext* context,
697 grpc::ServerReaderWriter<pb::FlightData, pb::FlightData>* stream) {
698 GrpcServerCallContext flight_context(context);
699 GRPC_RETURN_NOT_GRPC_OK(CheckAuth(FlightMethod::DoExchange, context, flight_context));
700 auto message_reader = std::unique_ptr<FlightMessageReaderImpl<pb::FlightData>>(
701 new FlightMessageReaderImpl<pb::FlightData>(stream));
702 SERVICE_RETURN_NOT_OK(flight_context, message_reader->Init());
703 auto writer =
704 std::unique_ptr<DoExchangeMessageWriter>(new DoExchangeMessageWriter(stream));
705 RETURN_WITH_MIDDLEWARE(flight_context,
706 server_->DoExchange(flight_context, std::move(message_reader),
707 std::move(writer)));
708 }
709
710 grpc::Status ListActions(ServerContext* context, const pb::Empty* request,
711 ServerWriter<pb::ActionType>* writer) {
712 GrpcServerCallContext flight_context(context);
713 GRPC_RETURN_NOT_GRPC_OK(
714 CheckAuth(FlightMethod::ListActions, context, flight_context));
715 // Retrieve the listing from the implementation
716 std::vector<ActionType> types;
717 SERVICE_RETURN_NOT_OK(flight_context, server_->ListActions(flight_context, &types));
718 RETURN_WITH_MIDDLEWARE(flight_context, WriteStream<ActionType>(types, writer));
719 }
720
721 grpc::Status DoAction(ServerContext* context, const pb::Action* request,
722 ServerWriter<pb::Result>* writer) {
723 GrpcServerCallContext flight_context(context);
724 GRPC_RETURN_NOT_GRPC_OK(CheckAuth(FlightMethod::DoAction, context, flight_context));
725 CHECK_ARG_NOT_NULL(flight_context, request, "Action cannot be null");
726 Action action;
727 SERVICE_RETURN_NOT_OK(flight_context, internal::FromProto(*request, &action));
728
729 std::unique_ptr<ResultStream> results;
730 SERVICE_RETURN_NOT_OK(flight_context,
731 server_->DoAction(flight_context, action, &results));
732
733 if (!results) {
734 RETURN_WITH_MIDDLEWARE(flight_context, grpc::Status::CANCELLED);
735 }
736
737 while (true) {
738 std::unique_ptr<Result> result;
739 SERVICE_RETURN_NOT_OK(flight_context, results->Next(&result));
740 if (!result) {
741 // No more results
742 break;
743 }
744 pb::Result pb_result;
745 SERVICE_RETURN_NOT_OK(flight_context, internal::ToProto(*result, &pb_result));
746 if (!writer->Write(pb_result)) {
747 // Stream may be closed
748 break;
749 }
750 }
751 RETURN_WITH_MIDDLEWARE(flight_context, grpc::Status::OK);
752 }
753
754 private:
755 std::shared_ptr<ServerAuthHandler> auth_handler_;
756 std::vector<std::pair<std::string, std::shared_ptr<ServerMiddlewareFactory>>>
757 middleware_;
758 FlightServerBase* server_;
759 };
760
761 } // namespace
762
763 FlightMetadataWriter::~FlightMetadataWriter() = default;
764
765 //
766 // gRPC server lifecycle
767 //
768
769 #if (ATOMIC_INT_LOCK_FREE != 2 || ATOMIC_POINTER_LOCK_FREE != 2)
770 #error "atomic ints and atomic pointers not always lock-free!"
771 #endif
772
773 using ::arrow::internal::SetSignalHandler;
774 using ::arrow::internal::SignalHandler;
775
776 #ifdef WIN32
777 #define PIPE_WRITE _write
778 #define PIPE_READ _read
779 #else
780 #define PIPE_WRITE write
781 #define PIPE_READ read
782 #endif
783
784 /// RAII guard that manages a self-pipe and a thread that listens on
785 /// the self-pipe, shutting down the gRPC server when a signal handler
786 /// writes to the pipe.
787 class ServerSignalHandler {
788 public:
789 ARROW_DISALLOW_COPY_AND_ASSIGN(ServerSignalHandler);
790 ServerSignalHandler() = default;
791
792 /// Create the pipe and handler thread.
793 ///
794 /// \return the fd of the write side of the pipe.
795 template <typename Fn>
796 arrow::Result<int> Init(Fn handler) {
797 ARROW_ASSIGN_OR_RAISE(auto pipe, arrow::internal::CreatePipe());
798 #ifndef WIN32
799 // Make write end nonblocking
800 int flags = fcntl(pipe.wfd, F_GETFL);
801 if (flags == -1) {
802 RETURN_NOT_OK(arrow::internal::FileClose(pipe.rfd));
803 RETURN_NOT_OK(arrow::internal::FileClose(pipe.wfd));
804 return arrow::internal::IOErrorFromErrno(
805 errno, "Could not initialize self-pipe to wait for signals");
806 }
807 flags |= O_NONBLOCK;
808 if (fcntl(pipe.wfd, F_SETFL, flags) == -1) {
809 RETURN_NOT_OK(arrow::internal::FileClose(pipe.rfd));
810 RETURN_NOT_OK(arrow::internal::FileClose(pipe.wfd));
811 return arrow::internal::IOErrorFromErrno(
812 errno, "Could not initialize self-pipe to wait for signals");
813 }
814 #endif
815 self_pipe_ = pipe;
816 handle_signals_ = std::thread(handler, self_pipe_.rfd);
817 return self_pipe_.wfd;
818 }
819
820 Status Shutdown() {
821 if (self_pipe_.rfd == 0) {
822 // Already closed
823 return Status::OK();
824 }
825 if (PIPE_WRITE(self_pipe_.wfd, "0", 1) < 0 && errno != EAGAIN &&
826 errno != EWOULDBLOCK && errno != EINTR) {
827 return arrow::internal::IOErrorFromErrno(errno, "Could not unblock signal thread");
828 }
829 RETURN_NOT_OK(arrow::internal::FileClose(self_pipe_.rfd));
830 RETURN_NOT_OK(arrow::internal::FileClose(self_pipe_.wfd));
831 handle_signals_.join();
832 self_pipe_.rfd = 0;
833 self_pipe_.wfd = 0;
834 return Status::OK();
835 }
836
837 ~ServerSignalHandler() { ARROW_CHECK_OK(Shutdown()); }
838
839 private:
840 arrow::internal::Pipe self_pipe_;
841 std::thread handle_signals_;
842 };
843
844 struct FlightServerBase::Impl {
845 std::unique_ptr<FlightServiceImpl> service_;
846 std::unique_ptr<grpc::Server> server_;
847 int port_;
848
849 // Signal handlers (on Windows) and the shutdown handler (other platforms)
850 // are executed in a separate thread, so getting the current thread instance
851 // wouldn't make sense. This means only a single instance can receive signals.
852 static std::atomic<Impl*> running_instance_;
853 // We'll use the self-pipe trick to notify a thread from the signal
854 // handler. The thread will then shut down the gRPC server.
855 int self_pipe_wfd_;
856
857 // Signal handling
858 std::vector<int> signals_;
859 std::vector<SignalHandler> old_signal_handlers_;
860 std::atomic<int> got_signal_;
861
862 static void HandleSignal(int signum) {
863 auto instance = running_instance_.load();
864 if (instance != nullptr) {
865 instance->DoHandleSignal(signum);
866 }
867 }
868
869 void DoHandleSignal(int signum) {
870 got_signal_ = signum;
871 int saved_errno = errno;
872 // Ignore errors - pipe is nonblocking
873 PIPE_WRITE(self_pipe_wfd_, "0", 1);
874 errno = saved_errno;
875 }
876
877 static void WaitForSignals(int fd) {
878 // Wait for a signal handler to write to the pipe
879 int8_t buf[1];
880 while (PIPE_READ(fd, /*buf=*/buf, /*count=*/1) == -1) {
881 if (errno == EINTR) {
882 continue;
883 }
884 ARROW_CHECK_OK(arrow::internal::IOErrorFromErrno(
885 errno, "Error while waiting for shutdown signal"));
886 }
887 auto instance = running_instance_.load();
888 if (instance != nullptr) {
889 instance->server_->Shutdown();
890 }
891 }
892 };
893
894 std::atomic<FlightServerBase::Impl*> FlightServerBase::Impl::running_instance_;
895
896 FlightServerOptions::FlightServerOptions(const Location& location_)
897 : location(location_),
898 auth_handler(nullptr),
899 tls_certificates(),
900 verify_client(false),
901 root_certificates(),
902 middleware(),
903 builder_hook(nullptr) {}
904
905 FlightServerOptions::~FlightServerOptions() = default;
906
907 FlightServerBase::FlightServerBase() { impl_.reset(new Impl); }
908
909 FlightServerBase::~FlightServerBase() {}
910
911 Status FlightServerBase::Init(const FlightServerOptions& options) {
912 impl_->service_.reset(
913 new FlightServiceImpl(options.auth_handler, options.middleware, this));
914
915 grpc::ServerBuilder builder;
916 // Allow uploading messages of any length
917 builder.SetMaxReceiveMessageSize(-1);
918
919 const Location& location = options.location;
920 const std::string scheme = location.scheme();
921 if (scheme == kSchemeGrpc || scheme == kSchemeGrpcTcp || scheme == kSchemeGrpcTls) {
922 std::stringstream address;
923 address << arrow::internal::UriEncodeHost(location.uri_->host()) << ':'
924 << location.uri_->port_text();
925
926 std::shared_ptr<grpc::ServerCredentials> creds;
927 if (scheme == kSchemeGrpcTls) {
928 grpc::SslServerCredentialsOptions ssl_options;
929 for (const auto& pair : options.tls_certificates) {
930 ssl_options.pem_key_cert_pairs.push_back({pair.pem_key, pair.pem_cert});
931 }
932 if (options.verify_client) {
933 ssl_options.client_certificate_request =
934 GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY;
935 }
936 if (!options.root_certificates.empty()) {
937 ssl_options.pem_root_certs = options.root_certificates;
938 }
939 creds = grpc::SslServerCredentials(ssl_options);
940 } else {
941 creds = grpc::InsecureServerCredentials();
942 }
943
944 builder.AddListeningPort(address.str(), creds, &impl_->port_);
945 } else if (scheme == kSchemeGrpcUnix) {
946 std::stringstream address;
947 address << "unix:" << location.uri_->path();
948 builder.AddListeningPort(address.str(), grpc::InsecureServerCredentials());
949 } else {
950 return Status::NotImplemented("Scheme is not supported: " + scheme);
951 }
952
953 builder.RegisterService(impl_->service_.get());
954
955 // Disable SO_REUSEPORT - it makes debugging/testing a pain as
956 // leftover processes can handle requests on accident
957 builder.AddChannelArgument(GRPC_ARG_ALLOW_REUSEPORT, 0);
958
959 if (options.builder_hook) {
960 options.builder_hook(&builder);
961 }
962
963 impl_->server_ = builder.BuildAndStart();
964 if (!impl_->server_) {
965 return Status::UnknownError("Server did not start properly");
966 }
967 return Status::OK();
968 }
969
970 int FlightServerBase::port() const { return impl_->port_; }
971
972 Status FlightServerBase::SetShutdownOnSignals(const std::vector<int> sigs) {
973 impl_->signals_ = sigs;
974 impl_->old_signal_handlers_.clear();
975 return Status::OK();
976 }
977
978 Status FlightServerBase::Serve() {
979 if (!impl_->server_) {
980 return Status::UnknownError("Server did not start properly");
981 }
982 impl_->got_signal_ = 0;
983 impl_->old_signal_handlers_.clear();
984 impl_->running_instance_ = impl_.get();
985
986 ServerSignalHandler signal_handler;
987 ARROW_ASSIGN_OR_RAISE(impl_->self_pipe_wfd_,
988 signal_handler.Init(&Impl::WaitForSignals));
989 // Override existing signal handlers with our own handler so as to stop the server.
990 for (size_t i = 0; i < impl_->signals_.size(); ++i) {
991 int signum = impl_->signals_[i];
992 SignalHandler new_handler(&Impl::HandleSignal), old_handler;
993 ARROW_ASSIGN_OR_RAISE(old_handler, SetSignalHandler(signum, new_handler));
994 impl_->old_signal_handlers_.push_back(std::move(old_handler));
995 }
996
997 impl_->server_->Wait();
998 impl_->running_instance_ = nullptr;
999
1000 // Restore signal handlers
1001 for (size_t i = 0; i < impl_->signals_.size(); ++i) {
1002 RETURN_NOT_OK(
1003 SetSignalHandler(impl_->signals_[i], impl_->old_signal_handlers_[i]).status());
1004 }
1005 return Status::OK();
1006 }
1007
1008 int FlightServerBase::GotSignal() const { return impl_->got_signal_; }
1009
1010 Status FlightServerBase::Shutdown() {
1011 auto server = impl_->server_.get();
1012 if (!server) {
1013 return Status::Invalid("Shutdown() on uninitialized FlightServerBase");
1014 }
1015 impl_->server_->Shutdown();
1016 return Status::OK();
1017 }
1018
1019 Status FlightServerBase::Wait() {
1020 impl_->server_->Wait();
1021 impl_->running_instance_ = nullptr;
1022 return Status::OK();
1023 }
1024
1025 Status FlightServerBase::ListFlights(const ServerCallContext& context,
1026 const Criteria* criteria,
1027 std::unique_ptr<FlightListing>* listings) {
1028 return Status::NotImplemented("NYI");
1029 }
1030
1031 Status FlightServerBase::GetFlightInfo(const ServerCallContext& context,
1032 const FlightDescriptor& request,
1033 std::unique_ptr<FlightInfo>* info) {
1034 return Status::NotImplemented("NYI");
1035 }
1036
1037 Status FlightServerBase::DoGet(const ServerCallContext& context, const Ticket& request,
1038 std::unique_ptr<FlightDataStream>* data_stream) {
1039 return Status::NotImplemented("NYI");
1040 }
1041
1042 Status FlightServerBase::DoPut(const ServerCallContext& context,
1043 std::unique_ptr<FlightMessageReader> reader,
1044 std::unique_ptr<FlightMetadataWriter> writer) {
1045 return Status::NotImplemented("NYI");
1046 }
1047
1048 Status FlightServerBase::DoExchange(const ServerCallContext& context,
1049 std::unique_ptr<FlightMessageReader> reader,
1050 std::unique_ptr<FlightMessageWriter> writer) {
1051 return Status::NotImplemented("NYI");
1052 }
1053
1054 Status FlightServerBase::DoAction(const ServerCallContext& context, const Action& action,
1055 std::unique_ptr<ResultStream>* result) {
1056 return Status::NotImplemented("NYI");
1057 }
1058
1059 Status FlightServerBase::ListActions(const ServerCallContext& context,
1060 std::vector<ActionType>* actions) {
1061 return Status::NotImplemented("NYI");
1062 }
1063
1064 Status FlightServerBase::GetSchema(const ServerCallContext& context,
1065 const FlightDescriptor& request,
1066 std::unique_ptr<SchemaResult>* schema) {
1067 return Status::NotImplemented("NYI");
1068 }
1069
1070 // ----------------------------------------------------------------------
1071 // Implement RecordBatchStream
1072
1073 class RecordBatchStream::RecordBatchStreamImpl {
1074 public:
1075 // Stages of the stream when producing payloads
1076 enum class Stage {
1077 NEW, // The stream has been created, but Next has not been called yet
1078 DICTIONARY, // Dictionaries have been collected, and are being sent
1079 RECORD_BATCH // Initial have been sent
1080 };
1081
1082 RecordBatchStreamImpl(const std::shared_ptr<RecordBatchReader>& reader,
1083 const ipc::IpcWriteOptions& options)
1084 : reader_(reader), mapper_(*reader_->schema()), ipc_options_(options) {}
1085
1086 std::shared_ptr<Schema> schema() { return reader_->schema(); }
1087
1088 Status GetSchemaPayload(FlightPayload* payload) {
1089 return ipc::GetSchemaPayload(*reader_->schema(), ipc_options_, mapper_,
1090 &payload->ipc_message);
1091 }
1092
1093 Status Next(FlightPayload* payload) {
1094 if (stage_ == Stage::NEW) {
1095 RETURN_NOT_OK(reader_->ReadNext(&current_batch_));
1096 if (!current_batch_) {
1097 // Signal that iteration is over
1098 payload->ipc_message.metadata = nullptr;
1099 return Status::OK();
1100 }
1101 ARROW_ASSIGN_OR_RAISE(dictionaries_,
1102 ipc::CollectDictionaries(*current_batch_, mapper_));
1103 stage_ = Stage::DICTIONARY;
1104 }
1105
1106 if (stage_ == Stage::DICTIONARY) {
1107 if (dictionary_index_ == static_cast<int>(dictionaries_.size())) {
1108 stage_ = Stage::RECORD_BATCH;
1109 return ipc::GetRecordBatchPayload(*current_batch_, ipc_options_,
1110 &payload->ipc_message);
1111 } else {
1112 return GetNextDictionary(payload);
1113 }
1114 }
1115
1116 RETURN_NOT_OK(reader_->ReadNext(&current_batch_));
1117
1118 // TODO(wesm): Delta dictionaries
1119 if (!current_batch_) {
1120 // Signal that iteration is over
1121 payload->ipc_message.metadata = nullptr;
1122 return Status::OK();
1123 } else {
1124 return ipc::GetRecordBatchPayload(*current_batch_, ipc_options_,
1125 &payload->ipc_message);
1126 }
1127 }
1128
1129 private:
1130 Status GetNextDictionary(FlightPayload* payload) {
1131 const auto& it = dictionaries_[dictionary_index_++];
1132 return ipc::GetDictionaryPayload(it.first, it.second, ipc_options_,
1133 &payload->ipc_message);
1134 }
1135
1136 Stage stage_ = Stage::NEW;
1137 std::shared_ptr<RecordBatchReader> reader_;
1138 ipc::DictionaryFieldMapper mapper_;
1139 ipc::IpcWriteOptions ipc_options_;
1140 std::shared_ptr<RecordBatch> current_batch_;
1141 std::vector<std::pair<int64_t, std::shared_ptr<Array>>> dictionaries_;
1142
1143 // Index of next dictionary to send
1144 int dictionary_index_ = 0;
1145 };
1146
1147 FlightDataStream::~FlightDataStream() {}
1148
1149 RecordBatchStream::RecordBatchStream(const std::shared_ptr<RecordBatchReader>& reader,
1150 const ipc::IpcWriteOptions& options) {
1151 impl_.reset(new RecordBatchStreamImpl(reader, options));
1152 }
1153
1154 RecordBatchStream::~RecordBatchStream() {}
1155
1156 std::shared_ptr<Schema> RecordBatchStream::schema() { return impl_->schema(); }
1157
1158 Status RecordBatchStream::GetSchemaPayload(FlightPayload* payload) {
1159 return impl_->GetSchemaPayload(payload);
1160 }
1161
1162 Status RecordBatchStream::Next(FlightPayload* payload) { return impl_->Next(payload); }
1163
1164 } // namespace flight
1165 } // namespace arrow