]>
Commit | Line | Data |
---|---|---|
416331ca XL |
1 | use std::fmt; |
2 | use std::sync::{Arc, Condvar, Mutex}; | |
3 | ||
4 | /// Enables threads to synchronize the beginning or end of some computation. | |
5 | /// | |
6 | /// # Wait groups vs barriers | |
7 | /// | |
8 | /// `WaitGroup` is very similar to [`Barrier`], but there are a few differences: | |
9 | /// | |
10 | /// * `Barrier` needs to know the number of threads at construction, while `WaitGroup` is cloned to | |
11 | /// register more threads. | |
12 | /// | |
13 | /// * A `Barrier` can be reused even after all threads have synchronized, while a `WaitGroup` | |
14 | /// synchronizes threads only once. | |
15 | /// | |
16 | /// * All threads wait for others to reach the `Barrier`. With `WaitGroup`, each thread can choose | |
17 | /// to either wait for other threads or to continue without blocking. | |
18 | /// | |
19 | /// # Examples | |
20 | /// | |
21 | /// ``` | |
22 | /// use crossbeam_utils::sync::WaitGroup; | |
23 | /// use std::thread; | |
24 | /// | |
25 | /// // Create a new wait group. | |
26 | /// let wg = WaitGroup::new(); | |
27 | /// | |
28 | /// for _ in 0..4 { | |
29 | /// // Create another reference to the wait group. | |
30 | /// let wg = wg.clone(); | |
31 | /// | |
32 | /// thread::spawn(move || { | |
33 | /// // Do some work. | |
34 | /// | |
35 | /// // Drop the reference to the wait group. | |
36 | /// drop(wg); | |
37 | /// }); | |
38 | /// } | |
39 | /// | |
40 | /// // Block until all threads have finished their work. | |
41 | /// wg.wait(); | |
42 | /// ``` | |
43 | /// | |
44 | /// [`Barrier`]: https://doc.rust-lang.org/std/sync/struct.Barrier.html | |
45 | pub struct WaitGroup { | |
46 | inner: Arc<Inner>, | |
47 | } | |
48 | ||
49 | /// Inner state of a `WaitGroup`. | |
50 | struct Inner { | |
51 | cvar: Condvar, | |
52 | count: Mutex<usize>, | |
53 | } | |
54 | ||
55 | impl WaitGroup { | |
56 | /// Creates a new wait group and returns the single reference to it. | |
57 | /// | |
58 | /// # Examples | |
59 | /// | |
60 | /// ``` | |
61 | /// use crossbeam_utils::sync::WaitGroup; | |
62 | /// | |
63 | /// let wg = WaitGroup::new(); | |
64 | /// ``` | |
65 | pub fn new() -> WaitGroup { | |
66 | WaitGroup { | |
67 | inner: Arc::new(Inner { | |
68 | cvar: Condvar::new(), | |
69 | count: Mutex::new(1), | |
70 | }), | |
71 | } | |
72 | } | |
73 | ||
74 | /// Drops this reference and waits until all other references are dropped. | |
75 | /// | |
76 | /// # Examples | |
77 | /// | |
78 | /// ``` | |
79 | /// use crossbeam_utils::sync::WaitGroup; | |
80 | /// use std::thread; | |
81 | /// | |
82 | /// let wg = WaitGroup::new(); | |
83 | /// | |
84 | /// thread::spawn({ | |
85 | /// let wg = wg.clone(); | |
86 | /// move || { | |
87 | /// // Block until both threads have reached `wait()`. | |
88 | /// wg.wait(); | |
89 | /// } | |
90 | /// }); | |
91 | /// | |
92 | /// // Block until both threads have reached `wait()`. | |
93 | /// wg.wait(); | |
94 | /// ``` | |
95 | pub fn wait(self) { | |
96 | if *self.inner.count.lock().unwrap() == 1 { | |
97 | return; | |
98 | } | |
99 | ||
100 | let inner = self.inner.clone(); | |
101 | drop(self); | |
102 | ||
103 | let mut count = inner.count.lock().unwrap(); | |
104 | while *count > 0 { | |
105 | count = inner.cvar.wait(count).unwrap(); | |
106 | } | |
107 | } | |
108 | } | |
109 | ||
110 | impl Drop for WaitGroup { | |
111 | fn drop(&mut self) { | |
112 | let mut count = self.inner.count.lock().unwrap(); | |
113 | *count -= 1; | |
114 | ||
115 | if *count == 0 { | |
116 | self.inner.cvar.notify_all(); | |
117 | } | |
118 | } | |
119 | } | |
120 | ||
121 | impl Clone for WaitGroup { | |
122 | fn clone(&self) -> WaitGroup { | |
123 | let mut count = self.inner.count.lock().unwrap(); | |
124 | *count += 1; | |
125 | ||
126 | WaitGroup { | |
127 | inner: self.inner.clone(), | |
128 | } | |
129 | } | |
130 | } | |
131 | ||
132 | impl fmt::Debug for WaitGroup { | |
133 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { | |
134 | let count: &usize = &*self.inner.count.lock().unwrap(); | |
135 | f.debug_struct("WaitGroup") | |
136 | .field("count", count) | |
137 | .finish() | |
138 | } | |
139 | } |