]> git.proxmox.com Git - rustc.git/blobdiff - library/std/src/sys_common/remutex.rs
New upstream version 1.63.0+dfsg1
[rustc.git] / library / std / src / sys_common / remutex.rs
index 162eab2388d554e905c684bcd7ad4b2948f8a6aa..8921af311d4152bb7b6fc7b4c37763ce35e14915 100644 (file)
@@ -1,20 +1,51 @@
 #[cfg(all(test, not(target_os = "emscripten")))]
 mod tests;
 
-use crate::fmt;
-use crate::marker;
+use crate::cell::UnsafeCell;
+use crate::marker::PhantomPinned;
 use crate::ops::Deref;
 use crate::panic::{RefUnwindSafe, UnwindSafe};
-use crate::sys::mutex as sys;
+use crate::pin::Pin;
+use crate::sync::atomic::{AtomicUsize, Ordering::Relaxed};
+use crate::sys::locks as sys;
 
 /// A re-entrant mutual exclusion
 ///
 /// This mutex will block *other* threads waiting for the lock to become
 /// available. The thread which has already locked the mutex can lock it
 /// multiple times without blocking, preventing a common source of deadlocks.
+///
+/// This is used by stdout().lock() and friends.
+///
+/// ## Implementation details
+///
+/// The 'owner' field tracks which thread has locked the mutex.
+///
+/// We use current_thread_unique_ptr() as the thread identifier,
+/// which is just the address of a thread local variable.
+///
+/// If `owner` is set to the identifier of the current thread,
+/// we assume the mutex is already locked and instead of locking it again,
+/// we increment `lock_count`.
+///
+/// When unlocking, we decrement `lock_count`, and only unlock the mutex when
+/// it reaches zero.
+///
+/// `lock_count` is protected by the mutex and only accessed by the thread that has
+/// locked the mutex, so needs no synchronization.
+///
+/// `owner` can be checked by other threads that want to see if they already
+/// hold the lock, so needs to be atomic. If it compares equal, we're on the
+/// same thread that holds the mutex and memory access can use relaxed ordering
+/// since we're not dealing with multiple threads. If it compares unequal,
+/// synchronization is left to the mutex, making relaxed memory ordering for
+/// the `owner` field fine in all cases.
 pub struct ReentrantMutex<T> {
-    inner: sys::ReentrantMutex,
+    mutex: sys::Mutex,
+    owner: AtomicUsize,
+    lock_count: UnsafeCell<u32>,
     data: T,
+    _pinned: PhantomPinned,
 }
 
 unsafe impl<T: Send> Send for ReentrantMutex<T> {}
@@ -37,10 +68,10 @@ impl<T> RefUnwindSafe for ReentrantMutex<T> {}
 /// guarded data.
 #[must_use = "if unused the ReentrantMutex will immediately unlock"]
 pub struct ReentrantMutexGuard<'a, T: 'a> {
-    lock: &'a ReentrantMutex<T>,
+    lock: Pin<&'a ReentrantMutex<T>>,
 }
 
-impl<T> !marker::Send for ReentrantMutexGuard<'_, T> {}
+impl<T> !Send for ReentrantMutexGuard<'_, T> {}
 
 impl<T> ReentrantMutex<T> {
     /// Creates a new reentrant mutex in an unlocked state.
@@ -51,7 +82,13 @@ impl<T> ReentrantMutex<T> {
     /// once this mutex is in its final resting place, and only then are the
     /// lock/unlock methods safe.
     pub const unsafe fn new(t: T) -> ReentrantMutex<T> {
-        ReentrantMutex { inner: sys::ReentrantMutex::uninitialized(), data: t }
+        ReentrantMutex {
+            mutex: sys::Mutex::new(),
+            owner: AtomicUsize::new(0),
+            lock_count: UnsafeCell::new(0),
+            data: t,
+            _pinned: PhantomPinned,
+        }
     }
 
     /// Initializes this mutex so it's ready for use.
@@ -60,8 +97,8 @@ impl<T> ReentrantMutex<T> {
     ///
     /// Unsafe to call more than once, and must be called after this will no
     /// longer move in memory.
-    pub unsafe fn init(&self) {
-        self.inner.init();
+    pub unsafe fn init(self: Pin<&mut Self>) {
+        self.get_unchecked_mut().mutex.init()
     }
 
     /// Acquires a mutex, blocking the current thread until it is able to do so.
@@ -76,9 +113,21 @@ impl<T> ReentrantMutex<T> {
     /// If another user of this mutex panicked while holding the mutex, then
     /// this call will return failure if the mutex would otherwise be
     /// acquired.
-    pub fn lock(&self) -> ReentrantMutexGuard<'_, T> {
-        unsafe { self.inner.lock() }
-        ReentrantMutexGuard::new(&self)
+    pub fn lock(self: Pin<&Self>) -> ReentrantMutexGuard<'_, T> {
+        let this_thread = current_thread_unique_ptr();
+        // Safety: We only touch lock_count when we own the lock,
+        // and since self is pinned we can safely call the lock() on the mutex.
+        unsafe {
+            if self.owner.load(Relaxed) == this_thread {
+                self.increment_lock_count();
+            } else {
+                self.mutex.lock();
+                self.owner.store(this_thread, Relaxed);
+                debug_assert_eq!(*self.lock_count.get(), 0);
+                *self.lock_count.get() = 1;
+            }
+        }
+        ReentrantMutexGuard { lock: self }
     }
 
     /// Attempts to acquire this lock.
@@ -93,41 +142,29 @@ impl<T> ReentrantMutex<T> {
     /// If another user of this mutex panicked while holding the mutex, then
     /// this call will return failure if the mutex would otherwise be
     /// acquired.
-    pub fn try_lock(&self) -> Option<ReentrantMutexGuard<'_, T>> {
-        if unsafe { self.inner.try_lock() } { Some(ReentrantMutexGuard::new(&self)) } else { None }
-    }
-}
-
-impl<T> Drop for ReentrantMutex<T> {
-    fn drop(&mut self) {
-        // This is actually safe b/c we know that there is no further usage of
-        // this mutex (it's up to the user to arrange for a mutex to get
-        // dropped, that's not our job)
-        unsafe { self.inner.destroy() }
-    }
-}
-
-impl<T: fmt::Debug + 'static> fmt::Debug for ReentrantMutex<T> {
-    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
-        match self.try_lock() {
-            Some(guard) => f.debug_struct("ReentrantMutex").field("data", &*guard).finish(),
-            None => {
-                struct LockedPlaceholder;
-                impl fmt::Debug for LockedPlaceholder {
-                    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
-                        f.write_str("<locked>")
-                    }
-                }
-
-                f.debug_struct("ReentrantMutex").field("data", &LockedPlaceholder).finish()
+    pub fn try_lock(self: Pin<&Self>) -> Option<ReentrantMutexGuard<'_, T>> {
+        let this_thread = current_thread_unique_ptr();
+        // Safety: We only touch lock_count when we own the lock,
+        // and since self is pinned we can safely call the try_lock on the mutex.
+        unsafe {
+            if self.owner.load(Relaxed) == this_thread {
+                self.increment_lock_count();
+                Some(ReentrantMutexGuard { lock: self })
+            } else if self.mutex.try_lock() {
+                self.owner.store(this_thread, Relaxed);
+                debug_assert_eq!(*self.lock_count.get(), 0);
+                *self.lock_count.get() = 1;
+                Some(ReentrantMutexGuard { lock: self })
+            } else {
+                None
             }
         }
     }
-}
 
-impl<'mutex, T> ReentrantMutexGuard<'mutex, T> {
-    fn new(lock: &'mutex ReentrantMutex<T>) -> ReentrantMutexGuard<'mutex, T> {
-        ReentrantMutexGuard { lock }
+    unsafe fn increment_lock_count(&self) {
+        *self.lock_count.get() = (*self.lock_count.get())
+            .checked_add(1)
+            .expect("lock count overflow in reentrant mutex");
     }
 }
 
@@ -142,8 +179,22 @@ impl<T> Deref for ReentrantMutexGuard<'_, T> {
 impl<T> Drop for ReentrantMutexGuard<'_, T> {
     #[inline]
     fn drop(&mut self) {
+        // Safety: We own the lock, and the lock is pinned.
         unsafe {
-            self.lock.inner.unlock();
+            *self.lock.lock_count.get() -= 1;
+            if *self.lock.lock_count.get() == 0 {
+                self.lock.owner.store(0, Relaxed);
+                self.lock.mutex.unlock();
+            }
         }
     }
 }
+
+/// Get an address that is unique per running thread.
+///
+/// This can be used as a non-null usize-sized ID.
+pub fn current_thread_unique_ptr() -> usize {
+    // Use a non-drop type to make sure it's still available during thread destruction.
+    thread_local! { static X: u8 = const { 0 } }
+    X.with(|x| <*const _>::addr(x))
+}