]> git.lizzy.rs Git - rust.git/commitdiff
Support closures in infer_function_return_type assist
authorLukas Wirth <lukastw97@gmail.com>
Fri, 6 Nov 2020 01:13:29 +0000 (02:13 +0100)
committerLukas Wirth <lukastw97@gmail.com>
Fri, 6 Nov 2020 01:15:13 +0000 (02:15 +0100)
crates/assists/src/handlers/infer_function_return_type.rs

index da60ff9dedf9311088aab34618f689d5bd844aba..f363a56f3220b47445e7b3eb2323286efe967807 100644 (file)
@@ -1,12 +1,12 @@
 use hir::HirDisplay;
-use syntax::{ast, AstNode, TextSize};
+use syntax::{ast, AstNode, SyntaxToken, TextSize};
 use test_utils::mark;
 
 use crate::{AssistContext, AssistId, AssistKind, Assists};
 
 // Assist: infer_function_return_type
 //
-// Adds the return type to a function inferred from its tail expression if it doesn't have a return
+// Adds the return type to a function or closure inferred from its tail expression if it doesn't have a return
 // type specified.
 //
 // ```
 // ```
 pub(crate) fn infer_function_return_type(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
     let expr = ctx.find_node_at_offset::<ast::Expr>()?;
-    let func = expr.syntax().ancestors().find_map(ast::Fn::cast)?;
-
-    if func.ret_type().is_some() {
-        mark::hit!(existing_ret_type);
-        return None;
-    }
-    let body = func.body()?;
-    let tail_expr = body.expr()?;
-    // check whether the expr we were at is indeed the tail expression
-    if !tail_expr.syntax().text_range().contains_range(expr.syntax().text_range()) {
-        mark::hit!(not_tail_expr);
-        return None;
-    }
-    let module = ctx.sema.scope(func.syntax()).module()?;
+    let (tail_expr, insert_pos) = extract(expr)?;
+    let module = ctx.sema.scope(tail_expr.syntax()).module()?;
     let ty = ctx.sema.type_of_expr(&tail_expr)?;
     let ty = ty.display_source_code(ctx.db(), module.into()).ok()?;
-    let rparen = func.param_list()?.r_paren_token()?;
 
     acc.add(
         AssistId("change_return_type_to_result", AssistKind::RefactorRewrite),
         "Wrap return type in Result",
         tail_expr.syntax().text_range(),
         |builder| {
-            let insert_pos = rparen.text_range().end() + TextSize::from(1);
-
+            let insert_pos = insert_pos.text_range().end() + TextSize::from(1);
             builder.insert(insert_pos, &format!("-> {} ", ty));
         },
     )
 }
 
+fn extract(expr: ast::Expr) -> Option<(ast::Expr, SyntaxToken)> {
+    let (ret_ty, tail_expr, insert_pos) =
+        if let Some(closure) = expr.syntax().ancestors().find_map(ast::ClosureExpr::cast) {
+            let tail_expr = match closure.body()? {
+                ast::Expr::BlockExpr(block) => block.expr()?,
+                body => body,
+            };
+            let ret_ty = closure.ret_type();
+            let rpipe = closure.param_list()?.syntax().last_token()?;
+            (ret_ty, tail_expr, rpipe)
+        } else {
+            let func = expr.syntax().ancestors().find_map(ast::Fn::cast)?;
+            let tail_expr = func.body()?.expr()?;
+            let ret_ty = func.ret_type();
+            let rparen = func.param_list()?.r_paren_token()?;
+            (ret_ty, tail_expr, rparen)
+        };
+    if ret_ty.is_some() {
+        mark::hit!(existing_ret_type);
+        mark::hit!(existing_ret_type_closure);
+        return None;
+    }
+    // check whether the expr we were at is indeed the tail expression
+    if !tail_expr.syntax().text_range().contains_range(expr.syntax().text_range()) {
+        mark::hit!(not_tail_expr);
+        return None;
+    }
+    Some((tail_expr, insert_pos))
+}
+
 #[cfg(test)]
 mod tests {
     use crate::tests::{check_assist, check_assist_not_applicable};
@@ -110,4 +126,83 @@ fn not_applicable_non_tail_expr() {
             }"#,
         );
     }
+
+    #[test]
+    fn infer_return_type_closure_block() {
+        check_assist(
+            infer_function_return_type,
+            r#"fn foo() {
+                |x: i32| {
+                    x<|>
+                };
+            }"#,
+            r#"fn foo() {
+                |x: i32| -> i32 {
+                    x
+                };
+            }"#,
+        );
+    }
+
+    #[test]
+    fn infer_return_type_closure() {
+        check_assist(
+            infer_function_return_type,
+            r#"fn foo() {
+                |x: i32| x<|>;
+            }"#,
+            r#"fn foo() {
+                |x: i32| -> i32 x;
+            }"#,
+        );
+    }
+
+    #[test]
+    fn infer_return_type_nested_closure() {
+        check_assist(
+            infer_function_return_type,
+            r#"fn foo() {
+                || {
+                    if true {
+                        3<|>
+                    } else {
+                        5
+                    }
+                }
+            }"#,
+            r#"fn foo() {
+                || -> i32 {
+                    if true {
+                        3
+                    } else {
+                        5
+                    }
+                }
+            }"#,
+        );
+    }
+
+    #[test]
+    fn not_applicable_ret_type_specified_closure() {
+        mark::check!(existing_ret_type_closure);
+        check_assist_not_applicable(
+            infer_function_return_type,
+            r#"fn foo() {
+                || -> i32 { 3<|> }
+            }"#,
+        );
+    }
+
+    #[test]
+    fn not_applicable_non_tail_expr_closure() {
+        check_assist_not_applicable(
+            infer_function_return_type,
+            r#"fn foo() {
+                || -> i32 {
+                    let x = 3<|>;
+                    6
+                }
+            }"#,
+        );
+    }
 }