]> git.lizzy.rs Git - rust.git/blob - crates/ide_assists/src/handlers/add_return_type.rs
internal: more reasonable grammar for blocks
[rust.git] / crates / ide_assists / src / handlers / add_return_type.rs
1 use hir::HirDisplay;
2 use syntax::{ast, AstNode, TextRange, TextSize};
3
4 use crate::{AssistContext, AssistId, AssistKind, Assists};
5
6 // Assist: add_return_type
7 //
8 // Adds the return type to a function or closure inferred from its tail expression if it doesn't have a return
9 // type specified. This assists is useable in a functions or closures tail expression or return type position.
10 //
11 // ```
12 // fn foo() { 4$02i32 }
13 // ```
14 // ->
15 // ```
16 // fn foo() -> i32 { 42i32 }
17 // ```
18 pub(crate) fn add_return_type(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
19     let (fn_type, tail_expr, builder_edit_pos) = extract_tail(ctx)?;
20     let module = ctx.sema.scope(tail_expr.syntax()).module()?;
21     let ty = ctx.sema.type_of_expr(&tail_expr)?.adjusted();
22     if ty.is_unit() {
23         return None;
24     }
25     let ty = ty.display_source_code(ctx.db(), module.into()).ok()?;
26
27     acc.add(
28         AssistId("add_return_type", AssistKind::RefactorRewrite),
29         match fn_type {
30             FnType::Function => "Add this function's return type",
31             FnType::Closure { .. } => "Add this closure's return type",
32         },
33         tail_expr.syntax().text_range(),
34         |builder| {
35             match builder_edit_pos {
36                 InsertOrReplace::Insert(insert_pos) => {
37                     builder.insert(insert_pos, &format!("-> {} ", ty))
38                 }
39                 InsertOrReplace::Replace(text_range) => {
40                     builder.replace(text_range, &format!("-> {}", ty))
41                 }
42             }
43             if let FnType::Closure { wrap_expr: true } = fn_type {
44                 cov_mark::hit!(wrap_closure_non_block_expr);
45                 // `|x| x` becomes `|x| -> T x` which is invalid, so wrap it in a block
46                 builder.replace(tail_expr.syntax().text_range(), &format!("{{{}}}", tail_expr));
47             }
48         },
49     )
50 }
51
52 enum InsertOrReplace {
53     Insert(TextSize),
54     Replace(TextRange),
55 }
56
57 /// Check the potentially already specified return type and reject it or turn it into a builder command
58 /// if allowed.
59 fn ret_ty_to_action(ret_ty: Option<ast::RetType>, insert_pos: TextSize) -> Option<InsertOrReplace> {
60     match ret_ty {
61         Some(ret_ty) => match ret_ty.ty() {
62             Some(ast::Type::InferType(_)) | None => {
63                 cov_mark::hit!(existing_infer_ret_type);
64                 cov_mark::hit!(existing_infer_ret_type_closure);
65                 Some(InsertOrReplace::Replace(ret_ty.syntax().text_range()))
66             }
67             _ => {
68                 cov_mark::hit!(existing_ret_type);
69                 cov_mark::hit!(existing_ret_type_closure);
70                 None
71             }
72         },
73         None => Some(InsertOrReplace::Insert(insert_pos + TextSize::from(1))),
74     }
75 }
76
77 enum FnType {
78     Function,
79     Closure { wrap_expr: bool },
80 }
81
82 fn extract_tail(ctx: &AssistContext) -> Option<(FnType, ast::Expr, InsertOrReplace)> {
83     let (fn_type, tail_expr, return_type_range, action) =
84         if let Some(closure) = ctx.find_node_at_offset::<ast::ClosureExpr>() {
85             let rpipe_pos = closure.param_list()?.syntax().last_token()?.text_range().end();
86             let action = ret_ty_to_action(closure.ret_type(), rpipe_pos)?;
87
88             let body = closure.body()?;
89             let body_start = body.syntax().first_token()?.text_range().start();
90             let (tail_expr, wrap_expr) = match body {
91                 ast::Expr::BlockExpr(block) => (block.tail_expr()?, false),
92                 body => (body, true),
93             };
94
95             let ret_range = TextRange::new(rpipe_pos, body_start);
96             (FnType::Closure { wrap_expr }, tail_expr, ret_range, action)
97         } else {
98             let func = ctx.find_node_at_offset::<ast::Fn>()?;
99             let rparen_pos = func.param_list()?.r_paren_token()?.text_range().end();
100             let action = ret_ty_to_action(func.ret_type(), rparen_pos)?;
101
102             let body = func.body()?;
103             let stmt_list = body.stmt_list()?;
104             let tail_expr = stmt_list.tail_expr()?;
105
106             let ret_range_end = stmt_list.l_curly_token()?.text_range().start();
107             let ret_range = TextRange::new(rparen_pos, ret_range_end);
108             (FnType::Function, tail_expr, ret_range, action)
109         };
110     let frange = ctx.frange.range;
111     if return_type_range.contains_range(frange) {
112         cov_mark::hit!(cursor_in_ret_position);
113         cov_mark::hit!(cursor_in_ret_position_closure);
114     } else if tail_expr.syntax().text_range().contains_range(frange) {
115         cov_mark::hit!(cursor_on_tail);
116         cov_mark::hit!(cursor_on_tail_closure);
117     } else {
118         return None;
119     }
120     Some((fn_type, tail_expr, action))
121 }
122
123 #[cfg(test)]
124 mod tests {
125     use crate::tests::{check_assist, check_assist_not_applicable};
126
127     use super::*;
128
129     #[test]
130     fn infer_return_type_specified_inferred() {
131         cov_mark::check!(existing_infer_ret_type);
132         check_assist(
133             add_return_type,
134             r#"fn foo() -> $0_ {
135     45
136 }"#,
137             r#"fn foo() -> i32 {
138     45
139 }"#,
140         );
141     }
142
143     #[test]
144     fn infer_return_type_specified_inferred_closure() {
145         cov_mark::check!(existing_infer_ret_type_closure);
146         check_assist(
147             add_return_type,
148             r#"fn foo() {
149     || -> _ {$045};
150 }"#,
151             r#"fn foo() {
152     || -> i32 {45};
153 }"#,
154         );
155     }
156
157     #[test]
158     fn infer_return_type_cursor_at_return_type_pos() {
159         cov_mark::check!(cursor_in_ret_position);
160         check_assist(
161             add_return_type,
162             r#"fn foo() $0{
163     45
164 }"#,
165             r#"fn foo() -> i32 {
166     45
167 }"#,
168         );
169     }
170
171     #[test]
172     fn infer_return_type_cursor_at_return_type_pos_closure() {
173         cov_mark::check!(cursor_in_ret_position_closure);
174         check_assist(
175             add_return_type,
176             r#"fn foo() {
177     || $045
178 }"#,
179             r#"fn foo() {
180     || -> i32 {45}
181 }"#,
182         );
183     }
184
185     #[test]
186     fn infer_return_type() {
187         cov_mark::check!(cursor_on_tail);
188         check_assist(
189             add_return_type,
190             r#"fn foo() {
191     45$0
192 }"#,
193             r#"fn foo() -> i32 {
194     45
195 }"#,
196         );
197     }
198
199     #[test]
200     fn infer_return_type_nested() {
201         check_assist(
202             add_return_type,
203             r#"fn foo() {
204     if true {
205         3$0
206     } else {
207         5
208     }
209 }"#,
210             r#"fn foo() -> i32 {
211     if true {
212         3
213     } else {
214         5
215     }
216 }"#,
217         );
218     }
219
220     #[test]
221     fn not_applicable_ret_type_specified() {
222         cov_mark::check!(existing_ret_type);
223         check_assist_not_applicable(
224             add_return_type,
225             r#"fn foo() -> i32 {
226     ( 45$0 + 32 ) * 123
227 }"#,
228         );
229     }
230
231     #[test]
232     fn not_applicable_non_tail_expr() {
233         check_assist_not_applicable(
234             add_return_type,
235             r#"fn foo() {
236     let x = $03;
237     ( 45 + 32 ) * 123
238 }"#,
239         );
240     }
241
242     #[test]
243     fn not_applicable_unit_return_type() {
244         check_assist_not_applicable(
245             add_return_type,
246             r#"fn foo() {
247     ($0)
248 }"#,
249         );
250     }
251
252     #[test]
253     fn infer_return_type_closure_block() {
254         cov_mark::check!(cursor_on_tail_closure);
255         check_assist(
256             add_return_type,
257             r#"fn foo() {
258     |x: i32| {
259         x$0
260     };
261 }"#,
262             r#"fn foo() {
263     |x: i32| -> i32 {
264         x
265     };
266 }"#,
267         );
268     }
269
270     #[test]
271     fn infer_return_type_closure() {
272         check_assist(
273             add_return_type,
274             r#"fn foo() {
275     |x: i32| { x$0 };
276 }"#,
277             r#"fn foo() {
278     |x: i32| -> i32 { x };
279 }"#,
280         );
281     }
282
283     #[test]
284     fn infer_return_type_closure_wrap() {
285         cov_mark::check!(wrap_closure_non_block_expr);
286         check_assist(
287             add_return_type,
288             r#"fn foo() {
289     |x: i32| x$0;
290 }"#,
291             r#"fn foo() {
292     |x: i32| -> i32 {x};
293 }"#,
294         );
295     }
296
297     #[test]
298     fn infer_return_type_nested_closure() {
299         check_assist(
300             add_return_type,
301             r#"fn foo() {
302     || {
303         if true {
304             3$0
305         } else {
306             5
307         }
308     }
309 }"#,
310             r#"fn foo() {
311     || -> i32 {
312         if true {
313             3
314         } else {
315             5
316         }
317     }
318 }"#,
319         );
320     }
321
322     #[test]
323     fn not_applicable_ret_type_specified_closure() {
324         cov_mark::check!(existing_ret_type_closure);
325         check_assist_not_applicable(
326             add_return_type,
327             r#"fn foo() {
328     || -> i32 { 3$0 }
329 }"#,
330         );
331     }
332
333     #[test]
334     fn not_applicable_non_tail_expr_closure() {
335         check_assist_not_applicable(
336             add_return_type,
337             r#"fn foo() {
338     || -> i32 {
339         let x = 3$0;
340         6
341     }
342 }"#,
343         );
344     }
345 }