]> git.proxmox.com Git - ceph.git/blob - ceph/src/seastar/src/rpc/rpc.cc
import 15.2.0 Octopus source
[ceph.git] / ceph / src / seastar / src / rpc / rpc.cc
1 #include <seastar/rpc/rpc.hh>
2 #include <seastar/core/print.hh>
3 #include <seastar/util/defer.hh>
4 #include <boost/range/adaptor/map.hpp>
5
6 namespace seastar {
7
8 namespace rpc {
9
10 void logger::operator()(const client_info& info, id_type msg_id, const sstring& str) const {
11 log(format("client {} msg_id {}: {}", info.addr, msg_id, str));
12 }
13
14 void logger::operator()(const client_info& info, const sstring& str) const {
15 (*this)(info.addr, str);
16 }
17
18 void logger::operator()(const socket_address& addr, const sstring& str) const {
19 log(format("client {}: {}", addr, str));
20 }
21
22 no_wait_type no_wait;
23
24 constexpr size_t snd_buf::chunk_size;
25
26 snd_buf::snd_buf(size_t size_) : size(size_) {
27 if (size <= chunk_size) {
28 bufs = temporary_buffer<char>(size);
29 } else {
30 std::vector<temporary_buffer<char>> v;
31 v.reserve(align_up(size_t(size), chunk_size) / chunk_size);
32 while (size_) {
33 v.push_back(temporary_buffer<char>(std::min(chunk_size, size_)));
34 size_ -= v.back().size();
35 }
36 bufs = std::move(v);
37 }
38 }
39
40 temporary_buffer<char>& snd_buf::front() {
41 auto* one = compat::get_if<temporary_buffer<char>>(&bufs);
42 if (one) {
43 return *one;
44 } else {
45 return compat::get<std::vector<temporary_buffer<char>>>(bufs).front();
46 }
47 }
48
49 // Make a copy of a remote buffer. No data is actually copied, only pointers and
50 // a deleter of a new buffer takes care of deleting the original buffer
51 template<typename T> // T is either snd_buf or rcv_buf
52 T make_shard_local_buffer_copy(foreign_ptr<std::unique_ptr<T>> org) {
53 if (org.get_owner_shard() == engine().cpu_id()) {
54 return std::move(*org);
55 }
56 T buf(org->size);
57 auto* one = compat::get_if<temporary_buffer<char>>(&org->bufs);
58
59 if (one) {
60 buf.bufs = temporary_buffer<char>(one->get_write(), one->size(), make_object_deleter(std::move(org)));
61 } else {
62 auto& orgbufs = compat::get<std::vector<temporary_buffer<char>>>(org->bufs);
63 std::vector<temporary_buffer<char>> newbufs;
64 newbufs.reserve(orgbufs.size());
65 deleter d = make_object_deleter(std::move(org));
66 for (auto&& b : orgbufs) {
67 newbufs.push_back(temporary_buffer<char>(b.get_write(), b.size(), d.share()));
68 }
69 buf.bufs = std::move(newbufs);
70 }
71
72 return buf;
73 }
74
75 template snd_buf make_shard_local_buffer_copy(foreign_ptr<std::unique_ptr<snd_buf>>);
76 template rcv_buf make_shard_local_buffer_copy(foreign_ptr<std::unique_ptr<rcv_buf>>);
77
78 snd_buf connection::compress(snd_buf buf) {
79 if (_compressor) {
80 buf = _compressor->compress(4, std::move(buf));
81 static_assert(snd_buf::chunk_size >= 4, "send buffer chunk size is too small");
82 write_le<uint32_t>(buf.front().get_write(), buf.size - 4);
83 return buf;
84 }
85 return buf;
86 }
87
88 future<> connection::send_buffer(snd_buf buf) {
89 auto* b = compat::get_if<temporary_buffer<char>>(&buf.bufs);
90 if (b) {
91 return _write_buf.write(std::move(*b));
92 } else {
93 return do_with(std::move(compat::get<std::vector<temporary_buffer<char>>>(buf.bufs)),
94 [this] (std::vector<temporary_buffer<char>>& ar) {
95 return do_for_each(ar.begin(), ar.end(), [this] (auto& b) {
96 return _write_buf.write(std::move(b));
97 });
98 });
99 }
100 }
101
102 template<connection::outgoing_queue_type QueueType>
103 void connection::send_loop() {
104 _send_loop_stopped = do_until([this] { return _error; }, [this] {
105 return _outgoing_queue_cond.wait([this] { return !_outgoing_queue.empty(); }).then([this] {
106 // despite using wait with predicated above _outgoing_queue can still be empty here if
107 // there is only one entry on the list and its expire timer runs after wait() returned ready future,
108 // but before this continuation runs.
109 if (_outgoing_queue.empty()) {
110 return make_ready_future();
111 }
112 auto d = std::move(_outgoing_queue.front());
113 _outgoing_queue.pop_front();
114 d.t.cancel(); // cancel timeout timer
115 if (d.pcancel) {
116 d.pcancel->cancel_send = std::function<void()>(); // request is no longer cancellable
117 }
118 if (QueueType == outgoing_queue_type::request) {
119 static_assert(snd_buf::chunk_size >= 8, "send buffer chunk size is too small");
120 if (_timeout_negotiated) {
121 auto expire = d.t.get_timeout();
122 uint64_t left = 0;
123 if (expire != typename timer<rpc_clock_type>::time_point()) {
124 left = std::chrono::duration_cast<std::chrono::milliseconds>(expire - timer<rpc_clock_type>::clock::now()).count();
125 }
126 write_le<uint64_t>(d.buf.front().get_write(), left);
127 } else {
128 d.buf.front().trim_front(8);
129 d.buf.size -= 8;
130 }
131 }
132 d.buf = compress(std::move(d.buf));
133 auto f = send_buffer(std::move(d.buf)).then([this] {
134 _stats.sent_messages++;
135 return _write_buf.flush();
136 });
137 return f.finally([d = std::move(d)] {});
138 });
139 }).handle_exception([this] (std::exception_ptr eptr) {
140 _error = true;
141 });
142 }
143
144 future<> connection::stop_send_loop() {
145 _error = true;
146 if (_connected) {
147 _outgoing_queue_cond.broken();
148 _fd.shutdown_output();
149 }
150 return when_all(std::move(_send_loop_stopped), std::move(_sink_closed_future)).then([this] (std::tuple<future<>, future<bool>> res){
151 _outgoing_queue.clear();
152 // both _send_loop_stopped and _sink_closed_future are never exceptional
153 bool sink_closed = std::get<1>(res).get0();
154 return _connected && !sink_closed ? _write_buf.close() : make_ready_future();
155 });
156 }
157
158 void connection::set_socket(connected_socket&& fd) {
159 if (_connected) {
160 throw std::runtime_error("already connected");
161 }
162 _fd = std::move(fd);
163 _read_buf =_fd.input();
164 _write_buf = _fd.output();
165 _connected = true;
166 }
167
168 future<> connection::send_negotiation_frame(feature_map features) {
169 auto negotiation_frame_feature_record_size = [] (const feature_map::value_type& e) {
170 return 8 + e.second.size();
171 };
172 auto extra_len = boost::accumulate(
173 features | boost::adaptors::transformed(negotiation_frame_feature_record_size),
174 uint32_t(0));
175 temporary_buffer<char> reply(sizeof(negotiation_frame) + extra_len);
176 auto p = reply.get_write();
177 p = std::copy_n(rpc_magic, 8, p);
178 write_le<uint32_t>(p, extra_len);
179 p += 4;
180 for (auto&& e : features) {
181 write_le<uint32_t>(p, static_cast<uint32_t>(e.first));
182 p += 4;
183 write_le<uint32_t>(p, e.second.size());
184 p += 4;
185 p = std::copy_n(e.second.begin(), e.second.size(), p);
186 }
187 return _write_buf.write(std::move(reply)).then([this] {
188 _stats.sent_messages++;
189 return _write_buf.flush();
190 });
191 }
192
193 future<> connection::send(snd_buf buf, compat::optional<rpc_clock_type::time_point> timeout, cancellable* cancel) {
194 if (!_error) {
195 if (timeout && *timeout <= rpc_clock_type::now()) {
196 return make_ready_future<>();
197 }
198 _outgoing_queue.emplace_back(std::move(buf));
199 auto deleter = [this, it = std::prev(_outgoing_queue.cend())] {
200 _outgoing_queue.erase(it);
201 };
202 if (timeout) {
203 auto& t = _outgoing_queue.back().t;
204 t.set_callback(deleter);
205 t.arm(timeout.value());
206 }
207 if (cancel) {
208 cancel->cancel_send = std::move(deleter);
209 cancel->send_back_pointer = &_outgoing_queue.back().pcancel;
210 _outgoing_queue.back().pcancel = cancel;
211 }
212 _outgoing_queue_cond.signal();
213 return _outgoing_queue.back().p->get_future();
214 } else {
215 return make_exception_future<>(closed_error());
216 }
217 }
218
219 void connection::abort() {
220 if (!_error) {
221 _error = true;
222 _fd.shutdown_input();
223 }
224 }
225
226 future<> connection::stop() {
227 abort();
228 return _stopped.get_future();
229 }
230
231 template<typename Connection>
232 static bool verify_frame(Connection& c, temporary_buffer<char>& buf, size_t expected, const char* log) {
233 if (buf.size() != expected) {
234 if (buf.size() != 0) {
235 c.get_logger()(c.peer_address(), log);
236 }
237 return false;
238 }
239 return true;
240 }
241
242 template<typename Connection>
243 static
244 future<feature_map>
245 receive_negotiation_frame(Connection& c, input_stream<char>& in) {
246 return in.read_exactly(sizeof(negotiation_frame)).then([&c, &in] (temporary_buffer<char> neg) {
247 if (!verify_frame(c, neg, sizeof(negotiation_frame), "unexpected eof during negotiation frame")) {
248 return make_exception_future<feature_map>(closed_error());
249 }
250 negotiation_frame frame;
251 std::copy_n(neg.get_write(), sizeof(frame.magic), frame.magic);
252 frame.len = read_le<uint32_t>(neg.get_write() + 8);
253 if (std::memcmp(frame.magic, rpc_magic, sizeof(frame.magic)) != 0) {
254 c.get_logger()(c.peer_address(), "wrong protocol magic");
255 return make_exception_future<feature_map>(closed_error());
256 }
257 auto len = frame.len;
258 return in.read_exactly(len).then([&c, len] (temporary_buffer<char> extra) {
259 if (extra.size() != len) {
260 c.get_logger()(c.peer_address(), "unexpected eof during negotiation frame");
261 return make_exception_future<feature_map>(closed_error());
262 }
263 feature_map map;
264 auto p = extra.get();
265 auto end = p + extra.size();
266 while (p != end) {
267 if (end - p < 8) {
268 c.get_logger()(c.peer_address(), "bad feature data format in negotiation frame");
269 return make_exception_future<feature_map>(closed_error());
270 }
271 auto feature = static_cast<protocol_features>(read_le<uint32_t>(p));
272 auto f_len = read_le<uint32_t>(p + 4);
273 p += 8;
274 if (f_len > end - p) {
275 c.get_logger()(c.peer_address(), "buffer underflow in feature data in negotiation frame");
276 return make_exception_future<feature_map>(closed_error());
277 }
278 auto data = sstring(p, f_len);
279 p += f_len;
280 map.emplace(feature, std::move(data));
281 }
282 return make_ready_future<feature_map>(std::move(map));
283 });
284 });
285 }
286
287 inline future<rcv_buf>
288 read_rcv_buf(input_stream<char>& in, uint32_t size) {
289 return in.read_up_to(size).then([&, size] (temporary_buffer<char> data) mutable {
290 rcv_buf rb(size);
291 if (data.size() == 0) {
292 return make_ready_future<rcv_buf>(rcv_buf());
293 } else if (data.size() == size) {
294 rb.bufs = std::move(data);
295 return make_ready_future<rcv_buf>(std::move(rb));
296 } else {
297 size -= data.size();
298 std::vector<temporary_buffer<char>> v;
299 v.push_back(std::move(data));
300 rb.bufs = std::move(v);
301 return do_with(std::move(rb), std::move(size), [&in] (rcv_buf& rb, uint32_t& left) {
302 return repeat([&] () {
303 return in.read_up_to(left).then([&] (temporary_buffer<char> data) {
304 if (!data.size()) {
305 rb.size -= left;
306 return stop_iteration::yes;
307 } else {
308 left -= data.size();
309 compat::get<std::vector<temporary_buffer<char>>>(rb.bufs).push_back(std::move(data));
310 return left ? stop_iteration::no : stop_iteration::yes;
311 }
312 });
313 }).then([&rb] {
314 return std::move(rb);
315 });
316 });
317 }
318 });
319 }
320
321 template<typename FrameType>
322 typename FrameType::return_type
323 connection::read_frame(socket_address info, input_stream<char>& in) {
324 auto header_size = FrameType::header_size();
325 return in.read_exactly(header_size).then([this, header_size, info, &in] (temporary_buffer<char> header) {
326 if (header.size() != header_size) {
327 if (header.size() != 0) {
328 _logger(info, format("unexpected eof on a {} while reading header: expected {:d} got {:d}", FrameType::role(), header_size, header.size()));
329 }
330 return FrameType::empty_value();
331 }
332 auto h = FrameType::decode_header(header.get());
333 auto size = FrameType::get_size(h);
334 if (!size) {
335 return FrameType::make_value(h, rcv_buf());
336 } else {
337 return read_rcv_buf(in, size).then([this, info, h = std::move(h), size] (rcv_buf rb) {
338 if (rb.size != size) {
339 _logger(info, format("unexpected eof on a {} while reading data: expected {:d} got {:d}", FrameType::role(), size, rb.size));
340 return FrameType::empty_value();
341 } else {
342 return FrameType::make_value(h, std::move(rb));
343 }
344 });
345 }
346 });
347 }
348
349 template<typename FrameType>
350 typename FrameType::return_type
351 connection::read_frame_compressed(socket_address info, std::unique_ptr<compressor>& compressor, input_stream<char>& in) {
352 if (compressor) {
353 return in.read_exactly(4).then([&] (temporary_buffer<char> compress_header) {
354 if (compress_header.size() != 4) {
355 if (compress_header.size() != 0) {
356 _logger(info, format("unexpected eof on a {} while reading compression header: expected 4 got {:d}", FrameType::role(), compress_header.size()));
357 }
358 return FrameType::empty_value();
359 }
360 auto ptr = compress_header.get();
361 auto size = read_le<uint32_t>(ptr);
362 return read_rcv_buf(in, size).then([this, size, &compressor, info] (rcv_buf compressed_data) {
363 if (compressed_data.size != size) {
364 _logger(info, format("unexpected eof on a {} while reading compressed data: expected {:d} got {:d}", FrameType::role(), size, compressed_data.size));
365 return FrameType::empty_value();
366 }
367 auto eb = compressor->decompress(std::move(compressed_data));
368 net::packet p;
369 auto* one = compat::get_if<temporary_buffer<char>>(&eb.bufs);
370 if (one) {
371 p = net::packet(std::move(p), std::move(*one));
372 } else {
373 for (auto&& b : compat::get<std::vector<temporary_buffer<char>>>(eb.bufs)) {
374 p = net::packet(std::move(p), std::move(b));
375 }
376 }
377 return do_with(as_input_stream(std::move(p)), [this, info] (input_stream<char>& in) {
378 return read_frame<FrameType>(info, in);
379 });
380 });
381 });
382 } else {
383 return read_frame<FrameType>(info, in);
384 }
385 }
386
387 struct stream_frame {
388 using opt_buf_type = compat::optional<rcv_buf>;
389 using return_type = future<opt_buf_type>;
390 struct header_type {
391 uint32_t size;
392 bool eos;
393 };
394 static size_t header_size() {
395 return 4;
396 }
397 static const char* role() {
398 return "stream";
399 }
400 static future<opt_buf_type> empty_value() {
401 return make_ready_future<opt_buf_type>(compat::nullopt);
402 }
403 static header_type decode_header(const char* ptr) {
404 header_type h{read_le<uint32_t>(ptr), false};
405 if (h.size == -1U) {
406 h.size = 0;
407 h.eos = true;
408 }
409 return h;
410 }
411 static uint32_t get_size(const header_type& t) {
412 return t.size;
413 }
414 static future<opt_buf_type> make_value(const header_type& t, rcv_buf data) {
415 if (t.eos) {
416 data.size = -1U;
417 }
418 return make_ready_future<opt_buf_type>(std::move(data));
419 }
420 };
421
422 future<compat::optional<rcv_buf>>
423 connection::read_stream_frame_compressed(input_stream<char>& in) {
424 return read_frame_compressed<stream_frame>(peer_address(), _compressor, in);
425 }
426
427 future<> connection::stream_close() {
428 auto f = make_ready_future<>();
429 if (!error()) {
430 promise<bool> p;
431 _sink_closed_future = p.get_future();
432 // stop_send_loop(), which also calls _write_buf.close(), and this code can run in parallel.
433 // Use _sink_closed_future to serialize them and skip second call to close()
434 f = _write_buf.close().finally([p = std::move(p)] () mutable { p.set_value(true);});
435 }
436 return f.finally([this] () mutable { return stop(); });
437 }
438
439 future<> connection::stream_process_incoming(rcv_buf&& buf) {
440 // we do not want to dead lock on huge packets, so let them in
441 // but only one at a time
442 auto size = std::min(size_t(buf.size), max_stream_buffers_memory);
443 return get_units(_stream_sem, size).then([this, buf = std::move(buf)] (semaphore_units<>&& su) mutable {
444 buf.su = std::move(su);
445 return _stream_queue.push_eventually(std::move(buf));
446 });
447 }
448
449 future<> connection::handle_stream_frame() {
450 return read_stream_frame_compressed(_read_buf).then([this] (compat::optional<rcv_buf> data) {
451 if (!data) {
452 _error = true;
453 return make_ready_future<>();
454 }
455 return stream_process_incoming(std::move(*data));
456 });
457 }
458
459 future<> connection::stream_receive(circular_buffer<foreign_ptr<std::unique_ptr<rcv_buf>>>& bufs) {
460 return _stream_queue.not_empty().then([this, &bufs] {
461 bool eof = !_stream_queue.consume([&bufs] (rcv_buf&& b) {
462 if (b.size == -1U) { // max fragment length marks an end of a stream
463 return false;
464 } else {
465 bufs.push_back(make_foreign(std::make_unique<rcv_buf>(std::move(b))));
466 return true;
467 }
468 });
469 if (eof && !bufs.empty()) {
470 assert(_stream_queue.empty());
471 _stream_queue.push(rcv_buf(-1U)); // push eof marker back for next read to notice it
472 }
473 });
474 }
475
476 void connection::register_stream(connection_id id, xshard_connection_ptr c) {
477 _streams.emplace(id, std::move(c));
478 }
479
480 xshard_connection_ptr connection::get_stream(connection_id id) const {
481 auto it = _streams.find(id);
482 if (it == _streams.end()) {
483 throw std::logic_error(format("rpc stream id {:d} not found", id).c_str());
484 }
485 return it->second;
486 }
487
488 static void log_exception(connection& c, const char* log, std::exception_ptr eptr) {
489 const char* s;
490 try {
491 std::rethrow_exception(eptr);
492 } catch (std::exception& ex) {
493 s = ex.what();
494 } catch (...) {
495 s = "unknown exception";
496 }
497 c.get_logger()(c.peer_address(), format("{}: {}", log, s));
498 }
499
500
501 void
502 client::negotiate(feature_map provided) {
503 // record features returned here
504 for (auto&& e : provided) {
505 auto id = e.first;
506 switch (id) {
507 // supported features go here
508 case protocol_features::COMPRESS:
509 if (_options.compressor_factory) {
510 _compressor = _options.compressor_factory->negotiate(e.second, false);
511 }
512 break;
513 case protocol_features::TIMEOUT:
514 _timeout_negotiated = true;
515 break;
516 case protocol_features::CONNECTION_ID: {
517 _id = deserialize_connection_id(e.second);
518 break;
519 }
520 default:
521 // nothing to do
522 ;
523 }
524 }
525 }
526
527 future<>
528 client::negotiate_protocol(input_stream<char>& in) {
529 return receive_negotiation_frame(*this, in).then([this] (feature_map features) {
530 return negotiate(features);
531 });
532 }
533
534 struct response_frame {
535 using opt_buf_type = compat::optional<rcv_buf>;
536 using header_and_buffer_type = std::tuple<int64_t, opt_buf_type>;
537 using return_type = future<header_and_buffer_type>;
538 using header_type = std::tuple<int64_t, uint32_t>;
539 static size_t header_size() {
540 return 12;
541 }
542 static const char* role() {
543 return "client";
544 }
545 static auto empty_value() {
546 return make_ready_future<header_and_buffer_type>(header_and_buffer_type(0, compat::nullopt));
547 }
548 static header_type decode_header(const char* ptr) {
549 auto msgid = read_le<int64_t>(ptr);
550 auto size = read_le<uint32_t>(ptr + 8);
551 return std::make_tuple(msgid, size);
552 }
553 static uint32_t get_size(const header_type& t) {
554 return std::get<1>(t);
555 }
556 static auto make_value(const header_type& t, rcv_buf data) {
557 return make_ready_future<header_and_buffer_type>(header_and_buffer_type(std::get<0>(t), std::move(data)));
558 }
559 };
560
561
562 future<response_frame::header_and_buffer_type>
563 client::read_response_frame(input_stream<char>& in) {
564 return read_frame<response_frame>(_server_addr, in);
565 }
566
567 future<response_frame::header_and_buffer_type>
568 client::read_response_frame_compressed(input_stream<char>& in) {
569 return read_frame_compressed<response_frame>(_server_addr, _compressor, in);
570 }
571
572 stats client::get_stats() const {
573 stats res = _stats;
574 res.wait_reply = _outstanding.size();
575 res.pending = _outgoing_queue.size();
576 return res;
577 }
578
579 void client::wait_for_reply(id_type id, std::unique_ptr<reply_handler_base>&& h, compat::optional<rpc_clock_type::time_point> timeout, cancellable* cancel) {
580 if (timeout) {
581 h->t.set_callback(std::bind(std::mem_fn(&client::wait_timed_out), this, id));
582 h->t.arm(timeout.value());
583 }
584 if (cancel) {
585 cancel->cancel_wait = [this, id] {
586 _outstanding[id]->cancel();
587 _outstanding.erase(id);
588 };
589 h->pcancel = cancel;
590 cancel->wait_back_pointer = &h->pcancel;
591 }
592 _outstanding.emplace(id, std::move(h));
593 }
594 void client::wait_timed_out(id_type id) {
595 _stats.timeout++;
596 _outstanding[id]->timeout();
597 _outstanding.erase(id);
598 }
599
600 future<> client::stop() {
601 if (!_error) {
602 _error = true;
603 _socket.shutdown();
604 }
605 return _stopped.get_future();
606 }
607
608 void client::abort_all_streams() {
609 while (!_streams.empty()) {
610 auto&& s = _streams.begin();
611 assert(s->second->get_owner_shard() == engine().cpu_id()); // abort can be called only locally
612 s->second->get()->abort();
613 _streams.erase(s);
614 }
615 }
616
617 void client::deregister_this_stream() {
618 if (_parent) {
619 _parent->_streams.erase(_id);
620 }
621 }
622
623 client::client(const logger& l, void* s, client_options ops, socket socket, const socket_address& addr, const socket_address& local)
624 : rpc::connection(l, s), _socket(std::move(socket)), _server_addr(addr), _options(ops) {
625 _socket.set_reuseaddr(ops.reuseaddr);
626 // Run client in the background.
627 // Communicate result via _stopped.
628 // The caller has to call client::stop() to synchronize.
629 (void)_socket.connect(addr, local).then([this, ops = std::move(ops)] (connected_socket fd) {
630 fd.set_nodelay(ops.tcp_nodelay);
631 if (ops.keepalive) {
632 fd.set_keepalive(true);
633 fd.set_keepalive_parameters(ops.keepalive.value());
634 }
635 set_socket(std::move(fd));
636
637 feature_map features;
638 if (_options.compressor_factory) {
639 features[protocol_features::COMPRESS] = _options.compressor_factory->supported();
640 }
641 if (_options.send_timeout_data) {
642 features[protocol_features::TIMEOUT] = "";
643 }
644 if (_options.stream_parent) {
645 features[protocol_features::STREAM_PARENT] = serialize_connection_id(_options.stream_parent);
646 }
647 if (!_options.isolation_cookie.empty()) {
648 features[protocol_features::ISOLATION] = _options.isolation_cookie;
649 }
650
651 return send_negotiation_frame(std::move(features)).then([this] {
652 return negotiate_protocol(_read_buf);
653 }).then([this] () {
654 _client_negotiated->set_value();
655 _client_negotiated = compat::nullopt;
656 send_loop();
657 return do_until([this] { return _read_buf.eof() || _error; }, [this] () mutable {
658 if (is_stream()) {
659 return handle_stream_frame();
660 }
661 return read_response_frame_compressed(_read_buf).then([this] (std::tuple<int64_t, compat::optional<rcv_buf>> msg_id_and_data) {
662 auto& msg_id = std::get<0>(msg_id_and_data);
663 auto& data = std::get<1>(msg_id_and_data);
664 auto it = _outstanding.find(std::abs(msg_id));
665 if (!data) {
666 _error = true;
667 } else if (it != _outstanding.end()) {
668 auto handler = std::move(it->second);
669 _outstanding.erase(it);
670 (*handler)(*this, msg_id, std::move(data.value()));
671 } else if (msg_id < 0) {
672 try {
673 std::rethrow_exception(unmarshal_exception(data.value()));
674 } catch(const unknown_verb_error& ex) {
675 // if this is unknown verb exception with unknown id ignore it
676 // can happen if unknown verb was used by no_wait client
677 get_logger()(peer_address(), format("unknown verb exception {:d} ignored", ex.type));
678 } catch(...) {
679 // We've got error response but handler is no longer waiting, could be timed out.
680 log_exception(*this, "ignoring error response", std::current_exception());
681 }
682 } else {
683 // we get a reply for a message id not in _outstanding
684 // this can happened if the message id is timed out already
685 // FIXME: log it but with low level, currently log levels are not supported
686 }
687 });
688 });
689 });
690 }).then_wrapped([this] (future<> f) {
691 std::exception_ptr ep;
692 if (f.failed()) {
693 ep = f.get_exception();
694 if (is_stream()) {
695 log_exception(*this, _connected ? "client stream connection dropped" : "stream fail to connect", ep);
696 } else {
697 log_exception(*this, _connected ? "client connection dropped" : "fail to connect", ep);
698 }
699 }
700 _error = true;
701 _stream_queue.abort(std::make_exception_ptr(stream_closed()));
702 return stop_send_loop().then_wrapped([this] (future<> f) {
703 f.ignore_ready_future();
704 _outstanding.clear();
705 if (is_stream()) {
706 deregister_this_stream();
707 } else {
708 abort_all_streams();
709 }
710 }).finally([this, ep]{
711 if (_client_negotiated && ep) {
712 _client_negotiated->set_exception(ep);
713 }
714 _stopped.set_value();
715 });
716 });
717 }
718
719 client::client(const logger& l, void* s, const socket_address& addr, const socket_address& local)
720 : client(l, s, client_options{}, engine().net().socket(), addr, local)
721 {}
722
723 client::client(const logger& l, void* s, client_options options, const socket_address& addr, const socket_address& local)
724 : client(l, s, options, engine().net().socket(), addr, local)
725 {}
726
727 client::client(const logger& l, void* s, socket socket, const socket_address& addr, const socket_address& local)
728 : client(l, s, client_options{}, std::move(socket), addr, local)
729 {}
730
731
732 future<feature_map>
733 server::connection::negotiate(feature_map requested) {
734 feature_map ret;
735 future<> f = make_ready_future<>();
736 for (auto&& e : requested) {
737 auto id = e.first;
738 switch (id) {
739 // supported features go here
740 case protocol_features::COMPRESS: {
741 if (_server._options.compressor_factory) {
742 _compressor = _server._options.compressor_factory->negotiate(e.second, true);
743 ret[protocol_features::COMPRESS] = _server._options.compressor_factory->supported();
744 }
745 }
746 break;
747 case protocol_features::TIMEOUT:
748 _timeout_negotiated = true;
749 ret[protocol_features::TIMEOUT] = "";
750 break;
751 case protocol_features::STREAM_PARENT: {
752 if (!_server._options.streaming_domain) {
753 f = make_exception_future<>(std::runtime_error("streaming is not configured for the server"));
754 } else {
755 _parent_id = deserialize_connection_id(e.second);
756 _is_stream = true;
757 // remove stream connection from rpc connection list
758 _server._conns.erase(get_connection_id());
759 f = smp::submit_to(_parent_id.shard(), [this, c = make_foreign(static_pointer_cast<rpc::connection>(shared_from_this()))] () mutable {
760 auto sit = _servers.find(*_server._options.streaming_domain);
761 if (sit == _servers.end()) {
762 throw std::logic_error(format("Shard {:d} does not have server with streaming domain {:x}", engine().cpu_id(), *_server._options.streaming_domain).c_str());
763 }
764 auto s = sit->second;
765 auto it = s->_conns.find(_parent_id);
766 if (it == s->_conns.end()) {
767 throw std::logic_error(format("Unknown parent connection {:d} on shard {:d}", _parent_id, engine().cpu_id()).c_str());
768 }
769 auto id = c->get_connection_id();
770 it->second->register_stream(id, make_lw_shared(std::move(c)));
771 });
772 }
773 break;
774 }
775 case protocol_features::ISOLATION: {
776 auto&& isolation_cookie = e.second;
777 _isolation_config = _server._limits.isolate_connection(isolation_cookie);
778 ret.emplace(e);
779 break;
780 }
781 default:
782 // nothing to do
783 ;
784 }
785 }
786 if (_server._options.streaming_domain) {
787 ret[protocol_features::CONNECTION_ID] = serialize_connection_id(_id);
788 }
789 return f.then([ret = std::move(ret)] {
790 return ret;
791 });
792 }
793
794 future<>
795 server::connection::negotiate_protocol(input_stream<char>& in) {
796 return receive_negotiation_frame(*this, in).then([this] (feature_map requested_features) {
797 return negotiate(std::move(requested_features)).then([this] (feature_map returned_features) {
798 return send_negotiation_frame(std::move(returned_features));
799 });
800 });
801 }
802
803 struct request_frame {
804 using opt_buf_type = compat::optional<rcv_buf>;
805 using header_and_buffer_type = std::tuple<compat::optional<uint64_t>, uint64_t, int64_t, opt_buf_type>;
806 using return_type = future<header_and_buffer_type>;
807 using header_type = std::tuple<compat::optional<uint64_t>, uint64_t, int64_t, uint32_t>;
808 static size_t header_size() {
809 return 20;
810 }
811 static const char* role() {
812 return "server";
813 }
814 static auto empty_value() {
815 return make_ready_future<header_and_buffer_type>(header_and_buffer_type(compat::nullopt, uint64_t(0), 0, compat::nullopt));
816 }
817 static header_type decode_header(const char* ptr) {
818 auto type = read_le<uint64_t>(ptr);
819 auto msgid = read_le<int64_t>(ptr + 8);
820 auto size = read_le<uint32_t>(ptr + 16);
821 return std::make_tuple(compat::nullopt, type, msgid, size);
822 }
823 static uint32_t get_size(const header_type& t) {
824 return std::get<3>(t);
825 }
826 static auto make_value(const header_type& t, rcv_buf data) {
827 return make_ready_future<header_and_buffer_type>(header_and_buffer_type(std::get<0>(t), std::get<1>(t), std::get<2>(t), std::move(data)));
828 }
829 };
830
831 struct request_frame_with_timeout : request_frame {
832 using super = request_frame;
833 static size_t header_size() {
834 return 28;
835 }
836 static typename super::header_type decode_header(const char* ptr) {
837 auto h = super::decode_header(ptr + 8);
838 std::get<0>(h) = read_le<uint64_t>(ptr);
839 return h;
840 }
841 };
842
843 future<request_frame::header_and_buffer_type>
844 server::connection::read_request_frame_compressed(input_stream<char>& in) {
845 if (_timeout_negotiated) {
846 return read_frame_compressed<request_frame_with_timeout>(_info.addr, _compressor, in);
847 } else {
848 return read_frame_compressed<request_frame>(_info.addr, _compressor, in);
849 }
850 }
851
852 future<>
853 server::connection::respond(int64_t msg_id, snd_buf&& data, compat::optional<rpc_clock_type::time_point> timeout) {
854 static_assert(snd_buf::chunk_size >= 12, "send buffer chunk size is too small");
855 auto p = data.front().get_write();
856 write_le<int64_t>(p, msg_id);
857 write_le<uint32_t>(p + 8, data.size - 12);
858 return send(std::move(data), timeout);
859 }
860
861 future<> server::connection::send_unknown_verb_reply(compat::optional<rpc_clock_type::time_point> timeout, int64_t msg_id, uint64_t type) {
862 return wait_for_resources(28, timeout).then([this, timeout, msg_id, type] (auto permit) {
863 // send unknown_verb exception back
864 snd_buf data(28);
865 static_assert(snd_buf::chunk_size >= 28, "send buffer chunk size is too small");
866 auto p = data.front().get_write() + 12;
867 write_le<uint32_t>(p, uint32_t(exception_type::UNKNOWN_VERB));
868 write_le<uint32_t>(p + 4, uint32_t(8));
869 write_le<uint64_t>(p + 8, type);
870 try {
871 // Send asynchronously.
872 // This is safe since connection::stop() will wait for background work.
873 (void)with_gate(_server._reply_gate, [this, timeout, msg_id, data = std::move(data), permit = std::move(permit)] () mutable {
874 // workaround for https://gcc.gnu.org/bugzilla/show_bug.cgi?id=83268
875 auto c = shared_from_this();
876 return respond(-msg_id, std::move(data), timeout).then([c = std::move(c), permit = std::move(permit)] {});
877 });
878 } catch(gate_closed_exception&) {/* ignore */}
879 });
880 }
881
882 future<> server::connection::process() {
883 return negotiate_protocol(_read_buf).then([this] () mutable {
884 auto sg = _isolation_config ? _isolation_config->sched_group : current_scheduling_group();
885 return with_scheduling_group(sg, [this] {
886 send_loop();
887 return do_until([this] { return _read_buf.eof() || _error; }, [this] () mutable {
888 if (is_stream()) {
889 return handle_stream_frame();
890 }
891 return read_request_frame_compressed(_read_buf).then([this] (request_frame::header_and_buffer_type header_and_buffer) {
892 auto& expire = std::get<0>(header_and_buffer);
893 auto& type = std::get<1>(header_and_buffer);
894 auto& msg_id = std::get<2>(header_and_buffer);
895 auto& data = std::get<3>(header_and_buffer);
896 if (!data) {
897 _error = true;
898 return make_ready_future<>();
899 } else {
900 compat::optional<rpc_clock_type::time_point> timeout;
901 if (expire && *expire) {
902 timeout = relative_timeout_to_absolute(std::chrono::milliseconds(*expire));
903 }
904 auto h = _server._proto->get_handler(type);
905 if (!h) {
906 return send_unknown_verb_reply(timeout, msg_id, type);
907 }
908
909 // If the new method of per-connection scheduling group was used, honor it.
910 // Otherwise, use the old per-handler scheduling group.
911 auto sg = _isolation_config ? _isolation_config->sched_group : h->sg;
912 return with_scheduling_group(sg, [this, timeout, msg_id, h, data = std::move(data.value())] () mutable {
913 return h->func(shared_from_this(), timeout, msg_id, std::move(data)).finally([this, h] {
914 // If anything between get_handler() and here throws, we leak put_handler
915 _server._proto->put_handler(h);
916 });
917 });
918 }
919 });
920 });
921 });
922 }).then_wrapped([this] (future<> f) {
923 if (f.failed()) {
924 log_exception(*this, format("server{} connection dropped", is_stream() ? " stream" : "").c_str(), f.get_exception());
925 }
926 _fd.shutdown_input();
927 _error = true;
928 _stream_queue.abort(std::make_exception_ptr(stream_closed()));
929 return stop_send_loop().then_wrapped([this] (future<> f) {
930 f.ignore_ready_future();
931 _server._conns.erase(get_connection_id());
932 if (is_stream()) {
933 return deregister_this_stream();
934 } else {
935 return make_ready_future<>();
936 }
937 }).finally([this] {
938 _stopped.set_value();
939 });
940 }).finally([conn_ptr = shared_from_this()] {
941 // hold onto connection pointer until do_until() exists
942 });
943 }
944
945 server::connection::connection(server& s, connected_socket&& fd, socket_address&& addr, const logger& l, void* serializer, connection_id id)
946 : rpc::connection(std::move(fd), l, serializer, id), _server(s) {
947 _info.addr = std::move(addr);
948 }
949
950 future<> server::connection::deregister_this_stream() {
951 if (!_server._options.streaming_domain) {
952 return make_ready_future<>();
953 }
954 return smp::submit_to(_parent_id.shard(), [this] () mutable {
955 auto sit = server::_servers.find(*_server._options.streaming_domain);
956 if (sit != server::_servers.end()) {
957 auto s = sit->second;
958 auto it = s->_conns.find(_parent_id);
959 if (it != s->_conns.end()) {
960 it->second->_streams.erase(get_connection_id());
961 }
962 }
963 });
964 }
965
966 thread_local std::unordered_map<streaming_domain_type, server*> server::_servers;
967
968 server::server(protocol_base* proto, const socket_address& addr, resource_limits limits)
969 : server(proto, engine().listen(addr, listen_options{true}), limits, server_options{})
970 {}
971
972 server::server(protocol_base* proto, server_options opts, const socket_address& addr, resource_limits limits)
973 : server(proto, engine().listen(addr, listen_options{true, opts.load_balancing_algorithm}), limits, opts)
974 {}
975
976 server::server(protocol_base* proto, server_socket ss, resource_limits limits, server_options opts)
977 : _proto(proto), _ss(std::move(ss)), _limits(limits), _resources_available(limits.max_memory), _options(opts)
978 {
979 if (_options.streaming_domain) {
980 if (_servers.find(*_options.streaming_domain) != _servers.end()) {
981 throw std::runtime_error(format("An RPC server with the streaming domain {} is already exist", *_options.streaming_domain));
982 }
983 _servers[*_options.streaming_domain] = this;
984 }
985 accept();
986 }
987
988 server::server(protocol_base* proto, server_options opts, server_socket ss, resource_limits limits)
989 : server(proto, std::move(ss), limits, opts)
990 {}
991
992 void server::accept() {
993 // Run asynchronously in background.
994 // Communicate result via __ss_stopped.
995 // The caller has to call server::stop() to synchronize.
996 (void)keep_doing([this] () mutable {
997 return _ss.accept().then([this] (accept_result ar) mutable {
998 auto fd = std::move(ar.connection);
999 auto addr = std::move(ar.remote_address);
1000 fd.set_nodelay(_options.tcp_nodelay);
1001 connection_id id = _options.streaming_domain ?
1002 connection_id::make_id(_next_client_id++, uint16_t(engine().cpu_id())) :
1003 connection_id::make_invalid_id(_next_client_id++);
1004 auto conn = _proto->make_server_connection(*this, std::move(fd), std::move(addr), id);
1005 auto r = _conns.emplace(id, conn);
1006 assert(r.second);
1007 // Process asynchronously in background.
1008 (void)conn->process();
1009 });
1010 }).then_wrapped([this] (future<>&& f){
1011 try {
1012 f.get();
1013 assert(false);
1014 } catch (...) {
1015 _ss_stopped.set_value();
1016 }
1017 });
1018 }
1019
1020 future<> server::stop() {
1021 _ss.abort_accept();
1022 _resources_available.broken();
1023 if (_options.streaming_domain) {
1024 _servers.erase(*_options.streaming_domain);
1025 }
1026 return when_all(_ss_stopped.get_future(),
1027 parallel_for_each(_conns | boost::adaptors::map_values, [] (shared_ptr<connection> conn) {
1028 return conn->stop();
1029 }),
1030 _reply_gate.close()
1031 ).discard_result();
1032 }
1033
1034 std::ostream& operator<<(std::ostream& os, const connection_id& id) {
1035 return fmt_print(os, "{:x}", id.id);
1036 }
1037
1038 std::ostream& operator<<(std::ostream& os, const streaming_domain_type& domain) {
1039 return fmt_print(os, "{:d}", domain._id);
1040 }
1041
1042 isolation_config default_isolate_connection(sstring isolation_cookie) {
1043 return isolation_config{};
1044 }
1045
1046 }
1047
1048 }