]> git.proxmox.com Git - proxmox-backup.git/blobdiff - src/client/http_client.rs
http_client: set connect timeout to 10 seconds
[proxmox-backup.git] / src / client / http_client.rs
index 7472330ce679b05e427bcc1521de35161345271c..e18b8fab316fef133a87b97aceb2435637af226d 100644 (file)
@@ -1,9 +1,9 @@
 use std::io::Write;
 use std::task::{Context, Poll};
-use std::sync::{Arc, Mutex};
+use std::sync::{Arc, Mutex, RwLock};
+use std::time::Duration;
 
-use chrono::Utc;
-use failure::*;
+use anyhow::{bail, format_err, Error};
 use futures::*;
 use http::Uri;
 use http::header::HeaderValue;
@@ -15,17 +15,22 @@ use serde_json::{json, Value};
 use percent_encoding::percent_encode;
 use xdg::BaseDirectories;
 
-use proxmox::tools::{
-    fs::{file_get_json, replace_file, CreateOptions},
+use proxmox::{
+    api::error::HttpError,
+    sys::linux::tty,
+    tools::{
+        fs::{file_get_json, replace_file, CreateOptions},
+    }
 };
 
 use super::pipe_to_stream::PipeToSendStream;
+use crate::api2::types::Userid;
 use crate::tools::async_io::EitherStream;
-use crate::tools::{self, tty, BroadcastFuture, DEFAULT_ENCODE_SET};
+use crate::tools::{self, BroadcastFuture, DEFAULT_ENCODE_SET};
 
 #[derive(Clone)]
 pub struct AuthInfo {
-    pub username: String,
+    pub userid: Userid,
     pub ticket: String,
     pub token: String,
 }
@@ -33,7 +38,6 @@ pub struct AuthInfo {
 pub struct HttpClientOptions {
     prefix: Option<String>,
     password: Option<String>,
-    password_env: Option<String>,
     fingerprint: Option<String>,
     interactive: bool,
     ticket_cache: bool,
@@ -47,7 +51,6 @@ impl HttpClientOptions {
         Self {
             prefix: None,
             password: None,
-            password_env: None,
             fingerprint: None,
             interactive: false,
             ticket_cache: false,
@@ -66,11 +69,6 @@ impl HttpClientOptions {
         self
     }
 
-    pub fn password_env(mut self, password_env: Option<String>) -> Self {
-        self.password_env = password_env;
-        self
-    }
-
     pub fn fingerprint(mut self, fingerprint: Option<String>) -> Self {
         self.fingerprint = fingerprint;
         self
@@ -101,13 +99,16 @@ impl HttpClientOptions {
 pub struct HttpClient {
     client: Client<HttpsConnector>,
     server: String,
+    port: u16,
     fingerprint: Arc<Mutex<Option<String>>>,
-    auth: BroadcastFuture<AuthInfo>,
+    first_auth: BroadcastFuture<()>,
+    auth: Arc<RwLock<AuthInfo>>,
+    ticket_abort: futures::future::AbortHandle,
     _options: HttpClientOptions,
 }
 
 /// Delete stored ticket data (logout)
-pub fn delete_ticket_info(prefix: &str, server: &str, username: &str) -> Result<(), Error> {
+pub fn delete_ticket_info(prefix: &str, server: &str, username: &Userid) -> Result<(), Error> {
 
     let base = BaseDirectories::with_prefix(prefix)?;
 
@@ -119,7 +120,7 @@ pub fn delete_ticket_info(prefix: &str, server: &str, username: &str) -> Result<
     let mut data = file_get_json(&path, Some(json!({})))?;
 
     if let Some(map) = data[server].as_object_mut() {
-        map.remove(username);
+        map.remove(username.as_str());
     }
 
     replace_file(path, data.to_string().as_bytes(), CreateOptions::new().perm(mode))?;
@@ -180,10 +181,8 @@ fn load_fingerprint(prefix: &str, server: &str) -> Option<String> {
 
     for line in raw.split('\n') {
         let items: Vec<String> = line.split_whitespace().map(String::from).collect();
-        if items.len() == 2 {
-            if &items[0] == server {
-                return Some(items[1].clone());
-            }
+        if items.len() == 2 && &items[0] == server {
+            return Some(items[1].clone());
         }
     }
 
@@ -201,7 +200,7 @@ fn store_ticket_info(prefix: &str, server: &str, username: &str, ticket: &str, t
 
     let mut data = file_get_json(&path, Some(json!({})))?;
 
-    let now = Utc::now().timestamp();
+    let now = proxmox::tools::time::epoch_i64();
 
     data[server][username] = json!({ "timestamp": now, "ticket": ticket, "token": token});
 
@@ -226,15 +225,15 @@ fn store_ticket_info(prefix: &str, server: &str, username: &str, ticket: &str, t
     Ok(())
 }
 
-fn load_ticket_info(prefix: &str, server: &str, username: &str) -> Option<(String, String)> {
+fn load_ticket_info(prefix: &str, server: &str, userid: &Userid) -> Option<(String, String)> {
     let base = BaseDirectories::with_prefix(prefix).ok()?;
 
     // usually /run/user/<uid>/...
     let path = base.place_runtime_file("tickets").ok()?;
     let data = file_get_json(&path, None).ok()?;
-    let now = Utc::now().timestamp();
+    let now = proxmox::tools::time::epoch_i64();
     let ticket_lifetime = tools::ticket::TICKET_LIFETIME - 60;
-    let uinfo = data[server][username].as_object()?;
+    let uinfo = data[server][userid.as_str()].as_object()?;
     let timestamp = uinfo["timestamp"].as_i64()?;
     let age = now - timestamp;
 
@@ -248,13 +247,21 @@ fn load_ticket_info(prefix: &str, server: &str, username: &str) -> Option<(Strin
 }
 
 impl HttpClient {
-
-    pub fn new(server: &str, username: &str, mut options: HttpClientOptions) -> Result<Self, Error> {
+    pub fn new(
+        server: &str,
+        port: u16,
+        userid: &Userid,
+        mut options: HttpClientOptions,
+    ) -> Result<Self, Error> {
 
         let verified_fingerprint = Arc::new(Mutex::new(None));
 
         let mut fingerprint = options.fingerprint.take();
-        if options.fingerprint_cache && fingerprint.is_none() && options.prefix.is_some() {
+
+        if fingerprint.is_some() {
+            // do not store fingerprints passed via options in cache
+            options.fingerprint_cache = false;
+        } else if options.fingerprint_cache && options.prefix.is_some() {
             fingerprint = load_fingerprint(options.prefix.as_ref().unwrap(), server);
         }
 
@@ -287,9 +294,9 @@ impl HttpClient {
 
         let mut httpc = hyper::client::HttpConnector::new();
         httpc.set_nodelay(true); // important for h2 download performance!
-        httpc.set_recv_buffer_size(Some(1024*1024)); //important for h2 download performance!
         httpc.enforce_http(false); // we want https...
 
+        httpc.set_connect_timeout(Some(std::time::Duration::new(10, 0)));
         let https = HttpsConnector::with_connector(httpc, ssl_connector_builder.build());
 
         let client = Client::builder()
@@ -305,48 +312,90 @@ impl HttpClient {
         } else {
             let mut ticket_info = None;
             if use_ticket_cache {
-                ticket_info = load_ticket_info(options.prefix.as_ref().unwrap(), server, username);
+                ticket_info = load_ticket_info(options.prefix.as_ref().unwrap(), server, userid);
             }
             if let Some((ticket, _token)) = ticket_info {
                 ticket
             } else {
-                Self::get_password(&username, options.interactive, options.password_env.clone())?
+                Self::get_password(userid, options.interactive)?
+            }
+        };
+
+        let auth = Arc::new(RwLock::new(AuthInfo {
+            userid: userid.clone(),
+            ticket: password.clone(),
+            token: "".to_string(),
+        }));
+
+        let server2 = server.to_string();
+        let client2 = client.clone();
+        let auth2 = auth.clone();
+        let prefix2 = options.prefix.clone();
+
+        let renewal_future = async move {
+            loop {
+                tokio::time::delay_for(Duration::new(60*15,  0)).await; // 15 minutes
+                let (userid, ticket) = {
+                    let authinfo = auth2.read().unwrap().clone();
+                    (authinfo.userid, authinfo.ticket)
+                };
+                match Self::credentials(client2.clone(), server2.clone(), port, userid, ticket).await {
+                    Ok(auth) => {
+                        if use_ticket_cache & &prefix2.is_some() {
+                            let _ = store_ticket_info(prefix2.as_ref().unwrap(), &server2, &auth.userid.to_string(), &auth.ticket, &auth.token);
+                        }
+                        *auth2.write().unwrap() = auth;
+                    },
+                    Err(err) => {
+                        eprintln!("re-authentication failed: {}", err);
+                        return;
+                    }
+                }
             }
         };
 
+        let (renewal_future, ticket_abort) = futures::future::abortable(renewal_future);
+
         let login_future = Self::credentials(
             client.clone(),
             server.to_owned(),
-            username.to_owned(),
-            password,
+            port,
+            userid.to_owned(),
+            password.to_owned(),
         ).map_ok({
             let server = server.to_string();
             let prefix = options.prefix.clone();
+            let authinfo = auth.clone();
 
             move |auth| {
                 if use_ticket_cache & &prefix.is_some() {
-                    let _ = store_ticket_info(prefix.as_ref().unwrap(), &server, &auth.username, &auth.ticket, &auth.token);
+                    let _ = store_ticket_info(prefix.as_ref().unwrap(), &server, &auth.userid.to_string(), &auth.ticket, &auth.token);
                 }
-
-                auth
+                *authinfo.write().unwrap() = auth;
+                tokio::spawn(renewal_future);
             }
         });
 
         Ok(Self {
             client,
             server: String::from(server),
+            port,
             fingerprint: verified_fingerprint,
-            auth: BroadcastFuture::new(Box::new(login_future)),
+            auth,
+            ticket_abort,
+            first_auth: BroadcastFuture::new(Box::new(login_future)),
             _options: options,
         })
     }
 
     /// Login
     ///
-    /// Login is done on demand, so this is onyl required if you need
+    /// Login is done on demand, so this is only required if you need
     /// access to authentication data in 'AuthInfo'.
     pub async fn login(&self) -> Result<AuthInfo, Error> {
-        self.auth.listen().await
+        self.first_auth.listen().await?;
+        let authinfo = self.auth.read().unwrap();
+        Ok(authinfo.clone())
     }
 
     /// Returns the optional fingerprint passed to the new() constructor.
@@ -354,18 +403,7 @@ impl HttpClient {
         (*self.fingerprint.lock().unwrap()).clone()
     }
 
-    fn get_password(username: &str, interactive: bool, password_env: Option<String>) -> Result<String, Error> {
-        if let Some(password_env) = password_env {
-            use std::env::VarError::*;
-            match std::env::var(&password_env) {
-                Ok(p) => return Ok(p),
-                Err(NotUnicode(_)) => bail!(format!("{} contains bad characters", password_env)),
-                Err(NotPresent) => {
-                    // Try another method
-                }
-            }
-        }
-
+    fn get_password(username: &Userid, interactive: bool) -> Result<String, Error> {
         // If we're on a TTY, query the user for a password
         if interactive && tty::stdin_isatty() {
             let msg = format!("Password for \"{}\": ", username);
@@ -400,7 +438,7 @@ impl HttpClient {
             .collect::<Vec<&str>>().join(":");
 
         if let Some(expected_fingerprint) = expected_fingerprint {
-            if expected_fingerprint == fp_string {
+            if expected_fingerprint.to_lowercase() == fp_string {
                 return (true, Some(fp_string));
             } else {
                 return (false, None);
@@ -411,21 +449,22 @@ impl HttpClient {
         if interactive && tty::stdin_isatty() {
             println!("fingerprint: {}", fp_string);
             loop {
-                print!("Want to trust? (y/n): ");
+                print!("Are you sure you want to continue connecting? (y/n): ");
                 let _ = std::io::stdout().flush();
-                let mut buf = [0u8; 1];
-                use std::io::Read;
-                match std::io::stdin().read_exact(&mut buf) {
-                    Ok(()) => {
-                        if buf[0] == b'y' || buf[0] == b'Y' {
+                use std::io::{BufRead, BufReader};
+                let mut line = String::new();
+                match BufReader::new(std::io::stdin()).read_line(&mut line) {
+                    Ok(_) => {
+                        let trimmed = line.trim();
+                        if trimmed == "y" || trimmed == "Y" {
                             return (true, Some(fp_string));
-                        } else if buf[0] == b'n' || buf[0] == b'N' {
+                        } else if trimmed == "n" || trimmed == "N" {
                             return (false, None);
+                        } else {
+                            continue;
                         }
                     }
-                    Err(_) => {
-                        return (false, None);
-                    }
+                    Err(_) => return (false, None),
                 }
             }
         }
@@ -450,7 +489,7 @@ impl HttpClient {
         path: &str,
         data: Option<Value>,
     ) -> Result<Value, Error> {
-        let req = Self::request_builder(&self.server, "GET", path, data).unwrap();
+        let req = Self::request_builder(&self.server, self.port, "GET", path, data)?;
         self.request(req).await
     }
 
@@ -459,7 +498,7 @@ impl HttpClient {
         path: &str,
         data: Option<Value>,
     ) -> Result<Value, Error> {
-        let req = Self::request_builder(&self.server, "DELETE", path, data).unwrap();
+        let req = Self::request_builder(&self.server, self.port, "DELETE", path, data)?;
         self.request(req).await
     }
 
@@ -468,7 +507,7 @@ impl HttpClient {
         path: &str,
         data: Option<Value>,
     ) -> Result<Value, Error> {
-        let req = Self::request_builder(&self.server, "POST", path, data).unwrap();
+        let req = Self::request_builder(&self.server, self.port, "POST", path, data)?;
         self.request(req).await
     }
 
@@ -476,8 +515,8 @@ impl HttpClient {
         &mut self,
         path: &str,
         output: &mut (dyn Write + Send),
-    ) ->  Result<(), Error> {
-        let mut req = Self::request_builder(&self.server, "GET", path, None).unwrap();
+    ) -> Result<(), Error> {
+        let mut req = Self::request_builder(&self.server, self.port, "GET", path, None)?;
 
         let client = self.client.clone();
 
@@ -513,7 +552,7 @@ impl HttpClient {
     ) -> Result<Value, Error> {
 
         let path = path.trim_matches('/');
-        let mut url = format!("https://{}:8007/{}", &self.server, path);
+        let mut url = format!("https://{}:{}/{}", &self.server, self.port, path);
 
         if let Some(data) = data {
             let query = tools::json_object_to_query(data).unwrap();
@@ -588,14 +627,15 @@ impl HttpClient {
     async fn credentials(
         client: Client<HttpsConnector>,
         server: String,
-        username: String,
+        port: u16,
+        username: Userid,
         password: String,
     ) -> Result<AuthInfo, Error> {
         let data = json!({ "username": username, "password": password });
-        let req = Self::request_builder(&server, "POST", "/api2/json/access/ticket", Some(data)).unwrap();
+        let req = Self::request_builder(&server, port, "POST", "/api2/json/access/ticket", Some(data))?;
         let cred = Self::api_request(client, req).await?;
         let auth = AuthInfo {
-            username: cred["data"]["username"].as_str().unwrap().to_owned(),
+            userid: cred["data"]["username"].as_str().unwrap().parse()?,
             ticket: cred["data"]["ticket"].as_str().unwrap().to_owned(),
             token: cred["data"]["CSRFPreventionToken"].as_str().unwrap().to_owned(),
         };
@@ -616,7 +656,7 @@ impl HttpClient {
                 Ok(value)
             }
         } else {
-            bail!("HTTP Error {}: {}", status, text);
+            Err(Error::from(HttpError::new(status, text)))
         }
     }
 
@@ -636,9 +676,13 @@ impl HttpClient {
         &self.server
     }
 
-    pub fn request_builder(server: &str, method: &str, path: &str, data: Option<Value>) -> Result<Request<Body>, Error> {
+    pub fn port(&self) -> u16 {
+        self.port
+    }
+
+    pub fn request_builder(server: &str, port: u16, method: &str, path: &str, data: Option<Value>) -> Result<Request<Body>, Error> {
         let path = path.trim_matches('/');
-        let url: Uri = format!("https://{}:8007/{}", server, path).parse()?;
+        let url: Uri = format!("https://{}:{}/{}", server, port, path).parse()?;
 
         if let Some(data) = data {
             if method == "POST" {
@@ -651,7 +695,7 @@ impl HttpClient {
                 return Ok(request);
             } else {
                 let query = tools::json_object_to_query(data)?;
-                let url: Uri = format!("https://{}:8007/{}?{}", server, path, query).parse()?;
+                let url: Uri = format!("https://{}:{}/{}?{}", server, port, path, query).parse()?;
                 let request = Request::builder()
                     .method(method)
                     .uri(url)
@@ -673,6 +717,12 @@ impl HttpClient {
     }
 }
 
+impl Drop for HttpClient {
+    fn drop(&mut self) {
+        self.ticket_abort.abort();
+    }
+}
+
 
 #[derive(Clone)]
 pub struct H2Client {
@@ -717,7 +767,7 @@ impl H2Client {
         path: &str,
         param: Option<Value>,
         mut output: W,
-    ) -> Result<W, Error> {
+    ) -> Result<(), Error> {
         let request = Self::request_builder("localhost", "GET", path, param, None).unwrap();
 
         let response_future = self.send_request(request, None).await?;
@@ -737,7 +787,7 @@ impl H2Client {
             output.write_all(&chunk)?;
         }
 
-        Ok(output)
+        Ok(())
     }
 
     pub async fn upload(
@@ -829,7 +879,7 @@ impl H2Client {
                 bail!("got result without data property");
             }
         } else {
-            bail!("HTTP Error {}: {}", status, text);
+            Err(Error::from(HttpError::new(status, text)))
         }
     }