]> git.lizzy.rs Git - rust.git/blob - library/std/src/sync/mpmc/array.rs
review comment: add test case
[rust.git] / library / std / src / sync / mpmc / array.rs
1 //! Bounded channel based on a preallocated array.
2 //!
3 //! This flavor has a fixed, positive capacity.
4 //!
5 //! The implementation is based on Dmitry Vyukov's bounded MPMC queue.
6 //!
7 //! Source:
8 //!   - <http://www.1024cores.net/home/lock-free-algorithms/queues/bounded-mpmc-queue>
9 //!   - <https://docs.google.com/document/d/1yIAYmbvL3JxOKOjuCyon7JhW4cSv1wy5hC0ApeGMV9s/pub>
10
11 use super::context::Context;
12 use super::error::*;
13 use super::select::{Operation, Selected, Token};
14 use super::utils::{Backoff, CachePadded};
15 use super::waker::SyncWaker;
16
17 use crate::cell::UnsafeCell;
18 use crate::mem::MaybeUninit;
19 use crate::ptr;
20 use crate::sync::atomic::{self, AtomicUsize, Ordering};
21 use crate::time::Instant;
22
23 /// A slot in a channel.
24 struct Slot<T> {
25     /// The current stamp.
26     stamp: AtomicUsize,
27
28     /// The message in this slot.
29     msg: UnsafeCell<MaybeUninit<T>>,
30 }
31
32 /// The token type for the array flavor.
33 #[derive(Debug)]
34 pub(crate) struct ArrayToken {
35     /// Slot to read from or write to.
36     slot: *const u8,
37
38     /// Stamp to store into the slot after reading or writing.
39     stamp: usize,
40 }
41
42 impl Default for ArrayToken {
43     #[inline]
44     fn default() -> Self {
45         ArrayToken { slot: ptr::null(), stamp: 0 }
46     }
47 }
48
49 /// Bounded channel based on a preallocated array.
50 pub(crate) struct Channel<T> {
51     /// The head of the channel.
52     ///
53     /// This value is a "stamp" consisting of an index into the buffer, a mark bit, and a lap, but
54     /// packed into a single `usize`. The lower bits represent the index, while the upper bits
55     /// represent the lap. The mark bit in the head is always zero.
56     ///
57     /// Messages are popped from the head of the channel.
58     head: CachePadded<AtomicUsize>,
59
60     /// The tail of the channel.
61     ///
62     /// This value is a "stamp" consisting of an index into the buffer, a mark bit, and a lap, but
63     /// packed into a single `usize`. The lower bits represent the index, while the upper bits
64     /// represent the lap. The mark bit indicates that the channel is disconnected.
65     ///
66     /// Messages are pushed into the tail of the channel.
67     tail: CachePadded<AtomicUsize>,
68
69     /// The buffer holding slots.
70     buffer: Box<[Slot<T>]>,
71
72     /// The channel capacity.
73     cap: usize,
74
75     /// A stamp with the value of `{ lap: 1, mark: 0, index: 0 }`.
76     one_lap: usize,
77
78     /// If this bit is set in the tail, that means the channel is disconnected.
79     mark_bit: usize,
80
81     /// Senders waiting while the channel is full.
82     senders: SyncWaker,
83
84     /// Receivers waiting while the channel is empty and not disconnected.
85     receivers: SyncWaker,
86 }
87
88 impl<T> Channel<T> {
89     /// Creates a bounded channel of capacity `cap`.
90     pub(crate) fn with_capacity(cap: usize) -> Self {
91         assert!(cap > 0, "capacity must be positive");
92
93         // Compute constants `mark_bit` and `one_lap`.
94         let mark_bit = (cap + 1).next_power_of_two();
95         let one_lap = mark_bit * 2;
96
97         // Head is initialized to `{ lap: 0, mark: 0, index: 0 }`.
98         let head = 0;
99         // Tail is initialized to `{ lap: 0, mark: 0, index: 0 }`.
100         let tail = 0;
101
102         // Allocate a buffer of `cap` slots initialized
103         // with stamps.
104         let buffer: Box<[Slot<T>]> = (0..cap)
105             .map(|i| {
106                 // Set the stamp to `{ lap: 0, mark: 0, index: i }`.
107                 Slot { stamp: AtomicUsize::new(i), msg: UnsafeCell::new(MaybeUninit::uninit()) }
108             })
109             .collect();
110
111         Channel {
112             buffer,
113             cap,
114             one_lap,
115             mark_bit,
116             head: CachePadded::new(AtomicUsize::new(head)),
117             tail: CachePadded::new(AtomicUsize::new(tail)),
118             senders: SyncWaker::new(),
119             receivers: SyncWaker::new(),
120         }
121     }
122
123     /// Attempts to reserve a slot for sending a message.
124     fn start_send(&self, token: &mut Token) -> bool {
125         let backoff = Backoff::new();
126         let mut tail = self.tail.load(Ordering::Relaxed);
127
128         loop {
129             // Check if the channel is disconnected.
130             if tail & self.mark_bit != 0 {
131                 token.array.slot = ptr::null();
132                 token.array.stamp = 0;
133                 return true;
134             }
135
136             // Deconstruct the tail.
137             let index = tail & (self.mark_bit - 1);
138             let lap = tail & !(self.one_lap - 1);
139
140             // Inspect the corresponding slot.
141             debug_assert!(index < self.buffer.len());
142             let slot = unsafe { self.buffer.get_unchecked(index) };
143             let stamp = slot.stamp.load(Ordering::Acquire);
144
145             // If the tail and the stamp match, we may attempt to push.
146             if tail == stamp {
147                 let new_tail = if index + 1 < self.cap {
148                     // Same lap, incremented index.
149                     // Set to `{ lap: lap, mark: 0, index: index + 1 }`.
150                     tail + 1
151                 } else {
152                     // One lap forward, index wraps around to zero.
153                     // Set to `{ lap: lap.wrapping_add(1), mark: 0, index: 0 }`.
154                     lap.wrapping_add(self.one_lap)
155                 };
156
157                 // Try moving the tail.
158                 match self.tail.compare_exchange_weak(
159                     tail,
160                     new_tail,
161                     Ordering::SeqCst,
162                     Ordering::Relaxed,
163                 ) {
164                     Ok(_) => {
165                         // Prepare the token for the follow-up call to `write`.
166                         token.array.slot = slot as *const Slot<T> as *const u8;
167                         token.array.stamp = tail + 1;
168                         return true;
169                     }
170                     Err(_) => {
171                         backoff.spin();
172                         tail = self.tail.load(Ordering::Relaxed);
173                     }
174                 }
175             } else if stamp.wrapping_add(self.one_lap) == tail + 1 {
176                 atomic::fence(Ordering::SeqCst);
177                 let head = self.head.load(Ordering::Relaxed);
178
179                 // If the head lags one lap behind the tail as well...
180                 if head.wrapping_add(self.one_lap) == tail {
181                     // ...then the channel is full.
182                     return false;
183                 }
184
185                 backoff.spin();
186                 tail = self.tail.load(Ordering::Relaxed);
187             } else {
188                 // Snooze because we need to wait for the stamp to get updated.
189                 backoff.snooze();
190                 tail = self.tail.load(Ordering::Relaxed);
191             }
192         }
193     }
194
195     /// Writes a message into the channel.
196     pub(crate) unsafe fn write(&self, token: &mut Token, msg: T) -> Result<(), T> {
197         // If there is no slot, the channel is disconnected.
198         if token.array.slot.is_null() {
199             return Err(msg);
200         }
201
202         let slot: &Slot<T> = &*(token.array.slot as *const Slot<T>);
203
204         // Write the message into the slot and update the stamp.
205         slot.msg.get().write(MaybeUninit::new(msg));
206         slot.stamp.store(token.array.stamp, Ordering::Release);
207
208         // Wake a sleeping receiver.
209         self.receivers.notify();
210         Ok(())
211     }
212
213     /// Attempts to reserve a slot for receiving a message.
214     fn start_recv(&self, token: &mut Token) -> bool {
215         let backoff = Backoff::new();
216         let mut head = self.head.load(Ordering::Relaxed);
217
218         loop {
219             // Deconstruct the head.
220             let index = head & (self.mark_bit - 1);
221             let lap = head & !(self.one_lap - 1);
222
223             // Inspect the corresponding slot.
224             debug_assert!(index < self.buffer.len());
225             let slot = unsafe { self.buffer.get_unchecked(index) };
226             let stamp = slot.stamp.load(Ordering::Acquire);
227
228             // If the the stamp is ahead of the head by 1, we may attempt to pop.
229             if head + 1 == stamp {
230                 let new = if index + 1 < self.cap {
231                     // Same lap, incremented index.
232                     // Set to `{ lap: lap, mark: 0, index: index + 1 }`.
233                     head + 1
234                 } else {
235                     // One lap forward, index wraps around to zero.
236                     // Set to `{ lap: lap.wrapping_add(1), mark: 0, index: 0 }`.
237                     lap.wrapping_add(self.one_lap)
238                 };
239
240                 // Try moving the head.
241                 match self.head.compare_exchange_weak(
242                     head,
243                     new,
244                     Ordering::SeqCst,
245                     Ordering::Relaxed,
246                 ) {
247                     Ok(_) => {
248                         // Prepare the token for the follow-up call to `read`.
249                         token.array.slot = slot as *const Slot<T> as *const u8;
250                         token.array.stamp = head.wrapping_add(self.one_lap);
251                         return true;
252                     }
253                     Err(_) => {
254                         backoff.spin();
255                         head = self.head.load(Ordering::Relaxed);
256                     }
257                 }
258             } else if stamp == head {
259                 atomic::fence(Ordering::SeqCst);
260                 let tail = self.tail.load(Ordering::Relaxed);
261
262                 // If the tail equals the head, that means the channel is empty.
263                 if (tail & !self.mark_bit) == head {
264                     // If the channel is disconnected...
265                     if tail & self.mark_bit != 0 {
266                         // ...then receive an error.
267                         token.array.slot = ptr::null();
268                         token.array.stamp = 0;
269                         return true;
270                     } else {
271                         // Otherwise, the receive operation is not ready.
272                         return false;
273                     }
274                 }
275
276                 backoff.spin();
277                 head = self.head.load(Ordering::Relaxed);
278             } else {
279                 // Snooze because we need to wait for the stamp to get updated.
280                 backoff.snooze();
281                 head = self.head.load(Ordering::Relaxed);
282             }
283         }
284     }
285
286     /// Reads a message from the channel.
287     pub(crate) unsafe fn read(&self, token: &mut Token) -> Result<T, ()> {
288         if token.array.slot.is_null() {
289             // The channel is disconnected.
290             return Err(());
291         }
292
293         let slot: &Slot<T> = &*(token.array.slot as *const Slot<T>);
294
295         // Read the message from the slot and update the stamp.
296         let msg = slot.msg.get().read().assume_init();
297         slot.stamp.store(token.array.stamp, Ordering::Release);
298
299         // Wake a sleeping sender.
300         self.senders.notify();
301         Ok(msg)
302     }
303
304     /// Attempts to send a message into the channel.
305     pub(crate) fn try_send(&self, msg: T) -> Result<(), TrySendError<T>> {
306         let token = &mut Token::default();
307         if self.start_send(token) {
308             unsafe { self.write(token, msg).map_err(TrySendError::Disconnected) }
309         } else {
310             Err(TrySendError::Full(msg))
311         }
312     }
313
314     /// Sends a message into the channel.
315     pub(crate) fn send(
316         &self,
317         msg: T,
318         deadline: Option<Instant>,
319     ) -> Result<(), SendTimeoutError<T>> {
320         let token = &mut Token::default();
321         loop {
322             // Try sending a message several times.
323             let backoff = Backoff::new();
324             loop {
325                 if self.start_send(token) {
326                     let res = unsafe { self.write(token, msg) };
327                     return res.map_err(SendTimeoutError::Disconnected);
328                 }
329
330                 if backoff.is_completed() {
331                     break;
332                 } else {
333                     backoff.spin();
334                 }
335             }
336
337             if let Some(d) = deadline {
338                 if Instant::now() >= d {
339                     return Err(SendTimeoutError::Timeout(msg));
340                 }
341             }
342
343             Context::with(|cx| {
344                 // Prepare for blocking until a receiver wakes us up.
345                 let oper = Operation::hook(token);
346                 self.senders.register(oper, cx);
347
348                 // Has the channel become ready just now?
349                 if !self.is_full() || self.is_disconnected() {
350                     let _ = cx.try_select(Selected::Aborted);
351                 }
352
353                 // Block the current thread.
354                 let sel = cx.wait_until(deadline);
355
356                 match sel {
357                     Selected::Waiting => unreachable!(),
358                     Selected::Aborted | Selected::Disconnected => {
359                         self.senders.unregister(oper).unwrap();
360                     }
361                     Selected::Operation(_) => {}
362                 }
363             });
364         }
365     }
366
367     /// Attempts to receive a message without blocking.
368     pub(crate) fn try_recv(&self) -> Result<T, TryRecvError> {
369         let token = &mut Token::default();
370
371         if self.start_recv(token) {
372             unsafe { self.read(token).map_err(|_| TryRecvError::Disconnected) }
373         } else {
374             Err(TryRecvError::Empty)
375         }
376     }
377
378     /// Receives a message from the channel.
379     pub(crate) fn recv(&self, deadline: Option<Instant>) -> Result<T, RecvTimeoutError> {
380         let token = &mut Token::default();
381         loop {
382             if self.start_recv(token) {
383                 let res = unsafe { self.read(token) };
384                 return res.map_err(|_| RecvTimeoutError::Disconnected);
385             }
386
387             if let Some(d) = deadline {
388                 if Instant::now() >= d {
389                     return Err(RecvTimeoutError::Timeout);
390                 }
391             }
392
393             Context::with(|cx| {
394                 // Prepare for blocking until a sender wakes us up.
395                 let oper = Operation::hook(token);
396                 self.receivers.register(oper, cx);
397
398                 // Has the channel become ready just now?
399                 if !self.is_empty() || self.is_disconnected() {
400                     let _ = cx.try_select(Selected::Aborted);
401                 }
402
403                 // Block the current thread.
404                 let sel = cx.wait_until(deadline);
405
406                 match sel {
407                     Selected::Waiting => unreachable!(),
408                     Selected::Aborted | Selected::Disconnected => {
409                         self.receivers.unregister(oper).unwrap();
410                         // If the channel was disconnected, we still have to check for remaining
411                         // messages.
412                     }
413                     Selected::Operation(_) => {}
414                 }
415             });
416         }
417     }
418
419     /// Returns the current number of messages inside the channel.
420     pub(crate) fn len(&self) -> usize {
421         loop {
422             // Load the tail, then load the head.
423             let tail = self.tail.load(Ordering::SeqCst);
424             let head = self.head.load(Ordering::SeqCst);
425
426             // If the tail didn't change, we've got consistent values to work with.
427             if self.tail.load(Ordering::SeqCst) == tail {
428                 let hix = head & (self.mark_bit - 1);
429                 let tix = tail & (self.mark_bit - 1);
430
431                 return if hix < tix {
432                     tix - hix
433                 } else if hix > tix {
434                     self.cap - hix + tix
435                 } else if (tail & !self.mark_bit) == head {
436                     0
437                 } else {
438                     self.cap
439                 };
440             }
441         }
442     }
443
444     /// Returns the capacity of the channel.
445     #[allow(clippy::unnecessary_wraps)] // This is intentional.
446     pub(crate) fn capacity(&self) -> Option<usize> {
447         Some(self.cap)
448     }
449
450     /// Disconnects the channel and wakes up all blocked senders and receivers.
451     ///
452     /// Returns `true` if this call disconnected the channel.
453     pub(crate) fn disconnect(&self) -> bool {
454         let tail = self.tail.fetch_or(self.mark_bit, Ordering::SeqCst);
455
456         if tail & self.mark_bit == 0 {
457             self.senders.disconnect();
458             self.receivers.disconnect();
459             true
460         } else {
461             false
462         }
463     }
464
465     /// Returns `true` if the channel is disconnected.
466     pub(crate) fn is_disconnected(&self) -> bool {
467         self.tail.load(Ordering::SeqCst) & self.mark_bit != 0
468     }
469
470     /// Returns `true` if the channel is empty.
471     pub(crate) fn is_empty(&self) -> bool {
472         let head = self.head.load(Ordering::SeqCst);
473         let tail = self.tail.load(Ordering::SeqCst);
474
475         // Is the tail equal to the head?
476         //
477         // Note: If the head changes just before we load the tail, that means there was a moment
478         // when the channel was not empty, so it is safe to just return `false`.
479         (tail & !self.mark_bit) == head
480     }
481
482     /// Returns `true` if the channel is full.
483     pub(crate) fn is_full(&self) -> bool {
484         let tail = self.tail.load(Ordering::SeqCst);
485         let head = self.head.load(Ordering::SeqCst);
486
487         // Is the head lagging one lap behind tail?
488         //
489         // Note: If the tail changes just before we load the head, that means there was a moment
490         // when the channel was not full, so it is safe to just return `false`.
491         head.wrapping_add(self.one_lap) == tail & !self.mark_bit
492     }
493 }
494
495 impl<T> Drop for Channel<T> {
496     fn drop(&mut self) {
497         // Get the index of the head.
498         let hix = self.head.load(Ordering::Relaxed) & (self.mark_bit - 1);
499
500         // Loop over all slots that hold a message and drop them.
501         for i in 0..self.len() {
502             // Compute the index of the next slot holding a message.
503             let index = if hix + i < self.cap { hix + i } else { hix + i - self.cap };
504
505             unsafe {
506                 debug_assert!(index < self.buffer.len());
507                 let slot = self.buffer.get_unchecked_mut(index);
508                 let msg = &mut *slot.msg.get();
509                 msg.as_mut_ptr().drop_in_place();
510             }
511         }
512     }
513 }