]> git.lizzy.rs Git - rust.git/blob - crates/ide_assists/src/handlers/extract_variable.rs
Merge #10596
[rust.git] / crates / ide_assists / src / handlers / extract_variable.rs
1 use stdx::format_to;
2 use syntax::{
3     ast::{self, AstNode},
4     NodeOrToken,
5     SyntaxKind::{
6         BLOCK_EXPR, BREAK_EXPR, CLOSURE_EXPR, COMMENT, LOOP_EXPR, MATCH_ARM, MATCH_GUARD,
7         PATH_EXPR, RETURN_EXPR,
8     },
9     SyntaxNode,
10 };
11
12 use crate::{utils::suggest_name, AssistContext, AssistId, AssistKind, Assists};
13
14 // Assist: extract_variable
15 //
16 // Extracts subexpression into a variable.
17 //
18 // ```
19 // fn main() {
20 //     $0(1 + 2)$0 * 4;
21 // }
22 // ```
23 // ->
24 // ```
25 // fn main() {
26 //     let $0var_name = (1 + 2);
27 //     var_name * 4;
28 // }
29 // ```
30 pub(crate) fn extract_variable(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
31     if ctx.has_empty_selection() {
32         return None;
33     }
34
35     let node = match ctx.covering_element() {
36         NodeOrToken::Node(it) => it,
37         NodeOrToken::Token(it) if it.kind() == COMMENT => {
38             cov_mark::hit!(extract_var_in_comment_is_not_applicable);
39             return None;
40         }
41         NodeOrToken::Token(it) => it.parent()?,
42     };
43     let node = node.ancestors().take_while(|anc| anc.text_range() == node.text_range()).last()?;
44     let to_extract = node
45         .descendants()
46         .take_while(|it| ctx.selection_trimmed().contains_range(it.text_range()))
47         .find_map(valid_target_expr)?;
48
49     if let Some(ty_info) = ctx.sema.type_of_expr(&to_extract) {
50         if ty_info.adjusted().is_unit() {
51             return None;
52         }
53     }
54
55     let anchor = Anchor::from(&to_extract)?;
56     let indent = anchor.syntax().prev_sibling_or_token()?.as_token()?.clone();
57     let target = to_extract.syntax().text_range();
58     acc.add(
59         AssistId("extract_variable", AssistKind::RefactorExtract),
60         "Extract into variable",
61         target,
62         move |edit| {
63             let field_shorthand =
64                 match to_extract.syntax().parent().and_then(ast::RecordExprField::cast) {
65                     Some(field) => field.name_ref(),
66                     None => None,
67                 };
68
69             let mut buf = String::new();
70
71             let var_name = match &field_shorthand {
72                 Some(it) => it.to_string(),
73                 None => suggest_name::for_variable(&to_extract, &ctx.sema),
74             };
75             let expr_range = match &field_shorthand {
76                 Some(it) => it.syntax().text_range().cover(to_extract.syntax().text_range()),
77                 None => to_extract.syntax().text_range(),
78             };
79
80             match anchor {
81                 Anchor::Before(_) | Anchor::Replace(_) => {
82                     format_to!(buf, "let {} = ", var_name)
83                 }
84                 Anchor::WrapInBlock(_) => format_to!(buf, "{{ let {} = ", var_name),
85             };
86             format_to!(buf, "{}", to_extract.syntax());
87
88             if let Anchor::Replace(stmt) = anchor {
89                 cov_mark::hit!(test_extract_var_expr_stmt);
90                 if stmt.semicolon_token().is_none() {
91                     buf.push(';');
92                 }
93                 match ctx.config.snippet_cap {
94                     Some(cap) => {
95                         let snip = buf
96                             .replace(&format!("let {}", var_name), &format!("let $0{}", var_name));
97                         edit.replace_snippet(cap, expr_range, snip)
98                     }
99                     None => edit.replace(expr_range, buf),
100                 }
101                 return;
102             }
103
104             buf.push(';');
105
106             // We want to maintain the indent level,
107             // but we do not want to duplicate possible
108             // extra newlines in the indent block
109             let text = indent.text();
110             if text.starts_with('\n') {
111                 buf.push('\n');
112                 buf.push_str(text.trim_start_matches('\n'));
113             } else {
114                 buf.push_str(text);
115             }
116
117             edit.replace(expr_range, var_name.clone());
118             let offset = anchor.syntax().text_range().start();
119             match ctx.config.snippet_cap {
120                 Some(cap) => {
121                     let snip =
122                         buf.replace(&format!("let {}", var_name), &format!("let $0{}", var_name));
123                     edit.insert_snippet(cap, offset, snip)
124                 }
125                 None => edit.insert(offset, buf),
126             }
127
128             if let Anchor::WrapInBlock(_) = anchor {
129                 edit.insert(anchor.syntax().text_range().end(), " }");
130             }
131         },
132     )
133 }
134
135 /// Check whether the node is a valid expression which can be extracted to a variable.
136 /// In general that's true for any expression, but in some cases that would produce invalid code.
137 fn valid_target_expr(node: SyntaxNode) -> Option<ast::Expr> {
138     match node.kind() {
139         PATH_EXPR | LOOP_EXPR => None,
140         BREAK_EXPR => ast::BreakExpr::cast(node).and_then(|e| e.expr()),
141         RETURN_EXPR => ast::ReturnExpr::cast(node).and_then(|e| e.expr()),
142         BLOCK_EXPR => {
143             ast::BlockExpr::cast(node).filter(|it| it.is_standalone()).map(ast::Expr::from)
144         }
145         _ => ast::Expr::cast(node),
146     }
147 }
148
149 #[derive(Debug)]
150 enum Anchor {
151     Before(SyntaxNode),
152     Replace(ast::ExprStmt),
153     WrapInBlock(SyntaxNode),
154 }
155
156 impl Anchor {
157     fn from(to_extract: &ast::Expr) -> Option<Anchor> {
158         to_extract
159             .syntax()
160             .ancestors()
161             .take_while(|it| !ast::Item::can_cast(it.kind()) || ast::MacroCall::can_cast(it.kind()))
162             .find_map(|node| {
163                 if ast::MacroCall::can_cast(node.kind()) {
164                     return None;
165                 }
166                 if let Some(expr) =
167                     node.parent().and_then(ast::StmtList::cast).and_then(|it| it.tail_expr())
168                 {
169                     if expr.syntax() == &node {
170                         cov_mark::hit!(test_extract_var_last_expr);
171                         return Some(Anchor::Before(node));
172                     }
173                 }
174
175                 if let Some(parent) = node.parent() {
176                     if parent.kind() == CLOSURE_EXPR {
177                         cov_mark::hit!(test_extract_var_in_closure_no_block);
178                         return Some(Anchor::WrapInBlock(node));
179                     }
180                     if parent.kind() == MATCH_ARM {
181                         if node.kind() == MATCH_GUARD {
182                             cov_mark::hit!(test_extract_var_in_match_guard);
183                         } else {
184                             cov_mark::hit!(test_extract_var_in_match_arm_no_block);
185                             return Some(Anchor::WrapInBlock(node));
186                         }
187                     }
188                 }
189
190                 if let Some(stmt) = ast::Stmt::cast(node.clone()) {
191                     if let ast::Stmt::ExprStmt(stmt) = stmt {
192                         if stmt.expr().as_ref() == Some(to_extract) {
193                             return Some(Anchor::Replace(stmt));
194                         }
195                     }
196                     return Some(Anchor::Before(node));
197                 }
198                 None
199             })
200     }
201
202     fn syntax(&self) -> &SyntaxNode {
203         match self {
204             Anchor::Before(it) | Anchor::WrapInBlock(it) => it,
205             Anchor::Replace(stmt) => stmt.syntax(),
206         }
207     }
208 }
209
210 #[cfg(test)]
211 mod tests {
212     use crate::tests::{check_assist, check_assist_not_applicable, check_assist_target};
213
214     use super::*;
215
216     #[test]
217     fn test_extract_var_simple() {
218         check_assist(
219             extract_variable,
220             r#"
221 fn foo() {
222     foo($01 + 1$0);
223 }"#,
224             r#"
225 fn foo() {
226     let $0var_name = 1 + 1;
227     foo(var_name);
228 }"#,
229         );
230     }
231
232     #[test]
233     fn extract_var_in_comment_is_not_applicable() {
234         cov_mark::check!(extract_var_in_comment_is_not_applicable);
235         check_assist_not_applicable(extract_variable, "fn main() { 1 + /* $0comment$0 */ 1; }");
236     }
237
238     #[test]
239     fn test_extract_var_expr_stmt() {
240         cov_mark::check!(test_extract_var_expr_stmt);
241         check_assist(
242             extract_variable,
243             r#"
244 fn foo() {
245   $0  1 + 1$0;
246 }"#,
247             r#"
248 fn foo() {
249     let $0var_name = 1 + 1;
250 }"#,
251         );
252         check_assist(
253             extract_variable,
254             r"
255 fn foo() {
256     $0{ let x = 0; x }$0
257     something_else();
258 }",
259             r"
260 fn foo() {
261     let $0var_name = { let x = 0; x };
262     something_else();
263 }",
264         );
265     }
266
267     #[test]
268     fn test_extract_var_part_of_expr_stmt() {
269         check_assist(
270             extract_variable,
271             r"
272 fn foo() {
273     $01$0 + 1;
274 }",
275             r"
276 fn foo() {
277     let $0var_name = 1;
278     var_name + 1;
279 }",
280         );
281     }
282
283     #[test]
284     fn test_extract_var_last_expr() {
285         cov_mark::check!(test_extract_var_last_expr);
286         check_assist(
287             extract_variable,
288             r#"
289 fn foo() {
290     bar($01 + 1$0)
291 }
292 "#,
293             r#"
294 fn foo() {
295     let $0var_name = 1 + 1;
296     bar(var_name)
297 }
298 "#,
299         );
300         check_assist(
301             extract_variable,
302             r#"
303 fn foo() -> i32 {
304     $0bar(1 + 1)$0
305 }
306
307 fn bar(i: i32) -> i32 {
308     i
309 }
310 "#,
311             r#"
312 fn foo() -> i32 {
313     let $0bar = bar(1 + 1);
314     bar
315 }
316
317 fn bar(i: i32) -> i32 {
318     i
319 }
320 "#,
321         )
322     }
323
324     #[test]
325     fn test_extract_var_in_match_arm_no_block() {
326         cov_mark::check!(test_extract_var_in_match_arm_no_block);
327         check_assist(
328             extract_variable,
329             r#"
330 fn main() {
331     let x = true;
332     let tuple = match x {
333         true => ($02 + 2$0, true)
334         _ => (0, false)
335     };
336 }
337 "#,
338             r#"
339 fn main() {
340     let x = true;
341     let tuple = match x {
342         true => { let $0var_name = 2 + 2; (var_name, true) }
343         _ => (0, false)
344     };
345 }
346 "#,
347         );
348     }
349
350     #[test]
351     fn test_extract_var_in_match_arm_with_block() {
352         check_assist(
353             extract_variable,
354             r#"
355 fn main() {
356     let x = true;
357     let tuple = match x {
358         true => {
359             let y = 1;
360             ($02 + y$0, true)
361         }
362         _ => (0, false)
363     };
364 }
365 "#,
366             r#"
367 fn main() {
368     let x = true;
369     let tuple = match x {
370         true => {
371             let y = 1;
372             let $0var_name = 2 + y;
373             (var_name, true)
374         }
375         _ => (0, false)
376     };
377 }
378 "#,
379         );
380     }
381
382     #[test]
383     fn test_extract_var_in_match_guard() {
384         cov_mark::check!(test_extract_var_in_match_guard);
385         check_assist(
386             extract_variable,
387             r#"
388 fn main() {
389     match () {
390         () if $010 > 0$0 => 1
391         _ => 2
392     };
393 }
394 "#,
395             r#"
396 fn main() {
397     let $0var_name = 10 > 0;
398     match () {
399         () if var_name => 1
400         _ => 2
401     };
402 }
403 "#,
404         );
405     }
406
407     #[test]
408     fn test_extract_var_in_closure_no_block() {
409         cov_mark::check!(test_extract_var_in_closure_no_block);
410         check_assist(
411             extract_variable,
412             r#"
413 fn main() {
414     let lambda = |x: u32| $0x * 2$0;
415 }
416 "#,
417             r#"
418 fn main() {
419     let lambda = |x: u32| { let $0var_name = x * 2; var_name };
420 }
421 "#,
422         );
423     }
424
425     #[test]
426     fn test_extract_var_in_closure_with_block() {
427         check_assist(
428             extract_variable,
429             r#"
430 fn main() {
431     let lambda = |x: u32| { $0x * 2$0 };
432 }
433 "#,
434             r#"
435 fn main() {
436     let lambda = |x: u32| { let $0var_name = x * 2; var_name };
437 }
438 "#,
439         );
440     }
441
442     #[test]
443     fn test_extract_var_path_simple() {
444         check_assist(
445             extract_variable,
446             "
447 fn main() {
448     let o = $0Some(true)$0;
449 }
450 ",
451             "
452 fn main() {
453     let $0var_name = Some(true);
454     let o = var_name;
455 }
456 ",
457         );
458     }
459
460     #[test]
461     fn test_extract_var_path_method() {
462         check_assist(
463             extract_variable,
464             "
465 fn main() {
466     let v = $0bar.foo()$0;
467 }
468 ",
469             "
470 fn main() {
471     let $0foo = bar.foo();
472     let v = foo;
473 }
474 ",
475         );
476     }
477
478     #[test]
479     fn test_extract_var_return() {
480         check_assist(
481             extract_variable,
482             "
483 fn foo() -> u32 {
484     $0return 2 + 2$0;
485 }
486 ",
487             "
488 fn foo() -> u32 {
489     let $0var_name = 2 + 2;
490     return var_name;
491 }
492 ",
493         );
494     }
495
496     #[test]
497     fn test_extract_var_does_not_add_extra_whitespace() {
498         check_assist(
499             extract_variable,
500             "
501 fn foo() -> u32 {
502
503
504     $0return 2 + 2$0;
505 }
506 ",
507             "
508 fn foo() -> u32 {
509
510
511     let $0var_name = 2 + 2;
512     return var_name;
513 }
514 ",
515         );
516
517         check_assist(
518             extract_variable,
519             "
520 fn foo() -> u32 {
521
522         $0return 2 + 2$0;
523 }
524 ",
525             "
526 fn foo() -> u32 {
527
528         let $0var_name = 2 + 2;
529         return var_name;
530 }
531 ",
532         );
533
534         check_assist(
535             extract_variable,
536             "
537 fn foo() -> u32 {
538     let foo = 1;
539
540     // bar
541
542
543     $0return 2 + 2$0;
544 }
545 ",
546             "
547 fn foo() -> u32 {
548     let foo = 1;
549
550     // bar
551
552
553     let $0var_name = 2 + 2;
554     return var_name;
555 }
556 ",
557         );
558     }
559
560     #[test]
561     fn test_extract_var_break() {
562         check_assist(
563             extract_variable,
564             "
565 fn main() {
566     let result = loop {
567         $0break 2 + 2$0;
568     };
569 }
570 ",
571             "
572 fn main() {
573     let result = loop {
574         let $0var_name = 2 + 2;
575         break var_name;
576     };
577 }
578 ",
579         );
580     }
581
582     #[test]
583     fn test_extract_var_for_cast() {
584         check_assist(
585             extract_variable,
586             "
587 fn main() {
588     let v = $00f32 as u32$0;
589 }
590 ",
591             "
592 fn main() {
593     let $0var_name = 0f32 as u32;
594     let v = var_name;
595 }
596 ",
597         );
598     }
599
600     #[test]
601     fn extract_var_field_shorthand() {
602         check_assist(
603             extract_variable,
604             r#"
605 struct S {
606     foo: i32
607 }
608
609 fn main() {
610     S { foo: $01 + 1$0 }
611 }
612 "#,
613             r#"
614 struct S {
615     foo: i32
616 }
617
618 fn main() {
619     let $0foo = 1 + 1;
620     S { foo }
621 }
622 "#,
623         )
624     }
625
626     #[test]
627     fn extract_var_name_from_type() {
628         check_assist(
629             extract_variable,
630             r#"
631 struct Test(i32);
632
633 fn foo() -> Test {
634     $0{ Test(10) }$0
635 }
636 "#,
637             r#"
638 struct Test(i32);
639
640 fn foo() -> Test {
641     let $0test = { Test(10) };
642     test
643 }
644 "#,
645         )
646     }
647
648     #[test]
649     fn extract_var_name_from_parameter() {
650         check_assist(
651             extract_variable,
652             r#"
653 fn bar(test: u32, size: u32)
654
655 fn foo() {
656     bar(1, $01+1$0);
657 }
658 "#,
659             r#"
660 fn bar(test: u32, size: u32)
661
662 fn foo() {
663     let $0size = 1+1;
664     bar(1, size);
665 }
666 "#,
667         )
668     }
669
670     #[test]
671     fn extract_var_parameter_name_has_precedence_over_type() {
672         check_assist(
673             extract_variable,
674             r#"
675 struct TextSize(u32);
676 fn bar(test: u32, size: TextSize)
677
678 fn foo() {
679     bar(1, $0{ TextSize(1+1) }$0);
680 }
681 "#,
682             r#"
683 struct TextSize(u32);
684 fn bar(test: u32, size: TextSize)
685
686 fn foo() {
687     let $0size = { TextSize(1+1) };
688     bar(1, size);
689 }
690 "#,
691         )
692     }
693
694     #[test]
695     fn extract_var_name_from_function() {
696         check_assist(
697             extract_variable,
698             r#"
699 fn is_required(test: u32, size: u32) -> bool
700
701 fn foo() -> bool {
702     $0is_required(1, 2)$0
703 }
704 "#,
705             r#"
706 fn is_required(test: u32, size: u32) -> bool
707
708 fn foo() -> bool {
709     let $0is_required = is_required(1, 2);
710     is_required
711 }
712 "#,
713         )
714     }
715
716     #[test]
717     fn extract_var_name_from_method() {
718         check_assist(
719             extract_variable,
720             r#"
721 struct S;
722 impl S {
723     fn bar(&self, n: u32) -> u32 { n }
724 }
725
726 fn foo() -> u32 {
727     $0S.bar(1)$0
728 }
729 "#,
730             r#"
731 struct S;
732 impl S {
733     fn bar(&self, n: u32) -> u32 { n }
734 }
735
736 fn foo() -> u32 {
737     let $0bar = S.bar(1);
738     bar
739 }
740 "#,
741         )
742     }
743
744     #[test]
745     fn extract_var_name_from_method_param() {
746         check_assist(
747             extract_variable,
748             r#"
749 struct S;
750 impl S {
751     fn bar(&self, n: u32, size: u32) { n }
752 }
753
754 fn foo() {
755     S.bar($01 + 1$0, 2)
756 }
757 "#,
758             r#"
759 struct S;
760 impl S {
761     fn bar(&self, n: u32, size: u32) { n }
762 }
763
764 fn foo() {
765     let $0n = 1 + 1;
766     S.bar(n, 2)
767 }
768 "#,
769         )
770     }
771
772     #[test]
773     fn extract_var_name_from_ufcs_method_param() {
774         check_assist(
775             extract_variable,
776             r#"
777 struct S;
778 impl S {
779     fn bar(&self, n: u32, size: u32) { n }
780 }
781
782 fn foo() {
783     S::bar(&S, $01 + 1$0, 2)
784 }
785 "#,
786             r#"
787 struct S;
788 impl S {
789     fn bar(&self, n: u32, size: u32) { n }
790 }
791
792 fn foo() {
793     let $0n = 1 + 1;
794     S::bar(&S, n, 2)
795 }
796 "#,
797         )
798     }
799
800     #[test]
801     fn extract_var_parameter_name_has_precedence_over_function() {
802         check_assist(
803             extract_variable,
804             r#"
805 fn bar(test: u32, size: u32)
806
807 fn foo() {
808     bar(1, $0symbol_size(1, 2)$0);
809 }
810 "#,
811             r#"
812 fn bar(test: u32, size: u32)
813
814 fn foo() {
815     let $0size = symbol_size(1, 2);
816     bar(1, size);
817 }
818 "#,
819         )
820     }
821
822     #[test]
823     fn extract_macro_call() {
824         check_assist(
825             extract_variable,
826             r"
827 struct Vec;
828 macro_rules! vec {
829     () => {Vec}
830 }
831 fn main() {
832     let _ = $0vec![]$0;
833 }
834 ",
835             r"
836 struct Vec;
837 macro_rules! vec {
838     () => {Vec}
839 }
840 fn main() {
841     let $0vec = vec![];
842     let _ = vec;
843 }
844 ",
845         );
846     }
847
848     #[test]
849     fn test_extract_var_for_return_not_applicable() {
850         check_assist_not_applicable(extract_variable, "fn foo() { $0return$0; } ");
851     }
852
853     #[test]
854     fn test_extract_var_for_break_not_applicable() {
855         check_assist_not_applicable(extract_variable, "fn main() { loop { $0break$0; }; }");
856     }
857
858     #[test]
859     fn test_extract_var_unit_expr_not_applicable() {
860         check_assist_not_applicable(
861             extract_variable,
862             r#"
863 fn foo() {
864     let mut i = 3;
865     $0if i >= 0 {
866         i += 1;
867     } else {
868         i -= 1;
869     }$0
870 }"#,
871         );
872     }
873
874     // FIXME: This is not quite correct, but good enough(tm) for the sorting heuristic
875     #[test]
876     fn extract_var_target() {
877         check_assist_target(extract_variable, "fn foo() -> u32 { $0return 2 + 2$0; }", "2 + 2");
878
879         check_assist_target(
880             extract_variable,
881             "
882 fn main() {
883     let x = true;
884     let tuple = match x {
885         true => ($02 + 2$0, true)
886         _ => (0, false)
887     };
888 }
889 ",
890             "2 + 2",
891         );
892     }
893
894     #[test]
895     fn extract_var_no_block_body() {
896         check_assist_not_applicable(
897             extract_variable,
898             r"
899 const X: usize = $0100$0;
900 ",
901         );
902     }
903 }