]> git.lizzy.rs Git - rust.git/blobdiff - crates/ide_assists/src/handlers/replace_derive_with_manual_impl.rs
Merge #11481
[rust.git] / crates / ide_assists / src / handlers / replace_derive_with_manual_impl.rs
index bd0b2028a1a00d1dcf78b96f6025cb801f9ed67e..b3723710a8631cd40a499349eff837d026188f0d 100644 (file)
@@ -1,10 +1,14 @@
 use hir::ModuleDef;
-use ide_db::helpers::{import_assets::NameToImport, mod_path_to_ast};
+use ide_db::helpers::insert_whitespace_into_node::insert_ws_into;
+use ide_db::helpers::{
+    get_path_at_cursor_in_tt, import_assets::NameToImport, mod_path_to_ast,
+    parse_tt_as_comma_sep_paths,
+};
 use ide_db::items_locator;
 use itertools::Itertools;
 use syntax::{
-    ast::{self, make, AstNode, NameOwner},
-    SyntaxKind::{IDENT, WHITESPACE},
+    ast::{self, AstNode, AstToken, HasName},
+    SyntaxKind::WHITESPACE,
 };
 
 use crate::{
@@ -52,9 +56,8 @@ pub(crate) fn replace_derive_with_manual_impl(
         return None;
     }
 
-    let trait_token = args.syntax().token_at_offset(ctx.offset()).find(|t| t.kind() == IDENT)?;
-    let trait_name = trait_token.text();
-
+    let ident = args.syntax().token_at_offset(ctx.offset()).find_map(ast::Ident::cast)?;
+    let trait_path = get_path_at_cursor_in_tt(&ident)?;
     let adt = attr.syntax().parent().and_then(ast::Adt::cast)?;
 
     let current_module = ctx.sema.scope(adt.syntax()).module()?;
@@ -63,7 +66,7 @@ pub(crate) fn replace_derive_with_manual_impl(
     let found_traits = items_locator::items_with_name(
         &ctx.sema,
         current_crate,
-        NameToImport::Exact(trait_name.to_string()),
+        NameToImport::exact_case_sensitive(trait_path.segments().last()?.to_string()),
         items_locator::AssocItemSearch::Exclude,
         Some(items_locator::DEFAULT_QUERY_SEARCH_LIMIT.inner()),
     )
@@ -80,12 +83,23 @@ pub(crate) fn replace_derive_with_manual_impl(
     });
 
     let mut no_traits_found = true;
-    for (trait_path, trait_) in found_traits.inspect(|_| no_traits_found = false) {
-        add_assist(acc, ctx, &attr, &args, &trait_path, Some(trait_), &adt)?;
+    let current_derives = parse_tt_as_comma_sep_paths(args.clone())?;
+    let current_derives = current_derives.as_slice();
+    for (replace_trait_path, trait_) in found_traits.inspect(|_| no_traits_found = false) {
+        add_assist(
+            acc,
+            ctx,
+            &attr,
+            &current_derives,
+            &args,
+            &trait_path,
+            &replace_trait_path,
+            Some(trait_),
+            &adt,
+        )?;
     }
     if no_traits_found {
-        let trait_path = make::ext::ident_path(trait_name);
-        add_assist(acc, ctx, &attr, &args, &trait_path, None, &adt)?;
+        add_assist(acc, ctx, &attr, &current_derives, &args, &trait_path, &trait_path, None, &adt)?;
     }
     Some(())
 }
@@ -94,15 +108,16 @@ fn add_assist(
     acc: &mut Assists,
     ctx: &AssistContext,
     attr: &ast::Attr,
-    input: &ast::TokenTree,
-    trait_path: &ast::Path,
+    old_derives: &[ast::Path],
+    old_tree: &ast::TokenTree,
+    old_trait_path: &ast::Path,
+    replace_trait_path: &ast::Path,
     trait_: Option<hir::Trait>,
     adt: &ast::Adt,
 ) -> Option<()> {
     let target = attr.syntax().text_range();
     let annotated_name = adt.name()?;
-    let label = format!("Convert to manual `impl {} for {}`", trait_path, annotated_name);
-    let trait_name = trait_path.segment().and_then(|seg| seg.name_ref())?;
+    let label = format!("Convert to manual `impl {} for {}`", replace_trait_path, annotated_name);
 
     acc.add(
         AssistId("replace_derive_with_manual_impl", AssistKind::Refactor),
@@ -111,9 +126,9 @@ fn add_assist(
         |builder| {
             let insert_pos = adt.syntax().text_range().end();
             let impl_def_with_items =
-                impl_def_from_trait(&ctx.sema, adt, &annotated_name, trait_, trait_path);
-            update_attribute(builder, input, &trait_name, attr);
-            let trait_path = format!("{}", trait_path);
+                impl_def_from_trait(&ctx.sema, adt, &annotated_name, trait_, replace_trait_path);
+            update_attribute(builder, old_derives, old_tree, old_trait_path, attr);
+            let trait_path = format!("{}", replace_trait_path);
             match (ctx.config.snippet_cap, impl_def_with_items) {
                 (None, _) => {
                     builder.insert(insert_pos, generate_trait_impl_text(adt, &trait_path, ""))
@@ -156,12 +171,40 @@ fn impl_def_from_trait(
 ) -> Option<(ast::Impl, ast::AssocItem)> {
     let trait_ = trait_?;
     let target_scope = sema.scope(annotated_name.syntax());
-    let trait_items = filter_assoc_items(sema.db, &trait_.items(sema.db), DefaultMethods::No);
+    let trait_items = filter_assoc_items(sema, &trait_.items(sema.db), DefaultMethods::No);
     if trait_items.is_empty() {
         return None;
     }
-    let impl_def =
-        make::impl_trait(trait_path.clone(), make::ext::ident_path(&annotated_name.text()));
+    let impl_def = {
+        use syntax::ast::Impl;
+        let text = generate_trait_impl_text(adt, trait_path.to_string().as_str(), "");
+        let parse = syntax::SourceFile::parse(&text);
+        let node = match parse.tree().syntax().descendants().find_map(Impl::cast) {
+            Some(it) => it,
+            None => {
+                panic!(
+                    "Failed to make ast node `{}` from text {}",
+                    std::any::type_name::<Impl>(),
+                    text
+                )
+            }
+        };
+        let node = node.clone_subtree();
+        assert_eq!(node.syntax().text_range().start(), 0.into());
+        node
+    };
+
+    let trait_items = trait_items
+        .into_iter()
+        .map(|it| {
+            if sema.hir_file_for(it.syntax()).is_macro() {
+                if let Some(it) = ast::AssocItem::cast(insert_ws_into(it.syntax().clone())) {
+                    return it;
+                }
+            }
+            it.clone_for_update()
+        })
+        .collect();
     let (impl_def, first_assoc_item) =
         add_trait_assoc_items_to_impl(sema, trait_items, trait_, impl_def, target_scope);
 
@@ -175,23 +218,20 @@ fn impl_def_from_trait(
 
 fn update_attribute(
     builder: &mut AssistBuilder,
-    input: &ast::TokenTree,
-    trait_name: &ast::NameRef,
+    old_derives: &[ast::Path],
+    old_tree: &ast::TokenTree,
+    old_trait_path: &ast::Path,
     attr: &ast::Attr,
 ) {
-    let trait_name = trait_name.text();
-    let new_attr_input = input
-        .syntax()
-        .descendants_with_tokens()
-        .filter(|t| t.kind() == IDENT)
-        .filter_map(|t| t.into_token().map(|t| t.text().to_string()))
-        .filter(|t| t != &trait_name)
+    let new_derives = old_derives
+        .iter()
+        .filter(|t| t.to_string() != old_trait_path.to_string())
         .collect::<Vec<_>>();
-    let has_more_derives = !new_attr_input.is_empty();
+    let has_more_derives = !new_derives.is_empty();
 
     if has_more_derives {
-        let new_attr_input = format!("({})", new_attr_input.iter().format(", "));
-        builder.replace(input.syntax().text_range(), new_attr_input);
+        let new_derives = format!("({})", new_derives.iter().format(", "));
+        builder.replace(old_tree.syntax().text_range(), new_derives);
     } else {
         let attr_range = attr.syntax().text_range();
         builder.delete(attr_range);
@@ -302,6 +342,71 @@ impl core::fmt::Debug for Foo {
         }
     }
 }
+"#,
+        )
+    }
+
+    #[test]
+    fn add_custom_impl_debug_tuple_enum() {
+        check_assist(
+            replace_derive_with_manual_impl,
+            r#"
+//- minicore: fmt
+#[derive(Debu$0g)]
+enum Foo {
+    Bar(usize, usize),
+    Baz,
+}
+"#,
+            r#"
+enum Foo {
+    Bar(usize, usize),
+    Baz,
+}
+
+impl core::fmt::Debug for Foo {
+    $0fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
+        match self {
+            Self::Bar(arg0, arg1) => f.debug_tuple("Bar").field(arg0).field(arg1).finish(),
+            Self::Baz => write!(f, "Baz"),
+        }
+    }
+}
+"#,
+        )
+    }
+    #[test]
+    fn add_custom_impl_debug_record_enum() {
+        check_assist(
+            replace_derive_with_manual_impl,
+            r#"
+//- minicore: fmt
+#[derive(Debu$0g)]
+enum Foo {
+    Bar {
+        baz: usize,
+        qux: usize,
+    },
+    Baz,
+}
+"#,
+            r#"
+enum Foo {
+    Bar {
+        baz: usize,
+        qux: usize,
+    },
+    Baz,
+}
+
+impl core::fmt::Debug for Foo {
+    $0fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
+        match self {
+            Self::Bar { baz, qux } => f.debug_struct("Bar").field("baz", baz).field("qux", qux).finish(),
+            Self::Baz => write!(f, "Baz"),
+        }
+    }
+}
 "#,
         )
     }
@@ -443,6 +548,431 @@ impl core::hash::Hash for Foo {
         core::mem::discriminant(self).hash(state);
     }
 }
+"#,
+        )
+    }
+
+    #[test]
+    fn add_custom_impl_clone_record_struct() {
+        check_assist(
+            replace_derive_with_manual_impl,
+            r#"
+//- minicore: clone
+#[derive(Clo$0ne)]
+struct Foo {
+    bin: usize,
+    bar: usize,
+}
+"#,
+            r#"
+struct Foo {
+    bin: usize,
+    bar: usize,
+}
+
+impl Clone for Foo {
+    $0fn clone(&self) -> Self {
+        Self { bin: self.bin.clone(), bar: self.bar.clone() }
+    }
+}
+"#,
+        )
+    }
+
+    #[test]
+    fn add_custom_impl_clone_tuple_struct() {
+        check_assist(
+            replace_derive_with_manual_impl,
+            r#"
+//- minicore: clone
+#[derive(Clo$0ne)]
+struct Foo(usize, usize);
+"#,
+            r#"
+struct Foo(usize, usize);
+
+impl Clone for Foo {
+    $0fn clone(&self) -> Self {
+        Self(self.0.clone(), self.1.clone())
+    }
+}
+"#,
+        )
+    }
+
+    #[test]
+    fn add_custom_impl_clone_empty_struct() {
+        check_assist(
+            replace_derive_with_manual_impl,
+            r#"
+//- minicore: clone
+#[derive(Clo$0ne)]
+struct Foo;
+"#,
+            r#"
+struct Foo;
+
+impl Clone for Foo {
+    $0fn clone(&self) -> Self {
+        Self {  }
+    }
+}
+"#,
+        )
+    }
+
+    #[test]
+    fn add_custom_impl_clone_enum() {
+        check_assist(
+            replace_derive_with_manual_impl,
+            r#"
+//- minicore: clone
+#[derive(Clo$0ne)]
+enum Foo {
+    Bar,
+    Baz,
+}
+"#,
+            r#"
+enum Foo {
+    Bar,
+    Baz,
+}
+
+impl Clone for Foo {
+    $0fn clone(&self) -> Self {
+        match self {
+            Self::Bar => Self::Bar,
+            Self::Baz => Self::Baz,
+        }
+    }
+}
+"#,
+        )
+    }
+
+    #[test]
+    fn add_custom_impl_clone_tuple_enum() {
+        check_assist(
+            replace_derive_with_manual_impl,
+            r#"
+//- minicore: clone
+#[derive(Clo$0ne)]
+enum Foo {
+    Bar(String),
+    Baz,
+}
+"#,
+            r#"
+enum Foo {
+    Bar(String),
+    Baz,
+}
+
+impl Clone for Foo {
+    $0fn clone(&self) -> Self {
+        match self {
+            Self::Bar(arg0) => Self::Bar(arg0.clone()),
+            Self::Baz => Self::Baz,
+        }
+    }
+}
+"#,
+        )
+    }
+
+    #[test]
+    fn add_custom_impl_clone_record_enum() {
+        check_assist(
+            replace_derive_with_manual_impl,
+            r#"
+//- minicore: clone
+#[derive(Clo$0ne)]
+enum Foo {
+    Bar {
+        bin: String,
+    },
+    Baz,
+}
+"#,
+            r#"
+enum Foo {
+    Bar {
+        bin: String,
+    },
+    Baz,
+}
+
+impl Clone for Foo {
+    $0fn clone(&self) -> Self {
+        match self {
+            Self::Bar { bin } => Self::Bar { bin: bin.clone() },
+            Self::Baz => Self::Baz,
+        }
+    }
+}
+"#,
+        )
+    }
+
+    #[test]
+    fn add_custom_impl_partial_ord_record_struct() {
+        check_assist(
+            replace_derive_with_manual_impl,
+            r#"
+//- minicore: ord
+#[derive(Partial$0Ord)]
+struct Foo {
+    bin: usize,
+}
+"#,
+            r#"
+struct Foo {
+    bin: usize,
+}
+
+impl PartialOrd for Foo {
+    $0fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
+        self.bin.partial_cmp(&other.bin)
+    }
+}
+"#,
+        )
+    }
+
+    #[test]
+    fn add_custom_impl_partial_ord_record_struct_multi_field() {
+        check_assist(
+            replace_derive_with_manual_impl,
+            r#"
+//- minicore: ord
+#[derive(Partial$0Ord)]
+struct Foo {
+    bin: usize,
+    bar: usize,
+    baz: usize,
+}
+"#,
+            r#"
+struct Foo {
+    bin: usize,
+    bar: usize,
+    baz: usize,
+}
+
+impl PartialOrd for Foo {
+    $0fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
+        match self.bin.partial_cmp(&other.bin) {
+            Some(core::cmp::Ordering::Equal) => {}
+            ord => return ord,
+        }
+        match self.bar.partial_cmp(&other.bar) {
+            Some(core::cmp::Ordering::Equal) => {}
+            ord => return ord,
+        }
+        self.baz.partial_cmp(&other.baz)
+    }
+}
+"#,
+        )
+    }
+
+    #[test]
+    fn add_custom_impl_partial_ord_tuple_struct() {
+        check_assist(
+            replace_derive_with_manual_impl,
+            r#"
+//- minicore: ord
+#[derive(Partial$0Ord)]
+struct Foo(usize, usize, usize);
+"#,
+            r#"
+struct Foo(usize, usize, usize);
+
+impl PartialOrd for Foo {
+    $0fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
+        match self.0.partial_cmp(&other.0) {
+            Some(core::cmp::Ordering::Equal) => {}
+            ord => return ord,
+        }
+        match self.1.partial_cmp(&other.1) {
+            Some(core::cmp::Ordering::Equal) => {}
+            ord => return ord,
+        }
+        self.2.partial_cmp(&other.2)
+    }
+}
+"#,
+        )
+    }
+
+    #[test]
+    fn add_custom_impl_partial_eq_record_struct() {
+        check_assist(
+            replace_derive_with_manual_impl,
+            r#"
+//- minicore: eq
+#[derive(Partial$0Eq)]
+struct Foo {
+    bin: usize,
+    bar: usize,
+}
+"#,
+            r#"
+struct Foo {
+    bin: usize,
+    bar: usize,
+}
+
+impl PartialEq for Foo {
+    $0fn eq(&self, other: &Self) -> bool {
+        self.bin == other.bin && self.bar == other.bar
+    }
+}
+"#,
+        )
+    }
+
+    #[test]
+    fn add_custom_impl_partial_eq_tuple_struct() {
+        check_assist(
+            replace_derive_with_manual_impl,
+            r#"
+//- minicore: eq
+#[derive(Partial$0Eq)]
+struct Foo(usize, usize);
+"#,
+            r#"
+struct Foo(usize, usize);
+
+impl PartialEq for Foo {
+    $0fn eq(&self, other: &Self) -> bool {
+        self.0 == other.0 && self.1 == other.1
+    }
+}
+"#,
+        )
+    }
+
+    #[test]
+    fn add_custom_impl_partial_eq_empty_struct() {
+        check_assist(
+            replace_derive_with_manual_impl,
+            r#"
+//- minicore: eq
+#[derive(Partial$0Eq)]
+struct Foo;
+"#,
+            r#"
+struct Foo;
+
+impl PartialEq for Foo {
+    $0fn eq(&self, other: &Self) -> bool {
+        true
+    }
+}
+"#,
+        )
+    }
+
+    #[test]
+    fn add_custom_impl_partial_eq_enum() {
+        check_assist(
+            replace_derive_with_manual_impl,
+            r#"
+//- minicore: eq
+#[derive(Partial$0Eq)]
+enum Foo {
+    Bar,
+    Baz,
+}
+"#,
+            r#"
+enum Foo {
+    Bar,
+    Baz,
+}
+
+impl PartialEq for Foo {
+    $0fn eq(&self, other: &Self) -> bool {
+        core::mem::discriminant(self) == core::mem::discriminant(other)
+    }
+}
+"#,
+        )
+    }
+
+    #[test]
+    fn add_custom_impl_partial_eq_tuple_enum() {
+        check_assist(
+            replace_derive_with_manual_impl,
+            r#"
+//- minicore: eq
+#[derive(Partial$0Eq)]
+enum Foo {
+    Bar(String),
+    Baz,
+}
+"#,
+            r#"
+enum Foo {
+    Bar(String),
+    Baz,
+}
+
+impl PartialEq for Foo {
+    $0fn eq(&self, other: &Self) -> bool {
+        match (self, other) {
+            (Self::Bar(l0), Self::Bar(r0)) => l0 == r0,
+            _ => core::mem::discriminant(self) == core::mem::discriminant(other),
+        }
+    }
+}
+"#,
+        )
+    }
+
+    #[test]
+    fn add_custom_impl_partial_eq_record_enum() {
+        check_assist(
+            replace_derive_with_manual_impl,
+            r#"
+//- minicore: eq
+#[derive(Partial$0Eq)]
+enum Foo {
+    Bar {
+        bin: String,
+    },
+    Baz {
+        qux: String,
+        fez: String,
+    },
+    Qux {},
+    Bin,
+}
+"#,
+            r#"
+enum Foo {
+    Bar {
+        bin: String,
+    },
+    Baz {
+        qux: String,
+        fez: String,
+    },
+    Qux {},
+    Bin,
+}
+
+impl PartialEq for Foo {
+    $0fn eq(&self, other: &Self) -> bool {
+        match (self, other) {
+            (Self::Bar { bin: l_bin }, Self::Bar { bin: r_bin }) => l_bin == r_bin,
+            (Self::Baz { qux: l_qux, fez: l_fez }, Self::Baz { qux: r_qux, fez: r_fez }) => l_qux == r_qux && l_fez == r_fez,
+            _ => core::mem::discriminant(self) == core::mem::discriminant(other),
+        }
+    }
+}
 "#,
         )
     }
@@ -558,6 +1088,54 @@ impl Debug for Foo {
         )
     }
 
+    #[test]
+    fn add_custom_impl_default_generic_record_struct() {
+        check_assist(
+            replace_derive_with_manual_impl,
+            r#"
+//- minicore: default
+#[derive(Defau$0lt)]
+struct Foo<T, U> {
+    foo: T,
+    bar: U,
+}
+"#,
+            r#"
+struct Foo<T, U> {
+    foo: T,
+    bar: U,
+}
+
+impl<T, U> Default for Foo<T, U> {
+    $0fn default() -> Self {
+        Self { foo: Default::default(), bar: Default::default() }
+    }
+}
+"#,
+        )
+    }
+
+    #[test]
+    fn add_custom_impl_clone_generic_tuple_struct_with_bounds() {
+        check_assist(
+            replace_derive_with_manual_impl,
+            r#"
+//- minicore: clone
+#[derive(Clo$0ne)]
+struct Foo<T: Clone>(T, usize);
+"#,
+            r#"
+struct Foo<T: Clone>(T, usize);
+
+impl<T: Clone> Clone for Foo<T> {
+    $0fn clone(&self) -> Self {
+        Self(self.0.clone(), self.1.clone())
+    }
+}
+"#,
+        )
+    }
+
     #[test]
     fn test_ignore_derive_macro_without_input() {
         check_assist_not_applicable(
@@ -610,4 +1188,48 @@ fn works_at_start_of_file() {
             "#,
         );
     }
+
+    #[test]
+    fn add_custom_impl_keep_path() {
+        check_assist(
+            replace_derive_with_manual_impl,
+            r#"
+//- minicore: clone
+#[derive(std::fmt::Debug, Clo$0ne)]
+pub struct Foo;
+"#,
+            r#"
+#[derive(std::fmt::Debug)]
+pub struct Foo;
+
+impl Clone for Foo {
+    $0fn clone(&self) -> Self {
+        Self {  }
+    }
+}
+"#,
+        )
+    }
+
+    #[test]
+    fn add_custom_impl_replace_path() {
+        check_assist(
+            replace_derive_with_manual_impl,
+            r#"
+//- minicore: fmt
+#[derive(core::fmt::Deb$0ug, Clone)]
+pub struct Foo;
+"#,
+            r#"
+#[derive(Clone)]
+pub struct Foo;
+
+impl core::fmt::Debug for Foo {
+    $0fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
+        f.debug_struct("Foo").finish()
+    }
+}
+"#,
+        )
+    }
 }