]> git.lizzy.rs Git - rust.git/blobdiff - crates/ide_assists/src/handlers/add_return_type.rs
Merge #11842
[rust.git] / crates / ide_assists / src / handlers / add_return_type.rs
index 2c5b61eddb769c06ddfa3c2eb4ff9ad477723433..c7172741e46de527b0f111336768fe7286bd5a6a 100644 (file)
@@ -1,5 +1,5 @@
 use hir::HirDisplay;
-use syntax::{ast, AstNode, SyntaxKind, SyntaxToken, TextRange, TextSize};
+use syntax::{ast, match_ast, AstNode, SyntaxKind, SyntaxToken, TextRange, TextSize};
 
 use crate::{AssistContext, AssistId, AssistKind, Assists};
 
@@ -18,7 +18,7 @@
 pub(crate) fn add_return_type(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
     let (fn_type, tail_expr, builder_edit_pos) = extract_tail(ctx)?;
     let module = ctx.sema.scope(tail_expr.syntax()).module()?;
-    let ty = ctx.sema.type_of_expr(&tail_expr)?.adjusted();
+    let ty = ctx.sema.type_of_expr(&peel_blocks(tail_expr.clone()))?.original();
     if ty.is_unit() {
         return None;
     }
@@ -93,6 +93,45 @@ enum FnType {
     Closure { wrap_expr: bool },
 }
 
+/// If we're looking at a block that is supposed to return `()`, type inference
+/// will just tell us it has type `()`. We have to look at the tail expression
+/// to see the mismatched actual type. This 'unpeels' the various blocks to
+/// hopefully let us see the type the user intends. (This still doesn't handle
+/// all situations fully correctly; the 'ideal' way to handle this would be to
+/// run type inference on the function again, but with a variable as the return
+/// type.)
+fn peel_blocks(mut expr: ast::Expr) -> ast::Expr {
+    loop {
+        match_ast! {
+            match (expr.syntax()) {
+                ast::BlockExpr(it) => {
+                    if let Some(tail) = it.tail_expr() {
+                        expr = tail.clone();
+                    } else {
+                        break;
+                    }
+                },
+                ast::IfExpr(it) => {
+                    if let Some(then_branch) = it.then_branch() {
+                        expr = ast::Expr::BlockExpr(then_branch.clone());
+                    } else {
+                        break;
+                    }
+                },
+                ast::MatchExpr(it) => {
+                    if let Some(arm_expr) = it.match_arm_list().and_then(|l| l.arms().next()).and_then(|a| a.expr()) {
+                        expr = arm_expr;
+                    } else {
+                        break;
+                    }
+                },
+                _ => break,
+            }
+        }
+    }
+    expr
+}
+
 fn extract_tail(ctx: &AssistContext) -> Option<(FnType, ast::Expr, InsertOrReplace)> {
     let (fn_type, tail_expr, return_type_range, action) =
         if let Some(closure) = ctx.find_node_at_offset::<ast::ClosureExpr>() {
@@ -248,6 +287,25 @@ fn infer_return_type_nested() {
         );
     }
 
+    #[test]
+    fn infer_return_type_nested_match() {
+        check_assist(
+            add_return_type,
+            r#"fn foo() {
+    match true {
+        true => { 3$0 },
+        false => { 5 },
+    }
+}"#,
+            r#"fn foo() -> i32 {
+    match true {
+        true => { 3 },
+        false => { 5 },
+    }
+}"#,
+        );
+    }
+
     #[test]
     fn not_applicable_ret_type_specified() {
         cov_mark::check!(existing_ret_type);