]> git.proxmox.com Git - proxmox.git/blame - proxmox-http/src/client/rate_limited_stream.rs
http: clippy fixups
[proxmox.git] / proxmox-http / src / client / rate_limited_stream.rs
CommitLineData
0eeb0dd1 1use std::io::IoSlice;
c94ad247 2use std::marker::Unpin;
0eeb0dd1 3use std::pin::Pin;
c94ad247
DM
4use std::sync::{Arc, Mutex};
5use std::time::{Duration, Instant};
6
7use futures::Future;
0eeb0dd1
TL
8use hyper::client::connect::{Connected, Connection};
9use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
c94ad247
DM
10use tokio::time::Sleep;
11
12use std::task::{Context, Poll};
13
0eeb0dd1 14use super::{RateLimiter, ShareableRateLimit};
8734d0c2
DM
15
16type SharedRateLimit = Arc<dyn ShareableRateLimit>;
c94ad247 17
94d388b9
WB
18pub type RateLimiterCallback =
19 dyn Fn() -> (Option<SharedRateLimit>, Option<SharedRateLimit>) + Send;
20
c94ad247
DM
21/// A rate limited stream using [RateLimiter]
22pub struct RateLimitedStream<S> {
8734d0c2 23 read_limiter: Option<SharedRateLimit>,
c94ad247 24 read_delay: Option<Pin<Box<Sleep>>>,
8734d0c2 25 write_limiter: Option<SharedRateLimit>,
c94ad247 26 write_delay: Option<Pin<Box<Sleep>>>,
94d388b9 27 update_limiter_cb: Option<Box<RateLimiterCallback>>,
e0305f72 28 last_limiter_update: Instant,
c94ad247
DM
29 stream: S,
30}
31
0c27d5da
DM
32impl RateLimitedStream<tokio::net::TcpStream> {
33 pub fn peer_addr(&self) -> std::io::Result<std::net::SocketAddr> {
34 self.stream.peer_addr()
35 }
36}
37
0eeb0dd1 38impl<S> RateLimitedStream<S> {
c94ad247
DM
39 /// Creates a new instance with reads and writes limited to the same `rate`.
40 pub fn new(stream: S, rate: u64, bucket_size: u64) -> Self {
41 let now = Instant::now();
8734d0c2
DM
42 let read_limiter = RateLimiter::with_start_time(rate, bucket_size, now);
43 let read_limiter: SharedRateLimit = Arc::new(Mutex::new(read_limiter));
44 let write_limiter = RateLimiter::with_start_time(rate, bucket_size, now);
45 let write_limiter: SharedRateLimit = Arc::new(Mutex::new(write_limiter));
c94ad247
DM
46 Self::with_limiter(stream, Some(read_limiter), Some(write_limiter))
47 }
48
c609a580 49 /// Creates a new instance with specified [`RateLimiter`s](RateLimiter) for reads and writes.
c94ad247
DM
50 pub fn with_limiter(
51 stream: S,
8734d0c2
DM
52 read_limiter: Option<SharedRateLimit>,
53 write_limiter: Option<SharedRateLimit>,
c94ad247 54 ) -> Self {
0eeb0dd1 55 Self {
c94ad247
DM
56 read_limiter,
57 read_delay: None,
58 write_limiter,
59 write_delay: None,
e0305f72
DM
60 update_limiter_cb: None,
61 last_limiter_update: Instant::now(),
c94ad247
DM
62 stream,
63 }
64 }
e0305f72
DM
65
66 /// Creates a new instance with limiter update callback.
67 ///
68 /// The fuction is called every minute to update/change the used limiters.
69 ///
70 /// Note: This function is called within an async context, so it
71 /// should be fast and must not block.
0eeb0dd1
TL
72 pub fn with_limiter_update_cb<
73 F: Fn() -> (Option<SharedRateLimit>, Option<SharedRateLimit>) + Send + 'static,
74 >(
e0305f72
DM
75 stream: S,
76 update_limiter_cb: F,
77 ) -> Self {
78 let (read_limiter, write_limiter) = update_limiter_cb();
79 Self {
80 read_limiter,
81 read_delay: None,
82 write_limiter,
83 write_delay: None,
84 update_limiter_cb: Some(Box::new(update_limiter_cb)),
85 last_limiter_update: Instant::now(),
86 stream,
87 }
88 }
89
90 fn update_limiters(&mut self) {
91 if let Some(ref update_limiter_cb) = self.update_limiter_cb {
92 if self.last_limiter_update.elapsed().as_secs() >= 5 {
93 self.last_limiter_update = Instant::now();
94 let (read_limiter, write_limiter) = update_limiter_cb();
95 self.read_limiter = read_limiter;
96 self.write_limiter = write_limiter;
97 }
98 }
99 }
c94ad247
DM
100}
101
0eeb0dd1 102fn register_traffic(limiter: &(dyn ShareableRateLimit), count: usize) -> Option<Pin<Box<Sleep>>> {
ded24b3f
DM
103 const MIN_DELAY: Duration = Duration::from_millis(10);
104
105 let now = Instant::now();
8734d0c2 106 let delay = limiter.register_traffic(now, count as u64);
ded24b3f
DM
107 if delay >= MIN_DELAY {
108 let sleep = tokio::time::sleep(delay);
109 Some(Box::pin(sleep))
110 } else {
111 None
112 }
113}
114
115fn delay_is_ready(delay: &mut Option<Pin<Box<Sleep>>>, ctx: &mut Context<'_>) -> bool {
116 match delay {
0eeb0dd1 117 Some(ref mut future) => future.as_mut().poll(ctx).is_ready(),
ded24b3f
DM
118 None => true,
119 }
120}
121
0eeb0dd1 122impl<S: AsyncWrite + Unpin> AsyncWrite for RateLimitedStream<S> {
c94ad247
DM
123 fn poll_write(
124 self: Pin<&mut Self>,
125 ctx: &mut Context<'_>,
0eeb0dd1 126 buf: &[u8],
c94ad247
DM
127 ) -> Poll<Result<usize, std::io::Error>> {
128 let this = self.get_mut();
129
ded24b3f 130 let is_ready = delay_is_ready(&mut this.write_delay, ctx);
c94ad247 131
0eeb0dd1
TL
132 if !is_ready {
133 return Poll::Pending;
134 }
c94ad247
DM
135
136 this.write_delay = None;
137
e0305f72
DM
138 this.update_limiters();
139
c94ad247
DM
140 let result = Pin::new(&mut this.stream).poll_write(ctx, buf);
141
8734d0c2 142 if let Some(ref mut limiter) = this.write_limiter {
ded24b3f 143 if let Poll::Ready(Ok(count)) = result {
8734d0c2 144 this.write_delay = register_traffic(limiter.as_ref(), count);
ded24b3f
DM
145 }
146 }
147
148 result
149 }
150
151 fn is_write_vectored(&self) -> bool {
152 self.stream.is_write_vectored()
153 }
154
155 fn poll_write_vectored(
156 self: Pin<&mut Self>,
157 ctx: &mut Context<'_>,
0eeb0dd1 158 bufs: &[IoSlice<'_>],
ded24b3f
DM
159 ) -> Poll<Result<usize, std::io::Error>> {
160 let this = self.get_mut();
161
162 let is_ready = delay_is_ready(&mut this.write_delay, ctx);
163
0eeb0dd1
TL
164 if !is_ready {
165 return Poll::Pending;
166 }
ded24b3f
DM
167
168 this.write_delay = None;
169
e0305f72
DM
170 this.update_limiters();
171
ded24b3f
DM
172 let result = Pin::new(&mut this.stream).poll_write_vectored(ctx, bufs);
173
174 if let Some(ref limiter) = this.write_limiter {
175 if let Poll::Ready(Ok(count)) = result {
8734d0c2 176 this.write_delay = register_traffic(limiter.as_ref(), count);
c94ad247
DM
177 }
178 }
179
180 result
181 }
182
0eeb0dd1 183 fn poll_flush(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
c94ad247
DM
184 let this = self.get_mut();
185 Pin::new(&mut this.stream).poll_flush(ctx)
186 }
187
188 fn poll_shutdown(
189 self: Pin<&mut Self>,
0eeb0dd1 190 ctx: &mut Context<'_>,
c94ad247
DM
191 ) -> Poll<Result<(), std::io::Error>> {
192 let this = self.get_mut();
193 Pin::new(&mut this.stream).poll_shutdown(ctx)
194 }
195}
196
0eeb0dd1 197impl<S: AsyncRead + Unpin> AsyncRead for RateLimitedStream<S> {
c94ad247
DM
198 fn poll_read(
199 self: Pin<&mut Self>,
200 ctx: &mut Context<'_>,
201 buf: &mut ReadBuf<'_>,
202 ) -> Poll<Result<(), std::io::Error>> {
203 let this = self.get_mut();
204
ded24b3f 205 let is_ready = delay_is_ready(&mut this.read_delay, ctx);
c94ad247 206
0eeb0dd1
TL
207 if !is_ready {
208 return Poll::Pending;
209 }
c94ad247
DM
210
211 this.read_delay = None;
212
e0305f72
DM
213 this.update_limiters();
214
c94ad247
DM
215 let filled_len = buf.filled().len();
216 let result = Pin::new(&mut this.stream).poll_read(ctx, buf);
217
218 if let Some(ref read_limiter) = this.read_limiter {
219 if let Poll::Ready(Ok(())) = &result {
220 let count = buf.filled().len() - filled_len;
8734d0c2 221 this.read_delay = register_traffic(read_limiter.as_ref(), count);
c94ad247
DM
222 }
223 }
224
225 result
226 }
c94ad247 227}
00ca0b7f
DM
228
229// we need this for the hyper http client
230impl<S: Connection + AsyncRead + AsyncWrite + Unpin> Connection for RateLimitedStream<S> {
231 fn connected(&self) -> Connected {
232 self.stream.connected()
233 }
234}