]> git.proxmox.com Git - proxmox-backup.git/blobdiff - src/server/rest.rs
switch from failure to anyhow
[proxmox-backup.git] / src / server / rest.rs
index d91b26218adebb9655fec33cb87374d3061e7748..27ac93f25fa9e80d024cbe0180588630a90ffa54 100644 (file)
@@ -1,28 +1,34 @@
 use std::collections::HashMap;
+use std::future::Future;
 use std::hash::BuildHasher;
 use std::path::{Path, PathBuf};
 use std::pin::Pin;
 use std::sync::Arc;
 use std::task::{Context, Poll};
 
-use failure::*;
-use futures::future::{self, Either, FutureExt, TryFutureExt};
+use anyhow::{bail, format_err, Error};
+use futures::future::{self, FutureExt, TryFutureExt};
 use futures::stream::TryStreamExt;
 use hyper::header;
 use hyper::http::request::Parts;
-use hyper::rt::Future;
 use hyper::{Body, Request, Response, StatusCode};
 use serde_json::{json, Value};
 use tokio::fs::File;
+use tokio::time::Instant;
 use url::form_urlencoded;
 
+use proxmox::http_err;
+use proxmox::api::{ApiHandler, ApiMethod, HttpError};
+use proxmox::api::{RpcEnvironment, RpcEnvironmentType, check_api_permission};
+use proxmox::api::schema::{ObjectSchema, parse_simple_value, verify_json_object, parse_parameter_strings};
+
 use super::environment::RestEnvironment;
 use super::formatter::*;
-use crate::api_schema::config::*;
-use crate::api_schema::router::*;
-use crate::api_schema::*;
+use super::ApiConfig;
+
 use crate::auth_helpers::*;
 use crate::tools;
+use crate::config::cached_user_info::CachedUserInfo;
 
 extern "C"  { fn tzset(); }
 
@@ -121,7 +127,7 @@ impl tower_service::Service<Request<Body>> for ApiService {
         let method = req.method().clone();
 
         let peer = self.peer;
-        Pin::from(handle_request(self.api_config.clone(), req))
+        handle_request(self.api_config.clone(), req)
             .map(move |result| match result {
                 Ok(res) => {
                     log_response(&peer, method, &path, &res);
@@ -145,13 +151,44 @@ impl tower_service::Service<Request<Body>> for ApiService {
     }
 }
 
-fn get_request_parameters_async<S: 'static + BuildHasher + Send>(
-    info: &'static ApiMethod,
+fn parse_query_parameters<S: 'static + BuildHasher + Send>(
+    param_schema: &ObjectSchema,
+    form: &str, // x-www-form-urlencoded body data
+    parts: &Parts,
+    uri_param: &HashMap<String, String, S>,
+) -> Result<Value, Error> {
+
+    let mut param_list: Vec<(String, String)> = vec![];
+
+    if !form.is_empty() {
+        for (k, v) in form_urlencoded::parse(form.as_bytes()).into_owned() {
+            param_list.push((k, v));
+        }
+    }
+
+    if let Some(query_str) = parts.uri.query() {
+        for (k, v) in form_urlencoded::parse(query_str.as_bytes()).into_owned() {
+            if k == "_dc" { continue; } // skip extjs "disable cache" parameter
+            param_list.push((k, v));
+        }
+    }
+
+    for (k, v) in uri_param {
+        param_list.push((k.clone(), v.clone()));
+    }
+
+    let params = parse_parameter_strings(&param_list, param_schema, true)?;
+
+    Ok(params)
+}
+
+async fn get_request_parameters<S: 'static + BuildHasher + Send>(
+    param_schema: &ObjectSchema,
     parts: Parts,
     req_body: Body,
     uri_param: HashMap<String, String, S>,
-) -> Box<dyn Future<Output = Result<Value, failure::Error>> + Send>
-{
+) -> Result<Value, Error> {
+
     let mut is_json = false;
 
     if let Some(value) = parts.headers.get(header::CONTENT_TYPE) {
@@ -162,13 +199,11 @@ fn get_request_parameters_async<S: 'static + BuildHasher + Send>(
             Ok(Some("application/json")) => {
                 is_json = true;
             }
-            _ => {
-                return Box::new(future::err(http_err!(BAD_REQUEST, "unsupported content type".to_string())));
-            }
+            _ => bail!("unsupported content type {:?}", value.to_str()),
         }
     }
 
-    let resp = req_body
+    let body = req_body
         .map_err(|err| http_err!(BAD_REQUEST, format!("Promlems reading request body: {}", err)))
         .try_fold(Vec::new(), |mut acc, chunk| async move {
             if acc.len() + chunk.len() < 64*1024 { //fimxe: max request body size?
@@ -177,57 +212,32 @@ fn get_request_parameters_async<S: 'static + BuildHasher + Send>(
             } else {
                 Err(http_err!(BAD_REQUEST, "Request body too large".to_string()))
             }
-        })
-        .and_then(move |body| async move {
-            let utf8 = std::str::from_utf8(&body)?;
-
-            let obj_schema = &info.parameters;
-
-            if is_json {
-                let mut params: Value = serde_json::from_str(utf8)?;
-                for (k, v) in uri_param {
-                    if let Some((_optional, prop_schema)) = obj_schema.properties.get::<str>(&k) {
-                        params[&k] = parse_simple_value(&v, prop_schema)?;
-                    }
-                }
-                verify_json_object(&params, obj_schema)?;
-                return Ok(params);
-            }
-
-            let mut param_list: Vec<(String, String)> = vec![];
+        }).await?;
 
-            if !utf8.is_empty() {
-                for (k, v) in form_urlencoded::parse(utf8.as_bytes()).into_owned() {
-                    param_list.push((k, v));
-                }
-            }
+    let utf8_data = std::str::from_utf8(&body)
+        .map_err(|err| format_err!("Request body not uft8: {}", err))?;
 
-            if let Some(query_str) = parts.uri.query() {
-                for (k, v) in form_urlencoded::parse(query_str.as_bytes()).into_owned() {
-                    if k == "_dc" { continue; } // skip extjs "disable cache" parameter
-                    param_list.push((k, v));
-                }
+    if is_json {
+        let mut params: Value = serde_json::from_str(utf8_data)?;
+        for (k, v) in uri_param {
+            if let Some((_optional, prop_schema)) = param_schema.lookup(&k) {
+                params[&k] = parse_simple_value(&v, prop_schema)?;
             }
-
-            for (k, v) in uri_param {
-                param_list.push((k.clone(), v.clone()));
-            }
-
-            let params = parse_parameter_strings(&param_list, obj_schema, true)?;
-
-            Ok(params)
-        }.boxed());
-
-    Box::new(resp)
+        }
+        verify_json_object(&params, param_schema)?;
+        return Ok(params);
+    } else {
+        parse_query_parameters(param_schema, utf8_data, &parts, &uri_param)
+    }
 }
 
 struct NoLogExtension();
 
-fn proxy_protected_request(
+async fn proxy_protected_request(
     info: &'static ApiMethod,
     mut parts: Parts,
     req_body: Body,
-) -> BoxFut {
+) -> Result<Response<Body>, Error> {
 
     let mut uri_parts = parts.uri.clone().into_parts();
 
@@ -239,110 +249,66 @@ fn proxy_protected_request(
 
     let request = Request::from_parts(parts, req_body);
 
+    let reload_timezone = info.reload_timezone;
+
     let resp = hyper::client::Client::new()
         .request(request)
         .map_err(Error::from)
         .map_ok(|mut resp| {
             resp.extensions_mut().insert(NoLogExtension());
             resp
-        });
+        })
+        .await?;
 
+    if reload_timezone { unsafe { tzset(); } }
 
-    let reload_timezone = info.reload_timezone;
-    Box::new(async move {
-        let result = resp.await;
-        if reload_timezone {
-            unsafe {
-                tzset();
-            }
-        }
-        result
-    })
+    Ok(resp)
 }
 
-pub fn handle_sync_api_request<Env: RpcEnvironment, S: 'static + BuildHasher + Send>(
+pub async fn handle_api_request<Env: RpcEnvironment, S: 'static + BuildHasher + Send>(
     mut rpcenv: Env,
     info: &'static ApiMethod,
     formatter: &'static OutputFormatter,
     parts: Parts,
     req_body: Body,
     uri_param: HashMap<String, String, S>,
-) -> BoxFut
-{
-    let params = get_request_parameters_async(info, parts, req_body, uri_param);
+) -> Result<Response<Body>, Error> {
 
     let delay_unauth_time = std::time::Instant::now() + std::time::Duration::from_millis(3000);
 
-    let resp = Pin::from(params)
-        .and_then(move |params| {
-            let mut delay = false;
-            let resp = match (info.handler.as_ref().unwrap())(params, info, &mut rpcenv) {
-                Ok(data) => (formatter.format_data)(data, &rpcenv),
-                Err(err) => {
-                    if let Some(httperr) = err.downcast_ref::<HttpError>() {
-                        if httperr.code == StatusCode::UNAUTHORIZED {
-                            delay = true;
-                        }
-                    }
-                    (formatter.format_error)(err)
-                }
-            };
-
-            if info.reload_timezone {
-                unsafe { tzset() };
-            }
-
-            if delay {
-                Either::Left(delayed_response(resp, delay_unauth_time))
-            } else {
-                Either::Right(future::ok(resp))
-            }
-        })
-        .or_else(move |err| {
-            future::ok((formatter.format_error)(err))
-        });
-
-    Box::new(resp)
-}
-
-pub fn handle_async_api_request<Env: RpcEnvironment>(
-    rpcenv: Env,
-    info: &'static ApiAsyncMethod,
-    formatter: &'static OutputFormatter,
-    parts: Parts,
-    req_body: Body,
-    uri_param: HashMap<String, String>,
-) -> BoxFut
-{
-    // fixme: convert parameters to Json
-    let mut param_list: Vec<(String, String)> = vec![];
-
-    if let Some(query_str) = parts.uri.query() {
-        for (k, v) in form_urlencoded::parse(query_str.as_bytes()).into_owned() {
-            if k == "_dc" { continue; } // skip extjs "disable cache" parameter
-            param_list.push((k, v));
+    let result = match info.handler {
+        ApiHandler::AsyncHttp(handler) => {
+            let params = parse_query_parameters(info.parameters, "", &parts, &uri_param)?;
+            (handler)(parts, req_body, params, info, Box::new(rpcenv)).await
         }
-    }
-
-    for (k, v) in uri_param {
-        param_list.push((k.clone(), v.clone()));
-    }
-
-    let params = match parse_parameter_strings(&param_list, &info.parameters, true) {
-        Ok(v) => v,
-        Err(err) => {
-            let resp = (formatter.format_error)(Error::from(err));
-            return Box::new(future::ok(resp));
+        ApiHandler::Sync(handler) => {
+            let params = get_request_parameters(info.parameters, parts, req_body, uri_param).await?;
+            (handler)(params, info, &mut rpcenv)
+                .map(|data| (formatter.format_data)(data, &rpcenv))
+        }
+        ApiHandler::Async(handler) => {
+            let params = get_request_parameters(info.parameters, parts, req_body, uri_param).await?;
+            (handler)(params, info, &mut rpcenv)
+                .await
+                .map(|data| (formatter.format_data)(data, &rpcenv))
         }
     };
 
-    match (info.handler)(parts, req_body, params, info, Box::new(rpcenv)) {
-        Ok(future) => future,
+    let resp = match result {
+        Ok(resp) => resp,
         Err(err) => {
-            let resp = (formatter.format_error)(err);
-            Box::new(future::ok(resp))
+            if let Some(httperr) = err.downcast_ref::<HttpError>() {
+                if httperr.code == StatusCode::UNAUTHORIZED {
+                    tokio::time::delay_until(Instant::from_std(delay_unauth_time)).await;
+                }
+            }
+            (formatter.format_error)(err)
         }
-    }
+    };
+
+    if info.reload_timezone { unsafe { tzset(); } }
+
+    Ok(resp)
 }
 
 fn get_index(username: Option<String>, token: Option<String>) ->  Response<Body> {
@@ -372,6 +338,7 @@ fn get_index(username: Option<String>, token: Option<String>) ->  Response<Body>
     <link rel="stylesheet" type="text/css" href="/extjs/theme-crisp/resources/theme-crisp-all.css" />
     <link rel="stylesheet" type="text/css" href="/extjs/crisp/resources/charts-all.css" />
     <link rel="stylesheet" type="text/css" href="/fontawesome/css/font-awesome.css" />
+    <link rel="stylesheet" type="text/css" href="/css/ext6-pbs.css" />
     <script type='text/javascript'> function gettext(buf) {{ return buf; }} </script>
     <script type="text/javascript" src="/extjs/ext-all-debug.js"></script>
     <script type="text/javascript" src="/extjs/charts-debug.js"></script>
@@ -459,8 +426,8 @@ async fn chuncked_static_file_download(filename: PathBuf) -> Result<Response<Bod
         .await
         .map_err(|err| http_err!(BAD_REQUEST, format!("File open failed: {}", err)))?;
 
-    let payload = tokio::codec::FramedRead::new(file, tokio::codec::BytesCodec::new())
-        .map_ok(|bytes| hyper::Chunk::from(bytes.freeze()));
+    let payload = tokio_util::codec::FramedRead::new(file, tokio_util::codec::BytesCodec::new())
+        .map_ok(|bytes| hyper::body::Bytes::from(bytes.freeze()));
     let body = Body::wrap_stream(payload);
 
     // fixme: set other headers ?
@@ -472,19 +439,17 @@ async fn chuncked_static_file_download(filename: PathBuf) -> Result<Response<Bod
     )
 }
 
-fn handle_static_file_download(filename: PathBuf) ->  BoxFut {
+async fn handle_static_file_download(filename: PathBuf) ->  Result<Response<Body>, Error> {
 
-    let response = tokio::fs::metadata(filename.clone())
+    let metadata = tokio::fs::metadata(filename.clone())
         .map_err(|err| http_err!(BAD_REQUEST, format!("File access problems: {}", err)))
-        .and_then(|metadata| async move {
-            if metadata.len() < 1024*32 {
-                simple_static_file_download(filename).await
-            } else {
-                chuncked_static_file_download(filename).await
-            }
-        });
+        .await?;
 
-    Box::new(response)
+    if metadata.len() < 1024*32 {
+        simple_static_file_download(filename).await
+    } else {
+        chuncked_static_file_download(filename).await
+    }
 }
 
 fn extract_auth_data(headers: &http::HeaderMap) -> (Option<String>, Option<String>) {
@@ -504,7 +469,12 @@ fn extract_auth_data(headers: &http::HeaderMap) -> (Option<String>, Option<Strin
     (ticket, token)
 }
 
-fn check_auth(method: &hyper::Method, ticket: &Option<String>, token: &Option<String>) -> Result<String, Error> {
+fn check_auth(
+    method: &hyper::Method,
+    ticket: &Option<String>,
+    token: &Option<String>,
+    user_info: &CachedUserInfo,
+) -> Result<String, Error> {
 
     let ticket_lifetime = tools::ticket::TICKET_LIFETIME;
 
@@ -517,6 +487,10 @@ fn check_auth(method: &hyper::Method, ticket: &Option<String>, token: &Option<St
         None => bail!("missing ticket"),
     };
 
+    if !user_info.is_active_user(&username) {
+        bail!("user account disabled or expired.");
+    }
+
     if method != hyper::Method::GET {
         if let Some(token) = token {
             println!("CSRF prevention token: {:?}", token);
@@ -529,24 +503,12 @@ fn check_auth(method: &hyper::Method, ticket: &Option<String>, token: &Option<St
     Ok(username)
 }
 
-async fn delayed_response(
-    resp: Response<Body>,
-    delay_unauth_time: std::time::Instant,
-) -> Result<Response<Body>, Error> {
-    tokio::timer::delay(delay_unauth_time).await;
-    Ok(resp)
-}
-
-pub fn handle_request(api: Arc<ApiConfig>, req: Request<Body>) -> BoxFut {
+pub async fn handle_request(api: Arc<ApiConfig>, req: Request<Body>) -> Result<Response<Body>, Error> {
 
     let (parts, body) = req.into_parts();
 
     let method = parts.method.clone();
-
-    let (path, components) = match tools::normalize_uri_path(parts.uri.path()) {
-        Ok((p,c)) => (p, c),
-        Err(err) => return Box::new(future::err(http_err!(BAD_REQUEST, err.to_string()))),
-    };
+    let (path, components) = tools::normalize_uri_path(parts.uri.path())?;
 
     let comp_len = components.len();
 
@@ -556,87 +518,98 @@ pub fn handle_request(api: Arc<ApiConfig>, req: Request<Body>) -> BoxFut {
     let env_type = api.env_type();
     let mut rpcenv = RestEnvironment::new(env_type);
 
+    let user_info = CachedUserInfo::new()?;
+
     let delay_unauth_time = std::time::Instant::now() + std::time::Duration::from_millis(3000);
+    let access_forbidden_time = std::time::Instant::now() + std::time::Duration::from_millis(500);
 
     if comp_len >= 1 && components[0] == "api2" {
 
         if comp_len >= 2 {
+
             let format = components[1];
+
             let formatter = match format {
                 "json" => &JSON_FORMATTER,
                 "extjs" => &EXTJS_FORMATTER,
-                _ =>  {
-                    return Box::new(future::err(http_err!(BAD_REQUEST, format!("Unsupported output format '{}'.", format))));
-                }
+                _ =>  bail!("Unsupported output format '{}'.", format),
             };
 
             let mut uri_param = HashMap::new();
 
-            if comp_len == 4 && components[2] == "access" && components[3] == "ticket" {
+            if comp_len == 4 && components[2] == "access" && (
+                (components[3] == "ticket" && method ==  hyper::Method::POST) ||
+                (components[3] == "domains" && method ==  hyper::Method::GET)
+            ) {
                 // explicitly allow those calls without auth
             } else {
                 let (ticket, token) = extract_auth_data(&parts.headers);
-                match check_auth(&method, &ticket, &token) {
-                    Ok(username) => {
-
-                        // fixme: check permissions
-
-                        rpcenv.set_user(Some(username));
-                    }
+                match check_auth(&method, &ticket, &token, &user_info) {
+                    Ok(username) => rpcenv.set_user(Some(username)),
                     Err(err) => {
                         // always delay unauthorized calls by 3 seconds (from start of request)
-                        let err = http_err!(UNAUTHORIZED, format!("permission check failed - {}", err));
-                        return Box::new(
-                            delayed_response((formatter.format_error)(err), delay_unauth_time)
-                        );
+                        let err = http_err!(UNAUTHORIZED, format!("authentication failed - {}", err));
+                        tokio::time::delay_until(Instant::from_std(delay_unauth_time)).await;
+                        return Ok((formatter.format_error)(err));
                     }
                 }
             }
 
             match api.find_method(&components[2..], method, &mut uri_param) {
-                MethodDefinition::None => {
+                None => {
                     let err = http_err!(NOT_FOUND, "Path not found.".to_string());
-                    return Box::new(future::ok((formatter.format_error)(err)));
+                    return Ok((formatter.format_error)(err));
                 }
-                MethodDefinition::Simple(api_method) => {
-                    if api_method.protected && env_type == RpcEnvironmentType::PUBLIC {
-                        return proxy_protected_request(api_method, parts, body);
+                Some(api_method) => {
+                    let user = rpcenv.get_user();
+                    if !check_api_permission(api_method.access.permission, user.as_deref(), &uri_param, &user_info) {
+                        let err = http_err!(FORBIDDEN, format!("permission check failed"));
+                        tokio::time::delay_until(Instant::from_std(access_forbidden_time)).await;
+                        return Ok((formatter.format_error)(err));
+                    }
+
+                    let result = if api_method.protected && env_type == RpcEnvironmentType::PUBLIC {
+                        proxy_protected_request(api_method, parts, body).await
                     } else {
-                        return handle_sync_api_request(rpcenv, api_method, formatter, parts, body, uri_param);
+                        handle_api_request(rpcenv, api_method, formatter, parts, body, uri_param).await
+                    };
+
+                    if let Err(err) = result {
+                        return Ok((formatter.format_error)(err));
                     }
-                }
-                MethodDefinition::Async(async_method) => {
-                    return handle_async_api_request(rpcenv, async_method, formatter, parts, body, uri_param);
+                    return result;
                 }
             }
+
         }
-    } else {
+     } else {
         // not Auth required for accessing files!
 
         if method != hyper::Method::GET {
-            return Box::new(future::err(http_err!(BAD_REQUEST, "Unsupported method".to_string())));
+            bail!("Unsupported HTTP method {}", method);
         }
 
         if comp_len == 0 {
             let (ticket, token) = extract_auth_data(&parts.headers);
             if ticket != None {
-                match check_auth(&method, &ticket, &token) {
+                match check_auth(&method, &ticket, &token, &user_info) {
                     Ok(username) => {
                         let new_token = assemble_csrf_prevention_token(csrf_secret(), &username);
-                        return Box::new(future::ok(get_index(Some(username), Some(new_token))));
+                        return Ok(get_index(Some(username), Some(new_token)));
                     }
                     _ => {
-                        return Box::new(delayed_response(get_index(None, None), delay_unauth_time));
+                        tokio::time::delay_until(Instant::from_std(delay_unauth_time)).await;
+                        return Ok(get_index(None, None));
                     }
                 }
             } else {
-                return Box::new(future::ok(get_index(None, None)));
+                return Ok(get_index(None, None));
             }
         } else {
             let filename = api.find_alias(&components);
-            return handle_static_file_download(filename);
+            return handle_static_file_download(filename).await;
         }
     }
 
-    Box::new(future::err(http_err!(NOT_FOUND, "Path not found.".to_string())))
+    Err(http_err!(NOT_FOUND, "Path not found.".to_string()))
 }