]>
Commit | Line | Data |
---|---|---|
c18e63b9 FG |
1 | use anyhow::{bail, format_err, Error}; |
2 | ||
3 | use std::collections::VecDeque; | |
4 | use std::sync::Arc; | |
5 | ||
6 | use futures::future::FutureExt; | |
7 | use futures::select; | |
8 | ||
9 | use hyper::client::{Client, HttpConnector}; | |
10 | use hyper::header::{SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_VERSION, UPGRADE}; | |
11 | use hyper::upgrade::Upgraded; | |
12 | use hyper::{Body, Request, StatusCode}; | |
13 | ||
14 | use openssl::ssl::{SslConnector, SslMethod}; | |
15 | use percent_encoding::{utf8_percent_encode, NON_ALPHANUMERIC}; | |
16 | ||
17 | use serde::{Deserialize, Serialize}; | |
18 | use serde_json::{Map, Value}; | |
19 | use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; | |
20 | use tokio::net::{UnixListener, UnixStream}; | |
21 | use tokio::sync::{mpsc, oneshot}; | |
22 | use tokio_stream::wrappers::LinesStream; | |
23 | use tokio_stream::StreamExt; | |
24 | ||
25 | use proxmox_http::client::HttpsConnector; | |
26 | use proxmox_http::websocket::{OpCode, WebSocket, WebSocketReader, WebSocketWriter}; | |
27 | ||
28 | #[derive(Serialize, Deserialize, Debug)] | |
29 | #[serde(rename_all = "kebab-case")] | |
30 | enum CmdType { | |
31 | Connect, | |
32 | Forward, | |
33 | NonControl, | |
34 | } | |
35 | ||
36 | type CmdData = Map<String, Value>; | |
37 | ||
38 | #[derive(Serialize, Deserialize, Debug)] | |
39 | #[serde(rename_all = "kebab-case")] | |
40 | struct ConnectCmdData { | |
941f0305 | 41 | /// target URL for WS connection |
c18e63b9 | 42 | url: String, |
941f0305 WB |
43 | |
44 | /// fingerprint of TLS certificate | |
c18e63b9 | 45 | fingerprint: Option<String>, |
941f0305 WB |
46 | |
47 | /// addition headers such as authorization | |
c18e63b9 FG |
48 | headers: Option<Vec<(String, String)>>, |
49 | } | |
50 | ||
51 | #[derive(Serialize, Deserialize, Debug, Clone)] | |
52 | #[serde(rename_all = "kebab-case")] | |
53 | struct ForwardCmdData { | |
941f0305 | 54 | /// target URL for WS connection |
c18e63b9 | 55 | url: String, |
941f0305 WB |
56 | |
57 | /// addition headers such as authorization | |
c18e63b9 | 58 | headers: Option<Vec<(String, String)>>, |
941f0305 WB |
59 | |
60 | /// fingerprint of TLS certificate | |
c18e63b9 | 61 | fingerprint: Option<String>, |
941f0305 WB |
62 | |
63 | /// local UNIX socket path for forwarding | |
c18e63b9 | 64 | unix: String, |
941f0305 WB |
65 | |
66 | /// request ticket using these parameters | |
c18e63b9 FG |
67 | ticket: Option<Map<String, Value>>, |
68 | } | |
69 | ||
70 | struct CtrlTunnel { | |
71 | sender: Option<mpsc::UnboundedSender<(Value, oneshot::Sender<String>)>>, | |
72 | } | |
73 | ||
74 | impl CtrlTunnel { | |
75 | async fn read_cmd_loop(mut self) -> Result<(), Error> { | |
76 | let mut stdin_stream = LinesStream::new(BufReader::new(tokio::io::stdin()).lines()); | |
77 | while let Some(res) = stdin_stream.next().await { | |
78 | match res { | |
79 | Ok(line) => { | |
80 | let (cmd_type, data) = Self::parse_cmd(&line)?; | |
81 | match cmd_type { | |
82 | CmdType::Connect => self.handle_connect_cmd(data).await, | |
83 | CmdType::Forward => { | |
84 | let res = self.handle_forward_cmd(data).await; | |
85 | match &res { | |
86 | Ok(()) => println!("{}", serde_json::json!({"success": true})), | |
87 | Err(msg) => println!( | |
88 | "{}", | |
89 | serde_json::json!({"success": false, "msg": msg.to_string()}) | |
90 | ), | |
91 | }; | |
92 | res | |
93 | } | |
94 | CmdType::NonControl => self | |
95 | .handle_tunnel_cmd(data) | |
96 | .await | |
97 | .map(|res| println!("{}", res)), | |
98 | } | |
99 | } | |
100 | Err(err) => bail!("Failed to read from STDIN - {}", err), | |
101 | }?; | |
102 | } | |
103 | ||
104 | Ok(()) | |
105 | } | |
106 | ||
107 | fn parse_cmd(line: &str) -> Result<(CmdType, CmdData), Error> { | |
108 | let mut json: Map<String, Value> = serde_json::from_str(line)?; | |
109 | match json.remove("control") { | |
110 | Some(Value::Bool(true)) => { | |
111 | match json.remove("cmd").map(serde_json::from_value::<CmdType>) { | |
112 | None => bail!("input has 'control' flag, but no control 'cmd' set.."), | |
113 | Some(Err(e)) => bail!("failed to parse control cmd - {}", e), | |
114 | Some(Ok(cmd_type)) => Ok((cmd_type, json)), | |
115 | } | |
116 | } | |
117 | _ => Ok((CmdType::NonControl, json)), | |
118 | } | |
119 | } | |
120 | ||
121 | async fn websocket_connect( | |
122 | url: String, | |
123 | extra_headers: Vec<(String, String)>, | |
124 | fingerprint: Option<String>, | |
125 | ) -> Result<Upgraded, Error> { | |
126 | let ws_key = proxmox_sys::linux::random_data(16)?; | |
127 | let ws_key = base64::encode(&ws_key); | |
128 | let mut req = Request::builder() | |
129 | .uri(url) | |
130 | .header(UPGRADE, "websocket") | |
131 | .header(SEC_WEBSOCKET_VERSION, "13") | |
132 | .header(SEC_WEBSOCKET_KEY, ws_key) | |
133 | .body(Body::empty())?; | |
134 | ||
135 | let headers = req.headers_mut(); | |
136 | for (name, value) in extra_headers { | |
137 | let name = hyper::header::HeaderName::from_bytes(name.as_bytes())?; | |
138 | let value = hyper::header::HeaderValue::from_str(&value)?; | |
139 | headers.insert(name, value); | |
140 | } | |
141 | ||
142 | let mut ssl_connector_builder = SslConnector::builder(SslMethod::tls())?; | |
838b8aaf FG |
143 | if let Some(expected) = fingerprint { |
144 | ssl_connector_builder.set_verify_callback( | |
145 | openssl::ssl::SslVerifyMode::PEER, | |
146 | move |_valid, ctx| { | |
147 | let cert = match ctx.current_cert() { | |
148 | Some(cert) => cert, | |
149 | None => { | |
150 | // should not happen | |
151 | eprintln!("SSL context lacks current certificate."); | |
152 | return false; | |
153 | } | |
154 | }; | |
155 | ||
156 | // skip CA certificates, we only care about the peer cert | |
157 | let depth = ctx.error_depth(); | |
158 | if depth != 0 { | |
159 | return true; | |
160 | } | |
161 | ||
162 | let fp = match cert.digest(openssl::hash::MessageDigest::sha256()) { | |
163 | Ok(fp) => fp, | |
164 | Err(err) => { | |
165 | // should not happen | |
166 | eprintln!("failed to calculate certificate FP - {}", err); | |
167 | return false; | |
168 | } | |
169 | }; | |
170 | let fp_string = hex::encode(&fp); | |
171 | let fp_string = fp_string | |
172 | .as_bytes() | |
173 | .chunks(2) | |
174 | .map(|v| std::str::from_utf8(v).unwrap()) | |
175 | .collect::<Vec<&str>>() | |
176 | .join(":"); | |
177 | ||
178 | let expected = expected.to_lowercase(); | |
179 | if expected == fp_string { | |
180 | true | |
181 | } else { | |
182 | eprintln!("certificate fingerprint does not match expected fingerprint!"); | |
183 | eprintln!("expected: {}", expected); | |
184 | eprintln!("encountered: {}", fp_string); | |
185 | false | |
186 | } | |
187 | }, | |
188 | ); | |
c18e63b9 FG |
189 | } else { |
190 | ssl_connector_builder.set_verify(openssl::ssl::SslVerifyMode::PEER); | |
191 | } | |
192 | ||
193 | let mut httpc = HttpConnector::new(); | |
194 | httpc.enforce_http(false); // we want https... | |
195 | httpc.set_connect_timeout(Some(std::time::Duration::new(10, 0))); | |
196 | let https = HttpsConnector::with_connector(httpc, ssl_connector_builder.build(), 120); | |
197 | ||
198 | let client = Client::builder().build::<_, Body>(https); | |
199 | let res = client.request(req).await?; | |
200 | if res.status() != StatusCode::SWITCHING_PROTOCOLS { | |
201 | bail!("server didn't upgrade: {}", res.status()); | |
202 | } | |
203 | ||
204 | hyper::upgrade::on(res) | |
205 | .await | |
206 | .map_err(|err| format_err!("failed to upgrade - {}", err)) | |
207 | } | |
208 | ||
209 | async fn handle_connect_cmd(&mut self, mut data: CmdData) -> Result<(), Error> { | |
210 | let mut data: ConnectCmdData = data | |
211 | .remove("data") | |
212 | .ok_or_else(|| format_err!("'connect' command missing 'data'")) | |
213 | .map(serde_json::from_value)??; | |
214 | ||
215 | if self.sender.is_some() { | |
216 | bail!("already connected!"); | |
217 | } | |
218 | ||
219 | let upgraded = Self::websocket_connect( | |
220 | data.url.clone(), | |
221 | data.headers.take().unwrap_or_else(Vec::new), | |
222 | data.fingerprint.take(), | |
223 | ) | |
224 | .await?; | |
225 | ||
226 | let (tx, rx) = mpsc::unbounded_channel(); | |
227 | self.sender = Some(tx); | |
228 | tokio::spawn(async move { | |
229 | if let Err(err) = Self::handle_ctrl_tunnel(upgraded, rx).await { | |
230 | eprintln!("Tunnel to {} failed - {}", data.url, err); | |
231 | } | |
232 | }); | |
233 | ||
234 | Ok(()) | |
235 | } | |
236 | ||
237 | async fn handle_forward_cmd(&mut self, mut data: CmdData) -> Result<(), Error> { | |
238 | let data: ForwardCmdData = data | |
239 | .remove("data") | |
240 | .ok_or_else(|| format_err!("'forward' command missing 'data'")) | |
241 | .map(serde_json::from_value)??; | |
242 | ||
243 | if self.sender.is_none() && data.ticket.is_some() { | |
244 | bail!("dynamically requesting ticket requires cmd tunnel connection!"); | |
245 | } | |
246 | ||
247 | let unix_listener = UnixListener::bind(data.unix.clone())?; | |
248 | let data = Arc::new(data); | |
249 | let cmd_sender = self.sender.clone(); | |
250 | ||
251 | tokio::spawn(async move { | |
252 | let mut tasks: Vec<tokio::task::JoinHandle<()>> = Vec::new(); | |
253 | let data2 = data.clone(); | |
254 | ||
255 | loop { | |
256 | let data3 = data2.clone(); | |
257 | ||
258 | match unix_listener.accept().await { | |
259 | Ok((unix_stream, _)) => { | |
260 | eprintln!("accepted new connection on '{}'", data3.unix); | |
261 | let cmd_sender2 = cmd_sender.clone(); | |
262 | ||
263 | let task = tokio::spawn(async move { | |
264 | if let Err(err) = Self::handle_forward_tunnel( | |
265 | cmd_sender2.clone(), | |
266 | data3.clone(), | |
267 | unix_stream, | |
268 | ) | |
269 | .await | |
270 | { | |
271 | eprintln!("Tunnel for {} failed - {}", data3.unix, err); | |
272 | } | |
273 | }); | |
274 | tasks.push(task); | |
275 | } | |
276 | Err(err) => eprintln!( | |
277 | "Failed to accept unix connection on {} - {}", | |
278 | data3.unix, err | |
279 | ), | |
280 | }; | |
281 | } | |
282 | }); | |
283 | ||
284 | Ok(()) | |
285 | } | |
286 | ||
287 | async fn handle_tunnel_cmd(&mut self, data: CmdData) -> Result<String, Error> { | |
288 | match &mut self.sender { | |
289 | None => bail!("not connected!"), | |
290 | Some(sender) => { | |
291 | let data: Value = data.into(); | |
292 | let (tx, rx) = oneshot::channel::<String>(); | |
293 | if let Some(cmd) = data.get("cmd") { | |
294 | eprintln!("-> sending command {} to remote", cmd); | |
295 | } else { | |
296 | eprintln!("-> sending data line to remote"); | |
297 | } | |
298 | sender.send((data, tx))?; | |
299 | let res = rx.await?; | |
300 | eprintln!("<- got reply"); | |
301 | Ok(res) | |
302 | } | |
303 | } | |
304 | } | |
305 | ||
306 | async fn handle_ctrl_tunnel( | |
307 | websocket: Upgraded, | |
308 | mut cmd_receiver: mpsc::UnboundedReceiver<(Value, oneshot::Sender<String>)>, | |
309 | ) -> Result<(), Error> { | |
310 | let (tunnel_reader, tunnel_writer) = tokio::io::split(websocket); | |
311 | let (ws_close_tx, mut ws_close_rx) = mpsc::unbounded_channel(); | |
312 | let ws_reader = WebSocketReader::new(tunnel_reader, ws_close_tx); | |
313 | let mut ws_writer = WebSocketWriter::new(Some([0, 0, 0, 0]), tunnel_writer); | |
314 | ||
315 | let mut framed_reader = | |
316 | tokio_util::codec::FramedRead::new(ws_reader, tokio_util::codec::LinesCodec::new()); | |
317 | ||
318 | let mut resp_tx_queue: VecDeque<oneshot::Sender<String>> = VecDeque::new(); | |
319 | let mut shutting_down = false; | |
320 | ||
321 | loop { | |
322 | let mut close_future = ws_close_rx.recv().boxed().fuse(); | |
323 | let mut frame_future = framed_reader.next().boxed().fuse(); | |
324 | let mut cmd_future = cmd_receiver.recv().boxed().fuse(); | |
325 | ||
326 | select! { | |
327 | res = close_future => { | |
328 | let res = res.ok_or_else(|| format_err!("WS control channel closed"))?; | |
329 | eprintln!("WS: received control message: '{:?}'", res); | |
330 | shutting_down = true; | |
331 | }, | |
332 | res = frame_future => { | |
333 | match res { | |
334 | None if shutting_down => { | |
335 | eprintln!("WS closed"); | |
336 | break; | |
337 | }, | |
338 | None => bail!("WS closed unexpectedly"), | |
339 | Some(Ok(res)) => { | |
340 | resp_tx_queue | |
341 | .pop_front() | |
342 | .ok_or_else(|| format_err!("no response handler"))? | |
343 | .send(res) | |
344 | .map_err(|msg| format_err!("failed to send tunnel response '{}' back to requester - receiver already closed?", msg))?; | |
345 | }, | |
346 | Some(Err(err)) => { | |
347 | bail!("reading from control tunnel failed - WS receive failed: {}", err); | |
348 | }, | |
349 | } | |
350 | }, | |
351 | res = cmd_future => { | |
352 | if shutting_down { continue }; | |
353 | match res { | |
354 | None => { | |
355 | eprintln!("CMD channel closed, shutting down"); | |
356 | ws_writer.send_control_frame(Some([1,2,3,4]), OpCode::Close, &[]).await?; | |
357 | shutting_down = true; | |
358 | }, | |
359 | Some((msg, resp_tx)) => { | |
360 | resp_tx_queue.push_back(resp_tx); | |
361 | ||
362 | let line = format!("{}\n", msg); | |
363 | ws_writer.write_all(line.as_bytes()).await?; | |
364 | ws_writer.flush().await?; | |
365 | }, | |
366 | } | |
367 | }, | |
368 | }; | |
369 | } | |
370 | ||
371 | Ok(()) | |
372 | } | |
373 | ||
374 | async fn handle_forward_tunnel( | |
375 | cmd_sender: Option<mpsc::UnboundedSender<(Value, oneshot::Sender<String>)>>, | |
376 | data: Arc<ForwardCmdData>, | |
377 | unix: UnixStream, | |
378 | ) -> Result<(), Error> { | |
379 | let data = match (&cmd_sender, &data.ticket) { | |
380 | (Some(cmd_sender), Some(_)) => Self::get_ticket(cmd_sender, data.clone()).await, | |
381 | _ => Ok(data.clone()), | |
382 | }?; | |
383 | ||
384 | let upgraded = Self::websocket_connect( | |
385 | data.url.clone(), | |
386 | data.headers.clone().unwrap_or_else(Vec::new), | |
387 | data.fingerprint.clone(), | |
388 | ) | |
389 | .await?; | |
390 | ||
391 | let ws = WebSocket { | |
392 | mask: Some([0, 0, 0, 0]), | |
393 | }; | |
394 | eprintln!("established new WS for forwarding '{}'", data.unix); | |
395 | ws.serve_connection(upgraded, unix).await?; | |
396 | ||
397 | eprintln!("done handling forwarded connection from '{}'", data.unix); | |
398 | ||
399 | Ok(()) | |
400 | } | |
401 | ||
402 | async fn get_ticket( | |
403 | cmd_sender: &mpsc::UnboundedSender<(Value, oneshot::Sender<String>)>, | |
404 | cmd_data: Arc<ForwardCmdData>, | |
405 | ) -> Result<Arc<ForwardCmdData>, Error> { | |
406 | eprintln!("requesting WS ticket via tunnel"); | |
407 | let ticket_cmd = match cmd_data.ticket.clone() { | |
408 | Some(mut ticket_cmd) => { | |
409 | ticket_cmd.insert("cmd".to_string(), serde_json::json!("ticket")); | |
410 | ticket_cmd | |
411 | } | |
412 | None => bail!("can't get ticket without ticket parameters"), | |
413 | }; | |
414 | let (tx, rx) = oneshot::channel::<String>(); | |
415 | cmd_sender.send((serde_json::json!(ticket_cmd), tx))?; | |
416 | let ticket = rx.await?; | |
417 | let mut ticket: Map<String, Value> = serde_json::from_str(&ticket)?; | |
418 | let ticket = ticket | |
419 | .remove("ticket") | |
420 | .ok_or_else(|| format_err!("failed to retrieve ticket via tunnel"))?; | |
421 | ||
422 | let ticket = ticket | |
423 | .as_str() | |
424 | .ok_or_else(|| format_err!("failed to format received ticket"))?; | |
425 | let ticket = utf8_percent_encode(ticket, NON_ALPHANUMERIC).to_string(); | |
426 | ||
427 | let mut data = cmd_data.clone(); | |
428 | let mut url = data.url.clone(); | |
429 | url.push_str("ticket="); | |
430 | url.push_str(&ticket); | |
431 | let mut d = Arc::make_mut(&mut data); | |
432 | d.url = url; | |
433 | Ok(data) | |
434 | } | |
435 | } | |
436 | ||
437 | #[tokio::main] | |
438 | async fn main() -> Result<(), Error> { | |
c18e63b9 FG |
439 | let tunnel = CtrlTunnel { sender: None }; |
440 | tunnel.read_cmd_loop().await | |
441 | } |