]> git.lizzy.rs Git - rust.git/blobdiff - crates/ide/src/expand_macro.rs
fix: Fix expand_macro always expanding the first listed derive
[rust.git] / crates / ide / src / expand_macro.rs
index 32dbd9070b9eb7f1d98e3379b90ce018c5d1027f..7bb6b24a23deefa1dd1cf56ffc505fc9af55fa1d 100644 (file)
@@ -3,8 +3,7 @@
     helpers::{insert_whitespace_into_node::insert_ws_into, pick_best_token},
     RootDatabase,
 };
-use itertools::Itertools;
-use syntax::{ast, ted, AstNode, SyntaxKind, SyntaxNode};
+use syntax::{ast, ted, AstNode, NodeOrToken, SyntaxKind, SyntaxNode, T};
 
 use crate::FilePosition;
 
@@ -41,20 +40,28 @@ pub(crate) fn expand_macro(db: &RootDatabase, position: FilePosition) -> Option<
     // struct Bar;
     // ```
 
-    let derive = sema.descend_into_macros(tok.clone()).iter().find_map(|descended| {
-        let attr = descended.ancestors().find_map(ast::Attr::cast)?;
-        let (path, tt) = attr.as_simple_call()?;
-        if path == "derive" {
-            let mut tt = tt.syntax().children_with_tokens().skip(1).join("");
-            tt.pop();
-            let expansions = sema.expand_derive_macro(&attr)?;
-            Some(ExpandedMacro {
-                name: tt,
-                expansion: expansions.into_iter().map(insert_ws_into).join(""),
-            })
-        } else {
-            None
+    let derive = sema.descend_into_macros(tok.clone()).into_iter().find_map(|descended| {
+        let hir_file = sema.hir_file_for(&descended.parent()?);
+        if !hir_file.is_derive_attr_pseudo_expansion(db) {
+            return None;
         }
+
+        let name = descended.ancestors().filter_map(ast::Path::cast).last()?.to_string();
+        // up map out of the #[derive] expansion
+        let token = hir::InFile::new(hir_file, descended).upmap(db)?.value;
+        let attr = token.ancestors().find_map(ast::Attr::cast)?;
+        let expansions = sema.expand_derive_macro(&attr)?;
+        let idx = attr
+            .token_tree()?
+            .token_trees_and_tokens()
+            .filter_map(NodeOrToken::into_token)
+            .take_while(|it| it != &token)
+            .filter(|it| it.kind() == T![,])
+            .count();
+        Some(ExpandedMacro {
+            name,
+            expansion: expansions.get(idx).cloned().map(insert_ws_into)?.to_string(),
+        })
     });
 
     if derive.is_some() {
@@ -372,9 +379,20 @@ fn macro_expand_derive_multi() {
 struct Foo {}
 "#,
             expect![[r#"
-                Copy, Clone
+                Copy
                 impl < >core::marker::Copy for Foo< >{}
 
+            "#]],
+        );
+        check(
+            r#"
+//- minicore: copy, clone, derive
+
+#[derive(Copy, Cl$0one)]
+struct Foo {}
+"#,
+            expect![[r#"
+                Clone
                 impl < >core::clone::Clone for Foo< >{}
 
             "#]],