]> git.lizzy.rs Git - rust.git/blob - crates/ide_assists/src/handlers/replace_if_let_with_match.rs
aee880625f8d1a34d5813955d417716df4697a45
[rust.git] / crates / ide_assists / src / handlers / replace_if_let_with_match.rs
1 use std::iter;
2
3 use ide_db::{ty_filter::TryEnum, RootDatabase};
4 use syntax::{
5     ast::{
6         self,
7         edit::{AstNodeEdit, IndentLevel},
8         make,
9     },
10     AstNode,
11 };
12
13 use crate::{
14     utils::{does_pat_match_variant, unwrap_trivial_block},
15     AssistContext, AssistId, AssistKind, Assists,
16 };
17
18 // Assist: replace_if_let_with_match
19 //
20 // Replaces `if let` with an else branch with a `match` expression.
21 //
22 // ```
23 // enum Action { Move { distance: u32 }, Stop }
24 //
25 // fn handle(action: Action) {
26 //     $0if let Action::Move { distance } = action {
27 //         foo(distance)
28 //     } else {
29 //         bar()
30 //     }
31 // }
32 // ```
33 // ->
34 // ```
35 // enum Action { Move { distance: u32 }, Stop }
36 //
37 // fn handle(action: Action) {
38 //     match action {
39 //         Action::Move { distance } => foo(distance),
40 //         _ => bar(),
41 //     }
42 // }
43 // ```
44 pub(crate) fn replace_if_let_with_match(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
45     let if_expr: ast::IfExpr = ctx.find_node_at_offset()?;
46     let cond = if_expr.condition()?;
47     let pat = cond.pat()?;
48     let expr = cond.expr()?;
49     let then_block = if_expr.then_branch()?;
50     let else_block = match if_expr.else_branch()? {
51         ast::ElseBranch::Block(it) => it,
52         ast::ElseBranch::IfExpr(_) => return None,
53     };
54
55     let target = if_expr.syntax().text_range();
56     acc.add(
57         AssistId("replace_if_let_with_match", AssistKind::RefactorRewrite),
58         "Replace with match",
59         target,
60         move |edit| {
61             let match_expr = {
62                 let then_arm = {
63                     let then_block = then_block.reset_indent().indent(IndentLevel(1));
64                     let then_expr = unwrap_trivial_block(then_block);
65                     make::match_arm(vec![pat.clone()], then_expr)
66                 };
67                 let else_arm = {
68                     let pattern = ctx
69                         .sema
70                         .type_of_pat(&pat)
71                         .and_then(|ty| TryEnum::from_ty(&ctx.sema, &ty))
72                         .map(|it| {
73                             if does_pat_match_variant(&pat, &it.sad_pattern()) {
74                                 it.happy_pattern()
75                             } else {
76                                 it.sad_pattern()
77                             }
78                         })
79                         .unwrap_or_else(|| make::wildcard_pat().into());
80                     let else_expr = unwrap_trivial_block(else_block);
81                     make::match_arm(vec![pattern], else_expr)
82                 };
83                 let match_expr =
84                     make::expr_match(expr, make::match_arm_list(vec![then_arm, else_arm]));
85                 match_expr.indent(IndentLevel::from_node(if_expr.syntax()))
86             };
87
88             edit.replace_ast::<ast::Expr>(if_expr.into(), match_expr);
89         },
90     )
91 }
92
93 // Assist: replace_match_with_if_let
94 //
95 // Replaces a binary `match` with a wildcard pattern and no guards with an `if let` expression.
96 //
97 // ```
98 // enum Action { Move { distance: u32 }, Stop }
99 //
100 // fn handle(action: Action) {
101 //     $0match action {
102 //         Action::Move { distance } => foo(distance),
103 //         _ => bar(),
104 //     }
105 // }
106 // ```
107 // ->
108 // ```
109 // enum Action { Move { distance: u32 }, Stop }
110 //
111 // fn handle(action: Action) {
112 //     if let Action::Move { distance } = action {
113 //         foo(distance)
114 //     } else {
115 //         bar()
116 //     }
117 // }
118 // ```
119 pub(crate) fn replace_match_with_if_let(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
120     let match_expr: ast::MatchExpr = ctx.find_node_at_offset()?;
121     let mut arms = match_expr.match_arm_list()?.arms();
122     let first_arm = arms.next()?;
123     let second_arm = arms.next()?;
124     if arms.next().is_some() || first_arm.guard().is_some() || second_arm.guard().is_some() {
125         return None;
126     }
127     let condition_expr = match_expr.expr()?;
128     let (if_let_pat, then_expr, else_expr) = if is_pat_wildcard_or_sad(&ctx.sema, &first_arm.pat()?)
129     {
130         (second_arm.pat()?, second_arm.expr()?, first_arm.expr()?)
131     } else if is_pat_wildcard_or_sad(&ctx.sema, &second_arm.pat()?) {
132         (first_arm.pat()?, first_arm.expr()?, second_arm.expr()?)
133     } else {
134         return None;
135     };
136
137     let target = match_expr.syntax().text_range();
138     acc.add(
139         AssistId("replace_match_with_if_let", AssistKind::RefactorRewrite),
140         "Replace with if let",
141         target,
142         move |edit| {
143             let condition = make::condition(condition_expr, Some(if_let_pat));
144             let then_block = match then_expr.reset_indent() {
145                 ast::Expr::BlockExpr(block) => block,
146                 expr => make::block_expr(iter::empty(), Some(expr)),
147             };
148             let else_expr = match else_expr {
149                 ast::Expr::BlockExpr(block)
150                     if block.statements().count() == 0 && block.tail_expr().is_none() =>
151                 {
152                     None
153                 }
154                 ast::Expr::TupleExpr(tuple) if tuple.fields().count() == 0 => None,
155                 expr => Some(expr),
156             };
157             let if_let_expr = make::expr_if(
158                 condition,
159                 then_block,
160                 else_expr.map(|else_expr| {
161                     ast::ElseBranch::Block(make::block_expr(iter::empty(), Some(else_expr)))
162                 }),
163             )
164             .indent(IndentLevel::from_node(match_expr.syntax()));
165
166             edit.replace_ast::<ast::Expr>(match_expr.into(), if_let_expr);
167         },
168     )
169 }
170
171 fn is_pat_wildcard_or_sad(sema: &hir::Semantics<RootDatabase>, pat: &ast::Pat) -> bool {
172     sema.type_of_pat(&pat)
173         .and_then(|ty| TryEnum::from_ty(sema, &ty))
174         .map(|it| it.sad_pattern().syntax().text() == pat.syntax().text())
175         .unwrap_or_else(|| matches!(pat, ast::Pat::WildcardPat(_)))
176 }
177
178 #[cfg(test)]
179 mod tests {
180     use super::*;
181
182     use crate::tests::{check_assist, check_assist_target};
183
184     #[test]
185     fn test_replace_if_let_with_match_unwraps_simple_expressions() {
186         check_assist(
187             replace_if_let_with_match,
188             r#"
189 impl VariantData {
190     pub fn is_struct(&self) -> bool {
191         if $0let VariantData::Struct(..) = *self {
192             true
193         } else {
194             false
195         }
196     }
197 }           "#,
198             r#"
199 impl VariantData {
200     pub fn is_struct(&self) -> bool {
201         match *self {
202             VariantData::Struct(..) => true,
203             _ => false,
204         }
205     }
206 }           "#,
207         )
208     }
209
210     #[test]
211     fn test_replace_if_let_with_match_doesnt_unwrap_multiline_expressions() {
212         check_assist(
213             replace_if_let_with_match,
214             r#"
215 fn foo() {
216     if $0let VariantData::Struct(..) = a {
217         bar(
218             123
219         )
220     } else {
221         false
222     }
223 }           "#,
224             r#"
225 fn foo() {
226     match a {
227         VariantData::Struct(..) => {
228             bar(
229                 123
230             )
231         }
232         _ => false,
233     }
234 }           "#,
235         )
236     }
237
238     #[test]
239     fn replace_if_let_with_match_target() {
240         check_assist_target(
241             replace_if_let_with_match,
242             r#"
243 impl VariantData {
244     pub fn is_struct(&self) -> bool {
245         if $0let VariantData::Struct(..) = *self {
246             true
247         } else {
248             false
249         }
250     }
251 }           "#,
252             "if let VariantData::Struct(..) = *self {
253             true
254         } else {
255             false
256         }",
257         );
258     }
259
260     #[test]
261     fn special_case_option() {
262         check_assist(
263             replace_if_let_with_match,
264             r#"
265 enum Option<T> { Some(T), None }
266 use Option::*;
267
268 fn foo(x: Option<i32>) {
269     $0if let Some(x) = x {
270         println!("{}", x)
271     } else {
272         println!("none")
273     }
274 }
275            "#,
276             r#"
277 enum Option<T> { Some(T), None }
278 use Option::*;
279
280 fn foo(x: Option<i32>) {
281     match x {
282         Some(x) => println!("{}", x),
283         None => println!("none"),
284     }
285 }
286            "#,
287         );
288     }
289
290     #[test]
291     fn special_case_inverted_option() {
292         check_assist(
293             replace_if_let_with_match,
294             r#"
295 enum Option<T> { Some(T), None }
296 use Option::*;
297
298 fn foo(x: Option<i32>) {
299     $0if let None = x {
300         println!("none")
301     } else {
302         println!("some")
303     }
304 }
305            "#,
306             r#"
307 enum Option<T> { Some(T), None }
308 use Option::*;
309
310 fn foo(x: Option<i32>) {
311     match x {
312         None => println!("none"),
313         Some(_) => println!("some"),
314     }
315 }
316            "#,
317         );
318     }
319
320     #[test]
321     fn special_case_result() {
322         check_assist(
323             replace_if_let_with_match,
324             r#"
325 enum Result<T, E> { Ok(T), Err(E) }
326 use Result::*;
327
328 fn foo(x: Result<i32, ()>) {
329     $0if let Ok(x) = x {
330         println!("{}", x)
331     } else {
332         println!("none")
333     }
334 }
335            "#,
336             r#"
337 enum Result<T, E> { Ok(T), Err(E) }
338 use Result::*;
339
340 fn foo(x: Result<i32, ()>) {
341     match x {
342         Ok(x) => println!("{}", x),
343         Err(_) => println!("none"),
344     }
345 }
346            "#,
347         );
348     }
349
350     #[test]
351     fn special_case_inverted_result() {
352         check_assist(
353             replace_if_let_with_match,
354             r#"
355 enum Result<T, E> { Ok(T), Err(E) }
356 use Result::*;
357
358 fn foo(x: Result<i32, ()>) {
359     $0if let Err(x) = x {
360         println!("{}", x)
361     } else {
362         println!("ok")
363     }
364 }
365            "#,
366             r#"
367 enum Result<T, E> { Ok(T), Err(E) }
368 use Result::*;
369
370 fn foo(x: Result<i32, ()>) {
371     match x {
372         Err(x) => println!("{}", x),
373         Ok(_) => println!("ok"),
374     }
375 }
376            "#,
377         );
378     }
379
380     #[test]
381     fn nested_indent() {
382         check_assist(
383             replace_if_let_with_match,
384             r#"
385 fn main() {
386     if true {
387         $0if let Ok(rel_path) = path.strip_prefix(root_path) {
388             let rel_path = RelativePathBuf::from_path(rel_path).ok()?;
389             Some((*id, rel_path))
390         } else {
391             None
392         }
393     }
394 }
395 "#,
396             r#"
397 fn main() {
398     if true {
399         match path.strip_prefix(root_path) {
400             Ok(rel_path) => {
401                 let rel_path = RelativePathBuf::from_path(rel_path).ok()?;
402                 Some((*id, rel_path))
403             }
404             _ => None,
405         }
406     }
407 }
408 "#,
409         )
410     }
411
412     #[test]
413     fn test_replace_match_with_if_let_unwraps_simple_expressions() {
414         check_assist(
415             replace_match_with_if_let,
416             r#"
417 impl VariantData {
418     pub fn is_struct(&self) -> bool {
419         $0match *self {
420             VariantData::Struct(..) => true,
421             _ => false,
422         }
423     }
424 }           "#,
425             r#"
426 impl VariantData {
427     pub fn is_struct(&self) -> bool {
428         if let VariantData::Struct(..) = *self {
429             true
430         } else {
431             false
432         }
433     }
434 }           "#,
435         )
436     }
437
438     #[test]
439     fn test_replace_match_with_if_let_doesnt_unwrap_multiline_expressions() {
440         check_assist(
441             replace_match_with_if_let,
442             r#"
443 fn foo() {
444     $0match a {
445         VariantData::Struct(..) => {
446             bar(
447                 123
448             )
449         }
450         _ => false,
451     }
452 }           "#,
453             r#"
454 fn foo() {
455     if let VariantData::Struct(..) = a {
456         bar(
457             123
458         )
459     } else {
460         false
461     }
462 }           "#,
463         )
464     }
465
466     #[test]
467     fn replace_match_with_if_let_target() {
468         check_assist_target(
469             replace_match_with_if_let,
470             r#"
471 impl VariantData {
472     pub fn is_struct(&self) -> bool {
473         $0match *self {
474             VariantData::Struct(..) => true,
475             _ => false,
476         }
477     }
478 }           "#,
479             r#"match *self {
480             VariantData::Struct(..) => true,
481             _ => false,
482         }"#,
483         );
484     }
485
486     #[test]
487     fn special_case_option_match_to_if_let() {
488         check_assist(
489             replace_match_with_if_let,
490             r#"
491 enum Option<T> { Some(T), None }
492 use Option::*;
493
494 fn foo(x: Option<i32>) {
495     $0match x {
496         Some(x) => println!("{}", x),
497         None => println!("none"),
498     }
499 }
500            "#,
501             r#"
502 enum Option<T> { Some(T), None }
503 use Option::*;
504
505 fn foo(x: Option<i32>) {
506     if let Some(x) = x {
507         println!("{}", x)
508     } else {
509         println!("none")
510     }
511 }
512            "#,
513         );
514     }
515
516     #[test]
517     fn special_case_result_match_to_if_let() {
518         check_assist(
519             replace_match_with_if_let,
520             r#"
521 enum Result<T, E> { Ok(T), Err(E) }
522 use Result::*;
523
524 fn foo(x: Result<i32, ()>) {
525     $0match x {
526         Ok(x) => println!("{}", x),
527         Err(_) => println!("none"),
528     }
529 }
530            "#,
531             r#"
532 enum Result<T, E> { Ok(T), Err(E) }
533 use Result::*;
534
535 fn foo(x: Result<i32, ()>) {
536     if let Ok(x) = x {
537         println!("{}", x)
538     } else {
539         println!("none")
540     }
541 }
542            "#,
543         );
544     }
545
546     #[test]
547     fn nested_indent_match_to_if_let() {
548         check_assist(
549             replace_match_with_if_let,
550             r#"
551 fn main() {
552     if true {
553         $0match path.strip_prefix(root_path) {
554             Ok(rel_path) => {
555                 let rel_path = RelativePathBuf::from_path(rel_path).ok()?;
556                 Some((*id, rel_path))
557             }
558             _ => None,
559         }
560     }
561 }
562 "#,
563             r#"
564 fn main() {
565     if true {
566         if let Ok(rel_path) = path.strip_prefix(root_path) {
567             let rel_path = RelativePathBuf::from_path(rel_path).ok()?;
568             Some((*id, rel_path))
569         } else {
570             None
571         }
572     }
573 }
574 "#,
575         )
576     }
577
578     #[test]
579     fn replace_match_with_if_let_empty_wildcard_expr() {
580         check_assist(
581             replace_match_with_if_let,
582             r#"
583 fn main() {
584     $0match path.strip_prefix(root_path) {
585         Ok(rel_path) => println!("{}", rel_path),
586         _ => (),
587     }
588 }
589 "#,
590             r#"
591 fn main() {
592     if let Ok(rel_path) = path.strip_prefix(root_path) {
593         println!("{}", rel_path)
594     }
595 }
596 "#,
597         )
598     }
599 }