]> git.proxmox.com Git - ceph.git/blame - ceph/src/arrow/cpp/src/arrow/flight/types.cc
import quincy 17.2.0
[ceph.git] / ceph / src / arrow / cpp / src / arrow / flight / types.cc
CommitLineData
1d09f67e
TL
1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18#include "arrow/flight/types.h"
19
20#include <memory>
21#include <sstream>
22#include <utility>
23
24#include "arrow/buffer.h"
25#include "arrow/flight/serialization_internal.h"
26#include "arrow/io/memory.h"
27#include "arrow/ipc/dictionary.h"
28#include "arrow/ipc/reader.h"
29#include "arrow/status.h"
30#include "arrow/table.h"
31#include "arrow/util/uri.h"
32
33namespace arrow {
34namespace flight {
35
36const char* kSchemeGrpc = "grpc";
37const char* kSchemeGrpcTcp = "grpc+tcp";
38const char* kSchemeGrpcUnix = "grpc+unix";
39const char* kSchemeGrpcTls = "grpc+tls";
40
41const char* kErrorDetailTypeId = "flight::FlightStatusDetail";
42
43const char* FlightStatusDetail::type_id() const { return kErrorDetailTypeId; }
44
45std::string FlightStatusDetail::ToString() const { return CodeAsString(); }
46
47FlightStatusCode FlightStatusDetail::code() const { return code_; }
48
49std::string FlightStatusDetail::extra_info() const { return extra_info_; }
50
51void FlightStatusDetail::set_extra_info(std::string extra_info) {
52 extra_info_ = std::move(extra_info);
53}
54
55std::string FlightStatusDetail::CodeAsString() const {
56 switch (code()) {
57 case FlightStatusCode::Internal:
58 return "Internal";
59 case FlightStatusCode::TimedOut:
60 return "TimedOut";
61 case FlightStatusCode::Cancelled:
62 return "Cancelled";
63 case FlightStatusCode::Unauthenticated:
64 return "Unauthenticated";
65 case FlightStatusCode::Unauthorized:
66 return "Unauthorized";
67 case FlightStatusCode::Unavailable:
68 return "Unavailable";
69 default:
70 return "Unknown";
71 }
72}
73
74std::shared_ptr<FlightStatusDetail> FlightStatusDetail::UnwrapStatus(
75 const arrow::Status& status) {
76 if (!status.detail() || status.detail()->type_id() != kErrorDetailTypeId) {
77 return nullptr;
78 }
79 return std::dynamic_pointer_cast<FlightStatusDetail>(status.detail());
80}
81
82Status MakeFlightError(FlightStatusCode code, std::string message,
83 std::string extra_info) {
84 StatusCode arrow_code = arrow::StatusCode::IOError;
85 return arrow::Status(arrow_code, std::move(message),
86 std::make_shared<FlightStatusDetail>(code, std::move(extra_info)));
87}
88
89bool FlightDescriptor::Equals(const FlightDescriptor& other) const {
90 if (type != other.type) {
91 return false;
92 }
93 switch (type) {
94 case PATH:
95 return path == other.path;
96 case CMD:
97 return cmd == other.cmd;
98 default:
99 return false;
100 }
101}
102
103std::string FlightDescriptor::ToString() const {
104 std::stringstream ss;
105 ss << "FlightDescriptor<";
106 switch (type) {
107 case PATH: {
108 bool first = true;
109 ss << "path = '";
110 for (const auto& p : path) {
111 if (!first) {
112 ss << "/";
113 }
114 first = false;
115 ss << p;
116 }
117 ss << "'";
118 break;
119 }
120 case CMD:
121 ss << "cmd = '" << cmd << "'";
122 break;
123 default:
124 break;
125 }
126 ss << ">";
127 return ss.str();
128}
129
130Status FlightPayload::Validate() const {
131 static constexpr int64_t kInt32Max = std::numeric_limits<int32_t>::max();
132 if (descriptor && descriptor->size() > kInt32Max) {
133 return Status::CapacityError("Descriptor size overflow (>= 2**31)");
134 }
135 if (app_metadata && app_metadata->size() > kInt32Max) {
136 return Status::CapacityError("app_metadata size overflow (>= 2**31)");
137 }
138 if (ipc_message.body_length > kInt32Max) {
139 return Status::Invalid("Cannot send record batches exceeding 2GiB yet");
140 }
141 return Status::OK();
142}
143
144Status SchemaResult::GetSchema(ipc::DictionaryMemo* dictionary_memo,
145 std::shared_ptr<Schema>* out) const {
146 io::BufferReader schema_reader(raw_schema_);
147 return ipc::ReadSchema(&schema_reader, dictionary_memo).Value(out);
148}
149
150Status FlightDescriptor::SerializeToString(std::string* out) const {
151 pb::FlightDescriptor pb_descriptor;
152 RETURN_NOT_OK(internal::ToProto(*this, &pb_descriptor));
153
154 if (!pb_descriptor.SerializeToString(out)) {
155 return Status::IOError("Serialized descriptor exceeded 2 GiB limit");
156 }
157 return Status::OK();
158}
159
160Status FlightDescriptor::Deserialize(const std::string& serialized,
161 FlightDescriptor* out) {
162 pb::FlightDescriptor pb_descriptor;
163 if (!pb_descriptor.ParseFromString(serialized)) {
164 return Status::Invalid("Not a valid descriptor");
165 }
166 return internal::FromProto(pb_descriptor, out);
167}
168
169bool Ticket::Equals(const Ticket& other) const { return ticket == other.ticket; }
170
171Status Ticket::SerializeToString(std::string* out) const {
172 pb::Ticket pb_ticket;
173 internal::ToProto(*this, &pb_ticket);
174
175 if (!pb_ticket.SerializeToString(out)) {
176 return Status::IOError("Serialized ticket exceeded 2 GiB limit");
177 }
178 return Status::OK();
179}
180
181Status Ticket::Deserialize(const std::string& serialized, Ticket* out) {
182 pb::Ticket pb_ticket;
183 if (!pb_ticket.ParseFromString(serialized)) {
184 return Status::Invalid("Not a valid ticket");
185 }
186 return internal::FromProto(pb_ticket, out);
187}
188
189arrow::Result<FlightInfo> FlightInfo::Make(const Schema& schema,
190 const FlightDescriptor& descriptor,
191 const std::vector<FlightEndpoint>& endpoints,
192 int64_t total_records, int64_t total_bytes) {
193 FlightInfo::Data data;
194 data.descriptor = descriptor;
195 data.endpoints = endpoints;
196 data.total_records = total_records;
197 data.total_bytes = total_bytes;
198 RETURN_NOT_OK(internal::SchemaToString(schema, &data.schema));
199 return FlightInfo(data);
200}
201
202Status FlightInfo::GetSchema(ipc::DictionaryMemo* dictionary_memo,
203 std::shared_ptr<Schema>* out) const {
204 if (reconstructed_schema_) {
205 *out = schema_;
206 return Status::OK();
207 }
208 io::BufferReader schema_reader(data_.schema);
209 RETURN_NOT_OK(ipc::ReadSchema(&schema_reader, dictionary_memo).Value(&schema_));
210 reconstructed_schema_ = true;
211 *out = schema_;
212 return Status::OK();
213}
214
215Status FlightInfo::SerializeToString(std::string* out) const {
216 pb::FlightInfo pb_info;
217 RETURN_NOT_OK(internal::ToProto(*this, &pb_info));
218
219 if (!pb_info.SerializeToString(out)) {
220 return Status::IOError("Serialized FlightInfo exceeded 2 GiB limit");
221 }
222 return Status::OK();
223}
224
225Status FlightInfo::Deserialize(const std::string& serialized,
226 std::unique_ptr<FlightInfo>* out) {
227 pb::FlightInfo pb_info;
228 if (!pb_info.ParseFromString(serialized)) {
229 return Status::Invalid("Not a valid FlightInfo");
230 }
231 FlightInfo::Data data;
232 RETURN_NOT_OK(internal::FromProto(pb_info, &data));
233 out->reset(new FlightInfo(data));
234 return Status::OK();
235}
236
237Location::Location() { uri_ = std::make_shared<arrow::internal::Uri>(); }
238
239Status Location::Parse(const std::string& uri_string, Location* location) {
240 return location->uri_->Parse(uri_string);
241}
242
243Status Location::ForGrpcTcp(const std::string& host, const int port, Location* location) {
244 std::stringstream uri_string;
245 uri_string << "grpc+tcp://" << host << ':' << port;
246 return Location::Parse(uri_string.str(), location);
247}
248
249Status Location::ForGrpcTls(const std::string& host, const int port, Location* location) {
250 std::stringstream uri_string;
251 uri_string << "grpc+tls://" << host << ':' << port;
252 return Location::Parse(uri_string.str(), location);
253}
254
255Status Location::ForGrpcUnix(const std::string& path, Location* location) {
256 std::stringstream uri_string;
257 uri_string << "grpc+unix://" << path;
258 return Location::Parse(uri_string.str(), location);
259}
260
261std::string Location::ToString() const { return uri_->ToString(); }
262std::string Location::scheme() const {
263 std::string scheme = uri_->scheme();
264 if (scheme.empty()) {
265 // Default to grpc+tcp
266 return "grpc+tcp";
267 }
268 return scheme;
269}
270
271bool Location::Equals(const Location& other) const {
272 return ToString() == other.ToString();
273}
274
275bool FlightEndpoint::Equals(const FlightEndpoint& other) const {
276 return ticket == other.ticket && locations == other.locations;
277}
278
279Status MetadataRecordBatchReader::ReadAll(
280 std::vector<std::shared_ptr<RecordBatch>>* batches) {
281 FlightStreamChunk chunk;
282
283 while (true) {
284 RETURN_NOT_OK(Next(&chunk));
285 if (!chunk.data) break;
286 batches->emplace_back(std::move(chunk.data));
287 }
288 return Status::OK();
289}
290
291Status MetadataRecordBatchReader::ReadAll(std::shared_ptr<Table>* table) {
292 std::vector<std::shared_ptr<RecordBatch>> batches;
293 RETURN_NOT_OK(ReadAll(&batches));
294 ARROW_ASSIGN_OR_RAISE(auto schema, GetSchema());
295 return Table::FromRecordBatches(schema, std::move(batches)).Value(table);
296}
297
298Status MetadataRecordBatchWriter::Begin(const std::shared_ptr<Schema>& schema) {
299 return Begin(schema, ipc::IpcWriteOptions::Defaults());
300}
301
302namespace {
303class MetadataRecordBatchReaderAdapter : public RecordBatchReader {
304 public:
305 explicit MetadataRecordBatchReaderAdapter(
306 std::shared_ptr<Schema> schema, std::shared_ptr<MetadataRecordBatchReader> delegate)
307 : schema_(std::move(schema)), delegate_(std::move(delegate)) {}
308 std::shared_ptr<Schema> schema() const override { return schema_; }
309 Status ReadNext(std::shared_ptr<RecordBatch>* batch) override {
310 FlightStreamChunk next;
311 while (true) {
312 RETURN_NOT_OK(delegate_->Next(&next));
313 if (!next.data && !next.app_metadata) {
314 // EOS
315 *batch = nullptr;
316 return Status::OK();
317 } else if (next.data) {
318 *batch = std::move(next.data);
319 return Status::OK();
320 }
321 // Got metadata, but no data (which is valid) - read the next message
322 }
323 }
324
325 private:
326 std::shared_ptr<Schema> schema_;
327 std::shared_ptr<MetadataRecordBatchReader> delegate_;
328};
329}; // namespace
330
331arrow::Result<std::shared_ptr<RecordBatchReader>> MakeRecordBatchReader(
332 std::shared_ptr<MetadataRecordBatchReader> reader) {
333 ARROW_ASSIGN_OR_RAISE(auto schema, reader->GetSchema());
334 return std::make_shared<MetadataRecordBatchReaderAdapter>(std::move(schema),
335 std::move(reader));
336}
337
338SimpleFlightListing::SimpleFlightListing(const std::vector<FlightInfo>& flights)
339 : position_(0), flights_(flights) {}
340
341SimpleFlightListing::SimpleFlightListing(std::vector<FlightInfo>&& flights)
342 : position_(0), flights_(std::move(flights)) {}
343
344Status SimpleFlightListing::Next(std::unique_ptr<FlightInfo>* info) {
345 if (position_ >= static_cast<int>(flights_.size())) {
346 *info = nullptr;
347 return Status::OK();
348 }
349 *info = std::unique_ptr<FlightInfo>(new FlightInfo(std::move(flights_[position_++])));
350 return Status::OK();
351}
352
353SimpleResultStream::SimpleResultStream(std::vector<Result>&& results)
354 : results_(std::move(results)), position_(0) {}
355
356Status SimpleResultStream::Next(std::unique_ptr<Result>* result) {
357 if (position_ >= results_.size()) {
358 *result = nullptr;
359 return Status::OK();
360 }
361 *result = std::unique_ptr<Result>(new Result(std::move(results_[position_++])));
362 return Status::OK();
363}
364
365Status BasicAuth::Deserialize(const std::string& serialized, BasicAuth* out) {
366 pb::BasicAuth pb_result;
367 pb_result.ParseFromString(serialized);
368 return internal::FromProto(pb_result, out);
369}
370
371Status BasicAuth::Serialize(const BasicAuth& basic_auth, std::string* out) {
372 pb::BasicAuth pb_result;
373 RETURN_NOT_OK(internal::ToProto(basic_auth, &pb_result));
374 *out = pb_result.SerializeAsString();
375 return Status::OK();
376}
377} // namespace flight
378} // namespace arrow