]>
Commit | Line | Data |
---|---|---|
0eeb0dd1 | 1 | use std::io::IoSlice; |
c94ad247 | 2 | use std::marker::Unpin; |
0eeb0dd1 | 3 | use std::pin::Pin; |
c94ad247 DM |
4 | use std::sync::{Arc, Mutex}; |
5 | use std::time::{Duration, Instant}; | |
6 | ||
7 | use futures::Future; | |
0eeb0dd1 TL |
8 | use hyper::client::connect::{Connected, Connection}; |
9 | use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; | |
c94ad247 DM |
10 | use tokio::time::Sleep; |
11 | ||
12 | use std::task::{Context, Poll}; | |
13 | ||
0eeb0dd1 | 14 | use super::{RateLimiter, ShareableRateLimit}; |
8734d0c2 DM |
15 | |
16 | type SharedRateLimit = Arc<dyn ShareableRateLimit>; | |
c94ad247 | 17 | |
94d388b9 WB |
18 | pub type RateLimiterCallback = |
19 | dyn Fn() -> (Option<SharedRateLimit>, Option<SharedRateLimit>) + Send; | |
20 | ||
c94ad247 DM |
21 | /// A rate limited stream using [RateLimiter] |
22 | pub struct RateLimitedStream<S> { | |
8734d0c2 | 23 | read_limiter: Option<SharedRateLimit>, |
c94ad247 | 24 | read_delay: Option<Pin<Box<Sleep>>>, |
8734d0c2 | 25 | write_limiter: Option<SharedRateLimit>, |
c94ad247 | 26 | write_delay: Option<Pin<Box<Sleep>>>, |
94d388b9 | 27 | update_limiter_cb: Option<Box<RateLimiterCallback>>, |
e0305f72 | 28 | last_limiter_update: Instant, |
c94ad247 DM |
29 | stream: S, |
30 | } | |
31 | ||
0c27d5da DM |
32 | impl RateLimitedStream<tokio::net::TcpStream> { |
33 | pub fn peer_addr(&self) -> std::io::Result<std::net::SocketAddr> { | |
34 | self.stream.peer_addr() | |
35 | } | |
36 | } | |
37 | ||
0eeb0dd1 | 38 | impl<S> RateLimitedStream<S> { |
c94ad247 DM |
39 | /// Creates a new instance with reads and writes limited to the same `rate`. |
40 | pub fn new(stream: S, rate: u64, bucket_size: u64) -> Self { | |
41 | let now = Instant::now(); | |
8734d0c2 DM |
42 | let read_limiter = RateLimiter::with_start_time(rate, bucket_size, now); |
43 | let read_limiter: SharedRateLimit = Arc::new(Mutex::new(read_limiter)); | |
44 | let write_limiter = RateLimiter::with_start_time(rate, bucket_size, now); | |
45 | let write_limiter: SharedRateLimit = Arc::new(Mutex::new(write_limiter)); | |
c94ad247 DM |
46 | Self::with_limiter(stream, Some(read_limiter), Some(write_limiter)) |
47 | } | |
48 | ||
c609a580 | 49 | /// Creates a new instance with specified [`RateLimiter`s](RateLimiter) for reads and writes. |
c94ad247 DM |
50 | pub fn with_limiter( |
51 | stream: S, | |
8734d0c2 DM |
52 | read_limiter: Option<SharedRateLimit>, |
53 | write_limiter: Option<SharedRateLimit>, | |
c94ad247 | 54 | ) -> Self { |
0eeb0dd1 | 55 | Self { |
c94ad247 DM |
56 | read_limiter, |
57 | read_delay: None, | |
58 | write_limiter, | |
59 | write_delay: None, | |
e0305f72 DM |
60 | update_limiter_cb: None, |
61 | last_limiter_update: Instant::now(), | |
c94ad247 DM |
62 | stream, |
63 | } | |
64 | } | |
e0305f72 DM |
65 | |
66 | /// Creates a new instance with limiter update callback. | |
67 | /// | |
68 | /// The fuction is called every minute to update/change the used limiters. | |
69 | /// | |
70 | /// Note: This function is called within an async context, so it | |
71 | /// should be fast and must not block. | |
0eeb0dd1 TL |
72 | pub fn with_limiter_update_cb< |
73 | F: Fn() -> (Option<SharedRateLimit>, Option<SharedRateLimit>) + Send + 'static, | |
74 | >( | |
e0305f72 DM |
75 | stream: S, |
76 | update_limiter_cb: F, | |
77 | ) -> Self { | |
78 | let (read_limiter, write_limiter) = update_limiter_cb(); | |
79 | Self { | |
80 | read_limiter, | |
81 | read_delay: None, | |
82 | write_limiter, | |
83 | write_delay: None, | |
84 | update_limiter_cb: Some(Box::new(update_limiter_cb)), | |
85 | last_limiter_update: Instant::now(), | |
86 | stream, | |
87 | } | |
88 | } | |
89 | ||
90 | fn update_limiters(&mut self) { | |
91 | if let Some(ref update_limiter_cb) = self.update_limiter_cb { | |
92 | if self.last_limiter_update.elapsed().as_secs() >= 5 { | |
93 | self.last_limiter_update = Instant::now(); | |
94 | let (read_limiter, write_limiter) = update_limiter_cb(); | |
95 | self.read_limiter = read_limiter; | |
96 | self.write_limiter = write_limiter; | |
97 | } | |
98 | } | |
99 | } | |
c94ad247 DM |
100 | } |
101 | ||
0eeb0dd1 | 102 | fn register_traffic(limiter: &(dyn ShareableRateLimit), count: usize) -> Option<Pin<Box<Sleep>>> { |
ded24b3f DM |
103 | const MIN_DELAY: Duration = Duration::from_millis(10); |
104 | ||
105 | let now = Instant::now(); | |
8734d0c2 | 106 | let delay = limiter.register_traffic(now, count as u64); |
ded24b3f DM |
107 | if delay >= MIN_DELAY { |
108 | let sleep = tokio::time::sleep(delay); | |
109 | Some(Box::pin(sleep)) | |
110 | } else { | |
111 | None | |
112 | } | |
113 | } | |
114 | ||
115 | fn delay_is_ready(delay: &mut Option<Pin<Box<Sleep>>>, ctx: &mut Context<'_>) -> bool { | |
116 | match delay { | |
0eeb0dd1 | 117 | Some(ref mut future) => future.as_mut().poll(ctx).is_ready(), |
ded24b3f DM |
118 | None => true, |
119 | } | |
120 | } | |
121 | ||
0eeb0dd1 | 122 | impl<S: AsyncWrite + Unpin> AsyncWrite for RateLimitedStream<S> { |
c94ad247 DM |
123 | fn poll_write( |
124 | self: Pin<&mut Self>, | |
125 | ctx: &mut Context<'_>, | |
0eeb0dd1 | 126 | buf: &[u8], |
c94ad247 DM |
127 | ) -> Poll<Result<usize, std::io::Error>> { |
128 | let this = self.get_mut(); | |
129 | ||
ded24b3f | 130 | let is_ready = delay_is_ready(&mut this.write_delay, ctx); |
c94ad247 | 131 | |
0eeb0dd1 TL |
132 | if !is_ready { |
133 | return Poll::Pending; | |
134 | } | |
c94ad247 DM |
135 | |
136 | this.write_delay = None; | |
137 | ||
e0305f72 DM |
138 | this.update_limiters(); |
139 | ||
c94ad247 DM |
140 | let result = Pin::new(&mut this.stream).poll_write(ctx, buf); |
141 | ||
8734d0c2 | 142 | if let Some(ref mut limiter) = this.write_limiter { |
ded24b3f | 143 | if let Poll::Ready(Ok(count)) = result { |
8734d0c2 | 144 | this.write_delay = register_traffic(limiter.as_ref(), count); |
ded24b3f DM |
145 | } |
146 | } | |
147 | ||
148 | result | |
149 | } | |
150 | ||
151 | fn is_write_vectored(&self) -> bool { | |
152 | self.stream.is_write_vectored() | |
153 | } | |
154 | ||
155 | fn poll_write_vectored( | |
156 | self: Pin<&mut Self>, | |
157 | ctx: &mut Context<'_>, | |
0eeb0dd1 | 158 | bufs: &[IoSlice<'_>], |
ded24b3f DM |
159 | ) -> Poll<Result<usize, std::io::Error>> { |
160 | let this = self.get_mut(); | |
161 | ||
162 | let is_ready = delay_is_ready(&mut this.write_delay, ctx); | |
163 | ||
0eeb0dd1 TL |
164 | if !is_ready { |
165 | return Poll::Pending; | |
166 | } | |
ded24b3f DM |
167 | |
168 | this.write_delay = None; | |
169 | ||
e0305f72 DM |
170 | this.update_limiters(); |
171 | ||
ded24b3f DM |
172 | let result = Pin::new(&mut this.stream).poll_write_vectored(ctx, bufs); |
173 | ||
174 | if let Some(ref limiter) = this.write_limiter { | |
175 | if let Poll::Ready(Ok(count)) = result { | |
8734d0c2 | 176 | this.write_delay = register_traffic(limiter.as_ref(), count); |
c94ad247 DM |
177 | } |
178 | } | |
179 | ||
180 | result | |
181 | } | |
182 | ||
0eeb0dd1 | 183 | fn poll_flush(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> { |
c94ad247 DM |
184 | let this = self.get_mut(); |
185 | Pin::new(&mut this.stream).poll_flush(ctx) | |
186 | } | |
187 | ||
188 | fn poll_shutdown( | |
189 | self: Pin<&mut Self>, | |
0eeb0dd1 | 190 | ctx: &mut Context<'_>, |
c94ad247 DM |
191 | ) -> Poll<Result<(), std::io::Error>> { |
192 | let this = self.get_mut(); | |
193 | Pin::new(&mut this.stream).poll_shutdown(ctx) | |
194 | } | |
195 | } | |
196 | ||
0eeb0dd1 | 197 | impl<S: AsyncRead + Unpin> AsyncRead for RateLimitedStream<S> { |
c94ad247 DM |
198 | fn poll_read( |
199 | self: Pin<&mut Self>, | |
200 | ctx: &mut Context<'_>, | |
201 | buf: &mut ReadBuf<'_>, | |
202 | ) -> Poll<Result<(), std::io::Error>> { | |
203 | let this = self.get_mut(); | |
204 | ||
ded24b3f | 205 | let is_ready = delay_is_ready(&mut this.read_delay, ctx); |
c94ad247 | 206 | |
0eeb0dd1 TL |
207 | if !is_ready { |
208 | return Poll::Pending; | |
209 | } | |
c94ad247 DM |
210 | |
211 | this.read_delay = None; | |
212 | ||
e0305f72 DM |
213 | this.update_limiters(); |
214 | ||
c94ad247 DM |
215 | let filled_len = buf.filled().len(); |
216 | let result = Pin::new(&mut this.stream).poll_read(ctx, buf); | |
217 | ||
218 | if let Some(ref read_limiter) = this.read_limiter { | |
219 | if let Poll::Ready(Ok(())) = &result { | |
220 | let count = buf.filled().len() - filled_len; | |
8734d0c2 | 221 | this.read_delay = register_traffic(read_limiter.as_ref(), count); |
c94ad247 DM |
222 | } |
223 | } | |
224 | ||
225 | result | |
226 | } | |
c94ad247 | 227 | } |
00ca0b7f DM |
228 | |
229 | // we need this for the hyper http client | |
230 | impl<S: Connection + AsyncRead + AsyncWrite + Unpin> Connection for RateLimitedStream<S> { | |
231 | fn connected(&self) -> Connected { | |
232 | self.stream.connected() | |
233 | } | |
234 | } |