]> git.lizzy.rs Git - rust.git/blob - crates/ide_assists/src/utils/gen_trait_fn_body.rs
Merge #9856
[rust.git] / crates / ide_assists / src / utils / gen_trait_fn_body.rs
1 //! This module contains functions to generate default trait impl function bodies where possible.
2
3 use syntax::{
4     ast::{self, edit::AstNodeEdit, make, AstNode, NameOwner},
5     ted,
6 };
7
8 /// Generate custom trait bodies where possible.
9 ///
10 /// Returns `Option` so that we can use `?` rather than `if let Some`. Returning
11 /// `None` means that generating a custom trait body failed, and the body will remain
12 /// as `todo!` instead.
13 pub(crate) fn gen_trait_fn_body(
14     func: &ast::Fn,
15     trait_path: &ast::Path,
16     adt: &ast::Adt,
17 ) -> Option<()> {
18     match trait_path.segment()?.name_ref()?.text().as_str() {
19         "Clone" => gen_clone_impl(adt, func),
20         "Debug" => gen_debug_impl(adt, func),
21         "Default" => gen_default_impl(adt, func),
22         "Hash" => gen_hash_impl(adt, func),
23         _ => None,
24     }
25 }
26
27 /// Generate a `Clone` impl based on the fields and members of the target type.
28 fn gen_clone_impl(adt: &ast::Adt, func: &ast::Fn) -> Option<()> {
29     fn gen_clone_call(target: ast::Expr) -> ast::Expr {
30         let method = make::name_ref("clone");
31         make::expr_method_call(target, method, make::arg_list(None))
32     }
33     let expr = match adt {
34         // `Clone` cannot be derived for unions, so no default impl can be provided.
35         ast::Adt::Union(_) => return None,
36         ast::Adt::Enum(enum_) => {
37             let list = enum_.variant_list()?;
38             let mut arms = vec![];
39             for variant in list.variants() {
40                 let name = variant.name()?;
41                 let left = make::ext::ident_path("Self");
42                 let right = make::ext::ident_path(&format!("{}", name));
43                 let variant_name = make::path_concat(left, right);
44
45                 match variant.field_list() {
46                     // => match self { Self::Name { x } => Self::Name { x: x.clone() } }
47                     Some(ast::FieldList::RecordFieldList(list)) => {
48                         let mut pats = vec![];
49                         let mut fields = vec![];
50                         for field in list.fields() {
51                             let field_name = field.name()?;
52                             let pat = make::ident_pat(false, false, field_name.clone());
53                             pats.push(pat.into());
54
55                             let path = make::ext::ident_path(&field_name.to_string());
56                             let method_call = gen_clone_call(make::expr_path(path));
57                             let name_ref = make::name_ref(&field_name.to_string());
58                             let field = make::record_expr_field(name_ref, Some(method_call));
59                             fields.push(field);
60                         }
61                         let pat = make::record_pat(variant_name.clone(), pats.into_iter());
62                         let fields = make::record_expr_field_list(fields);
63                         let record_expr = make::record_expr(variant_name, fields).into();
64                         arms.push(make::match_arm(Some(pat.into()), None, record_expr));
65                     }
66
67                     // => match self { Self::Name(arg1) => Self::Name(arg1.clone()) }
68                     Some(ast::FieldList::TupleFieldList(list)) => {
69                         let mut pats = vec![];
70                         let mut fields = vec![];
71                         for (i, _) in list.fields().enumerate() {
72                             let field_name = format!("arg{}", i);
73                             let pat = make::ident_pat(false, false, make::name(&field_name));
74                             pats.push(pat.into());
75
76                             let f_path = make::expr_path(make::ext::ident_path(&field_name));
77                             fields.push(gen_clone_call(f_path));
78                         }
79                         let pat = make::tuple_struct_pat(variant_name.clone(), pats.into_iter());
80                         let struct_name = make::expr_path(variant_name);
81                         let tuple_expr = make::expr_call(struct_name, make::arg_list(fields));
82                         arms.push(make::match_arm(Some(pat.into()), None, tuple_expr));
83                     }
84
85                     // => match self { Self::Name => Self::Name }
86                     None => {
87                         let pattern = make::path_pat(variant_name.clone());
88                         let variant_expr = make::expr_path(variant_name);
89                         arms.push(make::match_arm(Some(pattern.into()), None, variant_expr));
90                     }
91                 }
92             }
93
94             let match_target = make::expr_path(make::ext::ident_path("self"));
95             let list = make::match_arm_list(arms).indent(ast::edit::IndentLevel(1));
96             make::expr_match(match_target, list)
97         }
98         ast::Adt::Struct(strukt) => {
99             match strukt.field_list() {
100                 // => Self { name: self.name.clone() }
101                 Some(ast::FieldList::RecordFieldList(field_list)) => {
102                     let mut fields = vec![];
103                     for field in field_list.fields() {
104                         let base = make::expr_path(make::ext::ident_path("self"));
105                         let target = make::expr_field(base, &field.name()?.to_string());
106                         let method_call = gen_clone_call(target);
107                         let name_ref = make::name_ref(&field.name()?.to_string());
108                         let field = make::record_expr_field(name_ref, Some(method_call));
109                         fields.push(field);
110                     }
111                     let struct_name = make::ext::ident_path("Self");
112                     let fields = make::record_expr_field_list(fields);
113                     make::record_expr(struct_name, fields).into()
114                 }
115                 // => Self(self.0.clone(), self.1.clone())
116                 Some(ast::FieldList::TupleFieldList(field_list)) => {
117                     let mut fields = vec![];
118                     for (i, _) in field_list.fields().enumerate() {
119                         let f_path = make::expr_path(make::ext::ident_path("self"));
120                         let target = make::expr_field(f_path, &format!("{}", i)).into();
121                         fields.push(gen_clone_call(target));
122                     }
123                     let struct_name = make::expr_path(make::ext::ident_path("Self"));
124                     make::expr_call(struct_name, make::arg_list(fields))
125                 }
126                 // => Self { }
127                 None => {
128                     let struct_name = make::ext::ident_path("Self");
129                     let fields = make::record_expr_field_list(None);
130                     make::record_expr(struct_name, fields).into()
131                 }
132             }
133         }
134     };
135     let body = make::block_expr(None, Some(expr)).indent(ast::edit::IndentLevel(1));
136     ted::replace(func.body()?.syntax(), body.clone_for_update().syntax());
137     Some(())
138 }
139
140 /// Generate a `Debug` impl based on the fields and members of the target type.
141 fn gen_debug_impl(adt: &ast::Adt, func: &ast::Fn) -> Option<()> {
142     let annotated_name = adt.name()?;
143     match adt {
144         // `Debug` cannot be derived for unions, so no default impl can be provided.
145         ast::Adt::Union(_) => None,
146
147         // => match self { Self::Variant => write!(f, "Variant") }
148         ast::Adt::Enum(enum_) => {
149             let list = enum_.variant_list()?;
150             let mut arms = vec![];
151             for variant in list.variants() {
152                 let name = variant.name()?;
153                 let left = make::ext::ident_path("Self");
154                 let right = make::ext::ident_path(&format!("{}", name));
155                 let variant_name = make::path_pat(make::path_concat(left, right));
156
157                 let target = make::expr_path(make::ext::ident_path("f").into());
158                 let fmt_string = make::expr_literal(&(format!("\"{}\"", name))).into();
159                 let args = make::arg_list(vec![target, fmt_string]);
160                 let macro_name = make::expr_path(make::ext::ident_path("write"));
161                 let macro_call = make::expr_macro_call(macro_name, args);
162
163                 arms.push(make::match_arm(Some(variant_name.into()), None, macro_call.into()));
164             }
165
166             let match_target = make::expr_path(make::ext::ident_path("self"));
167             let list = make::match_arm_list(arms).indent(ast::edit::IndentLevel(1));
168             let match_expr = make::expr_match(match_target, list);
169
170             let body = make::block_expr(None, Some(match_expr));
171             let body = body.indent(ast::edit::IndentLevel(1));
172             ted::replace(func.body()?.syntax(), body.clone_for_update().syntax());
173             Some(())
174         }
175
176         ast::Adt::Struct(strukt) => {
177             let name = format!("\"{}\"", annotated_name);
178             let args = make::arg_list(Some(make::expr_literal(&name).into()));
179             let target = make::expr_path(make::ext::ident_path("f"));
180
181             let expr = match strukt.field_list() {
182                 // => f.debug_struct("Name").finish()
183                 None => make::expr_method_call(target, make::name_ref("debug_struct"), args),
184
185                 // => f.debug_struct("Name").field("foo", &self.foo).finish()
186                 Some(ast::FieldList::RecordFieldList(field_list)) => {
187                     let method = make::name_ref("debug_struct");
188                     let mut expr = make::expr_method_call(target, method, args);
189                     for field in field_list.fields() {
190                         let name = field.name()?;
191                         let f_name = make::expr_literal(&(format!("\"{}\"", name))).into();
192                         let f_path = make::expr_path(make::ext::ident_path("self"));
193                         let f_path = make::expr_ref(f_path, false);
194                         let f_path = make::expr_field(f_path, &format!("{}", name)).into();
195                         let args = make::arg_list(vec![f_name, f_path]);
196                         expr = make::expr_method_call(expr, make::name_ref("field"), args);
197                     }
198                     expr
199                 }
200
201                 // => f.debug_tuple("Name").field(self.0).finish()
202                 Some(ast::FieldList::TupleFieldList(field_list)) => {
203                     let method = make::name_ref("debug_tuple");
204                     let mut expr = make::expr_method_call(target, method, args);
205                     for (i, _) in field_list.fields().enumerate() {
206                         let f_path = make::expr_path(make::ext::ident_path("self"));
207                         let f_path = make::expr_ref(f_path, false);
208                         let f_path = make::expr_field(f_path, &format!("{}", i)).into();
209                         let method = make::name_ref("field");
210                         expr = make::expr_method_call(expr, method, make::arg_list(Some(f_path)));
211                     }
212                     expr
213                 }
214             };
215
216             let method = make::name_ref("finish");
217             let expr = make::expr_method_call(expr, method, make::arg_list(None));
218             let body = make::block_expr(None, Some(expr)).indent(ast::edit::IndentLevel(1));
219             ted::replace(func.body()?.syntax(), body.clone_for_update().syntax());
220             Some(())
221         }
222     }
223 }
224
225 /// Generate a `Debug` impl based on the fields and members of the target type.
226 fn gen_default_impl(adt: &ast::Adt, func: &ast::Fn) -> Option<()> {
227     fn gen_default_call() -> ast::Expr {
228         let trait_name = make::ext::ident_path("Default");
229         let method_name = make::ext::ident_path("default");
230         let fn_name = make::expr_path(make::path_concat(trait_name, method_name));
231         make::expr_call(fn_name, make::arg_list(None))
232     }
233     match adt {
234         // `Debug` cannot be derived for unions, so no default impl can be provided.
235         ast::Adt::Union(_) => None,
236         // Deriving `Debug` for enums is not stable yet.
237         ast::Adt::Enum(_) => None,
238         ast::Adt::Struct(strukt) => {
239             let expr = match strukt.field_list() {
240                 Some(ast::FieldList::RecordFieldList(field_list)) => {
241                     let mut fields = vec![];
242                     for field in field_list.fields() {
243                         let method_call = gen_default_call();
244                         let name_ref = make::name_ref(&field.name()?.to_string());
245                         let field = make::record_expr_field(name_ref, Some(method_call));
246                         fields.push(field);
247                     }
248                     let struct_name = make::ext::ident_path("Self");
249                     let fields = make::record_expr_field_list(fields);
250                     make::record_expr(struct_name, fields).into()
251                 }
252                 Some(ast::FieldList::TupleFieldList(field_list)) => {
253                     let struct_name = make::expr_path(make::ext::ident_path("Self"));
254                     let fields = field_list.fields().map(|_| gen_default_call());
255                     make::expr_call(struct_name, make::arg_list(fields))
256                 }
257                 None => {
258                     let struct_name = make::ext::ident_path("Self");
259                     let fields = make::record_expr_field_list(None);
260                     make::record_expr(struct_name, fields).into()
261                 }
262             };
263             let body = make::block_expr(None, Some(expr)).indent(ast::edit::IndentLevel(1));
264             ted::replace(func.body()?.syntax(), body.clone_for_update().syntax());
265             Some(())
266         }
267     }
268 }
269
270 /// Generate a `Hash` impl based on the fields and members of the target type.
271 fn gen_hash_impl(adt: &ast::Adt, func: &ast::Fn) -> Option<()> {
272     fn gen_hash_call(target: ast::Expr) -> ast::Stmt {
273         let method = make::name_ref("hash");
274         let arg = make::expr_path(make::ext::ident_path("state"));
275         let expr = make::expr_method_call(target, method, make::arg_list(Some(arg)));
276         let stmt = make::expr_stmt(expr);
277         stmt.into()
278     }
279
280     let body = match adt {
281         // `Hash` cannot be derived for unions, so no default impl can be provided.
282         ast::Adt::Union(_) => return None,
283
284         // => std::mem::discriminant(self).hash(state);
285         ast::Adt::Enum(_) => {
286             let root = make::ext::ident_path("core");
287             let submodule = make::ext::ident_path("mem");
288             let fn_name = make::ext::ident_path("discriminant");
289             let fn_name = make::path_concat(submodule, fn_name);
290             let fn_name = make::expr_path(make::path_concat(root, fn_name));
291
292             let arg = make::expr_path(make::ext::ident_path("self"));
293             let fn_call = make::expr_call(fn_name, make::arg_list(Some(arg)));
294             let stmt = gen_hash_call(fn_call);
295
296             make::block_expr(Some(stmt), None).indent(ast::edit::IndentLevel(1))
297         }
298         ast::Adt::Struct(strukt) => match strukt.field_list() {
299             // => self.<field>.hash(state);
300             Some(ast::FieldList::RecordFieldList(field_list)) => {
301                 let mut stmts = vec![];
302                 for field in field_list.fields() {
303                     let base = make::expr_path(make::ext::ident_path("self"));
304                     let target = make::expr_field(base, &field.name()?.to_string());
305                     stmts.push(gen_hash_call(target));
306                 }
307                 make::block_expr(stmts, None).indent(ast::edit::IndentLevel(1))
308             }
309
310             // => self.<field_index>.hash(state);
311             Some(ast::FieldList::TupleFieldList(field_list)) => {
312                 let mut stmts = vec![];
313                 for (i, _) in field_list.fields().enumerate() {
314                     let base = make::expr_path(make::ext::ident_path("self"));
315                     let target = make::expr_field(base, &format!("{}", i));
316                     stmts.push(gen_hash_call(target));
317                 }
318                 make::block_expr(stmts, None).indent(ast::edit::IndentLevel(1))
319             }
320
321             // No fields in the body means there's nothing to hash.
322             None => return None,
323         },
324     };
325
326     ted::replace(func.body()?.syntax(), body.clone_for_update().syntax());
327     Some(())
328 }