]> git.lizzy.rs Git - rust.git/commitdiff
Add support for fill match arms of boolean values
authorComonad <comonad@foxmail.com>
Wed, 21 Apr 2021 11:33:45 +0000 (19:33 +0800)
committerComonad <comonad@foxmail.com>
Wed, 21 Apr 2021 11:33:45 +0000 (19:33 +0800)
- Add support for boolean inside tuple

crates/ide_assists/src/handlers/fill_match_arms.rs
crates/syntax/src/ast/make.rs

index a30c4d04ee5b09cb74d2618ca34ad2ffccc2e1e0..be927cc1c4728b6d006fb3ce5e40f810f5ea0f53 100644 (file)
@@ -53,7 +53,7 @@ pub(crate) fn fill_match_arms(acc: &mut Assists, ctx: &AssistContext) -> Option<
         .iter()
         .filter_map(ast::MatchArm::pat)
         .flat_map(|pat| match pat {
-            // Special casee OrPat as separate top-level pats
+            // Special case OrPat as separate top-level pats
             Pat::OrPat(or_pat) => Either::Left(or_pat.pats()),
             _ => Either::Right(iter::once(pat)),
         })
@@ -72,7 +72,11 @@ pub(crate) fn fill_match_arms(acc: &mut Assists, ctx: &AssistContext) -> Option<
             .filter(|variant_pat| is_variant_missing(&top_lvl_pats, variant_pat))
             .map(|pat| make::match_arm(iter::once(pat), make::expr_empty_block()))
             .collect::<Vec<_>>();
-        if Some(enum_def) == FamousDefs(&ctx.sema, Some(module.krate())).core_option_Option() {
+        if Some(enum_def)
+            == FamousDefs(&ctx.sema, Some(module.krate()))
+                .core_option_Option()
+                .map(|x| lift_enum(x))
+        {
             // Match `Some` variant first.
             cov_mark::hit!(option_order);
             variants.reverse()
@@ -151,49 +155,99 @@ fn does_pat_match_variant(pat: &Pat, var: &Pat) -> bool {
     }
 }
 
-fn resolve_enum_def(sema: &Semantics<RootDatabase>, expr: &ast::Expr) -> Option<hir::Enum> {
+#[derive(Eq, PartialEq, Clone)]
+enum ExtendedEnum {
+    Bool,
+    Enum(hir::Enum),
+}
+
+#[derive(Eq, PartialEq, Clone)]
+enum ExtendedVariant {
+    True,
+    False,
+    Variant(hir::Variant),
+}
+
+fn lift_enum(e: hir::Enum) -> ExtendedEnum {
+    ExtendedEnum::Enum(e)
+}
+
+impl ExtendedEnum {
+    fn variants(&self, db: &RootDatabase) -> Vec<ExtendedVariant> {
+        match self {
+            ExtendedEnum::Enum(e) => {
+                e.variants(db).into_iter().map(|x| ExtendedVariant::Variant(x)).collect::<Vec<_>>()
+            }
+            ExtendedEnum::Bool => {
+                Vec::<ExtendedVariant>::from([ExtendedVariant::True, ExtendedVariant::False])
+            }
+        }
+    }
+}
+
+fn resolve_enum_def(sema: &Semantics<RootDatabase>, expr: &ast::Expr) -> Option<ExtendedEnum> {
     sema.type_of_expr(&expr)?.autoderef(sema.db).find_map(|ty| match ty.as_adt() {
-        Some(Adt::Enum(e)) => Some(e),
-        _ => None,
+        Some(Adt::Enum(e)) => Some(ExtendedEnum::Enum(e)),
+        _ => {
+            if ty.is_bool() {
+                Some(ExtendedEnum::Bool)
+            } else {
+                None
+            }
+        }
     })
 }
 
 fn resolve_tuple_of_enum_def(
     sema: &Semantics<RootDatabase>,
     expr: &ast::Expr,
-) -> Option<Vec<hir::Enum>> {
+) -> Option<Vec<ExtendedEnum>> {
     sema.type_of_expr(&expr)?
         .tuple_fields(sema.db)
         .iter()
         .map(|ty| {
             ty.autoderef(sema.db).find_map(|ty| match ty.as_adt() {
-                Some(Adt::Enum(e)) => Some(e),
+                Some(Adt::Enum(e)) => Some(lift_enum(e)),
                 // For now we only handle expansion for a tuple of enums. Here
                 // we map non-enum items to None and rely on `collect` to
                 // convert Vec<Option<hir::Enum>> into Option<Vec<hir::Enum>>.
-                _ => None,
+                _ => {
+                    if ty.is_bool() {
+                        Some(ExtendedEnum::Bool)
+                    } else {
+                        None
+                    }
+                }
             })
         })
         .collect()
 }
 
-fn build_pat(db: &RootDatabase, module: hir::Module, var: hir::Variant) -> Option<ast::Pat> {
-    let path = mod_path_to_ast(&module.find_use_path(db, ModuleDef::from(var))?);
+fn build_pat(db: &RootDatabase, module: hir::Module, var: ExtendedVariant) -> Option<ast::Pat> {
+    match var {
+        ExtendedVariant::Variant(var) => {
+            let path = mod_path_to_ast(&module.find_use_path(db, ModuleDef::from(var))?);
+
+            // FIXME: use HIR for this; it doesn't currently expose struct vs. tuple vs. unit variants though
+            let pat: ast::Pat = match var.source(db)?.value.kind() {
+                ast::StructKind::Tuple(field_list) => {
+                    let pats =
+                        iter::repeat(make::wildcard_pat().into()).take(field_list.fields().count());
+                    make::tuple_struct_pat(path, pats).into()
+                }
+                ast::StructKind::Record(field_list) => {
+                    let pats =
+                        field_list.fields().map(|f| make::ident_pat(f.name().unwrap()).into());
+                    make::record_pat(path, pats).into()
+                }
+                ast::StructKind::Unit => make::path_pat(path),
+            };
 
-    // FIXME: use HIR for this; it doesn't currently expose struct vs. tuple vs. unit variants though
-    let pat: ast::Pat = match var.source(db)?.value.kind() {
-        ast::StructKind::Tuple(field_list) => {
-            let pats = iter::repeat(make::wildcard_pat().into()).take(field_list.fields().count());
-            make::tuple_struct_pat(path, pats).into()
-        }
-        ast::StructKind::Record(field_list) => {
-            let pats = field_list.fields().map(|f| make::ident_pat(f.name().unwrap()).into());
-            make::record_pat(path, pats).into()
+            Some(pat)
         }
-        ast::StructKind::Unit => make::path_pat(path),
-    };
-
-    Some(pat)
+        ExtendedVariant::True => Some(ast::Pat::from(make::literal_pat("true"))),
+        ExtendedVariant::False => Some(ast::Pat::from(make::literal_pat("false"))),
+    }
 }
 
 #[cfg(test)]
@@ -225,6 +279,21 @@ fn main() {
         );
     }
 
+    #[test]
+    fn all_boolean_match_arms_provided() {
+        check_assist_not_applicable(
+            fill_match_arms,
+            r#"
+            fn foo(a: bool) {
+                match a$0 {
+                    true => {}
+                    false => {}
+                }
+            }
+            "#,
+        )
+    }
+
     #[test]
     fn tuple_of_non_enum() {
         // for now this case is not handled, although it potentially could be
@@ -240,6 +309,113 @@ fn main() {
         );
     }
 
+    #[test]
+    fn fill_match_arms_boolean() {
+        check_assist(
+            fill_match_arms,
+            r#"
+            fn foo(a: bool) {
+                match a$0 {
+                }
+            }
+            "#,
+            r#"
+            fn foo(a: bool) {
+                match a {
+                    $0true => {}
+                    false => {}
+                }
+            }
+            "#,
+        )
+    }
+
+    #[test]
+    fn partial_fill_boolean() {
+        check_assist(
+            fill_match_arms,
+            r#"
+            fn foo(a: bool) {
+                match a$0 {
+                    true => {}
+                }
+            }
+            "#,
+            r#"
+            fn foo(a: bool) {
+                match a {
+                    true => {}
+                    $0false => {}
+                }
+            }
+            "#,
+        )
+    }
+
+    #[test]
+    fn all_boolean_tuple_arms_provided() {
+        check_assist_not_applicable(
+            fill_match_arms,
+            r#"
+            fn foo(a: bool) {
+                match (a, a)$0 {
+                    (true, true) => {}
+                    (true, false) => {}
+                    (false, true) => {}
+                    (false, false) => {}
+                }
+            }
+            "#,
+        )
+    }
+
+    #[test]
+    fn fill_boolean_tuple() {
+        check_assist(
+            fill_match_arms,
+            r#"
+            fn foo(a: bool) {
+                match (a, a)$0 {
+                }
+            }
+            "#,
+            r#"
+            fn foo(a: bool) {
+                match (a, a) {
+                    $0(true, true) => {}
+                    (true, false) => {}
+                    (false, true) => {}
+                    (false, false) => {}
+                }
+            }
+            "#,
+        )
+    }
+
+    #[test]
+    fn partial_fill_boolean_tuple() {
+        check_assist(
+            fill_match_arms,
+            r#"
+            fn foo(a: bool) {
+                match (a, a)$0 {
+                    (false, true) => {}
+                }
+            }
+            "#,
+            r#"
+            fn foo(a: bool) {
+                match (a, a) {
+                    (false, true) => {}
+                    $0(true, true) => {}
+                    (true, false) => {}
+                    (false, false) => {}
+                }
+            }
+            "#,
+        )
+    }
+
     #[test]
     fn partial_fill_record_tuple() {
         check_assist(
index 94d4f2cf0b986659f1cbfbea4212d87941fd38bf..4cf6f871e2c92a9cc3cd7001bea8654da9ef41b7 100644 (file)
@@ -294,6 +294,14 @@ fn from_text(text: &str) -> ast::WildcardPat {
     }
 }
 
+pub fn literal_pat(lit: &str) -> ast::LiteralPat {
+    return from_text(lit);
+
+    fn from_text(text: &str) -> ast::LiteralPat {
+        ast_from_text(&format!("fn f() {{ match x {{ {} => {{}} }} }}", text))
+    }
+}
+
 /// Creates a tuple of patterns from an iterator of patterns.
 ///
 /// Invariant: `pats` must be length > 0