]> git.lizzy.rs Git - rust.git/blob - crates/ra_hir_expand/src/builtin_derive.rs
lsp-types 0.74
[rust.git] / crates / ra_hir_expand / src / builtin_derive.rs
1 //! Builtin derives.
2
3 use log::debug;
4
5 use ra_parser::FragmentKind;
6 use ra_syntax::{
7     ast::{self, AstNode, ModuleItemOwner, NameOwner, TypeParamsOwner},
8     match_ast,
9 };
10
11 use crate::db::AstDatabase;
12 use crate::{name, quote, LazyMacroId, MacroDefId, MacroDefKind};
13
14 macro_rules! register_builtin {
15     ( $($trait:ident => $expand:ident),* ) => {
16         #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
17         pub enum BuiltinDeriveExpander {
18             $($trait),*
19         }
20
21         impl BuiltinDeriveExpander {
22             pub fn expand(
23                 &self,
24                 db: &dyn AstDatabase,
25                 id: LazyMacroId,
26                 tt: &tt::Subtree,
27             ) -> Result<tt::Subtree, mbe::ExpandError> {
28                 let expander = match *self {
29                     $( BuiltinDeriveExpander::$trait => $expand, )*
30                 };
31                 expander(db, id, tt)
32             }
33         }
34
35         pub fn find_builtin_derive(ident: &name::Name) -> Option<MacroDefId> {
36             let kind = match ident {
37                 $( id if id == &name::name![$trait] => BuiltinDeriveExpander::$trait, )*
38                  _ => return None,
39             };
40
41             Some(MacroDefId { krate: None, ast_id: None, kind: MacroDefKind::BuiltInDerive(kind) })
42         }
43     };
44 }
45
46 register_builtin! {
47     Copy => copy_expand,
48     Clone => clone_expand,
49     Default => default_expand,
50     Debug => debug_expand,
51     Hash => hash_expand,
52     Ord => ord_expand,
53     PartialOrd => partial_ord_expand,
54     Eq => eq_expand,
55     PartialEq => partial_eq_expand
56 }
57
58 struct BasicAdtInfo {
59     name: tt::Ident,
60     type_params: usize,
61 }
62
63 fn parse_adt(tt: &tt::Subtree) -> Result<BasicAdtInfo, mbe::ExpandError> {
64     let (parsed, token_map) = mbe::token_tree_to_syntax_node(tt, FragmentKind::Items)?; // FragmentKind::Items doesn't parse attrs?
65     let macro_items = ast::MacroItems::cast(parsed.syntax_node()).ok_or_else(|| {
66         debug!("derive node didn't parse");
67         mbe::ExpandError::UnexpectedToken
68     })?;
69     let item = macro_items.items().next().ok_or_else(|| {
70         debug!("no module item parsed");
71         mbe::ExpandError::NoMatchingRule
72     })?;
73     let node = item.syntax();
74     let (name, params) = match_ast! {
75         match node {
76             ast::StructDef(it) => (it.name(), it.type_param_list()),
77             ast::EnumDef(it) => (it.name(), it.type_param_list()),
78             ast::UnionDef(it) => (it.name(), it.type_param_list()),
79             _ => {
80                 debug!("unexpected node is {:?}", node);
81                 return Err(mbe::ExpandError::ConversionError)
82             },
83         }
84     };
85     let name = name.ok_or_else(|| {
86         debug!("parsed item has no name");
87         mbe::ExpandError::NoMatchingRule
88     })?;
89     let name_token_id = token_map.token_by_range(name.syntax().text_range()).ok_or_else(|| {
90         debug!("name token not found");
91         mbe::ExpandError::ConversionError
92     })?;
93     let name_token = tt::Ident { id: name_token_id, text: name.text().clone() };
94     let type_params = params.map_or(0, |type_param_list| type_param_list.type_params().count());
95     Ok(BasicAdtInfo { name: name_token, type_params })
96 }
97
98 fn make_type_args(n: usize, bound: Vec<tt::TokenTree>) -> Vec<tt::TokenTree> {
99     let mut result = Vec::<tt::TokenTree>::new();
100     result.push(
101         tt::Leaf::Punct(tt::Punct {
102             char: '<',
103             spacing: tt::Spacing::Alone,
104             id: tt::TokenId::unspecified(),
105         })
106         .into(),
107     );
108     for i in 0..n {
109         if i > 0 {
110             result.push(
111                 tt::Leaf::Punct(tt::Punct {
112                     char: ',',
113                     spacing: tt::Spacing::Alone,
114                     id: tt::TokenId::unspecified(),
115                 })
116                 .into(),
117             );
118         }
119         result.push(
120             tt::Leaf::Ident(tt::Ident {
121                 id: tt::TokenId::unspecified(),
122                 text: format!("T{}", i).into(),
123             })
124             .into(),
125         );
126         result.extend(bound.iter().cloned());
127     }
128     result.push(
129         tt::Leaf::Punct(tt::Punct {
130             char: '>',
131             spacing: tt::Spacing::Alone,
132             id: tt::TokenId::unspecified(),
133         })
134         .into(),
135     );
136     result
137 }
138
139 fn expand_simple_derive(
140     tt: &tt::Subtree,
141     trait_path: tt::Subtree,
142 ) -> Result<tt::Subtree, mbe::ExpandError> {
143     let info = parse_adt(tt)?;
144     let name = info.name;
145     let trait_path_clone = trait_path.token_trees.clone();
146     let bound = (quote! { : ##trait_path_clone }).token_trees;
147     let type_params = make_type_args(info.type_params, bound);
148     let type_args = make_type_args(info.type_params, Vec::new());
149     let trait_path = trait_path.token_trees;
150     let expanded = quote! {
151         impl ##type_params ##trait_path for #name ##type_args {}
152     };
153     Ok(expanded)
154 }
155
156 fn copy_expand(
157     _db: &dyn AstDatabase,
158     _id: LazyMacroId,
159     tt: &tt::Subtree,
160 ) -> Result<tt::Subtree, mbe::ExpandError> {
161     expand_simple_derive(tt, quote! { std::marker::Copy })
162 }
163
164 fn clone_expand(
165     _db: &dyn AstDatabase,
166     _id: LazyMacroId,
167     tt: &tt::Subtree,
168 ) -> Result<tt::Subtree, mbe::ExpandError> {
169     expand_simple_derive(tt, quote! { std::clone::Clone })
170 }
171
172 fn default_expand(
173     _db: &dyn AstDatabase,
174     _id: LazyMacroId,
175     tt: &tt::Subtree,
176 ) -> Result<tt::Subtree, mbe::ExpandError> {
177     expand_simple_derive(tt, quote! { std::default::Default })
178 }
179
180 fn debug_expand(
181     _db: &dyn AstDatabase,
182     _id: LazyMacroId,
183     tt: &tt::Subtree,
184 ) -> Result<tt::Subtree, mbe::ExpandError> {
185     expand_simple_derive(tt, quote! { std::fmt::Debug })
186 }
187
188 fn hash_expand(
189     _db: &dyn AstDatabase,
190     _id: LazyMacroId,
191     tt: &tt::Subtree,
192 ) -> Result<tt::Subtree, mbe::ExpandError> {
193     expand_simple_derive(tt, quote! { std::hash::Hash })
194 }
195
196 fn eq_expand(
197     _db: &dyn AstDatabase,
198     _id: LazyMacroId,
199     tt: &tt::Subtree,
200 ) -> Result<tt::Subtree, mbe::ExpandError> {
201     expand_simple_derive(tt, quote! { std::cmp::Eq })
202 }
203
204 fn partial_eq_expand(
205     _db: &dyn AstDatabase,
206     _id: LazyMacroId,
207     tt: &tt::Subtree,
208 ) -> Result<tt::Subtree, mbe::ExpandError> {
209     expand_simple_derive(tt, quote! { std::cmp::PartialEq })
210 }
211
212 fn ord_expand(
213     _db: &dyn AstDatabase,
214     _id: LazyMacroId,
215     tt: &tt::Subtree,
216 ) -> Result<tt::Subtree, mbe::ExpandError> {
217     expand_simple_derive(tt, quote! { std::cmp::Ord })
218 }
219
220 fn partial_ord_expand(
221     _db: &dyn AstDatabase,
222     _id: LazyMacroId,
223     tt: &tt::Subtree,
224 ) -> Result<tt::Subtree, mbe::ExpandError> {
225     expand_simple_derive(tt, quote! { std::cmp::PartialOrd })
226 }
227
228 #[cfg(test)]
229 mod tests {
230     use super::*;
231     use crate::{test_db::TestDB, AstId, MacroCallId, MacroCallKind, MacroCallLoc};
232     use name::{known, Name};
233     use ra_db::{fixture::WithFixture, SourceDatabase};
234
235     fn expand_builtin_derive(s: &str, name: Name) -> String {
236         let def = find_builtin_derive(&name).unwrap();
237
238         let (db, file_id) = TestDB::with_single_file(&s);
239         let parsed = db.parse(file_id);
240         let items: Vec<_> =
241             parsed.syntax_node().descendants().filter_map(ast::ModuleItem::cast).collect();
242
243         let ast_id_map = db.ast_id_map(file_id.into());
244
245         let attr_id = AstId::new(file_id.into(), ast_id_map.ast_id(&items[0]));
246
247         let loc = MacroCallLoc { def, kind: MacroCallKind::Attr(attr_id, name.to_string()) };
248
249         let id: MacroCallId = db.intern_macro(loc).into();
250         let parsed = db.parse_or_expand(id.as_file()).unwrap();
251
252         // FIXME text() for syntax nodes parsed from token tree looks weird
253         // because there's no whitespace, see below
254         parsed.text().to_string()
255     }
256
257     #[test]
258     fn test_copy_expand_simple() {
259         let expanded = expand_builtin_derive(
260             r#"
261         #[derive(Copy)]
262         struct Foo;
263 "#,
264             known::Copy,
265         );
266
267         assert_eq!(expanded, "impl< >std::marker::CopyforFoo< >{}");
268     }
269
270     #[test]
271     fn test_copy_expand_with_type_params() {
272         let expanded = expand_builtin_derive(
273             r#"
274         #[derive(Copy)]
275         struct Foo<A, B>;
276 "#,
277             known::Copy,
278         );
279
280         assert_eq!(
281             expanded,
282             "impl<T0:std::marker::Copy,T1:std::marker::Copy>std::marker::CopyforFoo<T0,T1>{}"
283         );
284     }
285
286     #[test]
287     fn test_copy_expand_with_lifetimes() {
288         let expanded = expand_builtin_derive(
289             r#"
290         #[derive(Copy)]
291         struct Foo<A, B, 'a, 'b>;
292 "#,
293             known::Copy,
294         );
295
296         // We currently just ignore lifetimes
297
298         assert_eq!(
299             expanded,
300             "impl<T0:std::marker::Copy,T1:std::marker::Copy>std::marker::CopyforFoo<T0,T1>{}"
301         );
302     }
303
304     #[test]
305     fn test_clone_expand() {
306         let expanded = expand_builtin_derive(
307             r#"
308         #[derive(Clone)]
309         struct Foo<A, B>;
310 "#,
311             known::Clone,
312         );
313
314         assert_eq!(
315             expanded,
316             "impl<T0:std::clone::Clone,T1:std::clone::Clone>std::clone::CloneforFoo<T0,T1>{}"
317         );
318     }
319 }