]>
Commit | Line | Data |
---|---|---|
556eb70e WB |
1 | //! Generic AsyncRead/AsyncWrite utilities. |
2 | ||
3 | use std::io; | |
db0cb9ce WB |
4 | use std::mem::MaybeUninit; |
5 | use std::os::unix::io::{AsRawFd, RawFd}; | |
556eb70e WB |
6 | use std::pin::Pin; |
7 | use std::task::{Context, Poll}; | |
8 | ||
db0cb9ce | 9 | use futures::stream::{Stream, TryStream}; |
556eb70e | 10 | use tokio::io::{AsyncRead, AsyncWrite}; |
db0cb9ce WB |
11 | use tokio::net::TcpListener; |
12 | use hyper::client::connect::Connection; | |
556eb70e WB |
13 | |
14 | pub enum EitherStream<L, R> { | |
15 | Left(L), | |
16 | Right(R), | |
17 | } | |
18 | ||
19 | impl<L: AsyncRead, R: AsyncRead> AsyncRead for EitherStream<L, R> { | |
20 | fn poll_read( | |
21 | self: Pin<&mut Self>, | |
22 | cx: &mut Context, | |
23 | buf: &mut [u8], | |
24 | ) -> Poll<Result<usize, io::Error>> { | |
25 | match unsafe { self.get_unchecked_mut() } { | |
26 | EitherStream::Left(ref mut s) => { | |
27 | unsafe { Pin::new_unchecked(s) }.poll_read(cx, buf) | |
28 | } | |
29 | EitherStream::Right(ref mut s) => { | |
30 | unsafe { Pin::new_unchecked(s) }.poll_read(cx, buf) | |
31 | } | |
32 | } | |
33 | } | |
34 | ||
db0cb9ce | 35 | unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit<u8>]) -> bool { |
556eb70e WB |
36 | match *self { |
37 | EitherStream::Left(ref s) => s.prepare_uninitialized_buffer(buf), | |
38 | EitherStream::Right(ref s) => s.prepare_uninitialized_buffer(buf), | |
39 | } | |
40 | } | |
41 | ||
42 | fn poll_read_buf<B>( | |
43 | self: Pin<&mut Self>, | |
44 | cx: &mut Context, | |
45 | buf: &mut B, | |
46 | ) -> Poll<Result<usize, io::Error>> | |
47 | where | |
48 | B: bytes::BufMut, | |
49 | { | |
50 | match unsafe { self.get_unchecked_mut() } { | |
51 | EitherStream::Left(ref mut s) => { | |
52 | unsafe { Pin::new_unchecked(s) }.poll_read_buf(cx, buf) | |
53 | } | |
54 | EitherStream::Right(ref mut s) => { | |
55 | unsafe { Pin::new_unchecked(s) }.poll_read_buf(cx, buf) | |
56 | } | |
57 | } | |
58 | } | |
59 | } | |
60 | ||
61 | impl<L: AsyncWrite, R: AsyncWrite> AsyncWrite for EitherStream<L, R> { | |
62 | fn poll_write( | |
63 | self: Pin<&mut Self>, | |
64 | cx: &mut Context, | |
65 | buf: &[u8], | |
66 | ) -> Poll<Result<usize, io::Error>> { | |
67 | match unsafe { self.get_unchecked_mut() } { | |
68 | EitherStream::Left(ref mut s) => { | |
69 | unsafe { Pin::new_unchecked(s) }.poll_write(cx, buf) | |
70 | } | |
71 | EitherStream::Right(ref mut s) => { | |
72 | unsafe { Pin::new_unchecked(s) }.poll_write(cx, buf) | |
73 | } | |
74 | } | |
75 | } | |
76 | ||
77 | fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> { | |
78 | match unsafe { self.get_unchecked_mut() } { | |
79 | EitherStream::Left(ref mut s) => { | |
80 | unsafe { Pin::new_unchecked(s) }.poll_flush(cx) | |
81 | } | |
82 | EitherStream::Right(ref mut s) => { | |
83 | unsafe { Pin::new_unchecked(s) }.poll_flush(cx) | |
84 | } | |
85 | } | |
86 | } | |
87 | ||
88 | fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> { | |
89 | match unsafe { self.get_unchecked_mut() } { | |
90 | EitherStream::Left(ref mut s) => { | |
91 | unsafe { Pin::new_unchecked(s) }.poll_shutdown(cx) | |
92 | } | |
93 | EitherStream::Right(ref mut s) => { | |
94 | unsafe { Pin::new_unchecked(s) }.poll_shutdown(cx) | |
95 | } | |
96 | } | |
97 | } | |
98 | ||
99 | fn poll_write_buf<B>( | |
100 | self: Pin<&mut Self>, | |
101 | cx: &mut Context, | |
102 | buf: &mut B, | |
103 | ) -> Poll<Result<usize, io::Error>> | |
104 | where | |
105 | B: bytes::Buf, | |
106 | { | |
107 | match unsafe { self.get_unchecked_mut() } { | |
108 | EitherStream::Left(ref mut s) => { | |
109 | unsafe { Pin::new_unchecked(s) }.poll_write_buf(cx, buf) | |
110 | } | |
111 | EitherStream::Right(ref mut s) => { | |
112 | unsafe { Pin::new_unchecked(s) }.poll_write_buf(cx, buf) | |
113 | } | |
114 | } | |
115 | } | |
116 | } | |
db0cb9ce WB |
117 | |
118 | // we need this for crate::client::http_client: | |
119 | impl Connection for EitherStream< | |
120 | tokio::net::TcpStream, | |
121 | tokio_openssl::SslStream<tokio::net::TcpStream>, | |
122 | > { | |
123 | fn connected(&self) -> hyper::client::connect::Connected { | |
124 | match self { | |
125 | EitherStream::Left(s) => s.connected(), | |
126 | EitherStream::Right(s) => s.get_ref().connected(), | |
127 | } | |
128 | } | |
129 | } | |
130 | ||
131 | /// Tokio's `Incoming` now is a reference type and hyper's `AddrIncoming` misses some standard | |
132 | /// stuff like `AsRawFd`, so here's something implementing hyper's `Accept` from a `TcpListener` | |
133 | pub struct StaticIncoming(TcpListener); | |
134 | ||
135 | impl From<TcpListener> for StaticIncoming { | |
136 | fn from(inner: TcpListener) -> Self { | |
137 | Self(inner) | |
138 | } | |
139 | } | |
140 | ||
141 | impl AsRawFd for StaticIncoming { | |
142 | fn as_raw_fd(&self) -> RawFd { | |
143 | self.0.as_raw_fd() | |
144 | } | |
145 | } | |
146 | ||
147 | impl hyper::server::accept::Accept for StaticIncoming { | |
148 | type Conn = tokio::net::TcpStream; | |
149 | type Error = std::io::Error; | |
150 | ||
151 | fn poll_accept( | |
152 | self: Pin<&mut Self>, | |
153 | cx: &mut Context, | |
154 | ) -> Poll<Option<Result<Self::Conn, Self::Error>>> { | |
155 | match self.get_mut().0.poll_accept(cx) { | |
156 | Poll::Pending => Poll::Pending, | |
157 | Poll::Ready(Ok((conn, _addr))) => Poll::Ready(Some(Ok(conn))), | |
158 | Poll::Ready(Err(err)) => Poll::Ready(Some(Err(err))), | |
159 | } | |
160 | } | |
161 | } | |
162 | ||
163 | /// We also implement TryStream for this, as tokio doesn't do this anymore either and we want to be | |
164 | /// able to map connections to then add eg. ssl to them. This support code makes the changes | |
165 | /// required for hyper 0.13 a bit less annoying to read. | |
166 | impl Stream for StaticIncoming { | |
167 | type Item = std::io::Result<(tokio::net::TcpStream, std::net::SocketAddr)>; | |
168 | ||
169 | fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> { | |
170 | match self.get_mut().0.poll_accept(cx) { | |
171 | Poll::Pending => Poll::Pending, | |
172 | Poll::Ready(result) => Poll::Ready(Some(result)), | |
173 | } | |
174 | } | |
175 | } | |
176 | ||
177 | /// Implement hyper's `Accept` for any `TryStream` of sockets: | |
178 | pub struct HyperAccept<T>(pub T); | |
179 | ||
180 | ||
181 | impl<T, I> hyper::server::accept::Accept for HyperAccept<T> | |
182 | where | |
183 | T: TryStream<Ok = I>, | |
184 | I: AsyncRead + AsyncWrite, | |
185 | { | |
186 | type Conn = I; | |
187 | type Error = T::Error; | |
188 | ||
189 | fn poll_accept( | |
190 | self: Pin<&mut Self>, | |
191 | cx: &mut Context, | |
192 | ) -> Poll<Option<Result<Self::Conn, Self::Error>>> { | |
193 | let this = unsafe { self.map_unchecked_mut(|this| &mut this.0) }; | |
194 | this.try_poll_next(cx) | |
195 | } | |
196 | } |