]> git.lizzy.rs Git - rust.git/blob - crates/ide_assists/src/handlers/replace_derive_with_manual_impl.rs
Merge #8583
[rust.git] / crates / ide_assists / src / handlers / replace_derive_with_manual_impl.rs
1 use hir::ModuleDef;
2 use ide_db::helpers::{import_assets::NameToImport, mod_path_to_ast};
3 use ide_db::items_locator;
4 use itertools::Itertools;
5 use syntax::{
6     ast::{self, make, AstNode, NameOwner},
7     SyntaxKind::{IDENT, WHITESPACE},
8 };
9
10 use crate::{
11     assist_context::{AssistBuilder, AssistContext, Assists},
12     utils::{
13         add_trait_assoc_items_to_impl, filter_assoc_items, generate_trait_impl_text,
14         render_snippet, Cursor, DefaultMethods,
15     },
16     AssistId, AssistKind,
17 };
18
19 // Assist: replace_derive_with_manual_impl
20 //
21 // Converts a `derive` impl into a manual one.
22 //
23 // ```
24 // # trait Debug { fn fmt(&self, f: &mut Formatter) -> Result<()>; }
25 // #[derive(Deb$0ug, Display)]
26 // struct S;
27 // ```
28 // ->
29 // ```
30 // # trait Debug { fn fmt(&self, f: &mut Formatter) -> Result<()>; }
31 // #[derive(Display)]
32 // struct S;
33 //
34 // impl Debug for S {
35 //     fn fmt(&self, f: &mut Formatter) -> Result<()> {
36 //         ${0:todo!()}
37 //     }
38 // }
39 // ```
40 pub(crate) fn replace_derive_with_manual_impl(
41     acc: &mut Assists,
42     ctx: &AssistContext,
43 ) -> Option<()> {
44     let attr = ctx.find_node_at_offset::<ast::Attr>()?;
45     let (name, args) = attr.as_simple_call()?;
46     if name != "derive" {
47         return None;
48     }
49
50     let trait_token = args.syntax().token_at_offset(ctx.offset()).find(|t| t.kind() == IDENT)?;
51     let trait_name = trait_token.text();
52
53     let adt = attr.syntax().parent().and_then(ast::Adt::cast)?;
54
55     let current_module = ctx.sema.scope(adt.syntax()).module()?;
56     let current_crate = current_module.krate();
57
58     let found_traits = items_locator::items_with_name(
59         &ctx.sema,
60         current_crate,
61         NameToImport::Exact(trait_name.to_string()),
62         items_locator::AssocItemSearch::Exclude,
63         Some(items_locator::DEFAULT_QUERY_SEARCH_LIMIT),
64     )
65     .filter_map(|item| match ModuleDef::from(item.as_module_def_id()?) {
66         ModuleDef::Trait(trait_) => Some(trait_),
67         _ => None,
68     })
69     .flat_map(|trait_| {
70         current_module
71             .find_use_path(ctx.sema.db, hir::ModuleDef::Trait(trait_))
72             .as_ref()
73             .map(mod_path_to_ast)
74             .zip(Some(trait_))
75     });
76
77     let mut no_traits_found = true;
78     for (trait_path, trait_) in found_traits.inspect(|_| no_traits_found = false) {
79         add_assist(acc, ctx, &attr, &args, &trait_path, Some(trait_), &adt)?;
80     }
81     if no_traits_found {
82         let trait_path = make::path_unqualified(make::path_segment(make::name_ref(trait_name)));
83         add_assist(acc, ctx, &attr, &args, &trait_path, None, &adt)?;
84     }
85     Some(())
86 }
87
88 fn add_assist(
89     acc: &mut Assists,
90     ctx: &AssistContext,
91     attr: &ast::Attr,
92     input: &ast::TokenTree,
93     trait_path: &ast::Path,
94     trait_: Option<hir::Trait>,
95     adt: &ast::Adt,
96 ) -> Option<()> {
97     let target = attr.syntax().text_range();
98     let annotated_name = adt.name()?;
99     let label = format!("Convert to manual `impl {} for {}`", trait_path, annotated_name);
100     let trait_name = trait_path.segment().and_then(|seg| seg.name_ref())?;
101
102     acc.add(
103         AssistId("replace_derive_with_manual_impl", AssistKind::Refactor),
104         label,
105         target,
106         |builder| {
107             let insert_pos = adt.syntax().text_range().end();
108             let impl_def_with_items =
109                 impl_def_from_trait(&ctx.sema, &annotated_name, trait_, trait_path);
110             update_attribute(builder, &input, &trait_name, &attr);
111             let trait_path = format!("{}", trait_path);
112             match (ctx.config.snippet_cap, impl_def_with_items) {
113                 (None, _) => {
114                     builder.insert(insert_pos, generate_trait_impl_text(adt, &trait_path, ""))
115                 }
116                 (Some(cap), None) => builder.insert_snippet(
117                     cap,
118                     insert_pos,
119                     generate_trait_impl_text(adt, &trait_path, "    $0"),
120                 ),
121                 (Some(cap), Some((impl_def, first_assoc_item))) => {
122                     let mut cursor = Cursor::Before(first_assoc_item.syntax());
123                     let placeholder;
124                     if let ast::AssocItem::Fn(ref func) = first_assoc_item {
125                         if let Some(m) = func.syntax().descendants().find_map(ast::MacroCall::cast)
126                         {
127                             if m.syntax().text() == "todo!()" {
128                                 placeholder = m;
129                                 cursor = Cursor::Replace(placeholder.syntax());
130                             }
131                         }
132                     }
133
134                     builder.insert_snippet(
135                         cap,
136                         insert_pos,
137                         format!("\n\n{}", render_snippet(cap, impl_def.syntax(), cursor)),
138                     )
139                 }
140             };
141         },
142     )
143 }
144
145 fn impl_def_from_trait(
146     sema: &hir::Semantics<ide_db::RootDatabase>,
147     annotated_name: &ast::Name,
148     trait_: Option<hir::Trait>,
149     trait_path: &ast::Path,
150 ) -> Option<(ast::Impl, ast::AssocItem)> {
151     let trait_ = trait_?;
152     let target_scope = sema.scope(annotated_name.syntax());
153     let trait_items = filter_assoc_items(sema.db, &trait_.items(sema.db), DefaultMethods::No);
154     if trait_items.is_empty() {
155         return None;
156     }
157     let impl_def = make::impl_trait(
158         trait_path.clone(),
159         make::path_unqualified(make::path_segment(make::name_ref(&annotated_name.text()))),
160     );
161     let (impl_def, first_assoc_item) =
162         add_trait_assoc_items_to_impl(sema, trait_items, trait_, impl_def, target_scope);
163     Some((impl_def, first_assoc_item))
164 }
165
166 fn update_attribute(
167     builder: &mut AssistBuilder,
168     input: &ast::TokenTree,
169     trait_name: &ast::NameRef,
170     attr: &ast::Attr,
171 ) {
172     let trait_name = trait_name.text();
173     let new_attr_input = input
174         .syntax()
175         .descendants_with_tokens()
176         .filter(|t| t.kind() == IDENT)
177         .filter_map(|t| t.into_token().map(|t| t.text().to_string()))
178         .filter(|t| t != &trait_name)
179         .collect::<Vec<_>>();
180     let has_more_derives = !new_attr_input.is_empty();
181
182     if has_more_derives {
183         let new_attr_input = format!("({})", new_attr_input.iter().format(", "));
184         builder.replace(input.syntax().text_range(), new_attr_input);
185     } else {
186         let attr_range = attr.syntax().text_range();
187         builder.delete(attr_range);
188
189         if let Some(line_break_range) = attr
190             .syntax()
191             .next_sibling_or_token()
192             .filter(|t| t.kind() == WHITESPACE)
193             .map(|t| t.text_range())
194         {
195             builder.delete(line_break_range);
196         }
197     }
198 }
199
200 #[cfg(test)]
201 mod tests {
202     use crate::tests::{check_assist, check_assist_not_applicable};
203
204     use super::*;
205
206     #[test]
207     fn add_custom_impl_debug() {
208         check_assist(
209             replace_derive_with_manual_impl,
210             "
211 mod fmt {
212     pub struct Error;
213     pub type Result = Result<(), Error>;
214     pub struct Formatter<'a>;
215     pub trait Debug {
216         fn fmt(&self, f: &mut Formatter<'_>) -> Result;
217     }
218 }
219
220 #[derive(Debu$0g)]
221 struct Foo {
222     bar: String,
223 }
224 ",
225             "
226 mod fmt {
227     pub struct Error;
228     pub type Result = Result<(), Error>;
229     pub struct Formatter<'a>;
230     pub trait Debug {
231         fn fmt(&self, f: &mut Formatter<'_>) -> Result;
232     }
233 }
234
235 struct Foo {
236     bar: String,
237 }
238
239 impl fmt::Debug for Foo {
240     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
241         ${0:todo!()}
242     }
243 }
244 ",
245         )
246     }
247     #[test]
248     fn add_custom_impl_all() {
249         check_assist(
250             replace_derive_with_manual_impl,
251             "
252 mod foo {
253     pub trait Bar {
254         type Qux;
255         const Baz: usize = 42;
256         const Fez: usize;
257         fn foo();
258         fn bar() {}
259     }
260 }
261
262 #[derive($0Bar)]
263 struct Foo {
264     bar: String,
265 }
266 ",
267             "
268 mod foo {
269     pub trait Bar {
270         type Qux;
271         const Baz: usize = 42;
272         const Fez: usize;
273         fn foo();
274         fn bar() {}
275     }
276 }
277
278 struct Foo {
279     bar: String,
280 }
281
282 impl foo::Bar for Foo {
283     $0type Qux;
284
285     const Baz: usize = 42;
286
287     const Fez: usize;
288
289     fn foo() {
290         todo!()
291     }
292 }
293 ",
294         )
295     }
296     #[test]
297     fn add_custom_impl_for_unique_input() {
298         check_assist(
299             replace_derive_with_manual_impl,
300             "
301 #[derive(Debu$0g)]
302 struct Foo {
303     bar: String,
304 }
305             ",
306             "
307 struct Foo {
308     bar: String,
309 }
310
311 impl Debug for Foo {
312     $0
313 }
314             ",
315         )
316     }
317
318     #[test]
319     fn add_custom_impl_for_with_visibility_modifier() {
320         check_assist(
321             replace_derive_with_manual_impl,
322             "
323 #[derive(Debug$0)]
324 pub struct Foo {
325     bar: String,
326 }
327             ",
328             "
329 pub struct Foo {
330     bar: String,
331 }
332
333 impl Debug for Foo {
334     $0
335 }
336             ",
337         )
338     }
339
340     #[test]
341     fn add_custom_impl_when_multiple_inputs() {
342         check_assist(
343             replace_derive_with_manual_impl,
344             "
345 #[derive(Display, Debug$0, Serialize)]
346 struct Foo {}
347             ",
348             "
349 #[derive(Display, Serialize)]
350 struct Foo {}
351
352 impl Debug for Foo {
353     $0
354 }
355             ",
356         )
357     }
358
359     #[test]
360     fn test_ignore_derive_macro_without_input() {
361         check_assist_not_applicable(
362             replace_derive_with_manual_impl,
363             "
364 #[derive($0)]
365 struct Foo {}
366             ",
367         )
368     }
369
370     #[test]
371     fn test_ignore_if_cursor_on_param() {
372         check_assist_not_applicable(
373             replace_derive_with_manual_impl,
374             "
375 #[derive$0(Debug)]
376 struct Foo {}
377             ",
378         );
379
380         check_assist_not_applicable(
381             replace_derive_with_manual_impl,
382             "
383 #[derive(Debug)$0]
384 struct Foo {}
385             ",
386         )
387     }
388
389     #[test]
390     fn test_ignore_if_not_derive() {
391         check_assist_not_applicable(
392             replace_derive_with_manual_impl,
393             "
394 #[allow(non_camel_$0case_types)]
395 struct Foo {}
396             ",
397         )
398     }
399 }