]> git.lizzy.rs Git - rust.git/blob - crates/ide_assists/src/handlers/extract_variable.rs
Merge #10211
[rust.git] / crates / ide_assists / src / handlers / extract_variable.rs
1 use stdx::format_to;
2 use syntax::{
3     ast::{self, AstNode},
4     SyntaxKind::{
5         BLOCK_EXPR, BREAK_EXPR, CLOSURE_EXPR, COMMENT, LOOP_EXPR, MATCH_ARM, MATCH_GUARD,
6         PATH_EXPR, RETURN_EXPR,
7     },
8     SyntaxNode,
9 };
10
11 use crate::{utils::suggest_name, AssistContext, AssistId, AssistKind, Assists};
12
13 // Assist: extract_variable
14 //
15 // Extracts subexpression into a variable.
16 //
17 // ```
18 // fn main() {
19 //     $0(1 + 2)$0 * 4;
20 // }
21 // ```
22 // ->
23 // ```
24 // fn main() {
25 //     let $0var_name = (1 + 2);
26 //     var_name * 4;
27 // }
28 // ```
29 pub(crate) fn extract_variable(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
30     if ctx.frange.range.is_empty() {
31         return None;
32     }
33     let node = ctx.covering_element();
34     if node.kind() == COMMENT {
35         cov_mark::hit!(extract_var_in_comment_is_not_applicable);
36         return None;
37     }
38     let to_extract = node
39         .ancestors()
40         .take_while(|it| it.text_range().contains_range(ctx.frange.range))
41         .find_map(valid_target_expr)?;
42     if let Some(ty_info) = ctx.sema.type_of_expr(&to_extract) {
43         if ty_info.adjusted().is_unit() {
44             return None;
45         }
46     }
47     let anchor = Anchor::from(&to_extract)?;
48     let indent = anchor.syntax().prev_sibling_or_token()?.as_token()?.clone();
49     let target = to_extract.syntax().text_range();
50     acc.add(
51         AssistId("extract_variable", AssistKind::RefactorExtract),
52         "Extract into variable",
53         target,
54         move |edit| {
55             let field_shorthand =
56                 match to_extract.syntax().parent().and_then(ast::RecordExprField::cast) {
57                     Some(field) => field.name_ref(),
58                     None => None,
59                 };
60
61             let mut buf = String::new();
62
63             let var_name = match &field_shorthand {
64                 Some(it) => it.to_string(),
65                 None => suggest_name::for_variable(&to_extract, &ctx.sema),
66             };
67             let expr_range = match &field_shorthand {
68                 Some(it) => it.syntax().text_range().cover(to_extract.syntax().text_range()),
69                 None => to_extract.syntax().text_range(),
70             };
71
72             if let Anchor::WrapInBlock(_) = anchor {
73                 format_to!(buf, "{{ let {} = ", var_name);
74             } else {
75                 format_to!(buf, "let {} = ", var_name);
76             };
77             format_to!(buf, "{}", to_extract.syntax());
78
79             if let Anchor::Replace(stmt) = anchor {
80                 cov_mark::hit!(test_extract_var_expr_stmt);
81                 if stmt.semicolon_token().is_none() {
82                     buf.push(';');
83                 }
84                 match ctx.config.snippet_cap {
85                     Some(cap) => {
86                         let snip = buf
87                             .replace(&format!("let {}", var_name), &format!("let $0{}", var_name));
88                         edit.replace_snippet(cap, expr_range, snip)
89                     }
90                     None => edit.replace(expr_range, buf),
91                 }
92                 return;
93             }
94
95             buf.push(';');
96
97             // We want to maintain the indent level,
98             // but we do not want to duplicate possible
99             // extra newlines in the indent block
100             let text = indent.text();
101             if text.starts_with('\n') {
102                 buf.push('\n');
103                 buf.push_str(text.trim_start_matches('\n'));
104             } else {
105                 buf.push_str(text);
106             }
107
108             edit.replace(expr_range, var_name.clone());
109             let offset = anchor.syntax().text_range().start();
110             match ctx.config.snippet_cap {
111                 Some(cap) => {
112                     let snip =
113                         buf.replace(&format!("let {}", var_name), &format!("let $0{}", var_name));
114                     edit.insert_snippet(cap, offset, snip)
115                 }
116                 None => edit.insert(offset, buf),
117             }
118
119             if let Anchor::WrapInBlock(_) = anchor {
120                 edit.insert(anchor.syntax().text_range().end(), " }");
121             }
122         },
123     )
124 }
125
126 /// Check whether the node is a valid expression which can be extracted to a variable.
127 /// In general that's true for any expression, but in some cases that would produce invalid code.
128 fn valid_target_expr(node: SyntaxNode) -> Option<ast::Expr> {
129     match node.kind() {
130         PATH_EXPR | LOOP_EXPR => None,
131         BREAK_EXPR => ast::BreakExpr::cast(node).and_then(|e| e.expr()),
132         RETURN_EXPR => ast::ReturnExpr::cast(node).and_then(|e| e.expr()),
133         BLOCK_EXPR => {
134             ast::BlockExpr::cast(node).filter(|it| it.is_standalone()).map(ast::Expr::from)
135         }
136         _ => ast::Expr::cast(node),
137     }
138 }
139
140 enum Anchor {
141     Before(SyntaxNode),
142     Replace(ast::ExprStmt),
143     WrapInBlock(SyntaxNode),
144 }
145
146 impl Anchor {
147     fn from(to_extract: &ast::Expr) -> Option<Anchor> {
148         to_extract.syntax().ancestors().take_while(|it| !ast::Item::can_cast(it.kind())).find_map(
149             |node| {
150                 if let Some(expr) =
151                     node.parent().and_then(ast::BlockExpr::cast).and_then(|it| it.tail_expr())
152                 {
153                     if expr.syntax() == &node {
154                         cov_mark::hit!(test_extract_var_last_expr);
155                         return Some(Anchor::Before(node));
156                     }
157                 }
158
159                 if let Some(parent) = node.parent() {
160                     if parent.kind() == CLOSURE_EXPR {
161                         cov_mark::hit!(test_extract_var_in_closure_no_block);
162                         return Some(Anchor::WrapInBlock(node));
163                     }
164                     if parent.kind() == MATCH_ARM {
165                         if node.kind() == MATCH_GUARD {
166                             cov_mark::hit!(test_extract_var_in_match_guard);
167                         } else {
168                             cov_mark::hit!(test_extract_var_in_match_arm_no_block);
169                             return Some(Anchor::WrapInBlock(node));
170                         }
171                     }
172                 }
173
174                 if let Some(stmt) = ast::Stmt::cast(node.clone()) {
175                     if let ast::Stmt::ExprStmt(stmt) = stmt {
176                         if stmt.expr().as_ref() == Some(to_extract) {
177                             return Some(Anchor::Replace(stmt));
178                         }
179                     }
180                     return Some(Anchor::Before(node));
181                 }
182                 None
183             },
184         )
185     }
186
187     fn syntax(&self) -> &SyntaxNode {
188         match self {
189             Anchor::Before(it) | Anchor::WrapInBlock(it) => it,
190             Anchor::Replace(stmt) => stmt.syntax(),
191         }
192     }
193 }
194
195 #[cfg(test)]
196 mod tests {
197     use crate::tests::{check_assist, check_assist_not_applicable, check_assist_target};
198
199     use super::*;
200
201     #[test]
202     fn test_extract_var_simple() {
203         check_assist(
204             extract_variable,
205             r#"
206 fn foo() {
207     foo($01 + 1$0);
208 }"#,
209             r#"
210 fn foo() {
211     let $0var_name = 1 + 1;
212     foo(var_name);
213 }"#,
214         );
215     }
216
217     #[test]
218     fn extract_var_in_comment_is_not_applicable() {
219         cov_mark::check!(extract_var_in_comment_is_not_applicable);
220         check_assist_not_applicable(extract_variable, "fn main() { 1 + /* $0comment$0 */ 1; }");
221     }
222
223     #[test]
224     fn test_extract_var_expr_stmt() {
225         cov_mark::check!(test_extract_var_expr_stmt);
226         check_assist(
227             extract_variable,
228             r#"
229 fn foo() {
230     $01 + 1$0;
231 }"#,
232             r#"
233 fn foo() {
234     let $0var_name = 1 + 1;
235 }"#,
236         );
237         check_assist(
238             extract_variable,
239             "
240 fn foo() {
241     $0{ let x = 0; x }$0
242     something_else();
243 }",
244             "
245 fn foo() {
246     let $0var_name = { let x = 0; x };
247     something_else();
248 }",
249         );
250     }
251
252     #[test]
253     fn test_extract_var_part_of_expr_stmt() {
254         check_assist(
255             extract_variable,
256             "
257 fn foo() {
258     $01$0 + 1;
259 }",
260             "
261 fn foo() {
262     let $0var_name = 1;
263     var_name + 1;
264 }",
265         );
266     }
267
268     #[test]
269     fn test_extract_var_last_expr() {
270         cov_mark::check!(test_extract_var_last_expr);
271         check_assist(
272             extract_variable,
273             r#"
274 fn foo() {
275     bar($01 + 1$0)
276 }
277 "#,
278             r#"
279 fn foo() {
280     let $0var_name = 1 + 1;
281     bar(var_name)
282 }
283 "#,
284         );
285         check_assist(
286             extract_variable,
287             r#"
288 fn foo() -> i32 {
289     $0bar(1 + 1)$0
290 }
291
292 fn bar(i: i32) -> i32 {
293     i
294 }
295 "#,
296             r#"
297 fn foo() -> i32 {
298     let $0bar = bar(1 + 1);
299     bar
300 }
301
302 fn bar(i: i32) -> i32 {
303     i
304 }
305 "#,
306         )
307     }
308
309     #[test]
310     fn test_extract_var_in_match_arm_no_block() {
311         cov_mark::check!(test_extract_var_in_match_arm_no_block);
312         check_assist(
313             extract_variable,
314             r#"
315 fn main() {
316     let x = true;
317     let tuple = match x {
318         true => ($02 + 2$0, true)
319         _ => (0, false)
320     };
321 }
322 "#,
323             r#"
324 fn main() {
325     let x = true;
326     let tuple = match x {
327         true => { let $0var_name = 2 + 2; (var_name, true) }
328         _ => (0, false)
329     };
330 }
331 "#,
332         );
333     }
334
335     #[test]
336     fn test_extract_var_in_match_arm_with_block() {
337         check_assist(
338             extract_variable,
339             r#"
340 fn main() {
341     let x = true;
342     let tuple = match x {
343         true => {
344             let y = 1;
345             ($02 + y$0, true)
346         }
347         _ => (0, false)
348     };
349 }
350 "#,
351             r#"
352 fn main() {
353     let x = true;
354     let tuple = match x {
355         true => {
356             let y = 1;
357             let $0var_name = 2 + y;
358             (var_name, true)
359         }
360         _ => (0, false)
361     };
362 }
363 "#,
364         );
365     }
366
367     #[test]
368     fn test_extract_var_in_match_guard() {
369         cov_mark::check!(test_extract_var_in_match_guard);
370         check_assist(
371             extract_variable,
372             r#"
373 fn main() {
374     match () {
375         () if $010 > 0$0 => 1
376         _ => 2
377     };
378 }
379 "#,
380             r#"
381 fn main() {
382     let $0var_name = 10 > 0;
383     match () {
384         () if var_name => 1
385         _ => 2
386     };
387 }
388 "#,
389         );
390     }
391
392     #[test]
393     fn test_extract_var_in_closure_no_block() {
394         cov_mark::check!(test_extract_var_in_closure_no_block);
395         check_assist(
396             extract_variable,
397             r#"
398 fn main() {
399     let lambda = |x: u32| $0x * 2$0;
400 }
401 "#,
402             r#"
403 fn main() {
404     let lambda = |x: u32| { let $0var_name = x * 2; var_name };
405 }
406 "#,
407         );
408     }
409
410     #[test]
411     fn test_extract_var_in_closure_with_block() {
412         check_assist(
413             extract_variable,
414             r#"
415 fn main() {
416     let lambda = |x: u32| { $0x * 2$0 };
417 }
418 "#,
419             r#"
420 fn main() {
421     let lambda = |x: u32| { let $0var_name = x * 2; var_name };
422 }
423 "#,
424         );
425     }
426
427     #[test]
428     fn test_extract_var_path_simple() {
429         check_assist(
430             extract_variable,
431             "
432 fn main() {
433     let o = $0Some(true)$0;
434 }
435 ",
436             "
437 fn main() {
438     let $0var_name = Some(true);
439     let o = var_name;
440 }
441 ",
442         );
443     }
444
445     #[test]
446     fn test_extract_var_path_method() {
447         check_assist(
448             extract_variable,
449             "
450 fn main() {
451     let v = $0bar.foo()$0;
452 }
453 ",
454             "
455 fn main() {
456     let $0foo = bar.foo();
457     let v = foo;
458 }
459 ",
460         );
461     }
462
463     #[test]
464     fn test_extract_var_return() {
465         check_assist(
466             extract_variable,
467             "
468 fn foo() -> u32 {
469     $0return 2 + 2$0;
470 }
471 ",
472             "
473 fn foo() -> u32 {
474     let $0var_name = 2 + 2;
475     return var_name;
476 }
477 ",
478         );
479     }
480
481     #[test]
482     fn test_extract_var_does_not_add_extra_whitespace() {
483         check_assist(
484             extract_variable,
485             "
486 fn foo() -> u32 {
487
488
489     $0return 2 + 2$0;
490 }
491 ",
492             "
493 fn foo() -> u32 {
494
495
496     let $0var_name = 2 + 2;
497     return var_name;
498 }
499 ",
500         );
501
502         check_assist(
503             extract_variable,
504             "
505 fn foo() -> u32 {
506
507         $0return 2 + 2$0;
508 }
509 ",
510             "
511 fn foo() -> u32 {
512
513         let $0var_name = 2 + 2;
514         return var_name;
515 }
516 ",
517         );
518
519         check_assist(
520             extract_variable,
521             "
522 fn foo() -> u32 {
523     let foo = 1;
524
525     // bar
526
527
528     $0return 2 + 2$0;
529 }
530 ",
531             "
532 fn foo() -> u32 {
533     let foo = 1;
534
535     // bar
536
537
538     let $0var_name = 2 + 2;
539     return var_name;
540 }
541 ",
542         );
543     }
544
545     #[test]
546     fn test_extract_var_break() {
547         check_assist(
548             extract_variable,
549             "
550 fn main() {
551     let result = loop {
552         $0break 2 + 2$0;
553     };
554 }
555 ",
556             "
557 fn main() {
558     let result = loop {
559         let $0var_name = 2 + 2;
560         break var_name;
561     };
562 }
563 ",
564         );
565     }
566
567     #[test]
568     fn test_extract_var_for_cast() {
569         check_assist(
570             extract_variable,
571             "
572 fn main() {
573     let v = $00f32 as u32$0;
574 }
575 ",
576             "
577 fn main() {
578     let $0var_name = 0f32 as u32;
579     let v = var_name;
580 }
581 ",
582         );
583     }
584
585     #[test]
586     fn extract_var_field_shorthand() {
587         check_assist(
588             extract_variable,
589             r#"
590 struct S {
591     foo: i32
592 }
593
594 fn main() {
595     S { foo: $01 + 1$0 }
596 }
597 "#,
598             r#"
599 struct S {
600     foo: i32
601 }
602
603 fn main() {
604     let $0foo = 1 + 1;
605     S { foo }
606 }
607 "#,
608         )
609     }
610
611     #[test]
612     fn extract_var_name_from_type() {
613         check_assist(
614             extract_variable,
615             r#"
616 struct Test(i32);
617
618 fn foo() -> Test {
619     $0{ Test(10) }$0
620 }
621 "#,
622             r#"
623 struct Test(i32);
624
625 fn foo() -> Test {
626     let $0test = { Test(10) };
627     test
628 }
629 "#,
630         )
631     }
632
633     #[test]
634     fn extract_var_name_from_parameter() {
635         check_assist(
636             extract_variable,
637             r#"
638 fn bar(test: u32, size: u32)
639
640 fn foo() {
641     bar(1, $01+1$0);
642 }
643 "#,
644             r#"
645 fn bar(test: u32, size: u32)
646
647 fn foo() {
648     let $0size = 1+1;
649     bar(1, size);
650 }
651 "#,
652         )
653     }
654
655     #[test]
656     fn extract_var_parameter_name_has_precedence_over_type() {
657         check_assist(
658             extract_variable,
659             r#"
660 struct TextSize(u32);
661 fn bar(test: u32, size: TextSize)
662
663 fn foo() {
664     bar(1, $0{ TextSize(1+1) }$0);
665 }
666 "#,
667             r#"
668 struct TextSize(u32);
669 fn bar(test: u32, size: TextSize)
670
671 fn foo() {
672     let $0size = { TextSize(1+1) };
673     bar(1, size);
674 }
675 "#,
676         )
677     }
678
679     #[test]
680     fn extract_var_name_from_function() {
681         check_assist(
682             extract_variable,
683             r#"
684 fn is_required(test: u32, size: u32) -> bool
685
686 fn foo() -> bool {
687     $0is_required(1, 2)$0
688 }
689 "#,
690             r#"
691 fn is_required(test: u32, size: u32) -> bool
692
693 fn foo() -> bool {
694     let $0is_required = is_required(1, 2);
695     is_required
696 }
697 "#,
698         )
699     }
700
701     #[test]
702     fn extract_var_name_from_method() {
703         check_assist(
704             extract_variable,
705             r#"
706 struct S;
707 impl S {
708     fn bar(&self, n: u32) -> u32 { n }
709 }
710
711 fn foo() -> u32 {
712     $0S.bar(1)$0
713 }
714 "#,
715             r#"
716 struct S;
717 impl S {
718     fn bar(&self, n: u32) -> u32 { n }
719 }
720
721 fn foo() -> u32 {
722     let $0bar = S.bar(1);
723     bar
724 }
725 "#,
726         )
727     }
728
729     #[test]
730     fn extract_var_name_from_method_param() {
731         check_assist(
732             extract_variable,
733             r#"
734 struct S;
735 impl S {
736     fn bar(&self, n: u32, size: u32) { n }
737 }
738
739 fn foo() {
740     S.bar($01 + 1$0, 2)
741 }
742 "#,
743             r#"
744 struct S;
745 impl S {
746     fn bar(&self, n: u32, size: u32) { n }
747 }
748
749 fn foo() {
750     let $0n = 1 + 1;
751     S.bar(n, 2)
752 }
753 "#,
754         )
755     }
756
757     #[test]
758     fn extract_var_name_from_ufcs_method_param() {
759         check_assist(
760             extract_variable,
761             r#"
762 struct S;
763 impl S {
764     fn bar(&self, n: u32, size: u32) { n }
765 }
766
767 fn foo() {
768     S::bar(&S, $01 + 1$0, 2)
769 }
770 "#,
771             r#"
772 struct S;
773 impl S {
774     fn bar(&self, n: u32, size: u32) { n }
775 }
776
777 fn foo() {
778     let $0n = 1 + 1;
779     S::bar(&S, n, 2)
780 }
781 "#,
782         )
783     }
784
785     #[test]
786     fn extract_var_parameter_name_has_precedence_over_function() {
787         check_assist(
788             extract_variable,
789             r#"
790 fn bar(test: u32, size: u32)
791
792 fn foo() {
793     bar(1, $0symbol_size(1, 2)$0);
794 }
795 "#,
796             r#"
797 fn bar(test: u32, size: u32)
798
799 fn foo() {
800     let $0size = symbol_size(1, 2);
801     bar(1, size);
802 }
803 "#,
804         )
805     }
806
807     #[test]
808     fn test_extract_var_for_return_not_applicable() {
809         check_assist_not_applicable(extract_variable, "fn foo() { $0return$0; } ");
810     }
811
812     #[test]
813     fn test_extract_var_for_break_not_applicable() {
814         check_assist_not_applicable(extract_variable, "fn main() { loop { $0break$0; }; }");
815     }
816
817     #[test]
818     fn test_extract_var_unit_expr_not_applicable() {
819         check_assist_not_applicable(
820             extract_variable,
821             r#"
822 fn foo() {
823     let mut i = 3;
824     $0if i >= 0 {
825         i += 1;
826     } else {
827         i -= 1;
828     }$0
829 }"#,
830         );
831     }
832
833     // FIXME: This is not quite correct, but good enough(tm) for the sorting heuristic
834     #[test]
835     fn extract_var_target() {
836         check_assist_target(extract_variable, "fn foo() -> u32 { $0return 2 + 2$0; }", "2 + 2");
837
838         check_assist_target(
839             extract_variable,
840             "
841 fn main() {
842     let x = true;
843     let tuple = match x {
844         true => ($02 + 2$0, true)
845         _ => (0, false)
846     };
847 }
848 ",
849             "2 + 2",
850         );
851     }
852
853     #[test]
854     fn extract_var_no_block_body() {
855         check_assist_not_applicable(
856             extract_variable,
857             r"
858 const X: usize = $0100$0;
859 ",
860         );
861     }
862 }