]>
Commit | Line | Data |
---|---|---|
fc5870be WB |
1 | //! Wrappers between async readers and streams. |
2 | ||
3 | use std::io::{self, Read}; | |
4 | use std::future::Future; | |
5 | use std::pin::Pin; | |
6 | use std::task::{Context, Poll}; | |
7 | ||
8 | use anyhow::{Error, Result}; | |
9 | use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; | |
10 | use tokio::sync::mpsc::Sender; | |
11 | use futures::ready; | |
12 | use futures::future::FutureExt; | |
13 | use futures::stream::Stream; | |
14 | ||
15 | use proxmox::io_format_err; | |
fc5870be | 16 | use proxmox::sys::error::io_err_other; |
6ef1b649 | 17 | use proxmox_io::ByteBuffer; |
fc5870be WB |
18 | |
19 | use pbs_runtime::block_in_place; | |
20 | ||
21 | /// Wrapper struct to convert a Reader into a Stream | |
22 | pub struct WrappedReaderStream<R: Read + Unpin> { | |
23 | reader: R, | |
24 | buffer: Vec<u8>, | |
25 | } | |
26 | ||
27 | impl <R: Read + Unpin> WrappedReaderStream<R> { | |
28 | ||
29 | pub fn new(reader: R) -> Self { | |
30 | let mut buffer = Vec::with_capacity(64*1024); | |
31 | unsafe { buffer.set_len(buffer.capacity()); } | |
32 | Self { reader, buffer } | |
33 | } | |
34 | } | |
35 | ||
36 | impl<R: Read + Unpin> Stream for WrappedReaderStream<R> { | |
37 | type Item = Result<Vec<u8>, io::Error>; | |
38 | ||
39 | fn poll_next(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<Option<Self::Item>> { | |
40 | let this = self.get_mut(); | |
41 | match block_in_place(|| this.reader.read(&mut this.buffer)) { | |
42 | Ok(n) => { | |
43 | if n == 0 { | |
44 | // EOF | |
45 | Poll::Ready(None) | |
46 | } else { | |
47 | Poll::Ready(Some(Ok(this.buffer[..n].to_vec()))) | |
48 | } | |
49 | } | |
50 | Err(err) => Poll::Ready(Some(Err(err))), | |
51 | } | |
52 | } | |
53 | } | |
54 | ||
55 | /// Wrapper struct to convert an AsyncReader into a Stream | |
56 | pub struct AsyncReaderStream<R: AsyncRead + Unpin> { | |
57 | reader: R, | |
58 | buffer: Vec<u8>, | |
59 | } | |
60 | ||
61 | impl <R: AsyncRead + Unpin> AsyncReaderStream<R> { | |
62 | ||
63 | pub fn new(reader: R) -> Self { | |
64 | let mut buffer = Vec::with_capacity(64*1024); | |
65 | unsafe { buffer.set_len(buffer.capacity()); } | |
66 | Self { reader, buffer } | |
67 | } | |
68 | ||
69 | pub fn with_buffer_size(reader: R, buffer_size: usize) -> Self { | |
70 | let mut buffer = Vec::with_capacity(buffer_size); | |
71 | unsafe { buffer.set_len(buffer.capacity()); } | |
72 | Self { reader, buffer } | |
73 | } | |
74 | } | |
75 | ||
76 | impl<R: AsyncRead + Unpin> Stream for AsyncReaderStream<R> { | |
77 | type Item = Result<Vec<u8>, io::Error>; | |
78 | ||
79 | fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> { | |
80 | let this = self.get_mut(); | |
81 | let mut read_buf = ReadBuf::new(&mut this.buffer); | |
82 | match ready!(Pin::new(&mut this.reader).poll_read(cx, &mut read_buf)) { | |
83 | Ok(()) => { | |
84 | let n = read_buf.filled().len(); | |
85 | if n == 0 { | |
86 | // EOF | |
87 | Poll::Ready(None) | |
88 | } else { | |
89 | Poll::Ready(Some(Ok(this.buffer[..n].to_vec()))) | |
90 | } | |
91 | } | |
92 | Err(err) => Poll::Ready(Some(Err(err))), | |
93 | } | |
94 | } | |
95 | } | |
96 | ||
97 | #[cfg(test)] | |
98 | mod test { | |
99 | use std::io; | |
100 | ||
101 | use anyhow::Error; | |
102 | use futures::stream::TryStreamExt; | |
103 | ||
104 | #[test] | |
105 | fn test_wrapped_stream_reader() -> Result<(), Error> { | |
106 | pbs_runtime::main(async { | |
107 | run_wrapped_stream_reader_test().await | |
108 | }) | |
109 | } | |
110 | ||
111 | struct DummyReader(usize); | |
112 | ||
113 | impl io::Read for DummyReader { | |
114 | fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { | |
115 | self.0 += 1; | |
116 | ||
117 | if self.0 >= 10 { | |
118 | return Ok(0); | |
119 | } | |
120 | ||
121 | unsafe { | |
122 | std::ptr::write_bytes(buf.as_mut_ptr(), 0, buf.len()); | |
123 | } | |
124 | ||
125 | Ok(buf.len()) | |
126 | } | |
127 | } | |
128 | ||
129 | async fn run_wrapped_stream_reader_test() -> Result<(), Error> { | |
130 | let mut reader = super::WrappedReaderStream::new(DummyReader(0)); | |
131 | while let Some(_data) = reader.try_next().await? { | |
132 | // just waiting | |
133 | } | |
134 | Ok(()) | |
135 | } | |
136 | } | |
137 | ||
138 | /// Wrapper around tokio::sync::mpsc::Sender, which implements Write | |
139 | pub struct AsyncChannelWriter { | |
140 | sender: Option<Sender<Result<Vec<u8>, Error>>>, | |
141 | buf: ByteBuffer, | |
142 | state: WriterState, | |
143 | } | |
144 | ||
145 | type SendResult = io::Result<Sender<Result<Vec<u8>>>>; | |
146 | ||
147 | enum WriterState { | |
148 | Ready, | |
149 | Sending(Pin<Box<dyn Future<Output = SendResult> + Send + 'static>>), | |
150 | } | |
151 | ||
152 | impl AsyncChannelWriter { | |
153 | pub fn new(sender: Sender<Result<Vec<u8>, Error>>, buf_size: usize) -> Self { | |
154 | Self { | |
155 | sender: Some(sender), | |
156 | buf: ByteBuffer::with_capacity(buf_size), | |
157 | state: WriterState::Ready, | |
158 | } | |
159 | } | |
160 | ||
161 | fn poll_write_impl( | |
162 | &mut self, | |
163 | cx: &mut Context, | |
164 | buf: &[u8], | |
165 | flush: bool, | |
166 | ) -> Poll<io::Result<usize>> { | |
167 | loop { | |
168 | match &mut self.state { | |
169 | WriterState::Ready => { | |
170 | if flush { | |
171 | if self.buf.is_empty() { | |
172 | return Poll::Ready(Ok(0)); | |
173 | } | |
174 | } else { | |
175 | let free_size = self.buf.free_size(); | |
176 | if free_size > buf.len() || self.buf.is_empty() { | |
177 | let count = free_size.min(buf.len()); | |
178 | self.buf.get_free_mut_slice()[..count].copy_from_slice(&buf[..count]); | |
179 | self.buf.add_size(count); | |
180 | return Poll::Ready(Ok(count)); | |
181 | } | |
182 | } | |
183 | ||
184 | let sender = match self.sender.take() { | |
185 | Some(sender) => sender, | |
186 | None => return Poll::Ready(Err(io_err_other("no sender"))), | |
187 | }; | |
188 | ||
189 | let data = self.buf.remove_data(self.buf.len()).to_vec(); | |
190 | let future = async move { | |
191 | sender | |
192 | .send(Ok(data)) | |
193 | .await | |
194 | .map(move |_| sender) | |
195 | .map_err(|err| io_format_err!("could not send: {}", err)) | |
196 | }; | |
197 | ||
198 | self.state = WriterState::Sending(future.boxed()); | |
199 | } | |
200 | WriterState::Sending(ref mut future) => match ready!(future.as_mut().poll(cx)) { | |
201 | Ok(sender) => { | |
202 | self.sender = Some(sender); | |
203 | self.state = WriterState::Ready; | |
204 | } | |
205 | Err(err) => return Poll::Ready(Err(err)), | |
206 | }, | |
207 | } | |
208 | } | |
209 | } | |
210 | } | |
211 | ||
212 | impl AsyncWrite for AsyncChannelWriter { | |
213 | fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> { | |
214 | let this = self.get_mut(); | |
215 | this.poll_write_impl(cx, buf, false) | |
216 | } | |
217 | ||
218 | fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> { | |
219 | let this = self.get_mut(); | |
220 | match ready!(this.poll_write_impl(cx, &[], true)) { | |
221 | Ok(_) => Poll::Ready(Ok(())), | |
222 | Err(err) => Poll::Ready(Err(err)), | |
223 | } | |
224 | } | |
225 | ||
226 | fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> { | |
227 | self.poll_flush(cx) | |
228 | } | |
229 | } |