]> git.proxmox.com Git - proxmox-backup.git/blob - src/tools/async_channel_writer.rs
cleanup: remove unnecessary 'mut' and '.clone()'
[proxmox-backup.git] / src / tools / async_channel_writer.rs
1 use std::future::Future;
2 use std::io;
3 use std::pin::Pin;
4 use std::task::{Context, Poll};
5
6 use anyhow::{Error, Result};
7 use futures::{future::FutureExt, ready};
8 use tokio::io::AsyncWrite;
9 use tokio::sync::mpsc::Sender;
10
11 use proxmox::io_format_err;
12 use proxmox::tools::byte_buffer::ByteBuffer;
13 use proxmox::sys::error::io_err_other;
14
15 /// Wrapper around tokio::sync::mpsc::Sender, which implements Write
16 pub struct AsyncChannelWriter {
17 sender: Option<Sender<Result<Vec<u8>, Error>>>,
18 buf: ByteBuffer,
19 state: WriterState,
20 }
21
22 type SendResult = io::Result<Sender<Result<Vec<u8>>>>;
23
24 enum WriterState {
25 Ready,
26 Sending(Pin<Box<dyn Future<Output = SendResult> + Send + 'static>>),
27 }
28
29 impl AsyncChannelWriter {
30 pub fn new(sender: Sender<Result<Vec<u8>, Error>>, buf_size: usize) -> Self {
31 Self {
32 sender: Some(sender),
33 buf: ByteBuffer::with_capacity(buf_size),
34 state: WriterState::Ready,
35 }
36 }
37
38 fn poll_write_impl(
39 &mut self,
40 cx: &mut Context,
41 buf: &[u8],
42 flush: bool,
43 ) -> Poll<io::Result<usize>> {
44 loop {
45 match &mut self.state {
46 WriterState::Ready => {
47 if flush {
48 if self.buf.is_empty() {
49 return Poll::Ready(Ok(0));
50 }
51 } else {
52 let free_size = self.buf.free_size();
53 if free_size > buf.len() || self.buf.is_empty() {
54 let count = free_size.min(buf.len());
55 self.buf.get_free_mut_slice()[..count].copy_from_slice(&buf[..count]);
56 self.buf.add_size(count);
57 return Poll::Ready(Ok(count));
58 }
59 }
60
61 let sender = match self.sender.take() {
62 Some(sender) => sender,
63 None => return Poll::Ready(Err(io_err_other("no sender"))),
64 };
65
66 let data = self.buf.remove_data(self.buf.len()).to_vec();
67 let future = async move {
68 sender
69 .send(Ok(data))
70 .await
71 .map(move |_| sender)
72 .map_err(|err| io_format_err!("could not send: {}", err))
73 };
74
75 self.state = WriterState::Sending(future.boxed());
76 }
77 WriterState::Sending(ref mut future) => match ready!(future.as_mut().poll(cx)) {
78 Ok(sender) => {
79 self.sender = Some(sender);
80 self.state = WriterState::Ready;
81 }
82 Err(err) => return Poll::Ready(Err(err)),
83 },
84 }
85 }
86 }
87 }
88
89 impl AsyncWrite for AsyncChannelWriter {
90 fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
91 let this = self.get_mut();
92 this.poll_write_impl(cx, buf, false)
93 }
94
95 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
96 let this = self.get_mut();
97 match ready!(this.poll_write_impl(cx, &[], true)) {
98 Ok(_) => Poll::Ready(Ok(())),
99 Err(err) => Poll::Ready(Err(err)),
100 }
101 }
102
103 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
104 self.poll_flush(cx)
105 }
106 }