]> git.lizzy.rs Git - rust.git/blobdiff - crates/ide_assists/src/handlers/convert_to_guarded_return.rs
Merge #11481
[rust.git] / crates / ide_assists / src / handlers / convert_to_guarded_return.rs
index b499287a5a1cdb0aa2e8dbd2079090501732b0ca..193d1cdfb2439cc4ce3a31491a27a328b562cb34 100644 (file)
@@ -1,15 +1,15 @@
-use std::{iter::once, ops::RangeInclusive};
+use std::iter::once;
 
+use ide_db::helpers::node_ext::{is_pattern_cond, single_let};
 use syntax::{
-    algo::replace_children,
     ast::{
         self,
         edit::{AstNodeEdit, IndentLevel},
         make,
     },
-    AstNode,
-    SyntaxKind::{FN, LOOP_EXPR, L_CURLY, R_CURLY, WHILE_EXPR, WHITESPACE},
-    SyntaxNode,
+    ted, AstNode,
+    SyntaxKind::{FN, LOOP_EXPR, WHILE_EXPR, WHITESPACE},
+    T,
 };
 
 use crate::{
@@ -49,27 +49,30 @@ pub(crate) fn convert_to_guarded_return(acc: &mut Assists, ctx: &AssistContext)
     let cond = if_expr.condition()?;
 
     // Check if there is an IfLet that we can handle.
-    let if_let_pat = match cond.pat() {
-        None => None, // No IfLet, supported.
-        Some(ast::Pat::TupleStructPat(pat)) if pat.fields().count() == 1 => {
-            let path = pat.path()?;
-            match path.qualifier() {
-                None => {
-                    let bound_ident = pat.fields().next().unwrap();
-                    if ast::IdentPat::can_cast(bound_ident.syntax().kind()) {
-                        Some((path, bound_ident))
-                    } else {
-                        return None;
-                    }
+    let (if_let_pat, cond_expr) = if is_pattern_cond(cond.clone()) {
+        let let_ = single_let(cond)?;
+        match let_.pat() {
+            Some(ast::Pat::TupleStructPat(pat)) if pat.fields().count() == 1 => {
+                let path = pat.path()?;
+                if path.qualifier().is_some() {
+                    return None;
+                }
+
+                let bound_ident = pat.fields().next().unwrap();
+                if !ast::IdentPat::can_cast(bound_ident.syntax().kind()) {
+                    return None;
                 }
-                Some(_) => return None,
+
+                (Some((path, bound_ident)), let_.expr()?)
             }
+            _ => return None, // Unsupported IfLet.
         }
-        Some(_) => return None, // Unsupported IfLet.
+    } else {
+        (None, cond)
     };
 
-    let cond_expr = cond.expr()?;
     let then_block = if_expr.then_branch()?;
+    let then_block = then_block.stmt_list()?;
 
     let parent_block = if_expr.syntax().parent()?.ancestors().find_map(ast::BlockExpr::cast)?;
 
@@ -77,6 +80,9 @@ pub(crate) fn convert_to_guarded_return(acc: &mut Assists, ctx: &AssistContext)
         return None;
     }
 
+    // FIXME: This relies on untyped syntax tree and casts to much. It should be
+    // rewritten to use strongly-typed APIs.
+
     // check for early return and continue
     let first_in_then_block = then_block.syntax().first_child()?;
     if ast::ReturnExpr::can_cast(first_in_then_block.kind())
@@ -96,11 +102,11 @@ pub(crate) fn convert_to_guarded_return(acc: &mut Assists, ctx: &AssistContext)
         _ => return None,
     };
 
-    if then_block.syntax().first_child_or_token().map(|t| t.kind() == L_CURLY).is_none() {
+    if then_block.syntax().first_child_or_token().map(|t| t.kind() == T!['{']).is_none() {
         return None;
     }
 
-    then_block.syntax().last_child_or_token().filter(|t| t.kind() == R_CURLY)?;
+    then_block.syntax().last_child_or_token().filter(|t| t.kind() == T!['}'])?;
 
     let target = if_expr.syntax().text_range();
     acc.add(
@@ -108,18 +114,18 @@ pub(crate) fn convert_to_guarded_return(acc: &mut Assists, ctx: &AssistContext)
         "Convert to guarded return",
         target,
         |edit| {
+            let if_expr = edit.make_mut(if_expr);
             let if_indent_level = IndentLevel::from_node(if_expr.syntax());
-            let new_block = match if_let_pat {
+            let replacement = match if_let_pat {
                 None => {
                     // If.
                     let new_expr = {
                         let then_branch =
                             make::block_expr(once(make::expr_stmt(early_expression).into()), None);
                         let cond = invert_boolean_expression(cond_expr);
-                        make::expr_if(make::condition(cond, None), then_branch, None)
-                            .indent(if_indent_level)
+                        make::expr_if(cond, then_branch, None).indent(if_indent_level)
                     };
-                    replace(new_expr.syntax(), &then_block, &parent_block, &if_expr)
+                    new_expr.syntax().clone_for_update()
                 }
                 Some((path, bound_ident)) => {
                     // If-let.
@@ -148,41 +154,32 @@ pub(crate) fn convert_to_guarded_return(acc: &mut Assists, ctx: &AssistContext)
 
                     let let_stmt = make::let_stmt(bound_ident, None, Some(match_expr));
                     let let_stmt = let_stmt.indent(if_indent_level);
-                    replace(let_stmt.syntax(), &then_block, &parent_block, &if_expr)
+                    let_stmt.syntax().clone_for_update()
                 }
             };
-            edit.replace_ast(parent_block, ast::BlockExpr::cast(new_block).unwrap());
-
-            fn replace(
-                new_expr: &SyntaxNode,
-                then_block: &ast::BlockExpr,
-                parent_block: &ast::BlockExpr,
-                if_expr: &ast::IfExpr,
-            ) -> SyntaxNode {
-                let then_block_items = then_block.dedent(IndentLevel(1));
-                let end_of_then = then_block_items.syntax().last_child_or_token().unwrap();
-                let end_of_then =
-                    if end_of_then.prev_sibling_or_token().map(|n| n.kind()) == Some(WHITESPACE) {
-                        end_of_then.prev_sibling_or_token().unwrap()
-                    } else {
-                        end_of_then
-                    };
-                let mut then_statements = new_expr.children_with_tokens().chain(
+
+            let then_block_items = then_block.dedent(IndentLevel(1)).clone_for_update();
+
+            let end_of_then = then_block_items.syntax().last_child_or_token().unwrap();
+            let end_of_then =
+                if end_of_then.prev_sibling_or_token().map(|n| n.kind()) == Some(WHITESPACE) {
+                    end_of_then.prev_sibling_or_token().unwrap()
+                } else {
+                    end_of_then
+                };
+
+            let then_statements = replacement
+                .children_with_tokens()
+                .chain(
                     then_block_items
                         .syntax()
                         .children_with_tokens()
                         .skip(1)
                         .take_while(|i| *i != end_of_then),
-                );
-                replace_children(
-                    parent_block.syntax(),
-                    RangeInclusive::new(
-                        if_expr.clone().syntax().clone().into(),
-                        if_expr.syntax().clone().into(),
-                    ),
-                    &mut then_statements,
                 )
-            }
+                .collect();
+
+            ted::replace_with_many(if_expr.syntax(), then_statements)
         },
     )
 }