]> git.lizzy.rs Git - rust.git/blob - crates/ide_assists/src/utils/gen_trait_fn_body.rs
gen partialeq for tuple enums
[rust.git] / crates / ide_assists / src / utils / gen_trait_fn_body.rs
1 //! This module contains functions to generate default trait impl function bodies where possible.
2
3 use syntax::{
4     ast::{self, edit::AstNodeEdit, make, AstNode, NameOwner},
5     ted,
6 };
7
8 /// Generate custom trait bodies where possible.
9 ///
10 /// Returns `Option` so that we can use `?` rather than `if let Some`. Returning
11 /// `None` means that generating a custom trait body failed, and the body will remain
12 /// as `todo!` instead.
13 pub(crate) fn gen_trait_fn_body(
14     func: &ast::Fn,
15     trait_path: &ast::Path,
16     adt: &ast::Adt,
17 ) -> Option<()> {
18     match trait_path.segment()?.name_ref()?.text().as_str() {
19         "Clone" => gen_clone_impl(adt, func),
20         "Debug" => gen_debug_impl(adt, func),
21         "Default" => gen_default_impl(adt, func),
22         "Hash" => gen_hash_impl(adt, func),
23         "PartialEq" => gen_partial_eq(adt, func),
24         _ => None,
25     }
26 }
27
28 /// Generate a `Clone` impl based on the fields and members of the target type.
29 fn gen_clone_impl(adt: &ast::Adt, func: &ast::Fn) -> Option<()> {
30     fn gen_clone_call(target: ast::Expr) -> ast::Expr {
31         let method = make::name_ref("clone");
32         make::expr_method_call(target, method, make::arg_list(None))
33     }
34     let expr = match adt {
35         // `Clone` cannot be derived for unions, so no default impl can be provided.
36         ast::Adt::Union(_) => return None,
37         ast::Adt::Enum(enum_) => {
38             let list = enum_.variant_list()?;
39             let mut arms = vec![];
40             for variant in list.variants() {
41                 let name = variant.name()?;
42                 let left = make::ext::ident_path("Self");
43                 let right = make::ext::ident_path(&format!("{}", name));
44                 let variant_name = make::path_concat(left, right);
45
46                 match variant.field_list() {
47                     // => match self { Self::Name { x } => Self::Name { x: x.clone() } }
48                     Some(ast::FieldList::RecordFieldList(list)) => {
49                         let mut pats = vec![];
50                         let mut fields = vec![];
51                         for field in list.fields() {
52                             let field_name = field.name()?;
53                             let pat = make::ident_pat(false, false, field_name.clone());
54                             pats.push(pat.into());
55
56                             let path = make::ext::ident_path(&field_name.to_string());
57                             let method_call = gen_clone_call(make::expr_path(path));
58                             let name_ref = make::name_ref(&field_name.to_string());
59                             let field = make::record_expr_field(name_ref, Some(method_call));
60                             fields.push(field);
61                         }
62                         let pat = make::record_pat(variant_name.clone(), pats.into_iter());
63                         let fields = make::record_expr_field_list(fields);
64                         let record_expr = make::record_expr(variant_name, fields).into();
65                         arms.push(make::match_arm(Some(pat.into()), None, record_expr));
66                     }
67
68                     // => match self { Self::Name(arg1) => Self::Name(arg1.clone()) }
69                     Some(ast::FieldList::TupleFieldList(list)) => {
70                         let mut pats = vec![];
71                         let mut fields = vec![];
72                         for (i, _) in list.fields().enumerate() {
73                             let field_name = format!("arg{}", i);
74                             let pat = make::ident_pat(false, false, make::name(&field_name));
75                             pats.push(pat.into());
76
77                             let f_path = make::expr_path(make::ext::ident_path(&field_name));
78                             fields.push(gen_clone_call(f_path));
79                         }
80                         let pat = make::tuple_struct_pat(variant_name.clone(), pats.into_iter());
81                         let struct_name = make::expr_path(variant_name);
82                         let tuple_expr = make::expr_call(struct_name, make::arg_list(fields));
83                         arms.push(make::match_arm(Some(pat.into()), None, tuple_expr));
84                     }
85
86                     // => match self { Self::Name => Self::Name }
87                     None => {
88                         let pattern = make::path_pat(variant_name.clone());
89                         let variant_expr = make::expr_path(variant_name);
90                         arms.push(make::match_arm(Some(pattern.into()), None, variant_expr));
91                     }
92                 }
93             }
94
95             let match_target = make::expr_path(make::ext::ident_path("self"));
96             let list = make::match_arm_list(arms).indent(ast::edit::IndentLevel(1));
97             make::expr_match(match_target, list)
98         }
99         ast::Adt::Struct(strukt) => {
100             match strukt.field_list() {
101                 // => Self { name: self.name.clone() }
102                 Some(ast::FieldList::RecordFieldList(field_list)) => {
103                     let mut fields = vec![];
104                     for field in field_list.fields() {
105                         let base = make::expr_path(make::ext::ident_path("self"));
106                         let target = make::expr_field(base, &field.name()?.to_string());
107                         let method_call = gen_clone_call(target);
108                         let name_ref = make::name_ref(&field.name()?.to_string());
109                         let field = make::record_expr_field(name_ref, Some(method_call));
110                         fields.push(field);
111                     }
112                     let struct_name = make::ext::ident_path("Self");
113                     let fields = make::record_expr_field_list(fields);
114                     make::record_expr(struct_name, fields).into()
115                 }
116                 // => Self(self.0.clone(), self.1.clone())
117                 Some(ast::FieldList::TupleFieldList(field_list)) => {
118                     let mut fields = vec![];
119                     for (i, _) in field_list.fields().enumerate() {
120                         let f_path = make::expr_path(make::ext::ident_path("self"));
121                         let target = make::expr_field(f_path, &format!("{}", i)).into();
122                         fields.push(gen_clone_call(target));
123                     }
124                     let struct_name = make::expr_path(make::ext::ident_path("Self"));
125                     make::expr_call(struct_name, make::arg_list(fields))
126                 }
127                 // => Self { }
128                 None => {
129                     let struct_name = make::ext::ident_path("Self");
130                     let fields = make::record_expr_field_list(None);
131                     make::record_expr(struct_name, fields).into()
132                 }
133             }
134         }
135     };
136     let body = make::block_expr(None, Some(expr)).indent(ast::edit::IndentLevel(1));
137     ted::replace(func.body()?.syntax(), body.clone_for_update().syntax());
138     Some(())
139 }
140
141 /// Generate a `Debug` impl based on the fields and members of the target type.
142 fn gen_debug_impl(adt: &ast::Adt, func: &ast::Fn) -> Option<()> {
143     let annotated_name = adt.name()?;
144     match adt {
145         // `Debug` cannot be derived for unions, so no default impl can be provided.
146         ast::Adt::Union(_) => None,
147
148         // => match self { Self::Variant => write!(f, "Variant") }
149         ast::Adt::Enum(enum_) => {
150             let list = enum_.variant_list()?;
151             let mut arms = vec![];
152             for variant in list.variants() {
153                 let name = variant.name()?;
154                 let left = make::ext::ident_path("Self");
155                 let right = make::ext::ident_path(&format!("{}", name));
156                 let variant_name = make::path_pat(make::path_concat(left, right));
157
158                 let target = make::expr_path(make::ext::ident_path("f").into());
159                 let fmt_string = make::expr_literal(&(format!("\"{}\"", name))).into();
160                 let args = make::arg_list(vec![target, fmt_string]);
161                 let macro_name = make::expr_path(make::ext::ident_path("write"));
162                 let macro_call = make::expr_macro_call(macro_name, args);
163
164                 arms.push(make::match_arm(Some(variant_name.into()), None, macro_call.into()));
165             }
166
167             let match_target = make::expr_path(make::ext::ident_path("self"));
168             let list = make::match_arm_list(arms).indent(ast::edit::IndentLevel(1));
169             let match_expr = make::expr_match(match_target, list);
170
171             let body = make::block_expr(None, Some(match_expr));
172             let body = body.indent(ast::edit::IndentLevel(1));
173             ted::replace(func.body()?.syntax(), body.clone_for_update().syntax());
174             Some(())
175         }
176
177         ast::Adt::Struct(strukt) => {
178             let name = format!("\"{}\"", annotated_name);
179             let args = make::arg_list(Some(make::expr_literal(&name).into()));
180             let target = make::expr_path(make::ext::ident_path("f"));
181
182             let expr = match strukt.field_list() {
183                 // => f.debug_struct("Name").finish()
184                 None => make::expr_method_call(target, make::name_ref("debug_struct"), args),
185
186                 // => f.debug_struct("Name").field("foo", &self.foo).finish()
187                 Some(ast::FieldList::RecordFieldList(field_list)) => {
188                     let method = make::name_ref("debug_struct");
189                     let mut expr = make::expr_method_call(target, method, args);
190                     for field in field_list.fields() {
191                         let name = field.name()?;
192                         let f_name = make::expr_literal(&(format!("\"{}\"", name))).into();
193                         let f_path = make::expr_path(make::ext::ident_path("self"));
194                         let f_path = make::expr_ref(f_path, false);
195                         let f_path = make::expr_field(f_path, &format!("{}", name)).into();
196                         let args = make::arg_list(vec![f_name, f_path]);
197                         expr = make::expr_method_call(expr, make::name_ref("field"), args);
198                     }
199                     expr
200                 }
201
202                 // => f.debug_tuple("Name").field(self.0).finish()
203                 Some(ast::FieldList::TupleFieldList(field_list)) => {
204                     let method = make::name_ref("debug_tuple");
205                     let mut expr = make::expr_method_call(target, method, args);
206                     for (i, _) in field_list.fields().enumerate() {
207                         let f_path = make::expr_path(make::ext::ident_path("self"));
208                         let f_path = make::expr_ref(f_path, false);
209                         let f_path = make::expr_field(f_path, &format!("{}", i)).into();
210                         let method = make::name_ref("field");
211                         expr = make::expr_method_call(expr, method, make::arg_list(Some(f_path)));
212                     }
213                     expr
214                 }
215             };
216
217             let method = make::name_ref("finish");
218             let expr = make::expr_method_call(expr, method, make::arg_list(None));
219             let body = make::block_expr(None, Some(expr)).indent(ast::edit::IndentLevel(1));
220             ted::replace(func.body()?.syntax(), body.clone_for_update().syntax());
221             Some(())
222         }
223     }
224 }
225
226 /// Generate a `Debug` impl based on the fields and members of the target type.
227 fn gen_default_impl(adt: &ast::Adt, func: &ast::Fn) -> Option<()> {
228     fn gen_default_call() -> ast::Expr {
229         let trait_name = make::ext::ident_path("Default");
230         let method_name = make::ext::ident_path("default");
231         let fn_name = make::expr_path(make::path_concat(trait_name, method_name));
232         make::expr_call(fn_name, make::arg_list(None))
233     }
234     match adt {
235         // `Debug` cannot be derived for unions, so no default impl can be provided.
236         ast::Adt::Union(_) => None,
237         // Deriving `Debug` for enums is not stable yet.
238         ast::Adt::Enum(_) => None,
239         ast::Adt::Struct(strukt) => {
240             let expr = match strukt.field_list() {
241                 Some(ast::FieldList::RecordFieldList(field_list)) => {
242                     let mut fields = vec![];
243                     for field in field_list.fields() {
244                         let method_call = gen_default_call();
245                         let name_ref = make::name_ref(&field.name()?.to_string());
246                         let field = make::record_expr_field(name_ref, Some(method_call));
247                         fields.push(field);
248                     }
249                     let struct_name = make::ext::ident_path("Self");
250                     let fields = make::record_expr_field_list(fields);
251                     make::record_expr(struct_name, fields).into()
252                 }
253                 Some(ast::FieldList::TupleFieldList(field_list)) => {
254                     let struct_name = make::expr_path(make::ext::ident_path("Self"));
255                     let fields = field_list.fields().map(|_| gen_default_call());
256                     make::expr_call(struct_name, make::arg_list(fields))
257                 }
258                 None => {
259                     let struct_name = make::ext::ident_path("Self");
260                     let fields = make::record_expr_field_list(None);
261                     make::record_expr(struct_name, fields).into()
262                 }
263             };
264             let body = make::block_expr(None, Some(expr)).indent(ast::edit::IndentLevel(1));
265             ted::replace(func.body()?.syntax(), body.clone_for_update().syntax());
266             Some(())
267         }
268     }
269 }
270
271 /// Generate a `Hash` impl based on the fields and members of the target type.
272 fn gen_hash_impl(adt: &ast::Adt, func: &ast::Fn) -> Option<()> {
273     fn gen_hash_call(target: ast::Expr) -> ast::Stmt {
274         let method = make::name_ref("hash");
275         let arg = make::expr_path(make::ext::ident_path("state"));
276         let expr = make::expr_method_call(target, method, make::arg_list(Some(arg)));
277         let stmt = make::expr_stmt(expr);
278         stmt.into()
279     }
280
281     let body = match adt {
282         // `Hash` cannot be derived for unions, so no default impl can be provided.
283         ast::Adt::Union(_) => return None,
284
285         // => std::mem::discriminant(self).hash(state);
286         ast::Adt::Enum(_) => {
287             let root = make::ext::ident_path("core");
288             let submodule = make::ext::ident_path("mem");
289             let fn_name = make::ext::ident_path("discriminant");
290             let fn_name = make::path_concat(submodule, fn_name);
291             let fn_name = make::expr_path(make::path_concat(root, fn_name));
292
293             let arg = make::expr_path(make::ext::ident_path("self"));
294             let fn_call = make::expr_call(fn_name, make::arg_list(Some(arg)));
295             let stmt = gen_hash_call(fn_call);
296
297             make::block_expr(Some(stmt), None).indent(ast::edit::IndentLevel(1))
298         }
299         ast::Adt::Struct(strukt) => match strukt.field_list() {
300             // => self.<field>.hash(state);
301             Some(ast::FieldList::RecordFieldList(field_list)) => {
302                 let mut stmts = vec![];
303                 for field in field_list.fields() {
304                     let base = make::expr_path(make::ext::ident_path("self"));
305                     let target = make::expr_field(base, &field.name()?.to_string());
306                     stmts.push(gen_hash_call(target));
307                 }
308                 make::block_expr(stmts, None).indent(ast::edit::IndentLevel(1))
309             }
310
311             // => self.<field_index>.hash(state);
312             Some(ast::FieldList::TupleFieldList(field_list)) => {
313                 let mut stmts = vec![];
314                 for (i, _) in field_list.fields().enumerate() {
315                     let base = make::expr_path(make::ext::ident_path("self"));
316                     let target = make::expr_field(base, &format!("{}", i));
317                     stmts.push(gen_hash_call(target));
318                 }
319                 make::block_expr(stmts, None).indent(ast::edit::IndentLevel(1))
320             }
321
322             // No fields in the body means there's nothing to hash.
323             None => return None,
324         },
325     };
326
327     ted::replace(func.body()?.syntax(), body.clone_for_update().syntax());
328     Some(())
329 }
330
331 /// Generate a `PartialEq` impl based on the fields and members of the target type.
332 fn gen_partial_eq(adt: &ast::Adt, func: &ast::Fn) -> Option<()> {
333     fn gen_discriminant() -> ast::Expr {
334         let root = make::ext::ident_path("core");
335         let submodule = make::ext::ident_path("mem");
336         let fn_name = make::ext::ident_path("discriminant");
337         let fn_name = make::path_concat(submodule, fn_name);
338         let fn_name = make::expr_path(make::path_concat(root, fn_name));
339         fn_name
340     }
341
342     fn gen_eq_chain(expr: Option<ast::Expr>, cmp: ast::Expr) -> Option<ast::Expr> {
343         match expr {
344             Some(expr) => Some(make::expr_op(ast::BinOp::BooleanAnd, expr, cmp)),
345             None => Some(cmp),
346         }
347     }
348
349     fn gen_record_pat_field(field_name: &str, pat_name: &str) -> ast::RecordPatField {
350         let pat = make::ext::simple_ident_pat(make::name(&pat_name));
351         let name_ref = make::name_ref(field_name);
352         let field = make::record_pat_field(name_ref, pat.into());
353         field
354     }
355
356     fn gen_record_pat(
357         record_name: ast::Path,
358         r_fields: Vec<ast::RecordPatField>,
359     ) -> ast::RecordPat {
360         let list = make::record_pat_field_list(r_fields);
361         make::record_pat_with_fields(record_name, list)
362     }
363
364     fn gen_variant_path(variant: &ast::Variant) -> Option<ast::Path> {
365         let first = make::ext::ident_path("Self");
366         let second = make::path_from_text(&variant.name()?.to_string());
367         let record_name = make::path_concat(first, second);
368         Some(record_name)
369     }
370
371     fn gen_tuple_field(field_name: &String) -> ast::Pat {
372         ast::Pat::IdentPat(make::ident_pat(false, false, make::name(field_name)))
373     }
374     // FIXME: return `None` if the trait carries a generic type; we can only
375     // generate this code `Self` for the time being.
376
377     let body = match adt {
378         // `Hash` cannot be derived for unions, so no default impl can be provided.
379         ast::Adt::Union(_) => return None,
380
381         ast::Adt::Enum(enum_) => {
382             // => std::mem::discriminant(self) == std::mem::discriminant(other)
383             let self_name = make::expr_path(make::ext::ident_path("self"));
384             let lhs = make::expr_call(gen_discriminant(), make::arg_list(Some(self_name.clone())));
385             let other_name = make::expr_path(make::ext::ident_path("other"));
386             let rhs = make::expr_call(gen_discriminant(), make::arg_list(Some(other_name.clone())));
387             let eq_check = make::expr_op(ast::BinOp::EqualityTest, lhs, rhs);
388
389             let mut case_count = 0;
390             let mut arms = vec![];
391             for variant in enum_.variant_list()?.variants() {
392                 case_count += 1;
393                 match variant.field_list() {
394                     // => (Self::Bar { bin: l_bin }, Self::Bar { bin: r_bin }) => l_bin == r_bin,
395                     Some(ast::FieldList::RecordFieldList(list)) => {
396                         let mut expr = None;
397                         let mut l_fields = vec![];
398                         let mut r_fields = vec![];
399
400                         for field in list.fields() {
401                             let field_name = field.name()?.to_string();
402
403                             let l_name = &format!("l_{}", field_name);
404                             l_fields.push(gen_record_pat_field(&field_name, &l_name));
405
406                             let r_name = &format!("r_{}", field_name);
407                             r_fields.push(gen_record_pat_field(&field_name, &r_name));
408
409                             let lhs = make::expr_path(make::ext::ident_path(l_name));
410                             let rhs = make::expr_path(make::ext::ident_path(r_name));
411                             let cmp = make::expr_op(ast::BinOp::EqualityTest, lhs, rhs);
412                             expr = gen_eq_chain(expr, cmp);
413                         }
414
415                         let left = gen_record_pat(gen_variant_path(&variant)?, l_fields);
416                         let right = gen_record_pat(gen_variant_path(&variant)?, r_fields);
417                         let tuple = make::tuple_pat(vec![left.into(), right.into()]);
418
419                         if let Some(expr) = expr {
420                             arms.push(make::match_arm(Some(tuple.into()), None, expr));
421                         }
422                     }
423
424                     // todo!("implement tuple record iteration")
425                     Some(ast::FieldList::TupleFieldList(list)) => {
426                         let mut expr = None;
427                         let mut l_fields = vec![];
428                         let mut r_fields = vec![];
429
430                         for (i, _) in list.fields().enumerate() {
431                             let field_name = format!("{}", i);
432
433                             let l_name = format!("l{}", field_name);
434                             l_fields.push(gen_tuple_field(&l_name));
435
436                             let r_name = format!("r{}", field_name);
437                             r_fields.push(gen_tuple_field(&r_name));
438
439                             let lhs = make::expr_path(make::ext::ident_path(&l_name));
440                             let rhs = make::expr_path(make::ext::ident_path(&r_name));
441                             let cmp = make::expr_op(ast::BinOp::EqualityTest, lhs, rhs);
442                             expr = gen_eq_chain(expr, cmp);
443                         }
444
445                         let left = make::tuple_struct_pat(gen_variant_path(&variant)?, l_fields);
446                         let right = make::tuple_struct_pat(gen_variant_path(&variant)?, r_fields);
447                         let tuple = make::tuple_pat(vec![left.into(), right.into()]);
448
449                         if let Some(expr) = expr {
450                             arms.push(make::match_arm(Some(tuple.into()), None, expr));
451                         }
452                     }
453                     None => continue,
454                 }
455             }
456
457             if !arms.is_empty() && case_count > arms.len() {
458                 let lhs = make::wildcard_pat().into();
459                 arms.push(make::match_arm(Some(lhs), None, make::expr_literal("true").into()));
460             }
461
462             let expr = match arms.len() {
463                 0 => eq_check,
464                 _ => {
465                     let condition = make::condition(eq_check, None);
466
467                     let match_target = make::expr_tuple(vec![self_name, other_name]);
468                     let list = make::match_arm_list(arms).indent(ast::edit::IndentLevel(1));
469                     let match_expr = Some(make::expr_match(match_target, list));
470                     let then_branch = make::block_expr(None, match_expr);
471                     let then_branch = then_branch.indent(ast::edit::IndentLevel(1));
472
473                     let else_branche = make::expr_literal("false");
474                     let else_branche = make::block_expr(None, Some(else_branche.into()))
475                         .indent(ast::edit::IndentLevel(1));
476
477                     make::expr_if(condition, then_branch, Some(else_branche.into()))
478                 }
479             };
480
481             make::block_expr(None, Some(expr)).indent(ast::edit::IndentLevel(1))
482         }
483         ast::Adt::Struct(strukt) => match strukt.field_list() {
484             Some(ast::FieldList::RecordFieldList(field_list)) => {
485                 let mut expr = None;
486                 for field in field_list.fields() {
487                     let lhs = make::expr_path(make::ext::ident_path("self"));
488                     let lhs = make::expr_field(lhs, &field.name()?.to_string());
489                     let rhs = make::expr_path(make::ext::ident_path("other"));
490                     let rhs = make::expr_field(rhs, &field.name()?.to_string());
491                     let cmp = make::expr_op(ast::BinOp::EqualityTest, lhs, rhs);
492                     expr = gen_eq_chain(expr, cmp);
493                 }
494                 make::block_expr(None, expr).indent(ast::edit::IndentLevel(1))
495             }
496
497             Some(ast::FieldList::TupleFieldList(field_list)) => {
498                 let mut expr = None;
499                 for (i, _) in field_list.fields().enumerate() {
500                     let idx = format!("{}", i);
501                     let lhs = make::expr_path(make::ext::ident_path("self"));
502                     let lhs = make::expr_field(lhs, &idx);
503                     let rhs = make::expr_path(make::ext::ident_path("other"));
504                     let rhs = make::expr_field(rhs, &idx);
505                     let cmp = make::expr_op(ast::BinOp::EqualityTest, lhs, rhs);
506                     expr = gen_eq_chain(expr, cmp);
507                 }
508                 make::block_expr(None, expr).indent(ast::edit::IndentLevel(1))
509             }
510
511             // No fields in the body means there's nothing to hash.
512             None => {
513                 let expr = make::expr_literal("true").into();
514                 make::block_expr(None, Some(expr)).indent(ast::edit::IndentLevel(1))
515             }
516         },
517     };
518
519     ted::replace(func.body()?.syntax(), body.clone_for_update().syntax());
520     Some(())
521 }