use std::collections::HashMap;
use std::future::Future;
+use std::io;
use std::path::PathBuf;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
+use std::task::{Context, Poll};
use anyhow::{format_err, Error};
-use http::{HeaderMap, Method};
+use http::{HeaderMap, Method, Uri};
use hyper::http::request::Parts;
use hyper::{Body, Response};
+use tower_service::Service;
use proxmox_router::{Router, RpcEnvironmentType, UserInformation};
use proxmox_sys::fs::{create_path, CreateOptions};
handlers: Vec<Handler>,
auth_handler: Option<AuthHandler>,
index_handler: Option<IndexHandler>,
+ pub(crate) privileged_addr: Option<PrivilegedAddr>,
#[cfg(feature = "templates")]
templates: templates::Templates,
handlers: Vec::new(),
auth_handler: None,
index_handler: None,
+ privileged_addr: None,
#[cfg(feature = "templates")]
templates: Default::default(),
self.auth_handler(AuthHandler::from_fn(func))
}
+ /// This is used for `protected` API calls to proxy to a more privileged service.
+ pub fn privileged_addr(mut self, addr: impl Into<PrivilegedAddr>) -> Self {
+ self.privileged_addr = Some(addr.into());
+ self
+ }
+
/// Set the index handler.
pub fn index_handler(mut self, index_handler: IndexHandler) -> Self {
self.index_handler = Some(index_handler);
AuthError::Generic(err)
}
}
+
+#[derive(Clone, Debug)]
+/// For `protected` requests we support TCP or Unix connections.
+pub enum PrivilegedAddr {
+ Tcp(std::net::SocketAddr),
+ Unix(std::os::unix::net::SocketAddr),
+}
+
+impl From<std::net::SocketAddr> for PrivilegedAddr {
+ fn from(addr: std::net::SocketAddr) -> Self {
+ Self::Tcp(addr)
+ }
+}
+
+impl From<std::os::unix::net::SocketAddr> for PrivilegedAddr {
+ fn from(addr: std::os::unix::net::SocketAddr) -> Self {
+ Self::Unix(addr)
+ }
+}
+
+impl Service<Uri> for PrivilegedAddr {
+ type Response = PrivilegedSocket;
+ type Error = io::Error;
+ type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
+
+ fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
+ Poll::Ready(Ok(()))
+ }
+
+ fn call(&mut self, _req: Uri) -> Self::Future {
+ match self {
+ PrivilegedAddr::Tcp(addr) => {
+ let addr = addr.clone();
+ Box::pin(async move {
+ tokio::net::TcpStream::connect(addr)
+ .await
+ .map(PrivilegedSocket::Tcp)
+ })
+ }
+ PrivilegedAddr::Unix(addr) => {
+ let addr = addr.clone();
+ Box::pin(async move {
+ tokio::net::UnixStream::connect(addr.as_pathname().ok_or_else(|| {
+ io::Error::new(io::ErrorKind::Other, "empty path for unix socket")
+ })?)
+ .await
+ .map(PrivilegedSocket::Unix)
+ })
+ }
+ }
+ }
+}
+
+/// A socket which is either a TCP stream or a UNIX stream.
+pub enum PrivilegedSocket {
+ Tcp(tokio::net::TcpStream),
+ Unix(tokio::net::UnixStream),
+}
+
+impl tokio::io::AsyncRead for PrivilegedSocket {
+ fn poll_read(
+ self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ buf: &mut tokio::io::ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
+ match self.get_mut() {
+ Self::Tcp(s) => Pin::new(s).poll_read(cx, buf),
+ Self::Unix(s) => Pin::new(s).poll_read(cx, buf),
+ }
+ }
+}
+
+impl tokio::io::AsyncWrite for PrivilegedSocket {
+ fn poll_write(
+ self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ buf: &[u8],
+ ) -> Poll<io::Result<usize>> {
+ match self.get_mut() {
+ Self::Tcp(s) => Pin::new(s).poll_write(cx, buf),
+ Self::Unix(s) => Pin::new(s).poll_write(cx, buf),
+ }
+ }
+
+ fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
+ match self.get_mut() {
+ Self::Tcp(s) => Pin::new(s).poll_flush(cx),
+ Self::Unix(s) => Pin::new(s).poll_flush(cx),
+ }
+ }
+
+ fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
+ match self.get_mut() {
+ Self::Tcp(s) => Pin::new(s).poll_shutdown(cx),
+ Self::Unix(s) => Pin::new(s).poll_shutdown(cx),
+ }
+ }
+
+ fn poll_write_vectored(
+ self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ bufs: &[io::IoSlice<'_>],
+ ) -> Poll<io::Result<usize>> {
+ match self.get_mut() {
+ Self::Tcp(s) => Pin::new(s).poll_write_vectored(cx, bufs),
+ Self::Unix(s) => Pin::new(s).poll_write_vectored(cx, bufs),
+ }
+ }
+
+ fn is_write_vectored(&self) -> bool {
+ match self {
+ Self::Tcp(s) => s.is_write_vectored(),
+ Self::Unix(s) => s.is_write_vectored(),
+ }
+ }
+}
+
+impl hyper::client::connect::Connection for PrivilegedSocket {
+ fn connected(&self) -> hyper::client::connect::Connected {
+ match self {
+ Self::Tcp(s) => s.connected(),
+ Self::Unix(_) => hyper::client::connect::Connected::new(),
+ }
+ }
+}
+
+/// Implements hyper's `Accept` for `UnixListener`s.
+pub struct UnixAcceptor {
+ listener: tokio::net::UnixListener,
+}
+
+impl From<tokio::net::UnixListener> for UnixAcceptor {
+ fn from(listener: tokio::net::UnixListener) -> Self {
+ Self { listener }
+ }
+}
+
+impl hyper::server::accept::Accept for UnixAcceptor {
+ type Conn = tokio::net::UnixStream;
+ type Error = io::Error;
+
+ fn poll_accept(
+ self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ ) -> Poll<Option<io::Result<Self::Conn>>> {
+ Pin::new(&mut self.get_mut().listener)
+ .poll_accept(cx)
+ .map(|res| match res {
+ Ok((stream, _addr)) => Some(Ok(stream)),
+ Err(err) => Some(Err(err)),
+ })
+ }
+}
use std::task::{Context, Poll};
use anyhow::{bail, format_err, Error};
-use futures::future::{FutureExt, TryFutureExt};
+use futures::future::FutureExt;
use futures::stream::TryStreamExt;
use hyper::body::HttpBody;
use hyper::header::{self, HeaderMap};
struct NoLogExtension();
async fn proxy_protected_request(
- info: &'static ApiMethod,
+ config: &ApiConfig,
+ info: &ApiMethod,
mut parts: Parts,
req_body: Body,
peer: &std::net::SocketAddr,
let reload_timezone = info.reload_timezone;
- let resp = hyper::client::Client::new()
- .request(request)
- .map_err(Error::from)
- .map_ok(|mut resp| {
- resp.extensions_mut().insert(NoLogExtension());
- resp
- })
- .await?;
+ let mut resp = match config.privileged_addr.clone() {
+ None => hyper::client::Client::new().request(request).await?,
+ Some(addr) => {
+ hyper::client::Client::builder()
+ .build(addr)
+ .request(request)
+ .await?
+ }
+ };
+ resp.extensions_mut().insert(NoLogExtension());
if reload_timezone {
unsafe {
let result = if api_method.protected
&& rpcenv.env_type == RpcEnvironmentType::PUBLIC
{
- proxy_protected_request(api_method, parts, body, peer).await
+ proxy_protected_request(config, api_method, parts, body, peer).await
} else {
handle_api_request(rpcenv, api_method, formatter, parts, body, uri_param).await
};
let result = if api_method.protected
&& rpcenv.env_type == RpcEnvironmentType::PUBLIC
{
- proxy_protected_request(api_method, parts, body, peer).await
+ proxy_protected_request(config, api_method, parts, body, peer).await
} else {
handle_unformatted_api_request(rpcenv, api_method, parts, body, uri_param).await
};