]> git.lizzy.rs Git - rust.git/blob - crates/ide_assists/src/handlers/early_return.rs
Invert boolean literals in assist negation logic
[rust.git] / crates / ide_assists / src / handlers / early_return.rs
1 use std::{iter::once, ops::RangeInclusive};
2
3 use syntax::{
4     algo::replace_children,
5     ast::{
6         self,
7         edit::{AstNodeEdit, IndentLevel},
8         make,
9     },
10     AstNode,
11     SyntaxKind::{FN, LOOP_EXPR, L_CURLY, R_CURLY, WHILE_EXPR, WHITESPACE},
12     SyntaxNode,
13 };
14
15 use crate::{
16     assist_context::{AssistContext, Assists},
17     utils::invert_boolean_expression,
18     AssistId, AssistKind,
19 };
20
21 // Assist: convert_to_guarded_return
22 //
23 // Replace a large conditional with a guarded return.
24 //
25 // ```
26 // fn main() {
27 //     $0if cond {
28 //         foo();
29 //         bar();
30 //     }
31 // }
32 // ```
33 // ->
34 // ```
35 // fn main() {
36 //     if !cond {
37 //         return;
38 //     }
39 //     foo();
40 //     bar();
41 // }
42 // ```
43 pub(crate) fn convert_to_guarded_return(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
44     let if_expr: ast::IfExpr = ctx.find_node_at_offset()?;
45     if if_expr.else_branch().is_some() {
46         return None;
47     }
48
49     let cond = if_expr.condition()?;
50
51     // Check if there is an IfLet that we can handle.
52     let if_let_pat = match cond.pat() {
53         None => None, // No IfLet, supported.
54         Some(ast::Pat::TupleStructPat(pat)) if pat.fields().count() == 1 => {
55             let path = pat.path()?;
56             match path.qualifier() {
57                 None => {
58                     let bound_ident = pat.fields().next().unwrap();
59                     if ast::IdentPat::can_cast(bound_ident.syntax().kind()) {
60                         Some((path, bound_ident))
61                     } else {
62                         return None;
63                     }
64                 }
65                 Some(_) => return None,
66             }
67         }
68         Some(_) => return None, // Unsupported IfLet.
69     };
70
71     let cond_expr = cond.expr()?;
72     let then_block = if_expr.then_branch()?;
73
74     let parent_block = if_expr.syntax().parent()?.ancestors().find_map(ast::BlockExpr::cast)?;
75
76     if parent_block.tail_expr()? != if_expr.clone().into() {
77         return None;
78     }
79
80     // check for early return and continue
81     let first_in_then_block = then_block.syntax().first_child()?;
82     if ast::ReturnExpr::can_cast(first_in_then_block.kind())
83         || ast::ContinueExpr::can_cast(first_in_then_block.kind())
84         || first_in_then_block
85             .children()
86             .any(|x| ast::ReturnExpr::can_cast(x.kind()) || ast::ContinueExpr::can_cast(x.kind()))
87     {
88         return None;
89     }
90
91     let parent_container = parent_block.syntax().parent()?;
92
93     let early_expression: ast::Expr = match parent_container.kind() {
94         WHILE_EXPR | LOOP_EXPR => make::expr_continue(),
95         FN => make::expr_return(None),
96         _ => return None,
97     };
98
99     if then_block.syntax().first_child_or_token().map(|t| t.kind() == L_CURLY).is_none() {
100         return None;
101     }
102
103     then_block.syntax().last_child_or_token().filter(|t| t.kind() == R_CURLY)?;
104
105     let target = if_expr.syntax().text_range();
106     acc.add(
107         AssistId("convert_to_guarded_return", AssistKind::RefactorRewrite),
108         "Convert to guarded return",
109         target,
110         |edit| {
111             let if_indent_level = IndentLevel::from_node(if_expr.syntax());
112             let new_block = match if_let_pat {
113                 None => {
114                     // If.
115                     let new_expr = {
116                         let then_branch =
117                             make::block_expr(once(make::expr_stmt(early_expression).into()), None);
118                         let cond = invert_boolean_expression(&ctx.sema, cond_expr);
119                         make::expr_if(make::condition(cond, None), then_branch, None)
120                             .indent(if_indent_level)
121                     };
122                     replace(new_expr.syntax(), &then_block, &parent_block, &if_expr)
123                 }
124                 Some((path, bound_ident)) => {
125                     // If-let.
126                     let match_expr = {
127                         let happy_arm = {
128                             let pat = make::tuple_struct_pat(
129                                 path,
130                                 once(make::ext::simple_ident_pat(make::name("it")).into()),
131                             );
132                             let expr = {
133                                 let path = make::ext::ident_path("it");
134                                 make::expr_path(path)
135                             };
136                             make::match_arm(once(pat.into()), None, expr)
137                         };
138
139                         let sad_arm = make::match_arm(
140                             // FIXME: would be cool to use `None` or `Err(_)` if appropriate
141                             once(make::wildcard_pat().into()),
142                             None,
143                             early_expression,
144                         );
145
146                         make::expr_match(cond_expr, make::match_arm_list(vec![happy_arm, sad_arm]))
147                     };
148
149                     let let_stmt = make::let_stmt(bound_ident, Some(match_expr));
150                     let let_stmt = let_stmt.indent(if_indent_level);
151                     replace(let_stmt.syntax(), &then_block, &parent_block, &if_expr)
152                 }
153             };
154             edit.replace_ast(parent_block, ast::BlockExpr::cast(new_block).unwrap());
155
156             fn replace(
157                 new_expr: &SyntaxNode,
158                 then_block: &ast::BlockExpr,
159                 parent_block: &ast::BlockExpr,
160                 if_expr: &ast::IfExpr,
161             ) -> SyntaxNode {
162                 let then_block_items = then_block.dedent(IndentLevel(1));
163                 let end_of_then = then_block_items.syntax().last_child_or_token().unwrap();
164                 let end_of_then =
165                     if end_of_then.prev_sibling_or_token().map(|n| n.kind()) == Some(WHITESPACE) {
166                         end_of_then.prev_sibling_or_token().unwrap()
167                     } else {
168                         end_of_then
169                     };
170                 let mut then_statements = new_expr.children_with_tokens().chain(
171                     then_block_items
172                         .syntax()
173                         .children_with_tokens()
174                         .skip(1)
175                         .take_while(|i| *i != end_of_then),
176                 );
177                 replace_children(
178                     parent_block.syntax(),
179                     RangeInclusive::new(
180                         if_expr.clone().syntax().clone().into(),
181                         if_expr.syntax().clone().into(),
182                     ),
183                     &mut then_statements,
184                 )
185             }
186         },
187     )
188 }
189
190 #[cfg(test)]
191 mod tests {
192     use crate::tests::{check_assist, check_assist_not_applicable};
193
194     use super::*;
195
196     #[test]
197     fn convert_inside_fn() {
198         check_assist(
199             convert_to_guarded_return,
200             r#"
201             fn main() {
202                 bar();
203                 if$0 true {
204                     foo();
205
206                     // comment
207                     bar();
208                 }
209             }
210             "#,
211             r#"
212             fn main() {
213                 bar();
214                 if false {
215                     return;
216                 }
217                 foo();
218
219                 // comment
220                 bar();
221             }
222             "#,
223         );
224     }
225
226     #[test]
227     fn convert_let_inside_fn() {
228         check_assist(
229             convert_to_guarded_return,
230             r#"
231             fn main(n: Option<String>) {
232                 bar();
233                 if$0 let Some(n) = n {
234                     foo(n);
235
236                     // comment
237                     bar();
238                 }
239             }
240             "#,
241             r#"
242             fn main(n: Option<String>) {
243                 bar();
244                 let n = match n {
245                     Some(it) => it,
246                     _ => return,
247                 };
248                 foo(n);
249
250                 // comment
251                 bar();
252             }
253             "#,
254         );
255     }
256
257     #[test]
258     fn convert_if_let_result() {
259         check_assist(
260             convert_to_guarded_return,
261             r#"
262             fn main() {
263                 if$0 let Ok(x) = Err(92) {
264                     foo(x);
265                 }
266             }
267             "#,
268             r#"
269             fn main() {
270                 let x = match Err(92) {
271                     Ok(it) => it,
272                     _ => return,
273                 };
274                 foo(x);
275             }
276             "#,
277         );
278     }
279
280     #[test]
281     fn convert_let_ok_inside_fn() {
282         check_assist(
283             convert_to_guarded_return,
284             r#"
285             fn main(n: Option<String>) {
286                 bar();
287                 if$0 let Some(n) = n {
288                     foo(n);
289
290                     // comment
291                     bar();
292                 }
293             }
294             "#,
295             r#"
296             fn main(n: Option<String>) {
297                 bar();
298                 let n = match n {
299                     Some(it) => it,
300                     _ => return,
301                 };
302                 foo(n);
303
304                 // comment
305                 bar();
306             }
307             "#,
308         );
309     }
310
311     #[test]
312     fn convert_let_mut_ok_inside_fn() {
313         check_assist(
314             convert_to_guarded_return,
315             r#"
316             fn main(n: Option<String>) {
317                 bar();
318                 if$0 let Some(mut n) = n {
319                     foo(n);
320
321                     // comment
322                     bar();
323                 }
324             }
325             "#,
326             r#"
327             fn main(n: Option<String>) {
328                 bar();
329                 let mut n = match n {
330                     Some(it) => it,
331                     _ => return,
332                 };
333                 foo(n);
334
335                 // comment
336                 bar();
337             }
338             "#,
339         );
340     }
341
342     #[test]
343     fn convert_let_ref_ok_inside_fn() {
344         check_assist(
345             convert_to_guarded_return,
346             r#"
347             fn main(n: Option<&str>) {
348                 bar();
349                 if$0 let Some(ref n) = n {
350                     foo(n);
351
352                     // comment
353                     bar();
354                 }
355             }
356             "#,
357             r#"
358             fn main(n: Option<&str>) {
359                 bar();
360                 let ref n = match n {
361                     Some(it) => it,
362                     _ => return,
363                 };
364                 foo(n);
365
366                 // comment
367                 bar();
368             }
369             "#,
370         );
371     }
372
373     #[test]
374     fn convert_inside_while() {
375         check_assist(
376             convert_to_guarded_return,
377             r#"
378             fn main() {
379                 while true {
380                     if$0 true {
381                         foo();
382                         bar();
383                     }
384                 }
385             }
386             "#,
387             r#"
388             fn main() {
389                 while true {
390                     if false {
391                         continue;
392                     }
393                     foo();
394                     bar();
395                 }
396             }
397             "#,
398         );
399     }
400
401     #[test]
402     fn convert_let_inside_while() {
403         check_assist(
404             convert_to_guarded_return,
405             r#"
406             fn main() {
407                 while true {
408                     if$0 let Some(n) = n {
409                         foo(n);
410                         bar();
411                     }
412                 }
413             }
414             "#,
415             r#"
416             fn main() {
417                 while true {
418                     let n = match n {
419                         Some(it) => it,
420                         _ => continue,
421                     };
422                     foo(n);
423                     bar();
424                 }
425             }
426             "#,
427         );
428     }
429
430     #[test]
431     fn convert_inside_loop() {
432         check_assist(
433             convert_to_guarded_return,
434             r#"
435             fn main() {
436                 loop {
437                     if$0 true {
438                         foo();
439                         bar();
440                     }
441                 }
442             }
443             "#,
444             r#"
445             fn main() {
446                 loop {
447                     if false {
448                         continue;
449                     }
450                     foo();
451                     bar();
452                 }
453             }
454             "#,
455         );
456     }
457
458     #[test]
459     fn convert_let_inside_loop() {
460         check_assist(
461             convert_to_guarded_return,
462             r#"
463             fn main() {
464                 loop {
465                     if$0 let Some(n) = n {
466                         foo(n);
467                         bar();
468                     }
469                 }
470             }
471             "#,
472             r#"
473             fn main() {
474                 loop {
475                     let n = match n {
476                         Some(it) => it,
477                         _ => continue,
478                     };
479                     foo(n);
480                     bar();
481                 }
482             }
483             "#,
484         );
485     }
486
487     #[test]
488     fn ignore_already_converted_if() {
489         check_assist_not_applicable(
490             convert_to_guarded_return,
491             r#"
492             fn main() {
493                 if$0 true {
494                     return;
495                 }
496             }
497             "#,
498         );
499     }
500
501     #[test]
502     fn ignore_already_converted_loop() {
503         check_assist_not_applicable(
504             convert_to_guarded_return,
505             r#"
506             fn main() {
507                 loop {
508                     if$0 true {
509                         continue;
510                     }
511                 }
512             }
513             "#,
514         );
515     }
516
517     #[test]
518     fn ignore_return() {
519         check_assist_not_applicable(
520             convert_to_guarded_return,
521             r#"
522             fn main() {
523                 if$0 true {
524                     return
525                 }
526             }
527             "#,
528         );
529     }
530
531     #[test]
532     fn ignore_else_branch() {
533         check_assist_not_applicable(
534             convert_to_guarded_return,
535             r#"
536             fn main() {
537                 if$0 true {
538                     foo();
539                 } else {
540                     bar()
541                 }
542             }
543             "#,
544         );
545     }
546
547     #[test]
548     fn ignore_statements_aftert_if() {
549         check_assist_not_applicable(
550             convert_to_guarded_return,
551             r#"
552             fn main() {
553                 if$0 true {
554                     foo();
555                 }
556                 bar();
557             }
558             "#,
559         );
560     }
561
562     #[test]
563     fn ignore_statements_inside_if() {
564         check_assist_not_applicable(
565             convert_to_guarded_return,
566             r#"
567             fn main() {
568                 if false {
569                     if$0 true {
570                         foo();
571                     }
572                 }
573             }
574             "#,
575         );
576     }
577 }