]> git.lizzy.rs Git - rust.git/blobdiff - crates/ide_assists/src/handlers/pull_assignment_up.rs
Merge #11481
[rust.git] / crates / ide_assists / src / handlers / pull_assignment_up.rs
index 28d14b9c3f14e9c4091a0efeaa25c37daaa97de1..d142397c24f0b525df9277698d3f1afbf5e26f73 100644 (file)
@@ -39,7 +39,7 @@ pub(crate) fn pull_assignment_up(acc: &mut Assists, ctx: &AssistContext) -> Opti
     let assign_expr = ctx.find_node_at_offset::<ast::BinExpr>()?;
 
     let op_kind = assign_expr.op_kind()?;
-    if op_kind != ast::BinOp::Assignment {
+    if op_kind != (ast::BinaryOp::Assignment { op: None }) {
         cov_mark::hit!(test_cant_pull_non_assignments);
         return None;
     }
@@ -60,6 +60,12 @@ pub(crate) fn pull_assignment_up(acc: &mut Assists, ctx: &AssistContext) -> Opti
         return None;
     };
 
+    if let Some(parent) = tgt.syntax().parent() {
+        if matches!(parent.kind(), syntax::SyntaxKind::BIN_EXPR | syntax::SyntaxKind::LET_STMT) {
+            return None;
+        }
+    }
+
     acc.add(
         AssistId("pull_assignment_up", AssistKind::RefactorExtract),
         "Pull assignment up",
@@ -68,13 +74,19 @@ pub(crate) fn pull_assignment_up(acc: &mut Assists, ctx: &AssistContext) -> Opti
             let assignments: Vec<_> = collector
                 .assignments
                 .into_iter()
-                .map(|(stmt, rhs)| (edit.make_ast_mut(stmt), rhs.clone_for_update()))
+                .map(|(stmt, rhs)| (edit.make_mut(stmt), rhs.clone_for_update()))
                 .collect();
 
-            let tgt = edit.make_ast_mut(tgt);
+            let tgt = edit.make_mut(tgt);
 
             for (stmt, rhs) in assignments {
-                ted::replace(stmt.syntax(), rhs.syntax());
+                let mut stmt = stmt.syntax().clone();
+                if let Some(parent) = stmt.parent() {
+                    if ast::ExprStmt::cast(parent.clone()).is_some() {
+                        stmt = parent.clone();
+                    }
+                }
+                ted::replace(stmt, rhs.syntax());
             }
             let assign_expr = make::expr_assignment(collector.common_lhs, tgt.clone());
             let assign_stmt = make::expr_stmt(assign_expr);
@@ -87,7 +99,7 @@ pub(crate) fn pull_assignment_up(acc: &mut Assists, ctx: &AssistContext) -> Opti
 struct AssignmentsCollector<'a> {
     sema: &'a hir::Semantics<'a, ide_db::RootDatabase>,
     common_lhs: ast::Expr,
-    assignments: Vec<(ast::ExprStmt, ast::Expr)>,
+    assignments: Vec<(ast::BinExpr, ast::Expr)>,
 }
 
 impl<'a> AssignmentsCollector<'a> {
@@ -95,6 +107,7 @@ fn collect_match(&mut self, match_expr: &ast::MatchExpr) -> Option<()> {
         for arm in match_expr.match_arm_list()?.arms() {
             match arm.expr()? {
                 ast::Expr::BlockExpr(block) => self.collect_block(&block)?,
+                ast::Expr::BinExpr(expr) => self.collect_expr(&expr)?,
                 _ => return None,
             }
         }
@@ -114,24 +127,27 @@ fn collect_if(&mut self, if_expr: &ast::IfExpr) -> Option<()> {
         }
     }
     fn collect_block(&mut self, block: &ast::BlockExpr) -> Option<()> {
-        if block.tail_expr().is_some() {
-            return None;
-        }
+        let last_expr = block.tail_expr().or_else(|| match block.statements().last()? {
+            ast::Stmt::ExprStmt(stmt) => stmt.expr(),
+            ast::Stmt::Item(_) | ast::Stmt::LetStmt(_) => None,
+        })?;
 
-        let last_stmt = block.statements().last()?;
-        if let ast::Stmt::ExprStmt(stmt) = last_stmt {
-            if let ast::Expr::BinExpr(expr) = stmt.expr()? {
-                if expr.op_kind()? == ast::BinOp::Assignment
-                    && is_equivalent(self.sema, &expr.lhs()?, &self.common_lhs)
-                {
-                    self.assignments.push((stmt, expr.rhs()?));
-                    return Some(());
-                }
-            }
+        if let ast::Expr::BinExpr(expr) = last_expr {
+            return self.collect_expr(&expr);
         }
 
         None
     }
+
+    fn collect_expr(&mut self, expr: &ast::BinExpr) -> Option<()> {
+        if expr.op_kind()? == (ast::BinaryOp::Assignment { op: None })
+            && is_equivalent(self.sema, &expr.lhs()?, &self.common_lhs)
+        {
+            self.assignments.push((expr.clone(), expr.rhs()?));
+            return Some(());
+        }
+        None
+    }
 }
 
 fn is_equivalent(
@@ -154,8 +170,8 @@ fn is_equivalent(
             }
         }
         (ast::Expr::PrefixExpr(prefix0), ast::Expr::PrefixExpr(prefix1))
-            if prefix0.op_kind() == Some(ast::PrefixOp::Deref)
-                && prefix1.op_kind() == Some(ast::PrefixOp::Deref) =>
+            if prefix0.op_kind() == Some(ast::UnaryOp::Deref)
+                && prefix1.op_kind() == Some(ast::UnaryOp::Deref) =>
         {
             cov_mark::hit!(test_pull_assignment_up_deref);
             if let (Some(prefix0), Some(prefix1)) = (prefix0.expr(), prefix1.expr()) {
@@ -241,7 +257,6 @@ fn foo() {
     }
 
     #[test]
-    #[ignore]
     fn test_pull_assignment_up_assignment_expressions() {
         check_assist(
             pull_assignment_up,