]> git.proxmox.com Git - proxmox-websocket-tunnel.git/blame - src/main.rs
add fingerprint validation
[proxmox-websocket-tunnel.git] / src / main.rs
CommitLineData
c18e63b9
FG
1use anyhow::{bail, format_err, Error};
2
3use std::collections::VecDeque;
4use std::sync::Arc;
5
6use futures::future::FutureExt;
7use futures::select;
8
9use hyper::client::{Client, HttpConnector};
10use hyper::header::{SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_VERSION, UPGRADE};
11use hyper::upgrade::Upgraded;
12use hyper::{Body, Request, StatusCode};
13
14use openssl::ssl::{SslConnector, SslMethod};
15use percent_encoding::{utf8_percent_encode, NON_ALPHANUMERIC};
16
17use serde::{Deserialize, Serialize};
18use serde_json::{Map, Value};
19use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
20use tokio::net::{UnixListener, UnixStream};
21use tokio::sync::{mpsc, oneshot};
22use tokio_stream::wrappers::LinesStream;
23use tokio_stream::StreamExt;
24
25use proxmox_http::client::HttpsConnector;
26use proxmox_http::websocket::{OpCode, WebSocket, WebSocketReader, WebSocketWriter};
27
28#[derive(Serialize, Deserialize, Debug)]
29#[serde(rename_all = "kebab-case")]
30enum CmdType {
31 Connect,
32 Forward,
33 NonControl,
34}
35
36type CmdData = Map<String, Value>;
37
38#[derive(Serialize, Deserialize, Debug)]
39#[serde(rename_all = "kebab-case")]
40struct ConnectCmdData {
41 // target URL for WS connection
42 url: String,
43 // fingerprint of TLS certificate
44 fingerprint: Option<String>,
45 // addition headers such as authorization
46 headers: Option<Vec<(String, String)>>,
47}
48
49#[derive(Serialize, Deserialize, Debug, Clone)]
50#[serde(rename_all = "kebab-case")]
51struct ForwardCmdData {
52 // target URL for WS connection
53 url: String,
54 // addition headers such as authorization
55 headers: Option<Vec<(String, String)>>,
56 // fingerprint of TLS certificate
57 fingerprint: Option<String>,
58 // local UNIX socket path for forwarding
59 unix: String,
60 // request ticket using these parameters
61 ticket: Option<Map<String, Value>>,
62}
63
64struct CtrlTunnel {
65 sender: Option<mpsc::UnboundedSender<(Value, oneshot::Sender<String>)>>,
66}
67
68impl CtrlTunnel {
69 async fn read_cmd_loop(mut self) -> Result<(), Error> {
70 let mut stdin_stream = LinesStream::new(BufReader::new(tokio::io::stdin()).lines());
71 while let Some(res) = stdin_stream.next().await {
72 match res {
73 Ok(line) => {
74 let (cmd_type, data) = Self::parse_cmd(&line)?;
75 match cmd_type {
76 CmdType::Connect => self.handle_connect_cmd(data).await,
77 CmdType::Forward => {
78 let res = self.handle_forward_cmd(data).await;
79 match &res {
80 Ok(()) => println!("{}", serde_json::json!({"success": true})),
81 Err(msg) => println!(
82 "{}",
83 serde_json::json!({"success": false, "msg": msg.to_string()})
84 ),
85 };
86 res
87 }
88 CmdType::NonControl => self
89 .handle_tunnel_cmd(data)
90 .await
91 .map(|res| println!("{}", res)),
92 }
93 }
94 Err(err) => bail!("Failed to read from STDIN - {}", err),
95 }?;
96 }
97
98 Ok(())
99 }
100
101 fn parse_cmd(line: &str) -> Result<(CmdType, CmdData), Error> {
102 let mut json: Map<String, Value> = serde_json::from_str(line)?;
103 match json.remove("control") {
104 Some(Value::Bool(true)) => {
105 match json.remove("cmd").map(serde_json::from_value::<CmdType>) {
106 None => bail!("input has 'control' flag, but no control 'cmd' set.."),
107 Some(Err(e)) => bail!("failed to parse control cmd - {}", e),
108 Some(Ok(cmd_type)) => Ok((cmd_type, json)),
109 }
110 }
111 _ => Ok((CmdType::NonControl, json)),
112 }
113 }
114
115 async fn websocket_connect(
116 url: String,
117 extra_headers: Vec<(String, String)>,
118 fingerprint: Option<String>,
119 ) -> Result<Upgraded, Error> {
120 let ws_key = proxmox_sys::linux::random_data(16)?;
121 let ws_key = base64::encode(&ws_key);
122 let mut req = Request::builder()
123 .uri(url)
124 .header(UPGRADE, "websocket")
125 .header(SEC_WEBSOCKET_VERSION, "13")
126 .header(SEC_WEBSOCKET_KEY, ws_key)
127 .body(Body::empty())?;
128
129 let headers = req.headers_mut();
130 for (name, value) in extra_headers {
131 let name = hyper::header::HeaderName::from_bytes(name.as_bytes())?;
132 let value = hyper::header::HeaderValue::from_str(&value)?;
133 headers.insert(name, value);
134 }
135
136 let mut ssl_connector_builder = SslConnector::builder(SslMethod::tls())?;
838b8aaf
FG
137 if let Some(expected) = fingerprint {
138 ssl_connector_builder.set_verify_callback(
139 openssl::ssl::SslVerifyMode::PEER,
140 move |_valid, ctx| {
141 let cert = match ctx.current_cert() {
142 Some(cert) => cert,
143 None => {
144 // should not happen
145 eprintln!("SSL context lacks current certificate.");
146 return false;
147 }
148 };
149
150 // skip CA certificates, we only care about the peer cert
151 let depth = ctx.error_depth();
152 if depth != 0 {
153 return true;
154 }
155
156 let fp = match cert.digest(openssl::hash::MessageDigest::sha256()) {
157 Ok(fp) => fp,
158 Err(err) => {
159 // should not happen
160 eprintln!("failed to calculate certificate FP - {}", err);
161 return false;
162 }
163 };
164 let fp_string = hex::encode(&fp);
165 let fp_string = fp_string
166 .as_bytes()
167 .chunks(2)
168 .map(|v| std::str::from_utf8(v).unwrap())
169 .collect::<Vec<&str>>()
170 .join(":");
171
172 let expected = expected.to_lowercase();
173 if expected == fp_string {
174 true
175 } else {
176 eprintln!("certificate fingerprint does not match expected fingerprint!");
177 eprintln!("expected: {}", expected);
178 eprintln!("encountered: {}", fp_string);
179 false
180 }
181 },
182 );
c18e63b9
FG
183 } else {
184 ssl_connector_builder.set_verify(openssl::ssl::SslVerifyMode::PEER);
185 }
186
187 let mut httpc = HttpConnector::new();
188 httpc.enforce_http(false); // we want https...
189 httpc.set_connect_timeout(Some(std::time::Duration::new(10, 0)));
190 let https = HttpsConnector::with_connector(httpc, ssl_connector_builder.build(), 120);
191
192 let client = Client::builder().build::<_, Body>(https);
193 let res = client.request(req).await?;
194 if res.status() != StatusCode::SWITCHING_PROTOCOLS {
195 bail!("server didn't upgrade: {}", res.status());
196 }
197
198 hyper::upgrade::on(res)
199 .await
200 .map_err(|err| format_err!("failed to upgrade - {}", err))
201 }
202
203 async fn handle_connect_cmd(&mut self, mut data: CmdData) -> Result<(), Error> {
204 let mut data: ConnectCmdData = data
205 .remove("data")
206 .ok_or_else(|| format_err!("'connect' command missing 'data'"))
207 .map(serde_json::from_value)??;
208
209 if self.sender.is_some() {
210 bail!("already connected!");
211 }
212
213 let upgraded = Self::websocket_connect(
214 data.url.clone(),
215 data.headers.take().unwrap_or_else(Vec::new),
216 data.fingerprint.take(),
217 )
218 .await?;
219
220 let (tx, rx) = mpsc::unbounded_channel();
221 self.sender = Some(tx);
222 tokio::spawn(async move {
223 if let Err(err) = Self::handle_ctrl_tunnel(upgraded, rx).await {
224 eprintln!("Tunnel to {} failed - {}", data.url, err);
225 }
226 });
227
228 Ok(())
229 }
230
231 async fn handle_forward_cmd(&mut self, mut data: CmdData) -> Result<(), Error> {
232 let data: ForwardCmdData = data
233 .remove("data")
234 .ok_or_else(|| format_err!("'forward' command missing 'data'"))
235 .map(serde_json::from_value)??;
236
237 if self.sender.is_none() && data.ticket.is_some() {
238 bail!("dynamically requesting ticket requires cmd tunnel connection!");
239 }
240
241 let unix_listener = UnixListener::bind(data.unix.clone())?;
242 let data = Arc::new(data);
243 let cmd_sender = self.sender.clone();
244
245 tokio::spawn(async move {
246 let mut tasks: Vec<tokio::task::JoinHandle<()>> = Vec::new();
247 let data2 = data.clone();
248
249 loop {
250 let data3 = data2.clone();
251
252 match unix_listener.accept().await {
253 Ok((unix_stream, _)) => {
254 eprintln!("accepted new connection on '{}'", data3.unix);
255 let cmd_sender2 = cmd_sender.clone();
256
257 let task = tokio::spawn(async move {
258 if let Err(err) = Self::handle_forward_tunnel(
259 cmd_sender2.clone(),
260 data3.clone(),
261 unix_stream,
262 )
263 .await
264 {
265 eprintln!("Tunnel for {} failed - {}", data3.unix, err);
266 }
267 });
268 tasks.push(task);
269 }
270 Err(err) => eprintln!(
271 "Failed to accept unix connection on {} - {}",
272 data3.unix, err
273 ),
274 };
275 }
276 });
277
278 Ok(())
279 }
280
281 async fn handle_tunnel_cmd(&mut self, data: CmdData) -> Result<String, Error> {
282 match &mut self.sender {
283 None => bail!("not connected!"),
284 Some(sender) => {
285 let data: Value = data.into();
286 let (tx, rx) = oneshot::channel::<String>();
287 if let Some(cmd) = data.get("cmd") {
288 eprintln!("-> sending command {} to remote", cmd);
289 } else {
290 eprintln!("-> sending data line to remote");
291 }
292 sender.send((data, tx))?;
293 let res = rx.await?;
294 eprintln!("<- got reply");
295 Ok(res)
296 }
297 }
298 }
299
300 async fn handle_ctrl_tunnel(
301 websocket: Upgraded,
302 mut cmd_receiver: mpsc::UnboundedReceiver<(Value, oneshot::Sender<String>)>,
303 ) -> Result<(), Error> {
304 let (tunnel_reader, tunnel_writer) = tokio::io::split(websocket);
305 let (ws_close_tx, mut ws_close_rx) = mpsc::unbounded_channel();
306 let ws_reader = WebSocketReader::new(tunnel_reader, ws_close_tx);
307 let mut ws_writer = WebSocketWriter::new(Some([0, 0, 0, 0]), tunnel_writer);
308
309 let mut framed_reader =
310 tokio_util::codec::FramedRead::new(ws_reader, tokio_util::codec::LinesCodec::new());
311
312 let mut resp_tx_queue: VecDeque<oneshot::Sender<String>> = VecDeque::new();
313 let mut shutting_down = false;
314
315 loop {
316 let mut close_future = ws_close_rx.recv().boxed().fuse();
317 let mut frame_future = framed_reader.next().boxed().fuse();
318 let mut cmd_future = cmd_receiver.recv().boxed().fuse();
319
320 select! {
321 res = close_future => {
322 let res = res.ok_or_else(|| format_err!("WS control channel closed"))?;
323 eprintln!("WS: received control message: '{:?}'", res);
324 shutting_down = true;
325 },
326 res = frame_future => {
327 match res {
328 None if shutting_down => {
329 eprintln!("WS closed");
330 break;
331 },
332 None => bail!("WS closed unexpectedly"),
333 Some(Ok(res)) => {
334 resp_tx_queue
335 .pop_front()
336 .ok_or_else(|| format_err!("no response handler"))?
337 .send(res)
338 .map_err(|msg| format_err!("failed to send tunnel response '{}' back to requester - receiver already closed?", msg))?;
339 },
340 Some(Err(err)) => {
341 bail!("reading from control tunnel failed - WS receive failed: {}", err);
342 },
343 }
344 },
345 res = cmd_future => {
346 if shutting_down { continue };
347 match res {
348 None => {
349 eprintln!("CMD channel closed, shutting down");
350 ws_writer.send_control_frame(Some([1,2,3,4]), OpCode::Close, &[]).await?;
351 shutting_down = true;
352 },
353 Some((msg, resp_tx)) => {
354 resp_tx_queue.push_back(resp_tx);
355
356 let line = format!("{}\n", msg);
357 ws_writer.write_all(line.as_bytes()).await?;
358 ws_writer.flush().await?;
359 },
360 }
361 },
362 };
363 }
364
365 Ok(())
366 }
367
368 async fn handle_forward_tunnel(
369 cmd_sender: Option<mpsc::UnboundedSender<(Value, oneshot::Sender<String>)>>,
370 data: Arc<ForwardCmdData>,
371 unix: UnixStream,
372 ) -> Result<(), Error> {
373 let data = match (&cmd_sender, &data.ticket) {
374 (Some(cmd_sender), Some(_)) => Self::get_ticket(cmd_sender, data.clone()).await,
375 _ => Ok(data.clone()),
376 }?;
377
378 let upgraded = Self::websocket_connect(
379 data.url.clone(),
380 data.headers.clone().unwrap_or_else(Vec::new),
381 data.fingerprint.clone(),
382 )
383 .await?;
384
385 let ws = WebSocket {
386 mask: Some([0, 0, 0, 0]),
387 };
388 eprintln!("established new WS for forwarding '{}'", data.unix);
389 ws.serve_connection(upgraded, unix).await?;
390
391 eprintln!("done handling forwarded connection from '{}'", data.unix);
392
393 Ok(())
394 }
395
396 async fn get_ticket(
397 cmd_sender: &mpsc::UnboundedSender<(Value, oneshot::Sender<String>)>,
398 cmd_data: Arc<ForwardCmdData>,
399 ) -> Result<Arc<ForwardCmdData>, Error> {
400 eprintln!("requesting WS ticket via tunnel");
401 let ticket_cmd = match cmd_data.ticket.clone() {
402 Some(mut ticket_cmd) => {
403 ticket_cmd.insert("cmd".to_string(), serde_json::json!("ticket"));
404 ticket_cmd
405 }
406 None => bail!("can't get ticket without ticket parameters"),
407 };
408 let (tx, rx) = oneshot::channel::<String>();
409 cmd_sender.send((serde_json::json!(ticket_cmd), tx))?;
410 let ticket = rx.await?;
411 let mut ticket: Map<String, Value> = serde_json::from_str(&ticket)?;
412 let ticket = ticket
413 .remove("ticket")
414 .ok_or_else(|| format_err!("failed to retrieve ticket via tunnel"))?;
415
416 let ticket = ticket
417 .as_str()
418 .ok_or_else(|| format_err!("failed to format received ticket"))?;
419 let ticket = utf8_percent_encode(ticket, NON_ALPHANUMERIC).to_string();
420
421 let mut data = cmd_data.clone();
422 let mut url = data.url.clone();
423 url.push_str("ticket=");
424 url.push_str(&ticket);
425 let mut d = Arc::make_mut(&mut data);
426 d.url = url;
427 Ok(data)
428 }
429}
430
431#[tokio::main]
432async fn main() -> Result<(), Error> {
433 do_main().await
434}
435
436async fn do_main() -> Result<(), Error> {
437 let tunnel = CtrlTunnel { sender: None };
438 tunnel.read_cmd_loop().await
439}