]> git.lizzy.rs Git - rust.git/blob - crates/ide_assists/src/handlers/pull_assignment_up.rs
minor: add missing test
[rust.git] / crates / ide_assists / src / handlers / pull_assignment_up.rs
1 use syntax::{
2     ast::{self, edit::AstNodeEdit, make},
3     AstNode,
4 };
5
6 use crate::{
7     assist_context::{AssistContext, Assists},
8     AssistId, AssistKind,
9 };
10
11 // Assist: pull_assignment_up
12 //
13 // Extracts variable assignment to outside an if or match statement.
14 //
15 // ```
16 // fn main() {
17 //     let mut foo = 6;
18 //
19 //     if true {
20 //         $0foo = 5;
21 //     } else {
22 //         foo = 4;
23 //     }
24 // }
25 // ```
26 // ->
27 // ```
28 // fn main() {
29 //     let mut foo = 6;
30 //
31 //     foo = if true {
32 //         5
33 //     } else {
34 //         4
35 //     };
36 // }
37 // ```
38 pub(crate) fn pull_assignment_up(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
39     let assign_expr = ctx.find_node_at_offset::<ast::BinExpr>()?;
40
41     let op_kind = assign_expr.op_kind()?;
42     if op_kind != ast::BinOp::Assignment {
43         cov_mark::hit!(test_cant_pull_non_assignments);
44         return None;
45     }
46
47     let name_expr = assign_expr.lhs()?;
48
49     let old_stmt: ast::Expr;
50     let new_stmt: ast::Expr;
51
52     if let Some(if_expr) = ctx.find_node_at_offset::<ast::IfExpr>() {
53         new_stmt = exprify_if(&if_expr, &ctx.sema, &name_expr)?.indent(if_expr.indent_level());
54         old_stmt = if_expr.into();
55     } else if let Some(match_expr) = ctx.find_node_at_offset::<ast::MatchExpr>() {
56         new_stmt = exprify_match(&match_expr, &ctx.sema, &name_expr)?;
57         old_stmt = match_expr.into()
58     } else {
59         return None;
60     };
61
62     let expr_stmt = make::expr_stmt(new_stmt);
63
64     acc.add(
65         AssistId("pull_assignment_up", AssistKind::RefactorExtract),
66         "Pull assignment up",
67         old_stmt.syntax().text_range(),
68         move |edit| {
69             edit.replace(old_stmt.syntax().text_range(), format!("{} = {};", name_expr, expr_stmt));
70         },
71     )
72 }
73
74 fn exprify_match(
75     match_expr: &ast::MatchExpr,
76     sema: &hir::Semantics<ide_db::RootDatabase>,
77     name: &ast::Expr,
78 ) -> Option<ast::Expr> {
79     let new_arm_list = match_expr
80         .match_arm_list()?
81         .arms()
82         .map(|arm| {
83             if let ast::Expr::BlockExpr(block) = arm.expr()? {
84                 let new_block = exprify_block(&block, sema, name)?.indent(block.indent_level());
85                 Some(arm.replace_descendant(block, new_block))
86             } else {
87                 None
88             }
89         })
90         .collect::<Option<Vec<_>>>()?;
91     let new_arm_list = match_expr
92         .match_arm_list()?
93         .replace_descendants(match_expr.match_arm_list()?.arms().zip(new_arm_list));
94     Some(make::expr_match(match_expr.expr()?, new_arm_list))
95 }
96
97 fn exprify_if(
98     statement: &ast::IfExpr,
99     sema: &hir::Semantics<ide_db::RootDatabase>,
100     name: &ast::Expr,
101 ) -> Option<ast::Expr> {
102     let then_branch = exprify_block(&statement.then_branch()?, sema, name)?;
103     let else_branch = match statement.else_branch()? {
104         ast::ElseBranch::Block(block) => ast::ElseBranch::Block(exprify_block(&block, sema, name)?),
105         ast::ElseBranch::IfExpr(expr) => {
106             cov_mark::hit!(test_pull_assignment_up_chained_if);
107             ast::ElseBranch::IfExpr(ast::IfExpr::cast(
108                 exprify_if(&expr, sema, name)?.syntax().to_owned(),
109             )?)
110         }
111     };
112     Some(make::expr_if(statement.condition()?, then_branch, Some(else_branch)))
113 }
114
115 fn exprify_block(
116     block: &ast::BlockExpr,
117     sema: &hir::Semantics<ide_db::RootDatabase>,
118     name: &ast::Expr,
119 ) -> Option<ast::BlockExpr> {
120     if block.tail_expr().is_some() {
121         return None;
122     }
123
124     let mut stmts: Vec<_> = block.statements().collect();
125     let stmt = stmts.pop()?;
126
127     if let ast::Stmt::ExprStmt(stmt) = stmt {
128         if let ast::Expr::BinExpr(expr) = stmt.expr()? {
129             if expr.op_kind()? == ast::BinOp::Assignment && is_equivalent(sema, &expr.lhs()?, name)
130             {
131                 // The last statement in the block is an assignment to the name we want
132                 return Some(make::block_expr(stmts, Some(expr.rhs()?)));
133             }
134         }
135     }
136     None
137 }
138
139 fn is_equivalent(
140     sema: &hir::Semantics<ide_db::RootDatabase>,
141     expr0: &ast::Expr,
142     expr1: &ast::Expr,
143 ) -> bool {
144     match (expr0, expr1) {
145         (ast::Expr::FieldExpr(field_expr0), ast::Expr::FieldExpr(field_expr1)) => {
146             cov_mark::hit!(test_pull_assignment_up_field_assignment);
147             sema.resolve_field(field_expr0) == sema.resolve_field(field_expr1)
148         }
149         (ast::Expr::PathExpr(path0), ast::Expr::PathExpr(path1)) => {
150             let path0 = path0.path();
151             let path1 = path1.path();
152             if let (Some(path0), Some(path1)) = (path0, path1) {
153                 sema.resolve_path(&path0) == sema.resolve_path(&path1)
154             } else {
155                 false
156             }
157         }
158         (ast::Expr::PrefixExpr(prefix0), ast::Expr::PrefixExpr(prefix1))
159             if prefix0.op_kind() == Some(ast::PrefixOp::Deref)
160                 && prefix1.op_kind() == Some(ast::PrefixOp::Deref) =>
161         {
162             cov_mark::hit!(test_pull_assignment_up_deref);
163             if let (Some(prefix0), Some(prefix1)) = (prefix0.expr(), prefix1.expr()) {
164                 is_equivalent(sema, &prefix0, &prefix1)
165             } else {
166                 false
167             }
168         }
169         _ => false,
170     }
171 }
172
173 #[cfg(test)]
174 mod tests {
175     use super::*;
176
177     use crate::tests::{check_assist, check_assist_not_applicable};
178
179     #[test]
180     fn test_pull_assignment_up_if() {
181         check_assist(
182             pull_assignment_up,
183             r#"
184 fn foo() {
185     let mut a = 1;
186
187     if true {
188         $0a = 2;
189     } else {
190         a = 3;
191     }
192 }"#,
193             r#"
194 fn foo() {
195     let mut a = 1;
196
197     a = if true {
198         2
199     } else {
200         3
201     };
202 }"#,
203         );
204     }
205
206     #[test]
207     fn test_pull_assignment_up_match() {
208         check_assist(
209             pull_assignment_up,
210             r#"
211 fn foo() {
212     let mut a = 1;
213
214     match 1 {
215         1 => {
216             $0a = 2;
217         },
218         2 => {
219             a = 3;
220         },
221         3 => {
222             a = 4;
223         }
224     }
225 }"#,
226             r#"
227 fn foo() {
228     let mut a = 1;
229
230     a = match 1 {
231         1 => {
232             2
233         },
234         2 => {
235             3
236         },
237         3 => {
238             4
239         }
240     };
241 }"#,
242         );
243     }
244
245     #[test]
246     fn test_pull_assignment_up_not_last_not_applicable() {
247         check_assist_not_applicable(
248             pull_assignment_up,
249             r#"
250 fn foo() {
251     let mut a = 1;
252
253     if true {
254         $0a = 2;
255         b = a;
256     } else {
257         a = 3;
258     }
259 }"#,
260         )
261     }
262
263     #[test]
264     fn test_pull_assignment_up_chained_if() {
265         cov_mark::check!(test_pull_assignment_up_chained_if);
266         check_assist(
267             pull_assignment_up,
268             r#"
269 fn foo() {
270     let mut a = 1;
271
272     if true {
273         $0a = 2;
274     } else if false {
275         a = 3;
276     } else {
277         a = 4;
278     }
279 }"#,
280             r#"
281 fn foo() {
282     let mut a = 1;
283
284     a = if true {
285         2
286     } else if false {
287         3
288     } else {
289         4
290     };
291 }"#,
292         );
293     }
294
295     #[test]
296     fn test_pull_assignment_up_retains_stmts() {
297         check_assist(
298             pull_assignment_up,
299             r#"
300 fn foo() {
301     let mut a = 1;
302
303     if true {
304         let b = 2;
305         $0a = 2;
306     } else {
307         let b = 3;
308         a = 3;
309     }
310 }"#,
311             r#"
312 fn foo() {
313     let mut a = 1;
314
315     a = if true {
316         let b = 2;
317         2
318     } else {
319         let b = 3;
320         3
321     };
322 }"#,
323         )
324     }
325
326     #[test]
327     fn pull_assignment_up_let_stmt_not_applicable() {
328         check_assist_not_applicable(
329             pull_assignment_up,
330             r#"
331 fn foo() {
332     let mut a = 1;
333
334     let b = if true {
335         $0a = 2
336     } else {
337         a = 3
338     };
339 }"#,
340         )
341     }
342
343     #[test]
344     fn pull_assignment_up_if_missing_assigment_not_applicable() {
345         check_assist_not_applicable(
346             pull_assignment_up,
347             r#"
348 fn foo() {
349     let mut a = 1;
350
351     if true {
352         $0a = 2;
353     } else {}
354 }"#,
355         )
356     }
357
358     #[test]
359     fn pull_assignment_up_match_missing_assigment_not_applicable() {
360         check_assist_not_applicable(
361             pull_assignment_up,
362             r#"
363 fn foo() {
364     let mut a = 1;
365
366     match 1 {
367         1 => {
368             $0a = 2;
369         },
370         2 => {
371             a = 3;
372         },
373         3 => {},
374     }
375 }"#,
376         )
377     }
378
379     #[test]
380     fn test_pull_assignment_up_field_assignment() {
381         cov_mark::check!(test_pull_assignment_up_field_assignment);
382         check_assist(
383             pull_assignment_up,
384             r#"
385 struct A(usize);
386
387 fn foo() {
388     let mut a = A(1);
389
390     if true {
391         $0a.0 = 2;
392     } else {
393         a.0 = 3;
394     }
395 }"#,
396             r#"
397 struct A(usize);
398
399 fn foo() {
400     let mut a = A(1);
401
402     a.0 = if true {
403         2
404     } else {
405         3
406     };
407 }"#,
408         )
409     }
410
411     #[test]
412     fn test_pull_assignment_up_deref() {
413         cov_mark::check!(test_pull_assignment_up_deref);
414         check_assist(
415             pull_assignment_up,
416             r#"
417 fn foo() {
418     let mut a = 1;
419     let b = &mut a;
420
421     if true {
422         $0*b = 2;
423     } else {
424         *b = 3;
425     }
426 }
427 "#,
428             r#"
429 fn foo() {
430     let mut a = 1;
431     let b = &mut a;
432
433     *b = if true {
434         2
435     } else {
436         3
437     };
438 }
439 "#,
440         )
441     }
442
443     #[test]
444     fn test_cant_pull_non_assignments() {
445         cov_mark::check!(test_cant_pull_non_assignments);
446         check_assist_not_applicable(
447             pull_assignment_up,
448             r#"
449 fn foo() {
450     let mut a = 1;
451     let b = &mut a;
452
453     if true {
454         $0*b + 2;
455     } else {
456         *b + 3;
457     }
458 }
459 "#,
460         )
461     }
462 }