]> git.lizzy.rs Git - rust.git/blob - crates/ide_assists/src/handlers/reorder_fields.rs
Merge #8317
[rust.git] / crates / ide_assists / src / handlers / reorder_fields.rs
1 use either::Either;
2 use itertools::Itertools;
3 use rustc_hash::FxHashMap;
4
5 use syntax::{ast, ted, AstNode};
6
7 use crate::{AssistContext, AssistId, AssistKind, Assists};
8
9 // Assist: reorder_fields
10 //
11 // Reorder the fields of record literals and record patterns in the same order as in
12 // the definition.
13 //
14 // ```
15 // struct Foo {foo: i32, bar: i32};
16 // const test: Foo = $0Foo {bar: 0, foo: 1}
17 // ```
18 // ->
19 // ```
20 // struct Foo {foo: i32, bar: i32};
21 // const test: Foo = Foo {foo: 1, bar: 0}
22 // ```
23 //
24 pub(crate) fn reorder_fields(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
25     let record = ctx
26         .find_node_at_offset::<ast::RecordExpr>()
27         .map(Either::Left)
28         .or_else(|| ctx.find_node_at_offset::<ast::RecordPat>().map(Either::Right))?;
29
30     let path = record.as_ref().either(|it| it.path(), |it| it.path())?;
31     let ranks = compute_fields_ranks(&path, &ctx)?;
32     let get_rank_of_field =
33         |of: Option<_>| *ranks.get(&of.unwrap_or_default()).unwrap_or(&usize::MAX);
34
35     let field_list = match &record {
36         Either::Left(it) => Either::Left(it.record_expr_field_list()?),
37         Either::Right(it) => Either::Right(it.record_pat_field_list()?),
38     };
39     let fields = match field_list {
40         Either::Left(it) => Either::Left((
41             it.fields()
42                 .sorted_unstable_by_key(|field| {
43                     get_rank_of_field(field.field_name().map(|it| it.to_string()))
44                 })
45                 .collect::<Vec<_>>(),
46             it,
47         )),
48         Either::Right(it) => Either::Right((
49             it.fields()
50                 .sorted_unstable_by_key(|field| {
51                     get_rank_of_field(field.field_name().map(|it| it.to_string()))
52                 })
53                 .collect::<Vec<_>>(),
54             it,
55         )),
56     };
57
58     let is_sorted = fields.as_ref().either(
59         |(sorted, field_list)| field_list.fields().zip(sorted).all(|(a, b)| a == *b),
60         |(sorted, field_list)| field_list.fields().zip(sorted).all(|(a, b)| a == *b),
61     );
62     if is_sorted {
63         cov_mark::hit!(reorder_sorted_fields);
64         return None;
65     }
66     let target = record.as_ref().either(AstNode::syntax, AstNode::syntax).text_range();
67     acc.add(
68         AssistId("reorder_fields", AssistKind::RefactorRewrite),
69         "Reorder record fields",
70         target,
71         |builder| match fields {
72             Either::Left((sorted, field_list)) => {
73                 replace(builder.make_ast_mut(field_list).fields(), sorted)
74             }
75             Either::Right((sorted, field_list)) => {
76                 replace(builder.make_ast_mut(field_list).fields(), sorted)
77             }
78         },
79     )
80 }
81
82 fn replace<T: AstNode + PartialEq>(
83     fields: impl Iterator<Item = T>,
84     sorted_fields: impl IntoIterator<Item = T>,
85 ) {
86     fields.zip(sorted_fields).filter(|(field, sorted)| field != sorted).for_each(
87         |(field, sorted_field)| {
88             ted::replace(field.syntax(), sorted_field.syntax().clone_for_update());
89         },
90     );
91 }
92
93 fn compute_fields_ranks(path: &ast::Path, ctx: &AssistContext) -> Option<FxHashMap<String, usize>> {
94     let strukt = match ctx.sema.resolve_path(path) {
95         Some(hir::PathResolution::Def(hir::ModuleDef::Adt(hir::Adt::Struct(it)))) => it,
96         _ => return None,
97     };
98
99     let res = strukt
100         .fields(ctx.db())
101         .into_iter()
102         .enumerate()
103         .map(|(idx, field)| (field.name(ctx.db()).to_string(), idx))
104         .collect();
105
106     Some(res)
107 }
108
109 #[cfg(test)]
110 mod tests {
111     use crate::tests::{check_assist, check_assist_not_applicable};
112
113     use super::*;
114
115     #[test]
116     fn reorder_sorted_fields() {
117         cov_mark::check!(reorder_sorted_fields);
118         check_assist_not_applicable(
119             reorder_fields,
120             r#"
121 struct Foo { foo: i32, bar: i32 }
122 const test: Foo = $0Foo { foo: 0, bar: 0 };
123 "#,
124         )
125     }
126
127     #[test]
128     fn trivial_empty_fields() {
129         check_assist_not_applicable(
130             reorder_fields,
131             r#"
132 struct Foo {}
133 const test: Foo = $0Foo {};
134 "#,
135         )
136     }
137
138     #[test]
139     fn reorder_struct_fields() {
140         check_assist(
141             reorder_fields,
142             r#"
143 struct Foo { foo: i32, bar: i32 }
144 const test: Foo = $0Foo { bar: 0, foo: 1 };
145 "#,
146             r#"
147 struct Foo { foo: i32, bar: i32 }
148 const test: Foo = Foo { foo: 1, bar: 0 };
149 "#,
150         )
151     }
152     #[test]
153     fn reorder_struct_pattern() {
154         check_assist(
155             reorder_fields,
156             r#"
157 struct Foo { foo: i64, bar: i64, baz: i64 }
158
159 fn f(f: Foo) -> {
160     match f {
161         $0Foo { baz: 0, ref mut bar, .. } => (),
162         _ => ()
163     }
164 }
165 "#,
166             r#"
167 struct Foo { foo: i64, bar: i64, baz: i64 }
168
169 fn f(f: Foo) -> {
170     match f {
171         Foo { ref mut bar, baz: 0, .. } => (),
172         _ => ()
173     }
174 }
175 "#,
176         )
177     }
178
179     #[test]
180     fn reorder_with_extra_field() {
181         check_assist(
182             reorder_fields,
183             r#"
184 struct Foo { foo: String, bar: String }
185
186 impl Foo {
187     fn new() -> Foo {
188         let foo = String::new();
189         $0Foo {
190             bar: foo.clone(),
191             extra: "Extra field",
192             foo,
193         }
194     }
195 }
196 "#,
197             r#"
198 struct Foo { foo: String, bar: String }
199
200 impl Foo {
201     fn new() -> Foo {
202         let foo = String::new();
203         Foo {
204             foo,
205             bar: foo.clone(),
206             extra: "Extra field",
207         }
208     }
209 }
210 "#,
211         )
212     }
213 }