]> git.lizzy.rs Git - rust.git/blob - library/std/src/thread/local/tests.rs
std test: better type name, clarifying comment
[rust.git] / library / std / src / thread / local / tests.rs
1 use crate::cell::{Cell, UnsafeCell};
2 use crate::sync::atomic::{AtomicU8, Ordering};
3 use crate::sync::{Arc, Condvar, Mutex};
4 use crate::thread::{self, LocalKey};
5 use crate::thread_local;
6
7 #[derive(Clone, Default)]
8 struct Signal(Arc<(Mutex<bool>, Condvar)>);
9
10 impl Signal {
11     fn notify(&self) {
12         let (set, cvar) = &*self.0;
13         *set.lock().unwrap() = true;
14         cvar.notify_one();
15     }
16
17     fn wait(&self) {
18         let (set, cvar) = &*self.0;
19         let mut set = set.lock().unwrap();
20         while !*set {
21             set = cvar.wait(set).unwrap();
22         }
23     }
24 }
25
26 struct NotifyOnDrop(Signal);
27
28 impl Drop for NotifyOnDrop {
29     fn drop(&mut self) {
30         let NotifyOnDrop(ref f) = *self;
31         f.notify();
32     }
33 }
34
35 #[test]
36 fn smoke_no_dtor() {
37     thread_local!(static FOO: Cell<i32> = Cell::new(1));
38     run(&FOO);
39     thread_local!(static FOO2: Cell<i32> = const { Cell::new(1) });
40     run(&FOO2);
41
42     fn run(key: &'static LocalKey<Cell<i32>>) {
43         key.with(|f| {
44             assert_eq!(f.get(), 1);
45             f.set(2);
46         });
47         let t = thread::spawn(move || {
48             key.with(|f| {
49                 assert_eq!(f.get(), 1);
50             });
51         });
52         t.join().unwrap();
53
54         key.with(|f| {
55             assert_eq!(f.get(), 2);
56         });
57     }
58 }
59
60 #[test]
61 fn states() {
62     struct Foo(&'static LocalKey<Foo>);
63     impl Drop for Foo {
64         fn drop(&mut self) {
65             assert!(self.0.try_with(|_| ()).is_err());
66         }
67     }
68
69     thread_local!(static FOO: Foo = Foo(&FOO));
70     run(&FOO);
71     thread_local!(static FOO2: Foo = const { Foo(&FOO2) });
72     run(&FOO2);
73
74     fn run(foo: &'static LocalKey<Foo>) {
75         thread::spawn(move || {
76             assert!(foo.try_with(|_| ()).is_ok());
77         })
78         .join()
79         .unwrap();
80     }
81 }
82
83 #[test]
84 fn smoke_dtor() {
85     thread_local!(static FOO: UnsafeCell<Option<NotifyOnDrop>> = UnsafeCell::new(None));
86     run(&FOO);
87     thread_local!(static FOO2: UnsafeCell<Option<NotifyOnDrop>> = const { UnsafeCell::new(None) });
88     run(&FOO2);
89
90     fn run(key: &'static LocalKey<UnsafeCell<Option<NotifyOnDrop>>>) {
91         let signal = Signal::default();
92         let signal2 = signal.clone();
93         let t = thread::spawn(move || unsafe {
94             let mut signal = Some(signal2);
95             key.with(|f| {
96                 *f.get() = Some(NotifyOnDrop(signal.take().unwrap()));
97             });
98         });
99         signal.wait();
100         t.join().unwrap();
101     }
102 }
103
104 #[test]
105 fn circular() {
106     struct S1(&'static LocalKey<UnsafeCell<Option<S1>>>, &'static LocalKey<UnsafeCell<Option<S2>>>);
107     struct S2(&'static LocalKey<UnsafeCell<Option<S1>>>, &'static LocalKey<UnsafeCell<Option<S2>>>);
108     thread_local!(static K1: UnsafeCell<Option<S1>> = UnsafeCell::new(None));
109     thread_local!(static K2: UnsafeCell<Option<S2>> = UnsafeCell::new(None));
110     thread_local!(static K3: UnsafeCell<Option<S1>> = const { UnsafeCell::new(None) });
111     thread_local!(static K4: UnsafeCell<Option<S2>> = const { UnsafeCell::new(None) });
112     static mut HITS: usize = 0;
113
114     impl Drop for S1 {
115         fn drop(&mut self) {
116             unsafe {
117                 HITS += 1;
118                 if self.1.try_with(|_| ()).is_err() {
119                     assert_eq!(HITS, 3);
120                 } else {
121                     if HITS == 1 {
122                         self.1.with(|s| *s.get() = Some(S2(self.0, self.1)));
123                     } else {
124                         assert_eq!(HITS, 3);
125                     }
126                 }
127             }
128         }
129     }
130     impl Drop for S2 {
131         fn drop(&mut self) {
132             unsafe {
133                 HITS += 1;
134                 assert!(self.0.try_with(|_| ()).is_ok());
135                 assert_eq!(HITS, 2);
136                 self.0.with(|s| *s.get() = Some(S1(self.0, self.1)));
137             }
138         }
139     }
140
141     thread::spawn(move || {
142         drop(S1(&K1, &K2));
143     })
144     .join()
145     .unwrap();
146
147     unsafe {
148         HITS = 0;
149     }
150
151     thread::spawn(move || {
152         drop(S1(&K3, &K4));
153     })
154     .join()
155     .unwrap();
156 }
157
158 #[test]
159 fn self_referential() {
160     struct S1(&'static LocalKey<UnsafeCell<Option<S1>>>);
161
162     thread_local!(static K1: UnsafeCell<Option<S1>> = UnsafeCell::new(None));
163     thread_local!(static K2: UnsafeCell<Option<S1>> = const { UnsafeCell::new(None) });
164
165     impl Drop for S1 {
166         fn drop(&mut self) {
167             assert!(self.0.try_with(|_| ()).is_err());
168         }
169     }
170
171     thread::spawn(move || unsafe {
172         K1.with(|s| *s.get() = Some(S1(&K1)));
173     })
174     .join()
175     .unwrap();
176
177     thread::spawn(move || unsafe {
178         K2.with(|s| *s.get() = Some(S1(&K2)));
179     })
180     .join()
181     .unwrap();
182 }
183
184 // Note that this test will deadlock if TLS destructors aren't run (this
185 // requires the destructor to be run to pass the test).
186 #[test]
187 fn dtors_in_dtors_in_dtors() {
188     struct S1(Signal);
189     thread_local!(static K1: UnsafeCell<Option<S1>> = UnsafeCell::new(None));
190     thread_local!(static K2: UnsafeCell<Option<NotifyOnDrop>> = UnsafeCell::new(None));
191
192     impl Drop for S1 {
193         fn drop(&mut self) {
194             let S1(ref signal) = *self;
195             unsafe {
196                 let _ = K2.try_with(|s| *s.get() = Some(NotifyOnDrop(signal.clone())));
197             }
198         }
199     }
200
201     let signal = Signal::default();
202     let signal2 = signal.clone();
203     let _t = thread::spawn(move || unsafe {
204         let mut signal = Some(signal2);
205         K1.with(|s| *s.get() = Some(S1(signal.take().unwrap())));
206     });
207     signal.wait();
208 }
209
210 #[test]
211 fn dtors_in_dtors_in_dtors_const_init() {
212     struct S1(Signal);
213     thread_local!(static K1: UnsafeCell<Option<S1>> = const { UnsafeCell::new(None) });
214     thread_local!(static K2: UnsafeCell<Option<NotifyOnDrop>> = const { UnsafeCell::new(None) });
215
216     impl Drop for S1 {
217         fn drop(&mut self) {
218             let S1(ref signal) = *self;
219             unsafe {
220                 let _ = K2.try_with(|s| *s.get() = Some(NotifyOnDrop(signal.clone())));
221             }
222         }
223     }
224
225     let signal = Signal::default();
226     let signal2 = signal.clone();
227     let _t = thread::spawn(move || unsafe {
228         let mut signal = Some(signal2);
229         K1.with(|s| *s.get() = Some(S1(signal.take().unwrap())));
230     });
231     signal.wait();
232 }
233
234 // This test tests that TLS destructors have run before the thread joins. The
235 // test has no false positives (meaning: if the test fails, there's actually
236 // an ordering problem). It may have false negatives, where the test passes but
237 // join is not guaranteed to be after the TLS destructors. However, false
238 // negatives should be exceedingly rare due to judicious use of
239 // thread::yield_now and running the test several times.
240 #[test]
241 fn join_orders_after_tls_destructors() {
242     // We emulate a synchronous MPSC rendezvous channel using only atomics and
243     // thread::yield_now. We can't use std::mpsc as the implementation itself
244     // may rely on thread locals.
245     //
246     // The basic state machine for an SPSC rendezvous channel is:
247     //           FRESH -> THREAD1_WAITING -> MAIN_THREAD_RENDEZVOUS
248     // where the first transition is done by the “receiving” thread and the 2nd
249     // transition is done by the “sending” thread.
250     //
251     // We add an additional state `THREAD2_LAUNCHED` between `FRESH` and
252     // `THREAD1_WAITING` to block until all threads are actually running.
253     //
254     // A thread that joins on the “receiving” thread completion should never
255     // observe the channel in the `THREAD1_WAITING` state. If this does occur,
256     // we switch to the “poison” state `THREAD2_JOINED` and panic all around.
257     // (This is equivalent to “sending” from an alternate producer thread.)
258     const FRESH: u8 = 0;
259     const THREAD2_LAUNCHED: u8 = 1;
260     const THREAD1_WAITING: u8 = 2;
261     const MAIN_THREAD_RENDEZVOUS: u8 = 3;
262     const THREAD2_JOINED: u8 = 4;
263     static SYNC_STATE: AtomicU8 = AtomicU8::new(FRESH);
264
265     for _ in 0..10 {
266         SYNC_STATE.store(FRESH, Ordering::SeqCst);
267
268         let jh = thread::Builder::new()
269             .name("thread1".into())
270             .spawn(move || {
271                 struct TlDrop;
272
273                 impl Drop for TlDrop {
274                     fn drop(&mut self) {
275                         let mut sync_state = SYNC_STATE.swap(THREAD1_WAITING, Ordering::SeqCst);
276                         loop {
277                             match sync_state {
278                                 THREAD2_LAUNCHED | THREAD1_WAITING => thread::yield_now(),
279                                 MAIN_THREAD_RENDEZVOUS => break,
280                                 THREAD2_JOINED => panic!(
281                                     "Thread 1 still running after thread 2 joined on thread 1"
282                                 ),
283                                 v => unreachable!("sync state: {}", v),
284                             }
285                             sync_state = SYNC_STATE.load(Ordering::SeqCst);
286                         }
287                     }
288                 }
289
290                 thread_local! {
291                     static TL_DROP: TlDrop = TlDrop;
292                 }
293
294                 TL_DROP.with(|_| {});
295
296                 loop {
297                     match SYNC_STATE.load(Ordering::SeqCst) {
298                         FRESH => thread::yield_now(),
299                         THREAD2_LAUNCHED => break,
300                         v => unreachable!("sync state: {}", v),
301                     }
302                 }
303             })
304             .unwrap();
305
306         let jh2 = thread::Builder::new()
307             .name("thread2".into())
308             .spawn(move || {
309                 assert_eq!(SYNC_STATE.swap(THREAD2_LAUNCHED, Ordering::SeqCst), FRESH);
310                 jh.join().unwrap();
311                 match SYNC_STATE.swap(THREAD2_JOINED, Ordering::SeqCst) {
312                     MAIN_THREAD_RENDEZVOUS => return,
313                     THREAD2_LAUNCHED | THREAD1_WAITING => {
314                         panic!("Thread 2 running after thread 1 join before main thread rendezvous")
315                     }
316                     v => unreachable!("sync state: {:?}", v),
317                 }
318             })
319             .unwrap();
320
321         loop {
322             match SYNC_STATE.compare_exchange(
323                 THREAD1_WAITING,
324                 MAIN_THREAD_RENDEZVOUS,
325                 Ordering::SeqCst,
326                 Ordering::SeqCst,
327             ) {
328                 Ok(_) => break,
329                 Err(FRESH) => thread::yield_now(),
330                 Err(THREAD2_LAUNCHED) => thread::yield_now(),
331                 Err(THREAD2_JOINED) => {
332                     panic!("Main thread rendezvous after thread 2 joined thread 1")
333                 }
334                 v => unreachable!("sync state: {:?}", v),
335             }
336         }
337         jh2.join().unwrap();
338     }
339 }