]> git.proxmox.com Git - proxmox.git/commitdiff
proxmox-http: use SharedRateLimit trait object for RateLimitedStream
authorDietmar Maurer <dietmar@proxmox.com>
Sat, 13 Nov 2021 14:16:18 +0000 (15:16 +0100)
committerDietmar Maurer <dietmar@proxmox.com>
Sat, 13 Nov 2021 16:38:10 +0000 (17:38 +0100)
Signed-off-by: Dietmar Maurer <dietmar@proxmox.com>
proxmox-http/src/client/connector.rs
proxmox-http/src/client/mod.rs
proxmox-http/src/client/rate_limited_stream.rs
proxmox-http/src/client/rate_limiter.rs

index 71704d56b514bf70125a178cf589a86cdc7e2db5..1bcee7a4448894d3b0e4f6662765f8a516ffe8da 100644 (file)
@@ -1,7 +1,7 @@
 use anyhow::{bail, format_err, Error};
 use std::os::unix::io::AsRawFd;
 use std::pin::Pin;
-use std::sync::{Arc, Mutex};
+use std::sync::Arc;
 use std::task::{Context, Poll};
 
 use futures::*;
@@ -18,7 +18,9 @@ use crate::proxy_config::ProxyConfig;
 use crate::tls::MaybeTlsStream;
 use crate::uri::build_authority;
 
-use super::{RateLimiter, RateLimitedStream};
+use super::{RateLimitedStream, ShareableRateLimit};
+
+type SharedRateLimit = Arc<dyn ShareableRateLimit>;
 
 #[derive(Clone)]
 pub struct HttpsConnector {
@@ -26,8 +28,8 @@ pub struct HttpsConnector {
     ssl_connector: Arc<SslConnector>,
     proxy: Option<ProxyConfig>,
     tcp_keepalive: u32,
-    read_limiter: Option<Arc<Mutex<RateLimiter>>>,
-    write_limiter: Option<Arc<Mutex<RateLimiter>>>,
+    read_limiter: Option<SharedRateLimit>,
+    write_limiter: Option<SharedRateLimit>,
 }
 
 impl HttpsConnector {
@@ -51,11 +53,11 @@ impl HttpsConnector {
         self.proxy = Some(proxy);
     }
 
-    pub fn set_read_limiter(&mut self, limiter: Option<Arc<Mutex<RateLimiter>>>) {
+    pub fn set_read_limiter(&mut self, limiter: Option<SharedRateLimit>) {
         self.read_limiter = limiter;
     }
 
-    pub fn set_write_limiter(&mut self, limiter: Option<Arc<Mutex<RateLimiter>>>) {
+    pub fn set_write_limiter(&mut self, limiter: Option<SharedRateLimit>) {
         self.write_limiter = limiter;
     }
 
index 5ef81000ba3ed8cb2299233e40f419ea69a45bcb..fa57408e3647d57fb0157c5369391c94f5acb904 100644 (file)
@@ -3,7 +3,7 @@
 //! Contains a lightweight wrapper around `hyper` with support for TLS connections.
 
 mod rate_limiter;
-pub use rate_limiter::{RateLimit, RateLimiter};
+pub use rate_limiter::{RateLimit, RateLimiter, ShareableRateLimit};
 
 mod rate_limited_stream;
 pub use rate_limited_stream::RateLimitedStream;
index c288849a4434753fd78a585e5cbbf85eef8717ef..3a0eabd8462a4c26a16d5a8db869ea573170388c 100644 (file)
@@ -11,15 +11,17 @@ use hyper::client::connect::{Connection, Connected};
 
 use std::task::{Context, Poll};
 
-use super::{RateLimit, RateLimiter};
+use super::{ShareableRateLimit, RateLimiter};
+
+type SharedRateLimit = Arc<dyn ShareableRateLimit>;
 
 /// A rate limited stream using [RateLimiter]
 pub struct RateLimitedStream<S> {
-    read_limiter: Option<Arc<Mutex<RateLimiter>>>,
+    read_limiter: Option<SharedRateLimit>,
     read_delay: Option<Pin<Box<Sleep>>>,
-    write_limiter: Option<Arc<Mutex<RateLimiter>>>,
+    write_limiter: Option<SharedRateLimit>,
     write_delay: Option<Pin<Box<Sleep>>>,
-    update_limiter_cb: Option<Box<dyn Fn() -> (Option<Arc<Mutex<RateLimiter>>>, Option<Arc<Mutex<RateLimiter>>>) + Send>>,
+    update_limiter_cb: Option<Box<dyn Fn() -> (Option<SharedRateLimit>, Option<SharedRateLimit>) + Send>>,
     last_limiter_update: Instant,
     stream: S,
 }
@@ -35,18 +37,20 @@ impl <S> RateLimitedStream<S> {
     /// Creates a new instance with reads and writes limited to the same `rate`.
     pub fn new(stream: S, rate: u64, bucket_size: u64) -> Self {
         let now = Instant::now();
-        let read_limiter = Arc::new(Mutex::new(RateLimiter::with_start_time(rate, bucket_size, now)));
-        let write_limiter = Arc::new(Mutex::new(RateLimiter::with_start_time(rate, bucket_size, now)));
+        let read_limiter = RateLimiter::with_start_time(rate, bucket_size, now);
+        let read_limiter: SharedRateLimit = Arc::new(Mutex::new(read_limiter));
+        let write_limiter = RateLimiter::with_start_time(rate, bucket_size, now);
+        let write_limiter: SharedRateLimit = Arc::new(Mutex::new(write_limiter));
         Self::with_limiter(stream, Some(read_limiter), Some(write_limiter))
     }
 
     /// Creates a new instance with specified [RateLimiters] for reads and writes.
     pub fn with_limiter(
         stream: S,
-        read_limiter: Option<Arc<Mutex<RateLimiter>>>,
-        write_limiter: Option<Arc<Mutex<RateLimiter>>>,
+        read_limiter: Option<SharedRateLimit>,
+        write_limiter: Option<SharedRateLimit>,
     ) -> Self {
-        Self {
+       Self {
             read_limiter,
             read_delay: None,
             write_limiter,
@@ -63,7 +67,7 @@ impl <S> RateLimitedStream<S> {
     ///
     /// Note: This function is called within an async context, so it
     /// should be fast and must not block.
-    pub fn with_limiter_update_cb<F: Fn() -> (Option<Arc<Mutex<RateLimiter>>>, Option<Arc<Mutex<RateLimiter>>>) + Send + 'static>(
+    pub fn with_limiter_update_cb<F: Fn() -> (Option<SharedRateLimit>, Option<SharedRateLimit>) + Send + 'static>(
         stream: S,
         update_limiter_cb: F,
     ) -> Self {
@@ -92,15 +96,14 @@ impl <S> RateLimitedStream<S> {
 }
 
 fn register_traffic(
-    limiter: &Mutex<RateLimiter>,
+    limiter: &(dyn ShareableRateLimit),
     count: usize,
 ) -> Option<Pin<Box<Sleep>>>{
 
     const MIN_DELAY: Duration = Duration::from_millis(10);
 
     let now = Instant::now();
-    let delay = limiter.lock().unwrap()
-        .register_traffic(now, count as u64);
+    let delay = limiter.register_traffic(now, count as u64);
     if delay >= MIN_DELAY {
         let sleep = tokio::time::sleep(delay);
         Some(Box::pin(sleep))
@@ -137,9 +140,9 @@ impl <S: AsyncWrite + Unpin> AsyncWrite for RateLimitedStream<S> {
 
         let result = Pin::new(&mut this.stream).poll_write(ctx, buf);
 
-        if let Some(ref limiter) = this.write_limiter {
+        if let Some(ref mut limiter) = this.write_limiter {
             if let Poll::Ready(Ok(count)) = result {
-                this.write_delay = register_traffic(limiter, count);
+                this.write_delay = register_traffic(limiter.as_ref(), count);
             }
         }
 
@@ -169,7 +172,7 @@ impl <S: AsyncWrite + Unpin> AsyncWrite for RateLimitedStream<S> {
 
         if let Some(ref limiter) = this.write_limiter {
             if let Poll::Ready(Ok(count)) = result {
-                this.write_delay = register_traffic(limiter, count);
+                this.write_delay = register_traffic(limiter.as_ref(), count);
             }
         }
 
@@ -216,7 +219,7 @@ impl <S: AsyncRead + Unpin> AsyncRead for RateLimitedStream<S> {
         if let Some(ref read_limiter) = this.read_limiter {
             if let Poll::Ready(Ok(())) = &result {
                 let count = buf.filled().len() - filled_len;
-                this.read_delay = register_traffic(read_limiter, count);
+                this.read_delay = register_traffic(read_limiter.as_ref(), count);
             }
         }
 
index 72605ca26ab2eb7ddd7ad6e6854b153d1b2baf23..84856b19806c9c98bc56dc438382ca96b7a0a9b9 100644 (file)
@@ -14,6 +14,15 @@ pub trait RateLimit {
     fn register_traffic(&mut self, current_time: Instant, data_len: u64) -> Duration;
 }
 
+/// Like [RateLimit], but does not require self to be mutable.
+///
+/// This is useful for types providing internal mutability (Mutex).
+pub trait ShareableRateLimit: Send + Sync {
+    fn update_rate(&self, rate: u64, bucket_size: u64);
+    fn average_rate(&self, current_time: Instant) -> f64;
+    fn register_traffic(&self, current_time: Instant, data_len: u64) -> Duration;
+}
+
 /// Token bucket based rate limiter
 pub struct RateLimiter {
     rate: u64, // tokens/second
@@ -100,3 +109,18 @@ impl RateLimit for RateLimiter {
         Duration::from_nanos((self.consumed_tokens - self.bucket_size).saturating_mul(1_000_000_000)/ self.rate)
     }
 }
+
+impl <R: RateLimit + Send> ShareableRateLimit for std::sync::Mutex<R> {
+
+    fn update_rate(&self, rate: u64, bucket_size: u64) {
+        self.lock().unwrap().update_rate(rate, bucket_size);
+    }
+
+    fn average_rate(&self, current_time: Instant) -> f64 {
+        self.lock().unwrap().average_rate(current_time)
+    }
+
+    fn register_traffic(&self, current_time: Instant, data_len: u64) -> Duration {
+        self.lock().unwrap().register_traffic(current_time, data_len)
+    }
+}