]> git.lizzy.rs Git - rust.git/blobdiff - crates/ide_assists/src/handlers/replace_if_let_with_match.rs
parameters.split_last()
[rust.git] / crates / ide_assists / src / handlers / replace_if_let_with_match.rs
index 4c201aee6af5ba223328f88143b1a1eba8cfcee9..77909347927802a2702982a2deb6c9f70d8e319e 100644 (file)
@@ -1,18 +1,18 @@
 use std::iter::{self, successors};
 
 use either::Either;
-use ide_db::{ty_filter::TryEnum, RootDatabase};
+use ide_db::{defs::NameClass, ty_filter::TryEnum, RootDatabase};
 use syntax::{
     ast::{
         self,
         edit::{AstNodeEdit, IndentLevel},
-        make,
+        make, HasName,
     },
-    AstNode,
+    AstNode, TextRange,
 };
 
 use crate::{
-    utils::{does_pat_match_variant, unwrap_trivial_block},
+    utils::{does_nested_pattern, does_pat_match_variant, unwrap_trivial_block},
     AssistContext, AssistId, AssistKind, Assists,
 };
 
 // ```
 pub(crate) fn replace_if_let_with_match(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
     let if_expr: ast::IfExpr = ctx.find_node_at_offset()?;
+    let available_range = TextRange::new(
+        if_expr.syntax().text_range().start(),
+        if_expr.then_branch()?.syntax().text_range().start(),
+    );
+    let cursor_in_range = available_range.contains_range(ctx.selection_trimmed());
+    if !cursor_in_range {
+        return None;
+    }
     let mut else_block = None;
     let if_exprs = successors(Some(if_expr.clone()), |expr| match expr.else_branch()? {
         ast::ElseBranch::IfExpr(expr) => Some(expr),
@@ -79,14 +87,13 @@ pub(crate) fn replace_if_let_with_match(acc: &mut Assists, ctx: &AssistContext)
         return None;
     }
 
-    let target = if_expr.syntax().text_range();
     acc.add(
         AssistId("replace_if_let_with_match", AssistKind::RefactorRewrite),
         "Replace if let with match",
-        target,
+        available_range,
         move |edit| {
             let match_expr = {
-                let else_arm = make_else_arm(else_block, &cond_bodies, ctx);
+                let else_arm = make_else_arm(ctx, else_block, &cond_bodies);
                 let make_match_arm = |(pat, body): (_, ast::BlockExpr)| {
                     let body = body.reset_indent().indent(IndentLevel(1));
                     match pat {
@@ -119,23 +126,25 @@ pub(crate) fn replace_if_let_with_match(acc: &mut Assists, ctx: &AssistContext)
 }
 
 fn make_else_arm(
-    else_block: Option<ast::BlockExpr>,
-    cond_bodies: &Vec<(Either<ast::Pat, ast::Expr>, ast::BlockExpr)>,
     ctx: &AssistContext,
+    else_block: Option<ast::BlockExpr>,
+    conditionals: &[(Either<ast::Pat, ast::Expr>, ast::BlockExpr)],
 ) -> ast::MatchArm {
     if let Some(else_block) = else_block {
-        let pattern = if let [(Either::Left(pat), _)] = &**cond_bodies {
+        let pattern = if let [(Either::Left(pat), _)] = conditionals {
             ctx.sema
-                .type_of_pat(&pat)
-                .and_then(|ty| TryEnum::from_ty(&ctx.sema, &ty))
+                .type_of_pat(pat)
+                .and_then(|ty| TryEnum::from_ty(&ctx.sema, &ty.adjusted()))
                 .zip(Some(pat))
         } else {
             None
         };
         let pattern = match pattern {
             Some((it, pat)) => {
-                if does_pat_match_variant(&pat, &it.sad_pattern()) {
-                    it.happy_pattern()
+                if does_pat_match_variant(pat, &it.sad_pattern()) {
+                    it.happy_pattern_wildcard()
+                } else if does_nested_pattern(pat) {
+                    make::wildcard_pat().into()
                 } else {
                     it.sad_pattern()
                 }
@@ -144,7 +153,7 @@ fn make_else_arm(
         };
         make::match_arm(iter::once(pattern), None, unwrap_trivial_block(else_block))
     } else {
-        make::match_arm(iter::once(make::wildcard_pat().into()), None, make::expr_unit().into())
+        make::match_arm(iter::once(make::wildcard_pat().into()), None, make::expr_unit())
     }
 }
 
@@ -198,25 +207,23 @@ pub(crate) fn replace_match_with_if_let(acc: &mut Assists, ctx: &AssistContext)
         "Replace match with if let",
         target,
         move |edit| {
+            fn make_block_expr(expr: ast::Expr) -> ast::BlockExpr {
+                // Blocks with modifiers (unsafe, async, etc.) are parsed as BlockExpr, but are
+                // formatted without enclosing braces. If we encounter such block exprs,
+                // wrap them in another BlockExpr.
+                match expr {
+                    ast::Expr::BlockExpr(block) if block.modifier().is_none() => block,
+                    expr => make::block_expr(iter::empty(), Some(expr)),
+                }
+            }
+
             let condition = make::condition(scrutinee, Some(if_let_pat));
-            let then_block = match then_expr.reset_indent() {
-                ast::Expr::BlockExpr(block) => block,
-                expr => make::block_expr(iter::empty(), Some(expr)),
-            };
-            let else_expr = match else_expr {
-                ast::Expr::BlockExpr(block) if block.is_empty() => None,
-                ast::Expr::TupleExpr(tuple) if tuple.fields().next().is_none() => None,
-                expr => Some(expr),
-            };
+            let then_block = make_block_expr(then_expr.reset_indent());
+            let else_expr = if is_empty_expr(&else_expr) { None } else { Some(else_expr) };
             let if_let_expr = make::expr_if(
                 condition,
                 then_block,
-                else_expr
-                    .map(|expr| match expr {
-                        ast::Expr::BlockExpr(block) => block,
-                        expr => (make::block_expr(iter::empty(), Some(expr))),
-                    })
-                    .map(ast::ElseBranch::Block),
+                else_expr.map(make_block_expr).map(ast::ElseBranch::Block),
             )
             .indent(IndentLevel::from_node(match_expr.syntax()));
 
@@ -235,22 +242,37 @@ fn pick_pattern_and_expr_order(
 ) -> Option<(ast::Pat, ast::Expr, ast::Expr)> {
     let res = match (pat, pat2) {
         (ast::Pat::WildcardPat(_), _) => return None,
-        (pat, sad_pat) if is_sad_pat(sema, &sad_pat) => (pat, expr, expr2),
-        (sad_pat, pat) if is_sad_pat(sema, &sad_pat) => (pat, expr2, expr),
-        (pat, pat2) => match (binds_name(&pat), binds_name(&pat2)) {
+        (pat, _) if is_empty_expr(&expr2) => (pat, expr, expr2),
+        (_, pat) if is_empty_expr(&expr) => (pat, expr2, expr),
+        (pat, pat2) => match (binds_name(sema, &pat), binds_name(sema, &pat2)) {
             (true, true) => return None,
             (true, false) => (pat, expr, expr2),
             (false, true) => (pat2, expr2, expr),
+            _ if is_sad_pat(sema, &pat) => (pat2, expr2, expr),
             (false, false) => (pat, expr, expr2),
         },
     };
     Some(res)
 }
 
-fn binds_name(pat: &ast::Pat) -> bool {
-    let binds_name_v = |pat| binds_name(&pat);
+fn is_empty_expr(expr: &ast::Expr) -> bool {
+    match expr {
+        ast::Expr::BlockExpr(expr) => match expr.stmt_list() {
+            Some(it) => it.statements().next().is_none() && it.tail_expr().is_none(),
+            None => true,
+        },
+        ast::Expr::TupleExpr(expr) => expr.fields().next().is_none(),
+        _ => false,
+    }
+}
+
+fn binds_name(sema: &hir::Semantics<RootDatabase>, pat: &ast::Pat) -> bool {
+    let binds_name_v = |pat| binds_name(sema, &pat);
     match pat {
-        ast::Pat::IdentPat(_) => true,
+        ast::Pat::IdentPat(pat) => !matches!(
+            pat.name().and_then(|name| NameClass::classify(sema, &name)),
+            Some(NameClass::ConstReference(_))
+        ),
         ast::Pat::MacroPat(_) => true,
         ast::Pat::OrPat(pat) => pat.pats().any(binds_name_v),
         ast::Pat::SlicePat(pat) => pat.pats().any(binds_name_v),
@@ -268,7 +290,7 @@ fn binds_name(pat: &ast::Pat) -> bool {
 
 fn is_sad_pat(sema: &hir::Semantics<RootDatabase>, pat: &ast::Pat) -> bool {
     sema.type_of_pat(pat)
-        .and_then(|ty| TryEnum::from_ty(sema, &ty))
+        .and_then(|ty| TryEnum::from_ty(sema, &ty.adjusted()))
         .map_or(false, |it| does_pat_match_variant(pat, &it.sad_pattern()))
 }
 
@@ -318,6 +340,38 @@ pub fn foo(&self) {
         )
     }
 
+    #[test]
+    fn test_if_let_with_match_available_range_left() {
+        check_assist_not_applicable(
+            replace_if_let_with_match,
+            r#"
+impl VariantData {
+    pub fn foo(&self) {
+        $0 if let VariantData::Struct(..) = *self {
+            self.foo();
+        }
+    }
+}
+"#,
+        )
+    }
+
+    #[test]
+    fn test_if_let_with_match_available_range_right() {
+        check_assist_not_applicable(
+            replace_if_let_with_match,
+            r#"
+impl VariantData {
+    pub fn foo(&self) {
+        if let VariantData::Struct(..) = *self {$0
+            self.foo();
+        }
+    }
+}
+"#,
+        )
+    }
+
     #[test]
     fn test_if_let_with_match_basic() {
         check_assist(
@@ -524,6 +578,33 @@ fn main() {
         )
     }
 
+    #[test]
+    fn nested_type() {
+        check_assist(
+            replace_if_let_with_match,
+            r#"
+//- minicore: result
+fn foo(x: Result<i32, ()>) {
+    let bar: Result<_, ()> = Ok(Some(1));
+    $0if let Ok(Some(_)) = bar {
+        ()
+    } else {
+        ()
+    }
+}
+"#,
+            r#"
+fn foo(x: Result<i32, ()>) {
+    let bar: Result<_, ()> = Ok(Some(1));
+    match bar {
+        Ok(Some(_)) => (),
+        _ => (),
+    }
+}
+"#,
+        );
+    }
+
     #[test]
     fn test_replace_match_with_if_let_unwraps_simple_expressions() {
         check_assist(
@@ -702,6 +783,28 @@ fn main() {
         )
     }
 
+    #[test]
+    fn replace_match_with_if_let_number_body() {
+        check_assist(
+            replace_match_with_if_let,
+            r#"
+fn main() {
+    $0match Ok(()) {
+        Ok(()) => {},
+        Err(_) => 0,
+    }
+}
+"#,
+            r#"
+fn main() {
+    if let Err(_) = Ok(()) {
+        0
+    }
+}
+"#,
+        )
+    }
+
     #[test]
     fn replace_match_with_if_let_exhaustive() {
         check_assist(
@@ -762,6 +865,46 @@ fn foo() {
         );
     }
 
+    #[test]
+    fn replace_match_with_if_let_prefer_nonempty_body() {
+        check_assist(
+            replace_match_with_if_let,
+            r#"
+fn foo() {
+    match $0Ok(0) {
+        Ok(value) => {},
+        Err(err) => eprintln!("{}", err),
+    }
+}
+"#,
+            r#"
+fn foo() {
+    if let Err(err) = Ok(0) {
+        eprintln!("{}", err)
+    }
+}
+"#,
+        );
+        check_assist(
+            replace_match_with_if_let,
+            r#"
+fn foo() {
+    match $0Ok(0) {
+        Err(err) => eprintln!("{}", err),
+        Ok(value) => {},
+    }
+}
+"#,
+            r#"
+fn foo() {
+    if let Err(err) = Ok(0) {
+        eprintln!("{}", err)
+    }
+}
+"#,
+        );
+    }
+
     #[test]
     fn replace_match_with_if_let_rejects_double_name_bindings() {
         check_assist_not_applicable(
@@ -776,4 +919,30 @@ fn foo() {
 "#,
         );
     }
+
+    #[test]
+    fn test_replace_match_with_if_let_keeps_unsafe_block() {
+        check_assist(
+            replace_match_with_if_let,
+            r#"
+impl VariantData {
+    pub fn is_struct(&self) -> bool {
+        $0match *self {
+            VariantData::Struct(..) => true,
+            _ => unsafe { unreachable_unchecked() },
+        }
+    }
+}           "#,
+            r#"
+impl VariantData {
+    pub fn is_struct(&self) -> bool {
+        if let VariantData::Struct(..) = *self {
+            true
+        } else {
+            unsafe { unreachable_unchecked() }
+        }
+    }
+}           "#,
+        )
+    }
 }