]> git.lizzy.rs Git - rust.git/blobdiff - crates/ide_assists/src/handlers/merge_match_arms.rs
Merge #11481
[rust.git] / crates / ide_assists / src / handlers / merge_match_arms.rs
index ecb7d4bf0799fe6d048cf561f77a2798a1062809..622ead81f105ba325b20a3992f5da7dea36b6665 100644 (file)
@@ -2,7 +2,7 @@
 use std::{collections::HashMap, iter::successors};
 use syntax::{
     algo::neighbor,
-    ast::{self, AstNode, HasName, MatchArm, Pat},
+    ast::{self, AstNode, HasName},
     Direction,
 };
 
@@ -52,7 +52,7 @@ pub(crate) fn merge_match_arms(acc: &mut Assists, ctx: &AssistContext) -> Option
                     return false;
                 }
 
-                return are_same_types(&current_arm_types, arm, ctx);
+                are_same_types(&current_arm_types, arm, ctx)
             }
             _ => false,
         })
@@ -90,7 +90,7 @@ pub(crate) fn merge_match_arms(acc: &mut Assists, ctx: &AssistContext) -> Option
     )
 }
 
-fn contains_placeholder(a: &MatchArm) -> bool {
+fn contains_placeholder(a: &ast::MatchArm) -> bool {
     matches!(a.pat(), Some(ast::Pat::WildcardPat(..)))
 }
 
@@ -101,50 +101,56 @@ fn are_same_types(
 ) -> bool {
     let arm_types = get_arm_types(&ctx, &arm);
     for (other_arm_type_name, other_arm_type) in arm_types {
-        if let (Some(Some(current_arm_type)), Some(other_arm_type)) =
-            (current_arm_types.get(&other_arm_type_name), other_arm_type)
-        {
-            if other_arm_type.original != current_arm_type.original {
-                return false;
+        match (current_arm_types.get(&other_arm_type_name), other_arm_type) {
+            (Some(Some(current_arm_type)), Some(other_arm_type))
+                if other_arm_type.original == current_arm_type.original =>
+            {
+                ()
             }
-        } else {
-            // No corresponding field found
-            return false;
+            _ => return false,
         }
     }
 
-    return true;
+    true
 }
 
-fn get_arm_types(context: &AssistContext, arm: &MatchArm) -> HashMap<String, Option<TypeInfo>> {
+fn get_arm_types(
+    context: &AssistContext,
+    arm: &ast::MatchArm,
+) -> HashMap<String, Option<TypeInfo>> {
     let mut mapping: HashMap<String, Option<TypeInfo>> = HashMap::new();
 
     fn recurse(
-        pat: &Option<Pat>,
         map: &mut HashMap<String, Option<TypeInfo>>,
         ctx: &AssistContext,
+        pat: &Option<ast::Pat>,
     ) {
         if let Some(local_pat) = pat {
             match pat {
                 Some(ast::Pat::TupleStructPat(tuple)) => {
                     for field in tuple.fields() {
-                        recurse(&Some(field), map, ctx);
+                        recurse(map, ctx, &Some(field));
                     }
                 }
                 Some(ast::Pat::TuplePat(tuple)) => {
                     for field in tuple.fields() {
-                        recurse(&Some(field), map, ctx);
+                        recurse(map, ctx, &Some(field));
                     }
                 }
                 Some(ast::Pat::RecordPat(record)) => {
                     if let Some(field_list) = record.record_pat_field_list() {
                         for field in field_list.fields() {
-                            recurse(&field.pat(), map, ctx);
+                            recurse(map, ctx, &field.pat());
                         }
                     }
                 }
                 Some(ast::Pat::ParenPat(parentheses)) => {
-                    recurse(&parentheses.pat(), map, ctx);
+                    recurse(map, ctx, &parentheses.pat());
+                }
+                Some(ast::Pat::SlicePat(slice)) => {
+                    for slice_pat in slice.pats() {
+                        recurse(map, ctx, &Some(slice_pat));
+                    }
                 }
                 Some(ast::Pat::IdentPat(ident_pat)) => {
                     if let Some(name) = ident_pat.name() {
@@ -157,8 +163,8 @@ fn recurse(
         }
     }
 
-    recurse(&arm.pat(), &mut mapping, &context);
-    return mapping;
+    recurse(&mut mapping, &context, &arm.pat());
+    mapping
 }
 
 #[cfg(test)]
@@ -322,7 +328,8 @@ fn main() {
     fn merge_match_arms_different_type() {
         check_assist_not_applicable(
             merge_match_arms,
-            r#"//- minicore: result
+            r#"
+//- minicore: result
 fn func() {
     match Result::<f64, f32>::Ok(0f64) {
         Ok(x) => $0x.classify(),
@@ -337,7 +344,8 @@ fn func() {
     fn merge_match_arms_different_type_multiple_fields() {
         check_assist_not_applicable(
             merge_match_arms,
-            r#"//- minicore: result
+            r#"
+//- minicore: result
 fn func() {
     match Result::<(f64, f64), (f32, f32)>::Ok((0f64, 0f64)) {
         Ok(x) => $0x.1.classify(),
@@ -352,7 +360,8 @@ fn func() {
     fn merge_match_arms_same_type_multiple_fields() {
         check_assist(
             merge_match_arms,
-            r#"//- minicore: result
+            r#"
+//- minicore: result
 fn func() {
     match Result::<(f64, f64), (f64, f64)>::Ok((0f64, 0f64)) {
         Ok(x) => $0x.1.classify(),
@@ -432,7 +441,8 @@ fn func(e: MyEnum) {
     fn merge_match_arms_same_type_different_number_of_fields() {
         check_assist_not_applicable(
             merge_match_arms,
-            r#"//- minicore: result
+            r#"
+//- minicore: result
 fn func() {
     match Result::<(f64, f64, f64), (f64, f64)>::Ok((0f64, 0f64, 0f64)) {
         Ok(x) => $0x.1.classify(),
@@ -747,6 +757,67 @@ fn func(x: i32) {
         ((((variable)))) => "",
         _ => "other"
     };
+}
+        "#,
+        )
+    }
+
+    #[test]
+    fn merge_match_arms_refpat() {
+        check_assist_not_applicable(
+            merge_match_arms,
+            r#"
+fn func() {
+    let name = Some(String::from(""));
+    let n = String::from("");
+    match name {
+        Some(ref n) => $0"",
+        Some(n) => "",
+        _ => "other",
+    };
+}
+        "#,
+        )
+    }
+
+    #[test]
+    fn merge_match_arms_slice() {
+        check_assist_not_applicable(
+            merge_match_arms,
+            r#"
+fn func(binary: &[u8]) {
+    let space = b' ';
+    match binary {
+        [0x7f, b'E', b'L', b'F', ..] => $0"",
+        [space] => "",
+        _ => "other",
+    };
+}
+        "#,
+        )
+    }
+
+    #[test]
+    fn merge_match_arms_slice_identical() {
+        check_assist(
+            merge_match_arms,
+            r#"
+fn func(binary: &[u8]) {
+    let space = b' ';
+    match binary {
+        [space, 5u8] => $0"",
+        [space] => "",
+        _ => "other",
+    };
+}
+        "#,
+            r#"
+fn func(binary: &[u8]) {
+    let space = b' ';
+    match binary {
+        [space, 5u8] | [space] => "",
+        _ => "other",
+    };
 }
         "#,
         )