]> git.lizzy.rs Git - rust.git/blob - src/librustc_data_structures/bitvec.rs
Rollup merge of #47846 - roblabla:bugfix-ocaml, r=kennytm
[rust.git] / src / librustc_data_structures / bitvec.rs
1 // Copyright 2015 The Rust Project Developers. See the COPYRIGHT
2 // file at the top-level directory of this distribution and at
3 // http://rust-lang.org/COPYRIGHT.
4 //
5 // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
6 // http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
7 // <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
8 // option. This file may not be copied, modified, or distributed
9 // except according to those terms.
10
11 use std::iter::FromIterator;
12
13 /// A very simple BitVector type.
14 #[derive(Clone, Debug, PartialEq)]
15 pub struct BitVector {
16     data: Vec<u64>,
17 }
18
19 impl BitVector {
20     #[inline]
21     pub fn new(num_bits: usize) -> BitVector {
22         let num_words = u64s(num_bits);
23         BitVector { data: vec![0; num_words] }
24     }
25
26     #[inline]
27     pub fn clear(&mut self) {
28         for p in &mut self.data {
29             *p = 0;
30         }
31     }
32
33     pub fn count(&self) -> usize {
34         self.data.iter().map(|e| e.count_ones() as usize).sum()
35     }
36
37     #[inline]
38     pub fn contains(&self, bit: usize) -> bool {
39         let (word, mask) = word_mask(bit);
40         (self.data[word] & mask) != 0
41     }
42
43     /// Returns true if the bit has changed.
44     #[inline]
45     pub fn insert(&mut self, bit: usize) -> bool {
46         let (word, mask) = word_mask(bit);
47         let data = &mut self.data[word];
48         let value = *data;
49         let new_value = value | mask;
50         *data = new_value;
51         new_value != value
52     }
53
54     /// Returns true if the bit has changed.
55     #[inline]
56     pub fn remove(&mut self, bit: usize) -> bool {
57         let (word, mask) = word_mask(bit);
58         let data = &mut self.data[word];
59         let value = *data;
60         let new_value = value & !mask;
61         *data = new_value;
62         new_value != value
63     }
64
65     #[inline]
66     pub fn insert_all(&mut self, all: &BitVector) -> bool {
67         assert!(self.data.len() == all.data.len());
68         let mut changed = false;
69         for (i, j) in self.data.iter_mut().zip(&all.data) {
70             let value = *i;
71             *i = value | *j;
72             if value != *i {
73                 changed = true;
74             }
75         }
76         changed
77     }
78
79     #[inline]
80     pub fn grow(&mut self, num_bits: usize) {
81         let num_words = u64s(num_bits);
82         if self.data.len() < num_words {
83             self.data.resize(num_words, 0)
84         }
85     }
86
87     /// Iterates over indexes of set bits in a sorted order
88     #[inline]
89     pub fn iter<'a>(&'a self) -> BitVectorIter<'a> {
90         BitVectorIter {
91             iter: self.data.iter(),
92             current: 0,
93             idx: 0,
94         }
95     }
96 }
97
98 pub struct BitVectorIter<'a> {
99     iter: ::std::slice::Iter<'a, u64>,
100     current: u64,
101     idx: usize,
102 }
103
104 impl<'a> Iterator for BitVectorIter<'a> {
105     type Item = usize;
106     fn next(&mut self) -> Option<usize> {
107         while self.current == 0 {
108             self.current = if let Some(&i) = self.iter.next() {
109                 if i == 0 {
110                     self.idx += 64;
111                     continue;
112                 } else {
113                     self.idx = u64s(self.idx) * 64;
114                     i
115                 }
116             } else {
117                 return None;
118             }
119         }
120         let offset = self.current.trailing_zeros() as usize;
121         self.current >>= offset;
122         self.current >>= 1; // shift otherwise overflows for 0b1000_0000_…_0000
123         self.idx += offset + 1;
124         return Some(self.idx - 1);
125     }
126 }
127
128 impl FromIterator<bool> for BitVector {
129     fn from_iter<I>(iter: I) -> BitVector where I: IntoIterator<Item=bool> {
130         let iter = iter.into_iter();
131         let (len, _) = iter.size_hint();
132         // Make the minimum length for the bitvector 64 bits since that's
133         // the smallest non-zero size anyway.
134         let len = if len < 64 { 64 } else { len };
135         let mut bv = BitVector::new(len);
136         for (idx, val) in iter.enumerate() {
137             if idx > len {
138                 bv.grow(idx);
139             }
140             if val {
141                 bv.insert(idx);
142             }
143         }
144
145         bv
146     }
147 }
148
149 /// A "bit matrix" is basically a matrix of booleans represented as
150 /// one gigantic bitvector. In other words, it is as if you have
151 /// `rows` bitvectors, each of length `columns`.
152 #[derive(Clone, Debug)]
153 pub struct BitMatrix {
154     columns: usize,
155     vector: Vec<u64>,
156 }
157
158 impl BitMatrix {
159     /// Create a new `rows x columns` matrix, initially empty.
160     pub fn new(rows: usize, columns: usize) -> BitMatrix {
161         // For every element, we need one bit for every other
162         // element. Round up to an even number of u64s.
163         let u64s_per_row = u64s(columns);
164         BitMatrix {
165             columns,
166             vector: vec![0; rows * u64s_per_row],
167         }
168     }
169
170     /// The range of bits for a given row.
171     fn range(&self, row: usize) -> (usize, usize) {
172         let u64s_per_row = u64s(self.columns);
173         let start = row * u64s_per_row;
174         (start, start + u64s_per_row)
175     }
176
177     /// Sets the cell at `(row, column)` to true. Put another way, add
178     /// `column` to the bitset for `row`.
179     ///
180     /// Returns true if this changed the matrix, and false otherwies.
181     pub fn add(&mut self, row: usize, column: usize) -> bool {
182         let (start, _) = self.range(row);
183         let (word, mask) = word_mask(column);
184         let vector = &mut self.vector[..];
185         let v1 = vector[start + word];
186         let v2 = v1 | mask;
187         vector[start + word] = v2;
188         v1 != v2
189     }
190
191     /// Do the bits from `row` contain `column`? Put another way, is
192     /// the matrix cell at `(row, column)` true?  Put yet another way,
193     /// if the matrix represents (transitive) reachability, can
194     /// `row` reach `column`?
195     pub fn contains(&self, row: usize, column: usize) -> bool {
196         let (start, _) = self.range(row);
197         let (word, mask) = word_mask(column);
198         (self.vector[start + word] & mask) != 0
199     }
200
201     /// Returns those indices that are true in rows `a` and `b`.  This
202     /// is an O(n) operation where `n` is the number of elements
203     /// (somewhat independent from the actual size of the
204     /// intersection, in particular).
205     pub fn intersection(&self, a: usize, b: usize) -> Vec<usize> {
206         let (a_start, a_end) = self.range(a);
207         let (b_start, b_end) = self.range(b);
208         let mut result = Vec::with_capacity(self.columns);
209         for (base, (i, j)) in (a_start..a_end).zip(b_start..b_end).enumerate() {
210             let mut v = self.vector[i] & self.vector[j];
211             for bit in 0..64 {
212                 if v == 0 {
213                     break;
214                 }
215                 if v & 0x1 != 0 {
216                     result.push(base * 64 + bit);
217                 }
218                 v >>= 1;
219             }
220         }
221         result
222     }
223
224     /// Add the bits from row `read` to the bits from row `write`,
225     /// return true if anything changed.
226     ///
227     /// This is used when computing transitive reachability because if
228     /// you have an edge `write -> read`, because in that case
229     /// `write` can reach everything that `read` can (and
230     /// potentially more).
231     pub fn merge(&mut self, read: usize, write: usize) -> bool {
232         let (read_start, read_end) = self.range(read);
233         let (write_start, write_end) = self.range(write);
234         let vector = &mut self.vector[..];
235         let mut changed = false;
236         for (read_index, write_index) in (read_start..read_end).zip(write_start..write_end) {
237             let v1 = vector[write_index];
238             let v2 = v1 | vector[read_index];
239             vector[write_index] = v2;
240             changed = changed | (v1 != v2);
241         }
242         changed
243     }
244
245     /// Iterates through all the columns set to true in a given row of
246     /// the matrix.
247     pub fn iter<'a>(&'a self, row: usize) -> BitVectorIter<'a> {
248         let (start, end) = self.range(row);
249         BitVectorIter {
250             iter: self.vector[start..end].iter(),
251             current: 0,
252             idx: 0,
253         }
254     }
255 }
256
257 #[inline]
258 fn u64s(elements: usize) -> usize {
259     (elements + 63) / 64
260 }
261
262 #[inline]
263 fn word_mask(index: usize) -> (usize, u64) {
264     let word = index / 64;
265     let mask = 1 << (index % 64);
266     (word, mask)
267 }
268
269 #[test]
270 fn bitvec_iter_works() {
271     let mut bitvec = BitVector::new(100);
272     bitvec.insert(1);
273     bitvec.insert(10);
274     bitvec.insert(19);
275     bitvec.insert(62);
276     bitvec.insert(63);
277     bitvec.insert(64);
278     bitvec.insert(65);
279     bitvec.insert(66);
280     bitvec.insert(99);
281     assert_eq!(bitvec.iter().collect::<Vec<_>>(),
282                [1, 10, 19, 62, 63, 64, 65, 66, 99]);
283 }
284
285
286 #[test]
287 fn bitvec_iter_works_2() {
288     let mut bitvec = BitVector::new(319);
289     bitvec.insert(0);
290     bitvec.insert(127);
291     bitvec.insert(191);
292     bitvec.insert(255);
293     bitvec.insert(319);
294     assert_eq!(bitvec.iter().collect::<Vec<_>>(), [0, 127, 191, 255, 319]);
295 }
296
297 #[test]
298 fn union_two_vecs() {
299     let mut vec1 = BitVector::new(65);
300     let mut vec2 = BitVector::new(65);
301     assert!(vec1.insert(3));
302     assert!(!vec1.insert(3));
303     assert!(vec2.insert(5));
304     assert!(vec2.insert(64));
305     assert!(vec1.insert_all(&vec2));
306     assert!(!vec1.insert_all(&vec2));
307     assert!(vec1.contains(3));
308     assert!(!vec1.contains(4));
309     assert!(vec1.contains(5));
310     assert!(!vec1.contains(63));
311     assert!(vec1.contains(64));
312 }
313
314 #[test]
315 fn grow() {
316     let mut vec1 = BitVector::new(65);
317     for index in 0 .. 65 {
318         assert!(vec1.insert(index));
319         assert!(!vec1.insert(index));
320     }
321     vec1.grow(128);
322
323     // Check if the bits set before growing are still set
324     for index in 0 .. 65 {
325         assert!(vec1.contains(index));
326     }
327
328     // Check if the new bits are all un-set
329     for index in 65 .. 128 {
330         assert!(!vec1.contains(index));
331     }
332
333     // Check that we can set all new bits without running out of bounds
334     for index in 65 .. 128 {
335         assert!(vec1.insert(index));
336         assert!(!vec1.insert(index));
337     }
338 }
339
340 #[test]
341 fn matrix_intersection() {
342     let mut vec1 = BitMatrix::new(200, 200);
343
344     // (*) Elements reachable from both 2 and 65.
345
346     vec1.add(2, 3);
347     vec1.add(2, 6);
348     vec1.add(2, 10); // (*)
349     vec1.add(2, 64); // (*)
350     vec1.add(2, 65);
351     vec1.add(2, 130);
352     vec1.add(2, 160); // (*)
353
354     vec1.add(64, 133);
355
356     vec1.add(65, 2);
357     vec1.add(65, 8);
358     vec1.add(65, 10); // (*)
359     vec1.add(65, 64); // (*)
360     vec1.add(65, 68);
361     vec1.add(65, 133);
362     vec1.add(65, 160); // (*)
363
364     let intersection = vec1.intersection(2, 64);
365     assert!(intersection.is_empty());
366
367     let intersection = vec1.intersection(2, 65);
368     assert_eq!(intersection, &[10, 64, 160]);
369 }
370
371 #[test]
372 fn matrix_iter() {
373     let mut matrix = BitMatrix::new(64, 100);
374     matrix.add(3, 22);
375     matrix.add(3, 75);
376     matrix.add(2, 99);
377     matrix.add(4, 0);
378     matrix.merge(3, 5);
379
380     let expected = [99];
381     let mut iter = expected.iter();
382     for i in matrix.iter(2) {
383         let j = *iter.next().unwrap();
384         assert_eq!(i, j);
385     }
386     assert!(iter.next().is_none());
387
388     let expected = [22, 75];
389     let mut iter = expected.iter();
390     for i in matrix.iter(3) {
391         let j = *iter.next().unwrap();
392         assert_eq!(i, j);
393     }
394     assert!(iter.next().is_none());
395
396     let expected = [0];
397     let mut iter = expected.iter();
398     for i in matrix.iter(4) {
399         let j = *iter.next().unwrap();
400         assert_eq!(i, j);
401     }
402     assert!(iter.next().is_none());
403
404     let expected = [22, 75];
405     let mut iter = expected.iter();
406     for i in matrix.iter(5) {
407         let j = *iter.next().unwrap();
408         assert_eq!(i, j);
409     }
410     assert!(iter.next().is_none());
411 }