2 use syntax::{ast, match_ast, AstNode, SyntaxKind, SyntaxToken, TextRange, TextSize};
4 use crate::{AssistContext, AssistId, AssistKind, Assists};
6 // Assist: add_return_type
8 // Adds the return type to a function or closure inferred from its tail expression if it doesn't have a return
9 // type specified. This assists is useable in a functions or closures tail expression or return type position.
12 // fn foo() { 4$02i32 }
16 // fn foo() -> i32 { 42i32 }
18 pub(crate) fn add_return_type(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> {
19 let (fn_type, tail_expr, builder_edit_pos) = extract_tail(ctx)?;
20 let module = ctx.sema.scope(tail_expr.syntax())?.module();
21 let ty = ctx.sema.type_of_expr(&peel_blocks(tail_expr.clone()))?.original();
25 let ty = ty.display_source_code(ctx.db(), module.into()).ok()?;
28 AssistId("add_return_type", AssistKind::RefactorRewrite),
30 FnType::Function => "Add this function's return type",
31 FnType::Closure { .. } => "Add this closure's return type",
33 tail_expr.syntax().text_range(),
35 match builder_edit_pos {
36 InsertOrReplace::Insert(insert_pos, needs_whitespace) => {
37 let preceeding_whitespace = if needs_whitespace { " " } else { "" };
38 builder.insert(insert_pos, &format!("{}-> {} ", preceeding_whitespace, ty))
40 InsertOrReplace::Replace(text_range) => {
41 builder.replace(text_range, &format!("-> {}", ty))
44 if let FnType::Closure { wrap_expr: true } = fn_type {
45 cov_mark::hit!(wrap_closure_non_block_expr);
46 // `|x| x` becomes `|x| -> T x` which is invalid, so wrap it in a block
47 builder.replace(tail_expr.syntax().text_range(), &format!("{{{}}}", tail_expr));
53 enum InsertOrReplace {
54 Insert(TextSize, bool),
58 /// Check the potentially already specified return type and reject it or turn it into a builder command
61 ret_ty: Option<ast::RetType>,
62 insert_after: SyntaxToken,
63 ) -> Option<InsertOrReplace> {
65 Some(ret_ty) => match ret_ty.ty() {
66 Some(ast::Type::InferType(_)) | None => {
67 cov_mark::hit!(existing_infer_ret_type);
68 cov_mark::hit!(existing_infer_ret_type_closure);
69 Some(InsertOrReplace::Replace(ret_ty.syntax().text_range()))
72 cov_mark::hit!(existing_ret_type);
73 cov_mark::hit!(existing_ret_type_closure);
78 let insert_after_pos = insert_after.text_range().end();
79 let (insert_pos, needs_whitespace) = match insert_after.next_token() {
80 Some(it) if it.kind() == SyntaxKind::WHITESPACE => {
81 (insert_after_pos + TextSize::from(1), false)
83 _ => (insert_after_pos, true),
86 Some(InsertOrReplace::Insert(insert_pos, needs_whitespace))
93 Closure { wrap_expr: bool },
96 /// If we're looking at a block that is supposed to return `()`, type inference
97 /// will just tell us it has type `()`. We have to look at the tail expression
98 /// to see the mismatched actual type. This 'unpeels' the various blocks to
99 /// hopefully let us see the type the user intends. (This still doesn't handle
100 /// all situations fully correctly; the 'ideal' way to handle this would be to
101 /// run type inference on the function again, but with a variable as the return
103 fn peel_blocks(mut expr: ast::Expr) -> ast::Expr {
106 match (expr.syntax()) {
107 ast::BlockExpr(it) => {
108 if let Some(tail) = it.tail_expr() {
115 if let Some(then_branch) = it.then_branch() {
116 expr = ast::Expr::BlockExpr(then_branch.clone());
121 ast::MatchExpr(it) => {
122 if let Some(arm_expr) = it.match_arm_list().and_then(|l| l.arms().next()).and_then(|a| a.expr()) {
135 fn extract_tail(ctx: &AssistContext<'_>) -> Option<(FnType, ast::Expr, InsertOrReplace)> {
136 let (fn_type, tail_expr, return_type_range, action) =
137 if let Some(closure) = ctx.find_node_at_offset::<ast::ClosureExpr>() {
138 let rpipe = closure.param_list()?.syntax().last_token()?;
139 let rpipe_pos = rpipe.text_range().end();
141 let action = ret_ty_to_action(closure.ret_type(), rpipe)?;
143 let body = closure.body()?;
144 let body_start = body.syntax().first_token()?.text_range().start();
145 let (tail_expr, wrap_expr) = match body {
146 ast::Expr::BlockExpr(block) => (block.tail_expr()?, false),
147 body => (body, true),
150 let ret_range = TextRange::new(rpipe_pos, body_start);
151 (FnType::Closure { wrap_expr }, tail_expr, ret_range, action)
153 let func = ctx.find_node_at_offset::<ast::Fn>()?;
155 let rparen = func.param_list()?.r_paren_token()?;
156 let rparen_pos = rparen.text_range().end();
157 let action = ret_ty_to_action(func.ret_type(), rparen)?;
159 let body = func.body()?;
160 let stmt_list = body.stmt_list()?;
161 let tail_expr = stmt_list.tail_expr()?;
163 let ret_range_end = stmt_list.l_curly_token()?.text_range().start();
164 let ret_range = TextRange::new(rparen_pos, ret_range_end);
165 (FnType::Function, tail_expr, ret_range, action)
167 let range = ctx.selection_trimmed();
168 if return_type_range.contains_range(range) {
169 cov_mark::hit!(cursor_in_ret_position);
170 cov_mark::hit!(cursor_in_ret_position_closure);
171 } else if tail_expr.syntax().text_range().contains_range(range) {
172 cov_mark::hit!(cursor_on_tail);
173 cov_mark::hit!(cursor_on_tail_closure);
177 Some((fn_type, tail_expr, action))
182 use crate::tests::{check_assist, check_assist_not_applicable};
187 fn infer_return_type_specified_inferred() {
188 cov_mark::check!(existing_infer_ret_type);
201 fn infer_return_type_specified_inferred_closure() {
202 cov_mark::check!(existing_infer_ret_type_closure);
215 fn infer_return_type_cursor_at_return_type_pos() {
216 cov_mark::check!(cursor_in_ret_position);
229 fn infer_return_type_cursor_at_return_type_pos_closure() {
230 cov_mark::check!(cursor_in_ret_position_closure);
243 fn infer_return_type() {
244 cov_mark::check!(cursor_on_tail);
257 fn infer_return_type_no_whitespace() {
270 fn infer_return_type_nested() {
291 fn infer_return_type_nested_match() {
310 fn not_applicable_ret_type_specified() {
311 cov_mark::check!(existing_ret_type);
312 check_assist_not_applicable(
321 fn not_applicable_non_tail_expr() {
322 check_assist_not_applicable(
332 fn not_applicable_unit_return_type() {
333 check_assist_not_applicable(
342 fn infer_return_type_closure_block() {
343 cov_mark::check!(cursor_on_tail_closure);
360 fn infer_return_type_closure() {
367 |x: i32| -> i32 { x };
373 fn infer_return_type_closure_no_whitespace() {
380 |x: i32| -> i32 { x };
386 fn infer_return_type_closure_wrap() {
387 cov_mark::check!(wrap_closure_non_block_expr);
400 fn infer_return_type_nested_closure() {
425 fn not_applicable_ret_type_specified_closure() {
426 cov_mark::check!(existing_ret_type_closure);
427 check_assist_not_applicable(
436 fn not_applicable_non_tail_expr_closure() {
437 check_assist_not_applicable(