1 //! This module contains functions to generate default trait impl function bodies where possible.
4 ast::{self, edit::AstNodeEdit, make, AstNode, NameOwner},
8 /// Generate custom trait bodies where possible.
10 /// Returns `Option` so that we can use `?` rather than `if let Some`. Returning
11 /// `None` means that generating a custom trait body failed, and the body will remain
12 /// as `todo!` instead.
13 pub(crate) fn gen_trait_fn_body(
15 trait_path: &ast::Path,
18 match trait_path.segment()?.name_ref()?.text().as_str() {
19 "Clone" => gen_clone_impl(adt, func),
20 "Debug" => gen_debug_impl(adt, func),
21 "Default" => gen_default_impl(adt, func),
22 "Hash" => gen_hash_impl(adt, func),
23 "PartialEq" => gen_partial_eq(adt, func),
28 /// Generate a `Clone` impl based on the fields and members of the target type.
29 fn gen_clone_impl(adt: &ast::Adt, func: &ast::Fn) -> Option<()> {
30 fn gen_clone_call(target: ast::Expr) -> ast::Expr {
31 let method = make::name_ref("clone");
32 make::expr_method_call(target, method, make::arg_list(None))
34 let expr = match adt {
35 // `Clone` cannot be derived for unions, so no default impl can be provided.
36 ast::Adt::Union(_) => return None,
37 ast::Adt::Enum(enum_) => {
38 let list = enum_.variant_list()?;
39 let mut arms = vec![];
40 for variant in list.variants() {
41 let name = variant.name()?;
42 let left = make::ext::ident_path("Self");
43 let right = make::ext::ident_path(&format!("{}", name));
44 let variant_name = make::path_concat(left, right);
46 match variant.field_list() {
47 // => match self { Self::Name { x } => Self::Name { x: x.clone() } }
48 Some(ast::FieldList::RecordFieldList(list)) => {
49 let mut pats = vec![];
50 let mut fields = vec![];
51 for field in list.fields() {
52 let field_name = field.name()?;
53 let pat = make::ident_pat(false, false, field_name.clone());
54 pats.push(pat.into());
56 let path = make::ext::ident_path(&field_name.to_string());
57 let method_call = gen_clone_call(make::expr_path(path));
58 let name_ref = make::name_ref(&field_name.to_string());
59 let field = make::record_expr_field(name_ref, Some(method_call));
62 let pat = make::record_pat(variant_name.clone(), pats.into_iter());
63 let fields = make::record_expr_field_list(fields);
64 let record_expr = make::record_expr(variant_name, fields).into();
65 arms.push(make::match_arm(Some(pat.into()), None, record_expr));
68 // => match self { Self::Name(arg1) => Self::Name(arg1.clone()) }
69 Some(ast::FieldList::TupleFieldList(list)) => {
70 let mut pats = vec![];
71 let mut fields = vec![];
72 for (i, _) in list.fields().enumerate() {
73 let field_name = format!("arg{}", i);
74 let pat = make::ident_pat(false, false, make::name(&field_name));
75 pats.push(pat.into());
77 let f_path = make::expr_path(make::ext::ident_path(&field_name));
78 fields.push(gen_clone_call(f_path));
80 let pat = make::tuple_struct_pat(variant_name.clone(), pats.into_iter());
81 let struct_name = make::expr_path(variant_name);
82 let tuple_expr = make::expr_call(struct_name, make::arg_list(fields));
83 arms.push(make::match_arm(Some(pat.into()), None, tuple_expr));
86 // => match self { Self::Name => Self::Name }
88 let pattern = make::path_pat(variant_name.clone());
89 let variant_expr = make::expr_path(variant_name);
90 arms.push(make::match_arm(Some(pattern.into()), None, variant_expr));
95 let match_target = make::expr_path(make::ext::ident_path("self"));
96 let list = make::match_arm_list(arms).indent(ast::edit::IndentLevel(1));
97 make::expr_match(match_target, list)
99 ast::Adt::Struct(strukt) => {
100 match strukt.field_list() {
101 // => Self { name: self.name.clone() }
102 Some(ast::FieldList::RecordFieldList(field_list)) => {
103 let mut fields = vec![];
104 for field in field_list.fields() {
105 let base = make::expr_path(make::ext::ident_path("self"));
106 let target = make::expr_field(base, &field.name()?.to_string());
107 let method_call = gen_clone_call(target);
108 let name_ref = make::name_ref(&field.name()?.to_string());
109 let field = make::record_expr_field(name_ref, Some(method_call));
112 let struct_name = make::ext::ident_path("Self");
113 let fields = make::record_expr_field_list(fields);
114 make::record_expr(struct_name, fields).into()
116 // => Self(self.0.clone(), self.1.clone())
117 Some(ast::FieldList::TupleFieldList(field_list)) => {
118 let mut fields = vec![];
119 for (i, _) in field_list.fields().enumerate() {
120 let f_path = make::expr_path(make::ext::ident_path("self"));
121 let target = make::expr_field(f_path, &format!("{}", i)).into();
122 fields.push(gen_clone_call(target));
124 let struct_name = make::expr_path(make::ext::ident_path("Self"));
125 make::expr_call(struct_name, make::arg_list(fields))
129 let struct_name = make::ext::ident_path("Self");
130 let fields = make::record_expr_field_list(None);
131 make::record_expr(struct_name, fields).into()
136 let body = make::block_expr(None, Some(expr)).indent(ast::edit::IndentLevel(1));
137 ted::replace(func.body()?.syntax(), body.clone_for_update().syntax());
141 /// Generate a `Debug` impl based on the fields and members of the target type.
142 fn gen_debug_impl(adt: &ast::Adt, func: &ast::Fn) -> Option<()> {
143 let annotated_name = adt.name()?;
145 // `Debug` cannot be derived for unions, so no default impl can be provided.
146 ast::Adt::Union(_) => None,
148 // => match self { Self::Variant => write!(f, "Variant") }
149 ast::Adt::Enum(enum_) => {
150 let list = enum_.variant_list()?;
151 let mut arms = vec![];
152 for variant in list.variants() {
153 let name = variant.name()?;
154 let left = make::ext::ident_path("Self");
155 let right = make::ext::ident_path(&format!("{}", name));
156 let variant_name = make::path_pat(make::path_concat(left, right));
158 let target = make::expr_path(make::ext::ident_path("f").into());
159 let fmt_string = make::expr_literal(&(format!("\"{}\"", name))).into();
160 let args = make::arg_list(vec![target, fmt_string]);
161 let macro_name = make::expr_path(make::ext::ident_path("write"));
162 let macro_call = make::expr_macro_call(macro_name, args);
164 arms.push(make::match_arm(Some(variant_name.into()), None, macro_call.into()));
167 let match_target = make::expr_path(make::ext::ident_path("self"));
168 let list = make::match_arm_list(arms).indent(ast::edit::IndentLevel(1));
169 let match_expr = make::expr_match(match_target, list);
171 let body = make::block_expr(None, Some(match_expr));
172 let body = body.indent(ast::edit::IndentLevel(1));
173 ted::replace(func.body()?.syntax(), body.clone_for_update().syntax());
177 ast::Adt::Struct(strukt) => {
178 let name = format!("\"{}\"", annotated_name);
179 let args = make::arg_list(Some(make::expr_literal(&name).into()));
180 let target = make::expr_path(make::ext::ident_path("f"));
182 let expr = match strukt.field_list() {
183 // => f.debug_struct("Name").finish()
184 None => make::expr_method_call(target, make::name_ref("debug_struct"), args),
186 // => f.debug_struct("Name").field("foo", &self.foo).finish()
187 Some(ast::FieldList::RecordFieldList(field_list)) => {
188 let method = make::name_ref("debug_struct");
189 let mut expr = make::expr_method_call(target, method, args);
190 for field in field_list.fields() {
191 let name = field.name()?;
192 let f_name = make::expr_literal(&(format!("\"{}\"", name))).into();
193 let f_path = make::expr_path(make::ext::ident_path("self"));
194 let f_path = make::expr_ref(f_path, false);
195 let f_path = make::expr_field(f_path, &format!("{}", name)).into();
196 let args = make::arg_list(vec![f_name, f_path]);
197 expr = make::expr_method_call(expr, make::name_ref("field"), args);
202 // => f.debug_tuple("Name").field(self.0).finish()
203 Some(ast::FieldList::TupleFieldList(field_list)) => {
204 let method = make::name_ref("debug_tuple");
205 let mut expr = make::expr_method_call(target, method, args);
206 for (i, _) in field_list.fields().enumerate() {
207 let f_path = make::expr_path(make::ext::ident_path("self"));
208 let f_path = make::expr_ref(f_path, false);
209 let f_path = make::expr_field(f_path, &format!("{}", i)).into();
210 let method = make::name_ref("field");
211 expr = make::expr_method_call(expr, method, make::arg_list(Some(f_path)));
217 let method = make::name_ref("finish");
218 let expr = make::expr_method_call(expr, method, make::arg_list(None));
219 let body = make::block_expr(None, Some(expr)).indent(ast::edit::IndentLevel(1));
220 ted::replace(func.body()?.syntax(), body.clone_for_update().syntax());
226 /// Generate a `Debug` impl based on the fields and members of the target type.
227 fn gen_default_impl(adt: &ast::Adt, func: &ast::Fn) -> Option<()> {
228 fn gen_default_call() -> ast::Expr {
229 let trait_name = make::ext::ident_path("Default");
230 let method_name = make::ext::ident_path("default");
231 let fn_name = make::expr_path(make::path_concat(trait_name, method_name));
232 make::expr_call(fn_name, make::arg_list(None))
235 // `Debug` cannot be derived for unions, so no default impl can be provided.
236 ast::Adt::Union(_) => None,
237 // Deriving `Debug` for enums is not stable yet.
238 ast::Adt::Enum(_) => None,
239 ast::Adt::Struct(strukt) => {
240 let expr = match strukt.field_list() {
241 Some(ast::FieldList::RecordFieldList(field_list)) => {
242 let mut fields = vec![];
243 for field in field_list.fields() {
244 let method_call = gen_default_call();
245 let name_ref = make::name_ref(&field.name()?.to_string());
246 let field = make::record_expr_field(name_ref, Some(method_call));
249 let struct_name = make::ext::ident_path("Self");
250 let fields = make::record_expr_field_list(fields);
251 make::record_expr(struct_name, fields).into()
253 Some(ast::FieldList::TupleFieldList(field_list)) => {
254 let struct_name = make::expr_path(make::ext::ident_path("Self"));
255 let fields = field_list.fields().map(|_| gen_default_call());
256 make::expr_call(struct_name, make::arg_list(fields))
259 let struct_name = make::ext::ident_path("Self");
260 let fields = make::record_expr_field_list(None);
261 make::record_expr(struct_name, fields).into()
264 let body = make::block_expr(None, Some(expr)).indent(ast::edit::IndentLevel(1));
265 ted::replace(func.body()?.syntax(), body.clone_for_update().syntax());
271 /// Generate a `Hash` impl based on the fields and members of the target type.
272 fn gen_hash_impl(adt: &ast::Adt, func: &ast::Fn) -> Option<()> {
273 fn gen_hash_call(target: ast::Expr) -> ast::Stmt {
274 let method = make::name_ref("hash");
275 let arg = make::expr_path(make::ext::ident_path("state"));
276 let expr = make::expr_method_call(target, method, make::arg_list(Some(arg)));
277 let stmt = make::expr_stmt(expr);
281 let body = match adt {
282 // `Hash` cannot be derived for unions, so no default impl can be provided.
283 ast::Adt::Union(_) => return None,
285 // => std::mem::discriminant(self).hash(state);
286 ast::Adt::Enum(_) => {
287 let root = make::ext::ident_path("core");
288 let submodule = make::ext::ident_path("mem");
289 let fn_name = make::ext::ident_path("discriminant");
290 let fn_name = make::path_concat(submodule, fn_name);
291 let fn_name = make::expr_path(make::path_concat(root, fn_name));
293 let arg = make::expr_path(make::ext::ident_path("self"));
294 let fn_call = make::expr_call(fn_name, make::arg_list(Some(arg)));
295 let stmt = gen_hash_call(fn_call);
297 make::block_expr(Some(stmt), None).indent(ast::edit::IndentLevel(1))
299 ast::Adt::Struct(strukt) => match strukt.field_list() {
300 // => self.<field>.hash(state);
301 Some(ast::FieldList::RecordFieldList(field_list)) => {
302 let mut stmts = vec![];
303 for field in field_list.fields() {
304 let base = make::expr_path(make::ext::ident_path("self"));
305 let target = make::expr_field(base, &field.name()?.to_string());
306 stmts.push(gen_hash_call(target));
308 make::block_expr(stmts, None).indent(ast::edit::IndentLevel(1))
311 // => self.<field_index>.hash(state);
312 Some(ast::FieldList::TupleFieldList(field_list)) => {
313 let mut stmts = vec![];
314 for (i, _) in field_list.fields().enumerate() {
315 let base = make::expr_path(make::ext::ident_path("self"));
316 let target = make::expr_field(base, &format!("{}", i));
317 stmts.push(gen_hash_call(target));
319 make::block_expr(stmts, None).indent(ast::edit::IndentLevel(1))
322 // No fields in the body means there's nothing to hash.
327 ted::replace(func.body()?.syntax(), body.clone_for_update().syntax());
331 /// Generate a `PartialEq` impl based on the fields and members of the target type.
332 fn gen_partial_eq(adt: &ast::Adt, func: &ast::Fn) -> Option<()> {
333 fn gen_discriminant() -> ast::Expr {
334 let root = make::ext::ident_path("core");
335 let submodule = make::ext::ident_path("mem");
336 let fn_name = make::ext::ident_path("discriminant");
337 let fn_name = make::path_concat(submodule, fn_name);
338 let fn_name = make::expr_path(make::path_concat(root, fn_name));
342 fn gen_eq_chain(expr: Option<ast::Expr>, cmp: ast::Expr) -> Option<ast::Expr> {
344 Some(expr) => Some(make::expr_op(ast::BinOp::BooleanAnd, expr, cmp)),
349 fn gen_record_pat_field(field_name: &str, pat_name: &str) -> ast::RecordPatField {
350 let pat = make::ext::simple_ident_pat(make::name(&pat_name));
351 let name_ref = make::name_ref(field_name);
352 let field = make::record_pat_field(name_ref, pat.into());
357 record_name: ast::Path,
358 r_fields: Vec<ast::RecordPatField>,
359 ) -> ast::RecordPat {
360 let list = make::record_pat_field_list(r_fields);
361 make::record_pat_with_fields(record_name, list)
364 fn gen_variant_path(variant: &ast::Variant) -> Option<ast::Path> {
365 let first = make::ext::ident_path("Self");
366 let second = make::path_from_text(&variant.name()?.to_string());
367 let record_name = make::path_concat(first, second);
371 fn gen_tuple_field(field_name: &String) -> ast::Pat {
372 ast::Pat::IdentPat(make::ident_pat(false, false, make::name(field_name)))
374 // FIXME: return `None` if the trait carries a generic type; we can only
375 // generate this code `Self` for the time being.
377 let body = match adt {
378 // `Hash` cannot be derived for unions, so no default impl can be provided.
379 ast::Adt::Union(_) => return None,
381 ast::Adt::Enum(enum_) => {
382 // => std::mem::discriminant(self) == std::mem::discriminant(other)
383 let self_name = make::expr_path(make::ext::ident_path("self"));
384 let lhs = make::expr_call(gen_discriminant(), make::arg_list(Some(self_name.clone())));
385 let other_name = make::expr_path(make::ext::ident_path("other"));
386 let rhs = make::expr_call(gen_discriminant(), make::arg_list(Some(other_name.clone())));
387 let eq_check = make::expr_op(ast::BinOp::EqualityTest, lhs, rhs);
389 let mut case_count = 0;
390 let mut arms = vec![];
391 for variant in enum_.variant_list()?.variants() {
393 match variant.field_list() {
394 // => (Self::Bar { bin: l_bin }, Self::Bar { bin: r_bin }) => l_bin == r_bin,
395 Some(ast::FieldList::RecordFieldList(list)) => {
397 let mut l_fields = vec![];
398 let mut r_fields = vec![];
400 for field in list.fields() {
401 let field_name = field.name()?.to_string();
403 let l_name = &format!("l_{}", field_name);
404 l_fields.push(gen_record_pat_field(&field_name, &l_name));
406 let r_name = &format!("r_{}", field_name);
407 r_fields.push(gen_record_pat_field(&field_name, &r_name));
409 let lhs = make::expr_path(make::ext::ident_path(l_name));
410 let rhs = make::expr_path(make::ext::ident_path(r_name));
411 let cmp = make::expr_op(ast::BinOp::EqualityTest, lhs, rhs);
412 expr = gen_eq_chain(expr, cmp);
415 let left = gen_record_pat(gen_variant_path(&variant)?, l_fields);
416 let right = gen_record_pat(gen_variant_path(&variant)?, r_fields);
417 let tuple = make::tuple_pat(vec![left.into(), right.into()]);
419 if let Some(expr) = expr {
420 arms.push(make::match_arm(Some(tuple.into()), None, expr));
424 // todo!("implement tuple record iteration")
425 Some(ast::FieldList::TupleFieldList(list)) => {
427 let mut l_fields = vec![];
428 let mut r_fields = vec![];
430 for (i, _) in list.fields().enumerate() {
431 let field_name = format!("{}", i);
433 let l_name = format!("l{}", field_name);
434 l_fields.push(gen_tuple_field(&l_name));
436 let r_name = format!("r{}", field_name);
437 r_fields.push(gen_tuple_field(&r_name));
439 let lhs = make::expr_path(make::ext::ident_path(&l_name));
440 let rhs = make::expr_path(make::ext::ident_path(&r_name));
441 let cmp = make::expr_op(ast::BinOp::EqualityTest, lhs, rhs);
442 expr = gen_eq_chain(expr, cmp);
445 let left = make::tuple_struct_pat(gen_variant_path(&variant)?, l_fields);
446 let right = make::tuple_struct_pat(gen_variant_path(&variant)?, r_fields);
447 let tuple = make::tuple_pat(vec![left.into(), right.into()]);
449 if let Some(expr) = expr {
450 arms.push(make::match_arm(Some(tuple.into()), None, expr));
457 if !arms.is_empty() && case_count > arms.len() {
458 let lhs = make::wildcard_pat().into();
459 arms.push(make::match_arm(Some(lhs), None, make::expr_literal("true").into()));
462 let expr = match arms.len() {
465 let condition = make::condition(eq_check, None);
467 let match_target = make::expr_tuple(vec![self_name, other_name]);
468 let list = make::match_arm_list(arms).indent(ast::edit::IndentLevel(1));
469 let match_expr = Some(make::expr_match(match_target, list));
470 let then_branch = make::block_expr(None, match_expr);
471 let then_branch = then_branch.indent(ast::edit::IndentLevel(1));
473 let else_branche = make::expr_literal("false");
474 let else_branche = make::block_expr(None, Some(else_branche.into()))
475 .indent(ast::edit::IndentLevel(1));
477 make::expr_if(condition, then_branch, Some(else_branche.into()))
481 make::block_expr(None, Some(expr)).indent(ast::edit::IndentLevel(1))
483 ast::Adt::Struct(strukt) => match strukt.field_list() {
484 Some(ast::FieldList::RecordFieldList(field_list)) => {
486 for field in field_list.fields() {
487 let lhs = make::expr_path(make::ext::ident_path("self"));
488 let lhs = make::expr_field(lhs, &field.name()?.to_string());
489 let rhs = make::expr_path(make::ext::ident_path("other"));
490 let rhs = make::expr_field(rhs, &field.name()?.to_string());
491 let cmp = make::expr_op(ast::BinOp::EqualityTest, lhs, rhs);
492 expr = gen_eq_chain(expr, cmp);
494 make::block_expr(None, expr).indent(ast::edit::IndentLevel(1))
497 Some(ast::FieldList::TupleFieldList(field_list)) => {
499 for (i, _) in field_list.fields().enumerate() {
500 let idx = format!("{}", i);
501 let lhs = make::expr_path(make::ext::ident_path("self"));
502 let lhs = make::expr_field(lhs, &idx);
503 let rhs = make::expr_path(make::ext::ident_path("other"));
504 let rhs = make::expr_field(rhs, &idx);
505 let cmp = make::expr_op(ast::BinOp::EqualityTest, lhs, rhs);
506 expr = gen_eq_chain(expr, cmp);
508 make::block_expr(None, expr).indent(ast::edit::IndentLevel(1))
511 // No fields in the body means there's nothing to hash.
513 let expr = make::expr_literal("true").into();
514 make::block_expr(None, Some(expr)).indent(ast::edit::IndentLevel(1))
519 ted::replace(func.body()?.syntax(), body.clone_for_update().syntax());