]> git.proxmox.com Git - pve-lxc-syscalld.git/blob - src/fork.rs
replace custom Fd with std OwnedFd
[pve-lxc-syscalld.git] / src / fork.rs
1 //! Fork helper.
2 //!
3 //! Note that forking in rust can be dangerous. A fork must consider all mutexes to be in a broken
4 //! state, and cannot rely on any of its reference life times, so we must be careful what kind of
5 //! data we continue to work with.
6
7 use std::convert::TryInto;
8 use std::io;
9 use std::os::raw::c_int;
10 use std::os::unix::io::{FromRawFd, IntoRawFd};
11 use std::panic::UnwindSafe;
12
13 use tokio::io::AsyncReadExt;
14
15 use crate::io::pipe::{self, Pipe};
16 use crate::syscall::SyscallStatus;
17
18 pub async fn forking_syscall<F>(func: F) -> io::Result<SyscallStatus>
19 where
20 F: FnOnce() -> io::Result<SyscallStatus> + UnwindSafe,
21 {
22 let mut fork = Fork::new(func)?;
23 let result = fork.get_result().await?;
24 fork.wait()?;
25 Ok(result)
26 }
27
28 pub struct Fork {
29 pid: Option<libc::pid_t>,
30 // FIXME: abuse! tokio-fs is not updated to futures@0.3 yet, but a TcpStream does the same
31 // thing as a file when it's already open anyway...
32 out: Pipe<pipe::Read>,
33 }
34
35 impl Drop for Fork {
36 fn drop(&mut self) {
37 if self.pid.is_some() {
38 let _ = self.wait();
39 }
40 }
41 }
42
43 #[repr(C, packed)]
44 struct Data {
45 val: i64,
46 error: i32,
47 failure: i32,
48 }
49
50 impl Fork {
51 pub fn new<F>(func: F) -> io::Result<Self>
52 where
53 F: FnOnce() -> io::Result<SyscallStatus> + UnwindSafe,
54 {
55 let (pipe_r, pipe_w) = pipe::pipe_fds()?;
56
57 let pid = c_try!(unsafe { libc::fork() });
58 if pid == 0 {
59 drop(pipe_r);
60 let pipe_w = pipe_w.into_fd();
61 let _ = std::panic::catch_unwind(move || {
62 crate::tools::set_fd_nonblocking(&pipe_w, false).unwrap();
63 let mut pipe_w = unsafe { std::fs::File::from_raw_fd(pipe_w.into_raw_fd()) };
64 let out = match func() {
65 Ok(SyscallStatus::Ok(val)) => Data {
66 val,
67 error: 0,
68 failure: 0,
69 },
70 Ok(SyscallStatus::Err(error)) => Data {
71 val: -1,
72 error: error as _,
73 failure: 0,
74 },
75 Err(err) => Data {
76 val: -1,
77 error: -1,
78 failure: err.raw_os_error().unwrap_or(libc::EFAULT),
79 },
80 };
81
82 let slice = unsafe {
83 std::slice::from_raw_parts(
84 &out as *const Data as *const u8,
85 std::mem::size_of::<Data>(),
86 )
87 };
88
89 use std::io::Write;
90 match pipe_w.write_all(slice) {
91 Ok(()) => unsafe { libc::_exit(0) },
92 Err(_) => unsafe { libc::_exit(1) },
93 }
94 });
95 unsafe {
96 libc::_exit(-1);
97 }
98 }
99 drop(pipe_w);
100
101 let pipe_r = match pipe_r.try_into() {
102 Ok(p) => p,
103 Err(err) => {
104 unsafe {
105 libc::kill(pid, 9);
106 }
107 return Err(err);
108 }
109 };
110
111 Ok(Self {
112 pid: Some(pid),
113 out: pipe_r,
114 })
115 }
116
117 pub fn wait(&mut self) -> io::Result<()> {
118 let my_pid = self.pid.take().unwrap();
119 let mut status: c_int = -1;
120
121 loop {
122 match c_result!(unsafe { libc::waitpid(my_pid, &mut status, 0) }) {
123 Ok(pid) if pid == my_pid => break,
124 Ok(_other) => continue,
125 Err(ref err) if err.kind() == io::ErrorKind::Interrupted => continue,
126 Err(other) => return Err(other),
127 }
128 }
129
130 if status != 0 {
131 Err(io::Error::new(
132 io::ErrorKind::Other,
133 "error in child process",
134 ))
135 } else {
136 Ok(())
137 }
138 }
139
140 pub async fn get_result(&mut self) -> io::Result<SyscallStatus> {
141 let mut data: Data = unsafe { std::mem::zeroed() };
142 // Compiler bug: we currently need to put the slice into a temporary variable...
143 let dataslice: &mut [u8] = unsafe {
144 std::slice::from_raw_parts_mut(
145 &mut data as *mut Data as *mut u8,
146 std::mem::size_of::<Data>(),
147 )
148 };
149 self.out.read_exact(dataslice).await?;
150 //self.read_exact(unsafe {
151 // std::slice::from_raw_parts_mut(
152 // &mut data as *mut Data as *mut u8,
153 // std::mem::size_of::<Data>(),
154 // )
155 //})
156 //.await?;
157 if data.failure != 0 {
158 Err(io::Error::from_raw_os_error(data.failure))
159 } else if data.error == 0 {
160 Ok(SyscallStatus::Ok(data.val))
161 } else {
162 Ok(SyscallStatus::Err(data.error))
163 }
164 }
165 }