]> git.proxmox.com Git - pve-xtermjs.git/blobdiff - src/main.rs
termproxy: update clap crate to major version 4
[pve-xtermjs.git] / src / main.rs
index 605a92ea1e9f2ebe82971888a1717aeaa3c61b90..5eb14ee3585c3e1bf0d517f51dcddc869abb5785 100644 (file)
@@ -1,22 +1,21 @@
 use std::cmp::min;
 use std::collections::HashMap;
-use std::ffi::{OsStr, OsString};
-use std::io::{ErrorKind, Result, Write};
+use std::ffi::OsString;
+use std::io::{ErrorKind, Write};
 use std::os::unix::io::{AsRawFd, FromRawFd};
 use std::os::unix::process::CommandExt;
 use std::process::Command;
 use std::time::{Duration, Instant};
 
-use clap::{App, AppSettings, Arg};
-use curl::easy::Easy;
+use anyhow::{bail, format_err, Result};
+use clap::Arg;
 use mio::net::{TcpListener, TcpStream};
 use mio::unix::SourceFd;
 use mio::{Events, Interest, Poll, Token};
 
-use proxmox::sys::error::io_err_other;
-use proxmox::sys::linux::pty::{make_controlling_terminal, PTY};
-use proxmox::tools::byte_buffer::ByteBuffer;
-use proxmox::{io_bail, io_format_err};
+use proxmox_io::ByteBuffer;
+use proxmox_lang::error::io_err_other;
+use proxmox_sys::linux::pty::{make_controlling_terminal, PTY};
 
 const MSG_TYPE_DATA: u8 = 0;
 const MSG_TYPE_RESIZE: u8 = 1;
@@ -106,11 +105,11 @@ fn read_ticket_line(
             match buf.read_from(stream) {
                 Ok(n) => {
                     if n == 0 {
-                        io_bail!("connection closed before authentication");
+                        bail!("connection closed before authentication");
                     }
                 }
                 Err(err) if err.kind() == ErrorKind::WouldBlock => {}
-                Err(err) => return Err(err),
+                Err(err) => return Err(err.into()),
             }
 
             if buf[..].contains(&b'\n') {
@@ -118,13 +117,13 @@ fn read_ticket_line(
             }
 
             if buf.is_full() {
-                io_bail!("authentication data is incomplete: {:?}", &buf[..]);
+                bail!("authentication data is incomplete: {:?}", &buf[..]);
             }
         }
 
         elapsed = now.elapsed();
         if elapsed > timeout {
-            io_bail!("timed out");
+            bail!("timed out");
         }
     }
 
@@ -138,7 +137,7 @@ fn read_ticket_line(
             let (username, ticket) = line.split_at(pos);
             Ok((username.into(), ticket[1..].into()))
         }
-        None => io_bail!("authentication data is invalid"),
+        None => bail!("authentication data is invalid"),
     }
 }
 
@@ -150,41 +149,29 @@ fn authenticate(
     authport: u16,
     port: Option<u16>,
 ) -> Result<()> {
-    let mut curl = Easy::new();
-    curl.url(&format!(
-        "http://localhost:{}/api2/json/access/ticket",
-        authport
-    ))?;
-
-    let username = curl.url_encode(username);
-    let ticket = curl.url_encode(ticket);
-    let path = curl.url_encode(path.as_bytes());
-
-    let mut post_fields = Vec::with_capacity(5);
-    post_fields.push(format!("username={}", username));
-    post_fields.push(format!("password={}", ticket));
-    post_fields.push(format!("path={}", path));
-
+    let mut post_fields: Vec<(&str, &str)> = Vec::with_capacity(5);
+    post_fields.push(("username", std::str::from_utf8(username)?));
+    post_fields.push(("password", std::str::from_utf8(ticket)?));
+    post_fields.push(("path", path));
     if let Some(perm) = perm {
-        let perm = curl.url_encode(perm.as_bytes());
-        post_fields.push(format!("privs={}", perm));
+        post_fields.push(("privs", perm));
     }
-
+    let port_str;
     if let Some(port) = port {
-        post_fields.push(format!("port={}", port));
+        port_str = port.to_string();
+        post_fields.push(("port", &port_str));
     }
 
-    curl.post_fields_copy(post_fields.join("&").as_bytes())?;
-    curl.post(true)?;
-    curl.perform()?;
+    let url = format!("http://localhost:{}/api2/json/access/ticket", authport);
 
-    let response_code = curl.response_code()?;
-
-    if response_code != 200 {
-        io_bail!("invalid authentication, code {}", response_code);
+    match ureq::post(&url).send_form(&post_fields[..]) {
+        Ok(res) if res.status() == 200 => Ok(()),
+        Ok(res) | Err(ureq::Error::Status(_, res)) => {
+            let code = res.status();
+            bail!("invalid authentication - {} {}", code, res.status_text())
+        }
+        Err(err) => bail!("authentication request failed - {}", err),
     }
-
-    Ok(())
 }
 
 fn listen_and_accept(
@@ -220,12 +207,15 @@ fn listen_and_accept(
 
         elapsed = now.elapsed();
         if elapsed > timeout {
-            io_bail!("timed out");
+            bail!("timed out");
         }
     }
 }
 
-fn run_pty(cmd: &OsStr, params: clap::OsValues) -> Result<PTY> {
+fn run_pty<'a>(mut full_cmd: impl Iterator<Item = &'a OsString>) -> Result<PTY> {
+    let cmd_exe = full_cmd.next().unwrap();
+    let params = full_cmd; // rest
+
     let (mut pty, secondary_name) = PTY::new().map_err(io_err_other)?;
 
     let mut filtered_env: HashMap<OsString, OsString> = std::env::vars_os()
@@ -240,7 +230,7 @@ fn run_pty(cmd: &OsStr, params: clap::OsValues) -> Result<PTY> {
         .collect();
     filtered_env.insert("TERM".into(), "xterm-256color".into());
 
-    let mut command = Command::new(cmd);
+    let mut command = Command::new(cmd_exe);
 
     command.args(params).env_clear().envs(&filtered_env);
 
@@ -253,7 +243,7 @@ fn run_pty(cmd: &OsStr, params: clap::OsValues) -> Result<PTY> {
 
     command.spawn()?;
 
-    pty.set_size(80, 20).map_err(|x| x.as_errno().unwrap())?;
+    pty.set_size(80, 20)?;
     Ok(pty)
 }
 
@@ -261,64 +251,56 @@ const TCP: Token = Token(0);
 const PTY: Token = Token(1);
 
 fn do_main() -> Result<()> {
-    let matches = App::new("termproxy")
-        .setting(AppSettings::TrailingVarArg)
-        .arg(Arg::with_name("port").takes_value(true).required(true))
+    let matches = clap::builder::Command::new("termproxy")
+        .trailing_var_arg(true)
         .arg(
-            Arg::with_name("authport")
-                .takes_value(true)
-                .long("authport"),
+            Arg::new("port")
+                .num_args(1)
+                .required(true)
+                .value_parser(clap::value_parser!(u64)),
         )
-        .arg(Arg::with_name("use-port-as-fd").long("port-as-fd"))
+        .arg(Arg::new("authport").num_args(1).long("authport"))
+        .arg(Arg::new("use-port-as-fd").long("port-as-fd"))
+        .arg(Arg::new("path").num_args(1).long("path").required(true))
+        .arg(Arg::new("perm").num_args(1).long("perm"))
         .arg(
-            Arg::with_name("path")
-                .takes_value(true)
-                .long("path")
+            Arg::new("cmd")
+                .value_parser(clap::value_parser!(OsString))
+                .num_args(1..)
                 .required(true),
         )
-        .arg(Arg::with_name("perm").takes_value(true).long("perm"))
-        .arg(Arg::with_name("cmd").multiple(true).required(true))
         .get_matches();
 
-    let port: u64 = matches
-        .value_of("port")
-        .unwrap()
-        .parse()
-        .map_err(io_err_other)?;
-    let path = matches.value_of("path").unwrap();
-    let perm: Option<&str> = matches.value_of("perm");
-    let mut cmdparams = matches.values_of_os("cmd").unwrap();
-    let cmd = cmdparams.next().unwrap();
-    let authport: u16 = matches
-        .value_of("authport")
-        .unwrap_or("85")
-        .parse()
-        .map_err(io_err_other)?;
-    let mut pty_buf = ByteBuffer::new();
-    let mut tcp_buf = ByteBuffer::new();
-
-    let use_port_as_fd = matches.is_present("use-port-as-fd");
+    let port: u64 = *matches.get_one("port").unwrap();
+    let path = matches.get_one::<String>("path").unwrap();
+    let perm = matches.get_one::<String>("perm").map(|x| x.as_str());
+    let full_cmd: clap::parser::ValuesRef<OsString> = matches.get_many("cmd").unwrap();
+    let authport: u16 = *matches.get_one("authport").unwrap_or(&85);
+    let use_port_as_fd = matches.contains_id("use-port-as-fd");
 
     if use_port_as_fd && port > u16::MAX as u64 {
-        return Err(io_format_err!("port too big"));
+        return Err(format_err!("port too big"));
     } else if port > i32::MAX as u64 {
-        return Err(io_format_err!("Invalid FD number"));
+        return Err(format_err!("Invalid FD number"));
     }
 
     let (mut tcp_handle, port) =
         listen_and_accept("localhost", port, use_port_as_fd, Duration::new(10, 0))
-            .map_err(|err| io_format_err!("failed waiting for client: {}", err))?;
+            .map_err(|err| format_err!("failed waiting for client: {}", err))?;
+
+    let mut pty_buf = ByteBuffer::new();
+    let mut tcp_buf = ByteBuffer::new();
 
     let (username, ticket) = read_ticket_line(&mut tcp_handle, &mut pty_buf, Duration::new(10, 0))
-        .map_err(|err| io_format_err!("failed reading ticket: {}", err))?;
+        .map_err(|err| format_err!("failed reading ticket: {}", err))?;
     let port = if use_port_as_fd { Some(port) } else { None };
-    authenticate(&username, &ticket, path, perm, authport, port)?;
+    authenticate(&username, &ticket, &path, perm.as_deref(), authport, port)?;
     tcp_handle.write_all(b"OK").expect("error writing response");
 
     let mut poll = Poll::new()?;
     let mut events = Events::with_capacity(128);
 
-    let mut pty = run_pty(cmd, cmdparams)?;
+    let mut pty = run_pty(full_cmd)?;
 
     poll.registry().register(
         &mut tcp_handle,
@@ -381,7 +363,7 @@ fn do_main() -> Result<()> {
                 }
                 Err(err) => {
                     if !finished {
-                        return Err(io_format_err!("error reading from tcp: {}", err));
+                        return Err(format_err!("error reading from tcp: {}", err));
                     }
                     break;
                 }
@@ -401,7 +383,7 @@ fn do_main() -> Result<()> {
                 }
                 Err(err) => {
                     if !finished {
-                        return Err(io_format_err!("error reading from pty: {}", err));
+                        return Err(format_err!("error reading from pty: {}", err));
                     }
                     break;
                 }
@@ -421,7 +403,7 @@ fn do_main() -> Result<()> {
                 }
                 Err(err) => {
                     if !finished {
-                        return Err(io_format_err!("error writing to tcp : {}", err));
+                        return Err(format_err!("error writing to tcp : {}", err));
                     }
                     break;
                 }
@@ -445,7 +427,7 @@ fn do_main() -> Result<()> {
                 }
                 Err(err) => {
                     if !finished {
-                        return Err(io_format_err!("error writing to pty : {}", err));
+                        return Err(format_err!("error writing to pty : {}", err));
                     }
                     break;
                 }