]> git.lizzy.rs Git - rust.git/blob - src/libstd/sys/sgx/waitqueue.rs
SGX target: implement synchronization primitives and threading
[rust.git] / src / libstd / sys / sgx / waitqueue.rs
1 // Copyright 2018 The Rust Project Developers. See the COPYRIGHT
2 // file at the top-level directory of this distribution and at
3 // http://rust-lang.org/COPYRIGHT.
4 //
5 // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
6 // http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
7 // <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
8 // option. This file may not be copied, modified, or distributed
9 // except according to those terms.
10
11 /// A simple queue implementation for synchronization primitives.
12 ///
13 /// This queue is used to implement condition variable and mutexes.
14 ///
15 /// Users of this API are expected to use the `WaitVariable<T>` type. Since
16 /// that type is not `Sync`, it needs to be protected by e.g. a `SpinMutex` to
17 /// allow shared access.
18 ///
19 /// Since userspace may send spurious wake-ups, the wakeup event state is
20 /// recorded in the enclave. The wakeup event state is protected by a spinlock.
21 /// The queue and associated wait state are stored in a `WaitVariable`.
22
23 use ops::{Deref, DerefMut};
24 use num::NonZeroUsize;
25
26 use fortanix_sgx_abi::{Tcs, EV_UNPARK, WAIT_INDEFINITE};
27 use super::abi::usercalls;
28 use super::abi::thread;
29
30 use self::unsafe_list::{UnsafeList, UnsafeListEntry};
31 pub use self::spin_mutex::{SpinMutex, SpinMutexGuard, try_lock_or_false};
32
33 /// An queue entry in a `WaitQueue`.
34 struct WaitEntry {
35     /// TCS address of the thread that is waiting
36     tcs: Tcs,
37     /// Whether this thread has been notified to be awoken
38     wake: bool
39 }
40
41 /// Data stored with a `WaitQueue` alongside it. This ensures accesses to the
42 /// queue and the data are synchronized, since the type itself is not `Sync`.
43 ///
44 /// Consumers of this API should use a synchronization primitive for shared
45 /// access, such as `SpinMutex`.
46 #[derive(Default)]
47 pub struct WaitVariable<T> {
48     queue: WaitQueue,
49     lock: T
50 }
51
52 impl<T> WaitVariable<T> {
53     #[unstable(feature = "sgx_internals", issue = "0")] // FIXME: min_const_fn
54     pub const fn new(var: T) -> Self {
55         WaitVariable {
56             queue: WaitQueue::new(),
57             lock: var
58         }
59     }
60
61     pub fn queue_empty(&self) -> bool {
62         self.queue.is_empty()
63     }
64
65     pub fn lock_var(&self) -> &T {
66         &self.lock
67     }
68
69     pub fn lock_var_mut(&mut self) -> &mut T {
70         &mut self.lock
71     }
72 }
73
74 #[derive(Copy, Clone)]
75 pub enum NotifiedTcs {
76     Single(Tcs),
77     All { count: NonZeroUsize }
78 }
79
80 /// An RAII guard that will notify a set of target threads as well as unlock
81 /// a mutex on drop.
82 pub struct WaitGuard<'a, T: 'a> {
83     mutex_guard: Option<SpinMutexGuard<'a, WaitVariable<T>>>,
84     notified_tcs: NotifiedTcs
85 }
86
87 /// A queue of threads that are waiting on some synchronization primitive.
88 ///
89 /// `UnsafeList` entries are allocated on the waiting thread's stack. This
90 /// avoids any global locking that might happen in the heap allocator. This is
91 /// safe because the waiting thread will not return from that stack frame until
92 /// after it is notified. The notifying thread ensures to clean up any
93 /// references to the list entries before sending the wakeup event.
94 pub struct WaitQueue {
95     // We use an inner Mutex here to protect the data in the face of spurious
96     // wakeups.
97     inner: UnsafeList<SpinMutex<WaitEntry>>,
98 }
99 unsafe impl Send for WaitQueue {}
100
101 impl Default for WaitQueue {
102     fn default() -> Self {
103         Self::new()
104     }
105 }
106
107 impl<'a, T> WaitGuard<'a, T> {
108     /// Returns which TCSes will be notified when this guard drops.
109     pub fn notified_tcs(&self) -> NotifiedTcs {
110         self.notified_tcs
111     }
112 }
113
114 impl<'a, T> Deref for WaitGuard<'a, T> {
115     type Target = SpinMutexGuard<'a, WaitVariable<T>>;
116
117     fn deref(&self) -> &Self::Target {
118         self.mutex_guard.as_ref().unwrap()
119     }
120 }
121
122 impl<'a, T> DerefMut for WaitGuard<'a, T> {
123     fn deref_mut(&mut self) -> &mut Self::Target {
124         self.mutex_guard.as_mut().unwrap()
125     }
126 }
127
128 impl<'a, T> Drop for WaitGuard<'a, T> {
129     fn drop(&mut self) {
130         drop(self.mutex_guard.take());
131         let target_tcs = match self.notified_tcs {
132             NotifiedTcs::Single(tcs) => Some(tcs),
133             NotifiedTcs::All { .. } => None
134         };
135         usercalls::send(EV_UNPARK, target_tcs).unwrap();
136     }
137 }
138
139 impl WaitQueue {
140     #[unstable(feature = "sgx_internals", issue = "0")] // FIXME: min_const_fn
141     pub const fn new() -> Self {
142         WaitQueue {
143             inner: UnsafeList::new()
144         }
145     }
146
147     pub fn is_empty(&self) -> bool {
148         self.inner.is_empty()
149     }
150
151     /// Add the calling thread to the WaitVariable's wait queue, then wait
152     /// until a wakeup event.
153     ///
154     /// This function does not return until this thread has been awoken.
155     pub fn wait<T>(mut guard: SpinMutexGuard<WaitVariable<T>>) {
156         unsafe {
157             let mut entry = UnsafeListEntry::new(SpinMutex::new(WaitEntry {
158                 tcs: thread::current(),
159                 wake: false
160             }));
161             let entry = guard.queue.inner.push(&mut entry);
162             drop(guard);
163             while !entry.lock().wake {
164                 assert_eq!(
165                     usercalls::wait(EV_UNPARK, WAIT_INDEFINITE).unwrap() & EV_UNPARK,
166                     EV_UNPARK
167                 );
168             }
169         }
170     }
171
172     /// Either find the next waiter on the wait queue, or return the mutex
173     /// guard unchanged.
174     ///
175     /// If a waiter is found, a `WaitGuard` is returned which will notify the
176     /// waiter when it is dropped.
177     pub fn notify_one<T>(mut guard: SpinMutexGuard<WaitVariable<T>>)
178         -> Result<WaitGuard<T>, SpinMutexGuard<WaitVariable<T>>>
179     {
180         unsafe {
181             if let Some(entry) = guard.queue.inner.pop() {
182                 let mut entry_guard = entry.lock();
183                 let tcs = entry_guard.tcs;
184                 entry_guard.wake = true;
185                 drop(entry);
186                 Ok(WaitGuard {
187                     mutex_guard: Some(guard),
188                     notified_tcs: NotifiedTcs::Single(tcs)
189                 })
190             } else {
191                 Err(guard)
192             }
193         }
194     }
195
196     /// Either find any and all waiters on the wait queue, or return the mutex
197     /// guard unchanged.
198     ///
199     /// If at least one waiter is found, a `WaitGuard` is returned which will
200     /// notify all waiters when it is dropped.
201     pub fn notify_all<T>(mut guard: SpinMutexGuard<WaitVariable<T>>)
202         -> Result<WaitGuard<T>, SpinMutexGuard<WaitVariable<T>>>
203     {
204         unsafe {
205             let mut count = 0;
206             while let Some(entry) = guard.queue.inner.pop() {
207                 count += 1;
208                 let mut entry_guard = entry.lock();
209                 entry_guard.wake = true;
210             }
211             if let Some(count) = NonZeroUsize::new(count) {
212                 Ok(WaitGuard {
213                     mutex_guard: Some(guard),
214                     notified_tcs: NotifiedTcs::All { count }
215                 })
216             } else {
217                 Err(guard)
218             }
219         }
220     }
221 }
222
223 /// A doubly-linked list where callers are in charge of memory allocation
224 /// of the nodes in the list.
225 mod unsafe_list {
226     use ptr::NonNull;
227     use mem;
228
229     pub struct UnsafeListEntry<T> {
230         next: NonNull<UnsafeListEntry<T>>,
231         prev: NonNull<UnsafeListEntry<T>>,
232         value: Option<T>
233     }
234
235     impl<T> UnsafeListEntry<T> {
236         fn dummy() -> Self {
237             UnsafeListEntry {
238                 next: NonNull::dangling(),
239                 prev: NonNull::dangling(),
240                 value: None
241             }
242         }
243
244         pub fn new(value: T) -> Self {
245             UnsafeListEntry {
246                 value: Some(value),
247                 ..Self::dummy()
248             }
249         }
250     }
251
252     pub struct UnsafeList<T> {
253         head_tail: NonNull<UnsafeListEntry<T>>,
254         head_tail_entry: Option<UnsafeListEntry<T>>,
255     }
256
257     impl<T> UnsafeList<T> {
258         #[unstable(feature = "sgx_internals", issue = "0")] // FIXME: min_const_fn
259         pub const fn new() -> Self {
260             unsafe {
261                 UnsafeList {
262                     head_tail: NonNull::new_unchecked(1 as _),
263                     head_tail_entry: None
264                 }
265             }
266         }
267
268         unsafe fn init(&mut self) {
269             if self.head_tail_entry.is_none() {
270                 self.head_tail_entry = Some(UnsafeListEntry::dummy());
271                 self.head_tail = NonNull::new_unchecked(self.head_tail_entry.as_mut().unwrap());
272                 self.head_tail.as_mut().next = self.head_tail;
273                 self.head_tail.as_mut().prev = self.head_tail;
274             }
275         }
276
277         pub fn is_empty(&self) -> bool {
278             unsafe {
279                 if self.head_tail_entry.is_some() {
280                     let first = self.head_tail.as_ref().next;
281                     if first == self.head_tail {
282                         // ,-------> /---------\ next ---,
283                         // |         |head_tail|         |
284                         // `--- prev \---------/ <-------`
285                         assert_eq!(self.head_tail.as_ref().prev, first);
286                         true
287                     } else {
288                         false
289                     }
290                 } else {
291                     true
292                 }
293             }
294         }
295
296         /// Pushes an entry onto the back of the list.
297         ///
298         /// # Safety
299         ///
300         /// The entry must remain allocated until the entry is removed from the
301         /// list AND the caller who popped is done using the entry.
302         pub unsafe fn push<'a>(&mut self, entry: &'a mut UnsafeListEntry<T>) -> &'a T {
303             self.init();
304
305             // BEFORE:
306             //     /---------\ next ---> /---------\
307             // ... |prev_tail|           |head_tail| ...
308             //     \---------/ <--- prev \---------/
309             //
310             // AFTER:
311             //     /---------\ next ---> /-----\ next ---> /---------\
312             // ... |prev_tail|           |entry|           |head_tail| ...
313             //     \---------/ <--- prev \-----/ <--- prev \---------/
314             let mut entry = NonNull::new_unchecked(entry);
315             let mut prev_tail = mem::replace(&mut self.head_tail.as_mut().prev, entry);
316             entry.as_mut().prev = prev_tail;
317             entry.as_mut().next = self.head_tail;
318             prev_tail.as_mut().next = entry;
319             (*entry.as_ptr()).value.as_ref().unwrap()
320         }
321
322         /// Pops an entry from the front of the list.
323         ///
324         /// # Safety
325         ///
326         /// The caller must make sure to synchronize ending the borrow of the
327         /// return value and deallocation of the containing entry.
328         pub unsafe fn pop<'a>(&mut self) -> Option<&'a T> {
329             self.init();
330
331             if self.is_empty() {
332                 None
333             } else {
334                 // BEFORE:
335                 //     /---------\ next ---> /-----\ next ---> /------\
336                 // ... |head_tail|           |first|           |second| ...
337                 //     \---------/ <--- prev \-----/ <--- prev \------/
338                 //
339                 // AFTER:
340                 //     /---------\ next ---> /------\
341                 // ... |head_tail|           |second| ...
342                 //     \---------/ <--- prev \------/
343                 let mut first = self.head_tail.as_mut().next;
344                 let mut second = first.as_mut().next;
345                 self.head_tail.as_mut().next = second;
346                 second.as_mut().prev = self.head_tail;
347                 first.as_mut().next = NonNull::dangling();
348                 first.as_mut().prev = NonNull::dangling();
349                 Some((*first.as_ptr()).value.as_ref().unwrap())
350             }
351         }
352     }
353
354     #[cfg(test)]
355     mod tests {
356         use super::*;
357         use cell::Cell;
358
359         unsafe fn assert_empty<T>(list: &mut UnsafeList<T>) {
360             assert!(list.pop().is_none(), "assertion failed: list is not empty");
361         }
362
363         #[test]
364         fn init_empty() {
365             unsafe {
366                 assert_empty(&mut UnsafeList::<i32>::new());
367             }
368         }
369
370         #[test]
371         fn push_pop() {
372             unsafe {
373                 let mut node = UnsafeListEntry::new(1234);
374                 let mut list = UnsafeList::new();
375                 assert_eq!(list.push(&mut node), &1234);
376                 assert_eq!(list.pop().unwrap(), &1234);
377                 assert_empty(&mut list);
378             }
379         }
380
381         #[test]
382         fn complex_pushes_pops() {
383             unsafe {
384                 let mut node1 = UnsafeListEntry::new(1234);
385                 let mut node2 = UnsafeListEntry::new(4567);
386                 let mut node3 = UnsafeListEntry::new(9999);
387                 let mut node4 = UnsafeListEntry::new(8642);
388                 let mut list = UnsafeList::new();
389                 list.push(&mut node1);
390                 list.push(&mut node2);
391                 assert_eq!(list.pop().unwrap(), &1234);
392                 list.push(&mut node3);
393                 assert_eq!(list.pop().unwrap(), &4567);
394                 assert_eq!(list.pop().unwrap(), &9999);
395                 assert_empty(&mut list);
396                 list.push(&mut node4);
397                 assert_eq!(list.pop().unwrap(), &8642);
398                 assert_empty(&mut list);
399             }
400         }
401
402         #[test]
403         fn cell() {
404             unsafe {
405                 let mut node = UnsafeListEntry::new(Cell::new(0));
406                 let mut list = UnsafeList::new();
407                 let noderef = list.push(&mut node);
408                 assert_eq!(noderef.get(), 0);
409                 list.pop().unwrap().set(1);
410                 assert_empty(&mut list);
411                 assert_eq!(noderef.get(), 1);
412             }
413         }
414     }
415 }
416
417 /// Trivial spinlock-based implementation of `sync::Mutex`.
418 // FIXME: Perhaps use Intel TSX to avoid locking?
419 mod spin_mutex {
420     use cell::UnsafeCell;
421     use sync::atomic::{AtomicBool, Ordering, spin_loop_hint};
422     use ops::{Deref, DerefMut};
423
424     #[derive(Default)]
425     pub struct SpinMutex<T> {
426         value: UnsafeCell<T>,
427         lock: AtomicBool,
428     }
429
430     unsafe impl<T: Send> Send for SpinMutex<T> {}
431     unsafe impl<T: Send> Sync for SpinMutex<T> {}
432
433     pub struct SpinMutexGuard<'a, T: 'a> {
434         mutex: &'a SpinMutex<T>,
435     }
436
437     impl<'a, T> !Send for SpinMutexGuard<'a, T> {}
438     unsafe impl<'a, T: Sync> Sync for SpinMutexGuard<'a, T> {}
439
440     impl<T> SpinMutex<T> {
441         pub const fn new(value: T) -> Self {
442             SpinMutex {
443                 value: UnsafeCell::new(value),
444                 lock: AtomicBool::new(false)
445             }
446         }
447
448         #[inline(always)]
449         pub fn lock(&self) -> SpinMutexGuard<T> {
450             loop {
451                 match self.try_lock() {
452                     None => while self.lock.load(Ordering::Relaxed) {
453                         spin_loop_hint()
454                     },
455                     Some(guard) => return guard
456                 }
457             }
458         }
459
460         #[inline(always)]
461         pub fn try_lock(&self) -> Option<SpinMutexGuard<T>> {
462             if !self.lock.compare_and_swap(false, true, Ordering::Acquire) {
463                 Some(SpinMutexGuard {
464                     mutex: self,
465                 })
466             } else {
467                 None
468             }
469         }
470     }
471
472     pub macro try_lock_or_false {
473         ($e:expr) => {
474             if let Some(v) = $e.try_lock() {
475                 v
476             } else {
477                 return false
478             }
479         }
480     }
481
482     impl<'a, T> Deref for SpinMutexGuard<'a, T> {
483         type Target = T;
484
485         fn deref(&self) -> &T {
486             unsafe {
487                 &*self.mutex.value.get()
488             }
489         }
490     }
491
492     impl<'a, T> DerefMut for SpinMutexGuard<'a, T> {
493         fn deref_mut(&mut self) -> &mut T {
494             unsafe {
495                 &mut*self.mutex.value.get()
496             }
497         }
498     }
499
500     impl<'a, T> Drop for SpinMutexGuard<'a, T> {
501         fn drop(&mut self) {
502             self.mutex.lock.store(false, Ordering::Release)
503         }
504     }
505
506     #[cfg(test)]
507     mod tests {
508         #![allow(deprecated)]
509
510         use super::*;
511         use sync::Arc;
512         use thread;
513
514         #[test]
515         fn sleep() {
516             let mutex = Arc::new(SpinMutex::<i32>::default());
517             let mutex2 = mutex.clone();
518             let guard = mutex.lock();
519             let t1 = thread::spawn(move || {
520                 *mutex2.lock() = 1;
521             });
522             thread::sleep_ms(50);
523             assert_eq!(*guard, 0);
524             drop(guard);
525             t1.join().unwrap();
526             assert_eq!(*mutex.lock(), 1);
527         }
528     }
529 }
530
531 #[cfg(test)]
532 mod tests {
533     use super::*;
534     use sync::Arc;
535     use thread;
536
537     #[test]
538     fn queue() {
539         let wq = Arc::new(SpinMutex::<WaitVariable<()>>::default());
540         let wq2 = wq.clone();
541
542         let locked = wq.lock();
543
544         let t1 = thread::spawn(move || {
545             assert!(WaitQueue::notify_one(wq2.lock()).is_none())
546         });
547
548         WaitQueue::wait(locked);
549
550         t1.join().unwrap();
551     }
552 }