]> git.lizzy.rs Git - rust.git/blob - library/std/src/sync/mpsc/shared.rs
619a79f437b8a73805b8cccf45d19685903c91fa
[rust.git] / library / std / src / sync / mpsc / shared.rs
1 /// Shared channels.
2 ///
3 /// This is the flavor of channels which are not necessarily optimized for any
4 /// particular use case, but are the most general in how they are used. Shared
5 /// channels are cloneable allowing for multiple senders.
6 ///
7 /// High level implementation details can be found in the comment of the parent
8 /// module. You'll also note that the implementation of the shared and stream
9 /// channels are quite similar, and this is no coincidence!
10 pub use self::Failure::*;
11 use self::StartResult::*;
12
13 use core::cmp;
14 use core::intrinsics::abort;
15
16 use crate::cell::UnsafeCell;
17 use crate::ptr;
18 use crate::sync::atomic::{AtomicBool, AtomicIsize, AtomicUsize, Ordering};
19 use crate::sync::mpsc::blocking::{self, SignalToken};
20 use crate::sync::mpsc::mpsc_queue as mpsc;
21 use crate::sync::{Mutex, MutexGuard};
22 use crate::thread;
23 use crate::time::Instant;
24
25 const DISCONNECTED: isize = isize::MIN;
26 const FUDGE: isize = 1024;
27 const MAX_REFCOUNT: usize = (isize::MAX) as usize;
28 #[cfg(test)]
29 const MAX_STEALS: isize = 5;
30 #[cfg(not(test))]
31 const MAX_STEALS: isize = 1 << 20;
32
33 pub struct Packet<T> {
34     queue: mpsc::Queue<T>,
35     cnt: AtomicIsize,          // How many items are on this channel
36     steals: UnsafeCell<isize>, // How many times has a port received without blocking?
37     to_wake: AtomicUsize,      // SignalToken for wake up
38
39     // The number of channels which are currently using this packet.
40     channels: AtomicUsize,
41
42     // See the discussion in Port::drop and the channel send methods for what
43     // these are used for
44     port_dropped: AtomicBool,
45     sender_drain: AtomicIsize,
46
47     // this lock protects various portions of this implementation during
48     // select()
49     select_lock: Mutex<()>,
50 }
51
52 pub enum Failure {
53     Empty,
54     Disconnected,
55 }
56
57 #[derive(PartialEq, Eq)]
58 enum StartResult {
59     Installed,
60     Abort,
61 }
62
63 impl<T> Packet<T> {
64     // Creation of a packet *must* be followed by a call to postinit_lock
65     // and later by inherit_blocker
66     pub fn new() -> Packet<T> {
67         Packet {
68             queue: mpsc::Queue::new(),
69             cnt: AtomicIsize::new(0),
70             steals: UnsafeCell::new(0),
71             to_wake: AtomicUsize::new(0),
72             channels: AtomicUsize::new(2),
73             port_dropped: AtomicBool::new(false),
74             sender_drain: AtomicIsize::new(0),
75             select_lock: Mutex::new(()),
76         }
77     }
78
79     // This function should be used after newly created Packet
80     // was wrapped with an Arc
81     // In other case mutex data will be duplicated while cloning
82     // and that could cause problems on platforms where it is
83     // represented by opaque data structure
84     pub fn postinit_lock(&self) -> MutexGuard<'_, ()> {
85         self.select_lock.lock().unwrap()
86     }
87
88     // This function is used at the creation of a shared packet to inherit a
89     // previously blocked thread. This is done to prevent spurious wakeups of
90     // threads in select().
91     //
92     // This can only be called at channel-creation time
93     pub fn inherit_blocker(&self, token: Option<SignalToken>, guard: MutexGuard<'_, ()>) {
94         if let Some(token) = token {
95             assert_eq!(self.cnt.load(Ordering::SeqCst), 0);
96             assert_eq!(self.to_wake.load(Ordering::SeqCst), 0);
97             self.to_wake.store(unsafe { token.cast_to_usize() }, Ordering::SeqCst);
98             self.cnt.store(-1, Ordering::SeqCst);
99
100             // This store is a little sketchy. What's happening here is that
101             // we're transferring a blocker from a oneshot or stream channel to
102             // this shared channel. In doing so, we never spuriously wake them
103             // up and rather only wake them up at the appropriate time. This
104             // implementation of shared channels assumes that any blocking
105             // recv() will undo the increment of steals performed in try_recv()
106             // once the recv is complete.  This thread that we're inheriting,
107             // however, is not in the middle of recv. Hence, the first time we
108             // wake them up, they're going to wake up from their old port, move
109             // on to the upgraded port, and then call the block recv() function.
110             //
111             // When calling this function, they'll find there's data immediately
112             // available, counting it as a steal. This in fact wasn't a steal
113             // because we appropriately blocked them waiting for data.
114             //
115             // To offset this bad increment, we initially set the steal count to
116             // -1. You'll find some special code in abort_selection() as well to
117             // ensure that this -1 steal count doesn't escape too far.
118             unsafe {
119                 *self.steals.get() = -1;
120             }
121         }
122
123         // When the shared packet is constructed, we grabbed this lock. The
124         // purpose of this lock is to ensure that abort_selection() doesn't
125         // interfere with this method. After we unlock this lock, we're
126         // signifying that we're done modifying self.cnt and self.to_wake and
127         // the port is ready for the world to continue using it.
128         drop(guard);
129     }
130
131     pub fn send(&self, t: T) -> Result<(), T> {
132         // See Port::drop for what's going on
133         if self.port_dropped.load(Ordering::SeqCst) {
134             return Err(t);
135         }
136
137         // Note that the multiple sender case is a little trickier
138         // semantically than the single sender case. The logic for
139         // incrementing is "add and if disconnected store disconnected".
140         // This could end up leading some senders to believe that there
141         // wasn't a disconnect if in fact there was a disconnect. This means
142         // that while one thread is attempting to re-store the disconnected
143         // states, other threads could walk through merrily incrementing
144         // this very-negative disconnected count. To prevent senders from
145         // spuriously attempting to send when the channels is actually
146         // disconnected, the count has a ranged check here.
147         //
148         // This is also done for another reason. Remember that the return
149         // value of this function is:
150         //
151         //  `true` == the data *may* be received, this essentially has no
152         //            meaning
153         //  `false` == the data will *never* be received, this has a lot of
154         //             meaning
155         //
156         // In the SPSC case, we have a check of 'queue.is_empty()' to see
157         // whether the data was actually received, but this same condition
158         // means nothing in a multi-producer context. As a result, this
159         // preflight check serves as the definitive "this will never be
160         // received". Once we get beyond this check, we have permanently
161         // entered the realm of "this may be received"
162         if self.cnt.load(Ordering::SeqCst) < DISCONNECTED + FUDGE {
163             return Err(t);
164         }
165
166         self.queue.push(t);
167         match self.cnt.fetch_add(1, Ordering::SeqCst) {
168             -1 => {
169                 self.take_to_wake().signal();
170             }
171
172             // In this case, we have possibly failed to send our data, and
173             // we need to consider re-popping the data in order to fully
174             // destroy it. We must arbitrate among the multiple senders,
175             // however, because the queues that we're using are
176             // single-consumer queues. In order to do this, all exiting
177             // pushers will use an atomic count in order to count those
178             // flowing through. Pushers who see 0 are required to drain as
179             // much as possible, and then can only exit when they are the
180             // only pusher (otherwise they must try again).
181             n if n < DISCONNECTED + FUDGE => {
182                 // see the comment in 'try' for a shared channel for why this
183                 // window of "not disconnected" is ok.
184                 self.cnt.store(DISCONNECTED, Ordering::SeqCst);
185
186                 if self.sender_drain.fetch_add(1, Ordering::SeqCst) == 0 {
187                     loop {
188                         // drain the queue, for info on the thread yield see the
189                         // discussion in try_recv
190                         loop {
191                             match self.queue.pop() {
192                                 mpsc::Data(..) => {}
193                                 mpsc::Empty => break,
194                                 mpsc::Inconsistent => thread::yield_now(),
195                             }
196                         }
197                         // maybe we're done, if we're not the last ones
198                         // here, then we need to go try again.
199                         if self.sender_drain.fetch_sub(1, Ordering::SeqCst) == 1 {
200                             break;
201                         }
202                     }
203
204                     // At this point, there may still be data on the queue,
205                     // but only if the count hasn't been incremented and
206                     // some other sender hasn't finished pushing data just
207                     // yet. That sender in question will drain its own data.
208                 }
209             }
210
211             // Can't make any assumptions about this case like in the SPSC case.
212             _ => {}
213         }
214
215         Ok(())
216     }
217
218     pub fn recv(&self, deadline: Option<Instant>) -> Result<T, Failure> {
219         // This code is essentially the exact same as that found in the stream
220         // case (see stream.rs)
221         match self.try_recv() {
222             Err(Empty) => {}
223             data => return data,
224         }
225
226         let (wait_token, signal_token) = blocking::tokens();
227         if self.decrement(signal_token) == Installed {
228             if let Some(deadline) = deadline {
229                 let timed_out = !wait_token.wait_max_until(deadline);
230                 if timed_out {
231                     self.abort_selection(false);
232                 }
233             } else {
234                 wait_token.wait();
235             }
236         }
237
238         match self.try_recv() {
239             data @ Ok(..) => unsafe {
240                 *self.steals.get() -= 1;
241                 data
242             },
243             data => data,
244         }
245     }
246
247     // Essentially the exact same thing as the stream decrement function.
248     // Returns true if blocking should proceed.
249     fn decrement(&self, token: SignalToken) -> StartResult {
250         unsafe {
251             assert_eq!(
252                 self.to_wake.load(Ordering::SeqCst),
253                 0,
254                 "This is a known bug in rust. See https://github.com/rust-lang/rust/issues/39364"
255             );
256             let ptr = token.cast_to_usize();
257             self.to_wake.store(ptr, Ordering::SeqCst);
258
259             let steals = ptr::replace(self.steals.get(), 0);
260
261             match self.cnt.fetch_sub(1 + steals, Ordering::SeqCst) {
262                 DISCONNECTED => {
263                     self.cnt.store(DISCONNECTED, Ordering::SeqCst);
264                 }
265                 // If we factor in our steals and notice that the channel has no
266                 // data, we successfully sleep
267                 n => {
268                     assert!(n >= 0);
269                     if n - steals <= 0 {
270                         return Installed;
271                     }
272                 }
273             }
274
275             self.to_wake.store(0, Ordering::SeqCst);
276             drop(SignalToken::cast_from_usize(ptr));
277             Abort
278         }
279     }
280
281     pub fn try_recv(&self) -> Result<T, Failure> {
282         let ret = match self.queue.pop() {
283             mpsc::Data(t) => Some(t),
284             mpsc::Empty => None,
285
286             // This is a bit of an interesting case. The channel is reported as
287             // having data available, but our pop() has failed due to the queue
288             // being in an inconsistent state.  This means that there is some
289             // pusher somewhere which has yet to complete, but we are guaranteed
290             // that a pop will eventually succeed. In this case, we spin in a
291             // yield loop because the remote sender should finish their enqueue
292             // operation "very quickly".
293             //
294             // Avoiding this yield loop would require a different queue
295             // abstraction which provides the guarantee that after M pushes have
296             // succeeded, at least M pops will succeed. The current queues
297             // guarantee that if there are N active pushes, you can pop N times
298             // once all N have finished.
299             mpsc::Inconsistent => {
300                 let data;
301                 loop {
302                     thread::yield_now();
303                     match self.queue.pop() {
304                         mpsc::Data(t) => {
305                             data = t;
306                             break;
307                         }
308                         mpsc::Empty => panic!("inconsistent => empty"),
309                         mpsc::Inconsistent => {}
310                     }
311                 }
312                 Some(data)
313             }
314         };
315         match ret {
316             // See the discussion in the stream implementation for why we
317             // might decrement steals.
318             Some(data) => unsafe {
319                 if *self.steals.get() > MAX_STEALS {
320                     match self.cnt.swap(0, Ordering::SeqCst) {
321                         DISCONNECTED => {
322                             self.cnt.store(DISCONNECTED, Ordering::SeqCst);
323                         }
324                         n => {
325                             let m = cmp::min(n, *self.steals.get());
326                             *self.steals.get() -= m;
327                             self.bump(n - m);
328                         }
329                     }
330                     assert!(*self.steals.get() >= 0);
331                 }
332                 *self.steals.get() += 1;
333                 Ok(data)
334             },
335
336             // See the discussion in the stream implementation for why we try
337             // again.
338             None => {
339                 match self.cnt.load(Ordering::SeqCst) {
340                     n if n != DISCONNECTED => Err(Empty),
341                     _ => {
342                         match self.queue.pop() {
343                             mpsc::Data(t) => Ok(t),
344                             mpsc::Empty => Err(Disconnected),
345                             // with no senders, an inconsistency is impossible.
346                             mpsc::Inconsistent => unreachable!(),
347                         }
348                     }
349                 }
350             }
351         }
352     }
353
354     // Prepares this shared packet for a channel clone, essentially just bumping
355     // a refcount.
356     pub fn clone_chan(&self) {
357         let old_count = self.channels.fetch_add(1, Ordering::SeqCst);
358
359         // See comments on Arc::clone() on why we do this (for `mem::forget`).
360         if old_count > MAX_REFCOUNT {
361             abort();
362         }
363     }
364
365     // Decrement the reference count on a channel. This is called whenever a
366     // Chan is dropped and may end up waking up a receiver. It's the receiver's
367     // responsibility on the other end to figure out that we've disconnected.
368     pub fn drop_chan(&self) {
369         match self.channels.fetch_sub(1, Ordering::SeqCst) {
370             1 => {}
371             n if n > 1 => return,
372             n => panic!("bad number of channels left {}", n),
373         }
374
375         match self.cnt.swap(DISCONNECTED, Ordering::SeqCst) {
376             -1 => {
377                 self.take_to_wake().signal();
378             }
379             DISCONNECTED => {}
380             n => {
381                 assert!(n >= 0);
382             }
383         }
384     }
385
386     // See the long discussion inside of stream.rs for why the queue is drained,
387     // and why it is done in this fashion.
388     pub fn drop_port(&self) {
389         self.port_dropped.store(true, Ordering::SeqCst);
390         let mut steals = unsafe { *self.steals.get() };
391         while {
392             match self.cnt.compare_exchange(
393                 steals,
394                 DISCONNECTED,
395                 Ordering::SeqCst,
396                 Ordering::SeqCst,
397             ) {
398                 Ok(_) => false,
399                 Err(old) => old != DISCONNECTED,
400             }
401         } {
402             // See the discussion in 'try_recv' for why we yield
403             // control of this thread.
404             loop {
405                 match self.queue.pop() {
406                     mpsc::Data(..) => {
407                         steals += 1;
408                     }
409                     mpsc::Empty | mpsc::Inconsistent => break,
410                 }
411             }
412         }
413     }
414
415     // Consumes ownership of the 'to_wake' field.
416     fn take_to_wake(&self) -> SignalToken {
417         let ptr = self.to_wake.load(Ordering::SeqCst);
418         self.to_wake.store(0, Ordering::SeqCst);
419         assert!(ptr != 0);
420         unsafe { SignalToken::cast_from_usize(ptr) }
421     }
422
423     ////////////////////////////////////////////////////////////////////////////
424     // select implementation
425     ////////////////////////////////////////////////////////////////////////////
426
427     // increment the count on the channel (used for selection)
428     fn bump(&self, amt: isize) -> isize {
429         match self.cnt.fetch_add(amt, Ordering::SeqCst) {
430             DISCONNECTED => {
431                 self.cnt.store(DISCONNECTED, Ordering::SeqCst);
432                 DISCONNECTED
433             }
434             n => n,
435         }
436     }
437
438     // Cancels a previous thread waiting on this port, returning whether there's
439     // data on the port.
440     //
441     // This is similar to the stream implementation (hence fewer comments), but
442     // uses a different value for the "steals" variable.
443     pub fn abort_selection(&self, _was_upgrade: bool) -> bool {
444         // Before we do anything else, we bounce on this lock. The reason for
445         // doing this is to ensure that any upgrade-in-progress is gone and
446         // done with. Without this bounce, we can race with inherit_blocker
447         // about looking at and dealing with to_wake. Once we have acquired the
448         // lock, we are guaranteed that inherit_blocker is done.
449         {
450             let _guard = self.select_lock.lock().unwrap();
451         }
452
453         // Like the stream implementation, we want to make sure that the count
454         // on the channel goes non-negative. We don't know how negative the
455         // stream currently is, so instead of using a steal value of 1, we load
456         // the channel count and figure out what we should do to make it
457         // positive.
458         let steals = {
459             let cnt = self.cnt.load(Ordering::SeqCst);
460             if cnt < 0 && cnt != DISCONNECTED { -cnt } else { 0 }
461         };
462         let prev = self.bump(steals + 1);
463
464         if prev == DISCONNECTED {
465             assert_eq!(self.to_wake.load(Ordering::SeqCst), 0);
466             true
467         } else {
468             let cur = prev + steals + 1;
469             assert!(cur >= 0);
470             if prev < 0 {
471                 drop(self.take_to_wake());
472             } else {
473                 while self.to_wake.load(Ordering::SeqCst) != 0 {
474                     thread::yield_now();
475                 }
476             }
477             unsafe {
478                 // if the number of steals is -1, it was the pre-emptive -1 steal
479                 // count from when we inherited a blocker. This is fine because
480                 // we're just going to overwrite it with a real value.
481                 let old = self.steals.get();
482                 assert!(*old == 0 || *old == -1);
483                 *old = steals;
484                 prev >= 0
485             }
486         }
487     }
488 }
489
490 impl<T> Drop for Packet<T> {
491     fn drop(&mut self) {
492         // Note that this load is not only an assert for correctness about
493         // disconnection, but also a proper fence before the read of
494         // `to_wake`, so this assert cannot be removed with also removing
495         // the `to_wake` assert.
496         assert_eq!(self.cnt.load(Ordering::SeqCst), DISCONNECTED);
497         assert_eq!(self.to_wake.load(Ordering::SeqCst), 0);
498         assert_eq!(self.channels.load(Ordering::SeqCst), 0);
499     }
500 }