]>
Commit | Line | Data |
---|---|---|
1d09f67e TL |
1 | // Licensed to the Apache Software Foundation (ASF) under one |
2 | // or more contributor license agreements. See the NOTICE file | |
3 | // distributed with this work for additional information | |
4 | // regarding copyright ownership. The ASF licenses this file | |
5 | // to you under the Apache License, Version 2.0 (the | |
6 | // "License"); you may not use this file except in compliance | |
7 | // with the License. You may obtain a copy of the License at | |
8 | // | |
9 | // http://www.apache.org/licenses/LICENSE-2.0 | |
10 | // | |
11 | // Unless required by applicable law or agreed to in writing, | |
12 | // software distributed under the License is distributed on an | |
13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | |
14 | // KIND, either express or implied. See the License for the | |
15 | // specific language governing permissions and limitations | |
16 | // under the License. | |
17 | ||
18 | // Platform-specific defines | |
19 | #include "arrow/flight/platform.h" | |
20 | ||
21 | #include "arrow/flight/server.h" | |
22 | ||
23 | #ifdef _WIN32 | |
24 | #include <io.h> | |
25 | #else | |
26 | #include <fcntl.h> | |
27 | #include <unistd.h> | |
28 | #endif | |
29 | #include <atomic> | |
30 | #include <cerrno> | |
31 | #include <cstdint> | |
32 | #include <memory> | |
33 | #include <sstream> | |
34 | #include <string> | |
35 | #include <thread> | |
36 | #include <unordered_map> | |
37 | #include <utility> | |
38 | ||
39 | #ifdef GRPCPP_PP_INCLUDE | |
40 | #include <grpcpp/grpcpp.h> | |
41 | #else | |
42 | #include <grpc++/grpc++.h> | |
43 | #endif | |
44 | ||
45 | #include "arrow/buffer.h" | |
46 | #include "arrow/ipc/dictionary.h" | |
47 | #include "arrow/ipc/options.h" | |
48 | #include "arrow/ipc/reader.h" | |
49 | #include "arrow/ipc/writer.h" | |
50 | #include "arrow/memory_pool.h" | |
51 | #include "arrow/record_batch.h" | |
52 | #include "arrow/status.h" | |
53 | #include "arrow/util/io_util.h" | |
54 | #include "arrow/util/logging.h" | |
55 | #include "arrow/util/uri.h" | |
56 | ||
57 | #include "arrow/flight/internal.h" | |
58 | #include "arrow/flight/middleware.h" | |
59 | #include "arrow/flight/middleware_internal.h" | |
60 | #include "arrow/flight/serialization_internal.h" | |
61 | #include "arrow/flight/server_auth.h" | |
62 | #include "arrow/flight/server_middleware.h" | |
63 | #include "arrow/flight/types.h" | |
64 | ||
65 | using FlightService = arrow::flight::protocol::FlightService; | |
66 | using ServerContext = grpc::ServerContext; | |
67 | ||
68 | template <typename T> | |
69 | using ServerWriter = grpc::ServerWriter<T>; | |
70 | ||
71 | namespace arrow { | |
72 | namespace flight { | |
73 | ||
74 | namespace pb = arrow::flight::protocol; | |
75 | ||
76 | // Macro that runs interceptors before returning the given status | |
77 | #define RETURN_WITH_MIDDLEWARE(CONTEXT, STATUS) \ | |
78 | do { \ | |
79 | const auto& __s = (STATUS); \ | |
80 | return CONTEXT.FinishRequest(__s); \ | |
81 | } while (false) | |
82 | ||
83 | #define CHECK_ARG_NOT_NULL(CONTEXT, VAL, MESSAGE) \ | |
84 | if (VAL == nullptr) { \ | |
85 | RETURN_WITH_MIDDLEWARE(CONTEXT, \ | |
86 | grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, MESSAGE)); \ | |
87 | } | |
88 | ||
89 | // Same as RETURN_NOT_OK, but accepts either Arrow or gRPC status, and | |
90 | // will run interceptors | |
91 | #define SERVICE_RETURN_NOT_OK(CONTEXT, expr) \ | |
92 | do { \ | |
93 | const auto& _s = (expr); \ | |
94 | if (ARROW_PREDICT_FALSE(!_s.ok())) { \ | |
95 | return CONTEXT.FinishRequest(_s); \ | |
96 | } \ | |
97 | } while (false) | |
98 | ||
99 | namespace { | |
100 | ||
101 | // A MessageReader implementation that reads from a gRPC ServerReader. | |
102 | // Templated to be generic over DoPut/DoExchange. | |
103 | template <typename Reader> | |
104 | class FlightIpcMessageReader : public ipc::MessageReader { | |
105 | public: | |
106 | explicit FlightIpcMessageReader( | |
107 | std::shared_ptr<internal::PeekableFlightDataReader<Reader*>> peekable_reader, | |
108 | std::shared_ptr<Buffer>* app_metadata) | |
109 | : peekable_reader_(peekable_reader), app_metadata_(app_metadata) {} | |
110 | ||
111 | ::arrow::Result<std::unique_ptr<ipc::Message>> ReadNextMessage() override { | |
112 | if (stream_finished_) { | |
113 | return nullptr; | |
114 | } | |
115 | internal::FlightData* data; | |
116 | peekable_reader_->Next(&data); | |
117 | if (!data) { | |
118 | stream_finished_ = true; | |
119 | if (first_message_) { | |
120 | return Status::Invalid( | |
121 | "Client provided malformed message or did not provide message"); | |
122 | } | |
123 | return nullptr; | |
124 | } | |
125 | *app_metadata_ = std::move(data->app_metadata); | |
126 | return data->OpenMessage(); | |
127 | } | |
128 | ||
129 | protected: | |
130 | std::shared_ptr<internal::PeekableFlightDataReader<Reader*>> peekable_reader_; | |
131 | // A reference to FlightMessageReaderImpl.app_metadata_. That class | |
132 | // can't access the app metadata because when it Peek()s the stream, | |
133 | // it may be looking at a dictionary batch, not the record | |
134 | // batch. Updating it here ensures the reader is always updated with | |
135 | // the last metadata message read. | |
136 | std::shared_ptr<Buffer>* app_metadata_; | |
137 | bool first_message_ = true; | |
138 | bool stream_finished_ = false; | |
139 | }; | |
140 | ||
141 | template <typename WritePayload> | |
142 | class FlightMessageReaderImpl : public FlightMessageReader { | |
143 | public: | |
144 | using GrpcStream = grpc::ServerReaderWriter<WritePayload, pb::FlightData>; | |
145 | ||
146 | explicit FlightMessageReaderImpl(GrpcStream* reader) | |
147 | : reader_(reader), | |
148 | peekable_reader_(new internal::PeekableFlightDataReader<GrpcStream*>(reader)) {} | |
149 | ||
150 | Status Init() { | |
151 | // Peek the first message to get the descriptor. | |
152 | internal::FlightData* data; | |
153 | peekable_reader_->Peek(&data); | |
154 | if (!data) { | |
155 | return Status::IOError("Stream finished before first message sent"); | |
156 | } | |
157 | if (!data->descriptor) { | |
158 | return Status::IOError("Descriptor missing on first message"); | |
159 | } | |
160 | descriptor_ = *data->descriptor.get(); // Copy | |
161 | // If there's a schema (=DoPut), also Open(). | |
162 | if (data->metadata) { | |
163 | return EnsureDataStarted(); | |
164 | } | |
165 | peekable_reader_->Next(&data); | |
166 | return Status::OK(); | |
167 | } | |
168 | ||
169 | const FlightDescriptor& descriptor() const override { return descriptor_; } | |
170 | ||
171 | arrow::Result<std::shared_ptr<Schema>> GetSchema() override { | |
172 | RETURN_NOT_OK(EnsureDataStarted()); | |
173 | return batch_reader_->schema(); | |
174 | } | |
175 | ||
176 | Status Next(FlightStreamChunk* out) override { | |
177 | internal::FlightData* data; | |
178 | peekable_reader_->Peek(&data); | |
179 | if (!data) { | |
180 | out->app_metadata = nullptr; | |
181 | out->data = nullptr; | |
182 | return Status::OK(); | |
183 | } | |
184 | ||
185 | if (!data->metadata) { | |
186 | // Metadata-only (data->metadata is the IPC header) | |
187 | out->app_metadata = data->app_metadata; | |
188 | out->data = nullptr; | |
189 | peekable_reader_->Next(&data); | |
190 | return Status::OK(); | |
191 | } | |
192 | ||
193 | if (!batch_reader_) { | |
194 | RETURN_NOT_OK(EnsureDataStarted()); | |
195 | // re-peek here since EnsureDataStarted() advances the stream | |
196 | return Next(out); | |
197 | } | |
198 | RETURN_NOT_OK(batch_reader_->ReadNext(&out->data)); | |
199 | out->app_metadata = std::move(app_metadata_); | |
200 | return Status::OK(); | |
201 | } | |
202 | ||
203 | private: | |
204 | /// Ensure we are set up to read data. | |
205 | Status EnsureDataStarted() { | |
206 | if (!batch_reader_) { | |
207 | // peek() until we find the first data message; discard metadata | |
208 | if (!peekable_reader_->SkipToData()) { | |
209 | return Status::IOError("Client never sent a data message"); | |
210 | } | |
211 | auto message_reader = std::unique_ptr<ipc::MessageReader>( | |
212 | new FlightIpcMessageReader<GrpcStream>(peekable_reader_, &app_metadata_)); | |
213 | ARROW_ASSIGN_OR_RAISE( | |
214 | batch_reader_, ipc::RecordBatchStreamReader::Open(std::move(message_reader))); | |
215 | } | |
216 | return Status::OK(); | |
217 | } | |
218 | ||
219 | FlightDescriptor descriptor_; | |
220 | GrpcStream* reader_; | |
221 | std::shared_ptr<internal::PeekableFlightDataReader<GrpcStream*>> peekable_reader_; | |
222 | std::shared_ptr<RecordBatchReader> batch_reader_; | |
223 | std::shared_ptr<Buffer> app_metadata_; | |
224 | }; | |
225 | ||
226 | class GrpcMetadataWriter : public FlightMetadataWriter { | |
227 | public: | |
228 | explicit GrpcMetadataWriter( | |
229 | grpc::ServerReaderWriter<pb::PutResult, pb::FlightData>* writer) | |
230 | : writer_(writer) {} | |
231 | ||
232 | Status WriteMetadata(const Buffer& buffer) override { | |
233 | pb::PutResult message{}; | |
234 | message.set_app_metadata(buffer.data(), buffer.size()); | |
235 | if (writer_->Write(message)) { | |
236 | return Status::OK(); | |
237 | } | |
238 | return Status::IOError("Unknown error writing metadata."); | |
239 | } | |
240 | ||
241 | private: | |
242 | grpc::ServerReaderWriter<pb::PutResult, pb::FlightData>* writer_; | |
243 | }; | |
244 | ||
245 | class GrpcServerAuthReader : public ServerAuthReader { | |
246 | public: | |
247 | explicit GrpcServerAuthReader( | |
248 | grpc::ServerReaderWriter<pb::HandshakeResponse, pb::HandshakeRequest>* stream) | |
249 | : stream_(stream) {} | |
250 | ||
251 | Status Read(std::string* token) override { | |
252 | pb::HandshakeRequest request; | |
253 | if (stream_->Read(&request)) { | |
254 | *token = std::move(*request.mutable_payload()); | |
255 | return Status::OK(); | |
256 | } | |
257 | return Status::IOError("Stream is closed."); | |
258 | } | |
259 | ||
260 | private: | |
261 | grpc::ServerReaderWriter<pb::HandshakeResponse, pb::HandshakeRequest>* stream_; | |
262 | }; | |
263 | ||
264 | class GrpcServerAuthSender : public ServerAuthSender { | |
265 | public: | |
266 | explicit GrpcServerAuthSender( | |
267 | grpc::ServerReaderWriter<pb::HandshakeResponse, pb::HandshakeRequest>* stream) | |
268 | : stream_(stream) {} | |
269 | ||
270 | Status Write(const std::string& token) override { | |
271 | pb::HandshakeResponse response; | |
272 | response.set_payload(token); | |
273 | if (stream_->Write(response)) { | |
274 | return Status::OK(); | |
275 | } | |
276 | return Status::IOError("Stream was closed."); | |
277 | } | |
278 | ||
279 | private: | |
280 | grpc::ServerReaderWriter<pb::HandshakeResponse, pb::HandshakeRequest>* stream_; | |
281 | }; | |
282 | ||
283 | /// The implementation of the write side of a bidirectional FlightData | |
284 | /// stream for DoExchange. | |
285 | class DoExchangeMessageWriter : public FlightMessageWriter { | |
286 | public: | |
287 | explicit DoExchangeMessageWriter( | |
288 | grpc::ServerReaderWriter<pb::FlightData, pb::FlightData>* stream) | |
289 | : stream_(stream), ipc_options_(::arrow::ipc::IpcWriteOptions::Defaults()) {} | |
290 | ||
291 | Status Begin(const std::shared_ptr<Schema>& schema, | |
292 | const ipc::IpcWriteOptions& options) override { | |
293 | if (started_) { | |
294 | return Status::Invalid("This writer has already been started."); | |
295 | } | |
296 | started_ = true; | |
297 | ipc_options_ = options; | |
298 | ||
299 | RETURN_NOT_OK(mapper_.AddSchemaFields(*schema)); | |
300 | FlightPayload schema_payload; | |
301 | RETURN_NOT_OK(ipc::GetSchemaPayload(*schema, ipc_options_, mapper_, | |
302 | &schema_payload.ipc_message)); | |
303 | return WritePayload(schema_payload); | |
304 | } | |
305 | ||
306 | Status WriteRecordBatch(const RecordBatch& batch) override { | |
307 | return WriteWithMetadata(batch, nullptr); | |
308 | } | |
309 | ||
310 | Status WriteMetadata(std::shared_ptr<Buffer> app_metadata) override { | |
311 | FlightPayload payload{}; | |
312 | payload.app_metadata = app_metadata; | |
313 | return WritePayload(payload); | |
314 | } | |
315 | ||
316 | Status WriteWithMetadata(const RecordBatch& batch, | |
317 | std::shared_ptr<Buffer> app_metadata) override { | |
318 | RETURN_NOT_OK(CheckStarted()); | |
319 | RETURN_NOT_OK(EnsureDictionariesWritten(batch)); | |
320 | FlightPayload payload{}; | |
321 | if (app_metadata) { | |
322 | payload.app_metadata = app_metadata; | |
323 | } | |
324 | RETURN_NOT_OK(ipc::GetRecordBatchPayload(batch, ipc_options_, &payload.ipc_message)); | |
325 | RETURN_NOT_OK(WritePayload(payload)); | |
326 | ++stats_.num_record_batches; | |
327 | return Status::OK(); | |
328 | } | |
329 | ||
330 | Status Close() override { | |
331 | // It's fine to Close() without writing data | |
332 | return Status::OK(); | |
333 | } | |
334 | ||
335 | ipc::WriteStats stats() const override { return stats_; } | |
336 | ||
337 | private: | |
338 | Status WritePayload(const FlightPayload& payload) { | |
339 | RETURN_NOT_OK(internal::WritePayload(payload, stream_)); | |
340 | ++stats_.num_messages; | |
341 | return Status::OK(); | |
342 | } | |
343 | ||
344 | Status CheckStarted() { | |
345 | if (!started_) { | |
346 | return Status::Invalid("This writer is not started. Call Begin() with a schema"); | |
347 | } | |
348 | return Status::OK(); | |
349 | } | |
350 | ||
351 | Status EnsureDictionariesWritten(const RecordBatch& batch) { | |
352 | if (dictionaries_written_) { | |
353 | return Status::OK(); | |
354 | } | |
355 | dictionaries_written_ = true; | |
356 | ARROW_ASSIGN_OR_RAISE(const auto dictionaries, | |
357 | ipc::CollectDictionaries(batch, mapper_)); | |
358 | for (const auto& pair : dictionaries) { | |
359 | FlightPayload payload{}; | |
360 | RETURN_NOT_OK(ipc::GetDictionaryPayload(pair.first, pair.second, ipc_options_, | |
361 | &payload.ipc_message)); | |
362 | RETURN_NOT_OK(WritePayload(payload)); | |
363 | ++stats_.num_dictionary_batches; | |
364 | } | |
365 | return Status::OK(); | |
366 | } | |
367 | ||
368 | grpc::ServerReaderWriter<pb::FlightData, pb::FlightData>* stream_; | |
369 | ::arrow::ipc::IpcWriteOptions ipc_options_; | |
370 | ipc::DictionaryFieldMapper mapper_; | |
371 | ipc::WriteStats stats_; | |
372 | bool started_ = false; | |
373 | bool dictionaries_written_ = false; | |
374 | }; | |
375 | ||
376 | class FlightServiceImpl; | |
377 | class GrpcServerCallContext : public ServerCallContext { | |
378 | explicit GrpcServerCallContext(grpc::ServerContext* context) | |
379 | : context_(context), peer_(context_->peer()) {} | |
380 | ||
381 | const std::string& peer_identity() const override { return peer_identity_; } | |
382 | const std::string& peer() const override { return peer_; } | |
383 | bool is_cancelled() const override { return context_->IsCancelled(); } | |
384 | ||
385 | // Helper method that runs interceptors given the result of an RPC, | |
386 | // then returns the final gRPC status to send to the client | |
387 | grpc::Status FinishRequest(const grpc::Status& status) { | |
388 | // Don't double-convert status - return the original one here | |
389 | FinishRequest(internal::FromGrpcStatus(status)); | |
390 | return status; | |
391 | } | |
392 | ||
393 | grpc::Status FinishRequest(const arrow::Status& status) { | |
394 | for (const auto& instance : middleware_) { | |
395 | instance->CallCompleted(status); | |
396 | } | |
397 | ||
398 | // Set custom headers to map the exact Arrow status for clients | |
399 | // who want it. | |
400 | return internal::ToGrpcStatus(status, context_); | |
401 | } | |
402 | ||
403 | ServerMiddleware* GetMiddleware(const std::string& key) const override { | |
404 | const auto& instance = middleware_map_.find(key); | |
405 | if (instance == middleware_map_.end()) { | |
406 | return nullptr; | |
407 | } | |
408 | return instance->second.get(); | |
409 | } | |
410 | ||
411 | private: | |
412 | friend class FlightServiceImpl; | |
413 | ServerContext* context_; | |
414 | std::string peer_; | |
415 | std::string peer_identity_; | |
416 | std::vector<std::shared_ptr<ServerMiddleware>> middleware_; | |
417 | std::unordered_map<std::string, std::shared_ptr<ServerMiddleware>> middleware_map_; | |
418 | }; | |
419 | ||
420 | class GrpcAddCallHeaders : public AddCallHeaders { | |
421 | public: | |
422 | explicit GrpcAddCallHeaders(grpc::ServerContext* context) : context_(context) {} | |
423 | ~GrpcAddCallHeaders() override = default; | |
424 | ||
425 | void AddHeader(const std::string& key, const std::string& value) override { | |
426 | context_->AddInitialMetadata(key, value); | |
427 | } | |
428 | ||
429 | private: | |
430 | grpc::ServerContext* context_; | |
431 | }; | |
432 | ||
433 | // This class glues an implementation of FlightServerBase together with the | |
434 | // gRPC service definition, so the latter is not exposed in the public API | |
435 | class FlightServiceImpl : public FlightService::Service { | |
436 | public: | |
437 | explicit FlightServiceImpl( | |
438 | std::shared_ptr<ServerAuthHandler> auth_handler, | |
439 | std::vector<std::pair<std::string, std::shared_ptr<ServerMiddlewareFactory>>> | |
440 | middleware, | |
441 | FlightServerBase* server) | |
442 | : auth_handler_(auth_handler), middleware_(middleware), server_(server) {} | |
443 | ||
444 | template <typename UserType, typename Iterator, typename ProtoType> | |
445 | grpc::Status WriteStream(Iterator* iterator, ServerWriter<ProtoType>* writer) { | |
446 | if (!iterator) { | |
447 | return grpc::Status(grpc::StatusCode::INTERNAL, "No items to iterate"); | |
448 | } | |
449 | // Write flight info to stream until listing is exhausted | |
450 | while (true) { | |
451 | ProtoType pb_value; | |
452 | std::unique_ptr<UserType> value; | |
453 | GRPC_RETURN_NOT_OK(iterator->Next(&value)); | |
454 | if (!value) { | |
455 | break; | |
456 | } | |
457 | GRPC_RETURN_NOT_OK(internal::ToProto(*value, &pb_value)); | |
458 | ||
459 | // Blocking write | |
460 | if (!writer->Write(pb_value)) { | |
461 | // Write returns false if the stream is closed | |
462 | break; | |
463 | } | |
464 | } | |
465 | return grpc::Status::OK; | |
466 | } | |
467 | ||
468 | template <typename UserType, typename ProtoType> | |
469 | grpc::Status WriteStream(const std::vector<UserType>& values, | |
470 | ServerWriter<ProtoType>* writer) { | |
471 | // Write flight info to stream until listing is exhausted | |
472 | for (const UserType& value : values) { | |
473 | ProtoType pb_value; | |
474 | GRPC_RETURN_NOT_OK(internal::ToProto(value, &pb_value)); | |
475 | // Blocking write | |
476 | if (!writer->Write(pb_value)) { | |
477 | // Write returns false if the stream is closed | |
478 | break; | |
479 | } | |
480 | } | |
481 | return grpc::Status::OK; | |
482 | } | |
483 | ||
484 | // Authenticate the client (if applicable) and construct the call context | |
485 | grpc::Status CheckAuth(const FlightMethod& method, ServerContext* context, | |
486 | GrpcServerCallContext& flight_context) { | |
487 | if (!auth_handler_) { | |
488 | const auto auth_context = context->auth_context(); | |
489 | if (auth_context && auth_context->IsPeerAuthenticated()) { | |
490 | auto peer_identity = auth_context->GetPeerIdentity(); | |
491 | flight_context.peer_identity_ = | |
492 | peer_identity.empty() | |
493 | ? "" | |
494 | : std::string(peer_identity.front().begin(), peer_identity.front().end()); | |
495 | } else { | |
496 | flight_context.peer_identity_ = ""; | |
497 | } | |
498 | } else { | |
499 | const auto client_metadata = context->client_metadata(); | |
500 | const auto auth_header = client_metadata.find(internal::kGrpcAuthHeader); | |
501 | std::string token; | |
502 | if (auth_header == client_metadata.end()) { | |
503 | token = ""; | |
504 | } else { | |
505 | token = std::string(auth_header->second.data(), auth_header->second.length()); | |
506 | } | |
507 | GRPC_RETURN_NOT_OK(auth_handler_->IsValid(token, &flight_context.peer_identity_)); | |
508 | } | |
509 | ||
510 | return MakeCallContext(method, context, flight_context); | |
511 | } | |
512 | ||
513 | // Authenticate the client (if applicable) and construct the call context | |
514 | grpc::Status MakeCallContext(const FlightMethod& method, ServerContext* context, | |
515 | GrpcServerCallContext& flight_context) { | |
516 | // Run server middleware | |
517 | const CallInfo info{method}; | |
518 | CallHeaders incoming_headers; | |
519 | for (const auto& entry : context->client_metadata()) { | |
520 | incoming_headers.insert( | |
521 | {util::string_view(entry.first.data(), entry.first.length()), | |
522 | util::string_view(entry.second.data(), entry.second.length())}); | |
523 | } | |
524 | ||
525 | GrpcAddCallHeaders outgoing_headers(context); | |
526 | for (const auto& factory : middleware_) { | |
527 | std::shared_ptr<ServerMiddleware> instance; | |
528 | Status result = factory.second->StartCall(info, incoming_headers, &instance); | |
529 | if (!result.ok()) { | |
530 | // Interceptor rejected call, end the request on all existing | |
531 | // interceptors | |
532 | return flight_context.FinishRequest(result); | |
533 | } | |
534 | if (instance != nullptr) { | |
535 | flight_context.middleware_.push_back(instance); | |
536 | flight_context.middleware_map_.insert({factory.first, instance}); | |
537 | instance->SendingHeaders(&outgoing_headers); | |
538 | } | |
539 | } | |
540 | ||
541 | return grpc::Status::OK; | |
542 | } | |
543 | ||
544 | grpc::Status Handshake( | |
545 | ServerContext* context, | |
546 | grpc::ServerReaderWriter<pb::HandshakeResponse, pb::HandshakeRequest>* stream) { | |
547 | GrpcServerCallContext flight_context(context); | |
548 | GRPC_RETURN_NOT_GRPC_OK( | |
549 | MakeCallContext(FlightMethod::Handshake, context, flight_context)); | |
550 | ||
551 | if (!auth_handler_) { | |
552 | RETURN_WITH_MIDDLEWARE( | |
553 | flight_context, | |
554 | grpc::Status( | |
555 | grpc::StatusCode::UNIMPLEMENTED, | |
556 | "This service does not have an authentication mechanism enabled.")); | |
557 | } | |
558 | GrpcServerAuthSender outgoing{stream}; | |
559 | GrpcServerAuthReader incoming{stream}; | |
560 | RETURN_WITH_MIDDLEWARE(flight_context, | |
561 | auth_handler_->Authenticate(&outgoing, &incoming)); | |
562 | } | |
563 | ||
564 | grpc::Status ListFlights(ServerContext* context, const pb::Criteria* request, | |
565 | ServerWriter<pb::FlightInfo>* writer) { | |
566 | GrpcServerCallContext flight_context(context); | |
567 | GRPC_RETURN_NOT_GRPC_OK( | |
568 | CheckAuth(FlightMethod::ListFlights, context, flight_context)); | |
569 | ||
570 | // Retrieve the listing from the implementation | |
571 | std::unique_ptr<FlightListing> listing; | |
572 | ||
573 | Criteria criteria; | |
574 | if (request) { | |
575 | SERVICE_RETURN_NOT_OK(flight_context, internal::FromProto(*request, &criteria)); | |
576 | } | |
577 | SERVICE_RETURN_NOT_OK(flight_context, | |
578 | server_->ListFlights(flight_context, &criteria, &listing)); | |
579 | if (!listing) { | |
580 | // Treat null listing as no flights available | |
581 | RETURN_WITH_MIDDLEWARE(flight_context, grpc::Status::OK); | |
582 | } | |
583 | RETURN_WITH_MIDDLEWARE(flight_context, | |
584 | WriteStream<FlightInfo>(listing.get(), writer)); | |
585 | } | |
586 | ||
587 | grpc::Status GetFlightInfo(ServerContext* context, const pb::FlightDescriptor* request, | |
588 | pb::FlightInfo* response) { | |
589 | GrpcServerCallContext flight_context(context); | |
590 | GRPC_RETURN_NOT_GRPC_OK( | |
591 | CheckAuth(FlightMethod::GetFlightInfo, context, flight_context)); | |
592 | ||
593 | CHECK_ARG_NOT_NULL(flight_context, request, "FlightDescriptor cannot be null"); | |
594 | ||
595 | FlightDescriptor descr; | |
596 | SERVICE_RETURN_NOT_OK(flight_context, internal::FromProto(*request, &descr)); | |
597 | ||
598 | std::unique_ptr<FlightInfo> info; | |
599 | SERVICE_RETURN_NOT_OK(flight_context, | |
600 | server_->GetFlightInfo(flight_context, descr, &info)); | |
601 | ||
602 | if (!info) { | |
603 | // Treat null listing as no flights available | |
604 | RETURN_WITH_MIDDLEWARE( | |
605 | flight_context, grpc::Status(grpc::StatusCode::NOT_FOUND, "Flight not found")); | |
606 | } | |
607 | ||
608 | SERVICE_RETURN_NOT_OK(flight_context, internal::ToProto(*info, response)); | |
609 | RETURN_WITH_MIDDLEWARE(flight_context, grpc::Status::OK); | |
610 | } | |
611 | ||
612 | grpc::Status GetSchema(ServerContext* context, const pb::FlightDescriptor* request, | |
613 | pb::SchemaResult* response) { | |
614 | GrpcServerCallContext flight_context(context); | |
615 | GRPC_RETURN_NOT_GRPC_OK(CheckAuth(FlightMethod::GetSchema, context, flight_context)); | |
616 | ||
617 | CHECK_ARG_NOT_NULL(flight_context, request, "FlightDescriptor cannot be null"); | |
618 | ||
619 | FlightDescriptor descr; | |
620 | SERVICE_RETURN_NOT_OK(flight_context, internal::FromProto(*request, &descr)); | |
621 | ||
622 | std::unique_ptr<SchemaResult> result; | |
623 | SERVICE_RETURN_NOT_OK(flight_context, | |
624 | server_->GetSchema(flight_context, descr, &result)); | |
625 | ||
626 | if (!result) { | |
627 | // Treat null listing as no flights available | |
628 | RETURN_WITH_MIDDLEWARE( | |
629 | flight_context, grpc::Status(grpc::StatusCode::NOT_FOUND, "Flight not found")); | |
630 | } | |
631 | ||
632 | SERVICE_RETURN_NOT_OK(flight_context, internal::ToProto(*result, response)); | |
633 | RETURN_WITH_MIDDLEWARE(flight_context, grpc::Status::OK); | |
634 | } | |
635 | ||
636 | grpc::Status DoGet(ServerContext* context, const pb::Ticket* request, | |
637 | ServerWriter<pb::FlightData>* writer) { | |
638 | GrpcServerCallContext flight_context(context); | |
639 | GRPC_RETURN_NOT_GRPC_OK(CheckAuth(FlightMethod::DoGet, context, flight_context)); | |
640 | ||
641 | CHECK_ARG_NOT_NULL(flight_context, request, "ticket cannot be null"); | |
642 | ||
643 | Ticket ticket; | |
644 | SERVICE_RETURN_NOT_OK(flight_context, internal::FromProto(*request, &ticket)); | |
645 | ||
646 | std::unique_ptr<FlightDataStream> data_stream; | |
647 | SERVICE_RETURN_NOT_OK(flight_context, | |
648 | server_->DoGet(flight_context, ticket, &data_stream)); | |
649 | ||
650 | if (!data_stream) { | |
651 | RETURN_WITH_MIDDLEWARE(flight_context, grpc::Status(grpc::StatusCode::NOT_FOUND, | |
652 | "No data in this flight")); | |
653 | } | |
654 | ||
655 | // Write the schema as the first message in the stream | |
656 | FlightPayload schema_payload; | |
657 | SERVICE_RETURN_NOT_OK(flight_context, data_stream->GetSchemaPayload(&schema_payload)); | |
658 | auto status = internal::WritePayload(schema_payload, writer); | |
659 | if (status.IsIOError()) { | |
660 | // gRPC doesn't give any way for us to know why the message | |
661 | // could not be written. | |
662 | RETURN_WITH_MIDDLEWARE(flight_context, grpc::Status::OK); | |
663 | } | |
664 | SERVICE_RETURN_NOT_OK(flight_context, status); | |
665 | ||
666 | // Consume data stream and write out payloads | |
667 | while (true) { | |
668 | FlightPayload payload; | |
669 | SERVICE_RETURN_NOT_OK(flight_context, data_stream->Next(&payload)); | |
670 | // End of stream | |
671 | if (payload.ipc_message.metadata == nullptr) break; | |
672 | auto status = internal::WritePayload(payload, writer); | |
673 | // Connection terminated | |
674 | if (status.IsIOError()) break; | |
675 | SERVICE_RETURN_NOT_OK(flight_context, status); | |
676 | } | |
677 | RETURN_WITH_MIDDLEWARE(flight_context, grpc::Status::OK); | |
678 | } | |
679 | ||
680 | grpc::Status DoPut(ServerContext* context, | |
681 | grpc::ServerReaderWriter<pb::PutResult, pb::FlightData>* reader) { | |
682 | GrpcServerCallContext flight_context(context); | |
683 | GRPC_RETURN_NOT_GRPC_OK(CheckAuth(FlightMethod::DoPut, context, flight_context)); | |
684 | ||
685 | auto message_reader = std::unique_ptr<FlightMessageReaderImpl<pb::PutResult>>( | |
686 | new FlightMessageReaderImpl<pb::PutResult>(reader)); | |
687 | SERVICE_RETURN_NOT_OK(flight_context, message_reader->Init()); | |
688 | auto metadata_writer = | |
689 | std::unique_ptr<FlightMetadataWriter>(new GrpcMetadataWriter(reader)); | |
690 | RETURN_WITH_MIDDLEWARE(flight_context, | |
691 | server_->DoPut(flight_context, std::move(message_reader), | |
692 | std::move(metadata_writer))); | |
693 | } | |
694 | ||
695 | grpc::Status DoExchange( | |
696 | ServerContext* context, | |
697 | grpc::ServerReaderWriter<pb::FlightData, pb::FlightData>* stream) { | |
698 | GrpcServerCallContext flight_context(context); | |
699 | GRPC_RETURN_NOT_GRPC_OK(CheckAuth(FlightMethod::DoExchange, context, flight_context)); | |
700 | auto message_reader = std::unique_ptr<FlightMessageReaderImpl<pb::FlightData>>( | |
701 | new FlightMessageReaderImpl<pb::FlightData>(stream)); | |
702 | SERVICE_RETURN_NOT_OK(flight_context, message_reader->Init()); | |
703 | auto writer = | |
704 | std::unique_ptr<DoExchangeMessageWriter>(new DoExchangeMessageWriter(stream)); | |
705 | RETURN_WITH_MIDDLEWARE(flight_context, | |
706 | server_->DoExchange(flight_context, std::move(message_reader), | |
707 | std::move(writer))); | |
708 | } | |
709 | ||
710 | grpc::Status ListActions(ServerContext* context, const pb::Empty* request, | |
711 | ServerWriter<pb::ActionType>* writer) { | |
712 | GrpcServerCallContext flight_context(context); | |
713 | GRPC_RETURN_NOT_GRPC_OK( | |
714 | CheckAuth(FlightMethod::ListActions, context, flight_context)); | |
715 | // Retrieve the listing from the implementation | |
716 | std::vector<ActionType> types; | |
717 | SERVICE_RETURN_NOT_OK(flight_context, server_->ListActions(flight_context, &types)); | |
718 | RETURN_WITH_MIDDLEWARE(flight_context, WriteStream<ActionType>(types, writer)); | |
719 | } | |
720 | ||
721 | grpc::Status DoAction(ServerContext* context, const pb::Action* request, | |
722 | ServerWriter<pb::Result>* writer) { | |
723 | GrpcServerCallContext flight_context(context); | |
724 | GRPC_RETURN_NOT_GRPC_OK(CheckAuth(FlightMethod::DoAction, context, flight_context)); | |
725 | CHECK_ARG_NOT_NULL(flight_context, request, "Action cannot be null"); | |
726 | Action action; | |
727 | SERVICE_RETURN_NOT_OK(flight_context, internal::FromProto(*request, &action)); | |
728 | ||
729 | std::unique_ptr<ResultStream> results; | |
730 | SERVICE_RETURN_NOT_OK(flight_context, | |
731 | server_->DoAction(flight_context, action, &results)); | |
732 | ||
733 | if (!results) { | |
734 | RETURN_WITH_MIDDLEWARE(flight_context, grpc::Status::CANCELLED); | |
735 | } | |
736 | ||
737 | while (true) { | |
738 | std::unique_ptr<Result> result; | |
739 | SERVICE_RETURN_NOT_OK(flight_context, results->Next(&result)); | |
740 | if (!result) { | |
741 | // No more results | |
742 | break; | |
743 | } | |
744 | pb::Result pb_result; | |
745 | SERVICE_RETURN_NOT_OK(flight_context, internal::ToProto(*result, &pb_result)); | |
746 | if (!writer->Write(pb_result)) { | |
747 | // Stream may be closed | |
748 | break; | |
749 | } | |
750 | } | |
751 | RETURN_WITH_MIDDLEWARE(flight_context, grpc::Status::OK); | |
752 | } | |
753 | ||
754 | private: | |
755 | std::shared_ptr<ServerAuthHandler> auth_handler_; | |
756 | std::vector<std::pair<std::string, std::shared_ptr<ServerMiddlewareFactory>>> | |
757 | middleware_; | |
758 | FlightServerBase* server_; | |
759 | }; | |
760 | ||
761 | } // namespace | |
762 | ||
763 | FlightMetadataWriter::~FlightMetadataWriter() = default; | |
764 | ||
765 | // | |
766 | // gRPC server lifecycle | |
767 | // | |
768 | ||
769 | #if (ATOMIC_INT_LOCK_FREE != 2 || ATOMIC_POINTER_LOCK_FREE != 2) | |
770 | #error "atomic ints and atomic pointers not always lock-free!" | |
771 | #endif | |
772 | ||
773 | using ::arrow::internal::SetSignalHandler; | |
774 | using ::arrow::internal::SignalHandler; | |
775 | ||
776 | #ifdef WIN32 | |
777 | #define PIPE_WRITE _write | |
778 | #define PIPE_READ _read | |
779 | #else | |
780 | #define PIPE_WRITE write | |
781 | #define PIPE_READ read | |
782 | #endif | |
783 | ||
784 | /// RAII guard that manages a self-pipe and a thread that listens on | |
785 | /// the self-pipe, shutting down the gRPC server when a signal handler | |
786 | /// writes to the pipe. | |
787 | class ServerSignalHandler { | |
788 | public: | |
789 | ARROW_DISALLOW_COPY_AND_ASSIGN(ServerSignalHandler); | |
790 | ServerSignalHandler() = default; | |
791 | ||
792 | /// Create the pipe and handler thread. | |
793 | /// | |
794 | /// \return the fd of the write side of the pipe. | |
795 | template <typename Fn> | |
796 | arrow::Result<int> Init(Fn handler) { | |
797 | ARROW_ASSIGN_OR_RAISE(auto pipe, arrow::internal::CreatePipe()); | |
798 | #ifndef WIN32 | |
799 | // Make write end nonblocking | |
800 | int flags = fcntl(pipe.wfd, F_GETFL); | |
801 | if (flags == -1) { | |
802 | RETURN_NOT_OK(arrow::internal::FileClose(pipe.rfd)); | |
803 | RETURN_NOT_OK(arrow::internal::FileClose(pipe.wfd)); | |
804 | return arrow::internal::IOErrorFromErrno( | |
805 | errno, "Could not initialize self-pipe to wait for signals"); | |
806 | } | |
807 | flags |= O_NONBLOCK; | |
808 | if (fcntl(pipe.wfd, F_SETFL, flags) == -1) { | |
809 | RETURN_NOT_OK(arrow::internal::FileClose(pipe.rfd)); | |
810 | RETURN_NOT_OK(arrow::internal::FileClose(pipe.wfd)); | |
811 | return arrow::internal::IOErrorFromErrno( | |
812 | errno, "Could not initialize self-pipe to wait for signals"); | |
813 | } | |
814 | #endif | |
815 | self_pipe_ = pipe; | |
816 | handle_signals_ = std::thread(handler, self_pipe_.rfd); | |
817 | return self_pipe_.wfd; | |
818 | } | |
819 | ||
820 | Status Shutdown() { | |
821 | if (self_pipe_.rfd == 0) { | |
822 | // Already closed | |
823 | return Status::OK(); | |
824 | } | |
825 | if (PIPE_WRITE(self_pipe_.wfd, "0", 1) < 0 && errno != EAGAIN && | |
826 | errno != EWOULDBLOCK && errno != EINTR) { | |
827 | return arrow::internal::IOErrorFromErrno(errno, "Could not unblock signal thread"); | |
828 | } | |
829 | RETURN_NOT_OK(arrow::internal::FileClose(self_pipe_.rfd)); | |
830 | RETURN_NOT_OK(arrow::internal::FileClose(self_pipe_.wfd)); | |
831 | handle_signals_.join(); | |
832 | self_pipe_.rfd = 0; | |
833 | self_pipe_.wfd = 0; | |
834 | return Status::OK(); | |
835 | } | |
836 | ||
837 | ~ServerSignalHandler() { ARROW_CHECK_OK(Shutdown()); } | |
838 | ||
839 | private: | |
840 | arrow::internal::Pipe self_pipe_; | |
841 | std::thread handle_signals_; | |
842 | }; | |
843 | ||
844 | struct FlightServerBase::Impl { | |
845 | std::unique_ptr<FlightServiceImpl> service_; | |
846 | std::unique_ptr<grpc::Server> server_; | |
847 | int port_; | |
848 | ||
849 | // Signal handlers (on Windows) and the shutdown handler (other platforms) | |
850 | // are executed in a separate thread, so getting the current thread instance | |
851 | // wouldn't make sense. This means only a single instance can receive signals. | |
852 | static std::atomic<Impl*> running_instance_; | |
853 | // We'll use the self-pipe trick to notify a thread from the signal | |
854 | // handler. The thread will then shut down the gRPC server. | |
855 | int self_pipe_wfd_; | |
856 | ||
857 | // Signal handling | |
858 | std::vector<int> signals_; | |
859 | std::vector<SignalHandler> old_signal_handlers_; | |
860 | std::atomic<int> got_signal_; | |
861 | ||
862 | static void HandleSignal(int signum) { | |
863 | auto instance = running_instance_.load(); | |
864 | if (instance != nullptr) { | |
865 | instance->DoHandleSignal(signum); | |
866 | } | |
867 | } | |
868 | ||
869 | void DoHandleSignal(int signum) { | |
870 | got_signal_ = signum; | |
871 | int saved_errno = errno; | |
872 | // Ignore errors - pipe is nonblocking | |
873 | PIPE_WRITE(self_pipe_wfd_, "0", 1); | |
874 | errno = saved_errno; | |
875 | } | |
876 | ||
877 | static void WaitForSignals(int fd) { | |
878 | // Wait for a signal handler to write to the pipe | |
879 | int8_t buf[1]; | |
880 | while (PIPE_READ(fd, /*buf=*/buf, /*count=*/1) == -1) { | |
881 | if (errno == EINTR) { | |
882 | continue; | |
883 | } | |
884 | ARROW_CHECK_OK(arrow::internal::IOErrorFromErrno( | |
885 | errno, "Error while waiting for shutdown signal")); | |
886 | } | |
887 | auto instance = running_instance_.load(); | |
888 | if (instance != nullptr) { | |
889 | instance->server_->Shutdown(); | |
890 | } | |
891 | } | |
892 | }; | |
893 | ||
894 | std::atomic<FlightServerBase::Impl*> FlightServerBase::Impl::running_instance_; | |
895 | ||
896 | FlightServerOptions::FlightServerOptions(const Location& location_) | |
897 | : location(location_), | |
898 | auth_handler(nullptr), | |
899 | tls_certificates(), | |
900 | verify_client(false), | |
901 | root_certificates(), | |
902 | middleware(), | |
903 | builder_hook(nullptr) {} | |
904 | ||
905 | FlightServerOptions::~FlightServerOptions() = default; | |
906 | ||
907 | FlightServerBase::FlightServerBase() { impl_.reset(new Impl); } | |
908 | ||
909 | FlightServerBase::~FlightServerBase() {} | |
910 | ||
911 | Status FlightServerBase::Init(const FlightServerOptions& options) { | |
912 | impl_->service_.reset( | |
913 | new FlightServiceImpl(options.auth_handler, options.middleware, this)); | |
914 | ||
915 | grpc::ServerBuilder builder; | |
916 | // Allow uploading messages of any length | |
917 | builder.SetMaxReceiveMessageSize(-1); | |
918 | ||
919 | const Location& location = options.location; | |
920 | const std::string scheme = location.scheme(); | |
921 | if (scheme == kSchemeGrpc || scheme == kSchemeGrpcTcp || scheme == kSchemeGrpcTls) { | |
922 | std::stringstream address; | |
923 | address << arrow::internal::UriEncodeHost(location.uri_->host()) << ':' | |
924 | << location.uri_->port_text(); | |
925 | ||
926 | std::shared_ptr<grpc::ServerCredentials> creds; | |
927 | if (scheme == kSchemeGrpcTls) { | |
928 | grpc::SslServerCredentialsOptions ssl_options; | |
929 | for (const auto& pair : options.tls_certificates) { | |
930 | ssl_options.pem_key_cert_pairs.push_back({pair.pem_key, pair.pem_cert}); | |
931 | } | |
932 | if (options.verify_client) { | |
933 | ssl_options.client_certificate_request = | |
934 | GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY; | |
935 | } | |
936 | if (!options.root_certificates.empty()) { | |
937 | ssl_options.pem_root_certs = options.root_certificates; | |
938 | } | |
939 | creds = grpc::SslServerCredentials(ssl_options); | |
940 | } else { | |
941 | creds = grpc::InsecureServerCredentials(); | |
942 | } | |
943 | ||
944 | builder.AddListeningPort(address.str(), creds, &impl_->port_); | |
945 | } else if (scheme == kSchemeGrpcUnix) { | |
946 | std::stringstream address; | |
947 | address << "unix:" << location.uri_->path(); | |
948 | builder.AddListeningPort(address.str(), grpc::InsecureServerCredentials()); | |
949 | } else { | |
950 | return Status::NotImplemented("Scheme is not supported: " + scheme); | |
951 | } | |
952 | ||
953 | builder.RegisterService(impl_->service_.get()); | |
954 | ||
955 | // Disable SO_REUSEPORT - it makes debugging/testing a pain as | |
956 | // leftover processes can handle requests on accident | |
957 | builder.AddChannelArgument(GRPC_ARG_ALLOW_REUSEPORT, 0); | |
958 | ||
959 | if (options.builder_hook) { | |
960 | options.builder_hook(&builder); | |
961 | } | |
962 | ||
963 | impl_->server_ = builder.BuildAndStart(); | |
964 | if (!impl_->server_) { | |
965 | return Status::UnknownError("Server did not start properly"); | |
966 | } | |
967 | return Status::OK(); | |
968 | } | |
969 | ||
970 | int FlightServerBase::port() const { return impl_->port_; } | |
971 | ||
972 | Status FlightServerBase::SetShutdownOnSignals(const std::vector<int> sigs) { | |
973 | impl_->signals_ = sigs; | |
974 | impl_->old_signal_handlers_.clear(); | |
975 | return Status::OK(); | |
976 | } | |
977 | ||
978 | Status FlightServerBase::Serve() { | |
979 | if (!impl_->server_) { | |
980 | return Status::UnknownError("Server did not start properly"); | |
981 | } | |
982 | impl_->got_signal_ = 0; | |
983 | impl_->old_signal_handlers_.clear(); | |
984 | impl_->running_instance_ = impl_.get(); | |
985 | ||
986 | ServerSignalHandler signal_handler; | |
987 | ARROW_ASSIGN_OR_RAISE(impl_->self_pipe_wfd_, | |
988 | signal_handler.Init(&Impl::WaitForSignals)); | |
989 | // Override existing signal handlers with our own handler so as to stop the server. | |
990 | for (size_t i = 0; i < impl_->signals_.size(); ++i) { | |
991 | int signum = impl_->signals_[i]; | |
992 | SignalHandler new_handler(&Impl::HandleSignal), old_handler; | |
993 | ARROW_ASSIGN_OR_RAISE(old_handler, SetSignalHandler(signum, new_handler)); | |
994 | impl_->old_signal_handlers_.push_back(std::move(old_handler)); | |
995 | } | |
996 | ||
997 | impl_->server_->Wait(); | |
998 | impl_->running_instance_ = nullptr; | |
999 | ||
1000 | // Restore signal handlers | |
1001 | for (size_t i = 0; i < impl_->signals_.size(); ++i) { | |
1002 | RETURN_NOT_OK( | |
1003 | SetSignalHandler(impl_->signals_[i], impl_->old_signal_handlers_[i]).status()); | |
1004 | } | |
1005 | return Status::OK(); | |
1006 | } | |
1007 | ||
1008 | int FlightServerBase::GotSignal() const { return impl_->got_signal_; } | |
1009 | ||
1010 | Status FlightServerBase::Shutdown() { | |
1011 | auto server = impl_->server_.get(); | |
1012 | if (!server) { | |
1013 | return Status::Invalid("Shutdown() on uninitialized FlightServerBase"); | |
1014 | } | |
1015 | impl_->server_->Shutdown(); | |
1016 | return Status::OK(); | |
1017 | } | |
1018 | ||
1019 | Status FlightServerBase::Wait() { | |
1020 | impl_->server_->Wait(); | |
1021 | impl_->running_instance_ = nullptr; | |
1022 | return Status::OK(); | |
1023 | } | |
1024 | ||
1025 | Status FlightServerBase::ListFlights(const ServerCallContext& context, | |
1026 | const Criteria* criteria, | |
1027 | std::unique_ptr<FlightListing>* listings) { | |
1028 | return Status::NotImplemented("NYI"); | |
1029 | } | |
1030 | ||
1031 | Status FlightServerBase::GetFlightInfo(const ServerCallContext& context, | |
1032 | const FlightDescriptor& request, | |
1033 | std::unique_ptr<FlightInfo>* info) { | |
1034 | return Status::NotImplemented("NYI"); | |
1035 | } | |
1036 | ||
1037 | Status FlightServerBase::DoGet(const ServerCallContext& context, const Ticket& request, | |
1038 | std::unique_ptr<FlightDataStream>* data_stream) { | |
1039 | return Status::NotImplemented("NYI"); | |
1040 | } | |
1041 | ||
1042 | Status FlightServerBase::DoPut(const ServerCallContext& context, | |
1043 | std::unique_ptr<FlightMessageReader> reader, | |
1044 | std::unique_ptr<FlightMetadataWriter> writer) { | |
1045 | return Status::NotImplemented("NYI"); | |
1046 | } | |
1047 | ||
1048 | Status FlightServerBase::DoExchange(const ServerCallContext& context, | |
1049 | std::unique_ptr<FlightMessageReader> reader, | |
1050 | std::unique_ptr<FlightMessageWriter> writer) { | |
1051 | return Status::NotImplemented("NYI"); | |
1052 | } | |
1053 | ||
1054 | Status FlightServerBase::DoAction(const ServerCallContext& context, const Action& action, | |
1055 | std::unique_ptr<ResultStream>* result) { | |
1056 | return Status::NotImplemented("NYI"); | |
1057 | } | |
1058 | ||
1059 | Status FlightServerBase::ListActions(const ServerCallContext& context, | |
1060 | std::vector<ActionType>* actions) { | |
1061 | return Status::NotImplemented("NYI"); | |
1062 | } | |
1063 | ||
1064 | Status FlightServerBase::GetSchema(const ServerCallContext& context, | |
1065 | const FlightDescriptor& request, | |
1066 | std::unique_ptr<SchemaResult>* schema) { | |
1067 | return Status::NotImplemented("NYI"); | |
1068 | } | |
1069 | ||
1070 | // ---------------------------------------------------------------------- | |
1071 | // Implement RecordBatchStream | |
1072 | ||
1073 | class RecordBatchStream::RecordBatchStreamImpl { | |
1074 | public: | |
1075 | // Stages of the stream when producing payloads | |
1076 | enum class Stage { | |
1077 | NEW, // The stream has been created, but Next has not been called yet | |
1078 | DICTIONARY, // Dictionaries have been collected, and are being sent | |
1079 | RECORD_BATCH // Initial have been sent | |
1080 | }; | |
1081 | ||
1082 | RecordBatchStreamImpl(const std::shared_ptr<RecordBatchReader>& reader, | |
1083 | const ipc::IpcWriteOptions& options) | |
1084 | : reader_(reader), mapper_(*reader_->schema()), ipc_options_(options) {} | |
1085 | ||
1086 | std::shared_ptr<Schema> schema() { return reader_->schema(); } | |
1087 | ||
1088 | Status GetSchemaPayload(FlightPayload* payload) { | |
1089 | return ipc::GetSchemaPayload(*reader_->schema(), ipc_options_, mapper_, | |
1090 | &payload->ipc_message); | |
1091 | } | |
1092 | ||
1093 | Status Next(FlightPayload* payload) { | |
1094 | if (stage_ == Stage::NEW) { | |
1095 | RETURN_NOT_OK(reader_->ReadNext(¤t_batch_)); | |
1096 | if (!current_batch_) { | |
1097 | // Signal that iteration is over | |
1098 | payload->ipc_message.metadata = nullptr; | |
1099 | return Status::OK(); | |
1100 | } | |
1101 | ARROW_ASSIGN_OR_RAISE(dictionaries_, | |
1102 | ipc::CollectDictionaries(*current_batch_, mapper_)); | |
1103 | stage_ = Stage::DICTIONARY; | |
1104 | } | |
1105 | ||
1106 | if (stage_ == Stage::DICTIONARY) { | |
1107 | if (dictionary_index_ == static_cast<int>(dictionaries_.size())) { | |
1108 | stage_ = Stage::RECORD_BATCH; | |
1109 | return ipc::GetRecordBatchPayload(*current_batch_, ipc_options_, | |
1110 | &payload->ipc_message); | |
1111 | } else { | |
1112 | return GetNextDictionary(payload); | |
1113 | } | |
1114 | } | |
1115 | ||
1116 | RETURN_NOT_OK(reader_->ReadNext(¤t_batch_)); | |
1117 | ||
1118 | // TODO(wesm): Delta dictionaries | |
1119 | if (!current_batch_) { | |
1120 | // Signal that iteration is over | |
1121 | payload->ipc_message.metadata = nullptr; | |
1122 | return Status::OK(); | |
1123 | } else { | |
1124 | return ipc::GetRecordBatchPayload(*current_batch_, ipc_options_, | |
1125 | &payload->ipc_message); | |
1126 | } | |
1127 | } | |
1128 | ||
1129 | private: | |
1130 | Status GetNextDictionary(FlightPayload* payload) { | |
1131 | const auto& it = dictionaries_[dictionary_index_++]; | |
1132 | return ipc::GetDictionaryPayload(it.first, it.second, ipc_options_, | |
1133 | &payload->ipc_message); | |
1134 | } | |
1135 | ||
1136 | Stage stage_ = Stage::NEW; | |
1137 | std::shared_ptr<RecordBatchReader> reader_; | |
1138 | ipc::DictionaryFieldMapper mapper_; | |
1139 | ipc::IpcWriteOptions ipc_options_; | |
1140 | std::shared_ptr<RecordBatch> current_batch_; | |
1141 | std::vector<std::pair<int64_t, std::shared_ptr<Array>>> dictionaries_; | |
1142 | ||
1143 | // Index of next dictionary to send | |
1144 | int dictionary_index_ = 0; | |
1145 | }; | |
1146 | ||
1147 | FlightDataStream::~FlightDataStream() {} | |
1148 | ||
1149 | RecordBatchStream::RecordBatchStream(const std::shared_ptr<RecordBatchReader>& reader, | |
1150 | const ipc::IpcWriteOptions& options) { | |
1151 | impl_.reset(new RecordBatchStreamImpl(reader, options)); | |
1152 | } | |
1153 | ||
1154 | RecordBatchStream::~RecordBatchStream() {} | |
1155 | ||
1156 | std::shared_ptr<Schema> RecordBatchStream::schema() { return impl_->schema(); } | |
1157 | ||
1158 | Status RecordBatchStream::GetSchemaPayload(FlightPayload* payload) { | |
1159 | return impl_->GetSchemaPayload(payload); | |
1160 | } | |
1161 | ||
1162 | Status RecordBatchStream::Next(FlightPayload* payload) { return impl_->Next(payload); } | |
1163 | ||
1164 | } // namespace flight | |
1165 | } // namespace arrow |