]> git.lizzy.rs Git - rust.git/blob - src/libstd/sys/sgx/waitqueue.rs
SGX target: convert a bunch of panics to aborts
[rust.git] / src / libstd / sys / sgx / waitqueue.rs
1 /// A simple queue implementation for synchronization primitives.
2 ///
3 /// This queue is used to implement condition variable and mutexes.
4 ///
5 /// Users of this API are expected to use the `WaitVariable<T>` type. Since
6 /// that type is not `Sync`, it needs to be protected by e.g., a `SpinMutex` to
7 /// allow shared access.
8 ///
9 /// Since userspace may send spurious wake-ups, the wakeup event state is
10 /// recorded in the enclave. The wakeup event state is protected by a spinlock.
11 /// The queue and associated wait state are stored in a `WaitVariable`.
12
13 use crate::ops::{Deref, DerefMut};
14 use crate::num::NonZeroUsize;
15
16 use fortanix_sgx_abi::{Tcs, EV_UNPARK, WAIT_INDEFINITE};
17 use super::abi::usercalls;
18 use super::abi::thread;
19
20 use self::unsafe_list::{UnsafeList, UnsafeListEntry};
21 pub use self::spin_mutex::{SpinMutex, SpinMutexGuard, try_lock_or_false};
22
23 /// An queue entry in a `WaitQueue`.
24 struct WaitEntry {
25     /// TCS address of the thread that is waiting
26     tcs: Tcs,
27     /// Whether this thread has been notified to be awoken
28     wake: bool
29 }
30
31 /// Data stored with a `WaitQueue` alongside it. This ensures accesses to the
32 /// queue and the data are synchronized, since the type itself is not `Sync`.
33 ///
34 /// Consumers of this API should use a synchronization primitive for shared
35 /// access, such as `SpinMutex`.
36 #[derive(Default)]
37 pub struct WaitVariable<T> {
38     queue: WaitQueue,
39     lock: T
40 }
41
42 impl<T> WaitVariable<T> {
43     pub const fn new(var: T) -> Self {
44         WaitVariable {
45             queue: WaitQueue::new(),
46             lock: var
47         }
48     }
49
50     pub fn queue_empty(&self) -> bool {
51         self.queue.is_empty()
52     }
53
54     pub fn lock_var(&self) -> &T {
55         &self.lock
56     }
57
58     pub fn lock_var_mut(&mut self) -> &mut T {
59         &mut self.lock
60     }
61 }
62
63 #[derive(Copy, Clone)]
64 pub enum NotifiedTcs {
65     Single(Tcs),
66     All { count: NonZeroUsize }
67 }
68
69 /// An RAII guard that will notify a set of target threads as well as unlock
70 /// a mutex on drop.
71 pub struct WaitGuard<'a, T: 'a> {
72     mutex_guard: Option<SpinMutexGuard<'a, WaitVariable<T>>>,
73     notified_tcs: NotifiedTcs
74 }
75
76 /// A queue of threads that are waiting on some synchronization primitive.
77 ///
78 /// `UnsafeList` entries are allocated on the waiting thread's stack. This
79 /// avoids any global locking that might happen in the heap allocator. This is
80 /// safe because the waiting thread will not return from that stack frame until
81 /// after it is notified. The notifying thread ensures to clean up any
82 /// references to the list entries before sending the wakeup event.
83 pub struct WaitQueue {
84     // We use an inner Mutex here to protect the data in the face of spurious
85     // wakeups.
86     inner: UnsafeList<SpinMutex<WaitEntry>>,
87 }
88 unsafe impl Send for WaitQueue {}
89
90 impl Default for WaitQueue {
91     fn default() -> Self {
92         Self::new()
93     }
94 }
95
96 impl<'a, T> WaitGuard<'a, T> {
97     /// Returns which TCSes will be notified when this guard drops.
98     pub fn notified_tcs(&self) -> NotifiedTcs {
99         self.notified_tcs
100     }
101 }
102
103 impl<'a, T> Deref for WaitGuard<'a, T> {
104     type Target = SpinMutexGuard<'a, WaitVariable<T>>;
105
106     fn deref(&self) -> &Self::Target {
107         self.mutex_guard.as_ref().unwrap()
108     }
109 }
110
111 impl<'a, T> DerefMut for WaitGuard<'a, T> {
112     fn deref_mut(&mut self) -> &mut Self::Target {
113         self.mutex_guard.as_mut().unwrap()
114     }
115 }
116
117 impl<'a, T> Drop for WaitGuard<'a, T> {
118     fn drop(&mut self) {
119         drop(self.mutex_guard.take());
120         let target_tcs = match self.notified_tcs {
121             NotifiedTcs::Single(tcs) => Some(tcs),
122             NotifiedTcs::All { .. } => None
123         };
124         rtunwrap!(Ok, usercalls::send(EV_UNPARK, target_tcs));
125     }
126 }
127
128 impl WaitQueue {
129     pub const fn new() -> Self {
130         WaitQueue {
131             inner: UnsafeList::new()
132         }
133     }
134
135     pub fn is_empty(&self) -> bool {
136         self.inner.is_empty()
137     }
138
139     /// Adds the calling thread to the `WaitVariable`'s wait queue, then wait
140     /// until a wakeup event.
141     ///
142     /// This function does not return until this thread has been awoken.
143     pub fn wait<T>(mut guard: SpinMutexGuard<'_, WaitVariable<T>>) {
144         // very unsafe: check requirements of UnsafeList::push
145         unsafe {
146             let mut entry = UnsafeListEntry::new(SpinMutex::new(WaitEntry {
147                 tcs: thread::current(),
148                 wake: false
149             }));
150             let entry = guard.queue.inner.push(&mut entry);
151             drop(guard);
152             while !entry.lock().wake {
153                 // don't panic, this would invalidate `entry` during unwinding
154                 let eventset = rtunwrap!(Ok, usercalls::wait(EV_UNPARK, WAIT_INDEFINITE));
155                 rtassert!(eventset & EV_UNPARK == EV_UNPARK);
156             }
157         }
158     }
159
160     /// Either find the next waiter on the wait queue, or return the mutex
161     /// guard unchanged.
162     ///
163     /// If a waiter is found, a `WaitGuard` is returned which will notify the
164     /// waiter when it is dropped.
165     pub fn notify_one<T>(mut guard: SpinMutexGuard<'_, WaitVariable<T>>)
166         -> Result<WaitGuard<'_, T>, SpinMutexGuard<'_, WaitVariable<T>>>
167     {
168         unsafe {
169             if let Some(entry) = guard.queue.inner.pop() {
170                 let mut entry_guard = entry.lock();
171                 let tcs = entry_guard.tcs;
172                 entry_guard.wake = true;
173                 drop(entry);
174                 Ok(WaitGuard {
175                     mutex_guard: Some(guard),
176                     notified_tcs: NotifiedTcs::Single(tcs)
177                 })
178             } else {
179                 Err(guard)
180             }
181         }
182     }
183
184     /// Either find any and all waiters on the wait queue, or return the mutex
185     /// guard unchanged.
186     ///
187     /// If at least one waiter is found, a `WaitGuard` is returned which will
188     /// notify all waiters when it is dropped.
189     pub fn notify_all<T>(mut guard: SpinMutexGuard<'_, WaitVariable<T>>)
190         -> Result<WaitGuard<'_, T>, SpinMutexGuard<'_, WaitVariable<T>>>
191     {
192         unsafe {
193             let mut count = 0;
194             while let Some(entry) = guard.queue.inner.pop() {
195                 count += 1;
196                 let mut entry_guard = entry.lock();
197                 entry_guard.wake = true;
198             }
199             if let Some(count) = NonZeroUsize::new(count) {
200                 Ok(WaitGuard {
201                     mutex_guard: Some(guard),
202                     notified_tcs: NotifiedTcs::All { count }
203                 })
204             } else {
205                 Err(guard)
206             }
207         }
208     }
209 }
210
211 /// A doubly-linked list where callers are in charge of memory allocation
212 /// of the nodes in the list.
213 mod unsafe_list {
214     use crate::ptr::NonNull;
215     use crate::mem;
216
217     pub struct UnsafeListEntry<T> {
218         next: NonNull<UnsafeListEntry<T>>,
219         prev: NonNull<UnsafeListEntry<T>>,
220         value: Option<T>
221     }
222
223     impl<T> UnsafeListEntry<T> {
224         fn dummy() -> Self {
225             UnsafeListEntry {
226                 next: NonNull::dangling(),
227                 prev: NonNull::dangling(),
228                 value: None
229             }
230         }
231
232         pub fn new(value: T) -> Self {
233             UnsafeListEntry {
234                 value: Some(value),
235                 ..Self::dummy()
236             }
237         }
238     }
239
240     pub struct UnsafeList<T> {
241         head_tail: NonNull<UnsafeListEntry<T>>,
242         head_tail_entry: Option<UnsafeListEntry<T>>,
243     }
244
245     impl<T> UnsafeList<T> {
246         pub const fn new() -> Self {
247             unsafe {
248                 UnsafeList {
249                     head_tail: NonNull::new_unchecked(1 as _),
250                     head_tail_entry: None
251                 }
252             }
253         }
254
255         unsafe fn init(&mut self) {
256             if self.head_tail_entry.is_none() {
257                 self.head_tail_entry = Some(UnsafeListEntry::dummy());
258                 self.head_tail = NonNull::new_unchecked(self.head_tail_entry.as_mut().unwrap());
259                 self.head_tail.as_mut().next = self.head_tail;
260                 self.head_tail.as_mut().prev = self.head_tail;
261             }
262         }
263
264         pub fn is_empty(&self) -> bool {
265             unsafe {
266                 if self.head_tail_entry.is_some() {
267                     let first = self.head_tail.as_ref().next;
268                     if first == self.head_tail {
269                         // ,-------> /---------\ next ---,
270                         // |         |head_tail|         |
271                         // `--- prev \---------/ <-------`
272                         rtassert!(self.head_tail.as_ref().prev == first);
273                         true
274                     } else {
275                         false
276                     }
277                 } else {
278                     true
279                 }
280             }
281         }
282
283         /// Pushes an entry onto the back of the list.
284         ///
285         /// # Safety
286         ///
287         /// The entry must remain allocated until the entry is removed from the
288         /// list AND the caller who popped is done using the entry. Special
289         /// care must be taken in the caller of `push` to ensure unwinding does
290         /// not destroy the stack frame containing the entry.
291         pub unsafe fn push<'a>(&mut self, entry: &'a mut UnsafeListEntry<T>) -> &'a T {
292             self.init();
293
294             // BEFORE:
295             //     /---------\ next ---> /---------\
296             // ... |prev_tail|           |head_tail| ...
297             //     \---------/ <--- prev \---------/
298             //
299             // AFTER:
300             //     /---------\ next ---> /-----\ next ---> /---------\
301             // ... |prev_tail|           |entry|           |head_tail| ...
302             //     \---------/ <--- prev \-----/ <--- prev \---------/
303             let mut entry = NonNull::new_unchecked(entry);
304             let mut prev_tail = mem::replace(&mut self.head_tail.as_mut().prev, entry);
305             entry.as_mut().prev = prev_tail;
306             entry.as_mut().next = self.head_tail;
307             prev_tail.as_mut().next = entry;
308             // unwrap ok: always `Some` on non-dummy entries
309             (*entry.as_ptr()).value.as_ref().unwrap()
310         }
311
312         /// Pops an entry from the front of the list.
313         ///
314         /// # Safety
315         ///
316         /// The caller must make sure to synchronize ending the borrow of the
317         /// return value and deallocation of the containing entry.
318         pub unsafe fn pop<'a>(&mut self) -> Option<&'a T> {
319             self.init();
320
321             if self.is_empty() {
322                 None
323             } else {
324                 // BEFORE:
325                 //     /---------\ next ---> /-----\ next ---> /------\
326                 // ... |head_tail|           |first|           |second| ...
327                 //     \---------/ <--- prev \-----/ <--- prev \------/
328                 //
329                 // AFTER:
330                 //     /---------\ next ---> /------\
331                 // ... |head_tail|           |second| ...
332                 //     \---------/ <--- prev \------/
333                 let mut first = self.head_tail.as_mut().next;
334                 let mut second = first.as_mut().next;
335                 self.head_tail.as_mut().next = second;
336                 second.as_mut().prev = self.head_tail;
337                 first.as_mut().next = NonNull::dangling();
338                 first.as_mut().prev = NonNull::dangling();
339                 // unwrap ok: always `Some` on non-dummy entries
340                 Some((*first.as_ptr()).value.as_ref().unwrap())
341             }
342         }
343     }
344
345     #[cfg(test)]
346     mod tests {
347         use super::*;
348         use crate::cell::Cell;
349
350         unsafe fn assert_empty<T>(list: &mut UnsafeList<T>) {
351             assert!(list.pop().is_none(), "assertion failed: list is not empty");
352         }
353
354         #[test]
355         fn init_empty() {
356             unsafe {
357                 assert_empty(&mut UnsafeList::<i32>::new());
358             }
359         }
360
361         #[test]
362         fn push_pop() {
363             unsafe {
364                 let mut node = UnsafeListEntry::new(1234);
365                 let mut list = UnsafeList::new();
366                 assert_eq!(list.push(&mut node), &1234);
367                 assert_eq!(list.pop().unwrap(), &1234);
368                 assert_empty(&mut list);
369             }
370         }
371
372         #[test]
373         fn complex_pushes_pops() {
374             unsafe {
375                 let mut node1 = UnsafeListEntry::new(1234);
376                 let mut node2 = UnsafeListEntry::new(4567);
377                 let mut node3 = UnsafeListEntry::new(9999);
378                 let mut node4 = UnsafeListEntry::new(8642);
379                 let mut list = UnsafeList::new();
380                 list.push(&mut node1);
381                 list.push(&mut node2);
382                 assert_eq!(list.pop().unwrap(), &1234);
383                 list.push(&mut node3);
384                 assert_eq!(list.pop().unwrap(), &4567);
385                 assert_eq!(list.pop().unwrap(), &9999);
386                 assert_empty(&mut list);
387                 list.push(&mut node4);
388                 assert_eq!(list.pop().unwrap(), &8642);
389                 assert_empty(&mut list);
390             }
391         }
392
393         #[test]
394         fn cell() {
395             unsafe {
396                 let mut node = UnsafeListEntry::new(Cell::new(0));
397                 let mut list = UnsafeList::new();
398                 let noderef = list.push(&mut node);
399                 assert_eq!(noderef.get(), 0);
400                 list.pop().unwrap().set(1);
401                 assert_empty(&mut list);
402                 assert_eq!(noderef.get(), 1);
403             }
404         }
405     }
406 }
407
408 /// Trivial spinlock-based implementation of `sync::Mutex`.
409 // FIXME: Perhaps use Intel TSX to avoid locking?
410 mod spin_mutex {
411     use crate::cell::UnsafeCell;
412     use crate::sync::atomic::{AtomicBool, Ordering, spin_loop_hint};
413     use crate::ops::{Deref, DerefMut};
414
415     #[derive(Default)]
416     pub struct SpinMutex<T> {
417         value: UnsafeCell<T>,
418         lock: AtomicBool,
419     }
420
421     unsafe impl<T: Send> Send for SpinMutex<T> {}
422     unsafe impl<T: Send> Sync for SpinMutex<T> {}
423
424     pub struct SpinMutexGuard<'a, T: 'a> {
425         mutex: &'a SpinMutex<T>,
426     }
427
428     impl<'a, T> !Send for SpinMutexGuard<'a, T> {}
429     unsafe impl<'a, T: Sync> Sync for SpinMutexGuard<'a, T> {}
430
431     impl<T> SpinMutex<T> {
432         pub const fn new(value: T) -> Self {
433             SpinMutex {
434                 value: UnsafeCell::new(value),
435                 lock: AtomicBool::new(false)
436             }
437         }
438
439         #[inline(always)]
440         pub fn lock(&self) -> SpinMutexGuard<'_, T> {
441             loop {
442                 match self.try_lock() {
443                     None => while self.lock.load(Ordering::Relaxed) {
444                         spin_loop_hint()
445                     },
446                     Some(guard) => return guard
447                 }
448             }
449         }
450
451         #[inline(always)]
452         pub fn try_lock(&self) -> Option<SpinMutexGuard<'_, T>> {
453             if !self.lock.compare_and_swap(false, true, Ordering::Acquire) {
454                 Some(SpinMutexGuard {
455                     mutex: self,
456                 })
457             } else {
458                 None
459             }
460         }
461     }
462
463     /// Lock the Mutex or return false.
464     pub macro try_lock_or_false {
465         ($e:expr) => {
466             if let Some(v) = $e.try_lock() {
467                 v
468             } else {
469                 return false
470             }
471         }
472     }
473
474     impl<'a, T> Deref for SpinMutexGuard<'a, T> {
475         type Target = T;
476
477         fn deref(&self) -> &T {
478             unsafe {
479                 &*self.mutex.value.get()
480             }
481         }
482     }
483
484     impl<'a, T> DerefMut for SpinMutexGuard<'a, T> {
485         fn deref_mut(&mut self) -> &mut T {
486             unsafe {
487                 &mut*self.mutex.value.get()
488             }
489         }
490     }
491
492     impl<'a, T> Drop for SpinMutexGuard<'a, T> {
493         fn drop(&mut self) {
494             self.mutex.lock.store(false, Ordering::Release)
495         }
496     }
497
498     #[cfg(test)]
499     mod tests {
500         #![allow(deprecated)]
501
502         use super::*;
503         use crate::sync::Arc;
504         use crate::thread;
505         use crate::time::{SystemTime, Duration};
506
507         #[test]
508         fn sleep() {
509             let mutex = Arc::new(SpinMutex::<i32>::default());
510             let mutex2 = mutex.clone();
511             let guard = mutex.lock();
512             let t1 = thread::spawn(move || {
513                 *mutex2.lock() = 1;
514             });
515
516             // "sleep" for 50ms
517             // FIXME: https://github.com/fortanix/rust-sgx/issues/31
518             let start = SystemTime::now();
519             let max = Duration::from_millis(50);
520             while start.elapsed().unwrap() < max {}
521
522             assert_eq!(*guard, 0);
523             drop(guard);
524             t1.join().unwrap();
525             assert_eq!(*mutex.lock(), 1);
526         }
527     }
528 }
529
530 #[cfg(test)]
531 mod tests {
532     use super::*;
533     use crate::sync::Arc;
534     use crate::thread;
535
536     #[test]
537     fn queue() {
538         let wq = Arc::new(SpinMutex::<WaitVariable<()>>::default());
539         let wq2 = wq.clone();
540
541         let locked = wq.lock();
542
543         let t1 = thread::spawn(move || {
544             // if we obtain the lock, the main thread should be waiting
545             assert!(WaitQueue::notify_one(wq2.lock()).is_ok());
546         });
547
548         WaitQueue::wait(locked);
549
550         t1.join().unwrap();
551     }
552 }