3 //! Provides methods to read and write from websockets The reader and writer take a reader/writer
4 //! with AsyncRead/AsyncWrite respectively and provides the same
7 use std
::future
::Future
;
10 use std
::task
::{Context, Poll}
;
12 use anyhow
::{bail, format_err, Error}
;
15 HeaderMap
, HeaderValue
, CONNECTION
, SEC_WEBSOCKET_ACCEPT
, SEC_WEBSOCKET_KEY
,
16 SEC_WEBSOCKET_PROTOCOL
, SEC_WEBSOCKET_VERSION
, UPGRADE
,
18 use hyper
::{Body, Response, StatusCode}
;
19 use tokio
::io
::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}
;
20 use tokio
::sync
::mpsc
;
22 use futures
::future
::FutureExt
;
25 use crate::sys
::error
::io_err_other
;
26 use crate::tools
::byte_buffer
::ByteBuffer
;
28 // see RFC6455 section 7.4.1
29 #[derive(Debug, Clone, Copy)]
31 pub enum WebSocketErrorKind
{
39 impl WebSocketErrorKind
{
41 pub fn to_be_bytes(self) -> [u8; 2] {
42 (self as u16).to_be_bytes()
46 impl std
::fmt
::Display
for WebSocketErrorKind
{
47 fn fmt(&self, f
: &mut std
::fmt
::Formatter
) -> Result
<(), std
::fmt
::Error
> {
48 write
!(f
, "{}", *self as u16)
52 #[derive(Debug, Clone)]
53 pub struct WebSocketError
{
54 kind
: WebSocketErrorKind
,
59 pub fn new(kind
: WebSocketErrorKind
, message
: &str) -> Self {
62 message
: message
.to_string(),
66 pub fn generate_frame_payload(&self) -> Vec
<u8> {
67 let msglen
= self.message
.len().min(125);
68 let code
= self.kind
.to_be_bytes();
69 let mut data
= Vec
::with_capacity(msglen
+ 2);
70 data
.extend_from_slice(&code
);
71 data
.extend_from_slice(&self.message
.as_bytes()[..msglen
]);
76 impl std
::fmt
::Display
for WebSocketError
{
77 fn fmt(&self, f
: &mut std
::fmt
::Formatter
) -> Result
<(), std
::fmt
::Error
> {
78 write
!(f
, "{} (Code: {})", self.message
, self.kind
)
82 impl std
::error
::Error
for WebSocketError {}
85 #[derive(Debug, PartialEq, PartialOrd, Copy, Clone)]
86 /// Represents an OpCode of a websocket frame
88 /// A fragmented frame
90 /// A non-fragmented text frame
92 /// A non-fragmented binary frame
103 /// Tells whether it is a control frame or not
104 pub fn is_control(self) -> bool
{
105 (self as u8 & 0b1000) > 0
109 fn mask_bytes(mask
: Option
<[u8; 4]>, data
: &mut [u8]) {
110 let mask
= match mask
{
111 Some([0, 0, 0, 0]) | None
=> return,
116 for i
in 0..data
.len() {
117 data
[i
] ^
= mask
[i
% 4];
122 let mut newmask
: u32 = u32::from_le_bytes(mask
);
124 let (prefix
, middle
, suffix
) = unsafe { data.align_to_mut::<u32>() }
;
128 newmask
= newmask
.rotate_right(8);
137 newmask
= newmask
.rotate_right(8);
141 /// Can be used to create a complete WebSocket Frame.
143 /// Takes an optional mask, the data and the frame type
149 /// # use proxmox::tools::websocket::*;
151 /// # fn main() -> Result<(), WebSocketError> {
152 /// let data = vec![0,1,2,3,4];
153 /// let frame = create_frame(None, &data, OpCode::Text)?;
154 /// assert_eq!(frame, vec![0b10000001, 5, 0, 1, 2, 3, 4]);
162 /// # use proxmox::tools::websocket::*;
164 /// # fn main() -> Result<(), WebSocketError> {
165 /// let data = vec![0,1,2,3,4];
166 /// let frame = create_frame(Some([0u8, 1u8, 2u8, 3u8]), &data, OpCode::Text)?;
167 /// assert_eq!(frame, vec![0b10000001, 0b10000101, 0, 1, 2, 3, 0, 0, 0, 0, 4]);
175 /// # use proxmox::tools::websocket::*;
177 /// # fn main() -> Result<(), WebSocketError> {
178 /// let data = vec![0,1,2,3,4];
179 /// let frame = create_frame(None, &data, OpCode::Ping)?;
180 /// assert_eq!(frame, vec![0b10001001, 0b00000101, 0, 1, 2, 3, 4]);
186 mask
: Option
<[u8; 4]>,
189 ) -> Result
<Vec
<u8>, WebSocketError
> {
190 let first_byte
= 0b10000000 | (frametype
as u8);
191 let len
= data
.len();
192 if (frametype
as u8) & 0b00001000 > 0 && len
> 125 {
193 return Err(WebSocketError
::new(
194 WebSocketErrorKind
::Unexpected
,
195 "Control frames cannot have data longer than 125 bytes",
199 let mask_bit
= if mask
.is_some() {
205 let mut buf
= Vec
::new();
206 buf
.push(first_byte
);
209 buf
.push(mask_bit
| (len
as u8));
210 } else if len
< std
::u16::MAX
as usize {
211 buf
.push(mask_bit
| 126);
212 buf
.extend_from_slice(&(len
as u16).to_be_bytes());
214 buf
.push(mask_bit
| 127);
215 buf
.extend_from_slice(&(len
as u64).to_be_bytes());
218 if let Some(mask
) = mask
{
219 buf
.extend_from_slice(&mask
);
221 let mut data
= data
.to_vec().into_boxed_slice();
222 mask_bytes(mask
, &mut data
);
224 buf
.append(&mut data
.into_vec());
228 /// Wraps a writer that implements AsyncWrite
230 /// Can be used to send websocket frames to any writer that implements
231 /// AsyncWrite. Every write to it gets encoded as a seperate websocket frame,
232 /// without fragmentation.
236 /// # use proxmox::tools::websocket::*;
238 /// # use tokio::io::{AsyncWrite, AsyncWriteExt};
239 /// async fn code<I: AsyncWrite + Unpin>(writer: I) -> io::Result<()> {
240 /// let mut ws = WebSocketWriter::new(None, false, writer);
241 /// ws.write(&[1u8,2u8,3u8]).await?;
245 pub struct WebSocketWriter
<W
: AsyncWrite
+ Unpin
> {
248 mask
: Option
<[u8; 4]>,
249 frame
: Option
<(Vec
<u8>, usize, usize)>,
252 impl<W
: AsyncWrite
+ Unpin
> WebSocketWriter
<W
> {
253 /// Creates a new WebSocketWriter which will use the given mask (if any),
254 /// and mark the frames as either 'Text' or 'Binary'
255 pub fn new(mask
: Option
<[u8; 4]>, text
: bool
, writer
: W
) -> WebSocketWriter
<W
> {
264 pub async
fn send_control_frame(
266 mask
: Option
<[u8; 4]>,
269 ) -> Result
<(), Error
> {
270 let frame
= create_frame(mask
, data
, opcode
).map_err(Error
::from
)?
;
271 self.writer
.write_all(&frame
).await
.map_err(Error
::from
)
275 impl<W
: AsyncWrite
+ Unpin
> AsyncWrite
for WebSocketWriter
<W
> {
276 fn poll_write(self: Pin
<&mut Self>, cx
: &mut Context
, buf
: &[u8]) -> Poll
<io
::Result
<usize>> {
277 let this
= Pin
::get_mut(self);
279 let frametype
= if this
.text
{
285 if this
.frame
.is_none() {
287 let frame
= match create_frame(this
.mask
, buf
, frametype
) {
290 return Poll
::Ready(Err(io_err_other(e
)));
293 this
.frame
= Some((frame
, 0, buf
.len()));
296 // we have a frame in any case, so unwrap is ok
297 let (buf
, pos
, origsize
) = this
.frame
.as_mut().unwrap();
299 match ready
!(Pin
::new(&mut this
.writer
).poll_write(cx
, &buf
[*pos
..])) {
302 if *pos
== buf
.len() {
303 let size
= *origsize
;
305 return Poll
::Ready(Ok(size
));
309 eprintln
!("error in writer: {}", err
);
310 return Poll
::Ready(Err(err
));
316 fn poll_flush(self: Pin
<&mut Self>, cx
: &mut Context
) -> Poll
<io
::Result
<()>> {
317 let this
= Pin
::get_mut(self);
318 Pin
::new(&mut this
.writer
).poll_flush(cx
)
321 fn poll_shutdown(self: Pin
<&mut Self>, cx
: &mut Context
) -> Poll
<io
::Result
<()>> {
322 let this
= Pin
::get_mut(self);
323 Pin
::new(&mut this
.writer
).poll_shutdown(cx
)
327 #[derive(Debug, PartialEq)]
328 /// Represents the header of a websocket Frame
329 pub struct FrameHeader
{
330 /// True if the frame is either non-fragmented, or the last fragment
332 /// The optional mask of the frame
333 pub mask
: Option
<[u8; 4]>,
335 pub frametype
: OpCode
,
336 /// The length of the header (without payload).
338 /// The length of the payload.
339 pub payload_len
: usize,
343 /// Returns true if the frame is a control frame.
344 pub fn is_control_frame(&self) -> bool
{
345 self.frametype
.is_control()
348 /// Tries to parse a FrameHeader from bytes.
350 /// When there are not enough bytes to completely parse the header,
355 /// # use proxmox::tools::websocket::*;
357 /// # fn main() -> Result<(), WebSocketError> {
358 /// let frame = create_frame(None, &[0,1,2,3], OpCode::Ping)?;
359 /// let header = FrameHeader::try_from_bytes(&frame[..1])?;
361 /// Some(_) => unreachable!(),
364 /// let header = FrameHeader::try_from_bytes(&frame[..2])?;
366 /// None => unreachable!(),
367 /// Some(header) => assert_eq!(header, FrameHeader{
370 /// frametype: OpCode::Ping,
378 pub fn try_from_bytes(data
: &[u8]) -> Result
<Option
<FrameHeader
>, WebSocketError
> {
379 let len
= data
.len();
386 // we do not support extensions
387 if data
[0] & 0b01110000 > 0 {
388 return Err(WebSocketError
::new(
389 WebSocketErrorKind
::ProtocolError
,
390 "Extensions not supported",
394 let fin
= data
[0] & 0b10000000 != 0;
395 let frametype
= match data
[0] & 0b1111 {
396 0 => OpCode
::Continuation
,
403 return Err(WebSocketError
::new(
404 WebSocketErrorKind
::ProtocolError
,
405 &format
!("Unknown OpCode {}", other
),
410 if !fin
&& frametype
.is_control() {
411 return Err(WebSocketError
::new(
412 WebSocketErrorKind
::ProtocolError
,
413 "Control frames cannot be fragmented",
417 let mask_bit
= data
[1] & 0b10000000 != 0;
418 let mut mask_offset
= 2;
419 let mut payload_offset
= 2;
424 let mut payload_len
: usize = (data
[1] & 0b01111111).into();
425 if payload_len
== 126 {
429 payload_len
= u16::from_be_bytes([data
[2], data
[3]]) as usize;
432 } else if payload_len
== 127 {
436 payload_len
= u64::from_be_bytes([
437 data
[2], data
[3], data
[4], data
[5], data
[6], data
[7], data
[8], data
[9],
443 if payload_len
> 125 && frametype
.is_control() {
444 return Err(WebSocketError
::new(
445 WebSocketErrorKind
::ProtocolError
,
446 "Control frames cannot carry more than 125 bytes of data",
450 let mask
= if mask_bit
{
451 if len
< mask_offset
+ 4 {
454 let mut mask
= [0u8; 4];
455 mask
.copy_from_slice(&data
[mask_offset
as usize..payload_offset
as usize]);
461 Ok(Some(FrameHeader
{
466 header_len
: payload_offset
,
471 type WebSocketReadResult
= Result
<(OpCode
, Box
<[u8]>), WebSocketError
>;
473 /// Wraps a reader that implements AsyncRead and implements it itself.
475 /// On read, reads the underlying reader and tries to decode the frames and
476 /// simply returns the data stream.
477 /// When it encounters a control frame, calls the given callback.
479 /// Has an internal Buffer for storing incomplete headers.
480 pub struct WebSocketReader
<R
: AsyncRead
> {
482 sender
: mpsc
::UnboundedSender
<WebSocketReadResult
>,
483 read_buffer
: Option
<ByteBuffer
>,
484 header
: Option
<FrameHeader
>,
485 state
: ReaderState
<R
>,
488 impl<R
: AsyncReadExt
> WebSocketReader
<R
> {
489 /// Creates a new WebSocketReader with the given CallBack for control frames
490 /// and a default buffer size of 4096.
493 sender
: mpsc
::UnboundedSender
<WebSocketReadResult
>,
494 ) -> WebSocketReader
<R
> {
495 Self::with_capacity(reader
, 4096, sender
)
498 pub fn with_capacity(
501 sender
: mpsc
::UnboundedSender
<WebSocketReadResult
>,
502 ) -> WebSocketReader
<R
> {
504 reader
: Some(reader
),
506 read_buffer
: Some(ByteBuffer
::with_capacity(capacity
)),
508 state
: ReaderState
::NoData
,
513 struct ReadResult
<R
> {
519 enum ReaderState
<R
> {
521 Receiving(Pin
<Box
<dyn Future
<Output
= io
::Result
<ReadResult
<R
>>> + Send
+ '
static>>),
525 unsafe impl<R
: Sync
> Sync
for ReaderState
<R
> {}
527 impl<R
: AsyncReadExt
+ Unpin
+ Send
+ '
static> AsyncRead
for WebSocketReader
<R
> {
529 self: Pin
<&mut Self>,
532 ) -> Poll
<io
::Result
<usize>> {
533 let this
= Pin
::get_mut(self);
537 match &mut this
.state
{
538 ReaderState
::NoData
=> {
539 let mut reader
= match this
.reader
.take() {
540 Some(reader
) => reader
,
541 None
=> return Poll
::Ready(Err(io_err_other("no reader"))),
544 let mut buffer
= match this
.read_buffer
.take() {
545 Some(buffer
) => buffer
,
546 None
=> return Poll
::Ready(Err(io_err_other("no buffer"))),
549 let future
= async
move {
551 .read_from_async(&mut reader
)
553 .map(move |len
| ReadResult
{
560 this
.state
= ReaderState
::Receiving(future
.boxed());
562 ReaderState
::Receiving(ref mut future
) => match ready
!(future
.as_mut().poll(cx
)) {
568 this
.reader
= Some(reader
);
569 this
.read_buffer
= Some(buffer
);
570 this
.state
= ReaderState
::HaveData
;
572 return Poll
::Ready(Ok(0));
575 Err(err
) => return Poll
::Ready(Err(err
)),
577 ReaderState
::HaveData
=> {
578 let mut read_buffer
= match this
.read_buffer
.take() {
579 Some(read_buffer
) => read_buffer
,
580 None
=> return Poll
::Ready(Err(io_err_other("no buffer"))),
583 let mut header
= match this
.header
.take() {
584 Some(header
) => header
,
586 let header
= match FrameHeader
::try_from_bytes(&read_buffer
[..]) {
587 Ok(Some(header
)) => header
,
589 this
.state
= ReaderState
::NoData
;
590 this
.read_buffer
= Some(read_buffer
);
594 if let Err(err
) = this
.sender
.send(Err(err
.clone())) {
595 return Poll
::Ready(Err(io_err_other(err
)));
597 return Poll
::Ready(Err(io_err_other(err
)));
601 read_buffer
.consume(header
.header_len
as usize);
606 if header
.is_control_frame() {
607 if read_buffer
.len() >= header
.payload_len
{
608 let mut data
= read_buffer
.remove_data(header
.payload_len
);
609 mask_bytes(header
.mask
, &mut data
);
610 if let Err(err
) = this
.sender
.send(Ok((header
.frametype
, data
))) {
611 eprintln
!("error sending control frame: {}", err
);
614 this
.state
= if read_buffer
.is_empty() {
617 ReaderState
::HaveData
619 this
.read_buffer
= Some(read_buffer
);
622 this
.header
= Some(header
);
623 this
.read_buffer
= Some(read_buffer
);
624 this
.state
= ReaderState
::NoData
;
631 min(header
.payload_len
, read_buffer
.len()),
634 let mut data
= read_buffer
.remove_data(len
);
635 mask_bytes(header
.mask
, &mut data
);
636 buf
[offset
..offset
+ len
].copy_from_slice(&data
);
639 header
.payload_len
-= len
;
641 if header
.payload_len
> 0 {
642 this
.header
= Some(header
);
645 this
.state
= if read_buffer
.is_empty() {
648 ReaderState
::HaveData
650 this
.read_buffer
= Some(read_buffer
);
653 return Poll
::Ready(Ok(offset
));
661 /// Global Identifier for WebSockets, see RFC6455
662 pub const MAGIC_WEBSOCKET_GUID
: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
664 /// Provides methods for connecting a WebSocket endpoint with another
665 pub struct WebSocket
{
670 /// Returns a new WebSocket instance and the generates the correct
671 /// WebSocket response from request headers
672 pub fn new(headers
: HeaderMap
<HeaderValue
>) -> Result
<(Self, Response
<Body
>), Error
> {
673 let protocols
= headers
675 .ok_or_else(|| format_err
!("missing Upgrade header"))?
678 let version
= headers
679 .get(SEC_WEBSOCKET_VERSION
)
680 .ok_or_else(|| format_err
!("missing websocket version"))?
684 .get(SEC_WEBSOCKET_KEY
)
685 .ok_or_else(|| format_err
!("missing websocket key"))?
688 let ws_proto
= headers
689 .get(SEC_WEBSOCKET_PROTOCOL
)
690 .ok_or_else(|| format_err
!("missing websocket key"))?
693 let text
= ws_proto
== "text";
695 if protocols
!= "websocket" {
696 bail
!("invalid protocol name");
700 bail
!("invalid websocket version");
703 // we ignore extensions
705 let mut sha1
= openssl
::sha
::Sha1
::new();
706 let data
= format
!("{}{}", key
, MAGIC_WEBSOCKET_GUID
);
707 sha1
.update(data
.as_bytes());
708 let response_key
= base64
::encode(sha1
.finish());
710 let response
= Response
::builder()
711 .status(StatusCode
::SWITCHING_PROTOCOLS
)
712 .header(UPGRADE
, HeaderValue
::from_static("websocket"))
713 .header(CONNECTION
, HeaderValue
::from_static("Upgrade"))
714 .header(SEC_WEBSOCKET_ACCEPT
, response_key
)
715 .header(SEC_WEBSOCKET_PROTOCOL
, ws_proto
)
716 .body(Body
::empty())?
;
718 Ok((Self { text }
, response
))
721 async
fn handle_channel_message
<W
>(
722 result
: WebSocketReadResult
,
723 writer
: &mut WebSocketWriter
<W
>,
724 ) -> Result
<OpCode
, Error
>
726 W
: AsyncWrite
+ Unpin
+ Send
,
729 Ok((OpCode
::Ping
, msg
)) => {
730 writer
.send_control_frame(None
, OpCode
::Pong
, &msg
).await?
;
733 Ok((OpCode
::Close
, msg
)) => {
734 writer
.send_control_frame(None
, OpCode
::Close
, &msg
).await?
;
738 // ignore other frames
743 .send_control_frame(None
, OpCode
::Close
, &err
.generate_frame_payload())
745 Err(Error
::from(err
))
750 async
fn copy_to_websocket
<R
, W
>(
752 mut writer
: &mut WebSocketWriter
<W
>,
753 receiver
: &mut mpsc
::UnboundedReceiver
<WebSocketReadResult
>,
754 ) -> Result
<bool
, Error
>
756 R
: AsyncRead
+ Unpin
+ Send
,
757 W
: AsyncWrite
+ Unpin
+ Send
,
759 let mut buf
= ByteBuffer
::new();
763 let bytes
= select
! {
764 res
= buf
.read_from_async(&mut reader
).fuse() => res?
,
765 res
= receiver
.recv().fuse() => {
766 let res
= res
.ok_or_else(|| format_err
!("control channel closed"))?
;
767 match Self::handle_channel_message(res
, &mut writer
).await?
{
768 OpCode
::Close
=> return Ok(true),
779 let bytes
= writer
.write(&buf
).await?
;
786 if eof
&& buf
.is_empty() {
792 /// Takes two endpoints and connects them via a websocket, where the
793 /// 'upstream' endpoint sends and receives WebSocket frames, while
794 /// 'downstream' only expects and sends raw data.
795 /// This method takes care of copying the data between endpoints, and
796 /// sending correct responses for control frames (e.g. a Pont to a Ping).
797 pub async
fn serve_connection
<S
, L
>(&self, upstream
: S
, downstream
: L
) -> Result
<(), Error
>
799 S
: AsyncRead
+ AsyncWrite
+ Unpin
+ Send
+ '
static,
800 L
: AsyncRead
+ AsyncWrite
+ Unpin
+ Send
,
802 let (usreader
, uswriter
) = tokio
::io
::split(upstream
);
803 let (mut dsreader
, mut dswriter
) = tokio
::io
::split(downstream
);
805 let (tx
, mut rx
) = mpsc
::unbounded_channel();
806 let mut wsreader
= WebSocketReader
::new(usreader
, tx
);
807 let mut wswriter
= WebSocketWriter
::new(None
, self.text
, uswriter
);
809 let ws_future
= tokio
::io
::copy(&mut wsreader
, &mut dswriter
);
810 let term_future
= Self::copy_to_websocket(&mut dsreader
, &mut wswriter
, &mut rx
);
813 res
= ws_future
.fuse() => match res
{
815 Err(err
) => Err(Error
::from(err
)),
817 res
= term_future
.fuse() => match res
{
818 Ok(sent_close
) if !sent_close
=> {
819 // status code 1000 => 0x03E8
820 wswriter
.send_control_frame(None
, OpCode
::Close
, &WebSocketErrorKind
::Normal
.to_be_bytes()).await?
;
824 Err(err
) => Err(err
),