]> git.proxmox.com Git - proxmox-websocket-tunnel.git/blame - src/main.rs
cleanup
[proxmox-websocket-tunnel.git] / src / main.rs
CommitLineData
c18e63b9
FG
1use anyhow::{bail, format_err, Error};
2
3use std::collections::VecDeque;
4use std::sync::Arc;
5
6use futures::future::FutureExt;
7use futures::select;
8
9use hyper::client::{Client, HttpConnector};
10use hyper::header::{SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_VERSION, UPGRADE};
11use hyper::upgrade::Upgraded;
12use hyper::{Body, Request, StatusCode};
13
14use openssl::ssl::{SslConnector, SslMethod};
15use percent_encoding::{utf8_percent_encode, NON_ALPHANUMERIC};
16
17use serde::{Deserialize, Serialize};
18use serde_json::{Map, Value};
19use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
20use tokio::net::{UnixListener, UnixStream};
21use tokio::sync::{mpsc, oneshot};
22use tokio_stream::wrappers::LinesStream;
23use tokio_stream::StreamExt;
24
25use proxmox_http::client::HttpsConnector;
26use proxmox_http::websocket::{OpCode, WebSocket, WebSocketReader, WebSocketWriter};
27
28#[derive(Serialize, Deserialize, Debug)]
29#[serde(rename_all = "kebab-case")]
30enum CmdType {
31 Connect,
32 Forward,
33 NonControl,
34}
35
36type CmdData = Map<String, Value>;
37
38#[derive(Serialize, Deserialize, Debug)]
39#[serde(rename_all = "kebab-case")]
40struct ConnectCmdData {
941f0305 41 /// target URL for WS connection
c18e63b9 42 url: String,
941f0305
WB
43
44 /// fingerprint of TLS certificate
c18e63b9 45 fingerprint: Option<String>,
941f0305
WB
46
47 /// addition headers such as authorization
c18e63b9
FG
48 headers: Option<Vec<(String, String)>>,
49}
50
51#[derive(Serialize, Deserialize, Debug, Clone)]
52#[serde(rename_all = "kebab-case")]
53struct ForwardCmdData {
941f0305 54 /// target URL for WS connection
c18e63b9 55 url: String,
941f0305
WB
56
57 /// addition headers such as authorization
c18e63b9 58 headers: Option<Vec<(String, String)>>,
941f0305
WB
59
60 /// fingerprint of TLS certificate
c18e63b9 61 fingerprint: Option<String>,
941f0305
WB
62
63 /// local UNIX socket path for forwarding
c18e63b9 64 unix: String,
941f0305
WB
65
66 /// request ticket using these parameters
c18e63b9
FG
67 ticket: Option<Map<String, Value>>,
68}
69
70struct CtrlTunnel {
71 sender: Option<mpsc::UnboundedSender<(Value, oneshot::Sender<String>)>>,
72}
73
74impl CtrlTunnel {
75 async fn read_cmd_loop(mut self) -> Result<(), Error> {
76 let mut stdin_stream = LinesStream::new(BufReader::new(tokio::io::stdin()).lines());
77 while let Some(res) = stdin_stream.next().await {
78 match res {
79 Ok(line) => {
80 let (cmd_type, data) = Self::parse_cmd(&line)?;
81 match cmd_type {
82 CmdType::Connect => self.handle_connect_cmd(data).await,
83 CmdType::Forward => {
84 let res = self.handle_forward_cmd(data).await;
85 match &res {
86 Ok(()) => println!("{}", serde_json::json!({"success": true})),
87 Err(msg) => println!(
88 "{}",
89 serde_json::json!({"success": false, "msg": msg.to_string()})
90 ),
91 };
92 res
93 }
94 CmdType::NonControl => self
95 .handle_tunnel_cmd(data)
96 .await
97 .map(|res| println!("{}", res)),
98 }
99 }
100 Err(err) => bail!("Failed to read from STDIN - {}", err),
101 }?;
102 }
103
104 Ok(())
105 }
106
107 fn parse_cmd(line: &str) -> Result<(CmdType, CmdData), Error> {
108 let mut json: Map<String, Value> = serde_json::from_str(line)?;
109 match json.remove("control") {
110 Some(Value::Bool(true)) => {
111 match json.remove("cmd").map(serde_json::from_value::<CmdType>) {
112 None => bail!("input has 'control' flag, but no control 'cmd' set.."),
113 Some(Err(e)) => bail!("failed to parse control cmd - {}", e),
114 Some(Ok(cmd_type)) => Ok((cmd_type, json)),
115 }
116 }
117 _ => Ok((CmdType::NonControl, json)),
118 }
119 }
120
121 async fn websocket_connect(
122 url: String,
123 extra_headers: Vec<(String, String)>,
124 fingerprint: Option<String>,
125 ) -> Result<Upgraded, Error> {
126 let ws_key = proxmox_sys::linux::random_data(16)?;
127 let ws_key = base64::encode(&ws_key);
128 let mut req = Request::builder()
129 .uri(url)
130 .header(UPGRADE, "websocket")
131 .header(SEC_WEBSOCKET_VERSION, "13")
132 .header(SEC_WEBSOCKET_KEY, ws_key)
133 .body(Body::empty())?;
134
135 let headers = req.headers_mut();
136 for (name, value) in extra_headers {
137 let name = hyper::header::HeaderName::from_bytes(name.as_bytes())?;
138 let value = hyper::header::HeaderValue::from_str(&value)?;
139 headers.insert(name, value);
140 }
141
142 let mut ssl_connector_builder = SslConnector::builder(SslMethod::tls())?;
838b8aaf
FG
143 if let Some(expected) = fingerprint {
144 ssl_connector_builder.set_verify_callback(
145 openssl::ssl::SslVerifyMode::PEER,
146 move |_valid, ctx| {
147 let cert = match ctx.current_cert() {
148 Some(cert) => cert,
149 None => {
150 // should not happen
151 eprintln!("SSL context lacks current certificate.");
152 return false;
153 }
154 };
155
156 // skip CA certificates, we only care about the peer cert
157 let depth = ctx.error_depth();
158 if depth != 0 {
159 return true;
160 }
161
162 let fp = match cert.digest(openssl::hash::MessageDigest::sha256()) {
163 Ok(fp) => fp,
164 Err(err) => {
165 // should not happen
166 eprintln!("failed to calculate certificate FP - {}", err);
167 return false;
168 }
169 };
170 let fp_string = hex::encode(&fp);
171 let fp_string = fp_string
172 .as_bytes()
173 .chunks(2)
174 .map(|v| std::str::from_utf8(v).unwrap())
175 .collect::<Vec<&str>>()
176 .join(":");
177
178 let expected = expected.to_lowercase();
179 if expected == fp_string {
180 true
181 } else {
182 eprintln!("certificate fingerprint does not match expected fingerprint!");
183 eprintln!("expected: {}", expected);
184 eprintln!("encountered: {}", fp_string);
185 false
186 }
187 },
188 );
c18e63b9
FG
189 } else {
190 ssl_connector_builder.set_verify(openssl::ssl::SslVerifyMode::PEER);
191 }
192
193 let mut httpc = HttpConnector::new();
194 httpc.enforce_http(false); // we want https...
195 httpc.set_connect_timeout(Some(std::time::Duration::new(10, 0)));
196 let https = HttpsConnector::with_connector(httpc, ssl_connector_builder.build(), 120);
197
198 let client = Client::builder().build::<_, Body>(https);
199 let res = client.request(req).await?;
200 if res.status() != StatusCode::SWITCHING_PROTOCOLS {
201 bail!("server didn't upgrade: {}", res.status());
202 }
203
204 hyper::upgrade::on(res)
205 .await
206 .map_err(|err| format_err!("failed to upgrade - {}", err))
207 }
208
209 async fn handle_connect_cmd(&mut self, mut data: CmdData) -> Result<(), Error> {
210 let mut data: ConnectCmdData = data
211 .remove("data")
212 .ok_or_else(|| format_err!("'connect' command missing 'data'"))
213 .map(serde_json::from_value)??;
214
215 if self.sender.is_some() {
216 bail!("already connected!");
217 }
218
219 let upgraded = Self::websocket_connect(
220 data.url.clone(),
221 data.headers.take().unwrap_or_else(Vec::new),
222 data.fingerprint.take(),
223 )
224 .await?;
225
226 let (tx, rx) = mpsc::unbounded_channel();
227 self.sender = Some(tx);
228 tokio::spawn(async move {
229 if let Err(err) = Self::handle_ctrl_tunnel(upgraded, rx).await {
230 eprintln!("Tunnel to {} failed - {}", data.url, err);
231 }
232 });
233
234 Ok(())
235 }
236
237 async fn handle_forward_cmd(&mut self, mut data: CmdData) -> Result<(), Error> {
238 let data: ForwardCmdData = data
239 .remove("data")
240 .ok_or_else(|| format_err!("'forward' command missing 'data'"))
241 .map(serde_json::from_value)??;
242
243 if self.sender.is_none() && data.ticket.is_some() {
244 bail!("dynamically requesting ticket requires cmd tunnel connection!");
245 }
246
247 let unix_listener = UnixListener::bind(data.unix.clone())?;
248 let data = Arc::new(data);
249 let cmd_sender = self.sender.clone();
250
251 tokio::spawn(async move {
252 let mut tasks: Vec<tokio::task::JoinHandle<()>> = Vec::new();
253 let data2 = data.clone();
254
255 loop {
256 let data3 = data2.clone();
257
258 match unix_listener.accept().await {
259 Ok((unix_stream, _)) => {
260 eprintln!("accepted new connection on '{}'", data3.unix);
261 let cmd_sender2 = cmd_sender.clone();
262
263 let task = tokio::spawn(async move {
264 if let Err(err) = Self::handle_forward_tunnel(
265 cmd_sender2.clone(),
266 data3.clone(),
267 unix_stream,
268 )
269 .await
270 {
271 eprintln!("Tunnel for {} failed - {}", data3.unix, err);
272 }
273 });
274 tasks.push(task);
275 }
276 Err(err) => eprintln!(
277 "Failed to accept unix connection on {} - {}",
278 data3.unix, err
279 ),
280 };
281 }
282 });
283
284 Ok(())
285 }
286
287 async fn handle_tunnel_cmd(&mut self, data: CmdData) -> Result<String, Error> {
288 match &mut self.sender {
289 None => bail!("not connected!"),
290 Some(sender) => {
291 let data: Value = data.into();
292 let (tx, rx) = oneshot::channel::<String>();
293 if let Some(cmd) = data.get("cmd") {
294 eprintln!("-> sending command {} to remote", cmd);
295 } else {
296 eprintln!("-> sending data line to remote");
297 }
298 sender.send((data, tx))?;
299 let res = rx.await?;
300 eprintln!("<- got reply");
301 Ok(res)
302 }
303 }
304 }
305
306 async fn handle_ctrl_tunnel(
307 websocket: Upgraded,
308 mut cmd_receiver: mpsc::UnboundedReceiver<(Value, oneshot::Sender<String>)>,
309 ) -> Result<(), Error> {
310 let (tunnel_reader, tunnel_writer) = tokio::io::split(websocket);
311 let (ws_close_tx, mut ws_close_rx) = mpsc::unbounded_channel();
312 let ws_reader = WebSocketReader::new(tunnel_reader, ws_close_tx);
313 let mut ws_writer = WebSocketWriter::new(Some([0, 0, 0, 0]), tunnel_writer);
314
315 let mut framed_reader =
316 tokio_util::codec::FramedRead::new(ws_reader, tokio_util::codec::LinesCodec::new());
317
318 let mut resp_tx_queue: VecDeque<oneshot::Sender<String>> = VecDeque::new();
319 let mut shutting_down = false;
320
321 loop {
322 let mut close_future = ws_close_rx.recv().boxed().fuse();
323 let mut frame_future = framed_reader.next().boxed().fuse();
324 let mut cmd_future = cmd_receiver.recv().boxed().fuse();
325
326 select! {
327 res = close_future => {
328 let res = res.ok_or_else(|| format_err!("WS control channel closed"))?;
329 eprintln!("WS: received control message: '{:?}'", res);
330 shutting_down = true;
331 },
332 res = frame_future => {
333 match res {
334 None if shutting_down => {
335 eprintln!("WS closed");
336 break;
337 },
338 None => bail!("WS closed unexpectedly"),
339 Some(Ok(res)) => {
340 resp_tx_queue
341 .pop_front()
342 .ok_or_else(|| format_err!("no response handler"))?
343 .send(res)
344 .map_err(|msg| format_err!("failed to send tunnel response '{}' back to requester - receiver already closed?", msg))?;
345 },
346 Some(Err(err)) => {
347 bail!("reading from control tunnel failed - WS receive failed: {}", err);
348 },
349 }
350 },
351 res = cmd_future => {
352 if shutting_down { continue };
353 match res {
354 None => {
355 eprintln!("CMD channel closed, shutting down");
356 ws_writer.send_control_frame(Some([1,2,3,4]), OpCode::Close, &[]).await?;
357 shutting_down = true;
358 },
359 Some((msg, resp_tx)) => {
360 resp_tx_queue.push_back(resp_tx);
361
362 let line = format!("{}\n", msg);
363 ws_writer.write_all(line.as_bytes()).await?;
364 ws_writer.flush().await?;
365 },
366 }
367 },
368 };
369 }
370
371 Ok(())
372 }
373
374 async fn handle_forward_tunnel(
375 cmd_sender: Option<mpsc::UnboundedSender<(Value, oneshot::Sender<String>)>>,
376 data: Arc<ForwardCmdData>,
377 unix: UnixStream,
378 ) -> Result<(), Error> {
379 let data = match (&cmd_sender, &data.ticket) {
380 (Some(cmd_sender), Some(_)) => Self::get_ticket(cmd_sender, data.clone()).await,
381 _ => Ok(data.clone()),
382 }?;
383
384 let upgraded = Self::websocket_connect(
385 data.url.clone(),
386 data.headers.clone().unwrap_or_else(Vec::new),
387 data.fingerprint.clone(),
388 )
389 .await?;
390
391 let ws = WebSocket {
392 mask: Some([0, 0, 0, 0]),
393 };
394 eprintln!("established new WS for forwarding '{}'", data.unix);
395 ws.serve_connection(upgraded, unix).await?;
396
397 eprintln!("done handling forwarded connection from '{}'", data.unix);
398
399 Ok(())
400 }
401
402 async fn get_ticket(
403 cmd_sender: &mpsc::UnboundedSender<(Value, oneshot::Sender<String>)>,
404 cmd_data: Arc<ForwardCmdData>,
405 ) -> Result<Arc<ForwardCmdData>, Error> {
406 eprintln!("requesting WS ticket via tunnel");
407 let ticket_cmd = match cmd_data.ticket.clone() {
408 Some(mut ticket_cmd) => {
409 ticket_cmd.insert("cmd".to_string(), serde_json::json!("ticket"));
410 ticket_cmd
411 }
412 None => bail!("can't get ticket without ticket parameters"),
413 };
414 let (tx, rx) = oneshot::channel::<String>();
415 cmd_sender.send((serde_json::json!(ticket_cmd), tx))?;
416 let ticket = rx.await?;
417 let mut ticket: Map<String, Value> = serde_json::from_str(&ticket)?;
418 let ticket = ticket
419 .remove("ticket")
420 .ok_or_else(|| format_err!("failed to retrieve ticket via tunnel"))?;
421
422 let ticket = ticket
423 .as_str()
424 .ok_or_else(|| format_err!("failed to format received ticket"))?;
425 let ticket = utf8_percent_encode(ticket, NON_ALPHANUMERIC).to_string();
426
427 let mut data = cmd_data.clone();
428 let mut url = data.url.clone();
429 url.push_str("ticket=");
430 url.push_str(&ticket);
431 let mut d = Arc::make_mut(&mut data);
432 d.url = url;
433 Ok(data)
434 }
435}
436
437#[tokio::main]
438async fn main() -> Result<(), Error> {
c18e63b9
FG
439 let tunnel = CtrlTunnel { sender: None };
440 tunnel.read_cmd_loop().await
441}