]> git.lizzy.rs Git - rust.git/commitdiff
Record method call substs and use them in call info
authorFlorian Diebold <flodiebold@gmail.com>
Sun, 23 May 2021 14:59:23 +0000 (16:59 +0200)
committerFlorian Diebold <flodiebold@gmail.com>
Sun, 23 May 2021 16:24:21 +0000 (18:24 +0200)
crates/hir/src/semantics.rs
crates/hir/src/source_analyzer.rs
crates/hir_ty/src/diagnostics/expr.rs
crates/hir_ty/src/diagnostics/unsafe_check.rs
crates/hir_ty/src/infer.rs
crates/hir_ty/src/infer/expr.rs
crates/hir_ty/src/infer/unify.rs
crates/ide_completion/src/context.rs
crates/ide_db/src/call_info/tests.rs

index 1b5064b5a20d6203149ec7ca935db60c25596d6a..d65dd7df08de2ef28e86f0474a0692cd3ae1ad73 100644 (file)
@@ -11,7 +11,7 @@
     AsMacroCall, FunctionId, TraitId, VariantId,
 };
 use hir_expand::{name::AsName, ExpansionInfo};
-use hir_ty::associated_type_shorthand_candidates;
+use hir_ty::{associated_type_shorthand_candidates, Interner};
 use itertools::Itertools;
 use rustc_hash::{FxHashMap, FxHashSet};
 use syntax::{
@@ -501,14 +501,12 @@ fn type_of_self(&self, param: &ast::SelfParam) -> Option<Type> {
     }
 
     fn resolve_method_call(&self, call: &ast::MethodCallExpr) -> Option<FunctionId> {
-        self.analyze(call.syntax()).resolve_method_call(self.db, call)
+        self.analyze(call.syntax()).resolve_method_call(self.db, call).map(|(id, _)| id)
     }
 
     fn resolve_method_call_as_callable(&self, call: &ast::MethodCallExpr) -> Option<Callable> {
-        // FIXME: this erases Substs, we should instead record the correct
-        // substitution during inference and use that
-        let func = self.resolve_method_call(call)?;
-        let ty = hir_ty::TyBuilder::value_ty(self.db, func.into()).fill_with_unknown().build();
+        let (func, subst) = self.analyze(call.syntax()).resolve_method_call(self.db, call)?;
+        let ty = self.db.value_ty(func.into()).substitute(&Interner, &subst);
         let resolver = self.analyze(call.syntax()).resolver;
         let ty = Type::new_with_resolver(self.db, &resolver, ty)?;
         let mut res = ty.as_callable(self.db)?;
index b5c65808e60a9f9ba22597051e95f84eaf089b07..a1a9c727a6788b90d9cdea1674da468701829d94 100644 (file)
@@ -143,7 +143,7 @@ pub(crate) fn resolve_method_call(
         &self,
         db: &dyn HirDatabase,
         call: &ast::MethodCallExpr,
-    ) -> Option<FunctionId> {
+    ) -> Option<(FunctionId, Substitution)> {
         let expr_id = self.expr_id(db, &call.clone().into())?;
         self.infer.as_ref()?.method_resolution(expr_id)
     }
index 53c4ee9da9aac613260c2529f5661447b6ebaebc..d1f113e7ff9789f2dec08e4b5b1d306f8f3ddfd7 100644 (file)
@@ -181,7 +181,7 @@ fn check_for_filter_map_next(&mut self, db: &dyn HirDatabase) {
         for (id, expr) in body.exprs.iter() {
             if let Expr::MethodCall { receiver, .. } = expr {
                 let function_id = match self.infer.method_resolution(id) {
-                    Some(id) => id,
+                    Some((id, _)) => id,
                     None => continue,
                 };
 
@@ -239,15 +239,11 @@ fn validate_call(&mut self, db: &dyn HirDatabase, call_id: ExprId, expr: &Expr)
                     return;
                 }
 
-                // FIXME: note that we erase information about substs here. This
-                // is not right, but, luckily, doesn't matter as we care only
-                // about the number of params
-                let callee = match self.infer.method_resolution(call_id) {
-                    Some(callee) => callee,
+                let (callee, subst) = match self.infer.method_resolution(call_id) {
+                    Some(it) => it,
                     None => return,
                 };
-                let sig =
-                    db.callable_item_signature(callee.into()).into_value_and_skipped_binders().0;
+                let sig = db.callable_item_signature(callee.into()).substitute(&Interner, &subst);
 
                 (sig, args)
             }
index ed97dc0e3f37de367b00a3087cc5663f05625682..5d13bddea3598e3720018347568e4cb95d8f4dfe 100644 (file)
@@ -105,7 +105,7 @@ fn walk_unsafe(
         Expr::MethodCall { .. } => {
             if infer
                 .method_resolution(current)
-                .map(|func| db.function_data(func).is_unsafe())
+                .map(|(func, _)| db.function_data(func).is_unsafe())
                 .unwrap_or(false)
             {
                 unsafe_exprs.push(UnsafeExpr { expr: current, inside_unsafe_block });
index f1cebbdb983f70c21c20f7d26de02030ff839b68..db3c937ff5e780f576a6d59d119b7e73397cc82a 100644 (file)
@@ -37,8 +37,8 @@
 use super::{DomainGoal, InEnvironment, ProjectionTy, TraitEnvironment, TraitRef, Ty};
 use crate::{
     db::HirDatabase, fold_tys, infer::diagnostics::InferenceDiagnostic,
-    lower::ImplTraitLoweringMode, to_assoc_type_id, AliasEq, AliasTy, Goal, Interner, TyBuilder,
-    TyExt, TyKind,
+    lower::ImplTraitLoweringMode, to_assoc_type_id, AliasEq, AliasTy, Goal, Interner, Substitution,
+    TyBuilder, TyExt, TyKind,
 };
 
 // This lint has a false positive here. See the link below for details.
@@ -132,7 +132,7 @@ fn default() -> Self {
 #[derive(Clone, PartialEq, Eq, Debug, Default)]
 pub struct InferenceResult {
     /// For each method call expr, records the function it resolves to.
-    method_resolutions: FxHashMap<ExprId, FunctionId>,
+    method_resolutions: FxHashMap<ExprId, (FunctionId, Substitution)>,
     /// For each field access expr, records the field it resolves to.
     field_resolutions: FxHashMap<ExprId, FieldId>,
     /// For each struct literal or pattern, records the variant it resolves to.
@@ -152,8 +152,8 @@ pub struct InferenceResult {
 }
 
 impl InferenceResult {
-    pub fn method_resolution(&self, expr: ExprId) -> Option<FunctionId> {
-        self.method_resolutions.get(&expr).copied()
+    pub fn method_resolution(&self, expr: ExprId) -> Option<(FunctionId, Substitution)> {
+        self.method_resolutions.get(&expr).cloned()
     }
     pub fn field_resolution(&self, expr: ExprId) -> Option<FieldId> {
         self.field_resolutions.get(&expr).copied()
@@ -284,14 +284,17 @@ fn resolve_all(mut self) -> InferenceResult {
         self.table.propagate_diverging_flag();
         let mut result = std::mem::take(&mut self.result);
         for ty in result.type_of_expr.values_mut() {
-            *ty = self.table.resolve_ty_completely(ty.clone());
+            *ty = self.table.resolve_completely(ty.clone());
         }
         for ty in result.type_of_pat.values_mut() {
-            *ty = self.table.resolve_ty_completely(ty.clone());
+            *ty = self.table.resolve_completely(ty.clone());
         }
         for mismatch in result.type_mismatches.values_mut() {
-            mismatch.expected = self.table.resolve_ty_completely(mismatch.expected.clone());
-            mismatch.actual = self.table.resolve_ty_completely(mismatch.actual.clone());
+            mismatch.expected = self.table.resolve_completely(mismatch.expected.clone());
+            mismatch.actual = self.table.resolve_completely(mismatch.actual.clone());
+        }
+        for (_, subst) in result.method_resolutions.values_mut() {
+            *subst = self.table.resolve_completely(subst.clone());
         }
         result
     }
@@ -300,8 +303,8 @@ fn write_expr_ty(&mut self, expr: ExprId, ty: Ty) {
         self.result.type_of_expr.insert(expr, ty);
     }
 
-    fn write_method_resolution(&mut self, expr: ExprId, func: FunctionId) {
-        self.result.method_resolutions.insert(expr, func);
+    fn write_method_resolution(&mut self, expr: ExprId, func: FunctionId, subst: Substitution) {
+        self.result.method_resolutions.insert(expr, (func, subst));
     }
 
     fn write_field_resolution(&mut self, expr: ExprId, field: FieldId) {
index 08c05c67cc906261d022b6cbee3c1fcd5e9bf135..eab8fac910ed42505279b3d6b55ce7fcbff017ed 100644 (file)
@@ -891,17 +891,21 @@ fn infer_method_call(
                 method_name,
             )
         });
-        let (derefed_receiver_ty, method_ty, def_generics) = match resolved {
+        let (derefed_receiver_ty, method_ty, substs) = match resolved {
             Some((ty, func)) => {
                 let ty = canonicalized_receiver.decanonicalize_ty(ty);
-                self.write_method_resolution(tgt_expr, func);
-                (ty, self.db.value_ty(func.into()), Some(generics(self.db.upcast(), func.into())))
+                let generics = generics(self.db.upcast(), func.into());
+                let substs = self.substs_for_method_call(generics, generic_args, &ty);
+                self.write_method_resolution(tgt_expr, func, substs.clone());
+                (ty, self.db.value_ty(func.into()), substs)
             }
-            None => (receiver_ty, Binders::empty(&Interner, self.err_ty()), None),
+            None => (
+                receiver_ty,
+                Binders::empty(&Interner, self.err_ty()),
+                Substitution::empty(&Interner),
+            ),
         };
-        let substs = self.substs_for_method_call(def_generics, generic_args, &derefed_receiver_ty);
         let method_ty = method_ty.substitute(&Interner, &substs);
-        let method_ty = self.insert_type_vars(method_ty);
         self.register_obligations_for_call(&method_ty);
         let (expected_receiver_ty, param_tys, ret_ty) = match method_ty.callable_sig(self.db) {
             Some(sig) => {
@@ -950,23 +954,21 @@ fn check_call_arguments(&mut self, args: &[ExprId], param_tys: &[Ty]) {
 
     fn substs_for_method_call(
         &mut self,
-        def_generics: Option<Generics>,
+        def_generics: Generics,
         generic_args: Option<&GenericArgs>,
         receiver_ty: &Ty,
     ) -> Substitution {
         let (parent_params, self_params, type_params, impl_trait_params) =
-            def_generics.as_ref().map_or((0, 0, 0, 0), |g| g.provenance_split());
+            def_generics.provenance_split();
         assert_eq!(self_params, 0); // method shouldn't have another Self param
         let total_len = parent_params + type_params + impl_trait_params;
         let mut substs = Vec::with_capacity(total_len);
         // Parent arguments are unknown, except for the receiver type
-        if let Some(parent_generics) = def_generics.as_ref().map(|p| p.iter_parent()) {
-            for (_id, param) in parent_generics {
-                if param.provenance == hir_def::generics::TypeParamProvenance::TraitSelf {
-                    substs.push(receiver_ty.clone());
-                } else {
-                    substs.push(self.err_ty());
-                }
+        for (_id, param) in def_generics.iter_parent() {
+            if param.provenance == hir_def::generics::TypeParamProvenance::TraitSelf {
+                substs.push(receiver_ty.clone());
+            } else {
+                substs.push(self.table.new_type_var());
             }
         }
         // handle provided type arguments
@@ -989,7 +991,7 @@ fn substs_for_method_call(
         };
         let supplied_params = substs.len();
         for _ in supplied_params..total_len {
-            substs.push(self.err_ty());
+            substs.push(self.table.new_type_var());
         }
         assert_eq!(substs.len(), total_len);
         Substitution::from_iter(&Interner, substs)
index f8233cac393985af6f838ee35466e8befe52de2b..ea5684229f0840e68312ed7ed9bb20a01afc92e5 100644 (file)
@@ -295,8 +295,11 @@ fn resolve_with_fallback_inner<T>(
         .expect("fold failed unexpectedly")
     }
 
-    pub(crate) fn resolve_ty_completely(&mut self, ty: Ty) -> Ty {
-        self.resolve_with_fallback(ty, |_, _, d, _| d)
+    pub(crate) fn resolve_completely<T>(&mut self, t: T) -> T::Result
+    where
+        T: HasInterner<Interner = Interner> + Fold<Interner>,
+    {
+        self.resolve_with_fallback(t, |_, _, d, _| d)
     }
 
     /// Unify two types and register new trait goals that arise from that.
index 787eb2fd3f36cdad3ca17d324e713df0d2c72793..c929d73949369bfab5d22fa7f254e418c37a4668 100644 (file)
@@ -784,6 +784,19 @@ fn foo() {
         )
     }
 
+    #[test]
+    fn expected_type_generic_struct_field() {
+        check_expected_type_and_name(
+            r#"
+struct Foo<T> { a: T }
+fn foo() -> Foo<u32> {
+    Foo { a: $0 }
+}
+"#,
+            expect![[r#"ty: u32, name: a"#]],
+        )
+    }
+
     #[test]
     fn expected_type_struct_field_with_leading_char() {
         cov_mark::check!(expected_type_struct_field_with_leading_char);
@@ -895,4 +908,51 @@ fn foo() -> u32 {
             expect![[r#"ty: u32, name: ?"#]],
         )
     }
+
+    #[test]
+    fn expected_type_closure_param() {
+        check_expected_type_and_name(
+            r#"
+fn foo() {
+    bar(|| $0);
+}
+
+fn bar(f: impl FnOnce() -> u32) {}
+#[lang = "fn_once"]
+trait FnOnce { type Output; }
+"#,
+            expect![[r#"ty: u32, name: ?"#]],
+        );
+    }
+
+    #[test]
+    fn expected_type_generic_function() {
+        check_expected_type_and_name(
+            r#"
+fn foo() {
+    bar::<u32>($0);
+}
+
+fn bar<T>(t: T) {}
+"#,
+            expect![[r#"ty: u32, name: t"#]],
+        );
+    }
+
+    #[test]
+    fn expected_type_generic_method() {
+        check_expected_type_and_name(
+            r#"
+fn foo() {
+    S(1u32).bar($0);
+}
+
+struct S<T>(T);
+impl<T> S<T> {
+    fn bar(self, t: T) {}
+}
+"#,
+            expect![[r#"ty: u32, name: t"#]],
+        );
+    }
 }
index be1cc12de15bf51b7dde6cef97e08875ca384e2b..1aeda08e5f58c5806489b774cc0ab669ca9eb7d3 100644 (file)
@@ -188,6 +188,24 @@ fn foo(&self, x: i32)
     );
 }
 
+#[test]
+fn test_fn_signature_for_generic_method() {
+    check(
+        r#"
+struct S<T>(T);
+impl<T> S<T> {
+    fn foo(&self, x: T) {}
+}
+
+fn main() { S(1u32).foo($0); }
+"#,
+        expect![[r#"
+                fn foo(&self, x: u32)
+                (<x: u32>)
+            "#]],
+    );
+}
+
 #[test]
 fn test_fn_signature_for_method_with_arg_as_assoc_fn() {
     check(