]> git.lizzy.rs Git - rust.git/blob - crates/ide_assists/src/handlers/extract_variable.rs
Merge #10181
[rust.git] / crates / ide_assists / src / handlers / extract_variable.rs
1 use stdx::format_to;
2 use syntax::{
3     ast::{self, AstNode},
4     SyntaxKind::{
5         BLOCK_EXPR, BREAK_EXPR, CLOSURE_EXPR, COMMENT, LOOP_EXPR, MATCH_ARM, MATCH_GUARD,
6         PATH_EXPR, RETURN_EXPR,
7     },
8     SyntaxNode,
9 };
10
11 use crate::{utils::suggest_name, AssistContext, AssistId, AssistKind, Assists};
12
13 // Assist: extract_variable
14 //
15 // Extracts subexpression into a variable.
16 //
17 // ```
18 // fn main() {
19 //     $0(1 + 2)$0 * 4;
20 // }
21 // ```
22 // ->
23 // ```
24 // fn main() {
25 //     let $0var_name = (1 + 2);
26 //     var_name * 4;
27 // }
28 // ```
29 pub(crate) fn extract_variable(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
30     if ctx.frange.range.is_empty() {
31         return None;
32     }
33     let node = ctx.covering_element();
34     if node.kind() == COMMENT {
35         cov_mark::hit!(extract_var_in_comment_is_not_applicable);
36         return None;
37     }
38     let to_extract = node
39         .ancestors()
40         .take_while(|it| it.text_range().contains_range(ctx.frange.range))
41         .find_map(valid_target_expr)?;
42     if let Some(ty_info) = ctx.sema.type_of_expr(&to_extract) {
43         if ty_info.adjusted().is_unit() {
44             return None;
45         }
46     }
47     let anchor = Anchor::from(&to_extract)?;
48     let indent = anchor.syntax().prev_sibling_or_token()?.as_token()?.clone();
49     let target = to_extract.syntax().text_range();
50     acc.add(
51         AssistId("extract_variable", AssistKind::RefactorExtract),
52         "Extract into variable",
53         target,
54         move |edit| {
55             let field_shorthand =
56                 match to_extract.syntax().parent().and_then(ast::RecordExprField::cast) {
57                     Some(field) => field.name_ref(),
58                     None => None,
59                 };
60
61             let mut buf = String::new();
62
63             let var_name = match &field_shorthand {
64                 Some(it) => it.to_string(),
65                 None => suggest_name::for_variable(&to_extract, &ctx.sema),
66             };
67             let expr_range = match &field_shorthand {
68                 Some(it) => it.syntax().text_range().cover(to_extract.syntax().text_range()),
69                 None => to_extract.syntax().text_range(),
70             };
71
72             if let Anchor::WrapInBlock(_) = anchor {
73                 format_to!(buf, "{{ let {} = ", var_name);
74             } else {
75                 format_to!(buf, "let {} = ", var_name);
76             };
77             format_to!(buf, "{}", to_extract.syntax());
78
79             if let Anchor::Replace(stmt) = anchor {
80                 cov_mark::hit!(test_extract_var_expr_stmt);
81                 if stmt.semicolon_token().is_none() {
82                     buf.push(';');
83                 }
84                 match ctx.config.snippet_cap {
85                     Some(cap) => {
86                         let snip = buf
87                             .replace(&format!("let {}", var_name), &format!("let $0{}", var_name));
88                         edit.replace_snippet(cap, expr_range, snip)
89                     }
90                     None => edit.replace(expr_range, buf),
91                 }
92                 return;
93             }
94
95             buf.push(';');
96
97             // We want to maintain the indent level,
98             // but we do not want to duplicate possible
99             // extra newlines in the indent block
100             let text = indent.text();
101             if text.starts_with('\n') {
102                 buf.push('\n');
103                 buf.push_str(text.trim_start_matches('\n'));
104             } else {
105                 buf.push_str(text);
106             }
107
108             edit.replace(expr_range, var_name.clone());
109             let offset = anchor.syntax().text_range().start();
110             match ctx.config.snippet_cap {
111                 Some(cap) => {
112                     let snip =
113                         buf.replace(&format!("let {}", var_name), &format!("let $0{}", var_name));
114                     edit.insert_snippet(cap, offset, snip)
115                 }
116                 None => edit.insert(offset, buf),
117             }
118
119             if let Anchor::WrapInBlock(_) = anchor {
120                 edit.insert(anchor.syntax().text_range().end(), " }");
121             }
122         },
123     )
124 }
125
126 /// Check whether the node is a valid expression which can be extracted to a variable.
127 /// In general that's true for any expression, but in some cases that would produce invalid code.
128 fn valid_target_expr(node: SyntaxNode) -> Option<ast::Expr> {
129     match node.kind() {
130         PATH_EXPR | LOOP_EXPR => None,
131         BREAK_EXPR => ast::BreakExpr::cast(node).and_then(|e| e.expr()),
132         RETURN_EXPR => ast::ReturnExpr::cast(node).and_then(|e| e.expr()),
133         BLOCK_EXPR => {
134             ast::BlockExpr::cast(node).filter(|it| it.is_standalone()).map(ast::Expr::from)
135         }
136         _ => ast::Expr::cast(node),
137     }
138 }
139
140 #[derive(Debug)]
141 enum Anchor {
142     Before(SyntaxNode),
143     Replace(ast::ExprStmt),
144     WrapInBlock(SyntaxNode),
145 }
146
147 impl Anchor {
148     fn from(to_extract: &ast::Expr) -> Option<Anchor> {
149         to_extract.syntax().ancestors().take_while(|it| !ast::Item::can_cast(it.kind())).find_map(
150             |node| {
151                 if let Some(expr) =
152                     node.parent().and_then(ast::StmtList::cast).and_then(|it| it.tail_expr())
153                 {
154                     if expr.syntax() == &node {
155                         cov_mark::hit!(test_extract_var_last_expr);
156                         return Some(Anchor::Before(node));
157                     }
158                 }
159
160                 if let Some(parent) = node.parent() {
161                     if parent.kind() == CLOSURE_EXPR {
162                         cov_mark::hit!(test_extract_var_in_closure_no_block);
163                         return Some(Anchor::WrapInBlock(node));
164                     }
165                     if parent.kind() == MATCH_ARM {
166                         if node.kind() == MATCH_GUARD {
167                             cov_mark::hit!(test_extract_var_in_match_guard);
168                         } else {
169                             cov_mark::hit!(test_extract_var_in_match_arm_no_block);
170                             return Some(Anchor::WrapInBlock(node));
171                         }
172                     }
173                 }
174
175                 if let Some(stmt) = ast::Stmt::cast(node.clone()) {
176                     if let ast::Stmt::ExprStmt(stmt) = stmt {
177                         if stmt.expr().as_ref() == Some(to_extract) {
178                             return Some(Anchor::Replace(stmt));
179                         }
180                     }
181                     return Some(Anchor::Before(node));
182                 }
183                 None
184             },
185         )
186     }
187
188     fn syntax(&self) -> &SyntaxNode {
189         match self {
190             Anchor::Before(it) | Anchor::WrapInBlock(it) => it,
191             Anchor::Replace(stmt) => stmt.syntax(),
192         }
193     }
194 }
195
196 #[cfg(test)]
197 mod tests {
198     use crate::tests::{check_assist, check_assist_not_applicable, check_assist_target};
199
200     use super::*;
201
202     #[test]
203     fn test_extract_var_simple() {
204         check_assist(
205             extract_variable,
206             r#"
207 fn foo() {
208     foo($01 + 1$0);
209 }"#,
210             r#"
211 fn foo() {
212     let $0var_name = 1 + 1;
213     foo(var_name);
214 }"#,
215         );
216     }
217
218     #[test]
219     fn extract_var_in_comment_is_not_applicable() {
220         cov_mark::check!(extract_var_in_comment_is_not_applicable);
221         check_assist_not_applicable(extract_variable, "fn main() { 1 + /* $0comment$0 */ 1; }");
222     }
223
224     #[test]
225     fn test_extract_var_expr_stmt() {
226         cov_mark::check!(test_extract_var_expr_stmt);
227         check_assist(
228             extract_variable,
229             r#"
230 fn foo() {
231     $01 + 1$0;
232 }"#,
233             r#"
234 fn foo() {
235     let $0var_name = 1 + 1;
236 }"#,
237         );
238         check_assist(
239             extract_variable,
240             "
241 fn foo() {
242     $0{ let x = 0; x }$0
243     something_else();
244 }",
245             "
246 fn foo() {
247     let $0var_name = { let x = 0; x };
248     something_else();
249 }",
250         );
251     }
252
253     #[test]
254     fn test_extract_var_part_of_expr_stmt() {
255         check_assist(
256             extract_variable,
257             "
258 fn foo() {
259     $01$0 + 1;
260 }",
261             "
262 fn foo() {
263     let $0var_name = 1;
264     var_name + 1;
265 }",
266         );
267     }
268
269     #[test]
270     fn test_extract_var_last_expr() {
271         cov_mark::check!(test_extract_var_last_expr);
272         check_assist(
273             extract_variable,
274             r#"
275 fn foo() {
276     bar($01 + 1$0)
277 }
278 "#,
279             r#"
280 fn foo() {
281     let $0var_name = 1 + 1;
282     bar(var_name)
283 }
284 "#,
285         );
286         check_assist(
287             extract_variable,
288             r#"
289 fn foo() -> i32 {
290     $0bar(1 + 1)$0
291 }
292
293 fn bar(i: i32) -> i32 {
294     i
295 }
296 "#,
297             r#"
298 fn foo() -> i32 {
299     let $0bar = bar(1 + 1);
300     bar
301 }
302
303 fn bar(i: i32) -> i32 {
304     i
305 }
306 "#,
307         )
308     }
309
310     #[test]
311     fn test_extract_var_in_match_arm_no_block() {
312         cov_mark::check!(test_extract_var_in_match_arm_no_block);
313         check_assist(
314             extract_variable,
315             r#"
316 fn main() {
317     let x = true;
318     let tuple = match x {
319         true => ($02 + 2$0, true)
320         _ => (0, false)
321     };
322 }
323 "#,
324             r#"
325 fn main() {
326     let x = true;
327     let tuple = match x {
328         true => { let $0var_name = 2 + 2; (var_name, true) }
329         _ => (0, false)
330     };
331 }
332 "#,
333         );
334     }
335
336     #[test]
337     fn test_extract_var_in_match_arm_with_block() {
338         check_assist(
339             extract_variable,
340             r#"
341 fn main() {
342     let x = true;
343     let tuple = match x {
344         true => {
345             let y = 1;
346             ($02 + y$0, true)
347         }
348         _ => (0, false)
349     };
350 }
351 "#,
352             r#"
353 fn main() {
354     let x = true;
355     let tuple = match x {
356         true => {
357             let y = 1;
358             let $0var_name = 2 + y;
359             (var_name, true)
360         }
361         _ => (0, false)
362     };
363 }
364 "#,
365         );
366     }
367
368     #[test]
369     fn test_extract_var_in_match_guard() {
370         cov_mark::check!(test_extract_var_in_match_guard);
371         check_assist(
372             extract_variable,
373             r#"
374 fn main() {
375     match () {
376         () if $010 > 0$0 => 1
377         _ => 2
378     };
379 }
380 "#,
381             r#"
382 fn main() {
383     let $0var_name = 10 > 0;
384     match () {
385         () if var_name => 1
386         _ => 2
387     };
388 }
389 "#,
390         );
391     }
392
393     #[test]
394     fn test_extract_var_in_closure_no_block() {
395         cov_mark::check!(test_extract_var_in_closure_no_block);
396         check_assist(
397             extract_variable,
398             r#"
399 fn main() {
400     let lambda = |x: u32| $0x * 2$0;
401 }
402 "#,
403             r#"
404 fn main() {
405     let lambda = |x: u32| { let $0var_name = x * 2; var_name };
406 }
407 "#,
408         );
409     }
410
411     #[test]
412     fn test_extract_var_in_closure_with_block() {
413         check_assist(
414             extract_variable,
415             r#"
416 fn main() {
417     let lambda = |x: u32| { $0x * 2$0 };
418 }
419 "#,
420             r#"
421 fn main() {
422     let lambda = |x: u32| { let $0var_name = x * 2; var_name };
423 }
424 "#,
425         );
426     }
427
428     #[test]
429     fn test_extract_var_path_simple() {
430         check_assist(
431             extract_variable,
432             "
433 fn main() {
434     let o = $0Some(true)$0;
435 }
436 ",
437             "
438 fn main() {
439     let $0var_name = Some(true);
440     let o = var_name;
441 }
442 ",
443         );
444     }
445
446     #[test]
447     fn test_extract_var_path_method() {
448         check_assist(
449             extract_variable,
450             "
451 fn main() {
452     let v = $0bar.foo()$0;
453 }
454 ",
455             "
456 fn main() {
457     let $0foo = bar.foo();
458     let v = foo;
459 }
460 ",
461         );
462     }
463
464     #[test]
465     fn test_extract_var_return() {
466         check_assist(
467             extract_variable,
468             "
469 fn foo() -> u32 {
470     $0return 2 + 2$0;
471 }
472 ",
473             "
474 fn foo() -> u32 {
475     let $0var_name = 2 + 2;
476     return var_name;
477 }
478 ",
479         );
480     }
481
482     #[test]
483     fn test_extract_var_does_not_add_extra_whitespace() {
484         check_assist(
485             extract_variable,
486             "
487 fn foo() -> u32 {
488
489
490     $0return 2 + 2$0;
491 }
492 ",
493             "
494 fn foo() -> u32 {
495
496
497     let $0var_name = 2 + 2;
498     return var_name;
499 }
500 ",
501         );
502
503         check_assist(
504             extract_variable,
505             "
506 fn foo() -> u32 {
507
508         $0return 2 + 2$0;
509 }
510 ",
511             "
512 fn foo() -> u32 {
513
514         let $0var_name = 2 + 2;
515         return var_name;
516 }
517 ",
518         );
519
520         check_assist(
521             extract_variable,
522             "
523 fn foo() -> u32 {
524     let foo = 1;
525
526     // bar
527
528
529     $0return 2 + 2$0;
530 }
531 ",
532             "
533 fn foo() -> u32 {
534     let foo = 1;
535
536     // bar
537
538
539     let $0var_name = 2 + 2;
540     return var_name;
541 }
542 ",
543         );
544     }
545
546     #[test]
547     fn test_extract_var_break() {
548         check_assist(
549             extract_variable,
550             "
551 fn main() {
552     let result = loop {
553         $0break 2 + 2$0;
554     };
555 }
556 ",
557             "
558 fn main() {
559     let result = loop {
560         let $0var_name = 2 + 2;
561         break var_name;
562     };
563 }
564 ",
565         );
566     }
567
568     #[test]
569     fn test_extract_var_for_cast() {
570         check_assist(
571             extract_variable,
572             "
573 fn main() {
574     let v = $00f32 as u32$0;
575 }
576 ",
577             "
578 fn main() {
579     let $0var_name = 0f32 as u32;
580     let v = var_name;
581 }
582 ",
583         );
584     }
585
586     #[test]
587     fn extract_var_field_shorthand() {
588         check_assist(
589             extract_variable,
590             r#"
591 struct S {
592     foo: i32
593 }
594
595 fn main() {
596     S { foo: $01 + 1$0 }
597 }
598 "#,
599             r#"
600 struct S {
601     foo: i32
602 }
603
604 fn main() {
605     let $0foo = 1 + 1;
606     S { foo }
607 }
608 "#,
609         )
610     }
611
612     #[test]
613     fn extract_var_name_from_type() {
614         check_assist(
615             extract_variable,
616             r#"
617 struct Test(i32);
618
619 fn foo() -> Test {
620     $0{ Test(10) }$0
621 }
622 "#,
623             r#"
624 struct Test(i32);
625
626 fn foo() -> Test {
627     let $0test = { Test(10) };
628     test
629 }
630 "#,
631         )
632     }
633
634     #[test]
635     fn extract_var_name_from_parameter() {
636         check_assist(
637             extract_variable,
638             r#"
639 fn bar(test: u32, size: u32)
640
641 fn foo() {
642     bar(1, $01+1$0);
643 }
644 "#,
645             r#"
646 fn bar(test: u32, size: u32)
647
648 fn foo() {
649     let $0size = 1+1;
650     bar(1, size);
651 }
652 "#,
653         )
654     }
655
656     #[test]
657     fn extract_var_parameter_name_has_precedence_over_type() {
658         check_assist(
659             extract_variable,
660             r#"
661 struct TextSize(u32);
662 fn bar(test: u32, size: TextSize)
663
664 fn foo() {
665     bar(1, $0{ TextSize(1+1) }$0);
666 }
667 "#,
668             r#"
669 struct TextSize(u32);
670 fn bar(test: u32, size: TextSize)
671
672 fn foo() {
673     let $0size = { TextSize(1+1) };
674     bar(1, size);
675 }
676 "#,
677         )
678     }
679
680     #[test]
681     fn extract_var_name_from_function() {
682         check_assist(
683             extract_variable,
684             r#"
685 fn is_required(test: u32, size: u32) -> bool
686
687 fn foo() -> bool {
688     $0is_required(1, 2)$0
689 }
690 "#,
691             r#"
692 fn is_required(test: u32, size: u32) -> bool
693
694 fn foo() -> bool {
695     let $0is_required = is_required(1, 2);
696     is_required
697 }
698 "#,
699         )
700     }
701
702     #[test]
703     fn extract_var_name_from_method() {
704         check_assist(
705             extract_variable,
706             r#"
707 struct S;
708 impl S {
709     fn bar(&self, n: u32) -> u32 { n }
710 }
711
712 fn foo() -> u32 {
713     $0S.bar(1)$0
714 }
715 "#,
716             r#"
717 struct S;
718 impl S {
719     fn bar(&self, n: u32) -> u32 { n }
720 }
721
722 fn foo() -> u32 {
723     let $0bar = S.bar(1);
724     bar
725 }
726 "#,
727         )
728     }
729
730     #[test]
731     fn extract_var_name_from_method_param() {
732         check_assist(
733             extract_variable,
734             r#"
735 struct S;
736 impl S {
737     fn bar(&self, n: u32, size: u32) { n }
738 }
739
740 fn foo() {
741     S.bar($01 + 1$0, 2)
742 }
743 "#,
744             r#"
745 struct S;
746 impl S {
747     fn bar(&self, n: u32, size: u32) { n }
748 }
749
750 fn foo() {
751     let $0n = 1 + 1;
752     S.bar(n, 2)
753 }
754 "#,
755         )
756     }
757
758     #[test]
759     fn extract_var_name_from_ufcs_method_param() {
760         check_assist(
761             extract_variable,
762             r#"
763 struct S;
764 impl S {
765     fn bar(&self, n: u32, size: u32) { n }
766 }
767
768 fn foo() {
769     S::bar(&S, $01 + 1$0, 2)
770 }
771 "#,
772             r#"
773 struct S;
774 impl S {
775     fn bar(&self, n: u32, size: u32) { n }
776 }
777
778 fn foo() {
779     let $0n = 1 + 1;
780     S::bar(&S, n, 2)
781 }
782 "#,
783         )
784     }
785
786     #[test]
787     fn extract_var_parameter_name_has_precedence_over_function() {
788         check_assist(
789             extract_variable,
790             r#"
791 fn bar(test: u32, size: u32)
792
793 fn foo() {
794     bar(1, $0symbol_size(1, 2)$0);
795 }
796 "#,
797             r#"
798 fn bar(test: u32, size: u32)
799
800 fn foo() {
801     let $0size = symbol_size(1, 2);
802     bar(1, size);
803 }
804 "#,
805         )
806     }
807
808     #[test]
809     fn test_extract_var_for_return_not_applicable() {
810         check_assist_not_applicable(extract_variable, "fn foo() { $0return$0; } ");
811     }
812
813     #[test]
814     fn test_extract_var_for_break_not_applicable() {
815         check_assist_not_applicable(extract_variable, "fn main() { loop { $0break$0; }; }");
816     }
817
818     #[test]
819     fn test_extract_var_unit_expr_not_applicable() {
820         check_assist_not_applicable(
821             extract_variable,
822             r#"
823 fn foo() {
824     let mut i = 3;
825     $0if i >= 0 {
826         i += 1;
827     } else {
828         i -= 1;
829     }$0
830 }"#,
831         );
832     }
833
834     // FIXME: This is not quite correct, but good enough(tm) for the sorting heuristic
835     #[test]
836     fn extract_var_target() {
837         check_assist_target(extract_variable, "fn foo() -> u32 { $0return 2 + 2$0; }", "2 + 2");
838
839         check_assist_target(
840             extract_variable,
841             "
842 fn main() {
843     let x = true;
844     let tuple = match x {
845         true => ($02 + 2$0, true)
846         _ => (0, false)
847     };
848 }
849 ",
850             "2 + 2",
851         );
852     }
853
854     #[test]
855     fn extract_var_no_block_body() {
856         check_assist_not_applicable(
857             extract_variable,
858             r"
859 const X: usize = $0100$0;
860 ",
861         );
862     }
863 }