]> git.lizzy.rs Git - rust.git/blob - compiler/rustc_index/src/interval.rs
Auto merge of #106910 - aliemjay:alias-ty-in-regionck, r=oli-obk
[rust.git] / compiler / rustc_index / src / interval.rs
1 use std::iter::Step;
2 use std::marker::PhantomData;
3 use std::ops::RangeBounds;
4 use std::ops::{Bound, Range};
5
6 use crate::vec::Idx;
7 use crate::vec::IndexVec;
8 use smallvec::SmallVec;
9
10 #[cfg(test)]
11 mod tests;
12
13 /// Stores a set of intervals on the indices.
14 ///
15 /// The elements in `map` are sorted and non-adjacent, which means
16 /// the second value of the previous element is *greater* than the
17 /// first value of the following element.
18 #[derive(Debug, Clone)]
19 pub struct IntervalSet<I> {
20     // Start, end
21     map: SmallVec<[(u32, u32); 4]>,
22     domain: usize,
23     _data: PhantomData<I>,
24 }
25
26 #[inline]
27 fn inclusive_start<T: Idx>(range: impl RangeBounds<T>) -> u32 {
28     match range.start_bound() {
29         Bound::Included(start) => start.index() as u32,
30         Bound::Excluded(start) => start.index() as u32 + 1,
31         Bound::Unbounded => 0,
32     }
33 }
34
35 #[inline]
36 fn inclusive_end<T: Idx>(domain: usize, range: impl RangeBounds<T>) -> Option<u32> {
37     let end = match range.end_bound() {
38         Bound::Included(end) => end.index() as u32,
39         Bound::Excluded(end) => end.index().checked_sub(1)? as u32,
40         Bound::Unbounded => domain.checked_sub(1)? as u32,
41     };
42     Some(end)
43 }
44
45 impl<I: Idx> IntervalSet<I> {
46     pub fn new(domain: usize) -> IntervalSet<I> {
47         IntervalSet { map: SmallVec::new(), domain, _data: PhantomData }
48     }
49
50     pub fn clear(&mut self) {
51         self.map.clear();
52     }
53
54     pub fn iter(&self) -> impl Iterator<Item = I> + '_
55     where
56         I: Step,
57     {
58         self.iter_intervals().flatten()
59     }
60
61     /// Iterates through intervals stored in the set, in order.
62     pub fn iter_intervals(&self) -> impl Iterator<Item = std::ops::Range<I>> + '_
63     where
64         I: Step,
65     {
66         self.map.iter().map(|&(start, end)| I::new(start as usize)..I::new(end as usize + 1))
67     }
68
69     /// Returns true if we increased the number of elements present.
70     pub fn insert(&mut self, point: I) -> bool {
71         self.insert_range(point..=point)
72     }
73
74     /// Returns true if we increased the number of elements present.
75     pub fn insert_range(&mut self, range: impl RangeBounds<I> + Clone) -> bool {
76         let start = inclusive_start(range.clone());
77         let Some(end) = inclusive_end(self.domain, range) else {
78             // empty range
79             return false;
80         };
81         if start > end {
82             return false;
83         }
84
85         // This condition looks a bit weird, but actually makes sense.
86         //
87         // if r.0 == end + 1, then we're actually adjacent, so we want to
88         // continue to the next range. We're looking here for the first
89         // range which starts *non-adjacently* to our end.
90         let next = self.map.partition_point(|r| r.0 <= end + 1);
91         let result = if let Some(right) = next.checked_sub(1) {
92             let (prev_start, prev_end) = self.map[right];
93             if prev_end + 1 >= start {
94                 // If the start for the inserted range is adjacent to the
95                 // end of the previous, we can extend the previous range.
96                 if start < prev_start {
97                     // The first range which ends *non-adjacently* to our start.
98                     // And we can ensure that left <= right.
99                     let left = self.map.partition_point(|l| l.1 + 1 < start);
100                     let min = std::cmp::min(self.map[left].0, start);
101                     let max = std::cmp::max(prev_end, end);
102                     self.map[right] = (min, max);
103                     if left != right {
104                         self.map.drain(left..right);
105                     }
106                     true
107                 } else {
108                     // We overlap with the previous range, increase it to
109                     // include us.
110                     //
111                     // Make sure we're actually going to *increase* it though --
112                     // it may be that end is just inside the previously existing
113                     // set.
114                     if end > prev_end {
115                         self.map[right].1 = end;
116                         true
117                     } else {
118                         false
119                     }
120                 }
121             } else {
122                 // Otherwise, we don't overlap, so just insert
123                 self.map.insert(right + 1, (start, end));
124                 true
125             }
126         } else {
127             if self.map.is_empty() {
128                 // Quite common in practice, and expensive to call memcpy
129                 // with length zero.
130                 self.map.push((start, end));
131             } else {
132                 self.map.insert(next, (start, end));
133             }
134             true
135         };
136         debug_assert!(
137             self.check_invariants(),
138             "wrong intervals after insert {start:?}..={end:?} to {self:?}"
139         );
140         result
141     }
142
143     pub fn contains(&self, needle: I) -> bool {
144         let needle = needle.index() as u32;
145         let Some(last) = self.map.partition_point(|r| r.0 <= needle).checked_sub(1) else {
146             // All ranges in the map start after the new range's end
147             return false;
148         };
149         let (_, prev_end) = &self.map[last];
150         needle <= *prev_end
151     }
152
153     pub fn superset(&self, other: &IntervalSet<I>) -> bool
154     where
155         I: Step,
156     {
157         let mut sup_iter = self.iter_intervals();
158         let mut current = None;
159         let contains = |sup: Range<I>, sub: Range<I>, current: &mut Option<Range<I>>| {
160             if sup.end < sub.start {
161                 // if `sup.end == sub.start`, the next sup doesn't contain `sub.start`
162                 None // continue to the next sup
163             } else if sup.end >= sub.end && sup.start <= sub.start {
164                 *current = Some(sup); // save the current sup
165                 Some(true)
166             } else {
167                 Some(false)
168             }
169         };
170         other.iter_intervals().all(|sub| {
171             current
172                 .take()
173                 .and_then(|sup| contains(sup, sub.clone(), &mut current))
174                 .or_else(|| sup_iter.find_map(|sup| contains(sup, sub.clone(), &mut current)))
175                 .unwrap_or(false)
176         })
177     }
178
179     pub fn is_empty(&self) -> bool {
180         self.map.is_empty()
181     }
182
183     /// Returns the maximum (last) element present in the set from `range`.
184     pub fn last_set_in(&self, range: impl RangeBounds<I> + Clone) -> Option<I> {
185         let start = inclusive_start(range.clone());
186         let Some(end) = inclusive_end(self.domain, range) else {
187             // empty range
188             return None;
189         };
190         if start > end {
191             return None;
192         }
193         let Some(last) = self.map.partition_point(|r| r.0 <= end).checked_sub(1) else {
194             // All ranges in the map start after the new range's end
195             return None;
196         };
197         let (_, prev_end) = &self.map[last];
198         if start <= *prev_end { Some(I::new(std::cmp::min(*prev_end, end) as usize)) } else { None }
199     }
200
201     pub fn insert_all(&mut self) {
202         self.clear();
203         if let Some(end) = self.domain.checked_sub(1) {
204             self.map.push((0, end.try_into().unwrap()));
205         }
206         debug_assert!(self.check_invariants());
207     }
208
209     pub fn union(&mut self, other: &IntervalSet<I>) -> bool
210     where
211         I: Step,
212     {
213         assert_eq!(self.domain, other.domain);
214         let mut did_insert = false;
215         for range in other.iter_intervals() {
216             did_insert |= self.insert_range(range);
217         }
218         debug_assert!(self.check_invariants());
219         did_insert
220     }
221
222     // Check the intervals are valid, sorted and non-adjacent
223     fn check_invariants(&self) -> bool {
224         let mut current: Option<u32> = None;
225         for (start, end) in &self.map {
226             if start > end || current.map_or(false, |x| x + 1 >= *start) {
227                 return false;
228             }
229             current = Some(*end);
230         }
231         current.map_or(true, |x| x < self.domain as u32)
232     }
233 }
234
235 /// This data structure optimizes for cases where the stored bits in each row
236 /// are expected to be highly contiguous (long ranges of 1s or 0s), in contrast
237 /// to BitMatrix and SparseBitMatrix which are optimized for
238 /// "random"/non-contiguous bits and cheap(er) point queries at the expense of
239 /// memory usage.
240 #[derive(Clone)]
241 pub struct SparseIntervalMatrix<R, C>
242 where
243     R: Idx,
244     C: Idx,
245 {
246     rows: IndexVec<R, IntervalSet<C>>,
247     column_size: usize,
248 }
249
250 impl<R: Idx, C: Step + Idx> SparseIntervalMatrix<R, C> {
251     pub fn new(column_size: usize) -> SparseIntervalMatrix<R, C> {
252         SparseIntervalMatrix { rows: IndexVec::new(), column_size }
253     }
254
255     pub fn rows(&self) -> impl Iterator<Item = R> {
256         self.rows.indices()
257     }
258
259     pub fn row(&self, row: R) -> Option<&IntervalSet<C>> {
260         self.rows.get(row)
261     }
262
263     fn ensure_row(&mut self, row: R) -> &mut IntervalSet<C> {
264         self.rows.ensure_contains_elem(row, || IntervalSet::new(self.column_size));
265         &mut self.rows[row]
266     }
267
268     pub fn union_row(&mut self, row: R, from: &IntervalSet<C>) -> bool
269     where
270         C: Step,
271     {
272         self.ensure_row(row).union(from)
273     }
274
275     pub fn union_rows(&mut self, read: R, write: R) -> bool
276     where
277         C: Step,
278     {
279         if read == write || self.rows.get(read).is_none() {
280             return false;
281         }
282         self.ensure_row(write);
283         let (read_row, write_row) = self.rows.pick2_mut(read, write);
284         write_row.union(read_row)
285     }
286
287     pub fn insert_all_into_row(&mut self, row: R) {
288         self.ensure_row(row).insert_all();
289     }
290
291     pub fn insert_range(&mut self, row: R, range: impl RangeBounds<C> + Clone) {
292         self.ensure_row(row).insert_range(range);
293     }
294
295     pub fn insert(&mut self, row: R, point: C) -> bool {
296         self.ensure_row(row).insert(point)
297     }
298
299     pub fn contains(&self, row: R, point: C) -> bool {
300         self.row(row).map_or(false, |r| r.contains(point))
301     }
302 }