1 use std
::cell
::RefCell
;
2 use std
::collections
::VecDeque
;
3 use std
::future
::Future
;
6 use std
::sync
::{Arc, Condvar, Mutex, Weak}
;
7 use std
::task
::{Context, Poll}
;
8 use std
::thread
::JoinHandle
;
10 type BoxFut
= Box
<dyn Future
<Output
= ()> + Send
+ '
static>;
13 struct Task(Arc
<TaskInner
>);
16 fn into_raw(this
: Task
) -> *const TaskInner
{
20 unsafe fn from_raw(ptr
: *const TaskInner
) -> Self {
21 Self(Arc
::from_raw(ptr
))
25 if let Some(queue
) = self.0.queue
.upgrade() {
30 fn into_raw_waker(this
: Task
) -> std
::task
::RawWaker
{
31 std
::task
::RawWaker
::new(
32 Task
::into_raw(this
) as *const (),
33 &std
::task
::RawWakerVTable
::new(
44 future
: Mutex
<Option
<BoxFut
>>,
45 queue
: Weak
<TaskQueue
>,
49 queue
: Mutex
<VecDeque
<Task
>>,
56 queue
: Mutex
::new(VecDeque
::with_capacity(32)),
57 queue_cv
: Condvar
::new(),
61 fn new_task(self: Arc
<TaskQueue
>, future
: BoxFut
) {
62 let task
= Task(Arc
::new(TaskInner
{
63 future
: Mutex
::new(Some(future
)),
64 queue
: Arc
::downgrade(&self),
70 fn queue(&self, task
: Task
) {
71 let mut queue
= self.queue
.lock().unwrap();
72 queue
.push_back(task
);
73 self.queue_cv
.notify_one();
76 /// Blocks until a task is available
77 fn get_task(&self) -> Task
{
78 let mut queue
= self.queue
.lock().unwrap();
80 if let Some(task
) = queue
.pop_front() {
83 queue
= self.queue_cv
.wait(queue
).unwrap();
89 pub struct ThreadPool
{
90 _threads
: Mutex
<Vec
<JoinHandle
<()>>>,
91 queue
: Arc
<TaskQueue
>,
95 pub fn new() -> io
::Result
<Self> {
96 let count
= num_cpus()?
;
98 let queue
= Arc
::new(TaskQueue
::new());
100 let mut threads
= Vec
::new();
101 for thread_id
in 0..count
{
102 threads
.push(std
::thread
::spawn({
103 let queue
= Arc
::clone(&queue
);
104 move || thread_main(queue
, thread_id
)
109 _threads
: Mutex
::new(threads
),
114 pub fn spawn_ok
<T
>(&self, future
: T
)
116 T
: Future
<Output
= ()> + Send
+ '
static,
118 self.do_spawn(Box
::new(future
));
121 fn do_spawn(&self, future
: BoxFut
) {
122 Arc
::clone(&self.queue
).new_task(future
);
125 pub fn run
<R
, T
>(&self, future
: T
) -> R
127 T
: Future
<Output
= R
> + Send
+ '
static,
130 let mutex
: Arc
<Mutex
<Option
<R
>>> = Arc
::new(Mutex
::new(None
));
131 let cv
= Arc
::new(Condvar
::new());
132 let mut guard
= mutex
.lock().unwrap();
134 let mutex
= Arc
::clone(&mutex
);
135 let cv
= Arc
::clone(&cv
);
137 let result
= future
.await
;
138 *(mutex
.lock().unwrap()) = Some(result
);
143 guard
= cv
.wait(guard
).unwrap();
144 if let Some(result
) = guard
.take() {
152 static CURRENT_QUEUE
: RefCell
<*const TaskQueue
> = RefCell
::new(std
::ptr
::null());
153 static CURRENT_TASK
: RefCell
<*const Task
> = RefCell
::new(std
::ptr
::null());
156 fn thread_main(task_queue
: Arc
<TaskQueue
>, _thread_id
: usize) {
157 CURRENT_QUEUE
.with(|q
| *q
.borrow_mut() = task_queue
.as_ref() as *const TaskQueue
);
159 let local_waker
= unsafe {
160 std
::task
::Waker
::from_raw(std
::task
::RawWaker
::new(
162 &std
::task
::RawWakerVTable
::new(
163 local_waker_clone_fn
,
171 let mut context
= Context
::from_waker(&local_waker
);
174 let task
: Task
= task_queue
.get_task();
175 let task
: Pin
<&Task
> = Pin
::new(&task
);
176 let task
= task
.get_ref();
177 CURRENT_TASK
.with(|c
| *c
.borrow_mut() = task
as *const Task
);
179 let mut task_future
= task
.0.future
.lock().unwrap();
180 match task_future
.take() {
181 Some(mut future
) => {
182 //eprintln!("Thread {} has some work!", thread_id);
183 let pin
= unsafe { Pin::new_unchecked(&mut *future) }
;
184 match pin
.poll(&mut context
) {
185 Poll
::Ready(()) => (), // done with that task
187 *task_future
= Some(future
);
191 None
=> eprintln
!("task polled after ready"),
196 unsafe fn local_waker_clone_fn(_
: *const ()) -> std
::task
::RawWaker
{
197 let task
: Task
= CURRENT_TASK
.with(|t
| Task
::clone(&**t
.borrow()));
198 Task
::into_raw_waker(task
)
201 unsafe fn local_waker_wake_fn(_
: *const ()) {
202 let task
: Task
= CURRENT_TASK
.with(|t
| Task
::clone(&**t
.borrow()));
203 CURRENT_QUEUE
.with(|q
| (**q
.borrow()).queue(task
));
206 unsafe fn local_waker_drop_fn(_
: *const ()) {}
208 unsafe fn waker_clone_fn(this
: *const ()) -> std
::task
::RawWaker
{
209 let this
= Task
::from_raw(this
as *const TaskInner
);
210 let clone
= this
.clone();
211 let _
= Task
::into_raw(this
);
212 Task
::into_raw_waker(clone
)
215 unsafe fn waker_wake_fn(this
: *const ()) {
216 let this
= Task
::from_raw(this
as *const TaskInner
);
220 unsafe fn waker_wake_by_ref_fn(this
: *const ()) {
221 let this
= Task
::from_raw(this
as *const TaskInner
);
223 let _
= Task
::into_raw(this
);
226 unsafe fn waker_drop_fn(this
: *const ()) {
227 let _this
= Task
::from_raw(this
as *const TaskInner
);
230 fn num_cpus() -> io
::Result
<usize> {
231 let rc
= unsafe { libc::sysconf(libc::_SC_NPROCESSORS_ONLN) }
;
233 Err(io
::Error
::last_os_error())