]> git.lizzy.rs Git - rust.git/blob - src/librustc_index/bit_set.rs
Auto merge of #68358 - matthewjasper:spec-fix, r=nikomatsakis
[rust.git] / src / librustc_index / bit_set.rs
1 use crate::vec::{Idx, IndexVec};
2 use smallvec::SmallVec;
3 use std::fmt;
4 use std::iter;
5 use std::marker::PhantomData;
6 use std::mem;
7 use std::slice;
8
9 #[cfg(test)]
10 mod tests;
11
12 pub type Word = u64;
13 pub const WORD_BYTES: usize = mem::size_of::<Word>();
14 pub const WORD_BITS: usize = WORD_BYTES * 8;
15
16 /// A fixed-size bitset type with a dense representation.
17 ///
18 /// NOTE: Use [`GrowableBitSet`] if you need support for resizing after creation.
19 ///
20 /// `T` is an index type, typically a newtyped `usize` wrapper, but it can also
21 /// just be `usize`.
22 ///
23 /// All operations that involve an element will panic if the element is equal
24 /// to or greater than the domain size. All operations that involve two bitsets
25 /// will panic if the bitsets have differing domain sizes.
26 ///
27 /// [`GrowableBitSet`]: struct.GrowableBitSet.html
28 #[derive(Clone, Eq, PartialEq, RustcDecodable, RustcEncodable)]
29 pub struct BitSet<T: Idx> {
30     domain_size: usize,
31     words: Vec<Word>,
32     marker: PhantomData<T>,
33 }
34
35 impl<T: Idx> BitSet<T> {
36     /// Creates a new, empty bitset with a given `domain_size`.
37     #[inline]
38     pub fn new_empty(domain_size: usize) -> BitSet<T> {
39         let num_words = num_words(domain_size);
40         BitSet { domain_size, words: vec![0; num_words], marker: PhantomData }
41     }
42
43     /// Creates a new, filled bitset with a given `domain_size`.
44     #[inline]
45     pub fn new_filled(domain_size: usize) -> BitSet<T> {
46         let num_words = num_words(domain_size);
47         let mut result = BitSet { domain_size, words: vec![!0; num_words], marker: PhantomData };
48         result.clear_excess_bits();
49         result
50     }
51
52     /// Gets the domain size.
53     pub fn domain_size(&self) -> usize {
54         self.domain_size
55     }
56
57     /// Clear all elements.
58     #[inline]
59     pub fn clear(&mut self) {
60         for word in &mut self.words {
61             *word = 0;
62         }
63     }
64
65     /// Clear excess bits in the final word.
66     fn clear_excess_bits(&mut self) {
67         let num_bits_in_final_word = self.domain_size % WORD_BITS;
68         if num_bits_in_final_word > 0 {
69             let mask = (1 << num_bits_in_final_word) - 1;
70             let final_word_idx = self.words.len() - 1;
71             self.words[final_word_idx] &= mask;
72         }
73     }
74
75     /// Efficiently overwrite `self` with `other`.
76     pub fn overwrite(&mut self, other: &BitSet<T>) {
77         assert!(self.domain_size == other.domain_size);
78         self.words.clone_from_slice(&other.words);
79     }
80
81     /// Count the number of set bits in the set.
82     pub fn count(&self) -> usize {
83         self.words.iter().map(|e| e.count_ones() as usize).sum()
84     }
85
86     /// Returns `true` if `self` contains `elem`.
87     #[inline]
88     pub fn contains(&self, elem: T) -> bool {
89         assert!(elem.index() < self.domain_size);
90         let (word_index, mask) = word_index_and_mask(elem);
91         (self.words[word_index] & mask) != 0
92     }
93
94     /// Is `self` is a (non-strict) superset of `other`?
95     #[inline]
96     pub fn superset(&self, other: &BitSet<T>) -> bool {
97         assert_eq!(self.domain_size, other.domain_size);
98         self.words.iter().zip(&other.words).all(|(a, b)| (a & b) == *b)
99     }
100
101     /// Is the set empty?
102     #[inline]
103     pub fn is_empty(&self) -> bool {
104         self.words.iter().all(|a| *a == 0)
105     }
106
107     /// Insert `elem`. Returns whether the set has changed.
108     #[inline]
109     pub fn insert(&mut self, elem: T) -> bool {
110         assert!(elem.index() < self.domain_size);
111         let (word_index, mask) = word_index_and_mask(elem);
112         let word_ref = &mut self.words[word_index];
113         let word = *word_ref;
114         let new_word = word | mask;
115         *word_ref = new_word;
116         new_word != word
117     }
118
119     /// Sets all bits to true.
120     pub fn insert_all(&mut self) {
121         for word in &mut self.words {
122             *word = !0;
123         }
124         self.clear_excess_bits();
125     }
126
127     /// Returns `true` if the set has changed.
128     #[inline]
129     pub fn remove(&mut self, elem: T) -> bool {
130         assert!(elem.index() < self.domain_size);
131         let (word_index, mask) = word_index_and_mask(elem);
132         let word_ref = &mut self.words[word_index];
133         let word = *word_ref;
134         let new_word = word & !mask;
135         *word_ref = new_word;
136         new_word != word
137     }
138
139     /// Sets `self = self | other` and returns `true` if `self` changed
140     /// (i.e., if new bits were added).
141     pub fn union(&mut self, other: &impl UnionIntoBitSet<T>) -> bool {
142         other.union_into(self)
143     }
144
145     /// Sets `self = self - other` and returns `true` if `self` changed.
146     /// (i.e., if any bits were removed).
147     pub fn subtract(&mut self, other: &impl SubtractFromBitSet<T>) -> bool {
148         other.subtract_from(self)
149     }
150
151     /// Sets `self = self & other` and return `true` if `self` changed.
152     /// (i.e., if any bits were removed).
153     pub fn intersect(&mut self, other: &BitSet<T>) -> bool {
154         assert_eq!(self.domain_size, other.domain_size);
155         bitwise(&mut self.words, &other.words, |a, b| a & b)
156     }
157
158     /// Gets a slice of the underlying words.
159     pub fn words(&self) -> &[Word] {
160         &self.words
161     }
162
163     /// Iterates over the indices of set bits in a sorted order.
164     #[inline]
165     pub fn iter(&self) -> BitIter<'_, T> {
166         BitIter::new(&self.words)
167     }
168
169     /// Duplicates the set as a hybrid set.
170     pub fn to_hybrid(&self) -> HybridBitSet<T> {
171         // Note: we currently don't bother trying to make a Sparse set.
172         HybridBitSet::Dense(self.to_owned())
173     }
174
175     /// Set `self = self | other`. In contrast to `union` returns `true` if the set contains at
176     /// least one bit that is not in `other` (i.e. `other` is not a superset of `self`).
177     ///
178     /// This is an optimization for union of a hybrid bitset.
179     fn reverse_union_sparse(&mut self, sparse: &SparseBitSet<T>) -> bool {
180         assert!(sparse.domain_size == self.domain_size);
181         self.clear_excess_bits();
182
183         let mut not_already = false;
184         // Index of the current word not yet merged.
185         let mut current_index = 0;
186         // Mask of bits that came from the sparse set in the current word.
187         let mut new_bit_mask = 0;
188         for (word_index, mask) in sparse.iter().map(|x| word_index_and_mask(*x)) {
189             // Next bit is in a word not inspected yet.
190             if word_index > current_index {
191                 self.words[current_index] |= new_bit_mask;
192                 // Were there any bits in the old word that did not occur in the sparse set?
193                 not_already |= (self.words[current_index] ^ new_bit_mask) != 0;
194                 // Check all words we skipped for any set bit.
195                 not_already |= self.words[current_index + 1..word_index].iter().any(|&x| x != 0);
196                 // Update next word.
197                 current_index = word_index;
198                 // Reset bit mask, no bits have been merged yet.
199                 new_bit_mask = 0;
200             }
201             // Add bit and mark it as coming from the sparse set.
202             // self.words[word_index] |= mask;
203             new_bit_mask |= mask;
204         }
205         self.words[current_index] |= new_bit_mask;
206         // Any bits in the last inspected word that were not in the sparse set?
207         not_already |= (self.words[current_index] ^ new_bit_mask) != 0;
208         // Any bits in the tail? Note `clear_excess_bits` before.
209         not_already |= self.words[current_index + 1..].iter().any(|&x| x != 0);
210
211         not_already
212     }
213 }
214
215 /// This is implemented by all the bitsets so that BitSet::union() can be
216 /// passed any type of bitset.
217 pub trait UnionIntoBitSet<T: Idx> {
218     // Performs `other = other | self`.
219     fn union_into(&self, other: &mut BitSet<T>) -> bool;
220 }
221
222 /// This is implemented by all the bitsets so that BitSet::subtract() can be
223 /// passed any type of bitset.
224 pub trait SubtractFromBitSet<T: Idx> {
225     // Performs `other = other - self`.
226     fn subtract_from(&self, other: &mut BitSet<T>) -> bool;
227 }
228
229 impl<T: Idx> UnionIntoBitSet<T> for BitSet<T> {
230     fn union_into(&self, other: &mut BitSet<T>) -> bool {
231         assert_eq!(self.domain_size, other.domain_size);
232         bitwise(&mut other.words, &self.words, |a, b| a | b)
233     }
234 }
235
236 impl<T: Idx> SubtractFromBitSet<T> for BitSet<T> {
237     fn subtract_from(&self, other: &mut BitSet<T>) -> bool {
238         assert_eq!(self.domain_size, other.domain_size);
239         bitwise(&mut other.words, &self.words, |a, b| a & !b)
240     }
241 }
242
243 impl<T: Idx> fmt::Debug for BitSet<T> {
244     fn fmt(&self, w: &mut fmt::Formatter<'_>) -> fmt::Result {
245         w.debug_list().entries(self.iter()).finish()
246     }
247 }
248
249 impl<T: Idx> ToString for BitSet<T> {
250     fn to_string(&self) -> String {
251         let mut result = String::new();
252         let mut sep = '[';
253
254         // Note: this is a little endian printout of bytes.
255
256         // i tracks how many bits we have printed so far.
257         let mut i = 0;
258         for word in &self.words {
259             let mut word = *word;
260             for _ in 0..WORD_BYTES {
261                 // for each byte in `word`:
262                 let remain = self.domain_size - i;
263                 // If less than a byte remains, then mask just that many bits.
264                 let mask = if remain <= 8 { (1 << remain) - 1 } else { 0xFF };
265                 assert!(mask <= 0xFF);
266                 let byte = word & mask;
267
268                 result.push_str(&format!("{}{:02x}", sep, byte));
269
270                 if remain <= 8 {
271                     break;
272                 }
273                 word >>= 8;
274                 i += 8;
275                 sep = '-';
276             }
277             sep = '|';
278         }
279         result.push(']');
280
281         result
282     }
283 }
284
285 pub struct BitIter<'a, T: Idx> {
286     /// A copy of the current word, but with any already-visited bits cleared.
287     /// (This lets us use `trailing_zeros()` to find the next set bit.) When it
288     /// is reduced to 0, we move onto the next word.
289     word: Word,
290
291     /// The offset (measured in bits) of the current word.
292     offset: usize,
293
294     /// Underlying iterator over the words.
295     iter: slice::Iter<'a, Word>,
296
297     marker: PhantomData<T>,
298 }
299
300 impl<'a, T: Idx> BitIter<'a, T> {
301     #[inline]
302     fn new(words: &'a [Word]) -> BitIter<'a, T> {
303         // We initialize `word` and `offset` to degenerate values. On the first
304         // call to `next()` we will fall through to getting the first word from
305         // `iter`, which sets `word` to the first word (if there is one) and
306         // `offset` to 0. Doing it this way saves us from having to maintain
307         // additional state about whether we have started.
308         BitIter {
309             word: 0,
310             offset: std::usize::MAX - (WORD_BITS - 1),
311             iter: words.iter(),
312             marker: PhantomData,
313         }
314     }
315 }
316
317 impl<'a, T: Idx> Iterator for BitIter<'a, T> {
318     type Item = T;
319     fn next(&mut self) -> Option<T> {
320         loop {
321             if self.word != 0 {
322                 // Get the position of the next set bit in the current word,
323                 // then clear the bit.
324                 let bit_pos = self.word.trailing_zeros() as usize;
325                 let bit = 1 << bit_pos;
326                 self.word ^= bit;
327                 return Some(T::new(bit_pos + self.offset));
328             }
329
330             // Move onto the next word. `wrapping_add()` is needed to handle
331             // the degenerate initial value given to `offset` in `new()`.
332             let word = self.iter.next()?;
333             self.word = *word;
334             self.offset = self.offset.wrapping_add(WORD_BITS);
335         }
336     }
337 }
338
339 #[inline]
340 fn bitwise<Op>(out_vec: &mut [Word], in_vec: &[Word], op: Op) -> bool
341 where
342     Op: Fn(Word, Word) -> Word,
343 {
344     assert_eq!(out_vec.len(), in_vec.len());
345     let mut changed = false;
346     for (out_elem, in_elem) in out_vec.iter_mut().zip(in_vec.iter()) {
347         let old_val = *out_elem;
348         let new_val = op(old_val, *in_elem);
349         *out_elem = new_val;
350         changed |= old_val != new_val;
351     }
352     changed
353 }
354
355 const SPARSE_MAX: usize = 8;
356
357 /// A fixed-size bitset type with a sparse representation and a maximum of
358 /// `SPARSE_MAX` elements. The elements are stored as a sorted `SmallVec` with
359 /// no duplicates; although `SmallVec` can spill its elements to the heap, that
360 /// never happens within this type because of the `SPARSE_MAX` limit.
361 ///
362 /// This type is used by `HybridBitSet`; do not use directly.
363 #[derive(Clone, Debug)]
364 pub struct SparseBitSet<T: Idx> {
365     domain_size: usize,
366     elems: SmallVec<[T; SPARSE_MAX]>,
367 }
368
369 impl<T: Idx> SparseBitSet<T> {
370     fn new_empty(domain_size: usize) -> Self {
371         SparseBitSet { domain_size, elems: SmallVec::new() }
372     }
373
374     fn len(&self) -> usize {
375         self.elems.len()
376     }
377
378     fn is_empty(&self) -> bool {
379         self.elems.len() == 0
380     }
381
382     fn contains(&self, elem: T) -> bool {
383         assert!(elem.index() < self.domain_size);
384         self.elems.contains(&elem)
385     }
386
387     fn insert(&mut self, elem: T) -> bool {
388         assert!(elem.index() < self.domain_size);
389         let changed = if let Some(i) = self.elems.iter().position(|&e| e >= elem) {
390             if self.elems[i] == elem {
391                 // `elem` is already in the set.
392                 false
393             } else {
394                 // `elem` is smaller than one or more existing elements.
395                 self.elems.insert(i, elem);
396                 true
397             }
398         } else {
399             // `elem` is larger than all existing elements.
400             self.elems.push(elem);
401             true
402         };
403         assert!(self.len() <= SPARSE_MAX);
404         changed
405     }
406
407     fn remove(&mut self, elem: T) -> bool {
408         assert!(elem.index() < self.domain_size);
409         if let Some(i) = self.elems.iter().position(|&e| e == elem) {
410             self.elems.remove(i);
411             true
412         } else {
413             false
414         }
415     }
416
417     fn to_dense(&self) -> BitSet<T> {
418         let mut dense = BitSet::new_empty(self.domain_size);
419         for elem in self.elems.iter() {
420             dense.insert(*elem);
421         }
422         dense
423     }
424
425     fn iter(&self) -> slice::Iter<'_, T> {
426         self.elems.iter()
427     }
428 }
429
430 impl<T: Idx> UnionIntoBitSet<T> for SparseBitSet<T> {
431     fn union_into(&self, other: &mut BitSet<T>) -> bool {
432         assert_eq!(self.domain_size, other.domain_size);
433         let mut changed = false;
434         for elem in self.iter() {
435             changed |= other.insert(*elem);
436         }
437         changed
438     }
439 }
440
441 impl<T: Idx> SubtractFromBitSet<T> for SparseBitSet<T> {
442     fn subtract_from(&self, other: &mut BitSet<T>) -> bool {
443         assert_eq!(self.domain_size, other.domain_size);
444         let mut changed = false;
445         for elem in self.iter() {
446             changed |= other.remove(*elem);
447         }
448         changed
449     }
450 }
451
452 /// A fixed-size bitset type with a hybrid representation: sparse when there
453 /// are up to a `SPARSE_MAX` elements in the set, but dense when there are more
454 /// than `SPARSE_MAX`.
455 ///
456 /// This type is especially efficient for sets that typically have a small
457 /// number of elements, but a large `domain_size`, and are cleared frequently.
458 ///
459 /// `T` is an index type, typically a newtyped `usize` wrapper, but it can also
460 /// just be `usize`.
461 ///
462 /// All operations that involve an element will panic if the element is equal
463 /// to or greater than the domain size. All operations that involve two bitsets
464 /// will panic if the bitsets have differing domain sizes.
465 #[derive(Clone, Debug)]
466 pub enum HybridBitSet<T: Idx> {
467     Sparse(SparseBitSet<T>),
468     Dense(BitSet<T>),
469 }
470
471 impl<T: Idx> HybridBitSet<T> {
472     pub fn new_empty(domain_size: usize) -> Self {
473         HybridBitSet::Sparse(SparseBitSet::new_empty(domain_size))
474     }
475
476     fn domain_size(&self) -> usize {
477         match self {
478             HybridBitSet::Sparse(sparse) => sparse.domain_size,
479             HybridBitSet::Dense(dense) => dense.domain_size,
480         }
481     }
482
483     pub fn clear(&mut self) {
484         let domain_size = self.domain_size();
485         *self = HybridBitSet::new_empty(domain_size);
486     }
487
488     pub fn contains(&self, elem: T) -> bool {
489         match self {
490             HybridBitSet::Sparse(sparse) => sparse.contains(elem),
491             HybridBitSet::Dense(dense) => dense.contains(elem),
492         }
493     }
494
495     pub fn superset(&self, other: &HybridBitSet<T>) -> bool {
496         match (self, other) {
497             (HybridBitSet::Dense(self_dense), HybridBitSet::Dense(other_dense)) => {
498                 self_dense.superset(other_dense)
499             }
500             _ => {
501                 assert!(self.domain_size() == other.domain_size());
502                 other.iter().all(|elem| self.contains(elem))
503             }
504         }
505     }
506
507     pub fn is_empty(&self) -> bool {
508         match self {
509             HybridBitSet::Sparse(sparse) => sparse.is_empty(),
510             HybridBitSet::Dense(dense) => dense.is_empty(),
511         }
512     }
513
514     pub fn insert(&mut self, elem: T) -> bool {
515         // No need to check `elem` against `self.domain_size` here because all
516         // the match cases check it, one way or another.
517         match self {
518             HybridBitSet::Sparse(sparse) if sparse.len() < SPARSE_MAX => {
519                 // The set is sparse and has space for `elem`.
520                 sparse.insert(elem)
521             }
522             HybridBitSet::Sparse(sparse) if sparse.contains(elem) => {
523                 // The set is sparse and does not have space for `elem`, but
524                 // that doesn't matter because `elem` is already present.
525                 false
526             }
527             HybridBitSet::Sparse(sparse) => {
528                 // The set is sparse and full. Convert to a dense set.
529                 let mut dense = sparse.to_dense();
530                 let changed = dense.insert(elem);
531                 assert!(changed);
532                 *self = HybridBitSet::Dense(dense);
533                 changed
534             }
535             HybridBitSet::Dense(dense) => dense.insert(elem),
536         }
537     }
538
539     pub fn insert_all(&mut self) {
540         let domain_size = self.domain_size();
541         match self {
542             HybridBitSet::Sparse(_) => {
543                 *self = HybridBitSet::Dense(BitSet::new_filled(domain_size));
544             }
545             HybridBitSet::Dense(dense) => dense.insert_all(),
546         }
547     }
548
549     pub fn remove(&mut self, elem: T) -> bool {
550         // Note: we currently don't bother going from Dense back to Sparse.
551         match self {
552             HybridBitSet::Sparse(sparse) => sparse.remove(elem),
553             HybridBitSet::Dense(dense) => dense.remove(elem),
554         }
555     }
556
557     pub fn union(&mut self, other: &HybridBitSet<T>) -> bool {
558         match self {
559             HybridBitSet::Sparse(self_sparse) => {
560                 match other {
561                     HybridBitSet::Sparse(other_sparse) => {
562                         // Both sets are sparse. Add the elements in
563                         // `other_sparse` to `self` one at a time. This
564                         // may or may not cause `self` to be densified.
565                         assert_eq!(self.domain_size(), other.domain_size());
566                         let mut changed = false;
567                         for elem in other_sparse.iter() {
568                             changed |= self.insert(*elem);
569                         }
570                         changed
571                     }
572                     HybridBitSet::Dense(other_dense) => {
573                         // `self` is sparse and `other` is dense. To
574                         // merge them, we have two available strategies:
575                         // * Densify `self` then merge other
576                         // * Clone other then integrate bits from `self`
577                         // The second strategy requires dedicated method
578                         // since the usual `union` returns the wrong
579                         // result. In the dedicated case the computation
580                         // is slightly faster if the bits of the sparse
581                         // bitset map to only few words of the dense
582                         // representation, i.e. indices are near each
583                         // other.
584                         //
585                         // Benchmarking seems to suggest that the second
586                         // option is worth it.
587                         let mut new_dense = other_dense.clone();
588                         let changed = new_dense.reverse_union_sparse(self_sparse);
589                         *self = HybridBitSet::Dense(new_dense);
590                         changed
591                     }
592                 }
593             }
594
595             HybridBitSet::Dense(self_dense) => self_dense.union(other),
596         }
597     }
598
599     /// Converts to a dense set, consuming itself in the process.
600     pub fn to_dense(self) -> BitSet<T> {
601         match self {
602             HybridBitSet::Sparse(sparse) => sparse.to_dense(),
603             HybridBitSet::Dense(dense) => dense,
604         }
605     }
606
607     pub fn iter(&self) -> HybridIter<'_, T> {
608         match self {
609             HybridBitSet::Sparse(sparse) => HybridIter::Sparse(sparse.iter()),
610             HybridBitSet::Dense(dense) => HybridIter::Dense(dense.iter()),
611         }
612     }
613 }
614
615 impl<T: Idx> UnionIntoBitSet<T> for HybridBitSet<T> {
616     fn union_into(&self, other: &mut BitSet<T>) -> bool {
617         match self {
618             HybridBitSet::Sparse(sparse) => sparse.union_into(other),
619             HybridBitSet::Dense(dense) => dense.union_into(other),
620         }
621     }
622 }
623
624 impl<T: Idx> SubtractFromBitSet<T> for HybridBitSet<T> {
625     fn subtract_from(&self, other: &mut BitSet<T>) -> bool {
626         match self {
627             HybridBitSet::Sparse(sparse) => sparse.subtract_from(other),
628             HybridBitSet::Dense(dense) => dense.subtract_from(other),
629         }
630     }
631 }
632
633 pub enum HybridIter<'a, T: Idx> {
634     Sparse(slice::Iter<'a, T>),
635     Dense(BitIter<'a, T>),
636 }
637
638 impl<'a, T: Idx> Iterator for HybridIter<'a, T> {
639     type Item = T;
640
641     fn next(&mut self) -> Option<T> {
642         match self {
643             HybridIter::Sparse(sparse) => sparse.next().copied(),
644             HybridIter::Dense(dense) => dense.next(),
645         }
646     }
647 }
648
649 /// A resizable bitset type with a dense representation.
650 ///
651 /// `T` is an index type, typically a newtyped `usize` wrapper, but it can also
652 /// just be `usize`.
653 ///
654 /// All operations that involve an element will panic if the element is equal
655 /// to or greater than the domain size.
656 #[derive(Clone, Debug, PartialEq)]
657 pub struct GrowableBitSet<T: Idx> {
658     bit_set: BitSet<T>,
659 }
660
661 impl<T: Idx> GrowableBitSet<T> {
662     /// Ensure that the set can hold at least `min_domain_size` elements.
663     pub fn ensure(&mut self, min_domain_size: usize) {
664         if self.bit_set.domain_size < min_domain_size {
665             self.bit_set.domain_size = min_domain_size;
666         }
667
668         let min_num_words = num_words(min_domain_size);
669         if self.bit_set.words.len() < min_num_words {
670             self.bit_set.words.resize(min_num_words, 0)
671         }
672     }
673
674     pub fn new_empty() -> GrowableBitSet<T> {
675         GrowableBitSet { bit_set: BitSet::new_empty(0) }
676     }
677
678     pub fn with_capacity(capacity: usize) -> GrowableBitSet<T> {
679         GrowableBitSet { bit_set: BitSet::new_empty(capacity) }
680     }
681
682     /// Returns `true` if the set has changed.
683     #[inline]
684     pub fn insert(&mut self, elem: T) -> bool {
685         self.ensure(elem.index() + 1);
686         self.bit_set.insert(elem)
687     }
688
689     #[inline]
690     pub fn contains(&self, elem: T) -> bool {
691         let (word_index, mask) = word_index_and_mask(elem);
692         if let Some(word) = self.bit_set.words.get(word_index) { (word & mask) != 0 } else { false }
693     }
694 }
695
696 /// A fixed-size 2D bit matrix type with a dense representation.
697 ///
698 /// `R` and `C` are index types used to identify rows and columns respectively;
699 /// typically newtyped `usize` wrappers, but they can also just be `usize`.
700 ///
701 /// All operations that involve a row and/or column index will panic if the
702 /// index exceeds the relevant bound.
703 #[derive(Clone, Debug, Eq, PartialEq, RustcDecodable, RustcEncodable)]
704 pub struct BitMatrix<R: Idx, C: Idx> {
705     num_rows: usize,
706     num_columns: usize,
707     words: Vec<Word>,
708     marker: PhantomData<(R, C)>,
709 }
710
711 impl<R: Idx, C: Idx> BitMatrix<R, C> {
712     /// Creates a new `rows x columns` matrix, initially empty.
713     pub fn new(num_rows: usize, num_columns: usize) -> BitMatrix<R, C> {
714         // For every element, we need one bit for every other
715         // element. Round up to an even number of words.
716         let words_per_row = num_words(num_columns);
717         BitMatrix {
718             num_rows,
719             num_columns,
720             words: vec![0; num_rows * words_per_row],
721             marker: PhantomData,
722         }
723     }
724
725     /// Creates a new matrix, with `row` used as the value for every row.
726     pub fn from_row_n(row: &BitSet<C>, num_rows: usize) -> BitMatrix<R, C> {
727         let num_columns = row.domain_size();
728         let words_per_row = num_words(num_columns);
729         assert_eq!(words_per_row, row.words().len());
730         BitMatrix {
731             num_rows,
732             num_columns,
733             words: iter::repeat(row.words()).take(num_rows).flatten().cloned().collect(),
734             marker: PhantomData,
735         }
736     }
737
738     pub fn rows(&self) -> impl Iterator<Item = R> {
739         (0..self.num_rows).map(R::new)
740     }
741
742     /// The range of bits for a given row.
743     fn range(&self, row: R) -> (usize, usize) {
744         let words_per_row = num_words(self.num_columns);
745         let start = row.index() * words_per_row;
746         (start, start + words_per_row)
747     }
748
749     /// Sets the cell at `(row, column)` to true. Put another way, insert
750     /// `column` to the bitset for `row`.
751     ///
752     /// Returns `true` if this changed the matrix.
753     pub fn insert(&mut self, row: R, column: C) -> bool {
754         assert!(row.index() < self.num_rows && column.index() < self.num_columns);
755         let (start, _) = self.range(row);
756         let (word_index, mask) = word_index_and_mask(column);
757         let words = &mut self.words[..];
758         let word = words[start + word_index];
759         let new_word = word | mask;
760         words[start + word_index] = new_word;
761         word != new_word
762     }
763
764     /// Do the bits from `row` contain `column`? Put another way, is
765     /// the matrix cell at `(row, column)` true?  Put yet another way,
766     /// if the matrix represents (transitive) reachability, can
767     /// `row` reach `column`?
768     pub fn contains(&self, row: R, column: C) -> bool {
769         assert!(row.index() < self.num_rows && column.index() < self.num_columns);
770         let (start, _) = self.range(row);
771         let (word_index, mask) = word_index_and_mask(column);
772         (self.words[start + word_index] & mask) != 0
773     }
774
775     /// Returns those indices that are true in rows `a` and `b`. This
776     /// is an O(n) operation where `n` is the number of elements
777     /// (somewhat independent from the actual size of the
778     /// intersection, in particular).
779     pub fn intersect_rows(&self, row1: R, row2: R) -> Vec<C> {
780         assert!(row1.index() < self.num_rows && row2.index() < self.num_rows);
781         let (row1_start, row1_end) = self.range(row1);
782         let (row2_start, row2_end) = self.range(row2);
783         let mut result = Vec::with_capacity(self.num_columns);
784         for (base, (i, j)) in (row1_start..row1_end).zip(row2_start..row2_end).enumerate() {
785             let mut v = self.words[i] & self.words[j];
786             for bit in 0..WORD_BITS {
787                 if v == 0 {
788                     break;
789                 }
790                 if v & 0x1 != 0 {
791                     result.push(C::new(base * WORD_BITS + bit));
792                 }
793                 v >>= 1;
794             }
795         }
796         result
797     }
798
799     /// Adds the bits from row `read` to the bits from row `write`, and
800     /// returns `true` if anything changed.
801     ///
802     /// This is used when computing transitive reachability because if
803     /// you have an edge `write -> read`, because in that case
804     /// `write` can reach everything that `read` can (and
805     /// potentially more).
806     pub fn union_rows(&mut self, read: R, write: R) -> bool {
807         assert!(read.index() < self.num_rows && write.index() < self.num_rows);
808         let (read_start, read_end) = self.range(read);
809         let (write_start, write_end) = self.range(write);
810         let words = &mut self.words[..];
811         let mut changed = false;
812         for (read_index, write_index) in (read_start..read_end).zip(write_start..write_end) {
813             let word = words[write_index];
814             let new_word = word | words[read_index];
815             words[write_index] = new_word;
816             changed |= word != new_word;
817         }
818         changed
819     }
820
821     /// Adds the bits from `with` to the bits from row `write`, and
822     /// returns `true` if anything changed.
823     pub fn union_row_with(&mut self, with: &BitSet<C>, write: R) -> bool {
824         assert!(write.index() < self.num_rows);
825         assert_eq!(with.domain_size(), self.num_columns);
826         let (write_start, write_end) = self.range(write);
827         let mut changed = false;
828         for (read_index, write_index) in (0..with.words().len()).zip(write_start..write_end) {
829             let word = self.words[write_index];
830             let new_word = word | with.words()[read_index];
831             self.words[write_index] = new_word;
832             changed |= word != new_word;
833         }
834         changed
835     }
836
837     /// Sets every cell in `row` to true.
838     pub fn insert_all_into_row(&mut self, row: R) {
839         assert!(row.index() < self.num_rows);
840         let (start, end) = self.range(row);
841         let words = &mut self.words[..];
842         for index in start..end {
843             words[index] = !0;
844         }
845         self.clear_excess_bits(row);
846     }
847
848     /// Clear excess bits in the final word of the row.
849     fn clear_excess_bits(&mut self, row: R) {
850         let num_bits_in_final_word = self.num_columns % WORD_BITS;
851         if num_bits_in_final_word > 0 {
852             let mask = (1 << num_bits_in_final_word) - 1;
853             let (_, end) = self.range(row);
854             let final_word_idx = end - 1;
855             self.words[final_word_idx] &= mask;
856         }
857     }
858
859     /// Gets a slice of the underlying words.
860     pub fn words(&self) -> &[Word] {
861         &self.words
862     }
863
864     /// Iterates through all the columns set to true in a given row of
865     /// the matrix.
866     pub fn iter(&self, row: R) -> BitIter<'_, C> {
867         assert!(row.index() < self.num_rows);
868         let (start, end) = self.range(row);
869         BitIter::new(&self.words[start..end])
870     }
871
872     /// Returns the number of elements in `row`.
873     pub fn count(&self, row: R) -> usize {
874         let (start, end) = self.range(row);
875         self.words[start..end].iter().map(|e| e.count_ones() as usize).sum()
876     }
877 }
878
879 /// A fixed-column-size, variable-row-size 2D bit matrix with a moderately
880 /// sparse representation.
881 ///
882 /// Initially, every row has no explicit representation. If any bit within a
883 /// row is set, the entire row is instantiated as `Some(<HybridBitSet>)`.
884 /// Furthermore, any previously uninstantiated rows prior to it will be
885 /// instantiated as `None`. Those prior rows may themselves become fully
886 /// instantiated later on if any of their bits are set.
887 ///
888 /// `R` and `C` are index types used to identify rows and columns respectively;
889 /// typically newtyped `usize` wrappers, but they can also just be `usize`.
890 #[derive(Clone, Debug)]
891 pub struct SparseBitMatrix<R, C>
892 where
893     R: Idx,
894     C: Idx,
895 {
896     num_columns: usize,
897     rows: IndexVec<R, Option<HybridBitSet<C>>>,
898 }
899
900 impl<R: Idx, C: Idx> SparseBitMatrix<R, C> {
901     /// Creates a new empty sparse bit matrix with no rows or columns.
902     pub fn new(num_columns: usize) -> Self {
903         Self { num_columns, rows: IndexVec::new() }
904     }
905
906     fn ensure_row(&mut self, row: R) -> &mut HybridBitSet<C> {
907         // Instantiate any missing rows up to and including row `row` with an
908         // empty HybridBitSet.
909         self.rows.ensure_contains_elem(row, || None);
910
911         // Then replace row `row` with a full HybridBitSet if necessary.
912         let num_columns = self.num_columns;
913         self.rows[row].get_or_insert_with(|| HybridBitSet::new_empty(num_columns))
914     }
915
916     /// Sets the cell at `(row, column)` to true. Put another way, insert
917     /// `column` to the bitset for `row`.
918     ///
919     /// Returns `true` if this changed the matrix.
920     pub fn insert(&mut self, row: R, column: C) -> bool {
921         self.ensure_row(row).insert(column)
922     }
923
924     /// Do the bits from `row` contain `column`? Put another way, is
925     /// the matrix cell at `(row, column)` true?  Put yet another way,
926     /// if the matrix represents (transitive) reachability, can
927     /// `row` reach `column`?
928     pub fn contains(&self, row: R, column: C) -> bool {
929         self.row(row).map_or(false, |r| r.contains(column))
930     }
931
932     /// Adds the bits from row `read` to the bits from row `write`, and
933     /// returns `true` if anything changed.
934     ///
935     /// This is used when computing transitive reachability because if
936     /// you have an edge `write -> read`, because in that case
937     /// `write` can reach everything that `read` can (and
938     /// potentially more).
939     pub fn union_rows(&mut self, read: R, write: R) -> bool {
940         if read == write || self.row(read).is_none() {
941             return false;
942         }
943
944         self.ensure_row(write);
945         if let (Some(read_row), Some(write_row)) = self.rows.pick2_mut(read, write) {
946             write_row.union(read_row)
947         } else {
948             unreachable!()
949         }
950     }
951
952     /// Union a row, `from`, into the `into` row.
953     pub fn union_into_row(&mut self, into: R, from: &HybridBitSet<C>) -> bool {
954         self.ensure_row(into).union(from)
955     }
956
957     /// Insert all bits in the given row.
958     pub fn insert_all_into_row(&mut self, row: R) {
959         self.ensure_row(row).insert_all();
960     }
961
962     pub fn rows(&self) -> impl Iterator<Item = R> {
963         self.rows.indices()
964     }
965
966     /// Iterates through all the columns set to true in a given row of
967     /// the matrix.
968     pub fn iter<'a>(&'a self, row: R) -> impl Iterator<Item = C> + 'a {
969         self.row(row).into_iter().flat_map(|r| r.iter())
970     }
971
972     pub fn row(&self, row: R) -> Option<&HybridBitSet<C>> {
973         if let Some(Some(row)) = self.rows.get(row) { Some(row) } else { None }
974     }
975 }
976
977 #[inline]
978 fn num_words<T: Idx>(domain_size: T) -> usize {
979     (domain_size.index() + WORD_BITS - 1) / WORD_BITS
980 }
981
982 #[inline]
983 fn word_index_and_mask<T: Idx>(elem: T) -> (usize, Word) {
984     let elem = elem.index();
985     let word_index = elem / WORD_BITS;
986     let mask = 1 << (elem % WORD_BITS);
987     (word_index, mask)
988 }