]> git.proxmox.com Git - proxmox.git/blob - proxmox-http/src/client/rate_limited_stream.rs
http: clippy fixups
[proxmox.git] / proxmox-http / src / client / rate_limited_stream.rs
1 use std::io::IoSlice;
2 use std::marker::Unpin;
3 use std::pin::Pin;
4 use std::sync::{Arc, Mutex};
5 use std::time::{Duration, Instant};
6
7 use futures::Future;
8 use hyper::client::connect::{Connected, Connection};
9 use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
10 use tokio::time::Sleep;
11
12 use std::task::{Context, Poll};
13
14 use super::{RateLimiter, ShareableRateLimit};
15
16 type SharedRateLimit = Arc<dyn ShareableRateLimit>;
17
18 pub type RateLimiterCallback =
19 dyn Fn() -> (Option<SharedRateLimit>, Option<SharedRateLimit>) + Send;
20
21 /// A rate limited stream using [RateLimiter]
22 pub struct RateLimitedStream<S> {
23 read_limiter: Option<SharedRateLimit>,
24 read_delay: Option<Pin<Box<Sleep>>>,
25 write_limiter: Option<SharedRateLimit>,
26 write_delay: Option<Pin<Box<Sleep>>>,
27 update_limiter_cb: Option<Box<RateLimiterCallback>>,
28 last_limiter_update: Instant,
29 stream: S,
30 }
31
32 impl RateLimitedStream<tokio::net::TcpStream> {
33 pub fn peer_addr(&self) -> std::io::Result<std::net::SocketAddr> {
34 self.stream.peer_addr()
35 }
36 }
37
38 impl<S> RateLimitedStream<S> {
39 /// Creates a new instance with reads and writes limited to the same `rate`.
40 pub fn new(stream: S, rate: u64, bucket_size: u64) -> Self {
41 let now = Instant::now();
42 let read_limiter = RateLimiter::with_start_time(rate, bucket_size, now);
43 let read_limiter: SharedRateLimit = Arc::new(Mutex::new(read_limiter));
44 let write_limiter = RateLimiter::with_start_time(rate, bucket_size, now);
45 let write_limiter: SharedRateLimit = Arc::new(Mutex::new(write_limiter));
46 Self::with_limiter(stream, Some(read_limiter), Some(write_limiter))
47 }
48
49 /// Creates a new instance with specified [`RateLimiter`s](RateLimiter) for reads and writes.
50 pub fn with_limiter(
51 stream: S,
52 read_limiter: Option<SharedRateLimit>,
53 write_limiter: Option<SharedRateLimit>,
54 ) -> Self {
55 Self {
56 read_limiter,
57 read_delay: None,
58 write_limiter,
59 write_delay: None,
60 update_limiter_cb: None,
61 last_limiter_update: Instant::now(),
62 stream,
63 }
64 }
65
66 /// Creates a new instance with limiter update callback.
67 ///
68 /// The fuction is called every minute to update/change the used limiters.
69 ///
70 /// Note: This function is called within an async context, so it
71 /// should be fast and must not block.
72 pub fn with_limiter_update_cb<
73 F: Fn() -> (Option<SharedRateLimit>, Option<SharedRateLimit>) + Send + 'static,
74 >(
75 stream: S,
76 update_limiter_cb: F,
77 ) -> Self {
78 let (read_limiter, write_limiter) = update_limiter_cb();
79 Self {
80 read_limiter,
81 read_delay: None,
82 write_limiter,
83 write_delay: None,
84 update_limiter_cb: Some(Box::new(update_limiter_cb)),
85 last_limiter_update: Instant::now(),
86 stream,
87 }
88 }
89
90 fn update_limiters(&mut self) {
91 if let Some(ref update_limiter_cb) = self.update_limiter_cb {
92 if self.last_limiter_update.elapsed().as_secs() >= 5 {
93 self.last_limiter_update = Instant::now();
94 let (read_limiter, write_limiter) = update_limiter_cb();
95 self.read_limiter = read_limiter;
96 self.write_limiter = write_limiter;
97 }
98 }
99 }
100 }
101
102 fn register_traffic(limiter: &(dyn ShareableRateLimit), count: usize) -> Option<Pin<Box<Sleep>>> {
103 const MIN_DELAY: Duration = Duration::from_millis(10);
104
105 let now = Instant::now();
106 let delay = limiter.register_traffic(now, count as u64);
107 if delay >= MIN_DELAY {
108 let sleep = tokio::time::sleep(delay);
109 Some(Box::pin(sleep))
110 } else {
111 None
112 }
113 }
114
115 fn delay_is_ready(delay: &mut Option<Pin<Box<Sleep>>>, ctx: &mut Context<'_>) -> bool {
116 match delay {
117 Some(ref mut future) => future.as_mut().poll(ctx).is_ready(),
118 None => true,
119 }
120 }
121
122 impl<S: AsyncWrite + Unpin> AsyncWrite for RateLimitedStream<S> {
123 fn poll_write(
124 self: Pin<&mut Self>,
125 ctx: &mut Context<'_>,
126 buf: &[u8],
127 ) -> Poll<Result<usize, std::io::Error>> {
128 let this = self.get_mut();
129
130 let is_ready = delay_is_ready(&mut this.write_delay, ctx);
131
132 if !is_ready {
133 return Poll::Pending;
134 }
135
136 this.write_delay = None;
137
138 this.update_limiters();
139
140 let result = Pin::new(&mut this.stream).poll_write(ctx, buf);
141
142 if let Some(ref mut limiter) = this.write_limiter {
143 if let Poll::Ready(Ok(count)) = result {
144 this.write_delay = register_traffic(limiter.as_ref(), count);
145 }
146 }
147
148 result
149 }
150
151 fn is_write_vectored(&self) -> bool {
152 self.stream.is_write_vectored()
153 }
154
155 fn poll_write_vectored(
156 self: Pin<&mut Self>,
157 ctx: &mut Context<'_>,
158 bufs: &[IoSlice<'_>],
159 ) -> Poll<Result<usize, std::io::Error>> {
160 let this = self.get_mut();
161
162 let is_ready = delay_is_ready(&mut this.write_delay, ctx);
163
164 if !is_ready {
165 return Poll::Pending;
166 }
167
168 this.write_delay = None;
169
170 this.update_limiters();
171
172 let result = Pin::new(&mut this.stream).poll_write_vectored(ctx, bufs);
173
174 if let Some(ref limiter) = this.write_limiter {
175 if let Poll::Ready(Ok(count)) = result {
176 this.write_delay = register_traffic(limiter.as_ref(), count);
177 }
178 }
179
180 result
181 }
182
183 fn poll_flush(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
184 let this = self.get_mut();
185 Pin::new(&mut this.stream).poll_flush(ctx)
186 }
187
188 fn poll_shutdown(
189 self: Pin<&mut Self>,
190 ctx: &mut Context<'_>,
191 ) -> Poll<Result<(), std::io::Error>> {
192 let this = self.get_mut();
193 Pin::new(&mut this.stream).poll_shutdown(ctx)
194 }
195 }
196
197 impl<S: AsyncRead + Unpin> AsyncRead for RateLimitedStream<S> {
198 fn poll_read(
199 self: Pin<&mut Self>,
200 ctx: &mut Context<'_>,
201 buf: &mut ReadBuf<'_>,
202 ) -> Poll<Result<(), std::io::Error>> {
203 let this = self.get_mut();
204
205 let is_ready = delay_is_ready(&mut this.read_delay, ctx);
206
207 if !is_ready {
208 return Poll::Pending;
209 }
210
211 this.read_delay = None;
212
213 this.update_limiters();
214
215 let filled_len = buf.filled().len();
216 let result = Pin::new(&mut this.stream).poll_read(ctx, buf);
217
218 if let Some(ref read_limiter) = this.read_limiter {
219 if let Poll::Ready(Ok(())) = &result {
220 let count = buf.filled().len() - filled_len;
221 this.read_delay = register_traffic(read_limiter.as_ref(), count);
222 }
223 }
224
225 result
226 }
227 }
228
229 // we need this for the hyper http client
230 impl<S: Connection + AsyncRead + AsyncWrite + Unpin> Connection for RateLimitedStream<S> {
231 fn connected(&self) -> Connected {
232 self.stream.connected()
233 }
234 }