]> git.lizzy.rs Git - rust.git/blob - crates/ide_assists/src/handlers/convert_to_guarded_return.rs
internal: use standard test style
[rust.git] / crates / ide_assists / src / handlers / convert_to_guarded_return.rs
1 use std::{iter::once, ops::RangeInclusive};
2
3 use syntax::{
4     algo::replace_children,
5     ast::{
6         self,
7         edit::{AstNodeEdit, IndentLevel},
8         make,
9     },
10     AstNode,
11     SyntaxKind::{FN, LOOP_EXPR, L_CURLY, R_CURLY, WHILE_EXPR, WHITESPACE},
12     SyntaxNode,
13 };
14
15 use crate::{
16     assist_context::{AssistContext, Assists},
17     utils::invert_boolean_expression,
18     AssistId, AssistKind,
19 };
20
21 // Assist: convert_to_guarded_return
22 //
23 // Replace a large conditional with a guarded return.
24 //
25 // ```
26 // fn main() {
27 //     $0if cond {
28 //         foo();
29 //         bar();
30 //     }
31 // }
32 // ```
33 // ->
34 // ```
35 // fn main() {
36 //     if !cond {
37 //         return;
38 //     }
39 //     foo();
40 //     bar();
41 // }
42 // ```
43 pub(crate) fn convert_to_guarded_return(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
44     let if_expr: ast::IfExpr = ctx.find_node_at_offset()?;
45     if if_expr.else_branch().is_some() {
46         return None;
47     }
48
49     let cond = if_expr.condition()?;
50
51     // Check if there is an IfLet that we can handle.
52     let if_let_pat = match cond.pat() {
53         None => None, // No IfLet, supported.
54         Some(ast::Pat::TupleStructPat(pat)) if pat.fields().count() == 1 => {
55             let path = pat.path()?;
56             match path.qualifier() {
57                 None => {
58                     let bound_ident = pat.fields().next().unwrap();
59                     if ast::IdentPat::can_cast(bound_ident.syntax().kind()) {
60                         Some((path, bound_ident))
61                     } else {
62                         return None;
63                     }
64                 }
65                 Some(_) => return None,
66             }
67         }
68         Some(_) => return None, // Unsupported IfLet.
69     };
70
71     let cond_expr = cond.expr()?;
72     let then_block = if_expr.then_branch()?;
73
74     let parent_block = if_expr.syntax().parent()?.ancestors().find_map(ast::BlockExpr::cast)?;
75
76     if parent_block.tail_expr()? != if_expr.clone().into() {
77         return None;
78     }
79
80     // check for early return and continue
81     let first_in_then_block = then_block.syntax().first_child()?;
82     if ast::ReturnExpr::can_cast(first_in_then_block.kind())
83         || ast::ContinueExpr::can_cast(first_in_then_block.kind())
84         || first_in_then_block
85             .children()
86             .any(|x| ast::ReturnExpr::can_cast(x.kind()) || ast::ContinueExpr::can_cast(x.kind()))
87     {
88         return None;
89     }
90
91     let parent_container = parent_block.syntax().parent()?;
92
93     let early_expression: ast::Expr = match parent_container.kind() {
94         WHILE_EXPR | LOOP_EXPR => make::expr_continue(),
95         FN => make::expr_return(None),
96         _ => return None,
97     };
98
99     if then_block.syntax().first_child_or_token().map(|t| t.kind() == L_CURLY).is_none() {
100         return None;
101     }
102
103     then_block.syntax().last_child_or_token().filter(|t| t.kind() == R_CURLY)?;
104
105     let target = if_expr.syntax().text_range();
106     acc.add(
107         AssistId("convert_to_guarded_return", AssistKind::RefactorRewrite),
108         "Convert to guarded return",
109         target,
110         |edit| {
111             let if_indent_level = IndentLevel::from_node(if_expr.syntax());
112             let new_block = match if_let_pat {
113                 None => {
114                     // If.
115                     let new_expr = {
116                         let then_branch =
117                             make::block_expr(once(make::expr_stmt(early_expression).into()), None);
118                         let cond = invert_boolean_expression(cond_expr);
119                         make::expr_if(make::condition(cond, None), then_branch, None)
120                             .indent(if_indent_level)
121                     };
122                     replace(new_expr.syntax(), &then_block, &parent_block, &if_expr)
123                 }
124                 Some((path, bound_ident)) => {
125                     // If-let.
126                     let match_expr = {
127                         let happy_arm = {
128                             let pat = make::tuple_struct_pat(
129                                 path,
130                                 once(make::ext::simple_ident_pat(make::name("it")).into()),
131                             );
132                             let expr = {
133                                 let path = make::ext::ident_path("it");
134                                 make::expr_path(path)
135                             };
136                             make::match_arm(once(pat.into()), None, expr)
137                         };
138
139                         let sad_arm = make::match_arm(
140                             // FIXME: would be cool to use `None` or `Err(_)` if appropriate
141                             once(make::wildcard_pat().into()),
142                             None,
143                             early_expression,
144                         );
145
146                         make::expr_match(cond_expr, make::match_arm_list(vec![happy_arm, sad_arm]))
147                     };
148
149                     let let_stmt = make::let_stmt(bound_ident, None, Some(match_expr));
150                     let let_stmt = let_stmt.indent(if_indent_level);
151                     replace(let_stmt.syntax(), &then_block, &parent_block, &if_expr)
152                 }
153             };
154             edit.replace_ast(parent_block, ast::BlockExpr::cast(new_block).unwrap());
155
156             fn replace(
157                 new_expr: &SyntaxNode,
158                 then_block: &ast::BlockExpr,
159                 parent_block: &ast::BlockExpr,
160                 if_expr: &ast::IfExpr,
161             ) -> SyntaxNode {
162                 let then_block_items = then_block.dedent(IndentLevel(1));
163                 let end_of_then = then_block_items.syntax().last_child_or_token().unwrap();
164                 let end_of_then =
165                     if end_of_then.prev_sibling_or_token().map(|n| n.kind()) == Some(WHITESPACE) {
166                         end_of_then.prev_sibling_or_token().unwrap()
167                     } else {
168                         end_of_then
169                     };
170                 let mut then_statements = new_expr.children_with_tokens().chain(
171                     then_block_items
172                         .syntax()
173                         .children_with_tokens()
174                         .skip(1)
175                         .take_while(|i| *i != end_of_then),
176                 );
177                 replace_children(
178                     parent_block.syntax(),
179                     RangeInclusive::new(
180                         if_expr.clone().syntax().clone().into(),
181                         if_expr.syntax().clone().into(),
182                     ),
183                     &mut then_statements,
184                 )
185             }
186         },
187     )
188 }
189
190 #[cfg(test)]
191 mod tests {
192     use crate::tests::{check_assist, check_assist_not_applicable};
193
194     use super::*;
195
196     #[test]
197     fn convert_inside_fn() {
198         check_assist(
199             convert_to_guarded_return,
200             r#"
201 fn main() {
202     bar();
203     if$0 true {
204         foo();
205
206         // comment
207         bar();
208     }
209 }
210 "#,
211             r#"
212 fn main() {
213     bar();
214     if false {
215         return;
216     }
217     foo();
218
219     // comment
220     bar();
221 }
222 "#,
223         );
224     }
225
226     #[test]
227     fn convert_let_inside_fn() {
228         check_assist(
229             convert_to_guarded_return,
230             r#"
231 fn main(n: Option<String>) {
232     bar();
233     if$0 let Some(n) = n {
234         foo(n);
235
236         // comment
237         bar();
238     }
239 }
240 "#,
241             r#"
242 fn main(n: Option<String>) {
243     bar();
244     let n = match n {
245         Some(it) => it,
246         _ => return,
247     };
248     foo(n);
249
250     // comment
251     bar();
252 }
253 "#,
254         );
255     }
256
257     #[test]
258     fn convert_if_let_result() {
259         check_assist(
260             convert_to_guarded_return,
261             r#"
262 fn main() {
263     if$0 let Ok(x) = Err(92) {
264         foo(x);
265     }
266 }
267 "#,
268             r#"
269 fn main() {
270     let x = match Err(92) {
271         Ok(it) => it,
272         _ => return,
273     };
274     foo(x);
275 }
276 "#,
277         );
278     }
279
280     #[test]
281     fn convert_let_ok_inside_fn() {
282         check_assist(
283             convert_to_guarded_return,
284             r#"
285 fn main(n: Option<String>) {
286     bar();
287     if$0 let Some(n) = n {
288         foo(n);
289
290         // comment
291         bar();
292     }
293 }
294 "#,
295             r#"
296 fn main(n: Option<String>) {
297     bar();
298     let n = match n {
299         Some(it) => it,
300         _ => return,
301     };
302     foo(n);
303
304     // comment
305     bar();
306 }
307 "#,
308         );
309     }
310
311     #[test]
312     fn convert_let_mut_ok_inside_fn() {
313         check_assist(
314             convert_to_guarded_return,
315             r#"
316 fn main(n: Option<String>) {
317     bar();
318     if$0 let Some(mut n) = n {
319         foo(n);
320
321         // comment
322         bar();
323     }
324 }
325 "#,
326             r#"
327 fn main(n: Option<String>) {
328     bar();
329     let mut n = match n {
330         Some(it) => it,
331         _ => return,
332     };
333     foo(n);
334
335     // comment
336     bar();
337 }
338 "#,
339         );
340     }
341
342     #[test]
343     fn convert_let_ref_ok_inside_fn() {
344         check_assist(
345             convert_to_guarded_return,
346             r#"
347 fn main(n: Option<&str>) {
348     bar();
349     if$0 let Some(ref n) = n {
350         foo(n);
351
352         // comment
353         bar();
354     }
355 }
356 "#,
357             r#"
358 fn main(n: Option<&str>) {
359     bar();
360     let ref n = match n {
361         Some(it) => it,
362         _ => return,
363     };
364     foo(n);
365
366     // comment
367     bar();
368 }
369 "#,
370         );
371     }
372
373     #[test]
374     fn convert_inside_while() {
375         check_assist(
376             convert_to_guarded_return,
377             r#"
378 fn main() {
379     while true {
380         if$0 true {
381             foo();
382             bar();
383         }
384     }
385 }
386 "#,
387             r#"
388 fn main() {
389     while true {
390         if false {
391             continue;
392         }
393         foo();
394         bar();
395     }
396 }
397 "#,
398         );
399     }
400
401     #[test]
402     fn convert_let_inside_while() {
403         check_assist(
404             convert_to_guarded_return,
405             r#"
406 fn main() {
407     while true {
408         if$0 let Some(n) = n {
409             foo(n);
410             bar();
411         }
412     }
413 }
414 "#,
415             r#"
416 fn main() {
417     while true {
418         let n = match n {
419             Some(it) => it,
420             _ => continue,
421         };
422         foo(n);
423         bar();
424     }
425 }
426 "#,
427         );
428     }
429
430     #[test]
431     fn convert_inside_loop() {
432         check_assist(
433             convert_to_guarded_return,
434             r#"
435 fn main() {
436     loop {
437         if$0 true {
438             foo();
439             bar();
440         }
441     }
442 }
443 "#,
444             r#"
445 fn main() {
446     loop {
447         if false {
448             continue;
449         }
450         foo();
451         bar();
452     }
453 }
454 "#,
455         );
456     }
457
458     #[test]
459     fn convert_let_inside_loop() {
460         check_assist(
461             convert_to_guarded_return,
462             r#"
463 fn main() {
464     loop {
465         if$0 let Some(n) = n {
466             foo(n);
467             bar();
468         }
469     }
470 }
471 "#,
472             r#"
473 fn main() {
474     loop {
475         let n = match n {
476             Some(it) => it,
477             _ => continue,
478         };
479         foo(n);
480         bar();
481     }
482 }
483 "#,
484         );
485     }
486
487     #[test]
488     fn ignore_already_converted_if() {
489         check_assist_not_applicable(
490             convert_to_guarded_return,
491             r#"
492 fn main() {
493     if$0 true {
494         return;
495     }
496 }
497 "#,
498         );
499     }
500
501     #[test]
502     fn ignore_already_converted_loop() {
503         check_assist_not_applicable(
504             convert_to_guarded_return,
505             r#"
506 fn main() {
507     loop {
508         if$0 true {
509             continue;
510         }
511     }
512 }
513 "#,
514         );
515     }
516
517     #[test]
518     fn ignore_return() {
519         check_assist_not_applicable(
520             convert_to_guarded_return,
521             r#"
522 fn main() {
523     if$0 true {
524         return
525     }
526 }
527 "#,
528         );
529     }
530
531     #[test]
532     fn ignore_else_branch() {
533         check_assist_not_applicable(
534             convert_to_guarded_return,
535             r#"
536 fn main() {
537     if$0 true {
538         foo();
539     } else {
540         bar()
541     }
542 }
543 "#,
544         );
545     }
546
547     #[test]
548     fn ignore_statements_aftert_if() {
549         check_assist_not_applicable(
550             convert_to_guarded_return,
551             r#"
552 fn main() {
553     if$0 true {
554         foo();
555     }
556     bar();
557 }
558 "#,
559         );
560     }
561
562     #[test]
563     fn ignore_statements_inside_if() {
564         check_assist_not_applicable(
565             convert_to_guarded_return,
566             r#"
567 fn main() {
568     if false {
569         if$0 true {
570             foo();
571         }
572     }
573 }
574 "#,
575         );
576     }
577 }