]> git.lizzy.rs Git - rust.git/blob - library/std/src/sync/mpsc/stream.rs
initial port of crossbeam-channel
[rust.git] / library / std / src / sync / mpsc / stream.rs
1 /// Stream channels
2 ///
3 /// This is the flavor of channels which are optimized for one sender and one
4 /// receiver. The sender will be upgraded to a shared channel if the channel is
5 /// cloned.
6 ///
7 /// High level implementation details can be found in the comment of the parent
8 /// module.
9 pub use self::Failure::*;
10 use self::Message::*;
11 pub use self::UpgradeResult::*;
12
13 use core::cmp;
14
15 use crate::cell::UnsafeCell;
16 use crate::ptr;
17 use crate::thread;
18 use crate::time::Instant;
19
20 use crate::sync::atomic::{AtomicBool, AtomicIsize, AtomicPtr, Ordering};
21 use crate::sync::mpsc::blocking::{self, SignalToken};
22 use crate::sync::mpsc::spsc_queue as spsc;
23 use crate::sync::mpsc::Receiver;
24
25 const DISCONNECTED: isize = isize::MIN;
26 #[cfg(test)]
27 const MAX_STEALS: isize = 5;
28 #[cfg(not(test))]
29 const MAX_STEALS: isize = 1 << 20;
30 const EMPTY: *mut u8 = ptr::null_mut(); // initial state: no data, no blocked receiver
31
32 pub struct Packet<T> {
33     // internal queue for all messages
34     queue: spsc::Queue<Message<T>, ProducerAddition, ConsumerAddition>,
35 }
36
37 struct ProducerAddition {
38     cnt: AtomicIsize,       // How many items are on this channel
39     to_wake: AtomicPtr<u8>, // SignalToken for the blocked thread to wake up
40
41     port_dropped: AtomicBool, // flag if the channel has been destroyed.
42 }
43
44 struct ConsumerAddition {
45     steals: UnsafeCell<isize>, // How many times has a port received without blocking?
46 }
47
48 pub enum Failure<T> {
49     Empty,
50     Disconnected,
51     Upgraded(Receiver<T>),
52 }
53
54 pub enum UpgradeResult {
55     UpSuccess,
56     UpDisconnected,
57     UpWoke(SignalToken),
58 }
59
60 // Any message could contain an "upgrade request" to a new shared port, so the
61 // internal queue it's a queue of T, but rather Message<T>
62 enum Message<T> {
63     Data(T),
64     GoUp(Receiver<T>),
65 }
66
67 impl<T> Packet<T> {
68     pub fn new() -> Packet<T> {
69         Packet {
70             queue: unsafe {
71                 spsc::Queue::with_additions(
72                     128,
73                     ProducerAddition {
74                         cnt: AtomicIsize::new(0),
75                         to_wake: AtomicPtr::new(EMPTY),
76
77                         port_dropped: AtomicBool::new(false),
78                     },
79                     ConsumerAddition { steals: UnsafeCell::new(0) },
80                 )
81             },
82         }
83     }
84
85     pub fn send(&self, t: T) -> Result<(), T> {
86         // If the other port has deterministically gone away, then definitely
87         // must return the data back up the stack. Otherwise, the data is
88         // considered as being sent.
89         if self.queue.producer_addition().port_dropped.load(Ordering::SeqCst) {
90             return Err(t);
91         }
92
93         match self.do_send(Data(t)) {
94             UpSuccess | UpDisconnected => {}
95             UpWoke(token) => {
96                 token.signal();
97             }
98         }
99         Ok(())
100     }
101
102     pub fn upgrade(&self, up: Receiver<T>) -> UpgradeResult {
103         // If the port has gone away, then there's no need to proceed any
104         // further.
105         if self.queue.producer_addition().port_dropped.load(Ordering::SeqCst) {
106             return UpDisconnected;
107         }
108
109         self.do_send(GoUp(up))
110     }
111
112     fn do_send(&self, t: Message<T>) -> UpgradeResult {
113         self.queue.push(t);
114         match self.queue.producer_addition().cnt.fetch_add(1, Ordering::SeqCst) {
115             // As described in the mod's doc comment, -1 == wakeup
116             -1 => UpWoke(self.take_to_wake()),
117             // As described before, SPSC queues must be >= -2
118             -2 => UpSuccess,
119
120             // Be sure to preserve the disconnected state, and the return value
121             // in this case is going to be whether our data was received or not.
122             // This manifests itself on whether we have an empty queue or not.
123             //
124             // Primarily, are required to drain the queue here because the port
125             // will never remove this data. We can only have at most one item to
126             // drain (the port drains the rest).
127             DISCONNECTED => {
128                 self.queue.producer_addition().cnt.store(DISCONNECTED, Ordering::SeqCst);
129                 let first = self.queue.pop();
130                 let second = self.queue.pop();
131                 assert!(second.is_none());
132
133                 match first {
134                     Some(..) => UpSuccess,  // we failed to send the data
135                     None => UpDisconnected, // we successfully sent data
136                 }
137             }
138
139             // Otherwise we just sent some data on a non-waiting queue, so just
140             // make sure the world is sane and carry on!
141             n => {
142                 assert!(n >= 0);
143                 UpSuccess
144             }
145         }
146     }
147
148     // Consumes ownership of the 'to_wake' field.
149     fn take_to_wake(&self) -> SignalToken {
150         let ptr = self.queue.producer_addition().to_wake.load(Ordering::SeqCst);
151         self.queue.producer_addition().to_wake.store(EMPTY, Ordering::SeqCst);
152         assert!(ptr != EMPTY);
153         unsafe { SignalToken::from_raw(ptr) }
154     }
155
156     // Decrements the count on the channel for a sleeper, returning the sleeper
157     // back if it shouldn't sleep. Note that this is the location where we take
158     // steals into account.
159     fn decrement(&self, token: SignalToken) -> Result<(), SignalToken> {
160         assert_eq!(self.queue.producer_addition().to_wake.load(Ordering::SeqCst), EMPTY);
161         let ptr = unsafe { token.to_raw() };
162         self.queue.producer_addition().to_wake.store(ptr, Ordering::SeqCst);
163
164         let steals = unsafe { ptr::replace(self.queue.consumer_addition().steals.get(), 0) };
165
166         match self.queue.producer_addition().cnt.fetch_sub(1 + steals, Ordering::SeqCst) {
167             DISCONNECTED => {
168                 self.queue.producer_addition().cnt.store(DISCONNECTED, Ordering::SeqCst);
169             }
170             // If we factor in our steals and notice that the channel has no
171             // data, we successfully sleep
172             n => {
173                 assert!(n >= 0);
174                 if n - steals <= 0 {
175                     return Ok(());
176                 }
177             }
178         }
179
180         self.queue.producer_addition().to_wake.store(EMPTY, Ordering::SeqCst);
181         Err(unsafe { SignalToken::from_raw(ptr) })
182     }
183
184     pub fn recv(&self, deadline: Option<Instant>) -> Result<T, Failure<T>> {
185         // Optimistic preflight check (scheduling is expensive).
186         match self.try_recv() {
187             Err(Empty) => {}
188             data => return data,
189         }
190
191         // Welp, our channel has no data. Deschedule the current thread and
192         // initiate the blocking protocol.
193         let (wait_token, signal_token) = blocking::tokens();
194         if self.decrement(signal_token).is_ok() {
195             if let Some(deadline) = deadline {
196                 let timed_out = !wait_token.wait_max_until(deadline);
197                 if timed_out {
198                     self.abort_selection(/* was_upgrade = */ false).map_err(Upgraded)?;
199                 }
200             } else {
201                 wait_token.wait();
202             }
203         }
204
205         match self.try_recv() {
206             // Messages which actually popped from the queue shouldn't count as
207             // a steal, so offset the decrement here (we already have our
208             // "steal" factored into the channel count above).
209             data @ (Ok(..) | Err(Upgraded(..))) => unsafe {
210                 *self.queue.consumer_addition().steals.get() -= 1;
211                 data
212             },
213
214             data => data,
215         }
216     }
217
218     pub fn try_recv(&self) -> Result<T, Failure<T>> {
219         match self.queue.pop() {
220             // If we stole some data, record to that effect (this will be
221             // factored into cnt later on).
222             //
223             // Note that we don't allow steals to grow without bound in order to
224             // prevent eventual overflow of either steals or cnt as an overflow
225             // would have catastrophic results. Sometimes, steals > cnt, but
226             // other times cnt > steals, so we don't know the relation between
227             // steals and cnt. This code path is executed only rarely, so we do
228             // a pretty slow operation, of swapping 0 into cnt, taking steals
229             // down as much as possible (without going negative), and then
230             // adding back in whatever we couldn't factor into steals.
231             Some(data) => unsafe {
232                 if *self.queue.consumer_addition().steals.get() > MAX_STEALS {
233                     match self.queue.producer_addition().cnt.swap(0, Ordering::SeqCst) {
234                         DISCONNECTED => {
235                             self.queue
236                                 .producer_addition()
237                                 .cnt
238                                 .store(DISCONNECTED, Ordering::SeqCst);
239                         }
240                         n => {
241                             let m = cmp::min(n, *self.queue.consumer_addition().steals.get());
242                             *self.queue.consumer_addition().steals.get() -= m;
243                             self.bump(n - m);
244                         }
245                     }
246                     assert!(*self.queue.consumer_addition().steals.get() >= 0);
247                 }
248                 *self.queue.consumer_addition().steals.get() += 1;
249                 match data {
250                     Data(t) => Ok(t),
251                     GoUp(up) => Err(Upgraded(up)),
252                 }
253             },
254
255             None => {
256                 match self.queue.producer_addition().cnt.load(Ordering::SeqCst) {
257                     n if n != DISCONNECTED => Err(Empty),
258
259                     // This is a little bit of a tricky case. We failed to pop
260                     // data above, and then we have viewed that the channel is
261                     // disconnected. In this window more data could have been
262                     // sent on the channel. It doesn't really make sense to
263                     // return that the channel is disconnected when there's
264                     // actually data on it, so be extra sure there's no data by
265                     // popping one more time.
266                     //
267                     // We can ignore steals because the other end is
268                     // disconnected and we'll never need to really factor in our
269                     // steals again.
270                     _ => match self.queue.pop() {
271                         Some(Data(t)) => Ok(t),
272                         Some(GoUp(up)) => Err(Upgraded(up)),
273                         None => Err(Disconnected),
274                     },
275                 }
276             }
277         }
278     }
279
280     pub fn drop_chan(&self) {
281         // Dropping a channel is pretty simple, we just flag it as disconnected
282         // and then wakeup a blocker if there is one.
283         match self.queue.producer_addition().cnt.swap(DISCONNECTED, Ordering::SeqCst) {
284             -1 => {
285                 self.take_to_wake().signal();
286             }
287             DISCONNECTED => {}
288             n => {
289                 assert!(n >= 0);
290             }
291         }
292     }
293
294     pub fn drop_port(&self) {
295         // Dropping a port seems like a fairly trivial thing. In theory all we
296         // need to do is flag that we're disconnected and then everything else
297         // can take over (we don't have anyone to wake up).
298         //
299         // The catch for Ports is that we want to drop the entire contents of
300         // the queue. There are multiple reasons for having this property, the
301         // largest of which is that if another chan is waiting in this channel
302         // (but not received yet), then waiting on that port will cause a
303         // deadlock.
304         //
305         // So if we accept that we must now destroy the entire contents of the
306         // queue, this code may make a bit more sense. The tricky part is that
307         // we can't let any in-flight sends go un-dropped, we have to make sure
308         // *everything* is dropped and nothing new will come onto the channel.
309
310         // The first thing we do is set a flag saying that we're done for. All
311         // sends are gated on this flag, so we're immediately guaranteed that
312         // there are a bounded number of active sends that we'll have to deal
313         // with.
314         self.queue.producer_addition().port_dropped.store(true, Ordering::SeqCst);
315
316         // Now that we're guaranteed to deal with a bounded number of senders,
317         // we need to drain the queue. This draining process happens atomically
318         // with respect to the "count" of the channel. If the count is nonzero
319         // (with steals taken into account), then there must be data on the
320         // channel. In this case we drain everything and then try again. We will
321         // continue to fail while active senders send data while we're dropping
322         // data, but eventually we're guaranteed to break out of this loop
323         // (because there is a bounded number of senders).
324         let mut steals = unsafe { *self.queue.consumer_addition().steals.get() };
325         while {
326             match self.queue.producer_addition().cnt.compare_exchange(
327                 steals,
328                 DISCONNECTED,
329                 Ordering::SeqCst,
330                 Ordering::SeqCst,
331             ) {
332                 Ok(_) => false,
333                 Err(old) => old != DISCONNECTED,
334             }
335         } {
336             while self.queue.pop().is_some() {
337                 steals += 1;
338             }
339         }
340
341         // At this point in time, we have gated all future senders from sending,
342         // and we have flagged the channel as being disconnected. The senders
343         // still have some responsibility, however, because some sends might not
344         // complete until after we flag the disconnection. There are more
345         // details in the sending methods that see DISCONNECTED
346     }
347
348     ////////////////////////////////////////////////////////////////////////////
349     // select implementation
350     ////////////////////////////////////////////////////////////////////////////
351
352     // increment the count on the channel (used for selection)
353     fn bump(&self, amt: isize) -> isize {
354         match self.queue.producer_addition().cnt.fetch_add(amt, Ordering::SeqCst) {
355             DISCONNECTED => {
356                 self.queue.producer_addition().cnt.store(DISCONNECTED, Ordering::SeqCst);
357                 DISCONNECTED
358             }
359             n => n,
360         }
361     }
362
363     // Removes a previous thread from being blocked in this port
364     pub fn abort_selection(&self, was_upgrade: bool) -> Result<bool, Receiver<T>> {
365         // If we're aborting selection after upgrading from a oneshot, then
366         // we're guarantee that no one is waiting. The only way that we could
367         // have seen the upgrade is if data was actually sent on the channel
368         // half again. For us, this means that there is guaranteed to be data on
369         // this channel. Furthermore, we're guaranteed that there was no
370         // start_selection previously, so there's no need to modify `self.cnt`
371         // at all.
372         //
373         // Hence, because of these invariants, we immediately return `Ok(true)`.
374         // Note that the data might not actually be sent on the channel just yet.
375         // The other end could have flagged the upgrade but not sent data to
376         // this end. This is fine because we know it's a small bounded windows
377         // of time until the data is actually sent.
378         if was_upgrade {
379             assert_eq!(unsafe { *self.queue.consumer_addition().steals.get() }, 0);
380             assert_eq!(self.queue.producer_addition().to_wake.load(Ordering::SeqCst), EMPTY);
381             return Ok(true);
382         }
383
384         // We want to make sure that the count on the channel goes non-negative,
385         // and in the stream case we can have at most one steal, so just assume
386         // that we had one steal.
387         let steals = 1;
388         let prev = self.bump(steals + 1);
389
390         // If we were previously disconnected, then we know for sure that there
391         // is no thread in to_wake, so just keep going
392         let has_data = if prev == DISCONNECTED {
393             assert_eq!(self.queue.producer_addition().to_wake.load(Ordering::SeqCst), EMPTY);
394             true // there is data, that data is that we're disconnected
395         } else {
396             let cur = prev + steals + 1;
397             assert!(cur >= 0);
398
399             // If the previous count was negative, then we just made things go
400             // positive, hence we passed the -1 boundary and we're responsible
401             // for removing the to_wake() field and trashing it.
402             //
403             // If the previous count was positive then we're in a tougher
404             // situation. A possible race is that a sender just incremented
405             // through -1 (meaning it's going to try to wake a thread up), but it
406             // hasn't yet read the to_wake. In order to prevent a future recv()
407             // from waking up too early (this sender picking up the plastered
408             // over to_wake), we spin loop here waiting for to_wake to be 0.
409             // Note that this entire select() implementation needs an overhaul,
410             // and this is *not* the worst part of it, so this is not done as a
411             // final solution but rather out of necessity for now to get
412             // something working.
413             if prev < 0 {
414                 drop(self.take_to_wake());
415             } else {
416                 while self.queue.producer_addition().to_wake.load(Ordering::SeqCst) != EMPTY {
417                     thread::yield_now();
418                 }
419             }
420             unsafe {
421                 assert_eq!(*self.queue.consumer_addition().steals.get(), 0);
422                 *self.queue.consumer_addition().steals.get() = steals;
423             }
424
425             // if we were previously positive, then there's surely data to
426             // receive
427             prev >= 0
428         };
429
430         // Now that we've determined that this queue "has data", we peek at the
431         // queue to see if the data is an upgrade or not. If it's an upgrade,
432         // then we need to destroy this port and abort selection on the
433         // upgraded port.
434         if has_data {
435             match self.queue.peek() {
436                 Some(&mut GoUp(..)) => match self.queue.pop() {
437                     Some(GoUp(port)) => Err(port),
438                     _ => unreachable!(),
439                 },
440                 _ => Ok(true),
441             }
442         } else {
443             Ok(false)
444         }
445     }
446 }
447
448 impl<T> Drop for Packet<T> {
449     fn drop(&mut self) {
450         // Note that this load is not only an assert for correctness about
451         // disconnection, but also a proper fence before the read of
452         // `to_wake`, so this assert cannot be removed with also removing
453         // the `to_wake` assert.
454         assert_eq!(self.queue.producer_addition().cnt.load(Ordering::SeqCst), DISCONNECTED);
455         assert_eq!(self.queue.producer_addition().to_wake.load(Ordering::SeqCst), EMPTY);
456     }
457 }