]> git.lizzy.rs Git - rust.git/blob - crates/ide_assists/src/handlers/replace_derive_with_manual_impl.rs
Fix enum debug indent level
[rust.git] / crates / ide_assists / src / handlers / replace_derive_with_manual_impl.rs
1 use hir::ModuleDef;
2 use ide_db::helpers::{import_assets::NameToImport, mod_path_to_ast};
3 use ide_db::items_locator;
4 use itertools::Itertools;
5 use syntax::ast::edit::AstNodeEdit;
6 use syntax::ted;
7 use syntax::{
8     ast::{self, make, AstNode, NameOwner},
9     SyntaxKind::{IDENT, WHITESPACE},
10 };
11
12 use crate::{
13     assist_context::{AssistBuilder, AssistContext, Assists},
14     utils::{
15         add_trait_assoc_items_to_impl, filter_assoc_items, generate_trait_impl_text,
16         render_snippet, Cursor, DefaultMethods,
17     },
18     AssistId, AssistKind,
19 };
20
21 // Assist: replace_derive_with_manual_impl
22 //
23 // Converts a `derive` impl into a manual one.
24 //
25 // ```
26 // # trait Debug { fn fmt(&self, f: &mut Formatter) -> Result<()>; }
27 // #[derive(Deb$0ug, Display)]
28 // struct S;
29 // ```
30 // ->
31 // ```
32 // # trait Debug { fn fmt(&self, f: &mut Formatter) -> Result<()>; }
33 // #[derive(Display)]
34 // struct S;
35 //
36 // impl Debug for S {
37 //     $0fn fmt(&self, f: &mut Formatter) -> Result<()> {
38 //         f.debug_struct("S").finish()
39 //     }
40 // }
41 // ```
42 pub(crate) fn replace_derive_with_manual_impl(
43     acc: &mut Assists,
44     ctx: &AssistContext,
45 ) -> Option<()> {
46     let attr = ctx.find_node_at_offset::<ast::Attr>()?;
47     let (name, args) = attr.as_simple_call()?;
48     if name != "derive" {
49         return None;
50     }
51
52     if !args.syntax().text_range().contains(ctx.offset()) {
53         cov_mark::hit!(outside_of_attr_args);
54         return None;
55     }
56
57     let trait_token = args.syntax().token_at_offset(ctx.offset()).find(|t| t.kind() == IDENT)?;
58     let trait_name = trait_token.text();
59
60     let adt = attr.syntax().parent().and_then(ast::Adt::cast)?;
61
62     let current_module = ctx.sema.scope(adt.syntax()).module()?;
63     let current_crate = current_module.krate();
64
65     let found_traits = items_locator::items_with_name(
66         &ctx.sema,
67         current_crate,
68         NameToImport::Exact(trait_name.to_string()),
69         items_locator::AssocItemSearch::Exclude,
70         Some(items_locator::DEFAULT_QUERY_SEARCH_LIMIT.inner()),
71     )
72     .filter_map(|item| match item.as_module_def()? {
73         ModuleDef::Trait(trait_) => Some(trait_),
74         _ => None,
75     })
76     .flat_map(|trait_| {
77         current_module
78             .find_use_path(ctx.sema.db, hir::ModuleDef::Trait(trait_))
79             .as_ref()
80             .map(mod_path_to_ast)
81             .zip(Some(trait_))
82     });
83
84     let mut no_traits_found = true;
85     for (trait_path, trait_) in found_traits.inspect(|_| no_traits_found = false) {
86         add_assist(acc, ctx, &attr, &args, &trait_path, Some(trait_), &adt)?;
87     }
88     if no_traits_found {
89         let trait_path = make::ext::ident_path(trait_name);
90         add_assist(acc, ctx, &attr, &args, &trait_path, None, &adt)?;
91     }
92     Some(())
93 }
94
95 fn add_assist(
96     acc: &mut Assists,
97     ctx: &AssistContext,
98     attr: &ast::Attr,
99     input: &ast::TokenTree,
100     trait_path: &ast::Path,
101     trait_: Option<hir::Trait>,
102     adt: &ast::Adt,
103 ) -> Option<()> {
104     let target = attr.syntax().text_range();
105     let annotated_name = adt.name()?;
106     let label = format!("Convert to manual `impl {} for {}`", trait_path, annotated_name);
107     let trait_name = trait_path.segment().and_then(|seg| seg.name_ref())?;
108
109     acc.add(
110         AssistId("replace_derive_with_manual_impl", AssistKind::Refactor),
111         label,
112         target,
113         |builder| {
114             let insert_pos = adt.syntax().text_range().end();
115             let impl_def_with_items =
116                 impl_def_from_trait(&ctx.sema, adt, &annotated_name, trait_, trait_path);
117             update_attribute(builder, input, &trait_name, attr);
118             let trait_path = format!("{}", trait_path);
119             match (ctx.config.snippet_cap, impl_def_with_items) {
120                 (None, _) => {
121                     builder.insert(insert_pos, generate_trait_impl_text(adt, &trait_path, ""))
122                 }
123                 (Some(cap), None) => builder.insert_snippet(
124                     cap,
125                     insert_pos,
126                     generate_trait_impl_text(adt, &trait_path, "    $0"),
127                 ),
128                 (Some(cap), Some((impl_def, first_assoc_item))) => {
129                     let mut cursor = Cursor::Before(first_assoc_item.syntax());
130                     let placeholder;
131                     if let ast::AssocItem::Fn(ref func) = first_assoc_item {
132                         // need to know what kind of derive this is: if it's Derive Debug, special case it.
133                         // the name of the struct
134                         // list of fields of the struct
135                         if let Some(m) = func.syntax().descendants().find_map(ast::MacroCall::cast)
136                         {
137                             if m.syntax().text() == "todo!()" {
138                                 placeholder = m;
139                                 cursor = Cursor::Replace(placeholder.syntax());
140                             }
141                         }
142                     }
143
144                     builder.insert_snippet(
145                         cap,
146                         insert_pos,
147                         format!("\n\n{}", render_snippet(cap, impl_def.syntax(), cursor)),
148                     )
149                 }
150             };
151         },
152     )
153 }
154
155 fn impl_def_from_trait(
156     sema: &hir::Semantics<ide_db::RootDatabase>,
157     adt: &ast::Adt,
158     annotated_name: &ast::Name,
159     trait_: Option<hir::Trait>,
160     trait_path: &ast::Path,
161 ) -> Option<(ast::Impl, ast::AssocItem)> {
162     let trait_ = trait_?;
163     let target_scope = sema.scope(annotated_name.syntax());
164     let trait_items = filter_assoc_items(sema.db, &trait_.items(sema.db), DefaultMethods::No);
165     if trait_items.is_empty() {
166         return None;
167     }
168     let impl_def =
169         make::impl_trait(trait_path.clone(), make::ext::ident_path(&annotated_name.text()));
170     let (impl_def, first_assoc_item) =
171         add_trait_assoc_items_to_impl(sema, trait_items, trait_, impl_def, target_scope);
172
173     if let ast::AssocItem::Fn(func) = &first_assoc_item {
174         if trait_path.segment().unwrap().name_ref().unwrap().text() == "Debug" {
175             gen_debug_impl(adt, func, annotated_name);
176         }
177     }
178     Some((impl_def, first_assoc_item))
179 }
180
181 fn gen_debug_impl(adt: &ast::Adt, func: &ast::Fn, annotated_name: &ast::Name) {
182     match adt {
183         ast::Adt::Union(_) => {} // `Debug` cannot be derived for unions, so no default impl can be provided.
184         ast::Adt::Enum(enum_) => {
185             if let Some(list) = enum_.variant_list() {
186                 let mut arms = vec![];
187                 for variant in list.variants() {
188                     let name = variant.name().unwrap();
189
190                     // => Self::<Variant>
191                     let first = make::ext::ident_path("Self");
192                     let second = make::ext::ident_path(&format!("{}", name));
193                     let pat = make::path_pat(make::path_concat(first, second));
194
195                     // => write!(f, "<Variant>")
196                     let target = make::expr_path(make::ext::ident_path("f").into());
197                     let fmt_string = make::expr_literal(&(format!("\"{}\"", name))).into();
198                     let args = make::arg_list(vec![target, fmt_string]);
199                     let target = make::expr_path(make::ext::ident_path("write"));
200                     let expr = make::expr_macro_call(target, args);
201
202                     // => Self::<Variant> => write!(f, "<Variant>"),
203                     arms.push(make::match_arm(Some(pat.into()), None, expr.into()));
204                 }
205
206                 // => match self { ... }
207                 let f_path = make::expr_path(make::ext::ident_path("self"));
208                 let list = make::match_arm_list(arms).indent(ast::edit::IndentLevel(1));
209                 let expr = make::expr_match(f_path, list);
210
211                 let body = make::block_expr(None, Some(expr)).indent(ast::edit::IndentLevel(1));
212                 ted::replace(func.body().unwrap().syntax(), body.clone_for_update().syntax());
213             }
214         }
215         ast::Adt::Struct(strukt) => match strukt.field_list() {
216             Some(ast::FieldList::RecordFieldList(field_list)) => {
217                 let name = format!("\"{}\"", annotated_name);
218                 let args = make::arg_list(Some(make::expr_literal(&name).into()));
219                 let target = make::expr_path(make::ext::ident_path("f"));
220                 let mut expr = make::expr_method_call(target, "debug_struct", args);
221                 for field in field_list.fields() {
222                     if let Some(name) = field.name() {
223                         let f_name = make::expr_literal(&(format!("\"{}\"", name))).into();
224                         let f_path = make::expr_path(make::ext::ident_path("self"));
225                         let f_path = make::expr_ref(f_path, false);
226                         let f_path = make::expr_field(f_path, &format!("{}", name)).into();
227                         let args = make::arg_list(vec![f_name, f_path]);
228                         expr = make::expr_method_call(expr, "field", args);
229                     }
230                 }
231                 let expr = make::expr_method_call(expr, "finish", make::arg_list(None));
232                 let body = make::block_expr(None, Some(expr)).indent(ast::edit::IndentLevel(1));
233                 ted::replace(func.body().unwrap().syntax(), body.clone_for_update().syntax());
234             }
235             Some(ast::FieldList::TupleFieldList(field_list)) => {
236                 let name = format!("\"{}\"", annotated_name);
237                 let args = make::arg_list(Some(make::expr_literal(&name).into()));
238                 let target = make::expr_path(make::ext::ident_path("f"));
239                 let mut expr = make::expr_method_call(target, "debug_tuple", args);
240                 for (idx, _) in field_list.fields().enumerate() {
241                     let f_path = make::expr_path(make::ext::ident_path("self"));
242                     let f_path = make::expr_ref(f_path, false);
243                     let f_path = make::expr_field(f_path, &format!("{}", idx)).into();
244                     let args = make::arg_list(Some(f_path));
245                     expr = make::expr_method_call(expr, "field", args);
246                 }
247                 let expr = make::expr_method_call(expr, "finish", make::arg_list(None));
248                 let body = make::block_expr(None, Some(expr)).indent(ast::edit::IndentLevel(1));
249                 ted::replace(func.body().unwrap().syntax(), body.clone_for_update().syntax());
250             }
251             None => {
252                 let name = format!("\"{}\"", annotated_name);
253                 let args = make::arg_list(Some(make::expr_literal(&name).into()));
254                 let target = make::expr_path(make::ext::ident_path("f"));
255                 let expr = make::expr_method_call(target, "debug_struct", args);
256                 let expr = make::expr_method_call(expr, "finish", make::arg_list(None));
257                 let body = make::block_expr(None, Some(expr)).indent(ast::edit::IndentLevel(1));
258                 ted::replace(func.body().unwrap().syntax(), body.clone_for_update().syntax());
259             }
260         },
261     }
262 }
263
264 fn update_attribute(
265     builder: &mut AssistBuilder,
266     input: &ast::TokenTree,
267     trait_name: &ast::NameRef,
268     attr: &ast::Attr,
269 ) {
270     let trait_name = trait_name.text();
271     let new_attr_input = input
272         .syntax()
273         .descendants_with_tokens()
274         .filter(|t| t.kind() == IDENT)
275         .filter_map(|t| t.into_token().map(|t| t.text().to_string()))
276         .filter(|t| t != &trait_name)
277         .collect::<Vec<_>>();
278     let has_more_derives = !new_attr_input.is_empty();
279
280     if has_more_derives {
281         let new_attr_input = format!("({})", new_attr_input.iter().format(", "));
282         builder.replace(input.syntax().text_range(), new_attr_input);
283     } else {
284         let attr_range = attr.syntax().text_range();
285         builder.delete(attr_range);
286
287         if let Some(line_break_range) = attr
288             .syntax()
289             .next_sibling_or_token()
290             .filter(|t| t.kind() == WHITESPACE)
291             .map(|t| t.text_range())
292         {
293             builder.delete(line_break_range);
294         }
295     }
296 }
297
298 #[cfg(test)]
299 mod tests {
300     use crate::tests::{check_assist, check_assist_not_applicable};
301
302     use super::*;
303
304     #[test]
305     fn add_custom_impl_debug_record_struct() {
306         check_assist(
307             replace_derive_with_manual_impl,
308             r#"
309 mod fmt {
310     pub struct Error;
311     pub type Result = Result<(), Error>;
312     pub struct Formatter<'a>;
313     pub trait Debug {
314         fn fmt(&self, f: &mut Formatter<'_>) -> Result;
315     }
316 }
317
318 #[derive(Debu$0g)]
319 struct Foo {
320     bar: String,
321 }
322 "#,
323             r#"
324 mod fmt {
325     pub struct Error;
326     pub type Result = Result<(), Error>;
327     pub struct Formatter<'a>;
328     pub trait Debug {
329         fn fmt(&self, f: &mut Formatter<'_>) -> Result;
330     }
331 }
332
333 struct Foo {
334     bar: String,
335 }
336
337 impl fmt::Debug for Foo {
338     $0fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
339         f.debug_struct("Foo").field("bar", &self.bar).finish()
340     }
341 }
342 "#,
343         )
344     }
345     #[test]
346     fn add_custom_impl_debug_tuple_struct() {
347         check_assist(
348             replace_derive_with_manual_impl,
349             r#"
350 mod fmt {
351     pub struct Error;
352     pub type Result = Result<(), Error>;
353     pub struct Formatter<'a>;
354     pub trait Debug {
355         fn fmt(&self, f: &mut Formatter<'_>) -> Result;
356     }
357 }
358
359 #[derive(Debu$0g)]
360 struct Foo(String, usize);
361 "#,
362             r#"
363 mod fmt {
364     pub struct Error;
365     pub type Result = Result<(), Error>;
366     pub struct Formatter<'a>;
367     pub trait Debug {
368         fn fmt(&self, f: &mut Formatter<'_>) -> Result;
369     }
370 }
371
372 struct Foo(String, usize);
373
374 impl fmt::Debug for Foo {
375     $0fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
376         f.debug_tuple("Foo").field(&self.0).field(&self.1).finish()
377     }
378 }
379 "#,
380         )
381     }
382     #[test]
383     fn add_custom_impl_debug_empty_struct() {
384         check_assist(
385             replace_derive_with_manual_impl,
386             r#"
387 mod fmt {
388     pub struct Error;
389     pub type Result = Result<(), Error>;
390     pub struct Formatter<'a>;
391     pub trait Debug {
392         fn fmt(&self, f: &mut Formatter<'_>) -> Result;
393     }
394 }
395
396 #[derive(Debu$0g)]
397 struct Foo;
398 "#,
399             r#"
400 mod fmt {
401     pub struct Error;
402     pub type Result = Result<(), Error>;
403     pub struct Formatter<'a>;
404     pub trait Debug {
405         fn fmt(&self, f: &mut Formatter<'_>) -> Result;
406     }
407 }
408
409 struct Foo;
410
411 impl fmt::Debug for Foo {
412     $0fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
413         f.debug_struct("Foo").finish()
414     }
415 }
416 "#,
417         )
418     }
419     #[test]
420     fn add_custom_impl_debug_enum() {
421         check_assist(
422             replace_derive_with_manual_impl,
423             r#"
424 mod fmt {
425     pub struct Error;
426     pub type Result = Result<(), Error>;
427     pub struct Formatter<'a>;
428     pub trait Debug {
429         fn fmt(&self, f: &mut Formatter<'_>) -> Result;
430     }
431 }
432
433 #[derive(Debu$0g)]
434 enum Foo {
435     Bar,
436     Baz,
437 }
438 "#,
439             r#"
440 mod fmt {
441     pub struct Error;
442     pub type Result = Result<(), Error>;
443     pub struct Formatter<'a>;
444     pub trait Debug {
445         fn fmt(&self, f: &mut Formatter<'_>) -> Result;
446     }
447 }
448
449 enum Foo {
450     Bar,
451     Baz,
452 }
453
454 impl fmt::Debug for Foo {
455     $0fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
456         match self {
457             Self::Bar => write!(f, "Bar"),
458             Self::Baz => write!(f, "Baz"),
459         }
460     }
461 }
462 "#,
463         )
464     }
465     #[test]
466     fn add_custom_impl_all() {
467         check_assist(
468             replace_derive_with_manual_impl,
469             r#"
470 mod foo {
471     pub trait Bar {
472         type Qux;
473         const Baz: usize = 42;
474         const Fez: usize;
475         fn foo();
476         fn bar() {}
477     }
478 }
479
480 #[derive($0Bar)]
481 struct Foo {
482     bar: String,
483 }
484 "#,
485             r#"
486 mod foo {
487     pub trait Bar {
488         type Qux;
489         const Baz: usize = 42;
490         const Fez: usize;
491         fn foo();
492         fn bar() {}
493     }
494 }
495
496 struct Foo {
497     bar: String,
498 }
499
500 impl foo::Bar for Foo {
501     $0type Qux;
502
503     const Baz: usize = 42;
504
505     const Fez: usize;
506
507     fn foo() {
508         todo!()
509     }
510 }
511 "#,
512         )
513     }
514     #[test]
515     fn add_custom_impl_for_unique_input() {
516         check_assist(
517             replace_derive_with_manual_impl,
518             r#"
519 #[derive(Debu$0g)]
520 struct Foo {
521     bar: String,
522 }
523             "#,
524             r#"
525 struct Foo {
526     bar: String,
527 }
528
529 impl Debug for Foo {
530     $0
531 }
532             "#,
533         )
534     }
535
536     #[test]
537     fn add_custom_impl_for_with_visibility_modifier() {
538         check_assist(
539             replace_derive_with_manual_impl,
540             r#"
541 #[derive(Debug$0)]
542 pub struct Foo {
543     bar: String,
544 }
545             "#,
546             r#"
547 pub struct Foo {
548     bar: String,
549 }
550
551 impl Debug for Foo {
552     $0
553 }
554             "#,
555         )
556     }
557
558     #[test]
559     fn add_custom_impl_when_multiple_inputs() {
560         check_assist(
561             replace_derive_with_manual_impl,
562             r#"
563 #[derive(Display, Debug$0, Serialize)]
564 struct Foo {}
565             "#,
566             r#"
567 #[derive(Display, Serialize)]
568 struct Foo {}
569
570 impl Debug for Foo {
571     $0
572 }
573             "#,
574         )
575     }
576
577     #[test]
578     fn test_ignore_derive_macro_without_input() {
579         check_assist_not_applicable(
580             replace_derive_with_manual_impl,
581             r#"
582 #[derive($0)]
583 struct Foo {}
584             "#,
585         )
586     }
587
588     #[test]
589     fn test_ignore_if_cursor_on_param() {
590         check_assist_not_applicable(
591             replace_derive_with_manual_impl,
592             r#"
593 #[derive$0(Debug)]
594 struct Foo {}
595             "#,
596         );
597
598         check_assist_not_applicable(
599             replace_derive_with_manual_impl,
600             r#"
601 #[derive(Debug)$0]
602 struct Foo {}
603             "#,
604         )
605     }
606
607     #[test]
608     fn test_ignore_if_not_derive() {
609         check_assist_not_applicable(
610             replace_derive_with_manual_impl,
611             r#"
612 #[allow(non_camel_$0case_types)]
613 struct Foo {}
614             "#,
615         )
616     }
617
618     #[test]
619     fn works_at_start_of_file() {
620         cov_mark::check!(outside_of_attr_args);
621         check_assist_not_applicable(
622             replace_derive_with_manual_impl,
623             r#"
624 $0#[derive(Debug)]
625 struct S;
626             "#,
627         );
628     }
629 }