F: FnMut(&T) -> bool,
{
let len = self.len();
- let mut del = 0;
- {
- let v = &mut **self;
-
- for i in 0..len {
- if !f(&v[i]) {
- del += 1;
- } else if del > 0 {
- v.swap(i - del, i);
+ // Avoid double drop if the drop guard is not executed,
+ // since we may make some holes during the process.
+ unsafe { self.set_len(0) };
+
+ // Vec: [Kept, Kept, Hole, Hole, Hole, Hole, Unchecked, Unchecked]
+ // |<- processed len ->| ^- next to check
+ // |<- deleted cnt ->|
+ // |<- original_len ->|
+ // Kept: Elements which predicate returns true on.
+ // Hole: Moved or dropped element slot.
+ // Unchecked: Unchecked valid elements.
+ //
+ // This drop guard will be invoked when predicate or `drop` of element panicked.
+ // It shifts unchecked elements to cover holes and `set_len` to the correct length.
+ // In cases when predicate and `drop` never panick, it will be optimized out.
+ struct BackshiftOnDrop<'a, T, A: Allocator> {
+ v: &'a mut Vec<T, A>,
+ processed_len: usize,
+ deleted_cnt: usize,
+ original_len: usize,
+ }
+
+ impl<T, A: Allocator> Drop for BackshiftOnDrop<'_, T, A> {
+ fn drop(&mut self) {
+ if self.deleted_cnt > 0 {
+ // SAFETY: Fill the hole of dropped or moved
+ unsafe {
+ ptr::copy(
+ self.v.as_ptr().offset(self.processed_len as isize),
+ self.v
+ .as_mut_ptr()
+ .offset(self.processed_len as isize - self.deleted_cnt as isize),
+ self.original_len - self.processed_len,
+ );
+ self.v.set_len(self.original_len - self.deleted_cnt);
+ }
}
}
}
- if del > 0 {
- self.truncate(len - del);
+
+ let mut guard = BackshiftOnDrop {
+ v: self,
+ processed_len: 0,
+ deleted_cnt: 0,
+ original_len: len,
+ };
+
+ let mut del = 0usize;
+ for i in 0..len {
+ // SAFETY: Unchecked element must be valid.
+ let cur = unsafe { &mut *guard.v.as_mut_ptr().offset(i as isize) };
+ if !f(cur) {
+ del += 1;
+ // Advance early to avoid double drop if `drop_in_place` panicked.
+ guard.processed_len = i + 1;
+ guard.deleted_cnt = del;
+ // SAFETY: We never touch this element again after dropped.
+ unsafe { ptr::drop_in_place(cur) };
+ } else if del > 0 {
+ // SAFETY: `del` > 0 so the hole slot must not overlap with current element.
+ // We use copy for move, and never touch this element again.
+ unsafe {
+ let hole_slot = guard.v.as_mut_ptr().offset(i as isize - del as isize);
+ ptr::copy_nonoverlapping(cur, hole_slot, 1);
+ }
+ guard.processed_len = i + 1;
+ }
}
+
+ // All holes are at the end now. Simply cut them out.
+ unsafe { guard.v.set_len(len - del) };
+ mem::forget(guard);
}
/// Removes all but the first of consecutive elements in the vector that resolve to the same
assert_eq!(vec, [2, 4]);
}
+#[test]
+fn test_retain_pred_panic() {
+ use std::sync::atomic::{AtomicU64, Ordering};
+
+ struct Wrap<'a>(&'a AtomicU64, u64, bool);
+
+ impl Drop for Wrap<'_> {
+ fn drop(&mut self) {
+ self.0.fetch_or(self.1, Ordering::SeqCst);
+ }
+ }
+
+ let dropped = AtomicU64::new(0);
+
+ let ret = std::panic::catch_unwind(|| {
+ let mut v = vec![
+ Wrap(&dropped, 1, false),
+ Wrap(&dropped, 2, false),
+ Wrap(&dropped, 4, false),
+ Wrap(&dropped, 8, false),
+ Wrap(&dropped, 16, false),
+ ];
+ v.retain(|w| match w.1 {
+ 1 => true,
+ 2 => false,
+ 4 => true,
+ _ => panic!(),
+ });
+ });
+ assert!(ret.is_err());
+ // Everything is dropped when predicate panicked.
+ assert_eq!(dropped.load(Ordering::SeqCst), 1 | 2 | 4 | 8 | 16);
+}
+
+#[test]
+fn test_retain_drop_panic() {
+ use std::sync::atomic::{AtomicU64, Ordering};
+
+ struct Wrap<'a>(&'a AtomicU64, u64);
+
+ impl Drop for Wrap<'_> {
+ fn drop(&mut self) {
+ if self.1 == 8 {
+ panic!();
+ }
+ self.0.fetch_or(self.1, Ordering::SeqCst);
+ }
+ }
+
+ let dropped = AtomicU64::new(0);
+
+ let ret = std::panic::catch_unwind(|| {
+ let mut v = vec![
+ Wrap(&dropped, 1),
+ Wrap(&dropped, 2),
+ Wrap(&dropped, 4),
+ Wrap(&dropped, 8),
+ Wrap(&dropped, 16),
+ ];
+ v.retain(|w| match w.1 {
+ 1 => true,
+ 2 => false,
+ 4 => true,
+ 8 => false,
+ _ => true,
+ });
+ });
+ assert!(ret.is_err());
+ // Other elements are dropped when `drop` of one element panicked.
+ assert_eq!(dropped.load(Ordering::SeqCst), 1 | 2 | 4 | 16);
+}
+
#[test]
fn test_dedup() {
fn case(a: Vec<i32>, b: Vec<i32>) {