]> git.lizzy.rs Git - rust.git/blob - src/tools/rust-analyzer/crates/ide-assists/src/handlers/inline_local_variable.rs
Auto merge of #103913 - Neutron3529:patch-1, r=thomcc
[rust.git] / src / tools / rust-analyzer / crates / ide-assists / src / handlers / inline_local_variable.rs
1 use either::Either;
2 use hir::{PathResolution, Semantics};
3 use ide_db::{
4     base_db::FileId,
5     defs::Definition,
6     search::{FileReference, UsageSearchResult},
7     RootDatabase,
8 };
9 use syntax::{
10     ast::{self, AstNode, AstToken, HasName},
11     SyntaxElement, TextRange,
12 };
13
14 use crate::{
15     assist_context::{AssistContext, Assists},
16     AssistId, AssistKind,
17 };
18
19 // Assist: inline_local_variable
20 //
21 // Inlines a local variable.
22 //
23 // ```
24 // fn main() {
25 //     let x$0 = 1 + 2;
26 //     x * 4;
27 // }
28 // ```
29 // ->
30 // ```
31 // fn main() {
32 //     (1 + 2) * 4;
33 // }
34 // ```
35 pub(crate) fn inline_local_variable(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> {
36     let file_id = ctx.file_id();
37     let range = ctx.selection_trimmed();
38     let InlineData { let_stmt, delete_let, references, target } =
39         if let Some(path_expr) = ctx.find_node_at_offset::<ast::PathExpr>() {
40             inline_usage(&ctx.sema, path_expr, range, file_id)
41         } else if let Some(let_stmt) = ctx.find_node_at_offset::<ast::LetStmt>() {
42             inline_let(&ctx.sema, let_stmt, range, file_id)
43         } else {
44             None
45         }?;
46     let initializer_expr = let_stmt.initializer()?;
47
48     let delete_range = delete_let.then(|| {
49         if let Some(whitespace) = let_stmt
50             .syntax()
51             .next_sibling_or_token()
52             .and_then(SyntaxElement::into_token)
53             .and_then(ast::Whitespace::cast)
54         {
55             TextRange::new(
56                 let_stmt.syntax().text_range().start(),
57                 whitespace.syntax().text_range().end(),
58             )
59         } else {
60             let_stmt.syntax().text_range()
61         }
62     });
63
64     let wrap_in_parens = references
65         .into_iter()
66         .filter_map(|FileReference { range, name, .. }| match name {
67             ast::NameLike::NameRef(name) => Some((range, name)),
68             _ => None,
69         })
70         .map(|(range, name_ref)| {
71             if range != name_ref.syntax().text_range() {
72                 // Do not rename inside macros
73                 // FIXME: This feels like a bad heuristic for macros
74                 return None;
75             }
76             let usage_node =
77                 name_ref.syntax().ancestors().find(|it| ast::PathExpr::can_cast(it.kind()));
78             let usage_parent_option =
79                 usage_node.and_then(|it| it.parent()).and_then(ast::Expr::cast);
80             let usage_parent = match usage_parent_option {
81                 Some(u) => u,
82                 None => return Some((range, name_ref, false)),
83             };
84             let initializer = matches!(
85                 initializer_expr,
86                 ast::Expr::CallExpr(_)
87                     | ast::Expr::IndexExpr(_)
88                     | ast::Expr::MethodCallExpr(_)
89                     | ast::Expr::FieldExpr(_)
90                     | ast::Expr::TryExpr(_)
91                     | ast::Expr::Literal(_)
92                     | ast::Expr::TupleExpr(_)
93                     | ast::Expr::ArrayExpr(_)
94                     | ast::Expr::ParenExpr(_)
95                     | ast::Expr::PathExpr(_)
96                     | ast::Expr::BlockExpr(_),
97             );
98             let parent = matches!(
99                 usage_parent,
100                 ast::Expr::CallExpr(_)
101                     | ast::Expr::TupleExpr(_)
102                     | ast::Expr::ArrayExpr(_)
103                     | ast::Expr::ParenExpr(_)
104                     | ast::Expr::ForExpr(_)
105                     | ast::Expr::WhileExpr(_)
106                     | ast::Expr::BreakExpr(_)
107                     | ast::Expr::ReturnExpr(_)
108                     | ast::Expr::MatchExpr(_)
109                     | ast::Expr::BlockExpr(_)
110             );
111             Some((range, name_ref, !(initializer || parent)))
112         })
113         .collect::<Option<Vec<_>>>()?;
114
115     let init_str = initializer_expr.syntax().text().to_string();
116     let init_in_paren = format!("({init_str})");
117
118     let target = match target {
119         ast::NameOrNameRef::Name(it) => it.syntax().text_range(),
120         ast::NameOrNameRef::NameRef(it) => it.syntax().text_range(),
121     };
122
123     acc.add(
124         AssistId("inline_local_variable", AssistKind::RefactorInline),
125         "Inline variable",
126         target,
127         move |builder| {
128             if let Some(range) = delete_range {
129                 builder.delete(range);
130             }
131             for (range, name, should_wrap) in wrap_in_parens {
132                 let replacement = if should_wrap { &init_in_paren } else { &init_str };
133                 if ast::RecordExprField::for_field_name(&name).is_some() {
134                     cov_mark::hit!(inline_field_shorthand);
135                     builder.insert(range.end(), format!(": {replacement}"));
136                 } else {
137                     builder.replace(range, replacement.clone())
138                 }
139             }
140         },
141     )
142 }
143
144 struct InlineData {
145     let_stmt: ast::LetStmt,
146     delete_let: bool,
147     target: ast::NameOrNameRef,
148     references: Vec<FileReference>,
149 }
150
151 fn inline_let(
152     sema: &Semantics<'_, RootDatabase>,
153     let_stmt: ast::LetStmt,
154     range: TextRange,
155     file_id: FileId,
156 ) -> Option<InlineData> {
157     let bind_pat = match let_stmt.pat()? {
158         ast::Pat::IdentPat(pat) => pat,
159         _ => return None,
160     };
161     if bind_pat.mut_token().is_some() {
162         cov_mark::hit!(test_not_inline_mut_variable);
163         return None;
164     }
165     if !bind_pat.syntax().text_range().contains_range(range) {
166         cov_mark::hit!(not_applicable_outside_of_bind_pat);
167         return None;
168     }
169
170     let local = sema.to_def(&bind_pat)?;
171     let UsageSearchResult { mut references } = Definition::Local(local).usages(sema).all();
172     match references.remove(&file_id) {
173         Some(references) => Some(InlineData {
174             let_stmt,
175             delete_let: true,
176             target: ast::NameOrNameRef::Name(bind_pat.name()?),
177             references,
178         }),
179         None => {
180             cov_mark::hit!(test_not_applicable_if_variable_unused);
181             None
182         }
183     }
184 }
185
186 fn inline_usage(
187     sema: &Semantics<'_, RootDatabase>,
188     path_expr: ast::PathExpr,
189     range: TextRange,
190     file_id: FileId,
191 ) -> Option<InlineData> {
192     let path = path_expr.path()?;
193     let name = path.as_single_name_ref()?;
194     if !name.syntax().text_range().contains_range(range) {
195         cov_mark::hit!(test_not_inline_selection_too_broad);
196         return None;
197     }
198
199     let local = match sema.resolve_path(&path)? {
200         PathResolution::Local(local) => local,
201         _ => return None,
202     };
203     if local.is_mut(sema.db) {
204         cov_mark::hit!(test_not_inline_mut_variable_use);
205         return None;
206     }
207
208     // FIXME: Handle multiple local definitions
209     let bind_pat = match local.source(sema.db).value {
210         Either::Left(ident) => ident,
211         _ => return None,
212     };
213
214     let let_stmt = ast::LetStmt::cast(bind_pat.syntax().parent()?)?;
215
216     let UsageSearchResult { mut references } = Definition::Local(local).usages(sema).all();
217     let mut references = references.remove(&file_id)?;
218     let delete_let = references.len() == 1;
219     references.retain(|fref| fref.name.as_name_ref() == Some(&name));
220
221     Some(InlineData { let_stmt, delete_let, target: ast::NameOrNameRef::NameRef(name), references })
222 }
223
224 #[cfg(test)]
225 mod tests {
226     use crate::tests::{check_assist, check_assist_not_applicable};
227
228     use super::*;
229
230     #[test]
231     fn test_inline_let_bind_literal_expr() {
232         check_assist(
233             inline_local_variable,
234             r"
235 fn bar(a: usize) {}
236 fn foo() {
237     let a$0 = 1;
238     a + 1;
239     if a > 10 {
240     }
241
242     while a > 10 {
243
244     }
245     let b = a * 10;
246     bar(a);
247 }",
248             r"
249 fn bar(a: usize) {}
250 fn foo() {
251     1 + 1;
252     if 1 > 10 {
253     }
254
255     while 1 > 10 {
256
257     }
258     let b = 1 * 10;
259     bar(1);
260 }",
261         );
262     }
263
264     #[test]
265     fn test_inline_let_bind_bin_expr() {
266         check_assist(
267             inline_local_variable,
268             r"
269 fn bar(a: usize) {}
270 fn foo() {
271     let a$0 = 1 + 1;
272     a + 1;
273     if a > 10 {
274     }
275
276     while a > 10 {
277
278     }
279     let b = a * 10;
280     bar(a);
281 }",
282             r"
283 fn bar(a: usize) {}
284 fn foo() {
285     (1 + 1) + 1;
286     if (1 + 1) > 10 {
287     }
288
289     while (1 + 1) > 10 {
290
291     }
292     let b = (1 + 1) * 10;
293     bar(1 + 1);
294 }",
295         );
296     }
297
298     #[test]
299     fn test_inline_let_bind_function_call_expr() {
300         check_assist(
301             inline_local_variable,
302             r"
303 fn bar(a: usize) {}
304 fn foo() {
305     let a$0 = bar(1);
306     a + 1;
307     if a > 10 {
308     }
309
310     while a > 10 {
311
312     }
313     let b = a * 10;
314     bar(a);
315 }",
316             r"
317 fn bar(a: usize) {}
318 fn foo() {
319     bar(1) + 1;
320     if bar(1) > 10 {
321     }
322
323     while bar(1) > 10 {
324
325     }
326     let b = bar(1) * 10;
327     bar(bar(1));
328 }",
329         );
330     }
331
332     #[test]
333     fn test_inline_let_bind_cast_expr() {
334         check_assist(
335             inline_local_variable,
336             r"
337 fn bar(a: usize): usize { a }
338 fn foo() {
339     let a$0 = bar(1) as u64;
340     a + 1;
341     if a > 10 {
342     }
343
344     while a > 10 {
345
346     }
347     let b = a * 10;
348     bar(a);
349 }",
350             r"
351 fn bar(a: usize): usize { a }
352 fn foo() {
353     (bar(1) as u64) + 1;
354     if (bar(1) as u64) > 10 {
355     }
356
357     while (bar(1) as u64) > 10 {
358
359     }
360     let b = (bar(1) as u64) * 10;
361     bar(bar(1) as u64);
362 }",
363         );
364     }
365
366     #[test]
367     fn test_inline_let_bind_block_expr() {
368         check_assist(
369             inline_local_variable,
370             r"
371 fn foo() {
372     let a$0 = { 10 + 1 };
373     a + 1;
374     if a > 10 {
375     }
376
377     while a > 10 {
378
379     }
380     let b = a * 10;
381     bar(a);
382 }",
383             r"
384 fn foo() {
385     { 10 + 1 } + 1;
386     if { 10 + 1 } > 10 {
387     }
388
389     while { 10 + 1 } > 10 {
390
391     }
392     let b = { 10 + 1 } * 10;
393     bar({ 10 + 1 });
394 }",
395         );
396     }
397
398     #[test]
399     fn test_inline_let_bind_paren_expr() {
400         check_assist(
401             inline_local_variable,
402             r"
403 fn foo() {
404     let a$0 = ( 10 + 1 );
405     a + 1;
406     if a > 10 {
407     }
408
409     while a > 10 {
410
411     }
412     let b = a * 10;
413     bar(a);
414 }",
415             r"
416 fn foo() {
417     ( 10 + 1 ) + 1;
418     if ( 10 + 1 ) > 10 {
419     }
420
421     while ( 10 + 1 ) > 10 {
422
423     }
424     let b = ( 10 + 1 ) * 10;
425     bar(( 10 + 1 ));
426 }",
427         );
428     }
429
430     #[test]
431     fn test_not_inline_mut_variable() {
432         cov_mark::check!(test_not_inline_mut_variable);
433         check_assist_not_applicable(
434             inline_local_variable,
435             r"
436 fn foo() {
437     let mut a$0 = 1 + 1;
438     a + 1;
439 }",
440         );
441     }
442
443     #[test]
444     fn test_not_inline_mut_variable_use() {
445         cov_mark::check!(test_not_inline_mut_variable_use);
446         check_assist_not_applicable(
447             inline_local_variable,
448             r"
449 fn foo() {
450     let mut a = 1 + 1;
451     a$0 + 1;
452 }",
453         );
454     }
455
456     #[test]
457     fn test_call_expr() {
458         check_assist(
459             inline_local_variable,
460             r"
461 fn foo() {
462     let a$0 = bar(10 + 1);
463     let b = a * 10;
464     let c = a as usize;
465 }",
466             r"
467 fn foo() {
468     let b = bar(10 + 1) * 10;
469     let c = bar(10 + 1) as usize;
470 }",
471         );
472     }
473
474     #[test]
475     fn test_index_expr() {
476         check_assist(
477             inline_local_variable,
478             r"
479 fn foo() {
480     let x = vec![1, 2, 3];
481     let a$0 = x[0];
482     let b = a * 10;
483     let c = a as usize;
484 }",
485             r"
486 fn foo() {
487     let x = vec![1, 2, 3];
488     let b = x[0] * 10;
489     let c = x[0] as usize;
490 }",
491         );
492     }
493
494     #[test]
495     fn test_method_call_expr() {
496         check_assist(
497             inline_local_variable,
498             r"
499 fn foo() {
500     let bar = vec![1];
501     let a$0 = bar.len();
502     let b = a * 10;
503     let c = a as usize;
504 }",
505             r"
506 fn foo() {
507     let bar = vec![1];
508     let b = bar.len() * 10;
509     let c = bar.len() as usize;
510 }",
511         );
512     }
513
514     #[test]
515     fn test_field_expr() {
516         check_assist(
517             inline_local_variable,
518             r"
519 struct Bar {
520     foo: usize
521 }
522
523 fn foo() {
524     let bar = Bar { foo: 1 };
525     let a$0 = bar.foo;
526     let b = a * 10;
527     let c = a as usize;
528 }",
529             r"
530 struct Bar {
531     foo: usize
532 }
533
534 fn foo() {
535     let bar = Bar { foo: 1 };
536     let b = bar.foo * 10;
537     let c = bar.foo as usize;
538 }",
539         );
540     }
541
542     #[test]
543     fn test_try_expr() {
544         check_assist(
545             inline_local_variable,
546             r"
547 fn foo() -> Option<usize> {
548     let bar = Some(1);
549     let a$0 = bar?;
550     let b = a * 10;
551     let c = a as usize;
552     None
553 }",
554             r"
555 fn foo() -> Option<usize> {
556     let bar = Some(1);
557     let b = bar? * 10;
558     let c = bar? as usize;
559     None
560 }",
561         );
562     }
563
564     #[test]
565     fn test_ref_expr() {
566         check_assist(
567             inline_local_variable,
568             r"
569 fn foo() {
570     let bar = 10;
571     let a$0 = &bar;
572     let b = a * 10;
573 }",
574             r"
575 fn foo() {
576     let bar = 10;
577     let b = (&bar) * 10;
578 }",
579         );
580     }
581
582     #[test]
583     fn test_tuple_expr() {
584         check_assist(
585             inline_local_variable,
586             r"
587 fn foo() {
588     let a$0 = (10, 20);
589     let b = a[0];
590 }",
591             r"
592 fn foo() {
593     let b = (10, 20)[0];
594 }",
595         );
596     }
597
598     #[test]
599     fn test_array_expr() {
600         check_assist(
601             inline_local_variable,
602             r"
603 fn foo() {
604     let a$0 = [1, 2, 3];
605     let b = a.len();
606 }",
607             r"
608 fn foo() {
609     let b = [1, 2, 3].len();
610 }",
611         );
612     }
613
614     #[test]
615     fn test_paren() {
616         check_assist(
617             inline_local_variable,
618             r"
619 fn foo() {
620     let a$0 = (10 + 20);
621     let b = a * 10;
622     let c = a as usize;
623 }",
624             r"
625 fn foo() {
626     let b = (10 + 20) * 10;
627     let c = (10 + 20) as usize;
628 }",
629         );
630     }
631
632     #[test]
633     fn test_path_expr() {
634         check_assist(
635             inline_local_variable,
636             r"
637 fn foo() {
638     let d = 10;
639     let a$0 = d;
640     let b = a * 10;
641     let c = a as usize;
642 }",
643             r"
644 fn foo() {
645     let d = 10;
646     let b = d * 10;
647     let c = d as usize;
648 }",
649         );
650     }
651
652     #[test]
653     fn test_block_expr() {
654         check_assist(
655             inline_local_variable,
656             r"
657 fn foo() {
658     let a$0 = { 10 };
659     let b = a * 10;
660     let c = a as usize;
661 }",
662             r"
663 fn foo() {
664     let b = { 10 } * 10;
665     let c = { 10 } as usize;
666 }",
667         );
668     }
669
670     #[test]
671     fn test_used_in_different_expr1() {
672         check_assist(
673             inline_local_variable,
674             r"
675 fn foo() {
676     let a$0 = 10 + 20;
677     let b = a * 10;
678     let c = (a, 20);
679     let d = [a, 10];
680     let e = (a);
681 }",
682             r"
683 fn foo() {
684     let b = (10 + 20) * 10;
685     let c = (10 + 20, 20);
686     let d = [10 + 20, 10];
687     let e = (10 + 20);
688 }",
689         );
690     }
691
692     #[test]
693     fn test_used_in_for_expr() {
694         check_assist(
695             inline_local_variable,
696             r"
697 fn foo() {
698     let a$0 = vec![10, 20];
699     for i in a {}
700 }",
701             r"
702 fn foo() {
703     for i in vec![10, 20] {}
704 }",
705         );
706     }
707
708     #[test]
709     fn test_used_in_while_expr() {
710         check_assist(
711             inline_local_variable,
712             r"
713 fn foo() {
714     let a$0 = 1 > 0;
715     while a {}
716 }",
717             r"
718 fn foo() {
719     while 1 > 0 {}
720 }",
721         );
722     }
723
724     #[test]
725     fn test_used_in_break_expr() {
726         check_assist(
727             inline_local_variable,
728             r"
729 fn foo() {
730     let a$0 = 1 + 1;
731     loop {
732         break a;
733     }
734 }",
735             r"
736 fn foo() {
737     loop {
738         break 1 + 1;
739     }
740 }",
741         );
742     }
743
744     #[test]
745     fn test_used_in_return_expr() {
746         check_assist(
747             inline_local_variable,
748             r"
749 fn foo() {
750     let a$0 = 1 > 0;
751     return a;
752 }",
753             r"
754 fn foo() {
755     return 1 > 0;
756 }",
757         );
758     }
759
760     #[test]
761     fn test_used_in_match_expr() {
762         check_assist(
763             inline_local_variable,
764             r"
765 fn foo() {
766     let a$0 = 1 > 0;
767     match a {}
768 }",
769             r"
770 fn foo() {
771     match 1 > 0 {}
772 }",
773         );
774     }
775
776     #[test]
777     fn inline_field_shorthand() {
778         cov_mark::check!(inline_field_shorthand);
779         check_assist(
780             inline_local_variable,
781             r"
782 struct S { foo: i32}
783 fn main() {
784     let $0foo = 92;
785     S { foo }
786 }
787 ",
788             r"
789 struct S { foo: i32}
790 fn main() {
791     S { foo: 92 }
792 }
793 ",
794         );
795     }
796
797     #[test]
798     fn test_not_applicable_if_variable_unused() {
799         cov_mark::check!(test_not_applicable_if_variable_unused);
800         check_assist_not_applicable(
801             inline_local_variable,
802             r"
803 fn foo() {
804     let $0a = 0;
805 }
806             ",
807         )
808     }
809
810     #[test]
811     fn not_applicable_outside_of_bind_pat() {
812         cov_mark::check!(not_applicable_outside_of_bind_pat);
813         check_assist_not_applicable(
814             inline_local_variable,
815             r"
816 fn main() {
817     let x = $01 + 2;
818     x * 4;
819 }
820 ",
821         )
822     }
823
824     #[test]
825     fn works_on_local_usage() {
826         check_assist(
827             inline_local_variable,
828             r#"
829 fn f() {
830     let xyz = 0;
831     xyz$0;
832 }
833 "#,
834             r#"
835 fn f() {
836     0;
837 }
838 "#,
839         );
840     }
841
842     #[test]
843     fn does_not_remove_let_when_multiple_usages() {
844         check_assist(
845             inline_local_variable,
846             r#"
847 fn f() {
848     let xyz = 0;
849     xyz$0;
850     xyz;
851 }
852 "#,
853             r#"
854 fn f() {
855     let xyz = 0;
856     0;
857     xyz;
858 }
859 "#,
860         );
861     }
862
863     #[test]
864     fn not_applicable_with_non_ident_pattern() {
865         check_assist_not_applicable(
866             inline_local_variable,
867             r#"
868 fn main() {
869     let (x, y) = (0, 1);
870     x$0;
871 }
872 "#,
873         );
874     }
875
876     #[test]
877     fn not_applicable_on_local_usage_in_macro() {
878         check_assist_not_applicable(
879             inline_local_variable,
880             r#"
881 macro_rules! m {
882     ($i:ident) => { $i }
883 }
884 fn f() {
885     let xyz = 0;
886     m!(xyz$0); // replacing it would break the macro
887 }
888 "#,
889         );
890         check_assist_not_applicable(
891             inline_local_variable,
892             r#"
893 macro_rules! m {
894     ($i:ident) => { $i }
895 }
896 fn f() {
897     let xyz$0 = 0;
898     m!(xyz); // replacing it would break the macro
899 }
900 "#,
901         );
902     }
903
904     #[test]
905     fn test_not_inline_selection_too_broad() {
906         cov_mark::check!(test_not_inline_selection_too_broad);
907         check_assist_not_applicable(
908             inline_local_variable,
909             r#"
910 fn f() {
911     let foo = 0;
912     let bar = 0;
913     $0foo + bar$0;
914 }
915 "#,
916         );
917     }
918
919     #[test]
920     fn test_inline_ref_in_let() {
921         check_assist(
922             inline_local_variable,
923             r#"
924 fn f() {
925     let x = {
926         let y = 0;
927         y$0
928     };
929 }
930 "#,
931             r#"
932 fn f() {
933     let x = {
934         0
935     };
936 }
937 "#,
938         );
939     }
940
941     #[test]
942     fn test_inline_let_unit_struct() {
943         check_assist_not_applicable(
944             inline_local_variable,
945             r#"
946 struct S;
947 fn f() {
948     let S$0 = S;
949     S;
950 }
951 "#,
952         );
953     }
954 }