]> git.proxmox.com Git - ceph.git/blame - ceph/src/arrow/cpp/src/arrow/flight/perf_server.cc
import quincy 17.2.0
[ceph.git] / ceph / src / arrow / cpp / src / arrow / flight / perf_server.cc
CommitLineData
1d09f67e
TL
1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18// Performance server for benchmarking purposes
19
20#include <signal.h>
21#include <cstdint>
22#include <fstream>
23#include <iostream>
24#include <memory>
25#include <string>
26
27#include <gflags/gflags.h>
28
29#include "arrow/array.h"
30#include "arrow/io/test_common.h"
31#include "arrow/ipc/writer.h"
32#include "arrow/record_batch.h"
33#include "arrow/testing/random.h"
34#include "arrow/testing/util.h"
35#include "arrow/util/logging.h"
36
37#include "arrow/flight/api.h"
38#include "arrow/flight/internal.h"
39#include "arrow/flight/perf.pb.h"
40#include "arrow/flight/test_util.h"
41
42DEFINE_string(server_host, "localhost", "Host where the server is running on");
43DEFINE_int32(port, 31337, "Server port to listen on");
44DEFINE_string(server_unix, "", "Unix socket path where the server is running on");
45DEFINE_string(cert_file, "", "Path to TLS certificate");
46DEFINE_string(key_file, "", "Path to TLS private key");
47
48namespace perf = arrow::flight::perf;
49namespace proto = arrow::flight::protocol;
50
51namespace arrow {
52namespace flight {
53
54#define CHECK_PARSE(EXPR) \
55 do { \
56 if (!EXPR) { \
57 return Status::Invalid("cannot parse protobuf"); \
58 } \
59 } while (0)
60
61// Create record batches with a unique "a" column so we can verify on the
62// client side that the results are correct
63class PerfDataStream : public FlightDataStream {
64 public:
65 PerfDataStream(bool verify, const int64_t start, const int64_t total_records,
66 const std::shared_ptr<Schema>& schema, const ArrayVector& arrays)
67 : start_(start),
68 verify_(verify),
69 batch_length_(arrays[0]->length()),
70 total_records_(total_records),
71 records_sent_(0),
72 schema_(schema),
73 mapper_(*schema),
74 arrays_(arrays) {
75 batch_ = RecordBatch::Make(schema, batch_length_, arrays_);
76 }
77
78 std::shared_ptr<Schema> schema() override { return schema_; }
79
80 Status GetSchemaPayload(FlightPayload* payload) override {
81 return ipc::GetSchemaPayload(*schema_, ipc_options_, mapper_, &payload->ipc_message);
82 }
83
84 Status Next(FlightPayload* payload) override {
85 if (records_sent_ >= total_records_) {
86 // Signal that iteration is over
87 payload->ipc_message.metadata = nullptr;
88 return Status::OK();
89 }
90
91 if (verify_) {
92 // mutate first array
93 auto data =
94 reinterpret_cast<int64_t*>(arrays_[0]->data()->buffers[1]->mutable_data());
95 for (int64_t i = 0; i < batch_length_; ++i) {
96 data[i] = start_ + records_sent_ + i;
97 }
98 }
99
100 auto batch = batch_;
101
102 // Last partial batch
103 if (records_sent_ + batch_length_ > total_records_) {
104 batch = batch_->Slice(0, total_records_ - records_sent_);
105 records_sent_ += total_records_ - records_sent_;
106 } else {
107 records_sent_ += batch_length_;
108 }
109 return ipc::GetRecordBatchPayload(*batch, ipc_options_, &payload->ipc_message);
110 }
111
112 private:
113 const int64_t start_;
114 bool verify_;
115 const int64_t batch_length_;
116 const int64_t total_records_;
117 int64_t records_sent_;
118 std::shared_ptr<Schema> schema_;
119 ipc::DictionaryFieldMapper mapper_;
120 ipc::IpcWriteOptions ipc_options_;
121 std::shared_ptr<RecordBatch> batch_;
122 ArrayVector arrays_;
123};
124
125Status GetPerfBatches(const perf::Token& token, const std::shared_ptr<Schema>& schema,
126 bool use_verifier, std::unique_ptr<FlightDataStream>* data_stream) {
127 std::shared_ptr<ResizableBuffer> buffer;
128 std::vector<std::shared_ptr<Array>> arrays;
129
130 const int32_t length = token.definition().records_per_batch();
131 const int32_t ncolumns = 4;
132 for (int i = 0; i < ncolumns; ++i) {
133 RETURN_NOT_OK(MakeRandomByteBuffer(length * sizeof(int64_t), default_memory_pool(),
134 &buffer, static_cast<int32_t>(i) /* seed */));
135 arrays.push_back(std::make_shared<Int64Array>(length, buffer));
136 RETURN_NOT_OK(arrays.back()->Validate());
137 }
138
139 *data_stream = std::unique_ptr<FlightDataStream>(
140 new PerfDataStream(use_verifier, token.start(),
141 token.definition().records_per_stream(), schema, arrays));
142 return Status::OK();
143}
144
145class FlightPerfServer : public FlightServerBase {
146 public:
147 FlightPerfServer() : location_() {
148 perf_schema_ = schema({field("a", int64()), field("b", int64()), field("c", int64()),
149 field("d", int64())});
150 }
151
152 void SetLocation(Location location) { location_ = location; }
153
154 Status GetFlightInfo(const ServerCallContext& context, const FlightDescriptor& request,
155 std::unique_ptr<FlightInfo>* info) override {
156 perf::Perf perf_request;
157 CHECK_PARSE(perf_request.ParseFromString(request.cmd));
158
159 perf::Token token;
160 token.mutable_definition()->CopyFrom(perf_request);
161
162 std::vector<FlightEndpoint> endpoints;
163 Ticket tmp_ticket;
164 for (int64_t i = 0; i < perf_request.stream_count(); ++i) {
165 token.set_start(i * perf_request.records_per_stream());
166 token.set_end((i + 1) * perf_request.records_per_stream());
167
168 (void)token.SerializeToString(&tmp_ticket.ticket);
169
170 // All endpoints same location for now
171 endpoints.push_back(FlightEndpoint{tmp_ticket, {location_}});
172 }
173
174 uint64_t total_records =
175 perf_request.stream_count() * perf_request.records_per_stream();
176
177 FlightInfo::Data data;
178 RETURN_NOT_OK(
179 MakeFlightInfo(*perf_schema_, request, endpoints, total_records, -1, &data));
180 *info = std::unique_ptr<FlightInfo>(new FlightInfo(data));
181 return Status::OK();
182 }
183
184 Status DoGet(const ServerCallContext& context, const Ticket& request,
185 std::unique_ptr<FlightDataStream>* data_stream) override {
186 perf::Token token;
187 CHECK_PARSE(token.ParseFromString(request.ticket));
188 return GetPerfBatches(token, perf_schema_, false, data_stream);
189 }
190
191 Status DoPut(const ServerCallContext& context,
192 std::unique_ptr<FlightMessageReader> reader,
193 std::unique_ptr<FlightMetadataWriter> writer) override {
194 FlightStreamChunk chunk;
195 while (true) {
196 RETURN_NOT_OK(reader->Next(&chunk));
197 if (!chunk.data) break;
198 if (chunk.app_metadata) {
199 RETURN_NOT_OK(writer->WriteMetadata(*chunk.app_metadata));
200 }
201 }
202 return Status::OK();
203 }
204
205 Status DoAction(const ServerCallContext& context, const Action& action,
206 std::unique_ptr<ResultStream>* result) override {
207 if (action.type == "ping") {
208 std::shared_ptr<Buffer> buf = Buffer::FromString("ok");
209 *result = std::unique_ptr<ResultStream>(new SimpleResultStream({Result{buf}}));
210 return Status::OK();
211 }
212 return Status::NotImplemented(action.type);
213 }
214
215 private:
216 Location location_;
217 std::shared_ptr<Schema> perf_schema_;
218};
219
220} // namespace flight
221} // namespace arrow
222
223std::unique_ptr<arrow::flight::FlightPerfServer> g_server;
224
225void Shutdown(int signal) {
226 if (g_server != nullptr) {
227 ARROW_CHECK_OK(g_server->Shutdown());
228 }
229}
230
231int main(int argc, char** argv) {
232 gflags::ParseCommandLineFlags(&argc, &argv, true);
233
234 g_server.reset(new arrow::flight::FlightPerfServer);
235
236 arrow::flight::Location bind_location;
237 arrow::flight::Location connect_location;
238 if (FLAGS_server_unix.empty()) {
239 if (!FLAGS_cert_file.empty() || !FLAGS_key_file.empty()) {
240 if (!FLAGS_cert_file.empty() && !FLAGS_key_file.empty()) {
241 ARROW_CHECK_OK(
242 arrow::flight::Location::ForGrpcTls("0.0.0.0", FLAGS_port, &bind_location));
243 ARROW_CHECK_OK(arrow::flight::Location::ForGrpcTls(FLAGS_server_host, FLAGS_port,
244 &connect_location));
245 } else {
246 std::cerr << "If providing TLS cert/key, must provide both" << std::endl;
247 return 1;
248 }
249 } else {
250 ARROW_CHECK_OK(
251 arrow::flight::Location::ForGrpcTcp("0.0.0.0", FLAGS_port, &bind_location));
252 ARROW_CHECK_OK(arrow::flight::Location::ForGrpcTcp(FLAGS_server_host, FLAGS_port,
253 &connect_location));
254 }
255 } else {
256 ARROW_CHECK_OK(
257 arrow::flight::Location::ForGrpcUnix(FLAGS_server_unix, &bind_location));
258 ARROW_CHECK_OK(
259 arrow::flight::Location::ForGrpcUnix(FLAGS_server_unix, &connect_location));
260 }
261 arrow::flight::FlightServerOptions options(bind_location);
262 if (!FLAGS_cert_file.empty() && !FLAGS_key_file.empty()) {
263 std::cout << "Enabling TLS" << std::endl;
264 std::ifstream cert_file(FLAGS_cert_file);
265 std::string cert((std::istreambuf_iterator<char>(cert_file)),
266 (std::istreambuf_iterator<char>()));
267 std::ifstream key_file(FLAGS_key_file);
268 std::string key((std::istreambuf_iterator<char>(key_file)),
269 (std::istreambuf_iterator<char>()));
270 options.tls_certificates.push_back(arrow::flight::CertKeyPair{cert, key});
271 }
272
273 ARROW_CHECK_OK(g_server->Init(options));
274 // Exit with a clean error code (0) on SIGTERM
275 ARROW_CHECK_OK(g_server->SetShutdownOnSignals({SIGTERM}));
276 if (FLAGS_server_unix.empty()) {
277 std::cout << "Server host: " << FLAGS_server_host << std::endl;
278 std::cout << "Server port: " << FLAGS_port << std::endl;
279 } else {
280 std::cout << "Server unix socket: " << FLAGS_server_unix << std::endl;
281 }
282 g_server->SetLocation(connect_location);
283 ARROW_CHECK_OK(g_server->Serve());
284 return 0;
285}