]> git.lizzy.rs Git - rust.git/blob - crates/ide_assists/src/handlers/replace_derive_with_manual_impl.rs
Improve naming and add comments
[rust.git] / crates / ide_assists / src / handlers / replace_derive_with_manual_impl.rs
1 use hir::ModuleDef;
2 use ide_db::helpers::{import_assets::NameToImport, mod_path_to_ast};
3 use ide_db::items_locator;
4 use itertools::Itertools;
5 use syntax::ast::edit::AstNodeEdit;
6 use syntax::ted;
7 use syntax::{
8     ast::{self, make, AstNode, NameOwner},
9     SyntaxKind::{IDENT, WHITESPACE},
10 };
11
12 use crate::{
13     assist_context::{AssistBuilder, AssistContext, Assists},
14     utils::{
15         add_trait_assoc_items_to_impl, filter_assoc_items, generate_trait_impl_text,
16         render_snippet, Cursor, DefaultMethods,
17     },
18     AssistId, AssistKind,
19 };
20
21 // Assist: replace_derive_with_manual_impl
22 //
23 // Converts a `derive` impl into a manual one.
24 //
25 // ```
26 // # trait Debug { fn fmt(&self, f: &mut Formatter) -> Result<()>; }
27 // #[derive(Deb$0ug, Display)]
28 // struct S;
29 // ```
30 // ->
31 // ```
32 // # trait Debug { fn fmt(&self, f: &mut Formatter) -> Result<()>; }
33 // #[derive(Display)]
34 // struct S;
35 //
36 // impl Debug for S {
37 //     $0fn fmt(&self, f: &mut Formatter) -> Result<()> {
38 //         f.debug_struct("S").finish()
39 //     }
40 // }
41 // ```
42 pub(crate) fn replace_derive_with_manual_impl(
43     acc: &mut Assists,
44     ctx: &AssistContext,
45 ) -> Option<()> {
46     let attr = ctx.find_node_at_offset::<ast::Attr>()?;
47     let (name, args) = attr.as_simple_call()?;
48     if name != "derive" {
49         return None;
50     }
51
52     if !args.syntax().text_range().contains(ctx.offset()) {
53         cov_mark::hit!(outside_of_attr_args);
54         return None;
55     }
56
57     let trait_token = args.syntax().token_at_offset(ctx.offset()).find(|t| t.kind() == IDENT)?;
58     let trait_name = trait_token.text();
59
60     let adt = attr.syntax().parent().and_then(ast::Adt::cast)?;
61
62     let current_module = ctx.sema.scope(adt.syntax()).module()?;
63     let current_crate = current_module.krate();
64
65     let found_traits = items_locator::items_with_name(
66         &ctx.sema,
67         current_crate,
68         NameToImport::Exact(trait_name.to_string()),
69         items_locator::AssocItemSearch::Exclude,
70         Some(items_locator::DEFAULT_QUERY_SEARCH_LIMIT.inner()),
71     )
72     .filter_map(|item| match item.as_module_def()? {
73         ModuleDef::Trait(trait_) => Some(trait_),
74         _ => None,
75     })
76     .flat_map(|trait_| {
77         current_module
78             .find_use_path(ctx.sema.db, hir::ModuleDef::Trait(trait_))
79             .as_ref()
80             .map(mod_path_to_ast)
81             .zip(Some(trait_))
82     });
83
84     let mut no_traits_found = true;
85     for (trait_path, trait_) in found_traits.inspect(|_| no_traits_found = false) {
86         add_assist(acc, ctx, &attr, &args, &trait_path, Some(trait_), &adt)?;
87     }
88     if no_traits_found {
89         let trait_path = make::ext::ident_path(trait_name);
90         add_assist(acc, ctx, &attr, &args, &trait_path, None, &adt)?;
91     }
92     Some(())
93 }
94
95 fn add_assist(
96     acc: &mut Assists,
97     ctx: &AssistContext,
98     attr: &ast::Attr,
99     input: &ast::TokenTree,
100     trait_path: &ast::Path,
101     trait_: Option<hir::Trait>,
102     adt: &ast::Adt,
103 ) -> Option<()> {
104     let target = attr.syntax().text_range();
105     let annotated_name = adt.name()?;
106     let label = format!("Convert to manual `impl {} for {}`", trait_path, annotated_name);
107     let trait_name = trait_path.segment().and_then(|seg| seg.name_ref())?;
108
109     acc.add(
110         AssistId("replace_derive_with_manual_impl", AssistKind::Refactor),
111         label,
112         target,
113         |builder| {
114             let insert_pos = adt.syntax().text_range().end();
115             let impl_def_with_items =
116                 impl_def_from_trait(&ctx.sema, adt, &annotated_name, trait_, trait_path);
117             update_attribute(builder, input, &trait_name, attr);
118             let trait_path = format!("{}", trait_path);
119             match (ctx.config.snippet_cap, impl_def_with_items) {
120                 (None, _) => {
121                     builder.insert(insert_pos, generate_trait_impl_text(adt, &trait_path, ""))
122                 }
123                 (Some(cap), None) => builder.insert_snippet(
124                     cap,
125                     insert_pos,
126                     generate_trait_impl_text(adt, &trait_path, "    $0"),
127                 ),
128                 (Some(cap), Some((impl_def, first_assoc_item))) => {
129                     let mut cursor = Cursor::Before(first_assoc_item.syntax());
130                     let placeholder;
131                     if let ast::AssocItem::Fn(ref func) = first_assoc_item {
132                         if let Some(m) = func.syntax().descendants().find_map(ast::MacroCall::cast)
133                         {
134                             if m.syntax().text() == "todo!()" {
135                                 placeholder = m;
136                                 cursor = Cursor::Replace(placeholder.syntax());
137                             }
138                         }
139                     }
140
141                     builder.insert_snippet(
142                         cap,
143                         insert_pos,
144                         format!("\n\n{}", render_snippet(cap, impl_def.syntax(), cursor)),
145                     )
146                 }
147             };
148         },
149     )
150 }
151
152 fn impl_def_from_trait(
153     sema: &hir::Semantics<ide_db::RootDatabase>,
154     adt: &ast::Adt,
155     annotated_name: &ast::Name,
156     trait_: Option<hir::Trait>,
157     trait_path: &ast::Path,
158 ) -> Option<(ast::Impl, ast::AssocItem)> {
159     let trait_ = trait_?;
160     let target_scope = sema.scope(annotated_name.syntax());
161     let trait_items = filter_assoc_items(sema.db, &trait_.items(sema.db), DefaultMethods::No);
162     if trait_items.is_empty() {
163         return None;
164     }
165     let impl_def =
166         make::impl_trait(trait_path.clone(), make::ext::ident_path(&annotated_name.text()));
167     let (impl_def, first_assoc_item) =
168         add_trait_assoc_items_to_impl(sema, trait_items, trait_, impl_def, target_scope);
169
170     // Generate a default `impl` function body for the derived trait.
171     if let ast::AssocItem::Fn(func) = &first_assoc_item {
172         match trait_path.segment().unwrap().name_ref().unwrap().text().as_str() {
173             "Debug" => gen_debug_impl(adt, func, annotated_name),
174             _ => {} // => If we don't know about the trait, the function body is left as `todo!`.
175         };
176     }
177     Some((impl_def, first_assoc_item))
178 }
179
180 /// Generate a `Debug` impl based on the fields and members of the target type.
181 fn gen_debug_impl(adt: &ast::Adt, func: &ast::Fn, annotated_name: &ast::Name) {
182     match adt {
183         ast::Adt::Union(_) => {} // `Debug` cannot be derived for unions, so no default impl can be provided.
184         ast::Adt::Enum(enum_) => {
185             // => match self { Self::Variant => write!(f, "Variant") }
186             if let Some(list) = enum_.variant_list() {
187                 let mut arms = vec![];
188                 for variant in list.variants() {
189                     let name = variant.name().unwrap();
190
191                     let left = make::ext::ident_path("Self");
192                     let right = make::ext::ident_path(&format!("{}", name));
193                     let variant_name = make::path_pat(make::path_concat(left, right));
194
195                     let target = make::expr_path(make::ext::ident_path("f").into());
196                     let fmt_string = make::expr_literal(&(format!("\"{}\"", name))).into();
197                     let args = make::arg_list(vec![target, fmt_string]);
198                     let macro_name = make::expr_path(make::ext::ident_path("write"));
199                     let macro_call = make::expr_macro_call(macro_name, args);
200
201                     arms.push(make::match_arm(Some(variant_name.into()), None, macro_call.into()));
202                 }
203
204                 let match_target = make::expr_path(make::ext::ident_path("self"));
205                 let list = make::match_arm_list(arms).indent(ast::edit::IndentLevel(1));
206                 let match_expr = make::expr_match(match_target, list);
207
208                 let body = make::block_expr(None, Some(match_expr));
209                 let body = body.indent(ast::edit::IndentLevel(1));
210                 ted::replace(func.body().unwrap().syntax(), body.clone_for_update().syntax());
211             }
212         }
213         ast::Adt::Struct(strukt) => {
214             let name = format!("\"{}\"", annotated_name);
215             let args = make::arg_list(Some(make::expr_literal(&name).into()));
216             let target = make::expr_path(make::ext::ident_path("f"));
217
218             let expr = match strukt.field_list() {
219                 None => {
220                     // => f.debug_struct("Name").finish()
221                     make::expr_method_call(target, "debug_struct", args)
222                 }
223                 Some(ast::FieldList::RecordFieldList(field_list)) => {
224                     // => f.debug_struct("Name").field("foo", &self.foo).finish()
225                     let mut expr = make::expr_method_call(target, "debug_struct", args);
226                     for field in field_list.fields() {
227                         if let Some(name) = field.name() {
228                             let f_name = make::expr_literal(&(format!("\"{}\"", name))).into();
229                             let f_path = make::expr_path(make::ext::ident_path("self"));
230                             let f_path = make::expr_ref(f_path, false);
231                             let f_path = make::expr_field(f_path, &format!("{}", name)).into();
232                             let args = make::arg_list(vec![f_name, f_path]);
233                             expr = make::expr_method_call(expr, "field", args);
234                         }
235                     }
236                     expr
237                 }
238                 Some(ast::FieldList::TupleFieldList(field_list)) => {
239                     // => f.debug_tuple("Name").field(self.0).finish()
240                     let mut expr = make::expr_method_call(target, "debug_tuple", args);
241                     for (idx, _) in field_list.fields().enumerate() {
242                         let f_path = make::expr_path(make::ext::ident_path("self"));
243                         let f_path = make::expr_ref(f_path, false);
244                         let f_path = make::expr_field(f_path, &format!("{}", idx)).into();
245                         expr = make::expr_method_call(expr, "field", make::arg_list(Some(f_path)));
246                     }
247                     expr
248                 }
249             };
250
251             let expr = make::expr_method_call(expr, "finish", make::arg_list(None));
252             let body = make::block_expr(None, Some(expr)).indent(ast::edit::IndentLevel(1));
253             ted::replace(func.body().unwrap().syntax(), body.clone_for_update().syntax());
254         }
255     }
256 }
257
258 fn update_attribute(
259     builder: &mut AssistBuilder,
260     input: &ast::TokenTree,
261     trait_name: &ast::NameRef,
262     attr: &ast::Attr,
263 ) {
264     let trait_name = trait_name.text();
265     let new_attr_input = input
266         .syntax()
267         .descendants_with_tokens()
268         .filter(|t| t.kind() == IDENT)
269         .filter_map(|t| t.into_token().map(|t| t.text().to_string()))
270         .filter(|t| t != &trait_name)
271         .collect::<Vec<_>>();
272     let has_more_derives = !new_attr_input.is_empty();
273
274     if has_more_derives {
275         let new_attr_input = format!("({})", new_attr_input.iter().format(", "));
276         builder.replace(input.syntax().text_range(), new_attr_input);
277     } else {
278         let attr_range = attr.syntax().text_range();
279         builder.delete(attr_range);
280
281         if let Some(line_break_range) = attr
282             .syntax()
283             .next_sibling_or_token()
284             .filter(|t| t.kind() == WHITESPACE)
285             .map(|t| t.text_range())
286         {
287             builder.delete(line_break_range);
288         }
289     }
290 }
291
292 #[cfg(test)]
293 mod tests {
294     use crate::tests::{check_assist, check_assist_not_applicable};
295
296     use super::*;
297
298     #[test]
299     fn add_custom_impl_debug_record_struct() {
300         check_assist(
301             replace_derive_with_manual_impl,
302             r#"
303 mod fmt {
304     pub struct Error;
305     pub type Result = Result<(), Error>;
306     pub struct Formatter<'a>;
307     pub trait Debug {
308         fn fmt(&self, f: &mut Formatter<'_>) -> Result;
309     }
310 }
311
312 #[derive(Debu$0g)]
313 struct Foo {
314     bar: String,
315 }
316 "#,
317             r#"
318 mod fmt {
319     pub struct Error;
320     pub type Result = Result<(), Error>;
321     pub struct Formatter<'a>;
322     pub trait Debug {
323         fn fmt(&self, f: &mut Formatter<'_>) -> Result;
324     }
325 }
326
327 struct Foo {
328     bar: String,
329 }
330
331 impl fmt::Debug for Foo {
332     $0fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
333         f.debug_struct("Foo").field("bar", &self.bar).finish()
334     }
335 }
336 "#,
337         )
338     }
339     #[test]
340     fn add_custom_impl_debug_tuple_struct() {
341         check_assist(
342             replace_derive_with_manual_impl,
343             r#"
344 mod fmt {
345     pub struct Error;
346     pub type Result = Result<(), Error>;
347     pub struct Formatter<'a>;
348     pub trait Debug {
349         fn fmt(&self, f: &mut Formatter<'_>) -> Result;
350     }
351 }
352
353 #[derive(Debu$0g)]
354 struct Foo(String, usize);
355 "#,
356             r#"
357 mod fmt {
358     pub struct Error;
359     pub type Result = Result<(), Error>;
360     pub struct Formatter<'a>;
361     pub trait Debug {
362         fn fmt(&self, f: &mut Formatter<'_>) -> Result;
363     }
364 }
365
366 struct Foo(String, usize);
367
368 impl fmt::Debug for Foo {
369     $0fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
370         f.debug_tuple("Foo").field(&self.0).field(&self.1).finish()
371     }
372 }
373 "#,
374         )
375     }
376     #[test]
377     fn add_custom_impl_debug_empty_struct() {
378         check_assist(
379             replace_derive_with_manual_impl,
380             r#"
381 mod fmt {
382     pub struct Error;
383     pub type Result = Result<(), Error>;
384     pub struct Formatter<'a>;
385     pub trait Debug {
386         fn fmt(&self, f: &mut Formatter<'_>) -> Result;
387     }
388 }
389
390 #[derive(Debu$0g)]
391 struct Foo;
392 "#,
393             r#"
394 mod fmt {
395     pub struct Error;
396     pub type Result = Result<(), Error>;
397     pub struct Formatter<'a>;
398     pub trait Debug {
399         fn fmt(&self, f: &mut Formatter<'_>) -> Result;
400     }
401 }
402
403 struct Foo;
404
405 impl fmt::Debug for Foo {
406     $0fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
407         f.debug_struct("Foo").finish()
408     }
409 }
410 "#,
411         )
412     }
413     #[test]
414     fn add_custom_impl_debug_enum() {
415         check_assist(
416             replace_derive_with_manual_impl,
417             r#"
418 mod fmt {
419     pub struct Error;
420     pub type Result = Result<(), Error>;
421     pub struct Formatter<'a>;
422     pub trait Debug {
423         fn fmt(&self, f: &mut Formatter<'_>) -> Result;
424     }
425 }
426
427 #[derive(Debu$0g)]
428 enum Foo {
429     Bar,
430     Baz,
431 }
432 "#,
433             r#"
434 mod fmt {
435     pub struct Error;
436     pub type Result = Result<(), Error>;
437     pub struct Formatter<'a>;
438     pub trait Debug {
439         fn fmt(&self, f: &mut Formatter<'_>) -> Result;
440     }
441 }
442
443 enum Foo {
444     Bar,
445     Baz,
446 }
447
448 impl fmt::Debug for Foo {
449     $0fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
450         match self {
451             Self::Bar => write!(f, "Bar"),
452             Self::Baz => write!(f, "Baz"),
453         }
454     }
455 }
456 "#,
457         )
458     }
459     #[test]
460     fn add_custom_impl_all() {
461         check_assist(
462             replace_derive_with_manual_impl,
463             r#"
464 mod foo {
465     pub trait Bar {
466         type Qux;
467         const Baz: usize = 42;
468         const Fez: usize;
469         fn foo();
470         fn bar() {}
471     }
472 }
473
474 #[derive($0Bar)]
475 struct Foo {
476     bar: String,
477 }
478 "#,
479             r#"
480 mod foo {
481     pub trait Bar {
482         type Qux;
483         const Baz: usize = 42;
484         const Fez: usize;
485         fn foo();
486         fn bar() {}
487     }
488 }
489
490 struct Foo {
491     bar: String,
492 }
493
494 impl foo::Bar for Foo {
495     $0type Qux;
496
497     const Baz: usize = 42;
498
499     const Fez: usize;
500
501     fn foo() {
502         todo!()
503     }
504 }
505 "#,
506         )
507     }
508     #[test]
509     fn add_custom_impl_for_unique_input() {
510         check_assist(
511             replace_derive_with_manual_impl,
512             r#"
513 #[derive(Debu$0g)]
514 struct Foo {
515     bar: String,
516 }
517             "#,
518             r#"
519 struct Foo {
520     bar: String,
521 }
522
523 impl Debug for Foo {
524     $0
525 }
526             "#,
527         )
528     }
529
530     #[test]
531     fn add_custom_impl_for_with_visibility_modifier() {
532         check_assist(
533             replace_derive_with_manual_impl,
534             r#"
535 #[derive(Debug$0)]
536 pub struct Foo {
537     bar: String,
538 }
539             "#,
540             r#"
541 pub struct Foo {
542     bar: String,
543 }
544
545 impl Debug for Foo {
546     $0
547 }
548             "#,
549         )
550     }
551
552     #[test]
553     fn add_custom_impl_when_multiple_inputs() {
554         check_assist(
555             replace_derive_with_manual_impl,
556             r#"
557 #[derive(Display, Debug$0, Serialize)]
558 struct Foo {}
559             "#,
560             r#"
561 #[derive(Display, Serialize)]
562 struct Foo {}
563
564 impl Debug for Foo {
565     $0
566 }
567             "#,
568         )
569     }
570
571     #[test]
572     fn test_ignore_derive_macro_without_input() {
573         check_assist_not_applicable(
574             replace_derive_with_manual_impl,
575             r#"
576 #[derive($0)]
577 struct Foo {}
578             "#,
579         )
580     }
581
582     #[test]
583     fn test_ignore_if_cursor_on_param() {
584         check_assist_not_applicable(
585             replace_derive_with_manual_impl,
586             r#"
587 #[derive$0(Debug)]
588 struct Foo {}
589             "#,
590         );
591
592         check_assist_not_applicable(
593             replace_derive_with_manual_impl,
594             r#"
595 #[derive(Debug)$0]
596 struct Foo {}
597             "#,
598         )
599     }
600
601     #[test]
602     fn test_ignore_if_not_derive() {
603         check_assist_not_applicable(
604             replace_derive_with_manual_impl,
605             r#"
606 #[allow(non_camel_$0case_types)]
607 struct Foo {}
608             "#,
609         )
610     }
611
612     #[test]
613     fn works_at_start_of_file() {
614         cov_mark::check!(outside_of_attr_args);
615         check_assist_not_applicable(
616             replace_derive_with_manual_impl,
617             r#"
618 $0#[derive(Debug)]
619 struct S;
620             "#,
621         );
622     }
623 }