]> git.lizzy.rs Git - rust.git/blob - crates/ide_assists/src/handlers/generate_delegate_methods.rs
233f26ed63707181b342316bd32c0114f2c5bded
[rust.git] / crates / ide_assists / src / handlers / generate_delegate_methods.rs
1 use hir::{self, HasCrate, HasSource};
2 use syntax::ast::{self, make, AstNode, HasGenericParams, HasName, HasVisibility};
3
4 use crate::{
5     utils::{convert_param_list_to_arg_list, find_struct_impl, render_snippet, Cursor},
6     AssistContext, AssistId, AssistKind, Assists, GroupLabel,
7 };
8 use syntax::ast::edit::AstNodeEdit;
9
10 // Assist: generate_delegate_methods
11 //
12 // Generate delegate methods.
13 //
14 // ```
15 // struct Age(u8);
16 // impl Age {
17 //     fn age(&self) -> u8 {
18 //         self.0
19 //     }
20 // }
21 //
22 // struct Person {
23 //     ag$0e: Age,
24 // }
25 // ```
26 // ->
27 // ```
28 // struct Age(u8);
29 // impl Age {
30 //     fn age(&self) -> u8 {
31 //         self.0
32 //     }
33 // }
34 //
35 // struct Person {
36 //     age: Age,
37 // }
38 //
39 // impl Person {
40 //     $0fn age(&self) -> u8 {
41 //         self.age.age()
42 //     }
43 // }
44 // ```
45 pub(crate) fn generate_delegate_methods(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
46     let strukt = ctx.find_node_at_offset::<ast::Struct>()?;
47     let strukt_name = strukt.name()?;
48
49     let (field_name, field_ty) = match ctx.find_node_at_offset::<ast::RecordField>() {
50         Some(field) => {
51             let field_name = field.name()?;
52             let field_ty = field.ty()?;
53             (format!("{}", field_name), field_ty)
54         }
55         None => {
56             let field = ctx.find_node_at_offset::<ast::TupleField>()?;
57             let field_list = ctx.find_node_at_offset::<ast::TupleFieldList>()?;
58             let field_list_index = field_list.fields().position(|it| it == field)?;
59             let field_ty = field.ty()?;
60             (format!("{}", field_list_index), field_ty)
61         }
62     };
63
64     let sema_field_ty = ctx.sema.resolve_type(&field_ty)?;
65     let krate = sema_field_ty.krate(ctx.db());
66     let mut methods = vec![];
67     sema_field_ty.iterate_assoc_items(ctx.db(), krate, |item| {
68         if let hir::AssocItem::Function(f) = item {
69             if f.self_param(ctx.db()).is_some() {
70                 methods.push(f)
71             }
72         }
73         Option::<()>::None
74     });
75
76     let target = field_ty.syntax().text_range();
77     for method in methods {
78         let adt = ast::Adt::Struct(strukt.clone());
79         let name = method.name(ctx.db()).to_string();
80         let impl_def = find_struct_impl(ctx, &adt, &name).flatten();
81         acc.add_group(
82             &GroupLabel("Generate delegate methods…".to_owned()),
83             AssistId("generate_delegate_methods", AssistKind::Generate),
84             format!("Generate delegate for `{}.{}()`", field_name, method.name(ctx.db())),
85             target,
86             |builder| {
87                 // Create the function
88                 let method_source = match method.source(ctx.db()) {
89                     Some(source) => source.value,
90                     None => return,
91                 };
92                 let method_name = method.name(ctx.db());
93                 let vis = method_source.visibility();
94                 let name = make::name(&method.name(ctx.db()).to_string());
95                 let params =
96                     method_source.param_list().unwrap_or_else(|| make::param_list(None, []));
97                 let type_params = method_source.generic_param_list();
98                 let arg_list = match method_source.param_list() {
99                     Some(list) => convert_param_list_to_arg_list(list),
100                     None => make::arg_list([]),
101                 };
102                 let tail_expr = make::expr_method_call(
103                     make::ext::field_from_idents(["self", &field_name]).unwrap(), // This unwrap is ok because we have at least 1 arg in the list
104                     make::name_ref(&method_name.to_string()),
105                     arg_list,
106                 );
107                 let body = make::block_expr([], Some(tail_expr));
108                 let ret_type = method_source.ret_type();
109                 let is_async = method_source.async_token().is_some();
110                 let f = make::fn_(vis, name, type_params, params, body, ret_type, is_async)
111                     .indent(ast::edit::IndentLevel(1))
112                     .clone_for_update();
113
114                 let cursor = Cursor::Before(f.syntax());
115
116                 // Create or update an impl block, attach the function to it,
117                 // then insert into our code.
118                 match impl_def {
119                     Some(impl_def) => {
120                         // Remember where in our source our `impl` block lives.
121                         let impl_def = impl_def.clone_for_update();
122                         let old_range = impl_def.syntax().text_range();
123
124                         // Attach the function to the impl block
125                         let assoc_items = impl_def.get_or_create_assoc_item_list();
126                         assoc_items.add_item(f.clone().into());
127
128                         // Update the impl block.
129                         match ctx.config.snippet_cap {
130                             Some(cap) => {
131                                 let snippet = render_snippet(cap, impl_def.syntax(), cursor);
132                                 builder.replace_snippet(cap, old_range, snippet);
133                             }
134                             None => {
135                                 builder.replace(old_range, impl_def.syntax().to_string());
136                             }
137                         }
138                     }
139                     None => {
140                         // Attach the function to the impl block
141                         let name = &strukt_name.to_string();
142                         let params = strukt.generic_param_list();
143                         let ty_params = params.clone();
144                         let impl_def = make::impl_(make::ext::ident_path(name), params, ty_params)
145                             .clone_for_update();
146                         let assoc_items = impl_def.get_or_create_assoc_item_list();
147                         assoc_items.add_item(f.clone().into());
148
149                         // Insert the impl block.
150                         match ctx.config.snippet_cap {
151                             Some(cap) => {
152                                 let offset = strukt.syntax().text_range().end();
153                                 let snippet = render_snippet(cap, impl_def.syntax(), cursor);
154                                 let snippet = format!("\n\n{}", snippet);
155                                 builder.insert_snippet(cap, offset, snippet);
156                             }
157                             None => {
158                                 let offset = strukt.syntax().text_range().end();
159                                 let snippet = format!("\n\n{}", impl_def.syntax().to_string());
160                                 builder.insert(offset, snippet);
161                             }
162                         }
163                     }
164                 }
165             },
166         )?;
167     }
168     Some(())
169 }
170
171 #[cfg(test)]
172 mod tests {
173     use crate::tests::check_assist;
174
175     use super::*;
176
177     #[test]
178     fn test_generate_delegate_create_impl_block() {
179         check_assist(
180             generate_delegate_methods,
181             r#"
182 struct Age(u8);
183 impl Age {
184     fn age(&self) -> u8 {
185         self.0
186     }
187 }
188
189 struct Person {
190     ag$0e: Age,
191 }"#,
192             r#"
193 struct Age(u8);
194 impl Age {
195     fn age(&self) -> u8 {
196         self.0
197     }
198 }
199
200 struct Person {
201     age: Age,
202 }
203
204 impl Person {
205     $0fn age(&self) -> u8 {
206         self.age.age()
207     }
208 }"#,
209         );
210     }
211
212     #[test]
213     fn test_generate_delegate_update_impl_block() {
214         check_assist(
215             generate_delegate_methods,
216             r#"
217 struct Age(u8);
218 impl Age {
219     fn age(&self) -> u8 {
220         self.0
221     }
222 }
223
224 struct Person {
225     ag$0e: Age,
226 }
227
228 impl Person {}"#,
229             r#"
230 struct Age(u8);
231 impl Age {
232     fn age(&self) -> u8 {
233         self.0
234     }
235 }
236
237 struct Person {
238     age: Age,
239 }
240
241 impl Person {
242     $0fn age(&self) -> u8 {
243         self.age.age()
244     }
245 }"#,
246         );
247     }
248
249     #[test]
250     fn test_generate_delegate_tuple_struct() {
251         check_assist(
252             generate_delegate_methods,
253             r#"
254 struct Age(u8);
255 impl Age {
256     fn age(&self) -> u8 {
257         self.0
258     }
259 }
260
261 struct Person(A$0ge);"#,
262             r#"
263 struct Age(u8);
264 impl Age {
265     fn age(&self) -> u8 {
266         self.0
267     }
268 }
269
270 struct Person(Age);
271
272 impl Person {
273     $0fn age(&self) -> u8 {
274         self.0.age()
275     }
276 }"#,
277         );
278     }
279
280     #[test]
281     fn test_generate_delegate_enable_all_attributes() {
282         check_assist(
283             generate_delegate_methods,
284             r#"
285 struct Age<T>(T);
286 impl<T> Age<T> {
287     pub(crate) async fn age<J, 'a>(&'a mut self, ty: T, arg: J) -> T {
288         self.0
289     }
290 }
291
292 struct Person<T> {
293     ag$0e: Age<T>,
294 }"#,
295             r#"
296 struct Age<T>(T);
297 impl<T> Age<T> {
298     pub(crate) async fn age<J, 'a>(&'a mut self, ty: T, arg: J) -> T {
299         self.0
300     }
301 }
302
303 struct Person<T> {
304     age: Age<T>,
305 }
306
307 impl<T> Person<T> {
308     $0pub(crate) async fn age<J, 'a>(&'a mut self, ty: T, arg: J) -> T {
309         self.age.age(ty, arg)
310     }
311 }"#,
312         );
313     }
314 }