From e420f6f97f8033ac9126fc4344023c512c288433 Mon Sep 17 00:00:00 2001 From: Wolfgang Bumiller Date: Sun, 7 Jul 2019 18:29:18 +0200 Subject: [PATCH 1/1] Whole bunch of async code and preparation to fork. The GenericStream should not be necessary once tokio-fs is updated to futures@0.3 tools.rs needs to be split up... Signed-off-by: Wolfgang Bumiller --- src/client.rs | 109 ++++++++++++++++++++++ src/fork.rs | 155 +++++++++++++++++++++++++++++++ src/lxcseccomp.rs | 18 ++-- src/main.rs | 61 ++++--------- src/nsfd.rs | 79 ++++++++++++++++ src/pidfd.rs | 44 +++++++++ src/socket.rs | 20 ++-- src/sys_mknod.rs | 14 +++ src/tools.rs | 227 ++++++++++++++++++++++++++++++++++++++++------ 9 files changed, 641 insertions(+), 86 deletions(-) create mode 100644 src/client.rs create mode 100644 src/fork.rs create mode 100644 src/nsfd.rs create mode 100644 src/pidfd.rs create mode 100644 src/sys_mknod.rs diff --git a/src/client.rs b/src/client.rs new file mode 100644 index 0000000..ccdef32 --- /dev/null +++ b/src/client.rs @@ -0,0 +1,109 @@ +use std::os::unix::io::{FromRawFd, IntoRawFd}; +use std::sync::Arc; + +use failure::{format_err, Error}; + +use crate::lxcseccomp::ProxyMessageBuffer; +use crate::socket::AsyncSeqPacketSocket; +use crate::{SyscallMeta, SyscallStatus}; + +pub struct Client { + socket: AsyncSeqPacketSocket, +} + +impl Client { + pub fn new(socket: AsyncSeqPacketSocket) -> Arc { + Arc::new(Self { socket }) + } + + /// Wrapp futures returning a `Result` so if they fail we `shutdown()` the socket to drop the + /// client. + async fn wrap_error(self: Arc, fut: F) + where + F: std::future::Future>, + { + if let Err(err) = fut.await { + eprintln!("client error, dropping connection: {}", err); + if let Err(err) = self.socket.shutdown(nix::sys::socket::Shutdown::Both) { + eprintln!(" (error shutting down client socket: {})", err); + } + } + } + + pub async fn main(self: Arc) { + self.clone().wrap_error(self.main_do()).await + } + + async fn main_do(self: Arc) -> Result<(), Error> { + loop { + let mut msg = ProxyMessageBuffer::new(64); + + let (size, mut fds) = { + let mut iovec = msg.io_vec_mut(); + self.socket.recv_fds_vectored(&mut iovec, 1).await? + }; + + if size == 0 { + eprintln!("client disconnected"); + break Ok(()); + } + + msg.set_len(size)?; + + let mut fds = fds.drain(..); + let memory = fds + .next() + .ok_or_else(|| format_err!("did not receive memory file descriptor from liblxc"))?; + + std::mem::drop(fds); + + let meta = SyscallMeta { + memory: unsafe { std::fs::File::from_raw_fd(memory.into_raw_fd()) }, + }; + + // Note: our spawned tasks here must not access our socket, as we cannot guarantee + // they'll be woken up if another task errors into `wrap_error()`. + tokio::spawn( + self.clone() + .wrap_error(self.clone().__handle_syscall(msg, meta)), + ); + } + } + + // Note: we must not use the socket for anything other than sending the result! + async fn __handle_syscall( + self: Arc, + mut msg: ProxyMessageBuffer, + meta: SyscallMeta, + ) -> Result<(), Error> { + let result = Self::handle_syscall(&msg, meta).await?; + + let resp = msg.response_mut(); + match result { + SyscallStatus::Ok(val) => { + resp.val = val; + resp.error = 0; + } + SyscallStatus::Err(err) => { + resp.val = -1; + resp.error = -err; + } + } + + let iovec = msg.io_vec_no_cookie(); + self.socket.sendmsg_vectored(&iovec).await?; + + Ok(()) + } + + async fn handle_syscall( + msg: &ProxyMessageBuffer, + meta: SyscallMeta, + ) -> Result { + match msg.request().data.nr as i64 { + libc::SYS_mknod => crate::sys_mknod::mknod(msg, meta).await, + libc::SYS_mknodat => crate::sys_mknod::mknodat(msg, meta).await, + _ => Ok(SyscallStatus::Err(libc::ENOSYS)), + } + } +} diff --git a/src/fork.rs b/src/fork.rs new file mode 100644 index 0000000..29fcc4b --- /dev/null +++ b/src/fork.rs @@ -0,0 +1,155 @@ +//! Fork helper. +//! +//! Note that forking in rust can be dangerous. A fork must consider all mutexes to be in a broken +//! state, and cannot rely on any of its reference life times, so we be careful what kind of data +//! we continue to work with. + +use std::io; +use std::os::raw::c_int; +use std::os::unix::io::{FromRawFd, IntoRawFd}; +use std::panic::UnwindSafe; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use futures::future::poll_fn; +use futures::io::AsyncRead; + +use crate::tools::Fd; +use crate::SyscallStatus; +use crate::{libc_try, libc_wrap}; + +pub async fn forking_syscall(func: F) -> io::Result +where + F: FnOnce() -> io::Result + UnwindSafe, +{ + let mut fork = Fork::new(func)?; + let mut buf = [0u8; 10]; + + use futures::io::AsyncReadExt; + fork.read_exact(&mut buf).await?; + fork.wait()?; + + Ok(SyscallStatus::Err(libc::ENOENT)) +} + +pub struct Fork { + pid: Option, + // FIXME: abuse! tokio-fs is not updated to futures@0.3 yet, but a TcpStream does the same + // thing as a file when it's already open anyway... + out: crate::tools::GenericStream, +} + +impl Drop for Fork { + fn drop(&mut self) { + if self.pid.is_some() { + let _ = self.wait(); + } + } +} + +impl Fork { + pub fn new(func: F) -> io::Result + where + F: FnOnce() -> io::Result + UnwindSafe, + { + let mut pipe: [c_int; 2] = [0, 0]; + libc_try!(unsafe { libc::pipe2(pipe.as_mut_ptr(), libc::O_CLOEXEC | libc::O_NONBLOCK) }); + let (pipe_r, pipe_w) = (Fd(pipe[0]), Fd(pipe[1])); + + let pipe_r = match crate::tools::GenericStream::from_fd(pipe_r) { + Ok(o) => o, + Err(err) => return Err(io::Error::new(io::ErrorKind::Other, err.to_string())), + }; + + let pid = libc_try!(unsafe { libc::fork() }); + if pid == 0 { + std::mem::drop(pipe_r); + let mut pipe_w = unsafe { std::fs::File::from_raw_fd(pipe_w.into_raw_fd()) }; + + let _ = std::panic::catch_unwind(move || { + let mut buf = [0u8; 10]; + + match func() { + Ok(SyscallStatus::Ok(value)) => unsafe { + std::ptr::write(buf.as_mut_ptr().add(1) as *mut i64, value); + }, + Ok(SyscallStatus::Err(value)) => { + buf[0] = 1; + unsafe { + std::ptr::write(buf.as_mut_ptr().add(1) as *mut i32, value); + } + } + Err(err) => match err.raw_os_error() { + Some(err) => { + buf[0] = 2; + unsafe { + std::ptr::write(buf.as_mut_ptr().add(1) as *mut i32, err); + } + } + None => { + buf[0] = 3; + } + }, + } + + use std::io::Write; + match pipe_w.write_all(&buf) { + Ok(()) => unsafe { libc::_exit(0) }, + Err(_) => unsafe { libc::_exit(1) }, + } + }); + unsafe { + libc::_exit(-1); + } + } + + Ok(Self { + pid: Some(pid), + out: pipe_r, + }) + } + + pub fn wait(&mut self) -> io::Result<()> { + let my_pid = self.pid.take().unwrap(); + + loop { + let mut status: c_int = -1; + match libc_wrap!(unsafe { libc::waitpid(my_pid, &mut status, 0) }) { + Ok(pid) if pid == my_pid => break, + Ok(_other) => continue, + Err(ref err) if err.kind() == io::ErrorKind::Interrupted => continue, + Err(other) => return Err(other), + } + } + + Ok(()) + } + + + pub async fn async_read(&mut self, buf: &mut [u8]) -> io::Result { + poll_fn(|cx| Pin::new(&mut *self).poll_read(cx, buf)).await + } +} + +// default impl will work +impl AsyncRead for Fork { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context, + buf: &mut [u8], + ) -> Poll> { + unsafe { self.map_unchecked_mut(|this| &mut this.out) }.poll_read(cx, buf) + } + + unsafe fn initializer(&self) -> futures::io::Initializer { + self.out.initializer() + } + + fn poll_read_vectored( + self: Pin<&mut Self>, + cx: &mut Context, + bufs: &mut [futures::io::IoSliceMut], + ) -> Poll> { + unsafe { self.map_unchecked_mut(|this| &mut this.out) }.poll_read_vectored(cx, bufs) + } +} diff --git a/src/lxcseccomp.rs b/src/lxcseccomp.rs index 99472a2..f1b6ed3 100644 --- a/src/lxcseccomp.rs +++ b/src/lxcseccomp.rs @@ -1,14 +1,14 @@ //! Module for LXC specific related seccomp handling. use std::convert::TryFrom; -use std::{io, mem}; +use std::mem; use failure::{bail, Error}; use lazy_static::lazy_static; use libc::pid_t; -use super::seccomp::{SeccompNotif, SeccompNotifResp, SeccompNotifSizes}; -use super::tools::{IoVec, IoVecMut}; +use crate::seccomp::{SeccompNotif, SeccompNotifResp, SeccompNotifSizes}; +use crate::tools::{IoVec, IoVecMut}; /// Seccomp notification proxy message sent by the lxc monitor. /// @@ -46,7 +46,6 @@ pub struct SeccompNotifyProxyMsg { } /// Helper to receive and verify proxy notification messages. -#[repr(C)] pub struct ProxyMessageBuffer { proxy_msg: SeccompNotifyProxyMsg, seccomp_notif: SeccompNotif, @@ -72,27 +71,28 @@ unsafe fn io_vec(value: &T) -> IoVec { } lazy_static! { - static ref SECCOMP_SIZES: SeccompNotifSizes = - SeccompNotifSizes::get_checked().map_err(|e| panic!("{}\nrefusing to run", e)).unwrap(); + static ref SECCOMP_SIZES: SeccompNotifSizes = SeccompNotifSizes::get_checked() + .map_err(|e| panic!("{}\nrefusing to run", e)) + .unwrap(); } impl ProxyMessageBuffer { /// Allocate a new proxy message buffer with a specific maximum cookie size. - pub fn new(max_cookie: usize) -> io::Result { + pub fn new(max_cookie: usize) -> Self { let sizes = SECCOMP_SIZES.clone(); let seccomp_packet_size = mem::size_of::() + sizes.notif as usize + sizes.notif_resp as usize; - Ok(Self { + Self { proxy_msg: unsafe { mem::zeroed() }, seccomp_notif: unsafe { mem::zeroed() }, seccomp_resp: unsafe { mem::zeroed() }, cookie_buf: unsafe { super::tools::vec::uninitialized(max_cookie) }, sizes, seccomp_packet_size, - }) + } } /// Resets the buffer capacity and returns an IoVecMut used to fill the buffer. diff --git a/src/main.rs b/src/main.rs index ded493e..12f06b9 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,17 +2,31 @@ use std::ffi::OsString; use std::io; -use std::sync::Arc; use failure::{bail, format_err, Error}; use nix::sys::socket::SockAddr; +pub mod client; +pub mod fork; pub mod lxcseccomp; +pub mod nsfd; +pub mod pidfd; pub mod seccomp; pub mod socket; +pub mod sys_mknod; pub mod tools; -use socket::{AsyncSeqPacketSocket, SeqPacketListener}; +use socket::SeqPacketListener; + +pub enum SyscallStatus { + Ok(i64), + Err(i32), +} + +pub struct SyscallMeta { + //pid: pidfd::PidFd, + memory: std::fs::File, +} fn main() { if let Err(err) = run() { @@ -52,46 +66,7 @@ async fn async_run_do(socket_path: OsString) -> Result<(), Error> { .map_err(|e| format_err!("failed to create listening socket: {}", e))?; loop { let client = listener.accept().await?; - tokio::spawn(handle_client(Arc::new(client))); - } -} - -async fn handle_client(client: Arc) { - if let Err(err) = handle_client_do(client).await { - eprintln!( - "error communicating with client, dropping connection: {}", - err - ); + let client = client::Client::new(client); + tokio::spawn(client.main()); } } - -async fn handle_client_do(client: Arc) -> Result<(), Error> { - let mut msgbuf = lxcseccomp::ProxyMessageBuffer::new(64) - .map_err(|e| format_err!("failed to allocate proxy message buffer: {}", e))?; - - loop { - let (size, _fds) = { - let mut iovec = msgbuf.io_vec_mut(); - client.recv_fds_vectored(&mut iovec, 1).await? - }; - - if size == 0 { - println!("client disconnected"); - break; - } - - msgbuf.set_len(size)?; - - let req = msgbuf.request(); - println!("Received request for syscall {}", req.data.nr); - - let resp = msgbuf.response_mut(); - resp.val = 0; - resp.error = -libc::ENOENT; - - let iovec = msgbuf.io_vec_no_cookie(); - client.sendmsg_vectored(&iovec).await?; - } - - Ok(()) -} diff --git a/src/nsfd.rs b/src/nsfd.rs new file mode 100644 index 0000000..5c7f98a --- /dev/null +++ b/src/nsfd.rs @@ -0,0 +1,79 @@ +use std::io; +use std::marker::PhantomData; +use std::os::raw::c_int; +use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd}; +use std::path::Path; + +use crate::tools::path_ptr; +use crate::{file_descriptor_type, libc_try}; + +pub mod ns_type { + pub trait NsType { + const TYPE: libc::c_int; + } + + macro_rules! define_ns_type { + ($name:ident, $number:expr) => { + pub struct $name; + impl NsType for $name { + const TYPE: libc::c_int = $number; + } + }; + } + + define_ns_type!(Mount, libc::CLONE_NEWNS); + define_ns_type!(User, libc::CLONE_NEWUSER); + define_ns_type!(Cgroup, libc::CLONE_NEWCGROUP); +} + +pub use ns_type::NsType; + +file_descriptor_type!(RawNsFd); + +impl RawNsFd { + pub fn open>(path: P) -> io::Result { + Self::openat(libc::AT_FDCWD, path.as_ref()) + } + + pub fn openat>(fd: RawFd, path: P) -> io::Result { + let fd = libc_try!(unsafe { + libc::openat( + fd, + path_ptr(path.as_ref()), + libc::O_RDONLY | libc::O_CLOEXEC, + ) + }); + + Ok(Self(fd)) + } + + pub fn setns(&self, ns_type: c_int) -> io::Result<()> { + libc_try!(unsafe { libc::setns(self.0, ns_type) }); + Ok(()) + } +} + +#[repr(transparent)] +pub struct NsFd(RawNsFd, PhantomData); + +impl std::ops::Deref for NsFd { + type Target = RawNsFd; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl NsFd { + pub fn open>(path: P) -> io::Result { + Ok(Self(RawNsFd::open(path.as_ref())?, PhantomData)) + } + + pub fn openat>(fd: RawFd, path: P) -> io::Result { + Ok(Self(RawNsFd::openat(fd, path.as_ref())?, PhantomData)) + } + + pub fn setns(&self) -> io::Result<()> { + self.0.setns(T::TYPE) + } +} diff --git a/src/pidfd.rs b/src/pidfd.rs new file mode 100644 index 0000000..0782027 --- /dev/null +++ b/src/pidfd.rs @@ -0,0 +1,44 @@ +//! pidfd helper functionality + +use std::ffi::CString; +use std::io; +use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd}; + +use crate::nsfd::{ns_type, NsFd}; +use crate::tools::Fd; +use crate::{file_descriptor_type, libc_try}; + +file_descriptor_type!(PidFd); + +impl PidFd { + pub fn open(pid: libc::pid_t) -> io::Result { + let path = CString::new(format!("/proc/{}", pid)).unwrap(); + + let fd = + libc_try!(unsafe { libc::open(path.as_ptr(), libc::O_DIRECTORY | libc::O_CLOEXEC) }); + + Ok(Self(fd)) + } + + pub fn mount_namespace(&self) -> io::Result> { + NsFd::openat(self.0, "ns/mnt") + } + + pub fn cgroup_namespace(&self) -> io::Result> { + NsFd::openat(self.0, "ns/cgroup") + } + + pub fn user_namespace(&self) -> io::Result> { + NsFd::openat(self.0, "ns/user") + } + + pub fn cwd_fd(&self) -> io::Result { + Ok(Fd(libc_try!(unsafe { + libc::openat( + self.as_raw_fd(), + b"cwd".as_ptr() as *const _, + libc::O_DIRECTORY, + ) + }))) + } +} diff --git a/src/socket.rs b/src/socket.rs index 46ddb2e..7fc507c 100644 --- a/src/socket.rs +++ b/src/socket.rs @@ -10,7 +10,7 @@ use futures::future::poll_fn; use futures::ready; use nix::sys::socket::{AddressFamily, SockAddr, SockFlag, SockType}; -use super::tools::{vec, Fd, IoVec, IoVecMut}; +use crate::tools::{vec, Fd, IoVec, IoVecMut}; pub struct SeqPacketSocket(Fd); @@ -95,6 +95,12 @@ impl SeqPacketSocket { fn as_fd(&self) -> &Fd { &self.0 } + + /// Shutdown parts of the socket. + #[inline] + pub fn shutdown(&self, how: nix::sys::socket::Shutdown) -> nix::Result<()> { + nix::sys::socket::shutdown(self.as_raw_fd(), how) + } } impl AsRawFd for SeqPacketSocket { @@ -205,6 +211,12 @@ impl AsyncSeqPacketSocket { }) } + /// Shutdown parts of the socket. + #[inline] + pub fn shutdown(&self, how: nix::sys::socket::Shutdown) -> nix::Result<()> { + self.socket.shutdown(how) + } + pub fn poll_recv_fds_vectored( &self, iov: &mut [IoVecMut], @@ -233,11 +245,7 @@ impl AsyncSeqPacketSocket { poll_fn(move |cx| self.poll_recv_fds_vectored(iov, num_fds, cx)).await } - pub fn poll_sendmsg_vectored( - &self, - data: &[IoVec], - cx: &mut Context, - ) -> Poll> { + pub fn poll_sendmsg_vectored(&self, data: &[IoVec], cx: &mut Context) -> Poll> { loop { match self.socket.sendmsg_vectored(data) { Ok(res) => break Poll::Ready(Ok(res)), diff --git a/src/sys_mknod.rs b/src/sys_mknod.rs new file mode 100644 index 0000000..5fdcf1a --- /dev/null +++ b/src/sys_mknod.rs @@ -0,0 +1,14 @@ +use failure::Error; + +use crate::lxcseccomp::ProxyMessageBuffer; +use crate::{SyscallMeta, SyscallStatus}; + +pub async fn mknod(_msg: &ProxyMessageBuffer, _meta: SyscallMeta) -> Result { + println!("Responding with ENOENT"); + Ok(SyscallStatus::Err(libc::ENOENT)) +} + +pub async fn mknodat(_msg: &ProxyMessageBuffer, _meta: SyscallMeta) -> Result { + println!("Responding with ENOENT"); + Ok(SyscallStatus::Err(libc::ENOENT)) +} diff --git a/src/tools.rs b/src/tools.rs index 4e0780d..d62ec8c 100644 --- a/src/tools.rs +++ b/src/tools.rs @@ -6,46 +6,66 @@ use std::io; use std::marker::PhantomData; use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd}; +use std::pin::Pin; +use std::task::{Context, Poll}; +use failure::{bail, Error}; +use futures::io::{AsyncRead, AsyncWrite}; +use futures::ready; use mio::unix::EventedFd; use mio::{PollOpt, Ready, Token}; -/// Guard a raw file descriptor with a drop handler. This is mostly useful when access to an owned -/// `RawFd` is required without the corresponding handler object (such as when only the file -/// descriptor number is required in a closure which may be dropped instead of being executed). -#[repr(transparent)] -pub struct Fd(pub RawFd); +#[macro_export] +macro_rules! file_descriptor_type { + ($type:ident) => { + #[repr(transparent)] + pub struct $type(RawFd); -impl Drop for Fd { - fn drop(&mut self) { - if self.0 != -1 { - unsafe { - libc::close(self.0); + crate::file_descriptor_impl!($type); + }; +} + +#[macro_export] +macro_rules! file_descriptor_impl { + ($type:ty) => { + impl Drop for $type { + fn drop(&mut self) { + unsafe { + libc::close(self.0); + } } } - } -} -impl AsRawFd for Fd { - fn as_raw_fd(&self) -> RawFd { - self.0 - } -} + impl AsRawFd for $type { + fn as_raw_fd(&self) -> RawFd { + self.0 + } + } -impl IntoRawFd for Fd { - fn into_raw_fd(mut self) -> RawFd { - let fd = self.0; - self.0 = -1; - fd - } -} + impl IntoRawFd for $type { + fn into_raw_fd(mut self) -> RawFd { + let fd = self.0; + self.0 = -libc::EBADF; + fd + } + } -impl FromRawFd for Fd { - unsafe fn from_raw_fd(fd: RawFd) -> Self { - Self(fd) - } + impl FromRawFd for $type { + unsafe fn from_raw_fd(fd: RawFd) -> Self { + Self(fd) + } + } + }; } +/// Guard a raw file descriptor with a drop handler. This is mostly useful when access to an owned +/// `RawFd` is required without the corresponding handler object (such as when only the file +/// descriptor number is required in a closure which may be dropped instead of being executed). +#[repr(transparent)] +pub struct Fd(pub RawFd); + +file_descriptor_impl!(Fd); + impl mio::Evented for Fd { fn register( &self, @@ -72,6 +92,127 @@ impl mio::Evented for Fd { } } +pub struct AsyncFd { + fd: Fd, + registration: tokio::reactor::Registration, +} + +impl Drop for AsyncFd { + fn drop(&mut self) { + if let Err(err) = self.registration.deregister(&self.fd) { + eprintln!("failed to deregister I/O resource with reactor: {}", err); + } + } +} + +impl AsyncFd { + pub fn new(fd: Fd) -> Result { + let registration = tokio::reactor::Registration::new(); + if !registration.register(&fd)? { + bail!("duplicate file descriptor registration?"); + } + + Ok(Self { fd, registration }) + } + + pub fn poll_read_ready(&self, cx: &mut Context) -> Poll> { + self.registration.poll_read_ready(cx) + } + + pub fn poll_write_ready(&self, cx: &mut Context) -> Poll> { + self.registration.poll_write_ready(cx) + } +} + +impl AsRawFd for AsyncFd { + fn as_raw_fd(&self) -> RawFd { + self.fd.as_raw_fd() + } +} + +// At the time of writing, tokio-fs in master was disabled as it wasn't updated to futures@0.3 yet. +pub struct GenericStream(Option); + +impl GenericStream { + pub fn from_fd(fd: Fd) -> Result { + AsyncFd::new(fd).map(|fd| Self(Some(fd))) + } + + fn raw_fd(&self) -> RawFd { + self.0 + .as_ref() + .map(|fd| fd.as_raw_fd()) + .unwrap_or(-libc::EBADF) + } + + pub fn poll_read_ready(&self, cx: &mut Context) -> Poll> { + match self.0 { + Some(ref fd) => fd.poll_read_ready(cx), + None => Poll::Ready(Err(io::ErrorKind::InvalidInput.into())), + } + } + + pub fn poll_write_ready(&self, cx: &mut Context) -> Poll> { + match self.0 { + Some(ref fd) => fd.poll_write_ready(cx), + None => Poll::Ready(Err(io::ErrorKind::InvalidInput.into())), + } + } +} + +impl AsyncRead for GenericStream { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context, + buf: &mut [u8], + ) -> Poll> { + loop { + let res = unsafe { libc::read(self.raw_fd(), buf.as_mut_ptr() as *mut _, buf.len()) }; + if res >= 0 { + return Poll::Ready(Ok(res as usize)); + } + + let err = io::Error::last_os_error(); + if err.kind() == io::ErrorKind::WouldBlock { + match ready!(self.poll_read_ready(cx)) { + Ok(_) => continue, + Err(err) => return Poll::Ready(Err(err)), + } + } + return Poll::Ready(Err(err)); + } + } +} + +impl AsyncWrite for GenericStream { + fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll> { + loop { + let res = unsafe { libc::write(self.raw_fd(), buf.as_ptr() as *const _, buf.len()) }; + if res >= 0 { + return Poll::Ready(Ok(res as usize)); + } + + let err = io::Error::last_os_error(); + if err.kind() == io::ErrorKind::WouldBlock { + match ready!(self.poll_write_ready(cx)) { + Ok(_) => continue, + Err(err) => return Poll::Ready(Err(err)), + } + } + return Poll::Ready(Err(err)); + } + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_close(self: Pin<&mut Self>, _cx: &mut Context) -> Poll> { + std::mem::drop(self.get_mut().0.take()); + Poll::Ready(Ok(())) + } +} + /// Byte vector utilities. pub mod vec { /// Create an uninitialized byte vector of a specific size. @@ -132,3 +273,33 @@ impl IoVecMut<'_> { } } } + +#[macro_export] +macro_rules! libc_wrap { + ($expr:expr) => {{ + let res = $expr; + if res == -1 { + Err(io::Error::last_os_error()) + } else { + Ok::<_, io::Error>(res) + } + }}; +} + +#[macro_export] +macro_rules! libc_try { + ($expr:expr) => {{ + let res = $expr; + if res == -1 { + return Err(io::Error::last_os_error()); + } else { + res + } + }}; +} + +pub fn path_ptr(path: &std::path::Path) -> *const libc::c_char { + use std::os::unix::ffi::OsStrExt; + + path.as_os_str().as_bytes().as_ptr() as *const libc::c_char +} -- 2.39.5