]> git.lizzy.rs Git - rust.git/commitdiff
Correct iterator adaptor Chain
authorUlrik Sverdrup <bluss@users.noreply.github.com>
Tue, 25 Aug 2015 01:56:35 +0000 (03:56 +0200)
committerUlrik Sverdrup <bluss@users.noreply.github.com>
Tue, 25 Aug 2015 17:07:24 +0000 (19:07 +0200)
The iterator protocol specifies that the iteration ends with the return
value `None` from `.next()` (or `.next_back()`) and it is unspecified
what further calls return. The chain adaptor must account for this in
its DoubleEndedIterator implementation.

It uses three states:

- Both `a` and `b` are valid
- Only the Front iterator (`a`) is valid
- Only the Back iterator (`b`) is valid

The fourth state (neither iterator is valid) only occurs after Chain has
returned None once, so we don't need to store this state.

Fixes #26316

src/libcore/iter.rs
src/libcoretest/iter.rs

index ee32999ba8fbac3f03b0e024e69e279e290d1604..98d885e8dd34dfe15d88862e77e579533fe9cb79 100644 (file)
@@ -184,7 +184,7 @@ fn nth(&mut self, mut n: usize) -> Option<Self::Item> where Self: Sized {
     fn chain<U>(self, other: U) -> Chain<Self, U::IntoIter> where
         Self: Sized, U: IntoIterator<Item=Self::Item>,
     {
-        Chain{a: self, b: other.into_iter(), flag: false}
+        Chain{a: self, b: other.into_iter(), state: ChainState::Both}
     }
 
     /// Creates an iterator that iterates over both this and the specified
@@ -1277,7 +1277,30 @@ fn size_hint(&self) -> (usize, Option<usize>) {
 pub struct Chain<A, B> {
     a: A,
     b: B,
-    flag: bool,
+    state: ChainState,
+}
+
+// The iterator protocol specifies that iteration ends with the return value
+// `None` from `.next()` (or `.next_back()`) and it is unspecified what
+// further calls return. The chain adaptor must account for this since it uses
+// two subiterators.
+//
+//  It uses three states:
+//
+//  - Both: `a` and `b` are remaining
+//  - Front: `a` remaining
+//  - Back: `b` remaining
+//
+//  The fourth state (neither iterator is remaining) only occurs after Chain has
+//  returned None once, so we don't need to store this state.
+#[derive(Clone)]
+enum ChainState {
+    // both front and back iterator are remaining
+    Both,
+    // only front is remaining
+    Front,
+    // only back is remaining
+    Back,
 }
 
 #[stable(feature = "rust1", since = "1.0.0")]
@@ -1289,42 +1312,58 @@ impl<A, B> Iterator for Chain<A, B> where
 
     #[inline]
     fn next(&mut self) -> Option<A::Item> {
-        if self.flag {
-            self.b.next()
-        } else {
-            match self.a.next() {
-                Some(x) => return Some(x),
-                _ => ()
-            }
-            self.flag = true;
-            self.b.next()
+        match self.state {
+            ChainState::Both => match self.a.next() {
+                elt @ Some(..) => return elt,
+                None => {
+                    self.state = ChainState::Back;
+                    self.b.next()
+                }
+            },
+            ChainState::Front => self.a.next(),
+            ChainState::Back => self.b.next(),
         }
     }
 
     #[inline]
     fn count(self) -> usize {
-        (if !self.flag { self.a.count() } else { 0 }) + self.b.count()
+        match self.state {
+            ChainState::Both => self.a.count() + self.b.count(),
+            ChainState::Front => self.a.count(),
+            ChainState::Back => self.b.count(),
+        }
     }
 
     #[inline]
     fn nth(&mut self, mut n: usize) -> Option<A::Item> {
-        if !self.flag {
-            for x in self.a.by_ref() {
-                if n == 0 {
-                    return Some(x)
+        match self.state {
+            ChainState::Both | ChainState::Front => {
+                for x in self.a.by_ref() {
+                    if n == 0 {
+                        return Some(x)
+                    }
+                    n -= 1;
+                }
+                if let ChainState::Both = self.state {
+                    self.state = ChainState::Back;
                 }
-                n -= 1;
             }
-            self.flag = true;
+            ChainState::Back => {}
+        }
+        if let ChainState::Back = self.state {
+            self.b.nth(n)
+        } else {
+            None
         }
-        self.b.nth(n)
     }
 
     #[inline]
     fn last(self) -> Option<A::Item> {
-        let a_last = if self.flag { None } else { self.a.last() };
-        let b_last = self.b.last();
-        b_last.or(a_last)
+        match self.state {
+            ChainState::Both => self.b.last().or(self.a.last()),
+            ChainState::Front => self.a.last(),
+            ChainState::Back => self.b.last()
+        }
     }
 
     #[inline]
@@ -1350,9 +1389,16 @@ impl<A, B> DoubleEndedIterator for Chain<A, B> where
 {
     #[inline]
     fn next_back(&mut self) -> Option<A::Item> {
-        match self.b.next_back() {
-            Some(x) => Some(x),
-            None => self.a.next_back()
+        match self.state {
+            ChainState::Both => match self.b.next_back() {
+                elt @ Some(..) => return elt,
+                None => {
+                    self.state = ChainState::Front;
+                    self.a.next_back()
+                }
+            },
+            ChainState::Front => self.a.next_back(),
+            ChainState::Back => self.b.next_back(),
         }
     }
 }
index ea65c118e5e98a8a5a327150a3f989a0c0d9bead..87e69581c54b39b44184896669dda77a94c98d72 100644 (file)
@@ -729,6 +729,26 @@ fn test_double_ended_chain() {
     assert_eq!(it.next_back().unwrap(), &5);
     assert_eq!(it.next_back().unwrap(), &7);
     assert_eq!(it.next_back(), None);
+
+
+    // test that .chain() is well behaved with an unfused iterator
+    struct CrazyIterator(bool);
+    impl CrazyIterator { fn new() -> CrazyIterator { CrazyIterator(false) } }
+    impl Iterator for CrazyIterator {
+        type Item = i32;
+        fn next(&mut self) -> Option<i32> {
+            if self.0 { Some(99) } else { self.0 = true; None }
+        }
+    }
+
+    impl DoubleEndedIterator for CrazyIterator {
+        fn next_back(&mut self) -> Option<i32> {
+            self.next()
+        }
+    }
+
+    assert_eq!(CrazyIterator::new().chain(0..10).rev().last(), Some(0));
+    assert!((0..10).chain(CrazyIterator::new()).rev().any(|i| i == 0));
 }
 
 #[test]