]> git.proxmox.com Git - proxmox.git/blob - proxmox-http/src/websocket/mod.rs
fix deprecated use of std::u16 module
[proxmox.git] / proxmox-http / src / websocket / mod.rs
1 //! Websocket helpers
2 //!
3 //! Provides methods to read and write from websockets The reader and writer take a reader/writer
4 //! with AsyncRead/AsyncWrite respectively and provides the same
5
6 use std::cmp::min;
7 use std::future::Future;
8 use std::io;
9 use std::pin::Pin;
10 use std::task::{Context, Poll};
11
12 use anyhow::{bail, format_err, Error};
13 use futures::select;
14 use hyper::header::{
15 HeaderMap, HeaderValue, CONNECTION, SEC_WEBSOCKET_ACCEPT, SEC_WEBSOCKET_KEY,
16 SEC_WEBSOCKET_PROTOCOL, SEC_WEBSOCKET_VERSION, UPGRADE,
17 };
18 use hyper::{Body, Response, StatusCode};
19 use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf};
20 use tokio::sync::mpsc;
21
22 use futures::future::FutureExt;
23 use futures::ready;
24
25 use proxmox::sys::error::io_err_other;
26 use proxmox::tools::byte_buffer::ByteBuffer;
27
28 // see RFC6455 section 7.4.1
29 #[derive(Debug, Clone, Copy)]
30 #[repr(u16)]
31 pub enum WebSocketErrorKind {
32 Normal = 1000,
33 ProtocolError = 1002,
34 InvalidData = 1003,
35 Other = 1008,
36 Unexpected = 1011,
37 }
38
39 impl WebSocketErrorKind {
40 #[inline]
41 pub fn to_be_bytes(self) -> [u8; 2] {
42 (self as u16).to_be_bytes()
43 }
44 }
45
46 impl std::fmt::Display for WebSocketErrorKind {
47 fn fmt(&self, f: &mut std::fmt::Formatter) -> Result<(), std::fmt::Error> {
48 write!(f, "{}", *self as u16)
49 }
50 }
51
52 #[derive(Debug, Clone)]
53 pub struct WebSocketError {
54 kind: WebSocketErrorKind,
55 message: String,
56 }
57
58 impl WebSocketError {
59 pub fn new(kind: WebSocketErrorKind, message: &str) -> Self {
60 Self {
61 kind,
62 message: message.to_string(),
63 }
64 }
65
66 pub fn generate_frame_payload(&self) -> Vec<u8> {
67 let msglen = self.message.len().min(125);
68 let code = self.kind.to_be_bytes();
69 let mut data = Vec::with_capacity(msglen + 2);
70 data.extend_from_slice(&code);
71 data.extend_from_slice(&self.message.as_bytes()[..msglen]);
72 data
73 }
74 }
75
76 impl std::fmt::Display for WebSocketError {
77 fn fmt(&self, f: &mut std::fmt::Formatter) -> Result<(), std::fmt::Error> {
78 write!(f, "{} (Code: {})", self.message, self.kind)
79 }
80 }
81
82 impl std::error::Error for WebSocketError {}
83
84 #[repr(u8)]
85 #[derive(Debug, PartialEq, PartialOrd, Copy, Clone)]
86 /// Represents an OpCode of a websocket frame
87 pub enum OpCode {
88 /// A fragmented frame
89 Continuation = 0,
90 /// A non-fragmented text frame
91 Text = 1,
92 /// A non-fragmented binary frame
93 Binary = 2,
94 /// A closing frame
95 Close = 8,
96 /// A ping frame
97 Ping = 9,
98 /// A pong frame
99 Pong = 10,
100 }
101
102 impl OpCode {
103 /// Tells whether it is a control frame or not
104 pub fn is_control(self) -> bool {
105 (self as u8 & 0b1000) > 0
106 }
107 }
108
109 fn mask_bytes(mask: Option<[u8; 4]>, data: &mut [u8]) {
110 let mask = match mask {
111 Some([0, 0, 0, 0]) | None => return,
112 Some(mask) => mask,
113 };
114
115 if data.len() < 32 {
116 for i in 0..data.len() {
117 data[i] ^= mask[i % 4];
118 }
119 return;
120 }
121
122 let mut newmask: u32 = u32::from_le_bytes(mask);
123
124 let (prefix, middle, suffix) = unsafe { data.align_to_mut::<u32>() };
125
126 for p in prefix {
127 *p ^= newmask as u8;
128 newmask = newmask.rotate_right(8);
129 }
130
131 for m in middle {
132 *m ^= newmask;
133 }
134
135 for s in suffix {
136 *s ^= newmask as u8;
137 newmask = newmask.rotate_right(8);
138 }
139 }
140
141 /// Can be used to create a complete WebSocket Frame.
142 ///
143 /// Takes an optional mask, the data and the frame type
144 ///
145 /// Examples:
146 ///
147 /// A normal Frame
148 /// ```
149 /// # use proxmox_http::websocket::*;
150 /// # use std::io;
151 /// # fn main() -> Result<(), WebSocketError> {
152 /// let data = vec![0,1,2,3,4];
153 /// let frame = create_frame(None, &data, OpCode::Text)?;
154 /// assert_eq!(frame, vec![0b10000001, 5, 0, 1, 2, 3, 4]);
155 /// # Ok(())
156 /// # }
157 ///
158 /// ```
159 ///
160 /// A masked Frame
161 /// ```
162 /// # use proxmox_http::websocket::*;
163 /// # use std::io;
164 /// # fn main() -> Result<(), WebSocketError> {
165 /// let data = vec![0,1,2,3,4];
166 /// let frame = create_frame(Some([0u8, 1u8, 2u8, 3u8]), &data, OpCode::Text)?;
167 /// assert_eq!(frame, vec![0b10000001, 0b10000101, 0, 1, 2, 3, 0, 0, 0, 0, 4]);
168 /// # Ok(())
169 /// # }
170 ///
171 /// ```
172 ///
173 /// A ping Frame
174 /// ```
175 /// # use proxmox_http::websocket::*;
176 /// # use std::io;
177 /// # fn main() -> Result<(), WebSocketError> {
178 /// let data = vec![0,1,2,3,4];
179 /// let frame = create_frame(None, &data, OpCode::Ping)?;
180 /// assert_eq!(frame, vec![0b10001001, 0b00000101, 0, 1, 2, 3, 4]);
181 /// # Ok(())
182 /// # }
183 ///
184 /// ```
185 pub fn create_frame(
186 mask: Option<[u8; 4]>,
187 data: &[u8],
188 frametype: OpCode,
189 ) -> Result<Vec<u8>, WebSocketError> {
190 let first_byte = 0b10000000 | (frametype as u8);
191 let len = data.len();
192 if (frametype as u8) & 0b00001000 > 0 && len > 125 {
193 return Err(WebSocketError::new(
194 WebSocketErrorKind::Unexpected,
195 "Control frames cannot have data longer than 125 bytes",
196 ));
197 }
198
199 let mask_bit = if mask.is_some() {
200 0b10000000
201 } else {
202 0b00000000
203 };
204
205 let mut buf = vec![first_byte];
206
207 if len < 126 {
208 buf.push(mask_bit | (len as u8));
209 } else if len < u16::MAX as usize {
210 buf.push(mask_bit | 126);
211 buf.extend_from_slice(&(len as u16).to_be_bytes());
212 } else {
213 buf.push(mask_bit | 127);
214 buf.extend_from_slice(&(len as u64).to_be_bytes());
215 }
216
217 if let Some(mask) = mask {
218 buf.extend_from_slice(&mask);
219 }
220 let mut data = data.to_vec().into_boxed_slice();
221 mask_bytes(mask, &mut data);
222
223 buf.append(&mut data.into_vec());
224 Ok(buf)
225 }
226
227 /// Wraps a writer that implements AsyncWrite
228 ///
229 /// Can be used to send websocket frames to any writer that implements
230 /// AsyncWrite. Every write to it gets encoded as a seperate websocket frame,
231 /// without fragmentation.
232 ///
233 /// Example usage:
234 /// ```
235 /// # use proxmox_http::websocket::*;
236 /// # use std::io;
237 /// # use tokio::io::{AsyncWrite, AsyncWriteExt};
238 /// async fn code<I: AsyncWrite + Unpin>(writer: I) -> io::Result<()> {
239 /// let mut ws = WebSocketWriter::new(None, writer);
240 /// ws.write(&[1u8,2u8,3u8]).await?;
241 /// Ok(())
242 /// }
243 /// ```
244 pub struct WebSocketWriter<W: AsyncWrite + Unpin> {
245 writer: W,
246 mask: Option<[u8; 4]>,
247 frame: Option<(Vec<u8>, usize, usize)>,
248 }
249
250 impl<W: AsyncWrite + Unpin> WebSocketWriter<W> {
251 /// Creates a new WebSocketWriter which will use the given mask (if any),
252 /// and mark the frames as either 'Text' or 'Binary'
253 pub fn new(mask: Option<[u8; 4]>, writer: W) -> WebSocketWriter<W> {
254 WebSocketWriter {
255 writer,
256 mask,
257 frame: None,
258 }
259 }
260
261 pub async fn send_control_frame(
262 &mut self,
263 mask: Option<[u8; 4]>,
264 opcode: OpCode,
265 data: &[u8],
266 ) -> Result<(), Error> {
267 let frame = create_frame(mask, data, opcode).map_err(Error::from)?;
268 self.writer.write_all(&frame).await.map_err(Error::from)
269 }
270 }
271
272 impl<W: AsyncWrite + Unpin> AsyncWrite for WebSocketWriter<W> {
273 fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
274 let this = Pin::get_mut(self);
275
276 let frametype = OpCode::Binary;
277
278 if this.frame.is_none() {
279 // create frame buf
280 let frame = match create_frame(this.mask, buf, frametype) {
281 Ok(f) => f,
282 Err(e) => {
283 return Poll::Ready(Err(io_err_other(e)));
284 }
285 };
286 this.frame = Some((frame, 0, buf.len()));
287 }
288
289 // we have a frame in any case, so unwrap is ok
290 let (buf, pos, origsize) = this.frame.as_mut().unwrap();
291 loop {
292 match ready!(Pin::new(&mut this.writer).poll_write(cx, &buf[*pos..])) {
293 Ok(size) => {
294 *pos += size;
295 if *pos == buf.len() {
296 let size = *origsize;
297 this.frame = None;
298 return Poll::Ready(Ok(size));
299 }
300 }
301 Err(err) => {
302 eprintln!("error in writer: {}", err);
303 return Poll::Ready(Err(err));
304 }
305 }
306 }
307 }
308
309 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
310 let this = Pin::get_mut(self);
311 Pin::new(&mut this.writer).poll_flush(cx)
312 }
313
314 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
315 let this = Pin::get_mut(self);
316 Pin::new(&mut this.writer).poll_shutdown(cx)
317 }
318 }
319
320 #[derive(Debug, PartialEq)]
321 /// Represents the header of a websocket Frame
322 pub struct FrameHeader {
323 /// True if the frame is either non-fragmented, or the last fragment
324 pub fin: bool,
325 /// The optional mask of the frame
326 pub mask: Option<[u8; 4]>,
327 /// The frametype
328 pub frametype: OpCode,
329 /// The length of the header (without payload).
330 pub header_len: u8,
331 /// The length of the payload.
332 pub payload_len: usize,
333 }
334
335 impl FrameHeader {
336 /// Returns true if the frame is a control frame.
337 pub fn is_control_frame(&self) -> bool {
338 self.frametype.is_control()
339 }
340
341 /// Tries to parse a FrameHeader from bytes.
342 ///
343 /// When there are not enough bytes to completely parse the header,
344 /// returns Ok(None)
345 ///
346 /// Example:
347 /// ```
348 /// # use proxmox_http::websocket::*;
349 /// # use std::io;
350 /// # fn main() -> Result<(), WebSocketError> {
351 /// let frame = create_frame(None, &[0,1,2,3], OpCode::Ping)?;
352 /// let header = FrameHeader::try_from_bytes(&frame[..1])?;
353 /// match header {
354 /// Some(_) => unreachable!(),
355 /// None => {},
356 /// }
357 /// let header = FrameHeader::try_from_bytes(&frame[..2])?;
358 /// match header {
359 /// None => unreachable!(),
360 /// Some(header) => assert_eq!(header, FrameHeader{
361 /// fin: true,
362 /// mask: None,
363 /// frametype: OpCode::Ping,
364 /// header_len: 2,
365 /// payload_len: 4,
366 /// }),
367 /// }
368 /// # Ok(())
369 /// # }
370 /// ```
371 pub fn try_from_bytes(data: &[u8]) -> Result<Option<FrameHeader>, WebSocketError> {
372 let len = data.len();
373 if len < 2 {
374 return Ok(None);
375 }
376
377 let data = data;
378
379 // we do not support extensions
380 if data[0] & 0b01110000 > 0 {
381 return Err(WebSocketError::new(
382 WebSocketErrorKind::ProtocolError,
383 "Extensions not supported",
384 ));
385 }
386
387 let fin = data[0] & 0b10000000 != 0;
388 let frametype = match data[0] & 0b1111 {
389 0 => OpCode::Continuation,
390 1 => OpCode::Text,
391 2 => OpCode::Binary,
392 8 => OpCode::Close,
393 9 => OpCode::Ping,
394 10 => OpCode::Pong,
395 other => {
396 return Err(WebSocketError::new(
397 WebSocketErrorKind::ProtocolError,
398 &format!("Unknown OpCode {}", other),
399 ));
400 }
401 };
402
403 if !fin && frametype.is_control() {
404 return Err(WebSocketError::new(
405 WebSocketErrorKind::ProtocolError,
406 "Control frames cannot be fragmented",
407 ));
408 }
409
410 let mask_bit = data[1] & 0b10000000 != 0;
411 let mut mask_offset = 2;
412 let mut payload_offset = 2;
413 if mask_bit {
414 payload_offset += 4;
415 }
416
417 let mut payload_len: usize = (data[1] & 0b01111111).into();
418 if payload_len == 126 {
419 if len < 4 {
420 return Ok(None);
421 }
422 payload_len = u16::from_be_bytes([data[2], data[3]]) as usize;
423 mask_offset += 2;
424 payload_offset += 2;
425 } else if payload_len == 127 {
426 if len < 10 {
427 return Ok(None);
428 }
429 payload_len = u64::from_be_bytes([
430 data[2], data[3], data[4], data[5], data[6], data[7], data[8], data[9],
431 ]) as usize;
432 mask_offset += 8;
433 payload_offset += 8;
434 }
435
436 if payload_len > 125 && frametype.is_control() {
437 return Err(WebSocketError::new(
438 WebSocketErrorKind::ProtocolError,
439 "Control frames cannot carry more than 125 bytes of data",
440 ));
441 }
442
443 let mask = if mask_bit {
444 if len < mask_offset + 4 {
445 return Ok(None);
446 }
447 let mut mask = [0u8; 4];
448 mask.copy_from_slice(&data[mask_offset as usize..payload_offset as usize]);
449 Some(mask)
450 } else {
451 None
452 };
453
454 Ok(Some(FrameHeader {
455 fin,
456 mask,
457 frametype,
458 payload_len,
459 header_len: payload_offset,
460 }))
461 }
462 }
463
464 type WebSocketReadResult = Result<(OpCode, Box<[u8]>), WebSocketError>;
465
466 /// Wraps a reader that implements AsyncRead and implements it itself.
467 ///
468 /// On read, reads the underlying reader and tries to decode the frames and
469 /// simply returns the data stream.
470 /// When it encounters a control frame, calls the given callback.
471 ///
472 /// Has an internal Buffer for storing incomplete headers.
473 pub struct WebSocketReader<R: AsyncRead> {
474 reader: Option<R>,
475 sender: mpsc::UnboundedSender<WebSocketReadResult>,
476 read_buffer: Option<ByteBuffer>,
477 header: Option<FrameHeader>,
478 state: ReaderState<R>,
479 }
480
481 impl<R: AsyncRead> WebSocketReader<R> {
482 /// Creates a new WebSocketReader with the given CallBack for control frames
483 /// and a default buffer size of 4096.
484 pub fn new(
485 reader: R,
486 sender: mpsc::UnboundedSender<WebSocketReadResult>,
487 ) -> WebSocketReader<R> {
488 Self::with_capacity(reader, 4096, sender)
489 }
490
491 pub fn with_capacity(
492 reader: R,
493 capacity: usize,
494 sender: mpsc::UnboundedSender<WebSocketReadResult>,
495 ) -> WebSocketReader<R> {
496 WebSocketReader {
497 reader: Some(reader),
498 sender,
499 read_buffer: Some(ByteBuffer::with_capacity(capacity)),
500 header: None,
501 state: ReaderState::NoData,
502 }
503 }
504 }
505
506 struct ReadResult<R> {
507 len: usize,
508 reader: R,
509 buffer: ByteBuffer,
510 }
511
512 enum ReaderState<R> {
513 NoData,
514 Receiving(Pin<Box<dyn Future<Output = io::Result<ReadResult<R>>> + Send + 'static>>),
515 HaveData,
516 }
517
518 unsafe impl<R: Sync> Sync for ReaderState<R> {}
519
520 impl<R: AsyncRead + Unpin + Send + 'static> AsyncRead for WebSocketReader<R> {
521 fn poll_read(
522 self: Pin<&mut Self>,
523 cx: &mut Context,
524 buf: &mut ReadBuf,
525 ) -> Poll<io::Result<()>> {
526 let this = Pin::get_mut(self);
527
528 loop {
529 match &mut this.state {
530 ReaderState::NoData => {
531 let mut reader = match this.reader.take() {
532 Some(reader) => reader,
533 None => return Poll::Ready(Err(io_err_other("no reader"))),
534 };
535
536 let mut buffer = match this.read_buffer.take() {
537 Some(buffer) => buffer,
538 None => return Poll::Ready(Err(io_err_other("no buffer"))),
539 };
540
541 let future = async move {
542 buffer
543 .read_from_async(&mut reader)
544 .await
545 .map(move |len| ReadResult {
546 len,
547 reader,
548 buffer,
549 })
550 };
551
552 this.state = ReaderState::Receiving(future.boxed());
553 }
554 ReaderState::Receiving(ref mut future) => match ready!(future.as_mut().poll(cx)) {
555 Ok(ReadResult {
556 len,
557 reader,
558 buffer,
559 }) => {
560 this.reader = Some(reader);
561 this.read_buffer = Some(buffer);
562 this.state = ReaderState::HaveData;
563 if len == 0 {
564 return Poll::Ready(Ok(()));
565 }
566 }
567 Err(err) => return Poll::Ready(Err(err)),
568 },
569 ReaderState::HaveData => {
570 let mut read_buffer = match this.read_buffer.take() {
571 Some(read_buffer) => read_buffer,
572 None => return Poll::Ready(Err(io_err_other("no buffer"))),
573 };
574
575 let mut header = match this.header.take() {
576 Some(header) => header,
577 None => {
578 let header = match FrameHeader::try_from_bytes(&read_buffer[..]) {
579 Ok(Some(header)) => header,
580 Ok(None) => {
581 this.state = ReaderState::NoData;
582 this.read_buffer = Some(read_buffer);
583 continue;
584 }
585 Err(err) => {
586 if let Err(err) = this.sender.send(Err(err.clone())) {
587 return Poll::Ready(Err(io_err_other(err)));
588 }
589 return Poll::Ready(Err(io_err_other(err)));
590 }
591 };
592
593 read_buffer.consume(header.header_len as usize);
594 header
595 }
596 };
597
598 if header.is_control_frame() {
599 if read_buffer.len() >= header.payload_len {
600 let mut data = read_buffer.remove_data(header.payload_len);
601 mask_bytes(header.mask, &mut data);
602 if let Err(err) = this.sender.send(Ok((header.frametype, data))) {
603 eprintln!("error sending control frame: {}", err);
604 }
605
606 this.state = if read_buffer.is_empty() {
607 ReaderState::NoData
608 } else {
609 ReaderState::HaveData
610 };
611 this.read_buffer = Some(read_buffer);
612 continue;
613 } else {
614 this.header = Some(header);
615 this.read_buffer = Some(read_buffer);
616 this.state = ReaderState::NoData;
617 continue;
618 }
619 }
620
621 let len = min(buf.remaining(), min(header.payload_len, read_buffer.len()));
622
623 let mut data = read_buffer.remove_data(len);
624 mask_bytes(header.mask, &mut data);
625 buf.put_slice(&data);
626
627 header.payload_len -= len;
628
629 if header.payload_len > 0 {
630 this.header = Some(header);
631 }
632
633 this.state = if read_buffer.is_empty() {
634 ReaderState::NoData
635 } else {
636 ReaderState::HaveData
637 };
638 this.read_buffer = Some(read_buffer);
639
640 if len > 0 {
641 return Poll::Ready(Ok(()));
642 }
643 }
644 }
645 }
646 }
647 }
648
649 /// Global Identifier for WebSockets, see RFC6455
650 pub const MAGIC_WEBSOCKET_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
651
652 /// Provides methods for connecting a WebSocket endpoint with another
653 pub struct WebSocket;
654
655 impl WebSocket {
656 /// Returns a new WebSocket instance and the generates the correct
657 /// WebSocket response from request headers
658 pub fn new(headers: HeaderMap<HeaderValue>) -> Result<(Self, Response<Body>), Error> {
659 let protocols = headers
660 .get(UPGRADE)
661 .ok_or_else(|| format_err!("missing Upgrade header"))?
662 .to_str()?;
663
664 let version = headers
665 .get(SEC_WEBSOCKET_VERSION)
666 .ok_or_else(|| format_err!("missing websocket version"))?
667 .to_str()?;
668
669 let key = headers
670 .get(SEC_WEBSOCKET_KEY)
671 .ok_or_else(|| format_err!("missing websocket key"))?
672 .to_str()?;
673
674 if protocols != "websocket" {
675 bail!("invalid protocol name");
676 }
677
678 if version != "13" {
679 bail!("invalid websocket version");
680 }
681
682 // we ignore extensions
683
684 let mut sha1 = openssl::sha::Sha1::new();
685 let data = format!("{}{}", key, MAGIC_WEBSOCKET_GUID);
686 sha1.update(data.as_bytes());
687 let response_key = base64::encode(sha1.finish());
688
689 let mut response = Response::builder()
690 .status(StatusCode::SWITCHING_PROTOCOLS)
691 .header(UPGRADE, HeaderValue::from_static("websocket"))
692 .header(CONNECTION, HeaderValue::from_static("Upgrade"))
693 .header(SEC_WEBSOCKET_ACCEPT, response_key);
694
695 // FIXME: remove compat in PBS 3.x
696 //
697 // We currently do not support any subprotocols and we always send binary frames,
698 // but for backwards compatibilty we need to reply the requested protocols
699 if let Some(ws_proto) = headers.get(SEC_WEBSOCKET_PROTOCOL) {
700 response = response.header(SEC_WEBSOCKET_PROTOCOL, ws_proto)
701 }
702
703 let response = response.body(Body::empty())?;
704
705 Ok((Self, response))
706 }
707
708 async fn handle_channel_message<W>(
709 result: WebSocketReadResult,
710 writer: &mut WebSocketWriter<W>,
711 ) -> Result<OpCode, Error>
712 where
713 W: AsyncWrite + Unpin + Send,
714 {
715 match result {
716 Ok((OpCode::Ping, msg)) => {
717 writer.send_control_frame(None, OpCode::Pong, &msg).await?;
718 Ok(OpCode::Pong)
719 }
720 Ok((OpCode::Close, msg)) => {
721 writer.send_control_frame(None, OpCode::Close, &msg).await?;
722 Ok(OpCode::Close)
723 }
724 Ok((opcode, _)) => {
725 // ignore other frames
726 Ok(opcode)
727 }
728 Err(err) => {
729 writer
730 .send_control_frame(None, OpCode::Close, &err.generate_frame_payload())
731 .await?;
732 Err(Error::from(err))
733 }
734 }
735 }
736
737 async fn copy_to_websocket<R, W>(
738 mut reader: &mut R,
739 mut writer: &mut WebSocketWriter<W>,
740 receiver: &mut mpsc::UnboundedReceiver<WebSocketReadResult>,
741 ) -> Result<bool, Error>
742 where
743 R: AsyncRead + Unpin + Send,
744 W: AsyncWrite + Unpin + Send,
745 {
746 let mut buf = ByteBuffer::new();
747 let mut eof = false;
748 loop {
749 if !buf.is_full() {
750 let bytes = select! {
751 res = buf.read_from_async(&mut reader).fuse() => res?,
752 res = receiver.recv().fuse() => {
753 let res = res.ok_or_else(|| format_err!("control channel closed"))?;
754 match Self::handle_channel_message(res, &mut writer).await? {
755 OpCode::Close => return Ok(true),
756 _ => { continue; },
757 }
758 }
759 };
760
761 if bytes == 0 {
762 eof = true;
763 }
764 }
765 if buf.len() > 0 {
766 let bytes = writer.write(&buf).await?;
767 if bytes == 0 {
768 eof = true;
769 }
770 buf.consume(bytes);
771 }
772
773 if eof && buf.is_empty() {
774 return Ok(false);
775 }
776 }
777 }
778
779 /// Takes two endpoints and connects them via a websocket, where the
780 /// 'upstream' endpoint sends and receives WebSocket frames, while
781 /// 'downstream' only expects and sends raw data.
782 /// This method takes care of copying the data between endpoints, and
783 /// sending correct responses for control frames (e.g. a Pont to a Ping).
784 pub async fn serve_connection<S, L>(&self, upstream: S, downstream: L) -> Result<(), Error>
785 where
786 S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
787 L: AsyncRead + AsyncWrite + Unpin + Send,
788 {
789 let (usreader, uswriter) = tokio::io::split(upstream);
790 let (mut dsreader, mut dswriter) = tokio::io::split(downstream);
791
792 let (tx, mut rx) = mpsc::unbounded_channel();
793 let mut wsreader = WebSocketReader::new(usreader, tx);
794 let mut wswriter = WebSocketWriter::new(None, uswriter);
795
796 let ws_future = tokio::io::copy(&mut wsreader, &mut dswriter);
797 let term_future = Self::copy_to_websocket(&mut dsreader, &mut wswriter, &mut rx);
798
799 let res = select! {
800 res = ws_future.fuse() => match res {
801 Ok(_) => Ok(()),
802 Err(err) => Err(Error::from(err)),
803 },
804 res = term_future.fuse() => match res {
805 Ok(sent_close) if !sent_close => {
806 // status code 1000 => 0x03E8
807 wswriter.send_control_frame(None, OpCode::Close, &WebSocketErrorKind::Normal.to_be_bytes()).await?;
808 Ok(())
809 }
810 Ok(_) => Ok(()),
811 Err(err) => Err(err),
812 }
813 };
814
815 res
816 }
817 }