]> git.lizzy.rs Git - rust.git/blob - src/libstd/sync/mpsc/shared.rs
Merge pull request #20510 from tshepang/patch-6
[rust.git] / src / libstd / sync / mpsc / shared.rs
1 // Copyright 2014 The Rust Project Developers. See the COPYRIGHT
2 // file at the top-level directory of this distribution and at
3 // http://rust-lang.org/COPYRIGHT.
4 //
5 // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
6 // http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
7 // <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
8 // option. This file may not be copied, modified, or distributed
9 // except according to those terms.
10
11 /// Shared channels
12 ///
13 /// This is the flavor of channels which are not necessarily optimized for any
14 /// particular use case, but are the most general in how they are used. Shared
15 /// channels are cloneable allowing for multiple senders.
16 ///
17 /// High level implementation details can be found in the comment of the parent
18 /// module. You'll also note that the implementation of the shared and stream
19 /// channels are quite similar, and this is no coincidence!
20
21 pub use self::Failure::*;
22
23 use core::prelude::*;
24
25 use core::cmp;
26 use core::int;
27
28 use sync::atomic::{AtomicUint, AtomicInt, AtomicBool, Ordering};
29 use sync::mpsc::blocking::{self, SignalToken};
30 use sync::mpsc::mpsc_queue as mpsc;
31 use sync::mpsc::select::StartResult::*;
32 use sync::mpsc::select::StartResult;
33 use sync::{Mutex, MutexGuard};
34 use thread::Thread;
35
36 const DISCONNECTED: int = int::MIN;
37 const FUDGE: int = 1024;
38 #[cfg(test)]
39 const MAX_STEALS: int = 5;
40 #[cfg(not(test))]
41 const MAX_STEALS: int = 1 << 20;
42
43 pub struct Packet<T> {
44     queue: mpsc::Queue<T>,
45     cnt: AtomicInt, // How many items are on this channel
46     steals: int, // How many times has a port received without blocking?
47     to_wake: AtomicUint, // SignalToken for wake up
48
49     // The number of channels which are currently using this packet.
50     channels: AtomicInt,
51
52     // See the discussion in Port::drop and the channel send methods for what
53     // these are used for
54     port_dropped: AtomicBool,
55     sender_drain: AtomicInt,
56
57     // this lock protects various portions of this implementation during
58     // select()
59     select_lock: Mutex<()>,
60 }
61
62 pub enum Failure {
63     Empty,
64     Disconnected,
65 }
66
67 impl<T: Send> Packet<T> {
68     // Creation of a packet *must* be followed by a call to postinit_lock
69     // and later by inherit_blocker
70     pub fn new() -> Packet<T> {
71         let p = Packet {
72             queue: mpsc::Queue::new(),
73             cnt: AtomicInt::new(0),
74             steals: 0,
75             to_wake: AtomicUint::new(0),
76             channels: AtomicInt::new(2),
77             port_dropped: AtomicBool::new(false),
78             sender_drain: AtomicInt::new(0),
79             select_lock: Mutex::new(()),
80         };
81         return p;
82     }
83
84     // This function should be used after newly created Packet
85     // was wrapped with an Arc
86     // In other case mutex data will be duplicated while cloning
87     // and that could cause problems on platforms where it is
88     // represented by opaque data structure
89     pub fn postinit_lock(&self) -> MutexGuard<()> {
90         self.select_lock.lock().unwrap()
91     }
92
93     // This function is used at the creation of a shared packet to inherit a
94     // previously blocked task. This is done to prevent spurious wakeups of
95     // tasks in select().
96     //
97     // This can only be called at channel-creation time
98     pub fn inherit_blocker(&mut self,
99                            token: Option<SignalToken>,
100                            guard: MutexGuard<()>) {
101         token.map(|token| {
102             assert_eq!(self.cnt.load(Ordering::SeqCst), 0);
103             assert_eq!(self.to_wake.load(Ordering::SeqCst), 0);
104             self.to_wake.store(unsafe { token.cast_to_uint() }, Ordering::SeqCst);
105             self.cnt.store(-1, Ordering::SeqCst);
106
107             // This store is a little sketchy. What's happening here is that
108             // we're transferring a blocker from a oneshot or stream channel to
109             // this shared channel. In doing so, we never spuriously wake them
110             // up and rather only wake them up at the appropriate time. This
111             // implementation of shared channels assumes that any blocking
112             // recv() will undo the increment of steals performed in try_recv()
113             // once the recv is complete.  This thread that we're inheriting,
114             // however, is not in the middle of recv. Hence, the first time we
115             // wake them up, they're going to wake up from their old port, move
116             // on to the upgraded port, and then call the block recv() function.
117             //
118             // When calling this function, they'll find there's data immediately
119             // available, counting it as a steal. This in fact wasn't a steal
120             // because we appropriately blocked them waiting for data.
121             //
122             // To offset this bad increment, we initially set the steal count to
123             // -1. You'll find some special code in abort_selection() as well to
124             // ensure that this -1 steal count doesn't escape too far.
125             self.steals = -1;
126         });
127
128         // When the shared packet is constructed, we grabbed this lock. The
129         // purpose of this lock is to ensure that abort_selection() doesn't
130         // interfere with this method. After we unlock this lock, we're
131         // signifying that we're done modifying self.cnt and self.to_wake and
132         // the port is ready for the world to continue using it.
133         drop(guard);
134     }
135
136     pub fn send(&mut self, t: T) -> Result<(), T> {
137         // See Port::drop for what's going on
138         if self.port_dropped.load(Ordering::SeqCst) { return Err(t) }
139
140         // Note that the multiple sender case is a little trickier
141         // semantically than the single sender case. The logic for
142         // incrementing is "add and if disconnected store disconnected".
143         // This could end up leading some senders to believe that there
144         // wasn't a disconnect if in fact there was a disconnect. This means
145         // that while one thread is attempting to re-store the disconnected
146         // states, other threads could walk through merrily incrementing
147         // this very-negative disconnected count. To prevent senders from
148         // spuriously attempting to send when the channels is actually
149         // disconnected, the count has a ranged check here.
150         //
151         // This is also done for another reason. Remember that the return
152         // value of this function is:
153         //
154         //  `true` == the data *may* be received, this essentially has no
155         //            meaning
156         //  `false` == the data will *never* be received, this has a lot of
157         //             meaning
158         //
159         // In the SPSC case, we have a check of 'queue.is_empty()' to see
160         // whether the data was actually received, but this same condition
161         // means nothing in a multi-producer context. As a result, this
162         // preflight check serves as the definitive "this will never be
163         // received". Once we get beyond this check, we have permanently
164         // entered the realm of "this may be received"
165         if self.cnt.load(Ordering::SeqCst) < DISCONNECTED + FUDGE {
166             return Err(t)
167         }
168
169         self.queue.push(t);
170         match self.cnt.fetch_add(1, Ordering::SeqCst) {
171             -1 => {
172                 self.take_to_wake().signal();
173             }
174
175             // In this case, we have possibly failed to send our data, and
176             // we need to consider re-popping the data in order to fully
177             // destroy it. We must arbitrate among the multiple senders,
178             // however, because the queues that we're using are
179             // single-consumer queues. In order to do this, all exiting
180             // pushers will use an atomic count in order to count those
181             // flowing through. Pushers who see 0 are required to drain as
182             // much as possible, and then can only exit when they are the
183             // only pusher (otherwise they must try again).
184             n if n < DISCONNECTED + FUDGE => {
185                 // see the comment in 'try' for a shared channel for why this
186                 // window of "not disconnected" is ok.
187                 self.cnt.store(DISCONNECTED, Ordering::SeqCst);
188
189                 if self.sender_drain.fetch_add(1, Ordering::SeqCst) == 0 {
190                     loop {
191                         // drain the queue, for info on the thread yield see the
192                         // discussion in try_recv
193                         loop {
194                             match self.queue.pop() {
195                                 mpsc::Data(..) => {}
196                                 mpsc::Empty => break,
197                                 mpsc::Inconsistent => Thread::yield_now(),
198                             }
199                         }
200                         // maybe we're done, if we're not the last ones
201                         // here, then we need to go try again.
202                         if self.sender_drain.fetch_sub(1, Ordering::SeqCst) == 1 {
203                             break
204                         }
205                     }
206
207                     // At this point, there may still be data on the queue,
208                     // but only if the count hasn't been incremented and
209                     // some other sender hasn't finished pushing data just
210                     // yet. That sender in question will drain its own data.
211                 }
212             }
213
214             // Can't make any assumptions about this case like in the SPSC case.
215             _ => {}
216         }
217
218         Ok(())
219     }
220
221     pub fn recv(&mut self) -> Result<T, Failure> {
222         // This code is essentially the exact same as that found in the stream
223         // case (see stream.rs)
224         match self.try_recv() {
225             Err(Empty) => {}
226             data => return data,
227         }
228
229         let (wait_token, signal_token) = blocking::tokens();
230         if self.decrement(signal_token) == Installed {
231             wait_token.wait()
232         }
233
234         match self.try_recv() {
235             data @ Ok(..) => { self.steals -= 1; data }
236             data => data,
237         }
238     }
239
240     // Essentially the exact same thing as the stream decrement function.
241     // Returns true if blocking should proceed.
242     fn decrement(&mut self, token: SignalToken) -> StartResult {
243         assert_eq!(self.to_wake.load(Ordering::SeqCst), 0);
244         let ptr = unsafe { token.cast_to_uint() };
245         self.to_wake.store(ptr, Ordering::SeqCst);
246
247         let steals = self.steals;
248         self.steals = 0;
249
250         match self.cnt.fetch_sub(1 + steals, Ordering::SeqCst) {
251             DISCONNECTED => { self.cnt.store(DISCONNECTED, Ordering::SeqCst); }
252             // If we factor in our steals and notice that the channel has no
253             // data, we successfully sleep
254             n => {
255                 assert!(n >= 0);
256                 if n - steals <= 0 { return Installed }
257             }
258         }
259
260         self.to_wake.store(0, Ordering::SeqCst);
261         drop(unsafe { SignalToken::cast_from_uint(ptr) });
262         Abort
263     }
264
265     pub fn try_recv(&mut self) -> Result<T, Failure> {
266         let ret = match self.queue.pop() {
267             mpsc::Data(t) => Some(t),
268             mpsc::Empty => None,
269
270             // This is a bit of an interesting case. The channel is reported as
271             // having data available, but our pop() has failed due to the queue
272             // being in an inconsistent state.  This means that there is some
273             // pusher somewhere which has yet to complete, but we are guaranteed
274             // that a pop will eventually succeed. In this case, we spin in a
275             // yield loop because the remote sender should finish their enqueue
276             // operation "very quickly".
277             //
278             // Avoiding this yield loop would require a different queue
279             // abstraction which provides the guarantee that after M pushes have
280             // succeeded, at least M pops will succeed. The current queues
281             // guarantee that if there are N active pushes, you can pop N times
282             // once all N have finished.
283             mpsc::Inconsistent => {
284                 let data;
285                 loop {
286                     Thread::yield_now();
287                     match self.queue.pop() {
288                         mpsc::Data(t) => { data = t; break }
289                         mpsc::Empty => panic!("inconsistent => empty"),
290                         mpsc::Inconsistent => {}
291                     }
292                 }
293                 Some(data)
294             }
295         };
296         match ret {
297             // See the discussion in the stream implementation for why we
298             // might decrement steals.
299             Some(data) => {
300                 if self.steals > MAX_STEALS {
301                     match self.cnt.swap(0, Ordering::SeqCst) {
302                         DISCONNECTED => {
303                             self.cnt.store(DISCONNECTED, Ordering::SeqCst);
304                         }
305                         n => {
306                             let m = cmp::min(n, self.steals);
307                             self.steals -= m;
308                             self.bump(n - m);
309                         }
310                     }
311                     assert!(self.steals >= 0);
312                 }
313                 self.steals += 1;
314                 Ok(data)
315             }
316
317             // See the discussion in the stream implementation for why we try
318             // again.
319             None => {
320                 match self.cnt.load(Ordering::SeqCst) {
321                     n if n != DISCONNECTED => Err(Empty),
322                     _ => {
323                         match self.queue.pop() {
324                             mpsc::Data(t) => Ok(t),
325                             mpsc::Empty => Err(Disconnected),
326                             // with no senders, an inconsistency is impossible.
327                             mpsc::Inconsistent => unreachable!(),
328                         }
329                     }
330                 }
331             }
332         }
333     }
334
335     // Prepares this shared packet for a channel clone, essentially just bumping
336     // a refcount.
337     pub fn clone_chan(&mut self) {
338         self.channels.fetch_add(1, Ordering::SeqCst);
339     }
340
341     // Decrement the reference count on a channel. This is called whenever a
342     // Chan is dropped and may end up waking up a receiver. It's the receiver's
343     // responsibility on the other end to figure out that we've disconnected.
344     pub fn drop_chan(&mut self) {
345         match self.channels.fetch_sub(1, Ordering::SeqCst) {
346             1 => {}
347             n if n > 1 => return,
348             n => panic!("bad number of channels left {}", n),
349         }
350
351         match self.cnt.swap(DISCONNECTED, Ordering::SeqCst) {
352             -1 => { self.take_to_wake().signal(); }
353             DISCONNECTED => {}
354             n => { assert!(n >= 0); }
355         }
356     }
357
358     // See the long discussion inside of stream.rs for why the queue is drained,
359     // and why it is done in this fashion.
360     pub fn drop_port(&mut self) {
361         self.port_dropped.store(true, Ordering::SeqCst);
362         let mut steals = self.steals;
363         while {
364             let cnt = self.cnt.compare_and_swap(steals, DISCONNECTED, Ordering::SeqCst);
365             cnt != DISCONNECTED && cnt != steals
366         } {
367             // See the discussion in 'try_recv' for why we yield
368             // control of this thread.
369             loop {
370                 match self.queue.pop() {
371                     mpsc::Data(..) => { steals += 1; }
372                     mpsc::Empty | mpsc::Inconsistent => break,
373                 }
374             }
375         }
376     }
377
378     // Consumes ownership of the 'to_wake' field.
379     fn take_to_wake(&mut self) -> SignalToken {
380         let ptr = self.to_wake.load(Ordering::SeqCst);
381         self.to_wake.store(0, Ordering::SeqCst);
382         assert!(ptr != 0);
383         unsafe { SignalToken::cast_from_uint(ptr) }
384     }
385
386     ////////////////////////////////////////////////////////////////////////////
387     // select implementation
388     ////////////////////////////////////////////////////////////////////////////
389
390     // Helper function for select, tests whether this port can receive without
391     // blocking (obviously not an atomic decision).
392     //
393     // This is different than the stream version because there's no need to peek
394     // at the queue, we can just look at the local count.
395     pub fn can_recv(&mut self) -> bool {
396         let cnt = self.cnt.load(Ordering::SeqCst);
397         cnt == DISCONNECTED || cnt - self.steals > 0
398     }
399
400     // increment the count on the channel (used for selection)
401     fn bump(&mut self, amt: int) -> int {
402         match self.cnt.fetch_add(amt, Ordering::SeqCst) {
403             DISCONNECTED => {
404                 self.cnt.store(DISCONNECTED, Ordering::SeqCst);
405                 DISCONNECTED
406             }
407             n => n
408         }
409     }
410
411     // Inserts the signal token for selection on this port, returning true if
412     // blocking should proceed.
413     //
414     // The code here is the same as in stream.rs, except that it doesn't need to
415     // peek at the channel to see if an upgrade is pending.
416     pub fn start_selection(&mut self, token: SignalToken) -> StartResult {
417         match self.decrement(token) {
418             Installed => Installed,
419             Abort => {
420                 let prev = self.bump(1);
421                 assert!(prev == DISCONNECTED || prev >= 0);
422                 Abort
423             }
424         }
425     }
426
427     // Cancels a previous task waiting on this port, returning whether there's
428     // data on the port.
429     //
430     // This is similar to the stream implementation (hence fewer comments), but
431     // uses a different value for the "steals" variable.
432     pub fn abort_selection(&mut self, _was_upgrade: bool) -> bool {
433         // Before we do anything else, we bounce on this lock. The reason for
434         // doing this is to ensure that any upgrade-in-progress is gone and
435         // done with. Without this bounce, we can race with inherit_blocker
436         // about looking at and dealing with to_wake. Once we have acquired the
437         // lock, we are guaranteed that inherit_blocker is done.
438         {
439             let _guard = self.select_lock.lock().unwrap();
440         }
441
442         // Like the stream implementation, we want to make sure that the count
443         // on the channel goes non-negative. We don't know how negative the
444         // stream currently is, so instead of using a steal value of 1, we load
445         // the channel count and figure out what we should do to make it
446         // positive.
447         let steals = {
448             let cnt = self.cnt.load(Ordering::SeqCst);
449             if cnt < 0 && cnt != DISCONNECTED {-cnt} else {0}
450         };
451         let prev = self.bump(steals + 1);
452
453         if prev == DISCONNECTED {
454             assert_eq!(self.to_wake.load(Ordering::SeqCst), 0);
455             true
456         } else {
457             let cur = prev + steals + 1;
458             assert!(cur >= 0);
459             if prev < 0 {
460                 drop(self.take_to_wake());
461             } else {
462                 while self.to_wake.load(Ordering::SeqCst) != 0 {
463                     Thread::yield_now();
464                 }
465             }
466             // if the number of steals is -1, it was the pre-emptive -1 steal
467             // count from when we inherited a blocker. This is fine because
468             // we're just going to overwrite it with a real value.
469             assert!(self.steals == 0 || self.steals == -1);
470             self.steals = steals;
471             prev >= 0
472         }
473     }
474 }
475
476 #[unsafe_destructor]
477 impl<T: Send> Drop for Packet<T> {
478     fn drop(&mut self) {
479         // Note that this load is not only an assert for correctness about
480         // disconnection, but also a proper fence before the read of
481         // `to_wake`, so this assert cannot be removed with also removing
482         // the `to_wake` assert.
483         assert_eq!(self.cnt.load(Ordering::SeqCst), DISCONNECTED);
484         assert_eq!(self.to_wake.load(Ordering::SeqCst), 0);
485         assert_eq!(self.channels.load(Ordering::SeqCst), 0);
486     }
487 }