]> git.proxmox.com Git - proxmox-websocket-tunnel.git/blobdiff - src/main.rs
cleanup
[proxmox-websocket-tunnel.git] / src / main.rs
index 582214c7f9436caa7717f9dd1498cd582be289b6..cefe565fd1b1290825539bb88ec1c4f6a8a15f77 100644 (file)
@@ -38,26 +38,32 @@ type CmdData = Map<String, Value>;
 #[derive(Serialize, Deserialize, Debug)]
 #[serde(rename_all = "kebab-case")]
 struct ConnectCmdData {
-    // target URL for WS connection
+    /// target URL for WS connection
     url: String,
-    // fingerprint of TLS certificate
+
+    /// fingerprint of TLS certificate
     fingerprint: Option<String>,
-    // addition headers such as authorization
+
+    /// addition headers such as authorization
     headers: Option<Vec<(String, String)>>,
 }
 
 #[derive(Serialize, Deserialize, Debug, Clone)]
 #[serde(rename_all = "kebab-case")]
 struct ForwardCmdData {
-    // target URL for WS connection
+    /// target URL for WS connection
     url: String,
-    // addition headers such as authorization
+
+    /// addition headers such as authorization
     headers: Option<Vec<(String, String)>>,
-    // fingerprint of TLS certificate
+
+    /// fingerprint of TLS certificate
     fingerprint: Option<String>,
-    // local UNIX socket path for forwarding
+
+    /// local UNIX socket path for forwarding
     unix: String,
-    // request ticket using these parameters
+
+    /// request ticket using these parameters
     ticket: Option<Map<String, Value>>,
 }
 
@@ -134,9 +140,52 @@ impl CtrlTunnel {
         }
 
         let mut ssl_connector_builder = SslConnector::builder(SslMethod::tls())?;
-        if fingerprint.is_some() {
-            // FIXME actually verify fingerprint via callback!
-            ssl_connector_builder.set_verify(openssl::ssl::SslVerifyMode::NONE);
+        if let Some(expected) = fingerprint {
+            ssl_connector_builder.set_verify_callback(
+                openssl::ssl::SslVerifyMode::PEER,
+                move |_valid, ctx| {
+                    let cert = match ctx.current_cert() {
+                        Some(cert) => cert,
+                        None => {
+                            // should not happen
+                            eprintln!("SSL context lacks current certificate.");
+                            return false;
+                        }
+                    };
+
+                    // skip CA certificates, we only care about the peer cert
+                    let depth = ctx.error_depth();
+                    if depth != 0 {
+                        return true;
+                    }
+
+                    let fp = match cert.digest(openssl::hash::MessageDigest::sha256()) {
+                        Ok(fp) => fp,
+                        Err(err) => {
+                            // should not happen
+                            eprintln!("failed to calculate certificate FP - {}", err);
+                            return false;
+                        }
+                    };
+                    let fp_string = hex::encode(&fp);
+                    let fp_string = fp_string
+                        .as_bytes()
+                        .chunks(2)
+                        .map(|v| std::str::from_utf8(v).unwrap())
+                        .collect::<Vec<&str>>()
+                        .join(":");
+
+                    let expected = expected.to_lowercase();
+                    if expected == fp_string {
+                        true
+                    } else {
+                        eprintln!("certificate fingerprint does not match expected fingerprint!");
+                        eprintln!("expected:    {}", expected);
+                        eprintln!("encountered: {}", fp_string);
+                        false
+                    }
+                },
+            );
         } else {
             ssl_connector_builder.set_verify(openssl::ssl::SslVerifyMode::PEER);
         }
@@ -387,10 +436,6 @@ impl CtrlTunnel {
 
 #[tokio::main]
 async fn main() -> Result<(), Error> {
-    do_main().await
-}
-
-async fn do_main() -> Result<(), Error> {
     let tunnel = CtrlTunnel { sender: None };
     tunnel.read_cmd_loop().await
 }