2 use std
::marker
::Unpin
;
3 use std
::sync
::{Arc, Mutex}
;
4 use std
::time
::{Duration, Instant}
;
8 use tokio
::io
::{ReadBuf, AsyncRead, AsyncWrite}
;
9 use tokio
::time
::Sleep
;
10 use hyper
::client
::connect
::{Connection, Connected}
;
12 use std
::task
::{Context, Poll}
;
14 use super::{ShareableRateLimit, RateLimiter}
;
16 type SharedRateLimit
= Arc
<dyn ShareableRateLimit
>;
18 /// A rate limited stream using [RateLimiter]
19 pub struct RateLimitedStream
<S
> {
20 read_limiter
: Option
<SharedRateLimit
>,
21 read_delay
: Option
<Pin
<Box
<Sleep
>>>,
22 write_limiter
: Option
<SharedRateLimit
>,
23 write_delay
: Option
<Pin
<Box
<Sleep
>>>,
24 update_limiter_cb
: Option
<Box
<dyn Fn() -> (Option
<SharedRateLimit
>, Option
<SharedRateLimit
>) + Send
>>,
25 last_limiter_update
: Instant
,
29 impl RateLimitedStream
<tokio
::net
::TcpStream
> {
30 pub fn peer_addr(&self) -> std
::io
::Result
<std
::net
::SocketAddr
> {
31 self.stream
.peer_addr()
35 impl <S
> RateLimitedStream
<S
> {
37 /// Creates a new instance with reads and writes limited to the same `rate`.
38 pub fn new(stream
: S
, rate
: u64, bucket_size
: u64) -> Self {
39 let now
= Instant
::now();
40 let read_limiter
= RateLimiter
::with_start_time(rate
, bucket_size
, now
);
41 let read_limiter
: SharedRateLimit
= Arc
::new(Mutex
::new(read_limiter
));
42 let write_limiter
= RateLimiter
::with_start_time(rate
, bucket_size
, now
);
43 let write_limiter
: SharedRateLimit
= Arc
::new(Mutex
::new(write_limiter
));
44 Self::with_limiter(stream
, Some(read_limiter
), Some(write_limiter
))
47 /// Creates a new instance with specified [RateLimiters] for reads and writes.
50 read_limiter
: Option
<SharedRateLimit
>,
51 write_limiter
: Option
<SharedRateLimit
>,
58 update_limiter_cb
: None
,
59 last_limiter_update
: Instant
::now(),
64 /// Creates a new instance with limiter update callback.
66 /// The fuction is called every minute to update/change the used limiters.
68 /// Note: This function is called within an async context, so it
69 /// should be fast and must not block.
70 pub fn with_limiter_update_cb
<F
: Fn() -> (Option
<SharedRateLimit
>, Option
<SharedRateLimit
>) + Send
+ '
static>(
74 let (read_limiter
, write_limiter
) = update_limiter_cb();
80 update_limiter_cb
: Some(Box
::new(update_limiter_cb
)),
81 last_limiter_update
: Instant
::now(),
86 fn update_limiters(&mut self) {
87 if let Some(ref update_limiter_cb
) = self.update_limiter_cb
{
88 if self.last_limiter_update
.elapsed().as_secs() >= 5 {
89 self.last_limiter_update
= Instant
::now();
90 let (read_limiter
, write_limiter
) = update_limiter_cb();
91 self.read_limiter
= read_limiter
;
92 self.write_limiter
= write_limiter
;
99 limiter
: &(dyn ShareableRateLimit
),
101 ) -> Option
<Pin
<Box
<Sleep
>>>{
103 const MIN_DELAY
: Duration
= Duration
::from_millis(10);
105 let now
= Instant
::now();
106 let delay
= limiter
.register_traffic(now
, count
as u64);
107 if delay
>= MIN_DELAY
{
108 let sleep
= tokio
::time
::sleep(delay
);
109 Some(Box
::pin(sleep
))
115 fn delay_is_ready(delay
: &mut Option
<Pin
<Box
<Sleep
>>>, ctx
: &mut Context
<'_
>) -> bool
{
117 Some(ref mut future
) => {
118 future
.as_mut().poll(ctx
).is_ready()
124 impl <S
: AsyncWrite
+ Unpin
> AsyncWrite
for RateLimitedStream
<S
> {
127 self: Pin
<&mut Self>,
128 ctx
: &mut Context
<'_
>,
130 ) -> Poll
<Result
<usize, std
::io
::Error
>> {
131 let this
= self.get_mut();
133 let is_ready
= delay_is_ready(&mut this
.write_delay
, ctx
);
135 if !is_ready { return Poll::Pending; }
137 this
.write_delay
= None
;
139 this
.update_limiters();
141 let result
= Pin
::new(&mut this
.stream
).poll_write(ctx
, buf
);
143 if let Some(ref mut limiter
) = this
.write_limiter
{
144 if let Poll
::Ready(Ok(count
)) = result
{
145 this
.write_delay
= register_traffic(limiter
.as_ref(), count
);
152 fn is_write_vectored(&self) -> bool
{
153 self.stream
.is_write_vectored()
156 fn poll_write_vectored(
157 self: Pin
<&mut Self>,
158 ctx
: &mut Context
<'_
>,
160 ) -> Poll
<Result
<usize, std
::io
::Error
>> {
161 let this
= self.get_mut();
163 let is_ready
= delay_is_ready(&mut this
.write_delay
, ctx
);
165 if !is_ready { return Poll::Pending; }
167 this
.write_delay
= None
;
169 this
.update_limiters();
171 let result
= Pin
::new(&mut this
.stream
).poll_write_vectored(ctx
, bufs
);
173 if let Some(ref limiter
) = this
.write_limiter
{
174 if let Poll
::Ready(Ok(count
)) = result
{
175 this
.write_delay
= register_traffic(limiter
.as_ref(), count
);
183 self: Pin
<&mut Self>,
184 ctx
: &mut Context
<'_
>
185 ) -> Poll
<Result
<(), std
::io
::Error
>> {
186 let this
= self.get_mut();
187 Pin
::new(&mut this
.stream
).poll_flush(ctx
)
191 self: Pin
<&mut Self>,
192 ctx
: &mut Context
<'_
>
193 ) -> Poll
<Result
<(), std
::io
::Error
>> {
194 let this
= self.get_mut();
195 Pin
::new(&mut this
.stream
).poll_shutdown(ctx
)
199 impl <S
: AsyncRead
+ Unpin
> AsyncRead
for RateLimitedStream
<S
> {
202 self: Pin
<&mut Self>,
203 ctx
: &mut Context
<'_
>,
204 buf
: &mut ReadBuf
<'_
>,
205 ) -> Poll
<Result
<(), std
::io
::Error
>> {
206 let this
= self.get_mut();
208 let is_ready
= delay_is_ready(&mut this
.read_delay
, ctx
);
210 if !is_ready { return Poll::Pending; }
212 this
.read_delay
= None
;
214 this
.update_limiters();
216 let filled_len
= buf
.filled().len();
217 let result
= Pin
::new(&mut this
.stream
).poll_read(ctx
, buf
);
219 if let Some(ref read_limiter
) = this
.read_limiter
{
220 if let Poll
::Ready(Ok(())) = &result
{
221 let count
= buf
.filled().len() - filled_len
;
222 this
.read_delay
= register_traffic(read_limiter
.as_ref(), count
);
231 // we need this for the hyper http client
232 impl<S
: Connection
+ AsyncRead
+ AsyncWrite
+ Unpin
> Connection
for RateLimitedStream
<S
> {
233 fn connected(&self) -> Connected
{
234 self.stream
.connected()