3 //! Provides methods to read and write from websockets The reader and writer take a reader/writer
4 //! with AsyncRead/AsyncWrite respectively and provides the same
7 use std
::future
::Future
;
10 use std
::task
::{Context, Poll}
;
12 use anyhow
::{bail, format_err, Error}
;
15 HeaderMap
, HeaderValue
, CONNECTION
, SEC_WEBSOCKET_ACCEPT
, SEC_WEBSOCKET_KEY
,
16 SEC_WEBSOCKET_PROTOCOL
, SEC_WEBSOCKET_VERSION
, UPGRADE
,
18 use hyper
::{Body, Response, StatusCode}
;
19 use tokio
::io
::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf}
;
20 use tokio
::sync
::mpsc
;
22 use futures
::future
::FutureExt
;
25 use proxmox
::sys
::error
::io_err_other
;
26 use proxmox
::tools
::byte_buffer
::ByteBuffer
;
28 // see RFC6455 section 7.4.1
29 #[derive(Debug, Clone, Copy)]
31 pub enum WebSocketErrorKind
{
39 impl WebSocketErrorKind
{
41 pub fn to_be_bytes(self) -> [u8; 2] {
42 (self as u16).to_be_bytes()
46 impl std
::fmt
::Display
for WebSocketErrorKind
{
47 fn fmt(&self, f
: &mut std
::fmt
::Formatter
) -> Result
<(), std
::fmt
::Error
> {
48 write
!(f
, "{}", *self as u16)
52 #[derive(Debug, Clone)]
53 pub struct WebSocketError
{
54 kind
: WebSocketErrorKind
,
59 pub fn new(kind
: WebSocketErrorKind
, message
: &str) -> Self {
62 message
: message
.to_string(),
66 pub fn generate_frame_payload(&self) -> Vec
<u8> {
67 let msglen
= self.message
.len().min(125);
68 let code
= self.kind
.to_be_bytes();
69 let mut data
= Vec
::with_capacity(msglen
+ 2);
70 data
.extend_from_slice(&code
);
71 data
.extend_from_slice(&self.message
.as_bytes()[..msglen
]);
76 impl std
::fmt
::Display
for WebSocketError
{
77 fn fmt(&self, f
: &mut std
::fmt
::Formatter
) -> Result
<(), std
::fmt
::Error
> {
78 write
!(f
, "{} (Code: {})", self.message
, self.kind
)
82 impl std
::error
::Error
for WebSocketError {}
85 #[derive(Debug, PartialEq, PartialOrd, Copy, Clone)]
86 /// Represents an OpCode of a websocket frame
88 /// A fragmented frame
90 /// A non-fragmented text frame
92 /// A non-fragmented binary frame
103 /// Tells whether it is a control frame or not
104 pub fn is_control(self) -> bool
{
105 (self as u8 & 0b1000) > 0
109 fn mask_bytes(mask
: Option
<[u8; 4]>, data
: &mut [u8]) {
110 let mask
= match mask
{
111 Some([0, 0, 0, 0]) | None
=> return,
116 for i
in 0..data
.len() {
117 data
[i
] ^
= mask
[i
% 4];
122 let mut newmask
: u32 = u32::from_le_bytes(mask
);
124 let (prefix
, middle
, suffix
) = unsafe { data.align_to_mut::<u32>() }
;
128 newmask
= newmask
.rotate_right(8);
137 newmask
= newmask
.rotate_right(8);
141 /// Can be used to create a complete WebSocket Frame.
143 /// Takes an optional mask, the data and the frame type
149 /// # use proxmox_http::websocket::*;
151 /// # fn main() -> Result<(), WebSocketError> {
152 /// let data = vec![0,1,2,3,4];
153 /// let frame = create_frame(None, &data, OpCode::Text)?;
154 /// assert_eq!(frame, vec![0b10000001, 5, 0, 1, 2, 3, 4]);
162 /// # use proxmox_http::websocket::*;
164 /// # fn main() -> Result<(), WebSocketError> {
165 /// let data = vec![0,1,2,3,4];
166 /// let frame = create_frame(Some([0u8, 1u8, 2u8, 3u8]), &data, OpCode::Text)?;
167 /// assert_eq!(frame, vec![0b10000001, 0b10000101, 0, 1, 2, 3, 0, 0, 0, 0, 4]);
175 /// # use proxmox_http::websocket::*;
177 /// # fn main() -> Result<(), WebSocketError> {
178 /// let data = vec![0,1,2,3,4];
179 /// let frame = create_frame(None, &data, OpCode::Ping)?;
180 /// assert_eq!(frame, vec![0b10001001, 0b00000101, 0, 1, 2, 3, 4]);
186 mask
: Option
<[u8; 4]>,
189 ) -> Result
<Vec
<u8>, WebSocketError
> {
190 let first_byte
= 0b10000000 | (frametype
as u8);
191 let len
= data
.len();
192 if (frametype
as u8) & 0b00001000 > 0 && len
> 125 {
193 return Err(WebSocketError
::new(
194 WebSocketErrorKind
::Unexpected
,
195 "Control frames cannot have data longer than 125 bytes",
199 let mask_bit
= if mask
.is_some() {
205 let mut buf
= vec
![first_byte
];
208 buf
.push(mask_bit
| (len
as u8));
209 } else if len
< u16::MAX
as usize {
210 buf
.push(mask_bit
| 126);
211 buf
.extend_from_slice(&(len
as u16).to_be_bytes());
213 buf
.push(mask_bit
| 127);
214 buf
.extend_from_slice(&(len
as u64).to_be_bytes());
217 if let Some(mask
) = mask
{
218 buf
.extend_from_slice(&mask
);
220 let mut data
= data
.to_vec().into_boxed_slice();
221 mask_bytes(mask
, &mut data
);
223 buf
.append(&mut data
.into_vec());
227 /// Wraps a writer that implements AsyncWrite
229 /// Can be used to send websocket frames to any writer that implements
230 /// AsyncWrite. Every write to it gets encoded as a seperate websocket frame,
231 /// without fragmentation.
235 /// # use proxmox_http::websocket::*;
237 /// # use tokio::io::{AsyncWrite, AsyncWriteExt};
238 /// async fn code<I: AsyncWrite + Unpin>(writer: I) -> io::Result<()> {
239 /// let mut ws = WebSocketWriter::new(None, writer);
240 /// ws.write(&[1u8,2u8,3u8]).await?;
244 pub struct WebSocketWriter
<W
: AsyncWrite
+ Unpin
> {
246 mask
: Option
<[u8; 4]>,
247 frame
: Option
<(Vec
<u8>, usize, usize)>,
250 impl<W
: AsyncWrite
+ Unpin
> WebSocketWriter
<W
> {
251 /// Creates a new WebSocketWriter which will use the given mask (if any),
252 /// and mark the frames as either 'Text' or 'Binary'
253 pub fn new(mask
: Option
<[u8; 4]>, writer
: W
) -> WebSocketWriter
<W
> {
261 pub async
fn send_control_frame(
263 mask
: Option
<[u8; 4]>,
266 ) -> Result
<(), Error
> {
267 let frame
= create_frame(mask
, data
, opcode
).map_err(Error
::from
)?
;
268 self.writer
.write_all(&frame
).await
.map_err(Error
::from
)
272 impl<W
: AsyncWrite
+ Unpin
> AsyncWrite
for WebSocketWriter
<W
> {
273 fn poll_write(self: Pin
<&mut Self>, cx
: &mut Context
, buf
: &[u8]) -> Poll
<io
::Result
<usize>> {
274 let this
= Pin
::get_mut(self);
276 let frametype
= OpCode
::Binary
;
278 if this
.frame
.is_none() {
280 let frame
= match create_frame(this
.mask
, buf
, frametype
) {
283 return Poll
::Ready(Err(io_err_other(e
)));
286 this
.frame
= Some((frame
, 0, buf
.len()));
289 // we have a frame in any case, so unwrap is ok
290 let (buf
, pos
, origsize
) = this
.frame
.as_mut().unwrap();
292 match ready
!(Pin
::new(&mut this
.writer
).poll_write(cx
, &buf
[*pos
..])) {
295 if *pos
== buf
.len() {
296 let size
= *origsize
;
298 return Poll
::Ready(Ok(size
));
302 eprintln
!("error in writer: {}", err
);
303 return Poll
::Ready(Err(err
));
309 fn poll_flush(self: Pin
<&mut Self>, cx
: &mut Context
) -> Poll
<io
::Result
<()>> {
310 let this
= Pin
::get_mut(self);
311 Pin
::new(&mut this
.writer
).poll_flush(cx
)
314 fn poll_shutdown(self: Pin
<&mut Self>, cx
: &mut Context
) -> Poll
<io
::Result
<()>> {
315 let this
= Pin
::get_mut(self);
316 Pin
::new(&mut this
.writer
).poll_shutdown(cx
)
320 #[derive(Debug, PartialEq)]
321 /// Represents the header of a websocket Frame
322 pub struct FrameHeader
{
323 /// True if the frame is either non-fragmented, or the last fragment
325 /// The optional mask of the frame
326 pub mask
: Option
<[u8; 4]>,
328 pub frametype
: OpCode
,
329 /// The length of the header (without payload).
331 /// The length of the payload.
332 pub payload_len
: usize,
336 /// Returns true if the frame is a control frame.
337 pub fn is_control_frame(&self) -> bool
{
338 self.frametype
.is_control()
341 /// Tries to parse a FrameHeader from bytes.
343 /// When there are not enough bytes to completely parse the header,
348 /// # use proxmox_http::websocket::*;
350 /// # fn main() -> Result<(), WebSocketError> {
351 /// let frame = create_frame(None, &[0,1,2,3], OpCode::Ping)?;
352 /// let header = FrameHeader::try_from_bytes(&frame[..1])?;
354 /// Some(_) => unreachable!(),
357 /// let header = FrameHeader::try_from_bytes(&frame[..2])?;
359 /// None => unreachable!(),
360 /// Some(header) => assert_eq!(header, FrameHeader{
363 /// frametype: OpCode::Ping,
371 pub fn try_from_bytes(data
: &[u8]) -> Result
<Option
<FrameHeader
>, WebSocketError
> {
372 let len
= data
.len();
379 // we do not support extensions
380 if data
[0] & 0b01110000 > 0 {
381 return Err(WebSocketError
::new(
382 WebSocketErrorKind
::ProtocolError
,
383 "Extensions not supported",
387 let fin
= data
[0] & 0b10000000 != 0;
388 let frametype
= match data
[0] & 0b1111 {
389 0 => OpCode
::Continuation
,
396 return Err(WebSocketError
::new(
397 WebSocketErrorKind
::ProtocolError
,
398 &format
!("Unknown OpCode {}", other
),
403 if !fin
&& frametype
.is_control() {
404 return Err(WebSocketError
::new(
405 WebSocketErrorKind
::ProtocolError
,
406 "Control frames cannot be fragmented",
410 let mask_bit
= data
[1] & 0b10000000 != 0;
411 let mut mask_offset
= 2;
412 let mut payload_offset
= 2;
417 let mut payload_len
: usize = (data
[1] & 0b01111111).into();
418 if payload_len
== 126 {
422 payload_len
= u16::from_be_bytes([data
[2], data
[3]]) as usize;
425 } else if payload_len
== 127 {
429 payload_len
= u64::from_be_bytes([
430 data
[2], data
[3], data
[4], data
[5], data
[6], data
[7], data
[8], data
[9],
436 if payload_len
> 125 && frametype
.is_control() {
437 return Err(WebSocketError
::new(
438 WebSocketErrorKind
::ProtocolError
,
439 "Control frames cannot carry more than 125 bytes of data",
443 let mask
= if mask_bit
{
444 if len
< mask_offset
+ 4 {
447 let mut mask
= [0u8; 4];
448 mask
.copy_from_slice(&data
[mask_offset
as usize..payload_offset
as usize]);
454 Ok(Some(FrameHeader
{
459 header_len
: payload_offset
,
464 type WebSocketReadResult
= Result
<(OpCode
, Box
<[u8]>), WebSocketError
>;
466 /// Wraps a reader that implements AsyncRead and implements it itself.
468 /// On read, reads the underlying reader and tries to decode the frames and
469 /// simply returns the data stream.
470 /// When it encounters a control frame, calls the given callback.
472 /// Has an internal Buffer for storing incomplete headers.
473 pub struct WebSocketReader
<R
: AsyncRead
> {
475 sender
: mpsc
::UnboundedSender
<WebSocketReadResult
>,
476 read_buffer
: Option
<ByteBuffer
>,
477 header
: Option
<FrameHeader
>,
478 state
: ReaderState
<R
>,
481 impl<R
: AsyncRead
> WebSocketReader
<R
> {
482 /// Creates a new WebSocketReader with the given CallBack for control frames
483 /// and a default buffer size of 4096.
486 sender
: mpsc
::UnboundedSender
<WebSocketReadResult
>,
487 ) -> WebSocketReader
<R
> {
488 Self::with_capacity(reader
, 4096, sender
)
491 pub fn with_capacity(
494 sender
: mpsc
::UnboundedSender
<WebSocketReadResult
>,
495 ) -> WebSocketReader
<R
> {
497 reader
: Some(reader
),
499 read_buffer
: Some(ByteBuffer
::with_capacity(capacity
)),
501 state
: ReaderState
::NoData
,
506 struct ReadResult
<R
> {
512 enum ReaderState
<R
> {
514 Receiving(Pin
<Box
<dyn Future
<Output
= io
::Result
<ReadResult
<R
>>> + Send
+ '
static>>),
518 unsafe impl<R
: Sync
> Sync
for ReaderState
<R
> {}
520 impl<R
: AsyncRead
+ Unpin
+ Send
+ '
static> AsyncRead
for WebSocketReader
<R
> {
522 self: Pin
<&mut Self>,
525 ) -> Poll
<io
::Result
<()>> {
526 let this
= Pin
::get_mut(self);
529 match &mut this
.state
{
530 ReaderState
::NoData
=> {
531 let mut reader
= match this
.reader
.take() {
532 Some(reader
) => reader
,
533 None
=> return Poll
::Ready(Err(io_err_other("no reader"))),
536 let mut buffer
= match this
.read_buffer
.take() {
537 Some(buffer
) => buffer
,
538 None
=> return Poll
::Ready(Err(io_err_other("no buffer"))),
541 let future
= async
move {
543 .read_from_async(&mut reader
)
545 .map(move |len
| ReadResult
{
552 this
.state
= ReaderState
::Receiving(future
.boxed());
554 ReaderState
::Receiving(ref mut future
) => match ready
!(future
.as_mut().poll(cx
)) {
560 this
.reader
= Some(reader
);
561 this
.read_buffer
= Some(buffer
);
562 this
.state
= ReaderState
::HaveData
;
564 return Poll
::Ready(Ok(()));
567 Err(err
) => return Poll
::Ready(Err(err
)),
569 ReaderState
::HaveData
=> {
570 let mut read_buffer
= match this
.read_buffer
.take() {
571 Some(read_buffer
) => read_buffer
,
572 None
=> return Poll
::Ready(Err(io_err_other("no buffer"))),
575 let mut header
= match this
.header
.take() {
576 Some(header
) => header
,
578 let header
= match FrameHeader
::try_from_bytes(&read_buffer
[..]) {
579 Ok(Some(header
)) => header
,
581 this
.state
= ReaderState
::NoData
;
582 this
.read_buffer
= Some(read_buffer
);
586 if let Err(err
) = this
.sender
.send(Err(err
.clone())) {
587 return Poll
::Ready(Err(io_err_other(err
)));
589 return Poll
::Ready(Err(io_err_other(err
)));
593 read_buffer
.consume(header
.header_len
as usize);
598 if header
.is_control_frame() {
599 if read_buffer
.len() >= header
.payload_len
{
600 let mut data
= read_buffer
.remove_data(header
.payload_len
);
601 mask_bytes(header
.mask
, &mut data
);
602 if let Err(err
) = this
.sender
.send(Ok((header
.frametype
, data
))) {
603 eprintln
!("error sending control frame: {}", err
);
606 this
.state
= if read_buffer
.is_empty() {
609 ReaderState
::HaveData
611 this
.read_buffer
= Some(read_buffer
);
614 this
.header
= Some(header
);
615 this
.read_buffer
= Some(read_buffer
);
616 this
.state
= ReaderState
::NoData
;
621 let len
= min(buf
.remaining(), min(header
.payload_len
, read_buffer
.len()));
623 let mut data
= read_buffer
.remove_data(len
);
624 mask_bytes(header
.mask
, &mut data
);
625 buf
.put_slice(&data
);
627 header
.payload_len
-= len
;
629 if header
.payload_len
> 0 {
630 this
.header
= Some(header
);
633 this
.state
= if read_buffer
.is_empty() {
636 ReaderState
::HaveData
638 this
.read_buffer
= Some(read_buffer
);
641 return Poll
::Ready(Ok(()));
649 /// Global Identifier for WebSockets, see RFC6455
650 pub const MAGIC_WEBSOCKET_GUID
: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
652 /// Provides methods for connecting a WebSocket endpoint with another
653 pub struct WebSocket
;
656 /// Returns a new WebSocket instance and the generates the correct
657 /// WebSocket response from request headers
658 pub fn new(headers
: HeaderMap
<HeaderValue
>) -> Result
<(Self, Response
<Body
>), Error
> {
659 let protocols
= headers
661 .ok_or_else(|| format_err
!("missing Upgrade header"))?
664 let version
= headers
665 .get(SEC_WEBSOCKET_VERSION
)
666 .ok_or_else(|| format_err
!("missing websocket version"))?
670 .get(SEC_WEBSOCKET_KEY
)
671 .ok_or_else(|| format_err
!("missing websocket key"))?
674 if protocols
!= "websocket" {
675 bail
!("invalid protocol name");
679 bail
!("invalid websocket version");
682 // we ignore extensions
684 let mut sha1
= openssl
::sha
::Sha1
::new();
685 let data
= format
!("{}{}", key
, MAGIC_WEBSOCKET_GUID
);
686 sha1
.update(data
.as_bytes());
687 let response_key
= base64
::encode(sha1
.finish());
689 let mut response
= Response
::builder()
690 .status(StatusCode
::SWITCHING_PROTOCOLS
)
691 .header(UPGRADE
, HeaderValue
::from_static("websocket"))
692 .header(CONNECTION
, HeaderValue
::from_static("Upgrade"))
693 .header(SEC_WEBSOCKET_ACCEPT
, response_key
);
695 // FIXME: remove compat in PBS 3.x
697 // We currently do not support any subprotocols and we always send binary frames,
698 // but for backwards compatibilty we need to reply the requested protocols
699 if let Some(ws_proto
) = headers
.get(SEC_WEBSOCKET_PROTOCOL
) {
700 response
= response
.header(SEC_WEBSOCKET_PROTOCOL
, ws_proto
)
703 let response
= response
.body(Body
::empty())?
;
708 async
fn handle_channel_message
<W
>(
709 result
: WebSocketReadResult
,
710 writer
: &mut WebSocketWriter
<W
>,
711 ) -> Result
<OpCode
, Error
>
713 W
: AsyncWrite
+ Unpin
+ Send
,
716 Ok((OpCode
::Ping
, msg
)) => {
717 writer
.send_control_frame(None
, OpCode
::Pong
, &msg
).await?
;
720 Ok((OpCode
::Close
, msg
)) => {
721 writer
.send_control_frame(None
, OpCode
::Close
, &msg
).await?
;
725 // ignore other frames
730 .send_control_frame(None
, OpCode
::Close
, &err
.generate_frame_payload())
732 Err(Error
::from(err
))
737 async
fn copy_to_websocket
<R
, W
>(
739 mut writer
: &mut WebSocketWriter
<W
>,
740 receiver
: &mut mpsc
::UnboundedReceiver
<WebSocketReadResult
>,
741 ) -> Result
<bool
, Error
>
743 R
: AsyncRead
+ Unpin
+ Send
,
744 W
: AsyncWrite
+ Unpin
+ Send
,
746 let mut buf
= ByteBuffer
::new();
750 let bytes
= select
! {
751 res
= buf
.read_from_async(&mut reader
).fuse() => res?
,
752 res
= receiver
.recv().fuse() => {
753 let res
= res
.ok_or_else(|| format_err
!("control channel closed"))?
;
754 match Self::handle_channel_message(res
, &mut writer
).await?
{
755 OpCode
::Close
=> return Ok(true),
766 let bytes
= writer
.write(&buf
).await?
;
773 if eof
&& buf
.is_empty() {
779 /// Takes two endpoints and connects them via a websocket, where the
780 /// 'upstream' endpoint sends and receives WebSocket frames, while
781 /// 'downstream' only expects and sends raw data.
782 /// This method takes care of copying the data between endpoints, and
783 /// sending correct responses for control frames (e.g. a Pont to a Ping).
784 pub async
fn serve_connection
<S
, L
>(&self, upstream
: S
, downstream
: L
) -> Result
<(), Error
>
786 S
: AsyncRead
+ AsyncWrite
+ Unpin
+ Send
+ '
static,
787 L
: AsyncRead
+ AsyncWrite
+ Unpin
+ Send
,
789 let (usreader
, uswriter
) = tokio
::io
::split(upstream
);
790 let (mut dsreader
, mut dswriter
) = tokio
::io
::split(downstream
);
792 let (tx
, mut rx
) = mpsc
::unbounded_channel();
793 let mut wsreader
= WebSocketReader
::new(usreader
, tx
);
794 let mut wswriter
= WebSocketWriter
::new(None
, uswriter
);
796 let ws_future
= tokio
::io
::copy(&mut wsreader
, &mut dswriter
);
797 let term_future
= Self::copy_to_websocket(&mut dsreader
, &mut wswriter
, &mut rx
);
800 res
= ws_future
.fuse() => match res
{
802 Err(err
) => Err(Error
::from(err
)),
804 res
= term_future
.fuse() => match res
{
805 Ok(sent_close
) if !sent_close
=> {
806 // status code 1000 => 0x03E8
807 wswriter
.send_control_frame(None
, OpCode
::Close
, &WebSocketErrorKind
::Normal
.to_be_bytes()).await?
;
811 Err(err
) => Err(err
),