]> git.lizzy.rs Git - rust.git/commitdiff
Account for generics in extract_struct_from_enum_variant
authorLukas Wirth <lukastw97@gmail.com>
Wed, 2 Jun 2021 15:44:00 +0000 (17:44 +0200)
committerLukas Wirth <lukastw97@gmail.com>
Wed, 2 Jun 2021 15:44:00 +0000 (17:44 +0200)
crates/ide_assists/src/handlers/extract_struct_from_enum_variant.rs
crates/syntax/src/ast/make.rs

index 007aba23d216abdc5e4322074f061ec5779f15c9..730fc28bfcb7a1166fd8e43c9f368e80bd07b3f9 100644 (file)
     search::FileReference,
     RootDatabase,
 };
+use itertools::Itertools;
 use rustc_hash::FxHashSet;
 use syntax::{
     algo::find_node_at_offset,
-    ast::{self, make, AstNode, NameOwner, VisibilityOwner},
+    ast::{self, make, AstNode, GenericParamsOwner, NameOwner, TypeBoundsOwner, VisibilityOwner},
     ted, SyntaxNode, T,
 };
 
@@ -100,12 +101,12 @@ pub(crate) fn extract_struct_from_enum_variant(
                 });
             }
 
-            let def = create_struct_def(variant_name.clone(), &field_list, enum_ast.visibility());
+            let def = create_struct_def(variant_name.clone(), &field_list, &enum_ast);
             let start_offset = &variant.parent_enum().syntax().clone();
             ted::insert_raw(ted::Position::before(start_offset), def.syntax());
             ted::insert_raw(ted::Position::before(start_offset), &make::tokens::blank_line());
 
-            update_variant(&variant);
+            update_variant(&variant, enum_ast.generic_param_list());
         },
     )
 }
@@ -149,7 +150,7 @@ fn existing_definition(db: &RootDatabase, variant_name: &ast::Name, variant: &Va
 fn create_struct_def(
     variant_name: ast::Name,
     field_list: &Either<ast::RecordFieldList, ast::TupleFieldList>,
-    visibility: Option<ast::Visibility>,
+    enum_: &ast::Enum,
 ) -> ast::Struct {
     let pub_vis = make::visibility_pub();
 
@@ -184,12 +185,30 @@ fn create_struct_def(
         }
     };
 
-    make::struct_(visibility, variant_name, None, field_list).clone_for_update()
+    // FIXME: This uses all the generic params of the enum, but the variant might not use all of them.
+    make::struct_(enum_.visibility(), variant_name, enum_.generic_param_list(), field_list)
+        .clone_for_update()
 }
 
-fn update_variant(variant: &ast::Variant) -> Option<()> {
+fn update_variant(variant: &ast::Variant, generic: Option<ast::GenericParamList>) -> Option<()> {
     let name = variant.name()?;
-    let tuple_field = make::tuple_field(None, make::ty(&name.text()));
+    let ty = match generic {
+        // FIXME: This uses all the generic params of the enum, but the variant might not use all of them.
+        Some(gpl) => {
+            let gpl = gpl.clone_for_update();
+            gpl.generic_params().for_each(|gp| {
+                match gp {
+                    ast::GenericParam::LifetimeParam(it) => it.type_bound_list(),
+                    ast::GenericParam::TypeParam(it) => it.type_bound_list(),
+                    ast::GenericParam::ConstParam(_) => return,
+                }
+                .map(|it| it.remove());
+            });
+            make::ty(&format!("{}<{}>", name.text(), gpl.generic_params().join(", ")))
+        }
+        None => make::ty(&name.text()),
+    };
+    let tuple_field = make::tuple_field(None, ty);
     let replacement = make::variant(
         name,
         Some(ast::FieldList::TupleFieldList(make::tuple_field_list(iter::once(tuple_field)))),
@@ -208,10 +227,9 @@ fn apply_references(
     if let Some((scope, path)) = import {
         insert_use(&scope, mod_path_to_ast(&path), insert_use_cfg);
     }
-    ted::insert_raw(
-        ted::Position::before(segment.syntax()),
-        make::path_from_text(&format!("{}", segment)).clone_for_update().syntax(),
-    );
+    // deep clone to prevent cycle
+    let path = make::path_from_segments(iter::once(segment.clone_subtree()), false);
+    ted::insert_raw(ted::Position::before(segment.syntax()), path.clone_for_update().syntax());
     ted::insert_raw(ted::Position::before(segment.syntax()), make::token(T!['(']));
     ted::insert_raw(ted::Position::after(&node), make::token(T![')']));
 }
@@ -278,6 +296,12 @@ mod tests {
 
     use super::*;
 
+    fn check_not_applicable(ra_fixture: &str) {
+        let fixture =
+            format!("//- /main.rs crate:main deps:core\n{}\n{}", ra_fixture, FamousDefs::FIXTURE);
+        check_assist_not_applicable(extract_struct_from_enum_variant, &fixture)
+    }
+
     #[test]
     fn test_extract_struct_several_fields_tuple() {
         check_assist(
@@ -311,6 +335,17 @@ enum A { One(One) }"#,
         );
     }
 
+    #[test]
+    fn test_extract_struct_keeps_generics() {
+        check_assist(
+            extract_struct_from_enum_variant,
+            r"enum En<T> { Var { a: T$0 } }",
+            r#"struct Var<T>{ pub a: T }
+
+enum En<T> { Var(Var<T>) }"#,
+        );
+    }
+
     #[test]
     fn test_extract_struct_keep_comments_and_attrs_one_field_named() {
         check_assist(
@@ -610,12 +645,6 @@ fn foo() {
         );
     }
 
-    fn check_not_applicable(ra_fixture: &str) {
-        let fixture =
-            format!("//- /main.rs crate:main deps:core\n{}\n{}", ra_fixture, FamousDefs::FIXTURE);
-        check_assist_not_applicable(extract_struct_from_enum_variant, &fixture)
-    }
-
     #[test]
     fn test_extract_enum_not_applicable_for_element_with_no_fields() {
         check_not_applicable("enum A { $0One }");
index 0cf17062610996e5079b39d9c38e255215cd623e..4c3c9661d44d2ac3d3a31b89fbbdca11ee1f9913 100644 (file)
@@ -580,12 +580,11 @@ pub fn fn_(
 pub fn struct_(
     visibility: Option<ast::Visibility>,
     strukt_name: ast::Name,
-    type_params: Option<ast::GenericParamList>,
+    generic_param_list: Option<ast::GenericParamList>,
     field_list: ast::FieldList,
 ) -> ast::Struct {
     let semicolon = if matches!(field_list, ast::FieldList::TupleFieldList(_)) { ";" } else { "" };
-    let type_params =
-        if let Some(type_params) = type_params { format!("<{}>", type_params) } else { "".into() };
+    let type_params = generic_param_list.map_or_else(String::new, |it| it.to_string());
     let visibility = match visibility {
         None => String::new(),
         Some(it) => format!("{} ", it),