]> git.lizzy.rs Git - rust.git/blob - crates/ide_assists/src/handlers/pull_assignment_up.rs
internal: pull_assignment_up uses mutable trees
[rust.git] / crates / ide_assists / src / handlers / pull_assignment_up.rs
1 use syntax::{
2     ast::{self, make},
3     ted, AstNode,
4 };
5
6 use crate::{
7     assist_context::{AssistContext, Assists},
8     AssistId, AssistKind,
9 };
10
11 // Assist: pull_assignment_up
12 //
13 // Extracts variable assignment to outside an if or match statement.
14 //
15 // ```
16 // fn main() {
17 //     let mut foo = 6;
18 //
19 //     if true {
20 //         $0foo = 5;
21 //     } else {
22 //         foo = 4;
23 //     }
24 // }
25 // ```
26 // ->
27 // ```
28 // fn main() {
29 //     let mut foo = 6;
30 //
31 //     foo = if true {
32 //         5
33 //     } else {
34 //         4
35 //     };
36 // }
37 // ```
38 pub(crate) fn pull_assignment_up(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
39     let assign_expr = ctx.find_node_at_offset::<ast::BinExpr>()?;
40
41     let op_kind = assign_expr.op_kind()?;
42     if op_kind != ast::BinOp::Assignment {
43         cov_mark::hit!(test_cant_pull_non_assignments);
44         return None;
45     }
46
47     let mut collector = AssignmentsCollector {
48         sema: &ctx.sema,
49         common_lhs: assign_expr.lhs()?,
50         assignments: Vec::new(),
51     };
52
53     let tgt: ast::Expr = if let Some(if_expr) = ctx.find_node_at_offset::<ast::IfExpr>() {
54         collector.collect_if(&if_expr)?;
55         if_expr.into()
56     } else if let Some(match_expr) = ctx.find_node_at_offset::<ast::MatchExpr>() {
57         collector.collect_match(&match_expr)?;
58         match_expr.into()
59     } else {
60         return None;
61     };
62
63     acc.add(
64         AssistId("pull_assignment_up", AssistKind::RefactorExtract),
65         "Pull assignment up",
66         tgt.syntax().text_range(),
67         move |edit| {
68             let assignments: Vec<_> = collector
69                 .assignments
70                 .into_iter()
71                 .map(|(stmt, rhs)| (edit.make_ast_mut(stmt), rhs.clone_for_update()))
72                 .collect();
73
74             let tgt = edit.make_ast_mut(tgt);
75
76             for (stmt, rhs) in assignments {
77                 ted::replace(stmt.syntax(), rhs.syntax());
78             }
79             let assign_expr = make::expr_assignment(collector.common_lhs, tgt.clone());
80             let assign_stmt = make::expr_stmt(assign_expr);
81
82             ted::replace(tgt.syntax(), assign_stmt.syntax().clone_for_update());
83         },
84     )
85 }
86
87 struct AssignmentsCollector<'a> {
88     sema: &'a hir::Semantics<'a, ide_db::RootDatabase>,
89     common_lhs: ast::Expr,
90     assignments: Vec<(ast::ExprStmt, ast::Expr)>,
91 }
92
93 impl<'a> AssignmentsCollector<'a> {
94     fn collect_match(&mut self, match_expr: &ast::MatchExpr) -> Option<()> {
95         for arm in match_expr.match_arm_list()?.arms() {
96             match arm.expr()? {
97                 ast::Expr::BlockExpr(block) => self.collect_block(&block)?,
98                 // TODO: Handle this while we are at it?
99                 _ => return None,
100             }
101         }
102
103         Some(())
104     }
105     fn collect_if(&mut self, if_expr: &ast::IfExpr) -> Option<()> {
106         let then_branch = if_expr.then_branch()?;
107         self.collect_block(&then_branch)?;
108
109         match if_expr.else_branch()? {
110             ast::ElseBranch::Block(block) => self.collect_block(&block),
111             ast::ElseBranch::IfExpr(expr) => {
112                 cov_mark::hit!(test_pull_assignment_up_chained_if);
113                 self.collect_if(&expr)
114             }
115         }
116     }
117     fn collect_block(&mut self, block: &ast::BlockExpr) -> Option<()> {
118         if block.tail_expr().is_some() {
119             return None;
120         }
121
122         let last_stmt = block.statements().last()?;
123         if let ast::Stmt::ExprStmt(stmt) = last_stmt {
124             if let ast::Expr::BinExpr(expr) = stmt.expr()? {
125                 if expr.op_kind()? == ast::BinOp::Assignment
126                     && is_equivalent(self.sema, &expr.lhs()?, &self.common_lhs)
127                 {
128                     self.assignments.push((stmt, expr.rhs()?));
129                     return Some(());
130                 }
131             }
132         }
133
134         None
135     }
136 }
137
138 fn is_equivalent(
139     sema: &hir::Semantics<ide_db::RootDatabase>,
140     expr0: &ast::Expr,
141     expr1: &ast::Expr,
142 ) -> bool {
143     match (expr0, expr1) {
144         (ast::Expr::FieldExpr(field_expr0), ast::Expr::FieldExpr(field_expr1)) => {
145             cov_mark::hit!(test_pull_assignment_up_field_assignment);
146             sema.resolve_field(field_expr0) == sema.resolve_field(field_expr1)
147         }
148         (ast::Expr::PathExpr(path0), ast::Expr::PathExpr(path1)) => {
149             let path0 = path0.path();
150             let path1 = path1.path();
151             if let (Some(path0), Some(path1)) = (path0, path1) {
152                 sema.resolve_path(&path0) == sema.resolve_path(&path1)
153             } else {
154                 false
155             }
156         }
157         (ast::Expr::PrefixExpr(prefix0), ast::Expr::PrefixExpr(prefix1))
158             if prefix0.op_kind() == Some(ast::PrefixOp::Deref)
159                 && prefix1.op_kind() == Some(ast::PrefixOp::Deref) =>
160         {
161             cov_mark::hit!(test_pull_assignment_up_deref);
162             if let (Some(prefix0), Some(prefix1)) = (prefix0.expr(), prefix1.expr()) {
163                 is_equivalent(sema, &prefix0, &prefix1)
164             } else {
165                 false
166             }
167         }
168         _ => false,
169     }
170 }
171
172 #[cfg(test)]
173 mod tests {
174     use super::*;
175
176     use crate::tests::{check_assist, check_assist_not_applicable};
177
178     #[test]
179     fn test_pull_assignment_up_if() {
180         check_assist(
181             pull_assignment_up,
182             r#"
183 fn foo() {
184     let mut a = 1;
185
186     if true {
187         $0a = 2;
188     } else {
189         a = 3;
190     }
191 }"#,
192             r#"
193 fn foo() {
194     let mut a = 1;
195
196     a = if true {
197         2
198     } else {
199         3
200     };
201 }"#,
202         );
203     }
204
205     #[test]
206     fn test_pull_assignment_up_match() {
207         check_assist(
208             pull_assignment_up,
209             r#"
210 fn foo() {
211     let mut a = 1;
212
213     match 1 {
214         1 => {
215             $0a = 2;
216         },
217         2 => {
218             a = 3;
219         },
220         3 => {
221             a = 4;
222         }
223     }
224 }"#,
225             r#"
226 fn foo() {
227     let mut a = 1;
228
229     a = match 1 {
230         1 => {
231             2
232         },
233         2 => {
234             3
235         },
236         3 => {
237             4
238         }
239     };
240 }"#,
241         );
242     }
243
244     #[test]
245     fn test_pull_assignment_up_not_last_not_applicable() {
246         check_assist_not_applicable(
247             pull_assignment_up,
248             r#"
249 fn foo() {
250     let mut a = 1;
251
252     if true {
253         $0a = 2;
254         b = a;
255     } else {
256         a = 3;
257     }
258 }"#,
259         )
260     }
261
262     #[test]
263     fn test_pull_assignment_up_chained_if() {
264         cov_mark::check!(test_pull_assignment_up_chained_if);
265         check_assist(
266             pull_assignment_up,
267             r#"
268 fn foo() {
269     let mut a = 1;
270
271     if true {
272         $0a = 2;
273     } else if false {
274         a = 3;
275     } else {
276         a = 4;
277     }
278 }"#,
279             r#"
280 fn foo() {
281     let mut a = 1;
282
283     a = if true {
284         2
285     } else if false {
286         3
287     } else {
288         4
289     };
290 }"#,
291         );
292     }
293
294     #[test]
295     fn test_pull_assignment_up_retains_stmts() {
296         check_assist(
297             pull_assignment_up,
298             r#"
299 fn foo() {
300     let mut a = 1;
301
302     if true {
303         let b = 2;
304         $0a = 2;
305     } else {
306         let b = 3;
307         a = 3;
308     }
309 }"#,
310             r#"
311 fn foo() {
312     let mut a = 1;
313
314     a = if true {
315         let b = 2;
316         2
317     } else {
318         let b = 3;
319         3
320     };
321 }"#,
322         )
323     }
324
325     #[test]
326     fn pull_assignment_up_let_stmt_not_applicable() {
327         check_assist_not_applicable(
328             pull_assignment_up,
329             r#"
330 fn foo() {
331     let mut a = 1;
332
333     let b = if true {
334         $0a = 2
335     } else {
336         a = 3
337     };
338 }"#,
339         )
340     }
341
342     #[test]
343     fn pull_assignment_up_if_missing_assigment_not_applicable() {
344         check_assist_not_applicable(
345             pull_assignment_up,
346             r#"
347 fn foo() {
348     let mut a = 1;
349
350     if true {
351         $0a = 2;
352     } else {}
353 }"#,
354         )
355     }
356
357     #[test]
358     fn pull_assignment_up_match_missing_assigment_not_applicable() {
359         check_assist_not_applicable(
360             pull_assignment_up,
361             r#"
362 fn foo() {
363     let mut a = 1;
364
365     match 1 {
366         1 => {
367             $0a = 2;
368         },
369         2 => {
370             a = 3;
371         },
372         3 => {},
373     }
374 }"#,
375         )
376     }
377
378     #[test]
379     fn test_pull_assignment_up_field_assignment() {
380         cov_mark::check!(test_pull_assignment_up_field_assignment);
381         check_assist(
382             pull_assignment_up,
383             r#"
384 struct A(usize);
385
386 fn foo() {
387     let mut a = A(1);
388
389     if true {
390         $0a.0 = 2;
391     } else {
392         a.0 = 3;
393     }
394 }"#,
395             r#"
396 struct A(usize);
397
398 fn foo() {
399     let mut a = A(1);
400
401     a.0 = if true {
402         2
403     } else {
404         3
405     };
406 }"#,
407         )
408     }
409
410     #[test]
411     fn test_pull_assignment_up_deref() {
412         cov_mark::check!(test_pull_assignment_up_deref);
413         check_assist(
414             pull_assignment_up,
415             r#"
416 fn foo() {
417     let mut a = 1;
418     let b = &mut a;
419
420     if true {
421         $0*b = 2;
422     } else {
423         *b = 3;
424     }
425 }
426 "#,
427             r#"
428 fn foo() {
429     let mut a = 1;
430     let b = &mut a;
431
432     *b = if true {
433         2
434     } else {
435         3
436     };
437 }
438 "#,
439         )
440     }
441
442     #[test]
443     fn test_cant_pull_non_assignments() {
444         cov_mark::check!(test_cant_pull_non_assignments);
445         check_assist_not_applicable(
446             pull_assignment_up,
447             r#"
448 fn foo() {
449     let mut a = 1;
450     let b = &mut a;
451
452     if true {
453         $0*b + 2;
454     } else {
455         *b + 3;
456     }
457 }
458 "#,
459         )
460     }
461 }