]> git.lizzy.rs Git - rust.git/blob - crates/ide_assists/src/handlers/extract_struct_from_enum_variant.rs
Merge #9338
[rust.git] / crates / ide_assists / src / handlers / extract_struct_from_enum_variant.rs
1 use std::iter;
2
3 use either::Either;
4 use hir::{Module, ModuleDef, Name, Variant};
5 use ide_db::{
6     defs::Definition,
7     helpers::{
8         insert_use::{insert_use, ImportScope, InsertUseConfig},
9         mod_path_to_ast,
10     },
11     search::FileReference,
12     RootDatabase,
13 };
14 use itertools::Itertools;
15 use rustc_hash::FxHashSet;
16 use syntax::{
17     ast::{
18         self, make, AstNode, AttrsOwner, GenericParamsOwner, NameOwner, TypeBoundsOwner,
19         VisibilityOwner,
20     },
21     match_ast,
22     ted::{self, Position},
23     SyntaxNode, T,
24 };
25
26 use crate::{assist_context::AssistBuilder, AssistContext, AssistId, AssistKind, Assists};
27
28 // Assist: extract_struct_from_enum_variant
29 //
30 // Extracts a struct from enum variant.
31 //
32 // ```
33 // enum A { $0One(u32, u32) }
34 // ```
35 // ->
36 // ```
37 // struct One(pub u32, pub u32);
38 //
39 // enum A { One(One) }
40 // ```
41 pub(crate) fn extract_struct_from_enum_variant(
42     acc: &mut Assists,
43     ctx: &AssistContext,
44 ) -> Option<()> {
45     let variant = ctx.find_node_at_offset::<ast::Variant>()?;
46     let field_list = extract_field_list_if_applicable(&variant)?;
47
48     let variant_name = variant.name()?;
49     let variant_hir = ctx.sema.to_def(&variant)?;
50     if existing_definition(ctx.db(), &variant_name, &variant_hir) {
51         cov_mark::hit!(test_extract_enum_not_applicable_if_struct_exists);
52         return None;
53     }
54
55     let enum_ast = variant.parent_enum();
56     let enum_hir = ctx.sema.to_def(&enum_ast)?;
57     let target = variant.syntax().text_range();
58     acc.add(
59         AssistId("extract_struct_from_enum_variant", AssistKind::RefactorRewrite),
60         "Extract struct from enum variant",
61         target,
62         |builder| {
63             let variant_hir_name = variant_hir.name(ctx.db());
64             let enum_module_def = ModuleDef::from(enum_hir);
65             let usages =
66                 Definition::ModuleDef(ModuleDef::Variant(variant_hir)).usages(&ctx.sema).all();
67
68             let mut visited_modules_set = FxHashSet::default();
69             let current_module = enum_hir.module(ctx.db());
70             visited_modules_set.insert(current_module);
71             // record file references of the file the def resides in, we only want to swap to the edited file in the builder once
72             let mut def_file_references = None;
73             for (file_id, references) in usages {
74                 if file_id == ctx.frange.file_id {
75                     def_file_references = Some(references);
76                     continue;
77                 }
78                 builder.edit_file(file_id);
79                 let processed = process_references(
80                     ctx,
81                     builder,
82                     &mut visited_modules_set,
83                     &enum_module_def,
84                     &variant_hir_name,
85                     references,
86                 );
87                 processed.into_iter().for_each(|(path, node, import)| {
88                     apply_references(ctx.config.insert_use, path, node, import)
89                 });
90             }
91             builder.edit_file(ctx.frange.file_id);
92             let variant = builder.make_mut(variant.clone());
93             if let Some(references) = def_file_references {
94                 let processed = process_references(
95                     ctx,
96                     builder,
97                     &mut visited_modules_set,
98                     &enum_module_def,
99                     &variant_hir_name,
100                     references,
101                 );
102                 processed.into_iter().for_each(|(path, node, import)| {
103                     apply_references(ctx.config.insert_use, path, node, import)
104                 });
105             }
106
107             let def = create_struct_def(variant_name.clone(), &field_list, &enum_ast);
108             let start_offset = &variant.parent_enum().syntax().clone();
109             ted::insert_raw(ted::Position::before(start_offset), def.syntax());
110             ted::insert_raw(ted::Position::before(start_offset), &make::tokens::blank_line());
111
112             update_variant(&variant, enum_ast.generic_param_list());
113         },
114     )
115 }
116
117 fn extract_field_list_if_applicable(
118     variant: &ast::Variant,
119 ) -> Option<Either<ast::RecordFieldList, ast::TupleFieldList>> {
120     match variant.kind() {
121         ast::StructKind::Record(field_list) if field_list.fields().next().is_some() => {
122             Some(Either::Left(field_list))
123         }
124         ast::StructKind::Tuple(field_list) if field_list.fields().count() > 1 => {
125             Some(Either::Right(field_list))
126         }
127         _ => None,
128     }
129 }
130
131 fn existing_definition(db: &RootDatabase, variant_name: &ast::Name, variant: &Variant) -> bool {
132     variant
133         .parent_enum(db)
134         .module(db)
135         .scope(db, None)
136         .into_iter()
137         .filter(|(_, def)| match def {
138             // only check type-namespace
139             hir::ScopeDef::ModuleDef(def) => matches!(
140                 def,
141                 ModuleDef::Module(_)
142                     | ModuleDef::Adt(_)
143                     | ModuleDef::Variant(_)
144                     | ModuleDef::Trait(_)
145                     | ModuleDef::TypeAlias(_)
146                     | ModuleDef::BuiltinType(_)
147             ),
148             _ => false,
149         })
150         .any(|(name, _)| name.to_string() == variant_name.to_string())
151 }
152
153 fn create_struct_def(
154     variant_name: ast::Name,
155     field_list: &Either<ast::RecordFieldList, ast::TupleFieldList>,
156     enum_: &ast::Enum,
157 ) -> ast::Struct {
158     let pub_vis = make::visibility_pub();
159
160     let insert_pub = |node: &'_ SyntaxNode| {
161         let pub_vis = pub_vis.clone_for_update();
162         ted::insert(ted::Position::before(node), pub_vis.syntax());
163     };
164
165     // for fields without any existing visibility, use pub visibility
166     let field_list = match field_list {
167         Either::Left(field_list) => {
168             let field_list = field_list.clone_for_update();
169
170             field_list
171                 .fields()
172                 .filter(|field| field.visibility().is_none())
173                 .filter_map(|field| field.name())
174                 .for_each(|it| insert_pub(it.syntax()));
175
176             field_list.into()
177         }
178         Either::Right(field_list) => {
179             let field_list = field_list.clone_for_update();
180
181             field_list
182                 .fields()
183                 .filter(|field| field.visibility().is_none())
184                 .filter_map(|field| field.ty())
185                 .for_each(|it| insert_pub(it.syntax()));
186
187             field_list.into()
188         }
189     };
190
191     // FIXME: This uses all the generic params of the enum, but the variant might not use all of them.
192     let strukt =
193         make::struct_(enum_.visibility(), variant_name, enum_.generic_param_list(), field_list)
194             .clone_for_update();
195
196     // copy attributes
197     ted::insert_all(
198         Position::first_child_of(strukt.syntax()),
199         enum_.attrs().map(|it| it.syntax().clone_for_update().into()).collect(),
200     );
201     strukt
202 }
203
204 fn update_variant(variant: &ast::Variant, generic: Option<ast::GenericParamList>) -> Option<()> {
205     let name = variant.name()?;
206     let ty = match generic {
207         // FIXME: This uses all the generic params of the enum, but the variant might not use all of them.
208         Some(gpl) => {
209             let gpl = gpl.clone_for_update();
210             gpl.generic_params().for_each(|gp| {
211                 match gp {
212                     ast::GenericParam::LifetimeParam(it) => it.type_bound_list(),
213                     ast::GenericParam::TypeParam(it) => it.type_bound_list(),
214                     ast::GenericParam::ConstParam(_) => return,
215                 }
216                 .map(|it| it.remove());
217             });
218             make::ty(&format!("{}<{}>", name.text(), gpl.generic_params().join(", ")))
219         }
220         None => make::ty(&name.text()),
221     };
222     let tuple_field = make::tuple_field(None, ty);
223     let replacement = make::variant(
224         name,
225         Some(ast::FieldList::TupleFieldList(make::tuple_field_list(iter::once(tuple_field)))),
226     )
227     .clone_for_update();
228     ted::replace(variant.syntax(), replacement.syntax());
229     Some(())
230 }
231
232 fn apply_references(
233     insert_use_cfg: InsertUseConfig,
234     segment: ast::PathSegment,
235     node: SyntaxNode,
236     import: Option<(ImportScope, hir::ModPath)>,
237 ) {
238     if let Some((scope, path)) = import {
239         insert_use(&scope, mod_path_to_ast(&path), &insert_use_cfg);
240     }
241     // deep clone to prevent cycle
242     let path = make::path_from_segments(iter::once(segment.clone_subtree()), false);
243     ted::insert_raw(ted::Position::before(segment.syntax()), path.clone_for_update().syntax());
244     ted::insert_raw(ted::Position::before(segment.syntax()), make::token(T!['(']));
245     ted::insert_raw(ted::Position::after(&node), make::token(T![')']));
246 }
247
248 fn process_references(
249     ctx: &AssistContext,
250     builder: &mut AssistBuilder,
251     visited_modules: &mut FxHashSet<Module>,
252     enum_module_def: &ModuleDef,
253     variant_hir_name: &Name,
254     refs: Vec<FileReference>,
255 ) -> Vec<(ast::PathSegment, SyntaxNode, Option<(ImportScope, hir::ModPath)>)> {
256     // we have to recollect here eagerly as we are about to edit the tree we need to calculate the changes
257     // and corresponding nodes up front
258     refs.into_iter()
259         .flat_map(|reference| {
260             let (segment, scope_node, module) = reference_to_node(&ctx.sema, reference)?;
261             let segment = builder.make_mut(segment);
262             let scope_node = builder.make_syntax_mut(scope_node);
263             if !visited_modules.contains(&module) {
264                 let mod_path = module.find_use_path_prefixed(
265                     ctx.sema.db,
266                     *enum_module_def,
267                     ctx.config.insert_use.prefix_kind,
268                 );
269                 if let Some(mut mod_path) = mod_path {
270                     mod_path.pop_segment();
271                     mod_path.push_segment(variant_hir_name.clone());
272                     let scope = ImportScope::find_insert_use_container(&scope_node)?;
273                     visited_modules.insert(module);
274                     return Some((segment, scope_node, Some((scope, mod_path))));
275                 }
276             }
277             Some((segment, scope_node, None))
278         })
279         .collect()
280 }
281
282 fn reference_to_node(
283     sema: &hir::Semantics<RootDatabase>,
284     reference: FileReference,
285 ) -> Option<(ast::PathSegment, SyntaxNode, hir::Module)> {
286     let segment =
287         reference.name.as_name_ref()?.syntax().parent().and_then(ast::PathSegment::cast)?;
288     let parent = segment.parent_path().syntax().parent()?;
289     let expr_or_pat = match_ast! {
290         match parent {
291             ast::PathExpr(_it) => parent.parent()?,
292             ast::RecordExpr(_it) => parent,
293             ast::TupleStructPat(_it) => parent,
294             ast::RecordPat(_it) => parent,
295             _ => return None,
296         }
297     };
298     let module = sema.scope(&expr_or_pat).module()?;
299     Some((segment, expr_or_pat, module))
300 }
301
302 #[cfg(test)]
303 mod tests {
304     use crate::tests::{check_assist, check_assist_not_applicable};
305
306     use super::*;
307
308     #[test]
309     fn test_extract_struct_several_fields_tuple() {
310         check_assist(
311             extract_struct_from_enum_variant,
312             "enum A { $0One(u32, u32) }",
313             r#"struct One(pub u32, pub u32);
314
315 enum A { One(One) }"#,
316         );
317     }
318
319     #[test]
320     fn test_extract_struct_several_fields_named() {
321         check_assist(
322             extract_struct_from_enum_variant,
323             "enum A { $0One { foo: u32, bar: u32 } }",
324             r#"struct One{ pub foo: u32, pub bar: u32 }
325
326 enum A { One(One) }"#,
327         );
328     }
329
330     #[test]
331     fn test_extract_struct_one_field_named() {
332         check_assist(
333             extract_struct_from_enum_variant,
334             "enum A { $0One { foo: u32 } }",
335             r#"struct One{ pub foo: u32 }
336
337 enum A { One(One) }"#,
338         );
339     }
340
341     #[test]
342     fn test_extract_struct_carries_over_generics() {
343         check_assist(
344             extract_struct_from_enum_variant,
345             r"enum En<T> { Var { a: T$0 } }",
346             r#"struct Var<T>{ pub a: T }
347
348 enum En<T> { Var(Var<T>) }"#,
349         );
350     }
351
352     #[test]
353     fn test_extract_struct_carries_over_attributes() {
354         check_assist(
355             extract_struct_from_enum_variant,
356             r#"#[derive(Debug)]
357 #[derive(Clone)]
358 enum Enum { Variant{ field: u32$0 } }"#,
359             r#"#[derive(Debug)]#[derive(Clone)] struct Variant{ pub field: u32 }
360
361 #[derive(Debug)]
362 #[derive(Clone)]
363 enum Enum { Variant(Variant) }"#,
364         );
365     }
366
367     #[test]
368     fn test_extract_struct_keep_comments_and_attrs_one_field_named() {
369         check_assist(
370             extract_struct_from_enum_variant,
371             r#"
372 enum A {
373     $0One {
374         // leading comment
375         /// doc comment
376         #[an_attr]
377         foo: u32
378         // trailing comment
379     }
380 }"#,
381             r#"
382 struct One{
383         // leading comment
384         /// doc comment
385         #[an_attr]
386         pub foo: u32
387         // trailing comment
388     }
389
390 enum A {
391     One(One)
392 }"#,
393         );
394     }
395
396     #[test]
397     fn test_extract_struct_keep_comments_and_attrs_several_fields_named() {
398         check_assist(
399             extract_struct_from_enum_variant,
400             r#"
401 enum A {
402     $0One {
403         // comment
404         /// doc
405         #[attr]
406         foo: u32,
407         // comment
408         #[attr]
409         /// doc
410         bar: u32
411     }
412 }"#,
413             r#"
414 struct One{
415         // comment
416         /// doc
417         #[attr]
418         pub foo: u32,
419         // comment
420         #[attr]
421         /// doc
422         pub bar: u32
423     }
424
425 enum A {
426     One(One)
427 }"#,
428         );
429     }
430
431     #[test]
432     fn test_extract_struct_keep_comments_and_attrs_several_fields_tuple() {
433         check_assist(
434             extract_struct_from_enum_variant,
435             "enum A { $0One(/* comment */ #[attr] u32, /* another */ u32 /* tail */) }",
436             r#"
437 struct One(/* comment */ #[attr] pub u32, /* another */ pub u32 /* tail */);
438
439 enum A { One(One) }"#,
440         );
441     }
442
443     #[test]
444     fn test_extract_struct_keep_existing_visibility_named() {
445         check_assist(
446             extract_struct_from_enum_variant,
447             "enum A { $0One{ pub a: u32, pub(crate) b: u32, pub(super) c: u32, d: u32 } }",
448             r#"
449 struct One{ pub a: u32, pub(crate) b: u32, pub(super) c: u32, pub d: u32 }
450
451 enum A { One(One) }"#,
452         );
453     }
454
455     #[test]
456     fn test_extract_struct_keep_existing_visibility_tuple() {
457         check_assist(
458             extract_struct_from_enum_variant,
459             "enum A { $0One(pub u32, pub(crate) u32, pub(super) u32, u32) }",
460             r#"
461 struct One(pub u32, pub(crate) u32, pub(super) u32, pub u32);
462
463 enum A { One(One) }"#,
464         );
465     }
466
467     #[test]
468     fn test_extract_enum_variant_name_value_namespace() {
469         check_assist(
470             extract_struct_from_enum_variant,
471             r#"const One: () = ();
472 enum A { $0One(u32, u32) }"#,
473             r#"const One: () = ();
474 struct One(pub u32, pub u32);
475
476 enum A { One(One) }"#,
477         );
478     }
479
480     #[test]
481     fn test_extract_struct_pub_visibility() {
482         check_assist(
483             extract_struct_from_enum_variant,
484             "pub enum A { $0One(u32, u32) }",
485             r#"pub struct One(pub u32, pub u32);
486
487 pub enum A { One(One) }"#,
488         );
489     }
490
491     #[test]
492     fn test_extract_struct_with_complex_imports() {
493         check_assist(
494             extract_struct_from_enum_variant,
495             r#"mod my_mod {
496     fn another_fn() {
497         let m = my_other_mod::MyEnum::MyField(1, 1);
498     }
499
500     pub mod my_other_mod {
501         fn another_fn() {
502             let m = MyEnum::MyField(1, 1);
503         }
504
505         pub enum MyEnum {
506             $0MyField(u8, u8),
507         }
508     }
509 }
510
511 fn another_fn() {
512     let m = my_mod::my_other_mod::MyEnum::MyField(1, 1);
513 }"#,
514             r#"use my_mod::my_other_mod::MyField;
515
516 mod my_mod {
517     use self::my_other_mod::MyField;
518
519     fn another_fn() {
520         let m = my_other_mod::MyEnum::MyField(MyField(1, 1));
521     }
522
523     pub mod my_other_mod {
524         fn another_fn() {
525             let m = MyEnum::MyField(MyField(1, 1));
526         }
527
528         pub struct MyField(pub u8, pub u8);
529
530 pub enum MyEnum {
531             MyField(MyField),
532         }
533     }
534 }
535
536 fn another_fn() {
537     let m = my_mod::my_other_mod::MyEnum::MyField(MyField(1, 1));
538 }"#,
539         );
540     }
541
542     #[test]
543     fn extract_record_fix_references() {
544         check_assist(
545             extract_struct_from_enum_variant,
546             r#"
547 enum E {
548     $0V { i: i32, j: i32 }
549 }
550
551 fn f() {
552     let E::V { i, j } = E::V { i: 9, j: 2 };
553 }
554 "#,
555             r#"
556 struct V{ pub i: i32, pub j: i32 }
557
558 enum E {
559     V(V)
560 }
561
562 fn f() {
563     let E::V(V { i, j }) = E::V(V { i: 9, j: 2 });
564 }
565 "#,
566         )
567     }
568
569     #[test]
570     fn extract_record_fix_references2() {
571         check_assist(
572             extract_struct_from_enum_variant,
573             r#"
574 enum E {
575     $0V(i32, i32)
576 }
577
578 fn f() {
579     let E::V(i, j) = E::V(9, 2);
580 }
581 "#,
582             r#"
583 struct V(pub i32, pub i32);
584
585 enum E {
586     V(V)
587 }
588
589 fn f() {
590     let E::V(V(i, j)) = E::V(V(9, 2));
591 }
592 "#,
593         )
594     }
595
596     #[test]
597     fn test_several_files() {
598         check_assist(
599             extract_struct_from_enum_variant,
600             r#"
601 //- /main.rs
602 enum E {
603     $0V(i32, i32)
604 }
605 mod foo;
606
607 //- /foo.rs
608 use crate::E;
609 fn f() {
610     let e = E::V(9, 2);
611 }
612 "#,
613             r#"
614 //- /main.rs
615 struct V(pub i32, pub i32);
616
617 enum E {
618     V(V)
619 }
620 mod foo;
621
622 //- /foo.rs
623 use crate::{E, V};
624 fn f() {
625     let e = E::V(V(9, 2));
626 }
627 "#,
628         )
629     }
630
631     #[test]
632     fn test_several_files_record() {
633         check_assist(
634             extract_struct_from_enum_variant,
635             r#"
636 //- /main.rs
637 enum E {
638     $0V { i: i32, j: i32 }
639 }
640 mod foo;
641
642 //- /foo.rs
643 use crate::E;
644 fn f() {
645     let e = E::V { i: 9, j: 2 };
646 }
647 "#,
648             r#"
649 //- /main.rs
650 struct V{ pub i: i32, pub j: i32 }
651
652 enum E {
653     V(V)
654 }
655 mod foo;
656
657 //- /foo.rs
658 use crate::{E, V};
659 fn f() {
660     let e = E::V(V { i: 9, j: 2 });
661 }
662 "#,
663         )
664     }
665
666     #[test]
667     fn test_extract_struct_record_nested_call_exp() {
668         check_assist(
669             extract_struct_from_enum_variant,
670             r#"
671 enum A { $0One { a: u32, b: u32 } }
672
673 struct B(A);
674
675 fn foo() {
676     let _ = B(A::One { a: 1, b: 2 });
677 }
678 "#,
679             r#"
680 struct One{ pub a: u32, pub b: u32 }
681
682 enum A { One(One) }
683
684 struct B(A);
685
686 fn foo() {
687     let _ = B(A::One(One { a: 1, b: 2 }));
688 }
689 "#,
690         );
691     }
692
693     #[test]
694     fn test_extract_enum_not_applicable_for_element_with_no_fields() {
695         check_assist_not_applicable(extract_struct_from_enum_variant, r#"enum A { $0One }"#);
696     }
697
698     #[test]
699     fn test_extract_enum_not_applicable_if_struct_exists() {
700         cov_mark::check!(test_extract_enum_not_applicable_if_struct_exists);
701         check_assist_not_applicable(
702             extract_struct_from_enum_variant,
703             r#"
704 struct One;
705 enum A { $0One(u8, u32) }
706 "#,
707         );
708     }
709
710     #[test]
711     fn test_extract_not_applicable_one_field() {
712         check_assist_not_applicable(extract_struct_from_enum_variant, r"enum A { $0One(u32) }");
713     }
714
715     #[test]
716     fn test_extract_not_applicable_no_field_tuple() {
717         check_assist_not_applicable(extract_struct_from_enum_variant, r"enum A { $0None() }");
718     }
719
720     #[test]
721     fn test_extract_not_applicable_no_field_named() {
722         check_assist_not_applicable(extract_struct_from_enum_variant, r"enum A { $0None {} }");
723     }
724 }