]>
Commit | Line | Data |
---|---|---|
e1599b0c XL |
1 | // There's a lot of scary concurrent code in this module, but it is copied from |
2 | // `std::sync::Once` with two changes: | |
3 | // * no poisoning | |
4 | // * init function can fail | |
5 | ||
6 | use std::{ | |
f035d41b XL |
7 | cell::{Cell, UnsafeCell}, |
8 | hint::unreachable_unchecked, | |
e1599b0c XL |
9 | marker::PhantomData, |
10 | panic::{RefUnwindSafe, UnwindSafe}, | |
e1599b0c XL |
11 | sync::atomic::{AtomicBool, AtomicUsize, Ordering}, |
12 | thread::{self, Thread}, | |
13 | }; | |
14 | ||
6a06907d XL |
15 | use crate::take_unchecked; |
16 | ||
e1599b0c XL |
17 | #[derive(Debug)] |
18 | pub(crate) struct OnceCell<T> { | |
19 | // This `state` word is actually an encoded version of just a pointer to a | |
20 | // `Waiter`, so we add the `PhantomData` appropriately. | |
f035d41b | 21 | state_and_queue: AtomicUsize, |
e1599b0c | 22 | _marker: PhantomData<*mut Waiter>, |
f035d41b | 23 | value: UnsafeCell<Option<T>>, |
e1599b0c XL |
24 | } |
25 | ||
26 | // Why do we need `T: Send`? | |
27 | // Thread A creates a `OnceCell` and shares it with | |
28 | // scoped thread B, which fills the cell, which is | |
29 | // then destroyed by A. That is, destructor observes | |
30 | // a sent value. | |
31 | unsafe impl<T: Sync + Send> Sync for OnceCell<T> {} | |
32 | unsafe impl<T: Send> Send for OnceCell<T> {} | |
33 | ||
34 | impl<T: RefUnwindSafe + UnwindSafe> RefUnwindSafe for OnceCell<T> {} | |
35 | impl<T: UnwindSafe> UnwindSafe for OnceCell<T> {} | |
36 | ||
37 | // Three states that a OnceCell can be in, encoded into the lower bits of `state` in | |
38 | // the OnceCell structure. | |
39 | const INCOMPLETE: usize = 0x0; | |
40 | const RUNNING: usize = 0x1; | |
41 | const COMPLETE: usize = 0x2; | |
42 | ||
43 | // Mask to learn about the state. All other bits are the queue of waiters if | |
44 | // this is in the RUNNING state. | |
45 | const STATE_MASK: usize = 0x3; | |
46 | ||
47 | // Representation of a node in the linked list of waiters in the RUNNING state. | |
f035d41b | 48 | #[repr(align(4))] // Ensure the two lower bits are free to use as state bits. |
e1599b0c | 49 | struct Waiter { |
f035d41b | 50 | thread: Cell<Option<Thread>>, |
e1599b0c | 51 | signaled: AtomicBool, |
f035d41b | 52 | next: *const Waiter, |
e1599b0c XL |
53 | } |
54 | ||
f035d41b XL |
55 | // Head of a linked list of waiters. |
56 | // Every node is a struct on the stack of a waiting thread. | |
57 | // Will wake up the waiters when it gets dropped, i.e. also on panic. | |
58 | struct WaiterQueue<'a> { | |
59 | state_and_queue: &'a AtomicUsize, | |
60 | set_state_on_drop_to: usize, | |
e1599b0c XL |
61 | } |
62 | ||
63 | impl<T> OnceCell<T> { | |
64 | pub(crate) const fn new() -> OnceCell<T> { | |
65 | OnceCell { | |
f035d41b | 66 | state_and_queue: AtomicUsize::new(INCOMPLETE), |
e1599b0c XL |
67 | _marker: PhantomData, |
68 | value: UnsafeCell::new(None), | |
69 | } | |
70 | } | |
71 | ||
72 | /// Safety: synchronizes with store to value via Release/(Acquire|SeqCst). | |
73 | #[inline] | |
74 | pub(crate) fn is_initialized(&self) -> bool { | |
75 | // An `Acquire` load is enough because that makes all the initialization | |
76 | // operations visible to us, and, this being a fast path, weaker | |
77 | // ordering helps with performance. This `Acquire` synchronizes with | |
78 | // `SeqCst` operations on the slow path. | |
f035d41b | 79 | self.state_and_queue.load(Ordering::Acquire) == COMPLETE |
e1599b0c XL |
80 | } |
81 | ||
82 | /// Safety: synchronizes with store to value via SeqCst read from state, | |
83 | /// writes value only once because we never get to INCOMPLETE state after a | |
84 | /// successful write. | |
85 | #[cold] | |
86 | pub(crate) fn initialize<F, E>(&self, f: F) -> Result<(), E> | |
87 | where | |
88 | F: FnOnce() -> Result<T, E>, | |
89 | { | |
90 | let mut f = Some(f); | |
91 | let mut res: Result<(), E> = Ok(()); | |
f035d41b XL |
92 | let slot: *mut Option<T> = self.value.get(); |
93 | initialize_inner(&self.state_and_queue, &mut || { | |
6a06907d | 94 | let f = unsafe { take_unchecked(&mut f) }; |
e1599b0c XL |
95 | match f() { |
96 | Ok(value) => { | |
f035d41b | 97 | unsafe { *slot = Some(value) }; |
e1599b0c XL |
98 | true |
99 | } | |
f035d41b XL |
100 | Err(err) => { |
101 | res = Err(err); | |
e1599b0c XL |
102 | false |
103 | } | |
104 | } | |
105 | }); | |
106 | res | |
107 | } | |
f035d41b XL |
108 | |
109 | /// Get the reference to the underlying value, without checking if the cell | |
110 | /// is initialized. | |
111 | /// | |
112 | /// # Safety | |
113 | /// | |
114 | /// Caller must ensure that the cell is in initialized state, and that | |
115 | /// the contents are acquired by (synchronized to) this thread. | |
116 | pub(crate) unsafe fn get_unchecked(&self) -> &T { | |
117 | debug_assert!(self.is_initialized()); | |
118 | let slot: &Option<T> = &*self.value.get(); | |
119 | match slot { | |
120 | Some(value) => value, | |
121 | // This unsafe does improve performance, see `examples/bench`. | |
122 | None => { | |
123 | debug_assert!(false); | |
124 | unreachable_unchecked() | |
125 | } | |
126 | } | |
127 | } | |
128 | ||
129 | /// Gets the mutable reference to the underlying value. | |
130 | /// Returns `None` if the cell is empty. | |
131 | pub(crate) fn get_mut(&mut self) -> Option<&mut T> { | |
132 | // Safe b/c we have a unique access. | |
133 | unsafe { &mut *self.value.get() }.as_mut() | |
134 | } | |
135 | ||
136 | /// Consumes this `OnceCell`, returning the wrapped value. | |
137 | /// Returns `None` if the cell was empty. | |
138 | #[inline] | |
139 | pub(crate) fn into_inner(self) -> Option<T> { | |
140 | // Because `into_inner` takes `self` by value, the compiler statically | |
141 | // verifies that it is not currently borrowed. | |
142 | // So, it is safe to move out `Option<T>`. | |
143 | self.value.into_inner() | |
144 | } | |
e1599b0c XL |
145 | } |
146 | ||
f035d41b | 147 | // Corresponds to `std::sync::Once::call_inner` |
e1599b0c | 148 | // Note: this is intentionally monomorphic |
6a06907d | 149 | #[inline(never)] |
f035d41b XL |
150 | fn initialize_inner(my_state_and_queue: &AtomicUsize, init: &mut dyn FnMut() -> bool) -> bool { |
151 | let mut state_and_queue = my_state_and_queue.load(Ordering::Acquire); | |
e1599b0c | 152 | |
f035d41b XL |
153 | loop { |
154 | match state_and_queue { | |
155 | COMPLETE => return true, | |
e1599b0c | 156 | INCOMPLETE => { |
6a06907d | 157 | let exchange = my_state_and_queue.compare_exchange( |
f035d41b XL |
158 | state_and_queue, |
159 | RUNNING, | |
160 | Ordering::Acquire, | |
6a06907d | 161 | Ordering::Acquire, |
f035d41b | 162 | ); |
6a06907d | 163 | if let Err(old) = exchange { |
f035d41b | 164 | state_and_queue = old; |
e1599b0c XL |
165 | continue; |
166 | } | |
f035d41b XL |
167 | let mut waiter_queue = WaiterQueue { |
168 | state_and_queue: my_state_and_queue, | |
169 | set_state_on_drop_to: INCOMPLETE, // Difference, std uses `POISONED` | |
170 | }; | |
e1599b0c | 171 | let success = init(); |
f035d41b XL |
172 | |
173 | // Difference, std always uses `COMPLETE` | |
174 | waiter_queue.set_state_on_drop_to = if success { COMPLETE } else { INCOMPLETE }; | |
e1599b0c XL |
175 | return success; |
176 | } | |
e1599b0c | 177 | _ => { |
f035d41b XL |
178 | assert!(state_and_queue & STATE_MASK == RUNNING); |
179 | wait(&my_state_and_queue, state_and_queue); | |
180 | state_and_queue = my_state_and_queue.load(Ordering::Acquire); | |
e1599b0c XL |
181 | } |
182 | } | |
183 | } | |
184 | } | |
185 | ||
f035d41b XL |
186 | // Copy-pasted from std exactly. |
187 | fn wait(state_and_queue: &AtomicUsize, mut current_state: usize) { | |
188 | loop { | |
189 | if current_state & STATE_MASK != RUNNING { | |
190 | return; | |
191 | } | |
192 | ||
193 | let node = Waiter { | |
194 | thread: Cell::new(Some(thread::current())), | |
195 | signaled: AtomicBool::new(false), | |
196 | next: (current_state & !STATE_MASK) as *const Waiter, | |
e1599b0c | 197 | }; |
f035d41b XL |
198 | let me = &node as *const Waiter as usize; |
199 | ||
6a06907d XL |
200 | let exchange = state_and_queue.compare_exchange( |
201 | current_state, | |
202 | me | RUNNING, | |
203 | Ordering::Release, | |
204 | Ordering::Relaxed, | |
205 | ); | |
206 | if let Err(old) = exchange { | |
f035d41b XL |
207 | current_state = old; |
208 | continue; | |
209 | } | |
210 | ||
211 | while !node.signaled.load(Ordering::Acquire) { | |
212 | thread::park(); | |
213 | } | |
214 | break; | |
215 | } | |
216 | } | |
217 | ||
218 | // Copy-pasted from std exactly. | |
219 | impl Drop for WaiterQueue<'_> { | |
220 | fn drop(&mut self) { | |
221 | let state_and_queue = | |
222 | self.state_and_queue.swap(self.set_state_on_drop_to, Ordering::AcqRel); | |
223 | ||
224 | assert_eq!(state_and_queue & STATE_MASK, RUNNING); | |
e1599b0c | 225 | |
e1599b0c | 226 | unsafe { |
f035d41b | 227 | let mut queue = (state_and_queue & !STATE_MASK) as *const Waiter; |
e1599b0c XL |
228 | while !queue.is_null() { |
229 | let next = (*queue).next; | |
f035d41b XL |
230 | let thread = (*queue).thread.replace(None).unwrap(); |
231 | (*queue).signaled.store(true, Ordering::Release); | |
e1599b0c | 232 | queue = next; |
f035d41b | 233 | thread.unpark(); |
e1599b0c XL |
234 | } |
235 | } | |
236 | } | |
237 | } | |
238 | ||
239 | // These test are snatched from std as well. | |
240 | #[cfg(test)] | |
241 | mod tests { | |
242 | use std::panic; | |
e1599b0c XL |
243 | use std::{sync::mpsc::channel, thread}; |
244 | ||
245 | use super::OnceCell; | |
246 | ||
247 | impl<T> OnceCell<T> { | |
248 | fn init(&self, f: impl FnOnce() -> T) { | |
249 | enum Void {} | |
250 | let _ = self.initialize(|| Ok::<T, Void>(f())); | |
251 | } | |
252 | } | |
253 | ||
254 | #[test] | |
255 | fn smoke_once() { | |
256 | static O: OnceCell<()> = OnceCell::new(); | |
257 | let mut a = 0; | |
258 | O.init(|| a += 1); | |
259 | assert_eq!(a, 1); | |
260 | O.init(|| a += 1); | |
261 | assert_eq!(a, 1); | |
262 | } | |
263 | ||
264 | #[test] | |
6a06907d | 265 | #[cfg(not(miri))] |
e1599b0c XL |
266 | fn stampede_once() { |
267 | static O: OnceCell<()> = OnceCell::new(); | |
268 | static mut RUN: bool = false; | |
269 | ||
270 | let (tx, rx) = channel(); | |
271 | for _ in 0..10 { | |
272 | let tx = tx.clone(); | |
273 | thread::spawn(move || { | |
274 | for _ in 0..4 { | |
275 | thread::yield_now() | |
276 | } | |
277 | unsafe { | |
278 | O.init(|| { | |
279 | assert!(!RUN); | |
280 | RUN = true; | |
281 | }); | |
282 | assert!(RUN); | |
283 | } | |
284 | tx.send(()).unwrap(); | |
285 | }); | |
286 | } | |
287 | ||
288 | unsafe { | |
289 | O.init(|| { | |
290 | assert!(!RUN); | |
291 | RUN = true; | |
292 | }); | |
293 | assert!(RUN); | |
294 | } | |
295 | ||
296 | for _ in 0..10 { | |
297 | rx.recv().unwrap(); | |
298 | } | |
299 | } | |
300 | ||
301 | #[test] | |
e1599b0c XL |
302 | fn poison_bad() { |
303 | static O: OnceCell<()> = OnceCell::new(); | |
304 | ||
305 | // poison the once | |
306 | let t = panic::catch_unwind(|| { | |
307 | O.init(|| panic!()); | |
308 | }); | |
309 | assert!(t.is_err()); | |
310 | ||
311 | // we can subvert poisoning, however | |
312 | let mut called = false; | |
313 | O.init(|| { | |
314 | called = true; | |
315 | }); | |
316 | assert!(called); | |
317 | ||
318 | // once any success happens, we stop propagating the poison | |
319 | O.init(|| {}); | |
320 | } | |
321 | ||
322 | #[test] | |
e1599b0c XL |
323 | fn wait_for_force_to_finish() { |
324 | static O: OnceCell<()> = OnceCell::new(); | |
325 | ||
326 | // poison the once | |
327 | let t = panic::catch_unwind(|| { | |
328 | O.init(|| panic!()); | |
329 | }); | |
330 | assert!(t.is_err()); | |
331 | ||
332 | // make sure someone's waiting inside the once via a force | |
333 | let (tx1, rx1) = channel(); | |
334 | let (tx2, rx2) = channel(); | |
335 | let t1 = thread::spawn(move || { | |
336 | O.init(|| { | |
337 | tx1.send(()).unwrap(); | |
338 | rx2.recv().unwrap(); | |
339 | }); | |
340 | }); | |
341 | ||
342 | rx1.recv().unwrap(); | |
343 | ||
344 | // put another waiter on the once | |
345 | let t2 = thread::spawn(|| { | |
346 | let mut called = false; | |
347 | O.init(|| { | |
348 | called = true; | |
349 | }); | |
350 | assert!(!called); | |
351 | }); | |
352 | ||
353 | tx2.send(()).unwrap(); | |
354 | ||
355 | assert!(t1.join().is_ok()); | |
356 | assert!(t2.join().is_ok()); | |
357 | } | |
f035d41b XL |
358 | |
359 | #[test] | |
360 | #[cfg(target_pointer_width = "64")] | |
361 | fn test_size() { | |
362 | use std::mem::size_of; | |
363 | ||
364 | assert_eq!(size_of::<OnceCell<u32>>(), 4 * size_of::<u32>()); | |
365 | } | |
e1599b0c | 366 | } |