]> git.lizzy.rs Git - rust.git/blob - src/librustc_data_structures/bitvec.rs
Auto merge of #52712 - oli-obk:const_eval_cleanups, r=RalfJung
[rust.git] / src / librustc_data_structures / bitvec.rs
1 // Copyright 2015 The Rust Project Developers. See the COPYRIGHT
2 // file at the top-level directory of this distribution and at
3 // http://rust-lang.org/COPYRIGHT.
4 //
5 // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
6 // http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
7 // <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
8 // option. This file may not be copied, modified, or distributed
9 // except according to those terms.
10
11 use indexed_vec::{Idx, IndexVec};
12 use std::marker::PhantomData;
13
14 type Word = u128;
15 const WORD_BITS: usize = 128;
16
17 /// A very simple BitArray type.
18 ///
19 /// It does not support resizing after creation; use `BitVector` for that.
20 #[derive(Clone, Debug, PartialEq)]
21 pub struct BitArray<C: Idx> {
22     data: Vec<Word>,
23     marker: PhantomData<C>,
24 }
25
26 #[derive(Clone, Debug, PartialEq)]
27 pub struct BitVector<C: Idx> {
28     data: BitArray<C>,
29 }
30
31 impl<C: Idx> BitVector<C> {
32     pub fn grow(&mut self, num_bits: C) {
33         self.data.grow(num_bits)
34     }
35
36     pub fn new() -> BitVector<C> {
37         BitVector {
38             data: BitArray::new(0),
39         }
40     }
41
42     pub fn with_capacity(bits: usize) -> BitVector<C> {
43         BitVector {
44             data: BitArray::new(bits),
45         }
46     }
47
48     /// Returns true if the bit has changed.
49     #[inline]
50     pub fn insert(&mut self, bit: C) -> bool {
51         self.grow(bit);
52         self.data.insert(bit)
53     }
54
55     #[inline]
56     pub fn contains(&self, bit: C) -> bool {
57         let (word, mask) = word_mask(bit);
58         if let Some(word) = self.data.data.get(word) {
59             (word & mask) != 0
60         } else {
61             false
62         }
63     }
64 }
65
66 impl<C: Idx> BitArray<C> {
67     // Do not make this method public, instead switch your use case to BitVector.
68     #[inline]
69     fn grow(&mut self, num_bits: C) {
70         let num_words = words(num_bits);
71         if self.data.len() <= num_words {
72             self.data.resize(num_words + 1, 0)
73         }
74     }
75
76     #[inline]
77     pub fn new(num_bits: usize) -> BitArray<C> {
78         let num_words = words(num_bits);
79         BitArray {
80             data: vec![0; num_words],
81             marker: PhantomData,
82         }
83     }
84
85     #[inline]
86     pub fn clear(&mut self) {
87         for p in &mut self.data {
88             *p = 0;
89         }
90     }
91
92     pub fn count(&self) -> usize {
93         self.data.iter().map(|e| e.count_ones() as usize).sum()
94     }
95
96     /// True if `self` contains the bit `bit`.
97     #[inline]
98     pub fn contains(&self, bit: C) -> bool {
99         let (word, mask) = word_mask(bit);
100         (self.data[word] & mask) != 0
101     }
102
103     /// True if `self` contains all the bits in `other`.
104     ///
105     /// The two vectors must have the same length.
106     #[inline]
107     pub fn contains_all(&self, other: &BitArray<C>) -> bool {
108         assert_eq!(self.data.len(), other.data.len());
109         self.data.iter().zip(&other.data).all(|(a, b)| (a & b) == *b)
110     }
111
112     #[inline]
113     pub fn is_empty(&self) -> bool {
114         self.data.iter().all(|a| *a == 0)
115     }
116
117     /// Returns true if the bit has changed.
118     #[inline]
119     pub fn insert(&mut self, bit: C) -> bool {
120         let (word, mask) = word_mask(bit);
121         let data = &mut self.data[word];
122         let value = *data;
123         let new_value = value | mask;
124         *data = new_value;
125         new_value != value
126     }
127
128     /// Sets all bits to true.
129     pub fn insert_all(&mut self) {
130         for data in &mut self.data {
131             *data = u128::max_value();
132         }
133     }
134
135     /// Returns true if the bit has changed.
136     #[inline]
137     pub fn remove(&mut self, bit: C) -> bool {
138         let (word, mask) = word_mask(bit);
139         let data = &mut self.data[word];
140         let value = *data;
141         let new_value = value & !mask;
142         *data = new_value;
143         new_value != value
144     }
145
146     #[inline]
147     pub fn merge(&mut self, all: &BitArray<C>) -> bool {
148         assert!(self.data.len() == all.data.len());
149         let mut changed = false;
150         for (i, j) in self.data.iter_mut().zip(&all.data) {
151             let value = *i;
152             *i = value | *j;
153             if value != *i {
154                 changed = true;
155             }
156         }
157         changed
158     }
159
160     /// Iterates over indexes of set bits in a sorted order
161     #[inline]
162     pub fn iter<'a>(&'a self) -> BitIter<'a, C> {
163         BitIter {
164             iter: self.data.iter(),
165             current: 0,
166             idx: 0,
167             marker: PhantomData,
168         }
169     }
170 }
171
172 pub struct BitIter<'a, C: Idx> {
173     iter: ::std::slice::Iter<'a, Word>,
174     current: Word,
175     idx: usize,
176     marker: PhantomData<C>
177 }
178
179 impl<'a, C: Idx> Iterator for BitIter<'a, C> {
180     type Item = C;
181     fn next(&mut self) -> Option<C> {
182         while self.current == 0 {
183             self.current = if let Some(&i) = self.iter.next() {
184                 if i == 0 {
185                     self.idx += WORD_BITS;
186                     continue;
187                 } else {
188                     self.idx = words(self.idx) * WORD_BITS;
189                     i
190                 }
191             } else {
192                 return None;
193             }
194         }
195         let offset = self.current.trailing_zeros() as usize;
196         self.current >>= offset;
197         self.current >>= 1; // shift otherwise overflows for 0b1000_0000_…_0000
198         self.idx += offset + 1;
199         return Some(C::new(self.idx - 1));
200     }
201
202     fn size_hint(&self) -> (usize, Option<usize>) {
203         let (_, upper) = self.iter.size_hint();
204         (0, upper)
205     }
206 }
207
208 /// A "bit matrix" is basically a matrix of booleans represented as
209 /// one gigantic bitvector. In other words, it is as if you have
210 /// `rows` bitvectors, each of length `columns`.
211 #[derive(Clone, Debug)]
212 pub struct BitMatrix<R: Idx, C: Idx> {
213     columns: usize,
214     vector: Vec<Word>,
215     phantom: PhantomData<(R, C)>,
216 }
217
218 impl<R: Idx, C: Idx> BitMatrix<R, C> {
219     /// Create a new `rows x columns` matrix, initially empty.
220     pub fn new(rows: usize, columns: usize) -> BitMatrix<R, C> {
221         // For every element, we need one bit for every other
222         // element. Round up to an even number of words.
223         let words_per_row = words(columns);
224         BitMatrix {
225             columns,
226             vector: vec![0; rows * words_per_row],
227             phantom: PhantomData,
228         }
229     }
230
231     /// The range of bits for a given row.
232     fn range(&self, row: R) -> (usize, usize) {
233         let row = row.index();
234         let words_per_row = words(self.columns);
235         let start = row * words_per_row;
236         (start, start + words_per_row)
237     }
238
239     /// Sets the cell at `(row, column)` to true. Put another way, add
240     /// `column` to the bitset for `row`.
241     ///
242     /// Returns true if this changed the matrix, and false otherwise.
243     pub fn add(&mut self, row: R, column: R) -> bool {
244         let (start, _) = self.range(row);
245         let (word, mask) = word_mask(column);
246         let vector = &mut self.vector[..];
247         let v1 = vector[start + word];
248         let v2 = v1 | mask;
249         vector[start + word] = v2;
250         v1 != v2
251     }
252
253     /// Do the bits from `row` contain `column`? Put another way, is
254     /// the matrix cell at `(row, column)` true?  Put yet another way,
255     /// if the matrix represents (transitive) reachability, can
256     /// `row` reach `column`?
257     pub fn contains(&self, row: R, column: R) -> bool {
258         let (start, _) = self.range(row);
259         let (word, mask) = word_mask(column);
260         (self.vector[start + word] & mask) != 0
261     }
262
263     /// Returns those indices that are true in rows `a` and `b`.  This
264     /// is an O(n) operation where `n` is the number of elements
265     /// (somewhat independent from the actual size of the
266     /// intersection, in particular).
267     pub fn intersection(&self, a: R, b: R) -> Vec<C> {
268         let (a_start, a_end) = self.range(a);
269         let (b_start, b_end) = self.range(b);
270         let mut result = Vec::with_capacity(self.columns);
271         for (base, (i, j)) in (a_start..a_end).zip(b_start..b_end).enumerate() {
272             let mut v = self.vector[i] & self.vector[j];
273             for bit in 0..WORD_BITS {
274                 if v == 0 {
275                     break;
276                 }
277                 if v & 0x1 != 0 {
278                     result.push(C::new(base * WORD_BITS + bit));
279                 }
280                 v >>= 1;
281             }
282         }
283         result
284     }
285
286     /// Add the bits from row `read` to the bits from row `write`,
287     /// return true if anything changed.
288     ///
289     /// This is used when computing transitive reachability because if
290     /// you have an edge `write -> read`, because in that case
291     /// `write` can reach everything that `read` can (and
292     /// potentially more).
293     pub fn merge(&mut self, read: R, write: R) -> bool {
294         let (read_start, read_end) = self.range(read);
295         let (write_start, write_end) = self.range(write);
296         let vector = &mut self.vector[..];
297         let mut changed = false;
298         for (read_index, write_index) in (read_start..read_end).zip(write_start..write_end) {
299             let v1 = vector[write_index];
300             let v2 = v1 | vector[read_index];
301             vector[write_index] = v2;
302             changed = changed | (v1 != v2);
303         }
304         changed
305     }
306
307     /// Iterates through all the columns set to true in a given row of
308     /// the matrix.
309     pub fn iter<'a>(&'a self, row: R) -> BitIter<'a, C> {
310         let (start, end) = self.range(row);
311         BitIter {
312             iter: self.vector[start..end].iter(),
313             current: 0,
314             idx: 0,
315             marker: PhantomData,
316         }
317     }
318 }
319
320 /// A moderately sparse bit matrix: rows are appended lazily, but columns
321 /// within appended rows are instantiated fully upon creation.
322 #[derive(Clone, Debug)]
323 pub struct SparseBitMatrix<R, C>
324 where
325     R: Idx,
326     C: Idx,
327 {
328     columns: usize,
329     vector: IndexVec<R, BitArray<C>>,
330 }
331
332 impl<R: Idx, C: Idx> SparseBitMatrix<R, C> {
333     /// Create a new empty sparse bit matrix with no rows or columns.
334     pub fn new(columns: usize) -> Self {
335         Self {
336             columns,
337             vector: IndexVec::new(),
338         }
339     }
340
341     fn ensure_row(&mut self, row: R) {
342         let columns = self.columns;
343         self.vector
344             .ensure_contains_elem(row, || BitArray::new(columns));
345     }
346
347     /// Sets the cell at `(row, column)` to true. Put another way, insert
348     /// `column` to the bitset for `row`.
349     ///
350     /// Returns true if this changed the matrix, and false otherwise.
351     pub fn add(&mut self, row: R, column: C) -> bool {
352         self.ensure_row(row);
353         self.vector[row].insert(column)
354     }
355
356     /// Do the bits from `row` contain `column`? Put another way, is
357     /// the matrix cell at `(row, column)` true?  Put yet another way,
358     /// if the matrix represents (transitive) reachability, can
359     /// `row` reach `column`?
360     pub fn contains(&self, row: R, column: C) -> bool {
361         self.vector.get(row).map_or(false, |r| r.contains(column))
362     }
363
364     /// Add the bits from row `read` to the bits from row `write`,
365     /// return true if anything changed.
366     ///
367     /// This is used when computing transitive reachability because if
368     /// you have an edge `write -> read`, because in that case
369     /// `write` can reach everything that `read` can (and
370     /// potentially more).
371     pub fn merge(&mut self, read: R, write: R) -> bool {
372         if read == write || self.vector.get(read).is_none() {
373             return false;
374         }
375
376         self.ensure_row(write);
377         let (bitvec_read, bitvec_write) = self.vector.pick2_mut(read, write);
378         bitvec_write.merge(bitvec_read)
379     }
380
381     /// Merge a row, `from`, into the `into` row.
382     pub fn merge_into(&mut self, into: R, from: &BitArray<C>) -> bool {
383         self.ensure_row(into);
384         self.vector[into].merge(from)
385     }
386
387     /// Add all bits to the given row.
388     pub fn add_all(&mut self, row: R) {
389         self.ensure_row(row);
390         self.vector[row].insert_all();
391     }
392
393     /// Number of elements in the matrix.
394     pub fn len(&self) -> usize {
395         self.vector.len()
396     }
397
398     pub fn rows(&self) -> impl Iterator<Item = R> {
399         self.vector.indices()
400     }
401
402     /// Iterates through all the columns set to true in a given row of
403     /// the matrix.
404     pub fn iter<'a>(&'a self, row: R) -> impl Iterator<Item = C> + 'a {
405         self.vector.get(row).into_iter().flat_map(|r| r.iter())
406     }
407
408     /// Iterates through each row and the accompanying bit set.
409     pub fn iter_enumerated<'a>(&'a self) -> impl Iterator<Item = (R, &'a BitArray<C>)> + 'a {
410         self.vector.iter_enumerated()
411     }
412
413     pub fn row(&self, row: R) -> Option<&BitArray<C>> {
414         self.vector.get(row)
415     }
416 }
417
418 #[inline]
419 fn words<C: Idx>(elements: C) -> usize {
420     (elements.index() + WORD_BITS - 1) / WORD_BITS
421 }
422
423 #[inline]
424 fn word_mask<C: Idx>(index: C) -> (usize, Word) {
425     let index = index.index();
426     let word = index / WORD_BITS;
427     let mask = 1 << (index % WORD_BITS);
428     (word, mask)
429 }
430
431 #[test]
432 fn bitvec_iter_works() {
433     let mut bitvec: BitArray<usize> = BitArray::new(100);
434     bitvec.insert(1);
435     bitvec.insert(10);
436     bitvec.insert(19);
437     bitvec.insert(62);
438     bitvec.insert(63);
439     bitvec.insert(64);
440     bitvec.insert(65);
441     bitvec.insert(66);
442     bitvec.insert(99);
443     assert_eq!(
444         bitvec.iter().collect::<Vec<_>>(),
445         [1, 10, 19, 62, 63, 64, 65, 66, 99]
446     );
447 }
448
449 #[test]
450 fn bitvec_iter_works_2() {
451     let mut bitvec: BitArray<usize> = BitArray::new(319);
452     bitvec.insert(0);
453     bitvec.insert(127);
454     bitvec.insert(191);
455     bitvec.insert(255);
456     bitvec.insert(319);
457     assert_eq!(bitvec.iter().collect::<Vec<_>>(), [0, 127, 191, 255, 319]);
458 }
459
460 #[test]
461 fn union_two_vecs() {
462     let mut vec1: BitArray<usize> = BitArray::new(65);
463     let mut vec2: BitArray<usize> = BitArray::new(65);
464     assert!(vec1.insert(3));
465     assert!(!vec1.insert(3));
466     assert!(vec2.insert(5));
467     assert!(vec2.insert(64));
468     assert!(vec1.merge(&vec2));
469     assert!(!vec1.merge(&vec2));
470     assert!(vec1.contains(3));
471     assert!(!vec1.contains(4));
472     assert!(vec1.contains(5));
473     assert!(!vec1.contains(63));
474     assert!(vec1.contains(64));
475 }
476
477 #[test]
478 fn grow() {
479     let mut vec1: BitVector<usize> = BitVector::with_capacity(65);
480     for index in 0..65 {
481         assert!(vec1.insert(index));
482         assert!(!vec1.insert(index));
483     }
484     vec1.grow(128);
485
486     // Check if the bits set before growing are still set
487     for index in 0..65 {
488         assert!(vec1.contains(index));
489     }
490
491     // Check if the new bits are all un-set
492     for index in 65..128 {
493         assert!(!vec1.contains(index));
494     }
495
496     // Check that we can set all new bits without running out of bounds
497     for index in 65..128 {
498         assert!(vec1.insert(index));
499         assert!(!vec1.insert(index));
500     }
501 }
502
503 #[test]
504 fn matrix_intersection() {
505     let mut vec1: BitMatrix<usize, usize> = BitMatrix::new(200, 200);
506
507     // (*) Elements reachable from both 2 and 65.
508
509     vec1.add(2, 3);
510     vec1.add(2, 6);
511     vec1.add(2, 10); // (*)
512     vec1.add(2, 64); // (*)
513     vec1.add(2, 65);
514     vec1.add(2, 130);
515     vec1.add(2, 160); // (*)
516
517     vec1.add(64, 133);
518
519     vec1.add(65, 2);
520     vec1.add(65, 8);
521     vec1.add(65, 10); // (*)
522     vec1.add(65, 64); // (*)
523     vec1.add(65, 68);
524     vec1.add(65, 133);
525     vec1.add(65, 160); // (*)
526
527     let intersection = vec1.intersection(2, 64);
528     assert!(intersection.is_empty());
529
530     let intersection = vec1.intersection(2, 65);
531     assert_eq!(intersection, &[10, 64, 160]);
532 }
533
534 #[test]
535 fn matrix_iter() {
536     let mut matrix: BitMatrix<usize, usize> = BitMatrix::new(64, 100);
537     matrix.add(3, 22);
538     matrix.add(3, 75);
539     matrix.add(2, 99);
540     matrix.add(4, 0);
541     matrix.merge(3, 5);
542
543     let expected = [99];
544     let mut iter = expected.iter();
545     for i in matrix.iter(2) {
546         let j = *iter.next().unwrap();
547         assert_eq!(i, j);
548     }
549     assert!(iter.next().is_none());
550
551     let expected = [22, 75];
552     let mut iter = expected.iter();
553     for i in matrix.iter(3) {
554         let j = *iter.next().unwrap();
555         assert_eq!(i, j);
556     }
557     assert!(iter.next().is_none());
558
559     let expected = [0];
560     let mut iter = expected.iter();
561     for i in matrix.iter(4) {
562         let j = *iter.next().unwrap();
563         assert_eq!(i, j);
564     }
565     assert!(iter.next().is_none());
566
567     let expected = [22, 75];
568     let mut iter = expected.iter();
569     for i in matrix.iter(5) {
570         let j = *iter.next().unwrap();
571         assert_eq!(i, j);
572     }
573     assert!(iter.next().is_none());
574 }
575
576 #[test]
577 fn sparse_matrix_iter() {
578     let mut matrix: SparseBitMatrix<usize, usize> = SparseBitMatrix::new(100);
579     matrix.add(3, 22);
580     matrix.add(3, 75);
581     matrix.add(2, 99);
582     matrix.add(4, 0);
583     matrix.merge(3, 5);
584
585     let expected = [99];
586     let mut iter = expected.iter();
587     for i in matrix.iter(2) {
588         let j = *iter.next().unwrap();
589         assert_eq!(i, j);
590     }
591     assert!(iter.next().is_none());
592
593     let expected = [22, 75];
594     let mut iter = expected.iter();
595     for i in matrix.iter(3) {
596         let j = *iter.next().unwrap();
597         assert_eq!(i, j);
598     }
599     assert!(iter.next().is_none());
600
601     let expected = [0];
602     let mut iter = expected.iter();
603     for i in matrix.iter(4) {
604         let j = *iter.next().unwrap();
605         assert_eq!(i, j);
606     }
607     assert!(iter.next().is_none());
608
609     let expected = [22, 75];
610     let mut iter = expected.iter();
611     for i in matrix.iter(5) {
612         let j = *iter.next().unwrap();
613         assert_eq!(i, j);
614     }
615     assert!(iter.next().is_none());
616 }