]> git.lizzy.rs Git - rust.git/blob - crates/ide_assists/src/handlers/replace_derive_with_manual_impl.rs
impl Debug def from trait
[rust.git] / crates / ide_assists / src / handlers / replace_derive_with_manual_impl.rs
1 use hir::ModuleDef;
2 use ide_db::helpers::{import_assets::NameToImport, mod_path_to_ast};
3 use ide_db::items_locator;
4 use itertools::Itertools;
5 use syntax::ast::edit::AstNodeEdit;
6 use syntax::ast::Pat;
7 use syntax::ted;
8 use syntax::{
9     ast::{self, make, AstNode, NameOwner},
10     SyntaxKind::{IDENT, WHITESPACE},
11 };
12
13 use crate::{
14     assist_context::{AssistBuilder, AssistContext, Assists},
15     utils::{
16         add_trait_assoc_items_to_impl, filter_assoc_items, generate_trait_impl_text,
17         render_snippet, Cursor, DefaultMethods,
18     },
19     AssistId, AssistKind,
20 };
21
22 // Assist: replace_derive_with_manual_impl
23 //
24 // Converts a `derive` impl into a manual one.
25 //
26 // ```
27 // # trait Debug { fn fmt(&self, f: &mut Formatter) -> Result<()>; }
28 // #[derive(Deb$0ug, Display)]
29 // struct S;
30 // ```
31 // ->
32 // ```
33 // # trait Debug { fn fmt(&self, f: &mut Formatter) -> Result<()>; }
34 // #[derive(Display)]
35 // struct S;
36 //
37 // impl Debug for S {
38 //     fn fmt(&self, f: &mut Formatter) -> Result<()> {
39 //         ${0:todo!()}
40 //     }
41 // }
42 // ```
43 pub(crate) fn replace_derive_with_manual_impl(
44     acc: &mut Assists,
45     ctx: &AssistContext,
46 ) -> Option<()> {
47     let attr = ctx.find_node_at_offset::<ast::Attr>()?;
48     let (name, args) = attr.as_simple_call()?;
49     if name != "derive" {
50         return None;
51     }
52
53     if !args.syntax().text_range().contains(ctx.offset()) {
54         cov_mark::hit!(outside_of_attr_args);
55         return None;
56     }
57
58     let trait_token = args.syntax().token_at_offset(ctx.offset()).find(|t| t.kind() == IDENT)?;
59     let trait_name = trait_token.text();
60
61     let adt = attr.syntax().parent().and_then(ast::Adt::cast)?;
62
63     let current_module = ctx.sema.scope(adt.syntax()).module()?;
64     let current_crate = current_module.krate();
65
66     let found_traits = items_locator::items_with_name(
67         &ctx.sema,
68         current_crate,
69         NameToImport::Exact(trait_name.to_string()),
70         items_locator::AssocItemSearch::Exclude,
71         Some(items_locator::DEFAULT_QUERY_SEARCH_LIMIT.inner()),
72     )
73     .filter_map(|item| match item.as_module_def()? {
74         ModuleDef::Trait(trait_) => Some(trait_),
75         _ => None,
76     })
77     .flat_map(|trait_| {
78         current_module
79             .find_use_path(ctx.sema.db, hir::ModuleDef::Trait(trait_))
80             .as_ref()
81             .map(mod_path_to_ast)
82             .zip(Some(trait_))
83     });
84
85     let mut no_traits_found = true;
86     for (trait_path, trait_) in found_traits.inspect(|_| no_traits_found = false) {
87         add_assist(acc, ctx, &attr, &args, &trait_path, Some(trait_), &adt)?;
88     }
89     if no_traits_found {
90         let trait_path = make::ext::ident_path(trait_name);
91         add_assist(acc, ctx, &attr, &args, &trait_path, None, &adt)?;
92     }
93     Some(())
94 }
95
96 fn add_assist(
97     acc: &mut Assists,
98     ctx: &AssistContext,
99     attr: &ast::Attr,
100     input: &ast::TokenTree,
101     trait_path: &ast::Path,
102     trait_: Option<hir::Trait>,
103     adt: &ast::Adt,
104 ) -> Option<()> {
105     let target = attr.syntax().text_range();
106     let annotated_name = adt.name()?;
107     let label = format!("Convert to manual `impl {} for {}`", trait_path, annotated_name);
108     let trait_name = trait_path.segment().and_then(|seg| seg.name_ref())?;
109
110     acc.add(
111         AssistId("replace_derive_with_manual_impl", AssistKind::Refactor),
112         label,
113         target,
114         |builder| {
115             let insert_pos = adt.syntax().text_range().end();
116             let impl_def_with_items =
117                 impl_def_from_trait(&ctx.sema, &annotated_name, trait_, trait_path);
118             update_attribute(builder, input, &trait_name, attr);
119             let trait_path = format!("{}", trait_path);
120             match (ctx.config.snippet_cap, impl_def_with_items) {
121                 (None, _) => {
122                     builder.insert(insert_pos, generate_trait_impl_text(adt, &trait_path, ""))
123                 }
124                 (Some(cap), None) => builder.insert_snippet(
125                     cap,
126                     insert_pos,
127                     generate_trait_impl_text(adt, &trait_path, "    $0"),
128                 ),
129                 (Some(cap), Some((impl_def, first_assoc_item))) => {
130                     let mut cursor = Cursor::Before(first_assoc_item.syntax());
131                     let placeholder;
132                     if let ast::AssocItem::Fn(ref func) = first_assoc_item {
133                         // need to know what kind of derive this is: if it's Derive Debug, special case it.
134                         // the name of the struct
135                         // list of fields of the struct
136                         if let Some(m) = func.syntax().descendants().find_map(ast::MacroCall::cast)
137                         {
138                             if m.syntax().text() == "todo!()" {
139                                 placeholder = m;
140                                 cursor = Cursor::Replace(placeholder.syntax());
141                             }
142                         }
143                     }
144
145                     builder.insert_snippet(
146                         cap,
147                         insert_pos,
148                         format!("\n\n{}", render_snippet(cap, impl_def.syntax(), cursor)),
149                     )
150                 }
151             };
152         },
153     )
154 }
155
156 fn impl_def_from_trait(
157     sema: &hir::Semantics<ide_db::RootDatabase>,
158     annotated_name: &ast::Name,
159     trait_: Option<hir::Trait>,
160     trait_path: &ast::Path,
161 ) -> Option<(ast::Impl, ast::AssocItem)> {
162     let trait_ = trait_?;
163     let target_scope = sema.scope(annotated_name.syntax());
164     let trait_items = filter_assoc_items(sema.db, &trait_.items(sema.db), DefaultMethods::No);
165     if trait_items.is_empty() {
166         return None;
167     }
168     let impl_def =
169         make::impl_trait(trait_path.clone(), make::ext::ident_path(&annotated_name.text()));
170     let (impl_def, first_assoc_item) =
171         add_trait_assoc_items_to_impl(sema, trait_items, trait_, impl_def, target_scope);
172     if let ast::AssocItem::Fn(fn_) = &first_assoc_item {
173         if trait_path.segment().unwrap().name_ref().unwrap().text() == "Debug" {
174             let f_expr = make::expr_path(make::ext::ident_path("f"));
175             let args = make::arg_list(Some(make::expr_path(make::ext::ident_path(
176                 annotated_name.text().as_str(),
177             ))));
178             let body =
179                 make::block_expr(None, Some(make::expr_method_call(f_expr, "debug_struct", args)))
180                     .indent(ast::edit::IndentLevel(1));
181
182             ted::replace(
183                 fn_.body().unwrap().tail_expr().unwrap().syntax(),
184                 body.clone_for_update().syntax(),
185             );
186         }
187     }
188     Some((impl_def, first_assoc_item))
189 }
190
191 fn update_attribute(
192     builder: &mut AssistBuilder,
193     input: &ast::TokenTree,
194     trait_name: &ast::NameRef,
195     attr: &ast::Attr,
196 ) {
197     let trait_name = trait_name.text();
198     let new_attr_input = input
199         .syntax()
200         .descendants_with_tokens()
201         .filter(|t| t.kind() == IDENT)
202         .filter_map(|t| t.into_token().map(|t| t.text().to_string()))
203         .filter(|t| t != &trait_name)
204         .collect::<Vec<_>>();
205     let has_more_derives = !new_attr_input.is_empty();
206
207     if has_more_derives {
208         let new_attr_input = format!("({})", new_attr_input.iter().format(", "));
209         builder.replace(input.syntax().text_range(), new_attr_input);
210     } else {
211         let attr_range = attr.syntax().text_range();
212         builder.delete(attr_range);
213
214         if let Some(line_break_range) = attr
215             .syntax()
216             .next_sibling_or_token()
217             .filter(|t| t.kind() == WHITESPACE)
218             .map(|t| t.text_range())
219         {
220             builder.delete(line_break_range);
221         }
222     }
223 }
224
225 #[cfg(test)]
226 mod tests {
227     use crate::tests::{check_assist, check_assist_not_applicable};
228
229     use super::*;
230
231     #[test]
232     fn add_custom_impl_debug() {
233         check_assist(
234             replace_derive_with_manual_impl,
235             r#"
236 mod fmt {
237     pub struct Error;
238     pub type Result = Result<(), Error>;
239     pub struct Formatter<'a>;
240     pub trait Debug {
241         fn fmt(&self, f: &mut Formatter<'_>) -> Result;
242     }
243 }
244
245 #[derive(Debu$0g)]
246 struct Foo {
247     bar: String,
248 }
249 "#,
250             r#"
251 mod fmt {
252     pub struct Error;
253     pub type Result = Result<(), Error>;
254     pub struct Formatter<'a>;
255     pub trait Debug {
256         fn fmt(&self, f: &mut Formatter<'_>) -> Result;
257     }
258 }
259
260 struct Foo {
261     bar: String,
262 }
263
264 impl fmt::Debug for Foo {
265     $0fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
266         f.debug_struct("Foo").field("bar", &self.bar).finish()
267     }
268 }
269 "#,
270         )
271     }
272     #[test]
273     fn add_custom_impl_all() {
274         check_assist(
275             replace_derive_with_manual_impl,
276             r#"
277 mod foo {
278     pub trait Bar {
279         type Qux;
280         const Baz: usize = 42;
281         const Fez: usize;
282         fn foo();
283         fn bar() {}
284     }
285 }
286
287 #[derive($0Bar)]
288 struct Foo {
289     bar: String,
290 }
291 "#,
292             r#"
293 mod foo {
294     pub trait Bar {
295         type Qux;
296         const Baz: usize = 42;
297         const Fez: usize;
298         fn foo();
299         fn bar() {}
300     }
301 }
302
303 struct Foo {
304     bar: String,
305 }
306
307 impl foo::Bar for Foo {
308     $0type Qux;
309
310     const Baz: usize = 42;
311
312     const Fez: usize;
313
314     fn foo() {
315         todo!()
316     }
317 }
318 "#,
319         )
320     }
321     #[test]
322     fn add_custom_impl_for_unique_input() {
323         check_assist(
324             replace_derive_with_manual_impl,
325             r#"
326 #[derive(Debu$0g)]
327 struct Foo {
328     bar: String,
329 }
330             "#,
331             r#"
332 struct Foo {
333     bar: String,
334 }
335
336 impl Debug for Foo {
337     $0
338 }
339             "#,
340         )
341     }
342
343     #[test]
344     fn add_custom_impl_for_with_visibility_modifier() {
345         check_assist(
346             replace_derive_with_manual_impl,
347             r#"
348 #[derive(Debug$0)]
349 pub struct Foo {
350     bar: String,
351 }
352             "#,
353             r#"
354 pub struct Foo {
355     bar: String,
356 }
357
358 impl Debug for Foo {
359     $0
360 }
361             "#,
362         )
363     }
364
365     #[test]
366     fn add_custom_impl_when_multiple_inputs() {
367         check_assist(
368             replace_derive_with_manual_impl,
369             r#"
370 #[derive(Display, Debug$0, Serialize)]
371 struct Foo {}
372             "#,
373             r#"
374 #[derive(Display, Serialize)]
375 struct Foo {}
376
377 impl Debug for Foo {
378     $0
379 }
380             "#,
381         )
382     }
383
384     #[test]
385     fn test_ignore_derive_macro_without_input() {
386         check_assist_not_applicable(
387             replace_derive_with_manual_impl,
388             r#"
389 #[derive($0)]
390 struct Foo {}
391             "#,
392         )
393     }
394
395     #[test]
396     fn test_ignore_if_cursor_on_param() {
397         check_assist_not_applicable(
398             replace_derive_with_manual_impl,
399             r#"
400 #[derive$0(Debug)]
401 struct Foo {}
402             "#,
403         );
404
405         check_assist_not_applicable(
406             replace_derive_with_manual_impl,
407             r#"
408 #[derive(Debug)$0]
409 struct Foo {}
410             "#,
411         )
412     }
413
414     #[test]
415     fn test_ignore_if_not_derive() {
416         check_assist_not_applicable(
417             replace_derive_with_manual_impl,
418             r#"
419 #[allow(non_camel_$0case_types)]
420 struct Foo {}
421             "#,
422         )
423     }
424
425     #[test]
426     fn works_at_start_of_file() {
427         cov_mark::check!(outside_of_attr_args);
428         check_assist_not_applicable(
429             replace_derive_with_manual_impl,
430             r#"
431 $0#[derive(Debug)]
432 struct S;
433             "#,
434         );
435     }
436 }