]> git.lizzy.rs Git - rust.git/blob - library/std/src/sync/mpmc/list.rs
5bc196995b14e74c3e27d7ac13e7ae93e33669bf
[rust.git] / library / std / src / sync / mpmc / list.rs
1 //! Unbounded channel implemented as a linked list.
2
3 use super::context::Context;
4 use super::error::*;
5 use super::select::{Operation, Selected, Token};
6 use super::utils::{Backoff, CachePadded};
7 use super::waker::SyncWaker;
8
9 use crate::cell::UnsafeCell;
10 use crate::marker::PhantomData;
11 use crate::mem::MaybeUninit;
12 use crate::ptr;
13 use crate::sync::atomic::{self, AtomicPtr, AtomicUsize, Ordering};
14 use crate::time::Instant;
15
16 // Bits indicating the state of a slot:
17 // * If a message has been written into the slot, `WRITE` is set.
18 // * If a message has been read from the slot, `READ` is set.
19 // * If the block is being destroyed, `DESTROY` is set.
20 const WRITE: usize = 1;
21 const READ: usize = 2;
22 const DESTROY: usize = 4;
23
24 // Each block covers one "lap" of indices.
25 const LAP: usize = 32;
26 // The maximum number of messages a block can hold.
27 const BLOCK_CAP: usize = LAP - 1;
28 // How many lower bits are reserved for metadata.
29 const SHIFT: usize = 1;
30 // Has two different purposes:
31 // * If set in head, indicates that the block is not the last one.
32 // * If set in tail, indicates that the channel is disconnected.
33 const MARK_BIT: usize = 1;
34
35 /// A slot in a block.
36 struct Slot<T> {
37     /// The message.
38     msg: UnsafeCell<MaybeUninit<T>>,
39
40     /// The state of the slot.
41     state: AtomicUsize,
42 }
43
44 impl<T> Slot<T> {
45     /// Waits until a message is written into the slot.
46     fn wait_write(&self) {
47         let backoff = Backoff::new();
48         while self.state.load(Ordering::Acquire) & WRITE == 0 {
49             backoff.snooze();
50         }
51     }
52 }
53
54 /// A block in a linked list.
55 ///
56 /// Each block in the list can hold up to `BLOCK_CAP` messages.
57 struct Block<T> {
58     /// The next block in the linked list.
59     next: AtomicPtr<Block<T>>,
60
61     /// Slots for messages.
62     slots: [Slot<T>; BLOCK_CAP],
63 }
64
65 impl<T> Block<T> {
66     /// Creates an empty block.
67     fn new() -> Block<T> {
68         // SAFETY: This is safe because:
69         //  [1] `Block::next` (AtomicPtr) may be safely zero initialized.
70         //  [2] `Block::slots` (Array) may be safely zero initialized because of [3, 4].
71         //  [3] `Slot::msg` (UnsafeCell) may be safely zero initialized because it
72         //       holds a MaybeUninit.
73         //  [4] `Slot::state` (AtomicUsize) may be safely zero initialized.
74         unsafe { MaybeUninit::zeroed().assume_init() }
75     }
76
77     /// Waits until the next pointer is set.
78     fn wait_next(&self) -> *mut Block<T> {
79         let backoff = Backoff::new();
80         loop {
81             let next = self.next.load(Ordering::Acquire);
82             if !next.is_null() {
83                 return next;
84             }
85             backoff.snooze();
86         }
87     }
88
89     /// Sets the `DESTROY` bit in slots starting from `start` and destroys the block.
90     unsafe fn destroy(this: *mut Block<T>, start: usize) {
91         // It is not necessary to set the `DESTROY` bit in the last slot because that slot has
92         // begun destruction of the block.
93         for i in start..BLOCK_CAP - 1 {
94             let slot = (*this).slots.get_unchecked(i);
95
96             // Mark the `DESTROY` bit if a thread is still using the slot.
97             if slot.state.load(Ordering::Acquire) & READ == 0
98                 && slot.state.fetch_or(DESTROY, Ordering::AcqRel) & READ == 0
99             {
100                 // If a thread is still using the slot, it will continue destruction of the block.
101                 return;
102             }
103         }
104
105         // No thread is using the block, now it is safe to destroy it.
106         drop(Box::from_raw(this));
107     }
108 }
109
110 /// A position in a channel.
111 #[derive(Debug)]
112 struct Position<T> {
113     /// The index in the channel.
114     index: AtomicUsize,
115
116     /// The block in the linked list.
117     block: AtomicPtr<Block<T>>,
118 }
119
120 /// The token type for the list flavor.
121 #[derive(Debug)]
122 pub(crate) struct ListToken {
123     /// The block of slots.
124     block: *const u8,
125
126     /// The offset into the block.
127     offset: usize,
128 }
129
130 impl Default for ListToken {
131     #[inline]
132     fn default() -> Self {
133         ListToken { block: ptr::null(), offset: 0 }
134     }
135 }
136
137 /// Unbounded channel implemented as a linked list.
138 ///
139 /// Each message sent into the channel is assigned a sequence number, i.e. an index. Indices are
140 /// represented as numbers of type `usize` and wrap on overflow.
141 ///
142 /// Consecutive messages are grouped into blocks in order to put less pressure on the allocator and
143 /// improve cache efficiency.
144 pub(crate) struct Channel<T> {
145     /// The head of the channel.
146     head: CachePadded<Position<T>>,
147
148     /// The tail of the channel.
149     tail: CachePadded<Position<T>>,
150
151     /// Receivers waiting while the channel is empty and not disconnected.
152     receivers: SyncWaker,
153
154     /// Indicates that dropping a `Channel<T>` may drop messages of type `T`.
155     _marker: PhantomData<T>,
156 }
157
158 impl<T> Channel<T> {
159     /// Creates a new unbounded channel.
160     pub(crate) fn new() -> Self {
161         Channel {
162             head: CachePadded::new(Position {
163                 block: AtomicPtr::new(ptr::null_mut()),
164                 index: AtomicUsize::new(0),
165             }),
166             tail: CachePadded::new(Position {
167                 block: AtomicPtr::new(ptr::null_mut()),
168                 index: AtomicUsize::new(0),
169             }),
170             receivers: SyncWaker::new(),
171             _marker: PhantomData,
172         }
173     }
174
175     /// Attempts to reserve a slot for sending a message.
176     fn start_send(&self, token: &mut Token) -> bool {
177         let backoff = Backoff::new();
178         let mut tail = self.tail.index.load(Ordering::Acquire);
179         let mut block = self.tail.block.load(Ordering::Acquire);
180         let mut next_block = None;
181
182         loop {
183             // Check if the channel is disconnected.
184             if tail & MARK_BIT != 0 {
185                 token.list.block = ptr::null();
186                 return true;
187             }
188
189             // Calculate the offset of the index into the block.
190             let offset = (tail >> SHIFT) % LAP;
191
192             // If we reached the end of the block, wait until the next one is installed.
193             if offset == BLOCK_CAP {
194                 backoff.snooze();
195                 tail = self.tail.index.load(Ordering::Acquire);
196                 block = self.tail.block.load(Ordering::Acquire);
197                 continue;
198             }
199
200             // If we're going to have to install the next block, allocate it in advance in order to
201             // make the wait for other threads as short as possible.
202             if offset + 1 == BLOCK_CAP && next_block.is_none() {
203                 next_block = Some(Box::new(Block::<T>::new()));
204             }
205
206             // If this is the first message to be sent into the channel, we need to allocate the
207             // first block and install it.
208             if block.is_null() {
209                 let new = Box::into_raw(Box::new(Block::<T>::new()));
210
211                 if self
212                     .tail
213                     .block
214                     .compare_exchange(block, new, Ordering::Release, Ordering::Relaxed)
215                     .is_ok()
216                 {
217                     self.head.block.store(new, Ordering::Release);
218                     block = new;
219                 } else {
220                     next_block = unsafe { Some(Box::from_raw(new)) };
221                     tail = self.tail.index.load(Ordering::Acquire);
222                     block = self.tail.block.load(Ordering::Acquire);
223                     continue;
224                 }
225             }
226
227             let new_tail = tail + (1 << SHIFT);
228
229             // Try advancing the tail forward.
230             match self.tail.index.compare_exchange_weak(
231                 tail,
232                 new_tail,
233                 Ordering::SeqCst,
234                 Ordering::Acquire,
235             ) {
236                 Ok(_) => unsafe {
237                     // If we've reached the end of the block, install the next one.
238                     if offset + 1 == BLOCK_CAP {
239                         let next_block = Box::into_raw(next_block.unwrap());
240                         self.tail.block.store(next_block, Ordering::Release);
241                         self.tail.index.fetch_add(1 << SHIFT, Ordering::Release);
242                         (*block).next.store(next_block, Ordering::Release);
243                     }
244
245                     token.list.block = block as *const u8;
246                     token.list.offset = offset;
247                     return true;
248                 },
249                 Err(t) => {
250                     tail = t;
251                     block = self.tail.block.load(Ordering::Acquire);
252                     backoff.spin();
253                 }
254             }
255         }
256     }
257
258     /// Writes a message into the channel.
259     pub(crate) unsafe fn write(&self, token: &mut Token, msg: T) -> Result<(), T> {
260         // If there is no slot, the channel is disconnected.
261         if token.list.block.is_null() {
262             return Err(msg);
263         }
264
265         // Write the message into the slot.
266         let block = token.list.block as *mut Block<T>;
267         let offset = token.list.offset;
268         let slot = (*block).slots.get_unchecked(offset);
269         slot.msg.get().write(MaybeUninit::new(msg));
270         slot.state.fetch_or(WRITE, Ordering::Release);
271
272         // Wake a sleeping receiver.
273         self.receivers.notify();
274         Ok(())
275     }
276
277     /// Attempts to reserve a slot for receiving a message.
278     fn start_recv(&self, token: &mut Token) -> bool {
279         let backoff = Backoff::new();
280         let mut head = self.head.index.load(Ordering::Acquire);
281         let mut block = self.head.block.load(Ordering::Acquire);
282
283         loop {
284             // Calculate the offset of the index into the block.
285             let offset = (head >> SHIFT) % LAP;
286
287             // If we reached the end of the block, wait until the next one is installed.
288             if offset == BLOCK_CAP {
289                 backoff.snooze();
290                 head = self.head.index.load(Ordering::Acquire);
291                 block = self.head.block.load(Ordering::Acquire);
292                 continue;
293             }
294
295             let mut new_head = head + (1 << SHIFT);
296
297             if new_head & MARK_BIT == 0 {
298                 atomic::fence(Ordering::SeqCst);
299                 let tail = self.tail.index.load(Ordering::Relaxed);
300
301                 // If the tail equals the head, that means the channel is empty.
302                 if head >> SHIFT == tail >> SHIFT {
303                     // If the channel is disconnected...
304                     if tail & MARK_BIT != 0 {
305                         // ...then receive an error.
306                         token.list.block = ptr::null();
307                         return true;
308                     } else {
309                         // Otherwise, the receive operation is not ready.
310                         return false;
311                     }
312                 }
313
314                 // If head and tail are not in the same block, set `MARK_BIT` in head.
315                 if (head >> SHIFT) / LAP != (tail >> SHIFT) / LAP {
316                     new_head |= MARK_BIT;
317                 }
318             }
319
320             // The block can be null here only if the first message is being sent into the channel.
321             // In that case, just wait until it gets initialized.
322             if block.is_null() {
323                 backoff.snooze();
324                 head = self.head.index.load(Ordering::Acquire);
325                 block = self.head.block.load(Ordering::Acquire);
326                 continue;
327             }
328
329             // Try moving the head index forward.
330             match self.head.index.compare_exchange_weak(
331                 head,
332                 new_head,
333                 Ordering::SeqCst,
334                 Ordering::Acquire,
335             ) {
336                 Ok(_) => unsafe {
337                     // If we've reached the end of the block, move to the next one.
338                     if offset + 1 == BLOCK_CAP {
339                         let next = (*block).wait_next();
340                         let mut next_index = (new_head & !MARK_BIT).wrapping_add(1 << SHIFT);
341                         if !(*next).next.load(Ordering::Relaxed).is_null() {
342                             next_index |= MARK_BIT;
343                         }
344
345                         self.head.block.store(next, Ordering::Release);
346                         self.head.index.store(next_index, Ordering::Release);
347                     }
348
349                     token.list.block = block as *const u8;
350                     token.list.offset = offset;
351                     return true;
352                 },
353                 Err(h) => {
354                     head = h;
355                     block = self.head.block.load(Ordering::Acquire);
356                     backoff.spin();
357                 }
358             }
359         }
360     }
361
362     /// Reads a message from the channel.
363     pub(crate) unsafe fn read(&self, token: &mut Token) -> Result<T, ()> {
364         if token.list.block.is_null() {
365             // The channel is disconnected.
366             return Err(());
367         }
368
369         // Read the message.
370         let block = token.list.block as *mut Block<T>;
371         let offset = token.list.offset;
372         let slot = (*block).slots.get_unchecked(offset);
373         slot.wait_write();
374         let msg = slot.msg.get().read().assume_init();
375
376         // Destroy the block if we've reached the end, or if another thread wanted to destroy but
377         // couldn't because we were busy reading from the slot.
378         if offset + 1 == BLOCK_CAP {
379             Block::destroy(block, 0);
380         } else if slot.state.fetch_or(READ, Ordering::AcqRel) & DESTROY != 0 {
381             Block::destroy(block, offset + 1);
382         }
383
384         Ok(msg)
385     }
386
387     /// Attempts to send a message into the channel.
388     pub(crate) fn try_send(&self, msg: T) -> Result<(), TrySendError<T>> {
389         self.send(msg, None).map_err(|err| match err {
390             SendTimeoutError::Disconnected(msg) => TrySendError::Disconnected(msg),
391             SendTimeoutError::Timeout(_) => unreachable!(),
392         })
393     }
394
395     /// Sends a message into the channel.
396     pub(crate) fn send(
397         &self,
398         msg: T,
399         _deadline: Option<Instant>,
400     ) -> Result<(), SendTimeoutError<T>> {
401         let token = &mut Token::default();
402         assert!(self.start_send(token));
403         unsafe { self.write(token, msg).map_err(SendTimeoutError::Disconnected) }
404     }
405
406     /// Attempts to receive a message without blocking.
407     pub(crate) fn try_recv(&self) -> Result<T, TryRecvError> {
408         let token = &mut Token::default();
409
410         if self.start_recv(token) {
411             unsafe { self.read(token).map_err(|_| TryRecvError::Disconnected) }
412         } else {
413             Err(TryRecvError::Empty)
414         }
415     }
416
417     /// Receives a message from the channel.
418     pub(crate) fn recv(&self, deadline: Option<Instant>) -> Result<T, RecvTimeoutError> {
419         let token = &mut Token::default();
420         loop {
421             // Try receiving a message several times.
422             let backoff = Backoff::new();
423             loop {
424                 if self.start_recv(token) {
425                     unsafe {
426                         return self.read(token).map_err(|_| RecvTimeoutError::Disconnected);
427                     }
428                 }
429
430                 if backoff.is_completed() {
431                     break;
432                 } else {
433                     backoff.snooze();
434                 }
435             }
436
437             if let Some(d) = deadline {
438                 if Instant::now() >= d {
439                     return Err(RecvTimeoutError::Timeout);
440                 }
441             }
442
443             // Prepare for blocking until a sender wakes us up.
444             Context::with(|cx| {
445                 let oper = Operation::hook(token);
446                 self.receivers.register(oper, cx);
447
448                 // Has the channel become ready just now?
449                 if !self.is_empty() || self.is_disconnected() {
450                     let _ = cx.try_select(Selected::Aborted);
451                 }
452
453                 // Block the current thread.
454                 let sel = cx.wait_until(deadline);
455
456                 match sel {
457                     Selected::Waiting => unreachable!(),
458                     Selected::Aborted | Selected::Disconnected => {
459                         self.receivers.unregister(oper).unwrap();
460                         // If the channel was disconnected, we still have to check for remaining
461                         // messages.
462                     }
463                     Selected::Operation(_) => {}
464                 }
465             });
466         }
467     }
468
469     /// Returns the current number of messages inside the channel.
470     pub(crate) fn len(&self) -> usize {
471         loop {
472             // Load the tail index, then load the head index.
473             let mut tail = self.tail.index.load(Ordering::SeqCst);
474             let mut head = self.head.index.load(Ordering::SeqCst);
475
476             // If the tail index didn't change, we've got consistent indices to work with.
477             if self.tail.index.load(Ordering::SeqCst) == tail {
478                 // Erase the lower bits.
479                 tail &= !((1 << SHIFT) - 1);
480                 head &= !((1 << SHIFT) - 1);
481
482                 // Fix up indices if they fall onto block ends.
483                 if (tail >> SHIFT) & (LAP - 1) == LAP - 1 {
484                     tail = tail.wrapping_add(1 << SHIFT);
485                 }
486                 if (head >> SHIFT) & (LAP - 1) == LAP - 1 {
487                     head = head.wrapping_add(1 << SHIFT);
488                 }
489
490                 // Rotate indices so that head falls into the first block.
491                 let lap = (head >> SHIFT) / LAP;
492                 tail = tail.wrapping_sub((lap * LAP) << SHIFT);
493                 head = head.wrapping_sub((lap * LAP) << SHIFT);
494
495                 // Remove the lower bits.
496                 tail >>= SHIFT;
497                 head >>= SHIFT;
498
499                 // Return the difference minus the number of blocks between tail and head.
500                 return tail - head - tail / LAP;
501             }
502         }
503     }
504
505     /// Returns the capacity of the channel.
506     pub(crate) fn capacity(&self) -> Option<usize> {
507         None
508     }
509
510     /// Disconnects senders and wakes up all blocked receivers.
511     ///
512     /// Returns `true` if this call disconnected the channel.
513     pub(crate) fn disconnect_senders(&self) -> bool {
514         let tail = self.tail.index.fetch_or(MARK_BIT, Ordering::SeqCst);
515
516         if tail & MARK_BIT == 0 {
517             self.receivers.disconnect();
518             true
519         } else {
520             false
521         }
522     }
523
524     /// Disconnects receivers.
525     ///
526     /// Returns `true` if this call disconnected the channel.
527     pub(crate) fn disconnect_receivers(&self) -> bool {
528         let tail = self.tail.index.fetch_or(MARK_BIT, Ordering::SeqCst);
529
530         if tail & MARK_BIT == 0 {
531             // If receivers are dropped first, discard all messages to free
532             // memory eagerly.
533             self.discard_all_messages();
534             true
535         } else {
536             false
537         }
538     }
539
540     /// Discards all messages.
541     ///
542     /// This method should only be called when all receivers are dropped.
543     fn discard_all_messages(&self) {
544         let backoff = Backoff::new();
545         let mut tail = self.tail.index.load(Ordering::Acquire);
546         loop {
547             let offset = (tail >> SHIFT) % LAP;
548             if offset != BLOCK_CAP {
549                 break;
550             }
551
552             // New updates to tail will be rejected by MARK_BIT and aborted unless it's
553             // at boundary. We need to wait for the updates take affect otherwise there
554             // can be memory leaks.
555             backoff.snooze();
556             tail = self.tail.index.load(Ordering::Acquire);
557         }
558
559         let mut head = self.head.index.load(Ordering::Acquire);
560         let mut block = self.head.block.load(Ordering::Acquire);
561
562         unsafe {
563             // Drop all messages between head and tail and deallocate the heap-allocated blocks.
564             while head >> SHIFT != tail >> SHIFT {
565                 let offset = (head >> SHIFT) % LAP;
566
567                 if offset < BLOCK_CAP {
568                     // Drop the message in the slot.
569                     let slot = (*block).slots.get_unchecked(offset);
570                     slot.wait_write();
571                     let p = &mut *slot.msg.get();
572                     p.as_mut_ptr().drop_in_place();
573                 } else {
574                     (*block).wait_next();
575                     // Deallocate the block and move to the next one.
576                     let next = (*block).next.load(Ordering::Acquire);
577                     drop(Box::from_raw(block));
578                     block = next;
579                 }
580
581                 head = head.wrapping_add(1 << SHIFT);
582             }
583
584             // Deallocate the last remaining block.
585             if !block.is_null() {
586                 drop(Box::from_raw(block));
587             }
588         }
589         head &= !MARK_BIT;
590         self.head.block.store(ptr::null_mut(), Ordering::Release);
591         self.head.index.store(head, Ordering::Release);
592     }
593
594     /// Returns `true` if the channel is disconnected.
595     pub(crate) fn is_disconnected(&self) -> bool {
596         self.tail.index.load(Ordering::SeqCst) & MARK_BIT != 0
597     }
598
599     /// Returns `true` if the channel is empty.
600     pub(crate) fn is_empty(&self) -> bool {
601         let head = self.head.index.load(Ordering::SeqCst);
602         let tail = self.tail.index.load(Ordering::SeqCst);
603         head >> SHIFT == tail >> SHIFT
604     }
605
606     /// Returns `true` if the channel is full.
607     pub(crate) fn is_full(&self) -> bool {
608         false
609     }
610 }
611
612 impl<T> Drop for Channel<T> {
613     fn drop(&mut self) {
614         let mut head = self.head.index.load(Ordering::Relaxed);
615         let mut tail = self.tail.index.load(Ordering::Relaxed);
616         let mut block = self.head.block.load(Ordering::Relaxed);
617
618         // Erase the lower bits.
619         head &= !((1 << SHIFT) - 1);
620         tail &= !((1 << SHIFT) - 1);
621
622         unsafe {
623             // Drop all messages between head and tail and deallocate the heap-allocated blocks.
624             while head != tail {
625                 let offset = (head >> SHIFT) % LAP;
626
627                 if offset < BLOCK_CAP {
628                     // Drop the message in the slot.
629                     let slot = (*block).slots.get_unchecked(offset);
630                     let p = &mut *slot.msg.get();
631                     p.as_mut_ptr().drop_in_place();
632                 } else {
633                     // Deallocate the block and move to the next one.
634                     let next = (*block).next.load(Ordering::Relaxed);
635                     drop(Box::from_raw(block));
636                     block = next;
637                 }
638
639                 head = head.wrapping_add(1 << SHIFT);
640             }
641
642             // Deallocate the last remaining block.
643             if !block.is_null() {
644                 drop(Box::from_raw(block));
645             }
646         }
647     }
648 }