]> git.lizzy.rs Git - rust.git/blob - crates/syntax/src/algo.rs
internal: cleanup adt parsing
[rust.git] / crates / syntax / src / algo.rs
1 //! Collection of assorted algorithms for syntax trees.
2
3 use std::hash::BuildHasherDefault;
4
5 use indexmap::IndexMap;
6 use itertools::Itertools;
7 use rustc_hash::FxHashMap;
8 use text_edit::TextEditBuilder;
9
10 use crate::{
11     AstNode, Direction, NodeOrToken, SyntaxElement, SyntaxKind, SyntaxNode, SyntaxToken, TextRange,
12     TextSize,
13 };
14
15 /// Returns ancestors of the node at the offset, sorted by length. This should
16 /// do the right thing at an edge, e.g. when searching for expressions at `{
17 /// $0foo }` we will get the name reference instead of the whole block, which
18 /// we would get if we just did `find_token_at_offset(...).flat_map(|t|
19 /// t.parent().ancestors())`.
20 pub fn ancestors_at_offset(
21     node: &SyntaxNode,
22     offset: TextSize,
23 ) -> impl Iterator<Item = SyntaxNode> {
24     node.token_at_offset(offset)
25         .map(|token| token.ancestors())
26         .kmerge_by(|node1, node2| node1.text_range().len() < node2.text_range().len())
27 }
28
29 /// Finds a node of specific Ast type at offset. Note that this is slightly
30 /// imprecise: if the cursor is strictly between two nodes of the desired type,
31 /// as in
32 ///
33 /// ```no_run
34 /// struct Foo {}|struct Bar;
35 /// ```
36 ///
37 /// then the shorter node will be silently preferred.
38 pub fn find_node_at_offset<N: AstNode>(syntax: &SyntaxNode, offset: TextSize) -> Option<N> {
39     ancestors_at_offset(syntax, offset).find_map(N::cast)
40 }
41
42 pub fn find_node_at_range<N: AstNode>(syntax: &SyntaxNode, range: TextRange) -> Option<N> {
43     syntax.covering_element(range).ancestors().find_map(N::cast)
44 }
45
46 /// Skip to next non `trivia` token
47 pub fn skip_trivia_token(mut token: SyntaxToken, direction: Direction) -> Option<SyntaxToken> {
48     while token.kind().is_trivia() {
49         token = match direction {
50             Direction::Next => token.next_token()?,
51             Direction::Prev => token.prev_token()?,
52         }
53     }
54     Some(token)
55 }
56
57 /// Finds the first sibling in the given direction which is not `trivia`
58 pub fn non_trivia_sibling(element: SyntaxElement, direction: Direction) -> Option<SyntaxElement> {
59     return match element {
60         NodeOrToken::Node(node) => node.siblings_with_tokens(direction).skip(1).find(not_trivia),
61         NodeOrToken::Token(token) => token.siblings_with_tokens(direction).skip(1).find(not_trivia),
62     };
63
64     fn not_trivia(element: &SyntaxElement) -> bool {
65         match element {
66             NodeOrToken::Node(_) => true,
67             NodeOrToken::Token(token) => !token.kind().is_trivia(),
68         }
69     }
70 }
71
72 pub fn least_common_ancestor(u: &SyntaxNode, v: &SyntaxNode) -> Option<SyntaxNode> {
73     if u == v {
74         return Some(u.clone());
75     }
76
77     let u_depth = u.ancestors().count();
78     let v_depth = v.ancestors().count();
79     let keep = u_depth.min(v_depth);
80
81     let u_candidates = u.ancestors().skip(u_depth - keep);
82     let v_candidates = v.ancestors().skip(v_depth - keep);
83     let (res, _) = u_candidates.zip(v_candidates).find(|(x, y)| x == y)?;
84     Some(res)
85 }
86
87 pub fn neighbor<T: AstNode>(me: &T, direction: Direction) -> Option<T> {
88     me.syntax().siblings(direction).skip(1).find_map(T::cast)
89 }
90
91 pub fn has_errors(node: &SyntaxNode) -> bool {
92     node.children().any(|it| it.kind() == SyntaxKind::ERROR)
93 }
94
95 type FxIndexMap<K, V> = IndexMap<K, V, BuildHasherDefault<rustc_hash::FxHasher>>;
96
97 #[derive(Debug, Hash, PartialEq, Eq)]
98 enum TreeDiffInsertPos {
99     After(SyntaxElement),
100     AsFirstChild(SyntaxElement),
101 }
102
103 #[derive(Debug)]
104 pub struct TreeDiff {
105     replacements: FxHashMap<SyntaxElement, SyntaxElement>,
106     deletions: Vec<SyntaxElement>,
107     // the vec as well as the indexmap are both here to preserve order
108     insertions: FxIndexMap<TreeDiffInsertPos, Vec<SyntaxElement>>,
109 }
110
111 impl TreeDiff {
112     pub fn into_text_edit(&self, builder: &mut TextEditBuilder) {
113         let _p = profile::span("into_text_edit");
114
115         for (anchor, to) in self.insertions.iter() {
116             let offset = match anchor {
117                 TreeDiffInsertPos::After(it) => it.text_range().end(),
118                 TreeDiffInsertPos::AsFirstChild(it) => it.text_range().start(),
119             };
120             to.iter().for_each(|to| builder.insert(offset, to.to_string()));
121         }
122         for (from, to) in self.replacements.iter() {
123             builder.replace(from.text_range(), to.to_string())
124         }
125         for text_range in self.deletions.iter().map(SyntaxElement::text_range) {
126             builder.delete(text_range);
127         }
128     }
129
130     pub fn is_empty(&self) -> bool {
131         self.replacements.is_empty() && self.deletions.is_empty() && self.insertions.is_empty()
132     }
133 }
134
135 /// Finds a (potentially minimal) diff, which, applied to `from`, will result in `to`.
136 ///
137 /// Specifically, returns a structure that consists of a replacements, insertions and deletions
138 /// such that applying this map on `from` will result in `to`.
139 ///
140 /// This function tries to find a fine-grained diff.
141 pub fn diff(from: &SyntaxNode, to: &SyntaxNode) -> TreeDiff {
142     let _p = profile::span("diff");
143
144     let mut diff = TreeDiff {
145         replacements: FxHashMap::default(),
146         insertions: FxIndexMap::default(),
147         deletions: Vec::new(),
148     };
149     let (from, to) = (from.clone().into(), to.clone().into());
150
151     if !syntax_element_eq(&from, &to) {
152         go(&mut diff, from, to);
153     }
154     return diff;
155
156     fn syntax_element_eq(lhs: &SyntaxElement, rhs: &SyntaxElement) -> bool {
157         lhs.kind() == rhs.kind()
158             && lhs.text_range().len() == rhs.text_range().len()
159             && match (&lhs, &rhs) {
160                 (NodeOrToken::Node(lhs), NodeOrToken::Node(rhs)) => {
161                     lhs == rhs || lhs.text() == rhs.text()
162                 }
163                 (NodeOrToken::Token(lhs), NodeOrToken::Token(rhs)) => lhs.text() == rhs.text(),
164                 _ => false,
165             }
166     }
167
168     // FIXME: this is horribly inefficient. I bet there's a cool algorithm to diff trees properly.
169     fn go(diff: &mut TreeDiff, lhs: SyntaxElement, rhs: SyntaxElement) {
170         let (lhs, rhs) = match lhs.as_node().zip(rhs.as_node()) {
171             Some((lhs, rhs)) => (lhs, rhs),
172             _ => {
173                 cov_mark::hit!(diff_node_token_replace);
174                 diff.replacements.insert(lhs, rhs);
175                 return;
176             }
177         };
178
179         let mut look_ahead_scratch = Vec::default();
180
181         let mut rhs_children = rhs.children_with_tokens();
182         let mut lhs_children = lhs.children_with_tokens();
183         let mut last_lhs = None;
184         loop {
185             let lhs_child = lhs_children.next();
186             match (lhs_child.clone(), rhs_children.next()) {
187                 (None, None) => break,
188                 (None, Some(element)) => {
189                     let insert_pos = match last_lhs.clone() {
190                         Some(prev) => {
191                             cov_mark::hit!(diff_insert);
192                             TreeDiffInsertPos::After(prev)
193                         }
194                         // first iteration, insert into out parent as the first child
195                         None => {
196                             cov_mark::hit!(diff_insert_as_first_child);
197                             TreeDiffInsertPos::AsFirstChild(lhs.clone().into())
198                         }
199                     };
200                     diff.insertions.entry(insert_pos).or_insert_with(Vec::new).push(element);
201                 }
202                 (Some(element), None) => {
203                     cov_mark::hit!(diff_delete);
204                     diff.deletions.push(element);
205                 }
206                 (Some(ref lhs_ele), Some(ref rhs_ele)) if syntax_element_eq(lhs_ele, rhs_ele) => {}
207                 (Some(lhs_ele), Some(rhs_ele)) => {
208                     // nodes differ, look for lhs_ele in rhs, if its found we can mark everything up
209                     // until that element as insertions. This is important to keep the diff minimal
210                     // in regards to insertions that have been actually done, this is important for
211                     // use insertions as we do not want to replace the entire module node.
212                     look_ahead_scratch.push(rhs_ele.clone());
213                     let mut rhs_children_clone = rhs_children.clone();
214                     let mut insert = false;
215                     while let Some(rhs_child) = rhs_children_clone.next() {
216                         if syntax_element_eq(&lhs_ele, &rhs_child) {
217                             cov_mark::hit!(diff_insertions);
218                             insert = true;
219                             break;
220                         } else {
221                             look_ahead_scratch.push(rhs_child);
222                         }
223                     }
224                     let drain = look_ahead_scratch.drain(..);
225                     if insert {
226                         let insert_pos = if let Some(prev) = last_lhs.clone().filter(|_| insert) {
227                             TreeDiffInsertPos::After(prev)
228                         } else {
229                             cov_mark::hit!(insert_first_child);
230                             TreeDiffInsertPos::AsFirstChild(lhs.clone().into())
231                         };
232
233                         diff.insertions.entry(insert_pos).or_insert_with(Vec::new).extend(drain);
234                         rhs_children = rhs_children_clone;
235                     } else {
236                         go(diff, lhs_ele, rhs_ele)
237                     }
238                 }
239             }
240             last_lhs = lhs_child.or(last_lhs);
241         }
242     }
243 }
244
245 #[cfg(test)]
246 mod tests {
247     use expect_test::{expect, Expect};
248     use itertools::Itertools;
249     use parser::SyntaxKind;
250     use text_edit::TextEdit;
251
252     use crate::{AstNode, SyntaxElement};
253
254     #[test]
255     fn replace_node_token() {
256         cov_mark::check!(diff_node_token_replace);
257         check_diff(
258             r#"use node;"#,
259             r#"ident"#,
260             expect![[r#"
261                 insertions:
262
263
264
265                 replacements:
266
267                 Line 0: Token(USE_KW@0..3 "use") -> ident
268
269                 deletions:
270
271                 Line 1: " "
272                 Line 1: node
273                 Line 1: ;
274             "#]],
275         );
276     }
277
278     #[test]
279     fn replace_parent() {
280         cov_mark::check!(diff_insert_as_first_child);
281         check_diff(
282             r#""#,
283             r#"use foo::bar;"#,
284             expect![[r#"
285                 insertions:
286
287                 Line 0: AsFirstChild(Node(SOURCE_FILE@0..0))
288                 -> use foo::bar;
289
290                 replacements:
291
292
293
294                 deletions:
295
296
297             "#]],
298         );
299     }
300
301     #[test]
302     fn insert_last() {
303         cov_mark::check!(diff_insert);
304         check_diff(
305             r#"
306 use foo;
307 use bar;"#,
308             r#"
309 use foo;
310 use bar;
311 use baz;"#,
312             expect![[r#"
313                 insertions:
314
315                 Line 2: After(Node(USE@10..18))
316                 -> "\n"
317                 -> use baz;
318
319                 replacements:
320
321
322
323                 deletions:
324
325
326             "#]],
327         );
328     }
329
330     #[test]
331     fn insert_middle() {
332         check_diff(
333             r#"
334 use foo;
335 use baz;"#,
336             r#"
337 use foo;
338 use bar;
339 use baz;"#,
340             expect![[r#"
341                 insertions:
342
343                 Line 2: After(Token(WHITESPACE@9..10 "\n"))
344                 -> use bar;
345                 -> "\n"
346
347                 replacements:
348
349
350
351                 deletions:
352
353
354             "#]],
355         )
356     }
357
358     #[test]
359     fn insert_first() {
360         check_diff(
361             r#"
362 use bar;
363 use baz;"#,
364             r#"
365 use foo;
366 use bar;
367 use baz;"#,
368             expect![[r#"
369                 insertions:
370
371                 Line 0: After(Token(WHITESPACE@0..1 "\n"))
372                 -> use foo;
373                 -> "\n"
374
375                 replacements:
376
377
378
379                 deletions:
380
381
382             "#]],
383         )
384     }
385
386     #[test]
387     fn first_child_insertion() {
388         cov_mark::check!(insert_first_child);
389         check_diff(
390             r#"fn main() {
391         stdi
392     }"#,
393             r#"use foo::bar;
394
395     fn main() {
396         stdi
397     }"#,
398             expect![[r#"
399                 insertions:
400
401                 Line 0: AsFirstChild(Node(SOURCE_FILE@0..30))
402                 -> use foo::bar;
403                 -> "\n\n    "
404
405                 replacements:
406
407
408
409                 deletions:
410
411
412             "#]],
413         );
414     }
415
416     #[test]
417     fn delete_last() {
418         cov_mark::check!(diff_delete);
419         check_diff(
420             r#"use foo;
421             use bar;"#,
422             r#"use foo;"#,
423             expect![[r#"
424                 insertions:
425
426
427
428                 replacements:
429
430
431
432                 deletions:
433
434                 Line 1: "\n            "
435                 Line 2: use bar;
436             "#]],
437         );
438     }
439
440     #[test]
441     fn delete_middle() {
442         cov_mark::check!(diff_insertions);
443         check_diff(
444             r#"
445 use expect_test::{expect, Expect};
446 use text_edit::TextEdit;
447
448 use crate::AstNode;
449 "#,
450             r#"
451 use expect_test::{expect, Expect};
452
453 use crate::AstNode;
454 "#,
455             expect![[r#"
456                 insertions:
457
458                 Line 1: After(Node(USE@1..35))
459                 -> "\n\n"
460                 -> use crate::AstNode;
461
462                 replacements:
463
464
465
466                 deletions:
467
468                 Line 2: use text_edit::TextEdit;
469                 Line 3: "\n\n"
470                 Line 4: use crate::AstNode;
471                 Line 5: "\n"
472             "#]],
473         )
474     }
475
476     #[test]
477     fn delete_first() {
478         check_diff(
479             r#"
480 use text_edit::TextEdit;
481
482 use crate::AstNode;
483 "#,
484             r#"
485 use crate::AstNode;
486 "#,
487             expect![[r#"
488                 insertions:
489
490
491
492                 replacements:
493
494                 Line 2: Token(IDENT@5..14 "text_edit") -> crate
495                 Line 2: Token(IDENT@16..24 "TextEdit") -> AstNode
496                 Line 2: Token(WHITESPACE@25..27 "\n\n") -> "\n"
497
498                 deletions:
499
500                 Line 3: use crate::AstNode;
501                 Line 4: "\n"
502             "#]],
503         )
504     }
505
506     #[test]
507     fn merge_use() {
508         check_diff(
509             r#"
510 use std::{
511     fmt,
512     hash::BuildHasherDefault,
513     ops::{self, RangeInclusive},
514 };
515 "#,
516             r#"
517 use std::fmt;
518 use std::hash::BuildHasherDefault;
519 use std::ops::{self, RangeInclusive};
520 "#,
521             expect![[r#"
522                 insertions:
523
524                 Line 2: After(Node(PATH_SEGMENT@5..8))
525                 -> ::
526                 -> fmt
527                 Line 6: After(Token(WHITESPACE@86..87 "\n"))
528                 -> use std::hash::BuildHasherDefault;
529                 -> "\n"
530                 -> use std::ops::{self, RangeInclusive};
531                 -> "\n"
532
533                 replacements:
534
535                 Line 2: Token(IDENT@5..8 "std") -> std
536
537                 deletions:
538
539                 Line 2: ::
540                 Line 2: {
541                     fmt,
542                     hash::BuildHasherDefault,
543                     ops::{self, RangeInclusive},
544                 }
545             "#]],
546         )
547     }
548
549     #[test]
550     fn early_return_assist() {
551         check_diff(
552             r#"
553 fn main() {
554     if let Ok(x) = Err(92) {
555         foo(x);
556     }
557 }
558             "#,
559             r#"
560 fn main() {
561     let x = match Err(92) {
562         Ok(it) => it,
563         _ => return,
564     };
565     foo(x);
566 }
567             "#,
568             expect![[r#"
569                 insertions:
570
571                 Line 3: After(Node(BLOCK_EXPR@40..63))
572                 -> " "
573                 -> match Err(92) {
574                         Ok(it) => it,
575                         _ => return,
576                     }
577                 -> ;
578                 Line 3: After(Node(IF_EXPR@17..63))
579                 -> "\n    "
580                 -> foo(x);
581
582                 replacements:
583
584                 Line 3: Token(IF_KW@17..19 "if") -> let
585                 Line 3: Token(LET_KW@20..23 "let") -> x
586                 Line 3: Node(BLOCK_EXPR@40..63) -> =
587
588                 deletions:
589
590                 Line 3: " "
591                 Line 3: Ok(x)
592                 Line 3: " "
593                 Line 3: =
594                 Line 3: " "
595                 Line 3: Err(92)
596             "#]],
597         )
598     }
599
600     fn check_diff(from: &str, to: &str, expected_diff: Expect) {
601         let from_node = crate::SourceFile::parse(from).tree().syntax().clone();
602         let to_node = crate::SourceFile::parse(to).tree().syntax().clone();
603         let diff = super::diff(&from_node, &to_node);
604
605         let line_number =
606             |syn: &SyntaxElement| from[..syn.text_range().start().into()].lines().count();
607
608         let fmt_syntax = |syn: &SyntaxElement| match syn.kind() {
609             SyntaxKind::WHITESPACE => format!("{:?}", syn.to_string()),
610             _ => format!("{}", syn),
611         };
612
613         let insertions =
614             diff.insertions.iter().format_with("\n", |(k, v), f| -> Result<(), std::fmt::Error> {
615                 f(&format!(
616                     "Line {}: {:?}\n-> {}",
617                     line_number(match k {
618                         super::TreeDiffInsertPos::After(syn) => syn,
619                         super::TreeDiffInsertPos::AsFirstChild(syn) => syn,
620                     }),
621                     k,
622                     v.iter().format_with("\n-> ", |v, f| f(&fmt_syntax(v)))
623                 ))
624             });
625
626         let replacements = diff
627             .replacements
628             .iter()
629             .sorted_by_key(|(syntax, _)| syntax.text_range().start())
630             .format_with("\n", |(k, v), f| {
631                 f(&format!("Line {}: {:?} -> {}", line_number(k), k, fmt_syntax(v)))
632             });
633
634         let deletions = diff
635             .deletions
636             .iter()
637             .format_with("\n", |v, f| f(&format!("Line {}: {}", line_number(v), &fmt_syntax(v))));
638
639         let actual = format!(
640             "insertions:\n\n{}\n\nreplacements:\n\n{}\n\ndeletions:\n\n{}\n",
641             insertions, replacements, deletions
642         );
643         expected_diff.assert_eq(&actual);
644
645         let mut from = from.to_owned();
646         let mut text_edit = TextEdit::builder();
647         diff.into_text_edit(&mut text_edit);
648         text_edit.finish().apply(&mut from);
649         assert_eq!(&*from, to, "diff did not turn `from` to `to`");
650     }
651 }