]> git.proxmox.com Git - proxmox-websocket-tunnel.git/blob - src/main.rs
582214c7f9436caa7717f9dd1498cd582be289b6
[proxmox-websocket-tunnel.git] / src / main.rs
1 use anyhow::{bail, format_err, Error};
2
3 use std::collections::VecDeque;
4 use std::sync::Arc;
5
6 use futures::future::FutureExt;
7 use futures::select;
8
9 use hyper::client::{Client, HttpConnector};
10 use hyper::header::{SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_VERSION, UPGRADE};
11 use hyper::upgrade::Upgraded;
12 use hyper::{Body, Request, StatusCode};
13
14 use openssl::ssl::{SslConnector, SslMethod};
15 use percent_encoding::{utf8_percent_encode, NON_ALPHANUMERIC};
16
17 use serde::{Deserialize, Serialize};
18 use serde_json::{Map, Value};
19 use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
20 use tokio::net::{UnixListener, UnixStream};
21 use tokio::sync::{mpsc, oneshot};
22 use tokio_stream::wrappers::LinesStream;
23 use tokio_stream::StreamExt;
24
25 use proxmox_http::client::HttpsConnector;
26 use proxmox_http::websocket::{OpCode, WebSocket, WebSocketReader, WebSocketWriter};
27
28 #[derive(Serialize, Deserialize, Debug)]
29 #[serde(rename_all = "kebab-case")]
30 enum CmdType {
31 Connect,
32 Forward,
33 NonControl,
34 }
35
36 type CmdData = Map<String, Value>;
37
38 #[derive(Serialize, Deserialize, Debug)]
39 #[serde(rename_all = "kebab-case")]
40 struct ConnectCmdData {
41 // target URL for WS connection
42 url: String,
43 // fingerprint of TLS certificate
44 fingerprint: Option<String>,
45 // addition headers such as authorization
46 headers: Option<Vec<(String, String)>>,
47 }
48
49 #[derive(Serialize, Deserialize, Debug, Clone)]
50 #[serde(rename_all = "kebab-case")]
51 struct ForwardCmdData {
52 // target URL for WS connection
53 url: String,
54 // addition headers such as authorization
55 headers: Option<Vec<(String, String)>>,
56 // fingerprint of TLS certificate
57 fingerprint: Option<String>,
58 // local UNIX socket path for forwarding
59 unix: String,
60 // request ticket using these parameters
61 ticket: Option<Map<String, Value>>,
62 }
63
64 struct CtrlTunnel {
65 sender: Option<mpsc::UnboundedSender<(Value, oneshot::Sender<String>)>>,
66 }
67
68 impl CtrlTunnel {
69 async fn read_cmd_loop(mut self) -> Result<(), Error> {
70 let mut stdin_stream = LinesStream::new(BufReader::new(tokio::io::stdin()).lines());
71 while let Some(res) = stdin_stream.next().await {
72 match res {
73 Ok(line) => {
74 let (cmd_type, data) = Self::parse_cmd(&line)?;
75 match cmd_type {
76 CmdType::Connect => self.handle_connect_cmd(data).await,
77 CmdType::Forward => {
78 let res = self.handle_forward_cmd(data).await;
79 match &res {
80 Ok(()) => println!("{}", serde_json::json!({"success": true})),
81 Err(msg) => println!(
82 "{}",
83 serde_json::json!({"success": false, "msg": msg.to_string()})
84 ),
85 };
86 res
87 }
88 CmdType::NonControl => self
89 .handle_tunnel_cmd(data)
90 .await
91 .map(|res| println!("{}", res)),
92 }
93 }
94 Err(err) => bail!("Failed to read from STDIN - {}", err),
95 }?;
96 }
97
98 Ok(())
99 }
100
101 fn parse_cmd(line: &str) -> Result<(CmdType, CmdData), Error> {
102 let mut json: Map<String, Value> = serde_json::from_str(line)?;
103 match json.remove("control") {
104 Some(Value::Bool(true)) => {
105 match json.remove("cmd").map(serde_json::from_value::<CmdType>) {
106 None => bail!("input has 'control' flag, but no control 'cmd' set.."),
107 Some(Err(e)) => bail!("failed to parse control cmd - {}", e),
108 Some(Ok(cmd_type)) => Ok((cmd_type, json)),
109 }
110 }
111 _ => Ok((CmdType::NonControl, json)),
112 }
113 }
114
115 async fn websocket_connect(
116 url: String,
117 extra_headers: Vec<(String, String)>,
118 fingerprint: Option<String>,
119 ) -> Result<Upgraded, Error> {
120 let ws_key = proxmox_sys::linux::random_data(16)?;
121 let ws_key = base64::encode(&ws_key);
122 let mut req = Request::builder()
123 .uri(url)
124 .header(UPGRADE, "websocket")
125 .header(SEC_WEBSOCKET_VERSION, "13")
126 .header(SEC_WEBSOCKET_KEY, ws_key)
127 .body(Body::empty())?;
128
129 let headers = req.headers_mut();
130 for (name, value) in extra_headers {
131 let name = hyper::header::HeaderName::from_bytes(name.as_bytes())?;
132 let value = hyper::header::HeaderValue::from_str(&value)?;
133 headers.insert(name, value);
134 }
135
136 let mut ssl_connector_builder = SslConnector::builder(SslMethod::tls())?;
137 if fingerprint.is_some() {
138 // FIXME actually verify fingerprint via callback!
139 ssl_connector_builder.set_verify(openssl::ssl::SslVerifyMode::NONE);
140 } else {
141 ssl_connector_builder.set_verify(openssl::ssl::SslVerifyMode::PEER);
142 }
143
144 let mut httpc = HttpConnector::new();
145 httpc.enforce_http(false); // we want https...
146 httpc.set_connect_timeout(Some(std::time::Duration::new(10, 0)));
147 let https = HttpsConnector::with_connector(httpc, ssl_connector_builder.build(), 120);
148
149 let client = Client::builder().build::<_, Body>(https);
150 let res = client.request(req).await?;
151 if res.status() != StatusCode::SWITCHING_PROTOCOLS {
152 bail!("server didn't upgrade: {}", res.status());
153 }
154
155 hyper::upgrade::on(res)
156 .await
157 .map_err(|err| format_err!("failed to upgrade - {}", err))
158 }
159
160 async fn handle_connect_cmd(&mut self, mut data: CmdData) -> Result<(), Error> {
161 let mut data: ConnectCmdData = data
162 .remove("data")
163 .ok_or_else(|| format_err!("'connect' command missing 'data'"))
164 .map(serde_json::from_value)??;
165
166 if self.sender.is_some() {
167 bail!("already connected!");
168 }
169
170 let upgraded = Self::websocket_connect(
171 data.url.clone(),
172 data.headers.take().unwrap_or_else(Vec::new),
173 data.fingerprint.take(),
174 )
175 .await?;
176
177 let (tx, rx) = mpsc::unbounded_channel();
178 self.sender = Some(tx);
179 tokio::spawn(async move {
180 if let Err(err) = Self::handle_ctrl_tunnel(upgraded, rx).await {
181 eprintln!("Tunnel to {} failed - {}", data.url, err);
182 }
183 });
184
185 Ok(())
186 }
187
188 async fn handle_forward_cmd(&mut self, mut data: CmdData) -> Result<(), Error> {
189 let data: ForwardCmdData = data
190 .remove("data")
191 .ok_or_else(|| format_err!("'forward' command missing 'data'"))
192 .map(serde_json::from_value)??;
193
194 if self.sender.is_none() && data.ticket.is_some() {
195 bail!("dynamically requesting ticket requires cmd tunnel connection!");
196 }
197
198 let unix_listener = UnixListener::bind(data.unix.clone())?;
199 let data = Arc::new(data);
200 let cmd_sender = self.sender.clone();
201
202 tokio::spawn(async move {
203 let mut tasks: Vec<tokio::task::JoinHandle<()>> = Vec::new();
204 let data2 = data.clone();
205
206 loop {
207 let data3 = data2.clone();
208
209 match unix_listener.accept().await {
210 Ok((unix_stream, _)) => {
211 eprintln!("accepted new connection on '{}'", data3.unix);
212 let cmd_sender2 = cmd_sender.clone();
213
214 let task = tokio::spawn(async move {
215 if let Err(err) = Self::handle_forward_tunnel(
216 cmd_sender2.clone(),
217 data3.clone(),
218 unix_stream,
219 )
220 .await
221 {
222 eprintln!("Tunnel for {} failed - {}", data3.unix, err);
223 }
224 });
225 tasks.push(task);
226 }
227 Err(err) => eprintln!(
228 "Failed to accept unix connection on {} - {}",
229 data3.unix, err
230 ),
231 };
232 }
233 });
234
235 Ok(())
236 }
237
238 async fn handle_tunnel_cmd(&mut self, data: CmdData) -> Result<String, Error> {
239 match &mut self.sender {
240 None => bail!("not connected!"),
241 Some(sender) => {
242 let data: Value = data.into();
243 let (tx, rx) = oneshot::channel::<String>();
244 if let Some(cmd) = data.get("cmd") {
245 eprintln!("-> sending command {} to remote", cmd);
246 } else {
247 eprintln!("-> sending data line to remote");
248 }
249 sender.send((data, tx))?;
250 let res = rx.await?;
251 eprintln!("<- got reply");
252 Ok(res)
253 }
254 }
255 }
256
257 async fn handle_ctrl_tunnel(
258 websocket: Upgraded,
259 mut cmd_receiver: mpsc::UnboundedReceiver<(Value, oneshot::Sender<String>)>,
260 ) -> Result<(), Error> {
261 let (tunnel_reader, tunnel_writer) = tokio::io::split(websocket);
262 let (ws_close_tx, mut ws_close_rx) = mpsc::unbounded_channel();
263 let ws_reader = WebSocketReader::new(tunnel_reader, ws_close_tx);
264 let mut ws_writer = WebSocketWriter::new(Some([0, 0, 0, 0]), tunnel_writer);
265
266 let mut framed_reader =
267 tokio_util::codec::FramedRead::new(ws_reader, tokio_util::codec::LinesCodec::new());
268
269 let mut resp_tx_queue: VecDeque<oneshot::Sender<String>> = VecDeque::new();
270 let mut shutting_down = false;
271
272 loop {
273 let mut close_future = ws_close_rx.recv().boxed().fuse();
274 let mut frame_future = framed_reader.next().boxed().fuse();
275 let mut cmd_future = cmd_receiver.recv().boxed().fuse();
276
277 select! {
278 res = close_future => {
279 let res = res.ok_or_else(|| format_err!("WS control channel closed"))?;
280 eprintln!("WS: received control message: '{:?}'", res);
281 shutting_down = true;
282 },
283 res = frame_future => {
284 match res {
285 None if shutting_down => {
286 eprintln!("WS closed");
287 break;
288 },
289 None => bail!("WS closed unexpectedly"),
290 Some(Ok(res)) => {
291 resp_tx_queue
292 .pop_front()
293 .ok_or_else(|| format_err!("no response handler"))?
294 .send(res)
295 .map_err(|msg| format_err!("failed to send tunnel response '{}' back to requester - receiver already closed?", msg))?;
296 },
297 Some(Err(err)) => {
298 bail!("reading from control tunnel failed - WS receive failed: {}", err);
299 },
300 }
301 },
302 res = cmd_future => {
303 if shutting_down { continue };
304 match res {
305 None => {
306 eprintln!("CMD channel closed, shutting down");
307 ws_writer.send_control_frame(Some([1,2,3,4]), OpCode::Close, &[]).await?;
308 shutting_down = true;
309 },
310 Some((msg, resp_tx)) => {
311 resp_tx_queue.push_back(resp_tx);
312
313 let line = format!("{}\n", msg);
314 ws_writer.write_all(line.as_bytes()).await?;
315 ws_writer.flush().await?;
316 },
317 }
318 },
319 };
320 }
321
322 Ok(())
323 }
324
325 async fn handle_forward_tunnel(
326 cmd_sender: Option<mpsc::UnboundedSender<(Value, oneshot::Sender<String>)>>,
327 data: Arc<ForwardCmdData>,
328 unix: UnixStream,
329 ) -> Result<(), Error> {
330 let data = match (&cmd_sender, &data.ticket) {
331 (Some(cmd_sender), Some(_)) => Self::get_ticket(cmd_sender, data.clone()).await,
332 _ => Ok(data.clone()),
333 }?;
334
335 let upgraded = Self::websocket_connect(
336 data.url.clone(),
337 data.headers.clone().unwrap_or_else(Vec::new),
338 data.fingerprint.clone(),
339 )
340 .await?;
341
342 let ws = WebSocket {
343 mask: Some([0, 0, 0, 0]),
344 };
345 eprintln!("established new WS for forwarding '{}'", data.unix);
346 ws.serve_connection(upgraded, unix).await?;
347
348 eprintln!("done handling forwarded connection from '{}'", data.unix);
349
350 Ok(())
351 }
352
353 async fn get_ticket(
354 cmd_sender: &mpsc::UnboundedSender<(Value, oneshot::Sender<String>)>,
355 cmd_data: Arc<ForwardCmdData>,
356 ) -> Result<Arc<ForwardCmdData>, Error> {
357 eprintln!("requesting WS ticket via tunnel");
358 let ticket_cmd = match cmd_data.ticket.clone() {
359 Some(mut ticket_cmd) => {
360 ticket_cmd.insert("cmd".to_string(), serde_json::json!("ticket"));
361 ticket_cmd
362 }
363 None => bail!("can't get ticket without ticket parameters"),
364 };
365 let (tx, rx) = oneshot::channel::<String>();
366 cmd_sender.send((serde_json::json!(ticket_cmd), tx))?;
367 let ticket = rx.await?;
368 let mut ticket: Map<String, Value> = serde_json::from_str(&ticket)?;
369 let ticket = ticket
370 .remove("ticket")
371 .ok_or_else(|| format_err!("failed to retrieve ticket via tunnel"))?;
372
373 let ticket = ticket
374 .as_str()
375 .ok_or_else(|| format_err!("failed to format received ticket"))?;
376 let ticket = utf8_percent_encode(ticket, NON_ALPHANUMERIC).to_string();
377
378 let mut data = cmd_data.clone();
379 let mut url = data.url.clone();
380 url.push_str("ticket=");
381 url.push_str(&ticket);
382 let mut d = Arc::make_mut(&mut data);
383 d.url = url;
384 Ok(data)
385 }
386 }
387
388 #[tokio::main]
389 async fn main() -> Result<(), Error> {
390 do_main().await
391 }
392
393 async fn do_main() -> Result<(), Error> {
394 let tunnel = CtrlTunnel { sender: None };
395 tunnel.read_cmd_loop().await
396 }