]> git.lizzy.rs Git - rust.git/blob - crates/ide_assists/src/handlers/merge_match_arms.rs
compiles, but doesn't work yet
[rust.git] / crates / ide_assists / src / handlers / merge_match_arms.rs
1 use std::iter::successors;
2
3 use hir::{TypeInfo, HirDisplay};
4 use itertools::Itertools;
5 use syntax::{
6     algo::neighbor,
7     ast::{self, AstNode},
8     Direction,
9 };
10
11 use crate::{AssistContext, AssistId, AssistKind, Assists, TextRange};
12
13 // Assist: merge_match_arms
14 //
15 // Merges the current match arm with the following if their bodies are identical.
16 //
17 // ```
18 // enum Action { Move { distance: u32 }, Stop }
19 //
20 // fn handle(action: Action) {
21 //     match action {
22 //         $0Action::Move(..) => foo(),
23 //         Action::Stop => foo(),
24 //     }
25 // }
26 // ```
27 // ->
28 // ```
29 // enum Action { Move { distance: u32 }, Stop }
30 //
31 // fn handle(action: Action) {
32 //     match action {
33 //         Action::Move(..) | Action::Stop => foo(),
34 //     }
35 // }
36 // ```
37 pub(crate) fn merge_match_arms(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
38     let current_arm = ctx.find_node_at_offset::<ast::MatchArm>()?;
39     // Don't try to handle arms with guards for now - can add support for this later
40     if current_arm.guard().is_some() {
41         return None;
42     }
43     let current_expr = current_arm.expr()?;
44     let current_text_range = current_arm.syntax().text_range();
45     let current_arm_types = get_arm_types(&ctx, &current_arm);
46
47     // We check if the following match arms match this one. We could, but don't,
48     // compare to the previous match arm as well.
49     let arms_to_merge = successors(Some(current_arm), |it| neighbor(it, Direction::Next))
50         .take_while(|arm| match arm.expr() {
51             Some(expr) if arm.guard().is_none() && arm.pat().is_some() => {
52                 let same_text = expr.syntax().text() == current_expr.syntax().text();
53                 if !same_text {
54                     return false;
55                 }
56
57                 let arm_types = get_arm_types(&ctx, &arm);
58                 for i in 0..arm_types.len() {
59                     let other_arm_type = &arm_types[i].as_ref();
60                     let current_arm_type = current_arm_types[i].as_ref();
61                     if other_arm_type.is_some() && current_arm_type.is_some() {
62                         let other_arm_type = other_arm_type.unwrap().original.clone().as_adt();
63                         let current_arm_type = current_arm_type.unwrap().original.clone().as_adt();
64                         println!("Same types!");
65                         println!("{:?}", other_arm_type);
66                         println!("{:?}", current_arm_type);
67                         return other_arm_type == current_arm_type;
68                     }
69                 }
70
71                 true
72             }
73             _ => false,
74         })
75         .collect::<Vec<_>>();
76
77     if arms_to_merge.len() <= 1 {
78         return None;
79     }
80
81     acc.add(
82         AssistId("merge_match_arms", AssistKind::RefactorRewrite),
83         "Merge match arms",
84         current_text_range,
85         |edit| {
86             let pats = if arms_to_merge.iter().any(contains_placeholder) {
87                 "_".into()
88             } else {
89                 arms_to_merge
90                     .iter()
91                     .filter_map(ast::MatchArm::pat)
92                     .map(|x| x.syntax().to_string())
93                     .collect::<Vec<String>>()
94                     .join(" | ")
95             };
96
97             let arm = format!("{} => {},", pats, current_expr.syntax().text());
98
99             if let [first, .., last] = &*arms_to_merge {
100                 let start = first.syntax().text_range().start();
101                 let end = last.syntax().text_range().end();
102
103                 edit.replace(TextRange::new(start, end), arm);
104             }
105         },
106     )
107 }
108
109 fn contains_placeholder(a: &ast::MatchArm) -> bool {
110     matches!(a.pat(), Some(ast::Pat::WildcardPat(..)))
111 }
112
113 fn get_arm_types(ctx: &AssistContext, arm: &ast::MatchArm) -> Vec<Option<TypeInfo>> {
114     match arm.pat() {
115         Some(ast::Pat::TupleStructPat(tp)) => tp
116             .fields()
117             .into_iter()
118             .map(|field| {
119                 let pat_type = ctx.sema.type_of_pat(&field);
120                 pat_type
121             })
122             .collect_vec(),
123         _ => Vec::new(),
124     }
125 }
126
127 #[cfg(test)]
128 mod tests {
129     use crate::tests::{check_assist, check_assist_not_applicable};
130
131     use super::*;
132
133     #[test]
134     fn merge_match_arms_single_patterns() {
135         check_assist(
136             merge_match_arms,
137             r#"
138 #[derive(Debug)]
139 enum X { A, B, C }
140
141 fn main() {
142     let x = X::A;
143     let y = match x {
144         X::A => { 1i32$0 }
145         X::B => { 1i32 }
146         X::C => { 2i32 }
147     }
148 }
149 "#,
150             r#"
151 #[derive(Debug)]
152 enum X { A, B, C }
153
154 fn main() {
155     let x = X::A;
156     let y = match x {
157         X::A | X::B => { 1i32 },
158         X::C => { 2i32 }
159     }
160 }
161 "#,
162         );
163     }
164
165     #[test]
166     fn merge_match_arms_multiple_patterns() {
167         check_assist(
168             merge_match_arms,
169             r#"
170 #[derive(Debug)]
171 enum X { A, B, C, D, E }
172
173 fn main() {
174     let x = X::A;
175     let y = match x {
176         X::A | X::B => {$0 1i32 },
177         X::C | X::D => { 1i32 },
178         X::E => { 2i32 },
179     }
180 }
181 "#,
182             r#"
183 #[derive(Debug)]
184 enum X { A, B, C, D, E }
185
186 fn main() {
187     let x = X::A;
188     let y = match x {
189         X::A | X::B | X::C | X::D => { 1i32 },
190         X::E => { 2i32 },
191     }
192 }
193 "#,
194         );
195     }
196
197     #[test]
198     fn merge_match_arms_placeholder_pattern() {
199         check_assist(
200             merge_match_arms,
201             r#"
202 #[derive(Debug)]
203 enum X { A, B, C, D, E }
204
205 fn main() {
206     let x = X::A;
207     let y = match x {
208         X::A => { 1i32 },
209         X::B => { 2i$032 },
210         _ => { 2i32 }
211     }
212 }
213 "#,
214             r#"
215 #[derive(Debug)]
216 enum X { A, B, C, D, E }
217
218 fn main() {
219     let x = X::A;
220     let y = match x {
221         X::A => { 1i32 },
222         _ => { 2i32 },
223     }
224 }
225 "#,
226         );
227     }
228
229     #[test]
230     fn merges_all_subsequent_arms() {
231         check_assist(
232             merge_match_arms,
233             r#"
234 enum X { A, B, C, D, E }
235
236 fn main() {
237     match X::A {
238         X::A$0 => 92,
239         X::B => 92,
240         X::C => 92,
241         X::D => 62,
242         _ => panic!(),
243     }
244 }
245 "#,
246             r#"
247 enum X { A, B, C, D, E }
248
249 fn main() {
250     match X::A {
251         X::A | X::B | X::C => 92,
252         X::D => 62,
253         _ => panic!(),
254     }
255 }
256 "#,
257         )
258     }
259
260     #[test]
261     fn merge_match_arms_rejects_guards() {
262         check_assist_not_applicable(
263             merge_match_arms,
264             r#"
265 #[derive(Debug)]
266 enum X {
267     A(i32),
268     B,
269     C
270 }
271
272 fn main() {
273     let x = X::A;
274     let y = match x {
275         X::A(a) if a > 5 => { $01i32 },
276         X::B => { 1i32 },
277         X::C => { 2i32 }
278     }
279 }
280 "#,
281         );
282     }
283
284     #[test]
285     fn merge_match_arms_different_type() {
286         check_assist_not_applicable(
287             merge_match_arms,
288             r#"
289 fn func() {
290     match Result::<i32, f32>::Ok(0) {
291         Ok(x) => $0x.to_string(),
292         Err(x) => x.to_string()
293     };
294 }
295 "#,
296         );
297     }
298 }
299
300 // fn func() {
301 //     match Result::<i32, f32>::Ok(0) {
302 //         Ok(x) => x.to_string(),
303 //         Err(x) => x.to_string()
304 //     };
305 // }