]> git.lizzy.rs Git - rust.git/commitdiff
Document BinaryHeap unsafe functions
authorGiacomo Stevanato <giaco.stevanato@gmail.com>
Wed, 3 Feb 2021 11:45:55 +0000 (12:45 +0100)
committerMark Rousskov <mark.simulacrum@gmail.com>
Sat, 20 Feb 2021 20:44:17 +0000 (15:44 -0500)
library/alloc/src/collections/binary_heap.rs

index 8a36b2af765227f96ad51fd44a9a93aac5e4292d..87184f90d571a89897ff6b500602fc80e984be3f 100644 (file)
@@ -275,7 +275,8 @@ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 impl<T: Ord> Drop for PeekMut<'_, T> {
     fn drop(&mut self) {
         if self.sift {
-            self.heap.sift_down(0);
+            // SAFETY: PeekMut is only instantiated for non-empty heaps.
+            unsafe { self.heap.sift_down(0) };
         }
     }
 }
@@ -431,7 +432,8 @@ pub fn pop(&mut self) -> Option<T> {
         self.data.pop().map(|mut item| {
             if !self.is_empty() {
                 swap(&mut item, &mut self.data[0]);
-                self.sift_down_to_bottom(0);
+                // SAFETY: !self.is_empty() means that self.len() > 0
+                unsafe { self.sift_down_to_bottom(0) };
             }
             item
         })
@@ -473,7 +475,9 @@ pub fn pop(&mut self) -> Option<T> {
     pub fn push(&mut self, item: T) {
         let old_len = self.len();
         self.data.push(item);
-        self.sift_up(0, old_len);
+        // SAFETY: Since we pushed a new item it means that
+        //  old_len = self.len() - 1 < self.len()
+        unsafe { self.sift_up(0, old_len) };
     }
 
     /// Consumes the `BinaryHeap` and returns a vector in sorted
@@ -506,7 +510,10 @@ pub fn into_sorted_vec(mut self) -> Vec<T> {
                 let ptr = self.data.as_mut_ptr();
                 ptr::swap(ptr, ptr.add(end));
             }
-            self.sift_down_range(0, end);
+            // SAFETY: `end` goes from `self.len() - 1` to 1 (both included) so:
+            //  0 < 1 <= end <= self.len() - 1 < self.len()
+            //  Which means 0 < end and end < self.len().
+            unsafe { self.sift_down_range(0, end) };
         }
         self.into_vec()
     }
@@ -519,47 +526,82 @@ pub fn into_sorted_vec(mut self) -> Vec<T> {
     // the hole is filled back at the end of its scope, even on panic.
     // Using a hole reduces the constant factor compared to using swaps,
     // which involves twice as many moves.
-    fn sift_up(&mut self, start: usize, pos: usize) -> usize {
-        unsafe {
-            // Take out the value at `pos` and create a hole.
-            let mut hole = Hole::new(&mut self.data, pos);
-
-            while hole.pos() > start {
-                let parent = (hole.pos() - 1) / 2;
-                if hole.element() <= hole.get(parent) {
-                    break;
-                }
-                hole.move_to(parent);
+
+    /// # Safety
+    ///
+    /// The caller must guarantee that `pos < self.len()`.
+    unsafe fn sift_up(&mut self, start: usize, pos: usize) -> usize {
+        // Take out the value at `pos` and create a hole.
+        // SAFETY: The caller guarantees that pos < self.len()
+        let mut hole = unsafe { Hole::new(&mut self.data, pos) };
+
+        while hole.pos() > start {
+            let parent = (hole.pos() - 1) / 2;
+
+            // SAFETY: hole.pos() > start >= 0, which means hole.pos() > 0
+            //  and so hole.pos() - 1 can't underflow.
+            //  This guarantees that parent < hole.pos() so
+            //  it's a valid index and also != hole.pos().
+            if hole.element() <= unsafe { hole.get(parent) } {
+                break;
             }
-            hole.pos()
+
+            // SAFETY: Same as above
+            unsafe { hole.move_to(parent) };
         }
+
+        hole.pos()
     }
 
     /// Take an element at `pos` and move it down the heap,
     /// while its children are larger.
-    fn sift_down_range(&mut self, pos: usize, end: usize) {
-        unsafe {
-            let mut hole = Hole::new(&mut self.data, pos);
-            let mut child = 2 * pos + 1;
-            while child < end - 1 {
-                // compare with the greater of the two children
-                child += (hole.get(child) <= hole.get(child + 1)) as usize;
-                // if we are already in order, stop.
-                if hole.element() >= hole.get(child) {
-                    return;
-                }
-                hole.move_to(child);
-                child = 2 * hole.pos() + 1;
-            }
-            if child == end - 1 && hole.element() < hole.get(child) {
-                hole.move_to(child);
+    ///
+    /// # Safety
+    ///
+    /// The caller must guarantee that `pos < end <= self.len()`.
+    unsafe fn sift_down_range(&mut self, pos: usize, end: usize) {
+        // SAFETY: The caller guarantees that pos < end <= self.len().
+        let mut hole = unsafe { Hole::new(&mut self.data, pos) };
+        let mut child = 2 * hole.pos() + 1;
+
+        // Loop invariant: child == 2 * hole.pos() + 1.
+        while child < end - 1 {
+            // compare with the greater of the two children
+            // SAFETY: child < end - 1 < self.len() and
+            //  child + 1 < end <= self.len(), so they're valid indexes.
+            //  child == 2 * hole.pos() + 1 != hole.pos() and
+            //  child + 1 == 2 * hole.pos() + 2 != hole.pos().
+            child += unsafe { hole.get(child) <= hole.get(child + 1) } as usize;
+
+            // if we are already in order, stop.
+            // SAFETY: child is now either the old child or the old child+1
+            //  We already proven that both are < self.len() and != hole.pos()
+            if hole.element() >= unsafe { hole.get(child) } {
+                return;
             }
+
+            // SAFETY: same as above.
+            unsafe { hole.move_to(child) };
+            child = 2 * hole.pos() + 1;
+        }
+
+        // SAFETY: && short circuit, which means that in the
+        //  second condition it's already true that child == end - 1 < self.len().
+        if child == end - 1 && hole.element() < unsafe { hole.get(child) } {
+            // SAFETY: child is already proven to be a valid index and
+            //  child == 2 * hole.pos() + 1 != hole.pos().
+            unsafe { hole.move_to(child) };
         }
     }
 
-    fn sift_down(&mut self, pos: usize) {
+    /// # Safety
+    ///
+    /// The caller must guarantee that `pos < self.len()`.
+    unsafe fn sift_down(&mut self, pos: usize) {
         let len = self.len();
-        self.sift_down_range(pos, len);
+        // SAFETY: pos < len is guaranteed by the caller and
+        //  obviously len = self.len() <= self.len().
+        unsafe { self.sift_down_range(pos, len) };
     }
 
     /// Take an element at `pos` and move it all the way down the heap,
@@ -567,30 +609,52 @@ fn sift_down(&mut self, pos: usize) {
     ///
     /// Note: This is faster when the element is known to be large / should
     /// be closer to the bottom.
-    fn sift_down_to_bottom(&mut self, mut pos: usize) {
+    ///
+    /// # Safety
+    ///
+    /// The caller must guarantee that `pos < self.len()`.
+    unsafe fn sift_down_to_bottom(&mut self, mut pos: usize) {
         let end = self.len();
         let start = pos;
-        unsafe {
-            let mut hole = Hole::new(&mut self.data, pos);
-            let mut child = 2 * pos + 1;
-            while child < end - 1 {
-                child += (hole.get(child) <= hole.get(child + 1)) as usize;
-                hole.move_to(child);
-                child = 2 * hole.pos() + 1;
-            }
-            if child == end - 1 {
-                hole.move_to(child);
-            }
-            pos = hole.pos;
+
+        // SAFETY: The caller guarantees that pos < self.len().
+        let mut hole = unsafe { Hole::new(&mut self.data, pos) };
+        let mut child = 2 * hole.pos() + 1;
+
+        // Loop invariant: child == 2 * hole.pos() + 1.
+        while child < end - 1 {
+            // SAFETY: child < end - 1 < self.len() and
+            //  child + 1 < end <= self.len(), so they're valid indexes.
+            //  child == 2 * hole.pos() + 1 != hole.pos() and
+            //  child + 1 == 2 * hole.pos() + 2 != hole.pos().
+            child += unsafe { hole.get(child) <= hole.get(child + 1) } as usize;
+
+            // SAFETY: Same as above
+            unsafe { hole.move_to(child) };
+            child = 2 * hole.pos() + 1;
         }
-        self.sift_up(start, pos);
+
+        if child == end - 1 {
+            // SAFETY: child == end - 1 < self.len(), so it's a valid index
+            //  and child == 2 * hole.pos() + 1 != hole.pos().
+            unsafe { hole.move_to(child) };
+        }
+        pos = hole.pos();
+        drop(hole);
+
+        // SAFETY: pos is the position in the hole and was already proven
+        //  to be a valid index.
+        unsafe { self.sift_up(start, pos) };
     }
 
     fn rebuild(&mut self) {
         let mut n = self.len() / 2;
         while n > 0 {
             n -= 1;
-            self.sift_down(n);
+            // SAFETY: n starts from self.len() / 2 and goes down to 0.
+            //  The only case when !(n < self.len()) is if
+            //  self.len() == 0, but it's ruled out by the loop condition.
+            unsafe { self.sift_down(n) };
         }
     }