]> git.proxmox.com Git - pve-lxc-syscalld.git/blame - src/executor.rs
also use pidfd_open for explicit pids
[pve-lxc-syscalld.git] / src / executor.rs
CommitLineData
64d527ab
WB
1use std::cell::RefCell;
2use std::collections::VecDeque;
3use std::future::Future;
4use std::io;
5use std::pin::Pin;
6use std::sync::{Arc, Condvar, Mutex, Weak};
7use std::task::{Context, Poll};
8use std::thread::JoinHandle;
9
10type BoxFut = Box<dyn Future<Output = ()> + Send + 'static>;
11
12#[derive(Clone)]
13struct Task(Arc<TaskInner>);
14
15impl Task {
16 fn into_raw(this: Task) -> *const TaskInner {
17 Arc::into_raw(this.0)
18 }
19
20 unsafe fn from_raw(ptr: *const TaskInner) -> Self {
21 Self(Arc::from_raw(ptr))
22 }
23
24 fn wake(self) {
25 if let Some(queue) = self.0.queue.upgrade() {
26 queue.queue(self);
27 }
28 }
29
30 fn into_raw_waker(this: Task) -> std::task::RawWaker {
31 std::task::RawWaker::new(
32 Task::into_raw(this) as *const (),
33 &std::task::RawWakerVTable::new(
34 waker_clone_fn,
35 waker_wake_fn,
36 waker_wake_by_ref_fn,
37 waker_drop_fn,
38 ),
39 )
40 }
41}
42
43struct TaskInner {
44 future: Mutex<Option<BoxFut>>,
45 queue: Weak<TaskQueue>,
46}
47
48struct TaskQueue {
49 queue: Mutex<VecDeque<Task>>,
50 queue_cv: Condvar,
51}
52
53impl TaskQueue {
54 fn new() -> Self {
55 Self {
56 queue: Mutex::new(VecDeque::with_capacity(32)),
57 queue_cv: Condvar::new(),
58 }
59 }
60
61 fn new_task(self: Arc<TaskQueue>, future: BoxFut) {
62 let task = Task(Arc::new(TaskInner {
63 future: Mutex::new(Some(future)),
64 queue: Arc::downgrade(&self),
65 }));
66
67 self.queue(task);
68 }
69
70 fn queue(&self, task: Task) {
71 let mut queue = self.queue.lock().unwrap();
72 queue.push_back(task);
73 self.queue_cv.notify_one();
74 }
75
76 /// Blocks until a task is available
77 fn get_task(&self) -> Task {
78 let mut queue = self.queue.lock().unwrap();
79 loop {
80 if let Some(task) = queue.pop_front() {
81 return task;
82 } else {
83 queue = self.queue_cv.wait(queue).unwrap();
84 }
85 }
86 }
87}
88
89pub struct ThreadPool {
6f911968 90 _threads: Mutex<Vec<JoinHandle<()>>>,
64d527ab
WB
91 queue: Arc<TaskQueue>,
92}
93
94impl ThreadPool {
95 pub fn new() -> io::Result<Self> {
725170f2 96 let count = num_cpus()?;
64d527ab
WB
97
98 let queue = Arc::new(TaskQueue::new());
99
100 let mut threads = Vec::new();
101 for thread_id in 0..count {
102 threads.push(std::thread::spawn({
103 let queue = Arc::clone(&queue);
104 move || thread_main(queue, thread_id)
105 }));
106 }
107
108 Ok(Self {
6f911968 109 _threads: Mutex::new(threads),
64d527ab
WB
110 queue,
111 })
112 }
113
114 pub fn spawn_ok<T>(&self, future: T)
115 where
116 T: Future<Output = ()> + Send + 'static,
117 {
118 self.do_spawn(Box::new(future));
119 }
120
121 fn do_spawn(&self, future: BoxFut) {
122 Arc::clone(&self.queue).new_task(future);
123 }
124
125 pub fn run<R, T>(&self, future: T) -> R
126 where
127 T: Future<Output = R> + Send + 'static,
128 R: Send + 'static,
129 {
130 let mutex: Arc<Mutex<Option<R>>> = Arc::new(Mutex::new(None));
131 let cv = Arc::new(Condvar::new());
132 let mut guard = mutex.lock().unwrap();
133 self.spawn_ok({
134 let mutex = Arc::clone(&mutex);
135 let cv = Arc::clone(&cv);
136 async move {
137 let result = future.await;
138 *(mutex.lock().unwrap()) = Some(result);
139 cv.notify_all();
140 }
141 });
142 loop {
143 guard = cv.wait(guard).unwrap();
144 if let Some(result) = guard.take() {
145 return result;
146 }
147 }
148 }
149}
150
151thread_local! {
152 static CURRENT_QUEUE: RefCell<*const TaskQueue> = RefCell::new(std::ptr::null());
153 static CURRENT_TASK: RefCell<*const Task> = RefCell::new(std::ptr::null());
154}
155
156fn thread_main(task_queue: Arc<TaskQueue>, _thread_id: usize) {
157 CURRENT_QUEUE.with(|q| *q.borrow_mut() = task_queue.as_ref() as *const TaskQueue);
158
159 let local_waker = unsafe {
160 std::task::Waker::from_raw(std::task::RawWaker::new(
161 std::ptr::null(),
162 &std::task::RawWakerVTable::new(
163 local_waker_clone_fn,
164 local_waker_wake_fn,
165 local_waker_wake_fn,
166 local_waker_drop_fn,
167 ),
168 ))
169 };
170
171 let mut context = Context::from_waker(&local_waker);
172
173 loop {
174 let task: Task = task_queue.get_task();
175 let task: Pin<&Task> = Pin::new(&task);
176 let task = task.get_ref();
177 CURRENT_TASK.with(|c| *c.borrow_mut() = task as *const Task);
178
179 let mut task_future = task.0.future.lock().unwrap();
180 match task_future.take() {
181 Some(mut future) => {
725170f2 182 //eprintln!("Thread {} has some work!", thread_id);
64d527ab
WB
183 let pin = unsafe { Pin::new_unchecked(&mut *future) };
184 match pin.poll(&mut context) {
185 Poll::Ready(()) => (), // done with that task
186 Poll::Pending => {
187 *task_future = Some(future);
188 }
189 }
190 }
191 None => eprintln!("task polled after ready"),
192 }
193 }
194}
195
196unsafe fn local_waker_clone_fn(_: *const ()) -> std::task::RawWaker {
197 let task: Task = CURRENT_TASK.with(|t| Task::clone(&**t.borrow()));
198 Task::into_raw_waker(task)
199}
200
201unsafe fn local_waker_wake_fn(_: *const ()) {
202 let task: Task = CURRENT_TASK.with(|t| Task::clone(&**t.borrow()));
203 CURRENT_QUEUE.with(|q| (**q.borrow()).queue(task));
204}
205
206unsafe fn local_waker_drop_fn(_: *const ()) {}
207
208unsafe fn waker_clone_fn(this: *const ()) -> std::task::RawWaker {
209 let this = Task::from_raw(this as *const TaskInner);
210 let clone = this.clone();
211 let _ = Task::into_raw(this);
212 Task::into_raw_waker(clone)
213}
214
215unsafe fn waker_wake_fn(this: *const ()) {
216 let this = Task::from_raw(this as *const TaskInner);
217 this.wake();
218}
219
220unsafe fn waker_wake_by_ref_fn(this: *const ()) {
221 let this = Task::from_raw(this as *const TaskInner);
222 this.clone().wake();
223 let _ = Task::into_raw(this);
224}
225
226unsafe fn waker_drop_fn(this: *const ()) {
227 let _this = Task::from_raw(this as *const TaskInner);
228}
229
725170f2 230fn num_cpus() -> io::Result<usize> {
64d527ab
WB
231 let rc = unsafe { libc::sysconf(libc::_SC_NPROCESSORS_ONLN) };
232 if rc < 0 {
233 Err(io::Error::last_os_error())
234 } else {
235 Ok(rc as usize)
236 }
237}