]> git.lizzy.rs Git - rust.git/blob - crates/syntax/src/ast/node_ext.rs
2aa472fb494a3774259b9c2e056a72d6845449d5
[rust.git] / crates / syntax / src / ast / node_ext.rs
1 //! Various extension methods to ast Nodes, which are hard to code-generate.
2 //! Extensions for various expressions live in a sibling `expr_extensions` module.
3
4 use std::fmt;
5
6 use ast::AttrsOwner;
7 use itertools::Itertools;
8 use parser::SyntaxKind;
9
10 use crate::{
11     ast::{self, support, AstNode, AstToken, NameOwner, SyntaxNode},
12     SmolStr, SyntaxElement, SyntaxToken, T,
13 };
14
15 impl ast::Lifetime {
16     pub fn text(&self) -> &SmolStr {
17         text_of_first_token(self.syntax())
18     }
19 }
20
21 impl ast::Name {
22     pub fn text(&self) -> &SmolStr {
23         text_of_first_token(self.syntax())
24     }
25 }
26
27 impl ast::NameRef {
28     pub fn text(&self) -> &SmolStr {
29         text_of_first_token(self.syntax())
30     }
31
32     pub fn as_tuple_field(&self) -> Option<usize> {
33         self.text().parse().ok()
34     }
35 }
36
37 fn text_of_first_token(node: &SyntaxNode) -> &SmolStr {
38     node.green().children().next().and_then(|it| it.into_token()).unwrap().text()
39 }
40
41 pub enum Macro {
42     MacroRules(ast::MacroRules),
43     MacroDef(ast::MacroDef),
44 }
45
46 impl From<ast::MacroRules> for Macro {
47     fn from(it: ast::MacroRules) -> Self {
48         Macro::MacroRules(it)
49     }
50 }
51
52 impl From<ast::MacroDef> for Macro {
53     fn from(it: ast::MacroDef) -> Self {
54         Macro::MacroDef(it)
55     }
56 }
57
58 impl AstNode for Macro {
59     fn can_cast(kind: SyntaxKind) -> bool {
60         match kind {
61             SyntaxKind::MACRO_RULES | SyntaxKind::MACRO_DEF => true,
62             _ => false,
63         }
64     }
65     fn cast(syntax: SyntaxNode) -> Option<Self> {
66         let res = match syntax.kind() {
67             SyntaxKind::MACRO_RULES => Macro::MacroRules(ast::MacroRules { syntax }),
68             SyntaxKind::MACRO_DEF => Macro::MacroDef(ast::MacroDef { syntax }),
69             _ => return None,
70         };
71         Some(res)
72     }
73     fn syntax(&self) -> &SyntaxNode {
74         match self {
75             Macro::MacroRules(it) => it.syntax(),
76             Macro::MacroDef(it) => it.syntax(),
77         }
78     }
79 }
80
81 impl NameOwner for Macro {
82     fn name(&self) -> Option<ast::Name> {
83         match self {
84             Macro::MacroRules(mac) => mac.name(),
85             Macro::MacroDef(mac) => mac.name(),
86         }
87     }
88 }
89
90 impl AttrsOwner for Macro {}
91
92 #[derive(Debug, Clone, PartialEq, Eq)]
93 pub enum AttrKind {
94     Inner,
95     Outer,
96 }
97
98 impl ast::Attr {
99     pub fn as_simple_atom(&self) -> Option<SmolStr> {
100         if self.eq_token().is_some() || self.token_tree().is_some() {
101             return None;
102         }
103         self.simple_name()
104     }
105
106     pub fn as_simple_call(&self) -> Option<(SmolStr, ast::TokenTree)> {
107         let tt = self.token_tree()?;
108         Some((self.simple_name()?, tt))
109     }
110
111     pub fn as_simple_key_value(&self) -> Option<(SmolStr, SmolStr)> {
112         let lit = self.literal()?;
113         let key = self.simple_name()?;
114         let value_token = lit.syntax().first_token()?;
115
116         let value: SmolStr = ast::String::cast(value_token)?.value()?.into();
117
118         Some((key, value))
119     }
120
121     pub fn simple_name(&self) -> Option<SmolStr> {
122         let path = self.path()?;
123         match (path.segment(), path.qualifier()) {
124             (Some(segment), None) => Some(segment.syntax().first_token()?.text().clone()),
125             _ => None,
126         }
127     }
128
129     pub fn kind(&self) -> AttrKind {
130         let first_token = self.syntax().first_token();
131         let first_token_kind = first_token.as_ref().map(SyntaxToken::kind);
132         let second_token_kind =
133             first_token.and_then(|token| token.next_token()).as_ref().map(SyntaxToken::kind);
134
135         match (first_token_kind, second_token_kind) {
136             (Some(SyntaxKind::POUND), Some(T![!])) => AttrKind::Inner,
137             _ => AttrKind::Outer,
138         }
139     }
140 }
141
142 #[derive(Debug, Clone, PartialEq, Eq)]
143 pub enum PathSegmentKind {
144     Name(ast::NameRef),
145     Type { type_ref: Option<ast::Type>, trait_ref: Option<ast::PathType> },
146     SelfKw,
147     SuperKw,
148     CrateKw,
149 }
150
151 impl ast::PathSegment {
152     pub fn parent_path(&self) -> ast::Path {
153         self.syntax()
154             .parent()
155             .and_then(ast::Path::cast)
156             .expect("segments are always nested in paths")
157     }
158
159     pub fn kind(&self) -> Option<PathSegmentKind> {
160         let res = if let Some(name_ref) = self.name_ref() {
161             PathSegmentKind::Name(name_ref)
162         } else {
163             match self.syntax().first_child_or_token()?.kind() {
164                 T![self] => PathSegmentKind::SelfKw,
165                 T![super] => PathSegmentKind::SuperKw,
166                 T![crate] => PathSegmentKind::CrateKw,
167                 T![<] => {
168                     // <T> or <T as Trait>
169                     // T is any TypeRef, Trait has to be a PathType
170                     let mut type_refs =
171                         self.syntax().children().filter(|node| ast::Type::can_cast(node.kind()));
172                     let type_ref = type_refs.next().and_then(ast::Type::cast);
173                     let trait_ref = type_refs.next().and_then(ast::PathType::cast);
174                     PathSegmentKind::Type { type_ref, trait_ref }
175                 }
176                 _ => return None,
177             }
178         };
179         Some(res)
180     }
181 }
182
183 impl ast::Path {
184     pub fn parent_path(&self) -> Option<ast::Path> {
185         self.syntax().parent().and_then(ast::Path::cast)
186     }
187 }
188
189 impl ast::UseTreeList {
190     pub fn parent_use_tree(&self) -> ast::UseTree {
191         self.syntax()
192             .parent()
193             .and_then(ast::UseTree::cast)
194             .expect("UseTreeLists are always nested in UseTrees")
195     }
196
197     pub fn has_inner_comment(&self) -> bool {
198         self.syntax()
199             .children_with_tokens()
200             .filter_map(|it| it.into_token())
201             .find_map(ast::Comment::cast)
202             .is_some()
203     }
204 }
205
206 impl ast::Impl {
207     pub fn self_ty(&self) -> Option<ast::Type> {
208         match self.target() {
209             (Some(t), None) | (_, Some(t)) => Some(t),
210             _ => None,
211         }
212     }
213
214     pub fn trait_(&self) -> Option<ast::Type> {
215         match self.target() {
216             (Some(t), Some(_)) => Some(t),
217             _ => None,
218         }
219     }
220
221     fn target(&self) -> (Option<ast::Type>, Option<ast::Type>) {
222         let mut types = support::children(self.syntax());
223         let first = types.next();
224         let second = types.next();
225         (first, second)
226     }
227 }
228
229 #[derive(Debug, Clone, PartialEq, Eq)]
230 pub enum StructKind {
231     Record(ast::RecordFieldList),
232     Tuple(ast::TupleFieldList),
233     Unit,
234 }
235
236 impl StructKind {
237     fn from_node<N: AstNode>(node: &N) -> StructKind {
238         if let Some(nfdl) = support::child::<ast::RecordFieldList>(node.syntax()) {
239             StructKind::Record(nfdl)
240         } else if let Some(pfl) = support::child::<ast::TupleFieldList>(node.syntax()) {
241             StructKind::Tuple(pfl)
242         } else {
243             StructKind::Unit
244         }
245     }
246 }
247
248 impl ast::Struct {
249     pub fn kind(&self) -> StructKind {
250         StructKind::from_node(self)
251     }
252 }
253
254 impl ast::RecordExprField {
255     pub fn for_field_name(field_name: &ast::NameRef) -> Option<ast::RecordExprField> {
256         let candidate =
257             field_name.syntax().parent().and_then(ast::RecordExprField::cast).or_else(|| {
258                 field_name.syntax().ancestors().nth(4).and_then(ast::RecordExprField::cast)
259             })?;
260         if candidate.field_name().as_ref() == Some(field_name) {
261             Some(candidate)
262         } else {
263             None
264         }
265     }
266
267     /// Deals with field init shorthand
268     pub fn field_name(&self) -> Option<ast::NameRef> {
269         if let Some(name_ref) = self.name_ref() {
270             return Some(name_ref);
271         }
272         self.expr()?.name_ref()
273     }
274 }
275
276 pub enum NameOrNameRef {
277     Name(ast::Name),
278     NameRef(ast::NameRef),
279 }
280
281 impl fmt::Display for NameOrNameRef {
282     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
283         match self {
284             NameOrNameRef::Name(it) => fmt::Display::fmt(it, f),
285             NameOrNameRef::NameRef(it) => fmt::Display::fmt(it, f),
286         }
287     }
288 }
289
290 impl ast::RecordPatField {
291     /// Deals with field init shorthand
292     pub fn field_name(&self) -> Option<NameOrNameRef> {
293         if let Some(name_ref) = self.name_ref() {
294             return Some(NameOrNameRef::NameRef(name_ref));
295         }
296         if let Some(ast::Pat::IdentPat(pat)) = self.pat() {
297             let name = pat.name()?;
298             return Some(NameOrNameRef::Name(name));
299         }
300         None
301     }
302 }
303
304 impl ast::Variant {
305     pub fn parent_enum(&self) -> ast::Enum {
306         self.syntax()
307             .parent()
308             .and_then(|it| it.parent())
309             .and_then(ast::Enum::cast)
310             .expect("EnumVariants are always nested in Enums")
311     }
312     pub fn kind(&self) -> StructKind {
313         StructKind::from_node(self)
314     }
315 }
316
317 #[derive(Debug, Clone, PartialEq, Eq)]
318 pub enum FieldKind {
319     Name(ast::NameRef),
320     Index(SyntaxToken),
321 }
322
323 impl ast::FieldExpr {
324     pub fn index_token(&self) -> Option<SyntaxToken> {
325         self.syntax
326             .children_with_tokens()
327             // FIXME: Accepting floats here to reject them in validation later
328             .find(|c| c.kind() == SyntaxKind::INT_NUMBER || c.kind() == SyntaxKind::FLOAT_NUMBER)
329             .as_ref()
330             .and_then(SyntaxElement::as_token)
331             .cloned()
332     }
333
334     pub fn field_access(&self) -> Option<FieldKind> {
335         if let Some(nr) = self.name_ref() {
336             Some(FieldKind::Name(nr))
337         } else if let Some(tok) = self.index_token() {
338             Some(FieldKind::Index(tok))
339         } else {
340             None
341         }
342     }
343 }
344
345 pub struct SlicePatComponents {
346     pub prefix: Vec<ast::Pat>,
347     pub slice: Option<ast::Pat>,
348     pub suffix: Vec<ast::Pat>,
349 }
350
351 impl ast::SlicePat {
352     pub fn components(&self) -> SlicePatComponents {
353         let mut args = self.pats().peekable();
354         let prefix = args
355             .peeking_take_while(|p| match p {
356                 ast::Pat::RestPat(_) => false,
357                 ast::Pat::IdentPat(bp) => match bp.pat() {
358                     Some(ast::Pat::RestPat(_)) => false,
359                     _ => true,
360                 },
361                 ast::Pat::RefPat(rp) => match rp.pat() {
362                     Some(ast::Pat::RestPat(_)) => false,
363                     Some(ast::Pat::IdentPat(bp)) => match bp.pat() {
364                         Some(ast::Pat::RestPat(_)) => false,
365                         _ => true,
366                     },
367                     _ => true,
368                 },
369                 _ => true,
370             })
371             .collect();
372         let slice = args.next();
373         let suffix = args.collect();
374
375         SlicePatComponents { prefix, slice, suffix }
376     }
377 }
378
379 #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
380 pub enum SelfParamKind {
381     /// self
382     Owned,
383     /// &self
384     Ref,
385     /// &mut self
386     MutRef,
387 }
388
389 impl ast::SelfParam {
390     pub fn kind(&self) -> SelfParamKind {
391         if self.amp_token().is_some() {
392             if self.mut_token().is_some() {
393                 SelfParamKind::MutRef
394             } else {
395                 SelfParamKind::Ref
396             }
397         } else {
398             SelfParamKind::Owned
399         }
400     }
401 }
402
403 #[derive(Clone, Debug, PartialEq, Eq, Hash)]
404 pub enum TypeBoundKind {
405     /// Trait
406     PathType(ast::PathType),
407     /// for<'a> ...
408     ForType(ast::ForType),
409     /// 'a
410     Lifetime(ast::Lifetime),
411 }
412
413 impl ast::TypeBound {
414     pub fn kind(&self) -> TypeBoundKind {
415         if let Some(path_type) = support::children(self.syntax()).next() {
416             TypeBoundKind::PathType(path_type)
417         } else if let Some(for_type) = support::children(self.syntax()).next() {
418             TypeBoundKind::ForType(for_type)
419         } else if let Some(lifetime) = self.lifetime() {
420             TypeBoundKind::Lifetime(lifetime)
421         } else {
422             unreachable!()
423         }
424     }
425 }
426
427 pub enum VisibilityKind {
428     In(ast::Path),
429     PubCrate,
430     PubSuper,
431     PubSelf,
432     Pub,
433 }
434
435 impl ast::Visibility {
436     pub fn kind(&self) -> VisibilityKind {
437         if let Some(path) = support::children(self.syntax()).next() {
438             VisibilityKind::In(path)
439         } else if self.crate_token().is_some() {
440             VisibilityKind::PubCrate
441         } else if self.super_token().is_some() {
442             VisibilityKind::PubSuper
443         } else if self.self_token().is_some() {
444             VisibilityKind::PubSelf
445         } else {
446             VisibilityKind::Pub
447         }
448     }
449 }
450
451 impl ast::LifetimeParam {
452     pub fn lifetime_bounds(&self) -> impl Iterator<Item = SyntaxToken> {
453         self.syntax()
454             .children_with_tokens()
455             .filter_map(|it| it.into_token())
456             .skip_while(|x| x.kind() != T![:])
457             .filter(|it| it.kind() == T![lifetime_ident])
458     }
459 }
460
461 impl ast::RangePat {
462     pub fn start(&self) -> Option<ast::Pat> {
463         self.syntax()
464             .children_with_tokens()
465             .take_while(|it| !(it.kind() == T![..] || it.kind() == T![..=]))
466             .filter_map(|it| it.into_node())
467             .find_map(ast::Pat::cast)
468     }
469
470     pub fn end(&self) -> Option<ast::Pat> {
471         self.syntax()
472             .children_with_tokens()
473             .skip_while(|it| !(it.kind() == T![..] || it.kind() == T![..=]))
474             .filter_map(|it| it.into_node())
475             .find_map(ast::Pat::cast)
476     }
477 }
478
479 impl ast::TokenTree {
480     pub fn left_delimiter_token(&self) -> Option<SyntaxToken> {
481         self.syntax()
482             .first_child_or_token()?
483             .into_token()
484             .filter(|it| matches!(it.kind(), T!['{'] | T!['('] | T!['[']))
485     }
486
487     pub fn right_delimiter_token(&self) -> Option<SyntaxToken> {
488         self.syntax()
489             .last_child_or_token()?
490             .into_token()
491             .filter(|it| matches!(it.kind(), T!['}'] | T![')'] | T![']']))
492     }
493 }
494
495 impl ast::GenericParamList {
496     pub fn lifetime_params(&self) -> impl Iterator<Item = ast::LifetimeParam> {
497         self.generic_params().filter_map(|param| match param {
498             ast::GenericParam::LifetimeParam(it) => Some(it),
499             ast::GenericParam::TypeParam(_) | ast::GenericParam::ConstParam(_) => None,
500         })
501     }
502     pub fn type_params(&self) -> impl Iterator<Item = ast::TypeParam> {
503         self.generic_params().filter_map(|param| match param {
504             ast::GenericParam::TypeParam(it) => Some(it),
505             ast::GenericParam::LifetimeParam(_) | ast::GenericParam::ConstParam(_) => None,
506         })
507     }
508     pub fn const_params(&self) -> impl Iterator<Item = ast::ConstParam> {
509         self.generic_params().filter_map(|param| match param {
510             ast::GenericParam::ConstParam(it) => Some(it),
511             ast::GenericParam::TypeParam(_) | ast::GenericParam::LifetimeParam(_) => None,
512         })
513     }
514 }
515
516 impl ast::DocCommentsOwner for ast::SourceFile {}
517 impl ast::DocCommentsOwner for ast::Fn {}
518 impl ast::DocCommentsOwner for ast::Struct {}
519 impl ast::DocCommentsOwner for ast::Union {}
520 impl ast::DocCommentsOwner for ast::RecordField {}
521 impl ast::DocCommentsOwner for ast::TupleField {}
522 impl ast::DocCommentsOwner for ast::Enum {}
523 impl ast::DocCommentsOwner for ast::Variant {}
524 impl ast::DocCommentsOwner for ast::Trait {}
525 impl ast::DocCommentsOwner for ast::Module {}
526 impl ast::DocCommentsOwner for ast::Static {}
527 impl ast::DocCommentsOwner for ast::Const {}
528 impl ast::DocCommentsOwner for ast::TypeAlias {}
529 impl ast::DocCommentsOwner for ast::Impl {}
530 impl ast::DocCommentsOwner for ast::MacroRules {}
531 impl ast::DocCommentsOwner for ast::MacroDef {}
532 impl ast::DocCommentsOwner for ast::Macro {}
533 impl ast::DocCommentsOwner for ast::Use {}