]> git.lizzy.rs Git - rust.git/blob - crates/syntax/src/ast/edit_in_place.rs
Finish GenericParamsOwnerEdit impls
[rust.git] / crates / syntax / src / ast / edit_in_place.rs
1 //! Structural editing for ast.
2
3 use std::iter::empty;
4
5 use parser::T;
6
7 use crate::{
8     algo::neighbor,
9     ast::{self, edit::AstNodeEdit, make, GenericParamsOwner, WhereClause},
10     ted::{self, Position},
11     AstNode, AstToken, Direction,
12 };
13
14 use super::NameOwner;
15
16 pub trait GenericParamsOwnerEdit: ast::GenericParamsOwner + AstNodeEdit {
17     fn get_or_create_generic_param_list(&self) -> ast::GenericParamList;
18     fn get_or_create_where_clause(&self) -> ast::WhereClause;
19 }
20
21 impl GenericParamsOwnerEdit for ast::Fn {
22     fn get_or_create_generic_param_list(&self) -> ast::GenericParamList {
23         match self.generic_param_list() {
24             Some(it) => it,
25             None => {
26                 let position = if let Some(name) = self.name() {
27                     Position::after(name.syntax)
28                 } else if let Some(fn_token) = self.fn_token() {
29                     Position::after(fn_token)
30                 } else if let Some(param_list) = self.param_list() {
31                     Position::before(param_list.syntax)
32                 } else {
33                     Position::last_child_of(self.syntax())
34                 };
35                 create_generic_param_list(position)
36             }
37         }
38     }
39
40     fn get_or_create_where_clause(&self) -> WhereClause {
41         if self.where_clause().is_none() {
42             let position = if let Some(ty) = self.ret_type() {
43                 Position::after(ty.syntax())
44             } else if let Some(param_list) = self.param_list() {
45                 Position::after(param_list.syntax())
46             } else {
47                 Position::last_child_of(self.syntax())
48             };
49             create_where_clause(position)
50         }
51         self.where_clause().unwrap()
52     }
53 }
54
55 impl GenericParamsOwnerEdit for ast::Impl {
56     fn get_or_create_generic_param_list(&self) -> ast::GenericParamList {
57         match self.generic_param_list() {
58             Some(it) => it,
59             None => {
60                 let position = if let Some(imp_token) = self.impl_token() {
61                     Position::after(imp_token)
62                 } else {
63                     Position::last_child_of(self.syntax())
64                 };
65                 create_generic_param_list(position)
66             }
67         }
68     }
69
70     fn get_or_create_where_clause(&self) -> WhereClause {
71         if self.where_clause().is_none() {
72             let position = if let Some(items) = self.assoc_item_list() {
73                 Position::before(items.syntax())
74             } else {
75                 Position::last_child_of(self.syntax())
76             };
77             create_where_clause(position)
78         }
79         self.where_clause().unwrap()
80     }
81 }
82
83 impl GenericParamsOwnerEdit for ast::Trait {
84     fn get_or_create_generic_param_list(&self) -> ast::GenericParamList {
85         match self.generic_param_list() {
86             Some(it) => it,
87             None => {
88                 let position = if let Some(name) = self.name() {
89                     Position::after(name.syntax)
90                 } else if let Some(trait_token) = self.trait_token() {
91                     Position::after(trait_token)
92                 } else {
93                     Position::last_child_of(self.syntax())
94                 };
95                 create_generic_param_list(position)
96             }
97         }
98     }
99
100     fn get_or_create_where_clause(&self) -> WhereClause {
101         if self.where_clause().is_none() {
102             let position = if let Some(items) = self.assoc_item_list() {
103                 Position::before(items.syntax())
104             } else {
105                 Position::last_child_of(self.syntax())
106             };
107             create_where_clause(position)
108         }
109         self.where_clause().unwrap()
110     }
111 }
112
113 impl GenericParamsOwnerEdit for ast::Struct {
114     fn get_or_create_generic_param_list(&self) -> ast::GenericParamList {
115         match self.generic_param_list() {
116             Some(it) => it,
117             None => {
118                 let position = if let Some(name) = self.name() {
119                     Position::after(name.syntax)
120                 } else if let Some(struct_token) = self.struct_token() {
121                     Position::after(struct_token)
122                 } else {
123                     Position::last_child_of(self.syntax())
124                 };
125                 create_generic_param_list(position)
126             }
127         }
128     }
129
130     fn get_or_create_where_clause(&self) -> WhereClause {
131         if self.where_clause().is_none() {
132             let tfl = self.field_list().and_then(|fl| match fl {
133                 ast::FieldList::RecordFieldList(_) => None,
134                 ast::FieldList::TupleFieldList(it) => Some(it),
135             });
136             let position = if let Some(tfl) = tfl {
137                 Position::after(tfl.syntax())
138             } else if let Some(gpl) = self.generic_param_list() {
139                 Position::after(gpl.syntax())
140             } else if let Some(name) = self.name() {
141                 Position::after(name.syntax())
142             } else {
143                 Position::last_child_of(self.syntax())
144             };
145             create_where_clause(position)
146         }
147         self.where_clause().unwrap()
148     }
149 }
150
151 impl GenericParamsOwnerEdit for ast::Enum {
152     fn get_or_create_generic_param_list(&self) -> ast::GenericParamList {
153         match self.generic_param_list() {
154             Some(it) => it,
155             None => {
156                 let position = if let Some(name) = self.name() {
157                     Position::after(name.syntax)
158                 } else if let Some(enum_token) = self.enum_token() {
159                     Position::after(enum_token)
160                 } else {
161                     Position::last_child_of(self.syntax())
162                 };
163                 create_generic_param_list(position)
164             }
165         }
166     }
167
168     fn get_or_create_where_clause(&self) -> WhereClause {
169         if self.where_clause().is_none() {
170             let position = if let Some(gpl) = self.generic_param_list() {
171                 Position::after(gpl.syntax())
172             } else if let Some(name) = self.name() {
173                 Position::after(name.syntax())
174             } else {
175                 Position::last_child_of(self.syntax())
176             };
177             create_where_clause(position)
178         }
179         self.where_clause().unwrap()
180     }
181 }
182
183 fn create_where_clause(position: Position) {
184     let where_clause = make::where_clause(empty()).clone_for_update();
185     ted::insert(position, where_clause.syntax());
186 }
187
188 fn create_generic_param_list(position: Position) -> ast::GenericParamList {
189     let gpl = make::generic_param_list(empty()).clone_for_update();
190     ted::insert_raw(position, gpl.syntax());
191     gpl
192 }
193
194 impl ast::GenericParamList {
195     pub fn add_generic_param(&self, generic_param: ast::GenericParam) {
196         match self.generic_params().last() {
197             Some(last_param) => {
198                 let mut elems = Vec::new();
199                 if !last_param
200                     .syntax()
201                     .siblings_with_tokens(Direction::Next)
202                     .any(|it| it.kind() == T![,])
203                 {
204                     elems.push(make::token(T![,]).into());
205                     elems.push(make::tokens::single_space().into());
206                 };
207                 elems.push(generic_param.syntax().clone().into());
208                 let after_last_param = Position::after(last_param.syntax());
209                 ted::insert_all(after_last_param, elems);
210             }
211             None => {
212                 let after_l_angle = Position::after(self.l_angle_token().unwrap());
213                 ted::insert(after_l_angle, generic_param.syntax())
214             }
215         }
216     }
217 }
218
219 impl ast::WhereClause {
220     pub fn add_predicate(&self, predicate: ast::WherePred) {
221         if let Some(pred) = self.predicates().last() {
222             if !pred.syntax().siblings_with_tokens(Direction::Next).any(|it| it.kind() == T![,]) {
223                 ted::append_child_raw(self.syntax(), make::token(T![,]));
224             }
225         }
226         ted::append_child(self.syntax(), predicate.syntax())
227     }
228 }
229
230 impl ast::TypeBoundList {
231     pub fn remove(&self) {
232         if let Some(colon) =
233             self.syntax().siblings_with_tokens(Direction::Prev).find(|it| it.kind() == T![:])
234         {
235             ted::remove_all(colon..=self.syntax().clone().into())
236         } else {
237             ted::remove(self.syntax())
238         }
239     }
240 }
241
242 impl ast::UseTree {
243     pub fn remove(&self) {
244         for &dir in [Direction::Next, Direction::Prev].iter() {
245             if let Some(next_use_tree) = neighbor(self, dir) {
246                 let separators = self
247                     .syntax()
248                     .siblings_with_tokens(dir)
249                     .skip(1)
250                     .take_while(|it| it.as_node() != Some(next_use_tree.syntax()));
251                 ted::remove_all_iter(separators);
252                 break;
253             }
254         }
255         ted::remove(self.syntax())
256     }
257 }
258
259 impl ast::Use {
260     pub fn remove(&self) {
261         let next_ws = self
262             .syntax()
263             .next_sibling_or_token()
264             .and_then(|it| it.into_token())
265             .and_then(ast::Whitespace::cast);
266         if let Some(next_ws) = next_ws {
267             let ws_text = next_ws.syntax().text();
268             if let Some(rest) = ws_text.strip_prefix('\n') {
269                 if rest.is_empty() {
270                     ted::remove(next_ws.syntax())
271                 } else {
272                     ted::replace(next_ws.syntax(), make::tokens::whitespace(rest))
273                 }
274             }
275         }
276         ted::remove(self.syntax())
277     }
278 }
279
280 #[cfg(test)]
281 mod tests {
282     use std::fmt;
283
284     use crate::SourceFile;
285
286     use super::*;
287
288     fn ast_mut_from_text<N: AstNode>(text: &str) -> N {
289         let parse = SourceFile::parse(text);
290         parse.tree().syntax().descendants().find_map(N::cast).unwrap().clone_for_update()
291     }
292
293     #[test]
294     fn test_create_generic_param_list() {
295         fn check_create_gpl<N: GenericParamsOwnerEdit + fmt::Display>(before: &str, after: &str) {
296             let gpl_owner = ast_mut_from_text::<N>(before);
297             gpl_owner.get_or_create_generic_param_list();
298             assert_eq!(gpl_owner.to_string(), after);
299         }
300
301         check_create_gpl::<ast::Fn>("fn foo", "fn foo<>");
302         check_create_gpl::<ast::Fn>("fn foo() {}", "fn foo<>() {}");
303
304         check_create_gpl::<ast::Impl>("impl", "impl<>");
305         check_create_gpl::<ast::Impl>("impl Struct {}", "impl<> Struct {}");
306         check_create_gpl::<ast::Impl>("impl Trait for Struct {}", "impl<> Trait for Struct {}");
307
308         check_create_gpl::<ast::Trait>("trait Trait<>", "trait Trait<>");
309         check_create_gpl::<ast::Trait>("trait Trait<> {}", "trait Trait<> {}");
310
311         check_create_gpl::<ast::Struct>("struct A", "struct A<>");
312         check_create_gpl::<ast::Struct>("struct A;", "struct A<>;");
313         check_create_gpl::<ast::Struct>("struct A();", "struct A<>();");
314         check_create_gpl::<ast::Struct>("struct A {}", "struct A<> {}");
315
316         check_create_gpl::<ast::Enum>("enum E", "enum E<>");
317         check_create_gpl::<ast::Enum>("enum E {", "enum E<> {");
318     }
319 }