]> git.proxmox.com Git - proxmox.git/blob - proxmox/src/tools/websocket.rs
metrics: bump version to 0.3.1-1
[proxmox.git] / proxmox / src / tools / websocket.rs
1 //! Websocket helpers
2 //!
3 //! Provides methods to read and write from websockets The reader and writer take a reader/writer
4 //! with AsyncRead/AsyncWrite respectively and provides the same
5
6 use std::cmp::min;
7 use std::future::Future;
8 use std::io;
9 use std::pin::Pin;
10 use std::task::{Context, Poll};
11
12 use anyhow::{bail, format_err, Error};
13 use futures::select;
14 use hyper::header::{
15 HeaderMap, HeaderValue, CONNECTION, SEC_WEBSOCKET_ACCEPT, SEC_WEBSOCKET_KEY,
16 SEC_WEBSOCKET_PROTOCOL, SEC_WEBSOCKET_VERSION, UPGRADE,
17 };
18 use hyper::{Body, Response, StatusCode};
19 use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
20 use tokio::sync::mpsc;
21
22 use futures::future::FutureExt;
23 use futures::ready;
24
25 use crate::sys::error::io_err_other;
26 use crate::tools::byte_buffer::ByteBuffer;
27
28 // see RFC6455 section 7.4.1
29 #[derive(Debug, Clone, Copy)]
30 #[repr(u16)]
31 pub enum WebSocketErrorKind {
32 Normal = 1000,
33 ProtocolError = 1002,
34 InvalidData = 1003,
35 Other = 1008,
36 Unexpected = 1011,
37 }
38
39 impl WebSocketErrorKind {
40 #[inline]
41 pub fn to_be_bytes(self) -> [u8; 2] {
42 (self as u16).to_be_bytes()
43 }
44 }
45
46 impl std::fmt::Display for WebSocketErrorKind {
47 fn fmt(&self, f: &mut std::fmt::Formatter) -> Result<(), std::fmt::Error> {
48 write!(f, "{}", *self as u16)
49 }
50 }
51
52 #[derive(Debug, Clone)]
53 pub struct WebSocketError {
54 kind: WebSocketErrorKind,
55 message: String,
56 }
57
58 impl WebSocketError {
59 pub fn new(kind: WebSocketErrorKind, message: &str) -> Self {
60 Self {
61 kind,
62 message: message.to_string(),
63 }
64 }
65
66 pub fn generate_frame_payload(&self) -> Vec<u8> {
67 let msglen = self.message.len().min(125);
68 let code = self.kind.to_be_bytes();
69 let mut data = Vec::with_capacity(msglen + 2);
70 data.extend_from_slice(&code);
71 data.extend_from_slice(&self.message.as_bytes()[..msglen]);
72 data
73 }
74 }
75
76 impl std::fmt::Display for WebSocketError {
77 fn fmt(&self, f: &mut std::fmt::Formatter) -> Result<(), std::fmt::Error> {
78 write!(f, "{} (Code: {})", self.message, self.kind)
79 }
80 }
81
82 impl std::error::Error for WebSocketError {}
83
84 #[repr(u8)]
85 #[derive(Debug, PartialEq, PartialOrd, Copy, Clone)]
86 /// Represents an OpCode of a websocket frame
87 pub enum OpCode {
88 /// A fragmented frame
89 Continuation = 0,
90 /// A non-fragmented text frame
91 Text = 1,
92 /// A non-fragmented binary frame
93 Binary = 2,
94 /// A closing frame
95 Close = 8,
96 /// A ping frame
97 Ping = 9,
98 /// A pong frame
99 Pong = 10,
100 }
101
102 impl OpCode {
103 /// Tells whether it is a control frame or not
104 pub fn is_control(self) -> bool {
105 (self as u8 & 0b1000) > 0
106 }
107 }
108
109 fn mask_bytes(mask: Option<[u8; 4]>, data: &mut [u8]) {
110 let mask = match mask {
111 Some([0, 0, 0, 0]) | None => return,
112 Some(mask) => mask,
113 };
114
115 if data.len() < 32 {
116 for i in 0..data.len() {
117 data[i] ^= mask[i % 4];
118 }
119 return;
120 }
121
122 let mut newmask: u32 = u32::from_le_bytes(mask);
123
124 let (prefix, middle, suffix) = unsafe { data.align_to_mut::<u32>() };
125
126 for p in prefix {
127 *p ^= newmask as u8;
128 newmask = newmask.rotate_right(8);
129 }
130
131 for m in middle {
132 *m ^= newmask;
133 }
134
135 for s in suffix {
136 *s ^= newmask as u8;
137 newmask = newmask.rotate_right(8);
138 }
139 }
140
141 /// Can be used to create a complete WebSocket Frame.
142 ///
143 /// Takes an optional mask, the data and the frame type
144 ///
145 /// Examples:
146 ///
147 /// A normal Frame
148 /// ```
149 /// # use proxmox::tools::websocket::*;
150 /// # use std::io;
151 /// # fn main() -> Result<(), WebSocketError> {
152 /// let data = vec![0,1,2,3,4];
153 /// let frame = create_frame(None, &data, OpCode::Text)?;
154 /// assert_eq!(frame, vec![0b10000001, 5, 0, 1, 2, 3, 4]);
155 /// # Ok(())
156 /// # }
157 ///
158 /// ```
159 ///
160 /// A masked Frame
161 /// ```
162 /// # use proxmox::tools::websocket::*;
163 /// # use std::io;
164 /// # fn main() -> Result<(), WebSocketError> {
165 /// let data = vec![0,1,2,3,4];
166 /// let frame = create_frame(Some([0u8, 1u8, 2u8, 3u8]), &data, OpCode::Text)?;
167 /// assert_eq!(frame, vec![0b10000001, 0b10000101, 0, 1, 2, 3, 0, 0, 0, 0, 4]);
168 /// # Ok(())
169 /// # }
170 ///
171 /// ```
172 ///
173 /// A ping Frame
174 /// ```
175 /// # use proxmox::tools::websocket::*;
176 /// # use std::io;
177 /// # fn main() -> Result<(), WebSocketError> {
178 /// let data = vec![0,1,2,3,4];
179 /// let frame = create_frame(None, &data, OpCode::Ping)?;
180 /// assert_eq!(frame, vec![0b10001001, 0b00000101, 0, 1, 2, 3, 4]);
181 /// # Ok(())
182 /// # }
183 ///
184 /// ```
185 pub fn create_frame(
186 mask: Option<[u8; 4]>,
187 data: &[u8],
188 frametype: OpCode,
189 ) -> Result<Vec<u8>, WebSocketError> {
190 let first_byte = 0b10000000 | (frametype as u8);
191 let len = data.len();
192 if (frametype as u8) & 0b00001000 > 0 && len > 125 {
193 return Err(WebSocketError::new(
194 WebSocketErrorKind::Unexpected,
195 "Control frames cannot have data longer than 125 bytes",
196 ));
197 }
198
199 let mask_bit = if mask.is_some() {
200 0b10000000
201 } else {
202 0b00000000
203 };
204
205 let mut buf = Vec::new();
206 buf.push(first_byte);
207
208 if len < 126 {
209 buf.push(mask_bit | (len as u8));
210 } else if len < std::u16::MAX as usize {
211 buf.push(mask_bit | 126);
212 buf.extend_from_slice(&(len as u16).to_be_bytes());
213 } else {
214 buf.push(mask_bit | 127);
215 buf.extend_from_slice(&(len as u64).to_be_bytes());
216 }
217
218 if let Some(mask) = mask {
219 buf.extend_from_slice(&mask);
220 }
221 let mut data = data.to_vec().into_boxed_slice();
222 mask_bytes(mask, &mut data);
223
224 buf.append(&mut data.into_vec());
225 Ok(buf)
226 }
227
228 /// Wraps a writer that implements AsyncWrite
229 ///
230 /// Can be used to send websocket frames to any writer that implements
231 /// AsyncWrite. Every write to it gets encoded as a seperate websocket frame,
232 /// without fragmentation.
233 ///
234 /// Example usage:
235 /// ```
236 /// # use proxmox::tools::websocket::*;
237 /// # use std::io;
238 /// # use tokio::io::{AsyncWrite, AsyncWriteExt};
239 /// async fn code<I: AsyncWrite + Unpin>(writer: I) -> io::Result<()> {
240 /// let mut ws = WebSocketWriter::new(None, false, writer);
241 /// ws.write(&[1u8,2u8,3u8]).await?;
242 /// Ok(())
243 /// }
244 /// ```
245 pub struct WebSocketWriter<W: AsyncWrite + Unpin> {
246 writer: W,
247 text: bool,
248 mask: Option<[u8; 4]>,
249 frame: Option<(Vec<u8>, usize, usize)>,
250 }
251
252 impl<W: AsyncWrite + Unpin> WebSocketWriter<W> {
253 /// Creates a new WebSocketWriter which will use the given mask (if any),
254 /// and mark the frames as either 'Text' or 'Binary'
255 pub fn new(mask: Option<[u8; 4]>, text: bool, writer: W) -> WebSocketWriter<W> {
256 WebSocketWriter {
257 writer,
258 text,
259 mask,
260 frame: None,
261 }
262 }
263
264 pub async fn send_control_frame(
265 &mut self,
266 mask: Option<[u8; 4]>,
267 opcode: OpCode,
268 data: &[u8],
269 ) -> Result<(), Error> {
270 let frame = create_frame(mask, data, opcode).map_err(Error::from)?;
271 self.writer.write_all(&frame).await.map_err(Error::from)
272 }
273 }
274
275 impl<W: AsyncWrite + Unpin> AsyncWrite for WebSocketWriter<W> {
276 fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
277 let this = Pin::get_mut(self);
278
279 let frametype = if this.text {
280 OpCode::Text
281 } else {
282 OpCode::Binary
283 };
284
285 if this.frame.is_none() {
286 // create frame buf
287 let frame = match create_frame(this.mask, buf, frametype) {
288 Ok(f) => f,
289 Err(e) => {
290 return Poll::Ready(Err(io_err_other(e)));
291 }
292 };
293 this.frame = Some((frame, 0, buf.len()));
294 }
295
296 // we have a frame in any case, so unwrap is ok
297 let (buf, pos, origsize) = this.frame.as_mut().unwrap();
298 loop {
299 match ready!(Pin::new(&mut this.writer).poll_write(cx, &buf[*pos..])) {
300 Ok(size) => {
301 *pos += size;
302 if *pos == buf.len() {
303 let size = *origsize;
304 this.frame = None;
305 return Poll::Ready(Ok(size));
306 }
307 }
308 Err(err) => {
309 eprintln!("error in writer: {}", err);
310 return Poll::Ready(Err(err));
311 }
312 }
313 }
314 }
315
316 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
317 let this = Pin::get_mut(self);
318 Pin::new(&mut this.writer).poll_flush(cx)
319 }
320
321 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
322 let this = Pin::get_mut(self);
323 Pin::new(&mut this.writer).poll_shutdown(cx)
324 }
325 }
326
327 #[derive(Debug, PartialEq)]
328 /// Represents the header of a websocket Frame
329 pub struct FrameHeader {
330 /// True if the frame is either non-fragmented, or the last fragment
331 pub fin: bool,
332 /// The optional mask of the frame
333 pub mask: Option<[u8; 4]>,
334 /// The frametype
335 pub frametype: OpCode,
336 /// The length of the header (without payload).
337 pub header_len: u8,
338 /// The length of the payload.
339 pub payload_len: usize,
340 }
341
342 impl FrameHeader {
343 /// Returns true if the frame is a control frame.
344 pub fn is_control_frame(&self) -> bool {
345 self.frametype.is_control()
346 }
347
348 /// Tries to parse a FrameHeader from bytes.
349 ///
350 /// When there are not enough bytes to completely parse the header,
351 /// returns Ok(None)
352 ///
353 /// Example:
354 /// ```
355 /// # use proxmox::tools::websocket::*;
356 /// # use std::io;
357 /// # fn main() -> Result<(), WebSocketError> {
358 /// let frame = create_frame(None, &[0,1,2,3], OpCode::Ping)?;
359 /// let header = FrameHeader::try_from_bytes(&frame[..1])?;
360 /// match header {
361 /// Some(_) => unreachable!(),
362 /// None => {},
363 /// }
364 /// let header = FrameHeader::try_from_bytes(&frame[..2])?;
365 /// match header {
366 /// None => unreachable!(),
367 /// Some(header) => assert_eq!(header, FrameHeader{
368 /// fin: true,
369 /// mask: None,
370 /// frametype: OpCode::Ping,
371 /// header_len: 2,
372 /// payload_len: 4,
373 /// }),
374 /// }
375 /// # Ok(())
376 /// # }
377 /// ```
378 pub fn try_from_bytes(data: &[u8]) -> Result<Option<FrameHeader>, WebSocketError> {
379 let len = data.len();
380 if len < 2 {
381 return Ok(None);
382 }
383
384 let data = data;
385
386 // we do not support extensions
387 if data[0] & 0b01110000 > 0 {
388 return Err(WebSocketError::new(
389 WebSocketErrorKind::ProtocolError,
390 "Extensions not supported",
391 ));
392 }
393
394 let fin = data[0] & 0b10000000 != 0;
395 let frametype = match data[0] & 0b1111 {
396 0 => OpCode::Continuation,
397 1 => OpCode::Text,
398 2 => OpCode::Binary,
399 8 => OpCode::Close,
400 9 => OpCode::Ping,
401 10 => OpCode::Pong,
402 other => {
403 return Err(WebSocketError::new(
404 WebSocketErrorKind::ProtocolError,
405 &format!("Unknown OpCode {}", other),
406 ));
407 }
408 };
409
410 if !fin && frametype.is_control() {
411 return Err(WebSocketError::new(
412 WebSocketErrorKind::ProtocolError,
413 "Control frames cannot be fragmented",
414 ));
415 }
416
417 let mask_bit = data[1] & 0b10000000 != 0;
418 let mut mask_offset = 2;
419 let mut payload_offset = 2;
420 if mask_bit {
421 payload_offset += 4;
422 }
423
424 let mut payload_len: usize = (data[1] & 0b01111111).into();
425 if payload_len == 126 {
426 if len < 4 {
427 return Ok(None);
428 }
429 payload_len = u16::from_be_bytes([data[2], data[3]]) as usize;
430 mask_offset += 2;
431 payload_offset += 2;
432 } else if payload_len == 127 {
433 if len < 10 {
434 return Ok(None);
435 }
436 payload_len = u64::from_be_bytes([
437 data[2], data[3], data[4], data[5], data[6], data[7], data[8], data[9],
438 ]) as usize;
439 mask_offset += 8;
440 payload_offset += 8;
441 }
442
443 if payload_len > 125 && frametype.is_control() {
444 return Err(WebSocketError::new(
445 WebSocketErrorKind::ProtocolError,
446 "Control frames cannot carry more than 125 bytes of data",
447 ));
448 }
449
450 let mask = if mask_bit {
451 if len < mask_offset + 4 {
452 return Ok(None);
453 }
454 let mut mask = [0u8; 4];
455 mask.copy_from_slice(&data[mask_offset as usize..payload_offset as usize]);
456 Some(mask)
457 } else {
458 None
459 };
460
461 Ok(Some(FrameHeader {
462 fin,
463 mask,
464 frametype,
465 payload_len,
466 header_len: payload_offset,
467 }))
468 }
469 }
470
471 type WebSocketReadResult = Result<(OpCode, Box<[u8]>), WebSocketError>;
472
473 /// Wraps a reader that implements AsyncRead and implements it itself.
474 ///
475 /// On read, reads the underlying reader and tries to decode the frames and
476 /// simply returns the data stream.
477 /// When it encounters a control frame, calls the given callback.
478 ///
479 /// Has an internal Buffer for storing incomplete headers.
480 pub struct WebSocketReader<R: AsyncRead> {
481 reader: Option<R>,
482 sender: mpsc::UnboundedSender<WebSocketReadResult>,
483 read_buffer: Option<ByteBuffer>,
484 header: Option<FrameHeader>,
485 state: ReaderState<R>,
486 }
487
488 impl<R: AsyncReadExt> WebSocketReader<R> {
489 /// Creates a new WebSocketReader with the given CallBack for control frames
490 /// and a default buffer size of 4096.
491 pub fn new(
492 reader: R,
493 sender: mpsc::UnboundedSender<WebSocketReadResult>,
494 ) -> WebSocketReader<R> {
495 Self::with_capacity(reader, 4096, sender)
496 }
497
498 pub fn with_capacity(
499 reader: R,
500 capacity: usize,
501 sender: mpsc::UnboundedSender<WebSocketReadResult>,
502 ) -> WebSocketReader<R> {
503 WebSocketReader {
504 reader: Some(reader),
505 sender,
506 read_buffer: Some(ByteBuffer::with_capacity(capacity)),
507 header: None,
508 state: ReaderState::NoData,
509 }
510 }
511 }
512
513 struct ReadResult<R> {
514 len: usize,
515 reader: R,
516 buffer: ByteBuffer,
517 }
518
519 enum ReaderState<R> {
520 NoData,
521 Receiving(Pin<Box<dyn Future<Output = io::Result<ReadResult<R>>> + Send + 'static>>),
522 HaveData,
523 }
524
525 unsafe impl<R: Sync> Sync for ReaderState<R> {}
526
527 impl<R: AsyncReadExt + Unpin + Send + 'static> AsyncRead for WebSocketReader<R> {
528 fn poll_read(
529 self: Pin<&mut Self>,
530 cx: &mut Context,
531 buf: &mut [u8],
532 ) -> Poll<io::Result<usize>> {
533 let this = Pin::get_mut(self);
534 let mut offset = 0;
535
536 loop {
537 match &mut this.state {
538 ReaderState::NoData => {
539 let mut reader = match this.reader.take() {
540 Some(reader) => reader,
541 None => return Poll::Ready(Err(io_err_other("no reader"))),
542 };
543
544 let mut buffer = match this.read_buffer.take() {
545 Some(buffer) => buffer,
546 None => return Poll::Ready(Err(io_err_other("no buffer"))),
547 };
548
549 let future = async move {
550 buffer
551 .read_from_async(&mut reader)
552 .await
553 .map(move |len| ReadResult {
554 len,
555 reader,
556 buffer,
557 })
558 };
559
560 this.state = ReaderState::Receiving(future.boxed());
561 }
562 ReaderState::Receiving(ref mut future) => match ready!(future.as_mut().poll(cx)) {
563 Ok(ReadResult {
564 len,
565 reader,
566 buffer,
567 }) => {
568 this.reader = Some(reader);
569 this.read_buffer = Some(buffer);
570 this.state = ReaderState::HaveData;
571 if len == 0 {
572 return Poll::Ready(Ok(0));
573 }
574 }
575 Err(err) => return Poll::Ready(Err(err)),
576 },
577 ReaderState::HaveData => {
578 let mut read_buffer = match this.read_buffer.take() {
579 Some(read_buffer) => read_buffer,
580 None => return Poll::Ready(Err(io_err_other("no buffer"))),
581 };
582
583 let mut header = match this.header.take() {
584 Some(header) => header,
585 None => {
586 let header = match FrameHeader::try_from_bytes(&read_buffer[..]) {
587 Ok(Some(header)) => header,
588 Ok(None) => {
589 this.state = ReaderState::NoData;
590 this.read_buffer = Some(read_buffer);
591 continue;
592 }
593 Err(err) => {
594 if let Err(err) = this.sender.send(Err(err.clone())) {
595 return Poll::Ready(Err(io_err_other(err)));
596 }
597 return Poll::Ready(Err(io_err_other(err)));
598 }
599 };
600
601 read_buffer.consume(header.header_len as usize);
602 header
603 }
604 };
605
606 if header.is_control_frame() {
607 if read_buffer.len() >= header.payload_len {
608 let mut data = read_buffer.remove_data(header.payload_len);
609 mask_bytes(header.mask, &mut data);
610 if let Err(err) = this.sender.send(Ok((header.frametype, data))) {
611 eprintln!("error sending control frame: {}", err);
612 }
613
614 this.state = if read_buffer.is_empty() {
615 ReaderState::NoData
616 } else {
617 ReaderState::HaveData
618 };
619 this.read_buffer = Some(read_buffer);
620 continue;
621 } else {
622 this.header = Some(header);
623 this.read_buffer = Some(read_buffer);
624 this.state = ReaderState::NoData;
625 continue;
626 }
627 }
628
629 let len = min(
630 buf.len() - offset,
631 min(header.payload_len, read_buffer.len()),
632 );
633
634 let mut data = read_buffer.remove_data(len);
635 mask_bytes(header.mask, &mut data);
636 buf[offset..offset + len].copy_from_slice(&data);
637 offset += len;
638
639 header.payload_len -= len;
640
641 if header.payload_len > 0 {
642 this.header = Some(header);
643 }
644
645 this.state = if read_buffer.is_empty() {
646 ReaderState::NoData
647 } else {
648 ReaderState::HaveData
649 };
650 this.read_buffer = Some(read_buffer);
651
652 if offset > 0 {
653 return Poll::Ready(Ok(offset));
654 }
655 }
656 }
657 }
658 }
659 }
660
661 /// Global Identifier for WebSockets, see RFC6455
662 pub const MAGIC_WEBSOCKET_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
663
664 /// Provides methods for connecting a WebSocket endpoint with another
665 pub struct WebSocket {
666 text: bool,
667 }
668
669 impl WebSocket {
670 /// Returns a new WebSocket instance and the generates the correct
671 /// WebSocket response from request headers
672 pub fn new(headers: HeaderMap<HeaderValue>) -> Result<(Self, Response<Body>), Error> {
673 let protocols = headers
674 .get(UPGRADE)
675 .ok_or_else(|| format_err!("missing Upgrade header"))?
676 .to_str()?;
677
678 let version = headers
679 .get(SEC_WEBSOCKET_VERSION)
680 .ok_or_else(|| format_err!("missing websocket version"))?
681 .to_str()?;
682
683 let key = headers
684 .get(SEC_WEBSOCKET_KEY)
685 .ok_or_else(|| format_err!("missing websocket key"))?
686 .to_str()?;
687
688 let ws_proto = headers
689 .get(SEC_WEBSOCKET_PROTOCOL)
690 .ok_or_else(|| format_err!("missing websocket key"))?
691 .to_str()?;
692
693 let text = ws_proto == "text";
694
695 if protocols != "websocket" {
696 bail!("invalid protocol name");
697 }
698
699 if version != "13" {
700 bail!("invalid websocket version");
701 }
702
703 // we ignore extensions
704
705 let mut sha1 = openssl::sha::Sha1::new();
706 let data = format!("{}{}", key, MAGIC_WEBSOCKET_GUID);
707 sha1.update(data.as_bytes());
708 let response_key = base64::encode(sha1.finish());
709
710 let response = Response::builder()
711 .status(StatusCode::SWITCHING_PROTOCOLS)
712 .header(UPGRADE, HeaderValue::from_static("websocket"))
713 .header(CONNECTION, HeaderValue::from_static("Upgrade"))
714 .header(SEC_WEBSOCKET_ACCEPT, response_key)
715 .header(SEC_WEBSOCKET_PROTOCOL, ws_proto)
716 .body(Body::empty())?;
717
718 Ok((Self { text }, response))
719 }
720
721 async fn handle_channel_message<W>(
722 result: WebSocketReadResult,
723 writer: &mut WebSocketWriter<W>,
724 ) -> Result<OpCode, Error>
725 where
726 W: AsyncWrite + Unpin + Send,
727 {
728 match result {
729 Ok((OpCode::Ping, msg)) => {
730 writer.send_control_frame(None, OpCode::Pong, &msg).await?;
731 Ok(OpCode::Pong)
732 }
733 Ok((OpCode::Close, msg)) => {
734 writer.send_control_frame(None, OpCode::Close, &msg).await?;
735 Ok(OpCode::Close)
736 }
737 Ok((opcode, _)) => {
738 // ignore other frames
739 Ok(opcode)
740 }
741 Err(err) => {
742 writer
743 .send_control_frame(None, OpCode::Close, &err.generate_frame_payload())
744 .await?;
745 Err(Error::from(err))
746 }
747 }
748 }
749
750 async fn copy_to_websocket<R, W>(
751 mut reader: &mut R,
752 mut writer: &mut WebSocketWriter<W>,
753 receiver: &mut mpsc::UnboundedReceiver<WebSocketReadResult>,
754 ) -> Result<bool, Error>
755 where
756 R: AsyncRead + Unpin + Send,
757 W: AsyncWrite + Unpin + Send,
758 {
759 let mut buf = ByteBuffer::new();
760 let mut eof = false;
761 loop {
762 if !buf.is_full() {
763 let bytes = select! {
764 res = buf.read_from_async(&mut reader).fuse() => res?,
765 res = receiver.recv().fuse() => {
766 let res = res.ok_or_else(|| format_err!("control channel closed"))?;
767 match Self::handle_channel_message(res, &mut writer).await? {
768 OpCode::Close => return Ok(true),
769 _ => { continue; },
770 }
771 }
772 };
773
774 if bytes == 0 {
775 eof = true;
776 }
777 }
778 if buf.len() > 0 {
779 let bytes = writer.write(&buf).await?;
780 if bytes == 0 {
781 eof = true;
782 }
783 buf.consume(bytes);
784 }
785
786 if eof && buf.is_empty() {
787 return Ok(false);
788 }
789 }
790 }
791
792 /// Takes two endpoints and connects them via a websocket, where the
793 /// 'upstream' endpoint sends and receives WebSocket frames, while
794 /// 'downstream' only expects and sends raw data.
795 /// This method takes care of copying the data between endpoints, and
796 /// sending correct responses for control frames (e.g. a Pont to a Ping).
797 pub async fn serve_connection<S, L>(&self, upstream: S, downstream: L) -> Result<(), Error>
798 where
799 S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
800 L: AsyncRead + AsyncWrite + Unpin + Send,
801 {
802 let (usreader, uswriter) = tokio::io::split(upstream);
803 let (mut dsreader, mut dswriter) = tokio::io::split(downstream);
804
805 let (tx, mut rx) = mpsc::unbounded_channel();
806 let mut wsreader = WebSocketReader::new(usreader, tx);
807 let mut wswriter = WebSocketWriter::new(None, self.text, uswriter);
808
809 let ws_future = tokio::io::copy(&mut wsreader, &mut dswriter);
810 let term_future = Self::copy_to_websocket(&mut dsreader, &mut wswriter, &mut rx);
811
812 let res = select! {
813 res = ws_future.fuse() => match res {
814 Ok(_) => Ok(()),
815 Err(err) => Err(Error::from(err)),
816 },
817 res = term_future.fuse() => match res {
818 Ok(sent_close) if !sent_close => {
819 // status code 1000 => 0x03E8
820 wswriter.send_control_frame(None, OpCode::Close, &WebSocketErrorKind::Normal.to_be_bytes()).await?;
821 Ok(())
822 }
823 Ok(_) => Ok(()),
824 Err(err) => Err(err),
825 }
826 };
827
828 res
829 }
830 }