]> git.lizzy.rs Git - rust.git/commitdiff
BTreeSet symmetric_difference & union optimized, cleaned
authorStein Somers <git@steinsomers.be>
Tue, 8 Oct 2019 23:07:57 +0000 (01:07 +0200)
committerStein Somers <git@steinsomers.be>
Thu, 17 Oct 2019 22:11:32 +0000 (00:11 +0200)
src/liballoc/collections/btree/set.rs
src/liballoc/tests/btree/set.rs

index 8250fc38ccd1c7a0b4d556294b128d5d8d699714..f0796354e00c384076f8a47a59f62b849556cd1e 100644 (file)
@@ -2,7 +2,7 @@
 // to TreeMap
 
 use core::borrow::Borrow;
-use core::cmp::Ordering::{self, Less, Greater, Equal};
+use core::cmp::Ordering::{Less, Greater, Equal};
 use core::cmp::{max, min};
 use core::fmt::{self, Debug};
 use core::iter::{Peekable, FromIterator, FusedIterator};
@@ -109,6 +109,77 @@ pub struct Range<'a, T: 'a> {
     iter: btree_map::Range<'a, T, ()>,
 }
 
+/// Core of SymmetricDifference and Union.
+/// More efficient than btree.map.MergeIter,
+/// and crucially for SymmetricDifference, nexts() reports on both sides.
+#[derive(Clone)]
+struct MergeIterInner<I>
+    where I: Iterator,
+          I::Item: Copy,
+{
+    a: I,
+    b: I,
+    peeked: Option<MergeIterPeeked<I>>,
+}
+
+#[derive(Copy, Clone, Debug)]
+enum MergeIterPeeked<I: Iterator> {
+    A(I::Item),
+    B(I::Item),
+}
+
+impl<I> MergeIterInner<I>
+    where I: ExactSizeIterator + FusedIterator,
+          I::Item: Copy + Ord,
+{
+    fn new(a: I, b: I) -> Self {
+        MergeIterInner { a, b, peeked: None }
+    }
+
+    fn nexts(&mut self) -> (Option<I::Item>, Option<I::Item>) {
+        let mut a_next = match self.peeked {
+            Some(MergeIterPeeked::A(next)) => Some(next),
+            _ => self.a.next(),
+        };
+        let mut b_next = match self.peeked {
+            Some(MergeIterPeeked::B(next)) => Some(next),
+            _ => self.b.next(),
+        };
+        let ord = match (a_next, b_next) {
+            (None, None) => Equal,
+            (_, None) => Less,
+            (None, _) => Greater,
+            (Some(a1), Some(b1)) => a1.cmp(&b1),
+        };
+        self.peeked = match ord {
+            Less => b_next.take().map(MergeIterPeeked::B),
+            Equal => None,
+            Greater => a_next.take().map(MergeIterPeeked::A),
+        };
+        (a_next, b_next)
+    }
+
+    fn lens(&self) -> (usize, usize) {
+        match self.peeked {
+            Some(MergeIterPeeked::A(_)) => (1 + self.a.len(), self.b.len()),
+            Some(MergeIterPeeked::B(_)) => (self.a.len(), 1 + self.b.len()),
+            _ => (self.a.len(), self.b.len()),
+        }
+    }
+}
+
+impl<I> Debug for MergeIterInner<I>
+    where I: Iterator + Debug,
+          I::Item: Copy + Debug,
+{
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        f.debug_tuple("MergeIterInner")
+            .field(&self.a)
+            .field(&self.b)
+            .finish()
+    }
+}
+
 /// A lazy iterator producing elements in the difference of `BTreeSet`s.
 ///
 /// This `struct` is created by the [`difference`] method on [`BTreeSet`].
@@ -120,6 +191,7 @@ pub struct Range<'a, T: 'a> {
 pub struct Difference<'a, T: 'a> {
     inner: DifferenceInner<'a, T>,
 }
+#[derive(Debug)]
 enum DifferenceInner<'a, T: 'a> {
     Stitch {
         // iterate all of self and some of other, spotting matches along the way
@@ -137,21 +209,7 @@ enum DifferenceInner<'a, T: 'a> {
 #[stable(feature = "collection_debug", since = "1.17.0")]
 impl<T: fmt::Debug> fmt::Debug for Difference<'_, T> {
     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
-        match &self.inner {
-            DifferenceInner::Stitch {
-                self_iter,
-                other_iter,
-            } => f
-                .debug_tuple("Difference")
-                .field(&self_iter)
-                .field(&other_iter)
-                .finish(),
-            DifferenceInner::Search {
-                self_iter,
-                other_set: _,
-            } => f.debug_tuple("Difference").field(&self_iter).finish(),
-            DifferenceInner::Iterate(iter) => f.debug_tuple("Difference").field(&iter).finish(),
-        }
+        f.debug_tuple("Difference").field(&self.inner).finish()
     }
 }
 
@@ -163,18 +221,12 @@ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 /// [`BTreeSet`]: struct.BTreeSet.html
 /// [`symmetric_difference`]: struct.BTreeSet.html#method.symmetric_difference
 #[stable(feature = "rust1", since = "1.0.0")]
-pub struct SymmetricDifference<'a, T: 'a> {
-    a: Peekable<Iter<'a, T>>,
-    b: Peekable<Iter<'a, T>>,
-}
+pub struct SymmetricDifference<'a, T: 'a>(MergeIterInner<Iter<'a, T>>);
 
 #[stable(feature = "collection_debug", since = "1.17.0")]
 impl<T: fmt::Debug> fmt::Debug for SymmetricDifference<'_, T> {
     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
-        f.debug_tuple("SymmetricDifference")
-         .field(&self.a)
-         .field(&self.b)
-         .finish()
+        f.debug_tuple("SymmetricDifference").field(&self.0).finish()
     }
 }
 
@@ -189,6 +241,7 @@ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 pub struct Intersection<'a, T: 'a> {
     inner: IntersectionInner<'a, T>,
 }
+#[derive(Debug)]
 enum IntersectionInner<'a, T: 'a> {
     Stitch {
         // iterate similarly sized sets jointly, spotting matches along the way
@@ -206,23 +259,7 @@ enum IntersectionInner<'a, T: 'a> {
 #[stable(feature = "collection_debug", since = "1.17.0")]
 impl<T: fmt::Debug> fmt::Debug for Intersection<'_, T> {
     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
-        match &self.inner {
-            IntersectionInner::Stitch {
-                a,
-                b,
-            } => f
-                .debug_tuple("Intersection")
-                .field(&a)
-                .field(&b)
-                .finish(),
-            IntersectionInner::Search {
-                small_iter,
-                large_set: _,
-            } => f.debug_tuple("Intersection").field(&small_iter).finish(),
-            IntersectionInner::Answer(answer) => {
-                f.debug_tuple("Intersection").field(&answer).finish()
-            }
-        }
+        f.debug_tuple("Intersection").field(&self.inner).finish()
     }
 }
 
@@ -234,18 +271,12 @@ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 /// [`BTreeSet`]: struct.BTreeSet.html
 /// [`union`]: struct.BTreeSet.html#method.union
 #[stable(feature = "rust1", since = "1.0.0")]
-pub struct Union<'a, T: 'a> {
-    a: Peekable<Iter<'a, T>>,
-    b: Peekable<Iter<'a, T>>,
-}
+pub struct Union<'a, T: 'a>(MergeIterInner<Iter<'a, T>>);
 
 #[stable(feature = "collection_debug", since = "1.17.0")]
 impl<T: fmt::Debug> fmt::Debug for Union<'_, T> {
     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
-        f.debug_tuple("Union")
-         .field(&self.a)
-         .field(&self.b)
-         .finish()
+        f.debug_tuple("Union").field(&self.0).finish()
     }
 }
 
@@ -355,19 +386,16 @@ pub fn difference<'a>(&'a self, other: &'a BTreeSet<T>) -> Difference<'a, T> {
                     self_iter.next_back();
                     DifferenceInner::Iterate(self_iter)
                 }
-                _ => {
-                    if self.len() <= other.len() / ITER_PERFORMANCE_TIPPING_SIZE_DIFF {
-                        DifferenceInner::Search {
-                            self_iter: self.iter(),
-                            other_set: other,
-                        }
-                    } else {
-                        DifferenceInner::Stitch {
-                            self_iter: self.iter(),
-                            other_iter: other.iter().peekable(),
-                        }
+                _ if self.len() <= other.len() / ITER_PERFORMANCE_TIPPING_SIZE_DIFF => {
+                    DifferenceInner::Search {
+                        self_iter: self.iter(),
+                        other_set: other,
                     }
                 }
+                _ => DifferenceInner::Stitch {
+                    self_iter: self.iter(),
+                    other_iter: other.iter().peekable(),
+                },
             },
         }
     }
@@ -396,10 +424,7 @@ pub fn difference<'a>(&'a self, other: &'a BTreeSet<T>) -> Difference<'a, T> {
     pub fn symmetric_difference<'a>(&'a self,
                                     other: &'a BTreeSet<T>)
                                     -> SymmetricDifference<'a, T> {
-        SymmetricDifference {
-            a: self.iter().peekable(),
-            b: other.iter().peekable(),
-        }
+        SymmetricDifference(MergeIterInner::new(self.iter(), other.iter()))
     }
 
     /// Visits the values representing the intersection,
@@ -447,24 +472,22 @@ pub fn intersection<'a>(&'a self, other: &'a BTreeSet<T>) -> Intersection<'a, T>
                 (Greater, _) | (_, Less) => IntersectionInner::Answer(None),
                 (Equal, _) => IntersectionInner::Answer(Some(self_min)),
                 (_, Equal) => IntersectionInner::Answer(Some(self_max)),
-                _ => {
-                    if self.len() <= other.len() / ITER_PERFORMANCE_TIPPING_SIZE_DIFF {
-                        IntersectionInner::Search {
-                            small_iter: self.iter(),
-                            large_set: other,
-                        }
-                    } else if other.len() <= self.len() / ITER_PERFORMANCE_TIPPING_SIZE_DIFF {
-                        IntersectionInner::Search {
-                            small_iter: other.iter(),
-                            large_set: self,
-                        }
-                    } else {
-                        IntersectionInner::Stitch {
-                            a: self.iter(),
-                            b: other.iter(),
-                        }
+                _ if self.len() <= other.len() / ITER_PERFORMANCE_TIPPING_SIZE_DIFF => {
+                    IntersectionInner::Search {
+                        small_iter: self.iter(),
+                        large_set: other,
+                    }
+                }
+                _ if other.len() <= self.len() / ITER_PERFORMANCE_TIPPING_SIZE_DIFF => {
+                    IntersectionInner::Search {
+                        small_iter: other.iter(),
+                        large_set: self,
                     }
                 }
+                _ => IntersectionInner::Stitch {
+                    a: self.iter(),
+                    b: other.iter(),
+                },
             },
         }
     }
@@ -489,10 +512,7 @@ pub fn intersection<'a>(&'a self, other: &'a BTreeSet<T>) -> Intersection<'a, T>
     /// ```
     #[stable(feature = "rust1", since = "1.0.0")]
     pub fn union<'a>(&'a self, other: &'a BTreeSet<T>) -> Union<'a, T> {
-        Union {
-            a: self.iter().peekable(),
-            b: other.iter().peekable(),
-        }
+        Union(MergeIterInner::new(self.iter(), other.iter()))
     }
 
     /// Clears the set, removing all values.
@@ -1166,15 +1186,6 @@ fn next_back(&mut self) -> Option<&'a T> {
 #[stable(feature = "fused", since = "1.26.0")]
 impl<T> FusedIterator for Range<'_, T> {}
 
-/// Compares `x` and `y`, but return `short` if x is None and `long` if y is None
-fn cmp_opt<T: Ord>(x: Option<&T>, y: Option<&T>, short: Ordering, long: Ordering) -> Ordering {
-    match (x, y) {
-        (None, _) => short,
-        (_, None) => long,
-        (Some(x1), Some(y1)) => x1.cmp(y1),
-    }
-}
-
 #[stable(feature = "rust1", since = "1.0.0")]
 impl<T> Clone for Difference<'_, T> {
     fn clone(&self) -> Self {
@@ -1261,10 +1272,7 @@ impl<T: Ord> FusedIterator for Difference<'_, T> {}
 #[stable(feature = "rust1", since = "1.0.0")]
 impl<T> Clone for SymmetricDifference<'_, T> {
     fn clone(&self) -> Self {
-        SymmetricDifference {
-            a: self.a.clone(),
-            b: self.b.clone(),
-        }
+        SymmetricDifference(self.0.clone())
     }
 }
 #[stable(feature = "rust1", since = "1.0.0")]
@@ -1273,19 +1281,19 @@ impl<'a, T: Ord> Iterator for SymmetricDifference<'a, T> {
 
     fn next(&mut self) -> Option<&'a T> {
         loop {
-            match cmp_opt(self.a.peek(), self.b.peek(), Greater, Less) {
-                Less => return self.a.next(),
-                Equal => {
-                    self.a.next();
-                    self.b.next();
-                }
-                Greater => return self.b.next(),
+            let (a_next, b_next) = self.0.nexts();
+            if a_next.and(b_next).is_none() {
+                return a_next.or(b_next);
             }
         }
     }
 
     fn size_hint(&self) -> (usize, Option<usize>) {
-        (0, Some(self.a.len() + self.b.len()))
+        let (a_len, b_len) = self.0.lens();
+        // No checked_add, because even if a and b refer to the same set,
+        // and T is an empty type, the storage overhead of sets limits
+        // the number of elements to less than half the range of usize.
+        (0, Some(a_len + b_len))
     }
 }
 
@@ -1311,7 +1319,7 @@ fn clone(&self) -> Self {
                     small_iter: small_iter.clone(),
                     large_set,
                 },
-                IntersectionInner::Answer(answer) => IntersectionInner::Answer(answer.clone()),
+                IntersectionInner::Answer(answer) => IntersectionInner::Answer(*answer),
             },
         }
     }
@@ -1365,10 +1373,7 @@ impl<T: Ord> FusedIterator for Intersection<'_, T> {}
 #[stable(feature = "rust1", since = "1.0.0")]
 impl<T> Clone for Union<'_, T> {
     fn clone(&self) -> Self {
-        Union {
-            a: self.a.clone(),
-            b: self.b.clone(),
-        }
+        Union(self.0.clone())
     }
 }
 #[stable(feature = "rust1", since = "1.0.0")]
@@ -1376,19 +1381,13 @@ impl<'a, T: Ord> Iterator for Union<'a, T> {
     type Item = &'a T;
 
     fn next(&mut self) -> Option<&'a T> {
-        match cmp_opt(self.a.peek(), self.b.peek(), Greater, Less) {
-            Less => self.a.next(),
-            Equal => {
-                self.b.next();
-                self.a.next()
-            }
-            Greater => self.b.next(),
-        }
+        let (a_next, b_next) = self.0.nexts();
+        a_next.or(b_next)
     }
 
     fn size_hint(&self) -> (usize, Option<usize>) {
-        let a_len = self.a.len();
-        let b_len = self.b.len();
+        let (a_len, b_len) = self.0.lens();
+        // No checked_add - see SymmetricDifference::size_hint.
         (max(a_len, b_len), Some(a_len + b_len))
     }
 }
index 5c611fd21d21bedf9584bca4119adf662ecd17f6..e4883abc8b56c3362d1fd4726f0d73746f430971 100644 (file)
@@ -221,6 +221,18 @@ fn check_symmetric_difference(a: &[i32], b: &[i32], expected: &[i32]) {
                                &[-2, 1, 5, 11, 14, 22]);
 }
 
+#[test]
+fn test_symmetric_difference_size_hint() {
+    let x: BTreeSet<i32> = [2, 4].iter().copied().collect();
+    let y: BTreeSet<i32> = [1, 2, 3].iter().copied().collect();
+    let mut iter = x.symmetric_difference(&y);
+    assert_eq!(iter.size_hint(), (0, Some(5)));
+    assert_eq!(iter.next(), Some(&1));
+    assert_eq!(iter.size_hint(), (0, Some(4)));
+    assert_eq!(iter.next(), Some(&3));
+    assert_eq!(iter.size_hint(), (0, Some(1)));
+}
+
 #[test]
 fn test_union() {
     fn check_union(a: &[i32], b: &[i32], expected: &[i32]) {
@@ -235,6 +247,18 @@ fn check_union(a: &[i32], b: &[i32], expected: &[i32]) {
                 &[-2, 1, 3, 5, 9, 11, 13, 16, 19, 24]);
 }
 
+#[test]
+fn test_union_size_hint() {
+    let x: BTreeSet<i32> = [2, 4].iter().copied().collect();
+    let y: BTreeSet<i32> = [1, 2, 3].iter().copied().collect();
+    let mut iter = x.union(&y);
+    assert_eq!(iter.size_hint(), (3, Some(5)));
+    assert_eq!(iter.next(), Some(&1));
+    assert_eq!(iter.size_hint(), (2, Some(4)));
+    assert_eq!(iter.next(), Some(&2));
+    assert_eq!(iter.size_hint(), (1, Some(2)));
+}
+
 #[test]
 // Only tests the simple function definition with respect to intersection
 fn test_is_disjoint() {
@@ -244,7 +268,7 @@ fn test_is_disjoint() {
 }
 
 #[test]
-// Also tests the trivial function definition of is_superset
+// Also implicitly tests the trivial function definition of is_superset
 fn test_is_subset() {
     fn is_subset(a: &[i32], b: &[i32]) -> bool {
         let set_a = a.iter().collect::<BTreeSet<_>>();