]> git.proxmox.com Git - pve-lxc-syscalld.git/commitdiff
drop EventedFd/PolledFd helpers
authorWolfgang Bumiller <w.bumiller@proxmox.com>
Mon, 18 Jul 2022 10:13:59 +0000 (12:13 +0200)
committerWolfgang Bumiller <w.bumiller@proxmox.com>
Mon, 18 Jul 2022 10:13:59 +0000 (12:13 +0200)
And use tokio's AsyncFd correctly.

And restore SOCK_NONBLOCK on the receiver.

Signed-off-by: Wolfgang Bumiller <w.bumiller@proxmox.com>
src/io/mod.rs
src/io/pipe.rs
src/io/polled_fd.rs [deleted file]
src/io/seq_packet.rs

index 589b3c3a2c49530fb4b2c7df6c09d6ef752602e5..d19aea8826d3a7f80631f0d9daa4ff29bd3465b1 100644 (file)
@@ -1,5 +1,47 @@
+use std::io;
+use std::os::unix::io::{AsRawFd, RawFd};
+
+use tokio::io::unix::AsyncFd;
+
+use crate::tools::Fd;
+
 pub mod cmsg;
 pub mod pipe;
-pub mod polled_fd;
 pub mod rw_traits;
 pub mod seq_packet;
+
+pub async fn wrap_read<R, F>(async_fd: &AsyncFd<Fd>, mut call: F) -> io::Result<R>
+where
+    F: FnMut(RawFd) -> io::Result<R>,
+{
+    let fd = async_fd.as_raw_fd();
+    loop {
+        let mut guard = async_fd.readable().await?;
+
+        match call(fd) {
+            Err(err) if err.kind() == io::ErrorKind::WouldBlock => {
+                guard.clear_ready();
+                continue;
+            }
+            other => return other,
+        }
+    }
+}
+
+pub async fn wrap_write<R, F>(async_fd: &AsyncFd<Fd>, mut call: F) -> io::Result<R>
+where
+    F: FnMut(RawFd) -> io::Result<R>,
+{
+    let fd = async_fd.as_raw_fd();
+    loop {
+        let mut guard = async_fd.writable().await?;
+
+        match call(fd) {
+            Err(err) if err.kind() == io::ErrorKind::WouldBlock => {
+                guard.clear_ready();
+                continue;
+            }
+            other => return other,
+        }
+    }
+}
index b50e0cbfb175f5a781308ba8a8ed955be82e75f1..47ad102ee640c24edfdab45db393d4176ab608dd 100644 (file)
@@ -5,9 +5,9 @@ use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd};
 use std::pin::Pin;
 use std::task::{Context, Poll};
 
+use tokio::io::unix::AsyncFd;
 use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
 
-use crate::io::polled_fd::PolledFd;
 use crate::io::rw_traits;
 use crate::tools::Fd;
 
@@ -44,9 +44,9 @@ pub fn pipe_fds() -> io::Result<(PipeFd<rw_traits::Read>, PipeFd<rw_traits::Writ
 }
 
 /// Tokio supported pipe file descriptor. `tokio::fs::File` requires tokio's complete file system
-/// feature gate, so we just use this `PolledFd` wrapper.
+/// feature gate, so we just use this `AsyncFd` wrapper.
 pub struct Pipe<RW> {
-    fd: PolledFd,
+    fd: AsyncFd<Fd>,
     _phantom: PhantomData<RW>,
 }
 
@@ -55,7 +55,7 @@ impl<RW> TryFrom<PipeFd<RW>> for Pipe<RW> {
 
     fn try_from(fd: PipeFd<RW>) -> io::Result<Self> {
         Ok(Self {
-            fd: PolledFd::new(fd.into_fd())?,
+            fd: AsyncFd::new(fd.into_fd())?,
             _phantom: PhantomData,
         })
     }
@@ -71,7 +71,7 @@ impl<RW> AsRawFd for Pipe<RW> {
 impl<RW> IntoRawFd for Pipe<RW> {
     #[inline]
     fn into_raw_fd(self) -> RawFd {
-        self.fd.into_raw_fd()
+        self.fd.into_inner().into_raw_fd()
     }
 }
 
@@ -87,16 +87,28 @@ impl<RW: rw_traits::HasRead> AsyncRead for Pipe<RW> {
         cx: &mut Context<'_>,
         buf: &mut ReadBuf,
     ) -> Poll<io::Result<()>> {
-        self.fd.wrap_read(cx, || {
-            let fd = self.as_raw_fd();
-            let mem = buf.initialize_unfilled();
-            c_result!(unsafe { libc::read(fd, mem.as_mut_ptr() as *mut libc::c_void, mem.len()) })
-                .map(|received| {
-                    if received > 0 {
-                        buf.advance(received as usize)
-                    }
-                })
-        })
+        let mut guard = ready!(self.fd.poll_read_ready(cx))?;
+
+        let fd = self.as_raw_fd();
+        let mem = buf.initialize_unfilled();
+        match c_result!(unsafe { libc::read(fd, mem.as_mut_ptr() as *mut libc::c_void, mem.len()) })
+        {
+            Ok(received) => {
+                if received > 0 {
+                    buf.advance(received as usize)
+                }
+                guard.retain_ready();
+                Poll::Ready(Ok(()))
+            }
+            Err(err) if err.kind() == io::ErrorKind::WouldBlock => {
+                guard.clear_ready();
+                Poll::Pending
+            }
+            Err(err) => {
+                guard.retain_ready();
+                Poll::Ready(Err(err))
+            }
+        }
     }
 }
 
@@ -106,11 +118,24 @@ impl<RW: rw_traits::HasWrite> AsyncWrite for Pipe<RW> {
         cx: &mut Context<'_>,
         buf: &[u8],
     ) -> Poll<io::Result<usize>> {
-        self.fd.wrap_write(cx, || {
-            let fd = self.as_raw_fd();
-            c_result!(unsafe { libc::write(fd, buf.as_ptr() as *const libc::c_void, buf.len()) })
-                .map(|res| res as usize)
-        })
+        let mut guard = ready!(self.fd.poll_write_ready(cx))?;
+
+        let fd = self.as_raw_fd();
+        match c_result!(unsafe { libc::write(fd, buf.as_ptr() as *const libc::c_void, buf.len()) })
+        {
+            Ok(res) => {
+                guard.retain_ready();
+                Poll::Ready(Ok(res as usize))
+            }
+            Err(err) if err.kind() == io::ErrorKind::WouldBlock => {
+                guard.clear_ready();
+                Poll::Pending
+            }
+            Err(err) => {
+                guard.retain_ready();
+                Poll::Ready(Err(err))
+            }
+        }
     }
 
     fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
diff --git a/src/io/polled_fd.rs b/src/io/polled_fd.rs
deleted file mode 100644 (file)
index 208186f..0000000
+++ /dev/null
@@ -1,101 +0,0 @@
-use std::io;
-use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd};
-use std::task::{Context, Poll};
-
-use tokio::io::unix::AsyncFd;
-
-use crate::tools::Fd;
-
-#[repr(transparent)]
-pub struct EventedFd {
-    fd: Fd,
-}
-
-impl EventedFd {
-    #[inline]
-    pub fn new(fd: Fd) -> Self {
-        Self { fd }
-    }
-}
-
-impl AsRawFd for EventedFd {
-    #[inline]
-    fn as_raw_fd(&self) -> RawFd {
-        self.fd.as_raw_fd()
-    }
-}
-
-impl FromRawFd for EventedFd {
-    #[inline]
-    unsafe fn from_raw_fd(fd: RawFd) -> Self {
-        Self::new(unsafe { Fd::from_raw_fd(fd) })
-    }
-}
-
-impl IntoRawFd for EventedFd {
-    #[inline]
-    fn into_raw_fd(self) -> RawFd {
-        self.fd.into_raw_fd()
-    }
-}
-
-#[repr(transparent)]
-pub struct PolledFd {
-    fd: AsyncFd<EventedFd>,
-}
-
-impl PolledFd {
-    pub fn new(fd: Fd) -> tokio::io::Result<Self> {
-        Ok(Self {
-            fd: AsyncFd::new(EventedFd::new(fd))?,
-        })
-    }
-
-    pub fn wrap_read<T>(
-        &self,
-        cx: &mut Context,
-        func: impl FnOnce() -> io::Result<T>,
-    ) -> Poll<io::Result<T>> {
-        let mut ready_guard = ready!(self.fd.poll_read_ready(cx))?;
-        match func() {
-            Ok(out) => Poll::Ready(Ok(out)),
-            Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
-                ready_guard.clear_ready();
-                Poll::Pending
-            }
-            Err(err) => Poll::Ready(Err(err)),
-        }
-    }
-
-    pub fn wrap_write<T>(
-        &self,
-        cx: &mut Context,
-        func: impl FnOnce() -> io::Result<T>,
-    ) -> Poll<io::Result<T>> {
-        let mut ready_guard = ready!(self.fd.poll_write_ready(cx))?;
-        match func() {
-            Ok(out) => Poll::Ready(Ok(out)),
-            Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
-                ready_guard.clear_ready();
-                Poll::Pending
-            }
-            Err(err) => Poll::Ready(Err(err)),
-        }
-    }
-}
-
-impl AsRawFd for PolledFd {
-    #[inline]
-    fn as_raw_fd(&self) -> RawFd {
-        self.fd.get_ref().as_raw_fd()
-    }
-}
-
-impl IntoRawFd for PolledFd {
-    #[inline]
-    fn into_raw_fd(self) -> RawFd {
-        // for the kind of resource we're managing it should always be possible to extract it from
-        // its driver
-        self.fd.into_inner().into_raw_fd()
-    }
-}
index 0455348aca58de599c14fb3e8ebed9772de5ffba..b82b1635d18aa5cc47d672e230c4c5aee46bf89f 100644 (file)
@@ -1,13 +1,11 @@
 use std::io::{self, IoSlice, IoSliceMut};
 use std::os::unix::io::{AsRawFd, FromRawFd, RawFd};
 use std::ptr;
-use std::task::{Context, Poll};
 
 use anyhow::Error;
 use nix::sys::socket::{self, AddressFamily, SockFlag, SockType, SockaddrLike};
+use tokio::io::unix::AsyncFd;
 
-use crate::io::polled_fd::PolledFd;
-use crate::poll_fn::poll_fn;
 use crate::tools::AssertSendSync;
 use crate::tools::Fd;
 
@@ -22,7 +20,7 @@ fn seq_packet_socket(flags: SockFlag) -> nix::Result<Fd> {
 }
 
 pub struct SeqPacketListener {
-    fd: PolledFd,
+    fd: AsyncFd<Fd>,
 }
 
 impl AsRawFd for SeqPacketListener {
@@ -38,14 +36,13 @@ impl SeqPacketListener {
         socket::bind(fd.as_raw_fd(), address)?;
         socket::listen(fd.as_raw_fd(), 16)?;
 
-        let fd = PolledFd::new(fd)?;
+        let fd = AsyncFd::new(fd)?;
 
         Ok(Self { fd })
     }
 
-    pub fn poll_accept(&mut self, cx: &mut Context) -> Poll<io::Result<SeqPacketSocket>> {
-        let fd = self.as_raw_fd();
-        let res = self.fd.wrap_read(cx, || {
+    pub async fn accept(&mut self) -> io::Result<SeqPacketSocket> {
+        let fd = super::wrap_read(&self.fd, |fd| {
             c_result!(unsafe {
                 libc::accept4(
                     fd,
@@ -54,22 +51,16 @@ impl SeqPacketListener {
                     libc::SOCK_CLOEXEC | libc::SOCK_NONBLOCK,
                 )
             })
-            .map(|fd| unsafe { Fd::from_raw_fd(fd as RawFd) })
-        });
-        match res {
-            Poll::Pending => Poll::Pending,
-            Poll::Ready(Ok(fd)) => Poll::Ready(SeqPacketSocket::new(fd)),
-            Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
-        }
-    }
+        })
+        .await?;
 
-    pub async fn accept(&mut self) -> io::Result<SeqPacketSocket> {
-        poll_fn(move |cx| self.poll_accept(cx)).await
+        let fd = unsafe { Fd::from_raw_fd(fd as RawFd) };
+        SeqPacketSocket::new(fd)
     }
 }
 
 pub struct SeqPacketSocket {
-    fd: PolledFd,
+    fd: AsyncFd<Fd>,
 }
 
 impl AsRawFd for SeqPacketSocket {
@@ -82,21 +73,16 @@ impl AsRawFd for SeqPacketSocket {
 impl SeqPacketSocket {
     pub fn new(fd: Fd) -> io::Result<Self> {
         Ok(Self {
-            fd: PolledFd::new(fd)?,
+            fd: AsyncFd::new(fd)?,
         })
     }
 
-    pub fn poll_sendmsg(
-        &self,
-        cx: &mut Context,
-        msg: &AssertSendSync<libc::msghdr>,
-    ) -> Poll<io::Result<usize>> {
-        let fd = self.fd.as_raw_fd();
-
-        self.fd.wrap_write(cx, || {
+    async fn sendmsg(&self, msg: &AssertSendSync<libc::msghdr>) -> io::Result<usize> {
+        let rc = super::wrap_write(&self.fd, |fd| {
             c_result!(unsafe { libc::sendmsg(fd, &msg.0 as *const libc::msghdr, 0) })
-                .map(|rc| rc as usize)
         })
+        .await?;
+        Ok(rc as usize)
     }
 
     pub async fn sendmsg_vectored(&self, iov: &[IoSlice<'_>]) -> io::Result<usize> {
@@ -110,20 +96,15 @@ impl SeqPacketSocket {
             msg_flags: 0,
         });
 
-        poll_fn(move |cx| self.poll_sendmsg(cx, &msg)).await
+        self.sendmsg(&msg).await
     }
 
-    pub fn poll_recvmsg(
-        &self,
-        cx: &mut Context,
-        msg: &mut AssertSendSync<libc::msghdr>,
-    ) -> Poll<io::Result<usize>> {
-        let fd = self.fd.as_raw_fd();
-
-        self.fd.wrap_read(cx, || {
+    async fn recvmsg(&self, msg: &mut AssertSendSync<libc::msghdr>) -> io::Result<usize> {
+        let rc = super::wrap_read(&self.fd, move |fd| {
             c_result!(unsafe { libc::recvmsg(fd, &mut msg.0 as *mut libc::msghdr, 0) })
-                .map(|rc| rc as usize)
         })
+        .await?;
+        Ok(rc as usize)
     }
 
     // clippy is wrong about this one
@@ -143,7 +124,7 @@ impl SeqPacketSocket {
             msg_flags: libc::MSG_CMSG_CLOEXEC,
         });
 
-        let data_size = poll_fn(|cx| self.poll_recvmsg(cx, &mut msg)).await?;
+        let data_size = self.recvmsg(&mut msg).await?;
         Ok((data_size, msg.0.msg_controllen as usize))
     }