]> git.lizzy.rs Git - rust.git/blob - src/tools/rust-analyzer/crates/ide-assists/src/handlers/wrap_return_type_in_result.rs
Rollup merge of #99954 - dingxiangfei2009:break-out-let-else-higher-up, r=oli-obk
[rust.git] / src / tools / rust-analyzer / crates / ide-assists / src / handlers / wrap_return_type_in_result.rs
1 use std::iter;
2
3 use ide_db::{
4     famous_defs::FamousDefs,
5     syntax_helpers::node_ext::{for_each_tail_expr, walk_expr},
6 };
7 use syntax::{
8     ast::{self, make, Expr},
9     match_ast, AstNode,
10 };
11
12 use crate::{AssistContext, AssistId, AssistKind, Assists};
13
14 // Assist: wrap_return_type_in_result
15 //
16 // Wrap the function's return type into Result.
17 //
18 // ```
19 // # //- minicore: result
20 // fn foo() -> i32$0 { 42i32 }
21 // ```
22 // ->
23 // ```
24 // fn foo() -> Result<i32, ${0:_}> { Ok(42i32) }
25 // ```
26 pub(crate) fn wrap_return_type_in_result(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> {
27     let ret_type = ctx.find_node_at_offset::<ast::RetType>()?;
28     let parent = ret_type.syntax().parent()?;
29     let body = match_ast! {
30         match parent {
31             ast::Fn(func) => func.body()?,
32             ast::ClosureExpr(closure) => match closure.body()? {
33                 Expr::BlockExpr(block) => block,
34                 // closures require a block when a return type is specified
35                 _ => return None,
36             },
37             _ => return None,
38         }
39     };
40
41     let type_ref = &ret_type.ty()?;
42     let ty = ctx.sema.resolve_type(type_ref)?.as_adt();
43     let result_enum =
44         FamousDefs(&ctx.sema, ctx.sema.scope(type_ref.syntax())?.krate()).core_result_Result()?;
45
46     if matches!(ty, Some(hir::Adt::Enum(ret_type)) if ret_type == result_enum) {
47         cov_mark::hit!(wrap_return_type_in_result_simple_return_type_already_result);
48         return None;
49     }
50
51     acc.add(
52         AssistId("wrap_return_type_in_result", AssistKind::RefactorRewrite),
53         "Wrap return type in Result",
54         type_ref.syntax().text_range(),
55         |builder| {
56             let body = ast::Expr::BlockExpr(body);
57
58             let mut exprs_to_wrap = Vec::new();
59             let tail_cb = &mut |e: &_| tail_cb_impl(&mut exprs_to_wrap, e);
60             walk_expr(&body, &mut |expr| {
61                 if let Expr::ReturnExpr(ret_expr) = expr {
62                     if let Some(ret_expr_arg) = &ret_expr.expr() {
63                         for_each_tail_expr(ret_expr_arg, tail_cb);
64                     }
65                 }
66             });
67             for_each_tail_expr(&body, tail_cb);
68
69             for ret_expr_arg in exprs_to_wrap {
70                 let ok_wrapped = make::expr_call(
71                     make::expr_path(make::ext::ident_path("Ok")),
72                     make::arg_list(iter::once(ret_expr_arg.clone())),
73                 );
74                 builder.replace_ast(ret_expr_arg, ok_wrapped);
75             }
76
77             match ctx.config.snippet_cap {
78                 Some(cap) => {
79                     let snippet = format!("Result<{}, ${{0:_}}>", type_ref);
80                     builder.replace_snippet(cap, type_ref.syntax().text_range(), snippet)
81                 }
82                 None => builder
83                     .replace(type_ref.syntax().text_range(), format!("Result<{}, _>", type_ref)),
84             }
85         },
86     )
87 }
88
89 fn tail_cb_impl(acc: &mut Vec<ast::Expr>, e: &ast::Expr) {
90     match e {
91         Expr::BreakExpr(break_expr) => {
92             if let Some(break_expr_arg) = break_expr.expr() {
93                 for_each_tail_expr(&break_expr_arg, &mut |e| tail_cb_impl(acc, e))
94             }
95         }
96         Expr::ReturnExpr(ret_expr) => {
97             if let Some(ret_expr_arg) = &ret_expr.expr() {
98                 for_each_tail_expr(ret_expr_arg, &mut |e| tail_cb_impl(acc, e));
99             }
100         }
101         e => acc.push(e.clone()),
102     }
103 }
104
105 #[cfg(test)]
106 mod tests {
107     use crate::tests::{check_assist, check_assist_not_applicable};
108
109     use super::*;
110
111     #[test]
112     fn wrap_return_type_in_result_simple() {
113         check_assist(
114             wrap_return_type_in_result,
115             r#"
116 //- minicore: result
117 fn foo() -> i3$02 {
118     let test = "test";
119     return 42i32;
120 }
121 "#,
122             r#"
123 fn foo() -> Result<i32, ${0:_}> {
124     let test = "test";
125     return Ok(42i32);
126 }
127 "#,
128         );
129     }
130
131     #[test]
132     fn wrap_return_type_break_split_tail() {
133         check_assist(
134             wrap_return_type_in_result,
135             r#"
136 //- minicore: result
137 fn foo() -> i3$02 {
138     loop {
139         break if true {
140             1
141         } else {
142             0
143         };
144     }
145 }
146 "#,
147             r#"
148 fn foo() -> Result<i32, ${0:_}> {
149     loop {
150         break if true {
151             Ok(1)
152         } else {
153             Ok(0)
154         };
155     }
156 }
157 "#,
158         );
159     }
160
161     #[test]
162     fn wrap_return_type_in_result_simple_closure() {
163         check_assist(
164             wrap_return_type_in_result,
165             r#"
166 //- minicore: result
167 fn foo() {
168     || -> i32$0 {
169         let test = "test";
170         return 42i32;
171     };
172 }
173 "#,
174             r#"
175 fn foo() {
176     || -> Result<i32, ${0:_}> {
177         let test = "test";
178         return Ok(42i32);
179     };
180 }
181 "#,
182         );
183     }
184
185     #[test]
186     fn wrap_return_type_in_result_simple_return_type_bad_cursor() {
187         check_assist_not_applicable(
188             wrap_return_type_in_result,
189             r#"
190 //- minicore: result
191 fn foo() -> i32 {
192     let test = "test";$0
193     return 42i32;
194 }
195 "#,
196         );
197     }
198
199     #[test]
200     fn wrap_return_type_in_result_simple_return_type_bad_cursor_closure() {
201         check_assist_not_applicable(
202             wrap_return_type_in_result,
203             r#"
204 //- minicore: result
205 fn foo() {
206     || -> i32 {
207         let test = "test";$0
208         return 42i32;
209     };
210 }
211 "#,
212         );
213     }
214
215     #[test]
216     fn wrap_return_type_in_result_closure_non_block() {
217         check_assist_not_applicable(
218             wrap_return_type_in_result,
219             r#"
220 //- minicore: result
221 fn foo() { || -> i$032 3; }
222 "#,
223         );
224     }
225
226     #[test]
227     fn wrap_return_type_in_result_simple_return_type_already_result_std() {
228         check_assist_not_applicable(
229             wrap_return_type_in_result,
230             r#"
231 //- minicore: result
232 fn foo() -> core::result::Result<i32$0, String> {
233     let test = "test";
234     return 42i32;
235 }
236 "#,
237         );
238     }
239
240     #[test]
241     fn wrap_return_type_in_result_simple_return_type_already_result() {
242         cov_mark::check!(wrap_return_type_in_result_simple_return_type_already_result);
243         check_assist_not_applicable(
244             wrap_return_type_in_result,
245             r#"
246 //- minicore: result
247 fn foo() -> Result<i32$0, String> {
248     let test = "test";
249     return 42i32;
250 }
251 "#,
252         );
253     }
254
255     #[test]
256     fn wrap_return_type_in_result_simple_return_type_already_result_closure() {
257         check_assist_not_applicable(
258             wrap_return_type_in_result,
259             r#"
260 //- minicore: result
261 fn foo() {
262     || -> Result<i32$0, String> {
263         let test = "test";
264         return 42i32;
265     };
266 }
267 "#,
268         );
269     }
270
271     #[test]
272     fn wrap_return_type_in_result_simple_with_cursor() {
273         check_assist(
274             wrap_return_type_in_result,
275             r#"
276 //- minicore: result
277 fn foo() -> $0i32 {
278     let test = "test";
279     return 42i32;
280 }
281 "#,
282             r#"
283 fn foo() -> Result<i32, ${0:_}> {
284     let test = "test";
285     return Ok(42i32);
286 }
287 "#,
288         );
289     }
290
291     #[test]
292     fn wrap_return_type_in_result_simple_with_tail() {
293         check_assist(
294             wrap_return_type_in_result,
295             r#"
296 //- minicore: result
297 fn foo() ->$0 i32 {
298     let test = "test";
299     42i32
300 }
301 "#,
302             r#"
303 fn foo() -> Result<i32, ${0:_}> {
304     let test = "test";
305     Ok(42i32)
306 }
307 "#,
308         );
309     }
310
311     #[test]
312     fn wrap_return_type_in_result_simple_with_tail_closure() {
313         check_assist(
314             wrap_return_type_in_result,
315             r#"
316 //- minicore: result
317 fn foo() {
318     || ->$0 i32 {
319         let test = "test";
320         42i32
321     };
322 }
323 "#,
324             r#"
325 fn foo() {
326     || -> Result<i32, ${0:_}> {
327         let test = "test";
328         Ok(42i32)
329     };
330 }
331 "#,
332         );
333     }
334
335     #[test]
336     fn wrap_return_type_in_result_simple_with_tail_only() {
337         check_assist(
338             wrap_return_type_in_result,
339             r#"
340 //- minicore: result
341 fn foo() -> i32$0 { 42i32 }
342 "#,
343             r#"
344 fn foo() -> Result<i32, ${0:_}> { Ok(42i32) }
345 "#,
346         );
347     }
348
349     #[test]
350     fn wrap_return_type_in_result_simple_with_tail_block_like() {
351         check_assist(
352             wrap_return_type_in_result,
353             r#"
354 //- minicore: result
355 fn foo() -> i32$0 {
356     if true {
357         42i32
358     } else {
359         24i32
360     }
361 }
362 "#,
363             r#"
364 fn foo() -> Result<i32, ${0:_}> {
365     if true {
366         Ok(42i32)
367     } else {
368         Ok(24i32)
369     }
370 }
371 "#,
372         );
373     }
374
375     #[test]
376     fn wrap_return_type_in_result_simple_without_block_closure() {
377         check_assist(
378             wrap_return_type_in_result,
379             r#"
380 //- minicore: result
381 fn foo() {
382     || -> i32$0 {
383         if true {
384             42i32
385         } else {
386             24i32
387         }
388     };
389 }
390 "#,
391             r#"
392 fn foo() {
393     || -> Result<i32, ${0:_}> {
394         if true {
395             Ok(42i32)
396         } else {
397             Ok(24i32)
398         }
399     };
400 }
401 "#,
402         );
403     }
404
405     #[test]
406     fn wrap_return_type_in_result_simple_with_nested_if() {
407         check_assist(
408             wrap_return_type_in_result,
409             r#"
410 //- minicore: result
411 fn foo() -> i32$0 {
412     if true {
413         if false {
414             1
415         } else {
416             2
417         }
418     } else {
419         24i32
420     }
421 }
422 "#,
423             r#"
424 fn foo() -> Result<i32, ${0:_}> {
425     if true {
426         if false {
427             Ok(1)
428         } else {
429             Ok(2)
430         }
431     } else {
432         Ok(24i32)
433     }
434 }
435 "#,
436         );
437     }
438
439     #[test]
440     fn wrap_return_type_in_result_simple_with_await() {
441         check_assist(
442             wrap_return_type_in_result,
443             r#"
444 //- minicore: result
445 async fn foo() -> i$032 {
446     if true {
447         if false {
448             1.await
449         } else {
450             2.await
451         }
452     } else {
453         24i32.await
454     }
455 }
456 "#,
457             r#"
458 async fn foo() -> Result<i32, ${0:_}> {
459     if true {
460         if false {
461             Ok(1.await)
462         } else {
463             Ok(2.await)
464         }
465     } else {
466         Ok(24i32.await)
467     }
468 }
469 "#,
470         );
471     }
472
473     #[test]
474     fn wrap_return_type_in_result_simple_with_array() {
475         check_assist(
476             wrap_return_type_in_result,
477             r#"
478 //- minicore: result
479 fn foo() -> [i32;$0 3] { [1, 2, 3] }
480 "#,
481             r#"
482 fn foo() -> Result<[i32; 3], ${0:_}> { Ok([1, 2, 3]) }
483 "#,
484         );
485     }
486
487     #[test]
488     fn wrap_return_type_in_result_simple_with_cast() {
489         check_assist(
490             wrap_return_type_in_result,
491             r#"
492 //- minicore: result
493 fn foo() -$0> i32 {
494     if true {
495         if false {
496             1 as i32
497         } else {
498             2 as i32
499         }
500     } else {
501         24 as i32
502     }
503 }
504 "#,
505             r#"
506 fn foo() -> Result<i32, ${0:_}> {
507     if true {
508         if false {
509             Ok(1 as i32)
510         } else {
511             Ok(2 as i32)
512         }
513     } else {
514         Ok(24 as i32)
515     }
516 }
517 "#,
518         );
519     }
520
521     #[test]
522     fn wrap_return_type_in_result_simple_with_tail_block_like_match() {
523         check_assist(
524             wrap_return_type_in_result,
525             r#"
526 //- minicore: result
527 fn foo() -> i32$0 {
528     let my_var = 5;
529     match my_var {
530         5 => 42i32,
531         _ => 24i32,
532     }
533 }
534 "#,
535             r#"
536 fn foo() -> Result<i32, ${0:_}> {
537     let my_var = 5;
538     match my_var {
539         5 => Ok(42i32),
540         _ => Ok(24i32),
541     }
542 }
543 "#,
544         );
545     }
546
547     #[test]
548     fn wrap_return_type_in_result_simple_with_loop_with_tail() {
549         check_assist(
550             wrap_return_type_in_result,
551             r#"
552 //- minicore: result
553 fn foo() -> i32$0 {
554     let my_var = 5;
555     loop {
556         println!("test");
557         5
558     }
559     my_var
560 }
561 "#,
562             r#"
563 fn foo() -> Result<i32, ${0:_}> {
564     let my_var = 5;
565     loop {
566         println!("test");
567         5
568     }
569     Ok(my_var)
570 }
571 "#,
572         );
573     }
574
575     #[test]
576     fn wrap_return_type_in_result_simple_with_loop_in_let_stmt() {
577         check_assist(
578             wrap_return_type_in_result,
579             r#"
580 //- minicore: result
581 fn foo() -> i32$0 {
582     let my_var = let x = loop {
583         break 1;
584     };
585     my_var
586 }
587 "#,
588             r#"
589 fn foo() -> Result<i32, ${0:_}> {
590     let my_var = let x = loop {
591         break 1;
592     };
593     Ok(my_var)
594 }
595 "#,
596         );
597     }
598
599     #[test]
600     fn wrap_return_type_in_result_simple_with_tail_block_like_match_return_expr() {
601         check_assist(
602             wrap_return_type_in_result,
603             r#"
604 //- minicore: result
605 fn foo() -> i32$0 {
606     let my_var = 5;
607     let res = match my_var {
608         5 => 42i32,
609         _ => return 24i32,
610     };
611     res
612 }
613 "#,
614             r#"
615 fn foo() -> Result<i32, ${0:_}> {
616     let my_var = 5;
617     let res = match my_var {
618         5 => 42i32,
619         _ => return Ok(24i32),
620     };
621     Ok(res)
622 }
623 "#,
624         );
625
626         check_assist(
627             wrap_return_type_in_result,
628             r#"
629 //- minicore: result
630 fn foo() -> i32$0 {
631     let my_var = 5;
632     let res = if my_var == 5 {
633         42i32
634     } else {
635         return 24i32;
636     };
637     res
638 }
639 "#,
640             r#"
641 fn foo() -> Result<i32, ${0:_}> {
642     let my_var = 5;
643     let res = if my_var == 5 {
644         42i32
645     } else {
646         return Ok(24i32);
647     };
648     Ok(res)
649 }
650 "#,
651         );
652     }
653
654     #[test]
655     fn wrap_return_type_in_result_simple_with_tail_block_like_match_deeper() {
656         check_assist(
657             wrap_return_type_in_result,
658             r#"
659 //- minicore: result
660 fn foo() -> i32$0 {
661     let my_var = 5;
662     match my_var {
663         5 => {
664             if true {
665                 42i32
666             } else {
667                 25i32
668             }
669         },
670         _ => {
671             let test = "test";
672             if test == "test" {
673                 return bar();
674             }
675             53i32
676         },
677     }
678 }
679 "#,
680             r#"
681 fn foo() -> Result<i32, ${0:_}> {
682     let my_var = 5;
683     match my_var {
684         5 => {
685             if true {
686                 Ok(42i32)
687             } else {
688                 Ok(25i32)
689             }
690         },
691         _ => {
692             let test = "test";
693             if test == "test" {
694                 return Ok(bar());
695             }
696             Ok(53i32)
697         },
698     }
699 }
700 "#,
701         );
702     }
703
704     #[test]
705     fn wrap_return_type_in_result_simple_with_tail_block_like_early_return() {
706         check_assist(
707             wrap_return_type_in_result,
708             r#"
709 //- minicore: result
710 fn foo() -> i$032 {
711     let test = "test";
712     if test == "test" {
713         return 24i32;
714     }
715     53i32
716 }
717 "#,
718             r#"
719 fn foo() -> Result<i32, ${0:_}> {
720     let test = "test";
721     if test == "test" {
722         return Ok(24i32);
723     }
724     Ok(53i32)
725 }
726 "#,
727         );
728     }
729
730     #[test]
731     fn wrap_return_type_in_result_simple_with_closure() {
732         check_assist(
733             wrap_return_type_in_result,
734             r#"
735 //- minicore: result
736 fn foo(the_field: u32) ->$0 u32 {
737     let true_closure = || { return true; };
738     if the_field < 5 {
739         let mut i = 0;
740         if true_closure() {
741             return 99;
742         } else {
743             return 0;
744         }
745     }
746     the_field
747 }
748 "#,
749             r#"
750 fn foo(the_field: u32) -> Result<u32, ${0:_}> {
751     let true_closure = || { return true; };
752     if the_field < 5 {
753         let mut i = 0;
754         if true_closure() {
755             return Ok(99);
756         } else {
757             return Ok(0);
758         }
759     }
760     Ok(the_field)
761 }
762 "#,
763         );
764
765         check_assist(
766             wrap_return_type_in_result,
767             r#"
768 //- minicore: result
769 fn foo(the_field: u32) -> u32$0 {
770     let true_closure = || {
771         return true;
772     };
773     if the_field < 5 {
774         let mut i = 0;
775
776
777         if true_closure() {
778             return 99;
779         } else {
780             return 0;
781         }
782     }
783     let t = None;
784
785     t.unwrap_or_else(|| the_field)
786 }
787 "#,
788             r#"
789 fn foo(the_field: u32) -> Result<u32, ${0:_}> {
790     let true_closure = || {
791         return true;
792     };
793     if the_field < 5 {
794         let mut i = 0;
795
796
797         if true_closure() {
798             return Ok(99);
799         } else {
800             return Ok(0);
801         }
802     }
803     let t = None;
804
805     Ok(t.unwrap_or_else(|| the_field))
806 }
807 "#,
808         );
809     }
810
811     #[test]
812     fn wrap_return_type_in_result_simple_with_weird_forms() {
813         check_assist(
814             wrap_return_type_in_result,
815             r#"
816 //- minicore: result
817 fn foo() -> i32$0 {
818     let test = "test";
819     if test == "test" {
820         return 24i32;
821     }
822     let mut i = 0;
823     loop {
824         if i == 1 {
825             break 55;
826         }
827         i += 1;
828     }
829 }
830 "#,
831             r#"
832 fn foo() -> Result<i32, ${0:_}> {
833     let test = "test";
834     if test == "test" {
835         return Ok(24i32);
836     }
837     let mut i = 0;
838     loop {
839         if i == 1 {
840             break Ok(55);
841         }
842         i += 1;
843     }
844 }
845 "#,
846         );
847
848         check_assist(
849             wrap_return_type_in_result,
850             r#"
851 //- minicore: result
852 fn foo(the_field: u32) -> u32$0 {
853     if the_field < 5 {
854         let mut i = 0;
855         loop {
856             if i > 5 {
857                 return 55u32;
858             }
859             i += 3;
860         }
861         match i {
862             5 => return 99,
863             _ => return 0,
864         };
865     }
866     the_field
867 }
868 "#,
869             r#"
870 fn foo(the_field: u32) -> Result<u32, ${0:_}> {
871     if the_field < 5 {
872         let mut i = 0;
873         loop {
874             if i > 5 {
875                 return Ok(55u32);
876             }
877             i += 3;
878         }
879         match i {
880             5 => return Ok(99),
881             _ => return Ok(0),
882         };
883     }
884     Ok(the_field)
885 }
886 "#,
887         );
888
889         check_assist(
890             wrap_return_type_in_result,
891             r#"
892 //- minicore: result
893 fn foo(the_field: u32) -> u3$02 {
894     if the_field < 5 {
895         let mut i = 0;
896         match i {
897             5 => return 99,
898             _ => return 0,
899         }
900     }
901     the_field
902 }
903 "#,
904             r#"
905 fn foo(the_field: u32) -> Result<u32, ${0:_}> {
906     if the_field < 5 {
907         let mut i = 0;
908         match i {
909             5 => return Ok(99),
910             _ => return Ok(0),
911         }
912     }
913     Ok(the_field)
914 }
915 "#,
916         );
917
918         check_assist(
919             wrap_return_type_in_result,
920             r#"
921 //- minicore: result
922 fn foo(the_field: u32) -> u32$0 {
923     if the_field < 5 {
924         let mut i = 0;
925         if i == 5 {
926             return 99
927         } else {
928             return 0
929         }
930     }
931     the_field
932 }
933 "#,
934             r#"
935 fn foo(the_field: u32) -> Result<u32, ${0:_}> {
936     if the_field < 5 {
937         let mut i = 0;
938         if i == 5 {
939             return Ok(99)
940         } else {
941             return Ok(0)
942         }
943     }
944     Ok(the_field)
945 }
946 "#,
947         );
948
949         check_assist(
950             wrap_return_type_in_result,
951             r#"
952 //- minicore: result
953 fn foo(the_field: u32) -> $0u32 {
954     if the_field < 5 {
955         let mut i = 0;
956         if i == 5 {
957             return 99;
958         } else {
959             return 0;
960         }
961     }
962     the_field
963 }
964 "#,
965             r#"
966 fn foo(the_field: u32) -> Result<u32, ${0:_}> {
967     if the_field < 5 {
968         let mut i = 0;
969         if i == 5 {
970             return Ok(99);
971         } else {
972             return Ok(0);
973         }
974     }
975     Ok(the_field)
976 }
977 "#,
978         );
979     }
980 }