]> git.lizzy.rs Git - rust.git/blob - crates/ide_assists/src/utils/gen_trait_fn_body.rs
implement feedback from review
[rust.git] / crates / ide_assists / src / utils / gen_trait_fn_body.rs
1 //! This module contains functions to generate default trait impl function bodies where possible.
2
3 use syntax::{
4     ast::{self, edit::AstNodeEdit, make, AstNode, NameOwner},
5     ted,
6 };
7
8 /// Generate custom trait bodies where possible.
9 ///
10 /// Returns `Option` so that we can use `?` rather than `if let Some`. Returning
11 /// `None` means that generating a custom trait body failed, and the body will remain
12 /// as `todo!` instead.
13 pub(crate) fn gen_trait_fn_body(
14     func: &ast::Fn,
15     trait_path: &ast::Path,
16     adt: &ast::Adt,
17 ) -> Option<()> {
18     match trait_path.segment()?.name_ref()?.text().as_str() {
19         "Clone" => gen_clone_impl(adt, func),
20         "Debug" => gen_debug_impl(adt, func),
21         "Default" => gen_default_impl(adt, func),
22         "Hash" => gen_hash_impl(adt, func),
23         "PartialEq" => gen_partial_eq(adt, func),
24         _ => None,
25     }
26 }
27
28 /// Generate a `Clone` impl based on the fields and members of the target type.
29 fn gen_clone_impl(adt: &ast::Adt, func: &ast::Fn) -> Option<()> {
30     fn gen_clone_call(target: ast::Expr) -> ast::Expr {
31         let method = make::name_ref("clone");
32         make::expr_method_call(target, method, make::arg_list(None))
33     }
34     let expr = match adt {
35         // `Clone` cannot be derived for unions, so no default impl can be provided.
36         ast::Adt::Union(_) => return None,
37         ast::Adt::Enum(enum_) => {
38             let list = enum_.variant_list()?;
39             let mut arms = vec![];
40             for variant in list.variants() {
41                 let variant_name = make_variant_path(&variant)?;
42
43                 match variant.field_list() {
44                     // => match self { Self::Name { x } => Self::Name { x: x.clone() } }
45                     Some(ast::FieldList::RecordFieldList(list)) => {
46                         let mut pats = vec![];
47                         let mut fields = vec![];
48                         for field in list.fields() {
49                             let field_name = field.name()?;
50                             let pat = make::ident_pat(false, false, field_name.clone());
51                             pats.push(pat.into());
52
53                             let path = make::ext::ident_path(&field_name.to_string());
54                             let method_call = gen_clone_call(make::expr_path(path));
55                             let name_ref = make::name_ref(&field_name.to_string());
56                             let field = make::record_expr_field(name_ref, Some(method_call));
57                             fields.push(field);
58                         }
59                         let pat = make::record_pat(variant_name.clone(), pats.into_iter());
60                         let fields = make::record_expr_field_list(fields);
61                         let record_expr = make::record_expr(variant_name, fields).into();
62                         arms.push(make::match_arm(Some(pat.into()), None, record_expr));
63                     }
64
65                     // => match self { Self::Name(arg1) => Self::Name(arg1.clone()) }
66                     Some(ast::FieldList::TupleFieldList(list)) => {
67                         let mut pats = vec![];
68                         let mut fields = vec![];
69                         for (i, _) in list.fields().enumerate() {
70                             let field_name = format!("arg{}", i);
71                             let pat = make::ident_pat(false, false, make::name(&field_name));
72                             pats.push(pat.into());
73
74                             let f_path = make::expr_path(make::ext::ident_path(&field_name));
75                             fields.push(gen_clone_call(f_path));
76                         }
77                         let pat = make::tuple_struct_pat(variant_name.clone(), pats.into_iter());
78                         let struct_name = make::expr_path(variant_name);
79                         let tuple_expr = make::expr_call(struct_name, make::arg_list(fields));
80                         arms.push(make::match_arm(Some(pat.into()), None, tuple_expr));
81                     }
82
83                     // => match self { Self::Name => Self::Name }
84                     None => {
85                         let pattern = make::path_pat(variant_name.clone());
86                         let variant_expr = make::expr_path(variant_name);
87                         arms.push(make::match_arm(Some(pattern.into()), None, variant_expr));
88                     }
89                 }
90             }
91
92             let match_target = make::expr_path(make::ext::ident_path("self"));
93             let list = make::match_arm_list(arms).indent(ast::edit::IndentLevel(1));
94             make::expr_match(match_target, list)
95         }
96         ast::Adt::Struct(strukt) => {
97             match strukt.field_list() {
98                 // => Self { name: self.name.clone() }
99                 Some(ast::FieldList::RecordFieldList(field_list)) => {
100                     let mut fields = vec![];
101                     for field in field_list.fields() {
102                         let base = make::expr_path(make::ext::ident_path("self"));
103                         let target = make::expr_field(base, &field.name()?.to_string());
104                         let method_call = gen_clone_call(target);
105                         let name_ref = make::name_ref(&field.name()?.to_string());
106                         let field = make::record_expr_field(name_ref, Some(method_call));
107                         fields.push(field);
108                     }
109                     let struct_name = make::ext::ident_path("Self");
110                     let fields = make::record_expr_field_list(fields);
111                     make::record_expr(struct_name, fields).into()
112                 }
113                 // => Self(self.0.clone(), self.1.clone())
114                 Some(ast::FieldList::TupleFieldList(field_list)) => {
115                     let mut fields = vec![];
116                     for (i, _) in field_list.fields().enumerate() {
117                         let f_path = make::expr_path(make::ext::ident_path("self"));
118                         let target = make::expr_field(f_path, &format!("{}", i)).into();
119                         fields.push(gen_clone_call(target));
120                     }
121                     let struct_name = make::expr_path(make::ext::ident_path("Self"));
122                     make::expr_call(struct_name, make::arg_list(fields))
123                 }
124                 // => Self { }
125                 None => {
126                     let struct_name = make::ext::ident_path("Self");
127                     let fields = make::record_expr_field_list(None);
128                     make::record_expr(struct_name, fields).into()
129                 }
130             }
131         }
132     };
133     let body = make::block_expr(None, Some(expr)).indent(ast::edit::IndentLevel(1));
134     ted::replace(func.body()?.syntax(), body.clone_for_update().syntax());
135     Some(())
136 }
137
138 /// Generate a `Debug` impl based on the fields and members of the target type.
139 fn gen_debug_impl(adt: &ast::Adt, func: &ast::Fn) -> Option<()> {
140     let annotated_name = adt.name()?;
141     match adt {
142         // `Debug` cannot be derived for unions, so no default impl can be provided.
143         ast::Adt::Union(_) => None,
144
145         // => match self { Self::Variant => write!(f, "Variant") }
146         ast::Adt::Enum(enum_) => {
147             let list = enum_.variant_list()?;
148             let mut arms = vec![];
149             for variant in list.variants() {
150                 let name = variant.name()?;
151                 let variant_name = make::path_pat(make::path_from_text(&format!("Self::{}", name)));
152
153                 let target = make::expr_path(make::ext::ident_path("f").into());
154                 let fmt_string = make::expr_literal(&(format!("\"{}\"", name))).into();
155                 let args = make::arg_list(vec![target, fmt_string]);
156                 let macro_name = make::expr_path(make::ext::ident_path("write"));
157                 let macro_call = make::expr_macro_call(macro_name, args);
158
159                 arms.push(make::match_arm(Some(variant_name.into()), None, macro_call.into()));
160             }
161
162             let match_target = make::expr_path(make::ext::ident_path("self"));
163             let list = make::match_arm_list(arms).indent(ast::edit::IndentLevel(1));
164             let match_expr = make::expr_match(match_target, list);
165
166             let body = make::block_expr(None, Some(match_expr));
167             let body = body.indent(ast::edit::IndentLevel(1));
168             ted::replace(func.body()?.syntax(), body.clone_for_update().syntax());
169             Some(())
170         }
171
172         ast::Adt::Struct(strukt) => {
173             let name = format!("\"{}\"", annotated_name);
174             let args = make::arg_list(Some(make::expr_literal(&name).into()));
175             let target = make::expr_path(make::ext::ident_path("f"));
176
177             let expr = match strukt.field_list() {
178                 // => f.debug_struct("Name").finish()
179                 None => make::expr_method_call(target, make::name_ref("debug_struct"), args),
180
181                 // => f.debug_struct("Name").field("foo", &self.foo).finish()
182                 Some(ast::FieldList::RecordFieldList(field_list)) => {
183                     let method = make::name_ref("debug_struct");
184                     let mut expr = make::expr_method_call(target, method, args);
185                     for field in field_list.fields() {
186                         let name = field.name()?;
187                         let f_name = make::expr_literal(&(format!("\"{}\"", name))).into();
188                         let f_path = make::expr_path(make::ext::ident_path("self"));
189                         let f_path = make::expr_ref(f_path, false);
190                         let f_path = make::expr_field(f_path, &format!("{}", name)).into();
191                         let args = make::arg_list(vec![f_name, f_path]);
192                         expr = make::expr_method_call(expr, make::name_ref("field"), args);
193                     }
194                     expr
195                 }
196
197                 // => f.debug_tuple("Name").field(self.0).finish()
198                 Some(ast::FieldList::TupleFieldList(field_list)) => {
199                     let method = make::name_ref("debug_tuple");
200                     let mut expr = make::expr_method_call(target, method, args);
201                     for (i, _) in field_list.fields().enumerate() {
202                         let f_path = make::expr_path(make::ext::ident_path("self"));
203                         let f_path = make::expr_ref(f_path, false);
204                         let f_path = make::expr_field(f_path, &format!("{}", i)).into();
205                         let method = make::name_ref("field");
206                         expr = make::expr_method_call(expr, method, make::arg_list(Some(f_path)));
207                     }
208                     expr
209                 }
210             };
211
212             let method = make::name_ref("finish");
213             let expr = make::expr_method_call(expr, method, make::arg_list(None));
214             let body = make::block_expr(None, Some(expr)).indent(ast::edit::IndentLevel(1));
215             ted::replace(func.body()?.syntax(), body.clone_for_update().syntax());
216             Some(())
217         }
218     }
219 }
220
221 /// Generate a `Debug` impl based on the fields and members of the target type.
222 fn gen_default_impl(adt: &ast::Adt, func: &ast::Fn) -> Option<()> {
223     fn gen_default_call() -> ast::Expr {
224         let fn_name = make::path_from_text(&"Default::default");
225         make::expr_call(make::expr_path(fn_name), make::arg_list(None))
226     }
227     match adt {
228         // `Debug` cannot be derived for unions, so no default impl can be provided.
229         ast::Adt::Union(_) => None,
230         // Deriving `Debug` for enums is not stable yet.
231         ast::Adt::Enum(_) => None,
232         ast::Adt::Struct(strukt) => {
233             let expr = match strukt.field_list() {
234                 Some(ast::FieldList::RecordFieldList(field_list)) => {
235                     let mut fields = vec![];
236                     for field in field_list.fields() {
237                         let method_call = gen_default_call();
238                         let name_ref = make::name_ref(&field.name()?.to_string());
239                         let field = make::record_expr_field(name_ref, Some(method_call));
240                         fields.push(field);
241                     }
242                     let struct_name = make::ext::ident_path("Self");
243                     let fields = make::record_expr_field_list(fields);
244                     make::record_expr(struct_name, fields).into()
245                 }
246                 Some(ast::FieldList::TupleFieldList(field_list)) => {
247                     let struct_name = make::expr_path(make::ext::ident_path("Self"));
248                     let fields = field_list.fields().map(|_| gen_default_call());
249                     make::expr_call(struct_name, make::arg_list(fields))
250                 }
251                 None => {
252                     let struct_name = make::ext::ident_path("Self");
253                     let fields = make::record_expr_field_list(None);
254                     make::record_expr(struct_name, fields).into()
255                 }
256             };
257             let body = make::block_expr(None, Some(expr)).indent(ast::edit::IndentLevel(1));
258             ted::replace(func.body()?.syntax(), body.clone_for_update().syntax());
259             Some(())
260         }
261     }
262 }
263
264 /// Generate a `Hash` impl based on the fields and members of the target type.
265 fn gen_hash_impl(adt: &ast::Adt, func: &ast::Fn) -> Option<()> {
266     fn gen_hash_call(target: ast::Expr) -> ast::Stmt {
267         let method = make::name_ref("hash");
268         let arg = make::expr_path(make::ext::ident_path("state"));
269         let expr = make::expr_method_call(target, method, make::arg_list(Some(arg)));
270         make::expr_stmt(expr).into()
271     }
272
273     let body = match adt {
274         // `Hash` cannot be derived for unions, so no default impl can be provided.
275         ast::Adt::Union(_) => return None,
276
277         // => std::mem::discriminant(self).hash(state);
278         ast::Adt::Enum(_) => {
279             let fn_name = make_discriminant();
280
281             let arg = make::expr_path(make::ext::ident_path("self"));
282             let fn_call = make::expr_call(fn_name, make::arg_list(Some(arg)));
283             let stmt = gen_hash_call(fn_call);
284
285             make::block_expr(Some(stmt), None).indent(ast::edit::IndentLevel(1))
286         }
287         ast::Adt::Struct(strukt) => match strukt.field_list() {
288             // => self.<field>.hash(state);
289             Some(ast::FieldList::RecordFieldList(field_list)) => {
290                 let mut stmts = vec![];
291                 for field in field_list.fields() {
292                     let base = make::expr_path(make::ext::ident_path("self"));
293                     let target = make::expr_field(base, &field.name()?.to_string());
294                     stmts.push(gen_hash_call(target));
295                 }
296                 make::block_expr(stmts, None).indent(ast::edit::IndentLevel(1))
297             }
298
299             // => self.<field_index>.hash(state);
300             Some(ast::FieldList::TupleFieldList(field_list)) => {
301                 let mut stmts = vec![];
302                 for (i, _) in field_list.fields().enumerate() {
303                     let base = make::expr_path(make::ext::ident_path("self"));
304                     let target = make::expr_field(base, &format!("{}", i));
305                     stmts.push(gen_hash_call(target));
306                 }
307                 make::block_expr(stmts, None).indent(ast::edit::IndentLevel(1))
308             }
309
310             // No fields in the body means there's nothing to hash.
311             None => return None,
312         },
313     };
314
315     ted::replace(func.body()?.syntax(), body.clone_for_update().syntax());
316     Some(())
317 }
318
319 /// Generate a `PartialEq` impl based on the fields and members of the target type.
320 fn gen_partial_eq(adt: &ast::Adt, func: &ast::Fn) -> Option<()> {
321     fn gen_eq_chain(expr: Option<ast::Expr>, cmp: ast::Expr) -> Option<ast::Expr> {
322         match expr {
323             Some(expr) => Some(make::expr_op(ast::BinOp::BooleanAnd, expr, cmp)),
324             None => Some(cmp),
325         }
326     }
327
328     fn gen_record_pat_field(field_name: &str, pat_name: &str) -> ast::RecordPatField {
329         let pat = make::ext::simple_ident_pat(make::name(&pat_name));
330         let name_ref = make::name_ref(field_name);
331         make::record_pat_field(name_ref, pat.into())
332     }
333
334     fn gen_record_pat(record_name: ast::Path, fields: Vec<ast::RecordPatField>) -> ast::RecordPat {
335         let list = make::record_pat_field_list(fields);
336         make::record_pat_with_fields(record_name, list)
337     }
338
339     fn gen_tuple_field(field_name: &String) -> ast::Pat {
340         ast::Pat::IdentPat(make::ident_pat(false, false, make::name(field_name)))
341     }
342
343     // FIXME: return `None` if the trait carries a generic type; we can only
344     // generate this code `Self` for the time being.
345
346     let body = match adt {
347         // `PartialEq` cannot be derived for unions, so no default impl can be provided.
348         ast::Adt::Union(_) => return None,
349
350         ast::Adt::Enum(enum_) => {
351             // => std::mem::discriminant(self) == std::mem::discriminant(other)
352             let lhs_name = make::expr_path(make::ext::ident_path("self"));
353             let lhs = make::expr_call(make_discriminant(), make::arg_list(Some(lhs_name.clone())));
354             let rhs_name = make::expr_path(make::ext::ident_path("other"));
355             let rhs = make::expr_call(make_discriminant(), make::arg_list(Some(rhs_name.clone())));
356             let eq_check = make::expr_op(ast::BinOp::EqualityTest, lhs, rhs);
357
358             let mut case_count = 0;
359             let mut arms = vec![];
360             for variant in enum_.variant_list()?.variants() {
361                 case_count += 1;
362                 match variant.field_list() {
363                     // => (Self::Bar { bin: l_bin }, Self::Bar { bin: r_bin }) => l_bin == r_bin,
364                     Some(ast::FieldList::RecordFieldList(list)) => {
365                         let mut expr = None;
366                         let mut l_fields = vec![];
367                         let mut r_fields = vec![];
368
369                         for field in list.fields() {
370                             let field_name = field.name()?.to_string();
371
372                             let l_name = &format!("l_{}", field_name);
373                             l_fields.push(gen_record_pat_field(&field_name, &l_name));
374
375                             let r_name = &format!("r_{}", field_name);
376                             r_fields.push(gen_record_pat_field(&field_name, &r_name));
377
378                             let lhs = make::expr_path(make::ext::ident_path(l_name));
379                             let rhs = make::expr_path(make::ext::ident_path(r_name));
380                             let cmp = make::expr_op(ast::BinOp::EqualityTest, lhs, rhs);
381                             expr = gen_eq_chain(expr, cmp);
382                         }
383
384                         let left = gen_record_pat(make_variant_path(&variant)?, l_fields);
385                         let right = gen_record_pat(make_variant_path(&variant)?, r_fields);
386                         let tuple = make::tuple_pat(vec![left.into(), right.into()]);
387
388                         if let Some(expr) = expr {
389                             arms.push(make::match_arm(Some(tuple.into()), None, expr));
390                         }
391                     }
392
393                     Some(ast::FieldList::TupleFieldList(list)) => {
394                         let mut expr = None;
395                         let mut l_fields = vec![];
396                         let mut r_fields = vec![];
397
398                         for (i, _) in list.fields().enumerate() {
399                             let field_name = format!("{}", i);
400
401                             let l_name = format!("l{}", field_name);
402                             l_fields.push(gen_tuple_field(&l_name));
403
404                             let r_name = format!("r{}", field_name);
405                             r_fields.push(gen_tuple_field(&r_name));
406
407                             let lhs = make::expr_path(make::ext::ident_path(&l_name));
408                             let rhs = make::expr_path(make::ext::ident_path(&r_name));
409                             let cmp = make::expr_op(ast::BinOp::EqualityTest, lhs, rhs);
410                             expr = gen_eq_chain(expr, cmp);
411                         }
412
413                         let left = make::tuple_struct_pat(make_variant_path(&variant)?, l_fields);
414                         let right = make::tuple_struct_pat(make_variant_path(&variant)?, r_fields);
415                         let tuple = make::tuple_pat(vec![left.into(), right.into()]);
416
417                         if let Some(expr) = expr {
418                             arms.push(make::match_arm(Some(tuple.into()), None, expr));
419                         }
420                     }
421                     None => continue,
422                 }
423             }
424
425             let expr = match arms.len() {
426                 0 => eq_check,
427                 _ => {
428                     if case_count > arms.len() {
429                         let lhs = make::wildcard_pat().into();
430                         arms.push(make::match_arm(Some(lhs), None, eq_check));
431                     }
432
433                     let match_target = make::expr_tuple(vec![lhs_name, rhs_name]);
434                     let list = make::match_arm_list(arms).indent(ast::edit::IndentLevel(1));
435                     make::expr_match(match_target, list)
436                 }
437             };
438
439             make::block_expr(None, Some(expr)).indent(ast::edit::IndentLevel(1))
440         }
441         ast::Adt::Struct(strukt) => match strukt.field_list() {
442             Some(ast::FieldList::RecordFieldList(field_list)) => {
443                 let mut expr = None;
444                 for field in field_list.fields() {
445                     let lhs = make::expr_path(make::ext::ident_path("self"));
446                     let lhs = make::expr_field(lhs, &field.name()?.to_string());
447                     let rhs = make::expr_path(make::ext::ident_path("other"));
448                     let rhs = make::expr_field(rhs, &field.name()?.to_string());
449                     let cmp = make::expr_op(ast::BinOp::EqualityTest, lhs, rhs);
450                     expr = gen_eq_chain(expr, cmp);
451                 }
452                 make::block_expr(None, expr).indent(ast::edit::IndentLevel(1))
453             }
454
455             Some(ast::FieldList::TupleFieldList(field_list)) => {
456                 let mut expr = None;
457                 for (i, _) in field_list.fields().enumerate() {
458                     let idx = format!("{}", i);
459                     let lhs = make::expr_path(make::ext::ident_path("self"));
460                     let lhs = make::expr_field(lhs, &idx);
461                     let rhs = make::expr_path(make::ext::ident_path("other"));
462                     let rhs = make::expr_field(rhs, &idx);
463                     let cmp = make::expr_op(ast::BinOp::EqualityTest, lhs, rhs);
464                     expr = gen_eq_chain(expr, cmp);
465                 }
466                 make::block_expr(None, expr).indent(ast::edit::IndentLevel(1))
467             }
468
469             // No fields in the body means there's nothing to compare.
470             None => {
471                 let expr = make::expr_literal("true").into();
472                 make::block_expr(None, Some(expr)).indent(ast::edit::IndentLevel(1))
473             }
474         },
475     };
476
477     ted::replace(func.body()?.syntax(), body.clone_for_update().syntax());
478     Some(())
479 }
480
481 fn make_discriminant() -> ast::Expr {
482     make::expr_path(make::path_from_text("core::mem::discriminant"))
483 }
484
485 fn make_variant_path(variant: &ast::Variant) -> Option<ast::Path> {
486     Some(make::path_from_text(&format!("Self::{}", &variant.name()?)))
487 }