]> git.lizzy.rs Git - rust.git/blob - library/std/src/sync/barrier.rs
Rollup merge of #84320 - jsha:details-implementors, r=Manishearth,Nemo157,GuillaumeGomez
[rust.git] / library / std / src / sync / barrier.rs
1 #[cfg(test)]
2 mod tests;
3
4 use crate::fmt;
5 use crate::sync::{Condvar, Mutex};
6
7 /// A barrier enables multiple threads to synchronize the beginning
8 /// of some computation.
9 ///
10 /// # Examples
11 ///
12 /// ```
13 /// use std::sync::{Arc, Barrier};
14 /// use std::thread;
15 ///
16 /// let mut handles = Vec::with_capacity(10);
17 /// let barrier = Arc::new(Barrier::new(10));
18 /// for _ in 0..10 {
19 ///     let c = Arc::clone(&barrier);
20 ///     // The same messages will be printed together.
21 ///     // You will NOT see any interleaving.
22 ///     handles.push(thread::spawn(move|| {
23 ///         println!("before wait");
24 ///         c.wait();
25 ///         println!("after wait");
26 ///     }));
27 /// }
28 /// // Wait for other threads to finish.
29 /// for handle in handles {
30 ///     handle.join().unwrap();
31 /// }
32 /// ```
33 #[stable(feature = "rust1", since = "1.0.0")]
34 pub struct Barrier {
35     lock: Mutex<BarrierState>,
36     cvar: Condvar,
37     num_threads: usize,
38 }
39
40 // The inner state of a double barrier
41 struct BarrierState {
42     count: usize,
43     generation_id: usize,
44 }
45
46 /// A `BarrierWaitResult` is returned by [`Barrier::wait()`] when all threads
47 /// in the [`Barrier`] have rendezvoused.
48 ///
49 /// # Examples
50 ///
51 /// ```
52 /// use std::sync::Barrier;
53 ///
54 /// let barrier = Barrier::new(1);
55 /// let barrier_wait_result = barrier.wait();
56 /// ```
57 #[stable(feature = "rust1", since = "1.0.0")]
58 pub struct BarrierWaitResult(bool);
59
60 #[stable(feature = "std_debug", since = "1.16.0")]
61 impl fmt::Debug for Barrier {
62     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
63         f.debug_struct("Barrier").finish_non_exhaustive()
64     }
65 }
66
67 impl Barrier {
68     /// Creates a new barrier that can block a given number of threads.
69     ///
70     /// A barrier will block `n`-1 threads which call [`wait()`] and then wake
71     /// up all threads at once when the `n`th thread calls [`wait()`].
72     ///
73     /// [`wait()`]: Barrier::wait
74     ///
75     /// # Examples
76     ///
77     /// ```
78     /// use std::sync::Barrier;
79     ///
80     /// let barrier = Barrier::new(10);
81     /// ```
82     #[stable(feature = "rust1", since = "1.0.0")]
83     pub fn new(n: usize) -> Barrier {
84         Barrier {
85             lock: Mutex::new(BarrierState { count: 0, generation_id: 0 }),
86             cvar: Condvar::new(),
87             num_threads: n,
88         }
89     }
90
91     /// Blocks the current thread until all threads have rendezvoused here.
92     ///
93     /// Barriers are re-usable after all threads have rendezvoused once, and can
94     /// be used continuously.
95     ///
96     /// A single (arbitrary) thread will receive a [`BarrierWaitResult`] that
97     /// returns `true` from [`BarrierWaitResult::is_leader()`] when returning
98     /// from this function, and all other threads will receive a result that
99     /// will return `false` from [`BarrierWaitResult::is_leader()`].
100     ///
101     /// # Examples
102     ///
103     /// ```
104     /// use std::sync::{Arc, Barrier};
105     /// use std::thread;
106     ///
107     /// let mut handles = Vec::with_capacity(10);
108     /// let barrier = Arc::new(Barrier::new(10));
109     /// for _ in 0..10 {
110     ///     let c = Arc::clone(&barrier);
111     ///     // The same messages will be printed together.
112     ///     // You will NOT see any interleaving.
113     ///     handles.push(thread::spawn(move|| {
114     ///         println!("before wait");
115     ///         c.wait();
116     ///         println!("after wait");
117     ///     }));
118     /// }
119     /// // Wait for other threads to finish.
120     /// for handle in handles {
121     ///     handle.join().unwrap();
122     /// }
123     /// ```
124     #[stable(feature = "rust1", since = "1.0.0")]
125     pub fn wait(&self) -> BarrierWaitResult {
126         let mut lock = self.lock.lock().unwrap();
127         let local_gen = lock.generation_id;
128         lock.count += 1;
129         if lock.count < self.num_threads {
130             // We need a while loop to guard against spurious wakeups.
131             // https://en.wikipedia.org/wiki/Spurious_wakeup
132             while local_gen == lock.generation_id && lock.count < self.num_threads {
133                 lock = self.cvar.wait(lock).unwrap();
134             }
135             BarrierWaitResult(false)
136         } else {
137             lock.count = 0;
138             lock.generation_id = lock.generation_id.wrapping_add(1);
139             self.cvar.notify_all();
140             BarrierWaitResult(true)
141         }
142     }
143 }
144
145 #[stable(feature = "std_debug", since = "1.16.0")]
146 impl fmt::Debug for BarrierWaitResult {
147     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
148         f.debug_struct("BarrierWaitResult").field("is_leader", &self.is_leader()).finish()
149     }
150 }
151
152 impl BarrierWaitResult {
153     /// Returns `true` if this thread is the "leader thread" for the call to
154     /// [`Barrier::wait()`].
155     ///
156     /// Only one thread will have `true` returned from their result, all other
157     /// threads will have `false` returned.
158     ///
159     /// # Examples
160     ///
161     /// ```
162     /// use std::sync::Barrier;
163     ///
164     /// let barrier = Barrier::new(1);
165     /// let barrier_wait_result = barrier.wait();
166     /// println!("{:?}", barrier_wait_result.is_leader());
167     /// ```
168     #[stable(feature = "rust1", since = "1.0.0")]
169     pub fn is_leader(&self) -> bool {
170         self.0
171     }
172 }