]> git.lizzy.rs Git - rust.git/blob - src/tools/rust-analyzer/crates/ide-assists/src/handlers/reorder_fields.rs
Rollup merge of #97963 - devnexen:net_listener_neg, r=the8472
[rust.git] / src / tools / rust-analyzer / crates / ide-assists / src / handlers / reorder_fields.rs
1 use either::Either;
2 use ide_db::FxHashMap;
3 use itertools::Itertools;
4 use syntax::{ast, ted, AstNode};
5
6 use crate::{AssistContext, AssistId, AssistKind, Assists};
7
8 // Assist: reorder_fields
9 //
10 // Reorder the fields of record literals and record patterns in the same order as in
11 // the definition.
12 //
13 // ```
14 // struct Foo {foo: i32, bar: i32};
15 // const test: Foo = $0Foo {bar: 0, foo: 1}
16 // ```
17 // ->
18 // ```
19 // struct Foo {foo: i32, bar: i32};
20 // const test: Foo = Foo {foo: 1, bar: 0}
21 // ```
22 pub(crate) fn reorder_fields(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> {
23     let record = ctx
24         .find_node_at_offset::<ast::RecordExpr>()
25         .map(Either::Left)
26         .or_else(|| ctx.find_node_at_offset::<ast::RecordPat>().map(Either::Right))?;
27
28     let path = record.as_ref().either(|it| it.path(), |it| it.path())?;
29     let ranks = compute_fields_ranks(&path, ctx)?;
30     let get_rank_of_field =
31         |of: Option<_>| *ranks.get(&of.unwrap_or_default()).unwrap_or(&usize::MAX);
32
33     let field_list = match &record {
34         Either::Left(it) => Either::Left(it.record_expr_field_list()?),
35         Either::Right(it) => Either::Right(it.record_pat_field_list()?),
36     };
37     let fields = match field_list {
38         Either::Left(it) => Either::Left((
39             it.fields()
40                 .sorted_unstable_by_key(|field| {
41                     get_rank_of_field(field.field_name().map(|it| it.to_string()))
42                 })
43                 .collect::<Vec<_>>(),
44             it,
45         )),
46         Either::Right(it) => Either::Right((
47             it.fields()
48                 .sorted_unstable_by_key(|field| {
49                     get_rank_of_field(field.field_name().map(|it| it.to_string()))
50                 })
51                 .collect::<Vec<_>>(),
52             it,
53         )),
54     };
55
56     let is_sorted = fields.as_ref().either(
57         |(sorted, field_list)| field_list.fields().zip(sorted).all(|(a, b)| a == *b),
58         |(sorted, field_list)| field_list.fields().zip(sorted).all(|(a, b)| a == *b),
59     );
60     if is_sorted {
61         cov_mark::hit!(reorder_sorted_fields);
62         return None;
63     }
64     let target = record.as_ref().either(AstNode::syntax, AstNode::syntax).text_range();
65     acc.add(
66         AssistId("reorder_fields", AssistKind::RefactorRewrite),
67         "Reorder record fields",
68         target,
69         |builder| match fields {
70             Either::Left((sorted, field_list)) => {
71                 replace(builder.make_mut(field_list).fields(), sorted)
72             }
73             Either::Right((sorted, field_list)) => {
74                 replace(builder.make_mut(field_list).fields(), sorted)
75             }
76         },
77     )
78 }
79
80 fn replace<T: AstNode + PartialEq>(
81     fields: impl Iterator<Item = T>,
82     sorted_fields: impl IntoIterator<Item = T>,
83 ) {
84     fields.zip(sorted_fields).for_each(|(field, sorted_field)| {
85         ted::replace(field.syntax(), sorted_field.syntax().clone_for_update())
86     });
87 }
88
89 fn compute_fields_ranks(
90     path: &ast::Path,
91     ctx: &AssistContext<'_>,
92 ) -> Option<FxHashMap<String, usize>> {
93     let strukt = match ctx.sema.resolve_path(path) {
94         Some(hir::PathResolution::Def(hir::ModuleDef::Adt(hir::Adt::Struct(it)))) => it,
95         _ => return None,
96     };
97
98     let res = strukt
99         .fields(ctx.db())
100         .into_iter()
101         .enumerate()
102         .map(|(idx, field)| (field.name(ctx.db()).to_string(), idx))
103         .collect();
104
105     Some(res)
106 }
107
108 #[cfg(test)]
109 mod tests {
110     use crate::tests::{check_assist, check_assist_not_applicable};
111
112     use super::*;
113
114     #[test]
115     fn reorder_sorted_fields() {
116         cov_mark::check!(reorder_sorted_fields);
117         check_assist_not_applicable(
118             reorder_fields,
119             r#"
120 struct Foo { foo: i32, bar: i32 }
121 const test: Foo = $0Foo { foo: 0, bar: 0 };
122 "#,
123         )
124     }
125
126     #[test]
127     fn trivial_empty_fields() {
128         check_assist_not_applicable(
129             reorder_fields,
130             r#"
131 struct Foo {}
132 const test: Foo = $0Foo {};
133 "#,
134         )
135     }
136
137     #[test]
138     fn reorder_struct_fields() {
139         check_assist(
140             reorder_fields,
141             r#"
142 struct Foo { foo: i32, bar: i32 }
143 const test: Foo = $0Foo { bar: 0, foo: 1 };
144 "#,
145             r#"
146 struct Foo { foo: i32, bar: i32 }
147 const test: Foo = Foo { foo: 1, bar: 0 };
148 "#,
149         )
150     }
151     #[test]
152     fn reorder_struct_pattern() {
153         check_assist(
154             reorder_fields,
155             r#"
156 struct Foo { foo: i64, bar: i64, baz: i64 }
157
158 fn f(f: Foo) -> {
159     match f {
160         $0Foo { baz: 0, ref mut bar, .. } => (),
161         _ => ()
162     }
163 }
164 "#,
165             r#"
166 struct Foo { foo: i64, bar: i64, baz: i64 }
167
168 fn f(f: Foo) -> {
169     match f {
170         Foo { ref mut bar, baz: 0, .. } => (),
171         _ => ()
172     }
173 }
174 "#,
175         )
176     }
177
178     #[test]
179     fn reorder_with_extra_field() {
180         check_assist(
181             reorder_fields,
182             r#"
183 struct Foo { foo: String, bar: String }
184
185 impl Foo {
186     fn new() -> Foo {
187         let foo = String::new();
188         $0Foo {
189             bar: foo.clone(),
190             extra: "Extra field",
191             foo,
192         }
193     }
194 }
195 "#,
196             r#"
197 struct Foo { foo: String, bar: String }
198
199 impl Foo {
200     fn new() -> Foo {
201         let foo = String::new();
202         Foo {
203             foo,
204             bar: foo.clone(),
205             extra: "Extra field",
206         }
207     }
208 }
209 "#,
210         )
211     }
212 }