]> git.lizzy.rs Git - rust.git/blob - crates/ide_assists/src/handlers/extract_struct_from_enum_variant.rs
Merge #8207
[rust.git] / crates / ide_assists / src / handlers / extract_struct_from_enum_variant.rs
1 use std::iter;
2
3 use either::Either;
4 use hir::{Module, ModuleDef, Name, Variant};
5 use ide_db::{
6     defs::Definition,
7     helpers::{
8         insert_use::{insert_use, ImportScope},
9         mod_path_to_ast,
10     },
11     search::FileReference,
12     RootDatabase,
13 };
14 use rustc_hash::FxHashSet;
15 use syntax::{
16     algo::{find_node_at_offset, SyntaxRewriter},
17     ast::{self, edit::IndentLevel, make, AstNode, NameOwner, VisibilityOwner},
18     SourceFile, SyntaxElement, SyntaxNode, T,
19 };
20
21 use crate::{AssistContext, AssistId, AssistKind, Assists};
22
23 // Assist: extract_struct_from_enum_variant
24 //
25 // Extracts a struct from enum variant.
26 //
27 // ```
28 // enum A { $0One(u32, u32) }
29 // ```
30 // ->
31 // ```
32 // struct One(pub u32, pub u32);
33 //
34 // enum A { One(One) }
35 // ```
36 pub(crate) fn extract_struct_from_enum_variant(
37     acc: &mut Assists,
38     ctx: &AssistContext,
39 ) -> Option<()> {
40     let variant = ctx.find_node_at_offset::<ast::Variant>()?;
41     let field_list = extract_field_list_if_applicable(&variant)?;
42
43     let variant_name = variant.name()?;
44     let variant_hir = ctx.sema.to_def(&variant)?;
45     if existing_definition(ctx.db(), &variant_name, &variant_hir) {
46         return None;
47     }
48
49     let enum_ast = variant.parent_enum();
50     let enum_hir = ctx.sema.to_def(&enum_ast)?;
51     let target = variant.syntax().text_range();
52     acc.add(
53         AssistId("extract_struct_from_enum_variant", AssistKind::RefactorRewrite),
54         "Extract struct from enum variant",
55         target,
56         |builder| {
57             let variant_hir_name = variant_hir.name(ctx.db());
58             let enum_module_def = ModuleDef::from(enum_hir);
59             let usages =
60                 Definition::ModuleDef(ModuleDef::Variant(variant_hir)).usages(&ctx.sema).all();
61
62             let mut visited_modules_set = FxHashSet::default();
63             let current_module = enum_hir.module(ctx.db());
64             visited_modules_set.insert(current_module);
65             let mut def_rewriter = None;
66             for (file_id, references) in usages {
67                 let mut rewriter = SyntaxRewriter::default();
68                 let source_file = ctx.sema.parse(file_id);
69                 for reference in references {
70                     update_reference(
71                         ctx,
72                         &mut rewriter,
73                         reference,
74                         &source_file,
75                         &enum_module_def,
76                         &variant_hir_name,
77                         &mut visited_modules_set,
78                     );
79                 }
80                 if file_id == ctx.frange.file_id {
81                     def_rewriter = Some(rewriter);
82                     continue;
83                 }
84                 builder.edit_file(file_id);
85                 builder.rewrite(rewriter);
86             }
87             let mut rewriter = def_rewriter.unwrap_or_default();
88             update_variant(&mut rewriter, &variant);
89             extract_struct_def(
90                 &mut rewriter,
91                 &enum_ast,
92                 variant_name.clone(),
93                 &field_list,
94                 &variant.parent_enum().syntax().clone().into(),
95                 enum_ast.visibility(),
96             );
97             builder.edit_file(ctx.frange.file_id);
98             builder.rewrite(rewriter);
99         },
100     )
101 }
102
103 fn extract_field_list_if_applicable(
104     variant: &ast::Variant,
105 ) -> Option<Either<ast::RecordFieldList, ast::TupleFieldList>> {
106     match variant.kind() {
107         ast::StructKind::Record(field_list) if field_list.fields().next().is_some() => {
108             Some(Either::Left(field_list))
109         }
110         ast::StructKind::Tuple(field_list) if field_list.fields().count() > 1 => {
111             Some(Either::Right(field_list))
112         }
113         _ => None,
114     }
115 }
116
117 fn existing_definition(db: &RootDatabase, variant_name: &ast::Name, variant: &Variant) -> bool {
118     variant
119         .parent_enum(db)
120         .module(db)
121         .scope(db, None)
122         .into_iter()
123         .filter(|(_, def)| match def {
124             // only check type-namespace
125             hir::ScopeDef::ModuleDef(def) => matches!(
126                 def,
127                 ModuleDef::Module(_)
128                     | ModuleDef::Adt(_)
129                     | ModuleDef::Variant(_)
130                     | ModuleDef::Trait(_)
131                     | ModuleDef::TypeAlias(_)
132                     | ModuleDef::BuiltinType(_)
133             ),
134             _ => false,
135         })
136         .any(|(name, _)| name.to_string() == variant_name.to_string())
137 }
138
139 fn insert_import(
140     ctx: &AssistContext,
141     rewriter: &mut SyntaxRewriter,
142     scope_node: &SyntaxNode,
143     module: &Module,
144     enum_module_def: &ModuleDef,
145     variant_hir_name: &Name,
146 ) -> Option<()> {
147     let db = ctx.db();
148     let mod_path =
149         module.find_use_path_prefixed(db, *enum_module_def, ctx.config.insert_use.prefix_kind);
150     if let Some(mut mod_path) = mod_path {
151         mod_path.pop_segment();
152         mod_path.push_segment(variant_hir_name.clone());
153         let scope = ImportScope::find_insert_use_container(scope_node, &ctx.sema)?;
154         *rewriter += insert_use(&scope, mod_path_to_ast(&mod_path), ctx.config.insert_use);
155     }
156     Some(())
157 }
158
159 fn extract_struct_def(
160     rewriter: &mut SyntaxRewriter,
161     enum_: &ast::Enum,
162     variant_name: ast::Name,
163     field_list: &Either<ast::RecordFieldList, ast::TupleFieldList>,
164     start_offset: &SyntaxElement,
165     visibility: Option<ast::Visibility>,
166 ) -> Option<()> {
167     let pub_vis = Some(make::visibility_pub());
168     let field_list = match field_list {
169         Either::Left(field_list) => {
170             make::record_field_list(field_list.fields().flat_map(|field| {
171                 Some(make::record_field(pub_vis.clone(), field.name()?, field.ty()?))
172             }))
173             .into()
174         }
175         Either::Right(field_list) => make::tuple_field_list(
176             field_list
177                 .fields()
178                 .flat_map(|field| Some(make::tuple_field(pub_vis.clone(), field.ty()?))),
179         )
180         .into(),
181     };
182
183     rewriter.insert_before(
184         start_offset,
185         make::struct_(visibility, variant_name, None, field_list).syntax(),
186     );
187     rewriter.insert_before(start_offset, &make::tokens::blank_line());
188
189     if let indent_level @ 1..=usize::MAX = IndentLevel::from_node(enum_.syntax()).0 as usize {
190         rewriter
191             .insert_before(start_offset, &make::tokens::whitespace(&" ".repeat(4 * indent_level)));
192     }
193     Some(())
194 }
195
196 fn update_variant(rewriter: &mut SyntaxRewriter, variant: &ast::Variant) -> Option<()> {
197     let name = variant.name()?;
198     let tuple_field = make::tuple_field(None, make::ty(&name.text()));
199     let replacement = make::variant(
200         name,
201         Some(ast::FieldList::TupleFieldList(make::tuple_field_list(iter::once(tuple_field)))),
202     );
203     rewriter.replace(variant.syntax(), replacement.syntax());
204     Some(())
205 }
206
207 fn update_reference(
208     ctx: &AssistContext,
209     rewriter: &mut SyntaxRewriter,
210     reference: FileReference,
211     source_file: &SourceFile,
212     enum_module_def: &ModuleDef,
213     variant_hir_name: &Name,
214     visited_modules_set: &mut FxHashSet<Module>,
215 ) -> Option<()> {
216     let offset = reference.range.start();
217     let (segment, expr) = if let Some(path_expr) =
218         find_node_at_offset::<ast::PathExpr>(source_file.syntax(), offset)
219     {
220         // tuple variant
221         (path_expr.path()?.segment()?, path_expr.syntax().parent()?)
222     } else if let Some(record_expr) =
223         find_node_at_offset::<ast::RecordExpr>(source_file.syntax(), offset)
224     {
225         // record variant
226         (record_expr.path()?.segment()?, record_expr.syntax().clone())
227     } else {
228         return None;
229     };
230
231     let module = ctx.sema.scope(&expr).module()?;
232     if !visited_modules_set.contains(&module) {
233         if insert_import(ctx, rewriter, &expr, &module, enum_module_def, variant_hir_name).is_some()
234         {
235             visited_modules_set.insert(module);
236         }
237     }
238     rewriter.insert_after(segment.syntax(), &make::token(T!['(']));
239     rewriter.insert_after(segment.syntax(), segment.syntax());
240     rewriter.insert_after(&expr, &make::token(T![')']));
241     Some(())
242 }
243
244 #[cfg(test)]
245 mod tests {
246     use ide_db::helpers::FamousDefs;
247
248     use crate::tests::{check_assist, check_assist_not_applicable};
249
250     use super::*;
251
252     #[test]
253     fn test_extract_struct_several_fields_tuple() {
254         check_assist(
255             extract_struct_from_enum_variant,
256             "enum A { $0One(u32, u32) }",
257             r#"struct One(pub u32, pub u32);
258
259 enum A { One(One) }"#,
260         );
261     }
262
263     #[test]
264     fn test_extract_struct_several_fields_named() {
265         check_assist(
266             extract_struct_from_enum_variant,
267             "enum A { $0One { foo: u32, bar: u32 } }",
268             r#"struct One{ pub foo: u32, pub bar: u32 }
269
270 enum A { One(One) }"#,
271         );
272     }
273
274     #[test]
275     fn test_extract_struct_one_field_named() {
276         check_assist(
277             extract_struct_from_enum_variant,
278             "enum A { $0One { foo: u32 } }",
279             r#"struct One{ pub foo: u32 }
280
281 enum A { One(One) }"#,
282         );
283     }
284
285     #[test]
286     fn test_extract_enum_variant_name_value_namespace() {
287         check_assist(
288             extract_struct_from_enum_variant,
289             r#"const One: () = ();
290 enum A { $0One(u32, u32) }"#,
291             r#"const One: () = ();
292 struct One(pub u32, pub u32);
293
294 enum A { One(One) }"#,
295         );
296     }
297
298     #[test]
299     fn test_extract_struct_pub_visibility() {
300         check_assist(
301             extract_struct_from_enum_variant,
302             "pub enum A { $0One(u32, u32) }",
303             r#"pub struct One(pub u32, pub u32);
304
305 pub enum A { One(One) }"#,
306         );
307     }
308
309     #[test]
310     fn test_extract_struct_with_complex_imports() {
311         check_assist(
312             extract_struct_from_enum_variant,
313             r#"mod my_mod {
314     fn another_fn() {
315         let m = my_other_mod::MyEnum::MyField(1, 1);
316     }
317
318     pub mod my_other_mod {
319         fn another_fn() {
320             let m = MyEnum::MyField(1, 1);
321         }
322
323         pub enum MyEnum {
324             $0MyField(u8, u8),
325         }
326     }
327 }
328
329 fn another_fn() {
330     let m = my_mod::my_other_mod::MyEnum::MyField(1, 1);
331 }"#,
332             r#"use my_mod::my_other_mod::MyField;
333
334 mod my_mod {
335     use self::my_other_mod::MyField;
336
337     fn another_fn() {
338         let m = my_other_mod::MyEnum::MyField(MyField(1, 1));
339     }
340
341     pub mod my_other_mod {
342         fn another_fn() {
343             let m = MyEnum::MyField(MyField(1, 1));
344         }
345
346         pub struct MyField(pub u8, pub u8);
347
348         pub enum MyEnum {
349             MyField(MyField),
350         }
351     }
352 }
353
354 fn another_fn() {
355     let m = my_mod::my_other_mod::MyEnum::MyField(MyField(1, 1));
356 }"#,
357         );
358     }
359
360     #[test]
361     fn extract_record_fix_references() {
362         check_assist(
363             extract_struct_from_enum_variant,
364             r#"
365 enum E {
366     $0V { i: i32, j: i32 }
367 }
368
369 fn f() {
370     let e = E::V { i: 9, j: 2 };
371 }
372 "#,
373             r#"
374 struct V{ pub i: i32, pub j: i32 }
375
376 enum E {
377     V(V)
378 }
379
380 fn f() {
381     let e = E::V(V { i: 9, j: 2 });
382 }
383 "#,
384         )
385     }
386
387     #[test]
388     fn test_several_files() {
389         check_assist(
390             extract_struct_from_enum_variant,
391             r#"
392 //- /main.rs
393 enum E {
394     $0V(i32, i32)
395 }
396 mod foo;
397
398 //- /foo.rs
399 use crate::E;
400 fn f() {
401     let e = E::V(9, 2);
402 }
403 "#,
404             r#"
405 //- /main.rs
406 struct V(pub i32, pub i32);
407
408 enum E {
409     V(V)
410 }
411 mod foo;
412
413 //- /foo.rs
414 use crate::{E, V};
415 fn f() {
416     let e = E::V(V(9, 2));
417 }
418 "#,
419         )
420     }
421
422     #[test]
423     fn test_several_files_record() {
424         check_assist(
425             extract_struct_from_enum_variant,
426             r#"
427 //- /main.rs
428 enum E {
429     $0V { i: i32, j: i32 }
430 }
431 mod foo;
432
433 //- /foo.rs
434 use crate::E;
435 fn f() {
436     let e = E::V { i: 9, j: 2 };
437 }
438 "#,
439             r#"
440 //- /main.rs
441 struct V{ pub i: i32, pub j: i32 }
442
443 enum E {
444     V(V)
445 }
446 mod foo;
447
448 //- /foo.rs
449 use crate::{E, V};
450 fn f() {
451     let e = E::V(V { i: 9, j: 2 });
452 }
453 "#,
454         )
455     }
456
457     #[test]
458     fn test_extract_struct_record_nested_call_exp() {
459         check_assist(
460             extract_struct_from_enum_variant,
461             r#"
462 enum A { $0One { a: u32, b: u32 } }
463
464 struct B(A);
465
466 fn foo() {
467     let _ = B(A::One { a: 1, b: 2 });
468 }
469 "#,
470             r#"
471 struct One{ pub a: u32, pub b: u32 }
472
473 enum A { One(One) }
474
475 struct B(A);
476
477 fn foo() {
478     let _ = B(A::One(One { a: 1, b: 2 }));
479 }
480 "#,
481         );
482     }
483
484     fn check_not_applicable(ra_fixture: &str) {
485         let fixture =
486             format!("//- /main.rs crate:main deps:core\n{}\n{}", ra_fixture, FamousDefs::FIXTURE);
487         check_assist_not_applicable(extract_struct_from_enum_variant, &fixture)
488     }
489
490     #[test]
491     fn test_extract_enum_not_applicable_for_element_with_no_fields() {
492         check_not_applicable("enum A { $0One }");
493     }
494
495     #[test]
496     fn test_extract_enum_not_applicable_if_struct_exists() {
497         check_not_applicable(
498             r#"struct One;
499         enum A { $0One(u8, u32) }"#,
500         );
501     }
502
503     #[test]
504     fn test_extract_not_applicable_one_field() {
505         check_not_applicable(r"enum A { $0One(u32) }");
506     }
507
508     #[test]
509     fn test_extract_not_applicable_no_field_tuple() {
510         check_not_applicable(r"enum A { $0None() }");
511     }
512
513     #[test]
514     fn test_extract_not_applicable_no_field_named() {
515         check_not_applicable(r"enum A { $0None {} }");
516     }
517 }