]> git.lizzy.rs Git - rust.git/blob - src/librustc_data_structures/bit_set.rs
Rollup merge of #60187 - tmandry:generator-optimization, r=eddyb
[rust.git] / src / librustc_data_structures / bit_set.rs
1 use crate::indexed_vec::{Idx, IndexVec};
2 use smallvec::SmallVec;
3 use std::fmt;
4 use std::iter;
5 use std::marker::PhantomData;
6 use std::mem;
7 use std::slice;
8
9 pub type Word = u64;
10 pub const WORD_BYTES: usize = mem::size_of::<Word>();
11 pub const WORD_BITS: usize = WORD_BYTES * 8;
12
13 /// A fixed-size bitset type with a dense representation. It does not support
14 /// resizing after creation; use `GrowableBitSet` for that.
15 ///
16 /// `T` is an index type, typically a newtyped `usize` wrapper, but it can also
17 /// just be `usize`.
18 ///
19 /// All operations that involve an element will panic if the element is equal
20 /// to or greater than the domain size. All operations that involve two bitsets
21 /// will panic if the bitsets have differing domain sizes.
22 #[derive(Clone, Eq, PartialEq, RustcDecodable, RustcEncodable)]
23 pub struct BitSet<T: Idx> {
24     domain_size: usize,
25     words: Vec<Word>,
26     marker: PhantomData<T>,
27 }
28
29 impl<T: Idx> BitSet<T> {
30     /// Creates a new, empty bitset with a given `domain_size`.
31     #[inline]
32     pub fn new_empty(domain_size: usize) -> BitSet<T> {
33         let num_words = num_words(domain_size);
34         BitSet {
35             domain_size,
36             words: vec![0; num_words],
37             marker: PhantomData,
38         }
39     }
40
41     /// Creates a new, filled bitset with a given `domain_size`.
42     #[inline]
43     pub fn new_filled(domain_size: usize) -> BitSet<T> {
44         let num_words = num_words(domain_size);
45         let mut result = BitSet {
46             domain_size,
47             words: vec![!0; num_words],
48             marker: PhantomData,
49         };
50         result.clear_excess_bits();
51         result
52     }
53
54     /// Gets the domain size.
55     pub fn domain_size(&self) -> usize {
56         self.domain_size
57     }
58
59     /// Clear all elements.
60     #[inline]
61     pub fn clear(&mut self) {
62         for word in &mut self.words {
63             *word = 0;
64         }
65     }
66
67     /// Clear excess bits in the final word.
68     fn clear_excess_bits(&mut self) {
69         let num_bits_in_final_word = self.domain_size % WORD_BITS;
70         if num_bits_in_final_word > 0 {
71             let mask = (1 << num_bits_in_final_word) - 1;
72             let final_word_idx = self.words.len() - 1;
73             self.words[final_word_idx] &= mask;
74         }
75     }
76
77     /// Efficiently overwrite `self` with `other`.
78     pub fn overwrite(&mut self, other: &BitSet<T>) {
79         assert!(self.domain_size == other.domain_size);
80         self.words.clone_from_slice(&other.words);
81     }
82
83     /// Count the number of set bits in the set.
84     pub fn count(&self) -> usize {
85         self.words.iter().map(|e| e.count_ones() as usize).sum()
86     }
87
88     /// Returns `true` if `self` contains `elem`.
89     #[inline]
90     pub fn contains(&self, elem: T) -> bool {
91         assert!(elem.index() < self.domain_size);
92         let (word_index, mask) = word_index_and_mask(elem);
93         (self.words[word_index] & mask) != 0
94     }
95
96     /// Is `self` is a (non-strict) superset of `other`?
97     #[inline]
98     pub fn superset(&self, other: &BitSet<T>) -> bool {
99         assert_eq!(self.domain_size, other.domain_size);
100         self.words.iter().zip(&other.words).all(|(a, b)| (a & b) == *b)
101     }
102
103     /// Is the set empty?
104     #[inline]
105     pub fn is_empty(&self) -> bool {
106         self.words.iter().all(|a| *a == 0)
107     }
108
109     /// Insert `elem`. Returns whether the set has changed.
110     #[inline]
111     pub fn insert(&mut self, elem: T) -> bool {
112         assert!(elem.index() < self.domain_size);
113         let (word_index, mask) = word_index_and_mask(elem);
114         let word_ref = &mut self.words[word_index];
115         let word = *word_ref;
116         let new_word = word | mask;
117         *word_ref = new_word;
118         new_word != word
119     }
120
121     /// Sets all bits to true.
122     pub fn insert_all(&mut self) {
123         for word in &mut self.words {
124             *word = !0;
125         }
126         self.clear_excess_bits();
127     }
128
129     /// Returns `true` if the set has changed.
130     #[inline]
131     pub fn remove(&mut self, elem: T) -> bool {
132         assert!(elem.index() < self.domain_size);
133         let (word_index, mask) = word_index_and_mask(elem);
134         let word_ref = &mut self.words[word_index];
135         let word = *word_ref;
136         let new_word = word & !mask;
137         *word_ref = new_word;
138         new_word != word
139     }
140
141     /// Sets `self = self | other` and returns `true` if `self` changed
142     /// (i.e., if new bits were added).
143     pub fn union(&mut self, other: &impl UnionIntoBitSet<T>) -> bool {
144         other.union_into(self)
145     }
146
147     /// Sets `self = self - other` and returns `true` if `self` changed.
148     /// (i.e., if any bits were removed).
149     pub fn subtract(&mut self, other: &impl SubtractFromBitSet<T>) -> bool {
150         other.subtract_from(self)
151     }
152
153     /// Sets `self = self & other` and return `true` if `self` changed.
154     /// (i.e., if any bits were removed).
155     pub fn intersect(&mut self, other: &BitSet<T>) -> bool {
156         assert_eq!(self.domain_size, other.domain_size);
157         bitwise(&mut self.words, &other.words, |a, b| { a & b })
158     }
159
160     /// Gets a slice of the underlying words.
161     pub fn words(&self) -> &[Word] {
162         &self.words
163     }
164
165     /// Iterates over the indices of set bits in a sorted order.
166     #[inline]
167     pub fn iter<'a>(&'a self) -> BitIter<'a, T> {
168         BitIter {
169             cur: None,
170             iter: self.words.iter().enumerate(),
171             marker: PhantomData,
172         }
173     }
174
175     /// Duplicates the set as a hybrid set.
176     pub fn to_hybrid(&self) -> HybridBitSet<T> {
177         // Note: we currently don't bother trying to make a Sparse set.
178         HybridBitSet::Dense(self.to_owned())
179     }
180 }
181
182 /// This is implemented by all the bitsets so that BitSet::union() can be
183 /// passed any type of bitset.
184 pub trait UnionIntoBitSet<T: Idx> {
185     // Performs `other = other | self`.
186     fn union_into(&self, other: &mut BitSet<T>) -> bool;
187 }
188
189 /// This is implemented by all the bitsets so that BitSet::subtract() can be
190 /// passed any type of bitset.
191 pub trait SubtractFromBitSet<T: Idx> {
192     // Performs `other = other - self`.
193     fn subtract_from(&self, other: &mut BitSet<T>) -> bool;
194 }
195
196 impl<T: Idx> UnionIntoBitSet<T> for BitSet<T> {
197     fn union_into(&self, other: &mut BitSet<T>) -> bool {
198         assert_eq!(self.domain_size, other.domain_size);
199         bitwise(&mut other.words, &self.words, |a, b| { a | b })
200     }
201 }
202
203 impl<T: Idx> SubtractFromBitSet<T> for BitSet<T> {
204     fn subtract_from(&self, other: &mut BitSet<T>) -> bool {
205         assert_eq!(self.domain_size, other.domain_size);
206         bitwise(&mut other.words, &self.words, |a, b| { a & !b })
207     }
208 }
209
210 impl<T: Idx> fmt::Debug for BitSet<T> {
211     fn fmt(&self, w: &mut fmt::Formatter<'_>) -> fmt::Result {
212         w.debug_list()
213          .entries(self.iter())
214          .finish()
215     }
216 }
217
218 impl<T: Idx> ToString for BitSet<T> {
219     fn to_string(&self) -> String {
220         let mut result = String::new();
221         let mut sep = '[';
222
223         // Note: this is a little endian printout of bytes.
224
225         // i tracks how many bits we have printed so far.
226         let mut i = 0;
227         for word in &self.words {
228             let mut word = *word;
229             for _ in 0..WORD_BYTES { // for each byte in `word`:
230                 let remain = self.domain_size - i;
231                 // If less than a byte remains, then mask just that many bits.
232                 let mask = if remain <= 8 { (1 << remain) - 1 } else { 0xFF };
233                 assert!(mask <= 0xFF);
234                 let byte = word & mask;
235
236                 result.push_str(&format!("{}{:02x}", sep, byte));
237
238                 if remain <= 8 { break; }
239                 word >>= 8;
240                 i += 8;
241                 sep = '-';
242             }
243             sep = '|';
244         }
245         result.push(']');
246
247         result
248     }
249 }
250
251 pub struct BitIter<'a, T: Idx> {
252     cur: Option<(Word, usize)>,
253     iter: iter::Enumerate<slice::Iter<'a, Word>>,
254     marker: PhantomData<T>
255 }
256
257 impl<'a, T: Idx> Iterator for BitIter<'a, T> {
258     type Item = T;
259     fn next(&mut self) -> Option<T> {
260         loop {
261             if let Some((ref mut word, offset)) = self.cur {
262                 let bit_pos = word.trailing_zeros() as usize;
263                 if bit_pos != WORD_BITS {
264                     let bit = 1 << bit_pos;
265                     *word ^= bit;
266                     return Some(T::new(bit_pos + offset))
267                 }
268             }
269
270             let (i, word) = self.iter.next()?;
271             self.cur = Some((*word, WORD_BITS * i));
272         }
273     }
274 }
275
276 pub trait BitSetOperator {
277     /// Combine one bitset into another.
278     fn join<T: Idx>(&self, inout_set: &mut BitSet<T>, in_set: &BitSet<T>) -> bool;
279 }
280
281 #[inline]
282 fn bitwise<Op>(out_vec: &mut [Word], in_vec: &[Word], op: Op) -> bool
283     where Op: Fn(Word, Word) -> Word
284 {
285     assert_eq!(out_vec.len(), in_vec.len());
286     let mut changed = false;
287     for (out_elem, in_elem) in out_vec.iter_mut().zip(in_vec.iter()) {
288         let old_val = *out_elem;
289         let new_val = op(old_val, *in_elem);
290         *out_elem = new_val;
291         changed |= old_val != new_val;
292     }
293     changed
294 }
295
296 const SPARSE_MAX: usize = 8;
297
298 /// A fixed-size bitset type with a sparse representation and a maximum of
299 /// `SPARSE_MAX` elements. The elements are stored as a sorted `SmallVec` with
300 /// no duplicates; although `SmallVec` can spill its elements to the heap, that
301 /// never happens within this type because of the `SPARSE_MAX` limit.
302 ///
303 /// This type is used by `HybridBitSet`; do not use directly.
304 #[derive(Clone, Debug)]
305 pub struct SparseBitSet<T: Idx> {
306     domain_size: usize,
307     elems: SmallVec<[T; SPARSE_MAX]>,
308 }
309
310 impl<T: Idx> SparseBitSet<T> {
311     fn new_empty(domain_size: usize) -> Self {
312         SparseBitSet {
313             domain_size,
314             elems: SmallVec::new()
315         }
316     }
317
318     fn len(&self) -> usize {
319         self.elems.len()
320     }
321
322     fn is_empty(&self) -> bool {
323         self.elems.len() == 0
324     }
325
326     fn contains(&self, elem: T) -> bool {
327         assert!(elem.index() < self.domain_size);
328         self.elems.contains(&elem)
329     }
330
331     fn insert(&mut self, elem: T) -> bool {
332         assert!(elem.index() < self.domain_size);
333         let changed = if let Some(i) = self.elems.iter().position(|&e| e >= elem) {
334             if self.elems[i] == elem {
335                 // `elem` is already in the set.
336                 false
337             } else {
338                 // `elem` is smaller than one or more existing elements.
339                 self.elems.insert(i, elem);
340                 true
341             }
342         } else {
343             // `elem` is larger than all existing elements.
344             self.elems.push(elem);
345             true
346         };
347         assert!(self.len() <= SPARSE_MAX);
348         changed
349     }
350
351     fn remove(&mut self, elem: T) -> bool {
352         assert!(elem.index() < self.domain_size);
353         if let Some(i) = self.elems.iter().position(|&e| e == elem) {
354             self.elems.remove(i);
355             true
356         } else {
357             false
358         }
359     }
360
361     fn to_dense(&self) -> BitSet<T> {
362         let mut dense = BitSet::new_empty(self.domain_size);
363         for elem in self.elems.iter() {
364             dense.insert(*elem);
365         }
366         dense
367     }
368
369     fn iter(&self) -> slice::Iter<'_, T> {
370         self.elems.iter()
371     }
372 }
373
374 impl<T: Idx> UnionIntoBitSet<T> for SparseBitSet<T> {
375     fn union_into(&self, other: &mut BitSet<T>) -> bool {
376         assert_eq!(self.domain_size, other.domain_size);
377         let mut changed = false;
378         for elem in self.iter() {
379             changed |= other.insert(*elem);
380         }
381         changed
382     }
383 }
384
385 impl<T: Idx> SubtractFromBitSet<T> for SparseBitSet<T> {
386     fn subtract_from(&self, other: &mut BitSet<T>) -> bool {
387         assert_eq!(self.domain_size, other.domain_size);
388         let mut changed = false;
389         for elem in self.iter() {
390             changed |= other.remove(*elem);
391         }
392         changed
393     }
394 }
395
396 /// A fixed-size bitset type with a hybrid representation: sparse when there
397 /// are up to a `SPARSE_MAX` elements in the set, but dense when there are more
398 /// than `SPARSE_MAX`.
399 ///
400 /// This type is especially efficient for sets that typically have a small
401 /// number of elements, but a large `domain_size`, and are cleared frequently.
402 ///
403 /// `T` is an index type, typically a newtyped `usize` wrapper, but it can also
404 /// just be `usize`.
405 ///
406 /// All operations that involve an element will panic if the element is equal
407 /// to or greater than the domain size. All operations that involve two bitsets
408 /// will panic if the bitsets have differing domain sizes.
409 #[derive(Clone, Debug)]
410 pub enum HybridBitSet<T: Idx> {
411     Sparse(SparseBitSet<T>),
412     Dense(BitSet<T>),
413 }
414
415 impl<T: Idx> HybridBitSet<T> {
416     pub fn new_empty(domain_size: usize) -> Self {
417         HybridBitSet::Sparse(SparseBitSet::new_empty(domain_size))
418     }
419
420     fn domain_size(&self) -> usize {
421         match self {
422             HybridBitSet::Sparse(sparse) => sparse.domain_size,
423             HybridBitSet::Dense(dense) => dense.domain_size,
424         }
425     }
426
427     pub fn clear(&mut self) {
428         let domain_size = self.domain_size();
429         *self = HybridBitSet::new_empty(domain_size);
430     }
431
432     pub fn contains(&self, elem: T) -> bool {
433         match self {
434             HybridBitSet::Sparse(sparse) => sparse.contains(elem),
435             HybridBitSet::Dense(dense) => dense.contains(elem),
436         }
437     }
438
439     pub fn superset(&self, other: &HybridBitSet<T>) -> bool {
440         match (self, other) {
441             (HybridBitSet::Dense(self_dense), HybridBitSet::Dense(other_dense)) => {
442                 self_dense.superset(other_dense)
443             }
444             _ => {
445                 assert!(self.domain_size() == other.domain_size());
446                 other.iter().all(|elem| self.contains(elem))
447             }
448         }
449     }
450
451     pub fn is_empty(&self) -> bool {
452         match self {
453             HybridBitSet::Sparse(sparse) => sparse.is_empty(),
454             HybridBitSet::Dense(dense) => dense.is_empty(),
455         }
456     }
457
458     pub fn insert(&mut self, elem: T) -> bool {
459         // No need to check `elem` against `self.domain_size` here because all
460         // the match cases check it, one way or another.
461         match self {
462             HybridBitSet::Sparse(sparse) if sparse.len() < SPARSE_MAX => {
463                 // The set is sparse and has space for `elem`.
464                 sparse.insert(elem)
465             }
466             HybridBitSet::Sparse(sparse) if sparse.contains(elem) => {
467                 // The set is sparse and does not have space for `elem`, but
468                 // that doesn't matter because `elem` is already present.
469                 false
470             }
471             HybridBitSet::Sparse(sparse) => {
472                 // The set is sparse and full. Convert to a dense set.
473                 let mut dense = sparse.to_dense();
474                 let changed = dense.insert(elem);
475                 assert!(changed);
476                 *self = HybridBitSet::Dense(dense);
477                 changed
478             }
479             HybridBitSet::Dense(dense) => dense.insert(elem),
480         }
481     }
482
483     pub fn insert_all(&mut self) {
484         let domain_size = self.domain_size();
485         match self {
486             HybridBitSet::Sparse(_) => {
487                 *self = HybridBitSet::Dense(BitSet::new_filled(domain_size));
488             }
489             HybridBitSet::Dense(dense) => dense.insert_all(),
490         }
491     }
492
493     pub fn remove(&mut self, elem: T) -> bool {
494         // Note: we currently don't bother going from Dense back to Sparse.
495         match self {
496             HybridBitSet::Sparse(sparse) => sparse.remove(elem),
497             HybridBitSet::Dense(dense) => dense.remove(elem),
498         }
499     }
500
501     pub fn union(&mut self, other: &HybridBitSet<T>) -> bool {
502         match self {
503             HybridBitSet::Sparse(self_sparse) => {
504                 match other {
505                     HybridBitSet::Sparse(other_sparse) => {
506                         // Both sets are sparse. Add the elements in
507                         // `other_sparse` to `self` one at a time. This
508                         // may or may not cause `self` to be densified.
509                         assert_eq!(self.domain_size(), other.domain_size());
510                         let mut changed = false;
511                         for elem in other_sparse.iter() {
512                             changed |= self.insert(*elem);
513                         }
514                         changed
515                     }
516                     HybridBitSet::Dense(other_dense) => {
517                         // `self` is sparse and `other` is dense. Densify
518                         // `self` and then do the bitwise union.
519                         let mut new_dense = self_sparse.to_dense();
520                         let changed = new_dense.union(other_dense);
521                         *self = HybridBitSet::Dense(new_dense);
522                         changed
523                     }
524                 }
525             }
526
527             HybridBitSet::Dense(self_dense) => self_dense.union(other),
528         }
529     }
530
531     /// Converts to a dense set, consuming itself in the process.
532     pub fn to_dense(self) -> BitSet<T> {
533         match self {
534             HybridBitSet::Sparse(sparse) => sparse.to_dense(),
535             HybridBitSet::Dense(dense) => dense,
536         }
537     }
538
539     pub fn iter(&self) -> HybridIter<'_, T> {
540         match self {
541             HybridBitSet::Sparse(sparse) => HybridIter::Sparse(sparse.iter()),
542             HybridBitSet::Dense(dense) => HybridIter::Dense(dense.iter()),
543         }
544     }
545 }
546
547 impl<T: Idx> UnionIntoBitSet<T> for HybridBitSet<T> {
548     fn union_into(&self, other: &mut BitSet<T>) -> bool {
549         match self {
550             HybridBitSet::Sparse(sparse) => sparse.union_into(other),
551             HybridBitSet::Dense(dense) => dense.union_into(other),
552         }
553     }
554 }
555
556 impl<T: Idx> SubtractFromBitSet<T> for HybridBitSet<T> {
557     fn subtract_from(&self, other: &mut BitSet<T>) -> bool {
558         match self {
559             HybridBitSet::Sparse(sparse) => sparse.subtract_from(other),
560             HybridBitSet::Dense(dense) => dense.subtract_from(other),
561         }
562     }
563 }
564
565 pub enum HybridIter<'a, T: Idx> {
566     Sparse(slice::Iter<'a, T>),
567     Dense(BitIter<'a, T>),
568 }
569
570 impl<'a, T: Idx> Iterator for HybridIter<'a, T> {
571     type Item = T;
572
573     fn next(&mut self) -> Option<T> {
574         match self {
575             HybridIter::Sparse(sparse) => sparse.next().map(|e| *e),
576             HybridIter::Dense(dense) => dense.next(),
577         }
578     }
579 }
580
581 /// A resizable bitset type with a dense representation.
582 ///
583 /// `T` is an index type, typically a newtyped `usize` wrapper, but it can also
584 /// just be `usize`.
585 ///
586 /// All operations that involve an element will panic if the element is equal
587 /// to or greater than the domain size.
588 #[derive(Clone, Debug, PartialEq)]
589 pub struct GrowableBitSet<T: Idx> {
590     bit_set: BitSet<T>,
591 }
592
593 impl<T: Idx> GrowableBitSet<T> {
594     /// Ensure that the set can hold at least `min_domain_size` elements.
595     pub fn ensure(&mut self, min_domain_size: usize) {
596         if self.bit_set.domain_size < min_domain_size {
597             self.bit_set.domain_size = min_domain_size;
598         }
599
600         let min_num_words = num_words(min_domain_size);
601         if self.bit_set.words.len() < min_num_words {
602             self.bit_set.words.resize(min_num_words, 0)
603         }
604     }
605
606     pub fn new_empty() -> GrowableBitSet<T> {
607         GrowableBitSet { bit_set: BitSet::new_empty(0) }
608     }
609
610     pub fn with_capacity(capacity: usize) -> GrowableBitSet<T> {
611         GrowableBitSet { bit_set: BitSet::new_empty(capacity) }
612     }
613
614     /// Returns `true` if the set has changed.
615     #[inline]
616     pub fn insert(&mut self, elem: T) -> bool {
617         self.ensure(elem.index() + 1);
618         self.bit_set.insert(elem)
619     }
620
621     #[inline]
622     pub fn contains(&self, elem: T) -> bool {
623         let (word_index, mask) = word_index_and_mask(elem);
624         if let Some(word) = self.bit_set.words.get(word_index) {
625             (word & mask) != 0
626         } else {
627             false
628         }
629     }
630 }
631
632 /// A fixed-size 2D bit matrix type with a dense representation.
633 ///
634 /// `R` and `C` are index types used to identify rows and columns respectively;
635 /// typically newtyped `usize` wrappers, but they can also just be `usize`.
636 ///
637 /// All operations that involve a row and/or column index will panic if the
638 /// index exceeds the relevant bound.
639 #[derive(Clone, Debug, Eq, PartialEq, RustcDecodable, RustcEncodable)]
640 pub struct BitMatrix<R: Idx, C: Idx> {
641     num_rows: usize,
642     num_columns: usize,
643     words: Vec<Word>,
644     marker: PhantomData<(R, C)>,
645 }
646
647 impl<R: Idx, C: Idx> BitMatrix<R, C> {
648     /// Creates a new `rows x columns` matrix, initially empty.
649     pub fn new(num_rows: usize, num_columns: usize) -> BitMatrix<R, C> {
650         // For every element, we need one bit for every other
651         // element. Round up to an even number of words.
652         let words_per_row = num_words(num_columns);
653         BitMatrix {
654             num_rows,
655             num_columns,
656             words: vec![0; num_rows * words_per_row],
657             marker: PhantomData,
658         }
659     }
660
661     /// Creates a new matrix, with `row` used as the value for every row.
662     pub fn from_row_n(row: &BitSet<C>, num_rows: usize) -> BitMatrix<R, C> {
663         let num_columns = row.domain_size();
664         let words_per_row = num_words(num_columns);
665         assert_eq!(words_per_row, row.words().len());
666         BitMatrix {
667             num_rows,
668             num_columns,
669             words: iter::repeat(row.words()).take(num_rows).flatten().cloned().collect(),
670             marker: PhantomData,
671         }
672     }
673
674     pub fn rows(&self) -> impl Iterator<Item = R> {
675         (0..self.num_rows).map(R::new)
676     }
677
678     /// The range of bits for a given row.
679     fn range(&self, row: R) -> (usize, usize) {
680         let words_per_row = num_words(self.num_columns);
681         let start = row.index() * words_per_row;
682         (start, start + words_per_row)
683     }
684
685     /// Sets the cell at `(row, column)` to true. Put another way, insert
686     /// `column` to the bitset for `row`.
687     ///
688     /// Returns `true` if this changed the matrix.
689     pub fn insert(&mut self, row: R, column: C) -> bool {
690         assert!(row.index() < self.num_rows && column.index() < self.num_columns);
691         let (start, _) = self.range(row);
692         let (word_index, mask) = word_index_and_mask(column);
693         let words = &mut self.words[..];
694         let word = words[start + word_index];
695         let new_word = word | mask;
696         words[start + word_index] = new_word;
697         word != new_word
698     }
699
700     /// Do the bits from `row` contain `column`? Put another way, is
701     /// the matrix cell at `(row, column)` true?  Put yet another way,
702     /// if the matrix represents (transitive) reachability, can
703     /// `row` reach `column`?
704     pub fn contains(&self, row: R, column: C) -> bool {
705         assert!(row.index() < self.num_rows && column.index() < self.num_columns);
706         let (start, _) = self.range(row);
707         let (word_index, mask) = word_index_and_mask(column);
708         (self.words[start + word_index] & mask) != 0
709     }
710
711     /// Returns those indices that are true in rows `a` and `b`. This
712     /// is an O(n) operation where `n` is the number of elements
713     /// (somewhat independent from the actual size of the
714     /// intersection, in particular).
715     pub fn intersect_rows(&self, row1: R, row2: R) -> Vec<C> {
716         assert!(row1.index() < self.num_rows && row2.index() < self.num_rows);
717         let (row1_start, row1_end) = self.range(row1);
718         let (row2_start, row2_end) = self.range(row2);
719         let mut result = Vec::with_capacity(self.num_columns);
720         for (base, (i, j)) in (row1_start..row1_end).zip(row2_start..row2_end).enumerate() {
721             let mut v = self.words[i] & self.words[j];
722             for bit in 0..WORD_BITS {
723                 if v == 0 {
724                     break;
725                 }
726                 if v & 0x1 != 0 {
727                     result.push(C::new(base * WORD_BITS + bit));
728                 }
729                 v >>= 1;
730             }
731         }
732         result
733     }
734
735     /// Adds the bits from row `read` to the bits from row `write`, and
736     /// returns `true` if anything changed.
737     ///
738     /// This is used when computing transitive reachability because if
739     /// you have an edge `write -> read`, because in that case
740     /// `write` can reach everything that `read` can (and
741     /// potentially more).
742     pub fn union_rows(&mut self, read: R, write: R) -> bool {
743         assert!(read.index() < self.num_rows && write.index() < self.num_rows);
744         let (read_start, read_end) = self.range(read);
745         let (write_start, write_end) = self.range(write);
746         let words = &mut self.words[..];
747         let mut changed = false;
748         for (read_index, write_index) in (read_start..read_end).zip(write_start..write_end) {
749             let word = words[write_index];
750             let new_word = word | words[read_index];
751             words[write_index] = new_word;
752             changed |= word != new_word;
753         }
754         changed
755     }
756
757     /// Adds the bits from `with` to the bits from row `write`, and
758     /// returns `true` if anything changed.
759     pub fn union_row_with(&mut self, with: &BitSet<C>, write: R) -> bool {
760         assert!(write.index() < self.num_rows);
761         assert_eq!(with.domain_size(), self.num_columns);
762         let (write_start, write_end) = self.range(write);
763         let mut changed = false;
764         for (read_index, write_index) in (0..with.words().len()).zip(write_start..write_end) {
765             let word = self.words[write_index];
766             let new_word = word | with.words()[read_index];
767             self.words[write_index] = new_word;
768             changed |= word != new_word;
769         }
770         changed
771     }
772
773     /// Sets every cell in `row` to true.
774     pub fn insert_all_into_row(&mut self, row: R) {
775         assert!(row.index() < self.num_rows);
776         let (start, end) = self.range(row);
777         let words = &mut self.words[..];
778         for index in start..end {
779             words[index] = !0;
780         }
781         self.clear_excess_bits(row);
782     }
783
784     /// Clear excess bits in the final word of the row.
785     fn clear_excess_bits(&mut self, row: R) {
786         let num_bits_in_final_word = self.num_columns % WORD_BITS;
787         if num_bits_in_final_word > 0 {
788             let mask = (1 << num_bits_in_final_word) - 1;
789             let (_, end) = self.range(row);
790             let final_word_idx = end - 1;
791             self.words[final_word_idx] &= mask;
792         }
793     }
794
795     /// Gets a slice of the underlying words.
796     pub fn words(&self) -> &[Word] {
797         &self.words
798     }
799
800     /// Iterates through all the columns set to true in a given row of
801     /// the matrix.
802     pub fn iter<'a>(&'a self, row: R) -> BitIter<'a, C> {
803         assert!(row.index() < self.num_rows);
804         let (start, end) = self.range(row);
805         BitIter {
806             cur: None,
807             iter: self.words[start..end].iter().enumerate(),
808             marker: PhantomData,
809         }
810     }
811
812     /// Returns the number of elements in `row`.
813     pub fn count(&self, row: R) -> usize {
814         let (start, end) = self.range(row);
815         self.words[start..end].iter().map(|e| e.count_ones() as usize).sum()
816     }
817 }
818
819 /// A fixed-column-size, variable-row-size 2D bit matrix with a moderately
820 /// sparse representation.
821 ///
822 /// Initially, every row has no explicit representation. If any bit within a
823 /// row is set, the entire row is instantiated as `Some(<HybridBitSet>)`.
824 /// Furthermore, any previously uninstantiated rows prior to it will be
825 /// instantiated as `None`. Those prior rows may themselves become fully
826 /// instantiated later on if any of their bits are set.
827 ///
828 /// `R` and `C` are index types used to identify rows and columns respectively;
829 /// typically newtyped `usize` wrappers, but they can also just be `usize`.
830 #[derive(Clone, Debug)]
831 pub struct SparseBitMatrix<R, C>
832 where
833     R: Idx,
834     C: Idx,
835 {
836     num_columns: usize,
837     rows: IndexVec<R, Option<HybridBitSet<C>>>,
838 }
839
840 impl<R: Idx, C: Idx> SparseBitMatrix<R, C> {
841     /// Creates a new empty sparse bit matrix with no rows or columns.
842     pub fn new(num_columns: usize) -> Self {
843         Self {
844             num_columns,
845             rows: IndexVec::new(),
846         }
847     }
848
849     fn ensure_row(&mut self, row: R) -> &mut HybridBitSet<C> {
850         // Instantiate any missing rows up to and including row `row` with an
851         // empty HybridBitSet.
852         self.rows.ensure_contains_elem(row, || None);
853
854         // Then replace row `row` with a full HybridBitSet if necessary.
855         let num_columns = self.num_columns;
856         self.rows[row].get_or_insert_with(|| HybridBitSet::new_empty(num_columns))
857     }
858
859     /// Sets the cell at `(row, column)` to true. Put another way, insert
860     /// `column` to the bitset for `row`.
861     ///
862     /// Returns `true` if this changed the matrix.
863     pub fn insert(&mut self, row: R, column: C) -> bool {
864         self.ensure_row(row).insert(column)
865     }
866
867     /// Do the bits from `row` contain `column`? Put another way, is
868     /// the matrix cell at `(row, column)` true?  Put yet another way,
869     /// if the matrix represents (transitive) reachability, can
870     /// `row` reach `column`?
871     pub fn contains(&self, row: R, column: C) -> bool {
872         self.row(row).map_or(false, |r| r.contains(column))
873     }
874
875     /// Adds the bits from row `read` to the bits from row `write`, and
876     /// returns `true` if anything changed.
877     ///
878     /// This is used when computing transitive reachability because if
879     /// you have an edge `write -> read`, because in that case
880     /// `write` can reach everything that `read` can (and
881     /// potentially more).
882     pub fn union_rows(&mut self, read: R, write: R) -> bool {
883         if read == write || self.row(read).is_none() {
884             return false;
885         }
886
887         self.ensure_row(write);
888         if let (Some(read_row), Some(write_row)) = self.rows.pick2_mut(read, write) {
889             write_row.union(read_row)
890         } else {
891             unreachable!()
892         }
893     }
894
895     /// Union a row, `from`, into the `into` row.
896     pub fn union_into_row(&mut self, into: R, from: &HybridBitSet<C>) -> bool {
897         self.ensure_row(into).union(from)
898     }
899
900     /// Insert all bits in the given row.
901     pub fn insert_all_into_row(&mut self, row: R) {
902         self.ensure_row(row).insert_all();
903     }
904
905     pub fn rows(&self) -> impl Iterator<Item = R> {
906         self.rows.indices()
907     }
908
909     /// Iterates through all the columns set to true in a given row of
910     /// the matrix.
911     pub fn iter<'a>(&'a self, row: R) -> impl Iterator<Item = C> + 'a {
912         self.row(row).into_iter().flat_map(|r| r.iter())
913     }
914
915     pub fn row(&self, row: R) -> Option<&HybridBitSet<C>> {
916         if let Some(Some(row)) = self.rows.get(row) {
917             Some(row)
918         } else {
919             None
920         }
921     }
922 }
923
924 #[inline]
925 fn num_words<T: Idx>(domain_size: T) -> usize {
926     (domain_size.index() + WORD_BITS - 1) / WORD_BITS
927 }
928
929 #[inline]
930 fn word_index_and_mask<T: Idx>(elem: T) -> (usize, Word) {
931     let elem = elem.index();
932     let word_index = elem / WORD_BITS;
933     let mask = 1 << (elem % WORD_BITS);
934     (word_index, mask)
935 }
936
937 #[test]
938 fn test_new_filled() {
939     for i in 0..128 {
940         let idx_buf = BitSet::new_filled(i);
941         let elems: Vec<usize> = idx_buf.iter().collect();
942         let expected: Vec<usize> = (0..i).collect();
943         assert_eq!(elems, expected);
944     }
945 }
946
947 #[test]
948 fn bitset_iter_works() {
949     let mut bitset: BitSet<usize> = BitSet::new_empty(100);
950     bitset.insert(1);
951     bitset.insert(10);
952     bitset.insert(19);
953     bitset.insert(62);
954     bitset.insert(63);
955     bitset.insert(64);
956     bitset.insert(65);
957     bitset.insert(66);
958     bitset.insert(99);
959     assert_eq!(
960         bitset.iter().collect::<Vec<_>>(),
961         [1, 10, 19, 62, 63, 64, 65, 66, 99]
962     );
963 }
964
965 #[test]
966 fn bitset_iter_works_2() {
967     let mut bitset: BitSet<usize> = BitSet::new_empty(320);
968     bitset.insert(0);
969     bitset.insert(127);
970     bitset.insert(191);
971     bitset.insert(255);
972     bitset.insert(319);
973     assert_eq!(bitset.iter().collect::<Vec<_>>(), [0, 127, 191, 255, 319]);
974 }
975
976 #[test]
977 fn union_two_sets() {
978     let mut set1: BitSet<usize> = BitSet::new_empty(65);
979     let mut set2: BitSet<usize> = BitSet::new_empty(65);
980     assert!(set1.insert(3));
981     assert!(!set1.insert(3));
982     assert!(set2.insert(5));
983     assert!(set2.insert(64));
984     assert!(set1.union(&set2));
985     assert!(!set1.union(&set2));
986     assert!(set1.contains(3));
987     assert!(!set1.contains(4));
988     assert!(set1.contains(5));
989     assert!(!set1.contains(63));
990     assert!(set1.contains(64));
991 }
992
993 #[test]
994 fn hybrid_bitset() {
995     let mut sparse038: HybridBitSet<usize> = HybridBitSet::new_empty(256);
996     assert!(sparse038.is_empty());
997     assert!(sparse038.insert(0));
998     assert!(sparse038.insert(1));
999     assert!(sparse038.insert(8));
1000     assert!(sparse038.insert(3));
1001     assert!(!sparse038.insert(3));
1002     assert!(sparse038.remove(1));
1003     assert!(!sparse038.is_empty());
1004     assert_eq!(sparse038.iter().collect::<Vec<_>>(), [0, 3, 8]);
1005
1006     for i in 0..256 {
1007         if i == 0 || i == 3 || i == 8 {
1008             assert!(sparse038.contains(i));
1009         } else {
1010             assert!(!sparse038.contains(i));
1011         }
1012     }
1013
1014     let mut sparse01358 = sparse038.clone();
1015     assert!(sparse01358.insert(1));
1016     assert!(sparse01358.insert(5));
1017     assert_eq!(sparse01358.iter().collect::<Vec<_>>(), [0, 1, 3, 5, 8]);
1018
1019     let mut dense10 = HybridBitSet::new_empty(256);
1020     for i in 0..10 {
1021         assert!(dense10.insert(i));
1022     }
1023     assert!(!dense10.is_empty());
1024     assert_eq!(dense10.iter().collect::<Vec<_>>(), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
1025
1026     let mut dense256 = HybridBitSet::new_empty(256);
1027     assert!(dense256.is_empty());
1028     dense256.insert_all();
1029     assert!(!dense256.is_empty());
1030     for i in 0..256 {
1031         assert!(dense256.contains(i));
1032     }
1033
1034     assert!(sparse038.superset(&sparse038));    // sparse + sparse (self)
1035     assert!(sparse01358.superset(&sparse038));  // sparse + sparse
1036     assert!(dense10.superset(&sparse038));      // dense + sparse
1037     assert!(dense10.superset(&dense10));        // dense + dense (self)
1038     assert!(dense256.superset(&dense10));       // dense + dense
1039
1040     let mut hybrid = sparse038;
1041     assert!(!sparse01358.union(&hybrid));       // no change
1042     assert!(hybrid.union(&sparse01358));
1043     assert!(hybrid.superset(&sparse01358) && sparse01358.superset(&hybrid));
1044     assert!(!dense10.union(&sparse01358));
1045     assert!(!dense256.union(&dense10));
1046     let mut dense = dense10;
1047     assert!(dense.union(&dense256));
1048     assert!(dense.superset(&dense256) && dense256.superset(&dense));
1049     assert!(hybrid.union(&dense256));
1050     assert!(hybrid.superset(&dense256) && dense256.superset(&hybrid));
1051
1052     assert_eq!(dense256.iter().count(), 256);
1053     let mut dense0 = dense256;
1054     for i in 0..256 {
1055         assert!(dense0.remove(i));
1056     }
1057     assert!(!dense0.remove(0));
1058     assert!(dense0.is_empty());
1059 }
1060
1061 #[test]
1062 fn grow() {
1063     let mut set: GrowableBitSet<usize> = GrowableBitSet::with_capacity(65);
1064     for index in 0..65 {
1065         assert!(set.insert(index));
1066         assert!(!set.insert(index));
1067     }
1068     set.ensure(128);
1069
1070     // Check if the bits set before growing are still set
1071     for index in 0..65 {
1072         assert!(set.contains(index));
1073     }
1074
1075     // Check if the new bits are all un-set
1076     for index in 65..128 {
1077         assert!(!set.contains(index));
1078     }
1079
1080     // Check that we can set all new bits without running out of bounds
1081     for index in 65..128 {
1082         assert!(set.insert(index));
1083         assert!(!set.insert(index));
1084     }
1085 }
1086
1087 #[test]
1088 fn matrix_intersection() {
1089     let mut matrix: BitMatrix<usize, usize> = BitMatrix::new(200, 200);
1090
1091     // (*) Elements reachable from both 2 and 65.
1092
1093     matrix.insert(2, 3);
1094     matrix.insert(2, 6);
1095     matrix.insert(2, 10); // (*)
1096     matrix.insert(2, 64); // (*)
1097     matrix.insert(2, 65);
1098     matrix.insert(2, 130);
1099     matrix.insert(2, 160); // (*)
1100
1101     matrix.insert(64, 133);
1102
1103     matrix.insert(65, 2);
1104     matrix.insert(65, 8);
1105     matrix.insert(65, 10); // (*)
1106     matrix.insert(65, 64); // (*)
1107     matrix.insert(65, 68);
1108     matrix.insert(65, 133);
1109     matrix.insert(65, 160); // (*)
1110
1111     let intersection = matrix.intersect_rows(2, 64);
1112     assert!(intersection.is_empty());
1113
1114     let intersection = matrix.intersect_rows(2, 65);
1115     assert_eq!(intersection, &[10, 64, 160]);
1116 }
1117
1118 #[test]
1119 fn matrix_iter() {
1120     let mut matrix: BitMatrix<usize, usize> = BitMatrix::new(64, 100);
1121     matrix.insert(3, 22);
1122     matrix.insert(3, 75);
1123     matrix.insert(2, 99);
1124     matrix.insert(4, 0);
1125     matrix.union_rows(3, 5);
1126     matrix.insert_all_into_row(6);
1127
1128     let expected = [99];
1129     let mut iter = expected.iter();
1130     for i in matrix.iter(2) {
1131         let j = *iter.next().unwrap();
1132         assert_eq!(i, j);
1133     }
1134     assert!(iter.next().is_none());
1135
1136     let expected = [22, 75];
1137     let mut iter = expected.iter();
1138     assert_eq!(matrix.count(3), expected.len());
1139     for i in matrix.iter(3) {
1140         let j = *iter.next().unwrap();
1141         assert_eq!(i, j);
1142     }
1143     assert!(iter.next().is_none());
1144
1145     let expected = [0];
1146     let mut iter = expected.iter();
1147     assert_eq!(matrix.count(4), expected.len());
1148     for i in matrix.iter(4) {
1149         let j = *iter.next().unwrap();
1150         assert_eq!(i, j);
1151     }
1152     assert!(iter.next().is_none());
1153
1154     let expected = [22, 75];
1155     let mut iter = expected.iter();
1156     assert_eq!(matrix.count(5), expected.len());
1157     for i in matrix.iter(5) {
1158         let j = *iter.next().unwrap();
1159         assert_eq!(i, j);
1160     }
1161     assert!(iter.next().is_none());
1162
1163     assert_eq!(matrix.count(6), 100);
1164     let mut count = 0;
1165     for (idx, i) in matrix.iter(6).enumerate() {
1166         assert_eq!(idx, i);
1167         count += 1;
1168     }
1169     assert_eq!(count, 100);
1170
1171     if let Some(i) = matrix.iter(7).next() {
1172         panic!("expected no elements in row, but contains element {:?}", i);
1173     }
1174 }
1175
1176 #[test]
1177 fn sparse_matrix_iter() {
1178     let mut matrix: SparseBitMatrix<usize, usize> = SparseBitMatrix::new(100);
1179     matrix.insert(3, 22);
1180     matrix.insert(3, 75);
1181     matrix.insert(2, 99);
1182     matrix.insert(4, 0);
1183     matrix.union_rows(3, 5);
1184
1185     let expected = [99];
1186     let mut iter = expected.iter();
1187     for i in matrix.iter(2) {
1188         let j = *iter.next().unwrap();
1189         assert_eq!(i, j);
1190     }
1191     assert!(iter.next().is_none());
1192
1193     let expected = [22, 75];
1194     let mut iter = expected.iter();
1195     for i in matrix.iter(3) {
1196         let j = *iter.next().unwrap();
1197         assert_eq!(i, j);
1198     }
1199     assert!(iter.next().is_none());
1200
1201     let expected = [0];
1202     let mut iter = expected.iter();
1203     for i in matrix.iter(4) {
1204         let j = *iter.next().unwrap();
1205         assert_eq!(i, j);
1206     }
1207     assert!(iter.next().is_none());
1208
1209     let expected = [22, 75];
1210     let mut iter = expected.iter();
1211     for i in matrix.iter(5) {
1212         let j = *iter.next().unwrap();
1213         assert_eq!(i, j);
1214     }
1215     assert!(iter.next().is_none());
1216 }