]> git.lizzy.rs Git - rust.git/commitdiff
Infer correct expected type for generic struct fields
authorFlorian Diebold <flodiebold@gmail.com>
Sun, 23 May 2021 16:10:40 +0000 (18:10 +0200)
committerFlorian Diebold <flodiebold@gmail.com>
Sun, 23 May 2021 16:45:44 +0000 (18:45 +0200)
crates/hir/src/lib.rs
crates/hir_def/src/lib.rs
crates/ide_completion/src/context.rs
crates/ide_completion/src/render.rs

index a7c42ca1e5fe78a8cd600cc18fa74454f24cb756..edee99356b1b66c7322a73f369657c0d7dc43e89 100644 (file)
@@ -513,9 +513,9 @@ pub fn name(&self, db: &dyn HirDatabase) -> Name {
     }
 
     /// Returns the type as in the signature of the struct (i.e., with
-    /// placeholder types for type parameters). This is good for showing
-    /// signature help, but not so good to actually get the type of the field
-    /// when you actually have a variable of the struct.
+    /// placeholder types for type parameters). Only use this in the context of
+    /// the field *definition*; if you've already got a variable of the struct
+    /// type, use `Type::field_type` to get to the field type.
     pub fn ty(&self, db: &dyn HirDatabase) -> Type {
         let var_id = self.parent.into();
         let generic_def_id: GenericDefId = match self.parent {
@@ -1944,6 +1944,18 @@ fn go(ty: &Ty) -> bool {
         }
     }
 
+    pub fn field_type(&self, db: &dyn HirDatabase, field: Field) -> Option<Type> {
+        let (adt_id, substs) = self.ty.as_adt()?;
+        let variant_id: hir_def::VariantId = field.parent.into();
+        if variant_id.adt_id() != adt_id {
+            return None;
+        }
+
+        let ty = db.field_types(variant_id).get(field.id)?.clone();
+        let ty = ty.substitute(&Interner, substs);
+        Some(self.derived(ty))
+    }
+
     pub fn fields(&self, db: &dyn HirDatabase) -> Vec<(Field, Type)> {
         let (variant_id, substs) = match self.ty.kind(&Interner) {
             &TyKind::Adt(hir_ty::AdtId(AdtId::StructId(s)), ref substs) => (s.into(), substs),
index a82ea5957e9943e12d6ae4388b43f8b796784984..70001cac86668e57071d2ef66744cbc5fda18f40 100644 (file)
@@ -485,6 +485,14 @@ pub fn file_id(self, db: &dyn db::DefDatabase) -> HirFileId {
             VariantId::UnionId(it) => it.lookup(db).id.file_id(),
         }
     }
+
+    pub fn adt_id(self) -> AdtId {
+        match self {
+            VariantId::EnumVariantId(it) => it.parent.into(),
+            VariantId::StructId(it) => it.into(),
+            VariantId::UnionId(it) => it.into(),
+        }
+    }
 }
 
 trait Intern {
index c929d73949369bfab5d22fa7f254e418c37a4668..4a88a6e88b21f21302228a9e5ad75cbf8c7028f9 100644 (file)
@@ -337,25 +337,25 @@ fn expected_type_and_name(&self) -> (Option<Type>, Option<NameOrNameRef>) {
                     },
                     ast::RecordExprFieldList(_it) => {
                         cov_mark::hit!(expected_type_struct_field_without_leading_char);
-                        self.token.prev_sibling_or_token()
-                            .and_then(|se| se.into_node())
-                            .and_then(|node| ast::RecordExprField::cast(node))
-                            .and_then(|rf| self.sema.resolve_record_field(&rf).zip(Some(rf)))
-                            .map(|(f, rf)|(
-                                Some(f.0.ty(self.db)),
-                                rf.field_name().map(NameOrNameRef::NameRef),
+                        // wouldn't try {} be nice...
+                        (|| {
+                            let record_ty = self.sema.type_of_expr(&ast::Expr::cast(node.parent()?)?)?;
+                            let expr_field = self.token.prev_sibling_or_token()?
+                            .into_node()
+                                      .and_then(|node| ast::RecordExprField::cast(node))?;
+                            let field = self.sema.resolve_record_field(&expr_field)?.0;
+                            Some((
+                                record_ty.field_type(self.db, field),
+                                expr_field.field_name().map(NameOrNameRef::NameRef),
                             ))
-                            .unwrap_or((None, None))
+                        })().unwrap_or((None, None))
                     },
                     ast::RecordExprField(it) => {
                         cov_mark::hit!(expected_type_struct_field_with_leading_char);
-                        self.sema
-                            .resolve_record_field(&it)
-                            .map(|f|(
-                                Some(f.0.ty(self.db)),
-                                it.field_name().map(NameOrNameRef::NameRef),
-                            ))
-                            .unwrap_or((None, None))
+                        (
+                            it.expr().as_ref().and_then(|e| self.sema.type_of_expr(e)),
+                            it.field_name().map(NameOrNameRef::NameRef),
+                        )
                     },
                     ast::MatchExpr(it) => {
                         cov_mark::hit!(expected_type_match_arm_without_leading_char);
@@ -910,7 +910,7 @@ fn foo() -> u32 {
     }
 
     #[test]
-    fn expected_type_closure_param() {
+    fn expected_type_closure_param_return() {
         check_expected_type_and_name(
             r#"
 fn foo() {
index 6b04ee1648717d1176c82414e15c868696e389b7..d7f96b8645bfe5ff5e6e84bac32feab749bbea03 100644 (file)
@@ -667,6 +667,13 @@ fn foo() { A { the$0 } }
                         ),
                         detail: "u32",
                         deprecated: true,
+                        relevance: CompletionRelevance {
+                            exact_name_match: false,
+                            type_match: Some(
+                                CouldUnify,
+                            ),
+                            is_local: false,
+                        },
                     },
                 ]
             "#]],