]> git.lizzy.rs Git - rust.git/blob - crates/ide_assists/src/handlers/extract_struct_from_enum_variant.rs
Hygiene is an internal implementation detail of the compiler
[rust.git] / crates / ide_assists / src / handlers / extract_struct_from_enum_variant.rs
1 use std::iter;
2
3 use either::Either;
4 use hir::{Module, ModuleDef, Name, Variant};
5 use ide_db::{
6     defs::Definition,
7     helpers::{
8         insert_use::{insert_use, ImportScope},
9         mod_path_to_ast,
10     },
11     search::FileReference,
12     RootDatabase,
13 };
14 use rustc_hash::FxHashSet;
15 use syntax::{
16     algo::{find_node_at_offset, SyntaxRewriter},
17     ast::{self, edit::IndentLevel, make, AstNode, NameOwner, VisibilityOwner},
18     SourceFile, SyntaxElement, SyntaxNode, T,
19 };
20
21 use crate::{AssistContext, AssistId, AssistKind, Assists};
22
23 // Assist: extract_struct_from_enum_variant
24 //
25 // Extracts a struct from enum variant.
26 //
27 // ```
28 // enum A { $0One(u32, u32) }
29 // ```
30 // ->
31 // ```
32 // struct One(pub u32, pub u32);
33 //
34 // enum A { One(One) }
35 // ```
36 pub(crate) fn extract_struct_from_enum_variant(
37     acc: &mut Assists,
38     ctx: &AssistContext,
39 ) -> Option<()> {
40     let variant = ctx.find_node_at_offset::<ast::Variant>()?;
41     let field_list = extract_field_list_if_applicable(&variant)?;
42
43     let variant_name = variant.name()?;
44     let variant_hir = ctx.sema.to_def(&variant)?;
45     if existing_definition(ctx.db(), &variant_name, &variant_hir) {
46         return None;
47     }
48
49     let enum_ast = variant.parent_enum();
50     let enum_hir = ctx.sema.to_def(&enum_ast)?;
51     let target = variant.syntax().text_range();
52     acc.add(
53         AssistId("extract_struct_from_enum_variant", AssistKind::RefactorRewrite),
54         "Extract struct from enum variant",
55         target,
56         |builder| {
57             let variant_hir_name = variant_hir.name(ctx.db());
58             let enum_module_def = ModuleDef::from(enum_hir);
59             let usages =
60                 Definition::ModuleDef(ModuleDef::Variant(variant_hir)).usages(&ctx.sema).all();
61
62             let mut visited_modules_set = FxHashSet::default();
63             let current_module = enum_hir.module(ctx.db());
64             visited_modules_set.insert(current_module);
65             let mut def_rewriter = None;
66             for (file_id, references) in usages {
67                 let mut rewriter = SyntaxRewriter::default();
68                 let source_file = ctx.sema.parse(file_id);
69                 for reference in references {
70                     update_reference(
71                         ctx,
72                         &mut rewriter,
73                         reference,
74                         &source_file,
75                         &enum_module_def,
76                         &variant_hir_name,
77                         &mut visited_modules_set,
78                     );
79                 }
80                 if file_id == ctx.frange.file_id {
81                     def_rewriter = Some(rewriter);
82                     continue;
83                 }
84                 builder.edit_file(file_id);
85                 builder.rewrite(rewriter);
86             }
87             let mut rewriter = def_rewriter.unwrap_or_default();
88             update_variant(&mut rewriter, &variant);
89             extract_struct_def(
90                 &mut rewriter,
91                 &enum_ast,
92                 variant_name.clone(),
93                 &field_list,
94                 &variant.parent_enum().syntax().clone().into(),
95                 enum_ast.visibility(),
96             );
97             builder.edit_file(ctx.frange.file_id);
98             builder.rewrite(rewriter);
99         },
100     )
101 }
102
103 fn extract_field_list_if_applicable(
104     variant: &ast::Variant,
105 ) -> Option<Either<ast::RecordFieldList, ast::TupleFieldList>> {
106     match variant.kind() {
107         ast::StructKind::Record(field_list) if field_list.fields().next().is_some() => {
108             Some(Either::Left(field_list))
109         }
110         ast::StructKind::Tuple(field_list) if field_list.fields().count() > 1 => {
111             Some(Either::Right(field_list))
112         }
113         _ => None,
114     }
115 }
116
117 fn existing_definition(db: &RootDatabase, variant_name: &ast::Name, variant: &Variant) -> bool {
118     variant
119         .parent_enum(db)
120         .module(db)
121         .scope(db, None)
122         .into_iter()
123         .filter(|(_, def)| match def {
124             // only check type-namespace
125             hir::ScopeDef::ModuleDef(def) => matches!(
126                 def,
127                 ModuleDef::Module(_)
128                     | ModuleDef::Adt(_)
129                     | ModuleDef::Variant(_)
130                     | ModuleDef::Trait(_)
131                     | ModuleDef::TypeAlias(_)
132                     | ModuleDef::BuiltinType(_)
133             ),
134             _ => false,
135         })
136         .any(|(name, _)| name.to_string() == variant_name.to_string())
137 }
138
139 fn insert_import(
140     ctx: &AssistContext,
141     rewriter: &mut SyntaxRewriter,
142     scope_node: &SyntaxNode,
143     module: &Module,
144     enum_module_def: &ModuleDef,
145     variant_hir_name: &Name,
146 ) -> Option<()> {
147     let db = ctx.db();
148     let mod_path = module.find_use_path_prefixed(
149         db,
150         enum_module_def.clone(),
151         ctx.config.insert_use.prefix_kind,
152     );
153     if let Some(mut mod_path) = mod_path {
154         mod_path.pop_segment();
155         mod_path.push_segment(variant_hir_name.clone());
156         let scope = ImportScope::find_insert_use_container(scope_node, &ctx.sema)?;
157         *rewriter += insert_use(&scope, mod_path_to_ast(&mod_path), ctx.config.insert_use);
158     }
159     Some(())
160 }
161
162 fn extract_struct_def(
163     rewriter: &mut SyntaxRewriter,
164     enum_: &ast::Enum,
165     variant_name: ast::Name,
166     field_list: &Either<ast::RecordFieldList, ast::TupleFieldList>,
167     start_offset: &SyntaxElement,
168     visibility: Option<ast::Visibility>,
169 ) -> Option<()> {
170     let pub_vis = Some(make::visibility_pub());
171     let field_list = match field_list {
172         Either::Left(field_list) => {
173             make::record_field_list(field_list.fields().flat_map(|field| {
174                 Some(make::record_field(pub_vis.clone(), field.name()?, field.ty()?))
175             }))
176             .into()
177         }
178         Either::Right(field_list) => make::tuple_field_list(
179             field_list
180                 .fields()
181                 .flat_map(|field| Some(make::tuple_field(pub_vis.clone(), field.ty()?))),
182         )
183         .into(),
184     };
185
186     rewriter.insert_before(
187         start_offset,
188         make::struct_(visibility, variant_name, None, field_list).syntax(),
189     );
190     rewriter.insert_before(start_offset, &make::tokens::blank_line());
191
192     if let indent_level @ 1..=usize::MAX = IndentLevel::from_node(enum_.syntax()).0 as usize {
193         rewriter
194             .insert_before(start_offset, &make::tokens::whitespace(&" ".repeat(4 * indent_level)));
195     }
196     Some(())
197 }
198
199 fn update_variant(rewriter: &mut SyntaxRewriter, variant: &ast::Variant) -> Option<()> {
200     let name = variant.name()?;
201     let tuple_field = make::tuple_field(None, make::ty(name.text()));
202     let replacement = make::variant(
203         name,
204         Some(ast::FieldList::TupleFieldList(make::tuple_field_list(iter::once(tuple_field)))),
205     );
206     rewriter.replace(variant.syntax(), replacement.syntax());
207     Some(())
208 }
209
210 fn update_reference(
211     ctx: &AssistContext,
212     rewriter: &mut SyntaxRewriter,
213     reference: FileReference,
214     source_file: &SourceFile,
215     enum_module_def: &ModuleDef,
216     variant_hir_name: &Name,
217     visited_modules_set: &mut FxHashSet<Module>,
218 ) -> Option<()> {
219     let offset = reference.range.start();
220     let (segment, expr) = if let Some(path_expr) =
221         find_node_at_offset::<ast::PathExpr>(source_file.syntax(), offset)
222     {
223         // tuple variant
224         (path_expr.path()?.segment()?, path_expr.syntax().parent()?)
225     } else if let Some(record_expr) =
226         find_node_at_offset::<ast::RecordExpr>(source_file.syntax(), offset)
227     {
228         // record variant
229         (record_expr.path()?.segment()?, record_expr.syntax().clone())
230     } else {
231         return None;
232     };
233
234     let module = ctx.sema.scope(&expr).module()?;
235     if !visited_modules_set.contains(&module) {
236         if insert_import(ctx, rewriter, &expr, &module, enum_module_def, variant_hir_name).is_some()
237         {
238             visited_modules_set.insert(module);
239         }
240     }
241     rewriter.insert_after(segment.syntax(), &make::token(T!['(']));
242     rewriter.insert_after(segment.syntax(), segment.syntax());
243     rewriter.insert_after(&expr, &make::token(T![')']));
244     Some(())
245 }
246
247 #[cfg(test)]
248 mod tests {
249     use ide_db::helpers::FamousDefs;
250
251     use crate::tests::{check_assist, check_assist_not_applicable};
252
253     use super::*;
254
255     #[test]
256     fn test_extract_struct_several_fields_tuple() {
257         check_assist(
258             extract_struct_from_enum_variant,
259             "enum A { $0One(u32, u32) }",
260             r#"struct One(pub u32, pub u32);
261
262 enum A { One(One) }"#,
263         );
264     }
265
266     #[test]
267     fn test_extract_struct_several_fields_named() {
268         check_assist(
269             extract_struct_from_enum_variant,
270             "enum A { $0One { foo: u32, bar: u32 } }",
271             r#"struct One{ pub foo: u32, pub bar: u32 }
272
273 enum A { One(One) }"#,
274         );
275     }
276
277     #[test]
278     fn test_extract_struct_one_field_named() {
279         check_assist(
280             extract_struct_from_enum_variant,
281             "enum A { $0One { foo: u32 } }",
282             r#"struct One{ pub foo: u32 }
283
284 enum A { One(One) }"#,
285         );
286     }
287
288     #[test]
289     fn test_extract_enum_variant_name_value_namespace() {
290         check_assist(
291             extract_struct_from_enum_variant,
292             r#"const One: () = ();
293 enum A { $0One(u32, u32) }"#,
294             r#"const One: () = ();
295 struct One(pub u32, pub u32);
296
297 enum A { One(One) }"#,
298         );
299     }
300
301     #[test]
302     fn test_extract_struct_pub_visibility() {
303         check_assist(
304             extract_struct_from_enum_variant,
305             "pub enum A { $0One(u32, u32) }",
306             r#"pub struct One(pub u32, pub u32);
307
308 pub enum A { One(One) }"#,
309         );
310     }
311
312     #[test]
313     fn test_extract_struct_with_complex_imports() {
314         check_assist(
315             extract_struct_from_enum_variant,
316             r#"mod my_mod {
317     fn another_fn() {
318         let m = my_other_mod::MyEnum::MyField(1, 1);
319     }
320
321     pub mod my_other_mod {
322         fn another_fn() {
323             let m = MyEnum::MyField(1, 1);
324         }
325
326         pub enum MyEnum {
327             $0MyField(u8, u8),
328         }
329     }
330 }
331
332 fn another_fn() {
333     let m = my_mod::my_other_mod::MyEnum::MyField(1, 1);
334 }"#,
335             r#"use my_mod::my_other_mod::MyField;
336
337 mod my_mod {
338     use self::my_other_mod::MyField;
339
340     fn another_fn() {
341         let m = my_other_mod::MyEnum::MyField(MyField(1, 1));
342     }
343
344     pub mod my_other_mod {
345         fn another_fn() {
346             let m = MyEnum::MyField(MyField(1, 1));
347         }
348
349         pub struct MyField(pub u8, pub u8);
350
351         pub enum MyEnum {
352             MyField(MyField),
353         }
354     }
355 }
356
357 fn another_fn() {
358     let m = my_mod::my_other_mod::MyEnum::MyField(MyField(1, 1));
359 }"#,
360         );
361     }
362
363     #[test]
364     fn extract_record_fix_references() {
365         check_assist(
366             extract_struct_from_enum_variant,
367             r#"
368 enum E {
369     $0V { i: i32, j: i32 }
370 }
371
372 fn f() {
373     let e = E::V { i: 9, j: 2 };
374 }
375 "#,
376             r#"
377 struct V{ pub i: i32, pub j: i32 }
378
379 enum E {
380     V(V)
381 }
382
383 fn f() {
384     let e = E::V(V { i: 9, j: 2 });
385 }
386 "#,
387         )
388     }
389
390     #[test]
391     fn test_several_files() {
392         check_assist(
393             extract_struct_from_enum_variant,
394             r#"
395 //- /main.rs
396 enum E {
397     $0V(i32, i32)
398 }
399 mod foo;
400
401 //- /foo.rs
402 use crate::E;
403 fn f() {
404     let e = E::V(9, 2);
405 }
406 "#,
407             r#"
408 //- /main.rs
409 struct V(pub i32, pub i32);
410
411 enum E {
412     V(V)
413 }
414 mod foo;
415
416 //- /foo.rs
417 use crate::{E, V};
418 fn f() {
419     let e = E::V(V(9, 2));
420 }
421 "#,
422         )
423     }
424
425     #[test]
426     fn test_several_files_record() {
427         check_assist(
428             extract_struct_from_enum_variant,
429             r#"
430 //- /main.rs
431 enum E {
432     $0V { i: i32, j: i32 }
433 }
434 mod foo;
435
436 //- /foo.rs
437 use crate::E;
438 fn f() {
439     let e = E::V { i: 9, j: 2 };
440 }
441 "#,
442             r#"
443 //- /main.rs
444 struct V{ pub i: i32, pub j: i32 }
445
446 enum E {
447     V(V)
448 }
449 mod foo;
450
451 //- /foo.rs
452 use crate::{E, V};
453 fn f() {
454     let e = E::V(V { i: 9, j: 2 });
455 }
456 "#,
457         )
458     }
459
460     #[test]
461     fn test_extract_struct_record_nested_call_exp() {
462         check_assist(
463             extract_struct_from_enum_variant,
464             r#"
465 enum A { $0One { a: u32, b: u32 } }
466
467 struct B(A);
468
469 fn foo() {
470     let _ = B(A::One { a: 1, b: 2 });
471 }
472 "#,
473             r#"
474 struct One{ pub a: u32, pub b: u32 }
475
476 enum A { One(One) }
477
478 struct B(A);
479
480 fn foo() {
481     let _ = B(A::One(One { a: 1, b: 2 }));
482 }
483 "#,
484         );
485     }
486
487     fn check_not_applicable(ra_fixture: &str) {
488         let fixture =
489             format!("//- /main.rs crate:main deps:core\n{}\n{}", ra_fixture, FamousDefs::FIXTURE);
490         check_assist_not_applicable(extract_struct_from_enum_variant, &fixture)
491     }
492
493     #[test]
494     fn test_extract_enum_not_applicable_for_element_with_no_fields() {
495         check_not_applicable("enum A { $0One }");
496     }
497
498     #[test]
499     fn test_extract_enum_not_applicable_if_struct_exists() {
500         check_not_applicable(
501             r#"struct One;
502         enum A { $0One(u8, u32) }"#,
503         );
504     }
505
506     #[test]
507     fn test_extract_not_applicable_one_field() {
508         check_not_applicable(r"enum A { $0One(u32) }");
509     }
510
511     #[test]
512     fn test_extract_not_applicable_no_field_tuple() {
513         check_not_applicable(r"enum A { $0None() }");
514     }
515
516     #[test]
517     fn test_extract_not_applicable_no_field_named() {
518         check_not_applicable(r"enum A { $0None {} }");
519     }
520 }