]> git.lizzy.rs Git - rust.git/blob - crates/ide_assists/src/handlers/reorder_fields.rs
Merge #8639
[rust.git] / crates / ide_assists / src / handlers / reorder_fields.rs
1 use either::Either;
2 use itertools::Itertools;
3 use rustc_hash::FxHashMap;
4
5 use syntax::{ast, ted, AstNode};
6
7 use crate::{AssistContext, AssistId, AssistKind, Assists};
8
9 // Assist: reorder_fields
10 //
11 // Reorder the fields of record literals and record patterns in the same order as in
12 // the definition.
13 //
14 // ```
15 // struct Foo {foo: i32, bar: i32};
16 // const test: Foo = $0Foo {bar: 0, foo: 1}
17 // ```
18 // ->
19 // ```
20 // struct Foo {foo: i32, bar: i32};
21 // const test: Foo = Foo {foo: 1, bar: 0}
22 // ```
23 //
24 pub(crate) fn reorder_fields(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
25     let record = ctx
26         .find_node_at_offset::<ast::RecordExpr>()
27         .map(Either::Left)
28         .or_else(|| ctx.find_node_at_offset::<ast::RecordPat>().map(Either::Right))?;
29
30     let path = record.as_ref().either(|it| it.path(), |it| it.path())?;
31     let ranks = compute_fields_ranks(&path, &ctx)?;
32     let get_rank_of_field =
33         |of: Option<_>| *ranks.get(&of.unwrap_or_default()).unwrap_or(&usize::MAX);
34
35     let field_list = match &record {
36         Either::Left(it) => Either::Left(it.record_expr_field_list()?),
37         Either::Right(it) => Either::Right(it.record_pat_field_list()?),
38     };
39     let fields = match field_list {
40         Either::Left(it) => Either::Left((
41             it.fields()
42                 .sorted_unstable_by_key(|field| {
43                     get_rank_of_field(field.field_name().map(|it| it.to_string()))
44                 })
45                 .collect::<Vec<_>>(),
46             it,
47         )),
48         Either::Right(it) => Either::Right((
49             it.fields()
50                 .sorted_unstable_by_key(|field| {
51                     get_rank_of_field(field.field_name().map(|it| it.to_string()))
52                 })
53                 .collect::<Vec<_>>(),
54             it,
55         )),
56     };
57
58     let is_sorted = fields.as_ref().either(
59         |(sorted, field_list)| field_list.fields().zip(sorted).all(|(a, b)| a == *b),
60         |(sorted, field_list)| field_list.fields().zip(sorted).all(|(a, b)| a == *b),
61     );
62     if is_sorted {
63         cov_mark::hit!(reorder_sorted_fields);
64         return None;
65     }
66     let target = record.as_ref().either(AstNode::syntax, AstNode::syntax).text_range();
67     acc.add(
68         AssistId("reorder_fields", AssistKind::RefactorRewrite),
69         "Reorder record fields",
70         target,
71         |builder| match fields {
72             Either::Left((sorted, field_list)) => {
73                 replace(builder.make_ast_mut(field_list).fields(), sorted)
74             }
75             Either::Right((sorted, field_list)) => {
76                 replace(builder.make_ast_mut(field_list).fields(), sorted)
77             }
78         },
79     )
80 }
81
82 fn replace<T: AstNode + PartialEq>(
83     fields: impl Iterator<Item = T>,
84     sorted_fields: impl IntoIterator<Item = T>,
85 ) {
86     fields.zip(sorted_fields).for_each(|(field, sorted_field)| {
87         ted::replace(field.syntax(), sorted_field.syntax().clone_for_update())
88     });
89 }
90
91 fn compute_fields_ranks(path: &ast::Path, ctx: &AssistContext) -> Option<FxHashMap<String, usize>> {
92     let strukt = match ctx.sema.resolve_path(path) {
93         Some(hir::PathResolution::Def(hir::ModuleDef::Adt(hir::Adt::Struct(it)))) => it,
94         _ => return None,
95     };
96
97     let res = strukt
98         .fields(ctx.db())
99         .into_iter()
100         .enumerate()
101         .map(|(idx, field)| (field.name(ctx.db()).to_string(), idx))
102         .collect();
103
104     Some(res)
105 }
106
107 #[cfg(test)]
108 mod tests {
109     use crate::tests::{check_assist, check_assist_not_applicable};
110
111     use super::*;
112
113     #[test]
114     fn reorder_sorted_fields() {
115         cov_mark::check!(reorder_sorted_fields);
116         check_assist_not_applicable(
117             reorder_fields,
118             r#"
119 struct Foo { foo: i32, bar: i32 }
120 const test: Foo = $0Foo { foo: 0, bar: 0 };
121 "#,
122         )
123     }
124
125     #[test]
126     fn trivial_empty_fields() {
127         check_assist_not_applicable(
128             reorder_fields,
129             r#"
130 struct Foo {}
131 const test: Foo = $0Foo {};
132 "#,
133         )
134     }
135
136     #[test]
137     fn reorder_struct_fields() {
138         check_assist(
139             reorder_fields,
140             r#"
141 struct Foo { foo: i32, bar: i32 }
142 const test: Foo = $0Foo { bar: 0, foo: 1 };
143 "#,
144             r#"
145 struct Foo { foo: i32, bar: i32 }
146 const test: Foo = Foo { foo: 1, bar: 0 };
147 "#,
148         )
149     }
150     #[test]
151     fn reorder_struct_pattern() {
152         check_assist(
153             reorder_fields,
154             r#"
155 struct Foo { foo: i64, bar: i64, baz: i64 }
156
157 fn f(f: Foo) -> {
158     match f {
159         $0Foo { baz: 0, ref mut bar, .. } => (),
160         _ => ()
161     }
162 }
163 "#,
164             r#"
165 struct Foo { foo: i64, bar: i64, baz: i64 }
166
167 fn f(f: Foo) -> {
168     match f {
169         Foo { ref mut bar, baz: 0, .. } => (),
170         _ => ()
171     }
172 }
173 "#,
174         )
175     }
176
177     #[test]
178     fn reorder_with_extra_field() {
179         check_assist(
180             reorder_fields,
181             r#"
182 struct Foo { foo: String, bar: String }
183
184 impl Foo {
185     fn new() -> Foo {
186         let foo = String::new();
187         $0Foo {
188             bar: foo.clone(),
189             extra: "Extra field",
190             foo,
191         }
192     }
193 }
194 "#,
195             r#"
196 struct Foo { foo: String, bar: String }
197
198 impl Foo {
199     fn new() -> Foo {
200         let foo = String::new();
201         Foo {
202             foo,
203             bar: foo.clone(),
204             extra: "Extra field",
205         }
206     }
207 }
208 "#,
209         )
210     }
211 }