]>
Commit | Line | Data |
---|---|---|
64d527ab WB |
1 | use std::cell::RefCell; |
2 | use std::collections::VecDeque; | |
3 | use std::future::Future; | |
4 | use std::io; | |
5 | use std::pin::Pin; | |
6 | use std::sync::{Arc, Condvar, Mutex, Weak}; | |
7 | use std::task::{Context, Poll}; | |
8 | use std::thread::JoinHandle; | |
9 | ||
10 | type BoxFut = Box<dyn Future<Output = ()> + Send + 'static>; | |
11 | ||
12 | #[derive(Clone)] | |
13 | struct Task(Arc<TaskInner>); | |
14 | ||
15 | impl Task { | |
16 | fn into_raw(this: Task) -> *const TaskInner { | |
17 | Arc::into_raw(this.0) | |
18 | } | |
19 | ||
20 | unsafe fn from_raw(ptr: *const TaskInner) -> Self { | |
21 | Self(Arc::from_raw(ptr)) | |
22 | } | |
23 | ||
24 | fn wake(self) { | |
25 | if let Some(queue) = self.0.queue.upgrade() { | |
26 | queue.queue(self); | |
27 | } | |
28 | } | |
29 | ||
30 | fn into_raw_waker(this: Task) -> std::task::RawWaker { | |
31 | std::task::RawWaker::new( | |
32 | Task::into_raw(this) as *const (), | |
33 | &std::task::RawWakerVTable::new( | |
34 | waker_clone_fn, | |
35 | waker_wake_fn, | |
36 | waker_wake_by_ref_fn, | |
37 | waker_drop_fn, | |
38 | ), | |
39 | ) | |
40 | } | |
41 | } | |
42 | ||
43 | struct TaskInner { | |
44 | future: Mutex<Option<BoxFut>>, | |
45 | queue: Weak<TaskQueue>, | |
46 | } | |
47 | ||
48 | struct TaskQueue { | |
49 | queue: Mutex<VecDeque<Task>>, | |
50 | queue_cv: Condvar, | |
51 | } | |
52 | ||
53 | impl TaskQueue { | |
54 | fn new() -> Self { | |
55 | Self { | |
56 | queue: Mutex::new(VecDeque::with_capacity(32)), | |
57 | queue_cv: Condvar::new(), | |
58 | } | |
59 | } | |
60 | ||
61 | fn new_task(self: Arc<TaskQueue>, future: BoxFut) { | |
62 | let task = Task(Arc::new(TaskInner { | |
63 | future: Mutex::new(Some(future)), | |
64 | queue: Arc::downgrade(&self), | |
65 | })); | |
66 | ||
67 | self.queue(task); | |
68 | } | |
69 | ||
70 | fn queue(&self, task: Task) { | |
71 | let mut queue = self.queue.lock().unwrap(); | |
72 | queue.push_back(task); | |
73 | self.queue_cv.notify_one(); | |
74 | } | |
75 | ||
76 | /// Blocks until a task is available | |
77 | fn get_task(&self) -> Task { | |
78 | let mut queue = self.queue.lock().unwrap(); | |
79 | loop { | |
80 | if let Some(task) = queue.pop_front() { | |
81 | return task; | |
82 | } else { | |
83 | queue = self.queue_cv.wait(queue).unwrap(); | |
84 | } | |
85 | } | |
86 | } | |
87 | } | |
88 | ||
89 | pub struct ThreadPool { | |
6f911968 | 90 | _threads: Mutex<Vec<JoinHandle<()>>>, |
64d527ab WB |
91 | queue: Arc<TaskQueue>, |
92 | } | |
93 | ||
94 | impl ThreadPool { | |
95 | pub fn new() -> io::Result<Self> { | |
725170f2 | 96 | let count = num_cpus()?; |
64d527ab WB |
97 | |
98 | let queue = Arc::new(TaskQueue::new()); | |
99 | ||
100 | let mut threads = Vec::new(); | |
101 | for thread_id in 0..count { | |
102 | threads.push(std::thread::spawn({ | |
103 | let queue = Arc::clone(&queue); | |
104 | move || thread_main(queue, thread_id) | |
105 | })); | |
106 | } | |
107 | ||
108 | Ok(Self { | |
6f911968 | 109 | _threads: Mutex::new(threads), |
64d527ab WB |
110 | queue, |
111 | }) | |
112 | } | |
113 | ||
114 | pub fn spawn_ok<T>(&self, future: T) | |
115 | where | |
116 | T: Future<Output = ()> + Send + 'static, | |
117 | { | |
118 | self.do_spawn(Box::new(future)); | |
119 | } | |
120 | ||
121 | fn do_spawn(&self, future: BoxFut) { | |
122 | Arc::clone(&self.queue).new_task(future); | |
123 | } | |
124 | ||
125 | pub fn run<R, T>(&self, future: T) -> R | |
126 | where | |
127 | T: Future<Output = R> + Send + 'static, | |
128 | R: Send + 'static, | |
129 | { | |
130 | let mutex: Arc<Mutex<Option<R>>> = Arc::new(Mutex::new(None)); | |
131 | let cv = Arc::new(Condvar::new()); | |
132 | let mut guard = mutex.lock().unwrap(); | |
133 | self.spawn_ok({ | |
134 | let mutex = Arc::clone(&mutex); | |
135 | let cv = Arc::clone(&cv); | |
136 | async move { | |
137 | let result = future.await; | |
138 | *(mutex.lock().unwrap()) = Some(result); | |
139 | cv.notify_all(); | |
140 | } | |
141 | }); | |
142 | loop { | |
143 | guard = cv.wait(guard).unwrap(); | |
144 | if let Some(result) = guard.take() { | |
145 | return result; | |
146 | } | |
147 | } | |
148 | } | |
149 | } | |
150 | ||
151 | thread_local! { | |
152 | static CURRENT_QUEUE: RefCell<*const TaskQueue> = RefCell::new(std::ptr::null()); | |
153 | static CURRENT_TASK: RefCell<*const Task> = RefCell::new(std::ptr::null()); | |
154 | } | |
155 | ||
156 | fn thread_main(task_queue: Arc<TaskQueue>, _thread_id: usize) { | |
157 | CURRENT_QUEUE.with(|q| *q.borrow_mut() = task_queue.as_ref() as *const TaskQueue); | |
158 | ||
159 | let local_waker = unsafe { | |
160 | std::task::Waker::from_raw(std::task::RawWaker::new( | |
161 | std::ptr::null(), | |
162 | &std::task::RawWakerVTable::new( | |
163 | local_waker_clone_fn, | |
164 | local_waker_wake_fn, | |
165 | local_waker_wake_fn, | |
166 | local_waker_drop_fn, | |
167 | ), | |
168 | )) | |
169 | }; | |
170 | ||
171 | let mut context = Context::from_waker(&local_waker); | |
172 | ||
173 | loop { | |
174 | let task: Task = task_queue.get_task(); | |
175 | let task: Pin<&Task> = Pin::new(&task); | |
176 | let task = task.get_ref(); | |
177 | CURRENT_TASK.with(|c| *c.borrow_mut() = task as *const Task); | |
178 | ||
179 | let mut task_future = task.0.future.lock().unwrap(); | |
180 | match task_future.take() { | |
181 | Some(mut future) => { | |
725170f2 | 182 | //eprintln!("Thread {} has some work!", thread_id); |
64d527ab WB |
183 | let pin = unsafe { Pin::new_unchecked(&mut *future) }; |
184 | match pin.poll(&mut context) { | |
185 | Poll::Ready(()) => (), // done with that task | |
186 | Poll::Pending => { | |
187 | *task_future = Some(future); | |
188 | } | |
189 | } | |
190 | } | |
191 | None => eprintln!("task polled after ready"), | |
192 | } | |
193 | } | |
194 | } | |
195 | ||
196 | unsafe fn local_waker_clone_fn(_: *const ()) -> std::task::RawWaker { | |
197 | let task: Task = CURRENT_TASK.with(|t| Task::clone(&**t.borrow())); | |
198 | Task::into_raw_waker(task) | |
199 | } | |
200 | ||
201 | unsafe fn local_waker_wake_fn(_: *const ()) { | |
202 | let task: Task = CURRENT_TASK.with(|t| Task::clone(&**t.borrow())); | |
203 | CURRENT_QUEUE.with(|q| (**q.borrow()).queue(task)); | |
204 | } | |
205 | ||
206 | unsafe fn local_waker_drop_fn(_: *const ()) {} | |
207 | ||
208 | unsafe fn waker_clone_fn(this: *const ()) -> std::task::RawWaker { | |
209 | let this = Task::from_raw(this as *const TaskInner); | |
210 | let clone = this.clone(); | |
211 | let _ = Task::into_raw(this); | |
212 | Task::into_raw_waker(clone) | |
213 | } | |
214 | ||
215 | unsafe fn waker_wake_fn(this: *const ()) { | |
216 | let this = Task::from_raw(this as *const TaskInner); | |
217 | this.wake(); | |
218 | } | |
219 | ||
220 | unsafe fn waker_wake_by_ref_fn(this: *const ()) { | |
221 | let this = Task::from_raw(this as *const TaskInner); | |
222 | this.clone().wake(); | |
223 | let _ = Task::into_raw(this); | |
224 | } | |
225 | ||
226 | unsafe fn waker_drop_fn(this: *const ()) { | |
227 | let _this = Task::from_raw(this as *const TaskInner); | |
228 | } | |
229 | ||
725170f2 | 230 | fn num_cpus() -> io::Result<usize> { |
64d527ab WB |
231 | let rc = unsafe { libc::sysconf(libc::_SC_NPROCESSORS_ONLN) }; |
232 | if rc < 0 { | |
233 | Err(io::Error::last_os_error()) | |
234 | } else { | |
235 | Ok(rc as usize) | |
236 | } | |
237 | } |