]> git.proxmox.com Git - proxmox-backup.git/blobdiff - src/server/rest.rs
move channel/stream helpers to pbs-tools
[proxmox-backup.git] / src / server / rest.rs
index 7922897c2f066dd3babbb6dda7fd231008a24384..62b63a5d430b30c0ee2ce4d0e2d37e93a3f893ff 100644 (file)
@@ -30,17 +30,19 @@ use proxmox::api::{
 };
 use proxmox::http_err;
 
+use pbs_tools::compression::{DeflateEncoder, Level};
+use pbs_tools::stream::AsyncReaderStream;
+
+use super::auth::AuthError;
 use super::environment::RestEnvironment;
 use super::formatter::*;
 use super::ApiConfig;
-use super::auth::{check_auth, extract_auth_data};
 
 use crate::api2::types::{Authid, Userid};
 use crate::auth_helpers::*;
 use crate::config::cached_user_info::CachedUserInfo;
 use crate::tools;
-use crate::tools::compression::{CompressionMethod, DeflateEncoder, Level};
-use crate::tools::AsyncReaderStream;
+use crate::tools::compression::CompressionMethod;
 use crate::tools::FileLogger;
 
 extern "C" {
@@ -152,14 +154,13 @@ fn log_response(
     let path = &path_query[..MAX_URI_QUERY_LENGTH.min(path_query.len())];
 
     let status = resp.status();
-
     if !(status.is_success() || status.is_informational()) {
         let reason = status.canonical_reason().unwrap_or("unknown reason");
 
-        let mut message = "request failed";
-        if let Some(data) = resp.extensions().get::<ErrorMessageExtension>() {
-            message = &data.0;
-        }
+        let message = match resp.extensions().get::<ErrorMessageExtension>() {
+            Some(data) => &data.0,
+            None => "request failed",
+        };
 
         log::error!(
             "{} {}: {} {}: [client {}] {}",
@@ -201,7 +202,7 @@ pub fn auth_logger() -> Result<FileLogger, Error> {
         owned_by_backup: true,
         ..Default::default()
     };
-    FileLogger::new(crate::buildcfg::API_AUTH_LOG_FN, logger_options)
+    FileLogger::new(pbs_buildcfg::API_AUTH_LOG_FN, logger_options)
 }
 
 fn get_proxied_peer(headers: &HeaderMap) -> Option<std::net::SocketAddr> {
@@ -254,7 +255,10 @@ impl tower_service::Service<Request<Body>> for ApiService {
                         Some(apierr) => (apierr.message.clone(), apierr.code),
                         _ => (err.to_string(), StatusCode::BAD_REQUEST),
                     };
-                    Response::builder().status(code).body(err.into())?
+                    Response::builder()
+                        .status(code)
+                        .extension(ErrorMessageExtension(err.to_string()))
+                        .body(err.into())?
                 }
             };
             let logger = config.get_file_log();
@@ -440,7 +444,7 @@ pub async fn handle_api_request<Env: RpcEnvironment, S: 'static + BuildHasher +
             );
             resp.map(|body| {
                 Body::wrap_stream(DeflateEncoder::with_quality(
-                    body.map_err(|err| {
+                    TryStreamExt::map_err(body, |err| {
                         proxmox::io_format_err!("error during compression: {}", err)
                     }),
                     Level::Default,
@@ -497,7 +501,6 @@ fn get_index(
         "CSRFPreventionToken": csrf_token,
         "language": lang,
         "debug": debug,
-        "enableTapeUI": api.enable_tape_ui,
     });
 
     let (ct, index) = match api.render_template(template_file, &data) {
@@ -562,7 +565,8 @@ async fn simple_static_file_download(
     let mut response = match compression {
         Some(CompressionMethod::Deflate) => {
             let mut enc = DeflateEncoder::with_quality(data, Level::Default);
-            enc.compress_vec(&mut file, CHUNK_SIZE_LIMIT as usize).await?;
+            enc.compress_vec(&mut file, CHUNK_SIZE_LIMIT as usize)
+                .await?;
             let mut response = Response::new(enc.into_inner().into());
             response.headers_mut().insert(
                 header::CONTENT_ENCODING,
@@ -635,27 +639,21 @@ async fn handle_static_file_download(
 }
 
 fn extract_lang_header(headers: &http::HeaderMap) -> Option<String> {
-    if let Some(raw_cookie) = headers.get("COOKIE") {
-        if let Ok(cookie) = raw_cookie.to_str() {
-            return tools::extract_cookie(cookie, "PBSLangCookie");
-        }
+    if let Some(Ok(cookie)) = headers.get("COOKIE").map(|v| v.to_str()) {
+        return tools::extract_cookie(cookie, "PBSLangCookie");
     }
-
     None
 }
 
 // FIXME: support handling multiple compression methods
 fn extract_compression_method(headers: &http::HeaderMap) -> Option<CompressionMethod> {
-    if let Some(raw_encoding) = headers.get(header::ACCEPT_ENCODING) {
-        if let Ok(encoding) = raw_encoding.to_str() {
-            for encoding in encoding.split(&[',', ' '][..]) {
-                if let Ok(method) = encoding.parse() {
-                    return Some(method);
-                }
+    if let Some(Ok(encodings)) = headers.get(header::ACCEPT_ENCODING).map(|v| v.to_str()) {
+        for encoding in encodings.split(&[',', ' '][..]) {
+            if let Ok(method) = encoding.parse() {
+                return Some(method);
             }
         }
     }
-
     None
 }
 
@@ -684,6 +682,7 @@ async fn handle_request(
     rpcenv.set_client_ip(Some(*peer));
 
     let user_info = CachedUserInfo::new()?;
+    let auth = &api.api_auth;
 
     let delay_unauth_time = std::time::Instant::now() + std::time::Duration::from_millis(3000);
     let access_forbidden_time = std::time::Instant::now() + std::time::Duration::from_millis(500);
@@ -709,13 +708,15 @@ async fn handle_request(
             }
 
             if auth_required {
-                let auth_result = match extract_auth_data(&parts.headers) {
-                    Some(auth_data) => check_auth(&method, &auth_data, &user_info),
-                    None => Err(format_err!("no authentication credentials provided.")),
-                };
-                match auth_result {
+                match auth.check_auth(&parts.headers, &method, &user_info) {
                     Ok(authid) => rpcenv.set_auth_id(Some(authid.to_string())),
-                    Err(err) => {
+                    Err(auth_err) => {
+                        let err = match auth_err {
+                            AuthError::Generic(err) => err,
+                            AuthError::NoData => {
+                                format_err!("no authentication credentials provided.")
+                            }
+                        };
                         let peer = peer.ip();
                         auth_logger()?.log(format!(
                             "authentication failure; rhost={} msg={}",
@@ -778,9 +779,9 @@ async fn handle_request(
 
         if comp_len == 0 {
             let language = extract_lang_header(&parts.headers);
-            if let Some(auth_data) = extract_auth_data(&parts.headers) {
-                match check_auth(&method, &auth_data, &user_info) {
-                    Ok(auth_id) if !auth_id.is_token() => {
+            match auth.check_auth(&parts.headers, &method, &user_info) {
+                Ok(auth_id) => {
+                    if !auth_id.is_token() {
                         let userid = auth_id.user();
                         let new_csrf_token = assemble_csrf_prevention_token(csrf_secret(), userid);
                         return Ok(get_index(
@@ -791,14 +792,13 @@ async fn handle_request(
                             parts,
                         ));
                     }
-                    _ => {
-                        tokio::time::sleep_until(Instant::from_std(delay_unauth_time)).await;
-                        return Ok(get_index(None, None, language, &api, parts));
-                    }
                 }
-            } else {
-                return Ok(get_index(None, None, language, &api, parts));
+                Err(AuthError::Generic(_)) => {
+                    tokio::time::sleep_until(Instant::from_std(delay_unauth_time)).await;
+                }
+                Err(AuthError::NoData) => {}
             }
+            return Ok(get_index(None, None, language, &api, parts));
         } else {
             let filename = api.find_alias(&components);
             let compression = extract_compression_method(&parts.headers);