]> git.lizzy.rs Git - rust.git/commitdiff
Add references to fn args during completion
authoradamrk <ark.email@gmail.com>
Sat, 22 Aug 2020 18:11:37 +0000 (20:11 +0200)
committeradamrk <ark.email@gmail.com>
Sun, 30 Aug 2020 10:34:32 +0000 (12:34 +0200)
crates/hir/src/code_model.rs
crates/hir_ty/src/db.rs
crates/hir_ty/src/infer.rs
crates/hir_ty/src/lib.rs
crates/ide/src/completion/presentation.rs

index c2fc819e764aa8669001a8ca9745a67e58af05db..f182ab228fe37d85596e620c8c86f7333bc8aa9f 100644 (file)
@@ -708,12 +708,24 @@ pub fn self_param(self, db: &dyn HirDatabase) -> Option<SelfParam> {
         Some(SelfParam { func: self.id })
     }
 
-    pub fn params(self, db: &dyn HirDatabase) -> Vec<Param> {
+    pub fn params(self, db: &dyn HirDatabase) -> Vec<Type> {
+        let resolver = self.id.resolver(db.upcast());
+        let ctx = hir_ty::TyLoweringContext::new(db, &resolver);
+        let environment = TraitEnvironment::lower(db, &resolver);
         db.function_data(self.id)
             .params
             .iter()
             .skip(if self.self_param(db).is_some() { 1 } else { 0 })
-            .map(|_| Param { _ty: () })
+            .map(|type_ref| {
+                let ty = Type {
+                    krate: self.id.lookup(db.upcast()).container.module(db.upcast()).krate,
+                    ty: InEnvironment {
+                        value: Ty::from_hir_ext(&ctx, type_ref).0,
+                        environment: environment.clone(),
+                    },
+                };
+                ty
+            })
             .collect()
     }
 
@@ -747,10 +759,6 @@ pub struct SelfParam {
     func: FunctionId,
 }
 
-pub struct Param {
-    _ty: (),
-}
-
 impl SelfParam {
     pub fn access(self, db: &dyn HirDatabase) -> Access {
         let func_data = db.function_data(self.func);
@@ -1100,6 +1108,12 @@ pub fn source(self, db: &dyn HirDatabase) -> InFile<Either<ast::IdentPat, ast::S
             ast.map_left(|it| it.cast().unwrap().to_node(&root)).map_right(|it| it.to_node(&root))
         })
     }
+
+    pub fn can_unify(self, other: Type, db: &dyn HirDatabase) -> bool {
+        let def = DefWithBodyId::from(self.parent);
+        let infer = db.infer(def);
+        db.can_unify(def, infer[self.pat_id].clone(), other.ty.value)
+    }
 }
 
 #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
@@ -1276,6 +1290,14 @@ pub fn is_mutable_reference(&self) -> bool {
         )
     }
 
+    pub fn remove_ref(&self) -> Option<Type> {
+        if let Ty::Apply(ApplicationTy { ctor: TypeCtor::Ref(_), .. }) = self.ty.value {
+            self.ty.value.substs().map(|substs| self.derived(substs[0].clone()))
+        } else {
+            None
+        }
+    }
+
     pub fn is_unknown(&self) -> bool {
         matches!(self.ty.value, Ty::Unknown)
     }
index 25cf9eb7f1279099a2e7de51ad08e8b7e5a7add0..57e60c53b3759ec7aa288c986eb711b46539d9ca 100644 (file)
@@ -26,6 +26,9 @@ pub trait HirDatabase: DefDatabase + Upcast<dyn DefDatabase> {
     #[salsa::invoke(crate::infer::infer_query)]
     fn infer_query(&self, def: DefWithBodyId) -> Arc<InferenceResult>;
 
+    #[salsa::invoke(crate::infer::can_unify)]
+    fn can_unify(&self, def: DefWithBodyId, ty1: Ty, ty2: Ty) -> bool;
+
     #[salsa::invoke(crate::lower::ty_query)]
     #[salsa::cycle(crate::lower::ty_recover)]
     fn ty(&self, def: TyDefId) -> Binders<Ty>;
index 03b00b101c2314788f795ad6f7ae4edd3b64d40a..d461e077b2032b976f778626202ff2a7ac9bd0b5 100644 (file)
@@ -55,7 +55,7 @@ macro_rules! ty_app {
     };
 }
 
-mod unify;
+pub mod unify;
 mod path;
 mod expr;
 mod pat;
@@ -78,6 +78,19 @@ pub(crate) fn infer_query(db: &dyn HirDatabase, def: DefWithBodyId) -> Arc<Infer
     Arc::new(ctx.resolve_all())
 }
 
+pub(crate) fn can_unify(db: &dyn HirDatabase, def: DefWithBodyId, ty1: Ty, ty2: Ty) -> bool {
+    let resolver = def.resolver(db.upcast());
+    let mut ctx = InferenceContext::new(db, def, resolver);
+
+    let ty1 = ctx.canonicalizer().canonicalize_ty(ty1).value;
+    let ty2 = ctx.canonicalizer().canonicalize_ty(ty2).value;
+    let mut kinds = Vec::from(ty1.kinds.to_vec());
+    kinds.extend_from_slice(ty2.kinds.as_ref());
+    let tys = crate::Canonical::new((ty1.value, ty2.value), kinds);
+
+    unify(&tys).is_some()
+}
+
 #[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)]
 enum ExprOrPatId {
     ExprId(ExprId),
index 1e748476ac1946a6afefce3b91056136424889f3..681f98bde86be744bf0b86a36f558e8a3b4ae251 100644 (file)
@@ -43,7 +43,7 @@ macro_rules! eprintln {
 };
 
 pub use autoderef::autoderef;
-pub use infer::{InferTy, InferenceResult};
+pub use infer::{unify, InferTy, InferenceResult};
 pub use lower::CallableDefId;
 pub use lower::{
     associated_type_shorthand_candidates, callable_item_sig, ImplTraitLoweringMode, TyDefId,
index 24c507f9b23e6434afaf918fe0b17d70dc73eccf..cfcb6dfa190e1b8e37208d50fea0e84bd0d48efc 100644 (file)
@@ -191,6 +191,22 @@ pub(crate) fn add_function(
         func: hir::Function,
         local_name: Option<String>,
     ) {
+        fn add_arg(arg: &str, ty: &Type, ctx: &CompletionContext) -> String {
+            let mut prefix = "";
+            if let Some(derefed_ty) = ty.remove_ref() {
+                ctx.scope.process_all_names(&mut |name, scope| {
+                    if prefix != "" {
+                        return;
+                    }
+                    if let ScopeDef::Local(local) = scope {
+                        if name.to_string() == arg && local.can_unify(derefed_ty.clone(), ctx.db) {
+                            prefix = if ty.is_mutable_reference() { "&mut " } else { "&" };
+                        }
+                    }
+                });
+            }
+            prefix.to_string() + arg
+        };
         let name = local_name.unwrap_or_else(|| func.name(ctx.db).to_string());
         let ast_node = func.source(ctx.db).value;
 
@@ -205,12 +221,20 @@ pub(crate) fn add_function(
                 .set_deprecated(is_deprecated(func, ctx.db))
                 .detail(function_declaration(&ast_node));
 
+        let params_ty = func.params(ctx.db);
         let params = ast_node
             .param_list()
             .into_iter()
             .flat_map(|it| it.params())
-            .flat_map(|it| it.pat())
-            .map(|pat| pat.to_string().trim_start_matches('_').into())
+            .zip(params_ty)
+            .flat_map(|(it, param_ty)| {
+                if let Some(pat) = it.pat() {
+                    let name = pat.to_string();
+                    let arg = name.trim_start_matches('_');
+                    return Some(add_arg(arg, &param_ty, ctx));
+                }
+                None
+            })
             .collect();
 
         builder = builder.add_call_parens(ctx, name, Params::Named(params));
@@ -863,6 +887,85 @@ fn main() { foo(${1:foo}, ${2:bar}, ${3:ho_ge_})$0 }
         );
     }
 
+    #[test]
+    fn insert_ref_when_matching_local_in_scope() {
+        check_edit(
+            "ref_arg",
+            r#"
+struct Foo {}
+fn ref_arg(x: &Foo) {}
+fn main() {
+    let x = Foo {};
+    ref_ar<|>
+}
+"#,
+            r#"
+struct Foo {}
+fn ref_arg(x: &Foo) {}
+fn main() {
+    let x = Foo {};
+    ref_arg(${1:&x})$0
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn insert_mut_ref_when_matching_local_in_scope() {
+        check_edit(
+            "ref_arg",
+            r#"
+struct Foo {}
+fn ref_arg(x: &mut Foo) {}
+fn main() {
+    let x = Foo {};
+    ref_ar<|>
+}
+"#,
+            r#"
+struct Foo {}
+fn ref_arg(x: &mut Foo) {}
+fn main() {
+    let x = Foo {};
+    ref_arg(${1:&mut x})$0
+}
+"#,
+        );
+    }
+
+    #[test]
+    fn insert_ref_when_matching_local_in_scope_for_method() {
+        check_edit(
+            "apply_foo",
+            r#"
+struct Foo {}
+struct Bar {}
+impl Bar {
+    fn apply_foo(&self, x: &Foo) {}
+}
+
+fn main() {
+    let x = Foo {};
+    let y = Bar {};
+    y.<|>
+}
+"#,
+            r#"
+struct Foo {}
+struct Bar {}
+impl Bar {
+    fn apply_foo(&self, x: &Foo) {}
+}
+
+fn main() {
+    let x = Foo {};
+    let y = Bar {};
+    y.apply_foo(${1:&x})$0
+}
+"#,
+        );
+    }
+
     #[test]
     fn inserts_parens_for_tuple_enums() {
         mark::check!(inserts_parens_for_tuple_enums);