]> git.lizzy.rs Git - rust.git/blob - crates/ide_assists/src/handlers/replace_if_let_with_match.rs
Merge #9179
[rust.git] / crates / ide_assists / src / handlers / replace_if_let_with_match.rs
1 use std::iter::{self, successors};
2
3 use either::Either;
4 use ide_db::{ty_filter::TryEnum, RootDatabase};
5 use syntax::{
6     ast::{
7         self,
8         edit::{AstNodeEdit, IndentLevel},
9         make,
10     },
11     AstNode,
12 };
13
14 use crate::{
15     utils::{does_pat_match_variant, unwrap_trivial_block},
16     AssistContext, AssistId, AssistKind, Assists,
17 };
18
19 // Assist: replace_if_let_with_match
20 //
21 // Replaces a `if let` expression with a `match` expression.
22 //
23 // ```
24 // enum Action { Move { distance: u32 }, Stop }
25 //
26 // fn handle(action: Action) {
27 //     $0if let Action::Move { distance } = action {
28 //         foo(distance)
29 //     } else {
30 //         bar()
31 //     }
32 // }
33 // ```
34 // ->
35 // ```
36 // enum Action { Move { distance: u32 }, Stop }
37 //
38 // fn handle(action: Action) {
39 //     match action {
40 //         Action::Move { distance } => foo(distance),
41 //         _ => bar(),
42 //     }
43 // }
44 // ```
45 pub(crate) fn replace_if_let_with_match(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
46     let if_expr: ast::IfExpr = ctx.find_node_at_offset()?;
47     let mut else_block = None;
48     let if_exprs = successors(Some(if_expr.clone()), |expr| match expr.else_branch()? {
49         ast::ElseBranch::IfExpr(expr) => Some(expr),
50         ast::ElseBranch::Block(block) => {
51             else_block = Some(block);
52             None
53         }
54     });
55     let scrutinee_to_be_expr = if_expr.condition()?.expr()?;
56
57     let mut pat_seen = false;
58     let mut cond_bodies = Vec::new();
59     for if_expr in if_exprs {
60         let cond = if_expr.condition()?;
61         let expr = cond.expr()?;
62         let cond = match cond.pat() {
63             Some(pat) => {
64                 if scrutinee_to_be_expr.syntax().text() != expr.syntax().text() {
65                     // Only if all condition expressions are equal we can merge them into a match
66                     return None;
67                 }
68                 pat_seen = true;
69                 Either::Left(pat)
70             }
71             None => Either::Right(expr),
72         };
73         let body = if_expr.then_branch()?;
74         cond_bodies.push((cond, body));
75     }
76
77     if !pat_seen {
78         // Don't offer turning an if (chain) without patterns into a match
79         return None;
80     }
81
82     let target = if_expr.syntax().text_range();
83     acc.add(
84         AssistId("replace_if_let_with_match", AssistKind::RefactorRewrite),
85         "Replace if let with match",
86         target,
87         move |edit| {
88             let match_expr = {
89                 let else_arm = make_else_arm(else_block, &cond_bodies, ctx);
90                 let make_match_arm = |(pat, body): (_, ast::BlockExpr)| {
91                     let body = body.reset_indent().indent(IndentLevel(1));
92                     match pat {
93                         Either::Left(pat) => {
94                             make::match_arm(iter::once(pat), None, unwrap_trivial_block(body))
95                         }
96                         Either::Right(expr) => make::match_arm(
97                             iter::once(make::wildcard_pat().into()),
98                             Some(expr),
99                             unwrap_trivial_block(body),
100                         ),
101                     }
102                 };
103                 let arms = cond_bodies.into_iter().map(make_match_arm).chain(iter::once(else_arm));
104                 let match_expr = make::expr_match(scrutinee_to_be_expr, make::match_arm_list(arms));
105                 match_expr.indent(IndentLevel::from_node(if_expr.syntax()))
106             };
107
108             let has_preceding_if_expr =
109                 if_expr.syntax().parent().map_or(false, |it| ast::IfExpr::can_cast(it.kind()));
110             let expr = if has_preceding_if_expr {
111                 // make sure we replace the `else if let ...` with a block so we don't end up with `else expr`
112                 make::block_expr(None, Some(match_expr)).into()
113             } else {
114                 match_expr
115             };
116             edit.replace_ast::<ast::Expr>(if_expr.into(), expr);
117         },
118     )
119 }
120
121 fn make_else_arm(
122     else_block: Option<ast::BlockExpr>,
123     cond_bodies: &Vec<(Either<ast::Pat, ast::Expr>, ast::BlockExpr)>,
124     ctx: &AssistContext,
125 ) -> ast::MatchArm {
126     if let Some(else_block) = else_block {
127         let pattern = if let [(Either::Left(pat), _)] = &**cond_bodies {
128             ctx.sema
129                 .type_of_pat(&pat)
130                 .and_then(|ty| TryEnum::from_ty(&ctx.sema, &ty))
131                 .zip(Some(pat))
132         } else {
133             None
134         };
135         let pattern = match pattern {
136             Some((it, pat)) => {
137                 if does_pat_match_variant(&pat, &it.sad_pattern()) {
138                     it.happy_pattern()
139                 } else {
140                     it.sad_pattern()
141                 }
142             }
143             None => make::wildcard_pat().into(),
144         };
145         make::match_arm(iter::once(pattern), None, unwrap_trivial_block(else_block))
146     } else {
147         make::match_arm(iter::once(make::wildcard_pat().into()), None, make::expr_unit().into())
148     }
149 }
150
151 // Assist: replace_match_with_if_let
152 //
153 // Replaces a binary `match` with a wildcard pattern and no guards with an `if let` expression.
154 //
155 // ```
156 // enum Action { Move { distance: u32 }, Stop }
157 //
158 // fn handle(action: Action) {
159 //     $0match action {
160 //         Action::Move { distance } => foo(distance),
161 //         _ => bar(),
162 //     }
163 // }
164 // ```
165 // ->
166 // ```
167 // enum Action { Move { distance: u32 }, Stop }
168 //
169 // fn handle(action: Action) {
170 //     if let Action::Move { distance } = action {
171 //         foo(distance)
172 //     } else {
173 //         bar()
174 //     }
175 // }
176 // ```
177 pub(crate) fn replace_match_with_if_let(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
178     let match_expr: ast::MatchExpr = ctx.find_node_at_offset()?;
179
180     let mut arms = match_expr.match_arm_list()?.arms();
181     let (first_arm, second_arm) = (arms.next()?, arms.next()?);
182     if arms.next().is_some() || first_arm.guard().is_some() || second_arm.guard().is_some() {
183         return None;
184     }
185
186     let (if_let_pat, then_expr, else_expr) = pick_pattern_and_expr_order(
187         &ctx.sema,
188         first_arm.pat()?,
189         second_arm.pat()?,
190         first_arm.expr()?,
191         second_arm.expr()?,
192     )?;
193     let scrutinee = match_expr.expr()?;
194
195     let target = match_expr.syntax().text_range();
196     acc.add(
197         AssistId("replace_match_with_if_let", AssistKind::RefactorRewrite),
198         "Replace match with if let",
199         target,
200         move |edit| {
201             let condition = make::condition(scrutinee, Some(if_let_pat));
202             let then_block = match then_expr.reset_indent() {
203                 ast::Expr::BlockExpr(block) => block,
204                 expr => make::block_expr(iter::empty(), Some(expr)),
205             };
206             let else_expr = match else_expr {
207                 ast::Expr::BlockExpr(block) if block.is_empty() => None,
208                 ast::Expr::TupleExpr(tuple) if tuple.fields().next().is_none() => None,
209                 expr => Some(expr),
210             };
211             let if_let_expr = make::expr_if(
212                 condition,
213                 then_block,
214                 else_expr
215                     .map(|expr| match expr {
216                         ast::Expr::BlockExpr(block) => block,
217                         expr => (make::block_expr(iter::empty(), Some(expr))),
218                     })
219                     .map(ast::ElseBranch::Block),
220             )
221             .indent(IndentLevel::from_node(match_expr.syntax()));
222
223             edit.replace_ast::<ast::Expr>(match_expr.into(), if_let_expr);
224         },
225     )
226 }
227
228 /// Pick the pattern for the if let condition and return the expressions for the `then` body and `else` body in that order.
229 fn pick_pattern_and_expr_order(
230     sema: &hir::Semantics<RootDatabase>,
231     pat: ast::Pat,
232     pat2: ast::Pat,
233     expr: ast::Expr,
234     expr2: ast::Expr,
235 ) -> Option<(ast::Pat, ast::Expr, ast::Expr)> {
236     let res = match (pat, pat2) {
237         (ast::Pat::WildcardPat(_), _) => return None,
238         (pat, sad_pat) if is_sad_pat(sema, &sad_pat) => (pat, expr, expr2),
239         (sad_pat, pat) if is_sad_pat(sema, &sad_pat) => (pat, expr2, expr),
240         (pat, pat2) => match (binds_name(&pat), binds_name(&pat2)) {
241             (true, true) => return None,
242             (true, false) => (pat, expr, expr2),
243             (false, true) => (pat2, expr2, expr),
244             (false, false) => (pat, expr, expr2),
245         },
246     };
247     Some(res)
248 }
249
250 fn binds_name(pat: &ast::Pat) -> bool {
251     let binds_name_v = |pat| binds_name(&pat);
252     match pat {
253         ast::Pat::IdentPat(_) => true,
254         ast::Pat::MacroPat(_) => true,
255         ast::Pat::OrPat(pat) => pat.pats().any(binds_name_v),
256         ast::Pat::SlicePat(pat) => pat.pats().any(binds_name_v),
257         ast::Pat::TuplePat(it) => it.fields().any(binds_name_v),
258         ast::Pat::TupleStructPat(it) => it.fields().any(binds_name_v),
259         ast::Pat::RecordPat(it) => it
260             .record_pat_field_list()
261             .map_or(false, |rpfl| rpfl.fields().flat_map(|rpf| rpf.pat()).any(binds_name_v)),
262         ast::Pat::RefPat(pat) => pat.pat().map_or(false, binds_name_v),
263         ast::Pat::BoxPat(pat) => pat.pat().map_or(false, binds_name_v),
264         ast::Pat::ParenPat(pat) => pat.pat().map_or(false, binds_name_v),
265         _ => false,
266     }
267 }
268
269 fn is_sad_pat(sema: &hir::Semantics<RootDatabase>, pat: &ast::Pat) -> bool {
270     sema.type_of_pat(pat)
271         .and_then(|ty| TryEnum::from_ty(sema, &ty))
272         .map_or(false, |it| does_pat_match_variant(pat, &it.sad_pattern()))
273 }
274
275 #[cfg(test)]
276 mod tests {
277     use super::*;
278
279     use crate::tests::{check_assist, check_assist_not_applicable, check_assist_target};
280
281     #[test]
282     fn test_if_let_with_match_unapplicable_for_simple_ifs() {
283         check_assist_not_applicable(
284             replace_if_let_with_match,
285             r#"
286 fn main() {
287     if $0true {} else if false {} else {}
288 }
289 "#,
290         )
291     }
292
293     #[test]
294     fn test_if_let_with_match_no_else() {
295         check_assist(
296             replace_if_let_with_match,
297             r#"
298 impl VariantData {
299     pub fn foo(&self) {
300         if $0let VariantData::Struct(..) = *self {
301             self.foo();
302         }
303     }
304 }
305 "#,
306             r#"
307 impl VariantData {
308     pub fn foo(&self) {
309         match *self {
310             VariantData::Struct(..) => {
311                 self.foo();
312             }
313             _ => (),
314         }
315     }
316 }
317 "#,
318         )
319     }
320
321     #[test]
322     fn test_if_let_with_match_basic() {
323         check_assist(
324             replace_if_let_with_match,
325             r#"
326 impl VariantData {
327     pub fn is_struct(&self) -> bool {
328         if $0let VariantData::Struct(..) = *self {
329             true
330         } else if let VariantData::Tuple(..) = *self {
331             false
332         } else if cond() {
333             true
334         } else {
335             bar(
336                 123
337             )
338         }
339     }
340 }
341 "#,
342             r#"
343 impl VariantData {
344     pub fn is_struct(&self) -> bool {
345         match *self {
346             VariantData::Struct(..) => true,
347             VariantData::Tuple(..) => false,
348             _ if cond() => true,
349             _ => {
350                     bar(
351                         123
352                     )
353                 }
354         }
355     }
356 }
357 "#,
358         )
359     }
360
361     #[test]
362     fn test_if_let_with_match_on_tail_if_let() {
363         check_assist(
364             replace_if_let_with_match,
365             r#"
366 impl VariantData {
367     pub fn is_struct(&self) -> bool {
368         if let VariantData::Struct(..) = *self {
369             true
370         } else if let$0 VariantData::Tuple(..) = *self {
371             false
372         } else {
373             false
374         }
375     }
376 }
377 "#,
378             r#"
379 impl VariantData {
380     pub fn is_struct(&self) -> bool {
381         if let VariantData::Struct(..) = *self {
382             true
383         } else {
384     match *self {
385             VariantData::Tuple(..) => false,
386             _ => false,
387         }
388 }
389     }
390 }
391 "#,
392         )
393     }
394
395     #[test]
396     fn special_case_option() {
397         check_assist(
398             replace_if_let_with_match,
399             r#"
400 //- minicore: option
401 fn foo(x: Option<i32>) {
402     $0if let Some(x) = x {
403         println!("{}", x)
404     } else {
405         println!("none")
406     }
407 }
408 "#,
409             r#"
410 fn foo(x: Option<i32>) {
411     match x {
412         Some(x) => println!("{}", x),
413         None => println!("none"),
414     }
415 }
416 "#,
417         );
418     }
419
420     #[test]
421     fn special_case_inverted_option() {
422         check_assist(
423             replace_if_let_with_match,
424             r#"
425 //- minicore: option
426 fn foo(x: Option<i32>) {
427     $0if let None = x {
428         println!("none")
429     } else {
430         println!("some")
431     }
432 }
433 "#,
434             r#"
435 fn foo(x: Option<i32>) {
436     match x {
437         None => println!("none"),
438         Some(_) => println!("some"),
439     }
440 }
441 "#,
442         );
443     }
444
445     #[test]
446     fn special_case_result() {
447         check_assist(
448             replace_if_let_with_match,
449             r#"
450 //- minicore: result
451 fn foo(x: Result<i32, ()>) {
452     $0if let Ok(x) = x {
453         println!("{}", x)
454     } else {
455         println!("none")
456     }
457 }
458 "#,
459             r#"
460 fn foo(x: Result<i32, ()>) {
461     match x {
462         Ok(x) => println!("{}", x),
463         Err(_) => println!("none"),
464     }
465 }
466 "#,
467         );
468     }
469
470     #[test]
471     fn special_case_inverted_result() {
472         check_assist(
473             replace_if_let_with_match,
474             r#"
475 //- minicore: result
476 fn foo(x: Result<i32, ()>) {
477     $0if let Err(x) = x {
478         println!("{}", x)
479     } else {
480         println!("ok")
481     }
482 }
483 "#,
484             r#"
485 fn foo(x: Result<i32, ()>) {
486     match x {
487         Err(x) => println!("{}", x),
488         Ok(_) => println!("ok"),
489     }
490 }
491 "#,
492         );
493     }
494
495     #[test]
496     fn nested_indent() {
497         check_assist(
498             replace_if_let_with_match,
499             r#"
500 fn main() {
501     if true {
502         $0if let Ok(rel_path) = path.strip_prefix(root_path) {
503             let rel_path = RelativePathBuf::from_path(rel_path).ok()?;
504             Some((*id, rel_path))
505         } else {
506             None
507         }
508     }
509 }
510 "#,
511             r#"
512 fn main() {
513     if true {
514         match path.strip_prefix(root_path) {
515             Ok(rel_path) => {
516                 let rel_path = RelativePathBuf::from_path(rel_path).ok()?;
517                 Some((*id, rel_path))
518             }
519             _ => None,
520         }
521     }
522 }
523 "#,
524         )
525     }
526
527     #[test]
528     fn test_replace_match_with_if_let_unwraps_simple_expressions() {
529         check_assist(
530             replace_match_with_if_let,
531             r#"
532 impl VariantData {
533     pub fn is_struct(&self) -> bool {
534         $0match *self {
535             VariantData::Struct(..) => true,
536             _ => false,
537         }
538     }
539 }           "#,
540             r#"
541 impl VariantData {
542     pub fn is_struct(&self) -> bool {
543         if let VariantData::Struct(..) = *self {
544             true
545         } else {
546             false
547         }
548     }
549 }           "#,
550         )
551     }
552
553     #[test]
554     fn test_replace_match_with_if_let_doesnt_unwrap_multiline_expressions() {
555         check_assist(
556             replace_match_with_if_let,
557             r#"
558 fn foo() {
559     $0match a {
560         VariantData::Struct(..) => {
561             bar(
562                 123
563             )
564         }
565         _ => false,
566     }
567 }           "#,
568             r#"
569 fn foo() {
570     if let VariantData::Struct(..) = a {
571         bar(
572             123
573         )
574     } else {
575         false
576     }
577 }           "#,
578         )
579     }
580
581     #[test]
582     fn replace_match_with_if_let_target() {
583         check_assist_target(
584             replace_match_with_if_let,
585             r#"
586 impl VariantData {
587     pub fn is_struct(&self) -> bool {
588         $0match *self {
589             VariantData::Struct(..) => true,
590             _ => false,
591         }
592     }
593 }           "#,
594             r#"match *self {
595             VariantData::Struct(..) => true,
596             _ => false,
597         }"#,
598         );
599     }
600
601     #[test]
602     fn special_case_option_match_to_if_let() {
603         check_assist(
604             replace_match_with_if_let,
605             r#"
606 //- minicore: option
607 fn foo(x: Option<i32>) {
608     $0match x {
609         Some(x) => println!("{}", x),
610         None => println!("none"),
611     }
612 }
613 "#,
614             r#"
615 fn foo(x: Option<i32>) {
616     if let Some(x) = x {
617         println!("{}", x)
618     } else {
619         println!("none")
620     }
621 }
622 "#,
623         );
624     }
625
626     #[test]
627     fn special_case_result_match_to_if_let() {
628         check_assist(
629             replace_match_with_if_let,
630             r#"
631 //- minicore: result
632 fn foo(x: Result<i32, ()>) {
633     $0match x {
634         Ok(x) => println!("{}", x),
635         Err(_) => println!("none"),
636     }
637 }
638 "#,
639             r#"
640 fn foo(x: Result<i32, ()>) {
641     if let Ok(x) = x {
642         println!("{}", x)
643     } else {
644         println!("none")
645     }
646 }
647 "#,
648         );
649     }
650
651     #[test]
652     fn nested_indent_match_to_if_let() {
653         check_assist(
654             replace_match_with_if_let,
655             r#"
656 fn main() {
657     if true {
658         $0match path.strip_prefix(root_path) {
659             Ok(rel_path) => {
660                 let rel_path = RelativePathBuf::from_path(rel_path).ok()?;
661                 Some((*id, rel_path))
662             }
663             _ => None,
664         }
665     }
666 }
667 "#,
668             r#"
669 fn main() {
670     if true {
671         if let Ok(rel_path) = path.strip_prefix(root_path) {
672             let rel_path = RelativePathBuf::from_path(rel_path).ok()?;
673             Some((*id, rel_path))
674         } else {
675             None
676         }
677     }
678 }
679 "#,
680         )
681     }
682
683     #[test]
684     fn replace_match_with_if_let_empty_wildcard_expr() {
685         check_assist(
686             replace_match_with_if_let,
687             r#"
688 fn main() {
689     $0match path.strip_prefix(root_path) {
690         Ok(rel_path) => println!("{}", rel_path),
691         _ => (),
692     }
693 }
694 "#,
695             r#"
696 fn main() {
697     if let Ok(rel_path) = path.strip_prefix(root_path) {
698         println!("{}", rel_path)
699     }
700 }
701 "#,
702         )
703     }
704
705     #[test]
706     fn replace_match_with_if_let_exhaustive() {
707         check_assist(
708             replace_match_with_if_let,
709             r#"
710 fn print_source(def_source: ModuleSource) {
711     match def_so$0urce {
712         ModuleSource::SourceFile(..) => { println!("source file"); }
713         ModuleSource::Module(..) => { println!("module"); }
714     }
715 }
716 "#,
717             r#"
718 fn print_source(def_source: ModuleSource) {
719     if let ModuleSource::SourceFile(..) = def_source { println!("source file"); } else { println!("module"); }
720 }
721 "#,
722         )
723     }
724
725     #[test]
726     fn replace_match_with_if_let_prefer_name_bind() {
727         check_assist(
728             replace_match_with_if_let,
729             r#"
730 fn foo() {
731     match $0Foo(0) {
732         Foo(_) => (),
733         Bar(bar) => println!("bar {}", bar),
734     }
735 }
736 "#,
737             r#"
738 fn foo() {
739     if let Bar(bar) = Foo(0) {
740         println!("bar {}", bar)
741     }
742 }
743 "#,
744         );
745         check_assist(
746             replace_match_with_if_let,
747             r#"
748 fn foo() {
749     match $0Foo(0) {
750         Bar(bar) => println!("bar {}", bar),
751         Foo(_) => (),
752     }
753 }
754 "#,
755             r#"
756 fn foo() {
757     if let Bar(bar) = Foo(0) {
758         println!("bar {}", bar)
759     }
760 }
761 "#,
762         );
763     }
764
765     #[test]
766     fn replace_match_with_if_let_rejects_double_name_bindings() {
767         check_assist_not_applicable(
768             replace_match_with_if_let,
769             r#"
770 fn foo() {
771     match $0Foo(0) {
772         Foo(foo) => println!("bar {}", foo),
773         Bar(bar) => println!("bar {}", bar),
774     }
775 }
776 "#,
777         );
778     }
779 }