]>
Commit | Line | Data |
---|---|---|
1d09f67e TL |
1 | // Licensed to the Apache Software Foundation (ASF) under one |
2 | // or more contributor license agreements. See the NOTICE file | |
3 | // distributed with this work for additional information | |
4 | // regarding copyright ownership. The ASF licenses this file | |
5 | // to you under the Apache License, Version 2.0 (the | |
6 | // "License"); you may not use this file except in compliance | |
7 | // with the License. You may obtain a copy of the License at | |
8 | // | |
9 | // http://www.apache.org/licenses/LICENSE-2.0 | |
10 | // | |
11 | // Unless required by applicable law or agreed to in writing, | |
12 | // software distributed under the License is distributed on an | |
13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | |
14 | // KIND, either express or implied. See the License for the | |
15 | // specific language governing permissions and limitations | |
16 | // under the License. | |
17 | ||
18 | #include "arrow/flight/types.h" | |
19 | ||
20 | #include <memory> | |
21 | #include <sstream> | |
22 | #include <utility> | |
23 | ||
24 | #include "arrow/buffer.h" | |
25 | #include "arrow/flight/serialization_internal.h" | |
26 | #include "arrow/io/memory.h" | |
27 | #include "arrow/ipc/dictionary.h" | |
28 | #include "arrow/ipc/reader.h" | |
29 | #include "arrow/status.h" | |
30 | #include "arrow/table.h" | |
31 | #include "arrow/util/uri.h" | |
32 | ||
33 | namespace arrow { | |
34 | namespace flight { | |
35 | ||
36 | const char* kSchemeGrpc = "grpc"; | |
37 | const char* kSchemeGrpcTcp = "grpc+tcp"; | |
38 | const char* kSchemeGrpcUnix = "grpc+unix"; | |
39 | const char* kSchemeGrpcTls = "grpc+tls"; | |
40 | ||
41 | const char* kErrorDetailTypeId = "flight::FlightStatusDetail"; | |
42 | ||
43 | const char* FlightStatusDetail::type_id() const { return kErrorDetailTypeId; } | |
44 | ||
45 | std::string FlightStatusDetail::ToString() const { return CodeAsString(); } | |
46 | ||
47 | FlightStatusCode FlightStatusDetail::code() const { return code_; } | |
48 | ||
49 | std::string FlightStatusDetail::extra_info() const { return extra_info_; } | |
50 | ||
51 | void FlightStatusDetail::set_extra_info(std::string extra_info) { | |
52 | extra_info_ = std::move(extra_info); | |
53 | } | |
54 | ||
55 | std::string FlightStatusDetail::CodeAsString() const { | |
56 | switch (code()) { | |
57 | case FlightStatusCode::Internal: | |
58 | return "Internal"; | |
59 | case FlightStatusCode::TimedOut: | |
60 | return "TimedOut"; | |
61 | case FlightStatusCode::Cancelled: | |
62 | return "Cancelled"; | |
63 | case FlightStatusCode::Unauthenticated: | |
64 | return "Unauthenticated"; | |
65 | case FlightStatusCode::Unauthorized: | |
66 | return "Unauthorized"; | |
67 | case FlightStatusCode::Unavailable: | |
68 | return "Unavailable"; | |
69 | default: | |
70 | return "Unknown"; | |
71 | } | |
72 | } | |
73 | ||
74 | std::shared_ptr<FlightStatusDetail> FlightStatusDetail::UnwrapStatus( | |
75 | const arrow::Status& status) { | |
76 | if (!status.detail() || status.detail()->type_id() != kErrorDetailTypeId) { | |
77 | return nullptr; | |
78 | } | |
79 | return std::dynamic_pointer_cast<FlightStatusDetail>(status.detail()); | |
80 | } | |
81 | ||
82 | Status MakeFlightError(FlightStatusCode code, std::string message, | |
83 | std::string extra_info) { | |
84 | StatusCode arrow_code = arrow::StatusCode::IOError; | |
85 | return arrow::Status(arrow_code, std::move(message), | |
86 | std::make_shared<FlightStatusDetail>(code, std::move(extra_info))); | |
87 | } | |
88 | ||
89 | bool FlightDescriptor::Equals(const FlightDescriptor& other) const { | |
90 | if (type != other.type) { | |
91 | return false; | |
92 | } | |
93 | switch (type) { | |
94 | case PATH: | |
95 | return path == other.path; | |
96 | case CMD: | |
97 | return cmd == other.cmd; | |
98 | default: | |
99 | return false; | |
100 | } | |
101 | } | |
102 | ||
103 | std::string FlightDescriptor::ToString() const { | |
104 | std::stringstream ss; | |
105 | ss << "FlightDescriptor<"; | |
106 | switch (type) { | |
107 | case PATH: { | |
108 | bool first = true; | |
109 | ss << "path = '"; | |
110 | for (const auto& p : path) { | |
111 | if (!first) { | |
112 | ss << "/"; | |
113 | } | |
114 | first = false; | |
115 | ss << p; | |
116 | } | |
117 | ss << "'"; | |
118 | break; | |
119 | } | |
120 | case CMD: | |
121 | ss << "cmd = '" << cmd << "'"; | |
122 | break; | |
123 | default: | |
124 | break; | |
125 | } | |
126 | ss << ">"; | |
127 | return ss.str(); | |
128 | } | |
129 | ||
130 | Status FlightPayload::Validate() const { | |
131 | static constexpr int64_t kInt32Max = std::numeric_limits<int32_t>::max(); | |
132 | if (descriptor && descriptor->size() > kInt32Max) { | |
133 | return Status::CapacityError("Descriptor size overflow (>= 2**31)"); | |
134 | } | |
135 | if (app_metadata && app_metadata->size() > kInt32Max) { | |
136 | return Status::CapacityError("app_metadata size overflow (>= 2**31)"); | |
137 | } | |
138 | if (ipc_message.body_length > kInt32Max) { | |
139 | return Status::Invalid("Cannot send record batches exceeding 2GiB yet"); | |
140 | } | |
141 | return Status::OK(); | |
142 | } | |
143 | ||
144 | Status SchemaResult::GetSchema(ipc::DictionaryMemo* dictionary_memo, | |
145 | std::shared_ptr<Schema>* out) const { | |
146 | io::BufferReader schema_reader(raw_schema_); | |
147 | return ipc::ReadSchema(&schema_reader, dictionary_memo).Value(out); | |
148 | } | |
149 | ||
150 | Status FlightDescriptor::SerializeToString(std::string* out) const { | |
151 | pb::FlightDescriptor pb_descriptor; | |
152 | RETURN_NOT_OK(internal::ToProto(*this, &pb_descriptor)); | |
153 | ||
154 | if (!pb_descriptor.SerializeToString(out)) { | |
155 | return Status::IOError("Serialized descriptor exceeded 2 GiB limit"); | |
156 | } | |
157 | return Status::OK(); | |
158 | } | |
159 | ||
160 | Status FlightDescriptor::Deserialize(const std::string& serialized, | |
161 | FlightDescriptor* out) { | |
162 | pb::FlightDescriptor pb_descriptor; | |
163 | if (!pb_descriptor.ParseFromString(serialized)) { | |
164 | return Status::Invalid("Not a valid descriptor"); | |
165 | } | |
166 | return internal::FromProto(pb_descriptor, out); | |
167 | } | |
168 | ||
169 | bool Ticket::Equals(const Ticket& other) const { return ticket == other.ticket; } | |
170 | ||
171 | Status Ticket::SerializeToString(std::string* out) const { | |
172 | pb::Ticket pb_ticket; | |
173 | internal::ToProto(*this, &pb_ticket); | |
174 | ||
175 | if (!pb_ticket.SerializeToString(out)) { | |
176 | return Status::IOError("Serialized ticket exceeded 2 GiB limit"); | |
177 | } | |
178 | return Status::OK(); | |
179 | } | |
180 | ||
181 | Status Ticket::Deserialize(const std::string& serialized, Ticket* out) { | |
182 | pb::Ticket pb_ticket; | |
183 | if (!pb_ticket.ParseFromString(serialized)) { | |
184 | return Status::Invalid("Not a valid ticket"); | |
185 | } | |
186 | return internal::FromProto(pb_ticket, out); | |
187 | } | |
188 | ||
189 | arrow::Result<FlightInfo> FlightInfo::Make(const Schema& schema, | |
190 | const FlightDescriptor& descriptor, | |
191 | const std::vector<FlightEndpoint>& endpoints, | |
192 | int64_t total_records, int64_t total_bytes) { | |
193 | FlightInfo::Data data; | |
194 | data.descriptor = descriptor; | |
195 | data.endpoints = endpoints; | |
196 | data.total_records = total_records; | |
197 | data.total_bytes = total_bytes; | |
198 | RETURN_NOT_OK(internal::SchemaToString(schema, &data.schema)); | |
199 | return FlightInfo(data); | |
200 | } | |
201 | ||
202 | Status FlightInfo::GetSchema(ipc::DictionaryMemo* dictionary_memo, | |
203 | std::shared_ptr<Schema>* out) const { | |
204 | if (reconstructed_schema_) { | |
205 | *out = schema_; | |
206 | return Status::OK(); | |
207 | } | |
208 | io::BufferReader schema_reader(data_.schema); | |
209 | RETURN_NOT_OK(ipc::ReadSchema(&schema_reader, dictionary_memo).Value(&schema_)); | |
210 | reconstructed_schema_ = true; | |
211 | *out = schema_; | |
212 | return Status::OK(); | |
213 | } | |
214 | ||
215 | Status FlightInfo::SerializeToString(std::string* out) const { | |
216 | pb::FlightInfo pb_info; | |
217 | RETURN_NOT_OK(internal::ToProto(*this, &pb_info)); | |
218 | ||
219 | if (!pb_info.SerializeToString(out)) { | |
220 | return Status::IOError("Serialized FlightInfo exceeded 2 GiB limit"); | |
221 | } | |
222 | return Status::OK(); | |
223 | } | |
224 | ||
225 | Status FlightInfo::Deserialize(const std::string& serialized, | |
226 | std::unique_ptr<FlightInfo>* out) { | |
227 | pb::FlightInfo pb_info; | |
228 | if (!pb_info.ParseFromString(serialized)) { | |
229 | return Status::Invalid("Not a valid FlightInfo"); | |
230 | } | |
231 | FlightInfo::Data data; | |
232 | RETURN_NOT_OK(internal::FromProto(pb_info, &data)); | |
233 | out->reset(new FlightInfo(data)); | |
234 | return Status::OK(); | |
235 | } | |
236 | ||
237 | Location::Location() { uri_ = std::make_shared<arrow::internal::Uri>(); } | |
238 | ||
239 | Status Location::Parse(const std::string& uri_string, Location* location) { | |
240 | return location->uri_->Parse(uri_string); | |
241 | } | |
242 | ||
243 | Status Location::ForGrpcTcp(const std::string& host, const int port, Location* location) { | |
244 | std::stringstream uri_string; | |
245 | uri_string << "grpc+tcp://" << host << ':' << port; | |
246 | return Location::Parse(uri_string.str(), location); | |
247 | } | |
248 | ||
249 | Status Location::ForGrpcTls(const std::string& host, const int port, Location* location) { | |
250 | std::stringstream uri_string; | |
251 | uri_string << "grpc+tls://" << host << ':' << port; | |
252 | return Location::Parse(uri_string.str(), location); | |
253 | } | |
254 | ||
255 | Status Location::ForGrpcUnix(const std::string& path, Location* location) { | |
256 | std::stringstream uri_string; | |
257 | uri_string << "grpc+unix://" << path; | |
258 | return Location::Parse(uri_string.str(), location); | |
259 | } | |
260 | ||
261 | std::string Location::ToString() const { return uri_->ToString(); } | |
262 | std::string Location::scheme() const { | |
263 | std::string scheme = uri_->scheme(); | |
264 | if (scheme.empty()) { | |
265 | // Default to grpc+tcp | |
266 | return "grpc+tcp"; | |
267 | } | |
268 | return scheme; | |
269 | } | |
270 | ||
271 | bool Location::Equals(const Location& other) const { | |
272 | return ToString() == other.ToString(); | |
273 | } | |
274 | ||
275 | bool FlightEndpoint::Equals(const FlightEndpoint& other) const { | |
276 | return ticket == other.ticket && locations == other.locations; | |
277 | } | |
278 | ||
279 | Status MetadataRecordBatchReader::ReadAll( | |
280 | std::vector<std::shared_ptr<RecordBatch>>* batches) { | |
281 | FlightStreamChunk chunk; | |
282 | ||
283 | while (true) { | |
284 | RETURN_NOT_OK(Next(&chunk)); | |
285 | if (!chunk.data) break; | |
286 | batches->emplace_back(std::move(chunk.data)); | |
287 | } | |
288 | return Status::OK(); | |
289 | } | |
290 | ||
291 | Status MetadataRecordBatchReader::ReadAll(std::shared_ptr<Table>* table) { | |
292 | std::vector<std::shared_ptr<RecordBatch>> batches; | |
293 | RETURN_NOT_OK(ReadAll(&batches)); | |
294 | ARROW_ASSIGN_OR_RAISE(auto schema, GetSchema()); | |
295 | return Table::FromRecordBatches(schema, std::move(batches)).Value(table); | |
296 | } | |
297 | ||
298 | Status MetadataRecordBatchWriter::Begin(const std::shared_ptr<Schema>& schema) { | |
299 | return Begin(schema, ipc::IpcWriteOptions::Defaults()); | |
300 | } | |
301 | ||
302 | namespace { | |
303 | class MetadataRecordBatchReaderAdapter : public RecordBatchReader { | |
304 | public: | |
305 | explicit MetadataRecordBatchReaderAdapter( | |
306 | std::shared_ptr<Schema> schema, std::shared_ptr<MetadataRecordBatchReader> delegate) | |
307 | : schema_(std::move(schema)), delegate_(std::move(delegate)) {} | |
308 | std::shared_ptr<Schema> schema() const override { return schema_; } | |
309 | Status ReadNext(std::shared_ptr<RecordBatch>* batch) override { | |
310 | FlightStreamChunk next; | |
311 | while (true) { | |
312 | RETURN_NOT_OK(delegate_->Next(&next)); | |
313 | if (!next.data && !next.app_metadata) { | |
314 | // EOS | |
315 | *batch = nullptr; | |
316 | return Status::OK(); | |
317 | } else if (next.data) { | |
318 | *batch = std::move(next.data); | |
319 | return Status::OK(); | |
320 | } | |
321 | // Got metadata, but no data (which is valid) - read the next message | |
322 | } | |
323 | } | |
324 | ||
325 | private: | |
326 | std::shared_ptr<Schema> schema_; | |
327 | std::shared_ptr<MetadataRecordBatchReader> delegate_; | |
328 | }; | |
329 | }; // namespace | |
330 | ||
331 | arrow::Result<std::shared_ptr<RecordBatchReader>> MakeRecordBatchReader( | |
332 | std::shared_ptr<MetadataRecordBatchReader> reader) { | |
333 | ARROW_ASSIGN_OR_RAISE(auto schema, reader->GetSchema()); | |
334 | return std::make_shared<MetadataRecordBatchReaderAdapter>(std::move(schema), | |
335 | std::move(reader)); | |
336 | } | |
337 | ||
338 | SimpleFlightListing::SimpleFlightListing(const std::vector<FlightInfo>& flights) | |
339 | : position_(0), flights_(flights) {} | |
340 | ||
341 | SimpleFlightListing::SimpleFlightListing(std::vector<FlightInfo>&& flights) | |
342 | : position_(0), flights_(std::move(flights)) {} | |
343 | ||
344 | Status SimpleFlightListing::Next(std::unique_ptr<FlightInfo>* info) { | |
345 | if (position_ >= static_cast<int>(flights_.size())) { | |
346 | *info = nullptr; | |
347 | return Status::OK(); | |
348 | } | |
349 | *info = std::unique_ptr<FlightInfo>(new FlightInfo(std::move(flights_[position_++]))); | |
350 | return Status::OK(); | |
351 | } | |
352 | ||
353 | SimpleResultStream::SimpleResultStream(std::vector<Result>&& results) | |
354 | : results_(std::move(results)), position_(0) {} | |
355 | ||
356 | Status SimpleResultStream::Next(std::unique_ptr<Result>* result) { | |
357 | if (position_ >= results_.size()) { | |
358 | *result = nullptr; | |
359 | return Status::OK(); | |
360 | } | |
361 | *result = std::unique_ptr<Result>(new Result(std::move(results_[position_++]))); | |
362 | return Status::OK(); | |
363 | } | |
364 | ||
365 | Status BasicAuth::Deserialize(const std::string& serialized, BasicAuth* out) { | |
366 | pb::BasicAuth pb_result; | |
367 | pb_result.ParseFromString(serialized); | |
368 | return internal::FromProto(pb_result, out); | |
369 | } | |
370 | ||
371 | Status BasicAuth::Serialize(const BasicAuth& basic_auth, std::string* out) { | |
372 | pb::BasicAuth pb_result; | |
373 | RETURN_NOT_OK(internal::ToProto(basic_auth, &pb_result)); | |
374 | *out = pb_result.SerializeAsString(); | |
375 | return Status::OK(); | |
376 | } | |
377 | } // namespace flight | |
378 | } // namespace arrow |