]>
Commit | Line | Data |
---|---|---|
10a3ab22 | 1 | use std::future::Future; |
0eeb0dd1 | 2 | use std::io::IoSlice; |
c94ad247 | 3 | use std::marker::Unpin; |
0eeb0dd1 | 4 | use std::pin::Pin; |
c94ad247 DM |
5 | use std::sync::{Arc, Mutex}; |
6 | use std::time::{Duration, Instant}; | |
7 | ||
0eeb0dd1 TL |
8 | use hyper::client::connect::{Connected, Connection}; |
9 | use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; | |
c94ad247 DM |
10 | use tokio::time::Sleep; |
11 | ||
12 | use std::task::{Context, Poll}; | |
13 | ||
0eeb0dd1 | 14 | use super::{RateLimiter, ShareableRateLimit}; |
8734d0c2 DM |
15 | |
16 | type SharedRateLimit = Arc<dyn ShareableRateLimit>; | |
c94ad247 | 17 | |
94d388b9 WB |
18 | pub type RateLimiterCallback = |
19 | dyn Fn() -> (Option<SharedRateLimit>, Option<SharedRateLimit>) + Send; | |
20 | ||
c94ad247 DM |
21 | /// A rate limited stream using [RateLimiter] |
22 | pub struct RateLimitedStream<S> { | |
8734d0c2 | 23 | read_limiter: Option<SharedRateLimit>, |
c94ad247 | 24 | read_delay: Option<Pin<Box<Sleep>>>, |
8734d0c2 | 25 | write_limiter: Option<SharedRateLimit>, |
c94ad247 | 26 | write_delay: Option<Pin<Box<Sleep>>>, |
94d388b9 | 27 | update_limiter_cb: Option<Box<RateLimiterCallback>>, |
e0305f72 | 28 | last_limiter_update: Instant, |
c94ad247 DM |
29 | stream: S, |
30 | } | |
31 | ||
0eeb0dd1 | 32 | impl<S> RateLimitedStream<S> { |
c94ad247 DM |
33 | /// Creates a new instance with reads and writes limited to the same `rate`. |
34 | pub fn new(stream: S, rate: u64, bucket_size: u64) -> Self { | |
35 | let now = Instant::now(); | |
8734d0c2 DM |
36 | let read_limiter = RateLimiter::with_start_time(rate, bucket_size, now); |
37 | let read_limiter: SharedRateLimit = Arc::new(Mutex::new(read_limiter)); | |
38 | let write_limiter = RateLimiter::with_start_time(rate, bucket_size, now); | |
39 | let write_limiter: SharedRateLimit = Arc::new(Mutex::new(write_limiter)); | |
c94ad247 DM |
40 | Self::with_limiter(stream, Some(read_limiter), Some(write_limiter)) |
41 | } | |
42 | ||
c609a580 | 43 | /// Creates a new instance with specified [`RateLimiter`s](RateLimiter) for reads and writes. |
c94ad247 DM |
44 | pub fn with_limiter( |
45 | stream: S, | |
8734d0c2 DM |
46 | read_limiter: Option<SharedRateLimit>, |
47 | write_limiter: Option<SharedRateLimit>, | |
c94ad247 | 48 | ) -> Self { |
0eeb0dd1 | 49 | Self { |
c94ad247 DM |
50 | read_limiter, |
51 | read_delay: None, | |
52 | write_limiter, | |
53 | write_delay: None, | |
e0305f72 DM |
54 | update_limiter_cb: None, |
55 | last_limiter_update: Instant::now(), | |
c94ad247 DM |
56 | stream, |
57 | } | |
58 | } | |
e0305f72 DM |
59 | |
60 | /// Creates a new instance with limiter update callback. | |
61 | /// | |
3ac6f2d9 | 62 | /// The function is called every minute to update/change the used limiters. |
e0305f72 DM |
63 | /// |
64 | /// Note: This function is called within an async context, so it | |
65 | /// should be fast and must not block. | |
0eeb0dd1 TL |
66 | pub fn with_limiter_update_cb< |
67 | F: Fn() -> (Option<SharedRateLimit>, Option<SharedRateLimit>) + Send + 'static, | |
68 | >( | |
e0305f72 DM |
69 | stream: S, |
70 | update_limiter_cb: F, | |
71 | ) -> Self { | |
72 | let (read_limiter, write_limiter) = update_limiter_cb(); | |
73 | Self { | |
74 | read_limiter, | |
75 | read_delay: None, | |
76 | write_limiter, | |
77 | write_delay: None, | |
78 | update_limiter_cb: Some(Box::new(update_limiter_cb)), | |
79 | last_limiter_update: Instant::now(), | |
80 | stream, | |
81 | } | |
82 | } | |
83 | ||
84 | fn update_limiters(&mut self) { | |
85 | if let Some(ref update_limiter_cb) = self.update_limiter_cb { | |
86 | if self.last_limiter_update.elapsed().as_secs() >= 5 { | |
87 | self.last_limiter_update = Instant::now(); | |
88 | let (read_limiter, write_limiter) = update_limiter_cb(); | |
89 | self.read_limiter = read_limiter; | |
90 | self.write_limiter = write_limiter; | |
91 | } | |
92 | } | |
93 | } | |
d7ed04f8 WB |
94 | |
95 | pub fn inner(&self) -> &S { | |
96 | &self.stream | |
97 | } | |
98 | ||
99 | pub fn inner_mut(&mut self) -> &mut S { | |
100 | &mut self.stream | |
101 | } | |
c94ad247 DM |
102 | } |
103 | ||
0eeb0dd1 | 104 | fn register_traffic(limiter: &(dyn ShareableRateLimit), count: usize) -> Option<Pin<Box<Sleep>>> { |
ded24b3f DM |
105 | const MIN_DELAY: Duration = Duration::from_millis(10); |
106 | ||
107 | let now = Instant::now(); | |
8734d0c2 | 108 | let delay = limiter.register_traffic(now, count as u64); |
ded24b3f DM |
109 | if delay >= MIN_DELAY { |
110 | let sleep = tokio::time::sleep(delay); | |
111 | Some(Box::pin(sleep)) | |
112 | } else { | |
113 | None | |
114 | } | |
115 | } | |
116 | ||
117 | fn delay_is_ready(delay: &mut Option<Pin<Box<Sleep>>>, ctx: &mut Context<'_>) -> bool { | |
118 | match delay { | |
0eeb0dd1 | 119 | Some(ref mut future) => future.as_mut().poll(ctx).is_ready(), |
ded24b3f DM |
120 | None => true, |
121 | } | |
122 | } | |
123 | ||
0eeb0dd1 | 124 | impl<S: AsyncWrite + Unpin> AsyncWrite for RateLimitedStream<S> { |
c94ad247 DM |
125 | fn poll_write( |
126 | self: Pin<&mut Self>, | |
127 | ctx: &mut Context<'_>, | |
0eeb0dd1 | 128 | buf: &[u8], |
c94ad247 DM |
129 | ) -> Poll<Result<usize, std::io::Error>> { |
130 | let this = self.get_mut(); | |
131 | ||
ded24b3f | 132 | let is_ready = delay_is_ready(&mut this.write_delay, ctx); |
c94ad247 | 133 | |
0eeb0dd1 TL |
134 | if !is_ready { |
135 | return Poll::Pending; | |
136 | } | |
c94ad247 DM |
137 | |
138 | this.write_delay = None; | |
139 | ||
e0305f72 DM |
140 | this.update_limiters(); |
141 | ||
c94ad247 DM |
142 | let result = Pin::new(&mut this.stream).poll_write(ctx, buf); |
143 | ||
8734d0c2 | 144 | if let Some(ref mut limiter) = this.write_limiter { |
ded24b3f | 145 | if let Poll::Ready(Ok(count)) = result { |
8734d0c2 | 146 | this.write_delay = register_traffic(limiter.as_ref(), count); |
ded24b3f DM |
147 | } |
148 | } | |
149 | ||
150 | result | |
151 | } | |
152 | ||
153 | fn is_write_vectored(&self) -> bool { | |
154 | self.stream.is_write_vectored() | |
155 | } | |
156 | ||
157 | fn poll_write_vectored( | |
158 | self: Pin<&mut Self>, | |
159 | ctx: &mut Context<'_>, | |
0eeb0dd1 | 160 | bufs: &[IoSlice<'_>], |
ded24b3f DM |
161 | ) -> Poll<Result<usize, std::io::Error>> { |
162 | let this = self.get_mut(); | |
163 | ||
164 | let is_ready = delay_is_ready(&mut this.write_delay, ctx); | |
165 | ||
0eeb0dd1 TL |
166 | if !is_ready { |
167 | return Poll::Pending; | |
168 | } | |
ded24b3f DM |
169 | |
170 | this.write_delay = None; | |
171 | ||
e0305f72 DM |
172 | this.update_limiters(); |
173 | ||
ded24b3f DM |
174 | let result = Pin::new(&mut this.stream).poll_write_vectored(ctx, bufs); |
175 | ||
176 | if let Some(ref limiter) = this.write_limiter { | |
177 | if let Poll::Ready(Ok(count)) = result { | |
8734d0c2 | 178 | this.write_delay = register_traffic(limiter.as_ref(), count); |
c94ad247 DM |
179 | } |
180 | } | |
181 | ||
182 | result | |
183 | } | |
184 | ||
0eeb0dd1 | 185 | fn poll_flush(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> { |
c94ad247 DM |
186 | let this = self.get_mut(); |
187 | Pin::new(&mut this.stream).poll_flush(ctx) | |
188 | } | |
189 | ||
190 | fn poll_shutdown( | |
191 | self: Pin<&mut Self>, | |
0eeb0dd1 | 192 | ctx: &mut Context<'_>, |
c94ad247 DM |
193 | ) -> Poll<Result<(), std::io::Error>> { |
194 | let this = self.get_mut(); | |
195 | Pin::new(&mut this.stream).poll_shutdown(ctx) | |
196 | } | |
197 | } | |
198 | ||
0eeb0dd1 | 199 | impl<S: AsyncRead + Unpin> AsyncRead for RateLimitedStream<S> { |
c94ad247 DM |
200 | fn poll_read( |
201 | self: Pin<&mut Self>, | |
202 | ctx: &mut Context<'_>, | |
203 | buf: &mut ReadBuf<'_>, | |
204 | ) -> Poll<Result<(), std::io::Error>> { | |
205 | let this = self.get_mut(); | |
206 | ||
ded24b3f | 207 | let is_ready = delay_is_ready(&mut this.read_delay, ctx); |
c94ad247 | 208 | |
0eeb0dd1 TL |
209 | if !is_ready { |
210 | return Poll::Pending; | |
211 | } | |
c94ad247 DM |
212 | |
213 | this.read_delay = None; | |
214 | ||
e0305f72 DM |
215 | this.update_limiters(); |
216 | ||
c94ad247 DM |
217 | let filled_len = buf.filled().len(); |
218 | let result = Pin::new(&mut this.stream).poll_read(ctx, buf); | |
219 | ||
220 | if let Some(ref read_limiter) = this.read_limiter { | |
221 | if let Poll::Ready(Ok(())) = &result { | |
222 | let count = buf.filled().len() - filled_len; | |
8734d0c2 | 223 | this.read_delay = register_traffic(read_limiter.as_ref(), count); |
c94ad247 DM |
224 | } |
225 | } | |
226 | ||
227 | result | |
228 | } | |
c94ad247 | 229 | } |
00ca0b7f DM |
230 | |
231 | // we need this for the hyper http client | |
232 | impl<S: Connection + AsyncRead + AsyncWrite + Unpin> Connection for RateLimitedStream<S> { | |
233 | fn connected(&self) -> Connected { | |
234 | self.stream.connected() | |
235 | } | |
236 | } |