]> git.lizzy.rs Git - rust.git/blob - crates/ide/src/expand_macro.rs
Merge #10015
[rust.git] / crates / ide / src / expand_macro.rs
1 use std::iter;
2
3 use hir::Semantics;
4 use ide_db::{helpers::pick_best_token, RootDatabase};
5 use itertools::Itertools;
6 use syntax::{ast, ted, AstNode, NodeOrToken, SyntaxKind, SyntaxNode, WalkEvent, T};
7
8 use crate::FilePosition;
9
10 pub struct ExpandedMacro {
11     pub name: String,
12     pub expansion: String,
13 }
14
15 // Feature: Expand Macro Recursively
16 //
17 // Shows the full macro expansion of the macro at current cursor.
18 //
19 // |===
20 // | Editor  | Action Name
21 //
22 // | VS Code | **Rust Analyzer: Expand macro recursively**
23 // |===
24 //
25 // image::https://user-images.githubusercontent.com/48062697/113020648-b3973180-917a-11eb-84a9-ecb921293dc5.gif[]
26 pub(crate) fn expand_macro(db: &RootDatabase, position: FilePosition) -> Option<ExpandedMacro> {
27     let sema = Semantics::new(db);
28     let file = sema.parse(position.file_id);
29
30     let tok = pick_best_token(file.syntax().token_at_offset(position.offset), |kind| match kind {
31         SyntaxKind::IDENT => 1,
32         _ => 0,
33     })?;
34     let descended = sema.descend_into_macros(tok.clone());
35     if let Some(attr) = descended.ancestors().find_map(ast::Attr::cast) {
36         if let Some((path, tt)) = attr.as_simple_call() {
37             if path == "derive" {
38                 let mut tt = tt.syntax().children_with_tokens().skip(1).join("");
39                 tt.pop();
40                 let expansions = sema.expand_derive_macro(&attr)?;
41                 return Some(ExpandedMacro {
42                     name: tt,
43                     expansion: expansions.into_iter().map(insert_whitespaces).join(""),
44                 });
45             }
46         }
47     }
48     let mut expanded = None;
49     let mut name = None;
50     for node in tok.ancestors() {
51         if let Some(item) = ast::Item::cast(node.clone()) {
52             if let Some(def) = sema.resolve_attr_macro_call(&item) {
53                 name = def.name(db).map(|name| name.to_string());
54                 expanded = expand_attr_macro_recur(&sema, &item);
55                 break;
56             }
57         }
58         if let Some(mac) = ast::MacroCall::cast(node) {
59             name = Some(mac.path()?.segment()?.name_ref()?.to_string());
60             expanded = expand_macro_recur(&sema, &mac);
61             break;
62         }
63     }
64
65     // FIXME:
66     // macro expansion may lose all white space information
67     // But we hope someday we can use ra_fmt for that
68     let expansion = insert_whitespaces(expanded?);
69     Some(ExpandedMacro { name: name.unwrap_or_else(|| "???".to_owned()), expansion })
70 }
71
72 fn expand_macro_recur(
73     sema: &Semantics<RootDatabase>,
74     macro_call: &ast::MacroCall,
75 ) -> Option<SyntaxNode> {
76     let expanded = sema.expand(macro_call)?.clone_for_update();
77     expand(sema, expanded, ast::MacroCall::cast, expand_macro_recur)
78 }
79
80 fn expand_attr_macro_recur(sema: &Semantics<RootDatabase>, item: &ast::Item) -> Option<SyntaxNode> {
81     let expanded = sema.expand_attr_macro(item)?.clone_for_update();
82     expand(sema, expanded, ast::Item::cast, expand_attr_macro_recur)
83 }
84
85 fn expand<T: AstNode>(
86     sema: &Semantics<RootDatabase>,
87     expanded: SyntaxNode,
88     f: impl FnMut(SyntaxNode) -> Option<T>,
89     exp: impl Fn(&Semantics<RootDatabase>, &T) -> Option<SyntaxNode>,
90 ) -> Option<SyntaxNode> {
91     let children = expanded.descendants().filter_map(f);
92     let mut replacements = Vec::new();
93
94     for child in children {
95         if let Some(new_node) = exp(sema, &child) {
96             // check if the whole original syntax is replaced
97             if expanded == *child.syntax() {
98                 return Some(new_node);
99             }
100             replacements.push((child, new_node));
101         }
102     }
103
104     replacements.into_iter().rev().for_each(|(old, new)| ted::replace(old.syntax(), new));
105     Some(expanded)
106 }
107
108 // FIXME: It would also be cool to share logic here and in the mbe tests,
109 // which are pretty unreadable at the moment.
110 fn insert_whitespaces(syn: SyntaxNode) -> String {
111     use SyntaxKind::*;
112     let mut res = String::new();
113
114     let mut indent = 0;
115     let mut last: Option<SyntaxKind> = None;
116
117     for event in syn.preorder_with_tokens() {
118         let token = match event {
119             WalkEvent::Enter(NodeOrToken::Token(token)) => token,
120             WalkEvent::Leave(NodeOrToken::Node(node))
121                 if matches!(node.kind(), ATTR | MATCH_ARM | STRUCT | ENUM | UNION | FN | IMPL) =>
122             {
123                 res.push('\n');
124                 res.extend(iter::repeat(" ").take(2 * indent));
125                 continue;
126             }
127             _ => continue,
128         };
129         let is_next = |f: fn(SyntaxKind) -> bool, default| -> bool {
130             token.next_token().map(|it| f(it.kind())).unwrap_or(default)
131         };
132         let is_last =
133             |f: fn(SyntaxKind) -> bool, default| -> bool { last.map(f).unwrap_or(default) };
134
135         match token.kind() {
136             k if is_text(k) && is_next(|it| !it.is_punct(), true) => {
137                 res.push_str(token.text());
138                 res.push(' ');
139             }
140             L_CURLY if is_next(|it| it != R_CURLY, true) => {
141                 indent += 1;
142                 if is_last(is_text, false) {
143                     res.push(' ');
144                 }
145                 res.push_str("{\n");
146                 res.extend(iter::repeat(" ").take(2 * indent));
147             }
148             R_CURLY if is_last(|it| it != L_CURLY, true) => {
149                 indent = indent.saturating_sub(1);
150                 res.push('\n');
151                 res.extend(iter::repeat(" ").take(2 * indent));
152                 res.push_str("}");
153             }
154             R_CURLY => {
155                 res.push_str("}\n");
156                 res.extend(iter::repeat(" ").take(2 * indent));
157             }
158             LIFETIME_IDENT if is_next(|it| it == IDENT, true) => {
159                 res.push_str(token.text());
160                 res.push(' ');
161             }
162             T![;] => {
163                 res.push_str(";\n");
164                 res.extend(iter::repeat(" ").take(2 * indent));
165             }
166             T![->] => res.push_str(" -> "),
167             T![=] => res.push_str(" = "),
168             T![=>] => res.push_str(" => "),
169             _ => res.push_str(token.text()),
170         }
171
172         last = Some(token.kind());
173     }
174
175     return res;
176
177     fn is_text(k: SyntaxKind) -> bool {
178         k.is_keyword() || k.is_literal() || k == IDENT
179     }
180 }
181
182 #[cfg(test)]
183 mod tests {
184     use expect_test::{expect, Expect};
185
186     use crate::fixture;
187
188     #[track_caller]
189     fn check(ra_fixture: &str, expect: Expect) {
190         let (analysis, pos) = fixture::position(ra_fixture);
191         let expansion = analysis.expand_macro(pos).unwrap().unwrap();
192         let actual = format!("{}\n{}", expansion.name, expansion.expansion);
193         expect.assert_eq(&actual);
194     }
195
196     #[test]
197     fn macro_expand_recursive_expansion() {
198         check(
199             r#"
200 macro_rules! bar {
201     () => { fn  b() {} }
202 }
203 macro_rules! foo {
204     () => { bar!(); }
205 }
206 macro_rules! baz {
207     () => { foo!(); }
208 }
209 f$0oo!();
210 "#,
211             expect![[r#"
212                 foo
213                 fn b(){}
214
215             "#]],
216         );
217     }
218
219     #[test]
220     fn macro_expand_multiple_lines() {
221         check(
222             r#"
223 macro_rules! foo {
224     () => {
225         fn some_thing() -> u32 {
226             let a = 0;
227             a + 10
228         }
229     }
230 }
231 f$0oo!();
232         "#,
233             expect![[r#"
234                 foo
235                 fn some_thing() -> u32 {
236                   let a = 0;
237                   a+10
238                 }
239             "#]],
240         );
241     }
242
243     #[test]
244     fn macro_expand_match_ast() {
245         check(
246             r#"
247 macro_rules! match_ast {
248     (match $node:ident { $($tt:tt)* }) => { match_ast!(match ($node) { $($tt)* }) };
249     (match ($node:expr) {
250         $( ast::$ast:ident($it:ident) => $res:block, )*
251         _ => $catch_all:expr $(,)?
252     }) => {{
253         $( if let Some($it) = ast::$ast::cast($node.clone()) $res else )*
254         { $catch_all }
255     }};
256 }
257
258 fn main() {
259     mat$0ch_ast! {
260         match container {
261             ast::TraitDef(it) => {},
262             ast::ImplDef(it) => {},
263             _ => { continue },
264         }
265     }
266 }
267 "#,
268             expect![[r#"
269        match_ast
270        {
271          if let Some(it) = ast::TraitDef::cast(container.clone()){}
272          else if let Some(it) = ast::ImplDef::cast(container.clone()){}
273          else {
274            {
275              continue
276            }
277          }
278        }"#]],
279         );
280     }
281
282     #[test]
283     fn macro_expand_match_ast_inside_let_statement() {
284         check(
285             r#"
286 macro_rules! match_ast {
287     (match $node:ident { $($tt:tt)* }) => { match_ast!(match ($node) { $($tt)* }) };
288     (match ($node:expr) {}) => {{}};
289 }
290
291 fn main() {
292     let p = f(|it| {
293         let res = mat$0ch_ast! { match c {}};
294         Some(res)
295     })?;
296 }
297 "#,
298             expect![[r#"
299                 match_ast
300                 {}
301             "#]],
302         );
303     }
304
305     #[test]
306     fn macro_expand_inner_macro_fail_to_expand() {
307         check(
308             r#"
309 macro_rules! bar {
310     (BAD) => {};
311 }
312 macro_rules! foo {
313     () => {bar!()};
314 }
315
316 fn main() {
317     let res = fo$0o!();
318 }
319 "#,
320             expect![[r#"
321                 foo
322             "#]],
323         );
324     }
325
326     #[test]
327     fn macro_expand_with_dollar_crate() {
328         check(
329             r#"
330 #[macro_export]
331 macro_rules! bar {
332     () => {0};
333 }
334 macro_rules! foo {
335     () => {$crate::bar!()};
336 }
337
338 fn main() {
339     let res = fo$0o!();
340 }
341 "#,
342             expect![[r#"
343                 foo
344                 0 "#]],
345         );
346     }
347
348     #[test]
349     fn macro_expand_derive() {
350         check(
351             r#"
352 #[rustc_builtin_macro]
353 pub macro Clone {}
354
355 #[derive(C$0lone)]
356 struct Foo {}
357 "#,
358             expect![[r#"
359                 Clone
360                 impl< >crate::clone::Clone for Foo< >{}
361
362             "#]],
363         );
364     }
365
366     #[test]
367     fn macro_expand_derive2() {
368         check(
369             r#"
370 #[rustc_builtin_macro]
371 pub macro Clone {}
372 #[rustc_builtin_macro]
373 pub macro Copy {}
374
375 #[derive(Cop$0y)]
376 #[derive(Clone)]
377 struct Foo {}
378 "#,
379             expect![[r#"
380                 Copy
381                 impl< >crate::marker::Copy for Foo< >{}
382
383             "#]],
384         );
385     }
386
387     #[test]
388     fn macro_expand_derive_multi() {
389         check(
390             r#"
391 #[rustc_builtin_macro]
392 pub macro Clone {}
393 #[rustc_builtin_macro]
394 pub macro Copy {}
395
396 #[derive(Cop$0y, Clone)]
397 struct Foo {}
398 "#,
399             expect![[r#"
400                 Copy, Clone
401                 impl< >crate::marker::Copy for Foo< >{}
402
403                 impl< >crate::clone::Clone for Foo< >{}
404
405             "#]],
406         );
407     }
408 }