]> git.lizzy.rs Git - rust.git/blob - crates/ide_assists/src/handlers/extract_variable.rs
Merge #10423
[rust.git] / crates / ide_assists / src / handlers / extract_variable.rs
1 use stdx::format_to;
2 use syntax::{
3     ast::{self, AstNode},
4     NodeOrToken,
5     SyntaxKind::{
6         BLOCK_EXPR, BREAK_EXPR, CLOSURE_EXPR, COMMENT, LOOP_EXPR, MATCH_ARM, MATCH_GUARD,
7         PATH_EXPR, RETURN_EXPR,
8     },
9     SyntaxNode,
10 };
11
12 use crate::{utils::suggest_name, AssistContext, AssistId, AssistKind, Assists};
13
14 // Assist: extract_variable
15 //
16 // Extracts subexpression into a variable.
17 //
18 // ```
19 // fn main() {
20 //     $0(1 + 2)$0 * 4;
21 // }
22 // ```
23 // ->
24 // ```
25 // fn main() {
26 //     let $0var_name = (1 + 2);
27 //     var_name * 4;
28 // }
29 // ```
30 pub(crate) fn extract_variable(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
31     if ctx.frange.range.is_empty() {
32         return None;
33     }
34
35     let node = match ctx.covering_element() {
36         NodeOrToken::Node(it) => it,
37         NodeOrToken::Token(it) if it.kind() == COMMENT => {
38             cov_mark::hit!(extract_var_in_comment_is_not_applicable);
39             return None;
40         }
41         NodeOrToken::Token(it) => it.parent()?,
42     };
43     let node = node.ancestors().take_while(|anc| anc.text_range() == node.text_range()).last()?;
44     let to_extract = node
45         .descendants()
46         .take_while(|it| ctx.frange.range.contains_range(it.text_range()))
47         .find_map(valid_target_expr)?;
48
49     if let Some(ty_info) = ctx.sema.type_of_expr(&to_extract) {
50         if ty_info.adjusted().is_unit() {
51             return None;
52         }
53     }
54
55     let anchor = Anchor::from(&to_extract)?;
56     let indent = anchor.syntax().prev_sibling_or_token()?.as_token()?.clone();
57     let target = to_extract.syntax().text_range();
58     acc.add(
59         AssistId("extract_variable", AssistKind::RefactorExtract),
60         "Extract into variable",
61         target,
62         move |edit| {
63             let field_shorthand =
64                 match to_extract.syntax().parent().and_then(ast::RecordExprField::cast) {
65                     Some(field) => field.name_ref(),
66                     None => None,
67                 };
68
69             let mut buf = String::new();
70
71             let var_name = match &field_shorthand {
72                 Some(it) => it.to_string(),
73                 None => suggest_name::for_variable(&to_extract, &ctx.sema),
74             };
75             let expr_range = match &field_shorthand {
76                 Some(it) => it.syntax().text_range().cover(to_extract.syntax().text_range()),
77                 None => to_extract.syntax().text_range(),
78             };
79
80             match anchor {
81                 Anchor::Before(_) | Anchor::Replace(_) => {
82                     format_to!(buf, "let {} = ", var_name)
83                 }
84                 Anchor::WrapInBlock(_) => format_to!(buf, "{{ let {} = ", var_name),
85             };
86             format_to!(buf, "{}", to_extract.syntax());
87
88             if let Anchor::Replace(stmt) = anchor {
89                 cov_mark::hit!(test_extract_var_expr_stmt);
90                 if stmt.semicolon_token().is_none() {
91                     buf.push(';');
92                 }
93                 match ctx.config.snippet_cap {
94                     Some(cap) => {
95                         let snip = buf
96                             .replace(&format!("let {}", var_name), &format!("let $0{}", var_name));
97                         edit.replace_snippet(cap, expr_range, snip)
98                     }
99                     None => edit.replace(expr_range, buf),
100                 }
101                 return;
102             }
103
104             buf.push(';');
105
106             // We want to maintain the indent level,
107             // but we do not want to duplicate possible
108             // extra newlines in the indent block
109             let text = indent.text();
110             if text.starts_with('\n') {
111                 buf.push('\n');
112                 buf.push_str(text.trim_start_matches('\n'));
113             } else {
114                 buf.push_str(text);
115             }
116
117             edit.replace(expr_range, var_name.clone());
118             let offset = anchor.syntax().text_range().start();
119             match ctx.config.snippet_cap {
120                 Some(cap) => {
121                     let snip =
122                         buf.replace(&format!("let {}", var_name), &format!("let $0{}", var_name));
123                     edit.insert_snippet(cap, offset, snip)
124                 }
125                 None => edit.insert(offset, buf),
126             }
127
128             if let Anchor::WrapInBlock(_) = anchor {
129                 edit.insert(anchor.syntax().text_range().end(), " }");
130             }
131         },
132     )
133 }
134
135 /// Check whether the node is a valid expression which can be extracted to a variable.
136 /// In general that's true for any expression, but in some cases that would produce invalid code.
137 fn valid_target_expr(node: SyntaxNode) -> Option<ast::Expr> {
138     match node.kind() {
139         PATH_EXPR | LOOP_EXPR => None,
140         BREAK_EXPR => ast::BreakExpr::cast(node).and_then(|e| e.expr()),
141         RETURN_EXPR => ast::ReturnExpr::cast(node).and_then(|e| e.expr()),
142         BLOCK_EXPR => {
143             ast::BlockExpr::cast(node).filter(|it| it.is_standalone()).map(ast::Expr::from)
144         }
145         _ => ast::Expr::cast(node),
146     }
147 }
148
149 #[derive(Debug)]
150 enum Anchor {
151     Before(SyntaxNode),
152     Replace(ast::ExprStmt),
153     WrapInBlock(SyntaxNode),
154 }
155
156 impl Anchor {
157     fn from(to_extract: &ast::Expr) -> Option<Anchor> {
158         to_extract
159             .syntax()
160             .ancestors()
161             .take_while(|it| !ast::Item::can_cast(it.kind()) || ast::MacroCall::can_cast(it.kind()))
162             .find_map(|node| {
163                 if let Some(expr) =
164                     node.parent().and_then(ast::StmtList::cast).and_then(|it| it.tail_expr())
165                 {
166                     if expr.syntax() == &node {
167                         cov_mark::hit!(test_extract_var_last_expr);
168                         return Some(Anchor::Before(node));
169                     }
170                 }
171
172                 if let Some(parent) = node.parent() {
173                     if parent.kind() == CLOSURE_EXPR {
174                         cov_mark::hit!(test_extract_var_in_closure_no_block);
175                         return Some(Anchor::WrapInBlock(node));
176                     }
177                     if parent.kind() == MATCH_ARM {
178                         if node.kind() == MATCH_GUARD {
179                             cov_mark::hit!(test_extract_var_in_match_guard);
180                         } else {
181                             cov_mark::hit!(test_extract_var_in_match_arm_no_block);
182                             return Some(Anchor::WrapInBlock(node));
183                         }
184                     }
185                 }
186
187                 if let Some(stmt) = ast::Stmt::cast(node.clone()) {
188                     if let ast::Stmt::ExprStmt(stmt) = stmt {
189                         if stmt.expr().as_ref() == Some(to_extract) {
190                             return Some(Anchor::Replace(stmt));
191                         }
192                     }
193                     return Some(Anchor::Before(node));
194                 }
195                 None
196             })
197     }
198
199     fn syntax(&self) -> &SyntaxNode {
200         match self {
201             Anchor::Before(it) | Anchor::WrapInBlock(it) => it,
202             Anchor::Replace(stmt) => stmt.syntax(),
203         }
204     }
205 }
206
207 #[cfg(test)]
208 mod tests {
209     use crate::tests::{check_assist, check_assist_not_applicable, check_assist_target};
210
211     use super::*;
212
213     #[test]
214     fn test_extract_var_simple() {
215         check_assist(
216             extract_variable,
217             r#"
218 fn foo() {
219     foo($01 + 1$0);
220 }"#,
221             r#"
222 fn foo() {
223     let $0var_name = 1 + 1;
224     foo(var_name);
225 }"#,
226         );
227     }
228
229     #[test]
230     fn extract_var_in_comment_is_not_applicable() {
231         cov_mark::check!(extract_var_in_comment_is_not_applicable);
232         check_assist_not_applicable(extract_variable, "fn main() { 1 + /* $0comment$0 */ 1; }");
233     }
234
235     #[test]
236     fn test_extract_var_expr_stmt() {
237         cov_mark::check!(test_extract_var_expr_stmt);
238         check_assist(
239             extract_variable,
240             r#"
241 fn foo() {
242   $0  1 + 1$0;
243 }"#,
244             r#"
245 fn foo() {
246     let $0var_name = 1 + 1;
247 }"#,
248         );
249         check_assist(
250             extract_variable,
251             r"
252 fn foo() {
253     $0{ let x = 0; x }$0
254     something_else();
255 }",
256             r"
257 fn foo() {
258     let $0var_name = { let x = 0; x };
259     something_else();
260 }",
261         );
262     }
263
264     #[test]
265     fn test_extract_var_part_of_expr_stmt() {
266         check_assist(
267             extract_variable,
268             r"
269 fn foo() {
270     $01$0 + 1;
271 }",
272             r"
273 fn foo() {
274     let $0var_name = 1;
275     var_name + 1;
276 }",
277         );
278     }
279
280     #[test]
281     fn test_extract_var_last_expr() {
282         cov_mark::check!(test_extract_var_last_expr);
283         check_assist(
284             extract_variable,
285             r#"
286 fn foo() {
287     bar($01 + 1$0)
288 }
289 "#,
290             r#"
291 fn foo() {
292     let $0var_name = 1 + 1;
293     bar(var_name)
294 }
295 "#,
296         );
297         check_assist(
298             extract_variable,
299             r#"
300 fn foo() -> i32 {
301     $0bar(1 + 1)$0
302 }
303
304 fn bar(i: i32) -> i32 {
305     i
306 }
307 "#,
308             r#"
309 fn foo() -> i32 {
310     let $0bar = bar(1 + 1);
311     bar
312 }
313
314 fn bar(i: i32) -> i32 {
315     i
316 }
317 "#,
318         )
319     }
320
321     #[test]
322     fn test_extract_var_in_match_arm_no_block() {
323         cov_mark::check!(test_extract_var_in_match_arm_no_block);
324         check_assist(
325             extract_variable,
326             r#"
327 fn main() {
328     let x = true;
329     let tuple = match x {
330         true => ($02 + 2$0, true)
331         _ => (0, false)
332     };
333 }
334 "#,
335             r#"
336 fn main() {
337     let x = true;
338     let tuple = match x {
339         true => { let $0var_name = 2 + 2; (var_name, true) }
340         _ => (0, false)
341     };
342 }
343 "#,
344         );
345     }
346
347     #[test]
348     fn test_extract_var_in_match_arm_with_block() {
349         check_assist(
350             extract_variable,
351             r#"
352 fn main() {
353     let x = true;
354     let tuple = match x {
355         true => {
356             let y = 1;
357             ($02 + y$0, true)
358         }
359         _ => (0, false)
360     };
361 }
362 "#,
363             r#"
364 fn main() {
365     let x = true;
366     let tuple = match x {
367         true => {
368             let y = 1;
369             let $0var_name = 2 + y;
370             (var_name, true)
371         }
372         _ => (0, false)
373     };
374 }
375 "#,
376         );
377     }
378
379     #[test]
380     fn test_extract_var_in_match_guard() {
381         cov_mark::check!(test_extract_var_in_match_guard);
382         check_assist(
383             extract_variable,
384             r#"
385 fn main() {
386     match () {
387         () if $010 > 0$0 => 1
388         _ => 2
389     };
390 }
391 "#,
392             r#"
393 fn main() {
394     let $0var_name = 10 > 0;
395     match () {
396         () if var_name => 1
397         _ => 2
398     };
399 }
400 "#,
401         );
402     }
403
404     #[test]
405     fn test_extract_var_in_closure_no_block() {
406         cov_mark::check!(test_extract_var_in_closure_no_block);
407         check_assist(
408             extract_variable,
409             r#"
410 fn main() {
411     let lambda = |x: u32| $0x * 2$0;
412 }
413 "#,
414             r#"
415 fn main() {
416     let lambda = |x: u32| { let $0var_name = x * 2; var_name };
417 }
418 "#,
419         );
420     }
421
422     #[test]
423     fn test_extract_var_in_closure_with_block() {
424         check_assist(
425             extract_variable,
426             r#"
427 fn main() {
428     let lambda = |x: u32| { $0x * 2$0 };
429 }
430 "#,
431             r#"
432 fn main() {
433     let lambda = |x: u32| { let $0var_name = x * 2; var_name };
434 }
435 "#,
436         );
437     }
438
439     #[test]
440     fn test_extract_var_path_simple() {
441         check_assist(
442             extract_variable,
443             "
444 fn main() {
445     let o = $0Some(true)$0;
446 }
447 ",
448             "
449 fn main() {
450     let $0var_name = Some(true);
451     let o = var_name;
452 }
453 ",
454         );
455     }
456
457     #[test]
458     fn test_extract_var_path_method() {
459         check_assist(
460             extract_variable,
461             "
462 fn main() {
463     let v = $0bar.foo()$0;
464 }
465 ",
466             "
467 fn main() {
468     let $0foo = bar.foo();
469     let v = foo;
470 }
471 ",
472         );
473     }
474
475     #[test]
476     fn test_extract_var_return() {
477         check_assist(
478             extract_variable,
479             "
480 fn foo() -> u32 {
481     $0return 2 + 2$0;
482 }
483 ",
484             "
485 fn foo() -> u32 {
486     let $0var_name = 2 + 2;
487     return var_name;
488 }
489 ",
490         );
491     }
492
493     #[test]
494     fn test_extract_var_does_not_add_extra_whitespace() {
495         check_assist(
496             extract_variable,
497             "
498 fn foo() -> u32 {
499
500
501     $0return 2 + 2$0;
502 }
503 ",
504             "
505 fn foo() -> u32 {
506
507
508     let $0var_name = 2 + 2;
509     return var_name;
510 }
511 ",
512         );
513
514         check_assist(
515             extract_variable,
516             "
517 fn foo() -> u32 {
518
519         $0return 2 + 2$0;
520 }
521 ",
522             "
523 fn foo() -> u32 {
524
525         let $0var_name = 2 + 2;
526         return var_name;
527 }
528 ",
529         );
530
531         check_assist(
532             extract_variable,
533             "
534 fn foo() -> u32 {
535     let foo = 1;
536
537     // bar
538
539
540     $0return 2 + 2$0;
541 }
542 ",
543             "
544 fn foo() -> u32 {
545     let foo = 1;
546
547     // bar
548
549
550     let $0var_name = 2 + 2;
551     return var_name;
552 }
553 ",
554         );
555     }
556
557     #[test]
558     fn test_extract_var_break() {
559         check_assist(
560             extract_variable,
561             "
562 fn main() {
563     let result = loop {
564         $0break 2 + 2$0;
565     };
566 }
567 ",
568             "
569 fn main() {
570     let result = loop {
571         let $0var_name = 2 + 2;
572         break var_name;
573     };
574 }
575 ",
576         );
577     }
578
579     #[test]
580     fn test_extract_var_for_cast() {
581         check_assist(
582             extract_variable,
583             "
584 fn main() {
585     let v = $00f32 as u32$0;
586 }
587 ",
588             "
589 fn main() {
590     let $0var_name = 0f32 as u32;
591     let v = var_name;
592 }
593 ",
594         );
595     }
596
597     #[test]
598     fn extract_var_field_shorthand() {
599         check_assist(
600             extract_variable,
601             r#"
602 struct S {
603     foo: i32
604 }
605
606 fn main() {
607     S { foo: $01 + 1$0 }
608 }
609 "#,
610             r#"
611 struct S {
612     foo: i32
613 }
614
615 fn main() {
616     let $0foo = 1 + 1;
617     S { foo }
618 }
619 "#,
620         )
621     }
622
623     #[test]
624     fn extract_var_name_from_type() {
625         check_assist(
626             extract_variable,
627             r#"
628 struct Test(i32);
629
630 fn foo() -> Test {
631     $0{ Test(10) }$0
632 }
633 "#,
634             r#"
635 struct Test(i32);
636
637 fn foo() -> Test {
638     let $0test = { Test(10) };
639     test
640 }
641 "#,
642         )
643     }
644
645     #[test]
646     fn extract_var_name_from_parameter() {
647         check_assist(
648             extract_variable,
649             r#"
650 fn bar(test: u32, size: u32)
651
652 fn foo() {
653     bar(1, $01+1$0);
654 }
655 "#,
656             r#"
657 fn bar(test: u32, size: u32)
658
659 fn foo() {
660     let $0size = 1+1;
661     bar(1, size);
662 }
663 "#,
664         )
665     }
666
667     #[test]
668     fn extract_var_parameter_name_has_precedence_over_type() {
669         check_assist(
670             extract_variable,
671             r#"
672 struct TextSize(u32);
673 fn bar(test: u32, size: TextSize)
674
675 fn foo() {
676     bar(1, $0{ TextSize(1+1) }$0);
677 }
678 "#,
679             r#"
680 struct TextSize(u32);
681 fn bar(test: u32, size: TextSize)
682
683 fn foo() {
684     let $0size = { TextSize(1+1) };
685     bar(1, size);
686 }
687 "#,
688         )
689     }
690
691     #[test]
692     fn extract_var_name_from_function() {
693         check_assist(
694             extract_variable,
695             r#"
696 fn is_required(test: u32, size: u32) -> bool
697
698 fn foo() -> bool {
699     $0is_required(1, 2)$0
700 }
701 "#,
702             r#"
703 fn is_required(test: u32, size: u32) -> bool
704
705 fn foo() -> bool {
706     let $0is_required = is_required(1, 2);
707     is_required
708 }
709 "#,
710         )
711     }
712
713     #[test]
714     fn extract_var_name_from_method() {
715         check_assist(
716             extract_variable,
717             r#"
718 struct S;
719 impl S {
720     fn bar(&self, n: u32) -> u32 { n }
721 }
722
723 fn foo() -> u32 {
724     $0S.bar(1)$0
725 }
726 "#,
727             r#"
728 struct S;
729 impl S {
730     fn bar(&self, n: u32) -> u32 { n }
731 }
732
733 fn foo() -> u32 {
734     let $0bar = S.bar(1);
735     bar
736 }
737 "#,
738         )
739     }
740
741     #[test]
742     fn extract_var_name_from_method_param() {
743         check_assist(
744             extract_variable,
745             r#"
746 struct S;
747 impl S {
748     fn bar(&self, n: u32, size: u32) { n }
749 }
750
751 fn foo() {
752     S.bar($01 + 1$0, 2)
753 }
754 "#,
755             r#"
756 struct S;
757 impl S {
758     fn bar(&self, n: u32, size: u32) { n }
759 }
760
761 fn foo() {
762     let $0n = 1 + 1;
763     S.bar(n, 2)
764 }
765 "#,
766         )
767     }
768
769     #[test]
770     fn extract_var_name_from_ufcs_method_param() {
771         check_assist(
772             extract_variable,
773             r#"
774 struct S;
775 impl S {
776     fn bar(&self, n: u32, size: u32) { n }
777 }
778
779 fn foo() {
780     S::bar(&S, $01 + 1$0, 2)
781 }
782 "#,
783             r#"
784 struct S;
785 impl S {
786     fn bar(&self, n: u32, size: u32) { n }
787 }
788
789 fn foo() {
790     let $0n = 1 + 1;
791     S::bar(&S, n, 2)
792 }
793 "#,
794         )
795     }
796
797     #[test]
798     fn extract_var_parameter_name_has_precedence_over_function() {
799         check_assist(
800             extract_variable,
801             r#"
802 fn bar(test: u32, size: u32)
803
804 fn foo() {
805     bar(1, $0symbol_size(1, 2)$0);
806 }
807 "#,
808             r#"
809 fn bar(test: u32, size: u32)
810
811 fn foo() {
812     let $0size = symbol_size(1, 2);
813     bar(1, size);
814 }
815 "#,
816         )
817     }
818
819     #[test]
820     fn test_extract_var_for_return_not_applicable() {
821         check_assist_not_applicable(extract_variable, "fn foo() { $0return$0; } ");
822     }
823
824     #[test]
825     fn test_extract_var_for_break_not_applicable() {
826         check_assist_not_applicable(extract_variable, "fn main() { loop { $0break$0; }; }");
827     }
828
829     #[test]
830     fn test_extract_var_unit_expr_not_applicable() {
831         check_assist_not_applicable(
832             extract_variable,
833             r#"
834 fn foo() {
835     let mut i = 3;
836     $0if i >= 0 {
837         i += 1;
838     } else {
839         i -= 1;
840     }$0
841 }"#,
842         );
843     }
844
845     // FIXME: This is not quite correct, but good enough(tm) for the sorting heuristic
846     #[test]
847     fn extract_var_target() {
848         check_assist_target(extract_variable, "fn foo() -> u32 { $0return 2 + 2$0; }", "2 + 2");
849
850         check_assist_target(
851             extract_variable,
852             "
853 fn main() {
854     let x = true;
855     let tuple = match x {
856         true => ($02 + 2$0, true)
857         _ => (0, false)
858     };
859 }
860 ",
861             "2 + 2",
862         );
863     }
864
865     #[test]
866     fn extract_var_no_block_body() {
867         check_assist_not_applicable(
868             extract_variable,
869             r"
870 const X: usize = $0100$0;
871 ",
872         );
873     }
874 }