2 use std
::marker
::Unpin
;
4 use std
::sync
::{Arc, Mutex}
;
5 use std
::time
::{Duration, Instant}
;
8 use hyper
::client
::connect
::{Connected, Connection}
;
9 use tokio
::io
::{AsyncRead, AsyncWrite, ReadBuf}
;
10 use tokio
::time
::Sleep
;
12 use std
::task
::{Context, Poll}
;
14 use super::{RateLimiter, ShareableRateLimit}
;
16 type SharedRateLimit
= Arc
<dyn ShareableRateLimit
>;
18 pub type RateLimiterCallback
=
19 dyn Fn() -> (Option
<SharedRateLimit
>, Option
<SharedRateLimit
>) + Send
;
21 /// A rate limited stream using [RateLimiter]
22 pub struct RateLimitedStream
<S
> {
23 read_limiter
: Option
<SharedRateLimit
>,
24 read_delay
: Option
<Pin
<Box
<Sleep
>>>,
25 write_limiter
: Option
<SharedRateLimit
>,
26 write_delay
: Option
<Pin
<Box
<Sleep
>>>,
27 update_limiter_cb
: Option
<Box
<RateLimiterCallback
>>,
28 last_limiter_update
: Instant
,
32 impl RateLimitedStream
<tokio
::net
::TcpStream
> {
33 pub fn peer_addr(&self) -> std
::io
::Result
<std
::net
::SocketAddr
> {
34 self.stream
.peer_addr()
38 impl<S
> RateLimitedStream
<S
> {
39 /// Creates a new instance with reads and writes limited to the same `rate`.
40 pub fn new(stream
: S
, rate
: u64, bucket_size
: u64) -> Self {
41 let now
= Instant
::now();
42 let read_limiter
= RateLimiter
::with_start_time(rate
, bucket_size
, now
);
43 let read_limiter
: SharedRateLimit
= Arc
::new(Mutex
::new(read_limiter
));
44 let write_limiter
= RateLimiter
::with_start_time(rate
, bucket_size
, now
);
45 let write_limiter
: SharedRateLimit
= Arc
::new(Mutex
::new(write_limiter
));
46 Self::with_limiter(stream
, Some(read_limiter
), Some(write_limiter
))
49 /// Creates a new instance with specified [`RateLimiter`s](RateLimiter) for reads and writes.
52 read_limiter
: Option
<SharedRateLimit
>,
53 write_limiter
: Option
<SharedRateLimit
>,
60 update_limiter_cb
: None
,
61 last_limiter_update
: Instant
::now(),
66 /// Creates a new instance with limiter update callback.
68 /// The fuction is called every minute to update/change the used limiters.
70 /// Note: This function is called within an async context, so it
71 /// should be fast and must not block.
72 pub fn with_limiter_update_cb
<
73 F
: Fn() -> (Option
<SharedRateLimit
>, Option
<SharedRateLimit
>) + Send
+ '
static,
78 let (read_limiter
, write_limiter
) = update_limiter_cb();
84 update_limiter_cb
: Some(Box
::new(update_limiter_cb
)),
85 last_limiter_update
: Instant
::now(),
90 fn update_limiters(&mut self) {
91 if let Some(ref update_limiter_cb
) = self.update_limiter_cb
{
92 if self.last_limiter_update
.elapsed().as_secs() >= 5 {
93 self.last_limiter_update
= Instant
::now();
94 let (read_limiter
, write_limiter
) = update_limiter_cb();
95 self.read_limiter
= read_limiter
;
96 self.write_limiter
= write_limiter
;
102 fn register_traffic(limiter
: &(dyn ShareableRateLimit
), count
: usize) -> Option
<Pin
<Box
<Sleep
>>> {
103 const MIN_DELAY
: Duration
= Duration
::from_millis(10);
105 let now
= Instant
::now();
106 let delay
= limiter
.register_traffic(now
, count
as u64);
107 if delay
>= MIN_DELAY
{
108 let sleep
= tokio
::time
::sleep(delay
);
109 Some(Box
::pin(sleep
))
115 fn delay_is_ready(delay
: &mut Option
<Pin
<Box
<Sleep
>>>, ctx
: &mut Context
<'_
>) -> bool
{
117 Some(ref mut future
) => future
.as_mut().poll(ctx
).is_ready(),
122 impl<S
: AsyncWrite
+ Unpin
> AsyncWrite
for RateLimitedStream
<S
> {
124 self: Pin
<&mut Self>,
125 ctx
: &mut Context
<'_
>,
127 ) -> Poll
<Result
<usize, std
::io
::Error
>> {
128 let this
= self.get_mut();
130 let is_ready
= delay_is_ready(&mut this
.write_delay
, ctx
);
133 return Poll
::Pending
;
136 this
.write_delay
= None
;
138 this
.update_limiters();
140 let result
= Pin
::new(&mut this
.stream
).poll_write(ctx
, buf
);
142 if let Some(ref mut limiter
) = this
.write_limiter
{
143 if let Poll
::Ready(Ok(count
)) = result
{
144 this
.write_delay
= register_traffic(limiter
.as_ref(), count
);
151 fn is_write_vectored(&self) -> bool
{
152 self.stream
.is_write_vectored()
155 fn poll_write_vectored(
156 self: Pin
<&mut Self>,
157 ctx
: &mut Context
<'_
>,
158 bufs
: &[IoSlice
<'_
>],
159 ) -> Poll
<Result
<usize, std
::io
::Error
>> {
160 let this
= self.get_mut();
162 let is_ready
= delay_is_ready(&mut this
.write_delay
, ctx
);
165 return Poll
::Pending
;
168 this
.write_delay
= None
;
170 this
.update_limiters();
172 let result
= Pin
::new(&mut this
.stream
).poll_write_vectored(ctx
, bufs
);
174 if let Some(ref limiter
) = this
.write_limiter
{
175 if let Poll
::Ready(Ok(count
)) = result
{
176 this
.write_delay
= register_traffic(limiter
.as_ref(), count
);
183 fn poll_flush(self: Pin
<&mut Self>, ctx
: &mut Context
<'_
>) -> Poll
<Result
<(), std
::io
::Error
>> {
184 let this
= self.get_mut();
185 Pin
::new(&mut this
.stream
).poll_flush(ctx
)
189 self: Pin
<&mut Self>,
190 ctx
: &mut Context
<'_
>,
191 ) -> Poll
<Result
<(), std
::io
::Error
>> {
192 let this
= self.get_mut();
193 Pin
::new(&mut this
.stream
).poll_shutdown(ctx
)
197 impl<S
: AsyncRead
+ Unpin
> AsyncRead
for RateLimitedStream
<S
> {
199 self: Pin
<&mut Self>,
200 ctx
: &mut Context
<'_
>,
201 buf
: &mut ReadBuf
<'_
>,
202 ) -> Poll
<Result
<(), std
::io
::Error
>> {
203 let this
= self.get_mut();
205 let is_ready
= delay_is_ready(&mut this
.read_delay
, ctx
);
208 return Poll
::Pending
;
211 this
.read_delay
= None
;
213 this
.update_limiters();
215 let filled_len
= buf
.filled().len();
216 let result
= Pin
::new(&mut this
.stream
).poll_read(ctx
, buf
);
218 if let Some(ref read_limiter
) = this
.read_limiter
{
219 if let Poll
::Ready(Ok(())) = &result
{
220 let count
= buf
.filled().len() - filled_len
;
221 this
.read_delay
= register_traffic(read_limiter
.as_ref(), count
);
229 // we need this for the hyper http client
230 impl<S
: Connection
+ AsyncRead
+ AsyncWrite
+ Unpin
> Connection
for RateLimitedStream
<S
> {
231 fn connected(&self) -> Connected
{
232 self.stream
.connected()