]> git.lizzy.rs Git - rust.git/blob - crates/ide/src/expand_macro.rs
Improve expand_macro
[rust.git] / crates / ide / src / expand_macro.rs
1 use std::iter;
2
3 use hir::Semantics;
4 use ide_db::{helpers::pick_best_token, RootDatabase};
5 use itertools::Itertools;
6 use syntax::{ast, ted, AstNode, NodeOrToken, SyntaxKind, SyntaxNode, WalkEvent, T};
7
8 use crate::FilePosition;
9
10 pub struct ExpandedMacro {
11     pub name: String,
12     pub expansion: String,
13 }
14
15 // Feature: Expand Macro Recursively
16 //
17 // Shows the full macro expansion of the macro at current cursor.
18 //
19 // |===
20 // | Editor  | Action Name
21 //
22 // | VS Code | **Rust Analyzer: Expand macro recursively**
23 // |===
24 //
25 // image::https://user-images.githubusercontent.com/48062697/113020648-b3973180-917a-11eb-84a9-ecb921293dc5.gif[]
26 pub(crate) fn expand_macro(db: &RootDatabase, position: FilePosition) -> Option<ExpandedMacro> {
27     let sema = Semantics::new(db);
28     let file = sema.parse(position.file_id);
29
30     let tok = pick_best_token(file.syntax().token_at_offset(position.offset), |kind| match kind {
31         SyntaxKind::IDENT => 1,
32         _ => 0,
33     })?;
34     let descended = sema.descend_into_macros(tok.clone());
35     if let Some(attr) = descended.ancestors().find_map(ast::Attr::cast) {
36         if let Some((path, tt)) = attr.as_simple_call() {
37             if path == "derive" {
38                 let mut tt = tt.syntax().children_with_tokens().skip(1).join("");
39                 tt.pop();
40                 return sema
41                     .expand_derive_macro(&attr)
42                     .map(insert_whitespaces)
43                     .map(|expansion| ExpandedMacro { name: tt, expansion });
44             }
45         }
46     }
47     let mut expanded = None;
48     let mut name = None;
49     for node in tok.ancestors() {
50         if let Some(item) = ast::Item::cast(node.clone()) {
51             if let Some(def) = sema.resolve_attr_macro_call(&item) {
52                 name = def.name(db).map(|name| name.to_string());
53                 expanded = expand_attr_macro_recur(&sema, &item);
54                 break;
55             }
56         }
57         if let Some(mac) = ast::MacroCall::cast(node) {
58             name = Some(mac.path()?.segment()?.name_ref()?.to_string());
59             expanded = expand_macro_recur(&sema, &mac);
60             break;
61         }
62     }
63
64     // FIXME:
65     // macro expansion may lose all white space information
66     // But we hope someday we can use ra_fmt for that
67     let expansion = insert_whitespaces(expanded?);
68     Some(ExpandedMacro { name: name.unwrap_or_else(|| "???".to_owned()), expansion })
69 }
70
71 fn expand_macro_recur(
72     sema: &Semantics<RootDatabase>,
73     macro_call: &ast::MacroCall,
74 ) -> Option<SyntaxNode> {
75     let expanded = sema.expand(macro_call)?.clone_for_update();
76     expand(sema, expanded, ast::MacroCall::cast, expand_macro_recur)
77 }
78
79 fn expand_attr_macro_recur(sema: &Semantics<RootDatabase>, item: &ast::Item) -> Option<SyntaxNode> {
80     let expanded = sema.expand_attr_macro(item)?.clone_for_update();
81     expand(sema, expanded, ast::Item::cast, expand_attr_macro_recur)
82 }
83
84 fn expand<T: AstNode>(
85     sema: &Semantics<RootDatabase>,
86     expanded: SyntaxNode,
87     f: impl FnMut(SyntaxNode) -> Option<T>,
88     exp: impl Fn(&Semantics<RootDatabase>, &T) -> Option<SyntaxNode>,
89 ) -> Option<SyntaxNode> {
90     let children = expanded.descendants().filter_map(f);
91     let mut replacements = Vec::new();
92
93     for child in children {
94         if let Some(new_node) = exp(sema, &child) {
95             // check if the whole original syntax is replaced
96             if expanded == *child.syntax() {
97                 return Some(new_node);
98             }
99             replacements.push((child, new_node));
100         }
101     }
102
103     replacements.into_iter().rev().for_each(|(old, new)| ted::replace(old.syntax(), new));
104     Some(expanded)
105 }
106
107 // FIXME: It would also be cool to share logic here and in the mbe tests,
108 // which are pretty unreadable at the moment.
109 fn insert_whitespaces(syn: SyntaxNode) -> String {
110     use SyntaxKind::*;
111     let mut res = String::new();
112
113     let mut indent = 0;
114     let mut last: Option<SyntaxKind> = None;
115
116     for event in syn.preorder_with_tokens() {
117         let token = match event {
118             WalkEvent::Enter(NodeOrToken::Token(token)) => token,
119             WalkEvent::Leave(NodeOrToken::Node(node))
120                 if matches!(node.kind(), ATTR | MATCH_ARM | STRUCT | ENUM | UNION | FN | IMPL) =>
121             {
122                 res.push('\n');
123                 res.extend(iter::repeat(" ").take(2 * indent));
124                 continue;
125             }
126             _ => continue,
127         };
128         let is_next = |f: fn(SyntaxKind) -> bool, default| -> bool {
129             token.next_token().map(|it| f(it.kind())).unwrap_or(default)
130         };
131         let is_last =
132             |f: fn(SyntaxKind) -> bool, default| -> bool { last.map(f).unwrap_or(default) };
133
134         match token.kind() {
135             k if is_text(k) && is_next(|it| !it.is_punct(), true) => {
136                 res.push_str(token.text());
137                 res.push(' ');
138             }
139             L_CURLY if is_next(|it| it != R_CURLY, true) => {
140                 indent += 1;
141                 if is_last(is_text, false) {
142                     res.push(' ');
143                 }
144                 res.push_str("{\n");
145                 res.extend(iter::repeat(" ").take(2 * indent));
146             }
147             R_CURLY if is_last(|it| it != L_CURLY, true) => {
148                 indent = indent.saturating_sub(1);
149                 res.push('\n');
150                 res.extend(iter::repeat(" ").take(2 * indent));
151                 res.push_str("}");
152             }
153             R_CURLY => {
154                 res.push_str("}\n");
155                 res.extend(iter::repeat(" ").take(2 * indent));
156             }
157             LIFETIME_IDENT if is_next(|it| it == IDENT, true) => {
158                 res.push_str(token.text());
159                 res.push(' ');
160             }
161             T![;] => {
162                 res.push_str(";\n");
163                 res.extend(iter::repeat(" ").take(2 * indent));
164             }
165             T![->] => res.push_str(" -> "),
166             T![=] => res.push_str(" = "),
167             T![=>] => res.push_str(" => "),
168             _ => res.push_str(token.text()),
169         }
170
171         last = Some(token.kind());
172     }
173
174     return res;
175
176     fn is_text(k: SyntaxKind) -> bool {
177         k.is_keyword() || k.is_literal() || k == IDENT
178     }
179 }
180
181 #[cfg(test)]
182 mod tests {
183     use expect_test::{expect, Expect};
184
185     use crate::fixture;
186
187     #[track_caller]
188     fn check(ra_fixture: &str, expect: Expect) {
189         let (analysis, pos) = fixture::position(ra_fixture);
190         let expansion = analysis.expand_macro(pos).unwrap().unwrap();
191         let actual = format!("{}\n{}", expansion.name, expansion.expansion);
192         expect.assert_eq(&actual);
193     }
194
195     #[test]
196     fn macro_expand_recursive_expansion() {
197         check(
198             r#"
199 macro_rules! bar {
200     () => { fn  b() {} }
201 }
202 macro_rules! foo {
203     () => { bar!(); }
204 }
205 macro_rules! baz {
206     () => { foo!(); }
207 }
208 f$0oo!();
209 "#,
210             expect![[r#"
211                 foo
212                 fn b(){}
213
214             "#]],
215         );
216     }
217
218     #[test]
219     fn macro_expand_multiple_lines() {
220         check(
221             r#"
222 macro_rules! foo {
223     () => {
224         fn some_thing() -> u32 {
225             let a = 0;
226             a + 10
227         }
228     }
229 }
230 f$0oo!();
231         "#,
232             expect![[r#"
233                 foo
234                 fn some_thing() -> u32 {
235                   let a = 0;
236                   a+10
237                 }
238             "#]],
239         );
240     }
241
242     #[test]
243     fn macro_expand_match_ast() {
244         check(
245             r#"
246 macro_rules! match_ast {
247     (match $node:ident { $($tt:tt)* }) => { match_ast!(match ($node) { $($tt)* }) };
248     (match ($node:expr) {
249         $( ast::$ast:ident($it:ident) => $res:block, )*
250         _ => $catch_all:expr $(,)?
251     }) => {{
252         $( if let Some($it) = ast::$ast::cast($node.clone()) $res else )*
253         { $catch_all }
254     }};
255 }
256
257 fn main() {
258     mat$0ch_ast! {
259         match container {
260             ast::TraitDef(it) => {},
261             ast::ImplDef(it) => {},
262             _ => { continue },
263         }
264     }
265 }
266 "#,
267             expect![[r#"
268        match_ast
269        {
270          if let Some(it) = ast::TraitDef::cast(container.clone()){}
271          else if let Some(it) = ast::ImplDef::cast(container.clone()){}
272          else {
273            {
274              continue
275            }
276          }
277        }"#]],
278         );
279     }
280
281     #[test]
282     fn macro_expand_match_ast_inside_let_statement() {
283         check(
284             r#"
285 macro_rules! match_ast {
286     (match $node:ident { $($tt:tt)* }) => { match_ast!(match ($node) { $($tt)* }) };
287     (match ($node:expr) {}) => {{}};
288 }
289
290 fn main() {
291     let p = f(|it| {
292         let res = mat$0ch_ast! { match c {}};
293         Some(res)
294     })?;
295 }
296 "#,
297             expect![[r#"
298                 match_ast
299                 {}
300             "#]],
301         );
302     }
303
304     #[test]
305     fn macro_expand_inner_macro_fail_to_expand() {
306         check(
307             r#"
308 macro_rules! bar {
309     (BAD) => {};
310 }
311 macro_rules! foo {
312     () => {bar!()};
313 }
314
315 fn main() {
316     let res = fo$0o!();
317 }
318 "#,
319             expect![[r#"
320                 foo
321             "#]],
322         );
323     }
324
325     #[test]
326     fn macro_expand_with_dollar_crate() {
327         check(
328             r#"
329 #[macro_export]
330 macro_rules! bar {
331     () => {0};
332 }
333 macro_rules! foo {
334     () => {$crate::bar!()};
335 }
336
337 fn main() {
338     let res = fo$0o!();
339 }
340 "#,
341             expect![[r#"
342                 foo
343                 0 "#]],
344         );
345     }
346
347     #[test]
348     fn macro_expand_derive() {
349         check(
350             r#"
351 #[rustc_builtin_macro]
352 pub macro Clone {}
353
354 #[derive(C$0lone)]
355 struct Foo {}
356 "#,
357             expect![[r#"
358                 Clone
359                 impl< >crate::clone::Clone for Foo< >{}
360
361             "#]],
362         );
363     }
364
365     #[test]
366     fn macro_expand_derive2() {
367         check(
368             r#"
369 #[rustc_builtin_macro]
370 pub macro Clone {}
371 #[rustc_builtin_macro]
372 pub macro Copy {}
373
374 #[derive(Cop$0y)]
375 #[derive(Clone)]
376 struct Foo {}
377 "#,
378             expect![[r#"
379                 Copy
380                 impl< >crate::marker::Copy for Foo< >{}
381
382             "#]],
383         );
384     }
385 }