]> git.lizzy.rs Git - rust.git/commitdiff
Check structs for match exhaustiveness
authorLukas Wirth <lukastw97@gmail.com>
Tue, 24 Nov 2020 17:28:55 +0000 (18:28 +0100)
committerLukas Wirth <lukastw97@gmail.com>
Tue, 24 Nov 2020 17:50:56 +0000 (18:50 +0100)
crates/hir_ty/src/diagnostics/match_check.rs

index a52f41764be355e45656a343693aa0842a1ff367..62c329731d4802ad0ed7e0009c125f578e8409c7 100644 (file)
     adt::VariantData,
     body::Body,
     expr::{Expr, Literal, Pat, PatId},
-    AdtId, EnumVariantId, VariantId,
+    AdtId, EnumVariantId, StructId, VariantId,
 };
 use smallvec::{smallvec, SmallVec};
 
@@ -391,21 +391,28 @@ fn specialize_constructor(
                 }
             }
             (Pat::Wild, constructor) => Some(self.expand_wildcard(cx, constructor)?),
-            (Pat::Path(_), Constructor::Enum(constructor)) => {
+            (Pat::Path(_), constructor) => {
                 // unit enum variants become `Pat::Path`
                 let pat_id = head.as_id().expect("we know this isn't a wild");
-                if !enum_variant_matches(cx, pat_id, *constructor) {
+                let variant_id: VariantId = match constructor {
+                    &Constructor::Enum(e) => e.into(),
+                    &Constructor::Struct(s) => s.into(),
+                    _ => return Err(MatchCheckErr::NotImplemented),
+                };
+                if Some(variant_id) != cx.infer.variant_resolution_for_pat(pat_id) {
                     None
                 } else {
                     Some(self.to_tail())
                 }
             }
-            (
-                Pat::TupleStruct { args: ref pat_ids, ellipsis, .. },
-                Constructor::Enum(enum_constructor),
-            ) => {
+            (Pat::TupleStruct { args: ref pat_ids, ellipsis, .. }, constructor) => {
                 let pat_id = head.as_id().expect("we know this isn't a wild");
-                if !enum_variant_matches(cx, pat_id, *enum_constructor) {
+                let variant_id: VariantId = match constructor {
+                    &Constructor::Enum(e) => e.into(),
+                    &Constructor::Struct(s) => s.into(),
+                    _ => return Err(MatchCheckErr::MalformedMatchArm),
+                };
+                if Some(variant_id) != cx.infer.variant_resolution_for_pat(pat_id) {
                     None
                 } else {
                     let constructor_arity = constructor.arity(cx)?;
@@ -443,12 +450,22 @@ fn specialize_constructor(
                     }
                 }
             }
-            (Pat::Record { args: ref arg_patterns, .. }, Constructor::Enum(e)) => {
+            (Pat::Record { args: ref arg_patterns, .. }, constructor) => {
                 let pat_id = head.as_id().expect("we know this isn't a wild");
-                if !enum_variant_matches(cx, pat_id, *e) {
+                let (variant_id, variant_data) = match constructor {
+                    &Constructor::Enum(e) => (
+                        e.into(),
+                        cx.db.enum_data(e.parent).variants[e.local_id].variant_data.clone(),
+                    ),
+                    &Constructor::Struct(s) => {
+                        (s.into(), cx.db.struct_data(s).variant_data.clone())
+                    }
+                    _ => return Err(MatchCheckErr::MalformedMatchArm),
+                };
+                if Some(variant_id) != cx.infer.variant_resolution_for_pat(pat_id) {
                     None
                 } else {
-                    match cx.db.enum_data(e.parent).variants[e.local_id].variant_data.as_ref() {
+                    match variant_data.as_ref() {
                         VariantData::Record(struct_field_arena) => {
                             // Here we treat any missing fields in the record as the wild pattern, as
                             // if the record has ellipsis. We want to do this here even if the
@@ -727,6 +744,7 @@ enum Constructor {
     Bool(bool),
     Tuple { arity: usize },
     Enum(EnumVariantId),
+    Struct(StructId),
 }
 
 impl Constructor {
@@ -741,6 +759,11 @@ fn arity(&self, cx: &MatchCheckCtx) -> MatchCheckResult<usize> {
                     VariantData::Unit => 0,
                 }
             }
+            &Constructor::Struct(s) => match cx.db.struct_data(s).variant_data.as_ref() {
+                VariantData::Tuple(struct_field_data) => struct_field_data.len(),
+                VariantData::Record(struct_field_data) => struct_field_data.len(),
+                VariantData::Unit => 0,
+            },
         };
 
         Ok(arity)
@@ -749,7 +772,7 @@ fn arity(&self, cx: &MatchCheckCtx) -> MatchCheckResult<usize> {
     fn all_constructors(&self, cx: &MatchCheckCtx) -> Vec<Constructor> {
         match self {
             Constructor::Bool(_) => vec![Constructor::Bool(true), Constructor::Bool(false)],
-            Constructor::Tuple { .. } => vec![*self],
+            Constructor::Tuple { .. } | Constructor::Struct(_) => vec![*self],
             Constructor::Enum(e) => cx
                 .db
                 .enum_data(e.parent)
@@ -786,6 +809,7 @@ fn pat_constructor(cx: &MatchCheckCtx, pat: PatIdOrWild) -> MatchCheckResult<Opt
                 VariantId::EnumVariantId(enum_variant_id) => {
                     Some(Constructor::Enum(enum_variant_id))
                 }
+                VariantId::StructId(struct_id) => Some(Constructor::Struct(struct_id)),
                 _ => return Err(MatchCheckErr::NotImplemented),
             }
         }
@@ -830,13 +854,13 @@ fn all_constructors_covered(
 
             false
         }),
+        &Constructor::Struct(s) => used_constructors.iter().any(|constructor| match constructor {
+            &Constructor::Struct(sid) => sid == s,
+            _ => false,
+        }),
     }
 }
 
-fn enum_variant_matches(cx: &MatchCheckCtx, pat_id: PatId, enum_variant_id: EnumVariantId) -> bool {
-    Some(enum_variant_id.into()) == cx.infer.variant_resolution_for_pat(pat_id)
-}
-
 #[cfg(test)]
 mod tests {
     use crate::diagnostics::tests::check_diagnostics;
@@ -848,8 +872,8 @@ fn empty_tuple() {
 fn main() {
     match () { }
         //^^ Missing match arm
-   match (()) { }
-       //^^^^ Missing match arm
+    match (()) { }
+        //^^^^ Missing match arm
 
     match () { _ => (), }
     match () { () => (), }
@@ -1393,6 +1417,84 @@ fn main() {
         );
     }
 
+    #[test]
+    fn record_struct() {
+        check_diagnostics(
+            r#"struct Foo { a: bool }
+fn main(f: Foo) {
+    match f {}
+        //^ Missing match arm
+    match f { Foo { a: true } => () }
+        //^ Missing match arm
+    match &f { Foo { a: true } => () }
+        //^^ Missing match arm
+    match f { Foo { a: _ } => () }
+    match f {
+        Foo { a: true } => (),
+        Foo { a: false } => (),
+    }
+    match &f {
+        Foo { a: true } => (),
+        Foo { a: false } => (),
+    }
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn tuple_struct() {
+        check_diagnostics(
+            r#"struct Foo(bool);
+fn main(f: Foo) {
+    match f {}
+        //^ Missing match arm
+    match f { Foo(true) => () }
+        //^ Missing match arm
+    match f {
+        Foo(true) => (),
+        Foo(false) => (),
+    }
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn unit_struct() {
+        check_diagnostics(
+            r#"struct Foo;
+fn main(f: Foo) {
+    match f {}
+        //^ Missing match arm
+    match f { Foo => () }
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn record_struct_ellipsis() {
+        check_diagnostics(
+            r#"struct Foo { foo: bool, bar: bool }
+fn main(f: Foo) {
+    match f { Foo { foo: true, .. } => () }
+        //^ Missing match arm
+    match f {
+        //^ Missing match arm
+        Foo { foo: true, .. } => (),
+        Foo { bar: false, .. } => ()
+    }
+    match f { Foo { .. } => () }
+    match f {
+        Foo { foo: true, .. } => (),
+        Foo { foo: false, .. } => ()
+    }
+}
+"#,
+        );
+    }
+
     mod false_negatives {
         //! The implementation of match checking here is a work in progress. As we roll this out, we
         //! prefer false negatives to false positives (ideally there would be no false positives). This
@@ -1431,19 +1533,6 @@ enum Either { A(bool), B }
         Either::A(true | false) => (),
     }
 }
-"#,
-            );
-        }
-
-        #[test]
-        fn struct_missing_arm() {
-            // We don't currently handle structs.
-            check_diagnostics(
-                r#"
-struct Foo { a: bool }
-fn main(f: Foo) {
-    match f { Foo { a: true } => () }
-}
 "#,
             );
         }