]>
Commit | Line | Data |
---|---|---|
a2a8927a XL |
1 | use std::iter::Step; |
2 | use std::marker::PhantomData; | |
3 | use std::ops::Bound; | |
4 | use std::ops::RangeBounds; | |
5 | ||
6 | use crate::vec::Idx; | |
7 | use crate::vec::IndexVec; | |
8 | use smallvec::SmallVec; | |
9 | ||
10 | #[cfg(test)] | |
11 | mod tests; | |
12 | ||
13 | /// Stores a set of intervals on the indices. | |
14 | #[derive(Debug, Clone)] | |
15 | pub struct IntervalSet<I> { | |
16 | // Start, end | |
17 | map: SmallVec<[(u32, u32); 4]>, | |
18 | domain: usize, | |
19 | _data: PhantomData<I>, | |
20 | } | |
21 | ||
22 | #[inline] | |
23 | fn inclusive_start<T: Idx>(range: impl RangeBounds<T>) -> u32 { | |
24 | match range.start_bound() { | |
25 | Bound::Included(start) => start.index() as u32, | |
26 | Bound::Excluded(start) => start.index() as u32 + 1, | |
27 | Bound::Unbounded => 0, | |
28 | } | |
29 | } | |
30 | ||
31 | #[inline] | |
32 | fn inclusive_end<T: Idx>(domain: usize, range: impl RangeBounds<T>) -> Option<u32> { | |
33 | let end = match range.end_bound() { | |
34 | Bound::Included(end) => end.index() as u32, | |
35 | Bound::Excluded(end) => end.index().checked_sub(1)? as u32, | |
36 | Bound::Unbounded => domain.checked_sub(1)? as u32, | |
37 | }; | |
38 | Some(end) | |
39 | } | |
40 | ||
41 | impl<I: Idx> IntervalSet<I> { | |
42 | pub fn new(domain: usize) -> IntervalSet<I> { | |
43 | IntervalSet { map: SmallVec::new(), domain, _data: PhantomData } | |
44 | } | |
45 | ||
46 | pub fn clear(&mut self) { | |
47 | self.map.clear(); | |
48 | } | |
49 | ||
50 | pub fn iter(&self) -> impl Iterator<Item = I> + '_ | |
51 | where | |
52 | I: Step, | |
53 | { | |
54 | self.iter_intervals().flatten() | |
55 | } | |
56 | ||
57 | /// Iterates through intervals stored in the set, in order. | |
58 | pub fn iter_intervals(&self) -> impl Iterator<Item = std::ops::Range<I>> + '_ | |
59 | where | |
60 | I: Step, | |
61 | { | |
62 | self.map.iter().map(|&(start, end)| I::new(start as usize)..I::new(end as usize + 1)) | |
63 | } | |
64 | ||
65 | /// Returns true if we increased the number of elements present. | |
66 | pub fn insert(&mut self, point: I) -> bool { | |
67 | self.insert_range(point..=point) | |
68 | } | |
69 | ||
70 | /// Returns true if we increased the number of elements present. | |
71 | pub fn insert_range(&mut self, range: impl RangeBounds<I> + Clone) -> bool { | |
72 | let start = inclusive_start(range.clone()); | |
73 | let Some(mut end) = inclusive_end(self.domain, range) else { | |
74 | // empty range | |
75 | return false; | |
76 | }; | |
77 | if start > end { | |
78 | return false; | |
79 | } | |
80 | ||
81 | loop { | |
82 | // This condition looks a bit weird, but actually makes sense. | |
83 | // | |
84 | // if r.0 == end + 1, then we're actually adjacent, so we want to | |
85 | // continue to the next range. We're looking here for the first | |
86 | // range which starts *non-adjacently* to our end. | |
87 | let next = self.map.partition_point(|r| r.0 <= end + 1); | |
88 | if let Some(last) = next.checked_sub(1) { | |
89 | let (prev_start, prev_end) = &mut self.map[last]; | |
90 | if *prev_end + 1 >= start { | |
91 | // If the start for the inserted range is adjacent to the | |
92 | // end of the previous, we can extend the previous range. | |
93 | if start < *prev_start { | |
94 | // Our range starts before the one we found. We'll need | |
95 | // to *remove* it, and then try again. | |
96 | // | |
97 | // FIXME: This is not so efficient; we may need to | |
98 | // recurse a bunch of times here. Instead, it's probably | |
99 | // better to do something like drain_filter(...) on the | |
100 | // map to be able to delete or modify all the ranges in | |
101 | // start..=end and then potentially re-insert a new | |
102 | // range. | |
103 | end = std::cmp::max(end, *prev_end); | |
104 | self.map.remove(last); | |
105 | } else { | |
106 | // We overlap with the previous range, increase it to | |
107 | // include us. | |
108 | // | |
109 | // Make sure we're actually going to *increase* it though -- | |
110 | // it may be that end is just inside the previously existing | |
111 | // set. | |
112 | return if end > *prev_end { | |
113 | *prev_end = end; | |
114 | true | |
115 | } else { | |
116 | false | |
117 | }; | |
118 | } | |
119 | } else { | |
120 | // Otherwise, we don't overlap, so just insert | |
121 | self.map.insert(last + 1, (start, end)); | |
122 | return true; | |
123 | } | |
124 | } else { | |
125 | if self.map.is_empty() { | |
126 | // Quite common in practice, and expensive to call memcpy | |
127 | // with length zero. | |
128 | self.map.push((start, end)); | |
129 | } else { | |
130 | self.map.insert(next, (start, end)); | |
131 | } | |
132 | return true; | |
133 | } | |
134 | } | |
135 | } | |
136 | ||
137 | pub fn contains(&self, needle: I) -> bool { | |
138 | let needle = needle.index() as u32; | |
5e7ed085 FG |
139 | let Some(last) = self.map.partition_point(|r| r.0 <= needle).checked_sub(1) else { |
140 | // All ranges in the map start after the new range's end | |
141 | return false; | |
a2a8927a XL |
142 | }; |
143 | let (_, prev_end) = &self.map[last]; | |
144 | needle <= *prev_end | |
145 | } | |
146 | ||
147 | pub fn superset(&self, other: &IntervalSet<I>) -> bool | |
148 | where | |
149 | I: Step, | |
150 | { | |
151 | // FIXME: Performance here is probably not great. We will be doing a lot | |
152 | // of pointless tree traversals. | |
153 | other.iter().all(|elem| self.contains(elem)) | |
154 | } | |
155 | ||
156 | pub fn is_empty(&self) -> bool { | |
157 | self.map.is_empty() | |
158 | } | |
159 | ||
160 | /// Returns the maximum (last) element present in the set from `range`. | |
161 | pub fn last_set_in(&self, range: impl RangeBounds<I> + Clone) -> Option<I> { | |
162 | let start = inclusive_start(range.clone()); | |
163 | let Some(end) = inclusive_end(self.domain, range) else { | |
164 | // empty range | |
165 | return None; | |
166 | }; | |
167 | if start > end { | |
168 | return None; | |
169 | } | |
5e7ed085 FG |
170 | let Some(last) = self.map.partition_point(|r| r.0 <= end).checked_sub(1) else { |
171 | // All ranges in the map start after the new range's end | |
172 | return None; | |
a2a8927a XL |
173 | }; |
174 | let (_, prev_end) = &self.map[last]; | |
175 | if start <= *prev_end { Some(I::new(std::cmp::min(*prev_end, end) as usize)) } else { None } | |
176 | } | |
177 | ||
178 | pub fn insert_all(&mut self) { | |
179 | self.clear(); | |
180 | self.map.push((0, self.domain.try_into().unwrap())); | |
181 | } | |
182 | ||
183 | pub fn union(&mut self, other: &IntervalSet<I>) -> bool | |
184 | where | |
185 | I: Step, | |
186 | { | |
187 | assert_eq!(self.domain, other.domain); | |
188 | let mut did_insert = false; | |
189 | for range in other.iter_intervals() { | |
190 | did_insert |= self.insert_range(range); | |
191 | } | |
192 | did_insert | |
193 | } | |
194 | } | |
195 | ||
196 | /// This data structure optimizes for cases where the stored bits in each row | |
197 | /// are expected to be highly contiguous (long ranges of 1s or 0s), in contrast | |
198 | /// to BitMatrix and SparseBitMatrix which are optimized for | |
199 | /// "random"/non-contiguous bits and cheap(er) point queries at the expense of | |
200 | /// memory usage. | |
201 | #[derive(Clone)] | |
202 | pub struct SparseIntervalMatrix<R, C> | |
203 | where | |
204 | R: Idx, | |
205 | C: Idx, | |
206 | { | |
207 | rows: IndexVec<R, IntervalSet<C>>, | |
208 | column_size: usize, | |
209 | } | |
210 | ||
211 | impl<R: Idx, C: Step + Idx> SparseIntervalMatrix<R, C> { | |
212 | pub fn new(column_size: usize) -> SparseIntervalMatrix<R, C> { | |
213 | SparseIntervalMatrix { rows: IndexVec::new(), column_size } | |
214 | } | |
215 | ||
216 | pub fn rows(&self) -> impl Iterator<Item = R> { | |
217 | self.rows.indices() | |
218 | } | |
219 | ||
220 | pub fn row(&self, row: R) -> Option<&IntervalSet<C>> { | |
221 | self.rows.get(row) | |
222 | } | |
223 | ||
224 | fn ensure_row(&mut self, row: R) -> &mut IntervalSet<C> { | |
225 | self.rows.ensure_contains_elem(row, || IntervalSet::new(self.column_size)); | |
226 | &mut self.rows[row] | |
227 | } | |
228 | ||
229 | pub fn union_row(&mut self, row: R, from: &IntervalSet<C>) -> bool | |
230 | where | |
231 | C: Step, | |
232 | { | |
233 | self.ensure_row(row).union(from) | |
234 | } | |
235 | ||
236 | pub fn union_rows(&mut self, read: R, write: R) -> bool | |
237 | where | |
238 | C: Step, | |
239 | { | |
240 | if read == write || self.rows.get(read).is_none() { | |
241 | return false; | |
242 | } | |
243 | self.ensure_row(write); | |
244 | let (read_row, write_row) = self.rows.pick2_mut(read, write); | |
245 | write_row.union(read_row) | |
246 | } | |
247 | ||
248 | pub fn insert_all_into_row(&mut self, row: R) { | |
249 | self.ensure_row(row).insert_all(); | |
250 | } | |
251 | ||
252 | pub fn insert_range(&mut self, row: R, range: impl RangeBounds<C> + Clone) { | |
253 | self.ensure_row(row).insert_range(range); | |
254 | } | |
255 | ||
256 | pub fn insert(&mut self, row: R, point: C) -> bool { | |
257 | self.ensure_row(row).insert(point) | |
258 | } | |
259 | ||
260 | pub fn contains(&self, row: R, point: C) -> bool { | |
261 | self.row(row).map_or(false, |r| r.contains(point)) | |
262 | } | |
263 | } |