]>
Commit | Line | Data |
---|---|---|
6a06907d | 1 | use crate::registry::{Registry, WorkerThread}; |
94b46f34 XL |
2 | use std::fmt; |
3 | use std::ops::Deref; | |
4 | use std::sync::Arc; | |
5 | ||
6 | #[repr(align(64))] | |
7 | #[derive(Debug)] | |
8 | struct CacheAligned<T>(T); | |
9 | ||
10 | /// Holds worker-locals values for each thread in a thread pool. | |
11 | /// You can only access the worker local value through the Deref impl | |
12 | /// on the thread pool it was constructed on. It will panic otherwise | |
13 | pub struct WorkerLocal<T> { | |
14 | locals: Vec<CacheAligned<T>>, | |
15 | registry: Arc<Registry>, | |
16 | } | |
17 | ||
6a06907d XL |
18 | /// We prevent concurrent access to the underlying value in the |
19 | /// Deref impl, thus any values safe to send across threads can | |
20 | /// be used with WorkerLocal. | |
21 | unsafe impl<T: Send> Sync for WorkerLocal<T> {} | |
94b46f34 XL |
22 | |
23 | impl<T> WorkerLocal<T> { | |
24 | /// Creates a new worker local where the `initial` closure computes the | |
25 | /// value this worker local should take for each thread in the thread pool. | |
26 | #[inline] | |
27 | pub fn new<F: FnMut(usize) -> T>(mut initial: F) -> WorkerLocal<T> { | |
28 | let registry = Registry::current(); | |
29 | WorkerLocal { | |
30 | locals: (0..registry.num_threads()) | |
31 | .map(|i| CacheAligned(initial(i))) | |
32 | .collect(), | |
33 | registry, | |
34 | } | |
35 | } | |
36 | ||
37 | /// Returns the worker-local value for each thread | |
38 | #[inline] | |
39 | pub fn into_inner(self) -> Vec<T> { | |
40 | self.locals.into_iter().map(|c| c.0).collect() | |
41 | } | |
42 | ||
43 | fn current(&self) -> &T { | |
44 | unsafe { | |
45 | let worker_thread = WorkerThread::current(); | |
46 | if worker_thread.is_null() | |
47 | || &*(*worker_thread).registry as *const _ != &*self.registry as *const _ | |
48 | { | |
49 | panic!("WorkerLocal can only be used on the thread pool it was created on") | |
50 | } | |
51 | &self.locals[(*worker_thread).index].0 | |
52 | } | |
53 | } | |
54 | } | |
55 | ||
56 | impl<T> WorkerLocal<Vec<T>> { | |
57 | /// Joins the elements of all the worker locals into one Vec | |
58 | pub fn join(self) -> Vec<T> { | |
59 | self.into_inner().into_iter().flat_map(|v| v).collect() | |
60 | } | |
61 | } | |
62 | ||
63 | impl<T: fmt::Debug> fmt::Debug for WorkerLocal<T> { | |
64 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { | |
6a06907d XL |
65 | f.debug_struct("WorkerLocal") |
66 | .field("registry", &self.registry.id()) | |
67 | .finish() | |
94b46f34 XL |
68 | } |
69 | } | |
70 | ||
71 | impl<T> Deref for WorkerLocal<T> { | |
72 | type Target = T; | |
73 | ||
74 | #[inline(always)] | |
75 | fn deref(&self) -> &T { | |
76 | self.current() | |
77 | } | |
78 | } |