]> git.proxmox.com Git - proxmox.git/blame - proxmox-http/src/rate_limited_stream.rs
http: rate limited stream: fix typo in rustdoc comment
[proxmox.git] / proxmox-http / src / rate_limited_stream.rs
CommitLineData
10a3ab22 1use std::future::Future;
0eeb0dd1 2use std::io::IoSlice;
c94ad247 3use std::marker::Unpin;
0eeb0dd1 4use std::pin::Pin;
c94ad247
DM
5use std::sync::{Arc, Mutex};
6use std::time::{Duration, Instant};
7
0eeb0dd1
TL
8use hyper::client::connect::{Connected, Connection};
9use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
c94ad247
DM
10use tokio::time::Sleep;
11
12use std::task::{Context, Poll};
13
0eeb0dd1 14use super::{RateLimiter, ShareableRateLimit};
8734d0c2
DM
15
16type SharedRateLimit = Arc<dyn ShareableRateLimit>;
c94ad247 17
94d388b9
WB
18pub type RateLimiterCallback =
19 dyn Fn() -> (Option<SharedRateLimit>, Option<SharedRateLimit>) + Send;
20
c94ad247
DM
21/// A rate limited stream using [RateLimiter]
22pub struct RateLimitedStream<S> {
8734d0c2 23 read_limiter: Option<SharedRateLimit>,
c94ad247 24 read_delay: Option<Pin<Box<Sleep>>>,
8734d0c2 25 write_limiter: Option<SharedRateLimit>,
c94ad247 26 write_delay: Option<Pin<Box<Sleep>>>,
94d388b9 27 update_limiter_cb: Option<Box<RateLimiterCallback>>,
e0305f72 28 last_limiter_update: Instant,
c94ad247
DM
29 stream: S,
30}
31
0eeb0dd1 32impl<S> RateLimitedStream<S> {
c94ad247
DM
33 /// Creates a new instance with reads and writes limited to the same `rate`.
34 pub fn new(stream: S, rate: u64, bucket_size: u64) -> Self {
35 let now = Instant::now();
8734d0c2
DM
36 let read_limiter = RateLimiter::with_start_time(rate, bucket_size, now);
37 let read_limiter: SharedRateLimit = Arc::new(Mutex::new(read_limiter));
38 let write_limiter = RateLimiter::with_start_time(rate, bucket_size, now);
39 let write_limiter: SharedRateLimit = Arc::new(Mutex::new(write_limiter));
c94ad247
DM
40 Self::with_limiter(stream, Some(read_limiter), Some(write_limiter))
41 }
42
c609a580 43 /// Creates a new instance with specified [`RateLimiter`s](RateLimiter) for reads and writes.
c94ad247
DM
44 pub fn with_limiter(
45 stream: S,
8734d0c2
DM
46 read_limiter: Option<SharedRateLimit>,
47 write_limiter: Option<SharedRateLimit>,
c94ad247 48 ) -> Self {
0eeb0dd1 49 Self {
c94ad247
DM
50 read_limiter,
51 read_delay: None,
52 write_limiter,
53 write_delay: None,
e0305f72
DM
54 update_limiter_cb: None,
55 last_limiter_update: Instant::now(),
c94ad247
DM
56 stream,
57 }
58 }
e0305f72
DM
59
60 /// Creates a new instance with limiter update callback.
61 ///
3ac6f2d9 62 /// The function is called every minute to update/change the used limiters.
e0305f72
DM
63 ///
64 /// Note: This function is called within an async context, so it
65 /// should be fast and must not block.
0eeb0dd1
TL
66 pub fn with_limiter_update_cb<
67 F: Fn() -> (Option<SharedRateLimit>, Option<SharedRateLimit>) + Send + 'static,
68 >(
e0305f72
DM
69 stream: S,
70 update_limiter_cb: F,
71 ) -> Self {
72 let (read_limiter, write_limiter) = update_limiter_cb();
73 Self {
74 read_limiter,
75 read_delay: None,
76 write_limiter,
77 write_delay: None,
78 update_limiter_cb: Some(Box::new(update_limiter_cb)),
79 last_limiter_update: Instant::now(),
80 stream,
81 }
82 }
83
84 fn update_limiters(&mut self) {
85 if let Some(ref update_limiter_cb) = self.update_limiter_cb {
86 if self.last_limiter_update.elapsed().as_secs() >= 5 {
87 self.last_limiter_update = Instant::now();
88 let (read_limiter, write_limiter) = update_limiter_cb();
89 self.read_limiter = read_limiter;
90 self.write_limiter = write_limiter;
91 }
92 }
93 }
d7ed04f8
WB
94
95 pub fn inner(&self) -> &S {
96 &self.stream
97 }
98
99 pub fn inner_mut(&mut self) -> &mut S {
100 &mut self.stream
101 }
c94ad247
DM
102}
103
0eeb0dd1 104fn register_traffic(limiter: &(dyn ShareableRateLimit), count: usize) -> Option<Pin<Box<Sleep>>> {
ded24b3f
DM
105 const MIN_DELAY: Duration = Duration::from_millis(10);
106
107 let now = Instant::now();
8734d0c2 108 let delay = limiter.register_traffic(now, count as u64);
ded24b3f
DM
109 if delay >= MIN_DELAY {
110 let sleep = tokio::time::sleep(delay);
111 Some(Box::pin(sleep))
112 } else {
113 None
114 }
115}
116
117fn delay_is_ready(delay: &mut Option<Pin<Box<Sleep>>>, ctx: &mut Context<'_>) -> bool {
118 match delay {
0eeb0dd1 119 Some(ref mut future) => future.as_mut().poll(ctx).is_ready(),
ded24b3f
DM
120 None => true,
121 }
122}
123
0eeb0dd1 124impl<S: AsyncWrite + Unpin> AsyncWrite for RateLimitedStream<S> {
c94ad247
DM
125 fn poll_write(
126 self: Pin<&mut Self>,
127 ctx: &mut Context<'_>,
0eeb0dd1 128 buf: &[u8],
c94ad247
DM
129 ) -> Poll<Result<usize, std::io::Error>> {
130 let this = self.get_mut();
131
ded24b3f 132 let is_ready = delay_is_ready(&mut this.write_delay, ctx);
c94ad247 133
0eeb0dd1
TL
134 if !is_ready {
135 return Poll::Pending;
136 }
c94ad247
DM
137
138 this.write_delay = None;
139
e0305f72
DM
140 this.update_limiters();
141
c94ad247
DM
142 let result = Pin::new(&mut this.stream).poll_write(ctx, buf);
143
8734d0c2 144 if let Some(ref mut limiter) = this.write_limiter {
ded24b3f 145 if let Poll::Ready(Ok(count)) = result {
8734d0c2 146 this.write_delay = register_traffic(limiter.as_ref(), count);
ded24b3f
DM
147 }
148 }
149
150 result
151 }
152
153 fn is_write_vectored(&self) -> bool {
154 self.stream.is_write_vectored()
155 }
156
157 fn poll_write_vectored(
158 self: Pin<&mut Self>,
159 ctx: &mut Context<'_>,
0eeb0dd1 160 bufs: &[IoSlice<'_>],
ded24b3f
DM
161 ) -> Poll<Result<usize, std::io::Error>> {
162 let this = self.get_mut();
163
164 let is_ready = delay_is_ready(&mut this.write_delay, ctx);
165
0eeb0dd1
TL
166 if !is_ready {
167 return Poll::Pending;
168 }
ded24b3f
DM
169
170 this.write_delay = None;
171
e0305f72
DM
172 this.update_limiters();
173
ded24b3f
DM
174 let result = Pin::new(&mut this.stream).poll_write_vectored(ctx, bufs);
175
176 if let Some(ref limiter) = this.write_limiter {
177 if let Poll::Ready(Ok(count)) = result {
8734d0c2 178 this.write_delay = register_traffic(limiter.as_ref(), count);
c94ad247
DM
179 }
180 }
181
182 result
183 }
184
0eeb0dd1 185 fn poll_flush(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
c94ad247
DM
186 let this = self.get_mut();
187 Pin::new(&mut this.stream).poll_flush(ctx)
188 }
189
190 fn poll_shutdown(
191 self: Pin<&mut Self>,
0eeb0dd1 192 ctx: &mut Context<'_>,
c94ad247
DM
193 ) -> Poll<Result<(), std::io::Error>> {
194 let this = self.get_mut();
195 Pin::new(&mut this.stream).poll_shutdown(ctx)
196 }
197}
198
0eeb0dd1 199impl<S: AsyncRead + Unpin> AsyncRead for RateLimitedStream<S> {
c94ad247
DM
200 fn poll_read(
201 self: Pin<&mut Self>,
202 ctx: &mut Context<'_>,
203 buf: &mut ReadBuf<'_>,
204 ) -> Poll<Result<(), std::io::Error>> {
205 let this = self.get_mut();
206
ded24b3f 207 let is_ready = delay_is_ready(&mut this.read_delay, ctx);
c94ad247 208
0eeb0dd1
TL
209 if !is_ready {
210 return Poll::Pending;
211 }
c94ad247
DM
212
213 this.read_delay = None;
214
e0305f72
DM
215 this.update_limiters();
216
c94ad247
DM
217 let filled_len = buf.filled().len();
218 let result = Pin::new(&mut this.stream).poll_read(ctx, buf);
219
220 if let Some(ref read_limiter) = this.read_limiter {
221 if let Poll::Ready(Ok(())) = &result {
222 let count = buf.filled().len() - filled_len;
8734d0c2 223 this.read_delay = register_traffic(read_limiter.as_ref(), count);
c94ad247
DM
224 }
225 }
226
227 result
228 }
c94ad247 229}
00ca0b7f
DM
230
231// we need this for the hyper http client
232impl<S: Connection + AsyncRead + AsyncWrite + Unpin> Connection for RateLimitedStream<S> {
233 fn connected(&self) -> Connected {
234 self.stream.connected()
235 }
236}