]> git.lizzy.rs Git - rust.git/blob - crates/ide_assists/src/handlers/pull_assignment_up.rs
Merge #7887
[rust.git] / crates / ide_assists / src / handlers / pull_assignment_up.rs
1 use syntax::{
2     ast::{self, edit::AstNodeEdit, make},
3     AstNode,
4 };
5 use test_utils::mark;
6
7 use crate::{
8     assist_context::{AssistContext, Assists},
9     AssistId, AssistKind,
10 };
11
12 // Assist: pull_assignment_up
13 //
14 // Extracts variable assignment to outside an if or match statement.
15 //
16 // ```
17 // fn main() {
18 //     let mut foo = 6;
19 //
20 //     if true {
21 //         $0foo = 5;
22 //     } else {
23 //         foo = 4;
24 //     }
25 // }
26 // ```
27 // ->
28 // ```
29 // fn main() {
30 //     let mut foo = 6;
31 //
32 //     foo = if true {
33 //         5
34 //     } else {
35 //         4
36 //     };
37 // }
38 // ```
39 pub(crate) fn pull_assignment_up(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
40     let assign_expr = ctx.find_node_at_offset::<ast::BinExpr>()?;
41     let name_expr = if assign_expr.op_kind()? == ast::BinOp::Assignment {
42         assign_expr.lhs()?
43     } else {
44         return None;
45     };
46
47     let (old_stmt, new_stmt) = if let Some(if_expr) = ctx.find_node_at_offset::<ast::IfExpr>() {
48         (
49             ast::Expr::cast(if_expr.syntax().to_owned())?,
50             exprify_if(&if_expr, &ctx.sema, &name_expr)?.indent(if_expr.indent_level()),
51         )
52     } else if let Some(match_expr) = ctx.find_node_at_offset::<ast::MatchExpr>() {
53         (
54             ast::Expr::cast(match_expr.syntax().to_owned())?,
55             exprify_match(&match_expr, &ctx.sema, &name_expr)?,
56         )
57     } else {
58         return None;
59     };
60
61     let expr_stmt = make::expr_stmt(new_stmt);
62
63     acc.add(
64         AssistId("pull_assignment_up", AssistKind::RefactorExtract),
65         "Pull assignment up",
66         old_stmt.syntax().text_range(),
67         move |edit| {
68             edit.replace(old_stmt.syntax().text_range(), format!("{} = {};", name_expr, expr_stmt));
69         },
70     )
71 }
72
73 fn exprify_match(
74     match_expr: &ast::MatchExpr,
75     sema: &hir::Semantics<ide_db::RootDatabase>,
76     name: &ast::Expr,
77 ) -> Option<ast::Expr> {
78     let new_arm_list = match_expr
79         .match_arm_list()?
80         .arms()
81         .map(|arm| {
82             if let ast::Expr::BlockExpr(block) = arm.expr()? {
83                 let new_block = exprify_block(&block, sema, name)?.indent(block.indent_level());
84                 Some(arm.replace_descendant(block, new_block))
85             } else {
86                 None
87             }
88         })
89         .collect::<Option<Vec<_>>>()?;
90     let new_arm_list = match_expr
91         .match_arm_list()?
92         .replace_descendants(match_expr.match_arm_list()?.arms().zip(new_arm_list));
93     Some(make::expr_match(match_expr.expr()?, new_arm_list))
94 }
95
96 fn exprify_if(
97     statement: &ast::IfExpr,
98     sema: &hir::Semantics<ide_db::RootDatabase>,
99     name: &ast::Expr,
100 ) -> Option<ast::Expr> {
101     let then_branch = exprify_block(&statement.then_branch()?, sema, name)?;
102     let else_branch = match statement.else_branch()? {
103         ast::ElseBranch::Block(ref block) => {
104             ast::ElseBranch::Block(exprify_block(block, sema, name)?)
105         }
106         ast::ElseBranch::IfExpr(expr) => {
107             mark::hit!(test_pull_assignment_up_chained_if);
108             ast::ElseBranch::IfExpr(ast::IfExpr::cast(
109                 exprify_if(&expr, sema, name)?.syntax().to_owned(),
110             )?)
111         }
112     };
113     Some(make::expr_if(statement.condition()?, then_branch, Some(else_branch)))
114 }
115
116 fn exprify_block(
117     block: &ast::BlockExpr,
118     sema: &hir::Semantics<ide_db::RootDatabase>,
119     name: &ast::Expr,
120 ) -> Option<ast::BlockExpr> {
121     if block.tail_expr().is_some() {
122         return None;
123     }
124
125     let mut stmts: Vec<_> = block.statements().collect();
126     let stmt = stmts.pop()?;
127
128     if let ast::Stmt::ExprStmt(stmt) = stmt {
129         if let ast::Expr::BinExpr(expr) = stmt.expr()? {
130             if expr.op_kind()? == ast::BinOp::Assignment && is_equivalent(sema, &expr.lhs()?, name)
131             {
132                 // The last statement in the block is an assignment to the name we want
133                 return Some(make::block_expr(stmts, Some(expr.rhs()?)));
134             }
135         }
136     }
137     None
138 }
139
140 fn is_equivalent(
141     sema: &hir::Semantics<ide_db::RootDatabase>,
142     expr0: &ast::Expr,
143     expr1: &ast::Expr,
144 ) -> bool {
145     match (expr0, expr1) {
146         (ast::Expr::FieldExpr(field_expr0), ast::Expr::FieldExpr(field_expr1)) => {
147             mark::hit!(test_pull_assignment_up_field_assignment);
148             sema.resolve_field(field_expr0) == sema.resolve_field(field_expr1)
149         }
150         (ast::Expr::PathExpr(path0), ast::Expr::PathExpr(path1)) => {
151             let path0 = path0.path();
152             let path1 = path1.path();
153             if let (Some(path0), Some(path1)) = (path0, path1) {
154                 sema.resolve_path(&path0) == sema.resolve_path(&path1)
155             } else {
156                 false
157             }
158         }
159         (ast::Expr::PrefixExpr(prefix0), ast::Expr::PrefixExpr(prefix1))
160             if prefix0.op_kind() == Some(ast::PrefixOp::Deref)
161                 && prefix1.op_kind() == Some(ast::PrefixOp::Deref) =>
162         {
163             mark::hit!(test_pull_assignment_up_deref);
164             if let (Some(prefix0), Some(prefix1)) = (prefix0.expr(), prefix1.expr()) {
165                 is_equivalent(sema, &prefix0, &prefix1)
166             } else {
167                 false
168             }
169         }
170         _ => false,
171     }
172 }
173
174 #[cfg(test)]
175 mod tests {
176     use super::*;
177
178     use crate::tests::{check_assist, check_assist_not_applicable};
179
180     #[test]
181     fn test_pull_assignment_up_if() {
182         check_assist(
183             pull_assignment_up,
184             r#"
185 fn foo() {
186     let mut a = 1;
187
188     if true {
189         $0a = 2;
190     } else {
191         a = 3;
192     }
193 }"#,
194             r#"
195 fn foo() {
196     let mut a = 1;
197
198     a = if true {
199         2
200     } else {
201         3
202     };
203 }"#,
204         );
205     }
206
207     #[test]
208     fn test_pull_assignment_up_match() {
209         check_assist(
210             pull_assignment_up,
211             r#"
212 fn foo() {
213     let mut a = 1;
214
215     match 1 {
216         1 => {
217             $0a = 2;
218         },
219         2 => {
220             a = 3;
221         },
222         3 => {
223             a = 4;
224         }
225     }
226 }"#,
227             r#"
228 fn foo() {
229     let mut a = 1;
230
231     a = match 1 {
232         1 => {
233             2
234         },
235         2 => {
236             3
237         },
238         3 => {
239             4
240         }
241     };
242 }"#,
243         );
244     }
245
246     #[test]
247     fn test_pull_assignment_up_not_last_not_applicable() {
248         check_assist_not_applicable(
249             pull_assignment_up,
250             r#"
251 fn foo() {
252     let mut a = 1;
253
254     if true {
255         $0a = 2;
256         b = a;
257     } else {
258         a = 3;
259     }
260 }"#,
261         )
262     }
263
264     #[test]
265     fn test_pull_assignment_up_chained_if() {
266         mark::check!(test_pull_assignment_up_chained_if);
267         check_assist(
268             pull_assignment_up,
269             r#"
270 fn foo() {
271     let mut a = 1;
272
273     if true {
274         $0a = 2;
275     } else if false {
276         a = 3;
277     } else {
278         a = 4;
279     }
280 }"#,
281             r#"
282 fn foo() {
283     let mut a = 1;
284
285     a = if true {
286         2
287     } else if false {
288         3
289     } else {
290         4
291     };
292 }"#,
293         );
294     }
295
296     #[test]
297     fn test_pull_assignment_up_retains_stmts() {
298         check_assist(
299             pull_assignment_up,
300             r#"
301 fn foo() {
302     let mut a = 1;
303
304     if true {
305         let b = 2;
306         $0a = 2;
307     } else {
308         let b = 3;
309         a = 3;
310     }
311 }"#,
312             r#"
313 fn foo() {
314     let mut a = 1;
315
316     a = if true {
317         let b = 2;
318         2
319     } else {
320         let b = 3;
321         3
322     };
323 }"#,
324         )
325     }
326
327     #[test]
328     fn pull_assignment_up_let_stmt_not_applicable() {
329         check_assist_not_applicable(
330             pull_assignment_up,
331             r#"
332 fn foo() {
333     let mut a = 1;
334
335     let b = if true {
336         $0a = 2
337     } else {
338         a = 3
339     };
340 }"#,
341         )
342     }
343
344     #[test]
345     fn pull_assignment_up_if_missing_assigment_not_applicable() {
346         check_assist_not_applicable(
347             pull_assignment_up,
348             r#"
349 fn foo() {
350     let mut a = 1;
351
352     if true {
353         $0a = 2;
354     } else {}
355 }"#,
356         )
357     }
358
359     #[test]
360     fn pull_assignment_up_match_missing_assigment_not_applicable() {
361         check_assist_not_applicable(
362             pull_assignment_up,
363             r#"
364 fn foo() {
365     let mut a = 1;
366
367     match 1 {
368         1 => {
369             $0a = 2;
370         },
371         2 => {
372             a = 3;
373         },
374         3 => {},
375     }
376 }"#,
377         )
378     }
379
380     #[test]
381     fn test_pull_assignment_up_field_assignment() {
382         mark::check!(test_pull_assignment_up_field_assignment);
383         check_assist(
384             pull_assignment_up,
385             r#"
386 struct A(usize);
387
388 fn foo() {
389     let mut a = A(1);
390
391     if true {
392         $0a.0 = 2;
393     } else {
394         a.0 = 3;
395     }
396 }"#,
397             r#"
398 struct A(usize);
399
400 fn foo() {
401     let mut a = A(1);
402
403     a.0 = if true {
404         2
405     } else {
406         3
407     };
408 }"#,
409         )
410     }
411
412     #[test]
413     fn test_pull_assignment_up_deref() {
414         mark::check!(test_pull_assignment_up_deref);
415         check_assist(
416             pull_assignment_up,
417             r#"
418 fn foo() {
419     let mut a = 1;
420     let b = &mut a;
421
422     if true {
423         $0*b = 2;
424     } else {
425         *b = 3;
426     }
427 }
428 "#,
429             r#"
430 fn foo() {
431     let mut a = 1;
432     let b = &mut a;
433
434     *b = if true {
435         2
436     } else {
437         3
438     };
439 }
440 "#,
441         )
442     }
443 }