]>
Commit | Line | Data |
---|---|---|
1d09f67e TL |
1 | // Licensed to the Apache Software Foundation (ASF) under one |
2 | // or more contributor license agreements. See the NOTICE file | |
3 | // distributed with this work for additional information | |
4 | // regarding copyright ownership. The ASF licenses this file | |
5 | // to you under the Apache License, Version 2.0 (the | |
6 | // "License"); you may not use this file except in compliance | |
7 | // with the License. You may obtain a copy of the License at | |
8 | // | |
9 | // http://www.apache.org/licenses/LICENSE-2.0 | |
10 | // | |
11 | // Unless required by applicable law or agreed to in writing, | |
12 | // software distributed under the License is distributed on an | |
13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | |
14 | // KIND, either express or implied. See the License for the | |
15 | // specific language governing permissions and limitations | |
16 | // under the License. | |
17 | ||
18 | // Performance server for benchmarking purposes | |
19 | ||
20 | #include <signal.h> | |
21 | #include <cstdint> | |
22 | #include <fstream> | |
23 | #include <iostream> | |
24 | #include <memory> | |
25 | #include <string> | |
26 | ||
27 | #include <gflags/gflags.h> | |
28 | ||
29 | #include "arrow/array.h" | |
30 | #include "arrow/io/test_common.h" | |
31 | #include "arrow/ipc/writer.h" | |
32 | #include "arrow/record_batch.h" | |
33 | #include "arrow/testing/random.h" | |
34 | #include "arrow/testing/util.h" | |
35 | #include "arrow/util/logging.h" | |
36 | ||
37 | #include "arrow/flight/api.h" | |
38 | #include "arrow/flight/internal.h" | |
39 | #include "arrow/flight/perf.pb.h" | |
40 | #include "arrow/flight/test_util.h" | |
41 | ||
42 | DEFINE_string(server_host, "localhost", "Host where the server is running on"); | |
43 | DEFINE_int32(port, 31337, "Server port to listen on"); | |
44 | DEFINE_string(server_unix, "", "Unix socket path where the server is running on"); | |
45 | DEFINE_string(cert_file, "", "Path to TLS certificate"); | |
46 | DEFINE_string(key_file, "", "Path to TLS private key"); | |
47 | ||
48 | namespace perf = arrow::flight::perf; | |
49 | namespace proto = arrow::flight::protocol; | |
50 | ||
51 | namespace arrow { | |
52 | namespace flight { | |
53 | ||
54 | #define CHECK_PARSE(EXPR) \ | |
55 | do { \ | |
56 | if (!EXPR) { \ | |
57 | return Status::Invalid("cannot parse protobuf"); \ | |
58 | } \ | |
59 | } while (0) | |
60 | ||
61 | // Create record batches with a unique "a" column so we can verify on the | |
62 | // client side that the results are correct | |
63 | class PerfDataStream : public FlightDataStream { | |
64 | public: | |
65 | PerfDataStream(bool verify, const int64_t start, const int64_t total_records, | |
66 | const std::shared_ptr<Schema>& schema, const ArrayVector& arrays) | |
67 | : start_(start), | |
68 | verify_(verify), | |
69 | batch_length_(arrays[0]->length()), | |
70 | total_records_(total_records), | |
71 | records_sent_(0), | |
72 | schema_(schema), | |
73 | mapper_(*schema), | |
74 | arrays_(arrays) { | |
75 | batch_ = RecordBatch::Make(schema, batch_length_, arrays_); | |
76 | } | |
77 | ||
78 | std::shared_ptr<Schema> schema() override { return schema_; } | |
79 | ||
80 | Status GetSchemaPayload(FlightPayload* payload) override { | |
81 | return ipc::GetSchemaPayload(*schema_, ipc_options_, mapper_, &payload->ipc_message); | |
82 | } | |
83 | ||
84 | Status Next(FlightPayload* payload) override { | |
85 | if (records_sent_ >= total_records_) { | |
86 | // Signal that iteration is over | |
87 | payload->ipc_message.metadata = nullptr; | |
88 | return Status::OK(); | |
89 | } | |
90 | ||
91 | if (verify_) { | |
92 | // mutate first array | |
93 | auto data = | |
94 | reinterpret_cast<int64_t*>(arrays_[0]->data()->buffers[1]->mutable_data()); | |
95 | for (int64_t i = 0; i < batch_length_; ++i) { | |
96 | data[i] = start_ + records_sent_ + i; | |
97 | } | |
98 | } | |
99 | ||
100 | auto batch = batch_; | |
101 | ||
102 | // Last partial batch | |
103 | if (records_sent_ + batch_length_ > total_records_) { | |
104 | batch = batch_->Slice(0, total_records_ - records_sent_); | |
105 | records_sent_ += total_records_ - records_sent_; | |
106 | } else { | |
107 | records_sent_ += batch_length_; | |
108 | } | |
109 | return ipc::GetRecordBatchPayload(*batch, ipc_options_, &payload->ipc_message); | |
110 | } | |
111 | ||
112 | private: | |
113 | const int64_t start_; | |
114 | bool verify_; | |
115 | const int64_t batch_length_; | |
116 | const int64_t total_records_; | |
117 | int64_t records_sent_; | |
118 | std::shared_ptr<Schema> schema_; | |
119 | ipc::DictionaryFieldMapper mapper_; | |
120 | ipc::IpcWriteOptions ipc_options_; | |
121 | std::shared_ptr<RecordBatch> batch_; | |
122 | ArrayVector arrays_; | |
123 | }; | |
124 | ||
125 | Status GetPerfBatches(const perf::Token& token, const std::shared_ptr<Schema>& schema, | |
126 | bool use_verifier, std::unique_ptr<FlightDataStream>* data_stream) { | |
127 | std::shared_ptr<ResizableBuffer> buffer; | |
128 | std::vector<std::shared_ptr<Array>> arrays; | |
129 | ||
130 | const int32_t length = token.definition().records_per_batch(); | |
131 | const int32_t ncolumns = 4; | |
132 | for (int i = 0; i < ncolumns; ++i) { | |
133 | RETURN_NOT_OK(MakeRandomByteBuffer(length * sizeof(int64_t), default_memory_pool(), | |
134 | &buffer, static_cast<int32_t>(i) /* seed */)); | |
135 | arrays.push_back(std::make_shared<Int64Array>(length, buffer)); | |
136 | RETURN_NOT_OK(arrays.back()->Validate()); | |
137 | } | |
138 | ||
139 | *data_stream = std::unique_ptr<FlightDataStream>( | |
140 | new PerfDataStream(use_verifier, token.start(), | |
141 | token.definition().records_per_stream(), schema, arrays)); | |
142 | return Status::OK(); | |
143 | } | |
144 | ||
145 | class FlightPerfServer : public FlightServerBase { | |
146 | public: | |
147 | FlightPerfServer() : location_() { | |
148 | perf_schema_ = schema({field("a", int64()), field("b", int64()), field("c", int64()), | |
149 | field("d", int64())}); | |
150 | } | |
151 | ||
152 | void SetLocation(Location location) { location_ = location; } | |
153 | ||
154 | Status GetFlightInfo(const ServerCallContext& context, const FlightDescriptor& request, | |
155 | std::unique_ptr<FlightInfo>* info) override { | |
156 | perf::Perf perf_request; | |
157 | CHECK_PARSE(perf_request.ParseFromString(request.cmd)); | |
158 | ||
159 | perf::Token token; | |
160 | token.mutable_definition()->CopyFrom(perf_request); | |
161 | ||
162 | std::vector<FlightEndpoint> endpoints; | |
163 | Ticket tmp_ticket; | |
164 | for (int64_t i = 0; i < perf_request.stream_count(); ++i) { | |
165 | token.set_start(i * perf_request.records_per_stream()); | |
166 | token.set_end((i + 1) * perf_request.records_per_stream()); | |
167 | ||
168 | (void)token.SerializeToString(&tmp_ticket.ticket); | |
169 | ||
170 | // All endpoints same location for now | |
171 | endpoints.push_back(FlightEndpoint{tmp_ticket, {location_}}); | |
172 | } | |
173 | ||
174 | uint64_t total_records = | |
175 | perf_request.stream_count() * perf_request.records_per_stream(); | |
176 | ||
177 | FlightInfo::Data data; | |
178 | RETURN_NOT_OK( | |
179 | MakeFlightInfo(*perf_schema_, request, endpoints, total_records, -1, &data)); | |
180 | *info = std::unique_ptr<FlightInfo>(new FlightInfo(data)); | |
181 | return Status::OK(); | |
182 | } | |
183 | ||
184 | Status DoGet(const ServerCallContext& context, const Ticket& request, | |
185 | std::unique_ptr<FlightDataStream>* data_stream) override { | |
186 | perf::Token token; | |
187 | CHECK_PARSE(token.ParseFromString(request.ticket)); | |
188 | return GetPerfBatches(token, perf_schema_, false, data_stream); | |
189 | } | |
190 | ||
191 | Status DoPut(const ServerCallContext& context, | |
192 | std::unique_ptr<FlightMessageReader> reader, | |
193 | std::unique_ptr<FlightMetadataWriter> writer) override { | |
194 | FlightStreamChunk chunk; | |
195 | while (true) { | |
196 | RETURN_NOT_OK(reader->Next(&chunk)); | |
197 | if (!chunk.data) break; | |
198 | if (chunk.app_metadata) { | |
199 | RETURN_NOT_OK(writer->WriteMetadata(*chunk.app_metadata)); | |
200 | } | |
201 | } | |
202 | return Status::OK(); | |
203 | } | |
204 | ||
205 | Status DoAction(const ServerCallContext& context, const Action& action, | |
206 | std::unique_ptr<ResultStream>* result) override { | |
207 | if (action.type == "ping") { | |
208 | std::shared_ptr<Buffer> buf = Buffer::FromString("ok"); | |
209 | *result = std::unique_ptr<ResultStream>(new SimpleResultStream({Result{buf}})); | |
210 | return Status::OK(); | |
211 | } | |
212 | return Status::NotImplemented(action.type); | |
213 | } | |
214 | ||
215 | private: | |
216 | Location location_; | |
217 | std::shared_ptr<Schema> perf_schema_; | |
218 | }; | |
219 | ||
220 | } // namespace flight | |
221 | } // namespace arrow | |
222 | ||
223 | std::unique_ptr<arrow::flight::FlightPerfServer> g_server; | |
224 | ||
225 | void Shutdown(int signal) { | |
226 | if (g_server != nullptr) { | |
227 | ARROW_CHECK_OK(g_server->Shutdown()); | |
228 | } | |
229 | } | |
230 | ||
231 | int main(int argc, char** argv) { | |
232 | gflags::ParseCommandLineFlags(&argc, &argv, true); | |
233 | ||
234 | g_server.reset(new arrow::flight::FlightPerfServer); | |
235 | ||
236 | arrow::flight::Location bind_location; | |
237 | arrow::flight::Location connect_location; | |
238 | if (FLAGS_server_unix.empty()) { | |
239 | if (!FLAGS_cert_file.empty() || !FLAGS_key_file.empty()) { | |
240 | if (!FLAGS_cert_file.empty() && !FLAGS_key_file.empty()) { | |
241 | ARROW_CHECK_OK( | |
242 | arrow::flight::Location::ForGrpcTls("0.0.0.0", FLAGS_port, &bind_location)); | |
243 | ARROW_CHECK_OK(arrow::flight::Location::ForGrpcTls(FLAGS_server_host, FLAGS_port, | |
244 | &connect_location)); | |
245 | } else { | |
246 | std::cerr << "If providing TLS cert/key, must provide both" << std::endl; | |
247 | return 1; | |
248 | } | |
249 | } else { | |
250 | ARROW_CHECK_OK( | |
251 | arrow::flight::Location::ForGrpcTcp("0.0.0.0", FLAGS_port, &bind_location)); | |
252 | ARROW_CHECK_OK(arrow::flight::Location::ForGrpcTcp(FLAGS_server_host, FLAGS_port, | |
253 | &connect_location)); | |
254 | } | |
255 | } else { | |
256 | ARROW_CHECK_OK( | |
257 | arrow::flight::Location::ForGrpcUnix(FLAGS_server_unix, &bind_location)); | |
258 | ARROW_CHECK_OK( | |
259 | arrow::flight::Location::ForGrpcUnix(FLAGS_server_unix, &connect_location)); | |
260 | } | |
261 | arrow::flight::FlightServerOptions options(bind_location); | |
262 | if (!FLAGS_cert_file.empty() && !FLAGS_key_file.empty()) { | |
263 | std::cout << "Enabling TLS" << std::endl; | |
264 | std::ifstream cert_file(FLAGS_cert_file); | |
265 | std::string cert((std::istreambuf_iterator<char>(cert_file)), | |
266 | (std::istreambuf_iterator<char>())); | |
267 | std::ifstream key_file(FLAGS_key_file); | |
268 | std::string key((std::istreambuf_iterator<char>(key_file)), | |
269 | (std::istreambuf_iterator<char>())); | |
270 | options.tls_certificates.push_back(arrow::flight::CertKeyPair{cert, key}); | |
271 | } | |
272 | ||
273 | ARROW_CHECK_OK(g_server->Init(options)); | |
274 | // Exit with a clean error code (0) on SIGTERM | |
275 | ARROW_CHECK_OK(g_server->SetShutdownOnSignals({SIGTERM})); | |
276 | if (FLAGS_server_unix.empty()) { | |
277 | std::cout << "Server host: " << FLAGS_server_host << std::endl; | |
278 | std::cout << "Server port: " << FLAGS_port << std::endl; | |
279 | } else { | |
280 | std::cout << "Server unix socket: " << FLAGS_server_unix << std::endl; | |
281 | } | |
282 | g_server->SetLocation(connect_location); | |
283 | ARROW_CHECK_OK(g_server->Serve()); | |
284 | return 0; | |
285 | } |