]> git.lizzy.rs Git - rust.git/commitdiff
internal: use mutable trees when filling match arms
authorAleksey Kladov <aleksey.kladov@gmail.com>
Sun, 16 May 2021 12:10:18 +0000 (15:10 +0300)
committerAleksey Kladov <aleksey.kladov@gmail.com>
Sun, 16 May 2021 12:10:18 +0000 (15:10 +0300)
crates/ide_assists/src/handlers/fill_match_arms.rs
crates/syntax/src/ast/edit.rs
crates/syntax/src/ast/edit_in_place.rs

index be927cc1c4728b6d006fb3ce5e40f810f5ea0f53..f66a9b54bc7f55e77a52cd19f99e6dbc0d77fe23 100644 (file)
@@ -71,6 +71,7 @@ pub(crate) fn fill_match_arms(acc: &mut Assists, ctx: &AssistContext) -> Option<
             .filter_map(|variant| build_pat(ctx.db(), module, variant))
             .filter(|variant_pat| is_variant_missing(&top_lvl_pats, variant_pat))
             .map(|pat| make::match_arm(iter::once(pat), make::expr_empty_block()))
+            .map(|it| it.clone_for_update())
             .collect::<Vec<_>>();
         if Some(enum_def)
             == FamousDefs(&ctx.sema, Some(module.krate()))
@@ -99,6 +100,7 @@ pub(crate) fn fill_match_arms(acc: &mut Assists, ctx: &AssistContext) -> Option<
             })
             .filter(|variant_pat| is_variant_missing(&top_lvl_pats, variant_pat))
             .map(|pat| make::match_arm(iter::once(pat), make::expr_empty_block()))
+            .map(|it| it.clone_for_update())
             .collect()
     } else {
         return None;
@@ -114,10 +116,20 @@ pub(crate) fn fill_match_arms(acc: &mut Assists, ctx: &AssistContext) -> Option<
         "Fill match arms",
         target,
         |builder| {
-            let new_arm_list = match_arm_list.remove_placeholder();
-            let n_old_arms = new_arm_list.arms().count();
-            let new_arm_list = new_arm_list.append_arms(missing_arms);
-            let first_new_arm = new_arm_list.arms().nth(n_old_arms);
+            let new_match_arm_list = match_arm_list.clone_for_update();
+
+            let catch_all_arm = new_match_arm_list
+                .arms()
+                .find(|arm| matches!(arm.pat(), Some(ast::Pat::WildcardPat(_))));
+            if let Some(arm) = catch_all_arm {
+                arm.remove()
+            }
+            let mut first_new_arm = None;
+            for arm in missing_arms {
+                first_new_arm.get_or_insert_with(|| arm.clone());
+                new_match_arm_list.add_arm(arm);
+            }
+
             let old_range = ctx.sema.original_range(match_arm_list.syntax()).range;
             match (first_new_arm, ctx.config.snippet_cap) {
                 (Some(first_new_arm), Some(cap)) => {
@@ -131,10 +143,10 @@ pub(crate) fn fill_match_arms(acc: &mut Assists, ctx: &AssistContext) -> Option<
                             }
                             None => Cursor::Before(first_new_arm.syntax()),
                         };
-                    let snippet = render_snippet(cap, new_arm_list.syntax(), cursor);
+                    let snippet = render_snippet(cap, new_match_arm_list.syntax(), cursor);
                     builder.replace_snippet(cap, old_range, snippet);
                 }
-                _ => builder.replace(old_range, new_arm_list.to_string()),
+                _ => builder.replace(old_range, new_match_arm_list.to_string()),
             }
         },
     )
@@ -919,8 +931,8 @@ fn foo(a: A) {
                 match a {
                     // foo bar baz
                     A::One => {}
-                    // This is where the rest should be
                     $0A::Two => {}
+                    // This is where the rest should be
                 }
             }
             "#,
@@ -943,9 +955,9 @@ fn foo(a: A) {
             enum A { One, Two }
             fn foo(a: A) {
                 match a {
-                    // foo bar baz
                     $0A::One => {}
                     A::Two => {}
+                    // foo bar baz
                 }
             }
             "#,
index 4b5f5c5717f67c3c6e22fb92ff082fe68b53444d..5e6c1d44e8c7e9b969adae71340248050895d16c 100644 (file)
@@ -29,38 +29,6 @@ pub fn replace_op(&self, op: SyntaxKind) -> Option<ast::BinExpr> {
     }
 }
 
-fn make_multiline<N>(node: N) -> N
-where
-    N: AstNode + Clone,
-{
-    let l_curly = match node.syntax().children_with_tokens().find(|it| it.kind() == T!['{']) {
-        Some(it) => it,
-        None => return node,
-    };
-    let sibling = match l_curly.next_sibling_or_token() {
-        Some(it) => it,
-        None => return node,
-    };
-    let existing_ws = match sibling.as_token() {
-        None => None,
-        Some(tok) if tok.kind() != WHITESPACE => None,
-        Some(ws) => {
-            if ws.text().contains('\n') {
-                return node;
-            }
-            Some(ws.clone())
-        }
-    };
-
-    let indent = leading_indent(node.syntax()).unwrap_or_default();
-    let ws = tokens::WsBuilder::new(&format!("\n{}", indent));
-    let to_insert = iter::once(ws.ws().into());
-    match existing_ws {
-        None => node.insert_children(InsertPosition::After(l_curly), to_insert),
-        Some(ws) => node.replace_children(single_node(ws), to_insert),
-    }
-}
-
 impl ast::RecordExprFieldList {
     #[must_use]
     pub fn append_field(&self, field: &ast::RecordExprField) -> ast::RecordExprFieldList {
@@ -214,79 +182,6 @@ fn split_path_prefix(prefix: &ast::Path) -> Option<ast::Path> {
     }
 }
 
-impl ast::MatchArmList {
-    #[must_use]
-    pub fn append_arms(&self, items: impl IntoIterator<Item = ast::MatchArm>) -> ast::MatchArmList {
-        let mut res = self.clone();
-        res = res.strip_if_only_whitespace();
-        if !res.syntax().text().contains_char('\n') {
-            res = make_multiline(res);
-        }
-        items.into_iter().for_each(|it| res = res.append_arm(it));
-        res
-    }
-
-    fn strip_if_only_whitespace(&self) -> ast::MatchArmList {
-        let mut iter = self.syntax().children_with_tokens().skip_while(|it| it.kind() != T!['{']);
-        iter.next(); // Eat the curly
-        let mut inner = iter.take_while(|it| it.kind() != T!['}']);
-        if !inner.clone().all(|it| it.kind() == WHITESPACE) {
-            return self.clone();
-        }
-        let start = match inner.next() {
-            Some(s) => s,
-            None => return self.clone(),
-        };
-        let end = match inner.last() {
-            Some(s) => s,
-            None => start.clone(),
-        };
-        self.replace_children(start..=end, &mut iter::empty())
-    }
-
-    #[must_use]
-    pub fn remove_placeholder(&self) -> ast::MatchArmList {
-        let placeholder =
-            self.arms().find(|arm| matches!(arm.pat(), Some(ast::Pat::WildcardPat(_))));
-        if let Some(placeholder) = placeholder {
-            self.remove_arm(&placeholder)
-        } else {
-            self.clone()
-        }
-    }
-
-    #[must_use]
-    fn remove_arm(&self, arm: &ast::MatchArm) -> ast::MatchArmList {
-        let start = arm.syntax().clone();
-        let end = if let Some(comma) = start
-            .siblings_with_tokens(Direction::Next)
-            .skip(1)
-            .find(|it| !it.kind().is_trivia())
-            .filter(|it| it.kind() == T![,])
-        {
-            comma
-        } else {
-            start.clone().into()
-        };
-        self.replace_children(start.into()..=end, None)
-    }
-
-    #[must_use]
-    pub fn append_arm(&self, item: ast::MatchArm) -> ast::MatchArmList {
-        let r_curly = match self.syntax().children_with_tokens().find(|it| it.kind() == T!['}']) {
-            Some(t) => t,
-            None => return self.clone(),
-        };
-        let position = InsertPosition::Before(r_curly);
-        let arm_ws = tokens::WsBuilder::new("    ");
-        let match_indent = &leading_indent(self.syntax()).unwrap_or_default();
-        let match_ws = tokens::WsBuilder::new(&format!("\n{}", match_indent));
-        let to_insert: ArrayVec<SyntaxElement, 3> =
-            [arm_ws.ws().into(), item.syntax().clone().into(), match_ws.ws().into()].into();
-        self.insert_children(position, to_insert)
-    }
-}
-
 #[must_use]
 pub fn remove_attrs_and_docs<N: ast::AttrsOwner>(node: &N) -> N {
     N::cast(remove_attrs_and_docs_inner(node.syntax().clone())).unwrap()
index ca777d057da0aa96fa7ba5c6687bdfc1b55e2f0b..abab0269a027531f1dc58159b41c1c8cd08c6a7c 100644 (file)
@@ -13,7 +13,7 @@
         make, GenericParamsOwner,
     },
     ted::{self, Position},
-    AstNode, AstToken, Direction,
+    AstNode, AstToken, Direction, SyntaxNode,
 };
 
 use super::NameOwner;
@@ -297,7 +297,7 @@ pub fn add_item(&self, item: ast::AssocItem) {
             ),
             None => match self.l_curly_token() {
                 Some(l_curly) => {
-                    self.normalize_ws_between_braces();
+                    normalize_ws_between_braces(self.syntax());
                     (IndentLevel::from_token(&l_curly) + 1, Position::after(&l_curly), "\n")
                 }
                 None => (IndentLevel::single(), Position::last_child_of(self.syntax()), "\n"),
@@ -309,25 +309,6 @@ pub fn add_item(&self, item: ast::AssocItem) {
         ];
         ted::insert_all(position, elements);
     }
-
-    fn normalize_ws_between_braces(&self) -> Option<()> {
-        let l = self.l_curly_token()?;
-        let r = self.r_curly_token()?;
-        let indent = IndentLevel::from_node(self.syntax());
-
-        match l.next_sibling_or_token() {
-            Some(ws) if ws.kind() == SyntaxKind::WHITESPACE => {
-                if ws.next_sibling_or_token()?.into_token()? == r {
-                    ted::replace(ws, make::tokens::whitespace(&format!("\n{}", indent)));
-                }
-            }
-            Some(ws) if ws.kind() == T!['}'] => {
-                ted::insert(Position::after(l), make::tokens::whitespace(&format!("\n{}", indent)));
-            }
-            _ => (),
-        }
-        Some(())
-    }
 }
 
 impl ast::Fn {
@@ -346,6 +327,73 @@ pub fn get_or_create_body(&self) -> ast::BlockExpr {
     }
 }
 
+impl ast::MatchArm {
+    pub fn remove(&self) {
+        if let Some(sibling) = self.syntax().prev_sibling_or_token() {
+            if sibling.kind() == SyntaxKind::WHITESPACE {
+                ted::remove(sibling);
+            }
+        }
+        if let Some(sibling) = self.syntax().next_sibling_or_token() {
+            if sibling.kind() == T![,] {
+                ted::remove(sibling);
+            }
+        }
+        ted::remove(self.syntax());
+    }
+}
+
+impl ast::MatchArmList {
+    pub fn add_arm(&self, arm: ast::MatchArm) {
+        normalize_ws_between_braces(self.syntax());
+        let position = match self.arms().last() {
+            Some(last_arm) => {
+                let curly = last_arm
+                    .syntax()
+                    .siblings_with_tokens(Direction::Next)
+                    .find(|it| it.kind() == T![,]);
+                Position::after(curly.unwrap_or_else(|| last_arm.syntax().clone().into()))
+            }
+            None => match self.l_curly_token() {
+                Some(it) => Position::after(it),
+                None => Position::last_child_of(self.syntax()),
+            },
+        };
+        let indent = IndentLevel::from_node(self.syntax()) + 1;
+        let elements = vec![
+            make::tokens::whitespace(&format!("\n{}", indent)).into(),
+            arm.syntax().clone().into(),
+        ];
+        ted::insert_all(position, elements);
+    }
+}
+
+fn normalize_ws_between_braces(node: &SyntaxNode) -> Option<()> {
+    let l = node
+        .children_with_tokens()
+        .filter_map(|it| it.into_token())
+        .find(|it| it.kind() == T!['{'])?;
+    let r = node
+        .children_with_tokens()
+        .filter_map(|it| it.into_token())
+        .find(|it| it.kind() == T!['}'])?;
+
+    let indent = IndentLevel::from_node(node);
+
+    match l.next_sibling_or_token() {
+        Some(ws) if ws.kind() == SyntaxKind::WHITESPACE => {
+            if ws.next_sibling_or_token()?.into_token()? == r {
+                ted::replace(ws, make::tokens::whitespace(&format!("\n{}", indent)));
+            }
+        }
+        Some(ws) if ws.kind() == T!['}'] => {
+            ted::insert(Position::after(l), make::tokens::whitespace(&format!("\n{}", indent)));
+        }
+        _ => (),
+    }
+    Some(())
+}
+
 #[cfg(test)]
 mod tests {
     use std::fmt;