]> git.lizzy.rs Git - rust.git/blobdiff - crates/ide_assists/src/handlers/pull_assignment_up.rs
Merge #11481
[rust.git] / crates / ide_assists / src / handlers / pull_assignment_up.rs
index 543b1dfe98e04f7078cdcc26b8056031ead3abea..d142397c24f0b525df9277698d3f1afbf5e26f73 100644 (file)
@@ -1,6 +1,6 @@
 use syntax::{
-    ast::{self, edit::AstNodeEdit, make},
-    AstNode,
+    ast::{self, make},
+    ted, AstNode,
 };
 
 use crate::{
@@ -39,101 +39,115 @@ pub(crate) fn pull_assignment_up(acc: &mut Assists, ctx: &AssistContext) -> Opti
     let assign_expr = ctx.find_node_at_offset::<ast::BinExpr>()?;
 
     let op_kind = assign_expr.op_kind()?;
-    if op_kind != ast::BinOp::Assignment {
+    if op_kind != (ast::BinaryOp::Assignment { op: None }) {
         cov_mark::hit!(test_cant_pull_non_assignments);
         return None;
     }
 
-    let name_expr = assign_expr.lhs()?;
-
-    let old_stmt: ast::Expr;
-    let new_stmt: ast::Expr;
+    let mut collector = AssignmentsCollector {
+        sema: &ctx.sema,
+        common_lhs: assign_expr.lhs()?,
+        assignments: Vec::new(),
+    };
 
-    if let Some(if_expr) = ctx.find_node_at_offset::<ast::IfExpr>() {
-        new_stmt = exprify_if(&if_expr, &ctx.sema, &name_expr)?.indent(if_expr.indent_level());
-        old_stmt = if_expr.into();
+    let tgt: ast::Expr = if let Some(if_expr) = ctx.find_node_at_offset::<ast::IfExpr>() {
+        collector.collect_if(&if_expr)?;
+        if_expr.into()
     } else if let Some(match_expr) = ctx.find_node_at_offset::<ast::MatchExpr>() {
-        new_stmt = exprify_match(&match_expr, &ctx.sema, &name_expr)?;
-        old_stmt = match_expr.into()
+        collector.collect_match(&match_expr)?;
+        match_expr.into()
     } else {
         return None;
     };
 
-    let expr_stmt = make::expr_stmt(new_stmt);
+    if let Some(parent) = tgt.syntax().parent() {
+        if matches!(parent.kind(), syntax::SyntaxKind::BIN_EXPR | syntax::SyntaxKind::LET_STMT) {
+            return None;
+        }
+    }
 
     acc.add(
         AssistId("pull_assignment_up", AssistKind::RefactorExtract),
         "Pull assignment up",
-        old_stmt.syntax().text_range(),
+        tgt.syntax().text_range(),
         move |edit| {
-            edit.replace(old_stmt.syntax().text_range(), format!("{} = {};", name_expr, expr_stmt));
+            let assignments: Vec<_> = collector
+                .assignments
+                .into_iter()
+                .map(|(stmt, rhs)| (edit.make_mut(stmt), rhs.clone_for_update()))
+                .collect();
+
+            let tgt = edit.make_mut(tgt);
+
+            for (stmt, rhs) in assignments {
+                let mut stmt = stmt.syntax().clone();
+                if let Some(parent) = stmt.parent() {
+                    if ast::ExprStmt::cast(parent.clone()).is_some() {
+                        stmt = parent.clone();
+                    }
+                }
+                ted::replace(stmt, rhs.syntax());
+            }
+            let assign_expr = make::expr_assignment(collector.common_lhs, tgt.clone());
+            let assign_stmt = make::expr_stmt(assign_expr);
+
+            ted::replace(tgt.syntax(), assign_stmt.syntax().clone_for_update());
         },
     )
 }
 
-fn exprify_match(
-    match_expr: &ast::MatchExpr,
-    sema: &hir::Semantics<ide_db::RootDatabase>,
-    name: &ast::Expr,
-) -> Option<ast::Expr> {
-    let new_arm_list = match_expr
-        .match_arm_list()?
-        .arms()
-        .map(|arm| {
-            if let ast::Expr::BlockExpr(block) = arm.expr()? {
-                let new_block = exprify_block(&block, sema, name)?.indent(block.indent_level());
-                Some(arm.replace_descendant(block, new_block))
-            } else {
-                None
-            }
-        })
-        .collect::<Option<Vec<_>>>()?;
-    let new_arm_list = match_expr
-        .match_arm_list()?
-        .replace_descendants(match_expr.match_arm_list()?.arms().zip(new_arm_list));
-    Some(make::expr_match(match_expr.expr()?, new_arm_list))
+struct AssignmentsCollector<'a> {
+    sema: &'a hir::Semantics<'a, ide_db::RootDatabase>,
+    common_lhs: ast::Expr,
+    assignments: Vec<(ast::BinExpr, ast::Expr)>,
 }
 
-fn exprify_if(
-    statement: &ast::IfExpr,
-    sema: &hir::Semantics<ide_db::RootDatabase>,
-    name: &ast::Expr,
-) -> Option<ast::Expr> {
-    let then_branch = exprify_block(&statement.then_branch()?, sema, name)?;
-    let else_branch = match statement.else_branch()? {
-        ast::ElseBranch::Block(block) => ast::ElseBranch::Block(exprify_block(&block, sema, name)?),
-        ast::ElseBranch::IfExpr(expr) => {
-            cov_mark::hit!(test_pull_assignment_up_chained_if);
-            ast::ElseBranch::IfExpr(ast::IfExpr::cast(
-                exprify_if(&expr, sema, name)?.syntax().to_owned(),
-            )?)
+impl<'a> AssignmentsCollector<'a> {
+    fn collect_match(&mut self, match_expr: &ast::MatchExpr) -> Option<()> {
+        for arm in match_expr.match_arm_list()?.arms() {
+            match arm.expr()? {
+                ast::Expr::BlockExpr(block) => self.collect_block(&block)?,
+                ast::Expr::BinExpr(expr) => self.collect_expr(&expr)?,
+                _ => return None,
+            }
         }
-    };
-    Some(make::expr_if(statement.condition()?, then_branch, Some(else_branch)))
-}
 
-fn exprify_block(
-    block: &ast::BlockExpr,
-    sema: &hir::Semantics<ide_db::RootDatabase>,
-    name: &ast::Expr,
-) -> Option<ast::BlockExpr> {
-    if block.tail_expr().is_some() {
-        return None;
+        Some(())
+    }
+    fn collect_if(&mut self, if_expr: &ast::IfExpr) -> Option<()> {
+        let then_branch = if_expr.then_branch()?;
+        self.collect_block(&then_branch)?;
+
+        match if_expr.else_branch()? {
+            ast::ElseBranch::Block(block) => self.collect_block(&block),
+            ast::ElseBranch::IfExpr(expr) => {
+                cov_mark::hit!(test_pull_assignment_up_chained_if);
+                self.collect_if(&expr)
+            }
+        }
     }
+    fn collect_block(&mut self, block: &ast::BlockExpr) -> Option<()> {
+        let last_expr = block.tail_expr().or_else(|| match block.statements().last()? {
+            ast::Stmt::ExprStmt(stmt) => stmt.expr(),
+            ast::Stmt::Item(_) | ast::Stmt::LetStmt(_) => None,
+        })?;
+
+        if let ast::Expr::BinExpr(expr) = last_expr {
+            return self.collect_expr(&expr);
+        }
 
-    let mut stmts: Vec<_> = block.statements().collect();
-    let stmt = stmts.pop()?;
+        None
+    }
 
-    if let ast::Stmt::ExprStmt(stmt) = stmt {
-        if let ast::Expr::BinExpr(expr) = stmt.expr()? {
-            if expr.op_kind()? == ast::BinOp::Assignment && is_equivalent(sema, &expr.lhs()?, name)
-            {
-                // The last statement in the block is an assignment to the name we want
-                return Some(make::block_expr(stmts, Some(expr.rhs()?)));
-            }
+    fn collect_expr(&mut self, expr: &ast::BinExpr) -> Option<()> {
+        if expr.op_kind()? == (ast::BinaryOp::Assignment { op: None })
+            && is_equivalent(self.sema, &expr.lhs()?, &self.common_lhs)
+        {
+            self.assignments.push((expr.clone(), expr.rhs()?));
+            return Some(());
         }
+        None
     }
-    None
 }
 
 fn is_equivalent(
@@ -156,8 +170,8 @@ fn is_equivalent(
             }
         }
         (ast::Expr::PrefixExpr(prefix0), ast::Expr::PrefixExpr(prefix1))
-            if prefix0.op_kind() == Some(ast::PrefixOp::Deref)
-                && prefix1.op_kind() == Some(ast::PrefixOp::Deref) =>
+            if prefix0.op_kind() == Some(ast::UnaryOp::Deref)
+                && prefix1.op_kind() == Some(ast::UnaryOp::Deref) =>
         {
             cov_mark::hit!(test_pull_assignment_up_deref);
             if let (Some(prefix0), Some(prefix1)) = (prefix0.expr(), prefix1.expr()) {
@@ -242,6 +256,37 @@ fn foo() {
         );
     }
 
+    #[test]
+    fn test_pull_assignment_up_assignment_expressions() {
+        check_assist(
+            pull_assignment_up,
+            r#"
+fn foo() {
+    let mut a = 1;
+
+    match 1 {
+        1 => { $0a = 2; },
+        2 => a = 3,
+        3 => {
+            a = 4
+        }
+    }
+}"#,
+            r#"
+fn foo() {
+    let mut a = 1;
+
+    a = match 1 {
+        1 => { 2 },
+        2 => 3,
+        3 => {
+            4
+        }
+    };
+}"#,
+        );
+    }
+
     #[test]
     fn test_pull_assignment_up_not_last_not_applicable() {
         check_assist_not_applicable(