1 //! This module contains functions to generate default trait impl function bodies where possible.
4 ast::{self, edit::AstNodeEdit, make, AstNode, NameOwner},
8 /// Generate custom trait bodies where possible.
10 /// Returns `Option` so that we can use `?` rather than `if let Some`. Returning
11 /// `None` means that generating a custom trait body failed, and the body will remain
12 /// as `todo!` instead.
13 pub(crate) fn gen_trait_fn_body(
15 trait_path: &ast::Path,
18 match trait_path.segment()?.name_ref()?.text().as_str() {
19 "Clone" => gen_clone_impl(adt, func),
20 "Debug" => gen_debug_impl(adt, func),
21 "Default" => gen_default_impl(adt, func),
22 "Hash" => gen_hash_impl(adt, func),
23 "PartialEq" => gen_partial_eq(adt, func),
28 /// Generate a `Clone` impl based on the fields and members of the target type.
29 fn gen_clone_impl(adt: &ast::Adt, func: &ast::Fn) -> Option<()> {
30 fn gen_clone_call(target: ast::Expr) -> ast::Expr {
31 let method = make::name_ref("clone");
32 make::expr_method_call(target, method, make::arg_list(None))
34 let expr = match adt {
35 // `Clone` cannot be derived for unions, so no default impl can be provided.
36 ast::Adt::Union(_) => return None,
37 ast::Adt::Enum(enum_) => {
38 let list = enum_.variant_list()?;
39 let mut arms = vec![];
40 for variant in list.variants() {
41 let variant_name = make_variant_path(&variant)?;
43 match variant.field_list() {
44 // => match self { Self::Name { x } => Self::Name { x: x.clone() } }
45 Some(ast::FieldList::RecordFieldList(list)) => {
46 let mut pats = vec![];
47 let mut fields = vec![];
48 for field in list.fields() {
49 let field_name = field.name()?;
50 let pat = make::ident_pat(false, false, field_name.clone());
51 pats.push(pat.into());
53 let path = make::ext::ident_path(&field_name.to_string());
54 let method_call = gen_clone_call(make::expr_path(path));
55 let name_ref = make::name_ref(&field_name.to_string());
56 let field = make::record_expr_field(name_ref, Some(method_call));
59 let pat = make::record_pat(variant_name.clone(), pats.into_iter());
60 let fields = make::record_expr_field_list(fields);
61 let record_expr = make::record_expr(variant_name, fields).into();
62 arms.push(make::match_arm(Some(pat.into()), None, record_expr));
65 // => match self { Self::Name(arg1) => Self::Name(arg1.clone()) }
66 Some(ast::FieldList::TupleFieldList(list)) => {
67 let mut pats = vec![];
68 let mut fields = vec![];
69 for (i, _) in list.fields().enumerate() {
70 let field_name = format!("arg{}", i);
71 let pat = make::ident_pat(false, false, make::name(&field_name));
72 pats.push(pat.into());
74 let f_path = make::expr_path(make::ext::ident_path(&field_name));
75 fields.push(gen_clone_call(f_path));
77 let pat = make::tuple_struct_pat(variant_name.clone(), pats.into_iter());
78 let struct_name = make::expr_path(variant_name);
79 let tuple_expr = make::expr_call(struct_name, make::arg_list(fields));
80 arms.push(make::match_arm(Some(pat.into()), None, tuple_expr));
83 // => match self { Self::Name => Self::Name }
85 let pattern = make::path_pat(variant_name.clone());
86 let variant_expr = make::expr_path(variant_name);
87 arms.push(make::match_arm(Some(pattern.into()), None, variant_expr));
92 let match_target = make::expr_path(make::ext::ident_path("self"));
93 let list = make::match_arm_list(arms).indent(ast::edit::IndentLevel(1));
94 make::expr_match(match_target, list)
96 ast::Adt::Struct(strukt) => {
97 match strukt.field_list() {
98 // => Self { name: self.name.clone() }
99 Some(ast::FieldList::RecordFieldList(field_list)) => {
100 let mut fields = vec![];
101 for field in field_list.fields() {
102 let base = make::expr_path(make::ext::ident_path("self"));
103 let target = make::expr_field(base, &field.name()?.to_string());
104 let method_call = gen_clone_call(target);
105 let name_ref = make::name_ref(&field.name()?.to_string());
106 let field = make::record_expr_field(name_ref, Some(method_call));
109 let struct_name = make::ext::ident_path("Self");
110 let fields = make::record_expr_field_list(fields);
111 make::record_expr(struct_name, fields).into()
113 // => Self(self.0.clone(), self.1.clone())
114 Some(ast::FieldList::TupleFieldList(field_list)) => {
115 let mut fields = vec![];
116 for (i, _) in field_list.fields().enumerate() {
117 let f_path = make::expr_path(make::ext::ident_path("self"));
118 let target = make::expr_field(f_path, &format!("{}", i)).into();
119 fields.push(gen_clone_call(target));
121 let struct_name = make::expr_path(make::ext::ident_path("Self"));
122 make::expr_call(struct_name, make::arg_list(fields))
126 let struct_name = make::ext::ident_path("Self");
127 let fields = make::record_expr_field_list(None);
128 make::record_expr(struct_name, fields).into()
133 let body = make::block_expr(None, Some(expr)).indent(ast::edit::IndentLevel(1));
134 ted::replace(func.body()?.syntax(), body.clone_for_update().syntax());
138 /// Generate a `Debug` impl based on the fields and members of the target type.
139 fn gen_debug_impl(adt: &ast::Adt, func: &ast::Fn) -> Option<()> {
140 let annotated_name = adt.name()?;
142 // `Debug` cannot be derived for unions, so no default impl can be provided.
143 ast::Adt::Union(_) => None,
145 // => match self { Self::Variant => write!(f, "Variant") }
146 ast::Adt::Enum(enum_) => {
147 let list = enum_.variant_list()?;
148 let mut arms = vec![];
149 for variant in list.variants() {
150 let name = variant.name()?;
151 let variant_name = make::path_pat(make::path_from_text(&format!("Self::{}", name)));
153 let target = make::expr_path(make::ext::ident_path("f").into());
154 let fmt_string = make::expr_literal(&(format!("\"{}\"", name))).into();
155 let args = make::arg_list(vec![target, fmt_string]);
156 let macro_name = make::expr_path(make::ext::ident_path("write"));
157 let macro_call = make::expr_macro_call(macro_name, args);
159 arms.push(make::match_arm(Some(variant_name.into()), None, macro_call.into()));
162 let match_target = make::expr_path(make::ext::ident_path("self"));
163 let list = make::match_arm_list(arms).indent(ast::edit::IndentLevel(1));
164 let match_expr = make::expr_match(match_target, list);
166 let body = make::block_expr(None, Some(match_expr));
167 let body = body.indent(ast::edit::IndentLevel(1));
168 ted::replace(func.body()?.syntax(), body.clone_for_update().syntax());
172 ast::Adt::Struct(strukt) => {
173 let name = format!("\"{}\"", annotated_name);
174 let args = make::arg_list(Some(make::expr_literal(&name).into()));
175 let target = make::expr_path(make::ext::ident_path("f"));
177 let expr = match strukt.field_list() {
178 // => f.debug_struct("Name").finish()
179 None => make::expr_method_call(target, make::name_ref("debug_struct"), args),
181 // => f.debug_struct("Name").field("foo", &self.foo).finish()
182 Some(ast::FieldList::RecordFieldList(field_list)) => {
183 let method = make::name_ref("debug_struct");
184 let mut expr = make::expr_method_call(target, method, args);
185 for field in field_list.fields() {
186 let name = field.name()?;
187 let f_name = make::expr_literal(&(format!("\"{}\"", name))).into();
188 let f_path = make::expr_path(make::ext::ident_path("self"));
189 let f_path = make::expr_ref(f_path, false);
190 let f_path = make::expr_field(f_path, &format!("{}", name)).into();
191 let args = make::arg_list(vec![f_name, f_path]);
192 expr = make::expr_method_call(expr, make::name_ref("field"), args);
197 // => f.debug_tuple("Name").field(self.0).finish()
198 Some(ast::FieldList::TupleFieldList(field_list)) => {
199 let method = make::name_ref("debug_tuple");
200 let mut expr = make::expr_method_call(target, method, args);
201 for (i, _) in field_list.fields().enumerate() {
202 let f_path = make::expr_path(make::ext::ident_path("self"));
203 let f_path = make::expr_ref(f_path, false);
204 let f_path = make::expr_field(f_path, &format!("{}", i)).into();
205 let method = make::name_ref("field");
206 expr = make::expr_method_call(expr, method, make::arg_list(Some(f_path)));
212 let method = make::name_ref("finish");
213 let expr = make::expr_method_call(expr, method, make::arg_list(None));
214 let body = make::block_expr(None, Some(expr)).indent(ast::edit::IndentLevel(1));
215 ted::replace(func.body()?.syntax(), body.clone_for_update().syntax());
221 /// Generate a `Debug` impl based on the fields and members of the target type.
222 fn gen_default_impl(adt: &ast::Adt, func: &ast::Fn) -> Option<()> {
223 fn gen_default_call() -> ast::Expr {
224 let fn_name = make::path_from_text(&"Default::default");
225 make::expr_call(make::expr_path(fn_name), make::arg_list(None))
228 // `Debug` cannot be derived for unions, so no default impl can be provided.
229 ast::Adt::Union(_) => None,
230 // Deriving `Debug` for enums is not stable yet.
231 ast::Adt::Enum(_) => None,
232 ast::Adt::Struct(strukt) => {
233 let expr = match strukt.field_list() {
234 Some(ast::FieldList::RecordFieldList(field_list)) => {
235 let mut fields = vec![];
236 for field in field_list.fields() {
237 let method_call = gen_default_call();
238 let name_ref = make::name_ref(&field.name()?.to_string());
239 let field = make::record_expr_field(name_ref, Some(method_call));
242 let struct_name = make::ext::ident_path("Self");
243 let fields = make::record_expr_field_list(fields);
244 make::record_expr(struct_name, fields).into()
246 Some(ast::FieldList::TupleFieldList(field_list)) => {
247 let struct_name = make::expr_path(make::ext::ident_path("Self"));
248 let fields = field_list.fields().map(|_| gen_default_call());
249 make::expr_call(struct_name, make::arg_list(fields))
252 let struct_name = make::ext::ident_path("Self");
253 let fields = make::record_expr_field_list(None);
254 make::record_expr(struct_name, fields).into()
257 let body = make::block_expr(None, Some(expr)).indent(ast::edit::IndentLevel(1));
258 ted::replace(func.body()?.syntax(), body.clone_for_update().syntax());
264 /// Generate a `Hash` impl based on the fields and members of the target type.
265 fn gen_hash_impl(adt: &ast::Adt, func: &ast::Fn) -> Option<()> {
266 fn gen_hash_call(target: ast::Expr) -> ast::Stmt {
267 let method = make::name_ref("hash");
268 let arg = make::expr_path(make::ext::ident_path("state"));
269 let expr = make::expr_method_call(target, method, make::arg_list(Some(arg)));
270 make::expr_stmt(expr).into()
273 let body = match adt {
274 // `Hash` cannot be derived for unions, so no default impl can be provided.
275 ast::Adt::Union(_) => return None,
277 // => std::mem::discriminant(self).hash(state);
278 ast::Adt::Enum(_) => {
279 let fn_name = make_discriminant();
281 let arg = make::expr_path(make::ext::ident_path("self"));
282 let fn_call = make::expr_call(fn_name, make::arg_list(Some(arg)));
283 let stmt = gen_hash_call(fn_call);
285 make::block_expr(Some(stmt), None).indent(ast::edit::IndentLevel(1))
287 ast::Adt::Struct(strukt) => match strukt.field_list() {
288 // => self.<field>.hash(state);
289 Some(ast::FieldList::RecordFieldList(field_list)) => {
290 let mut stmts = vec![];
291 for field in field_list.fields() {
292 let base = make::expr_path(make::ext::ident_path("self"));
293 let target = make::expr_field(base, &field.name()?.to_string());
294 stmts.push(gen_hash_call(target));
296 make::block_expr(stmts, None).indent(ast::edit::IndentLevel(1))
299 // => self.<field_index>.hash(state);
300 Some(ast::FieldList::TupleFieldList(field_list)) => {
301 let mut stmts = vec![];
302 for (i, _) in field_list.fields().enumerate() {
303 let base = make::expr_path(make::ext::ident_path("self"));
304 let target = make::expr_field(base, &format!("{}", i));
305 stmts.push(gen_hash_call(target));
307 make::block_expr(stmts, None).indent(ast::edit::IndentLevel(1))
310 // No fields in the body means there's nothing to hash.
315 ted::replace(func.body()?.syntax(), body.clone_for_update().syntax());
319 /// Generate a `PartialEq` impl based on the fields and members of the target type.
320 fn gen_partial_eq(adt: &ast::Adt, func: &ast::Fn) -> Option<()> {
321 fn gen_eq_chain(expr: Option<ast::Expr>, cmp: ast::Expr) -> Option<ast::Expr> {
323 Some(expr) => Some(make::expr_op(ast::BinOp::BooleanAnd, expr, cmp)),
328 fn gen_record_pat_field(field_name: &str, pat_name: &str) -> ast::RecordPatField {
329 let pat = make::ext::simple_ident_pat(make::name(&pat_name));
330 let name_ref = make::name_ref(field_name);
331 make::record_pat_field(name_ref, pat.into())
334 fn gen_record_pat(record_name: ast::Path, fields: Vec<ast::RecordPatField>) -> ast::RecordPat {
335 let list = make::record_pat_field_list(fields);
336 make::record_pat_with_fields(record_name, list)
339 fn gen_tuple_field(field_name: &String) -> ast::Pat {
340 ast::Pat::IdentPat(make::ident_pat(false, false, make::name(field_name)))
343 // FIXME: return `None` if the trait carries a generic type; we can only
344 // generate this code `Self` for the time being.
346 let body = match adt {
347 // `PartialEq` cannot be derived for unions, so no default impl can be provided.
348 ast::Adt::Union(_) => return None,
350 ast::Adt::Enum(enum_) => {
351 // => std::mem::discriminant(self) == std::mem::discriminant(other)
352 let lhs_name = make::expr_path(make::ext::ident_path("self"));
353 let lhs = make::expr_call(make_discriminant(), make::arg_list(Some(lhs_name.clone())));
354 let rhs_name = make::expr_path(make::ext::ident_path("other"));
355 let rhs = make::expr_call(make_discriminant(), make::arg_list(Some(rhs_name.clone())));
356 let eq_check = make::expr_op(ast::BinOp::EqualityTest, lhs, rhs);
358 let mut case_count = 0;
359 let mut arms = vec![];
360 for variant in enum_.variant_list()?.variants() {
362 match variant.field_list() {
363 // => (Self::Bar { bin: l_bin }, Self::Bar { bin: r_bin }) => l_bin == r_bin,
364 Some(ast::FieldList::RecordFieldList(list)) => {
366 let mut l_fields = vec![];
367 let mut r_fields = vec![];
369 for field in list.fields() {
370 let field_name = field.name()?.to_string();
372 let l_name = &format!("l_{}", field_name);
373 l_fields.push(gen_record_pat_field(&field_name, &l_name));
375 let r_name = &format!("r_{}", field_name);
376 r_fields.push(gen_record_pat_field(&field_name, &r_name));
378 let lhs = make::expr_path(make::ext::ident_path(l_name));
379 let rhs = make::expr_path(make::ext::ident_path(r_name));
380 let cmp = make::expr_op(ast::BinOp::EqualityTest, lhs, rhs);
381 expr = gen_eq_chain(expr, cmp);
384 let left = gen_record_pat(make_variant_path(&variant)?, l_fields);
385 let right = gen_record_pat(make_variant_path(&variant)?, r_fields);
386 let tuple = make::tuple_pat(vec![left.into(), right.into()]);
388 if let Some(expr) = expr {
389 arms.push(make::match_arm(Some(tuple.into()), None, expr));
393 Some(ast::FieldList::TupleFieldList(list)) => {
395 let mut l_fields = vec![];
396 let mut r_fields = vec![];
398 for (i, _) in list.fields().enumerate() {
399 let field_name = format!("{}", i);
401 let l_name = format!("l{}", field_name);
402 l_fields.push(gen_tuple_field(&l_name));
404 let r_name = format!("r{}", field_name);
405 r_fields.push(gen_tuple_field(&r_name));
407 let lhs = make::expr_path(make::ext::ident_path(&l_name));
408 let rhs = make::expr_path(make::ext::ident_path(&r_name));
409 let cmp = make::expr_op(ast::BinOp::EqualityTest, lhs, rhs);
410 expr = gen_eq_chain(expr, cmp);
413 let left = make::tuple_struct_pat(make_variant_path(&variant)?, l_fields);
414 let right = make::tuple_struct_pat(make_variant_path(&variant)?, r_fields);
415 let tuple = make::tuple_pat(vec![left.into(), right.into()]);
417 if let Some(expr) = expr {
418 arms.push(make::match_arm(Some(tuple.into()), None, expr));
425 let expr = match arms.len() {
428 if case_count > arms.len() {
429 let lhs = make::wildcard_pat().into();
430 arms.push(make::match_arm(Some(lhs), None, eq_check));
433 let match_target = make::expr_tuple(vec![lhs_name, rhs_name]);
434 let list = make::match_arm_list(arms).indent(ast::edit::IndentLevel(1));
435 make::expr_match(match_target, list)
439 make::block_expr(None, Some(expr)).indent(ast::edit::IndentLevel(1))
441 ast::Adt::Struct(strukt) => match strukt.field_list() {
442 Some(ast::FieldList::RecordFieldList(field_list)) => {
444 for field in field_list.fields() {
445 let lhs = make::expr_path(make::ext::ident_path("self"));
446 let lhs = make::expr_field(lhs, &field.name()?.to_string());
447 let rhs = make::expr_path(make::ext::ident_path("other"));
448 let rhs = make::expr_field(rhs, &field.name()?.to_string());
449 let cmp = make::expr_op(ast::BinOp::EqualityTest, lhs, rhs);
450 expr = gen_eq_chain(expr, cmp);
452 make::block_expr(None, expr).indent(ast::edit::IndentLevel(1))
455 Some(ast::FieldList::TupleFieldList(field_list)) => {
457 for (i, _) in field_list.fields().enumerate() {
458 let idx = format!("{}", i);
459 let lhs = make::expr_path(make::ext::ident_path("self"));
460 let lhs = make::expr_field(lhs, &idx);
461 let rhs = make::expr_path(make::ext::ident_path("other"));
462 let rhs = make::expr_field(rhs, &idx);
463 let cmp = make::expr_op(ast::BinOp::EqualityTest, lhs, rhs);
464 expr = gen_eq_chain(expr, cmp);
466 make::block_expr(None, expr).indent(ast::edit::IndentLevel(1))
469 // No fields in the body means there's nothing to compare.
471 let expr = make::expr_literal("true").into();
472 make::block_expr(None, Some(expr)).indent(ast::edit::IndentLevel(1))
477 ted::replace(func.body()?.syntax(), body.clone_for_update().syntax());
481 fn make_discriminant() -> ast::Expr {
482 make::expr_path(make::path_from_text("core::mem::discriminant"))
485 fn make_variant_path(variant: &ast::Variant) -> Option<ast::Path> {
486 Some(make::path_from_text(&format!("Self::{}", &variant.name()?)))