]> git.lizzy.rs Git - rust.git/blob - crates/ide_assists/src/handlers/pull_assignment_up.rs
Merge #8795
[rust.git] / crates / ide_assists / src / handlers / pull_assignment_up.rs
1 use syntax::{
2     ast::{self, make},
3     ted, AstNode,
4 };
5
6 use crate::{
7     assist_context::{AssistContext, Assists},
8     AssistId, AssistKind,
9 };
10
11 // Assist: pull_assignment_up
12 //
13 // Extracts variable assignment to outside an if or match statement.
14 //
15 // ```
16 // fn main() {
17 //     let mut foo = 6;
18 //
19 //     if true {
20 //         $0foo = 5;
21 //     } else {
22 //         foo = 4;
23 //     }
24 // }
25 // ```
26 // ->
27 // ```
28 // fn main() {
29 //     let mut foo = 6;
30 //
31 //     foo = if true {
32 //         5
33 //     } else {
34 //         4
35 //     };
36 // }
37 // ```
38 pub(crate) fn pull_assignment_up(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
39     let assign_expr = ctx.find_node_at_offset::<ast::BinExpr>()?;
40
41     let op_kind = assign_expr.op_kind()?;
42     if op_kind != ast::BinOp::Assignment {
43         cov_mark::hit!(test_cant_pull_non_assignments);
44         return None;
45     }
46
47     let mut collector = AssignmentsCollector {
48         sema: &ctx.sema,
49         common_lhs: assign_expr.lhs()?,
50         assignments: Vec::new(),
51     };
52
53     let tgt: ast::Expr = if let Some(if_expr) = ctx.find_node_at_offset::<ast::IfExpr>() {
54         collector.collect_if(&if_expr)?;
55         if_expr.into()
56     } else if let Some(match_expr) = ctx.find_node_at_offset::<ast::MatchExpr>() {
57         collector.collect_match(&match_expr)?;
58         match_expr.into()
59     } else {
60         return None;
61     };
62
63     if let Some(parent) = tgt.syntax().parent() {
64         if matches!(parent.kind(), syntax::SyntaxKind::BIN_EXPR | syntax::SyntaxKind::LET_STMT) {
65             return None;
66         }
67     }
68
69     acc.add(
70         AssistId("pull_assignment_up", AssistKind::RefactorExtract),
71         "Pull assignment up",
72         tgt.syntax().text_range(),
73         move |edit| {
74             let assignments: Vec<_> = collector
75                 .assignments
76                 .into_iter()
77                 .map(|(stmt, rhs)| (edit.make_mut(stmt), rhs.clone_for_update()))
78                 .collect();
79
80             let tgt = edit.make_mut(tgt);
81
82             for (stmt, rhs) in assignments {
83                 let mut stmt = stmt.syntax().clone();
84                 if let Some(parent) = stmt.parent() {
85                     if ast::ExprStmt::cast(parent.clone()).is_some() {
86                         stmt = parent.clone();
87                     }
88                 }
89                 ted::replace(stmt, rhs.syntax());
90             }
91             let assign_expr = make::expr_assignment(collector.common_lhs, tgt.clone());
92             let assign_stmt = make::expr_stmt(assign_expr);
93
94             ted::replace(tgt.syntax(), assign_stmt.syntax().clone_for_update());
95         },
96     )
97 }
98
99 struct AssignmentsCollector<'a> {
100     sema: &'a hir::Semantics<'a, ide_db::RootDatabase>,
101     common_lhs: ast::Expr,
102     assignments: Vec<(ast::BinExpr, ast::Expr)>,
103 }
104
105 impl<'a> AssignmentsCollector<'a> {
106     fn collect_match(&mut self, match_expr: &ast::MatchExpr) -> Option<()> {
107         for arm in match_expr.match_arm_list()?.arms() {
108             match arm.expr()? {
109                 ast::Expr::BlockExpr(block) => self.collect_block(&block)?,
110                 ast::Expr::BinExpr(expr) => self.collect_expr(&expr)?,
111                 _ => return None,
112             }
113         }
114
115         Some(())
116     }
117     fn collect_if(&mut self, if_expr: &ast::IfExpr) -> Option<()> {
118         let then_branch = if_expr.then_branch()?;
119         self.collect_block(&then_branch)?;
120
121         match if_expr.else_branch()? {
122             ast::ElseBranch::Block(block) => self.collect_block(&block),
123             ast::ElseBranch::IfExpr(expr) => {
124                 cov_mark::hit!(test_pull_assignment_up_chained_if);
125                 self.collect_if(&expr)
126             }
127         }
128     }
129     fn collect_block(&mut self, block: &ast::BlockExpr) -> Option<()> {
130         let last_expr = block.tail_expr().or_else(|| {
131             if let ast::Stmt::ExprStmt(stmt) = block.statements().last()? {
132                 stmt.expr()
133             } else {
134                 None
135             }
136         })?;
137
138         if let ast::Expr::BinExpr(expr) = last_expr {
139             return self.collect_expr(&expr);
140         }
141
142         None
143     }
144
145     fn collect_expr(&mut self, expr: &ast::BinExpr) -> Option<()> {
146         if expr.op_kind()? == ast::BinOp::Assignment
147             && is_equivalent(self.sema, &expr.lhs()?, &self.common_lhs)
148         {
149             self.assignments.push((expr.clone(), expr.rhs()?));
150             return Some(());
151         }
152         None
153     }
154 }
155
156 fn is_equivalent(
157     sema: &hir::Semantics<ide_db::RootDatabase>,
158     expr0: &ast::Expr,
159     expr1: &ast::Expr,
160 ) -> bool {
161     match (expr0, expr1) {
162         (ast::Expr::FieldExpr(field_expr0), ast::Expr::FieldExpr(field_expr1)) => {
163             cov_mark::hit!(test_pull_assignment_up_field_assignment);
164             sema.resolve_field(field_expr0) == sema.resolve_field(field_expr1)
165         }
166         (ast::Expr::PathExpr(path0), ast::Expr::PathExpr(path1)) => {
167             let path0 = path0.path();
168             let path1 = path1.path();
169             if let (Some(path0), Some(path1)) = (path0, path1) {
170                 sema.resolve_path(&path0) == sema.resolve_path(&path1)
171             } else {
172                 false
173             }
174         }
175         (ast::Expr::PrefixExpr(prefix0), ast::Expr::PrefixExpr(prefix1))
176             if prefix0.op_kind() == Some(ast::PrefixOp::Deref)
177                 && prefix1.op_kind() == Some(ast::PrefixOp::Deref) =>
178         {
179             cov_mark::hit!(test_pull_assignment_up_deref);
180             if let (Some(prefix0), Some(prefix1)) = (prefix0.expr(), prefix1.expr()) {
181                 is_equivalent(sema, &prefix0, &prefix1)
182             } else {
183                 false
184             }
185         }
186         _ => false,
187     }
188 }
189
190 #[cfg(test)]
191 mod tests {
192     use super::*;
193
194     use crate::tests::{check_assist, check_assist_not_applicable};
195
196     #[test]
197     fn test_pull_assignment_up_if() {
198         check_assist(
199             pull_assignment_up,
200             r#"
201 fn foo() {
202     let mut a = 1;
203
204     if true {
205         $0a = 2;
206     } else {
207         a = 3;
208     }
209 }"#,
210             r#"
211 fn foo() {
212     let mut a = 1;
213
214     a = if true {
215         2
216     } else {
217         3
218     };
219 }"#,
220         );
221     }
222
223     #[test]
224     fn test_pull_assignment_up_match() {
225         check_assist(
226             pull_assignment_up,
227             r#"
228 fn foo() {
229     let mut a = 1;
230
231     match 1 {
232         1 => {
233             $0a = 2;
234         },
235         2 => {
236             a = 3;
237         },
238         3 => {
239             a = 4;
240         }
241     }
242 }"#,
243             r#"
244 fn foo() {
245     let mut a = 1;
246
247     a = match 1 {
248         1 => {
249             2
250         },
251         2 => {
252             3
253         },
254         3 => {
255             4
256         }
257     };
258 }"#,
259         );
260     }
261
262     #[test]
263     fn test_pull_assignment_up_assignment_expressions() {
264         check_assist(
265             pull_assignment_up,
266             r#"
267 fn foo() {
268     let mut a = 1;
269
270     match 1 {
271         1 => { $0a = 2; },
272         2 => a = 3,
273         3 => {
274             a = 4
275         }
276     }
277 }"#,
278             r#"
279 fn foo() {
280     let mut a = 1;
281
282     a = match 1 {
283         1 => { 2 },
284         2 => 3,
285         3 => {
286             4
287         }
288     };
289 }"#,
290         );
291     }
292
293     #[test]
294     fn test_pull_assignment_up_not_last_not_applicable() {
295         check_assist_not_applicable(
296             pull_assignment_up,
297             r#"
298 fn foo() {
299     let mut a = 1;
300
301     if true {
302         $0a = 2;
303         b = a;
304     } else {
305         a = 3;
306     }
307 }"#,
308         )
309     }
310
311     #[test]
312     fn test_pull_assignment_up_chained_if() {
313         cov_mark::check!(test_pull_assignment_up_chained_if);
314         check_assist(
315             pull_assignment_up,
316             r#"
317 fn foo() {
318     let mut a = 1;
319
320     if true {
321         $0a = 2;
322     } else if false {
323         a = 3;
324     } else {
325         a = 4;
326     }
327 }"#,
328             r#"
329 fn foo() {
330     let mut a = 1;
331
332     a = if true {
333         2
334     } else if false {
335         3
336     } else {
337         4
338     };
339 }"#,
340         );
341     }
342
343     #[test]
344     fn test_pull_assignment_up_retains_stmts() {
345         check_assist(
346             pull_assignment_up,
347             r#"
348 fn foo() {
349     let mut a = 1;
350
351     if true {
352         let b = 2;
353         $0a = 2;
354     } else {
355         let b = 3;
356         a = 3;
357     }
358 }"#,
359             r#"
360 fn foo() {
361     let mut a = 1;
362
363     a = if true {
364         let b = 2;
365         2
366     } else {
367         let b = 3;
368         3
369     };
370 }"#,
371         )
372     }
373
374     #[test]
375     fn pull_assignment_up_let_stmt_not_applicable() {
376         check_assist_not_applicable(
377             pull_assignment_up,
378             r#"
379 fn foo() {
380     let mut a = 1;
381
382     let b = if true {
383         $0a = 2
384     } else {
385         a = 3
386     };
387 }"#,
388         )
389     }
390
391     #[test]
392     fn pull_assignment_up_if_missing_assigment_not_applicable() {
393         check_assist_not_applicable(
394             pull_assignment_up,
395             r#"
396 fn foo() {
397     let mut a = 1;
398
399     if true {
400         $0a = 2;
401     } else {}
402 }"#,
403         )
404     }
405
406     #[test]
407     fn pull_assignment_up_match_missing_assigment_not_applicable() {
408         check_assist_not_applicable(
409             pull_assignment_up,
410             r#"
411 fn foo() {
412     let mut a = 1;
413
414     match 1 {
415         1 => {
416             $0a = 2;
417         },
418         2 => {
419             a = 3;
420         },
421         3 => {},
422     }
423 }"#,
424         )
425     }
426
427     #[test]
428     fn test_pull_assignment_up_field_assignment() {
429         cov_mark::check!(test_pull_assignment_up_field_assignment);
430         check_assist(
431             pull_assignment_up,
432             r#"
433 struct A(usize);
434
435 fn foo() {
436     let mut a = A(1);
437
438     if true {
439         $0a.0 = 2;
440     } else {
441         a.0 = 3;
442     }
443 }"#,
444             r#"
445 struct A(usize);
446
447 fn foo() {
448     let mut a = A(1);
449
450     a.0 = if true {
451         2
452     } else {
453         3
454     };
455 }"#,
456         )
457     }
458
459     #[test]
460     fn test_pull_assignment_up_deref() {
461         cov_mark::check!(test_pull_assignment_up_deref);
462         check_assist(
463             pull_assignment_up,
464             r#"
465 fn foo() {
466     let mut a = 1;
467     let b = &mut a;
468
469     if true {
470         $0*b = 2;
471     } else {
472         *b = 3;
473     }
474 }
475 "#,
476             r#"
477 fn foo() {
478     let mut a = 1;
479     let b = &mut a;
480
481     *b = if true {
482         2
483     } else {
484         3
485     };
486 }
487 "#,
488         )
489     }
490
491     #[test]
492     fn test_cant_pull_non_assignments() {
493         cov_mark::check!(test_cant_pull_non_assignments);
494         check_assist_not_applicable(
495             pull_assignment_up,
496             r#"
497 fn foo() {
498     let mut a = 1;
499     let b = &mut a;
500
501     if true {
502         $0*b + 2;
503     } else {
504         *b + 3;
505     }
506 }
507 "#,
508         )
509     }
510 }