]> git.proxmox.com Git - rustc.git/blob - vendor/tokio/src/net/tcp/split_owned.rs
New upstream version 1.60.0+dfsg1
[rustc.git] / vendor / tokio / src / net / tcp / split_owned.rs
1 //! `TcpStream` owned split support.
2 //!
3 //! A `TcpStream` can be split into an `OwnedReadHalf` and a `OwnedWriteHalf`
4 //! with the `TcpStream::into_split` method. `OwnedReadHalf` implements
5 //! `AsyncRead` while `OwnedWriteHalf` implements `AsyncWrite`.
6 //!
7 //! Compared to the generic split of `AsyncRead + AsyncWrite`, this specialized
8 //! split has no associated overhead and enforces all invariants at the type
9 //! level.
10
11 use crate::future::poll_fn;
12 use crate::io::{AsyncRead, AsyncWrite, ReadBuf};
13 use crate::net::TcpStream;
14
15 use std::error::Error;
16 use std::net::Shutdown;
17 use std::pin::Pin;
18 use std::sync::Arc;
19 use std::task::{Context, Poll};
20 use std::{fmt, io};
21
22 /// Owned read half of a [`TcpStream`], created by [`into_split`].
23 ///
24 /// Reading from an `OwnedReadHalf` is usually done using the convenience methods found
25 /// on the [`AsyncReadExt`] trait.
26 ///
27 /// [`TcpStream`]: TcpStream
28 /// [`into_split`]: TcpStream::into_split()
29 /// [`AsyncReadExt`]: trait@crate::io::AsyncReadExt
30 #[derive(Debug)]
31 pub struct OwnedReadHalf {
32 inner: Arc<TcpStream>,
33 }
34
35 /// Owned write half of a [`TcpStream`], created by [`into_split`].
36 ///
37 /// Note that in the [`AsyncWrite`] implementation of this type, [`poll_shutdown`] will
38 /// shut down the TCP stream in the write direction. Dropping the write half
39 /// will also shut down the write half of the TCP stream.
40 ///
41 /// Writing to an `OwnedWriteHalf` is usually done using the convenience methods found
42 /// on the [`AsyncWriteExt`] trait.
43 ///
44 /// [`TcpStream`]: TcpStream
45 /// [`into_split`]: TcpStream::into_split()
46 /// [`AsyncWrite`]: trait@crate::io::AsyncWrite
47 /// [`poll_shutdown`]: fn@crate::io::AsyncWrite::poll_shutdown
48 /// [`AsyncWriteExt`]: trait@crate::io::AsyncWriteExt
49 #[derive(Debug)]
50 pub struct OwnedWriteHalf {
51 inner: Arc<TcpStream>,
52 shutdown_on_drop: bool,
53 }
54
55 pub(crate) fn split_owned(stream: TcpStream) -> (OwnedReadHalf, OwnedWriteHalf) {
56 let arc = Arc::new(stream);
57 let read = OwnedReadHalf {
58 inner: Arc::clone(&arc),
59 };
60 let write = OwnedWriteHalf {
61 inner: arc,
62 shutdown_on_drop: true,
63 };
64 (read, write)
65 }
66
67 pub(crate) fn reunite(
68 read: OwnedReadHalf,
69 write: OwnedWriteHalf,
70 ) -> Result<TcpStream, ReuniteError> {
71 if Arc::ptr_eq(&read.inner, &write.inner) {
72 write.forget();
73 // This unwrap cannot fail as the api does not allow creating more than two Arcs,
74 // and we just dropped the other half.
75 Ok(Arc::try_unwrap(read.inner).expect("TcpStream: try_unwrap failed in reunite"))
76 } else {
77 Err(ReuniteError(read, write))
78 }
79 }
80
81 /// Error indicating that two halves were not from the same socket, and thus could
82 /// not be reunited.
83 #[derive(Debug)]
84 pub struct ReuniteError(pub OwnedReadHalf, pub OwnedWriteHalf);
85
86 impl fmt::Display for ReuniteError {
87 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
88 write!(
89 f,
90 "tried to reunite halves that are not from the same socket"
91 )
92 }
93 }
94
95 impl Error for ReuniteError {}
96
97 impl OwnedReadHalf {
98 /// Attempts to put the two halves of a `TcpStream` back together and
99 /// recover the original socket. Succeeds only if the two halves
100 /// originated from the same call to [`into_split`].
101 ///
102 /// [`into_split`]: TcpStream::into_split()
103 pub fn reunite(self, other: OwnedWriteHalf) -> Result<TcpStream, ReuniteError> {
104 reunite(self, other)
105 }
106
107 /// Attempt to receive data on the socket, without removing that data from
108 /// the queue, registering the current task for wakeup if data is not yet
109 /// available.
110 ///
111 /// Note that on multiple calls to `poll_peek` or `poll_read`, only the
112 /// `Waker` from the `Context` passed to the most recent call is scheduled
113 /// to receive a wakeup.
114 ///
115 /// See the [`TcpStream::poll_peek`] level documentation for more details.
116 ///
117 /// # Examples
118 ///
119 /// ```no_run
120 /// use tokio::io::{self, ReadBuf};
121 /// use tokio::net::TcpStream;
122 ///
123 /// use futures::future::poll_fn;
124 ///
125 /// #[tokio::main]
126 /// async fn main() -> io::Result<()> {
127 /// let stream = TcpStream::connect("127.0.0.1:8000").await?;
128 /// let (mut read_half, _) = stream.into_split();
129 /// let mut buf = [0; 10];
130 /// let mut buf = ReadBuf::new(&mut buf);
131 ///
132 /// poll_fn(|cx| {
133 /// read_half.poll_peek(cx, &mut buf)
134 /// }).await?;
135 ///
136 /// Ok(())
137 /// }
138 /// ```
139 ///
140 /// [`TcpStream::poll_peek`]: TcpStream::poll_peek
141 pub fn poll_peek(
142 &mut self,
143 cx: &mut Context<'_>,
144 buf: &mut ReadBuf<'_>,
145 ) -> Poll<io::Result<usize>> {
146 self.inner.poll_peek(cx, buf)
147 }
148
149 /// Receives data on the socket from the remote address to which it is
150 /// connected, without removing that data from the queue. On success,
151 /// returns the number of bytes peeked.
152 ///
153 /// See the [`TcpStream::peek`] level documentation for more details.
154 ///
155 /// [`TcpStream::peek`]: TcpStream::peek
156 ///
157 /// # Examples
158 ///
159 /// ```no_run
160 /// use tokio::net::TcpStream;
161 /// use tokio::io::AsyncReadExt;
162 /// use std::error::Error;
163 ///
164 /// #[tokio::main]
165 /// async fn main() -> Result<(), Box<dyn Error>> {
166 /// // Connect to a peer
167 /// let stream = TcpStream::connect("127.0.0.1:8080").await?;
168 /// let (mut read_half, _) = stream.into_split();
169 ///
170 /// let mut b1 = [0; 10];
171 /// let mut b2 = [0; 10];
172 ///
173 /// // Peek at the data
174 /// let n = read_half.peek(&mut b1).await?;
175 ///
176 /// // Read the data
177 /// assert_eq!(n, read_half.read(&mut b2[..n]).await?);
178 /// assert_eq!(&b1[..n], &b2[..n]);
179 ///
180 /// Ok(())
181 /// }
182 /// ```
183 ///
184 /// The [`read`] method is defined on the [`AsyncReadExt`] trait.
185 ///
186 /// [`read`]: fn@crate::io::AsyncReadExt::read
187 /// [`AsyncReadExt`]: trait@crate::io::AsyncReadExt
188 pub async fn peek(&mut self, buf: &mut [u8]) -> io::Result<usize> {
189 let mut buf = ReadBuf::new(buf);
190 poll_fn(|cx| self.poll_peek(cx, &mut buf)).await
191 }
192 }
193
194 impl AsyncRead for OwnedReadHalf {
195 fn poll_read(
196 self: Pin<&mut Self>,
197 cx: &mut Context<'_>,
198 buf: &mut ReadBuf<'_>,
199 ) -> Poll<io::Result<()>> {
200 self.inner.poll_read_priv(cx, buf)
201 }
202 }
203
204 impl OwnedWriteHalf {
205 /// Attempts to put the two halves of a `TcpStream` back together and
206 /// recover the original socket. Succeeds only if the two halves
207 /// originated from the same call to [`into_split`].
208 ///
209 /// [`into_split`]: TcpStream::into_split()
210 pub fn reunite(self, other: OwnedReadHalf) -> Result<TcpStream, ReuniteError> {
211 reunite(other, self)
212 }
213
214 /// Destroy the write half, but don't close the write half of the stream
215 /// until the read half is dropped. If the read half has already been
216 /// dropped, this closes the stream.
217 pub fn forget(mut self) {
218 self.shutdown_on_drop = false;
219 drop(self);
220 }
221 }
222
223 impl Drop for OwnedWriteHalf {
224 fn drop(&mut self) {
225 if self.shutdown_on_drop {
226 let _ = self.inner.shutdown_std(Shutdown::Write);
227 }
228 }
229 }
230
231 impl AsyncWrite for OwnedWriteHalf {
232 fn poll_write(
233 self: Pin<&mut Self>,
234 cx: &mut Context<'_>,
235 buf: &[u8],
236 ) -> Poll<io::Result<usize>> {
237 self.inner.poll_write_priv(cx, buf)
238 }
239
240 fn poll_write_vectored(
241 self: Pin<&mut Self>,
242 cx: &mut Context<'_>,
243 bufs: &[io::IoSlice<'_>],
244 ) -> Poll<io::Result<usize>> {
245 self.inner.poll_write_vectored_priv(cx, bufs)
246 }
247
248 fn is_write_vectored(&self) -> bool {
249 self.inner.is_write_vectored()
250 }
251
252 #[inline]
253 fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
254 // tcp flush is a no-op
255 Poll::Ready(Ok(()))
256 }
257
258 // `poll_shutdown` on a write half shutdowns the stream in the "write" direction.
259 fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
260 let res = self.inner.shutdown_std(Shutdown::Write);
261 if res.is_ok() {
262 Pin::into_inner(self).shutdown_on_drop = false;
263 }
264 res.into()
265 }
266 }
267
268 impl AsRef<TcpStream> for OwnedReadHalf {
269 fn as_ref(&self) -> &TcpStream {
270 &*self.inner
271 }
272 }
273
274 impl AsRef<TcpStream> for OwnedWriteHalf {
275 fn as_ref(&self) -> &TcpStream {
276 &*self.inner
277 }
278 }