]> git.proxmox.com Git - pve-lxc-syscalld.git/commitdiff
foo
authorWolfgang Bumiller <w.bumiller@errno.eu>
Thu, 24 Oct 2019 16:56:32 +0000 (18:56 +0200)
committerWolfgang Bumiller <w.bumiller@errno.eu>
Thu, 24 Oct 2019 16:56:32 +0000 (18:56 +0200)
Signed-off-by: Wolfgang Bumiller <w.bumiller@errno.eu>
src/executor/ring.rs
src/executor/slot_list.rs
src/executor/thread_pool.rs

index 0fe3e2295ff18f461b64a233905b7fde94c3f4bc..f56c0de9dabb2f37db092e8122fe77a05d37630b 100644 (file)
@@ -1,12 +1,39 @@
 use std::mem::MaybeUninit;
 use std::ptr;
-use std::sync::atomic::{AtomicUsize, Ordering};
+use std::sync::atomic::{fence, AtomicBool, AtomicUsize, Ordering};
+
+// We only perform a handful of memory read/writes in push()/pop(), so we use spin locks for
+// performance reasons:
+
+struct SpinLock(AtomicBool);
+struct SpinLockGuard<'a>(&'a AtomicBool);
+
+impl SpinLock {
+    const fn new() -> Self {
+        Self(AtomicBool::new(false))
+    }
+
+    fn lock(&self) -> SpinLockGuard {
+        while self.0.compare_and_swap(false, true, Ordering::Acquire) {
+            // spin
+        }
+        SpinLockGuard(&self.0)
+    }
+}
+
+impl Drop for SpinLockGuard<'_> {
+    fn drop(&mut self) {
+        self.0.store(false, Ordering::Release);
+    }
+}
 
 pub struct Ring<T> {
     head: usize,
     tail: usize,
     mask: usize,
     data: Box<[MaybeUninit<T>]>,
+    push_lock: SpinLock,
+    pop_lock: SpinLock,
 }
 
 impl<T> Ring<T> {
@@ -16,8 +43,8 @@ impl<T> Ring<T> {
         }
 
         let mut data = Vec::with_capacity(size);
-        for _ in 0..size {
-            data.push(MaybeUninit::uninit())
+        unsafe {
+            data.set_len(size);
         }
 
         Self {
@@ -25,9 +52,16 @@ impl<T> Ring<T> {
             tail: 0,
             mask: size - 1,
             data: data.into_boxed_slice(),
+            push_lock: SpinLock::new(),
+            pop_lock: SpinLock::new(),
         }
     }
 
+    pub fn len(&self) -> usize {
+        fence(Ordering::Acquire);
+        self.tail - self.head
+    }
+
     #[inline]
     fn atomic_tail(&self) -> &AtomicUsize {
         unsafe { &*(&self.tail as *const usize as *const AtomicUsize) }
@@ -39,33 +73,36 @@ impl<T> Ring<T> {
     }
 
     pub fn try_push(&self, data: T) -> bool {
-        let head = self.head;
+        let _guard = self.push_lock.lock();
+
         let tail = self.atomic_tail().load(Ordering::Acquire);
+        let head = self.head;
 
         if tail - head == self.data.len() {
             return false;
         }
 
         unsafe {
-            ptr::write(self.data[tail & self.mask].as_ptr() as *mut _, data);
+            ptr::write(self.data[tail & self.mask].as_ptr() as *mut T, data);
         }
-
-        self.atomic_tail().fetch_add(1, Ordering::Release);
+        self.atomic_tail().store(tail + 1, Ordering::Release);
 
         true
     }
 
     pub fn try_pop(&self) -> Option<T> {
-        let tail = self.tail;
+        let _guard = self.pop_lock.lock();
+
         let head = self.atomic_head().load(Ordering::Acquire);
+        let tail = self.tail;
 
         if tail - head == 0 {
             return None;
         }
 
-        let data = unsafe { std::ptr::read(self.data[head & self.mask].as_ptr()) };
+        let data = unsafe { ptr::read(self.data[head & self.mask].as_ptr()) };
 
-        self.atomic_head().fetch_add(1, Ordering::Release);
+        self.atomic_head().store(head + 1, Ordering::Release);
 
         Some(data)
     }
index aa3ddf7a0f1a51db5db66388d8adef5c13616958..ae7f084fa9eae033015b202ba10f4ba81402a7fa 100644 (file)
@@ -28,4 +28,8 @@ impl<T> SlotList<T> {
         self.free_slots.push(id);
         entry
     }
+
+    pub fn get(&self, id: usize) -> Option<&T> {
+        self.tasks[id].as_ref()
+    }
 }
index 67c228ab9c15380c24de55301a0461d062e64322..1d62220a33bd703c76f2194ebeb0b3901fd625ce 100644 (file)
@@ -1,6 +1,9 @@
+use std::cell::RefCell;
 use std::future::Future;
 use std::io;
-use std::sync::{mpsc, Arc, Mutex};
+use std::pin::Pin;
+use std::sync::{Arc, Mutex, RwLock};
+use std::task::{Context, Poll};
 use std::thread::JoinHandle;
 
 use super::num_cpus;
@@ -10,43 +13,31 @@ use super::slot_list::SlotList;
 type BoxFut = Box<dyn Future<Output = ()> + Send + 'static>;
 type TaskId = usize;
 
-pub struct ThreadPool {
-    inner: Arc<ThreadPoolInner>,
-}
-
-pub struct ThreadPoolInner {
-    threads: Mutex<Vec<Thread>>,
-    tasks: Mutex<SlotList<BoxFut>>,
+struct Task {
+    id: TaskId,
+    pool: Arc<ThreadPool>,
+    future: Option<(BoxFut, std::task::Waker)>,
 }
 
-pub struct Thread {
-    handle: JoinHandle<()>,
-    inner: Arc<ThreadInner>,
-    queue_sender: mpsc::Sender<Work>,
-}
-
-pub struct ThreadInner {
-    id: usize,
-    ring: Ring<TaskId>,
+pub struct ThreadPool {
+    inner: Arc<ThreadPoolInner>,
 }
 
-pub struct Work {}
-
 impl ThreadPool {
     pub fn new() -> io::Result<Self> {
         let count = num_cpus()?;
 
         let inner = Arc::new(ThreadPoolInner {
             threads: Mutex::new(Vec::new()),
-            tasks: Mutex::new(SlotList::new()),
+            tasks: RwLock::new(SlotList::new()),
+            overflow: RwLock::new(Vec::new()),
         });
 
-        let mut threads = Vec::with_capacity(count);
+        let mut threads = inner.threads.lock().unwrap();
         for thread_id in 0..count {
             threads.push(Thread::new(Arc::clone(&inner), thread_id));
         }
-
-        *inner.threads.lock().unwrap() = threads;
+        drop(threads);
 
         Ok(ThreadPool { inner })
     }
@@ -59,9 +50,17 @@ impl ThreadPool {
     }
 }
 
+struct ThreadPoolInner {
+    threads: Mutex<Vec<Thread>>,
+    tasks: RwLock<SlotList<BoxFut>>,
+    overflow: RwLock<Vec<TaskId>>,
+}
+
+unsafe impl Sync for ThreadPoolInner {}
+
 impl ThreadPoolInner {
     fn create_task(&self, future: BoxFut) -> TaskId {
-        self.tasks.lock().unwrap().add(future)
+        self.tasks.write().unwrap().add(future)
     }
 
     fn spawn(&self, future: BoxFut) {
@@ -70,36 +69,173 @@ impl ThreadPoolInner {
 
     fn queue_task(&self, task: TaskId) {
         let threads = self.threads.lock().unwrap();
-        //let shortest = threads
-        //    .iter()
-        //    .min_by(|a, b| a.task_count().cmp(b.task_count()))
-        //    .expect("thread pool should not be empty");
+
+        let shortest = threads
+            .iter()
+            .min_by(|a, b| a.task_count().cmp(&b.task_count()))
+            .expect("thread pool should not be empty");
+
+        if !shortest.try_queue(task) {
+            drop(threads);
+            self.overflow.write().unwrap().push(task);
+        }
     }
+
+    fn create_waker(self: Arc<Self>, task_id: TaskId) -> std::task::RawWaker {
+        let waker = Box::new(Waker {
+            pool: self,
+            task_id,
+        });
+        std::task::RawWaker::new(Box::leak(waker) as *mut Waker as *mut (), &WAKER_VTABLE)
+    }
+}
+
+struct Thread {
+    handle: JoinHandle<()>,
+    inner: Arc<ThreadInner>,
 }
 
 impl Thread {
     fn new(pool: Arc<ThreadPoolInner>, id: usize) -> Self {
-        let (queue_sender, queue_receiver) = mpsc::channel();
-
         let inner = Arc::new(ThreadInner {
             id,
             ring: Ring::new(32),
+            pool,
         });
 
         let handle = std::thread::spawn({
             let inner = Arc::clone(&inner);
-            move || inner.thread_main(queue_receiver)
+            move || inner.thread_main()
         });
-        Thread {
-            handle,
-            inner,
-            queue_sender,
-        }
+        Thread { handle, inner }
     }
+
+    fn task_count(&self) -> usize {
+        self.inner.task_count()
+    }
+
+    fn try_queue(&self, task: TaskId) -> bool {
+        self.inner.try_queue(task)
+    }
+}
+
+struct ThreadInner {
+    id: usize,
+    ring: Ring<TaskId>,
+    pool: Arc<ThreadPoolInner>,
+}
+
+thread_local! {
+    static THREAD_INNER: RefCell<*const ThreadInner> = RefCell::new(std::ptr::null());
 }
 
 impl ThreadInner {
-    fn thread_main(self: Arc<Self>, queue: mpsc::Receiver<Work>) {
-        let _ = queue;
+    fn thread_main(self: Arc<Self>) {
+        THREAD_INNER.with(|inner| {
+            *inner.borrow_mut() = self.as_ref() as *const Self;
+        });
+        loop {
+            if let Some(task_id) = self.ring.try_pop() {
+                self.poll_task(task_id);
+            }
+        }
+    }
+
+    fn poll_task(&self, task_id: TaskId) {
+        //let future = {
+        //    let task = self.pool.tasks.read().unwrap().get(task_id).unwrap();
+        //    if let Some(future) = task.future.as_ref() {
+        //        future.as_ref() as *const (dyn Future<Output = ()> + Send + 'static)
+        //            as *mut (dyn Future<Output = ()> + Send + 'static)
+        //    } else {
+        //        return;
+        //    }
+        //};
+        let waker = unsafe {
+            std::task::Waker::from_raw(std::task::RawWaker::new(
+                task_id as *const (),
+                &std::task::RawWakerVTable::new(
+                    local_waker_clone_fn,
+                    local_waker_wake_fn,
+                    local_waker_wake_by_ref_fn,
+                    local_waker_drop_fn,
+                ),
+            ))
+        };
+
+        let mut context = Context::from_waker(&waker);
+
+        let future = {
+            self.pool
+                .tasks
+                .read()
+                .unwrap()
+                .get(task_id)
+                .unwrap()
+                .as_ref() as *const (dyn Future<Output = ()> + Send + 'static)
+                as *mut (dyn Future<Output = ()> + Send + 'static)
+        };
+        if let Poll::Ready(value) = unsafe { Pin::new_unchecked(&mut *future) }.poll(&mut context) {
+            let task = self.pool.tasks.write().unwrap().remove(task_id);
+        }
     }
+
+    fn task_count(&self) -> usize {
+        self.ring.len()
+    }
+
+    fn try_queue(&self, task: TaskId) -> bool {
+        self.ring.try_push(task)
+    }
+}
+
+struct RefWaker<'a> {
+    pool: &'a ThreadPoolInner,
+    task_id: TaskId,
+}
+
+struct Waker {
+    pool: Arc<ThreadPoolInner>,
+    task_id: TaskId,
+}
+
+pub struct Work {}
+
+const WAKER_VTABLE: std::task::RawWakerVTable = std::task::RawWakerVTable::new(
+    waker_clone_fn,
+    waker_wake_fn,
+    waker_wake_by_ref_fn,
+    waker_drop_fn,
+);
+
+unsafe fn waker_clone_fn(_this: *const ()) -> std::task::RawWaker {
+    panic!("TODO");
+}
+
+unsafe fn waker_wake_fn(_this: *const ()) {
+    panic!("TODO");
+}
+
+unsafe fn waker_wake_by_ref_fn(_this: *const ()) {
+    panic!("TODO");
+}
+
+unsafe fn waker_drop_fn(_this: *const ()) {
+    panic!("TODO");
+}
+
+unsafe fn local_waker_clone_fn(_this: *const ()) -> std::task::RawWaker {
+    panic!("TODO");
+}
+
+unsafe fn local_waker_wake_fn(_this: *const ()) {
+    panic!("TODO");
+}
+
+unsafe fn local_waker_wake_by_ref_fn(_this: *const ()) {
+    panic!("TODO");
+}
+
+unsafe fn local_waker_drop_fn(_this: *const ()) {
+    panic!("TODO");
 }