]> git.lizzy.rs Git - rust.git/blob - crates/ide_assists/src/handlers/add_return_type.rs
Rename some assists
[rust.git] / crates / ide_assists / src / handlers / add_return_type.rs
1 use hir::HirDisplay;
2 use syntax::{ast, AstNode, TextRange, TextSize};
3
4 use crate::{AssistContext, AssistId, AssistKind, Assists};
5
6 // Assist: add_return_type
7 //
8 // Adds the return type to a function or closure inferred from its tail expression if it doesn't have a return
9 // type specified. This assists is useable in a functions or closures tail expression or return type position.
10 //
11 // ```
12 // fn foo() { 4$02i32 }
13 // ```
14 // ->
15 // ```
16 // fn foo() -> i32 { 42i32 }
17 // ```
18 pub(crate) fn add_return_type(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
19     let (fn_type, tail_expr, builder_edit_pos) = extract_tail(ctx)?;
20     let module = ctx.sema.scope(tail_expr.syntax()).module()?;
21     let ty = ctx.sema.type_of_expr(&tail_expr)?.adjusted();
22     if ty.is_unit() {
23         return None;
24     }
25     let ty = ty.display_source_code(ctx.db(), module.into()).ok()?;
26
27     acc.add(
28         AssistId("add_return_type", AssistKind::RefactorRewrite),
29         match fn_type {
30             FnType::Function => "Add this function's return type",
31             FnType::Closure { .. } => "Add this closure's return type",
32         },
33         tail_expr.syntax().text_range(),
34         |builder| {
35             match builder_edit_pos {
36                 InsertOrReplace::Insert(insert_pos) => {
37                     builder.insert(insert_pos, &format!("-> {} ", ty))
38                 }
39                 InsertOrReplace::Replace(text_range) => {
40                     builder.replace(text_range, &format!("-> {}", ty))
41                 }
42             }
43             if let FnType::Closure { wrap_expr: true } = fn_type {
44                 cov_mark::hit!(wrap_closure_non_block_expr);
45                 // `|x| x` becomes `|x| -> T x` which is invalid, so wrap it in a block
46                 builder.replace(tail_expr.syntax().text_range(), &format!("{{{}}}", tail_expr));
47             }
48         },
49     )
50 }
51
52 enum InsertOrReplace {
53     Insert(TextSize),
54     Replace(TextRange),
55 }
56
57 /// Check the potentially already specified return type and reject it or turn it into a builder command
58 /// if allowed.
59 fn ret_ty_to_action(ret_ty: Option<ast::RetType>, insert_pos: TextSize) -> Option<InsertOrReplace> {
60     match ret_ty {
61         Some(ret_ty) => match ret_ty.ty() {
62             Some(ast::Type::InferType(_)) | None => {
63                 cov_mark::hit!(existing_infer_ret_type);
64                 cov_mark::hit!(existing_infer_ret_type_closure);
65                 Some(InsertOrReplace::Replace(ret_ty.syntax().text_range()))
66             }
67             _ => {
68                 cov_mark::hit!(existing_ret_type);
69                 cov_mark::hit!(existing_ret_type_closure);
70                 None
71             }
72         },
73         None => Some(InsertOrReplace::Insert(insert_pos + TextSize::from(1))),
74     }
75 }
76
77 enum FnType {
78     Function,
79     Closure { wrap_expr: bool },
80 }
81
82 fn extract_tail(ctx: &AssistContext) -> Option<(FnType, ast::Expr, InsertOrReplace)> {
83     let (fn_type, tail_expr, return_type_range, action) =
84         if let Some(closure) = ctx.find_node_at_offset::<ast::ClosureExpr>() {
85             let rpipe_pos = closure.param_list()?.syntax().last_token()?.text_range().end();
86             let action = ret_ty_to_action(closure.ret_type(), rpipe_pos)?;
87
88             let body = closure.body()?;
89             let body_start = body.syntax().first_token()?.text_range().start();
90             let (tail_expr, wrap_expr) = match body {
91                 ast::Expr::BlockExpr(block) => (block.tail_expr()?, false),
92                 body => (body, true),
93             };
94
95             let ret_range = TextRange::new(rpipe_pos, body_start);
96             (FnType::Closure { wrap_expr }, tail_expr, ret_range, action)
97         } else {
98             let func = ctx.find_node_at_offset::<ast::Fn>()?;
99             let rparen_pos = func.param_list()?.r_paren_token()?.text_range().end();
100             let action = ret_ty_to_action(func.ret_type(), rparen_pos)?;
101
102             let body = func.body()?;
103             let tail_expr = body.tail_expr()?;
104
105             let ret_range_end = body.l_curly_token()?.text_range().start();
106             let ret_range = TextRange::new(rparen_pos, ret_range_end);
107             (FnType::Function, tail_expr, ret_range, action)
108         };
109     let frange = ctx.frange.range;
110     if return_type_range.contains_range(frange) {
111         cov_mark::hit!(cursor_in_ret_position);
112         cov_mark::hit!(cursor_in_ret_position_closure);
113     } else if tail_expr.syntax().text_range().contains_range(frange) {
114         cov_mark::hit!(cursor_on_tail);
115         cov_mark::hit!(cursor_on_tail_closure);
116     } else {
117         return None;
118     }
119     Some((fn_type, tail_expr, action))
120 }
121
122 #[cfg(test)]
123 mod tests {
124     use crate::tests::{check_assist, check_assist_not_applicable};
125
126     use super::*;
127
128     #[test]
129     fn infer_return_type_specified_inferred() {
130         cov_mark::check!(existing_infer_ret_type);
131         check_assist(
132             add_return_type,
133             r#"fn foo() -> $0_ {
134     45
135 }"#,
136             r#"fn foo() -> i32 {
137     45
138 }"#,
139         );
140     }
141
142     #[test]
143     fn infer_return_type_specified_inferred_closure() {
144         cov_mark::check!(existing_infer_ret_type_closure);
145         check_assist(
146             add_return_type,
147             r#"fn foo() {
148     || -> _ {$045};
149 }"#,
150             r#"fn foo() {
151     || -> i32 {45};
152 }"#,
153         );
154     }
155
156     #[test]
157     fn infer_return_type_cursor_at_return_type_pos() {
158         cov_mark::check!(cursor_in_ret_position);
159         check_assist(
160             add_return_type,
161             r#"fn foo() $0{
162     45
163 }"#,
164             r#"fn foo() -> i32 {
165     45
166 }"#,
167         );
168     }
169
170     #[test]
171     fn infer_return_type_cursor_at_return_type_pos_closure() {
172         cov_mark::check!(cursor_in_ret_position_closure);
173         check_assist(
174             add_return_type,
175             r#"fn foo() {
176     || $045
177 }"#,
178             r#"fn foo() {
179     || -> i32 {45}
180 }"#,
181         );
182     }
183
184     #[test]
185     fn infer_return_type() {
186         cov_mark::check!(cursor_on_tail);
187         check_assist(
188             add_return_type,
189             r#"fn foo() {
190     45$0
191 }"#,
192             r#"fn foo() -> i32 {
193     45
194 }"#,
195         );
196     }
197
198     #[test]
199     fn infer_return_type_nested() {
200         check_assist(
201             add_return_type,
202             r#"fn foo() {
203     if true {
204         3$0
205     } else {
206         5
207     }
208 }"#,
209             r#"fn foo() -> i32 {
210     if true {
211         3
212     } else {
213         5
214     }
215 }"#,
216         );
217     }
218
219     #[test]
220     fn not_applicable_ret_type_specified() {
221         cov_mark::check!(existing_ret_type);
222         check_assist_not_applicable(
223             add_return_type,
224             r#"fn foo() -> i32 {
225     ( 45$0 + 32 ) * 123
226 }"#,
227         );
228     }
229
230     #[test]
231     fn not_applicable_non_tail_expr() {
232         check_assist_not_applicable(
233             add_return_type,
234             r#"fn foo() {
235     let x = $03;
236     ( 45 + 32 ) * 123
237 }"#,
238         );
239     }
240
241     #[test]
242     fn not_applicable_unit_return_type() {
243         check_assist_not_applicable(
244             add_return_type,
245             r#"fn foo() {
246     ($0)
247 }"#,
248         );
249     }
250
251     #[test]
252     fn infer_return_type_closure_block() {
253         cov_mark::check!(cursor_on_tail_closure);
254         check_assist(
255             add_return_type,
256             r#"fn foo() {
257     |x: i32| {
258         x$0
259     };
260 }"#,
261             r#"fn foo() {
262     |x: i32| -> i32 {
263         x
264     };
265 }"#,
266         );
267     }
268
269     #[test]
270     fn infer_return_type_closure() {
271         check_assist(
272             add_return_type,
273             r#"fn foo() {
274     |x: i32| { x$0 };
275 }"#,
276             r#"fn foo() {
277     |x: i32| -> i32 { x };
278 }"#,
279         );
280     }
281
282     #[test]
283     fn infer_return_type_closure_wrap() {
284         cov_mark::check!(wrap_closure_non_block_expr);
285         check_assist(
286             add_return_type,
287             r#"fn foo() {
288     |x: i32| x$0;
289 }"#,
290             r#"fn foo() {
291     |x: i32| -> i32 {x};
292 }"#,
293         );
294     }
295
296     #[test]
297     fn infer_return_type_nested_closure() {
298         check_assist(
299             add_return_type,
300             r#"fn foo() {
301     || {
302         if true {
303             3$0
304         } else {
305             5
306         }
307     }
308 }"#,
309             r#"fn foo() {
310     || -> i32 {
311         if true {
312             3
313         } else {
314             5
315         }
316     }
317 }"#,
318         );
319     }
320
321     #[test]
322     fn not_applicable_ret_type_specified_closure() {
323         cov_mark::check!(existing_ret_type_closure);
324         check_assist_not_applicable(
325             add_return_type,
326             r#"fn foo() {
327     || -> i32 { 3$0 }
328 }"#,
329         );
330     }
331
332     #[test]
333     fn not_applicable_non_tail_expr_closure() {
334         check_assist_not_applicable(
335             add_return_type,
336             r#"fn foo() {
337     || -> i32 {
338         let x = 3$0;
339         6
340     }
341 }"#,
342         );
343     }
344 }