]> git.lizzy.rs Git - rust.git/blobdiff - crates/ide_assists/src/handlers/merge_match_arms.rs
Merge #11481
[rust.git] / crates / ide_assists / src / handlers / merge_match_arms.rs
index 5c6bb986b91c74dff4b0c97024cc48dfc9337b30..622ead81f105ba325b20a3992f5da7dea36b6665 100644 (file)
@@ -1,8 +1,8 @@
-use std::iter::successors;
-
+use hir::TypeInfo;
+use std::{collections::HashMap, iter::successors};
 use syntax::{
     algo::neighbor,
-    ast::{self, AstNode},
+    ast::{self, AstNode, HasName},
     Direction,
 };
 
@@ -40,13 +40,19 @@ pub(crate) fn merge_match_arms(acc: &mut Assists, ctx: &AssistContext) -> Option
     }
     let current_expr = current_arm.expr()?;
     let current_text_range = current_arm.syntax().text_range();
+    let current_arm_types = get_arm_types(&ctx, &current_arm);
 
     // We check if the following match arms match this one. We could, but don't,
     // compare to the previous match arm as well.
     let arms_to_merge = successors(Some(current_arm), |it| neighbor(it, Direction::Next))
         .take_while(|arm| match arm.expr() {
             Some(expr) if arm.guard().is_none() => {
-                expr.syntax().text() == current_expr.syntax().text()
+                let same_text = expr.syntax().text() == current_expr.syntax().text();
+                if !same_text {
+                    return false;
+                }
+
+                are_same_types(&current_arm_types, arm, ctx)
             }
             _ => false,
         })
@@ -72,7 +78,7 @@ pub(crate) fn merge_match_arms(acc: &mut Assists, ctx: &AssistContext) -> Option
                     .join(" | ")
             };
 
-            let arm = format!("{} => {}", pats, current_expr.syntax().text());
+            let arm = format!("{} => {},", pats, current_expr.syntax().text());
 
             if let [first, .., last] = &*arms_to_merge {
                 let start = first.syntax().text_range().start();
@@ -88,6 +94,79 @@ fn contains_placeholder(a: &ast::MatchArm) -> bool {
     matches!(a.pat(), Some(ast::Pat::WildcardPat(..)))
 }
 
+fn are_same_types(
+    current_arm_types: &HashMap<String, Option<TypeInfo>>,
+    arm: &ast::MatchArm,
+    ctx: &AssistContext,
+) -> bool {
+    let arm_types = get_arm_types(&ctx, &arm);
+    for (other_arm_type_name, other_arm_type) in arm_types {
+        match (current_arm_types.get(&other_arm_type_name), other_arm_type) {
+            (Some(Some(current_arm_type)), Some(other_arm_type))
+                if other_arm_type.original == current_arm_type.original =>
+            {
+                ()
+            }
+            _ => return false,
+        }
+    }
+
+    true
+}
+
+fn get_arm_types(
+    context: &AssistContext,
+    arm: &ast::MatchArm,
+) -> HashMap<String, Option<TypeInfo>> {
+    let mut mapping: HashMap<String, Option<TypeInfo>> = HashMap::new();
+
+    fn recurse(
+        map: &mut HashMap<String, Option<TypeInfo>>,
+        ctx: &AssistContext,
+        pat: &Option<ast::Pat>,
+    ) {
+        if let Some(local_pat) = pat {
+            match pat {
+                Some(ast::Pat::TupleStructPat(tuple)) => {
+                    for field in tuple.fields() {
+                        recurse(map, ctx, &Some(field));
+                    }
+                }
+                Some(ast::Pat::TuplePat(tuple)) => {
+                    for field in tuple.fields() {
+                        recurse(map, ctx, &Some(field));
+                    }
+                }
+                Some(ast::Pat::RecordPat(record)) => {
+                    if let Some(field_list) = record.record_pat_field_list() {
+                        for field in field_list.fields() {
+                            recurse(map, ctx, &field.pat());
+                        }
+                    }
+                }
+                Some(ast::Pat::ParenPat(parentheses)) => {
+                    recurse(map, ctx, &parentheses.pat());
+                }
+                Some(ast::Pat::SlicePat(slice)) => {
+                    for slice_pat in slice.pats() {
+                        recurse(map, ctx, &Some(slice_pat));
+                    }
+                }
+                Some(ast::Pat::IdentPat(ident_pat)) => {
+                    if let Some(name) = ident_pat.name() {
+                        let pat_type = ctx.sema.type_of_pat(local_pat);
+                        map.insert(name.text().to_string(), pat_type);
+                    }
+                }
+                _ => (),
+            }
+        }
+    }
+
+    recurse(&mut mapping, &context, &arm.pat());
+    mapping
+}
+
 #[cfg(test)]
 mod tests {
     use crate::tests::{check_assist, check_assist_not_applicable};
@@ -118,7 +197,7 @@ enum X { A, B, C }
 fn main() {
     let x = X::A;
     let y = match x {
-        X::A | X::B => { 1i32 }
+        X::A | X::B => { 1i32 },
         X::C => { 2i32 }
     }
 }
@@ -183,7 +262,7 @@ fn main() {
     let x = X::A;
     let y = match x {
         X::A => { 1i32 },
-        _ => { 2i32 }
+        _ => { 2i32 },
     }
 }
 "#,
@@ -244,4 +323,503 @@ fn main() {
 "#,
         );
     }
+
+    #[test]
+    fn merge_match_arms_different_type() {
+        check_assist_not_applicable(
+            merge_match_arms,
+            r#"
+//- minicore: result
+fn func() {
+    match Result::<f64, f32>::Ok(0f64) {
+        Ok(x) => $0x.classify(),
+        Err(x) => x.classify()
+    };
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn merge_match_arms_different_type_multiple_fields() {
+        check_assist_not_applicable(
+            merge_match_arms,
+            r#"
+//- minicore: result
+fn func() {
+    match Result::<(f64, f64), (f32, f32)>::Ok((0f64, 0f64)) {
+        Ok(x) => $0x.1.classify(),
+        Err(x) => x.1.classify()
+    };
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn merge_match_arms_same_type_multiple_fields() {
+        check_assist(
+            merge_match_arms,
+            r#"
+//- minicore: result
+fn func() {
+    match Result::<(f64, f64), (f64, f64)>::Ok((0f64, 0f64)) {
+        Ok(x) => $0x.1.classify(),
+        Err(x) => x.1.classify()
+    };
+}
+"#,
+            r#"
+fn func() {
+    match Result::<(f64, f64), (f64, f64)>::Ok((0f64, 0f64)) {
+        Ok(x) | Err(x) => x.1.classify(),
+    };
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn merge_match_arms_same_type_subsequent_arm_with_different_type_in_other() {
+        check_assist(
+            merge_match_arms,
+            r#"
+enum MyEnum {
+    OptionA(f32),
+    OptionB(f32),
+    OptionC(f64)
+}
+
+fn func(e: MyEnum) {
+    match e {
+        MyEnum::OptionA(x) => $0x.classify(),
+        MyEnum::OptionB(x) => x.classify(),
+        MyEnum::OptionC(x) => x.classify(),
+    };
+}
+"#,
+            r#"
+enum MyEnum {
+    OptionA(f32),
+    OptionB(f32),
+    OptionC(f64)
+}
+
+fn func(e: MyEnum) {
+    match e {
+        MyEnum::OptionA(x) | MyEnum::OptionB(x) => x.classify(),
+        MyEnum::OptionC(x) => x.classify(),
+    };
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn merge_match_arms_same_type_skip_arm_with_different_type_in_between() {
+        check_assist_not_applicable(
+            merge_match_arms,
+            r#"
+enum MyEnum {
+    OptionA(f32),
+    OptionB(f64),
+    OptionC(f32)
+}
+
+fn func(e: MyEnum) {
+    match e {
+        MyEnum::OptionA(x) => $0x.classify(),
+        MyEnum::OptionB(x) => x.classify(),
+        MyEnum::OptionC(x) => x.classify(),
+    };
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn merge_match_arms_same_type_different_number_of_fields() {
+        check_assist_not_applicable(
+            merge_match_arms,
+            r#"
+//- minicore: result
+fn func() {
+    match Result::<(f64, f64, f64), (f64, f64)>::Ok((0f64, 0f64, 0f64)) {
+        Ok(x) => $0x.1.classify(),
+        Err(x) => x.1.classify()
+    };
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn merge_match_same_destructuring_different_types() {
+        check_assist_not_applicable(
+            merge_match_arms,
+            r#"
+struct Point {
+    x: i32,
+    y: i32,
+}
+
+fn func() {
+    let p = Point { x: 0, y: 7 };
+
+    match p {
+        Point { x, y: 0 } => $0"",
+        Point { x: 0, y } => "",
+        Point { x, y } => "",
+    };
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn merge_match_arms_range() {
+        check_assist(
+            merge_match_arms,
+            r#"
+fn func() {
+    let x = 'c';
+
+    match x {
+        'a'..='j' => $0"",
+        'c'..='z' => "",
+        _ => "other",
+    };
+}
+"#,
+            r#"
+fn func() {
+    let x = 'c';
+
+    match x {
+        'a'..='j' | 'c'..='z' => "",
+        _ => "other",
+    };
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn merge_match_arms_enum_without_field() {
+        check_assist_not_applicable(
+            merge_match_arms,
+            r#"
+enum MyEnum {
+    NoField,
+    AField(u8)
+}
+
+fn func(x: MyEnum) {
+    match x {
+        MyEnum::NoField => $0"",
+        MyEnum::AField(x) => ""
+    };
+}
+        "#,
+        )
+    }
+
+    #[test]
+    fn merge_match_arms_enum_destructuring_different_types() {
+        check_assist_not_applicable(
+            merge_match_arms,
+            r#"
+enum MyEnum {
+    Move { x: i32, y: i32 },
+    Write(String),
+}
+
+fn func(x: MyEnum) {
+    match x {
+        MyEnum::Move { x, y } => $0"",
+        MyEnum::Write(text) => "",
+    };
+}
+        "#,
+        )
+    }
+
+    #[test]
+    fn merge_match_arms_enum_destructuring_same_types() {
+        check_assist(
+            merge_match_arms,
+            r#"
+enum MyEnum {
+    Move { x: i32, y: i32 },
+    Crawl { x: i32, y: i32 }
+}
+
+fn func(x: MyEnum) {
+    match x {
+        MyEnum::Move { x, y } => $0"",
+        MyEnum::Crawl { x, y } => "",
+    };
+}
+        "#,
+            r#"
+enum MyEnum {
+    Move { x: i32, y: i32 },
+    Crawl { x: i32, y: i32 }
+}
+
+fn func(x: MyEnum) {
+    match x {
+        MyEnum::Move { x, y } | MyEnum::Crawl { x, y } => "",
+    };
+}
+        "#,
+        )
+    }
+
+    #[test]
+    fn merge_match_arms_enum_destructuring_same_types_different_name() {
+        check_assist_not_applicable(
+            merge_match_arms,
+            r#"
+enum MyEnum {
+    Move { x: i32, y: i32 },
+    Crawl { a: i32, b: i32 }
+}
+
+fn func(x: MyEnum) {
+    match x {
+        MyEnum::Move { x, y } => $0"",
+        MyEnum::Crawl { a, b } => "",
+    };
+}
+        "#,
+        )
+    }
+
+    #[test]
+    fn merge_match_arms_enum_nested_pattern_different_names() {
+        check_assist_not_applicable(
+            merge_match_arms,
+            r#"
+enum Color {
+    Rgb(i32, i32, i32),
+    Hsv(i32, i32, i32),
+}
+
+enum Message {
+    Quit,
+    Move { x: i32, y: i32 },
+    Write(String),
+    ChangeColor(Color),
+}
+
+fn main(msg: Message) {
+    match msg {
+        Message::ChangeColor(Color::Rgb(r, g, b)) => $0"",
+        Message::ChangeColor(Color::Hsv(h, s, v)) => "",
+        _ => "other"
+    };
+}
+        "#,
+        )
+    }
+
+    #[test]
+    fn merge_match_arms_enum_nested_pattern_same_names() {
+        check_assist(
+            merge_match_arms,
+            r#"
+enum Color {
+    Rgb(i32, i32, i32),
+    Hsv(i32, i32, i32),
+}
+
+enum Message {
+    Quit,
+    Move { x: i32, y: i32 },
+    Write(String),
+    ChangeColor(Color),
+}
+
+fn main(msg: Message) {
+    match msg {
+        Message::ChangeColor(Color::Rgb(a, b, c)) => $0"",
+        Message::ChangeColor(Color::Hsv(a, b, c)) => "",
+        _ => "other"
+    };
+}
+        "#,
+            r#"
+enum Color {
+    Rgb(i32, i32, i32),
+    Hsv(i32, i32, i32),
+}
+
+enum Message {
+    Quit,
+    Move { x: i32, y: i32 },
+    Write(String),
+    ChangeColor(Color),
+}
+
+fn main(msg: Message) {
+    match msg {
+        Message::ChangeColor(Color::Rgb(a, b, c)) | Message::ChangeColor(Color::Hsv(a, b, c)) => "",
+        _ => "other"
+    };
+}
+        "#,
+        )
+    }
+
+    #[test]
+    fn merge_match_arms_enum_destructuring_with_ignore() {
+        check_assist(
+            merge_match_arms,
+            r#"
+enum MyEnum {
+    Move { x: i32, a: i32 },
+    Crawl { x: i32, b: i32 }
+}
+
+fn func(x: MyEnum) {
+    match x {
+        MyEnum::Move { x, .. } => $0"",
+        MyEnum::Crawl { x, .. } => "",
+    };
+}
+        "#,
+            r#"
+enum MyEnum {
+    Move { x: i32, a: i32 },
+    Crawl { x: i32, b: i32 }
+}
+
+fn func(x: MyEnum) {
+    match x {
+        MyEnum::Move { x, .. } | MyEnum::Crawl { x, .. } => "",
+    };
+}
+        "#,
+        )
+    }
+
+    #[test]
+    fn merge_match_arms_nested_with_conflicting_identifier() {
+        check_assist_not_applicable(
+            merge_match_arms,
+            r#"
+enum Color {
+    Rgb(i32, i32, i32),
+    Hsv(i32, i32, i32),
+}
+
+enum Message {
+    Move { x: i32, y: i32 },
+    ChangeColor(u8, Color),
+}
+
+fn main(msg: Message) {
+    match msg {
+        Message::ChangeColor(x, Color::Rgb(y, b, c)) => $0"",
+        Message::ChangeColor(y, Color::Hsv(x, b, c)) => "",
+        _ => "other"
+    };
+}
+        "#,
+        )
+    }
+
+    #[test]
+    fn merge_match_arms_tuple() {
+        check_assist_not_applicable(
+            merge_match_arms,
+            r#"
+fn func() {
+    match (0, "boo") {
+        (x, y) => $0"",
+        (y, x) => "",
+    };
+}
+        "#,
+        )
+    }
+
+    #[test]
+    fn merge_match_arms_parentheses() {
+        check_assist_not_applicable(
+            merge_match_arms,
+            r#"
+fn func(x: i32) {
+    let variable = 2;
+    match x {
+        1 => $0"",
+        ((((variable)))) => "",
+        _ => "other"
+    };
+}
+        "#,
+        )
+    }
+
+    #[test]
+    fn merge_match_arms_refpat() {
+        check_assist_not_applicable(
+            merge_match_arms,
+            r#"
+fn func() {
+    let name = Some(String::from(""));
+    let n = String::from("");
+    match name {
+        Some(ref n) => $0"",
+        Some(n) => "",
+        _ => "other",
+    };
+}
+        "#,
+        )
+    }
+
+    #[test]
+    fn merge_match_arms_slice() {
+        check_assist_not_applicable(
+            merge_match_arms,
+            r#"
+fn func(binary: &[u8]) {
+    let space = b' ';
+    match binary {
+        [0x7f, b'E', b'L', b'F', ..] => $0"",
+        [space] => "",
+        _ => "other",
+    };
+}
+        "#,
+        )
+    }
+
+    #[test]
+    fn merge_match_arms_slice_identical() {
+        check_assist(
+            merge_match_arms,
+            r#"
+fn func(binary: &[u8]) {
+    let space = b' ';
+    match binary {
+        [space, 5u8] => $0"",
+        [space] => "",
+        _ => "other",
+    };
+}
+        "#,
+            r#"
+fn func(binary: &[u8]) {
+    let space = b' ';
+    match binary {
+        [space, 5u8] | [space] => "",
+        _ => "other",
+    };
+}
+        "#,
+        )
+    }
 }