]> git.lizzy.rs Git - rust.git/blob - src/tools/rust-analyzer/crates/ide-assists/src/handlers/add_return_type.rs
Rollup merge of #103996 - SUPERCILEX:docs, r=RalfJung
[rust.git] / src / tools / rust-analyzer / crates / ide-assists / src / handlers / add_return_type.rs
1 use hir::HirDisplay;
2 use syntax::{ast, match_ast, AstNode, SyntaxKind, SyntaxToken, TextRange, TextSize};
3
4 use crate::{AssistContext, AssistId, AssistKind, Assists};
5
6 // Assist: add_return_type
7 //
8 // Adds the return type to a function or closure inferred from its tail expression if it doesn't have a return
9 // type specified. This assists is useable in a functions or closures tail expression or return type position.
10 //
11 // ```
12 // fn foo() { 4$02i32 }
13 // ```
14 // ->
15 // ```
16 // fn foo() -> i32 { 42i32 }
17 // ```
18 pub(crate) fn add_return_type(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> {
19     let (fn_type, tail_expr, builder_edit_pos) = extract_tail(ctx)?;
20     let module = ctx.sema.scope(tail_expr.syntax())?.module();
21     let ty = ctx.sema.type_of_expr(&peel_blocks(tail_expr.clone()))?.original();
22     if ty.is_unit() {
23         return None;
24     }
25     let ty = ty.display_source_code(ctx.db(), module.into()).ok()?;
26
27     acc.add(
28         AssistId("add_return_type", AssistKind::RefactorRewrite),
29         match fn_type {
30             FnType::Function => "Add this function's return type",
31             FnType::Closure { .. } => "Add this closure's return type",
32         },
33         tail_expr.syntax().text_range(),
34         |builder| {
35             match builder_edit_pos {
36                 InsertOrReplace::Insert(insert_pos, needs_whitespace) => {
37                     let preceeding_whitespace = if needs_whitespace { " " } else { "" };
38                     builder.insert(insert_pos, &format!("{preceeding_whitespace}-> {ty} "))
39                 }
40                 InsertOrReplace::Replace(text_range) => {
41                     builder.replace(text_range, &format!("-> {ty}"))
42                 }
43             }
44             if let FnType::Closure { wrap_expr: true } = fn_type {
45                 cov_mark::hit!(wrap_closure_non_block_expr);
46                 // `|x| x` becomes `|x| -> T x` which is invalid, so wrap it in a block
47                 builder.replace(tail_expr.syntax().text_range(), &format!("{{{tail_expr}}}"));
48             }
49         },
50     )
51 }
52
53 enum InsertOrReplace {
54     Insert(TextSize, bool),
55     Replace(TextRange),
56 }
57
58 /// Check the potentially already specified return type and reject it or turn it into a builder command
59 /// if allowed.
60 fn ret_ty_to_action(
61     ret_ty: Option<ast::RetType>,
62     insert_after: SyntaxToken,
63 ) -> Option<InsertOrReplace> {
64     match ret_ty {
65         Some(ret_ty) => match ret_ty.ty() {
66             Some(ast::Type::InferType(_)) | None => {
67                 cov_mark::hit!(existing_infer_ret_type);
68                 cov_mark::hit!(existing_infer_ret_type_closure);
69                 Some(InsertOrReplace::Replace(ret_ty.syntax().text_range()))
70             }
71             _ => {
72                 cov_mark::hit!(existing_ret_type);
73                 cov_mark::hit!(existing_ret_type_closure);
74                 None
75             }
76         },
77         None => {
78             let insert_after_pos = insert_after.text_range().end();
79             let (insert_pos, needs_whitespace) = match insert_after.next_token() {
80                 Some(it) if it.kind() == SyntaxKind::WHITESPACE => {
81                     (insert_after_pos + TextSize::from(1), false)
82                 }
83                 _ => (insert_after_pos, true),
84             };
85
86             Some(InsertOrReplace::Insert(insert_pos, needs_whitespace))
87         }
88     }
89 }
90
91 enum FnType {
92     Function,
93     Closure { wrap_expr: bool },
94 }
95
96 /// If we're looking at a block that is supposed to return `()`, type inference
97 /// will just tell us it has type `()`. We have to look at the tail expression
98 /// to see the mismatched actual type. This 'unpeels' the various blocks to
99 /// hopefully let us see the type the user intends. (This still doesn't handle
100 /// all situations fully correctly; the 'ideal' way to handle this would be to
101 /// run type inference on the function again, but with a variable as the return
102 /// type.)
103 fn peel_blocks(mut expr: ast::Expr) -> ast::Expr {
104     loop {
105         match_ast! {
106             match (expr.syntax()) {
107                 ast::BlockExpr(it) => {
108                     if let Some(tail) = it.tail_expr() {
109                         expr = tail.clone();
110                     } else {
111                         break;
112                     }
113                 },
114                 ast::IfExpr(it) => {
115                     if let Some(then_branch) = it.then_branch() {
116                         expr = ast::Expr::BlockExpr(then_branch.clone());
117                     } else {
118                         break;
119                     }
120                 },
121                 ast::MatchExpr(it) => {
122                     if let Some(arm_expr) = it.match_arm_list().and_then(|l| l.arms().next()).and_then(|a| a.expr()) {
123                         expr = arm_expr;
124                     } else {
125                         break;
126                     }
127                 },
128                 _ => break,
129             }
130         }
131     }
132     expr
133 }
134
135 fn extract_tail(ctx: &AssistContext<'_>) -> Option<(FnType, ast::Expr, InsertOrReplace)> {
136     let (fn_type, tail_expr, return_type_range, action) =
137         if let Some(closure) = ctx.find_node_at_offset::<ast::ClosureExpr>() {
138             let rpipe = closure.param_list()?.syntax().last_token()?;
139             let rpipe_pos = rpipe.text_range().end();
140
141             let action = ret_ty_to_action(closure.ret_type(), rpipe)?;
142
143             let body = closure.body()?;
144             let body_start = body.syntax().first_token()?.text_range().start();
145             let (tail_expr, wrap_expr) = match body {
146                 ast::Expr::BlockExpr(block) => (block.tail_expr()?, false),
147                 body => (body, true),
148             };
149
150             let ret_range = TextRange::new(rpipe_pos, body_start);
151             (FnType::Closure { wrap_expr }, tail_expr, ret_range, action)
152         } else {
153             let func = ctx.find_node_at_offset::<ast::Fn>()?;
154
155             let rparen = func.param_list()?.r_paren_token()?;
156             let rparen_pos = rparen.text_range().end();
157             let action = ret_ty_to_action(func.ret_type(), rparen)?;
158
159             let body = func.body()?;
160             let stmt_list = body.stmt_list()?;
161             let tail_expr = stmt_list.tail_expr()?;
162
163             let ret_range_end = stmt_list.l_curly_token()?.text_range().start();
164             let ret_range = TextRange::new(rparen_pos, ret_range_end);
165             (FnType::Function, tail_expr, ret_range, action)
166         };
167     let range = ctx.selection_trimmed();
168     if return_type_range.contains_range(range) {
169         cov_mark::hit!(cursor_in_ret_position);
170         cov_mark::hit!(cursor_in_ret_position_closure);
171     } else if tail_expr.syntax().text_range().contains_range(range) {
172         cov_mark::hit!(cursor_on_tail);
173         cov_mark::hit!(cursor_on_tail_closure);
174     } else {
175         return None;
176     }
177     Some((fn_type, tail_expr, action))
178 }
179
180 #[cfg(test)]
181 mod tests {
182     use crate::tests::{check_assist, check_assist_not_applicable};
183
184     use super::*;
185
186     #[test]
187     fn infer_return_type_specified_inferred() {
188         cov_mark::check!(existing_infer_ret_type);
189         check_assist(
190             add_return_type,
191             r#"fn foo() -> $0_ {
192     45
193 }"#,
194             r#"fn foo() -> i32 {
195     45
196 }"#,
197         );
198     }
199
200     #[test]
201     fn infer_return_type_specified_inferred_closure() {
202         cov_mark::check!(existing_infer_ret_type_closure);
203         check_assist(
204             add_return_type,
205             r#"fn foo() {
206     || -> _ {$045};
207 }"#,
208             r#"fn foo() {
209     || -> i32 {45};
210 }"#,
211         );
212     }
213
214     #[test]
215     fn infer_return_type_cursor_at_return_type_pos() {
216         cov_mark::check!(cursor_in_ret_position);
217         check_assist(
218             add_return_type,
219             r#"fn foo() $0{
220     45
221 }"#,
222             r#"fn foo() -> i32 {
223     45
224 }"#,
225         );
226     }
227
228     #[test]
229     fn infer_return_type_cursor_at_return_type_pos_closure() {
230         cov_mark::check!(cursor_in_ret_position_closure);
231         check_assist(
232             add_return_type,
233             r#"fn foo() {
234     || $045
235 }"#,
236             r#"fn foo() {
237     || -> i32 {45}
238 }"#,
239         );
240     }
241
242     #[test]
243     fn infer_return_type() {
244         cov_mark::check!(cursor_on_tail);
245         check_assist(
246             add_return_type,
247             r#"fn foo() {
248     45$0
249 }"#,
250             r#"fn foo() -> i32 {
251     45
252 }"#,
253         );
254     }
255
256     #[test]
257     fn infer_return_type_no_whitespace() {
258         check_assist(
259             add_return_type,
260             r#"fn foo(){
261     45$0
262 }"#,
263             r#"fn foo() -> i32 {
264     45
265 }"#,
266         );
267     }
268
269     #[test]
270     fn infer_return_type_nested() {
271         check_assist(
272             add_return_type,
273             r#"fn foo() {
274     if true {
275         3$0
276     } else {
277         5
278     }
279 }"#,
280             r#"fn foo() -> i32 {
281     if true {
282         3
283     } else {
284         5
285     }
286 }"#,
287         );
288     }
289
290     #[test]
291     fn infer_return_type_nested_match() {
292         check_assist(
293             add_return_type,
294             r#"fn foo() {
295     match true {
296         true => { 3$0 },
297         false => { 5 },
298     }
299 }"#,
300             r#"fn foo() -> i32 {
301     match true {
302         true => { 3 },
303         false => { 5 },
304     }
305 }"#,
306         );
307     }
308
309     #[test]
310     fn not_applicable_ret_type_specified() {
311         cov_mark::check!(existing_ret_type);
312         check_assist_not_applicable(
313             add_return_type,
314             r#"fn foo() -> i32 {
315     ( 45$0 + 32 ) * 123
316 }"#,
317         );
318     }
319
320     #[test]
321     fn not_applicable_non_tail_expr() {
322         check_assist_not_applicable(
323             add_return_type,
324             r#"fn foo() {
325     let x = $03;
326     ( 45 + 32 ) * 123
327 }"#,
328         );
329     }
330
331     #[test]
332     fn not_applicable_unit_return_type() {
333         check_assist_not_applicable(
334             add_return_type,
335             r#"fn foo() {
336     ($0)
337 }"#,
338         );
339     }
340
341     #[test]
342     fn infer_return_type_closure_block() {
343         cov_mark::check!(cursor_on_tail_closure);
344         check_assist(
345             add_return_type,
346             r#"fn foo() {
347     |x: i32| {
348         x$0
349     };
350 }"#,
351             r#"fn foo() {
352     |x: i32| -> i32 {
353         x
354     };
355 }"#,
356         );
357     }
358
359     #[test]
360     fn infer_return_type_closure() {
361         check_assist(
362             add_return_type,
363             r#"fn foo() {
364     |x: i32| { x$0 };
365 }"#,
366             r#"fn foo() {
367     |x: i32| -> i32 { x };
368 }"#,
369         );
370     }
371
372     #[test]
373     fn infer_return_type_closure_no_whitespace() {
374         check_assist(
375             add_return_type,
376             r#"fn foo() {
377     |x: i32|{ x$0 };
378 }"#,
379             r#"fn foo() {
380     |x: i32| -> i32 { x };
381 }"#,
382         );
383     }
384
385     #[test]
386     fn infer_return_type_closure_wrap() {
387         cov_mark::check!(wrap_closure_non_block_expr);
388         check_assist(
389             add_return_type,
390             r#"fn foo() {
391     |x: i32| x$0;
392 }"#,
393             r#"fn foo() {
394     |x: i32| -> i32 {x};
395 }"#,
396         );
397     }
398
399     #[test]
400     fn infer_return_type_nested_closure() {
401         check_assist(
402             add_return_type,
403             r#"fn foo() {
404     || {
405         if true {
406             3$0
407         } else {
408             5
409         }
410     }
411 }"#,
412             r#"fn foo() {
413     || -> i32 {
414         if true {
415             3
416         } else {
417             5
418         }
419     }
420 }"#,
421         );
422     }
423
424     #[test]
425     fn not_applicable_ret_type_specified_closure() {
426         cov_mark::check!(existing_ret_type_closure);
427         check_assist_not_applicable(
428             add_return_type,
429             r#"fn foo() {
430     || -> i32 { 3$0 }
431 }"#,
432         );
433     }
434
435     #[test]
436     fn not_applicable_non_tail_expr_closure() {
437         check_assist_not_applicable(
438             add_return_type,
439             r#"fn foo() {
440     || -> i32 {
441         let x = 3$0;
442         6
443     }
444 }"#,
445         );
446     }
447 }