]>
Commit | Line | Data |
---|---|---|
f035d41b XL |
1 | use std::cmp; |
2 | use std::collections::hash_map::Entry; | |
3 | use std::collections::{HashMap, HashSet}; | |
4 | use std::fmt; | |
5 | ||
6 | use utf8_ranges::{Utf8Range, Utf8Sequences}; | |
7 | ||
8 | use crate::automaton::Automaton; | |
9 | ||
10 | const STATE_LIMIT: usize = 10_000; // currently at least 20MB >_< | |
11 | ||
12 | /// An error that occurred while building a Levenshtein automaton. | |
13 | /// | |
14 | /// This error is only defined when the `levenshtein` crate feature is enabled. | |
15 | #[derive(Debug)] | |
16 | pub enum LevenshteinError { | |
17 | /// If construction of the automaton reaches some hard-coded limit | |
18 | /// on the number of states, then this error is returned. | |
19 | /// | |
20 | /// The number given is the limit that was exceeded. | |
21 | TooManyStates(usize), | |
22 | } | |
23 | ||
24 | impl fmt::Display for LevenshteinError { | |
25 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { | |
26 | match *self { | |
27 | LevenshteinError::TooManyStates(size_limit) => write!( | |
28 | f, | |
29 | "Levenshtein automaton exceeds size limit of \ | |
30 | {} states", | |
31 | size_limit | |
32 | ), | |
33 | } | |
34 | } | |
35 | } | |
36 | ||
37 | impl std::error::Error for LevenshteinError {} | |
38 | ||
39 | /// A Unicode aware Levenshtein automaton for running efficient fuzzy queries. | |
40 | /// | |
41 | /// This is only defined when the `levenshtein` crate feature is enabled. | |
42 | /// | |
43 | /// A Levenshtein automata is one way to search any finite state transducer | |
44 | /// for keys that *approximately* match a given query. A Levenshtein automaton | |
45 | /// approximates this by returning all keys within a certain edit distance of | |
46 | /// the query. The edit distance is defined by the number of insertions, | |
47 | /// deletions and substitutions required to turn the query into the key. | |
48 | /// Insertions, deletions and substitutions are based on | |
49 | /// **Unicode characters** (where each character is a single Unicode scalar | |
50 | /// value). | |
51 | /// | |
52 | /// # Example | |
53 | /// | |
54 | /// This example shows how to find all keys within an edit distance of `1` | |
55 | /// from `foo`. | |
56 | /// | |
57 | /// ```rust | |
58 | /// use fst::automaton::Levenshtein; | |
59 | /// use fst::{IntoStreamer, Streamer, Set}; | |
60 | /// | |
61 | /// fn main() { | |
62 | /// let keys = vec!["fa", "fo", "fob", "focus", "foo", "food", "foul"]; | |
63 | /// let set = Set::from_iter(keys).unwrap(); | |
64 | /// | |
65 | /// let lev = Levenshtein::new("foo", 1).unwrap(); | |
66 | /// let mut stream = set.search(&lev).into_stream(); | |
67 | /// | |
68 | /// let mut keys = vec![]; | |
69 | /// while let Some(key) = stream.next() { | |
70 | /// keys.push(key.to_vec()); | |
71 | /// } | |
72 | /// assert_eq!(keys, vec![ | |
73 | /// "fo".as_bytes(), // 1 deletion | |
74 | /// "fob".as_bytes(), // 1 substitution | |
75 | /// "foo".as_bytes(), // 0 insertions/deletions/substitutions | |
76 | /// "food".as_bytes(), // 1 insertion | |
77 | /// ]); | |
78 | /// } | |
79 | /// ``` | |
80 | /// | |
81 | /// This example only uses ASCII characters, but it will work equally well | |
82 | /// on Unicode characters. | |
83 | /// | |
84 | /// # Warning: experimental | |
85 | /// | |
86 | /// While executing this Levenshtein automaton against a finite state | |
87 | /// transducer will be very fast, *constructing* an automaton may not be. | |
88 | /// Namely, this implementation is a proof of concept. While I believe the | |
89 | /// algorithmic complexity is not exponential, the implementation is not speedy | |
90 | /// and it can use enormous amounts of memory (tens of MB before a hard-coded | |
91 | /// limit will cause an error to be returned). | |
92 | /// | |
93 | /// This is important functionality, so one should count on this implementation | |
94 | /// being vastly improved in the future. | |
95 | pub struct Levenshtein { | |
96 | prog: DynamicLevenshtein, | |
97 | dfa: Dfa, | |
98 | } | |
99 | ||
100 | impl Levenshtein { | |
101 | /// Create a new Levenshtein query. | |
102 | /// | |
103 | /// The query finds all matching terms that are at most `distance` | |
104 | /// edit operations from `query`. (An edit operation may be an insertion, | |
105 | /// a deletion or a substitution.) | |
106 | /// | |
107 | /// If the underlying automaton becomes too big, then an error is returned. | |
108 | /// | |
109 | /// A `Levenshtein` value satisfies the `Automaton` trait, which means it | |
110 | /// can be used with the `search` method of any finite state transducer. | |
111 | #[inline] | |
112 | pub fn new( | |
113 | query: &str, | |
114 | distance: u32, | |
115 | ) -> Result<Levenshtein, LevenshteinError> { | |
116 | let lev = DynamicLevenshtein { | |
117 | query: query.to_owned(), | |
118 | dist: distance as usize, | |
119 | }; | |
120 | let dfa = DfaBuilder::new(lev.clone()).build()?; | |
121 | Ok(Levenshtein { prog: lev, dfa }) | |
122 | } | |
123 | } | |
124 | ||
125 | impl fmt::Debug for Levenshtein { | |
126 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { | |
127 | write!( | |
128 | f, | |
129 | "Levenshtein(query: {:?}, distance: {:?})", | |
130 | self.prog.query, self.prog.dist | |
131 | ) | |
132 | } | |
133 | } | |
134 | ||
135 | #[derive(Clone)] | |
136 | struct DynamicLevenshtein { | |
137 | query: String, | |
138 | dist: usize, | |
139 | } | |
140 | ||
141 | impl DynamicLevenshtein { | |
142 | fn start(&self) -> Vec<usize> { | |
143 | (0..self.query.chars().count() + 1).collect() | |
144 | } | |
145 | ||
146 | fn is_match(&self, state: &[usize]) -> bool { | |
147 | state.last().map(|&n| n <= self.dist).unwrap_or(false) | |
148 | } | |
149 | ||
150 | fn can_match(&self, state: &[usize]) -> bool { | |
151 | state.iter().min().map(|&n| n <= self.dist).unwrap_or(false) | |
152 | } | |
153 | ||
154 | fn accept(&self, state: &[usize], chr: Option<char>) -> Vec<usize> { | |
155 | let mut next = vec![state[0] + 1]; | |
156 | for (i, c) in self.query.chars().enumerate() { | |
157 | let cost = if Some(c) == chr { 0 } else { 1 }; | |
158 | let v = cmp::min( | |
159 | cmp::min(next[i] + 1, state[i + 1] + 1), | |
160 | state[i] + cost, | |
161 | ); | |
162 | next.push(cmp::min(v, self.dist + 1)); | |
163 | } | |
164 | next | |
165 | } | |
166 | } | |
167 | ||
168 | impl Automaton for Levenshtein { | |
169 | type State = Option<usize>; | |
170 | ||
171 | #[inline] | |
172 | fn start(&self) -> Option<usize> { | |
173 | Some(0) | |
174 | } | |
175 | ||
176 | #[inline] | |
177 | fn is_match(&self, state: &Option<usize>) -> bool { | |
178 | state.map(|state| self.dfa.states[state].is_match).unwrap_or(false) | |
179 | } | |
180 | ||
181 | #[inline] | |
182 | fn can_match(&self, state: &Option<usize>) -> bool { | |
183 | state.is_some() | |
184 | } | |
185 | ||
186 | #[inline] | |
187 | fn accept(&self, state: &Option<usize>, byte: u8) -> Option<usize> { | |
188 | state.and_then(|state| self.dfa.states[state].next[byte as usize]) | |
189 | } | |
190 | } | |
191 | ||
192 | #[derive(Debug)] | |
193 | struct Dfa { | |
194 | states: Vec<State>, | |
195 | } | |
196 | ||
197 | struct State { | |
198 | next: [Option<usize>; 256], | |
199 | is_match: bool, | |
200 | } | |
201 | ||
202 | impl fmt::Debug for State { | |
203 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { | |
204 | writeln!(f, "State {{")?; | |
205 | writeln!(f, " is_match: {:?}", self.is_match)?; | |
206 | for i in 0..256 { | |
207 | if let Some(si) = self.next[i] { | |
208 | writeln!(f, " {:?}: {:?}", i, si)?; | |
209 | } | |
210 | } | |
211 | write!(f, "}}") | |
212 | } | |
213 | } | |
214 | ||
215 | struct DfaBuilder { | |
216 | dfa: Dfa, | |
217 | lev: DynamicLevenshtein, | |
218 | cache: HashMap<Vec<usize>, usize>, | |
219 | } | |
220 | ||
221 | impl DfaBuilder { | |
222 | fn new(lev: DynamicLevenshtein) -> DfaBuilder { | |
223 | DfaBuilder { | |
224 | dfa: Dfa { states: Vec::with_capacity(16) }, | |
225 | lev, | |
226 | cache: HashMap::with_capacity(1024), | |
227 | } | |
228 | } | |
229 | ||
230 | fn build(mut self) -> Result<Dfa, LevenshteinError> { | |
231 | let mut stack = vec![self.lev.start()]; | |
232 | let mut seen = HashSet::new(); | |
233 | let query = self.lev.query.clone(); // temp work around of borrowck | |
234 | while let Some(lev_state) = stack.pop() { | |
235 | let dfa_si = self.cached_state(&lev_state).unwrap(); | |
236 | let mismatch = self.add_mismatch_utf8_states(dfa_si, &lev_state); | |
237 | if let Some((next_si, lev_next)) = mismatch { | |
238 | if !seen.contains(&next_si) { | |
239 | seen.insert(next_si); | |
240 | stack.push(lev_next); | |
241 | } | |
242 | } | |
243 | for (i, c) in query.chars().enumerate() { | |
244 | if lev_state[i] > self.lev.dist { | |
245 | continue; | |
246 | } | |
247 | let lev_next = self.lev.accept(&lev_state, Some(c)); | |
248 | let next_si = self.cached_state(&lev_next); | |
249 | if let Some(next_si) = next_si { | |
250 | self.add_utf8_sequences(true, dfa_si, next_si, c, c); | |
251 | if !seen.contains(&next_si) { | |
252 | seen.insert(next_si); | |
253 | stack.push(lev_next); | |
254 | } | |
255 | } | |
256 | } | |
257 | if self.dfa.states.len() > STATE_LIMIT { | |
258 | return Err(LevenshteinError::TooManyStates(STATE_LIMIT)); | |
259 | } | |
260 | } | |
261 | Ok(self.dfa) | |
262 | } | |
263 | ||
264 | fn cached_state(&mut self, lev_state: &[usize]) -> Option<usize> { | |
265 | self.cached(lev_state).map(|(si, _)| si) | |
266 | } | |
267 | ||
268 | fn cached(&mut self, lev_state: &[usize]) -> Option<(usize, bool)> { | |
269 | if !self.lev.can_match(lev_state) { | |
270 | return None; | |
271 | } | |
272 | Some(match self.cache.entry(lev_state.to_vec()) { | |
273 | Entry::Occupied(v) => (*v.get(), true), | |
274 | Entry::Vacant(v) => { | |
275 | let is_match = self.lev.is_match(lev_state); | |
276 | self.dfa.states.push(State { next: [None; 256], is_match }); | |
277 | (*v.insert(self.dfa.states.len() - 1), false) | |
278 | } | |
279 | }) | |
280 | } | |
281 | ||
282 | fn add_mismatch_utf8_states( | |
283 | &mut self, | |
284 | from_si: usize, | |
285 | lev_state: &[usize], | |
286 | ) -> Option<(usize, Vec<usize>)> { | |
287 | let mismatch_state = self.lev.accept(lev_state, None); | |
288 | let to_si = match self.cached(&mismatch_state) { | |
289 | None => return None, | |
290 | Some((si, _)) => si, | |
291 | // Some((si, true)) => return Some((si, mismatch_state)), | |
292 | // Some((si, false)) => si, | |
293 | }; | |
294 | self.add_utf8_sequences(false, from_si, to_si, '\u{0}', '\u{10FFFF}'); | |
295 | return Some((to_si, mismatch_state)); | |
296 | } | |
297 | ||
298 | fn add_utf8_sequences( | |
299 | &mut self, | |
300 | overwrite: bool, | |
301 | from_si: usize, | |
302 | to_si: usize, | |
303 | from_chr: char, | |
304 | to_chr: char, | |
305 | ) { | |
306 | for seq in Utf8Sequences::new(from_chr, to_chr) { | |
307 | let mut fsi = from_si; | |
308 | for range in &seq.as_slice()[0..seq.len() - 1] { | |
309 | let tsi = self.new_state(false); | |
310 | self.add_utf8_range(overwrite, fsi, tsi, range); | |
311 | fsi = tsi; | |
312 | } | |
313 | self.add_utf8_range( | |
314 | overwrite, | |
315 | fsi, | |
316 | to_si, | |
317 | &seq.as_slice()[seq.len() - 1], | |
318 | ); | |
319 | } | |
320 | } | |
321 | ||
322 | fn add_utf8_range( | |
323 | &mut self, | |
324 | overwrite: bool, | |
325 | from: usize, | |
326 | to: usize, | |
327 | range: &Utf8Range, | |
328 | ) { | |
329 | for b in range.start as usize..range.end as usize + 1 { | |
330 | if overwrite || self.dfa.states[from].next[b].is_none() { | |
331 | self.dfa.states[from].next[b] = Some(to); | |
332 | } | |
333 | } | |
334 | } | |
335 | ||
336 | fn new_state(&mut self, is_match: bool) -> usize { | |
337 | self.dfa.states.push(State { next: [None; 256], is_match }); | |
338 | self.dfa.states.len() - 1 | |
339 | } | |
340 | } |