]> git.lizzy.rs Git - rust.git/blob - src/tools/rust-analyzer/crates/ide-assists/src/handlers/unwrap_result_return_type.rs
Rollup merge of #99954 - dingxiangfei2009:break-out-let-else-higher-up, r=oli-obk
[rust.git] / src / tools / rust-analyzer / crates / ide-assists / src / handlers / unwrap_result_return_type.rs
1 use ide_db::{
2     famous_defs::FamousDefs,
3     syntax_helpers::node_ext::{for_each_tail_expr, walk_expr},
4 };
5 use itertools::Itertools;
6 use syntax::{
7     ast::{self, Expr},
8     match_ast, AstNode, TextRange, TextSize,
9 };
10
11 use crate::{AssistContext, AssistId, AssistKind, Assists};
12
13 // Assist: unwrap_result_return_type
14 //
15 // Unwrap the function's return type.
16 //
17 // ```
18 // # //- minicore: result
19 // fn foo() -> Result<i32>$0 { Ok(42i32) }
20 // ```
21 // ->
22 // ```
23 // fn foo() -> i32 { 42i32 }
24 // ```
25 pub(crate) fn unwrap_result_return_type(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> {
26     let ret_type = ctx.find_node_at_offset::<ast::RetType>()?;
27     let parent = ret_type.syntax().parent()?;
28     let body = match_ast! {
29         match parent {
30             ast::Fn(func) => func.body()?,
31             ast::ClosureExpr(closure) => match closure.body()? {
32                 Expr::BlockExpr(block) => block,
33                 // closures require a block when a return type is specified
34                 _ => return None,
35             },
36             _ => return None,
37         }
38     };
39
40     let type_ref = &ret_type.ty()?;
41     let ty = ctx.sema.resolve_type(type_ref)?.as_adt();
42     let result_enum =
43         FamousDefs(&ctx.sema, ctx.sema.scope(type_ref.syntax())?.krate()).core_result_Result()?;
44
45     if !matches!(ty, Some(hir::Adt::Enum(ret_type)) if ret_type == result_enum) {
46         return None;
47     }
48
49     acc.add(
50         AssistId("unwrap_result_return_type", AssistKind::RefactorRewrite),
51         "Unwrap Result return type",
52         type_ref.syntax().text_range(),
53         |builder| {
54             let body = ast::Expr::BlockExpr(body);
55
56             let mut exprs_to_unwrap = Vec::new();
57             let tail_cb = &mut |e: &_| tail_cb_impl(&mut exprs_to_unwrap, e);
58             walk_expr(&body, &mut |expr| {
59                 if let Expr::ReturnExpr(ret_expr) = expr {
60                     if let Some(ret_expr_arg) = &ret_expr.expr() {
61                         for_each_tail_expr(ret_expr_arg, tail_cb);
62                     }
63                 }
64             });
65             for_each_tail_expr(&body, tail_cb);
66
67             let mut is_unit_type = false;
68             if let Some((_, inner_type)) = type_ref.to_string().split_once('<') {
69                 let inner_type = match inner_type.split_once(',') {
70                     Some((success_inner_type, _)) => success_inner_type,
71                     None => inner_type,
72                 };
73                 let new_ret_type = inner_type.strip_suffix('>').unwrap_or(inner_type);
74                 if new_ret_type == "()" {
75                     is_unit_type = true;
76                     let text_range = TextRange::new(
77                         ret_type.syntax().text_range().start(),
78                         ret_type.syntax().text_range().end() + TextSize::from(1u32),
79                     );
80                     builder.delete(text_range)
81                 } else {
82                     builder.replace(
83                         type_ref.syntax().text_range(),
84                         inner_type.strip_suffix('>').unwrap_or(inner_type),
85                     )
86                 }
87             }
88
89             for ret_expr_arg in exprs_to_unwrap {
90                 let ret_expr_str = ret_expr_arg.to_string();
91                 if ret_expr_str.starts_with("Ok(") || ret_expr_str.starts_with("Err(") {
92                     let arg_list = ret_expr_arg.syntax().children().find_map(ast::ArgList::cast);
93                     if let Some(arg_list) = arg_list {
94                         if is_unit_type {
95                             match ret_expr_arg.syntax().prev_sibling_or_token() {
96                                 // Useful to delete the entire line without leaving trailing whitespaces
97                                 Some(whitespace) => {
98                                     let new_range = TextRange::new(
99                                         whitespace.text_range().start(),
100                                         ret_expr_arg.syntax().text_range().end(),
101                                     );
102                                     builder.delete(new_range);
103                                 }
104                                 None => {
105                                     builder.delete(ret_expr_arg.syntax().text_range());
106                                 }
107                             }
108                         } else {
109                             builder.replace(
110                                 ret_expr_arg.syntax().text_range(),
111                                 arg_list.args().join(", "),
112                             );
113                         }
114                     }
115                 }
116             }
117         },
118     )
119 }
120
121 fn tail_cb_impl(acc: &mut Vec<ast::Expr>, e: &ast::Expr) {
122     match e {
123         Expr::BreakExpr(break_expr) => {
124             if let Some(break_expr_arg) = break_expr.expr() {
125                 for_each_tail_expr(&break_expr_arg, &mut |e| tail_cb_impl(acc, e))
126             }
127         }
128         Expr::ReturnExpr(ret_expr) => {
129             if let Some(ret_expr_arg) = &ret_expr.expr() {
130                 for_each_tail_expr(ret_expr_arg, &mut |e| tail_cb_impl(acc, e));
131             }
132         }
133         e => acc.push(e.clone()),
134     }
135 }
136
137 #[cfg(test)]
138 mod tests {
139     use crate::tests::{check_assist, check_assist_not_applicable};
140
141     use super::*;
142
143     #[test]
144     fn unwrap_result_return_type_simple() {
145         check_assist(
146             unwrap_result_return_type,
147             r#"
148 //- minicore: result
149 fn foo() -> Result<i3$02> {
150     let test = "test";
151     return Ok(42i32);
152 }
153 "#,
154             r#"
155 fn foo() -> i32 {
156     let test = "test";
157     return 42i32;
158 }
159 "#,
160         );
161     }
162
163     #[test]
164     fn unwrap_result_return_type_unit_type() {
165         check_assist(
166             unwrap_result_return_type,
167             r#"
168 //- minicore: result
169 fn foo() -> Result<(), Box<dyn Error$0>> {
170     Ok(())
171 }
172 "#,
173             r#"
174 fn foo() {
175 }
176 "#,
177         );
178     }
179
180     #[test]
181     fn unwrap_result_return_type_ending_with_parent() {
182         check_assist(
183             unwrap_result_return_type,
184             r#"
185 //- minicore: result
186 fn foo() -> Result<i32, Box<dyn Error$0>> {
187     if true {
188         Ok(42)
189     } else {
190         foo()
191     }
192 }
193 "#,
194             r#"
195 fn foo() -> i32 {
196     if true {
197         42
198     } else {
199         foo()
200     }
201 }
202 "#,
203         );
204     }
205
206     #[test]
207     fn unwrap_return_type_break_split_tail() {
208         check_assist(
209             unwrap_result_return_type,
210             r#"
211 //- minicore: result
212 fn foo() -> Result<i3$02, String> {
213     loop {
214         break if true {
215             Ok(1)
216         } else {
217             Ok(0)
218         };
219     }
220 }
221 "#,
222             r#"
223 fn foo() -> i32 {
224     loop {
225         break if true {
226             1
227         } else {
228             0
229         };
230     }
231 }
232 "#,
233         );
234     }
235
236     #[test]
237     fn unwrap_result_return_type_simple_closure() {
238         check_assist(
239             unwrap_result_return_type,
240             r#"
241 //- minicore: result
242 fn foo() {
243     || -> Result<i32$0> {
244         let test = "test";
245         return Ok(42i32);
246     };
247 }
248 "#,
249             r#"
250 fn foo() {
251     || -> i32 {
252         let test = "test";
253         return 42i32;
254     };
255 }
256 "#,
257         );
258     }
259
260     #[test]
261     fn unwrap_result_return_type_simple_return_type_bad_cursor() {
262         check_assist_not_applicable(
263             unwrap_result_return_type,
264             r#"
265 //- minicore: result
266 fn foo() -> i32 {
267     let test = "test";$0
268     return 42i32;
269 }
270 "#,
271         );
272     }
273
274     #[test]
275     fn unwrap_result_return_type_simple_return_type_bad_cursor_closure() {
276         check_assist_not_applicable(
277             unwrap_result_return_type,
278             r#"
279 //- minicore: result
280 fn foo() {
281     || -> i32 {
282         let test = "test";$0
283         return 42i32;
284     };
285 }
286 "#,
287         );
288     }
289
290     #[test]
291     fn unwrap_result_return_type_closure_non_block() {
292         check_assist_not_applicable(
293             unwrap_result_return_type,
294             r#"
295 //- minicore: result
296 fn foo() { || -> i$032 3; }
297 "#,
298         );
299     }
300
301     #[test]
302     fn unwrap_result_return_type_simple_return_type_already_not_result_std() {
303         check_assist_not_applicable(
304             unwrap_result_return_type,
305             r#"
306 //- minicore: result
307 fn foo() -> i32$0 {
308     let test = "test";
309     return 42i32;
310 }
311 "#,
312         );
313     }
314
315     #[test]
316     fn unwrap_result_return_type_simple_return_type_already_not_result_closure() {
317         check_assist_not_applicable(
318             unwrap_result_return_type,
319             r#"
320 //- minicore: result
321 fn foo() {
322     || -> i32$0 {
323         let test = "test";
324         return 42i32;
325     };
326 }
327 "#,
328         );
329     }
330
331     #[test]
332     fn unwrap_result_return_type_simple_with_tail() {
333         check_assist(
334             unwrap_result_return_type,
335             r#"
336 //- minicore: result
337 fn foo() ->$0 Result<i32> {
338     let test = "test";
339     Ok(42i32)
340 }
341 "#,
342             r#"
343 fn foo() -> i32 {
344     let test = "test";
345     42i32
346 }
347 "#,
348         );
349     }
350
351     #[test]
352     fn unwrap_result_return_type_simple_with_tail_closure() {
353         check_assist(
354             unwrap_result_return_type,
355             r#"
356 //- minicore: result
357 fn foo() {
358     || ->$0 Result<i32, String> {
359         let test = "test";
360         Ok(42i32)
361     };
362 }
363 "#,
364             r#"
365 fn foo() {
366     || -> i32 {
367         let test = "test";
368         42i32
369     };
370 }
371 "#,
372         );
373     }
374
375     #[test]
376     fn unwrap_result_return_type_simple_with_tail_only() {
377         check_assist(
378             unwrap_result_return_type,
379             r#"
380 //- minicore: result
381 fn foo() -> Result<i32$0> { Ok(42i32) }
382 "#,
383             r#"
384 fn foo() -> i32 { 42i32 }
385 "#,
386         );
387     }
388
389     #[test]
390     fn unwrap_result_return_type_simple_with_tail_block_like() {
391         check_assist(
392             unwrap_result_return_type,
393             r#"
394 //- minicore: result
395 fn foo() -> Result<i32>$0 {
396     if true {
397         Ok(42i32)
398     } else {
399         Ok(24i32)
400     }
401 }
402 "#,
403             r#"
404 fn foo() -> i32 {
405     if true {
406         42i32
407     } else {
408         24i32
409     }
410 }
411 "#,
412         );
413     }
414
415     #[test]
416     fn unwrap_result_return_type_simple_without_block_closure() {
417         check_assist(
418             unwrap_result_return_type,
419             r#"
420 //- minicore: result
421 fn foo() {
422     || -> Result<i32, String>$0 {
423         if true {
424             Ok(42i32)
425         } else {
426             Ok(24i32)
427         }
428     };
429 }
430 "#,
431             r#"
432 fn foo() {
433     || -> i32 {
434         if true {
435             42i32
436         } else {
437             24i32
438         }
439     };
440 }
441 "#,
442         );
443     }
444
445     #[test]
446     fn unwrap_result_return_type_simple_with_nested_if() {
447         check_assist(
448             unwrap_result_return_type,
449             r#"
450 //- minicore: result
451 fn foo() -> Result<i32>$0 {
452     if true {
453         if false {
454             Ok(1)
455         } else {
456             Ok(2)
457         }
458     } else {
459         Ok(24i32)
460     }
461 }
462 "#,
463             r#"
464 fn foo() -> i32 {
465     if true {
466         if false {
467             1
468         } else {
469             2
470         }
471     } else {
472         24i32
473     }
474 }
475 "#,
476         );
477     }
478
479     #[test]
480     fn unwrap_result_return_type_simple_with_await() {
481         check_assist(
482             unwrap_result_return_type,
483             r#"
484 //- minicore: result
485 async fn foo() -> Result<i$032> {
486     if true {
487         if false {
488             Ok(1.await)
489         } else {
490             Ok(2.await)
491         }
492     } else {
493         Ok(24i32.await)
494     }
495 }
496 "#,
497             r#"
498 async fn foo() -> i32 {
499     if true {
500         if false {
501             1.await
502         } else {
503             2.await
504         }
505     } else {
506         24i32.await
507     }
508 }
509 "#,
510         );
511     }
512
513     #[test]
514     fn unwrap_result_return_type_simple_with_array() {
515         check_assist(
516             unwrap_result_return_type,
517             r#"
518 //- minicore: result
519 fn foo() -> Result<[i32; 3]$0> { Ok([1, 2, 3]) }
520 "#,
521             r#"
522 fn foo() -> [i32; 3] { [1, 2, 3] }
523 "#,
524         );
525     }
526
527     #[test]
528     fn unwrap_result_return_type_simple_with_cast() {
529         check_assist(
530             unwrap_result_return_type,
531             r#"
532 //- minicore: result
533 fn foo() -$0> Result<i32> {
534     if true {
535         if false {
536             Ok(1 as i32)
537         } else {
538             Ok(2 as i32)
539         }
540     } else {
541         Ok(24 as i32)
542     }
543 }
544 "#,
545             r#"
546 fn foo() -> i32 {
547     if true {
548         if false {
549             1 as i32
550         } else {
551             2 as i32
552         }
553     } else {
554         24 as i32
555     }
556 }
557 "#,
558         );
559     }
560
561     #[test]
562     fn unwrap_result_return_type_simple_with_tail_block_like_match() {
563         check_assist(
564             unwrap_result_return_type,
565             r#"
566 //- minicore: result
567 fn foo() -> Result<i32$0> {
568     let my_var = 5;
569     match my_var {
570         5 => Ok(42i32),
571         _ => Ok(24i32),
572     }
573 }
574 "#,
575             r#"
576 fn foo() -> i32 {
577     let my_var = 5;
578     match my_var {
579         5 => 42i32,
580         _ => 24i32,
581     }
582 }
583 "#,
584         );
585     }
586
587     #[test]
588     fn unwrap_result_return_type_simple_with_loop_with_tail() {
589         check_assist(
590             unwrap_result_return_type,
591             r#"
592 //- minicore: result
593 fn foo() -> Result<i32$0> {
594     let my_var = 5;
595     loop {
596         println!("test");
597         5
598     }
599     Ok(my_var)
600 }
601 "#,
602             r#"
603 fn foo() -> i32 {
604     let my_var = 5;
605     loop {
606         println!("test");
607         5
608     }
609     my_var
610 }
611 "#,
612         );
613     }
614
615     #[test]
616     fn unwrap_result_return_type_simple_with_loop_in_let_stmt() {
617         check_assist(
618             unwrap_result_return_type,
619             r#"
620 //- minicore: result
621 fn foo() -> Result<i32$0> {
622     let my_var = let x = loop {
623         break 1;
624     };
625     Ok(my_var)
626 }
627 "#,
628             r#"
629 fn foo() -> i32 {
630     let my_var = let x = loop {
631         break 1;
632     };
633     my_var
634 }
635 "#,
636         );
637     }
638
639     #[test]
640     fn unwrap_result_return_type_simple_with_tail_block_like_match_return_expr() {
641         check_assist(
642             unwrap_result_return_type,
643             r#"
644 //- minicore: result
645 fn foo() -> Result<i32>$0 {
646     let my_var = 5;
647     let res = match my_var {
648         5 => 42i32,
649         _ => return Ok(24i32),
650     };
651     Ok(res)
652 }
653 "#,
654             r#"
655 fn foo() -> i32 {
656     let my_var = 5;
657     let res = match my_var {
658         5 => 42i32,
659         _ => return 24i32,
660     };
661     res
662 }
663 "#,
664         );
665
666         check_assist(
667             unwrap_result_return_type,
668             r#"
669 //- minicore: result
670 fn foo() -> Result<i32$0> {
671     let my_var = 5;
672     let res = if my_var == 5 {
673         42i32
674     } else {
675         return Ok(24i32);
676     };
677     Ok(res)
678 }
679 "#,
680             r#"
681 fn foo() -> i32 {
682     let my_var = 5;
683     let res = if my_var == 5 {
684         42i32
685     } else {
686         return 24i32;
687     };
688     res
689 }
690 "#,
691         );
692     }
693
694     #[test]
695     fn unwrap_result_return_type_simple_with_tail_block_like_match_deeper() {
696         check_assist(
697             unwrap_result_return_type,
698             r#"
699 //- minicore: result
700 fn foo() -> Result<i32$0> {
701     let my_var = 5;
702     match my_var {
703         5 => {
704             if true {
705                 Ok(42i32)
706             } else {
707                 Ok(25i32)
708             }
709         },
710         _ => {
711             let test = "test";
712             if test == "test" {
713                 return Ok(bar());
714             }
715             Ok(53i32)
716         },
717     }
718 }
719 "#,
720             r#"
721 fn foo() -> i32 {
722     let my_var = 5;
723     match my_var {
724         5 => {
725             if true {
726                 42i32
727             } else {
728                 25i32
729             }
730         },
731         _ => {
732             let test = "test";
733             if test == "test" {
734                 return bar();
735             }
736             53i32
737         },
738     }
739 }
740 "#,
741         );
742     }
743
744     #[test]
745     fn unwrap_result_return_type_simple_with_tail_block_like_early_return() {
746         check_assist(
747             unwrap_result_return_type,
748             r#"
749 //- minicore: result
750 fn foo() -> Result<i32$0> {
751     let test = "test";
752     if test == "test" {
753         return Ok(24i32);
754     }
755     Ok(53i32)
756 }
757 "#,
758             r#"
759 fn foo() -> i32 {
760     let test = "test";
761     if test == "test" {
762         return 24i32;
763     }
764     53i32
765 }
766 "#,
767         );
768     }
769
770     #[test]
771     fn unwrap_result_return_type_simple_with_closure() {
772         check_assist(
773             unwrap_result_return_type,
774             r#"
775 //- minicore: result
776 fn foo(the_field: u32) -> Result<u32$0> {
777     let true_closure = || { return true; };
778     if the_field < 5 {
779         let mut i = 0;
780         if true_closure() {
781             return Ok(99);
782         } else {
783             return Ok(0);
784         }
785     }
786     Ok(the_field)
787 }
788 "#,
789             r#"
790 fn foo(the_field: u32) -> u32 {
791     let true_closure = || { return true; };
792     if the_field < 5 {
793         let mut i = 0;
794         if true_closure() {
795             return 99;
796         } else {
797             return 0;
798         }
799     }
800     the_field
801 }
802 "#,
803         );
804
805         check_assist(
806             unwrap_result_return_type,
807             r#"
808 //- minicore: result
809 fn foo(the_field: u32) -> Result<u32$0> {
810     let true_closure = || {
811         return true;
812     };
813     if the_field < 5 {
814         let mut i = 0;
815
816
817         if true_closure() {
818             return Ok(99);
819         } else {
820             return Ok(0);
821         }
822     }
823     let t = None;
824
825     Ok(t.unwrap_or_else(|| the_field))
826 }
827 "#,
828             r#"
829 fn foo(the_field: u32) -> u32 {
830     let true_closure = || {
831         return true;
832     };
833     if the_field < 5 {
834         let mut i = 0;
835
836
837         if true_closure() {
838             return 99;
839         } else {
840             return 0;
841         }
842     }
843     let t = None;
844
845     t.unwrap_or_else(|| the_field)
846 }
847 "#,
848         );
849     }
850
851     #[test]
852     fn unwrap_result_return_type_simple_with_weird_forms() {
853         check_assist(
854             unwrap_result_return_type,
855             r#"
856 //- minicore: result
857 fn foo() -> Result<i32$0> {
858     let test = "test";
859     if test == "test" {
860         return Ok(24i32);
861     }
862     let mut i = 0;
863     loop {
864         if i == 1 {
865             break Ok(55);
866         }
867         i += 1;
868     }
869 }
870 "#,
871             r#"
872 fn foo() -> i32 {
873     let test = "test";
874     if test == "test" {
875         return 24i32;
876     }
877     let mut i = 0;
878     loop {
879         if i == 1 {
880             break 55;
881         }
882         i += 1;
883     }
884 }
885 "#,
886         );
887
888         check_assist(
889             unwrap_result_return_type,
890             r#"
891 //- minicore: result
892 fn foo(the_field: u32) -> Result<u32$0> {
893     if the_field < 5 {
894         let mut i = 0;
895         loop {
896             if i > 5 {
897                 return Ok(55u32);
898             }
899             i += 3;
900         }
901         match i {
902             5 => return Ok(99),
903             _ => return Ok(0),
904         };
905     }
906     Ok(the_field)
907 }
908 "#,
909             r#"
910 fn foo(the_field: u32) -> u32 {
911     if the_field < 5 {
912         let mut i = 0;
913         loop {
914             if i > 5 {
915                 return 55u32;
916             }
917             i += 3;
918         }
919         match i {
920             5 => return 99,
921             _ => return 0,
922         };
923     }
924     the_field
925 }
926 "#,
927         );
928
929         check_assist(
930             unwrap_result_return_type,
931             r#"
932 //- minicore: result
933 fn foo(the_field: u32) -> Result<u32$0> {
934     if the_field < 5 {
935         let mut i = 0;
936         match i {
937             5 => return Ok(99),
938             _ => return Ok(0),
939         }
940     }
941     Ok(the_field)
942 }
943 "#,
944             r#"
945 fn foo(the_field: u32) -> u32 {
946     if the_field < 5 {
947         let mut i = 0;
948         match i {
949             5 => return 99,
950             _ => return 0,
951         }
952     }
953     the_field
954 }
955 "#,
956         );
957
958         check_assist(
959             unwrap_result_return_type,
960             r#"
961 //- minicore: result
962 fn foo(the_field: u32) -> Result<u32$0> {
963     if the_field < 5 {
964         let mut i = 0;
965         if i == 5 {
966             return Ok(99)
967         } else {
968             return Ok(0)
969         }
970     }
971     Ok(the_field)
972 }
973 "#,
974             r#"
975 fn foo(the_field: u32) -> u32 {
976     if the_field < 5 {
977         let mut i = 0;
978         if i == 5 {
979             return 99
980         } else {
981             return 0
982         }
983     }
984     the_field
985 }
986 "#,
987         );
988
989         check_assist(
990             unwrap_result_return_type,
991             r#"
992 //- minicore: result
993 fn foo(the_field: u32) -> Result<u3$02> {
994     if the_field < 5 {
995         let mut i = 0;
996         if i == 5 {
997             return Ok(99);
998         } else {
999             return Ok(0);
1000         }
1001     }
1002     Ok(the_field)
1003 }
1004 "#,
1005             r#"
1006 fn foo(the_field: u32) -> u32 {
1007     if the_field < 5 {
1008         let mut i = 0;
1009         if i == 5 {
1010             return 99;
1011         } else {
1012             return 0;
1013         }
1014     }
1015     the_field
1016 }
1017 "#,
1018         );
1019     }
1020 }