1 //! Structural editing for ast.
9 ast::{self, edit::AstNodeEdit, make, GenericParamsOwner, WhereClause},
10 ted::{self, Position},
11 AstNode, AstToken, Direction,
16 pub trait GenericParamsOwnerEdit: ast::GenericParamsOwner + AstNodeEdit {
17 fn get_or_create_generic_param_list(&self) -> ast::GenericParamList;
18 fn get_or_create_where_clause(&self) -> ast::WhereClause;
21 impl GenericParamsOwnerEdit for ast::Fn {
22 fn get_or_create_generic_param_list(&self) -> ast::GenericParamList {
23 match self.generic_param_list() {
26 let position = if let Some(name) = self.name() {
27 Position::after(name.syntax)
28 } else if let Some(fn_token) = self.fn_token() {
29 Position::after(fn_token)
30 } else if let Some(param_list) = self.param_list() {
31 Position::before(param_list.syntax)
33 Position::last_child_of(self.syntax())
35 create_generic_param_list(position)
40 fn get_or_create_where_clause(&self) -> WhereClause {
41 if self.where_clause().is_none() {
42 let position = if let Some(ty) = self.ret_type() {
43 Position::after(ty.syntax())
44 } else if let Some(param_list) = self.param_list() {
45 Position::after(param_list.syntax())
47 Position::last_child_of(self.syntax())
49 create_where_clause(position)
51 self.where_clause().unwrap()
55 impl GenericParamsOwnerEdit for ast::Impl {
56 fn get_or_create_generic_param_list(&self) -> ast::GenericParamList {
57 match self.generic_param_list() {
60 let position = if let Some(imp_token) = self.impl_token() {
61 Position::after(imp_token)
63 Position::last_child_of(self.syntax())
65 create_generic_param_list(position)
70 fn get_or_create_where_clause(&self) -> WhereClause {
71 if self.where_clause().is_none() {
72 let position = if let Some(items) = self.assoc_item_list() {
73 Position::before(items.syntax())
75 Position::last_child_of(self.syntax())
77 create_where_clause(position)
79 self.where_clause().unwrap()
83 impl GenericParamsOwnerEdit for ast::Trait {
84 fn get_or_create_generic_param_list(&self) -> ast::GenericParamList {
85 match self.generic_param_list() {
88 let position = if let Some(name) = self.name() {
89 Position::after(name.syntax)
90 } else if let Some(trait_token) = self.trait_token() {
91 Position::after(trait_token)
93 Position::last_child_of(self.syntax())
95 create_generic_param_list(position)
100 fn get_or_create_where_clause(&self) -> WhereClause {
101 if self.where_clause().is_none() {
102 let position = if let Some(items) = self.assoc_item_list() {
103 Position::before(items.syntax())
105 Position::last_child_of(self.syntax())
107 create_where_clause(position)
109 self.where_clause().unwrap()
113 impl GenericParamsOwnerEdit for ast::Struct {
114 fn get_or_create_generic_param_list(&self) -> ast::GenericParamList {
115 match self.generic_param_list() {
118 let position = if let Some(name) = self.name() {
119 Position::after(name.syntax)
120 } else if let Some(struct_token) = self.struct_token() {
121 Position::after(struct_token)
123 Position::last_child_of(self.syntax())
125 create_generic_param_list(position)
130 fn get_or_create_where_clause(&self) -> WhereClause {
131 if self.where_clause().is_none() {
132 let tfl = self.field_list().and_then(|fl| match fl {
133 ast::FieldList::RecordFieldList(_) => None,
134 ast::FieldList::TupleFieldList(it) => Some(it),
136 let position = if let Some(tfl) = tfl {
137 Position::after(tfl.syntax())
138 } else if let Some(gpl) = self.generic_param_list() {
139 Position::after(gpl.syntax())
140 } else if let Some(name) = self.name() {
141 Position::after(name.syntax())
143 Position::last_child_of(self.syntax())
145 create_where_clause(position)
147 self.where_clause().unwrap()
151 impl GenericParamsOwnerEdit for ast::Enum {
152 fn get_or_create_generic_param_list(&self) -> ast::GenericParamList {
153 match self.generic_param_list() {
156 let position = if let Some(name) = self.name() {
157 Position::after(name.syntax)
158 } else if let Some(enum_token) = self.enum_token() {
159 Position::after(enum_token)
161 Position::last_child_of(self.syntax())
163 create_generic_param_list(position)
168 fn get_or_create_where_clause(&self) -> WhereClause {
169 if self.where_clause().is_none() {
170 let position = if let Some(gpl) = self.generic_param_list() {
171 Position::after(gpl.syntax())
172 } else if let Some(name) = self.name() {
173 Position::after(name.syntax())
175 Position::last_child_of(self.syntax())
177 create_where_clause(position)
179 self.where_clause().unwrap()
183 fn create_where_clause(position: Position) {
184 let where_clause = make::where_clause(empty()).clone_for_update();
185 ted::insert(position, where_clause.syntax());
188 fn create_generic_param_list(position: Position) -> ast::GenericParamList {
189 let gpl = make::generic_param_list(empty()).clone_for_update();
190 ted::insert_raw(position, gpl.syntax());
194 impl ast::GenericParamList {
195 pub fn add_generic_param(&self, generic_param: ast::GenericParam) {
196 match self.generic_params().last() {
197 Some(last_param) => {
198 let mut elems = Vec::new();
201 .siblings_with_tokens(Direction::Next)
202 .any(|it| it.kind() == T![,])
204 elems.push(make::token(T![,]).into());
205 elems.push(make::tokens::single_space().into());
207 elems.push(generic_param.syntax().clone().into());
208 let after_last_param = Position::after(last_param.syntax());
209 ted::insert_all(after_last_param, elems);
212 let after_l_angle = Position::after(self.l_angle_token().unwrap());
213 ted::insert(after_l_angle, generic_param.syntax())
219 impl ast::WhereClause {
220 pub fn add_predicate(&self, predicate: ast::WherePred) {
221 if let Some(pred) = self.predicates().last() {
222 if !pred.syntax().siblings_with_tokens(Direction::Next).any(|it| it.kind() == T![,]) {
223 ted::append_child_raw(self.syntax(), make::token(T![,]));
226 ted::append_child(self.syntax(), predicate.syntax())
230 impl ast::TypeBoundList {
231 pub fn remove(&self) {
233 self.syntax().siblings_with_tokens(Direction::Prev).find(|it| it.kind() == T![:])
235 ted::remove_all(colon..=self.syntax().clone().into())
237 ted::remove(self.syntax())
243 pub fn remove(&self) {
244 for &dir in [Direction::Next, Direction::Prev].iter() {
245 if let Some(next_use_tree) = neighbor(self, dir) {
246 let separators = self
248 .siblings_with_tokens(dir)
250 .take_while(|it| it.as_node() != Some(next_use_tree.syntax()));
251 ted::remove_all_iter(separators);
255 ted::remove(self.syntax())
260 pub fn remove(&self) {
263 .next_sibling_or_token()
264 .and_then(|it| it.into_token())
265 .and_then(ast::Whitespace::cast);
266 if let Some(next_ws) = next_ws {
267 let ws_text = next_ws.syntax().text();
268 if let Some(rest) = ws_text.strip_prefix('\n') {
270 ted::remove(next_ws.syntax())
272 ted::replace(next_ws.syntax(), make::tokens::whitespace(rest))
276 ted::remove(self.syntax())
284 use crate::SourceFile;
288 fn ast_mut_from_text<N: AstNode>(text: &str) -> N {
289 let parse = SourceFile::parse(text);
290 parse.tree().syntax().descendants().find_map(N::cast).unwrap().clone_for_update()
294 fn test_create_generic_param_list() {
295 fn check_create_gpl<N: GenericParamsOwnerEdit + fmt::Display>(before: &str, after: &str) {
296 let gpl_owner = ast_mut_from_text::<N>(before);
297 gpl_owner.get_or_create_generic_param_list();
298 assert_eq!(gpl_owner.to_string(), after);
301 check_create_gpl::<ast::Fn>("fn foo", "fn foo<>");
302 check_create_gpl::<ast::Fn>("fn foo() {}", "fn foo<>() {}");
304 check_create_gpl::<ast::Impl>("impl", "impl<>");
305 check_create_gpl::<ast::Impl>("impl Struct {}", "impl<> Struct {}");
306 check_create_gpl::<ast::Impl>("impl Trait for Struct {}", "impl<> Trait for Struct {}");
308 check_create_gpl::<ast::Trait>("trait Trait<>", "trait Trait<>");
309 check_create_gpl::<ast::Trait>("trait Trait<> {}", "trait Trait<> {}");
311 check_create_gpl::<ast::Struct>("struct A", "struct A<>");
312 check_create_gpl::<ast::Struct>("struct A;", "struct A<>;");
313 check_create_gpl::<ast::Struct>("struct A();", "struct A<>();");
314 check_create_gpl::<ast::Struct>("struct A {}", "struct A<> {}");
316 check_create_gpl::<ast::Enum>("enum E", "enum E<>");
317 check_create_gpl::<ast::Enum>("enum E {", "enum E<> {");