]> git.lizzy.rs Git - rust.git/commitdiff
Optimize Vec::retain
authoroxalica <oxalicc@pm.me>
Sun, 17 Jan 2021 16:51:37 +0000 (00:51 +0800)
committeroxalica <oxalicc@pm.me>
Sun, 17 Jan 2021 17:48:50 +0000 (01:48 +0800)
library/alloc/src/vec/mod.rs
library/alloc/tests/vec.rs

index ccc4f03a1e5058193c220f209d8c1e18fa28f949..2c510b8b2ae61de27fc422433b60a23574110466 100644 (file)
@@ -1371,21 +1371,78 @@ pub fn retain<F>(&mut self, mut f: F)
         F: FnMut(&T) -> bool,
     {
         let len = self.len();
-        let mut del = 0;
-        {
-            let v = &mut **self;
-
-            for i in 0..len {
-                if !f(&v[i]) {
-                    del += 1;
-                } else if del > 0 {
-                    v.swap(i - del, i);
+        // Avoid double drop if the drop guard is not executed,
+        // since we may make some holes during the process.
+        unsafe { self.set_len(0) };
+
+        // Vec: [Kept, Kept, Hole, Hole, Hole, Hole, Unchecked, Unchecked]
+        //      |<-              processed len   ->| ^- next to check
+        //                  |<-  deleted cnt     ->|
+        //      |<-              original_len                          ->|
+        // Kept: Elements which predicate returns true on.
+        // Hole: Moved or dropped element slot.
+        // Unchecked: Unchecked valid elements.
+        //
+        // This drop guard will be invoked when predicate or `drop` of element panicked.
+        // It shifts unchecked elements to cover holes and `set_len` to the correct length.
+        // In cases when predicate and `drop` never panick, it will be optimized out.
+        struct BackshiftOnDrop<'a, T, A: Allocator> {
+            v: &'a mut Vec<T, A>,
+            processed_len: usize,
+            deleted_cnt: usize,
+            original_len: usize,
+        }
+
+        impl<T, A: Allocator> Drop for BackshiftOnDrop<'_, T, A> {
+            fn drop(&mut self) {
+                if self.deleted_cnt > 0 {
+                    // SAFETY: Fill the hole of dropped or moved
+                    unsafe {
+                        ptr::copy(
+                            self.v.as_ptr().offset(self.processed_len as isize),
+                            self.v
+                                .as_mut_ptr()
+                                .offset(self.processed_len as isize - self.deleted_cnt as isize),
+                            self.original_len - self.processed_len,
+                        );
+                        self.v.set_len(self.original_len - self.deleted_cnt);
+                    }
                 }
             }
         }
-        if del > 0 {
-            self.truncate(len - del);
+
+        let mut guard = BackshiftOnDrop {
+            v: self,
+            processed_len: 0,
+            deleted_cnt: 0,
+            original_len: len,
+        };
+
+        let mut del = 0usize;
+        for i in 0..len {
+            // SAFETY: Unchecked element must be valid.
+            let cur = unsafe { &mut *guard.v.as_mut_ptr().offset(i as isize) };
+            if !f(cur) {
+                del += 1;
+                // Advance early to avoid double drop if `drop_in_place` panicked.
+                guard.processed_len = i + 1;
+                guard.deleted_cnt = del;
+                // SAFETY: We never touch this element again after dropped.
+                unsafe { ptr::drop_in_place(cur) };
+            } else if del > 0 {
+                // SAFETY: `del` > 0 so the hole slot must not overlap with current element.
+                // We use copy for move, and never touch this element again.
+                unsafe {
+                    let hole_slot = guard.v.as_mut_ptr().offset(i as isize - del as isize);
+                    ptr::copy_nonoverlapping(cur, hole_slot, 1);
+                }
+                guard.processed_len = i + 1;
+            }
         }
+
+        // All holes are at the end now. Simply cut them out.
+        unsafe { guard.v.set_len(len - del) };
+        mem::forget(guard);
     }
 
     /// Removes all but the first of consecutive elements in the vector that resolve to the same
index e19406d7a069737a2019de90db0003bad46d6952..856efb1d3a98e0e1bc6329978c2800a20ee0771f 100644 (file)
@@ -287,6 +287,78 @@ fn test_retain() {
     assert_eq!(vec, [2, 4]);
 }
 
+#[test]
+fn test_retain_pred_panic() {
+    use std::sync::atomic::{AtomicU64, Ordering};
+
+    struct Wrap<'a>(&'a AtomicU64, u64, bool);
+
+    impl Drop for Wrap<'_> {
+        fn drop(&mut self) {
+            self.0.fetch_or(self.1, Ordering::SeqCst);
+        }
+    }
+
+    let dropped = AtomicU64::new(0);
+
+    let ret = std::panic::catch_unwind(|| {
+        let mut v = vec![
+            Wrap(&dropped, 1, false),
+            Wrap(&dropped, 2, false),
+            Wrap(&dropped, 4, false),
+            Wrap(&dropped, 8, false),
+            Wrap(&dropped, 16, false),
+        ];
+        v.retain(|w| match w.1 {
+            1 => true,
+            2 => false,
+            4 => true,
+            _ => panic!(),
+        });
+    });
+    assert!(ret.is_err());
+    // Everything is dropped when predicate panicked.
+    assert_eq!(dropped.load(Ordering::SeqCst), 1 | 2 | 4 | 8 | 16);
+}
+
+#[test]
+fn test_retain_drop_panic() {
+    use std::sync::atomic::{AtomicU64, Ordering};
+
+    struct Wrap<'a>(&'a AtomicU64, u64);
+
+    impl Drop for Wrap<'_> {
+        fn drop(&mut self) {
+            if self.1 == 8 {
+                panic!();
+            }
+            self.0.fetch_or(self.1, Ordering::SeqCst);
+        }
+    }
+
+    let dropped = AtomicU64::new(0);
+
+    let ret = std::panic::catch_unwind(|| {
+        let mut v = vec![
+            Wrap(&dropped, 1),
+            Wrap(&dropped, 2),
+            Wrap(&dropped, 4),
+            Wrap(&dropped, 8),
+            Wrap(&dropped, 16),
+        ];
+        v.retain(|w| match w.1 {
+            1 => true,
+            2 => false,
+            4 => true,
+            8 => false,
+            _ => true,
+        });
+    });
+    assert!(ret.is_err());
+    // Other elements are dropped when `drop` of one element panicked.
+    assert_eq!(dropped.load(Ordering::SeqCst), 1 | 2 | 4 | 16);
+}
+
 #[test]
 fn test_dedup() {
     fn case(a: Vec<i32>, b: Vec<i32>) {