]> git.lizzy.rs Git - rust.git/blob - src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_to_guarded_return.rs
Rollup merge of #103996 - SUPERCILEX:docs, r=RalfJung
[rust.git] / src / tools / rust-analyzer / crates / ide-assists / src / handlers / convert_to_guarded_return.rs
1 use std::iter::once;
2
3 use ide_db::syntax_helpers::node_ext::{is_pattern_cond, single_let};
4 use syntax::{
5     ast::{
6         self,
7         edit::{AstNodeEdit, IndentLevel},
8         make,
9     },
10     ted, AstNode,
11     SyntaxKind::{FN, LOOP_EXPR, WHILE_EXPR, WHITESPACE},
12     T,
13 };
14
15 use crate::{
16     assist_context::{AssistContext, Assists},
17     utils::invert_boolean_expression,
18     AssistId, AssistKind,
19 };
20
21 // Assist: convert_to_guarded_return
22 //
23 // Replace a large conditional with a guarded return.
24 //
25 // ```
26 // fn main() {
27 //     $0if cond {
28 //         foo();
29 //         bar();
30 //     }
31 // }
32 // ```
33 // ->
34 // ```
35 // fn main() {
36 //     if !cond {
37 //         return;
38 //     }
39 //     foo();
40 //     bar();
41 // }
42 // ```
43 pub(crate) fn convert_to_guarded_return(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> {
44     let if_expr: ast::IfExpr = ctx.find_node_at_offset()?;
45     if if_expr.else_branch().is_some() {
46         return None;
47     }
48
49     let cond = if_expr.condition()?;
50
51     // Check if there is an IfLet that we can handle.
52     let (if_let_pat, cond_expr) = if is_pattern_cond(cond.clone()) {
53         let let_ = single_let(cond)?;
54         match let_.pat() {
55             Some(ast::Pat::TupleStructPat(pat)) if pat.fields().count() == 1 => {
56                 let path = pat.path()?;
57                 if path.qualifier().is_some() {
58                     return None;
59                 }
60
61                 let bound_ident = pat.fields().next().unwrap();
62                 if !ast::IdentPat::can_cast(bound_ident.syntax().kind()) {
63                     return None;
64                 }
65
66                 (Some((path, bound_ident)), let_.expr()?)
67             }
68             _ => return None, // Unsupported IfLet.
69         }
70     } else {
71         (None, cond)
72     };
73
74     let then_block = if_expr.then_branch()?;
75     let then_block = then_block.stmt_list()?;
76
77     let parent_block = if_expr.syntax().parent()?.ancestors().find_map(ast::BlockExpr::cast)?;
78
79     if parent_block.tail_expr()? != if_expr.clone().into() {
80         return None;
81     }
82
83     // FIXME: This relies on untyped syntax tree and casts to much. It should be
84     // rewritten to use strongly-typed APIs.
85
86     // check for early return and continue
87     let first_in_then_block = then_block.syntax().first_child()?;
88     if ast::ReturnExpr::can_cast(first_in_then_block.kind())
89         || ast::ContinueExpr::can_cast(first_in_then_block.kind())
90         || first_in_then_block
91             .children()
92             .any(|x| ast::ReturnExpr::can_cast(x.kind()) || ast::ContinueExpr::can_cast(x.kind()))
93     {
94         return None;
95     }
96
97     let parent_container = parent_block.syntax().parent()?;
98
99     let early_expression: ast::Expr = match parent_container.kind() {
100         WHILE_EXPR | LOOP_EXPR => make::expr_continue(None),
101         FN => make::expr_return(None),
102         _ => return None,
103     };
104
105     if then_block.syntax().first_child_or_token().map(|t| t.kind() == T!['{']).is_none() {
106         return None;
107     }
108
109     then_block.syntax().last_child_or_token().filter(|t| t.kind() == T!['}'])?;
110
111     let target = if_expr.syntax().text_range();
112     acc.add(
113         AssistId("convert_to_guarded_return", AssistKind::RefactorRewrite),
114         "Convert to guarded return",
115         target,
116         |edit| {
117             let if_expr = edit.make_mut(if_expr);
118             let if_indent_level = IndentLevel::from_node(if_expr.syntax());
119             let replacement = match if_let_pat {
120                 None => {
121                     // If.
122                     let new_expr = {
123                         let then_branch =
124                             make::block_expr(once(make::expr_stmt(early_expression).into()), None);
125                         let cond = invert_boolean_expression(cond_expr);
126                         make::expr_if(cond, then_branch, None).indent(if_indent_level)
127                     };
128                     new_expr.syntax().clone_for_update()
129                 }
130                 Some((path, bound_ident)) => {
131                     // If-let.
132                     let pat = make::tuple_struct_pat(path, once(bound_ident));
133                     let let_else_stmt = make::let_else_stmt(
134                         pat.into(),
135                         None,
136                         cond_expr,
137                         ast::make::tail_only_block_expr(early_expression),
138                     );
139                     let let_else_stmt = let_else_stmt.indent(if_indent_level);
140                     let_else_stmt.syntax().clone_for_update()
141                 }
142             };
143
144             let then_block_items = then_block.dedent(IndentLevel(1)).clone_for_update();
145
146             let end_of_then = then_block_items.syntax().last_child_or_token().unwrap();
147             let end_of_then =
148                 if end_of_then.prev_sibling_or_token().map(|n| n.kind()) == Some(WHITESPACE) {
149                     end_of_then.prev_sibling_or_token().unwrap()
150                 } else {
151                     end_of_then
152                 };
153
154             let then_statements = replacement
155                 .children_with_tokens()
156                 .chain(
157                     then_block_items
158                         .syntax()
159                         .children_with_tokens()
160                         .skip(1)
161                         .take_while(|i| *i != end_of_then),
162                 )
163                 .collect();
164
165             ted::replace_with_many(if_expr.syntax(), then_statements)
166         },
167     )
168 }
169
170 #[cfg(test)]
171 mod tests {
172     use crate::tests::{check_assist, check_assist_not_applicable};
173
174     use super::*;
175
176     #[test]
177     fn convert_inside_fn() {
178         check_assist(
179             convert_to_guarded_return,
180             r#"
181 fn main() {
182     bar();
183     if$0 true {
184         foo();
185
186         // comment
187         bar();
188     }
189 }
190 "#,
191             r#"
192 fn main() {
193     bar();
194     if false {
195         return;
196     }
197     foo();
198
199     // comment
200     bar();
201 }
202 "#,
203         );
204     }
205
206     #[test]
207     fn convert_let_inside_fn() {
208         check_assist(
209             convert_to_guarded_return,
210             r#"
211 fn main(n: Option<String>) {
212     bar();
213     if$0 let Some(n) = n {
214         foo(n);
215
216         // comment
217         bar();
218     }
219 }
220 "#,
221             r#"
222 fn main(n: Option<String>) {
223     bar();
224     let Some(n) = n else { return };
225     foo(n);
226
227     // comment
228     bar();
229 }
230 "#,
231         );
232     }
233
234     #[test]
235     fn convert_if_let_result() {
236         check_assist(
237             convert_to_guarded_return,
238             r#"
239 fn main() {
240     if$0 let Ok(x) = Err(92) {
241         foo(x);
242     }
243 }
244 "#,
245             r#"
246 fn main() {
247     let Ok(x) = Err(92) else { return };
248     foo(x);
249 }
250 "#,
251         );
252     }
253
254     #[test]
255     fn convert_let_ok_inside_fn() {
256         check_assist(
257             convert_to_guarded_return,
258             r#"
259 fn main(n: Option<String>) {
260     bar();
261     if$0 let Some(n) = n {
262         foo(n);
263
264         // comment
265         bar();
266     }
267 }
268 "#,
269             r#"
270 fn main(n: Option<String>) {
271     bar();
272     let Some(n) = n else { return };
273     foo(n);
274
275     // comment
276     bar();
277 }
278 "#,
279         );
280     }
281
282     #[test]
283     fn convert_let_mut_ok_inside_fn() {
284         check_assist(
285             convert_to_guarded_return,
286             r#"
287 fn main(n: Option<String>) {
288     bar();
289     if$0 let Some(mut n) = n {
290         foo(n);
291
292         // comment
293         bar();
294     }
295 }
296 "#,
297             r#"
298 fn main(n: Option<String>) {
299     bar();
300     let Some(mut n) = n else { return };
301     foo(n);
302
303     // comment
304     bar();
305 }
306 "#,
307         );
308     }
309
310     #[test]
311     fn convert_let_ref_ok_inside_fn() {
312         check_assist(
313             convert_to_guarded_return,
314             r#"
315 fn main(n: Option<&str>) {
316     bar();
317     if$0 let Some(ref n) = n {
318         foo(n);
319
320         // comment
321         bar();
322     }
323 }
324 "#,
325             r#"
326 fn main(n: Option<&str>) {
327     bar();
328     let Some(ref n) = n else { return };
329     foo(n);
330
331     // comment
332     bar();
333 }
334 "#,
335         );
336     }
337
338     #[test]
339     fn convert_inside_while() {
340         check_assist(
341             convert_to_guarded_return,
342             r#"
343 fn main() {
344     while true {
345         if$0 true {
346             foo();
347             bar();
348         }
349     }
350 }
351 "#,
352             r#"
353 fn main() {
354     while true {
355         if false {
356             continue;
357         }
358         foo();
359         bar();
360     }
361 }
362 "#,
363         );
364     }
365
366     #[test]
367     fn convert_let_inside_while() {
368         check_assist(
369             convert_to_guarded_return,
370             r#"
371 fn main() {
372     while true {
373         if$0 let Some(n) = n {
374             foo(n);
375             bar();
376         }
377     }
378 }
379 "#,
380             r#"
381 fn main() {
382     while true {
383         let Some(n) = n else { continue };
384         foo(n);
385         bar();
386     }
387 }
388 "#,
389         );
390     }
391
392     #[test]
393     fn convert_inside_loop() {
394         check_assist(
395             convert_to_guarded_return,
396             r#"
397 fn main() {
398     loop {
399         if$0 true {
400             foo();
401             bar();
402         }
403     }
404 }
405 "#,
406             r#"
407 fn main() {
408     loop {
409         if false {
410             continue;
411         }
412         foo();
413         bar();
414     }
415 }
416 "#,
417         );
418     }
419
420     #[test]
421     fn convert_let_inside_loop() {
422         check_assist(
423             convert_to_guarded_return,
424             r#"
425 fn main() {
426     loop {
427         if$0 let Some(n) = n {
428             foo(n);
429             bar();
430         }
431     }
432 }
433 "#,
434             r#"
435 fn main() {
436     loop {
437         let Some(n) = n else { continue };
438         foo(n);
439         bar();
440     }
441 }
442 "#,
443         );
444     }
445
446     #[test]
447     fn ignore_already_converted_if() {
448         check_assist_not_applicable(
449             convert_to_guarded_return,
450             r#"
451 fn main() {
452     if$0 true {
453         return;
454     }
455 }
456 "#,
457         );
458     }
459
460     #[test]
461     fn ignore_already_converted_loop() {
462         check_assist_not_applicable(
463             convert_to_guarded_return,
464             r#"
465 fn main() {
466     loop {
467         if$0 true {
468             continue;
469         }
470     }
471 }
472 "#,
473         );
474     }
475
476     #[test]
477     fn ignore_return() {
478         check_assist_not_applicable(
479             convert_to_guarded_return,
480             r#"
481 fn main() {
482     if$0 true {
483         return
484     }
485 }
486 "#,
487         );
488     }
489
490     #[test]
491     fn ignore_else_branch() {
492         check_assist_not_applicable(
493             convert_to_guarded_return,
494             r#"
495 fn main() {
496     if$0 true {
497         foo();
498     } else {
499         bar()
500     }
501 }
502 "#,
503         );
504     }
505
506     #[test]
507     fn ignore_statements_aftert_if() {
508         check_assist_not_applicable(
509             convert_to_guarded_return,
510             r#"
511 fn main() {
512     if$0 true {
513         foo();
514     }
515     bar();
516 }
517 "#,
518         );
519     }
520
521     #[test]
522     fn ignore_statements_inside_if() {
523         check_assist_not_applicable(
524             convert_to_guarded_return,
525             r#"
526 fn main() {
527     if false {
528         if$0 true {
529             foo();
530         }
531     }
532 }
533 "#,
534         );
535     }
536 }