]> git.proxmox.com Git - proxmox.git/commitdiff
rest-server: support configuring the privileged connection
authorWolfgang Bumiller <w.bumiller@proxmox.com>
Thu, 9 Mar 2023 14:50:55 +0000 (15:50 +0100)
committerWolfgang Bumiller <w.bumiller@proxmox.com>
Thu, 1 Feb 2024 13:00:22 +0000 (14:00 +0100)
Adds a privileged_addr to ApiConfig, and some helpers for
hyper (both server and client)

Signed-off-by: Wolfgang Bumiller <w.bumiller@proxmox.com>
Reviewed-by: Lukas Wagner <l.wagner@proxmox.com>
proxmox-rest-server/src/api_config.rs
proxmox-rest-server/src/lib.rs
proxmox-rest-server/src/rest.rs

index ad9a811138ec02309b27a4c66069bf2f48a887e4..805894466503be865807c64b07afca6284ac8b73 100644 (file)
@@ -1,13 +1,16 @@
 use std::collections::HashMap;
 use std::future::Future;
+use std::io;
 use std::path::PathBuf;
 use std::pin::Pin;
 use std::sync::{Arc, Mutex};
+use std::task::{Context, Poll};
 
 use anyhow::{format_err, Error};
-use http::{HeaderMap, Method};
+use http::{HeaderMap, Method, Uri};
 use hyper::http::request::Parts;
 use hyper::{Body, Response};
+use tower_service::Service;
 
 use proxmox_router::{Router, RpcEnvironmentType, UserInformation};
 use proxmox_sys::fs::{create_path, CreateOptions};
@@ -25,6 +28,7 @@ pub struct ApiConfig {
     handlers: Vec<Handler>,
     auth_handler: Option<AuthHandler>,
     index_handler: Option<IndexHandler>,
+    pub(crate) privileged_addr: Option<PrivilegedAddr>,
 
     #[cfg(feature = "templates")]
     templates: templates::Templates,
@@ -53,6 +57,7 @@ impl ApiConfig {
             handlers: Vec::new(),
             auth_handler: None,
             index_handler: None,
+            privileged_addr: None,
 
             #[cfg(feature = "templates")]
             templates: Default::default(),
@@ -73,6 +78,12 @@ impl ApiConfig {
         self.auth_handler(AuthHandler::from_fn(func))
     }
 
+    /// This is used for `protected` API calls to proxy to a more privileged service.
+    pub fn privileged_addr(mut self, addr: impl Into<PrivilegedAddr>) -> Self {
+        self.privileged_addr = Some(addr.into());
+        self
+    }
+
     /// Set the index handler.
     pub fn index_handler(mut self, index_handler: IndexHandler) -> Self {
         self.index_handler = Some(index_handler);
@@ -452,3 +463,156 @@ impl From<Error> for AuthError {
         AuthError::Generic(err)
     }
 }
+
+#[derive(Clone, Debug)]
+/// For `protected` requests we support TCP or Unix connections.
+pub enum PrivilegedAddr {
+    Tcp(std::net::SocketAddr),
+    Unix(std::os::unix::net::SocketAddr),
+}
+
+impl From<std::net::SocketAddr> for PrivilegedAddr {
+    fn from(addr: std::net::SocketAddr) -> Self {
+        Self::Tcp(addr)
+    }
+}
+
+impl From<std::os::unix::net::SocketAddr> for PrivilegedAddr {
+    fn from(addr: std::os::unix::net::SocketAddr) -> Self {
+        Self::Unix(addr)
+    }
+}
+
+impl Service<Uri> for PrivilegedAddr {
+    type Response = PrivilegedSocket;
+    type Error = io::Error;
+    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
+
+    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
+        Poll::Ready(Ok(()))
+    }
+
+    fn call(&mut self, _req: Uri) -> Self::Future {
+        match self {
+            PrivilegedAddr::Tcp(addr) => {
+                let addr = addr.clone();
+                Box::pin(async move {
+                    tokio::net::TcpStream::connect(addr)
+                        .await
+                        .map(PrivilegedSocket::Tcp)
+                })
+            }
+            PrivilegedAddr::Unix(addr) => {
+                let addr = addr.clone();
+                Box::pin(async move {
+                    tokio::net::UnixStream::connect(addr.as_pathname().ok_or_else(|| {
+                        io::Error::new(io::ErrorKind::Other, "empty path for unix socket")
+                    })?)
+                    .await
+                    .map(PrivilegedSocket::Unix)
+                })
+            }
+        }
+    }
+}
+
+/// A socket which is either a TCP stream or a UNIX stream.
+pub enum PrivilegedSocket {
+    Tcp(tokio::net::TcpStream),
+    Unix(tokio::net::UnixStream),
+}
+
+impl tokio::io::AsyncRead for PrivilegedSocket {
+    fn poll_read(
+        self: Pin<&mut Self>,
+        cx: &mut Context<'_>,
+        buf: &mut tokio::io::ReadBuf<'_>,
+    ) -> Poll<io::Result<()>> {
+        match self.get_mut() {
+            Self::Tcp(s) => Pin::new(s).poll_read(cx, buf),
+            Self::Unix(s) => Pin::new(s).poll_read(cx, buf),
+        }
+    }
+}
+
+impl tokio::io::AsyncWrite for PrivilegedSocket {
+    fn poll_write(
+        self: Pin<&mut Self>,
+        cx: &mut Context<'_>,
+        buf: &[u8],
+    ) -> Poll<io::Result<usize>> {
+        match self.get_mut() {
+            Self::Tcp(s) => Pin::new(s).poll_write(cx, buf),
+            Self::Unix(s) => Pin::new(s).poll_write(cx, buf),
+        }
+    }
+
+    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
+        match self.get_mut() {
+            Self::Tcp(s) => Pin::new(s).poll_flush(cx),
+            Self::Unix(s) => Pin::new(s).poll_flush(cx),
+        }
+    }
+
+    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
+        match self.get_mut() {
+            Self::Tcp(s) => Pin::new(s).poll_shutdown(cx),
+            Self::Unix(s) => Pin::new(s).poll_shutdown(cx),
+        }
+    }
+
+    fn poll_write_vectored(
+        self: Pin<&mut Self>,
+        cx: &mut Context<'_>,
+        bufs: &[io::IoSlice<'_>],
+    ) -> Poll<io::Result<usize>> {
+        match self.get_mut() {
+            Self::Tcp(s) => Pin::new(s).poll_write_vectored(cx, bufs),
+            Self::Unix(s) => Pin::new(s).poll_write_vectored(cx, bufs),
+        }
+    }
+
+    fn is_write_vectored(&self) -> bool {
+        match self {
+            Self::Tcp(s) => s.is_write_vectored(),
+            Self::Unix(s) => s.is_write_vectored(),
+        }
+    }
+}
+
+impl hyper::client::connect::Connection for PrivilegedSocket {
+    fn connected(&self) -> hyper::client::connect::Connected {
+        match self {
+            Self::Tcp(s) => s.connected(),
+            Self::Unix(_) => hyper::client::connect::Connected::new(),
+        }
+    }
+}
+
+/// Implements hyper's `Accept` for `UnixListener`s.
+pub struct UnixAcceptor {
+    listener: tokio::net::UnixListener,
+}
+
+impl From<tokio::net::UnixListener> for UnixAcceptor {
+    fn from(listener: tokio::net::UnixListener) -> Self {
+        Self { listener }
+    }
+}
+
+impl hyper::server::accept::Accept for UnixAcceptor {
+    type Conn = tokio::net::UnixStream;
+    type Error = io::Error;
+
+    fn poll_accept(
+        self: Pin<&mut Self>,
+        cx: &mut Context<'_>,
+    ) -> Poll<Option<io::Result<Self::Conn>>> {
+        Pin::new(&mut self.get_mut().listener)
+            .poll_accept(cx)
+            .map(|res| match res {
+                Ok((stream, _addr)) => Some(Ok(stream)),
+                Err(err) => Some(Err(err)),
+            })
+    }
+}
index 1c64ffb4ae336877f08b5fcea8f7f483461a133f..ce9e4f15e4ea56a2f8b8fefceb355ce30f85f06d 100644 (file)
@@ -45,7 +45,7 @@ mod file_logger;
 pub use file_logger::{FileLogOptions, FileLogger};
 
 mod api_config;
-pub use api_config::{ApiConfig, AuthError, AuthHandler, IndexHandler};
+pub use api_config::{ApiConfig, AuthError, AuthHandler, IndexHandler, UnixAcceptor};
 
 mod rest;
 pub use rest::{Redirector, RestServer};
index 39f98e55f0302ce0dcc541baf5eaa04ee49d9b3f..4900592d2d61fa284375791b8a3183fc99fc613f 100644 (file)
@@ -8,7 +8,7 @@ use std::sync::{Arc, Mutex};
 use std::task::{Context, Poll};
 
 use anyhow::{bail, format_err, Error};
-use futures::future::{FutureExt, TryFutureExt};
+use futures::future::FutureExt;
 use futures::stream::TryStreamExt;
 use hyper::body::HttpBody;
 use hyper::header::{self, HeaderMap};
@@ -443,7 +443,8 @@ async fn get_request_parameters<S: 'static + BuildHasher + Send>(
 struct NoLogExtension();
 
 async fn proxy_protected_request(
-    info: &'static ApiMethod,
+    config: &ApiConfig,
+    info: &ApiMethod,
     mut parts: Parts,
     req_body: Body,
     peer: &std::net::SocketAddr,
@@ -464,14 +465,16 @@ async fn proxy_protected_request(
 
     let reload_timezone = info.reload_timezone;
 
-    let resp = hyper::client::Client::new()
-        .request(request)
-        .map_err(Error::from)
-        .map_ok(|mut resp| {
-            resp.extensions_mut().insert(NoLogExtension());
-            resp
-        })
-        .await?;
+    let mut resp = match config.privileged_addr.clone() {
+        None => hyper::client::Client::new().request(request).await?,
+        Some(addr) => {
+            hyper::client::Client::builder()
+                .build(addr)
+                .request(request)
+                .await?
+        }
+    };
+    resp.extensions_mut().insert(NoLogExtension());
 
     if reload_timezone {
         unsafe {
@@ -1024,7 +1027,7 @@ impl Formatted {
                 let result = if api_method.protected
                     && rpcenv.env_type == RpcEnvironmentType::PUBLIC
                 {
-                    proxy_protected_request(api_method, parts, body, peer).await
+                    proxy_protected_request(config, api_method, parts, body, peer).await
                 } else {
                     handle_api_request(rpcenv, api_method, formatter, parts, body, uri_param).await
                 };
@@ -1129,7 +1132,7 @@ impl Unformatted {
                 let result = if api_method.protected
                     && rpcenv.env_type == RpcEnvironmentType::PUBLIC
                 {
-                    proxy_protected_request(api_method, parts, body, peer).await
+                    proxy_protected_request(config, api_method, parts, body, peer).await
                 } else {
                     handle_unformatted_api_request(rpcenv, api_method, parts, body, uri_param).await
                 };