]> git.lizzy.rs Git - rust.git/commitdiff
Feat: inline generics in const and func trait completions
authorrdambrosio <rdambrosio016@gmail.com>
Thu, 17 Jun 2021 23:54:28 +0000 (19:54 -0400)
committerrdambrosio <rdambrosio016@gmail.com>
Thu, 17 Jun 2021 23:54:28 +0000 (19:54 -0400)
crates/ide_assists/src/lib.rs
crates/ide_assists/src/path_transform.rs [deleted file]
crates/ide_assists/src/utils.rs
crates/ide_completion/src/completions/trait_impl.rs
crates/ide_db/src/lib.rs
crates/ide_db/src/path_transform.rs [new file with mode: 0644]

index fa378a622dc2e6f5e9fab0fc0d58dcb9890bce66..86a57ce5dcaa9d8bddc6f9405bf001e0cfae8479 100644 (file)
@@ -15,7 +15,6 @@ macro_rules! eprintln {
 #[cfg(test)]
 mod tests;
 pub mod utils;
-pub mod path_transform;
 
 use hir::Semantics;
 use ide_db::{base_db::FileRange, RootDatabase};
diff --git a/crates/ide_assists/src/path_transform.rs b/crates/ide_assists/src/path_transform.rs
deleted file mode 100644 (file)
index 48a7fa0..0000000
+++ /dev/null
@@ -1,160 +0,0 @@
-//! See [`PathTransform`].
-
-use hir::{HirDisplay, SemanticsScope};
-use ide_db::helpers::mod_path_to_ast;
-use rustc_hash::FxHashMap;
-use syntax::{
-    ast::{self, AstNode},
-    ted,
-};
-
-/// `PathTransform` substitutes path in SyntaxNodes in bulk.
-///
-/// This is mostly useful for IDE code generation. If you paste some existing
-/// code into a new context (for example, to add method overrides to an `impl`
-/// block), you generally want to appropriately qualify the names, and sometimes
-/// you might want to substitute generic parameters as well:
-///
-/// ```
-/// mod x {
-///   pub struct A<V>;
-///   pub trait T<U> { fn foo(&self, _: U) -> A<U>; }
-/// }
-///
-/// mod y {
-///   use x::T;
-///
-///   impl T<()> for () {
-///      // If we invoke **Add Missing Members** here, we want to copy-paste `foo`.
-///      // But we want a slightly-modified version of it:
-///      fn foo(&self, _: ()) -> x::A<()> {}
-///   }
-/// }
-/// ```
-pub(crate) struct PathTransform<'a> {
-    pub(crate) subst: (hir::Trait, ast::Impl),
-    pub(crate) target_scope: &'a SemanticsScope<'a>,
-    pub(crate) source_scope: &'a SemanticsScope<'a>,
-}
-
-impl<'a> PathTransform<'a> {
-    pub(crate) fn apply(&self, item: ast::AssocItem) {
-        if let Some(ctx) = self.build_ctx() {
-            ctx.apply(item)
-        }
-    }
-    fn build_ctx(&self) -> Option<Ctx<'a>> {
-        let db = self.source_scope.db;
-        let target_module = self.target_scope.module()?;
-        let source_module = self.source_scope.module()?;
-
-        let substs = get_syntactic_substs(self.subst.1.clone()).unwrap_or_default();
-        let generic_def: hir::GenericDef = self.subst.0.into();
-        let substs_by_param: FxHashMap<_, _> = generic_def
-            .type_params(db)
-            .into_iter()
-            // this is a trait impl, so we need to skip the first type parameter -- this is a bit hacky
-            .skip(1)
-            // The actual list of trait type parameters may be longer than the one
-            // used in the `impl` block due to trailing default type parameters.
-            // For that case we extend the `substs` with an empty iterator so we
-            // can still hit those trailing values and check if they actually have
-            // a default type. If they do, go for that type from `hir` to `ast` so
-            // the resulting change can be applied correctly.
-            .zip(substs.into_iter().map(Some).chain(std::iter::repeat(None)))
-            .filter_map(|(k, v)| match v {
-                Some(v) => Some((k, v)),
-                None => {
-                    let default = k.default(db)?;
-                    Some((
-                        k,
-                        ast::make::ty(&default.display_source_code(db, source_module.into()).ok()?),
-                    ))
-                }
-            })
-            .collect();
-
-        let res = Ctx { substs: substs_by_param, target_module, source_scope: self.source_scope };
-        Some(res)
-    }
-}
-
-struct Ctx<'a> {
-    substs: FxHashMap<hir::TypeParam, ast::Type>,
-    target_module: hir::Module,
-    source_scope: &'a SemanticsScope<'a>,
-}
-
-impl<'a> Ctx<'a> {
-    fn apply(&self, item: ast::AssocItem) {
-        for event in item.syntax().preorder() {
-            let node = match event {
-                syntax::WalkEvent::Enter(_) => continue,
-                syntax::WalkEvent::Leave(it) => it,
-            };
-            if let Some(path) = ast::Path::cast(node.clone()) {
-                self.transform_path(path);
-            }
-        }
-    }
-    fn transform_path(&self, path: ast::Path) -> Option<()> {
-        if path.qualifier().is_some() {
-            return None;
-        }
-        if path.segment().and_then(|s| s.param_list()).is_some() {
-            // don't try to qualify `Fn(Foo) -> Bar` paths, they are in prelude anyway
-            return None;
-        }
-
-        let resolution = self.source_scope.speculative_resolve(&path)?;
-
-        match resolution {
-            hir::PathResolution::TypeParam(tp) => {
-                if let Some(subst) = self.substs.get(&tp) {
-                    ted::replace(path.syntax(), subst.clone_subtree().clone_for_update().syntax())
-                }
-            }
-            hir::PathResolution::Def(def) => {
-                let found_path =
-                    self.target_module.find_use_path(self.source_scope.db.upcast(), def)?;
-                let res = mod_path_to_ast(&found_path).clone_for_update();
-                if let Some(args) = path.segment().and_then(|it| it.generic_arg_list()) {
-                    if let Some(segment) = res.segment() {
-                        let old = segment.get_or_create_generic_arg_list();
-                        ted::replace(old.syntax(), args.clone_subtree().syntax().clone_for_update())
-                    }
-                }
-                ted::replace(path.syntax(), res.syntax())
-            }
-            hir::PathResolution::Local(_)
-            | hir::PathResolution::ConstParam(_)
-            | hir::PathResolution::SelfType(_)
-            | hir::PathResolution::Macro(_)
-            | hir::PathResolution::AssocItem(_) => (),
-        }
-        Some(())
-    }
-}
-
-// FIXME: It would probably be nicer if we could get this via HIR (i.e. get the
-// trait ref, and then go from the types in the substs back to the syntax).
-fn get_syntactic_substs(impl_def: ast::Impl) -> Option<Vec<ast::Type>> {
-    let target_trait = impl_def.trait_()?;
-    let path_type = match target_trait {
-        ast::Type::PathType(path) => path,
-        _ => return None,
-    };
-    let generic_arg_list = path_type.path()?.segment()?.generic_arg_list()?;
-
-    let mut result = Vec::new();
-    for generic_arg in generic_arg_list.generic_args() {
-        match generic_arg {
-            ast::GenericArg::TypeArg(type_arg) => result.push(type_arg.ty()?),
-            ast::GenericArg::AssocTypeArg(_)
-            | ast::GenericArg::LifetimeArg(_)
-            | ast::GenericArg::ConstArg(_) => (),
-        }
-    }
-
-    Some(result)
-}
index 068df005bfaa218cc8e9343cfc1210c8b354628b..0ec236aa0c7232e3efa08c367e0533e472201aa9 100644 (file)
@@ -8,6 +8,7 @@
 use hir::{Adt, HasSource, Semantics};
 use ide_db::{
     helpers::{FamousDefs, SnippetCap},
+    path_transform::PathTransform,
     RootDatabase,
 };
 use itertools::Itertools;
     SyntaxNode, TextSize, T,
 };
 
-use crate::{
-    assist_context::{AssistBuilder, AssistContext},
-    path_transform::PathTransform,
-};
+use crate::assist_context::{AssistBuilder, AssistContext};
 
 pub(crate) fn unwrap_trivial_block(block: ast::BlockExpr) -> ast::Expr {
     extract_trivial_expression(&block)
index dc1d198cc23825c92ca77b91b2c4761cc9c18eae..1f6b959af772c77830b952d40808e56a3f130ee4 100644 (file)
@@ -32,7 +32,7 @@
 //! ```
 
 use hir::{self, HasAttrs, HasSource};
-use ide_db::{traits::get_missing_assoc_items, SymbolKind};
+use ide_db::{path_transform::PathTransform, traits::get_missing_assoc_items, SymbolKind};
 use syntax::{
     ast::{self, edit},
     display::function_declaration,
@@ -56,7 +56,9 @@ pub(crate) fn complete_trait_impl(acc: &mut Completions, ctx: &CompletionContext
             hir::AssocItem::Function(fn_item)
                 if kind == ImplCompletionKind::All || kind == ImplCompletionKind::Fn =>
             {
-                add_function_impl(&trigger, acc, ctx, fn_item)
+                if let Some(impl_def) = ctx.sema.to_def(&impl_def) {
+                    add_function_impl(&trigger, acc, ctx, fn_item, impl_def)
+                }
             }
             hir::AssocItem::TypeAlias(type_item)
                 if kind == ImplCompletionKind::All || kind == ImplCompletionKind::TypeAlias =>
@@ -66,7 +68,9 @@ pub(crate) fn complete_trait_impl(acc: &mut Completions, ctx: &CompletionContext
             hir::AssocItem::Const(const_item)
                 if kind == ImplCompletionKind::All || kind == ImplCompletionKind::Const =>
             {
-                add_const_impl(&trigger, acc, ctx, const_item)
+                if let Some(impl_def) = ctx.sema.to_def(&impl_def) {
+                    add_const_impl(&trigger, acc, ctx, const_item, impl_def)
+                }
             }
             _ => {}
         });
@@ -129,6 +133,7 @@ fn add_function_impl(
     acc: &mut Completions,
     ctx: &CompletionContext,
     func: hir::Function,
+    impl_def: hir::Impl,
 ) {
     let fn_name = func.name(ctx.db).to_string();
 
@@ -147,23 +152,55 @@ fn add_function_impl(
         CompletionItemKind::SymbolKind(SymbolKind::Function)
     };
     let range = replacement_range(ctx, fn_def_node);
-    if let Some(src) = func.source(ctx.db) {
-        let function_decl = function_declaration(&src.value);
-        match ctx.config.snippet_cap {
-            Some(cap) => {
-                let snippet = format!("{} {{\n    $0\n}}", function_decl);
-                item.snippet_edit(cap, TextEdit::replace(range, snippet));
-            }
-            None => {
-                let header = format!("{} {{", function_decl);
-                item.text_edit(TextEdit::replace(range, header));
-            }
-        };
-        item.kind(completion_kind);
-        item.add_to(acc);
+
+    if let Some(source) = func.source(ctx.db) {
+        let assoc_item = ast::AssocItem::Fn(source.value);
+        if let Some(transformed_item) = get_transformed_assoc_item(ctx, assoc_item, impl_def) {
+            let transformed_fn = match transformed_item {
+                ast::AssocItem::Fn(func) => func,
+                _ => unreachable!(),
+            };
+
+            let function_decl = function_declaration(&transformed_fn);
+            match ctx.config.snippet_cap {
+                Some(cap) => {
+                    let snippet = format!("{} {{\n    $0\n}}", function_decl);
+                    item.snippet_edit(cap, TextEdit::replace(range, snippet));
+                }
+                None => {
+                    let header = format!("{} {{", function_decl);
+                    item.text_edit(TextEdit::replace(range, header));
+                }
+            };
+            item.kind(completion_kind);
+            item.add_to(acc);
+        }
     }
 }
 
+/// Transform a relevant associated item to inline generics from the impl, remove attrs and docs, etc.
+fn get_transformed_assoc_item(
+    ctx: &CompletionContext,
+    assoc_item: ast::AssocItem,
+    impl_def: hir::Impl,
+) -> Option<ast::AssocItem> {
+    let assoc_item = assoc_item.clone_for_update();
+    let trait_ = impl_def.trait_(ctx.db)?;
+    let source_scope = &ctx.sema.scope_for_def(trait_);
+    let target_scope = &ctx.sema.scope(impl_def.source(ctx.db)?.syntax().value);
+    let transform = PathTransform {
+        subst: (trait_, impl_def.source(ctx.db)?.value),
+        source_scope,
+        target_scope,
+    };
+
+    transform.apply(assoc_item.clone());
+    Some(match assoc_item {
+        ast::AssocItem::Fn(func) => ast::AssocItem::Fn(edit::remove_attrs_and_docs(&func)),
+        _ => assoc_item,
+    })
+}
+
 fn add_type_alias_impl(
     type_def_node: &SyntaxNode,
     acc: &mut Completions,
@@ -188,21 +225,30 @@ fn add_const_impl(
     acc: &mut Completions,
     ctx: &CompletionContext,
     const_: hir::Const,
+    impl_def: hir::Impl,
 ) {
     let const_name = const_.name(ctx.db).map(|n| n.to_string());
 
     if let Some(const_name) = const_name {
         if let Some(source) = const_.source(ctx.db) {
-            let snippet = make_const_compl_syntax(&source.value);
-
-            let range = replacement_range(ctx, const_def_node);
-            let mut item =
-                CompletionItem::new(CompletionKind::Magic, ctx.source_range(), snippet.clone());
-            item.text_edit(TextEdit::replace(range, snippet))
-                .lookup_by(const_name)
-                .kind(SymbolKind::Const)
-                .set_documentation(const_.docs(ctx.db));
-            item.add_to(acc);
+            let assoc_item = ast::AssocItem::Const(source.value);
+            if let Some(transformed_item) = get_transformed_assoc_item(ctx, assoc_item, impl_def) {
+                let transformed_const = match transformed_item {
+                    ast::AssocItem::Const(const_) => const_,
+                    _ => unreachable!(),
+                };
+
+                let snippet = make_const_compl_syntax(&transformed_const);
+
+                let range = replacement_range(ctx, const_def_node);
+                let mut item =
+                    CompletionItem::new(CompletionKind::Magic, ctx.source_range(), snippet.clone());
+                item.text_edit(TextEdit::replace(range, snippet))
+                    .lookup_by(const_name)
+                    .kind(SymbolKind::Const)
+                    .set_documentation(const_.docs(ctx.db));
+                item.add_to(acc);
+            }
         }
     }
 }
@@ -779,4 +825,183 @@ impl Foo for T {{
         test("Type", "type T$0", "type Type = ");
         test("CONST", "const C$0", "const CONST: i32 = ");
     }
+
+    #[test]
+    fn generics_are_inlined_in_return_type() {
+        check_edit(
+            "function",
+            r#"
+trait Foo<T> {
+    fn function() -> T;
+}
+struct Bar;
+
+impl Foo<u32> for Bar {
+    fn f$0
+}
+"#,
+            r#"
+trait Foo<T> {
+    fn function() -> T;
+}
+struct Bar;
+
+impl Foo<u32> for Bar {
+    fn function() -> u32 {
+    $0
+}
+}
+"#,
+        )
+    }
+
+    #[test]
+    fn generics_are_inlined_in_parameter() {
+        check_edit(
+            "function",
+            r#"
+trait Foo<T> {
+    fn function(bar: T);
+}
+struct Bar;
+
+impl Foo<u32> for Bar {
+    fn f$0
+}
+"#,
+            r#"
+trait Foo<T> {
+    fn function(bar: T);
+}
+struct Bar;
+
+impl Foo<u32> for Bar {
+    fn function(bar: u32) {
+    $0
+}
+}
+"#,
+        )
+    }
+
+    #[test]
+    fn generics_are_inlined_when_part_of_other_types() {
+        check_edit(
+            "function",
+            r#"
+trait Foo<T> {
+    fn function(bar: Vec<T>);
+}
+struct Bar;
+
+impl Foo<u32> for Bar {
+    fn f$0
+}
+"#,
+            r#"
+trait Foo<T> {
+    fn function(bar: Vec<T>);
+}
+struct Bar;
+
+impl Foo<u32> for Bar {
+    fn function(bar: Vec<u32>) {
+    $0
+}
+}
+"#,
+        )
+    }
+
+    #[test]
+    fn generics_are_inlined_complex() {
+        check_edit(
+            "function",
+            r#"
+trait Foo<T, U, V> {
+    fn function(bar: Vec<T>, baz: U) -> Arc<Vec<V>>;
+}
+struct Bar;
+
+impl Foo<u32, Vec<usize>, u8> for Bar {
+    fn f$0
+}
+"#,
+            r#"
+trait Foo<T, U, V> {
+    fn function(bar: Vec<T>, baz: U) -> Arc<Vec<V>>;
+}
+struct Bar;
+
+impl Foo<u32, Vec<usize>, u8> for Bar {
+    fn function(bar: Vec<u32>, baz: Vec<usize>) -> Arc<Vec<u8>> {
+    $0
+}
+}
+"#,
+        )
+    }
+
+    #[test]
+    fn generics_are_inlined_in_associated_const() {
+        check_edit(
+            "BAR",
+            r#"
+trait Foo<T> {
+    const BAR: T;
+}
+struct Bar;
+
+impl Foo<u32> for Bar {
+    const B$0
+}
+"#,
+            r#"
+trait Foo<T> {
+    const BAR: T;
+}
+struct Bar;
+
+impl Foo<u32> for Bar {
+    const BAR: u32 = 
+}
+"#,
+        )
+    }
+
+    #[test]
+    fn generics_are_inlined_in_where_clause() {
+        check_edit(
+            "function",
+            r#"
+trait SomeTrait<T> {}
+
+trait Foo<T> {
+    fn function()
+    where Self: SomeTrait<T>;
+}
+struct Bar;
+
+impl Foo<u32> for Bar {
+    fn f$0
+}
+"#,
+            r#"
+trait SomeTrait<T> {}
+
+trait Foo<T> {
+    fn function()
+    where Self: SomeTrait<T>;
+}
+struct Bar;
+
+impl Foo<u32> for Bar {
+    fn function()
+where Self: SomeTrait<u32> {
+    $0
+}
+}
+"#,
+        )
+    }
 }
index 7bbd08d6f1764300125fd722f02d080e4f983aff..bde8767dd68fdc3c6997aa57bc7b1249740f2098 100644 (file)
@@ -14,6 +14,7 @@
 pub mod traits;
 pub mod call_info;
 pub mod helpers;
+pub mod path_transform;
 
 pub mod search;
 pub mod rename;
diff --git a/crates/ide_db/src/path_transform.rs b/crates/ide_db/src/path_transform.rs
new file mode 100644 (file)
index 0000000..f3d7aa9
--- /dev/null
@@ -0,0 +1,160 @@
+//! See [`PathTransform`].
+
+use crate::helpers::mod_path_to_ast;
+use hir::{HirDisplay, SemanticsScope};
+use rustc_hash::FxHashMap;
+use syntax::{
+    ast::{self, AstNode},
+    ted,
+};
+
+/// `PathTransform` substitutes path in SyntaxNodes in bulk.
+///
+/// This is mostly useful for IDE code generation. If you paste some existing
+/// code into a new context (for example, to add method overrides to an `impl`
+/// block), you generally want to appropriately qualify the names, and sometimes
+/// you might want to substitute generic parameters as well:
+///
+/// ```
+/// mod x {
+///   pub struct A<V>;
+///   pub trait T<U> { fn foo(&self, _: U) -> A<U>; }
+/// }
+///
+/// mod y {
+///   use x::T;
+///
+///   impl T<()> for () {
+///      // If we invoke **Add Missing Members** here, we want to copy-paste `foo`.
+///      // But we want a slightly-modified version of it:
+///      fn foo(&self, _: ()) -> x::A<()> {}
+///   }
+/// }
+/// ```
+pub struct PathTransform<'a> {
+    pub subst: (hir::Trait, ast::Impl),
+    pub target_scope: &'a SemanticsScope<'a>,
+    pub source_scope: &'a SemanticsScope<'a>,
+}
+
+impl<'a> PathTransform<'a> {
+    pub fn apply(&self, item: ast::AssocItem) {
+        if let Some(ctx) = self.build_ctx() {
+            ctx.apply(item)
+        }
+    }
+    fn build_ctx(&self) -> Option<Ctx<'a>> {
+        let db = self.source_scope.db;
+        let target_module = self.target_scope.module()?;
+        let source_module = self.source_scope.module()?;
+
+        let substs = get_syntactic_substs(self.subst.1.clone()).unwrap_or_default();
+        let generic_def: hir::GenericDef = self.subst.0.into();
+        let substs_by_param: FxHashMap<_, _> = generic_def
+            .type_params(db)
+            .into_iter()
+            // this is a trait impl, so we need to skip the first type parameter -- this is a bit hacky
+            .skip(1)
+            // The actual list of trait type parameters may be longer than the one
+            // used in the `impl` block due to trailing default type parameters.
+            // For that case we extend the `substs` with an empty iterator so we
+            // can still hit those trailing values and check if they actually have
+            // a default type. If they do, go for that type from `hir` to `ast` so
+            // the resulting change can be applied correctly.
+            .zip(substs.into_iter().map(Some).chain(std::iter::repeat(None)))
+            .filter_map(|(k, v)| match v {
+                Some(v) => Some((k, v)),
+                None => {
+                    let default = k.default(db)?;
+                    Some((
+                        k,
+                        ast::make::ty(&default.display_source_code(db, source_module.into()).ok()?),
+                    ))
+                }
+            })
+            .collect();
+
+        let res = Ctx { substs: substs_by_param, target_module, source_scope: self.source_scope };
+        Some(res)
+    }
+}
+
+struct Ctx<'a> {
+    substs: FxHashMap<hir::TypeParam, ast::Type>,
+    target_module: hir::Module,
+    source_scope: &'a SemanticsScope<'a>,
+}
+
+impl<'a> Ctx<'a> {
+    fn apply(&self, item: ast::AssocItem) {
+        for event in item.syntax().preorder() {
+            let node = match event {
+                syntax::WalkEvent::Enter(_) => continue,
+                syntax::WalkEvent::Leave(it) => it,
+            };
+            if let Some(path) = ast::Path::cast(node.clone()) {
+                self.transform_path(path);
+            }
+        }
+    }
+    fn transform_path(&self, path: ast::Path) -> Option<()> {
+        if path.qualifier().is_some() {
+            return None;
+        }
+        if path.segment().and_then(|s| s.param_list()).is_some() {
+            // don't try to qualify `Fn(Foo) -> Bar` paths, they are in prelude anyway
+            return None;
+        }
+
+        let resolution = self.source_scope.speculative_resolve(&path)?;
+
+        match resolution {
+            hir::PathResolution::TypeParam(tp) => {
+                if let Some(subst) = self.substs.get(&tp) {
+                    ted::replace(path.syntax(), subst.clone_subtree().clone_for_update().syntax())
+                }
+            }
+            hir::PathResolution::Def(def) => {
+                let found_path =
+                    self.target_module.find_use_path(self.source_scope.db.upcast(), def)?;
+                let res = mod_path_to_ast(&found_path).clone_for_update();
+                if let Some(args) = path.segment().and_then(|it| it.generic_arg_list()) {
+                    if let Some(segment) = res.segment() {
+                        let old = segment.get_or_create_generic_arg_list();
+                        ted::replace(old.syntax(), args.clone_subtree().syntax().clone_for_update())
+                    }
+                }
+                ted::replace(path.syntax(), res.syntax())
+            }
+            hir::PathResolution::Local(_)
+            | hir::PathResolution::ConstParam(_)
+            | hir::PathResolution::SelfType(_)
+            | hir::PathResolution::Macro(_)
+            | hir::PathResolution::AssocItem(_) => (),
+        }
+        Some(())
+    }
+}
+
+// FIXME: It would probably be nicer if we could get this via HIR (i.e. get the
+// trait ref, and then go from the types in the substs back to the syntax).
+fn get_syntactic_substs(impl_def: ast::Impl) -> Option<Vec<ast::Type>> {
+    let target_trait = impl_def.trait_()?;
+    let path_type = match target_trait {
+        ast::Type::PathType(path) => path,
+        _ => return None,
+    };
+    let generic_arg_list = path_type.path()?.segment()?.generic_arg_list()?;
+
+    let mut result = Vec::new();
+    for generic_arg in generic_arg_list.generic_args() {
+        match generic_arg {
+            ast::GenericArg::TypeArg(type_arg) => result.push(type_arg.ty()?),
+            ast::GenericArg::AssocTypeArg(_)
+            | ast::GenericArg::LifetimeArg(_)
+            | ast::GenericArg::ConstArg(_) => (),
+        }
+    }
+
+    Some(result)
+}