]> git.lizzy.rs Git - rust.git/blob - crates/syntax/src/algo.rs
827ae78f95db12a32f559080b145fe1ebf1afb44
[rust.git] / crates / syntax / src / algo.rs
1 //! FIXME: write short doc here
2
3 use std::{
4     fmt,
5     hash::BuildHasherDefault,
6     ops::{self, RangeInclusive},
7 };
8
9 use indexmap::IndexMap;
10 use itertools::Itertools;
11 use rustc_hash::FxHashMap;
12 use test_utils::mark;
13 use text_edit::TextEditBuilder;
14
15 use crate::{
16     AstNode, Direction, NodeOrToken, SyntaxElement, SyntaxKind, SyntaxNode, SyntaxNodePtr,
17     SyntaxToken, TextRange, TextSize,
18 };
19
20 /// Returns ancestors of the node at the offset, sorted by length. This should
21 /// do the right thing at an edge, e.g. when searching for expressions at `{
22 /// $0foo }` we will get the name reference instead of the whole block, which
23 /// we would get if we just did `find_token_at_offset(...).flat_map(|t|
24 /// t.parent().ancestors())`.
25 pub fn ancestors_at_offset(
26     node: &SyntaxNode,
27     offset: TextSize,
28 ) -> impl Iterator<Item = SyntaxNode> {
29     node.token_at_offset(offset)
30         .map(|token| token.parent().ancestors())
31         .kmerge_by(|node1, node2| node1.text_range().len() < node2.text_range().len())
32 }
33
34 /// Finds a node of specific Ast type at offset. Note that this is slightly
35 /// imprecise: if the cursor is strictly between two nodes of the desired type,
36 /// as in
37 ///
38 /// ```no_run
39 /// struct Foo {}|struct Bar;
40 /// ```
41 ///
42 /// then the shorter node will be silently preferred.
43 pub fn find_node_at_offset<N: AstNode>(syntax: &SyntaxNode, offset: TextSize) -> Option<N> {
44     ancestors_at_offset(syntax, offset).find_map(N::cast)
45 }
46
47 pub fn find_node_at_range<N: AstNode>(syntax: &SyntaxNode, range: TextRange) -> Option<N> {
48     syntax.covering_element(range).ancestors().find_map(N::cast)
49 }
50
51 /// Skip to next non `trivia` token
52 pub fn skip_trivia_token(mut token: SyntaxToken, direction: Direction) -> Option<SyntaxToken> {
53     while token.kind().is_trivia() {
54         token = match direction {
55             Direction::Next => token.next_token()?,
56             Direction::Prev => token.prev_token()?,
57         }
58     }
59     Some(token)
60 }
61
62 /// Finds the first sibling in the given direction which is not `trivia`
63 pub fn non_trivia_sibling(element: SyntaxElement, direction: Direction) -> Option<SyntaxElement> {
64     return match element {
65         NodeOrToken::Node(node) => node.siblings_with_tokens(direction).skip(1).find(not_trivia),
66         NodeOrToken::Token(token) => token.siblings_with_tokens(direction).skip(1).find(not_trivia),
67     };
68
69     fn not_trivia(element: &SyntaxElement) -> bool {
70         match element {
71             NodeOrToken::Node(_) => true,
72             NodeOrToken::Token(token) => !token.kind().is_trivia(),
73         }
74     }
75 }
76
77 pub fn least_common_ancestor(u: &SyntaxNode, v: &SyntaxNode) -> Option<SyntaxNode> {
78     if u == v {
79         return Some(u.clone());
80     }
81
82     let u_depth = u.ancestors().count();
83     let v_depth = v.ancestors().count();
84     let keep = u_depth.min(v_depth);
85
86     let u_candidates = u.ancestors().skip(u_depth - keep);
87     let v_candidates = v.ancestors().skip(v_depth - keep);
88     let (res, _) = u_candidates.zip(v_candidates).find(|(x, y)| x == y)?;
89     Some(res)
90 }
91
92 pub fn neighbor<T: AstNode>(me: &T, direction: Direction) -> Option<T> {
93     me.syntax().siblings(direction).skip(1).find_map(T::cast)
94 }
95
96 pub fn has_errors(node: &SyntaxNode) -> bool {
97     node.children().any(|it| it.kind() == SyntaxKind::ERROR)
98 }
99
100 #[derive(Debug, PartialEq, Eq, Clone, Copy)]
101 pub enum InsertPosition<T> {
102     First,
103     Last,
104     Before(T),
105     After(T),
106 }
107
108 type FxIndexMap<K, V> = IndexMap<K, V, BuildHasherDefault<rustc_hash::FxHasher>>;
109
110 #[derive(Debug, Hash, PartialEq, Eq)]
111 enum TreeDiffInsertPos {
112     After(SyntaxElement),
113     AsFirstChild(SyntaxElement),
114 }
115
116 #[derive(Debug)]
117 pub struct TreeDiff {
118     replacements: FxHashMap<SyntaxElement, SyntaxElement>,
119     deletions: Vec<SyntaxElement>,
120     // the vec as well as the indexmap are both here to preserve order
121     insertions: FxIndexMap<TreeDiffInsertPos, Vec<SyntaxElement>>,
122 }
123
124 impl TreeDiff {
125     pub fn into_text_edit(&self, builder: &mut TextEditBuilder) {
126         let _p = profile::span("into_text_edit");
127
128         for (anchor, to) in self.insertions.iter() {
129             let offset = match anchor {
130                 TreeDiffInsertPos::After(it) => it.text_range().end(),
131                 TreeDiffInsertPos::AsFirstChild(it) => it.text_range().start(),
132             };
133             to.iter().for_each(|to| builder.insert(offset, to.to_string()));
134         }
135         for (from, to) in self.replacements.iter() {
136             builder.replace(from.text_range(), to.to_string())
137         }
138         for text_range in self.deletions.iter().map(SyntaxElement::text_range) {
139             builder.delete(text_range);
140         }
141     }
142
143     pub fn is_empty(&self) -> bool {
144         self.replacements.is_empty() && self.deletions.is_empty() && self.insertions.is_empty()
145     }
146 }
147
148 /// Finds a (potentially minimal) diff, which, applied to `from`, will result in `to`.
149 ///
150 /// Specifically, returns a structure that consists of a replacements, insertions and deletions
151 /// such that applying this map on `from` will result in `to`.
152 ///
153 /// This function tries to find a fine-grained diff.
154 pub fn diff(from: &SyntaxNode, to: &SyntaxNode) -> TreeDiff {
155     let _p = profile::span("diff");
156
157     let mut diff = TreeDiff {
158         replacements: FxHashMap::default(),
159         insertions: FxIndexMap::default(),
160         deletions: Vec::new(),
161     };
162     let (from, to) = (from.clone().into(), to.clone().into());
163
164     if !syntax_element_eq(&from, &to) {
165         go(&mut diff, from, to);
166     }
167     return diff;
168
169     fn syntax_element_eq(lhs: &SyntaxElement, rhs: &SyntaxElement) -> bool {
170         lhs.kind() == rhs.kind()
171             && lhs.text_range().len() == rhs.text_range().len()
172             && match (&lhs, &rhs) {
173                 (NodeOrToken::Node(lhs), NodeOrToken::Node(rhs)) => {
174                     lhs.green() == rhs.green() || lhs.text() == rhs.text()
175                 }
176                 (NodeOrToken::Token(lhs), NodeOrToken::Token(rhs)) => lhs.text() == rhs.text(),
177                 _ => false,
178             }
179     }
180
181     // FIXME: this is horrible inefficient. I bet there's a cool algorithm to diff trees properly.
182     fn go(diff: &mut TreeDiff, lhs: SyntaxElement, rhs: SyntaxElement) {
183         let (lhs, rhs) = match lhs.as_node().zip(rhs.as_node()) {
184             Some((lhs, rhs)) => (lhs, rhs),
185             _ => {
186                 mark::hit!(diff_node_token_replace);
187                 diff.replacements.insert(lhs, rhs);
188                 return;
189             }
190         };
191
192         let mut look_ahead_scratch = Vec::default();
193
194         let mut rhs_children = rhs.children_with_tokens();
195         let mut lhs_children = lhs.children_with_tokens();
196         let mut last_lhs = None;
197         loop {
198             let lhs_child = lhs_children.next();
199             match (lhs_child.clone(), rhs_children.next()) {
200                 (None, None) => break,
201                 (None, Some(element)) => {
202                     let insert_pos = match last_lhs.clone() {
203                         Some(prev) => {
204                             mark::hit!(diff_insert);
205                             TreeDiffInsertPos::After(prev)
206                         }
207                         // first iteration, insert into out parent as the first child
208                         None => {
209                             mark::hit!(diff_insert_as_first_child);
210                             TreeDiffInsertPos::AsFirstChild(lhs.clone().into())
211                         }
212                     };
213                     diff.insertions.entry(insert_pos).or_insert_with(Vec::new).push(element);
214                 }
215                 (Some(element), None) => {
216                     mark::hit!(diff_delete);
217                     diff.deletions.push(element);
218                 }
219                 (Some(ref lhs_ele), Some(ref rhs_ele)) if syntax_element_eq(lhs_ele, rhs_ele) => {}
220                 (Some(lhs_ele), Some(rhs_ele)) => {
221                     // nodes differ, look for lhs_ele in rhs, if its found we can mark everything up
222                     // until that element as insertions. This is important to keep the diff minimal
223                     // in regards to insertions that have been actually done, this is important for
224                     // use insertions as we do not want to replace the entire module node.
225                     look_ahead_scratch.push(rhs_ele.clone());
226                     let mut rhs_children_clone = rhs_children.clone();
227                     let mut insert = false;
228                     while let Some(rhs_child) = rhs_children_clone.next() {
229                         if syntax_element_eq(&lhs_ele, &rhs_child) {
230                             mark::hit!(diff_insertions);
231                             insert = true;
232                             break;
233                         } else {
234                             look_ahead_scratch.push(rhs_child);
235                         }
236                     }
237                     let drain = look_ahead_scratch.drain(..);
238                     if insert {
239                         let insert_pos = if let Some(prev) = last_lhs.clone().filter(|_| insert) {
240                             TreeDiffInsertPos::After(prev)
241                         } else {
242                             mark::hit!(insert_first_child);
243                             TreeDiffInsertPos::AsFirstChild(lhs.clone().into())
244                         };
245
246                         diff.insertions.entry(insert_pos).or_insert_with(Vec::new).extend(drain);
247                         rhs_children = rhs_children_clone;
248                     } else {
249                         go(diff, lhs_ele, rhs_ele)
250                     }
251                 }
252             }
253             last_lhs = lhs_child.or(last_lhs);
254         }
255     }
256 }
257
258 /// Adds specified children (tokens or nodes) to the current node at the
259 /// specific position.
260 ///
261 /// This is a type-unsafe low-level editing API, if you need to use it,
262 /// prefer to create a type-safe abstraction on top of it instead.
263 pub fn insert_children(
264     parent: &SyntaxNode,
265     position: InsertPosition<SyntaxElement>,
266     to_insert: impl IntoIterator<Item = SyntaxElement>,
267 ) -> SyntaxNode {
268     let mut to_insert = to_insert.into_iter();
269     _insert_children(parent, position, &mut to_insert)
270 }
271
272 fn _insert_children(
273     parent: &SyntaxNode,
274     position: InsertPosition<SyntaxElement>,
275     to_insert: &mut dyn Iterator<Item = SyntaxElement>,
276 ) -> SyntaxNode {
277     let mut delta = TextSize::default();
278     let to_insert = to_insert.map(|element| {
279         delta += element.text_range().len();
280         to_green_element(element)
281     });
282
283     let mut old_children = parent.green().children().map(|it| match it {
284         NodeOrToken::Token(it) => NodeOrToken::Token(it.clone()),
285         NodeOrToken::Node(it) => NodeOrToken::Node(it.clone()),
286     });
287
288     let new_children = match &position {
289         InsertPosition::First => to_insert.chain(old_children).collect::<Vec<_>>(),
290         InsertPosition::Last => old_children.chain(to_insert).collect::<Vec<_>>(),
291         InsertPosition::Before(anchor) | InsertPosition::After(anchor) => {
292             let take_anchor = if let InsertPosition::After(_) = position { 1 } else { 0 };
293             let split_at = position_of_child(parent, anchor.clone()) + take_anchor;
294             let before = old_children.by_ref().take(split_at).collect::<Vec<_>>();
295             before.into_iter().chain(to_insert).chain(old_children).collect::<Vec<_>>()
296         }
297     };
298
299     with_children(parent, new_children)
300 }
301
302 /// Replaces all nodes in `to_delete` with nodes from `to_insert`
303 ///
304 /// This is a type-unsafe low-level editing API, if you need to use it,
305 /// prefer to create a type-safe abstraction on top of it instead.
306 pub fn replace_children(
307     parent: &SyntaxNode,
308     to_delete: RangeInclusive<SyntaxElement>,
309     to_insert: impl IntoIterator<Item = SyntaxElement>,
310 ) -> SyntaxNode {
311     let mut to_insert = to_insert.into_iter();
312     _replace_children(parent, to_delete, &mut to_insert)
313 }
314
315 fn _replace_children(
316     parent: &SyntaxNode,
317     to_delete: RangeInclusive<SyntaxElement>,
318     to_insert: &mut dyn Iterator<Item = SyntaxElement>,
319 ) -> SyntaxNode {
320     let start = position_of_child(parent, to_delete.start().clone());
321     let end = position_of_child(parent, to_delete.end().clone());
322     let mut old_children = parent.green().children().map(|it| match it {
323         NodeOrToken::Token(it) => NodeOrToken::Token(it.clone()),
324         NodeOrToken::Node(it) => NodeOrToken::Node(it.clone()),
325     });
326
327     let before = old_children.by_ref().take(start).collect::<Vec<_>>();
328     let new_children = before
329         .into_iter()
330         .chain(to_insert.map(to_green_element))
331         .chain(old_children.skip(end + 1 - start))
332         .collect::<Vec<_>>();
333     with_children(parent, new_children)
334 }
335
336 #[derive(Debug, PartialEq, Eq, Hash)]
337 enum InsertPos {
338     FirstChildOf(SyntaxNode),
339     After(SyntaxElement),
340 }
341
342 #[derive(Default)]
343 pub struct SyntaxRewriter<'a> {
344     f: Option<Box<dyn Fn(&SyntaxElement) -> Option<SyntaxElement> + 'a>>,
345     //FIXME: add debug_assertions that all elements are in fact from the same file.
346     replacements: FxHashMap<SyntaxElement, Replacement>,
347     insertions: IndexMap<InsertPos, Vec<SyntaxElement>>,
348 }
349
350 impl fmt::Debug for SyntaxRewriter<'_> {
351     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
352         f.debug_struct("SyntaxRewriter")
353             .field("replacements", &self.replacements)
354             .field("insertions", &self.insertions)
355             .finish()
356     }
357 }
358
359 impl<'a> SyntaxRewriter<'a> {
360     pub fn from_fn(f: impl Fn(&SyntaxElement) -> Option<SyntaxElement> + 'a) -> SyntaxRewriter<'a> {
361         SyntaxRewriter {
362             f: Some(Box::new(f)),
363             replacements: FxHashMap::default(),
364             insertions: IndexMap::default(),
365         }
366     }
367     pub fn delete<T: Clone + Into<SyntaxElement>>(&mut self, what: &T) {
368         let what = what.clone().into();
369         let replacement = Replacement::Delete;
370         self.replacements.insert(what, replacement);
371     }
372     pub fn insert_before<T: Clone + Into<SyntaxElement>, U: Clone + Into<SyntaxElement>>(
373         &mut self,
374         before: &T,
375         what: &U,
376     ) {
377         let before = before.clone().into();
378         let pos = match before.prev_sibling_or_token() {
379             Some(sibling) => InsertPos::After(sibling),
380             None => match before.parent() {
381                 Some(parent) => InsertPos::FirstChildOf(parent),
382                 None => return,
383             },
384         };
385         self.insertions.entry(pos).or_insert_with(Vec::new).push(what.clone().into());
386     }
387     pub fn insert_after<T: Clone + Into<SyntaxElement>, U: Clone + Into<SyntaxElement>>(
388         &mut self,
389         after: &T,
390         what: &U,
391     ) {
392         self.insertions
393             .entry(InsertPos::After(after.clone().into()))
394             .or_insert_with(Vec::new)
395             .push(what.clone().into());
396     }
397     pub fn insert_as_first_child<T: Clone + Into<SyntaxNode>, U: Clone + Into<SyntaxElement>>(
398         &mut self,
399         parent: &T,
400         what: &U,
401     ) {
402         self.insertions
403             .entry(InsertPos::FirstChildOf(parent.clone().into()))
404             .or_insert_with(Vec::new)
405             .push(what.clone().into());
406     }
407     pub fn insert_many_before<
408         T: Clone + Into<SyntaxElement>,
409         U: IntoIterator<Item = SyntaxElement>,
410     >(
411         &mut self,
412         before: &T,
413         what: U,
414     ) {
415         let before = before.clone().into();
416         let pos = match before.prev_sibling_or_token() {
417             Some(sibling) => InsertPos::After(sibling),
418             None => match before.parent() {
419                 Some(parent) => InsertPos::FirstChildOf(parent),
420                 None => return,
421             },
422         };
423         self.insertions.entry(pos).or_insert_with(Vec::new).extend(what);
424     }
425     pub fn insert_many_after<
426         T: Clone + Into<SyntaxElement>,
427         U: IntoIterator<Item = SyntaxElement>,
428     >(
429         &mut self,
430         after: &T,
431         what: U,
432     ) {
433         self.insertions
434             .entry(InsertPos::After(after.clone().into()))
435             .or_insert_with(Vec::new)
436             .extend(what);
437     }
438     pub fn insert_many_as_first_children<
439         T: Clone + Into<SyntaxNode>,
440         U: IntoIterator<Item = SyntaxElement>,
441     >(
442         &mut self,
443         parent: &T,
444         what: U,
445     ) {
446         self.insertions
447             .entry(InsertPos::FirstChildOf(parent.clone().into()))
448             .or_insert_with(Vec::new)
449             .extend(what)
450     }
451     pub fn replace<T: Clone + Into<SyntaxElement>>(&mut self, what: &T, with: &T) {
452         let what = what.clone().into();
453         let replacement = Replacement::Single(with.clone().into());
454         self.replacements.insert(what, replacement);
455     }
456     pub fn replace_with_many<T: Clone + Into<SyntaxElement>>(
457         &mut self,
458         what: &T,
459         with: Vec<SyntaxElement>,
460     ) {
461         let what = what.clone().into();
462         let replacement = Replacement::Many(with);
463         self.replacements.insert(what, replacement);
464     }
465     pub fn replace_ast<T: AstNode>(&mut self, what: &T, with: &T) {
466         self.replace(what.syntax(), with.syntax())
467     }
468
469     pub fn rewrite(&self, node: &SyntaxNode) -> SyntaxNode {
470         let _p = profile::span("rewrite");
471
472         if self.f.is_none() && self.replacements.is_empty() && self.insertions.is_empty() {
473             return node.clone();
474         }
475         let green = self.rewrite_children(node);
476         with_green(node, green)
477     }
478
479     pub fn rewrite_ast<N: AstNode>(self, node: &N) -> N {
480         N::cast(self.rewrite(node.syntax())).unwrap()
481     }
482
483     /// Returns a node that encompasses all replacements to be done by this rewriter.
484     ///
485     /// Passing the returned node to `rewrite` will apply all replacements queued up in `self`.
486     ///
487     /// Returns `None` when there are no replacements.
488     pub fn rewrite_root(&self) -> Option<SyntaxNode> {
489         let _p = profile::span("rewrite_root");
490         fn element_to_node_or_parent(element: &SyntaxElement) -> SyntaxNode {
491             match element {
492                 SyntaxElement::Node(it) => it.clone(),
493                 SyntaxElement::Token(it) => it.parent(),
494             }
495         }
496
497         assert!(self.f.is_none());
498         self.replacements
499             .keys()
500             .map(element_to_node_or_parent)
501             .chain(self.insertions.keys().map(|pos| match pos {
502                 InsertPos::FirstChildOf(it) => it.clone(),
503                 InsertPos::After(it) => element_to_node_or_parent(it),
504             }))
505             // If we only have one replacement/insertion, we must return its parent node, since `rewrite` does
506             // not replace the node passed to it.
507             .map(|it| it.parent().unwrap_or(it))
508             .fold1(|a, b| least_common_ancestor(&a, &b).unwrap())
509     }
510
511     fn replacement(&self, element: &SyntaxElement) -> Option<Replacement> {
512         if let Some(f) = &self.f {
513             assert!(self.replacements.is_empty());
514             return f(element).map(Replacement::Single);
515         }
516         self.replacements.get(element).cloned()
517     }
518
519     fn insertions(&self, pos: &InsertPos) -> Option<impl Iterator<Item = SyntaxElement> + '_> {
520         self.insertions.get(pos).map(|insertions| insertions.iter().cloned())
521     }
522
523     fn rewrite_children(&self, node: &SyntaxNode) -> rowan::GreenNode {
524         let _p = profile::span("rewrite_children");
525
526         //  FIXME: this could be made much faster.
527         let mut new_children = Vec::new();
528         if let Some(elements) = self.insertions(&InsertPos::FirstChildOf(node.clone())) {
529             new_children.extend(elements.map(element_to_green));
530         }
531         for child in node.children_with_tokens() {
532             self.rewrite_self(&mut new_children, &child);
533         }
534
535         rowan::GreenNode::new(rowan::SyntaxKind(node.kind() as u16), new_children)
536     }
537
538     fn rewrite_self(
539         &self,
540         acc: &mut Vec<NodeOrToken<rowan::GreenNode, rowan::GreenToken>>,
541         element: &SyntaxElement,
542     ) {
543         let _p = profile::span("rewrite_self");
544
545         if let Some(replacement) = self.replacement(&element) {
546             match replacement {
547                 Replacement::Single(element) => acc.push(element_to_green(element)),
548                 Replacement::Many(replacements) => {
549                     acc.extend(replacements.into_iter().map(element_to_green))
550                 }
551                 Replacement::Delete => (),
552             };
553         } else {
554             match element {
555                 NodeOrToken::Token(it) => acc.push(NodeOrToken::Token(it.green().clone())),
556                 NodeOrToken::Node(it) => {
557                     acc.push(NodeOrToken::Node(self.rewrite_children(it)));
558                 }
559             }
560         }
561         if let Some(elements) = self.insertions(&InsertPos::After(element.clone())) {
562             acc.extend(elements.map(element_to_green));
563         }
564     }
565 }
566
567 fn element_to_green(element: SyntaxElement) -> NodeOrToken<rowan::GreenNode, rowan::GreenToken> {
568     match element {
569         NodeOrToken::Node(it) => NodeOrToken::Node(it.green().clone()),
570         NodeOrToken::Token(it) => NodeOrToken::Token(it.green().clone()),
571     }
572 }
573
574 impl ops::AddAssign for SyntaxRewriter<'_> {
575     fn add_assign(&mut self, rhs: SyntaxRewriter) {
576         assert!(rhs.f.is_none());
577         self.replacements.extend(rhs.replacements);
578         for (pos, insertions) in rhs.insertions.into_iter() {
579             match self.insertions.entry(pos) {
580                 indexmap::map::Entry::Occupied(mut occupied) => {
581                     occupied.get_mut().extend(insertions)
582                 }
583                 indexmap::map::Entry::Vacant(vacant) => drop(vacant.insert(insertions)),
584             }
585         }
586     }
587 }
588
589 #[derive(Clone, Debug)]
590 enum Replacement {
591     Delete,
592     Single(SyntaxElement),
593     Many(Vec<SyntaxElement>),
594 }
595
596 fn with_children(
597     parent: &SyntaxNode,
598     new_children: Vec<NodeOrToken<rowan::GreenNode, rowan::GreenToken>>,
599 ) -> SyntaxNode {
600     let _p = profile::span("with_children");
601
602     let new_green = rowan::GreenNode::new(rowan::SyntaxKind(parent.kind() as u16), new_children);
603     with_green(parent, new_green)
604 }
605
606 fn with_green(syntax_node: &SyntaxNode, green: rowan::GreenNode) -> SyntaxNode {
607     let len = green.children().map(|it| it.text_len()).sum::<TextSize>();
608     let new_root_node = syntax_node.replace_with(green);
609     let new_root_node = SyntaxNode::new_root(new_root_node);
610
611     // FIXME: use a more elegant way to re-fetch the node (#1185), make
612     // `range` private afterwards
613     let mut ptr = SyntaxNodePtr::new(syntax_node);
614     ptr.range = TextRange::at(ptr.range.start(), len);
615     ptr.to_node(&new_root_node)
616 }
617
618 fn position_of_child(parent: &SyntaxNode, child: SyntaxElement) -> usize {
619     parent
620         .children_with_tokens()
621         .position(|it| it == child)
622         .expect("element is not a child of current element")
623 }
624
625 fn to_green_element(element: SyntaxElement) -> NodeOrToken<rowan::GreenNode, rowan::GreenToken> {
626     match element {
627         NodeOrToken::Node(it) => it.green().clone().into(),
628         NodeOrToken::Token(it) => it.green().clone().into(),
629     }
630 }
631
632 #[cfg(test)]
633 mod tests {
634     use expect_test::{expect, Expect};
635     use itertools::Itertools;
636     use parser::SyntaxKind;
637     use test_utils::mark;
638     use text_edit::TextEdit;
639
640     use crate::{AstNode, SyntaxElement};
641
642     #[test]
643     fn replace_node_token() {
644         mark::check!(diff_node_token_replace);
645         check_diff(
646             r#"use node;"#,
647             r#"ident"#,
648             expect![[r#"
649                 insertions:
650
651
652
653                 replacements:
654
655                 Line 0: Token(USE_KW@0..3 "use") -> ident
656
657                 deletions:
658
659                 Line 1: " "
660                 Line 1: node
661                 Line 1: ;
662             "#]],
663         );
664     }
665
666     #[test]
667     fn replace_parent() {
668         mark::check!(diff_insert_as_first_child);
669         check_diff(
670             r#""#,
671             r#"use foo::bar;"#,
672             expect![[r#"
673                 insertions:
674
675                 Line 0: AsFirstChild(Node(SOURCE_FILE@0..0))
676                 -> use foo::bar;
677
678                 replacements:
679
680
681
682                 deletions:
683
684
685             "#]],
686         );
687     }
688
689     #[test]
690     fn insert_last() {
691         mark::check!(diff_insert);
692         check_diff(
693             r#"
694 use foo;
695 use bar;"#,
696             r#"
697 use foo;
698 use bar;
699 use baz;"#,
700             expect![[r#"
701                 insertions:
702
703                 Line 2: After(Node(USE@10..18))
704                 -> "\n"
705                 -> use baz;
706
707                 replacements:
708
709
710
711                 deletions:
712
713
714             "#]],
715         );
716     }
717
718     #[test]
719     fn insert_middle() {
720         check_diff(
721             r#"
722 use foo;
723 use baz;"#,
724             r#"
725 use foo;
726 use bar;
727 use baz;"#,
728             expect![[r#"
729                 insertions:
730
731                 Line 2: After(Token(WHITESPACE@9..10 "\n"))
732                 -> use bar;
733                 -> "\n"
734
735                 replacements:
736
737
738
739                 deletions:
740
741
742             "#]],
743         )
744     }
745
746     #[test]
747     fn insert_first() {
748         check_diff(
749             r#"
750 use bar;
751 use baz;"#,
752             r#"
753 use foo;
754 use bar;
755 use baz;"#,
756             expect![[r#"
757                 insertions:
758
759                 Line 0: After(Token(WHITESPACE@0..1 "\n"))
760                 -> use foo;
761                 -> "\n"
762
763                 replacements:
764
765
766
767                 deletions:
768
769
770             "#]],
771         )
772     }
773
774     #[test]
775     fn first_child_insertion() {
776         mark::check!(insert_first_child);
777         check_diff(
778             r#"fn main() {
779         stdi
780     }"#,
781             r#"use foo::bar;
782
783     fn main() {
784         stdi
785     }"#,
786             expect![[r#"
787                 insertions:
788
789                 Line 0: AsFirstChild(Node(SOURCE_FILE@0..30))
790                 -> use foo::bar;
791                 -> "\n\n    "
792
793                 replacements:
794
795
796
797                 deletions:
798
799
800             "#]],
801         );
802     }
803
804     #[test]
805     fn delete_last() {
806         mark::check!(diff_delete);
807         check_diff(
808             r#"use foo;
809             use bar;"#,
810             r#"use foo;"#,
811             expect![[r#"
812                 insertions:
813
814
815
816                 replacements:
817
818
819
820                 deletions:
821
822                 Line 1: "\n            "
823                 Line 2: use bar;
824             "#]],
825         );
826     }
827
828     #[test]
829     fn delete_middle() {
830         mark::check!(diff_insertions);
831         check_diff(
832             r#"
833 use expect_test::{expect, Expect};
834 use text_edit::TextEdit;
835
836 use crate::AstNode;
837 "#,
838             r#"
839 use expect_test::{expect, Expect};
840
841 use crate::AstNode;
842 "#,
843             expect![[r#"
844                 insertions:
845
846                 Line 1: After(Node(USE@1..35))
847                 -> "\n\n"
848                 -> use crate::AstNode;
849
850                 replacements:
851
852
853
854                 deletions:
855
856                 Line 2: use text_edit::TextEdit;
857                 Line 3: "\n\n"
858                 Line 4: use crate::AstNode;
859                 Line 5: "\n"
860             "#]],
861         )
862     }
863
864     #[test]
865     fn delete_first() {
866         check_diff(
867             r#"
868 use text_edit::TextEdit;
869
870 use crate::AstNode;
871 "#,
872             r#"
873 use crate::AstNode;
874 "#,
875             expect![[r#"
876                 insertions:
877
878
879
880                 replacements:
881
882                 Line 2: Token(IDENT@5..14 "text_edit") -> crate
883                 Line 2: Token(IDENT@16..24 "TextEdit") -> AstNode
884                 Line 2: Token(WHITESPACE@25..27 "\n\n") -> "\n"
885
886                 deletions:
887
888                 Line 3: use crate::AstNode;
889                 Line 4: "\n"
890             "#]],
891         )
892     }
893
894     #[test]
895     fn merge_use() {
896         check_diff(
897             r#"
898 use std::{
899     fmt,
900     hash::BuildHasherDefault,
901     ops::{self, RangeInclusive},
902 };
903 "#,
904             r#"
905 use std::fmt;
906 use std::hash::BuildHasherDefault;
907 use std::ops::{self, RangeInclusive};
908 "#,
909             expect![[r#"
910                 insertions:
911
912                 Line 2: After(Node(PATH_SEGMENT@5..8))
913                 -> ::
914                 -> fmt
915                 Line 6: After(Token(WHITESPACE@86..87 "\n"))
916                 -> use std::hash::BuildHasherDefault;
917                 -> "\n"
918                 -> use std::ops::{self, RangeInclusive};
919                 -> "\n"
920
921                 replacements:
922
923                 Line 2: Token(IDENT@5..8 "std") -> std
924
925                 deletions:
926
927                 Line 2: ::
928                 Line 2: {
929                     fmt,
930                     hash::BuildHasherDefault,
931                     ops::{self, RangeInclusive},
932                 }
933             "#]],
934         )
935     }
936
937     #[test]
938     fn early_return_assist() {
939         check_diff(
940             r#"
941 fn main() {
942     if let Ok(x) = Err(92) {
943         foo(x);
944     }
945 }
946             "#,
947             r#"
948 fn main() {
949     let x = match Err(92) {
950         Ok(it) => it,
951         _ => return,
952     };
953     foo(x);
954 }
955             "#,
956             expect![[r#"
957                 insertions:
958
959                 Line 3: After(Node(BLOCK_EXPR@40..63))
960                 -> " "
961                 -> match Err(92) {
962                         Ok(it) => it,
963                         _ => return,
964                     }
965                 -> ;
966                 Line 3: After(Node(IF_EXPR@17..63))
967                 -> "\n    "
968                 -> foo(x);
969
970                 replacements:
971
972                 Line 3: Token(IF_KW@17..19 "if") -> let
973                 Line 3: Token(LET_KW@20..23 "let") -> x
974                 Line 3: Node(BLOCK_EXPR@40..63) -> =
975
976                 deletions:
977
978                 Line 3: " "
979                 Line 3: Ok(x)
980                 Line 3: " "
981                 Line 3: =
982                 Line 3: " "
983                 Line 3: Err(92)
984             "#]],
985         )
986     }
987
988     fn check_diff(from: &str, to: &str, expected_diff: Expect) {
989         let from_node = crate::SourceFile::parse(from).tree().syntax().clone();
990         let to_node = crate::SourceFile::parse(to).tree().syntax().clone();
991         let diff = super::diff(&from_node, &to_node);
992
993         let line_number =
994             |syn: &SyntaxElement| from[..syn.text_range().start().into()].lines().count();
995
996         let fmt_syntax = |syn: &SyntaxElement| match syn.kind() {
997             SyntaxKind::WHITESPACE => format!("{:?}", syn.to_string()),
998             _ => format!("{}", syn),
999         };
1000
1001         let insertions =
1002             diff.insertions.iter().format_with("\n", |(k, v), f| -> Result<(), std::fmt::Error> {
1003                 f(&format!(
1004                     "Line {}: {:?}\n-> {}",
1005                     line_number(match k {
1006                         super::TreeDiffInsertPos::After(syn) => syn,
1007                         super::TreeDiffInsertPos::AsFirstChild(syn) => syn,
1008                     }),
1009                     k,
1010                     v.iter().format_with("\n-> ", |v, f| f(&fmt_syntax(v)))
1011                 ))
1012             });
1013
1014         let replacements = diff
1015             .replacements
1016             .iter()
1017             .sorted_by_key(|(syntax, _)| syntax.text_range().start())
1018             .format_with("\n", |(k, v), f| {
1019                 f(&format!("Line {}: {:?} -> {}", line_number(k), k, fmt_syntax(v)))
1020             });
1021
1022         let deletions = diff
1023             .deletions
1024             .iter()
1025             .format_with("\n", |v, f| f(&format!("Line {}: {}", line_number(v), &fmt_syntax(v))));
1026
1027         let actual = format!(
1028             "insertions:\n\n{}\n\nreplacements:\n\n{}\n\ndeletions:\n\n{}\n",
1029             insertions, replacements, deletions
1030         );
1031         expected_diff.assert_eq(&actual);
1032
1033         let mut from = from.to_owned();
1034         let mut text_edit = TextEdit::builder();
1035         diff.into_text_edit(&mut text_edit);
1036         text_edit.finish().apply(&mut from);
1037         assert_eq!(&*from, to, "diff did not turn `from` to `to`");
1038     }
1039 }