]> git.lizzy.rs Git - rust.git/blob - crates/ide_assists/src/handlers/extract_variable.rs
Merge #10437
[rust.git] / crates / ide_assists / src / handlers / extract_variable.rs
1 use stdx::format_to;
2 use syntax::{
3     ast::{self, AstNode},
4     NodeOrToken,
5     SyntaxKind::{
6         BLOCK_EXPR, BREAK_EXPR, CLOSURE_EXPR, COMMENT, LOOP_EXPR, MATCH_ARM, MATCH_GUARD,
7         PATH_EXPR, RETURN_EXPR,
8     },
9     SyntaxNode,
10 };
11
12 use crate::{utils::suggest_name, AssistContext, AssistId, AssistKind, Assists};
13
14 // Assist: extract_variable
15 //
16 // Extracts subexpression into a variable.
17 //
18 // ```
19 // fn main() {
20 //     $0(1 + 2)$0 * 4;
21 // }
22 // ```
23 // ->
24 // ```
25 // fn main() {
26 //     let $0var_name = (1 + 2);
27 //     var_name * 4;
28 // }
29 // ```
30 pub(crate) fn extract_variable(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
31     if ctx.frange.range.is_empty() {
32         return None;
33     }
34     let node = match ctx.covering_element() {
35         NodeOrToken::Node(it) => it,
36         NodeOrToken::Token(it) if it.kind() == COMMENT => {
37             cov_mark::hit!(extract_var_in_comment_is_not_applicable);
38             return None;
39         }
40         NodeOrToken::Token(it) => it.parent()?,
41     };
42     let node = node.ancestors().take_while(|anc| anc.text_range() == node.text_range()).last()?;
43     let to_extract = node
44         .descendants()
45         .take_while(|it| ctx.frange.range.contains_range(it.text_range()))
46         .find_map(valid_target_expr)?;
47
48     if let Some(ty_info) = ctx.sema.type_of_expr(&to_extract) {
49         if ty_info.adjusted().is_unit() {
50             return None;
51         }
52     }
53
54     let anchor = Anchor::from(&to_extract)?;
55     let indent = anchor.syntax().prev_sibling_or_token()?.as_token()?.clone();
56     let target = to_extract.syntax().text_range();
57     acc.add(
58         AssistId("extract_variable", AssistKind::RefactorExtract),
59         "Extract into variable",
60         target,
61         move |edit| {
62             let field_shorthand =
63                 match to_extract.syntax().parent().and_then(ast::RecordExprField::cast) {
64                     Some(field) => field.name_ref(),
65                     None => None,
66                 };
67
68             let mut buf = String::new();
69
70             let var_name = match &field_shorthand {
71                 Some(it) => it.to_string(),
72                 None => suggest_name::for_variable(&to_extract, &ctx.sema),
73             };
74             let expr_range = match &field_shorthand {
75                 Some(it) => it.syntax().text_range().cover(to_extract.syntax().text_range()),
76                 None => to_extract.syntax().text_range(),
77             };
78
79             if let Anchor::WrapInBlock(_) = anchor {
80                 format_to!(buf, "{{ let {} = ", var_name);
81             } else {
82                 format_to!(buf, "let {} = ", var_name);
83             };
84             format_to!(buf, "{}", to_extract.syntax());
85
86             if let Anchor::Replace(stmt) = anchor {
87                 cov_mark::hit!(test_extract_var_expr_stmt);
88                 if stmt.semicolon_token().is_none() {
89                     buf.push(';');
90                 }
91                 match ctx.config.snippet_cap {
92                     Some(cap) => {
93                         let snip = buf
94                             .replace(&format!("let {}", var_name), &format!("let $0{}", var_name));
95                         edit.replace_snippet(cap, expr_range, snip)
96                     }
97                     None => edit.replace(expr_range, buf),
98                 }
99                 return;
100             }
101
102             buf.push(';');
103
104             // We want to maintain the indent level,
105             // but we do not want to duplicate possible
106             // extra newlines in the indent block
107             let text = indent.text();
108             if text.starts_with('\n') {
109                 buf.push('\n');
110                 buf.push_str(text.trim_start_matches('\n'));
111             } else {
112                 buf.push_str(text);
113             }
114
115             edit.replace(expr_range, var_name.clone());
116             let offset = anchor.syntax().text_range().start();
117             match ctx.config.snippet_cap {
118                 Some(cap) => {
119                     let snip =
120                         buf.replace(&format!("let {}", var_name), &format!("let $0{}", var_name));
121                     edit.insert_snippet(cap, offset, snip)
122                 }
123                 None => edit.insert(offset, buf),
124             }
125
126             if let Anchor::WrapInBlock(_) = anchor {
127                 edit.insert(anchor.syntax().text_range().end(), " }");
128             }
129         },
130     )
131 }
132
133 /// Check whether the node is a valid expression which can be extracted to a variable.
134 /// In general that's true for any expression, but in some cases that would produce invalid code.
135 fn valid_target_expr(node: SyntaxNode) -> Option<ast::Expr> {
136     match node.kind() {
137         PATH_EXPR | LOOP_EXPR => None,
138         BREAK_EXPR => ast::BreakExpr::cast(node).and_then(|e| e.expr()),
139         RETURN_EXPR => ast::ReturnExpr::cast(node).and_then(|e| e.expr()),
140         BLOCK_EXPR => {
141             ast::BlockExpr::cast(node).filter(|it| it.is_standalone()).map(ast::Expr::from)
142         }
143         _ => ast::Expr::cast(node),
144     }
145 }
146
147 #[derive(Debug)]
148 enum Anchor {
149     Before(SyntaxNode),
150     Replace(ast::ExprStmt),
151     WrapInBlock(SyntaxNode),
152 }
153
154 impl Anchor {
155     fn from(to_extract: &ast::Expr) -> Option<Anchor> {
156         to_extract
157             .syntax()
158             .ancestors()
159             .take_while(|it| !ast::Item::can_cast(it.kind()) || ast::MacroCall::can_cast(it.kind()))
160             .find_map(|node| {
161                 if let Some(expr) =
162                     node.parent().and_then(ast::StmtList::cast).and_then(|it| it.tail_expr())
163                 {
164                     if expr.syntax() == &node {
165                         cov_mark::hit!(test_extract_var_last_expr);
166                         return Some(Anchor::Before(node));
167                     }
168                 }
169
170                 if let Some(parent) = node.parent() {
171                     if parent.kind() == CLOSURE_EXPR {
172                         cov_mark::hit!(test_extract_var_in_closure_no_block);
173                         return Some(Anchor::WrapInBlock(node));
174                     }
175                     if parent.kind() == MATCH_ARM {
176                         if node.kind() == MATCH_GUARD {
177                             cov_mark::hit!(test_extract_var_in_match_guard);
178                         } else {
179                             cov_mark::hit!(test_extract_var_in_match_arm_no_block);
180                             return Some(Anchor::WrapInBlock(node));
181                         }
182                     }
183                 }
184
185                 if let Some(stmt) = ast::Stmt::cast(node.clone()) {
186                     if let ast::Stmt::ExprStmt(stmt) = stmt {
187                         if stmt.expr().as_ref() == Some(to_extract) {
188                             return Some(Anchor::Replace(stmt));
189                         }
190                     }
191                     return Some(Anchor::Before(node));
192                 }
193                 None
194             })
195     }
196
197     fn syntax(&self) -> &SyntaxNode {
198         match self {
199             Anchor::Before(it) | Anchor::WrapInBlock(it) => it,
200             Anchor::Replace(stmt) => stmt.syntax(),
201         }
202     }
203 }
204
205 #[cfg(test)]
206 mod tests {
207     use crate::tests::{check_assist, check_assist_not_applicable, check_assist_target};
208
209     use super::*;
210
211     #[test]
212     fn test_extract_var_simple() {
213         check_assist(
214             extract_variable,
215             r#"
216 fn foo() {
217     foo($01 + 1$0);
218 }"#,
219             r#"
220 fn foo() {
221     let $0var_name = 1 + 1;
222     foo(var_name);
223 }"#,
224         );
225     }
226
227     #[test]
228     fn extract_var_in_comment_is_not_applicable() {
229         cov_mark::check!(extract_var_in_comment_is_not_applicable);
230         check_assist_not_applicable(extract_variable, "fn main() { 1 + /* $0comment$0 */ 1; }");
231     }
232
233     #[test]
234     fn test_extract_var_expr_stmt() {
235         cov_mark::check!(test_extract_var_expr_stmt);
236         check_assist(
237             extract_variable,
238             r#"
239 fn foo() {
240     $01 + 1$0;
241 }"#,
242             r#"
243 fn foo() {
244     let $0var_name = 1 + 1;
245 }"#,
246         );
247         check_assist(
248             extract_variable,
249             "
250 fn foo() {
251     $0{ let x = 0; x }$0
252     something_else();
253 }",
254             "
255 fn foo() {
256     let $0var_name = { let x = 0; x };
257     something_else();
258 }",
259         );
260     }
261
262     #[test]
263     fn test_extract_var_part_of_expr_stmt() {
264         check_assist(
265             extract_variable,
266             "
267 fn foo() {
268     $01$0 + 1;
269 }",
270             "
271 fn foo() {
272     let $0var_name = 1;
273     var_name + 1;
274 }",
275         );
276     }
277
278     #[test]
279     fn test_extract_var_last_expr() {
280         cov_mark::check!(test_extract_var_last_expr);
281         check_assist(
282             extract_variable,
283             r#"
284 fn foo() {
285     bar($01 + 1$0)
286 }
287 "#,
288             r#"
289 fn foo() {
290     let $0var_name = 1 + 1;
291     bar(var_name)
292 }
293 "#,
294         );
295         check_assist(
296             extract_variable,
297             r#"
298 fn foo() -> i32 {
299     $0bar(1 + 1)$0
300 }
301
302 fn bar(i: i32) -> i32 {
303     i
304 }
305 "#,
306             r#"
307 fn foo() -> i32 {
308     let $0bar = bar(1 + 1);
309     bar
310 }
311
312 fn bar(i: i32) -> i32 {
313     i
314 }
315 "#,
316         )
317     }
318
319     #[test]
320     fn test_extract_var_in_match_arm_no_block() {
321         cov_mark::check!(test_extract_var_in_match_arm_no_block);
322         check_assist(
323             extract_variable,
324             r#"
325 fn main() {
326     let x = true;
327     let tuple = match x {
328         true => ($02 + 2$0, true)
329         _ => (0, false)
330     };
331 }
332 "#,
333             r#"
334 fn main() {
335     let x = true;
336     let tuple = match x {
337         true => { let $0var_name = 2 + 2; (var_name, true) }
338         _ => (0, false)
339     };
340 }
341 "#,
342         );
343     }
344
345     #[test]
346     fn test_extract_var_in_match_arm_with_block() {
347         check_assist(
348             extract_variable,
349             r#"
350 fn main() {
351     let x = true;
352     let tuple = match x {
353         true => {
354             let y = 1;
355             ($02 + y$0, true)
356         }
357         _ => (0, false)
358     };
359 }
360 "#,
361             r#"
362 fn main() {
363     let x = true;
364     let tuple = match x {
365         true => {
366             let y = 1;
367             let $0var_name = 2 + y;
368             (var_name, true)
369         }
370         _ => (0, false)
371     };
372 }
373 "#,
374         );
375     }
376
377     #[test]
378     fn test_extract_var_in_match_guard() {
379         cov_mark::check!(test_extract_var_in_match_guard);
380         check_assist(
381             extract_variable,
382             r#"
383 fn main() {
384     match () {
385         () if $010 > 0$0 => 1
386         _ => 2
387     };
388 }
389 "#,
390             r#"
391 fn main() {
392     let $0var_name = 10 > 0;
393     match () {
394         () if var_name => 1
395         _ => 2
396     };
397 }
398 "#,
399         );
400     }
401
402     #[test]
403     fn test_extract_var_in_closure_no_block() {
404         cov_mark::check!(test_extract_var_in_closure_no_block);
405         check_assist(
406             extract_variable,
407             r#"
408 fn main() {
409     let lambda = |x: u32| $0x * 2$0;
410 }
411 "#,
412             r#"
413 fn main() {
414     let lambda = |x: u32| { let $0var_name = x * 2; var_name };
415 }
416 "#,
417         );
418     }
419
420     #[test]
421     fn test_extract_var_in_closure_with_block() {
422         check_assist(
423             extract_variable,
424             r#"
425 fn main() {
426     let lambda = |x: u32| { $0x * 2$0 };
427 }
428 "#,
429             r#"
430 fn main() {
431     let lambda = |x: u32| { let $0var_name = x * 2; var_name };
432 }
433 "#,
434         );
435     }
436
437     #[test]
438     fn test_extract_var_path_simple() {
439         check_assist(
440             extract_variable,
441             "
442 fn main() {
443     let o = $0Some(true)$0;
444 }
445 ",
446             "
447 fn main() {
448     let $0var_name = Some(true);
449     let o = var_name;
450 }
451 ",
452         );
453     }
454
455     #[test]
456     fn test_extract_var_path_method() {
457         check_assist(
458             extract_variable,
459             "
460 fn main() {
461     let v = $0bar.foo()$0;
462 }
463 ",
464             "
465 fn main() {
466     let $0foo = bar.foo();
467     let v = foo;
468 }
469 ",
470         );
471     }
472
473     #[test]
474     fn test_extract_var_return() {
475         check_assist(
476             extract_variable,
477             "
478 fn foo() -> u32 {
479     $0return 2 + 2$0;
480 }
481 ",
482             "
483 fn foo() -> u32 {
484     let $0var_name = 2 + 2;
485     return var_name;
486 }
487 ",
488         );
489     }
490
491     #[test]
492     fn test_extract_var_does_not_add_extra_whitespace() {
493         check_assist(
494             extract_variable,
495             "
496 fn foo() -> u32 {
497
498
499     $0return 2 + 2$0;
500 }
501 ",
502             "
503 fn foo() -> u32 {
504
505
506     let $0var_name = 2 + 2;
507     return var_name;
508 }
509 ",
510         );
511
512         check_assist(
513             extract_variable,
514             "
515 fn foo() -> u32 {
516
517         $0return 2 + 2$0;
518 }
519 ",
520             "
521 fn foo() -> u32 {
522
523         let $0var_name = 2 + 2;
524         return var_name;
525 }
526 ",
527         );
528
529         check_assist(
530             extract_variable,
531             "
532 fn foo() -> u32 {
533     let foo = 1;
534
535     // bar
536
537
538     $0return 2 + 2$0;
539 }
540 ",
541             "
542 fn foo() -> u32 {
543     let foo = 1;
544
545     // bar
546
547
548     let $0var_name = 2 + 2;
549     return var_name;
550 }
551 ",
552         );
553     }
554
555     #[test]
556     fn test_extract_var_break() {
557         check_assist(
558             extract_variable,
559             "
560 fn main() {
561     let result = loop {
562         $0break 2 + 2$0;
563     };
564 }
565 ",
566             "
567 fn main() {
568     let result = loop {
569         let $0var_name = 2 + 2;
570         break var_name;
571     };
572 }
573 ",
574         );
575     }
576
577     #[test]
578     fn test_extract_var_for_cast() {
579         check_assist(
580             extract_variable,
581             "
582 fn main() {
583     let v = $00f32 as u32$0;
584 }
585 ",
586             "
587 fn main() {
588     let $0var_name = 0f32 as u32;
589     let v = var_name;
590 }
591 ",
592         );
593     }
594
595     #[test]
596     fn extract_var_field_shorthand() {
597         check_assist(
598             extract_variable,
599             r#"
600 struct S {
601     foo: i32
602 }
603
604 fn main() {
605     S { foo: $01 + 1$0 }
606 }
607 "#,
608             r#"
609 struct S {
610     foo: i32
611 }
612
613 fn main() {
614     let $0foo = 1 + 1;
615     S { foo }
616 }
617 "#,
618         )
619     }
620
621     #[test]
622     fn extract_var_name_from_type() {
623         check_assist(
624             extract_variable,
625             r#"
626 struct Test(i32);
627
628 fn foo() -> Test {
629     $0{ Test(10) }$0
630 }
631 "#,
632             r#"
633 struct Test(i32);
634
635 fn foo() -> Test {
636     let $0test = { Test(10) };
637     test
638 }
639 "#,
640         )
641     }
642
643     #[test]
644     fn extract_var_name_from_parameter() {
645         check_assist(
646             extract_variable,
647             r#"
648 fn bar(test: u32, size: u32)
649
650 fn foo() {
651     bar(1, $01+1$0);
652 }
653 "#,
654             r#"
655 fn bar(test: u32, size: u32)
656
657 fn foo() {
658     let $0size = 1+1;
659     bar(1, size);
660 }
661 "#,
662         )
663     }
664
665     #[test]
666     fn extract_var_parameter_name_has_precedence_over_type() {
667         check_assist(
668             extract_variable,
669             r#"
670 struct TextSize(u32);
671 fn bar(test: u32, size: TextSize)
672
673 fn foo() {
674     bar(1, $0{ TextSize(1+1) }$0);
675 }
676 "#,
677             r#"
678 struct TextSize(u32);
679 fn bar(test: u32, size: TextSize)
680
681 fn foo() {
682     let $0size = { TextSize(1+1) };
683     bar(1, size);
684 }
685 "#,
686         )
687     }
688
689     #[test]
690     fn extract_var_name_from_function() {
691         check_assist(
692             extract_variable,
693             r#"
694 fn is_required(test: u32, size: u32) -> bool
695
696 fn foo() -> bool {
697     $0is_required(1, 2)$0
698 }
699 "#,
700             r#"
701 fn is_required(test: u32, size: u32) -> bool
702
703 fn foo() -> bool {
704     let $0is_required = is_required(1, 2);
705     is_required
706 }
707 "#,
708         )
709     }
710
711     #[test]
712     fn extract_var_name_from_method() {
713         check_assist(
714             extract_variable,
715             r#"
716 struct S;
717 impl S {
718     fn bar(&self, n: u32) -> u32 { n }
719 }
720
721 fn foo() -> u32 {
722     $0S.bar(1)$0
723 }
724 "#,
725             r#"
726 struct S;
727 impl S {
728     fn bar(&self, n: u32) -> u32 { n }
729 }
730
731 fn foo() -> u32 {
732     let $0bar = S.bar(1);
733     bar
734 }
735 "#,
736         )
737     }
738
739     #[test]
740     fn extract_var_name_from_method_param() {
741         check_assist(
742             extract_variable,
743             r#"
744 struct S;
745 impl S {
746     fn bar(&self, n: u32, size: u32) { n }
747 }
748
749 fn foo() {
750     S.bar($01 + 1$0, 2)
751 }
752 "#,
753             r#"
754 struct S;
755 impl S {
756     fn bar(&self, n: u32, size: u32) { n }
757 }
758
759 fn foo() {
760     let $0n = 1 + 1;
761     S.bar(n, 2)
762 }
763 "#,
764         )
765     }
766
767     #[test]
768     fn extract_var_name_from_ufcs_method_param() {
769         check_assist(
770             extract_variable,
771             r#"
772 struct S;
773 impl S {
774     fn bar(&self, n: u32, size: u32) { n }
775 }
776
777 fn foo() {
778     S::bar(&S, $01 + 1$0, 2)
779 }
780 "#,
781             r#"
782 struct S;
783 impl S {
784     fn bar(&self, n: u32, size: u32) { n }
785 }
786
787 fn foo() {
788     let $0n = 1 + 1;
789     S::bar(&S, n, 2)
790 }
791 "#,
792         )
793     }
794
795     #[test]
796     fn extract_var_parameter_name_has_precedence_over_function() {
797         check_assist(
798             extract_variable,
799             r#"
800 fn bar(test: u32, size: u32)
801
802 fn foo() {
803     bar(1, $0symbol_size(1, 2)$0);
804 }
805 "#,
806             r#"
807 fn bar(test: u32, size: u32)
808
809 fn foo() {
810     let $0size = symbol_size(1, 2);
811     bar(1, size);
812 }
813 "#,
814         )
815     }
816
817     #[test]
818     fn test_extract_var_for_return_not_applicable() {
819         check_assist_not_applicable(extract_variable, "fn foo() { $0return$0; } ");
820     }
821
822     #[test]
823     fn test_extract_var_for_break_not_applicable() {
824         check_assist_not_applicable(extract_variable, "fn main() { loop { $0break$0; }; }");
825     }
826
827     #[test]
828     fn test_extract_var_unit_expr_not_applicable() {
829         check_assist_not_applicable(
830             extract_variable,
831             r#"
832 fn foo() {
833     let mut i = 3;
834     $0if i >= 0 {
835         i += 1;
836     } else {
837         i -= 1;
838     }$0
839 }"#,
840         );
841     }
842
843     // FIXME: This is not quite correct, but good enough(tm) for the sorting heuristic
844     #[test]
845     fn extract_var_target() {
846         check_assist_target(extract_variable, "fn foo() -> u32 { $0return 2 + 2$0; }", "2 + 2");
847
848         check_assist_target(
849             extract_variable,
850             "
851 fn main() {
852     let x = true;
853     let tuple = match x {
854         true => ($02 + 2$0, true)
855         _ => (0, false)
856     };
857 }
858 ",
859             "2 + 2",
860         );
861     }
862
863     #[test]
864     fn extract_var_no_block_body() {
865         check_assist_not_applicable(
866             extract_variable,
867             r"
868 const X: usize = $0100$0;
869 ",
870         );
871     }
872 }