]> git.lizzy.rs Git - rust.git/blob - crates/ide_assists/src/handlers/replace_if_let_with_match.rs
Merge #9954
[rust.git] / crates / ide_assists / src / handlers / replace_if_let_with_match.rs
1 use std::iter::{self, successors};
2
3 use either::Either;
4 use ide_db::{defs::NameClass, ty_filter::TryEnum, RootDatabase};
5 use syntax::{
6     ast::{
7         self,
8         edit::{AstNodeEdit, IndentLevel},
9         make, NameOwner,
10     },
11     AstNode,
12 };
13
14 use crate::{
15     utils::{does_pat_match_variant, unwrap_trivial_block},
16     AssistContext, AssistId, AssistKind, Assists,
17 };
18
19 // Assist: replace_if_let_with_match
20 //
21 // Replaces a `if let` expression with a `match` expression.
22 //
23 // ```
24 // enum Action { Move { distance: u32 }, Stop }
25 //
26 // fn handle(action: Action) {
27 //     $0if let Action::Move { distance } = action {
28 //         foo(distance)
29 //     } else {
30 //         bar()
31 //     }
32 // }
33 // ```
34 // ->
35 // ```
36 // enum Action { Move { distance: u32 }, Stop }
37 //
38 // fn handle(action: Action) {
39 //     match action {
40 //         Action::Move { distance } => foo(distance),
41 //         _ => bar(),
42 //     }
43 // }
44 // ```
45 pub(crate) fn replace_if_let_with_match(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
46     let if_expr: ast::IfExpr = ctx.find_node_at_offset()?;
47     let mut else_block = None;
48     let if_exprs = successors(Some(if_expr.clone()), |expr| match expr.else_branch()? {
49         ast::ElseBranch::IfExpr(expr) => Some(expr),
50         ast::ElseBranch::Block(block) => {
51             else_block = Some(block);
52             None
53         }
54     });
55     let scrutinee_to_be_expr = if_expr.condition()?.expr()?;
56
57     let mut pat_seen = false;
58     let mut cond_bodies = Vec::new();
59     for if_expr in if_exprs {
60         let cond = if_expr.condition()?;
61         let expr = cond.expr()?;
62         let cond = match cond.pat() {
63             Some(pat) => {
64                 if scrutinee_to_be_expr.syntax().text() != expr.syntax().text() {
65                     // Only if all condition expressions are equal we can merge them into a match
66                     return None;
67                 }
68                 pat_seen = true;
69                 Either::Left(pat)
70             }
71             None => Either::Right(expr),
72         };
73         let body = if_expr.then_branch()?;
74         cond_bodies.push((cond, body));
75     }
76
77     if !pat_seen {
78         // Don't offer turning an if (chain) without patterns into a match
79         return None;
80     }
81
82     let target = if_expr.syntax().text_range();
83     acc.add(
84         AssistId("replace_if_let_with_match", AssistKind::RefactorRewrite),
85         "Replace if let with match",
86         target,
87         move |edit| {
88             let match_expr = {
89                 let else_arm = make_else_arm(ctx, else_block, &cond_bodies);
90                 let make_match_arm = |(pat, body): (_, ast::BlockExpr)| {
91                     let body = body.reset_indent().indent(IndentLevel(1));
92                     match pat {
93                         Either::Left(pat) => {
94                             make::match_arm(iter::once(pat), None, unwrap_trivial_block(body))
95                         }
96                         Either::Right(expr) => make::match_arm(
97                             iter::once(make::wildcard_pat().into()),
98                             Some(expr),
99                             unwrap_trivial_block(body),
100                         ),
101                     }
102                 };
103                 let arms = cond_bodies.into_iter().map(make_match_arm).chain(iter::once(else_arm));
104                 let match_expr = make::expr_match(scrutinee_to_be_expr, make::match_arm_list(arms));
105                 match_expr.indent(IndentLevel::from_node(if_expr.syntax()))
106             };
107
108             let has_preceding_if_expr =
109                 if_expr.syntax().parent().map_or(false, |it| ast::IfExpr::can_cast(it.kind()));
110             let expr = if has_preceding_if_expr {
111                 // make sure we replace the `else if let ...` with a block so we don't end up with `else expr`
112                 make::block_expr(None, Some(match_expr)).into()
113             } else {
114                 match_expr
115             };
116             edit.replace_ast::<ast::Expr>(if_expr.into(), expr);
117         },
118     )
119 }
120
121 fn make_else_arm(
122     ctx: &AssistContext,
123     else_block: Option<ast::BlockExpr>,
124     conditionals: &[(Either<ast::Pat, ast::Expr>, ast::BlockExpr)],
125 ) -> ast::MatchArm {
126     if let Some(else_block) = else_block {
127         let pattern = if let [(Either::Left(pat), _)] = conditionals {
128             ctx.sema
129                 .type_of_pat(&pat)
130                 .and_then(|ty| TryEnum::from_ty(&ctx.sema, &ty.adjusted()))
131                 .zip(Some(pat))
132         } else {
133             None
134         };
135         let pattern = match pattern {
136             Some((it, pat)) => {
137                 if does_pat_match_variant(&pat, &it.sad_pattern()) {
138                     it.happy_pattern()
139                 } else {
140                     it.sad_pattern()
141                 }
142             }
143             None => make::wildcard_pat().into(),
144         };
145         make::match_arm(iter::once(pattern), None, unwrap_trivial_block(else_block))
146     } else {
147         make::match_arm(iter::once(make::wildcard_pat().into()), None, make::expr_unit().into())
148     }
149 }
150
151 // Assist: replace_match_with_if_let
152 //
153 // Replaces a binary `match` with a wildcard pattern and no guards with an `if let` expression.
154 //
155 // ```
156 // enum Action { Move { distance: u32 }, Stop }
157 //
158 // fn handle(action: Action) {
159 //     $0match action {
160 //         Action::Move { distance } => foo(distance),
161 //         _ => bar(),
162 //     }
163 // }
164 // ```
165 // ->
166 // ```
167 // enum Action { Move { distance: u32 }, Stop }
168 //
169 // fn handle(action: Action) {
170 //     if let Action::Move { distance } = action {
171 //         foo(distance)
172 //     } else {
173 //         bar()
174 //     }
175 // }
176 // ```
177 pub(crate) fn replace_match_with_if_let(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
178     let match_expr: ast::MatchExpr = ctx.find_node_at_offset()?;
179
180     let mut arms = match_expr.match_arm_list()?.arms();
181     let (first_arm, second_arm) = (arms.next()?, arms.next()?);
182     if arms.next().is_some() || first_arm.guard().is_some() || second_arm.guard().is_some() {
183         return None;
184     }
185
186     let (if_let_pat, then_expr, else_expr) = pick_pattern_and_expr_order(
187         &ctx.sema,
188         first_arm.pat()?,
189         second_arm.pat()?,
190         first_arm.expr()?,
191         second_arm.expr()?,
192     )?;
193     let scrutinee = match_expr.expr()?;
194
195     let target = match_expr.syntax().text_range();
196     acc.add(
197         AssistId("replace_match_with_if_let", AssistKind::RefactorRewrite),
198         "Replace match with if let",
199         target,
200         move |edit| {
201             let condition = make::condition(scrutinee, Some(if_let_pat));
202             let then_block = match then_expr.reset_indent() {
203                 ast::Expr::BlockExpr(block) => block,
204                 expr => make::block_expr(iter::empty(), Some(expr)),
205             };
206             let else_expr = match else_expr {
207                 ast::Expr::BlockExpr(block) if block.is_empty() => None,
208                 ast::Expr::TupleExpr(tuple) if tuple.fields().next().is_none() => None,
209                 expr => Some(expr),
210             };
211             let if_let_expr = make::expr_if(
212                 condition,
213                 then_block,
214                 else_expr
215                     .map(|expr| match expr {
216                         ast::Expr::BlockExpr(block) => block,
217                         expr => (make::block_expr(iter::empty(), Some(expr))),
218                     })
219                     .map(ast::ElseBranch::Block),
220             )
221             .indent(IndentLevel::from_node(match_expr.syntax()));
222
223             edit.replace_ast::<ast::Expr>(match_expr.into(), if_let_expr);
224         },
225     )
226 }
227
228 /// Pick the pattern for the if let condition and return the expressions for the `then` body and `else` body in that order.
229 fn pick_pattern_and_expr_order(
230     sema: &hir::Semantics<RootDatabase>,
231     pat: ast::Pat,
232     pat2: ast::Pat,
233     expr: ast::Expr,
234     expr2: ast::Expr,
235 ) -> Option<(ast::Pat, ast::Expr, ast::Expr)> {
236     let res = match (pat, pat2) {
237         (ast::Pat::WildcardPat(_), _) => return None,
238         (pat, _) if is_empty_expr(&expr2) => (pat, expr, expr2),
239         (_, pat) if is_empty_expr(&expr) => (pat, expr2, expr),
240         (pat, pat2) => match (binds_name(sema, &pat), binds_name(sema, &pat2)) {
241             (true, true) => return None,
242             (true, false) => (pat, expr, expr2),
243             (false, true) => (pat2, expr2, expr),
244             _ if is_sad_pat(sema, &pat) => (pat2, expr2, expr),
245             (false, false) => (pat, expr, expr2),
246         },
247     };
248     Some(res)
249 }
250
251 fn is_empty_expr(expr: &ast::Expr) -> bool {
252     match expr {
253         ast::Expr::BlockExpr(expr) => expr.is_empty(),
254         ast::Expr::TupleExpr(expr) => expr.fields().next().is_none(),
255         _ => false,
256     }
257 }
258
259 fn binds_name(sema: &hir::Semantics<RootDatabase>, pat: &ast::Pat) -> bool {
260     let binds_name_v = |pat| binds_name(&sema, &pat);
261     match pat {
262         ast::Pat::IdentPat(pat) => !matches!(
263             pat.name().and_then(|name| NameClass::classify(sema, &name)),
264             Some(NameClass::ConstReference(_))
265         ),
266         ast::Pat::MacroPat(_) => true,
267         ast::Pat::OrPat(pat) => pat.pats().any(binds_name_v),
268         ast::Pat::SlicePat(pat) => pat.pats().any(binds_name_v),
269         ast::Pat::TuplePat(it) => it.fields().any(binds_name_v),
270         ast::Pat::TupleStructPat(it) => it.fields().any(binds_name_v),
271         ast::Pat::RecordPat(it) => it
272             .record_pat_field_list()
273             .map_or(false, |rpfl| rpfl.fields().flat_map(|rpf| rpf.pat()).any(binds_name_v)),
274         ast::Pat::RefPat(pat) => pat.pat().map_or(false, binds_name_v),
275         ast::Pat::BoxPat(pat) => pat.pat().map_or(false, binds_name_v),
276         ast::Pat::ParenPat(pat) => pat.pat().map_or(false, binds_name_v),
277         _ => false,
278     }
279 }
280
281 fn is_sad_pat(sema: &hir::Semantics<RootDatabase>, pat: &ast::Pat) -> bool {
282     sema.type_of_pat(pat)
283         .and_then(|ty| TryEnum::from_ty(sema, &ty.adjusted()))
284         .map_or(false, |it| does_pat_match_variant(pat, &it.sad_pattern()))
285 }
286
287 #[cfg(test)]
288 mod tests {
289     use super::*;
290
291     use crate::tests::{check_assist, check_assist_not_applicable, check_assist_target};
292
293     #[test]
294     fn test_if_let_with_match_unapplicable_for_simple_ifs() {
295         check_assist_not_applicable(
296             replace_if_let_with_match,
297             r#"
298 fn main() {
299     if $0true {} else if false {} else {}
300 }
301 "#,
302         )
303     }
304
305     #[test]
306     fn test_if_let_with_match_no_else() {
307         check_assist(
308             replace_if_let_with_match,
309             r#"
310 impl VariantData {
311     pub fn foo(&self) {
312         if $0let VariantData::Struct(..) = *self {
313             self.foo();
314         }
315     }
316 }
317 "#,
318             r#"
319 impl VariantData {
320     pub fn foo(&self) {
321         match *self {
322             VariantData::Struct(..) => {
323                 self.foo();
324             }
325             _ => (),
326         }
327     }
328 }
329 "#,
330         )
331     }
332
333     #[test]
334     fn test_if_let_with_match_basic() {
335         check_assist(
336             replace_if_let_with_match,
337             r#"
338 impl VariantData {
339     pub fn is_struct(&self) -> bool {
340         if $0let VariantData::Struct(..) = *self {
341             true
342         } else if let VariantData::Tuple(..) = *self {
343             false
344         } else if cond() {
345             true
346         } else {
347             bar(
348                 123
349             )
350         }
351     }
352 }
353 "#,
354             r#"
355 impl VariantData {
356     pub fn is_struct(&self) -> bool {
357         match *self {
358             VariantData::Struct(..) => true,
359             VariantData::Tuple(..) => false,
360             _ if cond() => true,
361             _ => {
362                     bar(
363                         123
364                     )
365                 }
366         }
367     }
368 }
369 "#,
370         )
371     }
372
373     #[test]
374     fn test_if_let_with_match_on_tail_if_let() {
375         check_assist(
376             replace_if_let_with_match,
377             r#"
378 impl VariantData {
379     pub fn is_struct(&self) -> bool {
380         if let VariantData::Struct(..) = *self {
381             true
382         } else if let$0 VariantData::Tuple(..) = *self {
383             false
384         } else {
385             false
386         }
387     }
388 }
389 "#,
390             r#"
391 impl VariantData {
392     pub fn is_struct(&self) -> bool {
393         if let VariantData::Struct(..) = *self {
394             true
395         } else {
396     match *self {
397             VariantData::Tuple(..) => false,
398             _ => false,
399         }
400 }
401     }
402 }
403 "#,
404         )
405     }
406
407     #[test]
408     fn special_case_option() {
409         check_assist(
410             replace_if_let_with_match,
411             r#"
412 //- minicore: option
413 fn foo(x: Option<i32>) {
414     $0if let Some(x) = x {
415         println!("{}", x)
416     } else {
417         println!("none")
418     }
419 }
420 "#,
421             r#"
422 fn foo(x: Option<i32>) {
423     match x {
424         Some(x) => println!("{}", x),
425         None => println!("none"),
426     }
427 }
428 "#,
429         );
430     }
431
432     #[test]
433     fn special_case_inverted_option() {
434         check_assist(
435             replace_if_let_with_match,
436             r#"
437 //- minicore: option
438 fn foo(x: Option<i32>) {
439     $0if let None = x {
440         println!("none")
441     } else {
442         println!("some")
443     }
444 }
445 "#,
446             r#"
447 fn foo(x: Option<i32>) {
448     match x {
449         None => println!("none"),
450         Some(_) => println!("some"),
451     }
452 }
453 "#,
454         );
455     }
456
457     #[test]
458     fn special_case_result() {
459         check_assist(
460             replace_if_let_with_match,
461             r#"
462 //- minicore: result
463 fn foo(x: Result<i32, ()>) {
464     $0if let Ok(x) = x {
465         println!("{}", x)
466     } else {
467         println!("none")
468     }
469 }
470 "#,
471             r#"
472 fn foo(x: Result<i32, ()>) {
473     match x {
474         Ok(x) => println!("{}", x),
475         Err(_) => println!("none"),
476     }
477 }
478 "#,
479         );
480     }
481
482     #[test]
483     fn special_case_inverted_result() {
484         check_assist(
485             replace_if_let_with_match,
486             r#"
487 //- minicore: result
488 fn foo(x: Result<i32, ()>) {
489     $0if let Err(x) = x {
490         println!("{}", x)
491     } else {
492         println!("ok")
493     }
494 }
495 "#,
496             r#"
497 fn foo(x: Result<i32, ()>) {
498     match x {
499         Err(x) => println!("{}", x),
500         Ok(_) => println!("ok"),
501     }
502 }
503 "#,
504         );
505     }
506
507     #[test]
508     fn nested_indent() {
509         check_assist(
510             replace_if_let_with_match,
511             r#"
512 fn main() {
513     if true {
514         $0if let Ok(rel_path) = path.strip_prefix(root_path) {
515             let rel_path = RelativePathBuf::from_path(rel_path).ok()?;
516             Some((*id, rel_path))
517         } else {
518             None
519         }
520     }
521 }
522 "#,
523             r#"
524 fn main() {
525     if true {
526         match path.strip_prefix(root_path) {
527             Ok(rel_path) => {
528                 let rel_path = RelativePathBuf::from_path(rel_path).ok()?;
529                 Some((*id, rel_path))
530             }
531             _ => None,
532         }
533     }
534 }
535 "#,
536         )
537     }
538
539     #[test]
540     fn test_replace_match_with_if_let_unwraps_simple_expressions() {
541         check_assist(
542             replace_match_with_if_let,
543             r#"
544 impl VariantData {
545     pub fn is_struct(&self) -> bool {
546         $0match *self {
547             VariantData::Struct(..) => true,
548             _ => false,
549         }
550     }
551 }           "#,
552             r#"
553 impl VariantData {
554     pub fn is_struct(&self) -> bool {
555         if let VariantData::Struct(..) = *self {
556             true
557         } else {
558             false
559         }
560     }
561 }           "#,
562         )
563     }
564
565     #[test]
566     fn test_replace_match_with_if_let_doesnt_unwrap_multiline_expressions() {
567         check_assist(
568             replace_match_with_if_let,
569             r#"
570 fn foo() {
571     $0match a {
572         VariantData::Struct(..) => {
573             bar(
574                 123
575             )
576         }
577         _ => false,
578     }
579 }           "#,
580             r#"
581 fn foo() {
582     if let VariantData::Struct(..) = a {
583         bar(
584             123
585         )
586     } else {
587         false
588     }
589 }           "#,
590         )
591     }
592
593     #[test]
594     fn replace_match_with_if_let_target() {
595         check_assist_target(
596             replace_match_with_if_let,
597             r#"
598 impl VariantData {
599     pub fn is_struct(&self) -> bool {
600         $0match *self {
601             VariantData::Struct(..) => true,
602             _ => false,
603         }
604     }
605 }           "#,
606             r#"match *self {
607             VariantData::Struct(..) => true,
608             _ => false,
609         }"#,
610         );
611     }
612
613     #[test]
614     fn special_case_option_match_to_if_let() {
615         check_assist(
616             replace_match_with_if_let,
617             r#"
618 //- minicore: option
619 fn foo(x: Option<i32>) {
620     $0match x {
621         Some(x) => println!("{}", x),
622         None => println!("none"),
623     }
624 }
625 "#,
626             r#"
627 fn foo(x: Option<i32>) {
628     if let Some(x) = x {
629         println!("{}", x)
630     } else {
631         println!("none")
632     }
633 }
634 "#,
635         );
636     }
637
638     #[test]
639     fn special_case_result_match_to_if_let() {
640         check_assist(
641             replace_match_with_if_let,
642             r#"
643 //- minicore: result
644 fn foo(x: Result<i32, ()>) {
645     $0match x {
646         Ok(x) => println!("{}", x),
647         Err(_) => println!("none"),
648     }
649 }
650 "#,
651             r#"
652 fn foo(x: Result<i32, ()>) {
653     if let Ok(x) = x {
654         println!("{}", x)
655     } else {
656         println!("none")
657     }
658 }
659 "#,
660         );
661     }
662
663     #[test]
664     fn nested_indent_match_to_if_let() {
665         check_assist(
666             replace_match_with_if_let,
667             r#"
668 fn main() {
669     if true {
670         $0match path.strip_prefix(root_path) {
671             Ok(rel_path) => {
672                 let rel_path = RelativePathBuf::from_path(rel_path).ok()?;
673                 Some((*id, rel_path))
674             }
675             _ => None,
676         }
677     }
678 }
679 "#,
680             r#"
681 fn main() {
682     if true {
683         if let Ok(rel_path) = path.strip_prefix(root_path) {
684             let rel_path = RelativePathBuf::from_path(rel_path).ok()?;
685             Some((*id, rel_path))
686         } else {
687             None
688         }
689     }
690 }
691 "#,
692         )
693     }
694
695     #[test]
696     fn replace_match_with_if_let_empty_wildcard_expr() {
697         check_assist(
698             replace_match_with_if_let,
699             r#"
700 fn main() {
701     $0match path.strip_prefix(root_path) {
702         Ok(rel_path) => println!("{}", rel_path),
703         _ => (),
704     }
705 }
706 "#,
707             r#"
708 fn main() {
709     if let Ok(rel_path) = path.strip_prefix(root_path) {
710         println!("{}", rel_path)
711     }
712 }
713 "#,
714         )
715     }
716
717     #[test]
718     fn replace_match_with_if_let_number_body() {
719         check_assist(
720             replace_match_with_if_let,
721             r#"
722 fn main() {
723     $0match Ok(()) {
724         Ok(()) => {},
725         Err(_) => 0,
726     }
727 }
728 "#,
729             r#"
730 fn main() {
731     if let Err(_) = Ok(()) {
732         0
733     }
734 }
735 "#,
736         )
737     }
738
739     #[test]
740     fn replace_match_with_if_let_exhaustive() {
741         check_assist(
742             replace_match_with_if_let,
743             r#"
744 fn print_source(def_source: ModuleSource) {
745     match def_so$0urce {
746         ModuleSource::SourceFile(..) => { println!("source file"); }
747         ModuleSource::Module(..) => { println!("module"); }
748     }
749 }
750 "#,
751             r#"
752 fn print_source(def_source: ModuleSource) {
753     if let ModuleSource::SourceFile(..) = def_source { println!("source file"); } else { println!("module"); }
754 }
755 "#,
756         )
757     }
758
759     #[test]
760     fn replace_match_with_if_let_prefer_name_bind() {
761         check_assist(
762             replace_match_with_if_let,
763             r#"
764 fn foo() {
765     match $0Foo(0) {
766         Foo(_) => (),
767         Bar(bar) => println!("bar {}", bar),
768     }
769 }
770 "#,
771             r#"
772 fn foo() {
773     if let Bar(bar) = Foo(0) {
774         println!("bar {}", bar)
775     }
776 }
777 "#,
778         );
779         check_assist(
780             replace_match_with_if_let,
781             r#"
782 fn foo() {
783     match $0Foo(0) {
784         Bar(bar) => println!("bar {}", bar),
785         Foo(_) => (),
786     }
787 }
788 "#,
789             r#"
790 fn foo() {
791     if let Bar(bar) = Foo(0) {
792         println!("bar {}", bar)
793     }
794 }
795 "#,
796         );
797     }
798
799     #[test]
800     fn replace_match_with_if_let_prefer_nonempty_body() {
801         check_assist(
802             replace_match_with_if_let,
803             r#"
804 fn foo() {
805     match $0Ok(0) {
806         Ok(value) => {},
807         Err(err) => eprintln!("{}", err),
808     }
809 }
810 "#,
811             r#"
812 fn foo() {
813     if let Err(err) = Ok(0) {
814         eprintln!("{}", err)
815     }
816 }
817 "#,
818         );
819         check_assist(
820             replace_match_with_if_let,
821             r#"
822 fn foo() {
823     match $0Ok(0) {
824         Err(err) => eprintln!("{}", err),
825         Ok(value) => {},
826     }
827 }
828 "#,
829             r#"
830 fn foo() {
831     if let Err(err) = Ok(0) {
832         eprintln!("{}", err)
833     }
834 }
835 "#,
836         );
837     }
838
839     #[test]
840     fn replace_match_with_if_let_rejects_double_name_bindings() {
841         check_assist_not_applicable(
842             replace_match_with_if_let,
843             r#"
844 fn foo() {
845     match $0Foo(0) {
846         Foo(foo) => println!("bar {}", foo),
847         Bar(bar) => println!("bar {}", bar),
848     }
849 }
850 "#,
851         );
852     }
853 }