]> git.proxmox.com Git - rustc.git/blame - vendor/fst/src/automaton/levenshtein.rs
New upstream version 1.48.0+dfsg1
[rustc.git] / vendor / fst / src / automaton / levenshtein.rs
CommitLineData
f035d41b
XL
1use std::cmp;
2use std::collections::hash_map::Entry;
3use std::collections::{HashMap, HashSet};
4use std::fmt;
5
6use utf8_ranges::{Utf8Range, Utf8Sequences};
7
8use crate::automaton::Automaton;
9
10const STATE_LIMIT: usize = 10_000; // currently at least 20MB >_<
11
12/// An error that occurred while building a Levenshtein automaton.
13///
14/// This error is only defined when the `levenshtein` crate feature is enabled.
15#[derive(Debug)]
16pub enum LevenshteinError {
17 /// If construction of the automaton reaches some hard-coded limit
18 /// on the number of states, then this error is returned.
19 ///
20 /// The number given is the limit that was exceeded.
21 TooManyStates(usize),
22}
23
24impl fmt::Display for LevenshteinError {
25 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
26 match *self {
27 LevenshteinError::TooManyStates(size_limit) => write!(
28 f,
29 "Levenshtein automaton exceeds size limit of \
30 {} states",
31 size_limit
32 ),
33 }
34 }
35}
36
37impl std::error::Error for LevenshteinError {}
38
39/// A Unicode aware Levenshtein automaton for running efficient fuzzy queries.
40///
41/// This is only defined when the `levenshtein` crate feature is enabled.
42///
43/// A Levenshtein automata is one way to search any finite state transducer
44/// for keys that *approximately* match a given query. A Levenshtein automaton
45/// approximates this by returning all keys within a certain edit distance of
46/// the query. The edit distance is defined by the number of insertions,
47/// deletions and substitutions required to turn the query into the key.
48/// Insertions, deletions and substitutions are based on
49/// **Unicode characters** (where each character is a single Unicode scalar
50/// value).
51///
52/// # Example
53///
54/// This example shows how to find all keys within an edit distance of `1`
55/// from `foo`.
56///
57/// ```rust
58/// use fst::automaton::Levenshtein;
59/// use fst::{IntoStreamer, Streamer, Set};
60///
61/// fn main() {
62/// let keys = vec!["fa", "fo", "fob", "focus", "foo", "food", "foul"];
63/// let set = Set::from_iter(keys).unwrap();
64///
65/// let lev = Levenshtein::new("foo", 1).unwrap();
66/// let mut stream = set.search(&lev).into_stream();
67///
68/// let mut keys = vec![];
69/// while let Some(key) = stream.next() {
70/// keys.push(key.to_vec());
71/// }
72/// assert_eq!(keys, vec![
73/// "fo".as_bytes(), // 1 deletion
74/// "fob".as_bytes(), // 1 substitution
75/// "foo".as_bytes(), // 0 insertions/deletions/substitutions
76/// "food".as_bytes(), // 1 insertion
77/// ]);
78/// }
79/// ```
80///
81/// This example only uses ASCII characters, but it will work equally well
82/// on Unicode characters.
83///
84/// # Warning: experimental
85///
86/// While executing this Levenshtein automaton against a finite state
87/// transducer will be very fast, *constructing* an automaton may not be.
88/// Namely, this implementation is a proof of concept. While I believe the
89/// algorithmic complexity is not exponential, the implementation is not speedy
90/// and it can use enormous amounts of memory (tens of MB before a hard-coded
91/// limit will cause an error to be returned).
92///
93/// This is important functionality, so one should count on this implementation
94/// being vastly improved in the future.
95pub struct Levenshtein {
96 prog: DynamicLevenshtein,
97 dfa: Dfa,
98}
99
100impl Levenshtein {
101 /// Create a new Levenshtein query.
102 ///
103 /// The query finds all matching terms that are at most `distance`
104 /// edit operations from `query`. (An edit operation may be an insertion,
105 /// a deletion or a substitution.)
106 ///
107 /// If the underlying automaton becomes too big, then an error is returned.
108 ///
109 /// A `Levenshtein` value satisfies the `Automaton` trait, which means it
110 /// can be used with the `search` method of any finite state transducer.
111 #[inline]
112 pub fn new(
113 query: &str,
114 distance: u32,
115 ) -> Result<Levenshtein, LevenshteinError> {
116 let lev = DynamicLevenshtein {
117 query: query.to_owned(),
118 dist: distance as usize,
119 };
120 let dfa = DfaBuilder::new(lev.clone()).build()?;
121 Ok(Levenshtein { prog: lev, dfa })
122 }
123}
124
125impl fmt::Debug for Levenshtein {
126 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
127 write!(
128 f,
129 "Levenshtein(query: {:?}, distance: {:?})",
130 self.prog.query, self.prog.dist
131 )
132 }
133}
134
135#[derive(Clone)]
136struct DynamicLevenshtein {
137 query: String,
138 dist: usize,
139}
140
141impl DynamicLevenshtein {
142 fn start(&self) -> Vec<usize> {
143 (0..self.query.chars().count() + 1).collect()
144 }
145
146 fn is_match(&self, state: &[usize]) -> bool {
147 state.last().map(|&n| n <= self.dist).unwrap_or(false)
148 }
149
150 fn can_match(&self, state: &[usize]) -> bool {
151 state.iter().min().map(|&n| n <= self.dist).unwrap_or(false)
152 }
153
154 fn accept(&self, state: &[usize], chr: Option<char>) -> Vec<usize> {
155 let mut next = vec![state[0] + 1];
156 for (i, c) in self.query.chars().enumerate() {
157 let cost = if Some(c) == chr { 0 } else { 1 };
158 let v = cmp::min(
159 cmp::min(next[i] + 1, state[i + 1] + 1),
160 state[i] + cost,
161 );
162 next.push(cmp::min(v, self.dist + 1));
163 }
164 next
165 }
166}
167
168impl Automaton for Levenshtein {
169 type State = Option<usize>;
170
171 #[inline]
172 fn start(&self) -> Option<usize> {
173 Some(0)
174 }
175
176 #[inline]
177 fn is_match(&self, state: &Option<usize>) -> bool {
178 state.map(|state| self.dfa.states[state].is_match).unwrap_or(false)
179 }
180
181 #[inline]
182 fn can_match(&self, state: &Option<usize>) -> bool {
183 state.is_some()
184 }
185
186 #[inline]
187 fn accept(&self, state: &Option<usize>, byte: u8) -> Option<usize> {
188 state.and_then(|state| self.dfa.states[state].next[byte as usize])
189 }
190}
191
192#[derive(Debug)]
193struct Dfa {
194 states: Vec<State>,
195}
196
197struct State {
198 next: [Option<usize>; 256],
199 is_match: bool,
200}
201
202impl fmt::Debug for State {
203 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
204 writeln!(f, "State {{")?;
205 writeln!(f, " is_match: {:?}", self.is_match)?;
206 for i in 0..256 {
207 if let Some(si) = self.next[i] {
208 writeln!(f, " {:?}: {:?}", i, si)?;
209 }
210 }
211 write!(f, "}}")
212 }
213}
214
215struct DfaBuilder {
216 dfa: Dfa,
217 lev: DynamicLevenshtein,
218 cache: HashMap<Vec<usize>, usize>,
219}
220
221impl DfaBuilder {
222 fn new(lev: DynamicLevenshtein) -> DfaBuilder {
223 DfaBuilder {
224 dfa: Dfa { states: Vec::with_capacity(16) },
225 lev,
226 cache: HashMap::with_capacity(1024),
227 }
228 }
229
230 fn build(mut self) -> Result<Dfa, LevenshteinError> {
231 let mut stack = vec![self.lev.start()];
232 let mut seen = HashSet::new();
233 let query = self.lev.query.clone(); // temp work around of borrowck
234 while let Some(lev_state) = stack.pop() {
235 let dfa_si = self.cached_state(&lev_state).unwrap();
236 let mismatch = self.add_mismatch_utf8_states(dfa_si, &lev_state);
237 if let Some((next_si, lev_next)) = mismatch {
238 if !seen.contains(&next_si) {
239 seen.insert(next_si);
240 stack.push(lev_next);
241 }
242 }
243 for (i, c) in query.chars().enumerate() {
244 if lev_state[i] > self.lev.dist {
245 continue;
246 }
247 let lev_next = self.lev.accept(&lev_state, Some(c));
248 let next_si = self.cached_state(&lev_next);
249 if let Some(next_si) = next_si {
250 self.add_utf8_sequences(true, dfa_si, next_si, c, c);
251 if !seen.contains(&next_si) {
252 seen.insert(next_si);
253 stack.push(lev_next);
254 }
255 }
256 }
257 if self.dfa.states.len() > STATE_LIMIT {
258 return Err(LevenshteinError::TooManyStates(STATE_LIMIT));
259 }
260 }
261 Ok(self.dfa)
262 }
263
264 fn cached_state(&mut self, lev_state: &[usize]) -> Option<usize> {
265 self.cached(lev_state).map(|(si, _)| si)
266 }
267
268 fn cached(&mut self, lev_state: &[usize]) -> Option<(usize, bool)> {
269 if !self.lev.can_match(lev_state) {
270 return None;
271 }
272 Some(match self.cache.entry(lev_state.to_vec()) {
273 Entry::Occupied(v) => (*v.get(), true),
274 Entry::Vacant(v) => {
275 let is_match = self.lev.is_match(lev_state);
276 self.dfa.states.push(State { next: [None; 256], is_match });
277 (*v.insert(self.dfa.states.len() - 1), false)
278 }
279 })
280 }
281
282 fn add_mismatch_utf8_states(
283 &mut self,
284 from_si: usize,
285 lev_state: &[usize],
286 ) -> Option<(usize, Vec<usize>)> {
287 let mismatch_state = self.lev.accept(lev_state, None);
288 let to_si = match self.cached(&mismatch_state) {
289 None => return None,
290 Some((si, _)) => si,
291 // Some((si, true)) => return Some((si, mismatch_state)),
292 // Some((si, false)) => si,
293 };
294 self.add_utf8_sequences(false, from_si, to_si, '\u{0}', '\u{10FFFF}');
295 return Some((to_si, mismatch_state));
296 }
297
298 fn add_utf8_sequences(
299 &mut self,
300 overwrite: bool,
301 from_si: usize,
302 to_si: usize,
303 from_chr: char,
304 to_chr: char,
305 ) {
306 for seq in Utf8Sequences::new(from_chr, to_chr) {
307 let mut fsi = from_si;
308 for range in &seq.as_slice()[0..seq.len() - 1] {
309 let tsi = self.new_state(false);
310 self.add_utf8_range(overwrite, fsi, tsi, range);
311 fsi = tsi;
312 }
313 self.add_utf8_range(
314 overwrite,
315 fsi,
316 to_si,
317 &seq.as_slice()[seq.len() - 1],
318 );
319 }
320 }
321
322 fn add_utf8_range(
323 &mut self,
324 overwrite: bool,
325 from: usize,
326 to: usize,
327 range: &Utf8Range,
328 ) {
329 for b in range.start as usize..range.end as usize + 1 {
330 if overwrite || self.dfa.states[from].next[b].is_none() {
331 self.dfa.states[from].next[b] = Some(to);
332 }
333 }
334 }
335
336 fn new_state(&mut self, is_match: bool) -> usize {
337 self.dfa.states.push(State { next: [None; 256], is_match });
338 self.dfa.states.len() - 1
339 }
340}