]> git.lizzy.rs Git - rust.git/blobdiff - crates/ide_assists/src/handlers/reorder_fields.rs
Merge #11481
[rust.git] / crates / ide_assists / src / handlers / reorder_fields.rs
index 383ca6c473fc18019695535a9f0e567b889b9af1..cd4eb7c15e9085ff1ee83b6e9bf3b5b7404d45fd 100644 (file)
@@ -1,6 +1,8 @@
+use either::Either;
+use itertools::Itertools;
 use rustc_hash::FxHashMap;
 
-use syntax::{algo, ast, match_ast, AstNode, SyntaxKind::*, SyntaxNode};
+use syntax::{ast, ted, AstNode};
 
 use crate::{AssistContext, AssistId, AssistKind, Assists};
 
 // struct Foo {foo: i32, bar: i32};
 // const test: Foo = Foo {foo: 1, bar: 0}
 // ```
-//
 pub(crate) fn reorder_fields(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
     let record = ctx
         .find_node_at_offset::<ast::RecordExpr>()
-        .map(|it| it.syntax().clone())
-        .or_else(|| ctx.find_node_at_offset::<ast::RecordPat>().map(|it| it.syntax().clone()))?;
-
-    let path = record.children().find_map(ast::Path::cast)?;
+        .map(Either::Left)
+        .or_else(|| ctx.find_node_at_offset::<ast::RecordPat>().map(Either::Right))?;
 
-    let ranks = compute_fields_ranks(&path, &ctx)?;
+    let path = record.as_ref().either(|it| it.path(), |it| it.path())?;
+    let ranks = compute_fields_ranks(&path, ctx)?;
+    let get_rank_of_field =
+        |of: Option<_>| *ranks.get(&of.unwrap_or_default()).unwrap_or(&usize::MAX);
 
-    let fields: Vec<SyntaxNode> = {
-        let field_kind = match record.kind() {
-            RECORD_EXPR => RECORD_EXPR_FIELD,
-            RECORD_PAT => RECORD_PAT_FIELD,
-            _ => {
-                stdx::never!();
-                return None;
-            }
-        };
-        record.children().flat_map(|n| n.children()).filter(|n| n.kind() == field_kind).collect()
+    let field_list = match &record {
+        Either::Left(it) => Either::Left(it.record_expr_field_list()?),
+        Either::Right(it) => Either::Right(it.record_pat_field_list()?),
     };
-
-    let sorted_fields = {
-        let mut fields = fields.clone();
-        fields.sort_by_key(|node| *ranks.get(&get_field_name(node)).unwrap_or(&usize::max_value()));
-        fields
+    let fields = match field_list {
+        Either::Left(it) => Either::Left((
+            it.fields()
+                .sorted_unstable_by_key(|field| {
+                    get_rank_of_field(field.field_name().map(|it| it.to_string()))
+                })
+                .collect::<Vec<_>>(),
+            it,
+        )),
+        Either::Right(it) => Either::Right((
+            it.fields()
+                .sorted_unstable_by_key(|field| {
+                    get_rank_of_field(field.field_name().map(|it| it.to_string()))
+                })
+                .collect::<Vec<_>>(),
+            it,
+        )),
     };
 
-    if sorted_fields == fields {
+    let is_sorted = fields.as_ref().either(
+        |(sorted, field_list)| field_list.fields().zip(sorted).all(|(a, b)| a == *b),
+        |(sorted, field_list)| field_list.fields().zip(sorted).all(|(a, b)| a == *b),
+    );
+    if is_sorted {
         cov_mark::hit!(reorder_sorted_fields);
         return None;
     }
-
-    let target = record.text_range();
+    let target = record.as_ref().either(AstNode::syntax, AstNode::syntax).text_range();
     acc.add(
         AssistId("reorder_fields", AssistKind::RefactorRewrite),
         "Reorder record fields",
         target,
-        |edit| {
-            let mut rewriter = algo::SyntaxRewriter::default();
-            for (old, new) in fields.iter().zip(&sorted_fields) {
-                rewriter.replace(old, new);
+        |builder| match fields {
+            Either::Left((sorted, field_list)) => {
+                replace(builder.make_mut(field_list).fields(), sorted)
+            }
+            Either::Right((sorted, field_list)) => {
+                replace(builder.make_mut(field_list).fields(), sorted)
             }
-            edit.rewrite(rewriter);
         },
     )
 }
 
-fn get_field_name(node: &SyntaxNode) -> String {
-    let res = match_ast! {
-        match node {
-            ast::RecordExprField(field) => field.field_name().map(|it| it.to_string()),
-            ast::RecordPatField(field) => field.field_name().map(|it| it.to_string()),
-            _ => None,
-        }
-    };
-    res.unwrap_or_default()
+fn replace<T: AstNode + PartialEq>(
+    fields: impl Iterator<Item = T>,
+    sorted_fields: impl IntoIterator<Item = T>,
+) {
+    fields.zip(sorted_fields).for_each(|(field, sorted_field)| {
+        ted::replace(field.syntax(), sorted_field.syntax().clone_for_update())
+    });
 }
 
 fn compute_fields_ranks(path: &ast::Path, ctx: &AssistContext) -> Option<FxHashMap<String, usize>> {
@@ -86,7 +95,7 @@ fn compute_fields_ranks(path: &ast::Path, ctx: &AssistContext) -> Option<FxHashM
 
     let res = strukt
         .fields(ctx.db())
-        .iter()
+        .into_iter()
         .enumerate()
         .map(|(idx, field)| (field.name(ctx.db()).to_string(), idx))
         .collect();
@@ -137,7 +146,6 @@ struct Foo { foo: i32, bar: i32 }
 "#,
         )
     }
-
     #[test]
     fn reorder_struct_pattern() {
         check_assist(