]> git.lizzy.rs Git - rust.git/blob - crates/ide_assists/src/utils/gen_trait_fn_body.rs
add clone generation for structs and bare enums
[rust.git] / crates / ide_assists / src / utils / gen_trait_fn_body.rs
1 //! This module contains functions to generate default trait impl function bodies where possible.
2
3 use syntax::{
4     ast::{self, edit::AstNodeEdit, make, AstNode, NameOwner},
5     ted,
6 };
7
8 /// Generate custom trait bodies where possible.
9 ///
10 /// Returns `Option` so that we can use `?` rather than `if let Some`. Returning
11 /// `None` means that generating a custom trait body failed, and the body will remain
12 /// as `todo!` instead.
13 pub(crate) fn gen_trait_fn_body(
14     func: &ast::Fn,
15     trait_path: &ast::Path,
16     adt: &ast::Adt,
17 ) -> Option<()> {
18     match trait_path.segment()?.name_ref()?.text().as_str() {
19         "Clone" => gen_clone_impl(adt, func),
20         "Debug" => gen_debug_impl(adt, func),
21         "Default" => gen_default_impl(adt, func),
22         "Hash" => gen_hash_impl(adt, func),
23         _ => None,
24     }
25 }
26
27 /// Generate a `Clone` impl based on the fields and members of the target type.
28 fn gen_clone_impl(adt: &ast::Adt, func: &ast::Fn) -> Option<()> {
29     fn gen_clone_call(target: ast::Expr) -> ast::Expr {
30         let method = make::name_ref("clone");
31         make::expr_method_call(target, method, make::arg_list(None))
32     }
33     let expr = match adt {
34         // `Clone` cannot be derived for unions, so no default impl can be provided.
35         ast::Adt::Union(_) => return None,
36         ast::Adt::Enum(enum_) => {
37             let list = enum_.variant_list()?;
38             let mut arms = vec![];
39             for variant in list.variants() {
40                 let name = variant.name()?;
41                 let left = make::ext::ident_path("Self");
42                 let right = make::ext::ident_path(&format!("{}", name));
43                 let variant_name = make::path_concat(left, right);
44
45                 let pattern = make::path_pat(variant_name.clone());
46                 let variant_expr = make::expr_path(variant_name);
47                 arms.push(make::match_arm(Some(pattern.into()), None, variant_expr));
48             }
49
50             let match_target = make::expr_path(make::ext::ident_path("self"));
51             let list = make::match_arm_list(arms).indent(ast::edit::IndentLevel(1));
52             make::expr_match(match_target, list)
53         }
54         ast::Adt::Struct(strukt) => {
55             match strukt.field_list() {
56                 // => Self { name: self.name.clone() }
57                 Some(ast::FieldList::RecordFieldList(field_list)) => {
58                     let mut fields = vec![];
59                     for field in field_list.fields() {
60                         let base = make::expr_path(make::ext::ident_path("self"));
61                         let target = make::expr_field(base, &field.name()?.to_string());
62                         let method_call = gen_clone_call(target);
63                         let name_ref = make::name_ref(&field.name()?.to_string());
64                         let field = make::record_expr_field(name_ref, Some(method_call));
65                         fields.push(field);
66                     }
67                     let struct_name = make::ext::ident_path("Self");
68                     let fields = make::record_expr_field_list(fields);
69                     make::record_expr(struct_name, fields).into()
70                 }
71                 // => Self(self.0.clone(), self.1.clone())
72                 Some(ast::FieldList::TupleFieldList(field_list)) => {
73                     let mut fields = vec![];
74                     for (i, _) in field_list.fields().enumerate() {
75                         let f_path = make::expr_path(make::ext::ident_path("self"));
76                         let target = make::expr_field(f_path, &format!("{}", i)).into();
77                         fields.push(gen_clone_call(target));
78                     }
79                     let struct_name = make::expr_path(make::ext::ident_path("Self"));
80                     make::expr_call(struct_name, make::arg_list(fields))
81                 }
82                 // => Self { }
83                 None => {
84                     let struct_name = make::ext::ident_path("Self");
85                     let fields = make::record_expr_field_list(None);
86                     make::record_expr(struct_name, fields).into()
87                 }
88             }
89         }
90     };
91     let body = make::block_expr(None, Some(expr)).indent(ast::edit::IndentLevel(1));
92     ted::replace(func.body()?.syntax(), body.clone_for_update().syntax());
93     Some(())
94 }
95
96 /// Generate a `Debug` impl based on the fields and members of the target type.
97 fn gen_debug_impl(adt: &ast::Adt, func: &ast::Fn) -> Option<()> {
98     let annotated_name = adt.name()?;
99     match adt {
100         // `Debug` cannot be derived for unions, so no default impl can be provided.
101         ast::Adt::Union(_) => None,
102
103         // => match self { Self::Variant => write!(f, "Variant") }
104         ast::Adt::Enum(enum_) => {
105             let list = enum_.variant_list()?;
106             let mut arms = vec![];
107             for variant in list.variants() {
108                 let name = variant.name()?;
109                 let left = make::ext::ident_path("Self");
110                 let right = make::ext::ident_path(&format!("{}", name));
111                 let variant_name = make::path_pat(make::path_concat(left, right));
112
113                 let target = make::expr_path(make::ext::ident_path("f").into());
114                 let fmt_string = make::expr_literal(&(format!("\"{}\"", name))).into();
115                 let args = make::arg_list(vec![target, fmt_string]);
116                 let macro_name = make::expr_path(make::ext::ident_path("write"));
117                 let macro_call = make::expr_macro_call(macro_name, args);
118
119                 arms.push(make::match_arm(Some(variant_name.into()), None, macro_call.into()));
120             }
121
122             let match_target = make::expr_path(make::ext::ident_path("self"));
123             let list = make::match_arm_list(arms).indent(ast::edit::IndentLevel(1));
124             let match_expr = make::expr_match(match_target, list);
125
126             let body = make::block_expr(None, Some(match_expr));
127             let body = body.indent(ast::edit::IndentLevel(1));
128             ted::replace(func.body()?.syntax(), body.clone_for_update().syntax());
129             Some(())
130         }
131
132         ast::Adt::Struct(strukt) => {
133             let name = format!("\"{}\"", annotated_name);
134             let args = make::arg_list(Some(make::expr_literal(&name).into()));
135             let target = make::expr_path(make::ext::ident_path("f"));
136
137             let expr = match strukt.field_list() {
138                 // => f.debug_struct("Name").finish()
139                 None => make::expr_method_call(target, make::name_ref("debug_struct"), args),
140
141                 // => f.debug_struct("Name").field("foo", &self.foo).finish()
142                 Some(ast::FieldList::RecordFieldList(field_list)) => {
143                     let method = make::name_ref("debug_struct");
144                     let mut expr = make::expr_method_call(target, method, args);
145                     for field in field_list.fields() {
146                         let name = field.name()?;
147                         let f_name = make::expr_literal(&(format!("\"{}\"", name))).into();
148                         let f_path = make::expr_path(make::ext::ident_path("self"));
149                         let f_path = make::expr_ref(f_path, false);
150                         let f_path = make::expr_field(f_path, &format!("{}", name)).into();
151                         let args = make::arg_list(vec![f_name, f_path]);
152                         expr = make::expr_method_call(expr, make::name_ref("field"), args);
153                     }
154                     expr
155                 }
156
157                 // => f.debug_tuple("Name").field(self.0).finish()
158                 Some(ast::FieldList::TupleFieldList(field_list)) => {
159                     let method = make::name_ref("debug_tuple");
160                     let mut expr = make::expr_method_call(target, method, args);
161                     for (i, _) in field_list.fields().enumerate() {
162                         let f_path = make::expr_path(make::ext::ident_path("self"));
163                         let f_path = make::expr_ref(f_path, false);
164                         let f_path = make::expr_field(f_path, &format!("{}", i)).into();
165                         let method = make::name_ref("field");
166                         expr = make::expr_method_call(expr, method, make::arg_list(Some(f_path)));
167                     }
168                     expr
169                 }
170             };
171
172             let method = make::name_ref("finish");
173             let expr = make::expr_method_call(expr, method, make::arg_list(None));
174             let body = make::block_expr(None, Some(expr)).indent(ast::edit::IndentLevel(1));
175             ted::replace(func.body()?.syntax(), body.clone_for_update().syntax());
176             Some(())
177         }
178     }
179 }
180
181 /// Generate a `Debug` impl based on the fields and members of the target type.
182 fn gen_default_impl(adt: &ast::Adt, func: &ast::Fn) -> Option<()> {
183     fn gen_default_call() -> ast::Expr {
184         let trait_name = make::ext::ident_path("Default");
185         let method_name = make::ext::ident_path("default");
186         let fn_name = make::expr_path(make::path_concat(trait_name, method_name));
187         make::expr_call(fn_name, make::arg_list(None))
188     }
189     match adt {
190         // `Debug` cannot be derived for unions, so no default impl can be provided.
191         ast::Adt::Union(_) => None,
192         // Deriving `Debug` for enums is not stable yet.
193         ast::Adt::Enum(_) => None,
194         ast::Adt::Struct(strukt) => {
195             let expr = match strukt.field_list() {
196                 Some(ast::FieldList::RecordFieldList(field_list)) => {
197                     let mut fields = vec![];
198                     for field in field_list.fields() {
199                         let method_call = gen_default_call();
200                         let name_ref = make::name_ref(&field.name()?.to_string());
201                         let field = make::record_expr_field(name_ref, Some(method_call));
202                         fields.push(field);
203                     }
204                     let struct_name = make::ext::ident_path("Self");
205                     let fields = make::record_expr_field_list(fields);
206                     make::record_expr(struct_name, fields).into()
207                 }
208                 Some(ast::FieldList::TupleFieldList(field_list)) => {
209                     let struct_name = make::expr_path(make::ext::ident_path("Self"));
210                     let fields = field_list.fields().map(|_| gen_default_call());
211                     make::expr_call(struct_name, make::arg_list(fields))
212                 }
213                 None => {
214                     let struct_name = make::ext::ident_path("Self");
215                     let fields = make::record_expr_field_list(None);
216                     make::record_expr(struct_name, fields).into()
217                 }
218             };
219             let body = make::block_expr(None, Some(expr)).indent(ast::edit::IndentLevel(1));
220             ted::replace(func.body()?.syntax(), body.clone_for_update().syntax());
221             Some(())
222         }
223     }
224 }
225
226 /// Generate a `Hash` impl based on the fields and members of the target type.
227 fn gen_hash_impl(adt: &ast::Adt, func: &ast::Fn) -> Option<()> {
228     fn gen_hash_call(target: ast::Expr) -> ast::Stmt {
229         let method = make::name_ref("hash");
230         let arg = make::expr_path(make::ext::ident_path("state"));
231         let expr = make::expr_method_call(target, method, make::arg_list(Some(arg)));
232         let stmt = make::expr_stmt(expr);
233         stmt.into()
234     }
235
236     let body = match adt {
237         // `Hash` cannot be derived for unions, so no default impl can be provided.
238         ast::Adt::Union(_) => return None,
239
240         // => std::mem::discriminant(self).hash(state);
241         ast::Adt::Enum(_) => {
242             let root = make::ext::ident_path("core");
243             let submodule = make::ext::ident_path("mem");
244             let fn_name = make::ext::ident_path("discriminant");
245             let fn_name = make::path_concat(submodule, fn_name);
246             let fn_name = make::expr_path(make::path_concat(root, fn_name));
247
248             let arg = make::expr_path(make::ext::ident_path("self"));
249             let fn_call = make::expr_call(fn_name, make::arg_list(Some(arg)));
250             let stmt = gen_hash_call(fn_call);
251
252             make::block_expr(Some(stmt), None).indent(ast::edit::IndentLevel(1))
253         }
254         ast::Adt::Struct(strukt) => match strukt.field_list() {
255             // => self.<field>.hash(state);
256             Some(ast::FieldList::RecordFieldList(field_list)) => {
257                 let mut stmts = vec![];
258                 for field in field_list.fields() {
259                     let base = make::expr_path(make::ext::ident_path("self"));
260                     let target = make::expr_field(base, &field.name()?.to_string());
261                     stmts.push(gen_hash_call(target));
262                 }
263                 make::block_expr(stmts, None).indent(ast::edit::IndentLevel(1))
264             }
265
266             // => self.<field_index>.hash(state);
267             Some(ast::FieldList::TupleFieldList(field_list)) => {
268                 let mut stmts = vec![];
269                 for (i, _) in field_list.fields().enumerate() {
270                     let base = make::expr_path(make::ext::ident_path("self"));
271                     let target = make::expr_field(base, &format!("{}", i));
272                     stmts.push(gen_hash_call(target));
273                 }
274                 make::block_expr(stmts, None).indent(ast::edit::IndentLevel(1))
275             }
276
277             // No fields in the body means there's nothing to hash.
278             None => return None,
279         },
280     };
281
282     ted::replace(func.body()?.syntax(), body.clone_for_update().syntax());
283     Some(())
284 }