]> git.lizzy.rs Git - rust.git/blob - src/tools/rust-analyzer/crates/ide-assists/src/handlers/pull_assignment_up.rs
Rollup merge of #103996 - SUPERCILEX:docs, r=RalfJung
[rust.git] / src / tools / rust-analyzer / crates / ide-assists / src / handlers / pull_assignment_up.rs
1 use syntax::{
2     ast::{self, make},
3     ted, AstNode,
4 };
5
6 use crate::{
7     assist_context::{AssistContext, Assists},
8     AssistId, AssistKind,
9 };
10
11 // Assist: pull_assignment_up
12 //
13 // Extracts variable assignment to outside an if or match statement.
14 //
15 // ```
16 // fn main() {
17 //     let mut foo = 6;
18 //
19 //     if true {
20 //         $0foo = 5;
21 //     } else {
22 //         foo = 4;
23 //     }
24 // }
25 // ```
26 // ->
27 // ```
28 // fn main() {
29 //     let mut foo = 6;
30 //
31 //     foo = if true {
32 //         5
33 //     } else {
34 //         4
35 //     };
36 // }
37 // ```
38 pub(crate) fn pull_assignment_up(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> {
39     let assign_expr = ctx.find_node_at_offset::<ast::BinExpr>()?;
40
41     let op_kind = assign_expr.op_kind()?;
42     if op_kind != (ast::BinaryOp::Assignment { op: None }) {
43         cov_mark::hit!(test_cant_pull_non_assignments);
44         return None;
45     }
46
47     let mut collector = AssignmentsCollector {
48         sema: &ctx.sema,
49         common_lhs: assign_expr.lhs()?,
50         assignments: Vec::new(),
51     };
52
53     let tgt: ast::Expr = if let Some(if_expr) = ctx.find_node_at_offset::<ast::IfExpr>() {
54         collector.collect_if(&if_expr)?;
55         if_expr.into()
56     } else if let Some(match_expr) = ctx.find_node_at_offset::<ast::MatchExpr>() {
57         collector.collect_match(&match_expr)?;
58         match_expr.into()
59     } else {
60         return None;
61     };
62
63     if let Some(parent) = tgt.syntax().parent() {
64         if matches!(parent.kind(), syntax::SyntaxKind::BIN_EXPR | syntax::SyntaxKind::LET_STMT) {
65             return None;
66         }
67     }
68
69     acc.add(
70         AssistId("pull_assignment_up", AssistKind::RefactorExtract),
71         "Pull assignment up",
72         tgt.syntax().text_range(),
73         move |edit| {
74             let assignments: Vec<_> = collector
75                 .assignments
76                 .into_iter()
77                 .map(|(stmt, rhs)| (edit.make_mut(stmt), rhs.clone_for_update()))
78                 .collect();
79
80             let tgt = edit.make_mut(tgt);
81
82             for (stmt, rhs) in assignments {
83                 let mut stmt = stmt.syntax().clone();
84                 if let Some(parent) = stmt.parent() {
85                     if ast::ExprStmt::cast(parent.clone()).is_some() {
86                         stmt = parent.clone();
87                     }
88                 }
89                 ted::replace(stmt, rhs.syntax());
90             }
91             let assign_expr = make::expr_assignment(collector.common_lhs, tgt.clone());
92             let assign_stmt = make::expr_stmt(assign_expr);
93
94             ted::replace(tgt.syntax(), assign_stmt.syntax().clone_for_update());
95         },
96     )
97 }
98
99 struct AssignmentsCollector<'a> {
100     sema: &'a hir::Semantics<'a, ide_db::RootDatabase>,
101     common_lhs: ast::Expr,
102     assignments: Vec<(ast::BinExpr, ast::Expr)>,
103 }
104
105 impl<'a> AssignmentsCollector<'a> {
106     fn collect_match(&mut self, match_expr: &ast::MatchExpr) -> Option<()> {
107         for arm in match_expr.match_arm_list()?.arms() {
108             match arm.expr()? {
109                 ast::Expr::BlockExpr(block) => self.collect_block(&block)?,
110                 ast::Expr::BinExpr(expr) => self.collect_expr(&expr)?,
111                 _ => return None,
112             }
113         }
114
115         Some(())
116     }
117     fn collect_if(&mut self, if_expr: &ast::IfExpr) -> Option<()> {
118         let then_branch = if_expr.then_branch()?;
119         self.collect_block(&then_branch)?;
120
121         match if_expr.else_branch()? {
122             ast::ElseBranch::Block(block) => self.collect_block(&block),
123             ast::ElseBranch::IfExpr(expr) => {
124                 cov_mark::hit!(test_pull_assignment_up_chained_if);
125                 self.collect_if(&expr)
126             }
127         }
128     }
129     fn collect_block(&mut self, block: &ast::BlockExpr) -> Option<()> {
130         let last_expr = block.tail_expr().or_else(|| match block.statements().last()? {
131             ast::Stmt::ExprStmt(stmt) => stmt.expr(),
132             ast::Stmt::Item(_) | ast::Stmt::LetStmt(_) => None,
133         })?;
134
135         if let ast::Expr::BinExpr(expr) = last_expr {
136             return self.collect_expr(&expr);
137         }
138
139         None
140     }
141
142     fn collect_expr(&mut self, expr: &ast::BinExpr) -> Option<()> {
143         if expr.op_kind()? == (ast::BinaryOp::Assignment { op: None })
144             && is_equivalent(self.sema, &expr.lhs()?, &self.common_lhs)
145         {
146             self.assignments.push((expr.clone(), expr.rhs()?));
147             return Some(());
148         }
149         None
150     }
151 }
152
153 fn is_equivalent(
154     sema: &hir::Semantics<'_, ide_db::RootDatabase>,
155     expr0: &ast::Expr,
156     expr1: &ast::Expr,
157 ) -> bool {
158     match (expr0, expr1) {
159         (ast::Expr::FieldExpr(field_expr0), ast::Expr::FieldExpr(field_expr1)) => {
160             cov_mark::hit!(test_pull_assignment_up_field_assignment);
161             sema.resolve_field(field_expr0) == sema.resolve_field(field_expr1)
162         }
163         (ast::Expr::PathExpr(path0), ast::Expr::PathExpr(path1)) => {
164             let path0 = path0.path();
165             let path1 = path1.path();
166             if let (Some(path0), Some(path1)) = (path0, path1) {
167                 sema.resolve_path(&path0) == sema.resolve_path(&path1)
168             } else {
169                 false
170             }
171         }
172         (ast::Expr::PrefixExpr(prefix0), ast::Expr::PrefixExpr(prefix1))
173             if prefix0.op_kind() == Some(ast::UnaryOp::Deref)
174                 && prefix1.op_kind() == Some(ast::UnaryOp::Deref) =>
175         {
176             cov_mark::hit!(test_pull_assignment_up_deref);
177             if let (Some(prefix0), Some(prefix1)) = (prefix0.expr(), prefix1.expr()) {
178                 is_equivalent(sema, &prefix0, &prefix1)
179             } else {
180                 false
181             }
182         }
183         _ => false,
184     }
185 }
186
187 #[cfg(test)]
188 mod tests {
189     use super::*;
190
191     use crate::tests::{check_assist, check_assist_not_applicable};
192
193     #[test]
194     fn test_pull_assignment_up_if() {
195         check_assist(
196             pull_assignment_up,
197             r#"
198 fn foo() {
199     let mut a = 1;
200
201     if true {
202         $0a = 2;
203     } else {
204         a = 3;
205     }
206 }"#,
207             r#"
208 fn foo() {
209     let mut a = 1;
210
211     a = if true {
212         2
213     } else {
214         3
215     };
216 }"#,
217         );
218     }
219
220     #[test]
221     fn test_pull_assignment_up_match() {
222         check_assist(
223             pull_assignment_up,
224             r#"
225 fn foo() {
226     let mut a = 1;
227
228     match 1 {
229         1 => {
230             $0a = 2;
231         },
232         2 => {
233             a = 3;
234         },
235         3 => {
236             a = 4;
237         }
238     }
239 }"#,
240             r#"
241 fn foo() {
242     let mut a = 1;
243
244     a = match 1 {
245         1 => {
246             2
247         },
248         2 => {
249             3
250         },
251         3 => {
252             4
253         }
254     };
255 }"#,
256         );
257     }
258
259     #[test]
260     fn test_pull_assignment_up_assignment_expressions() {
261         check_assist(
262             pull_assignment_up,
263             r#"
264 fn foo() {
265     let mut a = 1;
266
267     match 1 {
268         1 => { $0a = 2; },
269         2 => a = 3,
270         3 => {
271             a = 4
272         }
273     }
274 }"#,
275             r#"
276 fn foo() {
277     let mut a = 1;
278
279     a = match 1 {
280         1 => { 2 },
281         2 => 3,
282         3 => {
283             4
284         }
285     };
286 }"#,
287         );
288     }
289
290     #[test]
291     fn test_pull_assignment_up_not_last_not_applicable() {
292         check_assist_not_applicable(
293             pull_assignment_up,
294             r#"
295 fn foo() {
296     let mut a = 1;
297
298     if true {
299         $0a = 2;
300         b = a;
301     } else {
302         a = 3;
303     }
304 }"#,
305         )
306     }
307
308     #[test]
309     fn test_pull_assignment_up_chained_if() {
310         cov_mark::check!(test_pull_assignment_up_chained_if);
311         check_assist(
312             pull_assignment_up,
313             r#"
314 fn foo() {
315     let mut a = 1;
316
317     if true {
318         $0a = 2;
319     } else if false {
320         a = 3;
321     } else {
322         a = 4;
323     }
324 }"#,
325             r#"
326 fn foo() {
327     let mut a = 1;
328
329     a = if true {
330         2
331     } else if false {
332         3
333     } else {
334         4
335     };
336 }"#,
337         );
338     }
339
340     #[test]
341     fn test_pull_assignment_up_retains_stmts() {
342         check_assist(
343             pull_assignment_up,
344             r#"
345 fn foo() {
346     let mut a = 1;
347
348     if true {
349         let b = 2;
350         $0a = 2;
351     } else {
352         let b = 3;
353         a = 3;
354     }
355 }"#,
356             r#"
357 fn foo() {
358     let mut a = 1;
359
360     a = if true {
361         let b = 2;
362         2
363     } else {
364         let b = 3;
365         3
366     };
367 }"#,
368         )
369     }
370
371     #[test]
372     fn pull_assignment_up_let_stmt_not_applicable() {
373         check_assist_not_applicable(
374             pull_assignment_up,
375             r#"
376 fn foo() {
377     let mut a = 1;
378
379     let b = if true {
380         $0a = 2
381     } else {
382         a = 3
383     };
384 }"#,
385         )
386     }
387
388     #[test]
389     fn pull_assignment_up_if_missing_assigment_not_applicable() {
390         check_assist_not_applicable(
391             pull_assignment_up,
392             r#"
393 fn foo() {
394     let mut a = 1;
395
396     if true {
397         $0a = 2;
398     } else {}
399 }"#,
400         )
401     }
402
403     #[test]
404     fn pull_assignment_up_match_missing_assigment_not_applicable() {
405         check_assist_not_applicable(
406             pull_assignment_up,
407             r#"
408 fn foo() {
409     let mut a = 1;
410
411     match 1 {
412         1 => {
413             $0a = 2;
414         },
415         2 => {
416             a = 3;
417         },
418         3 => {},
419     }
420 }"#,
421         )
422     }
423
424     #[test]
425     fn test_pull_assignment_up_field_assignment() {
426         cov_mark::check!(test_pull_assignment_up_field_assignment);
427         check_assist(
428             pull_assignment_up,
429             r#"
430 struct A(usize);
431
432 fn foo() {
433     let mut a = A(1);
434
435     if true {
436         $0a.0 = 2;
437     } else {
438         a.0 = 3;
439     }
440 }"#,
441             r#"
442 struct A(usize);
443
444 fn foo() {
445     let mut a = A(1);
446
447     a.0 = if true {
448         2
449     } else {
450         3
451     };
452 }"#,
453         )
454     }
455
456     #[test]
457     fn test_pull_assignment_up_deref() {
458         cov_mark::check!(test_pull_assignment_up_deref);
459         check_assist(
460             pull_assignment_up,
461             r#"
462 fn foo() {
463     let mut a = 1;
464     let b = &mut a;
465
466     if true {
467         $0*b = 2;
468     } else {
469         *b = 3;
470     }
471 }
472 "#,
473             r#"
474 fn foo() {
475     let mut a = 1;
476     let b = &mut a;
477
478     *b = if true {
479         2
480     } else {
481         3
482     };
483 }
484 "#,
485         )
486     }
487
488     #[test]
489     fn test_cant_pull_non_assignments() {
490         cov_mark::check!(test_cant_pull_non_assignments);
491         check_assist_not_applicable(
492             pull_assignment_up,
493             r#"
494 fn foo() {
495     let mut a = 1;
496     let b = &mut a;
497
498     if true {
499         $0*b + 2;
500     } else {
501         *b + 3;
502     }
503 }
504 "#,
505         )
506     }
507 }