]> git.proxmox.com Git - rustc.git/blob - vendor/rayon-core/src/broadcast/mod.rs
New upstream version 1.71.1+dfsg1
[rustc.git] / vendor / rayon-core / src / broadcast / mod.rs
1 use crate::job::{ArcJob, StackJob};
2 use crate::latch::LatchRef;
3 use crate::registry::{Registry, WorkerThread};
4 use crate::scope::ScopeLatch;
5 use std::fmt;
6 use std::marker::PhantomData;
7 use std::sync::Arc;
8
9 mod test;
10
11 /// Executes `op` within every thread in the current threadpool. If this is
12 /// called from a non-Rayon thread, it will execute in the global threadpool.
13 /// Any attempts to use `join`, `scope`, or parallel iterators will then operate
14 /// within that threadpool. When the call has completed on each thread, returns
15 /// a vector containing all of their return values.
16 ///
17 /// For more information, see the [`ThreadPool::broadcast()`][m] method.
18 ///
19 /// [m]: struct.ThreadPool.html#method.broadcast
20 pub fn broadcast<OP, R>(op: OP) -> Vec<R>
21 where
22 OP: Fn(BroadcastContext<'_>) -> R + Sync,
23 R: Send,
24 {
25 // We assert that current registry has not terminated.
26 unsafe { broadcast_in(op, &Registry::current()) }
27 }
28
29 /// Spawns an asynchronous task on every thread in this thread-pool. This task
30 /// will run in the implicit, global scope, which means that it may outlast the
31 /// current stack frame -- therefore, it cannot capture any references onto the
32 /// stack (you will likely need a `move` closure).
33 ///
34 /// For more information, see the [`ThreadPool::spawn_broadcast()`][m] method.
35 ///
36 /// [m]: struct.ThreadPool.html#method.spawn_broadcast
37 pub fn spawn_broadcast<OP>(op: OP)
38 where
39 OP: Fn(BroadcastContext<'_>) + Send + Sync + 'static,
40 {
41 // We assert that current registry has not terminated.
42 unsafe { spawn_broadcast_in(op, &Registry::current()) }
43 }
44
45 /// Provides context to a closure called by `broadcast`.
46 pub struct BroadcastContext<'a> {
47 worker: &'a WorkerThread,
48
49 /// Make sure to prevent auto-traits like `Send` and `Sync`.
50 _marker: PhantomData<&'a mut dyn Fn()>,
51 }
52
53 impl<'a> BroadcastContext<'a> {
54 pub(super) fn with<R>(f: impl FnOnce(BroadcastContext<'_>) -> R) -> R {
55 let worker_thread = WorkerThread::current();
56 assert!(!worker_thread.is_null());
57 f(BroadcastContext {
58 worker: unsafe { &*worker_thread },
59 _marker: PhantomData,
60 })
61 }
62
63 /// Our index amongst the broadcast threads (ranges from `0..self.num_threads()`).
64 #[inline]
65 pub fn index(&self) -> usize {
66 self.worker.index()
67 }
68
69 /// The number of threads receiving the broadcast in the thread pool.
70 ///
71 /// # Future compatibility note
72 ///
73 /// Future versions of Rayon might vary the number of threads over time, but
74 /// this method will always return the number of threads which are actually
75 /// receiving your particular `broadcast` call.
76 #[inline]
77 pub fn num_threads(&self) -> usize {
78 self.worker.registry().num_threads()
79 }
80 }
81
82 impl<'a> fmt::Debug for BroadcastContext<'a> {
83 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
84 fmt.debug_struct("BroadcastContext")
85 .field("index", &self.index())
86 .field("num_threads", &self.num_threads())
87 .field("pool_id", &self.worker.registry().id())
88 .finish()
89 }
90 }
91
92 /// Execute `op` on every thread in the pool. It will be executed on each
93 /// thread when they have nothing else to do locally, before they try to
94 /// steal work from other threads. This function will not return until all
95 /// threads have completed the `op`.
96 ///
97 /// Unsafe because `registry` must not yet have terminated.
98 pub(super) unsafe fn broadcast_in<OP, R>(op: OP, registry: &Arc<Registry>) -> Vec<R>
99 where
100 OP: Fn(BroadcastContext<'_>) -> R + Sync,
101 R: Send,
102 {
103 let f = move |injected: bool| {
104 debug_assert!(injected);
105 BroadcastContext::with(&op)
106 };
107
108 let n_threads = registry.num_threads();
109 let current_thread = WorkerThread::current().as_ref();
110 let latch = ScopeLatch::with_count(n_threads, current_thread);
111 let jobs: Vec<_> = (0..n_threads)
112 .map(|_| StackJob::new(&f, LatchRef::new(&latch)))
113 .collect();
114 let job_refs = jobs.iter().map(|job| job.as_job_ref());
115
116 registry.inject_broadcast(job_refs);
117
118 // Wait for all jobs to complete, then collect the results, maybe propagating a panic.
119 latch.wait(current_thread);
120 jobs.into_iter().map(|job| job.into_result()).collect()
121 }
122
123 /// Execute `op` on every thread in the pool. It will be executed on each
124 /// thread when they have nothing else to do locally, before they try to
125 /// steal work from other threads. This function returns immediately after
126 /// injecting the jobs.
127 ///
128 /// Unsafe because `registry` must not yet have terminated.
129 pub(super) unsafe fn spawn_broadcast_in<OP>(op: OP, registry: &Arc<Registry>)
130 where
131 OP: Fn(BroadcastContext<'_>) + Send + Sync + 'static,
132 {
133 let job = ArcJob::new({
134 let registry = Arc::clone(registry);
135 move || {
136 registry.catch_unwind(|| BroadcastContext::with(&op));
137 registry.terminate(); // (*) permit registry to terminate now
138 }
139 });
140
141 let n_threads = registry.num_threads();
142 let job_refs = (0..n_threads).map(|_| {
143 // Ensure that registry cannot terminate until this job has executed
144 // on each thread. This ref is decremented at the (*) above.
145 registry.increment_terminate_count();
146
147 ArcJob::as_static_job_ref(&job)
148 });
149
150 registry.inject_broadcast(job_refs);
151 }