2 use std::marker::PhantomData;
4 use std::ops::RangeBounds;
7 use crate::vec::IndexVec;
8 use smallvec::SmallVec;
13 /// Stores a set of intervals on the indices.
14 #[derive(Debug, Clone)]
15 pub struct IntervalSet<I> {
17 map: SmallVec<[(u32, u32); 4]>,
19 _data: PhantomData<I>,
23 fn inclusive_start<T: Idx>(range: impl RangeBounds<T>) -> u32 {
24 match range.start_bound() {
25 Bound::Included(start) => start.index() as u32,
26 Bound::Excluded(start) => start.index() as u32 + 1,
27 Bound::Unbounded => 0,
32 fn inclusive_end<T: Idx>(domain: usize, range: impl RangeBounds<T>) -> Option<u32> {
33 let end = match range.end_bound() {
34 Bound::Included(end) => end.index() as u32,
35 Bound::Excluded(end) => end.index().checked_sub(1)? as u32,
36 Bound::Unbounded => domain.checked_sub(1)? as u32,
41 impl<I: Idx> IntervalSet<I> {
42 pub fn new(domain: usize) -> IntervalSet<I> {
43 IntervalSet { map: SmallVec::new(), domain, _data: PhantomData }
46 pub fn clear(&mut self) {
50 pub fn iter(&self) -> impl Iterator<Item = I> + '_
54 self.iter_intervals().flatten()
57 /// Iterates through intervals stored in the set, in order.
58 pub fn iter_intervals(&self) -> impl Iterator<Item = std::ops::Range<I>> + '_
62 self.map.iter().map(|&(start, end)| I::new(start as usize)..I::new(end as usize + 1))
65 /// Returns true if we increased the number of elements present.
66 pub fn insert(&mut self, point: I) -> bool {
67 self.insert_range(point..=point)
70 /// Returns true if we increased the number of elements present.
71 pub fn insert_range(&mut self, range: impl RangeBounds<I> + Clone) -> bool {
72 let start = inclusive_start(range.clone());
73 let Some(end) = inclusive_end(self.domain, range) else {
81 // This condition looks a bit weird, but actually makes sense.
83 // if r.0 == end + 1, then we're actually adjacent, so we want to
84 // continue to the next range. We're looking here for the first
85 // range which starts *non-adjacently* to our end.
86 let next = self.map.partition_point(|r| r.0 <= end + 1);
87 if let Some(right) = next.checked_sub(1) {
88 let (prev_start, prev_end) = self.map[right];
89 if prev_end + 1 >= start {
90 // If the start for the inserted range is adjacent to the
91 // end of the previous, we can extend the previous range.
92 if start < prev_start {
93 // The first range which ends *non-adjacently* to our start.
94 // And we can ensure that left <= right.
95 let left = self.map.partition_point(|l| l.1 + 1 < start);
96 let min = std::cmp::min(self.map[left].0, start);
97 let max = std::cmp::max(prev_end, end);
98 self.map[right] = (min, max);
100 self.map.drain(left..right);
104 // We overlap with the previous range, increase it to
107 // Make sure we're actually going to *increase* it though --
108 // it may be that end is just inside the previously existing
110 return if end > prev_end {
111 self.map[right].1 = end;
118 // Otherwise, we don't overlap, so just insert
119 self.map.insert(right + 1, (start, end));
123 if self.map.is_empty() {
124 // Quite common in practice, and expensive to call memcpy
126 self.map.push((start, end));
128 self.map.insert(next, (start, end));
134 pub fn contains(&self, needle: I) -> bool {
135 let needle = needle.index() as u32;
136 let Some(last) = self.map.partition_point(|r| r.0 <= needle).checked_sub(1) else {
137 // All ranges in the map start after the new range's end
140 let (_, prev_end) = &self.map[last];
144 pub fn superset(&self, other: &IntervalSet<I>) -> bool
148 // FIXME: Performance here is probably not great. We will be doing a lot
149 // of pointless tree traversals.
150 other.iter().all(|elem| self.contains(elem))
153 pub fn is_empty(&self) -> bool {
157 /// Returns the maximum (last) element present in the set from `range`.
158 pub fn last_set_in(&self, range: impl RangeBounds<I> + Clone) -> Option<I> {
159 let start = inclusive_start(range.clone());
160 let Some(end) = inclusive_end(self.domain, range) else {
167 let Some(last) = self.map.partition_point(|r| r.0 <= end).checked_sub(1) else {
168 // All ranges in the map start after the new range's end
171 let (_, prev_end) = &self.map[last];
172 if start <= *prev_end { Some(I::new(std::cmp::min(*prev_end, end) as usize)) } else { None }
175 pub fn insert_all(&mut self) {
177 self.map.push((0, self.domain.try_into().unwrap()));
180 pub fn union(&mut self, other: &IntervalSet<I>) -> bool
184 assert_eq!(self.domain, other.domain);
185 let mut did_insert = false;
186 for range in other.iter_intervals() {
187 did_insert |= self.insert_range(range);
193 /// This data structure optimizes for cases where the stored bits in each row
194 /// are expected to be highly contiguous (long ranges of 1s or 0s), in contrast
195 /// to BitMatrix and SparseBitMatrix which are optimized for
196 /// "random"/non-contiguous bits and cheap(er) point queries at the expense of
199 pub struct SparseIntervalMatrix<R, C>
204 rows: IndexVec<R, IntervalSet<C>>,
208 impl<R: Idx, C: Step + Idx> SparseIntervalMatrix<R, C> {
209 pub fn new(column_size: usize) -> SparseIntervalMatrix<R, C> {
210 SparseIntervalMatrix { rows: IndexVec::new(), column_size }
213 pub fn rows(&self) -> impl Iterator<Item = R> {
217 pub fn row(&self, row: R) -> Option<&IntervalSet<C>> {
221 fn ensure_row(&mut self, row: R) -> &mut IntervalSet<C> {
222 self.rows.ensure_contains_elem(row, || IntervalSet::new(self.column_size));
226 pub fn union_row(&mut self, row: R, from: &IntervalSet<C>) -> bool
230 self.ensure_row(row).union(from)
233 pub fn union_rows(&mut self, read: R, write: R) -> bool
237 if read == write || self.rows.get(read).is_none() {
240 self.ensure_row(write);
241 let (read_row, write_row) = self.rows.pick2_mut(read, write);
242 write_row.union(read_row)
245 pub fn insert_all_into_row(&mut self, row: R) {
246 self.ensure_row(row).insert_all();
249 pub fn insert_range(&mut self, row: R, range: impl RangeBounds<C> + Clone) {
250 self.ensure_row(row).insert_range(range);
253 pub fn insert(&mut self, row: R, point: C) -> bool {
254 self.ensure_row(row).insert(point)
257 pub fn contains(&self, row: R, point: C) -> bool {
258 self.row(row).map_or(false, |r| r.contains(point))