]> git.lizzy.rs Git - rust.git/blob - crates/ide_assists/src/handlers/extract_variable.rs
Merge #8398
[rust.git] / crates / ide_assists / src / handlers / extract_variable.rs
1 use stdx::format_to;
2 use syntax::{
3     ast::{self, AstNode},
4     SyntaxKind::{
5         BLOCK_EXPR, BREAK_EXPR, CLOSURE_EXPR, COMMENT, LOOP_EXPR, MATCH_ARM, MATCH_GUARD,
6         PATH_EXPR, RETURN_EXPR,
7     },
8     SyntaxNode,
9 };
10
11 use crate::{utils::suggest_name, AssistContext, AssistId, AssistKind, Assists};
12
13 // Assist: extract_variable
14 //
15 // Extracts subexpression into a variable.
16 //
17 // ```
18 // fn main() {
19 //     $0(1 + 2)$0 * 4;
20 // }
21 // ```
22 // ->
23 // ```
24 // fn main() {
25 //     let $0var_name = (1 + 2);
26 //     var_name * 4;
27 // }
28 // ```
29 pub(crate) fn extract_variable(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
30     if ctx.frange.range.is_empty() {
31         return None;
32     }
33     let node = ctx.covering_element();
34     if node.kind() == COMMENT {
35         cov_mark::hit!(extract_var_in_comment_is_not_applicable);
36         return None;
37     }
38     let to_extract = node.ancestors().find_map(valid_target_expr)?;
39     let anchor = Anchor::from(&to_extract)?;
40     let indent = anchor.syntax().prev_sibling_or_token()?.as_token()?.clone();
41     let target = to_extract.syntax().text_range();
42     acc.add(
43         AssistId("extract_variable", AssistKind::RefactorExtract),
44         "Extract into variable",
45         target,
46         move |edit| {
47             let field_shorthand =
48                 match to_extract.syntax().parent().and_then(ast::RecordExprField::cast) {
49                     Some(field) => field.name_ref(),
50                     None => None,
51                 };
52
53             let mut buf = String::new();
54
55             let var_name = match &field_shorthand {
56                 Some(it) => it.to_string(),
57                 None => suggest_name::for_variable(&to_extract, &ctx.sema),
58             };
59             let expr_range = match &field_shorthand {
60                 Some(it) => it.syntax().text_range().cover(to_extract.syntax().text_range()),
61                 None => to_extract.syntax().text_range(),
62             };
63
64             if let Anchor::WrapInBlock(_) = anchor {
65                 format_to!(buf, "{{ let {} = ", var_name);
66             } else {
67                 format_to!(buf, "let {} = ", var_name);
68             };
69             format_to!(buf, "{}", to_extract.syntax());
70
71             if let Anchor::Replace(stmt) = anchor {
72                 cov_mark::hit!(test_extract_var_expr_stmt);
73                 if stmt.semicolon_token().is_none() {
74                     buf.push_str(";");
75                 }
76                 match ctx.config.snippet_cap {
77                     Some(cap) => {
78                         let snip = buf
79                             .replace(&format!("let {}", var_name), &format!("let $0{}", var_name));
80                         edit.replace_snippet(cap, expr_range, snip)
81                     }
82                     None => edit.replace(expr_range, buf),
83                 }
84                 return;
85             }
86
87             buf.push_str(";");
88
89             // We want to maintain the indent level,
90             // but we do not want to duplicate possible
91             // extra newlines in the indent block
92             let text = indent.text();
93             if text.starts_with('\n') {
94                 buf.push('\n');
95                 buf.push_str(text.trim_start_matches('\n'));
96             } else {
97                 buf.push_str(text);
98             }
99
100             edit.replace(expr_range, var_name.clone());
101             let offset = anchor.syntax().text_range().start();
102             match ctx.config.snippet_cap {
103                 Some(cap) => {
104                     let snip =
105                         buf.replace(&format!("let {}", var_name), &format!("let $0{}", var_name));
106                     edit.insert_snippet(cap, offset, snip)
107                 }
108                 None => edit.insert(offset, buf),
109             }
110
111             if let Anchor::WrapInBlock(_) = anchor {
112                 edit.insert(anchor.syntax().text_range().end(), " }");
113             }
114         },
115     )
116 }
117
118 /// Check whether the node is a valid expression which can be extracted to a variable.
119 /// In general that's true for any expression, but in some cases that would produce invalid code.
120 fn valid_target_expr(node: SyntaxNode) -> Option<ast::Expr> {
121     match node.kind() {
122         PATH_EXPR | LOOP_EXPR => None,
123         BREAK_EXPR => ast::BreakExpr::cast(node).and_then(|e| e.expr()),
124         RETURN_EXPR => ast::ReturnExpr::cast(node).and_then(|e| e.expr()),
125         BLOCK_EXPR => {
126             ast::BlockExpr::cast(node).filter(|it| it.is_standalone()).map(ast::Expr::from)
127         }
128         _ => ast::Expr::cast(node),
129     }
130 }
131
132 enum Anchor {
133     Before(SyntaxNode),
134     Replace(ast::ExprStmt),
135     WrapInBlock(SyntaxNode),
136 }
137
138 impl Anchor {
139     fn from(to_extract: &ast::Expr) -> Option<Anchor> {
140         to_extract.syntax().ancestors().find_map(|node| {
141             if let Some(expr) =
142                 node.parent().and_then(ast::BlockExpr::cast).and_then(|it| it.tail_expr())
143             {
144                 if expr.syntax() == &node {
145                     cov_mark::hit!(test_extract_var_last_expr);
146                     return Some(Anchor::Before(node));
147                 }
148             }
149
150             if let Some(parent) = node.parent() {
151                 if parent.kind() == CLOSURE_EXPR {
152                     cov_mark::hit!(test_extract_var_in_closure_no_block);
153                     return Some(Anchor::WrapInBlock(node));
154                 }
155                 if parent.kind() == MATCH_ARM {
156                     if node.kind() == MATCH_GUARD {
157                         cov_mark::hit!(test_extract_var_in_match_guard);
158                     } else {
159                         cov_mark::hit!(test_extract_var_in_match_arm_no_block);
160                         return Some(Anchor::WrapInBlock(node));
161                     }
162                 }
163             }
164
165             if let Some(stmt) = ast::Stmt::cast(node.clone()) {
166                 if let ast::Stmt::ExprStmt(stmt) = stmt {
167                     if stmt.expr().as_ref() == Some(to_extract) {
168                         return Some(Anchor::Replace(stmt));
169                     }
170                 }
171                 return Some(Anchor::Before(node));
172             }
173             None
174         })
175     }
176
177     fn syntax(&self) -> &SyntaxNode {
178         match self {
179             Anchor::Before(it) | Anchor::WrapInBlock(it) => it,
180             Anchor::Replace(stmt) => stmt.syntax(),
181         }
182     }
183 }
184
185 #[cfg(test)]
186 mod tests {
187     use crate::tests::{check_assist, check_assist_not_applicable, check_assist_target};
188
189     use super::*;
190
191     #[test]
192     fn test_extract_var_simple() {
193         check_assist(
194             extract_variable,
195             r#"
196 fn foo() {
197     foo($01 + 1$0);
198 }"#,
199             r#"
200 fn foo() {
201     let $0var_name = 1 + 1;
202     foo(var_name);
203 }"#,
204         );
205     }
206
207     #[test]
208     fn extract_var_in_comment_is_not_applicable() {
209         cov_mark::check!(extract_var_in_comment_is_not_applicable);
210         check_assist_not_applicable(extract_variable, "fn main() { 1 + /* $0comment$0 */ 1; }");
211     }
212
213     #[test]
214     fn test_extract_var_expr_stmt() {
215         cov_mark::check!(test_extract_var_expr_stmt);
216         check_assist(
217             extract_variable,
218             r#"
219 fn foo() {
220     $01 + 1$0;
221 }"#,
222             r#"
223 fn foo() {
224     let $0var_name = 1 + 1;
225 }"#,
226         );
227         check_assist(
228             extract_variable,
229             "
230 fn foo() {
231     $0{ let x = 0; x }$0
232     something_else();
233 }",
234             "
235 fn foo() {
236     let $0var_name = { let x = 0; x };
237     something_else();
238 }",
239         );
240     }
241
242     #[test]
243     fn test_extract_var_part_of_expr_stmt() {
244         check_assist(
245             extract_variable,
246             "
247 fn foo() {
248     $01$0 + 1;
249 }",
250             "
251 fn foo() {
252     let $0var_name = 1;
253     var_name + 1;
254 }",
255         );
256     }
257
258     #[test]
259     fn test_extract_var_last_expr() {
260         cov_mark::check!(test_extract_var_last_expr);
261         check_assist(
262             extract_variable,
263             r#"
264 fn foo() {
265     bar($01 + 1$0)
266 }
267 "#,
268             r#"
269 fn foo() {
270     let $0var_name = 1 + 1;
271     bar(var_name)
272 }
273 "#,
274         );
275         check_assist(
276             extract_variable,
277             r#"
278 fn foo() {
279     $0bar(1 + 1)$0
280 }
281 "#,
282             r#"
283 fn foo() {
284     let $0bar = bar(1 + 1);
285     bar
286 }
287 "#,
288         )
289     }
290
291     #[test]
292     fn test_extract_var_in_match_arm_no_block() {
293         cov_mark::check!(test_extract_var_in_match_arm_no_block);
294         check_assist(
295             extract_variable,
296             r#"
297 fn main() {
298     let x = true;
299     let tuple = match x {
300         true => ($02 + 2$0, true)
301         _ => (0, false)
302     };
303 }
304 "#,
305             r#"
306 fn main() {
307     let x = true;
308     let tuple = match x {
309         true => { let $0var_name = 2 + 2; (var_name, true) }
310         _ => (0, false)
311     };
312 }
313 "#,
314         );
315     }
316
317     #[test]
318     fn test_extract_var_in_match_arm_with_block() {
319         check_assist(
320             extract_variable,
321             r#"
322 fn main() {
323     let x = true;
324     let tuple = match x {
325         true => {
326             let y = 1;
327             ($02 + y$0, true)
328         }
329         _ => (0, false)
330     };
331 }
332 "#,
333             r#"
334 fn main() {
335     let x = true;
336     let tuple = match x {
337         true => {
338             let y = 1;
339             let $0var_name = 2 + y;
340             (var_name, true)
341         }
342         _ => (0, false)
343     };
344 }
345 "#,
346         );
347     }
348
349     #[test]
350     fn test_extract_var_in_match_guard() {
351         cov_mark::check!(test_extract_var_in_match_guard);
352         check_assist(
353             extract_variable,
354             r#"
355 fn main() {
356     match () {
357         () if $010 > 0$0 => 1
358         _ => 2
359     };
360 }
361 "#,
362             r#"
363 fn main() {
364     let $0var_name = 10 > 0;
365     match () {
366         () if var_name => 1
367         _ => 2
368     };
369 }
370 "#,
371         );
372     }
373
374     #[test]
375     fn test_extract_var_in_closure_no_block() {
376         cov_mark::check!(test_extract_var_in_closure_no_block);
377         check_assist(
378             extract_variable,
379             r#"
380 fn main() {
381     let lambda = |x: u32| $0x * 2$0;
382 }
383 "#,
384             r#"
385 fn main() {
386     let lambda = |x: u32| { let $0var_name = x * 2; var_name };
387 }
388 "#,
389         );
390     }
391
392     #[test]
393     fn test_extract_var_in_closure_with_block() {
394         check_assist(
395             extract_variable,
396             r#"
397 fn main() {
398     let lambda = |x: u32| { $0x * 2$0 };
399 }
400 "#,
401             r#"
402 fn main() {
403     let lambda = |x: u32| { let $0var_name = x * 2; var_name };
404 }
405 "#,
406         );
407     }
408
409     #[test]
410     fn test_extract_var_path_simple() {
411         check_assist(
412             extract_variable,
413             "
414 fn main() {
415     let o = $0Some(true)$0;
416 }
417 ",
418             "
419 fn main() {
420     let $0var_name = Some(true);
421     let o = var_name;
422 }
423 ",
424         );
425     }
426
427     #[test]
428     fn test_extract_var_path_method() {
429         check_assist(
430             extract_variable,
431             "
432 fn main() {
433     let v = $0bar.foo()$0;
434 }
435 ",
436             "
437 fn main() {
438     let $0foo = bar.foo();
439     let v = foo;
440 }
441 ",
442         );
443     }
444
445     #[test]
446     fn test_extract_var_return() {
447         check_assist(
448             extract_variable,
449             "
450 fn foo() -> u32 {
451     $0return 2 + 2$0;
452 }
453 ",
454             "
455 fn foo() -> u32 {
456     let $0var_name = 2 + 2;
457     return var_name;
458 }
459 ",
460         );
461     }
462
463     #[test]
464     fn test_extract_var_does_not_add_extra_whitespace() {
465         check_assist(
466             extract_variable,
467             "
468 fn foo() -> u32 {
469
470
471     $0return 2 + 2$0;
472 }
473 ",
474             "
475 fn foo() -> u32 {
476
477
478     let $0var_name = 2 + 2;
479     return var_name;
480 }
481 ",
482         );
483
484         check_assist(
485             extract_variable,
486             "
487 fn foo() -> u32 {
488
489         $0return 2 + 2$0;
490 }
491 ",
492             "
493 fn foo() -> u32 {
494
495         let $0var_name = 2 + 2;
496         return var_name;
497 }
498 ",
499         );
500
501         check_assist(
502             extract_variable,
503             "
504 fn foo() -> u32 {
505     let foo = 1;
506
507     // bar
508
509
510     $0return 2 + 2$0;
511 }
512 ",
513             "
514 fn foo() -> u32 {
515     let foo = 1;
516
517     // bar
518
519
520     let $0var_name = 2 + 2;
521     return var_name;
522 }
523 ",
524         );
525     }
526
527     #[test]
528     fn test_extract_var_break() {
529         check_assist(
530             extract_variable,
531             "
532 fn main() {
533     let result = loop {
534         $0break 2 + 2$0;
535     };
536 }
537 ",
538             "
539 fn main() {
540     let result = loop {
541         let $0var_name = 2 + 2;
542         break var_name;
543     };
544 }
545 ",
546         );
547     }
548
549     #[test]
550     fn test_extract_var_for_cast() {
551         check_assist(
552             extract_variable,
553             "
554 fn main() {
555     let v = $00f32 as u32$0;
556 }
557 ",
558             "
559 fn main() {
560     let $0var_name = 0f32 as u32;
561     let v = var_name;
562 }
563 ",
564         );
565     }
566
567     #[test]
568     fn extract_var_field_shorthand() {
569         check_assist(
570             extract_variable,
571             r#"
572 struct S {
573     foo: i32
574 }
575
576 fn main() {
577     S { foo: $01 + 1$0 }
578 }
579 "#,
580             r#"
581 struct S {
582     foo: i32
583 }
584
585 fn main() {
586     let $0foo = 1 + 1;
587     S { foo }
588 }
589 "#,
590         )
591     }
592
593     #[test]
594     fn extract_var_name_from_type() {
595         check_assist(
596             extract_variable,
597             r#"
598 struct Test(i32);
599
600 fn foo() -> Test {
601     $0{ Test(10) }$0
602 }
603 "#,
604             r#"
605 struct Test(i32);
606
607 fn foo() -> Test {
608     let $0test = { Test(10) };
609     test
610 }
611 "#,
612         )
613     }
614
615     #[test]
616     fn extract_var_name_from_parameter() {
617         check_assist(
618             extract_variable,
619             r#"
620 fn bar(test: u32, size: u32)
621
622 fn foo() {
623     bar(1, $01+1$0);
624 }
625 "#,
626             r#"
627 fn bar(test: u32, size: u32)
628
629 fn foo() {
630     let $0size = 1+1;
631     bar(1, size);
632 }
633 "#,
634         )
635     }
636
637     #[test]
638     fn extract_var_parameter_name_has_precedence_over_type() {
639         check_assist(
640             extract_variable,
641             r#"
642 struct TextSize(u32);
643 fn bar(test: u32, size: TextSize)
644
645 fn foo() {
646     bar(1, $0{ TextSize(1+1) }$0);
647 }
648 "#,
649             r#"
650 struct TextSize(u32);
651 fn bar(test: u32, size: TextSize)
652
653 fn foo() {
654     let $0size = { TextSize(1+1) };
655     bar(1, size);
656 }
657 "#,
658         )
659     }
660
661     #[test]
662     fn extract_var_name_from_function() {
663         check_assist(
664             extract_variable,
665             r#"
666 fn is_required(test: u32, size: u32) -> bool
667
668 fn foo() -> bool {
669     $0is_required(1, 2)$0
670 }
671 "#,
672             r#"
673 fn is_required(test: u32, size: u32) -> bool
674
675 fn foo() -> bool {
676     let $0is_required = is_required(1, 2);
677     is_required
678 }
679 "#,
680         )
681     }
682
683     #[test]
684     fn extract_var_name_from_method() {
685         check_assist(
686             extract_variable,
687             r#"
688 struct S;
689 impl S {
690     fn bar(&self, n: u32) -> u32 { n }
691 }
692
693 fn foo() -> u32 {
694     $0S.bar(1)$0
695 }
696 "#,
697             r#"
698 struct S;
699 impl S {
700     fn bar(&self, n: u32) -> u32 { n }
701 }
702
703 fn foo() -> u32 {
704     let $0bar = S.bar(1);
705     bar
706 }
707 "#,
708         )
709     }
710
711     #[test]
712     fn extract_var_name_from_method_param() {
713         check_assist(
714             extract_variable,
715             r#"
716 struct S;
717 impl S {
718     fn bar(&self, n: u32, size: u32) { n }
719 }
720
721 fn foo() {
722     S.bar($01 + 1$0, 2)
723 }
724 "#,
725             r#"
726 struct S;
727 impl S {
728     fn bar(&self, n: u32, size: u32) { n }
729 }
730
731 fn foo() {
732     let $0n = 1 + 1;
733     S.bar(n, 2)
734 }
735 "#,
736         )
737     }
738
739     #[test]
740     fn extract_var_name_from_ufcs_method_param() {
741         check_assist(
742             extract_variable,
743             r#"
744 struct S;
745 impl S {
746     fn bar(&self, n: u32, size: u32) { n }
747 }
748
749 fn foo() {
750     S::bar(&S, $01 + 1$0, 2)
751 }
752 "#,
753             r#"
754 struct S;
755 impl S {
756     fn bar(&self, n: u32, size: u32) { n }
757 }
758
759 fn foo() {
760     let $0n = 1 + 1;
761     S::bar(&S, n, 2)
762 }
763 "#,
764         )
765     }
766
767     #[test]
768     fn extract_var_parameter_name_has_precedence_over_function() {
769         check_assist(
770             extract_variable,
771             r#"
772 fn bar(test: u32, size: u32)
773
774 fn foo() {
775     bar(1, $0symbol_size(1, 2)$0);
776 }
777 "#,
778             r#"
779 fn bar(test: u32, size: u32)
780
781 fn foo() {
782     let $0size = symbol_size(1, 2);
783     bar(1, size);
784 }
785 "#,
786         )
787     }
788
789     #[test]
790     fn test_extract_var_for_return_not_applicable() {
791         check_assist_not_applicable(extract_variable, "fn foo() { $0return$0; } ");
792     }
793
794     #[test]
795     fn test_extract_var_for_break_not_applicable() {
796         check_assist_not_applicable(extract_variable, "fn main() { loop { $0break$0; }; }");
797     }
798
799     // FIXME: This is not quite correct, but good enough(tm) for the sorting heuristic
800     #[test]
801     fn extract_var_target() {
802         check_assist_target(extract_variable, "fn foo() -> u32 { $0return 2 + 2$0; }", "2 + 2");
803
804         check_assist_target(
805             extract_variable,
806             "
807 fn main() {
808     let x = true;
809     let tuple = match x {
810         true => ($02 + 2$0, true)
811         _ => (0, false)
812     };
813 }
814 ",
815             "2 + 2",
816         );
817     }
818 }