]> git.lizzy.rs Git - rust.git/blob - library/std/src/sync/mpmc/array.rs
Rollup merge of #107391 - notriddle:notriddle/copy-path-button, r=GuillaumeGomez
[rust.git] / library / std / src / sync / mpmc / array.rs
1 //! Bounded channel based on a preallocated array.
2 //!
3 //! This flavor has a fixed, positive capacity.
4 //!
5 //! The implementation is based on Dmitry Vyukov's bounded MPMC queue.
6 //!
7 //! Source:
8 //!   - <http://www.1024cores.net/home/lock-free-algorithms/queues/bounded-mpmc-queue>
9 //!   - <https://docs.google.com/document/d/1yIAYmbvL3JxOKOjuCyon7JhW4cSv1wy5hC0ApeGMV9s/pub>
10
11 use super::context::Context;
12 use super::error::*;
13 use super::select::{Operation, Selected, Token};
14 use super::utils::{Backoff, CachePadded};
15 use super::waker::SyncWaker;
16
17 use crate::cell::UnsafeCell;
18 use crate::mem::MaybeUninit;
19 use crate::ptr;
20 use crate::sync::atomic::{self, AtomicUsize, Ordering};
21 use crate::time::Instant;
22
23 /// A slot in a channel.
24 struct Slot<T> {
25     /// The current stamp.
26     stamp: AtomicUsize,
27
28     /// The message in this slot.
29     msg: UnsafeCell<MaybeUninit<T>>,
30 }
31
32 /// The token type for the array flavor.
33 #[derive(Debug)]
34 pub(crate) struct ArrayToken {
35     /// Slot to read from or write to.
36     slot: *const u8,
37
38     /// Stamp to store into the slot after reading or writing.
39     stamp: usize,
40 }
41
42 impl Default for ArrayToken {
43     #[inline]
44     fn default() -> Self {
45         ArrayToken { slot: ptr::null(), stamp: 0 }
46     }
47 }
48
49 /// Bounded channel based on a preallocated array.
50 pub(crate) struct Channel<T> {
51     /// The head of the channel.
52     ///
53     /// This value is a "stamp" consisting of an index into the buffer, a mark bit, and a lap, but
54     /// packed into a single `usize`. The lower bits represent the index, while the upper bits
55     /// represent the lap. The mark bit in the head is always zero.
56     ///
57     /// Messages are popped from the head of the channel.
58     head: CachePadded<AtomicUsize>,
59
60     /// The tail of the channel.
61     ///
62     /// This value is a "stamp" consisting of an index into the buffer, a mark bit, and a lap, but
63     /// packed into a single `usize`. The lower bits represent the index, while the upper bits
64     /// represent the lap. The mark bit indicates that the channel is disconnected.
65     ///
66     /// Messages are pushed into the tail of the channel.
67     tail: CachePadded<AtomicUsize>,
68
69     /// The buffer holding slots.
70     buffer: Box<[Slot<T>]>,
71
72     /// The channel capacity.
73     cap: usize,
74
75     /// A stamp with the value of `{ lap: 1, mark: 0, index: 0 }`.
76     one_lap: usize,
77
78     /// If this bit is set in the tail, that means the channel is disconnected.
79     mark_bit: usize,
80
81     /// Senders waiting while the channel is full.
82     senders: SyncWaker,
83
84     /// Receivers waiting while the channel is empty and not disconnected.
85     receivers: SyncWaker,
86 }
87
88 impl<T> Channel<T> {
89     /// Creates a bounded channel of capacity `cap`.
90     pub(crate) fn with_capacity(cap: usize) -> Self {
91         assert!(cap > 0, "capacity must be positive");
92
93         // Compute constants `mark_bit` and `one_lap`.
94         let mark_bit = (cap + 1).next_power_of_two();
95         let one_lap = mark_bit * 2;
96
97         // Head is initialized to `{ lap: 0, mark: 0, index: 0 }`.
98         let head = 0;
99         // Tail is initialized to `{ lap: 0, mark: 0, index: 0 }`.
100         let tail = 0;
101
102         // Allocate a buffer of `cap` slots initialized
103         // with stamps.
104         let buffer: Box<[Slot<T>]> = (0..cap)
105             .map(|i| {
106                 // Set the stamp to `{ lap: 0, mark: 0, index: i }`.
107                 Slot { stamp: AtomicUsize::new(i), msg: UnsafeCell::new(MaybeUninit::uninit()) }
108             })
109             .collect();
110
111         Channel {
112             buffer,
113             cap,
114             one_lap,
115             mark_bit,
116             head: CachePadded::new(AtomicUsize::new(head)),
117             tail: CachePadded::new(AtomicUsize::new(tail)),
118             senders: SyncWaker::new(),
119             receivers: SyncWaker::new(),
120         }
121     }
122
123     /// Attempts to reserve a slot for sending a message.
124     fn start_send(&self, token: &mut Token) -> bool {
125         let backoff = Backoff::new();
126         let mut tail = self.tail.load(Ordering::Relaxed);
127
128         loop {
129             // Check if the channel is disconnected.
130             if tail & self.mark_bit != 0 {
131                 token.array.slot = ptr::null();
132                 token.array.stamp = 0;
133                 return true;
134             }
135
136             // Deconstruct the tail.
137             let index = tail & (self.mark_bit - 1);
138             let lap = tail & !(self.one_lap - 1);
139
140             // Inspect the corresponding slot.
141             debug_assert!(index < self.buffer.len());
142             let slot = unsafe { self.buffer.get_unchecked(index) };
143             let stamp = slot.stamp.load(Ordering::Acquire);
144
145             // If the tail and the stamp match, we may attempt to push.
146             if tail == stamp {
147                 let new_tail = if index + 1 < self.cap {
148                     // Same lap, incremented index.
149                     // Set to `{ lap: lap, mark: 0, index: index + 1 }`.
150                     tail + 1
151                 } else {
152                     // One lap forward, index wraps around to zero.
153                     // Set to `{ lap: lap.wrapping_add(1), mark: 0, index: 0 }`.
154                     lap.wrapping_add(self.one_lap)
155                 };
156
157                 // Try moving the tail.
158                 match self.tail.compare_exchange_weak(
159                     tail,
160                     new_tail,
161                     Ordering::SeqCst,
162                     Ordering::Relaxed,
163                 ) {
164                     Ok(_) => {
165                         // Prepare the token for the follow-up call to `write`.
166                         token.array.slot = slot as *const Slot<T> as *const u8;
167                         token.array.stamp = tail + 1;
168                         return true;
169                     }
170                     Err(_) => {
171                         backoff.spin_light();
172                         tail = self.tail.load(Ordering::Relaxed);
173                     }
174                 }
175             } else if stamp.wrapping_add(self.one_lap) == tail + 1 {
176                 atomic::fence(Ordering::SeqCst);
177                 let head = self.head.load(Ordering::Relaxed);
178
179                 // If the head lags one lap behind the tail as well...
180                 if head.wrapping_add(self.one_lap) == tail {
181                     // ...then the channel is full.
182                     return false;
183                 }
184
185                 backoff.spin_light();
186                 tail = self.tail.load(Ordering::Relaxed);
187             } else {
188                 // Snooze because we need to wait for the stamp to get updated.
189                 backoff.spin_heavy();
190                 tail = self.tail.load(Ordering::Relaxed);
191             }
192         }
193     }
194
195     /// Writes a message into the channel.
196     pub(crate) unsafe fn write(&self, token: &mut Token, msg: T) -> Result<(), T> {
197         // If there is no slot, the channel is disconnected.
198         if token.array.slot.is_null() {
199             return Err(msg);
200         }
201
202         let slot: &Slot<T> = &*(token.array.slot as *const Slot<T>);
203
204         // Write the message into the slot and update the stamp.
205         slot.msg.get().write(MaybeUninit::new(msg));
206         slot.stamp.store(token.array.stamp, Ordering::Release);
207
208         // Wake a sleeping receiver.
209         self.receivers.notify();
210         Ok(())
211     }
212
213     /// Attempts to reserve a slot for receiving a message.
214     fn start_recv(&self, token: &mut Token) -> bool {
215         let backoff = Backoff::new();
216         let mut head = self.head.load(Ordering::Relaxed);
217
218         loop {
219             // Deconstruct the head.
220             let index = head & (self.mark_bit - 1);
221             let lap = head & !(self.one_lap - 1);
222
223             // Inspect the corresponding slot.
224             debug_assert!(index < self.buffer.len());
225             let slot = unsafe { self.buffer.get_unchecked(index) };
226             let stamp = slot.stamp.load(Ordering::Acquire);
227
228             // If the stamp is ahead of the head by 1, we may attempt to pop.
229             if head + 1 == stamp {
230                 let new = if index + 1 < self.cap {
231                     // Same lap, incremented index.
232                     // Set to `{ lap: lap, mark: 0, index: index + 1 }`.
233                     head + 1
234                 } else {
235                     // One lap forward, index wraps around to zero.
236                     // Set to `{ lap: lap.wrapping_add(1), mark: 0, index: 0 }`.
237                     lap.wrapping_add(self.one_lap)
238                 };
239
240                 // Try moving the head.
241                 match self.head.compare_exchange_weak(
242                     head,
243                     new,
244                     Ordering::SeqCst,
245                     Ordering::Relaxed,
246                 ) {
247                     Ok(_) => {
248                         // Prepare the token for the follow-up call to `read`.
249                         token.array.slot = slot as *const Slot<T> as *const u8;
250                         token.array.stamp = head.wrapping_add(self.one_lap);
251                         return true;
252                     }
253                     Err(_) => {
254                         backoff.spin_light();
255                         head = self.head.load(Ordering::Relaxed);
256                     }
257                 }
258             } else if stamp == head {
259                 atomic::fence(Ordering::SeqCst);
260                 let tail = self.tail.load(Ordering::Relaxed);
261
262                 // If the tail equals the head, that means the channel is empty.
263                 if (tail & !self.mark_bit) == head {
264                     // If the channel is disconnected...
265                     if tail & self.mark_bit != 0 {
266                         // ...then receive an error.
267                         token.array.slot = ptr::null();
268                         token.array.stamp = 0;
269                         return true;
270                     } else {
271                         // Otherwise, the receive operation is not ready.
272                         return false;
273                     }
274                 }
275
276                 backoff.spin_light();
277                 head = self.head.load(Ordering::Relaxed);
278             } else {
279                 // Snooze because we need to wait for the stamp to get updated.
280                 backoff.spin_heavy();
281                 head = self.head.load(Ordering::Relaxed);
282             }
283         }
284     }
285
286     /// Reads a message from the channel.
287     pub(crate) unsafe fn read(&self, token: &mut Token) -> Result<T, ()> {
288         if token.array.slot.is_null() {
289             // The channel is disconnected.
290             return Err(());
291         }
292
293         let slot: &Slot<T> = &*(token.array.slot as *const Slot<T>);
294
295         // Read the message from the slot and update the stamp.
296         let msg = slot.msg.get().read().assume_init();
297         slot.stamp.store(token.array.stamp, Ordering::Release);
298
299         // Wake a sleeping sender.
300         self.senders.notify();
301         Ok(msg)
302     }
303
304     /// Attempts to send a message into the channel.
305     pub(crate) fn try_send(&self, msg: T) -> Result<(), TrySendError<T>> {
306         let token = &mut Token::default();
307         if self.start_send(token) {
308             unsafe { self.write(token, msg).map_err(TrySendError::Disconnected) }
309         } else {
310             Err(TrySendError::Full(msg))
311         }
312     }
313
314     /// Sends a message into the channel.
315     pub(crate) fn send(
316         &self,
317         msg: T,
318         deadline: Option<Instant>,
319     ) -> Result<(), SendTimeoutError<T>> {
320         let token = &mut Token::default();
321         loop {
322             // Try sending a message.
323             if self.start_send(token) {
324                 let res = unsafe { self.write(token, msg) };
325                 return res.map_err(SendTimeoutError::Disconnected);
326             }
327
328             if let Some(d) = deadline {
329                 if Instant::now() >= d {
330                     return Err(SendTimeoutError::Timeout(msg));
331                 }
332             }
333
334             Context::with(|cx| {
335                 // Prepare for blocking until a receiver wakes us up.
336                 let oper = Operation::hook(token);
337                 self.senders.register(oper, cx);
338
339                 // Has the channel become ready just now?
340                 if !self.is_full() || self.is_disconnected() {
341                     let _ = cx.try_select(Selected::Aborted);
342                 }
343
344                 // Block the current thread.
345                 let sel = cx.wait_until(deadline);
346
347                 match sel {
348                     Selected::Waiting => unreachable!(),
349                     Selected::Aborted | Selected::Disconnected => {
350                         self.senders.unregister(oper).unwrap();
351                     }
352                     Selected::Operation(_) => {}
353                 }
354             });
355         }
356     }
357
358     /// Attempts to receive a message without blocking.
359     pub(crate) fn try_recv(&self) -> Result<T, TryRecvError> {
360         let token = &mut Token::default();
361
362         if self.start_recv(token) {
363             unsafe { self.read(token).map_err(|_| TryRecvError::Disconnected) }
364         } else {
365             Err(TryRecvError::Empty)
366         }
367     }
368
369     /// Receives a message from the channel.
370     pub(crate) fn recv(&self, deadline: Option<Instant>) -> Result<T, RecvTimeoutError> {
371         let token = &mut Token::default();
372         loop {
373             // Try receiving a message.
374             if self.start_recv(token) {
375                 let res = unsafe { self.read(token) };
376                 return res.map_err(|_| RecvTimeoutError::Disconnected);
377             }
378
379             if let Some(d) = deadline {
380                 if Instant::now() >= d {
381                     return Err(RecvTimeoutError::Timeout);
382                 }
383             }
384
385             Context::with(|cx| {
386                 // Prepare for blocking until a sender wakes us up.
387                 let oper = Operation::hook(token);
388                 self.receivers.register(oper, cx);
389
390                 // Has the channel become ready just now?
391                 if !self.is_empty() || self.is_disconnected() {
392                     let _ = cx.try_select(Selected::Aborted);
393                 }
394
395                 // Block the current thread.
396                 let sel = cx.wait_until(deadline);
397
398                 match sel {
399                     Selected::Waiting => unreachable!(),
400                     Selected::Aborted | Selected::Disconnected => {
401                         self.receivers.unregister(oper).unwrap();
402                         // If the channel was disconnected, we still have to check for remaining
403                         // messages.
404                     }
405                     Selected::Operation(_) => {}
406                 }
407             });
408         }
409     }
410
411     /// Returns the current number of messages inside the channel.
412     pub(crate) fn len(&self) -> usize {
413         loop {
414             // Load the tail, then load the head.
415             let tail = self.tail.load(Ordering::SeqCst);
416             let head = self.head.load(Ordering::SeqCst);
417
418             // If the tail didn't change, we've got consistent values to work with.
419             if self.tail.load(Ordering::SeqCst) == tail {
420                 let hix = head & (self.mark_bit - 1);
421                 let tix = tail & (self.mark_bit - 1);
422
423                 return if hix < tix {
424                     tix - hix
425                 } else if hix > tix {
426                     self.cap - hix + tix
427                 } else if (tail & !self.mark_bit) == head {
428                     0
429                 } else {
430                     self.cap
431                 };
432             }
433         }
434     }
435
436     /// Returns the capacity of the channel.
437     #[allow(clippy::unnecessary_wraps)] // This is intentional.
438     pub(crate) fn capacity(&self) -> Option<usize> {
439         Some(self.cap)
440     }
441
442     /// Disconnects the channel and wakes up all blocked senders and receivers.
443     ///
444     /// Returns `true` if this call disconnected the channel.
445     pub(crate) fn disconnect(&self) -> bool {
446         let tail = self.tail.fetch_or(self.mark_bit, Ordering::SeqCst);
447
448         if tail & self.mark_bit == 0 {
449             self.senders.disconnect();
450             self.receivers.disconnect();
451             true
452         } else {
453             false
454         }
455     }
456
457     /// Returns `true` if the channel is disconnected.
458     pub(crate) fn is_disconnected(&self) -> bool {
459         self.tail.load(Ordering::SeqCst) & self.mark_bit != 0
460     }
461
462     /// Returns `true` if the channel is empty.
463     pub(crate) fn is_empty(&self) -> bool {
464         let head = self.head.load(Ordering::SeqCst);
465         let tail = self.tail.load(Ordering::SeqCst);
466
467         // Is the tail equal to the head?
468         //
469         // Note: If the head changes just before we load the tail, that means there was a moment
470         // when the channel was not empty, so it is safe to just return `false`.
471         (tail & !self.mark_bit) == head
472     }
473
474     /// Returns `true` if the channel is full.
475     pub(crate) fn is_full(&self) -> bool {
476         let tail = self.tail.load(Ordering::SeqCst);
477         let head = self.head.load(Ordering::SeqCst);
478
479         // Is the head lagging one lap behind tail?
480         //
481         // Note: If the tail changes just before we load the head, that means there was a moment
482         // when the channel was not full, so it is safe to just return `false`.
483         head.wrapping_add(self.one_lap) == tail & !self.mark_bit
484     }
485 }
486
487 impl<T> Drop for Channel<T> {
488     fn drop(&mut self) {
489         // Get the index of the head.
490         let hix = self.head.load(Ordering::Relaxed) & (self.mark_bit - 1);
491
492         // Loop over all slots that hold a message and drop them.
493         for i in 0..self.len() {
494             // Compute the index of the next slot holding a message.
495             let index = if hix + i < self.cap { hix + i } else { hix + i - self.cap };
496
497             unsafe {
498                 debug_assert!(index < self.buffer.len());
499                 let slot = self.buffer.get_unchecked_mut(index);
500                 let msg = &mut *slot.msg.get();
501                 msg.as_mut_ptr().drop_in_place();
502             }
503         }
504     }
505 }