]> git.lizzy.rs Git - rust.git/blob - crates/syntax/src/algo.rs
:arrow_up: rowan
[rust.git] / crates / syntax / src / algo.rs
1 //! FIXME: write short doc here
2
3 use std::{
4     fmt,
5     hash::BuildHasherDefault,
6     ops::{self, RangeInclusive},
7     ptr,
8 };
9
10 use indexmap::IndexMap;
11 use itertools::Itertools;
12 use rustc_hash::FxHashMap;
13 use test_utils::mark;
14 use text_edit::TextEditBuilder;
15
16 use crate::{
17     AstNode, Direction, NodeOrToken, SyntaxElement, SyntaxKind, SyntaxNode, SyntaxNodePtr,
18     SyntaxToken, TextRange, TextSize,
19 };
20
21 /// Returns ancestors of the node at the offset, sorted by length. This should
22 /// do the right thing at an edge, e.g. when searching for expressions at `{
23 /// $0foo }` we will get the name reference instead of the whole block, which
24 /// we would get if we just did `find_token_at_offset(...).flat_map(|t|
25 /// t.parent().ancestors())`.
26 pub fn ancestors_at_offset(
27     node: &SyntaxNode,
28     offset: TextSize,
29 ) -> impl Iterator<Item = SyntaxNode> {
30     node.token_at_offset(offset)
31         .map(|token| token.parent().ancestors())
32         .kmerge_by(|node1, node2| node1.text_range().len() < node2.text_range().len())
33 }
34
35 /// Finds a node of specific Ast type at offset. Note that this is slightly
36 /// imprecise: if the cursor is strictly between two nodes of the desired type,
37 /// as in
38 ///
39 /// ```no_run
40 /// struct Foo {}|struct Bar;
41 /// ```
42 ///
43 /// then the shorter node will be silently preferred.
44 pub fn find_node_at_offset<N: AstNode>(syntax: &SyntaxNode, offset: TextSize) -> Option<N> {
45     ancestors_at_offset(syntax, offset).find_map(N::cast)
46 }
47
48 pub fn find_node_at_range<N: AstNode>(syntax: &SyntaxNode, range: TextRange) -> Option<N> {
49     syntax.covering_element(range).ancestors().find_map(N::cast)
50 }
51
52 /// Skip to next non `trivia` token
53 pub fn skip_trivia_token(mut token: SyntaxToken, direction: Direction) -> Option<SyntaxToken> {
54     while token.kind().is_trivia() {
55         token = match direction {
56             Direction::Next => token.next_token()?,
57             Direction::Prev => token.prev_token()?,
58         }
59     }
60     Some(token)
61 }
62
63 /// Finds the first sibling in the given direction which is not `trivia`
64 pub fn non_trivia_sibling(element: SyntaxElement, direction: Direction) -> Option<SyntaxElement> {
65     return match element {
66         NodeOrToken::Node(node) => node.siblings_with_tokens(direction).skip(1).find(not_trivia),
67         NodeOrToken::Token(token) => token.siblings_with_tokens(direction).skip(1).find(not_trivia),
68     };
69
70     fn not_trivia(element: &SyntaxElement) -> bool {
71         match element {
72             NodeOrToken::Node(_) => true,
73             NodeOrToken::Token(token) => !token.kind().is_trivia(),
74         }
75     }
76 }
77
78 pub fn least_common_ancestor(u: &SyntaxNode, v: &SyntaxNode) -> Option<SyntaxNode> {
79     if u == v {
80         return Some(u.clone());
81     }
82
83     let u_depth = u.ancestors().count();
84     let v_depth = v.ancestors().count();
85     let keep = u_depth.min(v_depth);
86
87     let u_candidates = u.ancestors().skip(u_depth - keep);
88     let v_candidates = v.ancestors().skip(v_depth - keep);
89     let (res, _) = u_candidates.zip(v_candidates).find(|(x, y)| x == y)?;
90     Some(res)
91 }
92
93 pub fn neighbor<T: AstNode>(me: &T, direction: Direction) -> Option<T> {
94     me.syntax().siblings(direction).skip(1).find_map(T::cast)
95 }
96
97 pub fn has_errors(node: &SyntaxNode) -> bool {
98     node.children().any(|it| it.kind() == SyntaxKind::ERROR)
99 }
100
101 #[derive(Debug, PartialEq, Eq, Clone, Copy)]
102 pub enum InsertPosition<T> {
103     First,
104     Last,
105     Before(T),
106     After(T),
107 }
108
109 type FxIndexMap<K, V> = IndexMap<K, V, BuildHasherDefault<rustc_hash::FxHasher>>;
110
111 #[derive(Debug, Hash, PartialEq, Eq)]
112 enum TreeDiffInsertPos {
113     After(SyntaxElement),
114     AsFirstChild(SyntaxElement),
115 }
116
117 #[derive(Debug)]
118 pub struct TreeDiff {
119     replacements: FxHashMap<SyntaxElement, SyntaxElement>,
120     deletions: Vec<SyntaxElement>,
121     // the vec as well as the indexmap are both here to preserve order
122     insertions: FxIndexMap<TreeDiffInsertPos, Vec<SyntaxElement>>,
123 }
124
125 impl TreeDiff {
126     pub fn into_text_edit(&self, builder: &mut TextEditBuilder) {
127         let _p = profile::span("into_text_edit");
128
129         for (anchor, to) in self.insertions.iter() {
130             let offset = match anchor {
131                 TreeDiffInsertPos::After(it) => it.text_range().end(),
132                 TreeDiffInsertPos::AsFirstChild(it) => it.text_range().start(),
133             };
134             to.iter().for_each(|to| builder.insert(offset, to.to_string()));
135         }
136         for (from, to) in self.replacements.iter() {
137             builder.replace(from.text_range(), to.to_string())
138         }
139         for text_range in self.deletions.iter().map(SyntaxElement::text_range) {
140             builder.delete(text_range);
141         }
142     }
143
144     pub fn is_empty(&self) -> bool {
145         self.replacements.is_empty() && self.deletions.is_empty() && self.insertions.is_empty()
146     }
147 }
148
149 /// Finds a (potentially minimal) diff, which, applied to `from`, will result in `to`.
150 ///
151 /// Specifically, returns a structure that consists of a replacements, insertions and deletions
152 /// such that applying this map on `from` will result in `to`.
153 ///
154 /// This function tries to find a fine-grained diff.
155 pub fn diff(from: &SyntaxNode, to: &SyntaxNode) -> TreeDiff {
156     let _p = profile::span("diff");
157
158     let mut diff = TreeDiff {
159         replacements: FxHashMap::default(),
160         insertions: FxIndexMap::default(),
161         deletions: Vec::new(),
162     };
163     let (from, to) = (from.clone().into(), to.clone().into());
164
165     if !syntax_element_eq(&from, &to) {
166         go(&mut diff, from, to);
167     }
168     return diff;
169
170     fn syntax_element_eq(lhs: &SyntaxElement, rhs: &SyntaxElement) -> bool {
171         lhs.kind() == rhs.kind()
172             && lhs.text_range().len() == rhs.text_range().len()
173             && match (&lhs, &rhs) {
174                 (NodeOrToken::Node(lhs), NodeOrToken::Node(rhs)) => {
175                     ptr::eq(lhs.green(), rhs.green()) || lhs.text() == rhs.text()
176                 }
177                 (NodeOrToken::Token(lhs), NodeOrToken::Token(rhs)) => lhs.text() == rhs.text(),
178                 _ => false,
179             }
180     }
181
182     // FIXME: this is horrible inefficient. I bet there's a cool algorithm to diff trees properly.
183     fn go(diff: &mut TreeDiff, lhs: SyntaxElement, rhs: SyntaxElement) {
184         let (lhs, rhs) = match lhs.as_node().zip(rhs.as_node()) {
185             Some((lhs, rhs)) => (lhs, rhs),
186             _ => {
187                 mark::hit!(diff_node_token_replace);
188                 diff.replacements.insert(lhs, rhs);
189                 return;
190             }
191         };
192
193         let mut look_ahead_scratch = Vec::default();
194
195         let mut rhs_children = rhs.children_with_tokens();
196         let mut lhs_children = lhs.children_with_tokens();
197         let mut last_lhs = None;
198         loop {
199             let lhs_child = lhs_children.next();
200             match (lhs_child.clone(), rhs_children.next()) {
201                 (None, None) => break,
202                 (None, Some(element)) => {
203                     let insert_pos = match last_lhs.clone() {
204                         Some(prev) => {
205                             mark::hit!(diff_insert);
206                             TreeDiffInsertPos::After(prev)
207                         }
208                         // first iteration, insert into out parent as the first child
209                         None => {
210                             mark::hit!(diff_insert_as_first_child);
211                             TreeDiffInsertPos::AsFirstChild(lhs.clone().into())
212                         }
213                     };
214                     diff.insertions.entry(insert_pos).or_insert_with(Vec::new).push(element);
215                 }
216                 (Some(element), None) => {
217                     mark::hit!(diff_delete);
218                     diff.deletions.push(element);
219                 }
220                 (Some(ref lhs_ele), Some(ref rhs_ele)) if syntax_element_eq(lhs_ele, rhs_ele) => {}
221                 (Some(lhs_ele), Some(rhs_ele)) => {
222                     // nodes differ, look for lhs_ele in rhs, if its found we can mark everything up
223                     // until that element as insertions. This is important to keep the diff minimal
224                     // in regards to insertions that have been actually done, this is important for
225                     // use insertions as we do not want to replace the entire module node.
226                     look_ahead_scratch.push(rhs_ele.clone());
227                     let mut rhs_children_clone = rhs_children.clone();
228                     let mut insert = false;
229                     while let Some(rhs_child) = rhs_children_clone.next() {
230                         if syntax_element_eq(&lhs_ele, &rhs_child) {
231                             mark::hit!(diff_insertions);
232                             insert = true;
233                             break;
234                         } else {
235                             look_ahead_scratch.push(rhs_child);
236                         }
237                     }
238                     let drain = look_ahead_scratch.drain(..);
239                     if insert {
240                         let insert_pos = if let Some(prev) = last_lhs.clone().filter(|_| insert) {
241                             TreeDiffInsertPos::After(prev)
242                         } else {
243                             mark::hit!(insert_first_child);
244                             TreeDiffInsertPos::AsFirstChild(lhs.clone().into())
245                         };
246
247                         diff.insertions.entry(insert_pos).or_insert_with(Vec::new).extend(drain);
248                         rhs_children = rhs_children_clone;
249                     } else {
250                         go(diff, lhs_ele, rhs_ele)
251                     }
252                 }
253             }
254             last_lhs = lhs_child.or(last_lhs);
255         }
256     }
257 }
258
259 /// Adds specified children (tokens or nodes) to the current node at the
260 /// specific position.
261 ///
262 /// This is a type-unsafe low-level editing API, if you need to use it,
263 /// prefer to create a type-safe abstraction on top of it instead.
264 pub fn insert_children(
265     parent: &SyntaxNode,
266     position: InsertPosition<SyntaxElement>,
267     to_insert: impl IntoIterator<Item = SyntaxElement>,
268 ) -> SyntaxNode {
269     let mut to_insert = to_insert.into_iter();
270     _insert_children(parent, position, &mut to_insert)
271 }
272
273 fn _insert_children(
274     parent: &SyntaxNode,
275     position: InsertPosition<SyntaxElement>,
276     to_insert: &mut dyn Iterator<Item = SyntaxElement>,
277 ) -> SyntaxNode {
278     let mut delta = TextSize::default();
279     let to_insert = to_insert.map(|element| {
280         delta += element.text_range().len();
281         to_green_element(element)
282     });
283
284     let mut old_children = parent.green().children().map(|it| match it {
285         NodeOrToken::Token(it) => NodeOrToken::Token(it.clone()),
286         NodeOrToken::Node(it) => NodeOrToken::Node(it.clone()),
287     });
288
289     let new_children = match &position {
290         InsertPosition::First => to_insert.chain(old_children).collect::<Vec<_>>(),
291         InsertPosition::Last => old_children.chain(to_insert).collect::<Vec<_>>(),
292         InsertPosition::Before(anchor) | InsertPosition::After(anchor) => {
293             let take_anchor = if let InsertPosition::After(_) = position { 1 } else { 0 };
294             let split_at = position_of_child(parent, anchor.clone()) + take_anchor;
295             let before = old_children.by_ref().take(split_at).collect::<Vec<_>>();
296             before.into_iter().chain(to_insert).chain(old_children).collect::<Vec<_>>()
297         }
298     };
299
300     with_children(parent, new_children)
301 }
302
303 /// Replaces all nodes in `to_delete` with nodes from `to_insert`
304 ///
305 /// This is a type-unsafe low-level editing API, if you need to use it,
306 /// prefer to create a type-safe abstraction on top of it instead.
307 pub fn replace_children(
308     parent: &SyntaxNode,
309     to_delete: RangeInclusive<SyntaxElement>,
310     to_insert: impl IntoIterator<Item = SyntaxElement>,
311 ) -> SyntaxNode {
312     let mut to_insert = to_insert.into_iter();
313     _replace_children(parent, to_delete, &mut to_insert)
314 }
315
316 fn _replace_children(
317     parent: &SyntaxNode,
318     to_delete: RangeInclusive<SyntaxElement>,
319     to_insert: &mut dyn Iterator<Item = SyntaxElement>,
320 ) -> SyntaxNode {
321     let start = position_of_child(parent, to_delete.start().clone());
322     let end = position_of_child(parent, to_delete.end().clone());
323     let mut old_children = parent.green().children().map(|it| match it {
324         NodeOrToken::Token(it) => NodeOrToken::Token(it.clone()),
325         NodeOrToken::Node(it) => NodeOrToken::Node(it.clone()),
326     });
327
328     let before = old_children.by_ref().take(start).collect::<Vec<_>>();
329     let new_children = before
330         .into_iter()
331         .chain(to_insert.map(to_green_element))
332         .chain(old_children.skip(end + 1 - start))
333         .collect::<Vec<_>>();
334     with_children(parent, new_children)
335 }
336
337 #[derive(Debug, PartialEq, Eq, Hash)]
338 enum InsertPos {
339     FirstChildOf(SyntaxNode),
340     After(SyntaxElement),
341 }
342
343 #[derive(Default)]
344 pub struct SyntaxRewriter<'a> {
345     f: Option<Box<dyn Fn(&SyntaxElement) -> Option<SyntaxElement> + 'a>>,
346     //FIXME: add debug_assertions that all elements are in fact from the same file.
347     replacements: FxHashMap<SyntaxElement, Replacement>,
348     insertions: IndexMap<InsertPos, Vec<SyntaxElement>>,
349 }
350
351 impl fmt::Debug for SyntaxRewriter<'_> {
352     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
353         f.debug_struct("SyntaxRewriter")
354             .field("replacements", &self.replacements)
355             .field("insertions", &self.insertions)
356             .finish()
357     }
358 }
359
360 impl<'a> SyntaxRewriter<'a> {
361     pub fn from_fn(f: impl Fn(&SyntaxElement) -> Option<SyntaxElement> + 'a) -> SyntaxRewriter<'a> {
362         SyntaxRewriter {
363             f: Some(Box::new(f)),
364             replacements: FxHashMap::default(),
365             insertions: IndexMap::default(),
366         }
367     }
368     pub fn delete<T: Clone + Into<SyntaxElement>>(&mut self, what: &T) {
369         let what = what.clone().into();
370         let replacement = Replacement::Delete;
371         self.replacements.insert(what, replacement);
372     }
373     pub fn insert_before<T: Clone + Into<SyntaxElement>, U: Clone + Into<SyntaxElement>>(
374         &mut self,
375         before: &T,
376         what: &U,
377     ) {
378         let before = before.clone().into();
379         let pos = match before.prev_sibling_or_token() {
380             Some(sibling) => InsertPos::After(sibling),
381             None => match before.parent() {
382                 Some(parent) => InsertPos::FirstChildOf(parent),
383                 None => return,
384             },
385         };
386         self.insertions.entry(pos).or_insert_with(Vec::new).push(what.clone().into());
387     }
388     pub fn insert_after<T: Clone + Into<SyntaxElement>, U: Clone + Into<SyntaxElement>>(
389         &mut self,
390         after: &T,
391         what: &U,
392     ) {
393         self.insertions
394             .entry(InsertPos::After(after.clone().into()))
395             .or_insert_with(Vec::new)
396             .push(what.clone().into());
397     }
398     pub fn insert_as_first_child<T: Clone + Into<SyntaxNode>, U: Clone + Into<SyntaxElement>>(
399         &mut self,
400         parent: &T,
401         what: &U,
402     ) {
403         self.insertions
404             .entry(InsertPos::FirstChildOf(parent.clone().into()))
405             .or_insert_with(Vec::new)
406             .push(what.clone().into());
407     }
408     pub fn insert_many_before<
409         T: Clone + Into<SyntaxElement>,
410         U: IntoIterator<Item = SyntaxElement>,
411     >(
412         &mut self,
413         before: &T,
414         what: U,
415     ) {
416         let before = before.clone().into();
417         let pos = match before.prev_sibling_or_token() {
418             Some(sibling) => InsertPos::After(sibling),
419             None => match before.parent() {
420                 Some(parent) => InsertPos::FirstChildOf(parent),
421                 None => return,
422             },
423         };
424         self.insertions.entry(pos).or_insert_with(Vec::new).extend(what);
425     }
426     pub fn insert_many_after<
427         T: Clone + Into<SyntaxElement>,
428         U: IntoIterator<Item = SyntaxElement>,
429     >(
430         &mut self,
431         after: &T,
432         what: U,
433     ) {
434         self.insertions
435             .entry(InsertPos::After(after.clone().into()))
436             .or_insert_with(Vec::new)
437             .extend(what);
438     }
439     pub fn insert_many_as_first_children<
440         T: Clone + Into<SyntaxNode>,
441         U: IntoIterator<Item = SyntaxElement>,
442     >(
443         &mut self,
444         parent: &T,
445         what: U,
446     ) {
447         self.insertions
448             .entry(InsertPos::FirstChildOf(parent.clone().into()))
449             .or_insert_with(Vec::new)
450             .extend(what)
451     }
452     pub fn replace<T: Clone + Into<SyntaxElement>>(&mut self, what: &T, with: &T) {
453         let what = what.clone().into();
454         let replacement = Replacement::Single(with.clone().into());
455         self.replacements.insert(what, replacement);
456     }
457     pub fn replace_with_many<T: Clone + Into<SyntaxElement>>(
458         &mut self,
459         what: &T,
460         with: Vec<SyntaxElement>,
461     ) {
462         let what = what.clone().into();
463         let replacement = Replacement::Many(with);
464         self.replacements.insert(what, replacement);
465     }
466     pub fn replace_ast<T: AstNode>(&mut self, what: &T, with: &T) {
467         self.replace(what.syntax(), with.syntax())
468     }
469
470     pub fn rewrite(&self, node: &SyntaxNode) -> SyntaxNode {
471         let _p = profile::span("rewrite");
472
473         if self.f.is_none() && self.replacements.is_empty() && self.insertions.is_empty() {
474             return node.clone();
475         }
476         let green = self.rewrite_children(node);
477         with_green(node, green)
478     }
479
480     pub fn rewrite_ast<N: AstNode>(self, node: &N) -> N {
481         N::cast(self.rewrite(node.syntax())).unwrap()
482     }
483
484     /// Returns a node that encompasses all replacements to be done by this rewriter.
485     ///
486     /// Passing the returned node to `rewrite` will apply all replacements queued up in `self`.
487     ///
488     /// Returns `None` when there are no replacements.
489     pub fn rewrite_root(&self) -> Option<SyntaxNode> {
490         let _p = profile::span("rewrite_root");
491         fn element_to_node_or_parent(element: &SyntaxElement) -> SyntaxNode {
492             match element {
493                 SyntaxElement::Node(it) => it.clone(),
494                 SyntaxElement::Token(it) => it.parent(),
495             }
496         }
497
498         assert!(self.f.is_none());
499         self.replacements
500             .keys()
501             .map(element_to_node_or_parent)
502             .chain(self.insertions.keys().map(|pos| match pos {
503                 InsertPos::FirstChildOf(it) => it.clone(),
504                 InsertPos::After(it) => element_to_node_or_parent(it),
505             }))
506             // If we only have one replacement/insertion, we must return its parent node, since `rewrite` does
507             // not replace the node passed to it.
508             .map(|it| it.parent().unwrap_or(it))
509             .fold1(|a, b| least_common_ancestor(&a, &b).unwrap())
510     }
511
512     fn replacement(&self, element: &SyntaxElement) -> Option<Replacement> {
513         if let Some(f) = &self.f {
514             assert!(self.replacements.is_empty());
515             return f(element).map(Replacement::Single);
516         }
517         self.replacements.get(element).cloned()
518     }
519
520     fn insertions(&self, pos: &InsertPos) -> Option<impl Iterator<Item = SyntaxElement> + '_> {
521         self.insertions.get(pos).map(|insertions| insertions.iter().cloned())
522     }
523
524     fn rewrite_children(&self, node: &SyntaxNode) -> rowan::GreenNode {
525         let _p = profile::span("rewrite_children");
526
527         //  FIXME: this could be made much faster.
528         let mut new_children = Vec::new();
529         if let Some(elements) = self.insertions(&InsertPos::FirstChildOf(node.clone())) {
530             new_children.extend(elements.map(element_to_green));
531         }
532         for child in node.children_with_tokens() {
533             self.rewrite_self(&mut new_children, &child);
534         }
535
536         rowan::GreenNode::new(rowan::SyntaxKind(node.kind() as u16), new_children)
537     }
538
539     fn rewrite_self(
540         &self,
541         acc: &mut Vec<NodeOrToken<rowan::GreenNode, rowan::GreenToken>>,
542         element: &SyntaxElement,
543     ) {
544         let _p = profile::span("rewrite_self");
545
546         if let Some(replacement) = self.replacement(&element) {
547             match replacement {
548                 Replacement::Single(element) => acc.push(element_to_green(element)),
549                 Replacement::Many(replacements) => {
550                     acc.extend(replacements.into_iter().map(element_to_green))
551                 }
552                 Replacement::Delete => (),
553             };
554         } else {
555             match element {
556                 NodeOrToken::Token(it) => acc.push(NodeOrToken::Token(it.green().clone())),
557                 NodeOrToken::Node(it) => {
558                     acc.push(NodeOrToken::Node(self.rewrite_children(it)));
559                 }
560             }
561         }
562         if let Some(elements) = self.insertions(&InsertPos::After(element.clone())) {
563             acc.extend(elements.map(element_to_green));
564         }
565     }
566 }
567
568 fn element_to_green(element: SyntaxElement) -> NodeOrToken<rowan::GreenNode, rowan::GreenToken> {
569     match element {
570         NodeOrToken::Node(it) => NodeOrToken::Node(it.green().to_owned()),
571         NodeOrToken::Token(it) => NodeOrToken::Token(it.green().clone()),
572     }
573 }
574
575 impl ops::AddAssign for SyntaxRewriter<'_> {
576     fn add_assign(&mut self, rhs: SyntaxRewriter) {
577         assert!(rhs.f.is_none());
578         self.replacements.extend(rhs.replacements);
579         for (pos, insertions) in rhs.insertions.into_iter() {
580             match self.insertions.entry(pos) {
581                 indexmap::map::Entry::Occupied(mut occupied) => {
582                     occupied.get_mut().extend(insertions)
583                 }
584                 indexmap::map::Entry::Vacant(vacant) => drop(vacant.insert(insertions)),
585             }
586         }
587     }
588 }
589
590 #[derive(Clone, Debug)]
591 enum Replacement {
592     Delete,
593     Single(SyntaxElement),
594     Many(Vec<SyntaxElement>),
595 }
596
597 fn with_children(
598     parent: &SyntaxNode,
599     new_children: Vec<NodeOrToken<rowan::GreenNode, rowan::GreenToken>>,
600 ) -> SyntaxNode {
601     let _p = profile::span("with_children");
602
603     let new_green = rowan::GreenNode::new(rowan::SyntaxKind(parent.kind() as u16), new_children);
604     with_green(parent, new_green)
605 }
606
607 fn with_green(syntax_node: &SyntaxNode, green: rowan::GreenNode) -> SyntaxNode {
608     let len = green.children().map(|it| it.text_len()).sum::<TextSize>();
609     let new_root_node = syntax_node.replace_with(green);
610     let new_root_node = SyntaxNode::new_root(new_root_node);
611
612     // FIXME: use a more elegant way to re-fetch the node (#1185), make
613     // `range` private afterwards
614     let mut ptr = SyntaxNodePtr::new(syntax_node);
615     ptr.range = TextRange::at(ptr.range.start(), len);
616     ptr.to_node(&new_root_node)
617 }
618
619 fn position_of_child(parent: &SyntaxNode, child: SyntaxElement) -> usize {
620     parent
621         .children_with_tokens()
622         .position(|it| it == child)
623         .expect("element is not a child of current element")
624 }
625
626 fn to_green_element(element: SyntaxElement) -> NodeOrToken<rowan::GreenNode, rowan::GreenToken> {
627     match element {
628         NodeOrToken::Node(it) => it.green().to_owned().into(),
629         NodeOrToken::Token(it) => it.green().clone().into(),
630     }
631 }
632
633 #[cfg(test)]
634 mod tests {
635     use expect_test::{expect, Expect};
636     use itertools::Itertools;
637     use parser::SyntaxKind;
638     use test_utils::mark;
639     use text_edit::TextEdit;
640
641     use crate::{AstNode, SyntaxElement};
642
643     #[test]
644     fn replace_node_token() {
645         mark::check!(diff_node_token_replace);
646         check_diff(
647             r#"use node;"#,
648             r#"ident"#,
649             expect![[r#"
650                 insertions:
651
652
653
654                 replacements:
655
656                 Line 0: Token(USE_KW@0..3 "use") -> ident
657
658                 deletions:
659
660                 Line 1: " "
661                 Line 1: node
662                 Line 1: ;
663             "#]],
664         );
665     }
666
667     #[test]
668     fn replace_parent() {
669         mark::check!(diff_insert_as_first_child);
670         check_diff(
671             r#""#,
672             r#"use foo::bar;"#,
673             expect![[r#"
674                 insertions:
675
676                 Line 0: AsFirstChild(Node(SOURCE_FILE@0..0))
677                 -> use foo::bar;
678
679                 replacements:
680
681
682
683                 deletions:
684
685
686             "#]],
687         );
688     }
689
690     #[test]
691     fn insert_last() {
692         mark::check!(diff_insert);
693         check_diff(
694             r#"
695 use foo;
696 use bar;"#,
697             r#"
698 use foo;
699 use bar;
700 use baz;"#,
701             expect![[r#"
702                 insertions:
703
704                 Line 2: After(Node(USE@10..18))
705                 -> "\n"
706                 -> use baz;
707
708                 replacements:
709
710
711
712                 deletions:
713
714
715             "#]],
716         );
717     }
718
719     #[test]
720     fn insert_middle() {
721         check_diff(
722             r#"
723 use foo;
724 use baz;"#,
725             r#"
726 use foo;
727 use bar;
728 use baz;"#,
729             expect![[r#"
730                 insertions:
731
732                 Line 2: After(Token(WHITESPACE@9..10 "\n"))
733                 -> use bar;
734                 -> "\n"
735
736                 replacements:
737
738
739
740                 deletions:
741
742
743             "#]],
744         )
745     }
746
747     #[test]
748     fn insert_first() {
749         check_diff(
750             r#"
751 use bar;
752 use baz;"#,
753             r#"
754 use foo;
755 use bar;
756 use baz;"#,
757             expect![[r#"
758                 insertions:
759
760                 Line 0: After(Token(WHITESPACE@0..1 "\n"))
761                 -> use foo;
762                 -> "\n"
763
764                 replacements:
765
766
767
768                 deletions:
769
770
771             "#]],
772         )
773     }
774
775     #[test]
776     fn first_child_insertion() {
777         mark::check!(insert_first_child);
778         check_diff(
779             r#"fn main() {
780         stdi
781     }"#,
782             r#"use foo::bar;
783
784     fn main() {
785         stdi
786     }"#,
787             expect![[r#"
788                 insertions:
789
790                 Line 0: AsFirstChild(Node(SOURCE_FILE@0..30))
791                 -> use foo::bar;
792                 -> "\n\n    "
793
794                 replacements:
795
796
797
798                 deletions:
799
800
801             "#]],
802         );
803     }
804
805     #[test]
806     fn delete_last() {
807         mark::check!(diff_delete);
808         check_diff(
809             r#"use foo;
810             use bar;"#,
811             r#"use foo;"#,
812             expect![[r#"
813                 insertions:
814
815
816
817                 replacements:
818
819
820
821                 deletions:
822
823                 Line 1: "\n            "
824                 Line 2: use bar;
825             "#]],
826         );
827     }
828
829     #[test]
830     fn delete_middle() {
831         mark::check!(diff_insertions);
832         check_diff(
833             r#"
834 use expect_test::{expect, Expect};
835 use text_edit::TextEdit;
836
837 use crate::AstNode;
838 "#,
839             r#"
840 use expect_test::{expect, Expect};
841
842 use crate::AstNode;
843 "#,
844             expect![[r#"
845                 insertions:
846
847                 Line 1: After(Node(USE@1..35))
848                 -> "\n\n"
849                 -> use crate::AstNode;
850
851                 replacements:
852
853
854
855                 deletions:
856
857                 Line 2: use text_edit::TextEdit;
858                 Line 3: "\n\n"
859                 Line 4: use crate::AstNode;
860                 Line 5: "\n"
861             "#]],
862         )
863     }
864
865     #[test]
866     fn delete_first() {
867         check_diff(
868             r#"
869 use text_edit::TextEdit;
870
871 use crate::AstNode;
872 "#,
873             r#"
874 use crate::AstNode;
875 "#,
876             expect![[r#"
877                 insertions:
878
879
880
881                 replacements:
882
883                 Line 2: Token(IDENT@5..14 "text_edit") -> crate
884                 Line 2: Token(IDENT@16..24 "TextEdit") -> AstNode
885                 Line 2: Token(WHITESPACE@25..27 "\n\n") -> "\n"
886
887                 deletions:
888
889                 Line 3: use crate::AstNode;
890                 Line 4: "\n"
891             "#]],
892         )
893     }
894
895     #[test]
896     fn merge_use() {
897         check_diff(
898             r#"
899 use std::{
900     fmt,
901     hash::BuildHasherDefault,
902     ops::{self, RangeInclusive},
903 };
904 "#,
905             r#"
906 use std::fmt;
907 use std::hash::BuildHasherDefault;
908 use std::ops::{self, RangeInclusive};
909 "#,
910             expect![[r#"
911                 insertions:
912
913                 Line 2: After(Node(PATH_SEGMENT@5..8))
914                 -> ::
915                 -> fmt
916                 Line 6: After(Token(WHITESPACE@86..87 "\n"))
917                 -> use std::hash::BuildHasherDefault;
918                 -> "\n"
919                 -> use std::ops::{self, RangeInclusive};
920                 -> "\n"
921
922                 replacements:
923
924                 Line 2: Token(IDENT@5..8 "std") -> std
925
926                 deletions:
927
928                 Line 2: ::
929                 Line 2: {
930                     fmt,
931                     hash::BuildHasherDefault,
932                     ops::{self, RangeInclusive},
933                 }
934             "#]],
935         )
936     }
937
938     #[test]
939     fn early_return_assist() {
940         check_diff(
941             r#"
942 fn main() {
943     if let Ok(x) = Err(92) {
944         foo(x);
945     }
946 }
947             "#,
948             r#"
949 fn main() {
950     let x = match Err(92) {
951         Ok(it) => it,
952         _ => return,
953     };
954     foo(x);
955 }
956             "#,
957             expect![[r#"
958                 insertions:
959
960                 Line 3: After(Node(BLOCK_EXPR@40..63))
961                 -> " "
962                 -> match Err(92) {
963                         Ok(it) => it,
964                         _ => return,
965                     }
966                 -> ;
967                 Line 3: After(Node(IF_EXPR@17..63))
968                 -> "\n    "
969                 -> foo(x);
970
971                 replacements:
972
973                 Line 3: Token(IF_KW@17..19 "if") -> let
974                 Line 3: Token(LET_KW@20..23 "let") -> x
975                 Line 3: Node(BLOCK_EXPR@40..63) -> =
976
977                 deletions:
978
979                 Line 3: " "
980                 Line 3: Ok(x)
981                 Line 3: " "
982                 Line 3: =
983                 Line 3: " "
984                 Line 3: Err(92)
985             "#]],
986         )
987     }
988
989     fn check_diff(from: &str, to: &str, expected_diff: Expect) {
990         let from_node = crate::SourceFile::parse(from).tree().syntax().clone();
991         let to_node = crate::SourceFile::parse(to).tree().syntax().clone();
992         let diff = super::diff(&from_node, &to_node);
993
994         let line_number =
995             |syn: &SyntaxElement| from[..syn.text_range().start().into()].lines().count();
996
997         let fmt_syntax = |syn: &SyntaxElement| match syn.kind() {
998             SyntaxKind::WHITESPACE => format!("{:?}", syn.to_string()),
999             _ => format!("{}", syn),
1000         };
1001
1002         let insertions =
1003             diff.insertions.iter().format_with("\n", |(k, v), f| -> Result<(), std::fmt::Error> {
1004                 f(&format!(
1005                     "Line {}: {:?}\n-> {}",
1006                     line_number(match k {
1007                         super::TreeDiffInsertPos::After(syn) => syn,
1008                         super::TreeDiffInsertPos::AsFirstChild(syn) => syn,
1009                     }),
1010                     k,
1011                     v.iter().format_with("\n-> ", |v, f| f(&fmt_syntax(v)))
1012                 ))
1013             });
1014
1015         let replacements = diff
1016             .replacements
1017             .iter()
1018             .sorted_by_key(|(syntax, _)| syntax.text_range().start())
1019             .format_with("\n", |(k, v), f| {
1020                 f(&format!("Line {}: {:?} -> {}", line_number(k), k, fmt_syntax(v)))
1021             });
1022
1023         let deletions = diff
1024             .deletions
1025             .iter()
1026             .format_with("\n", |v, f| f(&format!("Line {}: {}", line_number(v), &fmt_syntax(v))));
1027
1028         let actual = format!(
1029             "insertions:\n\n{}\n\nreplacements:\n\n{}\n\ndeletions:\n\n{}\n",
1030             insertions, replacements, deletions
1031         );
1032         expected_diff.assert_eq(&actual);
1033
1034         let mut from = from.to_owned();
1035         let mut text_edit = TextEdit::builder();
1036         diff.into_text_edit(&mut text_edit);
1037         text_edit.finish().apply(&mut from);
1038         assert_eq!(&*from, to, "diff did not turn `from` to `to`");
1039     }
1040 }