]> git.lizzy.rs Git - rust.git/blob - crates/ide_assists/src/handlers/replace_if_let_with_match.rs
minor: reduce duplication
[rust.git] / crates / ide_assists / src / handlers / replace_if_let_with_match.rs
1 use std::iter::{self, successors};
2
3 use either::Either;
4 use ide_db::{defs::NameClass, ty_filter::TryEnum, RootDatabase};
5 use syntax::{
6     ast::{
7         self,
8         edit::{AstNodeEdit, IndentLevel},
9         make, HasName,
10     },
11     AstNode, TextRange,
12 };
13
14 use crate::{
15     utils::{does_pat_match_variant, unwrap_trivial_block},
16     AssistContext, AssistId, AssistKind, Assists,
17 };
18
19 // Assist: replace_if_let_with_match
20 //
21 // Replaces a `if let` expression with a `match` expression.
22 //
23 // ```
24 // enum Action { Move { distance: u32 }, Stop }
25 //
26 // fn handle(action: Action) {
27 //     $0if let Action::Move { distance } = action {
28 //         foo(distance)
29 //     } else {
30 //         bar()
31 //     }
32 // }
33 // ```
34 // ->
35 // ```
36 // enum Action { Move { distance: u32 }, Stop }
37 //
38 // fn handle(action: Action) {
39 //     match action {
40 //         Action::Move { distance } => foo(distance),
41 //         _ => bar(),
42 //     }
43 // }
44 // ```
45 pub(crate) fn replace_if_let_with_match(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
46     let if_expr: ast::IfExpr = ctx.find_node_at_offset()?;
47     let available_range = TextRange::new(
48         if_expr.syntax().text_range().start(),
49         if_expr.then_branch()?.syntax().text_range().start(),
50     );
51     let cursor_in_range = available_range.contains_range(ctx.frange.range);
52     if !cursor_in_range {
53         return None;
54     }
55     let mut else_block = None;
56     let if_exprs = successors(Some(if_expr.clone()), |expr| match expr.else_branch()? {
57         ast::ElseBranch::IfExpr(expr) => Some(expr),
58         ast::ElseBranch::Block(block) => {
59             else_block = Some(block);
60             None
61         }
62     });
63     let scrutinee_to_be_expr = if_expr.condition()?.expr()?;
64
65     let mut pat_seen = false;
66     let mut cond_bodies = Vec::new();
67     for if_expr in if_exprs {
68         let cond = if_expr.condition()?;
69         let expr = cond.expr()?;
70         let cond = match cond.pat() {
71             Some(pat) => {
72                 if scrutinee_to_be_expr.syntax().text() != expr.syntax().text() {
73                     // Only if all condition expressions are equal we can merge them into a match
74                     return None;
75                 }
76                 pat_seen = true;
77                 Either::Left(pat)
78             }
79             None => Either::Right(expr),
80         };
81         let body = if_expr.then_branch()?;
82         cond_bodies.push((cond, body));
83     }
84
85     if !pat_seen {
86         // Don't offer turning an if (chain) without patterns into a match
87         return None;
88     }
89
90     acc.add(
91         AssistId("replace_if_let_with_match", AssistKind::RefactorRewrite),
92         "Replace if let with match",
93         available_range,
94         move |edit| {
95             let match_expr = {
96                 let else_arm = make_else_arm(ctx, else_block, &cond_bodies);
97                 let make_match_arm = |(pat, body): (_, ast::BlockExpr)| {
98                     let body = body.reset_indent().indent(IndentLevel(1));
99                     match pat {
100                         Either::Left(pat) => {
101                             make::match_arm(iter::once(pat), None, unwrap_trivial_block(body))
102                         }
103                         Either::Right(expr) => make::match_arm(
104                             iter::once(make::wildcard_pat().into()),
105                             Some(expr),
106                             unwrap_trivial_block(body),
107                         ),
108                     }
109                 };
110                 let arms = cond_bodies.into_iter().map(make_match_arm).chain(iter::once(else_arm));
111                 let match_expr = make::expr_match(scrutinee_to_be_expr, make::match_arm_list(arms));
112                 match_expr.indent(IndentLevel::from_node(if_expr.syntax()))
113             };
114
115             let has_preceding_if_expr =
116                 if_expr.syntax().parent().map_or(false, |it| ast::IfExpr::can_cast(it.kind()));
117             let expr = if has_preceding_if_expr {
118                 // make sure we replace the `else if let ...` with a block so we don't end up with `else expr`
119                 make::block_expr(None, Some(match_expr)).into()
120             } else {
121                 match_expr
122             };
123             edit.replace_ast::<ast::Expr>(if_expr.into(), expr);
124         },
125     )
126 }
127
128 fn make_else_arm(
129     ctx: &AssistContext,
130     else_block: Option<ast::BlockExpr>,
131     conditionals: &[(Either<ast::Pat, ast::Expr>, ast::BlockExpr)],
132 ) -> ast::MatchArm {
133     if let Some(else_block) = else_block {
134         let pattern = if let [(Either::Left(pat), _)] = conditionals {
135             ctx.sema
136                 .type_of_pat(pat)
137                 .and_then(|ty| TryEnum::from_ty(&ctx.sema, &ty.adjusted()))
138                 .zip(Some(pat))
139         } else {
140             None
141         };
142         let pattern = match pattern {
143             Some((it, pat)) => {
144                 if does_pat_match_variant(pat, &it.sad_pattern()) {
145                     it.happy_pattern()
146                 } else {
147                     it.sad_pattern()
148                 }
149             }
150             None => make::wildcard_pat().into(),
151         };
152         make::match_arm(iter::once(pattern), None, unwrap_trivial_block(else_block))
153     } else {
154         make::match_arm(iter::once(make::wildcard_pat().into()), None, make::expr_unit())
155     }
156 }
157
158 // Assist: replace_match_with_if_let
159 //
160 // Replaces a binary `match` with a wildcard pattern and no guards with an `if let` expression.
161 //
162 // ```
163 // enum Action { Move { distance: u32 }, Stop }
164 //
165 // fn handle(action: Action) {
166 //     $0match action {
167 //         Action::Move { distance } => foo(distance),
168 //         _ => bar(),
169 //     }
170 // }
171 // ```
172 // ->
173 // ```
174 // enum Action { Move { distance: u32 }, Stop }
175 //
176 // fn handle(action: Action) {
177 //     if let Action::Move { distance } = action {
178 //         foo(distance)
179 //     } else {
180 //         bar()
181 //     }
182 // }
183 // ```
184 pub(crate) fn replace_match_with_if_let(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
185     let match_expr: ast::MatchExpr = ctx.find_node_at_offset()?;
186
187     let mut arms = match_expr.match_arm_list()?.arms();
188     let (first_arm, second_arm) = (arms.next()?, arms.next()?);
189     if arms.next().is_some() || first_arm.guard().is_some() || second_arm.guard().is_some() {
190         return None;
191     }
192
193     let (if_let_pat, then_expr, else_expr) = pick_pattern_and_expr_order(
194         &ctx.sema,
195         first_arm.pat()?,
196         second_arm.pat()?,
197         first_arm.expr()?,
198         second_arm.expr()?,
199     )?;
200     let scrutinee = match_expr.expr()?;
201
202     let target = match_expr.syntax().text_range();
203     acc.add(
204         AssistId("replace_match_with_if_let", AssistKind::RefactorRewrite),
205         "Replace match with if let",
206         target,
207         move |edit| {
208             let condition = make::condition(scrutinee, Some(if_let_pat));
209             let then_block = match then_expr.reset_indent() {
210                 ast::Expr::BlockExpr(block) => block,
211                 expr => make::block_expr(iter::empty(), Some(expr)),
212             };
213             let else_expr = if is_empty_expr(&else_expr) { None } else { Some(else_expr) };
214             let if_let_expr = make::expr_if(
215                 condition,
216                 then_block,
217                 else_expr
218                     .map(|expr| match expr {
219                         ast::Expr::BlockExpr(block) => block,
220                         expr => (make::block_expr(iter::empty(), Some(expr))),
221                     })
222                     .map(ast::ElseBranch::Block),
223             )
224             .indent(IndentLevel::from_node(match_expr.syntax()));
225
226             edit.replace_ast::<ast::Expr>(match_expr.into(), if_let_expr);
227         },
228     )
229 }
230
231 /// Pick the pattern for the if let condition and return the expressions for the `then` body and `else` body in that order.
232 fn pick_pattern_and_expr_order(
233     sema: &hir::Semantics<RootDatabase>,
234     pat: ast::Pat,
235     pat2: ast::Pat,
236     expr: ast::Expr,
237     expr2: ast::Expr,
238 ) -> Option<(ast::Pat, ast::Expr, ast::Expr)> {
239     let res = match (pat, pat2) {
240         (ast::Pat::WildcardPat(_), _) => return None,
241         (pat, _) if is_empty_expr(&expr2) => (pat, expr, expr2),
242         (_, pat) if is_empty_expr(&expr) => (pat, expr2, expr),
243         (pat, pat2) => match (binds_name(sema, &pat), binds_name(sema, &pat2)) {
244             (true, true) => return None,
245             (true, false) => (pat, expr, expr2),
246             (false, true) => (pat2, expr2, expr),
247             _ if is_sad_pat(sema, &pat) => (pat2, expr2, expr),
248             (false, false) => (pat, expr, expr2),
249         },
250     };
251     Some(res)
252 }
253
254 fn is_empty_expr(expr: &ast::Expr) -> bool {
255     match expr {
256         ast::Expr::BlockExpr(expr) => expr.is_empty(),
257         ast::Expr::TupleExpr(expr) => expr.fields().next().is_none(),
258         _ => false,
259     }
260 }
261
262 fn binds_name(sema: &hir::Semantics<RootDatabase>, pat: &ast::Pat) -> bool {
263     let binds_name_v = |pat| binds_name(sema, &pat);
264     match pat {
265         ast::Pat::IdentPat(pat) => !matches!(
266             pat.name().and_then(|name| NameClass::classify(sema, &name)),
267             Some(NameClass::ConstReference(_))
268         ),
269         ast::Pat::MacroPat(_) => true,
270         ast::Pat::OrPat(pat) => pat.pats().any(binds_name_v),
271         ast::Pat::SlicePat(pat) => pat.pats().any(binds_name_v),
272         ast::Pat::TuplePat(it) => it.fields().any(binds_name_v),
273         ast::Pat::TupleStructPat(it) => it.fields().any(binds_name_v),
274         ast::Pat::RecordPat(it) => it
275             .record_pat_field_list()
276             .map_or(false, |rpfl| rpfl.fields().flat_map(|rpf| rpf.pat()).any(binds_name_v)),
277         ast::Pat::RefPat(pat) => pat.pat().map_or(false, binds_name_v),
278         ast::Pat::BoxPat(pat) => pat.pat().map_or(false, binds_name_v),
279         ast::Pat::ParenPat(pat) => pat.pat().map_or(false, binds_name_v),
280         _ => false,
281     }
282 }
283
284 fn is_sad_pat(sema: &hir::Semantics<RootDatabase>, pat: &ast::Pat) -> bool {
285     sema.type_of_pat(pat)
286         .and_then(|ty| TryEnum::from_ty(sema, &ty.adjusted()))
287         .map_or(false, |it| does_pat_match_variant(pat, &it.sad_pattern()))
288 }
289
290 #[cfg(test)]
291 mod tests {
292     use super::*;
293
294     use crate::tests::{check_assist, check_assist_not_applicable, check_assist_target};
295
296     #[test]
297     fn test_if_let_with_match_unapplicable_for_simple_ifs() {
298         check_assist_not_applicable(
299             replace_if_let_with_match,
300             r#"
301 fn main() {
302     if $0true {} else if false {} else {}
303 }
304 "#,
305         )
306     }
307
308     #[test]
309     fn test_if_let_with_match_no_else() {
310         check_assist(
311             replace_if_let_with_match,
312             r#"
313 impl VariantData {
314     pub fn foo(&self) {
315         if $0let VariantData::Struct(..) = *self {
316             self.foo();
317         }
318     }
319 }
320 "#,
321             r#"
322 impl VariantData {
323     pub fn foo(&self) {
324         match *self {
325             VariantData::Struct(..) => {
326                 self.foo();
327             }
328             _ => (),
329         }
330     }
331 }
332 "#,
333         )
334     }
335
336     #[test]
337     fn test_if_let_with_match_available_range_left() {
338         check_assist_not_applicable(
339             replace_if_let_with_match,
340             r#"
341 impl VariantData {
342     pub fn foo(&self) {
343         $0 if let VariantData::Struct(..) = *self {
344             self.foo();
345         }
346     }
347 }
348 "#,
349         )
350     }
351
352     #[test]
353     fn test_if_let_with_match_available_range_right() {
354         check_assist_not_applicable(
355             replace_if_let_with_match,
356             r#"
357 impl VariantData {
358     pub fn foo(&self) {
359         if let VariantData::Struct(..) = *self {$0
360             self.foo();
361         }
362     }
363 }
364 "#,
365         )
366     }
367
368     #[test]
369     fn test_if_let_with_match_basic() {
370         check_assist(
371             replace_if_let_with_match,
372             r#"
373 impl VariantData {
374     pub fn is_struct(&self) -> bool {
375         if $0let VariantData::Struct(..) = *self {
376             true
377         } else if let VariantData::Tuple(..) = *self {
378             false
379         } else if cond() {
380             true
381         } else {
382             bar(
383                 123
384             )
385         }
386     }
387 }
388 "#,
389             r#"
390 impl VariantData {
391     pub fn is_struct(&self) -> bool {
392         match *self {
393             VariantData::Struct(..) => true,
394             VariantData::Tuple(..) => false,
395             _ if cond() => true,
396             _ => {
397                     bar(
398                         123
399                     )
400                 }
401         }
402     }
403 }
404 "#,
405         )
406     }
407
408     #[test]
409     fn test_if_let_with_match_on_tail_if_let() {
410         check_assist(
411             replace_if_let_with_match,
412             r#"
413 impl VariantData {
414     pub fn is_struct(&self) -> bool {
415         if let VariantData::Struct(..) = *self {
416             true
417         } else if let$0 VariantData::Tuple(..) = *self {
418             false
419         } else {
420             false
421         }
422     }
423 }
424 "#,
425             r#"
426 impl VariantData {
427     pub fn is_struct(&self) -> bool {
428         if let VariantData::Struct(..) = *self {
429             true
430         } else {
431     match *self {
432             VariantData::Tuple(..) => false,
433             _ => false,
434         }
435 }
436     }
437 }
438 "#,
439         )
440     }
441
442     #[test]
443     fn special_case_option() {
444         check_assist(
445             replace_if_let_with_match,
446             r#"
447 //- minicore: option
448 fn foo(x: Option<i32>) {
449     $0if let Some(x) = x {
450         println!("{}", x)
451     } else {
452         println!("none")
453     }
454 }
455 "#,
456             r#"
457 fn foo(x: Option<i32>) {
458     match x {
459         Some(x) => println!("{}", x),
460         None => println!("none"),
461     }
462 }
463 "#,
464         );
465     }
466
467     #[test]
468     fn special_case_inverted_option() {
469         check_assist(
470             replace_if_let_with_match,
471             r#"
472 //- minicore: option
473 fn foo(x: Option<i32>) {
474     $0if let None = x {
475         println!("none")
476     } else {
477         println!("some")
478     }
479 }
480 "#,
481             r#"
482 fn foo(x: Option<i32>) {
483     match x {
484         None => println!("none"),
485         Some(_) => println!("some"),
486     }
487 }
488 "#,
489         );
490     }
491
492     #[test]
493     fn special_case_result() {
494         check_assist(
495             replace_if_let_with_match,
496             r#"
497 //- minicore: result
498 fn foo(x: Result<i32, ()>) {
499     $0if let Ok(x) = x {
500         println!("{}", x)
501     } else {
502         println!("none")
503     }
504 }
505 "#,
506             r#"
507 fn foo(x: Result<i32, ()>) {
508     match x {
509         Ok(x) => println!("{}", x),
510         Err(_) => println!("none"),
511     }
512 }
513 "#,
514         );
515     }
516
517     #[test]
518     fn special_case_inverted_result() {
519         check_assist(
520             replace_if_let_with_match,
521             r#"
522 //- minicore: result
523 fn foo(x: Result<i32, ()>) {
524     $0if let Err(x) = x {
525         println!("{}", x)
526     } else {
527         println!("ok")
528     }
529 }
530 "#,
531             r#"
532 fn foo(x: Result<i32, ()>) {
533     match x {
534         Err(x) => println!("{}", x),
535         Ok(_) => println!("ok"),
536     }
537 }
538 "#,
539         );
540     }
541
542     #[test]
543     fn nested_indent() {
544         check_assist(
545             replace_if_let_with_match,
546             r#"
547 fn main() {
548     if true {
549         $0if let Ok(rel_path) = path.strip_prefix(root_path) {
550             let rel_path = RelativePathBuf::from_path(rel_path).ok()?;
551             Some((*id, rel_path))
552         } else {
553             None
554         }
555     }
556 }
557 "#,
558             r#"
559 fn main() {
560     if true {
561         match path.strip_prefix(root_path) {
562             Ok(rel_path) => {
563                 let rel_path = RelativePathBuf::from_path(rel_path).ok()?;
564                 Some((*id, rel_path))
565             }
566             _ => None,
567         }
568     }
569 }
570 "#,
571         )
572     }
573
574     #[test]
575     fn test_replace_match_with_if_let_unwraps_simple_expressions() {
576         check_assist(
577             replace_match_with_if_let,
578             r#"
579 impl VariantData {
580     pub fn is_struct(&self) -> bool {
581         $0match *self {
582             VariantData::Struct(..) => true,
583             _ => false,
584         }
585     }
586 }           "#,
587             r#"
588 impl VariantData {
589     pub fn is_struct(&self) -> bool {
590         if let VariantData::Struct(..) = *self {
591             true
592         } else {
593             false
594         }
595     }
596 }           "#,
597         )
598     }
599
600     #[test]
601     fn test_replace_match_with_if_let_doesnt_unwrap_multiline_expressions() {
602         check_assist(
603             replace_match_with_if_let,
604             r#"
605 fn foo() {
606     $0match a {
607         VariantData::Struct(..) => {
608             bar(
609                 123
610             )
611         }
612         _ => false,
613     }
614 }           "#,
615             r#"
616 fn foo() {
617     if let VariantData::Struct(..) = a {
618         bar(
619             123
620         )
621     } else {
622         false
623     }
624 }           "#,
625         )
626     }
627
628     #[test]
629     fn replace_match_with_if_let_target() {
630         check_assist_target(
631             replace_match_with_if_let,
632             r#"
633 impl VariantData {
634     pub fn is_struct(&self) -> bool {
635         $0match *self {
636             VariantData::Struct(..) => true,
637             _ => false,
638         }
639     }
640 }           "#,
641             r#"match *self {
642             VariantData::Struct(..) => true,
643             _ => false,
644         }"#,
645         );
646     }
647
648     #[test]
649     fn special_case_option_match_to_if_let() {
650         check_assist(
651             replace_match_with_if_let,
652             r#"
653 //- minicore: option
654 fn foo(x: Option<i32>) {
655     $0match x {
656         Some(x) => println!("{}", x),
657         None => println!("none"),
658     }
659 }
660 "#,
661             r#"
662 fn foo(x: Option<i32>) {
663     if let Some(x) = x {
664         println!("{}", x)
665     } else {
666         println!("none")
667     }
668 }
669 "#,
670         );
671     }
672
673     #[test]
674     fn special_case_result_match_to_if_let() {
675         check_assist(
676             replace_match_with_if_let,
677             r#"
678 //- minicore: result
679 fn foo(x: Result<i32, ()>) {
680     $0match x {
681         Ok(x) => println!("{}", x),
682         Err(_) => println!("none"),
683     }
684 }
685 "#,
686             r#"
687 fn foo(x: Result<i32, ()>) {
688     if let Ok(x) = x {
689         println!("{}", x)
690     } else {
691         println!("none")
692     }
693 }
694 "#,
695         );
696     }
697
698     #[test]
699     fn nested_indent_match_to_if_let() {
700         check_assist(
701             replace_match_with_if_let,
702             r#"
703 fn main() {
704     if true {
705         $0match path.strip_prefix(root_path) {
706             Ok(rel_path) => {
707                 let rel_path = RelativePathBuf::from_path(rel_path).ok()?;
708                 Some((*id, rel_path))
709             }
710             _ => None,
711         }
712     }
713 }
714 "#,
715             r#"
716 fn main() {
717     if true {
718         if let Ok(rel_path) = path.strip_prefix(root_path) {
719             let rel_path = RelativePathBuf::from_path(rel_path).ok()?;
720             Some((*id, rel_path))
721         } else {
722             None
723         }
724     }
725 }
726 "#,
727         )
728     }
729
730     #[test]
731     fn replace_match_with_if_let_empty_wildcard_expr() {
732         check_assist(
733             replace_match_with_if_let,
734             r#"
735 fn main() {
736     $0match path.strip_prefix(root_path) {
737         Ok(rel_path) => println!("{}", rel_path),
738         _ => (),
739     }
740 }
741 "#,
742             r#"
743 fn main() {
744     if let Ok(rel_path) = path.strip_prefix(root_path) {
745         println!("{}", rel_path)
746     }
747 }
748 "#,
749         )
750     }
751
752     #[test]
753     fn replace_match_with_if_let_number_body() {
754         check_assist(
755             replace_match_with_if_let,
756             r#"
757 fn main() {
758     $0match Ok(()) {
759         Ok(()) => {},
760         Err(_) => 0,
761     }
762 }
763 "#,
764             r#"
765 fn main() {
766     if let Err(_) = Ok(()) {
767         0
768     }
769 }
770 "#,
771         )
772     }
773
774     #[test]
775     fn replace_match_with_if_let_exhaustive() {
776         check_assist(
777             replace_match_with_if_let,
778             r#"
779 fn print_source(def_source: ModuleSource) {
780     match def_so$0urce {
781         ModuleSource::SourceFile(..) => { println!("source file"); }
782         ModuleSource::Module(..) => { println!("module"); }
783     }
784 }
785 "#,
786             r#"
787 fn print_source(def_source: ModuleSource) {
788     if let ModuleSource::SourceFile(..) = def_source { println!("source file"); } else { println!("module"); }
789 }
790 "#,
791         )
792     }
793
794     #[test]
795     fn replace_match_with_if_let_prefer_name_bind() {
796         check_assist(
797             replace_match_with_if_let,
798             r#"
799 fn foo() {
800     match $0Foo(0) {
801         Foo(_) => (),
802         Bar(bar) => println!("bar {}", bar),
803     }
804 }
805 "#,
806             r#"
807 fn foo() {
808     if let Bar(bar) = Foo(0) {
809         println!("bar {}", bar)
810     }
811 }
812 "#,
813         );
814         check_assist(
815             replace_match_with_if_let,
816             r#"
817 fn foo() {
818     match $0Foo(0) {
819         Bar(bar) => println!("bar {}", bar),
820         Foo(_) => (),
821     }
822 }
823 "#,
824             r#"
825 fn foo() {
826     if let Bar(bar) = Foo(0) {
827         println!("bar {}", bar)
828     }
829 }
830 "#,
831         );
832     }
833
834     #[test]
835     fn replace_match_with_if_let_prefer_nonempty_body() {
836         check_assist(
837             replace_match_with_if_let,
838             r#"
839 fn foo() {
840     match $0Ok(0) {
841         Ok(value) => {},
842         Err(err) => eprintln!("{}", err),
843     }
844 }
845 "#,
846             r#"
847 fn foo() {
848     if let Err(err) = Ok(0) {
849         eprintln!("{}", err)
850     }
851 }
852 "#,
853         );
854         check_assist(
855             replace_match_with_if_let,
856             r#"
857 fn foo() {
858     match $0Ok(0) {
859         Err(err) => eprintln!("{}", err),
860         Ok(value) => {},
861     }
862 }
863 "#,
864             r#"
865 fn foo() {
866     if let Err(err) = Ok(0) {
867         eprintln!("{}", err)
868     }
869 }
870 "#,
871         );
872     }
873
874     #[test]
875     fn replace_match_with_if_let_rejects_double_name_bindings() {
876         check_assist_not_applicable(
877             replace_match_with_if_let,
878             r#"
879 fn foo() {
880     match $0Foo(0) {
881         Foo(foo) => println!("bar {}", foo),
882         Bar(bar) => println!("bar {}", bar),
883     }
884 }
885 "#,
886         );
887     }
888 }