]> git.lizzy.rs Git - rust.git/blob - crates/ide_assists/src/handlers/extract_variable.rs
correctly handle mutable references
[rust.git] / crates / ide_assists / src / handlers / extract_variable.rs
1 use stdx::format_to;
2 use syntax::{
3     ast::{self, AstNode},
4     NodeOrToken,
5     SyntaxKind::{
6         BLOCK_EXPR, BREAK_EXPR, CLOSURE_EXPR, COMMENT, LOOP_EXPR, MATCH_ARM, MATCH_GUARD,
7         PATH_EXPR, RETURN_EXPR,
8     },
9     SyntaxNode,
10 };
11
12 use crate::{utils::suggest_name, AssistContext, AssistId, AssistKind, Assists};
13
14 // Assist: extract_variable
15 //
16 // Extracts subexpression into a variable.
17 //
18 // ```
19 // fn main() {
20 //     $0(1 + 2)$0 * 4;
21 // }
22 // ```
23 // ->
24 // ```
25 // fn main() {
26 //     let $0var_name = (1 + 2);
27 //     var_name * 4;
28 // }
29 // ```
30 pub(crate) fn extract_variable(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
31     if ctx.has_empty_selection() {
32         return None;
33     }
34
35     let node = match ctx.covering_element() {
36         NodeOrToken::Node(it) => it,
37         NodeOrToken::Token(it) if it.kind() == COMMENT => {
38             cov_mark::hit!(extract_var_in_comment_is_not_applicable);
39             return None;
40         }
41         NodeOrToken::Token(it) => it.parent()?,
42     };
43     let node = node.ancestors().take_while(|anc| anc.text_range() == node.text_range()).last()?;
44     let to_extract = node
45         .descendants()
46         .take_while(|it| ctx.selection_trimmed().contains_range(it.text_range()))
47         .find_map(valid_target_expr)?;
48
49     if let Some(ty_info) = ctx.sema.type_of_expr(&to_extract) {
50         if ty_info.adjusted().is_unit() {
51             return None;
52         }
53     }
54
55     let is_mutable_reference = if let Some(receiver_type) = get_receiver_type(&ctx, &to_extract) {
56         receiver_type.is_mutable_reference()
57     } else {
58         false
59     };
60
61     let anchor = Anchor::from(&to_extract)?;
62     let indent = anchor.syntax().prev_sibling_or_token()?.as_token()?.clone();
63     let target = to_extract.syntax().text_range();
64     acc.add(
65         AssistId("extract_variable", AssistKind::RefactorExtract),
66         "Extract into variable",
67         target,
68         move |edit| {
69             let field_shorthand =
70                 match to_extract.syntax().parent().and_then(ast::RecordExprField::cast) {
71                     Some(field) => field.name_ref(),
72                     None => None,
73                 };
74
75             let mut buf = String::new();
76
77             let var_name = match &field_shorthand {
78                 Some(it) => it.to_string(),
79                 None => suggest_name::for_variable(&to_extract, &ctx.sema),
80             };
81             let expr_range = match &field_shorthand {
82                 Some(it) => it.syntax().text_range().cover(to_extract.syntax().text_range()),
83                 None => to_extract.syntax().text_range(),
84             };
85
86             let reference_modifier = if is_mutable_reference { "&mut " } else { "" };
87
88             match anchor {
89                 Anchor::Before(_) | Anchor::Replace(_) => {
90                     format_to!(buf, "let {} = {}", var_name, reference_modifier)
91                 }
92                 Anchor::WrapInBlock(_) => {
93                     format_to!(buf, "{{ let {} = {}", var_name, reference_modifier)
94                 }
95             };
96             format_to!(buf, "{}", to_extract.syntax());
97
98             if let Anchor::Replace(stmt) = anchor {
99                 cov_mark::hit!(test_extract_var_expr_stmt);
100                 if stmt.semicolon_token().is_none() {
101                     buf.push(';');
102                 }
103                 match ctx.config.snippet_cap {
104                     Some(cap) => {
105                         let snip = buf
106                             .replace(&format!("let {}", var_name), &format!("let $0{}", var_name));
107                         edit.replace_snippet(cap, expr_range, snip)
108                     }
109                     None => edit.replace(expr_range, buf),
110                 }
111                 return;
112             }
113
114             buf.push(';');
115
116             // We want to maintain the indent level,
117             // but we do not want to duplicate possible
118             // extra newlines in the indent block
119             let text = indent.text();
120             if text.starts_with('\n') {
121                 buf.push('\n');
122                 buf.push_str(text.trim_start_matches('\n'));
123             } else {
124                 buf.push_str(text);
125             }
126
127             edit.replace(expr_range, var_name.clone());
128             let offset = anchor.syntax().text_range().start();
129             match ctx.config.snippet_cap {
130                 Some(cap) => {
131                     let snip =
132                         buf.replace(&format!("let {}", var_name), &format!("let $0{}", var_name));
133                     edit.insert_snippet(cap, offset, snip)
134                 }
135                 None => edit.insert(offset, buf),
136             }
137
138             if let Anchor::WrapInBlock(_) = anchor {
139                 edit.insert(anchor.syntax().text_range().end(), " }");
140             }
141         },
142     )
143 }
144
145 /// Check whether the node is a valid expression which can be extracted to a variable.
146 /// In general that's true for any expression, but in some cases that would produce invalid code.
147 fn valid_target_expr(node: SyntaxNode) -> Option<ast::Expr> {
148     match node.kind() {
149         PATH_EXPR | LOOP_EXPR => None,
150         BREAK_EXPR => ast::BreakExpr::cast(node).and_then(|e| e.expr()),
151         RETURN_EXPR => ast::ReturnExpr::cast(node).and_then(|e| e.expr()),
152         BLOCK_EXPR => {
153             ast::BlockExpr::cast(node).filter(|it| it.is_standalone()).map(ast::Expr::from)
154         }
155         _ => ast::Expr::cast(node),
156     }
157 }
158
159 fn get_receiver_type(ctx: &AssistContext, expression: &ast::Expr) -> Option<hir::Type> {
160     let receiver = get_receiver(expression.to_owned())?;
161     Some(ctx.sema.type_of_expr(&receiver)?.original())
162 }
163
164 fn get_receiver(expression: ast::Expr) -> Option<ast::Expr> {
165     match expression {
166         ast::Expr::FieldExpr(field) if field.expr().is_some() => {
167             let nested_expression = &field.expr()?;
168             get_receiver(nested_expression.to_owned())
169         }
170         ast::Expr::PathExpr(_) => Some(expression),
171         _ => None,
172     }
173 }
174
175 #[derive(Debug)]
176 enum Anchor {
177     Before(SyntaxNode),
178     Replace(ast::ExprStmt),
179     WrapInBlock(SyntaxNode),
180 }
181
182 impl Anchor {
183     fn from(to_extract: &ast::Expr) -> Option<Anchor> {
184         to_extract
185             .syntax()
186             .ancestors()
187             .take_while(|it| !ast::Item::can_cast(it.kind()) || ast::MacroCall::can_cast(it.kind()))
188             .find_map(|node| {
189                 if ast::MacroCall::can_cast(node.kind()) {
190                     return None;
191                 }
192                 if let Some(expr) =
193                     node.parent().and_then(ast::StmtList::cast).and_then(|it| it.tail_expr())
194                 {
195                     if expr.syntax() == &node {
196                         cov_mark::hit!(test_extract_var_last_expr);
197                         return Some(Anchor::Before(node));
198                     }
199                 }
200
201                 if let Some(parent) = node.parent() {
202                     if parent.kind() == CLOSURE_EXPR {
203                         cov_mark::hit!(test_extract_var_in_closure_no_block);
204                         return Some(Anchor::WrapInBlock(node));
205                     }
206                     if parent.kind() == MATCH_ARM {
207                         if node.kind() == MATCH_GUARD {
208                             cov_mark::hit!(test_extract_var_in_match_guard);
209                         } else {
210                             cov_mark::hit!(test_extract_var_in_match_arm_no_block);
211                             return Some(Anchor::WrapInBlock(node));
212                         }
213                     }
214                 }
215
216                 if let Some(stmt) = ast::Stmt::cast(node.clone()) {
217                     if let ast::Stmt::ExprStmt(stmt) = stmt {
218                         if stmt.expr().as_ref() == Some(to_extract) {
219                             return Some(Anchor::Replace(stmt));
220                         }
221                     }
222                     return Some(Anchor::Before(node));
223                 }
224                 None
225             })
226     }
227
228     fn syntax(&self) -> &SyntaxNode {
229         match self {
230             Anchor::Before(it) | Anchor::WrapInBlock(it) => it,
231             Anchor::Replace(stmt) => stmt.syntax(),
232         }
233     }
234 }
235
236 #[cfg(test)]
237 mod tests {
238     use crate::tests::{check_assist, check_assist_not_applicable, check_assist_target};
239
240     use super::*;
241
242     #[test]
243     fn test_extract_var_simple() {
244         check_assist(
245             extract_variable,
246             r#"
247 fn foo() {
248     foo($01 + 1$0);
249 }"#,
250             r#"
251 fn foo() {
252     let $0var_name = 1 + 1;
253     foo(var_name);
254 }"#,
255         );
256     }
257
258     #[test]
259     fn extract_var_in_comment_is_not_applicable() {
260         cov_mark::check!(extract_var_in_comment_is_not_applicable);
261         check_assist_not_applicable(extract_variable, "fn main() { 1 + /* $0comment$0 */ 1; }");
262     }
263
264     #[test]
265     fn test_extract_var_expr_stmt() {
266         cov_mark::check!(test_extract_var_expr_stmt);
267         check_assist(
268             extract_variable,
269             r#"
270 fn foo() {
271   $0  1 + 1$0;
272 }"#,
273             r#"
274 fn foo() {
275     let $0var_name = 1 + 1;
276 }"#,
277         );
278         check_assist(
279             extract_variable,
280             r"
281 fn foo() {
282     $0{ let x = 0; x }$0
283     something_else();
284 }",
285             r"
286 fn foo() {
287     let $0var_name = { let x = 0; x };
288     something_else();
289 }",
290         );
291     }
292
293     #[test]
294     fn test_extract_var_part_of_expr_stmt() {
295         check_assist(
296             extract_variable,
297             r"
298 fn foo() {
299     $01$0 + 1;
300 }",
301             r"
302 fn foo() {
303     let $0var_name = 1;
304     var_name + 1;
305 }",
306         );
307     }
308
309     #[test]
310     fn test_extract_var_last_expr() {
311         cov_mark::check!(test_extract_var_last_expr);
312         check_assist(
313             extract_variable,
314             r#"
315 fn foo() {
316     bar($01 + 1$0)
317 }
318 "#,
319             r#"
320 fn foo() {
321     let $0var_name = 1 + 1;
322     bar(var_name)
323 }
324 "#,
325         );
326         check_assist(
327             extract_variable,
328             r#"
329 fn foo() -> i32 {
330     $0bar(1 + 1)$0
331 }
332
333 fn bar(i: i32) -> i32 {
334     i
335 }
336 "#,
337             r#"
338 fn foo() -> i32 {
339     let $0bar = bar(1 + 1);
340     bar
341 }
342
343 fn bar(i: i32) -> i32 {
344     i
345 }
346 "#,
347         )
348     }
349
350     #[test]
351     fn test_extract_var_in_match_arm_no_block() {
352         cov_mark::check!(test_extract_var_in_match_arm_no_block);
353         check_assist(
354             extract_variable,
355             r#"
356 fn main() {
357     let x = true;
358     let tuple = match x {
359         true => ($02 + 2$0, true)
360         _ => (0, false)
361     };
362 }
363 "#,
364             r#"
365 fn main() {
366     let x = true;
367     let tuple = match x {
368         true => { let $0var_name = 2 + 2; (var_name, true) }
369         _ => (0, false)
370     };
371 }
372 "#,
373         );
374     }
375
376     #[test]
377     fn test_extract_var_in_match_arm_with_block() {
378         check_assist(
379             extract_variable,
380             r#"
381 fn main() {
382     let x = true;
383     let tuple = match x {
384         true => {
385             let y = 1;
386             ($02 + y$0, true)
387         }
388         _ => (0, false)
389     };
390 }
391 "#,
392             r#"
393 fn main() {
394     let x = true;
395     let tuple = match x {
396         true => {
397             let y = 1;
398             let $0var_name = 2 + y;
399             (var_name, true)
400         }
401         _ => (0, false)
402     };
403 }
404 "#,
405         );
406     }
407
408     #[test]
409     fn test_extract_var_in_match_guard() {
410         cov_mark::check!(test_extract_var_in_match_guard);
411         check_assist(
412             extract_variable,
413             r#"
414 fn main() {
415     match () {
416         () if $010 > 0$0 => 1
417         _ => 2
418     };
419 }
420 "#,
421             r#"
422 fn main() {
423     let $0var_name = 10 > 0;
424     match () {
425         () if var_name => 1
426         _ => 2
427     };
428 }
429 "#,
430         );
431     }
432
433     #[test]
434     fn test_extract_var_in_closure_no_block() {
435         cov_mark::check!(test_extract_var_in_closure_no_block);
436         check_assist(
437             extract_variable,
438             r#"
439 fn main() {
440     let lambda = |x: u32| $0x * 2$0;
441 }
442 "#,
443             r#"
444 fn main() {
445     let lambda = |x: u32| { let $0var_name = x * 2; var_name };
446 }
447 "#,
448         );
449     }
450
451     #[test]
452     fn test_extract_var_in_closure_with_block() {
453         check_assist(
454             extract_variable,
455             r#"
456 fn main() {
457     let lambda = |x: u32| { $0x * 2$0 };
458 }
459 "#,
460             r#"
461 fn main() {
462     let lambda = |x: u32| { let $0var_name = x * 2; var_name };
463 }
464 "#,
465         );
466     }
467
468     #[test]
469     fn test_extract_var_path_simple() {
470         check_assist(
471             extract_variable,
472             "
473 fn main() {
474     let o = $0Some(true)$0;
475 }
476 ",
477             "
478 fn main() {
479     let $0var_name = Some(true);
480     let o = var_name;
481 }
482 ",
483         );
484     }
485
486     #[test]
487     fn test_extract_var_path_method() {
488         check_assist(
489             extract_variable,
490             "
491 fn main() {
492     let v = $0bar.foo()$0;
493 }
494 ",
495             "
496 fn main() {
497     let $0foo = bar.foo();
498     let v = foo;
499 }
500 ",
501         );
502     }
503
504     #[test]
505     fn test_extract_var_return() {
506         check_assist(
507             extract_variable,
508             "
509 fn foo() -> u32 {
510     $0return 2 + 2$0;
511 }
512 ",
513             "
514 fn foo() -> u32 {
515     let $0var_name = 2 + 2;
516     return var_name;
517 }
518 ",
519         );
520     }
521
522     #[test]
523     fn test_extract_var_does_not_add_extra_whitespace() {
524         check_assist(
525             extract_variable,
526             "
527 fn foo() -> u32 {
528
529
530     $0return 2 + 2$0;
531 }
532 ",
533             "
534 fn foo() -> u32 {
535
536
537     let $0var_name = 2 + 2;
538     return var_name;
539 }
540 ",
541         );
542
543         check_assist(
544             extract_variable,
545             "
546 fn foo() -> u32 {
547
548         $0return 2 + 2$0;
549 }
550 ",
551             "
552 fn foo() -> u32 {
553
554         let $0var_name = 2 + 2;
555         return var_name;
556 }
557 ",
558         );
559
560         check_assist(
561             extract_variable,
562             "
563 fn foo() -> u32 {
564     let foo = 1;
565
566     // bar
567
568
569     $0return 2 + 2$0;
570 }
571 ",
572             "
573 fn foo() -> u32 {
574     let foo = 1;
575
576     // bar
577
578
579     let $0var_name = 2 + 2;
580     return var_name;
581 }
582 ",
583         );
584     }
585
586     #[test]
587     fn test_extract_var_break() {
588         check_assist(
589             extract_variable,
590             "
591 fn main() {
592     let result = loop {
593         $0break 2 + 2$0;
594     };
595 }
596 ",
597             "
598 fn main() {
599     let result = loop {
600         let $0var_name = 2 + 2;
601         break var_name;
602     };
603 }
604 ",
605         );
606     }
607
608     #[test]
609     fn test_extract_var_for_cast() {
610         check_assist(
611             extract_variable,
612             "
613 fn main() {
614     let v = $00f32 as u32$0;
615 }
616 ",
617             "
618 fn main() {
619     let $0var_name = 0f32 as u32;
620     let v = var_name;
621 }
622 ",
623         );
624     }
625
626     #[test]
627     fn extract_var_field_shorthand() {
628         check_assist(
629             extract_variable,
630             r#"
631 struct S {
632     foo: i32
633 }
634
635 fn main() {
636     S { foo: $01 + 1$0 }
637 }
638 "#,
639             r#"
640 struct S {
641     foo: i32
642 }
643
644 fn main() {
645     let $0foo = 1 + 1;
646     S { foo }
647 }
648 "#,
649         )
650     }
651
652     #[test]
653     fn extract_var_name_from_type() {
654         check_assist(
655             extract_variable,
656             r#"
657 struct Test(i32);
658
659 fn foo() -> Test {
660     $0{ Test(10) }$0
661 }
662 "#,
663             r#"
664 struct Test(i32);
665
666 fn foo() -> Test {
667     let $0test = { Test(10) };
668     test
669 }
670 "#,
671         )
672     }
673
674     #[test]
675     fn extract_var_name_from_parameter() {
676         check_assist(
677             extract_variable,
678             r#"
679 fn bar(test: u32, size: u32)
680
681 fn foo() {
682     bar(1, $01+1$0);
683 }
684 "#,
685             r#"
686 fn bar(test: u32, size: u32)
687
688 fn foo() {
689     let $0size = 1+1;
690     bar(1, size);
691 }
692 "#,
693         )
694     }
695
696     #[test]
697     fn extract_var_parameter_name_has_precedence_over_type() {
698         check_assist(
699             extract_variable,
700             r#"
701 struct TextSize(u32);
702 fn bar(test: u32, size: TextSize)
703
704 fn foo() {
705     bar(1, $0{ TextSize(1+1) }$0);
706 }
707 "#,
708             r#"
709 struct TextSize(u32);
710 fn bar(test: u32, size: TextSize)
711
712 fn foo() {
713     let $0size = { TextSize(1+1) };
714     bar(1, size);
715 }
716 "#,
717         )
718     }
719
720     #[test]
721     fn extract_var_name_from_function() {
722         check_assist(
723             extract_variable,
724             r#"
725 fn is_required(test: u32, size: u32) -> bool
726
727 fn foo() -> bool {
728     $0is_required(1, 2)$0
729 }
730 "#,
731             r#"
732 fn is_required(test: u32, size: u32) -> bool
733
734 fn foo() -> bool {
735     let $0is_required = is_required(1, 2);
736     is_required
737 }
738 "#,
739         )
740     }
741
742     #[test]
743     fn extract_var_name_from_method() {
744         check_assist(
745             extract_variable,
746             r#"
747 struct S;
748 impl S {
749     fn bar(&self, n: u32) -> u32 { n }
750 }
751
752 fn foo() -> u32 {
753     $0S.bar(1)$0
754 }
755 "#,
756             r#"
757 struct S;
758 impl S {
759     fn bar(&self, n: u32) -> u32 { n }
760 }
761
762 fn foo() -> u32 {
763     let $0bar = S.bar(1);
764     bar
765 }
766 "#,
767         )
768     }
769
770     #[test]
771     fn extract_var_name_from_method_param() {
772         check_assist(
773             extract_variable,
774             r#"
775 struct S;
776 impl S {
777     fn bar(&self, n: u32, size: u32) { n }
778 }
779
780 fn foo() {
781     S.bar($01 + 1$0, 2)
782 }
783 "#,
784             r#"
785 struct S;
786 impl S {
787     fn bar(&self, n: u32, size: u32) { n }
788 }
789
790 fn foo() {
791     let $0n = 1 + 1;
792     S.bar(n, 2)
793 }
794 "#,
795         )
796     }
797
798     #[test]
799     fn extract_var_name_from_ufcs_method_param() {
800         check_assist(
801             extract_variable,
802             r#"
803 struct S;
804 impl S {
805     fn bar(&self, n: u32, size: u32) { n }
806 }
807
808 fn foo() {
809     S::bar(&S, $01 + 1$0, 2)
810 }
811 "#,
812             r#"
813 struct S;
814 impl S {
815     fn bar(&self, n: u32, size: u32) { n }
816 }
817
818 fn foo() {
819     let $0n = 1 + 1;
820     S::bar(&S, n, 2)
821 }
822 "#,
823         )
824     }
825
826     #[test]
827     fn extract_var_parameter_name_has_precedence_over_function() {
828         check_assist(
829             extract_variable,
830             r#"
831 fn bar(test: u32, size: u32)
832
833 fn foo() {
834     bar(1, $0symbol_size(1, 2)$0);
835 }
836 "#,
837             r#"
838 fn bar(test: u32, size: u32)
839
840 fn foo() {
841     let $0size = symbol_size(1, 2);
842     bar(1, size);
843 }
844 "#,
845         )
846     }
847
848     #[test]
849     fn extract_macro_call() {
850         check_assist(
851             extract_variable,
852             r"
853 struct Vec;
854 macro_rules! vec {
855     () => {Vec}
856 }
857 fn main() {
858     let _ = $0vec![]$0;
859 }
860 ",
861             r"
862 struct Vec;
863 macro_rules! vec {
864     () => {Vec}
865 }
866 fn main() {
867     let $0vec = vec![];
868     let _ = vec;
869 }
870 ",
871         );
872     }
873
874     #[test]
875     fn test_extract_var_for_return_not_applicable() {
876         check_assist_not_applicable(extract_variable, "fn foo() { $0return$0; } ");
877     }
878
879     #[test]
880     fn test_extract_var_for_break_not_applicable() {
881         check_assist_not_applicable(extract_variable, "fn main() { loop { $0break$0; }; }");
882     }
883
884     #[test]
885     fn test_extract_var_unit_expr_not_applicable() {
886         check_assist_not_applicable(
887             extract_variable,
888             r#"
889 fn foo() {
890     let mut i = 3;
891     $0if i >= 0 {
892         i += 1;
893     } else {
894         i -= 1;
895     }$0
896 }"#,
897         );
898     }
899
900     // FIXME: This is not quite correct, but good enough(tm) for the sorting heuristic
901     #[test]
902     fn extract_var_target() {
903         check_assist_target(extract_variable, "fn foo() -> u32 { $0return 2 + 2$0; }", "2 + 2");
904
905         check_assist_target(
906             extract_variable,
907             "
908 fn main() {
909     let x = true;
910     let tuple = match x {
911         true => ($02 + 2$0, true)
912         _ => (0, false)
913     };
914 }
915 ",
916             "2 + 2",
917         );
918     }
919
920     #[test]
921     fn extract_var_no_block_body() {
922         check_assist_not_applicable(
923             extract_variable,
924             r"
925 const X: usize = $0100$0;
926 ",
927         );
928     }
929
930     #[test]
931     fn test_extract_var_mutable_reference_parameter() {
932         check_assist(
933             extract_variable,
934             r#"
935 struct S {
936     vec: Vec<u8>
937 }
938
939 fn foo(s: &mut S) {
940     $0s.vec$0.push(0);
941 }"#,
942             r#"
943 struct S {
944     vec: Vec<u8>
945 }
946
947 fn foo(s: &mut S) {
948     let $0var_name = &mut s.vec;
949     var_name.push(0);
950 }"#,
951         );
952     }
953
954     #[test]
955     fn test_extract_var_mutable_reference_parameter_deep_nesting() {
956         check_assist(
957             extract_variable,
958             r#"
959 struct Y {
960     field: X
961 }
962 struct X {
963     field: S
964 }
965 struct S {
966     vec: Vec<u8>
967 }
968
969 fn foo(f: &mut Y) {
970     $0f.field.field.vec$0.push(0);
971 }"#,
972             r#"
973 struct Y {
974     field: X
975 }
976 struct X {
977     field: S
978 }
979 struct S {
980     vec: Vec<u8>
981 }
982
983 fn foo(f: &mut Y) {
984     let $0var_name = &mut f.field.field.vec;
985     var_name.push(0);
986 }"#,
987         );
988     }
989 }