]> git.lizzy.rs Git - rust.git/blob - crates/ide_assists/src/handlers/extract_variable.rs
Merge #10440
[rust.git] / crates / ide_assists / src / handlers / extract_variable.rs
1 use stdx::format_to;
2 use syntax::{
3     ast::{self, AstNode},
4     NodeOrToken,
5     SyntaxKind::{
6         BLOCK_EXPR, BREAK_EXPR, CLOSURE_EXPR, COMMENT, LOOP_EXPR, MATCH_ARM, MATCH_GUARD,
7         PATH_EXPR, RETURN_EXPR,
8     },
9     SyntaxNode,
10 };
11
12 use crate::{utils::suggest_name, AssistContext, AssistId, AssistKind, Assists};
13
14 // Assist: extract_variable
15 //
16 // Extracts subexpression into a variable.
17 //
18 // ```
19 // fn main() {
20 //     $0(1 + 2)$0 * 4;
21 // }
22 // ```
23 // ->
24 // ```
25 // fn main() {
26 //     let $0var_name = (1 + 2);
27 //     var_name * 4;
28 // }
29 // ```
30 pub(crate) fn extract_variable(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
31     if ctx.frange.range.is_empty() {
32         return None;
33     }
34     let node = match ctx.covering_element() {
35         NodeOrToken::Node(it) => it,
36         NodeOrToken::Token(it) if it.kind() == COMMENT => {
37             cov_mark::hit!(extract_var_in_comment_is_not_applicable);
38             return None;
39         }
40         NodeOrToken::Token(it) => it.parent()?,
41     };
42     let node = node.ancestors().take_while(|anc| anc.text_range() == node.text_range()).last()?;
43     let to_extract = node
44         .descendants()
45         .take_while(|it| ctx.frange.range.contains_range(it.text_range()))
46         .find_map(valid_target_expr)?;
47
48     if let Some(ty_info) = ctx.sema.type_of_expr(&to_extract) {
49         if ty_info.adjusted().is_unit() {
50             return None;
51         }
52     }
53
54     let anchor = Anchor::from(&to_extract)?;
55     let indent = anchor.syntax().prev_sibling_or_token()?.as_token()?.clone();
56     let target = to_extract.syntax().text_range();
57     acc.add(
58         AssistId("extract_variable", AssistKind::RefactorExtract),
59         "Extract into variable",
60         target,
61         move |edit| {
62             let field_shorthand =
63                 match to_extract.syntax().parent().and_then(ast::RecordExprField::cast) {
64                     Some(field) => field.name_ref(),
65                     None => None,
66                 };
67
68             let mut buf = String::new();
69
70             let var_name = match &field_shorthand {
71                 Some(it) => it.to_string(),
72                 None => suggest_name::for_variable(&to_extract, &ctx.sema),
73             };
74             let expr_range = match &field_shorthand {
75                 Some(it) => it.syntax().text_range().cover(to_extract.syntax().text_range()),
76                 None => to_extract.syntax().text_range(),
77             };
78
79             match anchor {
80                 Anchor::Before(_) | Anchor::Replace(_) => {
81                     format_to!(buf, "let {} = ", var_name)
82                 }
83                 Anchor::WrapInBlock(_) => format_to!(buf, "{{ let {} = ", var_name),
84             };
85             format_to!(buf, "{}", to_extract.syntax());
86
87             if let Anchor::Replace(stmt) = anchor {
88                 cov_mark::hit!(test_extract_var_expr_stmt);
89                 if stmt.semicolon_token().is_none() {
90                     buf.push(';');
91                 }
92                 match ctx.config.snippet_cap {
93                     Some(cap) => {
94                         let snip = buf
95                             .replace(&format!("let {}", var_name), &format!("let $0{}", var_name));
96                         edit.replace_snippet(cap, expr_range, snip)
97                     }
98                     None => edit.replace(expr_range, buf),
99                 }
100                 return;
101             }
102
103             buf.push(';');
104
105             // We want to maintain the indent level,
106             // but we do not want to duplicate possible
107             // extra newlines in the indent block
108             let text = indent.text();
109             if text.starts_with('\n') {
110                 buf.push('\n');
111                 buf.push_str(text.trim_start_matches('\n'));
112             } else {
113                 buf.push_str(text);
114             }
115
116             edit.replace(expr_range, var_name.clone());
117             let offset = anchor.syntax().text_range().start();
118             match ctx.config.snippet_cap {
119                 Some(cap) => {
120                     let snip =
121                         buf.replace(&format!("let {}", var_name), &format!("let $0{}", var_name));
122                     edit.insert_snippet(cap, offset, snip)
123                 }
124                 None => edit.insert(offset, buf),
125             }
126
127             if let Anchor::WrapInBlock(_) = anchor {
128                 edit.insert(anchor.syntax().text_range().end(), " }");
129             }
130         },
131     )
132 }
133
134 /// Check whether the node is a valid expression which can be extracted to a variable.
135 /// In general that's true for any expression, but in some cases that would produce invalid code.
136 fn valid_target_expr(node: SyntaxNode) -> Option<ast::Expr> {
137     match node.kind() {
138         PATH_EXPR | LOOP_EXPR => None,
139         BREAK_EXPR => ast::BreakExpr::cast(node).and_then(|e| e.expr()),
140         RETURN_EXPR => ast::ReturnExpr::cast(node).and_then(|e| e.expr()),
141         BLOCK_EXPR => {
142             ast::BlockExpr::cast(node).filter(|it| it.is_standalone()).map(ast::Expr::from)
143         }
144         _ => ast::Expr::cast(node),
145     }
146 }
147
148 #[derive(Debug)]
149 enum Anchor {
150     Before(SyntaxNode),
151     Replace(ast::ExprStmt),
152     WrapInBlock(SyntaxNode),
153 }
154
155 impl Anchor {
156     fn from(to_extract: &ast::Expr) -> Option<Anchor> {
157         to_extract
158             .syntax()
159             .ancestors()
160             .take_while(|it| !ast::Item::can_cast(it.kind()) || ast::MacroCall::can_cast(it.kind()))
161             .find_map(|node| {
162                 if let Some(expr) =
163                     node.parent().and_then(ast::StmtList::cast).and_then(|it| it.tail_expr())
164                 {
165                     if expr.syntax() == &node {
166                         cov_mark::hit!(test_extract_var_last_expr);
167                         return Some(Anchor::Before(node));
168                     }
169                 }
170
171                 if let Some(parent) = node.parent() {
172                     if parent.kind() == CLOSURE_EXPR {
173                         cov_mark::hit!(test_extract_var_in_closure_no_block);
174                         return Some(Anchor::WrapInBlock(node));
175                     }
176                     if parent.kind() == MATCH_ARM {
177                         if node.kind() == MATCH_GUARD {
178                             cov_mark::hit!(test_extract_var_in_match_guard);
179                         } else {
180                             cov_mark::hit!(test_extract_var_in_match_arm_no_block);
181                             return Some(Anchor::WrapInBlock(node));
182                         }
183                     }
184                 }
185
186                 if let Some(stmt) = ast::Stmt::cast(node.clone()) {
187                     if let ast::Stmt::ExprStmt(stmt) = stmt {
188                         if stmt.expr().as_ref() == Some(to_extract) {
189                             return Some(Anchor::Replace(stmt));
190                         }
191                     }
192                     return Some(Anchor::Before(node));
193                 }
194                 None
195             })
196     }
197
198     fn syntax(&self) -> &SyntaxNode {
199         match self {
200             Anchor::Before(it) | Anchor::WrapInBlock(it) => it,
201             Anchor::Replace(stmt) => stmt.syntax(),
202         }
203     }
204 }
205
206 #[cfg(test)]
207 mod tests {
208     use crate::tests::{check_assist, check_assist_not_applicable, check_assist_target};
209
210     use super::*;
211
212     #[test]
213     fn test_extract_var_simple() {
214         check_assist(
215             extract_variable,
216             r#"
217 fn foo() {
218     foo($01 + 1$0);
219 }"#,
220             r#"
221 fn foo() {
222     let $0var_name = 1 + 1;
223     foo(var_name);
224 }"#,
225         );
226     }
227
228     #[test]
229     fn extract_var_in_comment_is_not_applicable() {
230         cov_mark::check!(extract_var_in_comment_is_not_applicable);
231         check_assist_not_applicable(extract_variable, "fn main() { 1 + /* $0comment$0 */ 1; }");
232     }
233
234     #[test]
235     fn test_extract_var_expr_stmt() {
236         cov_mark::check!(test_extract_var_expr_stmt);
237         check_assist(
238             extract_variable,
239             r#"
240 fn foo() {
241     $01 + 1$0;
242 }"#,
243             r#"
244 fn foo() {
245     let $0var_name = 1 + 1;
246 }"#,
247         );
248         check_assist(
249             extract_variable,
250             "
251 fn foo() {
252     $0{ let x = 0; x }$0
253     something_else();
254 }",
255             "
256 fn foo() {
257     let $0var_name = { let x = 0; x };
258     something_else();
259 }",
260         );
261     }
262
263     #[test]
264     fn test_extract_var_part_of_expr_stmt() {
265         check_assist(
266             extract_variable,
267             "
268 fn foo() {
269     $01$0 + 1;
270 }",
271             "
272 fn foo() {
273     let $0var_name = 1;
274     var_name + 1;
275 }",
276         );
277     }
278
279     #[test]
280     fn test_extract_var_last_expr() {
281         cov_mark::check!(test_extract_var_last_expr);
282         check_assist(
283             extract_variable,
284             r#"
285 fn foo() {
286     bar($01 + 1$0)
287 }
288 "#,
289             r#"
290 fn foo() {
291     let $0var_name = 1 + 1;
292     bar(var_name)
293 }
294 "#,
295         );
296         check_assist(
297             extract_variable,
298             r#"
299 fn foo() -> i32 {
300     $0bar(1 + 1)$0
301 }
302
303 fn bar(i: i32) -> i32 {
304     i
305 }
306 "#,
307             r#"
308 fn foo() -> i32 {
309     let $0bar = bar(1 + 1);
310     bar
311 }
312
313 fn bar(i: i32) -> i32 {
314     i
315 }
316 "#,
317         )
318     }
319
320     #[test]
321     fn test_extract_var_in_match_arm_no_block() {
322         cov_mark::check!(test_extract_var_in_match_arm_no_block);
323         check_assist(
324             extract_variable,
325             r#"
326 fn main() {
327     let x = true;
328     let tuple = match x {
329         true => ($02 + 2$0, true)
330         _ => (0, false)
331     };
332 }
333 "#,
334             r#"
335 fn main() {
336     let x = true;
337     let tuple = match x {
338         true => { let $0var_name = 2 + 2; (var_name, true) }
339         _ => (0, false)
340     };
341 }
342 "#,
343         );
344     }
345
346     #[test]
347     fn test_extract_var_in_match_arm_with_block() {
348         check_assist(
349             extract_variable,
350             r#"
351 fn main() {
352     let x = true;
353     let tuple = match x {
354         true => {
355             let y = 1;
356             ($02 + y$0, true)
357         }
358         _ => (0, false)
359     };
360 }
361 "#,
362             r#"
363 fn main() {
364     let x = true;
365     let tuple = match x {
366         true => {
367             let y = 1;
368             let $0var_name = 2 + y;
369             (var_name, true)
370         }
371         _ => (0, false)
372     };
373 }
374 "#,
375         );
376     }
377
378     #[test]
379     fn test_extract_var_in_match_guard() {
380         cov_mark::check!(test_extract_var_in_match_guard);
381         check_assist(
382             extract_variable,
383             r#"
384 fn main() {
385     match () {
386         () if $010 > 0$0 => 1
387         _ => 2
388     };
389 }
390 "#,
391             r#"
392 fn main() {
393     let $0var_name = 10 > 0;
394     match () {
395         () if var_name => 1
396         _ => 2
397     };
398 }
399 "#,
400         );
401     }
402
403     #[test]
404     fn test_extract_var_in_closure_no_block() {
405         cov_mark::check!(test_extract_var_in_closure_no_block);
406         check_assist(
407             extract_variable,
408             r#"
409 fn main() {
410     let lambda = |x: u32| $0x * 2$0;
411 }
412 "#,
413             r#"
414 fn main() {
415     let lambda = |x: u32| { let $0var_name = x * 2; var_name };
416 }
417 "#,
418         );
419     }
420
421     #[test]
422     fn test_extract_var_in_closure_with_block() {
423         check_assist(
424             extract_variable,
425             r#"
426 fn main() {
427     let lambda = |x: u32| { $0x * 2$0 };
428 }
429 "#,
430             r#"
431 fn main() {
432     let lambda = |x: u32| { let $0var_name = x * 2; var_name };
433 }
434 "#,
435         );
436     }
437
438     #[test]
439     fn test_extract_var_path_simple() {
440         check_assist(
441             extract_variable,
442             "
443 fn main() {
444     let o = $0Some(true)$0;
445 }
446 ",
447             "
448 fn main() {
449     let $0var_name = Some(true);
450     let o = var_name;
451 }
452 ",
453         );
454     }
455
456     #[test]
457     fn test_extract_var_path_method() {
458         check_assist(
459             extract_variable,
460             "
461 fn main() {
462     let v = $0bar.foo()$0;
463 }
464 ",
465             "
466 fn main() {
467     let $0foo = bar.foo();
468     let v = foo;
469 }
470 ",
471         );
472     }
473
474     #[test]
475     fn test_extract_var_return() {
476         check_assist(
477             extract_variable,
478             "
479 fn foo() -> u32 {
480     $0return 2 + 2$0;
481 }
482 ",
483             "
484 fn foo() -> u32 {
485     let $0var_name = 2 + 2;
486     return var_name;
487 }
488 ",
489         );
490     }
491
492     #[test]
493     fn test_extract_var_does_not_add_extra_whitespace() {
494         check_assist(
495             extract_variable,
496             "
497 fn foo() -> u32 {
498
499
500     $0return 2 + 2$0;
501 }
502 ",
503             "
504 fn foo() -> u32 {
505
506
507     let $0var_name = 2 + 2;
508     return var_name;
509 }
510 ",
511         );
512
513         check_assist(
514             extract_variable,
515             "
516 fn foo() -> u32 {
517
518         $0return 2 + 2$0;
519 }
520 ",
521             "
522 fn foo() -> u32 {
523
524         let $0var_name = 2 + 2;
525         return var_name;
526 }
527 ",
528         );
529
530         check_assist(
531             extract_variable,
532             "
533 fn foo() -> u32 {
534     let foo = 1;
535
536     // bar
537
538
539     $0return 2 + 2$0;
540 }
541 ",
542             "
543 fn foo() -> u32 {
544     let foo = 1;
545
546     // bar
547
548
549     let $0var_name = 2 + 2;
550     return var_name;
551 }
552 ",
553         );
554     }
555
556     #[test]
557     fn test_extract_var_break() {
558         check_assist(
559             extract_variable,
560             "
561 fn main() {
562     let result = loop {
563         $0break 2 + 2$0;
564     };
565 }
566 ",
567             "
568 fn main() {
569     let result = loop {
570         let $0var_name = 2 + 2;
571         break var_name;
572     };
573 }
574 ",
575         );
576     }
577
578     #[test]
579     fn test_extract_var_for_cast() {
580         check_assist(
581             extract_variable,
582             "
583 fn main() {
584     let v = $00f32 as u32$0;
585 }
586 ",
587             "
588 fn main() {
589     let $0var_name = 0f32 as u32;
590     let v = var_name;
591 }
592 ",
593         );
594     }
595
596     #[test]
597     fn extract_var_field_shorthand() {
598         check_assist(
599             extract_variable,
600             r#"
601 struct S {
602     foo: i32
603 }
604
605 fn main() {
606     S { foo: $01 + 1$0 }
607 }
608 "#,
609             r#"
610 struct S {
611     foo: i32
612 }
613
614 fn main() {
615     let $0foo = 1 + 1;
616     S { foo }
617 }
618 "#,
619         )
620     }
621
622     #[test]
623     fn extract_var_name_from_type() {
624         check_assist(
625             extract_variable,
626             r#"
627 struct Test(i32);
628
629 fn foo() -> Test {
630     $0{ Test(10) }$0
631 }
632 "#,
633             r#"
634 struct Test(i32);
635
636 fn foo() -> Test {
637     let $0test = { Test(10) };
638     test
639 }
640 "#,
641         )
642     }
643
644     #[test]
645     fn extract_var_name_from_parameter() {
646         check_assist(
647             extract_variable,
648             r#"
649 fn bar(test: u32, size: u32)
650
651 fn foo() {
652     bar(1, $01+1$0);
653 }
654 "#,
655             r#"
656 fn bar(test: u32, size: u32)
657
658 fn foo() {
659     let $0size = 1+1;
660     bar(1, size);
661 }
662 "#,
663         )
664     }
665
666     #[test]
667     fn extract_var_parameter_name_has_precedence_over_type() {
668         check_assist(
669             extract_variable,
670             r#"
671 struct TextSize(u32);
672 fn bar(test: u32, size: TextSize)
673
674 fn foo() {
675     bar(1, $0{ TextSize(1+1) }$0);
676 }
677 "#,
678             r#"
679 struct TextSize(u32);
680 fn bar(test: u32, size: TextSize)
681
682 fn foo() {
683     let $0size = { TextSize(1+1) };
684     bar(1, size);
685 }
686 "#,
687         )
688     }
689
690     #[test]
691     fn extract_var_name_from_function() {
692         check_assist(
693             extract_variable,
694             r#"
695 fn is_required(test: u32, size: u32) -> bool
696
697 fn foo() -> bool {
698     $0is_required(1, 2)$0
699 }
700 "#,
701             r#"
702 fn is_required(test: u32, size: u32) -> bool
703
704 fn foo() -> bool {
705     let $0is_required = is_required(1, 2);
706     is_required
707 }
708 "#,
709         )
710     }
711
712     #[test]
713     fn extract_var_name_from_method() {
714         check_assist(
715             extract_variable,
716             r#"
717 struct S;
718 impl S {
719     fn bar(&self, n: u32) -> u32 { n }
720 }
721
722 fn foo() -> u32 {
723     $0S.bar(1)$0
724 }
725 "#,
726             r#"
727 struct S;
728 impl S {
729     fn bar(&self, n: u32) -> u32 { n }
730 }
731
732 fn foo() -> u32 {
733     let $0bar = S.bar(1);
734     bar
735 }
736 "#,
737         )
738     }
739
740     #[test]
741     fn extract_var_name_from_method_param() {
742         check_assist(
743             extract_variable,
744             r#"
745 struct S;
746 impl S {
747     fn bar(&self, n: u32, size: u32) { n }
748 }
749
750 fn foo() {
751     S.bar($01 + 1$0, 2)
752 }
753 "#,
754             r#"
755 struct S;
756 impl S {
757     fn bar(&self, n: u32, size: u32) { n }
758 }
759
760 fn foo() {
761     let $0n = 1 + 1;
762     S.bar(n, 2)
763 }
764 "#,
765         )
766     }
767
768     #[test]
769     fn extract_var_name_from_ufcs_method_param() {
770         check_assist(
771             extract_variable,
772             r#"
773 struct S;
774 impl S {
775     fn bar(&self, n: u32, size: u32) { n }
776 }
777
778 fn foo() {
779     S::bar(&S, $01 + 1$0, 2)
780 }
781 "#,
782             r#"
783 struct S;
784 impl S {
785     fn bar(&self, n: u32, size: u32) { n }
786 }
787
788 fn foo() {
789     let $0n = 1 + 1;
790     S::bar(&S, n, 2)
791 }
792 "#,
793         )
794     }
795
796     #[test]
797     fn extract_var_parameter_name_has_precedence_over_function() {
798         check_assist(
799             extract_variable,
800             r#"
801 fn bar(test: u32, size: u32)
802
803 fn foo() {
804     bar(1, $0symbol_size(1, 2)$0);
805 }
806 "#,
807             r#"
808 fn bar(test: u32, size: u32)
809
810 fn foo() {
811     let $0size = symbol_size(1, 2);
812     bar(1, size);
813 }
814 "#,
815         )
816     }
817
818     #[test]
819     fn test_extract_var_for_return_not_applicable() {
820         check_assist_not_applicable(extract_variable, "fn foo() { $0return$0; } ");
821     }
822
823     #[test]
824     fn test_extract_var_for_break_not_applicable() {
825         check_assist_not_applicable(extract_variable, "fn main() { loop { $0break$0; }; }");
826     }
827
828     #[test]
829     fn test_extract_var_unit_expr_not_applicable() {
830         check_assist_not_applicable(
831             extract_variable,
832             r#"
833 fn foo() {
834     let mut i = 3;
835     $0if i >= 0 {
836         i += 1;
837     } else {
838         i -= 1;
839     }$0
840 }"#,
841         );
842     }
843
844     // FIXME: This is not quite correct, but good enough(tm) for the sorting heuristic
845     #[test]
846     fn extract_var_target() {
847         check_assist_target(extract_variable, "fn foo() -> u32 { $0return 2 + 2$0; }", "2 + 2");
848
849         check_assist_target(
850             extract_variable,
851             "
852 fn main() {
853     let x = true;
854     let tuple = match x {
855         true => ($02 + 2$0, true)
856         _ => (0, false)
857     };
858 }
859 ",
860             "2 + 2",
861         );
862     }
863
864     #[test]
865     fn extract_var_no_block_body() {
866         check_assist_not_applicable(
867             extract_variable,
868             r"
869 const X: usize = $0100$0;
870 ",
871         );
872     }
873 }