]> git.lizzy.rs Git - rust.git/blobdiff - src/libstd/sys/sgx/waitqueue.rs
Fix unlock ordering in SGX synchronization primitives
[rust.git] / src / libstd / sys / sgx / waitqueue.rs
index 3f5e03ddad69eee59068aa76735f8a98ba1cc1a0..3cb40e509b6b2973b93d0a8c1fa0a08df3b6f8a1 100644 (file)
@@ -98,6 +98,12 @@ impl<'a, T> WaitGuard<'a, T> {
     pub fn notified_tcs(&self) -> NotifiedTcs {
         self.notified_tcs
     }
+
+    /// Drop this `WaitGuard`, after dropping another `guard`.
+    pub fn drop_after<U>(self, guard: U) {
+        drop(guard);
+        drop(self);
+    }
 }
 
 impl<'a, T> Deref for WaitGuard<'a, T> {
@@ -121,7 +127,7 @@ fn drop(&mut self) {
             NotifiedTcs::Single(tcs) => Some(tcs),
             NotifiedTcs::All { .. } => None
         };
-        usercalls::send(EV_UNPARK, target_tcs).unwrap();
+        rtunwrap!(Ok, usercalls::send(EV_UNPARK, target_tcs));
     }
 }
 
@@ -140,7 +146,8 @@ pub fn is_empty(&self) -> bool {
     /// until a wakeup event.
     ///
     /// This function does not return until this thread has been awoken.
-    pub fn wait<T>(mut guard: SpinMutexGuard<WaitVariable<T>>) {
+    pub fn wait<T, F: FnOnce()>(mut guard: SpinMutexGuard<'_, WaitVariable<T>>, before_wait: F) {
+        // very unsafe: check requirements of UnsafeList::push
         unsafe {
             let mut entry = UnsafeListEntry::new(SpinMutex::new(WaitEntry {
                 tcs: thread::current(),
@@ -148,11 +155,11 @@ pub fn wait<T>(mut guard: SpinMutexGuard<WaitVariable<T>>) {
             }));
             let entry = guard.queue.inner.push(&mut entry);
             drop(guard);
+            before_wait();
             while !entry.lock().wake {
-                assert_eq!(
-                    usercalls::wait(EV_UNPARK, WAIT_INDEFINITE).unwrap() & EV_UNPARK,
-                    EV_UNPARK
-                );
+                // don't panic, this would invalidate `entry` during unwinding
+                let eventset = rtunwrap!(Ok, usercalls::wait(EV_UNPARK, WAIT_INDEFINITE));
+                rtassert!(eventset & EV_UNPARK == EV_UNPARK);
             }
         }
     }
@@ -162,8 +169,8 @@ pub fn wait<T>(mut guard: SpinMutexGuard<WaitVariable<T>>) {
     ///
     /// If a waiter is found, a `WaitGuard` is returned which will notify the
     /// waiter when it is dropped.
-    pub fn notify_one<T>(mut guard: SpinMutexGuard<WaitVariable<T>>)
-        -> Result<WaitGuard<T>, SpinMutexGuard<WaitVariable<T>>>
+    pub fn notify_one<T>(mut guard: SpinMutexGuard<'_, WaitVariable<T>>)
+        -> Result<WaitGuard<'_, T>, SpinMutexGuard<'_, WaitVariable<T>>>
     {
         unsafe {
             if let Some(entry) = guard.queue.inner.pop() {
@@ -186,8 +193,8 @@ pub fn notify_one<T>(mut guard: SpinMutexGuard<WaitVariable<T>>)
     ///
     /// If at least one waiter is found, a `WaitGuard` is returned which will
     /// notify all waiters when it is dropped.
-    pub fn notify_all<T>(mut guard: SpinMutexGuard<WaitVariable<T>>)
-        -> Result<WaitGuard<T>, SpinMutexGuard<WaitVariable<T>>>
+    pub fn notify_all<T>(mut guard: SpinMutexGuard<'_, WaitVariable<T>>)
+        -> Result<WaitGuard<'_, T>, SpinMutexGuard<'_, WaitVariable<T>>>
     {
         unsafe {
             let mut count = 0;
@@ -269,7 +276,7 @@ pub fn is_empty(&self) -> bool {
                         // ,-------> /---------\ next ---,
                         // |         |head_tail|         |
                         // `--- prev \---------/ <-------`
-                        assert_eq!(self.head_tail.as_ref().prev, first);
+                        rtassert!(self.head_tail.as_ref().prev == first);
                         true
                     } else {
                         false
@@ -285,7 +292,9 @@ pub fn is_empty(&self) -> bool {
         /// # Safety
         ///
         /// The entry must remain allocated until the entry is removed from the
-        /// list AND the caller who popped is done using the entry.
+        /// list AND the caller who popped is done using the entry. Special
+        /// care must be taken in the caller of `push` to ensure unwinding does
+        /// not destroy the stack frame containing the entry.
         pub unsafe fn push<'a>(&mut self, entry: &'a mut UnsafeListEntry<T>) -> &'a T {
             self.init();
 
@@ -303,6 +312,7 @@ pub unsafe fn push<'a>(&mut self, entry: &'a mut UnsafeListEntry<T>) -> &'a T {
             entry.as_mut().prev = prev_tail;
             entry.as_mut().next = self.head_tail;
             prev_tail.as_mut().next = entry;
+            // unwrap ok: always `Some` on non-dummy entries
             (*entry.as_ptr()).value.as_ref().unwrap()
         }
 
@@ -333,6 +343,7 @@ pub unsafe fn pop<'a>(&mut self) -> Option<&'a T> {
                 second.as_mut().prev = self.head_tail;
                 first.as_mut().next = NonNull::dangling();
                 first.as_mut().prev = NonNull::dangling();
+                // unwrap ok: always `Some` on non-dummy entries
                 Some((*first.as_ptr()).value.as_ref().unwrap())
             }
         }
@@ -433,7 +444,7 @@ pub const fn new(value: T) -> Self {
         }
 
         #[inline(always)]
-        pub fn lock(&self) -> SpinMutexGuard<T> {
+        pub fn lock(&self) -> SpinMutexGuard<'_, T> {
             loop {
                 match self.try_lock() {
                     None => while self.lock.load(Ordering::Relaxed) {
@@ -445,7 +456,7 @@ pub fn lock(&self) -> SpinMutexGuard<T> {
         }
 
         #[inline(always)]
-        pub fn try_lock(&self) -> Option<SpinMutexGuard<T>> {
+        pub fn try_lock(&self) -> Option<SpinMutexGuard<'_, T>> {
             if !self.lock.compare_and_swap(false, true, Ordering::Acquire) {
                 Some(SpinMutexGuard {
                     mutex: self,
@@ -541,7 +552,7 @@ fn queue() {
             assert!(WaitQueue::notify_one(wq2.lock()).is_ok());
         });
 
-        WaitQueue::wait(locked);
+        WaitQueue::wait(locked, ||{});
 
         t1.join().unwrap();
     }