]> git.proxmox.com Git - rustc.git/blob - src/tools/rust-analyzer/crates/syntax/src/algo.rs
New upstream version 1.64.0+dfsg1
[rustc.git] / src / tools / rust-analyzer / crates / syntax / src / algo.rs
1 //! Collection of assorted algorithms for syntax trees.
2
3 use std::hash::BuildHasherDefault;
4
5 use indexmap::IndexMap;
6 use itertools::Itertools;
7 use rustc_hash::FxHashMap;
8 use text_edit::TextEditBuilder;
9
10 use crate::{
11 AstNode, Direction, NodeOrToken, SyntaxElement, SyntaxKind, SyntaxNode, SyntaxToken, TextRange,
12 TextSize,
13 };
14
15 /// Returns ancestors of the node at the offset, sorted by length. This should
16 /// do the right thing at an edge, e.g. when searching for expressions at `{
17 /// $0foo }` we will get the name reference instead of the whole block, which
18 /// we would get if we just did `find_token_at_offset(...).flat_map(|t|
19 /// t.parent().ancestors())`.
20 pub fn ancestors_at_offset(
21 node: &SyntaxNode,
22 offset: TextSize,
23 ) -> impl Iterator<Item = SyntaxNode> {
24 node.token_at_offset(offset)
25 .map(|token| token.parent_ancestors())
26 .kmerge_by(|node1, node2| node1.text_range().len() < node2.text_range().len())
27 }
28
29 /// Finds a node of specific Ast type at offset. Note that this is slightly
30 /// imprecise: if the cursor is strictly between two nodes of the desired type,
31 /// as in
32 ///
33 /// ```no_run
34 /// struct Foo {}|struct Bar;
35 /// ```
36 ///
37 /// then the shorter node will be silently preferred.
38 pub fn find_node_at_offset<N: AstNode>(syntax: &SyntaxNode, offset: TextSize) -> Option<N> {
39 ancestors_at_offset(syntax, offset).find_map(N::cast)
40 }
41
42 pub fn find_node_at_range<N: AstNode>(syntax: &SyntaxNode, range: TextRange) -> Option<N> {
43 syntax.covering_element(range).ancestors().find_map(N::cast)
44 }
45
46 /// Skip to next non `trivia` token
47 pub fn skip_trivia_token(mut token: SyntaxToken, direction: Direction) -> Option<SyntaxToken> {
48 while token.kind().is_trivia() {
49 token = match direction {
50 Direction::Next => token.next_token()?,
51 Direction::Prev => token.prev_token()?,
52 }
53 }
54 Some(token)
55 }
56 /// Skip to next non `whitespace` token
57 pub fn skip_whitespace_token(mut token: SyntaxToken, direction: Direction) -> Option<SyntaxToken> {
58 while token.kind() == SyntaxKind::WHITESPACE {
59 token = match direction {
60 Direction::Next => token.next_token()?,
61 Direction::Prev => token.prev_token()?,
62 }
63 }
64 Some(token)
65 }
66
67 /// Finds the first sibling in the given direction which is not `trivia`
68 pub fn non_trivia_sibling(element: SyntaxElement, direction: Direction) -> Option<SyntaxElement> {
69 return match element {
70 NodeOrToken::Node(node) => node.siblings_with_tokens(direction).skip(1).find(not_trivia),
71 NodeOrToken::Token(token) => token.siblings_with_tokens(direction).skip(1).find(not_trivia),
72 };
73
74 fn not_trivia(element: &SyntaxElement) -> bool {
75 match element {
76 NodeOrToken::Node(_) => true,
77 NodeOrToken::Token(token) => !token.kind().is_trivia(),
78 }
79 }
80 }
81
82 pub fn least_common_ancestor(u: &SyntaxNode, v: &SyntaxNode) -> Option<SyntaxNode> {
83 if u == v {
84 return Some(u.clone());
85 }
86
87 let u_depth = u.ancestors().count();
88 let v_depth = v.ancestors().count();
89 let keep = u_depth.min(v_depth);
90
91 let u_candidates = u.ancestors().skip(u_depth - keep);
92 let v_candidates = v.ancestors().skip(v_depth - keep);
93 let (res, _) = u_candidates.zip(v_candidates).find(|(x, y)| x == y)?;
94 Some(res)
95 }
96
97 pub fn neighbor<T: AstNode>(me: &T, direction: Direction) -> Option<T> {
98 me.syntax().siblings(direction).skip(1).find_map(T::cast)
99 }
100
101 pub fn has_errors(node: &SyntaxNode) -> bool {
102 node.children().any(|it| it.kind() == SyntaxKind::ERROR)
103 }
104
105 type FxIndexMap<K, V> = IndexMap<K, V, BuildHasherDefault<rustc_hash::FxHasher>>;
106
107 #[derive(Debug, Hash, PartialEq, Eq)]
108 enum TreeDiffInsertPos {
109 After(SyntaxElement),
110 AsFirstChild(SyntaxElement),
111 }
112
113 #[derive(Debug)]
114 pub struct TreeDiff {
115 replacements: FxHashMap<SyntaxElement, SyntaxElement>,
116 deletions: Vec<SyntaxElement>,
117 // the vec as well as the indexmap are both here to preserve order
118 insertions: FxIndexMap<TreeDiffInsertPos, Vec<SyntaxElement>>,
119 }
120
121 impl TreeDiff {
122 pub fn into_text_edit(&self, builder: &mut TextEditBuilder) {
123 let _p = profile::span("into_text_edit");
124
125 for (anchor, to) in &self.insertions {
126 let offset = match anchor {
127 TreeDiffInsertPos::After(it) => it.text_range().end(),
128 TreeDiffInsertPos::AsFirstChild(it) => it.text_range().start(),
129 };
130 to.iter().for_each(|to| builder.insert(offset, to.to_string()));
131 }
132 for (from, to) in &self.replacements {
133 builder.replace(from.text_range(), to.to_string());
134 }
135 for text_range in self.deletions.iter().map(SyntaxElement::text_range) {
136 builder.delete(text_range);
137 }
138 }
139
140 pub fn is_empty(&self) -> bool {
141 self.replacements.is_empty() && self.deletions.is_empty() && self.insertions.is_empty()
142 }
143 }
144
145 /// Finds a (potentially minimal) diff, which, applied to `from`, will result in `to`.
146 ///
147 /// Specifically, returns a structure that consists of a replacements, insertions and deletions
148 /// such that applying this map on `from` will result in `to`.
149 ///
150 /// This function tries to find a fine-grained diff.
151 pub fn diff(from: &SyntaxNode, to: &SyntaxNode) -> TreeDiff {
152 let _p = profile::span("diff");
153
154 let mut diff = TreeDiff {
155 replacements: FxHashMap::default(),
156 insertions: FxIndexMap::default(),
157 deletions: Vec::new(),
158 };
159 let (from, to) = (from.clone().into(), to.clone().into());
160
161 if !syntax_element_eq(&from, &to) {
162 go(&mut diff, from, to);
163 }
164 return diff;
165
166 fn syntax_element_eq(lhs: &SyntaxElement, rhs: &SyntaxElement) -> bool {
167 lhs.kind() == rhs.kind()
168 && lhs.text_range().len() == rhs.text_range().len()
169 && match (&lhs, &rhs) {
170 (NodeOrToken::Node(lhs), NodeOrToken::Node(rhs)) => {
171 lhs == rhs || lhs.text() == rhs.text()
172 }
173 (NodeOrToken::Token(lhs), NodeOrToken::Token(rhs)) => lhs.text() == rhs.text(),
174 _ => false,
175 }
176 }
177
178 // FIXME: this is horribly inefficient. I bet there's a cool algorithm to diff trees properly.
179 fn go(diff: &mut TreeDiff, lhs: SyntaxElement, rhs: SyntaxElement) {
180 let (lhs, rhs) = match lhs.as_node().zip(rhs.as_node()) {
181 Some((lhs, rhs)) => (lhs, rhs),
182 _ => {
183 cov_mark::hit!(diff_node_token_replace);
184 diff.replacements.insert(lhs, rhs);
185 return;
186 }
187 };
188
189 let mut look_ahead_scratch = Vec::default();
190
191 let mut rhs_children = rhs.children_with_tokens();
192 let mut lhs_children = lhs.children_with_tokens();
193 let mut last_lhs = None;
194 loop {
195 let lhs_child = lhs_children.next();
196 match (lhs_child.clone(), rhs_children.next()) {
197 (None, None) => break,
198 (None, Some(element)) => {
199 let insert_pos = match last_lhs.clone() {
200 Some(prev) => {
201 cov_mark::hit!(diff_insert);
202 TreeDiffInsertPos::After(prev)
203 }
204 // first iteration, insert into out parent as the first child
205 None => {
206 cov_mark::hit!(diff_insert_as_first_child);
207 TreeDiffInsertPos::AsFirstChild(lhs.clone().into())
208 }
209 };
210 diff.insertions.entry(insert_pos).or_insert_with(Vec::new).push(element);
211 }
212 (Some(element), None) => {
213 cov_mark::hit!(diff_delete);
214 diff.deletions.push(element);
215 }
216 (Some(ref lhs_ele), Some(ref rhs_ele)) if syntax_element_eq(lhs_ele, rhs_ele) => {}
217 (Some(lhs_ele), Some(rhs_ele)) => {
218 // nodes differ, look for lhs_ele in rhs, if its found we can mark everything up
219 // until that element as insertions. This is important to keep the diff minimal
220 // in regards to insertions that have been actually done, this is important for
221 // use insertions as we do not want to replace the entire module node.
222 look_ahead_scratch.push(rhs_ele.clone());
223 let mut rhs_children_clone = rhs_children.clone();
224 let mut insert = false;
225 for rhs_child in &mut rhs_children_clone {
226 if syntax_element_eq(&lhs_ele, &rhs_child) {
227 cov_mark::hit!(diff_insertions);
228 insert = true;
229 break;
230 }
231 look_ahead_scratch.push(rhs_child);
232 }
233 let drain = look_ahead_scratch.drain(..);
234 if insert {
235 let insert_pos = if let Some(prev) = last_lhs.clone().filter(|_| insert) {
236 TreeDiffInsertPos::After(prev)
237 } else {
238 cov_mark::hit!(insert_first_child);
239 TreeDiffInsertPos::AsFirstChild(lhs.clone().into())
240 };
241
242 diff.insertions.entry(insert_pos).or_insert_with(Vec::new).extend(drain);
243 rhs_children = rhs_children_clone;
244 } else {
245 go(diff, lhs_ele, rhs_ele);
246 }
247 }
248 }
249 last_lhs = lhs_child.or(last_lhs);
250 }
251 }
252 }
253
254 #[cfg(test)]
255 mod tests {
256 use expect_test::{expect, Expect};
257 use itertools::Itertools;
258 use parser::SyntaxKind;
259 use text_edit::TextEdit;
260
261 use crate::{AstNode, SyntaxElement};
262
263 #[test]
264 fn replace_node_token() {
265 cov_mark::check!(diff_node_token_replace);
266 check_diff(
267 r#"use node;"#,
268 r#"ident"#,
269 expect![[r#"
270 insertions:
271
272
273
274 replacements:
275
276 Line 0: Token(USE_KW@0..3 "use") -> ident
277
278 deletions:
279
280 Line 1: " "
281 Line 1: node
282 Line 1: ;
283 "#]],
284 );
285 }
286
287 #[test]
288 fn replace_parent() {
289 cov_mark::check!(diff_insert_as_first_child);
290 check_diff(
291 r#""#,
292 r#"use foo::bar;"#,
293 expect![[r#"
294 insertions:
295
296 Line 0: AsFirstChild(Node(SOURCE_FILE@0..0))
297 -> use foo::bar;
298
299 replacements:
300
301
302
303 deletions:
304
305
306 "#]],
307 );
308 }
309
310 #[test]
311 fn insert_last() {
312 cov_mark::check!(diff_insert);
313 check_diff(
314 r#"
315 use foo;
316 use bar;"#,
317 r#"
318 use foo;
319 use bar;
320 use baz;"#,
321 expect![[r#"
322 insertions:
323
324 Line 2: After(Node(USE@10..18))
325 -> "\n"
326 -> use baz;
327
328 replacements:
329
330
331
332 deletions:
333
334
335 "#]],
336 );
337 }
338
339 #[test]
340 fn insert_middle() {
341 check_diff(
342 r#"
343 use foo;
344 use baz;"#,
345 r#"
346 use foo;
347 use bar;
348 use baz;"#,
349 expect![[r#"
350 insertions:
351
352 Line 2: After(Token(WHITESPACE@9..10 "\n"))
353 -> use bar;
354 -> "\n"
355
356 replacements:
357
358
359
360 deletions:
361
362
363 "#]],
364 )
365 }
366
367 #[test]
368 fn insert_first() {
369 check_diff(
370 r#"
371 use bar;
372 use baz;"#,
373 r#"
374 use foo;
375 use bar;
376 use baz;"#,
377 expect![[r#"
378 insertions:
379
380 Line 0: After(Token(WHITESPACE@0..1 "\n"))
381 -> use foo;
382 -> "\n"
383
384 replacements:
385
386
387
388 deletions:
389
390
391 "#]],
392 )
393 }
394
395 #[test]
396 fn first_child_insertion() {
397 cov_mark::check!(insert_first_child);
398 check_diff(
399 r#"fn main() {
400 stdi
401 }"#,
402 r#"use foo::bar;
403
404 fn main() {
405 stdi
406 }"#,
407 expect![[r#"
408 insertions:
409
410 Line 0: AsFirstChild(Node(SOURCE_FILE@0..30))
411 -> use foo::bar;
412 -> "\n\n "
413
414 replacements:
415
416
417
418 deletions:
419
420
421 "#]],
422 );
423 }
424
425 #[test]
426 fn delete_last() {
427 cov_mark::check!(diff_delete);
428 check_diff(
429 r#"use foo;
430 use bar;"#,
431 r#"use foo;"#,
432 expect![[r#"
433 insertions:
434
435
436
437 replacements:
438
439
440
441 deletions:
442
443 Line 1: "\n "
444 Line 2: use bar;
445 "#]],
446 );
447 }
448
449 #[test]
450 fn delete_middle() {
451 cov_mark::check!(diff_insertions);
452 check_diff(
453 r#"
454 use expect_test::{expect, Expect};
455 use text_edit::TextEdit;
456
457 use crate::AstNode;
458 "#,
459 r#"
460 use expect_test::{expect, Expect};
461
462 use crate::AstNode;
463 "#,
464 expect![[r#"
465 insertions:
466
467 Line 1: After(Node(USE@1..35))
468 -> "\n\n"
469 -> use crate::AstNode;
470
471 replacements:
472
473
474
475 deletions:
476
477 Line 2: use text_edit::TextEdit;
478 Line 3: "\n\n"
479 Line 4: use crate::AstNode;
480 Line 5: "\n"
481 "#]],
482 )
483 }
484
485 #[test]
486 fn delete_first() {
487 check_diff(
488 r#"
489 use text_edit::TextEdit;
490
491 use crate::AstNode;
492 "#,
493 r#"
494 use crate::AstNode;
495 "#,
496 expect![[r#"
497 insertions:
498
499
500
501 replacements:
502
503 Line 2: Token(IDENT@5..14 "text_edit") -> crate
504 Line 2: Token(IDENT@16..24 "TextEdit") -> AstNode
505 Line 2: Token(WHITESPACE@25..27 "\n\n") -> "\n"
506
507 deletions:
508
509 Line 3: use crate::AstNode;
510 Line 4: "\n"
511 "#]],
512 )
513 }
514
515 #[test]
516 fn merge_use() {
517 check_diff(
518 r#"
519 use std::{
520 fmt,
521 hash::BuildHasherDefault,
522 ops::{self, RangeInclusive},
523 };
524 "#,
525 r#"
526 use std::fmt;
527 use std::hash::BuildHasherDefault;
528 use std::ops::{self, RangeInclusive};
529 "#,
530 expect![[r#"
531 insertions:
532
533 Line 2: After(Node(PATH_SEGMENT@5..8))
534 -> ::
535 -> fmt
536 Line 6: After(Token(WHITESPACE@86..87 "\n"))
537 -> use std::hash::BuildHasherDefault;
538 -> "\n"
539 -> use std::ops::{self, RangeInclusive};
540 -> "\n"
541
542 replacements:
543
544 Line 2: Token(IDENT@5..8 "std") -> std
545
546 deletions:
547
548 Line 2: ::
549 Line 2: {
550 fmt,
551 hash::BuildHasherDefault,
552 ops::{self, RangeInclusive},
553 }
554 "#]],
555 )
556 }
557
558 #[test]
559 fn early_return_assist() {
560 check_diff(
561 r#"
562 fn main() {
563 if let Ok(x) = Err(92) {
564 foo(x);
565 }
566 }
567 "#,
568 r#"
569 fn main() {
570 let x = match Err(92) {
571 Ok(it) => it,
572 _ => return,
573 };
574 foo(x);
575 }
576 "#,
577 expect![[r#"
578 insertions:
579
580 Line 3: After(Node(BLOCK_EXPR@40..63))
581 -> " "
582 -> match Err(92) {
583 Ok(it) => it,
584 _ => return,
585 }
586 -> ;
587 Line 3: After(Node(IF_EXPR@17..63))
588 -> "\n "
589 -> foo(x);
590
591 replacements:
592
593 Line 3: Token(IF_KW@17..19 "if") -> let
594 Line 3: Token(LET_KW@20..23 "let") -> x
595 Line 3: Node(BLOCK_EXPR@40..63) -> =
596
597 deletions:
598
599 Line 3: " "
600 Line 3: Ok(x)
601 Line 3: " "
602 Line 3: =
603 Line 3: " "
604 Line 3: Err(92)
605 "#]],
606 )
607 }
608
609 fn check_diff(from: &str, to: &str, expected_diff: Expect) {
610 let from_node = crate::SourceFile::parse(from).tree().syntax().clone();
611 let to_node = crate::SourceFile::parse(to).tree().syntax().clone();
612 let diff = super::diff(&from_node, &to_node);
613
614 let line_number =
615 |syn: &SyntaxElement| from[..syn.text_range().start().into()].lines().count();
616
617 let fmt_syntax = |syn: &SyntaxElement| match syn.kind() {
618 SyntaxKind::WHITESPACE => format!("{:?}", syn.to_string()),
619 _ => format!("{}", syn),
620 };
621
622 let insertions =
623 diff.insertions.iter().format_with("\n", |(k, v), f| -> Result<(), std::fmt::Error> {
624 f(&format!(
625 "Line {}: {:?}\n-> {}",
626 line_number(match k {
627 super::TreeDiffInsertPos::After(syn) => syn,
628 super::TreeDiffInsertPos::AsFirstChild(syn) => syn,
629 }),
630 k,
631 v.iter().format_with("\n-> ", |v, f| f(&fmt_syntax(v)))
632 ))
633 });
634
635 let replacements = diff
636 .replacements
637 .iter()
638 .sorted_by_key(|(syntax, _)| syntax.text_range().start())
639 .format_with("\n", |(k, v), f| {
640 f(&format!("Line {}: {:?} -> {}", line_number(k), k, fmt_syntax(v)))
641 });
642
643 let deletions = diff
644 .deletions
645 .iter()
646 .format_with("\n", |v, f| f(&format!("Line {}: {}", line_number(v), &fmt_syntax(v))));
647
648 let actual = format!(
649 "insertions:\n\n{}\n\nreplacements:\n\n{}\n\ndeletions:\n\n{}\n",
650 insertions, replacements, deletions
651 );
652 expected_diff.assert_eq(&actual);
653
654 let mut from = from.to_owned();
655 let mut text_edit = TextEdit::builder();
656 diff.into_text_edit(&mut text_edit);
657 text_edit.finish().apply(&mut from);
658 assert_eq!(&*from, to, "diff did not turn `from` to `to`");
659 }
660 }