]> git.lizzy.rs Git - rust.git/commitdiff
Rollup merge of #68770 - ssomers:btree_drain_filter, r=Amanieu
authorDylan DPC <dylan.dpc@gmail.com>
Tue, 31 Mar 2020 22:27:18 +0000 (00:27 +0200)
committerGitHub <noreply@github.com>
Tue, 31 Mar 2020 22:27:18 +0000 (00:27 +0200)
BTreeMap/BTreeSet: implement drain_filter

Provide an implementation of drain_filter for BTreeMap and BTreeSet. Should be optimal when the predicate picks only elements in leaf nodes with at least MIN_LEN remaining elements, which is a common case, at least when draining only a fraction of the map/set, and also when the predicate picks elements stored in internal nodes where the right subtree can easily let go of a replacement element.

The first commit adds benchmarks with an external, naive implementation. to compare how much this claimed optimality-in-some-cases is actually worth.

src/liballoc/benches/btree/set.rs
src/liballoc/benches/lib.rs
src/liballoc/collections/btree/map.rs
src/liballoc/collections/btree/set.rs
src/liballoc/tests/btree/map.rs
src/liballoc/tests/btree/set.rs
src/liballoc/tests/lib.rs

index d9e75ab7fa4ef6cefd68e85afea59bae35e149ad..2518506b9b5f3915996f9bc84fbc0f688c1e05af 100644 (file)
@@ -62,6 +62,22 @@ pub fn clone_100_and_clear(b: &mut Bencher) {
     b.iter(|| src.clone().clear())
 }
 
+#[bench]
+pub fn clone_100_and_drain_all(b: &mut Bencher) {
+    let src = pos(100);
+    b.iter(|| src.clone().drain_filter(|_| true).count())
+}
+
+#[bench]
+pub fn clone_100_and_drain_half(b: &mut Bencher) {
+    let src = pos(100);
+    b.iter(|| {
+        let mut set = src.clone();
+        assert_eq!(set.drain_filter(|i| i % 2 == 0).count(), 100 / 2);
+        assert_eq!(set.len(), 100 / 2);
+    })
+}
+
 #[bench]
 pub fn clone_100_and_into_iter(b: &mut Bencher) {
     let src = pos(100);
@@ -115,6 +131,22 @@ pub fn clone_10k_and_clear(b: &mut Bencher) {
     b.iter(|| src.clone().clear())
 }
 
+#[bench]
+pub fn clone_10k_and_drain_all(b: &mut Bencher) {
+    let src = pos(10_000);
+    b.iter(|| src.clone().drain_filter(|_| true).count())
+}
+
+#[bench]
+pub fn clone_10k_and_drain_half(b: &mut Bencher) {
+    let src = pos(10_000);
+    b.iter(|| {
+        let mut set = src.clone();
+        assert_eq!(set.drain_filter(|i| i % 2 == 0).count(), 10_000 / 2);
+        assert_eq!(set.len(), 10_000 / 2);
+    })
+}
+
 #[bench]
 pub fn clone_10k_and_into_iter(b: &mut Bencher) {
     let src = pos(10_000);
index 951477a24c8ed3453c655a72a9903ffb87b9933e..f31717d9fd517e76b860364f430e7170d2194822 100644 (file)
@@ -1,3 +1,4 @@
+#![feature(btree_drain_filter)]
 #![feature(map_first_last)]
 #![feature(repr_simd)]
 #![feature(test)]
index bde66c406af7f20fd64ae7cbbcf1c3b6f449f448..bbeced1751d1451830e5c3673f1489a6b6953de3 100644 (file)
@@ -1256,6 +1256,48 @@ pub fn split_off<Q: ?Sized + Ord>(&mut self, key: &Q) -> Self
         right
     }
 
+    /// Creates an iterator which uses a closure to determine if an element should be removed.
+    ///
+    /// If the closure returns true, the element is removed from the map and yielded.
+    /// If the closure returns false, or panics, the element remains in the map and will not be
+    /// yielded.
+    ///
+    /// Note that `drain_filter` lets you mutate every value in the filter closure, regardless of
+    /// whether you choose to keep or remove it.
+    ///
+    /// If the iterator is only partially consumed or not consumed at all, each of the remaining
+    /// elements will still be subjected to the closure and removed and dropped if it returns true.
+    ///
+    /// It is unspecified how many more elements will be subjected to the closure
+    /// if a panic occurs in the closure, or a panic occurs while dropping an element,
+    /// or if the `DrainFilter` value is leaked.
+    ///
+    /// # Examples
+    ///
+    /// Splitting a map into even and odd keys, reusing the original map:
+    ///
+    /// ```
+    /// #![feature(btree_drain_filter)]
+    /// use std::collections::BTreeMap;
+    ///
+    /// let mut map: BTreeMap<i32, i32> = (0..8).map(|x| (x, x)).collect();
+    /// let evens: BTreeMap<_, _> = map.drain_filter(|k, _v| k % 2 == 0).collect();
+    /// let odds = map;
+    /// assert_eq!(evens.keys().copied().collect::<Vec<_>>(), vec![0, 2, 4, 6]);
+    /// assert_eq!(odds.keys().copied().collect::<Vec<_>>(), vec![1, 3, 5, 7]);
+    /// ```
+    #[unstable(feature = "btree_drain_filter", issue = "70530")]
+    pub fn drain_filter<F>(&mut self, pred: F) -> DrainFilter<'_, K, V, F>
+    where
+        F: FnMut(&K, &mut V) -> bool,
+    {
+        DrainFilter { pred, inner: self.drain_filter_inner() }
+    }
+    pub(super) fn drain_filter_inner(&mut self) -> DrainFilterInner<'_, K, V> {
+        let front = self.root.as_mut().map(|r| r.as_mut().first_leaf_edge());
+        DrainFilterInner { length: &mut self.length, cur_leaf_edge: front }
+    }
+
     /// Calculates the number of elements if it is incorrect.
     fn recalc_length(&mut self) {
         fn dfs<'a, K, V>(node: NodeRef<marker::Immut<'a>, K, V, marker::LeafOrInternal>) -> usize
@@ -1653,6 +1695,124 @@ fn clone(&self) -> Self {
     }
 }
 
+/// An iterator produced by calling `drain_filter` on BTreeMap.
+#[unstable(feature = "btree_drain_filter", issue = "70530")]
+pub struct DrainFilter<'a, K, V, F>
+where
+    K: 'a + Ord, // This Ord bound should be removed before stabilization.
+    V: 'a,
+    F: 'a + FnMut(&K, &mut V) -> bool,
+{
+    pred: F,
+    inner: DrainFilterInner<'a, K, V>,
+}
+pub(super) struct DrainFilterInner<'a, K, V>
+where
+    K: 'a + Ord,
+    V: 'a,
+{
+    length: &'a mut usize,
+    cur_leaf_edge: Option<Handle<NodeRef<marker::Mut<'a>, K, V, marker::Leaf>, marker::Edge>>,
+}
+
+#[unstable(feature = "btree_drain_filter", issue = "70530")]
+impl<'a, K, V, F> Drop for DrainFilter<'a, K, V, F>
+where
+    K: 'a + Ord,
+    V: 'a,
+    F: 'a + FnMut(&K, &mut V) -> bool,
+{
+    fn drop(&mut self) {
+        self.for_each(drop);
+    }
+}
+
+#[unstable(feature = "btree_drain_filter", issue = "70530")]
+impl<'a, K, V, F> fmt::Debug for DrainFilter<'a, K, V, F>
+where
+    K: 'a + fmt::Debug + Ord,
+    V: 'a + fmt::Debug,
+    F: 'a + FnMut(&K, &mut V) -> bool,
+{
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        f.debug_tuple("DrainFilter").field(&self.inner.peek()).finish()
+    }
+}
+
+#[unstable(feature = "btree_drain_filter", issue = "70530")]
+impl<'a, K, V, F> Iterator for DrainFilter<'a, K, V, F>
+where
+    K: 'a + Ord,
+    V: 'a,
+    F: 'a + FnMut(&K, &mut V) -> bool,
+{
+    type Item = (K, V);
+
+    fn next(&mut self) -> Option<(K, V)> {
+        self.inner.next(&mut self.pred)
+    }
+
+    fn size_hint(&self) -> (usize, Option<usize>) {
+        self.inner.size_hint()
+    }
+}
+
+impl<'a, K, V> DrainFilterInner<'a, K, V>
+where
+    K: 'a + Ord,
+    V: 'a,
+{
+    /// Allow Debug implementations to predict the next element.
+    pub(super) fn peek(&self) -> Option<(&K, &V)> {
+        let edge = self.cur_leaf_edge.as_ref()?;
+        edge.reborrow().next_kv().ok().map(|kv| kv.into_kv())
+    }
+
+    unsafe fn next_kv(
+        &mut self,
+    ) -> Option<Handle<NodeRef<marker::Mut<'a>, K, V, marker::LeafOrInternal>, marker::KV>> {
+        let edge = self.cur_leaf_edge.as_ref()?;
+        ptr::read(edge).next_kv().ok()
+    }
+
+    /// Implementation of a typical `DrainFilter::next` method, given the predicate.
+    pub(super) fn next<F>(&mut self, pred: &mut F) -> Option<(K, V)>
+    where
+        F: FnMut(&K, &mut V) -> bool,
+    {
+        while let Some(kv) = unsafe { self.next_kv() } {
+            let (k, v) = unsafe { ptr::read(&kv) }.into_kv_mut();
+            if pred(k, v) {
+                *self.length -= 1;
+                let (k, v, leaf_edge_location) = kv.remove_kv_tracking();
+                // `remove_kv_tracking` has either preserved or invalidated `self.cur_leaf_edge`
+                if let Some(node) = leaf_edge_location {
+                    match search::search_tree(node, &k) {
+                        search::SearchResult::Found(_) => unreachable!(),
+                        search::SearchResult::GoDown(leaf) => self.cur_leaf_edge = Some(leaf),
+                    }
+                };
+                return Some((k, v));
+            }
+            self.cur_leaf_edge = Some(kv.next_leaf_edge());
+        }
+        None
+    }
+
+    /// Implementation of a typical `DrainFilter::size_hint` method.
+    pub(super) fn size_hint(&self) -> (usize, Option<usize>) {
+        (0, Some(*self.length))
+    }
+}
+
+#[unstable(feature = "btree_drain_filter", issue = "70530")]
+impl<K, V, F> FusedIterator for DrainFilter<'_, K, V, F>
+where
+    K: Ord,
+    F: FnMut(&K, &mut V) -> bool,
+{
+}
+
 #[stable(feature = "btree_range", since = "1.17.0")]
 impl<'a, K, V> Iterator for Range<'a, K, V> {
     type Item = (&'a K, &'a V);
@@ -2531,12 +2691,31 @@ pub fn remove(self) -> V {
     fn remove_kv(self) -> (K, V) {
         *self.length -= 1;
 
-        let (small_leaf, old_key, old_val) = match self.handle.force() {
+        let (old_key, old_val, _) = self.handle.remove_kv_tracking();
+        (old_key, old_val)
+    }
+}
+
+impl<'a, K: 'a, V: 'a> Handle<NodeRef<marker::Mut<'a>, K, V, marker::LeafOrInternal>, marker::KV> {
+    /// Removes a key/value-pair from the map, and returns that pair, as well as
+    /// the whereabouts of the leaf edge corresponding to that former pair:
+    /// if None is returned, the leaf edge is still the left leaf edge of the KV handle;
+    /// if a node is returned, it heads the subtree where the leaf edge may be found.
+    fn remove_kv_tracking(
+        self,
+    ) -> (K, V, Option<NodeRef<marker::Mut<'a>, K, V, marker::LeafOrInternal>>) {
+        let mut levels_down_handled: isize;
+        let (small_leaf, old_key, old_val) = match self.force() {
             Leaf(leaf) => {
+                levels_down_handled = 1; // handled at same level, but affects only the right side
                 let (hole, old_key, old_val) = leaf.remove();
                 (hole.into_node(), old_key, old_val)
             }
             Internal(mut internal) => {
+                // Replace the location freed in the internal node with the next KV,
+                // and remove that next KV from its leaf.
+                levels_down_handled = unsafe { ptr::read(&internal).into_node().height() } as isize;
+
                 let key_loc = internal.kv_mut().0 as *mut K;
                 let val_loc = internal.kv_mut().1 as *mut V;
 
@@ -2556,27 +2735,39 @@ fn remove_kv(self) -> (K, V) {
         let mut cur_node = small_leaf.forget_type();
         while cur_node.len() < node::MIN_LEN {
             match handle_underfull_node(cur_node) {
-                AtRoot => break,
+                AtRoot(root) => {
+                    cur_node = root;
+                    break;
+                }
                 EmptyParent(_) => unreachable!(),
                 Merged(parent) => {
+                    levels_down_handled -= 1;
                     if parent.len() == 0 {
                         // We must be at the root
-                        parent.into_root_mut().pop_level();
+                        let root = parent.into_root_mut();
+                        root.pop_level();
+                        cur_node = root.as_mut();
                         break;
                     } else {
                         cur_node = parent.forget_type();
                     }
                 }
-                Stole(_) => break,
+                Stole(internal_node) => {
+                    levels_down_handled -= 1;
+                    cur_node = internal_node.forget_type();
+                    // This internal node might be underfull, but only if it's the root.
+                    break;
+                }
             }
         }
 
-        (old_key, old_val)
+        let leaf_edge_location = if levels_down_handled > 0 { None } else { Some(cur_node) };
+        (old_key, old_val, leaf_edge_location)
     }
 }
 
 enum UnderflowResult<'a, K, V> {
-    AtRoot,
+    AtRoot(NodeRef<marker::Mut<'a>, K, V, marker::LeafOrInternal>),
     EmptyParent(NodeRef<marker::Mut<'a>, K, V, marker::Internal>),
     Merged(NodeRef<marker::Mut<'a>, K, V, marker::Internal>),
     Stole(NodeRef<marker::Mut<'a>, K, V, marker::Internal>),
@@ -2585,10 +2776,9 @@ enum UnderflowResult<'a, K, V> {
 fn handle_underfull_node<K, V>(
     node: NodeRef<marker::Mut<'_>, K, V, marker::LeafOrInternal>,
 ) -> UnderflowResult<'_, K, V> {
-    let parent = if let Ok(parent) = node.ascend() {
-        parent
-    } else {
-        return AtRoot;
+    let parent = match node.ascend() {
+        Ok(parent) => parent,
+        Err(root) => return AtRoot(root),
     };
 
     let (is_left, mut handle) = match parent.left_kv() {
index b100ce754caad589b92089d51c9c89e087faba4c..0b02223def4f85dc76c4b894e2cd5c94d222f69a 100644 (file)
@@ -8,8 +8,8 @@
 use core::iter::{FromIterator, FusedIterator, Peekable};
 use core::ops::{BitAnd, BitOr, BitXor, RangeBounds, Sub};
 
+use super::map::{BTreeMap, Keys};
 use super::Recover;
-use crate::collections::btree_map::{self, BTreeMap, Keys};
 
 // FIXME(conventions): implement bounded iterators
 
@@ -102,7 +102,7 @@ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 #[stable(feature = "rust1", since = "1.0.0")]
 #[derive(Debug)]
 pub struct IntoIter<T> {
-    iter: btree_map::IntoIter<T, ()>,
+    iter: super::map::IntoIter<T, ()>,
 }
 
 /// An iterator over a sub-range of items in a `BTreeSet`.
@@ -115,7 +115,7 @@ pub struct IntoIter<T> {
 #[derive(Debug)]
 #[stable(feature = "btree_range", since = "1.17.0")]
 pub struct Range<'a, T: 'a> {
-    iter: btree_map::Range<'a, T, ()>,
+    iter: super::map::Range<'a, T, ()>,
 }
 
 /// Core of SymmetricDifference and Union.
@@ -944,6 +944,41 @@ pub fn split_off<Q: ?Sized + Ord>(&mut self, key: &Q) -> Self
     {
         BTreeSet { map: self.map.split_off(key) }
     }
+
+    /// Creates an iterator which uses a closure to determine if a value should be removed.
+    ///
+    /// If the closure returns true, then the value is removed and yielded.
+    /// If the closure returns false, the value will remain in the list and will not be yielded
+    /// by the iterator.
+    ///
+    /// If the iterator is only partially consumed or not consumed at all, each of the remaining
+    /// values will still be subjected to the closure and removed and dropped if it returns true.
+    ///
+    /// It is unspecified how many more values will be subjected to the closure
+    /// if a panic occurs in the closure, or if a panic occurs while dropping a value, or if the
+    /// `DrainFilter` itself is leaked.
+    ///
+    /// # Examples
+    ///
+    /// Splitting a set into even and odd values, reusing the original set:
+    ///
+    /// ```
+    /// #![feature(btree_drain_filter)]
+    /// use std::collections::BTreeSet;
+    ///
+    /// let mut set: BTreeSet<i32> = (0..8).collect();
+    /// let evens: BTreeSet<_> = set.drain_filter(|v| v % 2 == 0).collect();
+    /// let odds = set;
+    /// assert_eq!(evens.into_iter().collect::<Vec<_>>(), vec![0, 2, 4, 6]);
+    /// assert_eq!(odds.into_iter().collect::<Vec<_>>(), vec![1, 3, 5, 7]);
+    /// ```
+    #[unstable(feature = "btree_drain_filter", issue = "70530")]
+    pub fn drain_filter<'a, F>(&'a mut self, pred: F) -> DrainFilter<'a, T, F>
+    where
+        F: 'a + FnMut(&T) -> bool,
+    {
+        DrainFilter { pred, inner: self.map.drain_filter_inner() }
+    }
 }
 
 impl<T> BTreeSet<T> {
@@ -1055,6 +1090,66 @@ fn into_iter(self) -> Iter<'a, T> {
     }
 }
 
+/// An iterator produced by calling `drain_filter` on BTreeSet.
+#[unstable(feature = "btree_drain_filter", issue = "70530")]
+pub struct DrainFilter<'a, T, F>
+where
+    T: 'a + Ord,
+    F: 'a + FnMut(&T) -> bool,
+{
+    pred: F,
+    inner: super::map::DrainFilterInner<'a, T, ()>,
+}
+
+#[unstable(feature = "btree_drain_filter", issue = "70530")]
+impl<'a, T, F> Drop for DrainFilter<'a, T, F>
+where
+    T: 'a + Ord,
+    F: 'a + FnMut(&T) -> bool,
+{
+    fn drop(&mut self) {
+        self.for_each(drop);
+    }
+}
+
+#[unstable(feature = "btree_drain_filter", issue = "70530")]
+impl<'a, T, F> fmt::Debug for DrainFilter<'a, T, F>
+where
+    T: 'a + Ord + fmt::Debug,
+    F: 'a + FnMut(&T) -> bool,
+{
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        f.debug_tuple("DrainFilter").field(&self.inner.peek().map(|(k, _)| k)).finish()
+    }
+}
+
+#[unstable(feature = "btree_drain_filter", issue = "70530")]
+impl<'a, 'f, T, F> Iterator for DrainFilter<'a, T, F>
+where
+    T: 'a + Ord,
+    F: 'a + 'f + FnMut(&T) -> bool,
+{
+    type Item = T;
+
+    fn next(&mut self) -> Option<T> {
+        let pred = &mut self.pred;
+        let mut mapped_pred = |k: &T, _v: &mut ()| pred(k);
+        self.inner.next(&mut mapped_pred).map(|(k, _)| k)
+    }
+
+    fn size_hint(&self) -> (usize, Option<usize>) {
+        self.inner.size_hint()
+    }
+}
+
+#[unstable(feature = "btree_drain_filter", issue = "70530")]
+impl<'a, T, F> FusedIterator for DrainFilter<'a, T, F>
+where
+    T: 'a + Ord,
+    F: 'a + FnMut(&T) -> bool,
+{
+}
+
 #[stable(feature = "rust1", since = "1.0.0")]
 impl<T: Ord> Extend<T> for BTreeSet<T> {
     #[inline]
index e28b71510ce71c8aa2785f33a46d796f7da3585a..14f12ca2d779a453881489ade0410bfc55c05023 100644 (file)
@@ -5,7 +5,7 @@
 use std::iter::FromIterator;
 use std::ops::Bound::{self, Excluded, Included, Unbounded};
 use std::ops::RangeBounds;
-use std::panic::catch_unwind;
+use std::panic::{catch_unwind, AssertUnwindSafe};
 use std::rc::Rc;
 use std::sync::atomic::{AtomicUsize, Ordering};
 
@@ -609,6 +609,263 @@ fn test_range_mut() {
     }
 }
 
+mod test_drain_filter {
+    use super::*;
+
+    #[test]
+    fn empty() {
+        let mut map: BTreeMap<i32, i32> = BTreeMap::new();
+        map.drain_filter(|_, _| unreachable!("there's nothing to decide on"));
+        assert!(map.is_empty());
+    }
+
+    #[test]
+    fn consuming_nothing() {
+        let pairs = (0..3).map(|i| (i, i));
+        let mut map: BTreeMap<_, _> = pairs.collect();
+        assert!(map.drain_filter(|_, _| false).eq(std::iter::empty()));
+    }
+
+    #[test]
+    fn consuming_all() {
+        let pairs = (0..3).map(|i| (i, i));
+        let mut map: BTreeMap<_, _> = pairs.clone().collect();
+        assert!(map.drain_filter(|_, _| true).eq(pairs));
+    }
+
+    #[test]
+    fn mutating_and_keeping() {
+        let pairs = (0..3).map(|i| (i, i));
+        let mut map: BTreeMap<_, _> = pairs.collect();
+        assert!(
+            map.drain_filter(|_, v| {
+                *v += 6;
+                false
+            })
+            .eq(std::iter::empty())
+        );
+        assert!(map.keys().copied().eq(0..3));
+        assert!(map.values().copied().eq(6..9));
+    }
+
+    #[test]
+    fn mutating_and_removing() {
+        let pairs = (0..3).map(|i| (i, i));
+        let mut map: BTreeMap<_, _> = pairs.collect();
+        assert!(
+            map.drain_filter(|_, v| {
+                *v += 6;
+                true
+            })
+            .eq((0..3).map(|i| (i, i + 6)))
+        );
+        assert!(map.is_empty());
+    }
+
+    #[test]
+    fn underfull_keeping_all() {
+        let pairs = (0..3).map(|i| (i, i));
+        let mut map: BTreeMap<_, _> = pairs.collect();
+        map.drain_filter(|_, _| false);
+        assert!(map.keys().copied().eq(0..3));
+    }
+
+    #[test]
+    fn underfull_removing_one() {
+        let pairs = (0..3).map(|i| (i, i));
+        for doomed in 0..3 {
+            let mut map: BTreeMap<_, _> = pairs.clone().collect();
+            map.drain_filter(|i, _| *i == doomed);
+            assert_eq!(map.len(), 2);
+        }
+    }
+
+    #[test]
+    fn underfull_keeping_one() {
+        let pairs = (0..3).map(|i| (i, i));
+        for sacred in 0..3 {
+            let mut map: BTreeMap<_, _> = pairs.clone().collect();
+            map.drain_filter(|i, _| *i != sacred);
+            assert!(map.keys().copied().eq(sacred..=sacred));
+        }
+    }
+
+    #[test]
+    fn underfull_removing_all() {
+        let pairs = (0..3).map(|i| (i, i));
+        let mut map: BTreeMap<_, _> = pairs.collect();
+        map.drain_filter(|_, _| true);
+        assert!(map.is_empty());
+    }
+
+    #[test]
+    fn height_0_keeping_all() {
+        let pairs = (0..NODE_CAPACITY).map(|i| (i, i));
+        let mut map: BTreeMap<_, _> = pairs.collect();
+        map.drain_filter(|_, _| false);
+        assert!(map.keys().copied().eq(0..NODE_CAPACITY));
+    }
+
+    #[test]
+    fn height_0_removing_one() {
+        let pairs = (0..NODE_CAPACITY).map(|i| (i, i));
+        for doomed in 0..NODE_CAPACITY {
+            let mut map: BTreeMap<_, _> = pairs.clone().collect();
+            map.drain_filter(|i, _| *i == doomed);
+            assert_eq!(map.len(), NODE_CAPACITY - 1);
+        }
+    }
+
+    #[test]
+    fn height_0_keeping_one() {
+        let pairs = (0..NODE_CAPACITY).map(|i| (i, i));
+        for sacred in 0..NODE_CAPACITY {
+            let mut map: BTreeMap<_, _> = pairs.clone().collect();
+            map.drain_filter(|i, _| *i != sacred);
+            assert!(map.keys().copied().eq(sacred..=sacred));
+        }
+    }
+
+    #[test]
+    fn height_0_removing_all() {
+        let pairs = (0..NODE_CAPACITY).map(|i| (i, i));
+        let mut map: BTreeMap<_, _> = pairs.collect();
+        map.drain_filter(|_, _| true);
+        assert!(map.is_empty());
+    }
+
+    #[test]
+    fn height_0_keeping_half() {
+        let mut map: BTreeMap<_, _> = (0..16).map(|i| (i, i)).collect();
+        assert_eq!(map.drain_filter(|i, _| *i % 2 == 0).count(), 8);
+        assert_eq!(map.len(), 8);
+    }
+
+    #[test]
+    fn height_1_removing_all() {
+        let pairs = (0..MIN_INSERTS_HEIGHT_1).map(|i| (i, i));
+        let mut map: BTreeMap<_, _> = pairs.collect();
+        map.drain_filter(|_, _| true);
+        assert!(map.is_empty());
+    }
+
+    #[test]
+    fn height_1_removing_one() {
+        let pairs = (0..MIN_INSERTS_HEIGHT_1).map(|i| (i, i));
+        for doomed in 0..MIN_INSERTS_HEIGHT_1 {
+            let mut map: BTreeMap<_, _> = pairs.clone().collect();
+            map.drain_filter(|i, _| *i == doomed);
+            assert_eq!(map.len(), MIN_INSERTS_HEIGHT_1 - 1);
+        }
+    }
+
+    #[test]
+    fn height_1_keeping_one() {
+        let pairs = (0..MIN_INSERTS_HEIGHT_1).map(|i| (i, i));
+        for sacred in 0..MIN_INSERTS_HEIGHT_1 {
+            let mut map: BTreeMap<_, _> = pairs.clone().collect();
+            map.drain_filter(|i, _| *i != sacred);
+            assert!(map.keys().copied().eq(sacred..=sacred));
+        }
+    }
+
+    #[cfg(not(miri))] // Miri is too slow
+    #[test]
+    fn height_2_removing_one() {
+        let pairs = (0..MIN_INSERTS_HEIGHT_2).map(|i| (i, i));
+        for doomed in (0..MIN_INSERTS_HEIGHT_2).step_by(12) {
+            let mut map: BTreeMap<_, _> = pairs.clone().collect();
+            map.drain_filter(|i, _| *i == doomed);
+            assert_eq!(map.len(), MIN_INSERTS_HEIGHT_2 - 1);
+        }
+    }
+
+    #[cfg(not(miri))] // Miri is too slow
+    #[test]
+    fn height_2_keeping_one() {
+        let pairs = (0..MIN_INSERTS_HEIGHT_2).map(|i| (i, i));
+        for sacred in (0..MIN_INSERTS_HEIGHT_2).step_by(12) {
+            let mut map: BTreeMap<_, _> = pairs.clone().collect();
+            map.drain_filter(|i, _| *i != sacred);
+            assert!(map.keys().copied().eq(sacred..=sacred));
+        }
+    }
+
+    #[test]
+    fn height_2_removing_all() {
+        let pairs = (0..MIN_INSERTS_HEIGHT_2).map(|i| (i, i));
+        let mut map: BTreeMap<_, _> = pairs.collect();
+        map.drain_filter(|_, _| true);
+        assert!(map.is_empty());
+    }
+
+    #[test]
+    fn drop_panic_leak() {
+        static PREDS: AtomicUsize = AtomicUsize::new(0);
+        static DROPS: AtomicUsize = AtomicUsize::new(0);
+
+        struct D;
+        impl Drop for D {
+            fn drop(&mut self) {
+                if DROPS.fetch_add(1, Ordering::SeqCst) == 1 {
+                    panic!("panic in `drop`");
+                }
+            }
+        }
+
+        let mut map = BTreeMap::new();
+        map.insert(0, D);
+        map.insert(4, D);
+        map.insert(8, D);
+
+        catch_unwind(move || {
+            drop(map.drain_filter(|i, _| {
+                PREDS.fetch_add(1usize << i, Ordering::SeqCst);
+                true
+            }))
+        })
+        .ok();
+
+        assert_eq!(PREDS.load(Ordering::SeqCst), 0x011);
+        assert_eq!(DROPS.load(Ordering::SeqCst), 3);
+    }
+
+    #[test]
+    fn pred_panic_leak() {
+        static PREDS: AtomicUsize = AtomicUsize::new(0);
+        static DROPS: AtomicUsize = AtomicUsize::new(0);
+
+        struct D;
+        impl Drop for D {
+            fn drop(&mut self) {
+                DROPS.fetch_add(1, Ordering::SeqCst);
+            }
+        }
+
+        let mut map = BTreeMap::new();
+        map.insert(0, D);
+        map.insert(4, D);
+        map.insert(8, D);
+
+        catch_unwind(AssertUnwindSafe(|| {
+            drop(map.drain_filter(|i, _| {
+                PREDS.fetch_add(1usize << i, Ordering::SeqCst);
+                match i {
+                    0 => true,
+                    _ => panic!(),
+                }
+            }))
+        }))
+        .ok();
+
+        assert_eq!(PREDS.load(Ordering::SeqCst), 0x011);
+        assert_eq!(DROPS.load(Ordering::SeqCst), 1);
+        assert_eq!(map.len(), 2);
+        assert_eq!(map.first_entry().unwrap().key(), &4);
+        assert_eq!(map.last_entry().unwrap().key(), &8);
+    }
+}
+
 #[test]
 fn test_borrow() {
     // make sure these compile -- using the Borrow trait
index 1a2b62d026b2ec969a38ef9186e9eff027a26eeb..136018b9f7df53e1f44985883524fca60aac293f 100644 (file)
@@ -1,5 +1,7 @@
 use std::collections::BTreeSet;
 use std::iter::FromIterator;
+use std::panic::{catch_unwind, AssertUnwindSafe};
+use std::sync::atomic::{AtomicU32, Ordering};
 
 use super::DeterministicRng;
 
@@ -302,6 +304,85 @@ fn is_subset(a: &[i32], b: &[i32]) -> bool {
     assert_eq!(is_subset(&[99, 100], &large), false);
 }
 
+#[test]
+fn test_drain_filter() {
+    let mut x: BTreeSet<_> = [1].iter().copied().collect();
+    let mut y: BTreeSet<_> = [1].iter().copied().collect();
+
+    x.drain_filter(|_| true);
+    y.drain_filter(|_| false);
+    assert_eq!(x.len(), 0);
+    assert_eq!(y.len(), 1);
+}
+
+#[test]
+fn test_drain_filter_drop_panic_leak() {
+    static PREDS: AtomicU32 = AtomicU32::new(0);
+    static DROPS: AtomicU32 = AtomicU32::new(0);
+
+    #[derive(PartialEq, Eq, PartialOrd, Ord)]
+    struct D(i32);
+    impl Drop for D {
+        fn drop(&mut self) {
+            if DROPS.fetch_add(1, Ordering::SeqCst) == 1 {
+                panic!("panic in `drop`");
+            }
+        }
+    }
+
+    let mut set = BTreeSet::new();
+    set.insert(D(0));
+    set.insert(D(4));
+    set.insert(D(8));
+
+    catch_unwind(move || {
+        drop(set.drain_filter(|d| {
+            PREDS.fetch_add(1u32 << d.0, Ordering::SeqCst);
+            true
+        }))
+    })
+    .ok();
+
+    assert_eq!(PREDS.load(Ordering::SeqCst), 0x011);
+    assert_eq!(DROPS.load(Ordering::SeqCst), 3);
+}
+
+#[test]
+fn test_drain_filter_pred_panic_leak() {
+    static PREDS: AtomicU32 = AtomicU32::new(0);
+    static DROPS: AtomicU32 = AtomicU32::new(0);
+
+    #[derive(PartialEq, Eq, PartialOrd, Ord)]
+    struct D(i32);
+    impl Drop for D {
+        fn drop(&mut self) {
+            DROPS.fetch_add(1, Ordering::SeqCst);
+        }
+    }
+
+    let mut set = BTreeSet::new();
+    set.insert(D(0));
+    set.insert(D(4));
+    set.insert(D(8));
+
+    catch_unwind(AssertUnwindSafe(|| {
+        drop(set.drain_filter(|d| {
+            PREDS.fetch_add(1u32 << d.0, Ordering::SeqCst);
+            match d.0 {
+                0 => true,
+                _ => panic!(),
+            }
+        }))
+    }))
+    .ok();
+
+    assert_eq!(PREDS.load(Ordering::SeqCst), 0x011);
+    assert_eq!(DROPS.load(Ordering::SeqCst), 1);
+    assert_eq!(set.len(), 2);
+    assert_eq!(set.first().unwrap().0, 4);
+    assert_eq!(set.last().unwrap().0, 8);
+}
+
 #[test]
 fn test_clear() {
     let mut x = BTreeSet::new();
index ea75f8903c3685fdd73203a18a810889c2e49615..ad6feaeebc67f884208e0fd4721935579fd8f662 100644 (file)
@@ -1,5 +1,6 @@
 #![feature(allocator_api)]
 #![feature(box_syntax)]
+#![feature(btree_drain_filter)]
 #![feature(drain_filter)]
 #![feature(exact_size_is_empty)]
 #![feature(map_first_last)]