]> git.lizzy.rs Git - rust.git/blob - src/libstd/sys/sgx/waitqueue.rs
1dbf2afbf4987e7b47368d7ecca303b69200922f
[rust.git] / src / libstd / sys / sgx / waitqueue.rs
1 /// A simple queue implementation for synchronization primitives.
2 ///
3 /// This queue is used to implement condition variable and mutexes.
4 ///
5 /// Users of this API are expected to use the `WaitVariable<T>` type. Since
6 /// that type is not `Sync`, it needs to be protected by e.g., a `SpinMutex` to
7 /// allow shared access.
8 ///
9 /// Since userspace may send spurious wake-ups, the wakeup event state is
10 /// recorded in the enclave. The wakeup event state is protected by a spinlock.
11 /// The queue and associated wait state are stored in a `WaitVariable`.
12
13 use crate::ops::{Deref, DerefMut};
14 use crate::num::NonZeroUsize;
15
16 use fortanix_sgx_abi::{Tcs, EV_UNPARK, WAIT_INDEFINITE};
17 use super::abi::usercalls;
18 use super::abi::thread;
19
20 use self::unsafe_list::{UnsafeList, UnsafeListEntry};
21 pub use self::spin_mutex::{SpinMutex, SpinMutexGuard, try_lock_or_false};
22
23 /// An queue entry in a `WaitQueue`.
24 struct WaitEntry {
25     /// TCS address of the thread that is waiting
26     tcs: Tcs,
27     /// Whether this thread has been notified to be awoken
28     wake: bool
29 }
30
31 /// Data stored with a `WaitQueue` alongside it. This ensures accesses to the
32 /// queue and the data are synchronized, since the type itself is not `Sync`.
33 ///
34 /// Consumers of this API should use a synchronization primitive for shared
35 /// access, such as `SpinMutex`.
36 #[derive(Default)]
37 pub struct WaitVariable<T> {
38     queue: WaitQueue,
39     lock: T
40 }
41
42 impl<T> WaitVariable<T> {
43     pub const fn new(var: T) -> Self {
44         WaitVariable {
45             queue: WaitQueue::new(),
46             lock: var
47         }
48     }
49
50     pub fn queue_empty(&self) -> bool {
51         self.queue.is_empty()
52     }
53
54     pub fn lock_var(&self) -> &T {
55         &self.lock
56     }
57
58     pub fn lock_var_mut(&mut self) -> &mut T {
59         &mut self.lock
60     }
61 }
62
63 #[derive(Copy, Clone)]
64 pub enum NotifiedTcs {
65     Single(Tcs),
66     All { count: NonZeroUsize }
67 }
68
69 /// An RAII guard that will notify a set of target threads as well as unlock
70 /// a mutex on drop.
71 pub struct WaitGuard<'a, T: 'a> {
72     mutex_guard: Option<SpinMutexGuard<'a, WaitVariable<T>>>,
73     notified_tcs: NotifiedTcs
74 }
75
76 /// A queue of threads that are waiting on some synchronization primitive.
77 ///
78 /// `UnsafeList` entries are allocated on the waiting thread's stack. This
79 /// avoids any global locking that might happen in the heap allocator. This is
80 /// safe because the waiting thread will not return from that stack frame until
81 /// after it is notified. The notifying thread ensures to clean up any
82 /// references to the list entries before sending the wakeup event.
83 pub struct WaitQueue {
84     // We use an inner Mutex here to protect the data in the face of spurious
85     // wakeups.
86     inner: UnsafeList<SpinMutex<WaitEntry>>,
87 }
88 unsafe impl Send for WaitQueue {}
89
90 impl Default for WaitQueue {
91     fn default() -> Self {
92         Self::new()
93     }
94 }
95
96 impl<'a, T> WaitGuard<'a, T> {
97     /// Returns which TCSes will be notified when this guard drops.
98     pub fn notified_tcs(&self) -> NotifiedTcs {
99         self.notified_tcs
100     }
101 }
102
103 impl<'a, T> Deref for WaitGuard<'a, T> {
104     type Target = SpinMutexGuard<'a, WaitVariable<T>>;
105
106     fn deref(&self) -> &Self::Target {
107         self.mutex_guard.as_ref().unwrap()
108     }
109 }
110
111 impl<'a, T> DerefMut for WaitGuard<'a, T> {
112     fn deref_mut(&mut self) -> &mut Self::Target {
113         self.mutex_guard.as_mut().unwrap()
114     }
115 }
116
117 impl<'a, T> Drop for WaitGuard<'a, T> {
118     fn drop(&mut self) {
119         drop(self.mutex_guard.take());
120         let target_tcs = match self.notified_tcs {
121             NotifiedTcs::Single(tcs) => Some(tcs),
122             NotifiedTcs::All { .. } => None
123         };
124         usercalls::send(EV_UNPARK, target_tcs).unwrap();
125     }
126 }
127
128 impl WaitQueue {
129     pub const fn new() -> Self {
130         WaitQueue {
131             inner: UnsafeList::new()
132         }
133     }
134
135     pub fn is_empty(&self) -> bool {
136         self.inner.is_empty()
137     }
138
139     /// Adds the calling thread to the `WaitVariable`'s wait queue, then wait
140     /// until a wakeup event.
141     ///
142     /// This function does not return until this thread has been awoken.
143     pub fn wait<T>(mut guard: SpinMutexGuard<WaitVariable<T>>) {
144         unsafe {
145             let mut entry = UnsafeListEntry::new(SpinMutex::new(WaitEntry {
146                 tcs: thread::current(),
147                 wake: false
148             }));
149             let entry = guard.queue.inner.push(&mut entry);
150             drop(guard);
151             while !entry.lock().wake {
152                 assert_eq!(
153                     usercalls::wait(EV_UNPARK, WAIT_INDEFINITE).unwrap() & EV_UNPARK,
154                     EV_UNPARK
155                 );
156             }
157         }
158     }
159
160     /// Either find the next waiter on the wait queue, or return the mutex
161     /// guard unchanged.
162     ///
163     /// If a waiter is found, a `WaitGuard` is returned which will notify the
164     /// waiter when it is dropped.
165     pub fn notify_one<T>(mut guard: SpinMutexGuard<WaitVariable<T>>)
166         -> Result<WaitGuard<T>, SpinMutexGuard<WaitVariable<T>>>
167     {
168         unsafe {
169             if let Some(entry) = guard.queue.inner.pop() {
170                 let mut entry_guard = entry.lock();
171                 let tcs = entry_guard.tcs;
172                 entry_guard.wake = true;
173                 drop(entry);
174                 Ok(WaitGuard {
175                     mutex_guard: Some(guard),
176                     notified_tcs: NotifiedTcs::Single(tcs)
177                 })
178             } else {
179                 Err(guard)
180             }
181         }
182     }
183
184     /// Either find any and all waiters on the wait queue, or return the mutex
185     /// guard unchanged.
186     ///
187     /// If at least one waiter is found, a `WaitGuard` is returned which will
188     /// notify all waiters when it is dropped.
189     pub fn notify_all<T>(mut guard: SpinMutexGuard<WaitVariable<T>>)
190         -> Result<WaitGuard<T>, SpinMutexGuard<WaitVariable<T>>>
191     {
192         unsafe {
193             let mut count = 0;
194             while let Some(entry) = guard.queue.inner.pop() {
195                 count += 1;
196                 let mut entry_guard = entry.lock();
197                 entry_guard.wake = true;
198             }
199             if let Some(count) = NonZeroUsize::new(count) {
200                 Ok(WaitGuard {
201                     mutex_guard: Some(guard),
202                     notified_tcs: NotifiedTcs::All { count }
203                 })
204             } else {
205                 Err(guard)
206             }
207         }
208     }
209 }
210
211 /// A doubly-linked list where callers are in charge of memory allocation
212 /// of the nodes in the list.
213 mod unsafe_list {
214     use crate::ptr::NonNull;
215     use crate::mem;
216
217     pub struct UnsafeListEntry<T> {
218         next: NonNull<UnsafeListEntry<T>>,
219         prev: NonNull<UnsafeListEntry<T>>,
220         value: Option<T>
221     }
222
223     impl<T> UnsafeListEntry<T> {
224         fn dummy() -> Self {
225             UnsafeListEntry {
226                 next: NonNull::dangling(),
227                 prev: NonNull::dangling(),
228                 value: None
229             }
230         }
231
232         pub fn new(value: T) -> Self {
233             UnsafeListEntry {
234                 value: Some(value),
235                 ..Self::dummy()
236             }
237         }
238     }
239
240     pub struct UnsafeList<T> {
241         head_tail: NonNull<UnsafeListEntry<T>>,
242         head_tail_entry: Option<UnsafeListEntry<T>>,
243     }
244
245     impl<T> UnsafeList<T> {
246         pub const fn new() -> Self {
247             unsafe {
248                 UnsafeList {
249                     head_tail: NonNull::new_unchecked(1 as _),
250                     head_tail_entry: None
251                 }
252             }
253         }
254
255         unsafe fn init(&mut self) {
256             if self.head_tail_entry.is_none() {
257                 self.head_tail_entry = Some(UnsafeListEntry::dummy());
258                 self.head_tail = NonNull::new_unchecked(self.head_tail_entry.as_mut().unwrap());
259                 self.head_tail.as_mut().next = self.head_tail;
260                 self.head_tail.as_mut().prev = self.head_tail;
261             }
262         }
263
264         pub fn is_empty(&self) -> bool {
265             unsafe {
266                 if self.head_tail_entry.is_some() {
267                     let first = self.head_tail.as_ref().next;
268                     if first == self.head_tail {
269                         // ,-------> /---------\ next ---,
270                         // |         |head_tail|         |
271                         // `--- prev \---------/ <-------`
272                         assert_eq!(self.head_tail.as_ref().prev, first);
273                         true
274                     } else {
275                         false
276                     }
277                 } else {
278                     true
279                 }
280             }
281         }
282
283         /// Pushes an entry onto the back of the list.
284         ///
285         /// # Safety
286         ///
287         /// The entry must remain allocated until the entry is removed from the
288         /// list AND the caller who popped is done using the entry.
289         pub unsafe fn push<'a>(&mut self, entry: &'a mut UnsafeListEntry<T>) -> &'a T {
290             self.init();
291
292             // BEFORE:
293             //     /---------\ next ---> /---------\
294             // ... |prev_tail|           |head_tail| ...
295             //     \---------/ <--- prev \---------/
296             //
297             // AFTER:
298             //     /---------\ next ---> /-----\ next ---> /---------\
299             // ... |prev_tail|           |entry|           |head_tail| ...
300             //     \---------/ <--- prev \-----/ <--- prev \---------/
301             let mut entry = NonNull::new_unchecked(entry);
302             let mut prev_tail = mem::replace(&mut self.head_tail.as_mut().prev, entry);
303             entry.as_mut().prev = prev_tail;
304             entry.as_mut().next = self.head_tail;
305             prev_tail.as_mut().next = entry;
306             (*entry.as_ptr()).value.as_ref().unwrap()
307         }
308
309         /// Pops an entry from the front of the list.
310         ///
311         /// # Safety
312         ///
313         /// The caller must make sure to synchronize ending the borrow of the
314         /// return value and deallocation of the containing entry.
315         pub unsafe fn pop<'a>(&mut self) -> Option<&'a T> {
316             self.init();
317
318             if self.is_empty() {
319                 None
320             } else {
321                 // BEFORE:
322                 //     /---------\ next ---> /-----\ next ---> /------\
323                 // ... |head_tail|           |first|           |second| ...
324                 //     \---------/ <--- prev \-----/ <--- prev \------/
325                 //
326                 // AFTER:
327                 //     /---------\ next ---> /------\
328                 // ... |head_tail|           |second| ...
329                 //     \---------/ <--- prev \------/
330                 let mut first = self.head_tail.as_mut().next;
331                 let mut second = first.as_mut().next;
332                 self.head_tail.as_mut().next = second;
333                 second.as_mut().prev = self.head_tail;
334                 first.as_mut().next = NonNull::dangling();
335                 first.as_mut().prev = NonNull::dangling();
336                 Some((*first.as_ptr()).value.as_ref().unwrap())
337             }
338         }
339     }
340
341     #[cfg(test)]
342     mod tests {
343         use super::*;
344         use crate::cell::Cell;
345
346         unsafe fn assert_empty<T>(list: &mut UnsafeList<T>) {
347             assert!(list.pop().is_none(), "assertion failed: list is not empty");
348         }
349
350         #[test]
351         fn init_empty() {
352             unsafe {
353                 assert_empty(&mut UnsafeList::<i32>::new());
354             }
355         }
356
357         #[test]
358         fn push_pop() {
359             unsafe {
360                 let mut node = UnsafeListEntry::new(1234);
361                 let mut list = UnsafeList::new();
362                 assert_eq!(list.push(&mut node), &1234);
363                 assert_eq!(list.pop().unwrap(), &1234);
364                 assert_empty(&mut list);
365             }
366         }
367
368         #[test]
369         fn complex_pushes_pops() {
370             unsafe {
371                 let mut node1 = UnsafeListEntry::new(1234);
372                 let mut node2 = UnsafeListEntry::new(4567);
373                 let mut node3 = UnsafeListEntry::new(9999);
374                 let mut node4 = UnsafeListEntry::new(8642);
375                 let mut list = UnsafeList::new();
376                 list.push(&mut node1);
377                 list.push(&mut node2);
378                 assert_eq!(list.pop().unwrap(), &1234);
379                 list.push(&mut node3);
380                 assert_eq!(list.pop().unwrap(), &4567);
381                 assert_eq!(list.pop().unwrap(), &9999);
382                 assert_empty(&mut list);
383                 list.push(&mut node4);
384                 assert_eq!(list.pop().unwrap(), &8642);
385                 assert_empty(&mut list);
386             }
387         }
388
389         #[test]
390         fn cell() {
391             unsafe {
392                 let mut node = UnsafeListEntry::new(Cell::new(0));
393                 let mut list = UnsafeList::new();
394                 let noderef = list.push(&mut node);
395                 assert_eq!(noderef.get(), 0);
396                 list.pop().unwrap().set(1);
397                 assert_empty(&mut list);
398                 assert_eq!(noderef.get(), 1);
399             }
400         }
401     }
402 }
403
404 /// Trivial spinlock-based implementation of `sync::Mutex`.
405 // FIXME: Perhaps use Intel TSX to avoid locking?
406 mod spin_mutex {
407     use crate::cell::UnsafeCell;
408     use crate::sync::atomic::{AtomicBool, Ordering, spin_loop_hint};
409     use crate::ops::{Deref, DerefMut};
410
411     #[derive(Default)]
412     pub struct SpinMutex<T> {
413         value: UnsafeCell<T>,
414         lock: AtomicBool,
415     }
416
417     unsafe impl<T: Send> Send for SpinMutex<T> {}
418     unsafe impl<T: Send> Sync for SpinMutex<T> {}
419
420     pub struct SpinMutexGuard<'a, T: 'a> {
421         mutex: &'a SpinMutex<T>,
422     }
423
424     impl<'a, T> !Send for SpinMutexGuard<'a, T> {}
425     unsafe impl<'a, T: Sync> Sync for SpinMutexGuard<'a, T> {}
426
427     impl<T> SpinMutex<T> {
428         pub const fn new(value: T) -> Self {
429             SpinMutex {
430                 value: UnsafeCell::new(value),
431                 lock: AtomicBool::new(false)
432             }
433         }
434
435         #[inline(always)]
436         pub fn lock(&self) -> SpinMutexGuard<T> {
437             loop {
438                 match self.try_lock() {
439                     None => while self.lock.load(Ordering::Relaxed) {
440                         spin_loop_hint()
441                     },
442                     Some(guard) => return guard
443                 }
444             }
445         }
446
447         #[inline(always)]
448         pub fn try_lock(&self) -> Option<SpinMutexGuard<T>> {
449             if !self.lock.compare_and_swap(false, true, Ordering::Acquire) {
450                 Some(SpinMutexGuard {
451                     mutex: self,
452                 })
453             } else {
454                 None
455             }
456         }
457     }
458
459     /// Lock the Mutex or return false.
460     pub macro try_lock_or_false {
461         ($e:expr) => {
462             if let Some(v) = $e.try_lock() {
463                 v
464             } else {
465                 return false
466             }
467         }
468     }
469
470     impl<'a, T> Deref for SpinMutexGuard<'a, T> {
471         type Target = T;
472
473         fn deref(&self) -> &T {
474             unsafe {
475                 &*self.mutex.value.get()
476             }
477         }
478     }
479
480     impl<'a, T> DerefMut for SpinMutexGuard<'a, T> {
481         fn deref_mut(&mut self) -> &mut T {
482             unsafe {
483                 &mut*self.mutex.value.get()
484             }
485         }
486     }
487
488     impl<'a, T> Drop for SpinMutexGuard<'a, T> {
489         fn drop(&mut self) {
490             self.mutex.lock.store(false, Ordering::Release)
491         }
492     }
493
494     #[cfg(test)]
495     mod tests {
496         #![allow(deprecated)]
497
498         use super::*;
499         use crate::sync::Arc;
500         use crate::thread;
501
502         #[test]
503         fn sleep() {
504             let mutex = Arc::new(SpinMutex::<i32>::default());
505             let mutex2 = mutex.clone();
506             let guard = mutex.lock();
507             let t1 = thread::spawn(move || {
508                 *mutex2.lock() = 1;
509             });
510             thread::sleep_ms(50);
511             assert_eq!(*guard, 0);
512             drop(guard);
513             t1.join().unwrap();
514             assert_eq!(*mutex.lock(), 1);
515         }
516     }
517 }
518
519 #[cfg(test)]
520 mod tests {
521     use super::*;
522     use crate::sync::Arc;
523     use crate::thread;
524
525     #[test]
526     fn queue() {
527         let wq = Arc::new(SpinMutex::<WaitVariable<()>>::default());
528         let wq2 = wq.clone();
529
530         let locked = wq.lock();
531
532         let t1 = thread::spawn(move || {
533             assert!(WaitQueue::notify_one(wq2.lock()).is_none())
534         });
535
536         WaitQueue::wait(locked);
537
538         t1.join().unwrap();
539     }
540 }