]> git.lizzy.rs Git - rust.git/blobdiff - crates/ide_assists/src/handlers/replace_if_let_with_match.rs
parameters.split_last()
[rust.git] / crates / ide_assists / src / handlers / replace_if_let_with_match.rs
index 5b766ecbeae2e2c3ba95a6ee16704017e42b6d15..77909347927802a2702982a2deb6c9f70d8e319e 100644 (file)
@@ -12,7 +12,7 @@
 };
 
 use crate::{
-    utils::{does_pat_match_variant, unwrap_trivial_block},
+    utils::{does_nested_pattern, does_pat_match_variant, unwrap_trivial_block},
     AssistContext, AssistId, AssistKind, Assists,
 };
 
@@ -48,7 +48,7 @@ pub(crate) fn replace_if_let_with_match(acc: &mut Assists, ctx: &AssistContext)
         if_expr.syntax().text_range().start(),
         if_expr.then_branch()?.syntax().text_range().start(),
     );
-    let cursor_in_range = available_range.contains_range(ctx.frange.range);
+    let cursor_in_range = available_range.contains_range(ctx.selection_trimmed());
     if !cursor_in_range {
         return None;
     }
@@ -93,7 +93,7 @@ pub(crate) fn replace_if_let_with_match(acc: &mut Assists, ctx: &AssistContext)
         available_range,
         move |edit| {
             let match_expr = {
-                let else_arm = make_else_arm(else_block);
+                let else_arm = make_else_arm(ctx, else_block, &cond_bodies);
                 let make_match_arm = |(pat, body): (_, ast::BlockExpr)| {
                     let body = body.reset_indent().indent(IndentLevel(1));
                     match pat {
@@ -125,9 +125,32 @@ pub(crate) fn replace_if_let_with_match(acc: &mut Assists, ctx: &AssistContext)
     )
 }
 
-fn make_else_arm(else_block: Option<ast::BlockExpr>) -> ast::MatchArm {
+fn make_else_arm(
+    ctx: &AssistContext,
+    else_block: Option<ast::BlockExpr>,
+    conditionals: &[(Either<ast::Pat, ast::Expr>, ast::BlockExpr)],
+) -> ast::MatchArm {
     if let Some(else_block) = else_block {
-        let pattern = make::wildcard_pat().into();
+        let pattern = if let [(Either::Left(pat), _)] = conditionals {
+            ctx.sema
+                .type_of_pat(pat)
+                .and_then(|ty| TryEnum::from_ty(&ctx.sema, &ty.adjusted()))
+                .zip(Some(pat))
+        } else {
+            None
+        };
+        let pattern = match pattern {
+            Some((it, pat)) => {
+                if does_pat_match_variant(pat, &it.sad_pattern()) {
+                    it.happy_pattern_wildcard()
+                } else if does_nested_pattern(pat) {
+                    make::wildcard_pat().into()
+                } else {
+                    it.sad_pattern()
+                }
+            }
+            None => make::wildcard_pat().into(),
+        };
         make::match_arm(iter::once(pattern), None, unwrap_trivial_block(else_block))
     } else {
         make::match_arm(iter::once(make::wildcard_pat().into()), None, make::expr_unit())
@@ -184,21 +207,23 @@ pub(crate) fn replace_match_with_if_let(acc: &mut Assists, ctx: &AssistContext)
         "Replace match with if let",
         target,
         move |edit| {
+            fn make_block_expr(expr: ast::Expr) -> ast::BlockExpr {
+                // Blocks with modifiers (unsafe, async, etc.) are parsed as BlockExpr, but are
+                // formatted without enclosing braces. If we encounter such block exprs,
+                // wrap them in another BlockExpr.
+                match expr {
+                    ast::Expr::BlockExpr(block) if block.modifier().is_none() => block,
+                    expr => make::block_expr(iter::empty(), Some(expr)),
+                }
+            }
+
             let condition = make::condition(scrutinee, Some(if_let_pat));
-            let then_block = match then_expr.reset_indent() {
-                ast::Expr::BlockExpr(block) => block,
-                expr => make::block_expr(iter::empty(), Some(expr)),
-            };
+            let then_block = make_block_expr(then_expr.reset_indent());
             let else_expr = if is_empty_expr(&else_expr) { None } else { Some(else_expr) };
             let if_let_expr = make::expr_if(
                 condition,
                 then_block,
-                else_expr
-                    .map(|expr| match expr {
-                        ast::Expr::BlockExpr(block) => block,
-                        expr => (make::block_expr(iter::empty(), Some(expr))),
-                    })
-                    .map(ast::ElseBranch::Block),
+                else_expr.map(make_block_expr).map(ast::ElseBranch::Block),
             )
             .indent(IndentLevel::from_node(match_expr.syntax()));
 
@@ -439,7 +464,7 @@ fn foo(x: Option<i32>) {
 fn foo(x: Option<i32>) {
     match x {
         Some(x) => println!("{}", x),
-        _ => println!("none"),
+        None => println!("none"),
     }
 }
 "#,
@@ -464,7 +489,7 @@ fn foo(x: Option<i32>) {
 fn foo(x: Option<i32>) {
     match x {
         None => println!("none"),
-        _ => println!("some"),
+        Some(_) => println!("some"),
     }
 }
 "#,
@@ -489,7 +514,7 @@ fn foo(x: Result<i32, ()>) {
 fn foo(x: Result<i32, ()>) {
     match x {
         Ok(x) => println!("{}", x),
-        _ => println!("none"),
+        Err(_) => println!("none"),
     }
 }
 "#,
@@ -514,7 +539,7 @@ fn foo(x: Result<i32, ()>) {
 fn foo(x: Result<i32, ()>) {
     match x {
         Err(x) => println!("{}", x),
-        _ => println!("ok"),
+        Ok(_) => println!("ok"),
     }
 }
 "#,
@@ -554,7 +579,7 @@ fn main() {
     }
 
     #[test]
-    fn replace_if_let_with_match_nested_type() {
+    fn nested_type() {
         check_assist(
             replace_if_let_with_match,
             r#"
@@ -894,4 +919,30 @@ fn foo() {
 "#,
         );
     }
+
+    #[test]
+    fn test_replace_match_with_if_let_keeps_unsafe_block() {
+        check_assist(
+            replace_match_with_if_let,
+            r#"
+impl VariantData {
+    pub fn is_struct(&self) -> bool {
+        $0match *self {
+            VariantData::Struct(..) => true,
+            _ => unsafe { unreachable_unchecked() },
+        }
+    }
+}           "#,
+            r#"
+impl VariantData {
+    pub fn is_struct(&self) -> bool {
+        if let VariantData::Struct(..) = *self {
+            true
+        } else {
+            unsafe { unreachable_unchecked() }
+        }
+    }
+}           "#,
+        )
+    }
 }