]> git.lizzy.rs Git - rust.git/blob - crates/ide_assists/src/handlers/add_return_type.rs
Merge #11481
[rust.git] / crates / ide_assists / src / handlers / add_return_type.rs
1 use hir::HirDisplay;
2 use syntax::{ast, AstNode, SyntaxKind, SyntaxToken, TextRange, TextSize};
3
4 use crate::{AssistContext, AssistId, AssistKind, Assists};
5
6 // Assist: add_return_type
7 //
8 // Adds the return type to a function or closure inferred from its tail expression if it doesn't have a return
9 // type specified. This assists is useable in a functions or closures tail expression or return type position.
10 //
11 // ```
12 // fn foo() { 4$02i32 }
13 // ```
14 // ->
15 // ```
16 // fn foo() -> i32 { 42i32 }
17 // ```
18 pub(crate) fn add_return_type(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
19     let (fn_type, tail_expr, builder_edit_pos) = extract_tail(ctx)?;
20     let module = ctx.sema.scope(tail_expr.syntax()).module()?;
21     let ty = ctx.sema.type_of_expr(&tail_expr)?.adjusted();
22     if ty.is_unit() {
23         return None;
24     }
25     let ty = ty.display_source_code(ctx.db(), module.into()).ok()?;
26
27     acc.add(
28         AssistId("add_return_type", AssistKind::RefactorRewrite),
29         match fn_type {
30             FnType::Function => "Add this function's return type",
31             FnType::Closure { .. } => "Add this closure's return type",
32         },
33         tail_expr.syntax().text_range(),
34         |builder| {
35             match builder_edit_pos {
36                 InsertOrReplace::Insert(insert_pos, needs_whitespace) => {
37                     let preceeding_whitespace = if needs_whitespace { " " } else { "" };
38                     builder.insert(insert_pos, &format!("{}-> {} ", preceeding_whitespace, ty))
39                 }
40                 InsertOrReplace::Replace(text_range) => {
41                     builder.replace(text_range, &format!("-> {}", ty))
42                 }
43             }
44             if let FnType::Closure { wrap_expr: true } = fn_type {
45                 cov_mark::hit!(wrap_closure_non_block_expr);
46                 // `|x| x` becomes `|x| -> T x` which is invalid, so wrap it in a block
47                 builder.replace(tail_expr.syntax().text_range(), &format!("{{{}}}", tail_expr));
48             }
49         },
50     )
51 }
52
53 enum InsertOrReplace {
54     Insert(TextSize, bool),
55     Replace(TextRange),
56 }
57
58 /// Check the potentially already specified return type and reject it or turn it into a builder command
59 /// if allowed.
60 fn ret_ty_to_action(
61     ret_ty: Option<ast::RetType>,
62     insert_after: SyntaxToken,
63 ) -> Option<InsertOrReplace> {
64     match ret_ty {
65         Some(ret_ty) => match ret_ty.ty() {
66             Some(ast::Type::InferType(_)) | None => {
67                 cov_mark::hit!(existing_infer_ret_type);
68                 cov_mark::hit!(existing_infer_ret_type_closure);
69                 Some(InsertOrReplace::Replace(ret_ty.syntax().text_range()))
70             }
71             _ => {
72                 cov_mark::hit!(existing_ret_type);
73                 cov_mark::hit!(existing_ret_type_closure);
74                 None
75             }
76         },
77         None => {
78             let insert_after_pos = insert_after.text_range().end();
79             let (insert_pos, needs_whitespace) = match insert_after.next_token() {
80                 Some(it) if it.kind() == SyntaxKind::WHITESPACE => {
81                     (insert_after_pos + TextSize::from(1), false)
82                 }
83                 _ => (insert_after_pos, true),
84             };
85
86             Some(InsertOrReplace::Insert(insert_pos, needs_whitespace))
87         }
88     }
89 }
90
91 enum FnType {
92     Function,
93     Closure { wrap_expr: bool },
94 }
95
96 fn extract_tail(ctx: &AssistContext) -> Option<(FnType, ast::Expr, InsertOrReplace)> {
97     let (fn_type, tail_expr, return_type_range, action) =
98         if let Some(closure) = ctx.find_node_at_offset::<ast::ClosureExpr>() {
99             let rpipe = closure.param_list()?.syntax().last_token()?;
100             let rpipe_pos = rpipe.text_range().end();
101
102             let action = ret_ty_to_action(closure.ret_type(), rpipe)?;
103
104             let body = closure.body()?;
105             let body_start = body.syntax().first_token()?.text_range().start();
106             let (tail_expr, wrap_expr) = match body {
107                 ast::Expr::BlockExpr(block) => (block.tail_expr()?, false),
108                 body => (body, true),
109             };
110
111             let ret_range = TextRange::new(rpipe_pos, body_start);
112             (FnType::Closure { wrap_expr }, tail_expr, ret_range, action)
113         } else {
114             let func = ctx.find_node_at_offset::<ast::Fn>()?;
115
116             let rparen = func.param_list()?.r_paren_token()?;
117             let rparen_pos = rparen.text_range().end();
118             let action = ret_ty_to_action(func.ret_type(), rparen)?;
119
120             let body = func.body()?;
121             let stmt_list = body.stmt_list()?;
122             let tail_expr = stmt_list.tail_expr()?;
123
124             let ret_range_end = stmt_list.l_curly_token()?.text_range().start();
125             let ret_range = TextRange::new(rparen_pos, ret_range_end);
126             (FnType::Function, tail_expr, ret_range, action)
127         };
128     let range = ctx.selection_trimmed();
129     if return_type_range.contains_range(range) {
130         cov_mark::hit!(cursor_in_ret_position);
131         cov_mark::hit!(cursor_in_ret_position_closure);
132     } else if tail_expr.syntax().text_range().contains_range(range) {
133         cov_mark::hit!(cursor_on_tail);
134         cov_mark::hit!(cursor_on_tail_closure);
135     } else {
136         return None;
137     }
138     Some((fn_type, tail_expr, action))
139 }
140
141 #[cfg(test)]
142 mod tests {
143     use crate::tests::{check_assist, check_assist_not_applicable};
144
145     use super::*;
146
147     #[test]
148     fn infer_return_type_specified_inferred() {
149         cov_mark::check!(existing_infer_ret_type);
150         check_assist(
151             add_return_type,
152             r#"fn foo() -> $0_ {
153     45
154 }"#,
155             r#"fn foo() -> i32 {
156     45
157 }"#,
158         );
159     }
160
161     #[test]
162     fn infer_return_type_specified_inferred_closure() {
163         cov_mark::check!(existing_infer_ret_type_closure);
164         check_assist(
165             add_return_type,
166             r#"fn foo() {
167     || -> _ {$045};
168 }"#,
169             r#"fn foo() {
170     || -> i32 {45};
171 }"#,
172         );
173     }
174
175     #[test]
176     fn infer_return_type_cursor_at_return_type_pos() {
177         cov_mark::check!(cursor_in_ret_position);
178         check_assist(
179             add_return_type,
180             r#"fn foo() $0{
181     45
182 }"#,
183             r#"fn foo() -> i32 {
184     45
185 }"#,
186         );
187     }
188
189     #[test]
190     fn infer_return_type_cursor_at_return_type_pos_closure() {
191         cov_mark::check!(cursor_in_ret_position_closure);
192         check_assist(
193             add_return_type,
194             r#"fn foo() {
195     || $045
196 }"#,
197             r#"fn foo() {
198     || -> i32 {45}
199 }"#,
200         );
201     }
202
203     #[test]
204     fn infer_return_type() {
205         cov_mark::check!(cursor_on_tail);
206         check_assist(
207             add_return_type,
208             r#"fn foo() {
209     45$0
210 }"#,
211             r#"fn foo() -> i32 {
212     45
213 }"#,
214         );
215     }
216
217     #[test]
218     fn infer_return_type_no_whitespace() {
219         check_assist(
220             add_return_type,
221             r#"fn foo(){
222     45$0
223 }"#,
224             r#"fn foo() -> i32 {
225     45
226 }"#,
227         );
228     }
229
230     #[test]
231     fn infer_return_type_nested() {
232         check_assist(
233             add_return_type,
234             r#"fn foo() {
235     if true {
236         3$0
237     } else {
238         5
239     }
240 }"#,
241             r#"fn foo() -> i32 {
242     if true {
243         3
244     } else {
245         5
246     }
247 }"#,
248         );
249     }
250
251     #[test]
252     fn not_applicable_ret_type_specified() {
253         cov_mark::check!(existing_ret_type);
254         check_assist_not_applicable(
255             add_return_type,
256             r#"fn foo() -> i32 {
257     ( 45$0 + 32 ) * 123
258 }"#,
259         );
260     }
261
262     #[test]
263     fn not_applicable_non_tail_expr() {
264         check_assist_not_applicable(
265             add_return_type,
266             r#"fn foo() {
267     let x = $03;
268     ( 45 + 32 ) * 123
269 }"#,
270         );
271     }
272
273     #[test]
274     fn not_applicable_unit_return_type() {
275         check_assist_not_applicable(
276             add_return_type,
277             r#"fn foo() {
278     ($0)
279 }"#,
280         );
281     }
282
283     #[test]
284     fn infer_return_type_closure_block() {
285         cov_mark::check!(cursor_on_tail_closure);
286         check_assist(
287             add_return_type,
288             r#"fn foo() {
289     |x: i32| {
290         x$0
291     };
292 }"#,
293             r#"fn foo() {
294     |x: i32| -> i32 {
295         x
296     };
297 }"#,
298         );
299     }
300
301     #[test]
302     fn infer_return_type_closure() {
303         check_assist(
304             add_return_type,
305             r#"fn foo() {
306     |x: i32| { x$0 };
307 }"#,
308             r#"fn foo() {
309     |x: i32| -> i32 { x };
310 }"#,
311         );
312     }
313
314     #[test]
315     fn infer_return_type_closure_no_whitespace() {
316         check_assist(
317             add_return_type,
318             r#"fn foo() {
319     |x: i32|{ x$0 };
320 }"#,
321             r#"fn foo() {
322     |x: i32| -> i32 { x };
323 }"#,
324         );
325     }
326
327     #[test]
328     fn infer_return_type_closure_wrap() {
329         cov_mark::check!(wrap_closure_non_block_expr);
330         check_assist(
331             add_return_type,
332             r#"fn foo() {
333     |x: i32| x$0;
334 }"#,
335             r#"fn foo() {
336     |x: i32| -> i32 {x};
337 }"#,
338         );
339     }
340
341     #[test]
342     fn infer_return_type_nested_closure() {
343         check_assist(
344             add_return_type,
345             r#"fn foo() {
346     || {
347         if true {
348             3$0
349         } else {
350             5
351         }
352     }
353 }"#,
354             r#"fn foo() {
355     || -> i32 {
356         if true {
357             3
358         } else {
359             5
360         }
361     }
362 }"#,
363         );
364     }
365
366     #[test]
367     fn not_applicable_ret_type_specified_closure() {
368         cov_mark::check!(existing_ret_type_closure);
369         check_assist_not_applicable(
370             add_return_type,
371             r#"fn foo() {
372     || -> i32 { 3$0 }
373 }"#,
374         );
375     }
376
377     #[test]
378     fn not_applicable_non_tail_expr_closure() {
379         check_assist_not_applicable(
380             add_return_type,
381             r#"fn foo() {
382     || -> i32 {
383         let x = 3$0;
384         6
385     }
386 }"#,
387         );
388     }
389 }