]> git.lizzy.rs Git - rust.git/blob - library/std/src/sync/mpsc/shared.rs
Rollup merge of #96609 - ibraheemdev:arc-downcast-unchecked, r=m-ou-se
[rust.git] / library / std / src / sync / mpsc / shared.rs
1 /// Shared channels.
2 ///
3 /// This is the flavor of channels which are not necessarily optimized for any
4 /// particular use case, but are the most general in how they are used. Shared
5 /// channels are cloneable allowing for multiple senders.
6 ///
7 /// High level implementation details can be found in the comment of the parent
8 /// module. You'll also note that the implementation of the shared and stream
9 /// channels are quite similar, and this is no coincidence!
10 pub use self::Failure::*;
11 use self::StartResult::*;
12
13 use core::cmp;
14 use core::intrinsics::abort;
15
16 use crate::cell::UnsafeCell;
17 use crate::ptr;
18 use crate::sync::atomic::{AtomicBool, AtomicIsize, AtomicPtr, AtomicUsize, Ordering};
19 use crate::sync::mpsc::blocking::{self, SignalToken};
20 use crate::sync::mpsc::mpsc_queue as mpsc;
21 use crate::sync::{Mutex, MutexGuard};
22 use crate::thread;
23 use crate::time::Instant;
24
25 const DISCONNECTED: isize = isize::MIN;
26 const FUDGE: isize = 1024;
27 const MAX_REFCOUNT: usize = (isize::MAX) as usize;
28 #[cfg(test)]
29 const MAX_STEALS: isize = 5;
30 #[cfg(not(test))]
31 const MAX_STEALS: isize = 1 << 20;
32 const EMPTY: *mut u8 = ptr::null_mut(); // initial state: no data, no blocked receiver
33
34 pub struct Packet<T> {
35     queue: mpsc::Queue<T>,
36     cnt: AtomicIsize,          // How many items are on this channel
37     steals: UnsafeCell<isize>, // How many times has a port received without blocking?
38     to_wake: AtomicPtr<u8>,    // SignalToken for wake up
39
40     // The number of channels which are currently using this packet.
41     channels: AtomicUsize,
42
43     // See the discussion in Port::drop and the channel send methods for what
44     // these are used for
45     port_dropped: AtomicBool,
46     sender_drain: AtomicIsize,
47
48     // this lock protects various portions of this implementation during
49     // select()
50     select_lock: Mutex<()>,
51 }
52
53 pub enum Failure {
54     Empty,
55     Disconnected,
56 }
57
58 #[derive(PartialEq, Eq)]
59 enum StartResult {
60     Installed,
61     Abort,
62 }
63
64 impl<T> Packet<T> {
65     // Creation of a packet *must* be followed by a call to postinit_lock
66     // and later by inherit_blocker
67     pub fn new() -> Packet<T> {
68         Packet {
69             queue: mpsc::Queue::new(),
70             cnt: AtomicIsize::new(0),
71             steals: UnsafeCell::new(0),
72             to_wake: AtomicPtr::new(EMPTY),
73             channels: AtomicUsize::new(2),
74             port_dropped: AtomicBool::new(false),
75             sender_drain: AtomicIsize::new(0),
76             select_lock: Mutex::new(()),
77         }
78     }
79
80     // This function should be used after newly created Packet
81     // was wrapped with an Arc
82     // In other case mutex data will be duplicated while cloning
83     // and that could cause problems on platforms where it is
84     // represented by opaque data structure
85     pub fn postinit_lock(&self) -> MutexGuard<'_, ()> {
86         self.select_lock.lock().unwrap()
87     }
88
89     // This function is used at the creation of a shared packet to inherit a
90     // previously blocked thread. This is done to prevent spurious wakeups of
91     // threads in select().
92     //
93     // This can only be called at channel-creation time
94     pub fn inherit_blocker(&self, token: Option<SignalToken>, guard: MutexGuard<'_, ()>) {
95         if let Some(token) = token {
96             assert_eq!(self.cnt.load(Ordering::SeqCst), 0);
97             assert_eq!(self.to_wake.load(Ordering::SeqCst), EMPTY);
98             self.to_wake.store(unsafe { token.to_raw() }, Ordering::SeqCst);
99             self.cnt.store(-1, Ordering::SeqCst);
100
101             // This store is a little sketchy. What's happening here is that
102             // we're transferring a blocker from a oneshot or stream channel to
103             // this shared channel. In doing so, we never spuriously wake them
104             // up and rather only wake them up at the appropriate time. This
105             // implementation of shared channels assumes that any blocking
106             // recv() will undo the increment of steals performed in try_recv()
107             // once the recv is complete.  This thread that we're inheriting,
108             // however, is not in the middle of recv. Hence, the first time we
109             // wake them up, they're going to wake up from their old port, move
110             // on to the upgraded port, and then call the block recv() function.
111             //
112             // When calling this function, they'll find there's data immediately
113             // available, counting it as a steal. This in fact wasn't a steal
114             // because we appropriately blocked them waiting for data.
115             //
116             // To offset this bad increment, we initially set the steal count to
117             // -1. You'll find some special code in abort_selection() as well to
118             // ensure that this -1 steal count doesn't escape too far.
119             unsafe {
120                 *self.steals.get() = -1;
121             }
122         }
123
124         // When the shared packet is constructed, we grabbed this lock. The
125         // purpose of this lock is to ensure that abort_selection() doesn't
126         // interfere with this method. After we unlock this lock, we're
127         // signifying that we're done modifying self.cnt and self.to_wake and
128         // the port is ready for the world to continue using it.
129         drop(guard);
130     }
131
132     pub fn send(&self, t: T) -> Result<(), T> {
133         // See Port::drop for what's going on
134         if self.port_dropped.load(Ordering::SeqCst) {
135             return Err(t);
136         }
137
138         // Note that the multiple sender case is a little trickier
139         // semantically than the single sender case. The logic for
140         // incrementing is "add and if disconnected store disconnected".
141         // This could end up leading some senders to believe that there
142         // wasn't a disconnect if in fact there was a disconnect. This means
143         // that while one thread is attempting to re-store the disconnected
144         // states, other threads could walk through merrily incrementing
145         // this very-negative disconnected count. To prevent senders from
146         // spuriously attempting to send when the channels is actually
147         // disconnected, the count has a ranged check here.
148         //
149         // This is also done for another reason. Remember that the return
150         // value of this function is:
151         //
152         //  `true` == the data *may* be received, this essentially has no
153         //            meaning
154         //  `false` == the data will *never* be received, this has a lot of
155         //             meaning
156         //
157         // In the SPSC case, we have a check of 'queue.is_empty()' to see
158         // whether the data was actually received, but this same condition
159         // means nothing in a multi-producer context. As a result, this
160         // preflight check serves as the definitive "this will never be
161         // received". Once we get beyond this check, we have permanently
162         // entered the realm of "this may be received"
163         if self.cnt.load(Ordering::SeqCst) < DISCONNECTED + FUDGE {
164             return Err(t);
165         }
166
167         self.queue.push(t);
168         match self.cnt.fetch_add(1, Ordering::SeqCst) {
169             -1 => {
170                 self.take_to_wake().signal();
171             }
172
173             // In this case, we have possibly failed to send our data, and
174             // we need to consider re-popping the data in order to fully
175             // destroy it. We must arbitrate among the multiple senders,
176             // however, because the queues that we're using are
177             // single-consumer queues. In order to do this, all exiting
178             // pushers will use an atomic count in order to count those
179             // flowing through. Pushers who see 0 are required to drain as
180             // much as possible, and then can only exit when they are the
181             // only pusher (otherwise they must try again).
182             n if n < DISCONNECTED + FUDGE => {
183                 // see the comment in 'try' for a shared channel for why this
184                 // window of "not disconnected" is ok.
185                 self.cnt.store(DISCONNECTED, Ordering::SeqCst);
186
187                 if self.sender_drain.fetch_add(1, Ordering::SeqCst) == 0 {
188                     loop {
189                         // drain the queue, for info on the thread yield see the
190                         // discussion in try_recv
191                         loop {
192                             match self.queue.pop() {
193                                 mpsc::Data(..) => {}
194                                 mpsc::Empty => break,
195                                 mpsc::Inconsistent => thread::yield_now(),
196                             }
197                         }
198                         // maybe we're done, if we're not the last ones
199                         // here, then we need to go try again.
200                         if self.sender_drain.fetch_sub(1, Ordering::SeqCst) == 1 {
201                             break;
202                         }
203                     }
204
205                     // At this point, there may still be data on the queue,
206                     // but only if the count hasn't been incremented and
207                     // some other sender hasn't finished pushing data just
208                     // yet. That sender in question will drain its own data.
209                 }
210             }
211
212             // Can't make any assumptions about this case like in the SPSC case.
213             _ => {}
214         }
215
216         Ok(())
217     }
218
219     pub fn recv(&self, deadline: Option<Instant>) -> Result<T, Failure> {
220         // This code is essentially the exact same as that found in the stream
221         // case (see stream.rs)
222         match self.try_recv() {
223             Err(Empty) => {}
224             data => return data,
225         }
226
227         let (wait_token, signal_token) = blocking::tokens();
228         if self.decrement(signal_token) == Installed {
229             if let Some(deadline) = deadline {
230                 let timed_out = !wait_token.wait_max_until(deadline);
231                 if timed_out {
232                     self.abort_selection(false);
233                 }
234             } else {
235                 wait_token.wait();
236             }
237         }
238
239         match self.try_recv() {
240             data @ Ok(..) => unsafe {
241                 *self.steals.get() -= 1;
242                 data
243             },
244             data => data,
245         }
246     }
247
248     // Essentially the exact same thing as the stream decrement function.
249     // Returns true if blocking should proceed.
250     fn decrement(&self, token: SignalToken) -> StartResult {
251         unsafe {
252             assert_eq!(
253                 self.to_wake.load(Ordering::SeqCst),
254                 EMPTY,
255                 "This is a known bug in the Rust standard library. See https://github.com/rust-lang/rust/issues/39364"
256             );
257             let ptr = token.to_raw();
258             self.to_wake.store(ptr, Ordering::SeqCst);
259
260             let steals = ptr::replace(self.steals.get(), 0);
261
262             match self.cnt.fetch_sub(1 + steals, Ordering::SeqCst) {
263                 DISCONNECTED => {
264                     self.cnt.store(DISCONNECTED, Ordering::SeqCst);
265                 }
266                 // If we factor in our steals and notice that the channel has no
267                 // data, we successfully sleep
268                 n => {
269                     assert!(n >= 0);
270                     if n - steals <= 0 {
271                         return Installed;
272                     }
273                 }
274             }
275
276             self.to_wake.store(EMPTY, Ordering::SeqCst);
277             drop(SignalToken::from_raw(ptr));
278             Abort
279         }
280     }
281
282     pub fn try_recv(&self) -> Result<T, Failure> {
283         let ret = match self.queue.pop() {
284             mpsc::Data(t) => Some(t),
285             mpsc::Empty => None,
286
287             // This is a bit of an interesting case. The channel is reported as
288             // having data available, but our pop() has failed due to the queue
289             // being in an inconsistent state.  This means that there is some
290             // pusher somewhere which has yet to complete, but we are guaranteed
291             // that a pop will eventually succeed. In this case, we spin in a
292             // yield loop because the remote sender should finish their enqueue
293             // operation "very quickly".
294             //
295             // Avoiding this yield loop would require a different queue
296             // abstraction which provides the guarantee that after M pushes have
297             // succeeded, at least M pops will succeed. The current queues
298             // guarantee that if there are N active pushes, you can pop N times
299             // once all N have finished.
300             mpsc::Inconsistent => {
301                 let data;
302                 loop {
303                     thread::yield_now();
304                     match self.queue.pop() {
305                         mpsc::Data(t) => {
306                             data = t;
307                             break;
308                         }
309                         mpsc::Empty => panic!("inconsistent => empty"),
310                         mpsc::Inconsistent => {}
311                     }
312                 }
313                 Some(data)
314             }
315         };
316         match ret {
317             // See the discussion in the stream implementation for why we
318             // might decrement steals.
319             Some(data) => unsafe {
320                 if *self.steals.get() > MAX_STEALS {
321                     match self.cnt.swap(0, Ordering::SeqCst) {
322                         DISCONNECTED => {
323                             self.cnt.store(DISCONNECTED, Ordering::SeqCst);
324                         }
325                         n => {
326                             let m = cmp::min(n, *self.steals.get());
327                             *self.steals.get() -= m;
328                             self.bump(n - m);
329                         }
330                     }
331                     assert!(*self.steals.get() >= 0);
332                 }
333                 *self.steals.get() += 1;
334                 Ok(data)
335             },
336
337             // See the discussion in the stream implementation for why we try
338             // again.
339             None => {
340                 match self.cnt.load(Ordering::SeqCst) {
341                     n if n != DISCONNECTED => Err(Empty),
342                     _ => {
343                         match self.queue.pop() {
344                             mpsc::Data(t) => Ok(t),
345                             mpsc::Empty => Err(Disconnected),
346                             // with no senders, an inconsistency is impossible.
347                             mpsc::Inconsistent => unreachable!(),
348                         }
349                     }
350                 }
351             }
352         }
353     }
354
355     // Prepares this shared packet for a channel clone, essentially just bumping
356     // a refcount.
357     pub fn clone_chan(&self) {
358         let old_count = self.channels.fetch_add(1, Ordering::SeqCst);
359
360         // See comments on Arc::clone() on why we do this (for `mem::forget`).
361         if old_count > MAX_REFCOUNT {
362             abort();
363         }
364     }
365
366     // Decrement the reference count on a channel. This is called whenever a
367     // Chan is dropped and may end up waking up a receiver. It's the receiver's
368     // responsibility on the other end to figure out that we've disconnected.
369     pub fn drop_chan(&self) {
370         match self.channels.fetch_sub(1, Ordering::SeqCst) {
371             1 => {}
372             n if n > 1 => return,
373             n => panic!("bad number of channels left {n}"),
374         }
375
376         match self.cnt.swap(DISCONNECTED, Ordering::SeqCst) {
377             -1 => {
378                 self.take_to_wake().signal();
379             }
380             DISCONNECTED => {}
381             n => {
382                 assert!(n >= 0);
383             }
384         }
385     }
386
387     // See the long discussion inside of stream.rs for why the queue is drained,
388     // and why it is done in this fashion.
389     pub fn drop_port(&self) {
390         self.port_dropped.store(true, Ordering::SeqCst);
391         let mut steals = unsafe { *self.steals.get() };
392         while {
393             match self.cnt.compare_exchange(
394                 steals,
395                 DISCONNECTED,
396                 Ordering::SeqCst,
397                 Ordering::SeqCst,
398             ) {
399                 Ok(_) => false,
400                 Err(old) => old != DISCONNECTED,
401             }
402         } {
403             // See the discussion in 'try_recv' for why we yield
404             // control of this thread.
405             loop {
406                 match self.queue.pop() {
407                     mpsc::Data(..) => {
408                         steals += 1;
409                     }
410                     mpsc::Empty | mpsc::Inconsistent => break,
411                 }
412             }
413         }
414     }
415
416     // Consumes ownership of the 'to_wake' field.
417     fn take_to_wake(&self) -> SignalToken {
418         let ptr = self.to_wake.load(Ordering::SeqCst);
419         self.to_wake.store(EMPTY, Ordering::SeqCst);
420         assert!(ptr != EMPTY);
421         unsafe { SignalToken::from_raw(ptr) }
422     }
423
424     ////////////////////////////////////////////////////////////////////////////
425     // select implementation
426     ////////////////////////////////////////////////////////////////////////////
427
428     // increment the count on the channel (used for selection)
429     fn bump(&self, amt: isize) -> isize {
430         match self.cnt.fetch_add(amt, Ordering::SeqCst) {
431             DISCONNECTED => {
432                 self.cnt.store(DISCONNECTED, Ordering::SeqCst);
433                 DISCONNECTED
434             }
435             n => n,
436         }
437     }
438
439     // Cancels a previous thread waiting on this port, returning whether there's
440     // data on the port.
441     //
442     // This is similar to the stream implementation (hence fewer comments), but
443     // uses a different value for the "steals" variable.
444     pub fn abort_selection(&self, _was_upgrade: bool) -> bool {
445         // Before we do anything else, we bounce on this lock. The reason for
446         // doing this is to ensure that any upgrade-in-progress is gone and
447         // done with. Without this bounce, we can race with inherit_blocker
448         // about looking at and dealing with to_wake. Once we have acquired the
449         // lock, we are guaranteed that inherit_blocker is done.
450         {
451             let _guard = self.select_lock.lock().unwrap();
452         }
453
454         // Like the stream implementation, we want to make sure that the count
455         // on the channel goes non-negative. We don't know how negative the
456         // stream currently is, so instead of using a steal value of 1, we load
457         // the channel count and figure out what we should do to make it
458         // positive.
459         let steals = {
460             let cnt = self.cnt.load(Ordering::SeqCst);
461             if cnt < 0 && cnt != DISCONNECTED { -cnt } else { 0 }
462         };
463         let prev = self.bump(steals + 1);
464
465         if prev == DISCONNECTED {
466             assert_eq!(self.to_wake.load(Ordering::SeqCst), EMPTY);
467             true
468         } else {
469             let cur = prev + steals + 1;
470             assert!(cur >= 0);
471             if prev < 0 {
472                 drop(self.take_to_wake());
473             } else {
474                 while self.to_wake.load(Ordering::SeqCst) != EMPTY {
475                     thread::yield_now();
476                 }
477             }
478             unsafe {
479                 // if the number of steals is -1, it was the pre-emptive -1 steal
480                 // count from when we inherited a blocker. This is fine because
481                 // we're just going to overwrite it with a real value.
482                 let old = self.steals.get();
483                 assert!(*old == 0 || *old == -1);
484                 *old = steals;
485                 prev >= 0
486             }
487         }
488     }
489 }
490
491 impl<T> Drop for Packet<T> {
492     fn drop(&mut self) {
493         // Note that this load is not only an assert for correctness about
494         // disconnection, but also a proper fence before the read of
495         // `to_wake`, so this assert cannot be removed with also removing
496         // the `to_wake` assert.
497         assert_eq!(self.cnt.load(Ordering::SeqCst), DISCONNECTED);
498         assert_eq!(self.to_wake.load(Ordering::SeqCst), EMPTY);
499         assert_eq!(self.channels.load(Ordering::SeqCst), 0);
500     }
501 }