]> git.lizzy.rs Git - rust.git/blob - src/libstd/sys/sgx/waitqueue.rs
Fix unlock ordering in SGX synchronization primitives
[rust.git] / src / libstd / sys / sgx / waitqueue.rs
1 /// A simple queue implementation for synchronization primitives.
2 ///
3 /// This queue is used to implement condition variable and mutexes.
4 ///
5 /// Users of this API are expected to use the `WaitVariable<T>` type. Since
6 /// that type is not `Sync`, it needs to be protected by e.g., a `SpinMutex` to
7 /// allow shared access.
8 ///
9 /// Since userspace may send spurious wake-ups, the wakeup event state is
10 /// recorded in the enclave. The wakeup event state is protected by a spinlock.
11 /// The queue and associated wait state are stored in a `WaitVariable`.
12
13 use crate::ops::{Deref, DerefMut};
14 use crate::num::NonZeroUsize;
15
16 use fortanix_sgx_abi::{Tcs, EV_UNPARK, WAIT_INDEFINITE};
17 use super::abi::usercalls;
18 use super::abi::thread;
19
20 use self::unsafe_list::{UnsafeList, UnsafeListEntry};
21 pub use self::spin_mutex::{SpinMutex, SpinMutexGuard, try_lock_or_false};
22
23 /// An queue entry in a `WaitQueue`.
24 struct WaitEntry {
25     /// TCS address of the thread that is waiting
26     tcs: Tcs,
27     /// Whether this thread has been notified to be awoken
28     wake: bool
29 }
30
31 /// Data stored with a `WaitQueue` alongside it. This ensures accesses to the
32 /// queue and the data are synchronized, since the type itself is not `Sync`.
33 ///
34 /// Consumers of this API should use a synchronization primitive for shared
35 /// access, such as `SpinMutex`.
36 #[derive(Default)]
37 pub struct WaitVariable<T> {
38     queue: WaitQueue,
39     lock: T
40 }
41
42 impl<T> WaitVariable<T> {
43     pub const fn new(var: T) -> Self {
44         WaitVariable {
45             queue: WaitQueue::new(),
46             lock: var
47         }
48     }
49
50     pub fn queue_empty(&self) -> bool {
51         self.queue.is_empty()
52     }
53
54     pub fn lock_var(&self) -> &T {
55         &self.lock
56     }
57
58     pub fn lock_var_mut(&mut self) -> &mut T {
59         &mut self.lock
60     }
61 }
62
63 #[derive(Copy, Clone)]
64 pub enum NotifiedTcs {
65     Single(Tcs),
66     All { count: NonZeroUsize }
67 }
68
69 /// An RAII guard that will notify a set of target threads as well as unlock
70 /// a mutex on drop.
71 pub struct WaitGuard<'a, T: 'a> {
72     mutex_guard: Option<SpinMutexGuard<'a, WaitVariable<T>>>,
73     notified_tcs: NotifiedTcs
74 }
75
76 /// A queue of threads that are waiting on some synchronization primitive.
77 ///
78 /// `UnsafeList` entries are allocated on the waiting thread's stack. This
79 /// avoids any global locking that might happen in the heap allocator. This is
80 /// safe because the waiting thread will not return from that stack frame until
81 /// after it is notified. The notifying thread ensures to clean up any
82 /// references to the list entries before sending the wakeup event.
83 pub struct WaitQueue {
84     // We use an inner Mutex here to protect the data in the face of spurious
85     // wakeups.
86     inner: UnsafeList<SpinMutex<WaitEntry>>,
87 }
88 unsafe impl Send for WaitQueue {}
89
90 impl Default for WaitQueue {
91     fn default() -> Self {
92         Self::new()
93     }
94 }
95
96 impl<'a, T> WaitGuard<'a, T> {
97     /// Returns which TCSes will be notified when this guard drops.
98     pub fn notified_tcs(&self) -> NotifiedTcs {
99         self.notified_tcs
100     }
101
102     /// Drop this `WaitGuard`, after dropping another `guard`.
103     pub fn drop_after<U>(self, guard: U) {
104         drop(guard);
105         drop(self);
106     }
107 }
108
109 impl<'a, T> Deref for WaitGuard<'a, T> {
110     type Target = SpinMutexGuard<'a, WaitVariable<T>>;
111
112     fn deref(&self) -> &Self::Target {
113         self.mutex_guard.as_ref().unwrap()
114     }
115 }
116
117 impl<'a, T> DerefMut for WaitGuard<'a, T> {
118     fn deref_mut(&mut self) -> &mut Self::Target {
119         self.mutex_guard.as_mut().unwrap()
120     }
121 }
122
123 impl<'a, T> Drop for WaitGuard<'a, T> {
124     fn drop(&mut self) {
125         drop(self.mutex_guard.take());
126         let target_tcs = match self.notified_tcs {
127             NotifiedTcs::Single(tcs) => Some(tcs),
128             NotifiedTcs::All { .. } => None
129         };
130         rtunwrap!(Ok, usercalls::send(EV_UNPARK, target_tcs));
131     }
132 }
133
134 impl WaitQueue {
135     pub const fn new() -> Self {
136         WaitQueue {
137             inner: UnsafeList::new()
138         }
139     }
140
141     pub fn is_empty(&self) -> bool {
142         self.inner.is_empty()
143     }
144
145     /// Adds the calling thread to the `WaitVariable`'s wait queue, then wait
146     /// until a wakeup event.
147     ///
148     /// This function does not return until this thread has been awoken.
149     pub fn wait<T, F: FnOnce()>(mut guard: SpinMutexGuard<'_, WaitVariable<T>>, before_wait: F) {
150         // very unsafe: check requirements of UnsafeList::push
151         unsafe {
152             let mut entry = UnsafeListEntry::new(SpinMutex::new(WaitEntry {
153                 tcs: thread::current(),
154                 wake: false
155             }));
156             let entry = guard.queue.inner.push(&mut entry);
157             drop(guard);
158             before_wait();
159             while !entry.lock().wake {
160                 // don't panic, this would invalidate `entry` during unwinding
161                 let eventset = rtunwrap!(Ok, usercalls::wait(EV_UNPARK, WAIT_INDEFINITE));
162                 rtassert!(eventset & EV_UNPARK == EV_UNPARK);
163             }
164         }
165     }
166
167     /// Either find the next waiter on the wait queue, or return the mutex
168     /// guard unchanged.
169     ///
170     /// If a waiter is found, a `WaitGuard` is returned which will notify the
171     /// waiter when it is dropped.
172     pub fn notify_one<T>(mut guard: SpinMutexGuard<'_, WaitVariable<T>>)
173         -> Result<WaitGuard<'_, T>, SpinMutexGuard<'_, WaitVariable<T>>>
174     {
175         unsafe {
176             if let Some(entry) = guard.queue.inner.pop() {
177                 let mut entry_guard = entry.lock();
178                 let tcs = entry_guard.tcs;
179                 entry_guard.wake = true;
180                 drop(entry);
181                 Ok(WaitGuard {
182                     mutex_guard: Some(guard),
183                     notified_tcs: NotifiedTcs::Single(tcs)
184                 })
185             } else {
186                 Err(guard)
187             }
188         }
189     }
190
191     /// Either find any and all waiters on the wait queue, or return the mutex
192     /// guard unchanged.
193     ///
194     /// If at least one waiter is found, a `WaitGuard` is returned which will
195     /// notify all waiters when it is dropped.
196     pub fn notify_all<T>(mut guard: SpinMutexGuard<'_, WaitVariable<T>>)
197         -> Result<WaitGuard<'_, T>, SpinMutexGuard<'_, WaitVariable<T>>>
198     {
199         unsafe {
200             let mut count = 0;
201             while let Some(entry) = guard.queue.inner.pop() {
202                 count += 1;
203                 let mut entry_guard = entry.lock();
204                 entry_guard.wake = true;
205             }
206             if let Some(count) = NonZeroUsize::new(count) {
207                 Ok(WaitGuard {
208                     mutex_guard: Some(guard),
209                     notified_tcs: NotifiedTcs::All { count }
210                 })
211             } else {
212                 Err(guard)
213             }
214         }
215     }
216 }
217
218 /// A doubly-linked list where callers are in charge of memory allocation
219 /// of the nodes in the list.
220 mod unsafe_list {
221     use crate::ptr::NonNull;
222     use crate::mem;
223
224     pub struct UnsafeListEntry<T> {
225         next: NonNull<UnsafeListEntry<T>>,
226         prev: NonNull<UnsafeListEntry<T>>,
227         value: Option<T>
228     }
229
230     impl<T> UnsafeListEntry<T> {
231         fn dummy() -> Self {
232             UnsafeListEntry {
233                 next: NonNull::dangling(),
234                 prev: NonNull::dangling(),
235                 value: None
236             }
237         }
238
239         pub fn new(value: T) -> Self {
240             UnsafeListEntry {
241                 value: Some(value),
242                 ..Self::dummy()
243             }
244         }
245     }
246
247     pub struct UnsafeList<T> {
248         head_tail: NonNull<UnsafeListEntry<T>>,
249         head_tail_entry: Option<UnsafeListEntry<T>>,
250     }
251
252     impl<T> UnsafeList<T> {
253         pub const fn new() -> Self {
254             unsafe {
255                 UnsafeList {
256                     head_tail: NonNull::new_unchecked(1 as _),
257                     head_tail_entry: None
258                 }
259             }
260         }
261
262         unsafe fn init(&mut self) {
263             if self.head_tail_entry.is_none() {
264                 self.head_tail_entry = Some(UnsafeListEntry::dummy());
265                 self.head_tail = NonNull::new_unchecked(self.head_tail_entry.as_mut().unwrap());
266                 self.head_tail.as_mut().next = self.head_tail;
267                 self.head_tail.as_mut().prev = self.head_tail;
268             }
269         }
270
271         pub fn is_empty(&self) -> bool {
272             unsafe {
273                 if self.head_tail_entry.is_some() {
274                     let first = self.head_tail.as_ref().next;
275                     if first == self.head_tail {
276                         // ,-------> /---------\ next ---,
277                         // |         |head_tail|         |
278                         // `--- prev \---------/ <-------`
279                         rtassert!(self.head_tail.as_ref().prev == first);
280                         true
281                     } else {
282                         false
283                     }
284                 } else {
285                     true
286                 }
287             }
288         }
289
290         /// Pushes an entry onto the back of the list.
291         ///
292         /// # Safety
293         ///
294         /// The entry must remain allocated until the entry is removed from the
295         /// list AND the caller who popped is done using the entry. Special
296         /// care must be taken in the caller of `push` to ensure unwinding does
297         /// not destroy the stack frame containing the entry.
298         pub unsafe fn push<'a>(&mut self, entry: &'a mut UnsafeListEntry<T>) -> &'a T {
299             self.init();
300
301             // BEFORE:
302             //     /---------\ next ---> /---------\
303             // ... |prev_tail|           |head_tail| ...
304             //     \---------/ <--- prev \---------/
305             //
306             // AFTER:
307             //     /---------\ next ---> /-----\ next ---> /---------\
308             // ... |prev_tail|           |entry|           |head_tail| ...
309             //     \---------/ <--- prev \-----/ <--- prev \---------/
310             let mut entry = NonNull::new_unchecked(entry);
311             let mut prev_tail = mem::replace(&mut self.head_tail.as_mut().prev, entry);
312             entry.as_mut().prev = prev_tail;
313             entry.as_mut().next = self.head_tail;
314             prev_tail.as_mut().next = entry;
315             // unwrap ok: always `Some` on non-dummy entries
316             (*entry.as_ptr()).value.as_ref().unwrap()
317         }
318
319         /// Pops an entry from the front of the list.
320         ///
321         /// # Safety
322         ///
323         /// The caller must make sure to synchronize ending the borrow of the
324         /// return value and deallocation of the containing entry.
325         pub unsafe fn pop<'a>(&mut self) -> Option<&'a T> {
326             self.init();
327
328             if self.is_empty() {
329                 None
330             } else {
331                 // BEFORE:
332                 //     /---------\ next ---> /-----\ next ---> /------\
333                 // ... |head_tail|           |first|           |second| ...
334                 //     \---------/ <--- prev \-----/ <--- prev \------/
335                 //
336                 // AFTER:
337                 //     /---------\ next ---> /------\
338                 // ... |head_tail|           |second| ...
339                 //     \---------/ <--- prev \------/
340                 let mut first = self.head_tail.as_mut().next;
341                 let mut second = first.as_mut().next;
342                 self.head_tail.as_mut().next = second;
343                 second.as_mut().prev = self.head_tail;
344                 first.as_mut().next = NonNull::dangling();
345                 first.as_mut().prev = NonNull::dangling();
346                 // unwrap ok: always `Some` on non-dummy entries
347                 Some((*first.as_ptr()).value.as_ref().unwrap())
348             }
349         }
350     }
351
352     #[cfg(test)]
353     mod tests {
354         use super::*;
355         use crate::cell::Cell;
356
357         unsafe fn assert_empty<T>(list: &mut UnsafeList<T>) {
358             assert!(list.pop().is_none(), "assertion failed: list is not empty");
359         }
360
361         #[test]
362         fn init_empty() {
363             unsafe {
364                 assert_empty(&mut UnsafeList::<i32>::new());
365             }
366         }
367
368         #[test]
369         fn push_pop() {
370             unsafe {
371                 let mut node = UnsafeListEntry::new(1234);
372                 let mut list = UnsafeList::new();
373                 assert_eq!(list.push(&mut node), &1234);
374                 assert_eq!(list.pop().unwrap(), &1234);
375                 assert_empty(&mut list);
376             }
377         }
378
379         #[test]
380         fn complex_pushes_pops() {
381             unsafe {
382                 let mut node1 = UnsafeListEntry::new(1234);
383                 let mut node2 = UnsafeListEntry::new(4567);
384                 let mut node3 = UnsafeListEntry::new(9999);
385                 let mut node4 = UnsafeListEntry::new(8642);
386                 let mut list = UnsafeList::new();
387                 list.push(&mut node1);
388                 list.push(&mut node2);
389                 assert_eq!(list.pop().unwrap(), &1234);
390                 list.push(&mut node3);
391                 assert_eq!(list.pop().unwrap(), &4567);
392                 assert_eq!(list.pop().unwrap(), &9999);
393                 assert_empty(&mut list);
394                 list.push(&mut node4);
395                 assert_eq!(list.pop().unwrap(), &8642);
396                 assert_empty(&mut list);
397             }
398         }
399
400         #[test]
401         fn cell() {
402             unsafe {
403                 let mut node = UnsafeListEntry::new(Cell::new(0));
404                 let mut list = UnsafeList::new();
405                 let noderef = list.push(&mut node);
406                 assert_eq!(noderef.get(), 0);
407                 list.pop().unwrap().set(1);
408                 assert_empty(&mut list);
409                 assert_eq!(noderef.get(), 1);
410             }
411         }
412     }
413 }
414
415 /// Trivial spinlock-based implementation of `sync::Mutex`.
416 // FIXME: Perhaps use Intel TSX to avoid locking?
417 mod spin_mutex {
418     use crate::cell::UnsafeCell;
419     use crate::sync::atomic::{AtomicBool, Ordering, spin_loop_hint};
420     use crate::ops::{Deref, DerefMut};
421
422     #[derive(Default)]
423     pub struct SpinMutex<T> {
424         value: UnsafeCell<T>,
425         lock: AtomicBool,
426     }
427
428     unsafe impl<T: Send> Send for SpinMutex<T> {}
429     unsafe impl<T: Send> Sync for SpinMutex<T> {}
430
431     pub struct SpinMutexGuard<'a, T: 'a> {
432         mutex: &'a SpinMutex<T>,
433     }
434
435     impl<'a, T> !Send for SpinMutexGuard<'a, T> {}
436     unsafe impl<'a, T: Sync> Sync for SpinMutexGuard<'a, T> {}
437
438     impl<T> SpinMutex<T> {
439         pub const fn new(value: T) -> Self {
440             SpinMutex {
441                 value: UnsafeCell::new(value),
442                 lock: AtomicBool::new(false)
443             }
444         }
445
446         #[inline(always)]
447         pub fn lock(&self) -> SpinMutexGuard<'_, T> {
448             loop {
449                 match self.try_lock() {
450                     None => while self.lock.load(Ordering::Relaxed) {
451                         spin_loop_hint()
452                     },
453                     Some(guard) => return guard
454                 }
455             }
456         }
457
458         #[inline(always)]
459         pub fn try_lock(&self) -> Option<SpinMutexGuard<'_, T>> {
460             if !self.lock.compare_and_swap(false, true, Ordering::Acquire) {
461                 Some(SpinMutexGuard {
462                     mutex: self,
463                 })
464             } else {
465                 None
466             }
467         }
468     }
469
470     /// Lock the Mutex or return false.
471     pub macro try_lock_or_false {
472         ($e:expr) => {
473             if let Some(v) = $e.try_lock() {
474                 v
475             } else {
476                 return false
477             }
478         }
479     }
480
481     impl<'a, T> Deref for SpinMutexGuard<'a, T> {
482         type Target = T;
483
484         fn deref(&self) -> &T {
485             unsafe {
486                 &*self.mutex.value.get()
487             }
488         }
489     }
490
491     impl<'a, T> DerefMut for SpinMutexGuard<'a, T> {
492         fn deref_mut(&mut self) -> &mut T {
493             unsafe {
494                 &mut*self.mutex.value.get()
495             }
496         }
497     }
498
499     impl<'a, T> Drop for SpinMutexGuard<'a, T> {
500         fn drop(&mut self) {
501             self.mutex.lock.store(false, Ordering::Release)
502         }
503     }
504
505     #[cfg(test)]
506     mod tests {
507         #![allow(deprecated)]
508
509         use super::*;
510         use crate::sync::Arc;
511         use crate::thread;
512         use crate::time::{SystemTime, Duration};
513
514         #[test]
515         fn sleep() {
516             let mutex = Arc::new(SpinMutex::<i32>::default());
517             let mutex2 = mutex.clone();
518             let guard = mutex.lock();
519             let t1 = thread::spawn(move || {
520                 *mutex2.lock() = 1;
521             });
522
523             // "sleep" for 50ms
524             // FIXME: https://github.com/fortanix/rust-sgx/issues/31
525             let start = SystemTime::now();
526             let max = Duration::from_millis(50);
527             while start.elapsed().unwrap() < max {}
528
529             assert_eq!(*guard, 0);
530             drop(guard);
531             t1.join().unwrap();
532             assert_eq!(*mutex.lock(), 1);
533         }
534     }
535 }
536
537 #[cfg(test)]
538 mod tests {
539     use super::*;
540     use crate::sync::Arc;
541     use crate::thread;
542
543     #[test]
544     fn queue() {
545         let wq = Arc::new(SpinMutex::<WaitVariable<()>>::default());
546         let wq2 = wq.clone();
547
548         let locked = wq.lock();
549
550         let t1 = thread::spawn(move || {
551             // if we obtain the lock, the main thread should be waiting
552             assert!(WaitQueue::notify_one(wq2.lock()).is_ok());
553         });
554
555         WaitQueue::wait(locked, ||{});
556
557         t1.join().unwrap();
558     }
559 }