use std::mem::MaybeUninit;
use std::ptr;
-use std::sync::atomic::{AtomicUsize, Ordering};
+use std::sync::atomic::{fence, AtomicBool, AtomicUsize, Ordering};
+
+// We only perform a handful of memory read/writes in push()/pop(), so we use spin locks for
+// performance reasons:
+
+struct SpinLock(AtomicBool);
+struct SpinLockGuard<'a>(&'a AtomicBool);
+
+impl SpinLock {
+ const fn new() -> Self {
+ Self(AtomicBool::new(false))
+ }
+
+ fn lock(&self) -> SpinLockGuard {
+ while self.0.compare_and_swap(false, true, Ordering::Acquire) {
+ // spin
+ }
+ SpinLockGuard(&self.0)
+ }
+}
+
+impl Drop for SpinLockGuard<'_> {
+ fn drop(&mut self) {
+ self.0.store(false, Ordering::Release);
+ }
+}
pub struct Ring<T> {
head: usize,
tail: usize,
mask: usize,
data: Box<[MaybeUninit<T>]>,
+ push_lock: SpinLock,
+ pop_lock: SpinLock,
}
impl<T> Ring<T> {
}
let mut data = Vec::with_capacity(size);
- for _ in 0..size {
- data.push(MaybeUninit::uninit())
+ unsafe {
+ data.set_len(size);
}
Self {
tail: 0,
mask: size - 1,
data: data.into_boxed_slice(),
+ push_lock: SpinLock::new(),
+ pop_lock: SpinLock::new(),
}
}
+ pub fn len(&self) -> usize {
+ fence(Ordering::Acquire);
+ self.tail - self.head
+ }
+
#[inline]
fn atomic_tail(&self) -> &AtomicUsize {
unsafe { &*(&self.tail as *const usize as *const AtomicUsize) }
}
pub fn try_push(&self, data: T) -> bool {
- let head = self.head;
+ let _guard = self.push_lock.lock();
+
let tail = self.atomic_tail().load(Ordering::Acquire);
+ let head = self.head;
if tail - head == self.data.len() {
return false;
}
unsafe {
- ptr::write(self.data[tail & self.mask].as_ptr() as *mut _, data);
+ ptr::write(self.data[tail & self.mask].as_ptr() as *mut T, data);
}
-
- self.atomic_tail().fetch_add(1, Ordering::Release);
+ self.atomic_tail().store(tail + 1, Ordering::Release);
true
}
pub fn try_pop(&self) -> Option<T> {
- let tail = self.tail;
+ let _guard = self.pop_lock.lock();
+
let head = self.atomic_head().load(Ordering::Acquire);
+ let tail = self.tail;
if tail - head == 0 {
return None;
}
- let data = unsafe { std::ptr::read(self.data[head & self.mask].as_ptr()) };
+ let data = unsafe { ptr::read(self.data[head & self.mask].as_ptr()) };
- self.atomic_head().fetch_add(1, Ordering::Release);
+ self.atomic_head().store(head + 1, Ordering::Release);
Some(data)
}
+use std::cell::RefCell;
use std::future::Future;
use std::io;
-use std::sync::{mpsc, Arc, Mutex};
+use std::pin::Pin;
+use std::sync::{Arc, Mutex, RwLock};
+use std::task::{Context, Poll};
use std::thread::JoinHandle;
use super::num_cpus;
type BoxFut = Box<dyn Future<Output = ()> + Send + 'static>;
type TaskId = usize;
-pub struct ThreadPool {
- inner: Arc<ThreadPoolInner>,
-}
-
-pub struct ThreadPoolInner {
- threads: Mutex<Vec<Thread>>,
- tasks: Mutex<SlotList<BoxFut>>,
+struct Task {
+ id: TaskId,
+ pool: Arc<ThreadPool>,
+ future: Option<(BoxFut, std::task::Waker)>,
}
-pub struct Thread {
- handle: JoinHandle<()>,
- inner: Arc<ThreadInner>,
- queue_sender: mpsc::Sender<Work>,
-}
-
-pub struct ThreadInner {
- id: usize,
- ring: Ring<TaskId>,
+pub struct ThreadPool {
+ inner: Arc<ThreadPoolInner>,
}
-pub struct Work {}
-
impl ThreadPool {
pub fn new() -> io::Result<Self> {
let count = num_cpus()?;
let inner = Arc::new(ThreadPoolInner {
threads: Mutex::new(Vec::new()),
- tasks: Mutex::new(SlotList::new()),
+ tasks: RwLock::new(SlotList::new()),
+ overflow: RwLock::new(Vec::new()),
});
- let mut threads = Vec::with_capacity(count);
+ let mut threads = inner.threads.lock().unwrap();
for thread_id in 0..count {
threads.push(Thread::new(Arc::clone(&inner), thread_id));
}
-
- *inner.threads.lock().unwrap() = threads;
+ drop(threads);
Ok(ThreadPool { inner })
}
}
}
+struct ThreadPoolInner {
+ threads: Mutex<Vec<Thread>>,
+ tasks: RwLock<SlotList<BoxFut>>,
+ overflow: RwLock<Vec<TaskId>>,
+}
+
+unsafe impl Sync for ThreadPoolInner {}
+
impl ThreadPoolInner {
fn create_task(&self, future: BoxFut) -> TaskId {
- self.tasks.lock().unwrap().add(future)
+ self.tasks.write().unwrap().add(future)
}
fn spawn(&self, future: BoxFut) {
fn queue_task(&self, task: TaskId) {
let threads = self.threads.lock().unwrap();
- //let shortest = threads
- // .iter()
- // .min_by(|a, b| a.task_count().cmp(b.task_count()))
- // .expect("thread pool should not be empty");
+
+ let shortest = threads
+ .iter()
+ .min_by(|a, b| a.task_count().cmp(&b.task_count()))
+ .expect("thread pool should not be empty");
+
+ if !shortest.try_queue(task) {
+ drop(threads);
+ self.overflow.write().unwrap().push(task);
+ }
}
+
+ fn create_waker(self: Arc<Self>, task_id: TaskId) -> std::task::RawWaker {
+ let waker = Box::new(Waker {
+ pool: self,
+ task_id,
+ });
+ std::task::RawWaker::new(Box::leak(waker) as *mut Waker as *mut (), &WAKER_VTABLE)
+ }
+}
+
+struct Thread {
+ handle: JoinHandle<()>,
+ inner: Arc<ThreadInner>,
}
impl Thread {
fn new(pool: Arc<ThreadPoolInner>, id: usize) -> Self {
- let (queue_sender, queue_receiver) = mpsc::channel();
-
let inner = Arc::new(ThreadInner {
id,
ring: Ring::new(32),
+ pool,
});
let handle = std::thread::spawn({
let inner = Arc::clone(&inner);
- move || inner.thread_main(queue_receiver)
+ move || inner.thread_main()
});
- Thread {
- handle,
- inner,
- queue_sender,
- }
+ Thread { handle, inner }
}
+
+ fn task_count(&self) -> usize {
+ self.inner.task_count()
+ }
+
+ fn try_queue(&self, task: TaskId) -> bool {
+ self.inner.try_queue(task)
+ }
+}
+
+struct ThreadInner {
+ id: usize,
+ ring: Ring<TaskId>,
+ pool: Arc<ThreadPoolInner>,
+}
+
+thread_local! {
+ static THREAD_INNER: RefCell<*const ThreadInner> = RefCell::new(std::ptr::null());
}
impl ThreadInner {
- fn thread_main(self: Arc<Self>, queue: mpsc::Receiver<Work>) {
- let _ = queue;
+ fn thread_main(self: Arc<Self>) {
+ THREAD_INNER.with(|inner| {
+ *inner.borrow_mut() = self.as_ref() as *const Self;
+ });
+ loop {
+ if let Some(task_id) = self.ring.try_pop() {
+ self.poll_task(task_id);
+ }
+ }
+ }
+
+ fn poll_task(&self, task_id: TaskId) {
+ //let future = {
+ // let task = self.pool.tasks.read().unwrap().get(task_id).unwrap();
+ // if let Some(future) = task.future.as_ref() {
+ // future.as_ref() as *const (dyn Future<Output = ()> + Send + 'static)
+ // as *mut (dyn Future<Output = ()> + Send + 'static)
+ // } else {
+ // return;
+ // }
+ //};
+ let waker = unsafe {
+ std::task::Waker::from_raw(std::task::RawWaker::new(
+ task_id as *const (),
+ &std::task::RawWakerVTable::new(
+ local_waker_clone_fn,
+ local_waker_wake_fn,
+ local_waker_wake_by_ref_fn,
+ local_waker_drop_fn,
+ ),
+ ))
+ };
+
+ let mut context = Context::from_waker(&waker);
+
+ let future = {
+ self.pool
+ .tasks
+ .read()
+ .unwrap()
+ .get(task_id)
+ .unwrap()
+ .as_ref() as *const (dyn Future<Output = ()> + Send + 'static)
+ as *mut (dyn Future<Output = ()> + Send + 'static)
+ };
+ if let Poll::Ready(value) = unsafe { Pin::new_unchecked(&mut *future) }.poll(&mut context) {
+ let task = self.pool.tasks.write().unwrap().remove(task_id);
+ }
}
+
+ fn task_count(&self) -> usize {
+ self.ring.len()
+ }
+
+ fn try_queue(&self, task: TaskId) -> bool {
+ self.ring.try_push(task)
+ }
+}
+
+struct RefWaker<'a> {
+ pool: &'a ThreadPoolInner,
+ task_id: TaskId,
+}
+
+struct Waker {
+ pool: Arc<ThreadPoolInner>,
+ task_id: TaskId,
+}
+
+pub struct Work {}
+
+const WAKER_VTABLE: std::task::RawWakerVTable = std::task::RawWakerVTable::new(
+ waker_clone_fn,
+ waker_wake_fn,
+ waker_wake_by_ref_fn,
+ waker_drop_fn,
+);
+
+unsafe fn waker_clone_fn(_this: *const ()) -> std::task::RawWaker {
+ panic!("TODO");
+}
+
+unsafe fn waker_wake_fn(_this: *const ()) {
+ panic!("TODO");
+}
+
+unsafe fn waker_wake_by_ref_fn(_this: *const ()) {
+ panic!("TODO");
+}
+
+unsafe fn waker_drop_fn(_this: *const ()) {
+ panic!("TODO");
+}
+
+unsafe fn local_waker_clone_fn(_this: *const ()) -> std::task::RawWaker {
+ panic!("TODO");
+}
+
+unsafe fn local_waker_wake_fn(_this: *const ()) {
+ panic!("TODO");
+}
+
+unsafe fn local_waker_wake_by_ref_fn(_this: *const ()) {
+ panic!("TODO");
+}
+
+unsafe fn local_waker_drop_fn(_this: *const ()) {
+ panic!("TODO");
}