]> git.lizzy.rs Git - rust.git/blob - crates/ide_assists/src/handlers/replace_derive_with_manual_impl.rs
Merge #8213
[rust.git] / crates / ide_assists / src / handlers / replace_derive_with_manual_impl.rs
1 use hir::ModuleDef;
2 use ide_db::helpers::{import_assets::NameToImport, mod_path_to_ast};
3 use ide_db::items_locator;
4 use itertools::Itertools;
5 use syntax::{
6     ast::{self, make, AstNode, NameOwner},
7     SyntaxKind::{IDENT, WHITESPACE},
8     TextSize,
9 };
10
11 use crate::{
12     assist_context::{AssistBuilder, AssistContext, Assists},
13     utils::{
14         add_trait_assoc_items_to_impl, filter_assoc_items, generate_trait_impl_text,
15         render_snippet, Cursor, DefaultMethods,
16     },
17     AssistId, AssistKind,
18 };
19
20 // Assist: replace_derive_with_manual_impl
21 //
22 // Converts a `derive` impl into a manual one.
23 //
24 // ```
25 // # trait Debug { fn fmt(&self, f: &mut Formatter) -> Result<()>; }
26 // #[derive(Deb$0ug, Display)]
27 // struct S;
28 // ```
29 // ->
30 // ```
31 // # trait Debug { fn fmt(&self, f: &mut Formatter) -> Result<()>; }
32 // #[derive(Display)]
33 // struct S;
34 //
35 // impl Debug for S {
36 //     fn fmt(&self, f: &mut Formatter) -> Result<()> {
37 //         ${0:todo!()}
38 //     }
39 // }
40 // ```
41 pub(crate) fn replace_derive_with_manual_impl(
42     acc: &mut Assists,
43     ctx: &AssistContext,
44 ) -> Option<()> {
45     let attr = ctx.find_node_at_offset::<ast::Attr>()?;
46
47     let has_derive = attr
48         .syntax()
49         .descendants_with_tokens()
50         .filter(|t| t.kind() == IDENT)
51         .find_map(syntax::NodeOrToken::into_token)
52         .filter(|t| t.text() == "derive")
53         .is_some();
54     if !has_derive {
55         return None;
56     }
57
58     let trait_token = ctx.token_at_offset().find(|t| t.kind() == IDENT && t.text() != "derive")?;
59     let trait_path = make::path_unqualified(make::path_segment(make::name_ref(trait_token.text())));
60
61     let adt = attr.syntax().parent().and_then(ast::Adt::cast)?;
62     let annotated_name = adt.name()?;
63     let insert_pos = adt.syntax().text_range().end();
64
65     let current_module = ctx.sema.scope(annotated_name.syntax()).module()?;
66     let current_crate = current_module.krate();
67
68     let found_traits = items_locator::items_with_name(
69         &ctx.sema,
70         current_crate,
71         NameToImport::Exact(trait_token.text().to_string()),
72         items_locator::AssocItemSearch::Exclude,
73         Some(items_locator::DEFAULT_QUERY_SEARCH_LIMIT),
74     )
75     .filter_map(|item| match ModuleDef::from(item.as_module_def_id()?) {
76         ModuleDef::Trait(trait_) => Some(trait_),
77         _ => None,
78     })
79     .flat_map(|trait_| {
80         current_module
81             .find_use_path(ctx.sema.db, hir::ModuleDef::Trait(trait_))
82             .as_ref()
83             .map(mod_path_to_ast)
84             .zip(Some(trait_))
85     });
86
87     let mut no_traits_found = true;
88     for (trait_path, trait_) in found_traits.inspect(|_| no_traits_found = false) {
89         add_assist(acc, ctx, &attr, &trait_path, Some(trait_), &adt, &annotated_name, insert_pos)?;
90     }
91     if no_traits_found {
92         add_assist(acc, ctx, &attr, &trait_path, None, &adt, &annotated_name, insert_pos)?;
93     }
94     Some(())
95 }
96
97 fn add_assist(
98     acc: &mut Assists,
99     ctx: &AssistContext,
100     attr: &ast::Attr,
101     trait_path: &ast::Path,
102     trait_: Option<hir::Trait>,
103     adt: &ast::Adt,
104     annotated_name: &ast::Name,
105     insert_pos: TextSize,
106 ) -> Option<()> {
107     let target = attr.syntax().text_range();
108     let input = attr.token_tree()?;
109     let label = format!("Convert to manual  `impl {} for {}`", trait_path, annotated_name);
110     let trait_name = trait_path.segment().and_then(|seg| seg.name_ref())?;
111
112     acc.add(
113         AssistId("replace_derive_with_manual_impl", AssistKind::Refactor),
114         label,
115         target,
116         |builder| {
117             let impl_def_with_items =
118                 impl_def_from_trait(&ctx.sema, annotated_name, trait_, trait_path);
119             update_attribute(builder, &input, &trait_name, &attr);
120             let trait_path = format!("{}", trait_path);
121             match (ctx.config.snippet_cap, impl_def_with_items) {
122                 (None, _) => {
123                     builder.insert(insert_pos, generate_trait_impl_text(adt, &trait_path, ""))
124                 }
125                 (Some(cap), None) => builder.insert_snippet(
126                     cap,
127                     insert_pos,
128                     generate_trait_impl_text(adt, &trait_path, "    $0"),
129                 ),
130                 (Some(cap), Some((impl_def, first_assoc_item))) => {
131                     let mut cursor = Cursor::Before(first_assoc_item.syntax());
132                     let placeholder;
133                     if let ast::AssocItem::Fn(ref func) = first_assoc_item {
134                         if let Some(m) = func.syntax().descendants().find_map(ast::MacroCall::cast)
135                         {
136                             if m.syntax().text() == "todo!()" {
137                                 placeholder = m;
138                                 cursor = Cursor::Replace(placeholder.syntax());
139                             }
140                         }
141                     }
142
143                     builder.insert_snippet(
144                         cap,
145                         insert_pos,
146                         format!("\n\n{}", render_snippet(cap, impl_def.syntax(), cursor)),
147                     )
148                 }
149             };
150         },
151     )
152 }
153
154 fn impl_def_from_trait(
155     sema: &hir::Semantics<ide_db::RootDatabase>,
156     annotated_name: &ast::Name,
157     trait_: Option<hir::Trait>,
158     trait_path: &ast::Path,
159 ) -> Option<(ast::Impl, ast::AssocItem)> {
160     let trait_ = trait_?;
161     let target_scope = sema.scope(annotated_name.syntax());
162     let trait_items = filter_assoc_items(sema.db, &trait_.items(sema.db), DefaultMethods::No);
163     if trait_items.is_empty() {
164         return None;
165     }
166     let impl_def = make::impl_trait(
167         trait_path.clone(),
168         make::path_unqualified(make::path_segment(make::name_ref(&annotated_name.text()))),
169     );
170     let (impl_def, first_assoc_item) =
171         add_trait_assoc_items_to_impl(sema, trait_items, trait_, impl_def, target_scope);
172     Some((impl_def, first_assoc_item))
173 }
174
175 fn update_attribute(
176     builder: &mut AssistBuilder,
177     input: &ast::TokenTree,
178     trait_name: &ast::NameRef,
179     attr: &ast::Attr,
180 ) {
181     let trait_name = trait_name.text();
182     let new_attr_input = input
183         .syntax()
184         .descendants_with_tokens()
185         .filter(|t| t.kind() == IDENT)
186         .filter_map(|t| t.into_token().map(|t| t.text().to_string()))
187         .filter(|t| t != &trait_name)
188         .collect::<Vec<_>>();
189     let has_more_derives = !new_attr_input.is_empty();
190
191     if has_more_derives {
192         let new_attr_input = format!("({})", new_attr_input.iter().format(", "));
193         builder.replace(input.syntax().text_range(), new_attr_input);
194     } else {
195         let attr_range = attr.syntax().text_range();
196         builder.delete(attr_range);
197
198         if let Some(line_break_range) = attr
199             .syntax()
200             .next_sibling_or_token()
201             .filter(|t| t.kind() == WHITESPACE)
202             .map(|t| t.text_range())
203         {
204             builder.delete(line_break_range);
205         }
206     }
207 }
208
209 #[cfg(test)]
210 mod tests {
211     use crate::tests::{check_assist, check_assist_not_applicable};
212
213     use super::*;
214
215     #[test]
216     fn add_custom_impl_debug() {
217         check_assist(
218             replace_derive_with_manual_impl,
219             "
220 mod fmt {
221     pub struct Error;
222     pub type Result = Result<(), Error>;
223     pub struct Formatter<'a>;
224     pub trait Debug {
225         fn fmt(&self, f: &mut Formatter<'_>) -> Result;
226     }
227 }
228
229 #[derive(Debu$0g)]
230 struct Foo {
231     bar: String,
232 }
233 ",
234             "
235 mod fmt {
236     pub struct Error;
237     pub type Result = Result<(), Error>;
238     pub struct Formatter<'a>;
239     pub trait Debug {
240         fn fmt(&self, f: &mut Formatter<'_>) -> Result;
241     }
242 }
243
244 struct Foo {
245     bar: String,
246 }
247
248 impl fmt::Debug for Foo {
249     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
250         ${0:todo!()}
251     }
252 }
253 ",
254         )
255     }
256     #[test]
257     fn add_custom_impl_all() {
258         check_assist(
259             replace_derive_with_manual_impl,
260             "
261 mod foo {
262     pub trait Bar {
263         type Qux;
264         const Baz: usize = 42;
265         const Fez: usize;
266         fn foo();
267         fn bar() {}
268     }
269 }
270
271 #[derive($0Bar)]
272 struct Foo {
273     bar: String,
274 }
275 ",
276             "
277 mod foo {
278     pub trait Bar {
279         type Qux;
280         const Baz: usize = 42;
281         const Fez: usize;
282         fn foo();
283         fn bar() {}
284     }
285 }
286
287 struct Foo {
288     bar: String,
289 }
290
291 impl foo::Bar for Foo {
292     $0type Qux;
293
294     const Baz: usize = 42;
295
296     const Fez: usize;
297
298     fn foo() {
299         todo!()
300     }
301 }
302 ",
303         )
304     }
305     #[test]
306     fn add_custom_impl_for_unique_input() {
307         check_assist(
308             replace_derive_with_manual_impl,
309             "
310 #[derive(Debu$0g)]
311 struct Foo {
312     bar: String,
313 }
314             ",
315             "
316 struct Foo {
317     bar: String,
318 }
319
320 impl Debug for Foo {
321     $0
322 }
323             ",
324         )
325     }
326
327     #[test]
328     fn add_custom_impl_for_with_visibility_modifier() {
329         check_assist(
330             replace_derive_with_manual_impl,
331             "
332 #[derive(Debug$0)]
333 pub struct Foo {
334     bar: String,
335 }
336             ",
337             "
338 pub struct Foo {
339     bar: String,
340 }
341
342 impl Debug for Foo {
343     $0
344 }
345             ",
346         )
347     }
348
349     #[test]
350     fn add_custom_impl_when_multiple_inputs() {
351         check_assist(
352             replace_derive_with_manual_impl,
353             "
354 #[derive(Display, Debug$0, Serialize)]
355 struct Foo {}
356             ",
357             "
358 #[derive(Display, Serialize)]
359 struct Foo {}
360
361 impl Debug for Foo {
362     $0
363 }
364             ",
365         )
366     }
367
368     #[test]
369     fn test_ignore_derive_macro_without_input() {
370         check_assist_not_applicable(
371             replace_derive_with_manual_impl,
372             "
373 #[derive($0)]
374 struct Foo {}
375             ",
376         )
377     }
378
379     #[test]
380     fn test_ignore_if_cursor_on_param() {
381         check_assist_not_applicable(
382             replace_derive_with_manual_impl,
383             "
384 #[derive$0(Debug)]
385 struct Foo {}
386             ",
387         );
388
389         check_assist_not_applicable(
390             replace_derive_with_manual_impl,
391             "
392 #[derive(Debug)$0]
393 struct Foo {}
394             ",
395         )
396     }
397
398     #[test]
399     fn test_ignore_if_not_derive() {
400         check_assist_not_applicable(
401             replace_derive_with_manual_impl,
402             "
403 #[allow(non_camel_$0case_types)]
404 struct Foo {}
405             ",
406         )
407     }
408 }