]> git.lizzy.rs Git - rust.git/blob - crates/ide_assists/src/handlers/replace_if_let_with_match.rs
Rename `*Owner` traits to `Has*`
[rust.git] / crates / ide_assists / src / handlers / replace_if_let_with_match.rs
1 use std::iter::{self, successors};
2
3 use either::Either;
4 use ide_db::{defs::NameClass, ty_filter::TryEnum, RootDatabase};
5 use syntax::{
6     ast::{
7         self,
8         edit::{AstNodeEdit, IndentLevel},
9         make, HasName,
10     },
11     AstNode, TextRange,
12 };
13
14 use crate::{
15     utils::{does_pat_match_variant, unwrap_trivial_block},
16     AssistContext, AssistId, AssistKind, Assists,
17 };
18
19 // Assist: replace_if_let_with_match
20 //
21 // Replaces a `if let` expression with a `match` expression.
22 //
23 // ```
24 // enum Action { Move { distance: u32 }, Stop }
25 //
26 // fn handle(action: Action) {
27 //     $0if let Action::Move { distance } = action {
28 //         foo(distance)
29 //     } else {
30 //         bar()
31 //     }
32 // }
33 // ```
34 // ->
35 // ```
36 // enum Action { Move { distance: u32 }, Stop }
37 //
38 // fn handle(action: Action) {
39 //     match action {
40 //         Action::Move { distance } => foo(distance),
41 //         _ => bar(),
42 //     }
43 // }
44 // ```
45 pub(crate) fn replace_if_let_with_match(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
46     let if_expr: ast::IfExpr = ctx.find_node_at_offset()?;
47     let available_range = TextRange::new(
48         if_expr.syntax().text_range().start(),
49         if_expr.then_branch()?.syntax().text_range().start(),
50     );
51     let cursor_in_range = available_range.contains_range(ctx.frange.range);
52     if !cursor_in_range {
53         return None;
54     }
55     let mut else_block = None;
56     let if_exprs = successors(Some(if_expr.clone()), |expr| match expr.else_branch()? {
57         ast::ElseBranch::IfExpr(expr) => Some(expr),
58         ast::ElseBranch::Block(block) => {
59             else_block = Some(block);
60             None
61         }
62     });
63     let scrutinee_to_be_expr = if_expr.condition()?.expr()?;
64
65     let mut pat_seen = false;
66     let mut cond_bodies = Vec::new();
67     for if_expr in if_exprs {
68         let cond = if_expr.condition()?;
69         let expr = cond.expr()?;
70         let cond = match cond.pat() {
71             Some(pat) => {
72                 if scrutinee_to_be_expr.syntax().text() != expr.syntax().text() {
73                     // Only if all condition expressions are equal we can merge them into a match
74                     return None;
75                 }
76                 pat_seen = true;
77                 Either::Left(pat)
78             }
79             None => Either::Right(expr),
80         };
81         let body = if_expr.then_branch()?;
82         cond_bodies.push((cond, body));
83     }
84
85     if !pat_seen {
86         // Don't offer turning an if (chain) without patterns into a match
87         return None;
88     }
89
90     acc.add(
91         AssistId("replace_if_let_with_match", AssistKind::RefactorRewrite),
92         "Replace if let with match",
93         available_range,
94         move |edit| {
95             let match_expr = {
96                 let else_arm = make_else_arm(ctx, else_block, &cond_bodies);
97                 let make_match_arm = |(pat, body): (_, ast::BlockExpr)| {
98                     let body = body.reset_indent().indent(IndentLevel(1));
99                     match pat {
100                         Either::Left(pat) => {
101                             make::match_arm(iter::once(pat), None, unwrap_trivial_block(body))
102                         }
103                         Either::Right(expr) => make::match_arm(
104                             iter::once(make::wildcard_pat().into()),
105                             Some(expr),
106                             unwrap_trivial_block(body),
107                         ),
108                     }
109                 };
110                 let arms = cond_bodies.into_iter().map(make_match_arm).chain(iter::once(else_arm));
111                 let match_expr = make::expr_match(scrutinee_to_be_expr, make::match_arm_list(arms));
112                 match_expr.indent(IndentLevel::from_node(if_expr.syntax()))
113             };
114
115             let has_preceding_if_expr =
116                 if_expr.syntax().parent().map_or(false, |it| ast::IfExpr::can_cast(it.kind()));
117             let expr = if has_preceding_if_expr {
118                 // make sure we replace the `else if let ...` with a block so we don't end up with `else expr`
119                 make::block_expr(None, Some(match_expr)).into()
120             } else {
121                 match_expr
122             };
123             edit.replace_ast::<ast::Expr>(if_expr.into(), expr);
124         },
125     )
126 }
127
128 fn make_else_arm(
129     ctx: &AssistContext,
130     else_block: Option<ast::BlockExpr>,
131     conditionals: &[(Either<ast::Pat, ast::Expr>, ast::BlockExpr)],
132 ) -> ast::MatchArm {
133     if let Some(else_block) = else_block {
134         let pattern = if let [(Either::Left(pat), _)] = conditionals {
135             ctx.sema
136                 .type_of_pat(pat)
137                 .and_then(|ty| TryEnum::from_ty(&ctx.sema, &ty.adjusted()))
138                 .zip(Some(pat))
139         } else {
140             None
141         };
142         let pattern = match pattern {
143             Some((it, pat)) => {
144                 if does_pat_match_variant(pat, &it.sad_pattern()) {
145                     it.happy_pattern()
146                 } else {
147                     it.sad_pattern()
148                 }
149             }
150             None => make::wildcard_pat().into(),
151         };
152         make::match_arm(iter::once(pattern), None, unwrap_trivial_block(else_block))
153     } else {
154         make::match_arm(iter::once(make::wildcard_pat().into()), None, make::expr_unit())
155     }
156 }
157
158 // Assist: replace_match_with_if_let
159 //
160 // Replaces a binary `match` with a wildcard pattern and no guards with an `if let` expression.
161 //
162 // ```
163 // enum Action { Move { distance: u32 }, Stop }
164 //
165 // fn handle(action: Action) {
166 //     $0match action {
167 //         Action::Move { distance } => foo(distance),
168 //         _ => bar(),
169 //     }
170 // }
171 // ```
172 // ->
173 // ```
174 // enum Action { Move { distance: u32 }, Stop }
175 //
176 // fn handle(action: Action) {
177 //     if let Action::Move { distance } = action {
178 //         foo(distance)
179 //     } else {
180 //         bar()
181 //     }
182 // }
183 // ```
184 pub(crate) fn replace_match_with_if_let(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
185     let match_expr: ast::MatchExpr = ctx.find_node_at_offset()?;
186
187     let mut arms = match_expr.match_arm_list()?.arms();
188     let (first_arm, second_arm) = (arms.next()?, arms.next()?);
189     if arms.next().is_some() || first_arm.guard().is_some() || second_arm.guard().is_some() {
190         return None;
191     }
192
193     let (if_let_pat, then_expr, else_expr) = pick_pattern_and_expr_order(
194         &ctx.sema,
195         first_arm.pat()?,
196         second_arm.pat()?,
197         first_arm.expr()?,
198         second_arm.expr()?,
199     )?;
200     let scrutinee = match_expr.expr()?;
201
202     let target = match_expr.syntax().text_range();
203     acc.add(
204         AssistId("replace_match_with_if_let", AssistKind::RefactorRewrite),
205         "Replace match with if let",
206         target,
207         move |edit| {
208             let condition = make::condition(scrutinee, Some(if_let_pat));
209             let then_block = match then_expr.reset_indent() {
210                 ast::Expr::BlockExpr(block) => block,
211                 expr => make::block_expr(iter::empty(), Some(expr)),
212             };
213             let else_expr = match else_expr {
214                 ast::Expr::BlockExpr(block) if block.is_empty() => None,
215                 ast::Expr::TupleExpr(tuple) if tuple.fields().next().is_none() => None,
216                 expr => Some(expr),
217             };
218             let if_let_expr = make::expr_if(
219                 condition,
220                 then_block,
221                 else_expr
222                     .map(|expr| match expr {
223                         ast::Expr::BlockExpr(block) => block,
224                         expr => (make::block_expr(iter::empty(), Some(expr))),
225                     })
226                     .map(ast::ElseBranch::Block),
227             )
228             .indent(IndentLevel::from_node(match_expr.syntax()));
229
230             edit.replace_ast::<ast::Expr>(match_expr.into(), if_let_expr);
231         },
232     )
233 }
234
235 /// Pick the pattern for the if let condition and return the expressions for the `then` body and `else` body in that order.
236 fn pick_pattern_and_expr_order(
237     sema: &hir::Semantics<RootDatabase>,
238     pat: ast::Pat,
239     pat2: ast::Pat,
240     expr: ast::Expr,
241     expr2: ast::Expr,
242 ) -> Option<(ast::Pat, ast::Expr, ast::Expr)> {
243     let res = match (pat, pat2) {
244         (ast::Pat::WildcardPat(_), _) => return None,
245         (pat, _) if is_empty_expr(&expr2) => (pat, expr, expr2),
246         (_, pat) if is_empty_expr(&expr) => (pat, expr2, expr),
247         (pat, pat2) => match (binds_name(sema, &pat), binds_name(sema, &pat2)) {
248             (true, true) => return None,
249             (true, false) => (pat, expr, expr2),
250             (false, true) => (pat2, expr2, expr),
251             _ if is_sad_pat(sema, &pat) => (pat2, expr2, expr),
252             (false, false) => (pat, expr, expr2),
253         },
254     };
255     Some(res)
256 }
257
258 fn is_empty_expr(expr: &ast::Expr) -> bool {
259     match expr {
260         ast::Expr::BlockExpr(expr) => expr.is_empty(),
261         ast::Expr::TupleExpr(expr) => expr.fields().next().is_none(),
262         _ => false,
263     }
264 }
265
266 fn binds_name(sema: &hir::Semantics<RootDatabase>, pat: &ast::Pat) -> bool {
267     let binds_name_v = |pat| binds_name(sema, &pat);
268     match pat {
269         ast::Pat::IdentPat(pat) => !matches!(
270             pat.name().and_then(|name| NameClass::classify(sema, &name)),
271             Some(NameClass::ConstReference(_))
272         ),
273         ast::Pat::MacroPat(_) => true,
274         ast::Pat::OrPat(pat) => pat.pats().any(binds_name_v),
275         ast::Pat::SlicePat(pat) => pat.pats().any(binds_name_v),
276         ast::Pat::TuplePat(it) => it.fields().any(binds_name_v),
277         ast::Pat::TupleStructPat(it) => it.fields().any(binds_name_v),
278         ast::Pat::RecordPat(it) => it
279             .record_pat_field_list()
280             .map_or(false, |rpfl| rpfl.fields().flat_map(|rpf| rpf.pat()).any(binds_name_v)),
281         ast::Pat::RefPat(pat) => pat.pat().map_or(false, binds_name_v),
282         ast::Pat::BoxPat(pat) => pat.pat().map_or(false, binds_name_v),
283         ast::Pat::ParenPat(pat) => pat.pat().map_or(false, binds_name_v),
284         _ => false,
285     }
286 }
287
288 fn is_sad_pat(sema: &hir::Semantics<RootDatabase>, pat: &ast::Pat) -> bool {
289     sema.type_of_pat(pat)
290         .and_then(|ty| TryEnum::from_ty(sema, &ty.adjusted()))
291         .map_or(false, |it| does_pat_match_variant(pat, &it.sad_pattern()))
292 }
293
294 #[cfg(test)]
295 mod tests {
296     use super::*;
297
298     use crate::tests::{check_assist, check_assist_not_applicable, check_assist_target};
299
300     #[test]
301     fn test_if_let_with_match_unapplicable_for_simple_ifs() {
302         check_assist_not_applicable(
303             replace_if_let_with_match,
304             r#"
305 fn main() {
306     if $0true {} else if false {} else {}
307 }
308 "#,
309         )
310     }
311
312     #[test]
313     fn test_if_let_with_match_no_else() {
314         check_assist(
315             replace_if_let_with_match,
316             r#"
317 impl VariantData {
318     pub fn foo(&self) {
319         if $0let VariantData::Struct(..) = *self {
320             self.foo();
321         }
322     }
323 }
324 "#,
325             r#"
326 impl VariantData {
327     pub fn foo(&self) {
328         match *self {
329             VariantData::Struct(..) => {
330                 self.foo();
331             }
332             _ => (),
333         }
334     }
335 }
336 "#,
337         )
338     }
339
340     #[test]
341     fn test_if_let_with_match_available_range_left() {
342         check_assist_not_applicable(
343             replace_if_let_with_match,
344             r#"
345 impl VariantData {
346     pub fn foo(&self) {
347         $0 if let VariantData::Struct(..) = *self {
348             self.foo();
349         }
350     }
351 }
352 "#,
353         )
354     }
355
356     #[test]
357     fn test_if_let_with_match_available_range_right() {
358         check_assist_not_applicable(
359             replace_if_let_with_match,
360             r#"
361 impl VariantData {
362     pub fn foo(&self) {
363         if let VariantData::Struct(..) = *self {$0
364             self.foo();
365         }
366     }
367 }
368 "#,
369         )
370     }
371
372     #[test]
373     fn test_if_let_with_match_basic() {
374         check_assist(
375             replace_if_let_with_match,
376             r#"
377 impl VariantData {
378     pub fn is_struct(&self) -> bool {
379         if $0let VariantData::Struct(..) = *self {
380             true
381         } else if let VariantData::Tuple(..) = *self {
382             false
383         } else if cond() {
384             true
385         } else {
386             bar(
387                 123
388             )
389         }
390     }
391 }
392 "#,
393             r#"
394 impl VariantData {
395     pub fn is_struct(&self) -> bool {
396         match *self {
397             VariantData::Struct(..) => true,
398             VariantData::Tuple(..) => false,
399             _ if cond() => true,
400             _ => {
401                     bar(
402                         123
403                     )
404                 }
405         }
406     }
407 }
408 "#,
409         )
410     }
411
412     #[test]
413     fn test_if_let_with_match_on_tail_if_let() {
414         check_assist(
415             replace_if_let_with_match,
416             r#"
417 impl VariantData {
418     pub fn is_struct(&self) -> bool {
419         if let VariantData::Struct(..) = *self {
420             true
421         } else if let$0 VariantData::Tuple(..) = *self {
422             false
423         } else {
424             false
425         }
426     }
427 }
428 "#,
429             r#"
430 impl VariantData {
431     pub fn is_struct(&self) -> bool {
432         if let VariantData::Struct(..) = *self {
433             true
434         } else {
435     match *self {
436             VariantData::Tuple(..) => false,
437             _ => false,
438         }
439 }
440     }
441 }
442 "#,
443         )
444     }
445
446     #[test]
447     fn special_case_option() {
448         check_assist(
449             replace_if_let_with_match,
450             r#"
451 //- minicore: option
452 fn foo(x: Option<i32>) {
453     $0if let Some(x) = x {
454         println!("{}", x)
455     } else {
456         println!("none")
457     }
458 }
459 "#,
460             r#"
461 fn foo(x: Option<i32>) {
462     match x {
463         Some(x) => println!("{}", x),
464         None => println!("none"),
465     }
466 }
467 "#,
468         );
469     }
470
471     #[test]
472     fn special_case_inverted_option() {
473         check_assist(
474             replace_if_let_with_match,
475             r#"
476 //- minicore: option
477 fn foo(x: Option<i32>) {
478     $0if let None = x {
479         println!("none")
480     } else {
481         println!("some")
482     }
483 }
484 "#,
485             r#"
486 fn foo(x: Option<i32>) {
487     match x {
488         None => println!("none"),
489         Some(_) => println!("some"),
490     }
491 }
492 "#,
493         );
494     }
495
496     #[test]
497     fn special_case_result() {
498         check_assist(
499             replace_if_let_with_match,
500             r#"
501 //- minicore: result
502 fn foo(x: Result<i32, ()>) {
503     $0if let Ok(x) = x {
504         println!("{}", x)
505     } else {
506         println!("none")
507     }
508 }
509 "#,
510             r#"
511 fn foo(x: Result<i32, ()>) {
512     match x {
513         Ok(x) => println!("{}", x),
514         Err(_) => println!("none"),
515     }
516 }
517 "#,
518         );
519     }
520
521     #[test]
522     fn special_case_inverted_result() {
523         check_assist(
524             replace_if_let_with_match,
525             r#"
526 //- minicore: result
527 fn foo(x: Result<i32, ()>) {
528     $0if let Err(x) = x {
529         println!("{}", x)
530     } else {
531         println!("ok")
532     }
533 }
534 "#,
535             r#"
536 fn foo(x: Result<i32, ()>) {
537     match x {
538         Err(x) => println!("{}", x),
539         Ok(_) => println!("ok"),
540     }
541 }
542 "#,
543         );
544     }
545
546     #[test]
547     fn nested_indent() {
548         check_assist(
549             replace_if_let_with_match,
550             r#"
551 fn main() {
552     if true {
553         $0if let Ok(rel_path) = path.strip_prefix(root_path) {
554             let rel_path = RelativePathBuf::from_path(rel_path).ok()?;
555             Some((*id, rel_path))
556         } else {
557             None
558         }
559     }
560 }
561 "#,
562             r#"
563 fn main() {
564     if true {
565         match path.strip_prefix(root_path) {
566             Ok(rel_path) => {
567                 let rel_path = RelativePathBuf::from_path(rel_path).ok()?;
568                 Some((*id, rel_path))
569             }
570             _ => None,
571         }
572     }
573 }
574 "#,
575         )
576     }
577
578     #[test]
579     fn test_replace_match_with_if_let_unwraps_simple_expressions() {
580         check_assist(
581             replace_match_with_if_let,
582             r#"
583 impl VariantData {
584     pub fn is_struct(&self) -> bool {
585         $0match *self {
586             VariantData::Struct(..) => true,
587             _ => false,
588         }
589     }
590 }           "#,
591             r#"
592 impl VariantData {
593     pub fn is_struct(&self) -> bool {
594         if let VariantData::Struct(..) = *self {
595             true
596         } else {
597             false
598         }
599     }
600 }           "#,
601         )
602     }
603
604     #[test]
605     fn test_replace_match_with_if_let_doesnt_unwrap_multiline_expressions() {
606         check_assist(
607             replace_match_with_if_let,
608             r#"
609 fn foo() {
610     $0match a {
611         VariantData::Struct(..) => {
612             bar(
613                 123
614             )
615         }
616         _ => false,
617     }
618 }           "#,
619             r#"
620 fn foo() {
621     if let VariantData::Struct(..) = a {
622         bar(
623             123
624         )
625     } else {
626         false
627     }
628 }           "#,
629         )
630     }
631
632     #[test]
633     fn replace_match_with_if_let_target() {
634         check_assist_target(
635             replace_match_with_if_let,
636             r#"
637 impl VariantData {
638     pub fn is_struct(&self) -> bool {
639         $0match *self {
640             VariantData::Struct(..) => true,
641             _ => false,
642         }
643     }
644 }           "#,
645             r#"match *self {
646             VariantData::Struct(..) => true,
647             _ => false,
648         }"#,
649         );
650     }
651
652     #[test]
653     fn special_case_option_match_to_if_let() {
654         check_assist(
655             replace_match_with_if_let,
656             r#"
657 //- minicore: option
658 fn foo(x: Option<i32>) {
659     $0match x {
660         Some(x) => println!("{}", x),
661         None => println!("none"),
662     }
663 }
664 "#,
665             r#"
666 fn foo(x: Option<i32>) {
667     if let Some(x) = x {
668         println!("{}", x)
669     } else {
670         println!("none")
671     }
672 }
673 "#,
674         );
675     }
676
677     #[test]
678     fn special_case_result_match_to_if_let() {
679         check_assist(
680             replace_match_with_if_let,
681             r#"
682 //- minicore: result
683 fn foo(x: Result<i32, ()>) {
684     $0match x {
685         Ok(x) => println!("{}", x),
686         Err(_) => println!("none"),
687     }
688 }
689 "#,
690             r#"
691 fn foo(x: Result<i32, ()>) {
692     if let Ok(x) = x {
693         println!("{}", x)
694     } else {
695         println!("none")
696     }
697 }
698 "#,
699         );
700     }
701
702     #[test]
703     fn nested_indent_match_to_if_let() {
704         check_assist(
705             replace_match_with_if_let,
706             r#"
707 fn main() {
708     if true {
709         $0match path.strip_prefix(root_path) {
710             Ok(rel_path) => {
711                 let rel_path = RelativePathBuf::from_path(rel_path).ok()?;
712                 Some((*id, rel_path))
713             }
714             _ => None,
715         }
716     }
717 }
718 "#,
719             r#"
720 fn main() {
721     if true {
722         if let Ok(rel_path) = path.strip_prefix(root_path) {
723             let rel_path = RelativePathBuf::from_path(rel_path).ok()?;
724             Some((*id, rel_path))
725         } else {
726             None
727         }
728     }
729 }
730 "#,
731         )
732     }
733
734     #[test]
735     fn replace_match_with_if_let_empty_wildcard_expr() {
736         check_assist(
737             replace_match_with_if_let,
738             r#"
739 fn main() {
740     $0match path.strip_prefix(root_path) {
741         Ok(rel_path) => println!("{}", rel_path),
742         _ => (),
743     }
744 }
745 "#,
746             r#"
747 fn main() {
748     if let Ok(rel_path) = path.strip_prefix(root_path) {
749         println!("{}", rel_path)
750     }
751 }
752 "#,
753         )
754     }
755
756     #[test]
757     fn replace_match_with_if_let_number_body() {
758         check_assist(
759             replace_match_with_if_let,
760             r#"
761 fn main() {
762     $0match Ok(()) {
763         Ok(()) => {},
764         Err(_) => 0,
765     }
766 }
767 "#,
768             r#"
769 fn main() {
770     if let Err(_) = Ok(()) {
771         0
772     }
773 }
774 "#,
775         )
776     }
777
778     #[test]
779     fn replace_match_with_if_let_exhaustive() {
780         check_assist(
781             replace_match_with_if_let,
782             r#"
783 fn print_source(def_source: ModuleSource) {
784     match def_so$0urce {
785         ModuleSource::SourceFile(..) => { println!("source file"); }
786         ModuleSource::Module(..) => { println!("module"); }
787     }
788 }
789 "#,
790             r#"
791 fn print_source(def_source: ModuleSource) {
792     if let ModuleSource::SourceFile(..) = def_source { println!("source file"); } else { println!("module"); }
793 }
794 "#,
795         )
796     }
797
798     #[test]
799     fn replace_match_with_if_let_prefer_name_bind() {
800         check_assist(
801             replace_match_with_if_let,
802             r#"
803 fn foo() {
804     match $0Foo(0) {
805         Foo(_) => (),
806         Bar(bar) => println!("bar {}", bar),
807     }
808 }
809 "#,
810             r#"
811 fn foo() {
812     if let Bar(bar) = Foo(0) {
813         println!("bar {}", bar)
814     }
815 }
816 "#,
817         );
818         check_assist(
819             replace_match_with_if_let,
820             r#"
821 fn foo() {
822     match $0Foo(0) {
823         Bar(bar) => println!("bar {}", bar),
824         Foo(_) => (),
825     }
826 }
827 "#,
828             r#"
829 fn foo() {
830     if let Bar(bar) = Foo(0) {
831         println!("bar {}", bar)
832     }
833 }
834 "#,
835         );
836     }
837
838     #[test]
839     fn replace_match_with_if_let_prefer_nonempty_body() {
840         check_assist(
841             replace_match_with_if_let,
842             r#"
843 fn foo() {
844     match $0Ok(0) {
845         Ok(value) => {},
846         Err(err) => eprintln!("{}", err),
847     }
848 }
849 "#,
850             r#"
851 fn foo() {
852     if let Err(err) = Ok(0) {
853         eprintln!("{}", err)
854     }
855 }
856 "#,
857         );
858         check_assist(
859             replace_match_with_if_let,
860             r#"
861 fn foo() {
862     match $0Ok(0) {
863         Err(err) => eprintln!("{}", err),
864         Ok(value) => {},
865     }
866 }
867 "#,
868             r#"
869 fn foo() {
870     if let Err(err) = Ok(0) {
871         eprintln!("{}", err)
872     }
873 }
874 "#,
875         );
876     }
877
878     #[test]
879     fn replace_match_with_if_let_rejects_double_name_bindings() {
880         check_assist_not_applicable(
881             replace_match_with_if_let,
882             r#"
883 fn foo() {
884     match $0Foo(0) {
885         Foo(foo) => println!("bar {}", foo),
886         Bar(bar) => println!("bar {}", bar),
887     }
888 }
889 "#,
890         );
891     }
892 }