]> git.lizzy.rs Git - rust.git/blob - src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_to_guarded_return.rs
Rollup merge of #99460 - JanBeh:PR_asref_asmut_docs, r=joshtriplett
[rust.git] / src / tools / rust-analyzer / crates / ide-assists / src / handlers / convert_to_guarded_return.rs
1 use std::iter::once;
2
3 use ide_db::syntax_helpers::node_ext::{is_pattern_cond, single_let};
4 use syntax::{
5     ast::{
6         self,
7         edit::{AstNodeEdit, IndentLevel},
8         make,
9     },
10     ted, AstNode,
11     SyntaxKind::{FN, LOOP_EXPR, WHILE_EXPR, WHITESPACE},
12     T,
13 };
14
15 use crate::{
16     assist_context::{AssistContext, Assists},
17     utils::invert_boolean_expression,
18     AssistId, AssistKind,
19 };
20
21 // Assist: convert_to_guarded_return
22 //
23 // Replace a large conditional with a guarded return.
24 //
25 // ```
26 // fn main() {
27 //     $0if cond {
28 //         foo();
29 //         bar();
30 //     }
31 // }
32 // ```
33 // ->
34 // ```
35 // fn main() {
36 //     if !cond {
37 //         return;
38 //     }
39 //     foo();
40 //     bar();
41 // }
42 // ```
43 pub(crate) fn convert_to_guarded_return(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> {
44     let if_expr: ast::IfExpr = ctx.find_node_at_offset()?;
45     if if_expr.else_branch().is_some() {
46         return None;
47     }
48
49     let cond = if_expr.condition()?;
50
51     // Check if there is an IfLet that we can handle.
52     let (if_let_pat, cond_expr) = if is_pattern_cond(cond.clone()) {
53         let let_ = single_let(cond)?;
54         match let_.pat() {
55             Some(ast::Pat::TupleStructPat(pat)) if pat.fields().count() == 1 => {
56                 let path = pat.path()?;
57                 if path.qualifier().is_some() {
58                     return None;
59                 }
60
61                 let bound_ident = pat.fields().next().unwrap();
62                 if !ast::IdentPat::can_cast(bound_ident.syntax().kind()) {
63                     return None;
64                 }
65
66                 (Some((path, bound_ident)), let_.expr()?)
67             }
68             _ => return None, // Unsupported IfLet.
69         }
70     } else {
71         (None, cond)
72     };
73
74     let then_block = if_expr.then_branch()?;
75     let then_block = then_block.stmt_list()?;
76
77     let parent_block = if_expr.syntax().parent()?.ancestors().find_map(ast::BlockExpr::cast)?;
78
79     if parent_block.tail_expr()? != if_expr.clone().into() {
80         return None;
81     }
82
83     // FIXME: This relies on untyped syntax tree and casts to much. It should be
84     // rewritten to use strongly-typed APIs.
85
86     // check for early return and continue
87     let first_in_then_block = then_block.syntax().first_child()?;
88     if ast::ReturnExpr::can_cast(first_in_then_block.kind())
89         || ast::ContinueExpr::can_cast(first_in_then_block.kind())
90         || first_in_then_block
91             .children()
92             .any(|x| ast::ReturnExpr::can_cast(x.kind()) || ast::ContinueExpr::can_cast(x.kind()))
93     {
94         return None;
95     }
96
97     let parent_container = parent_block.syntax().parent()?;
98
99     let early_expression: ast::Expr = match parent_container.kind() {
100         WHILE_EXPR | LOOP_EXPR => make::expr_continue(None),
101         FN => make::expr_return(None),
102         _ => return None,
103     };
104
105     if then_block.syntax().first_child_or_token().map(|t| t.kind() == T!['{']).is_none() {
106         return None;
107     }
108
109     then_block.syntax().last_child_or_token().filter(|t| t.kind() == T!['}'])?;
110
111     let target = if_expr.syntax().text_range();
112     acc.add(
113         AssistId("convert_to_guarded_return", AssistKind::RefactorRewrite),
114         "Convert to guarded return",
115         target,
116         |edit| {
117             let if_expr = edit.make_mut(if_expr);
118             let if_indent_level = IndentLevel::from_node(if_expr.syntax());
119             let replacement = match if_let_pat {
120                 None => {
121                     // If.
122                     let new_expr = {
123                         let then_branch =
124                             make::block_expr(once(make::expr_stmt(early_expression).into()), None);
125                         let cond = invert_boolean_expression(cond_expr);
126                         make::expr_if(cond, then_branch, None).indent(if_indent_level)
127                     };
128                     new_expr.syntax().clone_for_update()
129                 }
130                 Some((path, bound_ident)) => {
131                     // If-let.
132                     let match_expr = {
133                         let happy_arm = {
134                             let pat = make::tuple_struct_pat(
135                                 path,
136                                 once(make::ext::simple_ident_pat(make::name("it")).into()),
137                             );
138                             let expr = {
139                                 let path = make::ext::ident_path("it");
140                                 make::expr_path(path)
141                             };
142                             make::match_arm(once(pat.into()), None, expr)
143                         };
144
145                         let sad_arm = make::match_arm(
146                             // FIXME: would be cool to use `None` or `Err(_)` if appropriate
147                             once(make::wildcard_pat().into()),
148                             None,
149                             early_expression,
150                         );
151
152                         make::expr_match(cond_expr, make::match_arm_list(vec![happy_arm, sad_arm]))
153                     };
154
155                     let let_stmt = make::let_stmt(bound_ident, None, Some(match_expr));
156                     let let_stmt = let_stmt.indent(if_indent_level);
157                     let_stmt.syntax().clone_for_update()
158                 }
159             };
160
161             let then_block_items = then_block.dedent(IndentLevel(1)).clone_for_update();
162
163             let end_of_then = then_block_items.syntax().last_child_or_token().unwrap();
164             let end_of_then =
165                 if end_of_then.prev_sibling_or_token().map(|n| n.kind()) == Some(WHITESPACE) {
166                     end_of_then.prev_sibling_or_token().unwrap()
167                 } else {
168                     end_of_then
169                 };
170
171             let then_statements = replacement
172                 .children_with_tokens()
173                 .chain(
174                     then_block_items
175                         .syntax()
176                         .children_with_tokens()
177                         .skip(1)
178                         .take_while(|i| *i != end_of_then),
179                 )
180                 .collect();
181
182             ted::replace_with_many(if_expr.syntax(), then_statements)
183         },
184     )
185 }
186
187 #[cfg(test)]
188 mod tests {
189     use crate::tests::{check_assist, check_assist_not_applicable};
190
191     use super::*;
192
193     #[test]
194     fn convert_inside_fn() {
195         check_assist(
196             convert_to_guarded_return,
197             r#"
198 fn main() {
199     bar();
200     if$0 true {
201         foo();
202
203         // comment
204         bar();
205     }
206 }
207 "#,
208             r#"
209 fn main() {
210     bar();
211     if false {
212         return;
213     }
214     foo();
215
216     // comment
217     bar();
218 }
219 "#,
220         );
221     }
222
223     #[test]
224     fn convert_let_inside_fn() {
225         check_assist(
226             convert_to_guarded_return,
227             r#"
228 fn main(n: Option<String>) {
229     bar();
230     if$0 let Some(n) = n {
231         foo(n);
232
233         // comment
234         bar();
235     }
236 }
237 "#,
238             r#"
239 fn main(n: Option<String>) {
240     bar();
241     let n = match n {
242         Some(it) => it,
243         _ => return,
244     };
245     foo(n);
246
247     // comment
248     bar();
249 }
250 "#,
251         );
252     }
253
254     #[test]
255     fn convert_if_let_result() {
256         check_assist(
257             convert_to_guarded_return,
258             r#"
259 fn main() {
260     if$0 let Ok(x) = Err(92) {
261         foo(x);
262     }
263 }
264 "#,
265             r#"
266 fn main() {
267     let x = match Err(92) {
268         Ok(it) => it,
269         _ => return,
270     };
271     foo(x);
272 }
273 "#,
274         );
275     }
276
277     #[test]
278     fn convert_let_ok_inside_fn() {
279         check_assist(
280             convert_to_guarded_return,
281             r#"
282 fn main(n: Option<String>) {
283     bar();
284     if$0 let Some(n) = n {
285         foo(n);
286
287         // comment
288         bar();
289     }
290 }
291 "#,
292             r#"
293 fn main(n: Option<String>) {
294     bar();
295     let n = match n {
296         Some(it) => it,
297         _ => return,
298     };
299     foo(n);
300
301     // comment
302     bar();
303 }
304 "#,
305         );
306     }
307
308     #[test]
309     fn convert_let_mut_ok_inside_fn() {
310         check_assist(
311             convert_to_guarded_return,
312             r#"
313 fn main(n: Option<String>) {
314     bar();
315     if$0 let Some(mut n) = n {
316         foo(n);
317
318         // comment
319         bar();
320     }
321 }
322 "#,
323             r#"
324 fn main(n: Option<String>) {
325     bar();
326     let mut n = match n {
327         Some(it) => it,
328         _ => return,
329     };
330     foo(n);
331
332     // comment
333     bar();
334 }
335 "#,
336         );
337     }
338
339     #[test]
340     fn convert_let_ref_ok_inside_fn() {
341         check_assist(
342             convert_to_guarded_return,
343             r#"
344 fn main(n: Option<&str>) {
345     bar();
346     if$0 let Some(ref n) = n {
347         foo(n);
348
349         // comment
350         bar();
351     }
352 }
353 "#,
354             r#"
355 fn main(n: Option<&str>) {
356     bar();
357     let ref n = match n {
358         Some(it) => it,
359         _ => return,
360     };
361     foo(n);
362
363     // comment
364     bar();
365 }
366 "#,
367         );
368     }
369
370     #[test]
371     fn convert_inside_while() {
372         check_assist(
373             convert_to_guarded_return,
374             r#"
375 fn main() {
376     while true {
377         if$0 true {
378             foo();
379             bar();
380         }
381     }
382 }
383 "#,
384             r#"
385 fn main() {
386     while true {
387         if false {
388             continue;
389         }
390         foo();
391         bar();
392     }
393 }
394 "#,
395         );
396     }
397
398     #[test]
399     fn convert_let_inside_while() {
400         check_assist(
401             convert_to_guarded_return,
402             r#"
403 fn main() {
404     while true {
405         if$0 let Some(n) = n {
406             foo(n);
407             bar();
408         }
409     }
410 }
411 "#,
412             r#"
413 fn main() {
414     while true {
415         let n = match n {
416             Some(it) => it,
417             _ => continue,
418         };
419         foo(n);
420         bar();
421     }
422 }
423 "#,
424         );
425     }
426
427     #[test]
428     fn convert_inside_loop() {
429         check_assist(
430             convert_to_guarded_return,
431             r#"
432 fn main() {
433     loop {
434         if$0 true {
435             foo();
436             bar();
437         }
438     }
439 }
440 "#,
441             r#"
442 fn main() {
443     loop {
444         if false {
445             continue;
446         }
447         foo();
448         bar();
449     }
450 }
451 "#,
452         );
453     }
454
455     #[test]
456     fn convert_let_inside_loop() {
457         check_assist(
458             convert_to_guarded_return,
459             r#"
460 fn main() {
461     loop {
462         if$0 let Some(n) = n {
463             foo(n);
464             bar();
465         }
466     }
467 }
468 "#,
469             r#"
470 fn main() {
471     loop {
472         let n = match n {
473             Some(it) => it,
474             _ => continue,
475         };
476         foo(n);
477         bar();
478     }
479 }
480 "#,
481         );
482     }
483
484     #[test]
485     fn ignore_already_converted_if() {
486         check_assist_not_applicable(
487             convert_to_guarded_return,
488             r#"
489 fn main() {
490     if$0 true {
491         return;
492     }
493 }
494 "#,
495         );
496     }
497
498     #[test]
499     fn ignore_already_converted_loop() {
500         check_assist_not_applicable(
501             convert_to_guarded_return,
502             r#"
503 fn main() {
504     loop {
505         if$0 true {
506             continue;
507         }
508     }
509 }
510 "#,
511         );
512     }
513
514     #[test]
515     fn ignore_return() {
516         check_assist_not_applicable(
517             convert_to_guarded_return,
518             r#"
519 fn main() {
520     if$0 true {
521         return
522     }
523 }
524 "#,
525         );
526     }
527
528     #[test]
529     fn ignore_else_branch() {
530         check_assist_not_applicable(
531             convert_to_guarded_return,
532             r#"
533 fn main() {
534     if$0 true {
535         foo();
536     } else {
537         bar()
538     }
539 }
540 "#,
541         );
542     }
543
544     #[test]
545     fn ignore_statements_aftert_if() {
546         check_assist_not_applicable(
547             convert_to_guarded_return,
548             r#"
549 fn main() {
550     if$0 true {
551         foo();
552     }
553     bar();
554 }
555 "#,
556         );
557     }
558
559     #[test]
560     fn ignore_statements_inside_if() {
561         check_assist_not_applicable(
562             convert_to_guarded_return,
563             r#"
564 fn main() {
565     if false {
566         if$0 true {
567             foo();
568         }
569     }
570 }
571 "#,
572         );
573     }
574 }