]> git.lizzy.rs Git - rust.git/blobdiff - src/range_map.rs
fix reborrowing of tagged ZST references
[rust.git] / src / range_map.rs
index 118be32a2993b861cc5ba9f14a8cd41b02ef1c0e..607c830530e1f15f63fb6a093c6b01fe8466f34d 100644 (file)
 //! Implements a map from integer indices to data.
 //! Rather than storing data for every index, internally, this maps entire ranges to the data.
 //! To this end, the APIs all work on ranges, not on individual integers. Ranges are split as
-//! necessary (e.g. when [0,5) is first associated with X, and then [1,2) is mutated).
+//! necessary (e.g., when [0,5) is first associated with X, and then [1,2) is mutated).
 //! Users must not depend on whether a range is coalesced or not, even though this is observable
 //! via the iteration APIs.
-use std::collections::BTreeMap;
+
 use std::ops;
 
-#[derive(Clone, Debug, PartialEq, Eq)]
-pub struct RangeMap<T> {
-    map: BTreeMap<Range, T>,
-}
+use rustc_target::abi::Size;
 
-// The derived `Ord` impl sorts first by the first field, then, if the fields are the same,
-// by the second field.
-// This is exactly what we need for our purposes, since a range query on a BTReeSet/BTreeMap will give us all
-// `MemoryRange`s whose `start` is <= than the one we're looking for, but not > the end of the range we're checking.
-// At the same time the `end` is irrelevant for the sorting and range searching, but used for the check.
-// This kind of search breaks, if `end < start`, so don't do that!
-#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Debug)]
-struct Range {
-    start: u64,
-    end: u64, // Invariant: end > start
+#[derive(Clone, Debug)]
+struct Elem<T> {
+    /// The range covered by this element; never empty.
+    range: ops::Range<u64>,
+    /// The data stored for this element.
+    data: T,
 }
-
-impl Range {
-    fn range(offset: u64, len: u64) -> ops::Range<Range> {
-        assert!(len > 0);
-        // We select all elements that are within
-        // the range given by the offset into the allocation and the length.
-        // This is sound if all ranges that intersect with the argument range, are in the
-        // resulting range of ranges.
-        let left = Range {
-            // lowest range to include `offset`
-            start: 0,
-            end: offset + 1,
-        };
-        let right = Range {
-            // lowest (valid) range not to include `offset+len`
-            start: offset + len,
-            end: offset + len + 1,
-        };
-        left..right
-    }
-
-    /// Tests if all of [offset, offset+len) are contained in this range.
-    fn overlaps(&self, offset: u64, len: u64) -> bool {
-        assert!(len > 0);
-        offset < self.end && offset + len >= self.start
-    }
+#[derive(Clone, Debug)]
+pub struct RangeMap<T> {
+    v: Vec<Elem<T>>,
 }
 
 impl<T> RangeMap<T> {
-    pub fn new() -> RangeMap<T> {
-        RangeMap { map: BTreeMap::new() }
-    }
-
-    fn iter_with_range<'a>(
-        &'a self,
-        offset: u64,
-        len: u64,
-    ) -> impl Iterator<Item = (&'a Range, &'a T)> + 'a {
-        assert!(len > 0);
-        self.map.range(Range::range(offset, len)).filter_map(
-            move |(range,
-                   data)| {
-                if range.overlaps(offset, len) {
-                    Some((range, data))
-                } else {
-                    None
-                }
-            },
-        )
+    /// Creates a new `RangeMap` for the given size, and with the given initial value used for
+    /// the entire range.
+    #[inline(always)]
+    pub fn new(size: Size, init: T) -> RangeMap<T> {
+        let size = size.bytes();
+        let mut map = RangeMap { v: Vec::new() };
+        if size > 0 {
+            map.v.push(Elem { range: 0..size, data: init });
+        }
+        map
     }
 
-    pub fn iter<'a>(&'a self, offset: u64, len: u64) -> impl Iterator<Item = &'a T> + 'a {
-        self.iter_with_range(offset, len).map(|(_, data)| data)
+    /// Finds the index containing the given offset.
+    fn find_offset(&self, offset: u64) -> usize {
+        // We do a binary search.
+        let mut left = 0usize; // inclusive
+        let mut right = self.v.len(); // exclusive
+        loop {
+            debug_assert!(left < right, "find_offset: offset {} is out-of-bounds", offset);
+            let candidate = left.checked_add(right).unwrap() / 2;
+            let elem = &self.v[candidate];
+            if offset < elem.range.start {
+                // We are too far right (offset is further left).
+                debug_assert!(candidate < right); // we are making progress
+                right = candidate;
+            } else if offset >= elem.range.end {
+                // We are too far left (offset is further right).
+                debug_assert!(candidate >= left); // we are making progress
+                left = candidate + 1;
+            } else {
+                // This is it!
+                return candidate;
+            }
+        }
     }
 
-    fn split_entry_at(&mut self, offset: u64)
-    where
-        T: Clone,
-    {
-        let range = match self.iter_with_range(offset, 1).next() {
-            Some((&range, _)) => range,
-            None => return,
+    /// Provides read-only iteration over everything in the given range. This does
+    /// *not* split items if they overlap with the edges. Do not use this to mutate
+    /// through interior mutability.
+    ///
+    /// The iterator also provides the offset of the given element.
+    pub fn iter<'a>(&'a self, offset: Size, len: Size) -> impl Iterator<Item = (Size, &'a T)> + 'a {
+        let offset = offset.bytes();
+        let len = len.bytes();
+        // Compute a slice starting with the elements we care about.
+        let slice: &[Elem<T>] = if len == 0 {
+            // We just need any empty iterator. We don't even want to
+            // yield the element that surrounds this position.
+            &[]
+        } else {
+            let first_idx = self.find_offset(offset);
+            &self.v[first_idx..]
         };
-        assert!(
-            range.start <= offset && range.end > offset,
-            "We got a range that doesn't even contain what we asked for."
-        );
-        // There is an entry overlapping this position, see if we have to split it
-        if range.start < offset {
-            let data = self.map.remove(&range).unwrap();
-            let old = self.map.insert(
-                Range {
-                    start: range.start,
-                    end: offset,
-                },
-                data.clone(),
-            );
-            assert!(old.is_none());
-            let old = self.map.insert(
-                Range {
-                    start: offset,
-                    end: range.end,
-                },
-                data,
-            );
-            assert!(old.is_none());
-        }
+        // The first offset that is not included any more.
+        let end = offset + len;
+        slice.iter().take_while(move |elem| elem.range.start < end).map(|elem| (Size::from_bytes(elem.range.start), &elem.data))
     }
 
     pub fn iter_mut_all<'a>(&'a mut self) -> impl Iterator<Item = &'a mut T> + 'a {
-        self.map.values_mut()
+        self.v.iter_mut().map(|elem| &mut elem.data)
     }
 
-    /// Provide mutable iteration over everything in the given range.  As a side-effect,
-    /// this will split entries in the map that are only partially hit by the given range,
-    /// to make sure that when they are mutated, the effect is constrained to the given range.
-    pub fn iter_mut_with_gaps<'a>(
-        &'a mut self,
-        offset: u64,
-        len: u64,
-    ) -> impl Iterator<Item = &'a mut T> + 'a
+    // Splits the element situated at the given `index`, such that the 2nd one starts at offset
+    // `split_offset`. Do nothing if the element already starts there.
+    // Returns whether a split was necessary.
+    fn split_index(&mut self, index: usize, split_offset: u64) -> bool
     where
         T: Clone,
     {
-        assert!(len > 0);
-        // Preparation: Split first and last entry as needed.
-        self.split_entry_at(offset);
-        self.split_entry_at(offset + len);
-        // Now we can provide a mutable iterator
-        self.map.range_mut(Range::range(offset, len)).filter_map(
-            move |(&range, data)| {
-                if range.overlaps(offset, len) {
-                    assert!(
-                        offset <= range.start && offset + len >= range.end,
-                        "The splitting went wrong"
-                    );
-                    Some(data)
-                } else {
-                    // Skip this one
-                    None
-                }
-            },
-        )
-    }
-
-    /// Provide a mutable iterator over everything in the given range, with the same side-effects as
-    /// iter_mut_with_gaps.  Furthermore, if there are gaps between ranges, fill them with the given default.
-    /// This is also how you insert.
-    pub fn iter_mut<'a>(&'a mut self, offset: u64, len: u64) -> impl Iterator<Item = &'a mut T> + 'a
-    where
-        T: Clone + Default,
-    {
-        // Do a first iteration to collect the gaps
-        let mut gaps = Vec::new();
-        let mut last_end = offset;
-        for (range, _) in self.iter_with_range(offset, len) {
-            if last_end < range.start {
-                gaps.push(Range {
-                    start: last_end,
-                    end: range.start,
-                });
-            }
-            last_end = range.end;
-        }
-        if last_end < offset + len {
-            gaps.push(Range {
-                start: last_end,
-                end: offset + len,
-            });
-        }
-
-        // Add default for all gaps
-        for gap in gaps {
-            let old = self.map.insert(gap, Default::default());
-            assert!(old.is_none());
+        let elem = &mut self.v[index];
+        if split_offset == elem.range.start || split_offset == elem.range.end {
+            // Nothing to do.
+            return false;
         }
+        debug_assert!(
+            elem.range.contains(&split_offset),
+            "the `split_offset` is not in the element to be split"
+        );
 
-        // Now provide mutable iteration
-        self.iter_mut_with_gaps(offset, len)
+        // Now we really have to split. Reduce length of first element.
+        let second_range = split_offset..elem.range.end;
+        elem.range.end = split_offset;
+        // Copy the data, and insert second element.
+        let second = Elem { range: second_range, data: elem.data.clone() };
+        self.v.insert(index + 1, second);
+        return true;
     }
 
-    pub fn retain<F>(&mut self, mut f: F)
+    /// Provides mutable iteration over everything in the given range. As a side-effect,
+    /// this will split entries in the map that are only partially hit by the given range,
+    /// to make sure that when they are mutated, the effect is constrained to the given range.
+    /// Moreover, this will opportunistically merge neighbouring equal blocks.
+    ///
+    /// The iterator also provides the offset of the given element.
+    pub fn iter_mut<'a>(
+        &'a mut self,
+        offset: Size,
+        len: Size,
+    ) -> impl Iterator<Item = (Size, &'a mut T)> + 'a
     where
-        F: FnMut(&T) -> bool,
+        T: Clone + PartialEq,
     {
-        let mut remove = Vec::new();
-        for (range, data) in self.map.iter() {
-            if !f(data) {
-                remove.push(*range);
+        let offset = offset.bytes();
+        let len = len.bytes();
+        // Compute a slice containing exactly the elements we care about
+        let slice: &mut [Elem<T>] = if len == 0 {
+            // We just need any empty iterator. We don't even want to
+            // yield the element that surrounds this position, nor do
+            // any splitting.
+            &mut []
+        } else {
+            // Make sure we got a clear beginning
+            let mut first_idx = self.find_offset(offset);
+            if self.split_index(first_idx, offset) {
+                // The newly created 2nd element is ours
+                first_idx += 1;
             }
-        }
-
-        for range in remove {
-            self.map.remove(&range);
-        }
+            // No more mutation.
+            let first_idx = first_idx;
+            // Find our end. Linear scan, but that's ok because the iteration
+            // is doing the same linear scan anyway -- no increase in complexity.
+            // We combine this scan with a scan for duplicates that we can merge, to reduce
+            // the number of elements.
+            // We stop searching after the first "block" of size 1, to avoid spending excessive
+            // amounts of time on the merging.
+            let mut equal_since_idx = first_idx;
+            // Once we see too many non-mergeable blocks, we stop.
+            // The initial value is chosen via... magic. Benchmarking and magic.
+            let mut successful_merge_count = 3usize;
+            // When the loop is done, this is the first excluded element.
+            let mut end_idx = first_idx;
+            loop {
+                // Compute if `end` is the last element we need to look at.
+                let done = self.v[end_idx].range.end >= offset + len;
+                // We definitely need to include `end`, so move the index.
+                end_idx += 1;
+                debug_assert!(
+                    done || end_idx < self.v.len(),
+                    "iter_mut: end-offset {} is out-of-bounds",
+                    offset + len
+                );
+                // see if we want to merge everything in `equal_since..end` (exclusive at the end!)
+                if successful_merge_count > 0 {
+                    if done || self.v[end_idx].data != self.v[equal_since_idx].data {
+                        // Everything in `equal_since..end` was equal. Make them just one element covering
+                        // the entire range.
+                        let removed_elems = end_idx - equal_since_idx - 1; // number of elements that we would remove
+                        if removed_elems > 0 {
+                            // Adjust the range of the first element to cover all of them.
+                            let equal_until = self.v[end_idx - 1].range.end; // end of range of last of the equal elements
+                            self.v[equal_since_idx].range.end = equal_until;
+                            // Delete the rest of them.
+                            self.v.splice(equal_since_idx + 1..end_idx, std::iter::empty());
+                            // Adjust `end_idx` because we made the list shorter.
+                            end_idx -= removed_elems;
+                            // Adjust the count for the cutoff.
+                            successful_merge_count += removed_elems;
+                        } else {
+                            // Adjust the count for the cutoff.
+                            successful_merge_count -= 1;
+                        }
+                        // Go on scanning for the next block starting here.
+                        equal_since_idx = end_idx;
+                    }
+                }
+                // Leave loop if this is the last element.
+                if done {
+                    break;
+                }
+            }
+            // Move to last included instead of first excluded index.
+            let end_idx = end_idx - 1;
+            // We need to split the end as well. Even if this performs a
+            // split, we don't have to adjust our index as we only care about
+            // the first part of the split.
+            self.split_index(end_idx, offset + len);
+            // Now we yield the slice. `end` is inclusive.
+            &mut self.v[first_idx..=end_idx]
+        };
+        slice.iter_mut().map(|elem| (Size::from_bytes(elem.range.start), &mut elem.data))
     }
 }
 
@@ -209,42 +213,67 @@ mod tests {
     fn to_vec<T: Copy>(map: &RangeMap<T>, offset: u64, len: u64) -> Vec<T> {
         (offset..offset + len)
             .into_iter()
-            .map(|i| *map.iter(i, 1).next().unwrap())
+            .map(|i| map.iter(Size::from_bytes(i), Size::from_bytes(1)).next().map(|(_, &t)| t).unwrap())
             .collect()
     }
 
     #[test]
     fn basic_insert() {
-        let mut map = RangeMap::<i32>::new();
-        // Insert
-        for x in map.iter_mut(10, 1) {
+        let mut map = RangeMap::<i32>::new(Size::from_bytes(20), -1);
+        // Insert.
+        for (_, x) in map.iter_mut(Size::from_bytes(10), Size::from_bytes(1)) {
             *x = 42;
         }
-        // Check
+        // Check.
         assert_eq!(to_vec(&map, 10, 1), vec![42]);
+        assert_eq!(map.v.len(), 3);
+
+        // Insert with size 0.
+        for (_, x) in map.iter_mut(Size::from_bytes(10), Size::from_bytes(0)) {
+            *x = 19;
+        }
+        for (_, x) in map.iter_mut(Size::from_bytes(11), Size::from_bytes(0)) {
+            *x = 19;
+        }
+        assert_eq!(to_vec(&map, 10, 2), vec![42, -1]);
+        assert_eq!(map.v.len(), 3);
     }
 
     #[test]
     fn gaps() {
-        let mut map = RangeMap::<i32>::new();
-        for x in map.iter_mut(11, 1) {
+        let mut map = RangeMap::<i32>::new(Size::from_bytes(20), -1);
+        for (_, x) in map.iter_mut(Size::from_bytes(11), Size::from_bytes(1)) {
             *x = 42;
         }
-        for x in map.iter_mut(15, 1) {
-            *x = 42;
+        for (_, x) in map.iter_mut(Size::from_bytes(15), Size::from_bytes(1)) {
+            *x = 43;
         }
+        assert_eq!(map.v.len(), 5);
+        assert_eq!(to_vec(&map, 10, 10), vec![-1, 42, -1, -1, -1, 43, -1, -1, -1, -1]);
 
-        // Now request a range that needs three gaps filled
-        for x in map.iter_mut(10, 10) {
-            if *x != 42 {
+        for (_, x) in map.iter_mut(Size::from_bytes(10), Size::from_bytes(10)) {
+            if *x < 42 {
                 *x = 23;
             }
         }
+        assert_eq!(map.v.len(), 6);
+        assert_eq!(to_vec(&map, 10, 10), vec![23, 42, 23, 23, 23, 43, 23, 23, 23, 23]);
+        assert_eq!(to_vec(&map, 13, 5), vec![23, 23, 43, 23, 23]);
 
+        for (_, x) in map.iter_mut(Size::from_bytes(15), Size::from_bytes(5)) {
+            *x = 19;
+        }
+        assert_eq!(map.v.len(), 6);
+        assert_eq!(to_vec(&map, 10, 10), vec![23, 42, 23, 23, 23, 19, 19, 19, 19, 19]);
+        // Should be seeing two blocks with 19.
         assert_eq!(
-            to_vec(&map, 10, 10),
-            vec![23, 42, 23, 23, 23, 42, 23, 23, 23, 23]
+            map.iter(Size::from_bytes(15), Size::from_bytes(2)).map(|(_, &t)| t).collect::<Vec<_>>(),
+            vec![19, 19]
         );
-        assert_eq!(to_vec(&map, 13, 5), vec![23, 23, 42, 23, 23]);
+
+        // A NOP `iter_mut` should trigger merging.
+        for _ in map.iter_mut(Size::from_bytes(15), Size::from_bytes(5)) {}
+        assert_eq!(map.v.len(), 5);
+        assert_eq!(to_vec(&map, 10, 10), vec![23, 42, 23, 23, 23, 19, 19, 19, 19, 19]);
     }
 }