]> git.lizzy.rs Git - rust.git/blob - library/core/src/iter/adapters/take.rs
Rollup merge of #102854 - semarie:openbsd-immutablestack, r=m-ou-se
[rust.git] / library / core / src / iter / adapters / take.rs
1 use crate::cmp;
2 use crate::iter::{adapters::SourceIter, FusedIterator, InPlaceIterable, TrustedLen};
3 use crate::ops::{ControlFlow, Try};
4
5 /// An iterator that only iterates over the first `n` iterations of `iter`.
6 ///
7 /// This `struct` is created by the [`take`] method on [`Iterator`]. See its
8 /// documentation for more.
9 ///
10 /// [`take`]: Iterator::take
11 /// [`Iterator`]: trait.Iterator.html
12 #[derive(Clone, Debug)]
13 #[must_use = "iterators are lazy and do nothing unless consumed"]
14 #[stable(feature = "rust1", since = "1.0.0")]
15 pub struct Take<I> {
16     iter: I,
17     n: usize,
18 }
19
20 impl<I> Take<I> {
21     pub(in crate::iter) fn new(iter: I, n: usize) -> Take<I> {
22         Take { iter, n }
23     }
24 }
25
26 #[stable(feature = "rust1", since = "1.0.0")]
27 impl<I> Iterator for Take<I>
28 where
29     I: Iterator,
30 {
31     type Item = <I as Iterator>::Item;
32
33     #[inline]
34     fn next(&mut self) -> Option<<I as Iterator>::Item> {
35         if self.n != 0 {
36             self.n -= 1;
37             self.iter.next()
38         } else {
39             None
40         }
41     }
42
43     #[inline]
44     fn nth(&mut self, n: usize) -> Option<I::Item> {
45         if self.n > n {
46             self.n -= n + 1;
47             self.iter.nth(n)
48         } else {
49             if self.n > 0 {
50                 self.iter.nth(self.n - 1);
51                 self.n = 0;
52             }
53             None
54         }
55     }
56
57     #[inline]
58     fn size_hint(&self) -> (usize, Option<usize>) {
59         if self.n == 0 {
60             return (0, Some(0));
61         }
62
63         let (lower, upper) = self.iter.size_hint();
64
65         let lower = cmp::min(lower, self.n);
66
67         let upper = match upper {
68             Some(x) if x < self.n => Some(x),
69             _ => Some(self.n),
70         };
71
72         (lower, upper)
73     }
74
75     #[inline]
76     fn try_fold<Acc, Fold, R>(&mut self, init: Acc, fold: Fold) -> R
77     where
78         Self: Sized,
79         Fold: FnMut(Acc, Self::Item) -> R,
80         R: Try<Output = Acc>,
81     {
82         fn check<'a, T, Acc, R: Try<Output = Acc>>(
83             n: &'a mut usize,
84             mut fold: impl FnMut(Acc, T) -> R + 'a,
85         ) -> impl FnMut(Acc, T) -> ControlFlow<R, Acc> + 'a {
86             move |acc, x| {
87                 *n -= 1;
88                 let r = fold(acc, x);
89                 if *n == 0 { ControlFlow::Break(r) } else { ControlFlow::from_try(r) }
90             }
91         }
92
93         if self.n == 0 {
94             try { init }
95         } else {
96             let n = &mut self.n;
97             self.iter.try_fold(init, check(n, fold)).into_try()
98         }
99     }
100
101     impl_fold_via_try_fold! { fold -> try_fold }
102
103     #[inline]
104     #[rustc_inherit_overflow_checks]
105     fn advance_by(&mut self, n: usize) -> Result<(), usize> {
106         let min = self.n.min(n);
107         match self.iter.advance_by(min) {
108             Ok(_) => {
109                 self.n -= min;
110                 if min < n { Err(min) } else { Ok(()) }
111             }
112             ret @ Err(advanced) => {
113                 self.n -= advanced;
114                 ret
115             }
116         }
117     }
118 }
119
120 #[unstable(issue = "none", feature = "inplace_iteration")]
121 unsafe impl<I> SourceIter for Take<I>
122 where
123     I: SourceIter,
124 {
125     type Source = I::Source;
126
127     #[inline]
128     unsafe fn as_inner(&mut self) -> &mut I::Source {
129         // SAFETY: unsafe function forwarding to unsafe function with the same requirements
130         unsafe { SourceIter::as_inner(&mut self.iter) }
131     }
132 }
133
134 #[unstable(issue = "none", feature = "inplace_iteration")]
135 unsafe impl<I: InPlaceIterable> InPlaceIterable for Take<I> {}
136
137 #[stable(feature = "double_ended_take_iterator", since = "1.38.0")]
138 impl<I> DoubleEndedIterator for Take<I>
139 where
140     I: DoubleEndedIterator + ExactSizeIterator,
141 {
142     #[inline]
143     fn next_back(&mut self) -> Option<Self::Item> {
144         if self.n == 0 {
145             None
146         } else {
147             let n = self.n;
148             self.n -= 1;
149             self.iter.nth_back(self.iter.len().saturating_sub(n))
150         }
151     }
152
153     #[inline]
154     fn nth_back(&mut self, n: usize) -> Option<Self::Item> {
155         let len = self.iter.len();
156         if self.n > n {
157             let m = len.saturating_sub(self.n) + n;
158             self.n -= n + 1;
159             self.iter.nth_back(m)
160         } else {
161             if len > 0 {
162                 self.iter.nth_back(len - 1);
163             }
164             None
165         }
166     }
167
168     #[inline]
169     fn try_rfold<Acc, Fold, R>(&mut self, init: Acc, fold: Fold) -> R
170     where
171         Self: Sized,
172         Fold: FnMut(Acc, Self::Item) -> R,
173         R: Try<Output = Acc>,
174     {
175         if self.n == 0 {
176             try { init }
177         } else {
178             let len = self.iter.len();
179             if len > self.n && self.iter.nth_back(len - self.n - 1).is_none() {
180                 try { init }
181             } else {
182                 self.iter.try_rfold(init, fold)
183             }
184         }
185     }
186
187     #[inline]
188     fn rfold<Acc, Fold>(mut self, init: Acc, fold: Fold) -> Acc
189     where
190         Self: Sized,
191         Fold: FnMut(Acc, Self::Item) -> Acc,
192     {
193         if self.n == 0 {
194             init
195         } else {
196             let len = self.iter.len();
197             if len > self.n && self.iter.nth_back(len - self.n - 1).is_none() {
198                 init
199             } else {
200                 self.iter.rfold(init, fold)
201             }
202         }
203     }
204
205     #[inline]
206     #[rustc_inherit_overflow_checks]
207     fn advance_back_by(&mut self, n: usize) -> Result<(), usize> {
208         // The amount by which the inner iterator needs to be shortened for it to be
209         // at most as long as the take() amount.
210         let trim_inner = self.iter.len().saturating_sub(self.n);
211         // The amount we need to advance inner to fulfill the caller's request.
212         // take(), advance_by() and len() all can be at most usize, so we don't have to worry
213         // about having to advance more than usize::MAX here.
214         let advance_by = trim_inner.saturating_add(n);
215
216         let advanced = match self.iter.advance_back_by(advance_by) {
217             Ok(_) => advance_by - trim_inner,
218             Err(advanced) => advanced - trim_inner,
219         };
220         self.n -= advanced;
221         return if advanced < n { Err(advanced) } else { Ok(()) };
222     }
223 }
224
225 #[stable(feature = "rust1", since = "1.0.0")]
226 impl<I> ExactSizeIterator for Take<I> where I: ExactSizeIterator {}
227
228 #[stable(feature = "fused", since = "1.26.0")]
229 impl<I> FusedIterator for Take<I> where I: FusedIterator {}
230
231 #[unstable(feature = "trusted_len", issue = "37572")]
232 unsafe impl<I: TrustedLen> TrustedLen for Take<I> {}