]> git.proxmox.com Git - ceph.git/blob - ceph/src/seastar/tests/unit/loopback_socket.hh
update ceph source to reef 18.1.2
[ceph.git] / ceph / src / seastar / tests / unit / loopback_socket.hh
1 /*
2 * This file is open source software, licensed to you under the terms
3 * of the Apache License, Version 2.0 (the "License"). See the NOTICE file
4 * distributed with this work for additional information regarding copyright
5 * ownership. You may not use this file except in compliance with the License.
6 *
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing,
12 * software distributed under the License is distributed on an
13 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 * KIND, either express or implied. See the License for the
15 * specific language governing permissions and limitations
16 * under the License.
17 */
18 /*
19 * Copyright (C) 2016 ScyllaDB
20 */
21
22 #pragma once
23
24 #include <system_error>
25 #include <seastar/core/iostream.hh>
26 #include <seastar/core/circular_buffer.hh>
27 #include <seastar/core/shared_ptr.hh>
28 #include <seastar/core/queue.hh>
29 #include <seastar/core/loop.hh>
30 #include <seastar/core/do_with.hh>
31 #include <seastar/net/stack.hh>
32 #include <seastar/core/sharded.hh>
33
34 namespace seastar {
35
36 struct loopback_error_injector {
37 enum class error { none, one_shot, abort };
38 virtual ~loopback_error_injector() {};
39 virtual error server_rcv_error() { return error::none; }
40 virtual error server_snd_error() { return error::none; }
41 virtual error client_rcv_error() { return error::none; }
42 virtual error client_snd_error() { return error::none; }
43 virtual error connect_error() { return error::none; }
44 };
45
46 class loopback_buffer {
47 public:
48 enum class type : uint8_t {
49 CLIENT_TX,
50 SERVER_TX
51 };
52 private:
53 bool _aborted = false;
54 queue<temporary_buffer<char>> _q{1};
55 loopback_error_injector* _error_injector;
56 type _type;
57 public:
58 loopback_buffer(loopback_error_injector* error_injection, type t) : _error_injector(error_injection), _type(t) {}
59 future<> push(temporary_buffer<char>&& b) {
60 if (_aborted) {
61 return make_exception_future<>(std::system_error(EPIPE, std::system_category()));
62 }
63 if (_error_injector) {
64 auto error = _type == type::CLIENT_TX ? _error_injector->client_snd_error() : _error_injector->server_snd_error();
65 if (error == loopback_error_injector::error::one_shot) {
66 return make_exception_future<>(std::runtime_error("test injected glitch on send"));
67 }
68 if (error == loopback_error_injector::error::abort) {
69 shutdown();
70 return make_exception_future<>(std::runtime_error("test injected error on send"));
71 }
72 }
73 return _q.push_eventually(std::move(b));
74 }
75 future<temporary_buffer<char>> pop() {
76 if (_aborted) {
77 return make_exception_future<temporary_buffer<char>>(std::system_error(EPIPE, std::system_category()));
78 }
79 if (_error_injector) {
80 auto error = _type == type::CLIENT_TX ? _error_injector->client_rcv_error() : _error_injector->server_rcv_error();
81 if (error == loopback_error_injector::error::one_shot) {
82 return make_exception_future<temporary_buffer<char>>(std::runtime_error("test injected glitch on receive"));
83 }
84 if (error == loopback_error_injector::error::abort) {
85 shutdown();
86 return make_exception_future<temporary_buffer<char>>(std::runtime_error("test injected error on receive"));
87 }
88 }
89 return _q.pop_eventually();
90 }
91 void shutdown() noexcept {
92 _aborted = true;
93 _q.abort(std::make_exception_ptr(std::system_error(EPIPE, std::system_category())));
94 }
95 };
96
97 class loopback_data_sink_impl : public data_sink_impl {
98 lw_shared_ptr<foreign_ptr<lw_shared_ptr<loopback_buffer>>> _buffer;
99 public:
100 explicit loopback_data_sink_impl(lw_shared_ptr<foreign_ptr<lw_shared_ptr<loopback_buffer>>> buffer)
101 : _buffer(buffer) {
102 }
103 future<> put(net::packet data) override {
104 return do_with(data.release(), [this] (std::vector<temporary_buffer<char>>& bufs) {
105 return do_for_each(bufs, [this] (temporary_buffer<char>& buf) {
106 return smp::submit_to(_buffer->get_owner_shard(), [this, b = buf.get(), s = buf.size()] {
107 return (*_buffer)->push(temporary_buffer<char>(b, s));
108 });
109 });
110 });
111 }
112 future<> close() override {
113 return smp::submit_to(_buffer->get_owner_shard(), [this] {
114 return (*_buffer)->push({}).handle_exception_type([] (std::system_error& err) {
115 if (err.code().value() != EPIPE) {
116 throw err;
117 }
118 });
119 });
120 }
121 };
122
123 class loopback_data_source_impl : public data_source_impl {
124 bool _eof = false;
125 lw_shared_ptr<loopback_buffer> _buffer;
126 public:
127 explicit loopback_data_source_impl(lw_shared_ptr<loopback_buffer> buffer)
128 : _buffer(std::move(buffer)) {
129 }
130 future<temporary_buffer<char>> get() override {
131 return _buffer->pop().then_wrapped([this] (future<temporary_buffer<char>>&& b) {
132 _eof = b.failed();
133 if (!_eof) {
134 // future::get0() is destructive, so we have to play these games
135 // FIXME: make future::get0() non-destructive
136 auto&& tmp = b.get0();
137 _eof = tmp.empty();
138 b = make_ready_future<temporary_buffer<char>>(std::move(tmp));
139 }
140 return std::move(b);
141 });
142 }
143 future<> close() override {
144 if (!_eof) {
145 _buffer->shutdown();
146 }
147 return make_ready_future<>();
148 }
149 };
150
151
152 class loopback_connected_socket_impl : public net::connected_socket_impl {
153 lw_shared_ptr<foreign_ptr<lw_shared_ptr<loopback_buffer>>> _tx;
154 lw_shared_ptr<loopback_buffer> _rx;
155 public:
156 loopback_connected_socket_impl(foreign_ptr<lw_shared_ptr<loopback_buffer>> tx, lw_shared_ptr<loopback_buffer> rx)
157 : _tx(make_lw_shared(std::move(tx))), _rx(std::move(rx)) {
158 }
159 data_source source() override {
160 return data_source(std::make_unique<loopback_data_source_impl>(_rx));
161 }
162 data_sink sink() override {
163 return data_sink(std::make_unique<loopback_data_sink_impl>(_tx));
164 }
165 void shutdown_input() override {
166 _rx->shutdown();
167 }
168 void shutdown_output() override {
169 (void)smp::submit_to(_tx->get_owner_shard(), [tx = _tx] {
170 (*tx)->shutdown();
171 });
172 }
173 void set_nodelay(bool nodelay) override {
174 }
175 bool get_nodelay() const override {
176 return true;
177 }
178 void set_keepalive(bool keepalive) override {}
179 bool get_keepalive() const override {
180 return false;
181 }
182 void set_keepalive_parameters(const net::keepalive_params&) override {}
183 net::keepalive_params get_keepalive_parameters() const override {
184 return net::tcp_keepalive_params {std::chrono::seconds(0), std::chrono::seconds(0), 0};
185 }
186 void set_sockopt(int level, int optname, const void* data, size_t len) override {
187 throw std::runtime_error("Setting custom socket options is not supported for loopback");
188 }
189 int get_sockopt(int level, int optname, void* data, size_t len) const override {
190 throw std::runtime_error("Getting custom socket options is not supported for loopback");
191 }
192 socket_address local_address() const noexcept override {
193 // dummy
194 return {};
195 }
196 future<> wait_input_shutdown() override {
197 abort(); // No tests use this
198 return make_ready_future<>();
199 }
200 };
201
202 class loopback_server_socket_impl : public net::server_socket_impl {
203 lw_shared_ptr<queue<connected_socket>> _pending;
204 public:
205 explicit loopback_server_socket_impl(lw_shared_ptr<queue<connected_socket>> q)
206 : _pending(std::move(q)) {
207 }
208 future<accept_result> accept() override {
209 return _pending->pop_eventually().then([] (connected_socket&& cs) {
210 return make_ready_future<accept_result>(accept_result{std::move(cs), socket_address()});
211 });
212 }
213 void abort_accept() override {
214 _pending->abort(std::make_exception_ptr(std::system_error(ECONNABORTED, std::system_category())));
215 }
216 socket_address local_address() const override {
217 // CMH dummy
218 return {};
219 }
220 };
221
222
223 class loopback_connection_factory {
224 unsigned _shard = 0;
225 unsigned _shards_count;
226 std::vector<lw_shared_ptr<queue<connected_socket>>> _pending;
227 public:
228 explicit loopback_connection_factory(unsigned shards_count = smp::count)
229 : _shards_count(shards_count)
230 {
231 _pending.resize(shards_count);
232 }
233 server_socket get_server_socket() {
234 assert(this_shard_id() < _shards_count);
235 if (!_pending[this_shard_id()]) {
236 _pending[this_shard_id()] = make_lw_shared<queue<connected_socket>>(10);
237 }
238 return server_socket(std::make_unique<loopback_server_socket_impl>(_pending[this_shard_id()]));
239 }
240 future<> make_new_server_connection(foreign_ptr<lw_shared_ptr<loopback_buffer>> b1, lw_shared_ptr<loopback_buffer> b2) {
241 assert(this_shard_id() < _shards_count);
242 if (!_pending[this_shard_id()]) {
243 _pending[this_shard_id()] = make_lw_shared<queue<connected_socket>>(10);
244 }
245 return _pending[this_shard_id()]->push_eventually(connected_socket(std::make_unique<loopback_connected_socket_impl>(std::move(b1), b2)));
246 }
247 connected_socket make_new_client_connection(lw_shared_ptr<loopback_buffer> b1, foreign_ptr<lw_shared_ptr<loopback_buffer>> b2) {
248 return connected_socket(std::make_unique<loopback_connected_socket_impl>(std::move(b2), b1));
249 }
250 unsigned next_shard() {
251 return _shard++ % _shards_count;
252 }
253 void destroy_shard(unsigned shard) {
254 assert(shard < _shards_count);
255 _pending[shard] = nullptr;
256 }
257 future<> destroy_all_shards() {
258 return parallel_for_each(boost::irange(0u, _shards_count), [this](shard_id shard) {
259 return smp::submit_to(shard, [this] {
260 destroy_shard(this_shard_id());
261 });
262 });
263 }
264 };
265
266 class loopback_socket_impl : public net::socket_impl {
267 loopback_connection_factory& _factory;
268 loopback_error_injector* _error_injector;
269 lw_shared_ptr<loopback_buffer> _b1;
270 foreign_ptr<lw_shared_ptr<loopback_buffer>> _b2;
271 std::optional<promise<connected_socket>> _connect_abort;
272 public:
273 loopback_socket_impl(loopback_connection_factory& factory, loopback_error_injector* error_injector = nullptr)
274 : _factory(factory), _error_injector(error_injector)
275 { }
276 future<connected_socket> connect(socket_address sa, socket_address local, seastar::transport proto = seastar::transport::TCP) override {
277 if (_error_injector) {
278 auto error = _error_injector->connect_error();
279 if (error != loopback_error_injector::error::none) {
280 _connect_abort.emplace();
281 return _connect_abort->get_future();
282 }
283 }
284
285 auto shard = _factory.next_shard();
286 _b1 = make_lw_shared<loopback_buffer>(_error_injector, loopback_buffer::type::SERVER_TX);
287 return smp::submit_to(shard, [this, b1 = make_foreign(_b1)] () mutable {
288 auto b2 = make_lw_shared<loopback_buffer>(_error_injector, loopback_buffer::type::CLIENT_TX);
289 _b2 = make_foreign(b2);
290 return _factory.make_new_server_connection(std::move(b1), b2).then([b2] {
291 return make_foreign(b2);
292 });
293 }).then([this] (foreign_ptr<lw_shared_ptr<loopback_buffer>> b2) {
294 return _factory.make_new_client_connection(_b1, std::move(b2));
295 });
296 }
297 virtual void set_reuseaddr(bool reuseaddr) override {}
298 virtual bool get_reuseaddr() const override { return false; };
299
300 void shutdown() override {
301 if (_connect_abort) {
302 _connect_abort->set_exception(std::make_exception_ptr(std::system_error(ECONNABORTED, std::system_category())));
303 _connect_abort = std::nullopt;
304 } else {
305 _b1->shutdown();
306 (void)smp::submit_to(_b2.get_owner_shard(), [b2 = std::move(_b2)] {
307 b2->shutdown();
308 });
309 }
310 }
311 };
312
313 }