]> git.lizzy.rs Git - rust.git/commitdiff
Add convert tuple struct to named struct assist
authorunexge <unexge@gmail.com>
Sat, 3 Apr 2021 21:04:31 +0000 (00:04 +0300)
committerunexge <unexge@gmail.com>
Sun, 4 Apr 2021 17:52:43 +0000 (20:52 +0300)
crates/ide_assists/src/handlers/convert_tuple_struct_to_named_struct.rs [new file with mode: 0644]
crates/ide_assists/src/lib.rs
crates/ide_assists/src/tests/generated.rs
crates/ide_db/src/search.rs
crates/syntax/src/ast/make.rs

diff --git a/crates/ide_assists/src/handlers/convert_tuple_struct_to_named_struct.rs b/crates/ide_assists/src/handlers/convert_tuple_struct_to_named_struct.rs
new file mode 100644 (file)
index 0000000..b2f7be0
--- /dev/null
@@ -0,0 +1,345 @@
+use hir::{Adt, ModuleDef};
+use ide_db::defs::Definition;
+use syntax::{
+    ast::{self, AstNode, GenericParamsOwner, VisibilityOwner},
+    match_ast,
+};
+
+use crate::{assist_context::AssistBuilder, AssistContext, AssistId, AssistKind, Assists};
+
+// Assist: convert_tuple_struct_to_named_struct
+//
+// Converts tuple struct to struct with named fields.
+//
+// ```
+// struct Inner;
+// struct A$0(Inner);
+// ```
+// ->
+// ```
+// struct Inner;
+// struct A { field1: Inner }
+// ```
+pub(crate) fn convert_tuple_struct_to_named_struct(
+    acc: &mut Assists,
+    ctx: &AssistContext,
+) -> Option<()> {
+    let strukt = ctx.find_node_at_offset::<ast::Struct>()?;
+    let tuple_fields = match strukt.field_list()? {
+        ast::FieldList::TupleFieldList(it) => it,
+        ast::FieldList::RecordFieldList(_) => return None,
+    };
+
+    let target = strukt.syntax().text_range();
+    acc.add(
+        AssistId("convert_tuple_struct_to_named_struct", AssistKind::RefactorRewrite),
+        "Convert to named struct",
+        target,
+        |edit| {
+            let names = generate_names(tuple_fields.fields());
+            edit_field_references(ctx, edit, tuple_fields.fields(), &names);
+            edit_struct_references(ctx, edit, &strukt, &names);
+            edit_struct_def(ctx, edit, &strukt, tuple_fields, names);
+        },
+    )
+}
+
+fn edit_struct_def(
+    ctx: &AssistContext,
+    edit: &mut AssistBuilder,
+    strukt: &ast::Struct,
+    tuple_fields: ast::TupleFieldList,
+    names: Vec<ast::Name>,
+) {
+    let record_fields = tuple_fields
+        .fields()
+        .zip(names)
+        .map(|(f, name)| ast::make::record_field(f.visibility(), name, f.ty().unwrap()));
+    let record_fields = ast::make::record_field_list(record_fields);
+    let tuple_fields_text_range = tuple_fields.syntax().text_range();
+
+    edit.edit_file(ctx.frange.file_id);
+
+    if let Some(w) = strukt.where_clause() {
+        edit.delete(w.syntax().text_range());
+        edit.insert(tuple_fields_text_range.start(), ast::make::tokens::single_newline().text());
+        edit.insert(tuple_fields_text_range.start(), w.syntax().text());
+        edit.insert(tuple_fields_text_range.start(), ",");
+        edit.insert(tuple_fields_text_range.start(), ast::make::tokens::single_newline().text());
+    } else {
+        edit.insert(tuple_fields_text_range.start(), ast::make::tokens::single_space().text());
+    }
+
+    edit.replace(tuple_fields_text_range, record_fields.to_string());
+    strukt.semicolon_token().map(|t| edit.delete(t.text_range()));
+}
+
+fn edit_struct_references(
+    ctx: &AssistContext,
+    edit: &mut AssistBuilder,
+    strukt: &ast::Struct,
+    names: &[ast::Name],
+) {
+    let strukt_def = ctx.sema.to_def(strukt).unwrap();
+    let usages = Definition::ModuleDef(ModuleDef::Adt(Adt::Struct(strukt_def)))
+        .usages(&ctx.sema)
+        .include_self_kw_refs(true)
+        .all();
+
+    for (file_id, refs) in usages {
+        edit.edit_file(file_id);
+        for r in refs {
+            for node in r.name.syntax().ancestors() {
+                match_ast! {
+                    match node {
+                        ast::TupleStructPat(tuple_struct_pat) => {
+                            edit.replace(
+                                tuple_struct_pat.syntax().text_range(),
+                                ast::make::record_pat_with_fields(
+                                    tuple_struct_pat.path().unwrap(),
+                                    ast::make::record_pat_field_list(tuple_struct_pat.fields().zip(names).map(
+                                        |(pat, name)| {
+                                            ast::make::record_pat_field(
+                                                ast::make::name_ref(&name.to_string()),
+                                                pat,
+                                            )
+                                        },
+                                    )),
+                                )
+                                .to_string(),
+                            );
+                        },
+                        // for tuple struct creations like: Foo(42)
+                        ast::CallExpr(call_expr) => {
+                            let path = call_expr.syntax().descendants().find_map(ast::PathExpr::cast).unwrap();
+                            let arg_list =
+                                call_expr.syntax().descendants().find_map(ast::ArgList::cast).unwrap();
+
+                            edit.replace(
+                                call_expr.syntax().text_range(),
+                                ast::make::record_expr(
+                                    path.path().unwrap(),
+                                    ast::make::record_expr_field_list(arg_list.args().zip(names).map(
+                                        |(expr, name)| {
+                                            ast::make::record_expr_field(
+                                                ast::make::name_ref(&name.to_string()),
+                                                Some(expr),
+                                            )
+                                        },
+                                    )),
+                                )
+                                .to_string(),
+                            );
+                        },
+                        _ => ()
+                    }
+                }
+            }
+        }
+    }
+}
+
+fn edit_field_references(
+    ctx: &AssistContext,
+    edit: &mut AssistBuilder,
+    fields: impl Iterator<Item = ast::TupleField>,
+    names: &[ast::Name],
+) {
+    for (field, name) in fields.zip(names) {
+        let field = match ctx.sema.to_def(&field) {
+            Some(it) => it,
+            None => continue,
+        };
+        let def = Definition::Field(field);
+        let usages = def.usages(&ctx.sema).all();
+        for (file_id, refs) in usages {
+            edit.edit_file(file_id);
+            for r in refs {
+                if let Some(name_ref) = r.name.as_name_ref() {
+                    edit.replace(name_ref.syntax().text_range(), name.text());
+                }
+            }
+        }
+    }
+}
+
+fn generate_names(fields: impl Iterator<Item = ast::TupleField>) -> Vec<ast::Name> {
+    fields.enumerate().map(|(i, _)| ast::make::name(&format!("field{}", i + 1))).collect()
+}
+
+#[cfg(test)]
+mod tests {
+    use crate::tests::{check_assist, check_assist_not_applicable};
+
+    use super::*;
+
+    #[test]
+    fn not_applicable_other_than_tuple_struct() {
+        check_assist_not_applicable(
+            convert_tuple_struct_to_named_struct,
+            r#"struct Foo$0 { bar: u32 };"#,
+        );
+        check_assist_not_applicable(convert_tuple_struct_to_named_struct, r#"struct Foo$0;"#);
+    }
+
+    #[test]
+    fn convert_simple_struct() {
+        check_assist(
+            convert_tuple_struct_to_named_struct,
+            r#"
+struct Inner;
+struct A$0(Inner);
+
+impl A {
+    fn new() -> A {
+        A(Inner)
+    }
+
+    fn into_inner(self) -> Inner {
+        self.0
+    }
+}"#,
+            r#"
+struct Inner;
+struct A { field1: Inner }
+
+impl A {
+    fn new() -> A {
+        A { field1: Inner }
+    }
+
+    fn into_inner(self) -> Inner {
+        self.field1
+    }
+}"#,
+        );
+    }
+
+    #[test]
+    fn convert_struct_referenced_via_self_kw() {
+        check_assist(
+            convert_tuple_struct_to_named_struct,
+            r#"
+struct Inner;
+struct A$0(Inner);
+
+impl A {
+    fn new() -> Self {
+        Self(Inner)
+    }
+
+    fn into_inner(self) -> Inner {
+        self.0
+    }
+}"#,
+            r#"
+struct Inner;
+struct A { field1: Inner }
+
+impl A {
+    fn new() -> Self {
+        Self { field1: Inner }
+    }
+
+    fn into_inner(self) -> Inner {
+        self.field1
+    }
+}"#,
+        );
+    }
+
+    #[test]
+    fn convert_destructured_struct() {
+        check_assist(
+            convert_tuple_struct_to_named_struct,
+            r#"
+struct Inner;
+struct A$0(Inner);
+
+impl A {
+    fn into_inner(self) -> Inner {
+        let A(first) = self;
+        first
+    }
+
+    fn into_inner_via_self(self) -> Inner {
+        let Self(first) = self;
+        first
+    }
+}"#,
+            r#"
+struct Inner;
+struct A { field1: Inner }
+
+impl A {
+    fn into_inner(self) -> Inner {
+        let A { field1: first } = self;
+        first
+    }
+
+    fn into_inner_via_self(self) -> Inner {
+        let Self { field1: first } = self;
+        first
+    }
+}"#,
+        );
+    }
+
+    #[test]
+    fn convert_struct_with_visibility() {
+        check_assist(
+            convert_tuple_struct_to_named_struct,
+            r#"
+struct A$0(pub u32, pub(crate) u64);
+
+impl A {
+    fn new() -> A {
+        A(42, 42)
+    }
+
+    fn into_first(self) -> u32 {
+        self.0
+    }
+
+    fn into_second(self) -> u64 {
+        self.1
+    }
+}"#,
+            r#"
+struct A { pub field1: u32, pub(crate) field2: u64 }
+
+impl A {
+    fn new() -> A {
+        A { field1: 42, field2: 42 }
+    }
+
+    fn into_first(self) -> u32 {
+        self.field1
+    }
+
+    fn into_second(self) -> u64 {
+        self.field2
+    }
+}"#,
+        );
+    }
+
+    #[test]
+    fn convert_struct_with_where_clause() {
+        check_assist(
+            convert_tuple_struct_to_named_struct,
+            r#"
+struct Wrap$0<T>(T)
+where
+    T: Display;
+"#,
+            r#"
+struct Wrap<T>
+where
+    T: Display,
+{ field1: T }
+
+"#,
+        );
+    }
+}
index 3e2c82dace0a718347afd4317a799a9966bc9d6f..1c55b9fbf967928278049501213454fd68bb2df5 100644 (file)
@@ -118,6 +118,7 @@ mod handlers {
     mod convert_comment_block;
     mod convert_iter_for_each_to_for;
     mod convert_into_to_from;
+    mod convert_tuple_struct_to_named_struct;
     mod early_return;
     mod expand_glob_import;
     mod extract_function;
@@ -187,6 +188,7 @@ pub(crate) fn all() -> &'static [Handler] {
             convert_comment_block::convert_comment_block,
             convert_iter_for_each_to_for::convert_iter_for_each_to_for,
             convert_into_to_from::convert_into_to_from,
+            convert_tuple_struct_to_named_struct::convert_tuple_struct_to_named_struct,
             early_return::convert_to_guarded_return,
             expand_glob_import::expand_glob_import,
             extract_struct_from_enum_variant::extract_struct_from_enum_variant,
index 27a22ca10c18566f7908f8163f77521762df693d..53f455adf9def6ef1846c6c9bdaadde0cd409a8c 100644 (file)
@@ -291,6 +291,21 @@ fn main() {
     )
 }
 
+#[test]
+fn doctest_convert_tuple_struct_to_named_struct() {
+    check_doc_test(
+        "convert_tuple_struct_to_named_struct",
+        r#####"
+struct Inner;
+struct A$0(Inner);
+"#####,
+        r#####"
+struct Inner;
+struct A { field1: Inner }
+"#####,
+    )
+}
+
 #[test]
 fn doctest_expand_glob_import() {
     check_doc_test(
index 02f5e514b53adfb002d6d7bfc03acdc492219058..90e4e7b03735772bc4d92b5b97b8767ecfceac57 100644 (file)
@@ -430,6 +430,15 @@ fn found_name_ref(
         sink: &mut dyn FnMut(FileId, FileReference) -> bool,
     ) -> bool {
         match NameRefClass::classify(self.sema, &name_ref) {
+            Some(NameRefClass::Definition(def)) if &def == self.def => {
+                let FileRange { file_id, range } = self.sema.original_range(name_ref.syntax());
+                let reference = FileReference {
+                    range,
+                    name: ast::NameLike::NameRef(name_ref.clone()),
+                    access: reference_access(&def, &name_ref),
+                };
+                sink(file_id, reference)
+            }
             Some(NameRefClass::Definition(Definition::SelfType(impl_))) => {
                 let ty = impl_.self_ty(self.sema.db);
 
@@ -448,15 +457,6 @@ fn found_name_ref(
 
                 false
             }
-            Some(NameRefClass::Definition(def)) if &def == self.def => {
-                let FileRange { file_id, range } = self.sema.original_range(name_ref.syntax());
-                let reference = FileReference {
-                    range,
-                    name: ast::NameLike::NameRef(name_ref.clone()),
-                    access: reference_access(&def, &name_ref),
-                };
-                sink(file_id, reference)
-            }
             Some(NameRefClass::FieldShorthand { local_ref: local, field_ref: field }) => {
                 let FileRange { file_id, range } = self.sema.original_range(name_ref.syntax());
                 let reference = match self.def {
index c6a7b99b7c9ab37f2a37b424836690dd31d448c1..3a588e540033a5fbc5c191b0cae4b3e2348c2815 100644 (file)
@@ -133,6 +133,17 @@ pub fn use_(visibility: Option<ast::Visibility>, use_tree: ast::UseTree) -> ast:
     ast_from_text(&format!("{}use {};", visibility, use_tree))
 }
 
+pub fn record_expr(path: ast::Path, fields: ast::RecordExprFieldList) -> ast::RecordExpr {
+    ast_from_text(&format!("fn f() {{ {} {} }}", path, fields))
+}
+
+pub fn record_expr_field_list(
+    fields: impl IntoIterator<Item = ast::RecordExprField>,
+) -> ast::RecordExprFieldList {
+    let fields = fields.into_iter().join(", ");
+    ast_from_text(&format!("fn f() {{ S {{ {} }} }}", fields))
+}
+
 pub fn record_expr_field(name: ast::NameRef, expr: Option<ast::Expr>) -> ast::RecordExprField {
     return match expr {
         Some(expr) => from_text(&format!("{}: {}", name, expr)),
@@ -325,6 +336,21 @@ fn from_text(text: &str) -> ast::RecordPat {
     }
 }
 
+pub fn record_pat_with_fields(path: ast::Path, fields: ast::RecordPatFieldList) -> ast::RecordPat {
+    ast_from_text(&format!("fn f({} {}: ()))", path, fields))
+}
+
+pub fn record_pat_field_list(
+    fields: impl IntoIterator<Item = ast::RecordPatField>,
+) -> ast::RecordPatFieldList {
+    let fields = fields.into_iter().join(", ");
+    ast_from_text(&format!("fn f(S {{ {} }}: ()))", fields))
+}
+
+pub fn record_pat_field(name_ref: ast::NameRef, pat: ast::Pat) -> ast::RecordPatField {
+    ast_from_text(&format!("fn f(S {{ {}: {} }}: ()))", name_ref, pat))
+}
+
 /// Returns a `BindPat` if the path has just one segment, a `PathPat` otherwise.
 pub fn path_pat(path: ast::Path) -> ast::Pat {
     return from_text(&path.to_string());