]> git.lizzy.rs Git - rust.git/blob - src/range_map.rs
Merge pull request #374 from solson/cleanups
[rust.git] / src / range_map.rs
1 //! Implements a map from integer indices to data.
2 //! Rather than storing data for every index, internally, this maps entire ranges to the data.
3 //! To this end, the APIs all work on ranges, not on individual integers. Ranges are split as
4 //! necessary (e.g. when [0,5) is first associated with X, and then [1,2) is mutated).
5 //! Users must not depend on whether a range is coalesced or not, even though this is observable
6 //! via the iteration APIs.
7 use std::collections::BTreeMap;
8 use std::ops;
9
10 #[derive(Clone, Debug)]
11 pub struct RangeMap<T> {
12     map: BTreeMap<Range, T>,
13 }
14
15 // The derived `Ord` impl sorts first by the first field, then, if the fields are the same,
16 // by the second field.
17 // This is exactly what we need for our purposes, since a range query on a BTReeSet/BTreeMap will give us all
18 // `MemoryRange`s whose `start` is <= than the one we're looking for, but not > the end of the range we're checking.
19 // At the same time the `end` is irrelevant for the sorting and range searching, but used for the check.
20 // This kind of search breaks, if `end < start`, so don't do that!
21 #[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Debug)]
22 struct Range {
23     start: u64,
24     end: u64, // Invariant: end > start
25 }
26
27 impl Range {
28     fn range(offset: u64, len: u64) -> ops::Range<Range> {
29         assert!(len > 0);
30         // We select all elements that are within
31         // the range given by the offset into the allocation and the length.
32         // This is sound if all ranges that intersect with the argument range, are in the
33         // resulting range of ranges.
34         let left = Range {
35             // lowest range to include `offset`
36             start: 0,
37             end: offset + 1,
38         };
39         let right = Range {
40             // lowest (valid) range not to include `offset+len`
41             start: offset + len,
42             end: offset + len + 1,
43         };
44         left..right
45     }
46
47     /// Tests if all of [offset, offset+len) are contained in this range.
48     fn overlaps(&self, offset: u64, len: u64) -> bool {
49         assert!(len > 0);
50         offset < self.end && offset + len >= self.start
51     }
52 }
53
54 impl<T> RangeMap<T> {
55     pub fn new() -> RangeMap<T> {
56         RangeMap { map: BTreeMap::new() }
57     }
58
59     fn iter_with_range<'a>(
60         &'a self,
61         offset: u64,
62         len: u64,
63     ) -> impl Iterator<Item = (&'a Range, &'a T)> + 'a {
64         assert!(len > 0);
65         self.map.range(Range::range(offset, len)).filter_map(
66             move |(range,
67                    data)| {
68                 if range.overlaps(offset, len) {
69                     Some((range, data))
70                 } else {
71                     None
72                 }
73             },
74         )
75     }
76
77     pub fn iter<'a>(&'a self, offset: u64, len: u64) -> impl Iterator<Item = &'a T> + 'a {
78         self.iter_with_range(offset, len).map(|(_, data)| data)
79     }
80
81     fn split_entry_at(&mut self, offset: u64)
82     where
83         T: Clone,
84     {
85         let range = match self.iter_with_range(offset, 1).next() {
86             Some((&range, _)) => range,
87             None => return,
88         };
89         assert!(
90             range.start <= offset && range.end > offset,
91             "We got a range that doesn't even contain what we asked for."
92         );
93         // There is an entry overlapping this position, see if we have to split it
94         if range.start < offset {
95             let data = self.map.remove(&range).unwrap();
96             let old = self.map.insert(
97                 Range {
98                     start: range.start,
99                     end: offset,
100                 },
101                 data.clone(),
102             );
103             assert!(old.is_none());
104             let old = self.map.insert(
105                 Range {
106                     start: offset,
107                     end: range.end,
108                 },
109                 data,
110             );
111             assert!(old.is_none());
112         }
113     }
114
115     pub fn iter_mut_all<'a>(&'a mut self) -> impl Iterator<Item = &'a mut T> + 'a {
116         self.map.values_mut()
117     }
118
119     /// Provide mutable iteration over everything in the given range.  As a side-effect,
120     /// this will split entries in the map that are only partially hit by the given range,
121     /// to make sure that when they are mutated, the effect is constrained to the given range.
122     pub fn iter_mut_with_gaps<'a>(
123         &'a mut self,
124         offset: u64,
125         len: u64,
126     ) -> impl Iterator<Item = &'a mut T> + 'a
127     where
128         T: Clone,
129     {
130         assert!(len > 0);
131         // Preparation: Split first and last entry as needed.
132         self.split_entry_at(offset);
133         self.split_entry_at(offset + len);
134         // Now we can provide a mutable iterator
135         self.map.range_mut(Range::range(offset, len)).filter_map(
136             move |(&range, data)| {
137                 if range.overlaps(offset, len) {
138                     assert!(
139                         offset <= range.start && offset + len >= range.end,
140                         "The splitting went wrong"
141                     );
142                     Some(data)
143                 } else {
144                     // Skip this one
145                     None
146                 }
147             },
148         )
149     }
150
151     /// Provide a mutable iterator over everything in the given range, with the same side-effects as
152     /// iter_mut_with_gaps.  Furthermore, if there are gaps between ranges, fill them with the given default.
153     /// This is also how you insert.
154     pub fn iter_mut<'a>(&'a mut self, offset: u64, len: u64) -> impl Iterator<Item = &'a mut T> + 'a
155     where
156         T: Clone + Default,
157     {
158         // Do a first iteration to collect the gaps
159         let mut gaps = Vec::new();
160         let mut last_end = offset;
161         for (range, _) in self.iter_with_range(offset, len) {
162             if last_end < range.start {
163                 gaps.push(Range {
164                     start: last_end,
165                     end: range.start,
166                 });
167             }
168             last_end = range.end;
169         }
170         if last_end < offset + len {
171             gaps.push(Range {
172                 start: last_end,
173                 end: offset + len,
174             });
175         }
176
177         // Add default for all gaps
178         for gap in gaps {
179             let old = self.map.insert(gap, Default::default());
180             assert!(old.is_none());
181         }
182
183         // Now provide mutable iteration
184         self.iter_mut_with_gaps(offset, len)
185     }
186
187     pub fn retain<F>(&mut self, mut f: F)
188     where
189         F: FnMut(&T) -> bool,
190     {
191         let mut remove = Vec::new();
192         for (range, data) in self.map.iter() {
193             if !f(data) {
194                 remove.push(*range);
195             }
196         }
197
198         for range in remove {
199             self.map.remove(&range);
200         }
201     }
202 }
203
204 #[cfg(test)]
205 mod tests {
206     use super::*;
207
208     /// Query the map at every offset in the range and collect the results.
209     fn to_vec<T: Copy>(map: &RangeMap<T>, offset: u64, len: u64) -> Vec<T> {
210         (offset..offset + len)
211             .into_iter()
212             .map(|i| *map.iter(i, 1).next().unwrap())
213             .collect()
214     }
215
216     #[test]
217     fn basic_insert() {
218         let mut map = RangeMap::<i32>::new();
219         // Insert
220         for x in map.iter_mut(10, 1) {
221             *x = 42;
222         }
223         // Check
224         assert_eq!(to_vec(&map, 10, 1), vec![42]);
225     }
226
227     #[test]
228     fn gaps() {
229         let mut map = RangeMap::<i32>::new();
230         for x in map.iter_mut(11, 1) {
231             *x = 42;
232         }
233         for x in map.iter_mut(15, 1) {
234             *x = 42;
235         }
236
237         // Now request a range that needs three gaps filled
238         for x in map.iter_mut(10, 10) {
239             if *x != 42 {
240                 *x = 23;
241             }
242         }
243
244         assert_eq!(
245             to_vec(&map, 10, 10),
246             vec![23, 42, 23, 23, 23, 42, 23, 23, 23, 23]
247         );
248         assert_eq!(to_vec(&map, 13, 5), vec![23, 23, 42, 23, 23]);
249     }
250 }