]> git.lizzy.rs Git - rust.git/blob - src/tools/rust-analyzer/crates/ide-assists/src/handlers/convert_match_to_let_else.rs
Rollup merge of #103996 - SUPERCILEX:docs, r=RalfJung
[rust.git] / src / tools / rust-analyzer / crates / ide-assists / src / handlers / convert_match_to_let_else.rs
1 use ide_db::defs::{Definition, NameRefClass};
2 use syntax::{
3     ast::{self, HasName},
4     ted, AstNode, SyntaxNode,
5 };
6
7 use crate::{
8     assist_context::{AssistContext, Assists},
9     AssistId, AssistKind,
10 };
11
12 // Assist: convert_match_to_let_else
13 //
14 // Converts let statement with match initializer to let-else statement.
15 //
16 // ```
17 // # //- minicore: option
18 // fn foo(opt: Option<()>) {
19 //     let val = $0match opt {
20 //         Some(it) => it,
21 //         None => return,
22 //     };
23 // }
24 // ```
25 // ->
26 // ```
27 // fn foo(opt: Option<()>) {
28 //     let Some(val) = opt else { return };
29 // }
30 // ```
31 pub(crate) fn convert_match_to_let_else(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> {
32     let let_stmt: ast::LetStmt = ctx.find_node_at_offset()?;
33     let binding = find_binding(let_stmt.pat()?)?;
34
35     let initializer = match let_stmt.initializer() {
36         Some(ast::Expr::MatchExpr(it)) => it,
37         _ => return None,
38     };
39     let initializer_expr = initializer.expr()?;
40
41     let (extracting_arm, diverging_arm) = match find_arms(ctx, &initializer) {
42         Some(it) => it,
43         None => return None,
44     };
45     if extracting_arm.guard().is_some() {
46         cov_mark::hit!(extracting_arm_has_guard);
47         return None;
48     }
49
50     let diverging_arm_expr = diverging_arm.expr()?;
51     let extracting_arm_pat = extracting_arm.pat()?;
52     let extracted_variable = find_extracted_variable(ctx, &extracting_arm)?;
53
54     acc.add(
55         AssistId("convert_match_to_let_else", AssistKind::RefactorRewrite),
56         "Convert match to let-else",
57         let_stmt.syntax().text_range(),
58         |builder| {
59             let extracting_arm_pat = rename_variable(&extracting_arm_pat, extracted_variable, binding);
60             builder.replace(
61                 let_stmt.syntax().text_range(),
62                 format!("let {extracting_arm_pat} = {initializer_expr} else {{ {diverging_arm_expr} }};")
63             )
64         },
65     )
66 }
67
68 // Given a pattern, find the name introduced to the surrounding scope.
69 fn find_binding(pat: ast::Pat) -> Option<ast::IdentPat> {
70     if let ast::Pat::IdentPat(ident) = pat {
71         Some(ident)
72     } else {
73         None
74     }
75 }
76
77 // Given a match expression, find extracting and diverging arms.
78 fn find_arms(
79     ctx: &AssistContext<'_>,
80     match_expr: &ast::MatchExpr,
81 ) -> Option<(ast::MatchArm, ast::MatchArm)> {
82     let arms = match_expr.match_arm_list()?.arms().collect::<Vec<_>>();
83     if arms.len() != 2 {
84         return None;
85     }
86
87     let mut extracting = None;
88     let mut diverging = None;
89     for arm in arms {
90         if ctx.sema.type_of_expr(&arm.expr().unwrap()).unwrap().original().is_never() {
91             diverging = Some(arm);
92         } else {
93             extracting = Some(arm);
94         }
95     }
96
97     match (extracting, diverging) {
98         (Some(extracting), Some(diverging)) => Some((extracting, diverging)),
99         _ => {
100             cov_mark::hit!(non_diverging_match);
101             None
102         }
103     }
104 }
105
106 // Given an extracting arm, find the extracted variable.
107 fn find_extracted_variable(ctx: &AssistContext<'_>, arm: &ast::MatchArm) -> Option<ast::Name> {
108     match arm.expr()? {
109         ast::Expr::PathExpr(path) => {
110             let name_ref = path.syntax().descendants().find_map(ast::NameRef::cast)?;
111             match NameRefClass::classify(&ctx.sema, &name_ref)? {
112                 NameRefClass::Definition(Definition::Local(local)) => {
113                     let source = local.source(ctx.db()).value.left()?;
114                     Some(source.name()?)
115                 }
116                 _ => None,
117             }
118         }
119         _ => {
120             cov_mark::hit!(extracting_arm_is_not_an_identity_expr);
121             return None;
122         }
123     }
124 }
125
126 // Rename `extracted` with `binding` in `pat`.
127 fn rename_variable(pat: &ast::Pat, extracted: ast::Name, binding: ast::IdentPat) -> SyntaxNode {
128     let syntax = pat.syntax().clone_for_update();
129     let extracted_syntax = syntax.covering_element(extracted.syntax().text_range());
130
131     // If `extracted` variable is a record field, we should rename it to `binding`,
132     // otherwise we just need to replace `extracted` with `binding`.
133
134     if let Some(record_pat_field) = extracted_syntax.ancestors().find_map(ast::RecordPatField::cast)
135     {
136         if let Some(name_ref) = record_pat_field.field_name() {
137             ted::replace(
138                 record_pat_field.syntax(),
139                 ast::make::record_pat_field(ast::make::name_ref(&name_ref.text()), binding.into())
140                     .syntax()
141                     .clone_for_update(),
142             );
143         }
144     } else {
145         ted::replace(extracted_syntax, binding.syntax().clone_for_update());
146     }
147
148     syntax
149 }
150
151 #[cfg(test)]
152 mod tests {
153     use crate::tests::{check_assist, check_assist_not_applicable};
154
155     use super::*;
156
157     #[test]
158     fn should_not_be_applicable_for_non_diverging_match() {
159         cov_mark::check!(non_diverging_match);
160         check_assist_not_applicable(
161             convert_match_to_let_else,
162             r#"
163 //- minicore: option
164 fn foo(opt: Option<()>) {
165     let val = $0match opt {
166         Some(it) => it,
167         None => (),
168     };
169 }
170 "#,
171         );
172     }
173
174     #[test]
175     fn should_not_be_applicable_if_extracting_arm_is_not_an_identity_expr() {
176         cov_mark::check_count!(extracting_arm_is_not_an_identity_expr, 2);
177         check_assist_not_applicable(
178             convert_match_to_let_else,
179             r#"
180 //- minicore: option
181 fn foo(opt: Option<i32>) {
182     let val = $0match opt {
183         Some(it) => it + 1,
184         None => return,
185     };
186 }
187 "#,
188         );
189
190         check_assist_not_applicable(
191             convert_match_to_let_else,
192             r#"
193 //- minicore: option
194 fn foo(opt: Option<()>) {
195     let val = $0match opt {
196         Some(it) => {
197             let _ = 1 + 1;
198             it
199         },
200         None => return,
201     };
202 }
203 "#,
204         );
205     }
206
207     #[test]
208     fn should_not_be_applicable_if_extracting_arm_has_guard() {
209         cov_mark::check!(extracting_arm_has_guard);
210         check_assist_not_applicable(
211             convert_match_to_let_else,
212             r#"
213 //- minicore: option
214 fn foo(opt: Option<()>) {
215     let val = $0match opt {
216         Some(it) if 2 > 1 => it,
217         None => return,
218     };
219 }
220 "#,
221         );
222     }
223
224     #[test]
225     fn basic_pattern() {
226         check_assist(
227             convert_match_to_let_else,
228             r#"
229 //- minicore: option
230 fn foo(opt: Option<()>) {
231     let val = $0match opt {
232         Some(it) => it,
233         None => return,
234     };
235 }
236     "#,
237             r#"
238 fn foo(opt: Option<()>) {
239     let Some(val) = opt else { return };
240 }
241     "#,
242         );
243     }
244
245     #[test]
246     fn keeps_modifiers() {
247         check_assist(
248             convert_match_to_let_else,
249             r#"
250 //- minicore: option
251 fn foo(opt: Option<()>) {
252     let ref mut val = $0match opt {
253         Some(it) => it,
254         None => return,
255     };
256 }
257     "#,
258             r#"
259 fn foo(opt: Option<()>) {
260     let Some(ref mut val) = opt else { return };
261 }
262     "#,
263         );
264     }
265
266     #[test]
267     fn nested_pattern() {
268         check_assist(
269             convert_match_to_let_else,
270             r#"
271 //- minicore: option, result
272 fn foo(opt: Option<Result<()>>) {
273     let val = $0match opt {
274         Some(Ok(it)) => it,
275         _ => return,
276     };
277 }
278     "#,
279             r#"
280 fn foo(opt: Option<Result<()>>) {
281     let Some(Ok(val)) = opt else { return };
282 }
283     "#,
284         );
285     }
286
287     #[test]
288     fn works_with_any_diverging_block() {
289         check_assist(
290             convert_match_to_let_else,
291             r#"
292 //- minicore: option
293 fn foo(opt: Option<()>) {
294     loop {
295         let val = $0match opt {
296             Some(it) => it,
297             None => break,
298         };
299     }
300 }
301     "#,
302             r#"
303 fn foo(opt: Option<()>) {
304     loop {
305         let Some(val) = opt else { break };
306     }
307 }
308     "#,
309         );
310
311         check_assist(
312             convert_match_to_let_else,
313             r#"
314 //- minicore: option
315 fn foo(opt: Option<()>) {
316     loop {
317         let val = $0match opt {
318             Some(it) => it,
319             None => continue,
320         };
321     }
322 }
323     "#,
324             r#"
325 fn foo(opt: Option<()>) {
326     loop {
327         let Some(val) = opt else { continue };
328     }
329 }
330     "#,
331         );
332
333         check_assist(
334             convert_match_to_let_else,
335             r#"
336 //- minicore: option
337 fn panic() -> ! {}
338
339 fn foo(opt: Option<()>) {
340     loop {
341         let val = $0match opt {
342             Some(it) => it,
343             None => panic(),
344         };
345     }
346 }
347     "#,
348             r#"
349 fn panic() -> ! {}
350
351 fn foo(opt: Option<()>) {
352     loop {
353         let Some(val) = opt else { panic() };
354     }
355 }
356     "#,
357         );
358     }
359
360     #[test]
361     fn struct_pattern() {
362         check_assist(
363             convert_match_to_let_else,
364             r#"
365 //- minicore: option
366 struct Point {
367     x: i32,
368     y: i32,
369 }
370
371 fn foo(opt: Option<Point>) {
372     let val = $0match opt {
373         Some(Point { x: 0, y }) => y,
374         _ => return,
375     };
376 }
377     "#,
378             r#"
379 struct Point {
380     x: i32,
381     y: i32,
382 }
383
384 fn foo(opt: Option<Point>) {
385     let Some(Point { x: 0, y: val }) = opt else { return };
386 }
387     "#,
388         );
389     }
390
391     #[test]
392     fn renames_whole_binding() {
393         check_assist(
394             convert_match_to_let_else,
395             r#"
396 //- minicore: option
397 fn foo(opt: Option<i32>) -> Option<i32> {
398     let val = $0match opt {
399         it @ Some(42) => it,
400         _ => return None,
401     };
402     val
403 }
404     "#,
405             r#"
406 fn foo(opt: Option<i32>) -> Option<i32> {
407     let val @ Some(42) = opt else { return None };
408     val
409 }
410     "#,
411         );
412     }
413 }