]>
Commit | Line | Data |
---|---|---|
11fdf7f2 TL |
1 | /* |
2 | * This file is open source software, licensed to you under the terms | |
3 | * of the Apache License, Version 2.0 (the "License"). See the NOTICE file | |
4 | * distributed with this work for additional information regarding copyright | |
5 | * ownership. You may not use this file except in compliance with the License. | |
6 | * | |
7 | * You may obtain a copy of the License at | |
8 | * | |
9 | * http://www.apache.org/licenses/LICENSE-2.0 | |
10 | * | |
11 | * Unless required by applicable law or agreed to in writing, | |
12 | * software distributed under the License is distributed on an | |
13 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | |
14 | * KIND, either express or implied. See the License for the | |
15 | * specific language governing permissions and limitations | |
16 | * under the License. | |
17 | */ | |
18 | /* | |
19 | * Copyright (C) 2016 ScyllaDB | |
20 | */ | |
21 | ||
22 | #pragma once | |
23 | ||
24 | #include <system_error> | |
25 | #include <seastar/core/iostream.hh> | |
26 | #include <seastar/core/circular_buffer.hh> | |
27 | #include <seastar/core/shared_ptr.hh> | |
28 | #include <seastar/core/queue.hh> | |
f67539c2 | 29 | #include <seastar/core/loop.hh> |
11fdf7f2 TL |
30 | #include <seastar/core/do_with.hh> |
31 | #include <seastar/net/stack.hh> | |
11fdf7f2 TL |
32 | #include <seastar/core/sharded.hh> |
33 | ||
34 | namespace seastar { | |
35 | ||
36 | struct loopback_error_injector { | |
37 | virtual ~loopback_error_injector() {}; | |
38 | virtual bool server_rcv_error() { return false; } | |
39 | virtual bool server_snd_error() { return false; } | |
40 | virtual bool client_rcv_error() { return false; } | |
41 | virtual bool client_snd_error() { return false; } | |
42 | }; | |
43 | ||
44 | class loopback_buffer { | |
45 | public: | |
46 | enum class type : uint8_t { | |
47 | CLIENT_TX, | |
48 | SERVER_TX | |
49 | }; | |
50 | private: | |
51 | bool _aborted = false; | |
52 | queue<temporary_buffer<char>> _q{1}; | |
53 | loopback_error_injector* _error_injector; | |
54 | type _type; | |
55 | public: | |
56 | loopback_buffer(loopback_error_injector* error_injection, type t) : _error_injector(error_injection), _type(t) {} | |
57 | future<> push(temporary_buffer<char>&& b) { | |
58 | if (_aborted) { | |
59 | return make_exception_future<>(std::system_error(EPIPE, std::system_category())); | |
60 | } | |
61 | bool error = false; | |
62 | if (_error_injector) { | |
63 | error = _type == type::CLIENT_TX ? _error_injector->client_snd_error() : _error_injector->server_snd_error(); | |
64 | } | |
65 | if (error) { | |
66 | shutdown(); | |
67 | return make_exception_future<>(std::runtime_error("test injected error on send")); | |
68 | } | |
69 | return _q.push_eventually(std::move(b)); | |
70 | } | |
71 | future<temporary_buffer<char>> pop() { | |
72 | if (_aborted) { | |
73 | return make_exception_future<temporary_buffer<char>>(std::system_error(EPIPE, std::system_category())); | |
74 | } | |
75 | bool error = false; | |
76 | if (_error_injector) { | |
77 | error = _type == type::CLIENT_TX ? _error_injector->client_rcv_error() : _error_injector->server_rcv_error(); | |
78 | } | |
79 | if (error) { | |
80 | shutdown(); | |
81 | return make_exception_future<temporary_buffer<char>>(std::runtime_error("test injected error on receive")); | |
82 | } | |
83 | return _q.pop_eventually(); | |
84 | } | |
85 | void shutdown() { | |
86 | _aborted = true; | |
87 | _q.abort(std::make_exception_ptr(std::system_error(EPIPE, std::system_category()))); | |
88 | } | |
89 | }; | |
90 | ||
91 | class loopback_data_sink_impl : public data_sink_impl { | |
92 | foreign_ptr<lw_shared_ptr<loopback_buffer>>& _buffer; | |
93 | public: | |
94 | explicit loopback_data_sink_impl(foreign_ptr<lw_shared_ptr<loopback_buffer>>& buffer) | |
95 | : _buffer(buffer) { | |
96 | } | |
97 | future<> put(net::packet data) override { | |
98 | return do_with(data.release(), [this] (std::vector<temporary_buffer<char>>& bufs) { | |
99 | return do_for_each(bufs, [this] (temporary_buffer<char>& buf) { | |
100 | return smp::submit_to(_buffer.get_owner_shard(), [this, b = buf.get(), s = buf.size()] { | |
101 | return _buffer->push(temporary_buffer<char>(b, s)); | |
102 | }); | |
103 | }); | |
104 | }); | |
105 | } | |
106 | future<> close() override { | |
107 | return smp::submit_to(_buffer.get_owner_shard(), [this] { | |
9f95a23c TL |
108 | return _buffer->push({}).handle_exception_type([] (std::system_error& err) { |
109 | if (err.code().value() != EPIPE) { | |
110 | throw err; | |
111 | } | |
112 | }); | |
11fdf7f2 TL |
113 | }); |
114 | } | |
115 | }; | |
116 | ||
117 | class loopback_data_source_impl : public data_source_impl { | |
118 | bool _eof = false; | |
119 | lw_shared_ptr<loopback_buffer> _buffer; | |
120 | public: | |
121 | explicit loopback_data_source_impl(lw_shared_ptr<loopback_buffer> buffer) | |
122 | : _buffer(std::move(buffer)) { | |
123 | } | |
124 | future<temporary_buffer<char>> get() override { | |
125 | return _buffer->pop().then_wrapped([this] (future<temporary_buffer<char>>&& b) { | |
126 | _eof = b.failed(); | |
127 | if (!_eof) { | |
128 | // future::get0() is destructive, so we have to play these games | |
129 | // FIXME: make future::get0() non-destructive | |
130 | auto&& tmp = b.get0(); | |
131 | _eof = tmp.empty(); | |
132 | b = make_ready_future<temporary_buffer<char>>(std::move(tmp)); | |
133 | } | |
134 | return std::move(b); | |
135 | }); | |
136 | } | |
137 | future<> close() override { | |
138 | if (!_eof) { | |
139 | _buffer->shutdown(); | |
140 | } | |
141 | return make_ready_future<>(); | |
142 | } | |
143 | }; | |
144 | ||
145 | ||
146 | class loopback_connected_socket_impl : public net::connected_socket_impl { | |
147 | foreign_ptr<lw_shared_ptr<loopback_buffer>> _tx; | |
148 | lw_shared_ptr<loopback_buffer> _rx; | |
149 | public: | |
150 | loopback_connected_socket_impl(foreign_ptr<lw_shared_ptr<loopback_buffer>> tx, lw_shared_ptr<loopback_buffer> rx) | |
151 | : _tx(std::move(tx)), _rx(std::move(rx)) { | |
152 | } | |
153 | data_source source() override { | |
154 | return data_source(std::make_unique<loopback_data_source_impl>(_rx)); | |
155 | } | |
156 | data_sink sink() override { | |
157 | return data_sink(std::make_unique<loopback_data_sink_impl>(_tx)); | |
158 | } | |
159 | void shutdown_input() override { | |
160 | _rx->shutdown(); | |
161 | } | |
162 | void shutdown_output() override { | |
9f95a23c | 163 | (void)smp::submit_to(_tx.get_owner_shard(), [this] { |
11fdf7f2 TL |
164 | // FIXME: who holds to _tx? |
165 | _tx->shutdown(); | |
166 | }); | |
167 | } | |
168 | void set_nodelay(bool nodelay) override { | |
169 | } | |
170 | bool get_nodelay() const override { | |
171 | return true; | |
172 | } | |
173 | void set_keepalive(bool keepalive) override {} | |
174 | bool get_keepalive() const override { | |
175 | return false; | |
176 | } | |
177 | void set_keepalive_parameters(const net::keepalive_params&) override {} | |
178 | net::keepalive_params get_keepalive_parameters() const override { | |
179 | return net::tcp_keepalive_params {std::chrono::seconds(0), std::chrono::seconds(0), 0}; | |
180 | } | |
f67539c2 TL |
181 | void set_sockopt(int level, int optname, const void* data, size_t len) override { |
182 | throw std::runtime_error("Setting custom socket options is not supported for loopback"); | |
183 | } | |
184 | int get_sockopt(int level, int optname, void* data, size_t len) const override { | |
185 | throw std::runtime_error("Getting custom socket options is not supported for loopback"); | |
186 | } | |
20effc67 TL |
187 | socket_address local_address() const noexcept override { |
188 | // dummy | |
189 | return {}; | |
190 | } | |
11fdf7f2 TL |
191 | }; |
192 | ||
f67539c2 | 193 | class loopback_server_socket_impl : public net::server_socket_impl { |
11fdf7f2 TL |
194 | lw_shared_ptr<queue<connected_socket>> _pending; |
195 | public: | |
196 | explicit loopback_server_socket_impl(lw_shared_ptr<queue<connected_socket>> q) | |
197 | : _pending(std::move(q)) { | |
198 | } | |
9f95a23c | 199 | future<accept_result> accept() override { |
11fdf7f2 | 200 | return _pending->pop_eventually().then([] (connected_socket&& cs) { |
9f95a23c | 201 | return make_ready_future<accept_result>(accept_result{std::move(cs), socket_address()}); |
11fdf7f2 TL |
202 | }); |
203 | } | |
204 | void abort_accept() override { | |
205 | _pending->abort(std::make_exception_ptr(std::system_error(ECONNABORTED, std::system_category()))); | |
206 | } | |
9f95a23c TL |
207 | socket_address local_address() const override { |
208 | // CMH dummy | |
209 | return {}; | |
210 | } | |
11fdf7f2 TL |
211 | }; |
212 | ||
213 | ||
214 | class loopback_connection_factory { | |
215 | unsigned _shard = 0; | |
216 | std::vector<lw_shared_ptr<queue<connected_socket>>> _pending; | |
217 | public: | |
218 | loopback_connection_factory() { | |
219 | _pending.resize(smp::count); | |
220 | } | |
221 | server_socket get_server_socket() { | |
f67539c2 TL |
222 | if (!_pending[this_shard_id()]) { |
223 | _pending[this_shard_id()] = make_lw_shared<queue<connected_socket>>(10); | |
11fdf7f2 | 224 | } |
f67539c2 | 225 | return server_socket(std::make_unique<loopback_server_socket_impl>(_pending[this_shard_id()])); |
11fdf7f2 TL |
226 | } |
227 | future<> make_new_server_connection(foreign_ptr<lw_shared_ptr<loopback_buffer>> b1, lw_shared_ptr<loopback_buffer> b2) { | |
f67539c2 TL |
228 | if (!_pending[this_shard_id()]) { |
229 | _pending[this_shard_id()] = make_lw_shared<queue<connected_socket>>(10); | |
11fdf7f2 | 230 | } |
f67539c2 | 231 | return _pending[this_shard_id()]->push_eventually(connected_socket(std::make_unique<loopback_connected_socket_impl>(std::move(b1), b2))); |
11fdf7f2 TL |
232 | } |
233 | connected_socket make_new_client_connection(lw_shared_ptr<loopback_buffer> b1, foreign_ptr<lw_shared_ptr<loopback_buffer>> b2) { | |
234 | return connected_socket(std::make_unique<loopback_connected_socket_impl>(std::move(b2), b1)); | |
235 | } | |
236 | unsigned next_shard() { | |
237 | return _shard++ % smp::count; | |
238 | } | |
239 | void destroy_shard(unsigned shard) { | |
240 | _pending[shard] = nullptr; | |
241 | } | |
9f95a23c TL |
242 | future<> destroy_all_shards() { |
243 | return smp::invoke_on_all([this] () { | |
f67539c2 | 244 | destroy_shard(this_shard_id()); |
9f95a23c TL |
245 | }); |
246 | } | |
11fdf7f2 TL |
247 | }; |
248 | ||
249 | class loopback_socket_impl : public net::socket_impl { | |
250 | loopback_connection_factory& _factory; | |
251 | loopback_error_injector* _error_injector; | |
252 | lw_shared_ptr<loopback_buffer> _b1; | |
253 | foreign_ptr<lw_shared_ptr<loopback_buffer>> _b2; | |
254 | public: | |
255 | loopback_socket_impl(loopback_connection_factory& factory, loopback_error_injector* error_injector = nullptr) | |
256 | : _factory(factory), _error_injector(error_injector) | |
257 | { } | |
9f95a23c | 258 | future<connected_socket> connect(socket_address sa, socket_address local, seastar::transport proto = seastar::transport::TCP) override { |
11fdf7f2 TL |
259 | auto shard = _factory.next_shard(); |
260 | _b1 = make_lw_shared<loopback_buffer>(_error_injector, loopback_buffer::type::SERVER_TX); | |
261 | return smp::submit_to(shard, [this, b1 = make_foreign(_b1)] () mutable { | |
262 | auto b2 = make_lw_shared<loopback_buffer>(_error_injector, loopback_buffer::type::CLIENT_TX); | |
263 | _b2 = make_foreign(b2); | |
264 | return _factory.make_new_server_connection(std::move(b1), b2).then([b2] { | |
265 | return make_foreign(b2); | |
266 | }); | |
9f95a23c | 267 | }).then([this] (foreign_ptr<lw_shared_ptr<loopback_buffer>> b2) { |
11fdf7f2 TL |
268 | return _factory.make_new_client_connection(_b1, std::move(b2)); |
269 | }); | |
270 | } | |
9f95a23c TL |
271 | virtual void set_reuseaddr(bool reuseaddr) override {} |
272 | virtual bool get_reuseaddr() const override { return false; }; | |
11fdf7f2 | 273 | |
9f95a23c | 274 | void shutdown() override { |
11fdf7f2 | 275 | _b1->shutdown(); |
9f95a23c | 276 | (void)smp::submit_to(_b2.get_owner_shard(), [b2 = std::move(_b2)] { |
11fdf7f2 TL |
277 | b2->shutdown(); |
278 | }); | |
279 | } | |
280 | }; | |
281 | ||
282 | } |