]> git.lizzy.rs Git - rust.git/blob - crates/ide_assists/src/utils/gen_trait_fn_body.rs
Merge #9835
[rust.git] / crates / ide_assists / src / utils / gen_trait_fn_body.rs
1 //! This module contains functions to generate default trait impl function bodies where possible.
2
3 use syntax::{
4     ast::{self, edit::AstNodeEdit, make, AstNode, NameOwner},
5     ted,
6 };
7
8 /// Generate custom trait bodies where possible.
9 ///
10 /// Returns `Option` so that we can use `?` rather than `if let Some`. Returning
11 /// `None` means that generating a custom trait body failed, and the body will remain
12 /// as `todo!` instead.
13 pub(crate) fn gen_trait_fn_body(
14     func: &ast::Fn,
15     trait_path: &ast::Path,
16     adt: &ast::Adt,
17 ) -> Option<()> {
18     match trait_path.segment()?.name_ref()?.text().as_str() {
19         "Debug" => gen_debug_impl(adt, func),
20         "Default" => gen_default_impl(adt, func),
21         "Hash" => gen_hash_impl(adt, func),
22         _ => None,
23     }
24 }
25
26 /// Generate a `Debug` impl based on the fields and members of the target type.
27 fn gen_debug_impl(adt: &ast::Adt, func: &ast::Fn) -> Option<()> {
28     let annotated_name = adt.name()?;
29     match adt {
30         // `Debug` cannot be derived for unions, so no default impl can be provided.
31         ast::Adt::Union(_) => None,
32
33         // => match self { Self::Variant => write!(f, "Variant") }
34         ast::Adt::Enum(enum_) => {
35             let list = enum_.variant_list()?;
36             let mut arms = vec![];
37             for variant in list.variants() {
38                 let name = variant.name()?;
39                 let left = make::ext::ident_path("Self");
40                 let right = make::ext::ident_path(&format!("{}", name));
41                 let variant_name = make::path_pat(make::path_concat(left, right));
42
43                 let target = make::expr_path(make::ext::ident_path("f").into());
44                 let fmt_string = make::expr_literal(&(format!("\"{}\"", name))).into();
45                 let args = make::arg_list(vec![target, fmt_string]);
46                 let macro_name = make::expr_path(make::ext::ident_path("write"));
47                 let macro_call = make::expr_macro_call(macro_name, args);
48
49                 arms.push(make::match_arm(Some(variant_name.into()), None, macro_call.into()));
50             }
51
52             let match_target = make::expr_path(make::ext::ident_path("self"));
53             let list = make::match_arm_list(arms).indent(ast::edit::IndentLevel(1));
54             let match_expr = make::expr_match(match_target, list);
55
56             let body = make::block_expr(None, Some(match_expr));
57             let body = body.indent(ast::edit::IndentLevel(1));
58             ted::replace(func.body()?.syntax(), body.clone_for_update().syntax());
59             Some(())
60         }
61
62         ast::Adt::Struct(strukt) => {
63             let name = format!("\"{}\"", annotated_name);
64             let args = make::arg_list(Some(make::expr_literal(&name).into()));
65             let target = make::expr_path(make::ext::ident_path("f"));
66
67             let expr = match strukt.field_list() {
68                 // => f.debug_struct("Name").finish()
69                 None => make::expr_method_call(target, make::name_ref("debug_struct"), args),
70
71                 // => f.debug_struct("Name").field("foo", &self.foo).finish()
72                 Some(ast::FieldList::RecordFieldList(field_list)) => {
73                     let method = make::name_ref("debug_struct");
74                     let mut expr = make::expr_method_call(target, method, args);
75                     for field in field_list.fields() {
76                         let name = field.name()?;
77                         let f_name = make::expr_literal(&(format!("\"{}\"", name))).into();
78                         let f_path = make::expr_path(make::ext::ident_path("self"));
79                         let f_path = make::expr_ref(f_path, false);
80                         let f_path = make::expr_field(f_path, &format!("{}", name)).into();
81                         let args = make::arg_list(vec![f_name, f_path]);
82                         expr = make::expr_method_call(expr, make::name_ref("field"), args);
83                     }
84                     expr
85                 }
86
87                 // => f.debug_tuple("Name").field(self.0).finish()
88                 Some(ast::FieldList::TupleFieldList(field_list)) => {
89                     let method = make::name_ref("debug_tuple");
90                     let mut expr = make::expr_method_call(target, method, args);
91                     for (idx, _) in field_list.fields().enumerate() {
92                         let f_path = make::expr_path(make::ext::ident_path("self"));
93                         let f_path = make::expr_ref(f_path, false);
94                         let f_path = make::expr_field(f_path, &format!("{}", idx)).into();
95                         let method = make::name_ref("field");
96                         expr = make::expr_method_call(expr, method, make::arg_list(Some(f_path)));
97                     }
98                     expr
99                 }
100             };
101
102             let method = make::name_ref("finish");
103             let expr = make::expr_method_call(expr, method, make::arg_list(None));
104             let body = make::block_expr(None, Some(expr)).indent(ast::edit::IndentLevel(1));
105             ted::replace(func.body()?.syntax(), body.clone_for_update().syntax());
106             Some(())
107         }
108     }
109 }
110
111 /// Generate a `Debug` impl based on the fields and members of the target type.
112 fn gen_default_impl(adt: &ast::Adt, func: &ast::Fn) -> Option<()> {
113     fn gen_default_call() -> ast::Expr {
114         let trait_name = make::ext::ident_path("Default");
115         let method_name = make::ext::ident_path("default");
116         let fn_name = make::expr_path(make::path_concat(trait_name, method_name));
117         make::expr_call(fn_name, make::arg_list(None))
118     }
119     match adt {
120         // `Debug` cannot be derived for unions, so no default impl can be provided.
121         ast::Adt::Union(_) => None,
122         // Deriving `Debug` for enums is not stable yet.
123         ast::Adt::Enum(_) => None,
124         ast::Adt::Struct(strukt) => {
125             let expr = match strukt.field_list() {
126                 Some(ast::FieldList::RecordFieldList(field_list)) => {
127                     let mut fields = vec![];
128                     for field in field_list.fields() {
129                         let method_call = gen_default_call();
130                         let name_ref = make::name_ref(&field.name()?.to_string());
131                         let field = make::record_expr_field(name_ref, Some(method_call));
132                         fields.push(field);
133                     }
134                     let struct_name = make::ext::ident_path("Self");
135                     let fields = make::record_expr_field_list(fields);
136                     make::record_expr(struct_name, fields).into()
137                 }
138                 Some(ast::FieldList::TupleFieldList(field_list)) => {
139                     let struct_name = make::expr_path(make::ext::ident_path("Self"));
140                     let fields = field_list.fields().map(|_| gen_default_call());
141                     make::expr_call(struct_name, make::arg_list(fields))
142                 }
143                 None => {
144                     let struct_name = make::ext::ident_path("Self");
145                     let fields = make::record_expr_field_list(None);
146                     make::record_expr(struct_name, fields).into()
147                 }
148             };
149             let body = make::block_expr(None, Some(expr)).indent(ast::edit::IndentLevel(1));
150             ted::replace(func.body()?.syntax(), body.clone_for_update().syntax());
151             Some(())
152         }
153     }
154 }
155
156 /// Generate a `Hash` impl based on the fields and members of the target type.
157 fn gen_hash_impl(adt: &ast::Adt, func: &ast::Fn) -> Option<()> {
158     fn gen_hash_call(target: ast::Expr) -> ast::Stmt {
159         let method = make::name_ref("hash");
160         let arg = make::expr_path(make::ext::ident_path("state"));
161         let expr = make::expr_method_call(target, method, make::arg_list(Some(arg)));
162         let stmt = make::expr_stmt(expr);
163         stmt.into()
164     }
165
166     let body = match adt {
167         // `Hash` cannot be derived for unions, so no default impl can be provided.
168         ast::Adt::Union(_) => return None,
169
170         // => std::mem::discriminant(self).hash(state);
171         ast::Adt::Enum(_) => {
172             let root = make::ext::ident_path("core");
173             let submodule = make::ext::ident_path("mem");
174             let fn_name = make::ext::ident_path("discriminant");
175             let fn_name = make::path_concat(submodule, fn_name);
176             let fn_name = make::expr_path(make::path_concat(root, fn_name));
177
178             let arg = make::expr_path(make::ext::ident_path("self"));
179             let fn_call = make::expr_call(fn_name, make::arg_list(Some(arg)));
180             let stmt = gen_hash_call(fn_call);
181
182             make::block_expr(Some(stmt), None).indent(ast::edit::IndentLevel(1))
183         }
184         ast::Adt::Struct(strukt) => match strukt.field_list() {
185             // => self.<field>.hash(state);*
186             Some(ast::FieldList::RecordFieldList(field_list)) => {
187                 let mut stmts = vec![];
188                 for field in field_list.fields() {
189                     let base = make::expr_path(make::ext::ident_path("self"));
190                     let target = make::expr_field(base, &field.name()?.to_string());
191                     stmts.push(gen_hash_call(target));
192                 }
193                 make::block_expr(stmts, None).indent(ast::edit::IndentLevel(1))
194             }
195
196             // => self.<field_index>.hash(state);*
197             Some(ast::FieldList::TupleFieldList(field_list)) => {
198                 let mut stmts = vec![];
199                 for (i, _) in field_list.fields().enumerate() {
200                     let base = make::expr_path(make::ext::ident_path("self"));
201                     let target = make::expr_field(base, &format!("{}", i));
202                     stmts.push(gen_hash_call(target));
203                 }
204                 make::block_expr(stmts, None).indent(ast::edit::IndentLevel(1))
205             }
206
207             // No fields in the body means there's nothing to hash.
208             None => return None,
209         },
210     };
211
212     ted::replace(func.body()?.syntax(), body.clone_for_update().syntax());
213     Some(())
214 }