]> git.lizzy.rs Git - rust.git/blob - crates/mbe/src/expander/matcher.rs
5b58458509fa42446e4da4881bd67985820619e2
[rust.git] / crates / mbe / src / expander / matcher.rs
1 //! FIXME: write short doc here
2
3 use crate::{
4     expander::{Binding, Bindings, Fragment},
5     parser::{Op, RepeatKind, Separator},
6     subtree_source::SubtreeTokenSource,
7     tt_iter::TtIter,
8     ExpandError, MetaTemplate,
9 };
10
11 use super::ExpandResult;
12 use parser::{FragmentKind::*, TreeSink};
13 use syntax::{SmolStr, SyntaxKind};
14 use tt::buffer::{Cursor, TokenBuffer};
15
16 impl Bindings {
17     fn push_optional(&mut self, name: &SmolStr) {
18         // FIXME: Do we have a better way to represent an empty token ?
19         // Insert an empty subtree for empty token
20         let tt = tt::Subtree::default().into();
21         self.inner.insert(name.clone(), Binding::Fragment(Fragment::Tokens(tt)));
22     }
23
24     fn push_empty(&mut self, name: &SmolStr) {
25         self.inner.insert(name.clone(), Binding::Empty);
26     }
27
28     fn push_nested(&mut self, idx: usize, nested: Bindings) -> Result<(), ExpandError> {
29         for (key, value) in nested.inner {
30             if !self.inner.contains_key(&key) {
31                 self.inner.insert(key.clone(), Binding::Nested(Vec::new()));
32             }
33             match self.inner.get_mut(&key) {
34                 Some(Binding::Nested(it)) => {
35                     // insert empty nested bindings before this one
36                     while it.len() < idx {
37                         it.push(Binding::Nested(vec![]));
38                     }
39                     it.push(value);
40                 }
41                 _ => {
42                     return Err(ExpandError::BindingError(format!(
43                         "could not find binding `{}`",
44                         key
45                     )));
46                 }
47             }
48         }
49         Ok(())
50     }
51 }
52
53 macro_rules! err {
54     () => {
55         ExpandError::BindingError(format!(""))
56     };
57     ($($tt:tt)*) => {
58         ExpandError::BindingError(format!($($tt)*))
59     };
60 }
61
62 #[derive(Debug, Default)]
63 pub(super) struct Match {
64     pub(super) bindings: Bindings,
65     /// We currently just keep the first error and count the rest to compare matches.
66     pub(super) err: Option<ExpandError>,
67     pub(super) err_count: usize,
68     /// How many top-level token trees were left to match.
69     pub(super) unmatched_tts: usize,
70 }
71
72 impl Match {
73     pub(super) fn add_err(&mut self, err: ExpandError) {
74         let prev_err = self.err.take();
75         self.err = prev_err.or(Some(err));
76         self.err_count += 1;
77     }
78 }
79
80 // General note: These functions have two channels to return errors, a `Result`
81 // return value and the `&mut Match`. The returned Result is for pattern parsing
82 // errors; if a branch of the macro definition doesn't parse, it doesn't make
83 // sense to try using it. Matching errors are added to the `Match`. It might
84 // make sense to make pattern parsing a separate step?
85
86 pub(super) fn match_(pattern: &MetaTemplate, src: &tt::Subtree) -> Result<Match, ExpandError> {
87     assert!(pattern.delimiter == None);
88
89     let mut res = Match::default();
90     let mut src = TtIter::new(src);
91
92     match_subtree(&mut res, pattern, &mut src)?;
93
94     if src.len() > 0 {
95         res.unmatched_tts += src.len();
96         res.add_err(err!("leftover tokens"));
97     }
98
99     Ok(res)
100 }
101
102 fn match_subtree(
103     res: &mut Match,
104     pattern: &MetaTemplate,
105     src: &mut TtIter,
106 ) -> Result<(), ExpandError> {
107     for op in pattern.iter() {
108         match op.as_ref().map_err(|err| err.clone())? {
109             Op::Leaf(lhs) => {
110                 let rhs = match src.expect_leaf() {
111                     Ok(l) => l,
112                     Err(()) => {
113                         res.add_err(err!("expected leaf: `{}`", lhs));
114                         continue;
115                     }
116                 };
117                 match (lhs, rhs) {
118                     (
119                         tt::Leaf::Punct(tt::Punct { char: lhs, .. }),
120                         tt::Leaf::Punct(tt::Punct { char: rhs, .. }),
121                     ) if lhs == rhs => (),
122                     (
123                         tt::Leaf::Ident(tt::Ident { text: lhs, .. }),
124                         tt::Leaf::Ident(tt::Ident { text: rhs, .. }),
125                     ) if lhs == rhs => (),
126                     (
127                         tt::Leaf::Literal(tt::Literal { text: lhs, .. }),
128                         tt::Leaf::Literal(tt::Literal { text: rhs, .. }),
129                     ) if lhs == rhs => (),
130                     _ => {
131                         res.add_err(ExpandError::UnexpectedToken);
132                     }
133                 }
134             }
135             Op::Subtree(lhs) => {
136                 let rhs = match src.expect_subtree() {
137                     Ok(s) => s,
138                     Err(()) => {
139                         res.add_err(err!("expected subtree"));
140                         continue;
141                     }
142                 };
143                 if lhs.delimiter_kind() != rhs.delimiter_kind() {
144                     res.add_err(err!("mismatched delimiter"));
145                     continue;
146                 }
147                 let mut src = TtIter::new(rhs);
148                 match_subtree(res, lhs, &mut src)?;
149                 if src.len() > 0 {
150                     res.add_err(err!("leftover tokens"));
151                 }
152             }
153             Op::Var { name, kind, .. } => {
154                 let kind = match kind {
155                     Some(k) => k,
156                     None => {
157                         res.add_err(ExpandError::UnexpectedToken);
158                         continue;
159                     }
160                 };
161                 let ExpandResult { value: matched, err: match_err } =
162                     match_meta_var(kind.as_str(), src);
163                 match matched {
164                     Some(fragment) => {
165                         res.bindings.inner.insert(name.clone(), Binding::Fragment(fragment));
166                     }
167                     None if match_err.is_none() => res.bindings.push_optional(name),
168                     _ => {}
169                 }
170                 if let Some(err) = match_err {
171                     res.add_err(err);
172                 }
173             }
174             Op::Repeat { subtree, kind, separator } => {
175                 match_repeat(res, subtree, *kind, separator, src)?;
176             }
177         }
178     }
179     Ok(())
180 }
181
182 impl<'a> TtIter<'a> {
183     fn eat_separator(&mut self, separator: &Separator) -> bool {
184         let mut fork = self.clone();
185         let ok = match separator {
186             Separator::Ident(lhs) => match fork.expect_ident() {
187                 Ok(rhs) => rhs.text == lhs.text,
188                 _ => false,
189             },
190             Separator::Literal(lhs) => match fork.expect_literal() {
191                 Ok(rhs) => match rhs {
192                     tt::Leaf::Literal(rhs) => rhs.text == lhs.text,
193                     tt::Leaf::Ident(rhs) => rhs.text == lhs.text,
194                     tt::Leaf::Punct(_) => false,
195                 },
196                 _ => false,
197             },
198             Separator::Puncts(lhss) => lhss.iter().all(|lhs| match fork.expect_punct() {
199                 Ok(rhs) => rhs.char == lhs.char,
200                 _ => false,
201             }),
202         };
203         if ok {
204             *self = fork;
205         }
206         ok
207     }
208
209     pub(crate) fn expect_tt(&mut self) -> Result<tt::TokenTree, ()> {
210         match self.peek_n(0) {
211             Some(tt::TokenTree::Leaf(tt::Leaf::Punct(punct))) if punct.char == '\'' => {
212                 return self.expect_lifetime();
213             }
214             _ => (),
215         }
216
217         let tt = self.next().ok_or_else(|| ())?.clone();
218         let punct = match tt {
219             tt::TokenTree::Leaf(tt::Leaf::Punct(punct)) if punct.spacing == tt::Spacing::Joint => {
220                 punct
221             }
222             _ => return Ok(tt),
223         };
224
225         let (second, third) = match (self.peek_n(0), self.peek_n(1)) {
226             (
227                 Some(tt::TokenTree::Leaf(tt::Leaf::Punct(p2))),
228                 Some(tt::TokenTree::Leaf(tt::Leaf::Punct(p3))),
229             ) if p2.spacing == tt::Spacing::Joint => (p2.char, Some(p3.char)),
230             (Some(tt::TokenTree::Leaf(tt::Leaf::Punct(p2))), _) => (p2.char, None),
231             _ => return Ok(tt),
232         };
233
234         match (punct.char, second, third) {
235             ('.', '.', Some('.'))
236             | ('.', '.', Some('='))
237             | ('<', '<', Some('='))
238             | ('>', '>', Some('=')) => {
239                 let tt2 = self.next().unwrap().clone();
240                 let tt3 = self.next().unwrap().clone();
241                 Ok(tt::Subtree { delimiter: None, token_trees: vec![tt, tt2, tt3] }.into())
242             }
243             ('-', '=', _)
244             | ('-', '>', _)
245             | (':', ':', _)
246             | ('!', '=', _)
247             | ('.', '.', _)
248             | ('*', '=', _)
249             | ('/', '=', _)
250             | ('&', '&', _)
251             | ('&', '=', _)
252             | ('%', '=', _)
253             | ('^', '=', _)
254             | ('+', '=', _)
255             | ('<', '<', _)
256             | ('<', '=', _)
257             | ('=', '=', _)
258             | ('=', '>', _)
259             | ('>', '=', _)
260             | ('>', '>', _)
261             | ('|', '=', _)
262             | ('|', '|', _) => {
263                 let tt2 = self.next().unwrap().clone();
264                 Ok(tt::Subtree { delimiter: None, token_trees: vec![tt, tt2] }.into())
265             }
266             _ => Ok(tt),
267         }
268     }
269
270     pub(crate) fn expect_lifetime(&mut self) -> Result<tt::TokenTree, ()> {
271         let punct = self.expect_punct()?;
272         if punct.char != '\'' {
273             return Err(());
274         }
275         let ident = self.expect_ident()?;
276
277         Ok(tt::Subtree {
278             delimiter: None,
279             token_trees: vec![
280                 tt::Leaf::Punct(*punct).into(),
281                 tt::Leaf::Ident(ident.clone()).into(),
282             ],
283         }
284         .into())
285     }
286
287     pub(crate) fn expect_fragment(
288         &mut self,
289         fragment_kind: parser::FragmentKind,
290     ) -> ExpandResult<Option<tt::TokenTree>> {
291         pub(crate) struct OffsetTokenSink<'a> {
292             pub(crate) cursor: Cursor<'a>,
293             pub(crate) error: bool,
294         }
295
296         impl<'a> TreeSink for OffsetTokenSink<'a> {
297             fn token(&mut self, kind: SyntaxKind, mut n_tokens: u8) {
298                 if kind == SyntaxKind::LIFETIME_IDENT {
299                     n_tokens = 2;
300                 }
301                 for _ in 0..n_tokens {
302                     self.cursor = self.cursor.bump_subtree();
303                 }
304             }
305             fn start_node(&mut self, _kind: SyntaxKind) {}
306             fn finish_node(&mut self) {}
307             fn error(&mut self, _error: parser::ParseError) {
308                 self.error = true;
309             }
310         }
311
312         let buffer = TokenBuffer::from_tokens(&self.inner.as_slice());
313         let mut src = SubtreeTokenSource::new(&buffer);
314         let mut sink = OffsetTokenSink { cursor: buffer.begin(), error: false };
315
316         parser::parse_fragment(&mut src, &mut sink, fragment_kind);
317
318         let mut err = None;
319         if !sink.cursor.is_root() || sink.error {
320             err = Some(err!("expected {:?}", fragment_kind));
321         }
322
323         let mut curr = buffer.begin();
324         let mut res = vec![];
325
326         if sink.cursor.is_root() {
327             while curr != sink.cursor {
328                 if let Some(token) = curr.token_tree() {
329                     res.push(token);
330                 }
331                 curr = curr.bump();
332             }
333         }
334         self.inner = self.inner.as_slice()[res.len()..].iter();
335         if res.len() == 0 && err.is_none() {
336             err = Some(err!("no tokens consumed"));
337         }
338         let res = match res.len() {
339             1 => Some(res[0].cloned()),
340             0 => None,
341             _ => Some(tt::TokenTree::Subtree(tt::Subtree {
342                 delimiter: None,
343                 token_trees: res.into_iter().map(|it| it.cloned()).collect(),
344             })),
345         };
346         ExpandResult { value: res, err }
347     }
348
349     pub(crate) fn eat_vis(&mut self) -> Option<tt::TokenTree> {
350         let mut fork = self.clone();
351         match fork.expect_fragment(Visibility) {
352             ExpandResult { value: tt, err: None } => {
353                 *self = fork;
354                 tt
355             }
356             ExpandResult { value: _, err: Some(_) } => None,
357         }
358     }
359
360     pub(crate) fn eat_char(&mut self, c: char) -> Option<tt::TokenTree> {
361         let mut fork = self.clone();
362         match fork.expect_char(c) {
363             Ok(_) => {
364                 let tt = self.next().cloned();
365                 *self = fork;
366                 tt
367             }
368             Err(_) => None,
369         }
370     }
371 }
372
373 pub(super) fn match_repeat(
374     res: &mut Match,
375     pattern: &MetaTemplate,
376     kind: RepeatKind,
377     separator: &Option<Separator>,
378     src: &mut TtIter,
379 ) -> Result<(), ExpandError> {
380     // Dirty hack to make macro-expansion terminate.
381     // This should be replaced by a proper macro-by-example implementation
382     let mut limit = 65536;
383     let mut counter = 0;
384
385     for i in 0.. {
386         let mut fork = src.clone();
387
388         if let Some(separator) = &separator {
389             if i != 0 && !fork.eat_separator(separator) {
390                 break;
391             }
392         }
393
394         let mut nested = Match::default();
395         match_subtree(&mut nested, pattern, &mut fork)?;
396         if nested.err.is_none() {
397             limit -= 1;
398             if limit == 0 {
399                 log::warn!(
400                     "match_lhs exceeded repeat pattern limit => {:#?}\n{:#?}\n{:#?}\n{:#?}",
401                     pattern,
402                     src,
403                     kind,
404                     separator
405                 );
406                 break;
407             }
408             *src = fork;
409
410             if let Err(err) = res.bindings.push_nested(counter, nested.bindings) {
411                 res.add_err(err);
412             }
413             counter += 1;
414             if counter == 1 {
415                 if let RepeatKind::ZeroOrOne = kind {
416                     break;
417                 }
418             }
419         } else {
420             break;
421         }
422     }
423
424     match (kind, counter) {
425         (RepeatKind::OneOrMore, 0) => {
426             res.add_err(ExpandError::UnexpectedToken);
427         }
428         (_, 0) => {
429             // Collect all empty variables in subtrees
430             let mut vars = Vec::new();
431             collect_vars(&mut vars, pattern)?;
432             for var in vars {
433                 res.bindings.push_empty(&var)
434             }
435         }
436         _ => (),
437     }
438     Ok(())
439 }
440
441 fn match_meta_var(kind: &str, input: &mut TtIter) -> ExpandResult<Option<Fragment>> {
442     let fragment = match kind {
443         "path" => Path,
444         "expr" => Expr,
445         "ty" => Type,
446         "pat" => Pattern,
447         "stmt" => Statement,
448         "block" => Block,
449         "meta" => MetaItem,
450         "item" => Item,
451         _ => {
452             let tt_result = match kind {
453                 "ident" => input
454                     .expect_ident()
455                     .map(|ident| Some(tt::Leaf::from(ident.clone()).into()))
456                     .map_err(|()| err!("expected ident")),
457                 "tt" => input.expect_tt().map(Some).map_err(|()| err!()),
458                 "lifetime" => input
459                     .expect_lifetime()
460                     .map(|tt| Some(tt))
461                     .map_err(|()| err!("expected lifetime")),
462                 "literal" => {
463                     let neg = input.eat_char('-');
464                     input
465                         .expect_literal()
466                         .map(|literal| {
467                             let lit = tt::Leaf::from(literal.clone());
468                             match neg {
469                                 None => Some(lit.into()),
470                                 Some(neg) => Some(tt::TokenTree::Subtree(tt::Subtree {
471                                     delimiter: None,
472                                     token_trees: vec![neg, lit.into()],
473                                 })),
474                             }
475                         })
476                         .map_err(|()| err!())
477                 }
478                 // `vis` is optional
479                 "vis" => match input.eat_vis() {
480                     Some(vis) => Ok(Some(vis)),
481                     None => Ok(None),
482                 },
483                 _ => Err(ExpandError::UnexpectedToken),
484             };
485             return tt_result.map(|it| it.map(Fragment::Tokens)).into();
486         }
487     };
488     let result = input.expect_fragment(fragment);
489     result.map(|tt| if kind == "expr" { tt.map(Fragment::Ast) } else { tt.map(Fragment::Tokens) })
490 }
491
492 fn collect_vars(buf: &mut Vec<SmolStr>, pattern: &MetaTemplate) -> Result<(), ExpandError> {
493     for op in pattern.iter() {
494         match op.as_ref().map_err(|e| e.clone())? {
495             Op::Var { name, .. } => buf.push(name.clone()),
496             Op::Leaf(_) => (),
497             Op::Subtree(subtree) => collect_vars(buf, subtree)?,
498             Op::Repeat { subtree, .. } => collect_vars(buf, subtree)?,
499         }
500     }
501     Ok(())
502 }