]> git.lizzy.rs Git - rust.git/blob - crates/ide_assists/src/handlers/replace_derive_with_manual_impl.rs
Add trait codegen to `add_missing_impl_members` assist
[rust.git] / crates / ide_assists / src / handlers / replace_derive_with_manual_impl.rs
1 use hir::ModuleDef;
2 use ide_db::helpers::{import_assets::NameToImport, mod_path_to_ast};
3 use ide_db::items_locator;
4 use itertools::Itertools;
5 use syntax::{
6     ast::{self, make, AstNode, NameOwner},
7     SyntaxKind::{IDENT, WHITESPACE},
8 };
9
10 use crate::utils::gen_trait_body;
11 use crate::{
12     assist_context::{AssistBuilder, AssistContext, Assists},
13     utils::{
14         add_trait_assoc_items_to_impl, filter_assoc_items, generate_trait_impl_text,
15         render_snippet, Cursor, DefaultMethods,
16     },
17     AssistId, AssistKind,
18 };
19
20 // Assist: replace_derive_with_manual_impl
21 //
22 // Converts a `derive` impl into a manual one.
23 //
24 // ```
25 // # trait Debug { fn fmt(&self, f: &mut Formatter) -> Result<()>; }
26 // #[derive(Deb$0ug, Display)]
27 // struct S;
28 // ```
29 // ->
30 // ```
31 // # trait Debug { fn fmt(&self, f: &mut Formatter) -> Result<()>; }
32 // #[derive(Display)]
33 // struct S;
34 //
35 // impl Debug for S {
36 //     $0fn fmt(&self, f: &mut Formatter) -> Result<()> {
37 //         f.debug_struct("S").finish()
38 //     }
39 // }
40 // ```
41 pub(crate) fn replace_derive_with_manual_impl(
42     acc: &mut Assists,
43     ctx: &AssistContext,
44 ) -> Option<()> {
45     let attr = ctx.find_node_at_offset::<ast::Attr>()?;
46     let (name, args) = attr.as_simple_call()?;
47     if name != "derive" {
48         return None;
49     }
50
51     if !args.syntax().text_range().contains(ctx.offset()) {
52         cov_mark::hit!(outside_of_attr_args);
53         return None;
54     }
55
56     let trait_token = args.syntax().token_at_offset(ctx.offset()).find(|t| t.kind() == IDENT)?;
57     let trait_name = trait_token.text();
58
59     let adt = attr.syntax().parent().and_then(ast::Adt::cast)?;
60
61     let current_module = ctx.sema.scope(adt.syntax()).module()?;
62     let current_crate = current_module.krate();
63
64     let found_traits = items_locator::items_with_name(
65         &ctx.sema,
66         current_crate,
67         NameToImport::Exact(trait_name.to_string()),
68         items_locator::AssocItemSearch::Exclude,
69         Some(items_locator::DEFAULT_QUERY_SEARCH_LIMIT.inner()),
70     )
71     .filter_map(|item| match item.as_module_def()? {
72         ModuleDef::Trait(trait_) => Some(trait_),
73         _ => None,
74     })
75     .flat_map(|trait_| {
76         current_module
77             .find_use_path(ctx.sema.db, hir::ModuleDef::Trait(trait_))
78             .as_ref()
79             .map(mod_path_to_ast)
80             .zip(Some(trait_))
81     });
82
83     let mut no_traits_found = true;
84     for (trait_path, trait_) in found_traits.inspect(|_| no_traits_found = false) {
85         add_assist(acc, ctx, &attr, &args, &trait_path, Some(trait_), &adt)?;
86     }
87     if no_traits_found {
88         let trait_path = make::ext::ident_path(trait_name);
89         add_assist(acc, ctx, &attr, &args, &trait_path, None, &adt)?;
90     }
91     Some(())
92 }
93
94 fn add_assist(
95     acc: &mut Assists,
96     ctx: &AssistContext,
97     attr: &ast::Attr,
98     input: &ast::TokenTree,
99     trait_path: &ast::Path,
100     trait_: Option<hir::Trait>,
101     adt: &ast::Adt,
102 ) -> Option<()> {
103     let target = attr.syntax().text_range();
104     let annotated_name = adt.name()?;
105     let label = format!("Convert to manual `impl {} for {}`", trait_path, annotated_name);
106     let trait_name = trait_path.segment().and_then(|seg| seg.name_ref())?;
107
108     acc.add(
109         AssistId("replace_derive_with_manual_impl", AssistKind::Refactor),
110         label,
111         target,
112         |builder| {
113             let insert_pos = adt.syntax().text_range().end();
114             let impl_def_with_items =
115                 impl_def_from_trait(&ctx.sema, adt, &annotated_name, trait_, trait_path);
116             update_attribute(builder, input, &trait_name, attr);
117             let trait_path = format!("{}", trait_path);
118             match (ctx.config.snippet_cap, impl_def_with_items) {
119                 (None, _) => {
120                     builder.insert(insert_pos, generate_trait_impl_text(adt, &trait_path, ""))
121                 }
122                 (Some(cap), None) => builder.insert_snippet(
123                     cap,
124                     insert_pos,
125                     generate_trait_impl_text(adt, &trait_path, "    $0"),
126                 ),
127                 (Some(cap), Some((impl_def, first_assoc_item))) => {
128                     let mut cursor = Cursor::Before(first_assoc_item.syntax());
129                     let placeholder;
130                     if let ast::AssocItem::Fn(ref func) = first_assoc_item {
131                         if let Some(m) = func.syntax().descendants().find_map(ast::MacroCall::cast)
132                         {
133                             if m.syntax().text() == "todo!()" {
134                                 placeholder = m;
135                                 cursor = Cursor::Replace(placeholder.syntax());
136                             }
137                         }
138                     }
139
140                     builder.insert_snippet(
141                         cap,
142                         insert_pos,
143                         format!("\n\n{}", render_snippet(cap, impl_def.syntax(), cursor)),
144                     )
145                 }
146             };
147         },
148     )
149 }
150
151 fn impl_def_from_trait(
152     sema: &hir::Semantics<ide_db::RootDatabase>,
153     adt: &ast::Adt,
154     annotated_name: &ast::Name,
155     trait_: Option<hir::Trait>,
156     trait_path: &ast::Path,
157 ) -> Option<(ast::Impl, ast::AssocItem)> {
158     let trait_ = trait_?;
159     let target_scope = sema.scope(annotated_name.syntax());
160     let trait_items = filter_assoc_items(sema.db, &trait_.items(sema.db), DefaultMethods::No);
161     if trait_items.is_empty() {
162         return None;
163     }
164     let impl_def =
165         make::impl_trait(trait_path.clone(), make::ext::ident_path(&annotated_name.text()));
166     let (impl_def, first_assoc_item) =
167         add_trait_assoc_items_to_impl(sema, trait_items, trait_, impl_def, target_scope);
168
169     // Generate a default `impl` function body for the derived trait.
170     if let ast::AssocItem::Fn(ref func) = first_assoc_item {
171         let _ = gen_trait_body(func, trait_path, adt);
172     };
173
174     Some((impl_def, first_assoc_item))
175 }
176
177 fn update_attribute(
178     builder: &mut AssistBuilder,
179     input: &ast::TokenTree,
180     trait_name: &ast::NameRef,
181     attr: &ast::Attr,
182 ) {
183     let trait_name = trait_name.text();
184     let new_attr_input = input
185         .syntax()
186         .descendants_with_tokens()
187         .filter(|t| t.kind() == IDENT)
188         .filter_map(|t| t.into_token().map(|t| t.text().to_string()))
189         .filter(|t| t != &trait_name)
190         .collect::<Vec<_>>();
191     let has_more_derives = !new_attr_input.is_empty();
192
193     if has_more_derives {
194         let new_attr_input = format!("({})", new_attr_input.iter().format(", "));
195         builder.replace(input.syntax().text_range(), new_attr_input);
196     } else {
197         let attr_range = attr.syntax().text_range();
198         builder.delete(attr_range);
199
200         if let Some(line_break_range) = attr
201             .syntax()
202             .next_sibling_or_token()
203             .filter(|t| t.kind() == WHITESPACE)
204             .map(|t| t.text_range())
205         {
206             builder.delete(line_break_range);
207         }
208     }
209 }
210
211 #[cfg(test)]
212 mod tests {
213     use crate::tests::{check_assist, check_assist_not_applicable};
214
215     use super::*;
216
217     #[test]
218     fn add_custom_impl_debug_record_struct() {
219         check_assist(
220             replace_derive_with_manual_impl,
221             r#"
222 //- minicore: fmt
223 #[derive(Debu$0g)]
224 struct Foo {
225     bar: String,
226 }
227 "#,
228             r#"
229 struct Foo {
230     bar: String,
231 }
232
233 impl core::fmt::Debug for Foo {
234     $0fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
235         f.debug_struct("Foo").field("bar", &self.bar).finish()
236     }
237 }
238 "#,
239         )
240     }
241     #[test]
242     fn add_custom_impl_debug_tuple_struct() {
243         check_assist(
244             replace_derive_with_manual_impl,
245             r#"
246 //- minicore: fmt
247 #[derive(Debu$0g)]
248 struct Foo(String, usize);
249 "#,
250             r#"struct Foo(String, usize);
251
252 impl core::fmt::Debug for Foo {
253     $0fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
254         f.debug_tuple("Foo").field(&self.0).field(&self.1).finish()
255     }
256 }
257 "#,
258         )
259     }
260     #[test]
261     fn add_custom_impl_debug_empty_struct() {
262         check_assist(
263             replace_derive_with_manual_impl,
264             r#"
265 //- minicore: fmt
266 #[derive(Debu$0g)]
267 struct Foo;
268 "#,
269             r#"
270 struct Foo;
271
272 impl core::fmt::Debug for Foo {
273     $0fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
274         f.debug_struct("Foo").finish()
275     }
276 }
277 "#,
278         )
279     }
280     #[test]
281     fn add_custom_impl_debug_enum() {
282         check_assist(
283             replace_derive_with_manual_impl,
284             r#"
285 //- minicore: fmt
286 #[derive(Debu$0g)]
287 enum Foo {
288     Bar,
289     Baz,
290 }
291 "#,
292             r#"
293 enum Foo {
294     Bar,
295     Baz,
296 }
297
298 impl core::fmt::Debug for Foo {
299     $0fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
300         match self {
301             Self::Bar => write!(f, "Bar"),
302             Self::Baz => write!(f, "Baz"),
303         }
304     }
305 }
306 "#,
307         )
308     }
309     #[test]
310     fn add_custom_impl_default_record_struct() {
311         check_assist(
312             replace_derive_with_manual_impl,
313             r#"
314 //- minicore: default
315 #[derive(Defau$0lt)]
316 struct Foo {
317     foo: usize,
318 }
319 "#,
320             r#"
321 struct Foo {
322     foo: usize,
323 }
324
325 impl Default for Foo {
326     $0fn default() -> Self {
327         Self { foo: Default::default() }
328     }
329 }
330 "#,
331         )
332     }
333     #[test]
334     fn add_custom_impl_default_tuple_struct() {
335         check_assist(
336             replace_derive_with_manual_impl,
337             r#"
338 //- minicore: default
339 #[derive(Defau$0lt)]
340 struct Foo(usize);
341 "#,
342             r#"
343 struct Foo(usize);
344
345 impl Default for Foo {
346     $0fn default() -> Self {
347         Self(Default::default())
348     }
349 }
350 "#,
351         )
352     }
353     #[test]
354     fn add_custom_impl_default_empty_struct() {
355         check_assist(
356             replace_derive_with_manual_impl,
357             r#"
358 //- minicore: default
359 #[derive(Defau$0lt)]
360 struct Foo;
361 "#,
362             r#"
363 struct Foo;
364
365 impl Default for Foo {
366     $0fn default() -> Self {
367         Self {  }
368     }
369 }
370 "#,
371         )
372     }
373     #[test]
374     fn add_custom_impl_all() {
375         check_assist(
376             replace_derive_with_manual_impl,
377             r#"
378 mod foo {
379     pub trait Bar {
380         type Qux;
381         const Baz: usize = 42;
382         const Fez: usize;
383         fn foo();
384         fn bar() {}
385     }
386 }
387
388 #[derive($0Bar)]
389 struct Foo {
390     bar: String,
391 }
392 "#,
393             r#"
394 mod foo {
395     pub trait Bar {
396         type Qux;
397         const Baz: usize = 42;
398         const Fez: usize;
399         fn foo();
400         fn bar() {}
401     }
402 }
403
404 struct Foo {
405     bar: String,
406 }
407
408 impl foo::Bar for Foo {
409     $0type Qux;
410
411     const Baz: usize = 42;
412
413     const Fez: usize;
414
415     fn foo() {
416         todo!()
417     }
418 }
419 "#,
420         )
421     }
422     #[test]
423     fn add_custom_impl_for_unique_input() {
424         check_assist(
425             replace_derive_with_manual_impl,
426             r#"
427 #[derive(Debu$0g)]
428 struct Foo {
429     bar: String,
430 }
431             "#,
432             r#"
433 struct Foo {
434     bar: String,
435 }
436
437 impl Debug for Foo {
438     $0
439 }
440             "#,
441         )
442     }
443
444     #[test]
445     fn add_custom_impl_for_with_visibility_modifier() {
446         check_assist(
447             replace_derive_with_manual_impl,
448             r#"
449 #[derive(Debug$0)]
450 pub struct Foo {
451     bar: String,
452 }
453             "#,
454             r#"
455 pub struct Foo {
456     bar: String,
457 }
458
459 impl Debug for Foo {
460     $0
461 }
462             "#,
463         )
464     }
465
466     #[test]
467     fn add_custom_impl_when_multiple_inputs() {
468         check_assist(
469             replace_derive_with_manual_impl,
470             r#"
471 #[derive(Display, Debug$0, Serialize)]
472 struct Foo {}
473             "#,
474             r#"
475 #[derive(Display, Serialize)]
476 struct Foo {}
477
478 impl Debug for Foo {
479     $0
480 }
481             "#,
482         )
483     }
484
485     #[test]
486     fn test_ignore_derive_macro_without_input() {
487         check_assist_not_applicable(
488             replace_derive_with_manual_impl,
489             r#"
490 #[derive($0)]
491 struct Foo {}
492             "#,
493         )
494     }
495
496     #[test]
497     fn test_ignore_if_cursor_on_param() {
498         check_assist_not_applicable(
499             replace_derive_with_manual_impl,
500             r#"
501 #[derive$0(Debug)]
502 struct Foo {}
503             "#,
504         );
505
506         check_assist_not_applicable(
507             replace_derive_with_manual_impl,
508             r#"
509 #[derive(Debug)$0]
510 struct Foo {}
511             "#,
512         )
513     }
514
515     #[test]
516     fn test_ignore_if_not_derive() {
517         check_assist_not_applicable(
518             replace_derive_with_manual_impl,
519             r#"
520 #[allow(non_camel_$0case_types)]
521 struct Foo {}
522             "#,
523         )
524     }
525
526     #[test]
527     fn works_at_start_of_file() {
528         cov_mark::check!(outside_of_attr_args);
529         check_assist_not_applicable(
530             replace_derive_with_manual_impl,
531             r#"
532 $0#[derive(Debug)]
533 struct S;
534             "#,
535         );
536     }
537 }