]> git.lizzy.rs Git - rust.git/blob - crates/ide_assists/src/handlers/replace_if_let_with_match.rs
Use NameClass::classify to check for ConstReference
[rust.git] / crates / ide_assists / src / handlers / replace_if_let_with_match.rs
1 use std::iter::{self, successors};
2
3 use either::Either;
4 use ide_db::{defs::NameClass, ty_filter::TryEnum, RootDatabase};
5 use syntax::{
6     ast::{
7         self,
8         edit::{AstNodeEdit, IndentLevel},
9         make, NameOwner,
10     },
11     AstNode,
12 };
13
14 use crate::{
15     utils::{does_pat_match_variant, unwrap_trivial_block},
16     AssistContext, AssistId, AssistKind, Assists,
17 };
18
19 // Assist: replace_if_let_with_match
20 //
21 // Replaces a `if let` expression with a `match` expression.
22 //
23 // ```
24 // enum Action { Move { distance: u32 }, Stop }
25 //
26 // fn handle(action: Action) {
27 //     $0if let Action::Move { distance } = action {
28 //         foo(distance)
29 //     } else {
30 //         bar()
31 //     }
32 // }
33 // ```
34 // ->
35 // ```
36 // enum Action { Move { distance: u32 }, Stop }
37 //
38 // fn handle(action: Action) {
39 //     match action {
40 //         Action::Move { distance } => foo(distance),
41 //         _ => bar(),
42 //     }
43 // }
44 // ```
45 pub(crate) fn replace_if_let_with_match(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
46     let if_expr: ast::IfExpr = ctx.find_node_at_offset()?;
47     let mut else_block = None;
48     let if_exprs = successors(Some(if_expr.clone()), |expr| match expr.else_branch()? {
49         ast::ElseBranch::IfExpr(expr) => Some(expr),
50         ast::ElseBranch::Block(block) => {
51             else_block = Some(block);
52             None
53         }
54     });
55     let scrutinee_to_be_expr = if_expr.condition()?.expr()?;
56
57     let mut pat_seen = false;
58     let mut cond_bodies = Vec::new();
59     for if_expr in if_exprs {
60         let cond = if_expr.condition()?;
61         let expr = cond.expr()?;
62         let cond = match cond.pat() {
63             Some(pat) => {
64                 if scrutinee_to_be_expr.syntax().text() != expr.syntax().text() {
65                     // Only if all condition expressions are equal we can merge them into a match
66                     return None;
67                 }
68                 pat_seen = true;
69                 Either::Left(pat)
70             }
71             None => Either::Right(expr),
72         };
73         let body = if_expr.then_branch()?;
74         cond_bodies.push((cond, body));
75     }
76
77     if !pat_seen {
78         // Don't offer turning an if (chain) without patterns into a match
79         return None;
80     }
81
82     let target = if_expr.syntax().text_range();
83     acc.add(
84         AssistId("replace_if_let_with_match", AssistKind::RefactorRewrite),
85         "Replace if let with match",
86         target,
87         move |edit| {
88             let match_expr = {
89                 let else_arm = make_else_arm(ctx, else_block, &cond_bodies);
90                 let make_match_arm = |(pat, body): (_, ast::BlockExpr)| {
91                     let body = body.reset_indent().indent(IndentLevel(1));
92                     match pat {
93                         Either::Left(pat) => {
94                             make::match_arm(iter::once(pat), None, unwrap_trivial_block(body))
95                         }
96                         Either::Right(expr) => make::match_arm(
97                             iter::once(make::wildcard_pat().into()),
98                             Some(expr),
99                             unwrap_trivial_block(body),
100                         ),
101                     }
102                 };
103                 let arms = cond_bodies.into_iter().map(make_match_arm).chain(iter::once(else_arm));
104                 let match_expr = make::expr_match(scrutinee_to_be_expr, make::match_arm_list(arms));
105                 match_expr.indent(IndentLevel::from_node(if_expr.syntax()))
106             };
107
108             let has_preceding_if_expr =
109                 if_expr.syntax().parent().map_or(false, |it| ast::IfExpr::can_cast(it.kind()));
110             let expr = if has_preceding_if_expr {
111                 // make sure we replace the `else if let ...` with a block so we don't end up with `else expr`
112                 make::block_expr(None, Some(match_expr)).into()
113             } else {
114                 match_expr
115             };
116             edit.replace_ast::<ast::Expr>(if_expr.into(), expr);
117         },
118     )
119 }
120
121 fn make_else_arm(
122     ctx: &AssistContext,
123     else_block: Option<ast::BlockExpr>,
124     conditionals: &[(Either<ast::Pat, ast::Expr>, ast::BlockExpr)],
125 ) -> ast::MatchArm {
126     if let Some(else_block) = else_block {
127         let pattern = if let [(Either::Left(pat), _)] = conditionals {
128             ctx.sema
129                 .type_of_pat(&pat)
130                 .and_then(|ty| TryEnum::from_ty(&ctx.sema, &ty.adjusted()))
131                 .zip(Some(pat))
132         } else {
133             None
134         };
135         let pattern = match pattern {
136             Some((it, pat)) => {
137                 if does_pat_match_variant(&pat, &it.sad_pattern()) {
138                     it.happy_pattern()
139                 } else {
140                     it.sad_pattern()
141                 }
142             }
143             None => make::wildcard_pat().into(),
144         };
145         make::match_arm(iter::once(pattern), None, unwrap_trivial_block(else_block))
146     } else {
147         make::match_arm(iter::once(make::wildcard_pat().into()), None, make::expr_unit().into())
148     }
149 }
150
151 // Assist: replace_match_with_if_let
152 //
153 // Replaces a binary `match` with a wildcard pattern and no guards with an `if let` expression.
154 //
155 // ```
156 // enum Action { Move { distance: u32 }, Stop }
157 //
158 // fn handle(action: Action) {
159 //     $0match action {
160 //         Action::Move { distance } => foo(distance),
161 //         _ => bar(),
162 //     }
163 // }
164 // ```
165 // ->
166 // ```
167 // enum Action { Move { distance: u32 }, Stop }
168 //
169 // fn handle(action: Action) {
170 //     if let Action::Move { distance } = action {
171 //         foo(distance)
172 //     } else {
173 //         bar()
174 //     }
175 // }
176 // ```
177 pub(crate) fn replace_match_with_if_let(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
178     let match_expr: ast::MatchExpr = ctx.find_node_at_offset()?;
179
180     let mut arms = match_expr.match_arm_list()?.arms();
181     let (first_arm, second_arm) = (arms.next()?, arms.next()?);
182     if arms.next().is_some() || first_arm.guard().is_some() || second_arm.guard().is_some() {
183         return None;
184     }
185
186     let (if_let_pat, then_expr, else_expr) = pick_pattern_and_expr_order(
187         &ctx.sema,
188         first_arm.pat()?,
189         second_arm.pat()?,
190         first_arm.expr()?,
191         second_arm.expr()?,
192     )?;
193     let scrutinee = match_expr.expr()?;
194
195     let target = match_expr.syntax().text_range();
196     acc.add(
197         AssistId("replace_match_with_if_let", AssistKind::RefactorRewrite),
198         "Replace match with if let",
199         target,
200         move |edit| {
201             let condition = make::condition(scrutinee, Some(if_let_pat));
202             let then_block = match then_expr.reset_indent() {
203                 ast::Expr::BlockExpr(block) => block,
204                 expr => make::block_expr(iter::empty(), Some(expr)),
205             };
206             let else_expr = match else_expr {
207                 ast::Expr::BlockExpr(block) if block.is_empty() => None,
208                 ast::Expr::TupleExpr(tuple) if tuple.fields().next().is_none() => None,
209                 expr => Some(expr),
210             };
211             let if_let_expr = make::expr_if(
212                 condition,
213                 then_block,
214                 else_expr
215                     .map(|expr| match expr {
216                         ast::Expr::BlockExpr(block) => block,
217                         expr => (make::block_expr(iter::empty(), Some(expr))),
218                     })
219                     .map(ast::ElseBranch::Block),
220             )
221             .indent(IndentLevel::from_node(match_expr.syntax()));
222
223             edit.replace_ast::<ast::Expr>(match_expr.into(), if_let_expr);
224         },
225     )
226 }
227
228 /// Pick the pattern for the if let condition and return the expressions for the `then` body and `else` body in that order.
229 fn pick_pattern_and_expr_order(
230     sema: &hir::Semantics<RootDatabase>,
231     pat: ast::Pat,
232     pat2: ast::Pat,
233     expr: ast::Expr,
234     expr2: ast::Expr,
235 ) -> Option<(ast::Pat, ast::Expr, ast::Expr)> {
236     let res = match (pat, pat2) {
237         (ast::Pat::WildcardPat(_), _) => return None,
238         (pat, _) if is_empty_expr(&expr2) => (pat, expr, expr2),
239         (_, pat) if is_empty_expr(&expr) => (pat, expr2, expr),
240         (pat, pat2) => match (binds_name(sema, &pat), binds_name(sema, &pat2)) {
241             (true, true) => return None,
242             (true, false) => (pat, expr, expr2),
243             (false, true) => (pat2, expr2, expr),
244             _ if is_sad_pat(sema, &pat2) => (pat, expr, expr2),
245             _ if is_sad_pat(sema, &pat) => (pat2, expr2, expr),
246             (false, false) => (pat, expr, expr2),
247         },
248     };
249     Some(res)
250 }
251
252 fn is_empty_expr(expr: &ast::Expr) -> bool {
253     match expr {
254         ast::Expr::BlockExpr(expr) => {
255             expr.statements().next().is_none() && expr.tail_expr().is_none()
256         }
257         ast::Expr::TupleExpr(expr) => expr.fields().next().is_none(),
258         _ => false,
259     }
260 }
261
262 fn binds_name(sema: &hir::Semantics<RootDatabase>, pat: &ast::Pat) -> bool {
263     let binds_name_v = |pat| binds_name(&sema, &pat);
264     match pat {
265         ast::Pat::IdentPat(pat) => {
266             match pat.name().and_then(|name| NameClass::classify(sema, &name)) {
267                 Some(NameClass::ConstReference(_)) => false,
268                 _ => true,
269             }
270         }
271         ast::Pat::MacroPat(_) => true,
272         ast::Pat::OrPat(pat) => pat.pats().any(binds_name_v),
273         ast::Pat::SlicePat(pat) => pat.pats().any(binds_name_v),
274         ast::Pat::TuplePat(it) => it.fields().any(binds_name_v),
275         ast::Pat::TupleStructPat(it) => it.fields().any(binds_name_v),
276         ast::Pat::RecordPat(it) => it
277             .record_pat_field_list()
278             .map_or(false, |rpfl| rpfl.fields().flat_map(|rpf| rpf.pat()).any(binds_name_v)),
279         ast::Pat::RefPat(pat) => pat.pat().map_or(false, binds_name_v),
280         ast::Pat::BoxPat(pat) => pat.pat().map_or(false, binds_name_v),
281         ast::Pat::ParenPat(pat) => pat.pat().map_or(false, binds_name_v),
282         _ => false,
283     }
284 }
285
286 fn is_sad_pat(sema: &hir::Semantics<RootDatabase>, pat: &ast::Pat) -> bool {
287     sema.type_of_pat(pat)
288         .and_then(|ty| TryEnum::from_ty(sema, &ty.adjusted()))
289         .map_or(false, |it| does_pat_match_variant(pat, &it.sad_pattern()))
290 }
291
292 #[cfg(test)]
293 mod tests {
294     use super::*;
295
296     use crate::tests::{check_assist, check_assist_not_applicable, check_assist_target};
297
298     #[test]
299     fn test_if_let_with_match_unapplicable_for_simple_ifs() {
300         check_assist_not_applicable(
301             replace_if_let_with_match,
302             r#"
303 fn main() {
304     if $0true {} else if false {} else {}
305 }
306 "#,
307         )
308     }
309
310     #[test]
311     fn test_if_let_with_match_no_else() {
312         check_assist(
313             replace_if_let_with_match,
314             r#"
315 impl VariantData {
316     pub fn foo(&self) {
317         if $0let VariantData::Struct(..) = *self {
318             self.foo();
319         }
320     }
321 }
322 "#,
323             r#"
324 impl VariantData {
325     pub fn foo(&self) {
326         match *self {
327             VariantData::Struct(..) => {
328                 self.foo();
329             }
330             _ => (),
331         }
332     }
333 }
334 "#,
335         )
336     }
337
338     #[test]
339     fn test_if_let_with_match_basic() {
340         check_assist(
341             replace_if_let_with_match,
342             r#"
343 impl VariantData {
344     pub fn is_struct(&self) -> bool {
345         if $0let VariantData::Struct(..) = *self {
346             true
347         } else if let VariantData::Tuple(..) = *self {
348             false
349         } else if cond() {
350             true
351         } else {
352             bar(
353                 123
354             )
355         }
356     }
357 }
358 "#,
359             r#"
360 impl VariantData {
361     pub fn is_struct(&self) -> bool {
362         match *self {
363             VariantData::Struct(..) => true,
364             VariantData::Tuple(..) => false,
365             _ if cond() => true,
366             _ => {
367                     bar(
368                         123
369                     )
370                 }
371         }
372     }
373 }
374 "#,
375         )
376     }
377
378     #[test]
379     fn test_if_let_with_match_on_tail_if_let() {
380         check_assist(
381             replace_if_let_with_match,
382             r#"
383 impl VariantData {
384     pub fn is_struct(&self) -> bool {
385         if let VariantData::Struct(..) = *self {
386             true
387         } else if let$0 VariantData::Tuple(..) = *self {
388             false
389         } else {
390             false
391         }
392     }
393 }
394 "#,
395             r#"
396 impl VariantData {
397     pub fn is_struct(&self) -> bool {
398         if let VariantData::Struct(..) = *self {
399             true
400         } else {
401     match *self {
402             VariantData::Tuple(..) => false,
403             _ => false,
404         }
405 }
406     }
407 }
408 "#,
409         )
410     }
411
412     #[test]
413     fn special_case_option() {
414         check_assist(
415             replace_if_let_with_match,
416             r#"
417 //- minicore: option
418 fn foo(x: Option<i32>) {
419     $0if let Some(x) = x {
420         println!("{}", x)
421     } else {
422         println!("none")
423     }
424 }
425 "#,
426             r#"
427 fn foo(x: Option<i32>) {
428     match x {
429         Some(x) => println!("{}", x),
430         None => println!("none"),
431     }
432 }
433 "#,
434         );
435     }
436
437     #[test]
438     fn special_case_inverted_option() {
439         check_assist(
440             replace_if_let_with_match,
441             r#"
442 //- minicore: option
443 fn foo(x: Option<i32>) {
444     $0if let None = x {
445         println!("none")
446     } else {
447         println!("some")
448     }
449 }
450 "#,
451             r#"
452 fn foo(x: Option<i32>) {
453     match x {
454         None => println!("none"),
455         Some(_) => println!("some"),
456     }
457 }
458 "#,
459         );
460     }
461
462     #[test]
463     fn special_case_result() {
464         check_assist(
465             replace_if_let_with_match,
466             r#"
467 //- minicore: result
468 fn foo(x: Result<i32, ()>) {
469     $0if let Ok(x) = x {
470         println!("{}", x)
471     } else {
472         println!("none")
473     }
474 }
475 "#,
476             r#"
477 fn foo(x: Result<i32, ()>) {
478     match x {
479         Ok(x) => println!("{}", x),
480         Err(_) => println!("none"),
481     }
482 }
483 "#,
484         );
485     }
486
487     #[test]
488     fn special_case_inverted_result() {
489         check_assist(
490             replace_if_let_with_match,
491             r#"
492 //- minicore: result
493 fn foo(x: Result<i32, ()>) {
494     $0if let Err(x) = x {
495         println!("{}", x)
496     } else {
497         println!("ok")
498     }
499 }
500 "#,
501             r#"
502 fn foo(x: Result<i32, ()>) {
503     match x {
504         Err(x) => println!("{}", x),
505         Ok(_) => println!("ok"),
506     }
507 }
508 "#,
509         );
510     }
511
512     #[test]
513     fn nested_indent() {
514         check_assist(
515             replace_if_let_with_match,
516             r#"
517 fn main() {
518     if true {
519         $0if let Ok(rel_path) = path.strip_prefix(root_path) {
520             let rel_path = RelativePathBuf::from_path(rel_path).ok()?;
521             Some((*id, rel_path))
522         } else {
523             None
524         }
525     }
526 }
527 "#,
528             r#"
529 fn main() {
530     if true {
531         match path.strip_prefix(root_path) {
532             Ok(rel_path) => {
533                 let rel_path = RelativePathBuf::from_path(rel_path).ok()?;
534                 Some((*id, rel_path))
535             }
536             _ => None,
537         }
538     }
539 }
540 "#,
541         )
542     }
543
544     #[test]
545     fn test_replace_match_with_if_let_unwraps_simple_expressions() {
546         check_assist(
547             replace_match_with_if_let,
548             r#"
549 impl VariantData {
550     pub fn is_struct(&self) -> bool {
551         $0match *self {
552             VariantData::Struct(..) => true,
553             _ => false,
554         }
555     }
556 }           "#,
557             r#"
558 impl VariantData {
559     pub fn is_struct(&self) -> bool {
560         if let VariantData::Struct(..) = *self {
561             true
562         } else {
563             false
564         }
565     }
566 }           "#,
567         )
568     }
569
570     #[test]
571     fn test_replace_match_with_if_let_doesnt_unwrap_multiline_expressions() {
572         check_assist(
573             replace_match_with_if_let,
574             r#"
575 fn foo() {
576     $0match a {
577         VariantData::Struct(..) => {
578             bar(
579                 123
580             )
581         }
582         _ => false,
583     }
584 }           "#,
585             r#"
586 fn foo() {
587     if let VariantData::Struct(..) = a {
588         bar(
589             123
590         )
591     } else {
592         false
593     }
594 }           "#,
595         )
596     }
597
598     #[test]
599     fn replace_match_with_if_let_target() {
600         check_assist_target(
601             replace_match_with_if_let,
602             r#"
603 impl VariantData {
604     pub fn is_struct(&self) -> bool {
605         $0match *self {
606             VariantData::Struct(..) => true,
607             _ => false,
608         }
609     }
610 }           "#,
611             r#"match *self {
612             VariantData::Struct(..) => true,
613             _ => false,
614         }"#,
615         );
616     }
617
618     #[test]
619     fn special_case_option_match_to_if_let() {
620         check_assist(
621             replace_match_with_if_let,
622             r#"
623 //- minicore: option
624 fn foo(x: Option<i32>) {
625     $0match x {
626         Some(x) => println!("{}", x),
627         None => println!("none"),
628     }
629 }
630 "#,
631             r#"
632 fn foo(x: Option<i32>) {
633     if let Some(x) = x {
634         println!("{}", x)
635     } else {
636         println!("none")
637     }
638 }
639 "#,
640         );
641     }
642
643     #[test]
644     fn special_case_result_match_to_if_let() {
645         check_assist(
646             replace_match_with_if_let,
647             r#"
648 //- minicore: result
649 fn foo(x: Result<i32, ()>) {
650     $0match x {
651         Ok(x) => println!("{}", x),
652         Err(_) => println!("none"),
653     }
654 }
655 "#,
656             r#"
657 fn foo(x: Result<i32, ()>) {
658     if let Ok(x) = x {
659         println!("{}", x)
660     } else {
661         println!("none")
662     }
663 }
664 "#,
665         );
666     }
667
668     #[test]
669     fn nested_indent_match_to_if_let() {
670         check_assist(
671             replace_match_with_if_let,
672             r#"
673 fn main() {
674     if true {
675         $0match path.strip_prefix(root_path) {
676             Ok(rel_path) => {
677                 let rel_path = RelativePathBuf::from_path(rel_path).ok()?;
678                 Some((*id, rel_path))
679             }
680             _ => None,
681         }
682     }
683 }
684 "#,
685             r#"
686 fn main() {
687     if true {
688         if let Ok(rel_path) = path.strip_prefix(root_path) {
689             let rel_path = RelativePathBuf::from_path(rel_path).ok()?;
690             Some((*id, rel_path))
691         } else {
692             None
693         }
694     }
695 }
696 "#,
697         )
698     }
699
700     #[test]
701     fn replace_match_with_if_let_empty_wildcard_expr() {
702         check_assist(
703             replace_match_with_if_let,
704             r#"
705 fn main() {
706     $0match path.strip_prefix(root_path) {
707         Ok(rel_path) => println!("{}", rel_path),
708         _ => (),
709     }
710 }
711 "#,
712             r#"
713 fn main() {
714     if let Ok(rel_path) = path.strip_prefix(root_path) {
715         println!("{}", rel_path)
716     }
717 }
718 "#,
719         )
720     }
721
722     #[test]
723     fn replace_match_with_if_let_number_body() {
724         check_assist(
725             replace_match_with_if_let,
726             r#"
727 fn main() {
728     $0match Ok(()) {
729         Ok(()) => {},
730         Err(_) => 0,
731     }
732 }
733 "#,
734             r#"
735 fn main() {
736     if let Err(_) = Ok(()) {
737         0
738     }
739 }
740 "#,
741         )
742     }
743
744     #[test]
745     fn replace_match_with_if_let_exhaustive() {
746         check_assist(
747             replace_match_with_if_let,
748             r#"
749 fn print_source(def_source: ModuleSource) {
750     match def_so$0urce {
751         ModuleSource::SourceFile(..) => { println!("source file"); }
752         ModuleSource::Module(..) => { println!("module"); }
753     }
754 }
755 "#,
756             r#"
757 fn print_source(def_source: ModuleSource) {
758     if let ModuleSource::SourceFile(..) = def_source { println!("source file"); } else { println!("module"); }
759 }
760 "#,
761         )
762     }
763
764     #[test]
765     fn replace_match_with_if_let_prefer_name_bind() {
766         check_assist(
767             replace_match_with_if_let,
768             r#"
769 fn foo() {
770     match $0Foo(0) {
771         Foo(_) => (),
772         Bar(bar) => println!("bar {}", bar),
773     }
774 }
775 "#,
776             r#"
777 fn foo() {
778     if let Bar(bar) = Foo(0) {
779         println!("bar {}", bar)
780     }
781 }
782 "#,
783         );
784         check_assist(
785             replace_match_with_if_let,
786             r#"
787 fn foo() {
788     match $0Foo(0) {
789         Bar(bar) => println!("bar {}", bar),
790         Foo(_) => (),
791     }
792 }
793 "#,
794             r#"
795 fn foo() {
796     if let Bar(bar) = Foo(0) {
797         println!("bar {}", bar)
798     }
799 }
800 "#,
801         );
802     }
803
804     #[test]
805     fn replace_match_with_if_let_prefer_nonempty_body() {
806         check_assist(
807             replace_match_with_if_let,
808             r#"
809 fn foo() {
810     match $0Ok(0) {
811         Ok(value) => {},
812         Err(err) => eprintln!("{}", err),
813     }
814 }
815 "#,
816             r#"
817 fn foo() {
818     if let Err(err) = Ok(0) {
819         eprintln!("{}", err)
820     }
821 }
822 "#,
823         );
824         check_assist(
825             replace_match_with_if_let,
826             r#"
827 fn foo() {
828     match $0Ok(0) {
829         Err(err) => eprintln!("{}", err),
830         Ok(value) => {},
831     }
832 }
833 "#,
834             r#"
835 fn foo() {
836     if let Err(err) = Ok(0) {
837         eprintln!("{}", err)
838     }
839 }
840 "#,
841         );
842     }
843
844     #[test]
845     fn replace_match_with_if_let_rejects_double_name_bindings() {
846         check_assist_not_applicable(
847             replace_match_with_if_let,
848             r#"
849 fn foo() {
850     match $0Foo(0) {
851         Foo(foo) => println!("bar {}", foo),
852         Bar(bar) => println!("bar {}", bar),
853     }
854 }
855 "#,
856         );
857     }
858 }