]> git.lizzy.rs Git - rust.git/blob - crates/ide_assists/src/handlers/convert_to_guarded_return.rs
Merge #10181
[rust.git] / crates / ide_assists / src / handlers / convert_to_guarded_return.rs
1 use std::iter::once;
2
3 use syntax::{
4     ast::{
5         self,
6         edit::{AstNodeEdit, IndentLevel},
7         make,
8     },
9     ted, AstNode,
10     SyntaxKind::{FN, LOOP_EXPR, WHILE_EXPR, WHITESPACE},
11     T,
12 };
13
14 use crate::{
15     assist_context::{AssistContext, Assists},
16     utils::invert_boolean_expression,
17     AssistId, AssistKind,
18 };
19
20 // Assist: convert_to_guarded_return
21 //
22 // Replace a large conditional with a guarded return.
23 //
24 // ```
25 // fn main() {
26 //     $0if cond {
27 //         foo();
28 //         bar();
29 //     }
30 // }
31 // ```
32 // ->
33 // ```
34 // fn main() {
35 //     if !cond {
36 //         return;
37 //     }
38 //     foo();
39 //     bar();
40 // }
41 // ```
42 pub(crate) fn convert_to_guarded_return(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
43     let if_expr: ast::IfExpr = ctx.find_node_at_offset()?;
44     if if_expr.else_branch().is_some() {
45         return None;
46     }
47
48     let cond = if_expr.condition()?;
49
50     // Check if there is an IfLet that we can handle.
51     let if_let_pat = match cond.pat() {
52         None => None, // No IfLet, supported.
53         Some(ast::Pat::TupleStructPat(pat)) if pat.fields().count() == 1 => {
54             let path = pat.path()?;
55             if path.qualifier().is_some() {
56                 return None;
57             }
58
59             let bound_ident = pat.fields().next().unwrap();
60             if !ast::IdentPat::can_cast(bound_ident.syntax().kind()) {
61                 return None;
62             }
63
64             Some((path, bound_ident))
65         }
66         Some(_) => return None, // Unsupported IfLet.
67     };
68
69     let cond_expr = cond.expr()?;
70     let then_block = if_expr.then_branch()?;
71     let then_block = then_block.stmt_list()?;
72
73     let parent_block = if_expr.syntax().parent()?.ancestors().find_map(ast::BlockExpr::cast)?;
74
75     if parent_block.tail_expr()? != if_expr.clone().into() {
76         return None;
77     }
78
79     // FIXME: This relies on untyped syntax tree and casts to much. It should be
80     // rewritten to use strongly-typed APIs.
81
82     // check for early return and continue
83     let first_in_then_block = then_block.syntax().first_child()?;
84     if ast::ReturnExpr::can_cast(first_in_then_block.kind())
85         || ast::ContinueExpr::can_cast(first_in_then_block.kind())
86         || first_in_then_block
87             .children()
88             .any(|x| ast::ReturnExpr::can_cast(x.kind()) || ast::ContinueExpr::can_cast(x.kind()))
89     {
90         return None;
91     }
92
93     let parent_container = parent_block.syntax().parent()?;
94
95     let early_expression: ast::Expr = match parent_container.kind() {
96         WHILE_EXPR | LOOP_EXPR => make::expr_continue(),
97         FN => make::expr_return(None),
98         _ => return None,
99     };
100
101     if then_block.syntax().first_child_or_token().map(|t| t.kind() == T!['{']).is_none() {
102         return None;
103     }
104
105     then_block.syntax().last_child_or_token().filter(|t| t.kind() == T!['}'])?;
106
107     let target = if_expr.syntax().text_range();
108     acc.add(
109         AssistId("convert_to_guarded_return", AssistKind::RefactorRewrite),
110         "Convert to guarded return",
111         target,
112         |edit| {
113             let if_expr = edit.make_mut(if_expr);
114             let if_indent_level = IndentLevel::from_node(if_expr.syntax());
115             let replacement = match if_let_pat {
116                 None => {
117                     // If.
118                     let new_expr = {
119                         let then_branch =
120                             make::block_expr(once(make::expr_stmt(early_expression).into()), None);
121                         let cond = invert_boolean_expression(cond_expr);
122                         make::expr_if(make::condition(cond, None), then_branch, None)
123                             .indent(if_indent_level)
124                     };
125                     new_expr.syntax().clone_for_update()
126                 }
127                 Some((path, bound_ident)) => {
128                     // If-let.
129                     let match_expr = {
130                         let happy_arm = {
131                             let pat = make::tuple_struct_pat(
132                                 path,
133                                 once(make::ext::simple_ident_pat(make::name("it")).into()),
134                             );
135                             let expr = {
136                                 let path = make::ext::ident_path("it");
137                                 make::expr_path(path)
138                             };
139                             make::match_arm(once(pat.into()), None, expr)
140                         };
141
142                         let sad_arm = make::match_arm(
143                             // FIXME: would be cool to use `None` or `Err(_)` if appropriate
144                             once(make::wildcard_pat().into()),
145                             None,
146                             early_expression,
147                         );
148
149                         make::expr_match(cond_expr, make::match_arm_list(vec![happy_arm, sad_arm]))
150                     };
151
152                     let let_stmt = make::let_stmt(bound_ident, None, Some(match_expr));
153                     let let_stmt = let_stmt.indent(if_indent_level);
154                     let_stmt.syntax().clone_for_update()
155                 }
156             };
157
158             let then_block_items = then_block.dedent(IndentLevel(1)).clone_for_update();
159
160             let end_of_then = then_block_items.syntax().last_child_or_token().unwrap();
161             let end_of_then =
162                 if end_of_then.prev_sibling_or_token().map(|n| n.kind()) == Some(WHITESPACE) {
163                     end_of_then.prev_sibling_or_token().unwrap()
164                 } else {
165                     end_of_then
166                 };
167
168             let then_statements = replacement
169                 .children_with_tokens()
170                 .chain(
171                     then_block_items
172                         .syntax()
173                         .children_with_tokens()
174                         .skip(1)
175                         .take_while(|i| *i != end_of_then),
176                 )
177                 .collect();
178
179             ted::replace_with_many(if_expr.syntax(), then_statements)
180         },
181     )
182 }
183
184 #[cfg(test)]
185 mod tests {
186     use crate::tests::{check_assist, check_assist_not_applicable};
187
188     use super::*;
189
190     #[test]
191     fn convert_inside_fn() {
192         check_assist(
193             convert_to_guarded_return,
194             r#"
195 fn main() {
196     bar();
197     if$0 true {
198         foo();
199
200         // comment
201         bar();
202     }
203 }
204 "#,
205             r#"
206 fn main() {
207     bar();
208     if false {
209         return;
210     }
211     foo();
212
213     // comment
214     bar();
215 }
216 "#,
217         );
218     }
219
220     #[test]
221     fn convert_let_inside_fn() {
222         check_assist(
223             convert_to_guarded_return,
224             r#"
225 fn main(n: Option<String>) {
226     bar();
227     if$0 let Some(n) = n {
228         foo(n);
229
230         // comment
231         bar();
232     }
233 }
234 "#,
235             r#"
236 fn main(n: Option<String>) {
237     bar();
238     let n = match n {
239         Some(it) => it,
240         _ => return,
241     };
242     foo(n);
243
244     // comment
245     bar();
246 }
247 "#,
248         );
249     }
250
251     #[test]
252     fn convert_if_let_result() {
253         check_assist(
254             convert_to_guarded_return,
255             r#"
256 fn main() {
257     if$0 let Ok(x) = Err(92) {
258         foo(x);
259     }
260 }
261 "#,
262             r#"
263 fn main() {
264     let x = match Err(92) {
265         Ok(it) => it,
266         _ => return,
267     };
268     foo(x);
269 }
270 "#,
271         );
272     }
273
274     #[test]
275     fn convert_let_ok_inside_fn() {
276         check_assist(
277             convert_to_guarded_return,
278             r#"
279 fn main(n: Option<String>) {
280     bar();
281     if$0 let Some(n) = n {
282         foo(n);
283
284         // comment
285         bar();
286     }
287 }
288 "#,
289             r#"
290 fn main(n: Option<String>) {
291     bar();
292     let n = match n {
293         Some(it) => it,
294         _ => return,
295     };
296     foo(n);
297
298     // comment
299     bar();
300 }
301 "#,
302         );
303     }
304
305     #[test]
306     fn convert_let_mut_ok_inside_fn() {
307         check_assist(
308             convert_to_guarded_return,
309             r#"
310 fn main(n: Option<String>) {
311     bar();
312     if$0 let Some(mut n) = n {
313         foo(n);
314
315         // comment
316         bar();
317     }
318 }
319 "#,
320             r#"
321 fn main(n: Option<String>) {
322     bar();
323     let mut n = match n {
324         Some(it) => it,
325         _ => return,
326     };
327     foo(n);
328
329     // comment
330     bar();
331 }
332 "#,
333         );
334     }
335
336     #[test]
337     fn convert_let_ref_ok_inside_fn() {
338         check_assist(
339             convert_to_guarded_return,
340             r#"
341 fn main(n: Option<&str>) {
342     bar();
343     if$0 let Some(ref n) = n {
344         foo(n);
345
346         // comment
347         bar();
348     }
349 }
350 "#,
351             r#"
352 fn main(n: Option<&str>) {
353     bar();
354     let ref n = match n {
355         Some(it) => it,
356         _ => return,
357     };
358     foo(n);
359
360     // comment
361     bar();
362 }
363 "#,
364         );
365     }
366
367     #[test]
368     fn convert_inside_while() {
369         check_assist(
370             convert_to_guarded_return,
371             r#"
372 fn main() {
373     while true {
374         if$0 true {
375             foo();
376             bar();
377         }
378     }
379 }
380 "#,
381             r#"
382 fn main() {
383     while true {
384         if false {
385             continue;
386         }
387         foo();
388         bar();
389     }
390 }
391 "#,
392         );
393     }
394
395     #[test]
396     fn convert_let_inside_while() {
397         check_assist(
398             convert_to_guarded_return,
399             r#"
400 fn main() {
401     while true {
402         if$0 let Some(n) = n {
403             foo(n);
404             bar();
405         }
406     }
407 }
408 "#,
409             r#"
410 fn main() {
411     while true {
412         let n = match n {
413             Some(it) => it,
414             _ => continue,
415         };
416         foo(n);
417         bar();
418     }
419 }
420 "#,
421         );
422     }
423
424     #[test]
425     fn convert_inside_loop() {
426         check_assist(
427             convert_to_guarded_return,
428             r#"
429 fn main() {
430     loop {
431         if$0 true {
432             foo();
433             bar();
434         }
435     }
436 }
437 "#,
438             r#"
439 fn main() {
440     loop {
441         if false {
442             continue;
443         }
444         foo();
445         bar();
446     }
447 }
448 "#,
449         );
450     }
451
452     #[test]
453     fn convert_let_inside_loop() {
454         check_assist(
455             convert_to_guarded_return,
456             r#"
457 fn main() {
458     loop {
459         if$0 let Some(n) = n {
460             foo(n);
461             bar();
462         }
463     }
464 }
465 "#,
466             r#"
467 fn main() {
468     loop {
469         let n = match n {
470             Some(it) => it,
471             _ => continue,
472         };
473         foo(n);
474         bar();
475     }
476 }
477 "#,
478         );
479     }
480
481     #[test]
482     fn ignore_already_converted_if() {
483         check_assist_not_applicable(
484             convert_to_guarded_return,
485             r#"
486 fn main() {
487     if$0 true {
488         return;
489     }
490 }
491 "#,
492         );
493     }
494
495     #[test]
496     fn ignore_already_converted_loop() {
497         check_assist_not_applicable(
498             convert_to_guarded_return,
499             r#"
500 fn main() {
501     loop {
502         if$0 true {
503             continue;
504         }
505     }
506 }
507 "#,
508         );
509     }
510
511     #[test]
512     fn ignore_return() {
513         check_assist_not_applicable(
514             convert_to_guarded_return,
515             r#"
516 fn main() {
517     if$0 true {
518         return
519     }
520 }
521 "#,
522         );
523     }
524
525     #[test]
526     fn ignore_else_branch() {
527         check_assist_not_applicable(
528             convert_to_guarded_return,
529             r#"
530 fn main() {
531     if$0 true {
532         foo();
533     } else {
534         bar()
535     }
536 }
537 "#,
538         );
539     }
540
541     #[test]
542     fn ignore_statements_aftert_if() {
543         check_assist_not_applicable(
544             convert_to_guarded_return,
545             r#"
546 fn main() {
547     if$0 true {
548         foo();
549     }
550     bar();
551 }
552 "#,
553         );
554     }
555
556     #[test]
557     fn ignore_statements_inside_if() {
558         check_assist_not_applicable(
559             convert_to_guarded_return,
560             r#"
561 fn main() {
562     if false {
563         if$0 true {
564             foo();
565         }
566     }
567 }
568 "#,
569         );
570     }
571 }