]> git.lizzy.rs Git - rust.git/blob - crates/ide_assists/src/handlers/extract_variable.rs
Merge #9108
[rust.git] / crates / ide_assists / src / handlers / extract_variable.rs
1 use stdx::format_to;
2 use syntax::{
3     ast::{self, AstNode},
4     SyntaxKind::{
5         BLOCK_EXPR, BREAK_EXPR, CLOSURE_EXPR, COMMENT, LOOP_EXPR, MATCH_ARM, MATCH_GUARD,
6         PATH_EXPR, RETURN_EXPR,
7     },
8     SyntaxNode,
9 };
10
11 use crate::{utils::suggest_name, AssistContext, AssistId, AssistKind, Assists};
12
13 // Assist: extract_variable
14 //
15 // Extracts subexpression into a variable.
16 //
17 // ```
18 // fn main() {
19 //     $0(1 + 2)$0 * 4;
20 // }
21 // ```
22 // ->
23 // ```
24 // fn main() {
25 //     let $0var_name = (1 + 2);
26 //     var_name * 4;
27 // }
28 // ```
29 pub(crate) fn extract_variable(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
30     if ctx.frange.range.is_empty() {
31         return None;
32     }
33     let node = ctx.covering_element();
34     if node.kind() == COMMENT {
35         cov_mark::hit!(extract_var_in_comment_is_not_applicable);
36         return None;
37     }
38     let to_extract = node.ancestors().find_map(valid_target_expr)?;
39     if let Some(ty) = ctx.sema.type_of_expr(&to_extract) {
40         if ty.is_unit() {
41             return None;
42         }
43     }
44     let anchor = Anchor::from(&to_extract)?;
45     let indent = anchor.syntax().prev_sibling_or_token()?.as_token()?.clone();
46     let target = to_extract.syntax().text_range();
47     acc.add(
48         AssistId("extract_variable", AssistKind::RefactorExtract),
49         "Extract into variable",
50         target,
51         move |edit| {
52             let field_shorthand =
53                 match to_extract.syntax().parent().and_then(ast::RecordExprField::cast) {
54                     Some(field) => field.name_ref(),
55                     None => None,
56                 };
57
58             let mut buf = String::new();
59
60             let var_name = match &field_shorthand {
61                 Some(it) => it.to_string(),
62                 None => suggest_name::for_variable(&to_extract, &ctx.sema),
63             };
64             let expr_range = match &field_shorthand {
65                 Some(it) => it.syntax().text_range().cover(to_extract.syntax().text_range()),
66                 None => to_extract.syntax().text_range(),
67             };
68
69             if let Anchor::WrapInBlock(_) = anchor {
70                 format_to!(buf, "{{ let {} = ", var_name);
71             } else {
72                 format_to!(buf, "let {} = ", var_name);
73             };
74             format_to!(buf, "{}", to_extract.syntax());
75
76             if let Anchor::Replace(stmt) = anchor {
77                 cov_mark::hit!(test_extract_var_expr_stmt);
78                 if stmt.semicolon_token().is_none() {
79                     buf.push_str(";");
80                 }
81                 match ctx.config.snippet_cap {
82                     Some(cap) => {
83                         let snip = buf
84                             .replace(&format!("let {}", var_name), &format!("let $0{}", var_name));
85                         edit.replace_snippet(cap, expr_range, snip)
86                     }
87                     None => edit.replace(expr_range, buf),
88                 }
89                 return;
90             }
91
92             buf.push_str(";");
93
94             // We want to maintain the indent level,
95             // but we do not want to duplicate possible
96             // extra newlines in the indent block
97             let text = indent.text();
98             if text.starts_with('\n') {
99                 buf.push('\n');
100                 buf.push_str(text.trim_start_matches('\n'));
101             } else {
102                 buf.push_str(text);
103             }
104
105             edit.replace(expr_range, var_name.clone());
106             let offset = anchor.syntax().text_range().start();
107             match ctx.config.snippet_cap {
108                 Some(cap) => {
109                     let snip =
110                         buf.replace(&format!("let {}", var_name), &format!("let $0{}", var_name));
111                     edit.insert_snippet(cap, offset, snip)
112                 }
113                 None => edit.insert(offset, buf),
114             }
115
116             if let Anchor::WrapInBlock(_) = anchor {
117                 edit.insert(anchor.syntax().text_range().end(), " }");
118             }
119         },
120     )
121 }
122
123 /// Check whether the node is a valid expression which can be extracted to a variable.
124 /// In general that's true for any expression, but in some cases that would produce invalid code.
125 fn valid_target_expr(node: SyntaxNode) -> Option<ast::Expr> {
126     match node.kind() {
127         PATH_EXPR | LOOP_EXPR => None,
128         BREAK_EXPR => ast::BreakExpr::cast(node).and_then(|e| e.expr()),
129         RETURN_EXPR => ast::ReturnExpr::cast(node).and_then(|e| e.expr()),
130         BLOCK_EXPR => {
131             ast::BlockExpr::cast(node).filter(|it| it.is_standalone()).map(ast::Expr::from)
132         }
133         _ => ast::Expr::cast(node),
134     }
135 }
136
137 enum Anchor {
138     Before(SyntaxNode),
139     Replace(ast::ExprStmt),
140     WrapInBlock(SyntaxNode),
141 }
142
143 impl Anchor {
144     fn from(to_extract: &ast::Expr) -> Option<Anchor> {
145         to_extract.syntax().ancestors().find_map(|node| {
146             if let Some(expr) =
147                 node.parent().and_then(ast::BlockExpr::cast).and_then(|it| it.tail_expr())
148             {
149                 if expr.syntax() == &node {
150                     cov_mark::hit!(test_extract_var_last_expr);
151                     return Some(Anchor::Before(node));
152                 }
153             }
154
155             if let Some(parent) = node.parent() {
156                 if parent.kind() == CLOSURE_EXPR {
157                     cov_mark::hit!(test_extract_var_in_closure_no_block);
158                     return Some(Anchor::WrapInBlock(node));
159                 }
160                 if parent.kind() == MATCH_ARM {
161                     if node.kind() == MATCH_GUARD {
162                         cov_mark::hit!(test_extract_var_in_match_guard);
163                     } else {
164                         cov_mark::hit!(test_extract_var_in_match_arm_no_block);
165                         return Some(Anchor::WrapInBlock(node));
166                     }
167                 }
168             }
169
170             if let Some(stmt) = ast::Stmt::cast(node.clone()) {
171                 if let ast::Stmt::ExprStmt(stmt) = stmt {
172                     if stmt.expr().as_ref() == Some(to_extract) {
173                         return Some(Anchor::Replace(stmt));
174                     }
175                 }
176                 return Some(Anchor::Before(node));
177             }
178             None
179         })
180     }
181
182     fn syntax(&self) -> &SyntaxNode {
183         match self {
184             Anchor::Before(it) | Anchor::WrapInBlock(it) => it,
185             Anchor::Replace(stmt) => stmt.syntax(),
186         }
187     }
188 }
189
190 #[cfg(test)]
191 mod tests {
192     use crate::tests::{check_assist, check_assist_not_applicable, check_assist_target};
193
194     use super::*;
195
196     #[test]
197     fn test_extract_var_simple() {
198         check_assist(
199             extract_variable,
200             r#"
201 fn foo() {
202     foo($01 + 1$0);
203 }"#,
204             r#"
205 fn foo() {
206     let $0var_name = 1 + 1;
207     foo(var_name);
208 }"#,
209         );
210     }
211
212     #[test]
213     fn extract_var_in_comment_is_not_applicable() {
214         cov_mark::check!(extract_var_in_comment_is_not_applicable);
215         check_assist_not_applicable(extract_variable, "fn main() { 1 + /* $0comment$0 */ 1; }");
216     }
217
218     #[test]
219     fn test_extract_var_expr_stmt() {
220         cov_mark::check!(test_extract_var_expr_stmt);
221         check_assist(
222             extract_variable,
223             r#"
224 fn foo() {
225     $01 + 1$0;
226 }"#,
227             r#"
228 fn foo() {
229     let $0var_name = 1 + 1;
230 }"#,
231         );
232         check_assist(
233             extract_variable,
234             "
235 fn foo() {
236     $0{ let x = 0; x }$0
237     something_else();
238 }",
239             "
240 fn foo() {
241     let $0var_name = { let x = 0; x };
242     something_else();
243 }",
244         );
245     }
246
247     #[test]
248     fn test_extract_var_part_of_expr_stmt() {
249         check_assist(
250             extract_variable,
251             "
252 fn foo() {
253     $01$0 + 1;
254 }",
255             "
256 fn foo() {
257     let $0var_name = 1;
258     var_name + 1;
259 }",
260         );
261     }
262
263     #[test]
264     fn test_extract_var_last_expr() {
265         cov_mark::check!(test_extract_var_last_expr);
266         check_assist(
267             extract_variable,
268             r#"
269 fn foo() {
270     bar($01 + 1$0)
271 }
272 "#,
273             r#"
274 fn foo() {
275     let $0var_name = 1 + 1;
276     bar(var_name)
277 }
278 "#,
279         );
280         check_assist(
281             extract_variable,
282             r#"
283 fn foo() -> i32 {
284     $0bar(1 + 1)$0
285 }
286
287 fn bar(i: i32) -> i32 {
288     i
289 }
290 "#,
291             r#"
292 fn foo() -> i32 {
293     let $0bar = bar(1 + 1);
294     bar
295 }
296
297 fn bar(i: i32) -> i32 {
298     i
299 }
300 "#,
301         )
302     }
303
304     #[test]
305     fn test_extract_var_in_match_arm_no_block() {
306         cov_mark::check!(test_extract_var_in_match_arm_no_block);
307         check_assist(
308             extract_variable,
309             r#"
310 fn main() {
311     let x = true;
312     let tuple = match x {
313         true => ($02 + 2$0, true)
314         _ => (0, false)
315     };
316 }
317 "#,
318             r#"
319 fn main() {
320     let x = true;
321     let tuple = match x {
322         true => { let $0var_name = 2 + 2; (var_name, true) }
323         _ => (0, false)
324     };
325 }
326 "#,
327         );
328     }
329
330     #[test]
331     fn test_extract_var_in_match_arm_with_block() {
332         check_assist(
333             extract_variable,
334             r#"
335 fn main() {
336     let x = true;
337     let tuple = match x {
338         true => {
339             let y = 1;
340             ($02 + y$0, true)
341         }
342         _ => (0, false)
343     };
344 }
345 "#,
346             r#"
347 fn main() {
348     let x = true;
349     let tuple = match x {
350         true => {
351             let y = 1;
352             let $0var_name = 2 + y;
353             (var_name, true)
354         }
355         _ => (0, false)
356     };
357 }
358 "#,
359         );
360     }
361
362     #[test]
363     fn test_extract_var_in_match_guard() {
364         cov_mark::check!(test_extract_var_in_match_guard);
365         check_assist(
366             extract_variable,
367             r#"
368 fn main() {
369     match () {
370         () if $010 > 0$0 => 1
371         _ => 2
372     };
373 }
374 "#,
375             r#"
376 fn main() {
377     let $0var_name = 10 > 0;
378     match () {
379         () if var_name => 1
380         _ => 2
381     };
382 }
383 "#,
384         );
385     }
386
387     #[test]
388     fn test_extract_var_in_closure_no_block() {
389         cov_mark::check!(test_extract_var_in_closure_no_block);
390         check_assist(
391             extract_variable,
392             r#"
393 fn main() {
394     let lambda = |x: u32| $0x * 2$0;
395 }
396 "#,
397             r#"
398 fn main() {
399     let lambda = |x: u32| { let $0var_name = x * 2; var_name };
400 }
401 "#,
402         );
403     }
404
405     #[test]
406     fn test_extract_var_in_closure_with_block() {
407         check_assist(
408             extract_variable,
409             r#"
410 fn main() {
411     let lambda = |x: u32| { $0x * 2$0 };
412 }
413 "#,
414             r#"
415 fn main() {
416     let lambda = |x: u32| { let $0var_name = x * 2; var_name };
417 }
418 "#,
419         );
420     }
421
422     #[test]
423     fn test_extract_var_path_simple() {
424         check_assist(
425             extract_variable,
426             "
427 fn main() {
428     let o = $0Some(true)$0;
429 }
430 ",
431             "
432 fn main() {
433     let $0var_name = Some(true);
434     let o = var_name;
435 }
436 ",
437         );
438     }
439
440     #[test]
441     fn test_extract_var_path_method() {
442         check_assist(
443             extract_variable,
444             "
445 fn main() {
446     let v = $0bar.foo()$0;
447 }
448 ",
449             "
450 fn main() {
451     let $0foo = bar.foo();
452     let v = foo;
453 }
454 ",
455         );
456     }
457
458     #[test]
459     fn test_extract_var_return() {
460         check_assist(
461             extract_variable,
462             "
463 fn foo() -> u32 {
464     $0return 2 + 2$0;
465 }
466 ",
467             "
468 fn foo() -> u32 {
469     let $0var_name = 2 + 2;
470     return var_name;
471 }
472 ",
473         );
474     }
475
476     #[test]
477     fn test_extract_var_does_not_add_extra_whitespace() {
478         check_assist(
479             extract_variable,
480             "
481 fn foo() -> u32 {
482
483
484     $0return 2 + 2$0;
485 }
486 ",
487             "
488 fn foo() -> u32 {
489
490
491     let $0var_name = 2 + 2;
492     return var_name;
493 }
494 ",
495         );
496
497         check_assist(
498             extract_variable,
499             "
500 fn foo() -> u32 {
501
502         $0return 2 + 2$0;
503 }
504 ",
505             "
506 fn foo() -> u32 {
507
508         let $0var_name = 2 + 2;
509         return var_name;
510 }
511 ",
512         );
513
514         check_assist(
515             extract_variable,
516             "
517 fn foo() -> u32 {
518     let foo = 1;
519
520     // bar
521
522
523     $0return 2 + 2$0;
524 }
525 ",
526             "
527 fn foo() -> u32 {
528     let foo = 1;
529
530     // bar
531
532
533     let $0var_name = 2 + 2;
534     return var_name;
535 }
536 ",
537         );
538     }
539
540     #[test]
541     fn test_extract_var_break() {
542         check_assist(
543             extract_variable,
544             "
545 fn main() {
546     let result = loop {
547         $0break 2 + 2$0;
548     };
549 }
550 ",
551             "
552 fn main() {
553     let result = loop {
554         let $0var_name = 2 + 2;
555         break var_name;
556     };
557 }
558 ",
559         );
560     }
561
562     #[test]
563     fn test_extract_var_for_cast() {
564         check_assist(
565             extract_variable,
566             "
567 fn main() {
568     let v = $00f32 as u32$0;
569 }
570 ",
571             "
572 fn main() {
573     let $0var_name = 0f32 as u32;
574     let v = var_name;
575 }
576 ",
577         );
578     }
579
580     #[test]
581     fn extract_var_field_shorthand() {
582         check_assist(
583             extract_variable,
584             r#"
585 struct S {
586     foo: i32
587 }
588
589 fn main() {
590     S { foo: $01 + 1$0 }
591 }
592 "#,
593             r#"
594 struct S {
595     foo: i32
596 }
597
598 fn main() {
599     let $0foo = 1 + 1;
600     S { foo }
601 }
602 "#,
603         )
604     }
605
606     #[test]
607     fn extract_var_name_from_type() {
608         check_assist(
609             extract_variable,
610             r#"
611 struct Test(i32);
612
613 fn foo() -> Test {
614     $0{ Test(10) }$0
615 }
616 "#,
617             r#"
618 struct Test(i32);
619
620 fn foo() -> Test {
621     let $0test = { Test(10) };
622     test
623 }
624 "#,
625         )
626     }
627
628     #[test]
629     fn extract_var_name_from_parameter() {
630         check_assist(
631             extract_variable,
632             r#"
633 fn bar(test: u32, size: u32)
634
635 fn foo() {
636     bar(1, $01+1$0);
637 }
638 "#,
639             r#"
640 fn bar(test: u32, size: u32)
641
642 fn foo() {
643     let $0size = 1+1;
644     bar(1, size);
645 }
646 "#,
647         )
648     }
649
650     #[test]
651     fn extract_var_parameter_name_has_precedence_over_type() {
652         check_assist(
653             extract_variable,
654             r#"
655 struct TextSize(u32);
656 fn bar(test: u32, size: TextSize)
657
658 fn foo() {
659     bar(1, $0{ TextSize(1+1) }$0);
660 }
661 "#,
662             r#"
663 struct TextSize(u32);
664 fn bar(test: u32, size: TextSize)
665
666 fn foo() {
667     let $0size = { TextSize(1+1) };
668     bar(1, size);
669 }
670 "#,
671         )
672     }
673
674     #[test]
675     fn extract_var_name_from_function() {
676         check_assist(
677             extract_variable,
678             r#"
679 fn is_required(test: u32, size: u32) -> bool
680
681 fn foo() -> bool {
682     $0is_required(1, 2)$0
683 }
684 "#,
685             r#"
686 fn is_required(test: u32, size: u32) -> bool
687
688 fn foo() -> bool {
689     let $0is_required = is_required(1, 2);
690     is_required
691 }
692 "#,
693         )
694     }
695
696     #[test]
697     fn extract_var_name_from_method() {
698         check_assist(
699             extract_variable,
700             r#"
701 struct S;
702 impl S {
703     fn bar(&self, n: u32) -> u32 { n }
704 }
705
706 fn foo() -> u32 {
707     $0S.bar(1)$0
708 }
709 "#,
710             r#"
711 struct S;
712 impl S {
713     fn bar(&self, n: u32) -> u32 { n }
714 }
715
716 fn foo() -> u32 {
717     let $0bar = S.bar(1);
718     bar
719 }
720 "#,
721         )
722     }
723
724     #[test]
725     fn extract_var_name_from_method_param() {
726         check_assist(
727             extract_variable,
728             r#"
729 struct S;
730 impl S {
731     fn bar(&self, n: u32, size: u32) { n }
732 }
733
734 fn foo() {
735     S.bar($01 + 1$0, 2)
736 }
737 "#,
738             r#"
739 struct S;
740 impl S {
741     fn bar(&self, n: u32, size: u32) { n }
742 }
743
744 fn foo() {
745     let $0n = 1 + 1;
746     S.bar(n, 2)
747 }
748 "#,
749         )
750     }
751
752     #[test]
753     fn extract_var_name_from_ufcs_method_param() {
754         check_assist(
755             extract_variable,
756             r#"
757 struct S;
758 impl S {
759     fn bar(&self, n: u32, size: u32) { n }
760 }
761
762 fn foo() {
763     S::bar(&S, $01 + 1$0, 2)
764 }
765 "#,
766             r#"
767 struct S;
768 impl S {
769     fn bar(&self, n: u32, size: u32) { n }
770 }
771
772 fn foo() {
773     let $0n = 1 + 1;
774     S::bar(&S, n, 2)
775 }
776 "#,
777         )
778     }
779
780     #[test]
781     fn extract_var_parameter_name_has_precedence_over_function() {
782         check_assist(
783             extract_variable,
784             r#"
785 fn bar(test: u32, size: u32)
786
787 fn foo() {
788     bar(1, $0symbol_size(1, 2)$0);
789 }
790 "#,
791             r#"
792 fn bar(test: u32, size: u32)
793
794 fn foo() {
795     let $0size = symbol_size(1, 2);
796     bar(1, size);
797 }
798 "#,
799         )
800     }
801
802     #[test]
803     fn test_extract_var_for_return_not_applicable() {
804         check_assist_not_applicable(extract_variable, "fn foo() { $0return$0; } ");
805     }
806
807     #[test]
808     fn test_extract_var_for_break_not_applicable() {
809         check_assist_not_applicable(extract_variable, "fn main() { loop { $0break$0; }; }");
810     }
811
812     #[test]
813     fn test_extract_var_unit_expr_not_applicable() {
814         check_assist_not_applicable(
815             extract_variable,
816             r#"
817 fn foo() {
818     let mut i = 3;
819     $0if i >= 0 {
820         i += 1;
821     } else {
822         i -= 1;
823     }$0
824 }"#,
825         );
826     }
827
828     // FIXME: This is not quite correct, but good enough(tm) for the sorting heuristic
829     #[test]
830     fn extract_var_target() {
831         check_assist_target(extract_variable, "fn foo() -> u32 { $0return 2 + 2$0; }", "2 + 2");
832
833         check_assist_target(
834             extract_variable,
835             "
836 fn main() {
837     let x = true;
838     let tuple = match x {
839         true => ($02 + 2$0, true)
840         _ => (0, false)
841     };
842 }
843 ",
844             "2 + 2",
845         );
846     }
847 }