]> git.lizzy.rs Git - rust.git/commitdiff
Infer result of struct literals, and recurse into their child expressions
authorFlorian Diebold <flodiebold@gmail.com>
Mon, 24 Dec 2018 20:00:14 +0000 (21:00 +0100)
committerFlorian Diebold <flodiebold@gmail.com>
Tue, 25 Dec 2018 14:16:42 +0000 (15:16 +0100)
crates/ra_hir/src/adt.rs
crates/ra_hir/src/ty.rs
crates/ra_hir/src/ty/tests.rs
crates/ra_hir/src/ty/tests/data/0004_struct.txt
crates/ra_syntax/src/ast/generated.rs
crates/ra_syntax/src/grammar.ron

index a2d228593af7f132a52e88924fab6f5aab14c7e5..ee270ac459d7001443aa860af556804d6667314a 100644 (file)
@@ -1,3 +1,5 @@
+use std::sync::Arc;
+
 use ra_syntax::{SmolStr, ast::{self, NameOwner}};
 
 use crate::{
@@ -15,6 +17,14 @@ pub(crate) fn new(def_id: DefId) -> Self {
         Struct { def_id }
     }
 
+    pub fn def_id(&self) -> DefId {
+        self.def_id
+    }
+
+    pub fn struct_data(&self, db: &impl HirDatabase) -> Cancelable<Arc<StructData>> {
+        Ok(db.struct_data(self.def_id)?)
+    }
+
     pub fn name(&self, db: &impl HirDatabase) -> Cancelable<SmolStr> {
         Ok(db.struct_data(self.def_id)?.name.clone())
     }
@@ -23,7 +33,7 @@ pub fn name(&self, db: &impl HirDatabase) -> Cancelable<SmolStr> {
 #[derive(Debug, Clone, PartialEq, Eq)]
 pub struct StructData {
     name: SmolStr,
-    variant_data: VariantData,
+    variant_data: Arc<VariantData>,
 }
 
 impl StructData {
@@ -33,8 +43,17 @@ pub(crate) fn new(struct_def: ast::StructDef) -> StructData {
             .map(|n| n.text())
             .unwrap_or(SmolStr::new("[error]"));
         let variant_data = VariantData::Unit; // TODO implement this
+        let variant_data = Arc::new(variant_data);
         StructData { name, variant_data }
     }
+
+    pub fn name(&self) -> &SmolStr {
+        &self.name
+    }
+
+    pub fn variant_data(&self) -> &Arc<VariantData> {
+        &self.variant_data
+    }
 }
 
 pub struct Enum {
@@ -46,6 +65,10 @@ pub(crate) fn new(def_id: DefId) -> Self {
         Enum { def_id }
     }
 
+    pub fn def_id(&self) -> DefId {
+        self.def_id
+    }
+
     pub fn name(&self, db: &impl HirDatabase) -> Cancelable<SmolStr> {
         Ok(db.enum_data(self.def_id)?.name.clone())
     }
@@ -54,7 +77,7 @@ pub fn name(&self, db: &impl HirDatabase) -> Cancelable<SmolStr> {
 #[derive(Debug, Clone, PartialEq, Eq)]
 pub struct EnumData {
     name: SmolStr,
-    variants: Vec<(SmolStr, VariantData)>,
+    variants: Vec<(SmolStr, Arc<VariantData>)>,
 }
 
 impl EnumData {
index 429292cfc9460be8745f80fbcc8ef93c17b5abe4..386af8120abf074b63616622c0e88200d40e75f2 100644 (file)
@@ -16,9 +16,9 @@
 };
 
 use crate::{
-    Def, DefId, FnScopes, Module, Function,
-    Path, db::HirDatabase,
-    module::nameres::Namespace
+    Def, DefId, FnScopes, Module, Function, Struct, Path,
+    db::HirDatabase,
+    adt::VariantData,
 };
 
 #[derive(Clone, PartialEq, Eq, Hash, Debug)]
@@ -125,6 +125,37 @@ pub struct FnSig {
 }
 
 impl Ty {
+    pub(crate) fn new_from_ast_path(
+        db: &impl HirDatabase,
+        module: &Module,
+        path: ast::Path,
+    ) -> Cancelable<Self> {
+        let path = if let Some(p) = Path::from_ast(path) {
+            p
+        } else {
+            return Ok(Ty::Unknown);
+        };
+        if path.is_ident() {
+            let name = &path.segments[0];
+            if let Some(int_ty) = primitive::IntTy::from_string(&name) {
+                return Ok(Ty::Int(int_ty));
+            } else if let Some(uint_ty) = primitive::UintTy::from_string(&name) {
+                return Ok(Ty::Uint(uint_ty));
+            } else if let Some(float_ty) = primitive::FloatTy::from_string(&name) {
+                return Ok(Ty::Float(float_ty));
+            }
+        }
+
+        // Resolve in module (in type namespace)
+        let resolved = if let Some(r) = module.resolve_path(db, path)?.take_types() {
+            r
+        } else {
+            return Ok(Ty::Unknown);
+        };
+        let ty = db.type_for_def(resolved)?;
+        Ok(ty)
+    }
+
     pub(crate) fn new(
         db: &impl HirDatabase,
         module: &Module,
@@ -136,31 +167,11 @@ pub(crate) fn new(
             TupleType(_inner) => Ty::Unknown, // TODO
             NeverType(..) => Ty::Never,
             PathType(inner) => {
-                let path = if let Some(p) = inner.path().and_then(Path::from_ast) {
-                    p
+                if let Some(path) = inner.path() {
+                    Ty::new_from_ast_path(db, module, path)?
                 } else {
-                    return Ok(Ty::Unknown);
-                };
-                if path.is_ident() {
-                    let name = &path.segments[0];
-                    if let Some(int_ty) = primitive::IntTy::from_string(&name) {
-                        return Ok(Ty::Int(int_ty));
-                    } else if let Some(uint_ty) = primitive::UintTy::from_string(&name) {
-                        return Ok(Ty::Uint(uint_ty));
-                    } else if let Some(float_ty) = primitive::FloatTy::from_string(&name) {
-                        return Ok(Ty::Float(float_ty));
-                    }
+                    Ty::Unknown
                 }
-
-                // Resolve in module (in type namespace)
-                let resolved =
-                    if let Some(r) = module.resolve_path(db, path)?.take(Namespace::Types) {
-                        r
-                    } else {
-                        return Ok(Ty::Unknown);
-                    };
-                let ty = db.type_for_def(resolved)?;
-                ty
             }
             PointerType(_inner) => Ty::Unknown,     // TODO
             ArrayType(_inner) => Ty::Unknown,       // TODO
@@ -236,6 +247,13 @@ pub fn type_for_fn(db: &impl HirDatabase, f: Function) -> Cancelable<Ty> {
     Ok(Ty::FnPtr(Arc::new(sig)))
 }
 
+pub fn type_for_struct(db: &impl HirDatabase, s: Struct) -> Cancelable<Ty> {
+    Ok(Ty::Adt {
+        def_id: s.def_id(),
+        name: s.name(db)?,
+    })
+}
+
 // TODO this should probably be per namespace (i.e. types vs. values), since for
 // a tuple struct `struct Foo(Bar)`, Foo has function type as a value, but
 // defines the struct type Foo when used in the type namespace. rustc has a
@@ -249,10 +267,7 @@ pub fn type_for_def(db: &impl HirDatabase, def_id: DefId) -> Cancelable<Ty> {
             Ok(Ty::Unknown)
         }
         Def::Function(f) => type_for_fn(db, f),
-        Def::Struct(s) => Ok(Ty::Adt {
-            def_id,
-            name: s.name(db)?,
-        }),
+        Def::Struct(s) => type_for_struct(db, s),
         Def::Enum(e) => Ok(Ty::Adt {
             def_id,
             name: e.name(db)?,
@@ -330,15 +345,36 @@ fn infer_path_expr(&mut self, expr: ast::PathExpr) -> Cancelable<Option<Ty>> {
         };
 
         // resolve in module
-        let resolved = ctry!(self
-            .module
-            .resolve_path(self.db, path)?
-            .take(Namespace::Values));
+        let resolved = ctry!(self.module.resolve_path(self.db, path)?.take_values());
         let ty = self.db.type_for_def(resolved)?;
         // TODO we will need to add type variables for type parameters etc. here
         Ok(Some(ty))
     }
 
+    fn resolve_variant(
+        &self,
+        path: Option<ast::Path>,
+    ) -> Cancelable<(Ty, Option<Arc<VariantData>>)> {
+        let path = if let Some(path) = path.and_then(Path::from_ast) {
+            path
+        } else {
+            return Ok((Ty::Unknown, None));
+        };
+        let def_id = if let Some(def_id) = self.module.resolve_path(self.db, path)?.take_types() {
+            def_id
+        } else {
+            return Ok((Ty::Unknown, None));
+        };
+        Ok(match def_id.resolve(self.db)? {
+            Def::Struct(s) => {
+                let struct_data = self.db.struct_data(def_id)?;
+                let ty = type_for_struct(self.db, s)?;
+                (ty, Some(struct_data.variant_data().clone()))
+            }
+            _ => (Ty::Unknown, None),
+        })
+    }
+
     fn infer_expr(&mut self, expr: ast::Expr) -> Cancelable<Ty> {
         let ty = match expr {
             ast::Expr::IfExpr(e) => {
@@ -488,7 +524,7 @@ fn infer_expr(&mut self, expr: ast::Expr) -> Cancelable<Ty> {
             ast::Expr::Label(_e) => Ty::Unknown,
             ast::Expr::ReturnExpr(e) => {
                 if let Some(e) = e.expr() {
-                    // TODO unify with return type
+                    // TODO unify with / expect return type
                     self.infer_expr(e)?;
                 };
                 Ty::Never
@@ -497,7 +533,18 @@ fn infer_expr(&mut self, expr: ast::Expr) -> Cancelable<Ty> {
                 // Can this even occur outside of a match expression?
                 Ty::Unknown
             }
-            ast::Expr::StructLit(_e) => Ty::Unknown,
+            ast::Expr::StructLit(e) => {
+                let (ty, variant_data) = self.resolve_variant(e.path())?;
+                if let Some(nfl) = e.named_field_list() {
+                    for field in nfl.fields() {
+                        if let Some(e) = field.expr() {
+                            // TODO unify with / expect field type
+                            self.infer_expr(e)?;
+                        }
+                    }
+                }
+                ty
+            }
             ast::Expr::NamedFieldList(_) | ast::Expr::NamedField(_) => {
                 // Can this even occur outside of a struct literal?
                 Ty::Unknown
index 170eef1471b4a9256a297d105235d85ea6d81d1c..9bb58ec850a5d077508deeb701d9efd4e6da6eab 100644 (file)
@@ -82,7 +82,7 @@ struct A {
 fn test() {
     let c = C(1);
     B;
-    let a: A = A { b: B, c: C() };
+    let a: A = A { b: B, c: C(1) };
     a.b;
     a.c;
 }
index a4371c5a531ad7fb56affb572a51520bab5ffcc8..41357749f27fa0aefa9582a673e3c7a1f484e716 100644 (file)
@@ -1,10 +1,14 @@
 [86; 90) 'C(1)': [unknown]
-[72; 153) '{     ...a.c; }': ()
+[121; 122) 'B': [unknown]
 [86; 87) 'C': [unknown]
+[129; 130) '1': [unknown]
 [107; 108) 'a': A
-[114; 132) 'A { b:... C() }': [unknown]
-[138; 141) 'a.b': [unknown]
-[147; 150) 'a.c': [unknown]
+[127; 128) 'C': [unknown]
+[139; 142) 'a.b': [unknown]
+[114; 133) 'A { b:...C(1) }': A
+[148; 151) 'a.c': [unknown]
+[72; 154) '{     ...a.c; }': ()
 [96; 97) 'B': [unknown]
 [88; 89) '1': [unknown]
 [82; 83) 'c': [unknown]
+[127; 131) 'C(1)': [unknown]
index c735338619aa09b3a784b2c1d42538e1c5c341b0..334da67ef038dcc34d3188be2b31eda648fe8914 100644 (file)
@@ -2142,7 +2142,15 @@ pub fn owned(&self) -> NamedFieldNode {
 }
 
 
-impl<'a> NamedField<'a> {}
+impl<'a> NamedField<'a> {
+    pub fn name_ref(self) -> Option<NameRef<'a>> {
+        super::child_opt(self)
+    }
+
+    pub fn expr(self) -> Option<Expr<'a>> {
+        super::child_opt(self)
+    }
+}
 
 // NamedFieldDef
 #[derive(Debug, Clone, Copy,)]
@@ -2218,7 +2226,11 @@ pub fn owned(&self) -> NamedFieldListNode {
 }
 
 
-impl<'a> NamedFieldList<'a> {}
+impl<'a> NamedFieldList<'a> {
+    pub fn fields(self) -> impl Iterator<Item = NamedField<'a>> + 'a {
+        super::children(self)
+    }
+}
 
 // NeverType
 #[derive(Debug, Clone, Copy,)]
@@ -3467,7 +3479,15 @@ pub fn owned(&self) -> StructLitNode {
 }
 
 
-impl<'a> StructLit<'a> {}
+impl<'a> StructLit<'a> {
+    pub fn path(self) -> Option<Path<'a>> {
+        super::child_opt(self)
+    }
+
+    pub fn named_field_list(self) -> Option<NamedFieldList<'a>> {
+        super::child_opt(self)
+    }
+}
 
 // StructPat
 #[derive(Debug, Clone, Copy,)]
index e3b9032a0c6eb45ddf49168efcfc771865fcfd48..0da8b8183d98f514046dc508c16dd440974e5b9c 100644 (file)
@@ -392,9 +392,9 @@ Grammar(
             collections: [ [ "pats", "Pat" ] ]
         ),
         "MatchGuard": (),
-        "StructLit": (),
-        "NamedFieldList": (),
-        "NamedField": (),
+        "StructLit": (options: ["Path", "NamedFieldList"]),
+        "NamedFieldList": (collections: [ ["fields", "NamedField"] ]),
+        "NamedField": (options: ["NameRef", "Expr"]),
         "CallExpr": (
             traits: ["ArgListOwner"],
             options: [ "Expr" ],