]> git.lizzy.rs Git - rust.git/commitdiff
Improve BTreeSet::Intersection::size_hint
authorpcpthm <pcpthm@gmail.com>
Mon, 16 Sep 2019 04:37:52 +0000 (04:37 +0000)
committerpcpthm <pcpthm@gmail.com>
Mon, 16 Sep 2019 04:37:52 +0000 (04:37 +0000)
The commented invariant that an iterator is smaller than other iterator
was violated after next is called and two iterators are consumed at
different rates.

src/liballoc/collections/btree/set.rs
src/liballoc/tests/btree/set.rs

index d3af910a82c27939dab4b03e1190ccad6b842a86..0cb91ba4c81da148c9f89dde57ca042162b7e529 100644 (file)
@@ -3,7 +3,7 @@
 
 use core::borrow::Borrow;
 use core::cmp::Ordering::{self, Less, Greater, Equal};
-use core::cmp::max;
+use core::cmp::{max, min};
 use core::fmt::{self, Debug};
 use core::iter::{Peekable, FromIterator, FusedIterator};
 use core::ops::{BitOr, BitAnd, BitXor, Sub, RangeBounds};
@@ -187,8 +187,8 @@ pub struct Intersection<'a, T: 'a> {
 }
 enum IntersectionInner<'a, T: 'a> {
     Stitch {
-        small_iter: Iter<'a, T>, // for size_hint, should be the smaller of the sets
-        other_iter: Iter<'a, T>,
+        a: Iter<'a, T>,
+        b: Iter<'a, T>,
     },
     Search {
         small_iter: Iter<'a, T>,
@@ -201,12 +201,12 @@ impl<T: fmt::Debug> fmt::Debug for Intersection<'_, T> {
     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
         match &self.inner {
             IntersectionInner::Stitch {
-                small_iter,
-                other_iter,
+                a,
+                b,
             } => f
                 .debug_tuple("Intersection")
-                .field(&small_iter)
-                .field(&other_iter)
+                .field(&a)
+                .field(&b)
                 .finish(),
             IntersectionInner::Search {
                 small_iter,
@@ -397,8 +397,8 @@ pub fn intersection<'a>(&'a self, other: &'a BTreeSet<T>) -> Intersection<'a, T>
             // Iterate both sets jointly, spotting matches along the way.
             Intersection {
                 inner: IntersectionInner::Stitch {
-                    small_iter: small.iter(),
-                    other_iter: other.iter(),
+                    a: small.iter(),
+                    b: other.iter(),
                 },
             }
         } else {
@@ -1221,11 +1221,11 @@ fn clone(&self) -> Self {
         Intersection {
             inner: match &self.inner {
                 IntersectionInner::Stitch {
-                    small_iter,
-                    other_iter,
+                    a,
+                    b,
                 } => IntersectionInner::Stitch {
-                    small_iter: small_iter.clone(),
-                    other_iter: other_iter.clone(),
+                    a: a.clone(),
+                    b: b.clone(),
                 },
                 IntersectionInner::Search {
                     small_iter,
@@ -1245,16 +1245,16 @@ impl<'a, T: Ord> Iterator for Intersection<'a, T> {
     fn next(&mut self) -> Option<&'a T> {
         match &mut self.inner {
             IntersectionInner::Stitch {
-                small_iter,
-                other_iter,
+                a,
+                b,
             } => {
-                let mut small_next = small_iter.next()?;
-                let mut other_next = other_iter.next()?;
+                let mut a_next = a.next()?;
+                let mut b_next = b.next()?;
                 loop {
-                    match Ord::cmp(small_next, other_next) {
-                        Less => small_next = small_iter.next()?,
-                        Greater => other_next = other_iter.next()?,
-                        Equal => return Some(small_next),
+                    match Ord::cmp(a_next, b_next) {
+                        Less => a_next = a.next()?,
+                        Greater => b_next = b.next()?,
+                        Equal => return Some(a_next),
                     }
                 }
             }
@@ -1272,7 +1272,7 @@ fn next(&mut self) -> Option<&'a T> {
 
     fn size_hint(&self) -> (usize, Option<usize>) {
         let min_len = match &self.inner {
-            IntersectionInner::Stitch { small_iter, .. } => small_iter.len(),
+            IntersectionInner::Stitch { a, b } => min(a.len(), b.len()),
             IntersectionInner::Search { small_iter, .. } => small_iter.len(),
         };
         (0, Some(min_len))
index 62ccb53fcea18a936ce3f446489a5de54825adfe..35db18c39c83a705be03ec3b6813c03db8ba06fe 100644 (file)
@@ -90,6 +90,17 @@ fn check_intersection(a: &[i32], b: &[i32], expected: &[i32]) {
                        &[1, 3, 11, 77, 103]);
 }
 
+#[test]
+fn test_intersection_size_hint() {
+    let x: BTreeSet<i32> = [3, 4].iter().copied().collect();
+    let y: BTreeSet<i32> = [1, 2, 3].iter().copied().collect();
+    let mut iter = x.intersection(&y);
+    assert_eq!(iter.size_hint(), (0, Some(2)));
+    assert_eq!(iter.next(), Some(&3));
+    assert_eq!(iter.size_hint(), (0, Some(0)));
+    assert_eq!(iter.next(), None);
+}
+
 #[test]
 fn test_difference() {
     fn check_difference(a: &[i32], b: &[i32], expected: &[i32]) {