]> git.lizzy.rs Git - rust.git/blob - src/librustc_index/bit_set.rs
mir: use `FiniteBitSet<u32>` in polymorphization
[rust.git] / src / librustc_index / bit_set.rs
1 use crate::vec::{Idx, IndexVec};
2 use arrayvec::ArrayVec;
3 use std::fmt;
4 use std::iter;
5 use std::marker::PhantomData;
6 use std::mem;
7 use std::ops::{BitAnd, BitAndAssign, BitOrAssign, Not, Range, Shl};
8 use std::slice;
9
10 #[cfg(test)]
11 mod tests;
12
13 pub type Word = u64;
14 pub const WORD_BYTES: usize = mem::size_of::<Word>();
15 pub const WORD_BITS: usize = WORD_BYTES * 8;
16
17 /// A fixed-size bitset type with a dense representation.
18 ///
19 /// NOTE: Use [`GrowableBitSet`] if you need support for resizing after creation.
20 ///
21 /// `T` is an index type, typically a newtyped `usize` wrapper, but it can also
22 /// just be `usize`.
23 ///
24 /// All operations that involve an element will panic if the element is equal
25 /// to or greater than the domain size. All operations that involve two bitsets
26 /// will panic if the bitsets have differing domain sizes.
27 ///
28 /// [`GrowableBitSet`]: struct.GrowableBitSet.html
29 #[derive(Clone, Eq, PartialEq, RustcDecodable, RustcEncodable)]
30 pub struct BitSet<T: Idx> {
31     domain_size: usize,
32     words: Vec<Word>,
33     marker: PhantomData<T>,
34 }
35
36 impl<T: Idx> BitSet<T> {
37     /// Creates a new, empty bitset with a given `domain_size`.
38     #[inline]
39     pub fn new_empty(domain_size: usize) -> BitSet<T> {
40         let num_words = num_words(domain_size);
41         BitSet { domain_size, words: vec![0; num_words], marker: PhantomData }
42     }
43
44     /// Creates a new, filled bitset with a given `domain_size`.
45     #[inline]
46     pub fn new_filled(domain_size: usize) -> BitSet<T> {
47         let num_words = num_words(domain_size);
48         let mut result = BitSet { domain_size, words: vec![!0; num_words], marker: PhantomData };
49         result.clear_excess_bits();
50         result
51     }
52
53     /// Gets the domain size.
54     pub fn domain_size(&self) -> usize {
55         self.domain_size
56     }
57
58     /// Clear all elements.
59     #[inline]
60     pub fn clear(&mut self) {
61         for word in &mut self.words {
62             *word = 0;
63         }
64     }
65
66     /// Clear excess bits in the final word.
67     fn clear_excess_bits(&mut self) {
68         let num_bits_in_final_word = self.domain_size % WORD_BITS;
69         if num_bits_in_final_word > 0 {
70             let mask = (1 << num_bits_in_final_word) - 1;
71             let final_word_idx = self.words.len() - 1;
72             self.words[final_word_idx] &= mask;
73         }
74     }
75
76     /// Efficiently overwrite `self` with `other`.
77     pub fn overwrite(&mut self, other: &BitSet<T>) {
78         assert!(self.domain_size == other.domain_size);
79         self.words.clone_from_slice(&other.words);
80     }
81
82     /// Count the number of set bits in the set.
83     pub fn count(&self) -> usize {
84         self.words.iter().map(|e| e.count_ones() as usize).sum()
85     }
86
87     /// Returns `true` if `self` contains `elem`.
88     #[inline]
89     pub fn contains(&self, elem: T) -> bool {
90         assert!(elem.index() < self.domain_size);
91         let (word_index, mask) = word_index_and_mask(elem);
92         (self.words[word_index] & mask) != 0
93     }
94
95     /// Is `self` is a (non-strict) superset of `other`?
96     #[inline]
97     pub fn superset(&self, other: &BitSet<T>) -> bool {
98         assert_eq!(self.domain_size, other.domain_size);
99         self.words.iter().zip(&other.words).all(|(a, b)| (a & b) == *b)
100     }
101
102     /// Is the set empty?
103     #[inline]
104     pub fn is_empty(&self) -> bool {
105         self.words.iter().all(|a| *a == 0)
106     }
107
108     /// Insert `elem`. Returns whether the set has changed.
109     #[inline]
110     pub fn insert(&mut self, elem: T) -> bool {
111         assert!(elem.index() < self.domain_size);
112         let (word_index, mask) = word_index_and_mask(elem);
113         let word_ref = &mut self.words[word_index];
114         let word = *word_ref;
115         let new_word = word | mask;
116         *word_ref = new_word;
117         new_word != word
118     }
119
120     /// Sets all bits to true.
121     pub fn insert_all(&mut self) {
122         for word in &mut self.words {
123             *word = !0;
124         }
125         self.clear_excess_bits();
126     }
127
128     /// Returns `true` if the set has changed.
129     #[inline]
130     pub fn remove(&mut self, elem: T) -> bool {
131         assert!(elem.index() < self.domain_size);
132         let (word_index, mask) = word_index_and_mask(elem);
133         let word_ref = &mut self.words[word_index];
134         let word = *word_ref;
135         let new_word = word & !mask;
136         *word_ref = new_word;
137         new_word != word
138     }
139
140     /// Sets `self = self | other` and returns `true` if `self` changed
141     /// (i.e., if new bits were added).
142     pub fn union(&mut self, other: &impl UnionIntoBitSet<T>) -> bool {
143         other.union_into(self)
144     }
145
146     /// Sets `self = self - other` and returns `true` if `self` changed.
147     /// (i.e., if any bits were removed).
148     pub fn subtract(&mut self, other: &impl SubtractFromBitSet<T>) -> bool {
149         other.subtract_from(self)
150     }
151
152     /// Sets `self = self & other` and return `true` if `self` changed.
153     /// (i.e., if any bits were removed).
154     pub fn intersect(&mut self, other: &BitSet<T>) -> bool {
155         assert_eq!(self.domain_size, other.domain_size);
156         bitwise(&mut self.words, &other.words, |a, b| a & b)
157     }
158
159     /// Gets a slice of the underlying words.
160     pub fn words(&self) -> &[Word] {
161         &self.words
162     }
163
164     /// Iterates over the indices of set bits in a sorted order.
165     #[inline]
166     pub fn iter(&self) -> BitIter<'_, T> {
167         BitIter::new(&self.words)
168     }
169
170     /// Duplicates the set as a hybrid set.
171     pub fn to_hybrid(&self) -> HybridBitSet<T> {
172         // Note: we currently don't bother trying to make a Sparse set.
173         HybridBitSet::Dense(self.to_owned())
174     }
175
176     /// Set `self = self | other`. In contrast to `union` returns `true` if the set contains at
177     /// least one bit that is not in `other` (i.e. `other` is not a superset of `self`).
178     ///
179     /// This is an optimization for union of a hybrid bitset.
180     fn reverse_union_sparse(&mut self, sparse: &SparseBitSet<T>) -> bool {
181         assert!(sparse.domain_size == self.domain_size);
182         self.clear_excess_bits();
183
184         let mut not_already = false;
185         // Index of the current word not yet merged.
186         let mut current_index = 0;
187         // Mask of bits that came from the sparse set in the current word.
188         let mut new_bit_mask = 0;
189         for (word_index, mask) in sparse.iter().map(|x| word_index_and_mask(*x)) {
190             // Next bit is in a word not inspected yet.
191             if word_index > current_index {
192                 self.words[current_index] |= new_bit_mask;
193                 // Were there any bits in the old word that did not occur in the sparse set?
194                 not_already |= (self.words[current_index] ^ new_bit_mask) != 0;
195                 // Check all words we skipped for any set bit.
196                 not_already |= self.words[current_index + 1..word_index].iter().any(|&x| x != 0);
197                 // Update next word.
198                 current_index = word_index;
199                 // Reset bit mask, no bits have been merged yet.
200                 new_bit_mask = 0;
201             }
202             // Add bit and mark it as coming from the sparse set.
203             // self.words[word_index] |= mask;
204             new_bit_mask |= mask;
205         }
206         self.words[current_index] |= new_bit_mask;
207         // Any bits in the last inspected word that were not in the sparse set?
208         not_already |= (self.words[current_index] ^ new_bit_mask) != 0;
209         // Any bits in the tail? Note `clear_excess_bits` before.
210         not_already |= self.words[current_index + 1..].iter().any(|&x| x != 0);
211
212         not_already
213     }
214 }
215
216 /// This is implemented by all the bitsets so that BitSet::union() can be
217 /// passed any type of bitset.
218 pub trait UnionIntoBitSet<T: Idx> {
219     // Performs `other = other | self`.
220     fn union_into(&self, other: &mut BitSet<T>) -> bool;
221 }
222
223 /// This is implemented by all the bitsets so that BitSet::subtract() can be
224 /// passed any type of bitset.
225 pub trait SubtractFromBitSet<T: Idx> {
226     // Performs `other = other - self`.
227     fn subtract_from(&self, other: &mut BitSet<T>) -> bool;
228 }
229
230 impl<T: Idx> UnionIntoBitSet<T> for BitSet<T> {
231     fn union_into(&self, other: &mut BitSet<T>) -> bool {
232         assert_eq!(self.domain_size, other.domain_size);
233         bitwise(&mut other.words, &self.words, |a, b| a | b)
234     }
235 }
236
237 impl<T: Idx> SubtractFromBitSet<T> for BitSet<T> {
238     fn subtract_from(&self, other: &mut BitSet<T>) -> bool {
239         assert_eq!(self.domain_size, other.domain_size);
240         bitwise(&mut other.words, &self.words, |a, b| a & !b)
241     }
242 }
243
244 impl<T: Idx> fmt::Debug for BitSet<T> {
245     fn fmt(&self, w: &mut fmt::Formatter<'_>) -> fmt::Result {
246         w.debug_list().entries(self.iter()).finish()
247     }
248 }
249
250 impl<T: Idx> ToString for BitSet<T> {
251     fn to_string(&self) -> String {
252         let mut result = String::new();
253         let mut sep = '[';
254
255         // Note: this is a little endian printout of bytes.
256
257         // i tracks how many bits we have printed so far.
258         let mut i = 0;
259         for word in &self.words {
260             let mut word = *word;
261             for _ in 0..WORD_BYTES {
262                 // for each byte in `word`:
263                 let remain = self.domain_size - i;
264                 // If less than a byte remains, then mask just that many bits.
265                 let mask = if remain <= 8 { (1 << remain) - 1 } else { 0xFF };
266                 assert!(mask <= 0xFF);
267                 let byte = word & mask;
268
269                 result.push_str(&format!("{}{:02x}", sep, byte));
270
271                 if remain <= 8 {
272                     break;
273                 }
274                 word >>= 8;
275                 i += 8;
276                 sep = '-';
277             }
278             sep = '|';
279         }
280         result.push(']');
281
282         result
283     }
284 }
285
286 pub struct BitIter<'a, T: Idx> {
287     /// A copy of the current word, but with any already-visited bits cleared.
288     /// (This lets us use `trailing_zeros()` to find the next set bit.) When it
289     /// is reduced to 0, we move onto the next word.
290     word: Word,
291
292     /// The offset (measured in bits) of the current word.
293     offset: usize,
294
295     /// Underlying iterator over the words.
296     iter: slice::Iter<'a, Word>,
297
298     marker: PhantomData<T>,
299 }
300
301 impl<'a, T: Idx> BitIter<'a, T> {
302     #[inline]
303     fn new(words: &'a [Word]) -> BitIter<'a, T> {
304         // We initialize `word` and `offset` to degenerate values. On the first
305         // call to `next()` we will fall through to getting the first word from
306         // `iter`, which sets `word` to the first word (if there is one) and
307         // `offset` to 0. Doing it this way saves us from having to maintain
308         // additional state about whether we have started.
309         BitIter {
310             word: 0,
311             offset: usize::MAX - (WORD_BITS - 1),
312             iter: words.iter(),
313             marker: PhantomData,
314         }
315     }
316 }
317
318 impl<'a, T: Idx> Iterator for BitIter<'a, T> {
319     type Item = T;
320     fn next(&mut self) -> Option<T> {
321         loop {
322             if self.word != 0 {
323                 // Get the position of the next set bit in the current word,
324                 // then clear the bit.
325                 let bit_pos = self.word.trailing_zeros() as usize;
326                 let bit = 1 << bit_pos;
327                 self.word ^= bit;
328                 return Some(T::new(bit_pos + self.offset));
329             }
330
331             // Move onto the next word. `wrapping_add()` is needed to handle
332             // the degenerate initial value given to `offset` in `new()`.
333             let word = self.iter.next()?;
334             self.word = *word;
335             self.offset = self.offset.wrapping_add(WORD_BITS);
336         }
337     }
338 }
339
340 #[inline]
341 fn bitwise<Op>(out_vec: &mut [Word], in_vec: &[Word], op: Op) -> bool
342 where
343     Op: Fn(Word, Word) -> Word,
344 {
345     assert_eq!(out_vec.len(), in_vec.len());
346     let mut changed = false;
347     for (out_elem, in_elem) in out_vec.iter_mut().zip(in_vec.iter()) {
348         let old_val = *out_elem;
349         let new_val = op(old_val, *in_elem);
350         *out_elem = new_val;
351         changed |= old_val != new_val;
352     }
353     changed
354 }
355
356 const SPARSE_MAX: usize = 8;
357
358 /// A fixed-size bitset type with a sparse representation and a maximum of
359 /// `SPARSE_MAX` elements. The elements are stored as a sorted `ArrayVec` with
360 /// no duplicates.
361 ///
362 /// This type is used by `HybridBitSet`; do not use directly.
363 #[derive(Clone, Debug)]
364 pub struct SparseBitSet<T: Idx> {
365     domain_size: usize,
366     elems: ArrayVec<[T; SPARSE_MAX]>,
367 }
368
369 impl<T: Idx> SparseBitSet<T> {
370     fn new_empty(domain_size: usize) -> Self {
371         SparseBitSet { domain_size, elems: ArrayVec::new() }
372     }
373
374     fn len(&self) -> usize {
375         self.elems.len()
376     }
377
378     fn is_empty(&self) -> bool {
379         self.elems.len() == 0
380     }
381
382     fn contains(&self, elem: T) -> bool {
383         assert!(elem.index() < self.domain_size);
384         self.elems.contains(&elem)
385     }
386
387     fn insert(&mut self, elem: T) -> bool {
388         assert!(elem.index() < self.domain_size);
389         let changed = if let Some(i) = self.elems.iter().position(|&e| e >= elem) {
390             if self.elems[i] == elem {
391                 // `elem` is already in the set.
392                 false
393             } else {
394                 // `elem` is smaller than one or more existing elements.
395                 self.elems.insert(i, elem);
396                 true
397             }
398         } else {
399             // `elem` is larger than all existing elements.
400             self.elems.push(elem);
401             true
402         };
403         assert!(self.len() <= SPARSE_MAX);
404         changed
405     }
406
407     fn remove(&mut self, elem: T) -> bool {
408         assert!(elem.index() < self.domain_size);
409         if let Some(i) = self.elems.iter().position(|&e| e == elem) {
410             self.elems.remove(i);
411             true
412         } else {
413             false
414         }
415     }
416
417     fn to_dense(&self) -> BitSet<T> {
418         let mut dense = BitSet::new_empty(self.domain_size);
419         for elem in self.elems.iter() {
420             dense.insert(*elem);
421         }
422         dense
423     }
424
425     fn iter(&self) -> slice::Iter<'_, T> {
426         self.elems.iter()
427     }
428 }
429
430 impl<T: Idx> UnionIntoBitSet<T> for SparseBitSet<T> {
431     fn union_into(&self, other: &mut BitSet<T>) -> bool {
432         assert_eq!(self.domain_size, other.domain_size);
433         let mut changed = false;
434         for elem in self.iter() {
435             changed |= other.insert(*elem);
436         }
437         changed
438     }
439 }
440
441 impl<T: Idx> SubtractFromBitSet<T> for SparseBitSet<T> {
442     fn subtract_from(&self, other: &mut BitSet<T>) -> bool {
443         assert_eq!(self.domain_size, other.domain_size);
444         let mut changed = false;
445         for elem in self.iter() {
446             changed |= other.remove(*elem);
447         }
448         changed
449     }
450 }
451
452 /// A fixed-size bitset type with a hybrid representation: sparse when there
453 /// are up to a `SPARSE_MAX` elements in the set, but dense when there are more
454 /// than `SPARSE_MAX`.
455 ///
456 /// This type is especially efficient for sets that typically have a small
457 /// number of elements, but a large `domain_size`, and are cleared frequently.
458 ///
459 /// `T` is an index type, typically a newtyped `usize` wrapper, but it can also
460 /// just be `usize`.
461 ///
462 /// All operations that involve an element will panic if the element is equal
463 /// to or greater than the domain size. All operations that involve two bitsets
464 /// will panic if the bitsets have differing domain sizes.
465 #[derive(Clone, Debug)]
466 pub enum HybridBitSet<T: Idx> {
467     Sparse(SparseBitSet<T>),
468     Dense(BitSet<T>),
469 }
470
471 impl<T: Idx> HybridBitSet<T> {
472     pub fn new_empty(domain_size: usize) -> Self {
473         HybridBitSet::Sparse(SparseBitSet::new_empty(domain_size))
474     }
475
476     fn domain_size(&self) -> usize {
477         match self {
478             HybridBitSet::Sparse(sparse) => sparse.domain_size,
479             HybridBitSet::Dense(dense) => dense.domain_size,
480         }
481     }
482
483     pub fn clear(&mut self) {
484         let domain_size = self.domain_size();
485         *self = HybridBitSet::new_empty(domain_size);
486     }
487
488     pub fn contains(&self, elem: T) -> bool {
489         match self {
490             HybridBitSet::Sparse(sparse) => sparse.contains(elem),
491             HybridBitSet::Dense(dense) => dense.contains(elem),
492         }
493     }
494
495     pub fn superset(&self, other: &HybridBitSet<T>) -> bool {
496         match (self, other) {
497             (HybridBitSet::Dense(self_dense), HybridBitSet::Dense(other_dense)) => {
498                 self_dense.superset(other_dense)
499             }
500             _ => {
501                 assert!(self.domain_size() == other.domain_size());
502                 other.iter().all(|elem| self.contains(elem))
503             }
504         }
505     }
506
507     pub fn is_empty(&self) -> bool {
508         match self {
509             HybridBitSet::Sparse(sparse) => sparse.is_empty(),
510             HybridBitSet::Dense(dense) => dense.is_empty(),
511         }
512     }
513
514     pub fn insert(&mut self, elem: T) -> bool {
515         // No need to check `elem` against `self.domain_size` here because all
516         // the match cases check it, one way or another.
517         match self {
518             HybridBitSet::Sparse(sparse) if sparse.len() < SPARSE_MAX => {
519                 // The set is sparse and has space for `elem`.
520                 sparse.insert(elem)
521             }
522             HybridBitSet::Sparse(sparse) if sparse.contains(elem) => {
523                 // The set is sparse and does not have space for `elem`, but
524                 // that doesn't matter because `elem` is already present.
525                 false
526             }
527             HybridBitSet::Sparse(sparse) => {
528                 // The set is sparse and full. Convert to a dense set.
529                 let mut dense = sparse.to_dense();
530                 let changed = dense.insert(elem);
531                 assert!(changed);
532                 *self = HybridBitSet::Dense(dense);
533                 changed
534             }
535             HybridBitSet::Dense(dense) => dense.insert(elem),
536         }
537     }
538
539     pub fn insert_all(&mut self) {
540         let domain_size = self.domain_size();
541         match self {
542             HybridBitSet::Sparse(_) => {
543                 *self = HybridBitSet::Dense(BitSet::new_filled(domain_size));
544             }
545             HybridBitSet::Dense(dense) => dense.insert_all(),
546         }
547     }
548
549     pub fn remove(&mut self, elem: T) -> bool {
550         // Note: we currently don't bother going from Dense back to Sparse.
551         match self {
552             HybridBitSet::Sparse(sparse) => sparse.remove(elem),
553             HybridBitSet::Dense(dense) => dense.remove(elem),
554         }
555     }
556
557     pub fn union(&mut self, other: &HybridBitSet<T>) -> bool {
558         match self {
559             HybridBitSet::Sparse(self_sparse) => {
560                 match other {
561                     HybridBitSet::Sparse(other_sparse) => {
562                         // Both sets are sparse. Add the elements in
563                         // `other_sparse` to `self` one at a time. This
564                         // may or may not cause `self` to be densified.
565                         assert_eq!(self.domain_size(), other.domain_size());
566                         let mut changed = false;
567                         for elem in other_sparse.iter() {
568                             changed |= self.insert(*elem);
569                         }
570                         changed
571                     }
572                     HybridBitSet::Dense(other_dense) => {
573                         // `self` is sparse and `other` is dense. To
574                         // merge them, we have two available strategies:
575                         // * Densify `self` then merge other
576                         // * Clone other then integrate bits from `self`
577                         // The second strategy requires dedicated method
578                         // since the usual `union` returns the wrong
579                         // result. In the dedicated case the computation
580                         // is slightly faster if the bits of the sparse
581                         // bitset map to only few words of the dense
582                         // representation, i.e. indices are near each
583                         // other.
584                         //
585                         // Benchmarking seems to suggest that the second
586                         // option is worth it.
587                         let mut new_dense = other_dense.clone();
588                         let changed = new_dense.reverse_union_sparse(self_sparse);
589                         *self = HybridBitSet::Dense(new_dense);
590                         changed
591                     }
592                 }
593             }
594
595             HybridBitSet::Dense(self_dense) => self_dense.union(other),
596         }
597     }
598
599     /// Converts to a dense set, consuming itself in the process.
600     pub fn to_dense(self) -> BitSet<T> {
601         match self {
602             HybridBitSet::Sparse(sparse) => sparse.to_dense(),
603             HybridBitSet::Dense(dense) => dense,
604         }
605     }
606
607     pub fn iter(&self) -> HybridIter<'_, T> {
608         match self {
609             HybridBitSet::Sparse(sparse) => HybridIter::Sparse(sparse.iter()),
610             HybridBitSet::Dense(dense) => HybridIter::Dense(dense.iter()),
611         }
612     }
613 }
614
615 impl<T: Idx> UnionIntoBitSet<T> for HybridBitSet<T> {
616     fn union_into(&self, other: &mut BitSet<T>) -> bool {
617         match self {
618             HybridBitSet::Sparse(sparse) => sparse.union_into(other),
619             HybridBitSet::Dense(dense) => dense.union_into(other),
620         }
621     }
622 }
623
624 impl<T: Idx> SubtractFromBitSet<T> for HybridBitSet<T> {
625     fn subtract_from(&self, other: &mut BitSet<T>) -> bool {
626         match self {
627             HybridBitSet::Sparse(sparse) => sparse.subtract_from(other),
628             HybridBitSet::Dense(dense) => dense.subtract_from(other),
629         }
630     }
631 }
632
633 pub enum HybridIter<'a, T: Idx> {
634     Sparse(slice::Iter<'a, T>),
635     Dense(BitIter<'a, T>),
636 }
637
638 impl<'a, T: Idx> Iterator for HybridIter<'a, T> {
639     type Item = T;
640
641     fn next(&mut self) -> Option<T> {
642         match self {
643             HybridIter::Sparse(sparse) => sparse.next().copied(),
644             HybridIter::Dense(dense) => dense.next(),
645         }
646     }
647 }
648
649 /// A resizable bitset type with a dense representation.
650 ///
651 /// `T` is an index type, typically a newtyped `usize` wrapper, but it can also
652 /// just be `usize`.
653 ///
654 /// All operations that involve an element will panic if the element is equal
655 /// to or greater than the domain size.
656 #[derive(Clone, Debug, PartialEq)]
657 pub struct GrowableBitSet<T: Idx> {
658     bit_set: BitSet<T>,
659 }
660
661 impl<T: Idx> GrowableBitSet<T> {
662     /// Ensure that the set can hold at least `min_domain_size` elements.
663     pub fn ensure(&mut self, min_domain_size: usize) {
664         if self.bit_set.domain_size < min_domain_size {
665             self.bit_set.domain_size = min_domain_size;
666         }
667
668         let min_num_words = num_words(min_domain_size);
669         if self.bit_set.words.len() < min_num_words {
670             self.bit_set.words.resize(min_num_words, 0)
671         }
672     }
673
674     pub fn new_empty() -> GrowableBitSet<T> {
675         GrowableBitSet { bit_set: BitSet::new_empty(0) }
676     }
677
678     pub fn with_capacity(capacity: usize) -> GrowableBitSet<T> {
679         GrowableBitSet { bit_set: BitSet::new_empty(capacity) }
680     }
681
682     /// Returns `true` if the set has changed.
683     #[inline]
684     pub fn insert(&mut self, elem: T) -> bool {
685         self.ensure(elem.index() + 1);
686         self.bit_set.insert(elem)
687     }
688
689     #[inline]
690     pub fn contains(&self, elem: T) -> bool {
691         let (word_index, mask) = word_index_and_mask(elem);
692         if let Some(word) = self.bit_set.words.get(word_index) { (word & mask) != 0 } else { false }
693     }
694 }
695
696 /// A fixed-size 2D bit matrix type with a dense representation.
697 ///
698 /// `R` and `C` are index types used to identify rows and columns respectively;
699 /// typically newtyped `usize` wrappers, but they can also just be `usize`.
700 ///
701 /// All operations that involve a row and/or column index will panic if the
702 /// index exceeds the relevant bound.
703 #[derive(Clone, Eq, PartialEq, RustcDecodable, RustcEncodable)]
704 pub struct BitMatrix<R: Idx, C: Idx> {
705     num_rows: usize,
706     num_columns: usize,
707     words: Vec<Word>,
708     marker: PhantomData<(R, C)>,
709 }
710
711 impl<R: Idx, C: Idx> BitMatrix<R, C> {
712     /// Creates a new `rows x columns` matrix, initially empty.
713     pub fn new(num_rows: usize, num_columns: usize) -> BitMatrix<R, C> {
714         // For every element, we need one bit for every other
715         // element. Round up to an even number of words.
716         let words_per_row = num_words(num_columns);
717         BitMatrix {
718             num_rows,
719             num_columns,
720             words: vec![0; num_rows * words_per_row],
721             marker: PhantomData,
722         }
723     }
724
725     /// Creates a new matrix, with `row` used as the value for every row.
726     pub fn from_row_n(row: &BitSet<C>, num_rows: usize) -> BitMatrix<R, C> {
727         let num_columns = row.domain_size();
728         let words_per_row = num_words(num_columns);
729         assert_eq!(words_per_row, row.words().len());
730         BitMatrix {
731             num_rows,
732             num_columns,
733             words: iter::repeat(row.words()).take(num_rows).flatten().cloned().collect(),
734             marker: PhantomData,
735         }
736     }
737
738     pub fn rows(&self) -> impl Iterator<Item = R> {
739         (0..self.num_rows).map(R::new)
740     }
741
742     /// The range of bits for a given row.
743     fn range(&self, row: R) -> (usize, usize) {
744         let words_per_row = num_words(self.num_columns);
745         let start = row.index() * words_per_row;
746         (start, start + words_per_row)
747     }
748
749     /// Sets the cell at `(row, column)` to true. Put another way, insert
750     /// `column` to the bitset for `row`.
751     ///
752     /// Returns `true` if this changed the matrix.
753     pub fn insert(&mut self, row: R, column: C) -> bool {
754         assert!(row.index() < self.num_rows && column.index() < self.num_columns);
755         let (start, _) = self.range(row);
756         let (word_index, mask) = word_index_and_mask(column);
757         let words = &mut self.words[..];
758         let word = words[start + word_index];
759         let new_word = word | mask;
760         words[start + word_index] = new_word;
761         word != new_word
762     }
763
764     /// Do the bits from `row` contain `column`? Put another way, is
765     /// the matrix cell at `(row, column)` true?  Put yet another way,
766     /// if the matrix represents (transitive) reachability, can
767     /// `row` reach `column`?
768     pub fn contains(&self, row: R, column: C) -> bool {
769         assert!(row.index() < self.num_rows && column.index() < self.num_columns);
770         let (start, _) = self.range(row);
771         let (word_index, mask) = word_index_and_mask(column);
772         (self.words[start + word_index] & mask) != 0
773     }
774
775     /// Returns those indices that are true in rows `a` and `b`. This
776     /// is an *O*(*n*) operation where *n* is the number of elements
777     /// (somewhat independent from the actual size of the
778     /// intersection, in particular).
779     pub fn intersect_rows(&self, row1: R, row2: R) -> Vec<C> {
780         assert!(row1.index() < self.num_rows && row2.index() < self.num_rows);
781         let (row1_start, row1_end) = self.range(row1);
782         let (row2_start, row2_end) = self.range(row2);
783         let mut result = Vec::with_capacity(self.num_columns);
784         for (base, (i, j)) in (row1_start..row1_end).zip(row2_start..row2_end).enumerate() {
785             let mut v = self.words[i] & self.words[j];
786             for bit in 0..WORD_BITS {
787                 if v == 0 {
788                     break;
789                 }
790                 if v & 0x1 != 0 {
791                     result.push(C::new(base * WORD_BITS + bit));
792                 }
793                 v >>= 1;
794             }
795         }
796         result
797     }
798
799     /// Adds the bits from row `read` to the bits from row `write`, and
800     /// returns `true` if anything changed.
801     ///
802     /// This is used when computing transitive reachability because if
803     /// you have an edge `write -> read`, because in that case
804     /// `write` can reach everything that `read` can (and
805     /// potentially more).
806     pub fn union_rows(&mut self, read: R, write: R) -> bool {
807         assert!(read.index() < self.num_rows && write.index() < self.num_rows);
808         let (read_start, read_end) = self.range(read);
809         let (write_start, write_end) = self.range(write);
810         let words = &mut self.words[..];
811         let mut changed = false;
812         for (read_index, write_index) in (read_start..read_end).zip(write_start..write_end) {
813             let word = words[write_index];
814             let new_word = word | words[read_index];
815             words[write_index] = new_word;
816             changed |= word != new_word;
817         }
818         changed
819     }
820
821     /// Adds the bits from `with` to the bits from row `write`, and
822     /// returns `true` if anything changed.
823     pub fn union_row_with(&mut self, with: &BitSet<C>, write: R) -> bool {
824         assert!(write.index() < self.num_rows);
825         assert_eq!(with.domain_size(), self.num_columns);
826         let (write_start, write_end) = self.range(write);
827         let mut changed = false;
828         for (read_index, write_index) in (0..with.words().len()).zip(write_start..write_end) {
829             let word = self.words[write_index];
830             let new_word = word | with.words()[read_index];
831             self.words[write_index] = new_word;
832             changed |= word != new_word;
833         }
834         changed
835     }
836
837     /// Sets every cell in `row` to true.
838     pub fn insert_all_into_row(&mut self, row: R) {
839         assert!(row.index() < self.num_rows);
840         let (start, end) = self.range(row);
841         let words = &mut self.words[..];
842         for index in start..end {
843             words[index] = !0;
844         }
845         self.clear_excess_bits(row);
846     }
847
848     /// Clear excess bits in the final word of the row.
849     fn clear_excess_bits(&mut self, row: R) {
850         let num_bits_in_final_word = self.num_columns % WORD_BITS;
851         if num_bits_in_final_word > 0 {
852             let mask = (1 << num_bits_in_final_word) - 1;
853             let (_, end) = self.range(row);
854             let final_word_idx = end - 1;
855             self.words[final_word_idx] &= mask;
856         }
857     }
858
859     /// Gets a slice of the underlying words.
860     pub fn words(&self) -> &[Word] {
861         &self.words
862     }
863
864     /// Iterates through all the columns set to true in a given row of
865     /// the matrix.
866     pub fn iter(&self, row: R) -> BitIter<'_, C> {
867         assert!(row.index() < self.num_rows);
868         let (start, end) = self.range(row);
869         BitIter::new(&self.words[start..end])
870     }
871
872     /// Returns the number of elements in `row`.
873     pub fn count(&self, row: R) -> usize {
874         let (start, end) = self.range(row);
875         self.words[start..end].iter().map(|e| e.count_ones() as usize).sum()
876     }
877 }
878
879 impl<R: Idx, C: Idx> fmt::Debug for BitMatrix<R, C> {
880     fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
881         /// Forces its contents to print in regular mode instead of alternate mode.
882         struct OneLinePrinter<T>(T);
883         impl<T: fmt::Debug> fmt::Debug for OneLinePrinter<T> {
884             fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
885                 write!(fmt, "{:?}", self.0)
886             }
887         }
888
889         write!(fmt, "BitMatrix({}x{}) ", self.num_rows, self.num_columns)?;
890         let items = self.rows().flat_map(|r| self.iter(r).map(move |c| (r, c)));
891         fmt.debug_set().entries(items.map(OneLinePrinter)).finish()
892     }
893 }
894
895 /// A fixed-column-size, variable-row-size 2D bit matrix with a moderately
896 /// sparse representation.
897 ///
898 /// Initially, every row has no explicit representation. If any bit within a
899 /// row is set, the entire row is instantiated as `Some(<HybridBitSet>)`.
900 /// Furthermore, any previously uninstantiated rows prior to it will be
901 /// instantiated as `None`. Those prior rows may themselves become fully
902 /// instantiated later on if any of their bits are set.
903 ///
904 /// `R` and `C` are index types used to identify rows and columns respectively;
905 /// typically newtyped `usize` wrappers, but they can also just be `usize`.
906 #[derive(Clone, Debug)]
907 pub struct SparseBitMatrix<R, C>
908 where
909     R: Idx,
910     C: Idx,
911 {
912     num_columns: usize,
913     rows: IndexVec<R, Option<HybridBitSet<C>>>,
914 }
915
916 impl<R: Idx, C: Idx> SparseBitMatrix<R, C> {
917     /// Creates a new empty sparse bit matrix with no rows or columns.
918     pub fn new(num_columns: usize) -> Self {
919         Self { num_columns, rows: IndexVec::new() }
920     }
921
922     fn ensure_row(&mut self, row: R) -> &mut HybridBitSet<C> {
923         // Instantiate any missing rows up to and including row `row` with an
924         // empty HybridBitSet.
925         self.rows.ensure_contains_elem(row, || None);
926
927         // Then replace row `row` with a full HybridBitSet if necessary.
928         let num_columns = self.num_columns;
929         self.rows[row].get_or_insert_with(|| HybridBitSet::new_empty(num_columns))
930     }
931
932     /// Sets the cell at `(row, column)` to true. Put another way, insert
933     /// `column` to the bitset for `row`.
934     ///
935     /// Returns `true` if this changed the matrix.
936     pub fn insert(&mut self, row: R, column: C) -> bool {
937         self.ensure_row(row).insert(column)
938     }
939
940     /// Do the bits from `row` contain `column`? Put another way, is
941     /// the matrix cell at `(row, column)` true?  Put yet another way,
942     /// if the matrix represents (transitive) reachability, can
943     /// `row` reach `column`?
944     pub fn contains(&self, row: R, column: C) -> bool {
945         self.row(row).map_or(false, |r| r.contains(column))
946     }
947
948     /// Adds the bits from row `read` to the bits from row `write`, and
949     /// returns `true` if anything changed.
950     ///
951     /// This is used when computing transitive reachability because if
952     /// you have an edge `write -> read`, because in that case
953     /// `write` can reach everything that `read` can (and
954     /// potentially more).
955     pub fn union_rows(&mut self, read: R, write: R) -> bool {
956         if read == write || self.row(read).is_none() {
957             return false;
958         }
959
960         self.ensure_row(write);
961         if let (Some(read_row), Some(write_row)) = self.rows.pick2_mut(read, write) {
962             write_row.union(read_row)
963         } else {
964             unreachable!()
965         }
966     }
967
968     /// Union a row, `from`, into the `into` row.
969     pub fn union_into_row(&mut self, into: R, from: &HybridBitSet<C>) -> bool {
970         self.ensure_row(into).union(from)
971     }
972
973     /// Insert all bits in the given row.
974     pub fn insert_all_into_row(&mut self, row: R) {
975         self.ensure_row(row).insert_all();
976     }
977
978     pub fn rows(&self) -> impl Iterator<Item = R> {
979         self.rows.indices()
980     }
981
982     /// Iterates through all the columns set to true in a given row of
983     /// the matrix.
984     pub fn iter<'a>(&'a self, row: R) -> impl Iterator<Item = C> + 'a {
985         self.row(row).into_iter().flat_map(|r| r.iter())
986     }
987
988     pub fn row(&self, row: R) -> Option<&HybridBitSet<C>> {
989         if let Some(Some(row)) = self.rows.get(row) { Some(row) } else { None }
990     }
991 }
992
993 #[inline]
994 fn num_words<T: Idx>(domain_size: T) -> usize {
995     (domain_size.index() + WORD_BITS - 1) / WORD_BITS
996 }
997
998 #[inline]
999 fn word_index_and_mask<T: Idx>(elem: T) -> (usize, Word) {
1000     let elem = elem.index();
1001     let word_index = elem / WORD_BITS;
1002     let mask = 1 << (elem % WORD_BITS);
1003     (word_index, mask)
1004 }
1005
1006 /// Integral type used to represent the bit set.
1007 pub trait FiniteBitSetTy:
1008     BitAnd<Output = Self>
1009     + BitAndAssign
1010     + BitOrAssign
1011     + Clone
1012     + Copy
1013     + Shl
1014     + Not<Output = Self>
1015     + PartialEq
1016     + Sized
1017 {
1018     /// Size of the domain representable by this type, e.g. 64 for `u64`.
1019     const DOMAIN_SIZE: u32;
1020
1021     /// Value which represents the `FiniteBitSet` having every bit set.
1022     const FILLED: Self;
1023     /// Value which represents the `FiniteBitSet` having no bits set.
1024     const EMPTY: Self;
1025
1026     /// Value for one as the integral type.
1027     const ONE: Self;
1028     /// Value for zero as the integral type.
1029     const ZERO: Self;
1030
1031     /// Perform a checked left shift on the integral type.
1032     fn checked_shl(self, rhs: u32) -> Option<Self>;
1033     /// Perform a checked right shift on the integral type.
1034     fn checked_shr(self, rhs: u32) -> Option<Self>;
1035 }
1036
1037 impl FiniteBitSetTy for u32 {
1038     const DOMAIN_SIZE: u32 = 32;
1039
1040     const FILLED: Self = Self::MAX;
1041     const EMPTY: Self = Self::MIN;
1042
1043     const ONE: Self = 1u32;
1044     const ZERO: Self = 0u32;
1045
1046     fn checked_shl(self, rhs: u32) -> Option<Self> {
1047         self.checked_shl(rhs)
1048     }
1049
1050     fn checked_shr(self, rhs: u32) -> Option<Self> {
1051         self.checked_shr(rhs)
1052     }
1053 }
1054
1055 impl std::fmt::Debug for FiniteBitSet<u32> {
1056     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1057         write!(f, "{:032b}", self.0)
1058     }
1059 }
1060
1061 impl FiniteBitSetTy for u64 {
1062     const DOMAIN_SIZE: u32 = 64;
1063
1064     const FILLED: Self = Self::MAX;
1065     const EMPTY: Self = Self::MIN;
1066
1067     const ONE: Self = 1u64;
1068     const ZERO: Self = 0u64;
1069
1070     fn checked_shl(self, rhs: u32) -> Option<Self> {
1071         self.checked_shl(rhs)
1072     }
1073
1074     fn checked_shr(self, rhs: u32) -> Option<Self> {
1075         self.checked_shr(rhs)
1076     }
1077 }
1078
1079 impl std::fmt::Debug for FiniteBitSet<u64> {
1080     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1081         write!(f, "{:064b}", self.0)
1082     }
1083 }
1084
1085 impl FiniteBitSetTy for u128 {
1086     const DOMAIN_SIZE: u32 = 128;
1087
1088     const FILLED: Self = Self::MAX;
1089     const EMPTY: Self = Self::MIN;
1090
1091     const ONE: Self = 1u128;
1092     const ZERO: Self = 0u128;
1093
1094     fn checked_shl(self, rhs: u32) -> Option<Self> {
1095         self.checked_shl(rhs)
1096     }
1097
1098     fn checked_shr(self, rhs: u32) -> Option<Self> {
1099         self.checked_shr(rhs)
1100     }
1101 }
1102
1103 impl std::fmt::Debug for FiniteBitSet<u128> {
1104     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1105         write!(f, "{:0128b}", self.0)
1106     }
1107 }
1108
1109 /// A fixed-sized bitset type represented by an integer type. Indices outwith than the range
1110 /// representable by `T` are considered set.
1111 #[derive(Copy, Clone, Eq, PartialEq, RustcDecodable, RustcEncodable)]
1112 pub struct FiniteBitSet<T: FiniteBitSetTy>(pub T);
1113
1114 impl<T: FiniteBitSetTy> FiniteBitSet<T> {
1115     /// Creates a new, empty bitset.
1116     pub fn new_empty() -> Self {
1117         Self(T::EMPTY)
1118     }
1119
1120     /// Sets the `index`th bit.
1121     pub fn set(&mut self, index: u32) {
1122         self.0 |= T::ONE.checked_shl(index).unwrap_or(T::ZERO);
1123     }
1124
1125     /// Unsets the `index`th bit.
1126     pub fn clear(&mut self, index: u32) {
1127         self.0 &= !T::ONE.checked_shl(index).unwrap_or(T::ZERO);
1128     }
1129
1130     /// Sets the `i`th to `j`th bits.
1131     pub fn set_range(&mut self, range: Range<u32>) {
1132         let bits = T::FILLED
1133             .checked_shl(range.end - range.start)
1134             .unwrap_or(T::ZERO)
1135             .not()
1136             .checked_shl(range.start)
1137             .unwrap_or(T::ZERO);
1138         self.0 |= bits;
1139     }
1140
1141     /// Is the set empty?
1142     pub fn is_empty(&self) -> bool {
1143         self.0 == T::EMPTY
1144     }
1145
1146     /// Returns the domain size of the bitset.
1147     pub fn within_domain(&self, index: u32) -> bool {
1148         index < T::DOMAIN_SIZE
1149     }
1150
1151     /// Returns if the `index`th bit is set.
1152     pub fn contains(&self, index: u32) -> Option<bool> {
1153         self.within_domain(index)
1154             .then(|| ((self.0.checked_shr(index).unwrap_or(T::ONE)) & T::ONE) == T::ONE)
1155     }
1156 }
1157
1158 impl<T: FiniteBitSetTy> Default for FiniteBitSet<T> {
1159     fn default() -> Self {
1160         Self::new_empty()
1161     }
1162 }