]> git.lizzy.rs Git - rust.git/commitdiff
Fix unlock ordering in SGX synchronization primitives
authorJethro Beekman <jethro@fortanix.com>
Sat, 31 Aug 2019 03:35:27 +0000 (20:35 -0700)
committerJethro Beekman <jethro@fortanix.com>
Sat, 31 Aug 2019 03:35:27 +0000 (20:35 -0700)
src/libstd/sys/sgx/condvar.rs
src/libstd/sys/sgx/mutex.rs
src/libstd/sys/sgx/rwlock.rs
src/libstd/sys/sgx/waitqueue.rs

index 000bb19f2692ae8807551a5f7ba94271f41db1b8..cc1c04a83e752859343e4afcf8c6ebb281ac84c0 100644 (file)
@@ -27,8 +27,7 @@ pub unsafe fn notify_all(&self) {
 
     pub unsafe fn wait(&self, mutex: &Mutex) {
         let guard = self.inner.lock();
-        mutex.unlock();
-        WaitQueue::wait(guard);
+        WaitQueue::wait(guard, || mutex.unlock());
         mutex.lock()
     }
 
index f325fb1dd582f5a7dd8d580ca57b52be6d1756cd..662da8b3f66850f1a72add0674ad5006e0c84e57 100644 (file)
@@ -22,7 +22,7 @@ pub unsafe fn lock(&self) {
         let mut guard = self.inner.lock();
         if *guard.lock_var() {
             // Another thread has the lock, wait
-            WaitQueue::wait(guard)
+            WaitQueue::wait(guard, ||{})
             // Another thread has passed the lock to us
         } else {
             // We are just now obtaining the lock
@@ -83,7 +83,7 @@ pub unsafe fn lock(&self) {
         match guard.lock_var().owner {
             Some(tcs) if tcs != thread::current() => {
                 // Another thread has the lock, wait
-                WaitQueue::wait(guard);
+                WaitQueue::wait(guard, ||{});
                 // Another thread has passed the lock to us
             },
             _ => {
index 30c47e44eef8ecfaa794238a95c0a3d2de13ea8c..e2f94b1d928e10db38416f8c874fd88fb08c652d 100644 (file)
@@ -31,7 +31,7 @@ pub unsafe fn read(&self) {
         if *wguard.lock_var() || !wguard.queue_empty() {
             // Another thread has or is waiting for the write lock, wait
             drop(wguard);
-            WaitQueue::wait(rguard);
+            WaitQueue::wait(rguard, ||{});
             // Another thread has passed the lock to us
         } else {
             // No waiting writers, acquire the read lock
@@ -62,7 +62,7 @@ pub unsafe fn write(&self) {
         if *wguard.lock_var() || rguard.lock_var().is_some() {
             // Another thread has the lock, wait
             drop(rguard);
-            WaitQueue::wait(wguard);
+            WaitQueue::wait(wguard, ||{});
             // Another thread has passed the lock to us
         } else {
             // We are just now obtaining the lock
@@ -97,6 +97,7 @@ unsafe fn __read_unlock(
             if let Ok(mut wguard) = WaitQueue::notify_one(wguard) {
                 // A writer was waiting, pass the lock
                 *wguard.lock_var_mut() = true;
+                wguard.drop_after(rguard);
             } else {
                 // No writers were waiting, the lock is released
                 rtassert!(rguard.queue_empty());
@@ -117,21 +118,26 @@ unsafe fn __write_unlock(
         rguard: SpinMutexGuard<'_, WaitVariable<Option<NonZeroUsize>>>,
         wguard: SpinMutexGuard<'_, WaitVariable<bool>>,
     ) {
-        if let Err(mut wguard) = WaitQueue::notify_one(wguard) {
-            // No writers waiting, release the write lock
-            *wguard.lock_var_mut() = false;
-            if let Ok(mut rguard) = WaitQueue::notify_all(rguard) {
-                // One or more readers were waiting, pass the lock to them
-                if let NotifiedTcs::All { count } = rguard.notified_tcs() {
-                    *rguard.lock_var_mut() = Some(count)
+        match WaitQueue::notify_one(wguard) {
+            Err(mut wguard) => {
+                // No writers waiting, release the write lock
+                *wguard.lock_var_mut() = false;
+                if let Ok(mut rguard) = WaitQueue::notify_all(rguard) {
+                    // One or more readers were waiting, pass the lock to them
+                    if let NotifiedTcs::All { count } = rguard.notified_tcs() {
+                        *rguard.lock_var_mut() = Some(count)
+                    } else {
+                        unreachable!() // called notify_all
+                    }
+                    rguard.drop_after(wguard);
                 } else {
-                    unreachable!() // called notify_all
+                    // No readers waiting, the lock is released
                 }
-            } else {
-                // No readers waiting, the lock is released
+            },
+            Ok(wguard) => {
+                // There was a thread waiting for write, just pass the lock
+                wguard.drop_after(rguard);
             }
-        } else {
-            // There was a thread waiting for write, just pass the lock
         }
     }
 
index d542f9b41012793bc4b90d62e827c901dde9515a..3cb40e509b6b2973b93d0a8c1fa0a08df3b6f8a1 100644 (file)
@@ -98,6 +98,12 @@ impl<'a, T> WaitGuard<'a, T> {
     pub fn notified_tcs(&self) -> NotifiedTcs {
         self.notified_tcs
     }
+
+    /// Drop this `WaitGuard`, after dropping another `guard`.
+    pub fn drop_after<U>(self, guard: U) {
+        drop(guard);
+        drop(self);
+    }
 }
 
 impl<'a, T> Deref for WaitGuard<'a, T> {
@@ -140,7 +146,7 @@ pub fn is_empty(&self) -> bool {
     /// until a wakeup event.
     ///
     /// This function does not return until this thread has been awoken.
-    pub fn wait<T>(mut guard: SpinMutexGuard<'_, WaitVariable<T>>) {
+    pub fn wait<T, F: FnOnce()>(mut guard: SpinMutexGuard<'_, WaitVariable<T>>, before_wait: F) {
         // very unsafe: check requirements of UnsafeList::push
         unsafe {
             let mut entry = UnsafeListEntry::new(SpinMutex::new(WaitEntry {
@@ -149,6 +155,7 @@ pub fn wait<T>(mut guard: SpinMutexGuard<'_, WaitVariable<T>>) {
             }));
             let entry = guard.queue.inner.push(&mut entry);
             drop(guard);
+            before_wait();
             while !entry.lock().wake {
                 // don't panic, this would invalidate `entry` during unwinding
                 let eventset = rtunwrap!(Ok, usercalls::wait(EV_UNPARK, WAIT_INDEFINITE));
@@ -545,7 +552,7 @@ fn queue() {
             assert!(WaitQueue::notify_one(wq2.lock()).is_ok());
         });
 
-        WaitQueue::wait(locked);
+        WaitQueue::wait(locked, ||{});
 
         t1.join().unwrap();
     }