]>
Commit | Line | Data |
---|---|---|
d9579d0f AL |
1 | // Copyright 2015 The Rust Project Developers. See the COPYRIGHT |
2 | // file at the top-level directory of this distribution and at | |
3 | // http://rust-lang.org/COPYRIGHT. | |
4 | // | |
5 | // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or | |
6 | // http://www.apache.org/licenses/LICENSE-2.0> or the MIT license | |
7 | // <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your | |
8 | // option. This file may not be copied, modified, or distributed | |
9 | // except according to those terms. | |
10 | ||
54a0048b SL |
11 | use std::iter::FromIterator; |
12 | ||
d9579d0f | 13 | /// A very simple BitVector type. |
54a0048b | 14 | #[derive(Clone)] |
d9579d0f | 15 | pub struct BitVector { |
54a0048b | 16 | data: Vec<u64>, |
d9579d0f AL |
17 | } |
18 | ||
19 | impl BitVector { | |
20 | pub fn new(num_bits: usize) -> BitVector { | |
e9174d1e | 21 | let num_words = u64s(num_bits); |
c1a9b12d | 22 | BitVector { data: vec![0; num_words] } |
d9579d0f AL |
23 | } |
24 | ||
d9579d0f | 25 | pub fn contains(&self, bit: usize) -> bool { |
e9174d1e | 26 | let (word, mask) = word_mask(bit); |
d9579d0f AL |
27 | (self.data[word] & mask) != 0 |
28 | } | |
29 | ||
7453a54e | 30 | /// Returns true if the bit has changed. |
d9579d0f | 31 | pub fn insert(&mut self, bit: usize) -> bool { |
e9174d1e | 32 | let (word, mask) = word_mask(bit); |
d9579d0f AL |
33 | let data = &mut self.data[word]; |
34 | let value = *data; | |
7453a54e SL |
35 | let new_value = value | mask; |
36 | *data = new_value; | |
37 | new_value != value | |
d9579d0f | 38 | } |
e9174d1e SL |
39 | |
40 | pub fn insert_all(&mut self, all: &BitVector) -> bool { | |
41 | assert!(self.data.len() == all.data.len()); | |
42 | let mut changed = false; | |
43 | for (i, j) in self.data.iter_mut().zip(&all.data) { | |
44 | let value = *i; | |
45 | *i = value | *j; | |
54a0048b SL |
46 | if value != *i { |
47 | changed = true; | |
48 | } | |
e9174d1e SL |
49 | } |
50 | changed | |
51 | } | |
52 | ||
53 | pub fn grow(&mut self, num_bits: usize) { | |
54 | let num_words = u64s(num_bits); | |
55 | let extra_words = self.data.len() - num_words; | |
54a0048b SL |
56 | if extra_words > 0 { |
57 | self.data.extend((0..extra_words).map(|_| 0)); | |
58 | } | |
e9174d1e | 59 | } |
7453a54e SL |
60 | |
61 | /// Iterates over indexes of set bits in a sorted order | |
62 | pub fn iter<'a>(&'a self) -> BitVectorIter<'a> { | |
63 | BitVectorIter { | |
64 | iter: self.data.iter(), | |
65 | current: 0, | |
54a0048b | 66 | idx: 0, |
7453a54e SL |
67 | } |
68 | } | |
69 | } | |
70 | ||
71 | pub struct BitVectorIter<'a> { | |
72 | iter: ::std::slice::Iter<'a, u64>, | |
73 | current: u64, | |
54a0048b | 74 | idx: usize, |
7453a54e SL |
75 | } |
76 | ||
77 | impl<'a> Iterator for BitVectorIter<'a> { | |
78 | type Item = usize; | |
79 | fn next(&mut self) -> Option<usize> { | |
80 | while self.current == 0 { | |
81 | self.current = if let Some(&i) = self.iter.next() { | |
82 | if i == 0 { | |
83 | self.idx += 64; | |
84 | continue; | |
85 | } else { | |
86 | self.idx = u64s(self.idx) * 64; | |
87 | i | |
88 | } | |
89 | } else { | |
90 | return None; | |
91 | } | |
92 | } | |
93 | let offset = self.current.trailing_zeros() as usize; | |
94 | self.current >>= offset; | |
95 | self.current >>= 1; // shift otherwise overflows for 0b1000_0000_…_0000 | |
96 | self.idx += offset + 1; | |
97 | return Some(self.idx - 1); | |
98 | } | |
e9174d1e SL |
99 | } |
100 | ||
54a0048b SL |
101 | impl FromIterator<bool> for BitVector { |
102 | fn from_iter<I>(iter: I) -> BitVector where I: IntoIterator<Item=bool> { | |
103 | let iter = iter.into_iter(); | |
104 | let (len, _) = iter.size_hint(); | |
105 | // Make the minimum length for the bitvector 64 bits since that's | |
106 | // the smallest non-zero size anyway. | |
107 | let len = if len < 64 { 64 } else { len }; | |
108 | let mut bv = BitVector::new(len); | |
109 | for (idx, val) in iter.enumerate() { | |
110 | if idx > len { | |
111 | bv.grow(idx); | |
112 | } | |
113 | if val { | |
114 | bv.insert(idx); | |
115 | } | |
116 | } | |
117 | ||
118 | bv | |
119 | } | |
120 | } | |
121 | ||
e9174d1e SL |
122 | /// A "bit matrix" is basically a square matrix of booleans |
123 | /// represented as one gigantic bitvector. In other words, it is as if | |
124 | /// you have N bitvectors, each of length N. Note that `elements` here is `N`/ | |
125 | #[derive(Clone)] | |
126 | pub struct BitMatrix { | |
127 | elements: usize, | |
128 | vector: Vec<u64>, | |
129 | } | |
130 | ||
131 | impl BitMatrix { | |
132 | // Create a new `elements x elements` matrix, initially empty. | |
133 | pub fn new(elements: usize) -> BitMatrix { | |
134 | // For every element, we need one bit for every other | |
135 | // element. Round up to an even number of u64s. | |
136 | let u64s_per_elem = u64s(elements); | |
137 | BitMatrix { | |
138 | elements: elements, | |
54a0048b | 139 | vector: vec![0; elements * u64s_per_elem], |
e9174d1e SL |
140 | } |
141 | } | |
142 | ||
143 | /// The range of bits for a given element. | |
144 | fn range(&self, element: usize) -> (usize, usize) { | |
145 | let u64s_per_elem = u64s(self.elements); | |
146 | let start = element * u64s_per_elem; | |
147 | (start, start + u64s_per_elem) | |
148 | } | |
149 | ||
150 | pub fn add(&mut self, source: usize, target: usize) -> bool { | |
151 | let (start, _) = self.range(source); | |
152 | let (word, mask) = word_mask(target); | |
153 | let mut vector = &mut self.vector[..]; | |
54a0048b | 154 | let v1 = vector[start + word]; |
e9174d1e | 155 | let v2 = v1 | mask; |
54a0048b | 156 | vector[start + word] = v2; |
e9174d1e SL |
157 | v1 != v2 |
158 | } | |
159 | ||
160 | /// Do the bits from `source` contain `target`? | |
161 | /// | |
162 | /// Put another way, if the matrix represents (transitive) | |
163 | /// reachability, can `source` reach `target`? | |
164 | pub fn contains(&self, source: usize, target: usize) -> bool { | |
165 | let (start, _) = self.range(source); | |
166 | let (word, mask) = word_mask(target); | |
54a0048b | 167 | (self.vector[start + word] & mask) != 0 |
e9174d1e SL |
168 | } |
169 | ||
170 | /// Returns those indices that are reachable from both `a` and | |
171 | /// `b`. This is an O(n) operation where `n` is the number of | |
172 | /// elements (somewhat independent from the actual size of the | |
173 | /// intersection, in particular). | |
174 | pub fn intersection(&self, a: usize, b: usize) -> Vec<usize> { | |
175 | let (a_start, a_end) = self.range(a); | |
176 | let (b_start, b_end) = self.range(b); | |
177 | let mut result = Vec::with_capacity(self.elements); | |
178 | for (base, (i, j)) in (a_start..a_end).zip(b_start..b_end).enumerate() { | |
179 | let mut v = self.vector[i] & self.vector[j]; | |
180 | for bit in 0..64 { | |
54a0048b SL |
181 | if v == 0 { |
182 | break; | |
183 | } | |
184 | if v & 0x1 != 0 { | |
185 | result.push(base * 64 + bit); | |
186 | } | |
e9174d1e SL |
187 | v >>= 1; |
188 | } | |
189 | } | |
190 | result | |
191 | } | |
192 | ||
193 | /// Add the bits from `read` to the bits from `write`, | |
194 | /// return true if anything changed. | |
195 | /// | |
196 | /// This is used when computing transitive reachability because if | |
197 | /// you have an edge `write -> read`, because in that case | |
198 | /// `write` can reach everything that `read` can (and | |
199 | /// potentially more). | |
200 | pub fn merge(&mut self, read: usize, write: usize) -> bool { | |
201 | let (read_start, read_end) = self.range(read); | |
202 | let (write_start, write_end) = self.range(write); | |
203 | let vector = &mut self.vector[..]; | |
204 | let mut changed = false; | |
54a0048b | 205 | for (read_index, write_index) in (read_start..read_end).zip(write_start..write_end) { |
e9174d1e SL |
206 | let v1 = vector[write_index]; |
207 | let v2 = v1 | vector[read_index]; | |
208 | vector[write_index] = v2; | |
209 | changed = changed | (v1 != v2); | |
210 | } | |
211 | changed | |
212 | } | |
213 | } | |
214 | ||
215 | fn u64s(elements: usize) -> usize { | |
216 | (elements + 63) / 64 | |
217 | } | |
218 | ||
219 | fn word_mask(index: usize) -> (usize, u64) { | |
220 | let word = index / 64; | |
221 | let mask = 1 << (index % 64); | |
222 | (word, mask) | |
223 | } | |
224 | ||
7453a54e SL |
225 | #[test] |
226 | fn bitvec_iter_works() { | |
227 | let mut bitvec = BitVector::new(100); | |
228 | bitvec.insert(1); | |
229 | bitvec.insert(10); | |
230 | bitvec.insert(19); | |
231 | bitvec.insert(62); | |
232 | bitvec.insert(63); | |
233 | bitvec.insert(64); | |
234 | bitvec.insert(65); | |
235 | bitvec.insert(66); | |
236 | bitvec.insert(99); | |
54a0048b SL |
237 | assert_eq!(bitvec.iter().collect::<Vec<_>>(), |
238 | [1, 10, 19, 62, 63, 64, 65, 66, 99]); | |
7453a54e SL |
239 | } |
240 | ||
241 | #[test] | |
242 | fn bitvec_iter_works_2() { | |
243 | let mut bitvec = BitVector::new(300); | |
244 | bitvec.insert(1); | |
245 | bitvec.insert(10); | |
246 | bitvec.insert(19); | |
247 | bitvec.insert(62); | |
248 | bitvec.insert(66); | |
249 | bitvec.insert(99); | |
250 | bitvec.insert(299); | |
54a0048b SL |
251 | assert_eq!(bitvec.iter().collect::<Vec<_>>(), |
252 | [1, 10, 19, 62, 66, 99, 299]); | |
7453a54e SL |
253 | |
254 | } | |
255 | ||
256 | #[test] | |
257 | fn bitvec_iter_works_3() { | |
258 | let mut bitvec = BitVector::new(319); | |
259 | bitvec.insert(0); | |
260 | bitvec.insert(127); | |
261 | bitvec.insert(191); | |
262 | bitvec.insert(255); | |
263 | bitvec.insert(319); | |
264 | assert_eq!(bitvec.iter().collect::<Vec<_>>(), [0, 127, 191, 255, 319]); | |
265 | } | |
266 | ||
e9174d1e SL |
267 | #[test] |
268 | fn union_two_vecs() { | |
269 | let mut vec1 = BitVector::new(65); | |
270 | let mut vec2 = BitVector::new(65); | |
271 | assert!(vec1.insert(3)); | |
272 | assert!(!vec1.insert(3)); | |
273 | assert!(vec2.insert(5)); | |
274 | assert!(vec2.insert(64)); | |
275 | assert!(vec1.insert_all(&vec2)); | |
276 | assert!(!vec1.insert_all(&vec2)); | |
277 | assert!(vec1.contains(3)); | |
278 | assert!(!vec1.contains(4)); | |
279 | assert!(vec1.contains(5)); | |
280 | assert!(!vec1.contains(63)); | |
281 | assert!(vec1.contains(64)); | |
282 | } | |
283 | ||
284 | #[test] | |
285 | fn grow() { | |
286 | let mut vec1 = BitVector::new(65); | |
287 | assert!(vec1.insert(3)); | |
288 | assert!(!vec1.insert(3)); | |
289 | assert!(vec1.insert(5)); | |
290 | assert!(vec1.insert(64)); | |
291 | vec1.grow(128); | |
292 | assert!(vec1.contains(3)); | |
293 | assert!(vec1.contains(5)); | |
294 | assert!(vec1.contains(64)); | |
295 | assert!(!vec1.contains(126)); | |
296 | } | |
297 | ||
298 | #[test] | |
299 | fn matrix_intersection() { | |
300 | let mut vec1 = BitMatrix::new(200); | |
301 | ||
302 | // (*) Elements reachable from both 2 and 65. | |
303 | ||
304 | vec1.add(2, 3); | |
305 | vec1.add(2, 6); | |
306 | vec1.add(2, 10); // (*) | |
307 | vec1.add(2, 64); // (*) | |
308 | vec1.add(2, 65); | |
309 | vec1.add(2, 130); | |
310 | vec1.add(2, 160); // (*) | |
311 | ||
312 | vec1.add(64, 133); | |
313 | ||
314 | vec1.add(65, 2); | |
315 | vec1.add(65, 8); | |
316 | vec1.add(65, 10); // (*) | |
317 | vec1.add(65, 64); // (*) | |
318 | vec1.add(65, 68); | |
319 | vec1.add(65, 133); | |
320 | vec1.add(65, 160); // (*) | |
321 | ||
322 | let intersection = vec1.intersection(2, 64); | |
323 | assert!(intersection.is_empty()); | |
324 | ||
325 | let intersection = vec1.intersection(2, 65); | |
326 | assert_eq!(intersection, &[10, 64, 160]); | |
d9579d0f | 327 | } |