]>
Commit | Line | Data |
---|---|---|
c18e63b9 FG |
1 | use anyhow::{bail, format_err, Error}; |
2 | ||
3 | use std::collections::VecDeque; | |
4 | use std::sync::Arc; | |
5 | ||
6 | use futures::future::FutureExt; | |
7 | use futures::select; | |
8 | ||
9 | use hyper::client::{Client, HttpConnector}; | |
10 | use hyper::header::{SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_VERSION, UPGRADE}; | |
11 | use hyper::upgrade::Upgraded; | |
12 | use hyper::{Body, Request, StatusCode}; | |
13 | ||
14 | use openssl::ssl::{SslConnector, SslMethod}; | |
15 | use percent_encoding::{utf8_percent_encode, NON_ALPHANUMERIC}; | |
16 | ||
17 | use serde::{Deserialize, Serialize}; | |
18 | use serde_json::{Map, Value}; | |
19 | use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; | |
20 | use tokio::net::{UnixListener, UnixStream}; | |
21 | use tokio::sync::{mpsc, oneshot}; | |
22 | use tokio_stream::wrappers::LinesStream; | |
23 | use tokio_stream::StreamExt; | |
24 | ||
25 | use proxmox_http::client::HttpsConnector; | |
26 | use proxmox_http::websocket::{OpCode, WebSocket, WebSocketReader, WebSocketWriter}; | |
27 | ||
28 | #[derive(Serialize, Deserialize, Debug)] | |
29 | #[serde(rename_all = "kebab-case")] | |
30 | enum CmdType { | |
31 | Connect, | |
32 | Forward, | |
33 | NonControl, | |
34 | } | |
35 | ||
36 | type CmdData = Map<String, Value>; | |
37 | ||
38 | #[derive(Serialize, Deserialize, Debug)] | |
39 | #[serde(rename_all = "kebab-case")] | |
40 | struct ConnectCmdData { | |
41 | // target URL for WS connection | |
42 | url: String, | |
43 | // fingerprint of TLS certificate | |
44 | fingerprint: Option<String>, | |
45 | // addition headers such as authorization | |
46 | headers: Option<Vec<(String, String)>>, | |
47 | } | |
48 | ||
49 | #[derive(Serialize, Deserialize, Debug, Clone)] | |
50 | #[serde(rename_all = "kebab-case")] | |
51 | struct ForwardCmdData { | |
52 | // target URL for WS connection | |
53 | url: String, | |
54 | // addition headers such as authorization | |
55 | headers: Option<Vec<(String, String)>>, | |
56 | // fingerprint of TLS certificate | |
57 | fingerprint: Option<String>, | |
58 | // local UNIX socket path for forwarding | |
59 | unix: String, | |
60 | // request ticket using these parameters | |
61 | ticket: Option<Map<String, Value>>, | |
62 | } | |
63 | ||
64 | struct CtrlTunnel { | |
65 | sender: Option<mpsc::UnboundedSender<(Value, oneshot::Sender<String>)>>, | |
66 | } | |
67 | ||
68 | impl CtrlTunnel { | |
69 | async fn read_cmd_loop(mut self) -> Result<(), Error> { | |
70 | let mut stdin_stream = LinesStream::new(BufReader::new(tokio::io::stdin()).lines()); | |
71 | while let Some(res) = stdin_stream.next().await { | |
72 | match res { | |
73 | Ok(line) => { | |
74 | let (cmd_type, data) = Self::parse_cmd(&line)?; | |
75 | match cmd_type { | |
76 | CmdType::Connect => self.handle_connect_cmd(data).await, | |
77 | CmdType::Forward => { | |
78 | let res = self.handle_forward_cmd(data).await; | |
79 | match &res { | |
80 | Ok(()) => println!("{}", serde_json::json!({"success": true})), | |
81 | Err(msg) => println!( | |
82 | "{}", | |
83 | serde_json::json!({"success": false, "msg": msg.to_string()}) | |
84 | ), | |
85 | }; | |
86 | res | |
87 | } | |
88 | CmdType::NonControl => self | |
89 | .handle_tunnel_cmd(data) | |
90 | .await | |
91 | .map(|res| println!("{}", res)), | |
92 | } | |
93 | } | |
94 | Err(err) => bail!("Failed to read from STDIN - {}", err), | |
95 | }?; | |
96 | } | |
97 | ||
98 | Ok(()) | |
99 | } | |
100 | ||
101 | fn parse_cmd(line: &str) -> Result<(CmdType, CmdData), Error> { | |
102 | let mut json: Map<String, Value> = serde_json::from_str(line)?; | |
103 | match json.remove("control") { | |
104 | Some(Value::Bool(true)) => { | |
105 | match json.remove("cmd").map(serde_json::from_value::<CmdType>) { | |
106 | None => bail!("input has 'control' flag, but no control 'cmd' set.."), | |
107 | Some(Err(e)) => bail!("failed to parse control cmd - {}", e), | |
108 | Some(Ok(cmd_type)) => Ok((cmd_type, json)), | |
109 | } | |
110 | } | |
111 | _ => Ok((CmdType::NonControl, json)), | |
112 | } | |
113 | } | |
114 | ||
115 | async fn websocket_connect( | |
116 | url: String, | |
117 | extra_headers: Vec<(String, String)>, | |
118 | fingerprint: Option<String>, | |
119 | ) -> Result<Upgraded, Error> { | |
120 | let ws_key = proxmox_sys::linux::random_data(16)?; | |
121 | let ws_key = base64::encode(&ws_key); | |
122 | let mut req = Request::builder() | |
123 | .uri(url) | |
124 | .header(UPGRADE, "websocket") | |
125 | .header(SEC_WEBSOCKET_VERSION, "13") | |
126 | .header(SEC_WEBSOCKET_KEY, ws_key) | |
127 | .body(Body::empty())?; | |
128 | ||
129 | let headers = req.headers_mut(); | |
130 | for (name, value) in extra_headers { | |
131 | let name = hyper::header::HeaderName::from_bytes(name.as_bytes())?; | |
132 | let value = hyper::header::HeaderValue::from_str(&value)?; | |
133 | headers.insert(name, value); | |
134 | } | |
135 | ||
136 | let mut ssl_connector_builder = SslConnector::builder(SslMethod::tls())?; | |
838b8aaf FG |
137 | if let Some(expected) = fingerprint { |
138 | ssl_connector_builder.set_verify_callback( | |
139 | openssl::ssl::SslVerifyMode::PEER, | |
140 | move |_valid, ctx| { | |
141 | let cert = match ctx.current_cert() { | |
142 | Some(cert) => cert, | |
143 | None => { | |
144 | // should not happen | |
145 | eprintln!("SSL context lacks current certificate."); | |
146 | return false; | |
147 | } | |
148 | }; | |
149 | ||
150 | // skip CA certificates, we only care about the peer cert | |
151 | let depth = ctx.error_depth(); | |
152 | if depth != 0 { | |
153 | return true; | |
154 | } | |
155 | ||
156 | let fp = match cert.digest(openssl::hash::MessageDigest::sha256()) { | |
157 | Ok(fp) => fp, | |
158 | Err(err) => { | |
159 | // should not happen | |
160 | eprintln!("failed to calculate certificate FP - {}", err); | |
161 | return false; | |
162 | } | |
163 | }; | |
164 | let fp_string = hex::encode(&fp); | |
165 | let fp_string = fp_string | |
166 | .as_bytes() | |
167 | .chunks(2) | |
168 | .map(|v| std::str::from_utf8(v).unwrap()) | |
169 | .collect::<Vec<&str>>() | |
170 | .join(":"); | |
171 | ||
172 | let expected = expected.to_lowercase(); | |
173 | if expected == fp_string { | |
174 | true | |
175 | } else { | |
176 | eprintln!("certificate fingerprint does not match expected fingerprint!"); | |
177 | eprintln!("expected: {}", expected); | |
178 | eprintln!("encountered: {}", fp_string); | |
179 | false | |
180 | } | |
181 | }, | |
182 | ); | |
c18e63b9 FG |
183 | } else { |
184 | ssl_connector_builder.set_verify(openssl::ssl::SslVerifyMode::PEER); | |
185 | } | |
186 | ||
187 | let mut httpc = HttpConnector::new(); | |
188 | httpc.enforce_http(false); // we want https... | |
189 | httpc.set_connect_timeout(Some(std::time::Duration::new(10, 0))); | |
190 | let https = HttpsConnector::with_connector(httpc, ssl_connector_builder.build(), 120); | |
191 | ||
192 | let client = Client::builder().build::<_, Body>(https); | |
193 | let res = client.request(req).await?; | |
194 | if res.status() != StatusCode::SWITCHING_PROTOCOLS { | |
195 | bail!("server didn't upgrade: {}", res.status()); | |
196 | } | |
197 | ||
198 | hyper::upgrade::on(res) | |
199 | .await | |
200 | .map_err(|err| format_err!("failed to upgrade - {}", err)) | |
201 | } | |
202 | ||
203 | async fn handle_connect_cmd(&mut self, mut data: CmdData) -> Result<(), Error> { | |
204 | let mut data: ConnectCmdData = data | |
205 | .remove("data") | |
206 | .ok_or_else(|| format_err!("'connect' command missing 'data'")) | |
207 | .map(serde_json::from_value)??; | |
208 | ||
209 | if self.sender.is_some() { | |
210 | bail!("already connected!"); | |
211 | } | |
212 | ||
213 | let upgraded = Self::websocket_connect( | |
214 | data.url.clone(), | |
215 | data.headers.take().unwrap_or_else(Vec::new), | |
216 | data.fingerprint.take(), | |
217 | ) | |
218 | .await?; | |
219 | ||
220 | let (tx, rx) = mpsc::unbounded_channel(); | |
221 | self.sender = Some(tx); | |
222 | tokio::spawn(async move { | |
223 | if let Err(err) = Self::handle_ctrl_tunnel(upgraded, rx).await { | |
224 | eprintln!("Tunnel to {} failed - {}", data.url, err); | |
225 | } | |
226 | }); | |
227 | ||
228 | Ok(()) | |
229 | } | |
230 | ||
231 | async fn handle_forward_cmd(&mut self, mut data: CmdData) -> Result<(), Error> { | |
232 | let data: ForwardCmdData = data | |
233 | .remove("data") | |
234 | .ok_or_else(|| format_err!("'forward' command missing 'data'")) | |
235 | .map(serde_json::from_value)??; | |
236 | ||
237 | if self.sender.is_none() && data.ticket.is_some() { | |
238 | bail!("dynamically requesting ticket requires cmd tunnel connection!"); | |
239 | } | |
240 | ||
241 | let unix_listener = UnixListener::bind(data.unix.clone())?; | |
242 | let data = Arc::new(data); | |
243 | let cmd_sender = self.sender.clone(); | |
244 | ||
245 | tokio::spawn(async move { | |
246 | let mut tasks: Vec<tokio::task::JoinHandle<()>> = Vec::new(); | |
247 | let data2 = data.clone(); | |
248 | ||
249 | loop { | |
250 | let data3 = data2.clone(); | |
251 | ||
252 | match unix_listener.accept().await { | |
253 | Ok((unix_stream, _)) => { | |
254 | eprintln!("accepted new connection on '{}'", data3.unix); | |
255 | let cmd_sender2 = cmd_sender.clone(); | |
256 | ||
257 | let task = tokio::spawn(async move { | |
258 | if let Err(err) = Self::handle_forward_tunnel( | |
259 | cmd_sender2.clone(), | |
260 | data3.clone(), | |
261 | unix_stream, | |
262 | ) | |
263 | .await | |
264 | { | |
265 | eprintln!("Tunnel for {} failed - {}", data3.unix, err); | |
266 | } | |
267 | }); | |
268 | tasks.push(task); | |
269 | } | |
270 | Err(err) => eprintln!( | |
271 | "Failed to accept unix connection on {} - {}", | |
272 | data3.unix, err | |
273 | ), | |
274 | }; | |
275 | } | |
276 | }); | |
277 | ||
278 | Ok(()) | |
279 | } | |
280 | ||
281 | async fn handle_tunnel_cmd(&mut self, data: CmdData) -> Result<String, Error> { | |
282 | match &mut self.sender { | |
283 | None => bail!("not connected!"), | |
284 | Some(sender) => { | |
285 | let data: Value = data.into(); | |
286 | let (tx, rx) = oneshot::channel::<String>(); | |
287 | if let Some(cmd) = data.get("cmd") { | |
288 | eprintln!("-> sending command {} to remote", cmd); | |
289 | } else { | |
290 | eprintln!("-> sending data line to remote"); | |
291 | } | |
292 | sender.send((data, tx))?; | |
293 | let res = rx.await?; | |
294 | eprintln!("<- got reply"); | |
295 | Ok(res) | |
296 | } | |
297 | } | |
298 | } | |
299 | ||
300 | async fn handle_ctrl_tunnel( | |
301 | websocket: Upgraded, | |
302 | mut cmd_receiver: mpsc::UnboundedReceiver<(Value, oneshot::Sender<String>)>, | |
303 | ) -> Result<(), Error> { | |
304 | let (tunnel_reader, tunnel_writer) = tokio::io::split(websocket); | |
305 | let (ws_close_tx, mut ws_close_rx) = mpsc::unbounded_channel(); | |
306 | let ws_reader = WebSocketReader::new(tunnel_reader, ws_close_tx); | |
307 | let mut ws_writer = WebSocketWriter::new(Some([0, 0, 0, 0]), tunnel_writer); | |
308 | ||
309 | let mut framed_reader = | |
310 | tokio_util::codec::FramedRead::new(ws_reader, tokio_util::codec::LinesCodec::new()); | |
311 | ||
312 | let mut resp_tx_queue: VecDeque<oneshot::Sender<String>> = VecDeque::new(); | |
313 | let mut shutting_down = false; | |
314 | ||
315 | loop { | |
316 | let mut close_future = ws_close_rx.recv().boxed().fuse(); | |
317 | let mut frame_future = framed_reader.next().boxed().fuse(); | |
318 | let mut cmd_future = cmd_receiver.recv().boxed().fuse(); | |
319 | ||
320 | select! { | |
321 | res = close_future => { | |
322 | let res = res.ok_or_else(|| format_err!("WS control channel closed"))?; | |
323 | eprintln!("WS: received control message: '{:?}'", res); | |
324 | shutting_down = true; | |
325 | }, | |
326 | res = frame_future => { | |
327 | match res { | |
328 | None if shutting_down => { | |
329 | eprintln!("WS closed"); | |
330 | break; | |
331 | }, | |
332 | None => bail!("WS closed unexpectedly"), | |
333 | Some(Ok(res)) => { | |
334 | resp_tx_queue | |
335 | .pop_front() | |
336 | .ok_or_else(|| format_err!("no response handler"))? | |
337 | .send(res) | |
338 | .map_err(|msg| format_err!("failed to send tunnel response '{}' back to requester - receiver already closed?", msg))?; | |
339 | }, | |
340 | Some(Err(err)) => { | |
341 | bail!("reading from control tunnel failed - WS receive failed: {}", err); | |
342 | }, | |
343 | } | |
344 | }, | |
345 | res = cmd_future => { | |
346 | if shutting_down { continue }; | |
347 | match res { | |
348 | None => { | |
349 | eprintln!("CMD channel closed, shutting down"); | |
350 | ws_writer.send_control_frame(Some([1,2,3,4]), OpCode::Close, &[]).await?; | |
351 | shutting_down = true; | |
352 | }, | |
353 | Some((msg, resp_tx)) => { | |
354 | resp_tx_queue.push_back(resp_tx); | |
355 | ||
356 | let line = format!("{}\n", msg); | |
357 | ws_writer.write_all(line.as_bytes()).await?; | |
358 | ws_writer.flush().await?; | |
359 | }, | |
360 | } | |
361 | }, | |
362 | }; | |
363 | } | |
364 | ||
365 | Ok(()) | |
366 | } | |
367 | ||
368 | async fn handle_forward_tunnel( | |
369 | cmd_sender: Option<mpsc::UnboundedSender<(Value, oneshot::Sender<String>)>>, | |
370 | data: Arc<ForwardCmdData>, | |
371 | unix: UnixStream, | |
372 | ) -> Result<(), Error> { | |
373 | let data = match (&cmd_sender, &data.ticket) { | |
374 | (Some(cmd_sender), Some(_)) => Self::get_ticket(cmd_sender, data.clone()).await, | |
375 | _ => Ok(data.clone()), | |
376 | }?; | |
377 | ||
378 | let upgraded = Self::websocket_connect( | |
379 | data.url.clone(), | |
380 | data.headers.clone().unwrap_or_else(Vec::new), | |
381 | data.fingerprint.clone(), | |
382 | ) | |
383 | .await?; | |
384 | ||
385 | let ws = WebSocket { | |
386 | mask: Some([0, 0, 0, 0]), | |
387 | }; | |
388 | eprintln!("established new WS for forwarding '{}'", data.unix); | |
389 | ws.serve_connection(upgraded, unix).await?; | |
390 | ||
391 | eprintln!("done handling forwarded connection from '{}'", data.unix); | |
392 | ||
393 | Ok(()) | |
394 | } | |
395 | ||
396 | async fn get_ticket( | |
397 | cmd_sender: &mpsc::UnboundedSender<(Value, oneshot::Sender<String>)>, | |
398 | cmd_data: Arc<ForwardCmdData>, | |
399 | ) -> Result<Arc<ForwardCmdData>, Error> { | |
400 | eprintln!("requesting WS ticket via tunnel"); | |
401 | let ticket_cmd = match cmd_data.ticket.clone() { | |
402 | Some(mut ticket_cmd) => { | |
403 | ticket_cmd.insert("cmd".to_string(), serde_json::json!("ticket")); | |
404 | ticket_cmd | |
405 | } | |
406 | None => bail!("can't get ticket without ticket parameters"), | |
407 | }; | |
408 | let (tx, rx) = oneshot::channel::<String>(); | |
409 | cmd_sender.send((serde_json::json!(ticket_cmd), tx))?; | |
410 | let ticket = rx.await?; | |
411 | let mut ticket: Map<String, Value> = serde_json::from_str(&ticket)?; | |
412 | let ticket = ticket | |
413 | .remove("ticket") | |
414 | .ok_or_else(|| format_err!("failed to retrieve ticket via tunnel"))?; | |
415 | ||
416 | let ticket = ticket | |
417 | .as_str() | |
418 | .ok_or_else(|| format_err!("failed to format received ticket"))?; | |
419 | let ticket = utf8_percent_encode(ticket, NON_ALPHANUMERIC).to_string(); | |
420 | ||
421 | let mut data = cmd_data.clone(); | |
422 | let mut url = data.url.clone(); | |
423 | url.push_str("ticket="); | |
424 | url.push_str(&ticket); | |
425 | let mut d = Arc::make_mut(&mut data); | |
426 | d.url = url; | |
427 | Ok(data) | |
428 | } | |
429 | } | |
430 | ||
431 | #[tokio::main] | |
432 | async fn main() -> Result<(), Error> { | |
433 | do_main().await | |
434 | } | |
435 | ||
436 | async fn do_main() -> Result<(), Error> { | |
437 | let tunnel = CtrlTunnel { sender: None }; | |
438 | tunnel.read_cmd_loop().await | |
439 | } |