]> git.lizzy.rs Git - rust.git/blob - library/std/src/thread/local/tests.rs
:arrow_up: rust-analyzer
[rust.git] / library / std / src / thread / local / tests.rs
1 use crate::cell::{Cell, UnsafeCell};
2 use crate::sync::atomic::{AtomicU8, Ordering};
3 use crate::sync::mpsc::{channel, Sender};
4 use crate::thread::{self, LocalKey};
5 use crate::thread_local;
6
7 struct Foo(Sender<()>);
8
9 impl Drop for Foo {
10     fn drop(&mut self) {
11         let Foo(ref s) = *self;
12         s.send(()).unwrap();
13     }
14 }
15
16 #[test]
17 fn smoke_no_dtor() {
18     thread_local!(static FOO: Cell<i32> = Cell::new(1));
19     run(&FOO);
20     thread_local!(static FOO2: Cell<i32> = const { Cell::new(1) });
21     run(&FOO2);
22
23     fn run(key: &'static LocalKey<Cell<i32>>) {
24         key.with(|f| {
25             assert_eq!(f.get(), 1);
26             f.set(2);
27         });
28         let t = thread::spawn(move || {
29             key.with(|f| {
30                 assert_eq!(f.get(), 1);
31             });
32         });
33         t.join().unwrap();
34
35         key.with(|f| {
36             assert_eq!(f.get(), 2);
37         });
38     }
39 }
40
41 #[test]
42 fn states() {
43     struct Foo(&'static LocalKey<Foo>);
44     impl Drop for Foo {
45         fn drop(&mut self) {
46             assert!(self.0.try_with(|_| ()).is_err());
47         }
48     }
49
50     thread_local!(static FOO: Foo = Foo(&FOO));
51     run(&FOO);
52     thread_local!(static FOO2: Foo = const { Foo(&FOO2) });
53     run(&FOO2);
54
55     fn run(foo: &'static LocalKey<Foo>) {
56         thread::spawn(move || {
57             assert!(foo.try_with(|_| ()).is_ok());
58         })
59         .join()
60         .unwrap();
61     }
62 }
63
64 #[test]
65 fn smoke_dtor() {
66     thread_local!(static FOO: UnsafeCell<Option<Foo>> = UnsafeCell::new(None));
67     run(&FOO);
68     thread_local!(static FOO2: UnsafeCell<Option<Foo>> = const { UnsafeCell::new(None) });
69     run(&FOO2);
70
71     fn run(key: &'static LocalKey<UnsafeCell<Option<Foo>>>) {
72         let (tx, rx) = channel();
73         let t = thread::spawn(move || unsafe {
74             let mut tx = Some(tx);
75             key.with(|f| {
76                 *f.get() = Some(Foo(tx.take().unwrap()));
77             });
78         });
79         rx.recv().unwrap();
80         t.join().unwrap();
81     }
82 }
83
84 #[test]
85 fn circular() {
86     struct S1(&'static LocalKey<UnsafeCell<Option<S1>>>, &'static LocalKey<UnsafeCell<Option<S2>>>);
87     struct S2(&'static LocalKey<UnsafeCell<Option<S1>>>, &'static LocalKey<UnsafeCell<Option<S2>>>);
88     thread_local!(static K1: UnsafeCell<Option<S1>> = UnsafeCell::new(None));
89     thread_local!(static K2: UnsafeCell<Option<S2>> = UnsafeCell::new(None));
90     thread_local!(static K3: UnsafeCell<Option<S1>> = const { UnsafeCell::new(None) });
91     thread_local!(static K4: UnsafeCell<Option<S2>> = const { UnsafeCell::new(None) });
92     static mut HITS: usize = 0;
93
94     impl Drop for S1 {
95         fn drop(&mut self) {
96             unsafe {
97                 HITS += 1;
98                 if self.1.try_with(|_| ()).is_err() {
99                     assert_eq!(HITS, 3);
100                 } else {
101                     if HITS == 1 {
102                         self.1.with(|s| *s.get() = Some(S2(self.0, self.1)));
103                     } else {
104                         assert_eq!(HITS, 3);
105                     }
106                 }
107             }
108         }
109     }
110     impl Drop for S2 {
111         fn drop(&mut self) {
112             unsafe {
113                 HITS += 1;
114                 assert!(self.0.try_with(|_| ()).is_ok());
115                 assert_eq!(HITS, 2);
116                 self.0.with(|s| *s.get() = Some(S1(self.0, self.1)));
117             }
118         }
119     }
120
121     thread::spawn(move || {
122         drop(S1(&K1, &K2));
123     })
124     .join()
125     .unwrap();
126
127     unsafe {
128         HITS = 0;
129     }
130
131     thread::spawn(move || {
132         drop(S1(&K3, &K4));
133     })
134     .join()
135     .unwrap();
136 }
137
138 #[test]
139 fn self_referential() {
140     struct S1(&'static LocalKey<UnsafeCell<Option<S1>>>);
141
142     thread_local!(static K1: UnsafeCell<Option<S1>> = UnsafeCell::new(None));
143     thread_local!(static K2: UnsafeCell<Option<S1>> = const { UnsafeCell::new(None) });
144
145     impl Drop for S1 {
146         fn drop(&mut self) {
147             assert!(self.0.try_with(|_| ()).is_err());
148         }
149     }
150
151     thread::spawn(move || unsafe {
152         K1.with(|s| *s.get() = Some(S1(&K1)));
153     })
154     .join()
155     .unwrap();
156
157     thread::spawn(move || unsafe {
158         K2.with(|s| *s.get() = Some(S1(&K2)));
159     })
160     .join()
161     .unwrap();
162 }
163
164 // Note that this test will deadlock if TLS destructors aren't run (this
165 // requires the destructor to be run to pass the test).
166 #[test]
167 fn dtors_in_dtors_in_dtors() {
168     struct S1(Sender<()>);
169     thread_local!(static K1: UnsafeCell<Option<S1>> = UnsafeCell::new(None));
170     thread_local!(static K2: UnsafeCell<Option<Foo>> = UnsafeCell::new(None));
171
172     impl Drop for S1 {
173         fn drop(&mut self) {
174             let S1(ref tx) = *self;
175             unsafe {
176                 let _ = K2.try_with(|s| *s.get() = Some(Foo(tx.clone())));
177             }
178         }
179     }
180
181     let (tx, rx) = channel();
182     let _t = thread::spawn(move || unsafe {
183         let mut tx = Some(tx);
184         K1.with(|s| *s.get() = Some(S1(tx.take().unwrap())));
185     });
186     rx.recv().unwrap();
187 }
188
189 #[test]
190 fn dtors_in_dtors_in_dtors_const_init() {
191     struct S1(Sender<()>);
192     thread_local!(static K1: UnsafeCell<Option<S1>> = const { UnsafeCell::new(None) });
193     thread_local!(static K2: UnsafeCell<Option<Foo>> = const { UnsafeCell::new(None) });
194
195     impl Drop for S1 {
196         fn drop(&mut self) {
197             let S1(ref tx) = *self;
198             unsafe {
199                 let _ = K2.try_with(|s| *s.get() = Some(Foo(tx.clone())));
200             }
201         }
202     }
203
204     let (tx, rx) = channel();
205     let _t = thread::spawn(move || unsafe {
206         let mut tx = Some(tx);
207         K1.with(|s| *s.get() = Some(S1(tx.take().unwrap())));
208     });
209     rx.recv().unwrap();
210 }
211
212 // This test tests that TLS destructors have run before the thread joins. The
213 // test has no false positives (meaning: if the test fails, there's actually
214 // an ordering problem). It may have false negatives, where the test passes but
215 // join is not guaranteed to be after the TLS destructors. However, false
216 // negatives should be exceedingly rare due to judicious use of
217 // thread::yield_now and running the test several times.
218 #[test]
219 fn join_orders_after_tls_destructors() {
220     // We emulate a synchronous MPSC rendezvous channel using only atomics and
221     // thread::yield_now. We can't use std::mpsc as the implementation itself
222     // may rely on thread locals.
223     //
224     // The basic state machine for an SPSC rendezvous channel is:
225     //           FRESH -> THREAD1_WAITING -> MAIN_THREAD_RENDEZVOUS
226     // where the first transition is done by the “receiving” thread and the 2nd
227     // transition is done by the “sending” thread.
228     //
229     // We add an additional state `THREAD2_LAUNCHED` between `FRESH` and
230     // `THREAD1_WAITING` to block until all threads are actually running.
231     //
232     // A thread that joins on the “receiving” thread completion should never
233     // observe the channel in the `THREAD1_WAITING` state. If this does occur,
234     // we switch to the “poison” state `THREAD2_JOINED` and panic all around.
235     // (This is equivalent to “sending” from an alternate producer thread.)
236     const FRESH: u8 = 0;
237     const THREAD2_LAUNCHED: u8 = 1;
238     const THREAD1_WAITING: u8 = 2;
239     const MAIN_THREAD_RENDEZVOUS: u8 = 3;
240     const THREAD2_JOINED: u8 = 4;
241     static SYNC_STATE: AtomicU8 = AtomicU8::new(FRESH);
242
243     for _ in 0..10 {
244         SYNC_STATE.store(FRESH, Ordering::SeqCst);
245
246         let jh = thread::Builder::new()
247             .name("thread1".into())
248             .spawn(move || {
249                 struct TlDrop;
250
251                 impl Drop for TlDrop {
252                     fn drop(&mut self) {
253                         let mut sync_state = SYNC_STATE.swap(THREAD1_WAITING, Ordering::SeqCst);
254                         loop {
255                             match sync_state {
256                                 THREAD2_LAUNCHED | THREAD1_WAITING => thread::yield_now(),
257                                 MAIN_THREAD_RENDEZVOUS => break,
258                                 THREAD2_JOINED => panic!(
259                                     "Thread 1 still running after thread 2 joined on thread 1"
260                                 ),
261                                 v => unreachable!("sync state: {}", v),
262                             }
263                             sync_state = SYNC_STATE.load(Ordering::SeqCst);
264                         }
265                     }
266                 }
267
268                 thread_local! {
269                     static TL_DROP: TlDrop = TlDrop;
270                 }
271
272                 TL_DROP.with(|_| {});
273
274                 loop {
275                     match SYNC_STATE.load(Ordering::SeqCst) {
276                         FRESH => thread::yield_now(),
277                         THREAD2_LAUNCHED => break,
278                         v => unreachable!("sync state: {}", v),
279                     }
280                 }
281             })
282             .unwrap();
283
284         let jh2 = thread::Builder::new()
285             .name("thread2".into())
286             .spawn(move || {
287                 assert_eq!(SYNC_STATE.swap(THREAD2_LAUNCHED, Ordering::SeqCst), FRESH);
288                 jh.join().unwrap();
289                 match SYNC_STATE.swap(THREAD2_JOINED, Ordering::SeqCst) {
290                     MAIN_THREAD_RENDEZVOUS => return,
291                     THREAD2_LAUNCHED | THREAD1_WAITING => {
292                         panic!("Thread 2 running after thread 1 join before main thread rendezvous")
293                     }
294                     v => unreachable!("sync state: {:?}", v),
295                 }
296             })
297             .unwrap();
298
299         loop {
300             match SYNC_STATE.compare_exchange(
301                 THREAD1_WAITING,
302                 MAIN_THREAD_RENDEZVOUS,
303                 Ordering::SeqCst,
304                 Ordering::SeqCst,
305             ) {
306                 Ok(_) => break,
307                 Err(FRESH) => thread::yield_now(),
308                 Err(THREAD2_LAUNCHED) => thread::yield_now(),
309                 Err(THREAD2_JOINED) => {
310                     panic!("Main thread rendezvous after thread 2 joined thread 1")
311                 }
312                 v => unreachable!("sync state: {:?}", v),
313             }
314         }
315         jh2.join().unwrap();
316     }
317 }