]> git.lizzy.rs Git - rust.git/blob - crates/ra_assists/src/handlers/fill_match_arms.rs
Merge #3623
[rust.git] / crates / ra_assists / src / handlers / fill_match_arms.rs
1 //! FIXME: write short doc here
2
3 use std::iter;
4
5 use hir::{Adt, HasSource, Semantics};
6 use ra_ide_db::RootDatabase;
7
8 use crate::{Assist, AssistCtx, AssistId};
9 use ra_syntax::ast::{self, edit::IndentLevel, make, AstNode, NameOwner};
10
11 use ast::{MatchArm, Pat};
12
13 // Assist: fill_match_arms
14 //
15 // Adds missing clauses to a `match` expression.
16 //
17 // ```
18 // enum Action { Move { distance: u32 }, Stop }
19 //
20 // fn handle(action: Action) {
21 //     match action {
22 //         <|>
23 //     }
24 // }
25 // ```
26 // ->
27 // ```
28 // enum Action { Move { distance: u32 }, Stop }
29 //
30 // fn handle(action: Action) {
31 //     match action {
32 //         Action::Move { distance } => (),
33 //         Action::Stop => (),
34 //     }
35 // }
36 // ```
37 pub(crate) fn fill_match_arms(ctx: AssistCtx) -> Option<Assist> {
38     let match_expr = ctx.find_node_at_offset::<ast::MatchExpr>()?;
39     let match_arm_list = match_expr.match_arm_list()?;
40
41     let expr = match_expr.expr()?;
42     let enum_def = resolve_enum_def(&ctx.sema, &expr)?;
43     let module = ctx.sema.scope(expr.syntax()).module()?;
44
45     let variants = enum_def.variants(ctx.db);
46     if variants.is_empty() {
47         return None;
48     }
49
50     let mut arms: Vec<MatchArm> = match_arm_list.arms().collect();
51     if arms.len() == 1 {
52         if let Some(Pat::PlaceholderPat(..)) = arms[0].pat() {
53             arms.clear();
54         }
55     }
56
57     let db = ctx.db;
58     let missing_arms: Vec<MatchArm> = variants
59         .into_iter()
60         .filter_map(|variant| build_pat(db, module, variant))
61         .filter(|variant_pat| is_variant_missing(&mut arms, variant_pat))
62         .map(|pat| make::match_arm(iter::once(pat), make::expr_unit()))
63         .collect();
64
65     if missing_arms.is_empty() {
66         return None;
67     }
68
69     ctx.add_assist(AssistId("fill_match_arms"), "Fill match arms", |edit| {
70         arms.extend(missing_arms);
71
72         let indent_level = IndentLevel::from_node(match_arm_list.syntax());
73         let new_arm_list = indent_level.increase_indent(make::match_arm_list(arms));
74
75         edit.target(match_expr.syntax().text_range());
76         edit.set_cursor(expr.syntax().text_range().start());
77         edit.replace_ast(match_arm_list, new_arm_list);
78     })
79 }
80
81 fn is_variant_missing(existing_arms: &mut Vec<MatchArm>, var: &Pat) -> bool {
82     existing_arms.iter().filter_map(|arm| arm.pat()).all(|pat| {
83         // Special casee OrPat as separate top-level pats
84         let top_level_pats: Vec<Pat> = match pat {
85             Pat::OrPat(pats) => pats.pats().collect::<Vec<_>>(),
86             _ => vec![pat],
87         };
88
89         !top_level_pats.iter().any(|pat| does_pat_match_variant(pat, var))
90     })
91 }
92
93 fn does_pat_match_variant(pat: &Pat, var: &Pat) -> bool {
94     let pat_head = pat.syntax().first_child().map(|node| node.text());
95     let var_head = var.syntax().first_child().map(|node| node.text());
96
97     pat_head == var_head
98 }
99
100 fn resolve_enum_def(sema: &Semantics<RootDatabase>, expr: &ast::Expr) -> Option<hir::Enum> {
101     sema.type_of_expr(&expr)?.autoderef(sema.db).find_map(|ty| match ty.as_adt() {
102         Some(Adt::Enum(e)) => Some(e),
103         _ => None,
104     })
105 }
106
107 fn build_pat(db: &RootDatabase, module: hir::Module, var: hir::EnumVariant) -> Option<ast::Pat> {
108     let path = crate::ast_transform::path_to_ast(module.find_use_path(db, var.into())?);
109
110     // FIXME: use HIR for this; it doesn't currently expose struct vs. tuple vs. unit variants though
111     let pat: ast::Pat = match var.source(db).value.kind() {
112         ast::StructKind::Tuple(field_list) => {
113             let pats =
114                 iter::repeat(make::placeholder_pat().into()).take(field_list.fields().count());
115             make::tuple_struct_pat(path, pats).into()
116         }
117         ast::StructKind::Record(field_list) => {
118             let pats = field_list.fields().map(|f| make::bind_pat(f.name().unwrap()).into());
119             make::record_pat(path, pats).into()
120         }
121         ast::StructKind::Unit => make::path_pat(path),
122     };
123
124     Some(pat)
125 }
126
127 #[cfg(test)]
128 mod tests {
129     use crate::helpers::{check_assist, check_assist_not_applicable, check_assist_target};
130
131     use super::fill_match_arms;
132
133     #[test]
134     fn all_match_arms_provided() {
135         check_assist_not_applicable(
136             fill_match_arms,
137             r#"
138             enum A {
139                 As,
140                 Bs{x:i32, y:Option<i32>},
141                 Cs(i32, Option<i32>),
142             }
143             fn main() {
144                 match A::As<|> {
145                     A::As,
146                     A::Bs{x,y:Some(_)} => (),
147                     A::Cs(_, Some(_)) => (),
148                 }
149             }
150             "#,
151         );
152     }
153
154     #[test]
155     fn partial_fill_record_tuple() {
156         check_assist(
157             fill_match_arms,
158             r#"
159             enum A {
160                 As,
161                 Bs{x:i32, y:Option<i32>},
162                 Cs(i32, Option<i32>),
163             }
164             fn main() {
165                 match A::As<|> {
166                     A::Bs{x,y:Some(_)} => (),
167                     A::Cs(_, Some(_)) => (),
168                 }
169             }
170             "#,
171             r#"
172             enum A {
173                 As,
174                 Bs{x:i32, y:Option<i32>},
175                 Cs(i32, Option<i32>),
176             }
177             fn main() {
178                 match <|>A::As {
179                     A::Bs{x,y:Some(_)} => (),
180                     A::Cs(_, Some(_)) => (),
181                     A::As => (),
182                 }
183             }
184             "#,
185         );
186     }
187
188     #[test]
189     fn partial_fill_or_pat() {
190         check_assist(
191             fill_match_arms,
192             r#"
193             enum A {
194                 As,
195                 Bs,
196                 Cs(Option<i32>),
197             }
198             fn main() {
199                 match A::As<|> {
200                     A::Cs(_) | A::Bs => (),
201                 }
202             }
203             "#,
204             r#"
205             enum A {
206                 As,
207                 Bs,
208                 Cs(Option<i32>),
209             }
210             fn main() {
211                 match <|>A::As {
212                     A::Cs(_) | A::Bs => (),
213                     A::As => (),
214                 }
215             }
216             "#,
217         );
218     }
219
220     #[test]
221     fn partial_fill() {
222         check_assist(
223             fill_match_arms,
224             r#"
225             enum A {
226                 As,
227                 Bs,
228                 Cs,
229                 Ds(String),
230                 Es(B),
231             }
232             enum B {
233                 Xs,
234                 Ys,
235             }
236             fn main() {
237                 match A::As<|> {
238                     A::Bs if 0 < 1 => (),
239                     A::Ds(_value) => (),
240                     A::Es(B::Xs) => (),
241                 }
242             }
243             "#,
244             r#"
245             enum A {
246                 As,
247                 Bs,
248                 Cs,
249                 Ds(String),
250                 Es(B),
251             }
252             enum B {
253                 Xs,
254                 Ys,
255             }
256             fn main() {
257                 match <|>A::As {
258                     A::Bs if 0 < 1 => (),
259                     A::Ds(_value) => (),
260                     A::Es(B::Xs) => (),
261                     A::As => (),
262                     A::Cs => (),
263                 }
264             }
265             "#,
266         );
267     }
268
269     #[test]
270     fn fill_match_arms_empty_body() {
271         check_assist(
272             fill_match_arms,
273             r#"
274             enum A {
275                 As,
276                 Bs,
277                 Cs(String),
278                 Ds(String, String),
279                 Es{ x: usize, y: usize }
280             }
281
282             fn main() {
283                 let a = A::As;
284                 match a<|> {}
285             }
286             "#,
287             r#"
288             enum A {
289                 As,
290                 Bs,
291                 Cs(String),
292                 Ds(String, String),
293                 Es{ x: usize, y: usize }
294             }
295
296             fn main() {
297                 let a = A::As;
298                 match <|>a {
299                     A::As => (),
300                     A::Bs => (),
301                     A::Cs(_) => (),
302                     A::Ds(_, _) => (),
303                     A::Es { x, y } => (),
304                 }
305             }
306             "#,
307         );
308     }
309
310     #[test]
311     fn test_fill_match_arm_refs() {
312         check_assist(
313             fill_match_arms,
314             r#"
315             enum A {
316                 As,
317             }
318
319             fn foo(a: &A) {
320                 match a<|> {
321                 }
322             }
323             "#,
324             r#"
325             enum A {
326                 As,
327             }
328
329             fn foo(a: &A) {
330                 match <|>a {
331                     A::As => (),
332                 }
333             }
334             "#,
335         );
336
337         check_assist(
338             fill_match_arms,
339             r#"
340             enum A {
341                 Es{ x: usize, y: usize }
342             }
343
344             fn foo(a: &mut A) {
345                 match a<|> {
346                 }
347             }
348             "#,
349             r#"
350             enum A {
351                 Es{ x: usize, y: usize }
352             }
353
354             fn foo(a: &mut A) {
355                 match <|>a {
356                     A::Es { x, y } => (),
357                 }
358             }
359             "#,
360         );
361     }
362
363     #[test]
364     fn fill_match_arms_target() {
365         check_assist_target(
366             fill_match_arms,
367             r#"
368             enum E { X, Y }
369
370             fn main() {
371                 match E::X<|> {}
372             }
373             "#,
374             "match E::X {}",
375         );
376     }
377
378     #[test]
379     fn fill_match_arms_trivial_arm() {
380         check_assist(
381             fill_match_arms,
382             r#"
383             enum E { X, Y }
384
385             fn main() {
386                 match E::X {
387                     <|>_ => {},
388                 }
389             }
390             "#,
391             r#"
392             enum E { X, Y }
393
394             fn main() {
395                 match <|>E::X {
396                     E::X => (),
397                     E::Y => (),
398                 }
399             }
400             "#,
401         );
402     }
403
404     #[test]
405     fn fill_match_arms_qualifies_path() {
406         check_assist(
407             fill_match_arms,
408             r#"
409             mod foo { pub enum E { X, Y } }
410             use foo::E::X;
411
412             fn main() {
413                 match X {
414                     <|>
415                 }
416             }
417             "#,
418             r#"
419             mod foo { pub enum E { X, Y } }
420             use foo::E::X;
421
422             fn main() {
423                 match <|>X {
424                     X => (),
425                     foo::E::Y => (),
426                 }
427             }
428             "#,
429         );
430     }
431 }