]> git.proxmox.com Git - proxmox-backup.git/blame - src/tools/async_io.rs
d/control: add ',' after qrencode dependency
[proxmox-backup.git] / src / tools / async_io.rs
CommitLineData
556eb70e
WB
1//! Generic AsyncRead/AsyncWrite utilities.
2
3use std::io;
db0cb9ce
WB
4use std::mem::MaybeUninit;
5use std::os::unix::io::{AsRawFd, RawFd};
556eb70e
WB
6use std::pin::Pin;
7use std::task::{Context, Poll};
8
db0cb9ce 9use futures::stream::{Stream, TryStream};
556eb70e 10use tokio::io::{AsyncRead, AsyncWrite};
db0cb9ce
WB
11use tokio::net::TcpListener;
12use hyper::client::connect::Connection;
556eb70e
WB
13
14pub enum EitherStream<L, R> {
15 Left(L),
16 Right(R),
17}
18
19impl<L: AsyncRead, R: AsyncRead> AsyncRead for EitherStream<L, R> {
20 fn poll_read(
21 self: Pin<&mut Self>,
22 cx: &mut Context,
23 buf: &mut [u8],
24 ) -> Poll<Result<usize, io::Error>> {
25 match unsafe { self.get_unchecked_mut() } {
26 EitherStream::Left(ref mut s) => {
27 unsafe { Pin::new_unchecked(s) }.poll_read(cx, buf)
28 }
29 EitherStream::Right(ref mut s) => {
30 unsafe { Pin::new_unchecked(s) }.poll_read(cx, buf)
31 }
32 }
33 }
34
db0cb9ce 35 unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit<u8>]) -> bool {
556eb70e
WB
36 match *self {
37 EitherStream::Left(ref s) => s.prepare_uninitialized_buffer(buf),
38 EitherStream::Right(ref s) => s.prepare_uninitialized_buffer(buf),
39 }
40 }
41
42 fn poll_read_buf<B>(
43 self: Pin<&mut Self>,
44 cx: &mut Context,
45 buf: &mut B,
46 ) -> Poll<Result<usize, io::Error>>
47 where
48 B: bytes::BufMut,
49 {
50 match unsafe { self.get_unchecked_mut() } {
51 EitherStream::Left(ref mut s) => {
52 unsafe { Pin::new_unchecked(s) }.poll_read_buf(cx, buf)
53 }
54 EitherStream::Right(ref mut s) => {
55 unsafe { Pin::new_unchecked(s) }.poll_read_buf(cx, buf)
56 }
57 }
58 }
59}
60
61impl<L: AsyncWrite, R: AsyncWrite> AsyncWrite for EitherStream<L, R> {
62 fn poll_write(
63 self: Pin<&mut Self>,
64 cx: &mut Context,
65 buf: &[u8],
66 ) -> Poll<Result<usize, io::Error>> {
67 match unsafe { self.get_unchecked_mut() } {
68 EitherStream::Left(ref mut s) => {
69 unsafe { Pin::new_unchecked(s) }.poll_write(cx, buf)
70 }
71 EitherStream::Right(ref mut s) => {
72 unsafe { Pin::new_unchecked(s) }.poll_write(cx, buf)
73 }
74 }
75 }
76
77 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
78 match unsafe { self.get_unchecked_mut() } {
79 EitherStream::Left(ref mut s) => {
80 unsafe { Pin::new_unchecked(s) }.poll_flush(cx)
81 }
82 EitherStream::Right(ref mut s) => {
83 unsafe { Pin::new_unchecked(s) }.poll_flush(cx)
84 }
85 }
86 }
87
88 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
89 match unsafe { self.get_unchecked_mut() } {
90 EitherStream::Left(ref mut s) => {
91 unsafe { Pin::new_unchecked(s) }.poll_shutdown(cx)
92 }
93 EitherStream::Right(ref mut s) => {
94 unsafe { Pin::new_unchecked(s) }.poll_shutdown(cx)
95 }
96 }
97 }
98
99 fn poll_write_buf<B>(
100 self: Pin<&mut Self>,
101 cx: &mut Context,
102 buf: &mut B,
103 ) -> Poll<Result<usize, io::Error>>
104 where
105 B: bytes::Buf,
106 {
107 match unsafe { self.get_unchecked_mut() } {
108 EitherStream::Left(ref mut s) => {
109 unsafe { Pin::new_unchecked(s) }.poll_write_buf(cx, buf)
110 }
111 EitherStream::Right(ref mut s) => {
112 unsafe { Pin::new_unchecked(s) }.poll_write_buf(cx, buf)
113 }
114 }
115 }
116}
db0cb9ce
WB
117
118// we need this for crate::client::http_client:
119impl Connection for EitherStream<
120 tokio::net::TcpStream,
121 tokio_openssl::SslStream<tokio::net::TcpStream>,
122> {
123 fn connected(&self) -> hyper::client::connect::Connected {
124 match self {
125 EitherStream::Left(s) => s.connected(),
126 EitherStream::Right(s) => s.get_ref().connected(),
127 }
128 }
129}
130
131/// Tokio's `Incoming` now is a reference type and hyper's `AddrIncoming` misses some standard
132/// stuff like `AsRawFd`, so here's something implementing hyper's `Accept` from a `TcpListener`
133pub struct StaticIncoming(TcpListener);
134
135impl From<TcpListener> for StaticIncoming {
136 fn from(inner: TcpListener) -> Self {
137 Self(inner)
138 }
139}
140
141impl AsRawFd for StaticIncoming {
142 fn as_raw_fd(&self) -> RawFd {
143 self.0.as_raw_fd()
144 }
145}
146
147impl hyper::server::accept::Accept for StaticIncoming {
148 type Conn = tokio::net::TcpStream;
149 type Error = std::io::Error;
150
151 fn poll_accept(
152 self: Pin<&mut Self>,
153 cx: &mut Context,
154 ) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
155 match self.get_mut().0.poll_accept(cx) {
156 Poll::Pending => Poll::Pending,
157 Poll::Ready(Ok((conn, _addr))) => Poll::Ready(Some(Ok(conn))),
158 Poll::Ready(Err(err)) => Poll::Ready(Some(Err(err))),
159 }
160 }
161}
162
163/// We also implement TryStream for this, as tokio doesn't do this anymore either and we want to be
164/// able to map connections to then add eg. ssl to them. This support code makes the changes
165/// required for hyper 0.13 a bit less annoying to read.
166impl Stream for StaticIncoming {
167 type Item = std::io::Result<(tokio::net::TcpStream, std::net::SocketAddr)>;
168
169 fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
170 match self.get_mut().0.poll_accept(cx) {
171 Poll::Pending => Poll::Pending,
172 Poll::Ready(result) => Poll::Ready(Some(result)),
173 }
174 }
175}
176
177/// Implement hyper's `Accept` for any `TryStream` of sockets:
178pub struct HyperAccept<T>(pub T);
179
180
181impl<T, I> hyper::server::accept::Accept for HyperAccept<T>
182where
183 T: TryStream<Ok = I>,
184 I: AsyncRead + AsyncWrite,
185{
186 type Conn = I;
187 type Error = T::Error;
188
189 fn poll_accept(
190 self: Pin<&mut Self>,
191 cx: &mut Context,
192 ) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
193 let this = unsafe { self.map_unchecked_mut(|this| &mut this.0) };
194 this.try_poll_next(cx)
195 }
196}