]> git.proxmox.com Git - proxmox.git/blob - proxmox-http/src/client/rate_limited_stream.rs
proxmox-http: use SharedRateLimit trait object for RateLimitedStream
[proxmox.git] / proxmox-http / src / client / rate_limited_stream.rs
1 use std::pin::Pin;
2 use std::marker::Unpin;
3 use std::sync::{Arc, Mutex};
4 use std::time::{Duration, Instant};
5 use std::io::IoSlice;
6
7 use futures::Future;
8 use tokio::io::{ReadBuf, AsyncRead, AsyncWrite};
9 use tokio::time::Sleep;
10 use hyper::client::connect::{Connection, Connected};
11
12 use std::task::{Context, Poll};
13
14 use super::{ShareableRateLimit, RateLimiter};
15
16 type SharedRateLimit = Arc<dyn ShareableRateLimit>;
17
18 /// A rate limited stream using [RateLimiter]
19 pub struct RateLimitedStream<S> {
20 read_limiter: Option<SharedRateLimit>,
21 read_delay: Option<Pin<Box<Sleep>>>,
22 write_limiter: Option<SharedRateLimit>,
23 write_delay: Option<Pin<Box<Sleep>>>,
24 update_limiter_cb: Option<Box<dyn Fn() -> (Option<SharedRateLimit>, Option<SharedRateLimit>) + Send>>,
25 last_limiter_update: Instant,
26 stream: S,
27 }
28
29 impl RateLimitedStream<tokio::net::TcpStream> {
30 pub fn peer_addr(&self) -> std::io::Result<std::net::SocketAddr> {
31 self.stream.peer_addr()
32 }
33 }
34
35 impl <S> RateLimitedStream<S> {
36
37 /// Creates a new instance with reads and writes limited to the same `rate`.
38 pub fn new(stream: S, rate: u64, bucket_size: u64) -> Self {
39 let now = Instant::now();
40 let read_limiter = RateLimiter::with_start_time(rate, bucket_size, now);
41 let read_limiter: SharedRateLimit = Arc::new(Mutex::new(read_limiter));
42 let write_limiter = RateLimiter::with_start_time(rate, bucket_size, now);
43 let write_limiter: SharedRateLimit = Arc::new(Mutex::new(write_limiter));
44 Self::with_limiter(stream, Some(read_limiter), Some(write_limiter))
45 }
46
47 /// Creates a new instance with specified [RateLimiters] for reads and writes.
48 pub fn with_limiter(
49 stream: S,
50 read_limiter: Option<SharedRateLimit>,
51 write_limiter: Option<SharedRateLimit>,
52 ) -> Self {
53 Self {
54 read_limiter,
55 read_delay: None,
56 write_limiter,
57 write_delay: None,
58 update_limiter_cb: None,
59 last_limiter_update: Instant::now(),
60 stream,
61 }
62 }
63
64 /// Creates a new instance with limiter update callback.
65 ///
66 /// The fuction is called every minute to update/change the used limiters.
67 ///
68 /// Note: This function is called within an async context, so it
69 /// should be fast and must not block.
70 pub fn with_limiter_update_cb<F: Fn() -> (Option<SharedRateLimit>, Option<SharedRateLimit>) + Send + 'static>(
71 stream: S,
72 update_limiter_cb: F,
73 ) -> Self {
74 let (read_limiter, write_limiter) = update_limiter_cb();
75 Self {
76 read_limiter,
77 read_delay: None,
78 write_limiter,
79 write_delay: None,
80 update_limiter_cb: Some(Box::new(update_limiter_cb)),
81 last_limiter_update: Instant::now(),
82 stream,
83 }
84 }
85
86 fn update_limiters(&mut self) {
87 if let Some(ref update_limiter_cb) = self.update_limiter_cb {
88 if self.last_limiter_update.elapsed().as_secs() >= 5 {
89 self.last_limiter_update = Instant::now();
90 let (read_limiter, write_limiter) = update_limiter_cb();
91 self.read_limiter = read_limiter;
92 self.write_limiter = write_limiter;
93 }
94 }
95 }
96 }
97
98 fn register_traffic(
99 limiter: &(dyn ShareableRateLimit),
100 count: usize,
101 ) -> Option<Pin<Box<Sleep>>>{
102
103 const MIN_DELAY: Duration = Duration::from_millis(10);
104
105 let now = Instant::now();
106 let delay = limiter.register_traffic(now, count as u64);
107 if delay >= MIN_DELAY {
108 let sleep = tokio::time::sleep(delay);
109 Some(Box::pin(sleep))
110 } else {
111 None
112 }
113 }
114
115 fn delay_is_ready(delay: &mut Option<Pin<Box<Sleep>>>, ctx: &mut Context<'_>) -> bool {
116 match delay {
117 Some(ref mut future) => {
118 future.as_mut().poll(ctx).is_ready()
119 }
120 None => true,
121 }
122 }
123
124 impl <S: AsyncWrite + Unpin> AsyncWrite for RateLimitedStream<S> {
125
126 fn poll_write(
127 self: Pin<&mut Self>,
128 ctx: &mut Context<'_>,
129 buf: &[u8]
130 ) -> Poll<Result<usize, std::io::Error>> {
131 let this = self.get_mut();
132
133 let is_ready = delay_is_ready(&mut this.write_delay, ctx);
134
135 if !is_ready { return Poll::Pending; }
136
137 this.write_delay = None;
138
139 this.update_limiters();
140
141 let result = Pin::new(&mut this.stream).poll_write(ctx, buf);
142
143 if let Some(ref mut limiter) = this.write_limiter {
144 if let Poll::Ready(Ok(count)) = result {
145 this.write_delay = register_traffic(limiter.as_ref(), count);
146 }
147 }
148
149 result
150 }
151
152 fn is_write_vectored(&self) -> bool {
153 self.stream.is_write_vectored()
154 }
155
156 fn poll_write_vectored(
157 self: Pin<&mut Self>,
158 ctx: &mut Context<'_>,
159 bufs: &[IoSlice<'_>]
160 ) -> Poll<Result<usize, std::io::Error>> {
161 let this = self.get_mut();
162
163 let is_ready = delay_is_ready(&mut this.write_delay, ctx);
164
165 if !is_ready { return Poll::Pending; }
166
167 this.write_delay = None;
168
169 this.update_limiters();
170
171 let result = Pin::new(&mut this.stream).poll_write_vectored(ctx, bufs);
172
173 if let Some(ref limiter) = this.write_limiter {
174 if let Poll::Ready(Ok(count)) = result {
175 this.write_delay = register_traffic(limiter.as_ref(), count);
176 }
177 }
178
179 result
180 }
181
182 fn poll_flush(
183 self: Pin<&mut Self>,
184 ctx: &mut Context<'_>
185 ) -> Poll<Result<(), std::io::Error>> {
186 let this = self.get_mut();
187 Pin::new(&mut this.stream).poll_flush(ctx)
188 }
189
190 fn poll_shutdown(
191 self: Pin<&mut Self>,
192 ctx: &mut Context<'_>
193 ) -> Poll<Result<(), std::io::Error>> {
194 let this = self.get_mut();
195 Pin::new(&mut this.stream).poll_shutdown(ctx)
196 }
197 }
198
199 impl <S: AsyncRead + Unpin> AsyncRead for RateLimitedStream<S> {
200
201 fn poll_read(
202 self: Pin<&mut Self>,
203 ctx: &mut Context<'_>,
204 buf: &mut ReadBuf<'_>,
205 ) -> Poll<Result<(), std::io::Error>> {
206 let this = self.get_mut();
207
208 let is_ready = delay_is_ready(&mut this.read_delay, ctx);
209
210 if !is_ready { return Poll::Pending; }
211
212 this.read_delay = None;
213
214 this.update_limiters();
215
216 let filled_len = buf.filled().len();
217 let result = Pin::new(&mut this.stream).poll_read(ctx, buf);
218
219 if let Some(ref read_limiter) = this.read_limiter {
220 if let Poll::Ready(Ok(())) = &result {
221 let count = buf.filled().len() - filled_len;
222 this.read_delay = register_traffic(read_limiter.as_ref(), count);
223 }
224 }
225
226 result
227 }
228
229 }
230
231 // we need this for the hyper http client
232 impl<S: Connection + AsyncRead + AsyncWrite + Unpin> Connection for RateLimitedStream<S> {
233 fn connected(&self) -> Connected {
234 self.stream.connected()
235 }
236 }