]> git.lizzy.rs Git - rust.git/commitdiff
Implement type inference for enum variants
authorMarcus Klaas de Vries <mail@marcusklaas.nl>
Tue, 8 Jan 2019 15:01:19 +0000 (16:01 +0100)
committerMarcus Klaas de Vries <mail@marcusklaas.nl>
Thu, 10 Jan 2019 13:32:56 +0000 (14:32 +0100)
12 files changed:
crates/ra_hir/src/adt.rs
crates/ra_hir/src/code_model_api.rs
crates/ra_hir/src/code_model_impl/module.rs
crates/ra_hir/src/db.rs
crates/ra_hir/src/ids.rs
crates/ra_hir/src/lib.rs
crates/ra_hir/src/mock.rs
crates/ra_hir/src/ty.rs
crates/ra_hir/src/ty/tests.rs
crates/ra_hir/src/ty/tests/data/enum.txt [new file with mode: 0644]
crates/ra_ide_api/src/completion/complete_path.rs
crates/ra_ide_api/src/db.rs

index d30390f25bfc01dcbfdf64de25f76770ba4b8fb2..f1b98cdd76ef0fce202981dd83d8f31e75b5f97d 100644 (file)
@@ -1,10 +1,19 @@
 use std::sync::Arc;
 
 use ra_db::Cancelable;
-use ra_syntax::ast::{self, NameOwner, StructFlavor, AstNode};
+use ra_syntax::{
+    SyntaxNode,
+    ast::{self, NameOwner, StructFlavor, AstNode}
+};
 
 use crate::{
+<<<<<<< HEAD
     DefId, Name, AsName, Struct, Enum, HirDatabase, DefKind,
+=======
+    DefId, DefLoc, Name, AsName, Struct, Enum, EnumVariant,
+    VariantData, StructField, HirDatabase, DefKind,
+    SourceItemId,
+>>>>>>> 95ac72a3... Implement type inference for enum variants
     type_ref::TypeRef,
 };
 
@@ -45,33 +54,39 @@ pub(crate) fn struct_data_query(
     }
 }
 
-impl Enum {
-    pub(crate) fn new(def_id: DefId) -> Self {
-        Enum { def_id }
-    }
+fn get_def_id(
+    db: &impl HirDatabase,
+    same_file_loc: &DefLoc,
+    node: &SyntaxNode,
+    expected_kind: DefKind,
+) -> DefId {
+    let file_id = same_file_loc.source_item_id.file_id;
+    let file_items = db.file_items(file_id);
+
+    let item_id = file_items.id_of(file_id, node);
+    let source_item_id = SourceItemId {
+        item_id: Some(item_id),
+        ..same_file_loc.source_item_id
+    };
+    let loc = DefLoc {
+        kind: expected_kind,
+        source_item_id: source_item_id,
+        ..*same_file_loc
+    };
+    loc.id(db)
 }
 
 #[derive(Debug, Clone, PartialEq, Eq)]
 pub struct EnumData {
     pub(crate) name: Option<Name>,
-    pub(crate) variants: Vec<(Name, Arc<VariantData>)>,
+    // TODO: keep track of names also since we already have them?
+    // then we won't need additional db lookups
+    pub(crate) variants: Option<Vec<EnumVariant>>,
 }
 
 impl EnumData {
-    fn new(enum_def: &ast::EnumDef) -> Self {
+    fn new(enum_def: &ast::EnumDef, variants: Option<Vec<EnumVariant>>) -> Self {
         let name = enum_def.name().map(|n| n.as_name());
-        let variants = if let Some(evl) = enum_def.variant_list() {
-            evl.variants()
-                .map(|v| {
-                    (
-                        v.name().map(|n| n.as_name()).unwrap_or_else(Name::missing),
-                        Arc::new(VariantData::new(v.flavor())),
-                    )
-                })
-                .collect()
-        } else {
-            Vec::new()
-        };
         EnumData { name, variants }
     }
 
@@ -83,7 +98,57 @@ pub(crate) fn enum_data_query(
         assert!(def_loc.kind == DefKind::Enum);
         let syntax = db.file_item(def_loc.source_item_id);
         let enum_def = ast::EnumDef::cast(&syntax).expect("enum def should point to EnumDef node");
-        Ok(Arc::new(EnumData::new(enum_def)))
+        let variants = enum_def.variant_list().map(|vl| {
+            vl.variants()
+                .map(|ev| {
+                    let def_id = get_def_id(db, &def_loc, ev.syntax(), DefKind::EnumVariant);
+                    EnumVariant::new(def_id)
+                })
+                .collect()
+        });
+        Ok(Arc::new(EnumData::new(enum_def, variants)))
+    }
+}
+
+#[derive(Debug, Clone, PartialEq, Eq)]
+pub struct EnumVariantData {
+    pub(crate) name: Option<Name>,
+    pub(crate) variant_data: Arc<VariantData>,
+    pub(crate) parent_enum: Enum,
+}
+
+impl EnumVariantData {
+    fn new(variant_def: &ast::EnumVariant, parent_enum: Enum) -> EnumVariantData {
+        let name = variant_def.name().map(|n| n.as_name());
+        let variant_data = VariantData::new(variant_def.flavor());
+        let variant_data = Arc::new(variant_data);
+        EnumVariantData {
+            name,
+            variant_data,
+            parent_enum,
+        }
+    }
+
+    pub(crate) fn enum_variant_data_query(
+        db: &impl HirDatabase,
+        def_id: DefId,
+    ) -> Cancelable<Arc<EnumVariantData>> {
+        let def_loc = def_id.loc(db);
+        assert!(def_loc.kind == DefKind::EnumVariant);
+        let syntax = db.file_item(def_loc.source_item_id);
+        let variant_def = ast::EnumVariant::cast(&syntax)
+            .expect("enum variant def should point to EnumVariant node");
+        let enum_node = syntax
+            .parent()
+            .expect("enum variant should have enum variant list ancestor")
+            .parent()
+            .expect("enum variant list should have enum ancestor");
+        let enum_def_id = get_def_id(db, &def_loc, enum_node, DefKind::Enum);
+
+        Ok(Arc::new(EnumVariantData::new(
+            variant_def,
+            Enum::new(enum_def_id),
+        )))
     }
 }
 
index fa3e4baa7d2ba23c0dbe3beec5420c379a480434..c7d1bf0a625802f3631f270968e9fbeb5339347d 100644 (file)
@@ -44,6 +44,7 @@ pub enum Def {
     Module(Module),
     Struct(Struct),
     Enum(Enum),
+    EnumVariant(EnumVariant),
     Function(Function),
     Item,
 }
@@ -188,6 +189,10 @@ pub struct Enum {
 }
 
 impl Enum {
+    pub(crate) fn new(def_id: DefId) -> Self {
+        Enum { def_id }
+    }
+
     pub fn def_id(&self) -> DefId {
         self.def_id
     }
@@ -196,11 +201,38 @@ pub fn name(&self, db: &impl HirDatabase) -> Cancelable<Option<Name>> {
         Ok(db.enum_data(self.def_id)?.name.clone())
     }
 
-    pub fn variants(&self, db: &impl HirDatabase) -> Cancelable<Vec<(Name, Arc<VariantData>)>> {
+    pub fn variants(&self, db: &impl HirDatabase) -> Cancelable<Option<Vec<EnumVariant>>> {
         Ok(db.enum_data(self.def_id)?.variants.clone())
     }
 }
 
+#[derive(Debug, Clone, PartialEq, Eq, Hash)]
+pub struct EnumVariant {
+    pub(crate) def_id: DefId,
+}
+
+impl EnumVariant {
+    pub(crate) fn new(def_id: DefId) -> Self {
+        EnumVariant { def_id }
+    }
+
+    pub fn def_id(&self) -> DefId {
+        self.def_id
+    }
+
+    pub fn parent_enum(&self, db: &impl HirDatabase) -> Cancelable<Enum> {
+        Ok(db.enum_variant_data(self.def_id)?.parent_enum.clone())
+    }
+
+    pub fn name(&self, db: &impl HirDatabase) -> Cancelable<Option<Name>> {
+        Ok(db.enum_variant_data(self.def_id)?.name.clone())
+    }
+
+    pub fn variant_data(&self, db: &impl HirDatabase) -> Cancelable<Arc<VariantData>> {
+        Ok(db.enum_variant_data(self.def_id)?.variant_data.clone())
+    }
+}
+
 #[derive(Debug, Clone, PartialEq, Eq, Hash)]
 pub struct Function {
     pub(crate) def_id: DefId,
index 1cb408cff6c947e22c2025bd0e467506665c9196..d7d62e8634b5b9b26a21a8c54abe5a9fb60d0a6e 100644 (file)
@@ -13,6 +13,7 @@ impl Module {
     pub(crate) fn new(def_id: DefId) -> Self {
         crate::code_model_api::Module { def_id }
     }
+
     pub(crate) fn from_module_id(
         db: &impl HirDatabase,
         source_root_id: SourceRootId,
@@ -85,6 +86,7 @@ pub(crate) fn crate_root_impl(&self, db: &impl HirDatabase) -> Cancelable<Module
         let module_id = loc.module_id.crate_root(&module_tree);
         Module::from_module_id(db, loc.source_root_id, module_id)
     }
+
     /// Finds a child module with the specified name.
     pub fn child_impl(&self, db: &impl HirDatabase, name: &Name) -> Cancelable<Option<Module>> {
         let loc = self.def_id.loc(db);
@@ -92,12 +94,14 @@ pub fn child_impl(&self, db: &impl HirDatabase, name: &Name) -> Cancelable<Optio
         let child_id = ctry!(loc.module_id.child(&module_tree, name));
         Module::from_module_id(db, loc.source_root_id, child_id).map(Some)
     }
+
     pub fn parent_impl(&self, db: &impl HirDatabase) -> Cancelable<Option<Module>> {
         let loc = self.def_id.loc(db);
         let module_tree = db.module_tree(loc.source_root_id)?;
         let parent_id = ctry!(loc.module_id.parent(&module_tree));
         Module::from_module_id(db, loc.source_root_id, parent_id).map(Some)
     }
+
     /// Returns a `ModuleScope`: a set of items, visible in this module.
     pub fn scope_impl(&self, db: &impl HirDatabase) -> Cancelable<ModuleScope> {
         let loc = self.def_id.loc(db);
@@ -105,6 +109,7 @@ pub fn scope_impl(&self, db: &impl HirDatabase) -> Cancelable<ModuleScope> {
         let res = item_map.per_module[&loc.module_id].clone();
         Ok(res)
     }
+
     pub fn resolve_path_impl(
         &self,
         db: &impl HirDatabase,
@@ -126,7 +131,7 @@ pub fn resolve_path_impl(
         );
 
         let segments = &path.segments;
-        for name in segments.iter() {
+        for (idx, name) in segments.iter().enumerate() {
             let curr = if let Some(r) = curr_per_ns.as_ref().take_types() {
                 r
             } else {
@@ -134,7 +139,35 @@ pub fn resolve_path_impl(
             };
             let module = match curr.resolve(db)? {
                 Def::Module(it) => it,
-                // TODO here would be the place to handle enum variants...
+                Def::Enum(e) => {
+                    if segments.len() == idx + 1 {
+                        // enum variant
+                        let matching_variant = e.variants(db)?.map(|variants| {
+                            variants
+                                .into_iter()
+                                // FIXME: replace by match lol
+                                .find(|variant| {
+                                    variant
+                                        .name(db)
+                                        .map(|o| o.map(|ref n| n == name))
+                                        .unwrap_or(Some(false))
+                                        .unwrap_or(false)
+                                })
+                        });
+
+                        if let Some(Some(variant)) = matching_variant {
+                            return Ok(PerNs::both(variant.def_id(), e.def_id()));
+                        } else {
+                            return Ok(PerNs::none());
+                        }
+                    } else if segments.len() == idx {
+                        // enum
+                        return Ok(PerNs::types(e.def_id()));
+                    } else {
+                        // malformed enum?
+                        return Ok(PerNs::none());
+                    }
+                }
                 _ => return Ok(PerNs::none()),
             };
             let scope = module.scope(db)?;
@@ -146,6 +179,7 @@ pub fn resolve_path_impl(
         }
         Ok(curr_per_ns)
     }
+
     pub fn problems_impl(
         &self,
         db: &impl HirDatabase,
index 7dbe93f2bd34eb9d9c34153668ec7473f408e83a..9a6ef8083c87a5689e968f9dd78d560896c3142d 100644 (file)
@@ -12,7 +12,7 @@
     module_tree::{ModuleId, ModuleTree},
     nameres::{ItemMap, InputModuleItems},
     ty::{InferenceResult, Ty},
-    adt::{StructData, EnumData},
+    adt::{StructData, EnumData, EnumVariantData},
     impl_block::ModuleImplBlocks,
 };
 
@@ -47,6 +47,11 @@ fn enum_data(def_id: DefId) -> Cancelable<Arc<EnumData>> {
         use fn crate::adt::EnumData::enum_data_query;
     }
 
+    fn enum_variant_data(def_id: DefId) -> Cancelable<Arc<EnumVariantData>> {
+        type EnumVariantDataQuery;
+        use fn crate::adt::EnumVariantData::enum_variant_data_query;
+    }
+
     fn infer(def_id: DefId) -> Cancelable<Arc<InferenceResult>> {
         type InferQuery;
         use fn crate::ty::infer;
index 0aa687a08ed52ab6e48564177a750f9ebbb9cc28..db0107e53bc0a6136fc5e37b252e6ec07f43c92b 100644 (file)
@@ -3,7 +3,7 @@
 use ra_arena::{Arena, RawId, impl_arena_id};
 
 use crate::{
-    HirDatabase, PerNs, Def, Function, Struct, Enum, ImplBlock, Crate,
+    HirDatabase, PerNs, Def, Function, Struct, Enum, EnumVariant, ImplBlock, Crate,
     module_tree::ModuleId,
 };
 
@@ -145,6 +145,7 @@ pub(crate) enum DefKind {
     Function,
     Struct,
     Enum,
+    EnumVariant,
     Item,
 
     StructCtor,
@@ -170,10 +171,8 @@ pub fn resolve(self, db: &impl HirDatabase) -> Cancelable<Def> {
                 let struct_def = Struct::new(self);
                 Def::Struct(struct_def)
             }
-            DefKind::Enum => {
-                let enum_def = Enum::new(self);
-                Def::Enum(enum_def)
-            }
+            DefKind::Enum => Def::Enum(Enum::new(self)),
+            DefKind::EnumVariant => Def::EnumVariant(EnumVariant::new(self)),
             DefKind::StructCtor => Def::Item,
             DefKind::Item => Def::Item,
         };
@@ -258,7 +257,9 @@ fn init(&mut self, source_file: &SourceFile) {
         // change parent's id. This means that, say, adding a new function to a
         // trait does not chage ids of top-level items, which helps caching.
         bfs(source_file.syntax(), |it| {
-            if let Some(module_item) = ast::ModuleItem::cast(it) {
+            if let Some(enum_variant) = ast::EnumVariant::cast(it) {
+                self.alloc(enum_variant.syntax().to_owned());
+            } else if let Some(module_item) = ast::ModuleItem::cast(it) {
                 self.alloc(module_item.syntax().to_owned());
             } else if let Some(macro_call) = ast::MacroCall::cast(it) {
                 self.alloc(macro_call.syntax().to_owned());
index 1b6b72c98e98a75404d712f3f5feaee832c5f3d9..74957ffc90353b46f2d1cbcde2a05fda6fceb3b9 100644 (file)
@@ -56,6 +56,6 @@ macro_rules! ctry {
     Crate, CrateDependency,
     Def,
     Module, ModuleSource, Problem,
-    Struct, Enum,
+    Struct, Enum, EnumVariant,
     Function, FnSignature,
 };
index 7a0301648a68ba115d465c4bb6a0f3b355083dbd..6f93bb59de92b38084b212ea4cfcfb65f8e9d221 100644 (file)
@@ -233,6 +233,7 @@ impl db::HirDatabase {
             fn type_for_field() for db::TypeForFieldQuery;
             fn struct_data() for db::StructDataQuery;
             fn enum_data() for db::EnumDataQuery;
+            fn enum_variant_data() for db::EnumVariantDataQuery;
             fn impls_in_module() for db::ImplsInModuleQuery;
             fn body_hir() for db::BodyHirQuery;
             fn body_syntax_mapping() for db::BodySyntaxMappingQuery;
index eb7764f652c5dc0ab060dc95c0aff9547bcbc77d..18c41a0155772acd83007906038edd103f61f292 100644 (file)
@@ -30,7 +30,7 @@
 use ra_db::Cancelable;
 
 use crate::{
-    Def, DefId, Module, Function, Struct, Enum, Path, Name, ImplBlock,
+    Def, DefId, Module, Function, Struct, Enum, EnumVariant, Path, Name, ImplBlock,
     FnSignature, FnScopes,
     db::HirDatabase,
     type_ref::{TypeRef, Mutability},
@@ -453,6 +453,12 @@ pub fn type_for_enum(db: &impl HirDatabase, s: Enum) -> Cancelable<Ty> {
     })
 }
 
+pub fn type_for_enum_variant(db: &impl HirDatabase, ev: EnumVariant) -> Cancelable<Ty> {
+    let enum_parent = ev.parent_enum(db)?;
+
+    type_for_enum(db, enum_parent)
+}
+
 pub(super) fn type_for_def(db: &impl HirDatabase, def_id: DefId) -> Cancelable<Ty> {
     let def = def_id.resolve(db)?;
     match def {
@@ -463,6 +469,7 @@ pub(super) fn type_for_def(db: &impl HirDatabase, def_id: DefId) -> Cancelable<T
         Def::Function(f) => type_for_fn(db, f),
         Def::Struct(s) => type_for_struct(db, s),
         Def::Enum(e) => type_for_enum(db, e),
+        Def::EnumVariant(ev) => type_for_enum_variant(db, ev),
         Def::Item => {
             log::debug!("trying to get type for item of unknown type {:?}", def_id);
             Ok(Ty::Unknown)
@@ -477,12 +484,9 @@ pub(super) fn type_for_field(
 ) -> Cancelable<Option<Ty>> {
     let def = def_id.resolve(db)?;
     let variant_data = match def {
-        Def::Struct(s) => {
-            let variant_data = s.variant_data(db)?;
-            variant_data
-        }
+        Def::Struct(s) => s.variant_data(db)?,
+        Def::EnumVariant(ev) => ev.variant_data(db)?,
         // TODO: unions
-        // TODO: enum variants
         _ => panic!(
             "trying to get type for field in non-struct/variant {:?}",
             def_id
@@ -788,6 +792,10 @@ fn resolve_variant(&self, path: Option<&Path>) -> Cancelable<(Ty, Option<DefId>)
                 let ty = type_for_struct(self.db, s)?;
                 (ty, Some(def_id))
             }
+            Def::EnumVariant(ev) => {
+                let ty = type_for_enum_variant(self.db, ev)?;
+                (ty, Some(def_id))
+            }
             _ => (Ty::Unknown, None),
         })
     }
index ba2a444743c54656b416cb1ad9bc21dc3595f2ed..d8c0af32682d71fc0bbde62f01dca67bc0eb888b 100644 (file)
@@ -94,6 +94,22 @@ fn test() {
     );
 }
 
+#[test]
+fn infer_enum() {
+    check_inference(
+        r#"
+enum E {
+  V1 { field: u32 },
+  V2
+}
+fn test() {
+  E::V1 { field: 1 };
+  E::V2;
+}"#,
+        "enum.txt",
+    );
+}
+
 #[test]
 fn infer_refs() {
     check_inference(
diff --git a/crates/ra_hir/src/ty/tests/data/enum.txt b/crates/ra_hir/src/ty/tests/data/enum.txt
new file mode 100644 (file)
index 0000000..481eb0b
--- /dev/null
@@ -0,0 +1,4 @@
+[48; 82) '{   E:...:V2; }': ()
+[52; 70) 'E::V1 ...d: 1 }': E
+[67; 68) '1': u32
+[74; 79) 'E::V2': E
index 4723a65a6b0645f197233ab397b7001fce944fb7..6a55670d192cb271fbe6e72cc14651cf7110942a 100644 (file)
@@ -21,14 +21,20 @@ pub(super) fn complete_path(acc: &mut Completions, ctx: &CompletionContext) -> C
                     .add_to(acc)
             });
         }
-        hir::Def::Enum(e) => e
-            .variants(ctx.db)?
-            .into_iter()
-            .for_each(|(name, _variant)| {
-                CompletionItem::new(CompletionKind::Reference, name.to_string())
-                    .kind(CompletionItemKind::EnumVariant)
-                    .add_to(acc)
-            }),
+        hir::Def::Enum(e) => {
+            e.variants(ctx.db)?
+                .unwrap_or(vec![])
+                .into_iter()
+                .for_each(|variant| {
+                    let variant_name = variant.name(ctx.db);
+
+                    if let Ok(Some(name)) = variant_name {
+                        CompletionItem::new(CompletionKind::Reference, name.to_string())
+                            .kind(CompletionItemKind::EnumVariant)
+                            .add_to(acc)
+                    }
+                })
+        }
         _ => return Ok(()),
     };
     Ok(())
index a2e06f5db3e94644393657538a227b6ee72a683f..efdf261bef8c8d1a55b9c2c40d5d2615b9a4dc6e 100644 (file)
@@ -122,6 +122,7 @@ impl hir::db::HirDatabase {
             fn type_for_field() for hir::db::TypeForFieldQuery;
             fn struct_data() for hir::db::StructDataQuery;
             fn enum_data() for hir::db::EnumDataQuery;
+            fn enum_variant_data() for hir::db::EnumVariantDataQuery;
             fn impls_in_module() for hir::db::ImplsInModuleQuery;
             fn body_hir() for hir::db::BodyHirQuery;
             fn body_syntax_mapping() for hir::db::BodySyntaxMappingQuery;