1 //! This module contains functions to generate default trait impl function bodies where possible.
4 ast::{self, edit::AstNodeEdit, make, AstNode, BinaryOp, CmpOp, LogicOp, NameOwner},
8 /// Generate custom trait bodies where possible.
10 /// Returns `Option` so that we can use `?` rather than `if let Some`. Returning
11 /// `None` means that generating a custom trait body failed, and the body will remain
12 /// as `todo!` instead.
13 pub(crate) fn gen_trait_fn_body(
15 trait_path: &ast::Path,
18 match trait_path.segment()?.name_ref()?.text().as_str() {
19 "Clone" => gen_clone_impl(adt, func),
20 "Debug" => gen_debug_impl(adt, func),
21 "Default" => gen_default_impl(adt, func),
22 "Hash" => gen_hash_impl(adt, func),
23 "PartialEq" => gen_partial_eq(adt, func),
28 /// Generate a `Clone` impl based on the fields and members of the target type.
29 fn gen_clone_impl(adt: &ast::Adt, func: &ast::Fn) -> Option<()> {
30 fn gen_clone_call(target: ast::Expr) -> ast::Expr {
31 let method = make::name_ref("clone");
32 make::expr_method_call(target, method, make::arg_list(None))
34 let expr = match adt {
35 // `Clone` cannot be derived for unions, so no default impl can be provided.
36 ast::Adt::Union(_) => return None,
37 ast::Adt::Enum(enum_) => {
38 let list = enum_.variant_list()?;
39 let mut arms = vec![];
40 for variant in list.variants() {
41 let name = variant.name()?;
42 let variant_name = make::ext::path_from_idents(["Self", &format!("{}", name)])?;
44 match variant.field_list() {
45 // => match self { Self::Name { x } => Self::Name { x: x.clone() } }
46 Some(ast::FieldList::RecordFieldList(list)) => {
47 let mut pats = vec![];
48 let mut fields = vec![];
49 for field in list.fields() {
50 let field_name = field.name()?;
51 let pat = make::ident_pat(false, false, field_name.clone());
52 pats.push(pat.into());
54 let path = make::ext::ident_path(&field_name.to_string());
55 let method_call = gen_clone_call(make::expr_path(path));
56 let name_ref = make::name_ref(&field_name.to_string());
57 let field = make::record_expr_field(name_ref, Some(method_call));
60 let pat = make::record_pat(variant_name.clone(), pats.into_iter());
61 let fields = make::record_expr_field_list(fields);
62 let record_expr = make::record_expr(variant_name, fields).into();
63 arms.push(make::match_arm(Some(pat.into()), None, record_expr));
66 // => match self { Self::Name(arg1) => Self::Name(arg1.clone()) }
67 Some(ast::FieldList::TupleFieldList(list)) => {
68 let mut pats = vec![];
69 let mut fields = vec![];
70 for (i, _) in list.fields().enumerate() {
71 let field_name = format!("arg{}", i);
72 let pat = make::ident_pat(false, false, make::name(&field_name));
73 pats.push(pat.into());
75 let f_path = make::expr_path(make::ext::ident_path(&field_name));
76 fields.push(gen_clone_call(f_path));
78 let pat = make::tuple_struct_pat(variant_name.clone(), pats.into_iter());
79 let struct_name = make::expr_path(variant_name);
80 let tuple_expr = make::expr_call(struct_name, make::arg_list(fields));
81 arms.push(make::match_arm(Some(pat.into()), None, tuple_expr));
84 // => match self { Self::Name => Self::Name }
86 let pattern = make::path_pat(variant_name.clone());
87 let variant_expr = make::expr_path(variant_name);
88 arms.push(make::match_arm(Some(pattern.into()), None, variant_expr));
93 let match_target = make::expr_path(make::ext::ident_path("self"));
94 let list = make::match_arm_list(arms).indent(ast::edit::IndentLevel(1));
95 make::expr_match(match_target, list)
97 ast::Adt::Struct(strukt) => {
98 match strukt.field_list() {
99 // => Self { name: self.name.clone() }
100 Some(ast::FieldList::RecordFieldList(field_list)) => {
101 let mut fields = vec![];
102 for field in field_list.fields() {
103 let base = make::expr_path(make::ext::ident_path("self"));
104 let target = make::expr_field(base, &field.name()?.to_string());
105 let method_call = gen_clone_call(target);
106 let name_ref = make::name_ref(&field.name()?.to_string());
107 let field = make::record_expr_field(name_ref, Some(method_call));
110 let struct_name = make::ext::ident_path("Self");
111 let fields = make::record_expr_field_list(fields);
112 make::record_expr(struct_name, fields).into()
114 // => Self(self.0.clone(), self.1.clone())
115 Some(ast::FieldList::TupleFieldList(field_list)) => {
116 let mut fields = vec![];
117 for (i, _) in field_list.fields().enumerate() {
118 let f_path = make::expr_path(make::ext::ident_path("self"));
119 let target = make::expr_field(f_path, &format!("{}", i)).into();
120 fields.push(gen_clone_call(target));
122 let struct_name = make::expr_path(make::ext::ident_path("Self"));
123 make::expr_call(struct_name, make::arg_list(fields))
127 let struct_name = make::ext::ident_path("Self");
128 let fields = make::record_expr_field_list(None);
129 make::record_expr(struct_name, fields).into()
134 let body = make::block_expr(None, Some(expr)).indent(ast::edit::IndentLevel(1));
135 ted::replace(func.body()?.syntax(), body.clone_for_update().syntax());
139 /// Generate a `Debug` impl based on the fields and members of the target type.
140 fn gen_debug_impl(adt: &ast::Adt, func: &ast::Fn) -> Option<()> {
141 let annotated_name = adt.name()?;
143 // `Debug` cannot be derived for unions, so no default impl can be provided.
144 ast::Adt::Union(_) => None,
146 // => match self { Self::Variant => write!(f, "Variant") }
147 ast::Adt::Enum(enum_) => {
148 let list = enum_.variant_list()?;
149 let mut arms = vec![];
150 for variant in list.variants() {
151 let name = variant.name()?;
152 let variant_name = make::ext::path_from_idents(["Self", &format!("{}", name)])?;
153 let target = make::expr_path(make::ext::ident_path("f").into());
155 match variant.field_list() {
156 Some(ast::FieldList::RecordFieldList(list)) => {
157 // => f.debug_struct(name)
158 let target = make::expr_path(make::ext::ident_path("f"));
159 let method = make::name_ref("debug_struct");
160 let struct_name = format!("\"{}\"", name);
161 let args = make::arg_list(Some(make::expr_literal(&struct_name).into()));
162 let mut expr = make::expr_method_call(target, method, args);
164 let mut pats = vec![];
165 for field in list.fields() {
166 let field_name = field.name()?;
168 // create a field pattern for use in `MyStruct { fields.. }`
169 let pat = make::ident_pat(false, false, field_name.clone());
170 pats.push(pat.into());
172 // => <expr>.field("field_name", field)
173 let method_name = make::name_ref("field");
174 let name = make::expr_literal(&(format!("\"{}\"", field_name))).into();
175 let path = &format!("{}", field_name);
176 let path = make::expr_path(make::ext::ident_path(path));
177 let args = make::arg_list(vec![name, path]);
178 expr = make::expr_method_call(expr, method_name, args);
181 // => <expr>.finish()
182 let method = make::name_ref("finish");
183 let expr = make::expr_method_call(expr, method, make::arg_list(None));
185 // => MyStruct { fields.. } => f.debug_struct("MyStruct")...finish(),
186 let pat = make::record_pat(variant_name.clone(), pats.into_iter());
187 arms.push(make::match_arm(Some(pat.into()), None, expr));
189 Some(ast::FieldList::TupleFieldList(list)) => {
190 // => f.debug_tuple(name)
191 let target = make::expr_path(make::ext::ident_path("f"));
192 let method = make::name_ref("debug_tuple");
193 let struct_name = format!("\"{}\"", name);
194 let args = make::arg_list(Some(make::expr_literal(&struct_name).into()));
195 let mut expr = make::expr_method_call(target, method, args);
197 let mut pats = vec![];
198 for (i, _) in list.fields().enumerate() {
199 let name = format!("arg{}", i);
201 // create a field pattern for use in `MyStruct(fields..)`
202 let field_name = make::name(&name);
203 let pat = make::ident_pat(false, false, field_name.clone());
204 pats.push(pat.into());
206 // => <expr>.field(field)
207 let method_name = make::name_ref("field");
208 let field_path = &format!("{}", name);
209 let field_path = make::expr_path(make::ext::ident_path(field_path));
210 let args = make::arg_list(vec![field_path]);
211 expr = make::expr_method_call(expr, method_name, args);
214 // => <expr>.finish()
215 let method = make::name_ref("finish");
216 let expr = make::expr_method_call(expr, method, make::arg_list(None));
218 // => MyStruct (fields..) => f.debug_tuple("MyStruct")...finish(),
219 let pat = make::tuple_struct_pat(variant_name.clone(), pats.into_iter());
220 arms.push(make::match_arm(Some(pat.into()), None, expr));
223 let fmt_string = make::expr_literal(&(format!("\"{}\"", name))).into();
224 let args = make::arg_list([target, fmt_string]);
225 let macro_name = make::expr_path(make::ext::ident_path("write"));
226 let macro_call = make::expr_macro_call(macro_name, args);
228 let variant_name = make::path_pat(variant_name);
229 arms.push(make::match_arm(
230 Some(variant_name.into()),
238 let match_target = make::expr_path(make::ext::ident_path("self"));
239 let list = make::match_arm_list(arms).indent(ast::edit::IndentLevel(1));
240 let match_expr = make::expr_match(match_target, list);
242 let body = make::block_expr(None, Some(match_expr));
243 let body = body.indent(ast::edit::IndentLevel(1));
244 ted::replace(func.body()?.syntax(), body.clone_for_update().syntax());
248 ast::Adt::Struct(strukt) => {
249 let name = format!("\"{}\"", annotated_name);
250 let args = make::arg_list(Some(make::expr_literal(&name).into()));
251 let target = make::expr_path(make::ext::ident_path("f"));
253 let expr = match strukt.field_list() {
254 // => f.debug_struct("Name").finish()
255 None => make::expr_method_call(target, make::name_ref("debug_struct"), args),
257 // => f.debug_struct("Name").field("foo", &self.foo).finish()
258 Some(ast::FieldList::RecordFieldList(field_list)) => {
259 let method = make::name_ref("debug_struct");
260 let mut expr = make::expr_method_call(target, method, args);
261 for field in field_list.fields() {
262 let name = field.name()?;
263 let f_name = make::expr_literal(&(format!("\"{}\"", name))).into();
264 let f_path = make::expr_path(make::ext::ident_path("self"));
265 let f_path = make::expr_ref(f_path, false);
266 let f_path = make::expr_field(f_path, &format!("{}", name)).into();
267 let args = make::arg_list([f_name, f_path]);
268 expr = make::expr_method_call(expr, make::name_ref("field"), args);
273 // => f.debug_tuple("Name").field(self.0).finish()
274 Some(ast::FieldList::TupleFieldList(field_list)) => {
275 let method = make::name_ref("debug_tuple");
276 let mut expr = make::expr_method_call(target, method, args);
277 for (i, _) in field_list.fields().enumerate() {
278 let f_path = make::expr_path(make::ext::ident_path("self"));
279 let f_path = make::expr_ref(f_path, false);
280 let f_path = make::expr_field(f_path, &format!("{}", i)).into();
281 let method = make::name_ref("field");
282 expr = make::expr_method_call(expr, method, make::arg_list(Some(f_path)));
288 let method = make::name_ref("finish");
289 let expr = make::expr_method_call(expr, method, make::arg_list(None));
290 let body = make::block_expr(None, Some(expr)).indent(ast::edit::IndentLevel(1));
291 ted::replace(func.body()?.syntax(), body.clone_for_update().syntax());
297 /// Generate a `Debug` impl based on the fields and members of the target type.
298 fn gen_default_impl(adt: &ast::Adt, func: &ast::Fn) -> Option<()> {
299 fn gen_default_call() -> Option<ast::Expr> {
300 let fn_name = make::ext::path_from_idents(["Default", "default"])?;
301 Some(make::expr_call(make::expr_path(fn_name), make::arg_list(None)))
304 // `Debug` cannot be derived for unions, so no default impl can be provided.
305 ast::Adt::Union(_) => None,
306 // Deriving `Debug` for enums is not stable yet.
307 ast::Adt::Enum(_) => None,
308 ast::Adt::Struct(strukt) => {
309 let expr = match strukt.field_list() {
310 Some(ast::FieldList::RecordFieldList(field_list)) => {
311 let mut fields = vec![];
312 for field in field_list.fields() {
313 let method_call = gen_default_call()?;
314 let name_ref = make::name_ref(&field.name()?.to_string());
315 let field = make::record_expr_field(name_ref, Some(method_call));
318 let struct_name = make::ext::ident_path("Self");
319 let fields = make::record_expr_field_list(fields);
320 make::record_expr(struct_name, fields).into()
322 Some(ast::FieldList::TupleFieldList(field_list)) => {
323 let struct_name = make::expr_path(make::ext::ident_path("Self"));
324 let fields = field_list
326 .map(|_| gen_default_call())
327 .collect::<Option<Vec<ast::Expr>>>()?;
328 make::expr_call(struct_name, make::arg_list(fields))
331 let struct_name = make::ext::ident_path("Self");
332 let fields = make::record_expr_field_list(None);
333 make::record_expr(struct_name, fields).into()
336 let body = make::block_expr(None, Some(expr)).indent(ast::edit::IndentLevel(1));
337 ted::replace(func.body()?.syntax(), body.clone_for_update().syntax());
343 /// Generate a `Hash` impl based on the fields and members of the target type.
344 fn gen_hash_impl(adt: &ast::Adt, func: &ast::Fn) -> Option<()> {
345 fn gen_hash_call(target: ast::Expr) -> ast::Stmt {
346 let method = make::name_ref("hash");
347 let arg = make::expr_path(make::ext::ident_path("state"));
348 let expr = make::expr_method_call(target, method, make::arg_list(Some(arg)));
349 make::expr_stmt(expr).into()
352 let body = match adt {
353 // `Hash` cannot be derived for unions, so no default impl can be provided.
354 ast::Adt::Union(_) => return None,
356 // => std::mem::discriminant(self).hash(state);
357 ast::Adt::Enum(_) => {
358 let fn_name = make_discriminant()?;
360 let arg = make::expr_path(make::ext::ident_path("self"));
361 let fn_call = make::expr_call(fn_name, make::arg_list(Some(arg)));
362 let stmt = gen_hash_call(fn_call);
364 make::block_expr(Some(stmt), None).indent(ast::edit::IndentLevel(1))
366 ast::Adt::Struct(strukt) => match strukt.field_list() {
367 // => self.<field>.hash(state);
368 Some(ast::FieldList::RecordFieldList(field_list)) => {
369 let mut stmts = vec![];
370 for field in field_list.fields() {
371 let base = make::expr_path(make::ext::ident_path("self"));
372 let target = make::expr_field(base, &field.name()?.to_string());
373 stmts.push(gen_hash_call(target));
375 make::block_expr(stmts, None).indent(ast::edit::IndentLevel(1))
378 // => self.<field_index>.hash(state);
379 Some(ast::FieldList::TupleFieldList(field_list)) => {
380 let mut stmts = vec![];
381 for (i, _) in field_list.fields().enumerate() {
382 let base = make::expr_path(make::ext::ident_path("self"));
383 let target = make::expr_field(base, &format!("{}", i));
384 stmts.push(gen_hash_call(target));
386 make::block_expr(stmts, None).indent(ast::edit::IndentLevel(1))
389 // No fields in the body means there's nothing to hash.
394 ted::replace(func.body()?.syntax(), body.clone_for_update().syntax());
398 /// Generate a `PartialEq` impl based on the fields and members of the target type.
399 fn gen_partial_eq(adt: &ast::Adt, func: &ast::Fn) -> Option<()> {
400 fn gen_eq_chain(expr: Option<ast::Expr>, cmp: ast::Expr) -> Option<ast::Expr> {
402 Some(expr) => Some(make::expr_bin_op(expr, BinaryOp::LogicOp(LogicOp::And), cmp)),
407 fn gen_record_pat_field(field_name: &str, pat_name: &str) -> ast::RecordPatField {
408 let pat = make::ext::simple_ident_pat(make::name(&pat_name));
409 let name_ref = make::name_ref(field_name);
410 make::record_pat_field(name_ref, pat.into())
413 fn gen_record_pat(record_name: ast::Path, fields: Vec<ast::RecordPatField>) -> ast::RecordPat {
414 let list = make::record_pat_field_list(fields);
415 make::record_pat_with_fields(record_name, list)
418 fn gen_variant_path(variant: &ast::Variant) -> Option<ast::Path> {
419 make::ext::path_from_idents(["Self", &variant.name()?.to_string()])
422 fn gen_tuple_field(field_name: &String) -> ast::Pat {
423 ast::Pat::IdentPat(make::ident_pat(false, false, make::name(field_name)))
426 // FIXME: return `None` if the trait carries a generic type; we can only
427 // generate this code `Self` for the time being.
429 let body = match adt {
430 // `Hash` cannot be derived for unions, so no default impl can be provided.
431 ast::Adt::Union(_) => return None,
433 ast::Adt::Enum(enum_) => {
434 // => std::mem::discriminant(self) == std::mem::discriminant(other)
435 let lhs_name = make::expr_path(make::ext::ident_path("self"));
436 let lhs = make::expr_call(make_discriminant()?, make::arg_list(Some(lhs_name.clone())));
437 let rhs_name = make::expr_path(make::ext::ident_path("other"));
438 let rhs = make::expr_call(make_discriminant()?, make::arg_list(Some(rhs_name.clone())));
440 make::expr_bin_op(lhs, BinaryOp::CmpOp(CmpOp::Eq { negated: false }), rhs);
443 let mut arms = vec![];
444 for variant in enum_.variant_list()?.variants() {
446 match variant.field_list() {
447 // => (Self::Bar { bin: l_bin }, Self::Bar { bin: r_bin }) => l_bin == r_bin,
448 Some(ast::FieldList::RecordFieldList(list)) => {
450 let mut l_fields = vec![];
451 let mut r_fields = vec![];
453 for field in list.fields() {
454 let field_name = field.name()?.to_string();
456 let l_name = &format!("l_{}", field_name);
457 l_fields.push(gen_record_pat_field(&field_name, &l_name));
459 let r_name = &format!("r_{}", field_name);
460 r_fields.push(gen_record_pat_field(&field_name, &r_name));
462 let lhs = make::expr_path(make::ext::ident_path(l_name));
463 let rhs = make::expr_path(make::ext::ident_path(r_name));
464 let cmp = make::expr_bin_op(
466 BinaryOp::CmpOp(CmpOp::Eq { negated: false }),
469 expr = gen_eq_chain(expr, cmp);
472 let left = gen_record_pat(gen_variant_path(&variant)?, l_fields);
473 let right = gen_record_pat(gen_variant_path(&variant)?, r_fields);
474 let tuple = make::tuple_pat(vec![left.into(), right.into()]);
476 if let Some(expr) = expr {
477 arms.push(make::match_arm(Some(tuple.into()), None, expr));
481 Some(ast::FieldList::TupleFieldList(list)) => {
483 let mut l_fields = vec![];
484 let mut r_fields = vec![];
486 for (i, _) in list.fields().enumerate() {
487 let field_name = format!("{}", i);
489 let l_name = format!("l{}", field_name);
490 l_fields.push(gen_tuple_field(&l_name));
492 let r_name = format!("r{}", field_name);
493 r_fields.push(gen_tuple_field(&r_name));
495 let lhs = make::expr_path(make::ext::ident_path(&l_name));
496 let rhs = make::expr_path(make::ext::ident_path(&r_name));
497 let cmp = make::expr_bin_op(
499 BinaryOp::CmpOp(CmpOp::Eq { negated: false }),
502 expr = gen_eq_chain(expr, cmp);
505 let left = make::tuple_struct_pat(gen_variant_path(&variant)?, l_fields);
506 let right = make::tuple_struct_pat(gen_variant_path(&variant)?, r_fields);
507 let tuple = make::tuple_pat(vec![left.into(), right.into()]);
509 if let Some(expr) = expr {
510 arms.push(make::match_arm(Some(tuple.into()), None, expr));
517 let expr = match arms.len() {
520 if n_cases > arms.len() {
521 let lhs = make::wildcard_pat().into();
522 arms.push(make::match_arm(Some(lhs), None, eq_check));
525 let match_target = make::expr_tuple(vec![lhs_name, rhs_name]);
526 let list = make::match_arm_list(arms).indent(ast::edit::IndentLevel(1));
527 make::expr_match(match_target, list)
531 make::block_expr(None, Some(expr)).indent(ast::edit::IndentLevel(1))
533 ast::Adt::Struct(strukt) => match strukt.field_list() {
534 Some(ast::FieldList::RecordFieldList(field_list)) => {
536 for field in field_list.fields() {
537 let lhs = make::expr_path(make::ext::ident_path("self"));
538 let lhs = make::expr_field(lhs, &field.name()?.to_string());
539 let rhs = make::expr_path(make::ext::ident_path("other"));
540 let rhs = make::expr_field(rhs, &field.name()?.to_string());
542 make::expr_bin_op(lhs, BinaryOp::CmpOp(CmpOp::Eq { negated: false }), rhs);
543 expr = gen_eq_chain(expr, cmp);
545 make::block_expr(None, expr).indent(ast::edit::IndentLevel(1))
548 Some(ast::FieldList::TupleFieldList(field_list)) => {
550 for (i, _) in field_list.fields().enumerate() {
551 let idx = format!("{}", i);
552 let lhs = make::expr_path(make::ext::ident_path("self"));
553 let lhs = make::expr_field(lhs, &idx);
554 let rhs = make::expr_path(make::ext::ident_path("other"));
555 let rhs = make::expr_field(rhs, &idx);
557 make::expr_bin_op(lhs, BinaryOp::CmpOp(CmpOp::Eq { negated: false }), rhs);
558 expr = gen_eq_chain(expr, cmp);
560 make::block_expr(None, expr).indent(ast::edit::IndentLevel(1))
563 // No fields in the body means there's nothing to hash.
565 let expr = make::expr_literal("true").into();
566 make::block_expr(None, Some(expr)).indent(ast::edit::IndentLevel(1))
571 ted::replace(func.body()?.syntax(), body.clone_for_update().syntax());
575 fn make_discriminant() -> Option<ast::Expr> {
576 Some(make::expr_path(make::ext::path_from_idents(["core", "mem", "discriminant"])?))