1 #include <seastar/rpc/rpc.hh>
2 #include <seastar/core/print.hh>
3 #include <seastar/util/defer.hh>
4 #include <boost/range/adaptor/map.hpp>
10 void logger::operator()(const client_info
& info
, id_type msg_id
, const sstring
& str
) const {
11 log(format("client {} msg_id {}: {}", info
.addr
, msg_id
, str
));
14 void logger::operator()(const client_info
& info
, const sstring
& str
) const {
15 (*this)(info
.addr
, str
);
18 void logger::operator()(const socket_address
& addr
, const sstring
& str
) const {
19 log(format("client {}: {}", addr
, str
));
24 constexpr size_t snd_buf::chunk_size
;
26 snd_buf::snd_buf(size_t size_
) : size(size_
) {
27 if (size
<= chunk_size
) {
28 bufs
= temporary_buffer
<char>(size
);
30 std::vector
<temporary_buffer
<char>> v
;
31 v
.reserve(align_up(size_t(size
), chunk_size
) / chunk_size
);
33 v
.push_back(temporary_buffer
<char>(std::min(chunk_size
, size_
)));
34 size_
-= v
.back().size();
40 temporary_buffer
<char>& snd_buf::front() {
41 auto* one
= compat::get_if
<temporary_buffer
<char>>(&bufs
);
45 return compat::get
<std::vector
<temporary_buffer
<char>>>(bufs
).front();
49 // Make a copy of a remote buffer. No data is actually copied, only pointers and
50 // a deleter of a new buffer takes care of deleting the original buffer
51 template<typename T
> // T is either snd_buf or rcv_buf
52 T
make_shard_local_buffer_copy(foreign_ptr
<std::unique_ptr
<T
>> org
) {
53 if (org
.get_owner_shard() == engine().cpu_id()) {
54 return std::move(*org
);
57 auto* one
= compat::get_if
<temporary_buffer
<char>>(&org
->bufs
);
60 buf
.bufs
= temporary_buffer
<char>(one
->get_write(), one
->size(), make_object_deleter(std::move(org
)));
62 auto& orgbufs
= compat::get
<std::vector
<temporary_buffer
<char>>>(org
->bufs
);
63 std::vector
<temporary_buffer
<char>> newbufs
;
64 newbufs
.reserve(orgbufs
.size());
65 deleter d
= make_object_deleter(std::move(org
));
66 for (auto&& b
: orgbufs
) {
67 newbufs
.push_back(temporary_buffer
<char>(b
.get_write(), b
.size(), d
.share()));
69 buf
.bufs
= std::move(newbufs
);
75 template snd_buf
make_shard_local_buffer_copy(foreign_ptr
<std::unique_ptr
<snd_buf
>>);
76 template rcv_buf
make_shard_local_buffer_copy(foreign_ptr
<std::unique_ptr
<rcv_buf
>>);
78 snd_buf
connection::compress(snd_buf buf
) {
80 buf
= _compressor
->compress(4, std::move(buf
));
81 static_assert(snd_buf::chunk_size
>= 4, "send buffer chunk size is too small");
82 write_le
<uint32_t>(buf
.front().get_write(), buf
.size
- 4);
88 future
<> connection::send_buffer(snd_buf buf
) {
89 auto* b
= compat::get_if
<temporary_buffer
<char>>(&buf
.bufs
);
91 return _write_buf
.write(std::move(*b
));
93 return do_with(std::move(compat::get
<std::vector
<temporary_buffer
<char>>>(buf
.bufs
)),
94 [this] (std::vector
<temporary_buffer
<char>>& ar
) {
95 return do_for_each(ar
.begin(), ar
.end(), [this] (auto& b
) {
96 return _write_buf
.write(std::move(b
));
102 template<connection::outgoing_queue_type QueueType
>
103 void connection::send_loop() {
104 _send_loop_stopped
= do_until([this] { return _error
; }, [this] {
105 return _outgoing_queue_cond
.wait([this] { return !_outgoing_queue
.empty(); }).then([this] {
106 // despite using wait with predicated above _outgoing_queue can still be empty here if
107 // there is only one entry on the list and its expire timer runs after wait() returned ready future,
108 // but before this continuation runs.
109 if (_outgoing_queue
.empty()) {
110 return make_ready_future();
112 auto d
= std::move(_outgoing_queue
.front());
113 _outgoing_queue
.pop_front();
114 d
.t
.cancel(); // cancel timeout timer
116 d
.pcancel
->cancel_send
= std::function
<void()>(); // request is no longer cancellable
118 if (QueueType
== outgoing_queue_type::request
) {
119 static_assert(snd_buf::chunk_size
>= 8, "send buffer chunk size is too small");
120 if (_timeout_negotiated
) {
121 auto expire
= d
.t
.get_timeout();
123 if (expire
!= typename timer
<rpc_clock_type
>::time_point()) {
124 left
= std::chrono::duration_cast
<std::chrono::milliseconds
>(expire
- timer
<rpc_clock_type
>::clock::now()).count();
126 write_le
<uint64_t>(d
.buf
.front().get_write(), left
);
128 d
.buf
.front().trim_front(8);
132 d
.buf
= compress(std::move(d
.buf
));
133 auto f
= send_buffer(std::move(d
.buf
)).then([this] {
134 _stats
.sent_messages
++;
135 return _write_buf
.flush();
137 return f
.finally([d
= std::move(d
)] {});
139 }).handle_exception([this] (std::exception_ptr eptr
) {
144 future
<> connection::stop_send_loop() {
147 _outgoing_queue_cond
.broken();
148 _fd
.shutdown_output();
150 return when_all(std::move(_send_loop_stopped
), std::move(_sink_closed_future
)).then([this] (std::tuple
<future
<>, future
<bool>> res
){
151 _outgoing_queue
.clear();
152 // both _send_loop_stopped and _sink_closed_future are never exceptional
153 bool sink_closed
= std::get
<1>(res
).get0();
154 return _connected
&& !sink_closed
? _write_buf
.close() : make_ready_future();
158 void connection::set_socket(connected_socket
&& fd
) {
160 throw std::runtime_error("already connected");
163 _read_buf
=_fd
.input();
164 _write_buf
= _fd
.output();
168 future
<> connection::send_negotiation_frame(feature_map features
) {
169 auto negotiation_frame_feature_record_size
= [] (const feature_map::value_type
& e
) {
170 return 8 + e
.second
.size();
172 auto extra_len
= boost::accumulate(
173 features
| boost::adaptors::transformed(negotiation_frame_feature_record_size
),
175 temporary_buffer
<char> reply(sizeof(negotiation_frame
) + extra_len
);
176 auto p
= reply
.get_write();
177 p
= std::copy_n(rpc_magic
, 8, p
);
178 write_le
<uint32_t>(p
, extra_len
);
180 for (auto&& e
: features
) {
181 write_le
<uint32_t>(p
, static_cast<uint32_t>(e
.first
));
183 write_le
<uint32_t>(p
, e
.second
.size());
185 p
= std::copy_n(e
.second
.begin(), e
.second
.size(), p
);
187 return _write_buf
.write(std::move(reply
)).then([this] {
188 _stats
.sent_messages
++;
189 return _write_buf
.flush();
193 future
<> connection::send(snd_buf buf
, compat::optional
<rpc_clock_type::time_point
> timeout
, cancellable
* cancel
) {
195 if (timeout
&& *timeout
<= rpc_clock_type::now()) {
196 return make_ready_future
<>();
198 _outgoing_queue
.emplace_back(std::move(buf
));
199 auto deleter
= [this, it
= std::prev(_outgoing_queue
.cend())] {
200 _outgoing_queue
.erase(it
);
203 auto& t
= _outgoing_queue
.back().t
;
204 t
.set_callback(deleter
);
205 t
.arm(timeout
.value());
208 cancel
->cancel_send
= std::move(deleter
);
209 cancel
->send_back_pointer
= &_outgoing_queue
.back().pcancel
;
210 _outgoing_queue
.back().pcancel
= cancel
;
212 _outgoing_queue_cond
.signal();
213 return _outgoing_queue
.back().p
->get_future();
215 return make_exception_future
<>(closed_error());
219 void connection::abort() {
222 _fd
.shutdown_input();
226 future
<> connection::stop() {
228 return _stopped
.get_future();
231 template<typename Connection
>
232 static bool verify_frame(Connection
& c
, temporary_buffer
<char>& buf
, size_t expected
, const char* log
) {
233 if (buf
.size() != expected
) {
234 if (buf
.size() != 0) {
235 c
.get_logger()(c
.peer_address(), log
);
242 template<typename Connection
>
245 receive_negotiation_frame(Connection
& c
, input_stream
<char>& in
) {
246 return in
.read_exactly(sizeof(negotiation_frame
)).then([&c
, &in
] (temporary_buffer
<char> neg
) {
247 if (!verify_frame(c
, neg
, sizeof(negotiation_frame
), "unexpected eof during negotiation frame")) {
248 return make_exception_future
<feature_map
>(closed_error());
250 negotiation_frame frame
;
251 std::copy_n(neg
.get_write(), sizeof(frame
.magic
), frame
.magic
);
252 frame
.len
= read_le
<uint32_t>(neg
.get_write() + 8);
253 if (std::memcmp(frame
.magic
, rpc_magic
, sizeof(frame
.magic
)) != 0) {
254 c
.get_logger()(c
.peer_address(), "wrong protocol magic");
255 return make_exception_future
<feature_map
>(closed_error());
257 auto len
= frame
.len
;
258 return in
.read_exactly(len
).then([&c
, len
] (temporary_buffer
<char> extra
) {
259 if (extra
.size() != len
) {
260 c
.get_logger()(c
.peer_address(), "unexpected eof during negotiation frame");
261 return make_exception_future
<feature_map
>(closed_error());
264 auto p
= extra
.get();
265 auto end
= p
+ extra
.size();
268 c
.get_logger()(c
.peer_address(), "bad feature data format in negotiation frame");
269 return make_exception_future
<feature_map
>(closed_error());
271 auto feature
= static_cast<protocol_features
>(read_le
<uint32_t>(p
));
272 auto f_len
= read_le
<uint32_t>(p
+ 4);
274 if (f_len
> end
- p
) {
275 c
.get_logger()(c
.peer_address(), "buffer underflow in feature data in negotiation frame");
276 return make_exception_future
<feature_map
>(closed_error());
278 auto data
= sstring(p
, f_len
);
280 map
.emplace(feature
, std::move(data
));
282 return make_ready_future
<feature_map
>(std::move(map
));
287 inline future
<rcv_buf
>
288 read_rcv_buf(input_stream
<char>& in
, uint32_t size
) {
289 return in
.read_up_to(size
).then([&, size
] (temporary_buffer
<char> data
) mutable {
291 if (data
.size() == 0) {
292 return make_ready_future
<rcv_buf
>(rcv_buf());
293 } else if (data
.size() == size
) {
294 rb
.bufs
= std::move(data
);
295 return make_ready_future
<rcv_buf
>(std::move(rb
));
298 std::vector
<temporary_buffer
<char>> v
;
299 v
.push_back(std::move(data
));
300 rb
.bufs
= std::move(v
);
301 return do_with(std::move(rb
), std::move(size
), [&in
] (rcv_buf
& rb
, uint32_t& left
) {
302 return repeat([&] () {
303 return in
.read_up_to(left
).then([&] (temporary_buffer
<char> data
) {
306 return stop_iteration::yes
;
309 compat::get
<std::vector
<temporary_buffer
<char>>>(rb
.bufs
).push_back(std::move(data
));
310 return left
? stop_iteration::no
: stop_iteration::yes
;
314 return std::move(rb
);
321 template<typename FrameType
>
322 typename
FrameType::return_type
323 connection::read_frame(socket_address info
, input_stream
<char>& in
) {
324 auto header_size
= FrameType::header_size();
325 return in
.read_exactly(header_size
).then([this, header_size
, info
, &in
] (temporary_buffer
<char> header
) {
326 if (header
.size() != header_size
) {
327 if (header
.size() != 0) {
328 _logger(info
, format("unexpected eof on a {} while reading header: expected {:d} got {:d}", FrameType::role(), header_size
, header
.size()));
330 return FrameType::empty_value();
332 auto h
= FrameType::decode_header(header
.get());
333 auto size
= FrameType::get_size(h
);
335 return FrameType::make_value(h
, rcv_buf());
337 return read_rcv_buf(in
, size
).then([this, info
, h
= std::move(h
), size
] (rcv_buf rb
) {
338 if (rb
.size
!= size
) {
339 _logger(info
, format("unexpected eof on a {} while reading data: expected {:d} got {:d}", FrameType::role(), size
, rb
.size
));
340 return FrameType::empty_value();
342 return FrameType::make_value(h
, std::move(rb
));
349 template<typename FrameType
>
350 typename
FrameType::return_type
351 connection::read_frame_compressed(socket_address info
, std::unique_ptr
<compressor
>& compressor
, input_stream
<char>& in
) {
353 return in
.read_exactly(4).then([&] (temporary_buffer
<char> compress_header
) {
354 if (compress_header
.size() != 4) {
355 if (compress_header
.size() != 0) {
356 _logger(info
, format("unexpected eof on a {} while reading compression header: expected 4 got {:d}", FrameType::role(), compress_header
.size()));
358 return FrameType::empty_value();
360 auto ptr
= compress_header
.get();
361 auto size
= read_le
<uint32_t>(ptr
);
362 return read_rcv_buf(in
, size
).then([this, size
, &compressor
, info
] (rcv_buf compressed_data
) {
363 if (compressed_data
.size
!= size
) {
364 _logger(info
, format("unexpected eof on a {} while reading compressed data: expected {:d} got {:d}", FrameType::role(), size
, compressed_data
.size
));
365 return FrameType::empty_value();
367 auto eb
= compressor
->decompress(std::move(compressed_data
));
369 auto* one
= compat::get_if
<temporary_buffer
<char>>(&eb
.bufs
);
371 p
= net::packet(std::move(p
), std::move(*one
));
373 for (auto&& b
: compat::get
<std::vector
<temporary_buffer
<char>>>(eb
.bufs
)) {
374 p
= net::packet(std::move(p
), std::move(b
));
377 return do_with(as_input_stream(std::move(p
)), [this, info
] (input_stream
<char>& in
) {
378 return read_frame
<FrameType
>(info
, in
);
383 return read_frame
<FrameType
>(info
, in
);
387 struct stream_frame
{
388 using opt_buf_type
= compat::optional
<rcv_buf
>;
389 using return_type
= future
<opt_buf_type
>;
394 static size_t header_size() {
397 static const char* role() {
400 static future
<opt_buf_type
> empty_value() {
401 return make_ready_future
<opt_buf_type
>(compat::nullopt
);
403 static header_type
decode_header(const char* ptr
) {
404 header_type h
{read_le
<uint32_t>(ptr
), false};
411 static uint32_t get_size(const header_type
& t
) {
414 static future
<opt_buf_type
> make_value(const header_type
& t
, rcv_buf data
) {
418 return make_ready_future
<opt_buf_type
>(std::move(data
));
422 future
<compat::optional
<rcv_buf
>>
423 connection::read_stream_frame_compressed(input_stream
<char>& in
) {
424 return read_frame_compressed
<stream_frame
>(peer_address(), _compressor
, in
);
427 future
<> connection::stream_close() {
428 auto f
= make_ready_future
<>();
431 _sink_closed_future
= p
.get_future();
432 // stop_send_loop(), which also calls _write_buf.close(), and this code can run in parallel.
433 // Use _sink_closed_future to serialize them and skip second call to close()
434 f
= _write_buf
.close().finally([p
= std::move(p
)] () mutable { p
.set_value(true);});
436 return f
.finally([this] () mutable { return stop(); });
439 future
<> connection::stream_process_incoming(rcv_buf
&& buf
) {
440 // we do not want to dead lock on huge packets, so let them in
441 // but only one at a time
442 auto size
= std::min(size_t(buf
.size
), max_stream_buffers_memory
);
443 return get_units(_stream_sem
, size
).then([this, buf
= std::move(buf
)] (semaphore_units
<>&& su
) mutable {
444 buf
.su
= std::move(su
);
445 return _stream_queue
.push_eventually(std::move(buf
));
449 future
<> connection::handle_stream_frame() {
450 return read_stream_frame_compressed(_read_buf
).then([this] (compat::optional
<rcv_buf
> data
) {
453 return make_ready_future
<>();
455 return stream_process_incoming(std::move(*data
));
459 future
<> connection::stream_receive(circular_buffer
<foreign_ptr
<std::unique_ptr
<rcv_buf
>>>& bufs
) {
460 return _stream_queue
.not_empty().then([this, &bufs
] {
461 bool eof
= !_stream_queue
.consume([&bufs
] (rcv_buf
&& b
) {
462 if (b
.size
== -1U) { // max fragment length marks an end of a stream
465 bufs
.push_back(make_foreign(std::make_unique
<rcv_buf
>(std::move(b
))));
469 if (eof
&& !bufs
.empty()) {
470 assert(_stream_queue
.empty());
471 _stream_queue
.push(rcv_buf(-1U)); // push eof marker back for next read to notice it
476 void connection::register_stream(connection_id id
, xshard_connection_ptr c
) {
477 _streams
.emplace(id
, std::move(c
));
480 xshard_connection_ptr
connection::get_stream(connection_id id
) const {
481 auto it
= _streams
.find(id
);
482 if (it
== _streams
.end()) {
483 throw std::logic_error(format("rpc stream id {:d} not found", id
).c_str());
488 static void log_exception(connection
& c
, const char* log
, std::exception_ptr eptr
) {
491 std::rethrow_exception(eptr
);
492 } catch (std::exception
& ex
) {
495 s
= "unknown exception";
497 c
.get_logger()(c
.peer_address(), format("{}: {}", log
, s
));
502 client::negotiate(feature_map provided
) {
503 // record features returned here
504 for (auto&& e
: provided
) {
507 // supported features go here
508 case protocol_features::COMPRESS
:
509 if (_options
.compressor_factory
) {
510 _compressor
= _options
.compressor_factory
->negotiate(e
.second
, false);
513 case protocol_features::TIMEOUT
:
514 _timeout_negotiated
= true;
516 case protocol_features::CONNECTION_ID
: {
517 _id
= deserialize_connection_id(e
.second
);
528 client::negotiate_protocol(input_stream
<char>& in
) {
529 return receive_negotiation_frame(*this, in
).then([this] (feature_map features
) {
530 return negotiate(features
);
534 struct response_frame
{
535 using opt_buf_type
= compat::optional
<rcv_buf
>;
536 using header_and_buffer_type
= std::tuple
<int64_t, opt_buf_type
>;
537 using return_type
= future
<header_and_buffer_type
>;
538 using header_type
= std::tuple
<int64_t, uint32_t>;
539 static size_t header_size() {
542 static const char* role() {
545 static auto empty_value() {
546 return make_ready_future
<header_and_buffer_type
>(header_and_buffer_type(0, compat::nullopt
));
548 static header_type
decode_header(const char* ptr
) {
549 auto msgid
= read_le
<int64_t>(ptr
);
550 auto size
= read_le
<uint32_t>(ptr
+ 8);
551 return std::make_tuple(msgid
, size
);
553 static uint32_t get_size(const header_type
& t
) {
554 return std::get
<1>(t
);
556 static auto make_value(const header_type
& t
, rcv_buf data
) {
557 return make_ready_future
<header_and_buffer_type
>(header_and_buffer_type(std::get
<0>(t
), std::move(data
)));
562 future
<response_frame::header_and_buffer_type
>
563 client::read_response_frame(input_stream
<char>& in
) {
564 return read_frame
<response_frame
>(_server_addr
, in
);
567 future
<response_frame::header_and_buffer_type
>
568 client::read_response_frame_compressed(input_stream
<char>& in
) {
569 return read_frame_compressed
<response_frame
>(_server_addr
, _compressor
, in
);
572 stats
client::get_stats() const {
574 res
.wait_reply
= _outstanding
.size();
575 res
.pending
= _outgoing_queue
.size();
579 void client::wait_for_reply(id_type id
, std::unique_ptr
<reply_handler_base
>&& h
, compat::optional
<rpc_clock_type::time_point
> timeout
, cancellable
* cancel
) {
581 h
->t
.set_callback(std::bind(std::mem_fn(&client::wait_timed_out
), this, id
));
582 h
->t
.arm(timeout
.value());
585 cancel
->cancel_wait
= [this, id
] {
586 _outstanding
[id
]->cancel();
587 _outstanding
.erase(id
);
590 cancel
->wait_back_pointer
= &h
->pcancel
;
592 _outstanding
.emplace(id
, std::move(h
));
594 void client::wait_timed_out(id_type id
) {
596 _outstanding
[id
]->timeout();
597 _outstanding
.erase(id
);
600 future
<> client::stop() {
605 return _stopped
.get_future();
608 void client::abort_all_streams() {
609 while (!_streams
.empty()) {
610 auto&& s
= _streams
.begin();
611 assert(s
->second
->get_owner_shard() == engine().cpu_id()); // abort can be called only locally
612 s
->second
->get()->abort();
617 void client::deregister_this_stream() {
619 _parent
->_streams
.erase(_id
);
623 client::client(const logger
& l
, void* s
, client_options ops
, socket socket
, const socket_address
& addr
, const socket_address
& local
)
624 : rpc::connection(l
, s
), _socket(std::move(socket
)), _server_addr(addr
), _options(ops
) {
625 _socket
.set_reuseaddr(ops
.reuseaddr
);
626 // Run client in the background.
627 // Communicate result via _stopped.
628 // The caller has to call client::stop() to synchronize.
629 (void)_socket
.connect(addr
, local
).then([this, ops
= std::move(ops
)] (connected_socket fd
) {
630 fd
.set_nodelay(ops
.tcp_nodelay
);
632 fd
.set_keepalive(true);
633 fd
.set_keepalive_parameters(ops
.keepalive
.value());
635 set_socket(std::move(fd
));
637 feature_map features
;
638 if (_options
.compressor_factory
) {
639 features
[protocol_features::COMPRESS
] = _options
.compressor_factory
->supported();
641 if (_options
.send_timeout_data
) {
642 features
[protocol_features::TIMEOUT
] = "";
644 if (_options
.stream_parent
) {
645 features
[protocol_features::STREAM_PARENT
] = serialize_connection_id(_options
.stream_parent
);
647 if (!_options
.isolation_cookie
.empty()) {
648 features
[protocol_features::ISOLATION
] = _options
.isolation_cookie
;
651 return send_negotiation_frame(std::move(features
)).then([this] {
652 return negotiate_protocol(_read_buf
);
654 _client_negotiated
->set_value();
655 _client_negotiated
= compat::nullopt
;
657 return do_until([this] { return _read_buf
.eof() || _error
; }, [this] () mutable {
659 return handle_stream_frame();
661 return read_response_frame_compressed(_read_buf
).then([this] (std::tuple
<int64_t, compat::optional
<rcv_buf
>> msg_id_and_data
) {
662 auto& msg_id
= std::get
<0>(msg_id_and_data
);
663 auto& data
= std::get
<1>(msg_id_and_data
);
664 auto it
= _outstanding
.find(std::abs(msg_id
));
667 } else if (it
!= _outstanding
.end()) {
668 auto handler
= std::move(it
->second
);
669 _outstanding
.erase(it
);
670 (*handler
)(*this, msg_id
, std::move(data
.value()));
671 } else if (msg_id
< 0) {
673 std::rethrow_exception(unmarshal_exception(data
.value()));
674 } catch(const unknown_verb_error
& ex
) {
675 // if this is unknown verb exception with unknown id ignore it
676 // can happen if unknown verb was used by no_wait client
677 get_logger()(peer_address(), format("unknown verb exception {:d} ignored", ex
.type
));
679 // We've got error response but handler is no longer waiting, could be timed out.
680 log_exception(*this, "ignoring error response", std::current_exception());
683 // we get a reply for a message id not in _outstanding
684 // this can happened if the message id is timed out already
685 // FIXME: log it but with low level, currently log levels are not supported
690 }).then_wrapped([this] (future
<> f
) {
691 std::exception_ptr ep
;
693 ep
= f
.get_exception();
695 log_exception(*this, _connected
? "client stream connection dropped" : "stream fail to connect", ep
);
697 log_exception(*this, _connected
? "client connection dropped" : "fail to connect", ep
);
701 _stream_queue
.abort(std::make_exception_ptr(stream_closed()));
702 return stop_send_loop().then_wrapped([this] (future
<> f
) {
703 f
.ignore_ready_future();
704 _outstanding
.clear();
706 deregister_this_stream();
710 }).finally([this, ep
]{
711 if (_client_negotiated
&& ep
) {
712 _client_negotiated
->set_exception(ep
);
714 _stopped
.set_value();
719 client::client(const logger
& l
, void* s
, const socket_address
& addr
, const socket_address
& local
)
720 : client(l
, s
, client_options
{}, engine().net().socket(), addr
, local
)
723 client::client(const logger
& l
, void* s
, client_options options
, const socket_address
& addr
, const socket_address
& local
)
724 : client(l
, s
, options
, engine().net().socket(), addr
, local
)
727 client::client(const logger
& l
, void* s
, socket socket
, const socket_address
& addr
, const socket_address
& local
)
728 : client(l
, s
, client_options
{}, std::move(socket
), addr
, local
)
733 server::connection::negotiate(feature_map requested
) {
735 future
<> f
= make_ready_future
<>();
736 for (auto&& e
: requested
) {
739 // supported features go here
740 case protocol_features::COMPRESS
: {
741 if (_server
._options
.compressor_factory
) {
742 _compressor
= _server
._options
.compressor_factory
->negotiate(e
.second
, true);
743 ret
[protocol_features::COMPRESS
] = _server
._options
.compressor_factory
->supported();
747 case protocol_features::TIMEOUT
:
748 _timeout_negotiated
= true;
749 ret
[protocol_features::TIMEOUT
] = "";
751 case protocol_features::STREAM_PARENT
: {
752 if (!_server
._options
.streaming_domain
) {
753 f
= make_exception_future
<>(std::runtime_error("streaming is not configured for the server"));
755 _parent_id
= deserialize_connection_id(e
.second
);
757 // remove stream connection from rpc connection list
758 _server
._conns
.erase(get_connection_id());
759 f
= smp::submit_to(_parent_id
.shard(), [this, c
= make_foreign(static_pointer_cast
<rpc::connection
>(shared_from_this()))] () mutable {
760 auto sit
= _servers
.find(*_server
._options
.streaming_domain
);
761 if (sit
== _servers
.end()) {
762 throw std::logic_error(format("Shard {:d} does not have server with streaming domain {:x}", engine().cpu_id(), *_server
._options
.streaming_domain
).c_str());
764 auto s
= sit
->second
;
765 auto it
= s
->_conns
.find(_parent_id
);
766 if (it
== s
->_conns
.end()) {
767 throw std::logic_error(format("Unknown parent connection {:d} on shard {:d}", _parent_id
, engine().cpu_id()).c_str());
769 auto id
= c
->get_connection_id();
770 it
->second
->register_stream(id
, make_lw_shared(std::move(c
)));
775 case protocol_features::ISOLATION
: {
776 auto&& isolation_cookie
= e
.second
;
777 _isolation_config
= _server
._limits
.isolate_connection(isolation_cookie
);
786 if (_server
._options
.streaming_domain
) {
787 ret
[protocol_features::CONNECTION_ID
] = serialize_connection_id(_id
);
789 return f
.then([ret
= std::move(ret
)] {
795 server::connection::negotiate_protocol(input_stream
<char>& in
) {
796 return receive_negotiation_frame(*this, in
).then([this] (feature_map requested_features
) {
797 return negotiate(std::move(requested_features
)).then([this] (feature_map returned_features
) {
798 return send_negotiation_frame(std::move(returned_features
));
803 struct request_frame
{
804 using opt_buf_type
= compat::optional
<rcv_buf
>;
805 using header_and_buffer_type
= std::tuple
<compat::optional
<uint64_t>, uint64_t, int64_t, opt_buf_type
>;
806 using return_type
= future
<header_and_buffer_type
>;
807 using header_type
= std::tuple
<compat::optional
<uint64_t>, uint64_t, int64_t, uint32_t>;
808 static size_t header_size() {
811 static const char* role() {
814 static auto empty_value() {
815 return make_ready_future
<header_and_buffer_type
>(header_and_buffer_type(compat::nullopt
, uint64_t(0), 0, compat::nullopt
));
817 static header_type
decode_header(const char* ptr
) {
818 auto type
= read_le
<uint64_t>(ptr
);
819 auto msgid
= read_le
<int64_t>(ptr
+ 8);
820 auto size
= read_le
<uint32_t>(ptr
+ 16);
821 return std::make_tuple(compat::nullopt
, type
, msgid
, size
);
823 static uint32_t get_size(const header_type
& t
) {
824 return std::get
<3>(t
);
826 static auto make_value(const header_type
& t
, rcv_buf data
) {
827 return make_ready_future
<header_and_buffer_type
>(header_and_buffer_type(std::get
<0>(t
), std::get
<1>(t
), std::get
<2>(t
), std::move(data
)));
831 struct request_frame_with_timeout
: request_frame
{
832 using super
= request_frame
;
833 static size_t header_size() {
836 static typename
super::header_type
decode_header(const char* ptr
) {
837 auto h
= super::decode_header(ptr
+ 8);
838 std::get
<0>(h
) = read_le
<uint64_t>(ptr
);
843 future
<request_frame::header_and_buffer_type
>
844 server::connection::read_request_frame_compressed(input_stream
<char>& in
) {
845 if (_timeout_negotiated
) {
846 return read_frame_compressed
<request_frame_with_timeout
>(_info
.addr
, _compressor
, in
);
848 return read_frame_compressed
<request_frame
>(_info
.addr
, _compressor
, in
);
853 server::connection::respond(int64_t msg_id
, snd_buf
&& data
, compat::optional
<rpc_clock_type::time_point
> timeout
) {
854 static_assert(snd_buf::chunk_size
>= 12, "send buffer chunk size is too small");
855 auto p
= data
.front().get_write();
856 write_le
<int64_t>(p
, msg_id
);
857 write_le
<uint32_t>(p
+ 8, data
.size
- 12);
858 return send(std::move(data
), timeout
);
861 future
<> server::connection::send_unknown_verb_reply(compat::optional
<rpc_clock_type::time_point
> timeout
, int64_t msg_id
, uint64_t type
) {
862 return wait_for_resources(28, timeout
).then([this, timeout
, msg_id
, type
] (auto permit
) {
863 // send unknown_verb exception back
865 static_assert(snd_buf::chunk_size
>= 28, "send buffer chunk size is too small");
866 auto p
= data
.front().get_write() + 12;
867 write_le
<uint32_t>(p
, uint32_t(exception_type::UNKNOWN_VERB
));
868 write_le
<uint32_t>(p
+ 4, uint32_t(8));
869 write_le
<uint64_t>(p
+ 8, type
);
871 // Send asynchronously.
872 // This is safe since connection::stop() will wait for background work.
873 (void)with_gate(_server
._reply_gate
, [this, timeout
, msg_id
, data
= std::move(data
), permit
= std::move(permit
)] () mutable {
874 // workaround for https://gcc.gnu.org/bugzilla/show_bug.cgi?id=83268
875 auto c
= shared_from_this();
876 return respond(-msg_id
, std::move(data
), timeout
).then([c
= std::move(c
), permit
= std::move(permit
)] {});
878 } catch(gate_closed_exception
&) {/* ignore */}
882 future
<> server::connection::process() {
883 return negotiate_protocol(_read_buf
).then([this] () mutable {
884 auto sg
= _isolation_config
? _isolation_config
->sched_group
: current_scheduling_group();
885 return with_scheduling_group(sg
, [this] {
887 return do_until([this] { return _read_buf
.eof() || _error
; }, [this] () mutable {
889 return handle_stream_frame();
891 return read_request_frame_compressed(_read_buf
).then([this] (request_frame::header_and_buffer_type header_and_buffer
) {
892 auto& expire
= std::get
<0>(header_and_buffer
);
893 auto& type
= std::get
<1>(header_and_buffer
);
894 auto& msg_id
= std::get
<2>(header_and_buffer
);
895 auto& data
= std::get
<3>(header_and_buffer
);
898 return make_ready_future
<>();
900 compat::optional
<rpc_clock_type::time_point
> timeout
;
901 if (expire
&& *expire
) {
902 timeout
= relative_timeout_to_absolute(std::chrono::milliseconds(*expire
));
904 auto h
= _server
._proto
->get_handler(type
);
906 return send_unknown_verb_reply(timeout
, msg_id
, type
);
909 // If the new method of per-connection scheduling group was used, honor it.
910 // Otherwise, use the old per-handler scheduling group.
911 auto sg
= _isolation_config
? _isolation_config
->sched_group
: h
->sg
;
912 return with_scheduling_group(sg
, [this, timeout
, msg_id
, h
, data
= std::move(data
.value())] () mutable {
913 return h
->func(shared_from_this(), timeout
, msg_id
, std::move(data
)).finally([this, h
] {
914 // If anything between get_handler() and here throws, we leak put_handler
915 _server
._proto
->put_handler(h
);
922 }).then_wrapped([this] (future
<> f
) {
924 log_exception(*this, format("server{} connection dropped", is_stream() ? " stream" : "").c_str(), f
.get_exception());
926 _fd
.shutdown_input();
928 _stream_queue
.abort(std::make_exception_ptr(stream_closed()));
929 return stop_send_loop().then_wrapped([this] (future
<> f
) {
930 f
.ignore_ready_future();
931 _server
._conns
.erase(get_connection_id());
933 return deregister_this_stream();
935 return make_ready_future
<>();
938 _stopped
.set_value();
940 }).finally([conn_ptr
= shared_from_this()] {
941 // hold onto connection pointer until do_until() exists
945 server::connection::connection(server
& s
, connected_socket
&& fd
, socket_address
&& addr
, const logger
& l
, void* serializer
, connection_id id
)
946 : rpc::connection(std::move(fd
), l
, serializer
, id
), _server(s
) {
947 _info
.addr
= std::move(addr
);
950 future
<> server::connection::deregister_this_stream() {
951 if (!_server
._options
.streaming_domain
) {
952 return make_ready_future
<>();
954 return smp::submit_to(_parent_id
.shard(), [this] () mutable {
955 auto sit
= server::_servers
.find(*_server
._options
.streaming_domain
);
956 if (sit
!= server::_servers
.end()) {
957 auto s
= sit
->second
;
958 auto it
= s
->_conns
.find(_parent_id
);
959 if (it
!= s
->_conns
.end()) {
960 it
->second
->_streams
.erase(get_connection_id());
966 thread_local
std::unordered_map
<streaming_domain_type
, server
*> server::_servers
;
968 server::server(protocol_base
* proto
, const socket_address
& addr
, resource_limits limits
)
969 : server(proto
, engine().listen(addr
, listen_options
{true}), limits
, server_options
{})
972 server::server(protocol_base
* proto
, server_options opts
, const socket_address
& addr
, resource_limits limits
)
973 : server(proto
, engine().listen(addr
, listen_options
{true, opts
.load_balancing_algorithm
}), limits
, opts
)
976 server::server(protocol_base
* proto
, server_socket ss
, resource_limits limits
, server_options opts
)
977 : _proto(proto
), _ss(std::move(ss
)), _limits(limits
), _resources_available(limits
.max_memory
), _options(opts
)
979 if (_options
.streaming_domain
) {
980 if (_servers
.find(*_options
.streaming_domain
) != _servers
.end()) {
981 throw std::runtime_error(format("An RPC server with the streaming domain {} is already exist", *_options
.streaming_domain
));
983 _servers
[*_options
.streaming_domain
] = this;
988 server::server(protocol_base
* proto
, server_options opts
, server_socket ss
, resource_limits limits
)
989 : server(proto
, std::move(ss
), limits
, opts
)
992 void server::accept() {
993 // Run asynchronously in background.
994 // Communicate result via __ss_stopped.
995 // The caller has to call server::stop() to synchronize.
996 (void)keep_doing([this] () mutable {
997 return _ss
.accept().then([this] (accept_result ar
) mutable {
998 auto fd
= std::move(ar
.connection
);
999 auto addr
= std::move(ar
.remote_address
);
1000 fd
.set_nodelay(_options
.tcp_nodelay
);
1001 connection_id id
= _options
.streaming_domain
?
1002 connection_id::make_id(_next_client_id
++, uint16_t(engine().cpu_id())) :
1003 connection_id::make_invalid_id(_next_client_id
++);
1004 auto conn
= _proto
->make_server_connection(*this, std::move(fd
), std::move(addr
), id
);
1005 auto r
= _conns
.emplace(id
, conn
);
1007 // Process asynchronously in background.
1008 (void)conn
->process();
1010 }).then_wrapped([this] (future
<>&& f
){
1015 _ss_stopped
.set_value();
1020 future
<> server::stop() {
1022 _resources_available
.broken();
1023 if (_options
.streaming_domain
) {
1024 _servers
.erase(*_options
.streaming_domain
);
1026 return when_all(_ss_stopped
.get_future(),
1027 parallel_for_each(_conns
| boost::adaptors::map_values
, [] (shared_ptr
<connection
> conn
) {
1028 return conn
->stop();
1034 std::ostream
& operator<<(std::ostream
& os
, const connection_id
& id
) {
1035 return fmt_print(os
, "{:x}", id
.id
);
1038 std::ostream
& operator<<(std::ostream
& os
, const streaming_domain_type
& domain
) {
1039 return fmt_print(os
, "{:d}", domain
._id
);
1042 isolation_config
default_isolate_connection(sstring isolation_cookie
) {
1043 return isolation_config
{};