]> git.lizzy.rs Git - rust.git/blob - crates/mbe/src/lib.rs
Merge #11369
[rust.git] / crates / mbe / src / lib.rs
1 //! `mbe` (short for Macro By Example) crate contains code for handling
2 //! `macro_rules` macros. It uses `TokenTree` (from `tt` package) as the
3 //! interface, although it contains some code to bridge `SyntaxNode`s and
4 //! `TokenTree`s as well!
5 //!
6 //! The tes for this functionality live in another crate:
7 //! `hir_def::macro_expansion_tests::mbe`.
8
9 mod parser;
10 mod expander;
11 mod syntax_bridge;
12 mod tt_iter;
13 mod to_parser_input;
14
15 #[cfg(test)]
16 mod benchmark;
17 mod token_map;
18
19 use std::fmt;
20
21 use crate::{
22     parser::{MetaTemplate, Op},
23     tt_iter::TtIter,
24 };
25
26 // FIXME: we probably should re-think  `token_tree_to_syntax_node` interfaces
27 pub use ::parser::TopEntryPoint;
28 pub use tt::{Delimiter, DelimiterKind, Punct};
29
30 pub use crate::{
31     syntax_bridge::{
32         parse_exprs_with_sep, parse_to_token_tree, syntax_node_to_token_tree,
33         syntax_node_to_token_tree_with_modifications, token_tree_to_syntax_node, SyntheticToken,
34         SyntheticTokenId,
35     },
36     token_map::TokenMap,
37 };
38
39 #[derive(Debug, PartialEq, Eq, Clone)]
40 pub enum ParseError {
41     UnexpectedToken(Box<str>),
42     Expected(Box<str>),
43     InvalidRepeat,
44     RepetitionEmptyTokenTree,
45 }
46
47 impl ParseError {
48     fn expected(e: &str) -> ParseError {
49         ParseError::Expected(e.into())
50     }
51
52     fn unexpected(e: &str) -> ParseError {
53         ParseError::UnexpectedToken(e.into())
54     }
55 }
56
57 impl fmt::Display for ParseError {
58     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
59         match self {
60             ParseError::UnexpectedToken(it) => f.write_str(it),
61             ParseError::Expected(it) => f.write_str(it),
62             ParseError::InvalidRepeat => f.write_str("invalid repeat"),
63             ParseError::RepetitionEmptyTokenTree => f.write_str("empty token tree in repetition"),
64         }
65     }
66 }
67
68 #[derive(Debug, PartialEq, Eq, Clone)]
69 pub enum ExpandError {
70     NoMatchingRule,
71     UnexpectedToken,
72     BindingError(Box<str>),
73     ConversionError,
74     // FIXME: no way mbe should know about proc macros.
75     UnresolvedProcMacro,
76     Other(Box<str>),
77 }
78
79 impl ExpandError {
80     fn binding_error(e: &str) -> ExpandError {
81         ExpandError::BindingError(e.into())
82     }
83 }
84
85 impl fmt::Display for ExpandError {
86     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
87         match self {
88             ExpandError::NoMatchingRule => f.write_str("no rule matches input tokens"),
89             ExpandError::UnexpectedToken => f.write_str("unexpected token in input"),
90             ExpandError::BindingError(e) => f.write_str(e),
91             ExpandError::ConversionError => f.write_str("could not convert tokens"),
92             ExpandError::UnresolvedProcMacro => f.write_str("unresolved proc macro"),
93             ExpandError::Other(e) => f.write_str(e),
94         }
95     }
96 }
97
98 /// This struct contains AST for a single `macro_rules` definition. What might
99 /// be very confusing is that AST has almost exactly the same shape as
100 /// `tt::TokenTree`, but there's a crucial difference: in macro rules, `$ident`
101 /// and `$()*` have special meaning (see `Var` and `Repeat` data structures)
102 #[derive(Clone, Debug, PartialEq, Eq)]
103 pub struct DeclarativeMacro {
104     rules: Vec<Rule>,
105     /// Highest id of the token we have in TokenMap
106     shift: Shift,
107 }
108
109 #[derive(Clone, Debug, PartialEq, Eq)]
110 struct Rule {
111     lhs: MetaTemplate,
112     rhs: MetaTemplate,
113 }
114
115 #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
116 pub struct Shift(u32);
117
118 impl Shift {
119     pub fn new(tt: &tt::Subtree) -> Shift {
120         // Note that TokenId is started from zero,
121         // We have to add 1 to prevent duplication.
122         let value = max_id(tt).map_or(0, |it| it + 1);
123         return Shift(value);
124
125         // Find the max token id inside a subtree
126         fn max_id(subtree: &tt::Subtree) -> Option<u32> {
127             let filter = |tt: &_| match tt {
128                 tt::TokenTree::Subtree(subtree) => {
129                     let tree_id = max_id(subtree);
130                     match subtree.delimiter {
131                         Some(it) if it.id != tt::TokenId::unspecified() => {
132                             Some(tree_id.map_or(it.id.0, |t| t.max(it.id.0)))
133                         }
134                         _ => tree_id,
135                     }
136                 }
137                 tt::TokenTree::Leaf(leaf) => {
138                     let &(tt::Leaf::Ident(tt::Ident { id, .. })
139                     | tt::Leaf::Punct(tt::Punct { id, .. })
140                     | tt::Leaf::Literal(tt::Literal { id, .. })) = leaf;
141
142                     (id != tt::TokenId::unspecified()).then(|| id.0)
143                 }
144             };
145             subtree.token_trees.iter().filter_map(filter).max()
146         }
147     }
148
149     /// Shift given TokenTree token id
150     pub fn shift_all(self, tt: &mut tt::Subtree) {
151         for t in &mut tt.token_trees {
152             match t {
153                 tt::TokenTree::Leaf(
154                     tt::Leaf::Ident(tt::Ident { id, .. })
155                     | tt::Leaf::Punct(tt::Punct { id, .. })
156                     | tt::Leaf::Literal(tt::Literal { id, .. }),
157                 ) => *id = self.shift(*id),
158                 tt::TokenTree::Subtree(tt) => {
159                     if let Some(it) = tt.delimiter.as_mut() {
160                         it.id = self.shift(it.id);
161                     }
162                     self.shift_all(tt)
163                 }
164             }
165         }
166     }
167
168     pub fn shift(self, id: tt::TokenId) -> tt::TokenId {
169         if id == tt::TokenId::unspecified() {
170             id
171         } else {
172             tt::TokenId(id.0 + self.0)
173         }
174     }
175
176     pub fn unshift(self, id: tt::TokenId) -> Option<tt::TokenId> {
177         id.0.checked_sub(self.0).map(tt::TokenId)
178     }
179 }
180
181 #[derive(Debug, Eq, PartialEq)]
182 pub enum Origin {
183     Def,
184     Call,
185 }
186
187 impl DeclarativeMacro {
188     /// The old, `macro_rules! m {}` flavor.
189     pub fn parse_macro_rules(tt: &tt::Subtree) -> Result<DeclarativeMacro, ParseError> {
190         // Note: this parsing can be implemented using mbe machinery itself, by
191         // matching against `$($lhs:tt => $rhs:tt);*` pattern, but implementing
192         // manually seems easier.
193         let mut src = TtIter::new(tt);
194         let mut rules = Vec::new();
195         while src.len() > 0 {
196             let rule = Rule::parse(&mut src, true)?;
197             rules.push(rule);
198             if let Err(()) = src.expect_char(';') {
199                 if src.len() > 0 {
200                     return Err(ParseError::expected("expected `;`"));
201                 }
202                 break;
203             }
204         }
205
206         for Rule { lhs, .. } in &rules {
207             validate(lhs)?;
208         }
209
210         Ok(DeclarativeMacro { rules, shift: Shift::new(tt) })
211     }
212
213     /// The new, unstable `macro m {}` flavor.
214     pub fn parse_macro2(tt: &tt::Subtree) -> Result<DeclarativeMacro, ParseError> {
215         let mut src = TtIter::new(tt);
216         let mut rules = Vec::new();
217
218         if Some(tt::DelimiterKind::Brace) == tt.delimiter_kind() {
219             cov_mark::hit!(parse_macro_def_rules);
220             while src.len() > 0 {
221                 let rule = Rule::parse(&mut src, true)?;
222                 rules.push(rule);
223                 if let Err(()) = src.expect_any_char(&[';', ',']) {
224                     if src.len() > 0 {
225                         return Err(ParseError::expected("expected `;` or `,` to delimit rules"));
226                     }
227                     break;
228                 }
229             }
230         } else {
231             cov_mark::hit!(parse_macro_def_simple);
232             let rule = Rule::parse(&mut src, false)?;
233             if src.len() != 0 {
234                 return Err(ParseError::expected("remaining tokens in macro def"));
235             }
236             rules.push(rule);
237         }
238
239         for Rule { lhs, .. } in &rules {
240             validate(lhs)?;
241         }
242
243         Ok(DeclarativeMacro { rules, shift: Shift::new(tt) })
244     }
245
246     pub fn expand(&self, tt: &tt::Subtree) -> ExpandResult<tt::Subtree> {
247         // apply shift
248         let mut tt = tt.clone();
249         self.shift.shift_all(&mut tt);
250         expander::expand_rules(&self.rules, &tt)
251     }
252
253     pub fn map_id_down(&self, id: tt::TokenId) -> tt::TokenId {
254         self.shift.shift(id)
255     }
256
257     pub fn map_id_up(&self, id: tt::TokenId) -> (tt::TokenId, Origin) {
258         match self.shift.unshift(id) {
259             Some(id) => (id, Origin::Call),
260             None => (id, Origin::Def),
261         }
262     }
263
264     pub fn shift(&self) -> Shift {
265         self.shift
266     }
267 }
268
269 impl Rule {
270     fn parse(src: &mut TtIter, expect_arrow: bool) -> Result<Self, ParseError> {
271         let lhs = src.expect_subtree().map_err(|()| ParseError::expected("expected subtree"))?;
272         if expect_arrow {
273             src.expect_char('=').map_err(|()| ParseError::expected("expected `=`"))?;
274             src.expect_char('>').map_err(|()| ParseError::expected("expected `>`"))?;
275         }
276         let rhs = src.expect_subtree().map_err(|()| ParseError::expected("expected subtree"))?;
277
278         let lhs = MetaTemplate::parse_pattern(lhs)?;
279         let rhs = MetaTemplate::parse_template(rhs)?;
280
281         Ok(crate::Rule { lhs, rhs })
282     }
283 }
284
285 fn validate(pattern: &MetaTemplate) -> Result<(), ParseError> {
286     for op in pattern.iter() {
287         match op {
288             Op::Subtree { tokens, .. } => validate(tokens)?,
289             Op::Repeat { tokens: subtree, separator, .. } => {
290                 // Checks that no repetition which could match an empty token
291                 // https://github.com/rust-lang/rust/blob/a58b1ed44f5e06976de2bdc4d7dc81c36a96934f/src/librustc_expand/mbe/macro_rules.rs#L558
292                 let lsh_is_empty_seq = separator.is_none() && subtree.iter().all(|child_op| {
293                     match child_op {
294                         // vis is optional
295                         Op::Var { kind: Some(kind), .. } => kind == "vis",
296                         Op::Repeat {
297                             kind: parser::RepeatKind::ZeroOrMore | parser::RepeatKind::ZeroOrOne,
298                             ..
299                         } => true,
300                         _ => false,
301                     }
302                 });
303                 if lsh_is_empty_seq {
304                     return Err(ParseError::RepetitionEmptyTokenTree);
305                 }
306                 validate(subtree)?
307             }
308             _ => (),
309         }
310     }
311     Ok(())
312 }
313
314 #[derive(Debug, Clone, Eq, PartialEq)]
315 pub struct ExpandResult<T> {
316     pub value: T,
317     pub err: Option<ExpandError>,
318 }
319
320 impl<T> ExpandResult<T> {
321     pub fn ok(value: T) -> Self {
322         Self { value, err: None }
323     }
324
325     pub fn only_err(err: ExpandError) -> Self
326     where
327         T: Default,
328     {
329         Self { value: Default::default(), err: Some(err) }
330     }
331
332     pub fn str_err(err: String) -> Self
333     where
334         T: Default,
335     {
336         Self::only_err(ExpandError::Other(err.into()))
337     }
338
339     pub fn map<U>(self, f: impl FnOnce(T) -> U) -> ExpandResult<U> {
340         ExpandResult { value: f(self.value), err: self.err }
341     }
342
343     pub fn result(self) -> Result<T, ExpandError> {
344         self.err.map_or(Ok(self.value), Err)
345     }
346 }
347
348 impl<T: Default> From<Result<T, ExpandError>> for ExpandResult<T> {
349     fn from(result: Result<T, ExpandError>) -> Self {
350         result.map_or_else(Self::only_err, Self::ok)
351     }
352 }