]>
Commit | Line | Data |
---|---|---|
a2a8927a XL |
1 | use std::iter::Step; |
2 | use std::marker::PhantomData; | |
a2a8927a | 3 | use std::ops::RangeBounds; |
923072b8 | 4 | use std::ops::{Bound, Range}; |
a2a8927a | 5 | |
a2a8927a XL |
6 | use smallvec::SmallVec; |
7 | ||
49aad941 FG |
8 | use crate::idx::Idx; |
9 | use crate::vec::IndexVec; | |
10 | ||
a2a8927a XL |
11 | #[cfg(test)] |
12 | mod tests; | |
13 | ||
14 | /// Stores a set of intervals on the indices. | |
923072b8 FG |
15 | /// |
16 | /// The elements in `map` are sorted and non-adjacent, which means | |
17 | /// the second value of the previous element is *greater* than the | |
18 | /// first value of the following element. | |
a2a8927a XL |
19 | #[derive(Debug, Clone)] |
20 | pub struct IntervalSet<I> { | |
21 | // Start, end | |
22 | map: SmallVec<[(u32, u32); 4]>, | |
23 | domain: usize, | |
24 | _data: PhantomData<I>, | |
25 | } | |
26 | ||
27 | #[inline] | |
28 | fn inclusive_start<T: Idx>(range: impl RangeBounds<T>) -> u32 { | |
29 | match range.start_bound() { | |
30 | Bound::Included(start) => start.index() as u32, | |
31 | Bound::Excluded(start) => start.index() as u32 + 1, | |
32 | Bound::Unbounded => 0, | |
33 | } | |
34 | } | |
35 | ||
36 | #[inline] | |
37 | fn inclusive_end<T: Idx>(domain: usize, range: impl RangeBounds<T>) -> Option<u32> { | |
38 | let end = match range.end_bound() { | |
39 | Bound::Included(end) => end.index() as u32, | |
40 | Bound::Excluded(end) => end.index().checked_sub(1)? as u32, | |
41 | Bound::Unbounded => domain.checked_sub(1)? as u32, | |
42 | }; | |
43 | Some(end) | |
44 | } | |
45 | ||
46 | impl<I: Idx> IntervalSet<I> { | |
47 | pub fn new(domain: usize) -> IntervalSet<I> { | |
48 | IntervalSet { map: SmallVec::new(), domain, _data: PhantomData } | |
49 | } | |
50 | ||
51 | pub fn clear(&mut self) { | |
52 | self.map.clear(); | |
53 | } | |
54 | ||
55 | pub fn iter(&self) -> impl Iterator<Item = I> + '_ | |
56 | where | |
57 | I: Step, | |
58 | { | |
59 | self.iter_intervals().flatten() | |
60 | } | |
61 | ||
62 | /// Iterates through intervals stored in the set, in order. | |
63 | pub fn iter_intervals(&self) -> impl Iterator<Item = std::ops::Range<I>> + '_ | |
64 | where | |
65 | I: Step, | |
66 | { | |
67 | self.map.iter().map(|&(start, end)| I::new(start as usize)..I::new(end as usize + 1)) | |
68 | } | |
69 | ||
70 | /// Returns true if we increased the number of elements present. | |
71 | pub fn insert(&mut self, point: I) -> bool { | |
72 | self.insert_range(point..=point) | |
73 | } | |
74 | ||
75 | /// Returns true if we increased the number of elements present. | |
76 | pub fn insert_range(&mut self, range: impl RangeBounds<I> + Clone) -> bool { | |
77 | let start = inclusive_start(range.clone()); | |
923072b8 | 78 | let Some(end) = inclusive_end(self.domain, range) else { |
a2a8927a XL |
79 | // empty range |
80 | return false; | |
81 | }; | |
82 | if start > end { | |
83 | return false; | |
84 | } | |
85 | ||
923072b8 FG |
86 | // This condition looks a bit weird, but actually makes sense. |
87 | // | |
88 | // if r.0 == end + 1, then we're actually adjacent, so we want to | |
89 | // continue to the next range. We're looking here for the first | |
90 | // range which starts *non-adjacently* to our end. | |
91 | let next = self.map.partition_point(|r| r.0 <= end + 1); | |
92 | let result = if let Some(right) = next.checked_sub(1) { | |
93 | let (prev_start, prev_end) = self.map[right]; | |
94 | if prev_end + 1 >= start { | |
95 | // If the start for the inserted range is adjacent to the | |
96 | // end of the previous, we can extend the previous range. | |
97 | if start < prev_start { | |
98 | // The first range which ends *non-adjacently* to our start. | |
99 | // And we can ensure that left <= right. | |
100 | let left = self.map.partition_point(|l| l.1 + 1 < start); | |
101 | let min = std::cmp::min(self.map[left].0, start); | |
102 | let max = std::cmp::max(prev_end, end); | |
103 | self.map[right] = (min, max); | |
104 | if left != right { | |
105 | self.map.drain(left..right); | |
a2a8927a | 106 | } |
923072b8 | 107 | true |
a2a8927a | 108 | } else { |
923072b8 FG |
109 | // We overlap with the previous range, increase it to |
110 | // include us. | |
111 | // | |
112 | // Make sure we're actually going to *increase* it though -- | |
113 | // it may be that end is just inside the previously existing | |
114 | // set. | |
115 | if end > prev_end { | |
116 | self.map[right].1 = end; | |
117 | true | |
118 | } else { | |
119 | false | |
120 | } | |
a2a8927a XL |
121 | } |
122 | } else { | |
923072b8 FG |
123 | // Otherwise, we don't overlap, so just insert |
124 | self.map.insert(right + 1, (start, end)); | |
125 | true | |
a2a8927a | 126 | } |
923072b8 FG |
127 | } else { |
128 | if self.map.is_empty() { | |
129 | // Quite common in practice, and expensive to call memcpy | |
130 | // with length zero. | |
131 | self.map.push((start, end)); | |
132 | } else { | |
133 | self.map.insert(next, (start, end)); | |
134 | } | |
135 | true | |
136 | }; | |
137 | debug_assert!( | |
138 | self.check_invariants(), | |
9c376795 | 139 | "wrong intervals after insert {start:?}..={end:?} to {self:?}" |
923072b8 FG |
140 | ); |
141 | result | |
a2a8927a XL |
142 | } |
143 | ||
144 | pub fn contains(&self, needle: I) -> bool { | |
145 | let needle = needle.index() as u32; | |
5e7ed085 FG |
146 | let Some(last) = self.map.partition_point(|r| r.0 <= needle).checked_sub(1) else { |
147 | // All ranges in the map start after the new range's end | |
148 | return false; | |
a2a8927a XL |
149 | }; |
150 | let (_, prev_end) = &self.map[last]; | |
151 | needle <= *prev_end | |
152 | } | |
153 | ||
154 | pub fn superset(&self, other: &IntervalSet<I>) -> bool | |
155 | where | |
156 | I: Step, | |
157 | { | |
923072b8 FG |
158 | let mut sup_iter = self.iter_intervals(); |
159 | let mut current = None; | |
160 | let contains = |sup: Range<I>, sub: Range<I>, current: &mut Option<Range<I>>| { | |
161 | if sup.end < sub.start { | |
162 | // if `sup.end == sub.start`, the next sup doesn't contain `sub.start` | |
163 | None // continue to the next sup | |
164 | } else if sup.end >= sub.end && sup.start <= sub.start { | |
165 | *current = Some(sup); // save the current sup | |
166 | Some(true) | |
167 | } else { | |
168 | Some(false) | |
169 | } | |
170 | }; | |
171 | other.iter_intervals().all(|sub| { | |
172 | current | |
173 | .take() | |
174 | .and_then(|sup| contains(sup, sub.clone(), &mut current)) | |
175 | .or_else(|| sup_iter.find_map(|sup| contains(sup, sub.clone(), &mut current))) | |
176 | .unwrap_or(false) | |
177 | }) | |
a2a8927a XL |
178 | } |
179 | ||
180 | pub fn is_empty(&self) -> bool { | |
181 | self.map.is_empty() | |
182 | } | |
183 | ||
49aad941 FG |
184 | /// Equivalent to `range.iter().find(|i| !self.contains(i))`. |
185 | pub fn first_unset_in(&self, range: impl RangeBounds<I> + Clone) -> Option<I> { | |
186 | let start = inclusive_start(range.clone()); | |
187 | let Some(end) = inclusive_end(self.domain, range) else { | |
188 | // empty range | |
189 | return None; | |
190 | }; | |
191 | if start > end { | |
192 | return None; | |
193 | } | |
194 | let Some(last) = self.map.partition_point(|r| r.0 <= start).checked_sub(1) else { | |
195 | // All ranges in the map start after the new range's end | |
196 | return Some(I::new(start as usize)); | |
197 | }; | |
198 | let (_, prev_end) = self.map[last]; | |
199 | if start > prev_end { | |
200 | Some(I::new(start as usize)) | |
201 | } else if prev_end < end { | |
202 | Some(I::new(prev_end as usize + 1)) | |
203 | } else { | |
204 | None | |
205 | } | |
206 | } | |
207 | ||
a2a8927a XL |
208 | /// Returns the maximum (last) element present in the set from `range`. |
209 | pub fn last_set_in(&self, range: impl RangeBounds<I> + Clone) -> Option<I> { | |
210 | let start = inclusive_start(range.clone()); | |
211 | let Some(end) = inclusive_end(self.domain, range) else { | |
212 | // empty range | |
213 | return None; | |
214 | }; | |
215 | if start > end { | |
216 | return None; | |
217 | } | |
5e7ed085 FG |
218 | let Some(last) = self.map.partition_point(|r| r.0 <= end).checked_sub(1) else { |
219 | // All ranges in the map start after the new range's end | |
220 | return None; | |
a2a8927a XL |
221 | }; |
222 | let (_, prev_end) = &self.map[last]; | |
223 | if start <= *prev_end { Some(I::new(std::cmp::min(*prev_end, end) as usize)) } else { None } | |
224 | } | |
225 | ||
226 | pub fn insert_all(&mut self) { | |
227 | self.clear(); | |
923072b8 FG |
228 | if let Some(end) = self.domain.checked_sub(1) { |
229 | self.map.push((0, end.try_into().unwrap())); | |
230 | } | |
231 | debug_assert!(self.check_invariants()); | |
a2a8927a XL |
232 | } |
233 | ||
234 | pub fn union(&mut self, other: &IntervalSet<I>) -> bool | |
235 | where | |
236 | I: Step, | |
237 | { | |
238 | assert_eq!(self.domain, other.domain); | |
239 | let mut did_insert = false; | |
240 | for range in other.iter_intervals() { | |
241 | did_insert |= self.insert_range(range); | |
242 | } | |
923072b8 | 243 | debug_assert!(self.check_invariants()); |
a2a8927a XL |
244 | did_insert |
245 | } | |
923072b8 FG |
246 | |
247 | // Check the intervals are valid, sorted and non-adjacent | |
248 | fn check_invariants(&self) -> bool { | |
249 | let mut current: Option<u32> = None; | |
250 | for (start, end) in &self.map { | |
49aad941 | 251 | if start > end || current.is_some_and(|x| x + 1 >= *start) { |
923072b8 FG |
252 | return false; |
253 | } | |
254 | current = Some(*end); | |
255 | } | |
256 | current.map_or(true, |x| x < self.domain as u32) | |
257 | } | |
a2a8927a XL |
258 | } |
259 | ||
260 | /// This data structure optimizes for cases where the stored bits in each row | |
261 | /// are expected to be highly contiguous (long ranges of 1s or 0s), in contrast | |
262 | /// to BitMatrix and SparseBitMatrix which are optimized for | |
263 | /// "random"/non-contiguous bits and cheap(er) point queries at the expense of | |
264 | /// memory usage. | |
265 | #[derive(Clone)] | |
266 | pub struct SparseIntervalMatrix<R, C> | |
267 | where | |
268 | R: Idx, | |
269 | C: Idx, | |
270 | { | |
271 | rows: IndexVec<R, IntervalSet<C>>, | |
272 | column_size: usize, | |
273 | } | |
274 | ||
275 | impl<R: Idx, C: Step + Idx> SparseIntervalMatrix<R, C> { | |
276 | pub fn new(column_size: usize) -> SparseIntervalMatrix<R, C> { | |
277 | SparseIntervalMatrix { rows: IndexVec::new(), column_size } | |
278 | } | |
279 | ||
280 | pub fn rows(&self) -> impl Iterator<Item = R> { | |
281 | self.rows.indices() | |
282 | } | |
283 | ||
284 | pub fn row(&self, row: R) -> Option<&IntervalSet<C>> { | |
285 | self.rows.get(row) | |
286 | } | |
287 | ||
288 | fn ensure_row(&mut self, row: R) -> &mut IntervalSet<C> { | |
49aad941 | 289 | self.rows.ensure_contains_elem(row, || IntervalSet::new(self.column_size)) |
a2a8927a XL |
290 | } |
291 | ||
292 | pub fn union_row(&mut self, row: R, from: &IntervalSet<C>) -> bool | |
293 | where | |
294 | C: Step, | |
295 | { | |
296 | self.ensure_row(row).union(from) | |
297 | } | |
298 | ||
299 | pub fn union_rows(&mut self, read: R, write: R) -> bool | |
300 | where | |
301 | C: Step, | |
302 | { | |
303 | if read == write || self.rows.get(read).is_none() { | |
304 | return false; | |
305 | } | |
306 | self.ensure_row(write); | |
307 | let (read_row, write_row) = self.rows.pick2_mut(read, write); | |
308 | write_row.union(read_row) | |
309 | } | |
310 | ||
311 | pub fn insert_all_into_row(&mut self, row: R) { | |
312 | self.ensure_row(row).insert_all(); | |
313 | } | |
314 | ||
315 | pub fn insert_range(&mut self, row: R, range: impl RangeBounds<C> + Clone) { | |
316 | self.ensure_row(row).insert_range(range); | |
317 | } | |
318 | ||
319 | pub fn insert(&mut self, row: R, point: C) -> bool { | |
320 | self.ensure_row(row).insert(point) | |
321 | } | |
322 | ||
323 | pub fn contains(&self, row: R, point: C) -> bool { | |
49aad941 | 324 | self.row(row).is_some_and(|r| r.contains(point)) |
a2a8927a XL |
325 | } |
326 | } |