]> git.lizzy.rs Git - rust.git/blob - crates/ide_assists/src/handlers/convert_to_guarded_return.rs
Merge #9962
[rust.git] / crates / ide_assists / src / handlers / convert_to_guarded_return.rs
1 use std::iter::once;
2
3 use syntax::{
4     ast::{
5         self,
6         edit::{AstNodeEdit, IndentLevel},
7         make,
8     },
9     ted, AstNode,
10     SyntaxKind::{FN, LOOP_EXPR, WHILE_EXPR, WHITESPACE},
11     T,
12 };
13
14 use crate::{
15     assist_context::{AssistContext, Assists},
16     utils::invert_boolean_expression,
17     AssistId, AssistKind,
18 };
19
20 // Assist: convert_to_guarded_return
21 //
22 // Replace a large conditional with a guarded return.
23 //
24 // ```
25 // fn main() {
26 //     $0if cond {
27 //         foo();
28 //         bar();
29 //     }
30 // }
31 // ```
32 // ->
33 // ```
34 // fn main() {
35 //     if !cond {
36 //         return;
37 //     }
38 //     foo();
39 //     bar();
40 // }
41 // ```
42 pub(crate) fn convert_to_guarded_return(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
43     let if_expr: ast::IfExpr = ctx.find_node_at_offset()?;
44     if if_expr.else_branch().is_some() {
45         return None;
46     }
47
48     let cond = if_expr.condition()?;
49
50     // Check if there is an IfLet that we can handle.
51     let if_let_pat = match cond.pat() {
52         None => None, // No IfLet, supported.
53         Some(ast::Pat::TupleStructPat(pat)) if pat.fields().count() == 1 => {
54             let path = pat.path()?;
55             if path.qualifier().is_some() {
56                 return None;
57             }
58
59             let bound_ident = pat.fields().next().unwrap();
60             if !ast::IdentPat::can_cast(bound_ident.syntax().kind()) {
61                 return None;
62             }
63
64             Some((path, bound_ident))
65         }
66         Some(_) => return None, // Unsupported IfLet.
67     };
68
69     let cond_expr = cond.expr()?;
70     let then_block = if_expr.then_branch()?;
71
72     let parent_block = if_expr.syntax().parent()?.ancestors().find_map(ast::BlockExpr::cast)?;
73
74     if parent_block.tail_expr()? != if_expr.clone().into() {
75         return None;
76     }
77
78     // check for early return and continue
79     let first_in_then_block = then_block.syntax().first_child()?;
80     if ast::ReturnExpr::can_cast(first_in_then_block.kind())
81         || ast::ContinueExpr::can_cast(first_in_then_block.kind())
82         || first_in_then_block
83             .children()
84             .any(|x| ast::ReturnExpr::can_cast(x.kind()) || ast::ContinueExpr::can_cast(x.kind()))
85     {
86         return None;
87     }
88
89     let parent_container = parent_block.syntax().parent()?;
90
91     let early_expression: ast::Expr = match parent_container.kind() {
92         WHILE_EXPR | LOOP_EXPR => make::expr_continue(),
93         FN => make::expr_return(None),
94         _ => return None,
95     };
96
97     if then_block.syntax().first_child_or_token().map(|t| t.kind() == T!['{']).is_none() {
98         return None;
99     }
100
101     then_block.syntax().last_child_or_token().filter(|t| t.kind() == T!['}'])?;
102
103     let target = if_expr.syntax().text_range();
104     acc.add(
105         AssistId("convert_to_guarded_return", AssistKind::RefactorRewrite),
106         "Convert to guarded return",
107         target,
108         |edit| {
109             let if_expr = edit.make_mut(if_expr);
110             let if_indent_level = IndentLevel::from_node(if_expr.syntax());
111             let replacement = match if_let_pat {
112                 None => {
113                     // If.
114                     let new_expr = {
115                         let then_branch =
116                             make::block_expr(once(make::expr_stmt(early_expression).into()), None);
117                         let cond = invert_boolean_expression(cond_expr);
118                         make::expr_if(make::condition(cond, None), then_branch, None)
119                             .indent(if_indent_level)
120                     };
121                     new_expr.syntax().clone_for_update()
122                 }
123                 Some((path, bound_ident)) => {
124                     // If-let.
125                     let match_expr = {
126                         let happy_arm = {
127                             let pat = make::tuple_struct_pat(
128                                 path,
129                                 once(make::ext::simple_ident_pat(make::name("it")).into()),
130                             );
131                             let expr = {
132                                 let path = make::ext::ident_path("it");
133                                 make::expr_path(path)
134                             };
135                             make::match_arm(once(pat.into()), None, expr)
136                         };
137
138                         let sad_arm = make::match_arm(
139                             // FIXME: would be cool to use `None` or `Err(_)` if appropriate
140                             once(make::wildcard_pat().into()),
141                             None,
142                             early_expression,
143                         );
144
145                         make::expr_match(cond_expr, make::match_arm_list(vec![happy_arm, sad_arm]))
146                     };
147
148                     let let_stmt = make::let_stmt(bound_ident, None, Some(match_expr));
149                     let let_stmt = let_stmt.indent(if_indent_level);
150                     let_stmt.syntax().clone_for_update()
151                 }
152             };
153
154             let then_block_items = then_block.dedent(IndentLevel(1)).clone_for_update();
155
156             let end_of_then = then_block_items.syntax().last_child_or_token().unwrap();
157             let end_of_then =
158                 if end_of_then.prev_sibling_or_token().map(|n| n.kind()) == Some(WHITESPACE) {
159                     end_of_then.prev_sibling_or_token().unwrap()
160                 } else {
161                     end_of_then
162                 };
163
164             let then_statements = replacement
165                 .children_with_tokens()
166                 .chain(
167                     then_block_items
168                         .syntax()
169                         .children_with_tokens()
170                         .skip(1)
171                         .take_while(|i| *i != end_of_then),
172                 )
173                 .collect();
174
175             ted::replace_with_many(if_expr.syntax(), then_statements)
176         },
177     )
178 }
179
180 #[cfg(test)]
181 mod tests {
182     use crate::tests::{check_assist, check_assist_not_applicable};
183
184     use super::*;
185
186     #[test]
187     fn convert_inside_fn() {
188         check_assist(
189             convert_to_guarded_return,
190             r#"
191 fn main() {
192     bar();
193     if$0 true {
194         foo();
195
196         // comment
197         bar();
198     }
199 }
200 "#,
201             r#"
202 fn main() {
203     bar();
204     if false {
205         return;
206     }
207     foo();
208
209     // comment
210     bar();
211 }
212 "#,
213         );
214     }
215
216     #[test]
217     fn convert_let_inside_fn() {
218         check_assist(
219             convert_to_guarded_return,
220             r#"
221 fn main(n: Option<String>) {
222     bar();
223     if$0 let Some(n) = n {
224         foo(n);
225
226         // comment
227         bar();
228     }
229 }
230 "#,
231             r#"
232 fn main(n: Option<String>) {
233     bar();
234     let n = match n {
235         Some(it) => it,
236         _ => return,
237     };
238     foo(n);
239
240     // comment
241     bar();
242 }
243 "#,
244         );
245     }
246
247     #[test]
248     fn convert_if_let_result() {
249         check_assist(
250             convert_to_guarded_return,
251             r#"
252 fn main() {
253     if$0 let Ok(x) = Err(92) {
254         foo(x);
255     }
256 }
257 "#,
258             r#"
259 fn main() {
260     let x = match Err(92) {
261         Ok(it) => it,
262         _ => return,
263     };
264     foo(x);
265 }
266 "#,
267         );
268     }
269
270     #[test]
271     fn convert_let_ok_inside_fn() {
272         check_assist(
273             convert_to_guarded_return,
274             r#"
275 fn main(n: Option<String>) {
276     bar();
277     if$0 let Some(n) = n {
278         foo(n);
279
280         // comment
281         bar();
282     }
283 }
284 "#,
285             r#"
286 fn main(n: Option<String>) {
287     bar();
288     let n = match n {
289         Some(it) => it,
290         _ => return,
291     };
292     foo(n);
293
294     // comment
295     bar();
296 }
297 "#,
298         );
299     }
300
301     #[test]
302     fn convert_let_mut_ok_inside_fn() {
303         check_assist(
304             convert_to_guarded_return,
305             r#"
306 fn main(n: Option<String>) {
307     bar();
308     if$0 let Some(mut n) = n {
309         foo(n);
310
311         // comment
312         bar();
313     }
314 }
315 "#,
316             r#"
317 fn main(n: Option<String>) {
318     bar();
319     let mut n = match n {
320         Some(it) => it,
321         _ => return,
322     };
323     foo(n);
324
325     // comment
326     bar();
327 }
328 "#,
329         );
330     }
331
332     #[test]
333     fn convert_let_ref_ok_inside_fn() {
334         check_assist(
335             convert_to_guarded_return,
336             r#"
337 fn main(n: Option<&str>) {
338     bar();
339     if$0 let Some(ref n) = n {
340         foo(n);
341
342         // comment
343         bar();
344     }
345 }
346 "#,
347             r#"
348 fn main(n: Option<&str>) {
349     bar();
350     let ref n = match n {
351         Some(it) => it,
352         _ => return,
353     };
354     foo(n);
355
356     // comment
357     bar();
358 }
359 "#,
360         );
361     }
362
363     #[test]
364     fn convert_inside_while() {
365         check_assist(
366             convert_to_guarded_return,
367             r#"
368 fn main() {
369     while true {
370         if$0 true {
371             foo();
372             bar();
373         }
374     }
375 }
376 "#,
377             r#"
378 fn main() {
379     while true {
380         if false {
381             continue;
382         }
383         foo();
384         bar();
385     }
386 }
387 "#,
388         );
389     }
390
391     #[test]
392     fn convert_let_inside_while() {
393         check_assist(
394             convert_to_guarded_return,
395             r#"
396 fn main() {
397     while true {
398         if$0 let Some(n) = n {
399             foo(n);
400             bar();
401         }
402     }
403 }
404 "#,
405             r#"
406 fn main() {
407     while true {
408         let n = match n {
409             Some(it) => it,
410             _ => continue,
411         };
412         foo(n);
413         bar();
414     }
415 }
416 "#,
417         );
418     }
419
420     #[test]
421     fn convert_inside_loop() {
422         check_assist(
423             convert_to_guarded_return,
424             r#"
425 fn main() {
426     loop {
427         if$0 true {
428             foo();
429             bar();
430         }
431     }
432 }
433 "#,
434             r#"
435 fn main() {
436     loop {
437         if false {
438             continue;
439         }
440         foo();
441         bar();
442     }
443 }
444 "#,
445         );
446     }
447
448     #[test]
449     fn convert_let_inside_loop() {
450         check_assist(
451             convert_to_guarded_return,
452             r#"
453 fn main() {
454     loop {
455         if$0 let Some(n) = n {
456             foo(n);
457             bar();
458         }
459     }
460 }
461 "#,
462             r#"
463 fn main() {
464     loop {
465         let n = match n {
466             Some(it) => it,
467             _ => continue,
468         };
469         foo(n);
470         bar();
471     }
472 }
473 "#,
474         );
475     }
476
477     #[test]
478     fn ignore_already_converted_if() {
479         check_assist_not_applicable(
480             convert_to_guarded_return,
481             r#"
482 fn main() {
483     if$0 true {
484         return;
485     }
486 }
487 "#,
488         );
489     }
490
491     #[test]
492     fn ignore_already_converted_loop() {
493         check_assist_not_applicable(
494             convert_to_guarded_return,
495             r#"
496 fn main() {
497     loop {
498         if$0 true {
499             continue;
500         }
501     }
502 }
503 "#,
504         );
505     }
506
507     #[test]
508     fn ignore_return() {
509         check_assist_not_applicable(
510             convert_to_guarded_return,
511             r#"
512 fn main() {
513     if$0 true {
514         return
515     }
516 }
517 "#,
518         );
519     }
520
521     #[test]
522     fn ignore_else_branch() {
523         check_assist_not_applicable(
524             convert_to_guarded_return,
525             r#"
526 fn main() {
527     if$0 true {
528         foo();
529     } else {
530         bar()
531     }
532 }
533 "#,
534         );
535     }
536
537     #[test]
538     fn ignore_statements_aftert_if() {
539         check_assist_not_applicable(
540             convert_to_guarded_return,
541             r#"
542 fn main() {
543     if$0 true {
544         foo();
545     }
546     bar();
547 }
548 "#,
549         );
550     }
551
552     #[test]
553     fn ignore_statements_inside_if() {
554         check_assist_not_applicable(
555             convert_to_guarded_return,
556             r#"
557 fn main() {
558     if false {
559         if$0 true {
560             foo();
561         }
562     }
563 }
564 "#,
565         );
566     }
567 }