]> git.lizzy.rs Git - rust.git/blobdiff - crates/ide/src/expand_macro.rs
fix: insert whitespaces into assoc items for assist when macro generated
[rust.git] / crates / ide / src / expand_macro.rs
index eebae5ebea50bd15e73c1e7f67c6155fec89d1d7..949744c01b2ae73a79c342bea85f5cb9834776b5 100644 (file)
@@ -1,11 +1,10 @@
-use std::iter;
-
 use hir::Semantics;
-use ide_db::RootDatabase;
-use syntax::{
-    algo::find_node_at_offset, ast, ted, AstNode, NodeOrToken, SyntaxKind, SyntaxKind::*,
-    SyntaxNode, WalkEvent, T,
+use ide_db::{
+    helpers::{insert_whitespace_into_node::insert_ws_into, pick_best_token},
+    RootDatabase,
 };
+use itertools::Itertools;
+use syntax::{ast, ted, AstNode, SyntaxKind, SyntaxNode};
 
 use crate::FilePosition;
 
@@ -28,16 +27,64 @@ pub struct ExpandedMacro {
 pub(crate) fn expand_macro(db: &RootDatabase, position: FilePosition) -> Option<ExpandedMacro> {
     let sema = Semantics::new(db);
     let file = sema.parse(position.file_id);
-    let name_ref = find_node_at_offset::<ast::NameRef>(file.syntax(), position.offset)?;
-    let mac = name_ref.syntax().ancestors().find_map(ast::MacroCall::cast)?;
 
-    let expanded = expand_macro_recur(&sema, &mac)?;
+    let tok = pick_best_token(file.syntax().token_at_offset(position.offset), |kind| match kind {
+        SyntaxKind::IDENT => 1,
+        _ => 0,
+    })?;
+
+    // due to how Rust Analyzer works internally, we need to special case derive attributes,
+    // otherwise they might not get found, e.g. here with the cursor at $0 `#[attr]` would expand:
+    // ```
+    // #[attr]
+    // #[derive($0Foo)]
+    // struct Bar;
+    // ```
+
+    let derive = sema.descend_into_macros(tok.clone()).iter().find_map(|descended| {
+        let attr = descended.ancestors().find_map(ast::Attr::cast)?;
+        let (path, tt) = attr.as_simple_call()?;
+        if path == "derive" {
+            let mut tt = tt.syntax().children_with_tokens().skip(1).join("");
+            tt.pop();
+            let expansions = sema.expand_derive_macro(&attr)?;
+            Some(ExpandedMacro {
+                name: tt,
+                expansion: expansions.into_iter().map(insert_ws_into).join(""),
+            })
+        } else {
+            None
+        }
+    });
+
+    if derive.is_some() {
+        return derive;
+    }
+
+    // FIXME: Intermix attribute and bang! expansions
+    // currently we only recursively expand one of the two types
+    let mut expanded = None;
+    let mut name = None;
+    for node in tok.ancestors() {
+        if let Some(item) = ast::Item::cast(node.clone()) {
+            if let Some(def) = sema.resolve_attr_macro_call(&item) {
+                name = def.name(db).map(|name| name.to_string());
+                expanded = expand_attr_macro_recur(&sema, &item);
+                break;
+            }
+        }
+        if let Some(mac) = ast::MacroCall::cast(node) {
+            name = Some(mac.path()?.segment()?.name_ref()?.to_string());
+            expanded = expand_macro_recur(&sema, &mac);
+            break;
+        }
+    }
 
     // FIXME:
     // macro expansion may lose all white space information
     // But we hope someday we can use ra_fmt for that
-    let expansion = insert_whitespaces(expanded);
-    Some(ExpandedMacro { name: name_ref.text().to_string(), expansion })
+    let expansion = insert_ws_into(expanded?).to_string();
+    Some(ExpandedMacro { name: name.unwrap_or_else(|| "???".to_owned()), expansion })
 }
 
 fn expand_macro_recur(
@@ -45,12 +92,25 @@ fn expand_macro_recur(
     macro_call: &ast::MacroCall,
 ) -> Option<SyntaxNode> {
     let expanded = sema.expand(macro_call)?.clone_for_update();
+    expand(sema, expanded, ast::MacroCall::cast, expand_macro_recur)
+}
 
-    let children = expanded.descendants().filter_map(ast::MacroCall::cast);
+fn expand_attr_macro_recur(sema: &Semantics<RootDatabase>, item: &ast::Item) -> Option<SyntaxNode> {
+    let expanded = sema.expand_attr_macro(item)?.clone_for_update();
+    expand(sema, expanded, ast::Item::cast, expand_attr_macro_recur)
+}
+
+fn expand<T: AstNode>(
+    sema: &Semantics<RootDatabase>,
+    expanded: SyntaxNode,
+    f: impl FnMut(SyntaxNode) -> Option<T>,
+    exp: impl Fn(&Semantics<RootDatabase>, &T) -> Option<SyntaxNode>,
+) -> Option<SyntaxNode> {
+    let children = expanded.descendants().filter_map(f);
     let mut replacements = Vec::new();
 
     for child in children {
-        if let Some(new_node) = expand_macro_recur(sema, &child) {
+        if let Some(new_node) = exp(sema, &child) {
             // check if the whole original syntax is replaced
             if expanded == *child.syntax() {
                 return Some(new_node);
@@ -63,84 +123,13 @@ fn expand_macro_recur(
     Some(expanded)
 }
 
-// FIXME: It would also be cool to share logic here and in the mbe tests,
-// which are pretty unreadable at the moment.
-fn insert_whitespaces(syn: SyntaxNode) -> String {
-    let mut res = String::new();
-    let mut token_iter = syn
-        .preorder_with_tokens()
-        .filter_map(|event| {
-            if let WalkEvent::Enter(NodeOrToken::Token(token)) = event {
-                Some(token)
-            } else {
-                None
-            }
-        })
-        .peekable();
-
-    let mut indent = 0;
-    let mut last: Option<SyntaxKind> = None;
-
-    while let Some(token) = token_iter.next() {
-        let mut is_next = |f: fn(SyntaxKind) -> bool, default| -> bool {
-            token_iter.peek().map(|it| f(it.kind())).unwrap_or(default)
-        };
-        let is_last =
-            |f: fn(SyntaxKind) -> bool, default| -> bool { last.map(f).unwrap_or(default) };
-
-        match token.kind() {
-            k if is_text(k) && is_next(|it| !it.is_punct(), true) => {
-                res.push_str(token.text());
-                res.push(' ');
-            }
-            L_CURLY if is_next(|it| it != R_CURLY, true) => {
-                indent += 1;
-                if is_last(is_text, false) {
-                    res.push(' ');
-                }
-                res.push_str("{\n");
-                res.extend(iter::repeat(" ").take(2 * indent));
-            }
-            R_CURLY if is_last(|it| it != L_CURLY, true) => {
-                indent = indent.saturating_sub(1);
-                res.push('\n');
-                res.extend(iter::repeat(" ").take(2 * indent));
-                res.push_str("}");
-            }
-            R_CURLY => {
-                res.push_str("}\n");
-                res.extend(iter::repeat(" ").take(2 * indent));
-            }
-            LIFETIME_IDENT if is_next(|it| it == IDENT, true) => {
-                res.push_str(token.text());
-                res.push(' ');
-            }
-            T![;] => {
-                res.push_str(";\n");
-                res.extend(iter::repeat(" ").take(2 * indent));
-            }
-            T![->] => res.push_str(" -> "),
-            T![=] => res.push_str(" = "),
-            T![=>] => res.push_str(" => "),
-            _ => res.push_str(token.text()),
-        }
-
-        last = Some(token.kind());
-    }
-
-    return res;
-
-    fn is_text(k: SyntaxKind) -> bool {
-        k.is_keyword() || k.is_literal() || k == IDENT
-    }
-}
-
 #[cfg(test)]
 mod tests {
     use expect_test::{expect, Expect};
 
     use crate::fixture;
 
+    #[track_caller]
     fn check(ra_fixture: &str, expect: Expect) {
         let (analysis, pos) = fixture::position(ra_fixture);
         let expansion = analysis.expand_macro(pos).unwrap().unwrap();
@@ -148,6 +137,23 @@ fn check(ra_fixture: &str, expect: Expect) {
         expect.assert_eq(&actual);
     }
 
+    #[test]
+    fn macro_expand_as_keyword() {
+        check(
+            r#"
+macro_rules! bar {
+    ($i:tt) => { $i as _ }
+}
+fn main() {
+    let x: u64 = ba$0r!(5i64);
+}
+"#,
+            expect![[r#"
+                bar
+                5i64 as _"#]],
+        );
+    }
+
     #[test]
     fn macro_expand_recursive_expansion() {
         check(
@@ -166,6 +172,7 @@ macro_rules! baz {
             expect![[r#"
                 foo
                 fn b(){}
+
             "#]],
         );
     }
@@ -185,11 +192,12 @@ fn some_thing() -> u32 {
 f$0oo!();
         "#,
             expect![[r#"
-            foo
-            fn some_thing() -> u32 {
-              let a = 0;
-              a+10
-            }"#]],
+                foo
+                fn some_thing() -> u32 {
+                  let a = 0;
+                  a+10
+                }
+            "#]],
         );
     }
 
@@ -297,4 +305,60 @@ fn main() {
                 0 "#]],
         );
     }
+
+    #[test]
+    fn macro_expand_derive() {
+        check(
+            r#"
+//- proc_macros: identity
+//- minicore: clone, derive
+
+#[proc_macros::identity]
+#[derive(C$0lone)]
+struct Foo {}
+"#,
+            expect![[r#"
+                Clone
+                impl< >core::clone::Clone for Foo< >{}
+
+            "#]],
+        );
+    }
+
+    #[test]
+    fn macro_expand_derive2() {
+        check(
+            r#"
+//- minicore: copy, clone, derive
+
+#[derive(Cop$0y)]
+#[derive(Clone)]
+struct Foo {}
+"#,
+            expect![[r#"
+                Copy
+                impl< >core::marker::Copy for Foo< >{}
+
+            "#]],
+        );
+    }
+
+    #[test]
+    fn macro_expand_derive_multi() {
+        check(
+            r#"
+//- minicore: copy, clone, derive
+
+#[derive(Cop$0y, Clone)]
+struct Foo {}
+"#,
+            expect![[r#"
+                Copy, Clone
+                impl< >core::marker::Copy for Foo< >{}
+
+                impl< >core::clone::Clone for Foo< >{}
+
+            "#]],
+        );
+    }
 }