]> git.lizzy.rs Git - rust.git/blob - crates/ide_assists/src/handlers/replace_if_let_with_match.rs
Add empty-body check to replace_match_with_if_let and re-prioritize choices
[rust.git] / crates / ide_assists / src / handlers / replace_if_let_with_match.rs
1 use std::iter::{self, successors};
2
3 use either::Either;
4 use ide_db::{ty_filter::TryEnum, RootDatabase};
5 use syntax::{
6     ast::{
7         self,
8         edit::{AstNodeEdit, IndentLevel},
9         make,
10     },
11     AstNode,
12 };
13
14 use crate::{
15     utils::{does_pat_match_variant, unwrap_trivial_block},
16     AssistContext, AssistId, AssistKind, Assists,
17 };
18
19 // Assist: replace_if_let_with_match
20 //
21 // Replaces a `if let` expression with a `match` expression.
22 //
23 // ```
24 // enum Action { Move { distance: u32 }, Stop }
25 //
26 // fn handle(action: Action) {
27 //     $0if let Action::Move { distance } = action {
28 //         foo(distance)
29 //     } else {
30 //         bar()
31 //     }
32 // }
33 // ```
34 // ->
35 // ```
36 // enum Action { Move { distance: u32 }, Stop }
37 //
38 // fn handle(action: Action) {
39 //     match action {
40 //         Action::Move { distance } => foo(distance),
41 //         _ => bar(),
42 //     }
43 // }
44 // ```
45 pub(crate) fn replace_if_let_with_match(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
46     let if_expr: ast::IfExpr = ctx.find_node_at_offset()?;
47     let mut else_block = None;
48     let if_exprs = successors(Some(if_expr.clone()), |expr| match expr.else_branch()? {
49         ast::ElseBranch::IfExpr(expr) => Some(expr),
50         ast::ElseBranch::Block(block) => {
51             else_block = Some(block);
52             None
53         }
54     });
55     let scrutinee_to_be_expr = if_expr.condition()?.expr()?;
56
57     let mut pat_seen = false;
58     let mut cond_bodies = Vec::new();
59     for if_expr in if_exprs {
60         let cond = if_expr.condition()?;
61         let expr = cond.expr()?;
62         let cond = match cond.pat() {
63             Some(pat) => {
64                 if scrutinee_to_be_expr.syntax().text() != expr.syntax().text() {
65                     // Only if all condition expressions are equal we can merge them into a match
66                     return None;
67                 }
68                 pat_seen = true;
69                 Either::Left(pat)
70             }
71             None => Either::Right(expr),
72         };
73         let body = if_expr.then_branch()?;
74         cond_bodies.push((cond, body));
75     }
76
77     if !pat_seen {
78         // Don't offer turning an if (chain) without patterns into a match
79         return None;
80     }
81
82     let target = if_expr.syntax().text_range();
83     acc.add(
84         AssistId("replace_if_let_with_match", AssistKind::RefactorRewrite),
85         "Replace if let with match",
86         target,
87         move |edit| {
88             let match_expr = {
89                 let else_arm = make_else_arm(ctx, else_block, &cond_bodies);
90                 let make_match_arm = |(pat, body): (_, ast::BlockExpr)| {
91                     let body = body.reset_indent().indent(IndentLevel(1));
92                     match pat {
93                         Either::Left(pat) => {
94                             make::match_arm(iter::once(pat), None, unwrap_trivial_block(body))
95                         }
96                         Either::Right(expr) => make::match_arm(
97                             iter::once(make::wildcard_pat().into()),
98                             Some(expr),
99                             unwrap_trivial_block(body),
100                         ),
101                     }
102                 };
103                 let arms = cond_bodies.into_iter().map(make_match_arm).chain(iter::once(else_arm));
104                 let match_expr = make::expr_match(scrutinee_to_be_expr, make::match_arm_list(arms));
105                 match_expr.indent(IndentLevel::from_node(if_expr.syntax()))
106             };
107
108             let has_preceding_if_expr =
109                 if_expr.syntax().parent().map_or(false, |it| ast::IfExpr::can_cast(it.kind()));
110             let expr = if has_preceding_if_expr {
111                 // make sure we replace the `else if let ...` with a block so we don't end up with `else expr`
112                 make::block_expr(None, Some(match_expr)).into()
113             } else {
114                 match_expr
115             };
116             edit.replace_ast::<ast::Expr>(if_expr.into(), expr);
117         },
118     )
119 }
120
121 fn make_else_arm(
122     ctx: &AssistContext,
123     else_block: Option<ast::BlockExpr>,
124     conditionals: &[(Either<ast::Pat, ast::Expr>, ast::BlockExpr)],
125 ) -> ast::MatchArm {
126     if let Some(else_block) = else_block {
127         let pattern = if let [(Either::Left(pat), _)] = conditionals {
128             ctx.sema
129                 .type_of_pat(&pat)
130                 .and_then(|ty| TryEnum::from_ty(&ctx.sema, &ty.adjusted()))
131                 .zip(Some(pat))
132         } else {
133             None
134         };
135         let pattern = match pattern {
136             Some((it, pat)) => {
137                 if does_pat_match_variant(&pat, &it.sad_pattern()) {
138                     it.happy_pattern()
139                 } else {
140                     it.sad_pattern()
141                 }
142             }
143             None => make::wildcard_pat().into(),
144         };
145         make::match_arm(iter::once(pattern), None, unwrap_trivial_block(else_block))
146     } else {
147         make::match_arm(iter::once(make::wildcard_pat().into()), None, make::expr_unit().into())
148     }
149 }
150
151 // Assist: replace_match_with_if_let
152 //
153 // Replaces a binary `match` with a wildcard pattern and no guards with an `if let` expression.
154 //
155 // ```
156 // enum Action { Move { distance: u32 }, Stop }
157 //
158 // fn handle(action: Action) {
159 //     $0match action {
160 //         Action::Move { distance } => foo(distance),
161 //         _ => bar(),
162 //     }
163 // }
164 // ```
165 // ->
166 // ```
167 // enum Action { Move { distance: u32 }, Stop }
168 //
169 // fn handle(action: Action) {
170 //     if let Action::Move { distance } = action {
171 //         foo(distance)
172 //     } else {
173 //         bar()
174 //     }
175 // }
176 // ```
177 pub(crate) fn replace_match_with_if_let(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
178     let match_expr: ast::MatchExpr = ctx.find_node_at_offset()?;
179
180     let mut arms = match_expr.match_arm_list()?.arms();
181     let (first_arm, second_arm) = (arms.next()?, arms.next()?);
182     if arms.next().is_some() || first_arm.guard().is_some() || second_arm.guard().is_some() {
183         return None;
184     }
185
186     let (if_let_pat, then_expr, else_expr) = pick_pattern_and_expr_order(
187         &ctx.sema,
188         first_arm.pat()?,
189         second_arm.pat()?,
190         first_arm.expr()?,
191         second_arm.expr()?,
192     )?;
193     let scrutinee = match_expr.expr()?;
194
195     let target = match_expr.syntax().text_range();
196     acc.add(
197         AssistId("replace_match_with_if_let", AssistKind::RefactorRewrite),
198         "Replace match with if let",
199         target,
200         move |edit| {
201             let condition = make::condition(scrutinee, Some(if_let_pat));
202             let then_block = match then_expr.reset_indent() {
203                 ast::Expr::BlockExpr(block) => block,
204                 expr => make::block_expr(iter::empty(), Some(expr)),
205             };
206             let else_expr = match else_expr {
207                 ast::Expr::BlockExpr(block) if block.is_empty() => None,
208                 ast::Expr::TupleExpr(tuple) if tuple.fields().next().is_none() => None,
209                 expr => Some(expr),
210             };
211             let if_let_expr = make::expr_if(
212                 condition,
213                 then_block,
214                 else_expr
215                     .map(|expr| match expr {
216                         ast::Expr::BlockExpr(block) => block,
217                         expr => (make::block_expr(iter::empty(), Some(expr))),
218                     })
219                     .map(ast::ElseBranch::Block),
220             )
221             .indent(IndentLevel::from_node(match_expr.syntax()));
222
223             edit.replace_ast::<ast::Expr>(match_expr.into(), if_let_expr);
224         },
225     )
226 }
227
228 /// Pick the pattern for the if let condition and return the expressions for the `then` body and `else` body in that order.
229 fn pick_pattern_and_expr_order(
230     sema: &hir::Semantics<RootDatabase>,
231     pat: ast::Pat,
232     pat2: ast::Pat,
233     expr: ast::Expr,
234     expr2: ast::Expr,
235 ) -> Option<(ast::Pat, ast::Expr, ast::Expr)> {
236     let res = match (pat, pat2) {
237         (ast::Pat::WildcardPat(_), _) => return None,
238         (pat, _) if expr2.syntax().first_child().is_none() => (pat, expr, expr2),
239         (_, pat) if expr.syntax().first_child().is_none() => (pat, expr2, expr),
240         (pat, pat2) => match (binds_name(&pat), binds_name(&pat2)) {
241             (true, false) => (pat, expr, expr2),
242             (false, true) => (pat2, expr2, expr),
243             _ if is_sad_pat(sema, &pat2) => (pat, expr, expr2),
244             _ if is_sad_pat(sema, &pat) => (pat2, expr2, expr),
245             (true, true) => return None,
246             (false, false) => (pat, expr, expr2),
247         },
248     };
249     Some(res)
250 }
251
252 fn binds_name(pat: &ast::Pat) -> bool {
253     let binds_name_v = |pat| binds_name(&pat);
254     match pat {
255         ast::Pat::IdentPat(_) => true,
256         ast::Pat::MacroPat(_) => true,
257         ast::Pat::OrPat(pat) => pat.pats().any(binds_name_v),
258         ast::Pat::SlicePat(pat) => pat.pats().any(binds_name_v),
259         ast::Pat::TuplePat(it) => it.fields().any(binds_name_v),
260         ast::Pat::TupleStructPat(it) => it.fields().any(binds_name_v),
261         ast::Pat::RecordPat(it) => it
262             .record_pat_field_list()
263             .map_or(false, |rpfl| rpfl.fields().flat_map(|rpf| rpf.pat()).any(binds_name_v)),
264         ast::Pat::RefPat(pat) => pat.pat().map_or(false, binds_name_v),
265         ast::Pat::BoxPat(pat) => pat.pat().map_or(false, binds_name_v),
266         ast::Pat::ParenPat(pat) => pat.pat().map_or(false, binds_name_v),
267         _ => false,
268     }
269 }
270
271 fn is_sad_pat(sema: &hir::Semantics<RootDatabase>, pat: &ast::Pat) -> bool {
272     sema.type_of_pat(pat)
273         .and_then(|ty| TryEnum::from_ty(sema, &ty.adjusted()))
274         .map_or(false, |it| does_pat_match_variant(pat, &it.sad_pattern()))
275 }
276
277 #[cfg(test)]
278 mod tests {
279     use super::*;
280
281     use crate::tests::{check_assist, check_assist_not_applicable, check_assist_target};
282
283     #[test]
284     fn test_if_let_with_match_unapplicable_for_simple_ifs() {
285         check_assist_not_applicable(
286             replace_if_let_with_match,
287             r#"
288 fn main() {
289     if $0true {} else if false {} else {}
290 }
291 "#,
292         )
293     }
294
295     #[test]
296     fn test_if_let_with_match_no_else() {
297         check_assist(
298             replace_if_let_with_match,
299             r#"
300 impl VariantData {
301     pub fn foo(&self) {
302         if $0let VariantData::Struct(..) = *self {
303             self.foo();
304         }
305     }
306 }
307 "#,
308             r#"
309 impl VariantData {
310     pub fn foo(&self) {
311         match *self {
312             VariantData::Struct(..) => {
313                 self.foo();
314             }
315             _ => (),
316         }
317     }
318 }
319 "#,
320         )
321     }
322
323     #[test]
324     fn test_if_let_with_match_basic() {
325         check_assist(
326             replace_if_let_with_match,
327             r#"
328 impl VariantData {
329     pub fn is_struct(&self) -> bool {
330         if $0let VariantData::Struct(..) = *self {
331             true
332         } else if let VariantData::Tuple(..) = *self {
333             false
334         } else if cond() {
335             true
336         } else {
337             bar(
338                 123
339             )
340         }
341     }
342 }
343 "#,
344             r#"
345 impl VariantData {
346     pub fn is_struct(&self) -> bool {
347         match *self {
348             VariantData::Struct(..) => true,
349             VariantData::Tuple(..) => false,
350             _ if cond() => true,
351             _ => {
352                     bar(
353                         123
354                     )
355                 }
356         }
357     }
358 }
359 "#,
360         )
361     }
362
363     #[test]
364     fn test_if_let_with_match_on_tail_if_let() {
365         check_assist(
366             replace_if_let_with_match,
367             r#"
368 impl VariantData {
369     pub fn is_struct(&self) -> bool {
370         if let VariantData::Struct(..) = *self {
371             true
372         } else if let$0 VariantData::Tuple(..) = *self {
373             false
374         } else {
375             false
376         }
377     }
378 }
379 "#,
380             r#"
381 impl VariantData {
382     pub fn is_struct(&self) -> bool {
383         if let VariantData::Struct(..) = *self {
384             true
385         } else {
386     match *self {
387             VariantData::Tuple(..) => false,
388             _ => false,
389         }
390 }
391     }
392 }
393 "#,
394         )
395     }
396
397     #[test]
398     fn special_case_option() {
399         check_assist(
400             replace_if_let_with_match,
401             r#"
402 //- minicore: option
403 fn foo(x: Option<i32>) {
404     $0if let Some(x) = x {
405         println!("{}", x)
406     } else {
407         println!("none")
408     }
409 }
410 "#,
411             r#"
412 fn foo(x: Option<i32>) {
413     match x {
414         Some(x) => println!("{}", x),
415         None => println!("none"),
416     }
417 }
418 "#,
419         );
420     }
421
422     #[test]
423     fn special_case_inverted_option() {
424         check_assist(
425             replace_if_let_with_match,
426             r#"
427 //- minicore: option
428 fn foo(x: Option<i32>) {
429     $0if let None = x {
430         println!("none")
431     } else {
432         println!("some")
433     }
434 }
435 "#,
436             r#"
437 fn foo(x: Option<i32>) {
438     match x {
439         None => println!("none"),
440         Some(_) => println!("some"),
441     }
442 }
443 "#,
444         );
445     }
446
447     #[test]
448     fn special_case_result() {
449         check_assist(
450             replace_if_let_with_match,
451             r#"
452 //- minicore: result
453 fn foo(x: Result<i32, ()>) {
454     $0if let Ok(x) = x {
455         println!("{}", x)
456     } else {
457         println!("none")
458     }
459 }
460 "#,
461             r#"
462 fn foo(x: Result<i32, ()>) {
463     match x {
464         Ok(x) => println!("{}", x),
465         Err(_) => println!("none"),
466     }
467 }
468 "#,
469         );
470     }
471
472     #[test]
473     fn special_case_inverted_result() {
474         check_assist(
475             replace_if_let_with_match,
476             r#"
477 //- minicore: result
478 fn foo(x: Result<i32, ()>) {
479     $0if let Err(x) = x {
480         println!("{}", x)
481     } else {
482         println!("ok")
483     }
484 }
485 "#,
486             r#"
487 fn foo(x: Result<i32, ()>) {
488     match x {
489         Err(x) => println!("{}", x),
490         Ok(_) => println!("ok"),
491     }
492 }
493 "#,
494         );
495     }
496
497     #[test]
498     fn nested_indent() {
499         check_assist(
500             replace_if_let_with_match,
501             r#"
502 fn main() {
503     if true {
504         $0if let Ok(rel_path) = path.strip_prefix(root_path) {
505             let rel_path = RelativePathBuf::from_path(rel_path).ok()?;
506             Some((*id, rel_path))
507         } else {
508             None
509         }
510     }
511 }
512 "#,
513             r#"
514 fn main() {
515     if true {
516         match path.strip_prefix(root_path) {
517             Ok(rel_path) => {
518                 let rel_path = RelativePathBuf::from_path(rel_path).ok()?;
519                 Some((*id, rel_path))
520             }
521             _ => None,
522         }
523     }
524 }
525 "#,
526         )
527     }
528
529     #[test]
530     fn test_replace_match_with_if_let_unwraps_simple_expressions() {
531         check_assist(
532             replace_match_with_if_let,
533             r#"
534 impl VariantData {
535     pub fn is_struct(&self) -> bool {
536         $0match *self {
537             VariantData::Struct(..) => true,
538             _ => false,
539         }
540     }
541 }           "#,
542             r#"
543 impl VariantData {
544     pub fn is_struct(&self) -> bool {
545         if let VariantData::Struct(..) = *self {
546             true
547         } else {
548             false
549         }
550     }
551 }           "#,
552         )
553     }
554
555     #[test]
556     fn test_replace_match_with_if_let_doesnt_unwrap_multiline_expressions() {
557         check_assist(
558             replace_match_with_if_let,
559             r#"
560 fn foo() {
561     $0match a {
562         VariantData::Struct(..) => {
563             bar(
564                 123
565             )
566         }
567         _ => false,
568     }
569 }           "#,
570             r#"
571 fn foo() {
572     if let VariantData::Struct(..) = a {
573         bar(
574             123
575         )
576     } else {
577         false
578     }
579 }           "#,
580         )
581     }
582
583     #[test]
584     fn replace_match_with_if_let_target() {
585         check_assist_target(
586             replace_match_with_if_let,
587             r#"
588 impl VariantData {
589     pub fn is_struct(&self) -> bool {
590         $0match *self {
591             VariantData::Struct(..) => true,
592             _ => false,
593         }
594     }
595 }           "#,
596             r#"match *self {
597             VariantData::Struct(..) => true,
598             _ => false,
599         }"#,
600         );
601     }
602
603     #[test]
604     fn special_case_option_match_to_if_let() {
605         check_assist(
606             replace_match_with_if_let,
607             r#"
608 //- minicore: option
609 fn foo(x: Option<i32>) {
610     $0match x {
611         Some(x) => println!("{}", x),
612         None => println!("none"),
613     }
614 }
615 "#,
616             r#"
617 fn foo(x: Option<i32>) {
618     if let Some(x) = x {
619         println!("{}", x)
620     } else {
621         println!("none")
622     }
623 }
624 "#,
625         );
626     }
627
628     #[test]
629     fn special_case_result_match_to_if_let() {
630         check_assist(
631             replace_match_with_if_let,
632             r#"
633 //- minicore: result
634 fn foo(x: Result<i32, ()>) {
635     $0match x {
636         Ok(x) => println!("{}", x),
637         Err(_) => println!("none"),
638     }
639 }
640 "#,
641             r#"
642 fn foo(x: Result<i32, ()>) {
643     if let Ok(x) = x {
644         println!("{}", x)
645     } else {
646         println!("none")
647     }
648 }
649 "#,
650         );
651     }
652
653     #[test]
654     fn nested_indent_match_to_if_let() {
655         check_assist(
656             replace_match_with_if_let,
657             r#"
658 fn main() {
659     if true {
660         $0match path.strip_prefix(root_path) {
661             Ok(rel_path) => {
662                 let rel_path = RelativePathBuf::from_path(rel_path).ok()?;
663                 Some((*id, rel_path))
664             }
665             _ => None,
666         }
667     }
668 }
669 "#,
670             r#"
671 fn main() {
672     if true {
673         if let Ok(rel_path) = path.strip_prefix(root_path) {
674             let rel_path = RelativePathBuf::from_path(rel_path).ok()?;
675             Some((*id, rel_path))
676         } else {
677             None
678         }
679     }
680 }
681 "#,
682         )
683     }
684
685     #[test]
686     fn replace_match_with_if_let_empty_wildcard_expr() {
687         check_assist(
688             replace_match_with_if_let,
689             r#"
690 fn main() {
691     $0match path.strip_prefix(root_path) {
692         Ok(rel_path) => println!("{}", rel_path),
693         _ => (),
694     }
695 }
696 "#,
697             r#"
698 fn main() {
699     if let Ok(rel_path) = path.strip_prefix(root_path) {
700         println!("{}", rel_path)
701     }
702 }
703 "#,
704         )
705     }
706
707     #[test]
708     fn replace_match_with_if_let_exhaustive() {
709         check_assist(
710             replace_match_with_if_let,
711             r#"
712 fn print_source(def_source: ModuleSource) {
713     match def_so$0urce {
714         ModuleSource::SourceFile(..) => { println!("source file"); }
715         ModuleSource::Module(..) => { println!("module"); }
716     }
717 }
718 "#,
719             r#"
720 fn print_source(def_source: ModuleSource) {
721     if let ModuleSource::SourceFile(..) = def_source { println!("source file"); } else { println!("module"); }
722 }
723 "#,
724         )
725     }
726
727     #[test]
728     fn replace_match_with_if_let_prefer_name_bind() {
729         check_assist(
730             replace_match_with_if_let,
731             r#"
732 fn foo() {
733     match $0Foo(0) {
734         Foo(_) => (),
735         Bar(bar) => println!("bar {}", bar),
736     }
737 }
738 "#,
739             r#"
740 fn foo() {
741     if let Bar(bar) = Foo(0) {
742         println!("bar {}", bar)
743     }
744 }
745 "#,
746         );
747         check_assist(
748             replace_match_with_if_let,
749             r#"
750 fn foo() {
751     match $0Foo(0) {
752         Bar(bar) => println!("bar {}", bar),
753         Foo(_) => (),
754     }
755 }
756 "#,
757             r#"
758 fn foo() {
759     if let Bar(bar) = Foo(0) {
760         println!("bar {}", bar)
761     }
762 }
763 "#,
764         );
765     }
766
767     #[test]
768     fn replace_match_with_if_let_prefer_nonempty_body() {
769         check_assist(
770             replace_match_with_if_let,
771             r#"
772 fn foo() {
773     match $0Ok(0) {
774         Ok(value) => {},
775         Err(err) => eprintln!("{}", err),
776     }
777 }
778 "#,
779             r#"
780 fn foo() {
781     if let Err(err) = Ok(0) {
782         eprintln!("{}", err)
783     }
784 }
785 "#,
786         );
787         check_assist(
788             replace_match_with_if_let,
789             r#"
790 fn foo() {
791     match $0Ok(0) {
792         Err(err) => eprintln!("{}", err),
793         Ok(value) => {},
794     }
795 }
796 "#,
797             r#"
798 fn foo() {
799     if let Err(err) = Ok(0) {
800         eprintln!("{}", err)
801     }
802 }
803 "#,
804         );
805     }
806
807     #[test]
808     fn replace_match_with_if_let_rejects_double_name_bindings() {
809         check_assist_not_applicable(
810             replace_match_with_if_let,
811             r#"
812 fn foo() {
813     match $0Foo(0) {
814         Foo(foo) => println!("bar {}", foo),
815         Bar(bar) => println!("bar {}", bar),
816     }
817 }
818 "#,
819         );
820     }
821 }